diff --git a/.bazelrc b/.bazelrc index a635862b43a43c..c21cf6e6e15d5d 100644 --- a/.bazelrc +++ b/.bazelrc @@ -255,6 +255,14 @@ build:cuda_clang_official --action_env=CLANG_CUDA_COMPILER_PATH="/usr/lib/llvm-1 build:cuda_clang_official --action_env=LD_LIBRARY_PATH="/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64" build:cuda_clang_official --crosstool_top="@sigbuild-r2.16-clang_config_cuda//crosstool:toolchain" +# Build with nvcc for CUDA and clang for host +build:nvcc_clang --config=cuda +# Unfortunately, cuda_configure.bzl demands this for using nvcc + clang +build:nvcc_clang --action_env=TF_CUDA_CLANG="1" +build:nvcc_clang --action_env=TF_NVCC_CLANG="1" +build:nvcc_clang --@local_config_cuda//:cuda_compiler=nvcc + + # Debug config build:dbg -c dbg # Only include debug info for files under tensorflow/, excluding kernels, to @@ -527,8 +535,8 @@ build:rbe_linux_cuda --repo_env=TF_NCCL_CONFIG_REPO="@sigbuild-r2.16-clang_confi test:rbe_linux_cuda --test_env=LD_LIBRARY_PATH="/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64" build:rbe_linux_cuda_nvcc --config=rbe_linux_cuda +build:rbe_linux_cuda_nvcc --config=nvcc_clang build:rbe_linux_cuda_nvcc --repo_env TF_NCCL_USE_STUB=1 -build:rbe_linux_cuda_nvcc --action_env=TF_NVCC_CLANG="1" # TODO(kanglan): Remove rbe_win and rbe_win_py3* after b/289091160 is fixed build:rbe_win --config=rbe_base @@ -577,6 +585,7 @@ build:elinux_armhf --copt -mfp16-format=ieee # Load rc file written by ./configure. try-import %workspace%/.tf_configure.bazelrc +try-import %workspace%/xla_configure.bazelrc # Load rc file with user-specific options. try-import %workspace%/.bazelrc.user @@ -777,28 +786,38 @@ test:linux_cuda_pycpp_test_filters --build_tag_filters=-no_oss,-oss_excluded,-os test:linux_cuda_pycpp_test_filters --test_lang_filters=cc,py --test_size_filters=small,medium test:linux_cuda_pycpp_test --config=linux_cuda_pycpp_test_filters -- //tensorflow/... -//tensorflow/python/integration_testing/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... # ARM64 PYCPP -test:linux_arm64_pycpp_test_filters --test_tag_filters=-no_oss,-no_aarch64,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only -test:linux_arm64_pycpp_test_filters --build_tag_filters=-no_oss,-no_aarch64,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only -test:linux_arm64_pycpp_test_filters --test_lang_filters=cc,py --test_size_filters=small,medium --flaky_test_attempts=3 +# In Linux Arm64 presubmit/continuous build, we cross-compile the binaries on +# Linux x86 so that we can use RBE. Since tests still need to run on the single +# host Arm64 machine, the build becomes too slow (~30 min) to be a presubmit. +# For testing purposes, we want to see the runtime performance of an +# experimental job that is build-only, i.e, we only build the test targets and +# do not run them. By prefixing the configs with "build", we can run both +# `bazel build` and `bazel test` commands with the same config as test configs +# inherit from build. +build:linux_arm64_pycpp_test_filters --test_tag_filters=-no_oss,-no_aarch64,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only +build:linux_arm64_pycpp_test_filters --build_tag_filters=-no_oss,-no_aarch64,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only +build:linux_arm64_pycpp_test_filters --test_lang_filters=cc,py --test_size_filters=small,medium --flaky_test_attempts=3 # TODO(michaelhudgins): Why do we need to specifically omit go and java here? -test:linux_arm64_pycpp_test --config=linux_arm64_pycpp_test_filters -- //tensorflow/... -//tensorflow/python/integration_testing/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/core/grappler/optimizers:auto_mixed_precision_test_cpu -//tensorflow/core/grappler/optimizers:remapper_test_cpu -//tensorflow/core/kernels/image:resize_bicubic_op_test -//tensorflow/compiler/mlir/tfr/examples/customization:test_ops_test -//tensorflow/compiler/mlir/tfr/examples/mnist:mnist_ops_test -//tensorflow/compiler/mlir/tfr/examples/pad:pad_ops_test -//tensorflow/python/tools:aot_compiled_test +build:linux_arm64_pycpp_test --config=linux_arm64_pycpp_test_filters -- //tensorflow/... -//tensorflow/python/integration_testing/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/core/grappler/optimizers:auto_mixed_precision_test_cpu -//tensorflow/core/grappler/optimizers:remapper_test_cpu -//tensorflow/core/kernels/image:resize_bicubic_op_test -//tensorflow/compiler/mlir/tfr/examples/customization:test_ops_test -//tensorflow/compiler/mlir/tfr/examples/mnist:mnist_ops_test -//tensorflow/compiler/mlir/tfr/examples/pad:pad_ops_test -//tensorflow/python/tools:aot_compiled_test # CROSS-COMPILE ARM64 PYCPP -test:cross_compile_linux_arm64_pycpp_test --config=linux_arm64_pycpp_test +build:cross_compile_linux_arm64_pycpp_test --config=linux_arm64_pycpp_test # Tests that fail only when cross-compiled -test:cross_compile_linux_arm64_pycpp_test -//tensorflow/compiler/mlir/quantization/stablehlo:convert_tf_quant_to_mhlo_int_test +build:cross_compile_linux_arm64_pycpp_test -//tensorflow/compiler/mlir/quantization/stablehlo:convert_tf_quant_to_mhlo_int_test # MACOS ARM64 PYCPP test:macos_arm64_pycpp_test_filters --test_tag_filters=-no_oss,-oss_excluded,-oss_serial,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test,-no_mac_arm64,-no_aarch64 test:macos_arm64_pycpp_test_filters --build_tag_filters=-no_oss,-oss_excluded,-oss_serial,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test,-no_mac_arm64,-no_aarch64 test:macos_arm64_pycpp_test_filters --test_lang_filters=cc,py --test_size_filters=small,medium test:macos_arm64_pycpp_test --config=macos_arm64_pycpp_test_filters -- //tensorflow/... -//tensorflow/python/integration_testing/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/compiler/aot/... -//tensorflow/core/kernels/image:resize_bicubic_op_test # MACOS X86 PYCPP -test:macos_x86_pycpp_test_filters --test_tag_filters=-no_oss,-oss_excluded,-oss_serial,-no_oss_py38,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test -test:macos_x86_pycpp_test_filters --build_tag_filters=-no_oss,-oss_excluded,-oss_serial,-no_oss_py38,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test -test:macos_x86_pycpp_test_filters --keep_going --test_lang_filters=cc,py --test_size_filters=small,medium -test:macos_x86_pycpp_test --config=macos_x86_pycpp_test_filters -- //tensorflow/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/python/integration_testing/... -//tensorflow/tools/toolchains/... -//tensorflow/lite/... -//tensorflow/compiler/aot/... +# These are defined as build configs so that we can run a build only job. See +# the note under "ARM64 PYCPP" for more details. +build:macos_x86_pycpp_test_filters --test_tag_filters=-no_oss,-oss_excluded,-oss_serial,-no_oss_py38,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test +build:macos_x86_pycpp_test_filters --build_tag_filters=-no_oss,-oss_excluded,-oss_serial,-no_oss_py38,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test +build:macos_x86_pycpp_test_filters --keep_going --test_lang_filters=cc,py --test_size_filters=small,medium +build:macos_x86_pycpp_test --config=macos_x86_pycpp_test_filters -- //tensorflow/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/python/integration_testing/... -//tensorflow/tools/toolchains/... -//tensorflow/lite/... -//tensorflow/compiler/aot/... # CROSS-COMPILE MACOS X86 PYCPP -test:cross_compile_macos_x86_pycpp_test --config=macos_x86_pycpp_test -test:cross_compile_macos_x86_pycpp_test -//tensorflow/core/kernels:quantized_conv_ops_test -//tensorflow/core/kernels:quantized_matmul_op_test -//tensorflow/python/ops:quantized_conv_ops_test -//tensorflow/tools/graph_transforms:transforms_test -//tensorflow/python/tools:aot_compiled_test +build:cross_compile_macos_x86_pycpp_test --config=macos_x86_pycpp_test +build:cross_compile_macos_x86_pycpp_test -//tensorflow/core/kernels:quantized_conv_ops_test -//tensorflow/core/kernels:quantized_matmul_op_test -//tensorflow/python/ops:quantized_conv_ops_test -//tensorflow/tools/graph_transforms:transforms_test -//tensorflow/python/tools:aot_compiled_test # END TF TEST SUITE OPTIONS # START CROSS-COMPILE CONFIGS @@ -855,8 +874,12 @@ build:cross_compile_macos_x86 --platform_mappings=tensorflow/tools/toolchains/cr # RBE cross-compile configs for Darwin x86 build:rbe_cross_compile_macos_x86 --config=cross_compile_macos_x86 build:rbe_cross_compile_macos_x86 --config=rbe_cross_compile_base +build:rbe_cross_compile_macos_x86 --bes_upload_mode=nowait_for_upload_complete test:rbe_cross_compile_macos_x86 --config=rbe_cross_compile_base # Increase the test timeout as tests often take longer on mac. test:rbe_cross_compile_macos_x86 --test_timeout=300,450,1200,3600 +# Limit jobs to 100 to avoid running into "out of memory" issues (b/316266643) +build:rbe_cross_compile_macos_x86 --jobs=100 +test:rbe_cross_compile_macos_x86 --jobs=100 # END MACOS CROSS-COMPILE CONFIGS # END CROSS-COMPILE CONFIGS diff --git a/.github/workflows/arm-cd.yml b/.github/workflows/arm-cd.yml index 15433f8f14be32..ddcc1b373e5c14 100644 --- a/.github/workflows/arm-cd.yml +++ b/.github/workflows/arm-cd.yml @@ -57,12 +57,12 @@ jobs: run: find /home/ubuntu/actions-runner/_work/tensorflow/tensorflow/. -name . -o -prune -exec sudo rm -rf -- {} + || true - name: Checkout repository for nightly (skipped for releases) if: ${{ github.event_name == 'schedule' }} - uses: actions/checkout@755da8c3cf115ac066823e79a1e1788f8940201b # v3.2.0 + uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # v4.1.1 with: ref: 'nightly' - name: Checkout repository for releases (skipped for nightly) if: ${{ github.event_name == 'push' }} - uses: actions/checkout@755da8c3cf115ac066823e79a1e1788f8940201b # v3.2.0 + uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # v4.1.1 - name: Build and test pip wheel shell: bash run: | diff --git a/.github/workflows/arm-ci-extended-cpp.yml b/.github/workflows/arm-ci-extended-cpp.yml index e648297d37e789..2f9c67fb81eede 100644 --- a/.github/workflows/arm-ci-extended-cpp.yml +++ b/.github/workflows/arm-ci-extended-cpp.yml @@ -50,12 +50,12 @@ jobs: run: find /home/ubuntu/actions-runner/_work/tensorflow/tensorflow/. -name . -o -prune -exec sudo rm -rf -- {} + || true - name: Checkout repository for nightly (skipped for releases) if: ${{ github.event_name == 'schedule' }} - uses: actions/checkout@755da8c3cf115ac066823e79a1e1788f8940201b # v3.2.0 + uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # v4.1.1 with: ref: 'nightly' - name: Checkout repository if: ${{ github.event_name == 'push' }} - uses: actions/checkout@755da8c3cf115ac066823e79a1e1788f8940201b # v3.2.0 + uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # v4.1.1 - name: Build binary and run C++ tests shell: bash run: | diff --git a/.github/workflows/arm-ci-extended.yml b/.github/workflows/arm-ci-extended.yml index 01ce70ba82ecfa..db782d3cf35f30 100644 --- a/.github/workflows/arm-ci-extended.yml +++ b/.github/workflows/arm-ci-extended.yml @@ -51,12 +51,12 @@ jobs: run: find /home/ubuntu/actions-runner/_work/tensorflow/tensorflow/. -name . -o -prune -exec sudo rm -rf -- {} + || true - name: Checkout repository for nightly (skipped for releases) if: ${{ github.event_name == 'schedule' }} - uses: actions/checkout@755da8c3cf115ac066823e79a1e1788f8940201b # v3.2.0 + uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # v4.1.1 with: ref: 'nightly' - name: Checkout repository if: ${{ github.event_name == 'push' }} - uses: actions/checkout@755da8c3cf115ac066823e79a1e1788f8940201b # v3.2.0 + uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # v4.1.1 - name: Build binary and run python tests on nightly for all python versions shell: bash run: | diff --git a/.github/workflows/arm-ci.yml b/.github/workflows/arm-ci.yml index 3b07683008391d..7b3e8c6f24df49 100644 --- a/.github/workflows/arm-ci.yml +++ b/.github/workflows/arm-ci.yml @@ -47,7 +47,7 @@ jobs: shell: bash run: find /home/ubuntu/actions-runner/_work/tensorflow/tensorflow/. -name . -o -prune -exec sudo rm -rf -- {} + || true - name: Checkout repository - uses: actions/checkout@755da8c3cf115ac066823e79a1e1788f8940201b # v3.2.0 + uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # v4.1.1 - name: Build binary and run python tests shell: bash run: | diff --git a/.github/workflows/osv-scanner-scheduled.yml b/.github/workflows/osv-scanner-scheduled.yml index fb7366768436c5..a471d68b4fd2d7 100644 --- a/.github/workflows/osv-scanner-scheduled.yml +++ b/.github/workflows/osv-scanner-scheduled.yml @@ -28,7 +28,7 @@ permissions: jobs: scan-scheduled: if: github.repository == 'tensorflow/tensorflow' - uses: "google/osv-scanner/.github/workflows/osv-scanner-reusable.yml@main" + uses: "google/osv-scanner-action/.github/workflows/osv-scanner-reusable.yml@v1.6.2-beta1" with: scan-args: |- --lockfile=requirements.txt:./requirements_lock_3_9.txt diff --git a/.github/workflows/update-rbe.yml b/.github/workflows/update-rbe.yml index 1b421effec8198..bdce23b94d02f1 100644 --- a/.github/workflows/update-rbe.yml +++ b/.github/workflows/update-rbe.yml @@ -106,13 +106,13 @@ jobs: map sigbuild-r2.14-clang-python3.10 2.14-python3.10 map sigbuild-r2.14-clang-python3.11 2.14-python3.11 # TF 2.16 - map sigbuild-r2.16 2.16-python3.9 + map sigbuild-r2.16 2.16-python3.11 map sigbuild-r2.16-python3.9 2.16-python3.9 map sigbuild-r2.16-python3.10 2.16-python3.10 map sigbuild-r2.16-python3.11 2.16-python3.11 map sigbuild-r2.16-python3.12 2.16-python3.12 # TF 2.16 + Clang (containers are the same, but env vars in configs.bzl are different) - map sigbuild-r2.16-clang 2.16-python3.9 + map sigbuild-r2.16-clang 2.16-python3.11 map sigbuild-r2.16-clang-python3.9 2.16-python3.9 map sigbuild-r2.16-clang-python3.10 2.16-python3.10 map sigbuild-r2.16-clang-python3.11 2.16-python3.11 diff --git a/.gitignore b/.gitignore index cebef4f590d47e..614cde3446a16f 100644 --- a/.gitignore +++ b/.gitignore @@ -3,6 +3,7 @@ node_modules /.bazelrc.user /.tf_configure.bazelrc +/xla_configure.bazelrc /bazel-* /bazel_pip /tools/python_bin_path.sh diff --git a/RELEASE.md b/RELEASE.md index 784e2ac28ceea7..e75ca35b589d73 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -1,4 +1,4 @@ -# Release 2.16.0 +# Release 2.17.0 ## TensorFlow @@ -9,11 +9,31 @@ * * -* `tf.summary.trace_on` now takes a `profiler_outdir` argument. This must be set - if `profiler` arg is set to `True`. - * `tf.summary.trace_export`'s `profiler_outdir` arg is now a no-op. Enabling - the profiler now requires setting `profiler_outdir` in `trace_on`. +### Known Caveats +* +* +* + +### Major Features and Improvements + +* +* + +### Bug Fixes and Other Changes + +* +* +* + +## Keras + + + +### Breaking Changes + +* +* ### Known Caveats @@ -26,6 +46,101 @@ * * +### Bug Fixes and Other Changes + +* +* +* + +* `tf.lite` + * Quantization for `FullyConnected` layer is switched from per-tensor to + per-channel scales for dynamic range quantization use case (`float32` + inputs / outputs and `int8` weights). The change enables new quantization + schema globally in the converter and inference engine. The new behaviour + can be disabled via experimental + flag `converter._experimental_disable_per_channel_quantization_for_dense_layers = True`. + +## Thanks to our Contributors + +This release contains contributions from many people at Google, as well as: + +, , , , , + +# Release 2.16.0 + +## TensorFlow + + + +* TensorFlow Windows Build: + + * Clang is now the default compiler to build TensorFlow CPU wheels on the + Windows Platform starting with this release. The currently supported + version is LLVM/clang 17. The official Wheels-published on PyPI will be + based on Clang; however, users retain the option to build wheels using + the MSVC compiler following the steps mentioned in + https://www.tensorflow.org/install/source_windows as has been the case + before + +### Breaking Changes + +* +* + +* `tf.summary.trace_on` now takes a `profiler_outdir` argument. This must be + set if `profiler` arg is set to `True`. + + * `tf.summary.trace_export`'s `profiler_outdir` arg is now a no-op. + Enabling the profiler now requires setting `profiler_outdir` in + `trace_on`. + +* `tf.estimator` + + * The tf.estimator API is removed. + +* Keras 3.0 will be the default Keras version. You may need to update your + script to use Keras 3.0. + +* Please refer to the new Keras documentation for Keras 3.0 + (https://keras.io/keras_3). + +* To continue using Keras 2.0, do the following. + +* 1. Install tf-keras via pip install tf-keras~=2.16 + + 1. To switch tf.keras to use Keras 2 (tf-keras), set the environment + variable TF_USE_LEGACY_KERAS=1 directly or in your python program by + import os;os.environ["TF_USE_LEGACY_KERAS"]=1. Please note that this + will set it for all packages in your Python runtime program + +* 1. Change import of keras from tensorflow as follows +* import tensorflow.keras as keras and import keras to import tf_keras as + keras +* **Apple Silicon users:** If you previously installed TensorFlow using + `pip install tensorflow-macos`, please update your installation method. Use + `pip install tensorflow` from now on. Starting with TF 2.17, the + `tensorflow-macos` package will no longer receive updates. + +### Known Caveats + +* +* +* + +* Full aarch64 Linux and Arm64 macOS wheels are now published to the + `tensorflow` pypi repository and no longer redirect to a separate package. + +### Major Features and Improvements + +* +* + +* Support for Python 3.12 has been added. +* [tensorflow-tpu](https://pypi.org/project/tensorflow-tpu/) package is now + available for easier TPU based installs. +* TensorFlow pip packages are now built with CUDA 12.3 and cuDNN 8.9.7 + + ### Bug Fixes and Other Changes * @@ -54,6 +169,21 @@ * Added `experimental_skip_slot_variables` (a boolean option) to skip restoring of optimizer slot variables in a checkpoint. +* `tf.saved_model.SaveOptions` + + * `SaveOptions` now takes a new argument called + `experimental_debug_stripper`. When enabled, this strips the debug nodes + from both the node defs and the function defs of the graph. Note that + this currently only strips the `Assert` nodes from the graph and + converts them into `NoOp`s instead. + +* `tf.data` + + * `tf.data` now has an `autotune_options.initial_parallelism` option to + control the initial parallelism setting used by autotune before the data + pipeline has started running. The default is 16. A lower value reduces + initial memory usage, while a higher value improves startup time. + ## Keras * `keras.layers.experimental.DynamicEmbedding` diff --git a/ci/official/README.md b/ci/official/README.md index d070af86cd8090..3c0181c5384392 100644 --- a/ci/official/README.md +++ b/ci/official/README.md @@ -45,7 +45,7 @@ cd tensorflow-git-dir # Here is a single-line example of running a script on Linux to build the # GPU version of TensorFlow for Python 3.12, using the public TF bazel cache and # a local build cache: -TFCI=py312,linux_x86_cuda,multicache ci/official/wheel.sh +TFCI=py312,linux_x86_cuda,public_cache,disk_cache ci/official/wheel.sh # First, set your TFCI variable to choose the environment settings. # TFCI is a comma-separated list of filenames from the envs directory, which @@ -57,9 +57,10 @@ TFCI=py312,linux_x86_cuda,multicache ci/official/wheel.sh # value in the "env_vars" list that you can choose to copy that environment. # Ex. 1: TFCI=py311,linux_x86_cuda,nightly_upload (nightly job) # Ex. 2: TFCI=py39,linux_x86,rbe (continuous job) -# Non-Googlers should replace "nightly_upload" or "rbe" with "multicache". -# Googlers should replace "nightly_upload" with "multicache" or "rbe", if -# you have set up your system to use RBE (see further below). +# Non-Googlers should replace "nightly_upload" or "rbe" with +# "public_cache,disk_cache". +# Googlers should replace "nightly_upload" with "public_cache,disk_cache" or +# "rbe", if you have set up your system to use RBE (see further below). # # Here is how to choose your TFCI value: # 1. A Python version must come first, because other scripts reference it. @@ -74,7 +75,9 @@ TFCI=py312,linux_x86_cuda,multicache ci/official/wheel.sh # Ex. linux_x86_cuda -- x86_64 Linux platform, with Nvidia CUDA support # Ex. macos_arm64 -- arm64 MacOS platform # 3. Add modifiers. Some modifiers for local execution are: -# Ex. multicache -- Use a local cache combined with TF's public cache +# Ex. disk_cache -- Use a local cache +# Ex. public_cache -- Use TF's public cache (read-only) +# Ex. public_cache_push -- Use TF's public cache (read and write, Googlers only) # Ex. rbe -- Use RBE for faster builds (Googlers only; see below) # Ex. no_docker -- Disable docker on enabled platforms # See full examples below for more details on these. Some other modifiers are: @@ -94,7 +97,7 @@ TFCI=py312,linux_x86_cuda,multicache ci/official/wheel.sh # or tests passing incorrectly. # - Automatic LLVM updates are known to extend build time even with # the cache; this is unavoidable. -export TFCI=py311,linux_x86,multicache +export TFCI=py311,linux_x86,public_cache,disk_cache # Recommended: Configure Docker. (Linux only) # @@ -127,7 +130,7 @@ export TFCI=py311,linux_x86,multicache # it is only available to a limited set of internal TensorFlow developers. # # RBE is incompatible with local caching, so you must remove -# ci/official/envs/local_multicache from your $TFCI file. +# disk_cache, public_cache, and public_cache_push from your $TFCI file. # # To use RBE, you must first run `gcloud auth application-default login`, then: export TFCI=py311,linux_x86,rbe diff --git a/ci/official/any.sh b/ci/official/any.sh index 980bb3cfdf403a..dc1484b64dc9ea 100755 --- a/ci/official/any.sh +++ b/ci/official/any.sh @@ -29,7 +29,7 @@ # ./any.sh # # 3. DO THE SAME WITH A LOCAL CACHE OR RBE: -# export TF_ANY_EXTRA_ENV=ci/official/envs/local_multicache +# export TF_ANY_EXTRA_ENV=ci/official/envs/public_cache,ci/official/envs/disk_cache # ... # ./any.sh # or @@ -39,8 +39,8 @@ set -euxo pipefail cd "$(dirname "$0")/../../" # tensorflow/ # Any request that includes "nightly_upload" should just use the -# local multi-cache instead. -export TFCI="$(echo $TFCI | sed 's/,nightly_upload/,multicache/')" +# local multi-cache (public read-only cache + disk cache) instead. +export TFCI="$(echo $TFCI | sed 's/,nightly_upload/,public_cache,disk_cache/')" if [[ -n "${TF_ANY_EXTRA_ENV:-}" ]]; then export TFCI="$TFCI,$TF_ANY_EXTRA_ENV" fi diff --git a/ci/official/bisect.sh b/ci/official/bisect.sh index 4076a73b867e7a..7f18dd1460ff5b 100755 --- a/ci/official/bisect.sh +++ b/ci/official/bisect.sh @@ -34,6 +34,6 @@ # export TF_ANY_MODE=test set -euxo pipefail cd "$(dirname "$0")/../../" # tensorflow/ -export TFCI="$(echo $TFCI | sed 's/,nightly_upload/,multicache/')" +export TFCI="$(echo $TFCI | sed 's/,nightly_upload/,public_cache,disk_cache/')" git bisect start "$TF_BISECT_BAD" "$TF_BISECT_GOOD" git bisect run $TF_BISECT_SCRIPT diff --git a/ci/official/code_check_full.sh b/ci/official/code_check_full.sh index 5b1370b4f31e06..448fb82bf288b9 100755 --- a/ci/official/code_check_full.sh +++ b/ci/official/code_check_full.sh @@ -15,4 +15,4 @@ # ============================================================================== source "${BASH_SOURCE%/*}/utilities/setup.sh" -tfrun bats ./ci/official/utilities/code_check_full.bats --timing --output "$TFCI_OUTPUT_DIR" +tfrun bats ./ci/official/utilities/code_check_full.bats --timing --output "$TFCI_OUTPUT_DIR" \ No newline at end of file diff --git a/ci/official/containers/linux_arm64/devel.packages.txt b/ci/official/containers/linux_arm64/devel.packages.txt index efbae80eefacee..a8a9cb442c8b0b 100644 --- a/ci/official/containers/linux_arm64/devel.packages.txt +++ b/ci/official/containers/linux_arm64/devel.packages.txt @@ -3,8 +3,6 @@ autoconf automake build-essential ca-certificates -# TODO(b/308399490) Remove CMake once dm-tree (Keras dependency) has 3.12 wheels -cmake llvm-17 clang-17 clang-format-12 diff --git a/ci/official/containers/linux_arm64/jax.requirements.txt b/ci/official/containers/linux_arm64/jax.requirements.txt index 878d229d0f237e..6211528986fdc0 100644 --- a/ci/official/containers/linux_arm64/jax.requirements.txt +++ b/ci/official/containers/linux_arm64/jax.requirements.txt @@ -24,4 +24,6 @@ scipy==1.9.2;python_version=="3.11" scipy==1.7.3;python_version<"3.11" ml_dtypes>=0.2.0 +# For using Python 3.11 with Bazel 6 (b/286090018) +lit ~= 17.0.2 diff --git a/ci/official/envs/ci_default b/ci/official/envs/ci_default index d080a4566efe16..96d87423392541 100644 --- a/ci/official/envs/ci_default +++ b/ci/official/envs/ci_default @@ -58,6 +58,7 @@ TFCI_MACOS_UPGRADE_PYENV_ENABLE= TFCI_NIGHTLY_UPDATE_VERSION_ENABLE= TFCI_NVIDIA_SMI_ENABLE= TFCI_OUTPUT_DIR= +TFCI_PYCPP_SWAP_TO_BUILD_ENABLE= TFCI_PYTHON_VERIFY_PIP_INSTALL_ARGS= TFCI_PYTHON_VERSION= TFCI_WHL_AUDIT_ENABLE= diff --git a/third_party/xla/xla/python/tpu_driver/platform/external/tools.bzl b/ci/official/envs/disk_cache similarity index 51% rename from third_party/xla/xla/python/tpu_driver/platform/external/tools.bzl rename to ci/official/envs/disk_cache index 1c420ccb7f5039..bd8ccfa0d95692 100644 --- a/third_party/xla/xla/python/tpu_driver/platform/external/tools.bzl +++ b/ci/official/envs/disk_cache @@ -1,4 +1,4 @@ -# Copyright 2019 The OpenXLA Authors. +# Copyright 2023 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,21 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -""" -Build dependencies and utilities for the TPU driver interface. -""" - -def go_grpc_library(**_kwargs): - # A dummy macro placeholder for compatibility reason. - pass - -def external_deps(): - return [ - "@com_google_absl//absl/base:core_headers", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/container:inlined_vector", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/synchronization", - "@com_google_absl//absl/time", - "@com_google_absl//absl/types:span", - ] +# Sourcing this enables local disk cache +if [[ $(uname -s) == "Darwin" ]]; then + echo "Please note that using disk cache on macOS is not recommended because the" + echo "cache can end up being pretty big and make the build process inefficient." +fi +TFCI_BAZEL_COMMON_ARGS="$TFCI_BAZEL_COMMON_ARGS --disk_cache=$TFCI_OUTPUT_DIR/cache" diff --git a/ci/official/envs/enable_pycpp_build b/ci/official/envs/enable_pycpp_build new file mode 100644 index 00000000000000..d7e0e5ea8065c3 --- /dev/null +++ b/ci/official/envs/enable_pycpp_build @@ -0,0 +1,20 @@ +# Copyright 2023 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# +# Changes the behavior in pycpp.sh from "run all tests" to "verify that all +# tests can compile." Used in some CI jobs (macOS and Linux Arm64) where test +# execution is too expensive. +TFCI_PYCPP_SWAP_TO_BUILD_ENABLE=1 +TFCI_BAZEL_COMMON_ARGS="$TFCI_BAZEL_COMMON_ARGS --build_tests_only" \ No newline at end of file diff --git a/ci/official/envs/linux_arm64 b/ci/official/envs/linux_arm64 index 7c4270408dd68d..161b0e2e803822 100644 --- a/ci/official/envs/linux_arm64 +++ b/ci/official/envs/linux_arm64 @@ -19,7 +19,7 @@ TFCI_BAZEL_TARGET_SELECTING_CONFIG_PREFIX=linux_arm64 # despite lacking Nvidia CUDA support. TFCI_BUILD_PIP_PACKAGE_ARGS="--repo_env=WHEEL_NAME=tensorflow" TFCI_DOCKER_ENABLE=1 -TFCI_DOCKER_IMAGE=gcr.io/tensorflow-sigs/build-arm64:tf-latest-multi-python +TFCI_DOCKER_IMAGE=gcr.io/tensorflow-sigs/build-arm64:tf-2-16-multi-python TFCI_DOCKER_PULL_ENABLE=1 TFCI_DOCKER_REBUILD_ARGS="--target=tf ci/official/containers/linux_arm64" TFCI_INDEX_HTML_ENABLE=1 diff --git a/ci/official/envs/linux_arm64_onednn b/ci/official/envs/linux_arm64_onednn new file mode 100644 index 00000000000000..0d4b7cbd03bbaa --- /dev/null +++ b/ci/official/envs/linux_arm64_onednn @@ -0,0 +1,16 @@ +# Copyright 2023 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +source ci/official/envs/linux_arm64 +TFCI_BAZEL_COMMON_ARGS="$TFCI_BAZEL_COMMON_ARGS --test_env=TF_ENABLE_ONEDNN_OPTS=1" diff --git a/ci/official/envs/linux_x86 b/ci/official/envs/linux_x86 index 97fe9956f14ee1..2efc0fac466b00 100644 --- a/ci/official/envs/linux_x86 +++ b/ci/official/envs/linux_x86 @@ -16,7 +16,7 @@ TFCI_BAZEL_COMMON_ARGS="--repo_env=TF_PYTHON_VERSION=$TFCI_PYTHON_VERSION --conf TFCI_BAZEL_TARGET_SELECTING_CONFIG_PREFIX=linux_cpu TFCI_BUILD_PIP_PACKAGE_ARGS="--repo_env=WHEEL_NAME=tensorflow_cpu" TFCI_DOCKER_ENABLE=1 -TFCI_DOCKER_IMAGE=tensorflow/build:latest-python${TFCI_PYTHON_VERSION} +TFCI_DOCKER_IMAGE=tensorflow/build:2.16-python${TFCI_PYTHON_VERSION} TFCI_DOCKER_PULL_ENABLE=1 TFCI_DOCKER_REBUILD_ARGS="--build-arg PYTHON_VERSION=python$TFCI_PYTHON_VERSION --target=devel tensorflow/tools/tf_sig_build_dockerfiles" TFCI_INDEX_HTML_ENABLE=1 diff --git a/ci/official/envs/linux_x86_tpu b/ci/official/envs/linux_x86_tpu index 8a3cbe271c5b64..3c7d61b2ac3794 100644 --- a/ci/official/envs/linux_x86_tpu +++ b/ci/official/envs/linux_x86_tpu @@ -19,4 +19,4 @@ TFCI_BUILD_PIP_PACKAGE_ARGS="--repo_env=WHEEL_NAME=tensorflow_tpu" TFCI_LIB_SUFFIX="-tpu-linux-x86_64" TFCI_WHL_BAZEL_TEST_ENABLE=0 TFCI_WHL_SIZE_LIMIT=580M -TFCI_PYTHON_VERIFY_PIP_INSTALL_ARGS="-f https://storage.googleapis.com/libtpu-releases/index.html" +TFCI_PYTHON_VERIFY_PIP_INSTALL_ARGS="-f https://storage.googleapis.com/libtpu-tf-releases/index.html" diff --git a/ci/official/envs/macos_x86 b/ci/official/envs/macos_x86 index 3959830535628b..56166a0d0d4309 100644 --- a/ci/official/envs/macos_x86 +++ b/ci/official/envs/macos_x86 @@ -22,8 +22,8 @@ TFCI_MACOS_BAZEL_TEST_DIR_PATH="/System/Volumes/Data/bazel_output" TFCI_MACOS_INSTALL_BAZELISK_ENABLE=1 TFCI_MACOS_INSTALL_BAZELISK_URL="https://github.com/bazelbuild/bazelisk/releases/download/v1.11.0/bazelisk-darwin-amd64" TFCI_MACOS_TWINE_INSTALL_ENABLE=1 -TFCI_MACOS_UPGRADE_PYENV_ENABLE=1 TFCI_OUTPUT_DIR=build_output +TFCI_WHL_BAZEL_TEST_ENABLE=1 TFCI_WHL_SIZE_LIMIT=255M TFCI_WHL_SIZE_LIMIT_ENABLE=1 diff --git a/ci/official/envs/macos_x86_cross_compile b/ci/official/envs/macos_x86_cross_compile index 79f717156ea939..3a9dd2557faa1c 100644 --- a/ci/official/envs/macos_x86_cross_compile +++ b/ci/official/envs/macos_x86_cross_compile @@ -13,8 +13,7 @@ # limitations under the License. # ============================================================================== source ci/official/envs/macos_x86 -# Limit jobs to 100 to avoid running into "out of memory" issues (b/316266643) -TFCI_BAZEL_COMMON_ARGS="--repo_env=TF_PYTHON_VERSION=$TFCI_PYTHON_VERSION --jobs=100 --config cross_compile_macos_x86" +TFCI_BAZEL_COMMON_ARGS="--repo_env=TF_PYTHON_VERSION=$TFCI_PYTHON_VERSION --config cross_compile_macos_x86" TFCI_BAZEL_TARGET_SELECTING_CONFIG_PREFIX=cross_compile_macos_x86 TFCI_MACOS_CROSS_COMPILE_ENABLE=1 TFCI_MACOS_CROSS_COMPILE_SDK_DEST="tensorflow/tools/toolchains/cross_compile/cc/MacOSX.sdk" diff --git a/ci/official/envs/multicache b/ci/official/envs/public_cache similarity index 85% rename from ci/official/envs/multicache rename to ci/official/envs/public_cache index eb5c58e68e646f..ec57aad869ca47 100644 --- a/ci/official/envs/multicache +++ b/ci/official/envs/public_cache @@ -12,10 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -# Combine TF public build cache and local disk cache +# Sourcing this enables Bazel remote cache (public, read-only) # The cache configs are different for MacOS and Linux if [[ $(uname -s) == "Darwin" ]]; then - TFCI_BAZEL_COMMON_ARGS="$TFCI_BAZEL_COMMON_ARGS --config tf_public_macos_cache --disk_cache=$TFCI_OUTPUT_DIR/cache" + TFCI_BAZEL_COMMON_ARGS="$TFCI_BAZEL_COMMON_ARGS --config tf_public_macos_cache" else - TFCI_BAZEL_COMMON_ARGS="$TFCI_BAZEL_COMMON_ARGS --config tf_public_cache --disk_cache=$TFCI_OUTPUT_DIR/cache" + TFCI_BAZEL_COMMON_ARGS="$TFCI_BAZEL_COMMON_ARGS --config tf_public_cache" fi diff --git a/ci/official/envs/public_cache_push b/ci/official/envs/public_cache_push new file mode 100644 index 00000000000000..e686a0aac5d5ce --- /dev/null +++ b/ci/official/envs/public_cache_push @@ -0,0 +1,24 @@ +# Copyright 2023 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# Sourcing this enables Bazel remote cache (read and write) +# Note that "_push" cache configs write to GCS buckets and require +# authentication. If you are not a Googler, source "public_cache" to enable the +# public read-only cache. +# The cache configs are different for MacOS and Linux +if [[ $(uname -s) == "Darwin" ]]; then + TFCI_BAZEL_COMMON_ARGS="$TFCI_BAZEL_COMMON_ARGS --config tf_public_macos_cache_push" +else + TFCI_BAZEL_COMMON_ARGS="$TFCI_BAZEL_COMMON_ARGS --config tf_public_cache_push" +fi diff --git a/ci/official/pycpp.sh b/ci/official/pycpp.sh index 34294fe8a107f6..cf346007949c1e 100755 --- a/ci/official/pycpp.sh +++ b/ci/official/pycpp.sh @@ -15,7 +15,11 @@ # ============================================================================== source "${BASH_SOURCE%/*}/utilities/setup.sh" -tfrun bazel test $TFCI_BAZEL_COMMON_ARGS --profile "$TFCI_OUTPUT_DIR/profile.json.gz" --config="${TFCI_BAZEL_TARGET_SELECTING_CONFIG_PREFIX}_pycpp_test" +if [[ $TFCI_PYCPP_SWAP_TO_BUILD_ENABLE == 1 ]]; then + tfrun bazel build $TFCI_BAZEL_COMMON_ARGS --profile "$TFCI_OUTPUT_DIR/profile.json.gz" --config="${TFCI_BAZEL_TARGET_SELECTING_CONFIG_PREFIX}_pycpp_test" +else + tfrun bazel test $TFCI_BAZEL_COMMON_ARGS --profile "$TFCI_OUTPUT_DIR/profile.json.gz" --config="${TFCI_BAZEL_TARGET_SELECTING_CONFIG_PREFIX}_pycpp_test" +fi # Note: the profile can be viewed by visiting chrome://tracing in a Chrome browser. # See https://docs.bazel.build/versions/main/skylark/performance.html#performance-profiling diff --git a/ci/official/requirements_updater/.bazelversion b/ci/official/requirements_updater/.bazelversion new file mode 100644 index 00000000000000..f22d756da39d4c --- /dev/null +++ b/ci/official/requirements_updater/.bazelversion @@ -0,0 +1 @@ +6.5.0 diff --git a/ci/official/requirements_updater/release_updater.sh b/ci/official/requirements_updater/release_updater.sh index 88d54666eb21db..3d47199c7187af 100644 --- a/ci/official/requirements_updater/release_updater.sh +++ b/ci/official/requirements_updater/release_updater.sh @@ -25,7 +25,10 @@ SUPPORTED_VERSIONS=("3_9" "3_10" "3_11" "3_12") for VERSION in "${SUPPORTED_VERSIONS[@]}" do cp ../../../requirements_lock_"$VERSION".txt "requirements_lock_"$VERSION".txt" - bazel run --experimental_convenience_symlinks=ignore //:requirements_"$VERSION"_release.update + bazel run \ + --experimental_convenience_symlinks=ignore \ + --enable_bzlmod=false \ + //:requirements_"$VERSION"_release.update sed -i '/^#/d' requirements_lock_"$VERSION".txt mv "requirements_lock_"$VERSION".txt" ../../../requirements_lock_"$VERSION".txt done diff --git a/ci/official/requirements_updater/requirements.in b/ci/official/requirements_updater/requirements.in index 46b5a532d5bb17..364134fcf7c39b 100644 --- a/ci/official/requirements_updater/requirements.in +++ b/ci/official/requirements_updater/requirements.in @@ -11,7 +11,7 @@ astor == 0.7.1 typing_extensions == 4.8.0 gast == 0.4.0 termcolor == 2.3.0 -wrapt == 1.14.1 +wrapt == 1.16.0 tblib == 2.0.0 # Install tensorboard, and keras @@ -19,7 +19,7 @@ tblib == 2.0.0 # Note that we must use nightly here as these are used in nightly jobs # For release jobs, we will pin these on the release branch keras-nightly ~= 3.0.0.dev -tb-nightly ~= 2.15.0.a +tb-nightly ~= 2.17.0.a # Test dependencies grpcio >= 1.24.3, < 2.0 diff --git a/ci/official/requirements_updater/updater.sh b/ci/official/requirements_updater/updater.sh index 898151dab1b599..95c67322966d11 100755 --- a/ci/official/requirements_updater/updater.sh +++ b/ci/official/requirements_updater/updater.sh @@ -28,7 +28,10 @@ SUPPORTED_VERSIONS=("3_9" "3_10" "3_11" "3_12") for VERSION in "${SUPPORTED_VERSIONS[@]}" do touch "requirements_lock_$VERSION.txt" - bazel run --experimental_convenience_symlinks=ignore //:requirements_"$VERSION".update + bazel run \ + --experimental_convenience_symlinks=ignore \ + --enable_bzlmod=false \ + //:requirements_"$VERSION".update sed -i '/^#/d' requirements_lock_"$VERSION".txt mv requirements_lock_"$VERSION".txt ../../../requirements_lock_"$VERSION".txt done diff --git a/ci/official/utilities/code_check_full.bats b/ci/official/utilities/code_check_full.bats index 78dd88f1d56be6..8dacee0875535d 100644 --- a/ci/official/utilities/code_check_full.bats +++ b/ci/official/utilities/code_check_full.bats @@ -306,6 +306,12 @@ EOF echo "Look at the instructions for ':api_compatibility_test -- --update_goldens=True'" } +# See b/279852433 (internal). +# TODO(b/279852433) Replace deps(//tensorflow/...) with deps(//...) +@test "Verify that it's possible to query every TensorFlow target without BUILD errors" { + bazel query "deps(//tensorflow/...)" > /dev/null +} + teardown_file() { bazel shutdown } diff --git a/ci/official/utilities/setup_macos.sh b/ci/official/utilities/setup_macos.sh index a35a01788700d4..8a63d318c6e18e 100644 --- a/ci/official/utilities/setup_macos.sh +++ b/ci/official/utilities/setup_macos.sh @@ -81,13 +81,6 @@ if [[ "${TFCI_MACOS_PYENV_INSTALL_ENABLE}" == 1 ]]; then python --version fi -if [[ "$TFCI_PYTHON_VERSION" == "3.12" ]]; then - # dm-tree (Keras v3 dependency) doesn't have pre-built wheels for 3.12 yet. - # Having CMake allows building them. - # Once the wheels are added, this should be removed - b/308399490. - brew install cmake -fi - # TFCI Mac VM images do not have twine installed by default so we need to # install it manually. We use Twine in nightly builds to upload Python packages # to PyPI. diff --git a/requirements_lock_3_10.txt b/requirements_lock_3_10.txt index 434b38d603df80..2335d295d0faf6 100644 --- a/requirements_lock_3_10.txt +++ b/requirements_lock_3_10.txt @@ -1,6 +1,6 @@ -absl-py==2.0.0 \ - --hash=sha256:9a28abb62774ae4e8edbe2dd4c49ffcd45a6a848952a5eccc6a49f3f0fc1e2f3 \ - --hash=sha256:d9690211c5fcfefcdd1a45470ac2b5c5acd45241c3af71eed96bc5441746c0d5 +absl-py==2.1.0 \ + --hash=sha256:526a04eadab8b4ee719ce68f204172ead1027549089702d99b9059f129ff1308 \ + --hash=sha256:7820790efbb316739cde8b4e19357243fc3608a152024288513dd968d7d959ff # via # keras-nightly # tb-nightly @@ -12,105 +12,101 @@ astunparse==1.6.3 \ --hash=sha256:5ad93a8456f0d084c3456d059fd9a92cce667963232cbf763eac3bc5b7940872 \ --hash=sha256:c2652417f2c8b5bb325c885ae329bdf3f86424075c4fd1a128674bc6fba4b8e8 # via -r requirements.in -cachetools==5.3.2 \ - --hash=sha256:086ee420196f7b2ab9ca2db2520aca326318b68fe5ba8bc4d49cca91add450f2 \ - --hash=sha256:861f35a13a451f94e301ce2bec7cac63e881232ccce7ed67fab9b5df4d3beaa1 - # via google-auth -certifi==2023.7.22 \ - --hash=sha256:539cc1d13202e33ca466e88b2807e29f4c13049d6d87031a3c110744495cb082 \ - --hash=sha256:92d6037539857d8206b8f6ae472e8b77db8058fec5937a1ef3f54304089edbb9 +certifi==2024.2.2 \ + --hash=sha256:0569859f95fc761b18b45ef421b1290a0f65f147e92a1e5eb3e635f9a5e4e66f \ + --hash=sha256:dc383c07b76109f368f6106eee2b593b04a011ea4d55f652c6ca24a754d1cdd1 # via requests -charset-normalizer==3.3.1 \ - --hash=sha256:06cf46bdff72f58645434d467bf5228080801298fbba19fe268a01b4534467f5 \ - --hash=sha256:0c8c61fb505c7dad1d251c284e712d4e0372cef3b067f7ddf82a7fa82e1e9a93 \ - --hash=sha256:10b8dd31e10f32410751b3430996f9807fc4d1587ca69772e2aa940a82ab571a \ - --hash=sha256:1171ef1fc5ab4693c5d151ae0fdad7f7349920eabbaca6271f95969fa0756c2d \ - --hash=sha256:17a866d61259c7de1bdadef418a37755050ddb4b922df8b356503234fff7932c \ - --hash=sha256:1d6bfc32a68bc0933819cfdfe45f9abc3cae3877e1d90aac7259d57e6e0f85b1 \ - --hash=sha256:1ec937546cad86d0dce5396748bf392bb7b62a9eeb8c66efac60e947697f0e58 \ - --hash=sha256:223b4d54561c01048f657fa6ce41461d5ad8ff128b9678cfe8b2ecd951e3f8a2 \ - --hash=sha256:2465aa50c9299d615d757c1c888bc6fef384b7c4aec81c05a0172b4400f98557 \ - --hash=sha256:28f512b9a33235545fbbdac6a330a510b63be278a50071a336afc1b78781b147 \ - --hash=sha256:2c092be3885a1b7899cd85ce24acedc1034199d6fca1483fa2c3a35c86e43041 \ - --hash=sha256:2c4c99f98fc3a1835af8179dcc9013f93594d0670e2fa80c83aa36346ee763d2 \ - --hash=sha256:31445f38053476a0c4e6d12b047b08ced81e2c7c712e5a1ad97bc913256f91b2 \ - --hash=sha256:31bbaba7218904d2eabecf4feec0d07469284e952a27400f23b6628439439fa7 \ - --hash=sha256:34d95638ff3613849f473afc33f65c401a89f3b9528d0d213c7037c398a51296 \ - --hash=sha256:352a88c3df0d1fa886562384b86f9a9e27563d4704ee0e9d56ec6fcd270ea690 \ - --hash=sha256:39b70a6f88eebe239fa775190796d55a33cfb6d36b9ffdd37843f7c4c1b5dc67 \ - --hash=sha256:3c66df3f41abee950d6638adc7eac4730a306b022570f71dd0bd6ba53503ab57 \ - --hash=sha256:3f70fd716855cd3b855316b226a1ac8bdb3caf4f7ea96edcccc6f484217c9597 \ - --hash=sha256:3f9bc2ce123637a60ebe819f9fccc614da1bcc05798bbbaf2dd4ec91f3e08846 \ - --hash=sha256:3fb765362688821404ad6cf86772fc54993ec11577cd5a92ac44b4c2ba52155b \ - --hash=sha256:45f053a0ece92c734d874861ffe6e3cc92150e32136dd59ab1fb070575189c97 \ - --hash=sha256:46fb9970aa5eeca547d7aa0de5d4b124a288b42eaefac677bde805013c95725c \ - --hash=sha256:4cb50a0335382aac15c31b61d8531bc9bb657cfd848b1d7158009472189f3d62 \ - --hash=sha256:4e12f8ee80aa35e746230a2af83e81bd6b52daa92a8afaef4fea4a2ce9b9f4fa \ - --hash=sha256:4f3100d86dcd03c03f7e9c3fdb23d92e32abbca07e7c13ebd7ddfbcb06f5991f \ - --hash=sha256:4f6e2a839f83a6a76854d12dbebde50e4b1afa63e27761549d006fa53e9aa80e \ - --hash=sha256:4f861d94c2a450b974b86093c6c027888627b8082f1299dfd5a4bae8e2292821 \ - --hash=sha256:501adc5eb6cd5f40a6f77fbd90e5ab915c8fd6e8c614af2db5561e16c600d6f3 \ - --hash=sha256:520b7a142d2524f999447b3a0cf95115df81c4f33003c51a6ab637cbda9d0bf4 \ - --hash=sha256:548eefad783ed787b38cb6f9a574bd8664468cc76d1538215d510a3cd41406cb \ - --hash=sha256:555fe186da0068d3354cdf4bbcbc609b0ecae4d04c921cc13e209eece7720727 \ - --hash=sha256:55602981b2dbf8184c098bc10287e8c245e351cd4fdcad050bd7199d5a8bf514 \ - --hash=sha256:58e875eb7016fd014c0eea46c6fa92b87b62c0cb31b9feae25cbbe62c919f54d \ - --hash=sha256:5a3580a4fdc4ac05f9e53c57f965e3594b2f99796231380adb2baaab96e22761 \ - --hash=sha256:5b70bab78accbc672f50e878a5b73ca692f45f5b5e25c8066d748c09405e6a55 \ - --hash=sha256:5ceca5876032362ae73b83347be8b5dbd2d1faf3358deb38c9c88776779b2e2f \ - --hash=sha256:61f1e3fb621f5420523abb71f5771a204b33c21d31e7d9d86881b2cffe92c47c \ - --hash=sha256:633968254f8d421e70f91c6ebe71ed0ab140220469cf87a9857e21c16687c034 \ - --hash=sha256:63a6f59e2d01310f754c270e4a257426fe5a591dc487f1983b3bbe793cf6bac6 \ - --hash=sha256:63accd11149c0f9a99e3bc095bbdb5a464862d77a7e309ad5938fbc8721235ae \ - --hash=sha256:6db3cfb9b4fcecb4390db154e75b49578c87a3b9979b40cdf90d7e4b945656e1 \ - --hash=sha256:71ef3b9be10070360f289aea4838c784f8b851be3ba58cf796262b57775c2f14 \ - --hash=sha256:7ae8e5142dcc7a49168f4055255dbcced01dc1714a90a21f87448dc8d90617d1 \ - --hash=sha256:7b6cefa579e1237ce198619b76eaa148b71894fb0d6bcf9024460f9bf30fd228 \ - --hash=sha256:800561453acdecedaac137bf09cd719c7a440b6800ec182f077bb8e7025fb708 \ - --hash=sha256:82ca51ff0fc5b641a2d4e1cc8c5ff108699b7a56d7f3ad6f6da9dbb6f0145b48 \ - --hash=sha256:851cf693fb3aaef71031237cd68699dded198657ec1e76a76eb8be58c03a5d1f \ - --hash=sha256:854cc74367180beb327ab9d00f964f6d91da06450b0855cbbb09187bcdb02de5 \ - --hash=sha256:87071618d3d8ec8b186d53cb6e66955ef2a0e4fa63ccd3709c0c90ac5a43520f \ - --hash=sha256:871d045d6ccc181fd863a3cd66ee8e395523ebfbc57f85f91f035f50cee8e3d4 \ - --hash=sha256:8aee051c89e13565c6bd366813c386939f8e928af93c29fda4af86d25b73d8f8 \ - --hash=sha256:8af5a8917b8af42295e86b64903156b4f110a30dca5f3b5aedea123fbd638bff \ - --hash=sha256:8ec8ef42c6cd5856a7613dcd1eaf21e5573b2185263d87d27c8edcae33b62a61 \ - --hash=sha256:91e43805ccafa0a91831f9cd5443aa34528c0c3f2cc48c4cb3d9a7721053874b \ - --hash=sha256:9505dc359edb6a330efcd2be825fdb73ee3e628d9010597aa1aee5aa63442e97 \ - --hash=sha256:985c7965f62f6f32bf432e2681173db41336a9c2611693247069288bcb0c7f8b \ - --hash=sha256:9a74041ba0bfa9bc9b9bb2cd3238a6ab3b7618e759b41bd15b5f6ad958d17605 \ - --hash=sha256:9edbe6a5bf8b56a4a84533ba2b2f489d0046e755c29616ef8830f9e7d9cf5728 \ - --hash=sha256:a15c1fe6d26e83fd2e5972425a772cca158eae58b05d4a25a4e474c221053e2d \ - --hash=sha256:a66bcdf19c1a523e41b8e9d53d0cedbfbac2e93c649a2e9502cb26c014d0980c \ - --hash=sha256:ae4070f741f8d809075ef697877fd350ecf0b7c5837ed68738607ee0a2c572cf \ - --hash=sha256:ae55d592b02c4349525b6ed8f74c692509e5adffa842e582c0f861751701a673 \ - --hash=sha256:b578cbe580e3b41ad17b1c428f382c814b32a6ce90f2d8e39e2e635d49e498d1 \ - --hash=sha256:b891a2f68e09c5ef989007fac11476ed33c5c9994449a4e2c3386529d703dc8b \ - --hash=sha256:baec8148d6b8bd5cee1ae138ba658c71f5b03e0d69d5907703e3e1df96db5e41 \ - --hash=sha256:bb06098d019766ca16fc915ecaa455c1f1cd594204e7f840cd6258237b5079a8 \ - --hash=sha256:bc791ec3fd0c4309a753f95bb6c749ef0d8ea3aea91f07ee1cf06b7b02118f2f \ - --hash=sha256:bd28b31730f0e982ace8663d108e01199098432a30a4c410d06fe08fdb9e93f4 \ - --hash=sha256:be4d9c2770044a59715eb57c1144dedea7c5d5ae80c68fb9959515037cde2008 \ - --hash=sha256:c0c72d34e7de5604df0fde3644cc079feee5e55464967d10b24b1de268deceb9 \ - --hash=sha256:c0e842112fe3f1a4ffcf64b06dc4c61a88441c2f02f373367f7b4c1aa9be2ad5 \ - --hash=sha256:c15070ebf11b8b7fd1bfff7217e9324963c82dbdf6182ff7050519e350e7ad9f \ - --hash=sha256:c2000c54c395d9e5e44c99dc7c20a64dc371f777faf8bae4919ad3e99ce5253e \ - --hash=sha256:c30187840d36d0ba2893bc3271a36a517a717f9fd383a98e2697ee890a37c273 \ - --hash=sha256:cb7cd68814308aade9d0c93c5bd2ade9f9441666f8ba5aa9c2d4b389cb5e2a45 \ - --hash=sha256:cd805513198304026bd379d1d516afbf6c3c13f4382134a2c526b8b854da1c2e \ - --hash=sha256:d0bf89afcbcf4d1bb2652f6580e5e55a840fdf87384f6063c4a4f0c95e378656 \ - --hash=sha256:d9137a876020661972ca6eec0766d81aef8a5627df628b664b234b73396e727e \ - --hash=sha256:dbd95e300367aa0827496fe75a1766d198d34385a58f97683fe6e07f89ca3e3c \ - --hash=sha256:dced27917823df984fe0c80a5c4ad75cf58df0fbfae890bc08004cd3888922a2 \ - --hash=sha256:de0b4caa1c8a21394e8ce971997614a17648f94e1cd0640fbd6b4d14cab13a72 \ - --hash=sha256:debb633f3f7856f95ad957d9b9c781f8e2c6303ef21724ec94bea2ce2fcbd056 \ - --hash=sha256:e372d7dfd154009142631de2d316adad3cc1c36c32a38b16a4751ba78da2a397 \ - --hash=sha256:ecd26be9f112c4f96718290c10f4caea6cc798459a3a76636b817a0ed7874e42 \ - --hash=sha256:edc0202099ea1d82844316604e17d2b175044f9bcb6b398aab781eba957224bd \ - --hash=sha256:f194cce575e59ffe442c10a360182a986535fd90b57f7debfaa5c845c409ecc3 \ - --hash=sha256:f5fb672c396d826ca16a022ac04c9dce74e00a1c344f6ad1a0fdc1ba1f332213 \ - --hash=sha256:f6a02a3c7950cafaadcd46a226ad9e12fc9744652cc69f9e5534f98b47f3bbcf \ - --hash=sha256:fe81b35c33772e56f4b6cf62cf4aedc1762ef7162a31e6ac7fe5e40d0149eb67 +charset-normalizer==3.3.2 \ + --hash=sha256:06435b539f889b1f6f4ac1758871aae42dc3a8c0e24ac9e60c2384973ad73027 \ + --hash=sha256:06a81e93cd441c56a9b65d8e1d043daeb97a3d0856d177d5c90ba85acb3db087 \ + --hash=sha256:0a55554a2fa0d408816b3b5cedf0045f4b8e1a6065aec45849de2d6f3f8e9786 \ + --hash=sha256:0b2b64d2bb6d3fb9112bafa732def486049e63de9618b5843bcdd081d8144cd8 \ + --hash=sha256:10955842570876604d404661fbccbc9c7e684caf432c09c715ec38fbae45ae09 \ + --hash=sha256:122c7fa62b130ed55f8f285bfd56d5f4b4a5b503609d181f9ad85e55c89f4185 \ + --hash=sha256:1ceae2f17a9c33cb48e3263960dc5fc8005351ee19db217e9b1bb15d28c02574 \ + --hash=sha256:1d3193f4a680c64b4b6a9115943538edb896edc190f0b222e73761716519268e \ + --hash=sha256:1f79682fbe303db92bc2b1136016a38a42e835d932bab5b3b1bfcfbf0640e519 \ + --hash=sha256:2127566c664442652f024c837091890cb1942c30937add288223dc895793f898 \ + --hash=sha256:22afcb9f253dac0696b5a4be4a1c0f8762f8239e21b99680099abd9b2b1b2269 \ + --hash=sha256:25baf083bf6f6b341f4121c2f3c548875ee6f5339300e08be3f2b2ba1721cdd3 \ + --hash=sha256:2e81c7b9c8979ce92ed306c249d46894776a909505d8f5a4ba55b14206e3222f \ + --hash=sha256:3287761bc4ee9e33561a7e058c72ac0938c4f57fe49a09eae428fd88aafe7bb6 \ + --hash=sha256:34d1c8da1e78d2e001f363791c98a272bb734000fcef47a491c1e3b0505657a8 \ + --hash=sha256:37e55c8e51c236f95b033f6fb391d7d7970ba5fe7ff453dad675e88cf303377a \ + --hash=sha256:3d47fa203a7bd9c5b6cee4736ee84ca03b8ef23193c0d1ca99b5089f72645c73 \ + --hash=sha256:3e4d1f6587322d2788836a99c69062fbb091331ec940e02d12d179c1d53e25fc \ + --hash=sha256:42cb296636fcc8b0644486d15c12376cb9fa75443e00fb25de0b8602e64c1714 \ + --hash=sha256:45485e01ff4d3630ec0d9617310448a8702f70e9c01906b0d0118bdf9d124cf2 \ + --hash=sha256:4a78b2b446bd7c934f5dcedc588903fb2f5eec172f3d29e52a9096a43722adfc \ + --hash=sha256:4ab2fe47fae9e0f9dee8c04187ce5d09f48eabe611be8259444906793ab7cbce \ + --hash=sha256:4d0d1650369165a14e14e1e47b372cfcb31d6ab44e6e33cb2d4e57265290044d \ + --hash=sha256:549a3a73da901d5bc3ce8d24e0600d1fa85524c10287f6004fbab87672bf3e1e \ + --hash=sha256:55086ee1064215781fff39a1af09518bc9255b50d6333f2e4c74ca09fac6a8f6 \ + --hash=sha256:572c3763a264ba47b3cf708a44ce965d98555f618ca42c926a9c1616d8f34269 \ + --hash=sha256:573f6eac48f4769d667c4442081b1794f52919e7edada77495aaed9236d13a96 \ + --hash=sha256:5b4c145409bef602a690e7cfad0a15a55c13320ff7a3ad7ca59c13bb8ba4d45d \ + --hash=sha256:6463effa3186ea09411d50efc7d85360b38d5f09b870c48e4600f63af490e56a \ + --hash=sha256:65f6f63034100ead094b8744b3b97965785388f308a64cf8d7c34f2f2e5be0c4 \ + --hash=sha256:663946639d296df6a2bb2aa51b60a2454ca1cb29835324c640dafb5ff2131a77 \ + --hash=sha256:6897af51655e3691ff853668779c7bad41579facacf5fd7253b0133308cf000d \ + --hash=sha256:68d1f8a9e9e37c1223b656399be5d6b448dea850bed7d0f87a8311f1ff3dabb0 \ + --hash=sha256:6ac7ffc7ad6d040517be39eb591cac5ff87416c2537df6ba3cba3bae290c0fed \ + --hash=sha256:6b3251890fff30ee142c44144871185dbe13b11bab478a88887a639655be1068 \ + --hash=sha256:6c4caeef8fa63d06bd437cd4bdcf3ffefe6738fb1b25951440d80dc7df8c03ac \ + --hash=sha256:6ef1d82a3af9d3eecdba2321dc1b3c238245d890843e040e41e470ffa64c3e25 \ + --hash=sha256:753f10e867343b4511128c6ed8c82f7bec3bd026875576dfd88483c5c73b2fd8 \ + --hash=sha256:7cd13a2e3ddeed6913a65e66e94b51d80a041145a026c27e6bb76c31a853c6ab \ + --hash=sha256:7ed9e526742851e8d5cc9e6cf41427dfc6068d4f5a3bb03659444b4cabf6bc26 \ + --hash=sha256:7f04c839ed0b6b98b1a7501a002144b76c18fb1c1850c8b98d458ac269e26ed2 \ + --hash=sha256:802fe99cca7457642125a8a88a084cef28ff0cf9407060f7b93dca5aa25480db \ + --hash=sha256:80402cd6ee291dcb72644d6eac93785fe2c8b9cb30893c1af5b8fdd753b9d40f \ + --hash=sha256:8465322196c8b4d7ab6d1e049e4c5cb460d0394da4a27d23cc242fbf0034b6b5 \ + --hash=sha256:86216b5cee4b06df986d214f664305142d9c76df9b6512be2738aa72a2048f99 \ + --hash=sha256:87d1351268731db79e0f8e745d92493ee2841c974128ef629dc518b937d9194c \ + --hash=sha256:8bdb58ff7ba23002a4c5808d608e4e6c687175724f54a5dade5fa8c67b604e4d \ + --hash=sha256:8c622a5fe39a48f78944a87d4fb8a53ee07344641b0562c540d840748571b811 \ + --hash=sha256:8d756e44e94489e49571086ef83b2bb8ce311e730092d2c34ca8f7d925cb20aa \ + --hash=sha256:8f4a014bc36d3c57402e2977dada34f9c12300af536839dc38c0beab8878f38a \ + --hash=sha256:9063e24fdb1e498ab71cb7419e24622516c4a04476b17a2dab57e8baa30d6e03 \ + --hash=sha256:90d558489962fd4918143277a773316e56c72da56ec7aa3dc3dbbe20fdfed15b \ + --hash=sha256:923c0c831b7cfcb071580d3f46c4baf50f174be571576556269530f4bbd79d04 \ + --hash=sha256:95f2a5796329323b8f0512e09dbb7a1860c46a39da62ecb2324f116fa8fdc85c \ + --hash=sha256:96b02a3dc4381e5494fad39be677abcb5e6634bf7b4fa83a6dd3112607547001 \ + --hash=sha256:9f96df6923e21816da7e0ad3fd47dd8f94b2a5ce594e00677c0013018b813458 \ + --hash=sha256:a10af20b82360ab00827f916a6058451b723b4e65030c5a18577c8b2de5b3389 \ + --hash=sha256:a50aebfa173e157099939b17f18600f72f84eed3049e743b68ad15bd69b6bf99 \ + --hash=sha256:a981a536974bbc7a512cf44ed14938cf01030a99e9b3a06dd59578882f06f985 \ + --hash=sha256:a9a8e9031d613fd2009c182b69c7b2c1ef8239a0efb1df3f7c8da66d5dd3d537 \ + --hash=sha256:ae5f4161f18c61806f411a13b0310bea87f987c7d2ecdbdaad0e94eb2e404238 \ + --hash=sha256:aed38f6e4fb3f5d6bf81bfa990a07806be9d83cf7bacef998ab1a9bd660a581f \ + --hash=sha256:b01b88d45a6fcb69667cd6d2f7a9aeb4bf53760d7fc536bf679ec94fe9f3ff3d \ + --hash=sha256:b261ccdec7821281dade748d088bb6e9b69e6d15b30652b74cbbac25e280b796 \ + --hash=sha256:b2b0a0c0517616b6869869f8c581d4eb2dd83a4d79e0ebcb7d373ef9956aeb0a \ + --hash=sha256:b4a23f61ce87adf89be746c8a8974fe1c823c891d8f86eb218bb957c924bb143 \ + --hash=sha256:bd8f7df7d12c2db9fab40bdd87a7c09b1530128315d047a086fa3ae3435cb3a8 \ + --hash=sha256:beb58fe5cdb101e3a055192ac291b7a21e3b7ef4f67fa1d74e331a7f2124341c \ + --hash=sha256:c002b4ffc0be611f0d9da932eb0f704fe2602a9a949d1f738e4c34c75b0863d5 \ + --hash=sha256:c083af607d2515612056a31f0a8d9e0fcb5876b7bfc0abad3ecd275bc4ebc2d5 \ + --hash=sha256:c180f51afb394e165eafe4ac2936a14bee3eb10debc9d9e4db8958fe36afe711 \ + --hash=sha256:c235ebd9baae02f1b77bcea61bce332cb4331dc3617d254df3323aa01ab47bd4 \ + --hash=sha256:cd70574b12bb8a4d2aaa0094515df2463cb429d8536cfb6c7ce983246983e5a6 \ + --hash=sha256:d0eccceffcb53201b5bfebb52600a5fb483a20b61da9dbc885f8b103cbe7598c \ + --hash=sha256:d965bba47ddeec8cd560687584e88cf699fd28f192ceb452d1d7ee807c5597b7 \ + --hash=sha256:db364eca23f876da6f9e16c9da0df51aa4f104a972735574842618b8c6d999d4 \ + --hash=sha256:ddbb2551d7e0102e7252db79ba445cdab71b26640817ab1e3e3648dad515003b \ + --hash=sha256:deb6be0ac38ece9ba87dea880e438f25ca3eddfac8b002a2ec3d9183a454e8ae \ + --hash=sha256:e06ed3eb3218bc64786f7db41917d4e686cc4856944f53d5bdf83a6884432e12 \ + --hash=sha256:e27ad930a842b4c5eb8ac0016b0a54f5aebbe679340c26101df33424142c143c \ + --hash=sha256:e537484df0d8f426ce2afb2d0f8e1c3d0b114b83f8850e5f2fbea0e797bd82ae \ + --hash=sha256:eb00ed941194665c332bf8e078baf037d6c35d7c4f3102ea2d4f16ca94a26dc8 \ + --hash=sha256:eb6904c354526e758fda7167b33005998fb68c46fbc10e013ca97f21ca5c8887 \ + --hash=sha256:eb8821e09e916165e160797a6c17edda0679379a4be5c716c260e836e122f54b \ + --hash=sha256:efcb3f6676480691518c177e3b465bcddf57cea040302f9f4e6e191af91174d4 \ + --hash=sha256:f27273b60488abe721a075bcca6d7f3964f9f6f067c8c4c605743023d7d3944f \ + --hash=sha256:f30c3cb33b24454a82faecaf01b19c18562b1e89558fb6c56de4d9118a032fd5 \ + --hash=sha256:fb69256e180cb6c8a894fee62b3afebae785babc1ee98b81cdf68bbca1987f33 \ + --hash=sha256:fd1abc0d89e30cc4e02e4064dc67fcc51bd941eb395c502aac3ec19fab46b519 \ + --hash=sha256:ff8fa367d09b717b2a17a052544193ad76cd49979c805768879cb63d9ca50561 # via requests dill==0.3.7 \ --hash=sha256:76b122c08ef4ce2eedcd4d1abd8e641114bfc6c2867f49f3c41facf65bf19f5e \ @@ -133,6 +129,7 @@ dm-tree==0.1.8 \ --hash=sha256:35cc164a79336bfcfafb47e5f297898359123bbd3330c1967f0c4994f9cf9f60 \ --hash=sha256:378cc8ad93c5fe3590f405a309980721f021c790ca1bdf9b15bb1d59daec57f5 \ --hash=sha256:39070ba268c0491af9fe7a58644d99e8b4f2cde6e5884ba3380bddc84ed43d5f \ + --hash=sha256:435227cf3c5dc63f4de054cf3d00183790bd9ead4c3623138c74dde7f67f521b \ --hash=sha256:5483dca4d7eb1a0d65fe86d3b6a53ae717face83c1f17e0887b1a4a64ae5c410 \ --hash=sha256:694c3654cfd2a81552c08ec66bb5c4a3d48fa292b9a181880fb081c36c5b9134 \ --hash=sha256:803bfc53b4659f447ac694dbd04235f94a73ef7c1fd1e0df7c84ac41e0bc963b \ @@ -140,6 +137,8 @@ dm-tree==0.1.8 \ --hash=sha256:83b7764de0d855338abefc6e3ee9fe40d301668310aa3baea3f778ff051f4393 \ --hash=sha256:8c60a7eadab64c2278861f56bca320b2720f163dca9d7558103c3b77f2416571 \ --hash=sha256:8ed3564abed97c806db122c2d3e1a2b64c74a63debe9903aad795167cc301368 \ + --hash=sha256:94d3f0826311f45ee19b75f5b48c99466e4218a0489e81c0f0167bda50cacf22 \ + --hash=sha256:96a548a406a6fb15fe58f6a30a57ff2f2aafbf25f05afab00c8f5e5977b6c715 \ --hash=sha256:a5d819c38c03f0bb5b3b3703c60e4b170355a0fc6b5819325bf3d4ceb3ae7e80 \ --hash=sha256:ad16ceba90a56ec47cf45b21856d14962ac314787975ef786efb5e6e9ca75ec7 \ --hash=sha256:af4b3d372f2477dcd89a6e717e4a575ca35ccc20cc4454a8a4b6f8838a00672d \ @@ -147,6 +146,7 @@ dm-tree==0.1.8 \ --hash=sha256:b9bd9b9ccb59409d33d51d84b7668010c04c2af7d4a371632874c1ca356cff3d \ --hash=sha256:b9f89a454e98806b44fe9d40ec9eee61f848388f7e79ac2371a55679bd5a3ac6 \ --hash=sha256:bb2d109f42190225112da899b9f3d46d0d5f26aef501c61e43529fe9322530b5 \ + --hash=sha256:c0a94aba18a35457a1b5cd716fd7b46c5dafdc4cf7869b4bae665b91c4682a8e \ --hash=sha256:c5c8c12e3fda754ef6af94161bacdaeda816d941995fac415d6855c6c386af68 \ --hash=sha256:d1612fcaecd79023dbc6a6ae48d51a80beb5c385d6f3f6d71688e57bc8d07de8 \ --hash=sha256:d16e1f2a073604cfcc09f7131ae8d534674f43c3aef4c25742eae295bc60d04f \ @@ -154,6 +154,7 @@ dm-tree==0.1.8 \ --hash=sha256:d40fa4106ca6edc66760246a08f500ec0c85ef55c762fb4a363f6ee739ba02ee \ --hash=sha256:de287fabc464b8734be251e46e06aa9aa1001f34198da2b6ce07bd197172b9cb \ --hash=sha256:e4d714371bb08839e4e5e29024fc95832d9affe129825ef38836b143028bd144 \ + --hash=sha256:ea9e59e0451e7d29aece402d9f908f2e2a80922bcde2ebfd5dcb07750fcbfee8 \ --hash=sha256:f7ac31b9aecccb2c6e1ab29706f6ded3eba0c2c69c770322c9c685929c3d6afb \ --hash=sha256:fa42a605d099ee7d41ba2b5fb75e21423951fd26e5d50583a00471238fb3021d # via keras-nightly @@ -161,71 +162,61 @@ gast==0.4.0 \ --hash=sha256:40feb7b8b8434785585ab224d1568b857edb18297e5a3047f1ba012bc83b42c1 \ --hash=sha256:b7adcdd5adbebf1adf17378da5ba3f543684dbec47b1cda1f3997e573cd542c4 # via -r requirements.in -google-auth==2.23.3 \ - --hash=sha256:6864247895eea5d13b9c57c9e03abb49cb94ce2dc7c58e91cba3248c7477c9e3 \ - --hash=sha256:a8f4608e65c244ead9e0538f181a96c6e11199ec114d41f1d7b1bffa96937bda - # via - # google-auth-oauthlib - # tb-nightly -google-auth-oauthlib==1.1.0 \ - --hash=sha256:089c6e587d36f4803ac7e0720c045c6a8b1fd1790088b8424975b90d0ee61c12 \ - --hash=sha256:83ea8c3b0881e453790baff4448e8a6112ac8778d1de9da0b68010b843937afb - # via tb-nightly -grpcio==1.59.2 \ - --hash=sha256:023088764012411affe7db183d1ada3ad9daf2e23ddc719ff46d7061de661340 \ - --hash=sha256:08d77e682f2bf730a4961eea330e56d2f423c6a9b91ca222e5b1eb24a357b19f \ - --hash=sha256:0a4a3833c0e067f3558538727235cd8a49709bff1003200bbdefa2f09334e4b1 \ - --hash=sha256:0a754aff9e3af63bdc4c75c234b86b9d14e14a28a30c4e324aed1a9b873d755f \ - --hash=sha256:11168ef43e4a43ff1b1a65859f3e0ef1a173e277349e7fb16923ff108160a8cd \ - --hash=sha256:128e20f57c5f27cb0157e73756d1586b83c1b513ebecc83ea0ac37e4b0e4e758 \ - --hash=sha256:1f9524d1d701e399462d2c90ba7c193e49d1711cf429c0d3d97c966856e03d00 \ - --hash=sha256:1ff16d68bf453275466a9a46739061a63584d92f18a0f5b33d19fc97eb69867c \ - --hash=sha256:2067274c88bc6de89c278a672a652b4247d088811ece781a4858b09bdf8448e3 \ - --hash=sha256:2171c39f355ba5b551c5d5928d65aa6c69807fae195b86ef4a7d125bcdb860a9 \ - --hash=sha256:242adc47725b9a499ee77c6a2e36688fa6c96484611f33b1be4c57ab075a92dd \ - --hash=sha256:27f879ae604a7fcf371e59fba6f3ff4635a4c2a64768bd83ff0cac503142fef4 \ - --hash=sha256:2b230028a008ae1d0f430acb227d323ff8a619017415cf334c38b457f814119f \ - --hash=sha256:3059668df17627f0e0fa680e9ef8c995c946c792612e9518f5cc1503be14e90b \ - --hash=sha256:31176aa88f36020055ace9adff2405a33c8bdbfa72a9c4980e25d91b2f196873 \ - --hash=sha256:36f53c2b3449c015880e7d55a89c992c357f176327b0d2873cdaaf9628a37c69 \ - --hash=sha256:3b4368b33908f683a363f376dfb747d40af3463a6e5044afee07cf9436addf96 \ - --hash=sha256:3c61d641d4f409c5ae46bfdd89ea42ce5ea233dcf69e74ce9ba32b503c727e29 \ - --hash=sha256:4abb717e320e74959517dc8e84a9f48fbe90e9abe19c248541e9418b1ce60acd \ - --hash=sha256:4c93f4abbb54321ee6471e04a00139c80c754eda51064187963ddf98f5cf36a4 \ - --hash=sha256:535561990e075fa6bd4b16c4c3c1096b9581b7bb35d96fac4650f1181e428268 \ - --hash=sha256:53c9aa5ddd6857c0a1cd0287225a2a25873a8e09727c2e95c4aebb1be83a766a \ - --hash=sha256:5d573e70a6fe77555fb6143c12d3a7d3fa306632a3034b4e7c59ca09721546f8 \ - --hash=sha256:6009386a2df66159f64ac9f20425ae25229b29b9dd0e1d3dd60043f037e2ad7e \ - --hash=sha256:686e975a5d16602dc0982c7c703948d17184bd1397e16c8ee03511ecb8c4cdda \ - --hash=sha256:6959fb07e8351e20501ffb8cc4074c39a0b7ef123e1c850a7f8f3afdc3a3da01 \ - --hash=sha256:6b25ed37c27e652db01be341af93fbcea03d296c024d8a0e680017a268eb85dd \ - --hash=sha256:6da6dea3a1bacf99b3c2187e296db9a83029ed9c38fd4c52b7c9b7326d13c828 \ - --hash=sha256:72ca2399097c0b758198f2ff30f7178d680de8a5cfcf3d9b73a63cf87455532e \ - --hash=sha256:73abb8584b0cf74d37f5ef61c10722adc7275502ab71789a8fe3cb7ef04cf6e2 \ - --hash=sha256:74100fecaec8a535e380cf5f2fb556ff84957d481c13e54051c52e5baac70541 \ - --hash=sha256:75c6ecb70e809cf1504465174343113f51f24bc61e22a80ae1c859f3f7034c6d \ - --hash=sha256:7cf05053242f61ba94014dd3a986e11a083400a32664058f80bf4cf817c0b3a1 \ - --hash=sha256:9411e24328a2302e279e70cae6e479f1fddde79629fcb14e03e6d94b3956eabf \ - --hash=sha256:a213acfbf186b9f35803b52e4ca9addb153fc0b67f82a48f961be7000ecf6721 \ - --hash=sha256:bb7e0fe6ad73b7f06d7e2b689c19a71cf5cc48f0c2bf8608469e51ffe0bd2867 \ - --hash=sha256:c2504eed520958a5b77cc99458297cb7906308cb92327f35fb7fbbad4e9b2188 \ - --hash=sha256:c35aa9657f5d5116d23b934568e0956bd50c615127810fffe3ac356a914c176a \ - --hash=sha256:c5f09cffa619adfb44799fa4a81c2a1ad77c887187613fb0a8f201ab38d89ba1 \ - --hash=sha256:c978f864b35f2261e0819f5cd88b9830b04dc51bcf055aac3c601e525a10d2ba \ - --hash=sha256:cbe946b3e6e60a7b4618f091e62a029cb082b109a9d6b53962dd305087c6e4fd \ - --hash=sha256:cc3e4cd087f07758b16bef8f31d88dbb1b5da5671d2f03685ab52dece3d7a16e \ - --hash=sha256:cf0dead5a2c5a3347af2cfec7131d4f2a2e03c934af28989c9078f8241a491fa \ - --hash=sha256:d2794f0e68b3085d99b4f6ff9c089f6fdd02b32b9d3efdfbb55beac1bf22d516 \ - --hash=sha256:d2fa68a96a30dd240be80bbad838a0ac81a61770611ff7952b889485970c4c71 \ - --hash=sha256:d6f70406695e3220f09cd7a2f879333279d91aa4a8a1d34303b56d61a8180137 \ - --hash=sha256:d8f9cd4ad1be90b0cf350a2f04a38a36e44a026cac1e036ac593dc48efe91d52 \ - --hash=sha256:da2d94c15f88cd40d7e67f7919d4f60110d2b9d5b1e08cf354c2be773ab13479 \ - --hash=sha256:e1727c1c0e394096bb9af185c6923e8ea55a5095b8af44f06903bcc0e06800a2 \ - --hash=sha256:e420ced29b5904cdf9ee5545e23f9406189d8acb6750916c2db4793dada065c6 \ - --hash=sha256:e82c5cf1495244adf5252f925ac5932e5fd288b3e5ab6b70bec5593074b7236c \ - --hash=sha256:f1ef0d39bc1feb420caf549b3c657c871cad4ebbcf0580c4d03816b0590de0cf \ - --hash=sha256:f8753a6c88d1d0ba64302309eecf20f70d2770f65ca02d83c2452279085bfcd3 \ - --hash=sha256:f93dbf58f03146164048be5426ffde298b237a5e059144847e4940f5b80172c3 +grpcio==1.60.1 \ + --hash=sha256:0250a7a70b14000fa311de04b169cc7480be6c1a769b190769d347939d3232a8 \ + --hash=sha256:069fe2aeee02dfd2135d562d0663fe70fbb69d5eed6eb3389042a7e963b54de8 \ + --hash=sha256:082081e6a36b6eb5cf0fd9a897fe777dbb3802176ffd08e3ec6567edd85bc104 \ + --hash=sha256:0c5807e9152eff15f1d48f6b9ad3749196f79a4a050469d99eecb679be592acc \ + --hash=sha256:14e8f2c84c0832773fb3958240c69def72357bc11392571f87b2d7b91e0bb092 \ + --hash=sha256:2a6087f234cb570008a6041c8ffd1b7d657b397fdd6d26e83d72283dae3527b1 \ + --hash=sha256:2bb2a2911b028f01c8c64d126f6b632fcd8a9ac975aa1b3855766c94e4107180 \ + --hash=sha256:2f44c32aef186bbba254129cea1df08a20be414144ac3bdf0e84b24e3f3b2e05 \ + --hash=sha256:30e980cd6db1088c144b92fe376747328d5554bc7960ce583ec7b7d81cd47287 \ + --hash=sha256:33aed0a431f5befeffd9d346b0fa44b2c01aa4aeae5ea5b2c03d3e25e0071216 \ + --hash=sha256:33bdea30dcfd4f87b045d404388469eb48a48c33a6195a043d116ed1b9a0196c \ + --hash=sha256:39aa848794b887120b1d35b1b994e445cc028ff602ef267f87c38122c1add50d \ + --hash=sha256:4216e67ad9a4769117433814956031cb300f85edc855252a645a9a724b3b6594 \ + --hash=sha256:49c9b6a510e3ed8df5f6f4f3c34d7fbf2d2cae048ee90a45cd7415abab72912c \ + --hash=sha256:4eec8b8c1c2c9b7125508ff7c89d5701bf933c99d3910e446ed531cd16ad5d87 \ + --hash=sha256:50d56280b482875d1f9128ce596e59031a226a8b84bec88cb2bf76c289f5d0de \ + --hash=sha256:53b69e79d00f78c81eecfb38f4516080dc7f36a198b6b37b928f1c13b3c063e9 \ + --hash=sha256:55ccb7db5a665079d68b5c7c86359ebd5ebf31a19bc1a91c982fd622f1e31ff2 \ + --hash=sha256:5a1ebbae7e2214f51b1f23b57bf98eeed2cf1ba84e4d523c48c36d5b2f8829ff \ + --hash=sha256:61b7199cd2a55e62e45bfb629a35b71fc2c0cb88f686a047f25b1112d3810904 \ + --hash=sha256:660fc6b9c2a9ea3bb2a7e64ba878c98339abaf1811edca904ac85e9e662f1d73 \ + --hash=sha256:6d140bdeb26cad8b93c1455fa00573c05592793c32053d6e0016ce05ba267549 \ + --hash=sha256:6e490fa5f7f5326222cb9f0b78f207a2b218a14edf39602e083d5f617354306f \ + --hash=sha256:6ecf21d20d02d1733e9c820fb5c114c749d888704a7ec824b545c12e78734d1c \ + --hash=sha256:70c83bb530572917be20c21f3b6be92cd86b9aecb44b0c18b1d3b2cc3ae47df0 \ + --hash=sha256:72153a0d2e425f45b884540a61c6639436ddafa1829a42056aa5764b84108b8e \ + --hash=sha256:73e14acd3d4247169955fae8fb103a2b900cfad21d0c35f0dcd0fdd54cd60367 \ + --hash=sha256:76eaaba891083fcbe167aa0f03363311a9f12da975b025d30e94b93ac7a765fc \ + --hash=sha256:79ae0dc785504cb1e1788758c588c711f4e4a0195d70dff53db203c95a0bd303 \ + --hash=sha256:7d142bcd604166417929b071cd396aa13c565749a4c840d6c702727a59d835eb \ + --hash=sha256:8c9554ca8e26241dabe7951aa1fa03a1ba0856688ecd7e7bdbdd286ebc272e4c \ + --hash=sha256:8d488fbdbf04283f0d20742b64968d44825617aa6717b07c006168ed16488804 \ + --hash=sha256:91422ba785a8e7a18725b1dc40fbd88f08a5bb4c7f1b3e8739cab24b04fa8a03 \ + --hash=sha256:9a66f4d2a005bc78e61d805ed95dedfcb35efa84b7bba0403c6d60d13a3de2d6 \ + --hash=sha256:9b106bc52e7f28170e624ba61cc7dc6829566e535a6ec68528f8e1afbed1c41f \ + --hash=sha256:9b54577032d4f235452f77a83169b6527bf4b77d73aeada97d45b2aaf1bf5ce0 \ + --hash=sha256:a09506eb48fa5493c58f946c46754ef22f3ec0df64f2b5149373ff31fb67f3dd \ + --hash=sha256:a212e5dea1a4182e40cd3e4067ee46be9d10418092ce3627475e995cca95de21 \ + --hash=sha256:a731ac5cffc34dac62053e0da90f0c0b8560396a19f69d9703e88240c8f05858 \ + --hash=sha256:af5ef6cfaf0d023c00002ba25d0751e5995fa0e4c9eec6cd263c30352662cbce \ + --hash=sha256:b58b855d0071575ea9c7bc0d84a06d2edfbfccec52e9657864386381a7ce1ae9 \ + --hash=sha256:bc808924470643b82b14fe121923c30ec211d8c693e747eba8a7414bc4351a23 \ + --hash=sha256:c557e94e91a983e5b1e9c60076a8fd79fea1e7e06848eb2e48d0ccfb30f6e073 \ + --hash=sha256:c71be3f86d67d8d1311c6076a4ba3b75ba5703c0b856b4e691c9097f9b1e8bd2 \ + --hash=sha256:c8754c75f55781515a3005063d9a05878b2cfb3cb7e41d5401ad0cf19de14872 \ + --hash=sha256:cb0af13433dbbd1c806e671d81ec75bd324af6ef75171fd7815ca3074fe32bfe \ + --hash=sha256:cba6209c96828711cb7c8fcb45ecef8c8859238baf15119daa1bef0f6c84bfe7 \ + --hash=sha256:cf77f8cf2a651fbd869fbdcb4a1931464189cd210abc4cfad357f1cacc8642a6 \ + --hash=sha256:d7404cebcdb11bb5bd40bf94131faf7e9a7c10a6c60358580fe83913f360f929 \ + --hash=sha256:dd1d3a8d1d2e50ad9b59e10aa7f07c7d1be2b367f3f2d33c5fade96ed5460962 \ + --hash=sha256:e5d97c65ea7e097056f3d1ead77040ebc236feaf7f71489383d20f3b4c28412a \ + --hash=sha256:f1c3dc536b3ee124e8b24feb7533e5c70b9f2ef833e3b2e5513b2897fd46763a \ + --hash=sha256:f2212796593ad1d0235068c79836861f2201fc7137a99aa2fea7beeb3b101177 \ + --hash=sha256:fead980fbc68512dfd4e0c7b1f5754c2a8e5015a04dea454b9cada54a8423525 # via # -r requirements.in # tb-nightly @@ -258,113 +249,115 @@ h5py==3.10.0 \ # via # -r requirements.in # keras-nightly -idna==3.4 \ - --hash=sha256:814f528e8dead7d329833b91c5faa87d60bf71824cd12a7530b5526063d02cb4 \ - --hash=sha256:90b77e79eaa3eba6de819a0c442c0b4ceefc341a7a2ab77d7562bf49f425c5c2 +idna==3.6 \ + --hash=sha256:9ecdbbd083b06798ae1e86adcbfe8ab1479cf864e4ee30fe4e46a003d12491ca \ + --hash=sha256:c05567e9c24a6b9faaa835c4821bad0590fbb9d5779e7caa6e1cc4978e7eb24f # via requests jax==0.4.7 \ --hash=sha256:5e7002d74db25f97c99b979d4ba1233b1ef26e1597e5fc468ad11d1c8a9dc4f8 # via -r requirements.in -keras-nightly==3.0.0.dev2023103103 \ - --hash=sha256:25a6a9030d7e067d535ec41ae6700636c90cabada80fbddb3bb8775006807860 \ - --hash=sha256:c39fccebb3d4cc9d838371981c4d3eef08fc06eadf6cffb39b356cb625bab50f +keras-nightly==3.0.4.dev2024021403 \ + --hash=sha256:24ce69d29d582771685bf4235f59663723405b5a5b16f3eaff2657e52e74663a \ + --hash=sha256:9f416e66b820ef833779d219d255b346b8b90a72fdbd0b2f1e90a43ad142a03d # via -r requirements.in -lit==17.0.4 \ - --hash=sha256:ee2e180128e770abc6aed3a02de2daf09d81b7d30225e315205d3599c311d304 +lit==17.0.6 \ + --hash=sha256:dfa9af9b55fc4509a56be7bf2346f079d7f4a242d583b9f2e0b078fd0abae31b # via -r requirements.in -markdown==3.5 \ - --hash=sha256:4afb124395ce5fc34e6d9886dab977fd9ae987fc6e85689f08278cf0c69d4bf3 \ - --hash=sha256:a807eb2e4778d9156c8f07876c6e4d50b5494c5665c4834f67b06459dfd877b3 +markdown==3.5.2 \ + --hash=sha256:d43323865d89fc0cb9b20c75fc8ad313af307cc087e84b657d9eec768eddeadd \ + --hash=sha256:e1ac7b3dc550ee80e602e71c1d168002f062e49f1b11e26a36264dafd4df2ef8 # via tb-nightly markdown-it-py==3.0.0 \ --hash=sha256:355216845c60bd96232cd8d8c40e8f9765cc86f46880e43a8fd22dc1a1a8cab1 \ --hash=sha256:e3f60a94fa066dc52ec76661e37c851cb232d92f9886b15cb560aaada2df8feb # via rich -markupsafe==2.1.3 \ - --hash=sha256:05fb21170423db021895e1ea1e1f3ab3adb85d1c2333cbc2310f2a26bc77272e \ - --hash=sha256:0a4e4a1aff6c7ac4cd55792abf96c915634c2b97e3cc1c7129578aa68ebd754e \ - --hash=sha256:10bbfe99883db80bdbaff2dcf681dfc6533a614f700da1287707e8a5d78a8431 \ - --hash=sha256:134da1eca9ec0ae528110ccc9e48041e0828d79f24121a1a146161103c76e686 \ - --hash=sha256:14ff806850827afd6b07a5f32bd917fb7f45b046ba40c57abdb636674a8b559c \ - --hash=sha256:1577735524cdad32f9f694208aa75e422adba74f1baee7551620e43a3141f559 \ - --hash=sha256:1b40069d487e7edb2676d3fbdb2b0829ffa2cd63a2ec26c4938b2d34391b4ecc \ - --hash=sha256:1b8dd8c3fd14349433c79fa8abeb573a55fc0fdd769133baac1f5e07abf54aeb \ - --hash=sha256:1f67c7038d560d92149c060157d623c542173016c4babc0c1913cca0564b9939 \ - --hash=sha256:282c2cb35b5b673bbcadb33a585408104df04f14b2d9b01d4c345a3b92861c2c \ - --hash=sha256:2c1b19b3aaacc6e57b7e25710ff571c24d6c3613a45e905b1fde04d691b98ee0 \ - --hash=sha256:2ef12179d3a291be237280175b542c07a36e7f60718296278d8593d21ca937d4 \ - --hash=sha256:338ae27d6b8745585f87218a3f23f1512dbf52c26c28e322dbe54bcede54ccb9 \ - --hash=sha256:3c0fae6c3be832a0a0473ac912810b2877c8cb9d76ca48de1ed31e1c68386575 \ - --hash=sha256:3fd4abcb888d15a94f32b75d8fd18ee162ca0c064f35b11134be77050296d6ba \ - --hash=sha256:42de32b22b6b804f42c5d98be4f7e5e977ecdd9ee9b660fda1a3edf03b11792d \ - --hash=sha256:47d4f1c5f80fc62fdd7777d0d40a2e9dda0a05883ab11374334f6c4de38adffd \ - --hash=sha256:504b320cd4b7eff6f968eddf81127112db685e81f7e36e75f9f84f0df46041c3 \ - --hash=sha256:525808b8019e36eb524b8c68acdd63a37e75714eac50e988180b169d64480a00 \ - --hash=sha256:56d9f2ecac662ca1611d183feb03a3fa4406469dafe241673d521dd5ae92a155 \ - --hash=sha256:5bbe06f8eeafd38e5d0a4894ffec89378b6c6a625ff57e3028921f8ff59318ac \ - --hash=sha256:65c1a9bcdadc6c28eecee2c119465aebff8f7a584dd719facdd9e825ec61ab52 \ - --hash=sha256:68e78619a61ecf91e76aa3e6e8e33fc4894a2bebe93410754bd28fce0a8a4f9f \ - --hash=sha256:69c0f17e9f5a7afdf2cc9fb2d1ce6aabdb3bafb7f38017c0b77862bcec2bbad8 \ - --hash=sha256:6b2b56950d93e41f33b4223ead100ea0fe11f8e6ee5f641eb753ce4b77a7042b \ - --hash=sha256:715d3562f79d540f251b99ebd6d8baa547118974341db04f5ad06d5ea3eb8007 \ - --hash=sha256:787003c0ddb00500e49a10f2844fac87aa6ce977b90b0feaaf9de23c22508b24 \ - --hash=sha256:7ef3cb2ebbf91e330e3bb937efada0edd9003683db6b57bb108c4001f37a02ea \ - --hash=sha256:8023faf4e01efadfa183e863fefde0046de576c6f14659e8782065bcece22198 \ - --hash=sha256:8758846a7e80910096950b67071243da3e5a20ed2546e6392603c096778d48e0 \ - --hash=sha256:8afafd99945ead6e075b973fefa56379c5b5c53fd8937dad92c662da5d8fd5ee \ - --hash=sha256:8c41976a29d078bb235fea9b2ecd3da465df42a562910f9022f1a03107bd02be \ - --hash=sha256:8e254ae696c88d98da6555f5ace2279cf7cd5b3f52be2b5cf97feafe883b58d2 \ - --hash=sha256:8f9293864fe09b8149f0cc42ce56e3f0e54de883a9de90cd427f191c346eb2e1 \ - --hash=sha256:9402b03f1a1b4dc4c19845e5c749e3ab82d5078d16a2a4c2cd2df62d57bb0707 \ - --hash=sha256:962f82a3086483f5e5f64dbad880d31038b698494799b097bc59c2edf392fce6 \ - --hash=sha256:9aad3c1755095ce347e26488214ef77e0485a3c34a50c5a5e2471dff60b9dd9c \ - --hash=sha256:9dcdfd0eaf283af041973bff14a2e143b8bd64e069f4c383416ecd79a81aab58 \ - --hash=sha256:aa57bd9cf8ae831a362185ee444e15a93ecb2e344c8e52e4d721ea3ab6ef1823 \ - --hash=sha256:aa7bd130efab1c280bed0f45501b7c8795f9fdbeb02e965371bbef3523627779 \ - --hash=sha256:ab4a0df41e7c16a1392727727e7998a467472d0ad65f3ad5e6e765015df08636 \ - --hash=sha256:ad9e82fb8f09ade1c3e1b996a6337afac2b8b9e365f926f5a61aacc71adc5b3c \ - --hash=sha256:af598ed32d6ae86f1b747b82783958b1a4ab8f617b06fe68795c7f026abbdcad \ - --hash=sha256:b076b6226fb84157e3f7c971a47ff3a679d837cf338547532ab866c57930dbee \ - --hash=sha256:b7ff0f54cb4ff66dd38bebd335a38e2c22c41a8ee45aa608efc890ac3e3931bc \ - --hash=sha256:bfce63a9e7834b12b87c64d6b155fdd9b3b96191b6bd334bf37db7ff1fe457f2 \ - --hash=sha256:c011a4149cfbcf9f03994ec2edffcb8b1dc2d2aede7ca243746df97a5d41ce48 \ - --hash=sha256:c9c804664ebe8f83a211cace637506669e7890fec1b4195b505c214e50dd4eb7 \ - --hash=sha256:ca379055a47383d02a5400cb0d110cef0a776fc644cda797db0c5696cfd7e18e \ - --hash=sha256:cb0932dc158471523c9637e807d9bfb93e06a95cbf010f1a38b98623b929ef2b \ - --hash=sha256:cd0f502fe016460680cd20aaa5a76d241d6f35a1c3350c474bac1273803893fa \ - --hash=sha256:ceb01949af7121f9fc39f7d27f91be8546f3fb112c608bc4029aef0bab86a2a5 \ - --hash=sha256:d080e0a5eb2529460b30190fcfcc4199bd7f827663f858a226a81bc27beaa97e \ - --hash=sha256:dd15ff04ffd7e05ffcb7fe79f1b98041b8ea30ae9234aed2a9168b5797c3effb \ - --hash=sha256:df0be2b576a7abbf737b1575f048c23fb1d769f267ec4358296f31c2479db8f9 \ - --hash=sha256:e09031c87a1e51556fdcb46e5bd4f59dfb743061cf93c4d6831bf894f125eb57 \ - --hash=sha256:e4dd52d80b8c83fdce44e12478ad2e85c64ea965e75d66dbeafb0a3e77308fcc \ - --hash=sha256:f698de3fd0c4e6972b92290a45bd9b1536bffe8c6759c62471efaa8acb4c37bc \ - --hash=sha256:fec21693218efe39aa7f8599346e90c705afa52c5b31ae019b2e57e8f6542bb2 \ - --hash=sha256:ffcc3f7c66b5f5b7931a5aa68fc9cecc51e685ef90282f4a82f0f5e9b704ad11 +markupsafe==2.1.5 \ + --hash=sha256:00e046b6dd71aa03a41079792f8473dc494d564611a8f89bbbd7cb93295ebdcf \ + --hash=sha256:075202fa5b72c86ad32dc7d0b56024ebdbcf2048c0ba09f1cde31bfdd57bcfff \ + --hash=sha256:0e397ac966fdf721b2c528cf028494e86172b4feba51d65f81ffd65c63798f3f \ + --hash=sha256:17b950fccb810b3293638215058e432159d2b71005c74371d784862b7e4683f3 \ + --hash=sha256:1f3fbcb7ef1f16e48246f704ab79d79da8a46891e2da03f8783a5b6fa41a9532 \ + --hash=sha256:2174c595a0d73a3080ca3257b40096db99799265e1c27cc5a610743acd86d62f \ + --hash=sha256:2b7c57a4dfc4f16f7142221afe5ba4e093e09e728ca65c51f5620c9aaeb9a617 \ + --hash=sha256:2d2d793e36e230fd32babe143b04cec8a8b3eb8a3122d2aceb4a371e6b09b8df \ + --hash=sha256:30b600cf0a7ac9234b2638fbc0fb6158ba5bdcdf46aeb631ead21248b9affbc4 \ + --hash=sha256:397081c1a0bfb5124355710fe79478cdbeb39626492b15d399526ae53422b906 \ + --hash=sha256:3a57fdd7ce31c7ff06cdfbf31dafa96cc533c21e443d57f5b1ecc6cdc668ec7f \ + --hash=sha256:3c6b973f22eb18a789b1460b4b91bf04ae3f0c4234a0a6aa6b0a92f6f7b951d4 \ + --hash=sha256:3e53af139f8579a6d5f7b76549125f0d94d7e630761a2111bc431fd820e163b8 \ + --hash=sha256:4096e9de5c6fdf43fb4f04c26fb114f61ef0bf2e5604b6ee3019d51b69e8c371 \ + --hash=sha256:4275d846e41ecefa46e2015117a9f491e57a71ddd59bbead77e904dc02b1bed2 \ + --hash=sha256:4c31f53cdae6ecfa91a77820e8b151dba54ab528ba65dfd235c80b086d68a465 \ + --hash=sha256:4f11aa001c540f62c6166c7726f71f7573b52c68c31f014c25cc7901deea0b52 \ + --hash=sha256:5049256f536511ee3f7e1b3f87d1d1209d327e818e6ae1365e8653d7e3abb6a6 \ + --hash=sha256:58c98fee265677f63a4385256a6d7683ab1832f3ddd1e66fe948d5880c21a169 \ + --hash=sha256:598e3276b64aff0e7b3451b72e94fa3c238d452e7ddcd893c3ab324717456bad \ + --hash=sha256:5b7b716f97b52c5a14bffdf688f971b2d5ef4029127f1ad7a513973cfd818df2 \ + --hash=sha256:5dedb4db619ba5a2787a94d877bc8ffc0566f92a01c0ef214865e54ecc9ee5e0 \ + --hash=sha256:619bc166c4f2de5caa5a633b8b7326fbe98e0ccbfacabd87268a2b15ff73a029 \ + --hash=sha256:629ddd2ca402ae6dbedfceeba9c46d5f7b2a61d9749597d4307f943ef198fc1f \ + --hash=sha256:656f7526c69fac7f600bd1f400991cc282b417d17539a1b228617081106feb4a \ + --hash=sha256:6ec585f69cec0aa07d945b20805be741395e28ac1627333b1c5b0105962ffced \ + --hash=sha256:72b6be590cc35924b02c78ef34b467da4ba07e4e0f0454a2c5907f473fc50ce5 \ + --hash=sha256:7502934a33b54030eaf1194c21c692a534196063db72176b0c4028e140f8f32c \ + --hash=sha256:7a68b554d356a91cce1236aa7682dc01df0edba8d043fd1ce607c49dd3c1edcf \ + --hash=sha256:7b2e5a267c855eea6b4283940daa6e88a285f5f2a67f2220203786dfa59b37e9 \ + --hash=sha256:823b65d8706e32ad2df51ed89496147a42a2a6e01c13cfb6ffb8b1e92bc910bb \ + --hash=sha256:8590b4ae07a35970728874632fed7bd57b26b0102df2d2b233b6d9d82f6c62ad \ + --hash=sha256:8dd717634f5a044f860435c1d8c16a270ddf0ef8588d4887037c5028b859b0c3 \ + --hash=sha256:8dec4936e9c3100156f8a2dc89c4b88d5c435175ff03413b443469c7c8c5f4d1 \ + --hash=sha256:97cafb1f3cbcd3fd2b6fbfb99ae11cdb14deea0736fc2b0952ee177f2b813a46 \ + --hash=sha256:a17a92de5231666cfbe003f0e4b9b3a7ae3afb1ec2845aadc2bacc93ff85febc \ + --hash=sha256:a549b9c31bec33820e885335b451286e2969a2d9e24879f83fe904a5ce59d70a \ + --hash=sha256:ac07bad82163452a6884fe8fa0963fb98c2346ba78d779ec06bd7a6262132aee \ + --hash=sha256:ae2ad8ae6ebee9d2d94b17fb62763125f3f374c25618198f40cbb8b525411900 \ + --hash=sha256:b91c037585eba9095565a3556f611e3cbfaa42ca1e865f7b8015fe5c7336d5a5 \ + --hash=sha256:bc1667f8b83f48511b94671e0e441401371dfd0f0a795c7daa4a3cd1dde55bea \ + --hash=sha256:bec0a414d016ac1a18862a519e54b2fd0fc8bbfd6890376898a6c0891dd82e9f \ + --hash=sha256:bf50cd79a75d181c9181df03572cdce0fbb75cc353bc350712073108cba98de5 \ + --hash=sha256:bff1b4290a66b490a2f4719358c0cdcd9bafb6b8f061e45c7a2460866bf50c2e \ + --hash=sha256:c061bb86a71b42465156a3ee7bd58c8c2ceacdbeb95d05a99893e08b8467359a \ + --hash=sha256:c8b29db45f8fe46ad280a7294f5c3ec36dbac9491f2d1c17345be8e69cc5928f \ + --hash=sha256:ce409136744f6521e39fd8e2a24c53fa18ad67aa5bc7c2cf83645cce5b5c4e50 \ + --hash=sha256:d050b3361367a06d752db6ead6e7edeb0009be66bc3bae0ee9d97fb326badc2a \ + --hash=sha256:d283d37a890ba4c1ae73ffadf8046435c76e7bc2247bbb63c00bd1a709c6544b \ + --hash=sha256:d9fad5155d72433c921b782e58892377c44bd6252b5af2f67f16b194987338a4 \ + --hash=sha256:daa4ee5a243f0f20d528d939d06670a298dd39b1ad5f8a72a4275124a7819eff \ + --hash=sha256:db0b55e0f3cc0be60c1f19efdde9a637c32740486004f20d1cff53c3c0ece4d2 \ + --hash=sha256:e61659ba32cf2cf1481e575d0462554625196a1f2fc06a1c777d3f48e8865d46 \ + --hash=sha256:ea3d8a3d18833cf4304cd2fc9cbb1efe188ca9b5efef2bdac7adc20594a0e46b \ + --hash=sha256:ec6a563cff360b50eed26f13adc43e61bc0c04d94b8be985e6fb24b81f6dcfdf \ + --hash=sha256:f5dfb42c4604dddc8e4305050aa6deb084540643ed5804d7455b5df8fe16f5e5 \ + --hash=sha256:fa173ec60341d6bb97a89f5ea19c85c5643c1e7dedebc22f5181eb73573142c5 \ + --hash=sha256:fa9db3f79de01457b03d4f01b34cf91bc0048eb2c3846ff26f66687c2f6d16ab \ + --hash=sha256:fce659a462a1be54d2ffcacea5e3ba2d74daa74f30f5f143fe0c58636e355fdd \ + --hash=sha256:ffee1f21e5ef0d712f9033568f8344d5da8cc2869dbd08d87c84656e6a2d2f68 # via werkzeug mdurl==0.1.2 \ --hash=sha256:84008a41e51615a49fc9966191ff91509e3c40b939176e643fd50a5c2196b8f8 \ --hash=sha256:bb413d29f5eea38f31dd4754dd7377d4465116fb207585f97bf925588687c1ba # via markdown-it-py -ml-dtypes==0.3.1 \ - --hash=sha256:3d8ca0acbd377082792d8b97081ba580abdad67c6afb7f827012c675b052f058 \ - --hash=sha256:42a8980afd8b7c8e270e8b5c260237286b5b26acd276fcb758d13cd7cb567e99 \ - --hash=sha256:438437e2e614a3c91d75581653b6c40ec890e8b5994d7190a90c931740151c95 \ - --hash=sha256:4828b62fa3bf1ae35faa40f3db9a38ec72fbce02f328a1d14c3a9da4606af364 \ - --hash=sha256:4d94b2d1bed77284694f7fd0479640fa7aa5d96433dca3cbcec407a5ef752e77 \ - --hash=sha256:510d249a91face47211762eb294d6fe64f325356b965fb6388c1bf51bd339267 \ - --hash=sha256:5727effa7650f7ab10906542d137cfb3244fdc3b2b519beff42f82def8ba59be \ - --hash=sha256:5e0b0b6bb07fa5ad11bb61d174667176bee5e05857225067aabfc5adc1b51d23 \ - --hash=sha256:60778f99194b4c4f36ba42da200b35ef851ce4d4af698aaf70f5b91fe70fc611 \ - --hash=sha256:70984b473db6489ec1d8c79b082a1322105155193049d08a3b0c515094e9777b \ - --hash=sha256:979d7d196d9a17e0135ae22878f74241fbd3522cef58d7b292f1fd5b32282201 \ - --hash=sha256:a777928dcba8865ab4a8157eeb25d23aed7bc82e5fd74e1d5eca821d3f148b39 \ - --hash=sha256:cb0c404e0dd3e25b56362c1c1e5de0ef717f727dde59fa721af4ce6ab2acca44 \ - --hash=sha256:d1a8dc3bac1da2a17d0e2e4cba36ee89721d0bd33ea4765af2eefb5f41409e0f \ - --hash=sha256:da274599e4950a9b488d21571061f49a185537cc77f2d3f8121151d58a9e9f16 \ - --hash=sha256:f83ff080df8910c0f987f615b03e4f8198638e0c00c6e679ea8892dda909763b \ - --hash=sha256:fcae2c69715410d96906e1dfe8f017d9f78a0d10e0df91aae52e91f51fdfe45e - # via jax +ml-dtypes==0.3.2 \ + --hash=sha256:2c34f2ba9660b21fe1034b608308a01be82bbef2a92fb8199f24dc6bad0d5226 \ + --hash=sha256:3a17ef2322e60858d93584e9c52a5be7dd6236b056b7fa1ec57f1bb6ba043e33 \ + --hash=sha256:533059bc5f1764fac071ef54598db358c167c51a718f68f5bb55e3dee79d2967 \ + --hash=sha256:6604877d567a29bfe7cc02969ae0f2425260e5335505cf5e7fefc3e5465f5655 \ + --hash=sha256:6b35c4e8ca957c877ac35c79ffa77724ecc3702a1e4b18b08306c03feae597bb \ + --hash=sha256:763697ab8a88d47443997a7cdf3aac7340049aed45f7521f6b0ec8a0594821fe \ + --hash=sha256:7a4c3fcbf86fa52d0204f07cfd23947ef05b4ad743a1a988e163caa34a201e5e \ + --hash=sha256:7afde548890a92b41c0fed3a6c525f1200a5727205f73dc21181a2726571bb53 \ + --hash=sha256:7ba8e1fafc7fff3e643f453bffa7d082df1678a73286ce8187d3e825e776eb94 \ + --hash=sha256:91f8783fd1f2c23fd3b9ee5ad66b785dafa58ba3cdb050c4458021fa4d1eb226 \ + --hash=sha256:93b78f53431c93953f7850bb1b925a17f0ab5d97527e38a7e865b5b4bc5cfc18 \ + --hash=sha256:961134ea44c7b8ca63eda902a44b58cd8bd670e21d62e255c81fba0a8e70d9b7 \ + --hash=sha256:b89b194e9501a92d289c1ffd411380baf5daafb9818109a4f49b0a1b6dce4462 \ + --hash=sha256:c7b3fb3d4f6b39bcd4f6c4b98f406291f0d681a895490ee29a0f95bab850d53c \ + --hash=sha256:d1a746fe5fb9cd974a91070174258f0be129c592b93f9ce7df6cc336416c3fbd \ + --hash=sha256:e8505946df1665db01332d885c2020b4cb9e84a8b1241eb4ba69d59591f65855 \ + --hash=sha256:f47619d978ab1ae7dfdc4052ea97c636c6263e1f19bd1be0e42c346b98d15ff4 + # via + # jax + # keras-nightly namex==0.0.7 \ --hash=sha256:84ba65bc4d22bd909e3d26bf2ffb4b9529b608cb3f9a4336f776b04204ced69b \ --hash=sha256:8a4f062945f405d77cb66b907f16aa2fd83681945e998be840eb6c4154d40108 @@ -407,10 +400,6 @@ numpy==1.23.5 ; python_version <= "3.11" \ # opt-einsum # scipy # tb-nightly -oauthlib==3.2.2 \ - --hash=sha256:8139f29aac13e25d502680e9e19963e83f16838d48a0d71c287fe40e7067fbca \ - --hash=sha256:9859c40929662bec5d64f34d01c99e093149682a3f38915dc0655d5a633dd918 - # via requests-oauthlib opt-einsum==3.3.0 \ --hash=sha256:2455e59e3947d3c275477df7f5205b30635e266fe6dc300e3d9f9646bfcea147 \ --hash=sha256:59f6475f77bbc37dcf7cd748519c0ec60722e91e63ca114e68821c0c54a46549 @@ -440,57 +429,36 @@ protobuf==4.23.4 \ --hash=sha256:effeac51ab79332d44fba74660d40ae79985901ac21bca408f8dc335a81aa597 \ --hash=sha256:fee88269a090ada09ca63551bf2f573eb2424035bcf2cb1b121895b01a46594a # via tb-nightly -psutil==5.9.6 \ - --hash=sha256:10e8c17b4f898d64b121149afb136c53ea8b68c7531155147867b7b1ac9e7e28 \ - --hash=sha256:18cd22c5db486f33998f37e2bb054cc62fd06646995285e02a51b1e08da97017 \ - --hash=sha256:3ebf2158c16cc69db777e3c7decb3c0f43a7af94a60d72e87b2823aebac3d602 \ - --hash=sha256:51dc3d54607c73148f63732c727856f5febec1c7c336f8f41fcbd6315cce76ac \ - --hash=sha256:6e5fb8dc711a514da83098bc5234264e551ad980cec5f85dabf4d38ed6f15e9a \ - --hash=sha256:70cb3beb98bc3fd5ac9ac617a327af7e7f826373ee64c80efd4eb2856e5051e9 \ - --hash=sha256:748c9dd2583ed86347ed65d0035f45fa8c851e8d90354c122ab72319b5f366f4 \ - --hash=sha256:91ecd2d9c00db9817a4b4192107cf6954addb5d9d67a969a4f436dbc9200f88c \ - --hash=sha256:92e0cc43c524834af53e9d3369245e6cc3b130e78e26100d1f63cdb0abeb3d3c \ - --hash=sha256:a6f01f03bf1843280f4ad16f4bde26b817847b4c1a0db59bf6419807bc5ce05c \ - --hash=sha256:c69596f9fc2f8acd574a12d5f8b7b1ba3765a641ea5d60fb4736bf3c08a8214a \ - --hash=sha256:ca2780f5e038379e520281e4c032dddd086906ddff9ef0d1b9dcf00710e5071c \ - --hash=sha256:daecbcbd29b289aac14ece28eca6a3e60aa361754cf6da3dfb20d4d32b6c7f57 \ - --hash=sha256:e4b92ddcd7dd4cdd3f900180ea1e104932c7bce234fb88976e2a3b296441225a \ - --hash=sha256:fb8a697f11b0f5994550555fcfe3e69799e5b060c8ecf9e2f75c69302cc35c0d \ - --hash=sha256:ff18b8d1a784b810df0b0fff3bcb50ab941c3b8e2c8de5726f9c71c601c611aa +psutil==5.9.8 \ + --hash=sha256:02615ed8c5ea222323408ceba16c60e99c3f91639b07da6373fb7e6539abc56d \ + --hash=sha256:05806de88103b25903dff19bb6692bd2e714ccf9e668d050d144012055cbca73 \ + --hash=sha256:26bd09967ae00920df88e0352a91cff1a78f8d69b3ecabbfe733610c0af486c8 \ + --hash=sha256:27cc40c3493bb10de1be4b3f07cae4c010ce715290a5be22b98493509c6299e2 \ + --hash=sha256:36f435891adb138ed3c9e58c6af3e2e6ca9ac2f365efe1f9cfef2794e6c93b4e \ + --hash=sha256:50187900d73c1381ba1454cf40308c2bf6f34268518b3f36a9b663ca87e65e36 \ + --hash=sha256:611052c4bc70432ec770d5d54f64206aa7203a101ec273a0cd82418c86503bb7 \ + --hash=sha256:6be126e3225486dff286a8fb9a06246a5253f4c7c53b475ea5f5ac934e64194c \ + --hash=sha256:7d79560ad97af658a0f6adfef8b834b53f64746d45b403f225b85c5c2c140eee \ + --hash=sha256:8cb6403ce6d8e047495a701dc7c5bd788add903f8986d523e3e20b98b733e421 \ + --hash=sha256:8db4c1b57507eef143a15a6884ca10f7c73876cdf5d51e713151c1236a0e68cf \ + --hash=sha256:aee678c8720623dc456fa20659af736241f575d79429a0e5e9cf88ae0605cc81 \ + --hash=sha256:bc56c2a1b0d15aa3eaa5a60c9f3f8e3e565303b465dbf57a1b730e7a2b9844e0 \ + --hash=sha256:bd1184ceb3f87651a67b2708d4c3338e9b10c5df903f2e3776b62303b26cb631 \ + --hash=sha256:d06016f7f8625a1825ba3732081d77c94589dca78b7a3fc072194851e88461a4 \ + --hash=sha256:d16bbddf0693323b8c6123dd804100241da461e41d6e332fb0ba6058f630f8c8 # via portpicker -pyasn1==0.5.0 \ - --hash=sha256:87a2121042a1ac9358cabcaf1d07680ff97ee6404333bacca15f76aa8ad01a57 \ - --hash=sha256:97b7290ca68e62a832558ec3976f15cbf911bf5d7c7039d8b861c2a0ece69fde - # via - # pyasn1-modules - # rsa -pyasn1-modules==0.3.0 \ - --hash=sha256:5bd01446b736eb9d31512a30d46c1ac3395d676c6f3cafa4c03eb54b9925631c \ - --hash=sha256:d3ccd6ed470d9ffbc716be08bd90efbd44d0734bc9303818f7336070984a162d - # via google-auth -pygments==2.16.1 \ - --hash=sha256:13fc09fa63bc8d8671a6d247e1eb303c4b343eaee81d861f3404db2935653692 \ - --hash=sha256:1daff0494820c69bc8941e407aa20f577374ee88364ee10a98fdbe0aece96e29 +pygments==2.17.2 \ + --hash=sha256:b27c2826c47d0f3219f29554824c30c5e8945175d888647acd804ddd04af846c \ + --hash=sha256:da46cec9fd2de5be3a8a784f434e4c4ab670b4ff54d605c4c2717e9d49c4c367 # via rich requests==2.31.0 \ --hash=sha256:58cd2187c01e70e6e26505bca751777aa9f2ee0b7f4300988b709f44e013003f \ --hash=sha256:942c5a758f98d790eaed1a29cb6eefc7ffb0d1cf7af05c3d2791656dbd6ad1e1 - # via - # -r requirements.in - # requests-oauthlib - # tb-nightly -requests-oauthlib==1.3.1 \ - --hash=sha256:2577c501a2fb8d05a304c09d090d6e47c306fef15809d102b327cf8364bddab5 \ - --hash=sha256:75beac4a47881eeb94d5ea5d6ad31ef88856affe2332b9aafb52c6452ccf0d7a - # via google-auth-oauthlib -rich==13.6.0 \ - --hash=sha256:2b38e2fe9ca72c9a00170a1a2d20c63c790d0e10ef1fe35eba76e1e7b1d7d245 \ - --hash=sha256:5c14d22737e6d5084ef4771b62d5d4363165b403455a30a1c8ca39dc7b644bef + # via -r requirements.in +rich==13.7.0 \ + --hash=sha256:5cb5123b5cf9ee70584244246816e9114227e0b98ad9176eede6ad54bf5403fa \ + --hash=sha256:6da14c108c4866ee9520bbffa71f6fe3962e193b7da68720583850cd4548e235 # via keras-nightly -rsa==4.9 \ - --hash=sha256:90260d9058e514786967344d0ef75fa8727eed8a7d2e43ce9f4bcf1b536174f7 \ - --hash=sha256:e38464a49c6c85d7f1351b0126661487a7e0a14a50f1675ec50eb34d4f20ef21 - # via google-auth scipy==1.11.3 \ --hash=sha256:00f325434b6424952fbb636506f0567898dca7b0f7654d48f1c382ea338ce9a3 \ --hash=sha256:033c3fd95d55012dd1148b201b72ae854d5086d25e7c316ec9850de4fe776929 \ @@ -526,8 +494,8 @@ six==1.16.0 \ # via # astunparse # tb-nightly -tb-nightly==2.15.0a20231023 \ - --hash=sha256:8990a52985e3296aa18a6825efc017bcded9e2fb2cbcdd2d01f5c62d4fdc9825 +tb-nightly==2.17.0a20240214 \ + --hash=sha256:dbd59bcfd9b028e6199050e36cbb4ee1db4a94e69a7f8c00865f1498112a2b83 # via -r requirements.in tblib==2.0.0 \ --hash=sha256:9100bfa016b047d5b980d66e7efed952fbd20bd85b56110aaf473cb97d18709a \ @@ -542,13 +510,17 @@ termcolor==2.3.0 \ --hash=sha256:3afb05607b89aed0ffe25202399ee0867ad4d3cb4180d98aaf8eefa6a5f7d475 \ --hash=sha256:b5b08f68937f138fe92f6c089b99f1e2da0ae56c52b78bf7075fd95420fd9a5a # via -r requirements.in +tf-keras-nightly==2.16.0.dev2024021410 \ + --hash=sha256:05a4c19c6a795ec9ed6f4dac69d443601b9206b910524d8b19bad6e362d49f47 \ + --hash=sha256:83d5a1dd3979b42164a2fac4bdd2ed5981ef0c2a0514ca64e4bc67e369caef5c + # via tb-nightly typing-extensions==4.8.0 \ --hash=sha256:8f92fc8806f9a6b641eaa5318da32b44d401efaac0f6678c9bc448ba3605faa0 \ --hash=sha256:df8e4339e9cb77357558cbdbceca33c303714cf861d1eef15e1070055ae8b7ef # via -r requirements.in -urllib3==2.0.7 \ - --hash=sha256:c97dfde1f7bd43a71c8d2a58e369e9b2bf692d1334ea9f9cae55add7d0dd0f84 \ - --hash=sha256:fdb6d215c776278489906c2f8916e6e7d4f5a9b602ccbcfdf7f016fc8da0596e +urllib3==2.2.0 \ + --hash=sha256:051d961ad0c62a94e50ecf1af379c3aba230c66c710493493560c0c223c49f20 \ + --hash=sha256:ce3711610ddce217e6d113a2732fafad960a03fd0318c91faa79481e35c11224 # via requests werkzeug==3.0.1 \ --hash=sha256:507e811ecea72b18a404947aded4b3390e1db8f826b494d76550ef45bb3b1dcc \ @@ -560,81 +532,77 @@ wheel==0.41.3 \ # via # -r requirements.in # astunparse -wrapt==1.14.1 \ - --hash=sha256:00b6d4ea20a906c0ca56d84f93065b398ab74b927a7a3dbd470f6fc503f95dc3 \ - --hash=sha256:01c205616a89d09827986bc4e859bcabd64f5a0662a7fe95e0d359424e0e071b \ - --hash=sha256:02b41b633c6261feff8ddd8d11c711df6842aba629fdd3da10249a53211a72c4 \ - --hash=sha256:07f7a7d0f388028b2df1d916e94bbb40624c59b48ecc6cbc232546706fac74c2 \ - --hash=sha256:11871514607b15cfeb87c547a49bca19fde402f32e2b1c24a632506c0a756656 \ - --hash=sha256:1b376b3f4896e7930f1f772ac4b064ac12598d1c38d04907e696cc4d794b43d3 \ - --hash=sha256:2020f391008ef874c6d9e208b24f28e31bcb85ccff4f335f15a3251d222b92d9 \ - --hash=sha256:21ac0156c4b089b330b7666db40feee30a5d52634cc4560e1905d6529a3897ff \ - --hash=sha256:240b1686f38ae665d1b15475966fe0472f78e71b1b4903c143a842659c8e4cb9 \ - --hash=sha256:257fd78c513e0fb5cdbe058c27a0624c9884e735bbd131935fd49e9fe719d310 \ - --hash=sha256:26046cd03936ae745a502abf44dac702a5e6880b2b01c29aea8ddf3353b68224 \ - --hash=sha256:2b39d38039a1fdad98c87279b48bc5dce2c0ca0d73483b12cb72aa9609278e8a \ - --hash=sha256:2cf71233a0ed05ccdabe209c606fe0bac7379fdcf687f39b944420d2a09fdb57 \ - --hash=sha256:2fe803deacd09a233e4762a1adcea5db5d31e6be577a43352936179d14d90069 \ - --hash=sha256:2feecf86e1f7a86517cab34ae6c2f081fd2d0dac860cb0c0ded96d799d20b335 \ - --hash=sha256:3232822c7d98d23895ccc443bbdf57c7412c5a65996c30442ebe6ed3df335383 \ - --hash=sha256:34aa51c45f28ba7f12accd624225e2b1e5a3a45206aa191f6f9aac931d9d56fe \ - --hash=sha256:358fe87cc899c6bb0ddc185bf3dbfa4ba646f05b1b0b9b5a27c2cb92c2cea204 \ - --hash=sha256:36f582d0c6bc99d5f39cd3ac2a9062e57f3cf606ade29a0a0d6b323462f4dd87 \ - --hash=sha256:380a85cf89e0e69b7cfbe2ea9f765f004ff419f34194018a6827ac0e3edfed4d \ - --hash=sha256:40e7bc81c9e2b2734ea4bc1aceb8a8f0ceaac7c5299bc5d69e37c44d9081d43b \ - --hash=sha256:43ca3bbbe97af00f49efb06e352eae40434ca9d915906f77def219b88e85d907 \ - --hash=sha256:49ef582b7a1152ae2766557f0550a9fcbf7bbd76f43fbdc94dd3bf07cc7168be \ - --hash=sha256:4fcc4649dc762cddacd193e6b55bc02edca674067f5f98166d7713b193932b7f \ - --hash=sha256:5a0f54ce2c092aaf439813735584b9537cad479575a09892b8352fea5e988dc0 \ - --hash=sha256:5a9a0d155deafd9448baff28c08e150d9b24ff010e899311ddd63c45c2445e28 \ - --hash=sha256:5b02d65b9ccf0ef6c34cba6cf5bf2aab1bb2f49c6090bafeecc9cd81ad4ea1c1 \ - --hash=sha256:60db23fa423575eeb65ea430cee741acb7c26a1365d103f7b0f6ec412b893853 \ - --hash=sha256:642c2e7a804fcf18c222e1060df25fc210b9c58db7c91416fb055897fc27e8cc \ - --hash=sha256:6447e9f3ba72f8e2b985a1da758767698efa72723d5b59accefd716e9e8272bf \ - --hash=sha256:6a9a25751acb379b466ff6be78a315e2b439d4c94c1e99cb7266d40a537995d3 \ - --hash=sha256:6b1a564e6cb69922c7fe3a678b9f9a3c54e72b469875aa8018f18b4d1dd1adf3 \ - --hash=sha256:6d323e1554b3d22cfc03cd3243b5bb815a51f5249fdcbb86fda4bf62bab9e164 \ - --hash=sha256:6e743de5e9c3d1b7185870f480587b75b1cb604832e380d64f9504a0535912d1 \ - --hash=sha256:709fe01086a55cf79d20f741f39325018f4df051ef39fe921b1ebe780a66184c \ - --hash=sha256:7b7c050ae976e286906dd3f26009e117eb000fb2cf3533398c5ad9ccc86867b1 \ - --hash=sha256:7d2872609603cb35ca513d7404a94d6d608fc13211563571117046c9d2bcc3d7 \ - --hash=sha256:7ef58fb89674095bfc57c4069e95d7a31cfdc0939e2a579882ac7d55aadfd2a1 \ - --hash=sha256:80bb5c256f1415f747011dc3604b59bc1f91c6e7150bd7db03b19170ee06b320 \ - --hash=sha256:81b19725065dcb43df02b37e03278c011a09e49757287dca60c5aecdd5a0b8ed \ - --hash=sha256:833b58d5d0b7e5b9832869f039203389ac7cbf01765639c7309fd50ef619e0b1 \ - --hash=sha256:88bd7b6bd70a5b6803c1abf6bca012f7ed963e58c68d76ee20b9d751c74a3248 \ - --hash=sha256:8ad85f7f4e20964db4daadcab70b47ab05c7c1cf2a7c1e51087bfaa83831854c \ - --hash=sha256:8c0ce1e99116d5ab21355d8ebe53d9460366704ea38ae4d9f6933188f327b456 \ - --hash=sha256:8d649d616e5c6a678b26d15ece345354f7c2286acd6db868e65fcc5ff7c24a77 \ - --hash=sha256:903500616422a40a98a5a3c4ff4ed9d0066f3b4c951fa286018ecdf0750194ef \ - --hash=sha256:9736af4641846491aedb3c3f56b9bc5568d92b0692303b5a305301a95dfd38b1 \ - --hash=sha256:988635d122aaf2bdcef9e795435662bcd65b02f4f4c1ae37fbee7401c440b3a7 \ - --hash=sha256:9cca3c2cdadb362116235fdbd411735de4328c61425b0aa9f872fd76d02c4e86 \ - --hash=sha256:9e0fd32e0148dd5dea6af5fee42beb949098564cc23211a88d799e434255a1f4 \ - --hash=sha256:9f3e6f9e05148ff90002b884fbc2a86bd303ae847e472f44ecc06c2cd2fcdb2d \ - --hash=sha256:a85d2b46be66a71bedde836d9e41859879cc54a2a04fad1191eb50c2066f6e9d \ - --hash=sha256:a9008dad07d71f68487c91e96579c8567c98ca4c3881b9b113bc7b33e9fd78b8 \ - --hash=sha256:a9a52172be0b5aae932bef82a79ec0a0ce87288c7d132946d645eba03f0ad8a8 \ - --hash=sha256:aa31fdcc33fef9eb2552cbcbfee7773d5a6792c137b359e82879c101e98584c5 \ - --hash=sha256:acae32e13a4153809db37405f5eba5bac5fbe2e2ba61ab227926a22901051c0a \ - --hash=sha256:b014c23646a467558be7da3d6b9fa409b2c567d2110599b7cf9a0c5992b3b471 \ - --hash=sha256:b21bb4c09ffabfa0e85e3a6b623e19b80e7acd709b9f91452b8297ace2a8ab00 \ - --hash=sha256:b5901a312f4d14c59918c221323068fad0540e34324925c8475263841dbdfe68 \ - --hash=sha256:b9b7a708dd92306328117d8c4b62e2194d00c365f18eff11a9b53c6f923b01e3 \ - --hash=sha256:d1967f46ea8f2db647c786e78d8cc7e4313dbd1b0aca360592d8027b8508e24d \ - --hash=sha256:d52a25136894c63de15a35bc0bdc5adb4b0e173b9c0d07a2be9d3ca64a332735 \ - --hash=sha256:d77c85fedff92cf788face9bfa3ebaa364448ebb1d765302e9af11bf449ca36d \ - --hash=sha256:d79d7d5dc8a32b7093e81e97dad755127ff77bcc899e845f41bf71747af0c569 \ - --hash=sha256:dbcda74c67263139358f4d188ae5faae95c30929281bc6866d00573783c422b7 \ - --hash=sha256:ddaea91abf8b0d13443f6dac52e89051a5063c7d014710dcb4d4abb2ff811a59 \ - --hash=sha256:dee0ce50c6a2dd9056c20db781e9c1cfd33e77d2d569f5d1d9321c641bb903d5 \ - --hash=sha256:dee60e1de1898bde3b238f18340eec6148986da0455d8ba7848d50470a7a32fb \ - --hash=sha256:e2f83e18fe2f4c9e7db597e988f72712c0c3676d337d8b101f6758107c42425b \ - --hash=sha256:e3fb1677c720409d5f671e39bac6c9e0e422584e5f518bfd50aa4cbbea02433f \ - --hash=sha256:ecee4132c6cd2ce5308e21672015ddfed1ff975ad0ac8d27168ea82e71413f55 \ - --hash=sha256:ee2b1b1769f6707a8a445162ea16dddf74285c3964f605877a20e38545c3c462 \ - --hash=sha256:ee6acae74a2b91865910eef5e7de37dc6895ad96fa23603d1d27ea69df545015 \ - --hash=sha256:ef3f72c9666bba2bab70d2a8b79f2c6d2c1a42a7f7e2b0ec83bb2f9e383950af +wrapt==1.16.0 \ + --hash=sha256:0d2691979e93d06a95a26257adb7bfd0c93818e89b1406f5a28f36e0d8c1e1fc \ + --hash=sha256:14d7dc606219cdd7405133c713f2c218d4252f2a469003f8c46bb92d5d095d81 \ + --hash=sha256:1a5db485fe2de4403f13fafdc231b0dbae5eca4359232d2efc79025527375b09 \ + --hash=sha256:1acd723ee2a8826f3d53910255643e33673e1d11db84ce5880675954183ec47e \ + --hash=sha256:1ca9b6085e4f866bd584fb135a041bfc32cab916e69f714a7d1d397f8c4891ca \ + --hash=sha256:1dd50a2696ff89f57bd8847647a1c363b687d3d796dc30d4dd4a9d1689a706f0 \ + --hash=sha256:2076fad65c6736184e77d7d4729b63a6d1ae0b70da4868adeec40989858eb3fb \ + --hash=sha256:2a88e6010048489cda82b1326889ec075a8c856c2e6a256072b28eaee3ccf487 \ + --hash=sha256:3ebf019be5c09d400cf7b024aa52b1f3aeebeff51550d007e92c3c1c4afc2a40 \ + --hash=sha256:418abb18146475c310d7a6dc71143d6f7adec5b004ac9ce08dc7a34e2babdc5c \ + --hash=sha256:43aa59eadec7890d9958748db829df269f0368521ba6dc68cc172d5d03ed8060 \ + --hash=sha256:44a2754372e32ab315734c6c73b24351d06e77ffff6ae27d2ecf14cf3d229202 \ + --hash=sha256:490b0ee15c1a55be9c1bd8609b8cecd60e325f0575fc98f50058eae366e01f41 \ + --hash=sha256:49aac49dc4782cb04f58986e81ea0b4768e4ff197b57324dcbd7699c5dfb40b9 \ + --hash=sha256:5eb404d89131ec9b4f748fa5cfb5346802e5ee8836f57d516576e61f304f3b7b \ + --hash=sha256:5f15814a33e42b04e3de432e573aa557f9f0f56458745c2074952f564c50e664 \ + --hash=sha256:5f370f952971e7d17c7d1ead40e49f32345a7f7a5373571ef44d800d06b1899d \ + --hash=sha256:66027d667efe95cc4fa945af59f92c5a02c6f5bb6012bff9e60542c74c75c362 \ + --hash=sha256:66dfbaa7cfa3eb707bbfcd46dab2bc6207b005cbc9caa2199bcbc81d95071a00 \ + --hash=sha256:685f568fa5e627e93f3b52fda002c7ed2fa1800b50ce51f6ed1d572d8ab3e7fc \ + --hash=sha256:6906c4100a8fcbf2fa735f6059214bb13b97f75b1a61777fcf6432121ef12ef1 \ + --hash=sha256:6a42cd0cfa8ffc1915aef79cb4284f6383d8a3e9dcca70c445dcfdd639d51267 \ + --hash=sha256:6dcfcffe73710be01d90cae08c3e548d90932d37b39ef83969ae135d36ef3956 \ + --hash=sha256:6f6eac2360f2d543cc875a0e5efd413b6cbd483cb3ad7ebf888884a6e0d2e966 \ + --hash=sha256:72554a23c78a8e7aa02abbd699d129eead8b147a23c56e08d08dfc29cfdddca1 \ + --hash=sha256:73870c364c11f03ed072dda68ff7aea6d2a3a5c3fe250d917a429c7432e15228 \ + --hash=sha256:73aa7d98215d39b8455f103de64391cb79dfcad601701a3aa0dddacf74911d72 \ + --hash=sha256:75ea7d0ee2a15733684badb16de6794894ed9c55aa5e9903260922f0482e687d \ + --hash=sha256:7bd2d7ff69a2cac767fbf7a2b206add2e9a210e57947dd7ce03e25d03d2de292 \ + --hash=sha256:807cc8543a477ab7422f1120a217054f958a66ef7314f76dd9e77d3f02cdccd0 \ + --hash=sha256:8e9723528b9f787dc59168369e42ae1c3b0d3fadb2f1a71de14531d321ee05b0 \ + --hash=sha256:9090c9e676d5236a6948330e83cb89969f433b1943a558968f659ead07cb3b36 \ + --hash=sha256:9153ed35fc5e4fa3b2fe97bddaa7cbec0ed22412b85bcdaf54aeba92ea37428c \ + --hash=sha256:9159485323798c8dc530a224bd3ffcf76659319ccc7bbd52e01e73bd0241a0c5 \ + --hash=sha256:941988b89b4fd6b41c3f0bfb20e92bd23746579736b7343283297c4c8cbae68f \ + --hash=sha256:94265b00870aa407bd0cbcfd536f17ecde43b94fb8d228560a1e9d3041462d73 \ + --hash=sha256:98b5e1f498a8ca1858a1cdbffb023bfd954da4e3fa2c0cb5853d40014557248b \ + --hash=sha256:9b201ae332c3637a42f02d1045e1d0cccfdc41f1f2f801dafbaa7e9b4797bfc2 \ + --hash=sha256:a0ea261ce52b5952bf669684a251a66df239ec6d441ccb59ec7afa882265d593 \ + --hash=sha256:a33a747400b94b6d6b8a165e4480264a64a78c8a4c734b62136062e9a248dd39 \ + --hash=sha256:a452f9ca3e3267cd4d0fcf2edd0d035b1934ac2bd7e0e57ac91ad6b95c0c6389 \ + --hash=sha256:a86373cf37cd7764f2201b76496aba58a52e76dedfaa698ef9e9688bfd9e41cf \ + --hash=sha256:ac83a914ebaf589b69f7d0a1277602ff494e21f4c2f743313414378f8f50a4cf \ + --hash=sha256:aefbc4cb0a54f91af643660a0a150ce2c090d3652cf4052a5397fb2de549cd89 \ + --hash=sha256:b3646eefa23daeba62643a58aac816945cadc0afaf21800a1421eeba5f6cfb9c \ + --hash=sha256:b47cfad9e9bbbed2339081f4e346c93ecd7ab504299403320bf85f7f85c7d46c \ + --hash=sha256:b935ae30c6e7400022b50f8d359c03ed233d45b725cfdd299462f41ee5ffba6f \ + --hash=sha256:bb2dee3874a500de01c93d5c71415fcaef1d858370d405824783e7a8ef5db440 \ + --hash=sha256:bc57efac2da352a51cc4658878a68d2b1b67dbe9d33c36cb826ca449d80a8465 \ + --hash=sha256:bf5703fdeb350e36885f2875d853ce13172ae281c56e509f4e6eca049bdfb136 \ + --hash=sha256:c31f72b1b6624c9d863fc095da460802f43a7c6868c5dda140f51da24fd47d7b \ + --hash=sha256:c5cd603b575ebceca7da5a3a251e69561bec509e0b46e4993e1cac402b7247b8 \ + --hash=sha256:d2efee35b4b0a347e0d99d28e884dfd82797852d62fcd7ebdeee26f3ceb72cf3 \ + --hash=sha256:d462f28826f4657968ae51d2181a074dfe03c200d6131690b7d65d55b0f360f8 \ + --hash=sha256:d5e49454f19ef621089e204f862388d29e6e8d8b162efce05208913dde5b9ad6 \ + --hash=sha256:da4813f751142436b075ed7aa012a8778aa43a99f7b36afe9b742d3ed8bdc95e \ + --hash=sha256:db2e408d983b0e61e238cf579c09ef7020560441906ca990fe8412153e3b291f \ + --hash=sha256:db98ad84a55eb09b3c32a96c576476777e87c520a34e2519d3e59c44710c002c \ + --hash=sha256:dbed418ba5c3dce92619656802cc5355cb679e58d0d89b50f116e4a9d5a9603e \ + --hash=sha256:dcdba5c86e368442528f7060039eda390cc4091bfd1dca41e8046af7c910dda8 \ + --hash=sha256:decbfa2f618fa8ed81c95ee18a387ff973143c656ef800c9f24fb7e9c16054e2 \ + --hash=sha256:e4fdb9275308292e880dcbeb12546df7f3e0f96c6b41197e0cf37d2826359020 \ + --hash=sha256:eb1b046be06b0fce7249f1d025cd359b4b80fc1c3e24ad9eca33e0dcdb2e4a35 \ + --hash=sha256:eb6e651000a19c96f452c85132811d25e9264d836951022d6e81df2fff38337d \ + --hash=sha256:ed867c42c268f876097248e05b6117a65bcd1e63b779e916fe2e33cd6fd0d3c3 \ + --hash=sha256:edfad1d29c73f9b863ebe7082ae9321374ccb10879eeabc84ba3b69f2579d537 \ + --hash=sha256:f2058f813d4f2b5e3a9eb2eb3faf8f1d99b81c3e51aeda4b168406443e8ba809 \ + --hash=sha256:f6b2d0c6703c988d334f297aa5df18c45e97b0af3679bb75059e0e0bd8b1069d \ + --hash=sha256:f8212564d49c50eb4565e502814f694e240c55551a5f1bc841d4fcaabb0a9b8a \ + --hash=sha256:ffa565331890b90056c01db69c0fe634a776f8019c143a5ae265f9c6bc4bd6d4 # via -r requirements.in setuptools==68.2.2 \ diff --git a/requirements_lock_3_11.txt b/requirements_lock_3_11.txt index 434b38d603df80..2335d295d0faf6 100644 --- a/requirements_lock_3_11.txt +++ b/requirements_lock_3_11.txt @@ -1,6 +1,6 @@ -absl-py==2.0.0 \ - --hash=sha256:9a28abb62774ae4e8edbe2dd4c49ffcd45a6a848952a5eccc6a49f3f0fc1e2f3 \ - --hash=sha256:d9690211c5fcfefcdd1a45470ac2b5c5acd45241c3af71eed96bc5441746c0d5 +absl-py==2.1.0 \ + --hash=sha256:526a04eadab8b4ee719ce68f204172ead1027549089702d99b9059f129ff1308 \ + --hash=sha256:7820790efbb316739cde8b4e19357243fc3608a152024288513dd968d7d959ff # via # keras-nightly # tb-nightly @@ -12,105 +12,101 @@ astunparse==1.6.3 \ --hash=sha256:5ad93a8456f0d084c3456d059fd9a92cce667963232cbf763eac3bc5b7940872 \ --hash=sha256:c2652417f2c8b5bb325c885ae329bdf3f86424075c4fd1a128674bc6fba4b8e8 # via -r requirements.in -cachetools==5.3.2 \ - --hash=sha256:086ee420196f7b2ab9ca2db2520aca326318b68fe5ba8bc4d49cca91add450f2 \ - --hash=sha256:861f35a13a451f94e301ce2bec7cac63e881232ccce7ed67fab9b5df4d3beaa1 - # via google-auth -certifi==2023.7.22 \ - --hash=sha256:539cc1d13202e33ca466e88b2807e29f4c13049d6d87031a3c110744495cb082 \ - --hash=sha256:92d6037539857d8206b8f6ae472e8b77db8058fec5937a1ef3f54304089edbb9 +certifi==2024.2.2 \ + --hash=sha256:0569859f95fc761b18b45ef421b1290a0f65f147e92a1e5eb3e635f9a5e4e66f \ + --hash=sha256:dc383c07b76109f368f6106eee2b593b04a011ea4d55f652c6ca24a754d1cdd1 # via requests -charset-normalizer==3.3.1 \ - --hash=sha256:06cf46bdff72f58645434d467bf5228080801298fbba19fe268a01b4534467f5 \ - --hash=sha256:0c8c61fb505c7dad1d251c284e712d4e0372cef3b067f7ddf82a7fa82e1e9a93 \ - --hash=sha256:10b8dd31e10f32410751b3430996f9807fc4d1587ca69772e2aa940a82ab571a \ - --hash=sha256:1171ef1fc5ab4693c5d151ae0fdad7f7349920eabbaca6271f95969fa0756c2d \ - --hash=sha256:17a866d61259c7de1bdadef418a37755050ddb4b922df8b356503234fff7932c \ - --hash=sha256:1d6bfc32a68bc0933819cfdfe45f9abc3cae3877e1d90aac7259d57e6e0f85b1 \ - --hash=sha256:1ec937546cad86d0dce5396748bf392bb7b62a9eeb8c66efac60e947697f0e58 \ - --hash=sha256:223b4d54561c01048f657fa6ce41461d5ad8ff128b9678cfe8b2ecd951e3f8a2 \ - --hash=sha256:2465aa50c9299d615d757c1c888bc6fef384b7c4aec81c05a0172b4400f98557 \ - --hash=sha256:28f512b9a33235545fbbdac6a330a510b63be278a50071a336afc1b78781b147 \ - --hash=sha256:2c092be3885a1b7899cd85ce24acedc1034199d6fca1483fa2c3a35c86e43041 \ - --hash=sha256:2c4c99f98fc3a1835af8179dcc9013f93594d0670e2fa80c83aa36346ee763d2 \ - --hash=sha256:31445f38053476a0c4e6d12b047b08ced81e2c7c712e5a1ad97bc913256f91b2 \ - --hash=sha256:31bbaba7218904d2eabecf4feec0d07469284e952a27400f23b6628439439fa7 \ - --hash=sha256:34d95638ff3613849f473afc33f65c401a89f3b9528d0d213c7037c398a51296 \ - --hash=sha256:352a88c3df0d1fa886562384b86f9a9e27563d4704ee0e9d56ec6fcd270ea690 \ - --hash=sha256:39b70a6f88eebe239fa775190796d55a33cfb6d36b9ffdd37843f7c4c1b5dc67 \ - --hash=sha256:3c66df3f41abee950d6638adc7eac4730a306b022570f71dd0bd6ba53503ab57 \ - --hash=sha256:3f70fd716855cd3b855316b226a1ac8bdb3caf4f7ea96edcccc6f484217c9597 \ - --hash=sha256:3f9bc2ce123637a60ebe819f9fccc614da1bcc05798bbbaf2dd4ec91f3e08846 \ - --hash=sha256:3fb765362688821404ad6cf86772fc54993ec11577cd5a92ac44b4c2ba52155b \ - --hash=sha256:45f053a0ece92c734d874861ffe6e3cc92150e32136dd59ab1fb070575189c97 \ - --hash=sha256:46fb9970aa5eeca547d7aa0de5d4b124a288b42eaefac677bde805013c95725c \ - --hash=sha256:4cb50a0335382aac15c31b61d8531bc9bb657cfd848b1d7158009472189f3d62 \ - --hash=sha256:4e12f8ee80aa35e746230a2af83e81bd6b52daa92a8afaef4fea4a2ce9b9f4fa \ - --hash=sha256:4f3100d86dcd03c03f7e9c3fdb23d92e32abbca07e7c13ebd7ddfbcb06f5991f \ - --hash=sha256:4f6e2a839f83a6a76854d12dbebde50e4b1afa63e27761549d006fa53e9aa80e \ - --hash=sha256:4f861d94c2a450b974b86093c6c027888627b8082f1299dfd5a4bae8e2292821 \ - --hash=sha256:501adc5eb6cd5f40a6f77fbd90e5ab915c8fd6e8c614af2db5561e16c600d6f3 \ - --hash=sha256:520b7a142d2524f999447b3a0cf95115df81c4f33003c51a6ab637cbda9d0bf4 \ - --hash=sha256:548eefad783ed787b38cb6f9a574bd8664468cc76d1538215d510a3cd41406cb \ - --hash=sha256:555fe186da0068d3354cdf4bbcbc609b0ecae4d04c921cc13e209eece7720727 \ - --hash=sha256:55602981b2dbf8184c098bc10287e8c245e351cd4fdcad050bd7199d5a8bf514 \ - --hash=sha256:58e875eb7016fd014c0eea46c6fa92b87b62c0cb31b9feae25cbbe62c919f54d \ - --hash=sha256:5a3580a4fdc4ac05f9e53c57f965e3594b2f99796231380adb2baaab96e22761 \ - --hash=sha256:5b70bab78accbc672f50e878a5b73ca692f45f5b5e25c8066d748c09405e6a55 \ - --hash=sha256:5ceca5876032362ae73b83347be8b5dbd2d1faf3358deb38c9c88776779b2e2f \ - --hash=sha256:61f1e3fb621f5420523abb71f5771a204b33c21d31e7d9d86881b2cffe92c47c \ - --hash=sha256:633968254f8d421e70f91c6ebe71ed0ab140220469cf87a9857e21c16687c034 \ - --hash=sha256:63a6f59e2d01310f754c270e4a257426fe5a591dc487f1983b3bbe793cf6bac6 \ - --hash=sha256:63accd11149c0f9a99e3bc095bbdb5a464862d77a7e309ad5938fbc8721235ae \ - --hash=sha256:6db3cfb9b4fcecb4390db154e75b49578c87a3b9979b40cdf90d7e4b945656e1 \ - --hash=sha256:71ef3b9be10070360f289aea4838c784f8b851be3ba58cf796262b57775c2f14 \ - --hash=sha256:7ae8e5142dcc7a49168f4055255dbcced01dc1714a90a21f87448dc8d90617d1 \ - --hash=sha256:7b6cefa579e1237ce198619b76eaa148b71894fb0d6bcf9024460f9bf30fd228 \ - --hash=sha256:800561453acdecedaac137bf09cd719c7a440b6800ec182f077bb8e7025fb708 \ - --hash=sha256:82ca51ff0fc5b641a2d4e1cc8c5ff108699b7a56d7f3ad6f6da9dbb6f0145b48 \ - --hash=sha256:851cf693fb3aaef71031237cd68699dded198657ec1e76a76eb8be58c03a5d1f \ - --hash=sha256:854cc74367180beb327ab9d00f964f6d91da06450b0855cbbb09187bcdb02de5 \ - --hash=sha256:87071618d3d8ec8b186d53cb6e66955ef2a0e4fa63ccd3709c0c90ac5a43520f \ - --hash=sha256:871d045d6ccc181fd863a3cd66ee8e395523ebfbc57f85f91f035f50cee8e3d4 \ - --hash=sha256:8aee051c89e13565c6bd366813c386939f8e928af93c29fda4af86d25b73d8f8 \ - --hash=sha256:8af5a8917b8af42295e86b64903156b4f110a30dca5f3b5aedea123fbd638bff \ - --hash=sha256:8ec8ef42c6cd5856a7613dcd1eaf21e5573b2185263d87d27c8edcae33b62a61 \ - --hash=sha256:91e43805ccafa0a91831f9cd5443aa34528c0c3f2cc48c4cb3d9a7721053874b \ - --hash=sha256:9505dc359edb6a330efcd2be825fdb73ee3e628d9010597aa1aee5aa63442e97 \ - --hash=sha256:985c7965f62f6f32bf432e2681173db41336a9c2611693247069288bcb0c7f8b \ - --hash=sha256:9a74041ba0bfa9bc9b9bb2cd3238a6ab3b7618e759b41bd15b5f6ad958d17605 \ - --hash=sha256:9edbe6a5bf8b56a4a84533ba2b2f489d0046e755c29616ef8830f9e7d9cf5728 \ - --hash=sha256:a15c1fe6d26e83fd2e5972425a772cca158eae58b05d4a25a4e474c221053e2d \ - --hash=sha256:a66bcdf19c1a523e41b8e9d53d0cedbfbac2e93c649a2e9502cb26c014d0980c \ - --hash=sha256:ae4070f741f8d809075ef697877fd350ecf0b7c5837ed68738607ee0a2c572cf \ - --hash=sha256:ae55d592b02c4349525b6ed8f74c692509e5adffa842e582c0f861751701a673 \ - --hash=sha256:b578cbe580e3b41ad17b1c428f382c814b32a6ce90f2d8e39e2e635d49e498d1 \ - --hash=sha256:b891a2f68e09c5ef989007fac11476ed33c5c9994449a4e2c3386529d703dc8b \ - --hash=sha256:baec8148d6b8bd5cee1ae138ba658c71f5b03e0d69d5907703e3e1df96db5e41 \ - --hash=sha256:bb06098d019766ca16fc915ecaa455c1f1cd594204e7f840cd6258237b5079a8 \ - --hash=sha256:bc791ec3fd0c4309a753f95bb6c749ef0d8ea3aea91f07ee1cf06b7b02118f2f \ - --hash=sha256:bd28b31730f0e982ace8663d108e01199098432a30a4c410d06fe08fdb9e93f4 \ - --hash=sha256:be4d9c2770044a59715eb57c1144dedea7c5d5ae80c68fb9959515037cde2008 \ - --hash=sha256:c0c72d34e7de5604df0fde3644cc079feee5e55464967d10b24b1de268deceb9 \ - --hash=sha256:c0e842112fe3f1a4ffcf64b06dc4c61a88441c2f02f373367f7b4c1aa9be2ad5 \ - --hash=sha256:c15070ebf11b8b7fd1bfff7217e9324963c82dbdf6182ff7050519e350e7ad9f \ - --hash=sha256:c2000c54c395d9e5e44c99dc7c20a64dc371f777faf8bae4919ad3e99ce5253e \ - --hash=sha256:c30187840d36d0ba2893bc3271a36a517a717f9fd383a98e2697ee890a37c273 \ - --hash=sha256:cb7cd68814308aade9d0c93c5bd2ade9f9441666f8ba5aa9c2d4b389cb5e2a45 \ - --hash=sha256:cd805513198304026bd379d1d516afbf6c3c13f4382134a2c526b8b854da1c2e \ - --hash=sha256:d0bf89afcbcf4d1bb2652f6580e5e55a840fdf87384f6063c4a4f0c95e378656 \ - --hash=sha256:d9137a876020661972ca6eec0766d81aef8a5627df628b664b234b73396e727e \ - --hash=sha256:dbd95e300367aa0827496fe75a1766d198d34385a58f97683fe6e07f89ca3e3c \ - --hash=sha256:dced27917823df984fe0c80a5c4ad75cf58df0fbfae890bc08004cd3888922a2 \ - --hash=sha256:de0b4caa1c8a21394e8ce971997614a17648f94e1cd0640fbd6b4d14cab13a72 \ - --hash=sha256:debb633f3f7856f95ad957d9b9c781f8e2c6303ef21724ec94bea2ce2fcbd056 \ - --hash=sha256:e372d7dfd154009142631de2d316adad3cc1c36c32a38b16a4751ba78da2a397 \ - --hash=sha256:ecd26be9f112c4f96718290c10f4caea6cc798459a3a76636b817a0ed7874e42 \ - --hash=sha256:edc0202099ea1d82844316604e17d2b175044f9bcb6b398aab781eba957224bd \ - --hash=sha256:f194cce575e59ffe442c10a360182a986535fd90b57f7debfaa5c845c409ecc3 \ - --hash=sha256:f5fb672c396d826ca16a022ac04c9dce74e00a1c344f6ad1a0fdc1ba1f332213 \ - --hash=sha256:f6a02a3c7950cafaadcd46a226ad9e12fc9744652cc69f9e5534f98b47f3bbcf \ - --hash=sha256:fe81b35c33772e56f4b6cf62cf4aedc1762ef7162a31e6ac7fe5e40d0149eb67 +charset-normalizer==3.3.2 \ + --hash=sha256:06435b539f889b1f6f4ac1758871aae42dc3a8c0e24ac9e60c2384973ad73027 \ + --hash=sha256:06a81e93cd441c56a9b65d8e1d043daeb97a3d0856d177d5c90ba85acb3db087 \ + --hash=sha256:0a55554a2fa0d408816b3b5cedf0045f4b8e1a6065aec45849de2d6f3f8e9786 \ + --hash=sha256:0b2b64d2bb6d3fb9112bafa732def486049e63de9618b5843bcdd081d8144cd8 \ + --hash=sha256:10955842570876604d404661fbccbc9c7e684caf432c09c715ec38fbae45ae09 \ + --hash=sha256:122c7fa62b130ed55f8f285bfd56d5f4b4a5b503609d181f9ad85e55c89f4185 \ + --hash=sha256:1ceae2f17a9c33cb48e3263960dc5fc8005351ee19db217e9b1bb15d28c02574 \ + --hash=sha256:1d3193f4a680c64b4b6a9115943538edb896edc190f0b222e73761716519268e \ + --hash=sha256:1f79682fbe303db92bc2b1136016a38a42e835d932bab5b3b1bfcfbf0640e519 \ + --hash=sha256:2127566c664442652f024c837091890cb1942c30937add288223dc895793f898 \ + --hash=sha256:22afcb9f253dac0696b5a4be4a1c0f8762f8239e21b99680099abd9b2b1b2269 \ + --hash=sha256:25baf083bf6f6b341f4121c2f3c548875ee6f5339300e08be3f2b2ba1721cdd3 \ + --hash=sha256:2e81c7b9c8979ce92ed306c249d46894776a909505d8f5a4ba55b14206e3222f \ + --hash=sha256:3287761bc4ee9e33561a7e058c72ac0938c4f57fe49a09eae428fd88aafe7bb6 \ + --hash=sha256:34d1c8da1e78d2e001f363791c98a272bb734000fcef47a491c1e3b0505657a8 \ + --hash=sha256:37e55c8e51c236f95b033f6fb391d7d7970ba5fe7ff453dad675e88cf303377a \ + --hash=sha256:3d47fa203a7bd9c5b6cee4736ee84ca03b8ef23193c0d1ca99b5089f72645c73 \ + --hash=sha256:3e4d1f6587322d2788836a99c69062fbb091331ec940e02d12d179c1d53e25fc \ + --hash=sha256:42cb296636fcc8b0644486d15c12376cb9fa75443e00fb25de0b8602e64c1714 \ + --hash=sha256:45485e01ff4d3630ec0d9617310448a8702f70e9c01906b0d0118bdf9d124cf2 \ + --hash=sha256:4a78b2b446bd7c934f5dcedc588903fb2f5eec172f3d29e52a9096a43722adfc \ + --hash=sha256:4ab2fe47fae9e0f9dee8c04187ce5d09f48eabe611be8259444906793ab7cbce \ + --hash=sha256:4d0d1650369165a14e14e1e47b372cfcb31d6ab44e6e33cb2d4e57265290044d \ + --hash=sha256:549a3a73da901d5bc3ce8d24e0600d1fa85524c10287f6004fbab87672bf3e1e \ + --hash=sha256:55086ee1064215781fff39a1af09518bc9255b50d6333f2e4c74ca09fac6a8f6 \ + --hash=sha256:572c3763a264ba47b3cf708a44ce965d98555f618ca42c926a9c1616d8f34269 \ + --hash=sha256:573f6eac48f4769d667c4442081b1794f52919e7edada77495aaed9236d13a96 \ + --hash=sha256:5b4c145409bef602a690e7cfad0a15a55c13320ff7a3ad7ca59c13bb8ba4d45d \ + --hash=sha256:6463effa3186ea09411d50efc7d85360b38d5f09b870c48e4600f63af490e56a \ + --hash=sha256:65f6f63034100ead094b8744b3b97965785388f308a64cf8d7c34f2f2e5be0c4 \ + --hash=sha256:663946639d296df6a2bb2aa51b60a2454ca1cb29835324c640dafb5ff2131a77 \ + --hash=sha256:6897af51655e3691ff853668779c7bad41579facacf5fd7253b0133308cf000d \ + --hash=sha256:68d1f8a9e9e37c1223b656399be5d6b448dea850bed7d0f87a8311f1ff3dabb0 \ + --hash=sha256:6ac7ffc7ad6d040517be39eb591cac5ff87416c2537df6ba3cba3bae290c0fed \ + --hash=sha256:6b3251890fff30ee142c44144871185dbe13b11bab478a88887a639655be1068 \ + --hash=sha256:6c4caeef8fa63d06bd437cd4bdcf3ffefe6738fb1b25951440d80dc7df8c03ac \ + --hash=sha256:6ef1d82a3af9d3eecdba2321dc1b3c238245d890843e040e41e470ffa64c3e25 \ + --hash=sha256:753f10e867343b4511128c6ed8c82f7bec3bd026875576dfd88483c5c73b2fd8 \ + --hash=sha256:7cd13a2e3ddeed6913a65e66e94b51d80a041145a026c27e6bb76c31a853c6ab \ + --hash=sha256:7ed9e526742851e8d5cc9e6cf41427dfc6068d4f5a3bb03659444b4cabf6bc26 \ + --hash=sha256:7f04c839ed0b6b98b1a7501a002144b76c18fb1c1850c8b98d458ac269e26ed2 \ + --hash=sha256:802fe99cca7457642125a8a88a084cef28ff0cf9407060f7b93dca5aa25480db \ + --hash=sha256:80402cd6ee291dcb72644d6eac93785fe2c8b9cb30893c1af5b8fdd753b9d40f \ + --hash=sha256:8465322196c8b4d7ab6d1e049e4c5cb460d0394da4a27d23cc242fbf0034b6b5 \ + --hash=sha256:86216b5cee4b06df986d214f664305142d9c76df9b6512be2738aa72a2048f99 \ + --hash=sha256:87d1351268731db79e0f8e745d92493ee2841c974128ef629dc518b937d9194c \ + --hash=sha256:8bdb58ff7ba23002a4c5808d608e4e6c687175724f54a5dade5fa8c67b604e4d \ + --hash=sha256:8c622a5fe39a48f78944a87d4fb8a53ee07344641b0562c540d840748571b811 \ + --hash=sha256:8d756e44e94489e49571086ef83b2bb8ce311e730092d2c34ca8f7d925cb20aa \ + --hash=sha256:8f4a014bc36d3c57402e2977dada34f9c12300af536839dc38c0beab8878f38a \ + --hash=sha256:9063e24fdb1e498ab71cb7419e24622516c4a04476b17a2dab57e8baa30d6e03 \ + --hash=sha256:90d558489962fd4918143277a773316e56c72da56ec7aa3dc3dbbe20fdfed15b \ + --hash=sha256:923c0c831b7cfcb071580d3f46c4baf50f174be571576556269530f4bbd79d04 \ + --hash=sha256:95f2a5796329323b8f0512e09dbb7a1860c46a39da62ecb2324f116fa8fdc85c \ + --hash=sha256:96b02a3dc4381e5494fad39be677abcb5e6634bf7b4fa83a6dd3112607547001 \ + --hash=sha256:9f96df6923e21816da7e0ad3fd47dd8f94b2a5ce594e00677c0013018b813458 \ + --hash=sha256:a10af20b82360ab00827f916a6058451b723b4e65030c5a18577c8b2de5b3389 \ + --hash=sha256:a50aebfa173e157099939b17f18600f72f84eed3049e743b68ad15bd69b6bf99 \ + --hash=sha256:a981a536974bbc7a512cf44ed14938cf01030a99e9b3a06dd59578882f06f985 \ + --hash=sha256:a9a8e9031d613fd2009c182b69c7b2c1ef8239a0efb1df3f7c8da66d5dd3d537 \ + --hash=sha256:ae5f4161f18c61806f411a13b0310bea87f987c7d2ecdbdaad0e94eb2e404238 \ + --hash=sha256:aed38f6e4fb3f5d6bf81bfa990a07806be9d83cf7bacef998ab1a9bd660a581f \ + --hash=sha256:b01b88d45a6fcb69667cd6d2f7a9aeb4bf53760d7fc536bf679ec94fe9f3ff3d \ + --hash=sha256:b261ccdec7821281dade748d088bb6e9b69e6d15b30652b74cbbac25e280b796 \ + --hash=sha256:b2b0a0c0517616b6869869f8c581d4eb2dd83a4d79e0ebcb7d373ef9956aeb0a \ + --hash=sha256:b4a23f61ce87adf89be746c8a8974fe1c823c891d8f86eb218bb957c924bb143 \ + --hash=sha256:bd8f7df7d12c2db9fab40bdd87a7c09b1530128315d047a086fa3ae3435cb3a8 \ + --hash=sha256:beb58fe5cdb101e3a055192ac291b7a21e3b7ef4f67fa1d74e331a7f2124341c \ + --hash=sha256:c002b4ffc0be611f0d9da932eb0f704fe2602a9a949d1f738e4c34c75b0863d5 \ + --hash=sha256:c083af607d2515612056a31f0a8d9e0fcb5876b7bfc0abad3ecd275bc4ebc2d5 \ + --hash=sha256:c180f51afb394e165eafe4ac2936a14bee3eb10debc9d9e4db8958fe36afe711 \ + --hash=sha256:c235ebd9baae02f1b77bcea61bce332cb4331dc3617d254df3323aa01ab47bd4 \ + --hash=sha256:cd70574b12bb8a4d2aaa0094515df2463cb429d8536cfb6c7ce983246983e5a6 \ + --hash=sha256:d0eccceffcb53201b5bfebb52600a5fb483a20b61da9dbc885f8b103cbe7598c \ + --hash=sha256:d965bba47ddeec8cd560687584e88cf699fd28f192ceb452d1d7ee807c5597b7 \ + --hash=sha256:db364eca23f876da6f9e16c9da0df51aa4f104a972735574842618b8c6d999d4 \ + --hash=sha256:ddbb2551d7e0102e7252db79ba445cdab71b26640817ab1e3e3648dad515003b \ + --hash=sha256:deb6be0ac38ece9ba87dea880e438f25ca3eddfac8b002a2ec3d9183a454e8ae \ + --hash=sha256:e06ed3eb3218bc64786f7db41917d4e686cc4856944f53d5bdf83a6884432e12 \ + --hash=sha256:e27ad930a842b4c5eb8ac0016b0a54f5aebbe679340c26101df33424142c143c \ + --hash=sha256:e537484df0d8f426ce2afb2d0f8e1c3d0b114b83f8850e5f2fbea0e797bd82ae \ + --hash=sha256:eb00ed941194665c332bf8e078baf037d6c35d7c4f3102ea2d4f16ca94a26dc8 \ + --hash=sha256:eb6904c354526e758fda7167b33005998fb68c46fbc10e013ca97f21ca5c8887 \ + --hash=sha256:eb8821e09e916165e160797a6c17edda0679379a4be5c716c260e836e122f54b \ + --hash=sha256:efcb3f6676480691518c177e3b465bcddf57cea040302f9f4e6e191af91174d4 \ + --hash=sha256:f27273b60488abe721a075bcca6d7f3964f9f6f067c8c4c605743023d7d3944f \ + --hash=sha256:f30c3cb33b24454a82faecaf01b19c18562b1e89558fb6c56de4d9118a032fd5 \ + --hash=sha256:fb69256e180cb6c8a894fee62b3afebae785babc1ee98b81cdf68bbca1987f33 \ + --hash=sha256:fd1abc0d89e30cc4e02e4064dc67fcc51bd941eb395c502aac3ec19fab46b519 \ + --hash=sha256:ff8fa367d09b717b2a17a052544193ad76cd49979c805768879cb63d9ca50561 # via requests dill==0.3.7 \ --hash=sha256:76b122c08ef4ce2eedcd4d1abd8e641114bfc6c2867f49f3c41facf65bf19f5e \ @@ -133,6 +129,7 @@ dm-tree==0.1.8 \ --hash=sha256:35cc164a79336bfcfafb47e5f297898359123bbd3330c1967f0c4994f9cf9f60 \ --hash=sha256:378cc8ad93c5fe3590f405a309980721f021c790ca1bdf9b15bb1d59daec57f5 \ --hash=sha256:39070ba268c0491af9fe7a58644d99e8b4f2cde6e5884ba3380bddc84ed43d5f \ + --hash=sha256:435227cf3c5dc63f4de054cf3d00183790bd9ead4c3623138c74dde7f67f521b \ --hash=sha256:5483dca4d7eb1a0d65fe86d3b6a53ae717face83c1f17e0887b1a4a64ae5c410 \ --hash=sha256:694c3654cfd2a81552c08ec66bb5c4a3d48fa292b9a181880fb081c36c5b9134 \ --hash=sha256:803bfc53b4659f447ac694dbd04235f94a73ef7c1fd1e0df7c84ac41e0bc963b \ @@ -140,6 +137,8 @@ dm-tree==0.1.8 \ --hash=sha256:83b7764de0d855338abefc6e3ee9fe40d301668310aa3baea3f778ff051f4393 \ --hash=sha256:8c60a7eadab64c2278861f56bca320b2720f163dca9d7558103c3b77f2416571 \ --hash=sha256:8ed3564abed97c806db122c2d3e1a2b64c74a63debe9903aad795167cc301368 \ + --hash=sha256:94d3f0826311f45ee19b75f5b48c99466e4218a0489e81c0f0167bda50cacf22 \ + --hash=sha256:96a548a406a6fb15fe58f6a30a57ff2f2aafbf25f05afab00c8f5e5977b6c715 \ --hash=sha256:a5d819c38c03f0bb5b3b3703c60e4b170355a0fc6b5819325bf3d4ceb3ae7e80 \ --hash=sha256:ad16ceba90a56ec47cf45b21856d14962ac314787975ef786efb5e6e9ca75ec7 \ --hash=sha256:af4b3d372f2477dcd89a6e717e4a575ca35ccc20cc4454a8a4b6f8838a00672d \ @@ -147,6 +146,7 @@ dm-tree==0.1.8 \ --hash=sha256:b9bd9b9ccb59409d33d51d84b7668010c04c2af7d4a371632874c1ca356cff3d \ --hash=sha256:b9f89a454e98806b44fe9d40ec9eee61f848388f7e79ac2371a55679bd5a3ac6 \ --hash=sha256:bb2d109f42190225112da899b9f3d46d0d5f26aef501c61e43529fe9322530b5 \ + --hash=sha256:c0a94aba18a35457a1b5cd716fd7b46c5dafdc4cf7869b4bae665b91c4682a8e \ --hash=sha256:c5c8c12e3fda754ef6af94161bacdaeda816d941995fac415d6855c6c386af68 \ --hash=sha256:d1612fcaecd79023dbc6a6ae48d51a80beb5c385d6f3f6d71688e57bc8d07de8 \ --hash=sha256:d16e1f2a073604cfcc09f7131ae8d534674f43c3aef4c25742eae295bc60d04f \ @@ -154,6 +154,7 @@ dm-tree==0.1.8 \ --hash=sha256:d40fa4106ca6edc66760246a08f500ec0c85ef55c762fb4a363f6ee739ba02ee \ --hash=sha256:de287fabc464b8734be251e46e06aa9aa1001f34198da2b6ce07bd197172b9cb \ --hash=sha256:e4d714371bb08839e4e5e29024fc95832d9affe129825ef38836b143028bd144 \ + --hash=sha256:ea9e59e0451e7d29aece402d9f908f2e2a80922bcde2ebfd5dcb07750fcbfee8 \ --hash=sha256:f7ac31b9aecccb2c6e1ab29706f6ded3eba0c2c69c770322c9c685929c3d6afb \ --hash=sha256:fa42a605d099ee7d41ba2b5fb75e21423951fd26e5d50583a00471238fb3021d # via keras-nightly @@ -161,71 +162,61 @@ gast==0.4.0 \ --hash=sha256:40feb7b8b8434785585ab224d1568b857edb18297e5a3047f1ba012bc83b42c1 \ --hash=sha256:b7adcdd5adbebf1adf17378da5ba3f543684dbec47b1cda1f3997e573cd542c4 # via -r requirements.in -google-auth==2.23.3 \ - --hash=sha256:6864247895eea5d13b9c57c9e03abb49cb94ce2dc7c58e91cba3248c7477c9e3 \ - --hash=sha256:a8f4608e65c244ead9e0538f181a96c6e11199ec114d41f1d7b1bffa96937bda - # via - # google-auth-oauthlib - # tb-nightly -google-auth-oauthlib==1.1.0 \ - --hash=sha256:089c6e587d36f4803ac7e0720c045c6a8b1fd1790088b8424975b90d0ee61c12 \ - --hash=sha256:83ea8c3b0881e453790baff4448e8a6112ac8778d1de9da0b68010b843937afb - # via tb-nightly -grpcio==1.59.2 \ - --hash=sha256:023088764012411affe7db183d1ada3ad9daf2e23ddc719ff46d7061de661340 \ - --hash=sha256:08d77e682f2bf730a4961eea330e56d2f423c6a9b91ca222e5b1eb24a357b19f \ - --hash=sha256:0a4a3833c0e067f3558538727235cd8a49709bff1003200bbdefa2f09334e4b1 \ - --hash=sha256:0a754aff9e3af63bdc4c75c234b86b9d14e14a28a30c4e324aed1a9b873d755f \ - --hash=sha256:11168ef43e4a43ff1b1a65859f3e0ef1a173e277349e7fb16923ff108160a8cd \ - --hash=sha256:128e20f57c5f27cb0157e73756d1586b83c1b513ebecc83ea0ac37e4b0e4e758 \ - --hash=sha256:1f9524d1d701e399462d2c90ba7c193e49d1711cf429c0d3d97c966856e03d00 \ - --hash=sha256:1ff16d68bf453275466a9a46739061a63584d92f18a0f5b33d19fc97eb69867c \ - --hash=sha256:2067274c88bc6de89c278a672a652b4247d088811ece781a4858b09bdf8448e3 \ - --hash=sha256:2171c39f355ba5b551c5d5928d65aa6c69807fae195b86ef4a7d125bcdb860a9 \ - --hash=sha256:242adc47725b9a499ee77c6a2e36688fa6c96484611f33b1be4c57ab075a92dd \ - --hash=sha256:27f879ae604a7fcf371e59fba6f3ff4635a4c2a64768bd83ff0cac503142fef4 \ - --hash=sha256:2b230028a008ae1d0f430acb227d323ff8a619017415cf334c38b457f814119f \ - --hash=sha256:3059668df17627f0e0fa680e9ef8c995c946c792612e9518f5cc1503be14e90b \ - --hash=sha256:31176aa88f36020055ace9adff2405a33c8bdbfa72a9c4980e25d91b2f196873 \ - --hash=sha256:36f53c2b3449c015880e7d55a89c992c357f176327b0d2873cdaaf9628a37c69 \ - --hash=sha256:3b4368b33908f683a363f376dfb747d40af3463a6e5044afee07cf9436addf96 \ - --hash=sha256:3c61d641d4f409c5ae46bfdd89ea42ce5ea233dcf69e74ce9ba32b503c727e29 \ - --hash=sha256:4abb717e320e74959517dc8e84a9f48fbe90e9abe19c248541e9418b1ce60acd \ - --hash=sha256:4c93f4abbb54321ee6471e04a00139c80c754eda51064187963ddf98f5cf36a4 \ - --hash=sha256:535561990e075fa6bd4b16c4c3c1096b9581b7bb35d96fac4650f1181e428268 \ - --hash=sha256:53c9aa5ddd6857c0a1cd0287225a2a25873a8e09727c2e95c4aebb1be83a766a \ - --hash=sha256:5d573e70a6fe77555fb6143c12d3a7d3fa306632a3034b4e7c59ca09721546f8 \ - --hash=sha256:6009386a2df66159f64ac9f20425ae25229b29b9dd0e1d3dd60043f037e2ad7e \ - --hash=sha256:686e975a5d16602dc0982c7c703948d17184bd1397e16c8ee03511ecb8c4cdda \ - --hash=sha256:6959fb07e8351e20501ffb8cc4074c39a0b7ef123e1c850a7f8f3afdc3a3da01 \ - --hash=sha256:6b25ed37c27e652db01be341af93fbcea03d296c024d8a0e680017a268eb85dd \ - --hash=sha256:6da6dea3a1bacf99b3c2187e296db9a83029ed9c38fd4c52b7c9b7326d13c828 \ - --hash=sha256:72ca2399097c0b758198f2ff30f7178d680de8a5cfcf3d9b73a63cf87455532e \ - --hash=sha256:73abb8584b0cf74d37f5ef61c10722adc7275502ab71789a8fe3cb7ef04cf6e2 \ - --hash=sha256:74100fecaec8a535e380cf5f2fb556ff84957d481c13e54051c52e5baac70541 \ - --hash=sha256:75c6ecb70e809cf1504465174343113f51f24bc61e22a80ae1c859f3f7034c6d \ - --hash=sha256:7cf05053242f61ba94014dd3a986e11a083400a32664058f80bf4cf817c0b3a1 \ - --hash=sha256:9411e24328a2302e279e70cae6e479f1fddde79629fcb14e03e6d94b3956eabf \ - --hash=sha256:a213acfbf186b9f35803b52e4ca9addb153fc0b67f82a48f961be7000ecf6721 \ - --hash=sha256:bb7e0fe6ad73b7f06d7e2b689c19a71cf5cc48f0c2bf8608469e51ffe0bd2867 \ - --hash=sha256:c2504eed520958a5b77cc99458297cb7906308cb92327f35fb7fbbad4e9b2188 \ - --hash=sha256:c35aa9657f5d5116d23b934568e0956bd50c615127810fffe3ac356a914c176a \ - --hash=sha256:c5f09cffa619adfb44799fa4a81c2a1ad77c887187613fb0a8f201ab38d89ba1 \ - --hash=sha256:c978f864b35f2261e0819f5cd88b9830b04dc51bcf055aac3c601e525a10d2ba \ - --hash=sha256:cbe946b3e6e60a7b4618f091e62a029cb082b109a9d6b53962dd305087c6e4fd \ - --hash=sha256:cc3e4cd087f07758b16bef8f31d88dbb1b5da5671d2f03685ab52dece3d7a16e \ - --hash=sha256:cf0dead5a2c5a3347af2cfec7131d4f2a2e03c934af28989c9078f8241a491fa \ - --hash=sha256:d2794f0e68b3085d99b4f6ff9c089f6fdd02b32b9d3efdfbb55beac1bf22d516 \ - --hash=sha256:d2fa68a96a30dd240be80bbad838a0ac81a61770611ff7952b889485970c4c71 \ - --hash=sha256:d6f70406695e3220f09cd7a2f879333279d91aa4a8a1d34303b56d61a8180137 \ - --hash=sha256:d8f9cd4ad1be90b0cf350a2f04a38a36e44a026cac1e036ac593dc48efe91d52 \ - --hash=sha256:da2d94c15f88cd40d7e67f7919d4f60110d2b9d5b1e08cf354c2be773ab13479 \ - --hash=sha256:e1727c1c0e394096bb9af185c6923e8ea55a5095b8af44f06903bcc0e06800a2 \ - --hash=sha256:e420ced29b5904cdf9ee5545e23f9406189d8acb6750916c2db4793dada065c6 \ - --hash=sha256:e82c5cf1495244adf5252f925ac5932e5fd288b3e5ab6b70bec5593074b7236c \ - --hash=sha256:f1ef0d39bc1feb420caf549b3c657c871cad4ebbcf0580c4d03816b0590de0cf \ - --hash=sha256:f8753a6c88d1d0ba64302309eecf20f70d2770f65ca02d83c2452279085bfcd3 \ - --hash=sha256:f93dbf58f03146164048be5426ffde298b237a5e059144847e4940f5b80172c3 +grpcio==1.60.1 \ + --hash=sha256:0250a7a70b14000fa311de04b169cc7480be6c1a769b190769d347939d3232a8 \ + --hash=sha256:069fe2aeee02dfd2135d562d0663fe70fbb69d5eed6eb3389042a7e963b54de8 \ + --hash=sha256:082081e6a36b6eb5cf0fd9a897fe777dbb3802176ffd08e3ec6567edd85bc104 \ + --hash=sha256:0c5807e9152eff15f1d48f6b9ad3749196f79a4a050469d99eecb679be592acc \ + --hash=sha256:14e8f2c84c0832773fb3958240c69def72357bc11392571f87b2d7b91e0bb092 \ + --hash=sha256:2a6087f234cb570008a6041c8ffd1b7d657b397fdd6d26e83d72283dae3527b1 \ + --hash=sha256:2bb2a2911b028f01c8c64d126f6b632fcd8a9ac975aa1b3855766c94e4107180 \ + --hash=sha256:2f44c32aef186bbba254129cea1df08a20be414144ac3bdf0e84b24e3f3b2e05 \ + --hash=sha256:30e980cd6db1088c144b92fe376747328d5554bc7960ce583ec7b7d81cd47287 \ + --hash=sha256:33aed0a431f5befeffd9d346b0fa44b2c01aa4aeae5ea5b2c03d3e25e0071216 \ + --hash=sha256:33bdea30dcfd4f87b045d404388469eb48a48c33a6195a043d116ed1b9a0196c \ + --hash=sha256:39aa848794b887120b1d35b1b994e445cc028ff602ef267f87c38122c1add50d \ + --hash=sha256:4216e67ad9a4769117433814956031cb300f85edc855252a645a9a724b3b6594 \ + --hash=sha256:49c9b6a510e3ed8df5f6f4f3c34d7fbf2d2cae048ee90a45cd7415abab72912c \ + --hash=sha256:4eec8b8c1c2c9b7125508ff7c89d5701bf933c99d3910e446ed531cd16ad5d87 \ + --hash=sha256:50d56280b482875d1f9128ce596e59031a226a8b84bec88cb2bf76c289f5d0de \ + --hash=sha256:53b69e79d00f78c81eecfb38f4516080dc7f36a198b6b37b928f1c13b3c063e9 \ + --hash=sha256:55ccb7db5a665079d68b5c7c86359ebd5ebf31a19bc1a91c982fd622f1e31ff2 \ + --hash=sha256:5a1ebbae7e2214f51b1f23b57bf98eeed2cf1ba84e4d523c48c36d5b2f8829ff \ + --hash=sha256:61b7199cd2a55e62e45bfb629a35b71fc2c0cb88f686a047f25b1112d3810904 \ + --hash=sha256:660fc6b9c2a9ea3bb2a7e64ba878c98339abaf1811edca904ac85e9e662f1d73 \ + --hash=sha256:6d140bdeb26cad8b93c1455fa00573c05592793c32053d6e0016ce05ba267549 \ + --hash=sha256:6e490fa5f7f5326222cb9f0b78f207a2b218a14edf39602e083d5f617354306f \ + --hash=sha256:6ecf21d20d02d1733e9c820fb5c114c749d888704a7ec824b545c12e78734d1c \ + --hash=sha256:70c83bb530572917be20c21f3b6be92cd86b9aecb44b0c18b1d3b2cc3ae47df0 \ + --hash=sha256:72153a0d2e425f45b884540a61c6639436ddafa1829a42056aa5764b84108b8e \ + --hash=sha256:73e14acd3d4247169955fae8fb103a2b900cfad21d0c35f0dcd0fdd54cd60367 \ + --hash=sha256:76eaaba891083fcbe167aa0f03363311a9f12da975b025d30e94b93ac7a765fc \ + --hash=sha256:79ae0dc785504cb1e1788758c588c711f4e4a0195d70dff53db203c95a0bd303 \ + --hash=sha256:7d142bcd604166417929b071cd396aa13c565749a4c840d6c702727a59d835eb \ + --hash=sha256:8c9554ca8e26241dabe7951aa1fa03a1ba0856688ecd7e7bdbdd286ebc272e4c \ + --hash=sha256:8d488fbdbf04283f0d20742b64968d44825617aa6717b07c006168ed16488804 \ + --hash=sha256:91422ba785a8e7a18725b1dc40fbd88f08a5bb4c7f1b3e8739cab24b04fa8a03 \ + --hash=sha256:9a66f4d2a005bc78e61d805ed95dedfcb35efa84b7bba0403c6d60d13a3de2d6 \ + --hash=sha256:9b106bc52e7f28170e624ba61cc7dc6829566e535a6ec68528f8e1afbed1c41f \ + --hash=sha256:9b54577032d4f235452f77a83169b6527bf4b77d73aeada97d45b2aaf1bf5ce0 \ + --hash=sha256:a09506eb48fa5493c58f946c46754ef22f3ec0df64f2b5149373ff31fb67f3dd \ + --hash=sha256:a212e5dea1a4182e40cd3e4067ee46be9d10418092ce3627475e995cca95de21 \ + --hash=sha256:a731ac5cffc34dac62053e0da90f0c0b8560396a19f69d9703e88240c8f05858 \ + --hash=sha256:af5ef6cfaf0d023c00002ba25d0751e5995fa0e4c9eec6cd263c30352662cbce \ + --hash=sha256:b58b855d0071575ea9c7bc0d84a06d2edfbfccec52e9657864386381a7ce1ae9 \ + --hash=sha256:bc808924470643b82b14fe121923c30ec211d8c693e747eba8a7414bc4351a23 \ + --hash=sha256:c557e94e91a983e5b1e9c60076a8fd79fea1e7e06848eb2e48d0ccfb30f6e073 \ + --hash=sha256:c71be3f86d67d8d1311c6076a4ba3b75ba5703c0b856b4e691c9097f9b1e8bd2 \ + --hash=sha256:c8754c75f55781515a3005063d9a05878b2cfb3cb7e41d5401ad0cf19de14872 \ + --hash=sha256:cb0af13433dbbd1c806e671d81ec75bd324af6ef75171fd7815ca3074fe32bfe \ + --hash=sha256:cba6209c96828711cb7c8fcb45ecef8c8859238baf15119daa1bef0f6c84bfe7 \ + --hash=sha256:cf77f8cf2a651fbd869fbdcb4a1931464189cd210abc4cfad357f1cacc8642a6 \ + --hash=sha256:d7404cebcdb11bb5bd40bf94131faf7e9a7c10a6c60358580fe83913f360f929 \ + --hash=sha256:dd1d3a8d1d2e50ad9b59e10aa7f07c7d1be2b367f3f2d33c5fade96ed5460962 \ + --hash=sha256:e5d97c65ea7e097056f3d1ead77040ebc236feaf7f71489383d20f3b4c28412a \ + --hash=sha256:f1c3dc536b3ee124e8b24feb7533e5c70b9f2ef833e3b2e5513b2897fd46763a \ + --hash=sha256:f2212796593ad1d0235068c79836861f2201fc7137a99aa2fea7beeb3b101177 \ + --hash=sha256:fead980fbc68512dfd4e0c7b1f5754c2a8e5015a04dea454b9cada54a8423525 # via # -r requirements.in # tb-nightly @@ -258,113 +249,115 @@ h5py==3.10.0 \ # via # -r requirements.in # keras-nightly -idna==3.4 \ - --hash=sha256:814f528e8dead7d329833b91c5faa87d60bf71824cd12a7530b5526063d02cb4 \ - --hash=sha256:90b77e79eaa3eba6de819a0c442c0b4ceefc341a7a2ab77d7562bf49f425c5c2 +idna==3.6 \ + --hash=sha256:9ecdbbd083b06798ae1e86adcbfe8ab1479cf864e4ee30fe4e46a003d12491ca \ + --hash=sha256:c05567e9c24a6b9faaa835c4821bad0590fbb9d5779e7caa6e1cc4978e7eb24f # via requests jax==0.4.7 \ --hash=sha256:5e7002d74db25f97c99b979d4ba1233b1ef26e1597e5fc468ad11d1c8a9dc4f8 # via -r requirements.in -keras-nightly==3.0.0.dev2023103103 \ - --hash=sha256:25a6a9030d7e067d535ec41ae6700636c90cabada80fbddb3bb8775006807860 \ - --hash=sha256:c39fccebb3d4cc9d838371981c4d3eef08fc06eadf6cffb39b356cb625bab50f +keras-nightly==3.0.4.dev2024021403 \ + --hash=sha256:24ce69d29d582771685bf4235f59663723405b5a5b16f3eaff2657e52e74663a \ + --hash=sha256:9f416e66b820ef833779d219d255b346b8b90a72fdbd0b2f1e90a43ad142a03d # via -r requirements.in -lit==17.0.4 \ - --hash=sha256:ee2e180128e770abc6aed3a02de2daf09d81b7d30225e315205d3599c311d304 +lit==17.0.6 \ + --hash=sha256:dfa9af9b55fc4509a56be7bf2346f079d7f4a242d583b9f2e0b078fd0abae31b # via -r requirements.in -markdown==3.5 \ - --hash=sha256:4afb124395ce5fc34e6d9886dab977fd9ae987fc6e85689f08278cf0c69d4bf3 \ - --hash=sha256:a807eb2e4778d9156c8f07876c6e4d50b5494c5665c4834f67b06459dfd877b3 +markdown==3.5.2 \ + --hash=sha256:d43323865d89fc0cb9b20c75fc8ad313af307cc087e84b657d9eec768eddeadd \ + --hash=sha256:e1ac7b3dc550ee80e602e71c1d168002f062e49f1b11e26a36264dafd4df2ef8 # via tb-nightly markdown-it-py==3.0.0 \ --hash=sha256:355216845c60bd96232cd8d8c40e8f9765cc86f46880e43a8fd22dc1a1a8cab1 \ --hash=sha256:e3f60a94fa066dc52ec76661e37c851cb232d92f9886b15cb560aaada2df8feb # via rich -markupsafe==2.1.3 \ - --hash=sha256:05fb21170423db021895e1ea1e1f3ab3adb85d1c2333cbc2310f2a26bc77272e \ - --hash=sha256:0a4e4a1aff6c7ac4cd55792abf96c915634c2b97e3cc1c7129578aa68ebd754e \ - --hash=sha256:10bbfe99883db80bdbaff2dcf681dfc6533a614f700da1287707e8a5d78a8431 \ - --hash=sha256:134da1eca9ec0ae528110ccc9e48041e0828d79f24121a1a146161103c76e686 \ - --hash=sha256:14ff806850827afd6b07a5f32bd917fb7f45b046ba40c57abdb636674a8b559c \ - --hash=sha256:1577735524cdad32f9f694208aa75e422adba74f1baee7551620e43a3141f559 \ - --hash=sha256:1b40069d487e7edb2676d3fbdb2b0829ffa2cd63a2ec26c4938b2d34391b4ecc \ - --hash=sha256:1b8dd8c3fd14349433c79fa8abeb573a55fc0fdd769133baac1f5e07abf54aeb \ - --hash=sha256:1f67c7038d560d92149c060157d623c542173016c4babc0c1913cca0564b9939 \ - --hash=sha256:282c2cb35b5b673bbcadb33a585408104df04f14b2d9b01d4c345a3b92861c2c \ - --hash=sha256:2c1b19b3aaacc6e57b7e25710ff571c24d6c3613a45e905b1fde04d691b98ee0 \ - --hash=sha256:2ef12179d3a291be237280175b542c07a36e7f60718296278d8593d21ca937d4 \ - --hash=sha256:338ae27d6b8745585f87218a3f23f1512dbf52c26c28e322dbe54bcede54ccb9 \ - --hash=sha256:3c0fae6c3be832a0a0473ac912810b2877c8cb9d76ca48de1ed31e1c68386575 \ - --hash=sha256:3fd4abcb888d15a94f32b75d8fd18ee162ca0c064f35b11134be77050296d6ba \ - --hash=sha256:42de32b22b6b804f42c5d98be4f7e5e977ecdd9ee9b660fda1a3edf03b11792d \ - --hash=sha256:47d4f1c5f80fc62fdd7777d0d40a2e9dda0a05883ab11374334f6c4de38adffd \ - --hash=sha256:504b320cd4b7eff6f968eddf81127112db685e81f7e36e75f9f84f0df46041c3 \ - --hash=sha256:525808b8019e36eb524b8c68acdd63a37e75714eac50e988180b169d64480a00 \ - --hash=sha256:56d9f2ecac662ca1611d183feb03a3fa4406469dafe241673d521dd5ae92a155 \ - --hash=sha256:5bbe06f8eeafd38e5d0a4894ffec89378b6c6a625ff57e3028921f8ff59318ac \ - --hash=sha256:65c1a9bcdadc6c28eecee2c119465aebff8f7a584dd719facdd9e825ec61ab52 \ - --hash=sha256:68e78619a61ecf91e76aa3e6e8e33fc4894a2bebe93410754bd28fce0a8a4f9f \ - --hash=sha256:69c0f17e9f5a7afdf2cc9fb2d1ce6aabdb3bafb7f38017c0b77862bcec2bbad8 \ - --hash=sha256:6b2b56950d93e41f33b4223ead100ea0fe11f8e6ee5f641eb753ce4b77a7042b \ - --hash=sha256:715d3562f79d540f251b99ebd6d8baa547118974341db04f5ad06d5ea3eb8007 \ - --hash=sha256:787003c0ddb00500e49a10f2844fac87aa6ce977b90b0feaaf9de23c22508b24 \ - --hash=sha256:7ef3cb2ebbf91e330e3bb937efada0edd9003683db6b57bb108c4001f37a02ea \ - --hash=sha256:8023faf4e01efadfa183e863fefde0046de576c6f14659e8782065bcece22198 \ - --hash=sha256:8758846a7e80910096950b67071243da3e5a20ed2546e6392603c096778d48e0 \ - --hash=sha256:8afafd99945ead6e075b973fefa56379c5b5c53fd8937dad92c662da5d8fd5ee \ - --hash=sha256:8c41976a29d078bb235fea9b2ecd3da465df42a562910f9022f1a03107bd02be \ - --hash=sha256:8e254ae696c88d98da6555f5ace2279cf7cd5b3f52be2b5cf97feafe883b58d2 \ - --hash=sha256:8f9293864fe09b8149f0cc42ce56e3f0e54de883a9de90cd427f191c346eb2e1 \ - --hash=sha256:9402b03f1a1b4dc4c19845e5c749e3ab82d5078d16a2a4c2cd2df62d57bb0707 \ - --hash=sha256:962f82a3086483f5e5f64dbad880d31038b698494799b097bc59c2edf392fce6 \ - --hash=sha256:9aad3c1755095ce347e26488214ef77e0485a3c34a50c5a5e2471dff60b9dd9c \ - --hash=sha256:9dcdfd0eaf283af041973bff14a2e143b8bd64e069f4c383416ecd79a81aab58 \ - --hash=sha256:aa57bd9cf8ae831a362185ee444e15a93ecb2e344c8e52e4d721ea3ab6ef1823 \ - --hash=sha256:aa7bd130efab1c280bed0f45501b7c8795f9fdbeb02e965371bbef3523627779 \ - --hash=sha256:ab4a0df41e7c16a1392727727e7998a467472d0ad65f3ad5e6e765015df08636 \ - --hash=sha256:ad9e82fb8f09ade1c3e1b996a6337afac2b8b9e365f926f5a61aacc71adc5b3c \ - --hash=sha256:af598ed32d6ae86f1b747b82783958b1a4ab8f617b06fe68795c7f026abbdcad \ - --hash=sha256:b076b6226fb84157e3f7c971a47ff3a679d837cf338547532ab866c57930dbee \ - --hash=sha256:b7ff0f54cb4ff66dd38bebd335a38e2c22c41a8ee45aa608efc890ac3e3931bc \ - --hash=sha256:bfce63a9e7834b12b87c64d6b155fdd9b3b96191b6bd334bf37db7ff1fe457f2 \ - --hash=sha256:c011a4149cfbcf9f03994ec2edffcb8b1dc2d2aede7ca243746df97a5d41ce48 \ - --hash=sha256:c9c804664ebe8f83a211cace637506669e7890fec1b4195b505c214e50dd4eb7 \ - --hash=sha256:ca379055a47383d02a5400cb0d110cef0a776fc644cda797db0c5696cfd7e18e \ - --hash=sha256:cb0932dc158471523c9637e807d9bfb93e06a95cbf010f1a38b98623b929ef2b \ - --hash=sha256:cd0f502fe016460680cd20aaa5a76d241d6f35a1c3350c474bac1273803893fa \ - --hash=sha256:ceb01949af7121f9fc39f7d27f91be8546f3fb112c608bc4029aef0bab86a2a5 \ - --hash=sha256:d080e0a5eb2529460b30190fcfcc4199bd7f827663f858a226a81bc27beaa97e \ - --hash=sha256:dd15ff04ffd7e05ffcb7fe79f1b98041b8ea30ae9234aed2a9168b5797c3effb \ - --hash=sha256:df0be2b576a7abbf737b1575f048c23fb1d769f267ec4358296f31c2479db8f9 \ - --hash=sha256:e09031c87a1e51556fdcb46e5bd4f59dfb743061cf93c4d6831bf894f125eb57 \ - --hash=sha256:e4dd52d80b8c83fdce44e12478ad2e85c64ea965e75d66dbeafb0a3e77308fcc \ - --hash=sha256:f698de3fd0c4e6972b92290a45bd9b1536bffe8c6759c62471efaa8acb4c37bc \ - --hash=sha256:fec21693218efe39aa7f8599346e90c705afa52c5b31ae019b2e57e8f6542bb2 \ - --hash=sha256:ffcc3f7c66b5f5b7931a5aa68fc9cecc51e685ef90282f4a82f0f5e9b704ad11 +markupsafe==2.1.5 \ + --hash=sha256:00e046b6dd71aa03a41079792f8473dc494d564611a8f89bbbd7cb93295ebdcf \ + --hash=sha256:075202fa5b72c86ad32dc7d0b56024ebdbcf2048c0ba09f1cde31bfdd57bcfff \ + --hash=sha256:0e397ac966fdf721b2c528cf028494e86172b4feba51d65f81ffd65c63798f3f \ + --hash=sha256:17b950fccb810b3293638215058e432159d2b71005c74371d784862b7e4683f3 \ + --hash=sha256:1f3fbcb7ef1f16e48246f704ab79d79da8a46891e2da03f8783a5b6fa41a9532 \ + --hash=sha256:2174c595a0d73a3080ca3257b40096db99799265e1c27cc5a610743acd86d62f \ + --hash=sha256:2b7c57a4dfc4f16f7142221afe5ba4e093e09e728ca65c51f5620c9aaeb9a617 \ + --hash=sha256:2d2d793e36e230fd32babe143b04cec8a8b3eb8a3122d2aceb4a371e6b09b8df \ + --hash=sha256:30b600cf0a7ac9234b2638fbc0fb6158ba5bdcdf46aeb631ead21248b9affbc4 \ + --hash=sha256:397081c1a0bfb5124355710fe79478cdbeb39626492b15d399526ae53422b906 \ + --hash=sha256:3a57fdd7ce31c7ff06cdfbf31dafa96cc533c21e443d57f5b1ecc6cdc668ec7f \ + --hash=sha256:3c6b973f22eb18a789b1460b4b91bf04ae3f0c4234a0a6aa6b0a92f6f7b951d4 \ + --hash=sha256:3e53af139f8579a6d5f7b76549125f0d94d7e630761a2111bc431fd820e163b8 \ + --hash=sha256:4096e9de5c6fdf43fb4f04c26fb114f61ef0bf2e5604b6ee3019d51b69e8c371 \ + --hash=sha256:4275d846e41ecefa46e2015117a9f491e57a71ddd59bbead77e904dc02b1bed2 \ + --hash=sha256:4c31f53cdae6ecfa91a77820e8b151dba54ab528ba65dfd235c80b086d68a465 \ + --hash=sha256:4f11aa001c540f62c6166c7726f71f7573b52c68c31f014c25cc7901deea0b52 \ + --hash=sha256:5049256f536511ee3f7e1b3f87d1d1209d327e818e6ae1365e8653d7e3abb6a6 \ + --hash=sha256:58c98fee265677f63a4385256a6d7683ab1832f3ddd1e66fe948d5880c21a169 \ + --hash=sha256:598e3276b64aff0e7b3451b72e94fa3c238d452e7ddcd893c3ab324717456bad \ + --hash=sha256:5b7b716f97b52c5a14bffdf688f971b2d5ef4029127f1ad7a513973cfd818df2 \ + --hash=sha256:5dedb4db619ba5a2787a94d877bc8ffc0566f92a01c0ef214865e54ecc9ee5e0 \ + --hash=sha256:619bc166c4f2de5caa5a633b8b7326fbe98e0ccbfacabd87268a2b15ff73a029 \ + --hash=sha256:629ddd2ca402ae6dbedfceeba9c46d5f7b2a61d9749597d4307f943ef198fc1f \ + --hash=sha256:656f7526c69fac7f600bd1f400991cc282b417d17539a1b228617081106feb4a \ + --hash=sha256:6ec585f69cec0aa07d945b20805be741395e28ac1627333b1c5b0105962ffced \ + --hash=sha256:72b6be590cc35924b02c78ef34b467da4ba07e4e0f0454a2c5907f473fc50ce5 \ + --hash=sha256:7502934a33b54030eaf1194c21c692a534196063db72176b0c4028e140f8f32c \ + --hash=sha256:7a68b554d356a91cce1236aa7682dc01df0edba8d043fd1ce607c49dd3c1edcf \ + --hash=sha256:7b2e5a267c855eea6b4283940daa6e88a285f5f2a67f2220203786dfa59b37e9 \ + --hash=sha256:823b65d8706e32ad2df51ed89496147a42a2a6e01c13cfb6ffb8b1e92bc910bb \ + --hash=sha256:8590b4ae07a35970728874632fed7bd57b26b0102df2d2b233b6d9d82f6c62ad \ + --hash=sha256:8dd717634f5a044f860435c1d8c16a270ddf0ef8588d4887037c5028b859b0c3 \ + --hash=sha256:8dec4936e9c3100156f8a2dc89c4b88d5c435175ff03413b443469c7c8c5f4d1 \ + --hash=sha256:97cafb1f3cbcd3fd2b6fbfb99ae11cdb14deea0736fc2b0952ee177f2b813a46 \ + --hash=sha256:a17a92de5231666cfbe003f0e4b9b3a7ae3afb1ec2845aadc2bacc93ff85febc \ + --hash=sha256:a549b9c31bec33820e885335b451286e2969a2d9e24879f83fe904a5ce59d70a \ + --hash=sha256:ac07bad82163452a6884fe8fa0963fb98c2346ba78d779ec06bd7a6262132aee \ + --hash=sha256:ae2ad8ae6ebee9d2d94b17fb62763125f3f374c25618198f40cbb8b525411900 \ + --hash=sha256:b91c037585eba9095565a3556f611e3cbfaa42ca1e865f7b8015fe5c7336d5a5 \ + --hash=sha256:bc1667f8b83f48511b94671e0e441401371dfd0f0a795c7daa4a3cd1dde55bea \ + --hash=sha256:bec0a414d016ac1a18862a519e54b2fd0fc8bbfd6890376898a6c0891dd82e9f \ + --hash=sha256:bf50cd79a75d181c9181df03572cdce0fbb75cc353bc350712073108cba98de5 \ + --hash=sha256:bff1b4290a66b490a2f4719358c0cdcd9bafb6b8f061e45c7a2460866bf50c2e \ + --hash=sha256:c061bb86a71b42465156a3ee7bd58c8c2ceacdbeb95d05a99893e08b8467359a \ + --hash=sha256:c8b29db45f8fe46ad280a7294f5c3ec36dbac9491f2d1c17345be8e69cc5928f \ + --hash=sha256:ce409136744f6521e39fd8e2a24c53fa18ad67aa5bc7c2cf83645cce5b5c4e50 \ + --hash=sha256:d050b3361367a06d752db6ead6e7edeb0009be66bc3bae0ee9d97fb326badc2a \ + --hash=sha256:d283d37a890ba4c1ae73ffadf8046435c76e7bc2247bbb63c00bd1a709c6544b \ + --hash=sha256:d9fad5155d72433c921b782e58892377c44bd6252b5af2f67f16b194987338a4 \ + --hash=sha256:daa4ee5a243f0f20d528d939d06670a298dd39b1ad5f8a72a4275124a7819eff \ + --hash=sha256:db0b55e0f3cc0be60c1f19efdde9a637c32740486004f20d1cff53c3c0ece4d2 \ + --hash=sha256:e61659ba32cf2cf1481e575d0462554625196a1f2fc06a1c777d3f48e8865d46 \ + --hash=sha256:ea3d8a3d18833cf4304cd2fc9cbb1efe188ca9b5efef2bdac7adc20594a0e46b \ + --hash=sha256:ec6a563cff360b50eed26f13adc43e61bc0c04d94b8be985e6fb24b81f6dcfdf \ + --hash=sha256:f5dfb42c4604dddc8e4305050aa6deb084540643ed5804d7455b5df8fe16f5e5 \ + --hash=sha256:fa173ec60341d6bb97a89f5ea19c85c5643c1e7dedebc22f5181eb73573142c5 \ + --hash=sha256:fa9db3f79de01457b03d4f01b34cf91bc0048eb2c3846ff26f66687c2f6d16ab \ + --hash=sha256:fce659a462a1be54d2ffcacea5e3ba2d74daa74f30f5f143fe0c58636e355fdd \ + --hash=sha256:ffee1f21e5ef0d712f9033568f8344d5da8cc2869dbd08d87c84656e6a2d2f68 # via werkzeug mdurl==0.1.2 \ --hash=sha256:84008a41e51615a49fc9966191ff91509e3c40b939176e643fd50a5c2196b8f8 \ --hash=sha256:bb413d29f5eea38f31dd4754dd7377d4465116fb207585f97bf925588687c1ba # via markdown-it-py -ml-dtypes==0.3.1 \ - --hash=sha256:3d8ca0acbd377082792d8b97081ba580abdad67c6afb7f827012c675b052f058 \ - --hash=sha256:42a8980afd8b7c8e270e8b5c260237286b5b26acd276fcb758d13cd7cb567e99 \ - --hash=sha256:438437e2e614a3c91d75581653b6c40ec890e8b5994d7190a90c931740151c95 \ - --hash=sha256:4828b62fa3bf1ae35faa40f3db9a38ec72fbce02f328a1d14c3a9da4606af364 \ - --hash=sha256:4d94b2d1bed77284694f7fd0479640fa7aa5d96433dca3cbcec407a5ef752e77 \ - --hash=sha256:510d249a91face47211762eb294d6fe64f325356b965fb6388c1bf51bd339267 \ - --hash=sha256:5727effa7650f7ab10906542d137cfb3244fdc3b2b519beff42f82def8ba59be \ - --hash=sha256:5e0b0b6bb07fa5ad11bb61d174667176bee5e05857225067aabfc5adc1b51d23 \ - --hash=sha256:60778f99194b4c4f36ba42da200b35ef851ce4d4af698aaf70f5b91fe70fc611 \ - --hash=sha256:70984b473db6489ec1d8c79b082a1322105155193049d08a3b0c515094e9777b \ - --hash=sha256:979d7d196d9a17e0135ae22878f74241fbd3522cef58d7b292f1fd5b32282201 \ - --hash=sha256:a777928dcba8865ab4a8157eeb25d23aed7bc82e5fd74e1d5eca821d3f148b39 \ - --hash=sha256:cb0c404e0dd3e25b56362c1c1e5de0ef717f727dde59fa721af4ce6ab2acca44 \ - --hash=sha256:d1a8dc3bac1da2a17d0e2e4cba36ee89721d0bd33ea4765af2eefb5f41409e0f \ - --hash=sha256:da274599e4950a9b488d21571061f49a185537cc77f2d3f8121151d58a9e9f16 \ - --hash=sha256:f83ff080df8910c0f987f615b03e4f8198638e0c00c6e679ea8892dda909763b \ - --hash=sha256:fcae2c69715410d96906e1dfe8f017d9f78a0d10e0df91aae52e91f51fdfe45e - # via jax +ml-dtypes==0.3.2 \ + --hash=sha256:2c34f2ba9660b21fe1034b608308a01be82bbef2a92fb8199f24dc6bad0d5226 \ + --hash=sha256:3a17ef2322e60858d93584e9c52a5be7dd6236b056b7fa1ec57f1bb6ba043e33 \ + --hash=sha256:533059bc5f1764fac071ef54598db358c167c51a718f68f5bb55e3dee79d2967 \ + --hash=sha256:6604877d567a29bfe7cc02969ae0f2425260e5335505cf5e7fefc3e5465f5655 \ + --hash=sha256:6b35c4e8ca957c877ac35c79ffa77724ecc3702a1e4b18b08306c03feae597bb \ + --hash=sha256:763697ab8a88d47443997a7cdf3aac7340049aed45f7521f6b0ec8a0594821fe \ + --hash=sha256:7a4c3fcbf86fa52d0204f07cfd23947ef05b4ad743a1a988e163caa34a201e5e \ + --hash=sha256:7afde548890a92b41c0fed3a6c525f1200a5727205f73dc21181a2726571bb53 \ + --hash=sha256:7ba8e1fafc7fff3e643f453bffa7d082df1678a73286ce8187d3e825e776eb94 \ + --hash=sha256:91f8783fd1f2c23fd3b9ee5ad66b785dafa58ba3cdb050c4458021fa4d1eb226 \ + --hash=sha256:93b78f53431c93953f7850bb1b925a17f0ab5d97527e38a7e865b5b4bc5cfc18 \ + --hash=sha256:961134ea44c7b8ca63eda902a44b58cd8bd670e21d62e255c81fba0a8e70d9b7 \ + --hash=sha256:b89b194e9501a92d289c1ffd411380baf5daafb9818109a4f49b0a1b6dce4462 \ + --hash=sha256:c7b3fb3d4f6b39bcd4f6c4b98f406291f0d681a895490ee29a0f95bab850d53c \ + --hash=sha256:d1a746fe5fb9cd974a91070174258f0be129c592b93f9ce7df6cc336416c3fbd \ + --hash=sha256:e8505946df1665db01332d885c2020b4cb9e84a8b1241eb4ba69d59591f65855 \ + --hash=sha256:f47619d978ab1ae7dfdc4052ea97c636c6263e1f19bd1be0e42c346b98d15ff4 + # via + # jax + # keras-nightly namex==0.0.7 \ --hash=sha256:84ba65bc4d22bd909e3d26bf2ffb4b9529b608cb3f9a4336f776b04204ced69b \ --hash=sha256:8a4f062945f405d77cb66b907f16aa2fd83681945e998be840eb6c4154d40108 @@ -407,10 +400,6 @@ numpy==1.23.5 ; python_version <= "3.11" \ # opt-einsum # scipy # tb-nightly -oauthlib==3.2.2 \ - --hash=sha256:8139f29aac13e25d502680e9e19963e83f16838d48a0d71c287fe40e7067fbca \ - --hash=sha256:9859c40929662bec5d64f34d01c99e093149682a3f38915dc0655d5a633dd918 - # via requests-oauthlib opt-einsum==3.3.0 \ --hash=sha256:2455e59e3947d3c275477df7f5205b30635e266fe6dc300e3d9f9646bfcea147 \ --hash=sha256:59f6475f77bbc37dcf7cd748519c0ec60722e91e63ca114e68821c0c54a46549 @@ -440,57 +429,36 @@ protobuf==4.23.4 \ --hash=sha256:effeac51ab79332d44fba74660d40ae79985901ac21bca408f8dc335a81aa597 \ --hash=sha256:fee88269a090ada09ca63551bf2f573eb2424035bcf2cb1b121895b01a46594a # via tb-nightly -psutil==5.9.6 \ - --hash=sha256:10e8c17b4f898d64b121149afb136c53ea8b68c7531155147867b7b1ac9e7e28 \ - --hash=sha256:18cd22c5db486f33998f37e2bb054cc62fd06646995285e02a51b1e08da97017 \ - --hash=sha256:3ebf2158c16cc69db777e3c7decb3c0f43a7af94a60d72e87b2823aebac3d602 \ - --hash=sha256:51dc3d54607c73148f63732c727856f5febec1c7c336f8f41fcbd6315cce76ac \ - --hash=sha256:6e5fb8dc711a514da83098bc5234264e551ad980cec5f85dabf4d38ed6f15e9a \ - --hash=sha256:70cb3beb98bc3fd5ac9ac617a327af7e7f826373ee64c80efd4eb2856e5051e9 \ - --hash=sha256:748c9dd2583ed86347ed65d0035f45fa8c851e8d90354c122ab72319b5f366f4 \ - --hash=sha256:91ecd2d9c00db9817a4b4192107cf6954addb5d9d67a969a4f436dbc9200f88c \ - --hash=sha256:92e0cc43c524834af53e9d3369245e6cc3b130e78e26100d1f63cdb0abeb3d3c \ - --hash=sha256:a6f01f03bf1843280f4ad16f4bde26b817847b4c1a0db59bf6419807bc5ce05c \ - --hash=sha256:c69596f9fc2f8acd574a12d5f8b7b1ba3765a641ea5d60fb4736bf3c08a8214a \ - --hash=sha256:ca2780f5e038379e520281e4c032dddd086906ddff9ef0d1b9dcf00710e5071c \ - --hash=sha256:daecbcbd29b289aac14ece28eca6a3e60aa361754cf6da3dfb20d4d32b6c7f57 \ - --hash=sha256:e4b92ddcd7dd4cdd3f900180ea1e104932c7bce234fb88976e2a3b296441225a \ - --hash=sha256:fb8a697f11b0f5994550555fcfe3e69799e5b060c8ecf9e2f75c69302cc35c0d \ - --hash=sha256:ff18b8d1a784b810df0b0fff3bcb50ab941c3b8e2c8de5726f9c71c601c611aa +psutil==5.9.8 \ + --hash=sha256:02615ed8c5ea222323408ceba16c60e99c3f91639b07da6373fb7e6539abc56d \ + --hash=sha256:05806de88103b25903dff19bb6692bd2e714ccf9e668d050d144012055cbca73 \ + --hash=sha256:26bd09967ae00920df88e0352a91cff1a78f8d69b3ecabbfe733610c0af486c8 \ + --hash=sha256:27cc40c3493bb10de1be4b3f07cae4c010ce715290a5be22b98493509c6299e2 \ + --hash=sha256:36f435891adb138ed3c9e58c6af3e2e6ca9ac2f365efe1f9cfef2794e6c93b4e \ + --hash=sha256:50187900d73c1381ba1454cf40308c2bf6f34268518b3f36a9b663ca87e65e36 \ + --hash=sha256:611052c4bc70432ec770d5d54f64206aa7203a101ec273a0cd82418c86503bb7 \ + --hash=sha256:6be126e3225486dff286a8fb9a06246a5253f4c7c53b475ea5f5ac934e64194c \ + --hash=sha256:7d79560ad97af658a0f6adfef8b834b53f64746d45b403f225b85c5c2c140eee \ + --hash=sha256:8cb6403ce6d8e047495a701dc7c5bd788add903f8986d523e3e20b98b733e421 \ + --hash=sha256:8db4c1b57507eef143a15a6884ca10f7c73876cdf5d51e713151c1236a0e68cf \ + --hash=sha256:aee678c8720623dc456fa20659af736241f575d79429a0e5e9cf88ae0605cc81 \ + --hash=sha256:bc56c2a1b0d15aa3eaa5a60c9f3f8e3e565303b465dbf57a1b730e7a2b9844e0 \ + --hash=sha256:bd1184ceb3f87651a67b2708d4c3338e9b10c5df903f2e3776b62303b26cb631 \ + --hash=sha256:d06016f7f8625a1825ba3732081d77c94589dca78b7a3fc072194851e88461a4 \ + --hash=sha256:d16bbddf0693323b8c6123dd804100241da461e41d6e332fb0ba6058f630f8c8 # via portpicker -pyasn1==0.5.0 \ - --hash=sha256:87a2121042a1ac9358cabcaf1d07680ff97ee6404333bacca15f76aa8ad01a57 \ - --hash=sha256:97b7290ca68e62a832558ec3976f15cbf911bf5d7c7039d8b861c2a0ece69fde - # via - # pyasn1-modules - # rsa -pyasn1-modules==0.3.0 \ - --hash=sha256:5bd01446b736eb9d31512a30d46c1ac3395d676c6f3cafa4c03eb54b9925631c \ - --hash=sha256:d3ccd6ed470d9ffbc716be08bd90efbd44d0734bc9303818f7336070984a162d - # via google-auth -pygments==2.16.1 \ - --hash=sha256:13fc09fa63bc8d8671a6d247e1eb303c4b343eaee81d861f3404db2935653692 \ - --hash=sha256:1daff0494820c69bc8941e407aa20f577374ee88364ee10a98fdbe0aece96e29 +pygments==2.17.2 \ + --hash=sha256:b27c2826c47d0f3219f29554824c30c5e8945175d888647acd804ddd04af846c \ + --hash=sha256:da46cec9fd2de5be3a8a784f434e4c4ab670b4ff54d605c4c2717e9d49c4c367 # via rich requests==2.31.0 \ --hash=sha256:58cd2187c01e70e6e26505bca751777aa9f2ee0b7f4300988b709f44e013003f \ --hash=sha256:942c5a758f98d790eaed1a29cb6eefc7ffb0d1cf7af05c3d2791656dbd6ad1e1 - # via - # -r requirements.in - # requests-oauthlib - # tb-nightly -requests-oauthlib==1.3.1 \ - --hash=sha256:2577c501a2fb8d05a304c09d090d6e47c306fef15809d102b327cf8364bddab5 \ - --hash=sha256:75beac4a47881eeb94d5ea5d6ad31ef88856affe2332b9aafb52c6452ccf0d7a - # via google-auth-oauthlib -rich==13.6.0 \ - --hash=sha256:2b38e2fe9ca72c9a00170a1a2d20c63c790d0e10ef1fe35eba76e1e7b1d7d245 \ - --hash=sha256:5c14d22737e6d5084ef4771b62d5d4363165b403455a30a1c8ca39dc7b644bef + # via -r requirements.in +rich==13.7.0 \ + --hash=sha256:5cb5123b5cf9ee70584244246816e9114227e0b98ad9176eede6ad54bf5403fa \ + --hash=sha256:6da14c108c4866ee9520bbffa71f6fe3962e193b7da68720583850cd4548e235 # via keras-nightly -rsa==4.9 \ - --hash=sha256:90260d9058e514786967344d0ef75fa8727eed8a7d2e43ce9f4bcf1b536174f7 \ - --hash=sha256:e38464a49c6c85d7f1351b0126661487a7e0a14a50f1675ec50eb34d4f20ef21 - # via google-auth scipy==1.11.3 \ --hash=sha256:00f325434b6424952fbb636506f0567898dca7b0f7654d48f1c382ea338ce9a3 \ --hash=sha256:033c3fd95d55012dd1148b201b72ae854d5086d25e7c316ec9850de4fe776929 \ @@ -526,8 +494,8 @@ six==1.16.0 \ # via # astunparse # tb-nightly -tb-nightly==2.15.0a20231023 \ - --hash=sha256:8990a52985e3296aa18a6825efc017bcded9e2fb2cbcdd2d01f5c62d4fdc9825 +tb-nightly==2.17.0a20240214 \ + --hash=sha256:dbd59bcfd9b028e6199050e36cbb4ee1db4a94e69a7f8c00865f1498112a2b83 # via -r requirements.in tblib==2.0.0 \ --hash=sha256:9100bfa016b047d5b980d66e7efed952fbd20bd85b56110aaf473cb97d18709a \ @@ -542,13 +510,17 @@ termcolor==2.3.0 \ --hash=sha256:3afb05607b89aed0ffe25202399ee0867ad4d3cb4180d98aaf8eefa6a5f7d475 \ --hash=sha256:b5b08f68937f138fe92f6c089b99f1e2da0ae56c52b78bf7075fd95420fd9a5a # via -r requirements.in +tf-keras-nightly==2.16.0.dev2024021410 \ + --hash=sha256:05a4c19c6a795ec9ed6f4dac69d443601b9206b910524d8b19bad6e362d49f47 \ + --hash=sha256:83d5a1dd3979b42164a2fac4bdd2ed5981ef0c2a0514ca64e4bc67e369caef5c + # via tb-nightly typing-extensions==4.8.0 \ --hash=sha256:8f92fc8806f9a6b641eaa5318da32b44d401efaac0f6678c9bc448ba3605faa0 \ --hash=sha256:df8e4339e9cb77357558cbdbceca33c303714cf861d1eef15e1070055ae8b7ef # via -r requirements.in -urllib3==2.0.7 \ - --hash=sha256:c97dfde1f7bd43a71c8d2a58e369e9b2bf692d1334ea9f9cae55add7d0dd0f84 \ - --hash=sha256:fdb6d215c776278489906c2f8916e6e7d4f5a9b602ccbcfdf7f016fc8da0596e +urllib3==2.2.0 \ + --hash=sha256:051d961ad0c62a94e50ecf1af379c3aba230c66c710493493560c0c223c49f20 \ + --hash=sha256:ce3711610ddce217e6d113a2732fafad960a03fd0318c91faa79481e35c11224 # via requests werkzeug==3.0.1 \ --hash=sha256:507e811ecea72b18a404947aded4b3390e1db8f826b494d76550ef45bb3b1dcc \ @@ -560,81 +532,77 @@ wheel==0.41.3 \ # via # -r requirements.in # astunparse -wrapt==1.14.1 \ - --hash=sha256:00b6d4ea20a906c0ca56d84f93065b398ab74b927a7a3dbd470f6fc503f95dc3 \ - --hash=sha256:01c205616a89d09827986bc4e859bcabd64f5a0662a7fe95e0d359424e0e071b \ - --hash=sha256:02b41b633c6261feff8ddd8d11c711df6842aba629fdd3da10249a53211a72c4 \ - --hash=sha256:07f7a7d0f388028b2df1d916e94bbb40624c59b48ecc6cbc232546706fac74c2 \ - --hash=sha256:11871514607b15cfeb87c547a49bca19fde402f32e2b1c24a632506c0a756656 \ - --hash=sha256:1b376b3f4896e7930f1f772ac4b064ac12598d1c38d04907e696cc4d794b43d3 \ - --hash=sha256:2020f391008ef874c6d9e208b24f28e31bcb85ccff4f335f15a3251d222b92d9 \ - --hash=sha256:21ac0156c4b089b330b7666db40feee30a5d52634cc4560e1905d6529a3897ff \ - --hash=sha256:240b1686f38ae665d1b15475966fe0472f78e71b1b4903c143a842659c8e4cb9 \ - --hash=sha256:257fd78c513e0fb5cdbe058c27a0624c9884e735bbd131935fd49e9fe719d310 \ - --hash=sha256:26046cd03936ae745a502abf44dac702a5e6880b2b01c29aea8ddf3353b68224 \ - --hash=sha256:2b39d38039a1fdad98c87279b48bc5dce2c0ca0d73483b12cb72aa9609278e8a \ - --hash=sha256:2cf71233a0ed05ccdabe209c606fe0bac7379fdcf687f39b944420d2a09fdb57 \ - --hash=sha256:2fe803deacd09a233e4762a1adcea5db5d31e6be577a43352936179d14d90069 \ - --hash=sha256:2feecf86e1f7a86517cab34ae6c2f081fd2d0dac860cb0c0ded96d799d20b335 \ - --hash=sha256:3232822c7d98d23895ccc443bbdf57c7412c5a65996c30442ebe6ed3df335383 \ - --hash=sha256:34aa51c45f28ba7f12accd624225e2b1e5a3a45206aa191f6f9aac931d9d56fe \ - --hash=sha256:358fe87cc899c6bb0ddc185bf3dbfa4ba646f05b1b0b9b5a27c2cb92c2cea204 \ - --hash=sha256:36f582d0c6bc99d5f39cd3ac2a9062e57f3cf606ade29a0a0d6b323462f4dd87 \ - --hash=sha256:380a85cf89e0e69b7cfbe2ea9f765f004ff419f34194018a6827ac0e3edfed4d \ - --hash=sha256:40e7bc81c9e2b2734ea4bc1aceb8a8f0ceaac7c5299bc5d69e37c44d9081d43b \ - --hash=sha256:43ca3bbbe97af00f49efb06e352eae40434ca9d915906f77def219b88e85d907 \ - --hash=sha256:49ef582b7a1152ae2766557f0550a9fcbf7bbd76f43fbdc94dd3bf07cc7168be \ - --hash=sha256:4fcc4649dc762cddacd193e6b55bc02edca674067f5f98166d7713b193932b7f \ - --hash=sha256:5a0f54ce2c092aaf439813735584b9537cad479575a09892b8352fea5e988dc0 \ - --hash=sha256:5a9a0d155deafd9448baff28c08e150d9b24ff010e899311ddd63c45c2445e28 \ - --hash=sha256:5b02d65b9ccf0ef6c34cba6cf5bf2aab1bb2f49c6090bafeecc9cd81ad4ea1c1 \ - --hash=sha256:60db23fa423575eeb65ea430cee741acb7c26a1365d103f7b0f6ec412b893853 \ - --hash=sha256:642c2e7a804fcf18c222e1060df25fc210b9c58db7c91416fb055897fc27e8cc \ - --hash=sha256:6447e9f3ba72f8e2b985a1da758767698efa72723d5b59accefd716e9e8272bf \ - --hash=sha256:6a9a25751acb379b466ff6be78a315e2b439d4c94c1e99cb7266d40a537995d3 \ - --hash=sha256:6b1a564e6cb69922c7fe3a678b9f9a3c54e72b469875aa8018f18b4d1dd1adf3 \ - --hash=sha256:6d323e1554b3d22cfc03cd3243b5bb815a51f5249fdcbb86fda4bf62bab9e164 \ - --hash=sha256:6e743de5e9c3d1b7185870f480587b75b1cb604832e380d64f9504a0535912d1 \ - --hash=sha256:709fe01086a55cf79d20f741f39325018f4df051ef39fe921b1ebe780a66184c \ - --hash=sha256:7b7c050ae976e286906dd3f26009e117eb000fb2cf3533398c5ad9ccc86867b1 \ - --hash=sha256:7d2872609603cb35ca513d7404a94d6d608fc13211563571117046c9d2bcc3d7 \ - --hash=sha256:7ef58fb89674095bfc57c4069e95d7a31cfdc0939e2a579882ac7d55aadfd2a1 \ - --hash=sha256:80bb5c256f1415f747011dc3604b59bc1f91c6e7150bd7db03b19170ee06b320 \ - --hash=sha256:81b19725065dcb43df02b37e03278c011a09e49757287dca60c5aecdd5a0b8ed \ - --hash=sha256:833b58d5d0b7e5b9832869f039203389ac7cbf01765639c7309fd50ef619e0b1 \ - --hash=sha256:88bd7b6bd70a5b6803c1abf6bca012f7ed963e58c68d76ee20b9d751c74a3248 \ - --hash=sha256:8ad85f7f4e20964db4daadcab70b47ab05c7c1cf2a7c1e51087bfaa83831854c \ - --hash=sha256:8c0ce1e99116d5ab21355d8ebe53d9460366704ea38ae4d9f6933188f327b456 \ - --hash=sha256:8d649d616e5c6a678b26d15ece345354f7c2286acd6db868e65fcc5ff7c24a77 \ - --hash=sha256:903500616422a40a98a5a3c4ff4ed9d0066f3b4c951fa286018ecdf0750194ef \ - --hash=sha256:9736af4641846491aedb3c3f56b9bc5568d92b0692303b5a305301a95dfd38b1 \ - --hash=sha256:988635d122aaf2bdcef9e795435662bcd65b02f4f4c1ae37fbee7401c440b3a7 \ - --hash=sha256:9cca3c2cdadb362116235fdbd411735de4328c61425b0aa9f872fd76d02c4e86 \ - --hash=sha256:9e0fd32e0148dd5dea6af5fee42beb949098564cc23211a88d799e434255a1f4 \ - --hash=sha256:9f3e6f9e05148ff90002b884fbc2a86bd303ae847e472f44ecc06c2cd2fcdb2d \ - --hash=sha256:a85d2b46be66a71bedde836d9e41859879cc54a2a04fad1191eb50c2066f6e9d \ - --hash=sha256:a9008dad07d71f68487c91e96579c8567c98ca4c3881b9b113bc7b33e9fd78b8 \ - --hash=sha256:a9a52172be0b5aae932bef82a79ec0a0ce87288c7d132946d645eba03f0ad8a8 \ - --hash=sha256:aa31fdcc33fef9eb2552cbcbfee7773d5a6792c137b359e82879c101e98584c5 \ - --hash=sha256:acae32e13a4153809db37405f5eba5bac5fbe2e2ba61ab227926a22901051c0a \ - --hash=sha256:b014c23646a467558be7da3d6b9fa409b2c567d2110599b7cf9a0c5992b3b471 \ - --hash=sha256:b21bb4c09ffabfa0e85e3a6b623e19b80e7acd709b9f91452b8297ace2a8ab00 \ - --hash=sha256:b5901a312f4d14c59918c221323068fad0540e34324925c8475263841dbdfe68 \ - --hash=sha256:b9b7a708dd92306328117d8c4b62e2194d00c365f18eff11a9b53c6f923b01e3 \ - --hash=sha256:d1967f46ea8f2db647c786e78d8cc7e4313dbd1b0aca360592d8027b8508e24d \ - --hash=sha256:d52a25136894c63de15a35bc0bdc5adb4b0e173b9c0d07a2be9d3ca64a332735 \ - --hash=sha256:d77c85fedff92cf788face9bfa3ebaa364448ebb1d765302e9af11bf449ca36d \ - --hash=sha256:d79d7d5dc8a32b7093e81e97dad755127ff77bcc899e845f41bf71747af0c569 \ - --hash=sha256:dbcda74c67263139358f4d188ae5faae95c30929281bc6866d00573783c422b7 \ - --hash=sha256:ddaea91abf8b0d13443f6dac52e89051a5063c7d014710dcb4d4abb2ff811a59 \ - --hash=sha256:dee0ce50c6a2dd9056c20db781e9c1cfd33e77d2d569f5d1d9321c641bb903d5 \ - --hash=sha256:dee60e1de1898bde3b238f18340eec6148986da0455d8ba7848d50470a7a32fb \ - --hash=sha256:e2f83e18fe2f4c9e7db597e988f72712c0c3676d337d8b101f6758107c42425b \ - --hash=sha256:e3fb1677c720409d5f671e39bac6c9e0e422584e5f518bfd50aa4cbbea02433f \ - --hash=sha256:ecee4132c6cd2ce5308e21672015ddfed1ff975ad0ac8d27168ea82e71413f55 \ - --hash=sha256:ee2b1b1769f6707a8a445162ea16dddf74285c3964f605877a20e38545c3c462 \ - --hash=sha256:ee6acae74a2b91865910eef5e7de37dc6895ad96fa23603d1d27ea69df545015 \ - --hash=sha256:ef3f72c9666bba2bab70d2a8b79f2c6d2c1a42a7f7e2b0ec83bb2f9e383950af +wrapt==1.16.0 \ + --hash=sha256:0d2691979e93d06a95a26257adb7bfd0c93818e89b1406f5a28f36e0d8c1e1fc \ + --hash=sha256:14d7dc606219cdd7405133c713f2c218d4252f2a469003f8c46bb92d5d095d81 \ + --hash=sha256:1a5db485fe2de4403f13fafdc231b0dbae5eca4359232d2efc79025527375b09 \ + --hash=sha256:1acd723ee2a8826f3d53910255643e33673e1d11db84ce5880675954183ec47e \ + --hash=sha256:1ca9b6085e4f866bd584fb135a041bfc32cab916e69f714a7d1d397f8c4891ca \ + --hash=sha256:1dd50a2696ff89f57bd8847647a1c363b687d3d796dc30d4dd4a9d1689a706f0 \ + --hash=sha256:2076fad65c6736184e77d7d4729b63a6d1ae0b70da4868adeec40989858eb3fb \ + --hash=sha256:2a88e6010048489cda82b1326889ec075a8c856c2e6a256072b28eaee3ccf487 \ + --hash=sha256:3ebf019be5c09d400cf7b024aa52b1f3aeebeff51550d007e92c3c1c4afc2a40 \ + --hash=sha256:418abb18146475c310d7a6dc71143d6f7adec5b004ac9ce08dc7a34e2babdc5c \ + --hash=sha256:43aa59eadec7890d9958748db829df269f0368521ba6dc68cc172d5d03ed8060 \ + --hash=sha256:44a2754372e32ab315734c6c73b24351d06e77ffff6ae27d2ecf14cf3d229202 \ + --hash=sha256:490b0ee15c1a55be9c1bd8609b8cecd60e325f0575fc98f50058eae366e01f41 \ + --hash=sha256:49aac49dc4782cb04f58986e81ea0b4768e4ff197b57324dcbd7699c5dfb40b9 \ + --hash=sha256:5eb404d89131ec9b4f748fa5cfb5346802e5ee8836f57d516576e61f304f3b7b \ + --hash=sha256:5f15814a33e42b04e3de432e573aa557f9f0f56458745c2074952f564c50e664 \ + --hash=sha256:5f370f952971e7d17c7d1ead40e49f32345a7f7a5373571ef44d800d06b1899d \ + --hash=sha256:66027d667efe95cc4fa945af59f92c5a02c6f5bb6012bff9e60542c74c75c362 \ + --hash=sha256:66dfbaa7cfa3eb707bbfcd46dab2bc6207b005cbc9caa2199bcbc81d95071a00 \ + --hash=sha256:685f568fa5e627e93f3b52fda002c7ed2fa1800b50ce51f6ed1d572d8ab3e7fc \ + --hash=sha256:6906c4100a8fcbf2fa735f6059214bb13b97f75b1a61777fcf6432121ef12ef1 \ + --hash=sha256:6a42cd0cfa8ffc1915aef79cb4284f6383d8a3e9dcca70c445dcfdd639d51267 \ + --hash=sha256:6dcfcffe73710be01d90cae08c3e548d90932d37b39ef83969ae135d36ef3956 \ + --hash=sha256:6f6eac2360f2d543cc875a0e5efd413b6cbd483cb3ad7ebf888884a6e0d2e966 \ + --hash=sha256:72554a23c78a8e7aa02abbd699d129eead8b147a23c56e08d08dfc29cfdddca1 \ + --hash=sha256:73870c364c11f03ed072dda68ff7aea6d2a3a5c3fe250d917a429c7432e15228 \ + --hash=sha256:73aa7d98215d39b8455f103de64391cb79dfcad601701a3aa0dddacf74911d72 \ + --hash=sha256:75ea7d0ee2a15733684badb16de6794894ed9c55aa5e9903260922f0482e687d \ + --hash=sha256:7bd2d7ff69a2cac767fbf7a2b206add2e9a210e57947dd7ce03e25d03d2de292 \ + --hash=sha256:807cc8543a477ab7422f1120a217054f958a66ef7314f76dd9e77d3f02cdccd0 \ + --hash=sha256:8e9723528b9f787dc59168369e42ae1c3b0d3fadb2f1a71de14531d321ee05b0 \ + --hash=sha256:9090c9e676d5236a6948330e83cb89969f433b1943a558968f659ead07cb3b36 \ + --hash=sha256:9153ed35fc5e4fa3b2fe97bddaa7cbec0ed22412b85bcdaf54aeba92ea37428c \ + --hash=sha256:9159485323798c8dc530a224bd3ffcf76659319ccc7bbd52e01e73bd0241a0c5 \ + --hash=sha256:941988b89b4fd6b41c3f0bfb20e92bd23746579736b7343283297c4c8cbae68f \ + --hash=sha256:94265b00870aa407bd0cbcfd536f17ecde43b94fb8d228560a1e9d3041462d73 \ + --hash=sha256:98b5e1f498a8ca1858a1cdbffb023bfd954da4e3fa2c0cb5853d40014557248b \ + --hash=sha256:9b201ae332c3637a42f02d1045e1d0cccfdc41f1f2f801dafbaa7e9b4797bfc2 \ + --hash=sha256:a0ea261ce52b5952bf669684a251a66df239ec6d441ccb59ec7afa882265d593 \ + --hash=sha256:a33a747400b94b6d6b8a165e4480264a64a78c8a4c734b62136062e9a248dd39 \ + --hash=sha256:a452f9ca3e3267cd4d0fcf2edd0d035b1934ac2bd7e0e57ac91ad6b95c0c6389 \ + --hash=sha256:a86373cf37cd7764f2201b76496aba58a52e76dedfaa698ef9e9688bfd9e41cf \ + --hash=sha256:ac83a914ebaf589b69f7d0a1277602ff494e21f4c2f743313414378f8f50a4cf \ + --hash=sha256:aefbc4cb0a54f91af643660a0a150ce2c090d3652cf4052a5397fb2de549cd89 \ + --hash=sha256:b3646eefa23daeba62643a58aac816945cadc0afaf21800a1421eeba5f6cfb9c \ + --hash=sha256:b47cfad9e9bbbed2339081f4e346c93ecd7ab504299403320bf85f7f85c7d46c \ + --hash=sha256:b935ae30c6e7400022b50f8d359c03ed233d45b725cfdd299462f41ee5ffba6f \ + --hash=sha256:bb2dee3874a500de01c93d5c71415fcaef1d858370d405824783e7a8ef5db440 \ + --hash=sha256:bc57efac2da352a51cc4658878a68d2b1b67dbe9d33c36cb826ca449d80a8465 \ + --hash=sha256:bf5703fdeb350e36885f2875d853ce13172ae281c56e509f4e6eca049bdfb136 \ + --hash=sha256:c31f72b1b6624c9d863fc095da460802f43a7c6868c5dda140f51da24fd47d7b \ + --hash=sha256:c5cd603b575ebceca7da5a3a251e69561bec509e0b46e4993e1cac402b7247b8 \ + --hash=sha256:d2efee35b4b0a347e0d99d28e884dfd82797852d62fcd7ebdeee26f3ceb72cf3 \ + --hash=sha256:d462f28826f4657968ae51d2181a074dfe03c200d6131690b7d65d55b0f360f8 \ + --hash=sha256:d5e49454f19ef621089e204f862388d29e6e8d8b162efce05208913dde5b9ad6 \ + --hash=sha256:da4813f751142436b075ed7aa012a8778aa43a99f7b36afe9b742d3ed8bdc95e \ + --hash=sha256:db2e408d983b0e61e238cf579c09ef7020560441906ca990fe8412153e3b291f \ + --hash=sha256:db98ad84a55eb09b3c32a96c576476777e87c520a34e2519d3e59c44710c002c \ + --hash=sha256:dbed418ba5c3dce92619656802cc5355cb679e58d0d89b50f116e4a9d5a9603e \ + --hash=sha256:dcdba5c86e368442528f7060039eda390cc4091bfd1dca41e8046af7c910dda8 \ + --hash=sha256:decbfa2f618fa8ed81c95ee18a387ff973143c656ef800c9f24fb7e9c16054e2 \ + --hash=sha256:e4fdb9275308292e880dcbeb12546df7f3e0f96c6b41197e0cf37d2826359020 \ + --hash=sha256:eb1b046be06b0fce7249f1d025cd359b4b80fc1c3e24ad9eca33e0dcdb2e4a35 \ + --hash=sha256:eb6e651000a19c96f452c85132811d25e9264d836951022d6e81df2fff38337d \ + --hash=sha256:ed867c42c268f876097248e05b6117a65bcd1e63b779e916fe2e33cd6fd0d3c3 \ + --hash=sha256:edfad1d29c73f9b863ebe7082ae9321374ccb10879eeabc84ba3b69f2579d537 \ + --hash=sha256:f2058f813d4f2b5e3a9eb2eb3faf8f1d99b81c3e51aeda4b168406443e8ba809 \ + --hash=sha256:f6b2d0c6703c988d334f297aa5df18c45e97b0af3679bb75059e0e0bd8b1069d \ + --hash=sha256:f8212564d49c50eb4565e502814f694e240c55551a5f1bc841d4fcaabb0a9b8a \ + --hash=sha256:ffa565331890b90056c01db69c0fe634a776f8019c143a5ae265f9c6bc4bd6d4 # via -r requirements.in setuptools==68.2.2 \ diff --git a/requirements_lock_3_12.txt b/requirements_lock_3_12.txt index 4697b1849dc273..9bc6eff7313ec3 100644 --- a/requirements_lock_3_12.txt +++ b/requirements_lock_3_12.txt @@ -1,6 +1,6 @@ -absl-py==2.0.0 \ - --hash=sha256:9a28abb62774ae4e8edbe2dd4c49ffcd45a6a848952a5eccc6a49f3f0fc1e2f3 \ - --hash=sha256:d9690211c5fcfefcdd1a45470ac2b5c5acd45241c3af71eed96bc5441746c0d5 +absl-py==2.1.0 \ + --hash=sha256:526a04eadab8b4ee719ce68f204172ead1027549089702d99b9059f129ff1308 \ + --hash=sha256:7820790efbb316739cde8b4e19357243fc3608a152024288513dd968d7d959ff # via # keras-nightly # tb-nightly @@ -12,105 +12,101 @@ astunparse==1.6.3 \ --hash=sha256:5ad93a8456f0d084c3456d059fd9a92cce667963232cbf763eac3bc5b7940872 \ --hash=sha256:c2652417f2c8b5bb325c885ae329bdf3f86424075c4fd1a128674bc6fba4b8e8 # via -r requirements.in -cachetools==5.3.2 \ - --hash=sha256:086ee420196f7b2ab9ca2db2520aca326318b68fe5ba8bc4d49cca91add450f2 \ - --hash=sha256:861f35a13a451f94e301ce2bec7cac63e881232ccce7ed67fab9b5df4d3beaa1 - # via google-auth -certifi==2023.7.22 \ - --hash=sha256:539cc1d13202e33ca466e88b2807e29f4c13049d6d87031a3c110744495cb082 \ - --hash=sha256:92d6037539857d8206b8f6ae472e8b77db8058fec5937a1ef3f54304089edbb9 +certifi==2024.2.2 \ + --hash=sha256:0569859f95fc761b18b45ef421b1290a0f65f147e92a1e5eb3e635f9a5e4e66f \ + --hash=sha256:dc383c07b76109f368f6106eee2b593b04a011ea4d55f652c6ca24a754d1cdd1 # via requests -charset-normalizer==3.3.1 \ - --hash=sha256:06cf46bdff72f58645434d467bf5228080801298fbba19fe268a01b4534467f5 \ - --hash=sha256:0c8c61fb505c7dad1d251c284e712d4e0372cef3b067f7ddf82a7fa82e1e9a93 \ - --hash=sha256:10b8dd31e10f32410751b3430996f9807fc4d1587ca69772e2aa940a82ab571a \ - --hash=sha256:1171ef1fc5ab4693c5d151ae0fdad7f7349920eabbaca6271f95969fa0756c2d \ - --hash=sha256:17a866d61259c7de1bdadef418a37755050ddb4b922df8b356503234fff7932c \ - --hash=sha256:1d6bfc32a68bc0933819cfdfe45f9abc3cae3877e1d90aac7259d57e6e0f85b1 \ - --hash=sha256:1ec937546cad86d0dce5396748bf392bb7b62a9eeb8c66efac60e947697f0e58 \ - --hash=sha256:223b4d54561c01048f657fa6ce41461d5ad8ff128b9678cfe8b2ecd951e3f8a2 \ - --hash=sha256:2465aa50c9299d615d757c1c888bc6fef384b7c4aec81c05a0172b4400f98557 \ - --hash=sha256:28f512b9a33235545fbbdac6a330a510b63be278a50071a336afc1b78781b147 \ - --hash=sha256:2c092be3885a1b7899cd85ce24acedc1034199d6fca1483fa2c3a35c86e43041 \ - --hash=sha256:2c4c99f98fc3a1835af8179dcc9013f93594d0670e2fa80c83aa36346ee763d2 \ - --hash=sha256:31445f38053476a0c4e6d12b047b08ced81e2c7c712e5a1ad97bc913256f91b2 \ - --hash=sha256:31bbaba7218904d2eabecf4feec0d07469284e952a27400f23b6628439439fa7 \ - --hash=sha256:34d95638ff3613849f473afc33f65c401a89f3b9528d0d213c7037c398a51296 \ - --hash=sha256:352a88c3df0d1fa886562384b86f9a9e27563d4704ee0e9d56ec6fcd270ea690 \ - --hash=sha256:39b70a6f88eebe239fa775190796d55a33cfb6d36b9ffdd37843f7c4c1b5dc67 \ - --hash=sha256:3c66df3f41abee950d6638adc7eac4730a306b022570f71dd0bd6ba53503ab57 \ - --hash=sha256:3f70fd716855cd3b855316b226a1ac8bdb3caf4f7ea96edcccc6f484217c9597 \ - --hash=sha256:3f9bc2ce123637a60ebe819f9fccc614da1bcc05798bbbaf2dd4ec91f3e08846 \ - --hash=sha256:3fb765362688821404ad6cf86772fc54993ec11577cd5a92ac44b4c2ba52155b \ - --hash=sha256:45f053a0ece92c734d874861ffe6e3cc92150e32136dd59ab1fb070575189c97 \ - --hash=sha256:46fb9970aa5eeca547d7aa0de5d4b124a288b42eaefac677bde805013c95725c \ - --hash=sha256:4cb50a0335382aac15c31b61d8531bc9bb657cfd848b1d7158009472189f3d62 \ - --hash=sha256:4e12f8ee80aa35e746230a2af83e81bd6b52daa92a8afaef4fea4a2ce9b9f4fa \ - --hash=sha256:4f3100d86dcd03c03f7e9c3fdb23d92e32abbca07e7c13ebd7ddfbcb06f5991f \ - --hash=sha256:4f6e2a839f83a6a76854d12dbebde50e4b1afa63e27761549d006fa53e9aa80e \ - --hash=sha256:4f861d94c2a450b974b86093c6c027888627b8082f1299dfd5a4bae8e2292821 \ - --hash=sha256:501adc5eb6cd5f40a6f77fbd90e5ab915c8fd6e8c614af2db5561e16c600d6f3 \ - --hash=sha256:520b7a142d2524f999447b3a0cf95115df81c4f33003c51a6ab637cbda9d0bf4 \ - --hash=sha256:548eefad783ed787b38cb6f9a574bd8664468cc76d1538215d510a3cd41406cb \ - --hash=sha256:555fe186da0068d3354cdf4bbcbc609b0ecae4d04c921cc13e209eece7720727 \ - --hash=sha256:55602981b2dbf8184c098bc10287e8c245e351cd4fdcad050bd7199d5a8bf514 \ - --hash=sha256:58e875eb7016fd014c0eea46c6fa92b87b62c0cb31b9feae25cbbe62c919f54d \ - --hash=sha256:5a3580a4fdc4ac05f9e53c57f965e3594b2f99796231380adb2baaab96e22761 \ - --hash=sha256:5b70bab78accbc672f50e878a5b73ca692f45f5b5e25c8066d748c09405e6a55 \ - --hash=sha256:5ceca5876032362ae73b83347be8b5dbd2d1faf3358deb38c9c88776779b2e2f \ - --hash=sha256:61f1e3fb621f5420523abb71f5771a204b33c21d31e7d9d86881b2cffe92c47c \ - --hash=sha256:633968254f8d421e70f91c6ebe71ed0ab140220469cf87a9857e21c16687c034 \ - --hash=sha256:63a6f59e2d01310f754c270e4a257426fe5a591dc487f1983b3bbe793cf6bac6 \ - --hash=sha256:63accd11149c0f9a99e3bc095bbdb5a464862d77a7e309ad5938fbc8721235ae \ - --hash=sha256:6db3cfb9b4fcecb4390db154e75b49578c87a3b9979b40cdf90d7e4b945656e1 \ - --hash=sha256:71ef3b9be10070360f289aea4838c784f8b851be3ba58cf796262b57775c2f14 \ - --hash=sha256:7ae8e5142dcc7a49168f4055255dbcced01dc1714a90a21f87448dc8d90617d1 \ - --hash=sha256:7b6cefa579e1237ce198619b76eaa148b71894fb0d6bcf9024460f9bf30fd228 \ - --hash=sha256:800561453acdecedaac137bf09cd719c7a440b6800ec182f077bb8e7025fb708 \ - --hash=sha256:82ca51ff0fc5b641a2d4e1cc8c5ff108699b7a56d7f3ad6f6da9dbb6f0145b48 \ - --hash=sha256:851cf693fb3aaef71031237cd68699dded198657ec1e76a76eb8be58c03a5d1f \ - --hash=sha256:854cc74367180beb327ab9d00f964f6d91da06450b0855cbbb09187bcdb02de5 \ - --hash=sha256:87071618d3d8ec8b186d53cb6e66955ef2a0e4fa63ccd3709c0c90ac5a43520f \ - --hash=sha256:871d045d6ccc181fd863a3cd66ee8e395523ebfbc57f85f91f035f50cee8e3d4 \ - --hash=sha256:8aee051c89e13565c6bd366813c386939f8e928af93c29fda4af86d25b73d8f8 \ - --hash=sha256:8af5a8917b8af42295e86b64903156b4f110a30dca5f3b5aedea123fbd638bff \ - --hash=sha256:8ec8ef42c6cd5856a7613dcd1eaf21e5573b2185263d87d27c8edcae33b62a61 \ - --hash=sha256:91e43805ccafa0a91831f9cd5443aa34528c0c3f2cc48c4cb3d9a7721053874b \ - --hash=sha256:9505dc359edb6a330efcd2be825fdb73ee3e628d9010597aa1aee5aa63442e97 \ - --hash=sha256:985c7965f62f6f32bf432e2681173db41336a9c2611693247069288bcb0c7f8b \ - --hash=sha256:9a74041ba0bfa9bc9b9bb2cd3238a6ab3b7618e759b41bd15b5f6ad958d17605 \ - --hash=sha256:9edbe6a5bf8b56a4a84533ba2b2f489d0046e755c29616ef8830f9e7d9cf5728 \ - --hash=sha256:a15c1fe6d26e83fd2e5972425a772cca158eae58b05d4a25a4e474c221053e2d \ - --hash=sha256:a66bcdf19c1a523e41b8e9d53d0cedbfbac2e93c649a2e9502cb26c014d0980c \ - --hash=sha256:ae4070f741f8d809075ef697877fd350ecf0b7c5837ed68738607ee0a2c572cf \ - --hash=sha256:ae55d592b02c4349525b6ed8f74c692509e5adffa842e582c0f861751701a673 \ - --hash=sha256:b578cbe580e3b41ad17b1c428f382c814b32a6ce90f2d8e39e2e635d49e498d1 \ - --hash=sha256:b891a2f68e09c5ef989007fac11476ed33c5c9994449a4e2c3386529d703dc8b \ - --hash=sha256:baec8148d6b8bd5cee1ae138ba658c71f5b03e0d69d5907703e3e1df96db5e41 \ - --hash=sha256:bb06098d019766ca16fc915ecaa455c1f1cd594204e7f840cd6258237b5079a8 \ - --hash=sha256:bc791ec3fd0c4309a753f95bb6c749ef0d8ea3aea91f07ee1cf06b7b02118f2f \ - --hash=sha256:bd28b31730f0e982ace8663d108e01199098432a30a4c410d06fe08fdb9e93f4 \ - --hash=sha256:be4d9c2770044a59715eb57c1144dedea7c5d5ae80c68fb9959515037cde2008 \ - --hash=sha256:c0c72d34e7de5604df0fde3644cc079feee5e55464967d10b24b1de268deceb9 \ - --hash=sha256:c0e842112fe3f1a4ffcf64b06dc4c61a88441c2f02f373367f7b4c1aa9be2ad5 \ - --hash=sha256:c15070ebf11b8b7fd1bfff7217e9324963c82dbdf6182ff7050519e350e7ad9f \ - --hash=sha256:c2000c54c395d9e5e44c99dc7c20a64dc371f777faf8bae4919ad3e99ce5253e \ - --hash=sha256:c30187840d36d0ba2893bc3271a36a517a717f9fd383a98e2697ee890a37c273 \ - --hash=sha256:cb7cd68814308aade9d0c93c5bd2ade9f9441666f8ba5aa9c2d4b389cb5e2a45 \ - --hash=sha256:cd805513198304026bd379d1d516afbf6c3c13f4382134a2c526b8b854da1c2e \ - --hash=sha256:d0bf89afcbcf4d1bb2652f6580e5e55a840fdf87384f6063c4a4f0c95e378656 \ - --hash=sha256:d9137a876020661972ca6eec0766d81aef8a5627df628b664b234b73396e727e \ - --hash=sha256:dbd95e300367aa0827496fe75a1766d198d34385a58f97683fe6e07f89ca3e3c \ - --hash=sha256:dced27917823df984fe0c80a5c4ad75cf58df0fbfae890bc08004cd3888922a2 \ - --hash=sha256:de0b4caa1c8a21394e8ce971997614a17648f94e1cd0640fbd6b4d14cab13a72 \ - --hash=sha256:debb633f3f7856f95ad957d9b9c781f8e2c6303ef21724ec94bea2ce2fcbd056 \ - --hash=sha256:e372d7dfd154009142631de2d316adad3cc1c36c32a38b16a4751ba78da2a397 \ - --hash=sha256:ecd26be9f112c4f96718290c10f4caea6cc798459a3a76636b817a0ed7874e42 \ - --hash=sha256:edc0202099ea1d82844316604e17d2b175044f9bcb6b398aab781eba957224bd \ - --hash=sha256:f194cce575e59ffe442c10a360182a986535fd90b57f7debfaa5c845c409ecc3 \ - --hash=sha256:f5fb672c396d826ca16a022ac04c9dce74e00a1c344f6ad1a0fdc1ba1f332213 \ - --hash=sha256:f6a02a3c7950cafaadcd46a226ad9e12fc9744652cc69f9e5534f98b47f3bbcf \ - --hash=sha256:fe81b35c33772e56f4b6cf62cf4aedc1762ef7162a31e6ac7fe5e40d0149eb67 +charset-normalizer==3.3.2 \ + --hash=sha256:06435b539f889b1f6f4ac1758871aae42dc3a8c0e24ac9e60c2384973ad73027 \ + --hash=sha256:06a81e93cd441c56a9b65d8e1d043daeb97a3d0856d177d5c90ba85acb3db087 \ + --hash=sha256:0a55554a2fa0d408816b3b5cedf0045f4b8e1a6065aec45849de2d6f3f8e9786 \ + --hash=sha256:0b2b64d2bb6d3fb9112bafa732def486049e63de9618b5843bcdd081d8144cd8 \ + --hash=sha256:10955842570876604d404661fbccbc9c7e684caf432c09c715ec38fbae45ae09 \ + --hash=sha256:122c7fa62b130ed55f8f285bfd56d5f4b4a5b503609d181f9ad85e55c89f4185 \ + --hash=sha256:1ceae2f17a9c33cb48e3263960dc5fc8005351ee19db217e9b1bb15d28c02574 \ + --hash=sha256:1d3193f4a680c64b4b6a9115943538edb896edc190f0b222e73761716519268e \ + --hash=sha256:1f79682fbe303db92bc2b1136016a38a42e835d932bab5b3b1bfcfbf0640e519 \ + --hash=sha256:2127566c664442652f024c837091890cb1942c30937add288223dc895793f898 \ + --hash=sha256:22afcb9f253dac0696b5a4be4a1c0f8762f8239e21b99680099abd9b2b1b2269 \ + --hash=sha256:25baf083bf6f6b341f4121c2f3c548875ee6f5339300e08be3f2b2ba1721cdd3 \ + --hash=sha256:2e81c7b9c8979ce92ed306c249d46894776a909505d8f5a4ba55b14206e3222f \ + --hash=sha256:3287761bc4ee9e33561a7e058c72ac0938c4f57fe49a09eae428fd88aafe7bb6 \ + --hash=sha256:34d1c8da1e78d2e001f363791c98a272bb734000fcef47a491c1e3b0505657a8 \ + --hash=sha256:37e55c8e51c236f95b033f6fb391d7d7970ba5fe7ff453dad675e88cf303377a \ + --hash=sha256:3d47fa203a7bd9c5b6cee4736ee84ca03b8ef23193c0d1ca99b5089f72645c73 \ + --hash=sha256:3e4d1f6587322d2788836a99c69062fbb091331ec940e02d12d179c1d53e25fc \ + --hash=sha256:42cb296636fcc8b0644486d15c12376cb9fa75443e00fb25de0b8602e64c1714 \ + --hash=sha256:45485e01ff4d3630ec0d9617310448a8702f70e9c01906b0d0118bdf9d124cf2 \ + --hash=sha256:4a78b2b446bd7c934f5dcedc588903fb2f5eec172f3d29e52a9096a43722adfc \ + --hash=sha256:4ab2fe47fae9e0f9dee8c04187ce5d09f48eabe611be8259444906793ab7cbce \ + --hash=sha256:4d0d1650369165a14e14e1e47b372cfcb31d6ab44e6e33cb2d4e57265290044d \ + --hash=sha256:549a3a73da901d5bc3ce8d24e0600d1fa85524c10287f6004fbab87672bf3e1e \ + --hash=sha256:55086ee1064215781fff39a1af09518bc9255b50d6333f2e4c74ca09fac6a8f6 \ + --hash=sha256:572c3763a264ba47b3cf708a44ce965d98555f618ca42c926a9c1616d8f34269 \ + --hash=sha256:573f6eac48f4769d667c4442081b1794f52919e7edada77495aaed9236d13a96 \ + --hash=sha256:5b4c145409bef602a690e7cfad0a15a55c13320ff7a3ad7ca59c13bb8ba4d45d \ + --hash=sha256:6463effa3186ea09411d50efc7d85360b38d5f09b870c48e4600f63af490e56a \ + --hash=sha256:65f6f63034100ead094b8744b3b97965785388f308a64cf8d7c34f2f2e5be0c4 \ + --hash=sha256:663946639d296df6a2bb2aa51b60a2454ca1cb29835324c640dafb5ff2131a77 \ + --hash=sha256:6897af51655e3691ff853668779c7bad41579facacf5fd7253b0133308cf000d \ + --hash=sha256:68d1f8a9e9e37c1223b656399be5d6b448dea850bed7d0f87a8311f1ff3dabb0 \ + --hash=sha256:6ac7ffc7ad6d040517be39eb591cac5ff87416c2537df6ba3cba3bae290c0fed \ + --hash=sha256:6b3251890fff30ee142c44144871185dbe13b11bab478a88887a639655be1068 \ + --hash=sha256:6c4caeef8fa63d06bd437cd4bdcf3ffefe6738fb1b25951440d80dc7df8c03ac \ + --hash=sha256:6ef1d82a3af9d3eecdba2321dc1b3c238245d890843e040e41e470ffa64c3e25 \ + --hash=sha256:753f10e867343b4511128c6ed8c82f7bec3bd026875576dfd88483c5c73b2fd8 \ + --hash=sha256:7cd13a2e3ddeed6913a65e66e94b51d80a041145a026c27e6bb76c31a853c6ab \ + --hash=sha256:7ed9e526742851e8d5cc9e6cf41427dfc6068d4f5a3bb03659444b4cabf6bc26 \ + --hash=sha256:7f04c839ed0b6b98b1a7501a002144b76c18fb1c1850c8b98d458ac269e26ed2 \ + --hash=sha256:802fe99cca7457642125a8a88a084cef28ff0cf9407060f7b93dca5aa25480db \ + --hash=sha256:80402cd6ee291dcb72644d6eac93785fe2c8b9cb30893c1af5b8fdd753b9d40f \ + --hash=sha256:8465322196c8b4d7ab6d1e049e4c5cb460d0394da4a27d23cc242fbf0034b6b5 \ + --hash=sha256:86216b5cee4b06df986d214f664305142d9c76df9b6512be2738aa72a2048f99 \ + --hash=sha256:87d1351268731db79e0f8e745d92493ee2841c974128ef629dc518b937d9194c \ + --hash=sha256:8bdb58ff7ba23002a4c5808d608e4e6c687175724f54a5dade5fa8c67b604e4d \ + --hash=sha256:8c622a5fe39a48f78944a87d4fb8a53ee07344641b0562c540d840748571b811 \ + --hash=sha256:8d756e44e94489e49571086ef83b2bb8ce311e730092d2c34ca8f7d925cb20aa \ + --hash=sha256:8f4a014bc36d3c57402e2977dada34f9c12300af536839dc38c0beab8878f38a \ + --hash=sha256:9063e24fdb1e498ab71cb7419e24622516c4a04476b17a2dab57e8baa30d6e03 \ + --hash=sha256:90d558489962fd4918143277a773316e56c72da56ec7aa3dc3dbbe20fdfed15b \ + --hash=sha256:923c0c831b7cfcb071580d3f46c4baf50f174be571576556269530f4bbd79d04 \ + --hash=sha256:95f2a5796329323b8f0512e09dbb7a1860c46a39da62ecb2324f116fa8fdc85c \ + --hash=sha256:96b02a3dc4381e5494fad39be677abcb5e6634bf7b4fa83a6dd3112607547001 \ + --hash=sha256:9f96df6923e21816da7e0ad3fd47dd8f94b2a5ce594e00677c0013018b813458 \ + --hash=sha256:a10af20b82360ab00827f916a6058451b723b4e65030c5a18577c8b2de5b3389 \ + --hash=sha256:a50aebfa173e157099939b17f18600f72f84eed3049e743b68ad15bd69b6bf99 \ + --hash=sha256:a981a536974bbc7a512cf44ed14938cf01030a99e9b3a06dd59578882f06f985 \ + --hash=sha256:a9a8e9031d613fd2009c182b69c7b2c1ef8239a0efb1df3f7c8da66d5dd3d537 \ + --hash=sha256:ae5f4161f18c61806f411a13b0310bea87f987c7d2ecdbdaad0e94eb2e404238 \ + --hash=sha256:aed38f6e4fb3f5d6bf81bfa990a07806be9d83cf7bacef998ab1a9bd660a581f \ + --hash=sha256:b01b88d45a6fcb69667cd6d2f7a9aeb4bf53760d7fc536bf679ec94fe9f3ff3d \ + --hash=sha256:b261ccdec7821281dade748d088bb6e9b69e6d15b30652b74cbbac25e280b796 \ + --hash=sha256:b2b0a0c0517616b6869869f8c581d4eb2dd83a4d79e0ebcb7d373ef9956aeb0a \ + --hash=sha256:b4a23f61ce87adf89be746c8a8974fe1c823c891d8f86eb218bb957c924bb143 \ + --hash=sha256:bd8f7df7d12c2db9fab40bdd87a7c09b1530128315d047a086fa3ae3435cb3a8 \ + --hash=sha256:beb58fe5cdb101e3a055192ac291b7a21e3b7ef4f67fa1d74e331a7f2124341c \ + --hash=sha256:c002b4ffc0be611f0d9da932eb0f704fe2602a9a949d1f738e4c34c75b0863d5 \ + --hash=sha256:c083af607d2515612056a31f0a8d9e0fcb5876b7bfc0abad3ecd275bc4ebc2d5 \ + --hash=sha256:c180f51afb394e165eafe4ac2936a14bee3eb10debc9d9e4db8958fe36afe711 \ + --hash=sha256:c235ebd9baae02f1b77bcea61bce332cb4331dc3617d254df3323aa01ab47bd4 \ + --hash=sha256:cd70574b12bb8a4d2aaa0094515df2463cb429d8536cfb6c7ce983246983e5a6 \ + --hash=sha256:d0eccceffcb53201b5bfebb52600a5fb483a20b61da9dbc885f8b103cbe7598c \ + --hash=sha256:d965bba47ddeec8cd560687584e88cf699fd28f192ceb452d1d7ee807c5597b7 \ + --hash=sha256:db364eca23f876da6f9e16c9da0df51aa4f104a972735574842618b8c6d999d4 \ + --hash=sha256:ddbb2551d7e0102e7252db79ba445cdab71b26640817ab1e3e3648dad515003b \ + --hash=sha256:deb6be0ac38ece9ba87dea880e438f25ca3eddfac8b002a2ec3d9183a454e8ae \ + --hash=sha256:e06ed3eb3218bc64786f7db41917d4e686cc4856944f53d5bdf83a6884432e12 \ + --hash=sha256:e27ad930a842b4c5eb8ac0016b0a54f5aebbe679340c26101df33424142c143c \ + --hash=sha256:e537484df0d8f426ce2afb2d0f8e1c3d0b114b83f8850e5f2fbea0e797bd82ae \ + --hash=sha256:eb00ed941194665c332bf8e078baf037d6c35d7c4f3102ea2d4f16ca94a26dc8 \ + --hash=sha256:eb6904c354526e758fda7167b33005998fb68c46fbc10e013ca97f21ca5c8887 \ + --hash=sha256:eb8821e09e916165e160797a6c17edda0679379a4be5c716c260e836e122f54b \ + --hash=sha256:efcb3f6676480691518c177e3b465bcddf57cea040302f9f4e6e191af91174d4 \ + --hash=sha256:f27273b60488abe721a075bcca6d7f3964f9f6f067c8c4c605743023d7d3944f \ + --hash=sha256:f30c3cb33b24454a82faecaf01b19c18562b1e89558fb6c56de4d9118a032fd5 \ + --hash=sha256:fb69256e180cb6c8a894fee62b3afebae785babc1ee98b81cdf68bbca1987f33 \ + --hash=sha256:fd1abc0d89e30cc4e02e4064dc67fcc51bd941eb395c502aac3ec19fab46b519 \ + --hash=sha256:ff8fa367d09b717b2a17a052544193ad76cd49979c805768879cb63d9ca50561 # via requests dill==0.3.7 \ --hash=sha256:76b122c08ef4ce2eedcd4d1abd8e641114bfc6c2867f49f3c41facf65bf19f5e \ @@ -133,6 +129,7 @@ dm-tree==0.1.8 \ --hash=sha256:35cc164a79336bfcfafb47e5f297898359123bbd3330c1967f0c4994f9cf9f60 \ --hash=sha256:378cc8ad93c5fe3590f405a309980721f021c790ca1bdf9b15bb1d59daec57f5 \ --hash=sha256:39070ba268c0491af9fe7a58644d99e8b4f2cde6e5884ba3380bddc84ed43d5f \ + --hash=sha256:435227cf3c5dc63f4de054cf3d00183790bd9ead4c3623138c74dde7f67f521b \ --hash=sha256:5483dca4d7eb1a0d65fe86d3b6a53ae717face83c1f17e0887b1a4a64ae5c410 \ --hash=sha256:694c3654cfd2a81552c08ec66bb5c4a3d48fa292b9a181880fb081c36c5b9134 \ --hash=sha256:803bfc53b4659f447ac694dbd04235f94a73ef7c1fd1e0df7c84ac41e0bc963b \ @@ -140,6 +137,8 @@ dm-tree==0.1.8 \ --hash=sha256:83b7764de0d855338abefc6e3ee9fe40d301668310aa3baea3f778ff051f4393 \ --hash=sha256:8c60a7eadab64c2278861f56bca320b2720f163dca9d7558103c3b77f2416571 \ --hash=sha256:8ed3564abed97c806db122c2d3e1a2b64c74a63debe9903aad795167cc301368 \ + --hash=sha256:94d3f0826311f45ee19b75f5b48c99466e4218a0489e81c0f0167bda50cacf22 \ + --hash=sha256:96a548a406a6fb15fe58f6a30a57ff2f2aafbf25f05afab00c8f5e5977b6c715 \ --hash=sha256:a5d819c38c03f0bb5b3b3703c60e4b170355a0fc6b5819325bf3d4ceb3ae7e80 \ --hash=sha256:ad16ceba90a56ec47cf45b21856d14962ac314787975ef786efb5e6e9ca75ec7 \ --hash=sha256:af4b3d372f2477dcd89a6e717e4a575ca35ccc20cc4454a8a4b6f8838a00672d \ @@ -147,6 +146,7 @@ dm-tree==0.1.8 \ --hash=sha256:b9bd9b9ccb59409d33d51d84b7668010c04c2af7d4a371632874c1ca356cff3d \ --hash=sha256:b9f89a454e98806b44fe9d40ec9eee61f848388f7e79ac2371a55679bd5a3ac6 \ --hash=sha256:bb2d109f42190225112da899b9f3d46d0d5f26aef501c61e43529fe9322530b5 \ + --hash=sha256:c0a94aba18a35457a1b5cd716fd7b46c5dafdc4cf7869b4bae665b91c4682a8e \ --hash=sha256:c5c8c12e3fda754ef6af94161bacdaeda816d941995fac415d6855c6c386af68 \ --hash=sha256:d1612fcaecd79023dbc6a6ae48d51a80beb5c385d6f3f6d71688e57bc8d07de8 \ --hash=sha256:d16e1f2a073604cfcc09f7131ae8d534674f43c3aef4c25742eae295bc60d04f \ @@ -154,6 +154,7 @@ dm-tree==0.1.8 \ --hash=sha256:d40fa4106ca6edc66760246a08f500ec0c85ef55c762fb4a363f6ee739ba02ee \ --hash=sha256:de287fabc464b8734be251e46e06aa9aa1001f34198da2b6ce07bd197172b9cb \ --hash=sha256:e4d714371bb08839e4e5e29024fc95832d9affe129825ef38836b143028bd144 \ + --hash=sha256:ea9e59e0451e7d29aece402d9f908f2e2a80922bcde2ebfd5dcb07750fcbfee8 \ --hash=sha256:f7ac31b9aecccb2c6e1ab29706f6ded3eba0c2c69c770322c9c685929c3d6afb \ --hash=sha256:fa42a605d099ee7d41ba2b5fb75e21423951fd26e5d50583a00471238fb3021d # via keras-nightly @@ -161,71 +162,61 @@ gast==0.4.0 \ --hash=sha256:40feb7b8b8434785585ab224d1568b857edb18297e5a3047f1ba012bc83b42c1 \ --hash=sha256:b7adcdd5adbebf1adf17378da5ba3f543684dbec47b1cda1f3997e573cd542c4 # via -r requirements.in -google-auth==2.23.3 \ - --hash=sha256:6864247895eea5d13b9c57c9e03abb49cb94ce2dc7c58e91cba3248c7477c9e3 \ - --hash=sha256:a8f4608e65c244ead9e0538f181a96c6e11199ec114d41f1d7b1bffa96937bda - # via - # google-auth-oauthlib - # tb-nightly -google-auth-oauthlib==1.1.0 \ - --hash=sha256:089c6e587d36f4803ac7e0720c045c6a8b1fd1790088b8424975b90d0ee61c12 \ - --hash=sha256:83ea8c3b0881e453790baff4448e8a6112ac8778d1de9da0b68010b843937afb - # via tb-nightly -grpcio==1.59.2 \ - --hash=sha256:023088764012411affe7db183d1ada3ad9daf2e23ddc719ff46d7061de661340 \ - --hash=sha256:08d77e682f2bf730a4961eea330e56d2f423c6a9b91ca222e5b1eb24a357b19f \ - --hash=sha256:0a4a3833c0e067f3558538727235cd8a49709bff1003200bbdefa2f09334e4b1 \ - --hash=sha256:0a754aff9e3af63bdc4c75c234b86b9d14e14a28a30c4e324aed1a9b873d755f \ - --hash=sha256:11168ef43e4a43ff1b1a65859f3e0ef1a173e277349e7fb16923ff108160a8cd \ - --hash=sha256:128e20f57c5f27cb0157e73756d1586b83c1b513ebecc83ea0ac37e4b0e4e758 \ - --hash=sha256:1f9524d1d701e399462d2c90ba7c193e49d1711cf429c0d3d97c966856e03d00 \ - --hash=sha256:1ff16d68bf453275466a9a46739061a63584d92f18a0f5b33d19fc97eb69867c \ - --hash=sha256:2067274c88bc6de89c278a672a652b4247d088811ece781a4858b09bdf8448e3 \ - --hash=sha256:2171c39f355ba5b551c5d5928d65aa6c69807fae195b86ef4a7d125bcdb860a9 \ - --hash=sha256:242adc47725b9a499ee77c6a2e36688fa6c96484611f33b1be4c57ab075a92dd \ - --hash=sha256:27f879ae604a7fcf371e59fba6f3ff4635a4c2a64768bd83ff0cac503142fef4 \ - --hash=sha256:2b230028a008ae1d0f430acb227d323ff8a619017415cf334c38b457f814119f \ - --hash=sha256:3059668df17627f0e0fa680e9ef8c995c946c792612e9518f5cc1503be14e90b \ - --hash=sha256:31176aa88f36020055ace9adff2405a33c8bdbfa72a9c4980e25d91b2f196873 \ - --hash=sha256:36f53c2b3449c015880e7d55a89c992c357f176327b0d2873cdaaf9628a37c69 \ - --hash=sha256:3b4368b33908f683a363f376dfb747d40af3463a6e5044afee07cf9436addf96 \ - --hash=sha256:3c61d641d4f409c5ae46bfdd89ea42ce5ea233dcf69e74ce9ba32b503c727e29 \ - --hash=sha256:4abb717e320e74959517dc8e84a9f48fbe90e9abe19c248541e9418b1ce60acd \ - --hash=sha256:4c93f4abbb54321ee6471e04a00139c80c754eda51064187963ddf98f5cf36a4 \ - --hash=sha256:535561990e075fa6bd4b16c4c3c1096b9581b7bb35d96fac4650f1181e428268 \ - --hash=sha256:53c9aa5ddd6857c0a1cd0287225a2a25873a8e09727c2e95c4aebb1be83a766a \ - --hash=sha256:5d573e70a6fe77555fb6143c12d3a7d3fa306632a3034b4e7c59ca09721546f8 \ - --hash=sha256:6009386a2df66159f64ac9f20425ae25229b29b9dd0e1d3dd60043f037e2ad7e \ - --hash=sha256:686e975a5d16602dc0982c7c703948d17184bd1397e16c8ee03511ecb8c4cdda \ - --hash=sha256:6959fb07e8351e20501ffb8cc4074c39a0b7ef123e1c850a7f8f3afdc3a3da01 \ - --hash=sha256:6b25ed37c27e652db01be341af93fbcea03d296c024d8a0e680017a268eb85dd \ - --hash=sha256:6da6dea3a1bacf99b3c2187e296db9a83029ed9c38fd4c52b7c9b7326d13c828 \ - --hash=sha256:72ca2399097c0b758198f2ff30f7178d680de8a5cfcf3d9b73a63cf87455532e \ - --hash=sha256:73abb8584b0cf74d37f5ef61c10722adc7275502ab71789a8fe3cb7ef04cf6e2 \ - --hash=sha256:74100fecaec8a535e380cf5f2fb556ff84957d481c13e54051c52e5baac70541 \ - --hash=sha256:75c6ecb70e809cf1504465174343113f51f24bc61e22a80ae1c859f3f7034c6d \ - --hash=sha256:7cf05053242f61ba94014dd3a986e11a083400a32664058f80bf4cf817c0b3a1 \ - --hash=sha256:9411e24328a2302e279e70cae6e479f1fddde79629fcb14e03e6d94b3956eabf \ - --hash=sha256:a213acfbf186b9f35803b52e4ca9addb153fc0b67f82a48f961be7000ecf6721 \ - --hash=sha256:bb7e0fe6ad73b7f06d7e2b689c19a71cf5cc48f0c2bf8608469e51ffe0bd2867 \ - --hash=sha256:c2504eed520958a5b77cc99458297cb7906308cb92327f35fb7fbbad4e9b2188 \ - --hash=sha256:c35aa9657f5d5116d23b934568e0956bd50c615127810fffe3ac356a914c176a \ - --hash=sha256:c5f09cffa619adfb44799fa4a81c2a1ad77c887187613fb0a8f201ab38d89ba1 \ - --hash=sha256:c978f864b35f2261e0819f5cd88b9830b04dc51bcf055aac3c601e525a10d2ba \ - --hash=sha256:cbe946b3e6e60a7b4618f091e62a029cb082b109a9d6b53962dd305087c6e4fd \ - --hash=sha256:cc3e4cd087f07758b16bef8f31d88dbb1b5da5671d2f03685ab52dece3d7a16e \ - --hash=sha256:cf0dead5a2c5a3347af2cfec7131d4f2a2e03c934af28989c9078f8241a491fa \ - --hash=sha256:d2794f0e68b3085d99b4f6ff9c089f6fdd02b32b9d3efdfbb55beac1bf22d516 \ - --hash=sha256:d2fa68a96a30dd240be80bbad838a0ac81a61770611ff7952b889485970c4c71 \ - --hash=sha256:d6f70406695e3220f09cd7a2f879333279d91aa4a8a1d34303b56d61a8180137 \ - --hash=sha256:d8f9cd4ad1be90b0cf350a2f04a38a36e44a026cac1e036ac593dc48efe91d52 \ - --hash=sha256:da2d94c15f88cd40d7e67f7919d4f60110d2b9d5b1e08cf354c2be773ab13479 \ - --hash=sha256:e1727c1c0e394096bb9af185c6923e8ea55a5095b8af44f06903bcc0e06800a2 \ - --hash=sha256:e420ced29b5904cdf9ee5545e23f9406189d8acb6750916c2db4793dada065c6 \ - --hash=sha256:e82c5cf1495244adf5252f925ac5932e5fd288b3e5ab6b70bec5593074b7236c \ - --hash=sha256:f1ef0d39bc1feb420caf549b3c657c871cad4ebbcf0580c4d03816b0590de0cf \ - --hash=sha256:f8753a6c88d1d0ba64302309eecf20f70d2770f65ca02d83c2452279085bfcd3 \ - --hash=sha256:f93dbf58f03146164048be5426ffde298b237a5e059144847e4940f5b80172c3 +grpcio==1.60.1 \ + --hash=sha256:0250a7a70b14000fa311de04b169cc7480be6c1a769b190769d347939d3232a8 \ + --hash=sha256:069fe2aeee02dfd2135d562d0663fe70fbb69d5eed6eb3389042a7e963b54de8 \ + --hash=sha256:082081e6a36b6eb5cf0fd9a897fe777dbb3802176ffd08e3ec6567edd85bc104 \ + --hash=sha256:0c5807e9152eff15f1d48f6b9ad3749196f79a4a050469d99eecb679be592acc \ + --hash=sha256:14e8f2c84c0832773fb3958240c69def72357bc11392571f87b2d7b91e0bb092 \ + --hash=sha256:2a6087f234cb570008a6041c8ffd1b7d657b397fdd6d26e83d72283dae3527b1 \ + --hash=sha256:2bb2a2911b028f01c8c64d126f6b632fcd8a9ac975aa1b3855766c94e4107180 \ + --hash=sha256:2f44c32aef186bbba254129cea1df08a20be414144ac3bdf0e84b24e3f3b2e05 \ + --hash=sha256:30e980cd6db1088c144b92fe376747328d5554bc7960ce583ec7b7d81cd47287 \ + --hash=sha256:33aed0a431f5befeffd9d346b0fa44b2c01aa4aeae5ea5b2c03d3e25e0071216 \ + --hash=sha256:33bdea30dcfd4f87b045d404388469eb48a48c33a6195a043d116ed1b9a0196c \ + --hash=sha256:39aa848794b887120b1d35b1b994e445cc028ff602ef267f87c38122c1add50d \ + --hash=sha256:4216e67ad9a4769117433814956031cb300f85edc855252a645a9a724b3b6594 \ + --hash=sha256:49c9b6a510e3ed8df5f6f4f3c34d7fbf2d2cae048ee90a45cd7415abab72912c \ + --hash=sha256:4eec8b8c1c2c9b7125508ff7c89d5701bf933c99d3910e446ed531cd16ad5d87 \ + --hash=sha256:50d56280b482875d1f9128ce596e59031a226a8b84bec88cb2bf76c289f5d0de \ + --hash=sha256:53b69e79d00f78c81eecfb38f4516080dc7f36a198b6b37b928f1c13b3c063e9 \ + --hash=sha256:55ccb7db5a665079d68b5c7c86359ebd5ebf31a19bc1a91c982fd622f1e31ff2 \ + --hash=sha256:5a1ebbae7e2214f51b1f23b57bf98eeed2cf1ba84e4d523c48c36d5b2f8829ff \ + --hash=sha256:61b7199cd2a55e62e45bfb629a35b71fc2c0cb88f686a047f25b1112d3810904 \ + --hash=sha256:660fc6b9c2a9ea3bb2a7e64ba878c98339abaf1811edca904ac85e9e662f1d73 \ + --hash=sha256:6d140bdeb26cad8b93c1455fa00573c05592793c32053d6e0016ce05ba267549 \ + --hash=sha256:6e490fa5f7f5326222cb9f0b78f207a2b218a14edf39602e083d5f617354306f \ + --hash=sha256:6ecf21d20d02d1733e9c820fb5c114c749d888704a7ec824b545c12e78734d1c \ + --hash=sha256:70c83bb530572917be20c21f3b6be92cd86b9aecb44b0c18b1d3b2cc3ae47df0 \ + --hash=sha256:72153a0d2e425f45b884540a61c6639436ddafa1829a42056aa5764b84108b8e \ + --hash=sha256:73e14acd3d4247169955fae8fb103a2b900cfad21d0c35f0dcd0fdd54cd60367 \ + --hash=sha256:76eaaba891083fcbe167aa0f03363311a9f12da975b025d30e94b93ac7a765fc \ + --hash=sha256:79ae0dc785504cb1e1788758c588c711f4e4a0195d70dff53db203c95a0bd303 \ + --hash=sha256:7d142bcd604166417929b071cd396aa13c565749a4c840d6c702727a59d835eb \ + --hash=sha256:8c9554ca8e26241dabe7951aa1fa03a1ba0856688ecd7e7bdbdd286ebc272e4c \ + --hash=sha256:8d488fbdbf04283f0d20742b64968d44825617aa6717b07c006168ed16488804 \ + --hash=sha256:91422ba785a8e7a18725b1dc40fbd88f08a5bb4c7f1b3e8739cab24b04fa8a03 \ + --hash=sha256:9a66f4d2a005bc78e61d805ed95dedfcb35efa84b7bba0403c6d60d13a3de2d6 \ + --hash=sha256:9b106bc52e7f28170e624ba61cc7dc6829566e535a6ec68528f8e1afbed1c41f \ + --hash=sha256:9b54577032d4f235452f77a83169b6527bf4b77d73aeada97d45b2aaf1bf5ce0 \ + --hash=sha256:a09506eb48fa5493c58f946c46754ef22f3ec0df64f2b5149373ff31fb67f3dd \ + --hash=sha256:a212e5dea1a4182e40cd3e4067ee46be9d10418092ce3627475e995cca95de21 \ + --hash=sha256:a731ac5cffc34dac62053e0da90f0c0b8560396a19f69d9703e88240c8f05858 \ + --hash=sha256:af5ef6cfaf0d023c00002ba25d0751e5995fa0e4c9eec6cd263c30352662cbce \ + --hash=sha256:b58b855d0071575ea9c7bc0d84a06d2edfbfccec52e9657864386381a7ce1ae9 \ + --hash=sha256:bc808924470643b82b14fe121923c30ec211d8c693e747eba8a7414bc4351a23 \ + --hash=sha256:c557e94e91a983e5b1e9c60076a8fd79fea1e7e06848eb2e48d0ccfb30f6e073 \ + --hash=sha256:c71be3f86d67d8d1311c6076a4ba3b75ba5703c0b856b4e691c9097f9b1e8bd2 \ + --hash=sha256:c8754c75f55781515a3005063d9a05878b2cfb3cb7e41d5401ad0cf19de14872 \ + --hash=sha256:cb0af13433dbbd1c806e671d81ec75bd324af6ef75171fd7815ca3074fe32bfe \ + --hash=sha256:cba6209c96828711cb7c8fcb45ecef8c8859238baf15119daa1bef0f6c84bfe7 \ + --hash=sha256:cf77f8cf2a651fbd869fbdcb4a1931464189cd210abc4cfad357f1cacc8642a6 \ + --hash=sha256:d7404cebcdb11bb5bd40bf94131faf7e9a7c10a6c60358580fe83913f360f929 \ + --hash=sha256:dd1d3a8d1d2e50ad9b59e10aa7f07c7d1be2b367f3f2d33c5fade96ed5460962 \ + --hash=sha256:e5d97c65ea7e097056f3d1ead77040ebc236feaf7f71489383d20f3b4c28412a \ + --hash=sha256:f1c3dc536b3ee124e8b24feb7533e5c70b9f2ef833e3b2e5513b2897fd46763a \ + --hash=sha256:f2212796593ad1d0235068c79836861f2201fc7137a99aa2fea7beeb3b101177 \ + --hash=sha256:fead980fbc68512dfd4e0c7b1f5754c2a8e5015a04dea454b9cada54a8423525 # via # -r requirements.in # tb-nightly @@ -258,150 +249,156 @@ h5py==3.10.0 \ # via # -r requirements.in # keras-nightly -idna==3.4 \ - --hash=sha256:814f528e8dead7d329833b91c5faa87d60bf71824cd12a7530b5526063d02cb4 \ - --hash=sha256:90b77e79eaa3eba6de819a0c442c0b4ceefc341a7a2ab77d7562bf49f425c5c2 +idna==3.6 \ + --hash=sha256:9ecdbbd083b06798ae1e86adcbfe8ab1479cf864e4ee30fe4e46a003d12491ca \ + --hash=sha256:c05567e9c24a6b9faaa835c4821bad0590fbb9d5779e7caa6e1cc4978e7eb24f # via requests jax==0.4.7 \ --hash=sha256:5e7002d74db25f97c99b979d4ba1233b1ef26e1597e5fc468ad11d1c8a9dc4f8 # via -r requirements.in -keras-nightly==3.0.0.dev2023103103 \ - --hash=sha256:25a6a9030d7e067d535ec41ae6700636c90cabada80fbddb3bb8775006807860 \ - --hash=sha256:c39fccebb3d4cc9d838371981c4d3eef08fc06eadf6cffb39b356cb625bab50f +keras-nightly==3.0.4.dev2024021403 \ + --hash=sha256:24ce69d29d582771685bf4235f59663723405b5a5b16f3eaff2657e52e74663a \ + --hash=sha256:9f416e66b820ef833779d219d255b346b8b90a72fdbd0b2f1e90a43ad142a03d # via -r requirements.in -lit==17.0.4 \ - --hash=sha256:ee2e180128e770abc6aed3a02de2daf09d81b7d30225e315205d3599c311d304 +lit==17.0.6 \ + --hash=sha256:dfa9af9b55fc4509a56be7bf2346f079d7f4a242d583b9f2e0b078fd0abae31b # via -r requirements.in -markdown==3.5 \ - --hash=sha256:4afb124395ce5fc34e6d9886dab977fd9ae987fc6e85689f08278cf0c69d4bf3 \ - --hash=sha256:a807eb2e4778d9156c8f07876c6e4d50b5494c5665c4834f67b06459dfd877b3 +markdown==3.5.2 \ + --hash=sha256:d43323865d89fc0cb9b20c75fc8ad313af307cc087e84b657d9eec768eddeadd \ + --hash=sha256:e1ac7b3dc550ee80e602e71c1d168002f062e49f1b11e26a36264dafd4df2ef8 # via tb-nightly markdown-it-py==3.0.0 \ --hash=sha256:355216845c60bd96232cd8d8c40e8f9765cc86f46880e43a8fd22dc1a1a8cab1 \ --hash=sha256:e3f60a94fa066dc52ec76661e37c851cb232d92f9886b15cb560aaada2df8feb # via rich -markupsafe==2.1.3 \ - --hash=sha256:05fb21170423db021895e1ea1e1f3ab3adb85d1c2333cbc2310f2a26bc77272e \ - --hash=sha256:0a4e4a1aff6c7ac4cd55792abf96c915634c2b97e3cc1c7129578aa68ebd754e \ - --hash=sha256:10bbfe99883db80bdbaff2dcf681dfc6533a614f700da1287707e8a5d78a8431 \ - --hash=sha256:134da1eca9ec0ae528110ccc9e48041e0828d79f24121a1a146161103c76e686 \ - --hash=sha256:14ff806850827afd6b07a5f32bd917fb7f45b046ba40c57abdb636674a8b559c \ - --hash=sha256:1577735524cdad32f9f694208aa75e422adba74f1baee7551620e43a3141f559 \ - --hash=sha256:1b40069d487e7edb2676d3fbdb2b0829ffa2cd63a2ec26c4938b2d34391b4ecc \ - --hash=sha256:1b8dd8c3fd14349433c79fa8abeb573a55fc0fdd769133baac1f5e07abf54aeb \ - --hash=sha256:1f67c7038d560d92149c060157d623c542173016c4babc0c1913cca0564b9939 \ - --hash=sha256:282c2cb35b5b673bbcadb33a585408104df04f14b2d9b01d4c345a3b92861c2c \ - --hash=sha256:2c1b19b3aaacc6e57b7e25710ff571c24d6c3613a45e905b1fde04d691b98ee0 \ - --hash=sha256:2ef12179d3a291be237280175b542c07a36e7f60718296278d8593d21ca937d4 \ - --hash=sha256:338ae27d6b8745585f87218a3f23f1512dbf52c26c28e322dbe54bcede54ccb9 \ - --hash=sha256:3c0fae6c3be832a0a0473ac912810b2877c8cb9d76ca48de1ed31e1c68386575 \ - --hash=sha256:3fd4abcb888d15a94f32b75d8fd18ee162ca0c064f35b11134be77050296d6ba \ - --hash=sha256:42de32b22b6b804f42c5d98be4f7e5e977ecdd9ee9b660fda1a3edf03b11792d \ - --hash=sha256:47d4f1c5f80fc62fdd7777d0d40a2e9dda0a05883ab11374334f6c4de38adffd \ - --hash=sha256:504b320cd4b7eff6f968eddf81127112db685e81f7e36e75f9f84f0df46041c3 \ - --hash=sha256:525808b8019e36eb524b8c68acdd63a37e75714eac50e988180b169d64480a00 \ - --hash=sha256:56d9f2ecac662ca1611d183feb03a3fa4406469dafe241673d521dd5ae92a155 \ - --hash=sha256:5bbe06f8eeafd38e5d0a4894ffec89378b6c6a625ff57e3028921f8ff59318ac \ - --hash=sha256:65c1a9bcdadc6c28eecee2c119465aebff8f7a584dd719facdd9e825ec61ab52 \ - --hash=sha256:68e78619a61ecf91e76aa3e6e8e33fc4894a2bebe93410754bd28fce0a8a4f9f \ - --hash=sha256:69c0f17e9f5a7afdf2cc9fb2d1ce6aabdb3bafb7f38017c0b77862bcec2bbad8 \ - --hash=sha256:6b2b56950d93e41f33b4223ead100ea0fe11f8e6ee5f641eb753ce4b77a7042b \ - --hash=sha256:715d3562f79d540f251b99ebd6d8baa547118974341db04f5ad06d5ea3eb8007 \ - --hash=sha256:787003c0ddb00500e49a10f2844fac87aa6ce977b90b0feaaf9de23c22508b24 \ - --hash=sha256:7ef3cb2ebbf91e330e3bb937efada0edd9003683db6b57bb108c4001f37a02ea \ - --hash=sha256:8023faf4e01efadfa183e863fefde0046de576c6f14659e8782065bcece22198 \ - --hash=sha256:8758846a7e80910096950b67071243da3e5a20ed2546e6392603c096778d48e0 \ - --hash=sha256:8afafd99945ead6e075b973fefa56379c5b5c53fd8937dad92c662da5d8fd5ee \ - --hash=sha256:8c41976a29d078bb235fea9b2ecd3da465df42a562910f9022f1a03107bd02be \ - --hash=sha256:8e254ae696c88d98da6555f5ace2279cf7cd5b3f52be2b5cf97feafe883b58d2 \ - --hash=sha256:8f9293864fe09b8149f0cc42ce56e3f0e54de883a9de90cd427f191c346eb2e1 \ - --hash=sha256:9402b03f1a1b4dc4c19845e5c749e3ab82d5078d16a2a4c2cd2df62d57bb0707 \ - --hash=sha256:962f82a3086483f5e5f64dbad880d31038b698494799b097bc59c2edf392fce6 \ - --hash=sha256:9aad3c1755095ce347e26488214ef77e0485a3c34a50c5a5e2471dff60b9dd9c \ - --hash=sha256:9dcdfd0eaf283af041973bff14a2e143b8bd64e069f4c383416ecd79a81aab58 \ - --hash=sha256:aa57bd9cf8ae831a362185ee444e15a93ecb2e344c8e52e4d721ea3ab6ef1823 \ - --hash=sha256:aa7bd130efab1c280bed0f45501b7c8795f9fdbeb02e965371bbef3523627779 \ - --hash=sha256:ab4a0df41e7c16a1392727727e7998a467472d0ad65f3ad5e6e765015df08636 \ - --hash=sha256:ad9e82fb8f09ade1c3e1b996a6337afac2b8b9e365f926f5a61aacc71adc5b3c \ - --hash=sha256:af598ed32d6ae86f1b747b82783958b1a4ab8f617b06fe68795c7f026abbdcad \ - --hash=sha256:b076b6226fb84157e3f7c971a47ff3a679d837cf338547532ab866c57930dbee \ - --hash=sha256:b7ff0f54cb4ff66dd38bebd335a38e2c22c41a8ee45aa608efc890ac3e3931bc \ - --hash=sha256:bfce63a9e7834b12b87c64d6b155fdd9b3b96191b6bd334bf37db7ff1fe457f2 \ - --hash=sha256:c011a4149cfbcf9f03994ec2edffcb8b1dc2d2aede7ca243746df97a5d41ce48 \ - --hash=sha256:c9c804664ebe8f83a211cace637506669e7890fec1b4195b505c214e50dd4eb7 \ - --hash=sha256:ca379055a47383d02a5400cb0d110cef0a776fc644cda797db0c5696cfd7e18e \ - --hash=sha256:cb0932dc158471523c9637e807d9bfb93e06a95cbf010f1a38b98623b929ef2b \ - --hash=sha256:cd0f502fe016460680cd20aaa5a76d241d6f35a1c3350c474bac1273803893fa \ - --hash=sha256:ceb01949af7121f9fc39f7d27f91be8546f3fb112c608bc4029aef0bab86a2a5 \ - --hash=sha256:d080e0a5eb2529460b30190fcfcc4199bd7f827663f858a226a81bc27beaa97e \ - --hash=sha256:dd15ff04ffd7e05ffcb7fe79f1b98041b8ea30ae9234aed2a9168b5797c3effb \ - --hash=sha256:df0be2b576a7abbf737b1575f048c23fb1d769f267ec4358296f31c2479db8f9 \ - --hash=sha256:e09031c87a1e51556fdcb46e5bd4f59dfb743061cf93c4d6831bf894f125eb57 \ - --hash=sha256:e4dd52d80b8c83fdce44e12478ad2e85c64ea965e75d66dbeafb0a3e77308fcc \ - --hash=sha256:f698de3fd0c4e6972b92290a45bd9b1536bffe8c6759c62471efaa8acb4c37bc \ - --hash=sha256:fec21693218efe39aa7f8599346e90c705afa52c5b31ae019b2e57e8f6542bb2 \ - --hash=sha256:ffcc3f7c66b5f5b7931a5aa68fc9cecc51e685ef90282f4a82f0f5e9b704ad11 +markupsafe==2.1.5 \ + --hash=sha256:00e046b6dd71aa03a41079792f8473dc494d564611a8f89bbbd7cb93295ebdcf \ + --hash=sha256:075202fa5b72c86ad32dc7d0b56024ebdbcf2048c0ba09f1cde31bfdd57bcfff \ + --hash=sha256:0e397ac966fdf721b2c528cf028494e86172b4feba51d65f81ffd65c63798f3f \ + --hash=sha256:17b950fccb810b3293638215058e432159d2b71005c74371d784862b7e4683f3 \ + --hash=sha256:1f3fbcb7ef1f16e48246f704ab79d79da8a46891e2da03f8783a5b6fa41a9532 \ + --hash=sha256:2174c595a0d73a3080ca3257b40096db99799265e1c27cc5a610743acd86d62f \ + --hash=sha256:2b7c57a4dfc4f16f7142221afe5ba4e093e09e728ca65c51f5620c9aaeb9a617 \ + --hash=sha256:2d2d793e36e230fd32babe143b04cec8a8b3eb8a3122d2aceb4a371e6b09b8df \ + --hash=sha256:30b600cf0a7ac9234b2638fbc0fb6158ba5bdcdf46aeb631ead21248b9affbc4 \ + --hash=sha256:397081c1a0bfb5124355710fe79478cdbeb39626492b15d399526ae53422b906 \ + --hash=sha256:3a57fdd7ce31c7ff06cdfbf31dafa96cc533c21e443d57f5b1ecc6cdc668ec7f \ + --hash=sha256:3c6b973f22eb18a789b1460b4b91bf04ae3f0c4234a0a6aa6b0a92f6f7b951d4 \ + --hash=sha256:3e53af139f8579a6d5f7b76549125f0d94d7e630761a2111bc431fd820e163b8 \ + --hash=sha256:4096e9de5c6fdf43fb4f04c26fb114f61ef0bf2e5604b6ee3019d51b69e8c371 \ + --hash=sha256:4275d846e41ecefa46e2015117a9f491e57a71ddd59bbead77e904dc02b1bed2 \ + --hash=sha256:4c31f53cdae6ecfa91a77820e8b151dba54ab528ba65dfd235c80b086d68a465 \ + --hash=sha256:4f11aa001c540f62c6166c7726f71f7573b52c68c31f014c25cc7901deea0b52 \ + --hash=sha256:5049256f536511ee3f7e1b3f87d1d1209d327e818e6ae1365e8653d7e3abb6a6 \ + --hash=sha256:58c98fee265677f63a4385256a6d7683ab1832f3ddd1e66fe948d5880c21a169 \ + --hash=sha256:598e3276b64aff0e7b3451b72e94fa3c238d452e7ddcd893c3ab324717456bad \ + --hash=sha256:5b7b716f97b52c5a14bffdf688f971b2d5ef4029127f1ad7a513973cfd818df2 \ + --hash=sha256:5dedb4db619ba5a2787a94d877bc8ffc0566f92a01c0ef214865e54ecc9ee5e0 \ + --hash=sha256:619bc166c4f2de5caa5a633b8b7326fbe98e0ccbfacabd87268a2b15ff73a029 \ + --hash=sha256:629ddd2ca402ae6dbedfceeba9c46d5f7b2a61d9749597d4307f943ef198fc1f \ + --hash=sha256:656f7526c69fac7f600bd1f400991cc282b417d17539a1b228617081106feb4a \ + --hash=sha256:6ec585f69cec0aa07d945b20805be741395e28ac1627333b1c5b0105962ffced \ + --hash=sha256:72b6be590cc35924b02c78ef34b467da4ba07e4e0f0454a2c5907f473fc50ce5 \ + --hash=sha256:7502934a33b54030eaf1194c21c692a534196063db72176b0c4028e140f8f32c \ + --hash=sha256:7a68b554d356a91cce1236aa7682dc01df0edba8d043fd1ce607c49dd3c1edcf \ + --hash=sha256:7b2e5a267c855eea6b4283940daa6e88a285f5f2a67f2220203786dfa59b37e9 \ + --hash=sha256:823b65d8706e32ad2df51ed89496147a42a2a6e01c13cfb6ffb8b1e92bc910bb \ + --hash=sha256:8590b4ae07a35970728874632fed7bd57b26b0102df2d2b233b6d9d82f6c62ad \ + --hash=sha256:8dd717634f5a044f860435c1d8c16a270ddf0ef8588d4887037c5028b859b0c3 \ + --hash=sha256:8dec4936e9c3100156f8a2dc89c4b88d5c435175ff03413b443469c7c8c5f4d1 \ + --hash=sha256:97cafb1f3cbcd3fd2b6fbfb99ae11cdb14deea0736fc2b0952ee177f2b813a46 \ + --hash=sha256:a17a92de5231666cfbe003f0e4b9b3a7ae3afb1ec2845aadc2bacc93ff85febc \ + --hash=sha256:a549b9c31bec33820e885335b451286e2969a2d9e24879f83fe904a5ce59d70a \ + --hash=sha256:ac07bad82163452a6884fe8fa0963fb98c2346ba78d779ec06bd7a6262132aee \ + --hash=sha256:ae2ad8ae6ebee9d2d94b17fb62763125f3f374c25618198f40cbb8b525411900 \ + --hash=sha256:b91c037585eba9095565a3556f611e3cbfaa42ca1e865f7b8015fe5c7336d5a5 \ + --hash=sha256:bc1667f8b83f48511b94671e0e441401371dfd0f0a795c7daa4a3cd1dde55bea \ + --hash=sha256:bec0a414d016ac1a18862a519e54b2fd0fc8bbfd6890376898a6c0891dd82e9f \ + --hash=sha256:bf50cd79a75d181c9181df03572cdce0fbb75cc353bc350712073108cba98de5 \ + --hash=sha256:bff1b4290a66b490a2f4719358c0cdcd9bafb6b8f061e45c7a2460866bf50c2e \ + --hash=sha256:c061bb86a71b42465156a3ee7bd58c8c2ceacdbeb95d05a99893e08b8467359a \ + --hash=sha256:c8b29db45f8fe46ad280a7294f5c3ec36dbac9491f2d1c17345be8e69cc5928f \ + --hash=sha256:ce409136744f6521e39fd8e2a24c53fa18ad67aa5bc7c2cf83645cce5b5c4e50 \ + --hash=sha256:d050b3361367a06d752db6ead6e7edeb0009be66bc3bae0ee9d97fb326badc2a \ + --hash=sha256:d283d37a890ba4c1ae73ffadf8046435c76e7bc2247bbb63c00bd1a709c6544b \ + --hash=sha256:d9fad5155d72433c921b782e58892377c44bd6252b5af2f67f16b194987338a4 \ + --hash=sha256:daa4ee5a243f0f20d528d939d06670a298dd39b1ad5f8a72a4275124a7819eff \ + --hash=sha256:db0b55e0f3cc0be60c1f19efdde9a637c32740486004f20d1cff53c3c0ece4d2 \ + --hash=sha256:e61659ba32cf2cf1481e575d0462554625196a1f2fc06a1c777d3f48e8865d46 \ + --hash=sha256:ea3d8a3d18833cf4304cd2fc9cbb1efe188ca9b5efef2bdac7adc20594a0e46b \ + --hash=sha256:ec6a563cff360b50eed26f13adc43e61bc0c04d94b8be985e6fb24b81f6dcfdf \ + --hash=sha256:f5dfb42c4604dddc8e4305050aa6deb084540643ed5804d7455b5df8fe16f5e5 \ + --hash=sha256:fa173ec60341d6bb97a89f5ea19c85c5643c1e7dedebc22f5181eb73573142c5 \ + --hash=sha256:fa9db3f79de01457b03d4f01b34cf91bc0048eb2c3846ff26f66687c2f6d16ab \ + --hash=sha256:fce659a462a1be54d2ffcacea5e3ba2d74daa74f30f5f143fe0c58636e355fdd \ + --hash=sha256:ffee1f21e5ef0d712f9033568f8344d5da8cc2869dbd08d87c84656e6a2d2f68 # via werkzeug mdurl==0.1.2 \ --hash=sha256:84008a41e51615a49fc9966191ff91509e3c40b939176e643fd50a5c2196b8f8 \ --hash=sha256:bb413d29f5eea38f31dd4754dd7377d4465116fb207585f97bf925588687c1ba # via markdown-it-py -ml-dtypes==0.3.1 \ - --hash=sha256:3d8ca0acbd377082792d8b97081ba580abdad67c6afb7f827012c675b052f058 \ - --hash=sha256:42a8980afd8b7c8e270e8b5c260237286b5b26acd276fcb758d13cd7cb567e99 \ - --hash=sha256:438437e2e614a3c91d75581653b6c40ec890e8b5994d7190a90c931740151c95 \ - --hash=sha256:4828b62fa3bf1ae35faa40f3db9a38ec72fbce02f328a1d14c3a9da4606af364 \ - --hash=sha256:4d94b2d1bed77284694f7fd0479640fa7aa5d96433dca3cbcec407a5ef752e77 \ - --hash=sha256:510d249a91face47211762eb294d6fe64f325356b965fb6388c1bf51bd339267 \ - --hash=sha256:5727effa7650f7ab10906542d137cfb3244fdc3b2b519beff42f82def8ba59be \ - --hash=sha256:5e0b0b6bb07fa5ad11bb61d174667176bee5e05857225067aabfc5adc1b51d23 \ - --hash=sha256:60778f99194b4c4f36ba42da200b35ef851ce4d4af698aaf70f5b91fe70fc611 \ - --hash=sha256:70984b473db6489ec1d8c79b082a1322105155193049d08a3b0c515094e9777b \ - --hash=sha256:979d7d196d9a17e0135ae22878f74241fbd3522cef58d7b292f1fd5b32282201 \ - --hash=sha256:a777928dcba8865ab4a8157eeb25d23aed7bc82e5fd74e1d5eca821d3f148b39 \ - --hash=sha256:cb0c404e0dd3e25b56362c1c1e5de0ef717f727dde59fa721af4ce6ab2acca44 \ - --hash=sha256:d1a8dc3bac1da2a17d0e2e4cba36ee89721d0bd33ea4765af2eefb5f41409e0f \ - --hash=sha256:da274599e4950a9b488d21571061f49a185537cc77f2d3f8121151d58a9e9f16 \ - --hash=sha256:f83ff080df8910c0f987f615b03e4f8198638e0c00c6e679ea8892dda909763b \ - --hash=sha256:fcae2c69715410d96906e1dfe8f017d9f78a0d10e0df91aae52e91f51fdfe45e - # via jax +ml-dtypes==0.3.2 \ + --hash=sha256:2c34f2ba9660b21fe1034b608308a01be82bbef2a92fb8199f24dc6bad0d5226 \ + --hash=sha256:3a17ef2322e60858d93584e9c52a5be7dd6236b056b7fa1ec57f1bb6ba043e33 \ + --hash=sha256:533059bc5f1764fac071ef54598db358c167c51a718f68f5bb55e3dee79d2967 \ + --hash=sha256:6604877d567a29bfe7cc02969ae0f2425260e5335505cf5e7fefc3e5465f5655 \ + --hash=sha256:6b35c4e8ca957c877ac35c79ffa77724ecc3702a1e4b18b08306c03feae597bb \ + --hash=sha256:763697ab8a88d47443997a7cdf3aac7340049aed45f7521f6b0ec8a0594821fe \ + --hash=sha256:7a4c3fcbf86fa52d0204f07cfd23947ef05b4ad743a1a988e163caa34a201e5e \ + --hash=sha256:7afde548890a92b41c0fed3a6c525f1200a5727205f73dc21181a2726571bb53 \ + --hash=sha256:7ba8e1fafc7fff3e643f453bffa7d082df1678a73286ce8187d3e825e776eb94 \ + --hash=sha256:91f8783fd1f2c23fd3b9ee5ad66b785dafa58ba3cdb050c4458021fa4d1eb226 \ + --hash=sha256:93b78f53431c93953f7850bb1b925a17f0ab5d97527e38a7e865b5b4bc5cfc18 \ + --hash=sha256:961134ea44c7b8ca63eda902a44b58cd8bd670e21d62e255c81fba0a8e70d9b7 \ + --hash=sha256:b89b194e9501a92d289c1ffd411380baf5daafb9818109a4f49b0a1b6dce4462 \ + --hash=sha256:c7b3fb3d4f6b39bcd4f6c4b98f406291f0d681a895490ee29a0f95bab850d53c \ + --hash=sha256:d1a746fe5fb9cd974a91070174258f0be129c592b93f9ce7df6cc336416c3fbd \ + --hash=sha256:e8505946df1665db01332d885c2020b4cb9e84a8b1241eb4ba69d59591f65855 \ + --hash=sha256:f47619d978ab1ae7dfdc4052ea97c636c6263e1f19bd1be0e42c346b98d15ff4 + # via + # jax + # keras-nightly namex==0.0.7 \ --hash=sha256:84ba65bc4d22bd909e3d26bf2ffb4b9529b608cb3f9a4336f776b04204ced69b \ --hash=sha256:8a4f062945f405d77cb66b907f16aa2fd83681945e998be840eb6c4154d40108 # via keras-nightly -numpy==1.26.1 ; python_version >= "3.12" \ - --hash=sha256:06934e1a22c54636a059215d6da99e23286424f316fddd979f5071093b648668 \ - --hash=sha256:1c59c046c31a43310ad0199d6299e59f57a289e22f0f36951ced1c9eac3665b9 \ - --hash=sha256:1d1bd82d539607951cac963388534da3b7ea0e18b149a53cf883d8f699178c0f \ - --hash=sha256:1e11668d6f756ca5ef534b5be8653d16c5352cbb210a5c2a79ff288e937010d5 \ - --hash=sha256:3649d566e2fc067597125428db15d60eb42a4e0897fc48d28cb75dc2e0454e53 \ - --hash=sha256:59227c981d43425ca5e5c01094d59eb14e8772ce6975d4b2fc1e106a833d5ae2 \ - --hash=sha256:6081aed64714a18c72b168a9276095ef9155dd7888b9e74b5987808f0dd0a974 \ - --hash=sha256:6965888d65d2848e8768824ca8288db0a81263c1efccec881cb35a0d805fcd2f \ - --hash=sha256:76ff661a867d9272cd2a99eed002470f46dbe0943a5ffd140f49be84f68ffc42 \ - --hash=sha256:78ca54b2f9daffa5f323f34cdf21e1d9779a54073f0018a3094ab907938331a2 \ - --hash=sha256:82e871307a6331b5f09efda3c22e03c095d957f04bf6bc1804f30048d0e5e7af \ - --hash=sha256:8ab9163ca8aeb7fd32fe93866490654d2f7dda4e61bc6297bf72ce07fdc02f67 \ - --hash=sha256:9696aa2e35cc41e398a6d42d147cf326f8f9d81befcb399bc1ed7ffea339b64e \ - --hash=sha256:97e5d6a9f0702c2863aaabf19f0d1b6c2628fbe476438ce0b5ce06e83085064c \ - --hash=sha256:9f42284ebf91bdf32fafac29d29d4c07e5e9d1af862ea73686581773ef9e73a7 \ - --hash=sha256:a03fb25610ef560a6201ff06df4f8105292ba56e7cdd196ea350d123fc32e24e \ - --hash=sha256:a5b411040beead47a228bde3b2241100454a6abde9df139ed087bd73fc0a4908 \ - --hash=sha256:af22f3d8e228d84d1c0c44c1fbdeb80f97a15a0abe4f080960393a00db733b66 \ - --hash=sha256:afd5ced4e5a96dac6725daeb5242a35494243f2239244fad10a90ce58b071d24 \ - --hash=sha256:b9d45d1dbb9de84894cc50efece5b09939752a2d75aab3a8b0cef6f3a35ecd6b \ - --hash=sha256:bb894accfd16b867d8643fc2ba6c8617c78ba2828051e9a69511644ce86ce83e \ - --hash=sha256:c8c6c72d4a9f831f328efb1312642a1cafafaa88981d9ab76368d50d07d93cbe \ - --hash=sha256:cd7837b2b734ca72959a1caf3309457a318c934abef7a43a14bb984e574bbb9a \ - --hash=sha256:cdd9ec98f0063d93baeb01aad472a1a0840dee302842a2746a7a8e92968f9575 \ - --hash=sha256:d1cfc92db6af1fd37a7bb58e55c8383b4aa1ba23d012bdbba26b4bcca45ac297 \ - --hash=sha256:d1d2c6b7dd618c41e202c59c1413ef9b2c8e8a15f5039e344af64195459e3104 \ - --hash=sha256:d2984cb6caaf05294b8466966627e80bf6c7afd273279077679cb010acb0e5ab \ - --hash=sha256:d58e8c51a7cf43090d124d5073bc29ab2755822181fcad978b12e144e5e5a4b3 \ - --hash=sha256:d78f269e0c4fd365fc2992c00353e4530d274ba68f15e968d8bc3c69ce5f5244 \ - --hash=sha256:dcfaf015b79d1f9f9c9fd0731a907407dc3e45769262d657d754c3a028586124 \ - --hash=sha256:e44ccb93f30c75dfc0c3aa3ce38f33486a75ec9abadabd4e59f114994a9c4617 \ - --hash=sha256:e509cbc488c735b43b5ffea175235cec24bbc57b227ef1acc691725beb230d1c +numpy==1.26.4 ; python_version >= "3.12" \ + --hash=sha256:03a8c78d01d9781b28a6989f6fa1bb2c4f2d51201cf99d3dd875df6fbd96b23b \ + --hash=sha256:08beddf13648eb95f8d867350f6a018a4be2e5ad54c8d8caed89ebca558b2818 \ + --hash=sha256:1af303d6b2210eb850fcf03064d364652b7120803a0b872f5211f5234b399f20 \ + --hash=sha256:1dda2e7b4ec9dd512f84935c5f126c8bd8b9f2fc001e9f54af255e8c5f16b0e0 \ + --hash=sha256:2a02aba9ed12e4ac4eb3ea9421c420301a0c6460d9830d74a9df87efa4912010 \ + --hash=sha256:2e4ee3380d6de9c9ec04745830fd9e2eccb3e6cf790d39d7b98ffd19b0dd754a \ + --hash=sha256:3373d5d70a5fe74a2c1bb6d2cfd9609ecf686d47a2d7b1d37a8f3b6bf6003aea \ + --hash=sha256:47711010ad8555514b434df65f7d7b076bb8261df1ca9bb78f53d3b2db02e95c \ + --hash=sha256:4c66707fabe114439db9068ee468c26bbdf909cac0fb58686a42a24de1760c71 \ + --hash=sha256:50193e430acfc1346175fcbdaa28ffec49947a06918b7b92130744e81e640110 \ + --hash=sha256:52b8b60467cd7dd1e9ed082188b4e6bb35aa5cdd01777621a1658910745b90be \ + --hash=sha256:60dedbb91afcbfdc9bc0b1f3f402804070deed7392c23eb7a7f07fa857868e8a \ + --hash=sha256:62b8e4b1e28009ef2846b4c7852046736bab361f7aeadeb6a5b89ebec3c7055a \ + --hash=sha256:666dbfb6ec68962c033a450943ded891bed2d54e6755e35e5835d63f4f6931d5 \ + --hash=sha256:675d61ffbfa78604709862923189bad94014bef562cc35cf61d3a07bba02a7ed \ + --hash=sha256:679b0076f67ecc0138fd2ede3a8fd196dddc2ad3254069bcb9faf9a79b1cebcd \ + --hash=sha256:7349ab0fa0c429c82442a27a9673fc802ffdb7c7775fad780226cb234965e53c \ + --hash=sha256:7ab55401287bfec946ced39700c053796e7cc0e3acbef09993a9ad2adba6ca6e \ + --hash=sha256:7e50d0a0cc3189f9cb0aeb3a6a6af18c16f59f004b866cd2be1c14b36134a4a0 \ + --hash=sha256:95a7476c59002f2f6c590b9b7b998306fba6a5aa646b1e22ddfeaf8f78c3a29c \ + --hash=sha256:96ff0b2ad353d8f990b63294c8986f1ec3cb19d749234014f4e7eb0112ceba5a \ + --hash=sha256:9fad7dcb1aac3c7f0584a5a8133e3a43eeb2fe127f47e3632d43d677c66c102b \ + --hash=sha256:9ff0f4f29c51e2803569d7a51c2304de5554655a60c5d776e35b4a41413830d0 \ + --hash=sha256:a354325ee03388678242a4d7ebcd08b5c727033fcff3b2f536aea978e15ee9e6 \ + --hash=sha256:a4abb4f9001ad2858e7ac189089c42178fcce737e4169dc61321660f1a96c7d2 \ + --hash=sha256:ab47dbe5cc8210f55aa58e4805fe224dac469cde56b9f731a4c098b91917159a \ + --hash=sha256:afedb719a9dcfc7eaf2287b839d8198e06dcd4cb5d276a3df279231138e83d30 \ + --hash=sha256:b3ce300f3644fb06443ee2222c2201dd3a89ea6040541412b8fa189341847218 \ + --hash=sha256:b97fe8060236edf3662adfc2c633f56a08ae30560c56310562cb4f95500022d5 \ + --hash=sha256:bfe25acf8b437eb2a8b2d49d443800a5f18508cd811fea3181723922a8a82b07 \ + --hash=sha256:cd25bcecc4974d09257ffcd1f098ee778f7834c3ad767fe5db785be9a4aa9cb2 \ + --hash=sha256:d209d8969599b27ad20994c8e41936ee0964e6da07478d6c35016bc386b66ad4 \ + --hash=sha256:d5241e0a80d808d70546c697135da2c613f30e28251ff8307eb72ba696945764 \ + --hash=sha256:edd8b5fe47dab091176d21bb6de568acdd906d1887a4584a15a9a96a1dca06ef \ + --hash=sha256:f870204a840a60da0b12273ef34f7051e98c3b5961b61b0c2c1be6dfd64fbcd3 \ + --hash=sha256:ffa75af20b44f8dba823498024771d5ac50620e6915abac414251bd971b4529f # via # -r requirements.in # h5py @@ -411,10 +408,6 @@ numpy==1.26.1 ; python_version >= "3.12" \ # opt-einsum # scipy # tb-nightly -oauthlib==3.2.2 \ - --hash=sha256:8139f29aac13e25d502680e9e19963e83f16838d48a0d71c287fe40e7067fbca \ - --hash=sha256:9859c40929662bec5d64f34d01c99e093149682a3f38915dc0655d5a633dd918 - # via requests-oauthlib opt-einsum==3.3.0 \ --hash=sha256:2455e59e3947d3c275477df7f5205b30635e266fe6dc300e3d9f9646bfcea147 \ --hash=sha256:59f6475f77bbc37dcf7cd748519c0ec60722e91e63ca114e68821c0c54a46549 @@ -444,57 +437,36 @@ protobuf==4.23.4 \ --hash=sha256:effeac51ab79332d44fba74660d40ae79985901ac21bca408f8dc335a81aa597 \ --hash=sha256:fee88269a090ada09ca63551bf2f573eb2424035bcf2cb1b121895b01a46594a # via tb-nightly -psutil==5.9.6 \ - --hash=sha256:10e8c17b4f898d64b121149afb136c53ea8b68c7531155147867b7b1ac9e7e28 \ - --hash=sha256:18cd22c5db486f33998f37e2bb054cc62fd06646995285e02a51b1e08da97017 \ - --hash=sha256:3ebf2158c16cc69db777e3c7decb3c0f43a7af94a60d72e87b2823aebac3d602 \ - --hash=sha256:51dc3d54607c73148f63732c727856f5febec1c7c336f8f41fcbd6315cce76ac \ - --hash=sha256:6e5fb8dc711a514da83098bc5234264e551ad980cec5f85dabf4d38ed6f15e9a \ - --hash=sha256:70cb3beb98bc3fd5ac9ac617a327af7e7f826373ee64c80efd4eb2856e5051e9 \ - --hash=sha256:748c9dd2583ed86347ed65d0035f45fa8c851e8d90354c122ab72319b5f366f4 \ - --hash=sha256:91ecd2d9c00db9817a4b4192107cf6954addb5d9d67a969a4f436dbc9200f88c \ - --hash=sha256:92e0cc43c524834af53e9d3369245e6cc3b130e78e26100d1f63cdb0abeb3d3c \ - --hash=sha256:a6f01f03bf1843280f4ad16f4bde26b817847b4c1a0db59bf6419807bc5ce05c \ - --hash=sha256:c69596f9fc2f8acd574a12d5f8b7b1ba3765a641ea5d60fb4736bf3c08a8214a \ - --hash=sha256:ca2780f5e038379e520281e4c032dddd086906ddff9ef0d1b9dcf00710e5071c \ - --hash=sha256:daecbcbd29b289aac14ece28eca6a3e60aa361754cf6da3dfb20d4d32b6c7f57 \ - --hash=sha256:e4b92ddcd7dd4cdd3f900180ea1e104932c7bce234fb88976e2a3b296441225a \ - --hash=sha256:fb8a697f11b0f5994550555fcfe3e69799e5b060c8ecf9e2f75c69302cc35c0d \ - --hash=sha256:ff18b8d1a784b810df0b0fff3bcb50ab941c3b8e2c8de5726f9c71c601c611aa +psutil==5.9.8 \ + --hash=sha256:02615ed8c5ea222323408ceba16c60e99c3f91639b07da6373fb7e6539abc56d \ + --hash=sha256:05806de88103b25903dff19bb6692bd2e714ccf9e668d050d144012055cbca73 \ + --hash=sha256:26bd09967ae00920df88e0352a91cff1a78f8d69b3ecabbfe733610c0af486c8 \ + --hash=sha256:27cc40c3493bb10de1be4b3f07cae4c010ce715290a5be22b98493509c6299e2 \ + --hash=sha256:36f435891adb138ed3c9e58c6af3e2e6ca9ac2f365efe1f9cfef2794e6c93b4e \ + --hash=sha256:50187900d73c1381ba1454cf40308c2bf6f34268518b3f36a9b663ca87e65e36 \ + --hash=sha256:611052c4bc70432ec770d5d54f64206aa7203a101ec273a0cd82418c86503bb7 \ + --hash=sha256:6be126e3225486dff286a8fb9a06246a5253f4c7c53b475ea5f5ac934e64194c \ + --hash=sha256:7d79560ad97af658a0f6adfef8b834b53f64746d45b403f225b85c5c2c140eee \ + --hash=sha256:8cb6403ce6d8e047495a701dc7c5bd788add903f8986d523e3e20b98b733e421 \ + --hash=sha256:8db4c1b57507eef143a15a6884ca10f7c73876cdf5d51e713151c1236a0e68cf \ + --hash=sha256:aee678c8720623dc456fa20659af736241f575d79429a0e5e9cf88ae0605cc81 \ + --hash=sha256:bc56c2a1b0d15aa3eaa5a60c9f3f8e3e565303b465dbf57a1b730e7a2b9844e0 \ + --hash=sha256:bd1184ceb3f87651a67b2708d4c3338e9b10c5df903f2e3776b62303b26cb631 \ + --hash=sha256:d06016f7f8625a1825ba3732081d77c94589dca78b7a3fc072194851e88461a4 \ + --hash=sha256:d16bbddf0693323b8c6123dd804100241da461e41d6e332fb0ba6058f630f8c8 # via portpicker -pyasn1==0.5.0 \ - --hash=sha256:87a2121042a1ac9358cabcaf1d07680ff97ee6404333bacca15f76aa8ad01a57 \ - --hash=sha256:97b7290ca68e62a832558ec3976f15cbf911bf5d7c7039d8b861c2a0ece69fde - # via - # pyasn1-modules - # rsa -pyasn1-modules==0.3.0 \ - --hash=sha256:5bd01446b736eb9d31512a30d46c1ac3395d676c6f3cafa4c03eb54b9925631c \ - --hash=sha256:d3ccd6ed470d9ffbc716be08bd90efbd44d0734bc9303818f7336070984a162d - # via google-auth -pygments==2.16.1 \ - --hash=sha256:13fc09fa63bc8d8671a6d247e1eb303c4b343eaee81d861f3404db2935653692 \ - --hash=sha256:1daff0494820c69bc8941e407aa20f577374ee88364ee10a98fdbe0aece96e29 +pygments==2.17.2 \ + --hash=sha256:b27c2826c47d0f3219f29554824c30c5e8945175d888647acd804ddd04af846c \ + --hash=sha256:da46cec9fd2de5be3a8a784f434e4c4ab670b4ff54d605c4c2717e9d49c4c367 # via rich requests==2.31.0 \ --hash=sha256:58cd2187c01e70e6e26505bca751777aa9f2ee0b7f4300988b709f44e013003f \ --hash=sha256:942c5a758f98d790eaed1a29cb6eefc7ffb0d1cf7af05c3d2791656dbd6ad1e1 - # via - # -r requirements.in - # requests-oauthlib - # tb-nightly -requests-oauthlib==1.3.1 \ - --hash=sha256:2577c501a2fb8d05a304c09d090d6e47c306fef15809d102b327cf8364bddab5 \ - --hash=sha256:75beac4a47881eeb94d5ea5d6ad31ef88856affe2332b9aafb52c6452ccf0d7a - # via google-auth-oauthlib -rich==13.6.0 \ - --hash=sha256:2b38e2fe9ca72c9a00170a1a2d20c63c790d0e10ef1fe35eba76e1e7b1d7d245 \ - --hash=sha256:5c14d22737e6d5084ef4771b62d5d4363165b403455a30a1c8ca39dc7b644bef + # via -r requirements.in +rich==13.7.0 \ + --hash=sha256:5cb5123b5cf9ee70584244246816e9114227e0b98ad9176eede6ad54bf5403fa \ + --hash=sha256:6da14c108c4866ee9520bbffa71f6fe3962e193b7da68720583850cd4548e235 # via keras-nightly -rsa==4.9 \ - --hash=sha256:90260d9058e514786967344d0ef75fa8727eed8a7d2e43ce9f4bcf1b536174f7 \ - --hash=sha256:e38464a49c6c85d7f1351b0126661487a7e0a14a50f1675ec50eb34d4f20ef21 - # via google-auth scipy==1.11.3 \ --hash=sha256:00f325434b6424952fbb636506f0567898dca7b0f7654d48f1c382ea338ce9a3 \ --hash=sha256:033c3fd95d55012dd1148b201b72ae854d5086d25e7c316ec9850de4fe776929 \ @@ -530,8 +502,8 @@ six==1.16.0 \ # via # astunparse # tb-nightly -tb-nightly==2.15.0a20231023 \ - --hash=sha256:8990a52985e3296aa18a6825efc017bcded9e2fb2cbcdd2d01f5c62d4fdc9825 +tb-nightly==2.17.0a20240214 \ + --hash=sha256:dbd59bcfd9b028e6199050e36cbb4ee1db4a94e69a7f8c00865f1498112a2b83 # via -r requirements.in tblib==2.0.0 \ --hash=sha256:9100bfa016b047d5b980d66e7efed952fbd20bd85b56110aaf473cb97d18709a \ @@ -546,13 +518,17 @@ termcolor==2.3.0 \ --hash=sha256:3afb05607b89aed0ffe25202399ee0867ad4d3cb4180d98aaf8eefa6a5f7d475 \ --hash=sha256:b5b08f68937f138fe92f6c089b99f1e2da0ae56c52b78bf7075fd95420fd9a5a # via -r requirements.in +tf-keras-nightly==2.16.0.dev2024021410 \ + --hash=sha256:05a4c19c6a795ec9ed6f4dac69d443601b9206b910524d8b19bad6e362d49f47 \ + --hash=sha256:83d5a1dd3979b42164a2fac4bdd2ed5981ef0c2a0514ca64e4bc67e369caef5c + # via tb-nightly typing-extensions==4.8.0 \ --hash=sha256:8f92fc8806f9a6b641eaa5318da32b44d401efaac0f6678c9bc448ba3605faa0 \ --hash=sha256:df8e4339e9cb77357558cbdbceca33c303714cf861d1eef15e1070055ae8b7ef # via -r requirements.in -urllib3==2.0.7 \ - --hash=sha256:c97dfde1f7bd43a71c8d2a58e369e9b2bf692d1334ea9f9cae55add7d0dd0f84 \ - --hash=sha256:fdb6d215c776278489906c2f8916e6e7d4f5a9b602ccbcfdf7f016fc8da0596e +urllib3==2.2.0 \ + --hash=sha256:051d961ad0c62a94e50ecf1af379c3aba230c66c710493493560c0c223c49f20 \ + --hash=sha256:ce3711610ddce217e6d113a2732fafad960a03fd0318c91faa79481e35c11224 # via requests werkzeug==3.0.1 \ --hash=sha256:507e811ecea72b18a404947aded4b3390e1db8f826b494d76550ef45bb3b1dcc \ @@ -564,81 +540,77 @@ wheel==0.41.3 \ # via # -r requirements.in # astunparse -wrapt==1.14.1 \ - --hash=sha256:00b6d4ea20a906c0ca56d84f93065b398ab74b927a7a3dbd470f6fc503f95dc3 \ - --hash=sha256:01c205616a89d09827986bc4e859bcabd64f5a0662a7fe95e0d359424e0e071b \ - --hash=sha256:02b41b633c6261feff8ddd8d11c711df6842aba629fdd3da10249a53211a72c4 \ - --hash=sha256:07f7a7d0f388028b2df1d916e94bbb40624c59b48ecc6cbc232546706fac74c2 \ - --hash=sha256:11871514607b15cfeb87c547a49bca19fde402f32e2b1c24a632506c0a756656 \ - --hash=sha256:1b376b3f4896e7930f1f772ac4b064ac12598d1c38d04907e696cc4d794b43d3 \ - --hash=sha256:2020f391008ef874c6d9e208b24f28e31bcb85ccff4f335f15a3251d222b92d9 \ - --hash=sha256:21ac0156c4b089b330b7666db40feee30a5d52634cc4560e1905d6529a3897ff \ - --hash=sha256:240b1686f38ae665d1b15475966fe0472f78e71b1b4903c143a842659c8e4cb9 \ - --hash=sha256:257fd78c513e0fb5cdbe058c27a0624c9884e735bbd131935fd49e9fe719d310 \ - --hash=sha256:26046cd03936ae745a502abf44dac702a5e6880b2b01c29aea8ddf3353b68224 \ - --hash=sha256:2b39d38039a1fdad98c87279b48bc5dce2c0ca0d73483b12cb72aa9609278e8a \ - --hash=sha256:2cf71233a0ed05ccdabe209c606fe0bac7379fdcf687f39b944420d2a09fdb57 \ - --hash=sha256:2fe803deacd09a233e4762a1adcea5db5d31e6be577a43352936179d14d90069 \ - --hash=sha256:2feecf86e1f7a86517cab34ae6c2f081fd2d0dac860cb0c0ded96d799d20b335 \ - --hash=sha256:3232822c7d98d23895ccc443bbdf57c7412c5a65996c30442ebe6ed3df335383 \ - --hash=sha256:34aa51c45f28ba7f12accd624225e2b1e5a3a45206aa191f6f9aac931d9d56fe \ - --hash=sha256:358fe87cc899c6bb0ddc185bf3dbfa4ba646f05b1b0b9b5a27c2cb92c2cea204 \ - --hash=sha256:36f582d0c6bc99d5f39cd3ac2a9062e57f3cf606ade29a0a0d6b323462f4dd87 \ - --hash=sha256:380a85cf89e0e69b7cfbe2ea9f765f004ff419f34194018a6827ac0e3edfed4d \ - --hash=sha256:40e7bc81c9e2b2734ea4bc1aceb8a8f0ceaac7c5299bc5d69e37c44d9081d43b \ - --hash=sha256:43ca3bbbe97af00f49efb06e352eae40434ca9d915906f77def219b88e85d907 \ - --hash=sha256:49ef582b7a1152ae2766557f0550a9fcbf7bbd76f43fbdc94dd3bf07cc7168be \ - --hash=sha256:4fcc4649dc762cddacd193e6b55bc02edca674067f5f98166d7713b193932b7f \ - --hash=sha256:5a0f54ce2c092aaf439813735584b9537cad479575a09892b8352fea5e988dc0 \ - --hash=sha256:5a9a0d155deafd9448baff28c08e150d9b24ff010e899311ddd63c45c2445e28 \ - --hash=sha256:5b02d65b9ccf0ef6c34cba6cf5bf2aab1bb2f49c6090bafeecc9cd81ad4ea1c1 \ - --hash=sha256:60db23fa423575eeb65ea430cee741acb7c26a1365d103f7b0f6ec412b893853 \ - --hash=sha256:642c2e7a804fcf18c222e1060df25fc210b9c58db7c91416fb055897fc27e8cc \ - --hash=sha256:6447e9f3ba72f8e2b985a1da758767698efa72723d5b59accefd716e9e8272bf \ - --hash=sha256:6a9a25751acb379b466ff6be78a315e2b439d4c94c1e99cb7266d40a537995d3 \ - --hash=sha256:6b1a564e6cb69922c7fe3a678b9f9a3c54e72b469875aa8018f18b4d1dd1adf3 \ - --hash=sha256:6d323e1554b3d22cfc03cd3243b5bb815a51f5249fdcbb86fda4bf62bab9e164 \ - --hash=sha256:6e743de5e9c3d1b7185870f480587b75b1cb604832e380d64f9504a0535912d1 \ - --hash=sha256:709fe01086a55cf79d20f741f39325018f4df051ef39fe921b1ebe780a66184c \ - --hash=sha256:7b7c050ae976e286906dd3f26009e117eb000fb2cf3533398c5ad9ccc86867b1 \ - --hash=sha256:7d2872609603cb35ca513d7404a94d6d608fc13211563571117046c9d2bcc3d7 \ - --hash=sha256:7ef58fb89674095bfc57c4069e95d7a31cfdc0939e2a579882ac7d55aadfd2a1 \ - --hash=sha256:80bb5c256f1415f747011dc3604b59bc1f91c6e7150bd7db03b19170ee06b320 \ - --hash=sha256:81b19725065dcb43df02b37e03278c011a09e49757287dca60c5aecdd5a0b8ed \ - --hash=sha256:833b58d5d0b7e5b9832869f039203389ac7cbf01765639c7309fd50ef619e0b1 \ - --hash=sha256:88bd7b6bd70a5b6803c1abf6bca012f7ed963e58c68d76ee20b9d751c74a3248 \ - --hash=sha256:8ad85f7f4e20964db4daadcab70b47ab05c7c1cf2a7c1e51087bfaa83831854c \ - --hash=sha256:8c0ce1e99116d5ab21355d8ebe53d9460366704ea38ae4d9f6933188f327b456 \ - --hash=sha256:8d649d616e5c6a678b26d15ece345354f7c2286acd6db868e65fcc5ff7c24a77 \ - --hash=sha256:903500616422a40a98a5a3c4ff4ed9d0066f3b4c951fa286018ecdf0750194ef \ - --hash=sha256:9736af4641846491aedb3c3f56b9bc5568d92b0692303b5a305301a95dfd38b1 \ - --hash=sha256:988635d122aaf2bdcef9e795435662bcd65b02f4f4c1ae37fbee7401c440b3a7 \ - --hash=sha256:9cca3c2cdadb362116235fdbd411735de4328c61425b0aa9f872fd76d02c4e86 \ - --hash=sha256:9e0fd32e0148dd5dea6af5fee42beb949098564cc23211a88d799e434255a1f4 \ - --hash=sha256:9f3e6f9e05148ff90002b884fbc2a86bd303ae847e472f44ecc06c2cd2fcdb2d \ - --hash=sha256:a85d2b46be66a71bedde836d9e41859879cc54a2a04fad1191eb50c2066f6e9d \ - --hash=sha256:a9008dad07d71f68487c91e96579c8567c98ca4c3881b9b113bc7b33e9fd78b8 \ - --hash=sha256:a9a52172be0b5aae932bef82a79ec0a0ce87288c7d132946d645eba03f0ad8a8 \ - --hash=sha256:aa31fdcc33fef9eb2552cbcbfee7773d5a6792c137b359e82879c101e98584c5 \ - --hash=sha256:acae32e13a4153809db37405f5eba5bac5fbe2e2ba61ab227926a22901051c0a \ - --hash=sha256:b014c23646a467558be7da3d6b9fa409b2c567d2110599b7cf9a0c5992b3b471 \ - --hash=sha256:b21bb4c09ffabfa0e85e3a6b623e19b80e7acd709b9f91452b8297ace2a8ab00 \ - --hash=sha256:b5901a312f4d14c59918c221323068fad0540e34324925c8475263841dbdfe68 \ - --hash=sha256:b9b7a708dd92306328117d8c4b62e2194d00c365f18eff11a9b53c6f923b01e3 \ - --hash=sha256:d1967f46ea8f2db647c786e78d8cc7e4313dbd1b0aca360592d8027b8508e24d \ - --hash=sha256:d52a25136894c63de15a35bc0bdc5adb4b0e173b9c0d07a2be9d3ca64a332735 \ - --hash=sha256:d77c85fedff92cf788face9bfa3ebaa364448ebb1d765302e9af11bf449ca36d \ - --hash=sha256:d79d7d5dc8a32b7093e81e97dad755127ff77bcc899e845f41bf71747af0c569 \ - --hash=sha256:dbcda74c67263139358f4d188ae5faae95c30929281bc6866d00573783c422b7 \ - --hash=sha256:ddaea91abf8b0d13443f6dac52e89051a5063c7d014710dcb4d4abb2ff811a59 \ - --hash=sha256:dee0ce50c6a2dd9056c20db781e9c1cfd33e77d2d569f5d1d9321c641bb903d5 \ - --hash=sha256:dee60e1de1898bde3b238f18340eec6148986da0455d8ba7848d50470a7a32fb \ - --hash=sha256:e2f83e18fe2f4c9e7db597e988f72712c0c3676d337d8b101f6758107c42425b \ - --hash=sha256:e3fb1677c720409d5f671e39bac6c9e0e422584e5f518bfd50aa4cbbea02433f \ - --hash=sha256:ecee4132c6cd2ce5308e21672015ddfed1ff975ad0ac8d27168ea82e71413f55 \ - --hash=sha256:ee2b1b1769f6707a8a445162ea16dddf74285c3964f605877a20e38545c3c462 \ - --hash=sha256:ee6acae74a2b91865910eef5e7de37dc6895ad96fa23603d1d27ea69df545015 \ - --hash=sha256:ef3f72c9666bba2bab70d2a8b79f2c6d2c1a42a7f7e2b0ec83bb2f9e383950af +wrapt==1.16.0 \ + --hash=sha256:0d2691979e93d06a95a26257adb7bfd0c93818e89b1406f5a28f36e0d8c1e1fc \ + --hash=sha256:14d7dc606219cdd7405133c713f2c218d4252f2a469003f8c46bb92d5d095d81 \ + --hash=sha256:1a5db485fe2de4403f13fafdc231b0dbae5eca4359232d2efc79025527375b09 \ + --hash=sha256:1acd723ee2a8826f3d53910255643e33673e1d11db84ce5880675954183ec47e \ + --hash=sha256:1ca9b6085e4f866bd584fb135a041bfc32cab916e69f714a7d1d397f8c4891ca \ + --hash=sha256:1dd50a2696ff89f57bd8847647a1c363b687d3d796dc30d4dd4a9d1689a706f0 \ + --hash=sha256:2076fad65c6736184e77d7d4729b63a6d1ae0b70da4868adeec40989858eb3fb \ + --hash=sha256:2a88e6010048489cda82b1326889ec075a8c856c2e6a256072b28eaee3ccf487 \ + --hash=sha256:3ebf019be5c09d400cf7b024aa52b1f3aeebeff51550d007e92c3c1c4afc2a40 \ + --hash=sha256:418abb18146475c310d7a6dc71143d6f7adec5b004ac9ce08dc7a34e2babdc5c \ + --hash=sha256:43aa59eadec7890d9958748db829df269f0368521ba6dc68cc172d5d03ed8060 \ + --hash=sha256:44a2754372e32ab315734c6c73b24351d06e77ffff6ae27d2ecf14cf3d229202 \ + --hash=sha256:490b0ee15c1a55be9c1bd8609b8cecd60e325f0575fc98f50058eae366e01f41 \ + --hash=sha256:49aac49dc4782cb04f58986e81ea0b4768e4ff197b57324dcbd7699c5dfb40b9 \ + --hash=sha256:5eb404d89131ec9b4f748fa5cfb5346802e5ee8836f57d516576e61f304f3b7b \ + --hash=sha256:5f15814a33e42b04e3de432e573aa557f9f0f56458745c2074952f564c50e664 \ + --hash=sha256:5f370f952971e7d17c7d1ead40e49f32345a7f7a5373571ef44d800d06b1899d \ + --hash=sha256:66027d667efe95cc4fa945af59f92c5a02c6f5bb6012bff9e60542c74c75c362 \ + --hash=sha256:66dfbaa7cfa3eb707bbfcd46dab2bc6207b005cbc9caa2199bcbc81d95071a00 \ + --hash=sha256:685f568fa5e627e93f3b52fda002c7ed2fa1800b50ce51f6ed1d572d8ab3e7fc \ + --hash=sha256:6906c4100a8fcbf2fa735f6059214bb13b97f75b1a61777fcf6432121ef12ef1 \ + --hash=sha256:6a42cd0cfa8ffc1915aef79cb4284f6383d8a3e9dcca70c445dcfdd639d51267 \ + --hash=sha256:6dcfcffe73710be01d90cae08c3e548d90932d37b39ef83969ae135d36ef3956 \ + --hash=sha256:6f6eac2360f2d543cc875a0e5efd413b6cbd483cb3ad7ebf888884a6e0d2e966 \ + --hash=sha256:72554a23c78a8e7aa02abbd699d129eead8b147a23c56e08d08dfc29cfdddca1 \ + --hash=sha256:73870c364c11f03ed072dda68ff7aea6d2a3a5c3fe250d917a429c7432e15228 \ + --hash=sha256:73aa7d98215d39b8455f103de64391cb79dfcad601701a3aa0dddacf74911d72 \ + --hash=sha256:75ea7d0ee2a15733684badb16de6794894ed9c55aa5e9903260922f0482e687d \ + --hash=sha256:7bd2d7ff69a2cac767fbf7a2b206add2e9a210e57947dd7ce03e25d03d2de292 \ + --hash=sha256:807cc8543a477ab7422f1120a217054f958a66ef7314f76dd9e77d3f02cdccd0 \ + --hash=sha256:8e9723528b9f787dc59168369e42ae1c3b0d3fadb2f1a71de14531d321ee05b0 \ + --hash=sha256:9090c9e676d5236a6948330e83cb89969f433b1943a558968f659ead07cb3b36 \ + --hash=sha256:9153ed35fc5e4fa3b2fe97bddaa7cbec0ed22412b85bcdaf54aeba92ea37428c \ + --hash=sha256:9159485323798c8dc530a224bd3ffcf76659319ccc7bbd52e01e73bd0241a0c5 \ + --hash=sha256:941988b89b4fd6b41c3f0bfb20e92bd23746579736b7343283297c4c8cbae68f \ + --hash=sha256:94265b00870aa407bd0cbcfd536f17ecde43b94fb8d228560a1e9d3041462d73 \ + --hash=sha256:98b5e1f498a8ca1858a1cdbffb023bfd954da4e3fa2c0cb5853d40014557248b \ + --hash=sha256:9b201ae332c3637a42f02d1045e1d0cccfdc41f1f2f801dafbaa7e9b4797bfc2 \ + --hash=sha256:a0ea261ce52b5952bf669684a251a66df239ec6d441ccb59ec7afa882265d593 \ + --hash=sha256:a33a747400b94b6d6b8a165e4480264a64a78c8a4c734b62136062e9a248dd39 \ + --hash=sha256:a452f9ca3e3267cd4d0fcf2edd0d035b1934ac2bd7e0e57ac91ad6b95c0c6389 \ + --hash=sha256:a86373cf37cd7764f2201b76496aba58a52e76dedfaa698ef9e9688bfd9e41cf \ + --hash=sha256:ac83a914ebaf589b69f7d0a1277602ff494e21f4c2f743313414378f8f50a4cf \ + --hash=sha256:aefbc4cb0a54f91af643660a0a150ce2c090d3652cf4052a5397fb2de549cd89 \ + --hash=sha256:b3646eefa23daeba62643a58aac816945cadc0afaf21800a1421eeba5f6cfb9c \ + --hash=sha256:b47cfad9e9bbbed2339081f4e346c93ecd7ab504299403320bf85f7f85c7d46c \ + --hash=sha256:b935ae30c6e7400022b50f8d359c03ed233d45b725cfdd299462f41ee5ffba6f \ + --hash=sha256:bb2dee3874a500de01c93d5c71415fcaef1d858370d405824783e7a8ef5db440 \ + --hash=sha256:bc57efac2da352a51cc4658878a68d2b1b67dbe9d33c36cb826ca449d80a8465 \ + --hash=sha256:bf5703fdeb350e36885f2875d853ce13172ae281c56e509f4e6eca049bdfb136 \ + --hash=sha256:c31f72b1b6624c9d863fc095da460802f43a7c6868c5dda140f51da24fd47d7b \ + --hash=sha256:c5cd603b575ebceca7da5a3a251e69561bec509e0b46e4993e1cac402b7247b8 \ + --hash=sha256:d2efee35b4b0a347e0d99d28e884dfd82797852d62fcd7ebdeee26f3ceb72cf3 \ + --hash=sha256:d462f28826f4657968ae51d2181a074dfe03c200d6131690b7d65d55b0f360f8 \ + --hash=sha256:d5e49454f19ef621089e204f862388d29e6e8d8b162efce05208913dde5b9ad6 \ + --hash=sha256:da4813f751142436b075ed7aa012a8778aa43a99f7b36afe9b742d3ed8bdc95e \ + --hash=sha256:db2e408d983b0e61e238cf579c09ef7020560441906ca990fe8412153e3b291f \ + --hash=sha256:db98ad84a55eb09b3c32a96c576476777e87c520a34e2519d3e59c44710c002c \ + --hash=sha256:dbed418ba5c3dce92619656802cc5355cb679e58d0d89b50f116e4a9d5a9603e \ + --hash=sha256:dcdba5c86e368442528f7060039eda390cc4091bfd1dca41e8046af7c910dda8 \ + --hash=sha256:decbfa2f618fa8ed81c95ee18a387ff973143c656ef800c9f24fb7e9c16054e2 \ + --hash=sha256:e4fdb9275308292e880dcbeb12546df7f3e0f96c6b41197e0cf37d2826359020 \ + --hash=sha256:eb1b046be06b0fce7249f1d025cd359b4b80fc1c3e24ad9eca33e0dcdb2e4a35 \ + --hash=sha256:eb6e651000a19c96f452c85132811d25e9264d836951022d6e81df2fff38337d \ + --hash=sha256:ed867c42c268f876097248e05b6117a65bcd1e63b779e916fe2e33cd6fd0d3c3 \ + --hash=sha256:edfad1d29c73f9b863ebe7082ae9321374ccb10879eeabc84ba3b69f2579d537 \ + --hash=sha256:f2058f813d4f2b5e3a9eb2eb3faf8f1d99b81c3e51aeda4b168406443e8ba809 \ + --hash=sha256:f6b2d0c6703c988d334f297aa5df18c45e97b0af3679bb75059e0e0bd8b1069d \ + --hash=sha256:f8212564d49c50eb4565e502814f694e240c55551a5f1bc841d4fcaabb0a9b8a \ + --hash=sha256:ffa565331890b90056c01db69c0fe634a776f8019c143a5ae265f9c6bc4bd6d4 # via -r requirements.in setuptools==68.2.2 \ diff --git a/requirements_lock_3_9.txt b/requirements_lock_3_9.txt index 0fb35480a8f886..9d9e85aceda9c7 100644 --- a/requirements_lock_3_9.txt +++ b/requirements_lock_3_9.txt @@ -1,6 +1,6 @@ -absl-py==2.0.0 \ - --hash=sha256:9a28abb62774ae4e8edbe2dd4c49ffcd45a6a848952a5eccc6a49f3f0fc1e2f3 \ - --hash=sha256:d9690211c5fcfefcdd1a45470ac2b5c5acd45241c3af71eed96bc5441746c0d5 +absl-py==2.1.0 \ + --hash=sha256:526a04eadab8b4ee719ce68f204172ead1027549089702d99b9059f129ff1308 \ + --hash=sha256:7820790efbb316739cde8b4e19357243fc3608a152024288513dd968d7d959ff # via # keras-nightly # tb-nightly @@ -12,105 +12,101 @@ astunparse==1.6.3 \ --hash=sha256:5ad93a8456f0d084c3456d059fd9a92cce667963232cbf763eac3bc5b7940872 \ --hash=sha256:c2652417f2c8b5bb325c885ae329bdf3f86424075c4fd1a128674bc6fba4b8e8 # via -r requirements.in -cachetools==5.3.2 \ - --hash=sha256:086ee420196f7b2ab9ca2db2520aca326318b68fe5ba8bc4d49cca91add450f2 \ - --hash=sha256:861f35a13a451f94e301ce2bec7cac63e881232ccce7ed67fab9b5df4d3beaa1 - # via google-auth -certifi==2023.7.22 \ - --hash=sha256:539cc1d13202e33ca466e88b2807e29f4c13049d6d87031a3c110744495cb082 \ - --hash=sha256:92d6037539857d8206b8f6ae472e8b77db8058fec5937a1ef3f54304089edbb9 +certifi==2024.2.2 \ + --hash=sha256:0569859f95fc761b18b45ef421b1290a0f65f147e92a1e5eb3e635f9a5e4e66f \ + --hash=sha256:dc383c07b76109f368f6106eee2b593b04a011ea4d55f652c6ca24a754d1cdd1 # via requests -charset-normalizer==3.3.1 \ - --hash=sha256:06cf46bdff72f58645434d467bf5228080801298fbba19fe268a01b4534467f5 \ - --hash=sha256:0c8c61fb505c7dad1d251c284e712d4e0372cef3b067f7ddf82a7fa82e1e9a93 \ - --hash=sha256:10b8dd31e10f32410751b3430996f9807fc4d1587ca69772e2aa940a82ab571a \ - --hash=sha256:1171ef1fc5ab4693c5d151ae0fdad7f7349920eabbaca6271f95969fa0756c2d \ - --hash=sha256:17a866d61259c7de1bdadef418a37755050ddb4b922df8b356503234fff7932c \ - --hash=sha256:1d6bfc32a68bc0933819cfdfe45f9abc3cae3877e1d90aac7259d57e6e0f85b1 \ - --hash=sha256:1ec937546cad86d0dce5396748bf392bb7b62a9eeb8c66efac60e947697f0e58 \ - --hash=sha256:223b4d54561c01048f657fa6ce41461d5ad8ff128b9678cfe8b2ecd951e3f8a2 \ - --hash=sha256:2465aa50c9299d615d757c1c888bc6fef384b7c4aec81c05a0172b4400f98557 \ - --hash=sha256:28f512b9a33235545fbbdac6a330a510b63be278a50071a336afc1b78781b147 \ - --hash=sha256:2c092be3885a1b7899cd85ce24acedc1034199d6fca1483fa2c3a35c86e43041 \ - --hash=sha256:2c4c99f98fc3a1835af8179dcc9013f93594d0670e2fa80c83aa36346ee763d2 \ - --hash=sha256:31445f38053476a0c4e6d12b047b08ced81e2c7c712e5a1ad97bc913256f91b2 \ - --hash=sha256:31bbaba7218904d2eabecf4feec0d07469284e952a27400f23b6628439439fa7 \ - --hash=sha256:34d95638ff3613849f473afc33f65c401a89f3b9528d0d213c7037c398a51296 \ - --hash=sha256:352a88c3df0d1fa886562384b86f9a9e27563d4704ee0e9d56ec6fcd270ea690 \ - --hash=sha256:39b70a6f88eebe239fa775190796d55a33cfb6d36b9ffdd37843f7c4c1b5dc67 \ - --hash=sha256:3c66df3f41abee950d6638adc7eac4730a306b022570f71dd0bd6ba53503ab57 \ - --hash=sha256:3f70fd716855cd3b855316b226a1ac8bdb3caf4f7ea96edcccc6f484217c9597 \ - --hash=sha256:3f9bc2ce123637a60ebe819f9fccc614da1bcc05798bbbaf2dd4ec91f3e08846 \ - --hash=sha256:3fb765362688821404ad6cf86772fc54993ec11577cd5a92ac44b4c2ba52155b \ - --hash=sha256:45f053a0ece92c734d874861ffe6e3cc92150e32136dd59ab1fb070575189c97 \ - --hash=sha256:46fb9970aa5eeca547d7aa0de5d4b124a288b42eaefac677bde805013c95725c \ - --hash=sha256:4cb50a0335382aac15c31b61d8531bc9bb657cfd848b1d7158009472189f3d62 \ - --hash=sha256:4e12f8ee80aa35e746230a2af83e81bd6b52daa92a8afaef4fea4a2ce9b9f4fa \ - --hash=sha256:4f3100d86dcd03c03f7e9c3fdb23d92e32abbca07e7c13ebd7ddfbcb06f5991f \ - --hash=sha256:4f6e2a839f83a6a76854d12dbebde50e4b1afa63e27761549d006fa53e9aa80e \ - --hash=sha256:4f861d94c2a450b974b86093c6c027888627b8082f1299dfd5a4bae8e2292821 \ - --hash=sha256:501adc5eb6cd5f40a6f77fbd90e5ab915c8fd6e8c614af2db5561e16c600d6f3 \ - --hash=sha256:520b7a142d2524f999447b3a0cf95115df81c4f33003c51a6ab637cbda9d0bf4 \ - --hash=sha256:548eefad783ed787b38cb6f9a574bd8664468cc76d1538215d510a3cd41406cb \ - --hash=sha256:555fe186da0068d3354cdf4bbcbc609b0ecae4d04c921cc13e209eece7720727 \ - --hash=sha256:55602981b2dbf8184c098bc10287e8c245e351cd4fdcad050bd7199d5a8bf514 \ - --hash=sha256:58e875eb7016fd014c0eea46c6fa92b87b62c0cb31b9feae25cbbe62c919f54d \ - --hash=sha256:5a3580a4fdc4ac05f9e53c57f965e3594b2f99796231380adb2baaab96e22761 \ - --hash=sha256:5b70bab78accbc672f50e878a5b73ca692f45f5b5e25c8066d748c09405e6a55 \ - --hash=sha256:5ceca5876032362ae73b83347be8b5dbd2d1faf3358deb38c9c88776779b2e2f \ - --hash=sha256:61f1e3fb621f5420523abb71f5771a204b33c21d31e7d9d86881b2cffe92c47c \ - --hash=sha256:633968254f8d421e70f91c6ebe71ed0ab140220469cf87a9857e21c16687c034 \ - --hash=sha256:63a6f59e2d01310f754c270e4a257426fe5a591dc487f1983b3bbe793cf6bac6 \ - --hash=sha256:63accd11149c0f9a99e3bc095bbdb5a464862d77a7e309ad5938fbc8721235ae \ - --hash=sha256:6db3cfb9b4fcecb4390db154e75b49578c87a3b9979b40cdf90d7e4b945656e1 \ - --hash=sha256:71ef3b9be10070360f289aea4838c784f8b851be3ba58cf796262b57775c2f14 \ - --hash=sha256:7ae8e5142dcc7a49168f4055255dbcced01dc1714a90a21f87448dc8d90617d1 \ - --hash=sha256:7b6cefa579e1237ce198619b76eaa148b71894fb0d6bcf9024460f9bf30fd228 \ - --hash=sha256:800561453acdecedaac137bf09cd719c7a440b6800ec182f077bb8e7025fb708 \ - --hash=sha256:82ca51ff0fc5b641a2d4e1cc8c5ff108699b7a56d7f3ad6f6da9dbb6f0145b48 \ - --hash=sha256:851cf693fb3aaef71031237cd68699dded198657ec1e76a76eb8be58c03a5d1f \ - --hash=sha256:854cc74367180beb327ab9d00f964f6d91da06450b0855cbbb09187bcdb02de5 \ - --hash=sha256:87071618d3d8ec8b186d53cb6e66955ef2a0e4fa63ccd3709c0c90ac5a43520f \ - --hash=sha256:871d045d6ccc181fd863a3cd66ee8e395523ebfbc57f85f91f035f50cee8e3d4 \ - --hash=sha256:8aee051c89e13565c6bd366813c386939f8e928af93c29fda4af86d25b73d8f8 \ - --hash=sha256:8af5a8917b8af42295e86b64903156b4f110a30dca5f3b5aedea123fbd638bff \ - --hash=sha256:8ec8ef42c6cd5856a7613dcd1eaf21e5573b2185263d87d27c8edcae33b62a61 \ - --hash=sha256:91e43805ccafa0a91831f9cd5443aa34528c0c3f2cc48c4cb3d9a7721053874b \ - --hash=sha256:9505dc359edb6a330efcd2be825fdb73ee3e628d9010597aa1aee5aa63442e97 \ - --hash=sha256:985c7965f62f6f32bf432e2681173db41336a9c2611693247069288bcb0c7f8b \ - --hash=sha256:9a74041ba0bfa9bc9b9bb2cd3238a6ab3b7618e759b41bd15b5f6ad958d17605 \ - --hash=sha256:9edbe6a5bf8b56a4a84533ba2b2f489d0046e755c29616ef8830f9e7d9cf5728 \ - --hash=sha256:a15c1fe6d26e83fd2e5972425a772cca158eae58b05d4a25a4e474c221053e2d \ - --hash=sha256:a66bcdf19c1a523e41b8e9d53d0cedbfbac2e93c649a2e9502cb26c014d0980c \ - --hash=sha256:ae4070f741f8d809075ef697877fd350ecf0b7c5837ed68738607ee0a2c572cf \ - --hash=sha256:ae55d592b02c4349525b6ed8f74c692509e5adffa842e582c0f861751701a673 \ - --hash=sha256:b578cbe580e3b41ad17b1c428f382c814b32a6ce90f2d8e39e2e635d49e498d1 \ - --hash=sha256:b891a2f68e09c5ef989007fac11476ed33c5c9994449a4e2c3386529d703dc8b \ - --hash=sha256:baec8148d6b8bd5cee1ae138ba658c71f5b03e0d69d5907703e3e1df96db5e41 \ - --hash=sha256:bb06098d019766ca16fc915ecaa455c1f1cd594204e7f840cd6258237b5079a8 \ - --hash=sha256:bc791ec3fd0c4309a753f95bb6c749ef0d8ea3aea91f07ee1cf06b7b02118f2f \ - --hash=sha256:bd28b31730f0e982ace8663d108e01199098432a30a4c410d06fe08fdb9e93f4 \ - --hash=sha256:be4d9c2770044a59715eb57c1144dedea7c5d5ae80c68fb9959515037cde2008 \ - --hash=sha256:c0c72d34e7de5604df0fde3644cc079feee5e55464967d10b24b1de268deceb9 \ - --hash=sha256:c0e842112fe3f1a4ffcf64b06dc4c61a88441c2f02f373367f7b4c1aa9be2ad5 \ - --hash=sha256:c15070ebf11b8b7fd1bfff7217e9324963c82dbdf6182ff7050519e350e7ad9f \ - --hash=sha256:c2000c54c395d9e5e44c99dc7c20a64dc371f777faf8bae4919ad3e99ce5253e \ - --hash=sha256:c30187840d36d0ba2893bc3271a36a517a717f9fd383a98e2697ee890a37c273 \ - --hash=sha256:cb7cd68814308aade9d0c93c5bd2ade9f9441666f8ba5aa9c2d4b389cb5e2a45 \ - --hash=sha256:cd805513198304026bd379d1d516afbf6c3c13f4382134a2c526b8b854da1c2e \ - --hash=sha256:d0bf89afcbcf4d1bb2652f6580e5e55a840fdf87384f6063c4a4f0c95e378656 \ - --hash=sha256:d9137a876020661972ca6eec0766d81aef8a5627df628b664b234b73396e727e \ - --hash=sha256:dbd95e300367aa0827496fe75a1766d198d34385a58f97683fe6e07f89ca3e3c \ - --hash=sha256:dced27917823df984fe0c80a5c4ad75cf58df0fbfae890bc08004cd3888922a2 \ - --hash=sha256:de0b4caa1c8a21394e8ce971997614a17648f94e1cd0640fbd6b4d14cab13a72 \ - --hash=sha256:debb633f3f7856f95ad957d9b9c781f8e2c6303ef21724ec94bea2ce2fcbd056 \ - --hash=sha256:e372d7dfd154009142631de2d316adad3cc1c36c32a38b16a4751ba78da2a397 \ - --hash=sha256:ecd26be9f112c4f96718290c10f4caea6cc798459a3a76636b817a0ed7874e42 \ - --hash=sha256:edc0202099ea1d82844316604e17d2b175044f9bcb6b398aab781eba957224bd \ - --hash=sha256:f194cce575e59ffe442c10a360182a986535fd90b57f7debfaa5c845c409ecc3 \ - --hash=sha256:f5fb672c396d826ca16a022ac04c9dce74e00a1c344f6ad1a0fdc1ba1f332213 \ - --hash=sha256:f6a02a3c7950cafaadcd46a226ad9e12fc9744652cc69f9e5534f98b47f3bbcf \ - --hash=sha256:fe81b35c33772e56f4b6cf62cf4aedc1762ef7162a31e6ac7fe5e40d0149eb67 +charset-normalizer==3.3.2 \ + --hash=sha256:06435b539f889b1f6f4ac1758871aae42dc3a8c0e24ac9e60c2384973ad73027 \ + --hash=sha256:06a81e93cd441c56a9b65d8e1d043daeb97a3d0856d177d5c90ba85acb3db087 \ + --hash=sha256:0a55554a2fa0d408816b3b5cedf0045f4b8e1a6065aec45849de2d6f3f8e9786 \ + --hash=sha256:0b2b64d2bb6d3fb9112bafa732def486049e63de9618b5843bcdd081d8144cd8 \ + --hash=sha256:10955842570876604d404661fbccbc9c7e684caf432c09c715ec38fbae45ae09 \ + --hash=sha256:122c7fa62b130ed55f8f285bfd56d5f4b4a5b503609d181f9ad85e55c89f4185 \ + --hash=sha256:1ceae2f17a9c33cb48e3263960dc5fc8005351ee19db217e9b1bb15d28c02574 \ + --hash=sha256:1d3193f4a680c64b4b6a9115943538edb896edc190f0b222e73761716519268e \ + --hash=sha256:1f79682fbe303db92bc2b1136016a38a42e835d932bab5b3b1bfcfbf0640e519 \ + --hash=sha256:2127566c664442652f024c837091890cb1942c30937add288223dc895793f898 \ + --hash=sha256:22afcb9f253dac0696b5a4be4a1c0f8762f8239e21b99680099abd9b2b1b2269 \ + --hash=sha256:25baf083bf6f6b341f4121c2f3c548875ee6f5339300e08be3f2b2ba1721cdd3 \ + --hash=sha256:2e81c7b9c8979ce92ed306c249d46894776a909505d8f5a4ba55b14206e3222f \ + --hash=sha256:3287761bc4ee9e33561a7e058c72ac0938c4f57fe49a09eae428fd88aafe7bb6 \ + --hash=sha256:34d1c8da1e78d2e001f363791c98a272bb734000fcef47a491c1e3b0505657a8 \ + --hash=sha256:37e55c8e51c236f95b033f6fb391d7d7970ba5fe7ff453dad675e88cf303377a \ + --hash=sha256:3d47fa203a7bd9c5b6cee4736ee84ca03b8ef23193c0d1ca99b5089f72645c73 \ + --hash=sha256:3e4d1f6587322d2788836a99c69062fbb091331ec940e02d12d179c1d53e25fc \ + --hash=sha256:42cb296636fcc8b0644486d15c12376cb9fa75443e00fb25de0b8602e64c1714 \ + --hash=sha256:45485e01ff4d3630ec0d9617310448a8702f70e9c01906b0d0118bdf9d124cf2 \ + --hash=sha256:4a78b2b446bd7c934f5dcedc588903fb2f5eec172f3d29e52a9096a43722adfc \ + --hash=sha256:4ab2fe47fae9e0f9dee8c04187ce5d09f48eabe611be8259444906793ab7cbce \ + --hash=sha256:4d0d1650369165a14e14e1e47b372cfcb31d6ab44e6e33cb2d4e57265290044d \ + --hash=sha256:549a3a73da901d5bc3ce8d24e0600d1fa85524c10287f6004fbab87672bf3e1e \ + --hash=sha256:55086ee1064215781fff39a1af09518bc9255b50d6333f2e4c74ca09fac6a8f6 \ + --hash=sha256:572c3763a264ba47b3cf708a44ce965d98555f618ca42c926a9c1616d8f34269 \ + --hash=sha256:573f6eac48f4769d667c4442081b1794f52919e7edada77495aaed9236d13a96 \ + --hash=sha256:5b4c145409bef602a690e7cfad0a15a55c13320ff7a3ad7ca59c13bb8ba4d45d \ + --hash=sha256:6463effa3186ea09411d50efc7d85360b38d5f09b870c48e4600f63af490e56a \ + --hash=sha256:65f6f63034100ead094b8744b3b97965785388f308a64cf8d7c34f2f2e5be0c4 \ + --hash=sha256:663946639d296df6a2bb2aa51b60a2454ca1cb29835324c640dafb5ff2131a77 \ + --hash=sha256:6897af51655e3691ff853668779c7bad41579facacf5fd7253b0133308cf000d \ + --hash=sha256:68d1f8a9e9e37c1223b656399be5d6b448dea850bed7d0f87a8311f1ff3dabb0 \ + --hash=sha256:6ac7ffc7ad6d040517be39eb591cac5ff87416c2537df6ba3cba3bae290c0fed \ + --hash=sha256:6b3251890fff30ee142c44144871185dbe13b11bab478a88887a639655be1068 \ + --hash=sha256:6c4caeef8fa63d06bd437cd4bdcf3ffefe6738fb1b25951440d80dc7df8c03ac \ + --hash=sha256:6ef1d82a3af9d3eecdba2321dc1b3c238245d890843e040e41e470ffa64c3e25 \ + --hash=sha256:753f10e867343b4511128c6ed8c82f7bec3bd026875576dfd88483c5c73b2fd8 \ + --hash=sha256:7cd13a2e3ddeed6913a65e66e94b51d80a041145a026c27e6bb76c31a853c6ab \ + --hash=sha256:7ed9e526742851e8d5cc9e6cf41427dfc6068d4f5a3bb03659444b4cabf6bc26 \ + --hash=sha256:7f04c839ed0b6b98b1a7501a002144b76c18fb1c1850c8b98d458ac269e26ed2 \ + --hash=sha256:802fe99cca7457642125a8a88a084cef28ff0cf9407060f7b93dca5aa25480db \ + --hash=sha256:80402cd6ee291dcb72644d6eac93785fe2c8b9cb30893c1af5b8fdd753b9d40f \ + --hash=sha256:8465322196c8b4d7ab6d1e049e4c5cb460d0394da4a27d23cc242fbf0034b6b5 \ + --hash=sha256:86216b5cee4b06df986d214f664305142d9c76df9b6512be2738aa72a2048f99 \ + --hash=sha256:87d1351268731db79e0f8e745d92493ee2841c974128ef629dc518b937d9194c \ + --hash=sha256:8bdb58ff7ba23002a4c5808d608e4e6c687175724f54a5dade5fa8c67b604e4d \ + --hash=sha256:8c622a5fe39a48f78944a87d4fb8a53ee07344641b0562c540d840748571b811 \ + --hash=sha256:8d756e44e94489e49571086ef83b2bb8ce311e730092d2c34ca8f7d925cb20aa \ + --hash=sha256:8f4a014bc36d3c57402e2977dada34f9c12300af536839dc38c0beab8878f38a \ + --hash=sha256:9063e24fdb1e498ab71cb7419e24622516c4a04476b17a2dab57e8baa30d6e03 \ + --hash=sha256:90d558489962fd4918143277a773316e56c72da56ec7aa3dc3dbbe20fdfed15b \ + --hash=sha256:923c0c831b7cfcb071580d3f46c4baf50f174be571576556269530f4bbd79d04 \ + --hash=sha256:95f2a5796329323b8f0512e09dbb7a1860c46a39da62ecb2324f116fa8fdc85c \ + --hash=sha256:96b02a3dc4381e5494fad39be677abcb5e6634bf7b4fa83a6dd3112607547001 \ + --hash=sha256:9f96df6923e21816da7e0ad3fd47dd8f94b2a5ce594e00677c0013018b813458 \ + --hash=sha256:a10af20b82360ab00827f916a6058451b723b4e65030c5a18577c8b2de5b3389 \ + --hash=sha256:a50aebfa173e157099939b17f18600f72f84eed3049e743b68ad15bd69b6bf99 \ + --hash=sha256:a981a536974bbc7a512cf44ed14938cf01030a99e9b3a06dd59578882f06f985 \ + --hash=sha256:a9a8e9031d613fd2009c182b69c7b2c1ef8239a0efb1df3f7c8da66d5dd3d537 \ + --hash=sha256:ae5f4161f18c61806f411a13b0310bea87f987c7d2ecdbdaad0e94eb2e404238 \ + --hash=sha256:aed38f6e4fb3f5d6bf81bfa990a07806be9d83cf7bacef998ab1a9bd660a581f \ + --hash=sha256:b01b88d45a6fcb69667cd6d2f7a9aeb4bf53760d7fc536bf679ec94fe9f3ff3d \ + --hash=sha256:b261ccdec7821281dade748d088bb6e9b69e6d15b30652b74cbbac25e280b796 \ + --hash=sha256:b2b0a0c0517616b6869869f8c581d4eb2dd83a4d79e0ebcb7d373ef9956aeb0a \ + --hash=sha256:b4a23f61ce87adf89be746c8a8974fe1c823c891d8f86eb218bb957c924bb143 \ + --hash=sha256:bd8f7df7d12c2db9fab40bdd87a7c09b1530128315d047a086fa3ae3435cb3a8 \ + --hash=sha256:beb58fe5cdb101e3a055192ac291b7a21e3b7ef4f67fa1d74e331a7f2124341c \ + --hash=sha256:c002b4ffc0be611f0d9da932eb0f704fe2602a9a949d1f738e4c34c75b0863d5 \ + --hash=sha256:c083af607d2515612056a31f0a8d9e0fcb5876b7bfc0abad3ecd275bc4ebc2d5 \ + --hash=sha256:c180f51afb394e165eafe4ac2936a14bee3eb10debc9d9e4db8958fe36afe711 \ + --hash=sha256:c235ebd9baae02f1b77bcea61bce332cb4331dc3617d254df3323aa01ab47bd4 \ + --hash=sha256:cd70574b12bb8a4d2aaa0094515df2463cb429d8536cfb6c7ce983246983e5a6 \ + --hash=sha256:d0eccceffcb53201b5bfebb52600a5fb483a20b61da9dbc885f8b103cbe7598c \ + --hash=sha256:d965bba47ddeec8cd560687584e88cf699fd28f192ceb452d1d7ee807c5597b7 \ + --hash=sha256:db364eca23f876da6f9e16c9da0df51aa4f104a972735574842618b8c6d999d4 \ + --hash=sha256:ddbb2551d7e0102e7252db79ba445cdab71b26640817ab1e3e3648dad515003b \ + --hash=sha256:deb6be0ac38ece9ba87dea880e438f25ca3eddfac8b002a2ec3d9183a454e8ae \ + --hash=sha256:e06ed3eb3218bc64786f7db41917d4e686cc4856944f53d5bdf83a6884432e12 \ + --hash=sha256:e27ad930a842b4c5eb8ac0016b0a54f5aebbe679340c26101df33424142c143c \ + --hash=sha256:e537484df0d8f426ce2afb2d0f8e1c3d0b114b83f8850e5f2fbea0e797bd82ae \ + --hash=sha256:eb00ed941194665c332bf8e078baf037d6c35d7c4f3102ea2d4f16ca94a26dc8 \ + --hash=sha256:eb6904c354526e758fda7167b33005998fb68c46fbc10e013ca97f21ca5c8887 \ + --hash=sha256:eb8821e09e916165e160797a6c17edda0679379a4be5c716c260e836e122f54b \ + --hash=sha256:efcb3f6676480691518c177e3b465bcddf57cea040302f9f4e6e191af91174d4 \ + --hash=sha256:f27273b60488abe721a075bcca6d7f3964f9f6f067c8c4c605743023d7d3944f \ + --hash=sha256:f30c3cb33b24454a82faecaf01b19c18562b1e89558fb6c56de4d9118a032fd5 \ + --hash=sha256:fb69256e180cb6c8a894fee62b3afebae785babc1ee98b81cdf68bbca1987f33 \ + --hash=sha256:fd1abc0d89e30cc4e02e4064dc67fcc51bd941eb395c502aac3ec19fab46b519 \ + --hash=sha256:ff8fa367d09b717b2a17a052544193ad76cd49979c805768879cb63d9ca50561 # via requests dill==0.3.7 \ --hash=sha256:76b122c08ef4ce2eedcd4d1abd8e641114bfc6c2867f49f3c41facf65bf19f5e \ @@ -133,6 +129,7 @@ dm-tree==0.1.8 \ --hash=sha256:35cc164a79336bfcfafb47e5f297898359123bbd3330c1967f0c4994f9cf9f60 \ --hash=sha256:378cc8ad93c5fe3590f405a309980721f021c790ca1bdf9b15bb1d59daec57f5 \ --hash=sha256:39070ba268c0491af9fe7a58644d99e8b4f2cde6e5884ba3380bddc84ed43d5f \ + --hash=sha256:435227cf3c5dc63f4de054cf3d00183790bd9ead4c3623138c74dde7f67f521b \ --hash=sha256:5483dca4d7eb1a0d65fe86d3b6a53ae717face83c1f17e0887b1a4a64ae5c410 \ --hash=sha256:694c3654cfd2a81552c08ec66bb5c4a3d48fa292b9a181880fb081c36c5b9134 \ --hash=sha256:803bfc53b4659f447ac694dbd04235f94a73ef7c1fd1e0df7c84ac41e0bc963b \ @@ -140,6 +137,8 @@ dm-tree==0.1.8 \ --hash=sha256:83b7764de0d855338abefc6e3ee9fe40d301668310aa3baea3f778ff051f4393 \ --hash=sha256:8c60a7eadab64c2278861f56bca320b2720f163dca9d7558103c3b77f2416571 \ --hash=sha256:8ed3564abed97c806db122c2d3e1a2b64c74a63debe9903aad795167cc301368 \ + --hash=sha256:94d3f0826311f45ee19b75f5b48c99466e4218a0489e81c0f0167bda50cacf22 \ + --hash=sha256:96a548a406a6fb15fe58f6a30a57ff2f2aafbf25f05afab00c8f5e5977b6c715 \ --hash=sha256:a5d819c38c03f0bb5b3b3703c60e4b170355a0fc6b5819325bf3d4ceb3ae7e80 \ --hash=sha256:ad16ceba90a56ec47cf45b21856d14962ac314787975ef786efb5e6e9ca75ec7 \ --hash=sha256:af4b3d372f2477dcd89a6e717e4a575ca35ccc20cc4454a8a4b6f8838a00672d \ @@ -147,6 +146,7 @@ dm-tree==0.1.8 \ --hash=sha256:b9bd9b9ccb59409d33d51d84b7668010c04c2af7d4a371632874c1ca356cff3d \ --hash=sha256:b9f89a454e98806b44fe9d40ec9eee61f848388f7e79ac2371a55679bd5a3ac6 \ --hash=sha256:bb2d109f42190225112da899b9f3d46d0d5f26aef501c61e43529fe9322530b5 \ + --hash=sha256:c0a94aba18a35457a1b5cd716fd7b46c5dafdc4cf7869b4bae665b91c4682a8e \ --hash=sha256:c5c8c12e3fda754ef6af94161bacdaeda816d941995fac415d6855c6c386af68 \ --hash=sha256:d1612fcaecd79023dbc6a6ae48d51a80beb5c385d6f3f6d71688e57bc8d07de8 \ --hash=sha256:d16e1f2a073604cfcc09f7131ae8d534674f43c3aef4c25742eae295bc60d04f \ @@ -154,6 +154,7 @@ dm-tree==0.1.8 \ --hash=sha256:d40fa4106ca6edc66760246a08f500ec0c85ef55c762fb4a363f6ee739ba02ee \ --hash=sha256:de287fabc464b8734be251e46e06aa9aa1001f34198da2b6ce07bd197172b9cb \ --hash=sha256:e4d714371bb08839e4e5e29024fc95832d9affe129825ef38836b143028bd144 \ + --hash=sha256:ea9e59e0451e7d29aece402d9f908f2e2a80922bcde2ebfd5dcb07750fcbfee8 \ --hash=sha256:f7ac31b9aecccb2c6e1ab29706f6ded3eba0c2c69c770322c9c685929c3d6afb \ --hash=sha256:fa42a605d099ee7d41ba2b5fb75e21423951fd26e5d50583a00471238fb3021d # via keras-nightly @@ -161,71 +162,61 @@ gast==0.4.0 \ --hash=sha256:40feb7b8b8434785585ab224d1568b857edb18297e5a3047f1ba012bc83b42c1 \ --hash=sha256:b7adcdd5adbebf1adf17378da5ba3f543684dbec47b1cda1f3997e573cd542c4 # via -r requirements.in -google-auth==2.23.3 \ - --hash=sha256:6864247895eea5d13b9c57c9e03abb49cb94ce2dc7c58e91cba3248c7477c9e3 \ - --hash=sha256:a8f4608e65c244ead9e0538f181a96c6e11199ec114d41f1d7b1bffa96937bda - # via - # google-auth-oauthlib - # tb-nightly -google-auth-oauthlib==1.1.0 \ - --hash=sha256:089c6e587d36f4803ac7e0720c045c6a8b1fd1790088b8424975b90d0ee61c12 \ - --hash=sha256:83ea8c3b0881e453790baff4448e8a6112ac8778d1de9da0b68010b843937afb - # via tb-nightly -grpcio==1.59.2 \ - --hash=sha256:023088764012411affe7db183d1ada3ad9daf2e23ddc719ff46d7061de661340 \ - --hash=sha256:08d77e682f2bf730a4961eea330e56d2f423c6a9b91ca222e5b1eb24a357b19f \ - --hash=sha256:0a4a3833c0e067f3558538727235cd8a49709bff1003200bbdefa2f09334e4b1 \ - --hash=sha256:0a754aff9e3af63bdc4c75c234b86b9d14e14a28a30c4e324aed1a9b873d755f \ - --hash=sha256:11168ef43e4a43ff1b1a65859f3e0ef1a173e277349e7fb16923ff108160a8cd \ - --hash=sha256:128e20f57c5f27cb0157e73756d1586b83c1b513ebecc83ea0ac37e4b0e4e758 \ - --hash=sha256:1f9524d1d701e399462d2c90ba7c193e49d1711cf429c0d3d97c966856e03d00 \ - --hash=sha256:1ff16d68bf453275466a9a46739061a63584d92f18a0f5b33d19fc97eb69867c \ - --hash=sha256:2067274c88bc6de89c278a672a652b4247d088811ece781a4858b09bdf8448e3 \ - --hash=sha256:2171c39f355ba5b551c5d5928d65aa6c69807fae195b86ef4a7d125bcdb860a9 \ - --hash=sha256:242adc47725b9a499ee77c6a2e36688fa6c96484611f33b1be4c57ab075a92dd \ - --hash=sha256:27f879ae604a7fcf371e59fba6f3ff4635a4c2a64768bd83ff0cac503142fef4 \ - --hash=sha256:2b230028a008ae1d0f430acb227d323ff8a619017415cf334c38b457f814119f \ - --hash=sha256:3059668df17627f0e0fa680e9ef8c995c946c792612e9518f5cc1503be14e90b \ - --hash=sha256:31176aa88f36020055ace9adff2405a33c8bdbfa72a9c4980e25d91b2f196873 \ - --hash=sha256:36f53c2b3449c015880e7d55a89c992c357f176327b0d2873cdaaf9628a37c69 \ - --hash=sha256:3b4368b33908f683a363f376dfb747d40af3463a6e5044afee07cf9436addf96 \ - --hash=sha256:3c61d641d4f409c5ae46bfdd89ea42ce5ea233dcf69e74ce9ba32b503c727e29 \ - --hash=sha256:4abb717e320e74959517dc8e84a9f48fbe90e9abe19c248541e9418b1ce60acd \ - --hash=sha256:4c93f4abbb54321ee6471e04a00139c80c754eda51064187963ddf98f5cf36a4 \ - --hash=sha256:535561990e075fa6bd4b16c4c3c1096b9581b7bb35d96fac4650f1181e428268 \ - --hash=sha256:53c9aa5ddd6857c0a1cd0287225a2a25873a8e09727c2e95c4aebb1be83a766a \ - --hash=sha256:5d573e70a6fe77555fb6143c12d3a7d3fa306632a3034b4e7c59ca09721546f8 \ - --hash=sha256:6009386a2df66159f64ac9f20425ae25229b29b9dd0e1d3dd60043f037e2ad7e \ - --hash=sha256:686e975a5d16602dc0982c7c703948d17184bd1397e16c8ee03511ecb8c4cdda \ - --hash=sha256:6959fb07e8351e20501ffb8cc4074c39a0b7ef123e1c850a7f8f3afdc3a3da01 \ - --hash=sha256:6b25ed37c27e652db01be341af93fbcea03d296c024d8a0e680017a268eb85dd \ - --hash=sha256:6da6dea3a1bacf99b3c2187e296db9a83029ed9c38fd4c52b7c9b7326d13c828 \ - --hash=sha256:72ca2399097c0b758198f2ff30f7178d680de8a5cfcf3d9b73a63cf87455532e \ - --hash=sha256:73abb8584b0cf74d37f5ef61c10722adc7275502ab71789a8fe3cb7ef04cf6e2 \ - --hash=sha256:74100fecaec8a535e380cf5f2fb556ff84957d481c13e54051c52e5baac70541 \ - --hash=sha256:75c6ecb70e809cf1504465174343113f51f24bc61e22a80ae1c859f3f7034c6d \ - --hash=sha256:7cf05053242f61ba94014dd3a986e11a083400a32664058f80bf4cf817c0b3a1 \ - --hash=sha256:9411e24328a2302e279e70cae6e479f1fddde79629fcb14e03e6d94b3956eabf \ - --hash=sha256:a213acfbf186b9f35803b52e4ca9addb153fc0b67f82a48f961be7000ecf6721 \ - --hash=sha256:bb7e0fe6ad73b7f06d7e2b689c19a71cf5cc48f0c2bf8608469e51ffe0bd2867 \ - --hash=sha256:c2504eed520958a5b77cc99458297cb7906308cb92327f35fb7fbbad4e9b2188 \ - --hash=sha256:c35aa9657f5d5116d23b934568e0956bd50c615127810fffe3ac356a914c176a \ - --hash=sha256:c5f09cffa619adfb44799fa4a81c2a1ad77c887187613fb0a8f201ab38d89ba1 \ - --hash=sha256:c978f864b35f2261e0819f5cd88b9830b04dc51bcf055aac3c601e525a10d2ba \ - --hash=sha256:cbe946b3e6e60a7b4618f091e62a029cb082b109a9d6b53962dd305087c6e4fd \ - --hash=sha256:cc3e4cd087f07758b16bef8f31d88dbb1b5da5671d2f03685ab52dece3d7a16e \ - --hash=sha256:cf0dead5a2c5a3347af2cfec7131d4f2a2e03c934af28989c9078f8241a491fa \ - --hash=sha256:d2794f0e68b3085d99b4f6ff9c089f6fdd02b32b9d3efdfbb55beac1bf22d516 \ - --hash=sha256:d2fa68a96a30dd240be80bbad838a0ac81a61770611ff7952b889485970c4c71 \ - --hash=sha256:d6f70406695e3220f09cd7a2f879333279d91aa4a8a1d34303b56d61a8180137 \ - --hash=sha256:d8f9cd4ad1be90b0cf350a2f04a38a36e44a026cac1e036ac593dc48efe91d52 \ - --hash=sha256:da2d94c15f88cd40d7e67f7919d4f60110d2b9d5b1e08cf354c2be773ab13479 \ - --hash=sha256:e1727c1c0e394096bb9af185c6923e8ea55a5095b8af44f06903bcc0e06800a2 \ - --hash=sha256:e420ced29b5904cdf9ee5545e23f9406189d8acb6750916c2db4793dada065c6 \ - --hash=sha256:e82c5cf1495244adf5252f925ac5932e5fd288b3e5ab6b70bec5593074b7236c \ - --hash=sha256:f1ef0d39bc1feb420caf549b3c657c871cad4ebbcf0580c4d03816b0590de0cf \ - --hash=sha256:f8753a6c88d1d0ba64302309eecf20f70d2770f65ca02d83c2452279085bfcd3 \ - --hash=sha256:f93dbf58f03146164048be5426ffde298b237a5e059144847e4940f5b80172c3 +grpcio==1.60.1 \ + --hash=sha256:0250a7a70b14000fa311de04b169cc7480be6c1a769b190769d347939d3232a8 \ + --hash=sha256:069fe2aeee02dfd2135d562d0663fe70fbb69d5eed6eb3389042a7e963b54de8 \ + --hash=sha256:082081e6a36b6eb5cf0fd9a897fe777dbb3802176ffd08e3ec6567edd85bc104 \ + --hash=sha256:0c5807e9152eff15f1d48f6b9ad3749196f79a4a050469d99eecb679be592acc \ + --hash=sha256:14e8f2c84c0832773fb3958240c69def72357bc11392571f87b2d7b91e0bb092 \ + --hash=sha256:2a6087f234cb570008a6041c8ffd1b7d657b397fdd6d26e83d72283dae3527b1 \ + --hash=sha256:2bb2a2911b028f01c8c64d126f6b632fcd8a9ac975aa1b3855766c94e4107180 \ + --hash=sha256:2f44c32aef186bbba254129cea1df08a20be414144ac3bdf0e84b24e3f3b2e05 \ + --hash=sha256:30e980cd6db1088c144b92fe376747328d5554bc7960ce583ec7b7d81cd47287 \ + --hash=sha256:33aed0a431f5befeffd9d346b0fa44b2c01aa4aeae5ea5b2c03d3e25e0071216 \ + --hash=sha256:33bdea30dcfd4f87b045d404388469eb48a48c33a6195a043d116ed1b9a0196c \ + --hash=sha256:39aa848794b887120b1d35b1b994e445cc028ff602ef267f87c38122c1add50d \ + --hash=sha256:4216e67ad9a4769117433814956031cb300f85edc855252a645a9a724b3b6594 \ + --hash=sha256:49c9b6a510e3ed8df5f6f4f3c34d7fbf2d2cae048ee90a45cd7415abab72912c \ + --hash=sha256:4eec8b8c1c2c9b7125508ff7c89d5701bf933c99d3910e446ed531cd16ad5d87 \ + --hash=sha256:50d56280b482875d1f9128ce596e59031a226a8b84bec88cb2bf76c289f5d0de \ + --hash=sha256:53b69e79d00f78c81eecfb38f4516080dc7f36a198b6b37b928f1c13b3c063e9 \ + --hash=sha256:55ccb7db5a665079d68b5c7c86359ebd5ebf31a19bc1a91c982fd622f1e31ff2 \ + --hash=sha256:5a1ebbae7e2214f51b1f23b57bf98eeed2cf1ba84e4d523c48c36d5b2f8829ff \ + --hash=sha256:61b7199cd2a55e62e45bfb629a35b71fc2c0cb88f686a047f25b1112d3810904 \ + --hash=sha256:660fc6b9c2a9ea3bb2a7e64ba878c98339abaf1811edca904ac85e9e662f1d73 \ + --hash=sha256:6d140bdeb26cad8b93c1455fa00573c05592793c32053d6e0016ce05ba267549 \ + --hash=sha256:6e490fa5f7f5326222cb9f0b78f207a2b218a14edf39602e083d5f617354306f \ + --hash=sha256:6ecf21d20d02d1733e9c820fb5c114c749d888704a7ec824b545c12e78734d1c \ + --hash=sha256:70c83bb530572917be20c21f3b6be92cd86b9aecb44b0c18b1d3b2cc3ae47df0 \ + --hash=sha256:72153a0d2e425f45b884540a61c6639436ddafa1829a42056aa5764b84108b8e \ + --hash=sha256:73e14acd3d4247169955fae8fb103a2b900cfad21d0c35f0dcd0fdd54cd60367 \ + --hash=sha256:76eaaba891083fcbe167aa0f03363311a9f12da975b025d30e94b93ac7a765fc \ + --hash=sha256:79ae0dc785504cb1e1788758c588c711f4e4a0195d70dff53db203c95a0bd303 \ + --hash=sha256:7d142bcd604166417929b071cd396aa13c565749a4c840d6c702727a59d835eb \ + --hash=sha256:8c9554ca8e26241dabe7951aa1fa03a1ba0856688ecd7e7bdbdd286ebc272e4c \ + --hash=sha256:8d488fbdbf04283f0d20742b64968d44825617aa6717b07c006168ed16488804 \ + --hash=sha256:91422ba785a8e7a18725b1dc40fbd88f08a5bb4c7f1b3e8739cab24b04fa8a03 \ + --hash=sha256:9a66f4d2a005bc78e61d805ed95dedfcb35efa84b7bba0403c6d60d13a3de2d6 \ + --hash=sha256:9b106bc52e7f28170e624ba61cc7dc6829566e535a6ec68528f8e1afbed1c41f \ + --hash=sha256:9b54577032d4f235452f77a83169b6527bf4b77d73aeada97d45b2aaf1bf5ce0 \ + --hash=sha256:a09506eb48fa5493c58f946c46754ef22f3ec0df64f2b5149373ff31fb67f3dd \ + --hash=sha256:a212e5dea1a4182e40cd3e4067ee46be9d10418092ce3627475e995cca95de21 \ + --hash=sha256:a731ac5cffc34dac62053e0da90f0c0b8560396a19f69d9703e88240c8f05858 \ + --hash=sha256:af5ef6cfaf0d023c00002ba25d0751e5995fa0e4c9eec6cd263c30352662cbce \ + --hash=sha256:b58b855d0071575ea9c7bc0d84a06d2edfbfccec52e9657864386381a7ce1ae9 \ + --hash=sha256:bc808924470643b82b14fe121923c30ec211d8c693e747eba8a7414bc4351a23 \ + --hash=sha256:c557e94e91a983e5b1e9c60076a8fd79fea1e7e06848eb2e48d0ccfb30f6e073 \ + --hash=sha256:c71be3f86d67d8d1311c6076a4ba3b75ba5703c0b856b4e691c9097f9b1e8bd2 \ + --hash=sha256:c8754c75f55781515a3005063d9a05878b2cfb3cb7e41d5401ad0cf19de14872 \ + --hash=sha256:cb0af13433dbbd1c806e671d81ec75bd324af6ef75171fd7815ca3074fe32bfe \ + --hash=sha256:cba6209c96828711cb7c8fcb45ecef8c8859238baf15119daa1bef0f6c84bfe7 \ + --hash=sha256:cf77f8cf2a651fbd869fbdcb4a1931464189cd210abc4cfad357f1cacc8642a6 \ + --hash=sha256:d7404cebcdb11bb5bd40bf94131faf7e9a7c10a6c60358580fe83913f360f929 \ + --hash=sha256:dd1d3a8d1d2e50ad9b59e10aa7f07c7d1be2b367f3f2d33c5fade96ed5460962 \ + --hash=sha256:e5d97c65ea7e097056f3d1ead77040ebc236feaf7f71489383d20f3b4c28412a \ + --hash=sha256:f1c3dc536b3ee124e8b24feb7533e5c70b9f2ef833e3b2e5513b2897fd46763a \ + --hash=sha256:f2212796593ad1d0235068c79836861f2201fc7137a99aa2fea7beeb3b101177 \ + --hash=sha256:fead980fbc68512dfd4e0c7b1f5754c2a8e5015a04dea454b9cada54a8423525 # via # -r requirements.in # tb-nightly @@ -258,117 +249,119 @@ h5py==3.10.0 \ # via # -r requirements.in # keras-nightly -idna==3.4 \ - --hash=sha256:814f528e8dead7d329833b91c5faa87d60bf71824cd12a7530b5526063d02cb4 \ - --hash=sha256:90b77e79eaa3eba6de819a0c442c0b4ceefc341a7a2ab77d7562bf49f425c5c2 +idna==3.6 \ + --hash=sha256:9ecdbbd083b06798ae1e86adcbfe8ab1479cf864e4ee30fe4e46a003d12491ca \ + --hash=sha256:c05567e9c24a6b9faaa835c4821bad0590fbb9d5779e7caa6e1cc4978e7eb24f # via requests -importlib-metadata==6.8.0 \ - --hash=sha256:3ebb78df84a805d7698245025b975d9d67053cd94c79245ba4b3eb694abe68bb \ - --hash=sha256:dbace7892d8c0c4ac1ad096662232f831d4e64f4c4545bd53016a3e9d4654743 +importlib-metadata==7.0.1 \ + --hash=sha256:4805911c3a4ec7c3966410053e9ec6a1fecd629117df5adee56dfc9432a1081e \ + --hash=sha256:f238736bb06590ae52ac1fab06a3a9ef1d8dce2b7a35b5ab329371d6c8f5d2cc # via markdown jax==0.4.7 \ --hash=sha256:5e7002d74db25f97c99b979d4ba1233b1ef26e1597e5fc468ad11d1c8a9dc4f8 # via -r requirements.in -keras-nightly==3.0.0.dev2023103103 \ - --hash=sha256:25a6a9030d7e067d535ec41ae6700636c90cabada80fbddb3bb8775006807860 \ - --hash=sha256:c39fccebb3d4cc9d838371981c4d3eef08fc06eadf6cffb39b356cb625bab50f +keras-nightly==3.0.4.dev2024021403 \ + --hash=sha256:24ce69d29d582771685bf4235f59663723405b5a5b16f3eaff2657e52e74663a \ + --hash=sha256:9f416e66b820ef833779d219d255b346b8b90a72fdbd0b2f1e90a43ad142a03d # via -r requirements.in -lit==17.0.4 \ - --hash=sha256:ee2e180128e770abc6aed3a02de2daf09d81b7d30225e315205d3599c311d304 +lit==17.0.6 \ + --hash=sha256:dfa9af9b55fc4509a56be7bf2346f079d7f4a242d583b9f2e0b078fd0abae31b # via -r requirements.in -markdown==3.5 \ - --hash=sha256:4afb124395ce5fc34e6d9886dab977fd9ae987fc6e85689f08278cf0c69d4bf3 \ - --hash=sha256:a807eb2e4778d9156c8f07876c6e4d50b5494c5665c4834f67b06459dfd877b3 +markdown==3.5.2 \ + --hash=sha256:d43323865d89fc0cb9b20c75fc8ad313af307cc087e84b657d9eec768eddeadd \ + --hash=sha256:e1ac7b3dc550ee80e602e71c1d168002f062e49f1b11e26a36264dafd4df2ef8 # via tb-nightly markdown-it-py==3.0.0 \ --hash=sha256:355216845c60bd96232cd8d8c40e8f9765cc86f46880e43a8fd22dc1a1a8cab1 \ --hash=sha256:e3f60a94fa066dc52ec76661e37c851cb232d92f9886b15cb560aaada2df8feb # via rich -markupsafe==2.1.3 \ - --hash=sha256:05fb21170423db021895e1ea1e1f3ab3adb85d1c2333cbc2310f2a26bc77272e \ - --hash=sha256:0a4e4a1aff6c7ac4cd55792abf96c915634c2b97e3cc1c7129578aa68ebd754e \ - --hash=sha256:10bbfe99883db80bdbaff2dcf681dfc6533a614f700da1287707e8a5d78a8431 \ - --hash=sha256:134da1eca9ec0ae528110ccc9e48041e0828d79f24121a1a146161103c76e686 \ - --hash=sha256:14ff806850827afd6b07a5f32bd917fb7f45b046ba40c57abdb636674a8b559c \ - --hash=sha256:1577735524cdad32f9f694208aa75e422adba74f1baee7551620e43a3141f559 \ - --hash=sha256:1b40069d487e7edb2676d3fbdb2b0829ffa2cd63a2ec26c4938b2d34391b4ecc \ - --hash=sha256:1b8dd8c3fd14349433c79fa8abeb573a55fc0fdd769133baac1f5e07abf54aeb \ - --hash=sha256:1f67c7038d560d92149c060157d623c542173016c4babc0c1913cca0564b9939 \ - --hash=sha256:282c2cb35b5b673bbcadb33a585408104df04f14b2d9b01d4c345a3b92861c2c \ - --hash=sha256:2c1b19b3aaacc6e57b7e25710ff571c24d6c3613a45e905b1fde04d691b98ee0 \ - --hash=sha256:2ef12179d3a291be237280175b542c07a36e7f60718296278d8593d21ca937d4 \ - --hash=sha256:338ae27d6b8745585f87218a3f23f1512dbf52c26c28e322dbe54bcede54ccb9 \ - --hash=sha256:3c0fae6c3be832a0a0473ac912810b2877c8cb9d76ca48de1ed31e1c68386575 \ - --hash=sha256:3fd4abcb888d15a94f32b75d8fd18ee162ca0c064f35b11134be77050296d6ba \ - --hash=sha256:42de32b22b6b804f42c5d98be4f7e5e977ecdd9ee9b660fda1a3edf03b11792d \ - --hash=sha256:47d4f1c5f80fc62fdd7777d0d40a2e9dda0a05883ab11374334f6c4de38adffd \ - --hash=sha256:504b320cd4b7eff6f968eddf81127112db685e81f7e36e75f9f84f0df46041c3 \ - --hash=sha256:525808b8019e36eb524b8c68acdd63a37e75714eac50e988180b169d64480a00 \ - --hash=sha256:56d9f2ecac662ca1611d183feb03a3fa4406469dafe241673d521dd5ae92a155 \ - --hash=sha256:5bbe06f8eeafd38e5d0a4894ffec89378b6c6a625ff57e3028921f8ff59318ac \ - --hash=sha256:65c1a9bcdadc6c28eecee2c119465aebff8f7a584dd719facdd9e825ec61ab52 \ - --hash=sha256:68e78619a61ecf91e76aa3e6e8e33fc4894a2bebe93410754bd28fce0a8a4f9f \ - --hash=sha256:69c0f17e9f5a7afdf2cc9fb2d1ce6aabdb3bafb7f38017c0b77862bcec2bbad8 \ - --hash=sha256:6b2b56950d93e41f33b4223ead100ea0fe11f8e6ee5f641eb753ce4b77a7042b \ - --hash=sha256:715d3562f79d540f251b99ebd6d8baa547118974341db04f5ad06d5ea3eb8007 \ - --hash=sha256:787003c0ddb00500e49a10f2844fac87aa6ce977b90b0feaaf9de23c22508b24 \ - --hash=sha256:7ef3cb2ebbf91e330e3bb937efada0edd9003683db6b57bb108c4001f37a02ea \ - --hash=sha256:8023faf4e01efadfa183e863fefde0046de576c6f14659e8782065bcece22198 \ - --hash=sha256:8758846a7e80910096950b67071243da3e5a20ed2546e6392603c096778d48e0 \ - --hash=sha256:8afafd99945ead6e075b973fefa56379c5b5c53fd8937dad92c662da5d8fd5ee \ - --hash=sha256:8c41976a29d078bb235fea9b2ecd3da465df42a562910f9022f1a03107bd02be \ - --hash=sha256:8e254ae696c88d98da6555f5ace2279cf7cd5b3f52be2b5cf97feafe883b58d2 \ - --hash=sha256:8f9293864fe09b8149f0cc42ce56e3f0e54de883a9de90cd427f191c346eb2e1 \ - --hash=sha256:9402b03f1a1b4dc4c19845e5c749e3ab82d5078d16a2a4c2cd2df62d57bb0707 \ - --hash=sha256:962f82a3086483f5e5f64dbad880d31038b698494799b097bc59c2edf392fce6 \ - --hash=sha256:9aad3c1755095ce347e26488214ef77e0485a3c34a50c5a5e2471dff60b9dd9c \ - --hash=sha256:9dcdfd0eaf283af041973bff14a2e143b8bd64e069f4c383416ecd79a81aab58 \ - --hash=sha256:aa57bd9cf8ae831a362185ee444e15a93ecb2e344c8e52e4d721ea3ab6ef1823 \ - --hash=sha256:aa7bd130efab1c280bed0f45501b7c8795f9fdbeb02e965371bbef3523627779 \ - --hash=sha256:ab4a0df41e7c16a1392727727e7998a467472d0ad65f3ad5e6e765015df08636 \ - --hash=sha256:ad9e82fb8f09ade1c3e1b996a6337afac2b8b9e365f926f5a61aacc71adc5b3c \ - --hash=sha256:af598ed32d6ae86f1b747b82783958b1a4ab8f617b06fe68795c7f026abbdcad \ - --hash=sha256:b076b6226fb84157e3f7c971a47ff3a679d837cf338547532ab866c57930dbee \ - --hash=sha256:b7ff0f54cb4ff66dd38bebd335a38e2c22c41a8ee45aa608efc890ac3e3931bc \ - --hash=sha256:bfce63a9e7834b12b87c64d6b155fdd9b3b96191b6bd334bf37db7ff1fe457f2 \ - --hash=sha256:c011a4149cfbcf9f03994ec2edffcb8b1dc2d2aede7ca243746df97a5d41ce48 \ - --hash=sha256:c9c804664ebe8f83a211cace637506669e7890fec1b4195b505c214e50dd4eb7 \ - --hash=sha256:ca379055a47383d02a5400cb0d110cef0a776fc644cda797db0c5696cfd7e18e \ - --hash=sha256:cb0932dc158471523c9637e807d9bfb93e06a95cbf010f1a38b98623b929ef2b \ - --hash=sha256:cd0f502fe016460680cd20aaa5a76d241d6f35a1c3350c474bac1273803893fa \ - --hash=sha256:ceb01949af7121f9fc39f7d27f91be8546f3fb112c608bc4029aef0bab86a2a5 \ - --hash=sha256:d080e0a5eb2529460b30190fcfcc4199bd7f827663f858a226a81bc27beaa97e \ - --hash=sha256:dd15ff04ffd7e05ffcb7fe79f1b98041b8ea30ae9234aed2a9168b5797c3effb \ - --hash=sha256:df0be2b576a7abbf737b1575f048c23fb1d769f267ec4358296f31c2479db8f9 \ - --hash=sha256:e09031c87a1e51556fdcb46e5bd4f59dfb743061cf93c4d6831bf894f125eb57 \ - --hash=sha256:e4dd52d80b8c83fdce44e12478ad2e85c64ea965e75d66dbeafb0a3e77308fcc \ - --hash=sha256:f698de3fd0c4e6972b92290a45bd9b1536bffe8c6759c62471efaa8acb4c37bc \ - --hash=sha256:fec21693218efe39aa7f8599346e90c705afa52c5b31ae019b2e57e8f6542bb2 \ - --hash=sha256:ffcc3f7c66b5f5b7931a5aa68fc9cecc51e685ef90282f4a82f0f5e9b704ad11 +markupsafe==2.1.5 \ + --hash=sha256:00e046b6dd71aa03a41079792f8473dc494d564611a8f89bbbd7cb93295ebdcf \ + --hash=sha256:075202fa5b72c86ad32dc7d0b56024ebdbcf2048c0ba09f1cde31bfdd57bcfff \ + --hash=sha256:0e397ac966fdf721b2c528cf028494e86172b4feba51d65f81ffd65c63798f3f \ + --hash=sha256:17b950fccb810b3293638215058e432159d2b71005c74371d784862b7e4683f3 \ + --hash=sha256:1f3fbcb7ef1f16e48246f704ab79d79da8a46891e2da03f8783a5b6fa41a9532 \ + --hash=sha256:2174c595a0d73a3080ca3257b40096db99799265e1c27cc5a610743acd86d62f \ + --hash=sha256:2b7c57a4dfc4f16f7142221afe5ba4e093e09e728ca65c51f5620c9aaeb9a617 \ + --hash=sha256:2d2d793e36e230fd32babe143b04cec8a8b3eb8a3122d2aceb4a371e6b09b8df \ + --hash=sha256:30b600cf0a7ac9234b2638fbc0fb6158ba5bdcdf46aeb631ead21248b9affbc4 \ + --hash=sha256:397081c1a0bfb5124355710fe79478cdbeb39626492b15d399526ae53422b906 \ + --hash=sha256:3a57fdd7ce31c7ff06cdfbf31dafa96cc533c21e443d57f5b1ecc6cdc668ec7f \ + --hash=sha256:3c6b973f22eb18a789b1460b4b91bf04ae3f0c4234a0a6aa6b0a92f6f7b951d4 \ + --hash=sha256:3e53af139f8579a6d5f7b76549125f0d94d7e630761a2111bc431fd820e163b8 \ + --hash=sha256:4096e9de5c6fdf43fb4f04c26fb114f61ef0bf2e5604b6ee3019d51b69e8c371 \ + --hash=sha256:4275d846e41ecefa46e2015117a9f491e57a71ddd59bbead77e904dc02b1bed2 \ + --hash=sha256:4c31f53cdae6ecfa91a77820e8b151dba54ab528ba65dfd235c80b086d68a465 \ + --hash=sha256:4f11aa001c540f62c6166c7726f71f7573b52c68c31f014c25cc7901deea0b52 \ + --hash=sha256:5049256f536511ee3f7e1b3f87d1d1209d327e818e6ae1365e8653d7e3abb6a6 \ + --hash=sha256:58c98fee265677f63a4385256a6d7683ab1832f3ddd1e66fe948d5880c21a169 \ + --hash=sha256:598e3276b64aff0e7b3451b72e94fa3c238d452e7ddcd893c3ab324717456bad \ + --hash=sha256:5b7b716f97b52c5a14bffdf688f971b2d5ef4029127f1ad7a513973cfd818df2 \ + --hash=sha256:5dedb4db619ba5a2787a94d877bc8ffc0566f92a01c0ef214865e54ecc9ee5e0 \ + --hash=sha256:619bc166c4f2de5caa5a633b8b7326fbe98e0ccbfacabd87268a2b15ff73a029 \ + --hash=sha256:629ddd2ca402ae6dbedfceeba9c46d5f7b2a61d9749597d4307f943ef198fc1f \ + --hash=sha256:656f7526c69fac7f600bd1f400991cc282b417d17539a1b228617081106feb4a \ + --hash=sha256:6ec585f69cec0aa07d945b20805be741395e28ac1627333b1c5b0105962ffced \ + --hash=sha256:72b6be590cc35924b02c78ef34b467da4ba07e4e0f0454a2c5907f473fc50ce5 \ + --hash=sha256:7502934a33b54030eaf1194c21c692a534196063db72176b0c4028e140f8f32c \ + --hash=sha256:7a68b554d356a91cce1236aa7682dc01df0edba8d043fd1ce607c49dd3c1edcf \ + --hash=sha256:7b2e5a267c855eea6b4283940daa6e88a285f5f2a67f2220203786dfa59b37e9 \ + --hash=sha256:823b65d8706e32ad2df51ed89496147a42a2a6e01c13cfb6ffb8b1e92bc910bb \ + --hash=sha256:8590b4ae07a35970728874632fed7bd57b26b0102df2d2b233b6d9d82f6c62ad \ + --hash=sha256:8dd717634f5a044f860435c1d8c16a270ddf0ef8588d4887037c5028b859b0c3 \ + --hash=sha256:8dec4936e9c3100156f8a2dc89c4b88d5c435175ff03413b443469c7c8c5f4d1 \ + --hash=sha256:97cafb1f3cbcd3fd2b6fbfb99ae11cdb14deea0736fc2b0952ee177f2b813a46 \ + --hash=sha256:a17a92de5231666cfbe003f0e4b9b3a7ae3afb1ec2845aadc2bacc93ff85febc \ + --hash=sha256:a549b9c31bec33820e885335b451286e2969a2d9e24879f83fe904a5ce59d70a \ + --hash=sha256:ac07bad82163452a6884fe8fa0963fb98c2346ba78d779ec06bd7a6262132aee \ + --hash=sha256:ae2ad8ae6ebee9d2d94b17fb62763125f3f374c25618198f40cbb8b525411900 \ + --hash=sha256:b91c037585eba9095565a3556f611e3cbfaa42ca1e865f7b8015fe5c7336d5a5 \ + --hash=sha256:bc1667f8b83f48511b94671e0e441401371dfd0f0a795c7daa4a3cd1dde55bea \ + --hash=sha256:bec0a414d016ac1a18862a519e54b2fd0fc8bbfd6890376898a6c0891dd82e9f \ + --hash=sha256:bf50cd79a75d181c9181df03572cdce0fbb75cc353bc350712073108cba98de5 \ + --hash=sha256:bff1b4290a66b490a2f4719358c0cdcd9bafb6b8f061e45c7a2460866bf50c2e \ + --hash=sha256:c061bb86a71b42465156a3ee7bd58c8c2ceacdbeb95d05a99893e08b8467359a \ + --hash=sha256:c8b29db45f8fe46ad280a7294f5c3ec36dbac9491f2d1c17345be8e69cc5928f \ + --hash=sha256:ce409136744f6521e39fd8e2a24c53fa18ad67aa5bc7c2cf83645cce5b5c4e50 \ + --hash=sha256:d050b3361367a06d752db6ead6e7edeb0009be66bc3bae0ee9d97fb326badc2a \ + --hash=sha256:d283d37a890ba4c1ae73ffadf8046435c76e7bc2247bbb63c00bd1a709c6544b \ + --hash=sha256:d9fad5155d72433c921b782e58892377c44bd6252b5af2f67f16b194987338a4 \ + --hash=sha256:daa4ee5a243f0f20d528d939d06670a298dd39b1ad5f8a72a4275124a7819eff \ + --hash=sha256:db0b55e0f3cc0be60c1f19efdde9a637c32740486004f20d1cff53c3c0ece4d2 \ + --hash=sha256:e61659ba32cf2cf1481e575d0462554625196a1f2fc06a1c777d3f48e8865d46 \ + --hash=sha256:ea3d8a3d18833cf4304cd2fc9cbb1efe188ca9b5efef2bdac7adc20594a0e46b \ + --hash=sha256:ec6a563cff360b50eed26f13adc43e61bc0c04d94b8be985e6fb24b81f6dcfdf \ + --hash=sha256:f5dfb42c4604dddc8e4305050aa6deb084540643ed5804d7455b5df8fe16f5e5 \ + --hash=sha256:fa173ec60341d6bb97a89f5ea19c85c5643c1e7dedebc22f5181eb73573142c5 \ + --hash=sha256:fa9db3f79de01457b03d4f01b34cf91bc0048eb2c3846ff26f66687c2f6d16ab \ + --hash=sha256:fce659a462a1be54d2ffcacea5e3ba2d74daa74f30f5f143fe0c58636e355fdd \ + --hash=sha256:ffee1f21e5ef0d712f9033568f8344d5da8cc2869dbd08d87c84656e6a2d2f68 # via werkzeug mdurl==0.1.2 \ --hash=sha256:84008a41e51615a49fc9966191ff91509e3c40b939176e643fd50a5c2196b8f8 \ --hash=sha256:bb413d29f5eea38f31dd4754dd7377d4465116fb207585f97bf925588687c1ba # via markdown-it-py -ml-dtypes==0.3.1 \ - --hash=sha256:3d8ca0acbd377082792d8b97081ba580abdad67c6afb7f827012c675b052f058 \ - --hash=sha256:42a8980afd8b7c8e270e8b5c260237286b5b26acd276fcb758d13cd7cb567e99 \ - --hash=sha256:438437e2e614a3c91d75581653b6c40ec890e8b5994d7190a90c931740151c95 \ - --hash=sha256:4828b62fa3bf1ae35faa40f3db9a38ec72fbce02f328a1d14c3a9da4606af364 \ - --hash=sha256:4d94b2d1bed77284694f7fd0479640fa7aa5d96433dca3cbcec407a5ef752e77 \ - --hash=sha256:510d249a91face47211762eb294d6fe64f325356b965fb6388c1bf51bd339267 \ - --hash=sha256:5727effa7650f7ab10906542d137cfb3244fdc3b2b519beff42f82def8ba59be \ - --hash=sha256:5e0b0b6bb07fa5ad11bb61d174667176bee5e05857225067aabfc5adc1b51d23 \ - --hash=sha256:60778f99194b4c4f36ba42da200b35ef851ce4d4af698aaf70f5b91fe70fc611 \ - --hash=sha256:70984b473db6489ec1d8c79b082a1322105155193049d08a3b0c515094e9777b \ - --hash=sha256:979d7d196d9a17e0135ae22878f74241fbd3522cef58d7b292f1fd5b32282201 \ - --hash=sha256:a777928dcba8865ab4a8157eeb25d23aed7bc82e5fd74e1d5eca821d3f148b39 \ - --hash=sha256:cb0c404e0dd3e25b56362c1c1e5de0ef717f727dde59fa721af4ce6ab2acca44 \ - --hash=sha256:d1a8dc3bac1da2a17d0e2e4cba36ee89721d0bd33ea4765af2eefb5f41409e0f \ - --hash=sha256:da274599e4950a9b488d21571061f49a185537cc77f2d3f8121151d58a9e9f16 \ - --hash=sha256:f83ff080df8910c0f987f615b03e4f8198638e0c00c6e679ea8892dda909763b \ - --hash=sha256:fcae2c69715410d96906e1dfe8f017d9f78a0d10e0df91aae52e91f51fdfe45e - # via jax +ml-dtypes==0.3.2 \ + --hash=sha256:2c34f2ba9660b21fe1034b608308a01be82bbef2a92fb8199f24dc6bad0d5226 \ + --hash=sha256:3a17ef2322e60858d93584e9c52a5be7dd6236b056b7fa1ec57f1bb6ba043e33 \ + --hash=sha256:533059bc5f1764fac071ef54598db358c167c51a718f68f5bb55e3dee79d2967 \ + --hash=sha256:6604877d567a29bfe7cc02969ae0f2425260e5335505cf5e7fefc3e5465f5655 \ + --hash=sha256:6b35c4e8ca957c877ac35c79ffa77724ecc3702a1e4b18b08306c03feae597bb \ + --hash=sha256:763697ab8a88d47443997a7cdf3aac7340049aed45f7521f6b0ec8a0594821fe \ + --hash=sha256:7a4c3fcbf86fa52d0204f07cfd23947ef05b4ad743a1a988e163caa34a201e5e \ + --hash=sha256:7afde548890a92b41c0fed3a6c525f1200a5727205f73dc21181a2726571bb53 \ + --hash=sha256:7ba8e1fafc7fff3e643f453bffa7d082df1678a73286ce8187d3e825e776eb94 \ + --hash=sha256:91f8783fd1f2c23fd3b9ee5ad66b785dafa58ba3cdb050c4458021fa4d1eb226 \ + --hash=sha256:93b78f53431c93953f7850bb1b925a17f0ab5d97527e38a7e865b5b4bc5cfc18 \ + --hash=sha256:961134ea44c7b8ca63eda902a44b58cd8bd670e21d62e255c81fba0a8e70d9b7 \ + --hash=sha256:b89b194e9501a92d289c1ffd411380baf5daafb9818109a4f49b0a1b6dce4462 \ + --hash=sha256:c7b3fb3d4f6b39bcd4f6c4b98f406291f0d681a895490ee29a0f95bab850d53c \ + --hash=sha256:d1a746fe5fb9cd974a91070174258f0be129c592b93f9ce7df6cc336416c3fbd \ + --hash=sha256:e8505946df1665db01332d885c2020b4cb9e84a8b1241eb4ba69d59591f65855 \ + --hash=sha256:f47619d978ab1ae7dfdc4052ea97c636c6263e1f19bd1be0e42c346b98d15ff4 + # via + # jax + # keras-nightly namex==0.0.7 \ --hash=sha256:84ba65bc4d22bd909e3d26bf2ffb4b9529b608cb3f9a4336f776b04204ced69b \ --hash=sha256:8a4f062945f405d77cb66b907f16aa2fd83681945e998be840eb6c4154d40108 @@ -411,10 +404,6 @@ numpy==1.23.5 ; python_version <= "3.11" \ # opt-einsum # scipy # tb-nightly -oauthlib==3.2.2 \ - --hash=sha256:8139f29aac13e25d502680e9e19963e83f16838d48a0d71c287fe40e7067fbca \ - --hash=sha256:9859c40929662bec5d64f34d01c99e093149682a3f38915dc0655d5a633dd918 - # via requests-oauthlib opt-einsum==3.3.0 \ --hash=sha256:2455e59e3947d3c275477df7f5205b30635e266fe6dc300e3d9f9646bfcea147 \ --hash=sha256:59f6475f77bbc37dcf7cd748519c0ec60722e91e63ca114e68821c0c54a46549 @@ -444,57 +433,36 @@ protobuf==4.23.4 \ --hash=sha256:effeac51ab79332d44fba74660d40ae79985901ac21bca408f8dc335a81aa597 \ --hash=sha256:fee88269a090ada09ca63551bf2f573eb2424035bcf2cb1b121895b01a46594a # via tb-nightly -psutil==5.9.6 \ - --hash=sha256:10e8c17b4f898d64b121149afb136c53ea8b68c7531155147867b7b1ac9e7e28 \ - --hash=sha256:18cd22c5db486f33998f37e2bb054cc62fd06646995285e02a51b1e08da97017 \ - --hash=sha256:3ebf2158c16cc69db777e3c7decb3c0f43a7af94a60d72e87b2823aebac3d602 \ - --hash=sha256:51dc3d54607c73148f63732c727856f5febec1c7c336f8f41fcbd6315cce76ac \ - --hash=sha256:6e5fb8dc711a514da83098bc5234264e551ad980cec5f85dabf4d38ed6f15e9a \ - --hash=sha256:70cb3beb98bc3fd5ac9ac617a327af7e7f826373ee64c80efd4eb2856e5051e9 \ - --hash=sha256:748c9dd2583ed86347ed65d0035f45fa8c851e8d90354c122ab72319b5f366f4 \ - --hash=sha256:91ecd2d9c00db9817a4b4192107cf6954addb5d9d67a969a4f436dbc9200f88c \ - --hash=sha256:92e0cc43c524834af53e9d3369245e6cc3b130e78e26100d1f63cdb0abeb3d3c \ - --hash=sha256:a6f01f03bf1843280f4ad16f4bde26b817847b4c1a0db59bf6419807bc5ce05c \ - --hash=sha256:c69596f9fc2f8acd574a12d5f8b7b1ba3765a641ea5d60fb4736bf3c08a8214a \ - --hash=sha256:ca2780f5e038379e520281e4c032dddd086906ddff9ef0d1b9dcf00710e5071c \ - --hash=sha256:daecbcbd29b289aac14ece28eca6a3e60aa361754cf6da3dfb20d4d32b6c7f57 \ - --hash=sha256:e4b92ddcd7dd4cdd3f900180ea1e104932c7bce234fb88976e2a3b296441225a \ - --hash=sha256:fb8a697f11b0f5994550555fcfe3e69799e5b060c8ecf9e2f75c69302cc35c0d \ - --hash=sha256:ff18b8d1a784b810df0b0fff3bcb50ab941c3b8e2c8de5726f9c71c601c611aa +psutil==5.9.8 \ + --hash=sha256:02615ed8c5ea222323408ceba16c60e99c3f91639b07da6373fb7e6539abc56d \ + --hash=sha256:05806de88103b25903dff19bb6692bd2e714ccf9e668d050d144012055cbca73 \ + --hash=sha256:26bd09967ae00920df88e0352a91cff1a78f8d69b3ecabbfe733610c0af486c8 \ + --hash=sha256:27cc40c3493bb10de1be4b3f07cae4c010ce715290a5be22b98493509c6299e2 \ + --hash=sha256:36f435891adb138ed3c9e58c6af3e2e6ca9ac2f365efe1f9cfef2794e6c93b4e \ + --hash=sha256:50187900d73c1381ba1454cf40308c2bf6f34268518b3f36a9b663ca87e65e36 \ + --hash=sha256:611052c4bc70432ec770d5d54f64206aa7203a101ec273a0cd82418c86503bb7 \ + --hash=sha256:6be126e3225486dff286a8fb9a06246a5253f4c7c53b475ea5f5ac934e64194c \ + --hash=sha256:7d79560ad97af658a0f6adfef8b834b53f64746d45b403f225b85c5c2c140eee \ + --hash=sha256:8cb6403ce6d8e047495a701dc7c5bd788add903f8986d523e3e20b98b733e421 \ + --hash=sha256:8db4c1b57507eef143a15a6884ca10f7c73876cdf5d51e713151c1236a0e68cf \ + --hash=sha256:aee678c8720623dc456fa20659af736241f575d79429a0e5e9cf88ae0605cc81 \ + --hash=sha256:bc56c2a1b0d15aa3eaa5a60c9f3f8e3e565303b465dbf57a1b730e7a2b9844e0 \ + --hash=sha256:bd1184ceb3f87651a67b2708d4c3338e9b10c5df903f2e3776b62303b26cb631 \ + --hash=sha256:d06016f7f8625a1825ba3732081d77c94589dca78b7a3fc072194851e88461a4 \ + --hash=sha256:d16bbddf0693323b8c6123dd804100241da461e41d6e332fb0ba6058f630f8c8 # via portpicker -pyasn1==0.5.0 \ - --hash=sha256:87a2121042a1ac9358cabcaf1d07680ff97ee6404333bacca15f76aa8ad01a57 \ - --hash=sha256:97b7290ca68e62a832558ec3976f15cbf911bf5d7c7039d8b861c2a0ece69fde - # via - # pyasn1-modules - # rsa -pyasn1-modules==0.3.0 \ - --hash=sha256:5bd01446b736eb9d31512a30d46c1ac3395d676c6f3cafa4c03eb54b9925631c \ - --hash=sha256:d3ccd6ed470d9ffbc716be08bd90efbd44d0734bc9303818f7336070984a162d - # via google-auth -pygments==2.16.1 \ - --hash=sha256:13fc09fa63bc8d8671a6d247e1eb303c4b343eaee81d861f3404db2935653692 \ - --hash=sha256:1daff0494820c69bc8941e407aa20f577374ee88364ee10a98fdbe0aece96e29 +pygments==2.17.2 \ + --hash=sha256:b27c2826c47d0f3219f29554824c30c5e8945175d888647acd804ddd04af846c \ + --hash=sha256:da46cec9fd2de5be3a8a784f434e4c4ab670b4ff54d605c4c2717e9d49c4c367 # via rich requests==2.31.0 \ --hash=sha256:58cd2187c01e70e6e26505bca751777aa9f2ee0b7f4300988b709f44e013003f \ --hash=sha256:942c5a758f98d790eaed1a29cb6eefc7ffb0d1cf7af05c3d2791656dbd6ad1e1 - # via - # -r requirements.in - # requests-oauthlib - # tb-nightly -requests-oauthlib==1.3.1 \ - --hash=sha256:2577c501a2fb8d05a304c09d090d6e47c306fef15809d102b327cf8364bddab5 \ - --hash=sha256:75beac4a47881eeb94d5ea5d6ad31ef88856affe2332b9aafb52c6452ccf0d7a - # via google-auth-oauthlib -rich==13.6.0 \ - --hash=sha256:2b38e2fe9ca72c9a00170a1a2d20c63c790d0e10ef1fe35eba76e1e7b1d7d245 \ - --hash=sha256:5c14d22737e6d5084ef4771b62d5d4363165b403455a30a1c8ca39dc7b644bef + # via -r requirements.in +rich==13.7.0 \ + --hash=sha256:5cb5123b5cf9ee70584244246816e9114227e0b98ad9176eede6ad54bf5403fa \ + --hash=sha256:6da14c108c4866ee9520bbffa71f6fe3962e193b7da68720583850cd4548e235 # via keras-nightly -rsa==4.9 \ - --hash=sha256:90260d9058e514786967344d0ef75fa8727eed8a7d2e43ce9f4bcf1b536174f7 \ - --hash=sha256:e38464a49c6c85d7f1351b0126661487a7e0a14a50f1675ec50eb34d4f20ef21 - # via google-auth scipy==1.11.3 \ --hash=sha256:00f325434b6424952fbb636506f0567898dca7b0f7654d48f1c382ea338ce9a3 \ --hash=sha256:033c3fd95d55012dd1148b201b72ae854d5086d25e7c316ec9850de4fe776929 \ @@ -530,8 +498,8 @@ six==1.16.0 \ # via # astunparse # tb-nightly -tb-nightly==2.15.0a20231023 \ - --hash=sha256:8990a52985e3296aa18a6825efc017bcded9e2fb2cbcdd2d01f5c62d4fdc9825 +tb-nightly==2.17.0a20240214 \ + --hash=sha256:dbd59bcfd9b028e6199050e36cbb4ee1db4a94e69a7f8c00865f1498112a2b83 # via -r requirements.in tblib==2.0.0 \ --hash=sha256:9100bfa016b047d5b980d66e7efed952fbd20bd85b56110aaf473cb97d18709a \ @@ -546,13 +514,17 @@ termcolor==2.3.0 \ --hash=sha256:3afb05607b89aed0ffe25202399ee0867ad4d3cb4180d98aaf8eefa6a5f7d475 \ --hash=sha256:b5b08f68937f138fe92f6c089b99f1e2da0ae56c52b78bf7075fd95420fd9a5a # via -r requirements.in +tf-keras-nightly==2.16.0.dev2024021410 \ + --hash=sha256:05a4c19c6a795ec9ed6f4dac69d443601b9206b910524d8b19bad6e362d49f47 \ + --hash=sha256:83d5a1dd3979b42164a2fac4bdd2ed5981ef0c2a0514ca64e4bc67e369caef5c + # via tb-nightly typing-extensions==4.8.0 \ --hash=sha256:8f92fc8806f9a6b641eaa5318da32b44d401efaac0f6678c9bc448ba3605faa0 \ --hash=sha256:df8e4339e9cb77357558cbdbceca33c303714cf861d1eef15e1070055ae8b7ef # via -r requirements.in -urllib3==2.0.7 \ - --hash=sha256:c97dfde1f7bd43a71c8d2a58e369e9b2bf692d1334ea9f9cae55add7d0dd0f84 \ - --hash=sha256:fdb6d215c776278489906c2f8916e6e7d4f5a9b602ccbcfdf7f016fc8da0596e +urllib3==2.2.0 \ + --hash=sha256:051d961ad0c62a94e50ecf1af379c3aba230c66c710493493560c0c223c49f20 \ + --hash=sha256:ce3711610ddce217e6d113a2732fafad960a03fd0318c91faa79481e35c11224 # via requests werkzeug==3.0.1 \ --hash=sha256:507e811ecea72b18a404947aded4b3390e1db8f826b494d76550ef45bb3b1dcc \ @@ -564,81 +536,77 @@ wheel==0.41.3 \ # via # -r requirements.in # astunparse -wrapt==1.14.1 \ - --hash=sha256:00b6d4ea20a906c0ca56d84f93065b398ab74b927a7a3dbd470f6fc503f95dc3 \ - --hash=sha256:01c205616a89d09827986bc4e859bcabd64f5a0662a7fe95e0d359424e0e071b \ - --hash=sha256:02b41b633c6261feff8ddd8d11c711df6842aba629fdd3da10249a53211a72c4 \ - --hash=sha256:07f7a7d0f388028b2df1d916e94bbb40624c59b48ecc6cbc232546706fac74c2 \ - --hash=sha256:11871514607b15cfeb87c547a49bca19fde402f32e2b1c24a632506c0a756656 \ - --hash=sha256:1b376b3f4896e7930f1f772ac4b064ac12598d1c38d04907e696cc4d794b43d3 \ - --hash=sha256:2020f391008ef874c6d9e208b24f28e31bcb85ccff4f335f15a3251d222b92d9 \ - --hash=sha256:21ac0156c4b089b330b7666db40feee30a5d52634cc4560e1905d6529a3897ff \ - --hash=sha256:240b1686f38ae665d1b15475966fe0472f78e71b1b4903c143a842659c8e4cb9 \ - --hash=sha256:257fd78c513e0fb5cdbe058c27a0624c9884e735bbd131935fd49e9fe719d310 \ - --hash=sha256:26046cd03936ae745a502abf44dac702a5e6880b2b01c29aea8ddf3353b68224 \ - --hash=sha256:2b39d38039a1fdad98c87279b48bc5dce2c0ca0d73483b12cb72aa9609278e8a \ - --hash=sha256:2cf71233a0ed05ccdabe209c606fe0bac7379fdcf687f39b944420d2a09fdb57 \ - --hash=sha256:2fe803deacd09a233e4762a1adcea5db5d31e6be577a43352936179d14d90069 \ - --hash=sha256:2feecf86e1f7a86517cab34ae6c2f081fd2d0dac860cb0c0ded96d799d20b335 \ - --hash=sha256:3232822c7d98d23895ccc443bbdf57c7412c5a65996c30442ebe6ed3df335383 \ - --hash=sha256:34aa51c45f28ba7f12accd624225e2b1e5a3a45206aa191f6f9aac931d9d56fe \ - --hash=sha256:358fe87cc899c6bb0ddc185bf3dbfa4ba646f05b1b0b9b5a27c2cb92c2cea204 \ - --hash=sha256:36f582d0c6bc99d5f39cd3ac2a9062e57f3cf606ade29a0a0d6b323462f4dd87 \ - --hash=sha256:380a85cf89e0e69b7cfbe2ea9f765f004ff419f34194018a6827ac0e3edfed4d \ - --hash=sha256:40e7bc81c9e2b2734ea4bc1aceb8a8f0ceaac7c5299bc5d69e37c44d9081d43b \ - --hash=sha256:43ca3bbbe97af00f49efb06e352eae40434ca9d915906f77def219b88e85d907 \ - --hash=sha256:49ef582b7a1152ae2766557f0550a9fcbf7bbd76f43fbdc94dd3bf07cc7168be \ - --hash=sha256:4fcc4649dc762cddacd193e6b55bc02edca674067f5f98166d7713b193932b7f \ - --hash=sha256:5a0f54ce2c092aaf439813735584b9537cad479575a09892b8352fea5e988dc0 \ - --hash=sha256:5a9a0d155deafd9448baff28c08e150d9b24ff010e899311ddd63c45c2445e28 \ - --hash=sha256:5b02d65b9ccf0ef6c34cba6cf5bf2aab1bb2f49c6090bafeecc9cd81ad4ea1c1 \ - --hash=sha256:60db23fa423575eeb65ea430cee741acb7c26a1365d103f7b0f6ec412b893853 \ - --hash=sha256:642c2e7a804fcf18c222e1060df25fc210b9c58db7c91416fb055897fc27e8cc \ - --hash=sha256:6447e9f3ba72f8e2b985a1da758767698efa72723d5b59accefd716e9e8272bf \ - --hash=sha256:6a9a25751acb379b466ff6be78a315e2b439d4c94c1e99cb7266d40a537995d3 \ - --hash=sha256:6b1a564e6cb69922c7fe3a678b9f9a3c54e72b469875aa8018f18b4d1dd1adf3 \ - --hash=sha256:6d323e1554b3d22cfc03cd3243b5bb815a51f5249fdcbb86fda4bf62bab9e164 \ - --hash=sha256:6e743de5e9c3d1b7185870f480587b75b1cb604832e380d64f9504a0535912d1 \ - --hash=sha256:709fe01086a55cf79d20f741f39325018f4df051ef39fe921b1ebe780a66184c \ - --hash=sha256:7b7c050ae976e286906dd3f26009e117eb000fb2cf3533398c5ad9ccc86867b1 \ - --hash=sha256:7d2872609603cb35ca513d7404a94d6d608fc13211563571117046c9d2bcc3d7 \ - --hash=sha256:7ef58fb89674095bfc57c4069e95d7a31cfdc0939e2a579882ac7d55aadfd2a1 \ - --hash=sha256:80bb5c256f1415f747011dc3604b59bc1f91c6e7150bd7db03b19170ee06b320 \ - --hash=sha256:81b19725065dcb43df02b37e03278c011a09e49757287dca60c5aecdd5a0b8ed \ - --hash=sha256:833b58d5d0b7e5b9832869f039203389ac7cbf01765639c7309fd50ef619e0b1 \ - --hash=sha256:88bd7b6bd70a5b6803c1abf6bca012f7ed963e58c68d76ee20b9d751c74a3248 \ - --hash=sha256:8ad85f7f4e20964db4daadcab70b47ab05c7c1cf2a7c1e51087bfaa83831854c \ - --hash=sha256:8c0ce1e99116d5ab21355d8ebe53d9460366704ea38ae4d9f6933188f327b456 \ - --hash=sha256:8d649d616e5c6a678b26d15ece345354f7c2286acd6db868e65fcc5ff7c24a77 \ - --hash=sha256:903500616422a40a98a5a3c4ff4ed9d0066f3b4c951fa286018ecdf0750194ef \ - --hash=sha256:9736af4641846491aedb3c3f56b9bc5568d92b0692303b5a305301a95dfd38b1 \ - --hash=sha256:988635d122aaf2bdcef9e795435662bcd65b02f4f4c1ae37fbee7401c440b3a7 \ - --hash=sha256:9cca3c2cdadb362116235fdbd411735de4328c61425b0aa9f872fd76d02c4e86 \ - --hash=sha256:9e0fd32e0148dd5dea6af5fee42beb949098564cc23211a88d799e434255a1f4 \ - --hash=sha256:9f3e6f9e05148ff90002b884fbc2a86bd303ae847e472f44ecc06c2cd2fcdb2d \ - --hash=sha256:a85d2b46be66a71bedde836d9e41859879cc54a2a04fad1191eb50c2066f6e9d \ - --hash=sha256:a9008dad07d71f68487c91e96579c8567c98ca4c3881b9b113bc7b33e9fd78b8 \ - --hash=sha256:a9a52172be0b5aae932bef82a79ec0a0ce87288c7d132946d645eba03f0ad8a8 \ - --hash=sha256:aa31fdcc33fef9eb2552cbcbfee7773d5a6792c137b359e82879c101e98584c5 \ - --hash=sha256:acae32e13a4153809db37405f5eba5bac5fbe2e2ba61ab227926a22901051c0a \ - --hash=sha256:b014c23646a467558be7da3d6b9fa409b2c567d2110599b7cf9a0c5992b3b471 \ - --hash=sha256:b21bb4c09ffabfa0e85e3a6b623e19b80e7acd709b9f91452b8297ace2a8ab00 \ - --hash=sha256:b5901a312f4d14c59918c221323068fad0540e34324925c8475263841dbdfe68 \ - --hash=sha256:b9b7a708dd92306328117d8c4b62e2194d00c365f18eff11a9b53c6f923b01e3 \ - --hash=sha256:d1967f46ea8f2db647c786e78d8cc7e4313dbd1b0aca360592d8027b8508e24d \ - --hash=sha256:d52a25136894c63de15a35bc0bdc5adb4b0e173b9c0d07a2be9d3ca64a332735 \ - --hash=sha256:d77c85fedff92cf788face9bfa3ebaa364448ebb1d765302e9af11bf449ca36d \ - --hash=sha256:d79d7d5dc8a32b7093e81e97dad755127ff77bcc899e845f41bf71747af0c569 \ - --hash=sha256:dbcda74c67263139358f4d188ae5faae95c30929281bc6866d00573783c422b7 \ - --hash=sha256:ddaea91abf8b0d13443f6dac52e89051a5063c7d014710dcb4d4abb2ff811a59 \ - --hash=sha256:dee0ce50c6a2dd9056c20db781e9c1cfd33e77d2d569f5d1d9321c641bb903d5 \ - --hash=sha256:dee60e1de1898bde3b238f18340eec6148986da0455d8ba7848d50470a7a32fb \ - --hash=sha256:e2f83e18fe2f4c9e7db597e988f72712c0c3676d337d8b101f6758107c42425b \ - --hash=sha256:e3fb1677c720409d5f671e39bac6c9e0e422584e5f518bfd50aa4cbbea02433f \ - --hash=sha256:ecee4132c6cd2ce5308e21672015ddfed1ff975ad0ac8d27168ea82e71413f55 \ - --hash=sha256:ee2b1b1769f6707a8a445162ea16dddf74285c3964f605877a20e38545c3c462 \ - --hash=sha256:ee6acae74a2b91865910eef5e7de37dc6895ad96fa23603d1d27ea69df545015 \ - --hash=sha256:ef3f72c9666bba2bab70d2a8b79f2c6d2c1a42a7f7e2b0ec83bb2f9e383950af +wrapt==1.16.0 \ + --hash=sha256:0d2691979e93d06a95a26257adb7bfd0c93818e89b1406f5a28f36e0d8c1e1fc \ + --hash=sha256:14d7dc606219cdd7405133c713f2c218d4252f2a469003f8c46bb92d5d095d81 \ + --hash=sha256:1a5db485fe2de4403f13fafdc231b0dbae5eca4359232d2efc79025527375b09 \ + --hash=sha256:1acd723ee2a8826f3d53910255643e33673e1d11db84ce5880675954183ec47e \ + --hash=sha256:1ca9b6085e4f866bd584fb135a041bfc32cab916e69f714a7d1d397f8c4891ca \ + --hash=sha256:1dd50a2696ff89f57bd8847647a1c363b687d3d796dc30d4dd4a9d1689a706f0 \ + --hash=sha256:2076fad65c6736184e77d7d4729b63a6d1ae0b70da4868adeec40989858eb3fb \ + --hash=sha256:2a88e6010048489cda82b1326889ec075a8c856c2e6a256072b28eaee3ccf487 \ + --hash=sha256:3ebf019be5c09d400cf7b024aa52b1f3aeebeff51550d007e92c3c1c4afc2a40 \ + --hash=sha256:418abb18146475c310d7a6dc71143d6f7adec5b004ac9ce08dc7a34e2babdc5c \ + --hash=sha256:43aa59eadec7890d9958748db829df269f0368521ba6dc68cc172d5d03ed8060 \ + --hash=sha256:44a2754372e32ab315734c6c73b24351d06e77ffff6ae27d2ecf14cf3d229202 \ + --hash=sha256:490b0ee15c1a55be9c1bd8609b8cecd60e325f0575fc98f50058eae366e01f41 \ + --hash=sha256:49aac49dc4782cb04f58986e81ea0b4768e4ff197b57324dcbd7699c5dfb40b9 \ + --hash=sha256:5eb404d89131ec9b4f748fa5cfb5346802e5ee8836f57d516576e61f304f3b7b \ + --hash=sha256:5f15814a33e42b04e3de432e573aa557f9f0f56458745c2074952f564c50e664 \ + --hash=sha256:5f370f952971e7d17c7d1ead40e49f32345a7f7a5373571ef44d800d06b1899d \ + --hash=sha256:66027d667efe95cc4fa945af59f92c5a02c6f5bb6012bff9e60542c74c75c362 \ + --hash=sha256:66dfbaa7cfa3eb707bbfcd46dab2bc6207b005cbc9caa2199bcbc81d95071a00 \ + --hash=sha256:685f568fa5e627e93f3b52fda002c7ed2fa1800b50ce51f6ed1d572d8ab3e7fc \ + --hash=sha256:6906c4100a8fcbf2fa735f6059214bb13b97f75b1a61777fcf6432121ef12ef1 \ + --hash=sha256:6a42cd0cfa8ffc1915aef79cb4284f6383d8a3e9dcca70c445dcfdd639d51267 \ + --hash=sha256:6dcfcffe73710be01d90cae08c3e548d90932d37b39ef83969ae135d36ef3956 \ + --hash=sha256:6f6eac2360f2d543cc875a0e5efd413b6cbd483cb3ad7ebf888884a6e0d2e966 \ + --hash=sha256:72554a23c78a8e7aa02abbd699d129eead8b147a23c56e08d08dfc29cfdddca1 \ + --hash=sha256:73870c364c11f03ed072dda68ff7aea6d2a3a5c3fe250d917a429c7432e15228 \ + --hash=sha256:73aa7d98215d39b8455f103de64391cb79dfcad601701a3aa0dddacf74911d72 \ + --hash=sha256:75ea7d0ee2a15733684badb16de6794894ed9c55aa5e9903260922f0482e687d \ + --hash=sha256:7bd2d7ff69a2cac767fbf7a2b206add2e9a210e57947dd7ce03e25d03d2de292 \ + --hash=sha256:807cc8543a477ab7422f1120a217054f958a66ef7314f76dd9e77d3f02cdccd0 \ + --hash=sha256:8e9723528b9f787dc59168369e42ae1c3b0d3fadb2f1a71de14531d321ee05b0 \ + --hash=sha256:9090c9e676d5236a6948330e83cb89969f433b1943a558968f659ead07cb3b36 \ + --hash=sha256:9153ed35fc5e4fa3b2fe97bddaa7cbec0ed22412b85bcdaf54aeba92ea37428c \ + --hash=sha256:9159485323798c8dc530a224bd3ffcf76659319ccc7bbd52e01e73bd0241a0c5 \ + --hash=sha256:941988b89b4fd6b41c3f0bfb20e92bd23746579736b7343283297c4c8cbae68f \ + --hash=sha256:94265b00870aa407bd0cbcfd536f17ecde43b94fb8d228560a1e9d3041462d73 \ + --hash=sha256:98b5e1f498a8ca1858a1cdbffb023bfd954da4e3fa2c0cb5853d40014557248b \ + --hash=sha256:9b201ae332c3637a42f02d1045e1d0cccfdc41f1f2f801dafbaa7e9b4797bfc2 \ + --hash=sha256:a0ea261ce52b5952bf669684a251a66df239ec6d441ccb59ec7afa882265d593 \ + --hash=sha256:a33a747400b94b6d6b8a165e4480264a64a78c8a4c734b62136062e9a248dd39 \ + --hash=sha256:a452f9ca3e3267cd4d0fcf2edd0d035b1934ac2bd7e0e57ac91ad6b95c0c6389 \ + --hash=sha256:a86373cf37cd7764f2201b76496aba58a52e76dedfaa698ef9e9688bfd9e41cf \ + --hash=sha256:ac83a914ebaf589b69f7d0a1277602ff494e21f4c2f743313414378f8f50a4cf \ + --hash=sha256:aefbc4cb0a54f91af643660a0a150ce2c090d3652cf4052a5397fb2de549cd89 \ + --hash=sha256:b3646eefa23daeba62643a58aac816945cadc0afaf21800a1421eeba5f6cfb9c \ + --hash=sha256:b47cfad9e9bbbed2339081f4e346c93ecd7ab504299403320bf85f7f85c7d46c \ + --hash=sha256:b935ae30c6e7400022b50f8d359c03ed233d45b725cfdd299462f41ee5ffba6f \ + --hash=sha256:bb2dee3874a500de01c93d5c71415fcaef1d858370d405824783e7a8ef5db440 \ + --hash=sha256:bc57efac2da352a51cc4658878a68d2b1b67dbe9d33c36cb826ca449d80a8465 \ + --hash=sha256:bf5703fdeb350e36885f2875d853ce13172ae281c56e509f4e6eca049bdfb136 \ + --hash=sha256:c31f72b1b6624c9d863fc095da460802f43a7c6868c5dda140f51da24fd47d7b \ + --hash=sha256:c5cd603b575ebceca7da5a3a251e69561bec509e0b46e4993e1cac402b7247b8 \ + --hash=sha256:d2efee35b4b0a347e0d99d28e884dfd82797852d62fcd7ebdeee26f3ceb72cf3 \ + --hash=sha256:d462f28826f4657968ae51d2181a074dfe03c200d6131690b7d65d55b0f360f8 \ + --hash=sha256:d5e49454f19ef621089e204f862388d29e6e8d8b162efce05208913dde5b9ad6 \ + --hash=sha256:da4813f751142436b075ed7aa012a8778aa43a99f7b36afe9b742d3ed8bdc95e \ + --hash=sha256:db2e408d983b0e61e238cf579c09ef7020560441906ca990fe8412153e3b291f \ + --hash=sha256:db98ad84a55eb09b3c32a96c576476777e87c520a34e2519d3e59c44710c002c \ + --hash=sha256:dbed418ba5c3dce92619656802cc5355cb679e58d0d89b50f116e4a9d5a9603e \ + --hash=sha256:dcdba5c86e368442528f7060039eda390cc4091bfd1dca41e8046af7c910dda8 \ + --hash=sha256:decbfa2f618fa8ed81c95ee18a387ff973143c656ef800c9f24fb7e9c16054e2 \ + --hash=sha256:e4fdb9275308292e880dcbeb12546df7f3e0f96c6b41197e0cf37d2826359020 \ + --hash=sha256:eb1b046be06b0fce7249f1d025cd359b4b80fc1c3e24ad9eca33e0dcdb2e4a35 \ + --hash=sha256:eb6e651000a19c96f452c85132811d25e9264d836951022d6e81df2fff38337d \ + --hash=sha256:ed867c42c268f876097248e05b6117a65bcd1e63b779e916fe2e33cd6fd0d3c3 \ + --hash=sha256:edfad1d29c73f9b863ebe7082ae9321374ccb10879eeabc84ba3b69f2579d537 \ + --hash=sha256:f2058f813d4f2b5e3a9eb2eb3faf8f1d99b81c3e51aeda4b168406443e8ba809 \ + --hash=sha256:f6b2d0c6703c988d334f297aa5df18c45e97b0af3679bb75059e0e0bd8b1069d \ + --hash=sha256:f8212564d49c50eb4565e502814f694e240c55551a5f1bc841d4fcaabb0a9b8a \ + --hash=sha256:ffa565331890b90056c01db69c0fe634a776f8019c143a5ae265f9c6bc4bd6d4 # via -r requirements.in zipp==3.17.0 \ --hash=sha256:0e923e726174922dce09c53c59ad483ff7bbb8e572e00c7f7c46b88556409f31 \ diff --git a/tensorflow/BUILD b/tensorflow/BUILD index acc8468d6168e4..9eb036f01e0614 100644 --- a/tensorflow/BUILD +++ b/tensorflow/BUILD @@ -44,7 +44,7 @@ load("@bazel_skylib//:bzl_library.bzl", "bzl_library") # # buildifier: disable=out-of-order-load # load("//devtools/build_cleaner/skylark:action_config_test.bzl", "action_config_test") # load("//devtools/copybara/rules:copybara.bzl", "copybara_config_test") -# load("//tools/build_defs/license:license.bzl", "license") +# load("@rules_license//rules:license.bzl", "license") # # buildifier: enable=out-of-order-load # copybara:uncomment_end @@ -1631,6 +1631,7 @@ genrule( d="$${d#*external/farmhash_archive/src}" d="$${d#*external/$${extname}/}" + d="$${d#_virtual_includes/*/}" fi mkdir -p "$@/$${d}" diff --git a/tensorflow/c/eager/abstract_tensor_handle.cc b/tensorflow/c/eager/abstract_tensor_handle.cc index 8a4438e2b9e75a..e04a9810638f61 100644 --- a/tensorflow/c/eager/abstract_tensor_handle.cc +++ b/tensorflow/c/eager/abstract_tensor_handle.cc @@ -34,7 +34,7 @@ std::string AbstractTensorHandle::DebugString() const { Status AbstractTensorHandle::TensorHandleStatus() const { // Tensor handles in current runtime don't carry error info and this method // should always return OK status. - return OkStatus(); + return absl::OkStatus(); } } // namespace tensorflow diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc index 29301e4e37f754..ed3b9ae7015dfa 100644 --- a/tensorflow/c/eager/c_api.cc +++ b/tensorflow/c/eager/c_api.cc @@ -170,16 +170,17 @@ TF_CAPI_EXPORT extern void TFE_ContextSetServerDef(TFE_Context* ctx, TF_Status* status) { TFE_ContextSetServerDefWithTimeoutAndRetries( ctx, keep_alive_secs, proto, proto_len, /*init_timeout_in_ms=*/0, - /*retries=*/0, status); + /*retries=*/0, status, /*clear_existing_contexts=*/false); } // Set server def with timeout. TF_CAPI_EXPORT extern void TFE_ContextSetServerDefWithTimeout( TFE_Context* ctx, int keep_alive_secs, const void* proto, size_t proto_len, - int64_t init_timeout_in_ms, TF_Status* status) { - TFE_ContextSetServerDefWithTimeoutAndRetries(ctx, keep_alive_secs, proto, - proto_len, init_timeout_in_ms, - /*retries=*/0, status); + int64_t init_timeout_in_ms, TF_Status* status, + bool clear_existing_contexts) { + TFE_ContextSetServerDefWithTimeoutAndRetries( + ctx, keep_alive_secs, proto, proto_len, init_timeout_in_ms, + /*retries=*/0, status, clear_existing_contexts); } // Set server_def on the context, possibly updating it. @@ -190,7 +191,8 @@ TF_CAPI_EXPORT extern void TFE_ContextSetServerDefWithTimeout( // ParameterServerStrategy initialization to be robust to worker preemption. TF_CAPI_EXPORT extern void TFE_ContextSetServerDefWithTimeoutAndRetries( TFE_Context* ctx, int keep_alive_secs, const void* proto, size_t proto_len, - int64_t init_timeout_in_ms, int retries, TF_Status* status) { + int64_t init_timeout_in_ms, int retries, TF_Status* status, + bool clear_existing_contexts) { #if defined(IS_MOBILE_PLATFORM) status->status = tensorflow::errors::Unimplemented( "TFE_ContextSetServerDef not supported on mobile"); @@ -204,7 +206,7 @@ TF_CAPI_EXPORT extern void TFE_ContextSetServerDefWithTimeoutAndRetries( status->status = tensorflow::unwrap(ctx)->GetDistributedManager()->SetOrUpdateServerDef( server_def, /*reset_context=*/true, keep_alive_secs, - init_timeout_in_ms, retries); + init_timeout_in_ms, retries, clear_existing_contexts); #endif // !IS_MOBILE_PLATFORM } @@ -495,7 +497,7 @@ class CustomDeviceAPI : public tensorflow::CustomDevice { return status.status; } - tensorflow::StatusOr ShallPinToThisDevice( + absl::StatusOr ShallPinToThisDevice( const ImmediateExecutionOperation* op) override { TF_Status status; // Let this custom device choose the device to pin this op on if it @@ -555,7 +557,7 @@ class CAPICustomDeviceTensorHandle } summary = std::string(reinterpret_cast(summary_buffer->data), summary_buffer->length); - return OkStatus(); + return absl::OkStatus(); } private: diff --git a/tensorflow/c/eager/c_api_distributed_test.cc b/tensorflow/c/eager/c_api_distributed_test.cc index 13b688889a4567..3cb7a5d0fa5f1a 100644 --- a/tensorflow/c/eager/c_api_distributed_test.cc +++ b/tensorflow/c/eager/c_api_distributed_test.cc @@ -315,11 +315,11 @@ class GraphErrorInjectionPass : public tensorflow::GraphOptimizationPass { tensorflow::Status Run( const tensorflow::GraphOptimizationPassOptions& options) override { if (!enabled_) { - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } if (first_call_) { first_call_ = false; - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } return tensorflow::errors::Internal("Graph pass runs for more than once!"); } @@ -447,7 +447,7 @@ class FunctionErrorInjectionPass : public tensorflow::FunctionOptimizationPass { return tensorflow::errors::Internal("Injected graph pass error."); } } - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } private: diff --git a/tensorflow/c/eager/c_api_experimental.cc b/tensorflow/c/eager/c_api_experimental.cc index d52e938c047a6d..c2f8125ddbb76a 100644 --- a/tensorflow/c/eager/c_api_experimental.cc +++ b/tensorflow/c/eager/c_api_experimental.cc @@ -606,7 +606,7 @@ void TFE_OpSetCancellationManager(TFE_Op* op, TF_Status* status) { tensorflow::unwrap(op)->SetCancellationManager( tensorflow::unwrap(cancellation_manager)); - status->status = ::tensorflow::OkStatus(); + status->status = absl::OkStatus(); } TFE_Executor* TFE_NewExecutor(bool is_async, bool enable_streaming_enqueue, @@ -667,7 +667,7 @@ void TFE_ContextGetFunctionDef(TFE_Context* ctx, const char* function_name, buf->data_deallocator = [](void* data, size_t length) { tensorflow::port::Free(data); }; - status->status = ::tensorflow::OkStatus(); + status->status = absl::OkStatus(); } void TFE_ContextGetGraphDebugInfo(TFE_Context* ctx, const char* function_name, @@ -691,7 +691,7 @@ void TFE_ContextGetGraphDebugInfo(TFE_Context* ctx, const char* function_name, buf->data_deallocator = [](void* data, size_t length) { tensorflow::port::Free(data); }; - status->status = ::tensorflow::OkStatus(); + status->status = absl::OkStatus(); } TF_Tensor* TFE_AllocateHostTensor(TFE_Context* ctx, TF_DataType dtype, @@ -817,7 +817,7 @@ void TFE_GetExecutedOpNames(TFE_Context* ctx, TF_Buffer* buf, buf->data_deallocator = [](void* data, size_t length) { tensorflow::port::Free(data); }; - status->status = ::tensorflow::OkStatus(); + status->status = absl::OkStatus(); } void TFE_SetLogicalCpuDevices(TFE_Context* ctx, int num_cpus, @@ -960,7 +960,7 @@ void TFE_GetTaskStates(TFE_Context* ctx, const TF_Buffer& tasks, void* states, *state_iter = std::move(s); ++state_iter; } - status->status = tensorflow::OkStatus(); + status->status = absl::OkStatus(); } void TFE_WaitAtBarrier(TFE_Context* ctx, const char* barrier_id, diff --git a/tensorflow/c/eager/c_api_experimental.h b/tensorflow/c/eager/c_api_experimental.h index 604565ba120153..ab50b4701d8f36 100644 --- a/tensorflow/c/eager/c_api_experimental.h +++ b/tensorflow/c/eager/c_api_experimental.h @@ -416,7 +416,8 @@ TF_CAPI_EXPORT extern void TFE_ContextUpdateServerDefWithTimeout( // This API is for experimental usage and may be subject to change. TF_CAPI_EXPORT extern void TFE_ContextSetServerDefWithTimeout( TFE_Context* ctx, int keep_alive_secs, const void* proto, size_t proto_len, - int64_t init_timeout_in_ms, TF_Status* status); + int64_t init_timeout_in_ms, TF_Status* status, + bool clear_existing_contexts); // Set server def with retries and timeout. This is helpful for fault-tolerant // initial connection in high-preemption environments, such as @@ -424,7 +425,8 @@ TF_CAPI_EXPORT extern void TFE_ContextSetServerDefWithTimeout( // This API is for experimental usage and may be subject to change. TF_CAPI_EXPORT extern void TFE_ContextSetServerDefWithTimeoutAndRetries( TFE_Context* ctx, int keep_alive_secs, const void* proto, size_t proto_len, - int64_t init_timeout_in_ms, int retries, TF_Status* status); + int64_t init_timeout_in_ms, int retries, TF_Status* status, + bool clear_existing_contexts); // Checks whether a remote worker is alive or not. This will return true even if // the context doesn't exist on the remote worker. diff --git a/tensorflow/c/eager/c_api_test_util.cc b/tensorflow/c/eager/c_api_test_util.cc index cb401237d9f4dd..fe0e373664d818 100644 --- a/tensorflow/c/eager/c_api_test_util.cc +++ b/tensorflow/c/eager/c_api_test_util.cc @@ -565,7 +565,8 @@ TFE_Context* CreateContext(const std::string& serialized_server_def, EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status); TFE_ContextSetServerDefWithTimeout(ctx, 0, serialized_server_def.data(), serialized_server_def.size(), - init_timeout_in_ms, status); + init_timeout_in_ms, status, + /*clear_existing_contexts=*/false); EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status); TFE_DeleteContextOptions(opts); TF_DeleteStatus(status); diff --git a/tensorflow/c/eager/c_api_unified_experimental.cc b/tensorflow/c/eager/c_api_unified_experimental.cc index 53f340ee2aa450..8422459c21b529 100644 --- a/tensorflow/c/eager/c_api_unified_experimental.cc +++ b/tensorflow/c/eager/c_api_unified_experimental.cc @@ -52,7 +52,7 @@ Status SetDefaultTracingEngine(const char* name) { auto entry = GetFactories().find(name); if (entry != GetFactories().end()) { default_factory = GetFactories().find(name)->second; - return OkStatus(); + return absl::OkStatus(); } string msg = absl::StrCat( "No tracing engine factory has been registered with the key '", name, diff --git a/tensorflow/c/eager/c_api_unified_experimental_graph.cc b/tensorflow/c/eager/c_api_unified_experimental_graph.cc index 5e804dca267a0d..0c9d4830850bb7 100644 --- a/tensorflow/c/eager/c_api_unified_experimental_graph.cc +++ b/tensorflow/c/eager/c_api_unified_experimental_graph.cc @@ -71,7 +71,7 @@ class GraphTensor : public TracingTensorHandle { DCHECK_GE(num_dims, -1); TF_RETURN_IF_ERROR(StatusFromTF_Status(&status)); if (num_dims == kUnknownRank) { - return OkStatus(); + return absl::OkStatus(); } std::vector dims(num_dims, kUnknownDim); @@ -81,7 +81,7 @@ class GraphTensor : public TracingTensorHandle { TF_RETURN_IF_ERROR(StatusFromTF_Status(&status)); TF_RETURN_IF_ERROR(tensorflow::TensorShapeUtils::MakeShape(dims, shape)); - return OkStatus(); + return absl::OkStatus(); } tensorflow::FullTypeDef FullType() const override { @@ -119,7 +119,7 @@ class GraphOperation : public TracingOperation { device_name_ = raw_device_name; } op_type_ = op; - return OkStatus(); + return absl::OkStatus(); } Status SetOpName(const char* const op_name) override { if (op_) { @@ -135,7 +135,7 @@ class GraphOperation : public TracingOperation { mutex_lock l(g_->mu); op_.reset(new TF_OperationDescription(g_, op_type_.c_str(), g_->graph.NewName(op_name).c_str())); - return OkStatus(); + return absl::OkStatus(); } const string& Name() const override { return op_type_; } const string& DeviceName() const override { return device_name_; } @@ -143,7 +143,7 @@ class GraphOperation : public TracingOperation { Status SetDeviceName(const char* name) override { // TODO(srbs): Implement this. device_name_ = name; - return OkStatus(); + return absl::OkStatus(); } Status AddInput(AbstractTensorHandle* input) override { @@ -153,7 +153,7 @@ class GraphOperation : public TracingOperation { "Unable to cast input to GraphTensor"); } TF_AddInput(op_.get(), t->output_); - return OkStatus(); + return absl::OkStatus(); } Status AddInputList(absl::Span inputs) override { std::vector tf_outputs(inputs.size()); @@ -166,7 +166,7 @@ class GraphOperation : public TracingOperation { tf_outputs[i] = t->output_; } TF_AddInputList(op_.get(), tf_outputs.data(), tf_outputs.size()); - return OkStatus(); + return absl::OkStatus(); } Status Execute(absl::Span retvals, int* num_retvals) override { @@ -182,26 +182,26 @@ class GraphOperation : public TracingOperation { for (int i = 0; i < *num_retvals; ++i) { retvals[i] = new GraphTensor({operation, i}, g_); } - return OkStatus(); + return absl::OkStatus(); } Status SetAttrString(const char* attr_name, const char* data, size_t length) override { tensorflow::StringPiece s(data, length); op_->node_builder.Attr(attr_name, s); - return OkStatus(); + return absl::OkStatus(); } Status SetAttrInt(const char* attr_name, int64_t value) override { op_->node_builder.Attr(attr_name, static_cast(value)); - return OkStatus(); + return absl::OkStatus(); } Status SetAttrFloat(const char* attr_name, float value) override { op_->node_builder.Attr(attr_name, value); - return OkStatus(); + return absl::OkStatus(); } Status SetAttrBool(const char* attr_name, bool value) override { op_->node_builder.Attr(attr_name, value); - return OkStatus(); + return absl::OkStatus(); } Status SetAttrType(const char* const attr_name, DataType value) override { if (!op_) { @@ -210,7 +210,7 @@ class GraphOperation : public TracingOperation { "op_type and op_name must be specified before specifying attrs."); } op_->node_builder.Attr(attr_name, value); - return OkStatus(); + return absl::OkStatus(); } Status SetAttrShape(const char* attr_name, const int64_t* dims, const int num_dims) override { @@ -220,7 +220,7 @@ class GraphOperation : public TracingOperation { reinterpret_cast(dims), num_dims)); } op_->node_builder.Attr(attr_name, shape); - return OkStatus(); + return absl::OkStatus(); } Status SetAttrFunction(const char* attr_name, const AbstractOperation* value) override { @@ -232,7 +232,7 @@ class GraphOperation : public TracingOperation { tensorflow::NameAttrList func_name; func_name.set_name(string(value, value + length)); op_->node_builder.Attr(attr_name, func_name); - return OkStatus(); + return absl::OkStatus(); } Status SetAttrTensor(const char* attr_name, AbstractTensorInterface* tensor) override { @@ -255,26 +255,26 @@ class GraphOperation : public TracingOperation { } op_->node_builder.Attr(attr_name, v); } - return OkStatus(); + return absl::OkStatus(); } Status SetAttrFloatList(const char* attr_name, const float* values, int num_values) override { op_->node_builder.Attr(attr_name, ArraySlice(values, num_values)); - return OkStatus(); + return absl::OkStatus(); } Status SetAttrIntList(const char* attr_name, const int64_t* values, int num_values) override { op_->node_builder.Attr( attr_name, ArraySlice( reinterpret_cast(values), num_values)); - return OkStatus(); + return absl::OkStatus(); } Status SetAttrTypeList(const char* attr_name, const DataType* values, int num_values) override { op_->node_builder.Attr(attr_name, ArraySlice(values, num_values)); - return OkStatus(); + return absl::OkStatus(); } Status SetAttrBoolList(const char* attr_name, const unsigned char* values, int num_values) override { @@ -285,7 +285,7 @@ class GraphOperation : public TracingOperation { op_->node_builder.Attr(attr_name, ArraySlice(b.get(), num_values)); - return OkStatus(); + return absl::OkStatus(); } Status SetAttrShapeList(const char* attr_name, const int64_t** dims, const int* num_dims, int num_values) override { @@ -300,7 +300,7 @@ class GraphOperation : public TracingOperation { } } op_->node_builder.Attr(attr_name, shapes); - return OkStatus(); + return absl::OkStatus(); } Status SetAttrFunctionList( const char* attr_name, @@ -368,7 +368,7 @@ class GraphContext : public TracingContext { } inputs_.push_back(t->output_); *output = tensorflow::down_cast(outputs[0]); - return OkStatus(); + return absl::OkStatus(); } Status Finalize(OutputList* outputs, AbstractFunction** f) override { @@ -393,7 +393,7 @@ class GraphContext : public TracingContext { TF_DeleteFunction(func); TF_RETURN_IF_ERROR(StatusFromTF_Status(s)); TF_DeleteStatus(s); - return OkStatus(); + return absl::OkStatus(); } Status RegisterFunction(AbstractFunction* func) override { diff --git a/tensorflow/c/eager/dlpack.cc b/tensorflow/c/eager/dlpack.cc index ca033cd2266b01..f8de31aadbaa6f 100644 --- a/tensorflow/c/eager/dlpack.cc +++ b/tensorflow/c/eager/dlpack.cc @@ -169,57 +169,57 @@ Status TfDataTypeFormDlDataType(const DLDataType& dtype, "Only DLPack bools of bitwidth 8 are supported, got: ", dtype.bits); } *tf_dtype = TF_DataType::TF_BOOL; - return OkStatus(); + return absl::OkStatus(); case DLDataTypeCode::kDLUInt: switch (dtype.bits) { case 8: *tf_dtype = TF_DataType::TF_UINT8; - return OkStatus(); + return absl::OkStatus(); case 16: *tf_dtype = TF_DataType::TF_UINT16; - return OkStatus(); + return absl::OkStatus(); case 32: *tf_dtype = TF_DataType::TF_UINT32; - return OkStatus(); + return absl::OkStatus(); case 64: *tf_dtype = TF_DataType::TF_UINT64; - return OkStatus(); + return absl::OkStatus(); default: return tensorflow::errors::InvalidArgument("Unsupported UInt bits: ", dtype.bits); } - return OkStatus(); + return absl::OkStatus(); case DLDataTypeCode::kDLInt: switch (dtype.bits) { case 8: *tf_dtype = TF_DataType::TF_INT8; - return OkStatus(); + return absl::OkStatus(); case 16: *tf_dtype = TF_DataType::TF_INT16; - return OkStatus(); + return absl::OkStatus(); case 32: *tf_dtype = TF_DataType::TF_INT32; - return OkStatus(); + return absl::OkStatus(); case 64: *tf_dtype = TF_DataType::TF_INT64; - return OkStatus(); + return absl::OkStatus(); default: return tensorflow::errors::InvalidArgument("Unsupported Int bits: ", dtype.bits); } - return OkStatus(); + return absl::OkStatus(); case DLDataTypeCode::kDLFloat: switch (dtype.bits) { case 16: *tf_dtype = TF_DataType::TF_HALF; - return OkStatus(); + return absl::OkStatus(); case 32: *tf_dtype = TF_DataType::TF_FLOAT; - return OkStatus(); + return absl::OkStatus(); case 64: *tf_dtype = TF_DataType::TF_DOUBLE; - return OkStatus(); + return absl::OkStatus(); default: return tensorflow::errors::InvalidArgument("Unsupported Float bits: ", dtype.bits); @@ -229,7 +229,7 @@ Status TfDataTypeFormDlDataType(const DLDataType& dtype, switch (dtype.bits) { case 16: *tf_dtype = TF_DataType::TF_BFLOAT16; - return OkStatus(); + return absl::OkStatus(); default: return tensorflow::errors::InvalidArgument( "Unsupported BFloat bits: ", dtype.bits); @@ -239,10 +239,10 @@ Status TfDataTypeFormDlDataType(const DLDataType& dtype, switch (dtype.bits) { case 64: *tf_dtype = TF_DataType::TF_COMPLEX64; - return OkStatus(); + return absl::OkStatus(); case 128: *tf_dtype = TF_DataType::TF_COMPLEX128; - return OkStatus(); + return absl::OkStatus(); default: return tensorflow::errors::InvalidArgument( "Unsupported Complex bits: ", dtype.bits); diff --git a/tensorflow/c/eager/gradient_checker.cc b/tensorflow/c/eager/gradient_checker.cc index 2042c857e8b211..2fcaee07b37f50 100644 --- a/tensorflow/c/eager/gradient_checker.cc +++ b/tensorflow/c/eager/gradient_checker.cc @@ -65,7 +65,7 @@ Status RunAndMaybeSum(AbstractContext* ctx, Model forward, // If the output is a scalar, then return the scalar output if (num_dims_out == 0) { outputs[0] = model_out.release(); - return OkStatus(); + return absl::OkStatus(); } // Else, reduce sum the output to get a scalar @@ -85,7 +85,7 @@ Status RunAndMaybeSum(AbstractContext* ctx, Model forward, // Reduce sum the output on all dimensions. TF_RETURN_IF_ERROR(ops::Sum(ctx, model_out.get(), sum_dims.get(), &outputs[0], /*keep_dims=*/false, "sum_output")); - return OkStatus(); + return absl::OkStatus(); } // ========================= End Helper Functions============================== @@ -198,7 +198,7 @@ Status CalcNumericalGrad(AbstractContext* ctx, Model forward, TF_RETURN_IF_ERROR(TestTensorHandleWithDims( ctx, dtheta_approx.data(), theta_dims.data(), num_dims, numerical_grad)); TF_DeleteTensor(theta_tensor); - return OkStatus(); + return absl::OkStatus(); } } // namespace gradients diff --git a/tensorflow/c/eager/gradients.cc b/tensorflow/c/eager/gradients.cc index 6f4ab3016beb63..326a9e8cb829d4 100644 --- a/tensorflow/c/eager/gradients.cc +++ b/tensorflow/c/eager/gradients.cc @@ -47,7 +47,7 @@ Status ZerosLike(AbstractContext* ctx, AbstractTensorHandle* t, TF_RETURN_IF_ERROR( op->Execute(absl::Span(outputs), &num_outputs)); *result = outputs[0]; - return OkStatus(); + return absl::OkStatus(); } } // namespace @@ -59,7 +59,7 @@ Status GradientRegistry::Register( return errors::AlreadyExists(error_msg); } registry_.insert({op_name, gradient_function_factory}); - return OkStatus(); + return absl::OkStatus(); } Status GradientRegistry::Lookup( const ForwardOperation& op, @@ -70,7 +70,7 @@ Status GradientRegistry::Lookup( return errors::NotFound(error_msg); } gradient_function->reset(iter->second(op)); - return OkStatus(); + return absl::OkStatus(); } TapeTensor::TapeTensor(AbstractTensorHandle* handle) : handle_(handle) { @@ -200,7 +200,7 @@ Status TapeVSpace::BuildOnesLike(const TapeTensor& t, TF_RETURN_IF_ERROR( op->Execute(absl::Span(outputs), &num_outputs)); *result = outputs[0]; - return OkStatus(); + return absl::OkStatus(); } // Looks up the ID of a Gradient. @@ -292,7 +292,7 @@ Status Tape::ComputeGradient( TF_RETURN_IF_ERROR(GradientTape::ComputeGradient( vspace, target_tensor_ids, source_tensor_ids, sources_that_are_targets, output_gradients, result, /*build_default_zeros_grads*/ false)); - return OkStatus(); + return absl::OkStatus(); } // Helper functions which delegate to `AbstractOperation`, update @@ -309,7 +309,7 @@ Status AddInput(AbstractOperation* op_, AbstractTensorHandle* input, ForwardOperation* forward_op_) { TF_RETURN_IF_ERROR(op_->AddInput(input)); forward_op_->inputs.push_back(input); - return OkStatus(); + return absl::OkStatus(); } Status AddInputList(AbstractOperation* op_, absl::Span inputs, @@ -318,7 +318,7 @@ Status AddInputList(AbstractOperation* op_, for (auto input : inputs) { forward_op_->inputs.push_back(input); } - return OkStatus(); + return absl::OkStatus(); } Status SetAttrString(AbstractOperation* op_, const char* attr_name, @@ -482,7 +482,7 @@ Status Execute(AbstractOperation* op_, AbstractContext* ctx, TF_RETURN_IF_ERROR(registry.Lookup(*forward_op_, &gradient_fn)); tape->RecordOperation(forward_op_->inputs, retvals, gradient_fn.release(), op_->Name()); - return OkStatus(); + return absl::OkStatus(); } } // namespace internal diff --git a/tensorflow/c/eager/gradients_test.cc b/tensorflow/c/eager/gradients_test.cc index a345240e8c3e4f..9df16f10290d0b 100644 --- a/tensorflow/c/eager/gradients_test.cc +++ b/tensorflow/c/eager/gradients_test.cc @@ -59,7 +59,7 @@ class CppGradients Status RegisterGradients(GradientRegistry* registry) { TF_RETURN_IF_ERROR(RegisterNotDifferentiable(registry, "CheckNumerics")); - return OkStatus(); + return absl::OkStatus(); } TEST_P(CppGradients, TestSetAttrString) { diff --git a/tensorflow/c/eager/graph_function.cc b/tensorflow/c/eager/graph_function.cc index 3f4430bb614ea1..bf45feb34afb0f 100644 --- a/tensorflow/c/eager/graph_function.cc +++ b/tensorflow/c/eager/graph_function.cc @@ -22,7 +22,7 @@ GraphFunction::GraphFunction(FunctionDef fdef) GraphFunction::~GraphFunction() {} Status GraphFunction::GetFunctionDef(FunctionDef** fdef) { *fdef = &fdef_; - return OkStatus(); + return absl::OkStatus(); } } // namespace graph } // namespace tracing diff --git a/tensorflow/c/eager/immediate_execution_distributed_manager.h b/tensorflow/c/eager/immediate_execution_distributed_manager.h index cb927c36929b75..09011464a8b3d7 100644 --- a/tensorflow/c/eager/immediate_execution_distributed_manager.h +++ b/tensorflow/c/eager/immediate_execution_distributed_manager.h @@ -43,8 +43,8 @@ class ImmediateExecutionDistributedManager { // `keep_alive_secs` of inactivity. virtual Status SetOrUpdateServerDef(const ServerDef& server_def, bool reset_context, int keep_alive_secs, - int64_t init_timeout_in_ms, - int retries) = 0; + int64_t init_timeout_in_ms, int retries, + bool clear_existing_contexts = false) = 0; // Initializes context for the local worker and no contexts will be created // for remote workers. Currently this only works for resetting context. diff --git a/tensorflow/c/eager/immediate_execution_tensor_handle.cc b/tensorflow/c/eager/immediate_execution_tensor_handle.cc index d8cb9e165495c1..c99a270f0cb804 100644 --- a/tensorflow/c/eager/immediate_execution_tensor_handle.cc +++ b/tensorflow/c/eager/immediate_execution_tensor_handle.cc @@ -55,7 +55,7 @@ Status ImmediateExecutionTensorHandle::SummarizeValue( return status; } summary = resolved->SummarizeValue(); - return OkStatus(); + return absl::OkStatus(); } } // namespace tensorflow diff --git a/tensorflow/c/eager/parallel_device/parallel_device_lib.cc b/tensorflow/c/eager/parallel_device/parallel_device_lib.cc index 0522ad3b73072f..e5b1ee97a2e802 100644 --- a/tensorflow/c/eager/parallel_device/parallel_device_lib.cc +++ b/tensorflow/c/eager/parallel_device/parallel_device_lib.cc @@ -604,7 +604,7 @@ Status ParallelTensor::Shape(const std::vector** shape) const { shape_ = std::vector(dim_sizes.begin(), dim_sizes.end()); } *shape = &*shape_; - return OkStatus(); + return absl::OkStatus(); } Status ParallelTensor::SummarizeValue(std::string& summary) { @@ -624,7 +624,7 @@ Status ParallelTensor::SummarizeValue(std::string& summary) { "\": ", component_summary); } summary += "}"; - return OkStatus(); + return absl::OkStatus(); } } // namespace parallel_device diff --git a/tensorflow/c/experimental/stream_executor/BUILD b/tensorflow/c/experimental/stream_executor/BUILD index 48bdcb2c9a26bf..c0b62760cd4207 100644 --- a/tensorflow/c/experimental/stream_executor/BUILD +++ b/tensorflow/c/experimental/stream_executor/BUILD @@ -80,6 +80,7 @@ tf_cc_test( "//tensorflow/core:test", "//tensorflow/core:test_main", "//tensorflow/core/protobuf:error_codes_proto_impl_cc", + "@local_tsl//tsl/platform:statusor", "@local_xla//xla/stream_executor", ], ) diff --git a/tensorflow/c/experimental/stream_executor/stream_executor.cc b/tensorflow/c/experimental/stream_executor/stream_executor.cc index 12391143a4d9e0..0a15db6c0a1b94 100644 --- a/tensorflow/c/experimental/stream_executor/stream_executor.cc +++ b/tensorflow/c/experimental/stream_executor/stream_executor.cc @@ -30,8 +30,8 @@ limitations under the License. #include "tensorflow/c/experimental/stream_executor/stream_executor_internal.h" #include "tensorflow/c/tf_status_helper.h" #include "xla/stream_executor/executor_cache.h" -#include "xla/stream_executor/multi_platform_manager.h" #include "xla/stream_executor/platform.h" +#include "xla/stream_executor/platform_manager.h" #include "xla/stream_executor/stream.h" #include "xla/stream_executor/stream_executor.h" #include "tensorflow/core/common_runtime/device/device_utils.h" @@ -573,11 +573,6 @@ class CStreamExecutor : public internal::StreamExecutorInterface { return std::unique_ptr( new CEvent(&device_, stream_executor_)); } - std::unique_ptr CreateKernelImplementation() - override { - LOG(FATAL) - << "CreateKernelImplementation is not supported by pluggable device."; - } std::unique_ptr GetStreamImplementation() override { return std::unique_ptr( @@ -735,8 +730,8 @@ tsl::Status InitStreamExecutorPlugin(SEInitPluginFn init_fn, std::move(platform), params.destroy_platform, std::move(platform_fns), params.destroy_platform_fns, std::move(device_fns), std::move(se), std::move(timer_fns))); - TF_CHECK_OK(stream_executor::MultiPlatformManager::RegisterPlatform( - std::move(cplatform))); + TF_CHECK_OK( + stream_executor::PlatformManager::RegisterPlatform(std::move(cplatform))); // TODO(annarev): Return `use_bfc_allocator` value in some way so that it is // available in `PluggableDeviceProcessState` once the latter is checked in. return ::tensorflow::OkStatus(); diff --git a/tensorflow/c/experimental/stream_executor/stream_executor_test.cc b/tensorflow/c/experimental/stream_executor/stream_executor_test.cc index 0f3e2e76aa4ebe..17a9371f4e3fef 100644 --- a/tensorflow/c/experimental/stream_executor/stream_executor_test.cc +++ b/tensorflow/c/experimental/stream_executor/stream_executor_test.cc @@ -20,12 +20,13 @@ limitations under the License. #include "tensorflow/c/experimental/stream_executor/stream_executor_internal.h" #include "tensorflow/c/experimental/stream_executor/stream_executor_test_util.h" #include "xla/stream_executor/event.h" -#include "xla/stream_executor/multi_platform_manager.h" +#include "xla/stream_executor/platform_manager.h" #include "xla/stream_executor/stream.h" #include "xla/stream_executor/stream_executor.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/protobuf/error_codes.pb.h" +#include "tsl/platform/statusor.h" namespace stream_executor { namespace { @@ -42,7 +43,7 @@ TEST(StreamExecutor, SuccessfulRegistration) { InitStreamExecutorPlugin(plugin_init, &device_type, &platform_name); TF_ASSERT_OK(status); tsl::StatusOr maybe_platform = - MultiPlatformManager::PlatformWithName("MY_DEVICE"); + PlatformManager::PlatformWithName("MY_DEVICE"); TF_ASSERT_OK(maybe_platform.status()); Platform* platform = std::move(maybe_platform).value(); ASSERT_EQ(platform->Name(), test_util::kDeviceName); @@ -200,11 +201,11 @@ TEST_F(StreamExecutorTest, HostMemoryAllocate) { }; StreamExecutor* executor = GetExecutor(0); ASSERT_FALSE(allocate_called); - void* mem = executor->HostMemoryAllocate(8); - ASSERT_NE(mem, nullptr); + TF_ASSERT_OK_AND_ASSIGN(auto mem, executor->HostMemoryAllocate(8)); + ASSERT_NE(mem->opaque(), nullptr); ASSERT_TRUE(allocate_called); ASSERT_FALSE(deallocate_called); - executor->HostMemoryDeallocate(mem); + mem.reset(); ASSERT_TRUE(deallocate_called); } @@ -300,11 +301,11 @@ TEST_F(StreamExecutorTest, CreateStreamDependency) { StreamExecutor* executor = GetExecutor(0); Stream dependent(executor); - dependent.Init(); + TF_ASSERT_OK(dependent.Initialize()); Stream other(executor); - other.Init(); + TF_ASSERT_OK(other.Initialize()); ASSERT_FALSE(create_stream_dependency_called); - dependent.ThenWaitFor(&other); + TF_ASSERT_OK(dependent.WaitFor(&other)); ASSERT_TRUE(create_stream_dependency_called); } @@ -321,7 +322,7 @@ TEST_F(StreamExecutorTest, StreamStatus) { StreamExecutor* executor = GetExecutor(0); Stream stream(executor); - stream.Init(); + TF_ASSERT_OK(stream.Initialize()); ASSERT_TRUE(stream.ok()); TF_ASSERT_OK(stream.RefreshStatus()); status_ok = false; @@ -412,12 +413,12 @@ TEST_F(StreamExecutorTest, RecordAndWaitForEvent) { Event event(executor); event.Init(); Stream stream(executor); - stream.Init(); + TF_ASSERT_OK(stream.Initialize()); ASSERT_FALSE(record_called); - stream.ThenRecordEvent(&event); + TF_ASSERT_OK(stream.RecordEvent(&event)); ASSERT_TRUE(record_called); ASSERT_FALSE(wait_called); - stream.ThenWaitFor(&event); + TF_ASSERT_OK(stream.WaitFor(&event)); ASSERT_TRUE(wait_called); } @@ -440,14 +441,13 @@ TEST_F(StreamExecutorTest, MemcpyToHost) { StreamExecutor* executor = GetExecutor(0); Stream stream(executor); - stream.Init(); + TF_ASSERT_OK(stream.Initialize()); size_t size = sizeof(int); int src_data = 34; int dst_data = 2; DeviceMemoryBase device_src(&src_data, size); - Stream& stream_ref = stream.ThenMemcpy(&dst_data, device_src, size); + TF_ASSERT_OK(stream.Memcpy(&dst_data, device_src, size)); ASSERT_EQ(dst_data, 34); - ASSERT_EQ(stream_ref.implementation(), stream.implementation()); } TEST_F(StreamExecutorTest, MemcpyFromHost) { @@ -461,12 +461,12 @@ TEST_F(StreamExecutorTest, MemcpyFromHost) { StreamExecutor* executor = GetExecutor(0); Stream stream(executor); - stream.Init(); + TF_ASSERT_OK(stream.Initialize()); size_t size = sizeof(int); int src_data = 18; int dst_data = 0; DeviceMemoryBase device_dst(&dst_data, size); - stream.ThenMemcpy(&device_dst, &src_data, size); + TF_ASSERT_OK(stream.Memcpy(&device_dst, &src_data, size)); ASSERT_EQ(dst_data, 18); } @@ -481,13 +481,13 @@ TEST_F(StreamExecutorTest, MemcpyDeviceToDevice) { StreamExecutor* executor = GetExecutor(0); Stream stream(executor); - stream.Init(); + TF_ASSERT_OK(stream.Initialize()); size_t size = sizeof(int); int src_data = 18; int dst_data = 0; DeviceMemoryBase device_dst(&dst_data, size); DeviceMemoryBase device_src(&src_data, size); - stream.ThenMemcpy(&device_dst, device_src, size); + TF_ASSERT_OK(stream.Memcpy(&device_dst, device_src, size)); ASSERT_EQ(dst_data, 18); } @@ -562,7 +562,7 @@ TEST_F(StreamExecutorTest, BlockHostForEvent) { StreamExecutor* executor = GetExecutor(0); Stream stream(executor); - stream.Init(); + TF_ASSERT_OK(stream.Initialize()); ASSERT_FALSE(block_host_for_event_called); TF_ASSERT_OK(stream.BlockHostUntilDone()); ASSERT_TRUE(block_host_for_event_called); @@ -587,7 +587,7 @@ TEST_F(StreamExecutorTest, BlockHostUntilDone) { StreamExecutor* executor = GetExecutor(0); Stream stream(executor); - stream.Init(); + TF_ASSERT_OK(stream.Initialize()); ASSERT_FALSE(block_host_until_done_called); TF_ASSERT_OK(stream.BlockHostUntilDone()); ASSERT_TRUE(block_host_until_done_called); @@ -619,12 +619,11 @@ TEST_F(StreamExecutorTest, HostCallbackOk) { }; StreamExecutor* executor = GetExecutor(0); Stream stream(executor); - stream.Init(); + TF_ASSERT_OK(stream.Initialize()); std::function callback = []() -> absl::Status { return absl::OkStatus(); }; - stream.ThenDoHostCallbackWithStatus(callback); - ASSERT_TRUE(stream.ok()); + TF_ASSERT_OK(stream.DoHostCallbackWithStatus(callback)); } TEST_F(StreamExecutorTest, HostCallbackError) { @@ -639,12 +638,11 @@ TEST_F(StreamExecutorTest, HostCallbackError) { }; StreamExecutor* executor = GetExecutor(0); Stream stream(executor); - stream.Init(); + TF_ASSERT_OK(stream.Initialize()); std::function callback = []() -> tsl::Status { return tsl::errors::Unimplemented("Unimplemented"); }; - stream.ThenDoHostCallbackWithStatus(callback); - ASSERT_FALSE(stream.ok()); + ASSERT_FALSE(stream.DoHostCallbackWithStatus(callback).ok()); } TEST_F(StreamExecutorTest, DeviceDescription) { @@ -718,13 +716,12 @@ TEST_F(StreamExecutorTest, MemZero) { StreamExecutor* executor = GetExecutor(0); Stream stream(executor); - stream.Init(); + TF_ASSERT_OK(stream.Initialize()); size_t size = sizeof(int); int data = 2; DeviceMemoryBase device_data(&data, size); - Stream& stream_ref = stream.ThenMemZero(&device_data, size); + TF_ASSERT_OK(stream.MemZero(&device_data, size)); ASSERT_EQ(data, 0); - ASSERT_EQ(stream_ref.implementation(), stream.implementation()); } TEST_F(StreamExecutorTest, Memset32) { @@ -749,13 +746,12 @@ TEST_F(StreamExecutorTest, Memset32) { StreamExecutor* executor = GetExecutor(0); Stream stream(executor); - stream.Init(); + TF_ASSERT_OK(stream.Initialize()); size_t size = sizeof(int); int data = 2; DeviceMemoryBase device_data(&data, size); - Stream& stream_ref = stream.ThenMemset32(&device_data, 18, size); + TF_ASSERT_OK(stream.Memset32(&device_data, 18, size)); ASSERT_EQ(data, 18); - ASSERT_EQ(stream_ref.implementation(), stream.implementation()); } } // namespace diff --git a/tensorflow/c/kernels/ops/bitcast.cc b/tensorflow/c/kernels/ops/bitcast.cc index 6bb658eb25fd7a..d56b2897c89c3d 100644 --- a/tensorflow/c/kernels/ops/bitcast.cc +++ b/tensorflow/c/kernels/ops/bitcast.cc @@ -17,6 +17,8 @@ limitations under the License. #include #include "tensorflow/c/ops.h" +#include "tensorflow/c/tf_datatype.h" +#include "tensorflow/c/tf_status.h" #include "tensorflow/core/framework/registration/registration.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/macros.h" diff --git a/tensorflow/cc/saved_model/BUILD b/tensorflow/cc/saved_model/BUILD index 935b37b37aa5c0..4d8d5dfa11da87 100644 --- a/tensorflow/cc/saved_model/BUILD +++ b/tensorflow/cc/saved_model/BUILD @@ -109,7 +109,7 @@ cc_library( "//tensorflow/core:direct_session", "//tensorflow/core:all_kernels", ] + if_google( - ["@local_tsl//tsl/platform/default/build_config:tensorflow_platform_specific"], + ["//tensorflow/core/platform/default/build_config:tensorflow_platform_specific"], [], )) + if_not_mobile([ "//tensorflow/core:core_cpu", diff --git a/tensorflow/cc/saved_model/fingerprinting_utils.cc b/tensorflow/cc/saved_model/fingerprinting_utils.cc index f51a7c02bd0f98..b5248562c3f490 100644 --- a/tensorflow/cc/saved_model/fingerprinting_utils.cc +++ b/tensorflow/cc/saved_model/fingerprinting_utils.cc @@ -53,11 +53,11 @@ limitations under the License. namespace tensorflow::saved_model::fingerprinting { -using ::proto_splitter::ChunkedField; -using ::proto_splitter::ChunkedMessage; -using ::proto_splitter::ChunkInfo; -using ::proto_splitter::ChunkMetadata; -using ::proto_splitter::FieldIndex; +using ::tensorflow::proto_splitter::ChunkedField; +using ::tensorflow::proto_splitter::ChunkedMessage; +using ::tensorflow::proto_splitter::ChunkInfo; +using ::tensorflow::proto_splitter::ChunkMetadata; +using ::tensorflow::proto_splitter::FieldIndex; using tools::proto_splitter::Field; using tools::proto_splitter::FieldType; using tools::proto_splitter::GetChunkMetadata; @@ -83,54 +83,61 @@ absl::StatusOr fieldTagMatches(const RepeatedPtrField& a, int matches = 0; for (int i = 0; i == matches && i < a.size() && i < b.size(); i++) { switch (b[i].kind_case()) { - case ::proto_splitter::FieldIndex::KindCase::kField: + case ::tensorflow::proto_splitter::FieldIndex::KindCase::kField: if (a.at(i).has_field() && a.at(i).field() == b.at(i).field()) { matches += 1; } break; - case ::proto_splitter::FieldIndex::KindCase::kIndex: + case ::tensorflow::proto_splitter::FieldIndex::KindCase::kIndex: if (a.at(i).has_index() && a.at(i).index() == b.at(i).index()) { matches += 1; } break; - case ::proto_splitter::FieldIndex::KindCase::kMapKey: + case ::tensorflow::proto_splitter::FieldIndex::KindCase::kMapKey: if (a.at(i).has_map_key()) { - const ::proto_splitter::FieldIndex_MapKey& key = b.at(i).map_key(); - const ::proto_splitter::FieldIndex_MapKey& chunked_key = + const ::tensorflow::proto_splitter::FieldIndex_MapKey& key = + b.at(i).map_key(); + const ::tensorflow::proto_splitter::FieldIndex_MapKey& chunked_key = a.at(i).map_key(); switch (key.type_case()) { - case ::proto_splitter::FieldIndex::MapKey::TypeCase::kS: + case ::tensorflow::proto_splitter::FieldIndex::MapKey::TypeCase::kS: if (chunked_key.has_s() && chunked_key.s() == key.s()) { matches += 1; } break; - case ::proto_splitter::FieldIndex::MapKey::TypeCase::kBoolean: + case ::tensorflow::proto_splitter::FieldIndex::MapKey::TypeCase:: + kBoolean: if (chunked_key.has_boolean() && chunked_key.boolean() == key.boolean()) { matches += 1; } break; - case ::proto_splitter::FieldIndex::MapKey::TypeCase::kUi32: + case ::tensorflow::proto_splitter::FieldIndex::MapKey::TypeCase:: + kUi32: if (chunked_key.has_ui32() && chunked_key.ui32() == key.ui32()) { matches += 1; } break; - case ::proto_splitter::FieldIndex::MapKey::TypeCase::kUi64: + case ::tensorflow::proto_splitter::FieldIndex::MapKey::TypeCase:: + kUi64: if (chunked_key.has_ui64() && chunked_key.ui64() == key.ui64()) { matches += 1; } break; - case ::proto_splitter::FieldIndex::MapKey::TypeCase::kI32: + case ::tensorflow::proto_splitter::FieldIndex::MapKey::TypeCase:: + kI32: if (chunked_key.has_i32() && chunked_key.i32() == key.i32()) { matches += 1; } break; - case ::proto_splitter::FieldIndex::MapKey::TypeCase::kI64: + case ::tensorflow::proto_splitter::FieldIndex::MapKey::TypeCase:: + kI64: if (chunked_key.has_i64() && chunked_key.i64() == key.i64()) { matches += 1; } break; - case ::proto_splitter::FieldIndex::MapKey::TypeCase::TYPE_NOT_SET: + case ::tensorflow::proto_splitter::FieldIndex::MapKey::TypeCase:: + TYPE_NOT_SET: default: return absl::FailedPreconditionError( "Encountered unknown field_tag.map_key type."); @@ -146,12 +153,13 @@ absl::StatusOr fieldTagMatches(const RepeatedPtrField& a, return matches; } -absl::StatusOr<::proto_splitter::ChunkedMessage> PruneChunkedMessage( - const ::proto_splitter::ChunkedMessage& chunked_message, +absl::StatusOr<::tensorflow::proto_splitter::ChunkedMessage> +PruneChunkedMessage( + const ::tensorflow::proto_splitter::ChunkedMessage& chunked_message, riegeli::RecordReader>& reader, std::vector chunks_info, std::vector> target_fields_list) { - ::proto_splitter::ChunkedMessage pruned_chunked_message; + ::tensorflow::proto_splitter::ChunkedMessage pruned_chunked_message; if (chunked_message.has_chunk_index()) { pruned_chunked_message.set_chunk_index(chunked_message.chunk_index()); } diff --git a/tensorflow/cc/saved_model/fingerprinting_utils.h b/tensorflow/cc/saved_model/fingerprinting_utils.h index f63203a231c47e..306abec8acdfd3 100644 --- a/tensorflow/cc/saved_model/fingerprinting_utils.h +++ b/tensorflow/cc/saved_model/fingerprinting_utils.h @@ -44,17 +44,18 @@ using ::tensorflow::protobuf::RepeatedPtrField; // subsequence.) // Example: `a = {4, 2}`, `b = {4, 2, 1, 3}`, `fieldTagMatches(a, b) == 2` absl::StatusOr fieldTagMatches( - const RepeatedPtrField<::proto_splitter::FieldIndex>& a, - const RepeatedPtrField<::proto_splitter::FieldIndex>& b); + const RepeatedPtrField<::tensorflow::proto_splitter::FieldIndex>& a, + const RepeatedPtrField<::tensorflow::proto_splitter::FieldIndex>& b); // Pull out the relevant data within `chunked_message`. A `chunked_field` is // relevant if its `field_tags` are an initial subsequence any of the // `target_fields` in the provided `target_fields_list`. -absl::StatusOr<::proto_splitter::ChunkedMessage> PruneChunkedMessage( - const ::proto_splitter::ChunkedMessage& chunked_message, +absl::StatusOr<::tensorflow::proto_splitter::ChunkedMessage> +PruneChunkedMessage( + const ::tensorflow::proto_splitter::ChunkedMessage& chunked_message, riegeli::RecordReader>& reader, - std::vector<::proto_splitter::ChunkInfo> chunks_info, - std::vector> + std::vector<::tensorflow::proto_splitter::ChunkInfo> chunks_info, + std::vector> target_fields_list); // Deterministically serializes the proto `message`. @@ -63,20 +64,23 @@ std::string SerializeProto(const Message& message); // Uses metadata contained in `chunked_message` to hash fields within the // data accessed by the `reader` using `chunks_info`. absl::StatusOr HashFields( - const ::proto_splitter::ChunkedMessage& chunked_message, + const ::tensorflow::proto_splitter::ChunkedMessage& chunked_message, riegeli::RecordReader>& reader, - const std::vector<::proto_splitter::ChunkInfo>& chunks_info, - const RepeatedPtrField<::proto_splitter::FieldIndex>& field_tags, + const std::vector<::tensorflow::proto_splitter::ChunkInfo>& chunks_info, + const RepeatedPtrField<::tensorflow::proto_splitter::FieldIndex>& + field_tags, Message* merged_message); -// Gets the field tags for `graph_def`. -inline RepeatedPtrField<::proto_splitter::FieldIndex> GraphDefFieldTags(); +// Gets the field tags for `graph_def`.::tensorflow +inline RepeatedPtrField<::tensorflow::proto_splitter::FieldIndex> +GraphDefFieldTags(); // Gets the field tags for `signature_def`. -inline RepeatedPtrField<::proto_splitter::FieldIndex> SignatureDefFieldTags(); +inline RepeatedPtrField<::tensorflow::proto_splitter::FieldIndex> +SignatureDefFieldTags(); // Gets the field tags for `saved_object_graph`. -inline RepeatedPtrField<::proto_splitter::FieldIndex> +inline RepeatedPtrField<::tensorflow::proto_splitter::FieldIndex> SavedObjectGraphFieldTags(); // Returns a `SavedModel` containing only fields (up to those) specified by @@ -85,36 +89,38 @@ SavedObjectGraphFieldTags(); absl::StatusOr PrunedSavedModel( absl::string_view export_dir, riegeli::RecordReader>& reader, - const std::vector<::proto_splitter::ChunkInfo>& chunks_info, - ::proto_splitter::ChunkMetadata& chunk_metadata); + const std::vector<::tensorflow::proto_splitter::ChunkInfo>& chunks_info, + ::tensorflow::proto_splitter::ChunkMetadata& chunk_metadata); // Hashes the contents of `message` specified by `field_tags`. absl::StatusOr HashMessage( - Message* message, const ::proto_splitter::ChunkedMessage& chunked_message, + Message* message, + const ::tensorflow::proto_splitter::ChunkedMessage& chunked_message, riegeli::RecordReader>& reader, - const std::vector<::proto_splitter::ChunkInfo>& chunks_info, - const RepeatedPtrField<::proto_splitter::FieldIndex>& field_tags); + const std::vector<::tensorflow::proto_splitter::ChunkInfo>& chunks_info, + const RepeatedPtrField<::tensorflow::proto_splitter::FieldIndex>& + field_tags); // Hashes the contents of `graph_def`. absl::StatusOr HashGraphDef( tensorflow::GraphDef* graph_def, - const ::proto_splitter::ChunkedMessage& chunked_message, + const ::tensorflow::proto_splitter::ChunkedMessage& chunked_message, riegeli::RecordReader>& reader, - const std::vector<::proto_splitter::ChunkInfo>& chunks_info); + const std::vector<::tensorflow::proto_splitter::ChunkInfo>& chunks_info); // Hashes the contents of `signature_def`. absl::StatusOr HashSignatureDef( const Map& signature_def_map, - const ::proto_splitter::ChunkedMessage& chunked_message, + const ::tensorflow::proto_splitter::ChunkedMessage& chunked_message, riegeli::RecordReader>& reader, - const std::vector<::proto_splitter::ChunkInfo>& chunks_info); + const std::vector<::tensorflow::proto_splitter::ChunkInfo>& chunks_info); // Hashes the contents of `saved_object_graph`. absl::StatusOr HashSavedObjectGraph( tensorflow::SavedObjectGraph* saved_object_graph, - const ::proto_splitter::ChunkedMessage& chunked_message, + const ::tensorflow::proto_splitter::ChunkedMessage& chunked_message, riegeli::RecordReader>& reader, - const std::vector<::proto_splitter::ChunkInfo>& chunks_info); + const std::vector<::tensorflow::proto_splitter::ChunkInfo>& chunks_info); } // namespace fingerprinting_utils_internal diff --git a/tensorflow/cc/saved_model/fingerprinting_utils_test.cc b/tensorflow/cc/saved_model/fingerprinting_utils_test.cc index f535ed65068a5c..1f6b0e150850e4 100644 --- a/tensorflow/cc/saved_model/fingerprinting_utils_test.cc +++ b/tensorflow/cc/saved_model/fingerprinting_utils_test.cc @@ -51,12 +51,12 @@ using fingerprinting_utils_internal::HashSavedObjectGraph; using fingerprinting_utils_internal::HashSignatureDef; using fingerprinting_utils_internal::PruneChunkedMessage; using fingerprinting_utils_internal::SerializeProto; -using ::proto_splitter::ChunkedField; -using ::proto_splitter::ChunkedMessage; -using ::proto_splitter::ChunkInfo; -using ::proto_splitter::ChunkMetadata; -using ::proto_splitter::FieldIndex; -using ::proto_splitter_testdata::ManyFields; +using ::tensorflow::proto_splitter::ChunkedField; +using ::tensorflow::proto_splitter::ChunkedMessage; +using ::tensorflow::proto_splitter::ChunkInfo; +using ::tensorflow::proto_splitter::ChunkMetadata; +using ::tensorflow::proto_splitter::FieldIndex; +using ::tensorflow::proto_splitter_testdata::ManyFields; using ::tensorflow::protobuf::Message; using ::tensorflow::protobuf::RepeatedPtrField; using ::tensorflow::protobuf::TextFormat; @@ -82,8 +82,8 @@ absl::Status ParseTextProto(absl::string_view text_proto, absl::StrCat("Could not parse text proto: ", text_proto)); } -absl::StatusOr> ExtractFieldTags( - absl::string_view chunked_field_text_proto) { +absl::StatusOr> +ExtractFieldTags(absl::string_view chunked_field_text_proto) { ChunkedField chunked_field; TF_RETURN_IF_ERROR(ParseTextProto(chunked_field_text_proto, &chunked_field)); return chunked_field.field_tag(); diff --git a/tensorflow/cc/saved_model/loader.cc b/tensorflow/cc/saved_model/loader.cc index a245bf59a1f187..ae63fdab2fa32c 100644 --- a/tensorflow/cc/saved_model/loader.cc +++ b/tensorflow/cc/saved_model/loader.cc @@ -104,7 +104,7 @@ static Status ValidateNode(const NodeDef& node) { "Saved model contains node \"", node.name(), "\" which is a constant tensor but no value has been provided")); } - return OkStatus(); + return absl::OkStatus(); } static Status ValidateFunctionNotRecursive(const FunctionDef& function) { @@ -117,7 +117,7 @@ static Status ValidateFunctionNotRecursive(const FunctionDef& function) { } } - return OkStatus(); + return absl::OkStatus(); } static Status ValidateSavedTensors(const GraphDef& graph_def) { @@ -137,7 +137,7 @@ static Status ValidateSavedTensors(const GraphDef& graph_def) { } } - return OkStatus(); + return absl::OkStatus(); } Tensor CreateStringTensor(const string& value) { @@ -223,7 +223,7 @@ Status RunInitOp(const RunOptions& run_options, const string& export_dir, return RunOnce(run_options, inputs, {}, {init_op_name}, nullptr /* outputs */, &run_metadata, session); } - return OkStatus(); + return absl::OkStatus(); } Status RunRestore(const RunOptions& run_options, const string& export_dir, @@ -247,7 +247,7 @@ Status RunRestore(const RunOptions& run_options, const string& export_dir, LOG(INFO) << "The specified SavedModel has no variables; no checkpoints " "were restored. File does not exist: " << variables_index_path; - return OkStatus(); + return absl::OkStatus(); } const string variables_path = io::JoinPath(variables_directory, kSavedModelVariablesFilename); @@ -293,7 +293,7 @@ Status LoadSavedModelInternal(const SessionOptions& session_options, session_options, bundle->meta_graph_def, &bundle->session)); TF_RETURN_IF_ERROR(RestoreSession(run_options, bundle->meta_graph_def, export_dir, &bundle->session)); - return OkStatus(); + return absl::OkStatus(); } Status LoadSavedModel(const SessionOptions& session_options, @@ -469,7 +469,7 @@ Status RestoreSession(const RunOptions& run_options, // Record wall time spent in init op. load_latency_by_stage->GetCell(export_dir, "init_graph") ->Add(GetLatencyMicroseconds(graph_init_start_microseconds)); - return OkStatus(); + return absl::OkStatus(); } Status LoadSavedModel(const SessionOptions& session_options, @@ -494,7 +494,7 @@ Status LoadSavedModel(const SessionOptions& session_options, *bundle = SavedModelBundleLite( std::make_unique(std::move(legacy_bundle.session)), std::move(*legacy_bundle.meta_graph_def.mutable_signature_def())); - return OkStatus(); + return absl::OkStatus(); } bool MaybeSavedModelDirectory(const string& export_dir) { diff --git a/tensorflow/cc/saved_model/loader_util.cc b/tensorflow/cc/saved_model/loader_util.cc index e17f2ed4abb690..3a984bf31b3cd9 100644 --- a/tensorflow/cc/saved_model/loader_util.cc +++ b/tensorflow/cc/saved_model/loader_util.cc @@ -42,7 +42,7 @@ Status GetInitOp(const string& export_dir, const MetaGraphDef& meta_graph_def, kSavedModelInitOpSignatureKey); } *init_op_name = sig_def_outputs_it->second.name(); - return OkStatus(); + return absl::OkStatus(); } const auto& collection_def_map = meta_graph_def.collection_def(); @@ -62,7 +62,7 @@ Status GetInitOp(const string& export_dir, const MetaGraphDef& meta_graph_def, } *init_op_name = init_op_it->second.node_list().value(0); } - return OkStatus(); + return absl::OkStatus(); } Status GetAssetFileDefs(const MetaGraphDef& meta_graph_def, @@ -73,13 +73,13 @@ Status GetAssetFileDefs(const MetaGraphDef& meta_graph_def, for (const auto& asset : meta_graph_def.asset_file_def()) { asset_file_defs->push_back(asset); } - return OkStatus(); + return absl::OkStatus(); } // Fall back to read from collection to be backward compatible with v1. const auto& collection_def_map = meta_graph_def.collection_def(); const auto assets_it = collection_def_map.find(kSavedModelAssetsKey); if (assets_it == collection_def_map.end()) { - return OkStatus(); + return absl::OkStatus(); } const auto& any_assets = assets_it->second.any_list().value(); for (const auto& any_asset : any_assets) { @@ -88,7 +88,7 @@ Status GetAssetFileDefs(const MetaGraphDef& meta_graph_def, ParseAny(any_asset, &asset_file_def, "tensorflow.AssetFileDef")); asset_file_defs->push_back(asset_file_def); } - return OkStatus(); + return absl::OkStatus(); } } // namespace internal diff --git a/tensorflow/cc/saved_model/reader.cc b/tensorflow/cc/saved_model/reader.cc index 5563439f290391..b90f84438b3abf 100644 --- a/tensorflow/cc/saved_model/reader.cc +++ b/tensorflow/cc/saved_model/reader.cc @@ -63,7 +63,7 @@ Status FindMetaGraphDef(const std::unordered_set& tags, if (!port::kLittleEndian) { TF_RETURN_IF_ERROR(ByteSwapTensorContentInMetaGraphDef(meta_graph_def)); } - return OkStatus(); + return absl::OkStatus(); } } return Status( @@ -137,7 +137,7 @@ Status ReadMetaGraphDefFromSavedModel(const string& export_dir, TF_RETURN_IF_ERROR(ReadSavedModel(export_dir, &saved_model_proto)); TF_RETURN_IF_ERROR( FindMetaGraphDef(tags, &saved_model_proto, meta_graph_def)); - return OkStatus(); + return absl::OkStatus(); } Status ReadSavedModelDebugInfoIfPresent( @@ -156,7 +156,7 @@ Status ReadSavedModelDebugInfoIfPresent( ReadBinaryProto(Env::Default(), debug_info_pb_path, &debug_info)); *debug_info_proto = std::make_unique(std::move(debug_info)); } - return OkStatus(); + return absl::OkStatus(); } } // namespace tensorflow diff --git a/tensorflow/cc/saved_model/util.cc b/tensorflow/cc/saved_model/util.cc index 3e0b1eb27026bb..b474f1ef3ed0f3 100644 --- a/tensorflow/cc/saved_model/util.cc +++ b/tensorflow/cc/saved_model/util.cc @@ -86,7 +86,7 @@ Status GetInputValues( absl::StrJoin(seen_request_inputs, ","), ", request input: ", absl::StrJoin(GetMapKeys(request_inputs), ","))); } - return OkStatus(); + return absl::OkStatus(); } } // namespace saved_model diff --git a/tensorflow/compiler/aot/BUILD b/tensorflow/compiler/aot/BUILD index 8eb99e2b46fa09..5143188651791f 100644 --- a/tensorflow/compiler/aot/BUILD +++ b/tensorflow/compiler/aot/BUILD @@ -1,7 +1,7 @@ -load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") -load("//tensorflow/compiler/aot:tfcompile.bzl", "tf_library") load("//tensorflow:tensorflow.bzl", "if_google", "if_oss", "tf_cc_binary", "tf_cc_test") +load("//tensorflow/compiler/aot:tfcompile.bzl", "tf_library") load("//tensorflow/core/platform:build_config.bzl", "if_llvm_aarch64_available", "if_llvm_system_z_available") +load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], @@ -18,7 +18,7 @@ cc_library( deps = if_oss([ "//tensorflow/core:test_main", ]) + if_google([ - "@local_tsl//tsl/platform/default/build_config:test_main", + "//tensorflow/core/platform/default/build_config:test_main", ]), ) @@ -84,6 +84,7 @@ cc_library( "@local_xla//xla/service:compiler", "@local_xla//xla/service/cpu:buffer_info_util", "@local_xla//xla/service/cpu:cpu_compiler", + "@local_xla//xla/stream_executor:platform_manager", ], ) diff --git a/tensorflow/compiler/aot/codegen_test.cc b/tensorflow/compiler/aot/codegen_test.cc index 6ea99f2fc4d7c7..f7273c091a4d37 100644 --- a/tensorflow/compiler/aot/codegen_test.cc +++ b/tensorflow/compiler/aot/codegen_test.cc @@ -230,7 +230,7 @@ TEST(CodegenTest, Golden) { /*result_param_number=*/1), BufferInfo::MakeResultParameter(/*size=*/5 * 4, /*result_param_number=*/2)}, - 0, {})); + 0, nullptr, {})); compile_result.program_shape = xla::ShapeUtil::MakeProgramShape( { diff --git a/tensorflow/compiler/aot/compile.cc b/tensorflow/compiler/aot/compile.cc index 6bcea32f897e85..a0f0c20c1b6ad8 100644 --- a/tensorflow/compiler/aot/compile.cc +++ b/tensorflow/compiler/aot/compile.cc @@ -33,6 +33,7 @@ limitations under the License. #include "xla/client/xla_computation.h" #include "xla/service/cpu/cpu_compiler.h" #include "xla/statusor.h" +#include "xla/stream_executor/platform_manager.h" #include "xla/util.h" #include "xla/xla_data.pb.h" #include "tensorflow/core/framework/graph.pb.h" @@ -110,7 +111,7 @@ Status CompileGraph(GraphDef graph_def, const tf2xla::Config& config, // computation. // TODO(toddw): Should we let the user pick the XLA cpu vs. gpu client? se::Platform* cpu_platform = - se::MultiPlatformManager::PlatformWithName("Host").value(); + se::PlatformManager::PlatformWithName("Host").value(); xla::CompileOnlyClient* client = xla::ClientLibrary::GetOrCreateCompileOnlyClient(cpu_platform).value(); xla::XlaComputation computation; diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD index 9e3fd27b8f6e86..932a44d2762e6b 100644 --- a/tensorflow/compiler/jit/BUILD +++ b/tensorflow/compiler/jit/BUILD @@ -131,6 +131,7 @@ cc_library( "//tensorflow/core:core_cpu_internal", "//tensorflow/core:lib", "@com_google_absl//absl/memory", + "@local_xla//xla/stream_executor:platform_manager", ] + if_libtpu( if_false = [ "@local_xla//xla/service:cpu_plugin", # buildcleaner: keep @@ -160,6 +161,7 @@ cc_library( "//tensorflow/core:lib", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", + "@local_xla//xla/stream_executor:platform_manager", "@local_xla//xla/stream_executor/gpu:gpu_init", ] + if_libtpu( if_false = [ @@ -364,6 +366,7 @@ cc_library( "@local_xla//xla/pjrt:tf_pjrt_client", "@local_xla//xla/service:compiler", "@local_xla//xla/service:executable", + "@local_xla//xla/stream_executor:platform_manager", ], alwayslink = 1, ) @@ -598,6 +601,7 @@ cc_library( "@local_xla//xla/pjrt:tracked_device_buffer", "@local_xla//xla/service:shaped_buffer", "@local_xla//xla/stream_executor:device_memory_allocator", + "@local_xla//xla/stream_executor:platform_manager", ], ) @@ -1734,7 +1738,7 @@ tf_cuda_only_cc_test( "@local_xla//xla:shape_util", "@local_xla//xla/stream_executor", "@local_xla//xla/stream_executor:device_memory", - "@local_xla//xla/stream_executor:multi_platform_manager", + "@local_xla//xla/stream_executor:platform_manager", ], ) @@ -1842,6 +1846,7 @@ tf_cuda_cc_test( "//tensorflow/core/platform:statusor", "@com_google_googletest//:gtest_main", "@local_xla//xla/client:client_library", + "@local_xla//xla/stream_executor:platform_manager", ], ) diff --git a/tensorflow/compiler/jit/build_xla_ops_pass.cc b/tensorflow/compiler/jit/build_xla_ops_pass.cc index 95f6e4964fe09d..a8c327a204bda3 100644 --- a/tensorflow/compiler/jit/build_xla_ops_pass.cc +++ b/tensorflow/compiler/jit/build_xla_ops_pass.cc @@ -260,7 +260,7 @@ Status GetXlaClusterInfo(Node* n, XlaClusterInfo* result) { result->function.set_name(n->type_string()); *result->function.mutable_attr() = n->def().attr(); - return OkStatus(); + return absl::OkStatus(); } Status CopyIncomingControlEdges(Graph* g, Node* from, Node* to) { @@ -270,7 +270,7 @@ Status CopyIncomingControlEdges(Graph* g, Node* from, Node* to) { } } - return OkStatus(); + return absl::OkStatus(); } void RemoveAllIncomingControlEdges(Graph* g, Node* n) { @@ -289,7 +289,7 @@ Status DeviceRequiresCompilation(const jit::DeviceInfoCache& device_info_cache, device_info_cache.GetCompilationDevice(device); *result = registration->autoclustering_policy == XlaOpRegistry::AutoclusteringPolicy::kAlways; - return OkStatus(); + return absl::OkStatus(); } // Replaces `n` with a `PartitionedCall` op that calls the same function. @@ -442,7 +442,7 @@ Status PredicateInt32Inputs(const Scope& root, Node* n, } if (int32_inputs.empty()) { - return OkStatus(); + return absl::OkStatus(); } // Create a single IdentityN that is dead if and only if @@ -460,7 +460,7 @@ Status PredicateInt32Inputs(const Scope& root, Node* n, int32_inputs_input_idxs[i])); } - return OkStatus(); + return absl::OkStatus(); } Status ReplaceNodeWithXlaCompileAndXlaRun( @@ -564,7 +564,7 @@ Status ReplaceNodeWithXlaCompileAndXlaRun( PredicateInt32Inputs(root, pco, inverse_predicate_as_control)); } - return OkStatus(); + return absl::OkStatus(); } } // namespace @@ -614,6 +614,6 @@ Status BuildXlaOpsPass::Run(const GraphOptimizationPassOptions& options) { DumpGraphToFile("build_xla_ops", *graph, options.flib_def); } - return OkStatus(); + return absl::OkStatus(); } } // namespace tensorflow diff --git a/tensorflow/compiler/jit/build_xla_ops_pass_test.cc b/tensorflow/compiler/jit/build_xla_ops_pass_test.cc index 15707cc7f348ed..8e2324f68ddd34 100644 --- a/tensorflow/compiler/jit/build_xla_ops_pass_test.cc +++ b/tensorflow/compiler/jit/build_xla_ops_pass_test.cc @@ -82,7 +82,7 @@ Status BuildXlaOps(const Scope& s, const FunctionDefLibrary& fdef_lib, TF_RETURN_IF_ERROR(pass.Run(opt_options)); VLOG(3) << graph->ToGraphDefDebug().DebugString(); *result = std::move(graph); - return OkStatus(); + return absl::OkStatus(); } Status MakeXlaCompiledKernel(Graph* graph, const string& callee_name, @@ -95,7 +95,7 @@ Status MakeXlaCompiledKernel(Graph* graph, const string& callee_name, AddNodeAttr(kXlaNumConstantArgsAttr, num_constant_args, &call_node); AddNodeAttr(kXlaNumResourceArgsAttr, num_resource_args, &call_node); TF_ASSIGN_OR_RETURN(*result, graph->AddNode(call_node)); - return OkStatus(); + return absl::OkStatus(); } Status MakeXlaCompiledKernel(Graph* graph, const string& callee_name, diff --git a/tensorflow/compiler/jit/clone_constants_for_better_clustering.cc b/tensorflow/compiler/jit/clone_constants_for_better_clustering.cc index dc446c6baf64ca..0c6359148625bf 100644 --- a/tensorflow/compiler/jit/clone_constants_for_better_clustering.cc +++ b/tensorflow/compiler/jit/clone_constants_for_better_clustering.cc @@ -137,7 +137,7 @@ Status CloneConstantsForBetterClusteringPassImpl::CloneSmallConstantInputs( } } } - return OkStatus(); + return absl::OkStatus(); } Status CloneConstantsForBetterClusteringPassImpl::Run() { @@ -185,7 +185,7 @@ Status CloneConstantsForBetterClusteringPassImpl::Run() { // operation only modifies Const/clone_2 in place. if (IsInPlaceOp(n->type_string())) { - return OkStatus(); + return absl::OkStatus(); } nodes.push_back(n); } @@ -195,13 +195,13 @@ Status CloneConstantsForBetterClusteringPassImpl::Run() { for (Node* n : nodes) { TF_RETURN_IF_ERROR(CloneSmallConstantInputs(name_set, n)); } - return OkStatus(); + return absl::OkStatus(); } Status CloneConstantsForBetterClusteringPass::Run( const GraphOptimizationPassOptions& options) { if (GetGlobalJitLevelForGraph(options) == OptimizerOptions::OFF) { - return OkStatus(); + return absl::OkStatus(); } Graph* g = options.graph->get(); @@ -216,7 +216,7 @@ Status CloneConstantsForBetterClusteringPass::Run( DumpGraphToFile("after_clone_constants_for_better_clustering", *g); } - return OkStatus(); + return absl::OkStatus(); } } // namespace tensorflow diff --git a/tensorflow/compiler/jit/clone_constants_for_better_clustering_test.cc b/tensorflow/compiler/jit/clone_constants_for_better_clustering_test.cc index 468b1eab82b036..db5ebaac54b3ba 100644 --- a/tensorflow/compiler/jit/clone_constants_for_better_clustering_test.cc +++ b/tensorflow/compiler/jit/clone_constants_for_better_clustering_test.cc @@ -55,7 +55,7 @@ Status CloneConstantsForBetterClustering(const Scope& s, CloneConstantsForBetterClusteringPass rewriter; TF_RETURN_IF_ERROR(rewriter.Run(options)); *result = std::move(graph); - return OkStatus(); + return absl::OkStatus(); } const char* kCPU = "/job:localhost/replica:0/task:0/device:CPU:0"; diff --git a/tensorflow/compiler/jit/cluster_scoping_pass.cc b/tensorflow/compiler/jit/cluster_scoping_pass.cc index 9000fc0ae0abc6..edf72f83861e54 100644 --- a/tensorflow/compiler/jit/cluster_scoping_pass.cc +++ b/tensorflow/compiler/jit/cluster_scoping_pass.cc @@ -142,12 +142,12 @@ Status ClusterScopingPassImpl::ScopingForPipelineStages() { } } - return OkStatus(); + return absl::OkStatus(); } Status ClusterScopingPassImpl::Run() { if (global_jit_level_ == OptimizerOptions::OFF) { - return OkStatus(); + return absl::OkStatus(); } return ScopingForPipelineStages(); diff --git a/tensorflow/compiler/jit/compilability_check_util.cc b/tensorflow/compiler/jit/compilability_check_util.cc index 89798a05a1e0c6..cc7df74f78c7ac 100644 --- a/tensorflow/compiler/jit/compilability_check_util.cc +++ b/tensorflow/compiler/jit/compilability_check_util.cc @@ -86,7 +86,7 @@ Status MakeCallNodeFromAttribute(const Node& node, const std::string& attr_name, TF_RETURN_IF_ERROR(GetNodeAttr(node.attrs(), attr_name, &name_attr)); node_def->set_op(name_attr->name()); *(node_def->mutable_attr()) = name_attr->attr(); - return OkStatus(); + return absl::OkStatus(); } StatusOr> MakeCallNodesFromAttribute( @@ -660,7 +660,7 @@ Status GetBodyAndConstantsAndResources(FunctionLibraryRuntime* flr, } } - return OkStatus(); + return absl::OkStatus(); } tensorflow::MemoryTypeVector GetInputMemoryTypes( diff --git a/tensorflow/compiler/jit/deadness_analysis.cc b/tensorflow/compiler/jit/deadness_analysis.cc index 7251da75203f6b..377c97edf24696 100644 --- a/tensorflow/compiler/jit/deadness_analysis.cc +++ b/tensorflow/compiler/jit/deadness_analysis.cc @@ -430,7 +430,7 @@ class PredicateFactory { TF_RET_CHECK(tensor.FromProto(*proto)); *predicate = tensor.scalar()() ? MakeTrue() : MakeFalse(); - return OkStatus(); + return absl::OkStatus(); } SignatureForSymbol signature = {tensor_id, must_be_true}; @@ -446,7 +446,7 @@ class PredicateFactory { *predicate = it->second.get(); } - return OkStatus(); + return absl::OkStatus(); } Status MakeSymbolPredicate(Node* node, int output_idx, @@ -465,7 +465,7 @@ class PredicateFactory { *predicate = tensor.scalar()() == *must_have_value ? MakeTrue() : MakeFalse(); - return OkStatus(); + return absl::OkStatus(); } SignatureForIntSymbol signature = {tensor_id, must_have_value}; auto it = interned_int_symbol_instances_.find(signature); @@ -480,7 +480,7 @@ class PredicateFactory { *predicate = it->second.get(); } - return OkStatus(); + return absl::OkStatus(); } Predicate* MakeTrue() { return MakeAndPredicate({}); } @@ -924,7 +924,7 @@ Status DeadnessAnalysisImpl::GetInputPreds( result->push_back(it->second); } } - return OkStatus(); + return absl::OkStatus(); } Status DeadnessAnalysisImpl::HandleSwitch(Node* n, @@ -977,7 +977,7 @@ Status DeadnessAnalysisImpl::HandleSwitch(Node* n, predicate_factory_.MakeAndPredicate(input_preds), should_revisit); - return OkStatus(); + return absl::OkStatus(); } namespace { @@ -1005,7 +1005,7 @@ Status FindUniqueBackedge(Node* merge, const Edge** result) { *result = e; } } - return OkStatus(); + return absl::OkStatus(); } // If `backedge_predicate` is equal to `symbolic_predicate` & Step where Step @@ -1070,7 +1070,7 @@ Status GetFullFrame(const Node* n, absl::Span cfi_infos, } } - return OkStatus(); + return absl::OkStatus(); } // If the node is inside some frames, get the name of the outermost non-empty @@ -1091,7 +1091,7 @@ Status GetRootFrame(const Node* n, absl::Span cfi_infos, } *frame = cfi_iter->frame_name; - return OkStatus(); + return absl::OkStatus(); } } // namespace @@ -1135,7 +1135,7 @@ Status DeadnessAnalysisImpl::HandleMerge(Node* n, SetPredicate(n, {0, 1, Graph::kControlSlot}, input_data_pred, should_revisit); - return OkStatus(); + return absl::OkStatus(); } std::vector input_preds; @@ -1146,7 +1146,7 @@ Status DeadnessAnalysisImpl::HandleMerge(Node* n, predicate_factory_.MakeOrPredicate(input_preds); SetPredicate(n, {0, 1, Graph::kControlSlot}, input_data_pred, should_revisit); - return OkStatus(); + return absl::OkStatus(); } if (it->second->kind() == Predicate::Kind::kSymbol) { @@ -1178,11 +1178,11 @@ Status DeadnessAnalysisImpl::HandleMerge(Node* n, Predicate* and_rec = predicate_factory_.MakeAndRecurrencePredicate( start, step, std::move(frame)); SetPredicate(n, {0, 1, Graph::kControlSlot}, and_rec, should_revisit); - return OkStatus(); + return absl::OkStatus(); } } } - return OkStatus(); + return absl::OkStatus(); } Status DeadnessAnalysisImpl::HandleRecv(Node* n, @@ -1198,7 +1198,7 @@ Status DeadnessAnalysisImpl::HandleRecv(Node* n, SetPredicate(n, {0, Graph::kControlSlot}, predicate_factory_.MakeAndPredicate(input_preds), should_revisit); - return OkStatus(); + return absl::OkStatus(); } Status DeadnessAnalysisImpl::HandleGeneric(Node* n, @@ -1211,7 +1211,7 @@ Status DeadnessAnalysisImpl::HandleGeneric(Node* n, SetPredicate(n, output_idx, pred, should_revisit); } SetPredicate(n, Graph::kControlSlot, pred, should_revisit); - return OkStatus(); + return absl::OkStatus(); } Status DeadnessAnalysisImpl::HandleNode(Node* n, @@ -1231,7 +1231,7 @@ Status DeadnessAnalysisImpl::HandleNode(Node* n, } else { TF_RETURN_IF_ERROR(HandleGeneric(n, should_revisit)); } - return OkStatus(); + return absl::OkStatus(); } // Compute a special topological order for the Graph, where nodes having the @@ -1341,7 +1341,7 @@ Status DeadnessAnalysisImpl::GetFrameBasedTopologicalOrder( "Some enters/exits have never been visited in the traversal." " Most probably the input graph is malformed."); } - return OkStatus(); + return absl::OkStatus(); } // We populate the nodes along a special topological order where nodes having @@ -1415,7 +1415,7 @@ Status DeadnessAnalysisImpl::Populate(bool enable_optimistic) { << (success ? "optimistic" : "pessimistic") << " mode."; } - return OkStatus(); + return absl::OkStatus(); } Status DeadnessAnalysisImpl::PopulateFrame(absl::Span topo, @@ -1535,7 +1535,7 @@ Status DeadnessAnalysisImpl::PopulateFrame(absl::Span topo, } } - return OkStatus(); + return absl::OkStatus(); } StatusOr @@ -1578,7 +1578,7 @@ DeadnessAnalysis::~DeadnessAnalysis() {} } *result = std::move(analysis); - return OkStatus(); + return absl::OkStatus(); } absl::flat_hash_map @@ -1596,7 +1596,7 @@ Status ComputePredicates(const Graph& graph, PredicateMapTy* out_predicate_map, DeadnessAnalysisImpl impl(&graph); TF_RETURN_IF_ERROR(impl.Populate(enable_optimistic)); *out_predicate_map = impl.PredicateMapAsString(); - return OkStatus(); + return absl::OkStatus(); } } // namespace deadness_analysis_internal diff --git a/tensorflow/compiler/jit/deadness_analysis.h b/tensorflow/compiler/jit/deadness_analysis.h index 72b446f165a6e9..fe00acb866f179 100644 --- a/tensorflow/compiler/jit/deadness_analysis.h +++ b/tensorflow/compiler/jit/deadness_analysis.h @@ -73,8 +73,8 @@ class DeadnessAnalysis { friend class DeadnessAnalysis; }; - virtual tsl::StatusOr GetPredicateFor(Node* n, - int oidx) const = 0; + virtual absl::StatusOr GetPredicateFor(Node* n, + int oidx) const = 0; // Prints out the internal state of this instance. For debugging purposes // only. diff --git a/tensorflow/compiler/jit/deadness_analysis_test.cc b/tensorflow/compiler/jit/deadness_analysis_test.cc index 33cb716623fe3c..e2db6c0acca490 100644 --- a/tensorflow/compiler/jit/deadness_analysis_test.cc +++ b/tensorflow/compiler/jit/deadness_analysis_test.cc @@ -36,7 +36,7 @@ limitations under the License. namespace tensorflow { namespace { -tsl::StatusOr HasInputsWithMismatchingDeadness( +absl::StatusOr HasInputsWithMismatchingDeadness( const DeadnessAnalysis& deadness_analysis, const Node& n) { std::optional pred; for (const Edge* edge : n.in_edges()) { diff --git a/tensorflow/compiler/jit/device_compilation_cache_test.cc b/tensorflow/compiler/jit/device_compilation_cache_test.cc index 47c2eabf900557..194aa2ea9a99f0 100644 --- a/tensorflow/compiler/jit/device_compilation_cache_test.cc +++ b/tensorflow/compiler/jit/device_compilation_cache_test.cc @@ -96,7 +96,7 @@ TEST(DeviceCompilationCacheTest, RequestCountUnchangedOnStore) { EXPECT_EQ(cache_value.request_count, 3); auto compilation_result = std::make_unique(); - cache->Store(key, DeviceCompileState::kCompiled, OkStatus(), + cache->Store(key, DeviceCompileState::kCompiled, absl::OkStatus(), std::move(compilation_result), std::nullopt); cache_value = cache->LookupOrCreate(key); @@ -109,7 +109,7 @@ TEST(DeviceCompilationCacheTest, StoreLookup) { TF_ASSERT_OK_AND_ASSIGN(auto key, BuildSampleSignature("foo")); auto compilation_result = std::make_unique(); auto executable = std::make_unique("foo_exe"); - cache->Store(key, DeviceCompileState::kCompiled, OkStatus(), + cache->Store(key, DeviceCompileState::kCompiled, absl::OkStatus(), std::move(compilation_result), std::move(executable)); auto cache_value = cache->Lookup(key); @@ -127,7 +127,7 @@ TEST(DeviceCompilationCacheTest, StoreLookupOrCreate) { TF_ASSERT_OK_AND_ASSIGN(auto key, BuildSampleSignature("foo")); auto compilation_result = std::make_unique(); auto executable = std::make_unique("foo_exe"); - cache->Store(key, DeviceCompileState::kCompiled, OkStatus(), + cache->Store(key, DeviceCompileState::kCompiled, absl::OkStatus(), std::move(compilation_result), std::move(executable)); auto cache_value = cache->LookupOrCreate(key); @@ -198,7 +198,7 @@ TEST(DeviceCompilationCacheTest, StoreMultipleEntries) { cache->Store(key1, DeviceCompileState::kCompiled, errors::InvalidArgument("Invalid argument."), std::move(compilation_result1), std::move(executable1)); - cache->Store(key2, DeviceCompileState::kCompiling, OkStatus(), + cache->Store(key2, DeviceCompileState::kCompiling, absl::OkStatus(), std::move(compilation_result2), std::move(executable2)); auto cache_value_1 = cache->Lookup(key1); auto cache_value_2 = cache->Lookup(key2); diff --git a/tensorflow/compiler/jit/device_compilation_cluster_signature.cc b/tensorflow/compiler/jit/device_compilation_cluster_signature.cc index 8ca571b104b5ce..81902a28532dbf 100644 --- a/tensorflow/compiler/jit/device_compilation_cluster_signature.cc +++ b/tensorflow/compiler/jit/device_compilation_cluster_signature.cc @@ -110,7 +110,7 @@ uint64 Signature::Hash::operator()(const Signature& signature) const { return h; } -StatusOr Signature::Build( +absl::StatusOr Signature::Build( const NameAttrList& function, absl::Span args) { Signature signature; diff --git a/tensorflow/compiler/jit/device_compilation_cluster_signature.h b/tensorflow/compiler/jit/device_compilation_cluster_signature.h index 76a8daa0d95ab4..4acea2a03c2cb4 100644 --- a/tensorflow/compiler/jit/device_compilation_cluster_signature.h +++ b/tensorflow/compiler/jit/device_compilation_cluster_signature.h @@ -46,7 +46,7 @@ struct DeviceCompilationClusterSignature { string HumanString() const; // Builds the signature for a compilation. - static StatusOr Build( + static absl::StatusOr Build( const NameAttrList& function, absl::Span args); }; diff --git a/tensorflow/compiler/jit/device_compilation_profiler.cc b/tensorflow/compiler/jit/device_compilation_profiler.cc index 53df0c811ff826..b2da1959c98f72 100644 --- a/tensorflow/compiler/jit/device_compilation_profiler.cc +++ b/tensorflow/compiler/jit/device_compilation_profiler.cc @@ -72,7 +72,7 @@ DeviceCompilationProfiler::~DeviceCompilationProfiler() { cluster_compile_stats_.clear(); } -StatusOr +absl::StatusOr DeviceCompilationProfiler::GetCompileStats(const NameAttrList& function) const { mutex_lock lock(mu_); diff --git a/tensorflow/compiler/jit/device_compilation_profiler.h b/tensorflow/compiler/jit/device_compilation_profiler.h index 1f855f0be6bf2d..2057e1adc12dee 100644 --- a/tensorflow/compiler/jit/device_compilation_profiler.h +++ b/tensorflow/compiler/jit/device_compilation_profiler.h @@ -56,7 +56,7 @@ class DeviceCompilationProfiler : public ResourceBase { }; // Returns the compilation statistics for the given cluster. - StatusOr GetCompileStats( + absl::StatusOr GetCompileStats( const NameAttrList& function) const; // Determines whether the cluster should be compiled. Creates and inserts an diff --git a/tensorflow/compiler/jit/device_compiler.h b/tensorflow/compiler/jit/device_compiler.h index fc562d8277a77e..da2809cb7b62a1 100644 --- a/tensorflow/compiler/jit/device_compiler.h +++ b/tensorflow/compiler/jit/device_compiler.h @@ -201,7 +201,7 @@ inline Status EligibleToPersist(DeviceCompileState compile_state, return errors::FailedPrecondition( "LocalExecutable not found for cache entry to serialize."); } - return OkStatus(); + return absl::OkStatus(); } } // namespace device_compiler_internal @@ -391,7 +391,7 @@ Status DeviceCompiler::CompileAsynchronous( std::nullopt); } }); - return OkStatus(); + return absl::OkStatus(); } template @@ -469,14 +469,14 @@ Status DeviceCompiler::CompileImpl( if (!profiler->ShouldCompileCluster(function, compile_mode, current_request_count)) { VLOG(2) << "Not compiling for signature: " << human_signature; - return OkStatus(); + return absl::OkStatus(); } else if (compile_mode == DeviceCompileMode::kAsync) { VLOG(2) << "Queueing asynchronous compilation for signature: " << human_signature; TF_RETURN_IF_ERROR(CompileAsynchronous(signature, compile_options, options, args, function, scope, ctx, profiler)); - return OkStatus(); + return absl::OkStatus(); } else { VLOG(2) << "Instantly compiling for signature: " << human_signature; TF_ASSIGN_OR_RETURN( @@ -487,7 +487,7 @@ Status DeviceCompiler::CompileImpl( } else if (state == DeviceCompileState::kCompiling) { VLOG(2) << "Ongoing asynchronous compilation for signature: " << human_signature; - return OkStatus(); + return absl::OkStatus(); } else if (state == DeviceCompileState::kCompiled) { VLOG(2) << "Already Compiled for signature: " << human_signature; } @@ -495,7 +495,7 @@ Status DeviceCompiler::CompileImpl( TF_RETURN_IF_ERROR(cache_value.compilation_status); *out_compilation_result = cache_value.compilation_result; *out_executable = cache_value.executable; - return OkStatus(); + return absl::OkStatus(); } } // namespace tensorflow diff --git a/tensorflow/compiler/jit/device_compiler_client.h b/tensorflow/compiler/jit/device_compiler_client.h index 582bbbc4c102c6..358cb9236f492d 100644 --- a/tensorflow/compiler/jit/device_compiler_client.h +++ b/tensorflow/compiler/jit/device_compiler_client.h @@ -39,13 +39,13 @@ class DeviceCompilerClient { // Serializes an available `executable` to string using `ClientType` and // returns it. - virtual StatusOr SerializeExecutable( + virtual absl::StatusOr SerializeExecutable( const ExecutableType& executable) = 0; // Compiles `result` (HLO) to a serializable executable (eg. // xla::AotCompilationResult) using `ClientType`, serializes it to string and // returns it. - virtual StatusOr BuildSerializedExecutable( + virtual absl::StatusOr BuildSerializedExecutable( const XlaCompiler::Options& options, const XlaCompiler::CompilationResult& result) = 0; diff --git a/tensorflow/compiler/jit/device_compiler_test.cc b/tensorflow/compiler/jit/device_compiler_test.cc index 091257436b8c8c..70dad7031abd18 100644 --- a/tensorflow/compiler/jit/device_compiler_test.cc +++ b/tensorflow/compiler/jit/device_compiler_test.cc @@ -32,6 +32,7 @@ limitations under the License. #include "tensorflow/compiler/jit/tests/device_compiler_test_helper.h" #include "tensorflow/compiler/jit/xla_device_compiler_client.h" #include "xla/client/client_library.h" +#include "xla/stream_executor/platform_manager.h" #include "tensorflow/core/framework/fake_input.h" #include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/graph_to_functiondef.h" @@ -58,7 +59,7 @@ using Signature = DeviceCompilationClusterSignature; xla::LocalClient* GetLocalClient() { // TODO(b/255826209): Figure out how to run this test with the CPU client as // well. - auto platform = se::MultiPlatformManager::PlatformWithName("cuda").value(); + auto platform = se::PlatformManager::PlatformWithName("cuda").value(); return xla::ClientLibrary::GetOrCreateLocalClient(platform).value(); } @@ -284,7 +285,7 @@ TEST_F(DeviceCompilerTest, CompileAsyncSuccess) { EXPECT_CALL(*mock_profiler_, RegisterCompilation(_, _, false)) .WillOnce([&done] { done.Notify(); - return OkStatus(); + return absl::OkStatus(); }); auto args = SampleArgsForAddXY(); diff --git a/tensorflow/compiler/jit/device_executable_persistor.h b/tensorflow/compiler/jit/device_executable_persistor.h index 692f7ccb226a5e..037822821c111b 100644 --- a/tensorflow/compiler/jit/device_executable_persistor.h +++ b/tensorflow/compiler/jit/device_executable_persistor.h @@ -271,7 +271,7 @@ DeviceExecutablePersistor::VerifyLoadedCacheEntry( if (entry.executable().empty()) { return errors::InvalidArgument("No binary found in serialized entry."); } - return OkStatus(); + return absl::OkStatus(); } template @@ -392,7 +392,7 @@ DeviceExecutablePersistor::TryToPersistExecutable( persistent_cache_directory_read_only_) { VLOG(1) << "Not persisting executable. No `persistent_cache_directory` " "provided or cache is read-only."; - return OkStatus(); + return absl::OkStatus(); } XLA_SCOPED_LOGGING_TIMER( @@ -401,7 +401,7 @@ DeviceExecutablePersistor::TryToPersistExecutable( SerializeEntry(signature_hash, options, compilation_result, executable, client)); TF_RETURN_IF_ERROR(SaveSerializedEntry(std::move(serialized_entry))); - return OkStatus(); + return absl::OkStatus(); } } // namespace tensorflow diff --git a/tensorflow/compiler/jit/device_executable_persistor_test.cc b/tensorflow/compiler/jit/device_executable_persistor_test.cc index ba85bc59e0e6f8..3796105af40fd9 100644 --- a/tensorflow/compiler/jit/device_executable_persistor_test.cc +++ b/tensorflow/compiler/jit/device_executable_persistor_test.cc @@ -143,7 +143,7 @@ class DeviceExecutionPersistorTest : public ::testing::Test { GetOrCreatePjRtClient(DeviceType(DEVICE_CPU_XLA_JIT))); pjrt_compiler_client_ = std::make_unique(pjrt_client); - return OkStatus(); + return absl::OkStatus(); } std::unique_ptr flib_def_; diff --git a/tensorflow/compiler/jit/device_util.cc b/tensorflow/compiler/jit/device_util.cc index 6d82fe0f488988..7755db2e1afdfc 100644 --- a/tensorflow/compiler/jit/device_util.cc +++ b/tensorflow/compiler/jit/device_util.cc @@ -93,7 +93,7 @@ Status DeviceNameToDeviceType(const string& device, DeviceType* device_type) { return errors::Internal("Malformed assigned device '", device, "'"); } *device_type = DeviceType(parsed.type); - return OkStatus(); + return absl::OkStatus(); } StatusOr> PickDeviceForXlaImpl( diff --git a/tensorflow/compiler/jit/device_util.h b/tensorflow/compiler/jit/device_util.h index 90852c096f8997..df3b7d04fbfe7b 100644 --- a/tensorflow/compiler/jit/device_util.h +++ b/tensorflow/compiler/jit/device_util.h @@ -118,7 +118,7 @@ class DeviceInfoCache { return names_[device.id()]; } - StatusOr GetIdFor(absl::string_view name); + absl::StatusOr GetIdFor(absl::string_view name); using DeviceRegistration = const XlaOpRegistry::DeviceRegistration; @@ -126,7 +126,8 @@ class DeviceInfoCache { return id_to_compilation_device_[device.id()]; } - StatusOr GetCompilationDevice(absl::string_view name) { + absl::StatusOr GetCompilationDevice( + absl::string_view name) { TF_ASSIGN_OR_RETURN(DeviceId device_id, GetIdFor(name)); return GetCompilationDevice(device_id); } @@ -137,7 +138,8 @@ class DeviceInfoCache { using DeviceTypeConstRef = std::reference_wrapper; - StatusOr GetDeviceTypeFor(absl::string_view device_name) { + absl::StatusOr GetDeviceTypeFor( + absl::string_view device_name) { TF_ASSIGN_OR_RETURN(DeviceId device_id, GetIdFor(device_name)); return std::cref(*id_to_device_type_[device_id.id()]); } @@ -196,7 +198,7 @@ Status DeviceNameToDeviceType(const string& device, DeviceType* device_type); // case it is the responsibility of the optimization pass that injected the // CPU nodes into the cluster to ensure that these nodes can be compiled by // the unknown XLA backend. -StatusOr PickDeviceForXla( +absl::StatusOr PickDeviceForXla( const jit::DeviceInfoCache& device_info_cache, const jit::DeviceSet& devices, bool allow_mixing_unknown_and_cpu); @@ -205,7 +207,7 @@ StatusOr PickDeviceForXla( // // We return a failing Status for errors unrelated to the device choice // algorithm itself. -StatusOr> MaybePickDeviceForXla( +absl::StatusOr> MaybePickDeviceForXla( const jit::DeviceInfoCache& device_info_cache, const jit::DeviceSet& devices, bool allow_mixing_unknown_and_cpu); } // namespace tensorflow diff --git a/tensorflow/compiler/jit/device_util_test.cc b/tensorflow/compiler/jit/device_util_test.cc index 5a3f90d0fe1292..c63aa6683b9544 100644 --- a/tensorflow/compiler/jit/device_util_test.cc +++ b/tensorflow/compiler/jit/device_util_test.cc @@ -35,7 +35,7 @@ Status PickDeviceHelper(bool allow_mixing_unknown_and_cpu, jit::DeviceId result_id, PickDeviceForXla(cache, device_set, allow_mixing_unknown_and_cpu)); *result = string(cache.GetNameFor(result_id)); - return OkStatus(); + return absl::OkStatus(); } void CheckPickDeviceResult(absl::string_view expected_result, diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc index 0ae10eb6fae42b..47034a6a791f77 100644 --- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc +++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc @@ -479,7 +479,7 @@ Status Encapsulator::Subgraph::RecordArg( int dst_slot = edge->dst_input(); args_by_dst_[InputTensor(dst_node, dst_slot)] = arg_index; graph_->AddEdge(args_[arg_index], 0, dst_image, dst_slot); - return OkStatus(); + return absl::OkStatus(); } Status Encapsulator::Subgraph::RecordControlResult( @@ -488,7 +488,7 @@ Status Encapsulator::Subgraph::RecordControlResult( Node* src_node = edge->src(); Node* src_image = node_images.at(src_node); control_output_nodes_.insert(src_image->name()); - return OkStatus(); + return absl::OkStatus(); } Status Encapsulator::Subgraph::RecordResult( @@ -516,7 +516,7 @@ Status Encapsulator::Subgraph::RecordResult( TF_ASSIGN_OR_RETURN(Node * ret, graph_->AddNode(ret_def)); graph_->AddEdge(src_image, src_slot, ret, 0); } - return OkStatus(); + return absl::OkStatus(); } Status Encapsulator::Subgraph::MakeSequencingNode(const string& subgraph_name, @@ -532,7 +532,7 @@ Status Encapsulator::Subgraph::MakeSequencingNode(const string& subgraph_name, TF_ASSIGN_OR_RETURN(sequencer_, graph_out->AddNode(seq_def)); } - return OkStatus(); + return absl::OkStatus(); } void Encapsulator::Subgraph::ConnectSequencerToCallNode(Graph* graph_out) { @@ -617,7 +617,7 @@ Status Encapsulator::Subgraph::BuildFunctionDef( } else if (!FunctionDefsEqual(*original_fdef, fdef)) { TF_RETURN_IF_ERROR(library->ReplaceFunction(name, fdef)); } - return OkStatus(); + return absl::OkStatus(); } Status Encapsulator::Subgraph::ReplaceFunctionDef( @@ -636,7 +636,7 @@ Status Encapsulator::Subgraph::ReplaceFunctionDef( } TF_RETURN_IF_ERROR(library->ReplaceFunction(name, fdef)); - return OkStatus(); + return absl::OkStatus(); } Status Encapsulator::Subgraph::AddFunctionCallNode( @@ -647,7 +647,7 @@ Status Encapsulator::Subgraph::AddFunctionCallNode( // Copy the assigned device and the key_annotation over. call_node_->set_assigned_device_name(device_); - return OkStatus(); + return absl::OkStatus(); } Status Encapsulator::GetFunctionNameAttr(Node const* node, string* attr) const { @@ -660,7 +660,7 @@ Status Encapsulator::GetFunctionNameAttr(Node const* node, string* attr) const { break; } } - return OkStatus(); + return absl::OkStatus(); } bool IsInSubgraph(const string& func_id) { return !func_id.empty(); } @@ -677,7 +677,7 @@ Status Encapsulator::CopySubgraphNodes( image->ClearAttr(group_attribute_); (*node_images)[node] = image; } - return OkStatus(); + return absl::OkStatus(); } Status Encapsulator::CopySubgraphEdges( @@ -748,7 +748,7 @@ Status Encapsulator::CopySubgraphEdges( } } } - return OkStatus(); + return absl::OkStatus(); } Status Encapsulator::SplitIntoSubgraphs(FunctionLibraryDefinition* library) { @@ -793,7 +793,7 @@ Status Encapsulator::BuildFunctionDefs( TF_RETURN_IF_ERROR(subgraph.BuildFunctionDef( name, rewrite_subgraph_fn, reuse_existing_functions, library)); } - return OkStatus(); + return absl::OkStatus(); } Status Encapsulator::CopyNodesToOutputGraph( @@ -810,7 +810,7 @@ Status Encapsulator::CopyNodesToOutputGraph( } (*node_images)[graph_in_->source_node()] = graph_out->source_node(); (*node_images)[graph_in_->sink_node()] = graph_out->sink_node(); - return OkStatus(); + return absl::OkStatus(); } Status Encapsulator::AddFunctionCallNodes( @@ -820,7 +820,7 @@ Status Encapsulator::AddFunctionCallNodes( TF_RETURN_IF_ERROR( subgraph_entry.second.AddFunctionCallNode(node_images, graph_out)); } - return OkStatus(); + return absl::OkStatus(); } Status Encapsulator::FindOutputImageOfEdgeSrc( @@ -836,7 +836,7 @@ Status Encapsulator::FindOutputImageOfEdgeSrc( // the output graph. *src_image = node_images.at(original_src_node); } - return OkStatus(); + return absl::OkStatus(); } int Encapsulator::FindOutputSlotOfEdgeSrc(const string& src_func_id, @@ -867,7 +867,7 @@ Status Encapsulator::FindOutputImageOfEdgeDst( // in the output graph. *dst_image = node_images.at(original_dst_node); } - return OkStatus(); + return absl::OkStatus(); } int Encapsulator::FindOutputSlotOfEdgeDst(const string& src_func_id, @@ -910,7 +910,7 @@ Status Encapsulator::CopyEdgeToOutputGraph( /* allow_duplicates= */ true); } - return OkStatus(); + return absl::OkStatus(); } int src_output = FindOutputSlotOfEdgeSrc(src_func_id, dst_func_id, edge); @@ -924,7 +924,7 @@ Status Encapsulator::CopyEdgeToOutputGraph( .second) { graph_out->AddEdge(src_image, src_output, dst_image, dst_input); } - return OkStatus(); + return absl::OkStatus(); } Status Encapsulator::AddEdgesToOutputGraph( @@ -961,7 +961,7 @@ Status Encapsulator::AddEdgesToOutputGraph( subgraph.ConnectSequencerToCallNode(graph_out); } - return OkStatus(); + return absl::OkStatus(); } namespace { @@ -1068,7 +1068,7 @@ Status Encapsulator::MakePrunedGraphCopyAndInline( fbody.get(), inline_opts)); } - return OkStatus(); + return absl::OkStatus(); } Status Encapsulator::BuildOutputGraph(Graph* graph_out, @@ -1080,7 +1080,7 @@ Status Encapsulator::BuildOutputGraph(Graph* graph_out, TF_RETURN_IF_ERROR(AddFunctionCallNodes(node_images, graph_out)); TF_RETURN_IF_ERROR(AddEdgesToOutputGraph(node_images, graph_out)); - return OkStatus(); + return absl::OkStatus(); } } // anonymous namespace @@ -1101,7 +1101,7 @@ Status EncapsulateSubgraphsInFunctions( TF_RETURN_IF_ERROR(encapsulator.BuildOutputGraph(out.get(), library)); *graph_out = std::move(out); - return OkStatus(); + return absl::OkStatus(); } // Finds the types of the _Arg nodes, indexed by position. @@ -1117,7 +1117,7 @@ static Status GetArgTypes(const Graph& graph, DataTypeVector* types) { (*types)[index] = n->output_type(0); } } - return OkStatus(); + return absl::OkStatus(); } // Renumber the indices of _Arg nodes in a graph, according to @@ -1135,7 +1135,7 @@ static Status RenumberArguments(Graph* graph, n->AddAttr("index", permutation[index]); } } - return OkStatus(); + return absl::OkStatus(); } Status EncapsulateSubgraphsPass::Run( @@ -1154,7 +1154,7 @@ Status EncapsulateSubgraphsPass::Run( // and doesn't require auto clustering. if (n->type_string() == "TPUExecute" || n->type_string() == "TPUExecuteAndUpdateVariables") { - return OkStatus(); + return absl::OkStatus(); } } @@ -1288,7 +1288,7 @@ Status EncapsulateSubgraphsPass::Run( AddNodeAttr(kXlaCompiledKernelAttr, true, node); AddNodeAttr(kXlaNumConstantArgsAttr, num_consts, node); AddNodeAttr(kXlaNumResourceArgsAttr, num_resources, node); - return OkStatus(); + return absl::OkStatus(); }; TF_RETURN_WITH_CONTEXT_IF_ERROR( @@ -1311,7 +1311,7 @@ Status EncapsulateSubgraphsPass::Run( VLOG(3) << "Has ref vars = " << has_ref_vars << ", node: " << node->def().DebugString(); } - return OkStatus(); + return absl::OkStatus(); } bool IsXlaCompiledKernel(const Node& node) { diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc b/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc index 5d722b8c751ffa..e634982c85db8a 100644 --- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc +++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc @@ -60,7 +60,7 @@ Status AddGraphDefToFunctionLibrary(const GraphDefBuilder& graphdef_builder, *graph, absl::StrCat("_outside_compilation_shape_inference_", name_suffix), fdef)); - return OkStatus(); + return absl::OkStatus(); } template @@ -299,14 +299,14 @@ REGISTER_OP("InputTest") .Output("o: float") .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) { c->set_output(0, c->UnknownShape()); - return OkStatus(); + return absl::OkStatus(); }); REGISTER_OP("InputTestShaped") .Output("o: float") .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) { c->set_output(0, c->Vector(2)); - return OkStatus(); + return absl::OkStatus(); }); REGISTER_OP("UnaryTest") @@ -316,7 +316,7 @@ REGISTER_OP("UnaryTest") ::tensorflow::shape_inference::ShapeHandle o; TF_RETURN_IF_ERROR(c->Merge(c->UnknownShape(), c->input(0), &o)); c->set_output(0, o); - return OkStatus(); + return absl::OkStatus(); }); REGISTER_OP("BinaryTest") .Input("a: float") @@ -326,7 +326,7 @@ REGISTER_OP("BinaryTest") ::tensorflow::shape_inference::ShapeHandle o; TF_RETURN_IF_ERROR(c->Merge(c->UnknownShape(), c->input(0), &o)); c->set_output(0, o); - return OkStatus(); + return absl::OkStatus(); }); REGISTER_OP("BinaryTest2") .Input("a: float") @@ -807,7 +807,7 @@ TEST(EncapsulateSubgraphsWithGuaranteeConstOpTest, Simple) { EXPECT_FALSE(HasGuaranteeConstAttr(*n)); } } - return OkStatus(); + return absl::OkStatus(); }, /*reuse_existing_functions=*/false, &graph_after, &library)); EXPECT_EQ(2, guaranteed_consts); @@ -852,7 +852,7 @@ TEST(EncapsulateSubgraphsWithGuaranteeConstOpTest, Add) { EXPECT_FALSE(HasGuaranteeConstAttr(*n)); } } - return OkStatus(); + return absl::OkStatus(); }, /*reuse_existing_functions=*/false, &graph_after, &library)); // Only 1 runtime const, which is const_guarantee_add1. Add2 has one const diff --git a/tensorflow/compiler/jit/encapsulate_util.cc b/tensorflow/compiler/jit/encapsulate_util.cc index d761483a1f45f0..d97995f8b8974a 100644 --- a/tensorflow/compiler/jit/encapsulate_util.cc +++ b/tensorflow/compiler/jit/encapsulate_util.cc @@ -57,7 +57,7 @@ Status AppendToListAttr(Node* n, const string& attr_name, const string& value) { n->ClearAttr(attr_name); attr_value.push_back(value); n->AddAttr(attr_name, attr_value); - return OkStatus(); + return absl::OkStatus(); } // Replaces attribute value. @@ -104,7 +104,7 @@ Status PreprocessControlEdgesBetweenOutsideCompilations( for (auto e : edges_to_remove) { g->RemoveEdge(e); } - return OkStatus(); + return absl::OkStatus(); } // Step 2 for `PreprocessEdgesBetweenOutsideCompilations`. See comments of @@ -188,7 +188,7 @@ Status PreprocessDataEdgesBetweenOutsideCompilations( } } } - return OkStatus(); + return absl::OkStatus(); } // Step 1 for `PostprocessEdgesBetweenOutsideCompilations`. See comments of @@ -264,7 +264,7 @@ Status PostprocessDataEdgesBetweenOutsideCompilations( // Remove placeholder node. g->RemoveNode(n); } - return OkStatus(); + return absl::OkStatus(); } // Step 2 for `PostprocessEdgesBetweenOutsideCompilations`. See comments of @@ -297,7 +297,7 @@ Status PostprocessControlEdgesBetweenOutsideCompilations( } } } - return OkStatus(); + return absl::OkStatus(); } } // namespace @@ -337,7 +337,7 @@ Status PerformStaticShapeInferenceBeforeEncapsulation(Graph* g) { n->AddAttr(kXlaInferredShapesAttrName, output_shapes); } - return OkStatus(); + return absl::OkStatus(); } StatusOr>>> @@ -400,7 +400,7 @@ Status PreprocessEdgesBetweenOutsideCompilations( g, outside_compilation_attr_name)); TF_RETURN_IF_ERROR(PreprocessDataEdgesBetweenOutsideCompilations( g, outside_compilation_attr_name)); - return OkStatus(); + return absl::OkStatus(); } Status PostprocessEdgesBetweenOutsideCompilations( @@ -409,7 +409,7 @@ Status PostprocessEdgesBetweenOutsideCompilations( g, outside_compilation_attr_name)); TF_RETURN_IF_ERROR(PostprocessControlEdgesBetweenOutsideCompilations( g, outside_compilation_attr_name)); - return OkStatus(); + return absl::OkStatus(); } } // namespace tensorflow diff --git a/tensorflow/compiler/jit/encapsulate_util.h b/tensorflow/compiler/jit/encapsulate_util.h index 304e317ee2521d..ee31751a45cafd 100644 --- a/tensorflow/compiler/jit/encapsulate_util.h +++ b/tensorflow/compiler/jit/encapsulate_util.h @@ -116,7 +116,8 @@ struct XlaClusterInfo { // dependencies and control dependencies. cluster_deps maps the name name of an // outside compilation cluster to a set of names of outside compilation clusters // that it depends on. -tsl::StatusOr>>> +absl::StatusOr< + std::unique_ptr>>> OutsideCompilationClusterDependencies( const Graph* g, const string& outside_compilation_attr_name); diff --git a/tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc b/tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc index 048ca0f48c4e05..9a482bf918ec91 100644 --- a/tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc +++ b/tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc @@ -77,7 +77,7 @@ Status GetIndexAttr(const Node& n, int num_args, int* index) { return errors::InvalidArgument("Invalid ", n.type_string(), " number ", *index); } - return OkStatus(); + return absl::OkStatus(); } // Returns the data type of the destination of an edge. @@ -189,7 +189,7 @@ Status RewriteSubgraph(const std::vector& arg_source_tensors, TF_ASSIGN_OR_RETURN(uint64 fingerprint, FingerprintGraph(*graph)); VLOG(1) << "Subgraph fingerprint:" << fingerprint; call_def->set_op(absl::StrCat(call_def->op(), "_", fingerprint)); - return OkStatus(); + return absl::OkStatus(); } } // namespace @@ -223,7 +223,7 @@ Status RewriteSubgraph(const std::vector& arg_source_tensors, /*reuse_existing_functions=*/true, &output, flib_def), "EncapsulateXlaComputationsPass failed"); graph->swap(output); - return OkStatus(); + return absl::OkStatus(); } /*static*/ Status EncapsulateXlaComputationsPass::BuildXlaLaunchOps( @@ -355,7 +355,7 @@ Status RewriteSubgraph(const std::vector& arg_source_tensors, graph->AddControlEdge(xla_launch, n); } } - return OkStatus(); + return absl::OkStatus(); } /*static*/ Status EncapsulateXlaComputationsPass::BuildXlaLaunchOps( @@ -399,7 +399,7 @@ Status EncapsulateXlaComputationsPass::Run( VLOG(1) << "EncapsulateXlaComputations() finished: " << DumpGraphToFile("encapsulate_xla_computations_after", **options.graph, options.flib_def); - return OkStatus(); + return absl::OkStatus(); } } // namespace tensorflow diff --git a/tensorflow/compiler/jit/encapsulate_xla_computations_pass.h b/tensorflow/compiler/jit/encapsulate_xla_computations_pass.h index e4ae139121abe5..b6af1277976f44 100644 --- a/tensorflow/compiler/jit/encapsulate_xla_computations_pass.h +++ b/tensorflow/compiler/jit/encapsulate_xla_computations_pass.h @@ -73,8 +73,9 @@ class EncapsulateXlaComputationsPass : public GraphOptimizationPass { // XlaLaunch -> NodeA static Status BuildXlaLaunchOps( Graph* graph, - const std::function(const Node&)>& is_xla_launch_node, - const std::function(const Node&)>& + const std::function(const Node&)>& + is_xla_launch_node, + const std::function(const Node&)>& get_xla_function_info, bool add_edges_to_output_of_downstream_nodes); }; diff --git a/tensorflow/compiler/jit/extract_outside_compilation_pass.cc b/tensorflow/compiler/jit/extract_outside_compilation_pass.cc index e18832ea1b443a..ce139296bb74ce 100644 --- a/tensorflow/compiler/jit/extract_outside_compilation_pass.cc +++ b/tensorflow/compiler/jit/extract_outside_compilation_pass.cc @@ -100,7 +100,7 @@ Status GetArgDataTypes(const std::vector& arg_nodes, return errors::Internal("Cannot get datatype for input ", i); } } - return OkStatus(); + return absl::OkStatus(); } // Builds XlaRecvAtHost node. @@ -200,7 +200,7 @@ Status GetRetDataTypes(const std::vector& ret_nodes, return errors::Internal("Cannot get datatype for output ", i); } } - return OkStatus(); + return absl::OkStatus(); } // Builds XlaSendFromHost node. @@ -431,7 +431,7 @@ Status ResetDeviceOrdinalToPlaceholderValue(Graph* g) { n->DebugString()); } } - return OkStatus(); + return absl::OkStatus(); } // Cheap check to tell whether FunctionDef contains a lifted argument. @@ -536,7 +536,7 @@ Status AddMatchingRetvalNode(const FunctionBody& function_body, TF_ASSIGN_OR_RETURN(Node * ret_node, function_body.graph->AddNode(ret_def)); function_body.graph->AddEdge(arg_node, 0, ret_node, 0); - return OkStatus(); + return absl::OkStatus(); } void ReplaceLiftedArgNodePlaceholderWithArg( @@ -571,7 +571,7 @@ Status AddFunctionWithNewName(const std::string& new_name, func_attr->set_name(new_name); callsite_node->ClearAttr(func_attr_name); callsite_node->AddAttr(func_attr_name, *func_attr); - return OkStatus(); + return absl::OkStatus(); } // Reconnect outside compilation lifted arguments in a functional While node to @@ -588,7 +588,7 @@ Status PostprocessLiftedArgsForWhile( TF_RET_CHECK(body_function_def); if (!HasLiftedArgs(*body_function_def)) { - return OkStatus(); + return absl::OkStatus(); } // Gather all lifted args. @@ -681,7 +681,7 @@ Status PostprocessLiftedArgsForWhile( TF_RETURN_IF_ERROR(AddFunctionWithNewName(new_cond_function_name, "cond", rewritten_cond_function_def, &cond_func, n, fld)); - return OkStatus(); + return absl::OkStatus(); } Status PostprocessLiftedArgsForIf( @@ -704,7 +704,7 @@ Status PostprocessLiftedArgsForIf( // Nothing to do if neither branch contains any lifted arguments. if (!HasLiftedArgs(*then_branch_function_def) && !HasLiftedArgs(*else_branch_function_def)) { - return OkStatus(); + return absl::OkStatus(); } std::unique_ptr then_branch_function_body; @@ -820,7 +820,7 @@ Status PostprocessLiftedArgsForIf( TF_RETURN_IF_ERROR(AddFunctionWithNewName( new_else_function_name, "else_branch", rewritten_else_branch_function_def, &else_branch_func, n, fld)); - return OkStatus(); + return absl::OkStatus(); } Status PostprocessLiftedArgsForCall( @@ -831,7 +831,7 @@ Status PostprocessLiftedArgsForCall( // Nothing to do if the function does not contain any lifted arguments. if (!HasLiftedArgs(*fdef)) { - return OkStatus(); + return absl::OkStatus(); } std::unique_ptr fbody; @@ -917,7 +917,7 @@ Status PostprocessLiftedArgsForCall( data_types, outside_compilation_nodes, g, n); - return OkStatus(); + return absl::OkStatus(); } // Creates a mapping from outside compilation cluster name to lifted argument @@ -974,7 +974,7 @@ Status PostprocessLiftedArgs(Graph* g, FunctionLibraryDefinition* fld) { outside_compilation_attr_to_node, g, n, fld)); } - return OkStatus(); + return absl::OkStatus(); } // For an XLA computation, builds host side graph given all outside compilation @@ -1116,7 +1116,7 @@ Status ConstructHostGraph( **host_graph, fld); } - return OkStatus(); + return absl::OkStatus(); } // Expand XLA computation's outside compilation host side graph into main graph. @@ -1154,7 +1154,7 @@ Status ExpandHostGraphIntoMainGraph(Graph* main_graph, node_map[host_graph->source_node()] = main_graph->source_node(); } node_map[host_graph->sink_node()] = main_graph->sink_node(); - Status s = OkStatus(); + Status s = absl::OkStatus(); auto copy_node_fn = [&](const Node* n) { if (!s.ok()) { return; @@ -1323,7 +1323,7 @@ Status RewriteShapeInferenceGraph(const string& shape_inference_graph_name, TF_RETURN_IF_ERROR( fld->ReplaceFunction(shape_inference_graph_name, fdef_replace)); - return OkStatus(); + return absl::OkStatus(); } void SetMaximalSharding(NodeDefBuilder& node_builder) { @@ -1397,7 +1397,7 @@ Status ReplaceKeyPlaceholderWithArgNode(const string& xla_cluster_name, TF_RETURN_IF_ERROR(GraphToFunctionDef( *g, func_name, HostGraphControlRetMapping, &replace_fdef)); TF_RETURN_IF_ERROR(fld->ReplaceFunction(func_name, replace_fdef)); - return OkStatus(); + return absl::OkStatus(); } // Builds host side graph for If node. @@ -1477,7 +1477,7 @@ TF_ATTRIBUTE_NOINLINE Status BuildHostGraphForIfNode( TF_RETURN_IF_ERROR(fld->AddFunctionDef(oc_host_graph_fdef)); } - return OkStatus(); + return absl::OkStatus(); } // Rewrites loop cond to add a node which sends loop cond to host. @@ -1553,7 +1553,7 @@ TF_ATTRIBUTE_NOINLINE Status AddSendLoopPredToLoopCond( while_node->AddAttr("cond", *loop_cond_func); } - return OkStatus(); + return absl::OkStatus(); } // Rewrites while loop cond function for host. @@ -1627,7 +1627,7 @@ Status RewriteHostWhileLoopCond( TF_RETURN_IF_ERROR( fld->ReplaceFunction(cond_host_func_name, cond_replace_fdef)); - return OkStatus(); + return absl::OkStatus(); } // Rewrites while loop body function for host. @@ -1685,7 +1685,7 @@ Status RewriteHostWhileLoopBody( TF_RETURN_IF_ERROR( fld->ReplaceFunction(body_host_func_name, body_replace_fdef)); - return OkStatus(); + return absl::OkStatus(); } // Builds host side graph for while node. @@ -1752,7 +1752,7 @@ TF_ATTRIBUTE_NOINLINE Status BuildHostGraphForWhileNode( TF_RETURN_IF_ERROR(fld->AddFunctionDef(oc_host_graph_fdef)); } - return OkStatus(); + return absl::OkStatus(); } // Builds host graph for func call nodes. @@ -1801,7 +1801,7 @@ Status BuildHostGraphForFuncCallNode( TF_RETURN_IF_ERROR(fld->AddFunctionDef(oc_host_graph_fdef)); } - return OkStatus(); + return absl::OkStatus(); } TF_ATTRIBUTE_NOINLINE Status ExtractOutsideCompilationForFuncCallNode( @@ -1843,7 +1843,7 @@ TF_ATTRIBUTE_NOINLINE Status ExtractOutsideCompilationForFuncCallNode( // If the function call does not have outside compilation, nothing to do. if (!func_has_outside_compilation) { - return OkStatus(); + return absl::OkStatus(); } *has_outside_compilation = true; @@ -1887,7 +1887,7 @@ TF_ATTRIBUTE_NOINLINE Status ExtractOutsideCompilationForFuncCallNode( // Record the host graph. host_graphs->push_back(oc_host_graph_name); - return OkStatus(); + return absl::OkStatus(); } Status ExtractOutsideCompilationForIfNode( @@ -1926,7 +1926,7 @@ Status ExtractOutsideCompilationForIfNode( // If then/else branch do not have outside compilation, nothing to do. if (!then_branch_has_outside_compilation && !else_branch_has_outside_compilation) { - return OkStatus(); + return absl::OkStatus(); } *has_outside_compilation = true; @@ -2006,7 +2006,7 @@ Status ExtractOutsideCompilationForIfNode( then_branch_host_func_name, else_branch_host_func_name)); host_graphs->push_back(oc_host_graph_name); - return OkStatus(); + return absl::OkStatus(); } Status ExtractOutsideCompilationForWhileNode( @@ -2040,7 +2040,7 @@ Status ExtractOutsideCompilationForWhileNode( // If cond/body do not have outside compilation, nothing to do. if (!cond_has_outside_compilation && !body_has_outside_compilation) { - return OkStatus(); + return absl::OkStatus(); } *has_outside_compilation = true; @@ -2107,7 +2107,7 @@ Status ExtractOutsideCompilationForWhileNode( cond_host_func_name, body_host_func_name)); host_graphs->push_back(oc_host_graph_name); - return OkStatus(); + return absl::OkStatus(); } Status ExtractOutsideCompilationForNodesWithAssociatedFunctions( @@ -2149,7 +2149,7 @@ Status ExtractOutsideCompilationForNodesWithAssociatedFunctions( has_outside_compilation)); } - return OkStatus(); + return absl::OkStatus(); } Status CopyOutsideCompilationConstNodes( @@ -2194,7 +2194,7 @@ Status CopyOutsideCompilationConstNodes( } } - return OkStatus(); + return absl::OkStatus(); } } // namespace @@ -2300,7 +2300,7 @@ Status RewriteOutsideCompilationSubgraphFn::operator()( AddNodeAttr("Toutputs", send_from_host_dtypes, node_def); AddNodeAttr("key", absl::StrCat("host_compute_channel_", new_name), node_def); - return OkStatus(); + return absl::OkStatus(); } Status ExtractOutsideCompilationForFunction( @@ -2316,7 +2316,7 @@ Status ExtractOutsideCompilationForFunction( FunctionLibraryRuntime::Handle handle; TF_RETURN_IF_ERROR( flr->Instantiate(func_name, AttrSlice(&func_name_attrs.attr()), &handle)); - Status ret_status = OkStatus(); + Status ret_status = absl::OkStatus(); auto cleanup_handle = gtl::MakeCleanup([&]() { auto s = flr->ReleaseHandle(handle); if (!s.ok()) { @@ -2543,7 +2543,7 @@ Status ExtractOutsideCompilation( if (VLOG_IS_ON(4)) { DumpGraphToFile("extract_outside_compilation_after", *g, fld); } - return OkStatus(); + return absl::OkStatus(); } } // namespace tensorflow diff --git a/tensorflow/compiler/jit/flags.cc b/tensorflow/compiler/jit/flags.cc index 82ed25767b90de..f85fd5fde4c1fa 100644 --- a/tensorflow/compiler/jit/flags.cc +++ b/tensorflow/compiler/jit/flags.cc @@ -282,6 +282,7 @@ void AllocateAndParseFlags() { bool enable_mlir_bridge_is_explicit = false; bool enable_mlir_merge_control_flow_pass = true; bool enable_mlir_convert_control_to_data_outputs_pass = false; + bool enable_mlir_composite_tpuexecute_side_effects = false; bool enable_mlir_strict_clusters = false; bool enable_mlir_multiple_local_cpu_devices = false; // Dump graphs in TFG dialect. @@ -376,6 +377,10 @@ void AllocateAndParseFlags() { &enable_mlir_convert_control_to_data_outputs_pass, "Enables `tf-executor-convert-control-to-data-outputs` pass for " "MLIR-Based TensorFlow Compiler Bridge."), + Flag("tf_mlir_composite_tpuexecute_side_effects", + &enable_mlir_composite_tpuexecute_side_effects, + "Enables certain TPUExecute ops to run in parallel if they only " + "operate on resources that live on composite devices."), Flag("tf_mlir_enable_strict_clusters", &enable_mlir_strict_clusters, "Do not allow clusters that have cyclic control dependencies."), Flag("tf_mlir_enable_multiple_local_cpu_devices", @@ -414,6 +419,8 @@ void AllocateAndParseFlags() { enable_mlir_merge_control_flow_pass; mlir_flags->tf_mlir_enable_convert_control_to_data_outputs_pass = enable_mlir_convert_control_to_data_outputs_pass; + mlir_flags->tf_mlir_enable_composite_tpuexecute_side_effects = + enable_mlir_composite_tpuexecute_side_effects; mlir_flags->tf_mlir_enable_strict_clusters = enable_mlir_strict_clusters; mlir_flags->tf_mlir_enable_generic_outside_compilation = enable_mlir_generic_outside_compilation; diff --git a/tensorflow/compiler/jit/flags.h b/tensorflow/compiler/jit/flags.h index 45a4c83a614afd..d2c078a617b258 100644 --- a/tensorflow/compiler/jit/flags.h +++ b/tensorflow/compiler/jit/flags.h @@ -288,6 +288,7 @@ struct MlirCommonFlags { bool tf_mlir_enable_merge_control_flow_pass; bool tf_mlir_enable_convert_control_to_data_outputs_pass; + bool tf_mlir_enable_composite_tpuexecute_side_effects; bool tf_mlir_enable_strict_clusters; bool tf_mlir_enable_generic_outside_compilation; bool tf_mlir_enable_tpu_variable_runtime_reformatting_pass; diff --git a/tensorflow/compiler/jit/force_xla_constants_on_host_pass.cc b/tensorflow/compiler/jit/force_xla_constants_on_host_pass.cc index 8e054b08984732..d407f5d15d904d 100644 --- a/tensorflow/compiler/jit/force_xla_constants_on_host_pass.cc +++ b/tensorflow/compiler/jit/force_xla_constants_on_host_pass.cc @@ -50,7 +50,7 @@ Status ForceXlaConstantsOnHostPass::Run( node->AddAttr("_input_hostmem", constant_arg_indices); } } - return OkStatus(); + return absl::OkStatus(); } } // namespace tensorflow diff --git a/tensorflow/compiler/jit/force_xla_constants_on_host_pass_test.cc b/tensorflow/compiler/jit/force_xla_constants_on_host_pass_test.cc index 7e44d1a1c297bc..e4b937551838f9 100644 --- a/tensorflow/compiler/jit/force_xla_constants_on_host_pass_test.cc +++ b/tensorflow/compiler/jit/force_xla_constants_on_host_pass_test.cc @@ -51,7 +51,7 @@ Status ForceXlaConstantsOnHost(const Scope& s, ForceXlaConstantsOnHostPass rewriter; TF_RETURN_IF_ERROR(rewriter.Run(options)); *result = std::move(graph); - return OkStatus(); + return absl::OkStatus(); } TEST(ForceXlaConstantsOnHostPassTest, Simple) { diff --git a/tensorflow/compiler/jit/get_compiler_ir.cc b/tensorflow/compiler/jit/get_compiler_ir.cc index 0809594816106d..2f99bc5357c2af 100644 --- a/tensorflow/compiler/jit/get_compiler_ir.cc +++ b/tensorflow/compiler/jit/get_compiler_ir.cc @@ -59,7 +59,7 @@ limitations under the License. namespace tensorflow { -static StatusOr> BuildExecutable( +static absl::StatusOr> BuildExecutable( xla::LocalClient* local_client, const XlaCompiler::CompilationResult& result, const XlaCompiler::Options& options, @@ -93,7 +93,7 @@ static StatusOr> BuildExecutable( return std::move(executables[0]); } -static StatusOr BuildHLOString( +static absl::StatusOr BuildHLOString( IrExportStage stage, const XlaCompiler::CompilationResult& result, xla::LocalClient* local_client, const XlaCompiler::Options& options) { switch (stage) { @@ -138,7 +138,7 @@ static StatusOr BuildHLOString( case IrExportStage::OPTIMIZED_HLO_DOT: { TF_ASSIGN_OR_RETURN(std::unique_ptr executable, BuildExecutable(local_client, result, options)); - StatusOr graph = xla::RenderGraph( + absl::StatusOr graph = xla::RenderGraph( *executable->executable()->module().entry_computation(), "Visualization", /*debug_options=*/{}, xla::RenderedGraphFormat::kDot, @@ -149,7 +149,7 @@ static StatusOr BuildHLOString( } } -static StatusOr> +static absl::StatusOr> BuildXlaCompilerArgumentFromTensorSpec( const FunctionBody* fbody, absl::Span must_be_constant_idxs, absl::Span inputs, @@ -328,7 +328,7 @@ absl::StatusOr CompileAndBuildHLOString( * - `input_handles`: Contains all concrete_fn inputs tensors, including * captured inputs. */ -StatusOr GetCompilerIr( +absl::StatusOr GetCompilerIr( IrExportStage stage, ProcessFunctionLibraryRuntime* pflr, absl::string_view func_name, Device* dev, EagerContext* context, absl::Span input_arg_shape_and_dtype, @@ -386,7 +386,7 @@ StatusOr GetCompilerIr( function, args); } -StatusOr GetCompilerIr( +absl::StatusOr GetCompilerIr( IrExportStage stage, ProcessFunctionLibraryRuntime* pflr, absl::string_view func_name, absl::string_view platform_name, EagerContext* context, diff --git a/tensorflow/compiler/jit/get_compiler_ir.h b/tensorflow/compiler/jit/get_compiler_ir.h index b2485d00878bbd..079a7af7ad3cd9 100644 --- a/tensorflow/compiler/jit/get_compiler_ir.h +++ b/tensorflow/compiler/jit/get_compiler_ir.h @@ -54,7 +54,7 @@ enum class CompilerArgSource { // Returns the IR format of the selected stage for a given function `func_name` // using library runtime `runtime` on a device `dev` with given // `inputs_arg_shape_and_dtype` and `input_handles`. -StatusOr GetCompilerIr( +absl::StatusOr GetCompilerIr( IrExportStage stage, ProcessFunctionLibraryRuntime* pflr, absl::string_view func_name, Device* dev, EagerContext* context, absl::Span input_arg_shape_and_dtype, @@ -64,7 +64,7 @@ StatusOr GetCompilerIr( // Returns the IR format of the selected stage for a given function `func_name` // using library runtime `runtime` on a platform `platform_name` with given // `inputs_arg_shape_and_dtype` and `input_handles`. -StatusOr GetCompilerIr( +absl::StatusOr GetCompilerIr( IrExportStage stage, ProcessFunctionLibraryRuntime* pflr, absl::string_view func_name, absl::string_view platform_name, EagerContext* context, diff --git a/tensorflow/compiler/jit/increase_dynamism_for_auto_jit_pass.cc b/tensorflow/compiler/jit/increase_dynamism_for_auto_jit_pass.cc index 7be2dcef3ea7cf..bbe0e1a5425956 100644 --- a/tensorflow/compiler/jit/increase_dynamism_for_auto_jit_pass.cc +++ b/tensorflow/compiler/jit/increase_dynamism_for_auto_jit_pass.cc @@ -184,7 +184,7 @@ Status ComputeSliceSize(const Scope& host_scope, if (absl::c_all_of(slice_inputs.size_as_vector, [](int64_t i) { return i >= 0; })) { *size = slice_inputs.size; - return OkStatus(); + return absl::OkStatus(); } Output input_shape = @@ -227,7 +227,7 @@ Status ComputeSliceSize(const Scope& host_scope, *size = ops::Concat(host_scope.WithOpName("slice_size"), slice_size, concat_axis); } - return OkStatus(); + return absl::OkStatus(); } // Terminology: "static sized" slice is a slice with the @@ -310,7 +310,7 @@ Status RewriteSlice(Graph* g, Node* slice, const SliceInputs& slice_inputs, TF_RETURN_IF_ERROR(ConvertTensorFlowSliceToStaticShapedSlice( g, slice, slice_inputs, cluster_name, &static_shaped_slice)); ReplaceTensorFlowSliceWithStaticShapedSlice(g, slice, static_shaped_slice); - return OkStatus(); + return absl::OkStatus(); } // Return true if `n` is a slice we should rewrite to have a static shape @@ -368,7 +368,7 @@ Status FindAndRewriteSlices(Graph* g, bool* changed) { *changed = !slices_to_rewrite.empty(); - return OkStatus(); + return absl::OkStatus(); } } // namespace @@ -387,7 +387,7 @@ Status IncreaseDynamismForAutoJitPass::Run( options.flib_def); } - return OkStatus(); + return absl::OkStatus(); } } // namespace tensorflow diff --git a/tensorflow/compiler/jit/increase_dynamism_for_auto_jit_pass_test.cc b/tensorflow/compiler/jit/increase_dynamism_for_auto_jit_pass_test.cc index ba4eb975268f90..e864ef1dd12ae9 100644 --- a/tensorflow/compiler/jit/increase_dynamism_for_auto_jit_pass_test.cc +++ b/tensorflow/compiler/jit/increase_dynamism_for_auto_jit_pass_test.cc @@ -94,7 +94,7 @@ Status IncreaseDynamismForAutoJit(const Scope& s, IncreaseDynamismForAutoJitPass rewriter; TF_RETURN_IF_ERROR(rewriter.Run(options)); *result = std::move(graph); - return OkStatus(); + return absl::OkStatus(); } TEST(SliceToDynamicSliceRewriteTest, Basic) { diff --git a/tensorflow/compiler/jit/kernels/xla_ops.cc b/tensorflow/compiler/jit/kernels/xla_ops.cc index 435b63d8f5dbe9..6ba7afe9884e91 100644 --- a/tensorflow/compiler/jit/kernels/xla_ops.cc +++ b/tensorflow/compiler/jit/kernels/xla_ops.cc @@ -212,7 +212,7 @@ Status GetTaskName(const std::string_view device_name, std::string* task_name) { device_name); } - return OkStatus(); + return absl::OkStatus(); } // Provide SendDeviceMemoryFunction for XLA host callbacks. This callback @@ -400,7 +400,7 @@ Status CompileToLocalExecutable( rm->default_container(), "device_compilation_profiler", &profiler, [](DeviceCompilationProfiler** profiler) { *profiler = new DeviceCompilationProfiler(); - return OkStatus(); + return absl::OkStatus(); })); // Hold the reference to the XLA device compiler and profiler during // evaluation. (We could probably free them sooner because the ResourceMgr @@ -899,7 +899,7 @@ void XlaRunOp::Compute(OpKernelContext* ctx) { closure.client(), closure.executable(), ctx)); } - OP_REQUIRES_OK(ctx, OkStatus()); + OP_REQUIRES_OK(ctx, absl::OkStatus()); return; } diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.cc b/tensorflow/compiler/jit/mark_for_compilation_pass.cc index 526059c22fde8b..0883cff150b995 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass.cc @@ -282,7 +282,7 @@ class MarkForCompilationPassImpl { // If this returns false then Initialize exited early (either because there is // nothing to do or we saw a graph that we can't handle) and not all the // fields in this MarkForCompilationPassImpl instance are set up. - StatusOr Initialize(); + absl::StatusOr Initialize(); // Runs through the entire cluster graph in post-order and calls `fn(from, // to)` on each edge. `fn(from, to)` is expected to return true if it was @@ -290,7 +290,7 @@ class MarkForCompilationPassImpl { // // Returns true if `fn` returned true for any edge. template - StatusOr ForEachEdgeInPostOrder(FnTy fn); + absl::StatusOr ForEachEdgeInPostOrder(FnTy fn); // Contracts as many edges as possible to create XLA clusters. After this // finishes the clustering decisions made are implicitly stored in @@ -319,7 +319,7 @@ class MarkForCompilationPassImpl { // Tries to contract the edge from cluster `from` to cluster `to`. Returns // true if successful. - StatusOr TryToContractEdge(Cluster* from, Cluster* to); + absl::StatusOr TryToContractEdge(Cluster* from, Cluster* to); // Nodes that XLA can compile are put in `compilation_candidates_`. Status FindCompilationCandidates(); @@ -329,11 +329,11 @@ class MarkForCompilationPassImpl { // Populates `clusters_`. Status BuildInitialClusterSet(); - StatusOr ShouldCompileClusterImpl(const Cluster& cluster); + absl::StatusOr ShouldCompileClusterImpl(const Cluster& cluster); - StatusOr ShouldCompileCluster(const Cluster& cluster); + absl::StatusOr ShouldCompileCluster(const Cluster& cluster); - StatusOr ClusteringWillIntroduceInterDeviceDependency( + absl::StatusOr ClusteringWillIntroduceInterDeviceDependency( const Cluster& from, const Cluster& to); bool ShouldCompile(bool is_xla_compile_attr_true, @@ -352,8 +352,8 @@ class MarkForCompilationPassImpl { // Returns true if the devices in `cluster_a` and `cluster_b` are compatible // and therefore not a hindrance for combining the two clusters into a larger // cluster. - StatusOr AreDevicesCompatible(const Cluster& cluster_a, - const Cluster& cluster_b); + absl::StatusOr AreDevicesCompatible(const Cluster& cluster_a, + const Cluster& cluster_b); void DumpPostClusteringGraphs(); void VLogClusteringSummary(); @@ -637,7 +637,7 @@ Status IgnoreResourceOpForSafetyAnalysis( if (n.assigned_device_name().empty()) { *ignore = false; - return OkStatus(); + return absl::OkStatus(); } TF_ASSIGN_OR_RETURN( @@ -649,10 +649,10 @@ Status IgnoreResourceOpForSafetyAnalysis( } else { *ignore = registration->cluster_resource_variable_ops_unsafely; } - return OkStatus(); + return absl::OkStatus(); } -StatusOr MarkForCompilationPassImpl::Initialize() { +absl::StatusOr MarkForCompilationPassImpl::Initialize() { TF_RET_CHECK(!initialized_ && !edges_contracted_ && !clusters_created_); initialized_ = true; @@ -690,7 +690,8 @@ StatusOr MarkForCompilationPassImpl::Initialize() { } template -StatusOr MarkForCompilationPassImpl::ForEachEdgeInPostOrder(FnTy fn) { +absl::StatusOr MarkForCompilationPassImpl::ForEachEdgeInPostOrder( + FnTy fn) { bool changed = false; for (int32_t node : cycles_graph_.AllNodesInPostOrder()) { Cluster* cluster_from = GetClusterForCyclesGraphNode(node); @@ -799,7 +800,8 @@ Status MarkForCompilationPassImpl::RunEdgeContractionLoop() { VLOG(4) << "Running phase 0"; TF_RETURN_IF_ERROR( - ForEachEdgeInPostOrder([&](Cluster* from, Cluster* to) -> StatusOr { + ForEachEdgeInPostOrder([&](Cluster* from, + Cluster* to) -> absl::StatusOr { // Shape consuming operations are desirable to cluster with their // operands because they return a small set of scalar values after // consuming a large amount of data. For example, given a graph X -> Y @@ -822,7 +824,8 @@ Status MarkForCompilationPassImpl::RunEdgeContractionLoop() { VLOG(4) << "Running phase 1"; TF_RETURN_IF_ERROR( - ForEachEdgeInPostOrder([&](Cluster* from, Cluster* to) -> StatusOr { + ForEachEdgeInPostOrder([&](Cluster* from, + Cluster* to) -> absl::StatusOr { // We split out this phase to get good clustering in the presence of a // specific pattern seen in some graphs: // @@ -892,7 +895,7 @@ Status MarkForCompilationPassImpl::RunEdgeContractionLoop() { })); TF_RET_CHECK(!changed); - return OkStatus(); + return absl::OkStatus(); } Status MarkForCompilationPassImpl::DeclusterNodes() { @@ -922,7 +925,7 @@ Status MarkForCompilationPassImpl::DeclusterNodes() { } } - return OkStatus(); + return absl::OkStatus(); } // Tracks monotonic sequence numbers for graphs. @@ -1010,7 +1013,7 @@ Status MarkForCompilationPassImpl::CreateClusters() { } } - return OkStatus(); + return absl::OkStatus(); } Status MarkForCompilationPassImpl::DumpDebugInfo() { @@ -1022,10 +1025,10 @@ Status MarkForCompilationPassImpl::DumpDebugInfo() { VLogClusteringSummary(); - return OkStatus(); + return absl::OkStatus(); } -StatusOr +absl::StatusOr MarkForCompilationPassImpl::ClusteringWillIntroduceInterDeviceDependency( const Cluster& cluster_from, const Cluster& cluster_to) { // If any of the consumer's producers are on a different device, do not @@ -1181,10 +1184,10 @@ Status MarkForCompilationPassImpl::BuildInitialClusterSet() { cluster_for_node_[node->id()].Get() = new_cluster; } - return OkStatus(); + return absl::OkStatus(); } -StatusOr IsIdentityDrivingConstsInLoop(Node* node) { +absl::StatusOr IsIdentityDrivingConstsInLoop(Node* node) { if (!node->IsIdentity()) { return false; } @@ -1475,7 +1478,7 @@ Status MarkForCompilationPassImpl::FindCompilationCandidates() { VLOG(2) << "compilation_candidates_.size() = " << compilation_candidates_.size(); - return OkStatus(); + return absl::OkStatus(); } bool MarkForCompilationPassImpl::CompilationDisallowedByXlaCompileAttr( @@ -1513,8 +1516,8 @@ bool MarkForCompilationPassImpl::LogNotContractableAndReturnFalse( return false; } -StatusOr MarkForCompilationPassImpl::TryToContractEdge(Cluster* from, - Cluster* to) { +absl::StatusOr MarkForCompilationPassImpl::TryToContractEdge( + Cluster* from, Cluster* to) { DCHECK(from->deadness_predicate().has_value() == to->deadness_predicate().has_value()); if (from->deadness_predicate() != to->deadness_predicate()) { @@ -1596,7 +1599,7 @@ Status MarkForCompilationPassImpl::Run() { if (!initialized) { // Initialization exited early which means this instance of // MarkForCompilationPassImpl is not set up to run the subsequent phases. - return OkStatus(); + return absl::OkStatus(); } TF_RETURN_IF_ERROR(RunEdgeContractionLoop()); @@ -1604,7 +1607,7 @@ Status MarkForCompilationPassImpl::Run() { TF_RETURN_IF_ERROR(CreateClusters()); TF_RETURN_IF_ERROR(DumpDebugInfo()); - return OkStatus(); + return absl::OkStatus(); } void MarkForCompilationPassImpl::DumpPostClusteringGraphs() { @@ -1756,7 +1759,7 @@ void MarkForCompilationPassImpl::VLogClusteringSummary() { } } -StatusOr MarkForCompilationPassImpl::AreDevicesCompatible( +absl::StatusOr MarkForCompilationPassImpl::AreDevicesCompatible( const Cluster& cluster_a, const Cluster& cluster_b) { DeviceSet devices = cluster_a.devices(); devices.UnionWith(cluster_b.devices()); @@ -1787,7 +1790,7 @@ StatusOr MarkForCompilationPassImpl::AreDevicesCompatible( } // Returns `true` iff we should compile `cluster`. -StatusOr MarkForCompilationPassImpl::ShouldCompileClusterImpl( +absl::StatusOr MarkForCompilationPassImpl::ShouldCompileClusterImpl( const Cluster& cluster) { TF_ASSIGN_OR_RETURN(DeviceId chosen_device, PickDeviceForXla(device_info_cache_, cluster.devices(), @@ -1838,7 +1841,7 @@ proper command-line flag, not via TF_XLA_FLAGS).)"; return should_compile; } -StatusOr MarkForCompilationPassImpl::ShouldCompileCluster( +absl::StatusOr MarkForCompilationPassImpl::ShouldCompileCluster( const Cluster& cluster) { auto it = should_compile_cluster_cache_.find(&cluster); if (it != should_compile_cluster_cache_.end()) { @@ -1864,14 +1867,14 @@ Status MarkForCompilation( for (Node* n : graph->nodes()) { // See explanation on `kXlaAlreadyClustered`. if (n->attrs().Find(kXlaAlreadyClustered)) { - return OkStatus(); + return absl::OkStatus(); } // Skip the pass if we found TPUExecute or TPUExecuteAndUpdateVariables ops // in the graph, which indicates the graph is produced by TPU TF-XLA bridge // and doesn't require auto clustering. if (n->type_string() == "TPUExecute" || n->type_string() == "TPUExecuteAndUpdateVariables") { - return OkStatus(); + return absl::OkStatus(); } } @@ -2277,6 +2280,7 @@ absl::flat_hash_set GetKnownXLAAllowlistOp() { "VariableShape", "Where", "While", + "XlaAllReduce", "XlaBroadcastHelper", "XlaCallModule", "XlaConcatND", @@ -2298,6 +2302,7 @@ absl::flat_hash_set GetKnownXLAAllowlistOp() { "XlaRecv", "XlaReduce", "XlaReducePrecision", + "XlaReduceScatter", "XlaReduceWindow", "XlaRemoveDynamicDimensionSize", "XlaReplicaId", diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc index b9527ae9ec56b0..aabedf61202d3f 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc @@ -390,7 +390,7 @@ static Status GradForUnaryCwise(FunctionDef* g, {}, // Nodes nodes); - return OkStatus(); + return absl::OkStatus(); } // A gradient containing only supported operators @@ -1816,7 +1816,7 @@ TEST(XlaCompilationTest, DeterministicClusterNames) { " rhs: ", rhs_cluster_name); } - return OkStatus(); + return absl::OkStatus(); }; testing::ResetClusterSequenceNumber(); diff --git a/tensorflow/compiler/jit/ops/xla_ops.cc b/tensorflow/compiler/jit/ops/xla_ops.cc index cf0e03bc9f64e5..7c370e46dec63f 100644 --- a/tensorflow/compiler/jit/ops/xla_ops.cc +++ b/tensorflow/compiler/jit/ops/xla_ops.cc @@ -55,7 +55,7 @@ REGISTER_OP("XlaClusterOutput") for (int i = 0; i < c->num_outputs(); ++i) { c->set_output(i, c->input(0)); } - return OkStatus(); + return absl::OkStatus(); }) .Doc( "Operator that connects the output of an XLA computation to other " @@ -112,7 +112,7 @@ REGISTER_OP("_XlaMerge") .Attr("T: type") .SetShapeFn([](InferenceContext* c) { c->set_output(0, c->input(0)); - return OkStatus(); + return absl::OkStatus(); }) .Doc(R"(XLA Merge Op. For use by the XLA JIT only. diff --git a/tensorflow/compiler/jit/partially_decluster_pass.cc b/tensorflow/compiler/jit/partially_decluster_pass.cc index 715cc1b31738b5..eb66a8d905cc8c 100644 --- a/tensorflow/compiler/jit/partially_decluster_pass.cc +++ b/tensorflow/compiler/jit/partially_decluster_pass.cc @@ -113,7 +113,7 @@ Status FindNodesToDecluster(const Graph& graph, } } } - return OkStatus(); + return absl::OkStatus(); } Status PartiallyDeclusterNode(Graph* graph, Node* n) { @@ -156,7 +156,7 @@ Status PartiallyDeclusterNode(Graph* graph, Node* n) { graph->RemoveNode(n); } - return OkStatus(); + return absl::OkStatus(); } // Clones nodes to outside their cluster to avoid device-to-host copies. For @@ -221,7 +221,7 @@ Status PartiallyDeclusterGraph(Graph* graph) { FindNodesToDecluster(*graph, &nodes_to_partially_decluster, post_order)); CHECK(nodes_to_partially_decluster.empty()); - return OkStatus(); + return absl::OkStatus(); } } // namespace reduce_device_to_host_copies @@ -251,12 +251,12 @@ Status MustCompileNode(const Node* n, bool* must_compile) { if (IsMustCompileDevice(device_type)) { *must_compile = true; - return OkStatus(); + return absl::OkStatus(); } // We must compile `n` if it does not have a TensorFlow kernel. *must_compile = !FindKernelDef(device_type, n->def(), nullptr, nullptr).ok(); - return OkStatus(); + return absl::OkStatus(); } // Declusters nodes to reduce the number of times we think we need to recompile @@ -363,7 +363,7 @@ Status PartiallyDeclusterGraph(Graph* graph, } } - return OkStatus(); + return absl::OkStatus(); } } // namespace reduce_recompilation @@ -397,7 +397,7 @@ Status PartiallyDeclusterGraph(Graph* graph) { << " because it is a root shape consumer"; RemoveFromXlaCluster(n); } - return OkStatus(); + return absl::OkStatus(); } } // namespace decluster_root_shape_consumers } // namespace @@ -430,6 +430,6 @@ Status PartiallyDeclusterPass::Run( TF_RETURN_IF_ERROR( decluster_root_shape_consumers::PartiallyDeclusterGraph(graph)); - return OkStatus(); + return absl::OkStatus(); } } // namespace tensorflow diff --git a/tensorflow/compiler/jit/pjrt_base_device.cc b/tensorflow/compiler/jit/pjrt_base_device.cc index d7c12921c7131c..ce7ed954575040 100644 --- a/tensorflow/compiler/jit/pjrt_base_device.cc +++ b/tensorflow/compiler/jit/pjrt_base_device.cc @@ -43,7 +43,7 @@ PjRtBaseDevice::PjRtBaseDevice(const SessionOptions& session_options, << " device_name: " << name(); } -/*static*/ StatusOr +/*static*/ absl::StatusOr PjRtBaseDevice::GetMetadataFromDevice(DeviceBase* device) { PjRtBaseDevice* pjrt_device = dynamic_cast(device->UnderlyingDevice()); diff --git a/tensorflow/compiler/jit/pjrt_base_device.h b/tensorflow/compiler/jit/pjrt_base_device.h index e8f4ac7cbe7061..b21357455c283b 100644 --- a/tensorflow/compiler/jit/pjrt_base_device.h +++ b/tensorflow/compiler/jit/pjrt_base_device.h @@ -99,7 +99,7 @@ class PjRtBaseDevice : public LocalDevice { // Creates a new PJRT base device. PjRtBaseDevice(const SessionOptions& session_options, const Options& options); - static StatusOr GetMetadataFromDevice( + static absl::StatusOr GetMetadataFromDevice( DeviceBase* device); private: diff --git a/tensorflow/compiler/jit/pjrt_device_compiler_client.cc b/tensorflow/compiler/jit/pjrt_device_compiler_client.cc index 1ce1f57f923708..f64468fd2d255b 100644 --- a/tensorflow/compiler/jit/pjrt_device_compiler_client.cc +++ b/tensorflow/compiler/jit/pjrt_device_compiler_client.cc @@ -39,7 +39,7 @@ xla::CompileOptions GetPjRtCompileOptions( return pjrt_compile_options; } -StatusOr> +absl::StatusOr> PjRtDeviceCompilerClient::BuildExecutable( const XlaCompiler::Options& options, const XlaCompiler::CompilationResult& result) { @@ -56,13 +56,13 @@ PjRtDeviceCompilerClient::BuildExecutable( return std::move(executable); } -StatusOr PjRtDeviceCompilerClient::SerializeExecutable( +absl::StatusOr PjRtDeviceCompilerClient::SerializeExecutable( const xla::PjRtLoadedExecutable& executable) { VLOG(1) << "Serializing xla::PjRtLoadedExecutable to string."; return executable.SerializeExecutable(); } -StatusOr PjRtDeviceCompilerClient::BuildSerializedExecutable( +absl::StatusOr PjRtDeviceCompilerClient::BuildSerializedExecutable( const XlaCompiler::Options& options, const XlaCompiler::CompilationResult& result) { VLOG(1) << "PJRT currently doesn't support AOT compilation. Compiling to " @@ -71,7 +71,7 @@ StatusOr PjRtDeviceCompilerClient::BuildSerializedExecutable( return executable->SerializeExecutable(); } -StatusOr> +absl::StatusOr> PjRtDeviceCompilerClient::LoadExecutable( const XlaCompiler::Options& options, const XlaCompiler::CompilationResult& result, diff --git a/tensorflow/compiler/jit/pjrt_device_compiler_client.h b/tensorflow/compiler/jit/pjrt_device_compiler_client.h index 73b898361f969a..8c590b57d702a0 100644 --- a/tensorflow/compiler/jit/pjrt_device_compiler_client.h +++ b/tensorflow/compiler/jit/pjrt_device_compiler_client.h @@ -33,19 +33,19 @@ class PjRtDeviceCompilerClient explicit PjRtDeviceCompilerClient(xla::PjRtClient* client) : client_(client) {} - StatusOr> BuildExecutable( + absl::StatusOr> BuildExecutable( const XlaCompiler::Options& options, const XlaCompiler::CompilationResult& result) override; // Returns a platform-specific serialization of `executable`. The // serialization is not guaranteed to be stable over time. `executable` must // have been produced by this client. - StatusOr SerializeExecutable( + absl::StatusOr SerializeExecutable( const xla::PjRtLoadedExecutable& executable) override; // PjRt doesn't support AOT compilation yet. Builds a PjRtLoadedExecutable and // serializes it to string. - StatusOr BuildSerializedExecutable( + absl::StatusOr BuildSerializedExecutable( const XlaCompiler::Options& options, const XlaCompiler::CompilationResult& result) override; @@ -57,7 +57,7 @@ class PjRtDeviceCompilerClient // is currently only implemented for TfrtTpuPjrtClient and hence, this // function doesn't use PjRtClient::LoadSerializedExecutable() and uses // PjRtClient::DeserializeExecutable() instead. - StatusOr> LoadExecutable( + absl::StatusOr> LoadExecutable( const XlaCompiler::Options& options, const XlaCompiler::CompilationResult& result, const std::string& serialized_executable) override; diff --git a/tensorflow/compiler/jit/pjrt_device_context.cc b/tensorflow/compiler/jit/pjrt_device_context.cc index cd8d231af19f5b..3065a34e5a10b7 100644 --- a/tensorflow/compiler/jit/pjrt_device_context.cc +++ b/tensorflow/compiler/jit/pjrt_device_context.cc @@ -38,7 +38,7 @@ limitations under the License. namespace tensorflow { namespace { -StatusOr> HostTensorToPjRtBuffer( +absl::StatusOr> HostTensorToPjRtBuffer( const tensorflow::Tensor* cpu_tensor, tensorflow::Device* device, xla::PjRtClient* pjrt_client, const XlaShapeLayoutHelpers::ShapeDeterminationFns @@ -96,7 +96,7 @@ void PjRtDeviceContext::CopyDeviceTensorToCPU(const Tensor* device_tensor, profiler::TraceMe traceme("PjRtDeviceContext::CopyDeviceTensorToCPU"); if (device_tensor->NumElements() == 0) { VLOG(2) << "CopyDeviceTensorToCPU empty tensor"; - done(OkStatus()); + done(absl::OkStatus()); return; } auto literal = std::make_unique(); @@ -149,19 +149,20 @@ void PjRtDeviceContext::CopyCPUTensorToDevice(const Tensor* cpu_tensor, profiler::TraceMe traceme("PjRtDeviceContext::CopyCPUTensorToDevice"); if (cpu_tensor->NumElements() == 0) { VLOG(2) << "CopyCPUTensorToDevice empty tensor"; - done(OkStatus()); + done(absl::OkStatus()); return; } // TODO(b/252887149): figure out how to cache PJRT client. - StatusOr pjrt_client = + absl::StatusOr pjrt_client = GetOrCreatePjRtClient(DeviceType(device->device_type())); if (!pjrt_client.ok()) { done(pjrt_client.status()); return; } - StatusOr> buffer_or = HostTensorToPjRtBuffer( - cpu_tensor, device, *pjrt_client, shape_determination_fns_); + absl::StatusOr> buffer_or = + HostTensorToPjRtBuffer(cpu_tensor, device, *pjrt_client, + shape_determination_fns_); if (!buffer_or.ok()) { done(buffer_or.status()); return; @@ -171,7 +172,7 @@ void PjRtDeviceContext::CopyCPUTensorToDevice(const Tensor* cpu_tensor, if (use_pjrt_tensor_buffer_) { // Copy the newly created tensor with PjRtTensorBuffer to output device // tensor. - StatusOr t = MakeTensorFromPjRtBuffer( + absl::StatusOr t = MakeTensorFromPjRtBuffer( device_tensor->dtype(), device_tensor->shape(), std::move(*buffer_or)); if (!t.ok()) { done(t.status()); @@ -203,13 +204,15 @@ void PjRtDeviceContext::CopyTensorInSameDevice(const Tensor* input_tensor, } // TODO(b/288585098): consider whether to support same device copy in PJRT // API. - StatusOr c_src_buffer = GetPjRtCBufferFromTensor(input_tensor); + absl::StatusOr c_src_buffer = + GetPjRtCBufferFromTensor(input_tensor); if (!c_src_buffer.ok()) { done(c_src_buffer.status()); return; } - StatusOr c_api_client = tensorflow::GetPjRtCApiClient( - tensorflow::DeviceType(device->device_type())); + absl::StatusOr c_api_client = + tensorflow::GetPjRtCApiClient( + tensorflow::DeviceType(device->device_type())); if (!c_api_client.ok()) { done(c_api_client.status()); return; @@ -243,11 +246,11 @@ void PjRtDeviceToDeviceCopy(DeviceContext* send_dev_context, profiler::TraceMe traceme("PjRtDevice_DeviceToDeviceCopy"); if (input->NumElements() == 0) { VLOG(2) << "PjRtDevice_DeviceToDeviceCopy empty tensor"; - done(OkStatus()); + done(absl::OkStatus()); return; } - StatusOr pjrt_dst_client = + absl::StatusOr pjrt_dst_client = GetOrCreatePjRtClient(DeviceType(dst->device_type())); if (!pjrt_dst_client.ok()) { @@ -267,7 +270,7 @@ void PjRtDeviceToDeviceCopy(DeviceContext* send_dev_context, xla::PjRtDevice* pjrt_dst_device = (*pjrt_dst_client)->LookupAddressableDevice(pjrt_dst_device_id).value(); - StatusOr> buffer_or = + absl::StatusOr> buffer_or = src_device_buffer->CopyToDevice(pjrt_dst_device); if (!buffer_or.ok()) { done(buffer_or.status()); @@ -280,7 +283,7 @@ void PjRtDeviceToDeviceCopy(DeviceContext* send_dev_context, ->use_pjrt_tensor_buffer()) { // Copy the newly created tensor with PjRtTensorBuffer to output device // tensor. - StatusOr t = MakeTensorFromPjRtBuffer( + absl::StatusOr t = MakeTensorFromPjRtBuffer( output->dtype(), output->shape(), std::move(*buffer_or)); if (!t.ok()) { done(t.status()); diff --git a/tensorflow/compiler/jit/rearrange_function_argument_pass_test.cc b/tensorflow/compiler/jit/rearrange_function_argument_pass_test.cc index 9cf49b0666ddd5..ffbcef3371ae81 100644 --- a/tensorflow/compiler/jit/rearrange_function_argument_pass_test.cc +++ b/tensorflow/compiler/jit/rearrange_function_argument_pass_test.cc @@ -118,7 +118,7 @@ TEST(RearrangeFunctionArgumentForFunctionTest, Basic) { &fld, &new_fbody)); *fbody = new_fbody.get(); fbodies.push_back(std::move(new_fbody)); - return OkStatus(); + return absl::OkStatus(); }, g.get(), &fld)); @@ -229,7 +229,7 @@ TEST(RearrangeFunctionArgumentForFunctionTest, &fld, &new_fbody)); *fbody = new_fbody.get(); fbodies.push_back(std::move(new_fbody)); - return OkStatus(); + return absl::OkStatus(); }, g.get(), &fld); EXPECT_EQ(status.code(), error::UNIMPLEMENTED); diff --git a/tensorflow/compiler/jit/resource_operation_safety_analysis.cc b/tensorflow/compiler/jit/resource_operation_safety_analysis.cc index 650369863c4d05..92f79dde874217 100644 --- a/tensorflow/compiler/jit/resource_operation_safety_analysis.cc +++ b/tensorflow/compiler/jit/resource_operation_safety_analysis.cc @@ -106,13 +106,13 @@ Status XlaResourceOpKindForNode( } if (should_ignore) { *out_resource_op_kind = std::nullopt; - return OkStatus(); + return absl::OkStatus(); } const XlaResourceOpInfo* op_info = GetResourceOpInfoForOp(n.type_string()); if (op_info) { *out_resource_op_kind = op_info->kind(); - return OkStatus(); + return absl::OkStatus(); } // We conservatively assume that functions will both read and write resource @@ -124,7 +124,7 @@ Status XlaResourceOpKindForNode( *out_resource_op_kind = std::nullopt; } - return OkStatus(); + return absl::OkStatus(); } // Returns true if a control or data dependence from a TensorFlow operation of @@ -314,6 +314,6 @@ Status ComputeIncompatibleResourceOperationPairs( std::sort(result->begin(), result->end()); CHECK(std::unique(result->begin(), result->end()) == result->end()); - return OkStatus(); + return absl::OkStatus(); } } // namespace tensorflow diff --git a/tensorflow/compiler/jit/shape_inference.cc b/tensorflow/compiler/jit/shape_inference.cc index b9cbd1e3105b13..1848998da26116 100644 --- a/tensorflow/compiler/jit/shape_inference.cc +++ b/tensorflow/compiler/jit/shape_inference.cc @@ -33,7 +33,7 @@ Status ShapeHandleToTensorShape(shape_inference::InferenceContext* context, const shape_inference::ShapeHandle& handle, PartialTensorShape* shape) { // The default is already unknown - if (!context->RankKnown(handle)) return OkStatus(); + if (!context->RankKnown(handle)) return absl::OkStatus(); std::vector dims(context->Rank(handle)); for (int32_t i = 0, end = dims.size(); i < end; ++i) { @@ -199,7 +199,7 @@ Status PropagateShapes(Graph* graph, } } } - return OkStatus(); + return absl::OkStatus(); } // Store the shapes of the output tensors in a map @@ -235,7 +235,7 @@ Status StoreOutputShapes(const Graph& graph, const ShapeRefiner& shape_refiner, << output.handle_shape.DebugString(); } } - return OkStatus(); + return absl::OkStatus(); } } // namespace @@ -267,8 +267,8 @@ Status InferShapes(Graph* graph, const std::map& arg_shapes, return StoreOutputShapes(*graph, shape_refiner, shape_info); } -StatusOr MergeInferredShapes(const InferredShape& a, - const InferredShape& b) { +absl::StatusOr MergeInferredShapes(const InferredShape& a, + const InferredShape& b) { InferredShape result; TF_RETURN_IF_ERROR(a.shape.MergeWith(b.shape, &result.shape)); diff --git a/tensorflow/compiler/jit/shape_inference.h b/tensorflow/compiler/jit/shape_inference.h index 48a9cfd79d8497..2d6322644b9e12 100644 --- a/tensorflow/compiler/jit/shape_inference.h +++ b/tensorflow/compiler/jit/shape_inference.h @@ -46,8 +46,8 @@ Status InferShapes(Graph* graph, const std::map& arg_shapes, // Merges two InferredShapes. Return an error if the two shapes cannot be // merged. -StatusOr MergeInferredShapes(const InferredShape& a, - const InferredShape& b); +absl::StatusOr MergeInferredShapes(const InferredShape& a, + const InferredShape& b); } // namespace tensorflow diff --git a/tensorflow/compiler/jit/shape_inference_helpers.cc b/tensorflow/compiler/jit/shape_inference_helpers.cc index f3dd0c7ec78453..9290861d48f0bc 100644 --- a/tensorflow/compiler/jit/shape_inference_helpers.cc +++ b/tensorflow/compiler/jit/shape_inference_helpers.cc @@ -41,7 +41,7 @@ Status BackEdgeHelper::Remove(Graph* graph) { for (const BackEdge& be : back_edges_) { graph_->RemoveEdge(be.edge); } - return OkStatus(); + return absl::OkStatus(); } const std::vector& BackEdgeHelper::RemovedEdges() @@ -60,7 +60,7 @@ Status BackEdgeHelper::Replace() { for (const BackEdge& be : back_edges_) { graph_->AddEdge(be.src, be.src_output, be.dst, be.dst_input); } - return OkStatus(); + return absl::OkStatus(); } } // namespace tensorflow diff --git a/tensorflow/compiler/jit/test_util.cc b/tensorflow/compiler/jit/test_util.cc index 41afd63cca3b1e..f073902bc03d4a 100644 --- a/tensorflow/compiler/jit/test_util.cc +++ b/tensorflow/compiler/jit/test_util.cc @@ -58,7 +58,7 @@ Status ShapeAnnotationsMatch( return errors::InvalidArgument("Missing shapes for nodes: ", absl::StrJoin(missing, ",")); } - return OkStatus(); + return absl::OkStatus(); } void DeviceSetup::AddDevicesAndSetUp( diff --git a/tensorflow/compiler/jit/tests/auto_clustering_test_helper.cc b/tensorflow/compiler/jit/tests/auto_clustering_test_helper.cc index e93af0df217e84..2a761c3e5a57c6 100644 --- a/tensorflow/compiler/jit/tests/auto_clustering_test_helper.cc +++ b/tensorflow/compiler/jit/tests/auto_clustering_test_helper.cc @@ -33,7 +33,8 @@ limitations under the License. namespace tensorflow { namespace { -StatusOr SummarizeClustering(const GraphDef& auto_clustered_graph_def) { +absl::StatusOr SummarizeClustering( + const GraphDef& auto_clustered_graph_def) { testing::ResetClusterSequenceNumber(); Graph graph(OpRegistry::Global()); GraphConstructorOptions graph_opts; @@ -95,7 +96,7 @@ Status AssertGraphDefIsUnclustered(const GraphDef& graphdef) { } } - return OkStatus(); + return absl::OkStatus(); } Status ReadTextProtoFromString(Env* env, const string& data, @@ -103,7 +104,7 @@ Status ReadTextProtoFromString(Env* env, const string& data, if (!::tensorflow::protobuf::TextFormat::ParseFromString(data, proto)) { return errors::DataLoss("Can't parse input data as text proto"); } - return OkStatus(); + return absl::OkStatus(); } } // namespace @@ -121,7 +122,7 @@ Status AutoClusteringTest::RunAutoClusteringTestImpl( LOG(INFO) << "Not running " << ::testing::UnitTest::GetInstance()->current_test_info()->name() << " since test was not built with --config=cuda"; - return OkStatus(); + return absl::OkStatus(); } TF_RETURN_IF_ERROR(AssertGraphDefIsUnclustered(graphdef)); @@ -158,7 +159,7 @@ Status AutoClusteringTest::RunAutoClusteringTestImpl( EXPECT_EQ(golden_file_contents, clustering_summary); - return OkStatus(); + return absl::OkStatus(); } Status AutoClusteringTest::RunAutoClusteringTestWithPbtxt( @@ -221,7 +222,7 @@ Status BenchmarkMarkForCompilation(absl::string_view graph_def_path, std::move(graph_def_copy), &result)); } - return OkStatus(); + return absl::OkStatus(); } #endif // PLATFORM_GOOGLE diff --git a/tensorflow/compiler/jit/tests/device_compiler_test_helper.cc b/tensorflow/compiler/jit/tests/device_compiler_test_helper.cc index c1ae143ee94508..4eb93e85819651 100644 --- a/tensorflow/compiler/jit/tests/device_compiler_test_helper.cc +++ b/tensorflow/compiler/jit/tests/device_compiler_test_helper.cc @@ -131,7 +131,7 @@ Status DeviceCompilerSerializeTest::ExecuteWithBatch(const GraphDef& graph, EXPECT_NEAR(golden_output_tensors[0].flat()(i), output_tensors[0].flat()(i), 1e-3); } - return OkStatus(); + return absl::OkStatus(); } Status DeviceCompilerSerializeTest::AlterPersistentCacheEntryHloModuleNames( @@ -160,7 +160,7 @@ Status DeviceCompilerSerializeTest::AlterPersistentCacheEntryHloModuleNames( return errors::NotFound( "Did not find any persistent XLA compilation cache entries to alter."); } - return OkStatus(); + return absl::OkStatus(); } } // namespace tensorflow diff --git a/tensorflow/compiler/jit/tests/device_compiler_test_helper.h b/tensorflow/compiler/jit/tests/device_compiler_test_helper.h index e8ae70928d17d4..9cf36d0cbc6cb1 100644 --- a/tensorflow/compiler/jit/tests/device_compiler_test_helper.h +++ b/tensorflow/compiler/jit/tests/device_compiler_test_helper.h @@ -33,17 +33,17 @@ class JitCompilationListener : public XlaActivityListener { public: Status Listen( const XlaAutoClusteringActivity& auto_clustering_activity) override { - return OkStatus(); + return absl::OkStatus(); } Status Listen( const XlaJitCompilationActivity& jit_compilation_activity) override { activity_history_.push_back(jit_compilation_activity); - return OkStatus(); + return absl::OkStatus(); } Status Listen(const XlaOptimizationRemark& optimization_remark) override { - return OkStatus(); + return absl::OkStatus(); } ~JitCompilationListener() override = default; @@ -55,7 +55,7 @@ class JitCompilationListener : public XlaActivityListener { return absl::FailedPreconditionError("Unexpected listener history."); } } - return OkStatus(); + return absl::OkStatus(); } std::vector GetListenerHistory() { diff --git a/tensorflow/compiler/jit/variable_info_util.cc b/tensorflow/compiler/jit/variable_info_util.cc index 7e5d7076fd330d..315d5d63c73fc7 100644 --- a/tensorflow/compiler/jit/variable_info_util.cc +++ b/tensorflow/compiler/jit/variable_info_util.cc @@ -73,7 +73,7 @@ Status GetVariableInfosFromInputs(ResourceMgr* rm, DeviceBase* dev, handle.container(), handle.name(), &variable, [](Var** ptr) { // This var is uninitialized for now. *ptr = new Var(DT_INVALID); - return OkStatus(); + return absl::OkStatus(); })); VariableInfo& variable_info = result->emplace_back( var_idx, handle.name(), variable, handle.definition_stack_trace()); @@ -82,7 +82,7 @@ Status GetVariableInfosFromInputs(ResourceMgr* rm, DeviceBase* dev, variable_info.set_read_only(); } } - return OkStatus(); + return absl::OkStatus(); } Status LockVariables(absl::Span variables) { @@ -134,7 +134,7 @@ Status LockVariables(absl::Span variables) { prev = mu; } VLOG(4) << "Finished acquiring variable locks."; - return OkStatus(); + return absl::OkStatus(); } Status LockVariables(absl::Span variables) { @@ -155,7 +155,7 @@ Status SnapshotResourceVariables(OpKernelContext* ctx, (*result)[variable_indices[i]] = var ? std::make_optional(*var->tensor()) : std::nullopt; } - return OkStatus(); + return absl::OkStatus(); } std::vector GetResourceVariableIndicesFromContext(OpKernelContext* ctx) { @@ -179,7 +179,7 @@ Status CreateVariableInfoLookup( } variable_info_lookup.emplace(info.index(), &info); } - return OkStatus(); + return absl::OkStatus(); } } // namespace tensorflow diff --git a/tensorflow/compiler/jit/xla_activity_listener.cc b/tensorflow/compiler/jit/xla_activity_listener.cc index 476ed344b5fc4a..c3df741fbb08e2 100644 --- a/tensorflow/compiler/jit/xla_activity_listener.cc +++ b/tensorflow/compiler/jit/xla_activity_listener.cc @@ -48,13 +48,13 @@ Status ForEachListener(FnTy fn) { TF_RETURN_IF_ERROR(fn(listener.get())); } - return OkStatus(); + return absl::OkStatus(); } void FlushAllListeners() { Status s = ForEachListener([](XlaActivityListener* listener) { listener->Flush(); - return OkStatus(); + return absl::OkStatus(); }); CHECK(s.ok()); } diff --git a/tensorflow/compiler/jit/xla_activity_listener_test.cc b/tensorflow/compiler/jit/xla_activity_listener_test.cc index adf672b619a498..ee58c280d66d80 100644 --- a/tensorflow/compiler/jit/xla_activity_listener_test.cc +++ b/tensorflow/compiler/jit/xla_activity_listener_test.cc @@ -34,17 +34,17 @@ class TestListener : public XlaActivityListener { Status Listen( const XlaAutoClusteringActivity& auto_clustering_activity) override { auto_clustering_activity_ = auto_clustering_activity; - return OkStatus(); + return absl::OkStatus(); } Status Listen( const XlaJitCompilationActivity& jit_compilation_activity) override { jit_compilation_activity_ = jit_compilation_activity; - return OkStatus(); + return absl::OkStatus(); } Status Listen(const XlaOptimizationRemark& optimization_remark) override { - return OkStatus(); + return absl::OkStatus(); } ~TestListener() override {} diff --git a/tensorflow/compiler/jit/xla_activity_logging_listener.cc b/tensorflow/compiler/jit/xla_activity_logging_listener.cc index 021d60ab77f511..20262548e8bc2b 100644 --- a/tensorflow/compiler/jit/xla_activity_logging_listener.cc +++ b/tensorflow/compiler/jit/xla_activity_logging_listener.cc @@ -27,29 +27,29 @@ class XlaActivityLoggingListener final : public XlaActivityListener { const XlaAutoClusteringActivity& auto_clustering_activity) override { if (!IsEnabled()) { VLOG(3) << "Logging XlaAutoClusteringActivity disabled"; - return OkStatus(); + return absl::OkStatus(); } - return OkStatus(); + return absl::OkStatus(); } Status Listen( const XlaJitCompilationActivity& jit_compilation_activity) override { if (!IsEnabled()) { VLOG(3) << "Logging XlaJitCompilationActivity disabled"; - return OkStatus(); + return absl::OkStatus(); } - return OkStatus(); + return absl::OkStatus(); } Status Listen(const XlaOptimizationRemark& optimization_remark) override { if (!IsEnabled()) { VLOG(3) << "Logging XlaJitCompilationActivity disabled"; - return OkStatus(); + return absl::OkStatus(); } - return OkStatus(); + return absl::OkStatus(); } private: diff --git a/tensorflow/compiler/jit/xla_cluster_util.cc b/tensorflow/compiler/jit/xla_cluster_util.cc index e5c499b7b81e83..7bbaf158d7f368 100644 --- a/tensorflow/compiler/jit/xla_cluster_util.cc +++ b/tensorflow/compiler/jit/xla_cluster_util.cc @@ -557,7 +557,7 @@ Status GetNodesRelatedToRefVariablesInDirection( VLOG(2) << "# iterations = " << iterations; - return OkStatus(); + return absl::OkStatus(); } // Sorts control inputs of a graphdef so that they are deterministically diff --git a/tensorflow/compiler/jit/xla_cluster_util.h b/tensorflow/compiler/jit/xla_cluster_util.h index 22dd2662b79984..241786fefbe108 100644 --- a/tensorflow/compiler/jit/xla_cluster_util.h +++ b/tensorflow/compiler/jit/xla_cluster_util.h @@ -55,8 +55,8 @@ bool HasForwardedRefInput(const Node& node); // // Returns true for success and false for valid graphs that we can't handle yet // (b/127521408). -StatusOr CreateCycleDetectionGraph(const Graph* graph, - GraphCycles* cycles); +absl::StatusOr CreateCycleDetectionGraph(const Graph* graph, + GraphCycles* cycles); // Returns the XLA cluster in which `node` is placed if it is in an XLA cluster, // otherwise returns nullopt. @@ -97,16 +97,16 @@ XlaAutoClusteringSummary GetXlaAutoClusteringSummary(const Graph& graph); // // We assume each node has a trivial path to itself so the returned set includes // all of the nodes that have ref variables as input or output. -StatusOr> GetNodesRelatedToRefVariables( +absl::StatusOr> GetNodesRelatedToRefVariables( const Graph& graph, FunctionLibraryRuntime* lib_runtime); // Deterministically serialized the graph to a byte string. -StatusOr SerializeGraphDeterministic(const Graph& graph); +absl::StatusOr SerializeGraphDeterministic(const Graph& graph); // Computes a fingerprint of the given `graph`. The fingerprint can use used to // check if two graphs are likely the same but should not be relied on // determining if the graphs are identical. -StatusOr FingerprintGraph(const Graph& graph); +absl::StatusOr FingerprintGraph(const Graph& graph); } // namespace tensorflow diff --git a/tensorflow/compiler/jit/xla_cluster_util_test.cc b/tensorflow/compiler/jit/xla_cluster_util_test.cc index 863e92a86b97bb..44ce0d68365bd2 100644 --- a/tensorflow/compiler/jit/xla_cluster_util_test.cc +++ b/tensorflow/compiler/jit/xla_cluster_util_test.cc @@ -135,7 +135,7 @@ TEST(IsSingleGpuGraph, ReturnsFalseForMultiGpuGraph) { EXPECT_FALSE(IsSingleGpuGraph(*root.graph())); } -StatusOr> GetNodesRelatedToRefVarsSorted( +absl::StatusOr> GetNodesRelatedToRefVarsSorted( const Scope& scope, FunctionLibraryDefinition* flib_def = nullptr) { FunctionDefLibrary flib; FunctionLibraryDefinition flib_def_local(OpRegistry::Global(), flib); diff --git a/tensorflow/compiler/jit/xla_compile_on_demand_op.cc b/tensorflow/compiler/jit/xla_compile_on_demand_op.cc index 097b1dbd1b9317..05b1cec191a56b 100644 --- a/tensorflow/compiler/jit/xla_compile_on_demand_op.cc +++ b/tensorflow/compiler/jit/xla_compile_on_demand_op.cc @@ -99,7 +99,7 @@ Status GetAndLockVariablesAndBuildXlaCompilerArguments( XlaComputationLaunchContext::BuildXlaCompilerArguments( constant_indices, inputs, *variables, static_cast(ctx.device()))); - return OkStatus(); + return absl::OkStatus(); } } // namespace @@ -164,7 +164,7 @@ Status XlaCompileOnDemandOp::Run(const ResourceVarsSnapshot& variable_args, ctx, result, execution_output.ConsumeResult(), /*missing_ctx_input_prefix=*/0, absl::MakeSpan(*variable_infos), input_output_alias, snapshot_ptrs)); - return OkStatus(); + return absl::OkStatus(); } Status XlaCompileOnDemandOp::Compile( @@ -215,7 +215,7 @@ Status XlaCompileOnDemandOp::Compile( rm->default_container(), "device_compilation_profiler", profiler, [](DeviceCompilationProfiler** profiler) { *profiler = new DeviceCompilationProfiler(); - return OkStatus(); + return absl::OkStatus(); })); XlaCompiler::Options options = GenerateCompilerOptions( diff --git a/tensorflow/compiler/jit/xla_compile_util.cc b/tensorflow/compiler/jit/xla_compile_util.cc index 52c541b5c1cee7..05feb2c8f36769 100644 --- a/tensorflow/compiler/jit/xla_compile_util.cc +++ b/tensorflow/compiler/jit/xla_compile_util.cc @@ -36,7 +36,7 @@ constexpr const char* kPjRtDeviceCompilationProfilerResourceName = "pjrt_device_compilation_profiler"; } // namespace -StatusOr> CreateSingleOpGraph( +absl::StatusOr> CreateSingleOpGraph( const NodeDef& node_def, absl::Span args, absl::Span result_types) { // TODO(b/74182462): We implement this by creating a new dummy Graph including @@ -95,7 +95,7 @@ std::string GetPjRtDeviceCompilationProfilerResourceName( device_type.type_string()); } -StatusOr GetResourceMgrForDeviceCompiler( +absl::StatusOr GetResourceMgrForDeviceCompiler( const OpKernelContext& ctx, const DeviceType& device_type) { // We store information about the JIT-compiled XLA computation in the // ResourceMgr. The DeviceCompiler (which contains the DeviceCompilationCache) diff --git a/tensorflow/compiler/jit/xla_compile_util.h b/tensorflow/compiler/jit/xla_compile_util.h index 5fa1b40f821c48..d722ba8e784e61 100644 --- a/tensorflow/compiler/jit/xla_compile_util.h +++ b/tensorflow/compiler/jit/xla_compile_util.h @@ -41,7 +41,7 @@ enum class DeviceCompileState { // Creates a single-node graph using the specified `node_def` as the only op // apart from the arg and retval nodes corresponding to `args` and // `result_types` respectively. -StatusOr> CreateSingleOpGraph( +absl::StatusOr> CreateSingleOpGraph( const NodeDef& node_def, absl::Span args, absl::Span result_types); @@ -59,7 +59,7 @@ std::string GetPjRtDeviceCompilationProfilerResourceName( // Gets the ResourceMgr where the DeviceCompiler is/should be stored for the // given `device_type`. -StatusOr GetResourceMgrForDeviceCompiler( +absl::StatusOr GetResourceMgrForDeviceCompiler( const OpKernelContext& ctx, const DeviceType& device_type); } // namespace tensorflow diff --git a/tensorflow/compiler/jit/xla_compiler_options_util.cc b/tensorflow/compiler/jit/xla_compiler_options_util.cc index e3750061f07aa3..7b3b63723e12b4 100644 --- a/tensorflow/compiler/jit/xla_compiler_options_util.cc +++ b/tensorflow/compiler/jit/xla_compiler_options_util.cc @@ -102,7 +102,7 @@ XlaCompiler::Options GenerateCompilerOptionsForPjRt( const XlaPlatformInfo& platform_info, const PjRtDeviceCompiler* pjrt_device_compiler) { XlaCompiler::Options options; - StatusOr platform_device_id = + absl::StatusOr platform_device_id = tsl::GetPlatformDeviceIdFromDeviceParsedName( device_base->parsed_name(), DeviceType(tensorflow::down_cast(device_base) diff --git a/tensorflow/compiler/jit/xla_cpu_device.cc b/tensorflow/compiler/jit/xla_cpu_device.cc index 2cd0498a10fb29..4c41805d034a0e 100644 --- a/tensorflow/compiler/jit/xla_cpu_device.cc +++ b/tensorflow/compiler/jit/xla_cpu_device.cc @@ -27,6 +27,7 @@ limitations under the License. #include "tensorflow/compiler/jit/xla_device_ops.h" #include "tensorflow/compiler/tf2xla/layout_util.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "xla/stream_executor/platform_manager.h" #include "tensorflow/core/common_runtime/device_factory.h" #include "tensorflow/core/lib/core/status.h" @@ -45,11 +46,11 @@ Status XlaCpuDeviceFactory::ListPhysicalDevices(std::vector* devices) { if (!flags->tf_xla_enable_xla_devices && !XlaDevicesCreationRequired()) { VLOG(1) << "Not creating XLA devices, tf_xla_enable_xla_devices not set " "and XLA device creation not requested"; - return OkStatus(); + return absl::OkStatus(); } devices->push_back(absl::StrCat("/physical_device:", DEVICE_XLA_CPU, ":0")); - return OkStatus(); + return absl::OkStatus(); } Status XlaCpuDeviceFactory::CreateDevices( @@ -58,7 +59,7 @@ Status XlaCpuDeviceFactory::CreateDevices( XlaDeviceFlags* flags = GetXlaDeviceFlags(); if (!flags->tf_xla_enable_xla_devices && !XlaDevicesCreationRequired()) { VLOG(1) << "Not creating XLA devices, tf_xla_enable_xla_devices not set"; - return OkStatus(); + return absl::OkStatus(); } bool compile_on_demand = flags->tf_xla_compile_on_demand; @@ -84,7 +85,7 @@ Status XlaCpuDeviceFactory::CreateDevices( (void)registrations; TF_ASSIGN_OR_RETURN(auto platform, - se::MultiPlatformManager::PlatformWithName("Host")); + se::PlatformManager::PlatformWithName("Host")); XlaDevice::Options options; options.platform = platform; @@ -109,7 +110,7 @@ Status XlaCpuDeviceFactory::CreateDevices( return status; } devices->push_back(std::move(device)); - return OkStatus(); + return absl::OkStatus(); } REGISTER_LOCAL_DEVICE_FACTORY(DEVICE_XLA_CPU, XlaCpuDeviceFactory); diff --git a/tensorflow/compiler/jit/xla_device.cc b/tensorflow/compiler/jit/xla_device.cc index 6f71814bbe7b6a..c621143f8d9127 100644 --- a/tensorflow/compiler/jit/xla_device.cc +++ b/tensorflow/compiler/jit/xla_device.cc @@ -75,7 +75,7 @@ Status DefaultPaddedShapeFn(const Tensor& tensor, xla::Shape* shape) { const xla::ShapedBuffer& shaped_buffer = xla_tensor->shaped_buffer(); *shape = shaped_buffer.on_device_shape(); - return OkStatus(); + return absl::OkStatus(); } // Caches a XlaDeviceAllocator per pair. A @@ -182,7 +182,7 @@ const DeviceType& XlaDevice::Metadata::jit_device_type() const { "placed on the wrong device."); } *metadata = &(xla_device->xla_metadata_); - return OkStatus(); + return absl::OkStatus(); } /* static */ Status XlaDevice::GetMetadata(OpKernelContext* ctx, @@ -305,7 +305,7 @@ Status XlaDevice::EnsureStreamOkLocked(xla::Backend* backend, << (*stream)->DebugStreamPointers(); *stream_was_changed = true; } - return OkStatus(); + return absl::OkStatus(); } StatusOr> XlaDevice::GetDeviceContextLocked() { @@ -448,7 +448,7 @@ Status XlaDevice::TryGetDeviceContext(DeviceContext** out_context) { TF_ASSIGN_OR_RETURN(auto device_context, GetDeviceContextDefault()); device_context->Ref(); *out_context = device_context; - return OkStatus(); + return absl::OkStatus(); } // Warn about XLA_CPU/XLA_GPU exactly once. @@ -490,7 +490,7 @@ Status XlaDevice::Sync() { mutex_lock lock(mu_); stream = stream_; } - if (!stream) return OkStatus(); + if (!stream) return absl::OkStatus(); Status status = stream->BlockHostUntilDone(); TF_RETURN_IF_ERROR(status); @@ -498,39 +498,7 @@ Status XlaDevice::Sync() { return errors::Internal("XlaDevice::Sync() failed."); } VLOG(1) << "XlaDevice::Sync completed"; - return OkStatus(); -} - -// TODO(b/112409994): This is no longer necessary. Consolidate it with the -// synchronous version. -void XlaDevice::Sync(const DoneCallback& done) { - VLOG(1) << "XlaDevice::Sync (asynchronous)"; - std::shared_ptr stream; - { - mutex_lock lock(mu_); - stream = stream_; - } - if (!stream) { - done(OkStatus()); - return; - } - - // The call to ThenEnqueueOnBackgroundThread below enqueues a host callback at - // the end of the stream, after everything that has already been enqueued - // there at this moment. When the host callback is called, everything before - // it must have already finished, and the host callback will then place the - // task below onto a background thread. (See the implementation of - // ThenEnqueueOnBackgroundThread for details.) Therefore, when the done - // callback is finally called from that background thread, we know for sure - // that everything enqueued onto the stream (i.e., the device) at this very - // moment--when ThenEnqueueOnBackgroundThread is called--will have finished. - // This achieves a device-wide sync. - stream->ThenEnqueueOnBackgroundThread([stream, done](se::StreamExecutor*) { - profiler::TraceMe activity("XlaDevice::Sync::Callback", - profiler::TraceMeLevel::kInfo); - done(stream->ok() ? OkStatus() - : errors::Internal("XlaDevice::Sync() failed.")); - }); + return absl::OkStatus(); } Status XlaDevice::MakeTensorFromProto(DeviceContext* device_context, @@ -594,7 +562,7 @@ Status XlaDevice::HandleDeviceError() { if (local_device_error_callback != nullptr) { return local_device_error_callback(); } - return OkStatus(); + return absl::OkStatus(); } Status XlaDevice::RefreshStatus() { @@ -604,7 +572,7 @@ Status XlaDevice::RefreshStatus() { stream = stream_; } if (!stream) { - return OkStatus(); + return absl::OkStatus(); } Status status = stream->RefreshStatus(); if (!status.ok()) { diff --git a/tensorflow/compiler/jit/xla_device.h b/tensorflow/compiler/jit/xla_device.h index aeff4501af480d..64f3abbeca7a45 100644 --- a/tensorflow/compiler/jit/xla_device.h +++ b/tensorflow/compiler/jit/xla_device.h @@ -158,7 +158,6 @@ class XlaDevice : public LocalDevice { void ComputeAsync(AsyncOpKernel* op_kernel, OpKernelContext* context, AsyncOpKernel::DoneCallback done) override; Status Sync() override; - void Sync(const DoneCallback& done) override; Status TryGetDeviceContext(DeviceContext** out_context) override TF_LOCKS_EXCLUDED(mu_); @@ -185,9 +184,9 @@ class XlaDevice : public LocalDevice { // Two convenient methods to get the underlying device context. // Get the default device context, created by the first // shape_representation_fn. - StatusOr GetDeviceContextDefault(); + absl::StatusOr GetDeviceContextDefault(); // Get the device context given the index. - StatusOr GetDeviceContextWithIndex(int index); + absl::StatusOr GetDeviceContextWithIndex(int index); // Instructs this XlaDevice to set a AcceleratorDeviceInfo, which holds extra // information for GPU and TPU devices. @@ -205,7 +204,7 @@ class XlaDevice : public LocalDevice { Status RefreshStatus() override TF_LOCKS_EXCLUDED(mu_); private: - StatusOr GetOrCreateClient() const; + absl::StatusOr GetOrCreateClient() const; Allocator* GetAllocatorLocked(AllocatorAttributes attr) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); Status EnsureStreamOkLocked(xla::Backend* backend, const string& name, @@ -215,7 +214,7 @@ class XlaDevice : public LocalDevice { // Return a vector of device context, ordered by the sequence in the given // shape_representation_fns. - StatusOr> GetDeviceContextLocked() + absl::StatusOr> GetDeviceContextLocked() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); // Handles error when RefreshStatus sees !status.ok(). diff --git a/tensorflow/compiler/jit/xla_device_compiler_client.cc b/tensorflow/compiler/jit/xla_device_compiler_client.cc index 69436e6a4cabfb..71be1f7ec6b25d 100644 --- a/tensorflow/compiler/jit/xla_device_compiler_client.cc +++ b/tensorflow/compiler/jit/xla_device_compiler_client.cc @@ -35,7 +35,7 @@ std::vector GetShapePointers( } } // namespace -StatusOr> +absl::StatusOr> XlaDeviceCompilerClient::BuildExecutable( const XlaCompiler::Options& options, const XlaCompiler::CompilationResult& result) { @@ -52,7 +52,7 @@ XlaDeviceCompilerClient::BuildExecutable( return std::move(executables[0]); } -StatusOr XlaDeviceCompilerClient::SerializeExecutable( +absl::StatusOr XlaDeviceCompilerClient::SerializeExecutable( const xla::LocalExecutable& executable) { if (executable.executable() == nullptr) { return errors::FailedPrecondition( @@ -71,7 +71,7 @@ StatusOr XlaDeviceCompilerClient::SerializeExecutable( return exported.status(); } -StatusOr XlaDeviceCompilerClient::BuildSerializedExecutable( +absl::StatusOr XlaDeviceCompilerClient::BuildSerializedExecutable( const XlaCompiler::Options& options, const XlaCompiler::CompilationResult& result) { VLOG(2) << "Compiling to xla::AotCompilationResult and serializing it"; @@ -88,7 +88,7 @@ StatusOr XlaDeviceCompilerClient::BuildSerializedExecutable( return aot_results[0]->SerializeAsString(); } -StatusOr> +absl::StatusOr> XlaDeviceCompilerClient::LoadExecutable( const XlaCompiler::Options& options, const XlaCompiler::CompilationResult& result, diff --git a/tensorflow/compiler/jit/xla_device_compiler_client.h b/tensorflow/compiler/jit/xla_device_compiler_client.h index 4459574f935ed1..3967897ccf7441 100644 --- a/tensorflow/compiler/jit/xla_device_compiler_client.h +++ b/tensorflow/compiler/jit/xla_device_compiler_client.h @@ -31,24 +31,24 @@ class XlaDeviceCompilerClient explicit XlaDeviceCompilerClient(xla::LocalClient* client) : client_(client) {} - StatusOr> BuildExecutable( + absl::StatusOr> BuildExecutable( const XlaCompiler::Options& options, const XlaCompiler::CompilationResult& result) override; // Returns a serialized AOT result obtained by exporting the available // `executable` using the XlaCompiler. - StatusOr SerializeExecutable( + absl::StatusOr SerializeExecutable( const xla::LocalExecutable& executable) override; // Returns a serialized AOT result obtained by compiling `result` into an AOT // result. - StatusOr BuildSerializedExecutable( + absl::StatusOr BuildSerializedExecutable( const XlaCompiler::Options& options, const XlaCompiler::CompilationResult& result) override; // Loads a serialized AOT result (`serialized_executable`) into an // xla::LocalExecutable and returns it. - StatusOr> LoadExecutable( + absl::StatusOr> LoadExecutable( const XlaCompiler::Options& options, const XlaCompiler::CompilationResult& result, const std::string& serialized_executable) override; diff --git a/tensorflow/compiler/jit/xla_device_context.cc b/tensorflow/compiler/jit/xla_device_context.cc index 1c02dbe9a166f2..77437408e2f36e 100644 --- a/tensorflow/compiler/jit/xla_device_context.cc +++ b/tensorflow/compiler/jit/xla_device_context.cc @@ -117,7 +117,7 @@ void XlaDeviceContext::CopyCPUTensorToDevice(const Tensor* cpu_tensor, bool sync_dst_compute) const { if (cpu_tensor->NumElements() == 0) { VLOG(2) << "CopyCPUTensorToDevice empty tensor"; - done(OkStatus()); + done(absl::OkStatus()); return; } @@ -178,7 +178,7 @@ void XlaDeviceContext::CopyCPUTensorToDevice(const Tensor* cpu_tensor, host_to_device_stream_.get()); } - return OkStatus(); + return absl::OkStatus(); }(); if (!status.ok()) { done(status); @@ -200,7 +200,7 @@ void XlaDeviceContext::CopyCPUTensorToDevice(const Tensor* cpu_tensor, } else { host_to_device_stream_->ThenDoHostCallback([ref, done]() { ref.Unref(); - done(OkStatus()); + done(absl::OkStatus()); }); } } @@ -211,7 +211,7 @@ void XlaDeviceContext::CopyDeviceTensorToCPU(const Tensor* device_tensor, StatusCallback done) { if (device_tensor->NumElements() == 0) { VLOG(2) << "CopyDeviceTensorToCPU empty tensor"; - done(OkStatus()); + done(absl::OkStatus()); return; } VLOG(2) << "CopyDeviceTensorToCPU " @@ -300,7 +300,7 @@ Status XlaDeviceContext::ThenExecute(Device* device, std::function func) { VLOG(2) << "XlaDeviceContext::ThenExecute"; stream->ThenDoHostCallback(std::move(func)); - return OkStatus(); + return absl::OkStatus(); } } // namespace tensorflow diff --git a/tensorflow/compiler/jit/xla_device_ops.cc b/tensorflow/compiler/jit/xla_device_ops.cc index 9305de9e47db60..f22d72e1111013 100644 --- a/tensorflow/compiler/jit/xla_device_ops.cc +++ b/tensorflow/compiler/jit/xla_device_ops.cc @@ -55,7 +55,7 @@ void XlaAssignVariableOp::Compute(OpKernelContext* context) { *ptr = new Var(dtype_); *(*ptr)->tensor() = value; (*ptr)->is_initialized = true; - return OkStatus(); + return absl::OkStatus(); })); mutex_lock ml(*variable->mu()); OP_REQUIRES( diff --git a/tensorflow/compiler/jit/xla_gpu_device.cc b/tensorflow/compiler/jit/xla_gpu_device.cc index f8cfe4e7ecaf20..a16415ececc035 100644 --- a/tensorflow/compiler/jit/xla_gpu_device.cc +++ b/tensorflow/compiler/jit/xla_gpu_device.cc @@ -31,6 +31,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/layout_util.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "xla/stream_executor/gpu/gpu_init.h" +#include "xla/stream_executor/platform_manager.h" #include "tensorflow/core/common_runtime/device_factory.h" #include "tensorflow/core/lib/core/status.h" @@ -48,20 +49,19 @@ Status XlaGpuDeviceFactory::ListPhysicalDevices(std::vector* devices) { if (!flags->tf_xla_enable_xla_devices && !XlaDevicesCreationRequired()) { VLOG(1) << "Not creating XLA devices, tf_xla_enable_xla_devices not set " "and XLA devices creation not required"; - return OkStatus(); + return absl::OkStatus(); } - auto platform = - se::MultiPlatformManager::PlatformWithName(se::GpuPlatformName()); + auto platform = se::PlatformManager::PlatformWithName(se::GpuPlatformName()); if (!platform.ok()) { // Treat failures as non-fatal; there might not be a GPU in the machine. VLOG(1) << "Failed to create XLA_GPU device: " << platform.status(); - return OkStatus(); + return absl::OkStatus(); } int device_count = platform.value()->VisibleDeviceCount(); if (device_count <= 0) { - return OkStatus(); + return absl::OkStatus(); } for (int i = 0; i < device_count; ++i) { @@ -69,7 +69,7 @@ Status XlaGpuDeviceFactory::ListPhysicalDevices(std::vector* devices) { absl::StrCat("/physical_device:", DEVICE_XLA_GPU, ":", i)); } - return OkStatus(); + return absl::OkStatus(); } Status XlaGpuDeviceFactory::CreateDevices( @@ -78,7 +78,7 @@ Status XlaGpuDeviceFactory::CreateDevices( XlaDeviceFlags* flags = GetXlaDeviceFlags(); if (!flags->tf_xla_enable_xla_devices && !XlaDevicesCreationRequired()) { VLOG(1) << "Not creating XLA devices, tf_xla_enable_xla_devices not set"; - return OkStatus(); + return absl::OkStatus(); } XlaOpRegistry::DeviceRegistration registration; @@ -100,19 +100,18 @@ Status XlaGpuDeviceFactory::CreateDevices( RegisterXlaDeviceKernels(DEVICE_XLA_GPU, DEVICE_GPU_XLA_JIT); (void)registrations; - auto platform = - se::MultiPlatformManager::PlatformWithName(se::GpuPlatformName()); + auto platform = se::PlatformManager::PlatformWithName(se::GpuPlatformName()); if (!platform.ok()) { // Treat failures as non-fatal; there might not be a GPU in the machine. VLOG(1) << "Failed to create XLA_GPU device: " << platform.status(); - return OkStatus(); + return absl::OkStatus(); } auto iter = session_options.config.device_count().find("GPU"); if (iter != session_options.config.device_count().end() && iter->second == 0) { // Device count for GPU is 0. - return OkStatus(); + return absl::OkStatus(); } string allowed_gpus = @@ -149,7 +148,7 @@ Status XlaGpuDeviceFactory::CreateDevices( devices->push_back(std::move(device)); } - return OkStatus(); + return absl::OkStatus(); } REGISTER_LOCAL_DEVICE_FACTORY(DEVICE_XLA_GPU, XlaGpuDeviceFactory); diff --git a/tensorflow/compiler/jit/xla_host_recv_device_context.cc b/tensorflow/compiler/jit/xla_host_recv_device_context.cc index b634ac88739cf7..486d554ff75bc3 100644 --- a/tensorflow/compiler/jit/xla_host_recv_device_context.cc +++ b/tensorflow/compiler/jit/xla_host_recv_device_context.cc @@ -43,7 +43,7 @@ void XlaHostRecvDeviceContext::CopyDeviceTensorToCPU( } done_event_.SetStateConcrete(); - done(OkStatus()); + done(absl::OkStatus()); } } // namespace tensorflow diff --git a/tensorflow/compiler/jit/xla_host_send_device_context.cc b/tensorflow/compiler/jit/xla_host_send_device_context.cc index 1c30ef022a81e6..084c6b28cb3b5a 100644 --- a/tensorflow/compiler/jit/xla_host_send_device_context.cc +++ b/tensorflow/compiler/jit/xla_host_send_device_context.cc @@ -33,7 +33,7 @@ void XlaHostSendDeviceContext::CopyCPUTensorToDevice( } done_event_.SetStateConcrete(); - done(OkStatus()); + done(absl::OkStatus()); } } // namespace tensorflow diff --git a/tensorflow/compiler/jit/xla_host_send_recv_device_context_test.cc b/tensorflow/compiler/jit/xla_host_send_recv_device_context_test.cc index a6f8801cf5bd5a..98ba1ffb12e00a 100644 --- a/tensorflow/compiler/jit/xla_host_send_recv_device_context_test.cc +++ b/tensorflow/compiler/jit/xla_host_send_recv_device_context_test.cc @@ -23,7 +23,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "xla/stream_executor/device_memory.h" -#include "xla/stream_executor/multi_platform_manager.h" +#include "xla/stream_executor/platform_manager.h" #include "xla/stream_executor/stream.h" #include "xla/stream_executor/stream_executor.h" #include "tensorflow/core/framework/tensor_testutil.h" @@ -65,7 +65,7 @@ TEST_F(XlaHostSendRecvDeviceContextTest, CopyDeviceTensorToCPU) { Tensor dest_cpu_tensor(host_allocator_, DT_FLOAT, TensorShape({2, 2})); stream_executor::Platform* platform = - stream_executor::MultiPlatformManager::PlatformWithName("CUDA").value(); + stream_executor::PlatformManager::PlatformWithName("CUDA").value(); stream_executor::StreamExecutor* executor = platform->ExecutorForDevice(0).value(); stream_executor::Stream stream(executor); @@ -100,7 +100,7 @@ TEST_F(XlaHostSendRecvDeviceContextTest, CopyCPUTensorToDevice) { Tensor dest_cpu_tensor(host_allocator_, DT_FLOAT, TensorShape({2, 2})); stream_executor::Platform* platform = - stream_executor::MultiPlatformManager::PlatformWithName("CUDA").value(); + stream_executor::PlatformManager::PlatformWithName("CUDA").value(); stream_executor::StreamExecutor* executor = platform->ExecutorForDevice(0).value(); stream_executor::Stream stream(executor); @@ -135,7 +135,7 @@ TEST_F(XlaHostSendRecvDeviceContextTest, RoundTrip) { Tensor dest_cpu_tensor(host_allocator_, DT_FLOAT, TensorShape({2, 2})); stream_executor::Platform* platform = - stream_executor::MultiPlatformManager::PlatformWithName("CUDA").value(); + stream_executor::PlatformManager::PlatformWithName("CUDA").value(); stream_executor::StreamExecutor* executor = platform->ExecutorForDevice(0).value(); stream_executor::Stream stream(executor); diff --git a/tensorflow/compiler/jit/xla_launch_util.cc b/tensorflow/compiler/jit/xla_launch_util.cc index 288344eb341d23..3714ccf89837ce 100644 --- a/tensorflow/compiler/jit/xla_launch_util.cc +++ b/tensorflow/compiler/jit/xla_launch_util.cc @@ -41,6 +41,7 @@ limitations under the License. #include "xla/pjrt/tracked_device_buffer.h" #include "xla/shape_util.h" #include "xla/status_macros.h" +#include "xla/stream_executor/platform_manager.h" #include "tensorflow/core/common_runtime/dma_helper.h" #include "tensorflow/core/common_runtime/gpu_device_context.h" #include "tensorflow/core/framework/allocator.h" @@ -315,7 +316,7 @@ Status SetOutputForConstant( ctx->set_output(output_num, const_tensor); output_tensor = ctx->mutable_output(output_num); } - return OkStatus(); + return absl::OkStatus(); } static StatusOr GetOrCreateResourceVar( @@ -325,7 +326,7 @@ static StatusOr GetOrCreateResourceVar( TF_RETURN_IF_ERROR( LookupOrCreateResource(ctx, handle, &variable, [&write](Var** ptr) { *ptr = new Var(write.type); - return OkStatus(); + return absl::OkStatus(); })); return variable; } @@ -411,7 +412,7 @@ Status XlaComputationLaunchContext::PopulateOutputs( } else { // Stream is not set for the host platform. TF_ASSIGN_OR_RETURN(platform, - se::MultiPlatformManager::PlatformWithId( + se::PlatformManager::PlatformWithId( XlaPlatformInfoFromDevice(ctx->device()))); } TF_ASSIGN_OR_RETURN(auto transfer_manager, @@ -505,7 +506,7 @@ Status XlaComputationLaunchContext::PopulateOutputs( *var->tensor() = output_tensor; ++output_num; } - return OkStatus(); + return absl::OkStatus(); } StatusOr> @@ -673,7 +674,6 @@ Status PreparePjRtExecutableArguments( } } else { if (av_tensor->GetBuffer() == nullptr) { - // TODO(b/260799971): verify size 0 argument is supported. CHECK_EQ(tensor->NumElements(), 0); // Crash OK continue; } @@ -684,7 +684,7 @@ Status PreparePjRtExecutableArguments( non_donatable_input_indices->insert(args->size() - 1); } } - return OkStatus(); + return absl::OkStatus(); } // TODO(b/289002708) Create a unit test to cover use_pjrt_tensor_buffer=true. @@ -794,7 +794,7 @@ Status PopulateCtxOutputsFromPjRtExecutableOutputs( var->is_initialized |= write.modified; ++output_num; } - return OkStatus(); + return absl::OkStatus(); } xla::ExecuteOptions GetPjRtExecuteOptions( @@ -873,7 +873,7 @@ Status RunPjRtExecutable( TF_RETURN_IF_ERROR(PopulateCtxOutputsFromPjRtExecutableOutputs( num_missing_prefix_ctx_inputs, inputs, updated_variables, compilation_result, use_pjrt_tensor_buffer, execute_outputs, ctx)); - return OkStatus(); + return absl::OkStatus(); } StatusOr>> RunPjRtExecutable( diff --git a/tensorflow/compiler/jit/xla_launch_util.h b/tensorflow/compiler/jit/xla_launch_util.h index 97e26307bafded..32f277ff13b0b8 100644 --- a/tensorflow/compiler/jit/xla_launch_util.h +++ b/tensorflow/compiler/jit/xla_launch_util.h @@ -38,7 +38,7 @@ limitations under the License. namespace tensorflow { // Creates a list of updated resource variables. -StatusOr> GatherVariableInfo( +absl::StatusOr> GatherVariableInfo( OpKernelContext* ctx, const XlaCompiler::CompilationResult& compilation_result, int missing_ctx_input_prefix); @@ -46,7 +46,7 @@ StatusOr> GatherVariableInfo( // Returns pointers to inputs stored in `ctx`. std::vector InputsFromContext(OpKernelContext* ctx); -StatusOr> GetConstantInputIndicesFromContext( +absl::StatusOr> GetConstantInputIndicesFromContext( OpKernelContext* ctx); Status SetOutputForConstant( @@ -143,7 +143,7 @@ Status RunPjRtExecutable( // Similar to the above function but it does not take an OpKernelContext, and // it returns the output in PjRtBuffers, instead of populating results into // OpKernelContext. -StatusOr>> RunPjRtExecutable( +absl::StatusOr>> RunPjRtExecutable( int num_missing_prefix_ctx_inputs, const std::vector& inputs, const absl::flat_hash_map& variable_snapshots, const std::vector& updated_variables, @@ -172,10 +172,11 @@ class XlaComputationLaunchContext { // Builds a XlaCompiler::Argument vector from the arguments to an XlaLaunch // op. // Precondition: variables in `variable_args` are locked. - static StatusOr> BuildXlaCompilerArguments( - absl::Span must_be_constant_idxs, - absl::Span inputs, - absl::Span variable_args, Device* device); + static absl::StatusOr> + BuildXlaCompilerArguments(absl::Span must_be_constant_idxs, + absl::Span inputs, + absl::Span variable_args, + Device* device); // Add all inputs within `ctx` as XLA arguments (returned by arguments()). // `variables` is a map from TensorFlow argument number to resource variable. @@ -184,7 +185,7 @@ class XlaComputationLaunchContext { // missing and adjusts input indices accordingly. All elements in kernel's // input_mapping must be greater than or equal to `missing_ctx_input_prefix` // (in other words, no inputs actually required by the kernel can be missing). - StatusOr> PopulateInputs( + absl::StatusOr> PopulateInputs( OpKernelContext* ctx, const XlaCompiler::CompilationResult* compilation_result, const std::map& resource_vars, diff --git a/tensorflow/compiler/jit/xla_launch_util_test.cc b/tensorflow/compiler/jit/xla_launch_util_test.cc index 2206d5a6c144bb..08db00e21690e2 100644 --- a/tensorflow/compiler/jit/xla_launch_util_test.cc +++ b/tensorflow/compiler/jit/xla_launch_util_test.cc @@ -186,7 +186,7 @@ class PjRtExecutionUtilTest : public OpsTestBase { // Runs a PjRtLoadedExecutable with the given inputs, variables. Requires the // XlaCompiler::CompilationResult that was used to build the executable. - StatusOr>> RunExecutable( + absl::StatusOr>> RunExecutable( const std::vector& inputs, const std::vector& variables, const XlaCompiler::CompilationResult* result, diff --git a/tensorflow/compiler/jit/xla_platform_info.cc b/tensorflow/compiler/jit/xla_platform_info.cc index d750fbc42039e6..4f8c9f2902ca91 100644 --- a/tensorflow/compiler/jit/xla_platform_info.cc +++ b/tensorflow/compiler/jit/xla_platform_info.cc @@ -35,6 +35,7 @@ limitations under the License. #include "xla/client/local_client.h" #include "xla/pjrt/pjrt_client.h" #include "xla/service/compiler.h" +#include "xla/stream_executor/platform_manager.h" #include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/resource_mgr.h" @@ -109,7 +110,7 @@ Status GetCompilationDeviceTypeAndPjRtClient( *compilation_device_type = platform_info.xla_device_metadata()->jit_device_type(); TF_ASSIGN_OR_RETURN(*pjrt_client, GetOrCreatePjRtClient(device_type)); - return OkStatus(); + return absl::OkStatus(); } if (platform_info.pjrt_device_metadata()) { @@ -119,7 +120,7 @@ Status GetCompilationDeviceTypeAndPjRtClient( *compilation_device_type = platform_info.pjrt_device_metadata()->jit_device_type(); TF_ASSIGN_OR_RETURN(*pjrt_client, GetOrCreatePjRtClient(device_type)); - return OkStatus(); + return absl::OkStatus(); } // TFRT-TPU is used if device_type is `DEVICE_TPU` and platform_info does not @@ -127,7 +128,7 @@ Status GetCompilationDeviceTypeAndPjRtClient( if (device_type == DEVICE_TPU) { *compilation_device_type = DeviceType(DEVICE_TPU_XLA_JIT); TF_ASSIGN_OR_RETURN(*pjrt_client, GetOrCreatePjRtClient(device_type)); - return OkStatus(); + return absl::OkStatus(); } VLOG(2) << "platform_info.xla_device_metadata not found and " @@ -148,7 +149,7 @@ Status GetCompilationDeviceTypeAndPjRtClient( TF_ASSIGN_OR_RETURN(*pjrt_client, GetOrCreatePjRtClient(device_type, allowed_gpus)); - return OkStatus(); + return absl::OkStatus(); } } // namespace @@ -214,7 +215,7 @@ Status BuildXlaDeviceCompiler(DeviceBase* device, FunctionLibraryRuntime* flr, // cross platform lowering. *xla_device_compiler = new XlaDeviceCompiler(/*persistor=*/nullptr, /*compiler_client=*/nullptr); - return OkStatus(); + return absl::OkStatus(); } std::string persistent_cache_directory = GetPersistentCacheDirectory(platform_info.device_type()); @@ -230,7 +231,7 @@ Status BuildXlaDeviceCompiler(DeviceBase* device, FunctionLibraryRuntime* flr, persistor_config, platform_info.xla_device_metadata()->jit_device_type(), platform_info.xla_device_metadata()->client()); - return OkStatus(); + return absl::OkStatus(); } // TFRT-TPU is used if device type is `DEVICE_TPU` and platform_info does not @@ -241,14 +242,14 @@ Status BuildXlaDeviceCompiler(DeviceBase* device, FunctionLibraryRuntime* flr, if (platform_info.device_type() == DEVICE_TPU) { *xla_device_compiler = CreateXlaDeviceCompiler( persistor_config, DeviceType(DEVICE_TPU_XLA_JIT), nullptr); - return OkStatus(); + return absl::OkStatus(); } if (platform_info.platform_id() == nullptr) { return errors::InvalidArgument("platform_id is null."); } auto platform = - se::MultiPlatformManager::PlatformWithId(platform_info.platform_id()); + se::PlatformManager::PlatformWithId(platform_info.platform_id()); if (!platform.ok()) { return platform.status(); } @@ -291,7 +292,7 @@ Status BuildXlaDeviceCompiler(DeviceBase* device, FunctionLibraryRuntime* flr, *xla_device_compiler = CreateXlaDeviceCompiler( persistor_config, compilation_device_type, client); - return OkStatus(); + return absl::OkStatus(); } Status GetOrCreatePjRtDeviceCompilerAndProfiler( @@ -337,7 +338,7 @@ Status GetOrCreatePjRtDeviceCompilerAndProfiler( [&](PjRtDeviceCompiler** pjrt_device_compiler) { *pjrt_device_compiler = CreatePjRtDeviceCompiler(compilation_device_type, pjrt_client); - return OkStatus(); + return absl::OkStatus(); })); } @@ -345,10 +346,10 @@ Status GetOrCreatePjRtDeviceCompilerAndProfiler( rm->default_container(), profiler_name, profiler, [](DeviceCompilationProfiler** profiler) { *profiler = new DeviceCompilationProfiler(); - return OkStatus(); + return absl::OkStatus(); })); - return OkStatus(); + return absl::OkStatus(); } Status GetOrCreatePjRtDeviceCompilerAndProfiler( @@ -412,7 +413,7 @@ std::shared_ptr GetAllocator( if (!stream) { // Stream is not set for the host platform. se::Platform* platform = - se::MultiPlatformManager::PlatformWithId(platform_info.platform_id()) + se::PlatformManager::PlatformWithId(platform_info.platform_id()) .value(); return std::make_shared(alloc, platform); } diff --git a/tensorflow/compiler/jit/xla_platform_info.h b/tensorflow/compiler/jit/xla_platform_info.h index 754536c3e18b7b..94764e4d3dd7fe 100644 --- a/tensorflow/compiler/jit/xla_platform_info.h +++ b/tensorflow/compiler/jit/xla_platform_info.h @@ -103,12 +103,12 @@ class XlaPlatformInfo { // Returns a set containing the device ids contained in visible_device_list or // nullopt if it is empty. It returns error in case of malformed configuration // string. -StatusOr>> ParseVisibleDeviceList( +absl::StatusOr>> ParseVisibleDeviceList( absl::string_view visible_device_list); // Returns the device type for building a DeviceCompiler from the given platform // type. -StatusOr GetCompilationDeviceType( +absl::StatusOr GetCompilationDeviceType( const DeviceType& platform_device_type); // Builds a DeviceCompiler that uses xla::LocalClient using `platform_info` and diff --git a/tensorflow/compiler/jit/xla_tensor.cc b/tensorflow/compiler/jit/xla_tensor.cc index 871a8de1c68e47..3a7ea396e61862 100644 --- a/tensorflow/compiler/jit/xla_tensor.cc +++ b/tensorflow/compiler/jit/xla_tensor.cc @@ -66,7 +66,7 @@ Status XlaTensor::AllocateShapedBuffer(DataType dtype, VLOG(4) << shaped_buffer.ToString(); set_shaped_buffer(std::move(shaped_buffer)); - return OkStatus(); + return absl::OkStatus(); } void XlaTensor::WaitForDefinitionEventOnStream(se::Stream* stream) { @@ -83,7 +83,7 @@ void XlaTensor::WaitForDefinitionEventOnStream(se::Stream* stream) { return; } - stream->ThenWaitFor(definition_event_.get()); + stream->WaitFor(definition_event_.get()).IgnoreError(); streams_defined_on_.push_back(stream); } diff --git a/tensorflow/compiler/jit/xla_tpu_device.cc b/tensorflow/compiler/jit/xla_tpu_device.cc index eeb809f62f083b..77f17c6c48cb30 100644 --- a/tensorflow/compiler/jit/xla_tpu_device.cc +++ b/tensorflow/compiler/jit/xla_tpu_device.cc @@ -100,7 +100,7 @@ Status TpuPaddedShapeFn(const Tensor& tensor, xla::Shape* shape) { return status.status(); } *shape = tpu_shape.AsCpp(); - return OkStatus(); + return absl::OkStatus(); } // Check if TPU has been initialized. TPU initialization is not necessary @@ -111,7 +111,7 @@ Status CheckIfTPUInitialized() { return errors::FailedPrecondition( "The TPU system has not been initialized."); } - return OkStatus(); + return absl::OkStatus(); } // Implementation of TPU->TPU device copies that copies over the dedicated TPU @@ -140,13 +140,13 @@ void TpuDeviceToDeviceCopy(DeviceContext* src_dev_context, Status s = CheckIfTPUInitialized(); if (!s.ok()) { done(s); - return OkStatus(); + return absl::OkStatus(); } } if (input->shape().num_elements() == 0) { // Zero-element tensors have no backing buffers. - done(OkStatus()); - return OkStatus(); + done(absl::OkStatus()); + return absl::OkStatus(); } se::Stream* const src_compute_stream = src_xla_context->stream(); @@ -167,8 +167,8 @@ void TpuDeviceToDeviceCopy(DeviceContext* src_dev_context, dst_compute_stream_impl)) { // Surprisingly, this path does get triggered in practice. *output = *input; - done(OkStatus()); - return OkStatus(); + done(absl::OkStatus()); + return absl::OkStatus(); } // To avoid stream exhaustion, we pick a substream from a pool if enabled. @@ -177,7 +177,8 @@ void TpuDeviceToDeviceCopy(DeviceContext* src_dev_context, : nullptr; se::Stream* const dst_device_to_device_stream = should_use_substream - ? device_to_device_master_stream->GetOrCreateSubStream() + ? device_to_device_master_stream->GetOrCreateSubStream().value_or( + nullptr) : dst_xla_context->GetDeviceToDeviceStream(); TF_RET_CHECK(dst_device_to_device_stream != nullptr); auto return_substream = gtl::MakeCleanup( @@ -296,10 +297,10 @@ void TpuDeviceToDeviceCopy(DeviceContext* src_dev_context, dst_device_to_device_stream); } input_reference.Unref(); - done(OkStatus()); + done(absl::OkStatus()); }); - return OkStatus(); + return absl::OkStatus(); }; Status status = impl(); if (!status.ok()) { @@ -319,7 +320,7 @@ Status TpuNodeDeviceFactory::ListPhysicalDevices(std::vector* devices) { tpu::TpuPlatformInterface::GetRegisteredPlatform(); if (platform == nullptr) { // If we don't have a platform registered, then we have no devices. - return OkStatus(); + return absl::OkStatus(); } int device_count = platform->VisibleDeviceCount(); @@ -329,7 +330,7 @@ Status TpuNodeDeviceFactory::ListPhysicalDevices(std::vector* devices) { devices->push_back(device_name); } - return OkStatus(); + return absl::OkStatus(); } Status TpuNodeDeviceFactory::CreateDevices( @@ -339,7 +340,7 @@ Status TpuNodeDeviceFactory::CreateDevices( tpu::TpuPlatformInterface::GetRegisteredPlatform(); if (platform == nullptr) { // If we don't have a platform registered, then we should not create any. - return OkStatus(); + return absl::OkStatus(); } if (platform != nullptr && platform->ShouldRegisterTpuDeviceToDeviceCopy()) { @@ -406,7 +407,7 @@ Status TpuNodeDeviceFactory::CreateDevices( devices->push_back(std::move(device)); } - return OkStatus(); + return absl::OkStatus(); } class TpuSystemDeviceFactory : public DeviceFactory { @@ -422,12 +423,12 @@ Status TpuSystemDeviceFactory::ListPhysicalDevices( TF_RETURN_IF_ERROR(tpu::TpuPlatform::TpusPerHost(&device_count)); if (device_count == 0) { VLOG(1) << "Host has no TPUs, not creating a TPU_SYSTEM device"; - return OkStatus(); + return absl::OkStatus(); } devices->push_back("/physical_device:TPU_SYSTEM:0"); - return OkStatus(); + return absl::OkStatus(); } Status TpuSystemDeviceFactory::CreateDevices( @@ -437,7 +438,7 @@ Status TpuSystemDeviceFactory::CreateDevices( TF_RETURN_IF_ERROR(tpu::TpuPlatform::TpusPerHost(&device_count)); if (device_count == 0) { VLOG(1) << "Host has no TPUs, not creating a TPU_SYSTEM device"; - return OkStatus(); + return absl::OkStatus(); } int64_t memory_limit; @@ -452,7 +453,7 @@ Status TpuSystemDeviceFactory::CreateDevices( VLOG(1) << "Created TPU_SYSTEM device. This host has " << device_count << " TPUs"; - return OkStatus(); + return absl::OkStatus(); } } // namespace diff --git a/tensorflow/compiler/mlir/BUILD b/tensorflow/compiler/mlir/BUILD index 742f97d902522a..b30f08a1bfe1b4 100644 --- a/tensorflow/compiler/mlir/BUILD +++ b/tensorflow/compiler/mlir/BUILD @@ -74,7 +74,6 @@ cc_library( "@local_xla//xla/mlir/framework/transforms:passes", "@local_xla//xla/mlir_hlo:all_passes", "@local_xla//xla/service/cpu:hlo_xla_runtime_pipeline", - "@local_xla//xla/translate/mhlo_to_lhlo_with_xla", ], ) @@ -204,7 +203,6 @@ cc_library( "@local_xla//xla/mlir/framework/ir:xla_framework", "@local_xla//xla/mlir_hlo:hlo_dialect_registration", "@local_xla//xla/service/cpu:hlo_xla_runtime_pipeline", - "@local_xla//xla/translate/mhlo_to_lhlo_with_xla", "@stablehlo//:register", ], ) @@ -239,7 +237,6 @@ tf_cc_binary( "@llvm-project//mlir:TranslateLib", "@local_xla//xla/translate/hlo_to_mhlo:translate_registration", "@local_xla//xla/translate/mhlo_to_hlo:translate_registration", - "@local_xla//xla/translate/mhlo_to_lhlo_with_xla:translate_registration", ], ) diff --git a/tensorflow/compiler/mlir/lite/BUILD b/tensorflow/compiler/mlir/lite/BUILD index 3ebbd3b1b81942..38610dcef42b52 100644 --- a/tensorflow/compiler/mlir/lite/BUILD +++ b/tensorflow/compiler/mlir/lite/BUILD @@ -5,7 +5,7 @@ load("//tensorflow:tensorflow.default.bzl", "filegroup", "get_compatible_with_po load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") package( - # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + # copybara:uncomment default_applicable_licenses = ["//tensorflow:LICENSE"], default_visibility = ["//visibility:public"], licenses = ["notice"], ) @@ -452,14 +452,15 @@ cc_library( "utils/constant_utils.h", ], deps = [ - "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/mlir/tensorflow:mangling_util", "//tensorflow/compiler/mlir/tensorflow:tensorflow_attributes", "//tensorflow/core:protos_all_cc", "//tensorflow/core/platform:status", + "@com_google_absl//absl/status", "@llvm-project//mlir:ArithDialect", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", "@local_tsl//tsl/platform:statusor", ], ) @@ -754,6 +755,7 @@ cc_library( "transforms/passes.h", ], deps = [ + ":constant_utils", ":convert_type", ":tensorflow_lite", ":tensorflow_lite_passes_inc_gen", @@ -1267,24 +1269,29 @@ tf_cc_binary( ":tf_to_tfl_flatbuffer", "//tensorflow/cc/saved_model:loader", "//tensorflow/compiler/mlir:init_mlir", + "//tensorflow/compiler/mlir/lite/quantization:quantization_config", + "//tensorflow/compiler/mlir/quantization/tensorflow/calibrator:calibrator_singleton_impl", "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_flags", "//tensorflow/compiler/mlir/tensorflow:translate_cl_options", "//tensorflow/compiler/mlir/tensorflow/transforms:tensorflow_passes", "//tensorflow/compiler/tf2xla/kernels:xla_ops", + "//tensorflow/core:core_cpu_base", "//tensorflow/core:protos_all_cc", "//tensorflow/core/platform:errors", "//tensorflow/lite:framework", "//tensorflow/lite/schema:schema_fbs", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", "@llvm-project//llvm:Support", "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:FuncExtensions", "@llvm-project//mlir:IR", "@llvm-project//mlir:Parser", "@llvm-project//mlir:Pass", "@llvm-project//mlir:Support", "@llvm-project//mlir:Transforms", - "@local_tsl//tsl/platform:statusor", "@local_xla//xla/translate/hlo_to_mhlo:translate", "@stablehlo//:stablehlo_ops", ], @@ -1338,23 +1345,17 @@ cc_library( "//tensorflow/compiler/mlir/lite/stablehlo:tfl_legalize_hlo", # buildcleaner: keep "//tensorflow/compiler/mlir/lite/stablehlo:transforms", "//tensorflow/compiler/mlir/lite/stablehlo:uniform_quantized_stablehlo_to_tfl_pass", - "//tensorflow/compiler/mlir/tensorflow", - "//tensorflow/compiler/mlir/tensorflow:tensorflow_ops", - "//tensorflow/compiler/mlir/tensorflow:translate_lib", "//tensorflow/compiler/mlir/tensorflow/transforms:tensorflow_passes", - "//tensorflow/compiler/mlir/tensorflow/transforms:tf_graph_optimization_pass", "//tensorflow/compiler/mlir/tensorflow/transforms:tf_saved_model_passes", "//tensorflow/core:core_cpu_base", - "//tensorflow/lite/toco:model_flags_proto_cc", "//tensorflow/lite/toco:toco_flags_proto_cc", "@llvm-project//llvm:Support", "@llvm-project//mlir:FuncDialect", - "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", "@llvm-project//mlir:Transforms", - "@local_xla//xla/mlir_hlo", - "@local_xla//xla/mlir_hlo:all_passes", "@local_xla//xla/mlir_hlo:mhlo_passes", + "@stablehlo//stablehlo/experimental:experimental_stablehlo_passes", ], ) @@ -1370,45 +1371,54 @@ cc_library( ":tensorflow_lite", ":tf_tfl_passes", "//tensorflow/cc/saved_model:loader", + "//tensorflow/compiler/mlir:op_or_arg_name_mapper", "//tensorflow/compiler/mlir/lite/debug", + "//tensorflow/compiler/mlir/lite/metrics:error_collector", "//tensorflow/compiler/mlir/lite/metrics:error_collector_inst", "//tensorflow/compiler/mlir/lite/quantization:quantization_config", + "//tensorflow/compiler/mlir/lite/quantization/stablehlo:quantization", "//tensorflow/compiler/mlir/lite/stablehlo:op_stat_pass", "//tensorflow/compiler/mlir/lite/stablehlo:stablehlo_tfl", "//tensorflow/compiler/mlir/lite/stablehlo:stablehlo_util", "//tensorflow/compiler/mlir/lite/stablehlo:transforms", - "//tensorflow/compiler/mlir/lite/stablehlo/serializer:flatbuffer_export", + "//tensorflow/compiler/mlir/quantization/stablehlo:quantization_config_proto_cc", "//tensorflow/compiler/mlir/quantization/stablehlo:quantize_passes", + "//tensorflow/compiler/mlir/quantization/stablehlo/cc:static_range_ptq_impl", # buildcleaner: keep; prevents undefined reference "//tensorflow/compiler/mlir/quantization/tensorflow:quantization_options_proto_cc", "//tensorflow/compiler/mlir/quantization/tensorflow:quantize_passes", "//tensorflow/compiler/mlir/quantization/tensorflow:quantize_preprocess", + "//tensorflow/compiler/mlir/quantization/tensorflow/python:py_function_lib", "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/mlir/tensorflow:error_util", + "//tensorflow/compiler/mlir/tensorflow:mlir_import_options", "//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_flags", "//tensorflow/compiler/mlir/tensorflow:tf_dialect_lib", "//tensorflow/compiler/mlir/tensorflow:translate_lib", - "//tensorflow/compiler/mlir/tensorflow/transforms:tensorflow_passes", "//tensorflow/compiler/mlir/tensorflow/transforms:tf_dialect_passes", "//tensorflow/compiler/mlir/tensorflow/transforms:tf_saved_model_freeze_variables", - "//tensorflow/compiler/mlir/tensorflow/transforms:tf_saved_model_passes", + "//tensorflow/core:core_cpu_base", "//tensorflow/core:framework", "//tensorflow/core:protos_all_cc", - "//tensorflow/core/platform:status", + "//tensorflow/lite/c:c_api_types", + "//tensorflow/lite/experimental/remat:metadata_util", "//tensorflow/lite/python/metrics:converter_error_data_proto_cc", + "//tensorflow/lite/schema:schema_fbs", "//tensorflow/lite/toco:toco_flags_proto_cc", "//tensorflow/lite/tools/optimize:quantize_weights", "//tensorflow/lite/tools/optimize:reduced_precision_support", + "@com_google_absl//absl/log", "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", + "@flatbuffers//:runtime_cc", "@llvm-project//llvm:Support", - "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:FuncExtensions", "@llvm-project//mlir:IR", "@llvm-project//mlir:Parser", "@llvm-project//mlir:Pass", "@llvm-project//mlir:Support", - "@llvm-project//mlir:Transforms", + "@local_tsl//tsl/platform:protobuf", "@local_tsl//tsl/platform:statusor", ], ) diff --git a/tensorflow/compiler/mlir/lite/common/tfl_pass_config.h b/tensorflow/compiler/mlir/lite/common/tfl_pass_config.h index 9dd57f2bdea429..39f81c7a6a770d 100644 --- a/tensorflow/compiler/mlir/lite/common/tfl_pass_config.h +++ b/tensorflow/compiler/mlir/lite/common/tfl_pass_config.h @@ -95,6 +95,10 @@ struct PassConfig { // ops and to convert kernels to quantized kernels wherever appropriate. quant::QDQConversionMode qdq_conversion_mode = quant::QDQConversionMode::kQDQNone; + + // When set to true, StableHLO Quantizer is run. The full configuration for + // the quantizer is at `TocoFlags::quantization_config`. + bool enable_stablehlo_quantizer = false; }; inline llvm::raw_ostream& operator<<(llvm::raw_ostream& os, diff --git a/tensorflow/compiler/mlir/lite/debug/debug_test.cc b/tensorflow/compiler/mlir/lite/debug/debug_test.cc index 1bcb86de9a4a94..127d485b842f94 100644 --- a/tensorflow/compiler/mlir/lite/debug/debug_test.cc +++ b/tensorflow/compiler/mlir/lite/debug/debug_test.cc @@ -53,11 +53,7 @@ limitations under the License. #include "tsl/platform/status.h" namespace tensorflow { -namespace { - -using ::testing::HasSubstr; -using ::testing::IsEmpty; -using ::testing::Not; +namespace debug_test { class NopPass : public mlir::PassWrapper> { public: @@ -84,6 +80,15 @@ class AlwaysFailPass void runOnOperation() override { signalPassFailure(); } }; +} // namespace debug_test + +namespace { + +using ::testing::HasSubstr; +using ::testing::IsEmpty; +using ::testing::Not; +using namespace tensorflow::debug_test; + class InitPassManagerTest : public testing::Test { protected: InitPassManagerTest() @@ -179,8 +184,7 @@ TEST_F(InitPassManagerTest, DumpToDir) { TF_ASSERT_OK(tsl::ReadFileToString( tsl::Env::Default(), tsl::io::JoinPath( - dump_dir, - "00000000.main.tensorflow_anonymous_namespace_NopPass_after.mlir"), + dump_dir, "00000000.main.tensorflow_debug_test_NopPass_after.mlir"), &mlir_dump)); EXPECT_THAT(mlir_dump, Not(IsEmpty())); } @@ -190,7 +194,7 @@ TEST_F(InitPassManagerTest, DumpToDir) { tsl::Env::Default(), tsl::io::JoinPath( dump_dir, - "00000000.main.tensorflow_anonymous_namespace_NopPass_before.mlir"), + "00000000.main.tensorflow_debug_test_NopPass_before.mlir"), &mlir_dump)); EXPECT_THAT(mlir_dump, Not(IsEmpty())); } @@ -207,12 +211,10 @@ TEST_F(InitPassManagerTest, PrintIRBeforeEverything) { pm.addPass(std::make_unique()); ASSERT_TRUE(mlir::succeeded(pm.run(*module_))); - EXPECT_THAT( - captured_out, - HasSubstr("IR Dump Before tensorflow::(anonymous namespace)::NopPass")); EXPECT_THAT(captured_out, - Not(HasSubstr( - "IR Dump After tensorflow::(anonymous namespace)::NopPass"))); + HasSubstr("IR Dump Before tensorflow::debug_test::NopPass")); + EXPECT_THAT(captured_out, + Not(HasSubstr("IR Dump After tensorflow::debug_test::NopPass"))); } TEST_F(InitPassManagerTest, PrintIRAfterEverything) { @@ -226,13 +228,11 @@ TEST_F(InitPassManagerTest, PrintIRAfterEverything) { pm.addPass(std::make_unique()); ASSERT_TRUE(mlir::succeeded(pm.run(*module_))); + EXPECT_THAT(captured_out, + HasSubstr("IR Dump After tensorflow::debug_test::MutatePass")); EXPECT_THAT( captured_out, - HasSubstr("IR Dump After tensorflow::(anonymous namespace)::MutatePass")); - EXPECT_THAT( - captured_out, - Not(HasSubstr( - "IR Dump Before tensorflow::(anonymous namespace)::MutatePass"))); + Not(HasSubstr("IR Dump Before tensorflow::debug_test::MutatePass"))); } TEST_F(InitPassManagerTest, PrintIRBeforeAndAfterEverything) { @@ -247,13 +247,10 @@ TEST_F(InitPassManagerTest, PrintIRBeforeAndAfterEverything) { pm.addPass(std::make_unique()); ASSERT_TRUE(mlir::succeeded(pm.run(*module_))); - EXPECT_THAT( - captured_out, - HasSubstr("IR Dump After tensorflow::(anonymous namespace)::MutatePass")); - EXPECT_THAT( - captured_out, - HasSubstr( - "IR Dump Before tensorflow::(anonymous namespace)::MutatePass")); + EXPECT_THAT(captured_out, + HasSubstr("IR Dump After tensorflow::debug_test::MutatePass")); + EXPECT_THAT(captured_out, + HasSubstr("IR Dump Before tensorflow::debug_test::MutatePass")); } TEST_F(InitPassManagerTest, ElideLargeElementAttrs) { diff --git a/tensorflow/compiler/mlir/lite/flatbuffer_export.cc b/tensorflow/compiler/mlir/lite/flatbuffer_export.cc index 04a92f3412ba82..25f62ce0981b6f 100644 --- a/tensorflow/compiler/mlir/lite/flatbuffer_export.cc +++ b/tensorflow/compiler/mlir/lite/flatbuffer_export.cc @@ -2632,8 +2632,7 @@ Translator::CreateMetadataVector() { } else { module_.emitError( "all values in tfl.metadata's dictionary key-value pairs should " - "be " - "string attributes"); + "be string attributes"); return std::nullopt; } } diff --git a/tensorflow/compiler/mlir/lite/flatbuffer_import.cc b/tensorflow/compiler/mlir/lite/flatbuffer_import.cc index 26ae90d95a97ea..bd912797d44820 100644 --- a/tensorflow/compiler/mlir/lite/flatbuffer_import.cc +++ b/tensorflow/compiler/mlir/lite/flatbuffer_import.cc @@ -40,6 +40,7 @@ limitations under the License. #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringExtras.h" #include "llvm/ADT/StringRef.h" +#include "llvm/ADT/StringSet.h" #include "llvm/Analysis/AssumeBundleQueries.h" #include "llvm/Support/Casting.h" #include "llvm/Support/FormatVariadic.h" @@ -1495,6 +1496,8 @@ OwningOpRef tflite::FlatBufferToMlir( bool use_stablehlo_constant = false; + llvm::SmallVector metadata_attrs; + mlir::StringSet<> seen_attr; for (const auto& metadata : model->metadata) { if (metadata->name == tflite::kModelControlDependenciesMetadataKey) { const std::vector& data = model->buffers[metadata->buffer]->data; @@ -1502,15 +1505,28 @@ OwningOpRef tflite::FlatBufferToMlir( reinterpret_cast(data.data()), data.size(), &model_control_dependencies)) { return emitError(base_loc, - "Invalid model_control_dependencies metadata"), + "invalid model_control_dependencies metadata"), nullptr; } - break; + continue; } + + // Skip already seen attributes. Ideally there should be no duplicates here. + if (!seen_attr.try_emplace(metadata->name).second) continue; + // check if the model is serialized using stablehlo constant tensor if (metadata->name == tflite::kModelUseStablehloTensorKey) { use_stablehlo_constant = true; + metadata_attrs.emplace_back(builder.getStringAttr(metadata->name), + builder.getStringAttr("true")); + continue; } + + std::vector buffer = model->buffers[metadata->buffer]->data; + metadata_attrs.emplace_back( + builder.getStringAttr(metadata->name), + builder.getStringAttr(llvm::StringRef( + reinterpret_cast(buffer.data()), buffer.size()))); } std::vector func_names; @@ -1528,18 +1544,15 @@ OwningOpRef tflite::FlatBufferToMlir( builder.getStringAttr(model->description)); } + if (!metadata_attrs.empty()) { + module->setAttr("tfl.metadata", builder.getDictionaryAttr(metadata_attrs)); + } + if (!model->signature_defs.empty()) { module->setAttr("tf_saved_model.semantics", mlir::UnitAttr::get(builder.getContext())); } - if (use_stablehlo_constant) { - module->setAttr("tfl.metadata", - builder.getDictionaryAttr(builder.getNamedAttr( - tflite::kModelUseStablehloTensorKey, - builder.getStringAttr("true")))); - } - absl::flat_hash_map subgraph_to_signature_map; for (int i = 0; i < model->signature_defs.size(); i++) { diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_ops.td b/tensorflow/compiler/mlir/lite/ir/tfl_ops.td index a29de738c5de81..e8f9787947d6eb 100644 --- a/tensorflow/compiler/mlir/lite/ir/tfl_ops.td +++ b/tensorflow/compiler/mlir/lite/ir/tfl_ops.td @@ -1061,7 +1061,7 @@ def TFL_DepthwiseConv2DOp : def TFL_FullyConnectedOp : TFL_Op<"fully_connected", [ Pure, AccumulatorUniformScale<2, 0, 1>, AffineQuantizedOpInterface, - AffineOpCoefficient<-1, 1>, + AffineOpCoefficient<0, 1>, TFL_SparseOp, DeclareOpInterfaceMethods, QuantizableResult, @@ -1097,7 +1097,7 @@ def TFL_FullyConnectedOp : TFL_Op<"fully_connected", [ let extraClassDeclaration = [{ // AffineQuantizedOpInterface: int GetChannelDimIndex() { return 0; } - int GetQuantizationDimIndex() { return -1; } + int GetQuantizationDimIndex() { return 0; } // SparseOpInterface: std::vector GetSparseOperands() { return {1}; } std::vector> GetFloatBlockSize() { return {{1, 4}}; } @@ -5399,12 +5399,12 @@ subsequent operation and then be optimized away, however.) }]; let arguments = (ins - TFL_TensorOf<[F32, I32, I1, TFL_I4, I8, QI8, UI8, QUI8, I16, QI16, I64, Complex>]>:$input, + TFL_TensorOf<[F32, I32, I1, TFL_I4, I8, QI8, UI8, UI32, QUI8, I16, QI16, I64, Complex>]>:$input, TFL_I32OrI64Tensor:$shape ); let results = (outs - TFL_TensorOf<[F32, I32, I1, TFL_I4, I8, QI8, UI8, QUI8, I16, QI16, I64, Complex>]>:$output + TFL_TensorOf<[F32, I32, I1, TFL_I4, I8, QI8, UI8, UI32, QUI8, I16, QI16, I64, Complex>]>:$output ); let hasCanonicalizer = 1; diff --git a/tensorflow/compiler/mlir/lite/python/BUILD b/tensorflow/compiler/mlir/lite/python/BUILD index a911d438d20368..0ac8bc0ff65117 100644 --- a/tensorflow/compiler/mlir/lite/python/BUILD +++ b/tensorflow/compiler/mlir/lite/python/BUILD @@ -23,21 +23,25 @@ cc_library( srcs = ["tf_tfl_flatbuffer_helpers.cc"], hdrs = ["tf_tfl_flatbuffer_helpers.h"], deps = [ + "//tensorflow/cc/saved_model:loader", "//tensorflow/compiler/mlir/lite:common", "//tensorflow/compiler/mlir/lite:tensorflow_lite", - "//tensorflow/compiler/mlir/lite:tf_tfl_passes", "//tensorflow/compiler/mlir/lite:tf_to_tfl_flatbuffer", - "//tensorflow/compiler/mlir/tensorflow:dump_mlir_util", - "//tensorflow/compiler/mlir/tensorflow:import_model", - "//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_flags", + "//tensorflow/compiler/mlir/lite/quantization:quantization_config", + "//tensorflow/compiler/mlir/quantization/tensorflow/python:py_function_lib", "//tensorflow/compiler/mlir/tensorflow:tensorflow_ops", "//tensorflow/core:core_cpu_base", + "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", "//tensorflow/lite/toco:model_flags_proto_cc", "//tensorflow/lite/toco:toco_flags_proto_cc", "//tensorflow/lite/toco:types_proto_cc", "//tensorflow/lite/tools/optimize:reduced_precision_support", + "@com_google_absl//absl/log", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", "@llvm-project//llvm:Support", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", @@ -56,10 +60,8 @@ cc_library( deps = [ ":tf_tfl_flatbuffer_helpers", "//tensorflow/compiler/mlir/lite:common", - "//tensorflow/compiler/mlir/lite:tensorflow_lite", - "//tensorflow/compiler/mlir/lite:tf_tfl_passes", - "//tensorflow/compiler/mlir/lite:tf_to_tfl_flatbuffer", - "//tensorflow/compiler/mlir/tensorflow", + "//tensorflow/compiler/mlir/lite/quantization:quantization_config", + "//tensorflow/compiler/mlir/quantization/tensorflow/python:py_function_lib", "//tensorflow/compiler/mlir/tensorflow:import_model", "//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_flags", "//tensorflow/core:lib", @@ -68,12 +70,8 @@ cc_library( "//tensorflow/lite/toco:toco_flags_proto_cc", "//tensorflow/lite/toco:types_proto_cc", "//tensorflow/lite/tools/optimize:reduced_precision_support", - "@llvm-project//llvm:Support", - "@llvm-project//mlir:FuncDialect", + "@com_google_absl//absl/status", "@llvm-project//mlir:IR", - "@llvm-project//mlir:Pass", - "@llvm-project//mlir:Support", - "@llvm-project//mlir:Transforms", ], ) @@ -90,6 +88,7 @@ cc_library( "//tensorflow/compiler/mlir/lite:tensorflow_lite", "//tensorflow/compiler/mlir/lite:tf_to_tfl_flatbuffer", "//tensorflow/compiler/mlir/lite/quantization:quantization_config", + "//tensorflow/compiler/mlir/quantization/tensorflow/python:py_function_lib", "//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_flags", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", @@ -114,23 +113,19 @@ cc_library( ":tf_tfl_flatbuffer_helpers", "//tensorflow/compiler/mlir/lite:common", "//tensorflow/compiler/mlir/lite:tensorflow_lite", - "//tensorflow/compiler/mlir/lite:tf_tfl_passes", - "//tensorflow/compiler/mlir/lite:tf_to_tfl_flatbuffer", - "//tensorflow/compiler/mlir/tensorflow", - "//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_flags", + "//tensorflow/compiler/mlir/lite/quantization:quantization_config", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", "//tensorflow/lite/toco:model_flags_proto_cc", "//tensorflow/lite/toco:toco_flags_proto_cc", "//tensorflow/lite/toco:types_proto_cc", + "@com_google_absl//absl/log", + "@com_google_absl//absl/status", "@com_google_absl//absl/strings", - "@com_google_absl//absl/types:span", "@llvm-project//llvm:Support", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", - "@llvm-project//mlir:Pass", "@llvm-project//mlir:Support", - "@llvm-project//mlir:Transforms", "@local_xla//xla/service:hlo_parser", "@local_xla//xla/service:hlo_proto_cc", "@local_xla//xla/translate/hlo_to_mhlo:hlo_to_mlir_hlo", diff --git a/tensorflow/compiler/mlir/lite/python/graphdef_to_tfl_flatbuffer.cc b/tensorflow/compiler/mlir/lite/python/graphdef_to_tfl_flatbuffer.cc index 9409856cd4c864..f678daf32f234c 100644 --- a/tensorflow/compiler/mlir/lite/python/graphdef_to_tfl_flatbuffer.cc +++ b/tensorflow/compiler/mlir/lite/python/graphdef_to_tfl_flatbuffer.cc @@ -16,49 +16,41 @@ limitations under the License. #include "tensorflow/compiler/mlir/lite/python/graphdef_to_tfl_flatbuffer.h" #include -#include #include #include #include -#include "llvm/Support/ToolOutputFile.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project -#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "absl/status/status.h" #include "mlir/IR/MLIRContext.h" // from @llvm-project -#include "mlir/Pass/Pass.h" // from @llvm-project -#include "mlir/Support/FileUtilities.h" // from @llvm-project -#include "mlir/Transforms/ViewOpGraph.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/common/tfl_pass_config.h" #include "tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.h" -#include "tensorflow/compiler/mlir/lite/tf_tfl_passes.h" -#include "tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.h" -#include "tensorflow/compiler/mlir/lite/transforms/passes.h" +#include "tensorflow/compiler/mlir/lite/quantization/quantization_config.h" #include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h" #include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h" #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/graph_debug_info.pb.h" #include "tensorflow/core/framework/types.pb.h" -#include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/platform/status.h" #include "tensorflow/lite/toco/model_flags.pb.h" #include "tensorflow/lite/toco/toco_flags.pb.h" #include "tensorflow/lite/toco/types.pb.h" +#include "tensorflow/lite/tools/optimize/reduced_precision_support.h" +#include "tsl/platform/errors.h" #include "tsl/platform/statusor.h" namespace tensorflow { -Status ConvertGraphDefToTFLiteFlatBuffer(const toco::ModelFlags& model_flags, - const toco::TocoFlags& toco_flags, - const GraphDebugInfo& debug_info, - const GraphDef& input, - string* result) { + +absl::Status ConvertGraphDefToTFLiteFlatBuffer( + const toco::ModelFlags& model_flags, const toco::TocoFlags& toco_flags, + const GraphDebugInfo& debug_info, const GraphDef& input, + std::string* result) { using ::tflite::optimize::ReducedPrecisionSupport; mlir::MLIRContext context; GraphImportConfig specs; mlir::quant::QuantizationSpecs quant_specs; // Parse input arrays. - std::vector node_names; - std::vector node_dtypes; + std::vector node_names; + std::vector node_dtypes; std::vector>> node_shapes; std::vector> node_mins; std::vector> node_maxs; @@ -68,21 +60,20 @@ Status ConvertGraphDefToTFLiteFlatBuffer(const toco::ModelFlags& model_flags, model_flags, toco_flags, &quant_specs, &node_names, &node_dtypes, &node_shapes, &node_mins, &node_maxs)); - TF_RETURN_IF_ERROR(tensorflow::ParseInputArrayInfo( - node_names, node_dtypes, node_shapes, &specs.inputs)); + TF_RETURN_IF_ERROR( + ParseInputArrayInfo(node_names, node_dtypes, node_shapes, &specs.inputs)); // Parse output arrays. - std::vector output_arrays(model_flags.output_arrays().begin(), - model_flags.output_arrays().end()); - TF_RETURN_IF_ERROR( - tensorflow::ParseOutputArrayInfo(output_arrays, &specs.outputs)); + std::vector output_arrays(model_flags.output_arrays().begin(), + model_flags.output_arrays().end()); + TF_RETURN_IF_ERROR(ParseOutputArrayInfo(output_arrays, &specs.outputs)); // Parse control output arrays. - std::vector control_output_arrays( + std::vector control_output_arrays( model_flags.control_output_arrays().begin(), model_flags.control_output_arrays().end()); - TF_RETURN_IF_ERROR(tensorflow::ParseOutputArrayInfo(control_output_arrays, - &specs.control_outputs)); + TF_RETURN_IF_ERROR( + ParseOutputArrayInfo(control_output_arrays, &specs.control_outputs)); specs.prune_unused_nodes = true; specs.convert_legacy_fed_inputs = true; @@ -118,10 +109,12 @@ Status ConvertGraphDefToTFLiteFlatBuffer(const toco::ModelFlags& model_flags, toco_flags.guarantee_all_funcs_one_use(); pass_config.enable_stablehlo_conversion = toco_flags.convert_to_stablehlo(); + // StableHLO Quantizer is not supported for GraphDef inputs, so + // quantization_py_function_lib is set to nullptr. return internal::ConvertMLIRToTFLiteFlatBuffer( model_flags, toco_flags, std::move(module), pass_config, - /*saved_model_tags=*/{}, result, - /*session=*/std::nullopt); + /*saved_model_tags=*/{}, result, /*saved_model_bundle=*/nullptr, + /*quantization_py_function_lib=*/nullptr); } } // namespace tensorflow diff --git a/tensorflow/compiler/mlir/lite/python/graphdef_to_tfl_flatbuffer.h b/tensorflow/compiler/mlir/lite/python/graphdef_to_tfl_flatbuffer.h index e69d3c718d9b37..54f8a996e8883c 100644 --- a/tensorflow/compiler/mlir/lite/python/graphdef_to_tfl_flatbuffer.h +++ b/tensorflow/compiler/mlir/lite/python/graphdef_to_tfl_flatbuffer.h @@ -15,9 +15,11 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_LITE_PYTHON_GRAPHDEF_TO_TFL_FLATBUFFER_H_ #define TENSORFLOW_COMPILER_MLIR_LITE_PYTHON_GRAPHDEF_TO_TFL_FLATBUFFER_H_ +#include + +#include "absl/status/status.h" #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/graph_debug_info.pb.h" -#include "tensorflow/core/lib/core/errors.h" #include "tensorflow/lite/toco/model_flags.pb.h" #include "tensorflow/lite/toco/toco_flags.pb.h" @@ -26,10 +28,10 @@ namespace tensorflow { // Converts the given GraphDef to a TF Lite FlatBuffer string according to the // given model flags, toco flags and debug information. Returns error status if // it fails to convert the input. -Status ConvertGraphDefToTFLiteFlatBuffer(const toco::ModelFlags& model_flags, - const toco::TocoFlags& toco_flags, - const GraphDebugInfo& debug_info, - const GraphDef& input, string* result); +absl::Status ConvertGraphDefToTFLiteFlatBuffer( + const toco::ModelFlags& model_flags, const toco::TocoFlags& toco_flags, + const GraphDebugInfo& debug_info, const GraphDef& input, + std::string* result); } // namespace tensorflow diff --git a/tensorflow/compiler/mlir/lite/python/jax_to_tfl_flatbuffer.cc b/tensorflow/compiler/mlir/lite/python/jax_to_tfl_flatbuffer.cc index f81b5e8b5da6a7..b25040827ccbae 100644 --- a/tensorflow/compiler/mlir/lite/python/jax_to_tfl_flatbuffer.cc +++ b/tensorflow/compiler/mlir/lite/python/jax_to_tfl_flatbuffer.cc @@ -20,27 +20,23 @@ limitations under the License. #include #include +#include "absl/log/log.h" +#include "absl/status/status.h" #include "absl/strings/str_join.h" -#include "absl/types/span.h" #include "llvm/ADT/SmallVector.h" -#include "llvm/ADT/StringSet.h" -#include "llvm/Support/ToolOutputFile.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project -#include "mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/OwningOpRef.h" // from @llvm-project #include "mlir/IR/TypeUtilities.h" // from @llvm-project -#include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Support/FileUtilities.h" // from @llvm-project -#include "mlir/Transforms/ViewOpGraph.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/common/tfl_pass_config.h" #include "tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.h" -#include "tensorflow/compiler/mlir/lite/tf_tfl_passes.h" -#include "tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.h" +#include "tensorflow/compiler/mlir/lite/quantization/quantization_config.h" #include "tensorflow/compiler/mlir/lite/transforms/passes.h" -#include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h" #include "xla/service/hlo.pb.h" #include "xla/service/hlo_parser.h" #include "xla/translate/hlo_to_mhlo/hlo_to_mlir_hlo.h" @@ -49,23 +45,24 @@ limitations under the License. #include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/platform/errors.h" -#include "tensorflow/core/platform/status.h" #include "tensorflow/lite/toco/model_flags.pb.h" #include "tensorflow/lite/toco/toco_flags.pb.h" #include "tensorflow/lite/toco/types.pb.h" -#include "tsl/platform/statusor.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/protobuf.h" // IWYU pragma: keep namespace tensorflow { namespace { // Error collector that simply ignores errors reported. -class NoOpErrorCollector : public tensorflow::protobuf::io::ErrorCollector { +class NoOpErrorCollector : public protobuf::io::ErrorCollector { public: - void AddError(int line, int column, const string& message) override {} + void AddError(int line, int column, const std::string& message) override {} }; bool LoadHloProto(const std::string& contents, xla::HloProto* hlo_proto) { - tensorflow::protobuf::TextFormat::Parser parser; + // NOLINTNEXTLINE: Use tsl::protobuf to be compatible with OSS. + tsl::protobuf::TextFormat::Parser parser; NoOpErrorCollector collector; parser.RecordErrorsTo(&collector); return hlo_proto->ParseFromString(contents) || @@ -75,10 +72,10 @@ bool LoadHloProto(const std::string& contents, xla::HloProto* hlo_proto) { } mlir::OwningOpRef HloToMlirHloTranslateFunction( - llvm::StringRef input, mlir::MLIRContext* context, + mlir::StringRef input, mlir::MLIRContext* context, bool import_all_computations) { xla::HloProto hlo_proto; - string content(input.data(), input.size()); + std::string content(input.data(), input.size()); if (!LoadHloProto(content, &hlo_proto)) { LOG(ERROR) << "Failed to load proto"; return nullptr; @@ -100,7 +97,7 @@ mlir::OwningOpRef HloTextToMlirHloTranslateFunction( llvm::StringRef input, mlir::MLIRContext* context, bool import_all_computations) { xla::HloProto hlo_proto; - string content(input.data(), input.size()); + std::string content(input.data(), input.size()); auto hlo_module_error = xla::ParseAndReturnUnverifiedModule(content); if (!hlo_module_error.ok()) { @@ -122,16 +119,16 @@ mlir::OwningOpRef HloTextToMlirHloTranslateFunction( } } // namespace -Status ConvertJaxToTFLiteFlatBuffer(const std::string& input, - const toco::ModelFlags& model_flags, - const toco::TocoFlags& toco_flags, - string* result) { +absl::Status ConvertJaxToTFLiteFlatBuffer(const std::string& input, + const toco::ModelFlags& model_flags, + const toco::TocoFlags& toco_flags, + std::string* result) { mlir::MLIRContext context; mlir::quant::QuantizationSpecs quant_specs; // Parse input arrays. - std::vector node_names; - std::vector node_dtypes; + std::vector node_names; + std::vector node_dtypes; std::vector>> node_shapes; std::vector> node_mins; std::vector> node_maxs; @@ -191,10 +188,12 @@ Status ConvertJaxToTFLiteFlatBuffer(const std::string& input, // phase. main_func->setAttr("tf.entry_function", builder.getDictionaryAttr(attrs)); + // StableHLO Quantizer is not supported for JAX input models, so + // quantization_py_function_lib is set to nullptr. auto status = internal::ConvertMLIRToTFLiteFlatBuffer( model_flags, toco_flags, std::move(module), pass_config, - /*saved_model_tags=*/{}, result, - /*session=*/std::nullopt); + /*saved_model_tags=*/{}, result, /*saved_model_bundle=*/nullptr, + /*quantization_py_function_lib=*/nullptr); return status; } diff --git a/tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.cc b/tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.cc index 917eeb400f41a8..57550c9f5b0f9d 100644 --- a/tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.cc +++ b/tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.cc @@ -35,6 +35,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/lite/quantization/quantization_config.h" #include "tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.h" #include "tensorflow/compiler/mlir/lite/transforms/passes.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/python/py_function_lib.h" #include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h" #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/graph_debug_info.pb.h" @@ -48,6 +49,8 @@ limitations under the License. namespace tensorflow { +using tensorflow::quantization::PyFunctionLibrary; + Status HandleInputOutputArraysWithModule( const toco::ModelFlags& model_flags, mlir::OwningOpRef* module) { @@ -124,9 +127,10 @@ Status HandleInputOutputArraysWithModule( return OkStatus(); } -Status ConvertSavedModelToTFLiteFlatBuffer(const toco::ModelFlags& model_flags, - const toco::TocoFlags& toco_flags, - string* result) { +Status ConvertSavedModelToTFLiteFlatBuffer( + const toco::ModelFlags& model_flags, const toco::TocoFlags& toco_flags, + std::string* result, + const PyFunctionLibrary* quantization_py_function_lib) { mlir::MLIRContext context; mlir::quant::QuantizationSpecs quant_specs; @@ -199,6 +203,7 @@ Status ConvertSavedModelToTFLiteFlatBuffer(const toco::ModelFlags& model_flags, pass_config.enable_stablehlo_conversion = toco_flags.convert_to_stablehlo(); pass_config.legalize_custom_tensor_list_ops = toco_flags.legalize_custom_tensor_list_ops(); + pass_config.enable_stablehlo_quantizer = toco_flags.has_quantization_config(); if (toco_flags.qdq_conversion_mode() == "STATIC") { pass_config.quant_specs.qdq_conversion_mode = @@ -228,7 +233,7 @@ Status ConvertSavedModelToTFLiteFlatBuffer(const toco::ModelFlags& model_flags, // TODO(b/153507667): Pass the session object when importing logic is removed. auto status = internal::ConvertMLIRToTFLiteFlatBuffer( model_flags, toco_flags, std::move(module), pass_config, tags, result, - bundle ? bundle->GetSession() : nullptr); + bundle.get(), quantization_py_function_lib); return status; } diff --git a/tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.h b/tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.h index 362e9e39ae54c8..50d61dbd4f873b 100644 --- a/tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.h +++ b/tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.h @@ -15,6 +15,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_LITE_PYTHON_SAVED_MODEL_TO_TFL_FLATBUFFER_H_ #define TENSORFLOW_COMPILER_MLIR_LITE_PYTHON_SAVED_MODEL_TO_TFL_FLATBUFFER_H_ +#include "tensorflow/compiler/mlir/quantization/tensorflow/python/py_function_lib.h" #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/graph_debug_info.pb.h" #include "tensorflow/core/lib/core/errors.h" @@ -28,7 +29,8 @@ namespace tensorflow { // status if it fails to convert the input. Status ConvertSavedModelToTFLiteFlatBuffer( const toco::ModelFlags& model_flags, const toco::TocoFlags& toco_flags, - string* result); + string* result, + const quantization::PyFunctionLibrary* quantization_py_function_lib); } // namespace tensorflow diff --git a/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc b/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc index 980d74d6a47aa5..795485328d4779 100644 --- a/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc +++ b/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc @@ -15,48 +15,55 @@ limitations under the License. #include "tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.h" #include -#include #include #include #include #include +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" #include "llvm/Support/ToolOutputFile.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project -#include "mlir/IR/MLIRContext.h" // from @llvm-project -#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/IR/OwningOpRef.h" // from @llvm-project +#include "mlir/IR/Visitors.h" // from @llvm-project +#include "mlir/Pass/PassManager.h" // from @llvm-project #include "mlir/Support/FileUtilities.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Transforms/ViewOpGraph.h" // from @llvm-project +#include "tensorflow/cc/saved_model/loader.h" #include "tensorflow/compiler/mlir/lite/common/tfl_pass_config.h" -#include "tensorflow/compiler/mlir/lite/tf_tfl_passes.h" +#include "tensorflow/compiler/mlir/lite/quantization/quantization_config.h" #include "tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.h" #include "tensorflow/compiler/mlir/lite/transforms/passes.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/python/py_function_lib.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.h" -#include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h" -#include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h" -#include "tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h" #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/graph_debug_info.pb.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_def_builder.h" #include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/public/session.h" #include "tensorflow/lite/toco/model_flags.pb.h" #include "tensorflow/lite/toco/toco_flags.pb.h" #include "tensorflow/lite/toco/types.pb.h" #include "tensorflow/lite/tools/optimize/reduced_precision_support.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/protobuf.h" // IWYU pragma: keep #include "tsl/platform/statusor.h" -using tsl::StatusOr; - namespace tensorflow { namespace internal { namespace { using ::mlir::quant::ReducedPrecisionSupport; +using ::tensorflow::quantization::PyFunctionLibrary; // Op def string for TFLite_Detection_PostProcess Op. -const char kDetectionPostProcessOp[] = +constexpr mlir::StringRef kDetectionPostProcessOp = "name: 'TFLite_Detection_PostProcess' input_arg: { name: " "'raw_outputs/box_encodings' type: DT_FLOAT } input_arg: { name: " "'raw_outputs/class_predictions' type: DT_FLOAT } input_arg: { name: " @@ -74,7 +81,7 @@ const char kDetectionPostProcessOp[] = "'detections_per_class' type: 'int' default_value { i : 100 }} attr { " "name: 'use_regular_nms' type: 'bool' default_value { b : false }}"; -const char kUnidirectionalSequenceLstmOp[] = +constexpr mlir::StringRef kUnidirectionalSequenceLstmOp = "name: 'UnidirectionalSequenceLstm' input_arg: {name: 'Input' type: " "DT_FLOAT} input_arg: { name: 'InputToInputWeights' type: DT_FLOAT } " "input_arg: { name: 'InputToForgetWeights' type: DT_FLOAT } input_arg: { " @@ -98,7 +105,7 @@ const char kUnidirectionalSequenceLstmOp[] = "'LastState' type: DT_FLOAT } output_arg: { name: 'Output' type: DT_FLOAT} " "attr : { name: '_tflite_input_indices' type: 'list(int)'}"; -const char kUnidirectionalSequenceRnnOp[] = +constexpr mlir::StringRef kUnidirectionalSequenceRnnOp = "name: 'UnidirectionalSequenceRnn' input_arg: {name: 'Input' type: " "DT_FLOAT} input_arg: { name: 'Weights' type: DT_FLOAT } " "input_arg: { name: 'RecurrentWeights' type: DT_FLOAT } input_arg: { " @@ -158,8 +165,9 @@ DataType ConvertIODataTypeToDataType(toco::IODataType dtype) { } } -StatusOr> InputStatsToMinMax(double mean, double std, - DataType type) { +absl::StatusOr> InputStatsToMinMax(double mean, + double std, + DataType type) { // Only qint8 and quint8 are considered here. double qmin, qmax; if (type == DT_QUINT8) { @@ -169,58 +177,59 @@ StatusOr> InputStatsToMinMax(double mean, double std, qmin = -128.0; qmax = 127.0; } else { - return errors::InvalidArgument("Only int8 and uint8 are considered."); + return absl::InvalidArgumentError("Only int8 and uint8 are considered."); } return std::make_pair((qmin - mean) / std, (qmax - mean) / std); } -Status RegisterCustomBuiltinOps(const std::vector extra_tf_opdefs) { +absl::Status RegisterCustomBuiltinOps( + const std::vector extra_tf_opdefs) { for (const auto& tf_opdefs_string : extra_tf_opdefs) { - tensorflow::OpDef opdef; - if (!tensorflow::protobuf::TextFormat::ParseFromString(tf_opdefs_string, - &opdef)) { + OpDef opdef; + // NOLINTNEXTLINE: Use tsl::protobuf to be compatible with OSS. + if (!tsl::protobuf::TextFormat::ParseFromString(tf_opdefs_string, &opdef)) { return errors::InvalidArgument("fail to parse extra OpDef"); } // Make sure the op is not already registered. If registered continue. const OpRegistrationData* op_reg = - tensorflow::OpRegistry::Global()->LookUp(opdef.name()); + OpRegistry::Global()->LookUp(opdef.name()); if (op_reg) continue; - tensorflow::OpRegistry::Global()->Register( - [opdef](tensorflow::OpRegistrationData* op_reg_data) -> Status { - *op_reg_data = tensorflow::OpRegistrationData(opdef); - return OkStatus(); + OpRegistry::Global()->Register( + [opdef](OpRegistrationData* op_reg_data) -> absl::Status { + *op_reg_data = OpRegistrationData(opdef); + return absl::OkStatus(); }); } - return OkStatus(); + return absl::OkStatus(); } } // namespace -Status RegisterAllCustomOps(const toco::TocoFlags& toco_flags) { +absl::Status RegisterAllCustomOps(const toco::TocoFlags& toco_flags) { // Register any custom OpDefs. - std::vector extra_tf_opdefs(toco_flags.custom_opdefs().begin(), - toco_flags.custom_opdefs().end()); - extra_tf_opdefs.push_back(kDetectionPostProcessOp); - extra_tf_opdefs.push_back(kUnidirectionalSequenceLstmOp); - extra_tf_opdefs.push_back(kUnidirectionalSequenceRnnOp); + std::vector extra_tf_opdefs(toco_flags.custom_opdefs().begin(), + toco_flags.custom_opdefs().end()); + extra_tf_opdefs.push_back(kDetectionPostProcessOp.str()); + extra_tf_opdefs.push_back(kUnidirectionalSequenceLstmOp.str()); + extra_tf_opdefs.push_back(kUnidirectionalSequenceRnnOp.str()); return RegisterCustomBuiltinOps(extra_tf_opdefs); } -Status PopulateQuantizationSpecs( +absl::Status PopulateQuantizationSpecs( const toco::ModelFlags& model_flags, const toco::TocoFlags& toco_flags, mlir::quant::QuantizationSpecs* quant_specs, - std::vector* node_names, std::vector* node_dtypes, + std::vector* node_names, std::vector* node_dtypes, std::vector>>* node_shapes, std::vector>* node_mins, std::vector>* node_maxs) { quant_specs->inference_input_type = ConvertIODataTypeToDataType(toco_flags.inference_input_type()); - tensorflow::DataType inference_type = + DataType inference_type = ConvertIODataTypeToDataType(toco_flags.inference_type()); // Use non-float flag `inference_input_type` to override the `inference_type` // because we have to apply quantization to satisfy that. - if (quant_specs->inference_input_type != tensorflow::DT_FLOAT) { + if (quant_specs->inference_input_type != DT_FLOAT) { inference_type = quant_specs->inference_input_type; } @@ -270,11 +279,11 @@ Status PopulateQuantizationSpecs( quant_specs->disable_per_channel = toco_flags.disable_per_channel_quantization(); if (toco_flags.quantize_to_float16()) { - quant_specs->inference_type = tensorflow::DT_HALF; - quant_specs->inference_input_type = tensorflow::DT_HALF; + quant_specs->inference_type = DT_HALF; + quant_specs->inference_input_type = DT_HALF; } else { - quant_specs->inference_type = tensorflow::DT_QINT8; - quant_specs->inference_input_type = tensorflow::DT_QINT8; + quant_specs->inference_type = DT_QINT8; + quant_specs->inference_input_type = DT_QINT8; } } else { // These flags are incompatible with post_training_quantize() as only @@ -313,11 +322,14 @@ Status PopulateQuantizationSpecs( toco_flags.enable_mlir_dynamic_range_quantizer(); quant_specs->enable_mlir_variable_quantization = toco_flags.enable_mlir_variable_quantization(); - return OkStatus(); + quant_specs->disable_per_channel_for_dense_layers = + toco_flags.disable_per_channel_quantization_for_dense_layers(); + return absl::OkStatus(); } // Dumps the op graph of the `module` to `filename` in DOT format. -Status DumpOpGraphToFile(mlir::ModuleOp module, const std::string& filename) { +absl::Status DumpOpGraphToFile(mlir::ModuleOp module, + const std::string& filename) { std::string error_message; auto output = mlir::openOutputFile(filename, &error_message); if (!error_message.empty()) { @@ -329,15 +341,16 @@ Status DumpOpGraphToFile(mlir::ModuleOp module, const std::string& filename) { return errors::Unknown("Failed to dump Op Graph from MLIR module."); } output->keep(); - return OkStatus(); + return absl::OkStatus(); } -Status ConvertMLIRToTFLiteFlatBuffer( +absl::Status ConvertMLIRToTFLiteFlatBuffer( const toco::ModelFlags& model_flags, const toco::TocoFlags& toco_flags, mlir::OwningOpRef module, const mlir::TFL::PassConfig& pass_config, - const std::unordered_set& saved_model_tags, string* result, - std::optional session) { + const std::unordered_set& saved_model_tags, + std::string* result, SavedModelBundle* saved_model_bundle, + const PyFunctionLibrary* quantization_py_function_lib) { if (toco_flags.has_dump_graphviz_dir()) { TF_RETURN_IF_ERROR(DumpOpGraphToFile( module.get(), @@ -361,7 +374,8 @@ Status ConvertMLIRToTFLiteFlatBuffer( auto status = ConvertTFExecutorToTFLOrFlatbuffer( module.get(), /*export_to_mlir=*/false, toco_flags, pass_config_copy, - saved_model_tags, model_flags.saved_model_dir(), session, result); + saved_model_tags, model_flags.saved_model_dir(), saved_model_bundle, + result, /*serialize_stablehlo_ops=*/false, quantization_py_function_lib); if (toco_flags.has_dump_graphviz_dir()) { TF_RETURN_IF_ERROR(DumpOpGraphToFile( // rename once we enable the new converter feature flag. diff --git a/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.h b/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.h index 8d7eeb2912a7b6..039e56672ddadc 100644 --- a/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.h +++ b/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.h @@ -24,8 +24,10 @@ limitations under the License. #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "tensorflow/cc/saved_model/loader.h" #include "tensorflow/compiler/mlir/lite/common/tfl_pass_config.h" #include "tensorflow/compiler/mlir/lite/transforms/passes.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/python/py_function_lib.h" #include "tensorflow/core/public/session.h" #include "tensorflow/lite/toco/model_flags.pb.h" #include "tensorflow/lite/toco/toco_flags.pb.h" @@ -55,7 +57,8 @@ Status ConvertMLIRToTFLiteFlatBuffer( mlir::OwningOpRef module, const mlir::TFL::PassConfig& pass_config, const std::unordered_set& saved_model_tags, string* result, - std::optional session); + SavedModelBundle* saved_model_bundle, + const quantization::PyFunctionLibrary* quantization_py_function_lib); // Give a warning for any unused flags that have been specified. void WarningUnusedFlags(const toco::ModelFlags& model_flags, diff --git a/tensorflow/compiler/mlir/lite/quantization/lite/BUILD b/tensorflow/compiler/mlir/lite/quantization/lite/BUILD index 356801e7fd38a0..cf437c27e2ec4e 100644 --- a/tensorflow/compiler/mlir/lite/quantization/lite/BUILD +++ b/tensorflow/compiler/mlir/lite/quantization/lite/BUILD @@ -2,7 +2,7 @@ load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") load("//tensorflow:tensorflow.bzl", "tf_cc_binary", "tf_cc_test") package( - # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + # copybara:uncomment default_applicable_licenses = ["//tensorflow:LICENSE"], default_visibility = [ ":friends", "//tensorflow:__pkg__", diff --git a/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model.cc b/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model.cc index 4545aa412686b7..f7ffc7d71d02f9 100644 --- a/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model.cc +++ b/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model.cc @@ -58,7 +58,8 @@ TfLiteStatus QuantizeModel( bool whole_model_verify, bool legacy_float_scale, const absl::flat_hash_set& denylisted_ops, const absl::flat_hash_set& denylisted_nodes, - const bool enable_variable_quantization) { + const bool enable_variable_quantization, + bool disable_per_channel_for_dense_layers) { // Translate TFLite names to mlir op names. absl::flat_hash_set denylisted_mlir_op_names; for (const auto& entry : denylisted_ops) { @@ -84,6 +85,8 @@ TfLiteStatus QuantizeModel( quant_specs.inference_type = tflite::TflTypeToTfType(inference_type); quant_specs.post_training_quantization = true; quant_specs.disable_per_channel = disable_per_channel; + quant_specs.disable_per_channel_for_dense_layers = + disable_per_channel_for_dense_layers; quant_specs.verify_numeric = verify_numeric; quant_specs.whole_model_verify = whole_model_verify; quant_specs.legacy_float_scale = legacy_float_scale; diff --git a/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model.h b/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model.h index d85aba47811675..50b397ba0206d2 100644 --- a/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model.h +++ b/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model.h @@ -54,7 +54,8 @@ TfLiteStatus QuantizeModel( bool whole_model_verify = false, bool legacy_float_scale = true, const absl::flat_hash_set& denylisted_ops = {}, const absl::flat_hash_set& denylisted_nodes = {}, - bool enable_variable_quantization = false); + bool enable_variable_quantization = false, + bool disable_per_channel_for_dense_layers = false); } // namespace lite } // namespace mlir diff --git a/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model_test.cc b/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model_test.cc index 5898c9e54234a5..696a2545d7097a 100644 --- a/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model_test.cc +++ b/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model_test.cc @@ -72,7 +72,8 @@ TfLiteStatus QuantizeModel( const TensorType& activations_type, ErrorReporter* error_reporter, std::string& output_buffer, const bool disable_per_channel = false, const absl::flat_hash_set& blocked_ops = {}, - const absl::flat_hash_set& blocked_nodes = {}) { + const absl::flat_hash_set& blocked_nodes = {}, + const bool disable_per_channel_for_dense_layers = false) { TensorType inference_tensor_type = activations_type; const bool fully_quantize = !allow_float; @@ -87,7 +88,10 @@ TfLiteStatus QuantizeModel( input_buffer, input_type, output_type, inference_tensor_type, /*operator_names=*/{}, disable_per_channel, fully_quantize, output_buffer, error_reporter, /*verify_numeric=*/false, /*whole_model_verify=*/false, - /*legacy_float_scale=*/true, blocked_ops, blocked_nodes); + /*legacy_float_scale=*/true, blocked_ops, blocked_nodes, + /*enable_variable_quantization=*/false, + /*disable_per_channel_for_dense_layers=*/ + disable_per_channel_for_dense_layers); if (status != kTfLiteOk) { return status; } @@ -140,6 +144,21 @@ TfLiteStatus QuantizeModelAllOperators( output_buffer); } +TfLiteStatus QuantizeModelAllOperators( + ModelT* model, const TensorType& input_type, const TensorType& output_type, + bool allow_float, const TensorType& activations_type, + ErrorReporter* error_reporter, std::string& output_buffer, + bool disable_per_channel_for_dense_layers) { + return QuantizeModel(model, input_type, output_type, allow_float, + /*operator_names=*/{}, activations_type, error_reporter, + output_buffer, + /*disable_per_channel=*/false, + /* blocked_ops=*/{}, + /*blocked_nodes=*/{}, + /*disable_per_channel_for_dense_layers=*/ + disable_per_channel_for_dense_layers); +} + std::unique_ptr ReadModel(const string& model_name) { auto model_path = tensorflow::io::JoinPath(*g_test_model_dir, model_name); return FlatBufferModel::BuildFromFile(model_path.c_str()); @@ -1118,16 +1137,20 @@ TEST_F(QuantizeSVDFTest, VerifySVDF) { ExpectSameModels(model_, expected_model); } -class QuantizeFCTest : public QuantizeModelTest { +class QuantizeFCTest : public QuantizeModelTest, + public testing::WithParamInterface { protected: QuantizeFCTest() { + disable_per_channel_quantization_for_dense_ = GetParam(); input_model_ = ReadModel(internal::kModelWithFCOp); readonly_model_ = input_model_->GetModel(); model_ = UnPackFlatBufferModel(*readonly_model_); } + + bool disable_per_channel_quantization_for_dense_; }; -TEST_F(QuantizeFCTest, VerifyFC8x8) { +TEST_P(QuantizeFCTest, VerifyFC8x8) { auto status = QuantizeModelAllOperators( &model_, TensorType_INT8, TensorType_INT8, /*allow_float=*/false, TensorType_INT8, &error_reporter_, output_buffer_); @@ -1180,7 +1203,7 @@ TEST_F(QuantizeFCTest, VerifyFC8x8) { /*bit_num=*/8, /*symmetric=*/false); } -TEST_F(QuantizeFCTest, VerifyFCFor16x8) { +TEST_P(QuantizeFCTest, VerifyFCFor16x8) { auto status = QuantizeModelAllOperators( &model_, TensorType_INT8, TensorType_INT8, /*allow_float=*/false, TensorType_INT16, &error_reporter_, output_buffer_); @@ -1195,7 +1218,7 @@ TEST_F(QuantizeFCTest, VerifyFCFor16x8) { ASSERT_THAT(op->outputs, SizeIs(1)); const SubGraph* float_graph = readonly_model_->subgraphs()->Get(0); - // Verify FC input tesnor and weight are int16 and int8 quantized. + // Verify FC input tensor and weight are int16 and int8 quantized. const Operator* float_op = float_graph->operators()->Get(0); ASSERT_THAT(float_graph->tensors()->Get(float_op->inputs()->Get(0))->type(), Eq(TensorType_FLOAT32)); @@ -1235,6 +1258,136 @@ TEST_F(QuantizeFCTest, VerifyFCFor16x8) { /*bit_num=*/16, /*symmetric=*/true); } +TEST_P(QuantizeFCTest, VerifyDisablePerChannelQuantization) { + auto status = QuantizeModelAllOperators( + &model_, TensorType_INT8, TensorType_INT8, /*allow_float=*/false, + TensorType_INT8, &error_reporter_, output_buffer_, + /*disable_per_channel_for_dense_layers=*/ + disable_per_channel_quantization_for_dense_); + ASSERT_THAT(status, Eq(kTfLiteOk)); + const auto& subgraph = model_.subgraphs[0]; + auto fc_op = subgraph->operators[0].get(); + + ASSERT_THAT(fc_op->inputs, SizeIs(3)); + ASSERT_THAT(fc_op->outputs, SizeIs(1)); + + const int input_tensor_idx = 0; + const int weights_tensor_idx = 1; + const int bias_tensor_index = 2; + const int output_tensor_idx = 0; + const auto bias_tensor = + subgraph->tensors[fc_op->inputs[bias_tensor_index]].get(); + const auto input_tensor = + subgraph->tensors[fc_op->inputs[input_tensor_idx]].get(); + const auto weights_tensor = + subgraph->tensors[fc_op->inputs[weights_tensor_idx]].get(); + const auto output_tensor = + subgraph->tensors[fc_op->outputs[output_tensor_idx]].get(); + + EXPECT_THAT(bias_tensor->type, Eq(TensorType_INT32)); + EXPECT_THAT(input_tensor->type, Eq(TensorType_INT8)); + EXPECT_THAT(weights_tensor->type, Eq(TensorType_INT8)); + EXPECT_THAT(output_tensor->type, Eq(TensorType_INT8)); + + ASSERT_TRUE(weights_tensor->quantization); + ASSERT_TRUE(bias_tensor->quantization); + ASSERT_TRUE(weights_tensor->quantization); + const std::vector& bias_scales = bias_tensor->quantization->scale; + const std::vector& weights_scales = + weights_tensor->quantization->scale; + const std::vector& weights_zero_points = + weights_tensor->quantization->zero_point; + + const int out_channel_size = 2; + ASSERT_THAT(bias_scales, SizeIs(disable_per_channel_quantization_for_dense_ + ? 1 + : out_channel_size)); + ASSERT_THAT(weights_scales, SizeIs(disable_per_channel_quantization_for_dense_ + ? 1 + : out_channel_size)); + ASSERT_THAT( + weights_zero_points, + SizeIs(disable_per_channel_quantization_for_dense_ ? 1 + : out_channel_size)); + ASSERT_THAT(input_tensor->quantization->scale, SizeIs(1)); + ASSERT_THAT(output_tensor->quantization->scale, SizeIs(1)); + + const float eps = 1e-7; + + // Bias scale should be input * per_channel_weight_scale. + for (size_t i = 0; i < out_channel_size; i++) { + EXPECT_THAT((disable_per_channel_quantization_for_dense_ ? bias_scales[0] + : bias_scales[i]), + FloatNear(input_tensor->quantization->scale[0] * + (disable_per_channel_quantization_for_dense_ + ? weights_scales[0] + : weights_scales[i]), + eps)); + } + + const auto bias_buffer = model_.buffers[bias_tensor->buffer].get(); + auto control_size = sizeof(int32_t) * bias_tensor->shape[0]; + + ASSERT_THAT(bias_buffer->data, SizeIs(control_size)); + const auto float_op = + readonly_model_->subgraphs()->Get(0)->operators()->Get(0); + const auto original_bias_tensor = + readonly_model_->subgraphs()->Get(0)->tensors()->Get( + float_op->inputs()->Get(2)); + ASSERT_THAT(bias_buffer->data, SizeIs(control_size)); + const auto original_bias_buffer = + readonly_model_->buffers()->Get(original_bias_tensor->buffer()); + const float* bias_float_buffer = + reinterpret_cast(original_bias_buffer->data()->data()); + + int32_t* bias_values = reinterpret_cast(bias_buffer->data.data()); + for (size_t i = 0; i < out_channel_size; i++) { + const float bias_scale = disable_per_channel_quantization_for_dense_ + ? bias_scales[0] + : bias_scales[i]; + auto dequantized_value = bias_values[i] * bias_scale; + EXPECT_THAT(dequantized_value, + FloatNear(bias_float_buffer[i], bias_scale / 2)); + } + + const auto weights_buffer = model_.buffers[weights_tensor->buffer].get(); + const auto original_weights_tensor = + readonly_model_->subgraphs()->Get(0)->tensors()->Get( + float_op->inputs()->Get(1)); + const auto original_weights_buffer = + readonly_model_->buffers()->Get(original_weights_tensor->buffer()); + const int8_t* weight_values = + reinterpret_cast(weights_buffer->data.data()); + const float* weights_float_buffer = + reinterpret_cast(original_weights_buffer->data()->data()); + ASSERT_THAT(sizeof(float) * weights_buffer->data.size(), + Eq(original_weights_buffer->data()->size())); + int num_values_in_channel = weights_buffer->data.size() / out_channel_size; + for (size_t channel_idx = 0; channel_idx < out_channel_size; channel_idx++) { + for (size_t j = 0; j < num_values_in_channel; j++) { + size_t element_idx = channel_idx * num_values_in_channel + j; + auto scale = disable_per_channel_quantization_for_dense_ + ? weights_scales[0] + : weights_scales[channel_idx]; + auto zero_point = disable_per_channel_quantization_for_dense_ + ? weights_zero_points[0] + : weights_zero_points[channel_idx]; + auto dequantized_value = weight_values[element_idx] * scale; + EXPECT_THAT(dequantized_value, + FloatNear(weights_float_buffer[element_idx], scale / 2)); + EXPECT_THAT(zero_point, Eq(0)); + } + } + + // check op and versioning. + EXPECT_THAT(model_.operator_codes, SizeIs(1)); + EXPECT_THAT(GetBuiltinCode(model_.operator_codes[0].get()), + Eq(BuiltinOperator_FULLY_CONNECTED)); + ASSERT_THAT(model_.operator_codes[0]->version, 5); +} + +INSTANTIATE_TEST_SUITE_P(QuantizeFCTestInst, QuantizeFCTest, testing::Bool()); + class QuantizeCustomOpTest : public QuantizeModelTest, public ::testing::WithParamInterface { diff --git a/tensorflow/compiler/mlir/lite/quantization/quantization_config.h b/tensorflow/compiler/mlir/lite/quantization/quantization_config.h index 0361163aeee112..66175aabf394e4 100644 --- a/tensorflow/compiler/mlir/lite/quantization/quantization_config.h +++ b/tensorflow/compiler/mlir/lite/quantization/quantization_config.h @@ -86,6 +86,11 @@ struct QuantizationSpecs { // weight FakeQuant). bool disable_per_channel = false; + // Disables per channel weights quantization for Dense layers and enables + // legacy per tensor quantization. The legacy quantization for Dense layers is + // inconsistent with Conv 1x1 which always performs per channel quantization. + bool disable_per_channel_for_dense_layers = false; + // When set to true, the fixed output ranges of the activation ops (tanh, // sigmoid, etc.) and the weight constants are not inferred. Then, to quantize // these ops, quantization emulation ops should be placed after the ops in the diff --git a/tensorflow/compiler/mlir/lite/quantization/quantization_driver.cc b/tensorflow/compiler/mlir/lite/quantization/quantization_driver.cc index 62c2733d2b510c..408540cd84a146 100644 --- a/tensorflow/compiler/mlir/lite/quantization/quantization_driver.cc +++ b/tensorflow/compiler/mlir/lite/quantization/quantization_driver.cc @@ -17,6 +17,7 @@ limitations under the License. #include #include #include +#include #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/STLExtras.h" @@ -32,6 +33,7 @@ limitations under the License. #include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project #include "mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project #include "mlir/IR/Matchers.h" // from @llvm-project @@ -109,7 +111,8 @@ class QuantizationDriver { bool disable_per_channel, OpQuantSpecGetter op_quant_spec_getter, OpQuantScaleSpecGetter op_quant_scale_spec_getter, - bool infer_tensor_range, bool legacy_float_scale) + bool infer_tensor_range, bool legacy_float_scale, + bool is_qdq_conversion) : fn_(fn), builder_(fn.getBody()), is_signed_(is_signed), @@ -118,7 +121,8 @@ class QuantizationDriver { op_quant_spec_getter_(op_quant_spec_getter), op_quant_scale_spec_getter_(op_quant_scale_spec_getter), infer_tensor_range_(infer_tensor_range), - legacy_float_scale_(legacy_float_scale) {} + legacy_float_scale_(legacy_float_scale), + is_qdq_conversion_(is_qdq_conversion) {} // The entry point of the quantization parameters propagation. void Run(); @@ -198,7 +202,7 @@ class QuantizationDriver { // Returns the quantization params for the bias input from the non-bias // operands which have their indexes in the `non_biases` vector. The returned // parameters are calculated by `func`. - QuantParams GetBiasParams(Operation *op, int bias, + QuantParams GetBiasParams(Operation *op, int bias_index, const std::vector &non_biases, AccumulatorScaleFunc func); @@ -429,6 +433,10 @@ class QuantizationDriver { // Calculate scales in float instead of double, so that the scales and // quantized values are exactly the same with the TOCO quantizer. bool legacy_float_scale_; + + // If true, the model is a floating point graph with QDQ ops to be eliminated + // and fused into quantized kernels. + bool is_qdq_conversion_; }; } // namespace @@ -518,20 +526,35 @@ bool QuantizationDriver::SetResultParams(Operation *op, int res_index, } QuantParams QuantizationDriver::GetBiasParams( - Operation *op, int bias, const std::vector &non_biases, + Operation *op, const int bias_index, const std::vector &non_biases, AccumulatorScaleFunc func) { - auto &bias_state = GetOperandQuantState(op, bias); + QuantState &bias_state = GetOperandQuantState(op, bias_index); if (!bias_state.IsEmpty()) { return bias_state.params; } std::vector op_types; op_types.reserve(non_biases.size()); + int adjusted_quant_dim = -1; + if (op->getNumOperands() > bias_index) { + // Some kernels allow 1D bias, broadcasting it inside the kernel. In this + // case, the `quantizedDimension=0` when quantizing per-channel. + // However, for some kernels which require bias to be already broadcasted + // to match the accumulation shape, the very last index should be used. + Operation *bias_op = op->getOperand(bias_index).getDefiningOp(); + if (bias_op != nullptr) { + Type bias_type = bias_op->getResult(0).getType(); + if (bias_type != builder_.getNoneType()) { + int bias_rank = bias_type.dyn_cast().getRank(); + adjusted_quant_dim = bias_rank > 1 ? bias_rank - 1 : 0; + } + } + } + for (auto non_bias : non_biases) { auto &non_bias_type = GetOperandQuantState(op, non_bias); op_types.push_back(non_bias_type.params); } - if (op_types.empty()) return {}; - return func(op_types, legacy_float_scale_); + return func(op_types, adjusted_quant_dim, legacy_float_scale_); } bool QuantizationDriver::SetOperandParams(Operation *op, int index, @@ -956,7 +979,10 @@ bool QuantizationDriver::PropagateParams() { } } - if (scale_spec->has_fixed_output_range && infer_tensor_range_) { + // If the model already contains immutable QDQs, require upstream to + // explicitly fix output range instead. + if (scale_spec->has_fixed_output_range && infer_tensor_range_ && + !is_qdq_conversion_) { // Infer ranges from the activation ops. This is usually required for // the post-training quantization workflow. // TODO(fengliuai): different result can have different fixed range. @@ -1182,20 +1208,22 @@ void ApplyQuantizationParamsPropagation(mlir::func::FuncOp func, bool is_signed, int bit_width, bool disable_per_channel, OpQuantSpecGetter op_quant_spec_getter, bool infer_tensor_ranges, - bool legacy_float_scale) { + bool legacy_float_scale, + bool is_qdq_conversion) { ApplyQuantizationParamsPropagation( func, is_signed, bit_width, disable_per_channel, op_quant_spec_getter, - GetDefaultQuantScaleSpec, infer_tensor_ranges, legacy_float_scale); + GetDefaultQuantScaleSpec, infer_tensor_ranges, legacy_float_scale, + is_qdq_conversion); } void ApplyQuantizationParamsPropagation( mlir::func::FuncOp func, bool is_signed, int bit_width, bool disable_per_channel, OpQuantSpecGetter op_quant_spec_getter, OpQuantScaleSpecGetter op_quant_scale_spec_getter, bool infer_tensor_ranges, - bool legacy_float_scale) { + bool legacy_float_scale, bool is_qdq_conversion) { QuantizationDriver(func, is_signed, bit_width, disable_per_channel, op_quant_spec_getter, op_quant_scale_spec_getter, - infer_tensor_ranges, legacy_float_scale) + infer_tensor_ranges, legacy_float_scale, is_qdq_conversion) .Run(); } diff --git a/tensorflow/compiler/mlir/lite/quantization/quantization_utils.cc b/tensorflow/compiler/mlir/lite/quantization/quantization_utils.cc index 9a151a80e8f48b..53f8024c7900dd 100644 --- a/tensorflow/compiler/mlir/lite/quantization/quantization_utils.cc +++ b/tensorflow/compiler/mlir/lite/quantization/quantization_utils.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include #include +#include #include #include #include @@ -29,13 +30,20 @@ limitations under the License. #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" #include "llvm/Support/Casting.h" +#include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/raw_ostream.h" +#include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project #include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributeInterfaces.h" // from @llvm-project #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project #include "mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/IR/Diagnostics.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/OpDefinition.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/quantization/ir/FakeQuantSupport.h" @@ -43,7 +51,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/lite/quantization/ir/QuantizeUtils.h" #include "tensorflow/compiler/mlir/lite/quantization/ir/UniformSupport.h" #include "tensorflow/compiler/mlir/lite/quantization/quantization_traits.h" -#include "tensorflow/lite/kernels/internal/tensor_utils.h" +#include "tensorflow/lite/kernels/internal/portable_tensor_utils.h" #include "tensorflow/lite/tools/optimize/quantization_utils.h" namespace mlir { @@ -469,7 +477,7 @@ Type GetUniformQuantizedPerAxisTypeForWeight(ElementsAttr attr, int quant_dim, quant::QuantizedType GetUniformQuantizedTypeForBias( const std::vector& op_types, - bool legacy_float_scale) { + const int adjusted_quant_dim, const bool legacy_float_scale) { if (op_types.empty()) return {}; size_t axis_size = 1; @@ -531,13 +539,14 @@ quant::QuantizedType GetUniformQuantizedTypeForBias( /*zeroPoint=*/0, storage_type_min, storage_type_max); } else { llvm::SmallVector zero_points(axis_size, 0); - // Assume the bias is a 1-D tensor, and set the quantization dim to the last - // dimension, which is 0. If the bias rank is larger than 1, this returned - // quantized type couldn't be used to quantize the bias. + // If the bias is a 1-D tensor, set the `quantizedDimension` to 0. + // If the bias rank is larger than 1 because it was already broadcasted + // to match the output shape, use the last index. return quant::UniformQuantizedPerAxisType::getChecked( builder.getUnknownLoc(), /*flags=*/true, storage_type, expressed_type, scales, zero_points, - /*quantizedDimension=*/0, storage_type_min, storage_type_max); + /*quantizedDimension=*/std::max(adjusted_quant_dim, 0), + storage_type_min, storage_type_max); } } @@ -598,7 +607,7 @@ ElementsAttr QuantizeLegacy(Attribute real_value, Type tensor_type) { return DenseElementsAttr::get(new_dense_type, quantized_attr); } else if (width == 8) { // This can be a state tensor, or an actual constant tensor with - // asymmetric range. For a state tensor, assigining correct quantization + // asymmetric range. For a state tensor, assigning correct quantization // parameters is sufficient, and for constants with asymmetric range it's // not correctly quantized by legacy quantizer so call the new Quantize. return Quantize(real_value, tensor_type); @@ -643,7 +652,7 @@ ElementsAttr Quantize(Attribute real_value, Type tensor_type) { quant::QuantizedType::getQuantizedElementType(tensor_type)) { Type converted_type; return quantfork::quantizeAttr(real_value, q_type, converted_type) - .dyn_cast(); + .dyn_cast_or_null(); } return {}; } @@ -816,7 +825,7 @@ bool RemoveRedundantStatsOps( } } - // Step 2: backward pass: For the ops skiped in the forward pass, propagate + // Step 2: backward pass: For the ops skipped in the forward pass, propagate // its results scale backwards as far as possible. func.walk([&](quantfork::StatisticsOp stats_op) { if (redundant_stats_ops.find(stats_op) == redundant_stats_ops.end()) { diff --git a/tensorflow/compiler/mlir/lite/quantization/quantization_utils.h b/tensorflow/compiler/mlir/lite/quantization/quantization_utils.h index 1113bb868fa3e8..e1b697e3be67d1 100644 --- a/tensorflow/compiler/mlir/lite/quantization/quantization_utils.h +++ b/tensorflow/compiler/mlir/lite/quantization/quantization_utils.h @@ -66,29 +66,29 @@ namespace quant { // A unit attribute can be attached to the quantize/dequantize ops which are // added by the quantization passes. These ops can be removed erased without // losing accuracy. -constexpr char kVolatileOpAttrName[] = "volatile"; +inline constexpr char kVolatileOpAttrName[] = "volatile"; // Following attributes are used to mark ops that are not quantizable during // debug model generation process for whole-model verify mode. If these // attributes are attached, the upstream float/quantized ops know which ops to // connect to, and it also prevents these ops from being copied again. -constexpr char kDebugModeOpFloatAttrName[] = "debug_float"; -constexpr char kDebugModeOpQuantAttrName[] = "debug_quant"; +inline constexpr char kDebugModeOpFloatAttrName[] = "debug_float"; +inline constexpr char kDebugModeOpQuantAttrName[] = "debug_quant"; // Used to annotate custom ops if they are quantizable. -constexpr char kQuantTraitAttrName[] = "_tfl_quant_trait"; +inline constexpr char kQuantTraitAttrName[] = "_tfl_quant_trait"; enum QuantizationTrait { FullyQuantizable = 0, NotQuantizable = 1 }; -constexpr absl::string_view QuantTraitValues[] = {"fully_quantizable", - "not_quantizable"}; +inline constexpr absl::string_view QuantTraitValues[] = {"fully_quantizable", + "not_quantizable"}; -constexpr double kNearZeroTolerance = 1.0e-6; +inline constexpr double kNearZeroTolerance = 1.0e-6; using QuantParams = QuantizedType; using QuantSpec = QuantizationSpecs; using SignedInteger = std::pair; // bitwidth and sign using QuantParamsForResults = llvm::SmallVector; using AccumulatorScaleFunc = - std::function&, bool)>; + std::function&, int, bool)>; using BiasParamsMap = std::unordered_map, AccumulatorScaleFunc>>; // UniformQuantizedType GetFixedOutputRange(bool sign, int bit_width) @@ -890,7 +890,7 @@ Type GetUniformQuantizedPerAxisTypeForWeight( // other operands which are multiply-accumulated (the bias is added to the // accumulated value). quant::QuantizedType GetUniformQuantizedTypeForBias( - const std::vector& op_types, + const std::vector& op_types, int adjusted_quant_dim, bool legacy_float_scale = false); // Propagates quantization parameters across ops in this function and satisfy @@ -906,13 +906,14 @@ void ApplyQuantizationParamsPropagation(mlir::func::FuncOp func, bool is_signed, int bit_width, bool disable_per_channel, OpQuantSpecGetter op_quant_spec_getter, bool infer_tensor_ranges, - bool legacy_float_scale = false); + bool legacy_float_scale = false, + bool is_qdq_conversion = false); void ApplyQuantizationParamsPropagation( mlir::func::FuncOp func, bool is_signed, int bit_width, bool disable_per_channel, OpQuantSpecGetter op_quant_spec_getter, OpQuantScaleSpecGetter op_quant_scale_spec_getter, bool infer_tensor_ranges, - bool legacy_float_scale = false); + bool legacy_float_scale = false, bool is_qdq_conversion = false); // Gets quantization scale specs (e.g. fixed output range, same result and // operand scales) from the default quantization interfaces. The op should diff --git a/tensorflow/compiler/mlir/lite/quantization/stablehlo/BUILD b/tensorflow/compiler/mlir/lite/quantization/stablehlo/BUILD new file mode 100644 index 00000000000000..7f6b74431a95b7 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/quantization/stablehlo/BUILD @@ -0,0 +1,48 @@ +load("//tensorflow:tensorflow.bzl", "tf_cc_test") +load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_portable") + +package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = ["//tensorflow/compiler/mlir/lite:__subpackages__"], + licenses = ["notice"], +) + +cc_library( + name = "quantization", + srcs = ["quantization.cc"], + hdrs = ["quantization.h"], + compatible_with = get_compatible_with_portable(), + deps = [ + "//tensorflow/cc/saved_model:constants", + "//tensorflow/cc/saved_model:loader", + "//tensorflow/compiler/mlir/quantization/stablehlo:quantization_config_proto_cc", + "//tensorflow/compiler/mlir/quantization/stablehlo/cc:static_range_ptq", + "//tensorflow/compiler/mlir/quantization/tensorflow/python:py_function_lib", + "//tensorflow/compiler/mlir/tensorflow/transforms:tf_saved_model_freeze_variables", + "//tensorflow/core/protobuf:for_core_protos_cc", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", + ], +) + +tf_cc_test( + name = "quantization_test", + srcs = ["quantization_test.cc"], + deps = [ + ":quantization", + "//tensorflow/cc/saved_model:loader", + "//tensorflow/compiler/mlir/quantization/stablehlo:quantization_config_proto_cc", + "//tensorflow/compiler/mlir/quantization/stablehlo/cc:io", + "//tensorflow/compiler/mlir/quantization/stablehlo/cc:static_range_ptq_impl", # buildcleaner: keep; prevents undefined reference + "//tensorflow/compiler/mlir/quantization/tensorflow/calibrator:calibrator_singleton_impl", # buildcleaner: keep; prevents undefined reference + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_googletest//:gtest_main", + "@llvm-project//mlir:IR", + "@local_tsl//tsl/platform:status_matchers", + ], +) diff --git a/tensorflow/compiler/mlir/lite/quantization/stablehlo/quantization.cc b/tensorflow/compiler/mlir/lite/quantization/stablehlo/quantization.cc new file mode 100644 index 00000000000000..929634164fba3b --- /dev/null +++ b/tensorflow/compiler/mlir/lite/quantization/stablehlo/quantization.cc @@ -0,0 +1,118 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/compiler/mlir/lite/quantization/stablehlo/quantization.h" + +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "tensorflow/cc/saved_model/constants.h" +#include "tensorflow/cc/saved_model/loader.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/cc/static_range_ptq.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.pb.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/python/py_function_lib.h" +#include "tensorflow/compiler/mlir/tensorflow/transforms/tf_saved_model_freeze_variables.h" +#include "tensorflow/core/protobuf/meta_graph.pb.h" + +namespace tensorflow { +namespace { + +using ::mlir::quant::stablehlo::StaticRangePtqComponent; +using ::stablehlo::quantization::QuantizationConfig; +using ::tensorflow::SignatureDef; +using ::tensorflow::quantization::PyFunctionLibrary; + +// Returns signature key -> `SignatureDef` mapping, excluding the signature for +// initialization op, which is only used during initialization. +// TODO: b/314124142 - Remove the need for this function. +absl::flat_hash_map GetSignatureDefMapFromBundle( + const SavedModelBundle& saved_model_bundle) { + // Translate protobuf::Map -> absl::flat_hash_map. + const protobuf::Map& signatures = + saved_model_bundle.GetSignatures(); + absl::flat_hash_map signature_def_map( + signatures.begin(), signatures.end()); + + // Init op is only used during initialization and it's not a target for + // quantization. + signature_def_map.erase(kSavedModelInitOpSignatureKey); + return signature_def_map; +} + +// Retrieves the function name -> function alias mapping from the +// `SavedModelBundle`. +// TODO: b/314124142 - Remove the need for this function. +absl::flat_hash_map GetFunctionAliases( + const SavedModelBundle& saved_model_bundle) { + const protobuf::Map& function_aliases = + saved_model_bundle.meta_graph_def.meta_info_def().function_aliases(); + return absl::flat_hash_map(function_aliases.begin(), + function_aliases.end()); +} + +} // namespace + +absl::StatusOr RunQuantization( + const SavedModelBundle* saved_model_bundle, + const absl::string_view saved_model_dir, + const std::unordered_set& saved_model_tags, + const QuantizationConfig& quantization_config, + const PyFunctionLibrary* quantization_py_function_lib, + mlir::ModuleOp module_op) { + if (saved_model_bundle == nullptr) { + return absl::InvalidArgumentError( + "Failed to run quantization. `saved_model_bundle` should not be " + "nullptr."); + } + + if (quantization_py_function_lib == nullptr) { + return absl::InvalidArgumentError( + "Failed to run quantization. `quantization_py_function_lib` should not " + "be nullptr."); + } + + const absl::flat_hash_map signature_def_map = + GetSignatureDefMapFromBundle(*saved_model_bundle); + + std::vector exported_names; + for (const auto& [key, value_unused] : signature_def_map) { + exported_names.push_back(key); + } + + if (failed(mlir::tf_saved_model::FreezeVariables( + module_op, saved_model_bundle->GetSession()))) { + return absl::InternalError("Failed to freeze variables."); + } + + StaticRangePtqComponent static_range_ptq_component( + module_op.getContext(), quantization_py_function_lib, saved_model_dir, + /*signature_keys=*/exported_names, saved_model_tags, signature_def_map, + GetFunctionAliases(*saved_model_bundle)); + const absl::StatusOr quantized_module_op = + static_range_ptq_component.Run(module_op, quantization_config); + if (!quantized_module_op.ok()) { + return absl::InternalError("Failed to run quantization. Status msg: " + + quantized_module_op.status().ToString()); + } + return quantized_module_op; +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/mlir/lite/quantization/stablehlo/quantization.h b/tensorflow/compiler/mlir/lite/quantization/stablehlo/quantization.h new file mode 100644 index 00000000000000..c55d59cad0f1a0 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/quantization/stablehlo/quantization.h @@ -0,0 +1,60 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Adaptor functions for StableHLO Quantizer. +// Provides simpler interfaces when integrating StableHLO Quantizer into TFLite +// Converter. + +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_STABLEHLO_QUANTIZATION_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_STABLEHLO_QUANTIZATION_H_ + +#include +#include + +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "tensorflow/cc/saved_model/loader.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.pb.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/python/py_function_lib.h" + +namespace tensorflow { + +// Runs quantization on `module_op`. `saved_model_bundle` is required to +// retrieve information about the original model (e.g. signature def mapping) +// because quantization requires exporting the intermediate `ModuleOp` back to +// SavedModel for calibration. Similarly, `saved_model_dir` is required to +// access the assets of the original model. `saved_model_tags` uniquely +// identifies the `MetaGraphDef`. `quantization_config` determines the behavior +// of StableHLO Quantizer. `quantization_py_function_lib` contains python +// implementations of certain APIs that are required for calibration. +// `module_op` is the input graph to be quantized and it should contain +// StableHLO ops. +// +// Returns a quantized `ModuleOp` in StableHLO, potentially wrapped inside a +// XlaCallModuleOp. Returns a non-OK status if quantization fails, or any of +// `saved_model_bundle` or `quantization_py_function_lib` is a nullptr. +absl::StatusOr RunQuantization( + const SavedModelBundle* saved_model_bundle, + absl::string_view saved_model_dir, + const std::unordered_set& saved_model_tags, + const stablehlo::quantization::QuantizationConfig& quantization_config, + const tensorflow::quantization::PyFunctionLibrary* + quantization_py_function_lib, + mlir::ModuleOp module_op); + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_STABLEHLO_QUANTIZATION_H_ diff --git a/tensorflow/compiler/mlir/lite/quantization/stablehlo/quantization_test.cc b/tensorflow/compiler/mlir/lite/quantization/stablehlo/quantization_test.cc new file mode 100644 index 00000000000000..3cbc9e6ea47864 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/quantization/stablehlo/quantization_test.cc @@ -0,0 +1,79 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// Test cases for the StableHLO Quantizer adaptor functions. + +#include "tensorflow/compiler/mlir/lite/quantization/stablehlo/quantization.h" + +#include + +#include +#include +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "tensorflow/cc/saved_model/loader.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/cc/io.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.pb.h" +#include "tsl/platform/status_matchers.h" + +namespace tensorflow { +namespace { + +using ::stablehlo::quantization::QuantizationConfig; +using ::stablehlo::quantization::io::CreateTmpDir; +using ::testing::HasSubstr; +using ::tsl::testing::IsOk; +using ::tsl::testing::StatusIs; + +// Test cases for `RunQuantization` mainly tests for error cases because testing +// for successful cases require passing python implementation to +// `quantization_py_function_lib`, which requires testing from the python level. +// Internal integration tests exist for testing successful quantization. + +TEST(RunQuantizationTest, + WhenSavedModelBundleIsNullptrReturnsInvalidArgumentError) { + const absl::StatusOr tmp_saved_model_dir = CreateTmpDir(); + ASSERT_THAT(tmp_saved_model_dir, IsOk()); + + const absl::StatusOr quantized_module_op = RunQuantization( + /*saved_model_bundle=*/nullptr, *tmp_saved_model_dir, + /*saved_model_tags=*/{}, QuantizationConfig(), + /*quantization_py_function_lib=*/nullptr, /*module_op=*/{}); + EXPECT_THAT( + quantized_module_op, + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("`saved_model_bundle` should not be nullptr"))); +} + +TEST(RunQuantizationTest, + WhenPyFunctionLibIsNullptrReturnsInvalidArgumentError) { + const absl::StatusOr tmp_saved_model_dir = CreateTmpDir(); + ASSERT_THAT(tmp_saved_model_dir, IsOk()); + + // Dummy SavedModelBundle to pass a non-nullptr argument. + SavedModelBundle bundle{}; + const absl::StatusOr quantized_module_op = RunQuantization( + /*saved_model_bundle=*/&bundle, *tmp_saved_model_dir, + /*saved_model_tags=*/{}, QuantizationConfig(), + /*quantization_py_function_lib=*/nullptr, /*module_op=*/{}); + EXPECT_THAT( + quantized_module_op, + StatusIs( + absl::StatusCode::kInvalidArgument, + HasSubstr("`quantization_py_function_lib` should not be nullptr"))); +} + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/mlir/lite/quantization/tensorflow/BUILD b/tensorflow/compiler/mlir/lite/quantization/tensorflow/BUILD index cdefbdb1e28a4e..5110947df307e3 100644 --- a/tensorflow/compiler/mlir/lite/quantization/tensorflow/BUILD +++ b/tensorflow/compiler/mlir/lite/quantization/tensorflow/BUILD @@ -58,17 +58,23 @@ cc_library( "passes.h", ], deps = [ + ":ptq_fallback_to_flex_inc_gen", "//tensorflow/compiler/mlir/lite:tensorflow_lite", "//tensorflow/compiler/mlir/lite/quantization:quantization_lib", "//tensorflow/compiler/mlir/lite/quantization/ir:QuantOps", "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/mlir/tensorflow:export_tf_dialect_op", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core/platform:statusor", "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/strings:string_view", "@flatbuffers", + "@llvm-project//llvm:Support", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", "@llvm-project//mlir:QuantOps", + "@llvm-project//mlir:Support", "@llvm-project//mlir:TransformUtils", ], alwayslink = 1, diff --git a/tensorflow/compiler/mlir/lite/quantization/tensorflow/fallback_to_flex_ops.cc b/tensorflow/compiler/mlir/lite/quantization/tensorflow/fallback_to_flex_ops.cc index 759893401e69d7..dbd0b9ea524cb5 100644 --- a/tensorflow/compiler/mlir/lite/quantization/tensorflow/fallback_to_flex_ops.cc +++ b/tensorflow/compiler/mlir/lite/quantization/tensorflow/fallback_to_flex_ops.cc @@ -18,13 +18,32 @@ limitations under the License. #include #include "absl/base/attributes.h" +#include "absl/strings/string_view.h" #include "flatbuffers/flexbuffers.h" // from @flatbuffers +#include "llvm/ADT/STLExtras.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/CommandLine.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/Diagnostics.h" // from @llvm-project +#include "mlir/IR/DialectRegistry.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Pass/PassRegistry.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/translate/export_tf_dialect_op.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/platform/statusor.h" namespace mlir { namespace TF { diff --git a/tensorflow/compiler/mlir/lite/quantization/tensorflow/tf_to_quant.cc b/tensorflow/compiler/mlir/lite/quantization/tensorflow/tf_to_quant.cc index 9dfe5166033277..cc8543613162aa 100644 --- a/tensorflow/compiler/mlir/lite/quantization/tensorflow/tf_to_quant.cc +++ b/tensorflow/compiler/mlir/lite/quantization/tensorflow/tf_to_quant.cc @@ -15,8 +15,21 @@ limitations under the License. #include #include +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project +#include "mlir/IR/DialectRegistry.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Matchers.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Pass/PassRegistry.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/Support/TypeID.h" // from @llvm-project #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/quantization/ir/QuantOps.h" #include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h" diff --git a/tensorflow/compiler/mlir/lite/quantization/tools/op_quant_spec_getters_gen.cc b/tensorflow/compiler/mlir/lite/quantization/tools/op_quant_spec_getters_gen.cc index ad4112a05ad4a9..47440b4c4c0beb 100644 --- a/tensorflow/compiler/mlir/lite/quantization/tools/op_quant_spec_getters_gen.cc +++ b/tensorflow/compiler/mlir/lite/quantization/tools/op_quant_spec_getters_gen.cc @@ -15,7 +15,10 @@ limitations under the License. #include +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringRef.h" +#include "llvm/Support/Casting.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/InitLLVM.h" #include "llvm/Support/PrettyStackTrace.h" @@ -24,6 +27,7 @@ limitations under the License. #include "llvm/TableGen/Record.h" #include "llvm/TableGen/TableGenBackend.h" #include "mlir/TableGen/Operator.h" // from @llvm-project +#include "mlir/TableGen/Trait.h" // from @llvm-project using llvm::LessRecord; using llvm::raw_ostream; @@ -50,7 +54,8 @@ static bool OpQuantSpecWriter(raw_ostream &os, RecordKeeper &records) { llvm::sort(defs, LessRecord()); OUT(0) << "static std::unique_ptr " - "GetOpQuantSpec(mlir::Operation *op) {\n"; + "GetOpQuantSpec(mlir::Operation *op, bool " + "disable_per_channel_for_dense_layers = false) {\n"; // TODO(b/176258587): Move to OpTrait if this should be generalized. // Add special handling for LSTM. OUT(2) << "if (auto lstm_op = llvm::dyn_cast(op)) {\n"; @@ -94,7 +99,9 @@ static bool OpQuantSpecWriter(raw_ostream &os, RecordKeeper &records) { // There is a "QuantChannelDim" trait, set the quantization dimension. if (coeff_index_trait_regex.match(trait_str, &matches)) { OUT(4) << "spec->coeff_op_quant_dim[tfl.GetCoefficientOperandIndex()" - << "] = tfl.GetQuantizationDim();\n"; + << "] = llvm::dyn_cast(op) && " + "disable_per_channel_for_dense_layers ? -1 : " + "tfl.GetQuantizationDim();\n"; matches.clear(); } diff --git a/tensorflow/compiler/mlir/lite/stablehlo/BUILD b/tensorflow/compiler/mlir/lite/stablehlo/BUILD index 5ec6c735df2963..2468648d166f6b 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/BUILD +++ b/tensorflow/compiler/mlir/lite/stablehlo/BUILD @@ -423,6 +423,47 @@ cc_library( alwayslink = 1, ) +cc_library( + name = "legalize_stablehlo_to_vhlo_pass", + srcs = [ + "transforms/legalize_stablehlo_to_vhlo.cc", + ], + hdrs = [ + "transforms/passes.h", + "transforms/passes.h.inc", + ], + copts = [ + "-Ithird_party", + ], + deps = [ + "//tensorflow/compiler/mlir/lite:tensorflow_lite", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:InferTypeOpInterface", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:QuantOps", + "@llvm-project//mlir:ReconcileUnrealizedCasts", + "@llvm-project//mlir:ShapeDialect", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TensorDialect", + "@llvm-project//mlir:TransformUtils", + "@llvm-project//mlir:Transforms", + "@stablehlo//:base", + "@stablehlo//:stablehlo_ops", + "@stablehlo//:stablehlo_ops_inc_gen", + "@stablehlo//:stablehlo_pass_inc_gen", + "@stablehlo//:stablehlo_passes", + "@stablehlo//:stablehlo_portable_api", + "@stablehlo//:stablehlo_type_inference", + "@stablehlo//:version", + "@stablehlo//:vhlo_ops", + "@stablehlo//:vhlo_types", + ], + alwayslink = 1, +) + cc_library( name = "optimize", srcs = [ @@ -457,6 +498,7 @@ cc_library( deps = [ ":passes_inc_gen", "//tensorflow/compiler/mlir/lite:tensorflow_lite", + "//tensorflow/compiler/mlir/quantization/common:attrs_and_constraints", "//tensorflow/compiler/mlir/quantization/common:uniform_quantized_types", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/log:check", @@ -602,8 +644,10 @@ tf_cc_binary( "//tensorflow/compiler/mlir/lite:flatbuffer_export", "//tensorflow/compiler/mlir/lite:tf_to_tfl_flatbuffer", "//tensorflow/compiler/mlir/lite/stablehlo/serializer:flatbuffer_export", + "//tensorflow/compiler/mlir/quantization/stablehlo/cc:static_range_ptq_impl", "//tensorflow/compiler/mlir/quantization/tensorflow:quantize_preprocess", "//tensorflow/compiler/mlir/quantization/tensorflow:tf_quant_ops", + "//tensorflow/compiler/mlir/quantization/tensorflow/calibrator:calibrator_singleton_impl", "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/mlir/tensorflow/transforms:tensorflow_passes", "//tensorflow/compiler/mlir/tensorflow/transforms:tf_graph_optimization_pass", @@ -634,6 +678,7 @@ tf_cc_binary( ":compose_uniform_quantized_type_pass", ":fold_broadcast_pass", ":fuse_convolution_pass", + ":legalize_stablehlo_to_vhlo_pass", ":legalize_tf_xla_call_module_to_stablehlo_pass", ":optimize", ":passes_inc_gen", diff --git a/tensorflow/compiler/mlir/lite/stablehlo/tests/call_xla_module_to_stablehlo.mlir b/tensorflow/compiler/mlir/lite/stablehlo/tests/call_xla_module_to_stablehlo.mlir index 795840247cab93..292802ec92e5e1 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/tests/call_xla_module_to_stablehlo.mlir +++ b/tensorflow/compiler/mlir/lite/stablehlo/tests/call_xla_module_to_stablehlo.mlir @@ -17,7 +17,8 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, pr } } -// CHECK: module attributes {tfl.description = "MLIR Converted.", tfl.metadata = {keep_stablehlo_constant = "true"}, tfl.schema_version = 3 : i32} { +// CHECK: module attributes +// CHECK-SAME: tfl.metadata = {{{.*}}keep_stablehlo_constant = "true"{{.*}}} // CHECK-NEXT: func.func @main(%arg0: tensor<2x3xi32>) -> tensor<2x3xi32> attributes {tf.entry_function = {inputs = "args_tf_0", outputs = "Identity"}} { // CHECK-NEXT: %0 = stablehlo.custom_call @Sharding(%arg0) {mhlo.sharding = ""} : (tensor<2x3xi32>) -> tensor<2x3xi32> // CHECK-NEXT: %1 = stablehlo.multiply %0, %0 : tensor<2x3xi32> diff --git a/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-stablehlo-tf-fb-tf.mlir b/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-stablehlo-tf-fb-tf.mlir index f50c399deb5579..d0da1f09fa5ae1 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-stablehlo-tf-fb-tf.mlir +++ b/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-stablehlo-tf-fb-tf.mlir @@ -8,10 +8,8 @@ func.func @main(%arg0: tensor<2xi32>) -> tensor<2xi32> { } } -// CHECK: module attributes {tfl.description = "MLIR Converted.", tfl.schema_version = 3 : i32} { -// CHECK-NEXT: func @main(%arg0: tensor<2xi32>) -> tensor<2xi32> attributes {tf.entry_function = {inputs = "arg0", outputs = "tfl.custom1"}} { +// CHECK: func @main(%arg0: tensor<2xi32>) -> tensor<2xi32> attributes {tf.entry_function = {inputs = "arg0", outputs = "tfl.custom1"}} { // CHECK-NEXT: %0 = "tfl.custom"(%arg0, %arg0) {custom_code = "stablehlo.add", custom_option = #tfl} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> // CHECK-NEXT: %1 = "tfl.custom"(%0, %arg0) {custom_code = "stablehlo.subtract", custom_option = #tfl} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> // CHECK-NEXT: return %1 : tensor<2xi32> // CHECK-NEXT: } -// CHECK-NEXT:} diff --git a/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-stablehlo-vhlo.mlir b/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-stablehlo-vhlo.mlir new file mode 100644 index 00000000000000..fed52f5473911a --- /dev/null +++ b/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-stablehlo-vhlo.mlir @@ -0,0 +1,91 @@ +// RUN: odml-to-stablehlo-opt %s --stablehlo-legalize-vhlo -split-input-file | FileCheck %s +// RUN: odml-to-stablehlo-opt --stablehlo-legalize-vhlo %s | odml-to-stablehlo-opt --vhlo-legalize-stablehlo > %t.0 +// RUN: odml-to-stablehlo-opt %s > %t.1 +// RUN: diff %t.0 %t.1 + +// CHECK-LABEL: op_tfl +func.func @op_tfl(%arg0 : tensor) -> (tensor) { + // CHECK: %0 = tfl.add %arg0, %arg0 {fused_activation_function = "NONE"} : tensor + %0 = tfl.add %arg0, %arg0 {fused_activation_function = "NONE"} : tensor + return %0 : tensor +} + +// ----- + +// CHECK-LABEL: op_shlo +func.func @op_shlo(%arg0 : tensor) -> (tensor) { + // CHECK: %0 = "vhlo.add_v1"(%arg0, %arg0) : (tensor, tensor) -> tensor + %0 = stablehlo.add %arg0, %arg0 : tensor + return %0 : tensor +} + +// ----- + +// CHECK-LABEL: mixed_shlo_tfl_shlo +func.func @mixed_shlo_tfl_shlo(%arg0 : tensor) -> (tensor) { + // CHECK: %0 = "vhlo.abs_v1"(%arg0) : (tensor) -> tensor + // CHECK-NEXT: %1 = tfl.add %0, %arg0 {fused_activation_function = "NONE"} : tensor + // CHECK-NEXT: %2 = "vhlo.abs_v1"(%1) : (tensor) -> tensor + %0 = stablehlo.abs %arg0 : tensor + %1 = tfl.add %0, %arg0 {fused_activation_function = "NONE"} : tensor + %2 = stablehlo.abs %1 : tensor + return %2 : tensor +} + +// ----- + +// CHECK-LABEL: mixed_tfl_shlo_tfl +func.func @mixed_tfl_shlo_tfl(%arg0 : tensor) -> (tensor) { + %0 = "tfl.abs"(%arg0) {fused_activation_function = "NONE"} : (tensor) -> tensor + // CHECK: %1 = "vhlo.add_v1"(%0, %arg0) : (tensor, tensor) -> tensor + %1 = stablehlo.add %0, %arg0 : tensor + %2 = "tfl.abs"(%1) {fused_activation_function = "NONE"} : (tensor) -> tensor + return %2 : tensor +} + +// ----- + +// CHECK-LABEL: op_with_region +func.func @op_with_region(%arg0: tensor<1x16x16x320xf32>, %arg1: tensor) -> tensor<1x320xf32> { + // CHECK: %0 = "vhlo.reduce_v1"(%arg0, %arg1) <{{.*}}> ({ + // CHECK-NEXT: ^bb0(%arg2: tensor, %arg3: tensor): + // CHECK-NEXT: %1 = "vhlo.add_v1"(%arg2, %arg3) : (tensor, tensor) -> tensor + // CHECK-NEXT: "vhlo.return_v1"(%1) : (tensor) -> () + // CHECK-NEXT: }) : (tensor<1x16x16x320xf32>, tensor) -> tensor<1x320xf32> + %0 = stablehlo.reduce(%arg0 init: %arg1) applies stablehlo.add across dimensions = [1, 2] : (tensor<1x16x16x320xf32>, tensor) -> tensor<1x320xf32> + return %0 : tensor<1x320xf32> +} + +// ----- + +// CHECK-LABEL: op_with_region_mixed_tfl_shlo_tfl +func.func @op_with_region_mixed_tfl_shlo_tfl(%arg0: tensor<7x5xf32>, %arg1 : tensor<5xf32>) -> tensor<5xf32> { + %0 = "stablehlo.reduce"(%arg0, %arg1) ({ + ^bb0(%arg2: tensor<5xf32>, %arg3: tensor<5xf32>): + // CHECK: %1 = "tfl.abs"(%arg2) {fused_activation_function = "NONE"} : (tensor<5xf32>) -> tensor<5xf32> + // CHECK-NEXT: %2 = "vhlo.add_v1"(%1, %arg2) : (tensor<5xf32>, tensor<5xf32>) -> tensor<5xf32> + // CHECK-NEXT: %3 = "tfl.abs"(%2) {fused_activation_function = "NONE"} : (tensor<5xf32>) -> tensor<5xf32> + %1 = "tfl.abs"(%arg2) {fused_activation_function = "NONE"} : (tensor<5xf32>) -> tensor<5xf32> + %2 = stablehlo.add %1, %arg2 : tensor<5xf32> + %3 = "tfl.abs"(%2) {fused_activation_function = "NONE"} : (tensor<5xf32>) -> tensor<5xf32> + "stablehlo.return"(%3) : (tensor<5xf32>) -> () + }) {dimensions = array} : (tensor<7x5xf32>, tensor<5xf32>) -> tensor<5xf32> + func.return %0: tensor<5xf32> +} + +// ----- + +// CHECK-LABEL: op_with_region_mixed_shlo_tfl_shlo +func.func @op_with_region_mixed_shlo_tfl_shlo(%arg0: tensor<7x5xf32>, %arg1 : tensor<5xf32>) -> tensor<5xf32> { + %0 = "stablehlo.reduce"(%arg0, %arg1) ({ + ^bb0(%arg2: tensor<5xf32>, %arg3: tensor<5xf32> ): + // CHECK: %1 = "vhlo.abs_v1"(%arg2) : (tensor<5xf32>) -> tensor<5xf32> + // CHECK-NEXT: %2 = tfl.add %1, %arg2 {fused_activation_function = "NONE"} : tensor<5xf32> + // CHECK-NEXT: %3 = "vhlo.abs_v1"(%2) : (tensor<5xf32>) -> tensor<5xf32> + %1 = stablehlo.abs %arg2 : tensor<5xf32> + %2 = tfl.add %1, %arg2 {fused_activation_function = "NONE"} : tensor<5xf32> + %3 = stablehlo.abs %2 : tensor<5xf32> + "stablehlo.return"(%3) : (tensor<5xf32>) -> () + }) {dimensions = array} : (tensor<7x5xf32>, tensor<5xf32>) -> tensor<5xf32> + func.return %0: tensor<5xf32> +} diff --git a/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize_hlo.mlir b/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize_hlo.mlir index d9d86dac6782e6..4152a1b785daa6 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize_hlo.mlir +++ b/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize_hlo.mlir @@ -8,7 +8,7 @@ // CHECK: return %[[VAL_2]] : tensor<1x32x10x32xi32> // CHECK: } func.func @biasAdd_NHWC(%arg0: tensor<1x32x10x32xi32>, %arg1: tensor<32xi32>) -> tensor<1x32x10x32xi32> { - %0 = "chlo.broadcast_add"(%arg0, %arg1) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<1x32x10x32xi32>, tensor<32xi32>) -> tensor<1x32x10x32xi32> + %0 = "chlo.broadcast_add"(%arg0, %arg1) {broadcast_dimensions = array} : (tensor<1x32x10x32xi32>, tensor<32xi32>) -> tensor<1x32x10x32xi32> func.return %0 : tensor<1x32x10x32xi32> } @@ -19,7 +19,7 @@ func.func @biasAdd_NHWC(%arg0: tensor<1x32x10x32xi32>, %arg1: tensor<32xi32>) -> // CHECK: return %[[VAL_2]] : tensor<1x32x10x32xi32> // CHECK: } func.func @biasAdd_NCHW(%arg0: tensor<1x32x10x32xi32>, %arg1: tensor<32xi32>) -> tensor<1x32x10x32xi32> { - %0 = "chlo.broadcast_add"(%arg0, %arg1) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<1x32x10x32xi32>, tensor<32xi32>) -> tensor<1x32x10x32xi32> + %0 = "chlo.broadcast_add"(%arg0, %arg1) {broadcast_dimensions = array} : (tensor<1x32x10x32xi32>, tensor<32xi32>) -> tensor<1x32x10x32xi32> func.return %0 : tensor<1x32x10x32xi32> } @@ -30,7 +30,7 @@ func.func @biasAdd_NCHW(%arg0: tensor<1x32x10x32xi32>, %arg1: tensor<32xi32>) -> // CHECK: return %[[VAL_2]] : tensor // CHECK: } func.func @biasAdd_dynamic(%arg0: tensor, %arg1: tensor) -> tensor { - %0 = "chlo.broadcast_add"(%arg0, %arg1) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor, tensor) -> tensor + %0 = "chlo.broadcast_add"(%arg0, %arg1) {broadcast_dimensions = array} : (tensor, tensor) -> tensor func.return %0 : tensor } @@ -68,7 +68,7 @@ func.func @broadcast_add(%arg0: tensor<1x1xf32>, %arg1: tensor<1x1000xf32>) -> ( // CHECK: return %[[VAL_2]] : tensor<1x2xi32> // CHECK: } func.func @broadcast_add_chlo(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi32> { - %0 = "chlo.broadcast_add"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32> + %0 = "chlo.broadcast_add"(%arg0, %arg1) {broadcast_dimensions = array} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32> func.return %0 : tensor<1x2xi32> } @@ -79,14 +79,14 @@ func.func @broadcast_add_chlo(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> t // CHECK: return %[[VAL_2]] : tensor<4x4x4x4xi32> // CHECK: } func.func @broadcast_multi_dim_add(%arg0: tensor<4x1x1xi32>, %arg1: tensor<4x4x4x4xi32>) -> tensor<4x4x4x4xi32> { - %0 = "chlo.broadcast_add"(%arg0, %arg1) {broadcast_dimensions = dense<[1, 2, 3]> : tensor<3xi64>} : (tensor<4x1x1xi32>, tensor<4x4x4x4xi32>) -> tensor<4x4x4x4xi32> + %0 = "chlo.broadcast_add"(%arg0, %arg1) {broadcast_dimensions = array} : (tensor<4x1x1xi32>, tensor<4x4x4x4xi32>) -> tensor<4x4x4x4xi32> func.return %0 : tensor<4x4x4x4xi32> } // CHECK-LABEL: func @unsupported_broadcast_add // CHECK: chlo.broadcast_add func.func @unsupported_broadcast_add(%arg0: tensor<4x1x1xi32>, %arg1: tensor<4x4x4x4xi32>) -> tensor<4x4x4x4xi32> { - %0 = "chlo.broadcast_add"(%arg0, %arg1) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<4x1x1xi32>, tensor<4x4x4x4xi32>) -> tensor<4x4x4x4xi32> + %0 = "chlo.broadcast_add"(%arg0, %arg1) {broadcast_dimensions = array} : (tensor<4x1x1xi32>, tensor<4x4x4x4xi32>) -> tensor<4x4x4x4xi32> func.return %0 : tensor<4x4x4x4xi32> } @@ -122,7 +122,7 @@ func.func @broadcast_div(%arg0: tensor<1x1xf32>, %arg1: tensor<1x1000xf32>) -> ( // CHECK: return %[[VAL_2]] : tensor<1x2xi32> // CHECK: } func.func @broadcast_div_chlo(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi32> { - %0 = "chlo.broadcast_divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32> + %0 = "chlo.broadcast_divide"(%arg0, %arg1) {broadcast_dimensions = array} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32> func.return %0 : tensor<1x2xi32> } @@ -159,7 +159,7 @@ func.func @broadcast_shift_left(%arg0: tensor<1xi32>, %arg1: tensor<4xi32>) -> ( // CHECK: return %[[VAL_2]] : tensor // CHECK: } func.func @div_dynamic(%arg0: tensor, %arg1: tensor) -> tensor { - %0 = "chlo.broadcast_divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor, tensor) -> tensor + %0 = "chlo.broadcast_divide"(%arg0, %arg1) {broadcast_dimensions = array} : (tensor, tensor) -> tensor func.return %0 : tensor } @@ -247,7 +247,7 @@ func.func @broadcast_mul(%arg0: tensor<1x1xf32>, %arg1: tensor<1x1000xf32>) -> ( // CHECK: return %[[VAL_2]] : tensor<1x2xi32> // CHECK: } func.func @broadcast_mul_chlo(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi32> { - %0 = "chlo.broadcast_multiply"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32> + %0 = "chlo.broadcast_multiply"(%arg0, %arg1) {broadcast_dimensions = array} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32> func.return %0 : tensor<1x2xi32> } @@ -268,7 +268,7 @@ func.func @real_div(%arg0: tensor<2xi32>) -> tensor<2xi32> { // CHECK: return %[[VAL_2]] : tensor<1x2xi32> // CHECK: } func.func @broadcast_real_div(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi32> { - %0 = "chlo.broadcast_divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32> + %0 = "chlo.broadcast_divide"(%arg0, %arg1) {broadcast_dimensions = array} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32> func.return %0 : tensor<1x2xi32> } @@ -304,7 +304,7 @@ func.func @broadcast_sub(%arg0: tensor<1x1xf32>, %arg1: tensor<1x1000xf32>) -> ( // CHECK: return %[[VAL_2]] : tensor<1x2xi32> // CHECK: } func.func @broadcast_sub_chlo(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi32> { - %0 = "chlo.broadcast_subtract"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32> + %0 = "chlo.broadcast_subtract"(%arg0, %arg1) {broadcast_dimensions = array} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32> func.return %0 : tensor<1x2xi32> } @@ -341,7 +341,7 @@ func.func @shift_right(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi // CHECK: return %[[VAL_2]] : tensor<2x4xi32> // CHECK: } func.func @broadcast_shift_right(%arg0: tensor<4xi32>, %arg1: tensor<2x4xi32>) -> tensor<2x4xi32> { - %0 = "chlo.broadcast_shift_right_arithmetic"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<4xi32>, tensor<2x4xi32>) -> tensor<2x4xi32> + %0 = "chlo.broadcast_shift_right_arithmetic"(%arg0, %arg1) {broadcast_dimensions = array} : (tensor<4xi32>, tensor<2x4xi32>) -> tensor<2x4xi32> func.return %0 : tensor<2x4xi32> } @@ -363,7 +363,7 @@ func.func @and(%arg0: tensor<2xi1>, %arg1: tensor<2xi1>) -> tensor<2xi1> { // CHECK: return %[[VAL_2]] : tensor<1x2xi1> // CHECK: } func.func @and_broadcast(%arg0: tensor<1xi1>, %arg1: tensor<1x2xi1>) -> tensor<1x2xi1> { - %0 = "chlo.broadcast_and"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi1>, tensor<1x2xi1>) -> tensor<1x2xi1> + %0 = "chlo.broadcast_and"(%arg0, %arg1) {broadcast_dimensions = array} : (tensor<1xi1>, tensor<1x2xi1>) -> tensor<1x2xi1> func.return %0 : tensor<1x2xi1> } @@ -396,7 +396,7 @@ func.func @or(%arg0: tensor<2xi1>, %arg1: tensor<2xi1>) -> tensor<2xi1> { // CHECK: return %[[VAL_2]] : tensor<1x2xi1> // CHECK: } func.func @or_broadcast(%arg0: tensor<1xi1>, %arg1: tensor<1x2xi1>) -> tensor<1x2xi1> { - %0 = "chlo.broadcast_or"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi1>, tensor<1x2xi1>) -> tensor<1x2xi1> + %0 = "chlo.broadcast_or"(%arg0, %arg1) {broadcast_dimensions = array} : (tensor<1xi1>, tensor<1x2xi1>) -> tensor<1x2xi1> func.return %0 : tensor<1x2xi1> } @@ -441,7 +441,7 @@ func.func @bitwise_or_broadcast(%arg0: tensor<1xi8>, %arg1: tensor<1x4xi8>) -> t // CHECK: return %[[VAL_2]] : tensor<1x4xi8> // CHECK: } func.func @bitwise_or_broadcast_chlo(%arg0: tensor<1xi8>, %arg1: tensor<1x4xi8>) -> tensor<1x4xi8> { - %0 = "chlo.broadcast_or"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi8>, tensor<1x4xi8>) -> tensor<1x4xi8> + %0 = "chlo.broadcast_or"(%arg0, %arg1) {broadcast_dimensions = array} : (tensor<1xi8>, tensor<1x4xi8>) -> tensor<1x4xi8> func.return %0 : tensor<1x4xi8> } @@ -474,7 +474,7 @@ func.func @bitwise_xor(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi // CHECK: return %[[VAL_2]] : tensor<1x4xi8> // CHECK: } func.func @bitwise_xor_broadcast(%arg0: tensor<1xi8>, %arg1: tensor<1x4xi8>) -> tensor<1x4xi8> { - %0 = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[1]> : tensor<1xi64>} : (tensor<1xi8>) -> tensor<1x4xi8> + %0 = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi8>) -> tensor<1x4xi8> %1 = mhlo.xor %0, %arg1 : tensor<1x4xi8> func.return %1 : tensor<1x4xi8> } @@ -509,7 +509,7 @@ func.func @bitwise_and_broadcast(%arg0: tensor<1xi8>, %arg1: tensor<1x4xi8>) -> // CHECK: return %[[VAL_2]] : tensor<1x4xi8> // CHECK: } func.func @bitwise_and_broadcast_chlo(%arg0: tensor<1xi8>, %arg1: tensor<1x4xi8>) -> tensor<1x4xi8> { - %0 = "chlo.broadcast_and"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi8>, tensor<1x4xi8>) -> tensor<1x4xi8> + %0 = "chlo.broadcast_and"(%arg0, %arg1) {broadcast_dimensions = array} : (tensor<1xi8>, tensor<1x4xi8>) -> tensor<1x4xi8> func.return %0 : tensor<1x4xi8> } @@ -584,16 +584,16 @@ func.func @floordiv_broadcast_i32(%arg0: tensor<2x3xi32>, %arg1: tensor<3xi32>) %1 = "chlo.broadcast_compare"(%arg0, %0) {comparison_direction = #chlo} : (tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi1> %2 = mhlo.constant dense<0> : tensor<3xi32> %3 = "chlo.broadcast_compare"(%arg1, %2) {comparison_direction = #chlo} : (tensor<3xi32>, tensor<3xi32>) -> tensor<3xi1> - %4 = "chlo.broadcast_compare"(%1, %3) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = #chlo} : (tensor<2x3xi1>, tensor<3xi1>) -> tensor<2x3xi1> - %5 = "chlo.broadcast_divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<2x3xi32>, tensor<3xi32>) -> tensor<2x3xi32> + %4 = "chlo.broadcast_compare"(%1, %3) {broadcast_dimensions = array, comparison_direction = #chlo} : (tensor<2x3xi1>, tensor<3xi1>) -> tensor<2x3xi1> + %5 = "chlo.broadcast_divide"(%arg0, %arg1) {broadcast_dimensions = array} : (tensor<2x3xi32>, tensor<3xi32>) -> tensor<2x3xi32> %6 = "mhlo.abs"(%arg0) : (tensor<2x3xi32>) -> tensor<2x3xi32> %7 = "mhlo.abs"(%arg1) : (tensor<3xi32>) -> tensor<3xi32> %8 = mhlo.constant dense<1> : tensor<3xi32> %9 = mhlo.subtract %7, %8 : tensor<3xi32> - %10 = "chlo.broadcast_add"(%6, %9) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<2x3xi32>, tensor<3xi32>) -> tensor<2x3xi32> + %10 = "chlo.broadcast_add"(%6, %9) {broadcast_dimensions = array} : (tensor<2x3xi32>, tensor<3xi32>) -> tensor<2x3xi32> %11 = "mhlo.negate"(%10) : (tensor<2x3xi32>) -> tensor<2x3xi32> %12 = "mhlo.abs"(%arg1) : (tensor<3xi32>) -> tensor<3xi32> - %13 = "chlo.broadcast_divide"(%11, %12) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<2x3xi32>, tensor<3xi32>) -> tensor<2x3xi32> + %13 = "chlo.broadcast_divide"(%11, %12) {broadcast_dimensions = array} : (tensor<2x3xi32>, tensor<3xi32>) -> tensor<2x3xi32> %14 = "mhlo.select"(%4, %5, %13) : (tensor<2x3xi1>, tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32> func.return %14 : tensor<2x3xi32> } @@ -623,13 +623,13 @@ func.func @floordiv_reverse_broadcast_i32(%arg0: tensor<3xi32>, %arg1: tensor<2x %1 = "mhlo.compare"(%arg0, %0) {comparison_direction = #mhlo} : (tensor<3xi32>, tensor<3xi32>) -> tensor<3xi1> %2 = mhlo.constant dense<0> : tensor<2x3xi32> %3 = "chlo.broadcast_compare"(%arg1, %2) {comparison_direction = #chlo} : (tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi1> - %4 = "chlo.broadcast_compare"(%1, %3) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = #chlo} : (tensor<3xi1>, tensor<2x3xi1>) -> tensor<2x3xi1> - %5 = "chlo.broadcast_divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32> + %4 = "chlo.broadcast_compare"(%1, %3) {broadcast_dimensions = array, comparison_direction = #chlo} : (tensor<3xi1>, tensor<2x3xi1>) -> tensor<2x3xi1> + %5 = "chlo.broadcast_divide"(%arg0, %arg1) {broadcast_dimensions = array} : (tensor<3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32> %6 = "mhlo.abs"(%arg0) : (tensor<3xi32>) -> tensor<3xi32> %7 = "mhlo.abs"(%arg1) : (tensor<2x3xi32>) -> tensor<2x3xi32> %8 = mhlo.constant dense<1> : tensor<2x3xi32> %9 = mhlo.subtract %7, %8 : tensor<2x3xi32> - %10 = "chlo.broadcast_add"(%6, %9) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32> + %10 = "chlo.broadcast_add"(%6, %9) {broadcast_dimensions = array} : (tensor<3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32> %11 = "mhlo.negate"(%10) : (tensor<2x3xi32>) -> tensor<2x3xi32> %12 = "mhlo.abs"(%arg1) : (tensor<2x3xi32>) -> tensor<2x3xi32> %13 = mhlo.divide %11, %12 : tensor<2x3xi32> @@ -660,8 +660,8 @@ func.func @floordiv_f32(%arg0: tensor<2xf32>) -> tensor<2xf32> { // CHECK: return %[[VAL_4]] : tensor<2x3xf16> // CHECK: } func.func @floordiv_f16_broadcast(%arg0: tensor<2x3xf16>, %arg1: tensor<3xf16>) -> tensor<2x3xf16> { - %0 = "chlo.broadcast_divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<2x3xf16>, tensor<3xf16>) -> tensor<2x3xf16> - %1 = "chlo.broadcast_divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<2x3xf16>, tensor<3xf16>) -> tensor<2x3xf16> + %0 = "chlo.broadcast_divide"(%arg0, %arg1) {broadcast_dimensions = array} : (tensor<2x3xf16>, tensor<3xf16>) -> tensor<2x3xf16> + %1 = "chlo.broadcast_divide"(%arg0, %arg1) {broadcast_dimensions = array} : (tensor<2x3xf16>, tensor<3xf16>) -> tensor<2x3xf16> %2 = "mhlo.floor"(%1) : (tensor<2x3xf16>) -> tensor<2x3xf16> func.return %2 : tensor<2x3xf16> } @@ -707,7 +707,7 @@ func.func @equal_broadcast(%arg0: tensor<1x1xi32>, %arg1: tensor<1x2xi32>) -> te // CHECK: return %[[VAL_2]] : tensor<1x2xi1> // CHECK: } func.func @equal_broadcast_chlo(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi1> { - %0 = "chlo.broadcast_compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = #chlo} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> + %0 = "chlo.broadcast_compare"(%arg0, %arg1) {broadcast_dimensions = array, comparison_direction = #chlo} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> func.return %0 : tensor<1x2xi1> } @@ -718,7 +718,7 @@ func.func @equal_broadcast_chlo(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> // CHECK: return %[[VAL_2]] : tensor<1x2xi1> // CHECK: } func.func @equal_broadcast_no_incompatible_shapes_error(%arg0: tensor<2xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi1> { - %0 = "chlo.broadcast_compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = #chlo} : (tensor<2xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> + %0 = "chlo.broadcast_compare"(%arg0, %arg1) {broadcast_dimensions = array, comparison_direction = #chlo} : (tensor<2xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> func.return %0 : tensor<1x2xi1> } @@ -736,7 +736,7 @@ func.func @equal_incompatible_shape_broadcastable(%arg0: tensor, %arg1: t // CHECK-LABEL: func @equal_unsupported_compare_type func.func @equal_unsupported_compare_type(%arg0: tensor<1xf32>, %arg1: tensor<1x2xf32>) -> tensor<1x2xi1> { // CHECK: chlo.broadcast_compare - %0 = "chlo.broadcast_compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, compare_type = #chlo, comparison_direction = #chlo} : (tensor<1xf32>, tensor<1x2xf32>) -> tensor<1x2xi1> + %0 = "chlo.broadcast_compare"(%arg0, %arg1) {broadcast_dimensions = array, compare_type = #chlo, comparison_direction = #chlo} : (tensor<1xf32>, tensor<1x2xf32>) -> tensor<1x2xi1> func.return %0 : tensor<1x2xi1> } @@ -770,7 +770,7 @@ func.func @notequal_broadcast(%arg0: tensor<1x1xi32>, %arg1: tensor<1x2xi32>) -> // CHECK: return %[[VAL_2]] : tensor<1x2xi1> // CHECK: } func.func @notequal_broadcast_chlo(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi1> { - %0 = "chlo.broadcast_compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = #chlo} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> + %0 = "chlo.broadcast_compare"(%arg0, %arg1) {broadcast_dimensions = array, comparison_direction = #chlo} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> func.return %0 : tensor<1x2xi1> } @@ -781,7 +781,7 @@ func.func @notequal_broadcast_chlo(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) // CHECK: return %[[VAL_2]] : tensor<1x2xi1> // CHECK: } func.func @notequal_broadcast_no_incompatible_shapes_error(%arg0: tensor<2xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi1> { - %0 = "chlo.broadcast_compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = #chlo} : (tensor<2xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> + %0 = "chlo.broadcast_compare"(%arg0, %arg1) {broadcast_dimensions = array, comparison_direction = #chlo} : (tensor<2xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> func.return %0 : tensor<1x2xi1> } @@ -826,7 +826,7 @@ func.func @broadcast_greater(%arg0: tensor<1x1xi32>, %arg1: tensor<1x2xi32>) -> // CHECK: return %[[VAL_2]] : tensor<1x2xi1> // CHECK: } func.func @broadcast_greater_chlo(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi1> { - %0 = "chlo.broadcast_compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = #chlo} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> + %0 = "chlo.broadcast_compare"(%arg0, %arg1) {broadcast_dimensions = array, comparison_direction = #chlo} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> func.return %0 : tensor<1x2xi1> } @@ -867,7 +867,7 @@ func.func @broadcast_greater_equal(%arg0: tensor<1x1xi32>, %arg1: tensor<1x2xi32 // CHECK: return %[[VAL_2]] : tensor<1x2xi1> // CHECK: } func.func @broadcast_greater_equal_chlo(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi1> { - %0 = "chlo.broadcast_compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = #chlo} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> + %0 = "chlo.broadcast_compare"(%arg0, %arg1) {broadcast_dimensions = array, comparison_direction = #chlo} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> func.return %0 : tensor<1x2xi1> } @@ -901,7 +901,7 @@ func.func @broadcast_less(%arg0: tensor<1x1xi32>, %arg1: tensor<1x2xi32>) -> ten // CHECK: return %[[VAL_2]] : tensor<1x2xi1> // CHECK: } func.func @broadcast_less_chlo(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi1> { - %0 = "chlo.broadcast_compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = #chlo} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> + %0 = "chlo.broadcast_compare"(%arg0, %arg1) {broadcast_dimensions = array, comparison_direction = #chlo} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> func.return %0 : tensor<1x2xi1> } @@ -935,7 +935,7 @@ func.func @broadcast_less_equal(%arg0: tensor<1x1xi32>, %arg1: tensor<1x2xi32>) // CHECK: return %[[VAL_2]] : tensor<1x2xi1> // CHECK: } func.func @broadcast_less_equal_chlo(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi1> { - %0 = "chlo.broadcast_compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = #chlo} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> + %0 = "chlo.broadcast_compare"(%arg0, %arg1) {broadcast_dimensions = array, comparison_direction = #chlo} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> func.return %0 : tensor<1x2xi1> } @@ -980,7 +980,7 @@ func.func @const() -> tensor<2xi32> { // CHECK: } func.func @relu(%arg0: tensor<1xi32>) -> tensor<1xi32> { %0 = mhlo.constant dense<0> : tensor - %1 = "chlo.broadcast_maximum"(%0, %arg0) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor, tensor<1xi32>) -> tensor<1xi32> + %1 = "chlo.broadcast_maximum"(%0, %arg0) {broadcast_dimensions = array} : (tensor, tensor<1xi32>) -> tensor<1xi32> func.return %1 : tensor<1xi32> } @@ -992,7 +992,7 @@ func.func @relu(%arg0: tensor<1xi32>) -> tensor<1xi32> { // CHECK: } func.func @relu_unranked(%arg0: tensor) -> tensor { %0 = mhlo.constant dense<0> : tensor - %1 = "chlo.broadcast_maximum"(%0, %arg0) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor, tensor) -> tensor + %1 = "chlo.broadcast_maximum"(%0, %arg0) {broadcast_dimensions = array} : (tensor, tensor) -> tensor func.return %1 : tensor } @@ -1007,8 +1007,8 @@ func.func @relu_unranked(%arg0: tensor) -> tensor { func.func @relu6(%arg0: tensor<1xi32>) -> tensor<1xi32> { %0 = mhlo.constant dense<0> : tensor %1 = mhlo.constant dense<6> : tensor - %2 = "chlo.broadcast_minimum"(%arg0, %1) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<1xi32>, tensor) -> tensor<1xi32> - %3 = "chlo.broadcast_maximum"(%2, %0) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<1xi32>, tensor) -> tensor<1xi32> + %2 = "chlo.broadcast_minimum"(%arg0, %1) {broadcast_dimensions = array} : (tensor<1xi32>, tensor) -> tensor<1xi32> + %3 = "chlo.broadcast_maximum"(%2, %0) {broadcast_dimensions = array} : (tensor<1xi32>, tensor) -> tensor<1xi32> func.return %3 : tensor<1xi32> } @@ -1023,8 +1023,8 @@ func.func @relu6(%arg0: tensor<1xi32>) -> tensor<1xi32> { func.func @relu6_unranked(%arg0: tensor) -> tensor { %0 = mhlo.constant dense<0> : tensor %1 = mhlo.constant dense<6> : tensor - %2 = "chlo.broadcast_minimum"(%arg0, %1) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor, tensor) -> tensor - %3 = "chlo.broadcast_maximum"(%2, %0) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor, tensor) -> tensor + %2 = "chlo.broadcast_minimum"(%arg0, %1) {broadcast_dimensions = array} : (tensor, tensor) -> tensor + %3 = "chlo.broadcast_maximum"(%2, %0) {broadcast_dimensions = array} : (tensor, tensor) -> tensor func.return %3 : tensor } @@ -1039,7 +1039,7 @@ func.func @relu6_unranked(%arg0: tensor) -> tensor { // CHECK: } func.func @relu_grad(%arg0: tensor<4x8xf32>, %arg1: tensor) -> tensor<4x8xf32> { %0 = mhlo.constant dense<0.000000e+00> : tensor - %1 = "chlo.broadcast_compare"(%arg1, %0) {broadcast_dimensions = dense<[]> : tensor<0xi64>, comparison_direction = #chlo} : (tensor, tensor) -> tensor + %1 = "chlo.broadcast_compare"(%arg1, %0) {broadcast_dimensions = array, comparison_direction = #chlo} : (tensor, tensor) -> tensor %2 = mhlo.constant dense<0.000000e+00> : tensor<4x8xf32> %3 = "mhlo.select"(%1, %arg0, %2) : (tensor, tensor<4x8xf32>, tensor<4x8xf32>) -> tensor<4x8xf32> func.return %3 : tensor<4x8xf32> @@ -2140,6 +2140,8 @@ func.func @convert_dot_general_dynamic_contracting_dim(%arg0: tensor<4x4x?xf32>, func.return %0 : tensor<4x4x256xf32> } + + // CHECK-LABEL: func.func @convert_conv1d( // CHECK-SAME: %[[VAL_0:.*]]: tensor<16x32x256xbf16>, // CHECK-SAME: %[[VAL_1:.*]]: tensor<1x256x256xbf16>) -> tensor<16x32x256xbf16> { @@ -2204,7 +2206,29 @@ func.func @convert_conv1d_dynamic_batch(%arg0: tensor, %arg1: ten func.return %0 : tensor } - +// CHECK-LABEL: convert_dynamic_1d_group_conv +func.func private @convert_dynamic_1d_group_conv(%arg1: tensor, %arg2: tensor<768x48x128xf32>) -> (tensor) { + %0 = mhlo.convolution(%arg1, %arg2) + dim_numbers = [b, f, 0]x[o, i, 0]->[b, f, 0], + window = {pad = [[64, 64]]} + {batch_group_count = 1 : i64, feature_group_count = 16 : i64} + : (tensor, tensor<768x48x128xf32>) -> tensor + return %0 : tensor +// CHECK: %cst = arith.constant dense<[-9223372036854775808, 768, 2, 1]> : tensor<4xi64> +// CHECK: %0 = "tf.Reshape"(%arg0, %cst) : (tensor, tensor<4xi64>) -> tensor +// CHECK: %cst_0 = "tf.Const"() <{value = dense<[0, 2, 3, 1]> : tensor<4xi64>}> : () -> tensor<4xi64> +// CHECK: %1 = "tf.Transpose"(%0, %cst_0) : (tensor, tensor<4xi64>) -> tensor +// CHECK: %cst_1 = arith.constant dense<[768, 48, 128, 1]> : tensor<4xi64> +// CHECK: %2 = "tf.Reshape"(%arg1, %cst_1) : (tensor<768x48x128xf32>, tensor<4xi64>) -> tensor<768x48x128x1xf32> +// CHECK: %cst_2 = "tf.Const"() <{value = dense<[2, 3, 1, 0]> : tensor<4xi64>}> : () -> tensor<4xi64> +// CHECK: %3 = "tf.Transpose"(%2, %cst_2) : (tensor<768x48x128x1xf32>, tensor<4xi64>) -> tensor<128x1x48x768xf32> +// CHECK: %4 = "tf.Conv2D"(%1, %3) <{data_format = "NHWC", dilations = [1, 1, 1, 1], explicit_paddings = [0, 0, 64, 64, 0, 0, 0, 0], padding = "EXPLICIT", strides = [1, 1, 1, 1], use_cudnn_on_gpu = true}> : (tensor, tensor<128x1x48x768xf32>) -> tensor +// CHECK: %cst_3 = "tf.Const"() <{value = dense<[0, 3, 1, 2]> : tensor<4xi64>}> : () -> tensor<4xi64> +// CHECK: %5 = "tf.Transpose"(%4, %cst_3) : (tensor, tensor<4xi64>) -> tensor +// CHECK: %cst_4 = arith.constant dense<[-9223372036854775808, 768, 3]> : tensor<3xi64> +// CHECK: %6 = "tf.Reshape"(%5, %cst_4) : (tensor, tensor<3xi64>) -> tensor +// CHECK: return %6 : tensor +} // CHECK-LABEL: func.func @convert_conv1d_no_lhs_dil_rhs_dil_precision_conf( // CHECK-SAME: %[[VAL_0:.*]]: tensor<16x32x256xbf16>, @@ -2318,13 +2342,22 @@ func.func @no_convert_conv1d_dynamic(%arg0: tensor<16x?x256xbf16>, %arg1: tensor func.return %0 : tensor<16x?x256xbf16> } -// CHECK-LABEL: func.func @no_convert_conv1d_feature_group_gt_1( -// CHECK-SAME: %[[VAL_0:.*]]: tensor<16x32x256xbf16>, -// CHECK-SAME: %[[VAL_1:.*]]: tensor<1x128x128xbf16>) -> tensor<16x32x128xbf16> { -// CHECK: %[[VAL_2:.*]] = mhlo.convolution(%[[VAL_0]], %[[VAL_1]]) dim_numbers = [b, 0, f]x[0, i, o]->[b, 0, f], window = {stride = [1], pad = {{\[\[}}0, 0]], lhs_dilate = [1], rhs_dilate = [1]} {batch_group_count = 1 : i64, feature_group_count = 2 : i64, precision_config = [#mhlo, #mhlo]} : (tensor<16x32x256xbf16>, tensor<1x128x128xbf16>) -> tensor<16x32x128xbf16> -// CHECK: return %[[VAL_2]] : tensor<16x32x128xbf16> -// CHECK: } -func.func @no_convert_conv1d_feature_group_gt_1(%arg0: tensor<16x32x256xbf16>, %arg1: tensor<1x128x128xbf16>) -> tensor<16x32x128xbf16> { +// CHECK-LABEL: func.func @convert_conv1d_feature_group_gt_1( +// CHECK: %cst = arith.constant dense<[16, 32, 256, 1]> : tensor<4xi64> +// CHECK: %0 = "tf.Reshape"(%arg0, %cst) : (tensor<16x32x256xbf16>, tensor<4xi64>) -> tensor<16x32x256x1xbf16> +// CHECK: %cst_0 = "tf.Const"() <{value = dense<[0, 1, 3, 2]> : tensor<4xi64>}> : () -> tensor<4xi64> +// CHECK: %1 = "tf.Transpose"(%0, %cst_0) : (tensor<16x32x256x1xbf16>, tensor<4xi64>) -> tensor<16x32x1x256xbf16> +// CHECK: %cst_1 = arith.constant dense<[1, 128, 128, 1]> : tensor<4xi64> +// CHECK: %2 = "tf.Reshape"(%arg1, %cst_1) : (tensor<1x128x128xbf16>, tensor<4xi64>) -> tensor<1x128x128x1xbf16> +// CHECK: %cst_2 = "tf.Const"() <{value = dense<[0, 3, 1, 2]> : tensor<4xi64>}> : () -> tensor<4xi64> +// CHECK: %3 = "tf.Transpose"(%2, %cst_2) : (tensor<1x128x128x1xbf16>, tensor<4xi64>) -> tensor<1x1x128x128xbf16> +// CHECK: %4 = "tf.Conv2D"(%1, %3) <{data_format = "NHWC", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "VALID", strides = [1, 1, 1, 1], use_cudnn_on_gpu = true}> : (tensor<16x32x1x256xbf16>, tensor<1x1x128x128xbf16>) -> tensor<16x32x1x128xbf16> +// CHECK: %cst_3 = "tf.Const"() <{value = dense<[0, 1, 3, 2]> : tensor<4xi64>}> : () -> tensor<4xi64> +// CHECK: %5 = "tf.Transpose"(%4, %cst_3) : (tensor<16x32x1x128xbf16>, tensor<4xi64>) -> tensor<16x32x128x1xbf16> +// CHECK: %cst_4 = arith.constant dense<[16, 32, 128]> : tensor<3xi64> +// CHECK: %6 = "tf.Reshape"(%5, %cst_4) : (tensor<16x32x128x1xbf16>, tensor<3xi64>) -> tensor<16x32x128xbf16> +// CHECK: return %6 : tensor<16x32x128xbf16> +func.func @convert_conv1d_feature_group_gt_1(%arg0: tensor<16x32x256xbf16>, %arg1: tensor<1x128x128xbf16>) -> tensor<16x32x128xbf16> { %0 = "mhlo.convolution"(%arg0, %arg1) { batch_group_count = 1 : i64, dimension_numbers = #mhlo.conv<[b, 0, f]x[0, i, o]->[b, 0, f]>, diff --git a/tensorflow/compiler/mlir/lite/stablehlo/tests/tf-tfl-translate-serialize-stablehlo.mlir b/tensorflow/compiler/mlir/lite/stablehlo/tests/tf-tfl-translate-serialize-stablehlo.mlir index 4e12ffd931c5f2..27e22cb524b8af 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/tests/tf-tfl-translate-serialize-stablehlo.mlir +++ b/tensorflow/compiler/mlir/lite/stablehlo/tests/tf-tfl-translate-serialize-stablehlo.mlir @@ -11,12 +11,12 @@ func.func @tfInplaceUpdate(%arg0: tensor<2x1x2xf32>) -> tensor<2x1x2xf32> { } } -//CHECK: module attributes {tfl.description = "MLIR Converted.", tfl.metadata = {keep_stablehlo_constant = "true"}, tfl.schema_version = 3 : i32} { +//CHECK: module attributes +//CHECK-SAME: keep_stablehlo_constant = "true" //CHECK-NEXT: func.func @main(%arg0: tensor<2x1x2xf32>) -> tensor<2x1x2xf32> attributes {tf.entry_function = {inputs = "arg0", outputs = "stablehlo.dynamic_update_slice"}} { //CHECK-DAG: %0 = stablehlo.constant dense<2.000000e+00> : tensor<1x1x2xf32> //CHECK-DAG: %1 = stablehlo.constant dense<1> : tensor //CHECK-DAG: %2 = stablehlo.constant dense<0> : tensor //CHECK-NEXT: %3 = stablehlo.dynamic_update_slice %arg0, %0, %1, %2, %2 : (tensor<2x1x2xf32>, tensor<1x1x2xf32>, tensor, tensor, tensor) -> tensor<2x1x2xf32> //CHECK-NEXT: return %3 : tensor<2x1x2xf32> -//CHECK-NEXT: } -//CHECK-NEXT:} \ No newline at end of file +//CHECK-NEXT: } \ No newline at end of file diff --git a/tensorflow/compiler/mlir/lite/stablehlo/tests/uniform-quantized-stablehlo-to-tfl.mlir b/tensorflow/compiler/mlir/lite/stablehlo/tests/uniform-quantized-stablehlo-to-tfl.mlir index 3cf031d6a5dbad..6f96d6637821d5 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/tests/uniform-quantized-stablehlo-to-tfl.mlir +++ b/tensorflow/compiler/mlir/lite/stablehlo/tests/uniform-quantized-stablehlo-to-tfl.mlir @@ -1,12 +1,19 @@ // RUN: odml-to-stablehlo-opt --uniform-quantized-stablehlo-to-tfl \ // RUN: --split-input-file --verify-diagnostics %s | FileCheck %s -// CHECK-LABEL: uniform_quantize_op +// ============================================================================ +// The following functions tests example quantization patterns outputted from +// JAX Quantizer. JAX Quantizer should output integer types, which are +// composed into `UniformQuantized{|PerAxis}Type` via +// `compose_uniform_quantized_type_pass.cc`. +// ============================================================================ + func.func @uniform_quantize_op(%arg: tensor<2x2xf32>) -> tensor<2x2x!quant.uniform> { %0 = stablehlo.uniform_quantize %arg : (tensor<2x2xf32>) -> tensor<2x2x!quant.uniform> return %0 : tensor<2x2x!quant.uniform> } -// CHECK: %[[QUANT:.*]] = "tfl.quantize"({{.*}}) {qtype = tensor<2x2x!quant.uniform>} : (tensor<2x2xf32>) -> tensor<2x2x!quant.uniform> +// CHECK-LABEL: uniform_quantize_op +// CHECK: %[[QUANT:.+]] = "tfl.quantize"({{.*}}) {qtype = tensor<2x2x!quant.uniform>} : (tensor<2x2xf32>) -> tensor<2x2x!quant.uniform> // CHECK: return %[[QUANT]] // ----- @@ -14,11 +21,11 @@ func.func @uniform_quantize_op(%arg: tensor<2x2xf32>) -> tensor<2x2x!quant.unifo // Tests that the pattern doesn't match when the input tensor's type is a // quantized type. -// CHECK-LABEL: uniform_quantize_op_quantized_input func.func @uniform_quantize_op_quantized_input(%arg: tensor<2x2x!quant.uniform>) -> tensor<2x2x!quant.uniform> { %0 = stablehlo.uniform_quantize %arg : (tensor<2x2x!quant.uniform>) -> tensor<2x2x!quant.uniform> return %0 : tensor<2x2x!quant.uniform> } +// CHECK-LABEL: uniform_quantize_op_quantized_input // CHECK: stablehlo.uniform_quantize // CHECK-NOT: tfl.quantize @@ -28,11 +35,11 @@ func.func @uniform_quantize_op_quantized_input(%arg: tensor<2x2x!quant.uniform) -> tensor<2x2x!quant.uniform> { %0 = stablehlo.uniform_quantize %arg : (tensor<2x2xf32>) -> tensor<2x2x!quant.uniform> return %0 : tensor<2x2x!quant.uniform> } +// CHECK-LABEL: uniform_quantize_op_uint16_output // CHECK: stablehlo.uniform_quantize // CHECK-NOT: tfl.quantize @@ -42,22 +49,22 @@ func.func @uniform_quantize_op_uint16_output(%arg: tensor<2x2xf32>) -> tensor<2x // is i32. i32 storage type for quantized type is not compatible with // `tfl.quantize`. -// CHECK-LABEL: uniform_quantize_op_i32_output func.func @uniform_quantize_op_i32_output(%arg: tensor<2x2xf32>) -> tensor<2x2x!quant.uniform> { %0 = stablehlo.uniform_quantize %arg : (tensor<2x2xf32>) -> tensor<2x2x!quant.uniform> return %0 : tensor<2x2x!quant.uniform> } +// CHECK-LABEL: uniform_quantize_op_i32_output // CHECK: stablehlo.uniform_quantize // CHECK-NOT: tfl.quantize // ----- -// CHECK-LABEL: uniform_dequantize_op func.func @uniform_dequantize_op(%arg: tensor<2x2x!quant.uniform>) -> tensor<2x2xf32> { %0 = stablehlo.uniform_dequantize %arg : (tensor<2x2x!quant.uniform>) -> tensor<2x2xf32> return %0 : tensor<2x2xf32> } -// CHECK: %[[DEQUANT:.*]] = "tfl.dequantize"({{.*}}) : (tensor<2x2x!quant.uniform>) -> tensor<2x2xf32> +// CHECK-LABEL: uniform_dequantize_op +// CHECK: %[[DEQUANT:.+]] = "tfl.dequantize"({{.*}}) : (tensor<2x2x!quant.uniform>) -> tensor<2x2xf32> // CHECK: return %[[DEQUANT]] // ----- @@ -66,11 +73,11 @@ func.func @uniform_dequantize_op(%arg: tensor<2x2x!quant.uniform>) -> tensor<2x2xf32> { %0 = stablehlo.uniform_dequantize %arg : (tensor<2x2x!quant.uniform>) -> tensor<2x2xf32> return %0 : tensor<2x2xf32> } +// CHECK-LABEL: uniform_dequantize_op_ui16_storage_input // CHECK: stablehlo.uniform_dequantize // CHECK-NOT: tfl.dequantize @@ -80,11 +87,11 @@ func.func @uniform_dequantize_op_ui16_storage_input(%arg: tensor<2x2x!quant.unif // storage type is i32. i32 storage type is not compatible with // `tfl.dequantize`. -// CHECK-LABEL: uniform_dequantize_op_i32_storage_input func.func @uniform_dequantize_op_i32_storage_input(%arg: tensor<2x2x!quant.uniform>) -> tensor<2x2xf32> { %0 = stablehlo.uniform_dequantize %arg : (tensor<2x2x!quant.uniform>) -> tensor<2x2xf32> return %0 : tensor<2x2xf32> } +// CHECK-LABEL: uniform_dequantize_op_i32_storage_input // CHECK: stablehlo.uniform_dequantize // CHECK-NOT: tfl.dequantize @@ -94,109 +101,104 @@ func.func @uniform_dequantize_op_i32_storage_input(%arg: tensor<2x2x!quant.unifo // storage type is i32. i32 storage type is not compatible with // `tfl.dequantize`. -// CHECK-LABEL: uniform_dequantize_op_return_f64 func.func @uniform_dequantize_op_return_f64(%arg: tensor<2x2x!quant.uniform>) -> tensor<2x2xf64> { %0 = stablehlo.uniform_dequantize %arg : (tensor<2x2x!quant.uniform>) -> tensor<2x2xf64> return %0 : tensor<2x2xf64> } +// CHECK-LABEL: uniform_dequantize_op_return_f64 // CHECK: stablehlo.uniform_dequantize // CHECK-NOT: tfl.dequantize // ----- -// CHECK-LABEL: convolution_upstream_full_integer -func.func @convolution_upstream_full_integer(%arg0: tensor<1x3x3x4x!quant.uniform>) -> tensor<1x3x3x2x!quant.uniform> { +func.func @convolution_upstream_same_padding_srq(%arg0: tensor<1x3x3x4x!quant.uniform>) -> tensor<1x3x3x2x!quant.uniform> { %0 = stablehlo.constant() {value = dense<3> : tensor<3x3x4x2xi8>} : () -> tensor<3x3x4x2x!quant.uniform> %1 = stablehlo.convolution(%arg0, %0) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {pad = [[1, 1], [1, 1]]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<1x3x3x4x!quant.uniform>, tensor<3x3x4x2x!quant.uniform>) -> tensor<1x3x3x2x!quant.uniform> return %1 : tensor<1x3x3x2x!quant.uniform> } -// CHECK-SAME: %[[ARG:.*]]: tensor<1x3x3x4x!quant.uniform> -// CHECK-DAG: %[[CONST_0:.*]] = "tfl.pseudo_const"(){{.*}}dense<{{\[\[0, 0\], \[1, 1\], \[1, 1\], \[0, 0\]\]}}> : tensor<4x2xi32> // Note that the quantized dimension is 0, and the shape has been transposed // to (2, 3, 3, 4). -// CHECK-DAG: %[[QCONST_0:.*]] = "tfl.pseudo_qconst"() {qtype = tensor<2x3x3x4x!quant.uniform:f32:0, {2.000000e+02,3.000000e+03}>>, value = dense<3> : tensor<2x3x3x4xi8>} : () -> tensor<2x3x3x4x!quant.uniform:f32:0, {2.000000e+02,3.000000e+03}>> -// CHECK-DAG: %[[QCONST_1:.*]] = "tfl.pseudo_qconst"() {qtype = tensor<2x!quant.uniform>, value = dense<0> : tensor<2xi32>} : () -> tensor<2x!quant.uniform> -// Explicit tfl.pad op to reflect explicit padding attribute. -// CHECK: %[[PAD:.*]] = "tfl.pad"(%[[ARG]], %[[CONST_0]]) : (tensor<1x3x3x4x!quant.uniform>, tensor<4x2xi32>) -> tensor<1x5x5x4x!quant.uniform> -// CHECK: %[[CONV2D:.*]] = "tfl.conv_2d"(%[[PAD]], %[[QCONST_0]], %[[QCONST_1]]) {dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<1x5x5x4x!quant.uniform>, tensor<2x3x3x4x!quant.uniform:f32:0, {2.000000e+02,3.000000e+03}>>, tensor<2x!quant.uniform>) -> tensor<1x3x3x2x!quant.uniform> +// CHECK-LABEL: convolution_upstream_same_padding_srq +// CHECK-SAME: %[[ARG:.+]]: tensor<1x3x3x4x!quant.uniform> +// CHECK-DAG: %[[QCONST_0:.+]] = "tfl.pseudo_qconst"() {qtype = tensor<2x3x3x4x!quant.uniform:f32:0, {2.000000e+02,3.000000e+03}>>, value = dense<3> : tensor<2x3x3x4xi8>} : () -> tensor<2x3x3x4x!quant.uniform:f32:0, {2.000000e+02,3.000000e+03}>> +// CHECK-DAG: %[[QCONST_1:.+]] = "tfl.pseudo_qconst"() {qtype = tensor<2x!quant.uniform>, value = dense<0> : tensor<2xi32>} : () -> tensor<2x!quant.uniform> +// CHECK: %[[CONV2D:.+]] = "tfl.conv_2d"(%[[ARG]], %[[QCONST_0]], %[[QCONST_1]]) {dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<1x3x3x4x!quant.uniform>, tensor<2x3x3x4x!quant.uniform:f32:0, {2.000000e+02,3.000000e+03}>>, tensor<2x!quant.uniform>) -> tensor<1x3x3x2x!quant.uniform> // CHECK: return %[[CONV2D]] : tensor<1x3x3x2x!quant.uniform> // ----- -// CHECK-LABEL: convolution_upstream_full_integer_non_const_filter -func.func @convolution_upstream_full_integer_non_const_filter(%arg0: tensor<1x3x3x4x!quant.uniform>, %arg1: tensor<3x3x4x2x!quant.uniform>) -> tensor<1x3x3x2x!quant.uniform> { +func.func @convolution_upstream_srq_non_const_filter(%arg0: tensor<1x3x3x4x!quant.uniform>, %arg1: tensor<3x3x4x2x!quant.uniform>) -> tensor<1x3x3x2x!quant.uniform> { %0 = stablehlo.convolution(%arg0, %arg1) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {pad = [[1, 1], [1, 1]]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<1x3x3x4x!quant.uniform>, tensor<3x3x4x2x!quant.uniform>) -> tensor<1x3x3x2x!quant.uniform> return %0 : tensor<1x3x3x2x!quant.uniform> } -// CHECK-SAME: %[[ARG:.*]]: tensor<1x3x3x4x!quant.uniform> - // Confirm that the `stablehlo.convolution` is not converted to `tfl.conv_2d`. +// CHECK-LABEL: convolution_upstream_srq_non_const_filter +// CHECK-SAME: %[[ARG:.+]]: tensor<1x3x3x4x!quant.uniform> // CHECK: stablehlo.convolution // CHECK-NOT: tfl.conv_2d // ----- -// Test that if the window padding contains values of 0, tfl.pad op is not +// Tests that if the window padding contains values of 0, tfl.pad op is not // created and the `padding` attribute is set as "VALID". -// CHECK-LABEL: convolution_upstream_full_integer_valid_padding -func.func @convolution_upstream_full_integer_valid_padding(%arg0: tensor<1x3x3x4x!quant.uniform>) -> tensor<1x1x1x2x!quant.uniform> { +func.func @convolution_upstream_srq_valid_padding(%arg0: tensor<1x3x3x4x!quant.uniform>) -> tensor<1x1x1x2x!quant.uniform> { %0 = stablehlo.constant() {value = dense<3> : tensor<3x3x4x2xi8>} : () -> tensor<3x3x4x2x!quant.uniform> %1 = stablehlo.convolution(%arg0, %0) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {pad = [[0, 0], [0, 0]]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<1x3x3x4x!quant.uniform>, tensor<3x3x4x2x!quant.uniform>) -> tensor<1x1x1x2x!quant.uniform> return %1 : tensor<1x1x1x2x!quant.uniform> } -// CHECK-SAME: %[[ARG:.*]]: tensor<1x3x3x4x!quant.uniform> -// CHECK: %[[QCONST_0:.*]] = "tfl.pseudo_qconst"() {qtype = tensor<2x3x3x4x!quant.uniform:f32:0, {2.000000e+02,3.000000e+03}>>, value = dense<3> : tensor<2x3x3x4xi8>} : () -> tensor<2x3x3x4x!quant.uniform:f32:0, {2.000000e+02,3.000000e+03}>> -// CHECK: %[[QCONST_1:.*]] = "tfl.pseudo_qconst"() {qtype = tensor<2x!quant.uniform>, value = dense<0> : tensor<2xi32>} : () -> tensor<2x!quant.uniform> +// CHECK-LABEL: convolution_upstream_srq_valid_padding +// CHECK-SAME: %[[ARG:.+]]: tensor<1x3x3x4x!quant.uniform> +// CHECK: %[[QCONST_0:.+]] = "tfl.pseudo_qconst"() {qtype = tensor<2x3x3x4x!quant.uniform:f32:0, {2.000000e+02,3.000000e+03}>>, value = dense<3> : tensor<2x3x3x4xi8>} : () -> tensor<2x3x3x4x!quant.uniform:f32:0, {2.000000e+02,3.000000e+03}>> +// CHECK: %[[QCONST_1:.+]] = "tfl.pseudo_qconst"() {qtype = tensor<2x!quant.uniform>, value = dense<0> : tensor<2xi32>} : () -> tensor<2x!quant.uniform> // CHECK-NOT: tfl.pad -// CHECK: %[[CONV2D:.*]] = "tfl.conv_2d"(%[[ARG]], %[[QCONST_0]], %[[QCONST_1]]) {dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<1x3x3x4x!quant.uniform>, tensor<2x3x3x4x!quant.uniform:f32:0, {2.000000e+02,3.000000e+03}>>, tensor<2x!quant.uniform>) -> tensor<1x1x1x2x!quant.uniform> +// CHECK: %[[CONV2D:.+]] = "tfl.conv_2d"(%[[ARG]], %[[QCONST_0]], %[[QCONST_1]]) {dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<1x3x3x4x!quant.uniform>, tensor<2x3x3x4x!quant.uniform:f32:0, {2.000000e+02,3.000000e+03}>>, tensor<2x!quant.uniform>) -> tensor<1x1x1x2x!quant.uniform> // CHECK: return %[[CONV2D]] : tensor<1x1x1x2x!quant.uniform> // ----- -// Test that if the window padding value is missing, tfl.pad op is not +// Tests that if the window padding value is missing, tfl.pad op is not // created and the `padding` attribute is set as "VALID". -// CHECK-LABEL: convolution_upstream_full_integer_valid_padding -func.func @convolution_upstream_full_integer_valid_padding(%arg0: tensor<1x3x3x4x!quant.uniform>) -> tensor<1x1x1x2x!quant.uniform> { +func.func @convolution_upstream_srq_valid_padding(%arg0: tensor<1x3x3x4x!quant.uniform>) -> tensor<1x1x1x2x!quant.uniform> { %0 = stablehlo.constant() {value = dense<3> : tensor<3x3x4x2xi8>} : () -> tensor<3x3x4x2x!quant.uniform> // The `window` attribute is empty. %1 = stablehlo.convolution(%arg0, %0) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<1x3x3x4x!quant.uniform>, tensor<3x3x4x2x!quant.uniform>) -> tensor<1x1x1x2x!quant.uniform> return %1 : tensor<1x1x1x2x!quant.uniform> } -// CHECK-SAME: %[[ARG:.*]]: tensor<1x3x3x4x!quant.uniform> -// CHECK: %[[QCONST_0:.*]] = "tfl.pseudo_qconst"() {qtype = tensor<2x3x3x4x!quant.uniform:f32:0, {2.000000e+02,3.000000e+03}>>, value = dense<3> : tensor<2x3x3x4xi8>} : () -> tensor<2x3x3x4x!quant.uniform:f32:0, {2.000000e+02,3.000000e+03}>> -// CHECK: %[[QCONST_1:.*]] = "tfl.pseudo_qconst"() {qtype = tensor<2x!quant.uniform>, value = dense<0> : tensor<2xi32>} : () -> tensor<2x!quant.uniform> -// CHECK: %[[CONV2D:.*]] = "tfl.conv_2d"(%[[ARG]], %[[QCONST_0]], %[[QCONST_1]]) {dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<1x3x3x4x!quant.uniform>, tensor<2x3x3x4x!quant.uniform:f32:0, {2.000000e+02,3.000000e+03}>>, tensor<2x!quant.uniform>) -> tensor<1x1x1x2x!quant.uniform> +// CHECK-LABEL: convolution_upstream_srq_valid_padding +// CHECK-SAME: %[[ARG:.+]]: tensor<1x3x3x4x!quant.uniform> +// CHECK: %[[QCONST_0:.+]] = "tfl.pseudo_qconst"() {qtype = tensor<2x3x3x4x!quant.uniform:f32:0, {2.000000e+02,3.000000e+03}>>, value = dense<3> : tensor<2x3x3x4xi8>} : () -> tensor<2x3x3x4x!quant.uniform:f32:0, {2.000000e+02,3.000000e+03}>> +// CHECK: %[[QCONST_1:.+]] = "tfl.pseudo_qconst"() {qtype = tensor<2x!quant.uniform>, value = dense<0> : tensor<2xi32>} : () -> tensor<2x!quant.uniform> +// CHECK: %[[CONV2D:.+]] = "tfl.conv_2d"(%[[ARG]], %[[QCONST_0]], %[[QCONST_1]]) {dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<1x3x3x4x!quant.uniform>, tensor<2x3x3x4x!quant.uniform:f32:0, {2.000000e+02,3.000000e+03}>>, tensor<2x!quant.uniform>) -> tensor<1x1x1x2x!quant.uniform> // CHECK: return %[[CONV2D]] : tensor<1x1x1x2x!quant.uniform> // ----- -// Test that if the window stride value is explicitly set, the attribute +// Tests that if the window stride value is explicitly set, the attribute // value is transferred to tfl.conv_2d's stridw_h and stride_w values. -// CHECK-LABEL: convolution_upstream_full_integer_strides -func.func @convolution_upstream_full_integer_strides(%arg0: tensor<1x3x3x4x!quant.uniform>) -> tensor<1x3x2x2x!quant.uniform> { +func.func @convolution_upstream_srq_strides(%arg0: tensor<1x3x3x4x!quant.uniform>) -> tensor<1x3x2x2x!quant.uniform> { %0 = stablehlo.constant() {value = dense<3> : tensor<3x3x4x2xi8>} : () -> tensor<3x3x4x2x!quant.uniform> // The stride value is explicitly set to [1, 2]. %1 = stablehlo.convolution(%arg0, %0) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {stride = [1, 2], pad = [[1, 1], [1, 1]]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<1x3x3x4x!quant.uniform>, tensor<3x3x4x2x!quant.uniform>) -> tensor<1x3x2x2x!quant.uniform> return %1 : tensor<1x3x2x2x!quant.uniform> } -// CHECK-SAME: %[[ARG:.*]]: tensor<1x3x3x4x!quant.uniform> -// CHECK-DAG: %[[CONST:.*]] = "tfl.pseudo_const"(){{.*}}dense<{{\[\[0, 0\], \[1, 1\], \[1, 1\], \[0, 0\]\]}}> : tensor<4x2xi32> -// CHECK-DAG: %[[QCONST_0:.*]] = "tfl.pseudo_qconst"() {qtype = tensor<2x3x3x4x!quant.uniform:f32:0, {2.000000e+02,3.000000e+03}>>, value = dense<3> : tensor<2x3x3x4xi8>} : () -> tensor<2x3x3x4x!quant.uniform:f32:0, {2.000000e+02,3.000000e+03}>> -// CHECK-DAG: %[[QCONST_1:.*]] = "tfl.pseudo_qconst"() {qtype = tensor<2x!quant.uniform>, value = dense<0> : tensor<2xi32>} : () -> tensor<2x!quant.uniform> -// CHECK: %[[PAD:.*]] = "tfl.pad"(%arg0, %[[CONST]]) : (tensor<1x3x3x4x!quant.uniform>, tensor<4x2xi32>) -> tensor<1x5x5x4x!quant.uniform> +// CHECK-LABEL: convolution_upstream_srq_strides +// CHECK-SAME: %[[ARG:.+]]: tensor<1x3x3x4x!quant.uniform> +// CHECK-DAG: %[[CONST:.+]] = "tfl.pseudo_const"(){{.*}}dense<{{\[\[0, 0\], \[1, 1\], \[1, 1\], \[0, 0\]\]}}> : tensor<4x2xi32> +// CHECK-DAG: %[[QCONST_0:.+]] = "tfl.pseudo_qconst"() {qtype = tensor<2x3x3x4x!quant.uniform:f32:0, {2.000000e+02,3.000000e+03}>>, value = dense<3> : tensor<2x3x3x4xi8>} : () -> tensor<2x3x3x4x!quant.uniform:f32:0, {2.000000e+02,3.000000e+03}>> +// CHECK-DAG: %[[QCONST_1:.+]] = "tfl.pseudo_qconst"() {qtype = tensor<2x!quant.uniform>, value = dense<0> : tensor<2xi32>} : () -> tensor<2x!quant.uniform> +// CHECK: %[[PAD:.+]] = "tfl.pad"(%arg0, %[[CONST]]) : (tensor<1x3x3x4x!quant.uniform>, tensor<4x2xi32>) -> tensor<1x5x5x4x!quant.uniform> // Tests that the stride_w is set to 2. -// CHECK: %[[CONV2D:.*]] = "tfl.conv_2d"(%[[PAD]], %[[QCONST_0]], %[[QCONST_1]]) {dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 1 : i32, stride_w = 2 : i32} : (tensor<1x5x5x4x!quant.uniform>, tensor<2x3x3x4x!quant.uniform:f32:0, {2.000000e+02,3.000000e+03}>>, tensor<2x!quant.uniform>) -> tensor<1x3x2x2x!quant.uniform> +// CHECK: %[[CONV2D:.+]] = "tfl.conv_2d"(%[[PAD]], %[[QCONST_0]], %[[QCONST_1]]) {dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 1 : i32, stride_w = 2 : i32} : (tensor<1x5x5x4x!quant.uniform>, tensor<2x3x3x4x!quant.uniform:f32:0, {2.000000e+02,3.000000e+03}>>, tensor<2x!quant.uniform>) -> tensor<1x3x2x2x!quant.uniform> // CHECK: return %[[CONV2D]] : tensor<1x3x2x2x!quant.uniform> // ----- -// Test full integer quantized dot_general with asymmetric quantized input. +// Tests static range quantized dot_general with asymmetric quantized input. -// CHECK-LABEL: dot_general_upstream_full_integer_asym_input -func.func @dot_general_upstream_full_integer_asym_input(%arg0: tensor<1x2x3x4x!quant.uniform>) -> tensor<1x2x3x5x!quant.uniform> { +func.func @dot_general_upstream_srq_asym_input(%arg0: tensor<1x2x3x4x!quant.uniform>) -> tensor<1x2x3x5x!quant.uniform> { %0 = stablehlo.constant() {value = dense<1> : tensor<1x2x4x5xi8>} : () -> tensor<1x2x4x5x!quant.uniform> %1 = "stablehlo.dot_general"(%arg0, %0) { dot_dimension_numbers = #stablehlo.dot< @@ -208,16 +210,16 @@ func.func @dot_general_upstream_full_integer_asym_input(%arg0: tensor<1x2x3x4x!q } : (tensor<1x2x3x4x!quant.uniform>, tensor<1x2x4x5x!quant.uniform>) -> tensor<1x2x3x5x!quant.uniform> return %1 : tensor<1x2x3x5x!quant.uniform> } -// CHECK-SAME: %[[ARG:.*]]: tensor<1x2x3x4x!quant.uniform> -// CHECK: %[[QCONST_0:.*]] = "tfl.pseudo_qconst"() {qtype = tensor<1x2x4x5x!quant.uniform>, value = dense<1> : tensor<1x2x4x5xi8>} : () -> tensor<1x2x4x5x!quant.uniform> -// CHECK: %[[BMM:.*]] = "tfl.batch_matmul"(%[[ARG]], %[[QCONST_0]]) {adj_x = false, adj_y = false} : (tensor<1x2x3x4x!quant.uniform>, tensor<1x2x4x5x!quant.uniform>) -> tensor<1x2x3x5x!quant.uniform> +// CHECK-LABEL: dot_general_upstream_srq_asym_input +// CHECK-SAME: %[[ARG:.+]]: tensor<1x2x3x4x!quant.uniform> +// CHECK: %[[QCONST_0:.+]] = "tfl.pseudo_qconst"() {qtype = tensor<1x2x4x5x!quant.uniform>, value = dense<1> : tensor<1x2x4x5xi8>} : () -> tensor<1x2x4x5x!quant.uniform> +// CHECK: %[[BMM:.+]] = "tfl.batch_matmul"(%[[ARG]], %[[QCONST_0]]) {adj_x = false, adj_y = false} : (tensor<1x2x3x4x!quant.uniform>, tensor<1x2x4x5x!quant.uniform>) -> tensor<1x2x3x5x!quant.uniform> // ----- -// Test full integer quantized dot_general with symmetric quantized input. +// Tests static range quantized dot_general with symmetric quantized input. -// CHECK-LABEL: dot_general_upstream_full_integer_sym_input -func.func @dot_general_upstream_full_integer_sym_input(%arg0: tensor<1x2x3x4x!quant.uniform>) -> tensor<1x2x3x5x!quant.uniform> { +func.func @dot_general_upstream_srq_sym_input(%arg0: tensor<1x2x3x4x!quant.uniform>) -> tensor<1x2x3x5x!quant.uniform> { %0 = stablehlo.constant() {value = dense<1> : tensor<1x2x4x5xi8>} : () -> tensor<1x2x4x5x!quant.uniform> %1 = "stablehlo.dot_general"(%arg0, %0) { dot_dimension_numbers = #stablehlo.dot< @@ -230,41 +232,16 @@ func.func @dot_general_upstream_full_integer_sym_input(%arg0: tensor<1x2x3x4x!qu } : (tensor<1x2x3x4x!quant.uniform>, tensor<1x2x4x5x!quant.uniform>) -> tensor<1x2x3x5x!quant.uniform> return %1 : tensor<1x2x3x5x!quant.uniform> } - -// CHECK-SAME: %[[ARG:.*]]: tensor<1x2x3x4x!quant.uniform> -// CHECK: %[[QCONST_0:.*]] = "tfl.pseudo_qconst"() +// CHECK-LABEL: dot_general_upstream_srq_sym_input +// CHECK-SAME: %[[ARG:.+]]: tensor<1x2x3x4x!quant.uniform> +// CHECK: %[[QCONST_0:.+]] = "tfl.pseudo_qconst"() // CHECK: "tfl.batch_matmul"(%[[ARG]], %[[QCONST_0]]) {adj_x = false, adj_y = false} // ----- -// Tests that the pattern does not match when the output tensor's storage -// type is i32. Currently we support qi8, qi8 -> qi8 only for GEMM ops that -// are quantized upstream. Other cases should be handled by regular quantized -// stablehlo.dot_general case. - -// CHECK-LABEL: dot_general_upstream_full_integer_i32_output -func.func @dot_general_upstream_full_integer_i32_output(%arg0: tensor<1x2x3x4x!quant.uniform>) -> tensor<1x2x3x5x!quant.uniform> { - %0 = stablehlo.constant() {value = dense<1> : tensor<1x2x4x5xi8>} : () -> tensor<1x2x4x5x!quant.uniform> - %1 = "stablehlo.dot_general"(%arg0, %0) { - dot_dimension_numbers = #stablehlo.dot< - lhs_batching_dimensions = [0, 1], - rhs_batching_dimensions = [0, 1], - lhs_contracting_dimensions = [3], - rhs_contracting_dimensions = [2] - >, - precision_config = [#stablehlo, #stablehlo] - } : (tensor<1x2x3x4x!quant.uniform>, tensor<1x2x4x5x!quant.uniform>) -> tensor<1x2x3x5x!quant.uniform> - return %1 : tensor<1x2x3x5x!quant.uniform> -} -// CHECK: stablehlo.dot_general -// CHECK-NOT: tfl.quantize - -// ----- - -// Test full integer quantized dot_general with activation as RHS +// Tests static range quantized dot_general with activation as RHS -// CHECK-LABEL: dot_general_upstream_full_integer_activation_rhs -func.func @dot_general_upstream_full_integer_activation_rhs(%arg0: tensor<1x2x3x4x!quant.uniform>, %arg1: tensor<1x2x4x5x!quant.uniform>) -> tensor<1x2x3x5x!quant.uniform> { +func.func @dot_general_upstream_srq_activation_rhs(%arg0: tensor<1x2x3x4x!quant.uniform>, %arg1: tensor<1x2x4x5x!quant.uniform>) -> tensor<1x2x3x5x!quant.uniform> { %0 = "stablehlo.dot_general"(%arg0, %arg1) { dot_dimension_numbers = #stablehlo.dot< lhs_batching_dimensions = [0, 1], @@ -276,14 +253,15 @@ func.func @dot_general_upstream_full_integer_activation_rhs(%arg0: tensor<1x2x3x } : (tensor<1x2x3x4x!quant.uniform>, tensor<1x2x4x5x!quant.uniform>) -> tensor<1x2x3x5x!quant.uniform> return %0 : tensor<1x2x3x5x!quant.uniform> } +// CHECK-LABEL: dot_general_upstream_srq_activation_rhs // CHECK: "tfl.batch_matmul"(%arg0, %arg1) {adj_x = false, adj_y = false} : (tensor<1x2x3x4x!quant.uniform>, tensor<1x2x4x5x!quant.uniform>) -> tensor<1x2x3x5x!quant.uniform> // ----- -// Test full integer quantized dot_general with adj_x +// Tests static range quantized dot_general with adj_x -// CHECK-LABEL: dot_general_upstream_full_integer_adj_x -func.func @dot_general_upstream_full_integer_adj_x(%arg0: tensor<1x2x4x3x!quant.uniform>) -> tensor<1x2x3x5x!quant.uniform> { +// CHECK-LABEL: dot_general_upstream_srq_adj_x +func.func @dot_general_upstream_srq_adj_x(%arg0: tensor<1x2x4x3x!quant.uniform>) -> tensor<1x2x3x5x!quant.uniform> { %0 = stablehlo.constant() {value = dense<1> : tensor<1x2x4x5xi8>} : () -> tensor<1x2x4x5x!quant.uniform> %1 = "stablehlo.dot_general"(%arg0, %0) { dot_dimension_numbers = #stablehlo.dot< @@ -297,17 +275,15 @@ func.func @dot_general_upstream_full_integer_adj_x(%arg0: tensor<1x2x4x3x!quant. } : (tensor<1x2x4x3x!quant.uniform>, tensor<1x2x4x5x!quant.uniform>) -> tensor<1x2x3x5x!quant.uniform> return %1 : tensor<1x2x3x5x!quant.uniform> } - -// CHECK-SAME: %[[ARG:.*]]: tensor<1x2x4x3x!quant.uniform> -// CHECK: %[[QCONST_0:.*]] = "tfl.pseudo_qconst"() {qtype = tensor<1x2x4x5x!quant.uniform>, value = dense<1> : tensor<1x2x4x5xi8>} : () -> tensor<1x2x4x5x!quant.uniform> +// CHECK-SAME: %[[ARG:.+]]: tensor<1x2x4x3x!quant.uniform> +// CHECK: %[[QCONST_0:.+]] = "tfl.pseudo_qconst"() {qtype = tensor<1x2x4x5x!quant.uniform>, value = dense<1> : tensor<1x2x4x5xi8>} : () -> tensor<1x2x4x5x!quant.uniform> // CHECK: "tfl.batch_matmul"(%[[ARG]], %[[QCONST_0]]) {adj_x = true, adj_y = false} // ----- -// Test full integer quantized dot_general with adj_y +// Tests static range quantized dot_general with adj_y -// CHECK-LABEL: dot_general_upstream_full_integer_adj_y -func.func @dot_general_upstream_full_integer_adj_y(%arg0: tensor<1x2x3x4x!quant.uniform>) -> tensor<1x2x3x5x!quant.uniform> { +func.func @dot_general_upstream_srq_adj_y(%arg0: tensor<1x2x3x4x!quant.uniform>) -> tensor<1x2x3x5x!quant.uniform> { %0 = stablehlo.constant() {value = dense<1> : tensor<1x2x5x4xi8>} : () -> tensor<1x2x5x4x!quant.uniform> %1 = "stablehlo.dot_general"(%arg0, %0) { dot_dimension_numbers = #stablehlo.dot< @@ -321,17 +297,16 @@ func.func @dot_general_upstream_full_integer_adj_y(%arg0: tensor<1x2x3x4x!quant. } : (tensor<1x2x3x4x!quant.uniform>, tensor<1x2x5x4x!quant.uniform>) -> tensor<1x2x3x5x!quant.uniform> return %1 : tensor<1x2x3x5x!quant.uniform> } - -// CHECK-SAME: %[[ARG:.*]]: tensor<1x2x3x4x!quant.uniform> -// CHECK: %[[QCONST_0:.*]] = "tfl.pseudo_qconst"() {qtype = tensor<1x2x5x4x!quant.uniform>, value = dense<1> : tensor<1x2x5x4xi8>} : () -> tensor<1x2x5x4x!quant.uniform> +// CHECK-LABEL: dot_general_upstream_srq_adj_y +// CHECK-SAME: %[[ARG:.+]]: tensor<1x2x3x4x!quant.uniform> +// CHECK: %[[QCONST_0:.+]] = "tfl.pseudo_qconst"() {qtype = tensor<1x2x5x4x!quant.uniform>, value = dense<1> : tensor<1x2x5x4xi8>} : () -> tensor<1x2x5x4x!quant.uniform> // CHECK: "tfl.batch_matmul"(%[[ARG]], %[[QCONST_0]]) {adj_x = false, adj_y = true} // ----- -// Test full integer quantized dot_general with wrong batch dims +// Tests static range quantized dot_general with wrong batch dims -// CHECK-LABEL: dot_general_upstream_full_integer_too_many_batches -func.func @dot_general_upstream_full_integer_too_many_batches(%arg0: tensor<1x1x1x2x3x4x!quant.uniform>) -> tensor<1x1x1x2x3x5x!quant.uniform> { +func.func @dot_general_upstream_srq_too_many_batches(%arg0: tensor<1x1x1x2x3x4x!quant.uniform>) -> tensor<1x1x1x2x3x5x!quant.uniform> { %0 = stablehlo.constant() {value = dense<1> : tensor<1x1x1x2x4x5xi8>} : () -> tensor<1x1x1x2x4x5x!quant.uniform> %1 = "stablehlo.dot_general"(%arg0, %0) { dot_dimension_numbers = #stablehlo.dot< @@ -345,15 +320,15 @@ func.func @dot_general_upstream_full_integer_too_many_batches(%arg0: tensor<1x1x return %1 : tensor<1x1x1x2x3x5x!quant.uniform> } // Only support size(batching_dimensions) <= 3 +// CHECK-LABEL: dot_general_upstream_srq_too_many_batches // CHECK: stablehlo.dot_general // CHECK-NOT: tfl.batch_matmul // ----- -// Test full integer quantized dot_general with too many contracting dimension +// Tests static range quantized dot_general with too many contracting dimension -// CHECK-LABEL: dot_general_upstream_full_integer_too_many_contractions -func.func @dot_general_upstream_full_integer_too_many_contractions(%arg0: tensor<1x2x3x4x4x!quant.uniform>) -> tensor<1x2x3x5x!quant.uniform> { +func.func @dot_general_upstream_srq_too_many_contractions(%arg0: tensor<1x2x3x4x4x!quant.uniform>) -> tensor<1x2x3x5x!quant.uniform> { %0 = stablehlo.constant() {value = dense<1> : tensor<1x2x4x4x5xi8>} : () -> tensor<1x2x4x4x5x!quant.uniform> %1 = "stablehlo.dot_general"(%arg0, %0) { dot_dimension_numbers = #stablehlo.dot< @@ -367,15 +342,15 @@ func.func @dot_general_upstream_full_integer_too_many_contractions(%arg0: tensor return %1 : tensor<1x2x3x5x!quant.uniform> } // Only support size(contracting_dimensions) == 1 +// CHECK-LABEL: dot_general_upstream_srq_too_many_contractions // CHECK: stablehlo.dot_general // CHECK-NOT: tfl.batch_matmul // ----- -// Test full integer quantized dot_general with unsupported contracting dim +// Tests static range quantized dot_general with unsupported contracting dim -// CHECK-LABEL: dot_general_upstream_full_integer_wrong_contracting -func.func @dot_general_upstream_full_integer_wrong_contracting(%arg0: tensor<1x2x3x4x!quant.uniform>) -> tensor<1x4x3x5x!quant.uniform> { +func.func @dot_general_upstream_srq_wrong_contracting(%arg0: tensor<1x2x3x4x!quant.uniform>) -> tensor<1x4x3x5x!quant.uniform> { %0 = stablehlo.constant() {value = dense<1> : tensor<1x2x4x5xi8>} : () -> tensor<1x2x4x5x!quant.uniform> %1 = "stablehlo.dot_general"(%arg0, %0) { dot_dimension_numbers = #stablehlo.dot< @@ -388,17 +363,17 @@ func.func @dot_general_upstream_full_integer_wrong_contracting(%arg0: tensor<1x2 } : (tensor<1x2x3x4x!quant.uniform>, tensor<1x2x4x5x!quant.uniform>) -> tensor<1x4x3x5x!quant.uniform> return %1 : tensor<1x4x3x5x!quant.uniform> } - // Contracting dimension must be the last two dimension +// CHECK-LABEL: dot_general_upstream_srq_wrong_contracting // CHECK: stablehlo.dot_general // CHECK-NOT: tfl.batch_matmul // ----- -// Test full integer quantized dot_general with float operands +// Tests static range quantized dot_general with float operands -// CHECK-LABEL: dot_general_upstream_full_integer_float_operands -func.func @dot_general_upstream_full_integer_float_operands(%arg0: tensor<1x2x3x4xf32>, %arg1: tensor<1x2x4x5xf32>) -> tensor<1x2x3x5xf32> { +// CHECK-LABEL: dot_general_upstream_srq_float_operands +func.func @dot_general_upstream_srq_float_operands(%arg0: tensor<1x2x3x4xf32>, %arg1: tensor<1x2x4x5xf32>) -> tensor<1x2x3x5xf32> { %0 = "stablehlo.dot_general"(%arg0, %arg1) { dot_dimension_numbers = #stablehlo.dot< lhs_batching_dimensions = [0, 1], @@ -416,44 +391,44 @@ func.func @dot_general_upstream_full_integer_float_operands(%arg0: tensor<1x2x3x // ----- -// Test full integer quantized dot_general with asymmetric weight (rhs). +// Tests static range quantized dot_general with asymmetric weight (rhs). -// CHECK-LABEL: dot_general_upstream_full_integer_asym_weight -func.func @dot_general_upstream_full_integer_asym_weight(%arg0: tensor<1x2x3x4x!quant.uniform>) -> tensor<1x2x3x5x!quant.uniform> { +// CHECK-LABEL: dot_general_upstream_srq_asym_weight +func.func @dot_general_upstream_srq_asym_weight(%arg0: tensor<1x2x3x4x!quant.uniform>) -> tensor<1x2x3x5x!quant.uniform> { %0 = stablehlo.constant() {value = dense<1> : tensor<1x2x4x5xi8>} : () -> tensor<1x2x4x5x!quant.uniform> %1 = "stablehlo.dot_general"(%arg0, %0) {dot_dimension_numbers = #stablehlo.dot, precision_config = [#stablehlo, #stablehlo]} : (tensor<1x2x3x4x!quant.uniform>, tensor<1x2x4x5x!quant.uniform>) -> tensor<1x2x3x5x!quant.uniform> return %1 : tensor<1x2x3x5x!quant.uniform> } -// CHECK-SAME: %[[ARG:.*]]: tensor<1x2x3x4x!quant.uniform> -// CHECK: %[[QCONST_0:.*]] = "tfl.pseudo_qconst"() {qtype = tensor<1x2x4x5x!quant.uniform>, value = dense<1> : tensor<1x2x4x5xi8>} : () -> tensor<1x2x4x5x!quant.uniform> -// CHECK: %[[BMM:.*]] = "tfl.batch_matmul"(%[[ARG]], %[[QCONST_0]]) {adj_x = false, adj_y = false} : (tensor<1x2x3x4x!quant.uniform>, tensor<1x2x4x5x!quant.uniform>) -> tensor<1x2x3x5x!quant.uniform> +// CHECK-SAME: %[[ARG:.+]]: tensor<1x2x3x4x!quant.uniform> +// CHECK: %[[QCONST_0:.+]] = "tfl.pseudo_qconst"() {qtype = tensor<1x2x4x5x!quant.uniform>, value = dense<1> : tensor<1x2x4x5xi8>} : () -> tensor<1x2x4x5x!quant.uniform> +// CHECK: %[[BMM:.+]] = "tfl.batch_matmul"(%[[ARG]], %[[QCONST_0]]) {adj_x = false, adj_y = false} : (tensor<1x2x3x4x!quant.uniform>, tensor<1x2x4x5x!quant.uniform>) -> tensor<1x2x3x5x!quant.uniform> // ----- -// Test that when the weight tensor for `stablehlo.dot_general` is per-axis +// Tests that when the weight tensor for `stablehlo.dot_general` is per-axis // quantized, it is converted to `tfl.fully_connected` op. -// CHECK-LABEL: dot_general_upstream_full_integer_per_axis_quantized_filter -func.func @dot_general_upstream_full_integer_per_axis_quantized_filter(%arg0: tensor<1x3x!quant.uniform>) -> tensor<1x2x!quant.uniform> { +// CHECK-LABEL: dot_general_upstream_srq_per_axis_quantized_filter +func.func @dot_general_upstream_srq_per_axis_quantized_filter(%arg0: tensor<1x3x!quant.uniform>) -> tensor<1x2x!quant.uniform> { %0 = stablehlo.constant() {value = dense<1> : tensor<3x2xi8>} : () -> tensor<3x2x!quant.uniform> %1 = stablehlo.dot_general %arg0, %0, contracting_dims = [1] x [0] : (tensor<1x3x!quant.uniform>, tensor<3x2x!quant.uniform>) -> tensor<1x2x!quant.uniform> return %1 : tensor<1x2x!quant.uniform> } -// CHECK-SAME: %[[ARG_0:.*]]: tensor<1x3x!quant.uniform> +// CHECK-SAME: %[[ARG_0:.+]]: tensor<1x3x!quant.uniform> // Weight tensor is transposed, as tfl.fully_connected accepts a [o, i] matrix. -// CHECK-DAG: %[[QCONST_0:.*]] = "tfl.pseudo_qconst"() {qtype = tensor<2x3x!quant.uniform:f32:0, {2.000000e+02,3.000000e+03}>>, value = dense<1> : tensor<2x3xi8>} : () -> tensor<2x3x!quant.uniform:f32:0, {2.000000e+02,3.000000e+03}>> -// CHECK-DAG: %[[QCONST_1:.*]] = "tfl.pseudo_qconst"() {qtype = tensor<2x!quant.uniform>, value = dense<0> : tensor<2xi32>} : () -> tensor<2x!quant.uniform> +// CHECK-DAG: %[[QCONST_0:.+]] = "tfl.pseudo_qconst"() {qtype = tensor<2x3x!quant.uniform:f32:0, {2.000000e+02,3.000000e+03}>>, value = dense<1> : tensor<2x3xi8>} : () -> tensor<2x3x!quant.uniform:f32:0, {2.000000e+02,3.000000e+03}>> +// CHECK-DAG: %[[QCONST_1:.+]] = "tfl.pseudo_qconst"() {qtype = tensor<2x!quant.uniform>, value = dense<0> : tensor<2xi32>} : () -> tensor<2x!quant.uniform> // Bias tensor's scale is input scale * filter scale. -// CHECK: %[[FC:.*]] = "tfl.fully_connected"(%[[ARG_0]], %[[QCONST_0]], %[[QCONST_1]]) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<1x3x!quant.uniform>, tensor<2x3x!quant.uniform:f32:0, {2.000000e+02,3.000000e+03}>>, tensor<2x!quant.uniform>) -> tensor<1x2x!quant.uniform> +// CHECK: %[[FC:.+]] = "tfl.fully_connected"(%[[ARG_0]], %[[QCONST_0]], %[[QCONST_1]]) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<1x3x!quant.uniform>, tensor<2x3x!quant.uniform:f32:0, {2.000000e+02,3.000000e+03}>>, tensor<2x!quant.uniform>) -> tensor<1x2x!quant.uniform> // CHECK-NEXT: return %[[FC]] : tensor<1x2x!quant.uniform> // ----- -// Test that when the weight tensor for `stablehlo.dot_general` is per-axis +// Tests that when the weight tensor for `stablehlo.dot_general` is per-axis // quantized but has a batch dimension, it is not converted. -// CHECK-LABEL: dot_general_upstream_full_integer_per_axis_quantized_filter_with_batch_dim -func.func @dot_general_upstream_full_integer_per_axis_quantized_filter_with_batch_dim(%arg0: tensor<1x1x3x!quant.uniform>) -> tensor<1x1x2x!quant.uniform> { +// CHECK-LABEL: dot_general_upstream_srq_per_axis_quantized_filter_with_batch_dim +func.func @dot_general_upstream_srq_per_axis_quantized_filter_with_batch_dim(%arg0: tensor<1x1x3x!quant.uniform>) -> tensor<1x1x2x!quant.uniform> { %0 = stablehlo.constant() {value = dense<1> : tensor<1x3x2xi8>} : () -> tensor<1x3x2x!quant.uniform> %1 = stablehlo.dot_general %arg0, %0, batching_dims = [0] x [0], contracting_dims = [2] x [1] : (tensor<1x1x3x!quant.uniform>, tensor<1x3x2x!quant.uniform>) -> tensor<1x1x2x!quant.uniform> return %1 : tensor<1x1x2x!quant.uniform> @@ -465,11 +440,11 @@ func.func @dot_general_upstream_full_integer_per_axis_quantized_filter_with_batc // ----- -// Test that when the weight tensor for `stablehlo.dot_general` is per-axis +// Tests that when the weight tensor for `stablehlo.dot_general` is per-axis // quantized but has a batch dim > 1, it is not converted. -// CHECK-LABEL: dot_general_upstream_full_integer_per_axis_quantized_filter_multibatch -func.func @dot_general_upstream_full_integer_per_axis_quantized_filter_multibatch(%arg0: tensor<3x1x3x!quant.uniform>) -> tensor<3x1x2x!quant.uniform> { +// CHECK-LABEL: dot_general_upstream_srq_per_axis_quantized_filter_multibatch +func.func @dot_general_upstream_srq_per_axis_quantized_filter_multibatch(%arg0: tensor<3x1x3x!quant.uniform>) -> tensor<3x1x2x!quant.uniform> { %0 = stablehlo.constant() {value = dense<1> : tensor<3x3x2xi8>} : () -> tensor<3x3x2x!quant.uniform> %1 = stablehlo.dot_general %arg0, %0, batching_dims = [0] x [0], contracting_dims = [2] x [1] : (tensor<3x1x3x!quant.uniform>, tensor<3x3x2x!quant.uniform>) -> tensor<3x1x2x!quant.uniform> return %1 : tensor<3x1x2x!quant.uniform> @@ -481,11 +456,11 @@ func.func @dot_general_upstream_full_integer_per_axis_quantized_filter_multibatc // ----- -// Test that when the weight tensor for `stablehlo.dot_general` is per-axis +// Tests that when the weight tensor for `stablehlo.dot_general` is per-axis // quantized but has more than one contracting dimension, it is not converted. -// CHECK-LABEL: dot_general_upstream_full_integer_per_axis_quantized_filter_with_multiple_contracting_dims -func.func @dot_general_upstream_full_integer_per_axis_quantized_filter_with_multiple_contracting_dims(%arg0: tensor<1x2x3x!quant.uniform>) -> tensor<1x1x!quant.uniform> { +// CHECK-LABEL: dot_general_upstream_srq_per_axis_quantized_filter_with_multiple_contracting_dims +func.func @dot_general_upstream_srq_per_axis_quantized_filter_with_multiple_contracting_dims(%arg0: tensor<1x2x3x!quant.uniform>) -> tensor<1x1x!quant.uniform> { %0 = stablehlo.constant() {value = dense<1> : tensor<1x3x2xi8>} : () -> tensor<1x3x2x!quant.uniform> %1 = stablehlo.dot_general %arg0, %0, contracting_dims = [1, 2] x [2, 1] : (tensor<1x2x3x!quant.uniform>, tensor<1x3x2x!quant.uniform>) -> tensor<1x1x!quant.uniform> return %1 : tensor<1x1x!quant.uniform> @@ -497,113 +472,258 @@ func.func @dot_general_upstream_full_integer_per_axis_quantized_filter_with_mult // ----- -// Test that a simple per-tensor quantized stablehlo.dot_general is properly -// fused with a subsequent requantize (qi32->qi8) op then legalized. -// Supports the following format: (lhs: qi8, rhs: qi8) -> result: qi32 +// ============================================================================ +// The following functions tests example quantization patterns outputted from +// StableHLO Quantizer. These patterns should be legalized early directly +// to fused tflite ops. +// ============================================================================ + +// Tests that a simple per-tensor quantized `stablehlo.dot_general` is properly +// lowered to fused `tfl.fully_connected`. +// This case covers for the following quantization patterns because +// activation clipping ranges take affect in scale and zp of the final +// `stablehlo.uniform_quantize`. See more details in b/319168201. +// * dot_general_fn +// * dot_general_with_relu_fn +// * dot_general_with_relu6_fn + +func.func @dot_general_srq(%arg0: tensor<1x1024x!quant.uniform>) -> (tensor<1x3x!quant.uniform>) { + %0 = stablehlo.constant() {value = dense<1> : tensor<1024x3xi8>} : () -> tensor<1024x3x!quant.uniform:f32, 2.000000e+0:0>> + %1 = stablehlo.dot_general %arg0, %0, contracting_dims = [1] x [0] : (tensor<1x1024x!quant.uniform>, tensor<1024x3x!quant.uniform:f32, 2.000000e+0:0>>) -> tensor<1x3x!quant.uniform> + %2 = stablehlo.uniform_quantize %1 : (tensor<1x3x!quant.uniform>) -> tensor<1x3x!quant.uniform> + return %2 : tensor<1x3x!quant.uniform> +} +// CHECK-LABEL: dot_general_srq +// CHECK-SAME: (%[[ARG_1:.+]]: tensor<1x1024x!quant.uniform) -> tensor<1x3x!quant.uniform> +// CHECK-NOT: stablehlo.dot_general +// CHECK: %[[QCONST_0:.+]] = "tfl.pseudo_qconst"() {qtype = tensor<3x1024x!quant.uniform:f32, 2.000000e+00>>, value = dense<1> : tensor<3x1024xi8>} : () -> tensor<3x1024x!quant.uniform:f32, 2.000000e+00>> +// CHECK: %[[QCONST_1:.+]] = "tfl.pseudo_qconst"() {qtype = tensor<3x!quant.uniform>, value = dense<0> : tensor<3xi32>} : () -> tensor<3x!quant.uniform> +// CHECK: %[[FULLY_CONNECTED:.+]] = "tfl.fully_connected"(%[[ARG_1]], %[[QCONST_0]], %[[QCONST_1]]) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<1x1024x!quant.uniform>, tensor<3x1024x!quant.uniform:f32, 2.000000e+00>>, tensor<3x!quant.uniform>) -> tensor<1x3x!quant.uniform> +// CHECK-NOT: tfl.batch_matmul +// CHECK: return %[[FULLY_CONNECTED]] -// CHECK-LABEL: dot_general_full_integer -// CHECK-SAME: (%[[ARG_1:.*]]: tensor<1x1024x!quant.uniform - func.func @dot_general_full_integer(%arg0: tensor<1x1024x!quant.uniform> {tf_saved_model.index_path = ["input_tensor"]}) -> (tensor<1x3xf32> {tf_saved_model.index_path = ["output"]}) { - %0 = stablehlo.constant() {value = dense<1> : tensor<1024x3xi8>} : () -> tensor<1024x3x!quant.uniform:f32, 2.000000e+0:0>> - %1 = stablehlo.dot_general %arg0, %0, contracting_dims = [1] x [0] : (tensor<1x1024x!quant.uniform>, tensor<1024x3x!quant.uniform:f32, 2.000000e+0:0>>) -> tensor<1x3x!quant.uniform> - %2 = stablehlo.uniform_quantize %1 : (tensor<1x3x!quant.uniform>) -> tensor<1x3x!quant.uniform> - %3 = stablehlo.uniform_dequantize %2 : (tensor<1x3x!quant.uniform>) -> tensor<1x3xf32> - return %3 : tensor<1x3xf32> - } +// ----- +// Tests that a fused per-tensor quantized `stablehlo.dot_general` is properly +// lowered to fused `tfl.fully_connected`. +// TODO: b/309896242 - Add more support for dynamic bias fusion cases. + +func.func @dot_general_with_bias_same_shape_srq(%arg0: tensor<1x1024x!quant.uniform>) -> (tensor<1x3x!quant.uniform>) { + %0 = stablehlo.constant() {value = dense<1> : tensor<1024x3xi8>} : () -> tensor<1024x3x!quant.uniform:f32, 2.000000e+0:0>> + %1 = stablehlo.constant() {value = dense<2> : tensor<1x3xi32>} : () -> tensor<1x3x!quant.uniform> + %2 = stablehlo.dot_general %arg0, %0, contracting_dims = [1] x [0] : (tensor<1x1024x!quant.uniform>, tensor<1024x3x!quant.uniform:f32, 2.000000e+0:0>>) -> tensor<1x3x!quant.uniform> + %3 = stablehlo.add %2, %1 : tensor<1x3x!quant.uniform> + %4 = stablehlo.uniform_quantize %3 : (tensor<1x3x!quant.uniform>) -> tensor<1x3x!quant.uniform> + return %4 : tensor<1x3x!quant.uniform> +} +// CHECK-LABEL: dot_general_with_bias_same_shape +// CHECK-SAME: (%[[ARG_0:.+]]: tensor<1x1024x!quant.uniform>) -> tensor<1x3x!quant.uniform> +// CHECK-DAG: %[[QCONST_0:.+]] = "tfl.pseudo_qconst"() {qtype = tensor<3x1024x!quant.uniform:f32, 2.000000e+00>>, value = dense<1> : tensor<3x1024xi8>} : () -> tensor<3x1024x!quant.uniform:f32, 2.000000e+00>> +// CHECK-DAG: %[[QCONST_1:.+]] = "tfl.pseudo_qconst"() {qtype = tensor<3x!quant.uniform>, value = dense<2> : tensor<1x3xi32>} : () -> tensor<3x!quant.uniform> +// CHECK: %[[FULLY_CONNECTED:.+]] = "tfl.fully_connected"(%[[ARG_0]], %[[QCONST_0]], %[[QCONST_1]]) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<1x1024x!quant.uniform>, tensor<3x1024x!quant.uniform:f32, 2.000000e+00>>, tensor<3x!quant.uniform>) -> tensor<1x3x!quant.uniform> +// CHECK: return %[[FULLY_CONNECTED]] + +// ----- + +// Tests static range quantized dot_general with qi32 -> qi8 requantization is +// properly lowered to `tfl.batch_matmul`. + +func.func @dot_general_srq_to_batch_matmul(%arg0: tensor<1x2x3x4x!quant.uniform>, %arg1: tensor<1x2x4x5x!quant.uniform>) -> tensor<1x2x3x5x!quant.uniform> { + %0 = "stablehlo.dot_general"(%arg0, %arg1) { + dot_dimension_numbers = #stablehlo.dot< + lhs_batching_dimensions = [0, 1], + rhs_batching_dimensions = [0, 1], + lhs_contracting_dimensions = [3], + rhs_contracting_dimensions = [2] + >, + precision_config = [#stablehlo, #stablehlo] + } : (tensor<1x2x3x4x!quant.uniform>, tensor<1x2x4x5x!quant.uniform>) -> tensor<1x2x3x5x!quant.uniform> + %1 = stablehlo.uniform_quantize %0 : (tensor<1x2x3x5x!quant.uniform>) -> tensor<1x2x3x5x!quant.uniform> + return %1 : tensor<1x2x3x5x!quant.uniform> +} + +// CHECK-LABEL: dot_general_srq_to_batch_matmul +// CHECK-SAME: (%[[ARG_0:.+]]: tensor<1x2x3x4x!quant.uniform>, %[[ARG_1:.+]]: tensor<1x2x4x5x!quant.uniform>) -> tensor<1x2x3x5x!quant.uniform> +// CHECK: %[[BMM:.+]] = "tfl.batch_matmul"(%[[ARG_0]], %[[ARG_1]]) {adj_x = false, adj_y = false} : (tensor<1x2x3x4x!quant.uniform>, tensor<1x2x4x5x!quant.uniform>) -> tensor<1x2x3x5x!quant.uniform> // CHECK-NOT: stablehlo.dot_general -// CHECK: %[[QCONST_0:.*]] = "tfl.pseudo_qconst"() {qtype = tensor<3x1024x!quant.uniform:f32, 2.000000e+00>>, value = dense<1> : tensor<3x1024xi8>} : () -> tensor<3x1024x!quant.uniform:f32, 2.000000e+00>> -// CHECK: %[[QCONST_1:.*]] = "tfl.pseudo_qconst"() {qtype = tensor<3x!quant.uniform>, value = dense<0> : tensor<3xi32>} : () -> tensor<3x!quant.uniform> -// CHECK: "tfl.fully_connected"(%[[ARG_1]], %[[QCONST_0]], %[[QCONST_1]]) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<1x1024x!quant.uniform>, tensor<3x1024x!quant.uniform:f32, 2.000000e+00>>, tensor<3x!quant.uniform>) -> tensor<1x3x!quant.uniform> -// CHECK-NOT: tfl.batch_matmul +// CHECK-NOT: stablehlo.uniform_quantize +// CHECK-NOT: tfl.fully_connected +// CHECK-NOT: tfl.quantize +// CHECK: return %[[BMM]] // ----- -// Test that a `stablehlo.dot_general` with an i32 output remains unchanged when -// it is not followed by a requantization (`stablehlo.quantize`). +// Tests static range quantized dot_general with qi32 -> qi8 requantization is +// not converted to `tfl.batch_matmul` when there are multiple use of the +// intermediate result. -// CHECK-LABEL: dot_general_no_requantize -func.func @dot_general_no_requantize(%arg0: tensor<1x4xf32>) -> tensor<1x3xf32> { - %0 = stablehlo.constant() {value = dense<5> : tensor<4x3xi8>} : () -> tensor<4x3x!quant.uniform:f32, 2.000000e+00>> - %1 = stablehlo.uniform_quantize %arg0 : (tensor<1x4xf32>) -> tensor<1x4x!quant.uniform> - %2 = stablehlo.dot_general %1, %0, contracting_dims = [1] x [0] : (tensor<1x4x!quant.uniform>, tensor<4x3x!quant.uniform:f32, 2.000000e+00>>) -> tensor<1x3x!quant.uniform> - %3 = stablehlo.uniform_dequantize %2 : (tensor<1x3x!quant.uniform>) -> tensor<1x3xf32> - return %3 : tensor<1x3xf32> +func.func @dot_general_srq_multiple_use_of_intermediate_result(%arg0: tensor<1x2x3x4x!quant.uniform>, %arg1: tensor<1x2x4x5x!quant.uniform>) -> tensor<1x2x3x5x!quant.uniform> { + %0 = "stablehlo.dot_general"(%arg0, %arg1) { + dot_dimension_numbers = #stablehlo.dot< + lhs_batching_dimensions = [0, 1], + rhs_batching_dimensions = [0, 1], + lhs_contracting_dimensions = [3], + rhs_contracting_dimensions = [2] + >, + precision_config = [#stablehlo, #stablehlo] + } : (tensor<1x2x3x4x!quant.uniform>, tensor<1x2x4x5x!quant.uniform>) -> tensor<1x2x3x5x!quant.uniform> + %1 = stablehlo.uniform_quantize %0 : (tensor<1x2x3x5x!quant.uniform>) -> tensor<1x2x3x5x!quant.uniform> + %2 = stablehlo.uniform_quantize %0 : (tensor<1x2x3x5x!quant.uniform>) -> tensor<1x2x3x5x!quant.uniform> + %3 = stablehlo.add %1, %2 : tensor<1x2x3x5x!quant.uniform> + return %3 : tensor<1x2x3x5x!quant.uniform> } -// CHECK: "tfl.quantize" -// CHECK: stablehlo.dot_general + +// CHECK-LABEL: dot_general_srq_multiple_use_of_intermediate_result // CHECK-NOT: tfl.fully_connected // CHECK-NOT: tfl.batch_matmul -// CHECK: stablehlo.uniform_dequantize +// CHECK: stablehlo.dot_general // ----- -// Test that a quantized stablehlo.transpose is converted to tfl.transpose. +// Tests that a simple per-channel quantized `stablehlo.convolution` is properly +// lowered to fused `tfl.conv_2d`. +// This case covers for the following quantization patterns because +// activation clipping ranges take affect in scale and zp of the final +// `stablehlo.uniform_quantize`. See more details in b/319168201. +// * conv_fn +// * conv_with_relu_fn +// * conv_with_relu6_fn + +func.func @conv_srq(%arg0: tensor<1x5x5x2x!quant.uniform>) -> (tensor<1x6x6x4x!quant.uniform>) { + %0 = stablehlo.constant() {value = dense<3> : tensor<2x2x2x4xi8>} : () -> tensor<2x2x2x4x!quant.uniform> + %1 = stablehlo.convolution(%arg0, %0) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {pad = [[1, 1], [1, 1]]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<1x5x5x2x!quant.uniform>, tensor<2x2x2x4x!quant.uniform>) -> tensor<1x6x6x4x!quant.uniform> + %2 = stablehlo.uniform_quantize %1 : (tensor<1x6x6x4x!quant.uniform>) -> tensor<1x6x6x4x!quant.uniform> + return %2 : tensor<1x6x6x4x!quant.uniform> +} +// CHECK-LABEL: func.func @conv_srq +// CHECK-SAME: (%[[ARG_0:.+]]: tensor<1x5x5x2x!quant.uniform>) -> tensor<1x6x6x4x!quant.uniform> +// CHECK-DAG: %[[QCONST_0:.+]] = "tfl.pseudo_const"() {value = dense<{{\[\[0, 0\], \[1, 1\], \[1, 1\], \[0, 0\]\]}}> : tensor<4x2xi32>} : () -> tensor<4x2xi32> +// CHECK-DAG: %[[QCONST_1:.+]] = "tfl.pseudo_qconst"() {qtype = tensor<4x2x2x2x!quant.uniform:f32:0, {3.000000e+00,3.000000e+00,3.000000e+00,3.000000e+00}>>, value = dense<3> : tensor<4x2x2x2xi8>} : () -> tensor<4x2x2x2x!quant.uniform:f32:0, {3.000000e+00,3.000000e+00,3.000000e+00,3.000000e+00}>> +// CHECK-DAG: %[[QCONST_2:.+]] = "tfl.pseudo_qconst"() {qtype = tensor<4x!quant.uniform>, value = dense<0> : tensor<4xi32>} : () -> tensor<4x!quant.uniform> +// CHECK: %[[PAD:.+]] = "tfl.pad"(%[[ARG_0]], %[[QCONST_0]]) : (tensor<1x5x5x2x!quant.uniform>, tensor<4x2xi32>) -> tensor<1x7x7x2x!quant.uniform> +// CHECK: %[[CONV_2D:.+]] = "tfl.conv_2d"(%[[PAD]], %[[QCONST_1]], %[[QCONST_2]]) {dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<1x7x7x2x!quant.uniform>, tensor<4x2x2x2x!quant.uniform:f32:0, {3.000000e+00,3.000000e+00,3.000000e+00,3.000000e+00}>>, tensor<4x!quant.uniform>) -> tensor<1x6x6x4x!quant.uniform> +// CHECK: return %[[CONV_2D]] + +func.func @conv_same_padding_srq(%arg0: tensor<1x32x32x3x!quant.uniform>) -> (tensor<1x32x32x2x!quant.uniform>) { + %0 = stablehlo.constant() {value = dense<3> : tensor<3x3x3x2xi8>} : () -> tensor<3x3x3x2x!quant.uniform> + %1 = stablehlo.convolution(%arg0, %0) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {pad = [[1, 1], [1, 1]]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<1x32x32x3x!quant.uniform>, tensor<3x3x3x2x!quant.uniform>) -> tensor<1x32x32x2x!quant.uniform> + %2 = stablehlo.uniform_quantize %1 : (tensor<1x32x32x2x!quant.uniform>) -> tensor<1x32x32x2x!quant.uniform> + return %2 : tensor<1x32x32x2x!quant.uniform> +} +// CHECK-LABEL: func.func @conv_same_padding_srq +// CHECK-SAME: (%[[ARG_0:.+]]: tensor<1x32x32x3x!quant.uniform>) -> tensor<1x32x32x2x!quant.uniform> +// CHECK-DAG: %[[QCONST_0:.+]] = "tfl.pseudo_qconst"() {qtype = tensor<2x3x3x3x!quant.uniform:f32:0, {3.000000e+00,3.000000e+00}>>, value = dense<3> : tensor<2x3x3x3xi8>} : () -> tensor<2x3x3x3x!quant.uniform:f32:0, {3.000000e+00,3.000000e+00}>> +// CHECK-DAG: %[[QCONST_1:.+]] = "tfl.pseudo_qconst"() {qtype = tensor<2x!quant.uniform>, value = dense<0> : tensor<2xi32>} : () -> tensor<2x!quant.uniform> +// CHECK: %[[CONV_2D:.+]] = "tfl.conv_2d"(%[[ARG_0]], %[[QCONST_0]], %[[QCONST_1]]) {dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<1x32x32x3x!quant.uniform>, tensor<2x3x3x3x!quant.uniform:f32:0, {3.000000e+00,3.000000e+00}>>, tensor<2x!quant.uniform>) -> tensor<1x32x32x2x!quant.uniform> +// CHECK: return %[[CONV_2D]] : tensor<1x32x32x2x!quant.uniform> + +// ----- + +// Tests that a fused per-channel quantized `stablehlo.convolution` is properly +// lowered to fused `tfl.conv_2d`. +// This case covers for the following quantization patterns because +// activation clipping ranges take affect in scale and zp of the final +// `stablehlo.uniform_quantize`. See more details in b/319168201. +// * conv_with_bias_fn +// * conv_with_bias_and_relu_fn +// * conv_with_bias_and_relu6_fn + +func.func @conv_with_bias_and_relu_srq(%arg0: tensor<1x5x5x2x!quant.uniform>) -> (tensor<1x6x6x4x!quant.uniform>) { + %0 = stablehlo.constant() {value = dense<5> : tensor<1x1x1x4xi32>} : () -> tensor<1x1x1x4x!quant.uniform> + %1 = stablehlo.constant() {value = dense<3> : tensor<2x2x2x4xi8>} : () -> tensor<2x2x2x4x!quant.uniform> + %2 = stablehlo.broadcast_in_dim %0, dims = [0, 1, 2, 3] : (tensor<1x1x1x4x!quant.uniform>) -> tensor<1x6x6x4x!quant.uniform> + %3 = stablehlo.convolution(%arg0, %1) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {pad = [[1, 1], [1, 1]]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<1x5x5x2x!quant.uniform>, tensor<2x2x2x4x!quant.uniform>) -> tensor<1x6x6x4x!quant.uniform> + %4 = stablehlo.add %3, %2 : tensor<1x6x6x4x!quant.uniform> + %5 = stablehlo.uniform_quantize %4 : (tensor<1x6x6x4x!quant.uniform>) -> tensor<1x6x6x4x!quant.uniform> + return %5 : tensor<1x6x6x4x!quant.uniform> + } +// CHECK-LABEL: func.func @conv_with_bias_and_relu_srq +// CHECK-SAME: (%[[ARG_0:.+]]: tensor<1x5x5x2x!quant.uniform>) -> tensor<1x6x6x4x!quant.uniform> +// CHECK-DAG: %[[CONST_0:.+]] = "tfl.pseudo_const"() {value = dense<{{\[\[0, 0\], \[1, 1\], \[1, 1\], \[0, 0\]\]}}> : tensor<4x2xi32>} : () -> tensor<4x2xi32> +// CHECK-DAG: %[[QCONST_1:.+]] = "tfl.pseudo_qconst"() {qtype = tensor<4x2x2x2x!quant.uniform:f32:0, {3.000000e+00,3.000000e+00,3.000000e+00,3.000000e+00}>>, value = dense<3> : tensor<4x2x2x2xi8>} : () -> tensor<4x2x2x2x!quant.uniform:f32:0, {3.000000e+00,3.000000e+00,3.000000e+00,3.000000e+00}>> +// CHECK-DAG: %[[QCONST_2:.+]] = "tfl.pseudo_qconst"() {qtype = tensor<4x!quant.uniform>, value = dense<5> : tensor<1x1x1x4xi32>} : () -> tensor<4x!quant.uniform> +// CHECK: %[[PAD:.+]] = "tfl.pad"(%[[ARG_0]], %[[CONST_0]]) : (tensor<1x5x5x2x!quant.uniform>, tensor<4x2xi32>) -> tensor<1x7x7x2x!quant.uniform> +// CHECK: %[[CONV_2D:.+]] = "tfl.conv_2d"(%[[PAD]], %[[QCONST_1]], %[[QCONST_2]]) {dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<1x7x7x2x!quant.uniform>, tensor<4x2x2x2x!quant.uniform:f32:0, {3.000000e+00,3.000000e+00,3.000000e+00,3.000000e+00}>>, tensor<4x!quant.uniform>) -> tensor<1x6x6x4x!quant.uniform> +// CHECK: return %[[CONV_2D]] + +func.func @conv_with_bias_same_padding_srq(%arg0: tensor<1x32x32x3x!quant.uniform>) -> (tensor<1x32x32x2x!quant.uniform>) { + %0 = stablehlo.constant() {value = dense<3> : tensor<3x3x3x2xi8>} : () -> tensor<3x3x3x2x!quant.uniform> + %1 = stablehlo.constant() {value = dense<5> : tensor<1x1x1x2xi32>} : () -> tensor<1x1x1x2x!quant.uniform> + %2 = stablehlo.broadcast_in_dim %1, dims = [0, 1, 2, 3] : (tensor<1x1x1x2x!quant.uniform>) -> tensor<1x32x32x2x!quant.uniform> + %3 = stablehlo.convolution(%arg0, %0) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {pad = [[1, 1], [1, 1]]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<1x32x32x3x!quant.uniform>, tensor<3x3x3x2x!quant.uniform>) -> tensor<1x32x32x2x!quant.uniform> + %4 = stablehlo.add %3, %2 : tensor<1x32x32x2x!quant.uniform> + %5 = stablehlo.uniform_quantize %4 : (tensor<1x32x32x2x!quant.uniform>) -> tensor<1x32x32x2x!quant.uniform> + return %5 : tensor<1x32x32x2x!quant.uniform> +} +// CHECK-LABEL: func.func @conv_with_bias_same_padding_srq +// CHECK-SAME: (%[[ARG_0:.+]]: tensor<1x32x32x3x!quant.uniform>) -> tensor<1x32x32x2x!quant.uniform> +// CHECK-DAG: %[[QCONST_0:.+]] = "tfl.pseudo_qconst"() {qtype = tensor<2x3x3x3x!quant.uniform:f32:0, {3.000000e+00,3.000000e+00}>>, value = dense<3> : tensor<2x3x3x3xi8>} : () -> tensor<2x3x3x3x!quant.uniform:f32:0, {3.000000e+00,3.000000e+00}>> +// CHECK-DAG: %[[QCONST_1:.+]] = "tfl.pseudo_qconst"() {qtype = tensor<2x!quant.uniform>, value = dense<5> : tensor<1x1x1x2xi32>} : () -> tensor<2x!quant.uniform> +// CHECK: %[[CONV_2D:.+]] = "tfl.conv_2d"(%[[ARG_0]], %[[QCONST_0]], %[[QCONST_1]]) {dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<1x32x32x3x!quant.uniform>, tensor<2x3x3x3x!quant.uniform:f32:0, {3.000000e+00,3.000000e+00}>>, tensor<2x!quant.uniform>) -> tensor<1x32x32x2x!quant.uniform> +// CHECK: return %[[CONV_2D]] + +// ----- + +// Tests that a quantized stablehlo.transpose is converted to tfl.transpose. -// CHECK-LABEL: transpose -// CHECK-SAME: %[[ARG0:.*]]: tensor<2x3x4x!quant.uniform> func.func @transpose( %arg0: tensor<2x3x4x!quant.uniform> ) -> tensor<4x3x2x!quant.uniform> { %0 = stablehlo.transpose %arg0, dims = [2, 1, 0] : (tensor<2x3x4x!quant.uniform>) -> tensor<4x3x2x!quant.uniform> return %0 : tensor<4x3x2x!quant.uniform> } - +// CHECK-LABEL: transpose +// CHECK-SAME: %[[ARG0:.+]]: tensor<2x3x4x!quant.uniform> // CHECK-NOT: stablehlo.transpose -// CHECK: %[[CST:.*]] = arith.constant dense<[2, 1, 0]> : tensor<3xi32> -// CHECK: %[[TRANSPOSE:.*]] = "tfl.transpose"(%[[ARG0]], %[[CST]]) : (tensor<2x3x4x!quant.uniform>, tensor<3xi32>) -> tensor<4x3x2x!quant.uniform> +// CHECK: %[[CST:.+]] = arith.constant dense<[2, 1, 0]> : tensor<3xi32> +// CHECK: %[[TRANSPOSE:.+]] = "tfl.transpose"(%[[ARG0]], %[[CST]]) : (tensor<2x3x4x!quant.uniform>, tensor<3xi32>) -> tensor<4x3x2x!quant.uniform> // CHECK: return %[[TRANSPOSE]] // ----- -// Test that a float stablehlo.transpose is not converted to tfl.transpose. +// Tests that a float stablehlo.transpose is not converted to tfl.transpose. -// CHECK-LABEL: float_transpose func.func @float_transpose(%arg0: tensor<2x3x4xf32>) -> tensor<4x3x2xf32> { %0 = stablehlo.transpose %arg0, dims = [2, 1, 0] : (tensor<2x3x4xf32>) -> tensor<4x3x2xf32> return %0 : tensor<4x3x2xf32> } - +// CHECK-LABEL: float_transpose // CHECK-NOT: tfl.transpose // CHECK: stablehlo.transpose // ----- -// Test that a quantized stablehlo.reshape is converted to tfl.reshape. +// Tests that a quantized stablehlo.reshape is converted to tfl.reshape. -// CHECK-LABEL: reshape -// CHECK-SAME: %[[ARG0:.*]]: tensor<2x3x4x!quant.uniform> func.func @reshape( %arg0: tensor<2x3x4x!quant.uniform> ) -> tensor<6x4x!quant.uniform> { %0 = stablehlo.reshape %arg0 : (tensor<2x3x4x!quant.uniform>) -> tensor<6x4x!quant.uniform> return %0 : tensor<6x4x!quant.uniform> } - +// CHECK-LABEL: reshape +// CHECK-SAME: %[[ARG0:.+]]: tensor<2x3x4x!quant.uniform> // CHECK-NOT: stablehlo.reshape -// CHECK: %[[CST:.*]] = arith.constant dense<[6, 4]> : tensor<2xi32> -// CHECK: %[[RESHAPE:.*]] = "tfl.reshape"(%[[ARG0]], %[[CST]]) : (tensor<2x3x4x!quant.uniform>, tensor<2xi32>) -> tensor<6x4x!quant.uniform> +// CHECK: %[[CST:.+]] = arith.constant dense<[6, 4]> : tensor<2xi32> +// CHECK: %[[RESHAPE:.+]] = "tfl.reshape"(%[[ARG0]], %[[CST]]) : (tensor<2x3x4x!quant.uniform>, tensor<2xi32>) -> tensor<6x4x!quant.uniform> // CHECK: return %[[RESHAPE]] // ----- -// Test that a float stablehlo.reshape is not converted to tfl.reshape. +// Tests that a float stablehlo.reshape is not converted to tfl.reshape. -// CHECK-LABEL: float_reshape func.func @float_reshape(%arg0: tensor<2x3x4xf32>) -> tensor<6x4xf32> { %0 = stablehlo.reshape %arg0 : (tensor<2x3x4xf32>) -> tensor<6x4xf32> return %0 : tensor<6x4xf32> } - +// CHECK-LABEL: float_reshape // CHECK-NOT: tfl.reshape // CHECK: stablehlo.reshape // ----- -// Test that a quantized stablehlo.select is converted to tfl.select_v2. +// Tests that a quantized stablehlo.select is converted to tfl.select_v2. -// CHECK-LABEL: select -// CHECK-SAME: %[[ARG0:.*]]: tensor<1x3xi1>, %[[ARG1:.*]]: tensor<1x3x!quant.uniform>, %[[ARG2:.*]]: tensor<1x3x!quant.uniform> func.func @select( %arg0: tensor<1x3xi1>, %arg1: tensor<1x3x!quant.uniform>, @@ -616,31 +736,28 @@ func.func @select( ) -> tensor<1x3x!quant.uniform> return %0 : tensor<1x3x!quant.uniform> } - +// CHECK-LABEL: select +// CHECK-SAME: %[[ARG0:.+]]: tensor<1x3xi1>, %[[ARG1:.+]]: tensor<1x3x!quant.uniform>, %[[ARG2:.+]]: tensor<1x3x!quant.uniform> // CHECK-NOT: stablehlo.select -// CHECK: %[[SELECT:.*]] = "tfl.select_v2"(%[[ARG0]], %[[ARG1]], %[[ARG2]]) : (tensor<1x3xi1>, tensor<1x3x!quant.uniform>, tensor<1x3x!quant.uniform>) -> tensor<1x3x!quant.uniform> +// CHECK: %[[SELECT:.+]] = "tfl.select_v2"(%[[ARG0]], %[[ARG1]], %[[ARG2]]) : (tensor<1x3xi1>, tensor<1x3x!quant.uniform>, tensor<1x3x!quant.uniform>) -> tensor<1x3x!quant.uniform> // CHECK: return %[[SELECT]] // ----- -// Test that a float stablehlo.select is not converted to tfl.select_v2. +// Tests that a float stablehlo.select is not converted to tfl.select_v2. - -// CHECK-LABEL: float_select func.func @float_select(%arg0: tensor<1x3xi1>, %arg1: tensor<1x3xf32>, %arg2: tensor<1x3xf32>) -> tensor<1x3xf32> { %0 = "stablehlo.select"(%arg0, %arg1, %arg2) : (tensor<1x3xi1>, tensor<1x3xf32>, tensor<1x3xf32>) -> tensor<1x3xf32> return %0 : tensor<1x3xf32> } - +// CHECK-LABEL: float_select // CHECK-NOT: tfl.select_v2 // CHECK: stablehlo.select // ----- -// Test that a quantized stablehlo.concatenate is converted to tfl.concatenation. +// Tests that a quantized stablehlo.concatenate is converted to tfl.concatenation. -// CHECK-LABEL: concatenate -// CHECK-SAME: %[[ARG0:.*]]: tensor<3x2x!quant.uniform>, %[[ARG1:.*]]: tensor<1x2x!quant.uniform> func.func @concatenate( %arg0: tensor<3x2x!quant.uniform>, %arg1: tensor<1x2x!quant.uniform> @@ -651,32 +768,29 @@ func.func @concatenate( ) -> tensor<4x2x!quant.uniform> return %0 : tensor<4x2x!quant.uniform> } - +// CHECK-LABEL: concatenate +// CHECK-SAME: %[[ARG0:.+]]: tensor<3x2x!quant.uniform>, %[[ARG1:.+]]: tensor<1x2x!quant.uniform> // CHECK-NOT: stablehlo.concatenate -// CHECK: %[[CONCAT:.*]] = "tfl.concatenation"(%arg0, %arg1) {axis = 0 : i32, fused_activation_function = "NONE"} : (tensor<3x2x!quant.uniform>, tensor<1x2x!quant.uniform>) -> tensor<4x2x!quant.uniform> +// CHECK: %[[CONCAT:.+]] = "tfl.concatenation"(%arg0, %arg1) {axis = 0 : i32, fused_activation_function = "NONE"} : (tensor<3x2x!quant.uniform>, tensor<1x2x!quant.uniform>) -> tensor<4x2x!quant.uniform> // CHECK: return %[[CONCAT]] // ----- -// Test that a float stablehlo.concatenate is not converted to tfl.concatenation. +// Tests that a float stablehlo.concatenate is not converted to tfl.concatenation. -// CHECK-LABEL: float_concatenate func.func @float_concatenate(%arg0: tensor<3x2xf32>, %arg1: tensor<1x2xf32>) -> tensor<4x2xf32> { %0 = "stablehlo.concatenate"(%arg0, %arg1) {dimension = 0 : i64} : (tensor<3x2xf32>, tensor<1x2xf32>) -> tensor<4x2xf32> return %0 : tensor<4x2xf32> } - +// CHECK-LABEL: float_concatenate // CHECK-NOT: tfl.concatenation // CHECK: stablehlo.concatenate // ----- -// Test that a quantized stablehlo.pad without interior padding is converted to +// Tests that a quantized stablehlo.pad without interior padding is converted to // tfl.padv2. -// CHECK-LABEL: pad_without_interior_padding -// CHECK-SAME: %[[ARG0:.*]]: tensor<2x3x!quant.uniform> -// CHECK-SAME: %[[ARG1:.*]]: tensor> func.func @pad_without_interior_padding( %arg0: tensor<2x3x!quant.uniform>, %arg1: tensor> @@ -687,20 +801,19 @@ func.func @pad_without_interior_padding( ) -> tensor<4x5x!quant.uniform> return %0 : tensor<4x5x!quant.uniform> } - -// CHECK: %[[PADDING:.*]] = arith.constant +// CHECK-LABEL: pad_without_interior_padding +// CHECK-SAME: %[[ARG0:.+]]: tensor<2x3x!quant.uniform> +// CHECK-SAME: %[[ARG1:.+]]: tensor> +// CHECK: %[[PADDING:.+]] = arith.constant // CHECK{LITERAL}: dense<[[0, 2], [1, 1]]> : tensor<2x2xi32> -// CHECK: %[[PAD:.*]] = "tfl.padv2"(%[[ARG0]], %[[PADDING]], %[[ARG1]]) : (tensor<2x3x!quant.uniform>, tensor<2x2xi32>, tensor>) -> tensor<4x5x!quant.uniform> +// CHECK: %[[PAD:.+]] = "tfl.padv2"(%[[ARG0]], %[[PADDING]], %[[ARG1]]) : (tensor<2x3x!quant.uniform>, tensor<2x2xi32>, tensor>) -> tensor<4x5x!quant.uniform> // CHECK: return %[[PAD]] // ----- -// Test that a quantized stablehlo.pad with interior padding is converted to +// Tests that a quantized stablehlo.pad with interior padding is converted to // tfl.dilate and tfl.padv2. -// CHECK-LABEL: pad_with_interior_padding -// CHECK-SAME: %[[ARG0:.*]]: tensor<2x3x!quant.uniform> -// CHECK-SAME: %[[ARG1:.*]]: tensor> func.func @pad_with_interior_padding( %arg0: tensor<2x3x!quant.uniform>, %arg1: tensor> @@ -711,24 +824,295 @@ func.func @pad_with_interior_padding( ) -> tensor<5x9x!quant.uniform> return %0 : tensor<5x9x!quant.uniform> } - -// CHECK: %[[PADDING:.*]] = arith.constant +// CHECK-LABEL: pad_with_interior_padding +// CHECK-SAME: %[[ARG0:.+]]: tensor<2x3x!quant.uniform> +// CHECK-SAME: %[[ARG1:.+]]: tensor> +// CHECK: %[[PADDING:.+]] = arith.constant // CHECK{LITERAL}: dense<[[0, 2], [1, 1]]> : tensor<2x2xi32> -// CHECK: %[[INTERIOR:.*]] = arith.constant +// CHECK: %[[INTERIOR:.+]] = arith.constant // CHECK{LITERAL}: dense<[1, 2]> : tensor<2xi32> -// CHECK: %[[DILATE:.*]] = "tfl.dilate"(%[[ARG0]], %[[INTERIOR]], %[[ARG1]]) : (tensor<2x3x!quant.uniform>, tensor<2xi32>, tensor>) -> tensor<3x7x!quant.uniform> -// CHECK: %[[PAD:.*]] = "tfl.padv2"(%[[DILATE]], %[[PADDING]], %[[ARG1]]) : (tensor<3x7x!quant.uniform>, tensor<2x2xi32>, tensor>) -> tensor<5x9x!quant.uniform> +// CHECK: %[[DILATE:.+]] = "tfl.dilate"(%[[ARG0]], %[[INTERIOR]], %[[ARG1]]) : (tensor<2x3x!quant.uniform>, tensor<2xi32>, tensor>) -> tensor<3x7x!quant.uniform> +// CHECK: %[[PAD:.+]] = "tfl.padv2"(%[[DILATE]], %[[PADDING]], %[[ARG1]]) : (tensor<3x7x!quant.uniform>, tensor<2x2xi32>, tensor>) -> tensor<5x9x!quant.uniform> // CHECK: return %[[PAD]] // ----- -// Test that a float stablehlo.pad is not converted to tfl.padv2. +// Tests that a float stablehlo.pad is not converted to tfl.padv2. -// CHECK-LABEL: float_pad func.func @float_pad(%arg0: tensor<2x3xf32>, %arg1: tensor) -> tensor<4x5xf32> { %0 = stablehlo.pad %arg0, %arg1, low = [0, 1], high = [2, 1], interior = [0, 0] : (tensor<2x3xf32>, tensor) -> tensor<4x5xf32> return %0 : tensor<4x5xf32> } - +// CHECK-LABEL: float_pad // CHECK-NOT: tfl.padv2 // CHECK: stablehlo.pad + +// ----- + +// Tests that a quantized stablehlo.slice is converted to tfl.slice when stride +// is 1. + +func.func @slice( + %arg0: tensor<3x4x!quant.uniform> + ) -> tensor<2x2x!quant.uniform> { + %0 = "stablehlo.slice"(%arg0) { + start_indices = array, + limit_indices = array, + strides = array + } : ( + tensor<3x4x!quant.uniform> + ) -> tensor<2x2x!quant.uniform> + return %0 : tensor<2x2x!quant.uniform> +} +// CHECK-LABEL: slice +// CHECK-SAME: %[[ARG0:.+]]: tensor<3x4x!quant.uniform> +// CHECK-DAG: %[[START:.+]] = arith.constant dense<{{\[1, 2\]}}> : tensor<2xi32> +// CHECK-DAG: %[[SIZE:.+]] = arith.constant dense<2> : tensor<2xi32> +// CHECK: %[[SLICE:.+]] = "tfl.slice"(%[[ARG0]], %[[START]], %[[SIZE]]) : (tensor<3x4x!quant.uniform>, tensor<2xi32>, tensor<2xi32>) -> tensor<2x2x!quant.uniform> +// CHECK: return %[[SLICE]] + +// ----- + +// Tests that a quantized stablehlo.slice is converted to tfl.strided_slice when +// stride is not 1. + +func.func @strided_slice( + %arg0: tensor<3x6x!quant.uniform> + ) -> tensor<2x2x!quant.uniform> { + %0 = "stablehlo.slice"(%arg0) { + start_indices = array, + limit_indices = array, + strides = array + } : ( + tensor<3x6x!quant.uniform> + ) -> tensor<2x2x!quant.uniform> + return %0 : tensor<2x2x!quant.uniform> +} +// CHECK-LABEL: strided_slice +// CHECK-SAME: %[[ARG0:.+]]: tensor<3x6x!quant.uniform> +// CHECK: %[[START:.+]] = arith.constant +// CHECK{LITERAL}: dense<[0, 2]> : tensor<2xi32> +// CHECK: %[[SIZE:.+]] = arith.constant +// CHECK{LITERAL}: dense<[3, 4]> : tensor<2xi32> +// CHECK: %[[STRIDE:.+]] = arith.constant +// CHECK{LITERAL}: dense<[2, 3]> : tensor<2xi32> +// CHECK: %[[SLICE:.+]] = "tfl.strided_slice"(%[[ARG0]], %[[START]], %[[SIZE]], %[[STRIDE]]) {begin_mask = 0 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, offset = false, shrink_axis_mask = 0 : i32} : (tensor<3x6x!quant.uniform>, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor<2x2x!quant.uniform> +// CHECK: return %[[SLICE]] + +// ----- + +// Tests that a float stablehlo.slice is not converted to tfl.slice. + +func.func @float_slice(%arg0: tensor<3x4xf32>) -> tensor<2x2xf32> { + %0 = "stablehlo.slice"(%arg0) { + start_indices = array, + limit_indices = array, + strides = array + } : (tensor<3x4xf32>) -> tensor<2x2xf32> + return %0 : tensor<2x2xf32> +} +// CHECK-LABEL: float_slice +// CHECK-NOT: tfl.slice +// CHECK-NOT: tfl.strided_slice +// CHECK: stablehlo.slice + +// ----- + +// Tests that a quantized stablehlo.broadcast_in_dim is converted to +// tfl.broadcast_to. + +func.func @broadcast_in_dim( + %arg0: tensor<1x2x!quant.uniform> + ) -> tensor<3x2x!quant.uniform> { + %0 = "stablehlo.broadcast_in_dim"(%arg0) { + broadcast_dimensions = array + } : (tensor<1x2x!quant.uniform>) -> tensor<3x2x!quant.uniform> + return %0 : tensor<3x2x!quant.uniform> +} +// CHECK-LABEL: broadcast_in_dim +// CHECK-SAME: %[[ARG0:.+]]: tensor<1x2x!quant.uniform> +// CHECK: %[[SHAPE:.+]] = arith.constant +// CHECK{LITERAL}: dense<[3, 2]> : tensor<2xi32> +// CHECK: %[[BROADCAST:.+]] = "tfl.broadcast_to"(%[[ARG0]], %[[SHAPE]]) : (tensor<1x2x!quant.uniform>, tensor<2xi32>) -> tensor<3x2x!quant.uniform> +// CHECK: return %[[BROADCAST]] + +// ----- + +// Tests that a quantized stablehlo.broadcast_in_dim is converted to +// tfl.transpose and tfl.broadcast_to when broadcast_dimensions is not in +// ascending order. + +func.func @broadcast_in_dim_with_transpose( + %arg0: tensor<1x2x!quant.uniform> + ) -> tensor<2x3x!quant.uniform> { + %0 = "stablehlo.broadcast_in_dim"(%arg0) { + broadcast_dimensions = array + } : (tensor<1x2x!quant.uniform>) -> tensor<2x3x!quant.uniform> + return %0 : tensor<2x3x!quant.uniform> +} +// CHECK-LABEL: broadcast_in_dim_with_transpose +// CHECK-SAME: %[[ARG0:.+]]: tensor<1x2x!quant.uniform> +// CHECK: %[[BROADCAST_DIM:.+]] = arith.constant +// CHECK{LITERAL}: dense<[2, 3]> : tensor<2xi32> +// CHECK: %[[PERM:.+]] = arith.constant +// CHECK{LITERAL}: dense<[1, 0]> : tensor<2xi32> +// CHECK: %[[TRANSPOSE:.+]] = "tfl.transpose"(%[[ARG0]], %[[PERM]]) : (tensor<1x2x!quant.uniform>, tensor<2xi32>) -> tensor<2x1x!quant.uniform> +// CHECK: %[[BROADCAST:.+]] = "tfl.broadcast_to"(%[[TRANSPOSE]], %[[BROADCAST_DIM]]) : (tensor<2x1x!quant.uniform>, tensor<2xi32>) -> tensor<2x3x!quant.uniform> +// CHECK: return %[[BROADCAST]] + +// ----- + +// Tests that a quantized stablehlo.broadcast_in_dim is converted to +// tfl.expand_dims and tfl.broadcast_to when input rank is smaller than output +// rank. + +func.func @broadcast_in_dim_with_expand( + %arg0: tensor<1x2x!quant.uniform> + ) -> tensor<3x2x1x1x!quant.uniform> { + %0 = "stablehlo.broadcast_in_dim"(%arg0) { + broadcast_dimensions = array + } : (tensor<1x2x!quant.uniform>) -> tensor<3x2x1x1x!quant.uniform> + return %0 : tensor<3x2x1x1x!quant.uniform> +} +// CHECK-LABEL: broadcast_in_dim_with_expand +// CHECK-SAME: %[[ARG0:.+]]: tensor<1x2x!quant.uniform> +// CHECK-DAG: %[[BROADCAST_DIM:.+]] = arith.constant dense<{{\[3, 2, 1, 1\]}}> : tensor<4xi32> +// CHECK-DAG: %[[EXPAND_DIM1:.+]] = arith.constant dense<3> : tensor<1xi32> +// CHECK-DAG: %[[EXPAND_DIM0:.+]] = arith.constant dense<2> : tensor<1xi32> +// CHECK: %[[EXPAND0:.+]] = "tfl.expand_dims"(%[[ARG0]], %[[EXPAND_DIM0]]) : (tensor<1x2x!quant.uniform>, tensor<1xi32>) -> tensor<1x2x1x!quant.uniform> +// CHECK: %[[EXPAND1:.+]] = "tfl.expand_dims"(%[[EXPAND0]], %[[EXPAND_DIM1]]) : (tensor<1x2x1x!quant.uniform>, tensor<1xi32>) -> tensor<1x2x1x1x!quant.uniform> +// CHECK: %[[BROADCAST:.+]] = "tfl.broadcast_to"(%[[EXPAND1]], %[[BROADCAST_DIM]]) : (tensor<1x2x1x1x!quant.uniform>, tensor<4xi32>) -> tensor<3x2x1x1x!quant.uniform> +// CHECK: return %[[BROADCAST]] + +// ----- + +// Tests that a quantized stablehlo.broadcast_in_dim is converted to +// tfl.transpose, tfl.expand_dims and tfl.broadcast_to when broadcast_dimensions +// is not in ascending order and input rank is smaller than output rank. + +func.func @broadcast_in_dim_with_transpose_and_expand( + %arg0: tensor<2x3x4x!quant.uniform> + ) -> tensor<3x2x1x1x4x!quant.uniform> { + %0 = "stablehlo.broadcast_in_dim"(%arg0) { + broadcast_dimensions = array + } : (tensor<2x3x4x!quant.uniform>) -> tensor<3x2x1x1x4x!quant.uniform> + return %0 : tensor<3x2x1x1x4x!quant.uniform> +} +// CHECK-LABEL: broadcast_in_dim_with_transpose_and_expand +// CHECK-SAME: %[[ARG0:.+]]: tensor<2x3x4x!quant.uniform> +// CHECK-DAG: %[[BROADCAST_DIM:.+]] = arith.constant dense<{{\[3, 2, 1, 1, 4\]}}> : tensor<5xi32> +// CHECK-DAG: %[[EXPAND_DIM1:.+]] = arith.constant dense<3> : tensor<1xi32> +// CHECK-DAG: %[[EXPAND_DIM0:.+]] = arith.constant dense<2> : tensor<1xi32> +// CHECK-DAG: %[[PERM:.+]] = arith.constant dense<{{\[1, 0, 2\]}}> : tensor<3xi32> +// CHECK: %[[TRANSPOSE:.+]] = "tfl.transpose"(%[[ARG0]], %[[PERM]]) : (tensor<2x3x4x!quant.uniform>, tensor<3xi32>) -> tensor<3x2x4x!quant.uniform> +// CHECK: %[[EXPAND0:.+]] = "tfl.expand_dims"(%[[TRANSPOSE]], %[[EXPAND_DIM0]]) : (tensor<3x2x4x!quant.uniform>, tensor<1xi32>) -> tensor<3x2x1x4x!quant.uniform> +// CHECK: %[[EXPAND1:.+]] = "tfl.expand_dims"(%[[EXPAND0]], %[[EXPAND_DIM1]]) : (tensor<3x2x1x4x!quant.uniform>, tensor<1xi32>) -> tensor<3x2x1x1x4x!quant.uniform> +// CHECK: %[[BROADCAST:.+]] = "tfl.broadcast_to"(%[[EXPAND1]], %[[BROADCAST_DIM]]) : (tensor<3x2x1x1x4x!quant.uniform>, tensor<5xi32>) -> tensor<3x2x1x1x4x!quant.uniform> +// CHECK: return %[[BROADCAST]] + +// ----- + +// Tests that a float stablehlo.broadcast_in_dim is not converted to tfl.broadcast_to. + +func.func @float_broadcast_in_dim(%arg0: tensor<1x2xf32>) -> tensor<3x2xf32> { + %0 = "stablehlo.broadcast_in_dim"(%arg0) { + broadcast_dimensions = array + } : (tensor<1x2xf32>) -> tensor<3x2xf32> + return %0 : tensor<3x2xf32> +} +// CHECK-LABEL: float_broadcast_in_dim +// CHECK-NOT: tfl.broadcast_to +// CHECK-NOT: tfl.transpose +// CHECK-NOT: tfl.expand_dims +// CHECK: stablehlo.broadcast_in_dim + +// ----- + +// Test that a quantized stablehlo.reduce_window with max is converted to +// tfl.max_pool_2d. + +func.func @reduce_window_with_max( + %arg0: tensor<2x9x10x3x!quant.uniform>, + %arg1: tensor> +) -> tensor<2x4x3x3x!quant.uniform> { + %0 = "stablehlo.reduce_window"(%arg0, %arg1) ({ + ^bb0(%arg2: tensor>, %arg3: tensor>): + %1 = stablehlo.maximum %arg2, %arg3 : tensor> + stablehlo.return %1 : tensor> + }) {window_dimensions = array, window_strides = array} : (tensor<2x9x10x3x!quant.uniform>, tensor>) -> tensor<2x4x3x3x!quant.uniform> + return %0 : tensor<2x4x3x3x!quant.uniform> +} + +// CHECK-LABEL: reduce_window_with_max +// CHECK-SAME: %[[ARG0:.*]]: tensor<2x9x10x3x!quant.uniform> +// CHECK-SAME: %[[ARG1:.*]]: tensor> +// CHECK: %[[MAX_POOL:.*]] = "tfl.max_pool_2d"(%[[ARG0]]) +// CHECK-SAME: {filter_height = 3 : i32, filter_width = 4 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 2 : i32, stride_w = 3 : i32} +// CHECK-SAME: (tensor<2x9x10x3x!quant.uniform>) -> tensor<2x4x3x3x!quant.uniform> +// CHECK: return %[[MAX_POOL]] + +// ----- + +// Test that a quantized stablehlo.reduce_window with max whose rank is not 4 +// is not converted to tfl.max_pool_2d. + +func.func @reduce_window_not_4d( + %arg0: tensor<3x2x9x10x3x!quant.uniform>, + %arg1: tensor> +) -> tensor<3x2x4x3x3x!quant.uniform> { + %0 = "stablehlo.reduce_window"(%arg0, %arg1) ({ + ^bb0(%arg2: tensor>, %arg3: tensor>): + %1 = stablehlo.maximum %arg2, %arg3 : tensor> + stablehlo.return %1 : tensor> + }) {window_dimensions = array, window_strides = array} : (tensor<3x2x9x10x3x!quant.uniform>, tensor>) -> tensor<3x2x4x3x3x!quant.uniform> + return %0 : tensor<3x2x4x3x3x!quant.uniform> +} + +// CHECK-LABEL: reduce_window_not_4d +// CHECK: stablehlo.reduce_window +// CHECK-NOT: tfl.max_pool_2d + +// ----- + +// Test that a quantized stablehlo.reduce_window with max that takes multiple +// inputs is not converted to tfl.max_pool_2d. + +func.func @reduce_window_not_binary( + %arg0: tensor<3x2x9x10x3x!quant.uniform>, + %arg1: tensor<3x2x9x10x3x!quant.uniform>, + %arg2: tensor>, + %arg3: tensor> +) -> tensor<3x2x4x3x3x!quant.uniform> { + %0, %1 = "stablehlo.reduce_window"(%arg0, %arg1, %arg2, %arg3) ({ + ^bb0(%arg4: tensor>, %arg5: tensor>, %arg6: tensor>, %arg7: tensor>): + %2 = stablehlo.maximum %arg4, %arg5 : tensor> + %3 = stablehlo.maximum %arg6, %arg7 : tensor> + stablehlo.return %2, %3 : tensor>, tensor> + }) {window_dimensions = array, window_strides = array} : (tensor<3x2x9x10x3x!quant.uniform>, tensor<3x2x9x10x3x!quant.uniform>, tensor>, tensor>) -> (tensor<3x2x4x3x3x!quant.uniform>, tensor<3x2x4x3x3x!quant.uniform>) + return %0 : tensor<3x2x4x3x3x!quant.uniform> +} + +// CHECK-LABEL: reduce_window_not_binary +// CHECK: stablehlo.reduce_window +// CHECK-NOT: tfl.max_pool_2d + +// ----- + +// Test that a float stablehlo.reduce_window with max is not converted to +// tfl.max_pool_2d. + +func.func @float_reduce_window_with_max( + %arg0: tensor<2x9x10x3xf32>, + %arg1: tensor +) -> tensor<2x4x3x3xf32> { + %0 = "stablehlo.reduce_window"(%arg0, %arg1) ({ + ^bb0(%arg2: tensor, %arg3: tensor): + %1 = stablehlo.maximum %arg2, %arg3 : tensor + stablehlo.return %1 : tensor + }) {window_dimensions = array, window_strides = array} : (tensor<2x9x10x3xf32>, tensor) -> tensor<2x4x3x3xf32> + return %0 : tensor<2x4x3x3xf32> +} + +// CHECK-LABEL: float_reduce_window_with_max +// CHECK: stablehlo.reduce_window +// CHECK-NOT: tfl.max_pool_2d diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo.cc index e0218a504da76e..000f88639240f3 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo.cc +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo.cc @@ -482,10 +482,15 @@ class Convert1DConvOp : public OpConversionPattern { const int64_t input_channels = conv_op.getLhs().getType().cast().getDimSize( input_feature_dimension); + const int kernel_input_feature_dimension = + dnums.getKernelInputFeatureDimension(); + const int kernel_input_channels = + conv_op.getRhs().getType().cast().getDimSize( + kernel_input_feature_dimension); const int64_t feature_group_count = conv_op.getFeatureGroupCount(); - if (feature_group_count != 1 && feature_group_count != input_channels) - return rewriter.notifyMatchFailure(conv_op, - "Group convolution is not supported,"); + if (feature_group_count != input_channels / kernel_input_channels || + input_channels % kernel_input_channels != 0) + return failure(); // // Transpose and reshape the input and kernel @@ -498,6 +503,7 @@ class Convert1DConvOp : public OpConversionPattern { image_2d_shape.push_back(1); auto image_2d_type = RankedTensorType::get(image_2d_shape, image_type.getElementType()); + auto loc = conv_op.getLoc(); auto image_2d_op = rewriter.create( conv_op.getLoc(), image_2d_type, conv_op.getLhs()); @@ -509,8 +515,8 @@ class Convert1DConvOp : public OpConversionPattern { auto image_permutation_and_shape = GetPermutationAndTransposedShape( image_permutation, image_2d_type, rewriter); auto transposed_image_2d_op = rewriter.create( - conv_op.getLoc(), image_permutation_and_shape.shape, - image_2d_op->getResult(0), image_permutation_and_shape.permutation); + loc, image_permutation_and_shape.shape, image_2d_op->getResult(0), + image_permutation_and_shape.permutation); // Reshape kernel to add a new spatial dimension. auto kernel_type = conv_op.getRhs().getType().cast(); @@ -521,8 +527,8 @@ class Convert1DConvOp : public OpConversionPattern { kernel_2d_shape.push_back(1); auto kernel_2d_type = RankedTensorType::get(kernel_2d_shape, kernel_type.getElementType()); - auto kernel_2d_op = rewriter.create( - conv_op.getLoc(), kernel_2d_type, conv_op.getRhs()); + auto kernel_2d_op = + rewriter.create(loc, kernel_2d_type, conv_op.getRhs()); // Transpose kernel to get it into WHIO form (where H is the added dim). SmallVector kernel_permutation = { @@ -533,8 +539,8 @@ class Convert1DConvOp : public OpConversionPattern { auto kernel_permutation_and_shape = GetPermutationAndTransposedShape( kernel_permutation, kernel_2d_type, rewriter); auto transposed_kernel_2d_op = rewriter.create( - conv_op.getLoc(), kernel_permutation_and_shape.shape, - kernel_2d_op->getResult(0), kernel_permutation_and_shape.permutation); + loc, kernel_permutation_and_shape.shape, kernel_2d_op->getResult(0), + kernel_permutation_and_shape.permutation); // // Create 2d equivalents for 1d convolution attributes. @@ -624,11 +630,11 @@ class Convert1DConvOp : public OpConversionPattern { .shape; auto conv2d_op = rewriter.create( - conv_op.getLoc(), transposed_output_2d_shape, - transposed_image_2d_op.getResult(), transposed_kernel_2d_op.getResult(), - window_strides_2d, padding_2d, lhs_dilation_2d, rhs_dilation_2d, - window_reversal_2d, dnums_2d, conv_op.getFeatureGroupCount(), - conv_op.getBatchGroupCount(), conv_op.getPrecisionConfigAttr()); + loc, transposed_output_2d_shape, transposed_image_2d_op.getResult(), + transposed_kernel_2d_op.getResult(), window_strides_2d, padding_2d, + lhs_dilation_2d, rhs_dilation_2d, window_reversal_2d, dnums_2d, + conv_op.getFeatureGroupCount(), conv_op.getBatchGroupCount(), + conv_op.getPrecisionConfigAttr()); OpResult conv2d_output = conv2d_op->getResult(0); auto conv2d_output_type = conv2d_output.getType().cast(); @@ -642,7 +648,7 @@ class Convert1DConvOp : public OpConversionPattern { auto output_permutation_and_shape = GetInversePermutationAndShape( output_permutation, conv2d_output_type, rewriter); auto transposed_output_2d_op = rewriter.create( - conv_op.getLoc(), output_permutation_and_shape.shape, conv2d_output, + loc, output_permutation_and_shape.shape, conv2d_output, output_permutation_and_shape.permutation); // Drop the trailing spatial dimension from the output. diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_patterns.td b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_patterns.td index 1607343ab15879..fe988ba9b20265 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_patterns.td +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_patterns.td @@ -23,7 +23,7 @@ include "mhlo/IR/hlo_ops.td" // Check if broadcasting is compatible with TF ops. def IsLegalNumpyRankedBroadcast : - Constraint, + Constraint{})">, "broadcasting should be compatible with TF ops">; // Return a constant op that carries the shape of the given value. diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_stablehlo_to_vhlo.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_stablehlo_to_vhlo.cc new file mode 100644 index 00000000000000..d11ec0738a1b14 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_stablehlo_to_vhlo.cc @@ -0,0 +1,317 @@ +/* Copyright 2022 The StableHLO Authors. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include + +#include "llvm/Support/Debug.h" +#include "mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h" // from @llvm-project +#include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/IR/ValueRange.h" // from @llvm-project +#include "mlir/IR/Visitors.h" // from @llvm-project +#include "mlir/Pass/PassManager.h" // from @llvm-project +#include "mlir/Pass/PassRegistry.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/Transforms/DialectConversion.h" // from @llvm-project +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project +#include "stablehlo/api/PortableApi.h" // from @stablehlo +#include "stablehlo/dialect/StablehloOps.h" // from @stablehlo +#include "stablehlo/dialect/VhloOps.h" // from @stablehlo +#include "stablehlo/dialect/VhloTypes.h" // from @stablehlo +#include "stablehlo/transforms/Passes.h" // from @stablehlo +#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" + +#define DEBUG_TYPE "compat-passes" + +namespace mlir { +namespace odml { + +#define GEN_PASS_DEF_LEGALIZESTABLEHLOTOVHLOPASS +#define GEN_PASS_DEF_LEGALIZEVHLOTOSTABLEHLOPASS +#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/passes.h.inc" + +namespace { + +//===----------------------------------------------------------------------===// +// StableHLO --> VHLO types +//===----------------------------------------------------------------------===// + +std::optional MaterializeIllegalCast(OpBuilder &builder, Type type, + ValueRange inputs, Location loc) { + return builder.create(loc, type, inputs) + ->getResult(0); +} + +class StablehloToOdmlTypeConverter : public vhlo::VhloTypeConverter { + public: + StablehloToOdmlTypeConverter() : vhlo::VhloTypeConverter() { + addConversion([](Type type) { + if (type.getDialect().getNamespace() == + vhlo::VhloDialect::getDialectNamespace()) { + return type; + } + LLVM_DEBUG(llvm::dbgs() << "Invalid type: " << type << '\n'); + return Type(); + }); + addConversion([](stablehlo::TokenType token) { + return vhlo::TokenV1Type::get(token.getContext()); + }); + addBuiltinToVhloConversions(); + + addArgumentMaterialization(MaterializeIllegalCast); + addSourceMaterialization(MaterializeIllegalCast); + addTargetMaterialization(MaterializeIllegalCast); + } + + Attribute convertEncoding(Attribute attr) const final { + LLVM_DEBUG(llvm::dbgs() << "Converting encoding.\n" << attr << '\n'); + // Must be VHLO encoding, or convertible to VHLO encoding. + if (attr.getDialect().getNamespace() == + vhlo::VhloDialect::getDialectNamespace()) + return attr; + + if (auto stablehlo_attr = + attr.dyn_cast_or_null()) { + return vhlo::TypeExtensionsV1Attr::get(stablehlo_attr.getContext(), + stablehlo_attr.getBounds()); + } + + // Was not VHLO encoding, or convertible. + return {}; + } +}; + +class VhloToStablehloTypeConverter : public vhlo::VhloTypeConverter { + public: + VhloToStablehloTypeConverter() : vhlo::VhloTypeConverter() { + addConversion([](Type type) { return type; }); + addConversion([](vhlo::TokenV1Type token) { + LLVM_DEBUG(llvm::dbgs() << "Converting TokenType\n"); + return stablehlo::TokenType::get(token.getContext()); + }); + addVhloToBuiltinConversions(); + + addArgumentMaterialization(MaterializeIllegalCast); + addSourceMaterialization(MaterializeIllegalCast); + addTargetMaterialization(MaterializeIllegalCast); + } + + Attribute convertEncoding(Attribute attr) const final { + if (auto vhlo_attr = attr.dyn_cast_or_null()) { + return stablehlo::TypeExtensionsAttr::get(vhlo_attr.getContext(), + vhlo_attr.getBounds()); + } + // All encodings supported in StableHLO. + return attr; + } +}; + +//===----------------------------------------------------------------------===// +// StableHLO+TFL --> VHLO+TFL Ops +//===----------------------------------------------------------------------===// + +// Wrap op result uses in an unrealized cast to create a cast to buffer +// any type changes to result, and apply type converter to result: +// result = op(V0) +// V1 = op2(result) +// ==> +// result = op(V0) +// V1 = unrealized_cast(result) +// V2 = op2(V1) +void ConvertAndWrapUsesInUnrealizedCast(Value result, TypeConverter &converter, + IRRewriter &rewriter) { + auto type = result.getType(); + result.setType(converter.convertType(result.getType())); + auto new_value = converter.materializeArgumentConversion( + rewriter, result.getLoc(), type, {result}); + rewriter.replaceAllUsesExcept(result, new_value, new_value.getDefiningOp()); +} + +// Wrap operands in an an unrealized cast to create a cast to buffer any type +// changes to the operand, and apply type converter to operands: +// V0 = op(operand) +// ==> +// V0 = unrealized_cast(operand) +// V1 = op(V0) +void WrapOperandsInUnrealizedCastAndConvert(Operation *op, + TypeConverter &converter, + IRRewriter &rewriter) { + for (int i = 0; i < op->getNumOperands(); ++i) { + auto operand = op->getOperand(i); + auto new_operand = converter.materializeArgumentConversion( + rewriter, op->getLoc(), converter.convertType(operand.getType()), + {operand}); + op->setOperand(i, new_operand); + } +} + +// vhlo.op %1 : vhlo.tensor<...> +// ==> +// vhlo.op %1 : tensor<...> +// +// TODO: There's likely a way to make MLIR manage the unrealized cast +// conversions using a specific rewriter. +LogicalResult ApplyTypeConverter(ModuleOp op, TypeConverter &converter) { + IRRewriter rewriter(op->getContext()); + + op->walk([&](Operation *op) { + if (op->getDialect()->getNamespace() != "vhlo") return; + + // Convert operands + rewriter.modifyOpInPlace(op, [&]() { + rewriter.setInsertionPoint(op); + WrapOperandsInUnrealizedCastAndConvert(op, converter, rewriter); + + // Convert op types + for (auto value : op->getResults()) { + rewriter.setInsertionPointAfter(value.getDefiningOp()); + ConvertAndWrapUsesInUnrealizedCast(value, converter, rewriter); + } + + // Convert block arguments + for (auto ®ion : op->getRegions()) { + for (auto &block : region.getBlocks()) { + rewriter.setInsertionPointToStart(&block); + for (auto arg : block.getArguments()) { + ConvertAndWrapUsesInUnrealizedCast(arg, converter, rewriter); + } + } + } + }); + }); + return success(); +} + +// Legalize StableHLO portion of program to VHLO, leaves TFL untouched +LogicalResult ApplyStablehloToVhloPatterns(ModuleOp module, + bool is_func_legal) { + MLIRContext *context = module.getContext(); + ConversionTarget target(*context); + target.addIllegalDialect(); + target.addDynamicallyLegalDialect( + [&](auto) { return is_func_legal; }); + target.addLegalDialect(); + target.addLegalDialect(); + target.addLegalDialect(); + + StablehloToOdmlTypeConverter converter; + RewritePatternSet patterns(context); + stablehlo::populateStablehloToVhloPatterns(&patterns, &converter, context); + + if (failed(applyPartialConversion(module, target, std::move(patterns)))) { + return module->emitError("Failed partial conversion to VHLO"); + } + return success(); +} + +LogicalResult ApplyVhloToVersionPatterns(ModuleOp module, + const std::string &version) { + PassManager pm(module.getContext()); + pm.addPass(stablehlo::createVhloToVersionPass({version})); + if (failed(pm.run(module))) { + return module->emitError("Failed VHLO to version") << version; + } + return success(); +} + +// Legalize VHLO portion of program to StableHLO, leaves TFL untouched. +LogicalResult ApplyVhloToStablehloPatterns(ModuleOp module) { + MLIRContext *context = module.getContext(); + ConversionTarget target(*context); + target.addIllegalDialect(); + target.addLegalDialect(); + target.addLegalDialect(); + target.addLegalDialect(); + target.addLegalDialect(); + + VhloToStablehloTypeConverter converter; + RewritePatternSet patterns(context); + stablehlo::populateVhloToStablehloPatterns(&patterns, &converter, context); + + if (failed(applyPartialConversion(module, target, std::move(patterns)))) { + return module->emitError("Failed partial conversion to StableHLO"); + } + return success(); +} + +LogicalResult ApplyUnrealizedCastCanonicalization(ModuleOp module) { + RewritePatternSet patterns(module->getContext()); + populateReconcileUnrealizedCastsPatterns(patterns); + if (failed(applyPatternsAndFoldGreedily(module, std::move(patterns)))) { + return module->emitError("Failed to fold unrealized cast"); + } + return success(); +} + +} // namespace + +struct LegalizeStablehloToVhloPass + : public impl::LegalizeStablehloToVhloPassBase< + LegalizeStablehloToVhloPass> { + void runOnOperation() override { + ModuleOp module = getOperation(); + std::string target_version = "0.14.0"; + VhloToStablehloTypeConverter to_builtin_converter; + + // StableHLO --> VHLO (allow funcs) + // VHLO -> Downgrade to 0.14.0 + // VHLO Tensor --> Builtin Tensor + // Remove cast(tensor->vhlo) -> cast(vhlo->tensor) pattern + if (failed(ApplyStablehloToVhloPatterns(module, + /*is_func_legal=*/true)) || + failed(ApplyVhloToVersionPatterns(module, target_version)) || + failed(ApplyTypeConverter(module, to_builtin_converter)) || + failed(ApplyUnrealizedCastCanonicalization(module))) + return signalPassFailure(); + } +}; + +struct LegalizeVhloToStablehloPass + : public impl::LegalizeVhloToStablehloPassBase< + LegalizeVhloToStablehloPass> { + void runOnOperation() override { + // Revert the tensor types to VHLO + auto module = getOperation(); + StablehloToOdmlTypeConverter to_vhlo_converter; + + // Builtin Tensor --> VHLO Tensor + // StableHLO --> VHLO + // VHLO --> Upgrade to current + // VHLO --> StableHLO + // Remove cast(tensor->vhlo) -> cast(vhlo->tensor) pattern + if (failed(ApplyTypeConverter(module, to_vhlo_converter)) || + failed(ApplyStablehloToVhloPatterns(module, + /*is_func_legal=*/false)) || + failed(ApplyVhloToVersionPatterns(module, + stablehlo::getCurrentVersion())) || + failed(ApplyVhloToStablehloPatterns(module)) || + failed(ApplyUnrealizedCastCanonicalization(module))) + return signalPassFailure(); + } +}; + +static PassRegistration pass_s2v; +static PassRegistration pass_v2s; + +} // namespace odml +} // namespace mlir diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/passes.h b/tensorflow/compiler/mlir/lite/stablehlo/transforms/passes.h index 211019b70524f5..066c3d6037f726 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/passes.h +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/passes.h @@ -47,7 +47,7 @@ CreateComposeUniformQuantizedTypePass(); // quantized typed tensors and converts them to equivalent ops in the TFLite // dialect. std::unique_ptr> -CreateUniformQuantizedStablehloToTflPass(); +CreateUniformQuantizedStableHloToTflPass(); // Create a pass that legalizes MHLO to TF dialect. std::unique_ptr> CreateLegalizeHloToTfPass(); @@ -63,6 +63,8 @@ std::unique_ptr> CreateLegalizeHloToTfLitePass(); void PopulateLegalizeHloToTfPatterns(RewritePatternSet* patterns, MLIRContext* context); +#define GEN_PASS_DECL_LEGALIZESTABLEHLOTOVHLOPASS +#define GEN_PASS_DECL_LEGALIZEVHLOTOSTABLEHLOPASS #define GEN_PASS_REGISTRATION #include "tensorflow/compiler/mlir/lite/stablehlo/transforms/passes.h.inc" diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/passes.td b/tensorflow/compiler/mlir/lite/stablehlo/transforms/passes.td index ca49787a715bd5..002990601a9efb 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/passes.td +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/passes.td @@ -58,10 +58,10 @@ def ComposeUniformQuantizedTypePass : Pass<"compose-uniform-quantized-type", "Mo ]; } -def UniformQuantizedStablehloToTflPass +def UniformQuantizedStableHloToTflPass : Pass<"uniform-quantized-stablehlo-to-tfl", "mlir::func::FuncOp"> { let summary = "Converts StableHLO ops using uniform quantized types to equivalent TFL ops."; - let constructor = "mlir::odml::CreateUniformQuantizedStablehloToTflPass()"; + let constructor = "mlir::odml::CreateUniformQuantizedStableHloToTflPass()"; let description = [{ Converts StableHLO ops that accept or return uniform quantized types to equivalent ops in the TFLite dialect. @@ -85,6 +85,17 @@ def LegalizeHloToTfLitePass : Pass<"tfl-legalize-hlo", "mlir::ModuleOp"> { let constructor = "mlir::odml::CreateLegalizeHloToTfLitePass()"; } +def LegalizeStablehloToVhloPass : Pass<"stablehlo-legalize-vhlo", "ModuleOp"> { + let summary = "Legalize StableHLO to VHLO for ODML."; + let dependentDialects = ["mlir::vhlo::VhloDialect"]; +} + +def LegalizeVhloToStablehloPass : Pass<"vhlo-legalize-stablehlo", "ModuleOp"> { + let summary = "Legalize VHLO to StableHLO for ODML."; + let dependentDialects = ["mlir::stablehlo::StablehloDialect"]; +} + + def UnfoldSplatConstantPass : Pass<"unfold-splat-constant-pass", "ModuleOp"> { let summary = "Replaces a splat constant tensor with a BroadcastInDim op."; let constructor = "mlir::odml::CreateUnfoldSplatConstantPass()"; diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/transforms.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/transforms.cc index df3a5f62e8ff59..7e01a70ec2de63 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/transforms.cc +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/transforms.cc @@ -38,7 +38,7 @@ void AddTFToStablehloPasses(OpPassManager& pm, bool skip_resize, // if the input is a call_xla_module, then unwrap the content pm.addPass(mlir::odml::CreateLegalizeTFXlaCallModuleToStablehloPass()); - // TODO(b/230572023): Consider improving shape inference for While op instead + // TODO: b/230572023 - Consider improving shape inference for While op instead // of dropping the attribute. This need not be correct for models not trained // on TPU. @@ -85,11 +85,18 @@ void AddTFToStablehloPasses(OpPassManager& pm, bool skip_resize, } } -void AddMhloOptimizationPasses(OpPassManager& pm) { +void AddMhloOptimizationPasses(OpPassManager& pm, + const bool add_fold_broadcast_pass) { // Rewrites some patterns for better performance. pm.addNestedPass(createUnfuseBatchNormPass()); pm.addNestedPass(createFuseConvolutionPass()); pm.addNestedPass(createOptimizePass()); + // Conditionally enable below pass because this causes unfused convolutions + // described in b/293149194. This problem is not replicated in + // StableHLO Quantizer. + if (add_fold_broadcast_pass) { + pm.addNestedPass(createFoldBroadcastPass()); + } // Rewrites legacy StableHLO ops. pm.addNestedPass(mhlo::createLegalizeEinsumToDotGeneralPass()); @@ -109,8 +116,8 @@ void AddStablehloOptimizationPasses(OpPassManager& pm) { // StableHLO -> MHLO legalization. pm.addPass(mhlo::createStablehloLegalizeToHloPass()); - AddMhloOptimizationPasses(pm); - // TODO(b/293149194) Add `createFoldBroadcastPass` back to + AddMhloOptimizationPasses(pm, /*enable_stablehlo_quantizer=*/false); + // TODO: b/293149194 - Add `createFoldBroadcastPass` back to // `AddMhloOptimizationPasses` pm.addNestedPass(createFoldBroadcastPass()); diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/transforms.h b/tensorflow/compiler/mlir/lite/stablehlo/transforms/transforms.h index 02d5d527901c9a..abcdd8276ca903 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/transforms.h +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/transforms.h @@ -36,7 +36,7 @@ void AddTFToStablehloPasses(OpPassManager& pm, bool skip_resize, void AddStablehloOptimizationPasses(OpPassManager& pm); // Adds all the backend-agonstic stableHLO optimization passes -void AddMhloOptimizationPasses(OpPassManager& pm); +void AddMhloOptimizationPasses(OpPassManager& pm, bool add_fold_broadcast_pass); } // namespace odml } // namespace mlir diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/uniform_quantized_stablehlo_to_tfl_pass.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/uniform_quantized_stablehlo_to_tfl_pass.cc index 3cca33c89280c4..d1245ec4d2653f 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/uniform_quantized_stablehlo_to_tfl_pass.cc +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/uniform_quantized_stablehlo_to_tfl_pass.cc @@ -14,20 +14,21 @@ limitations under the License. ==============================================================================*/ #include #include -#include #include +#include #include #include "absl/algorithm/container.h" #include "absl/log/check.h" #include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" #include "llvm/Support/Debug.h" -#include "llvm/Support/MathExtras.h" #include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project // NOLINT: Required to register quantization dialect. #include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project #include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributeInterfaces.h" // from @llvm-project #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project #include "mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project #include "mlir/IR/BuiltinTypes.h" // from @llvm-project @@ -40,9 +41,9 @@ limitations under the License. #include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project -#include "stablehlo/dialect/Base.h" // from @stablehlo #include "stablehlo/dialect/StablehloOps.h" // from @stablehlo #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" +#include "tensorflow/compiler/mlir/quantization/common/attrs_and_constraints.h" #include "tensorflow/compiler/mlir/quantization/common/uniform_quantized_types.h" #define DEBUG_TYPE "uniform-quantized-stablehlo-to-tfl" @@ -54,10 +55,14 @@ namespace { // TODO: b/311029361: Add e2e test for verifying this legalization once // StableHLO Quantizer API migration is complete. +using ::mlir::quant::CastI64ArrayToI32; +using ::mlir::quant::CastI64ToI32; using ::mlir::quant::CreateI32F32UniformQuantizedPerAxisType; using ::mlir::quant::CreateI32F32UniformQuantizedType; using ::mlir::quant::CreateI8F32UniformQuantizedPerAxisType; using ::mlir::quant::CreateI8F32UniformQuantizedType; +using ::mlir::quant::FindUserOfType; +using ::mlir::quant::IsI32F32UniformQuantizedPerAxisType; using ::mlir::quant::IsI32F32UniformQuantizedType; using ::mlir::quant::IsI8F32UniformQuantizedPerAxisType; using ::mlir::quant::IsI8F32UniformQuantizedType; @@ -68,16 +73,20 @@ using ::mlir::quant::QuantizedType; using ::mlir::quant::UniformQuantizedPerAxisType; using ::mlir::quant::UniformQuantizedType; +const char* kPaddingSame = "SAME"; +const char* kPaddingValid = "VALID"; + #define GEN_PASS_DEF_UNIFORMQUANTIZEDSTABLEHLOTOTFLPASS #include "tensorflow/compiler/mlir/lite/stablehlo/transforms/passes.h.inc" -class UniformQuantizedStablehloToTflPass - : public impl::UniformQuantizedStablehloToTflPassBase< - UniformQuantizedStablehloToTflPass> { +class UniformQuantizedStableHloToTflPass + : public impl::UniformQuantizedStableHloToTflPassBase< + UniformQuantizedStableHloToTflPass> { private: void runOnOperation() override; }; +// TODO: b/323645515 - Refactor reference functions. // Bias scales for matmul-like ops should be input scale * filter scale. Here it // is assumed that the input is per-tensor quantized and filter is per-channel // quantized. @@ -103,7 +112,7 @@ double GetBiasScale(const double input_scale, const double filter_scale) { // whereas `tfl.fully_connected` accepts an OI format. TFL::QConstOp CreateTflConstOpForFilter( stablehlo::ConstantOp filter_constant_op, PatternRewriter& rewriter, - bool is_per_axis) { + bool is_per_channel) { const auto filter_values = filter_constant_op.getValue() .cast() .getValues(); @@ -132,7 +141,7 @@ TFL::QConstOp CreateTflConstOpForFilter( Type new_filter_quantized_type; - if (is_per_axis) { + if (is_per_channel) { auto filter_quantized_type = filter_constant_op.getResult() .getType() .cast() @@ -173,17 +182,16 @@ TFL::QConstOp CreateTflConstOpForFilter( // transformation). The quantization scale for the bias is input scale * // filter scale. `filter_const_op` is used to retrieve the filter scales and // the size of the bias constant. -// TODO - b/309896242: Support bias fusion legalization. -TFL::QConstOp CreateTflConstOpForDummyBias(const Location loc, - const double input_scale, - TFL::QConstOp filter_const_op, - PatternRewriter& rewriter, - bool is_per_axis, MLIRContext& ctx) { +// TODO - b/309896242: Support bias fusion legalization and spatial dimension +// check when `stride` is not 1. +TFL::QConstOp CreateTflConstOpForDummyBias( + const Location loc, const double input_scale, TFL::QConstOp filter_const_op, + PatternRewriter& rewriter, bool is_per_channel, MLIRContext& ctx) { const ArrayRef filter_shape = filter_const_op.getResult().getType().getShape(); Type bias_quantized_type; - if (is_per_axis) { + if (is_per_channel) { const auto filter_quantized_element_type = filter_const_op.getResult() .getType() @@ -239,7 +247,8 @@ class RewriteUniformQuantizeOp const Type input_element_type = op.getOperand().getType().cast().getElementType(); if (!(input_element_type.isa() || - IsI32F32UniformQuantizedType(input_element_type))) { + IsI32F32UniformQuantizedType(input_element_type) || + IsI32F32UniformQuantizedPerAxisType(input_element_type))) { LLVM_DEBUG(llvm::dbgs() << "Uniform quantize op's input should be a " "float type or int32. Got: " << input_element_type << ".\n"); @@ -319,94 +328,493 @@ class RewriteUniformDequantizeOp } }; -// Rewrites `stablehlo.convolution` -> `tfl.conv_2d` when it accepts uniform -// quantized tensors. +// Rewrites `stablehlo.dot_general` to `tfl.fully_connected` or +// `tfl.batch_matmul` when it accepts uniform quantized tensors. // -// Conditions for the conversion: -// * Input and output tensors are per-tensor uniform quantized (i8->f32) +// StableHLO Quantizer output: +// * input: per-tensor qi8 +// * filter: per-tensor qi8 +// * output: per-tensor qi32 +// JAX Quantizer output: +// * input: per-tensor qi8 +// * filter: per-channel qi8 +// * output: per-tensor qi8 +// +// Conditions for the `tfl.batch_matmul` conversion: +// * size(batching_dimensions) <= 3 (TFLite support restriction) +// * size(contracting_dimensions) = 1 +// * Input tensors are per-tensor uniform quantized (i8->f32) +// tensors (full integer) with shape [..., r_x, c_x] or [..., c_x, r_x]. +// * The filter tensor is a per-tensor uniform quantized (i8->f32) tensor +// (constant or activation) with shape [..., r_y, c_y] or [..., c_y, r_y]. +// * Output tensors are per-tensor uniform quantized (i8->f32) or +// per-channel uniform quantized (i32->f32) tensors. +// +// Conditions for `tfl.fully_connected` conversion: +// * Input tensors are per-tensor uniform quantized (i8->f32) // tensors. -// * The filter tensor is constant a per-channel uniform quantized (i8->f32) -// tensor. -// * Convolution is a 2D convolution op and both the input's and filter's -// shape is 4 dimensional. -// * The filter tensor's format is `[0, 1, i, o]`. -// * Not a depthwise convolution. -// * Does not consider bias add fusion. -// TODO: b/294771704 - Support bias quantization. -// TODO: b/322428814 - Add StableHLO quantizer integration tests for ODML. -class RewriteUpstreamQuantizedConvolutionOp - : public OpRewritePattern { +// * The filter tensor is constant a per-tensor uniform quantized (i8->f32) +// tensor. The quantization dimension should be 1 (the non-contracting +// dimension). +// * Output tensors are per-tensor uniform quantized (i8->f32) or +// per-channel uniform quantized (i32->f32) tensors. +// * The input tensor's rank is either 2 or 3. The last dimension of the input +// tensor should be the contracting dimension, i.e. [..., c_x, r_x]. +// * The filter tensor's rank is 2. The contracting dimension should be the +// first dimension (dim 0), i.e. [c_y, r_y] where c_y == r_x. +// TODO: b/309896242 - Add support for fused op case. Add support for +// per-channel quantization. +// TODO: b/295264927 - `stablehlo.dot_general` with per-axis quantized operands +// is not specified in the StableHLO dialect. Update the spec to allow this. +class RewriteQuantizedDotGeneralOpToTflFullyConnectedOrBatchMatmulOp + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + public: - using OpRewritePattern::OpRewritePattern; + LogicalResult match(stablehlo::DotGeneralOp op) const override { + const stablehlo::DotDimensionNumbersAttr dot_dimension_nums = + op.getDotDimensionNumbers(); + const bool is_batch_matmul = + !dot_dimension_nums.getLhsBatchingDimensions().empty(); + const bool has_i32_output = IsI32F32UniformQuantizedType( + op.getResult().getType().cast().getElementType()); - static LogicalResult MatchInput(Value input) { - auto input_type = input.getType().cast(); - if (input_type.getRank() != 4) { - LLVM_DEBUG(llvm::dbgs() << "Only 2D convolution op is supported. " - "Expected input rank of 4. Got: " - << input_type.getRank() << ".\n"); + if (failed(MatchInputDotGeneralCommonPattern(op.getLhs()))) { + LLVM_DEBUG(llvm::dbgs() + << "Failed to match input for quantized dot_general.\n"); + return failure(); + } + if (failed(MatchFilterCommonPattern(op.getRhs()))) { + LLVM_DEBUG(llvm::dbgs() + << "Failed to match filter for quantized dot_general.\n"); + return failure(); + } + if (failed(MatchOutput(op.getResult(), has_i32_output))) { + LLVM_DEBUG(llvm::dbgs() + << "Failed to match output for quantized dot_general.\n"); return failure(); } - if (const auto input_element_type = input_type.getElementType(); - !IsI8F32UniformQuantizedType(input_element_type)) { + if (is_batch_matmul) { + return MatchDotGeneralToTflBatchMatmulOp(op, dot_dimension_nums, + has_i32_output); + } + return MatchDotGeneralToTflFullyConnectedOp(op, dot_dimension_nums, + has_i32_output); + } + + void rewrite(stablehlo::DotGeneralOp op, + PatternRewriter& rewriter) const override { + const bool has_i32_output = IsI32F32UniformQuantizedType( + op.getResult().getType().cast().getElementType()); + const stablehlo::DotDimensionNumbersAttr dot_dimension_nums = + op.getDotDimensionNumbers(); + const bool is_batch_matmul = + !dot_dimension_nums.getLhsBatchingDimensions().empty(); + + if (is_batch_matmul) { + RewriteDotGeneralToTflBatchMatmulOp(op, rewriter, dot_dimension_nums, + has_i32_output); + } else { + RewriteDotGeneralToTflFullyConnectedOp(op, rewriter, dot_dimension_nums, + has_i32_output); + } + } + + private: + static LogicalResult MatchDotGeneralToTflBatchMatmulOp( + stablehlo::DotGeneralOp op, + const stablehlo::DotDimensionNumbersAttr dot_dimension_nums, + const bool has_i32_output) { + if (has_i32_output && !HasOneUseByQuantizeOp(op)) { LLVM_DEBUG(llvm::dbgs() - << "Expected an i8->f32 uniform quantized type. Got: " - << input_element_type << ".\n"); + << "When output type of dot_general is qi32, it should have " + "only one use of requantization.\n"); + return failure(); + } + + const int num_lhs_batching_dims = + dot_dimension_nums.getLhsBatchingDimensions().size(); + const int num_lhs_contracting_dims = + dot_dimension_nums.getLhsContractingDimensions().size(); + if (num_lhs_batching_dims > 3) { + LLVM_DEBUG(llvm::dbgs() << "Failed to match batching dimension for " + "quantized dot_general.\n"); + return failure(); + } + // Checking one side is enough since + // (C1) size(lhs_batching_dimensions) = size(rhs_batching_dimensions). + if (num_lhs_contracting_dims != 1) { + // Check one side is enough since + // (C2) size(lhs_contracting_dimensions) = + // size(rhs_contracting_dimensions). + LLVM_DEBUG(llvm::dbgs() << "Failed to match contract dimension for " + "quantized dot_general.\n"); + return failure(); + } + const auto input_type = op.getLhs().getType().cast(); + const int input_rank = input_type.getRank(); + const auto input_contracting_dim = + dot_dimension_nums.getLhsContractingDimensions()[0]; + if ((input_contracting_dim != input_rank - 1) && + (input_contracting_dim != input_rank - 2)) { + LLVM_DEBUG(llvm::dbgs() + << "Failed to match input contracting dimensions.\n"); return failure(); } + const auto filter_type = op.getRhs().getType().cast(); + const Type filter_element_type = filter_type.getElementType(); + if (!IsI8F32UniformQuantizedType(filter_element_type)) { + LLVM_DEBUG(llvm::dbgs() + << "Expected a per-tensor uniform " + "quantized (i8->f32) weight for dot_general. Got: " + << filter_type << "\n"); + return failure(); + } + const int rhs_rank = filter_type.cast().getRank(); + const auto rhs_contracting_dim = + dot_dimension_nums.getRhsContractingDimensions()[0]; + if ((rhs_contracting_dim != rhs_rank - 1) && + (rhs_contracting_dim != rhs_rank - 2)) { + LLVM_DEBUG(llvm::dbgs() + << "Not supported rhs contracting dim for dot_general.\n"); + return failure(); + } return success(); } - static LogicalResult MatchFilter(Value filter) { - auto filter_type = filter.getType().cast(); - if (filter_type.getRank() != 4) { - LLVM_DEBUG(llvm::dbgs() << "Only 2D convolution op is supported. " - "Expected filter rank of 4. Got: " - << filter_type.getRank() << ".\n"); + static LogicalResult MatchDotGeneralToTflFullyConnectedOp( + stablehlo::DotGeneralOp op, + const stablehlo::DotDimensionNumbersAttr dot_dimension_nums, + const bool has_i32_output) { + const int num_lhs_contracting_dims = + dot_dimension_nums.getLhsContractingDimensions().size(); + const int num_rhs_contracting_dims = + dot_dimension_nums.getRhsContractingDimensions().size(); + if (num_lhs_contracting_dims != 1 || num_rhs_contracting_dims != 1) { + LLVM_DEBUG(llvm::dbgs() + << "Expected number of contracting dimensions to be 1. Got: " + << num_rhs_contracting_dims << ".\n"); return failure(); } - const Type filter_element_type = filter_type.getElementType(); - if (!IsI8F32UniformQuantizedPerAxisType(filter_type.getElementType())) { - LLVM_DEBUG( - llvm::dbgs() - << "Expected a per-channel uniform quantized (i8->f32) type. Got: " - << filter_element_type << "\n"); + const auto input_type = op.getLhs().getType().cast(); + if (!(input_type.getRank() == 2 || input_type.getRank() == 3)) { + LLVM_DEBUG(llvm::dbgs() << "Input expected to have rank of 2 or 3. Got: " + << input_type << ".\n"); return failure(); } - if (filter_element_type.cast() - .getQuantizedDimension() != 3) { - LLVM_DEBUG(llvm::dbgs() << "Quantized dimension should be 3. Got: " - << filter_element_type << "\n"); + const auto filter_type = op.getRhs().getType().cast(); + if (filter_type.getRank() != 2) { + LLVM_DEBUG(llvm::dbgs() + << "Filter tensor expected to have a tensor rank of 2. Got: " + << filter_type << ".\n"); return failure(); } + if (has_i32_output) { + if (!IsI8F32UniformQuantizedType(filter_type.getElementType())) { + LLVM_DEBUG(llvm::dbgs() << "Expected a per-channel uniform quantized " + "(i8->f32) type. Got: " + << filter_type.getElementType() << "\n"); + return failure(); + } + } else { + if (!IsI8F32UniformQuantizedPerAxisType(filter_type.getElementType())) { + LLVM_DEBUG(llvm::dbgs() << "Expected a per-channel uniform quantized " + "(i8->f32) type. Got: " + << filter_type.getElementType() << "\n"); + return failure(); + } + } - if (Operation* filter_op = filter.getDefiningOp(); - filter_op == nullptr || !isa(filter_op)) { - LLVM_DEBUG(llvm::dbgs() << "Filter should be a constant.\n"); + // If the op has a fusible bias, make sure the bias is a constant. + if (auto add_op = FindUserOfType(op); + add_op != nullptr && + !isa(add_op->getOperand(1).getDefiningOp())) { + LLVM_DEBUG(llvm::dbgs() << "Expected a `stablehlo.constant` as the " + << "rhs of `stablehlo.add`.\n"); + } + + return success(); + } + + static LogicalResult MatchInputDotGeneralCommonPattern(const Value input) { + const auto input_type = input.getType().cast(); + if (const auto input_element_type = input_type.getElementType(); + !IsI8F32UniformQuantizedType(input_element_type)) { + LLVM_DEBUG(llvm::dbgs() + << "Expected an i8->f32 uniform quantized type. Got: " + << input_element_type << ".\n"); return failure(); } + if (!input_type.hasRank()) { + LLVM_DEBUG(llvm::dbgs() << "Expected input_type to have rank.\n"); + return failure(); + } return success(); } - static LogicalResult MatchOutput(Value output) { + static LogicalResult MatchFilterCommonPattern(const Value filter) { + auto filter_type = filter.getType().cast(); + if (!filter_type.hasRank()) { + LLVM_DEBUG(llvm::dbgs() << "Expected rhs of dot_general has rank. Got: " + << filter.getType() << "\n"); + return failure(); + } + return success(); + } + + static LogicalResult MatchOutput(const Value output, + const bool has_i32_output) { const Type output_element_type = output.getType().cast().getElementType(); + if (has_i32_output) { + if (!IsI32F32UniformQuantizedType(output_element_type)) { + LLVM_DEBUG( + llvm::dbgs() + << "Expected a per-tensor uniform quantized (i32->f32) type. Got: " + << output_element_type << ".\n"); + return failure(); + } + return success(); + } if (!IsI8F32UniformQuantizedType(output_element_type)) { - LLVM_DEBUG(llvm::dbgs() - << "Expected a uniform quantized (i8->f32) type. Got: " - << output_element_type << ".\n"); + LLVM_DEBUG( + llvm::dbgs() + << "Expected a per-tensor uniform quantized (i8->f32) type. Got: " + << output_element_type << ".\n"); return failure(); } - return success(); } + static void RewriteDotGeneralToTflBatchMatmulOp( + stablehlo::DotGeneralOp op, PatternRewriter& rewriter, + const stablehlo::DotDimensionNumbersAttr dot_dimension_nums, + const bool has_i32_output) { + const auto rhs_contracting_dims = + dot_dimension_nums.getRhsContractingDimensions(); + const auto lhs_contracting_dims = + dot_dimension_nums.getLhsContractingDimensions(); + + const Value rhs_value = op.getRhs(); + const Value lhs_value = op.getLhs(); + + Operation* rhs_op = rhs_value.getDefiningOp(); + auto filter_constant_op = dyn_cast_or_null(rhs_op); + + // Set to `nullptr` because this attribute only matters when the input is + // dynamic-range quantized. + const BoolAttr asymmetric_quantize_inputs = nullptr; + + const int lhs_rank = lhs_value.getType().cast().getRank(); + const BoolAttr adj_x = + (lhs_contracting_dims[0] == lhs_rank - 2 ? rewriter.getBoolAttr(true) + : rewriter.getBoolAttr(false)); + const int rhs_rank = rhs_value.getType().cast().getRank(); + const BoolAttr adj_y = + (rhs_contracting_dims[0] == rhs_rank - 1 ? rewriter.getBoolAttr(true) + : rewriter.getBoolAttr(false)); + + Value result = op.getResult(); + Operation* result_user_op = *op->getUsers().begin(); + if (isa(result_user_op) || + isa(result_user_op)) { + result = result_user_op->getResult(0); + } + + // Create BMM assuming rhs is activation. + auto tfl_batchmatmul_op = rewriter.create( + op.getLoc(), /*output=*/result.getType(), + /*input=*/lhs_value, + /*filter=*/rhs_value, adj_x, adj_y, asymmetric_quantize_inputs); + + // Update BMM if rhs is a constant. + if (filter_constant_op != nullptr) { + const auto rhs_uniform_quantized_type = + rhs_value.getType().cast(); + const auto rhs_constant_value_attr = + cast(filter_constant_op.getValue()); + auto rhs_constant_op = rewriter.create( + rhs_op->getLoc(), + /*output=*/TypeAttr::get(rhs_uniform_quantized_type), + rhs_constant_value_attr); + tfl_batchmatmul_op = rewriter.create( + op.getLoc(), /*output=*/result.getType(), + /*input=*/lhs_value, /*filter=*/rhs_constant_op.getResult(), adj_x, + adj_y, asymmetric_quantize_inputs); + } + + rewriter.replaceAllUsesWith(result, tfl_batchmatmul_op.getResult()); + } + + static void RewriteDotGeneralToTflFullyConnectedOp( + stablehlo::DotGeneralOp op, PatternRewriter& rewriter, + const stablehlo::DotDimensionNumbersAttr dot_dimension_nums, + const bool has_i32_output) { + const Value rhs_value = op.getRhs(); + const Value lhs_value = op.getLhs(); + + Operation* rhs_op = rhs_value.getDefiningOp(); + const auto filter_constant_op = + dyn_cast_or_null(rhs_op); + + // Set to `nullptr` because this attribute only matters when the input is + // dynamic-range quantized. + const BoolAttr asymmetric_quantize_inputs = nullptr; + + // Checks for `tfl.fully_connected` condition. + + // StableHLO Quantizer does not yet support per-channel quantization of + // dot_general. + const bool is_per_channel = !has_i32_output; + // Create the new filter constant - transpose filter value + // from [i, o] -> [o, i]. This is because we assume `[i, o]` format for + // `stablehlo.dot_general` (i.e. contracting dimension == 1) whereas + // `tfl.fully_connected` accepts an OI format. + TFL::QConstOp new_filter_constant_op = + CreateTflConstOpForFilter(filter_constant_op, rewriter, is_per_channel); + + const double input_scale = lhs_value.getType() + .cast() + .getElementType() + .cast() + .getScale(); + TFL::QConstOp bias_tfl_op; + bool fuse_bias_constant = + FindUserOfType(op) && has_i32_output; + // Get the desired output type and extract any existing fusible bias + // as `TFL::QConstOp` so that it can be fused with TFL::FullyConnectedOp`. + TensorType output_type = GetOutputTypeAndOptionallyUpdateBias( + op, rewriter, &bias_tfl_op, has_i32_output, fuse_bias_constant); + + // If there is no explicit bias, create a dummy value filled with zeroes. + if (!fuse_bias_constant) { + bias_tfl_op = CreateTflConstOpForDummyBias( + op.getLoc(), input_scale, new_filter_constant_op, rewriter, + is_per_channel, *op.getContext()); + } + rewriter.replaceOpWithNewOp( + op, /*output=*/output_type, + /*input=*/lhs_value, + /*filter=*/new_filter_constant_op.getResult(), + /*bias=*/bias_tfl_op.getResult(), + /*fused_activation_function=*/rewriter.getStringAttr("NONE"), + /*weights_format=*/rewriter.getStringAttr("DEFAULT"), + /*keep_num_dims=*/rewriter.getBoolAttr(false), + asymmetric_quantize_inputs); + } + + static TensorType GetOutputTypeAndOptionallyUpdateBias( + Operation* op, PatternRewriter& rewriter, TFL::QConstOp* bias_tfl_op, + const bool has_i32_output, const bool fuse_bias_constant) { + TensorType output_type; + if (has_i32_output) { + Operation* uniform_quantize_op; + if (fuse_bias_constant) { + Operation* add_op = FindUserOfType(op); + uniform_quantize_op = FindUserOfType(add_op); + auto filter_quantized_type = op->getOperand(1) + .getType() + .cast() + .getElementType() + .cast(); + double bias_scale = GetBiasScale( + /*input_scale=*/op->getOperand(0) + .getType() + .cast() + .getElementType() + .cast() + .getScale(), + /*filter_scale=*/filter_quantized_type.getScale()); + ArrayRef output_shape = + op->getResult(0).getType().cast().getShape(); + const SmallVector bias_shape = { + output_shape[output_shape.size() - 1]}; + auto bias_quantized_type = CreateI32F32UniformQuantizedType( + op->getLoc(), *op->getContext(), std::move(bias_scale), + op->getResult(0) + .getType() + .cast() + .getElementType() + .cast() + .getZeroPoint()); + Operation* stablehlo_bias_op = add_op->getOperand(1).getDefiningOp(); + auto bias_type = RankedTensorType::getChecked(op->getLoc(), bias_shape, + bias_quantized_type); + auto bias_value = cast( + cast(stablehlo_bias_op).getValue()); + + *bias_tfl_op = rewriter.create( + op->getLoc(), + /*output=*/TypeAttr::get(bias_type), /*value=*/bias_value); + } else { + uniform_quantize_op = FindUserOfType(op); + } + + auto result_quantized_type = uniform_quantize_op->getResult(0) + .getType() + .cast() + .getElementType() + .cast(); + auto new_result_quantized_type = CreateI8F32UniformQuantizedType( + uniform_quantize_op->getLoc(), *rewriter.getContext(), + result_quantized_type.getScale(), + result_quantized_type.getZeroPoint()); + output_type = op->getResult(0).getType().cast().clone( + new_result_quantized_type); + // Omit any bias and requantize ops as `tfl.fully_connected` outputs a + // fused `qi8` type. + FindUserOfType<>(uniform_quantize_op)->setOperand(0, op->getResult(0)); + } else { + output_type = op->getResult(0).getType().cast(); + } + return output_type; + } + + static bool HasOneUseByQuantizeOp(Operation* op) { + return op->hasOneUse() && + (FindUserOfType(op) != nullptr || + FindUserOfType(op) != nullptr); + } +}; + +// Rewrites `stablehlo.convolution` into fused `tfl.conv_2d`. +// If available, fuse bias and activation adjacent to `stablehlo.convolution`. +// This RewritePattern rewrites both the following into `tfl.conv_2d` op: +// +// StableHLO Quantizer output: +// * input: per-tensor qi8 +// * filter: per-channel qi8 (`quantization_dimension` = 3) +// * output: per-channel qi32 (`quantization_dimension` = 3) +// JAX Quantizer output: +// * input: per-tensor qi8 +// * filter: per-channel qi8 (`quantization_dimension` = 3) +// * output: per-tensor qi8 +// +// Conditions for the conversion: +// * Input tensors are per-tensor uniform quantized (i8->f32) +// tensors. +// * The filter tensor is constant a per-channel uniform quantized (i8->f32) +// tensor. +// * Output tensors are per-tensor uniform quantized (i8->f32) or +// per-channel uniform quantized (i32->f32) tensors. +// * Convolution is a 2D convolution op and both the input's and filter's +// shape is 4 dimensional. +// * The filter tensor's format is `[0, 1, i, o]`. +// * Not a depthwise convolution. +class RewriteQuantizedConvolutionOp + : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; LogicalResult match(stablehlo::ConvolutionOp op) const override { + const bool has_i32_output = IsI32F32UniformQuantizedPerAxisType( + op.getResult().getType().cast().getElementType()); + const bool fuse_bias_constant = + FindUserOfType(op) && has_i32_output; stablehlo::ConvDimensionNumbersAttr dimension_numbers = op.getDimensionNumbers(); @@ -444,14 +852,39 @@ class RewriteUpstreamQuantizedConvolutionOp return failure(); } + // TODO: b/309896242 - Lift the assumptions on adjacent ops below + // as we cover more dynamic fused pattern legalization. + if (fuse_bias_constant) { + Operation* add_op = FindUserOfType(op); + if (add_op == nullptr) { + LLVM_DEBUG(llvm::dbgs() << "Failed to find AddOp for bias fusion.\n"); + return failure(); + } + Operation* broadcast_in_dim_op = add_op->getOperand(1).getDefiningOp(); + if (!isa(broadcast_in_dim_op)) { + LLVM_DEBUG(llvm::dbgs() << "Failed to find broadcasted bias.\n"); + return failure(); + } + Operation* bias_const_op = + broadcast_in_dim_op->getOperand(0).getDefiningOp(); + if (!isa(bias_const_op)) { + LLVM_DEBUG(llvm::dbgs() << "Failed to find bias constant.\n"); + return failure(); + } + } + return success(); } void rewrite(stablehlo::ConvolutionOp op, PatternRewriter& rewriter) const override { + const bool has_i32_output = IsI32F32UniformQuantizedPerAxisType( + op.getResult().getType().cast().getElementType()); + stablehlo::ConvDimensionNumbersAttr dimension_numbers = + op.getDimensionNumbers(); + Value filter_value = op.getOperand(1); Operation* filter_op = filter_value.getDefiningOp(); - auto filter_uniform_quantized_type = filter_value.getType() .cast() @@ -489,43 +922,33 @@ class RewriteUpstreamQuantizedConvolutionOp filter_op->getLoc(), /*output=*/TypeAttr::get(new_filter_result_type), new_filter_value_attr); - SmallVector bias_scales = - GetBiasScales(/*input_scale=*/op.getOperand(0) - .getType() - .cast() - .getElementType() - .cast() - .getScale(), - /*filter_scales=*/new_filter_quantized_type.getScales()); + Operation* uniform_quantize_op; + const bool fuse_bias_constant = + FindUserOfType(op) && has_i32_output; + if (has_i32_output) { + if (fuse_bias_constant) { + Operation* add_op = FindUserOfType(op); + uniform_quantize_op = FindUserOfType(add_op); + } else { + uniform_quantize_op = FindUserOfType(op); + } + } - // Create a bias filled with zeros. Mimics the behavior of no bias add. const int64_t num_output_features = new_filter_result_type.getShape()[0]; const SmallVector bias_shape = {num_output_features}; - auto bias_quantized_type = CreateI32F32UniformQuantizedPerAxisType( - op.getLoc(), *op.getContext(), std::move(bias_scales), - new_filter_quantized_type.getZeroPoints(), - /*quantization_dimension=*/0); - auto bias_type = RankedTensorType::getChecked(op.getLoc(), bias_shape, - bias_quantized_type); - - // Create a bias constant. It should have values of 0. - auto bias_value_type = RankedTensorType::getChecked(op.getLoc(), bias_shape, - rewriter.getI32Type()); - auto bias_value = DenseIntElementsAttr::get( - bias_value_type, APInt(/*numBits=*/32, /*value=*/0, /*isSigned=*/true)); - auto bias = rewriter.create( - op.getLoc(), /*output=*/TypeAttr::get(bias_type), - /*value=*/bias_value); + + TFL::QConstOp bias = GetBiasOp(op, rewriter, new_filter_result_type, + new_filter_quantized_type, bias_shape, + has_i32_output, fuse_bias_constant); // Determine the attributes for the TFL::Conv2DOp. - // TODO: b/294808863 - Use `padding = "SAME"` if the padding attribute - // matches the semantics. + Value input_value = op.getOperand(0); if (const DenseIntElementsAttr padding_attr = op.getPaddingAttr(); - !IsPaddingValid(padding_attr)) { + !HasProperPadding(op, dimension_numbers, padding_attr)) { // Add an extra tfl.pad_op if there are explicit padding values. This - // extra pad op will allow us to always set the `padding` attribute of the - // newly created tfl.conv_2d op as "VALID". + // extra pad op will allow us to always set the `padding` attribute of + // the newly created tfl.conv_2d op as "VALID". TFL::PadOp pad_op = CreateTflPadOp(op.getLoc(), padding_attr, input_value, rewriter); input_value = pad_op.getResult(); @@ -534,465 +957,50 @@ class RewriteUpstreamQuantizedConvolutionOp const auto [stride_h, stride_w] = GetStrides(op); const auto [dilation_h_factor, dilation_w_factor] = GetDilationFactors(op); - auto tfl_conv2d_op = rewriter.create( - op.getLoc(), /*output=*/op.getResult().getType(), /*input=*/input_value, + Type output_type; + if (has_i32_output) { + // StableHLO Quantizer outputs an i32 type. Rewrite to i8 type result + // to meet TFLite op requirement. + auto result_quantized_type = uniform_quantize_op->getResult(0) + .getType() + .cast() + .getElementType() + .cast(); + auto new_result_quantized_type = CreateI8F32UniformQuantizedType( + uniform_quantize_op->getLoc(), *rewriter.getContext(), + result_quantized_type.getScale(), + result_quantized_type.getZeroPoint()); + output_type = op.getResult().getType().cast().clone( + new_result_quantized_type); + // Omit any bias and requantize ops as `tfl.fully_connected` outputs a + // fused `qi8` type. + FindUserOfType<>(uniform_quantize_op)->setOperand(0, op->getResult(0)); + } else { + output_type = op.getResult().getType(); + } + rewriter.replaceOpWithNewOp( + // op result should be recasted to desired quantized type. + op, output_type, + /*input=*/input_value, /*filter=*/new_filter_constant_op, /*bias=*/bias.getResult(), /*dilation_h_factor=*/rewriter.getI32IntegerAttr(dilation_h_factor), /*dilation_w_factor=*/rewriter.getI32IntegerAttr(dilation_w_factor), /*fused_activation_function=*/rewriter.getStringAttr("NONE"), - /*padding=*/rewriter.getStringAttr("VALID"), + /*padding=*/ + rewriter.getStringAttr(UseSamePadding(op, dimension_numbers) + ? kPaddingSame + : kPaddingValid), /*stride_h=*/rewriter.getI32IntegerAttr(stride_h), /*stride_w=*/rewriter.getI32IntegerAttr(stride_w)); - - rewriter.replaceAllUsesWith(op.getResult(), tfl_conv2d_op.getResult()); - rewriter.eraseOp(op); - } - - private: - // Create a `tfl.pad` op to apply explicit padding to the input tensor that - // correspond to the `padding` attribute from the `stablehlo.convolution` op. - TFL::PadOp CreateTflPadOp(Location loc, - const DenseIntElementsAttr& padding_attr, - Value input_value, - PatternRewriter& rewriter) const { - auto padding_values = padding_attr.getValues(); - // [[h_l, h_r], [w_l, w_r]]. - DCHECK_EQ(padding_attr.size(), 4); - - // In StableHLO the padding attribute doesn't include the padding values for - // input and output feature dimensions (because they are 0 anyways). In - // TFLite, padding values for input and output feature dimensions should be - // explicitly set to 0s. Note that TFLite's input tensor is formatted as - // OHWI. The resulting pad values becomes: [[0, 0], [h_l, h_r], [w_l, w_r], - // [0, 0]] - SmallVector tfl_pad_values = {0, 0}; // For output feature dim. - for (const int64_t padding_value : padding_values) { - tfl_pad_values.push_back(static_cast(padding_value)); - } - // For input feature dim. - tfl_pad_values.push_back(0); - tfl_pad_values.push_back(0); - - const auto input_tensor_type = - input_value.getType().cast(); - const int64_t rank = input_tensor_type.getRank(); - - SmallVector padded_output_tensor_shape = - InferPaddedTensorShape(input_tensor_type.getShape(), tfl_pad_values); - - auto padded_output_tensor_type = RankedTensorType::get( - padded_output_tensor_shape, input_tensor_type.getElementType()); - - // The pad values is provided as a const op. - auto pad_value_const_op = rewriter.create( - loc, /*value=*/DenseIntElementsAttr::get( - RankedTensorType::get({rank, 2}, rewriter.getIntegerType(32)), - tfl_pad_values)); - - return rewriter.create( - loc, /*output=*/padded_output_tensor_type, input_value, - /*padding=*/pad_value_const_op.getResult()); - } - - // Infers the output tensor's shape after padding `tfl_pad_values` to the - // `tensor_shape`. `tfl_pad_values` should be formatted as `[[l_0, r_0], [l_1, - // r_1], ..., [l_n, r_n]]`, where `l_x` and `r_x` are the left and paddings - // for the x-th dimension, respectively. - SmallVector InferPaddedTensorShape( - const ArrayRef tensor_shape, - const ArrayRef tfl_pad_values) const { - SmallVector padded_shape(tensor_shape.begin(), tensor_shape.end()); - for (int i = 0; i < padded_shape.size(); ++i) { - // Left padding + right padding. - const int32_t padded = tfl_pad_values[i * 2] + tfl_pad_values[i * 2 + 1]; - padded_shape[i] += padded; - } - - return padded_shape; - } - - // Transposes the filter tensor to match the filter tensor format for - // `tfl.conv_2d`. This function performs the following index permutation - // only: (3, 0, 1, 2). The filter value is assumed to be of `[0, 1, i, o]` - // format. The `tfl.conv_2d` accepts the filter of `[o, 0, 1, i]`. - // TODO: b/291598373 - Lift the assumption about the filter tensor's format - // and generalize the transpose. - DenseIntElementsAttr TransposeFilterValue( - Location loc, PatternRewriter& rewriter, - const DenseIntElementsAttr& filter_value_attr) const { - ArrayRef filter_shape = - filter_value_attr.getShapedType().getShape(); - SmallVector filter_constant_values; - for (const auto filter_val : filter_value_attr.getValues()) { - filter_constant_values.push_back(filter_val); - } - - SmallVector new_filter_constant_values( - filter_constant_values.size(), 0); - - SmallVector new_filter_shape; - SmallVector transpose_dims = {3, 0, 1, 2}; - for (int i = 0; i < filter_shape.size(); ++i) { - new_filter_shape.push_back(filter_shape[transpose_dims[i]]); - } - - auto get_array_idx = [](ArrayRef shape, const int i, const int j, - const int k, const int l) -> int64_t { - return (i * shape[1] * shape[2] * shape[3]) + (j * shape[2] * shape[3]) + - (k * shape[3]) + l; - }; - - // Transpose the filter value. - for (int i = 0; i < filter_shape[0]; ++i) { - for (int j = 0; j < filter_shape[1]; ++j) { - for (int k = 0; k < filter_shape[2]; ++k) { - for (int l = 0; l < filter_shape[3]; ++l) { - // [i][j][k][l] -> [l][i][j][k] - const int old_idx = get_array_idx(filter_shape, i, j, k, l); - const int new_idx = get_array_idx(new_filter_shape, l, i, j, k); - - new_filter_constant_values[new_idx] = - filter_constant_values[old_idx]; - } - } - } - } - - // Create the new filter constant. - auto new_filter_value_attr_type = - RankedTensorType::getChecked(loc, new_filter_shape, - /*elementType=*/rewriter.getI8Type()); - auto new_filter_constant_value_attr = DenseIntElementsAttr::get( - new_filter_value_attr_type, new_filter_constant_values); - - return new_filter_constant_value_attr; - } - - // Determines if the padding attribute corresponds to "VALID" - // (https://www.tensorflow.org/api_docs/python/tf/nn). - bool IsPaddingValid(const DenseIntElementsAttr& padding_attr) const { - // If padding_attr is empty, it defaults to splat 0s. - return !padding_attr || (padding_attr.isSplat() && - padding_attr.getSplatValue() == 0); - } - - // Returns the stride amount for the height and width, respectively. - std::pair GetStrides(stablehlo::ConvolutionOp op) const { - const Attribute window_strides_attr = op.getWindowStridesAttr(); - if (!window_strides_attr) { - return {1, 1}; // Default values. - } - - const auto window_strides_attr_value = - hlo::getI64Array(window_strides_attr); - // It is guaranteed from the spec that it has two values: - // https://github.com/openxla/stablehlo/blob/main/docs/spec.md#convolution. - return {window_strides_attr_value[0], window_strides_attr_value[1]}; - } - - // Returns the dilation amount for the height and width, respectively. - std::pair GetDilationFactors( - stablehlo::ConvolutionOp op) const { - const Attribute lhs_dilation_attr = op.getLhsDilationAttr(); - if (!lhs_dilation_attr) { - return {1, 1}; // Default values. - } - - const auto lhs_dilation_attr_value = hlo::getI64Array(lhs_dilation_attr); - // It is guaranteed from the spec that it has two values: - // https://github.com/openxla/stablehlo/blob/main/docs/spec.md#convolution. - return {lhs_dilation_attr_value[0], lhs_dilation_attr_value[1]}; - } -}; - -// Rewrites full-integer quantized `stablehlo.dot_general` ->`tfl.batch_matmul` -// when it accepts uniform quantized tensors. -// -// Since transpose and reshape of quantized tensors are not natively supported -// at the moment, the conversion condition is relatively strict, following -// (https://www.tensorflow.org/api_docs/cc/class/tensorflow/ops/batch-mat-mul-v3) -// -// Conditions for the conversion : -// * size(batching_dimensions) <= 3 (TFLite support restriction) -// * size(contracting_dimensions) = 1 -// * Input (lhs) and output tensors are per-tensor uniform quantized (i8->f32) -// tensors (full integer) with shape [..., r_x, c_x] or [..., c_x, r_x]. -// * The rhs tensor is a per-tensor uniform quantized (i8->f32) tensor -// (constant or activation) with shape [..., r_y, c_y] or [..., c_y, r_y]. -// -// TODO: b/293650675 - Relax the conversion condition to support dot_general in -// general. -// TODO: b/322428814 - Add StableHLO quantizer integration tests for ODML. -class RewriteUpstreamQuantizedDotGeneralOpToBatchMatmulOp - : public OpRewritePattern { - public: - using OpRewritePattern::OpRewritePattern; - - static LogicalResult MatchLhs( - Value lhs, stablehlo::DotDimensionNumbersAttr dimension_numbers) { - auto lhs_type = lhs.getType().cast(); - if (!IsI8F32UniformQuantizedType(lhs_type.getElementType())) { - LLVM_DEBUG(llvm::dbgs() - << "Expected a per-tensor uniform " - "quantized (i8->f32) input for dot_general. Got: " - << lhs_type << "\n"); - return failure(); - } - if (!lhs_type.hasRank()) { - LLVM_DEBUG(llvm::dbgs() << "Expected lhs of dot_general has rank. Got: " - << lhs_type << "\n"); - return failure(); - } - const int lhs_rank = lhs_type.getRank(); - auto lhs_contracting_dim = - dimension_numbers.getLhsContractingDimensions()[0]; - if ((lhs_contracting_dim != lhs_rank - 1) && - (lhs_contracting_dim != lhs_rank - 2)) { - LLVM_DEBUG(llvm::dbgs() - << "Not supported lhs contracting dim for dot_general.\n"); - return failure(); - } - return success(); - } - - static LogicalResult MatchRhs( - Value rhs, stablehlo::DotDimensionNumbersAttr dimension_numbers) { - if (!rhs.getType().cast().hasRank()) { - LLVM_DEBUG(llvm::dbgs() << "Expected rhs of dot_general has rank. Got: " - << rhs.getType() << "\n"); - return failure(); - } - const int rhs_rank = rhs.getType().cast().getRank(); - auto rhs_contracting_dim = - dimension_numbers.getRhsContractingDimensions()[0]; - if ((rhs_contracting_dim != rhs_rank - 1) && - (rhs_contracting_dim != rhs_rank - 2)) { - LLVM_DEBUG(llvm::dbgs() - << "Not supported rhs contracting dim for dot_general.\n"); - return failure(); - } - - auto rhs_type = rhs.getType().cast(); - if (!IsI8F32UniformQuantizedType(rhs_type.getElementType())) { - LLVM_DEBUG(llvm::dbgs() - << "Expected a per-tensor uniform " - "quantized (i8->f32) weight for dot_general. Got: " - << rhs_type << "\n"); - return failure(); - } - return success(); - } - - static LogicalResult MatchOutput( - Value output, stablehlo::DotDimensionNumbersAttr dimension_numbers) { - auto output_type = output.getType().cast(); - if (!IsI8F32UniformQuantizedType(output_type.getElementType())) { - LLVM_DEBUG(llvm::dbgs() - << "Expected a per-tensor uniform " - "quantized (i8->f32) output for dot_general. Got: " - << output_type << "\n"); - return failure(); - } - return success(); - } - - LogicalResult match(stablehlo::DotGeneralOp op) const override { - stablehlo::DotDimensionNumbersAttr dimension_numbers = - op.getDotDimensionNumbers(); - - // Check one side is enough since - // (C1) size(lhs_batching_dimensions) = size(rhs_batching_dimensions). - if (dimension_numbers.getLhsBatchingDimensions().size() > 3) { - LLVM_DEBUG( - llvm::dbgs() - << "Failed to match batch dimention for quantized dot_general.\n"); - return failure(); - } - // Check one side is enough since - // (C2) size(lhs_contracting_dimensions) = size(rhs_contracting_dimensions). - if (dimension_numbers.getLhsContractingDimensions().size() != 1) { - LLVM_DEBUG( - llvm::dbgs() - << "Failed to match contract dimention for quantized dot_general.\n"); - return failure(); - } - - if (failed(MatchLhs(op.getLhs(), dimension_numbers))) { - LLVM_DEBUG(llvm::dbgs() - << "Failed to match input for quantized dot_general.\n"); - return failure(); - } - if (failed(MatchRhs(op.getRhs(), dimension_numbers))) { - LLVM_DEBUG(llvm::dbgs() - << "Failed to match weight for quantized dot_general.\n"); - return failure(); - } - - if (failed(MatchOutput(op.getResult(), dimension_numbers))) { - LLVM_DEBUG(llvm::dbgs() - << "Failed to match output for quantized dot_general.\n"); - return failure(); - } - - return success(); - } - - void rewrite(stablehlo::DotGeneralOp op, - PatternRewriter& rewriter) const override { - Value rhs_value = op.getRhs(); - Operation* rhs_op = rhs_value.getDefiningOp(); - - stablehlo::DotDimensionNumbersAttr dimension_numbers = - op.getDotDimensionNumbers(); - Value input_value = op.getLhs(); - const int lhs_rank = input_value.getType().cast().getRank(); - auto lhs_contracting_dim = - dimension_numbers.getLhsContractingDimensions()[0]; - BoolAttr adj_x = - (lhs_contracting_dim == lhs_rank - 2 ? rewriter.getBoolAttr(true) - : rewriter.getBoolAttr(false)); - auto rhs_contracting_dim = - dimension_numbers.getRhsContractingDimensions()[0]; - const int rhs_rank = rhs_value.getType().cast().getRank(); - BoolAttr adj_y = - (rhs_contracting_dim == rhs_rank - 1 ? rewriter.getBoolAttr(true) - : rewriter.getBoolAttr(false)); - - // Set to `nullptr` because this attribute only matters when the input is - // dynamic-range quantized. - BoolAttr asymmetric_quantize_inputs = nullptr; - - // Create BMM assuming rhs is activation. - auto tfl_batchmatmul_op = rewriter.create( - op.getLoc(), /*output=*/op.getResult().getType(), /*input=*/input_value, - /*filter=*/rhs_value, adj_x, adj_y, asymmetric_quantize_inputs); - - // Update BMM if rhs is a constant. - auto const_rhs = dyn_cast_or_null(rhs_op); - if (const_rhs) { - auto rhs_uniform_quantized_type = rhs_value.getType().cast(); - auto rhs_constant_value_attr = - cast(const_rhs.getValue()); - auto rhs_constant_op = rewriter.create( - rhs_op->getLoc(), - /*output=*/TypeAttr::get(rhs_uniform_quantized_type), - rhs_constant_value_attr); - tfl_batchmatmul_op = rewriter.create( - op.getLoc(), /*output=*/op.getResult().getType(), - /*input=*/input_value, /*filter=*/rhs_constant_op.getResult(), adj_x, - adj_y, asymmetric_quantize_inputs); - } - - rewriter.replaceAllUsesWith(op.getResult(), tfl_batchmatmul_op.getResult()); - } -}; - -// Rewrites `stablehlo.dot_general` -> `tfl.fully_connected` when it accepts -// uniform quantized tensors with per-axis quantized filter tensor (rhs). -// -// Conditions for the conversion: -// * Input and output tensors are per-tensor uniform quantized (i8->f32) -// tensors. -// * The filter tensor is constant a per-channel uniform quantized (i8->f32) -// tensor. The quantization dimension should be 1 (the non-contracting -// dimension). -// * The input tensor's rank is either 2 or 3. The last dimension of the input -// tensor should be the contracting dimension, i.e. [..., c_x, r_x]. -// * The filter tensor's rank is 2. The contracting dimension should be the -// first dimension (dim 0), i.e. [c_y, r_y] where c_y == r_x. -// * Does not consider activation fusion. -// * Does not consider bias add fusion. -// -// TODO: b/294983811 - Merge this pattern into -// `RewriteFullIntegerQuantizedDotGeneralOp`. -// TODO: b/295264927 - `stablehlo.dot_general` with per-axis quantized operands -// is not specified in the StableHLO dialect. Update the spec to allow this. -// TODO: b/322428814 - Add StableHLO quantizer integration tests for ODML. -class RewriteUpstreamQuantizedDotGeneralOpToTflFullyConnectedOp - : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - public: - LogicalResult match(stablehlo::DotGeneralOp op) const override { - const stablehlo::DotDimensionNumbersAttr dot_dimension_nums = - op.getDotDimensionNumbers(); - if (const int num_rhs_contracting_dims = - dot_dimension_nums.getRhsContractingDimensions().size(); - num_rhs_contracting_dims != 1) { - LLVM_DEBUG(llvm::dbgs() - << "Expected number of contracting dimensions to be 1. Got: " - << num_rhs_contracting_dims << ".\n"); - return failure(); - } - - if (failed(MatchInput(op.getOperand(0)))) { - LLVM_DEBUG(llvm::dbgs() - << "Failed to match input for quantized dot_general op.\n"); - return failure(); - } - - if (failed(MatchFilter(op.getOperand(1)))) { - LLVM_DEBUG(llvm::dbgs() - << "Failed to match filter for quantized dot_general op.\n"); - return failure(); - } - - if (failed(MatchOutput(op.getResult()))) { - LLVM_DEBUG(llvm::dbgs() - << "Failed to match output for quantized dot_general op.\n"); - return failure(); - } - - return success(); - } - - void rewrite(stablehlo::DotGeneralOp op, - PatternRewriter& rewriter) const override { - // Create the new filter constant - transpose filter value - // from [i, o] -> [o, i]. This is because we assume `[i, o]` format for - // `stablehlo.dot_general` (i.e. contracting dimension == 1) whereas - // `tfl.fully_connected` accepts an OI format. - auto filter_constant_op = - cast(op.getOperand(1).getDefiningOp()); - - TFL::QConstOp new_filter_constant_op = - CreateTflConstOpForFilter(filter_constant_op, rewriter, - /*is_per_axis=*/true); - const Value input_value = op.getOperand(0); - const double input_scale = input_value.getType() - .cast() - .getElementType() - .cast() - .getScale(); - TFL::QConstOp bias_constant_op = CreateTflConstOpForDummyBias( - op.getLoc(), input_scale, new_filter_constant_op, rewriter, - /*is_per_axis=*/true, *op.getContext()); - - const Value result_value = op.getResult(); - // Set to `nullptr` because this attribute only matters when the input is - // dynamic-range quantized. - const BoolAttr asymmetric_quantize_inputs = nullptr; - auto tfl_fully_connected_op = rewriter.create( - op.getLoc(), /*output=*/result_value.getType(), - /*input=*/input_value, /*filter=*/new_filter_constant_op.getResult(), - /*bias=*/bias_constant_op.getResult(), - /*fused_activation_function=*/rewriter.getStringAttr("NONE"), - /*weights_format=*/rewriter.getStringAttr("DEFAULT"), - /*keep_num_dims=*/rewriter.getBoolAttr(false), - asymmetric_quantize_inputs); - - rewriter.replaceAllUsesWith(result_value, - tfl_fully_connected_op.getResult(0)); - rewriter.eraseOp(op); } private: static LogicalResult MatchInput(Value input) { auto input_type = input.getType().cast(); - if (!input_type.hasRank() || - !(input_type.getRank() == 2 || input_type.getRank() == 3)) { - LLVM_DEBUG(llvm::dbgs() << "Input expected to have rank of 2 or 3. Got: " - << input_type << ".\n"); + if (input_type.getRank() != 4) { + LLVM_DEBUG(llvm::dbgs() << "Only 2D convolution op is supported. " + "Expected input rank of 4. Got: " + << input_type.getRank() << ".\n"); return failure(); } @@ -1009,10 +1017,10 @@ class RewriteUpstreamQuantizedDotGeneralOpToTflFullyConnectedOp static LogicalResult MatchFilter(Value filter) { auto filter_type = filter.getType().cast(); - if (!filter_type.hasRank() || filter_type.getRank() != 2) { - LLVM_DEBUG(llvm::dbgs() - << "Filter tensor expected to have a tensor rank of 2. Got: " - << filter_type << ".\n"); + if (filter_type.getRank() != 4) { + LLVM_DEBUG(llvm::dbgs() << "Only 2D convolution op is supported. " + "Expected filter rank of 4. Got: " + << filter_type.getRank() << ".\n"); return failure(); } @@ -1026,8 +1034,8 @@ class RewriteUpstreamQuantizedDotGeneralOpToTflFullyConnectedOp } if (filter_element_type.cast() - .getQuantizedDimension() != 1) { - LLVM_DEBUG(llvm::dbgs() << "Quantized dimension should be 1. Got: " + .getQuantizedDimension() != 3) { + LLVM_DEBUG(llvm::dbgs() << "Quantized dimension should be 3. Got: " << filter_element_type << "\n"); return failure(); } @@ -1037,234 +1045,253 @@ class RewriteUpstreamQuantizedDotGeneralOpToTflFullyConnectedOp LLVM_DEBUG(llvm::dbgs() << "Filter should be a constant.\n"); return failure(); } - return success(); } static LogicalResult MatchOutput(Value output) { const Type output_element_type = output.getType().cast().getElementType(); - if (!IsI8F32UniformQuantizedType(output_element_type)) { - LLVM_DEBUG(llvm::dbgs() - << "Expected a uniform quantized (i8->f32) type. Got: " - << output_element_type << ".\n"); + if (!IsI32F32UniformQuantizedPerAxisType(output_element_type) && + !IsI8F32UniformQuantizedType(output_element_type)) { + LLVM_DEBUG( + llvm::dbgs() + << "Expected a per-channel uniform quantized (i32->f32) type or " + << "per-tensor uniform quantized (i8->f32) type. Got: " + << output_element_type << ".\n"); return failure(); } - return success(); } -}; - -// Rewrites `stablehlo.dot_general` to `tfl.fully_connected` or -// `tfl.batch_matmul` when it accepts uniform quantized tensors. -// -// Conditions for `tfl.fully_connected` conversion: -// * Input and output tensors are per-tensor uniform quantized (i8->f32) -// tensors. -// * The filter tensor is constant a per-tensor uniform quantized (i8->f32) -// tensor. The quantization dimension should be 1 (the non-contracting -// dimension). -// * The input tensor's rank is either 2 or 3. The last dimension of the input -// tensor should be the contracting dimension, i.e. [..., c_x, r_x]. -// * The filter tensor's rank is 2. The contracting dimension should be the -// first dimension (dim 0), i.e. [c_y, r_y] where c_y == r_x. -// * Does not consider activation fusion. -// * Does not consider bias add fusion. -// TODO: b/580909703 - Include conversion conditions for `tfl.batch_matmul` op. -// -// TODO: b/295264927 - `stablehlo.dot_general` with per-axis quantized operands -// is not specified in the StableHLO dialect. Update the spec to allow this. -// TODO: b/322428814 - Add StableHLO quantizer integration tests for ODML. -class RewriteQuantizedDotGeneralOpToTflFullyConnectedOrBatchMatmulOp - : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; + // Create a `tfl.pad` op to apply explicit padding to the input tensor that + // correspond to the `padding` attribute from the `stablehlo.convolution` op. + TFL::PadOp CreateTflPadOp(Location loc, + const DenseIntElementsAttr& padding_attr, + Value input_value, + PatternRewriter& rewriter) const { + auto padding_values = padding_attr.getValues(); + // [[h_l, h_r], [w_l, w_r]]. + DCHECK_EQ(padding_attr.size(), 4); - public: - LogicalResult match(stablehlo::DotGeneralOp op) const override { - const stablehlo::DotDimensionNumbersAttr dot_dimension_nums = - op.getDotDimensionNumbers(); - if (const int num_rhs_contracting_dims = - dot_dimension_nums.getRhsContractingDimensions().size(); - num_rhs_contracting_dims != 1) { - LLVM_DEBUG(llvm::dbgs() - << "Expected number of contracting dimensions to be 1. Got: " - << num_rhs_contracting_dims << ".\n"); - return failure(); + // In StableHLO the padding attribute doesn't include the padding values for + // input and output feature dimensions (because they are 0 anyways). In + // TFLite, padding values for input and output feature dimensions should be + // explicitly set to 0s. Note that TFLite's input tensor is formatted as + // OHWI. The resulting pad values becomes: [[0, 0], [h_l, h_r], [w_l, w_r], + // [0, 0]] + SmallVector tfl_pad_values = {0, 0}; // For output feature dim. + for (const int64_t padding_value : padding_values) { + tfl_pad_values.push_back(CastI64ToI32(padding_value).value()); } + // For input feature dim. + tfl_pad_values.push_back(0); + tfl_pad_values.push_back(0); - if (failed(MatchInput(op.getOperand(0)))) { - LLVM_DEBUG(llvm::dbgs() - << "Failed to match input for quantized dot_general op.\n"); - return failure(); - } + const auto input_tensor_type = + input_value.getType().cast(); + const int64_t rank = input_tensor_type.getRank(); - if (failed(MatchFilter(op.getOperand(1)))) { - LLVM_DEBUG(llvm::dbgs() - << "Failed to match filter for quantized dot_general op.\n"); - return failure(); - } + SmallVector padded_output_tensor_shape = + InferPaddedTensorShape(input_tensor_type.getShape(), tfl_pad_values); - if (failed(MatchOutput(op.getResult()))) { - LLVM_DEBUG(llvm::dbgs() - << "Failed to match output for quantized dot_general op.\n"); - return failure(); - } + auto padded_output_tensor_type = RankedTensorType::get( + padded_output_tensor_shape, input_tensor_type.getElementType()); - if (failed(MatchUsers(op.getResult()))) { - LLVM_DEBUG(llvm::dbgs() << "Failed to match subsequent requantize for " - "quantized dot_general op.\n"); - return failure(); - } + // The pad values is provided as a const op. + auto pad_value_const_op = rewriter.create( + loc, /*value=*/DenseIntElementsAttr::get( + RankedTensorType::get({rank, 2}, rewriter.getIntegerType(32)), + tfl_pad_values)); - return success(); + return rewriter.create( + loc, /*output=*/padded_output_tensor_type, input_value, + /*padding=*/pad_value_const_op.getResult()); } - void rewrite(stablehlo::DotGeneralOp op, - PatternRewriter& rewriter) const override { - // Create the new filter constant - transpose filter value - // from [i, o] -> [o, i]. This is because we assume `[i, o]` format for - // `stablehlo.dot_general` (i.e. contracting dimension == 1) whereas - // `tfl.fully_connected` accepts an OI format. - auto filter_constant_op = - cast(op.getOperand(1).getDefiningOp()); - - TFL::QConstOp new_filter_constant_op = CreateTflConstOpForFilter( - filter_constant_op, rewriter, /*is_per_axis=*/false); - const Value input_value = op.getOperand(0); - const double input_scale = input_value.getType() - .cast() - .getElementType() - .cast() - .getScale(); - TFL::QConstOp bias_constant_op = CreateTflConstOpForDummyBias( - op.getLoc(), input_scale, new_filter_constant_op, rewriter, - /*is_per_axis=*/false, *op.getContext()); + // Infers the output tensor's shape after padding `tfl_pad_values` to the + // `tensor_shape`. `tfl_pad_values` should be formatted as `[[l_0, r_0], [l_1, + // r_1], ..., [l_n, r_n]]`, where `l_x` and `r_x` are the left and paddings + // for the x-th dimension, respectively. + SmallVector InferPaddedTensorShape( + const ArrayRef tensor_shape, + const ArrayRef tfl_pad_values) const { + SmallVector padded_shape(tensor_shape.begin(), tensor_shape.end()); + for (int i = 0; i < padded_shape.size(); ++i) { + // Left padding + right padding. + const int32_t padded = tfl_pad_values[i * 2] + tfl_pad_values[i * 2 + 1]; + padded_shape[i] += padded; + } - auto output_op = op.getResult().getDefiningOp(); - Operation* requantize_op = *output_op->getResult(0).getUsers().begin(); - Operation* dequantize_op = *requantize_op->getResult(0).getUsers().begin(); + return padded_shape; + } - // Set to `nullptr` because this attribute only matters when the input is - // dynamic-range quantized. - const BoolAttr asymmetric_quantize_inputs = nullptr; - auto tfl_fully_connected_op = rewriter.create( - op.getLoc(), - /*output=*/ - requantize_op->getResult(0).getType(), // result_value.getType(), - /*input=*/input_value, /*filter=*/new_filter_constant_op.getResult(), - /*bias=*/bias_constant_op.getResult(), - /*fused_activation_function=*/rewriter.getStringAttr("NONE"), - /*weights_format=*/rewriter.getStringAttr("DEFAULT"), - /*keep_num_dims=*/rewriter.getBoolAttr(false), - asymmetric_quantize_inputs); + // Transposes the filter tensor to match the filter tensor format for + // `tfl.conv_2d`. This function performs the following index permutation + // only: (3, 0, 1, 2). The filter value is assumed to be of `[0, 1, i, o]` + // format. The `tfl.conv_2d` accepts the filter of `[o, 0, 1, i]`. + // TODO: b/291598373 - Lift the assumption about the filter tensor's format + // and generalize the transpose. + DenseIntElementsAttr TransposeFilterValue( + Location loc, PatternRewriter& rewriter, + const DenseIntElementsAttr& filter_value_attr) const { + ArrayRef filter_shape = + filter_value_attr.getShapedType().getShape(); + SmallVector filter_constant_values; + for (auto filter_val : filter_value_attr.getValues()) { + filter_constant_values.push_back(filter_val); + } - auto tfl_dequantize_op = rewriter.create( - op.getLoc(), dequantize_op->getResult(0).getType(), - tfl_fully_connected_op->getResult(0)); + SmallVector new_filter_constant_values( + filter_constant_values.size(), 0); - rewriter.replaceAllUsesWith(dequantize_op->getResult(0), - tfl_dequantize_op->getResult(0)); + SmallVector new_filter_shape; + SmallVector transpose_dims = {3, 0, 1, 2}; + for (int i = 0; i < filter_shape.size(); ++i) { + new_filter_shape.push_back(filter_shape[transpose_dims[i]]); + } - rewriter.replaceAllUsesWith(op.getResult(), - tfl_fully_connected_op.getResult(0)); + auto get_array_idx = [](ArrayRef shape, const int i, const int j, + const int k, const int l) -> int64_t { + return (i * shape[1] * shape[2] * shape[3]) + (j * shape[2] * shape[3]) + + (k * shape[3]) + l; + }; - rewriter.eraseOp(op); - } + // Transpose the filter value. + for (int i = 0; i < filter_shape[0]; ++i) { + for (int j = 0; j < filter_shape[1]; ++j) { + for (int k = 0; k < filter_shape[2]; ++k) { + for (int l = 0; l < filter_shape[3]; ++l) { + // [i][j][k][l] -> [l][i][j][k] + int old_idx = get_array_idx(filter_shape, i, j, k, l); + int new_idx = get_array_idx(new_filter_shape, l, i, j, k); - private: - static LogicalResult MatchInput(Value input) { - auto input_type = input.getType().cast(); - if (!input_type.hasRank() || - !(input_type.getRank() == 2 || input_type.getRank() == 3)) { - LLVM_DEBUG(llvm::dbgs() << "Input expected to have rank of 2 or 3. Got: " - << input_type << ".\n"); - return failure(); + new_filter_constant_values[new_idx] = + filter_constant_values[old_idx]; + } + } + } } - if (const auto input_element_type = input_type.getElementType(); - !IsI8F32UniformQuantizedType(input_element_type)) { - LLVM_DEBUG(llvm::dbgs() - << "Expected an i8->f32 uniform quantized type. Got: " - << input_element_type << ".\n"); - return failure(); - } + // Create the new filter constant. + auto new_filter_value_attr_type = + RankedTensorType::getChecked(loc, new_filter_shape, + /*elementType=*/rewriter.getI8Type()); + auto new_filter_constant_value_attr = DenseIntElementsAttr::get( + new_filter_value_attr_type, new_filter_constant_values); - return success(); + return new_filter_constant_value_attr; } - static LogicalResult MatchFilter(Value filter) { - auto filter_type = filter.getType().cast(); - if (!filter_type.hasRank() || filter_type.getRank() != 2) { - LLVM_DEBUG(llvm::dbgs() - << "Filter tensor expected to have a tensor rank of 2. Got: " - << filter_type << ".\n"); - return failure(); - } + bool UseSamePadding( + Operation* op, + stablehlo::ConvDimensionNumbersAttr dimension_numbers) const { + // TODO: b/294808863 - Account for dynamic shapes. + const ArrayRef input_shape = + op->getOperand(0).getType().cast().getShape(); + const ArrayRef output_shape = + op->getResult(0).getType().cast().getShape(); + const ArrayRef input_spatial_dim_inds = + dimension_numbers.getInputSpatialDimensions(); + const ArrayRef output_spatial_dim_inds = + dimension_numbers.getOutputSpatialDimensions(); + return (input_shape[input_spatial_dim_inds[0]] == + output_shape[output_spatial_dim_inds[0]] && + input_shape[input_spatial_dim_inds[1]] == + output_shape[output_spatial_dim_inds[1]]); + } - const Type filter_element_type = filter_type.getElementType(); - if (!IsI8F32UniformQuantizedType(filter_element_type)) { - LLVM_DEBUG(llvm::dbgs() - << "Expected a uniform quantized (i8->f32) type. Got: " - << filter_element_type << "\n"); - return failure(); - } + // Determines if the padding attribute corresponds to "VALID" or "SAME". + // If not, the input's shape should be adjusted with explicit `tfl.pad` op. + // (https://www.tensorflow.org/api_docs/python/tf/nn). + bool HasProperPadding(Operation* op, + stablehlo::ConvDimensionNumbersAttr dimension_numbers, + const DenseIntElementsAttr& padding_attr) const { + // If padding_attr is empty, it defaults to splat 0s. + return UseSamePadding(op, dimension_numbers) || + (!padding_attr || (padding_attr.isSplat() && + padding_attr.getSplatValue() == 0)); + } - if (Operation* filter_op = filter.getDefiningOp(); - filter_op == nullptr || !isa(filter_op)) { - LLVM_DEBUG(llvm::dbgs() << "Filter should be a constant.\n"); - return failure(); + // Returns the stride amount for the height and width, respectively. + std::pair GetStrides(stablehlo::ConvolutionOp op) const { + DenseI64ArrayAttr window_strides_attr = op.getWindowStridesAttr(); + if (!window_strides_attr) { + return {1, 1}; // Default values. } - return success(); + auto window_strides_attr_value = window_strides_attr.asArrayRef(); + // It is guaranteed from the spec that it has two values: + // https://github.com/openxla/stablehlo/blob/main/docs/spec.md#convolution. + return {window_strides_attr_value[0], window_strides_attr_value[1]}; } - static LogicalResult MatchOutput(Value output) { - const Type output_element_type = - output.getType().cast().getElementType(); - if (!IsI32F32UniformQuantizedType(output_element_type)) { - LLVM_DEBUG(llvm::dbgs() - << "Expected a uniform quantized (i32->f32) type. Got: " - << output_element_type << ".\n"); - return failure(); + // Returns the dilation amount for the height and width, respectively. + std::pair GetDilationFactors( + stablehlo::ConvolutionOp op) const { + DenseI64ArrayAttr lhs_dilation_attr = op.getLhsDilationAttr(); + if (!lhs_dilation_attr) { + return {1, 1}; // Default values. } - return success(); + + auto lhs_dilation_attr_value = lhs_dilation_attr.asArrayRef(); + // It is guaranteed from the spec that it has two values: + // https://github.com/openxla/stablehlo/blob/main/docs/spec.md#convolution. + return {lhs_dilation_attr_value[0], lhs_dilation_attr_value[1]}; } - static LogicalResult MatchUsers(Value output) { - auto output_op = output.getDefiningOp(); + TFL::QConstOp GetBiasOp( + stablehlo::ConvolutionOp op, PatternRewriter& rewriter, + const RankedTensorType new_filter_result_type, + const UniformQuantizedPerAxisType new_filter_quantized_type, + const SmallVector bias_shape, const bool has_i32_output, + const bool fuse_bias_constant) const { + const SmallVector bias_scales = GetBiasScales( + /*input_scale=*/op.getOperand(0) + .getType() + .cast() + .getElementType() + .cast() + .getScale(), + /*filter_scales=*/new_filter_quantized_type.getScales()); - if (!output_op->hasOneUse()) { - LLVM_DEBUG(llvm::dbgs() << "Expected output to be used only once.\n"); - return failure(); - } - // TODO: b/309896242 - Add support for fused op case. - if (Operation* requantize_op = dyn_cast_or_null( - *output_op->getResult(0).getUsers().begin())) { - const Type requantize_element_type = requantize_op->getResult(0) - .getType() - .cast() - .getElementType(); - if (!IsI8F32UniformQuantizedType(requantize_element_type)) { - LLVM_DEBUG(llvm::dbgs() << "Expected a quantize (i8->f32) type. Got: " - << requantize_element_type << ".\n"); - return failure(); - } - if (!isa( - *requantize_op->getResult(0).getUsers().begin())) { - LLVM_DEBUG(llvm::dbgs() << "Expected a dequantize type.\n"); - return failure(); - } + const auto bias_quantized_type = CreateI32F32UniformQuantizedPerAxisType( + op.getLoc(), *op.getContext(), std::move(bias_scales), + new_filter_quantized_type.getZeroPoints(), + /*quantization_dimension=*/0); + const auto bias_type = RankedTensorType::getChecked(op.getLoc(), bias_shape, + bias_quantized_type); + TFL::QConstOp bias; + if (fuse_bias_constant && has_i32_output) { + Operation* add_op = FindUserOfType(op); + // TODO: b/309896242 - Lift the assumptions on adjacent ops below + // as we cover more dynamic fused pattern legalization. + Operation* broadcast_in_dim_op = add_op->getOperand(1).getDefiningOp(); + Operation* bias_const_op = + broadcast_in_dim_op->getOperand(0).getDefiningOp(); + const ElementsAttr bias_constant_value = + cast(bias_const_op).getValue(); + bias = rewriter.create(op.getLoc(), + /*output=*/TypeAttr::get(bias_type), + /*value=*/bias_constant_value); } else { - // Op not followed by a requantization is not supported. - return failure(); - } - return success(); + // Create a bias constant. It should have values of 0. + const auto bias_value_type = RankedTensorType::getChecked( + op.getLoc(), bias_shape, rewriter.getI32Type()); + // Create a bias filled with zeros. Mimics the behavior of no bias add. + const auto bias_value = DenseIntElementsAttr::get( + bias_value_type, + APInt(/*numBits=*/32, /*value=*/0, /*isSigned=*/true)); + bias = rewriter.create(op.getLoc(), + /*output=*/TypeAttr::get(bias_type), + /*value=*/bias_value); + } + return bias; } }; // Rewrites quantized stablehlo.transpose to tfl.transpose. // TODO: b/322428814 - Add StableHLO quantizer integration tests for ODML. -class RewriteTransposeOp : public OpRewritePattern { +class RewriteQuantizedTransposeOp + : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; @@ -1281,11 +1308,8 @@ class RewriteTransposeOp : public OpRewritePattern { operand_type.cloneWith(shape, rewriter.getI32Type()); // Cast permutation attribute from i64 to i32 as they are required to be i32 // in TFLite. - SmallVector permutation_i32; - for (int64_t dim : op.getPermutation()) { - permutation_i32.push_back(static_cast(dim)); - } - + SmallVector permutation_i32 = + CastI64ArrayToI32(op.getPermutation()).value(); auto permutation_attr = DenseIntElementsAttr::get(permutation_type, permutation_i32); auto permutation = @@ -1297,7 +1321,8 @@ class RewriteTransposeOp : public OpRewritePattern { // Rewrites quantized stablehlo.reshape to tfl.reshape. // TODO: b/322428814 - Add StableHLO quantizer integration tests for ODML. -class RewriteReshapeOp : public OpRewritePattern { +class RewriteQuantizedReshapeOp + : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; @@ -1310,10 +1335,8 @@ class RewriteReshapeOp : public OpRewritePattern { auto result_type = op->getResult(0).getType().cast(); // Cast result shapes from i64 to i32 as they are required to be i32 in // TFLite. - SmallVector shape_i32; - for (int64_t dim : result_type.getShape()) { - shape_i32.push_back(static_cast(dim)); - } + SmallVector shape_i32 = + CastI64ArrayToI32(result_type.getShape()).value(); const int64_t shape_length = shape_i32.size(); ArrayRef shape(shape_length); @@ -1327,7 +1350,7 @@ class RewriteReshapeOp : public OpRewritePattern { // Rewrites quantized stablehlo.select to tfl.select_v2. // TODO: b/322428814 - Add StableHLO quantizer integration tests for ODML. -class RewriteSelectOp : public OpRewritePattern { +class RewriteQuantizedSelectOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; @@ -1355,7 +1378,8 @@ class RewriteSelectOp : public OpRewritePattern { // Rewrites quantized stablehlo.concatenate to tfl.concatenation. // TODO: b/322428814 - Add StableHLO quantizer integration tests for ODML. -class RewriteConcatenateOp : public OpRewritePattern { +class RewriteQuantizedConcatenateOp + : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; @@ -1366,7 +1390,7 @@ class RewriteConcatenateOp : public OpRewritePattern { void rewrite(stablehlo::ConcatenateOp op, PatternRewriter& rewriter) const override { Type output_type = op.getResult().getType(); - uint32_t axis = static_cast(op.getDimension()); + uint32_t axis = CastI64ToI32(op.getDimension()).value(); rewriter.replaceOpWithNewOp( op, output_type, op.getOperands(), axis, /*fused_activation_function=*/rewriter.getStringAttr("NONE")); @@ -1376,7 +1400,7 @@ class RewriteConcatenateOp : public OpRewritePattern { // Rewrites quantized stablehlo.pad to tfl.padv2. // tfl.dilate is introduced in between when interior padding exists. // TODO: b/322428814 - Add StableHLO quantizer integration tests for ODML. -class RewritePadOp : public OpRewritePattern { +class RewriteQuantizedPadOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; @@ -1404,8 +1428,8 @@ class RewritePadOp : public OpRewritePattern { ArrayRef padding_high = op.getEdgePaddingHigh(); SmallVector padding_value; for (int i = 0; i < rank; ++i) { - padding_value.push_back(static_cast(padding_low[i])); - padding_value.push_back(static_cast(padding_high[i])); + padding_value.push_back(CastI64ToI32(padding_low[i]).value()); + padding_value.push_back(CastI64ToI32(padding_high[i]).value()); } TensorType output_type = op.getResult().getType().cast(); @@ -1426,10 +1450,8 @@ class RewritePadOp : public OpRewritePattern { TensorType dilate_type = operand_type.cloneWith(dilate_shape, rewriter.getI32Type()); ArrayRef interior_padding_i64 = op.getInteriorPadding(); - SmallVector interior_padding_i32; - for (int64_t pad : interior_padding_i64) { - interior_padding_i32.push_back(static_cast(pad)); - } + SmallVector interior_padding_i32 = + CastI64ArrayToI32(interior_padding_i64).value(); auto dilate_attr = DenseIntElementsAttr::get(dilate_type, interior_padding_i32); auto dilate = rewriter.create(op.getLoc(), dilate_attr); @@ -1450,18 +1472,265 @@ class RewritePadOp : public OpRewritePattern { } }; -void UniformQuantizedStablehloToTflPass::runOnOperation() { +// Rewrites quantized stablehlo.slice to tfl.slice or tfl.strided_slice. +// TODO: b/322428814 - Add StableHLO quantizer integration tests for ODML. +class RewriteQuantizedSliceOp : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult match(stablehlo::SliceOp op) const override { + return success(IsOpFullyQuantized(op)); + } + + void rewrite(stablehlo::SliceOp op, + PatternRewriter& rewriter) const override { + auto operand_type = op.getOperand().getType().cast(); + Type output_type = op.getResult().getType(); + const int64_t rank = operand_type.getRank(); + + ArrayRef idx_shape(rank); + TensorType idx_type = + operand_type.cloneWith(idx_shape, rewriter.getI32Type()); + + ArrayRef start_idx_i64 = op.getStartIndices(); + ArrayRef limit_idx_i64 = op.getLimitIndices(); + + SmallVector start_idx_i32 = + CastI64ArrayToI32(start_idx_i64).value(); + auto start_idx_attr = DenseIntElementsAttr::get(idx_type, start_idx_i32); + auto start_idx = + rewriter.create(op.getLoc(), start_idx_attr); + + SmallVector slice_size_i32(rank); + for (int i = 0; i < rank; ++i) { + slice_size_i32[i] = + CastI64ToI32(limit_idx_i64[i] - start_idx_i64[i]).value(); + } + auto slice_size_attr = DenseIntElementsAttr::get(idx_type, slice_size_i32); + auto slice_size = + rewriter.create(op.getLoc(), slice_size_attr); + + ArrayRef strides = op.getStrides(); + // If stride of every dimension is 1, create tfl.slice and return early. + // Otherwise, create tfl.strided_slice instead. + if (llvm::all_of(strides, [](int64_t stride) { return stride == 1; })) { + rewriter.replaceOpWithNewOp( + op, output_type, op.getOperand(), start_idx, slice_size); + return; + } + + SmallVector stride_i32 = CastI64ArrayToI32(strides).value(); + auto stride_attr = DenseIntElementsAttr::get(idx_type, stride_i32); + auto stride = rewriter.create(op.getLoc(), stride_attr); + rewriter.replaceOpWithNewOp( + op, output_type, op.getOperand(), start_idx, slice_size, stride, + /*begin_mask=*/0, /*end_mask=*/0, + /*ellipsis_mask=*/0, /*new_axis_mask=*/0, /*shrink_axis_mask=*/0, + /*offset=*/false); + } +}; + +// Rewrites quantized stablehlo.broadcast_in_dim to tfl.broadcast_to. +// tfl.transpose is introduced when broadcast_dimensions is not in ascending +// order. Also, tfl.expand_dims is introduced when input rank is smaller than +// output rank. +// TODO: b/322428814 - Add StableHLO quantizer integration tests for ODML. +class RewriteQuantizedBroadcastInDimOp + : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult match(stablehlo::BroadcastInDimOp op) const override { + return success(IsOpFullyQuantized(op)); + } + + void rewrite(stablehlo::BroadcastInDimOp op, + PatternRewriter& rewriter) const override { + auto operand_type = op.getOperand().getType().cast(); + auto output_type = op.getResult().getType().cast(); + Value input = op.getOperand(); + + // If broadcast_dimensions is not in ascending order, transpose first. + if (!llvm::is_sorted(op.getBroadcastDimensions())) { + input = InsertTransposeOp(op, rewriter); + } + + // If rank of operand is smaller than that of the output, expand dimensions + // before broadcasting. + if (operand_type.getRank() < output_type.getRank()) { + input = InsertExpandDimsOp(op, rewriter, input, output_type.getRank()); + } + + SmallVector broadcast_shape = + CastI64ArrayToI32(output_type.getShape()).value(); + TensorType broadcast_shape_type = + output_type.cloneWith({output_type.getRank()}, rewriter.getI32Type()); + auto broadcast_shape_attr = + DenseIntElementsAttr::get(broadcast_shape_type, broadcast_shape); + auto shape = + rewriter.create(op.getLoc(), broadcast_shape_attr); + + rewriter.replaceOpWithNewOp(op, output_type, input, + shape); + } + + Value InsertTransposeOp(stablehlo::BroadcastInDimOp op, + PatternRewriter& rewriter) const { + SmallVector sorted_dims = + llvm::to_vector(op.getBroadcastDimensions()); + llvm::sort(sorted_dims); + auto broadcast_dims = op.getBroadcastDimensions(); + SmallVector permutation( + llvm::map_range(broadcast_dims, [sorted_dims](int64_t dim) { + return static_cast(llvm::find(sorted_dims, dim) - + sorted_dims.begin()); + })); + auto operand_type = op.getOperand().getType().cast(); + TensorType perm_type = operand_type.cloneWith( + {static_cast(permutation.size())}, rewriter.getI32Type()); + auto perm_attr = DenseIntElementsAttr::get(perm_type, permutation); + auto perm = rewriter.create(op.getLoc(), perm_attr); + Value input = op.getOperand(); + + return rewriter.create(op.getLoc(), input, perm); + } + + Value InsertExpandDimsOp(stablehlo::BroadcastInDimOp op, + PatternRewriter& rewriter, Value input, + int64_t output_rank) const { + auto input_type = input.getType().cast(); + SmallVector input_shape(input_type.getShape()); + SmallVector input_dims = + llvm::to_vector(op.getBroadcastDimensions()); + + while (input_dims.size() < output_rank) { + int32_t dim_to_expand = 0; + for (int32_t i = 0; i < output_rank; ++i) { + if (!llvm::is_contained(input_dims, i)) { + dim_to_expand = i; + break; + } + } + + TensorType dim_type = input_type.cloneWith({static_cast(1)}, + rewriter.getI32Type()); + ArrayRef dims(dim_to_expand); + auto dim_attr = DenseIntElementsAttr::get(dim_type, dims); + auto dim = rewriter.create(op.getLoc(), dim_attr); + + input_shape.insert(input_shape.begin() + dim_to_expand, 1); + TensorType expanded_type = input_type.clone(input_shape); + input = rewriter.create(op.getLoc(), expanded_type, + input, dim); + + // Update expanded dimension in the input dimensions for the next + // iteration. + input_dims.push_back(static_cast(dim_to_expand)); + } + return input; + } +}; + +// Rewrites quantized stablehlo.reduce_window with max to tfl.max_pool_2d. +// TODO: b/322428814 - Add StableHLO quantizer integration tests for ODML. +class RewriteQuantizedReduceWindowOpWithMax + : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult MatchBinaryReduceFunction(Region& function) const { + Block& body = function.front(); + if (body.getNumArguments() != 2) return failure(); + + auto return_op = dyn_cast(body.back()); + if (!return_op) return failure(); + if (return_op.getNumOperands() != 1) return failure(); + + auto reduce_op = dyn_cast_or_null( + return_op.getOperands().front().getDefiningOp()); + if (!reduce_op) return failure(); + return success(reduce_op.getLhs() == body.getArgument(0) && + reduce_op.getRhs() == body.getArgument(1)); + } + + LogicalResult match(stablehlo::ReduceWindowOp op) const override { + // Check that the reduce-window is a max-reduce-window. + if (failed(MatchBinaryReduceFunction(op.getBody()))) { + return failure(); + } + + // Only 2d pooling is supported in TFLite. + if (op.getWindowDimensions().size() != 4) { + return failure(); + } + + // reduce_window op with dilations or padding will supported later. + // TODO: b/321099943 - Support reduce_window op with dilations and padding. + if (op.getBaseDilations().has_value() || + op.getWindowDilations().has_value() || op.getPadding().has_value()) { + return failure(); + } + + // Window_dimensions and window_strides should have batch and channel + // dimension of 1 as they cannot be specified in tfl.max_pool_2d. + ArrayRef window_dims = op.getWindowDimensions(); + if (window_dims[0] != 1 || window_dims[3] != 1) { + return failure(); + } + std::optional> window_strides = op.getWindowStrides(); + if (window_strides.has_value()) { + if ((*window_strides)[0] != 1 || (*window_strides)[3] != 1) { + return failure(); + } + } + + return success(IsOpFullyQuantized(op)); + } + + void rewrite(stablehlo::ReduceWindowOp op, + PatternRewriter& rewriter) const override { + Type result_type = op.getResult(0).getType(); + Value input = op.getOperand(0); + // Ops with padding is rejected in matching function, so we can use the + // padding to be 'VALID'. + StringAttr padding = rewriter.getStringAttr("VALID"); + + // Use NHWC format. + int32_t stride_h = 1; + int32_t stride_w = 1; + std::optional> window_strides = op.getWindowStrides(); + if (window_strides.has_value()) { + stride_h = CastI64ToI32((*window_strides)[1]).value(); + stride_w = CastI64ToI32((*window_strides)[2]).value(); + } + auto stride_h_attr = IntegerAttr::get(rewriter.getI32Type(), stride_h); + auto stride_w_attr = IntegerAttr::get(rewriter.getI32Type(), stride_w); + + ArrayRef window_dims = op.getWindowDimensions(); + auto window_w_attr = IntegerAttr::get(rewriter.getI32Type(), + CastI64ToI32(window_dims[2]).value()); + auto window_h_attr = IntegerAttr::get(rewriter.getI32Type(), + CastI64ToI32(window_dims[1]).value()); + StringAttr activation_function = rewriter.getStringAttr("NONE"); + + rewriter.replaceOpWithNewOp( + op, result_type, input, padding, stride_w_attr, stride_h_attr, + window_w_attr, window_h_attr, activation_function); + } +}; + +void UniformQuantizedStableHloToTflPass::runOnOperation() { func::FuncOp func_op = getOperation(); MLIRContext& ctx = getContext(); RewritePatternSet patterns(&ctx); patterns.add(&ctx); + RewriteQuantizedConvolutionOp, RewriteQuantizedTransposeOp, + RewriteQuantizedReshapeOp, RewriteQuantizedSelectOp, + RewriteQuantizedConcatenateOp, RewriteQuantizedPadOp, + RewriteQuantizedSliceOp, RewriteQuantizedBroadcastInDimOp, + RewriteQuantizedReduceWindowOpWithMax>(&ctx); if (failed(applyPatternsAndFoldGreedily(func_op, std::move(patterns)))) { func_op.emitError() << "Failed to convert stablehlo ops with uniform " @@ -1473,11 +1742,11 @@ void UniformQuantizedStablehloToTflPass::runOnOperation() { } // namespace std::unique_ptr> -CreateUniformQuantizedStablehloToTflPass() { - return std::make_unique(); +CreateUniformQuantizedStableHloToTflPass() { + return std::make_unique(); } -static PassRegistration pass; +static PassRegistration pass; } // namespace odml } // namespace mlir diff --git a/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/mix_tflite_stablehlo.mlir b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/mix_tflite_stablehlo.mlir index cbf7c3dd6cebfe..3c70fd1dfacf42 100644 --- a/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/mix_tflite_stablehlo.mlir +++ b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/mix_tflite_stablehlo.mlir @@ -9,10 +9,8 @@ func.func @main(%arg0: tensor<1x1x1x96xf32>) -> tensor<1x1x1x96xf32> { } } -// CHECK:module attributes {tfl.description = "MLIR Converted.", tfl.schema_version = 3 : i32} { -// CHECK-NEXT: func.func @main(%arg0: tensor<1x1x1x96xf32>) -> tensor<1x1x1x96xf32> attributes {tf.entry_function = {inputs = "arg0", outputs = "exp"}} { +// CHECK: func.func @main(%arg0: tensor<1x1x1x96xf32>) -> tensor<1x1x1x96xf32> attributes {tf.entry_function = {inputs = "arg0", outputs = "exp"}} { // CHECK-NEXT: %0 = stablehlo.logistic %arg0 : tensor<1x1x1x96xf32> // CHECK-NEXT: %1 = "tfl.exp"(%0) : (tensor<1x1x1x96xf32>) -> tensor<1x1x1x96xf32> // CHECK-NEXT: return %1 : tensor<1x1x1x96xf32> -// CHECK-NEXT: } -// CHECK-NEXT:} \ No newline at end of file +// CHECK-NEXT: } \ No newline at end of file diff --git a/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/simple.mlir b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/simple.mlir index 75474e7ec8b268..76f778bcebec20 100644 --- a/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/simple.mlir +++ b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/simple.mlir @@ -5,7 +5,9 @@ func.func @main(tensor<3x2xi32>) -> tensor<3x2xi32> { ^bb0(%arg0: tensor<3x2xi32>): - // CHECK: module attributes {tfl.description = "MLIR Converted.", tfl.schema_version = 3 : i32} + // CHECK: module attributes + // CHECK-SAME: tfl.description = "MLIR Converted." + // CHECK-SAME: tfl.schema_version = 3 : i32 // CHECK: %{{.*}} = "tfl.pseudo_const"() {value = dense<{{\[\[1, 2\], \[3, 4\], \[5, 6\]\]}}> : tensor<3x2xi32>} // CHECK-NEXT: [[SUB:%.*]] = tfl.sub %{{.*}}, %{{.*}} {fused_activation_function = "RELU6"} : tensor<3x2xi32> diff --git a/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/stablehlo_const.mlir b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/stablehlo_const.mlir index f3c64f67fc5f9b..e9708e0f14a877 100644 --- a/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/stablehlo_const.mlir +++ b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/stablehlo_const.mlir @@ -8,9 +8,12 @@ module attributes {tfl.metadata = {"keep_stablehlo_constant" = "true"}} { } } -//CHECK:module attributes {tfl.description = "MLIR Converted.", tfl.metadata = {keep_stablehlo_constant = "true"}, tfl.schema_version = 3 : i32} { -//CHECK-NEXT: func.func @main() -> tensor<1x1x1x96xf32> attributes {tf.entry_function = {outputs = "stablehlo.constant"}} { -//CHECK-NEXT: %0 = stablehlo.constant dense<0.000000e+00> : tensor<1x1x1x96xf32> -//CHECK-NEXT: return %0 : tensor<1x1x1x96xf32> -//CHECK-NEXT: } -//CHECK-NEXT:} \ No newline at end of file +// CHECK: module attributes { +// CHECK-SAME: tfl.metadata +// CHECK-SAME: keep_stablehlo_constant = "true" + +// CHECK-NEXT: func.func @main() -> tensor<1x1x1x96xf32> attributes {tf.entry_function = {outputs = "stablehlo.constant"}} { +// CHECK-NEXT: %0 = stablehlo.constant dense<0.000000e+00> : tensor<1x1x1x96xf32> +// CHECK-NEXT: return %0 : tensor<1x1x1x96xf32> +// CHECK-NEXT: } +// CHECK-NEXT:} \ No newline at end of file diff --git a/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/tf_variant_type.mlir b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/tf_variant_type.mlir index 5d1566cf121590..07738c1102f767 100644 --- a/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/tf_variant_type.mlir +++ b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/tf_variant_type.mlir @@ -4,6 +4,5 @@ func.func @main(%arg0 : tensor>>, %arg1: tensor>> } -// CHECK: module attributes {tfl.description = "MLIR Converted.", tfl.schema_version = 3 : i32} // CHECK: func.func @main(%[[ARG0:.*]]: tensor>>, %[[ARG1:.*]]: tensor>>) -> tensor>> // CHECK-NEXT: return %[[ARG0]] : tensor>> diff --git a/tensorflow/compiler/mlir/lite/tests/legalize-tf-no-runtime-verification.mlir b/tensorflow/compiler/mlir/lite/tests/legalize-tf-no-runtime-verification.mlir index 40db57950cc30f..46ff509b7cc46e 100644 --- a/tensorflow/compiler/mlir/lite/tests/legalize-tf-no-runtime-verification.mlir +++ b/tensorflow/compiler/mlir/lite/tests/legalize-tf-no-runtime-verification.mlir @@ -1,4 +1,4 @@ -// RUN: tf-opt %s -tfl-prepare-tf -tfl-legalize-tf='run-tfl-runtime-verification=false' | FileCheck %s +// RUN: tf-opt %s -tfl-prepare-tf -tfl-legalize-tf='run-tfl-runtime-verification=false' -tfl-optimize | FileCheck %s func.func @broadcast_to_bf16(%arg0: tensor<3xbf16>, %arg1: tensor<2xi64>) -> tensor<3x3xbf16> { %0 = "tf.BroadcastTo"(%arg0, %arg1) : (tensor<3xbf16>, tensor<2xi64>) -> tensor<3x3xbf16> diff --git a/tensorflow/compiler/mlir/lite/tests/optimize.mlir b/tensorflow/compiler/mlir/lite/tests/optimize.mlir index cbdb85ca94d5ab..9da0e13c0471ac 100644 --- a/tensorflow/compiler/mlir/lite/tests/optimize.mlir +++ b/tensorflow/compiler/mlir/lite/tests/optimize.mlir @@ -784,26 +784,6 @@ func.func @FuseReshapeAroundBMMRHS(%arg0: tensor<1x3x6x5x1024xf32>) -> tensor<1x // CHECK: return %0 : tensor<1x3x6x5x8192xf32> } -// CHECK-LABEL: @FuseTransposeIntoBMM_RHS -func.func @FuseTransposeIntoBMM_RHS(%arg0: tensor<1x4x1440x256xf32>, %arg1: tensor<1x1440x256xf32>) -> tensor<1x4x1440x1440xf32> { - %cst_1 = arith.constant dense<[0, 2, 1]> : tensor<3xi32> - %32 = "tfl.transpose"(%arg1, %cst_1) : (tensor<1x1440x256xf32>, tensor<3xi32>) -> tensor<1x256x1440xf32> - %33 = "tfl.batch_matmul"(%arg0, %32) {adj_x = false, adj_y = false} : (tensor<1x4x1440x256xf32>, tensor<1x256x1440xf32>) -> tensor<1x4x1440x1440xf32> - return %33 : tensor<1x4x1440x1440xf32> - // CHECK: %0 = "tfl.batch_matmul"(%arg0, %arg1) {adj_x = false, adj_y = true} : (tensor<1x4x1440x256xf32>, tensor<1x1440x256xf32>) -> tensor<1x4x1440x1440xf32> - // CHECK: return %0 : tensor<1x4x1440x1440xf32> -} - -// CHECK-LABEL: @FuseTransposeIntoBMM_RHS2 -func.func @FuseTransposeIntoBMM_RHS2(%arg0: tensor, %arg1: tensor) -> tensor { - %cst_1 = arith.constant dense<[0, 2, 1]> : tensor<3xi32> - %32 = "tfl.transpose"(%arg1, %cst_1) : (tensor, tensor<3xi32>) -> tensor - %33 = "tfl.batch_matmul"(%arg0, %32) {adj_x = false, adj_y = false} : (tensor, tensor) -> tensor - return %33 : tensor - // CHECK: %0 = "tfl.batch_matmul"(%arg0, %arg1) {adj_x = false, adj_y = true} : (tensor, tensor) -> tensor - // CHECK: return %0 : tensor -} - // CHECK-LABEL: @FuseTransposeIntoBMM_LHS func.func @FuseTransposeIntoBMM_LHS(%arg0: tensor<1x4x1440x256xf32>, %arg1: tensor<1x1440x256xf32>) -> tensor<1x4x256x256xf32> { %cst_1 = arith.constant dense<[0, 2, 1]> : tensor<3xi32> @@ -3919,3 +3899,163 @@ func.func @NoReorderNCHWTransposeAddNotBias(%arg0: tensor<1x40x40x1xf32>, %filte // CHECK: %[[add:.*]] = tfl.add %[[transpose]], // CHECK: return %[[add]] } + +// CHECK-LABEL: @ConvertStridedSliceToSlice +func.func @ConvertStridedSliceToSlice(%arg0: tensor<2x3872x1x128xf32>) -> tensor<1x3872x1x128xf32> { + %44 = arith.constant dense<0> : tensor<4xi32> + %45 = arith.constant dense<[1, 3872, 1, 128]> : tensor<4xi32> + %46 = arith.constant dense<1> : tensor<4xi32> + %47 = "tfl.strided_slice"(%arg0, %44, %45, %46) {begin_mask = 0 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, offset = false, shrink_axis_mask = 0 : i32} : (tensor<2x3872x1x128xf32>, tensor<4xi32>, tensor<4xi32>, tensor<4xi32>) -> tensor<1x3872x1x128xf32> + func.return %47 : tensor<1x3872x1x128xf32> + + // CHECK: %[[slice:.*]] = "tfl.slice" + // CHECK: return %[[slice]] +} + +// CHECK-LABEL: @FuseExcessBroadcastingOnReshapes +func.func @FuseExcessBroadcastingOnReshapes(%arg0: tensor<1x8xf32>) -> tensor<1x1x1x128xf32> { + %cst = arith.constant dense<[1, 1, 1, 8, 1, 1]> : tensor<6xi32> + %cst_0 = arith.constant dense<[1, 1, 1, 8, 16, 1]> : tensor<6xi32> + %cst_1 = arith.constant dense<[1, 1, 1, 128]> : tensor<4xi32> + %0 = "tfl.reshape"(%arg0, %cst) : (tensor<1x8xf32>, tensor<6xi32>) -> tensor<1x1x1x8x1x1xf32> + %1 = "tfl.broadcast_to"(%0, %cst_0) : (tensor<1x1x1x8x1x1xf32>, tensor<6xi32>) -> tensor<1x1x1x8x16x1xf32> + %2 = "tfl.reshape"(%1, %cst_1) : (tensor<1x1x1x8x16x1xf32>, tensor<4xi32>) -> tensor<1x1x1x128xf32> + return %2 : tensor<1x1x1x128xf32> + // CHECK: %cst = arith.constant dense<1.000000e+00> : tensor<8x16xf32> + // CHECK: %cst_0 = arith.constant dense<[1, 1, 1, 128]> : tensor<4xi32> + // CHECK: %cst_1 = arith.constant dense<[8, 1]> : tensor<2xi32> + // CHECK: %0 = "tfl.reshape"(%arg0, %cst_1) : (tensor<1x8xf32>, tensor<2xi32>) -> tensor<8x1xf32> + // CHECK: %1 = tfl.mul(%0, %cst) {fused_activation_function = "NONE"} : (tensor<8x1xf32>, tensor<8x16xf32>) -> tensor<8x16xf32> + // CHECK: %2 = "tfl.reshape"(%1, %cst_0) : (tensor<8x16xf32>, tensor<4xi32>) -> tensor<1x1x1x128xf32> + // CHECK: return %2 : tensor<1x1x1x128xf32> +} + +// CHECK-LABEL: @FuseExcessBroadcastingOnReshapesDynamicShapes +func.func @FuseExcessBroadcastingOnReshapesDynamicShapes(%arg0: tensor, %arg1: tensor<6xi32>, %arg2: tensor<6xi32>, %arg3: tensor<2xi32>) -> tensor { + %1196 = "tfl.reshape"(%arg0, %arg1) : (tensor, tensor<6xi32>) -> tensor<1x?x1x10x1x1xf32> + %1197 = "tfl.broadcast_to"(%1196, %arg2) : (tensor<1x?x1x10x1x1xf32>, tensor<6xi32>) -> tensor<1x?x1x10x5x1xf32> + %1198 = "tfl.reshape"(%1197, %arg3) : (tensor<1x?x1x10x5x1xf32>, tensor<2xi32>) -> tensor + return %1198 : tensor + + // CHECK: %0 = "tfl.reshape"(%arg0, %arg1) : (tensor, tensor<6xi32>) -> tensor<1x?x1x10x1x1xf32> + // CHECK: %1 = "tfl.broadcast_to"(%0, %arg2) : (tensor<1x?x1x10x1x1xf32>, tensor<6xi32>) -> tensor<1x?x1x10x5x1xf32> + // CHECK: %2 = "tfl.reshape"(%1, %arg3) : (tensor<1x?x1x10x5x1xf32>, tensor<2xi32>) -> tensor + // CHECK: return %2 : tensor +} + +// CHECK-LABEL: @broadcast_to_f32_low_dim +func.func @broadcast_to_f32_low_dim(%arg0: tensor<3xf32>, %arg1: tensor<2xi32>) -> tensor<3x3xf32> { + %0 = "tfl.broadcast_to"(%arg0, %arg1) : (tensor<3xf32>, tensor<2xi32>) -> tensor<3x3xf32> + return %0 : tensor<3x3xf32> + // CHECK: %cst = arith.constant dense<1.000000e+00> : tensor<3x3xf32> + // CHECK: %0 = tfl.mul(%arg0, %cst) {fused_activation_function = "NONE"} : (tensor<3xf32>, tensor<3x3xf32>) -> tensor<3x3xf32> + // CHECK: return %0 : tensor<3x3xf32> +} + +// CHECK-LABEL: @broadcast_to_i32_low_dim +func.func @broadcast_to_i32_low_dim(%arg0: tensor<3xi32>, %arg1: tensor<2xi32>) -> tensor<3x3xi32> { + %0 = "tfl.broadcast_to"(%arg0, %arg1) : (tensor<3xi32>, tensor<2xi32>) -> tensor<3x3xi32> + return %0 : tensor<3x3xi32> + // CHECK: %cst = arith.constant dense<1> : tensor<3x3xi32> + // CHECK: %0 = tfl.mul(%arg0, %cst) {fused_activation_function = "NONE"} : (tensor<3xi32>, tensor<3x3xi32>) -> tensor<3x3xi32> + // CHECK: return %0 : tensor<3x3xi32> +} + +// CHECK-LABEL: @broadcast_to_low_dim_with_unknown_shape +func.func @broadcast_to_low_dim_with_unknown_shape(%arg0: tensor<3xf32>, %arg1: tensor<*xi32>) -> tensor<3x3xf32> { + %0 = "tfl.broadcast_to"(%arg0, %arg1) : (tensor<3xf32>, tensor<*xi32>) -> tensor<3x3xf32> + return %0 : tensor<3x3xf32> + // CHECK: %cst = arith.constant dense<1.000000e+00> : tensor<3x3xf32> + // CHECK: %0 = tfl.mul(%arg0, %cst) {fused_activation_function = "NONE"} : (tensor<3xf32>, tensor<3x3xf32>) -> tensor<3x3xf32> + // CHECK: return %0 : tensor<3x3xf32> +} + +// CHECK-LABEL: @broadcast_to_i16_low_dim +func.func @broadcast_to_i16_low_dim(%arg0: tensor<3xi16>, %arg1: tensor<2xi32>) -> tensor<3x3xi16> { + %0 = "tfl.broadcast_to"(%arg0, %arg1) : (tensor<3xi16>, tensor<2xi32>) -> tensor<3x3xi16> + return %0 : tensor<3x3xi16> + // CHECK: %cst = arith.constant dense<1> : tensor<3x3xi16> + // CHECK: %0 = tfl.mul(%arg0, %cst) {fused_activation_function = "NONE"} : (tensor<3xi16>, tensor<3x3xi16>) -> tensor<3x3xi16> + // CHECK: return %0 : tensor<3x3xi16> +} + +// CHECK-LABEL: @broadcast_to_i32_low_dim_with_unknown_output +func.func @broadcast_to_i32_low_dim_with_unknown_output(%arg0: tensor<3xi32>, %arg1: tensor<2xi32>) -> tensor<*xi32> { + %0 = "tfl.broadcast_to"(%arg0, %arg1) : (tensor<3xi32>, tensor<2xi32>) -> tensor<*xi32> + return %0 : tensor<*xi32> + // CHECK: %cst = arith.constant dense<1> : tensor + // CHECK: %0 = "tfl.fill"(%arg1, %cst) : (tensor<2xi32>, tensor) -> tensor<*xi32> + // CHECK: %1 = tfl.mul(%arg0, %0) {fused_activation_function = "NONE"} : (tensor<3xi32>, tensor<*xi32>) -> tensor<*xi32> + // CHECK: return %1 : tensor<*xi32> +} + +// CHECK-LABEL: @broadcast_to_ui32 +func.func @broadcast_to_ui32(%arg0: tensor, %arg1: tensor<1xi64>) -> tensor<10xui32> { + %0 = "tfl.broadcast_to"(%arg0, %arg1) : (tensor, tensor<1xi64>) -> tensor<10xui32> + return %0 : tensor<10xui32> + // CHECK: %cst = arith.constant dense<1> : tensor<10xui32> + // CHECK: %0 = tfl.mul(%arg0, %cst) {fused_activation_function = "NONE"} : (tensor, tensor<10xui32>) -> tensor<10xui32> + // CHECK: return %0 : tensor<10xui32> +} + +// CHECK-LABEL: @broadcast_to_f32 +func.func @broadcast_to_f32(%arg0: tensor<3xf32>, %arg1: tensor<2xi32>) -> tensor<3x3xf32> { + %0 = "tfl.broadcast_to"(%arg0, %arg1) : (tensor<3xf32>, tensor<2xi32>) -> tensor<3x3xf32> + return %0 : tensor<3x3xf32> + // CHECK: %cst = arith.constant dense<1.000000e+00> : tensor<3x3xf32> + // CHECK: %0 = tfl.mul(%arg0, %cst) {fused_activation_function = "NONE"} : (tensor<3xf32>, tensor<3x3xf32>) -> tensor<3x3xf32> + // CHECK: return %0 : tensor<3x3xf32> +} + +// CHECK-LABEL: @broadcast_to_i32 +func.func @broadcast_to_i32(%arg0: tensor<3xi32>, %arg1: tensor<2xi32>) -> tensor<3x3xi32> { + %0 = "tfl.broadcast_to"(%arg0, %arg1) : (tensor<3xi32>, tensor<2xi32>) -> tensor<3x3xi32> + return %0 : tensor<3x3xi32> + // CHECK: %cst = arith.constant dense<1> : tensor<3x3xi32> + // CHECK: %0 = tfl.mul(%arg0, %cst) {fused_activation_function = "NONE"} : (tensor<3xi32>, tensor<3x3xi32>) -> tensor<3x3xi32> + // CHECK: return %0 : tensor<3x3xi32> +} + +// CHECK-LABEL: @broadcast_to_i32_with_dynamic_shape_and_output +func.func @broadcast_to_i32_with_dynamic_shape_and_output(%arg0: tensor<3xi32>, %arg1: tensor<2xi32>) -> tensor<3x?xi32> { + %0 = "tfl.broadcast_to"(%arg0, %arg1) : (tensor<3xi32>, tensor<2xi32>) -> tensor<3x?xi32> + return %0 : tensor<3x?xi32> + // CHECK: %cst = arith.constant dense<1> : tensor + // CHECK: %0 = "tfl.fill"(%arg1, %cst) : (tensor<2xi32>, tensor) -> tensor<3x?xi32> + // CHECK: %1 = tfl.mul(%arg0, %0) {fused_activation_function = "NONE"} : (tensor<3xi32>, tensor<3x?xi32>) -> tensor<3x?xi32> + // CHECK: return %1 : tensor<3x?xi32> +} + +// CHECK-LABEL: @broadcast_to_ui32_with_dynamic_output +func.func @broadcast_to_ui32_with_dynamic_output(%arg0: tensor<1xi32>) -> tensor { + %cst = arith.constant dense<0> : tensor<1xui32> + %0 = "tfl.broadcast_to"(%cst, %arg0) : (tensor<1xui32>, tensor<1xi32>) -> tensor + return %0 : tensor + + // CHECK: %cst = arith.constant dense<0> : tensor<1xui32> + // CHECK: %0 = "tfl.broadcast_to"(%cst, %arg0) : (tensor<1xui32>, tensor<1xi32>) -> tensor + // CHECK: return %0 : tensor +} + + +// CHECK-LABEL: @ConvertStridedSliceToSliceNeg +func.func @ConvertStridedSliceToSliceNeg(%arg0: tensor<5x5x5x5xf32>) -> tensor<*xf32> { + %44 = arith.constant dense<[5, 5, 5, 5]> : tensor<4xi32> + %45 = arith.constant dense<[1, 1, 1, 1]> : tensor<4xi32> + %46 = arith.constant dense<1> : tensor<4xi32> + %47 = "tfl.strided_slice"(%arg0, %44, %45, %46) {begin_mask = 0 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, offset = false, shrink_axis_mask = 0 : i32} : (tensor<5x5x5x5xf32>, tensor<4xi32>, tensor<4xi32>, tensor<4xi32>) -> tensor<*xf32> + func.return %47 : tensor<*xf32> + + // CHECK-NOT: %[[slice:.*]] = "tfl.slice" +} + +// CHECK-LABEL: @StridedSliceToSliceBeginNeg +func.func @StridedSliceToSliceBeginNeg(%arg0: tensor<5x5x5x5xf32>) -> tensor<*xf32> { + %44 = arith.constant dense<[-5, 0, 0, 0]> : tensor<4xi32> + %45 = arith.constant dense<[1, 1, 1, 1]> : tensor<4xi32> + %46 = arith.constant dense<1> : tensor<4xi32> + %47 = "tfl.strided_slice"(%arg0, %44, %45, %46) {begin_mask = 0 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, offset = false, shrink_axis_mask = 0 : i32} : (tensor<5x5x5x5xf32>, tensor<4xi32>, tensor<4xi32>, tensor<4xi32>) -> tensor<*xf32> + func.return %47 : tensor<*xf32> + + // CHECK-NOT: %[[slice:.*]] = "tfl.slice" +} \ No newline at end of file diff --git a/tensorflow/compiler/mlir/lite/tests/prepare-quantize-dynamic-range.mlir b/tensorflow/compiler/mlir/lite/tests/prepare-quantize-dynamic-range.mlir index 0c9f058c1912c9..b35355524127dc 100644 --- a/tensorflow/compiler/mlir/lite/tests/prepare-quantize-dynamic-range.mlir +++ b/tensorflow/compiler/mlir/lite/tests/prepare-quantize-dynamic-range.mlir @@ -103,8 +103,8 @@ func.func @QuantizeFullyConnected(%arg0: tensor<1x224x224x3xf32>) -> tensor<1x11 func.return %fc : tensor<1x112x112x512xf32> // CHECK-DAG: %[[w:.*]] = arith.constant dense<1.270000e+02> : tensor<512x12xf32> -// CHECK-DAG: %[[q_w:.*]] = "tfl.quantize"(%[[w]]) {qtype = tensor<512x12x!quant.uniform:f32, 1.000000e+00>>} -// CHECK-DAG: %[[dq_w:.*]] = "tfl.dequantize"(%[[q_w]]) : (tensor<512x12x!quant.uniform:f32, 1.000000e+00>>) -> tensor<512x12xf32> +// CHECK-DAG: %[[q_w:.*]] = "tfl.quantize"(%[[w]]) {qtype = tensor<512x12x!quant.uniform:f32:0, {1.000000e+00,1.000000e+00,1.000000e+00,1.000000e+00, +// CHECK-DAG: %[[dq_w:.*]] = "tfl.dequantize"(%[[q_w]]) : (tensor<512x12x!quant.uniform:f32:0, {1.000000e+00,1.000000e+00,1.000000e+00,1.000000e+00, // CHECK-DAG: %[[b:.*]] = arith.constant dense<0.000000e+00> : tensor<512xf32> // CHECK: %[[fc:.*]] = "tfl.fully_connected"(%arg0, %[[dq_w]], %[[b]]) { // CHECK-NOT: fused_activation_function = "NONE" diff --git a/tensorflow/compiler/mlir/lite/tests/prepare-quantize-post-training-16bits.mlir b/tensorflow/compiler/mlir/lite/tests/prepare-quantize-post-training-16bits.mlir index d2e04734e0e2e6..15ede0019e12d6 100644 --- a/tensorflow/compiler/mlir/lite/tests/prepare-quantize-post-training-16bits.mlir +++ b/tensorflow/compiler/mlir/lite/tests/prepare-quantize-post-training-16bits.mlir @@ -298,8 +298,8 @@ func.func @QuantizeFullyConnectedOp(%arg0: tensor<1x3xf32>) -> (tensor<1x1xf32>) // CHECK-NEXT: %[[q1:.*]] = "tfl.quantize"(%[[cst]]) {qtype = tensor<1x!quant.uniform>, volatile} : (tensor<1xf32>) -> tensor<1x!quant.uniform> // CHECK-NEXT: %[[dq1:.*]] = "tfl.dequantize"(%[[q1]]) : (tensor<1x!quant.uniform>) -> tensor<1xf32> // CHECK-NEXT: %[[cst_0:.*]] = arith.constant dense<{{.*}}> : tensor<1x3xf32> -// CHECK-NEXT: %[[q2:.*]] = "tfl.quantize"(%[[cst_0]]) {qtype = tensor<1x3x!quant.uniform:f32, {{.*}}>>, volatile} : (tensor<1x3xf32>) -> tensor<1x3x!quant.uniform:f32, {{.*}}>> -// CHECK-NEXT: %[[dq2:.*]] = "tfl.dequantize"(%[[q2]]) : (tensor<1x3x!quant.uniform:f32, {{.*}}>>) -> tensor<1x3xf32> +// CHECK-NEXT: %[[q2:.*]] = "tfl.quantize"(%[[cst_0]]) {qtype = tensor<1x3x!quant.uniform:f32:0, {{.*}}>>, volatile} : (tensor<1x3xf32>) -> tensor<1x3x!quant.uniform:f32:0, {{.*}}>> +// CHECK-NEXT: %[[dq2:.*]] = "tfl.dequantize"(%[[q2]]) : (tensor<1x3x!quant.uniform:f32:0, {{.*}}>>) -> tensor<1x3xf32> // CHECK-NEXT: %[[q3:.*]] = "tfl.quantize"(%arg0) {qtype = tensor<1x3x!quant.uniform>, volatile} : (tensor<1x3xf32>) -> tensor<1x3x!quant.uniform> // CHECK-NEXT: %[[dq3:.*]] = "tfl.dequantize"(%[[q3]]) : (tensor<1x3x!quant.uniform>) -> tensor<1x3xf32> // CHECK-NEXT: %[[fc:.*]] = "tfl.fully_connected"(%[[dq3]], %[[dq2]], %[[dq1]]) {{{.*}}} : (tensor<1x3xf32>, tensor<1x3xf32>, tensor<1xf32>) -> tensor<1x1xf32> @@ -324,8 +324,8 @@ func.func @QuantizeReshapeAndFullyConnectedOp(%arg0: tensor<1x1x3xf32>) -> (tens // CHECK-NEXT: %[[q1:.*]] = "tfl.quantize"(%[[cst]]) {qtype = tensor<1x!quant.uniform>, volatile} : (tensor<1xf32>) -> tensor<1x!quant.uniform> // CHECK-NEXT: %[[dq1:.*]] = "tfl.dequantize"(%[[q1]]) : (tensor<1x!quant.uniform>) -> tensor<1xf32> // CHECK-NEXT: %[[cst_0:.*]] = arith.constant dense<{{.*}}> : tensor<1x3xf32> -// CHECK-NEXT: %[[q2:.*]] = "tfl.quantize"(%[[cst_0]]) {qtype = tensor<1x3x!quant.uniform:f32, {{.*}}>>, volatile} : (tensor<1x3xf32>) -> tensor<1x3x!quant.uniform:f32, {{.*}}>> -// CHECK-NEXT: %[[dq2:.*]] = "tfl.dequantize"(%[[q2]]) : (tensor<1x3x!quant.uniform:f32, {{.*}}>>) -> tensor<1x3xf32> +// CHECK-NEXT: %[[q2:.*]] = "tfl.quantize"(%[[cst_0]]) {qtype = tensor<1x3x!quant.uniform:f32:0, {{.*}}>>, volatile} : (tensor<1x3xf32>) -> tensor<1x3x!quant.uniform:f32:0, {{.*}}>> +// CHECK-NEXT: %[[dq2:.*]] = "tfl.dequantize"(%[[q2]]) : (tensor<1x3x!quant.uniform:f32:0, {{.*}}>>) -> tensor<1x3xf32> // CHECK-NEXT: %[[cst_1:.*]] = arith.constant dense<[-1, 3]> : tensor<2xi32> // CHECK-NEXT: %[[q3:.*]] = "tfl.quantize"(%arg0) {qtype = tensor<1x1x3x!quant.uniform>, volatile} : (tensor<1x1x3xf32>) -> tensor<1x1x3x!quant.uniform> // CHECK-NEXT: %[[dq3:.*]] = "tfl.dequantize"(%[[q3]]) : (tensor<1x1x3x!quant.uniform>) -> tensor<1x1x3xf32> diff --git a/tensorflow/compiler/mlir/lite/tests/prepare-quantize-signed.mlir b/tensorflow/compiler/mlir/lite/tests/prepare-quantize-signed.mlir index 6e9ca99e11f492..882b335135cf74 100644 --- a/tensorflow/compiler/mlir/lite/tests/prepare-quantize-signed.mlir +++ b/tensorflow/compiler/mlir/lite/tests/prepare-quantize-signed.mlir @@ -166,20 +166,20 @@ func.func @prepareDepthwiseConv2D(%arg0: tensor<1x224x224x3xf32>) -> tensor<1x11 // CHECK-LABEL: QuantizeFullyConnected // PerTensor-LABEL: QuantizeFullyConnected -func.func @QuantizeFullyConnected(%arg0: tensor<1x224x224x3xf32>) -> tensor<1x112x112x32xf32> { - %w = arith.constant dense<127.0> : tensor<32x12xf32> - %b = arith.constant dense<0.0> : tensor<32xf32> - %fc = "tfl.fully_connected"(%arg0, %w, %b) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<1x224x224x3xf32>, tensor<32x12xf32>, tensor<32xf32>) -> tensor<1x112x112x32xf32> - func.return %fc : tensor<1x112x112x32xf32> - -// CHECK: %[[cst:.*]] = arith.constant dense<1.270000e+02> : tensor<32x12xf32> -// CHECK: %[[q:.*]] = "tfl.quantize"(%cst) {qtype = tensor<32x12x!quant.uniform:f32, 1.000000e+00>>, volatile} -// CHECK: %[[dq:.*]] = "tfl.dequantize"(%0) : (tensor<32x12x!quant.uniform:f32, 1.000000e+00>>) -> tensor<32x12xf32> +func.func @QuantizeFullyConnected(%arg0: tensor<1x224x224x3xf32>) -> tensor<1x112x112x4xf32> { + %w = arith.constant dense<127.0> : tensor<4x12xf32> + %b = arith.constant dense<0.0> : tensor<4xf32> + %fc = "tfl.fully_connected"(%arg0, %w, %b) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<1x224x224x3xf32>, tensor<4x12xf32>, tensor<4xf32>) -> tensor<1x112x112x4xf32> + func.return %fc : tensor<1x112x112x4xf32> + +// CHECK: %[[cst:.*]] = arith.constant dense<1.270000e+02> : tensor<4x12xf32> +// CHECK: %[[q:.*]] = "tfl.quantize"(%cst) {qtype = tensor<4x12x!quant.uniform:f32:0, {1.000000e+00,1.000000e+00,1.000000e+00,1.000000e+00}>>, volatile} +// CHECK: %[[dq:.*]] = "tfl.dequantize"(%0) : (tensor<4x12x!quant.uniform:f32:0, {1.000000e+00,1.000000e+00,1.000000e+00,1.000000e+00}>>) -> tensor<4x12xf32> // CHECK: "tfl.fully_connected"(%arg0, %[[dq]] -// PerTensor: %[[cst:.*]] = arith.constant dense<1.270000e+02> : tensor<32x12xf32> -// PerTensor: %[[q:.*]] = "tfl.quantize"(%cst) {qtype = tensor<32x12x!quant.uniform:f32, 1.000000e+00>>, volatile} -// PerTensor: %[[dq:.*]] = "tfl.dequantize"(%0) : (tensor<32x12x!quant.uniform:f32, 1.000000e+00>>) -> tensor<32x12xf32> +// PerTensor: %[[cst:.*]] = arith.constant dense<1.270000e+02> : tensor<4x12xf32> +// PerTensor: %[[q:.*]] = "tfl.quantize"(%cst) {qtype = tensor<4x12x!quant.uniform:f32, 1.000000e+00>>, volatile} +// PerTensor: %[[dq:.*]] = "tfl.dequantize"(%0) : (tensor<4x12x!quant.uniform:f32, 1.000000e+00>>) -> tensor<4x12xf32> // PerTensor: "tfl.fully_connected"(%arg0, %[[dq]] } @@ -215,8 +215,8 @@ func.func @bias_adjust_pertensor(%arg0: tensor<1x2xf32>) -> (tensor<1x2xf32>) { func.return %fc : tensor<1x2xf32> // CHECK-DAG: %[[weight:.*]] = arith.constant dense<{{\[\[}}0.000000e+00, 1.000000e+00] // CHECK-DAG: %[[bias:.*]] = arith.constant dense<[0.000000e+00, 2147364.75]> -// CHECK-DAG: %[[b_q:.*]] = "tfl.quantize"(%[[bias]]){{.*}}quant.uniform> -// CHECK-DAG: %[[w_q:.*]] = "tfl.quantize"(%[[weight]]){{.*}}quant.uniform:f32, 19998.892343977564>> +// CHECK-DAG: %[[b_q:.*]] = "tfl.quantize"(%[[bias]]){{.*}}quant.uniform> +// CHECK-DAG: %[[w_q:.*]] = "tfl.quantize"(%[[weight]]){{.*}}quant.uniform:f32:0, {0.0078740157480314959,19998.892343977564}>> // CHECK-DAG: %[[b_dq:.*]] = "tfl.dequantize"(%[[b_q]]) // CHECK-DAG: %[[w_dq:.*]] = "tfl.dequantize"(%[[w_q]]) // CHECK: %[[fc:.*]] = "tfl.fully_connected"(%[[input:.*]], %[[w_dq]], %[[b_dq]]) diff --git a/tensorflow/compiler/mlir/lite/tests/prepare-quantize.mlir b/tensorflow/compiler/mlir/lite/tests/prepare-quantize.mlir index 2a4b2af88f5319..cce986eb8f1a8e 100644 --- a/tensorflow/compiler/mlir/lite/tests/prepare-quantize.mlir +++ b/tensorflow/compiler/mlir/lite/tests/prepare-quantize.mlir @@ -1,5 +1,6 @@ // RUN: tf-opt %s -tfl-prepare-quantize="quantize-allowlist=quantize_float_placeholder_only,not_reset_input" | FileCheck %s // RUN: tf-opt %s -tfl-prepare-quantize="disable-set-input-nodes-quantization-params=true" | FileCheck --check-prefix=MixedPrecision %s +// RUN: tf-opt %s -tfl-prepare-quantize="is-qdq-conversion=true" | FileCheck --check-prefix=QDQ %s // CHECK-LABEL: main // Uses `main` function to match the default target function of QuantSpecs and @@ -394,6 +395,32 @@ func.func @NotRescaleLogistic(%arg0: tensor<1x6x6x16x!quant.uniform>) -> tensor<1x6x6x16xf32> { +^bb0(%arg0: tensor<1x6x6x16x!quant.uniform>): + %0 = "tfl.dequantize"(%arg0) : (tensor<1x6x6x16x!quant.uniform>) -> tensor<1x6x6x16xf32> + %1 = "tfl.logistic"(%0) : (tensor<1x6x6x16xf32>) -> tensor<1x6x6x16xf32> + func.return %1 : tensor<1x6x6x16xf32> + +// QDQ: %0 = "tfl.dequantize"(%arg0) +// QDQ: %1 = "tfl.logistic"(%0) : (tensor<1x6x6x16xf32>) -> tensor<1x6x6x16xf32> +// QDQ-NOT:"tfl.quantize" +// QDQ: return %1 : tensor<1x6x6x16xf32> +} + +// QDQ-LABEL: QDQNoQuantizeSoftmax +func.func @QDQNoQuantizeSoftmax(tensor<1x6x6x16x!quant.uniform>) -> tensor<1x6x6x16xf32> { +^bb0(%arg0: tensor<1x6x6x16x!quant.uniform>): + %0 = "tfl.dequantize"(%arg0) : (tensor<1x6x6x16x!quant.uniform>) -> tensor<1x6x6x16xf32> + %1 = "tfl.softmax"(%0) {beta = 1.000000e+00 : f32} : (tensor<1x6x6x16xf32>) -> tensor<1x6x6x16xf32> + func.return %1 : tensor<1x6x6x16xf32> + +// QDQ: %0 = "tfl.dequantize"(%arg0) +// QDQ: %1 = "tfl.softmax"(%0) {beta = 1.000000e+00 : f32} : (tensor<1x6x6x16xf32>) -> tensor<1x6x6x16xf32> +// QDQ-NOT: "tfl.quantize" +// QDQ: return %1 : tensor<1x6x6x16xf32> +} + // CHECK-LABEL: QuantizeL2Norm func.func @QuantizeL2Norm(%arg0: tensor<1x6x6x16x!quant.uniform>) -> tensor<1x6x6x16xf32> { %0 = "tfl.dequantize"(%arg0) : (tensor<1x6x6x16x!quant.uniform>) -> tensor<1x6x6x16xf32> diff --git a/tensorflow/compiler/mlir/lite/tests/prepare-tf.mlir b/tensorflow/compiler/mlir/lite/tests/prepare-tf.mlir index 0769e768507ee7..bfbcbd573cb0e0 100644 --- a/tensorflow/compiler/mlir/lite/tests/prepare-tf.mlir +++ b/tensorflow/compiler/mlir/lite/tests/prepare-tf.mlir @@ -440,34 +440,13 @@ func.func @StridedSliceShrinkAxisAndNewAxisMaskBothSet(%arg0: tensor<6x7x8xf32>) // CHECK: %[[STRIDED_SLICE:.*]] = "tf.StridedSlice"(%[[RESHAPE]], %[[BEGIN]], %[[END]], %[[STEP]]) <{begin_mask = 26 : i64, ellipsis_mask = 0 : i64, end_mask = 26 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64}> : (tensor<6x1x7x1x8xf32>, tensor<5xi32>, tensor<5xi32>, tensor<5xi32>) -> tensor<1x4x1x8xf32> } -func.func @broadcast_to_f32_low_dim(%arg0: tensor<3xf32>, %arg1: tensor<2xi32>) -> tensor<3x3xf32> { - %0 = "tf.BroadcastTo"(%arg0, %arg1) : (tensor<3xf32>, tensor<2xi32>) -> tensor<3x3xf32> - func.return %0: tensor<3x3xf32> - -// CHECK-LABEL: broadcast_to_f32_low_dim -// CHECK: [[CST:%.*]] = arith.constant dense<1.000000e+00> : tensor<3x3xf32> -// CHECK: [[MUL:%.*]] = "tf.Mul"(%arg0, [[CST]]) : (tensor<3xf32>, tensor<3x3xf32>) -> tensor<3x3xf32> -// CHECK: return [[MUL]] : tensor<3x3xf32> -} - -func.func @broadcast_to_i32_low_dim(%input: tensor<3xi32>, %shape: tensor<2xi32>) -> tensor<3x3xi32> { - %0 = "tf.BroadcastTo"(%input, %shape) : (tensor<3xi32>, tensor<2xi32>) -> tensor<3x3xi32> - func.return %0: tensor<3x3xi32> - -// CHECK-LABEL: broadcast_to_i32_low_dim -// CHECK: [[CST:%.*]] = arith.constant dense<1> : tensor<3x3xi32> -// CHECK: [[MUL:%.*]] = "tf.Mul"(%arg0, [[CST]]) : (tensor<3xi32>, tensor<3x3xi32>) -> tensor<3x3xi32> -// CHECK: return [[MUL]] : tensor<3x3xi32> -} - func.func @broadcast_to_i16_low_dim(%input: tensor<3xi16>, %shape: tensor<2xi32>) -> tensor<3x3xi16> { %0 = "tf.BroadcastTo"(%input, %shape) : (tensor<3xi16>, tensor<2xi32>) -> tensor<3x3xi16> func.return %0: tensor<3x3xi16> // CHECK-LABEL: broadcast_to_i16_low_dim -// CHECK: [[CST:%.*]] = arith.constant dense<1> : tensor<3x3xi16> -// CHECK: [[MUL:%.*]] = "tf.Mul"(%arg0, [[CST]]) : (tensor<3xi16>, tensor<3x3xi16>) -> tensor<3x3xi16> -// CHECK: return [[MUL]] : tensor<3x3xi16> +// CHECK: %0 = "tf.BroadcastTo"(%arg0, %arg1) : (tensor<3xi16>, tensor<2xi32>) -> tensor<3x3xi16> +// CHECK: return %0 : tensor<3x3xi16> } func.func @broadcast_to_low_dim_with_unknown_shape(%arg0: tensor<3xf32>, %arg1: tensor<*xi32>) -> tensor<3x3xf32> { @@ -475,9 +454,8 @@ func.func @broadcast_to_low_dim_with_unknown_shape(%arg0: tensor<3xf32>, %arg1: func.return %0: tensor<3x3xf32> // CHECK-LABEL: broadcast_to_low_dim_with_unknown_shape -// CHECK: [[CST:%.*]] = arith.constant dense<1.000000e+00> : tensor<3x3xf32> -// CHECK: [[MUL:%.*]] = "tf.Mul"(%arg0, [[CST]]) : (tensor<3xf32>, tensor<3x3xf32>) -> tensor<3x3xf32> -// CHECK: return [[MUL]] : tensor<3x3xf32> +// CHECK: %0 = "tf.BroadcastTo"(%arg0, %arg1) : (tensor<3xf32>, tensor<*xi32>) -> tensor<3x3xf32> +// CHECK: return %0 : tensor<3x3xf32> } func.func @broadcast_to_i32_low_dim_with_unknown_output(%input: tensor<3xi32>, %shape: tensor<2xi32>) -> tensor<*xi32> { @@ -485,10 +463,8 @@ func.func @broadcast_to_i32_low_dim_with_unknown_output(%input: tensor<3xi32>, % func.return %0: tensor<*xi32> // CHECK-LABEL: broadcast_to_i32_low_dim_with_unknown_output -// CHECK: [[CST:%.*]] = arith.constant dense<1> : tensor -// CHECK: [[FILL:%.*]] = "tf.Fill"(%arg1, [[CST]]) : (tensor<2xi32>, tensor) -> tensor<*xi32> -// CHECK: [[MUL:%.*]] = "tf.Mul"(%arg0, [[FILL]]) : (tensor<3xi32>, tensor<*xi32>) -> tensor<*xi32> -// CHECK: return [[MUL]] : tensor<*xi32> +// CHECK: %0 = "tf.BroadcastTo"(%arg0, %arg1) : (tensor<3xi32>, tensor<2xi32>) -> tensor<*xi32> +// CHECK: return %0 : tensor<*xi32> } func.func @broadcast_to_high_dim_with_unknown_shape(%arg0: tensor<1x2x3x4x5x6xf32>, %arg1: tensor<*xi32>) -> tensor<7x8x1x2x3x4x5x6xf32> { @@ -517,16 +493,6 @@ func.func @broadcast_to_with_unknown_shape_and_output(%arg0: tensor<1x2x3x4x5x6x // CHECK: "tf.BroadcastTo"(%arg0, %arg1) } -func.func @broadcast_to_ui32(%arg0: tensor, %arg1: tensor<1xi64>) -> tensor<10xui32> { - %0 = "tf.BroadcastTo"(%arg0, %arg1) : (tensor, tensor<1xi64>) -> tensor<10xui32> - func.return %0: tensor<10xui32> - -// CHECK-LABEL: broadcast_to_ui32 -// CHECK: [[CST:%.*]] = arith.constant dense<1> : tensor<10xui32> -// CHECK: [[MUL:%.*]] = "tf.Mul"(%arg0, [[CST]]) : (tensor, tensor<10xui32>) -> tensor<10xui32> -// CHECK: return [[MUL]] : tensor<10xui32> -} - // CHECK-LABEL: xla_conv_v2 func.func @xla_conv_v2(%arg0: tensor<4x8x8x16xf32>) -> tensor<4x8x8x16xf32> { %0 = "tf.Const"() {value = dense<1.000000e+00> : tensor<3x3x16x16xf32>} : () -> tensor<3x3x16x16xf32> loc("Const_1") @@ -541,26 +507,6 @@ func.func @xla_conv_v2(%arg0: tensor<4x8x8x16xf32>) -> tensor<4x8x8x16xf32> { // CHECK: return %[[RES]] } -func.func @broadcast_to_f32(%arg0: tensor<3xf32>, %arg1: tensor<2xi32>) -> tensor<3x3xf32> { - %0 = "tf.BroadcastTo"(%arg0, %arg1) : (tensor<3xf32>, tensor<2xi32>) -> tensor<3x3xf32> - func.return %0: tensor<3x3xf32> - -// CHECK-LABEL: broadcast_to_f32 -// CHECK: [[CST:%.*]] = arith.constant dense<1.000000e+00> : tensor<3x3xf32> -// CHECK: [[MUL:%.*]] = "tf.Mul"(%arg0, [[CST]]) : (tensor<3xf32>, tensor<3x3xf32>) -> tensor<3x3xf32> -// CHECK: return [[MUL]] : tensor<3x3xf32> -} - -func.func @broadcast_to_i32(%input: tensor<3xi32>, %shape: tensor<2xi32>) -> tensor<3x3xi32> { - %0 = "tf.BroadcastTo"(%input, %shape) : (tensor<3xi32>, tensor<2xi32>) -> tensor<3x3xi32> - func.return %0: tensor<3x3xi32> - -// CHECK-LABEL: broadcast_to_i32 -// CHECK: [[CST:%.*]] = arith.constant dense<1> : tensor<3x3xi32> -// CHECK: [[MUL:%.*]] = "tf.Mul"(%arg0, [[CST]]) : (tensor<3xi32>, tensor<3x3xi32>) -> tensor<3x3xi32> -// CHECK: return [[MUL]] : tensor<3x3xi32> -} - // CHECK-LABEL: lower_rfft_to_rfft2d func.func @lower_rfft_to_rfft2d(%input: tensor<10x20x30xf32>, %fft_len: tensor<1xi32>) -> tensor<10x20x30xcomplex> { %0 = "tf.RFFT"(%input, %fft_len) : (tensor<10x20x30xf32>, tensor<1xi32>) -> tensor<10x20x30xcomplex> diff --git a/tensorflow/compiler/mlir/lite/tests/quantize-dynamic-range.mlir b/tensorflow/compiler/mlir/lite/tests/quantize-dynamic-range.mlir index 15684bc4bd2204..ad4ff5a129f4a2 100644 --- a/tensorflow/compiler/mlir/lite/tests/quantize-dynamic-range.mlir +++ b/tensorflow/compiler/mlir/lite/tests/quantize-dynamic-range.mlir @@ -88,7 +88,7 @@ func.func @QuantizeFullyConnected(%arg0: tensor<1x224x224x3xf32>) -> tensor<1x11 func.return %fc : tensor<1x112x112x512xf32> // CHECK: %[[b:.*]] = arith.constant dense<0.000000e+00> : tensor<512xf32> -// CHECK: %[[w:.*]] = "tfl.pseudo_qconst"() {qtype = tensor<512x12x!quant.uniform:f32, 1.000000e+00>> +// CHECK: %[[w:.*]] = "tfl.pseudo_qconst"() {qtype = tensor<512x12x!quant.uniform:f32:0, {1.000000e+00,1.000000e+00,1.000000e+00,1.000000e+00 // CHECK: %[[fc:.*]] = "tfl.fully_connected"(%arg0, %[[w]], %[[b]]) { // CHECK-NOT: fused_activation_function = "NONE", // CHECK-SAME: asymmetric_quantize_inputs = true, @@ -102,8 +102,8 @@ func.func @QuantizeFullyConnected(%arg0: tensor<1x224x224x3xf32>) -> tensor<1x11 // PerTensor: return %[[fc:.*]] // PerChannelWeightOnly: %[[b:.*]] = arith.constant dense<0.000000e+00> : tensor<512xf32> -// PerChannelWeightOnly: %[[w:.*]] = "tfl.pseudo_qconst"() {qtype = tensor<512x12x!quant.uniform:f32, 1.000000e+00>> -// PerChannelWeightOnly: %[[dq_w:.*]] = "tfl.dequantize"(%[[w]]) : (tensor<512x12x!quant.uniform:f32, 1.000000e+00>> +// PerChannelWeightOnly: %[[w:.*]] = "tfl.pseudo_qconst"() {qtype = tensor<512x12x!quant.uniform:f32:0, {1.000000e+00,1.000000e+00,1.000000e+00,1.000000e+00 +// PerChannelWeightOnly: %[[dq_w:.*]] = "tfl.dequantize"(%[[w]]) : (tensor<512x12x!quant.uniform:f32:0, {1.000000e+00,1.000000e+00,1.000000e+00,1.000000e+00 // PerChannelWeightOnly: %[[fc:.*]] = "tfl.fully_connected"(%arg0, %[[dq_w]], %[[b]]) { // PerChannelWeightOnly-NOT: fused_activation_function = "NONE", // PerChannelWeightOnly-SAME: asymmetric_quantize_inputs = true, diff --git a/tensorflow/compiler/mlir/lite/tf_tfl_passes.cc b/tensorflow/compiler/mlir/lite/tf_tfl_passes.cc index 520154ab595ccd..e71c5117bd2d51 100644 --- a/tensorflow/compiler/mlir/lite/tf_tfl_passes.cc +++ b/tensorflow/compiler/mlir/lite/tf_tfl_passes.cc @@ -19,21 +19,22 @@ limitations under the License. #include #include +#include "llvm/ADT/StringRef.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Pass/PassManager.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Transforms/Passes.h" // from @llvm-project +#include "stablehlo/experimental/transforms/Passes.h" // from @stablehlo #include "tensorflow/compiler/mlir/lite/common/tfl_pass_config.h" #include "tensorflow/compiler/mlir/lite/quantization/quantization_config.h" #include "tensorflow/compiler/mlir/lite/quantization/quantization_passes.h" #include "tensorflow/compiler/mlir/lite/quantization/tensorflow/passes.h" #include "tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_tf_xla_call_module_to_stablehlo_pass.h" #include "tensorflow/compiler/mlir/lite/stablehlo/transforms/passes.h" -#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/rename_entrypoint_to_main.h" #include "tensorflow/compiler/mlir/lite/stablehlo/transforms/transforms.h" #include "tensorflow/compiler/mlir/lite/transforms/passes.h" #include "tensorflow/compiler/mlir/lite/utils/fake_quant_utils.h" -#include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/tf_saved_model_passes.h" #include "xla/mlir_hlo/mhlo/transforms/passes.h" @@ -47,7 +48,7 @@ CreateTFExecutorToControlDialectConversion(); namespace tensorflow { namespace { // Data layout supported by TFLite. -const char kTFLiteDataLayout[] = "NHWC"; +constexpr mlir::StringRef kTFLiteDataLayout = "NHWC"; } // namespace void AddQuantizationPasses(const mlir::TFL::PassConfig& pass_config, @@ -136,12 +137,26 @@ void AddDynamicRangeQuantizationPasses(const mlir::TFL::PassConfig& pass_config, mlir::TFL::CreateOptimizePass(/*enable_canonicalization=*/true)); } -void AddConvertHloToTfPass(std::string entry_function_name, - const mlir::TFL::PassConfig& pass_config, - mlir::OpPassManager* pass_manager) { - pass_manager->addPass( +void AddPreQuantizationStableHloToTfPasses( + const mlir::StringRef entry_function_name, + const mlir::TFL::PassConfig& pass_config, + mlir::OpPassManager& pass_manager) { + pass_manager.addPass( mlir::odml::CreateLegalizeTFXlaCallModuleToStablehloPass()); + // Add CHLO to StableHLO Decompositions: + // This is needed since we are relying on XlaCallModule uses MHLO + // specific features like mhlo::ErfOp which aren't supported + // in StableHLO, but we have CHLO->StableHLO decompositions to legalize. + pass_manager.addPass(mlir::mhlo::createHloLegalizeToStablehloPass()); + pass_manager.addPass( + mlir::stablehlo::experimental::createChloRecomposeOpsPass()); + pass_manager.addNestedPass( + mlir::mhlo::createChloLegalizeToHloBasisOpsPass()); + pass_manager.addNestedPass( + mlir::mhlo::createChloLegalizeToHloPass()); + pass_manager.addPass(mlir::mhlo::createHloLegalizeToStablehloPass()); + // The following two passes find specific uniform quantization patterns in // StableHLO and converts them to TFLite ops that accept or produce uniform // quantized types. They only target a specific set of models that contain @@ -153,65 +168,94 @@ void AddConvertHloToTfPass(std::string entry_function_name, // There are future plans to make the framework to directly produce StableHLO // uniform quantized ops and deprecate `ComposeUniformQuantizedTypePass`. If // no quantization patterns are found, it is a no-op. - pass_manager->addPass(mlir::odml::CreateComposeUniformQuantizedTypePass()); - pass_manager->addNestedPass( - mlir::odml::CreateUniformQuantizedStablehloToTflPass()); + pass_manager.addPass(mlir::odml::CreateComposeUniformQuantizedTypePass()); + pass_manager.addNestedPass( + mlir::odml::CreateUniformQuantizedStableHloToTflPass()); - pass_manager->addPass(mlir::mhlo::createStablehloLegalizeToHloPass()); + pass_manager.addPass(mlir::mhlo::createStablehloLegalizeToHloPass()); // Legalize jax random to tflite custom op. // The CreateLegalizeJaxRandom Pass has to stay at because we need to replace // the random function body before being inlined. - pass_manager->addNestedPass( + pass_manager.addNestedPass( mlir::TFL::CreateLegalizeJaxRandomPass()); // Canonicalize, CSE etc. - pass_manager->addNestedPass( + pass_manager.addNestedPass( mlir::createCanonicalizerPass()); - pass_manager->addNestedPass(mlir::createCSEPass()); + pass_manager.addNestedPass(mlir::createCSEPass()); // DCE for private symbols. - pass_manager->addPass(mlir::createSymbolDCEPass()); + pass_manager.addPass(mlir::createSymbolDCEPass()); - pass_manager->addPass(mlir::TF::CreateStripNoinlineAttributePass()); + pass_manager.addPass(mlir::TF::CreateStripNoinlineAttributePass()); // Add inline pass. - pass_manager->addPass(mlir::createInlinerPass()); + pass_manager.addPass(mlir::createInlinerPass()); // Expands mhlo.tuple ops. - pass_manager->addPass( - mlir::mhlo::createExpandHloTuplesPass(entry_function_name)); + pass_manager.addPass( + mlir::mhlo::createExpandHloTuplesPass(entry_function_name.str())); // Flatten tuples for control flows. - pass_manager->addNestedPass( + pass_manager.addNestedPass( mlir::mhlo::createFlattenTuplePass()); - mlir::odml::AddMhloOptimizationPasses(*pass_manager); + mlir::odml::AddMhloOptimizationPasses( + pass_manager, + /*add_fold_broadcast_pass=*/pass_config.enable_stablehlo_quantizer); // Undo the MHLO::BroadcastInDimOp folding pattern on splat constants. This // pass must be added right before the legalization because pattern rewriter // driver applies folding by default. - // TODO(b/295966255): Remove this pass after moving MHLO folders to a separate - // pass. - pass_manager->addPass(mlir::odml::CreateUnfoldSplatConstantPass()); + // TODO: b/295966255 - Remove this pass after moving MHLO folders to a + // separate pass. + pass_manager.addPass(mlir::odml::CreateUnfoldSplatConstantPass()); + + if (pass_config.enable_stablehlo_quantizer) { + // When using StableHLO Quantizer, MHLO ops should be transformed back into + // StableHLO because the quantizer takes StableHLO dialect as its input. + pass_manager.addPass(mlir::mhlo::createHloLegalizeToStablehloPass()); + } +} + +void AddPostQuantizationStableHloToTfPasses( + const mlir::TFL::PassConfig& pass_config, + mlir::OpPassManager& pass_manager) { + if (pass_config.enable_stablehlo_quantizer) { + // StableHLO Quantizer emits quantized StableHLO module serialized within a + // XlaCallModule op. Add this pass to extract StableHLO module from the + // XlaCallModuleOp. + pass_manager.addPass( + mlir::odml::CreateLegalizeTFXlaCallModuleToStablehloPass()); + + // Convert StableHLO -> TFLite for fused quantization patterns early so that + // quantized types do not go through the TF dialect which doesn't support + // quantized types. + pass_manager.addNestedPass( + mlir::odml::CreateUniformQuantizedStableHloToTflPass()); + + // StableHLO -> MHLO + pass_manager.addPass(mlir::mhlo::createStablehloLegalizeToHloPass()); + } // TFLite dialect passes. if (!pass_config.disable_hlo_to_tfl_conversion) { - pass_manager->addPass(mlir::odml::CreateLegalizeHloToTfLitePass()); + pass_manager.addPass(mlir::odml::CreateLegalizeHloToTfLitePass()); } // TF dialect passes - pass_manager->addPass(mlir::odml::CreateLegalizeHloToTfPass()); + pass_manager.addPass(mlir::odml::CreateLegalizeHloToTfPass()); // folds tf.BroadcastTo ops with subsequent ops if they have built in // broadcasting support. This needs to be run immediately after HLO->TF // legalization; otherwise other passes like `ConvertTFBroadcastTo` will // constant fold the newly generated TF broadcast ops and materialize the // weights. - pass_manager->addNestedPass( + pass_manager.addNestedPass( mlir::TF::CreateBroadcastFoldPass()); // Canonicalization after TF legalization. - pass_manager->addNestedPass( + pass_manager.addNestedPass( mlir::createCanonicalizerPass()); // Legalize all remaining mhlo ops to stableHLO - pass_manager->addPass(mlir::mhlo::createHloLegalizeToStablehloPass()); + pass_manager.addPass(mlir::mhlo::createHloLegalizeToStablehloPass()); } // This is the early part of the conversion in isolation. This enables a caller @@ -220,11 +264,6 @@ void AddConvertHloToTfPass(std::string entry_function_name, void AddPreVariableFreezingTFToTFLConversionPasses( const mlir::TFL::PassConfig& pass_config, mlir::OpPassManager* pass_manager) { - if (pass_config.enable_hlo_to_tf_conversion) { - // TODO(b/194747383): We need to valid that indeed the "main" func is - // presented. - AddConvertHloToTfPass("main", pass_config, pass_manager); - } // This pass wraps all the tf.FakeQuant ops in a custom op so they are not // folded before being converted to tfl.quantize and tfl.dequantize ops. auto wrapped_ops = mlir::TFL::AllTfFakeQuantOps(); @@ -266,7 +305,7 @@ void AddPreVariableFreezingTFToTFLConversionPasses( // This decomposes resource ops like ResourceGather into read-variable op // followed by gather. This is used when the saved model import path is used - // during which resources dont get frozen in the python layer. + // during which resources don't get frozen in the python layer. pass_manager->addNestedPass( mlir::TFDevice::CreateDecomposeResourceOpsPass()); @@ -375,7 +414,7 @@ void AddPostVariableFreezingTFToTFLConversionPasses( // Force layout supported by TFLite, this will transpose the data // to match 'kTFLiteDataLayout' mlir::TF::LayoutOptimizationPipelineOptions layout_optimization_options; - layout_optimization_options.force_data_format = kTFLiteDataLayout; + layout_optimization_options.force_data_format = kTFLiteDataLayout.str(); layout_optimization_options.skip_fold_transpose_in_ops = true; mlir::TF::CreateLayoutOptimizationPipeline( pass_manager->nest(), layout_optimization_options); diff --git a/tensorflow/compiler/mlir/lite/tf_tfl_passes.h b/tensorflow/compiler/mlir/lite/tf_tfl_passes.h index 50bf75023a808a..8de2142f0ebd83 100644 --- a/tensorflow/compiler/mlir/lite/tf_tfl_passes.h +++ b/tensorflow/compiler/mlir/lite/tf_tfl_passes.h @@ -17,6 +17,7 @@ limitations under the License. #define TENSORFLOW_COMPILER_MLIR_LITE_TF_TFL_PASSES_H_ #include "mlir/Pass/PassManager.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/common/tfl_pass_config.h" #include "tensorflow/lite/toco/toco_flags.pb.h" @@ -32,6 +33,25 @@ void AddTFToTFLConversionPasses(llvm::StringRef saved_model_dir, const mlir::TFL::PassConfig& pass_config, mlir::OpPassManager* pass_manager); +// Adds the first portion of StableHLO->TF passes happening before quantization. +// The `pass_manager` that runs on a `mlir::ModuleOp` expects a graph containing +// a `mlir::TF::XlaCallModuleOp` with serialized StableHLO module. The resulting +// `mlir::ModuleOp` after running these passes will be an MHLO module, or a +// StableHLO module if `pass_config.enable_stablehlo_quantizer` is `true`. This +// is because StableHLO Quantizer accepts StableHLO modules. +void AddPreQuantizationStableHloToTfPasses( + mlir::StringRef entry_function_name, + const mlir::TFL::PassConfig& pass_config, + mlir::OpPassManager& pass_manager); + +// Adds the second portion of StableHlo->TF passes happening after quantization. +// The input module is expected to be an MHLO module, or a quantized StableHLO +// graph (expressed as `mlir::TF::XlaCallModuleOp`s) if +// `pass_config.enable_stablehlo_quantizer` is `true`. +void AddPostQuantizationStableHloToTfPasses( + const mlir::TFL::PassConfig& pass_config, + mlir::OpPassManager& pass_manager); + // This is the early part of the conversion in isolation. This enables a caller // to inject more information in the middle of the conversion before resuming it // (like freezing variables for example). diff --git a/tensorflow/compiler/mlir/lite/tf_tfl_translate.cc b/tensorflow/compiler/mlir/lite/tf_tfl_translate.cc index dc0ae41ba49a2b..892eed27385035 100644 --- a/tensorflow/compiler/mlir/lite/tf_tfl_translate.cc +++ b/tensorflow/compiler/mlir/lite/tf_tfl_translate.cc @@ -17,56 +17,55 @@ limitations under the License. #include #include #include +#include #include #include +#include "absl/status/statusor.h" #include "absl/strings/str_split.h" +#include "absl/types/span.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringExtras.h" #include "llvm/ADT/StringRef.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/FormatVariadic.h" -#include "llvm/Support/InitLLVM.h" #include "llvm/Support/SourceMgr.h" #include "llvm/Support/ToolOutputFile.h" #include "llvm/Support/raw_ostream.h" +#include "mlir/Dialect/Func/Extensions/AllExtensions.h" // from @llvm-project #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/IR/AsmState.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/IR/Diagnostics.h" // from @llvm-project #include "mlir/IR/DialectRegistry.h" // from @llvm-project +#include "mlir/IR/Location.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project #include "mlir/Parser/Parser.h" // from @llvm-project -#include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Pass/PassManager.h" // from @llvm-project #include "mlir/Support/FileUtilities.h" // from @llvm-project -#include "mlir/Transforms/Passes.h" // from @llvm-project #include "stablehlo/dialect/StablehloOps.h" // from @stablehlo #include "tensorflow/cc/saved_model/loader.h" #include "tensorflow/compiler/mlir/init_mlir.h" #include "tensorflow/compiler/mlir/lite/common/tfl_pass_config.h" -#include "tensorflow/compiler/mlir/lite/flatbuffer_export.h" #include "tensorflow/compiler/mlir/lite/flatbuffer_export_flags.h" -#include "tensorflow/compiler/mlir/lite/tf_tfl_passes.h" +#include "tensorflow/compiler/mlir/lite/quantization/quantization_config.h" #include "tensorflow/compiler/mlir/lite/tf_tfl_translate_cl.h" #include "tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.h" #include "tensorflow/compiler/mlir/lite/transforms/passes.h" #include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h" -#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" #include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h" #include "tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate_cl.h" #include "xla/translate/hlo_to_mhlo/translate.h" #include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/platform/errors.h" -#include "tensorflow/lite/model.h" +#include "tensorflow/core/public/session.h" +#include "tensorflow/lite/model_builder.h" #include "tensorflow/lite/schema/schema_generated.h" -#include "tsl/platform/statusor.h" using mlir::MLIRContext; using mlir::ModuleOp; using mlir::func::FuncOp; -using tsl::StatusOr; // Debugging flag to print function mapping in the flatbuffer. // NOLINTNEXTLINE @@ -170,7 +169,7 @@ int main(int argc, char **argv) { context.appendDialectRegistry(registry); } - StatusOr> module; + absl::StatusOr> module; std::unordered_set tags; tensorflow::GraphImportConfig specs; @@ -321,7 +320,7 @@ int main(int argc, char **argv) { if (bundle) session = bundle->GetSession(); auto status = tensorflow::ConvertTFExecutorToTFLOrFlatbuffer( module.value().get(), output_mlir, toco_flags, pass_config, tags, - /*saved_model_dir=*/"", session, &result, serialize_stablehlo_ops); + /*saved_model_dir=*/"", bundle.get(), &result, serialize_stablehlo_ops); if (!status.ok()) { llvm::errs() << status.message() << '\n'; return kTrFailure; diff --git a/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.cc b/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.cc index 1b3c5d21ba3dfb..b6a08b35d69445 100644 --- a/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.cc +++ b/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.h" +#include #include #include #include @@ -23,60 +24,75 @@ limitations under the License. #include #include +#include "absl/log/log.h" #include "absl/status/status.h" +#include "absl/status/statusor.h" #include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" #include "absl/types/span.h" +#include "flatbuffers/flatbuffer_builder.h" // from @flatbuffers #include "llvm/Support/Casting.h" #include "llvm/Support/raw_ostream.h" #include "mlir/Dialect/Func/Extensions/AllExtensions.h" // from @llvm-project -#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project -#include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/Diagnostics.h" // from @llvm-project +#include "mlir/IR/OwningOpRef.h" // from @llvm-project #include "mlir/IR/Visitors.h" // from @llvm-project #include "mlir/Parser/Parser.h" // from @llvm-project -#include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Pass/PassManager.h" // from @llvm-project #include "mlir/Support/FileUtilities.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project -#include "mlir/Transforms/Passes.h" // from @llvm-project +#include "tensorflow/cc/saved_model/loader.h" +#include "tensorflow/compiler/mlir/lite/common/tfl_pass_config.h" #include "tensorflow/compiler/mlir/lite/debug/debug.h" #include "tensorflow/compiler/mlir/lite/flatbuffer_export.h" +#include "tensorflow/compiler/mlir/lite/metrics/error_collector.h" #include "tensorflow/compiler/mlir/lite/metrics/error_collector_inst.h" #include "tensorflow/compiler/mlir/lite/quantization/quantization_config.h" -#include "tensorflow/compiler/mlir/lite/stablehlo/serializer/flatbuffer_export.h" +#include "tensorflow/compiler/mlir/lite/quantization/stablehlo/quantization.h" #include "tensorflow/compiler/mlir/lite/stablehlo/transforms/op_stat_pass.h" -#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/stablehlo_tfl_pass.h" #include "tensorflow/compiler/mlir/lite/stablehlo/transforms/stablehlo_util.h" #include "tensorflow/compiler/mlir/lite/stablehlo/transforms/transforms.h" #include "tensorflow/compiler/mlir/lite/tf_tfl_passes.h" #include "tensorflow/compiler/mlir/lite/transforms/passes.h" +#include "tensorflow/compiler/mlir/op_or_arg_name_mapper.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.pb.h" #include "tensorflow/compiler/mlir/quantization/stablehlo/quantize_passes.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/python/py_function_lib.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/quantization_options.pb.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/quantize_passes.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/quantize_preprocess.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" -#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/tf_saved_model_freeze_variables.h" -#include "tensorflow/compiler/mlir/tensorflow/transforms/tf_saved_model_passes.h" +#include "tensorflow/compiler/mlir/tensorflow/translate/mlir_import_options.h" +#include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h" #include "tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.h" #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op_def.pb.h" +#include "tensorflow/core/framework/op_def_builder.h" #include "tensorflow/core/framework/types.pb.h" -#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/protobuf/meta_graph.pb.h" +#include "tensorflow/core/public/session.h" +#include "tensorflow/lite/c/c_api_types.h" +#include "tensorflow/lite/experimental/remat/metadata_util.h" #include "tensorflow/lite/python/metrics/converter_error_data.pb.h" +#include "tensorflow/lite/schema/schema_generated.h" #include "tensorflow/lite/tools/optimize/quantize_weights.h" #include "tensorflow/lite/tools/optimize/reduced_precision_support.h" -#include "tsl/platform/statusor.h" +#include "tsl/platform/protobuf.h" // IWYU pragma: keep namespace tensorflow { namespace { + using mlir::MLIRContext; using mlir::ModuleOp; using mlir::Operation; using mlir::OwningOpRef; -using tsl::StatusOr; +using ::stablehlo::quantization::QuantizationConfig; +using ::tensorflow::quantization::PyFunctionLibrary; bool IsControlFlowV1Op(Operation* op) { return mlir::isa extra_tf_opdefs) { +absl::Status RegisterExtraTfOpDefs( + absl::Span extra_tf_opdefs) { for (const auto& tf_opdefs_string : extra_tf_opdefs) { - tensorflow::OpDef opdef; - if (!tensorflow::protobuf::TextFormat::ParseFromString(tf_opdefs_string, - &opdef)) { + OpDef opdef; + // NOLINTNEXTLINE: Use tsl::protobuf to be compatible with OSS. + if (!tsl::protobuf::TextFormat::ParseFromString(tf_opdefs_string, &opdef)) { LOG(ERROR) << "OpDef parsing failed for: " << tf_opdefs_string; - return errors::InvalidArgument("fail to parse extra OpDef"); + return absl::InvalidArgumentError("fail to parse extra OpDef"); } // Register extra opdefs. - // TODO(b/133770952): Support shape functions. - tensorflow::OpRegistry::Global()->Register( - [opdef](tensorflow::OpRegistrationData* op_reg_data) -> Status { - *op_reg_data = tensorflow::OpRegistrationData(opdef); - return OkStatus(); + // TODO: b/133770952 - Support shape functions. + OpRegistry::Global()->Register( + [opdef](OpRegistrationData* op_reg_data) -> absl::Status { + *op_reg_data = OpRegistrationData(opdef); + return absl::OkStatus(); }); } - return OkStatus(); + return absl::OkStatus(); +} + +// The hlo->tf conversion is done in three steps; pre-quantization, +// quantization, and post-quantization. Quantization is optional, enabled only +// when `pass_config.enable_stablehlo_quantizer` is `true`. If quantization is +// not run, it only performs the hlo->tf conversion. +// +// All parameters except for `pass_config`, `pass_manager`, `status_handler`, +// and `module` are only required for quantization. See the comments of +// `RunQuantization` for details. If quantization is not performed, they will be +// ignored. +// +// Returns a failure status when any of the three steps fail. `pass_manager` +// will be cleared before returning. +mlir::LogicalResult RunHloToTfConversion( + const mlir::TFL::PassConfig& pass_config, + const absl::string_view saved_model_dir, + const std::unordered_set& saved_model_tags, + const QuantizationConfig& quantization_config, + const PyFunctionLibrary* quantization_py_function_lib, + const SavedModelBundle* saved_model_bundle, mlir::PassManager& pass_manager, + mlir::StatusScopedDiagnosticHandler& status_handler, ModuleOp& module) { + // TODO: b/194747383 - We need to valid that indeed the "main" func is + // presented. + AddPreQuantizationStableHloToTfPasses(/*entry_function_name=*/"main", + pass_config, pass_manager); + if (failed(pass_manager.run(module))) { + return mlir::failure(); + } + pass_manager.clear(); + + if (pass_config.enable_stablehlo_quantizer) { + const absl::StatusOr quantized_module_op = RunQuantization( + saved_model_bundle, saved_model_dir, saved_model_tags, + quantization_config, quantization_py_function_lib, module); + if (!quantized_module_op.ok()) { + LOG(ERROR) << "Failed to run quantization: " + << quantized_module_op.status(); + return mlir::failure(); + } + module = *quantized_module_op; + } + + AddPostQuantizationStableHloToTfPasses(pass_config, pass_manager); + if (failed(pass_manager.run(module))) { + return mlir::failure(); + } + pass_manager.clear(); + + return mlir::success(); } + } // namespace -StatusOr> LoadFromGraphdefOrMlirSource( +absl::StatusOr> LoadFromGraphdefOrMlirSource( const std::string& input_filename, bool input_mlir, bool use_splatted_constant, const std::vector& extra_tf_opdefs, const GraphImportConfig& specs, absl::string_view debug_info_file, @@ -156,8 +224,8 @@ StatusOr> LoadFromGraphdefOrMlirSource( std::string error_message; auto file = mlir::openInputFile(input_filename, &error_message); if (!file) { - llvm::errs() << error_message << "\n"; - return errors::InvalidArgument("fail to open input file"); + return absl::InvalidArgumentError( + absl::StrCat("Failed to open input file: ", error_message)); } if (input_mlir) { @@ -170,7 +238,7 @@ StatusOr> LoadFromGraphdefOrMlirSource( auto extra_opdefs_status = RegisterExtraTfOpDefs(extra_tf_opdefs); if (!extra_opdefs_status.ok()) return extra_opdefs_status; - ::tensorflow::GraphdefToMlirOptions graphdef_conversion_options{ + GraphdefToMlirOptions graphdef_conversion_options{ std::string(debug_info_file), /*xla_compile_device_type=*/"", /*prune_unused_nodes=*/specs.prune_unused_nodes, @@ -182,21 +250,21 @@ StatusOr> LoadFromGraphdefOrMlirSource( /*enable_soft_placement=*/false}; if (use_splatted_constant) { - return tensorflow::GraphdefToSplattedMlirTranslateFunction( + return GraphdefToSplattedMlirTranslateFunction( file->getBuffer(), input_arrays, input_dtypes, input_shapes, output_arrays, control_output_arrays, graphdef_conversion_options, context); } - return tensorflow::GraphdefToMlirTranslateFunction( - file->getBuffer(), input_arrays, input_dtypes, input_shapes, - output_arrays, control_output_arrays, graphdef_conversion_options, - context); + return GraphdefToMlirTranslateFunction(file->getBuffer(), input_arrays, + input_dtypes, input_shapes, + output_arrays, control_output_arrays, + graphdef_conversion_options, context); } // Applying post-training dynamic range quantization from the old TOCO quantizer // on the translated_result using quant_specs and saving the final output in // result. -Status ApplyDynamicRangeQuantizationFromOldQuantizer( +absl::Status ApplyDynamicRangeQuantizationFromOldQuantizer( const mlir::quant::QuantizationSpecs& quant_specs, std::string translated_result, std::string* result) { flatbuffers::FlatBufferBuilder q_builder(/*initial_size=*/10240); @@ -206,14 +274,14 @@ Status ApplyDynamicRangeQuantizationFromOldQuantizer( ::tflite::optimize::BufferType quantized_type; switch (quant_specs.inference_type) { - case tensorflow::DT_QINT8: + case DT_QINT8: quantized_type = ::tflite::optimize::BufferType::QUANTIZED_INT8; break; - case tensorflow::DT_HALF: + case DT_HALF: quantized_type = ::tflite::optimize::BufferType::QUANTIZED_FLOAT16; break; default: - return errors::InvalidArgument("Quantized type not supported"); + return absl::InvalidArgumentError("Quantized type not supported"); break; } @@ -221,59 +289,59 @@ Status ApplyDynamicRangeQuantizationFromOldQuantizer( if (::tflite::optimize::QuantizeWeights( &q_builder, input_model, quantized_type, use_updated_hybrid_scheme, ::tflite::optimize::QuantizerType::OLD_QUANTIZER) != kTfLiteOk) { - return errors::InvalidArgument("Quantize weights transformation failed."); + return absl::InvalidArgumentError( + "Quantize weights transformation failed."); } const uint8_t* q_buffer = q_builder.GetBufferPointer(); *result = - string(reinterpret_cast(q_buffer), q_builder.GetSize()); + std::string(reinterpret_cast(q_buffer), q_builder.GetSize()); - return OkStatus(); + return absl::OkStatus(); } -Status ConvertTFExecutorToStablehloFlatbuffer( +absl::Status ConvertTFExecutorToStablehloFlatbuffer( mlir::PassManager& pass_manager, mlir::ModuleOp module, bool export_to_mlir, - mlir::StatusScopedDiagnosticHandler& statusHandler, + mlir::StatusScopedDiagnosticHandler& status_handler, const toco::TocoFlags& toco_flags, const mlir::TFL::PassConfig& pass_config, - std::optional session, std::string* result, + std::optional session, std::string* result, const std::unordered_set& saved_model_tags) { // Currently, TF quantization only support dynamic range quant, as such // when toco flag post training quantization is specified with converting to // stablehlo, we automatically enable dynamic range quantization if (toco_flags.post_training_quantize()) { - const auto status = tensorflow::quantization::PreprocessAndFreezeGraph( + const auto status = quantization::PreprocessAndFreezeGraph( module, module.getContext(), session); if (!status.ok()) { - return errors::Aborted("Failed to preprocess & freeze TF graph"); + return status_handler.Combine( + absl::InternalError("Failed to preprocess & freeze TF graph.")); } - // TODO(b/264218457): Refactor the component below once StableHLO Quantizer + // TODO: b/264218457 - Refactor the component below once StableHLO Quantizer // can run DRQ. Temporarily using TF Quantization for StableHLO DRQ. if (!toco_flags.has_quantization_options()) { // The default minimum number of elements a weights array must have to be // quantized by this transformation. const int kWeightsMinNumElementsDefault = 1024; - tensorflow::quantization::QuantizationOptions quantization_options; + quantization::QuantizationOptions quantization_options; quantization_options.mutable_quantization_method()->set_preset_method( - tensorflow::quantization::QuantizationMethod:: - METHOD_DYNAMIC_RANGE_INT8); - quantization_options.set_op_set( - tensorflow::quantization::UNIFORM_QUANTIZED); + quantization::QuantizationMethod::METHOD_DYNAMIC_RANGE_INT8); + quantization_options.set_op_set(quantization::UNIFORM_QUANTIZED); quantization_options.set_min_num_elements_for_weights( kWeightsMinNumElementsDefault); - tensorflow::quantization::AddQuantizePtqDynamicRangePasses( - pass_manager, quantization_options); + quantization::AddQuantizePtqDynamicRangePasses(pass_manager, + quantization_options); } if (failed(pass_manager.run(module))) { - return statusHandler.ConsumeStatus(); + return status_handler.ConsumeStatus(); } } pass_manager.clear(); - mlir::odml::AddTFToStablehloPasses(pass_manager, /*skip_resize*/ true, - /*smuggle_disallowed_ops*/ true); + mlir::odml::AddTFToStablehloPasses(pass_manager, /*skip_resize=*/true, + /*smuggle_disallowed_ops=*/true); // Print out a detailed report of non-converted stats. pass_manager.addPass(mlir::odml::createPrintOpStatsPass( mlir::odml::GetAcceptedStableHLODialects())); @@ -283,13 +351,13 @@ Status ConvertTFExecutorToStablehloFlatbuffer( pass_manager, toco_flags.quantization_options()); } if (failed(pass_manager.run(module))) { - return statusHandler.ConsumeStatus(); + return status_handler.ConsumeStatus(); } if (export_to_mlir) { llvm::raw_string_ostream os(*result); module.print(os); - return statusHandler.ConsumeStatus(); + return status_handler.ConsumeStatus(); } // Write MLIR Stablehlo dialect into FlatBuffer @@ -301,24 +369,20 @@ Status ConvertTFExecutorToStablehloFlatbuffer( options.metadata[tflite::kModelUseStablehloTensorKey] = "true"; if (!tflite::MlirToFlatBufferTranslateFunction(module, options, result, true)) { - auto s = statusHandler.ConsumeStatus(); - std::string message = "Could not translate MLIR to FlatBuffer."; - if (!s.ok()) { - absl::StrAppend(&message, " ", s.ToString()); - } - return absl::UnknownError(message); + return status_handler.Combine( + absl::InternalError("Could not translate MLIR to FlatBuffer.")); } - return OkStatus(); + return absl::OkStatus(); } -Status ConvertTFExecutorToTFLOrFlatbuffer( +absl::Status ConvertTFExecutorToTFLOrFlatbuffer( mlir::ModuleOp module, bool export_to_mlir, const toco::TocoFlags& toco_flags, const mlir::TFL::PassConfig& pass_config, const std::unordered_set& saved_model_tags, - llvm::StringRef saved_model_dir, - std::optional session, std::string* result, - bool serialize_stablehlo_ops) { + llvm::StringRef saved_model_dir, SavedModelBundle* saved_model_bundle, + std::string* result, bool serialize_stablehlo_ops, + const PyFunctionLibrary* quantization_py_function_lib) { // Explicitly disable dumping Op details on failures. module.getContext()->printOpOnDiagnostic(false); @@ -326,92 +390,77 @@ Status ConvertTFExecutorToTFLOrFlatbuffer( mlir::func::registerAllExtensions(registry); module.getContext()->appendDialectRegistry(registry); - // Register a warning handler only log to std out. - mlir::ScopedDiagnosticHandler s( - module.getContext(), [](mlir::Diagnostic& diag) { - if (diag.getSeverity() == mlir::DiagnosticSeverity::Warning) { - for (auto& note : diag.getNotes()) { - std::cout << note.str() << "\n"; - LOG(WARNING) << note.str() << "\n"; - } - } - return mlir::failure(); - }); - - mlir::StatusScopedDiagnosticHandler statusHandler(module.getContext(), - /*propagate=*/true); - - if (failed(IsValidGraph(module))) { - return statusHandler.ConsumeStatus(); - } + mlir::StatusScopedDiagnosticHandler status_handler(module.getContext(), + /*propagate=*/true); mlir::PassManager pass_manager(module.getContext()); mlir::registerPassManagerCLOptions(); if (mlir::failed(mlir::applyPassManagerCLOptions(pass_manager))) { - return absl::UnknownError("failed to apply MLIR pass manager CL options"); + return absl::InternalError("Failed to apply MLIR pass manager CL options."); } + InitPassManager(pass_manager, toco_flags.debug_options()); + pass_manager.addInstrumentation( std::make_unique( pass_manager.getContext())); - InitPassManager(pass_manager, toco_flags.debug_options()); + if (failed(IsValidGraph(module))) { + return status_handler.ConsumeStatus(); + } + + Session* session = saved_model_bundle == nullptr + ? nullptr + : saved_model_bundle->GetSession(); if (pass_config.enable_stablehlo_conversion) { + // `ConvertTFExecutorToStablehloFlatbuffer` expects a `std::nullopt` if the + // `Session*` is a nullptr. + std::optional session_opt = + session == nullptr ? std::nullopt : std::make_optional(session); + // return to avoid adding TFL converter path return ConvertTFExecutorToStablehloFlatbuffer( - pass_manager, module, export_to_mlir, statusHandler, toco_flags, - pass_config, session, result, saved_model_tags); + pass_manager, module, export_to_mlir, status_handler, toco_flags, + pass_config, std::move(session_opt), result, saved_model_tags); } - tensorflow::AddPreVariableFreezingTFToTFLConversionPasses(pass_config, - &pass_manager); + if (pass_config.enable_hlo_to_tf_conversion) { + if (failed(RunHloToTfConversion( + pass_config, saved_model_dir, saved_model_tags, + toco_flags.quantization_config(), quantization_py_function_lib, + saved_model_bundle, pass_manager, status_handler, module))) { + return status_handler.ConsumeStatus(); + } + } + + AddPreVariableFreezingTFToTFLConversionPasses(pass_config, &pass_manager); if (failed(pass_manager.run(module))) { - return statusHandler.ConsumeStatus(); + return status_handler.ConsumeStatus(); } + // Freeze variables if a session is provided. - if (session.has_value()) { - mlir::TFL::ErrorCollectorInstrumentation collector(module.getContext()); - if (failed( - mlir::tf_saved_model::FreezeVariables(module, session.value()))) { - auto status = statusHandler.ConsumeStatus(); - mlir::TFL::ErrorCollector* collector = - mlir::TFL::ErrorCollector::GetErrorCollector(); - if (!collector->CollectedErrors().empty()) { - // LINT.IfChange - return errors::InvalidArgument( - "Variable constant folding is failed. Please consider using " - "enabling `experimental_enable_resource_variables` flag in the " - "TFLite converter object. For example, " - "converter.experimental_enable_resource_variables = True"); - // LINT.ThenChange(//tensorflow/lite/python/lite_v2_test.py) - } - return status; - } + if (session != nullptr && + failed(mlir::tf_saved_model::FreezeVariables(module, session))) { + return status_handler.Combine(absl::InvalidArgumentError( + "Variable constant folding is failed. Please consider using " + "enabling `experimental_enable_resource_variables` flag in the " + "TFLite converter object. For example, " + "converter.experimental_enable_resource_variables = True")); } pass_manager.clear(); - tensorflow::AddPostVariableFreezingTFToTFLConversionPasses( - saved_model_dir, toco_flags, pass_config, &pass_manager); + AddPostVariableFreezingTFToTFLConversionPasses(saved_model_dir, toco_flags, + pass_config, &pass_manager); if (failed(pass_manager.run(module))) { - auto status = statusHandler.ConsumeStatus(); - mlir::TFL::ErrorCollector* collector = - mlir::TFL::ErrorCollector::GetErrorCollector(); - for (const auto& error_data : collector->CollectedErrors()) { - if (error_data.subcomponent() == "FreezeGlobalTensorsPass") { - // LINT.IfChange - return errors::InvalidArgument( - "Variable constant folding is failed. Please consider using " - "enabling `experimental_enable_resource_variables` flag in the " - "TFLite converter object. For example, " - "converter.experimental_enable_resource_variables = True"); - // LINT.ThenChange(//tensorflow/lite/python/lite_v2_test.py) - } - } - return status; + return status_handler.Combine(absl::InvalidArgumentError( + "Variable constant folding is failed. Please consider using " + "enabling `experimental_enable_resource_variables` flag in the " + "TFLite converter object. For example, " + "converter.experimental_enable_resource_variables = True")); } if (failed(GraphContainsStatefulPartitionedOp(module))) { - return statusHandler.ConsumeStatus(); + return status_handler.ConsumeStatus(); } if (export_to_mlir) { @@ -420,12 +469,12 @@ Status ConvertTFExecutorToTFLOrFlatbuffer( pass_manager.addPass(mlir::odml::createPrintOpStatsPass( mlir::odml::GetAcceptedTFLiteDialects())); if (failed(pass_manager.run(module))) { - return statusHandler.ConsumeStatus(); + return status_handler.ConsumeStatus(); } llvm::raw_string_ostream os(*result); module.print(os); - return statusHandler.ConsumeStatus(); + return status_handler.ConsumeStatus(); } // Write MLIR TFLite dialect into FlatBuffer @@ -443,15 +492,11 @@ Status ConvertTFExecutorToTFLOrFlatbuffer( } if (!tflite::MlirToFlatBufferTranslateFunction( module, options, &translated_result, serialize_stablehlo_ops)) { - auto s = statusHandler.ConsumeStatus(); - std::string message = "Could not translate MLIR to FlatBuffer."; - if (!s.ok()) { - absl::StrAppend(&message, " ", s.ToString()); - } - return absl::UnknownError(message); + return status_handler.Combine( + absl::InternalError("Could not translate MLIR to FlatBuffer.")); } - // TODO(b/176267167): Quantize flex fallback in the MLIR pipeline + // TODO: b/176267167 - Quantize flex fallback in the MLIR pipeline if (quant_specs.weight_quantization && (!quant_specs.RunAndRewriteDynamicRangeQuantizationPasses() || !pass_config.emit_builtin_tflite_ops)) { @@ -460,30 +505,33 @@ Status ConvertTFExecutorToTFLOrFlatbuffer( // statement. auto status = ApplyDynamicRangeQuantizationFromOldQuantizer( quant_specs, translated_result, result); - if (!status.ok()) return status; + if (!status.ok()) { + return status_handler.Combine(status); + } } else { *result = translated_result; } if (mlir::failed(module.verifyInvariants())) { - return tensorflow::errors::Unknown("Final module is invalid"); + return status_handler.Combine( + absl::InternalError("Final module is invalid.")); } - return OkStatus(); + return absl::OkStatus(); } -StatusOr> ImportSavedModel( +absl::StatusOr> ImportSavedModel( const std::string& input_filename, const int saved_model_version, const std::unordered_set& tags, absl::Span extra_tf_opdefs, absl::Span exported_names, const GraphImportConfig& specs, bool enable_variable_lifting, mlir::MLIRContext* context, - std::unique_ptr* saved_model_bundle) { + std::unique_ptr* saved_model_bundle) { // Register extra TF ops passed as OpDef. auto extra_opdefs_status = RegisterExtraTfOpDefs(extra_tf_opdefs); if (!extra_opdefs_status.ok()) return extra_opdefs_status; if (saved_model_version == 2) { - auto module_or = tensorflow::SavedModelObjectGraphToMlirImport( + auto module_or = SavedModelObjectGraphToMlirImport( input_filename, tags, exported_names, context, /*unconditionally_use_set_output_shapes=*/true); if (!module_or.status().ok()) return module_or.status(); @@ -493,15 +541,14 @@ StatusOr> ImportSavedModel( options.upgrade_legacy = specs.upgrade_legacy; options.unconditionally_use_set_output_shapes = true; options.lift_variables = enable_variable_lifting; - auto module_or = tensorflow::SavedModelSignatureDefsToMlirImport( + auto module_or = SavedModelSignatureDefsToMlirImport( input_filename, tags, exported_names, context, options, saved_model_bundle); if (!module_or.status().ok()) return module_or.status(); return std::move(module_or).value(); } else { - return tensorflow::errors::InvalidArgument( - "Should be either saved model v1 or v2"); + return absl::InvalidArgumentError("Should be either saved model v1 or v2."); } } diff --git a/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.h b/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.h index 8ef8813c6c19fc..a82afbeee7eda6 100644 --- a/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.h +++ b/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.h @@ -22,15 +22,17 @@ limitations under the License. #include #include +#include "absl/strings/string_view.h" #include "absl/types/span.h" #include "llvm/Support/SourceMgr.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/OwningOpRef.h" // from @llvm-project #include "mlir/Pass/PassManager.h" // from @llvm-project #include "tensorflow/cc/saved_model/loader.h" #include "tensorflow/compiler/mlir/lite/common/tfl_pass_config.h" #include "tensorflow/compiler/mlir/lite/quantization/quantization_config.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/python/py_function_lib.h" #include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h" #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h" #include "tensorflow/lite/toco/toco_flags.pb.h" @@ -84,9 +86,10 @@ Status ConvertTFExecutorToTFLOrFlatbuffer( mlir::ModuleOp module, bool export_to_mlir, const toco::TocoFlags& toco_flags, const mlir::TFL::PassConfig& pass_config, const std::unordered_set& saved_model_tags, - llvm::StringRef saved_model_dir, - std::optional session, std::string* result, - bool serialize_stablehlo_ops = false); + llvm::StringRef saved_model_dir, SavedModelBundle* saved_model_bundle, + std::string* result, bool serialize_stablehlo_ops = false, + const quantization::PyFunctionLibrary* quantization_py_function_lib = + nullptr); } // namespace tensorflow #endif // TENSORFLOW_COMPILER_MLIR_LITE_TF_TO_TFL_FLATBUFFER_H_ diff --git a/tensorflow/compiler/mlir/lite/transforms/default_quant_params.cc b/tensorflow/compiler/mlir/lite/transforms/default_quant_params.cc index bf4224c7631dd0..f620995ea2ecfd 100644 --- a/tensorflow/compiler/mlir/lite/transforms/default_quant_params.cc +++ b/tensorflow/compiler/mlir/lite/transforms/default_quant_params.cc @@ -14,25 +14,24 @@ limitations under the License. ==============================================================================*/ #include +#include #include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineMap.h" -#include "mlir/IR/Attributes.h" #include "mlir/IR/BuiltinTypes.h" -#include "mlir/IR/PatternMatch.h" #include "mlir/Pass/Pass.h" -#include "mlir/Support/LLVM.h" -#include "absl/memory/memory.h" #include "llvm/ADT/STLExtras.h" -#include "llvm/ADT/StringSwitch.h" -#include "mlir/IR/Location.h" // from @llvm-project +#include "llvm/Support/Casting.h" +#include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/Support/TypeID.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" #include "tensorflow/compiler/mlir/lite/quantization/ir/FakeQuantSupport.h" -#include "tensorflow/compiler/mlir/lite/quantization/ir/QuantOps.h" #include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h" #include "tensorflow/compiler/mlir/lite/transforms/passes.h" -#include "tensorflow/compiler/mlir/lite/transforms/prepare_quantize_helper.h" +#include "tensorflow/compiler/mlir/lite/transforms/prepare_quantize_helper.h" // IWYU pragma: keep +#include "tensorflow/compiler/mlir/lite/utils/utils.h" //===----------------------------------------------------------------------===// // The Pass to add default quantization parameters for the activations which @@ -215,7 +214,8 @@ quant::QuantParams DefaultQuantParamsPass::GetQuantParamsForBias( // The non-bias hasn't been quantized, let's skip this bias. if (non_bias_types.size() != non_biases.size()) return {}; - return func(non_bias_types, false); + return func(/*op_types=*/non_bias_types, /*adjusted_quant_dim=*/-1, + /*legacy_float_scale=*/false); } quant::QuantParams DefaultQuantParamsPass::GetDefaultQuantParams( diff --git a/tensorflow/compiler/mlir/lite/transforms/legalize_patterns.td b/tensorflow/compiler/mlir/lite/transforms/legalize_patterns.td index a2ea10fe199736..cfe9bc754d8077 100644 --- a/tensorflow/compiler/mlir/lite/transforms/legalize_patterns.td +++ b/tensorflow/compiler/mlir/lite/transforms/legalize_patterns.td @@ -70,6 +70,9 @@ def ExtractSingleElementAsInt32 : NativeCodeCall< def CreateTFCastToInt32Op : NativeCodeCall< "CreateCastToInt32($0, $_loc, $_builder)">; +def CreateInt32ConstOrCast : NativeCodeCall< + "CreateInt32ConstOrCast($0, $_loc, $_builder)">; + def CreateNoneValue : NativeCodeCall< "$_builder.create($0.getLoc(), $_builder.getUnitAttr())">; @@ -587,10 +590,7 @@ def LegalizeCumsum : Pat< def LegalizeReshape : Pat< (TF_ReshapeOp $input, $shape), - (TFL_ReshapeOp $input, (CreateTFCastToInt32Op $shape))>; - -def ZeroIntAttr - : AttrConstraint().getInt() == 0">>; + (TFL_ReshapeOp $input, (CreateInt32ConstOrCast $shape))>; def LegalizeStridedSlice : Pat< (TF_StridedSliceOp diff --git a/tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc b/tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc index 1ae84c64ddea7f..cc4e7e46b71b99 100644 --- a/tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc +++ b/tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc @@ -33,13 +33,17 @@ limitations under the License. #include "llvm/ADT/StringSwitch.h" #include "llvm/Support/Threading.h" #include "llvm/Support/raw_ostream.h" +#include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project #include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project #include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project #include "mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/IR/Diagnostics.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Matchers.h" // from @llvm-project #include "mlir/IR/Operation.h" // from @llvm-project #include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project @@ -106,6 +110,34 @@ Value CreateCastToInt32(Value val, Location loc, PatternRewriter& rewriter) { rewriter.getBoolAttr(false)); } +// Utility function to- +// 1. Create a tfl.const op with an int32_t values, from an MLIR Value, if the +// `Value` can be matched to a Constant DenseIntElementsAttr. +// This will make sure the dynamic dimensions are asigned to be `-1` +// 2. In the default case, cast the `Value` to an int32_t. +Value CreateInt32ConstOrCast(Value val, Location loc, + PatternRewriter& rewriter) { + if (val.getType().cast().hasStaticShape()) { + DenseElementsAttr shape_value_attr; + if (matchPattern(val, m_Constant(&shape_value_attr))) { + SmallVector new_shape_array_i32; + auto shape_value_array = shape_value_attr.getValues(); + for (int32_t idx = 0; idx < shape_value_array.size(); ++idx) { + auto size = shape_value_array[idx].getSExtValue(); + new_shape_array_i32.push_back( + ShapedType::isDynamic(size) ? -1 : static_cast(size)); + } + return rewriter.create( + loc, DenseIntElementsAttr::get( + RankedTensorType::get(new_shape_array_i32.size(), + rewriter.getIntegerType(32)), + new_shape_array_i32)); + } + } + + return CreateCastToInt32(val, loc, rewriter); +} + // Get shape of an operand or result, support both dynamic and static shape. Value GetShape(Value input, Location loc, PatternRewriter& rewriter) { auto shaped_type = input.getType().cast(); diff --git a/tensorflow/compiler/mlir/lite/transforms/optimize.cc b/tensorflow/compiler/mlir/lite/transforms/optimize.cc index 8af3bbd7e812fe..13703233f6259f 100644 --- a/tensorflow/compiler/mlir/lite/transforms/optimize.cc +++ b/tensorflow/compiler/mlir/lite/transforms/optimize.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -60,6 +60,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h" #include "tensorflow/compiler/mlir/lite/transforms/passes.h" #include "tensorflow/compiler/mlir/lite/utils/attribute_utils.h" +#include "tensorflow/compiler/mlir/lite/utils/constant_utils.h" #include "tensorflow/compiler/mlir/lite/utils/convert_type.h" #include "tensorflow/compiler/mlir/lite/utils/utils.h" #include "tensorflow/compiler/mlir/lite/utils/validators.h" @@ -704,6 +705,252 @@ bool IsPermutationNCHW(Value perm) { #include "tensorflow/compiler/mlir/lite/transforms/generated_optimize.inc" +// Returns 1D 32-bit dense elements attribute with the given values. +static DenseIntElementsAttr GetI32ElementsAttr(ArrayRef values, + Builder *builder) { + RankedTensorType ty = mlir::RankedTensorType::get( + {static_cast(values.size())}, builder->getIntegerType(32)); + return DenseIntElementsAttr::get(ty, values); +} + +DenseIntElementsAttr GetI64ElementsAttr(ArrayRef values, + Builder *builder) { + RankedTensorType ty = RankedTensorType::get( + {static_cast(values.size())}, builder->getIntegerType(64)); + return DenseIntElementsAttr::get(ty, values); +} + +// Get the number of leading 1s in the shape of the given input. +// Ex. input_shape = [1 x 1 x 1 x 1 x 2 x 1] => 4 +// returns 0 if the input shape is not static. +int GetNumLeadingOnes(ShapedType input_type) { + if (!input_type.hasStaticShape()) return 0; + auto input_shape = input_type.getShape(); + int num_leading_broadcast_dims = 0; + for (int i = 0; i < input_shape.size(); ++i) { + if (input_shape[i] == 1) { + ++num_leading_broadcast_dims; + } else { + break; + } + } + return num_leading_broadcast_dims; +} + +// Return the number of trailing 1s in the shape of the given input. +// Ex. input_shape = [1 x 1 x 2 x 1] => 1 +// returns 0 if the input shape is not static. +int GetNumTrailingOnes(ShapedType input_type) { + if (!input_type.hasStaticShape()) return 0; + auto input_shape = input_type.getShape(); + int num_trailing_broadcast_dims = 0; + for (int i = input_shape.size() - 1; i >= 0; --i) { + if (input_shape[i] == 1) { + ++num_trailing_broadcast_dims; + } else { + break; + } + } + return num_trailing_broadcast_dims; +} + +// Consider as Reshape( +// Broadcast( +// Reshape(input, // input_shape=[1 x n] +// inner_shape), // inner_shape=[1 x 1 x 1 x n x 1 x 1] +// broadcast_shape), // broadcast_shape=[1 x 1 x 1 x n x m x 1] +// outer_shape))) // outer_shape=[1 x 1 x n*m] +// Here the broadcast operation is used to create `m` repetetions of the `n` +// elements in the origiginal tensor, making a total of `m*n` number of elements +// in the final tensor that will then be reshaped to form something like +// [1 x 1 x 1 x m*n] by the outermost reshape_op. +// problem: The inefficiency here is that the innermost reshape_op and the +// broadcast_op are introducing unnecessary leading and trailing 1s'. +// fix: Remove the unnecessary 1s' in the inner reshape_op and broadcast_op. +struct SqueezeReshapesAroundBroadcastOp + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(TFL::BroadcastToOp tfl_broadcast_to_op, + PatternRewriter &rewriter) const override { + auto loc = tfl_broadcast_to_op->getLoc(); + + // Match the + // Reshape( + // Broadcast( + // Reshape(input,inner_shape), + // broadcast_shape), + // outer_shape))) pattern. + if (!llvm::dyn_cast_or_null( + tfl_broadcast_to_op.getInput().getDefiningOp()) || + // Check that the broadcast_to op has only one use. + !tfl_broadcast_to_op.getOutput().hasOneUse() || + !llvm::dyn_cast_or_null( + *tfl_broadcast_to_op.getOutput().getUsers().begin())) { + return rewriter.notifyMatchFailure( + loc, "No Reshape->BroadcastTo->Reshape pattern found"); + } + + // Pattern is applied only if the broadcast_to shape has more than 5 + // dimensions. + if (tfl_broadcast_to_op.getShape() + .getType() + .cast() + .getNumElements() < 6) { + return rewriter.notifyMatchFailure(loc, + "Not supported broadcast_to shape"); + } + auto inner_reshape_op = llvm::dyn_cast_or_null( + tfl_broadcast_to_op.getInput().getDefiningOp()); + auto inner_reshape_input = inner_reshape_op.getInput(); + auto outer_reshape_op = llvm::dyn_cast_or_null( + *tfl_broadcast_to_op.getOutput().getUsers().begin()); + + // Check that the outermost reshape_op in the pattern does not add + // additional elements to the final output tensor. + // TODO: b/323217483. This code needs to generalized to additional cases. + // For example- inner-shape = [1, 1, 1, 8, 1, 10], + // broadcast_shape = [1, 1, 1, 8, 16, 10] & outer_shape = [1, 1, 1, 1280, 1] + // And extend the pettern to handle dynamic shapes. + if (!inner_reshape_op.getOutput().getType().hasStaticShape() || + !tfl_broadcast_to_op.getOutput().getType().hasStaticShape() || + !outer_reshape_op.getOutput().getType().hasStaticShape()) { + return rewriter.notifyMatchFailure( + loc, "Unsupported shapes. Currely only static shapes are supported"); + } + + if (!IsLastDimEqualToNumElements(inner_reshape_input.getType(), + inner_reshape_op.getOutput().getType()) || + !IsLastDimEqualToNumElements( + outer_reshape_op.getOutput().getType(), + tfl_broadcast_to_op.getOutput().getType())) { + return rewriter.notifyMatchFailure( + loc, "Not supported Reshape->BroadcastTo->Reshape pattern"); + } + + // Calculate the number of extra leading and trailing 1s in the + // broadcast_op output. + auto broadcast_output_shapetype = + tfl_broadcast_to_op.getOutput().getType().cast(); + int num_leading_broadcast_dims = + GetNumLeadingOnes(broadcast_output_shapetype); + int num_trailing_broadcast_dims = + GetNumTrailingOnes(broadcast_output_shapetype); + + // Get the new shape for the inner reshape_op after removing the extra 1s. + llvm::SmallVector new_reshape_shape_i32{ + inner_reshape_op.getOutput() + .getType() + .cast() + .getShape() + .drop_back(num_trailing_broadcast_dims) + .drop_front(num_leading_broadcast_dims)}; + + Value new_reshape_shape_value = rewriter.create( + inner_reshape_op->getLoc(), + GetI32ElementsAttr(new_reshape_shape_i32, &rewriter)); + + auto new_inner_reshape_op = rewriter.create( + inner_reshape_op->getLoc(), + inner_reshape_input, new_reshape_shape_value); + + // Create a new reshape_op to replace the old inner reshape_op. + rewriter.replaceOp(inner_reshape_op, new_inner_reshape_op.getResult()); + + // Get the new shape for the broadcast_op after removing the extra 1s. + llvm::SmallVector new_broadcast_shape{ + broadcast_output_shapetype.getShape() + .drop_back(num_trailing_broadcast_dims) + .drop_front(num_leading_broadcast_dims)}; + + Value new_broadcast_shape_value = rewriter.create( + loc, GetI64ElementsAttr(new_broadcast_shape, &rewriter)); + + auto new_broadcast_to_op = rewriter.create( + loc, RankedTensorType::get(new_broadcast_shape, rewriter.getF32Type()), + new_inner_reshape_op.getOutput(), new_broadcast_shape_value); + + // Create a new broadcast_op to replace the old broadcast_op. + rewriter.replaceOp(tfl_broadcast_to_op, new_broadcast_to_op.getResult()); + + return success(); + } +}; + +// This pattern matches TFL::BroadcastToOp WITH TENSOR RANK <= 4 and replaces +// it with a MulOp that multiplies the tensor by a splat constant with 1s. +struct ConvertTFLBroadcastToMulOp + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(TFL::BroadcastToOp tfl_broadcast_to_op, + PatternRewriter &rewriter) const override { + auto input_type = + tfl_broadcast_to_op.getInput().getType().cast(); + auto output_type = + tfl_broadcast_to_op.getOutput().getType().cast(); + auto shape_type = + tfl_broadcast_to_op.getShape().getType().cast(); + Type element_type = input_type.getElementType(); + + auto loc = tfl_broadcast_to_op->getLoc(); + + // Check that the output type is not dynamic and is less-than-equal to 4D or + // the shape type is static, 1D and has less-than-equal to 4 elements. + bool is_output_shape_dynamic = + (!output_type.hasRank() || (output_type.getRank() > 4) || + (output_type.getNumDynamicDims() > 0)); + bool is_broadcast_shape_dynamic = + (!shape_type.hasStaticShape() || (shape_type.getRank() != 1) || + (shape_type.getDimSize(0) > 4)); + if (is_output_shape_dynamic && is_broadcast_shape_dynamic) + return rewriter.notifyMatchFailure( + loc, "output_rank or broadcast_to shape not supported"); + + // Allow lowering when the input's elements type is F32, BFloat16, I32 or + // I16. + if (!(element_type.isa() || + element_type.isInteger(32) || element_type.isInteger(16))) + return rewriter.notifyMatchFailure(loc, "element_type_not_supported"); + + // TFL_FillOp is created only if is_output_shape_dynamic is true, otherwise + // a Arith.ConstOp is created. + if (is_output_shape_dynamic && + output_type.getElementType().isUnsignedInteger()) { + return rewriter.notifyMatchFailure( + loc, + "Unsigned broadcast_to output with dynamic shape is not supported"); + } + + Value mul_rhs_value; + if (!output_type.hasRank() || (output_type.getNumDynamicDims() > 0)) { + auto status_or_const_op = + CreateConstOpWithSingleValue(&rewriter, loc, input_type, 1); + if (!status_or_const_op.ok()) { + return failure(); + } + + mul_rhs_value = rewriter.create( + loc, output_type, tfl_broadcast_to_op.getShape(), + status_or_const_op.value()); + } else { + auto status_or_const_op = + CreateConstOpWithVectorValue(&rewriter, loc, output_type, 1); + if (!status_or_const_op.ok()) { + return failure(); + } + + mul_rhs_value = status_or_const_op.value(); + } + + auto mul_op = rewriter.create( + loc, output_type, tfl_broadcast_to_op.getInput(), mul_rhs_value, + rewriter.getStringAttr("NONE")); + rewriter.replaceOp(tfl_broadcast_to_op, mul_op.getResult()); + return success(); + } +}; + struct FuseAddAndStridedSlice : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -886,14 +1133,9 @@ struct Convert2DUpscalingToResizeNearestNeighor SmallVector reshape_shape_in_int64( {1, image_size, image_size, feature_size}); - auto reshape_shape_type = - RankedTensorType::get({static_cast(reshape_shape.size())}, - rewriter.getIntegerType(32)); - auto reshape_shape_attr = - DenseIntElementsAttr::get(reshape_shape_type, reshape_shape); - auto reshape_shape_const_op = rewriter.create( - gather_nd_first->getLoc(), reshape_shape_attr); + gather_nd_first->getLoc(), + GetI32ElementsAttr(reshape_shape, &rewriter)); auto reshape_op = rewriter.create( gather_nd_first->getLoc(), @@ -903,12 +1145,8 @@ struct Convert2DUpscalingToResizeNearestNeighor // Add TFL::resize_nearest_neighor op for 2x upscaling. SmallVector size_vec = {image_size * 2, image_size * 2}; - auto size_type = mlir::RankedTensorType::get( - {static_cast(size_vec.size())}, rewriter.getIntegerType(32)); - auto size_attr = mlir::DenseIntElementsAttr::get(size_type, size_vec); - - auto size_const_op = - rewriter.create(gather_nd_first->getLoc(), size_attr); + auto size_const_op = rewriter.create( + gather_nd_first->getLoc(), GetI32ElementsAttr(size_vec, &rewriter)); auto resize = rewriter.create( gather_nd_first->getLoc(), transpose_second.getResult().getType(), @@ -1765,11 +2003,9 @@ struct ConvertTrivialTransposeOpToReshapeOp output_shape_values.push_back( ShapedType::isDynamic(dim) ? -1 : static_cast(dim)); } - auto type = mlir::RankedTensorType::get(output_shape_values.size(), - rewriter.getIntegerType(32)); - auto new_shape_attr = - mlir::DenseIntElementsAttr::get(type, output_shape_values); - auto new_shape = rewriter.create(loc, new_shape_attr); + + auto new_shape = rewriter.create( + loc, GetI32ElementsAttr(output_shape_values, &rewriter)); rewriter.replaceOpWithNewOp( transpose_op, transpose_op.getOutput().getType(), @@ -1938,11 +2174,7 @@ struct FuseUnpackAndConcatToReshape ShapedType::isDynamic(size) ? -1 : static_cast(size)); } auto new_shape = rewriter.create( - concat_op.getLoc(), - DenseIntElementsAttr::get( - RankedTensorType::get(new_shape_array_i32.size(), - rewriter.getIntegerType(32)), - new_shape_array_i32)); + concat_op.getLoc(), GetI32ElementsAttr(new_shape_array_i32, &rewriter)); rewriter.replaceOpWithNewOp( concat_op, output_type, unpack_op.getInput(), new_shape); @@ -2132,9 +2364,7 @@ struct FuseReshapeAndTransposeAroundBatchMatmul transpose_input.getType().getShape().begin() + 2, transpose_input.getType().getShape().end(), 1, std::multiplies()))}; auto shape_constant = rewriter.create( - batch_matmul.getLoc(), - DenseIntElementsAttr::get( - RankedTensorType::get(3, rewriter.getI32Type()), new_shape)); + batch_matmul.getLoc(), GetI32ElementsAttr(new_shape, &rewriter)); auto reshaped_input = rewriter.create( batch_matmul.getLoc(), transpose_op.getInput(), shape_constant); rewriter.replaceOpWithNewOp( @@ -2196,10 +2426,7 @@ struct FuseTransposeReshapeIntoBatchMatmul reshape_op.getType().getShape().drop_front().end()); new_shape.push_back(reshape_op.getType().getDimSize(0)); auto shape_constant = rewriter.create( - op.getLoc(), DenseIntElementsAttr::get( - RankedTensorType::get(reshape_op.getType().getRank(), - rewriter.getI32Type()), - new_shape)); + op.getLoc(), GetI32ElementsAttr(new_shape, &rewriter)); auto new_reshape = rewriter.create( op.getLoc(), transpose_op.getInput(), shape_constant); rewriter.replaceOpWithNewOp( @@ -2427,8 +2654,8 @@ void OptimizePass::runOnOperation() { // binary ops. RewritePatternSet phase_0_patterns(&getContext()); phase_0_patterns - .add( - ctx); + .add(ctx); (void)applyPatternsAndFoldGreedily(func, std::move(phase_0_patterns)); // Potentially the binary ops might be fused together, like hard_swish, thus diff --git a/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td b/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td index 114a01492ff16b..008decb62b0d55 100644 --- a/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td +++ b/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -1129,6 +1129,26 @@ def OptimizeSliceOp : Pat< (replaceWithValue $input), [(CanOptimizeIdentitySliceOp $input, $begin, $size)]>; +// Convert the StridedSliceOp to a SliceOp when possible. This will enable other +// optimizations on SliceOp to run. +def OptimizeStridedSlice : Pat< + (TFL_StridedSliceOp $input, + (Arith_ConstantOp $begin), + (Arith_ConstantOp $end), + (Arith_ConstantOp $stride), + ZeroIntAttr:$_, + ZeroIntAttr:$_, + ZeroIntAttr:$_, + ZeroIntAttr:$_, + ZeroIntAttr:$_, + ConstBoolAttrFalse), + (TFL_SliceOp $input, + (Arith_ConstantOp $begin), + (Arith_ConstantOp (GetOffSet $begin, $end))), + [(IsAllOnesConstant $stride), + (HasNonNegativeValues $begin), + (HasNonNegativeOffset $begin, $end)]>; + def GetNumElementsOrOne: NativeCodeCall<"GetNumElementsOrOne($0.getType())">; def ReshapeValueDroppingLastDim : NativeCodeCall< @@ -1510,14 +1530,7 @@ def FuseReshapesAroundBatchMatMulLHS1: Pat< (BroadcastDimsProductEqual<1> $final_shape_change, $bmm_tmp_output), (AreTensorSubSectionShapesEqual<1, 1> $input, $final_shape_change)]>; -// Fuse redundant TFL_TransposeOp into TFL_BatchMatMulOp -def FuseTransposeIntoBatchMatMulRHS: Pat< - (TFL_BatchMatMulOp $lhs, - (TFL_TransposeOp:$transposed_value $input, (Arith_ConstantOp:$perm_value $p0)), - $adj_x, $adj_y, $asymmetric_quantize_inputs), - (TFL_BatchMatMulOp $lhs, $input, $adj_x, ConstBoolAttrTrue, $asymmetric_quantize_inputs), - [(AreLastTwoDimsTransposed $perm_value), - (IsBoolAttrEqual<"false"> $adj_y)]>; + // Fuse redundant TFL_TransposeOp into TFL_BatchMatMulOp def FuseTransposeIntoBatchMatMulLHS: Pat< diff --git a/tensorflow/compiler/mlir/lite/transforms/passes.td b/tensorflow/compiler/mlir/lite/transforms/passes.td index 4dfea319b336d7..45428d2648a43f 100644 --- a/tensorflow/compiler/mlir/lite/transforms/passes.td +++ b/tensorflow/compiler/mlir/lite/transforms/passes.td @@ -309,6 +309,8 @@ def PrepareQuantizePass : Pass<"tfl-prepare-quantize", "mlir::func::FuncOp"> { "disable-set-input-nodes-quantization-params", "bool", "false", "Whether disable set input nodes quantization parameters.">, + Option<"is_qdq_conversion_", "is-qdq-conversion", "bool", "false", + "Whether the source graph is a QDQ model intended for conversion only.">, ]; } @@ -323,6 +325,9 @@ def PrepareDynamicRangeQuantizePass : Pass<"tfl-prepare-quantize-dynamic-range", Option<"enable_dynamic_range_per_channel_quantization_", "enable-dynamic-range-per-channel-quantization", "bool", "true", "Whether enable per-channel quantized weights.">, + Option<"enable_dynamic_range_per_channel_quantization_for_dense_layers_", + "enable-dynamic-range-per-channel-quantization-for-dense-layers", "bool", + "true", "Whether enable per-channel quantized weights for Fully Connected layers (default is per tensor).">, Option<"min_elements_for_weights_", "min-elements-for-weights", "int64_t", "1024", "The minimum number of elements in a weights array required to apply quantization.">, diff --git a/tensorflow/compiler/mlir/lite/transforms/prepare_quantize.cc b/tensorflow/compiler/mlir/lite/transforms/prepare_quantize.cc index 807d1c2dfaa2b5..fb613d74bbfaa2 100644 --- a/tensorflow/compiler/mlir/lite/transforms/prepare_quantize.cc +++ b/tensorflow/compiler/mlir/lite/transforms/prepare_quantize.cc @@ -14,7 +14,9 @@ limitations under the License. ==============================================================================*/ // This transformation pass applies quantization propagation on TFLite dialect. +#include #include +#include #include #include #include @@ -217,7 +219,9 @@ bool PrepareQuantizePass::SetInputNodesQuantizationParams(func::FuncOp func) { #include "tensorflow/compiler/mlir/lite/utils/generated_op_quant_spec_getters.inc" bool PrepareQuantizePass::RemoveRedundantStats(func::FuncOp func) { - return RemoveRedundantStatsOps(func, GetOpQuantSpec); + return RemoveRedundantStatsOps( + func, std::bind(GetOpQuantSpec, std::placeholders::_1, + quant_specs_.disable_per_channel_for_dense_layers)); } static Value Quantized(Operation* user) { @@ -402,12 +406,21 @@ void PrepareQuantizePass::runOnOperation() { SanityCheckAndAdjustment(func); + // Bind the getter with the fixed configuration parameter for the correct + // quantization settings of the ops. + std::function(Operation*)> + op_quant_spec_getter = + std::bind(GetOpQuantSpec, std::placeholders::_1, + quant_specs_.disable_per_channel_for_dense_layers); + // Finally, the quantization parameters can be propagated to the rest of the // values (tensors). ApplyQuantizationParamsPropagation( func, is_signed, bit_width, - disable_per_channel_ || quant_specs_.disable_per_channel, GetOpQuantSpec, - infer_tensor_range, quant_specs_.legacy_float_scale); + disable_per_channel_ || quant_specs_.disable_per_channel, + op_quant_spec_getter, infer_tensor_range, quant_specs_.legacy_float_scale, + (is_qdq_conversion_ || + quant_specs_.qdq_conversion_mode != quant::QDQConversionMode::kQDQNone)); } } // namespace diff --git a/tensorflow/compiler/mlir/lite/transforms/prepare_quantize_dynamic_range.cc b/tensorflow/compiler/mlir/lite/transforms/prepare_quantize_dynamic_range.cc index 951748b31273f3..a60ebe57212f9e 100644 --- a/tensorflow/compiler/mlir/lite/transforms/prepare_quantize_dynamic_range.cc +++ b/tensorflow/compiler/mlir/lite/transforms/prepare_quantize_dynamic_range.cc @@ -72,6 +72,8 @@ class PrepareDynamicRangeQuantizePass : quant_specs_(quant_specs) { enable_dynamic_range_per_channel_quantization_ = !quant_specs_.disable_per_channel; + enable_dynamic_range_per_channel_quantization_for_dense_layers_ = + !quant_specs_.disable_per_channel_for_dense_layers; min_elements_for_weights_ = quant_specs_.minimum_elements_for_weights; } @@ -275,6 +277,10 @@ class PrepareDynamicRangeQuantizableOp op_with_per_axis_support = op_with_narrow_range && affine_user.GetQuantizationDimIndex() != -1 && !quant_specs_.disable_per_channel; + if (dyn_cast(quantize_op)) { + op_with_per_axis_support &= + !quant_specs_.disable_per_channel_for_dense_layers; + } } QuantizedType quant_type = nullptr; @@ -473,6 +479,8 @@ void PrepareDynamicRangeQuantizePass::runOnOperation() { quant_specs_.disable_per_channel = !enable_dynamic_range_per_channel_quantization_; + quant_specs_.disable_per_channel_for_dense_layers = + !enable_dynamic_range_per_channel_quantization_for_dense_layers_; quant_specs_.minimum_elements_for_weights = min_elements_for_weights_; if (!enable_custom_op_quantization_.empty()) { diff --git a/tensorflow/compiler/mlir/lite/transforms/prepare_quantize_helper.h b/tensorflow/compiler/mlir/lite/transforms/prepare_quantize_helper.h index 90a48d577ef669..216c6756ab67db 100644 --- a/tensorflow/compiler/mlir/lite/transforms/prepare_quantize_helper.h +++ b/tensorflow/compiler/mlir/lite/transforms/prepare_quantize_helper.h @@ -505,9 +505,10 @@ class ConvertLstmStatsToQDQs : public ConvertOpStatsToQDQs { inline quant::AccumulatorScaleFunc GetUniformQuantizedTypeForBiasWithScale( double scale) { return [=](const std::vector& quant_params, - bool legacy_float_scale) -> quant::QuantParams { - if (auto qtype = quant::GetUniformQuantizedTypeForBias(quant_params, - legacy_float_scale) + const int adjusted_quant_dim, + const bool legacy_float_scale) -> quant::QuantParams { + if (auto qtype = quant::GetUniformQuantizedTypeForBias( + quant_params, legacy_float_scale, adjusted_quant_dim) .dyn_cast_or_null()) { return quant::UniformQuantizedType::get( qtype.getFlags(), qtype.getStorageType(), qtype.getExpressedType(), diff --git a/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc b/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc index c80be89c567e09..2e920595819f84 100644 --- a/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc +++ b/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc @@ -39,7 +39,6 @@ limitations under the License. #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/StringSwitch.h" -#include "llvm/Support/Casting.h" #include "llvm/Support/Debug.h" #include "mlir/Dialect/Affine/Analysis/LoopAnalysis.h" // from @llvm-project #include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project @@ -780,48 +779,6 @@ struct ConvertTFStridedSlice : public RewritePattern { } }; -struct ConvertTFBroadcastTo : public RewritePattern { - explicit ConvertTFBroadcastTo(MLIRContext *context) - : RewritePattern(TF::BroadcastToOp::getOperationName(), 1, context) {} - - LogicalResult matchAndRewrite(Operation *op, - PatternRewriter &rewriter) const override { - auto tf_broadcast_to_op = cast(op); - auto input_type = - tf_broadcast_to_op.getInput().getType().cast(); - auto output_type = - tf_broadcast_to_op.getOutput().getType().cast(); - auto shape_type = - tf_broadcast_to_op.getShape().getType().cast(); - Type element_type = input_type.getElementType(); - - // Allow lowering when low dimension inputs are given and its type is F32 or - // I32. - if (!((output_type.hasRank() && output_type.getRank() <= 4) || - (shape_type.hasStaticShape() && shape_type.getRank() == 1 && - shape_type.getDimSize(0) <= 4))) - return failure(); - - if (!(element_type.isa() || - element_type.isInteger(32) || element_type.isInteger(16))) - return failure(); - - auto status_or_const_op = - CreateConstOpWithSingleValue(&rewriter, op->getLoc(), input_type, 1); - if (!status_or_const_op.ok()) { - return failure(); - } - - auto tf_fill_op = rewriter.create(op->getLoc(), output_type, - tf_broadcast_to_op.getShape(), - status_or_const_op.value()); - - auto mul_op = rewriter.create( - op->getLoc(), output_type, tf_broadcast_to_op.getInput(), tf_fill_op); - rewriter.replaceOp(op, mul_op.getResult()); - return success(); - } -}; // The below pattern is equivalent to the DRR rule below // The checks are dependent on generated values, so we can't add @@ -1591,9 +1548,8 @@ void PrepareTFPass::runOnOperation() { if (unfold_batch_matmul_) { TF::PopulateUnrollTfBatchMatMul(ctx, phase_2_patterns); } - phase_2_patterns - .add(ctx); + phase_2_patterns.add(ctx); phase_2_patterns.add( ctx, allow_bf16_and_f16_type_legalization_); // Remove redundant reshape ops. diff --git a/tensorflow/compiler/mlir/lite/utils/constant_utils.cc b/tensorflow/compiler/mlir/lite/utils/constant_utils.cc index 86d0509ceb7e65..0e5d10e9e7469a 100644 --- a/tensorflow/compiler/mlir/lite/utils/constant_utils.cc +++ b/tensorflow/compiler/mlir/lite/utils/constant_utils.cc @@ -15,39 +15,47 @@ limitations under the License. #include "tensorflow/compiler/mlir/lite/utils/constant_utils.h" +#include +#include #include #include +#include "absl/status/status.h" #include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributeInterfaces.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/Location.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h" -#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/utils/mangling_util.h" #include "tensorflow/core/framework/tensor.pb.h" #include "tensorflow/core/framework/tensor_shape.pb.h" #include "tensorflow/core/platform/status.h" +#include "tsl/platform/statusor.h" namespace mlir { namespace TFL { -tsl::StatusOr CreateConstOpWithSingleValue( - PatternRewriter* rewriter, Location loc, ShapedType shaped_type, - int value) { +tsl::StatusOr CreateTypedAttr(ShapedType shaped_type, int value) { Type element_type = shaped_type.getElementType(); - ShapedType scalar_type = RankedTensorType::get({}, element_type); - TypedAttr attr; if (element_type.isF16()) { auto floatType = mlir::FloatType::getF16(element_type.getContext()); auto floatAttr = mlir::FloatAttr::get(floatType, static_cast(value)); std::vector floatValues({floatAttr}); - attr = DenseElementsAttr::get(scalar_type, floatValues); + return DenseElementsAttr::get(shaped_type, floatValues); } else if (element_type.isBF16()) { auto floatType = mlir::FloatType::getBF16(element_type.getContext()); auto floatAttr = mlir::FloatAttr::get(floatType, static_cast(value)); std::vector floatValues({floatAttr}); - attr = DenseElementsAttr::get(scalar_type, floatValues); + return DenseElementsAttr::get(shaped_type, floatValues); } else if (element_type.isF32()) { - attr = - DenseElementsAttr::get(scalar_type, static_cast(value)); + return DenseElementsAttr::get(shaped_type, + static_cast(value)); } else if (auto complex_type = element_type.dyn_cast()) { auto etype = complex_type.getElementType(); if (etype.isF32()) { @@ -64,7 +72,7 @@ tsl::StatusOr CreateConstOpWithSingleValue( repr.set_tensor_content(content); std::string mangled = tensorflow::mangling_util::MangleTensor(repr); - attr = mlir::TF::TensorProtoAttr::get(scalar_type, mangled); + return mlir::TF::TensorProtoAttr::get(shaped_type, mangled); } else { return tensorflow::Status(absl::StatusCode::kInvalidArgument, "Unsupported type"); @@ -73,19 +81,19 @@ tsl::StatusOr CreateConstOpWithSingleValue( if (element_type.isSignedInteger()) { switch (itype.getWidth()) { case 8: - attr = DenseElementsAttr::get(scalar_type, + return DenseElementsAttr::get(shaped_type, static_cast(value)); break; case 16: - attr = DenseElementsAttr::get(scalar_type, + return DenseElementsAttr::get(shaped_type, static_cast(value)); break; case 32: - attr = DenseElementsAttr::get(scalar_type, + return DenseElementsAttr::get(shaped_type, static_cast(value)); break; case 64: - attr = DenseElementsAttr::get(scalar_type, + return DenseElementsAttr::get(shaped_type, static_cast(value)); break; default: @@ -95,19 +103,19 @@ tsl::StatusOr CreateConstOpWithSingleValue( } else { switch (itype.getWidth()) { case 8: - attr = DenseElementsAttr::get(scalar_type, + return DenseElementsAttr::get(shaped_type, static_cast(value)); break; case 16: - attr = DenseElementsAttr::get(scalar_type, + return DenseElementsAttr::get(shaped_type, static_cast(value)); break; case 32: - attr = DenseElementsAttr::get(scalar_type, + return DenseElementsAttr::get(shaped_type, static_cast(value)); break; case 64: - attr = DenseElementsAttr::get(scalar_type, + return DenseElementsAttr::get(shaped_type, static_cast(value)); break; default: @@ -119,8 +127,29 @@ tsl::StatusOr CreateConstOpWithSingleValue( return tensorflow::Status(absl::StatusCode::kInvalidArgument, "Unsupported type"); } +} + +// Returns a Constant op with a splat vector value. +tsl::StatusOr CreateConstOpWithVectorValue( + PatternRewriter* rewriter, Location loc, ShapedType shaped_type, + int value) { + ShapedType dense_type = RankedTensorType::get(shaped_type.getShape(), + shaped_type.getElementType()); + auto attr = CreateTypedAttr(dense_type, value); + + return rewriter->create(loc, dense_type, + cast(*attr)); +} + +tsl::StatusOr CreateConstOpWithSingleValue( + PatternRewriter* rewriter, Location loc, ShapedType shaped_type, + int value) { + ShapedType scalar_type = + RankedTensorType::get({}, shaped_type.getElementType()); + auto attr = CreateTypedAttr(scalar_type, value); + return rewriter->create(loc, scalar_type, - cast(attr)); + cast(*attr)); } } // namespace TFL diff --git a/tensorflow/compiler/mlir/lite/utils/constant_utils.h b/tensorflow/compiler/mlir/lite/utils/constant_utils.h index 1a71bd55a85e8a..f062e31a557d17 100644 --- a/tensorflow/compiler/mlir/lite/utils/constant_utils.h +++ b/tensorflow/compiler/mlir/lite/utils/constant_utils.h @@ -31,6 +31,10 @@ namespace TFL { tsl::StatusOr CreateConstOpWithSingleValue( PatternRewriter* rewriter, Location loc, ShapedType shaped_type, int value); +// Returns a Constant op with a splat vector value. +tsl::StatusOr CreateConstOpWithVectorValue( + PatternRewriter* rewriter, Location loc, ShapedType shaped_type, int value); + } // namespace TFL } // namespace mlir #endif // TENSORFLOW_COMPILER_MLIR_LITE_UTILS_CONSTANT_UTILS_H_ diff --git a/tensorflow/compiler/mlir/lite/utils/utils.h b/tensorflow/compiler/mlir/lite/utils/utils.h index 6130bab6531ba2..9fce1bc44387c3 100644 --- a/tensorflow/compiler/mlir/lite/utils/utils.h +++ b/tensorflow/compiler/mlir/lite/utils/utils.h @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,16 +16,20 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_LITE_UTILS_UTILS_H_ #define TENSORFLOW_COMPILER_MLIR_LITE_UTILS_UTILS_H_ +#include #include #include #include #include "llvm/ADT/ArrayRef.h" +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributeInterfaces.h" // from @llvm-project #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project #include "mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project #include "mlir/IR/Matchers.h" // from @llvm-project #include "mlir/IR/Operation.h" // from @llvm-project #include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project namespace mlir { @@ -58,6 +62,44 @@ inline bool OpHasSameStaticShapes(Operation* op) { return true; } +// Checks if all elements in the constant attribute value are 1. +inline bool IsAllOnesConstant(Attribute value) { + auto values = value.cast().getValues(); + return !std::any_of(values.begin(), values.end(), + [](int32_t element_value) { return element_value != 1; }); +} + +// Checks if all elements in the constant attribute value are non-negative. +inline bool HasNonNegativeValues(Attribute value) { + auto values = value.cast().getValues(); + return !std::any_of( + values.begin(), values.end(), + [](const APInt& element_value) { return element_value.isNegative(); }); +} + +// Utility function to get the offset between two dense attribute values. +inline TypedAttr GetOffSet(Attribute begin, Attribute end) { + auto begin_values = begin.cast().getValues(); + auto end_values = end.cast().getValues(); + + SmallVector offsets; + if (begin_values.size() == end_values.size()) { + for (size_t i = 0; i < begin_values.size(); ++i) { + offsets.push_back(end_values[i] - begin_values[i]); + } + } + + return mlir::DenseElementsAttr::get( + RankedTensorType::get({static_cast(offsets.size())}, + mlir::IntegerType::get(begin.getContext(), 32)), + llvm::ArrayRef(offsets)); +} + +// Check if the offset between two dense attribute values is non-negative. +inline bool HasNonNegativeOffset(Attribute begin, Attribute end) { + return HasNonNegativeValues(GetOffSet(begin, end)); +} + // Return true if the permutation value only swaps the last two dimensions inline bool AreLastTwoDimsTransposed(Value permutation) { if (!permutation) return false; diff --git a/tensorflow/compiler/mlir/lite/utils/utils.td b/tensorflow/compiler/mlir/lite/utils/utils.td index e64b591ae78eda..42af8c67b2a7ce 100644 --- a/tensorflow/compiler/mlir/lite/utils/utils.td +++ b/tensorflow/compiler/mlir/lite/utils/utils.td @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -23,6 +23,23 @@ include "mlir/IR/PatternBase.td" // if called without a ranked tensor it will fail. def GetShape: NativeCodeCall<"GetShape($0)">; +// Constraint that values in list attribute are all ones. +def IsAllOnesConstant : Constraint>; + +// Constraint that checks if all values in offset between two +// attributes are non-negative. +def HasNonNegativeOffset : Constraint>; + +// Constraint that checks if all values in list attribute are non-negative. +def HasNonNegativeValues : Constraint>; + +// Utility function to get the offset between two dense attribute values. +def GetOffSet : NativeCodeCall<"TFL::GetOffSet($0, $1)">; + +// Attribute Constraint that checks if the attribute value is zero. +def ZeroIntAttr + : AttrConstraint().getInt() == 0">>; + // Checks if the value has rank at most 'n'. class HasRankAtLeast : Constraint< CPred<"$0.getType().cast().hasRank() && " diff --git a/tensorflow/compiler/mlir/quantization/common/BUILD b/tensorflow/compiler/mlir/quantization/common/BUILD index e15f729ff814e5..421b3df68642b2 100644 --- a/tensorflow/compiler/mlir/quantization/common/BUILD +++ b/tensorflow/compiler/mlir/quantization/common/BUILD @@ -52,6 +52,7 @@ tf_cc_test( name = "lift_as_function_call_test", srcs = ["lift_as_function_call_test.cc"], deps = [ + ":func", ":lift_as_function_call", ":test_base", "//tensorflow/compiler/mlir/tensorflow", @@ -65,6 +66,35 @@ tf_cc_test( ], ) +cc_library( + name = "func", + srcs = ["func.cc"], + hdrs = ["func.h"], + compatible_with = get_compatible_with_portable(), + deps = [ + "//tensorflow/cc/saved_model:signature_constants", + "//tensorflow/compiler/mlir/tensorflow:import_model", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", + ], +) + +tf_cc_test( + name = "func_test", + srcs = ["func_test.cc"], + compatible_with = get_compatible_with_portable(), + deps = [ + ":func", + ":test_base", + "@com_google_absl//absl/strings:string_view", + "@com_google_googletest//:gtest_main", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + ], +) + cc_library( name = "test_base", testonly = 1, @@ -101,9 +131,9 @@ cc_library( "//tensorflow/compiler/mlir/lite/quantization:quantization_lib", "//tensorflow/compiler/mlir/quantization/tensorflow:quantization_options_proto_cc", "//tensorflow/compiler/mlir/tensorflow", + "//tensorflow/compiler/mlir/tensorflow:xla_call_module_attrs", "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", - "@llvm-project//mlir:QuantOps", "@llvm-project//mlir:Support", ], ) @@ -113,11 +143,12 @@ tf_cc_test( srcs = ["attrs_and_constraints_test.cc"], deps = [ ":attrs_and_constraints", + ":func", ":test_base", - "//tensorflow/compiler/mlir/lite/quantization/ir:QuantOps", "//tensorflow/compiler/mlir/tensorflow", "@com_google_absl//absl/strings:string_view", "@com_google_googletest//:gtest_main", + "@llvm-project//llvm:Support", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:Support", diff --git a/tensorflow/compiler/mlir/quantization/common/attrs_and_constraints.cc b/tensorflow/compiler/mlir/quantization/common/attrs_and_constraints.cc index 9c3192d5345f28..1d2ccbdaaf4d2b 100644 --- a/tensorflow/compiler/mlir/quantization/common/attrs_and_constraints.cc +++ b/tensorflow/compiler/mlir/quantization/common/attrs_and_constraints.cc @@ -14,15 +14,18 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/mlir/quantization/common/attrs_and_constraints.h" +#include + #include "llvm/ADT/STLExtras.h" -#include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project +#include "llvm/Support/Debug.h" +#include "llvm/Support/MathExtras.h" #include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project -#include "mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/IR/Operation.h" // from @llvm-project #include "mlir/IR/Types.h" // from @llvm-project #include "mlir/IR/Value.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h" // IWYU pragma: keep namespace mlir::quant { @@ -59,4 +62,30 @@ SmallVector CloneOpWithReplacedOperands( return builder.clone(*op, mapping)->getResults(); } +FailureOr CastI64ToI32(const int64_t value) { + if (!llvm::isInt<32>(value)) { + DEBUG_WITH_TYPE( + "mlir-quant-attrs-and-constraints", + llvm::dbgs() + << "Tried to cast " << value + << "from int64 to int32, but lies out of range of int32.\n"); + return failure(); + } + return static_cast(value); +} + +FailureOr> CastI64ArrayToI32( + const ArrayRef int64_array) { + SmallVector int32_array{}; + int32_array.reserve(int64_array.size()); + + for (const int64_t i64 : int64_array) { + FailureOr cast_i32 = CastI64ToI32(i64); + if (failed(cast_i32)) return failure(); + + int32_array.push_back(*cast_i32); + } + return int32_array; +} + } // namespace mlir::quant diff --git a/tensorflow/compiler/mlir/quantization/common/attrs_and_constraints.h b/tensorflow/compiler/mlir/quantization/common/attrs_and_constraints.h index 8f298b56ec947e..2600a0547c563d 100644 --- a/tensorflow/compiler/mlir/quantization/common/attrs_and_constraints.h +++ b/tensorflow/compiler/mlir/quantization/common/attrs_and_constraints.h @@ -29,6 +29,7 @@ limitations under the License. #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "tensorflow/compiler/mlir/quantization/tensorflow/quantization_options.pb.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/xla_call_module_attrs.h" namespace mlir::quant { @@ -150,6 +151,36 @@ FailureOr TryCast(Operation *op, const StringRef name) { } } +FailureOr CastI64ToI32(int64_t value); + +// Tries to cast an array of int64 to int32. If any of the element in the +// array is not in the range of int32, returns failure(). +FailureOr> CastI64ArrayToI32( + ArrayRef int64_array); + +// Returns the first user of the given operation, optionally of the given +// type if provided. If there is no user or user of type, return nullptr. +template +Operation *FindUserOfType(Operation *op) { + for (Operation *user : op->getUsers()) { + if (isa(user)) { + return user; + } + } + return nullptr; +} + +// Returns the function attribute for the given call op which is lifted for +// quantization. +inline FlatSymbolRefAttr GetFuncAttr(TF::PartitionedCallOp call_op) { + return call_op.getFAttr().template dyn_cast(); +} + +inline FlatSymbolRefAttr GetFuncAttr(TF::XlaCallModuleOp call_op) { + return call_op->getAttrOfType( + TF::kStablehloEntryFunctionAttrName); +} + } // namespace mlir::quant #endif // TENSORFLOW_COMPILER_MLIR_QUANTIZATION_COMMON_ATTRS_AND_CONSTRAINTS_H_ diff --git a/tensorflow/compiler/mlir/quantization/common/attrs_and_constraints.td b/tensorflow/compiler/mlir/quantization/common/attrs_and_constraints.td index 09bcfe4533d61b..2a540fafe9aa0b 100644 --- a/tensorflow/compiler/mlir/quantization/common/attrs_and_constraints.td +++ b/tensorflow/compiler/mlir/quantization/common/attrs_and_constraints.td @@ -69,6 +69,10 @@ def IsUniformQuantizedType : Constraint< def AreTheSameElementType : Constraint< CPred<"$0.getType() == $1.getType()">>; +// Checks if the given two values are the same. +def AreTheSameValue : Constraint< + CPred<"$0 == $1">>; + // Checks if the value has rank. def HasRank : Constraint< CPred<"$0.getType().cast().hasRank()">>; diff --git a/tensorflow/compiler/mlir/quantization/common/attrs_and_constraints_test.cc b/tensorflow/compiler/mlir/quantization/common/attrs_and_constraints_test.cc index afa7d44e1b595e..2c466b4415818b 100644 --- a/tensorflow/compiler/mlir/quantization/common/attrs_and_constraints_test.cc +++ b/tensorflow/compiler/mlir/quantization/common/attrs_and_constraints_test.cc @@ -15,8 +15,12 @@ limitations under the License. #include "tensorflow/compiler/mlir/quantization/common/attrs_and_constraints.h" +#include + +#include #include #include "absl/strings/string_view.h" +#include "llvm/Support/MathExtras.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project @@ -25,20 +29,26 @@ limitations under the License. #include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "stablehlo/dialect/StablehloOps.h" // from @stablehlo +#include "tensorflow/compiler/mlir/quantization/common/func.h" #include "tensorflow/compiler/mlir/quantization/common/test_base.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" namespace mlir::quant { namespace { using ::mlir::quant::QuantizationTestBase; using ::mlir::stablehlo::AddOp; +using ::mlir::stablehlo::ConvolutionOp; using ::mlir::stablehlo::DotGeneralOp; +using ::mlir::stablehlo::SubtractOp; +using ::testing::ElementsAreArray; +using ::testing::NotNull; class AttrsAndConstraintsTest : public QuantizationTestBase {}; constexpr absl::string_view kModuleStatic = R"mlir( module { - func.func private @main(%arg0: tensor<1x1024xf32>, %arg1: tensor<1024x3xf32>) -> tensor<1x3xf32> attributes {_from_xla_call_module} { + func.func @main(%arg0: tensor<1x1024xf32>, %arg1: tensor<1024x3xf32>) -> tensor<1x3xf32> attributes {_from_xla_call_module} { %0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0], precision = [] : (tensor<1x1024xf32>, tensor<1024x3xf32>) -> tensor<1x3xf32> return %0 : tensor<1x3xf32> } @@ -47,16 +57,57 @@ constexpr absl::string_view kModuleStatic = R"mlir( constexpr absl::string_view kModuleDynamic = R"mlir( module { - func.func private @main(%arg0: tensor, %arg1: tensor<1024x3xf32>) -> tensor attributes {_from_xla_call_module} { + func.func @main(%arg0: tensor, %arg1: tensor<1024x3xf32>) -> tensor attributes {_from_xla_call_module} { %0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0], precision = [] : (tensor, tensor<1024x3xf32>) -> tensor return %0 : tensor } } )mlir"; +constexpr absl::string_view kModuleMultipleUses = R"mlir( + module { + func.func @main(%arg0: tensor<1x1024xf32>, %arg1: tensor<1024x3xf32>, %arg2: tensor<1x3xf32>) -> tensor<1x3xf32> attributes {_from_xla_call_module} { + %0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0], precision = [] : (tensor<1x1024xf32>, tensor<1024x3xf32>) -> tensor<1x3xf32> + %1 = stablehlo.subtract %0, %arg2 : tensor<1x3xf32> + %2 = stablehlo.add %0, %arg2 : tensor<1x3xf32> + return %2 : tensor<1x3xf32> + } + } +)mlir"; + +constexpr absl::string_view kModuleXlaCallModule = R"mlir( + module { + func.func @main(%arg0: tensor {tf_saved_model.index_path = ["input_tensor"]}) -> (tensor) { + %0 = stablehlo.constant dense<[-0.211145893, -0.708605706]> : tensor<2xf32> + %1 = stablehlo.constant dense<[[-0.630731344, 0.54962182], [0.180364341, -0.764542698]]> : tensor<2x2xf32> + %2 = "tf.XlaCallModule"(%arg0, %1, %0) <{Sout = [#tf_type.shape], module = "", version = 9 : i64}> {_entry_function = @composite_fn_1, _original_entry_function = "composite_fn_1", _tfl_quant_trait = "fully_quantizable"} : (tensor, tensor<2x2xf32>, tensor<2xf32>) -> tensor + return %2 : tensor + } + func.func private @composite_fn_1(%arg0: tensor, %arg1: tensor<2x2xf32>, %arg2: tensor<2xf32>) -> tensor attributes {_from_xla_call_module, tf_quant.composite_function} { + return %arg0 : tensor + } + } +)mlir"; + +constexpr absl::string_view kModulePartitionedCall = R"mlir( + module { + func.func @main(%arg0: tensor<2x2xf32> {tf_saved_model.index_path = ["input_tensor"]}) -> (tensor<2x2xf32>) { + %cst = "tf.Const"() {device = "", value = dense<[[-0.630731344, 0.54962182], [0.180364341, -0.764542698]]> : tensor<2x2xf32>} : () -> tensor<2x2xf32> + %0 = "tf.PartitionedCall"(%arg0, %cst) {_tfl_quant_trait = "fully_quantizable", config = "", config_proto = "", executor_type = "", f = @composite_fn_1} : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> loc(callsite("test@main"("MatMul") at "QuantizationUnit(\12\06MatMul\1a\07main)")) + return %0 : tensor<2x2xf32> + } + func.func private @composite_fn_1(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2xf32>) -> tensor<2x2xf32> attributes {tf_quant.composite_function} { + %0 = "tf.MatMul"(%arg0, %arg1) {attr_map = "0:transpose_a,1:transpose_b", device = "", transpose_a = false, transpose_b = false} : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> + return %0 : tensor<2x2xf32> + } + } +)mlir"; + TEST_F(AttrsAndConstraintsTest, HasStaticShapeSucceedsWithStaticShapes) { OwningOpRef module_op_ref = ParseModuleOpString(kModuleStatic); - func::FuncOp main_fn = GetFunctionFromModule(*module_op_ref, "main"); + func::FuncOp main_fn = FindMainFuncOp(*module_op_ref); + ASSERT_THAT(main_fn, NotNull()); + Value dot_general_result = FindOperationOfType(main_fn)->getResult(0); EXPECT_TRUE(HasStaticShape(dot_general_result)); @@ -66,7 +117,9 @@ TEST_F(AttrsAndConstraintsTest, HasStaticShapeSucceedsWithStaticShapes) { TEST_F(AttrsAndConstraintsTest, HasStaticShapeFailsWithDynamicShapes) { OwningOpRef module_op_ref = ParseModuleOpString(kModuleDynamic); - func::FuncOp main_fn = GetFunctionFromModule(*module_op_ref, "main"); + func::FuncOp main_fn = FindMainFuncOp(*module_op_ref); + ASSERT_THAT(main_fn, NotNull()); + Value dot_general_result = FindOperationOfType(main_fn)->getResult(0); EXPECT_FALSE(HasStaticShape(dot_general_result)); @@ -76,7 +129,9 @@ TEST_F(AttrsAndConstraintsTest, HasStaticShapeFailsWithDynamicShapes) { TEST_F(AttrsAndConstraintsTest, TryCastSucceeds) { OwningOpRef module_op_ref = ParseModuleOpString(kModuleStatic); - func::FuncOp main_fn = GetFunctionFromModule(*module_op_ref, "main"); + func::FuncOp main_fn = FindMainFuncOp(*module_op_ref); + ASSERT_THAT(main_fn, NotNull()); + Operation* dot_general_op = FindOperationOfType(main_fn); EXPECT_TRUE(succeeded( TryCast(dot_general_op, /*name=*/"dot_general_op"))); @@ -84,7 +139,9 @@ TEST_F(AttrsAndConstraintsTest, TryCastSucceeds) { TEST_F(AttrsAndConstraintsTest, TryCastFailsOnWrongType) { OwningOpRef module_op_ref = ParseModuleOpString(kModuleStatic); - func::FuncOp main_fn = GetFunctionFromModule(*module_op_ref, "main"); + func::FuncOp main_fn = FindMainFuncOp(*module_op_ref); + ASSERT_THAT(main_fn, NotNull()); + Operation* dot_general_op = FindOperationOfType(main_fn); EXPECT_TRUE( failed(TryCast(dot_general_op, /*name=*/"dot_general_op"))); @@ -92,7 +149,9 @@ TEST_F(AttrsAndConstraintsTest, TryCastFailsOnWrongType) { TEST_F(AttrsAndConstraintsTest, TryCastFailsOnNullPtr) { OwningOpRef module_op_ref = ParseModuleOpString(kModuleStatic); - func::FuncOp main_fn = GetFunctionFromModule(*module_op_ref, "main"); + func::FuncOp main_fn = FindMainFuncOp(*module_op_ref); + ASSERT_THAT(main_fn, NotNull()); + Operation* op_nullptr = FindOperationOfType(main_fn)->getNextNode()->getNextNode(); // getNextNode() returns a nullptr if at the very last node. @@ -101,5 +160,74 @@ TEST_F(AttrsAndConstraintsTest, TryCastFailsOnNullPtr) { EXPECT_TRUE(failed(TryCast(nullptr, /*name=*/"nullptr"))); } +TEST_F(AttrsAndConstraintsTest, I64ValueInI32RangeAreCastedCorrectly) { + EXPECT_TRUE(succeeded(CastI64ToI32(llvm::minIntN(32)))); + EXPECT_TRUE(succeeded(CastI64ToI32(llvm::maxIntN(32)))); +} + +TEST_F(AttrsAndConstraintsTest, CastingFailsForI64ValueOutOfI32Range) { + EXPECT_TRUE(failed(CastI64ToI32(llvm::minIntN(32) - 10))); + EXPECT_TRUE(failed(CastI64ToI32(llvm::maxIntN(32) + 10))); +} + +TEST_F(AttrsAndConstraintsTest, I64ArrayInI32RangeAreCastedCorrectly) { + const SmallVector array_i64 = {llvm::minIntN(32), -2, -1, 0, 1, 2, + llvm::maxIntN(32)}; + + FailureOr> array_i32 = CastI64ArrayToI32(array_i64); + EXPECT_TRUE(succeeded(array_i32)); + EXPECT_THAT( + *array_i32, + ElementsAreArray({static_cast(llvm::minIntN(32)), -2, -1, 0, 1, + 2, static_cast(llvm::maxIntN(32))})); +} + +TEST_F(AttrsAndConstraintsTest, CastingFailsForI64ArrayUnderI32Range) { + const int64_t under_min_i32 = -2147483658; + ArrayRef array_i64{under_min_i32}; + EXPECT_EQ(under_min_i32, llvm::minIntN(32) - 10); + EXPECT_TRUE(failed(CastI64ArrayToI32(array_i64))); +} + +TEST_F(AttrsAndConstraintsTest, CastingFailsForI64ArrayAboveI32Range) { + const int64_t below_max_i32 = 2147483657; + ArrayRef array_i64{below_max_i32}; + EXPECT_EQ(below_max_i32, llvm::maxIntN(32) + 10); + EXPECT_TRUE(failed(CastI64ArrayToI32(array_i64))); +} + +TEST_F(AttrsAndConstraintsTest, FindUserOfDifferentTypes) { + OwningOpRef module_op_ref = + ParseModuleOpString(kModuleMultipleUses); + func::FuncOp main_fn = FindMainFuncOp(*module_op_ref); + ASSERT_THAT(main_fn, NotNull()); + + Operation* dot_general_op = FindOperationOfType(main_fn); + ASSERT_NE(FindUserOfType(dot_general_op), nullptr); + ASSERT_NE(FindUserOfType(dot_general_op), nullptr); + ASSERT_NE(FindUserOfType<>(dot_general_op), nullptr); + ASSERT_EQ(FindUserOfType(dot_general_op), nullptr); +} + +TEST_F(AttrsAndConstraintsTest, CallGetFuncAttr) { + OwningOpRef xla_module_op_ref = + ParseModuleOpString(kModuleXlaCallModule); + func::FuncOp xml_main_fn = FindMainFuncOp(*xla_module_op_ref); + Operation* xla_op = FindOperationOfType(xml_main_fn); + auto xla_call_op = dyn_cast_or_null(*xla_op); + FlatSymbolRefAttr xla_call_op_attr = GetFuncAttr(xla_call_op); + EXPECT_EQ(xla_call_op_attr.getValue(), "composite_fn_1"); + + OwningOpRef partitioned_module_op_ref = + ParseModuleOpString(kModulePartitionedCall); + func::FuncOp partitioned_main_fn = FindMainFuncOp(*partitioned_module_op_ref); + Operation* partitioned_op = + FindOperationOfType(partitioned_main_fn); + auto partitioned_call_op = + dyn_cast_or_null(*partitioned_op); + FlatSymbolRefAttr partitioned_call_op_attr = GetFuncAttr(partitioned_call_op); + EXPECT_EQ(partitioned_call_op_attr.getValue(), "composite_fn_1"); +} + } // namespace } // namespace mlir::quant diff --git a/tensorflow/compiler/mlir/quantization/common/func.cc b/tensorflow/compiler/mlir/quantization/common/func.cc new file mode 100644 index 00000000000000..5849289e6d7ebd --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/common/func.cc @@ -0,0 +1,55 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/compiler/mlir/quantization/common/func.h" + +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/SymbolTable.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "tensorflow/cc/saved_model/signature_constants.h" +#include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h" + +namespace mlir::quant { +namespace { + +using ::tensorflow::kDefaultServingSignatureDefKey; +using ::tensorflow::kImportModelDefaultGraphFuncName; + +// Returns true iff the function's symbol is public. +bool IsPublicFuncOp(func::FuncOp func_op) { + return SymbolTable::getSymbolVisibility(&*func_op) == + SymbolTable::Visibility::Public; +} + +} // namespace + +func::FuncOp FindMainFuncOp(ModuleOp module_op) { + if (const auto main_func_op = module_op.lookupSymbol( + kImportModelDefaultGraphFuncName); + main_func_op != nullptr && IsPublicFuncOp(main_func_op)) { + return main_func_op; + } + + if (const auto serving_default_func_op = + module_op.lookupSymbol(kDefaultServingSignatureDefKey); + serving_default_func_op != nullptr && + IsPublicFuncOp(serving_default_func_op)) { + return serving_default_func_op; + } + + return nullptr; +} + +} // namespace mlir::quant diff --git a/tensorflow/compiler/mlir/quantization/common/func.h b/tensorflow/compiler/mlir/quantization/common/func.h new file mode 100644 index 00000000000000..ade7bcfc71027b --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/common/func.h @@ -0,0 +1,31 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_QUANTIZATION_COMMON_FUNC_H_ +#define TENSORFLOW_COMPILER_MLIR_QUANTIZATION_COMMON_FUNC_H_ + +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project + +namespace mlir::quant { + +// Returns a public `func::FuncOp` in `module_op` whose name matches either +// `main` or `serving_default`. If `func::FuncOps` with both names exist, the +// function with name "main" takes precedence. Returns null if no such a +// function exists. +func::FuncOp FindMainFuncOp(ModuleOp module_op); + +} // namespace mlir::quant + +#endif // TENSORFLOW_COMPILER_MLIR_QUANTIZATION_COMMON_FUNC_H_ diff --git a/tensorflow/compiler/mlir/quantization/common/func_test.cc b/tensorflow/compiler/mlir/quantization/common/func_test.cc new file mode 100644 index 00000000000000..8555da63b71feb --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/common/func_test.cc @@ -0,0 +1,113 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/compiler/mlir/quantization/common/func.h" + +#include +#include +#include "absl/strings/string_view.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/OwningOpRef.h" // from @llvm-project +#include "tensorflow/compiler/mlir/quantization/common/test_base.h" + +namespace mlir::quant { +namespace { + +using ::testing::IsNull; +using ::testing::NotNull; + +class FindMainFuncOpTest : public QuantizationTestBase {}; + +TEST_F(FindMainFuncOpTest, ReturnsMainFuncOp) { + constexpr absl::string_view kModuleWithMainFunc = R"mlir( + module { + func.func @main() -> () { + return + } + } + )mlir"; + + OwningOpRef module_op = ParseModuleOpString(kModuleWithMainFunc); + EXPECT_THAT(*module_op, NotNull()); + + func::FuncOp main_func_op = FindMainFuncOp(*module_op); + EXPECT_THAT(main_func_op, NotNull()); +} + +TEST_F(FindMainFuncOpTest, ReturnsNullWhenMainFuncOpIsPrivate) { + constexpr absl::string_view kModuleWithPrivateMainFunc = R"mlir( + module { + func.func private @main() -> () { + return + } + } + )mlir"; + + OwningOpRef module_op = + ParseModuleOpString(kModuleWithPrivateMainFunc); + EXPECT_THAT(*module_op, NotNull()); + + EXPECT_THAT(FindMainFuncOp(*module_op), IsNull()); +} + +TEST_F(FindMainFuncOpTest, ReturnsServingDefaultFuncOp) { + constexpr absl::string_view kModuleWithServingDefaultFunc = R"mlir( + module { + func.func @serving_default() -> () { + return + } + } + )mlir"; + + OwningOpRef module_op = + ParseModuleOpString(kModuleWithServingDefaultFunc); + EXPECT_THAT(*module_op, NotNull()); + + EXPECT_THAT(FindMainFuncOp(*module_op), NotNull()); +} + +TEST_F(FindMainFuncOpTest, ReturnsNullWhenServingDefaultFuncOpIsPrivate) { + constexpr absl::string_view kModuleWithPrivateServingDefaultFunc = R"mlir( + module { + func.func private @serving_default() -> () { + return + } + } + )mlir"; + + OwningOpRef module_op = + ParseModuleOpString(kModuleWithPrivateServingDefaultFunc); + EXPECT_THAT(*module_op, NotNull()); + + EXPECT_THAT(FindMainFuncOp(*module_op), IsNull()); +} + +TEST_F(FindMainFuncOpTest, ReturnsNullWhenMainFuncNotFound) { + constexpr absl::string_view kModuleWithNoMainFunc = R"mlir( + module { + func.func @foo() -> () { + return + } + } + )mlir"; + + OwningOpRef module_op = ParseModuleOpString(kModuleWithNoMainFunc); + EXPECT_THAT(*module_op, NotNull()); + + EXPECT_THAT(FindMainFuncOp(*module_op), IsNull()); +} + +} // namespace +} // namespace mlir::quant diff --git a/tensorflow/compiler/mlir/quantization/common/lift_as_function_call_test.cc b/tensorflow/compiler/mlir/quantization/common/lift_as_function_call_test.cc index 9140c372f68f55..2a1b10cc3163a2 100644 --- a/tensorflow/compiler/mlir/quantization/common/lift_as_function_call_test.cc +++ b/tensorflow/compiler/mlir/quantization/common/lift_as_function_call_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/quantization/common/lift_as_function_call.h" +#include #include #include "absl/strings/string_view.h" #include "llvm/ADT/SmallVector.h" @@ -28,13 +29,14 @@ limitations under the License. #include "mlir/IR/Value.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project #include "stablehlo/dialect/StablehloOps.h" // from @stablehlo +#include "tensorflow/compiler/mlir/quantization/common/func.h" #include "tensorflow/compiler/mlir/quantization/common/test_base.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" namespace mlir::quant { namespace { -using ::mlir::quant::QuantizationTestBase; +using ::testing::NotNull; class LiftAsFunctionCallTest : public QuantizationTestBase {}; @@ -49,8 +51,10 @@ constexpr absl::string_view kModuleLifted = R"mlir( TEST_F(LiftAsFunctionCallTest, LiftedFunctionSucceeds) { OwningOpRef module_op_ref = ParseModuleOpString(kModuleLifted); - func::FuncOp composite_dot_general_fn = - GetFunctionFromModule(*module_op_ref, "composite_dot_general_fn_1"); + auto composite_dot_general_fn = + module_op_ref->lookupSymbol("composite_dot_general_fn_1"); + ASSERT_THAT(composite_dot_general_fn, NotNull()); + Operation* dot_general_op = FindOperationOfType( composite_dot_general_fn); @@ -59,7 +63,7 @@ TEST_F(LiftAsFunctionCallTest, LiftedFunctionSucceeds) { constexpr absl::string_view kModuleStableHlo = R"mlir( module { - func.func private @main(%arg0: tensor<1x1024xf32>, %arg1: tensor<1024x3xf32>) -> tensor<1x3xf32> attributes {_from_xla_call_module} { + func.func @main(%arg0: tensor<1x1024xf32>, %arg1: tensor<1024x3xf32>) -> tensor<1x3xf32> attributes {_from_xla_call_module} { %0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0], precision = [] : (tensor<1x1024xf32>, tensor<1024x3xf32>) -> tensor<1x3xf32> return %0 : tensor<1x3xf32> } @@ -68,7 +72,9 @@ constexpr absl::string_view kModuleStableHlo = R"mlir( TEST_F(LiftAsFunctionCallTest, FunctionLiftedAsXlaCallModuleOp) { OwningOpRef module_op_ref = ParseModuleOpString(kModuleStableHlo); - func::FuncOp main_fn = GetFunctionFromModule(*module_op_ref, "main"); + func::FuncOp main_fn = FindMainFuncOp(*module_op_ref); + ASSERT_THAT(main_fn, NotNull()); + Operation* dot_general_op = FindOperationOfType(main_fn); @@ -106,7 +112,9 @@ TEST_F(LiftAsFunctionCallTest, FunctionLiftedAsXlaCallModuleOp) { TEST_F(LiftAsFunctionCallTest, FunctionNoAttrLiftedAsXlaCallModuleOp) { OwningOpRef module_op_ref = ParseModuleOpString(kModuleStableHlo); - func::FuncOp main_fn = GetFunctionFromModule(*module_op_ref, "main"); + func::FuncOp main_fn = FindMainFuncOp(*module_op_ref); + ASSERT_THAT(main_fn, NotNull()); + Operation* dot_general_op = FindOperationOfType(main_fn); Operation* lifted_op = diff --git a/tensorflow/compiler/mlir/quantization/common/test_base.h b/tensorflow/compiler/mlir/quantization/common/test_base.h index f1c78a960995e8..1068a42a615027 100644 --- a/tensorflow/compiler/mlir/quantization/common/test_base.h +++ b/tensorflow/compiler/mlir/quantization/common/test_base.h @@ -27,7 +27,6 @@ limitations under the License. #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project #include "mlir/IR/OwningOpRef.h" // from @llvm-project -#include "mlir/IR/SymbolTable.h" // from @llvm-project #include "mlir/Parser/Parser.h" // from @llvm-project #include "stablehlo/dialect/StablehloOps.h" // from @stablehlo #include "tensorflow/compiler/mlir/lite/quantization/ir/QuantOps.h" @@ -63,13 +62,6 @@ class QuantizationTestBase : public Test { return module_op_ref; } - // Gets the function with the given name from the module. - func::FuncOp GetFunctionFromModule(ModuleOp module, - absl::string_view function_name) { - SymbolTable symbol_table(module); - return symbol_table.lookup(function_name); - } - // Returns the first operation with the given type in the function. template OpType FindOperationOfType(func::FuncOp function) { diff --git a/tensorflow/compiler/mlir/quantization/common/uniform_quantized_types.cc b/tensorflow/compiler/mlir/quantization/common/uniform_quantized_types.cc index 09f0fa78f1508e..5cee0692080b2a 100644 --- a/tensorflow/compiler/mlir/quantization/common/uniform_quantized_types.cc +++ b/tensorflow/compiler/mlir/quantization/common/uniform_quantized_types.cc @@ -169,6 +169,30 @@ bool IsI32F32UniformQuantizedType(const Type type) { return true; } +bool IsI32F32UniformQuantizedPerAxisType(const Type type) { + const UniformQuantizedPerAxisType quantized_per_axis_type = + type.dyn_cast_or_null(); + if (!quantized_per_axis_type) { + LLVM_DEBUG(llvm::dbgs() + << "Expected a uniform quantized type. Got: " << type << ".\n"); + return false; + } + + if (!IsStorageTypeI32(quantized_per_axis_type)) { + LLVM_DEBUG(llvm::dbgs() << "Expected an i32 storage type. Got: " + << quantized_per_axis_type << ".\n"); + return false; + } + + if (!IsExpressedTypeF32(quantized_per_axis_type)) { + LLVM_DEBUG(llvm::dbgs() << "Expected an f32 expressed type. Got: " + << quantized_per_axis_type << ".\n"); + return false; + } + + return true; +} + // Determines whether the storage type of a quantized type is supported by // `tfl.quantize` or `tfl.dequantize` ops. ui8, i8 and i16 are supported. bool IsSupportedByTfliteQuantizeOrDequantizeOps(IntegerType storage_type) { diff --git a/tensorflow/compiler/mlir/quantization/common/uniform_quantized_types.h b/tensorflow/compiler/mlir/quantization/common/uniform_quantized_types.h index 816002b5d5b942..f1c94302d816b3 100644 --- a/tensorflow/compiler/mlir/quantization/common/uniform_quantized_types.h +++ b/tensorflow/compiler/mlir/quantization/common/uniform_quantized_types.h @@ -90,6 +90,10 @@ bool IsI8F32UniformQuantizedPerAxisType(Type type); // 32-bit integer and expressed type is f32. bool IsI32F32UniformQuantizedType(Type type); +// Returns true iff `type` is a uniform quantized per-axis (per-channel) type +// whose storage type is 32-bit integer and expressed type is f32. +bool IsI32F32UniformQuantizedPerAxisType(Type type); + // Determines whether the storage type of a quantized type is supported by // `tfl.quantize` or `tfl.dequantize` ops. ui8, i8 and i16 are supported. bool IsSupportedByTfliteQuantizeOrDequantizeOps(IntegerType storage_type); diff --git a/tensorflow/compiler/mlir/quantization/common/uniform_quantized_types_test.cc b/tensorflow/compiler/mlir/quantization/common/uniform_quantized_types_test.cc index 26583c9b348c14..10499526873f2c 100644 --- a/tensorflow/compiler/mlir/quantization/common/uniform_quantized_types_test.cc +++ b/tensorflow/compiler/mlir/quantization/common/uniform_quantized_types_test.cc @@ -31,6 +31,7 @@ namespace quant { namespace { using ::testing::ElementsAreArray; +using ::testing::IsNull; using ::testing::NotNull; using ::testing::Test; @@ -47,7 +48,7 @@ TEST_F(CreateI8F32UniformQuantizedTypeTest, I8StorageTypeSucceeds) { const UniformQuantizedType quantized_type = CreateI8F32UniformQuantizedType(UnknownLoc::get(&ctx_), ctx_, /*scale=*/1.0, /*zero_point=*/0); - + // Storage type of `i8` is currently verifiable as `unsigned` in `Types.cpp`. EXPECT_TRUE(quantized_type.getStorageType().isSignlessInteger(8)); } @@ -108,6 +109,7 @@ TEST_F(CreateI32F32UniformQuantizedTypeTest, I32StorageTypeSucceeds) { CreateI32F32UniformQuantizedType(UnknownLoc::get(&ctx_), ctx_, /*scale=*/1.0, /*zero_point=*/0); + // Storage type of `i32` is currently verifiable as `unsigned` in `Types.cpp`. EXPECT_TRUE(quantized_type.getStorageType().isSignlessInteger(32)); } @@ -165,6 +167,7 @@ TEST_F(CreateI8F32UniformQuantizedPerAxisTypeTest, I8StorageTypeSucceeds) { /*zero_points=*/SmallVector{0, 0}, /*quantization_dimension=*/0); + // Storage type of `i8` is currently verifiable as `unsigned` in `Types.cpp`. EXPECT_TRUE(quantized_type.getStorageType().isSignlessInteger(8)); } @@ -242,62 +245,139 @@ TEST_F(CreateI8F32UniformQuantizedPerAxisTypeTest, EXPECT_THAT(quantized_type.getZeroPoints(), ElementsAreArray({98, 99})); } +class CreateI32F32UniformQuantizedPerAxisTypeTest : public Test { + protected: + CreateI32F32UniformQuantizedPerAxisTypeTest() : ctx_() { + ctx_.loadDialect(); + } + + MLIRContext ctx_; +}; + +TEST_F(CreateI32F32UniformQuantizedPerAxisTypeTest, I32StorageTypeSucceeds) { + const UniformQuantizedPerAxisType quantized_type = + CreateI32F32UniformQuantizedPerAxisType( + UnknownLoc::get(&ctx_), ctx_, + /*scales=*/SmallVector{1.0, 1.0}, + /*zero_points=*/SmallVector{0, 0}, + /*quantization_dimension=*/0); + + // Storage type of `i32` is currently verifiable as `unsigned` in `Types.cpp`. + EXPECT_TRUE(quantized_type.getStorageType().isSignlessInteger(32)); +} + +TEST_F(CreateI32F32UniformQuantizedPerAxisTypeTest, F32ExpressedTypeSucceeds) { + const UniformQuantizedPerAxisType quantized_type = + CreateI32F32UniformQuantizedPerAxisType( + UnknownLoc::get(&ctx_), ctx_, + /*scales=*/SmallVector{1.0, 1.0}, + /*zero_points=*/SmallVector{0, 0}, + /*quantization_dimension=*/0); + + EXPECT_TRUE(quantized_type.getExpressedType().isF32()); +} + +TEST_F(CreateI32F32UniformQuantizedPerAxisTypeTest, + StorageTypeMinMaxEqualToI32MinMax) { + const UniformQuantizedPerAxisType quantized_type = + CreateI32F32UniformQuantizedPerAxisType( + UnknownLoc::get(&ctx_), ctx_, + /*scales=*/SmallVector{1.0, 1.0}, + /*zero_points=*/SmallVector{0, 0}, + /*quantization_dimension=*/0); + + EXPECT_EQ(quantized_type.getStorageTypeMin(), + std::numeric_limits::min()); + EXPECT_EQ(quantized_type.getStorageTypeMax(), + std::numeric_limits::max()); +} + +TEST_F(CreateI32F32UniformQuantizedPerAxisTypeTest, + HasQuantizationDimensionProperlySet) { + const UniformQuantizedPerAxisType quantized_type = + CreateI32F32UniformQuantizedPerAxisType( + UnknownLoc::get(&ctx_), ctx_, + /*scales=*/SmallVector{1.0, 1.0}, + /*zero_points=*/SmallVector{0, 0}, + /*quantization_dimension=*/3); + + EXPECT_EQ(quantized_type.getQuantizedDimension(), 3); +} + +TEST_F(CreateI32F32UniformQuantizedPerAxisTypeTest, + HasScaleAndZeroPointProperlySet) { + const UniformQuantizedPerAxisType quantized_type = + CreateI32F32UniformQuantizedPerAxisType( + UnknownLoc::get(&ctx_), ctx_, + /*scales=*/SmallVector{8.0, 9.0}, + /*zero_points=*/SmallVector{98, 99}, + /*quantization_dimension=*/0); + + EXPECT_THAT(quantized_type.getScales(), ElementsAreArray({8.0, 9.0})); + EXPECT_THAT(quantized_type.getZeroPoints(), ElementsAreArray({98, 99})); +} + class IsI8F32UniformQuantizedTypeTest : public Test { protected: - IsI8F32UniformQuantizedTypeTest() { + IsI8F32UniformQuantizedTypeTest() : builder_(&ctx_) { ctx_.loadDialect(); } MLIRContext ctx_; - OpBuilder builder_{&ctx_}; + OpBuilder builder_; }; TEST_F(IsI8F32UniformQuantizedTypeTest, I8F32UniformQuantizedTypeSucceeds) { const UniformQuantizedType qi8_type = quant::UniformQuantizedType::get( - /*flags=*/0, builder_.getI8Type(), builder_.getF32Type(), /*scale=*/1.0, - /*zeroPoint=*/0, /*storageTypeMin=*/0, /*storageTypeMax=*/255); + /*flags=*/QuantizationFlags::Signed, builder_.getI8Type(), + builder_.getF32Type(), /*scale=*/1.0, + /*zeroPoint=*/0, /*storageTypeMin=*/-128, /*storageTypeMax=*/127); EXPECT_TRUE(IsI8F32UniformQuantizedType(qi8_type)); } TEST_F(IsI8F32UniformQuantizedTypeTest, UniformQuantizedTypeSucceeds) { const UniformQuantizedType qi8_type = quant::UniformQuantizedType::get( - /*flags=*/0, builder_.getI8Type(), builder_.getF32Type(), /*scale=*/1.0, - /*zeroPoint=*/0, /*storageTypeMin=*/0, /*storageTypeMax=*/255); + /*flags=*/QuantizationFlags::Signed, builder_.getI8Type(), + builder_.getF32Type(), /*scale=*/1.0, + /*zeroPoint=*/0, /*storageTypeMin=*/-128, /*storageTypeMax=*/127); EXPECT_THAT(qi8_type.dyn_cast_or_null(), NotNull()); } TEST_F(IsI8F32UniformQuantizedTypeTest, StorageTypeI8Succeeds) { const UniformQuantizedType qi8_type = quant::UniformQuantizedType::get( - /*flags=*/0, builder_.getI8Type(), builder_.getF32Type(), /*scale=*/1.0, - /*zeroPoint=*/0, /*storageTypeMin=*/0, /*storageTypeMax=*/255); + /*flags=*/QuantizationFlags::Signed, builder_.getI8Type(), + builder_.getF32Type(), /*scale=*/1.0, + /*zeroPoint=*/0, /*storageTypeMin=*/-128, /*storageTypeMax=*/127); EXPECT_TRUE(IsStorageTypeI8(qi8_type)); } TEST_F(IsI8F32UniformQuantizedTypeTest, ExpressedTypeF32Succeeds) { const UniformQuantizedType qi8_type = quant::UniformQuantizedType::get( - /*flags=*/0, builder_.getI8Type(), builder_.getF32Type(), /*scale=*/1.0, - /*zeroPoint=*/0, /*storageTypeMin=*/0, /*storageTypeMax=*/255); + /*flags=*/QuantizationFlags::Signed, builder_.getI8Type(), + builder_.getF32Type(), /*scale=*/1.0, + /*zeroPoint=*/0, /*storageTypeMin=*/-128, /*storageTypeMax=*/127); EXPECT_TRUE(IsExpressedTypeF32(qi8_type)); } class IsI8F32UniformQuantizedPerAxisTypeTest : public Test { protected: - IsI8F32UniformQuantizedPerAxisTypeTest() { + IsI8F32UniformQuantizedPerAxisTypeTest() : builder_(&ctx_) { ctx_.loadDialect(); } MLIRContext ctx_; - OpBuilder builder_{&ctx_}; + OpBuilder builder_; }; TEST_F(IsI8F32UniformQuantizedPerAxisTypeTest, I8F32UniformQuantizedPerAxisTypeSucceeds) { const UniformQuantizedPerAxisType qi8_per_axis_type = quant::UniformQuantizedPerAxisType::get( - /*flags=*/0, builder_.getI8Type(), builder_.getF32Type(), + /*flags=*/QuantizationFlags::Signed, builder_.getI8Type(), + builder_.getF32Type(), /*scales=*/{1.0}, - /*zeroPoints=*/{0}, /*quantizedDimension=*/0, /*storageTypeMin=*/0, - /*storageTypeMax=*/255); + /*zeroPoints=*/{0}, /*quantizedDimension=*/0, /*storageTypeMin=*/-128, + /*storageTypeMax=*/127); EXPECT_TRUE(IsI8F32UniformQuantizedPerAxisType(qi8_per_axis_type)); EXPECT_FALSE(IsI8F32UniformQuantizedType(qi8_per_axis_type)); } @@ -305,10 +385,11 @@ TEST_F(IsI8F32UniformQuantizedPerAxisTypeTest, TEST_F(IsI8F32UniformQuantizedTypeTest, UniformQuantizedPerAxisTypeSucceeds) { const UniformQuantizedPerAxisType qi8_per_axis_type = quant::UniformQuantizedPerAxisType::get( - /*flags=*/0, builder_.getI8Type(), builder_.getF32Type(), + /*flags=*/QuantizationFlags::Signed, builder_.getI8Type(), + builder_.getF32Type(), /*scales=*/{1.0}, - /*zeroPoints=*/{0}, /*quantizedDimension=*/0, /*storageTypeMin=*/0, - /*storageTypeMax=*/255); + /*zeroPoints=*/{0}, /*quantizedDimension=*/0, /*storageTypeMin=*/-128, + /*storageTypeMax=*/127); EXPECT_THAT(qi8_per_axis_type.dyn_cast_or_null(), NotNull()); } @@ -316,167 +397,187 @@ TEST_F(IsI8F32UniformQuantizedTypeTest, UniformQuantizedPerAxisTypeSucceeds) { TEST_F(IsI8F32UniformQuantizedPerAxisTypeTest, StorageTypeI8Succeeds) { const UniformQuantizedPerAxisType qi8_per_axis_type = quant::UniformQuantizedPerAxisType::get( - /*flags=*/0, builder_.getI8Type(), builder_.getF32Type(), + /*flags=*/QuantizationFlags::Signed, builder_.getI8Type(), + builder_.getF32Type(), /*scales=*/{1.0}, - /*zeroPoints=*/{0}, /*quantizedDimension=*/0, /*storageTypeMin=*/0, - /*storageTypeMax=*/255); + /*zeroPoints=*/{0}, /*quantizedDimension=*/0, /*storageTypeMin=*/-128, + /*storageTypeMax=*/127); EXPECT_TRUE(IsStorageTypeI8(qi8_per_axis_type)); } TEST_F(IsI8F32UniformQuantizedPerAxisTypeTest, ExpressedTypeF32Succeeds) { const UniformQuantizedPerAxisType qi8_per_axis_type = quant::UniformQuantizedPerAxisType::get( - /*flags=*/0, builder_.getI8Type(), builder_.getF32Type(), + /*flags=*/QuantizationFlags::Signed, builder_.getI8Type(), + builder_.getF32Type(), /*scales=*/{1.0}, - /*zeroPoints=*/{0}, /*quantizedDimension=*/0, /*storageTypeMin=*/0, - /*storageTypeMax=*/255); + /*zeroPoints=*/{0}, /*quantizedDimension=*/0, /*storageTypeMin=*/-128, + /*storageTypeMax=*/127); EXPECT_TRUE(IsExpressedTypeF32(qi8_per_axis_type)); } class IsI32F32UniformQuantizedTypeTest : public Test { protected: - IsI32F32UniformQuantizedTypeTest() { + IsI32F32UniformQuantizedTypeTest() : builder_(&ctx_) { ctx_.loadDialect(); } MLIRContext ctx_; - OpBuilder builder_{&ctx_}; + OpBuilder builder_; }; TEST_F(IsI32F32UniformQuantizedTypeTest, I32F32UniformQuantizedTypeSucceeds) { const UniformQuantizedType qi32_type = quant::UniformQuantizedType::get( - /*flags=*/0, builder_.getI32Type(), builder_.getF32Type(), /*scale=*/1.0, - /*zeroPoint=*/0, /*storageTypeMin=*/0, /*storageTypeMax=*/255); + /*flags=*/QuantizationFlags::Signed, builder_.getI32Type(), + builder_.getF32Type(), + /*scale=*/1.0, + /*zeroPoint=*/0, /*storageTypeMin=*/-2147483647, + /*storageTypeMax=*/2147483646); EXPECT_TRUE(IsI32F32UniformQuantizedType(qi32_type)); } TEST_F(IsI32F32UniformQuantizedTypeTest, UniformQuantizedTypeSucceeds) { const UniformQuantizedType qi32_type = quant::UniformQuantizedType::get( - /*flags=*/0, builder_.getI8Type(), builder_.getF32Type(), /*scale=*/1.0, - /*zeroPoint=*/0, /*storageTypeMin=*/0, /*storageTypeMax=*/255); + /*flags=*/QuantizationFlags::Signed, builder_.getI32Type(), + builder_.getF32Type(), + /*scale=*/1.0, + /*zeroPoint=*/0, /*storageTypeMin=*/-2147483647, + /*storageTypeMax=*/2147483646); + EXPECT_TRUE(IsI32F32UniformQuantizedType(qi32_type)); EXPECT_THAT(qi32_type.dyn_cast_or_null(), NotNull()); } TEST_F(IsI32F32UniformQuantizedTypeTest, StorageTypeI32Succeeds) { const UniformQuantizedType qi32_type = quant::UniformQuantizedType::get( - /*flags=*/0, builder_.getI32Type(), builder_.getF32Type(), /*scale=*/1.0, - /*zeroPoint=*/0, /*storageTypeMin=*/0, /*storageTypeMax=*/255); + /*flags=*/QuantizationFlags::Signed, builder_.getI32Type(), + builder_.getF32Type(), + /*scale=*/1.0, + /*zeroPoint=*/0, /*storageTypeMin=*/-2147483647, + /*storageTypeMax=*/2147483646); + EXPECT_TRUE(IsI32F32UniformQuantizedType(qi32_type)); EXPECT_TRUE(IsStorageTypeI32(qi32_type)); } TEST_F(IsI32F32UniformQuantizedTypeTest, ExpressedTypeF32Succeeds) { const UniformQuantizedType qi32_per_axis_type = quant::UniformQuantizedType::get( - /*flags=*/0, builder_.getI8Type(), builder_.getF32Type(), + /*flags=*/QuantizationFlags::Signed, builder_.getI32Type(), + builder_.getF32Type(), /*scale=*/1.0, - /*zeroPoint=*/0, /*storageTypeMin=*/0, /*storageTypeMax=*/255); + /*zeroPoint=*/0, /*storageTypeMin=*/-2147483647, + /*storageTypeMax=*/2147483646); EXPECT_TRUE(IsExpressedTypeF32(qi32_per_axis_type)); } -class CreateI32F32UniformQuantizedPerAxisTypeTest : public Test { +class IsI32F32UniformQuantizedPerAxisTypeTest : public Test { protected: - CreateI32F32UniformQuantizedPerAxisTypeTest() : ctx_() { + IsI32F32UniformQuantizedPerAxisTypeTest() : builder_(&ctx_) { ctx_.loadDialect(); } MLIRContext ctx_; + OpBuilder builder_; }; -TEST_F(CreateI8F32UniformQuantizedPerAxisTypeTest, I32StorageTypeSucceeds) { - const UniformQuantizedPerAxisType quantized_type = - CreateI32F32UniformQuantizedPerAxisType( - UnknownLoc::get(&ctx_), ctx_, - /*scales=*/SmallVector{1.0, 1.0}, - /*zero_points=*/SmallVector{0, 0}, - /*quantization_dimension=*/0); - - EXPECT_TRUE(quantized_type.getStorageType().isSignlessInteger(32)); -} - -TEST_F(CreateI32F32UniformQuantizedPerAxisTypeTest, F32ExpressedTypeSucceeds) { - const UniformQuantizedPerAxisType quantized_type = - CreateI32F32UniformQuantizedPerAxisType( - UnknownLoc::get(&ctx_), ctx_, - /*scales=*/SmallVector{1.0, 1.0}, - /*zero_points=*/SmallVector{0, 0}, - /*quantization_dimension=*/0); - - EXPECT_TRUE(quantized_type.getExpressedType().isF32()); +TEST_F(IsI32F32UniformQuantizedPerAxisTypeTest, + I32F32UniformQuantizedPerAxisTypeSucceeds) { + const UniformQuantizedPerAxisType qi32_per_axis_type = + quant::UniformQuantizedPerAxisType::get( + /*flags=*/QuantizationFlags::Signed, builder_.getI32Type(), + builder_.getF32Type(), + /*scales=*/{1.0}, + /*zeroPoints=*/{0}, /*quantizedDimension=*/0, + /*storageTypeMin=*/-2147483647, /*storageTypeMax=*/2147483646); + EXPECT_TRUE(IsI32F32UniformQuantizedPerAxisType(qi32_per_axis_type)); + EXPECT_FALSE(IsI32F32UniformQuantizedType(qi32_per_axis_type)); } -TEST_F(CreateI32F32UniformQuantizedPerAxisTypeTest, - StorageTypeMinMaxEqualToI32MinMax) { - const UniformQuantizedPerAxisType quantized_type = - CreateI32F32UniformQuantizedPerAxisType( - UnknownLoc::get(&ctx_), ctx_, - /*scales=*/SmallVector{1.0, 1.0}, - /*zero_points=*/SmallVector{0, 0}, - /*quantization_dimension=*/0); +TEST_F(IsI32F32UniformQuantizedPerAxisTypeTest, + I8F32UniformQuantizedTypeFails) { + const UniformQuantizedType qi8_type = quant::UniformQuantizedType::get( + /*flags=*/QuantizationFlags::Signed, builder_.getI8Type(), + builder_.getF32Type(), + /*scale=*/1.0, /*zeroPoint=*/0, /*storageTypeMin=*/-128, + /*storageTypeMax=*/127); + EXPECT_FALSE(IsI32F32UniformQuantizedPerAxisType(qi8_type)); + EXPECT_FALSE(IsStorageTypeI32(qi8_type)); + EXPECT_THAT(qi8_type.dyn_cast_or_null(), + IsNull()); +} + +TEST_F(IsI32F32UniformQuantizedTypeTest, UniformQuantizedPerAxisTypeSucceeds) { + const UniformQuantizedPerAxisType qi32_per_axis_type = + quant::UniformQuantizedPerAxisType::get( + /*flags=*/QuantizationFlags::Signed, builder_.getI32Type(), + builder_.getF32Type(), + /*scales=*/{1.0}, + /*zeroPoints=*/{0}, /*quantizedDimension=*/0, + /*storageTypeMin=*/-2147483647, /*storageTypeMax=*/2147483646); - EXPECT_EQ(quantized_type.getStorageTypeMin(), - std::numeric_limits::min()); - EXPECT_EQ(quantized_type.getStorageTypeMax(), - std::numeric_limits::max()); + EXPECT_THAT( + qi32_per_axis_type.dyn_cast_or_null(), + NotNull()); } -TEST_F(CreateI32F32UniformQuantizedPerAxisTypeTest, - HasQuantizationDimensionProperlySet) { - const UniformQuantizedPerAxisType quantized_type = - CreateI32F32UniformQuantizedPerAxisType( - UnknownLoc::get(&ctx_), ctx_, - /*scales=*/SmallVector{1.0, 1.0}, - /*zero_points=*/SmallVector{0, 0}, - /*quantization_dimension=*/3); +TEST_F(IsI32F32UniformQuantizedPerAxisTypeTest, StorageTypeI8Succeeds) { + const UniformQuantizedPerAxisType qi32_per_axis_type = + quant::UniformQuantizedPerAxisType::get( + /*flags=*/QuantizationFlags::Signed, builder_.getI32Type(), + builder_.getF32Type(), + /*scales=*/{1.0}, + /*zeroPoints=*/{0}, /*quantizedDimension=*/0, + /*storageTypeMin=*/-2147483647, /*storageTypeMax=*/2147483646); - EXPECT_EQ(quantized_type.getQuantizedDimension(), 3); + EXPECT_TRUE(IsStorageTypeI32(qi32_per_axis_type)); } -TEST_F(CreateI32F32UniformQuantizedPerAxisTypeTest, - HasScaleAndZeroPointProperlySet) { - const UniformQuantizedPerAxisType quantized_type = - CreateI32F32UniformQuantizedPerAxisType( - UnknownLoc::get(&ctx_), ctx_, - /*scales=*/SmallVector{8.0, 9.0}, - /*zero_points=*/SmallVector{98, 99}, - /*quantization_dimension=*/0); - - EXPECT_THAT(quantized_type.getScales(), ElementsAreArray({8.0, 9.0})); - EXPECT_THAT(quantized_type.getZeroPoints(), ElementsAreArray({98, 99})); +TEST_F(IsI32F32UniformQuantizedPerAxisTypeTest, ExpressedTypeF32Succeeds) { + const UniformQuantizedPerAxisType qi32_per_axis_type = + quant::UniformQuantizedPerAxisType::get( + /*flags=*/QuantizationFlags::Signed, builder_.getI32Type(), + builder_.getF32Type(), + /*scales=*/{1.0}, + /*zeroPoints=*/{0}, /*quantizedDimension=*/0, + /*storageTypeMin=*/-2147483647, /*storageTypeMax=*/2147483646); + EXPECT_TRUE(IsExpressedTypeF32(qi32_per_axis_type)); } class IsSupportedByTfliteQuantizeOrDequantizeOpsTest : public Test { protected: - IsSupportedByTfliteQuantizeOrDequantizeOpsTest() { + IsSupportedByTfliteQuantizeOrDequantizeOpsTest() : builder_(&ctx_) { ctx_.loadDialect(); } MLIRContext ctx_; - OpBuilder builder_{&ctx_}; + OpBuilder builder_; }; TEST_F(IsSupportedByTfliteQuantizeOrDequantizeOpsTest, StorageTypeI8Succeeds) { auto qi8_type = quant::UniformQuantizedType::get( - /*flags=*/0, builder_.getIntegerType(8, /*isSigned=*/true), - builder_.getF32Type(), /*scale=*/1.0, - /*zeroPoint=*/0, /*storageTypeMin=*/0, /*storageTypeMax=*/255); + /*flags=*/QuantizationFlags::Signed, builder_.getI8Type(), + builder_.getF32Type(), + /*scale=*/1.0, + /*zeroPoint=*/0, /*storageTypeMin=*/-128, /*storageTypeMax=*/127); EXPECT_TRUE(IsSupportedByTfliteQuantizeOrDequantizeOps( dyn_cast_or_null(qi8_type.getStorageType()))); } TEST_F(IsSupportedByTfliteQuantizeOrDequantizeOpsTest, StorageTypeI16Succeeds) { auto qi16_type = quant::UniformQuantizedType::get( - /*flags=*/0, builder_.getIntegerType(16, /*isSigned=*/true), - builder_.getF32Type(), /*scale=*/1.0, - /*zeroPoint=*/0, /*storageTypeMin=*/0, /*storageTypeMax=*/255); + /*flags=*/QuantizationFlags::Signed, builder_.getI8Type(), + builder_.getF32Type(), + /*scale=*/1.0, + /*zeroPoint=*/0, /*storageTypeMin=*/-128, /*storageTypeMax=*/127); EXPECT_TRUE(IsSupportedByTfliteQuantizeOrDequantizeOps( dyn_cast_or_null(qi16_type.getStorageType()))); } TEST_F(IsSupportedByTfliteQuantizeOrDequantizeOpsTest, StorageTypeUI8Succeeds) { auto qi8_type = quant::UniformQuantizedType::get( - /*flags=*/0, builder_.getIntegerType(8, /*isSigned=*/false), - builder_.getF32Type(), /*scale=*/1.0, - /*zeroPoint=*/0, /*storageTypeMin=*/0, /*storageTypeMax=*/255); + /*flags=*/QuantizationFlags::Signed, builder_.getI8Type(), + builder_.getF32Type(), + /*scale=*/1.0, + /*zeroPoint=*/0, /*storageTypeMin=*/-128, /*storageTypeMax=*/127); EXPECT_TRUE(IsSupportedByTfliteQuantizeOrDequantizeOps( dyn_cast_or_null(qi8_type.getStorageType()))); } diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/BUILD b/tensorflow/compiler/mlir/quantization/stablehlo/BUILD index 13b87ba39e7817..439e8542cf7d0e 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/BUILD +++ b/tensorflow/compiler/mlir/quantization/stablehlo/BUILD @@ -52,7 +52,6 @@ cc_library( "passes/lift_quantizable_spots_as_functions_fusion.inc", "passes/lift_quantizable_spots_as_functions_simple.inc", "passes/optimize_graph.cc", - "passes/populate_shape.cc", "passes/post_quantize.cc", "passes/prepare_quantize.cc", "passes/quantize.cc", @@ -74,6 +73,7 @@ cc_library( ":lift_quantizable_spots_as_functions_fusion_inc_gen", ":lift_quantizable_spots_as_functions_simple_inc_gen", ":optimize_graph_inc_gen", + ":quantization_config_proto_cc", ":quantization_options_proto_cc", ":quantization_patterns", ":stablehlo_passes_inc_gen", @@ -83,6 +83,7 @@ cc_library( "//tensorflow/compiler/mlir/lite/quantization:quantization_lib", "//tensorflow/compiler/mlir/lite/quantization/ir:QuantOps", "//tensorflow/compiler/mlir/quantization/common:attrs_and_constraints", + "//tensorflow/compiler/mlir/quantization/common:func", "//tensorflow/compiler/mlir/quantization/common:lift_as_function_call", "//tensorflow/compiler/mlir/quantization/common/ir:QuantOps", "//tensorflow/compiler/mlir/quantization/stablehlo/ops:stablehlo_op_quant_spec", @@ -124,6 +125,7 @@ cc_library( "@llvm-project//mlir:Support", "@llvm-project//mlir:TransformUtils", "@llvm-project//mlir:Transforms", + "@local_tsl//tsl/platform:regexp", "@local_tsl//tsl/platform:str_util", "@local_tsl//tsl/protobuf:protos_all_cc", "@local_xla//xla/mlir_hlo", @@ -152,7 +154,6 @@ cc_library( "//tensorflow/compiler/mlir/quantization/tensorflow:passes", "//tensorflow/compiler/mlir/quantization/tensorflow:quantization_options_proto_cc", "//tensorflow/compiler/mlir/tensorflow", - "//tensorflow/compiler/mlir/tensorflow:xla_call_module_attrs", "//tensorflow/core:protos_all_cc", "//tensorflow/core/platform:path", "@com_google_absl//absl/algorithm:container", @@ -474,28 +475,38 @@ gentbl_cc_library( cc_library( name = "test_passes", srcs = [ + "passes/testing/test_lift_quantizable_spots_as_functions_with_quantization_specs.cc", "passes/testing/test_post_calibration_component.cc", "passes/testing/test_pre_calibration_component.cc", + "passes/testing/test_tf_to_stablehlo_pass.cc", ], hdrs = [ "passes/testing/passes.h", ], compatible_with = get_compatible_with_portable(), deps = [ + ":passes", ":quantization_config_proto_cc", ":stablehlo_test_passes_inc_gen", "//tensorflow/compiler/mlir/lite/quantization/ir:QuantOps", "//tensorflow/compiler/mlir/quantization/stablehlo/cc:post_calibration", "//tensorflow/compiler/mlir/quantization/stablehlo/cc:pre_calibration", "//tensorflow/compiler/mlir/quantization/tensorflow:quantization_options_proto_cc", + "//tensorflow/compiler/mlir/quantization/tensorflow:quantize_preprocess", + "//tensorflow/compiler/mlir/quantization/tensorflow/cc:run_passes", "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/mlir/tensorflow/transforms:verify_no_outside_compilation_markers_pass", + "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@llvm-project//llvm:Support", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", "@llvm-project//mlir:QuantOps", + "@llvm-project//mlir:SparseTensorDialect", "@llvm-project//mlir:Support", + "@local_tsl//tsl/platform:protobuf", "@local_xla//xla/mlir_hlo", "@stablehlo//:chlo_ops", "@stablehlo//:stablehlo_ops", diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/cc/BUILD b/tensorflow/compiler/mlir/quantization/stablehlo/cc/BUILD index 34fb52a0e5bec6..10d4b020166552 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/cc/BUILD +++ b/tensorflow/compiler/mlir/quantization/stablehlo/cc/BUILD @@ -7,6 +7,8 @@ load( package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], default_visibility = [ + # For TFLite Converter integration. + "//tensorflow/compiler/mlir/lite:__subpackages__", "//tensorflow/compiler/mlir/quantization:__subpackages__", ], licenses = ["notice"], @@ -31,6 +33,29 @@ cc_library( ], ) +# OSS: This is a header-only target. Do NOT directly depend on `config_impl` unless it is necessary +# (e.g. undefined symbol error), to avoid ODR violation. +cc_library( + name = "config", + hdrs = ["config.h"], + compatible_with = get_compatible_with_portable(), + deps = [ + "//tensorflow/compiler/mlir/quantization/stablehlo:quantization_config_proto_cc", + ], +) + +# OSS: This is a impl target corresponding to `config`. Do NOT directly depend on `config_impl` +# unless it is necessary (e.g. undefined symbol error), to avoid ODR violation. +cc_library( + name = "config_impl", + srcs = ["config.cc"], + hdrs = ["config.h"], + compatible_with = get_compatible_with_portable(), + deps = [ + "//tensorflow/compiler/mlir/quantization/stablehlo:quantization_config_proto_cc", + ], +) + cc_library( name = "io", srcs = ["io.cc"], @@ -186,6 +211,9 @@ cc_library( deps = [ "//tensorflow/compiler/mlir/quantization/stablehlo:bridge_passes", "//tensorflow/compiler/mlir/quantization/stablehlo:passes", + "//tensorflow/compiler/mlir/quantization/stablehlo:quantization_config_proto_cc", + "//tensorflow/compiler/mlir/quantization/tensorflow:passes", + "//tensorflow/compiler/mlir/quantization/tensorflow:quantization_options_proto_cc", "//tensorflow/compiler/mlir/tensorflow/transforms:tensorflow_passes", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:Pass", @@ -209,7 +237,6 @@ cc_library( deps = [ ":component", ":pass_pipeline", - "//tensorflow/compiler/mlir/quantization/stablehlo:passes", "//tensorflow/compiler/mlir/quantization/stablehlo:quantization_config_proto_cc", "//tensorflow/compiler/mlir/quantization/tensorflow:passes", "//tensorflow/compiler/mlir/quantization/tensorflow:quantization_options_proto_cc", @@ -218,7 +245,6 @@ cc_library( "@com_google_absl//absl/log:die_if_null", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:string_view", - "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", "@local_tsl//tsl/platform:errors", @@ -263,7 +289,6 @@ cc_library( deps = [ ":component", ":pass_pipeline", - "//tensorflow/compiler/mlir/quantization/stablehlo:passes", "//tensorflow/compiler/mlir/quantization/stablehlo:quantization_config_proto_cc", "//tensorflow/compiler/mlir/quantization/tensorflow:passes", "//tensorflow/compiler/mlir/quantization/tensorflow/cc:run_passes", @@ -271,7 +296,6 @@ cc_library( "@com_google_absl//absl/log:die_if_null", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:string_view", - "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", "@local_tsl//tsl/platform:errors", @@ -310,6 +334,11 @@ cc_library( compatible_with = get_compatible_with_portable(), visibility = [ "//tensorflow:__pkg__", + "//tensorflow/compiler/mlir/lite:__pkg__", # For tf_tfl_translate binary. + # For odml_to_stablehlo binary. + "//tensorflow/compiler/mlir/lite/stablehlo:__pkg__", + # For StableHLO Quantizer adapter functionalities within TFLite. Testonly. + "//tensorflow/compiler/mlir/lite/quantization/stablehlo:__pkg__", "//tensorflow/python:__pkg__", ], deps = [ diff --git a/third_party/xla/third_party/tsl/tsl/platform/jpeg.h b/tensorflow/compiler/mlir/quantization/stablehlo/cc/config.cc similarity index 52% rename from third_party/xla/third_party/tsl/tsl/platform/jpeg.h rename to tensorflow/compiler/mlir/quantization/stablehlo/cc/config.cc index a7b640db03943f..679e1f8754be9b 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/jpeg.h +++ b/tensorflow/compiler/mlir/quantization/stablehlo/cc/config.cc @@ -1,4 +1,4 @@ -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -12,18 +12,20 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "tensorflow/compiler/mlir/quantization/stablehlo/cc/config.h" -#ifndef TENSORFLOW_TSL_PLATFORM_JPEG_H_ -#define TENSORFLOW_TSL_PLATFORM_JPEG_H_ +namespace stablehlo::quantization { -#include -#include -#include -#include +QuantizationConfig PopulateDefaults( + const QuantizationConfig& user_provided_config) { + QuantizationConfig config = user_provided_config; -extern "C" { -#include "jerror.h" // from @libjpeg_turbo // IWYU pragma: export -#include "jpeglib.h" // from @libjpeg_turbo // IWYU pragma: export + PipelineConfig& pipeline_config = *config.mutable_pipeline_config(); + if (!pipeline_config.has_unpack_quantized_types()) { + pipeline_config.set_unpack_quantized_types(true); + } + + return config; } -#endif // TENSORFLOW_TSL_PLATFORM_JPEG_H_ +} // namespace stablehlo::quantization diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/cc/config.h b/tensorflow/compiler/mlir/quantization/stablehlo/cc/config.h new file mode 100644 index 00000000000000..20b9efa4a60fa0 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/cc/config.h @@ -0,0 +1,29 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_CC_CONFIG_H_ +#define TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_CC_CONFIG_H_ + +#include "tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.pb.h" + +namespace stablehlo::quantization { + +// Returns a copy of `user_provided_config` with default values populated where +// the user did not explicitly specify. +QuantizationConfig PopulateDefaults( + const QuantizationConfig& user_provided_config); + +} // namespace stablehlo::quantization + +#endif // TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_CC_CONFIG_H_ diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/cc/io_test.cc b/tensorflow/compiler/mlir/quantization/stablehlo/cc/io_test.cc index b5cee2fc492f85..f4f1c5c16589e4 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/cc/io_test.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/cc/io_test.cc @@ -71,12 +71,12 @@ class TestEnvBrokenFileSystem : public tsl::Env { absl::Status LoadDynamicLibrary(const char* library_filename, void** handle) override { - return tsl::OkStatus(); + return absl::OkStatus(); } absl::Status GetSymbolFromLibrary(void* handle, const char* symbol_name, void** symbol) override { - return tsl::OkStatus(); + return absl::OkStatus(); } tsl::string FormatLibraryFileName(const tsl::string& name, diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/cc/pass_pipeline.cc b/tensorflow/compiler/mlir/quantization/stablehlo/cc/pass_pipeline.cc index f8df4095d0e636..31c67f2d20c4ff 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/cc/pass_pipeline.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/cc/pass_pipeline.cc @@ -20,11 +20,45 @@ limitations under the License. #include "mlir/Transforms/Passes.h" // from @llvm-project #include "tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/passes.h" #include "tensorflow/compiler/mlir/quantization/stablehlo/passes/passes.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.pb.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/passes/passes.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" #include "xla/mlir_hlo/mhlo/transforms/passes.h" namespace mlir::quant::stablehlo { +using ::stablehlo::quantization::PipelineConfig; +using ::stablehlo::quantization::QuantizationSpecs; +using ::stablehlo::quantization::StaticRangePtqPreset; +using ::tensorflow::quantization::CalibrationOptions; + +void AddPreCalibrationPasses(OpPassManager& pm, + const CalibrationOptions& calibration_options, + const QuantizationSpecs& quantization_specs) { + pm.addPass(CreateLiftQuantizableSpotsAsFunctionsPass(quantization_specs)); + pm.addNestedPass( + CreateInsertCustomAggregationOpsPass(calibration_options)); + pm.addPass(CreateIssueIDsOfCustomAggregationOpsPass()); + // StableHLO Quantizer currently uses TF's calibration passes. Serialize + // the StableHLO module as tf.XlaCallModule to run calibration. + AddCallModuleSerializationPasses(pm); +} + +void AddPostCalibrationPasses( + OpPassManager& pm, const PipelineConfig& pipeline_config, + const StaticRangePtqPreset& static_range_ptq_preset) { + QuantizeCompositeFunctionsPassOptions options; + options.enable_per_channel_quantized_weight_ = + static_range_ptq_preset.enable_per_channel_quantized_weight(); + pm.addNestedPass( + CreateConvertCustomAggregationOpToQuantStatsPass()); + pm.addPass(createQuantizeCompositeFunctionsPass(options)); + if (pipeline_config.unpack_quantized_types()) { + AddStablehloQuantToIntPasses(pm); + } + AddCallModuleSerializationPasses(pm); +} + void AddXlaCallModuleOpDeserializationPasses(OpPassManager& pm) { pm.addPass(TF::CreateXlaCallModuleDeserializationPass()); pm.addPass(createRestoreFunctionNamePass()); diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/cc/pass_pipeline.h b/tensorflow/compiler/mlir/quantization/stablehlo/cc/pass_pipeline.h index e5272732400365..5920619bd3fb8d 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/cc/pass_pipeline.h +++ b/tensorflow/compiler/mlir/quantization/stablehlo/cc/pass_pipeline.h @@ -16,9 +16,26 @@ limitations under the License. #define TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_CC_PASS_PIPELINE_H_ #include "mlir/Pass/PassManager.h" // from @llvm-project +#include "tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.pb.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/quantization_options.pb.h" namespace mlir::quant::stablehlo { +// Adds passes for static-range quantization pre-calibration. Inserts ops +// required to collect tensor statistics. +void AddPreCalibrationPasses( + OpPassManager& pm, + const ::tensorflow::quantization::CalibrationOptions& calibration_options, + const ::stablehlo::quantization::QuantizationSpecs& specs); + +// Adds passes for static-range quantization post-calibration. Utilizes tensor +// statistics collected from the calibration step and performs quantization. +void AddPostCalibrationPasses( + OpPassManager& pm, + const ::stablehlo::quantization::PipelineConfig& pipeline_config, + const ::stablehlo::quantization::StaticRangePtqPreset& + static_range_ptq_preset); + // Deserializes StableHLO functions serialized and embedded in XlaCallModuleOps. void AddXlaCallModuleOpDeserializationPasses(OpPassManager& pm); diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/cc/post_calibration.cc b/tensorflow/compiler/mlir/quantization/stablehlo/cc/post_calibration.cc index 6d869bafd87788..6f5f10b48f41f5 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/cc/post_calibration.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/cc/post_calibration.cc @@ -17,14 +17,11 @@ limitations under the License. #include "absl/base/nullability.h" #include "absl/log/die_if_null.h" #include "absl/status/statusor.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project // IWYU: keep #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/Pass/PassManager.h" // from @llvm-project #include "tensorflow/compiler/mlir/quantization/stablehlo/cc/pass_pipeline.h" -#include "tensorflow/compiler/mlir/quantization/stablehlo/passes/passes.h" #include "tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.pb.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/cc/run_passes.h" -#include "tensorflow/compiler/mlir/quantization/tensorflow/passes/passes.h" #include "xla/mlir_hlo/mhlo/transforms/passes.h" #include "tsl/platform/errors.h" @@ -32,6 +29,7 @@ namespace mlir::quant::stablehlo { using ::stablehlo::quantization::PipelineConfig; using ::stablehlo::quantization::QuantizationConfig; +using ::stablehlo::quantization::StaticRangePtqPreset; using ::tensorflow::quantization::RunPasses; PostCalibrationComponent::PostCalibrationComponent( @@ -43,21 +41,17 @@ absl::StatusOr PostCalibrationComponent::Run( TF_RETURN_IF_ERROR(RunPasses( kName, /*add_passes_func=*/ [&config, this](PassManager& pm) { - AddPasses(pm, config.pipeline_config()); + AddPostCalibrationPasses(pm, config.pipeline_config(), + config.static_range_ptq_preset()); }, *ctx_, module_op)); return module_op; } void PostCalibrationComponent::AddPasses( - OpPassManager& pm, const PipelineConfig& pipeline_config) const { - pm.addNestedPass( - CreateConvertCustomAggregationOpToQuantStatsPass()); - pm.addPass(createQuantizeCompositeFunctionsPass()); - if (pipeline_config.unpack_quantized_types()) { - AddStablehloQuantToIntPasses(pm); - } - AddCallModuleSerializationPasses(pm); + OpPassManager& pm, const StaticRangePtqPreset& static_range_ptq_preset, + const PipelineConfig& pipeline_config) const { + AddPostCalibrationPasses(pm, pipeline_config, static_range_ptq_preset); } } // namespace mlir::quant::stablehlo diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/cc/post_calibration.h b/tensorflow/compiler/mlir/quantization/stablehlo/cc/post_calibration.h index 9383e744046e7b..3c218c9f857524 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/cc/post_calibration.h +++ b/tensorflow/compiler/mlir/quantization/stablehlo/cc/post_calibration.h @@ -45,11 +45,10 @@ class PostCalibrationComponent : public Component { ModuleOp module_op, const ::stablehlo::quantization::QuantizationConfig& config) override; - // Adds MLIR passes to the pass manager. `Run` will essentially run these - // passes on the module op. `pipeline_config` configures the behavior of the - // passes. void AddPasses( OpPassManager& pm, + const ::stablehlo::quantization::StaticRangePtqPreset& + static_range_ptq_preset, const ::stablehlo::quantization::PipelineConfig& pipeline_config) const; private: diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/cc/pre_calibration.cc b/tensorflow/compiler/mlir/quantization/stablehlo/cc/pre_calibration.cc index 2f9882417420dd..f54f947990866b 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/cc/pre_calibration.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/cc/pre_calibration.cc @@ -19,14 +19,11 @@ limitations under the License. #include "absl/base/nullability.h" #include "absl/log/die_if_null.h" #include "absl/status/statusor.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/Pass/PassManager.h" // from @llvm-project #include "tensorflow/compiler/mlir/quantization/stablehlo/cc/pass_pipeline.h" -#include "tensorflow/compiler/mlir/quantization/stablehlo/passes/passes.h" #include "tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.pb.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/cc/run_passes.h" -#include "tensorflow/compiler/mlir/quantization/tensorflow/passes/passes.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/quantization_options.pb.h" #include "tsl/platform/errors.h" @@ -44,15 +41,9 @@ PreCalibrationComponent::PreCalibrationComponent( absl::StatusOr PreCalibrationComponent::Run( ModuleOp module_op, const QuantizationConfig& config) { TF_RETURN_IF_ERROR(RunPasses( - /*name=*/kName, /*add_passes_func=*/ - [this](PassManager& pm) { - pm.addPass(createLiftQuantizableSpotsAsFunctionsPass()); - pm.addNestedPass( - CreateInsertCustomAggregationOpsPass(calibration_options_)); - pm.addPass(CreateIssueIDsOfCustomAggregationOpsPass()); - // StableHLO Quantizer currently uses TF's calibration passes. Serialize - // the StableHLO module as tf.XlaCallModule to run calibration. - AddCallModuleSerializationPasses(pm); + kName, /*add_passes_func=*/ + [&config, this](PassManager& pm) { + AddPreCalibrationPasses(pm, calibration_options_, config.specs()); }, *ctx_, module_op)); return module_op; diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/ops/BUILD b/tensorflow/compiler/mlir/quantization/stablehlo/ops/BUILD index 2138a04d2bb389..b7d75897cddd07 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/ops/BUILD +++ b/tensorflow/compiler/mlir/quantization/stablehlo/ops/BUILD @@ -1,3 +1,4 @@ +load("//tensorflow:tensorflow.bzl", "tf_cc_test") load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_portable") package( @@ -17,6 +18,7 @@ cc_library( compatible_with = get_compatible_with_portable(), deps = [ "//tensorflow/compiler/mlir/lite/quantization:quantization_lib", + "//tensorflow/compiler/mlir/lite/quantization/ir:QuantOps", "//tensorflow/compiler/mlir/quantization/tensorflow:quantization_options_proto_cc", "//tensorflow/compiler/mlir/tensorflow", "@llvm-project//llvm:Support", @@ -26,3 +28,20 @@ cc_library( "@stablehlo//:stablehlo_ops", ], ) + +tf_cc_test( + name = "stablehlo_op_quant_spec_test", + srcs = ["stablehlo_op_quant_spec_test.cc"], + deps = [ + ":stablehlo_op_quant_spec", + "//tensorflow/compiler/mlir/lite/quantization/ir:QuantOps", + "//tensorflow/compiler/mlir/quantization/common:test_base", + "//tensorflow/compiler/mlir/tensorflow", + "//tensorflow/core:test", + "@com_google_absl//absl/strings:string_view", + "@com_google_googletest//:gtest_main", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@stablehlo//:stablehlo_ops", + ], +) diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/ops/stablehlo_op_quant_spec.cc b/tensorflow/compiler/mlir/quantization/stablehlo/ops/stablehlo_op_quant_spec.cc index 9e832650c24a23..bbcff2dcdbe6d2 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/ops/stablehlo_op_quant_spec.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/ops/stablehlo_op_quant_spec.cc @@ -25,6 +25,7 @@ limitations under the License. #include "mlir/IR/Value.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project #include "stablehlo/dialect/StablehloOps.h" // from @stablehlo +#include "tensorflow/compiler/mlir/lite/quantization/ir/QuantOps.h" #include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/quantization_options.pb.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" @@ -59,8 +60,8 @@ std::unique_ptr GetStableHloOpQuantSpec(Operation* op) { quant::GetUniformQuantizedTypeForBias}; } } - for (auto quantizable_operand : spec->coeff_op_quant_dim) { - spec->quantizable_operands.insert(quantizable_operand.first); + for (const auto [operand_idx, per_channel_dim] : spec->coeff_op_quant_dim) { + spec->quantizable_operands.insert(operand_idx); } } return spec; @@ -69,18 +70,17 @@ std::unique_ptr GetStableHloOpQuantSpec(Operation* op) { std::unique_ptr GetStableHloQuantScaleSpec(Operation* op) { auto scale_spec = std::make_unique(); if (llvm::isa(op)) { + mlir::stablehlo::ConcatenateOp, mlir::stablehlo::GatherOp, + mlir::stablehlo::PadOp, mlir::stablehlo::ReduceWindowOp, + mlir::stablehlo::ReshapeOp, mlir::stablehlo::SelectOp, + mlir::stablehlo::SliceOp, mlir::stablehlo::TransposeOp>(op)) { scale_spec->has_same_scale_requirement = true; } return scale_spec; } bool IsOpQuantizableStableHlo(Operation* op) { - if (mlir::isa(op)) { + if (isa(op)) { // Constant ops do not have QuantizableResult attribute but can be // quantized. return true; diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/stablehlo_op_quant_spec_test.cc b/tensorflow/compiler/mlir/quantization/stablehlo/ops/stablehlo_op_quant_spec_test.cc similarity index 56% rename from tensorflow/compiler/mlir/quantization/stablehlo/tests/stablehlo_op_quant_spec_test.cc rename to tensorflow/compiler/mlir/quantization/stablehlo/ops/stablehlo_op_quant_spec_test.cc index a469bd7c349c99..80a6d2fa451e6c 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/tests/stablehlo_op_quant_spec_test.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/ops/stablehlo_op_quant_spec_test.cc @@ -15,33 +15,28 @@ limitations under the License. #include "tensorflow/compiler/mlir/quantization/stablehlo/ops/stablehlo_op_quant_spec.h" +#include #include #include "absl/strings/string_view.h" -#include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project -#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project -#include "mlir/IR/MLIRContext.h" // from @llvm-project #include "mlir/IR/OwningOpRef.h" // from @llvm-project -#include "mlir/IR/SymbolTable.h" // from @llvm-project -#include "mlir/Parser/Parser.h" // from @llvm-project #include "stablehlo/dialect/StablehloOps.h" // from @stablehlo #include "tensorflow/compiler/mlir/lite/quantization/ir/QuantOps.h" #include "tensorflow/compiler/mlir/quantization/common/test_base.h" -#include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/core/platform/test.h" namespace mlir::quant::stablehlo { namespace { -using ::mlir::quant::QuantizationTestBase; +using ::testing::NotNull; class IsOpQuantizableStableHloTest : public QuantizationTestBase {}; // Quantizable ops: constants // Non-quantizable ops: normal StableHLO ops and terminators -constexpr absl::string_view module_constant_add = R"mlir( +constexpr absl::string_view kModuleConstantAdd = R"mlir( module { func.func @constant_add() -> (tensor<3x2xf32>) { %cst1 = stablehlo.constant dense<2.4> : tensor<3x2xf32> @@ -55,7 +50,7 @@ constexpr absl::string_view module_constant_add = R"mlir( // Quantizable ops: XlaCallModule op with "fully_quantizable" attribute and // same-scale StableHLO ops // Non-quantizable ops: quantize/dequantize ops -constexpr absl::string_view module_composite_same_scale = R"mlir( +constexpr absl::string_view kModuleCompositeSameScale = R"mlir( module { func.func @same_scale_after_composite() -> tensor<3x1xf32> { %0 = "tf.XlaCallModule"() {Sout = [#tf_type.shape<1x3>], _entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : () -> tensor<1x3xf32> @@ -70,7 +65,7 @@ constexpr absl::string_view module_composite_same_scale = R"mlir( )mlir"; // Non-quantizable ops: XlaCallModule op without "fully_quantizable" attribute -constexpr absl::string_view module_composite_no_attr = R"mlir( +constexpr absl::string_view kModuleCompositeNoAttr = R"mlir( module { func.func @composite_without_attr() -> tensor<1x3xf32> { %0 = "tf.XlaCallModule"() {Sout = [#tf_type.shape<1x3>], _entry_function = @non_quantizable_composite, _original_entry_function = "non_quantizable_composite", _stablehlo_module_attrs = {}, device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : () -> tensor<1x3xf32> @@ -80,97 +75,79 @@ constexpr absl::string_view module_composite_no_attr = R"mlir( )mlir"; TEST_F(IsOpQuantizableStableHloTest, ConstantOpQuantizable) { - OwningOpRef module_op_ref = - ParseModuleOpString(module_constant_add); - func::FuncOp test_func = - GetFunctionFromModule(*module_op_ref, "constant_add"); - Operation* constant_op = - FindOperationOfType(test_func); - bool is_constant_quantizable = - mlir::quant::stablehlo::IsOpQuantizableStableHlo(constant_op); + OwningOpRef module_op_ref = ParseModuleOpString(kModuleConstantAdd); + auto test_func = module_op_ref->lookupSymbol("constant_add"); + ASSERT_THAT(test_func, NotNull()); - EXPECT_TRUE(is_constant_quantizable); + auto constant_op = + FindOperationOfType(test_func); + EXPECT_TRUE(IsOpQuantizableStableHlo(constant_op)); } TEST_F(IsOpQuantizableStableHloTest, TerminatorOpNotQuantizable) { - OwningOpRef module_op_ref = - ParseModuleOpString(module_constant_add); - func::FuncOp test_func = - GetFunctionFromModule(*module_op_ref, "constant_add"); - Operation* return_op = FindOperationOfType(test_func); - bool is_return_quantizable = - mlir::quant::stablehlo::IsOpQuantizableStableHlo(return_op); - - EXPECT_FALSE(is_return_quantizable); + OwningOpRef module_op_ref = ParseModuleOpString(kModuleConstantAdd); + auto test_func = module_op_ref->lookupSymbol("constant_add"); + ASSERT_THAT(test_func, NotNull()); + + auto return_op = FindOperationOfType(test_func); + EXPECT_FALSE(IsOpQuantizableStableHlo(return_op)); } TEST_F(IsOpQuantizableStableHloTest, SameScaleOpQuantizable) { OwningOpRef module_op_ref = - ParseModuleOpString(module_composite_same_scale); - func::FuncOp test_func = - GetFunctionFromModule(*module_op_ref, "same_scale_after_composite"); - Operation* reshape_op = - FindOperationOfType(test_func); - bool is_reshape_quantizable = - mlir::quant::stablehlo::IsOpQuantizableStableHlo(reshape_op); - - EXPECT_TRUE(is_reshape_quantizable); + ParseModuleOpString(kModuleCompositeSameScale); + auto test_func = + module_op_ref->lookupSymbol("same_scale_after_composite"); + ASSERT_THAT(test_func, NotNull()); + + auto reshape_op = FindOperationOfType(test_func); + EXPECT_TRUE(IsOpQuantizableStableHlo(reshape_op)); } TEST_F(IsOpQuantizableStableHloTest, NonSameScaleOpNotQuantizable) { - OwningOpRef module_op_ref = - ParseModuleOpString(module_constant_add); - func::FuncOp test_func = - GetFunctionFromModule(*module_op_ref, "constant_add"); - Operation* add_op = FindOperationOfType(test_func); - bool is_add_quantizable = - mlir::quant::stablehlo::IsOpQuantizableStableHlo(add_op); - - EXPECT_FALSE(is_add_quantizable); + OwningOpRef module_op_ref = ParseModuleOpString(kModuleConstantAdd); + auto test_func = module_op_ref->lookupSymbol("constant_add"); + ASSERT_THAT(test_func, NotNull()); + + auto add_op = FindOperationOfType(test_func); + EXPECT_FALSE(IsOpQuantizableStableHlo(add_op)); } TEST_F(IsOpQuantizableStableHloTest, ValidXlaCallModuleOpQuantizable) { OwningOpRef module_op_ref = - ParseModuleOpString(module_composite_same_scale); - func::FuncOp test_func = - GetFunctionFromModule(*module_op_ref, "same_scale_after_composite"); - Operation* xla_call_module_op = - FindOperationOfType(test_func); - bool is_xla_call_module_quantizable = - mlir::quant::stablehlo::IsOpQuantizableStableHlo(xla_call_module_op); - - EXPECT_TRUE(is_xla_call_module_quantizable); + ParseModuleOpString(kModuleCompositeSameScale); + auto test_func = + module_op_ref->lookupSymbol("same_scale_after_composite"); + ASSERT_THAT(test_func, NotNull()); + + auto xla_call_module_op = FindOperationOfType(test_func); + EXPECT_TRUE(IsOpQuantizableStableHlo(xla_call_module_op)); } TEST_F(IsOpQuantizableStableHloTest, InvalidXlaCallModuleOpNotQuantizable) { OwningOpRef module_op_ref = - ParseModuleOpString(module_composite_no_attr); - func::FuncOp test_func = - GetFunctionFromModule(*module_op_ref, "composite_without_attr"); - Operation* xla_call_module_op = - FindOperationOfType(test_func); - bool is_xla_call_module_quantizable = - mlir::quant::stablehlo::IsOpQuantizableStableHlo(xla_call_module_op); - - EXPECT_FALSE(is_xla_call_module_quantizable); + ParseModuleOpString(kModuleCompositeNoAttr); + auto test_func = + module_op_ref->lookupSymbol("composite_without_attr"); + ASSERT_THAT(test_func, NotNull()); + + auto xla_call_module_op = FindOperationOfType(test_func); + EXPECT_FALSE(IsOpQuantizableStableHlo(xla_call_module_op)); } TEST_F(IsOpQuantizableStableHloTest, QuantizeDequantizeOpNotQuantizable) { OwningOpRef module_op_ref = - ParseModuleOpString(module_composite_same_scale); - func::FuncOp test_func = - GetFunctionFromModule(*module_op_ref, "same_scale_after_composite"); - Operation* quantize_op = - FindOperationOfType(test_func); - Operation* dequantize_op = - FindOperationOfType(test_func); - bool is_quantize_quantizable = - mlir::quant::stablehlo::IsOpQuantizableStableHlo(quantize_op); - bool is_dequantize_quantizable = - mlir::quant::stablehlo::IsOpQuantizableStableHlo(dequantize_op); + ParseModuleOpString(kModuleCompositeSameScale); + auto test_func = + module_op_ref->lookupSymbol("same_scale_after_composite"); + ASSERT_THAT(test_func, NotNull()); + + auto quantize_op = FindOperationOfType(test_func); + EXPECT_FALSE(IsOpQuantizableStableHlo(quantize_op)); - EXPECT_FALSE(is_quantize_quantizable); - EXPECT_FALSE(is_dequantize_quantizable); + auto dequantize_op = + FindOperationOfType(test_func); + EXPECT_FALSE(IsOpQuantizableStableHlo(dequantize_op)); } } // namespace diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/convert_mhlo_quant_to_int.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/convert_mhlo_quant_to_int.cc index 633c193ab57c7c..f572a0795e3b77 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/convert_mhlo_quant_to_int.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/convert_mhlo_quant_to_int.cc @@ -89,7 +89,7 @@ UniformQuantizedPerAxisType GetPerChannelType(QuantType quant_type) { void GetQuantizationParams(OpBuilder &builder, Location loc, QuantType quant_type, Value &scales, Value &zero_points, bool output_zero_point_in_fp, - DenseIntElementsAttr &broadcast_dims) { + DenseI64ArrayAttr &broadcast_dims) { // Get scales/zero points for per-tensor and per-axis quantization cases. if (auto *quant_per_tensor_type = std::get_if(&quant_type)) { @@ -140,8 +140,8 @@ void GetQuantizationParams(OpBuilder &builder, Location loc, builder.getI32Type()), zero_points_vec)); } - broadcast_dims = DenseIntElementsAttr::get( - RankedTensorType::get({1}, builder.getI64Type()), + broadcast_dims = DenseI64ArrayAttr::get( + builder.getContext(), {static_cast(quant_per_channel_type.getQuantizedDimension())}); } } @@ -256,9 +256,8 @@ Value ApplyMergedScalesAndZps(OpBuilder &builder, Location loc, merged_scale_double.end()), merged_zp_float(merged_zp_double.begin(), merged_zp_double.end()); - auto broadcast_dims = DenseIntElementsAttr::get( - RankedTensorType::get({1}, builder.getI64Type()), - {quantized_dimension}); + auto broadcast_dims = + DenseI64ArrayAttr::get(builder.getContext(), {quantized_dimension}); Value merged_scale = builder.create( loc, DenseFPElementsAttr::get( RankedTensorType::get({channel_size}, builder.getF32Type()), @@ -367,7 +366,7 @@ class ConvertUniformQuantizeOp ConversionPatternRewriter &rewriter, QuantType quant_type) const { Value scales, zero_points; - DenseIntElementsAttr broadcast_dims; + DenseI64ArrayAttr broadcast_dims; GetQuantizationParams(rewriter, op->getLoc(), quant_type, scales, zero_points, /*output_zero_point_in_fp=*/true, broadcast_dims); @@ -425,7 +424,7 @@ class ConvertUniformDequantizeOp return failure(); } Value scales, zero_points; - DenseIntElementsAttr broadcast_dims; + DenseI64ArrayAttr broadcast_dims; GetQuantizationParams(rewriter, op->getLoc(), *quant_type, scales, zero_points, /*output_zero_point_in_fp=*/false, broadcast_dims); @@ -465,15 +464,41 @@ class ConvertUniformQuantizedAddOp : public OpConversionPattern { // We only handle cases where lhs, rhs and results all have quantized // element type. - if (failed(lhs_quant_type) || IsPerChannelType(*lhs_quant_type) || - failed(rhs_quant_type) || IsPerChannelType(*rhs_quant_type) || - failed(res_quant_type) || IsPerChannelType(*res_quant_type)) { + if (failed(lhs_quant_type) || failed(rhs_quant_type) || + failed(res_quant_type)) { op->emitError( - "AddOp requires the same quantized element type for all operands and " + "AddOp requires the quantized element type for all operands and " "results"); return failure(); } + if (IsPerChannelType(*lhs_quant_type) || + IsPerChannelType(*rhs_quant_type) || + IsPerChannelType(*res_quant_type)) { + // Handle Per-Channel Quantized Types. We only support lhs/rhs/result with + // exact same per-channel quantized types with I32 storage type. + if (!IsPerChannelType(*lhs_quant_type) || + !IsPerChannelType(*rhs_quant_type) || + !IsPerChannelType(*res_quant_type) || + GetPerChannelType(*lhs_quant_type) != + GetPerChannelType(*rhs_quant_type) || + GetPerChannelType(*lhs_quant_type) != + GetPerChannelType(*res_quant_type)) { + op->emitError( + "Per-channel quantized AddOp requires the same quantized element " + "type for all operands and results"); + return failure(); + } + if (!GetPerChannelType(*lhs_quant_type).getStorageType().isInteger(32)) { + // For server-side StableHLO Quantization, add is quantized only when + // fused with conv/dot ops, whose output must be i32. + op->emitError("Per-channel quantized AddOp requires i32 storage type"); + return failure(); + } + return matchAndRewritePerChannel(op, adaptor, rewriter, + GetPerChannelType(*lhs_quant_type)); + } + // TODO: b/260280919 - Consider avoiding conversion to int32. auto res_int32_tensor_type = op.getResult().getType().clone(rewriter.getI32Type()); @@ -536,6 +561,33 @@ class ConvertUniformQuantizedAddOp : public OpConversionPattern { return success(); } + + LogicalResult matchAndRewritePerChannel( + mhlo::AddOp op, mhlo::AddOpAdaptor adaptor, + ConversionPatternRewriter &rewriter, + UniformQuantizedPerAxisType quant_type) const { + // We assume lhs/rhs/result have the same quantized type with i32 storage. + Value add_result = rewriter.create( + op->getLoc(), adaptor.getLhs(), adaptor.getRhs()); + // Add zp contribution if it is non-zero for any channel. + if (llvm::any_of(quant_type.getZeroPoints(), + [](int64_t zp) { return zp != 0; })) { + SmallVector zps_vec(quant_type.getZeroPoints().begin(), + quant_type.getZeroPoints().end()); + Value zps = rewriter.create( + op->getLoc(), + DenseIntElementsAttr::get( + RankedTensorType::get({static_cast(zps_vec.size())}, + rewriter.getI32Type()), + zps_vec)); + add_result = rewriter.create( + op->getLoc(), add_result, zps, + rewriter.getDenseI64ArrayAttr( + {static_cast(quant_type.getQuantizedDimension())})); + } + rewriter.replaceOp(op, add_result); + return success(); + } }; // This is a convenient struct for holding dimension numbers for dot-like ops diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/convert_tf_quant_ops_to_mhlo.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/convert_tf_quant_ops_to_mhlo.cc index a24db896d79f1d..a6312e067af50f 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/convert_tf_quant_ops_to_mhlo.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/convert_tf_quant_ops_to_mhlo.cc @@ -596,7 +596,7 @@ class ConvertUniformQuantizedAddOp // rhs (bias) is always 1D that broadcasts to the last dim of lhs. auto broadcast_dims = - mhlo::GetI64ElementsAttr({lhs_type.getRank() - 1}, &rewriter); + rewriter.getDenseI64ArrayAttr({lhs_type.getRank() - 1}); auto rhs_type = GetUniformQuantizedType( op, op.getRhs().getType(), op.getRhsScales(), op.getRhsZeroPoints(), @@ -651,8 +651,7 @@ class ConvertUniformQuantizedClipByValueOp if (quantization_axis >= 0) { broadcast_dims_values.push_back(quantization_axis); } - auto broadcast_dims = - mhlo::GetI64ElementsAttr(broadcast_dims_values, &rewriter); + auto broadcast_dims = rewriter.getDenseI64ArrayAttr(broadcast_dims_values); auto min_max_type = GetUniformQuantizedType( op, op.getMin().getType(), op.getScales(), op.getZeroPoints(), diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/legalize_tf_quant_test.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/legalize_tf_quant_test.cc index 4c20b6bebdcdad..93946fdc320a97 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/legalize_tf_quant_test.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/legalize_tf_quant_test.cc @@ -23,8 +23,8 @@ limitations under the License. #include "tensorflow/compiler/mlir/tf2xla/api/v2/legalize_tf.h" #include "xla/client/client_library.h" #include "xla/shape.h" -#include "xla/stream_executor/multi_platform_manager.h" #include "xla/stream_executor/platform.h" +#include "xla/stream_executor/platform_manager.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/protobuf/config.pb.h" #include "tensorflow/core/protobuf/tpu/compile_metadata.pb.h" @@ -45,7 +45,7 @@ class LegalizeTFQuantTest : public Test { tensorflow::ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_UNSPECIFIED; mlir_to_hlo_args.mlir_module = mlir_module_string; tensorflow::se::Platform* platform = - tensorflow::se::MultiPlatformManager::PlatformWithName("Host").value(); + tensorflow::se::PlatformManager::PlatformWithName("Host").value(); auto client = xla::ClientLibrary::GetOrCreateCompileOnlyClient(platform).value(); tensorflow::tpu::TPUCompileMetadataProto metadata_proto; diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/convert_func_to_bfloat16.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/convert_func_to_bfloat16.cc index 3e5b7e1f8d5ace..17300611e356ce 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/convert_func_to_bfloat16.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/convert_func_to_bfloat16.cc @@ -47,7 +47,29 @@ class BFloat16TypeConverter : public TypeConverter { } }; -// An Op is illegal iff it is non-UQ op and it contains qint types. +// This helper function makes legality check easier. Both convert ops in the +// patterns below are considered legal: +// - BitcastConvertOp(i32 -> f32) + ConvertOp(f32 -> bf16) +// - ConvertOp(bf16 -> f32) -> BitcastConvertOp(f32 -> i32) +template +bool IsConvertOpLegal(ConvertOp convert_op, BFloat16TypeConverter &converter) { + if (!converter.isLegal(convert_op.getOperand().getType())) { + auto other_convert_op = dyn_cast_or_null( + convert_op.getOperand().getDefiningOp()); + return other_convert_op && + converter.isLegal(other_convert_op.getOperand().getType()); + } else if (!converter.isLegal(convert_op.getResult().getType())) { + if (!convert_op.getResult().hasOneUse()) { + return false; + } + auto other_convert_op = dyn_cast_or_null( + *convert_op.getResult().getUsers().begin()); + return other_convert_op && + converter.isLegal(other_convert_op.getResult().getType()); + } + return true; +} + class BFloat16TypeConversionTarget : public ConversionTarget { public: explicit BFloat16TypeConversionTarget(MLIRContext &ctx, @@ -58,6 +80,15 @@ class BFloat16TypeConversionTarget : public ConversionTarget { // types do not contain. if (auto func = dyn_cast(op)) { if (!converter_.isSignatureLegal(func.getFunctionType())) return false; + } else if (auto bitcast_convert_op = + dyn_cast(op)) { + return IsConvertOpLegal(bitcast_convert_op, + converter_); + } else if (auto convert_op = dyn_cast(op)) { + return IsConvertOpLegal(convert_op, + converter_); } return converter_.isLegal(op); }); @@ -69,7 +100,7 @@ class BFloat16TypeConversionTarget : public ConversionTarget { class BFloat16TypePattern : public ConversionPattern { public: - BFloat16TypePattern(MLIRContext *ctx, TypeConverter &converter) + BFloat16TypePattern(TypeConverter &converter, MLIRContext *ctx) : ConversionPattern(converter, MatchAnyOpTypeTag(), /*benefit=*/1, ctx) {} LogicalResult matchAndRewrite( @@ -78,6 +109,10 @@ class BFloat16TypePattern : public ConversionPattern { if (getTypeConverter()->isLegal(op)) { return failure(); } + if (isa(op)) { + // Skip BitcastConvertOp, which is handled by the other pattern. + return failure(); + } // Update the results. SmallVector new_results; @@ -118,6 +153,42 @@ class BFloat16TypePattern : public ConversionPattern { return success(); } }; + +class BitcastConvertOpPattern + : public OpConversionPattern { + public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite( + mlir::stablehlo::BitcastConvertOp op, + mlir::stablehlo::BitcastConvertOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + bool is_input_legal = + getTypeConverter()->isLegal(op.getOperand().getType()); + bool is_output_legal = + getTypeConverter()->isLegal(op.getResult().getType()); + if (is_input_legal && is_output_legal) { + return failure(); + } else if (is_input_legal) { + // output is f32, we bitcast_convert to f32 and then convert to bf16. + Value output = rewriter.create( + op->getLoc(), op.getResult().getType(), adaptor.getOperand()); + rewriter.replaceOpWithNewOp( + op, getTypeConverter()->convertType(op.getResult().getType()), + output); + } else if (is_output_legal) { + // input is f32, we convert from bf16 and then bitcast_convert. + Value output = rewriter.create( + op->getLoc(), op.getOperand().getType(), adaptor.getOperand()); + rewriter.replaceOpWithNewOp( + op, op.getResult().getType(), output); + } else { + // Both input/output are f32. Convert to no-op. + rewriter.replaceOp(op, adaptor.getOperand()); + } + return success(); + } +}; } // namespace #define GEN_PASS_DEF_CONVERTFUNCTOBFLOAT16PASS @@ -140,7 +211,8 @@ void ConvertFuncToBfloat16Pass::runOnOperation() { RewritePatternSet patterns(context); BFloat16TypeConverter converter; - patterns.add(context, converter); + patterns.add(converter, + context); populateFunctionOpInterfaceTypeConversionPattern(patterns, converter); BFloat16TypeConversionTarget target(*context, converter); diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/lift_quantizable_spots_as_functions.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/lift_quantizable_spots_as_functions.cc index 6f13634b317aa4..dbe88208ae7b03 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/lift_quantizable_spots_as_functions.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/lift_quantizable_spots_as_functions.cc @@ -12,9 +12,13 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include +#include #include #include "llvm/ADT/STLExtras.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/raw_ostream.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project @@ -31,6 +35,11 @@ limitations under the License. #include "stablehlo/dialect/StablehloOps.h" // from @stablehlo // IWYU pragma: keep #include "tensorflow/compiler/mlir/quantization/common/attrs_and_constraints.h" #include "tensorflow/compiler/mlir/quantization/common/lift_as_function_call.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.pb.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tsl/platform/regexp.h" // IWYU pragma: keep + +#define DEBUG_TYPE "lift_quantizable_spots_as_functions" namespace mlir::quant::stablehlo { @@ -39,13 +48,16 @@ namespace mlir::quant::stablehlo { namespace { +using ::stablehlo::quantization::FunctionNameMatcherSpec; +using ::stablehlo::quantization::Method; +using ::stablehlo::quantization::QuantizationSpec; +using ::stablehlo::quantization::QuantizationSpecs; + // TODO - b/303543789: Move the helper functions below to a separate util. // Fetches the default or null attribute, used for pattern matching. Attribute DefaultOrNullAttr(OpBuilder& builder, const Attribute& attr) { - if (!attr) { - return builder.getStringAttr(kNullAttributeValue); - } - return attr; + if (attr) return attr; + return builder.getStringAttr(kNullAttributeValue); } // Checks whether the value of a constant equals the given float, regardless @@ -62,6 +74,12 @@ bool FloatValueEquals(const Attribute& attr, const double value) { }); } +// Lifts quantizable units as separate functions, thereby identifying the +// boundaries of quantizable subgraphs. `QuantizationSpecs` influences how +// quantizable units are lifted. +// +// FileCheck test cases using various `QuantizationSpecs` can be seen at +// `TestLiftQuantizableSpotsAsFunctionsWithQuantizationSpecsPass`. class LiftQuantizableSpotsAsFunctionsPass : public impl::LiftQuantizableSpotsAsFunctionsPassBase< LiftQuantizableSpotsAsFunctionsPass> { @@ -69,10 +87,19 @@ class LiftQuantizableSpotsAsFunctionsPass MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( LiftQuantizableSpotsAsFunctionsPass) - explicit LiftQuantizableSpotsAsFunctionsPass() = default; + LiftQuantizableSpotsAsFunctionsPass() = default; + + // Constructor with explicit user-provided `QuantizationSpecs`. + explicit LiftQuantizableSpotsAsFunctionsPass( + QuantizationSpecs quantization_specs) + : quantization_specs_(std::move(quantization_specs)) {} private: void runOnOperation() override; + + // No explicit quantization spec is specified by default. Implicitly this + // means that all quantizable units will be identified and lifted. + QuantizationSpecs quantization_specs_{}; }; namespace simple_patterns { @@ -83,6 +110,91 @@ namespace fusion_patterns { #include "tensorflow/compiler/mlir/quantization/stablehlo/passes/lift_quantizable_spots_as_functions_fusion.inc" } +// Returns a `func::FuncOp` in `module_op` (not nested) whose name matches +// `name`. Returns null if no such a function exists. +// TODO: b/307620778 - Factor out "FindMainFuncOp" functionality. +func::FuncOp FindFuncOp(ModuleOp module_op, const StringRef name) { + auto func_ops = module_op.getOps(); + auto func_itr = llvm::find_if(func_ops, [name](func::FuncOp func_op) { + return func_op.getName() == name; + }); + + if (func_itr == func_ops.end()) return {}; + return *func_itr; +} + +// Quantizable Unit matcher that uses lifted function's name for matching. +class FunctionNameMatcher { + public: + explicit FunctionNameMatcher(const FunctionNameMatcherSpec& spec) + : match_regex_(GetMatchRegex(spec)) {} + + // Returns `true` when matched with the entry function of + // `xla_call_module_op`. + bool Match(TF::XlaCallModuleOp xla_call_module_op) const { + if (match_regex_ == nullptr) return false; + + const std::string lifted_func_name = + xla_call_module_op->getAttrOfType("_entry_function") + .getValue() + .str(); + + return RE2::FullMatch(lifted_func_name, *match_regex_); // NOLINT + } + + private: + // Returns an owned `RE2` object that corresponds to the `spec`. Returns + // `nullptr` if the `spec` is invalid. + // NOLINTNEXTLINE - RE2 included via TSL regexp.h + std::unique_ptr GetMatchRegex(const FunctionNameMatcherSpec& spec) { + const std::string& regex = spec.regex(); + if (regex.empty()) return nullptr; + + return std::make_unique(regex); // NOLINT + } + + // Regex object used for matching against a lifted function's name. + std::unique_ptr match_regex_; // NOLINT +}; + +// Applies quantization spec to all matched lifted functions. At this point only +// denylisting (`NoQuantization`) will be applied if specs is nonempty. +// TODO: b/307620778 - Support more advanced selective quantization methods. +LogicalResult ApplyQuantizationSpec(const QuantizationSpec& spec, + ModuleOp module_op) { + func::FuncOp main_func = FindFuncOp(module_op, "main"); + if (!main_func) return failure(); + + const Method& quantization_method = spec.method(); + if (!quantization_method.has_no_quantization()) { + module_op->emitError() << "Unsupported quantization method: " + << quantization_method.DebugString() << "\n"; + return failure(); + } + + const FunctionNameMatcher matcher(spec.matcher().function_name()); + for (auto xla_call_module_op : main_func.getOps()) { + if (!matcher.Match(xla_call_module_op)) continue; + + // Disable quantization when matched. + const std::string lifted_func_name = + xla_call_module_op->getAttrOfType("_entry_function") + .getValue() + .str(); + func::FuncOp lifted_func = FindFuncOp(module_op, lifted_func_name); + + // Remove relevant attributes that enable quantization. This essentially + // disables quantization for the matched `xla_call_module_op`. + xla_call_module_op->removeAttr("_original_entry_function"); + xla_call_module_op->removeAttr("_tfl_quant_trait"); + lifted_func->removeAttr("tf_quant.composite_function"); + + LLVM_DEBUG(llvm::dbgs() << "Disabled quantization for quantizable unit: " + << lifted_func_name << "\n"); + } + return success(); +} + void LiftQuantizableSpotsAsFunctionsPass::runOnOperation() { MLIRContext* ctx = &getContext(); RewritePatternSet patterns(ctx); @@ -101,8 +213,26 @@ void LiftQuantizableSpotsAsFunctionsPass::runOnOperation() { // Remove all attr_map attributes. module_op.walk([](Operation* op) { op->removeAttr(kAttrMapAttribute); }); + + // Perform selective quantization. Iterates over the quantization specs and + // applies quantization methods to each matched lifted function. + for (const QuantizationSpec& spec : quantization_specs_.specs()) { + if (failed(ApplyQuantizationSpec(spec, module_op))) { + signalPassFailure(); + return; + } + } } } // namespace +// Creates `LiftQuantizableSpotsAsFunctionsPass` with user-defined +// `QuantizationSpecs`. +std::unique_ptr> +CreateLiftQuantizableSpotsAsFunctionsPass( + const QuantizationSpecs& quantization_specs) { + return std::make_unique( + quantization_specs); +} + } // namespace mlir::quant::stablehlo diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/lift_quantizable_spots_as_functions_fusion.td b/tensorflow/compiler/mlir/quantization/stablehlo/passes/lift_quantizable_spots_as_functions_fusion.td index a0bc228397465b..6377740bf6018e 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/lift_quantizable_spots_as_functions_fusion.td +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/lift_quantizable_spots_as_functions_fusion.td @@ -76,12 +76,12 @@ def LiftDotGeneralWithBias : Pat< def LiftConvWithBiasDynamic : Pat< (StableHLO_AddOp:$res - (StableHLO_ConvolutionOp $lhs, $rhs, $window_strides, $padding, + (StableHLO_ConvolutionOp:$conv_0 $lhs, $rhs, $window_strides, $padding, $lhs_dilation, $rhs_dilation, $window_reversal, $dimension_numbers, $feature_group_count, $batch_group_count, $precision_config), (StableHLO_DynamicBroadcastInDimOp $bias, - (Shape_ShapeOfOp $conv), $_, $_, $_)), + (Shape_ShapeOfOp $conv_1), $_, $_, $_)), (LiftAsTFXlaCallModule<"composite_conv_with_bias_dynamic_fn"> (ArgumentList $lhs, $rhs, $bias), (ResultList $res), @@ -95,21 +95,21 @@ def LiftConvWithBiasDynamic : Pat< (NamedAttr<"feature_group_count"> $feature_group_count), (NamedAttr<"batch_group_count"> $batch_group_count), (NamedAttr<"precision_config"> (DefaultOrNullAttr $precision_config)))), - [(IsNotInLiftedFunc $res), (IsStableHLOConstantOp $bias)], [], (addBenefit 10)>; + [(IsNotInLiftedFunc $res), (IsStableHLOConstantOp $bias), (AreTheSameValue $conv_0, $conv_1)], [], (addBenefit 10)>; def LiftDotGeneralWithBiasDynamic : Pat< (StableHLO_AddOp:$res - (StableHLO_DotGeneralOp $lhs, $rhs, $dot_dimension_numbers, $precision_config), + (StableHLO_DotGeneralOp:$dot_general_0 $lhs, $rhs, $dot_dimension_numbers, $precision_config), (StableHLO_DynamicBroadcastInDimOp $bias, - (Shape_ShapeOfOp $dot_general), $_, $_, $_)), + (Shape_ShapeOfOp $dot_general_1), $_, $_, $_)), (LiftAsTFXlaCallModule<"composite_dot_general_with_bias_dynamic_fn"> (ArgumentList $lhs, $rhs, $bias), (ResultList $res), (NamedAttributeList (NamedAttr<"dot_dimension_numbers"> $dot_dimension_numbers), (NamedAttr<"precision_config"> (DefaultOrNullAttr $precision_config)))), - [(IsNotInLiftedFunc $res), (IsStableHLOConstantOp $bias)], [], (addBenefit 10)>; + [(IsNotInLiftedFunc $res), (IsStableHLOConstantOp $bias), (AreTheSameValue $dot_general_0, $dot_general_1)], [], (addBenefit 10)>; //===----------------------------------------------------------------------===// // Pattern rules for lifting ops with activation as functions @@ -152,12 +152,12 @@ def LiftDotGeneralWithRelu : Pat< def LiftConvWithReluDynamic : Pat< (StableHLO_MaxOp:$res - (StableHLO_ConvolutionOp $lhs, $rhs, $window_strides, $padding, + (StableHLO_ConvolutionOp:$conv_0 $lhs, $rhs, $window_strides, $padding, $lhs_dilation, $rhs_dilation, $window_reversal, $dimension_numbers, $feature_group_count, $batch_group_count, $precision_config), (StableHLO_DynamicBroadcastInDimOp (StableHLO_ConstantOp $cst), - (Shape_ShapeOfOp $conv), $_, $_, $_)), + (Shape_ShapeOfOp $conv_1), $_, $_, $_)), (LiftAsTFXlaCallModule<"composite_conv_with_relu_dynamic_fn"> (ArgumentList $lhs, $rhs), (ResultList $res), @@ -172,14 +172,14 @@ def LiftConvWithReluDynamic : Pat< (NamedAttr<"batch_group_count"> $batch_group_count), (NamedAttr<"precision_config"> (DefaultOrNullAttr $precision_config)))), [(IsNotInLiftedFunc $res), - (FloatValueEquals<"0"> $cst)], [], (addBenefit 15)>; + (FloatValueEquals<"0"> $cst), (AreTheSameValue $conv_0, $conv_1)], [], (addBenefit 15)>; def LiftDotGeneralWithReluDynamic : Pat< (StableHLO_MaxOp:$res - (StableHLO_DotGeneralOp $lhs, $rhs, $dot_dimension_numbers, $precision_config), + (StableHLO_DotGeneralOp:$dot_general_0 $lhs, $rhs, $dot_dimension_numbers, $precision_config), (StableHLO_DynamicBroadcastInDimOp (StableHLO_ConstantOp $cst), - (Shape_ShapeOfOp $dot_general), $_, $_, $_)), + (Shape_ShapeOfOp $dot_general_1), $_, $_, $_)), (LiftAsTFXlaCallModule<"composite_dot_general_with_relu_dynamic_fn"> (ArgumentList $lhs, $rhs), (ResultList $res), @@ -187,7 +187,7 @@ def LiftDotGeneralWithReluDynamic : Pat< (NamedAttr<"dot_dimension_numbers"> $dot_dimension_numbers), (NamedAttr<"precision_config"> (DefaultOrNullAttr $precision_config)))), [(IsNotInLiftedFunc $res), - (FloatValueEquals<"0"> $cst)], [], (addBenefit 15)>; + (FloatValueEquals<"0"> $cst), (AreTheSameValue $dot_general_0, $dot_general_1)], [], (addBenefit 15)>; def LiftConvWithRelu6 : Pat< (StableHLO_ClampOp:$res @@ -287,16 +287,16 @@ def LiftDotGeneralWithBiasAndRelu : Pat< def LiftConvWithBiasAndReluDynamic : Pat< (StableHLO_MaxOp:$res - (StableHLO_AddOp - (StableHLO_ConvolutionOp $lhs, $rhs, $window_strides, $padding, + (StableHLO_AddOp:$add_0 + (StableHLO_ConvolutionOp:$conv_0 $lhs, $rhs, $window_strides, $padding, $lhs_dilation, $rhs_dilation, $window_reversal, $dimension_numbers, $feature_group_count, $batch_group_count, $precision_config), (StableHLO_DynamicBroadcastInDimOp $bias, - (Shape_ShapeOfOp $conv), $_, $_, $_)), + (Shape_ShapeOfOp $conv_1), $_, $_, $_)), (StableHLO_DynamicBroadcastInDimOp (StableHLO_ConstantOp $cst), - (Shape_ShapeOfOp $add), $_, $_, $_)), + (Shape_ShapeOfOp $add_1), $_, $_, $_)), (LiftAsTFXlaCallModule<"composite_conv_with_bias_and_relu_dynamic_fn"> (ArgumentList $lhs, $rhs, $bias), (ResultList $res), @@ -311,18 +311,18 @@ def LiftConvWithBiasAndReluDynamic : Pat< (NamedAttr<"batch_group_count"> $batch_group_count), (NamedAttr<"precision_config"> (DefaultOrNullAttr $precision_config)))), [(IsNotInLiftedFunc $res), - (FloatValueEquals<"0"> $cst), (IsStableHLOConstantOp $bias)], [], (addBenefit 15)>; + (FloatValueEquals<"0"> $cst), (IsStableHLOConstantOp $bias), (AreTheSameValue $conv_0, $conv_1), (AreTheSameValue $add_0, $add_1)], [], (addBenefit 15)>; def LiftDotGeneralWithBiasAndReluDynamic : Pat< (StableHLO_MaxOp:$res - (StableHLO_AddOp - (StableHLO_DotGeneralOp $lhs, $rhs, $dot_dimension_numbers, $precision_config), + (StableHLO_AddOp:$add_0 + (StableHLO_DotGeneralOp:$dot_general_0 $lhs, $rhs, $dot_dimension_numbers, $precision_config), (StableHLO_DynamicBroadcastInDimOp $bias, - (Shape_ShapeOfOp $dot_general), $_, $_, $_)), + (Shape_ShapeOfOp $dot_general_1), $_, $_, $_)), (StableHLO_DynamicBroadcastInDimOp (StableHLO_ConstantOp $cst), - (Shape_ShapeOfOp $add), $_, $_, $_)), + (Shape_ShapeOfOp $add_1), $_, $_, $_)), (LiftAsTFXlaCallModule<"composite_dot_general_with_bias_and_relu_dynamic_fn"> (ArgumentList $lhs, $rhs, $bias), (ResultList $res), @@ -330,7 +330,7 @@ def LiftDotGeneralWithBiasAndReluDynamic : Pat< (NamedAttr<"dot_dimension_numbers"> $dot_dimension_numbers), (NamedAttr<"precision_config"> (DefaultOrNullAttr $precision_config)))), [(IsNotInLiftedFunc $res), - (FloatValueEquals<"0"> $cst), (IsStableHLOConstantOp $bias)], [], (addBenefit 15)>; + (FloatValueEquals<"0"> $cst), (IsStableHLOConstantOp $bias), (AreTheSameValue $dot_general_0, $dot_general_1), (AreTheSameValue $add_0, $add_1)], [], (addBenefit 15)>; def LiftDotGeneralWithBiasSameShapeAndRelu6 : Pat< (StableHLO_ClampOp:$res @@ -392,12 +392,12 @@ def LiftConvWithBiasAndRelu6Dynamic : Pat< (StableHLO_ClampOp:$res (StableHLO_ConstantOp $cst_0), (StableHLO_AddOp - (StableHLO_ConvolutionOp $lhs, $rhs, $window_strides, $padding, + (StableHLO_ConvolutionOp:$conv_0 $lhs, $rhs, $window_strides, $padding, $lhs_dilation, $rhs_dilation, $window_reversal, $dimension_numbers, $feature_group_count, $batch_group_count, $precision_config), (StableHLO_DynamicBroadcastInDimOp $bias, - (Shape_ShapeOfOp $conv), $_, $_, $_)), + (Shape_ShapeOfOp $conv_1), $_, $_, $_)), (StableHLO_ConstantOp $cst_1)), (LiftAsTFXlaCallModule<"composite_conv_with_bias_and_relu6_dynamic_fn"> (ArgumentList $lhs, $rhs, $bias), @@ -412,17 +412,17 @@ def LiftConvWithBiasAndRelu6Dynamic : Pat< (NamedAttr<"feature_group_count"> $feature_group_count), (NamedAttr<"batch_group_count"> $batch_group_count), (NamedAttr<"precision_config"> (DefaultOrNullAttr $precision_config)))), - [(IsNotInLiftedFunc $res), (IsStableHLOConstantOp $bias), (FloatValueEquals<"0"> $cst_0), (FloatValueEquals<"6"> $cst_1)], [], (addBenefit 15)>; + [(IsNotInLiftedFunc $res), (IsStableHLOConstantOp $bias), (FloatValueEquals<"0"> $cst_0), (FloatValueEquals<"6"> $cst_1), (AreTheSameValue $conv_0, $conv_1)], [], (addBenefit 15)>; def LiftDotGeneralWithBiasAndRelu6Dynamic : Pat< (StableHLO_ClampOp:$res (StableHLO_ConstantOp $cst_0), (StableHLO_AddOp - (StableHLO_DotGeneralOp + (StableHLO_DotGeneralOp:$dot_general_0 $lhs, $rhs, $dot_dimension_numbers, $precision_config), (StableHLO_DynamicBroadcastInDimOp $bias, - (Shape_ShapeOfOp $dot_general), $_, $_, $_)), + (Shape_ShapeOfOp $dot_general_1), $_, $_, $_)), (StableHLO_ConstantOp $cst_1)), (LiftAsTFXlaCallModule<"composite_dot_general_with_bias_and_relu6_dynamic_fn"> (ArgumentList $lhs, $rhs, $bias), @@ -430,4 +430,4 @@ def LiftDotGeneralWithBiasAndRelu6Dynamic : Pat< (NamedAttributeList (NamedAttr<"dot_dimension_numbers"> $dot_dimension_numbers), (NamedAttr<"precision_config"> (DefaultOrNullAttr $precision_config)))), - [(IsNotInLiftedFunc $res), (IsStableHLOConstantOp $bias), (FloatValueEquals<"0"> $cst_0), (FloatValueEquals<"6"> $cst_1)], [], (addBenefit 15)>; + [(IsNotInLiftedFunc $res), (IsStableHLOConstantOp $bias), (FloatValueEquals<"0"> $cst_0), (FloatValueEquals<"6"> $cst_1), (AreTheSameValue $dot_general_0, $dot_general_1)], [], (addBenefit 15)>; diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/passes.h b/tensorflow/compiler/mlir/quantization/stablehlo/passes/passes.h index 53f5e84640bd31..5bb3e58be01d58 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/passes.h +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/passes.h @@ -25,6 +25,7 @@ limitations under the License. #include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/quantization/quantization_config.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.pb.h" #include "tensorflow/compiler/mlir/quantization/stablehlo/quantization_options.pb.h" namespace mlir::quant::stablehlo { @@ -32,23 +33,23 @@ namespace mlir::quant::stablehlo { // Creates a `QuantizePass` that quantizes ops according to surrounding qcast / // dcast ops. std::unique_ptr> CreateQuantizePass( - const quant::QuantizationSpecs& quantization_specs); + const quant::QuantizationSpecs& quantization_specs, + bool enable_per_channel_quantized_weight = true); // Creates a pass that quantizes weight component of StableHLO graph. std::unique_ptr> CreateQuantizeWeightPass( const ::stablehlo::quantization::QuantizationComponentSpec& quantization_component_spec = {}); -// Creates an instance of the StableHLO dialect PrepareQuantize pass without any -// arguments. Preset method of SRQ is set to the quantization option by default. -std::unique_ptr> CreatePrepareQuantizePass( - bool enable_per_channel_quantization = false, int bit_width = 8); - // Converts a serialized StableHLO module to bfloat16 and output serialized // module. absl::StatusOr ConvertSerializedStableHloModuleToBfloat16( StringRef serialized_stablehlo_module); +std::unique_ptr> +CreateLiftQuantizableSpotsAsFunctionsPass( + const ::stablehlo::quantization::QuantizationSpecs& quantization_specs); + // Adds generated pass default constructors or options definitions. #define GEN_PASS_DECL // Adds generated pass registration functions. diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/passes.td b/tensorflow/compiler/mlir/quantization/stablehlo/passes/passes.td index 3e1c431b6e6fd5..dcebf70eb5c1e9 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/passes.td +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/passes.td @@ -21,25 +21,6 @@ def QuantizeWeightPass : Pass<"stablehlo-quantize-weight", "mlir::func::FuncOp"> let constructor = "mlir::quant::stablehlo::CreateQuantizeWeightPass()"; } -def PrepareQuantizePass : Pass<"stablehlo-prepare-quantize", "mlir::func::FuncOp"> { - let summary = "Prepare StableHLO dialect for static range quantization."; - let options = [ - Option<"enable_per_channel_quantization_", - "enable-per-channel-quantization", - "bool", /*default=*/"true", - "Whether enable per-channel quantized weights.">, - Option<"bit_width_", "bit-width", "int", /*default=*/"8", - "Bitwidth of quantized integer"> - ]; - let constructor = "mlir::quant::stablehlo::CreatePrepareQuantizePass()"; - let dependentDialects = [ - "mlir::stablehlo::StablehloDialect", - "mlir::quant::QuantizationDialect", - "mlir::quantfork::QuantizationForkDialect", - "mlir::arith::ArithDialect", - ]; -} - def UnfuseMhloBatchNormPass : Pass<"stablehlo-unfuse-mhlo-batch-norm", "mlir::func::FuncOp"> { let summary = "Unfuses batch normalization into arithmetic ops."; } @@ -53,6 +34,7 @@ def LiftQuantizableSpotsAsFunctionsPass : Pass<"stablehlo-lift-quantizable-spots that disperse values. (ex: convolution, dot_general) }]; let dependentDialects = [ + "mlir::func::FuncDialect", "mlir::stablehlo::StablehloDialect", "TF::TensorFlowDialect", ]; @@ -67,40 +49,62 @@ def ReplaceStablehloOpsInMainFunctionWithXlaCallModuleOpsPass : Pass<"stablehlo- }]; } -def QuantizePass : Pass<"stablehlo-quantize", "mlir::ModuleOp"> { - let summary = "Applies static-range quantization on ops."; +def RestoreFunctionNamePass : Pass<"stablehlo-restore-function-name", "ModuleOp"> { + let summary = "Restores function name from XlaCallModule op."; +} + +def QuantizeCompositeFunctionsPass : Pass<"stablehlo-quantize-composite-functions", "ModuleOp"> { + let summary = "Quantize composite functions with QDQ input / outputs."; + let options = [ + Option<"enable_per_channel_quantized_weight_", + "enable-per-channel-quantized-weight", + "bool", /*default=*/"true", + "Whether to enable per-channel quantized weights.">, + Option<"mlir_dump_file_name_", "mlir-dump-file-name", + "std::optional", /*default=*/"std::nullopt", + "MLIR dump file name."> + ]; let dependentDialects = [ + "mlir::arith::ArithDialect", "mlir::stablehlo::StablehloDialect", "mlir::quant::QuantizationDialect", "mlir::quantfork::QuantizationForkDialect", + "TF::TensorFlowDialect", ]; } -def RestoreFunctionNamePass : Pass<"stablehlo-restore-function-name", "ModuleOp"> { - let summary = "Restores function name from XlaCallModule op."; +def PrepareQuantizePass : Pass<"stablehlo-prepare-quantize", "mlir::func::FuncOp"> { + let summary = "Prepare StableHLO dialect for static range quantization by converting quantfork.stats into quantfork.qcast and dcast ops."; + let options = [ + Option<"enable_per_channel_quantized_weight_", + "enable-per-channel-quantized-weight", + "bool", /*default=*/"true", + "Whether to enable per-channel quantized weights.">, + Option<"bit_width_", "bit-width", "int", /*default=*/"8", + "Bitwidth of quantized integer"> + ]; + let dependentDialects = [ + "mlir::stablehlo::StablehloDialect", + "mlir::quant::QuantizationDialect", + "mlir::quantfork::QuantizationForkDialect", + "mlir::arith::ArithDialect", + ]; } -def PostQuantizePass : Pass<"stablehlo-post-quantize", "mlir::func::FuncOp"> { - let summary = "Apply clean-up after quantization."; +def QuantizePass : Pass<"stablehlo-quantize", "mlir::ModuleOp"> { + let summary = "Applies static-range quantization on ops by converting quantfork.qcast, quantfork.dcast, and float op into uniform quantized ops ."; let dependentDialects = [ "mlir::stablehlo::StablehloDialect", + "mlir::quant::QuantizationDialect", "mlir::quantfork::QuantizationForkDialect", ]; } -def QuantizeCompositeFunctionsPass : Pass<"stablehlo-quantize-composite-functions", "ModuleOp"> { - let summary = "Quantize composite functions with QDQ input / outputs."; - let options = [ - Option<"mlir_dump_file_name_", "mlir-dump-file-name", - "std::optional", /*default=*/"std::nullopt", - "MLIR dump file name."> - ]; +def PostQuantizePass : Pass<"stablehlo-post-quantize", "mlir::func::FuncOp"> { + let summary = "Apply clean-up after quantization."; let dependentDialects = [ - "mlir::arith::ArithDialect", "mlir::stablehlo::StablehloDialect", - "mlir::quant::QuantizationDialect", "mlir::quantfork::QuantizationForkDialect", - "TF::TensorFlowDialect", ]; } @@ -124,11 +128,6 @@ def ConvertXlaCallModuleOpToBfloat16Pass : Pass<"stablehlo-convert-xla-call-modu ]; } -def PopulateShapePass : Pass<"populate-shape", "ModuleOp"> { - let summary = "Populate output shape with known information for CustomAggregatorOp and XlaCallModuleOp."; - let dependentDialects = ["TF::TensorFlowDialect"]; -} - def OptimizeGraphPass : Pass<"optimize-graph", "ModuleOp"> { let summary = "Optimize the sub-optimal patterns after quantization."; let dependentDialects = ["mlir::stablehlo::StablehloDialect",]; diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/populate_shape.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/populate_shape.cc deleted file mode 100644 index 0d4f0594f5c7d8..00000000000000 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/populate_shape.cc +++ /dev/null @@ -1,144 +0,0 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include - -#include "llvm/Support/Casting.h" -#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project -#include "mlir/IR/BuiltinTypes.h" // from @llvm-project -#include "mlir/IR/PatternMatch.h" // from @llvm-project -#include "mlir/IR/TypeRange.h" // from @llvm-project -#include "mlir/IR/TypeUtilities.h" // from @llvm-project -#include "mlir/IR/Types.h" // from @llvm-project -#include "mlir/Support/LogicalResult.h" // from @llvm-project -#include "mlir/Support/TypeID.h" // from @llvm-project -#include "mlir/Transforms/DialectConversion.h" // from @llvm-project -#include "tensorflow/compiler/mlir/quantization/stablehlo/passes/passes.h" // IWYU pragma: keep -#include "tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_quant_ops.h" -#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" -#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" -#include "tensorflow/compiler/mlir/tensorflow/utils/dynamic_shape_utils.h" -#include "tensorflow/core/ir/types/dialect.h" - -namespace mlir::quant::stablehlo { - -#define GEN_PASS_DEF_POPULATESHAPEPASS -#include "tensorflow/compiler/mlir/quantization/stablehlo/passes/passes.h.inc" - -namespace { - -class PopulateShapeForCustomAggregatorOp - : public OpConversionPattern { - public: - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite( - TF::CustomAggregatorOp op, TF::CustomAggregatorOpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto input_shape_type = op.getInput().getType().dyn_cast(); - auto output_shape_type = op.getOutput().getType(); - - if (!input_shape_type.isa()) { - input_shape_type = adaptor.getInput().getType(); - } - - if (input_shape_type.isa() && - !output_shape_type.isa() && - TF::HasCompatibleElementTypes(input_shape_type, output_shape_type)) { - auto new_op = rewriter.create( - op->getLoc(), /*output=*/input_shape_type, - /*args=*/adaptor.getInput(), - /*Id=*/op.getId()); - new_op->setAttrs(op->getAttrs()); - rewriter.replaceOp(op, new_op); - return success(); - } - return failure(); - } -}; - -class PopulateShapeForXlaCallModuleOp - : public OpConversionPattern { - public: - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite( - TF::XlaCallModuleOp op, TF::XlaCallModuleOpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - if (op->getNumResults() != 1) { - op->emitError("XlaCallModuleOp doesn't have 1 output."); - return failure(); - } - // Assume XlaCallModuleOp only has 1 output. - auto output_shape_type = op->getResultTypes()[0]; - if (!output_shape_type.isa()) { - auto output_shape_attr = op.getSout()[0].dyn_cast(); - if (!output_shape_attr.hasRank()) { - return failure(); - } - auto new_output_shape_type = tensorflow::GetTypeFromTFTensorShape( - output_shape_attr.getShape(), - getElementTypeOrSelf(op.getResultTypes()[0])); - auto new_op = rewriter.create( - op->getLoc(), /*output=*/new_output_shape_type, - /*args=*/adaptor.getOperands(), - /*version=*/op.getVersionAttr(), - /*module=*/op.getModuleAttr(), - /*Sout=*/op.getSoutAttr()); - new_op->setAttrs(op->getAttrs()); - rewriter.replaceOp(op, new_op); - return success(); - } - return failure(); - } -}; - -class PopulateShapePass - : public impl::PopulateShapePassBase { - public: - MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PopulateShapePass) - - explicit PopulateShapePass() = default; - - private: - void runOnOperation() override; -}; - -void PopulateShapePass::runOnOperation() { - Operation *op = getOperation(); - MLIRContext *context = op->getContext(); - RewritePatternSet patterns(context); - ConversionTarget target(*context); - target.addDynamicallyLegalOp([](Operation *op) { - auto custom_aggregator_op = llvm::dyn_cast(op); - return custom_aggregator_op.getInput().getType().isa() && - custom_aggregator_op.getOutput().getType().isa(); - }); - target.addDynamicallyLegalOp([](Operation *op) { - if (op->getNumResults() != 1) return true; - return op->getResultTypes()[0].isa(); - }); - - patterns - .add( - context); - - if (failed(applyPartialConversion(op, target, std::move(patterns)))) { - return signalPassFailure(); - } -} -} // namespace - -} // namespace mlir::quant::stablehlo diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/prepare_quantize.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/prepare_quantize.cc index 1291b0f7aa83eb..688e21b7d898dc 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/prepare_quantize.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/prepare_quantize.cc @@ -29,6 +29,7 @@ limitations under the License. #include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/Support/TypeID.h" // from @llvm-project #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project #include "stablehlo/dialect/StablehloOps.h" // from @stablehlo #include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h" @@ -40,11 +41,11 @@ namespace mlir { namespace quant { namespace stablehlo { -namespace { - #define GEN_PASS_DEF_PREPAREQUANTIZEPASS #include "tensorflow/compiler/mlir/quantization/stablehlo/passes/passes.h.inc" +namespace { + // Applies prepare quantization on the model in TF dialect. This pass runs // before the quantization pass and propagate the quantization parameters // across ops. This step is necessary for post-training quantization and also @@ -53,12 +54,14 @@ namespace { class PrepareQuantizePass : public impl::PrepareQuantizePassBase { public: - PrepareQuantizePass() = default; - PrepareQuantizePass(const PrepareQuantizePass&) = default; + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PrepareQuantizePass) + + using impl::PrepareQuantizePassBase< + PrepareQuantizePass>::PrepareQuantizePassBase; - explicit PrepareQuantizePass(bool enable_per_channel_quantization, + explicit PrepareQuantizePass(bool enable_per_channel_quantized_weight, int bit_width) { - enable_per_channel_quantization_ = enable_per_channel_quantization; + enable_per_channel_quantized_weight_ = enable_per_channel_quantized_weight; bit_width_ = bit_width; } @@ -162,9 +165,11 @@ void PrepareQuantizePass::runOnOperation() { // Finally, the quantization parameters can be propagated to the rest of the // values (tensors). ApplyQuantizationParamsPropagation( - func, /*is_signed=*/true, bit_width_, !enable_per_channel_quantization_, - GetStableHloOpQuantSpec, GetStableHloQuantScaleSpec, - /*infer_tensor_ranges=*/true, /*legacy_float_scale=*/false); + func, /*is_signed=*/true, bit_width_, + !enable_per_channel_quantized_weight_, GetStableHloOpQuantSpec, + GetStableHloQuantScaleSpec, + /*infer_tensor_ranges=*/true, /*legacy_float_scale=*/false, + /*is_qdq_conversion=*/false); // Restore constants as stablehlo::ConstantOp. RewritePatternSet patterns_2(ctx); @@ -180,9 +185,9 @@ void PrepareQuantizePass::runOnOperation() { // Creates an instance of the TensorFlow dialect PrepareQuantize pass. std::unique_ptr> CreatePrepareQuantizePass( - bool enable_per_channel_quantization, int bit_width) { - return std::make_unique(enable_per_channel_quantization, - bit_width); + bool enable_per_channel_quantized_weight, int bit_width) { + return std::make_unique( + enable_per_channel_quantized_weight, bit_width); } } // namespace stablehlo diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantization_patterns.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantization_patterns.cc index 969336d65a13a4..76430bec75e4ce 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantization_patterns.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantization_patterns.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include #include +#include #include "absl/algorithm/container.h" #include "llvm/ADT/STLExtras.h" @@ -59,6 +60,7 @@ namespace mlir::quant::stablehlo { namespace { +using ::mlir::quant::FindUserOfType; using ::mlir::quant::TryCast; using ::mlir::stablehlo::AddOp; using ::mlir::stablehlo::BroadcastInDimOp; @@ -66,6 +68,7 @@ using ::mlir::stablehlo::ConcatenateOp; using ::mlir::stablehlo::ConvolutionOp; using ::mlir::stablehlo::DotGeneralOp; using ::mlir::stablehlo::DynamicBroadcastInDimOp; +using ::mlir::stablehlo::GatherOp; using ::mlir::stablehlo::GetDimensionSizeOp; using ::mlir::stablehlo::ReshapeOp; using ::mlir::stablehlo::UniformQuantizeOp; @@ -104,7 +107,7 @@ bool IsQuantizedTensorType(const Type type) { // %6 = stablehlo.concatenate %5, %0, %1, %2, dim = 0 : // (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) // -> tensor<4xi32> -// %7 = stablehlo.dynamic_broadcast_in_dims %arg2, %6 +// %7 = stablehlo.dynamic_broadcast_in_dim %arg2, %6 // %8 = stablehlo.add %3, %7 // ``` // @@ -112,54 +115,36 @@ bool IsQuantizedTensorType(const Type type) { // ``` // %3 = stablehlo.convolution(%%arg0, %%arg1) : // (tensor, tensor<2x3x3x2xf32>) -> tensor -// %4 = stablehlo.broadcast_in_dims %arg2, %3 +// %4 = stablehlo.broadcast_in_dim %arg2, %3 // %5 = stablehlo.add %3, %4 // ``` template Operation* GetBroadcastedUserOp(Operation* op) { // Broadcast bias for known input shape. - auto broadcast_in_dims_op = - TryCast(op->getNextNode(), - /*name=*/"broadcast_in_dims_op"); - if (succeeded(broadcast_in_dims_op)) { - auto target_op = TryCast((*broadcast_in_dims_op)->getNextNode(), - /*name=*/"target_op"); - if (succeeded(target_op)) { - return *target_op; - } + auto broadcast_in_dim_op = FindUserOfType(op); + if (broadcast_in_dim_op != nullptr) { + auto target_op = FindUserOfType(broadcast_in_dim_op); + if (target_op != nullptr) return target_op; } // Broadcast bias for unknown input shape. - FailureOr get_dimension_size_op = - TryCast(op->getNextNode(), - /*name=*/"get_dimension_size_op"); - if (failed(get_dimension_size_op)) { - return nullptr; - } - auto reshape_op = TryCast((*get_dimension_size_op)->getNextNode(), - /*name=*/"reshape_op"); - if (failed(reshape_op)) { - return nullptr; - } - auto concatenate_op = TryCast((*reshape_op)->getNextNode(), - /*name=*/"concatenate_op"); - if (failed(concatenate_op)) { - return nullptr; - } + auto get_dimension_size_op = FindUserOfType(op); + if (get_dimension_size_op == nullptr) return nullptr; + + auto reshape_op = FindUserOfType(get_dimension_size_op); + if (reshape_op == nullptr) return nullptr; + + auto concatenate_op = FindUserOfType(reshape_op); + if (concatenate_op == nullptr) return nullptr; + auto dynamic_broadcast_in_dim_op = - TryCast((*concatenate_op)->getNextNode(), - /*name=*/"dynamic_broadcast_in_dim_op"); - if (failed(dynamic_broadcast_in_dim_op)) { - return nullptr; - } - auto target_op = TryCast((*dynamic_broadcast_in_dim_op)->getNextNode(), - /*name=*/"target_op"); - if (failed(target_op)) { - return nullptr; - } - return *target_op; + FindUserOfType(concatenate_op); + if (dynamic_broadcast_in_dim_op == nullptr) return nullptr; + + auto target_op = FindUserOfType(dynamic_broadcast_in_dim_op); + return target_op; } -// Checks if all inputs and outputs are quantized. +// Checks if one of the inputs and outputs are quantized. bool HasQuantizedOperandOrOutput(Operation* call_op) { SmallVector arg_types; for (const Value arg : call_op->getOperands()) { @@ -171,8 +156,8 @@ bool HasQuantizedOperandOrOutput(Operation* call_op) { output_types.push_back(output.getType()); } - return absl::c_all_of(arg_types, IsQuantizedTensorType) && - absl::c_all_of(output_types, IsQuantizedTensorType); + return absl::c_any_of(arg_types, IsQuantizedTensorType) && + absl::c_any_of(output_types, IsQuantizedTensorType); } // Gets the corresponding quantized function name from the given function name. @@ -185,7 +170,7 @@ std::string GetQuantizedFunctionName(const StringRef func_name) { // Returns true if `xla_call_module_op` is quantized. To be considered // quantized, it should meet three conditions: -// 1. At least one of the inputs or outputs should be a uniform quantized type. +// 1. At least one of the inputs and outputs should be a uniform quantized type. // 2. `xla_call_module_op` should have the `kQuantTraitAttrName` attribute. // 3. It should also have the `kEntryFuncAttrName` attribute, which points to // the function that `xla_call_module_op` represents. @@ -254,24 +239,25 @@ template void CreateAndReturnQuantizedBiasPattern( Operation* op, PatternRewriter& rewriter, func::FuncOp entry_func_op, const Type func_result_type, const Type accumulation_quantized_element_type, - GemmStyleOp gemm_style_op, double result_scale) { + GemmStyleOp gemm_style_op) { Value bias_op = op->getOperand(1); Value add_op_result = op->getResult(0); // Broadcast bias value if unmatched with output shape. auto bcast_op = TryCast(bias_op.getDefiningOp(), - /*name=*/"broadcast_in_dims_op"); + /*name=*/"broadcast_in_dim_op"); + if (failed(bcast_op)) { bcast_op = TryCast( bias_op.getDefiningOp(), - /*name=*/"dynamic_broadcast_in_dims_op"); + /*name=*/"dynamic_broadcast_in_dim_op"); } + // Update the bias type for both static and dynamic broadcasts. if (succeeded(bcast_op)) { Value bcast_op_result = (*bcast_op)->getResult(0); auto bcast_op_result_type = bcast_op_result.getType().cast(); const ArrayRef bcast_shape = bcast_op_result_type.getShape(); - const TensorType new_bcast_op_result_type = bcast_op_result_type.cloneWith( bcast_shape, accumulation_quantized_element_type); bcast_op_result.setType(new_bcast_op_result_type); @@ -315,12 +301,14 @@ template LogicalResult MatchGemmStyleOp(func::FuncOp entry_func_op) { auto op_iterator_range = entry_func_op.getOps(); if (op_iterator_range.empty()) { - LLVM_DEBUG(llvm::dbgs() << "Function does not have GemmStyle op.\n"); + LLVM_DEBUG(llvm::dbgs() << "Function does not have " + << GemmStyleOp::getOperationName() << " op.\n"); return failure(); } if (!isa( (*op_iterator_range.begin()).getResult().getType())) { - LLVM_DEBUG(llvm::dbgs() << "GemmStyle op must have ranked tensor type.\n"); + LLVM_DEBUG(llvm::dbgs() << GemmStyleOp::getOperationName() + << " op must have ranked tensor type.\n"); return failure(); } @@ -328,8 +316,8 @@ LogicalResult MatchGemmStyleOp(func::FuncOp entry_func_op) { entry_func_op.getBody().getArguments(); // Function must have input, filter, and optionally bias. if (operands.size() != 2 && operands.size() != 3) { - LLVM_DEBUG(llvm::dbgs() - << "GemmStyle op function should have 2 or 3 operands.\n"); + LLVM_DEBUG(llvm::dbgs() << GemmStyleOp::getOperationName() + << " op function should have 2 or 3 operands.\n"); return failure(); } return success(); @@ -337,42 +325,70 @@ LogicalResult MatchGemmStyleOp(func::FuncOp entry_func_op) { // Gemm Style Op: glossary/gemm. template -void RewriteGemmStyleOp(func::FuncOp entry_func_op, PatternRewriter& rewriter) { - // Update the output type of the gemm_style op. - GemmStyleOp gemm_style_op = *entry_func_op.getOps().begin(); +void RewriteGemmStyleOp(func::FuncOp entry_func_op, PatternRewriter& rewriter, + bool enable_per_channel_quantized_weight) { + const GemmStyleOp gemm_style_op = + *entry_func_op.getOps().begin(); const Type input_type = entry_func_op.getArgumentTypes()[0]; const Type filter_type = entry_func_op.getArgumentTypes()[1]; const Type func_result_type = entry_func_op.getResultTypes()[0]; - const double input_scale = - getElementTypeOrSelf(input_type).cast().getScale(); - const double filter_scale = - getElementTypeOrSelf(filter_type).cast().getScale(); - const double result_scale = input_scale * filter_scale; - - // Define the intermediate output type, which is an i32 quantized type. - // This is intermediate because the final output type of the entry_func_op - // should be an i8 quantized type. - const UniformQuantizedType accumulation_quantized_element_type = - CreateI32F32UniformQuantizedType(gemm_style_op->getLoc(), - *rewriter.getContext(), result_scale, - /*zero_point=*/0); - Value gemm_style_op_result = gemm_style_op->getResult(0); - auto gemm_style_op_result_type = + const auto gemm_style_op_result_type = gemm_style_op_result.getType().cast(); const ArrayRef gemm_style_shape = gemm_style_op_result_type.getShape(); - const TensorType new_gemm_style_op_result_type = - gemm_style_op_result_type.cloneWith(gemm_style_shape, - accumulation_quantized_element_type); + Type accumulation_quantized_element_type; + TensorType new_gemm_style_op_result_type; + + const double input_scale = + getElementTypeOrSelf(input_type).cast().getScale(); + + if (enable_per_channel_quantized_weight) { + ArrayRef filter_scales = getElementTypeOrSelf(filter_type) + .cast() + .getScales(); + std::vector result_scales; + result_scales.reserve(filter_scales.size()); + + for (double filter_scale : filter_scales) { + result_scales.push_back(input_scale * filter_scale); + } + + const ArrayRef zero_points = + getElementTypeOrSelf(filter_type) + .cast() + .getZeroPoints(); + + // [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f] + accumulation_quantized_element_type = + CreateI32F32UniformQuantizedPerAxisType( + gemm_style_op->getLoc(), *rewriter.getContext(), result_scales, + zero_points, /*quantization_dimension=*/3); + + new_gemm_style_op_result_type = gemm_style_op_result_type.cloneWith( + gemm_style_shape, accumulation_quantized_element_type); + } else { + const double filter_scale = getElementTypeOrSelf(filter_type) + .cast() + .getScale(); + double result_scale = input_scale * filter_scale; + + accumulation_quantized_element_type = CreateI32F32UniformQuantizedType( + gemm_style_op->getLoc(), *rewriter.getContext(), result_scale, + /*zero_point=*/0); + + new_gemm_style_op_result_type = gemm_style_op_result_type.cloneWith( + gemm_style_shape, accumulation_quantized_element_type); + } + gemm_style_op_result.setType(new_gemm_style_op_result_type); rewriter.setInsertionPointAfter(gemm_style_op); - Operation* next_op = gemm_style_op->getNextNode(); + Operation* next_op = FindUserOfType<>(gemm_style_op); // If activation exists, omit clipping op. // Since out_scale and out_zp are computed based on clipped range, @@ -381,27 +397,72 @@ void RewriteGemmStyleOp(func::FuncOp entry_func_op, PatternRewriter& rewriter) { // bias fusion CreateAndReturnQuantizedBiasPattern( next_op, rewriter, entry_func_op, func_result_type, - accumulation_quantized_element_type, gemm_style_op, result_scale); + accumulation_quantized_element_type, gemm_style_op); } else if (auto add_op = cast_or_null( GetBroadcastedUserOp(gemm_style_op))) { - // dynamic bias fusion + // broadcasted bias fusion rewriter.setInsertionPointAfter(add_op); CreateAndReturnQuantizedBiasPattern( add_op, rewriter, entry_func_op, func_result_type, - accumulation_quantized_element_type, gemm_style_op, result_scale); + accumulation_quantized_element_type, gemm_style_op); } else { // Non fusible op - // If an op is used multiple times and is not a dynamic shape case, do not - // apply quantization of fused patterns to prevent removal of dependee ops. + // If an op is used multiple times and is not a broadcasted shape case, + // do not apply quantization of fused patterns to prevent removal of + // dependee ops. CreateAndReturnUniformQuantizeOp(rewriter, *gemm_style_op, entry_func_op, func_result_type); } } +template +// Match for tensor manipulation op. +LogicalResult MatchSingularOp(func::FuncOp entry_func_op) { + auto op_iterator_range = entry_func_op.getOps(); + if (op_iterator_range.empty()) { + LLVM_DEBUG(llvm::dbgs() << "Function does not have " + << SingularOp::getOperationName() << " op.\n"); + return failure(); + } + if (!isa( + (*op_iterator_range.begin()).getResult().getType())) { + LLVM_DEBUG(llvm::dbgs() << SingularOp::getOperationName() + << " op must have ranked tensor type.\n"); + return failure(); + } + return success(); +} + +template +void RewriteSingularOp(func::FuncOp entry_func_op, PatternRewriter& rewriter) { + SingularOp singular_op = *entry_func_op.getOps().begin(); + + const Type operand_type = entry_func_op.getArgumentTypes()[0]; + const Type func_result_type = entry_func_op.getResultTypes()[0]; + + // Get the quantized tensor manipulation op's output type and update. + Value singular_op_result = singular_op.getResult(); + auto singular_op_result_type = + singular_op_result.getType().cast(); + const ArrayRef singular_op_shape = + singular_op_result_type.getShape(); + const TensorType new_singular_op_result_type = + singular_op_result_type.cloneWith( + singular_op_shape, + getElementTypeOrSelf(operand_type).cast()); + singular_op_result.setType(new_singular_op_result_type); + + // Create requantization op and return. + rewriter.setInsertionPointAfter(singular_op); + CreateAndReturnUniformQuantizeOp(rewriter, *singular_op, entry_func_op, + func_result_type); +} + // Quantizes the entry function's body containing a `DotGeneralOp`. class QuantizeDotGeneralOpPattern : public EntryFuncBodyQuantizationPattern { public: - explicit QuantizeDotGeneralOpPattern() = default; + explicit QuantizeDotGeneralOpPattern( + bool enable_per_channel_quantized_weight) {} LogicalResult match(func::FuncOp entry_func_op) const override { return MatchGemmStyleOp(entry_func_op); @@ -409,14 +470,19 @@ class QuantizeDotGeneralOpPattern : public EntryFuncBodyQuantizationPattern { void rewrite(func::FuncOp entry_func_op, PatternRewriter& rewriter) const override { - RewriteGemmStyleOp(entry_func_op, rewriter); + RewriteGemmStyleOp( + entry_func_op, rewriter, + /*enable_per_channel_quantized_weight=*/false); } }; // Quantizes the entry function's body containing a `ConvolutionOp`. class QuantizeConvolutionOpPattern : public EntryFuncBodyQuantizationPattern { public: - explicit QuantizeConvolutionOpPattern() = default; + explicit QuantizeConvolutionOpPattern( + bool enable_per_channel_quantized_weight) + : enable_per_channel_quantized_weight_( + enable_per_channel_quantized_weight) {} LogicalResult match(func::FuncOp entry_func_op) const override { return MatchGemmStyleOp(entry_func_op); @@ -424,8 +490,32 @@ class QuantizeConvolutionOpPattern : public EntryFuncBodyQuantizationPattern { void rewrite(func::FuncOp entry_func_op, PatternRewriter& rewriter) const override { - RewriteGemmStyleOp(entry_func_op, rewriter); + RewriteGemmStyleOp(entry_func_op, rewriter, + enable_per_channel_quantized_weight_); + } + + private: + bool enable_per_channel_quantized_weight_; +}; + +// Quantizes the entry function's body containing a `GatherOp`. +class QuantizeGatherOpPattern : public EntryFuncBodyQuantizationPattern { + public: + explicit QuantizeGatherOpPattern(bool enable_per_channel_quantized_weight) + : enable_per_channel_quantized_weight_( + enable_per_channel_quantized_weight) {} + + LogicalResult match(func::FuncOp entry_func_op) const override { + return MatchSingularOp(entry_func_op); + } + + void rewrite(func::FuncOp entry_func_op, + PatternRewriter& rewriter) const override { + RewriteSingularOp(entry_func_op, rewriter); } + + private: + bool enable_per_channel_quantized_weight_; }; // Converts `entry_func_op` to be quantized according to the respective @@ -484,8 +574,11 @@ template >> class XlaCallModuleOpToCallOp : public OpRewritePattern { public: - explicit XlaCallModuleOpToCallOp(MLIRContext& ctx) - : OpRewritePattern(&ctx) {} + explicit XlaCallModuleOpToCallOp(MLIRContext& ctx, + bool enable_per_channel_quantized_weight) + : OpRewritePattern(&ctx), + enable_per_channel_quantized_weight_( + enable_per_channel_quantized_weight) {} LogicalResult match(TF::XlaCallModuleOp op) const override { ModuleOp module_op = op->getParentOfType(); @@ -499,15 +592,19 @@ class XlaCallModuleOpToCallOp : public OpRewritePattern { op->emitError("Failed to find a valid entry function."); return failure(); } - return FuncBodyRewritePatternT().match(entry_func_op); + return FuncBodyRewritePatternT(enable_per_channel_quantized_weight_) + .match(entry_func_op); } void rewrite(TF::XlaCallModuleOp xla_call_module_op, PatternRewriter& rewriter) const override { ReplaceQuantizedXlaCallModuleOpWithQuantizedCallOp( *rewriter.getContext(), rewriter, xla_call_module_op, - FuncBodyRewritePatternT()); + FuncBodyRewritePatternT(enable_per_channel_quantized_weight_)); } + + private: + bool enable_per_channel_quantized_weight_; }; // Quantizes op with regions such as stablehlo.reduce_window op. @@ -701,7 +798,8 @@ bool IsQuantizedCompositeFunction(func::CallOp call_op) { if (type.getElementType().isa()) { return false; } - if (type.getElementType().isa()) { + if (type.getElementType() + .isa()) { has_quantized_types = true; } } @@ -711,7 +809,8 @@ bool IsQuantizedCompositeFunction(func::CallOp call_op) { if (type.getElementType().isa()) { return false; } - if (type.getElementType().isa()) { + if (type.getElementType() + .isa()) { has_quantized_types = true; } } @@ -780,9 +879,17 @@ bool IsConnectedWithQuantizedCompsiteFunction(Operation* same_scale_op) { // TODO: b/307620428 - Increase fused op coverage for static range quantization. void PopulateFusedGemmStylePatterns(MLIRContext& ctx, - RewritePatternSet& patterns) { - patterns.add, - XlaCallModuleOpToCallOp>(ctx); + RewritePatternSet& patterns, + bool enable_per_channel_quantized_weight) { + patterns.add>( + ctx, enable_per_channel_quantized_weight); + // By default, we set `enable_per_channel_quantized_weight` to true for + // passes to ensure per-channel quantization for all supported ops. + // For ops that do not yet support per-channel quantization, explicitly + // mark as false like below. We will soon add support for per-channel + // quantization of the following ops. + patterns.add>( + ctx, /*enable_per_channel_quantized_weight=*/false); } void PopulateQuantizeOpWithRegionPattern(MLIRContext& ctx, @@ -790,4 +897,10 @@ void PopulateQuantizeOpWithRegionPattern(MLIRContext& ctx, patterns.add(ctx); } +void PopulateQuantizeSingularOpPatterns(MLIRContext& ctx, + RewritePatternSet& patterns) { + // TODO: b/307620772 - Per-channel quantization for gather. + patterns.add>( + ctx, /*enable_per_channel_quantized_weight=*/false); +} } // namespace mlir::quant::stablehlo diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantization_patterns.h b/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantization_patterns.h index 91170115ce2baa..9922e5bd69eb49 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantization_patterns.h +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantization_patterns.h @@ -124,11 +124,20 @@ class StableHloQuantizationPattern : public RewritePattern { // Const-> QuantizeOp pattern will be handled separately. return failure(); } - if (Operation* quantizing_op = quantize_operand.getDefiningOp()) { + if (Operation* quantizing_op = quantize_operand.getDefiningOp(); + quantizing_op != nullptr) { quantizing_ops.push_back(quantizing_op); + } else { + // When `QuantizeOpT`'s operand does not have a defining op, it means it + // is a `BlockArgument`. The pattern does not match if there is no op to + // quantize. + return failure(); } } + // Safeguard check to ensure that there is at least one quantizable op. + if (quantizing_ops.empty()) return failure(); + absl::flat_hash_set ops_blocklist = quant_params_.quant_spec.ops_blocklist; absl::flat_hash_set nodes_blocklist = @@ -276,15 +285,19 @@ class StableHloQuantizationPattern : public RewritePattern { }; // Gemm Style Op: glossary/gemm. -// Populates conversion patterns to unfuse batch normalization operations. void PopulateFusedGemmStylePatterns(MLIRContext& ctx, - RewritePatternSet& patterns); + RewritePatternSet& patterns, + bool enable_per_channel_quantized_weight); // Populates pattern for quantization of ops with regions such as // stablehlo.reduce_window op. void PopulateQuantizeOpWithRegionPattern(MLIRContext& ctx, RewritePatternSet& patterns); +// Populates conversion patterns for unary data movement ops. +void PopulateQuantizeSingularOpPatterns(MLIRContext& ctx, + RewritePatternSet& patterns); + } // namespace mlir::quant::stablehlo #endif // TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_PASSES_QUANTIZATION_PATTERNS_H_ diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantize.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantize.cc index 8d321d9269345c..fd5898d686be96 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantize.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantize.cc @@ -44,7 +44,7 @@ namespace mlir::quant::stablehlo { namespace { // Base struct for quantization. -template +template struct StableHloQuantizationBase : public StableHloQuantizationPattern { explicit QuantizePass() = default; - explicit QuantizePass(const QuantizationSpecs& quant_specs) - : quant_specs_(quant_specs) {} + explicit QuantizePass(const QuantizationSpecs& quant_specs, + bool enable_per_channel_quantized_weight) + : quant_specs_(quant_specs), + enable_per_channel_quantized_weight_( + enable_per_channel_quantized_weight) {} - QuantizePass(const QuantizePass& other) : quant_specs_(other.quant_specs_) {} + QuantizePass(const QuantizePass& other) + : quant_specs_(other.quant_specs_), + enable_per_channel_quantized_weight_( + other.enable_per_channel_quantized_weight_) {} private: void runOnOperation() override; QuantizationSpecs quant_specs_; + bool enable_per_channel_quantized_weight_; }; void QuantizePass::runOnOperation() { @@ -131,14 +138,14 @@ void QuantizePass::runOnOperation() { patterns.add( &ctx, quant_params); PopulateQuantizeOpWithRegionPattern(ctx, patterns); - PopulateFusedGemmStylePatterns(ctx, patterns); + PopulateFusedGemmStylePatterns(ctx, patterns, + enable_per_channel_quantized_weight_); + PopulateQuantizeSingularOpPatterns(ctx, patterns); if (failed(applyPatternsAndFoldGreedily(module_op, std::move(patterns)))) { // There are cases where no rewrites happen even if a pattern matches, // causing this to result in a convergence failure. Consider this as a // best-effort. - // TODO: b/305469508 - Make QuantizationPattern converge if there are no - // patterns that are rewritable. module_op.emitWarning("Failed to converge pattern at QuantizePass."); } } @@ -146,8 +153,10 @@ void QuantizePass::runOnOperation() { } // namespace std::unique_ptr> CreateQuantizePass( - const QuantizationSpecs& quantization_specs) { - return std::make_unique(quantization_specs); + const QuantizationSpecs& quantization_specs, + bool enable_per_channel_quantized_weight) { + return std::make_unique(quantization_specs, + enable_per_channel_quantized_weight); } } // namespace mlir::quant::stablehlo diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantize_composite_functions.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantize_composite_functions.cc index 026c0742615128..0d491a8cb66404 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantize_composite_functions.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantize_composite_functions.cc @@ -12,6 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include + #include "absl/status/status.h" #include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project // IWYU pragma: keep #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project @@ -25,6 +27,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/lite/quantization/ir/QuantOps.h" // IWYU pragma: keep #include "tensorflow/compiler/mlir/lite/quantization/quantization_config.h" #include "tensorflow/compiler/mlir/quantization/stablehlo/passes/passes.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.pb.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/cc/run_passes.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/quantization_options.pb.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" // IWYU pragma: keep @@ -50,6 +53,11 @@ class QuantizeCompositeFunctionsPass using impl::QuantizeCompositeFunctionsPassBase< QuantizeCompositeFunctionsPass>::QuantizeCompositeFunctionsPassBase; + explicit QuantizeCompositeFunctionsPass( + bool enable_per_channel_quantized_weight) { + enable_per_channel_quantized_weight_ = enable_per_channel_quantized_weight; + } + private: void runOnOperation() override; }; @@ -65,11 +73,16 @@ void QuantizeCompositeFunctionsPass::runOnOperation() { // (XlaCallModuleOps) with quantized input and output types, which are not // allowed in the TF dialect. pm.enableVerifier(false); - - pm.addNestedPass(CreatePrepareQuantizePass()); + PrepareQuantizePassOptions options; + options.enable_per_channel_quantized_weight_ = + enable_per_channel_quantized_weight_; + // Change this to user-given bit width once we have custom configuration. + options.bit_width_ = 8; + pm.addNestedPass(createPrepareQuantizePass(options)); // QuantizePass modifies FuncOps referenced outside of its given scope // and therefore requires a module-level context. - pm.addPass(CreateQuantizePass(quant_specs)); + pm.addPass( + CreateQuantizePass(quant_specs, enable_per_channel_quantized_weight_)); pm.addNestedPass(createPostQuantizePass()); ModuleOp module_op = getOperation(); @@ -79,7 +92,13 @@ void QuantizeCompositeFunctionsPass::runOnOperation() { signalPassFailure(); } } - } // namespace +// Creates an instance of the TensorFlow dialect QuantizeCompositeFunctionsPass. +std::unique_ptr> CreateQuantizeCompositeFunctionsPass( + bool enable_per_channel_quantized_weight) { + return std::make_unique( + enable_per_channel_quantized_weight); +} + } // namespace mlir::quant::stablehlo diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/replace_stablehlo_ops_in_main_function_with_xla_call_module_ops.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/replace_stablehlo_ops_in_main_function_with_xla_call_module_ops.cc index aa43b64b97b10b..6a152843e3278a 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/replace_stablehlo_ops_in_main_function_with_xla_call_module_ops.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/replace_stablehlo_ops_in_main_function_with_xla_call_module_ops.cc @@ -32,6 +32,7 @@ limitations under the License. #include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/TypeID.h" // from @llvm-project #include "stablehlo/dialect/StablehloOps.h" // from @stablehlo +#include "tensorflow/compiler/mlir/quantization/common/func.h" #include "tensorflow/compiler/mlir/quantization/common/lift_as_function_call.h" #include "tensorflow/compiler/mlir/quantization/stablehlo/utils/stablehlo_type_utils.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" @@ -45,7 +46,6 @@ namespace mlir::quant::stablehlo { namespace { -constexpr StringRef kQuantizeTargetOpAttr = "tf_quant.composite_function"; constexpr StringRef kStablehloModuleAttrsAttrName = "_stablehlo_module_attrs"; constexpr StringRef kUsesShapePolymorphismAttr = "jax.uses_shape_polymorphism"; @@ -73,19 +73,6 @@ class ReplaceStablehloOpsInMainFunctionWithXlaCallModuleOpsPass void runOnOperation() override; }; -// Finds the main function from module_op. Returns nullptr if not found. -// The model's signature keys will contain "@serving_default" as default TF -// Model signature, or "@main" if it is in being exported from MLIR module to -// GraphDef. -func::FuncOp GetMainFunc(ModuleOp module_op) { - for (auto func_op : module_op.getOps()) { - if (func_op.getSymName().equals("main") || - func_op.getSymName().equals("serving_default")) - return func_op; - } - return nullptr; -} - // Creates a unique stablehlo function name based on op order. std::string CreateStablehloFunctionName(const int id) { return Twine("_stablehlo_main_").concat(std::to_string(id)).str(); @@ -447,7 +434,7 @@ void ReplaceStablehloOpsInMainFunctionWithXlaCallModuleOpsPass:: runOnOperation() { ModuleOp module_op = getOperation(); - func::FuncOp main_func = GetMainFunc(module_op); + func::FuncOp main_func = FindMainFuncOp(module_op); if (!main_func) return; DuplicateSmallConstantOps(module_op, main_func); diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/testing/passes.td b/tensorflow/compiler/mlir/quantization/stablehlo/passes/testing/passes.td index ab22c3b3d47ae9..38d60e94f97e9a 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/testing/passes.td +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/testing/passes.td @@ -48,3 +48,30 @@ def TestPostCalibrationComponentPass : Pass<"stablehlo-test-post-calibration-com "mlir::quantfork::QuantizationForkDialect", ]; } + +def TestTFToStablehloPass : Pass<"stablehlo-test-tf-to-stablehlo", "mlir::ModuleOp"> { + let summary = "Test-only pass to test TFToStablehloPasses."; + let description = [{ + Runs the TFToStablehloPasses. + }]; + let dependentDialects = [ + "mlir::stablehlo::StablehloDialect", "mlir::TF::TensorFlowDialect", + "mlir::chlo::ChloDialect", "mlir::quant::QuantizationDialect", + "mlir::mhlo::MhloDialect", "mlir::shape::ShapeDialect", + "mlir::sparse_tensor::SparseTensorDialect", "mlir::vhlo::VhloDialect", + ]; +} + +def TestLiftQuantizableSpotsAsFunctionsWithQuantizationSpecsPass : + Pass<"stablehlo-test-lift-quantizable-spots-as-functions-with-quantization-specs", "mlir::ModuleOp"> { + let summary = "Test-only pass for testing the LiftQuantizableSpotsAsFunctionsPass with a predefined QuantizationSpecs."; + let description = [{ + This test-only pass is the same as `LiftQuantizableSpotsAsFunctionsPass` but + has predefined `QuantizationSpecs` to make FileCheck testing easier. + }]; + let dependentDialects = [ + "mlir::func::FuncDialect", + "mlir::stablehlo::StablehloDialect", + "TF::TensorFlowDialect", + ]; +} diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/testing/test_lift_quantizable_spots_as_functions_with_quantization_specs.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/testing/test_lift_quantizable_spots_as_functions_with_quantization_specs.cc new file mode 100644 index 00000000000000..e8cb185cb7b55d --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/testing/test_lift_quantizable_spots_as_functions_with_quantization_specs.cc @@ -0,0 +1,101 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project // IWYU pragma: keep +#include "mlir/Pass/Pass.h" // from @llvm-project // IWYU pragma: keep +#include "mlir/Pass/PassManager.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/Support/TypeID.h" // from @llvm-project +#include "stablehlo/dialect/StablehloOps.h" // from @stablehlo // IWYU pragma: keep +#include "tensorflow/compiler/mlir/quantization/stablehlo/passes/passes.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.pb.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" // IWYU pragma: keep +#include "tsl/platform/protobuf.h" // IWYU pragma: keep + +namespace mlir::quant::stablehlo::testing { + +// NOLINTNEXTLINE - Automatically generated. +#define GEN_PASS_DEF_TESTLIFTQUANTIZABLESPOTSASFUNCTIONSWITHQUANTIZATIONSPECSPASS +#include "tensorflow/compiler/mlir/quantization/stablehlo/passes/testing/passes.h.inc" + +namespace { + +using ::stablehlo::quantization::QuantizationSpecs; +using ::tsl::protobuf::TextFormat; +// NOLINTNEXTLINE(misc-include-cleaner) - Required for OSS. +using ::tsl::protobuf::io::ArrayInputStream; + +// Configure `QuantizationSpecs` to disable quantization for all dot_general +// quantizable units. +constexpr absl::string_view kSpecsDisableAllDotGeneralByFuncName = + R"pb(specs + [ { + matcher { function_name { regex: "composite_dot_general_.*" } } + method { no_quantization {} } + }])pb"; + +class TestLiftQuantizableSpotsAsFunctionsWithQuantizationSpecsPass + : public impl:: + TestLiftQuantizableSpotsAsFunctionsWithQuantizationSpecsPassBase< + TestLiftQuantizableSpotsAsFunctionsWithQuantizationSpecsPass> { + public: + using impl::TestLiftQuantizableSpotsAsFunctionsWithQuantizationSpecsPassBase< + TestLiftQuantizableSpotsAsFunctionsWithQuantizationSpecsPass>:: + TestLiftQuantizableSpotsAsFunctionsWithQuantizationSpecsPassBase; + + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( + TestLiftQuantizableSpotsAsFunctionsWithQuantizationSpecsPass) + + private: + void runOnOperation() override; +}; + +// Parses a text proto into a `QuantizationSpecs` proto. Returns +// `InvalidArgumentError` if `text_proto` is invalid. +absl::StatusOr ParseQuantizationSpecsTextProto( + const absl::string_view text_proto) { + QuantizationSpecs quantization_specs; + TextFormat::Parser parser; + ArrayInputStream input_stream(text_proto.data(), text_proto.size()); + if (parser.Parse(&input_stream, &quantization_specs)) { + return quantization_specs; + } + return absl::InvalidArgumentError("Could not parse text proto."); +} + +void TestLiftQuantizableSpotsAsFunctionsWithQuantizationSpecsPass:: + runOnOperation() { + PassManager pass_manager{&getContext()}; + + const absl::StatusOr quantization_specs = + ParseQuantizationSpecsTextProto(kSpecsDisableAllDotGeneralByFuncName); + if (!quantization_specs.ok()) { + signalPassFailure(); + return; + } + + pass_manager.addPass( + CreateLiftQuantizableSpotsAsFunctionsPass(*quantization_specs)); + + if (failed(pass_manager.run(getOperation()))) { + signalPassFailure(); + return; + } +} + +} // namespace +} // namespace mlir::quant::stablehlo::testing diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/testing/test_post_calibration_component.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/testing/test_post_calibration_component.cc index de7a53e6d22431..88fa9e59b4977d 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/testing/test_post_calibration_component.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/testing/test_post_calibration_component.cc @@ -40,6 +40,7 @@ namespace mlir::quant::stablehlo::testing { namespace { using ::stablehlo::quantization::PipelineConfig; +using ::stablehlo::quantization::StaticRangePtqPreset; class TestPostCalibrationComponentPass : public impl::TestPostCalibrationComponentPassBase< @@ -60,11 +61,12 @@ void TestPostCalibrationComponentPass::runOnOperation() { OpPassManager pm(ModuleOp::getOperationName()); + StaticRangePtqPreset static_range_ptq_preset; PipelineConfig pipeline_config; pipeline_config.set_unpack_quantized_types(unpack_quantized_types_); PostCalibrationComponent component(&ctx); - component.AddPasses(pm, pipeline_config); + component.AddPasses(pm, static_range_ptq_preset, pipeline_config); // Adds a XlaCallModuleOp deserialization pass for easier testing by // inspecting the contents of serialized StableHLO function. diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/testing/test_tf_to_stablehlo_pass.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/testing/test_tf_to_stablehlo_pass.cc new file mode 100644 index 00000000000000..3af53a213b0064 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/testing/test_tf_to_stablehlo_pass.cc @@ -0,0 +1,69 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project // IWYU pragma: keep +#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project // IWYU pragma: keep +#include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project // IWYU pragma: keep +#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" // from @llvm-project // IWYU pragma: keep +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project // IWYU pragma: keep +#include "mlir/Pass/PassManager.h" // from @llvm-project +#include "mlir/Pass/PassRegistry.h" // from @llvm-project +#include "mlir/Support/TypeID.h" // from @llvm-project +#include "stablehlo/dialect/ChloOps.h" // from @stablehlo // IWYU pragma: keep +#include "stablehlo/dialect/StablehloOps.h" // from @stablehlo // IWYU pragma: keep +#include "stablehlo/dialect/VhloOps.h" // from @stablehlo // IWYU pragma: keep +#include "tensorflow/compiler/mlir/lite/quantization/ir/QuantOps.h" // IWYU pragma: keep +#include "tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.pb.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/cc/run_passes.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/quantization_options.pb.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/quantize_preprocess.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h" // IWYU pragma: keep +#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" // IWYU pragma: keep + +namespace mlir::quant::stablehlo::testing { + +#define GEN_PASS_DEF_TESTTFTOSTABLEHLOPASS +#include "tensorflow/compiler/mlir/quantization/stablehlo/passes/testing/passes.h.inc" + +namespace { + +using ::tensorflow::quantization::AddTFToStablehloPasses; +using ::tensorflow::quantization::RunPassesOnModuleOp; + +class TestTFToStablehloPass + : public impl::TestTFToStablehloPassBase { + public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestTFToStablehloPass) + + private: + void runOnOperation() override; +}; + +void TestTFToStablehloPass::runOnOperation() { + ModuleOp module_op = getOperation(); + MLIRContext* ctx = &getContext(); + mlir::PassManager pm(ctx); + + AddTFToStablehloPasses(pm); + if (!RunPassesOnModuleOp( + /*mlir_dump_file_name=*/"test_tf_to_stablehlo_pass", pm, module_op) + .ok()) { + return signalPassFailure(); + } +} + +} // namespace +} // namespace mlir::quant::stablehlo::testing diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/python/BUILD b/tensorflow/compiler/mlir/quantization/stablehlo/python/BUILD index 8716775299b5ca..a91ceec6e151f5 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/python/BUILD +++ b/tensorflow/compiler/mlir/quantization/stablehlo/python/BUILD @@ -67,17 +67,25 @@ pytype_strict_library( tf_py_strict_test( name = "quantize_model_test", srcs = ["integration_test/quantize_model_test.py"], + shard_count = 50, # Parallelize the test to avoid timeouts. deps = [ ":quantization", ":quantize_model_test_base", "//tensorflow/compiler/mlir/quantization/stablehlo:quantization_config_proto_py", "//tensorflow/compiler/mlir/quantization/tensorflow/python:representative_dataset", + "//tensorflow/python/eager:def_function", + "//tensorflow/python/framework:dtypes", "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor_spec", "//tensorflow/python/framework:test_lib", + "//tensorflow/python/module", + "//tensorflow/python/ops:math_ops", "//tensorflow/python/ops:nn_ops", "//tensorflow/python/platform:client_testlib", "//tensorflow/python/saved_model:load", + "//tensorflow/python/saved_model:save", "//tensorflow/python/saved_model:tag_constants", + "//tensorflow/python/types:core", "@absl_py//absl/testing:parameterized", ], ) @@ -87,6 +95,7 @@ tf_python_pybind_extension( srcs = ["pywrap_quantization.cc"], pytype_srcs = ["pywrap_quantization.pyi"], deps = [ + "//tensorflow/compiler/mlir/quantization/stablehlo/cc:config_impl", "//tensorflow/compiler/mlir/quantization/stablehlo/cc:static_range_ptq", "//tensorflow/compiler/mlir/quantization/tensorflow/python:type_casters", "@pybind11", diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/python/integration_test/quantize_model_test.py b/tensorflow/compiler/mlir/quantization/stablehlo/python/integration_test/quantize_model_test.py index 5acbf1c2953a8b..f359981eaf89d7 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/python/integration_test/quantize_model_test.py +++ b/tensorflow/compiler/mlir/quantization/stablehlo/python/integration_test/quantize_model_test.py @@ -13,7 +13,7 @@ # limitations under the License. # ============================================================================== import itertools -from typing import Optional, Sequence +from typing import Mapping, Optional, Sequence from absl.testing import parameterized import numpy as np @@ -22,12 +22,19 @@ from tensorflow.compiler.mlir.quantization.stablehlo.python import quantization from tensorflow.compiler.mlir.quantization.stablehlo.python.integration_test import quantize_model_test_base from tensorflow.compiler.mlir.quantization.tensorflow.python import representative_dataset as repr_dataset +from tensorflow.python.eager import def_function +from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_spec from tensorflow.python.framework import test_util +from tensorflow.python.module import module +from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn_ops from tensorflow.python.platform import test from tensorflow.python.saved_model import load +from tensorflow.python.saved_model import save from tensorflow.python.saved_model import tag_constants +from tensorflow.python.types import core def parameter_combinations(test_parameters): @@ -64,6 +71,7 @@ class StaticRangeQuantizationTest(quantize_model_test_base.QuantizedModelTest): ([10, 1, 1024], [10, 1024, 3]), ([2, 3, 1, 1024], [2, 3, 1024, 3]), ), + 'rng_seed': (1230, 1231, 1232, 1233), }]) ) @test_util.run_in_graph_and_eager_modes @@ -72,6 +80,7 @@ def test_matmul_ptq_model( bias_fn: Optional[ops.Operation], activation_fn: Optional[ops.Operation], dim_sizes: Sequence[int], + rng_seed: int, ): lhs_dim_size, rhs_dim_size = dim_sizes input_shape = (*lhs_dim_size,) @@ -85,7 +94,7 @@ def test_matmul_ptq_model( activation_fn, ) - rng = np.random.default_rng(seed=1235) + rng = np.random.default_rng(rng_seed) input_data = ops.convert_to_tensor( rng.uniform(low=0.0, high=1.0, size=static_input_shape).astype( np.float32 @@ -136,7 +145,7 @@ def data_gen() -> repr_dataset.RepresentativeDataset: @parameterized.parameters( parameter_combinations([{ - 'same_scale_op': [ + 'same_scale_op': ( 'concatenate', 'gather', 'max_pool', @@ -145,13 +154,15 @@ def data_gen() -> repr_dataset.RepresentativeDataset: 'select', 'slice', 'transpose', - ], + ), + 'rng_seed': (0, 11, 222, 3333), }]) ) @test_util.run_in_graph_and_eager_modes def test_matmul_and_same_scale_ptq_model( self, same_scale_op: str, + rng_seed: int, ): input_shape = (2, 3, 1, 1024) filter_shape = (2, 3, 1024, 3) @@ -164,7 +175,7 @@ def test_matmul_and_same_scale_ptq_model( same_scale_op, ) - rng = np.random.default_rng(seed=1235) + rng = np.random.default_rng(rng_seed) input_data = ops.convert_to_tensor( rng.uniform(low=0.0, high=1.0, size=static_input_shape).astype( np.float32 @@ -229,7 +240,11 @@ def data_gen() -> repr_dataset.RepresentativeDataset: False, True, ), - 'enable_per_channel_quantization': (False,), + 'enable_per_channel_quantized_weight': ( + False, + True, + ), + 'rng_seed': (10, 11, 12, 13), }]) ) @test_util.run_in_graph_and_eager_modes @@ -239,7 +254,8 @@ def test_conv_ptq_model( activation_fn: Optional[ops.Operation], has_batch_norm: bool, input_shape_dynamic: bool, - enable_per_channel_quantization: bool, + enable_per_channel_quantized_weight: bool, + rng_seed: int, dilations: Sequence[int] = None, ): input_shape = (None, 3, 4, 3) if input_shape_dynamic else (1, 3, 4, 3) @@ -257,7 +273,7 @@ def test_conv_ptq_model( ) # Generate model input data. - rng = np.random.default_rng(seed=1224) + rng = np.random.default_rng(rng_seed) static_input_shape = [dim if dim is not None else 2 for dim in input_shape] input_data = ops.convert_to_tensor( rng.uniform(low=0.0, high=1.0, size=static_input_shape).astype( @@ -285,7 +301,8 @@ def data_gen() -> repr_dataset.RepresentativeDataset: qc.RepresentativeDatasetConfig( tf_record=qc.TfRecordFile(path=dataset_path) ) - ] + ], + enable_per_channel_quantized_weight=enable_per_channel_quantized_weight, ), tf_saved_model=qc.TfSavedModelConfig(tags=[tag_constants.SERVING]), ) @@ -307,10 +324,19 @@ def data_gen() -> repr_dataset.RepresentativeDataset: # values are arbitrary. self.assertAllClose(new_outputs, expected_outputs, rtol=0.02, atol=0.05) - @parameterized.parameters(('abc,cde->abde',), ('abc,dce->abde',)) + @parameterized.parameters( + parameter_combinations([{ + 'equation': ( + 'abc,cde->abde', + 'abc,dce->abde', + ), + 'rng_seed': (82, 82732, 4444, 14), + }]) + ) def test_einsum_ptq_model( self, equation: str, + rng_seed: int, ): _, y_shape, bias_shape, x_signature, y_signature = ( self._prepare_sample_einsum_datashapes(equation, use_bias=True) @@ -326,7 +352,7 @@ def test_einsum_ptq_model( ) # Generate model input data. - rng = np.random.default_rng(seed=1231) + rng = np.random.default_rng(rng_seed) input_data = ops.convert_to_tensor( rng.uniform(low=0.0, high=1.0, size=x_signature).astype('f4') ) @@ -390,6 +416,239 @@ def test_when_preset_not_srq_raises_error(self): config, ) + @test_util.run_in_graph_and_eager_modes + def test_ptq_denylist_basic(self): + """Tests that the op is not quantized when no quantization is enabled.""" + input_shape = (1, 2) + model = self._create_matmul_model( + input_shape, + weight_shape=(2, 3), + saved_model_path=self._input_saved_model_path, + ) + + rng = np.random.default_rng(1230) + random_tensor_gen_fn = lambda: rng.uniform( + low=0.0, high=1.0, size=input_shape + ).astype(np.float32) + + def data_gen() -> repr_dataset.RepresentativeDataset: + for _ in range(50): + yield {'input_tensor': random_tensor_gen_fn()} + + dataset_path = self.create_tempfile('tfrecord').full_path + path_map = {'serving_default': dataset_path} + repr_dataset.TfRecordRepresentativeDatasetSaver(path_map).save( + {'serving_default': data_gen()} + ) + + config = qc.QuantizationConfig( + static_range_ptq_preset=qc.StaticRangePtqPreset( + representative_datasets=[ + qc.RepresentativeDatasetConfig( + tf_record=qc.TfRecordFile(path=dataset_path) + ) + ] + ), + tf_saved_model=qc.TfSavedModelConfig(tags=[tag_constants.SERVING]), + # Disable quantization for the quantizable unit (lifted function) whose + # function name starts with "composite_dot_general". + specs=qc.QuantizationSpecs( + specs=[ + qc.QuantizationSpec( + matcher=qc.MatcherSpec( + function_name=qc.FunctionNameMatcherSpec( + regex='composite_dot_general.*' + ) + ), + method=qc.Method(no_quantization={}), + ) + ] + ), + ) + quantization.quantize_saved_model( + self._input_saved_model_path, + self._output_saved_model_path, + config, + ) + + input_data = ops.convert_to_tensor(random_tensor_gen_fn()) + expected_outputs = model.matmul(input_data) + + root = load.load(self._output_saved_model_path) + self.assertCountEqual(root.signatures.keys(), {'serving_default'}) + + new_outputs = root.signatures['serving_default']( + input_tensor=ops.convert_to_tensor(input_data) + ) + + # Indirectly tests that the model is not quantized by asserting that there + # are negligible numeric difference. + self.assertAllClose(new_outputs, expected_outputs, rtol=0.000001) + + @test_util.run_in_graph_and_eager_modes + def test_ptq_selective_denylist(self): + """Tests that the op is not quantized when no quantization is enabled.""" + + rng = np.random.default_rng(1230) + random_tensor_gen_fn = lambda shape: rng.uniform( + low=-1.0, high=1.0, size=shape + ).astype(np.float32) + + class TwoMatmulModel(module.Module): + """A model with two matmul ops.""" + + @def_function.function + def matmul(self, input_tensor: core.Tensor) -> Mapping[str, core.Tensor]: + """Performs a matrix multiplication. + + Args: + input_tensor: Input tensor to matmul with the filter. + + Returns: + A 'output' -> output tensor mapping + """ + out = math_ops.matmul(input_tensor, random_tensor_gen_fn((2, 3))) + out = math_ops.matmul(out, random_tensor_gen_fn((3, 4))) + return {'output': out} + + model = TwoMatmulModel() + input_shape = (1, 2) + + save.save( + model, + self._input_saved_model_path, + signatures=model.matmul.get_concrete_function( + tensor_spec.TensorSpec( + shape=input_shape, dtype=dtypes.float32, name='input_tensor' + ) + ), + ) + + def data_gen() -> repr_dataset.RepresentativeDataset: + for _ in range(50): + yield {'input_tensor': random_tensor_gen_fn(input_shape)} + + dataset_path = self.create_tempfile('tfrecord').full_path + path_map = {'serving_default': dataset_path} + repr_dataset.TfRecordRepresentativeDatasetSaver(path_map).save( + {'serving_default': data_gen()} + ) + + config = qc.QuantizationConfig( + static_range_ptq_preset=qc.StaticRangePtqPreset( + representative_datasets=[ + qc.RepresentativeDatasetConfig( + tf_record=qc.TfRecordFile(path=dataset_path) + ), + ], + ), + tf_saved_model=qc.TfSavedModelConfig(tags=[tag_constants.SERVING]), + # Disable quantization for the quantizable unit (lifted function) whose + # function name matches "composite_dot_general_fn_1". + # "composite_dot_general_fn_2" will be quantized. + specs=qc.QuantizationSpecs( + specs=[ + qc.QuantizationSpec( + matcher=qc.MatcherSpec( + function_name=qc.FunctionNameMatcherSpec( + regex='composite_dot_general_fn_1' + ) + ), + method=qc.Method(no_quantization={}), + ) + ] + ), + ) + + quantization.quantize_saved_model( + self._input_saved_model_path, + self._output_saved_model_path, + config, + ) + + input_data = ops.convert_to_tensor(random_tensor_gen_fn(input_shape)) + expected_outputs = model.matmul(input_data) + + root = load.load(self._output_saved_model_path) + self.assertCountEqual(root.signatures.keys(), {'serving_default'}) + + new_outputs = root.signatures['serving_default']( + input_tensor=ops.convert_to_tensor(input_data) + ) + + # Indirectly tests that the model is only partially quantized. + self.assertAllClose(new_outputs, expected_outputs, rtol=0.011) + + @test_util.run_in_graph_and_eager_modes + def test_ptq_quantization_method_not_applied_when_matcher_mismatch(self): + """Tests that quantization method is not applied to unmatched units.""" + input_shape = (1, 2) + model = self._create_matmul_model( + input_shape, + weight_shape=(2, 3), + saved_model_path=self._input_saved_model_path, + ) + + rng = np.random.default_rng(1230) + random_tensor_gen_fn = lambda: rng.uniform( + low=0.0, high=1.0, size=input_shape + ).astype(np.float32) + + def data_gen() -> repr_dataset.RepresentativeDataset: + for _ in range(50): + yield {'input_tensor': random_tensor_gen_fn()} + + dataset_path = self.create_tempfile('tfrecord').full_path + path_map = {'serving_default': dataset_path} + repr_dataset.TfRecordRepresentativeDatasetSaver(path_map).save( + {'serving_default': data_gen()} + ) + + config = qc.QuantizationConfig( + static_range_ptq_preset=qc.StaticRangePtqPreset( + representative_datasets=[ + qc.RepresentativeDatasetConfig( + tf_record=qc.TfRecordFile(path=dataset_path) + ) + ] + ), + tf_saved_model=qc.TfSavedModelConfig(tags=[tag_constants.SERVING]), + specs=qc.QuantizationSpecs( + specs=[ + qc.QuantizationSpec( + # Provide a regex that wouldn't match any quantizable units. + matcher=qc.MatcherSpec( + function_name=qc.FunctionNameMatcherSpec( + regex='.*invalid_function_name.*' + ), + ), + method=qc.Method(no_quantization={}), + ), + ], + ), + ) + quantization.quantize_saved_model( + self._input_saved_model_path, + self._output_saved_model_path, + config, + ) + + input_data = ops.convert_to_tensor(random_tensor_gen_fn()) + expected_outputs = model.matmul(input_data) + + root = load.load(self._output_saved_model_path) + self.assertCountEqual(root.signatures.keys(), {'serving_default'}) + + new_outputs = root.signatures['serving_default']( + input_tensor=ops.convert_to_tensor(input_data) + ) + + # Tests that the quantized graph outputs similar values. They also shouldn't + # be exactly the same. Indirectly proves that the `FunctionNameMatcherSpec` + # with regex '.*invalid_function_name.*' did not match the quantizable unit. + self.assertAllClose(new_outputs, expected_outputs, rtol=0.04) + self.assertNotAllClose(new_outputs, expected_outputs, rtol=0.00001) + if __name__ == '__main__': test.main() diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/python/pywrap_quantization.cc b/tensorflow/compiler/mlir/quantization/stablehlo/python/pywrap_quantization.cc index 7aae8224242001..bfecd82e21e56f 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/python/pywrap_quantization.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/python/pywrap_quantization.cc @@ -20,6 +20,7 @@ limitations under the License. #include "pybind11_abseil/absl_casters.h" // from @pybind11_abseil // IWYU pragma: keep #include "pybind11_abseil/import_status_module.h" // from @pybind11_abseil #include "pybind11_abseil/status_casters.h" // from @pybind11_abseil // IWYU pragma: keep +#include "tensorflow/compiler/mlir/quantization/stablehlo/cc/config.h" #include "tensorflow/compiler/mlir/quantization/stablehlo/cc/static_range_ptq.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/python/type_casters.h" // IWYU pragma: keep @@ -28,8 +29,9 @@ namespace py = pybind11; namespace { using ::mlir::quant::stablehlo::QuantizeStaticRangePtq; +using ::stablehlo::quantization::PopulateDefaults; -} +} // namespace PYBIND11_MODULE(pywrap_quantization, m) { // Supports absl::Status type conversions. @@ -62,4 +64,17 @@ PYBIND11_MODULE(pywrap_quantization, m) { py::arg("signature_keys"), py::arg("signature_def_map_serialized"), py::arg("function_aliases"), py::arg("py_function_library")); // LINT.ThenChange(pywrap_quantization.pyi:static_range_ptq) + + // If the function signature changes, likely its corresponding .pyi type + // hinting should also change. + // LINT.IfChange(populate_default_configs) + m.def("populate_default_configs", &PopulateDefaults, + R"pbdoc( + Populates `QuantizationConfig` with default values. + + Returns an updated `QuantizationConfig` (serialized) after populating + default values to fields that the user did not explicitly specify. + )pbdoc", + py::arg("user_provided_config_serialized")); + // LINT.ThenChange(pywrap_quantization.pyi:static_range_ptq) } diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/python/pywrap_quantization.pyi b/tensorflow/compiler/mlir/quantization/stablehlo/python/pywrap_quantization.pyi index 513c5431d6e8cb..b3d016465004e6 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/python/pywrap_quantization.pyi +++ b/tensorflow/compiler/mlir/quantization/stablehlo/python/pywrap_quantization.pyi @@ -30,3 +30,10 @@ def static_range_ptq( ) -> Any: ... # Status # LINT.ThenChange() + +# LINT.IfChange(populate_default_configs) +def populate_default_configs( + user_provided_quantization_config_serialized: bytes, +) -> bytes: ... # QuantizationConfig + +# LINT.ThenChange() diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/python/quantization.py b/tensorflow/compiler/mlir/quantization/stablehlo/python/quantization.py index ea33602ed8e725..5e1ce4e7d65ba8 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/python/quantization.py +++ b/tensorflow/compiler/mlir/quantization/stablehlo/python/quantization.py @@ -89,7 +89,9 @@ def quantize_saved_model( ' single signature.' ) - config = _populate_default_quantization_config(config) + config = qc.QuantizationConfig.FromString( + pywrap_quantization.populate_default_configs(config.SerializeToString()) + ) signature_def_map = save_model.get_signatures_from_saved_model( src_saved_model_path, diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.proto b/tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.proto index 49c6a493d355a4..25623ad4497655 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.proto +++ b/tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.proto @@ -42,7 +42,7 @@ message StaticRangePtqPreset { // this field once available. // If set true, enable channel-wise quantization for all supported ops. // This value is true by default. - bool enable_per_channel_quantization = 2; + bool enable_per_channel_quantized_weight = 2; } // Metadata specific to the input TensorFlow SavedModel, which may be required @@ -63,10 +63,68 @@ message PipelineConfig { optional bool unpack_quantized_types = 1; } +// A quantization method representing "do not quantize". Mostly used for +// denylisting quantizable units from quantization. +message NoQuantization {} + +// Represents a matching method that matches quantizable units by lifted +// functions' names. +message FunctionNameMatcherSpec { + // Regular expression to match lifted functions' names. Underlying regex + // engine uses re2, which accepts a subset of PCRE. See + // https://github.com/google/re2/wiki/Syntax for details. + string regex = 1; +} + +// Matcher specification for identifying quantizable units. +message MatcherSpec { + // Matches lifted functions by their names. + FunctionNameMatcherSpec function_name = 1; +} + +// Specifies how to quantize matched quantizable units. +message Method { + NoQuantization no_quantization = 1; +} + +// A QuantizationSpec is essentially a (matcher spec, quantization method) pair, +// where the matcher spec is used to identify quantizable units and the +// quantization method specifies what type of quantization to apply on the +// matched quantizable units. +// Next ID: 3 +message QuantizationSpec { + // Configures matchers for identifying quantizable units. Matched quantizable + // units will be quantized according to `method`. + MatcherSpec matcher = 1; + + // Specifies how to quantize the matched quantizable units. + Method method = 2; +} + +// Quantization specifications. A simple wrapper around a sequence of +// `QuantizationSpec`s so that specs can be easily passed around or represented +// as a textproto. +// Next ID: 2 +message QuantizationSpecs { + // List of `QuantizationSpec`s. Later spec in the sequence takes precedence. + // + // NOTE: Tie-breaking mechanism is not yet supported. Providing multiple + // `QuantizationSpec` with conflicting quantizable units may result in + // undefined behavior. + // TODO: b/307620778 - Support tie-breaking for conflicting specs. + repeated QuantizationSpec specs = 1; +} + // Quantization configuration for StableHLO Quantizer. This is the primary // message containing all configurable options. -// Next ID: 4 +// Next ID: 5 message QuantizationConfig { + // Config presets provide predefined popular or common quantization specs. + // Lightweight users may choose one of the presets for quick experiments. Each + // preset is completely represented by `QuantizationSpecs`. When extra entries + // in `QuantizationSpecs` are provided along with a preset, then the preset + // will be overridden for the quantizable units matched by those additional + // `QuantizationSpec`s. oneof preset { // Performs best-effort static-range post-training quantization (PTQ). StaticRangePtqPreset static_range_ptq_preset = 1; @@ -77,4 +135,6 @@ message QuantizationConfig { // Configures the graph transformation pipeline for quantization. PipelineConfig pipeline_config = 3; + + QuantizationSpecs specs = 4; } diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/BUILD b/tensorflow/compiler/mlir/quantization/stablehlo/tests/BUILD index 6fc15864fb0f8b..db4bc1a92483c1 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/tests/BUILD +++ b/tensorflow/compiler/mlir/quantization/stablehlo/tests/BUILD @@ -46,24 +46,3 @@ tf_cc_test( "@local_tsl//tsl/platform:protobuf", ], ) - -tf_cc_test( - name = "stablehlo_op_quant_spec_test", - srcs = ["stablehlo_op_quant_spec_test.cc"], - deps = [ - "//tensorflow/compiler/mlir/lite/quantization/ir:QuantOps", - "//tensorflow/compiler/mlir/quantization/common:test_base", - "//tensorflow/compiler/mlir/quantization/stablehlo/ops:stablehlo_op_quant_spec", - "//tensorflow/compiler/mlir/tensorflow", - "//tensorflow/compiler/mlir/tensorflow:tensorflow_types", - "//tensorflow/core:test", - "@com_google_absl//absl/strings:string_view", - "@com_google_googletest//:gtest_main", - "@llvm-project//mlir:ArithDialect", - "@llvm-project//mlir:FuncDialect", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:Parser", - "@llvm-project//mlir:QuantOps", - "@stablehlo//:stablehlo_ops", - ], -) diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/bridge/convert-mhlo-quant-to-int.mlir b/tensorflow/compiler/mlir/quantization/stablehlo/tests/bridge/convert-mhlo-quant-to-int.mlir index 266b9735224e79..cba7d378fcc190 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/tests/bridge/convert-mhlo-quant-to-int.mlir +++ b/tensorflow/compiler/mlir/quantization/stablehlo/tests/bridge/convert-mhlo-quant-to-int.mlir @@ -117,10 +117,10 @@ func.func @quantize_per_channel(%arg0: tensor<26x26x3x2xf32> // CHECK-DAG: %[[QMIN:.*]] = mhlo.constant dense<-2.14748365E+9> : tensor // CHECK-DAG: %[[QMAX:.*]] = mhlo.constant dense<2.14748365E+9> : tensor // CHECK: %[[DIVIDE:.*]] = chlo.broadcast_divide %arg0, %[[SCALES]] - // CHECK-SAME: {broadcast_dimensions = dense<3> : tensor<1xi64>} + // CHECK-SAME: {broadcast_dimensions = array} // CHECK-SAME: (tensor<26x26x3x2xf32>, tensor<2xf32>) -> tensor<26x26x3x2xf32> // CHECK: %[[ADD:.*]] = chlo.broadcast_add %[[DIVIDE]], %[[ZPS]] - // CHECK-SAME: {broadcast_dimensions = dense<3> : tensor<1xi64>} + // CHECK-SAME: {broadcast_dimensions = array} // CHECK-SAME: (tensor<26x26x3x2xf32>, tensor<2xf32>) -> tensor<26x26x3x2xf32> // CHECK: %[[CLAMP:.*]] = mhlo.clamp %[[QMIN]], %[[ADD]], %[[QMAX]] // CHECK: %[[ROUND:.*]] = mhlo.round_nearest_even %[[CLAMP]] @@ -141,12 +141,12 @@ func.func @dequantize_per_channel( // CHECK-DAG: %[[ZPS:.*]] = mhlo.constant dense<[-10, 2]> : tensor<2xi32> // CHECK: %[[SUBTRACT:.*]] = chlo.broadcast_subtract // CHECK-SAME: %[[INPUT:.*]], %[[ZPS]] - // CHECK-SAME: {broadcast_dimensions = dense<3> : tensor<1xi64>} + // CHECK-SAME: {broadcast_dimensions = array} // CHECK-SAME: (tensor<26x26x3x2xi32>, tensor<2xi32>) -> tensor<26x26x3x2xi32> // CHECK: %[[FLOAT:.*]] = mhlo.convert %[[SUBTRACT]] // CHECK: %[[RESULT:.*]] = chlo.broadcast_multiply // CHECK-SAME: %[[FLOAT]], %[[SCALES]] - // CHECK-SAME: {broadcast_dimensions = dense<3> : tensor<1xi64>} + // CHECK-SAME: {broadcast_dimensions = array} // CHECK-SAME: (tensor<26x26x3x2xf32>, tensor<2xf32>) -> tensor<26x26x3x2xf32> %0 = mhlo.uniform_dequantize %arg0 : ( tensor<26x26x3x2x!quant.uniform> @@ -304,6 +304,78 @@ func.func @add_different_res_type( // ----- +// CHECK-LABEL: func @add_per_channel +func.func @add_per_channel( + %arg0: tensor>, + %arg1: tensor> + ) -> tensor> { + // CHECK: %[[ADD:.*]] = mhlo.add {{.*}} : tensor + // CHECK: %[[ZPS:.*]] = mhlo.constant dense<[3, 2]> : tensor<2xi32> + // CHECK: %[[BCAST_SUB:.*]] = chlo.broadcast_subtract %[[ADD]], %[[ZPS]] + // CHECK-SAME: {broadcast_dimensions = array} + // CHECK-SAME: (tensor, tensor<2xi32>) -> tensor + // CHECK: return %[[BCAST_SUB]] : tensor + %11 = mhlo.add %arg0, %arg1 : tensor> + return %11 : tensor> +} + +// ----- + +// CHECK-LABEL: func @add_per_channel_no_zp +func.func @add_per_channel_no_zp( + %arg0: tensor>, + %arg1: tensor> + ) -> tensor> { + // CHECK: %[[ADD:.*]] = mhlo.add {{.*}} : tensor + // CHECK: return %[[ADD]] : tensor + %11 = mhlo.add %arg0, %arg1 : tensor> + return %11 : tensor> +} + +// ----- + +func.func @add_per_channel_i8( + %arg0: tensor>, + %arg1: tensor> + ) -> tensor> { + // expected-error@+2 {{Per-channel quantized AddOp requires i32 storage type}} + // expected-error@+1 {{failed to legalize operation 'mhlo.add' that was explicitly marked illegal}} + %11 = mhlo.add %arg0, %arg1 : tensor> + return %11 : tensor> +} + +// ----- + +func.func @add_per_channel_different_quant_types( + %arg0: tensor>, + %arg1: tensor> + ) -> tensor> { + // expected-error@+2 {{Per-channel quantized AddOp requires the same quantized element type for all operands and results}} + // expected-error@+1 {{failed to legalize operation 'mhlo.add' that was explicitly marked illegal}} + %11 = mhlo.add %arg0, %arg1 : ( + tensor>, + tensor> + ) -> tensor> + return %11 : tensor> +} + +// ----- + +func.func @add_per_channel_per_tensor_mix( + %arg0: tensor>, + %arg1: tensor> + ) -> tensor> { + // expected-error@+2 {{Per-channel quantized AddOp requires the same quantized element type for all operands and results}} + // expected-error@+1 {{failed to legalize operation 'mhlo.add' that was explicitly marked illegal}} + %11 = mhlo.add %arg0, %arg1 : ( + tensor>, + tensor> + ) -> tensor> + return %11 : tensor> +} + +// ----- + // CHECK-LABEL: func @requantize func.func @requantize( %arg0: tensor> @@ -351,10 +423,10 @@ func.func @requantize_per_channel( // CHECK-DAG: %[[VAL1:.*]] = mhlo.convert %arg0 : (tensor<2x2xi8>) -> tensor<2x2xf32> // CHECK-DAG: %[[MERGED_SCALE:.*]] = mhlo.constant dense<[2.000000e+00, 5.000000e-01]> : tensor<2xf32> // CHECK: %[[VAL2:.*]] = chlo.broadcast_multiply %[[VAL1]], %[[MERGED_SCALE]] - // CHECK-SAME: broadcast_dimensions = dense<1> : tensor<1xi64> + // CHECK-SAME: broadcast_dimensions = array // CHECK-DAG: %[[MERGED_ZP:.*]] = mhlo.constant dense<[-5.000000e+00, -2.000000e+00]> : tensor<2xf32> // CHECK: %[[VAL3:.*]] = chlo.broadcast_add %[[VAL2]], %[[MERGED_ZP]] - // CHECK-SAME: broadcast_dimensions = dense<1> : tensor<1xi64> + // CHECK-SAME: broadcast_dimensions = array // CHECK-DAG: %[[QUANT_MIN:.*]] = mhlo.constant dense<-1.280000e+02> : tensor // CHECK-DAG: %[[QUANT_MAX:.*]] = mhlo.constant dense<1.270000e+02> : tensor // CHECK: %[[VAL4:.*]] = mhlo.clamp %[[QUANT_MIN]], %[[VAL3]], %[[QUANT_MAX]] @@ -375,10 +447,10 @@ func.func @requantize_per_channel_to_per_tensor( // CHECK-DAG: %[[VAL1:.*]] = mhlo.convert %arg0 : (tensor<2x2xi8>) -> tensor<2x2xf32> // CHECK-DAG: %[[MERGED_SCALE:.*]] = mhlo.constant dense<[2.000000e+00, 1.000000e+00]> : tensor<2xf32> // CHECK: %[[VAL2:.*]] = chlo.broadcast_multiply %[[VAL1]], %[[MERGED_SCALE]] - // CHECK-SAME: broadcast_dimensions = dense<1> : tensor<1xi64> + // CHECK-SAME: broadcast_dimensions = array // CHECK-DAG: %[[MERGED_ZP:.*]] = mhlo.constant dense<[-5.000000e+00, -1.000000e+00]> : tensor<2xf32> // CHECK: %[[VAL3:.*]] = chlo.broadcast_add %[[VAL2]], %[[MERGED_ZP]] - // CHECK-SAME: broadcast_dimensions = dense<1> : tensor<1xi64> + // CHECK-SAME: broadcast_dimensions = array // CHECK-DAG: %[[QUANT_MIN:.*]] = mhlo.constant dense<-1.280000e+02> : tensor // CHECK-DAG: %[[QUANT_MAX:.*]] = mhlo.constant dense<1.270000e+02> : tensor // CHECK: %[[VAL4:.*]] = mhlo.clamp %[[QUANT_MIN]], %[[VAL3]], %[[QUANT_MAX]] @@ -399,10 +471,10 @@ func.func @requantize_per_tensor_to_per_channel( // CHECK-DAG: %[[VAL1:.*]] = mhlo.convert %arg0 : (tensor<2x2xi8>) -> tensor<2x2xf32> // CHECK-DAG: %[[MERGED_SCALE:.*]] = mhlo.constant dense<[1.000000e+00, 5.000000e-01]> : tensor<2xf32> // CHECK: %[[VAL2:.*]] = chlo.broadcast_multiply %[[VAL1]], %[[MERGED_SCALE]] - // CHECK-SAME: broadcast_dimensions = dense<1> : tensor<1xi64> + // CHECK-SAME: broadcast_dimensions = array // CHECK-DAG: %[[MERGED_ZP:.*]] = mhlo.constant dense<[-1.000000e+00, -2.000000e+00]> : tensor<2xf32> // CHECK: %[[VAL3:.*]] = chlo.broadcast_add %[[VAL2]], %[[MERGED_ZP]] - // CHECK-SAME: broadcast_dimensions = dense<1> : tensor<1xi64> + // CHECK-SAME: broadcast_dimensions = array // CHECK-DAG: %[[QUANT_MIN:.*]] = mhlo.constant dense<-1.280000e+02> : tensor // CHECK-DAG: %[[QUANT_MAX:.*]] = mhlo.constant dense<1.270000e+02> : tensor // CHECK: %[[VAL4:.*]] = mhlo.clamp %[[QUANT_MIN]], %[[VAL3]], %[[QUANT_MAX]] diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/bridge/convert_tf_quant_ops_to_mhlo.mlir b/tensorflow/compiler/mlir/quantization/stablehlo/tests/bridge/convert_tf_quant_ops_to_mhlo.mlir index 7a568425415170..7ca87df8587758 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/tests/bridge/convert_tf_quant_ops_to_mhlo.mlir +++ b/tensorflow/compiler/mlir/quantization/stablehlo/tests/bridge/convert_tf_quant_ops_to_mhlo.mlir @@ -1,7 +1,7 @@ // RUN: stablehlo-quant-opt %s -quant-convert-tf-quant-ops-to-mhlo | FileCheck %s // CHECK-LABEL: func @quantized_matmul_fn -func.func @quantized_matmul_fn(%input: tensor<*xf32>) -> tensor<*xf32> { +func.func @quantized_matmul_fn(%input: tensor) -> tensor { %weight = "tf.Const"() { value = #tf_type : tensor<2x2x!tf_type.qint8> } : () -> tensor<2x2x!tf_type.qint8> %weight_scales = "tf.Const"() { value = dense<1.0> : tensor } : () -> tensor %weight_zps = "tf.Const"() { value = dense<3> : tensor } : () -> tensor @@ -9,12 +9,12 @@ func.func @quantized_matmul_fn(%input: tensor<*xf32>) -> tensor<*xf32> { // CHECK: "tf.AddV2" // CHECK: mhlo.constant // CHECK-SAME{LITERAL}: dense<[[1, 2], [3, 4]]> : tensor<2x2xi8> - %0 = "tf.AddV2"(%input, %input) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> + %0 = "tf.AddV2"(%input, %input) : (tensor, tensor) -> tensor // CHECK: "mhlo.dot" - // CHECK-SAME: (tensor<*xf32>, tensor<2x2x!quant.uniform>) -> tensor<*xf32> - %1 = "tf.UniformQuantizedDotHybrid"(%0, %weight, %weight_scales, %weight_zps) {rhs_quantization_axis = -1 : i64, rhs_quantization_min_val = -128 : i64, rhs_quantization_max_val = 127 : i64} : (tensor<*xf32>, tensor<2x2x!tf_type.qint8>, tensor, tensor) -> tensor<*xf32> - func.return %1 : tensor<*xf32> + // CHECK-SAME: (tensor, tensor<2x2x!quant.uniform>) -> tensor + %1 = "tf.UniformQuantizedDotHybrid"(%0, %weight, %weight_scales, %weight_zps) {rhs_quantization_axis = -1 : i64, rhs_quantization_min_val = -128 : i64, rhs_quantization_max_val = 127 : i64} : (tensor, tensor<2x2x!tf_type.qint8>, tensor, tensor) -> tensor + func.return %1 : tensor } // ----- @@ -40,7 +40,7 @@ func.func @uniform_quantized_add(%input: tensor<3x2xf32>) -> tensor<3x2xf32> { quantization_axis = -1 : i64, quantization_min_val = -2147483648 : i64, quantization_max_val = 2147483647 : i64 } : (tensor<3x2xf32>, tensor, tensor) -> tensor<3x2x!tf_type.qint32> - // CHECK: chlo.broadcast_add %[[LHS2]], %[[RHS]] {broadcast_dimensions = dense<1> : tensor<1xi64>} : + // CHECK: chlo.broadcast_add %[[LHS2]], %[[RHS]] {broadcast_dimensions = array} : // CHECK-SAME: (tensor<3x2x!quant.uniform>, tensor<2x!quant.uniform>) // CHECK-SAME: -> tensor<3x2x!quant.uniform> %1 = "tf.UniformQuantizedAdd"( @@ -85,7 +85,7 @@ func.func @uniform_quantized_add_bias_not_const(%input1: tensor<3x2xi32>, %input %input1_qint = "tf.Cast"(%input1) {Truncate = false} : (tensor<3x2xi32>) -> tensor<3x2x!tf_type.qint32> %input2_qint = "tf.Cast"(%input2) {Truncate = false} : (tensor<2xi32>) -> tensor<2x!tf_type.qint32> - // CHECK: %[[RES:.*]] = chlo.broadcast_add %[[LHS_2]], %[[RHS_2]] {broadcast_dimensions = dense<1> : tensor<1xi64>} : + // CHECK: %[[RES:.*]] = chlo.broadcast_add %[[LHS_2]], %[[RHS_2]] {broadcast_dimensions = array} : // CHECK-SAME: (tensor<3x2x!quant.uniform>, tensor<2x!quant.uniform>) // CHECK-SAME: -> tensor<3x2x!quant.uniform> %result = "tf.UniformQuantizedAdd"( diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/post_calibration_component.mlir b/tensorflow/compiler/mlir/quantization/stablehlo/tests/components/post_calibration_component.mlir similarity index 100% rename from tensorflow/compiler/mlir/quantization/stablehlo/tests/post_calibration_component.mlir rename to tensorflow/compiler/mlir/quantization/stablehlo/tests/components/post_calibration_component.mlir diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/pre_calibration_component.mlir b/tensorflow/compiler/mlir/quantization/stablehlo/tests/components/pre_calibration_component.mlir similarity index 100% rename from tensorflow/compiler/mlir/quantization/stablehlo/tests/pre_calibration_component.mlir rename to tensorflow/compiler/mlir/quantization/stablehlo/tests/components/pre_calibration_component.mlir diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/components/tf_to_stablehlo.mlir b/tensorflow/compiler/mlir/quantization/stablehlo/tests/components/tf_to_stablehlo.mlir new file mode 100644 index 00000000000000..09afb528f602aa --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/tests/components/tf_to_stablehlo.mlir @@ -0,0 +1,59 @@ +// RUN: stablehlo-quant-opt %s -split-input-file -verify-diagnostics -stablehlo-test-tf-to-stablehlo | FileCheck %s + +func.func @fused_batchnorm_no_training(%arg0: tensor<8x8x8x8xf32>) -> (tensor<8x8x8x8xf32>) { + %cst_0 = "tf.Const"() {value = dense<[0.1, 0.2, 0.1, 0.2, 0.1, 0.2, 0.1, 0.2]> : tensor<8xf32>} : () -> tensor<8xf32> + %cst_1 = "tf.Const"() {value = dense<[0.3, 0.4, 0.3, 0.4, 0.3, 0.4, 0.3, 0.4]> : tensor<8xf32>} : () -> tensor<8xf32> + %0:6 = "tf.FusedBatchNormV3"(%arg0, %cst_0, %cst_1, %cst_0, %cst_1) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = false} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) + func.return %0#0 : tensor<8x8x8x8xf32> +} +// CHECK: func.func @main(%[[ARG_0:.+]]: tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> +// CHECK-DAG: %[[CONST_0:.*]] = stablehlo.constant dense<{{.*}}> : tensor<8xf32> +// CHECK-DAG: %[[CONST_1:.*]] = stablehlo.constant dense<{{.*}}> : tensor<8xf32> +// CHECK: %[[BROADCAST_0:.*]] = stablehlo.broadcast_in_dim %[[CONST_0]], dims = [3] : (tensor<8xf32>) -> tensor<8x8x8x8xf32> +// CHECK: %[[BROADCAST_1:.*]] = stablehlo.broadcast_in_dim %[[CONST_1]], dims = [3] : (tensor<8xf32>) -> tensor<8x8x8x8xf32> +// CHECK: %[[MUL:.*]] = stablehlo.multiply %arg0, %[[BROADCAST_0]] : tensor<8x8x8x8xf32> +// CHECK: %[[ADD:.*]] = stablehlo.add %[[MUL]], %[[BROADCAST_1]] : tensor<8x8x8x8xf32> +// CHECK: return %[[ADD]] : tensor<8x8x8x8xf32> + +// ----- + +func.func @fuse_conv_batchnorm(%arg0: tensor<1x3x4x3xf32>) -> (tensor<1x3x2x2xf32>) { + %cst_0 = "tf.Const"() {value = dense<1.000000e+00> : tensor<2x3x3x2xf32>} : () -> tensor<2x3x3x2xf32> + %cst_1 = "tf.Const"() {value = dense<[0.1, 0.2]> : tensor<2xf32>} : () -> tensor<2xf32> + %cst_2 = "tf.Const"() {value = dense<[0.3, 0.4]> : tensor<2xf32>} : () -> tensor<2xf32> + %0 = "tf.Conv2D"(%arg0, %cst_0) {data_format = "NHWC", dilations = [1, 1, 2, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 2, 1], use_cudnn_on_gpu = true} : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>) -> tensor<1x3x2x2xf32> + %1:6 = "tf.FusedBatchNormV3"(%0, %cst_1, %cst_2, %cst_1, %cst_2) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = false} : (tensor<1x3x2x2xf32>, tensor<2xf32>, tensor<2xf32>, tensor<2xf32>, tensor<2xf32>) -> (tensor<1x3x2x2xf32>, tensor<2xf32>, tensor<2xf32>, tensor<2xf32>, tensor<2xf32>, tensor<2xf32>) + func.return %1#0 : tensor<1x3x2x2xf32> +} +// CHECK: func.func @main(%[[ARG:.+]]: tensor<1x3x4x3xf32>) -> tensor<1x3x2x2xf32> { +// CHECK-DAG: %[[CONST_0:.*]] = stablehlo.constant dense<[{{.*}}]> : tensor<2xf32> +// CHECK-DAG: %[[CONST_1:.*]] = stablehlo.constant dense<[{{.*}}]> : tensor<2xf32> +// CHECK: %[[BROADCAST_0:.*]] = stablehlo.broadcast_in_dim %[[CONST_0]], dims = [3] : (tensor<2xf32>) -> tensor<1x3x2x2xf32> +// CHECK: %[[BROADCAST_1:.*]] = stablehlo.broadcast_in_dim %[[CONST_1]], dims = [3] : (tensor<2xf32>) -> tensor<2x3x3x2xf32> +// CHECK: %[[CONV:.*]] = stablehlo.convolution(%[[ARG]], %[[BROADCAST_1]]) {{.*}} : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>) -> tensor<1x3x2x2xf32> +// CHECK: %[[ADD:.*]] = stablehlo.add %[[CONV]], %[[BROADCAST_0]] : tensor<1x3x2x2xf32> +// CHECK: return %[[ADD]] : tensor<1x3x2x2xf32> + +// ----- + +func.func @func_conv_batchnorm_relu6(%arg0: tensor<1x3x4x3xf32>) -> (tensor<1x3x2x2xf32>) { + %cst_0 = "tf.Const"() {value = dense<1.000000e+00> : tensor<2x3x3x2xf32>} : () -> tensor<2x3x3x2xf32> + %cst_1 = "tf.Const"() {value = dense<[0.1, 0.2]> : tensor<2xf32>} : () -> tensor<2xf32> + %cst_2 = "tf.Const"() {value = dense<[0.3, 0.4]> : tensor<2xf32>} : () -> tensor<2xf32> + %0 = "tf.Conv2D"(%arg0, %cst_0) {data_format = "NHWC", dilations = [1, 1, 2, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 2, 1], use_cudnn_on_gpu = true} : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>) -> tensor<1x3x2x2xf32> + %1:6 = "tf.FusedBatchNormV3"(%0, %cst_1, %cst_2, %cst_1, %cst_2) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = false} : (tensor<1x3x2x2xf32>, tensor<2xf32>, tensor<2xf32>, tensor<2xf32>, tensor<2xf32>) -> (tensor<1x3x2x2xf32>, tensor<2xf32>, tensor<2xf32>, tensor<2xf32>, tensor<2xf32>, tensor<2xf32>) + %2 = "tf.Relu6"(%1#0) : (tensor<1x3x2x2xf32>) -> tensor<1x3x2x2xf32> + func.return %2 : tensor<1x3x2x2xf32> +} +// CHECK: func.func @main(%[[ARG:.+]]: tensor<1x3x4x3xf32>) -> tensor<1x3x2x2xf32> { +// CHECK-DAG: %[[CONST_0:.*]] = stablehlo.constant dense<[{{.*}}]> : tensor<2xf32> +// CHECK-DAG: %[[CONST_1:.*]] = stablehlo.constant dense<[{{.*}}]> : tensor<2xf32> +// CHECK-DAG: %[[CONST_2:.*]] = stablehlo.constant dense<6.000000e+00> : tensor +// CHECK-DAG: %[[CONST_3:.*]] = stablehlo.constant dense<0.000000e+00> : tensor +// CHECK: %[[BROADCAST_0:.*]] = stablehlo.broadcast_in_dim %[[CONST_0]], dims = [3] : (tensor<2xf32>) -> tensor<1x3x2x2xf32> +// CHECK: %[[BROADCAST_1:.*]] = stablehlo.broadcast_in_dim %[[CONST_1]], dims = [3] : (tensor<2xf32>) -> tensor<2x3x3x2xf32> +// CHECK: %[[CONV:.*]] = stablehlo.convolution(%[[ARG]], %[[BROADCAST_1]]) {{.*}} : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>) -> tensor<1x3x2x2xf32> +// CHECK: %[[ADD:.*]] = stablehlo.add %[[CONV]], %[[BROADCAST_0]] : tensor<1x3x2x2xf32> +// CHECK: %[[RELU6:.*]] = stablehlo.clamp %[[CONST_3]], %[[ADD]], %[[CONST_2]] : (tensor, tensor<1x3x2x2xf32>, tensor) -> tensor<1x3x2x2xf32> +// CHECK: return %[[RELU6]] : tensor<1x3x2x2xf32> + diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/convert_func_to_bfloat16.mlir b/tensorflow/compiler/mlir/quantization/stablehlo/tests/convert_func_to_bfloat16.mlir deleted file mode 100644 index 317f51ea3fce68..00000000000000 --- a/tensorflow/compiler/mlir/quantization/stablehlo/tests/convert_func_to_bfloat16.mlir +++ /dev/null @@ -1,55 +0,0 @@ -// RUN: stablehlo-quant-opt %s -split-input-file -stablehlo-convert-func-to-bfloat16 -verify-diagnostics | FileCheck %s - -// CHECK-LABEL: @add_f32(%arg0: tensor<3x3xbf16>, %arg1: tensor<3x3xbf16>) -> tensor<3x3xbf16> -func.func @add_f32(%arg0: tensor<3x3xf32>, %arg1: tensor<3x3xf32>) -> tensor<3x3xf32> { - // CHECK-NOT: f32 - // CHECK: stablehlo.add - %0 = stablehlo.add %arg0, %arg1: (tensor<3x3xf32>, tensor<3x3xf32>) -> tensor<3x3xf32> - return %0 : tensor<3x3xf32> -} - -// ----- - -// CHECK-LABEL: @add_f64(%arg0: tensor<3x3xbf16>, %arg1: tensor<3x3xbf16>) -> tensor<3x3xbf16> -func.func @add_f64(%arg0: tensor<3x3xf64>, %arg1: tensor<3x3xf64>) -> tensor<3x3xf64> { - // CHECK-NOT: f64 - // CHECK: stablehlo.add - %0 = stablehlo.add %arg0, %arg1: (tensor<3x3xf64>, tensor<3x3xf64>) -> tensor<3x3xf64> - return %0 : tensor<3x3xf64> -} - -// ----- - -// CHECK-LABEL: @constant_f32() -> tensor<2x2xbf16> -func.func @constant_f32() -> tensor<2x2xf32> { - // CHECK-NOT: f32 - // CHECK{LITERAL}: stablehlo.constant dense<[[1.398440e+00, 0.000000e+00], [3.093750e+00, -2.001950e-01]]> : tensor<2x2xbf16> - %0 = stablehlo.constant dense<[[1.4, 0.0], [3.1, -0.2]]> : tensor<2x2xf32> - return %0 : tensor<2x2xf32> -} - -// ----- - -func.func @constant_elided() -> tensor<2x2xf32> { - // expected-error @+1 {{failed to legalize operation 'stablehlo.constant' that was explicitly marked illegal}} - %0 = stablehlo.constant dense_resource<__elided__> : tensor<2x2xf32> - return %0 : tensor<2x2xf32> -} - -// ----- - -// CHECK-LABEL: @reduce_window_f32(%arg0: tensor<2x3x1x3xbf16>) -> tensor<2x3x1x3xbf16> -func.func @reduce_window_f32(%arg0: tensor<2x3x1x3xf32>) -> tensor<2x3x1x3xf32> { - // CHECK-NOT: f32 - // CHECK: stablehlo.reduce_window - %0 = stablehlo.constant dense<0.0> : tensor - %1 = "stablehlo.reduce_window"(%arg0, %0) ({ - ^bb0(%arg1: tensor, %arg2: tensor): - %2 = stablehlo.maximum %arg1, %arg2 : tensor - stablehlo.return %2 : tensor - }) {padding = dense<[[0, 0], [1, 1], [1, 1], [0, 0]]> : tensor<4x2xi64>, window_dimensions = array} : (tensor<2x3x1x3xf32>, tensor) -> tensor<2x3x1x3xf32> - return %1 : tensor<2x3x1x3xf32> -} - -// ----- - diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/convert_func_to_bfloat16.mlir b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/convert_func_to_bfloat16.mlir new file mode 100644 index 00000000000000..fdb5860eb1bd23 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/convert_func_to_bfloat16.mlir @@ -0,0 +1,128 @@ +// RUN: stablehlo-quant-opt %s -split-input-file -stablehlo-convert-func-to-bfloat16 -verify-diagnostics | FileCheck %s + +// CHECK-LABEL: @add_f32(%arg0: tensor<3x3xbf16>, %arg1: tensor<3x3xbf16>) -> tensor<3x3xbf16> +func.func @add_f32(%arg0: tensor<3x3xf32>, %arg1: tensor<3x3xf32>) -> tensor<3x3xf32> { + // CHECK-NOT: f32 + // CHECK: stablehlo.add + %0 = stablehlo.add %arg0, %arg1: (tensor<3x3xf32>, tensor<3x3xf32>) -> tensor<3x3xf32> + return %0 : tensor<3x3xf32> +} + +// ----- + +// CHECK-LABEL: @add_f64(%arg0: tensor<3x3xbf16>, %arg1: tensor<3x3xbf16>) -> tensor<3x3xbf16> +func.func @add_f64(%arg0: tensor<3x3xf64>, %arg1: tensor<3x3xf64>) -> tensor<3x3xf64> { + // CHECK-NOT: f64 + // CHECK: stablehlo.add + %0 = stablehlo.add %arg0, %arg1: (tensor<3x3xf64>, tensor<3x3xf64>) -> tensor<3x3xf64> + return %0 : tensor<3x3xf64> +} + +// ----- + +// CHECK-LABEL: @constant_f32() -> tensor<2x2xbf16> +func.func @constant_f32() -> tensor<2x2xf32> { + // CHECK-NOT: f32 + // CHECK{LITERAL}: stablehlo.constant dense<[[1.398440e+00, 0.000000e+00], [3.093750e+00, -2.001950e-01]]> : tensor<2x2xbf16> + %0 = stablehlo.constant dense<[[1.4, 0.0], [3.1, -0.2]]> : tensor<2x2xf32> + return %0 : tensor<2x2xf32> +} + +// ----- + +func.func @constant_elided() -> tensor<2x2xf32> { + // expected-error @+1 {{failed to legalize operation 'stablehlo.constant' that was explicitly marked illegal}} + %0 = stablehlo.constant dense_resource<__elided__> : tensor<2x2xf32> + return %0 : tensor<2x2xf32> +} + +// ----- + +// CHECK-LABEL: @reduce_window_f32(%arg0: tensor<2x3x1x3xbf16>) -> tensor<2x3x1x3xbf16> +func.func @reduce_window_f32(%arg0: tensor<2x3x1x3xf32>) -> tensor<2x3x1x3xf32> { + // CHECK-NOT: f32 + // CHECK: stablehlo.reduce_window + %0 = stablehlo.constant dense<0.0> : tensor + %1 = "stablehlo.reduce_window"(%arg0, %0) ({ + ^bb0(%arg1: tensor, %arg2: tensor): + %2 = stablehlo.maximum %arg1, %arg2 : tensor + stablehlo.return %2 : tensor + }) {padding = dense<[[0, 0], [1, 1], [1, 1], [0, 0]]> : tensor<4x2xi64>, window_dimensions = array} : (tensor<2x3x1x3xf32>, tensor) -> tensor<2x3x1x3xf32> + return %1 : tensor<2x3x1x3xf32> +} + +// ----- + +// CHECK-LABEL: @bitcast_convert_i32_f32(%arg0: tensor<1x256128xi32>) -> tensor<1x256128xbf16> +func.func @bitcast_convert_i32_f32(%arg0: tensor<1x256128xi32>) -> tensor<1x256128xf32> { + // CHECK: %[[BITCAST:.*]] = stablehlo.bitcast_convert %arg0 : (tensor<1x256128xi32>) -> tensor<1x256128xf32> + // CHECK: %[[CONVERT:.*]] = stablehlo.convert %[[BITCAST]] : (tensor<1x256128xf32>) -> tensor<1x256128xbf16> + // CHECK: return %[[CONVERT]] : tensor<1x256128xbf16> + %20 = stablehlo.bitcast_convert %arg0 : (tensor<1x256128xi32>) -> tensor<1x256128xf32> + return %20 : tensor<1x256128xf32> +} + +// ----- + +// CHECK-LABEL: @bitcast_convert_f32_i32(%arg0: tensor<1x256128xbf16>) -> tensor<1x256128xi32> +func.func @bitcast_convert_f32_i32(%arg0: tensor<1x256128xf32>) -> tensor<1x256128xi32> { + // CHECK: %[[CONVERT:.*]] = stablehlo.convert %arg0 : (tensor<1x256128xbf16>) -> tensor<1x256128xf32> + // CHECK: %[[BITCAST:.*]] = stablehlo.bitcast_convert %[[CONVERT]] : (tensor<1x256128xf32>) -> tensor<1x256128xi32> + // CHECK: return %[[BITCAST]] : tensor<1x256128xi32> + %20 = stablehlo.bitcast_convert %arg0 : (tensor<1x256128xf32>) -> tensor<1x256128xi32> + return %20 : tensor<1x256128xi32> +} + +// ----- + +// CHECK-LABEL: @bitcast_convert_ui32_f32(%arg0: tensor<1x256128xui32>) -> tensor<1x256128xbf16> +func.func @bitcast_convert_ui32_f32(%arg0: tensor<1x256128xui32>) -> tensor<1x256128xf32> { + // CHECK: %[[BITCAST:.*]] = stablehlo.bitcast_convert %arg0 : (tensor<1x256128xui32>) -> tensor<1x256128xf32> + // CHECK: %[[CONVERT:.*]] = stablehlo.convert %[[BITCAST]] : (tensor<1x256128xf32>) -> tensor<1x256128xbf16> + // CHECK: return %[[CONVERT]] : tensor<1x256128xbf16> + %20 = stablehlo.bitcast_convert %arg0 : (tensor<1x256128xui32>) -> tensor<1x256128xf32> + return %20 : tensor<1x256128xf32> +} + +// ----- + +// CHECK-LABEL: @bitcast_convert_f32_ui32(%arg0: tensor<1x256128xbf16>) -> tensor<1x256128xui32> +func.func @bitcast_convert_f32_ui32(%arg0: tensor<1x256128xf32>) -> tensor<1x256128xui32> { + // CHECK: %[[CONVERT:.*]] = stablehlo.convert %arg0 : (tensor<1x256128xbf16>) -> tensor<1x256128xf32> + // CHECK: %[[BITCAST:.*]] = stablehlo.bitcast_convert %[[CONVERT]] : (tensor<1x256128xf32>) -> tensor<1x256128xui32> + // CHECK: return %[[BITCAST]] : tensor<1x256128xui32> + %20 = stablehlo.bitcast_convert %arg0 : (tensor<1x256128xf32>) -> tensor<1x256128xui32> + return %20 : tensor<1x256128xui32> +} + +// ----- + +// CHECK-LABEL: @bitcast_convert_f32_f32(%arg0: tensor<1x256128xbf16>) -> tensor<1x256128xbf16> +func.func @bitcast_convert_f32_f32(%arg0: tensor<1x256128xf32>) -> tensor<1x256128xf32> { + // Convert bitcast_convert to no-op for f32->f32. + // CHECK: return %arg0 : tensor<1x256128xbf16> + %20 = stablehlo.bitcast_convert %arg0 : (tensor<1x256128xf32>) -> tensor<1x256128xf32> + return %20 : tensor<1x256128xf32> +} + +// ----- + +// CHECK-LABEL: @bitcast_convert_i32_ui32(%arg0: tensor<1x256128xi32>) -> tensor<1x256128xui32> +func.func @bitcast_convert_i32_ui32(%arg0: tensor<1x256128xi32>) -> tensor<1x256128xui32> { + // Do not convert bitcast_convert for legal types. + // CHECK: %[[BITCAST:.*]] = stablehlo.bitcast_convert %arg0 : (tensor<1x256128xi32>) -> tensor<1x256128xui32> + // CHECK: return %[[BITCAST]] : tensor<1x256128xui32> + %20 = stablehlo.bitcast_convert %arg0 : (tensor<1x256128xi32>) -> tensor<1x256128xui32> + return %20 : tensor<1x256128xui32> +} + +// ----- + +// CHECK-LABEL: @bitcast_convert_bf16_bf16(%arg0: tensor<1x256128xbf16>) -> tensor<1x256128xbf16> +func.func @bitcast_convert_bf16_bf16(%arg0: tensor<1x256128xbf16>) -> tensor<1x256128xbf16> { + // Do not convert bitcast_convert for legal types. + // CHECK: %[[BITCAST:.*]] = stablehlo.bitcast_convert %arg0 : (tensor<1x256128xbf16>) -> tensor<1x256128xbf16> + // CHECK: return %[[BITCAST]] : tensor<1x256128xbf16> + %20 = stablehlo.bitcast_convert %arg0 : (tensor<1x256128xbf16>) -> tensor<1x256128xbf16> + return %20 : tensor<1x256128xbf16> +} diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/convert_xla_call_module_op_to_bfloat16.mlir b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/convert_xla_call_module_op_to_bfloat16.mlir similarity index 100% rename from tensorflow/compiler/mlir/quantization/stablehlo/tests/convert_xla_call_module_op_to_bfloat16.mlir rename to tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/convert_xla_call_module_op_to_bfloat16.mlir diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/lift_quantizable_spots_as_functions.mlir b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/lift_quantizable_spots_as_functions.mlir similarity index 85% rename from tensorflow/compiler/mlir/quantization/stablehlo/tests/lift_quantizable_spots_as_functions.mlir rename to tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/lift_quantizable_spots_as_functions.mlir index e743a19dc0a822..fa722c2fc71c88 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/tests/lift_quantizable_spots_as_functions.mlir +++ b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/lift_quantizable_spots_as_functions.mlir @@ -138,6 +138,25 @@ func.func @conv_with_bias_dynamic_fn(%arg0: tensor) -> tensor +func.func @conv_with_bias_dynamic_shape_not_same_op_fn(%arg0: tensor) -> tensor { + %0 = stablehlo.constant dense<2.000000e+00> : tensor<3x3x1x16xf32> + %1 = stablehlo.constant dense<2.000000e+00> : tensor<16xf32> + %2 = stablehlo.convolution(%arg0, %0) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {stride = [1, 1], pad = [[1, 1], [1, 1]], rhs_dilate = [1, 1]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64, precision_config = [#stablehlo, #stablehlo]} : (tensor, tensor<3x3x1x16xf32>) -> tensor + %3 = stablehlo.convolution(%arg0, %0) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {stride = [1, 1], pad = [[1, 1], [1, 1]], rhs_dilate = [1, 1]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64, precision_config = [#stablehlo, #stablehlo]} : (tensor, tensor<3x3x1x16xf32>) -> tensor + %4 = shape.shape_of %3 : tensor -> tensor<4xindex> + %5 = stablehlo.dynamic_broadcast_in_dim %1, %4, dims = [3] : (tensor<16xf32>, tensor<4xindex>) -> tensor + %6 = stablehlo.add %2, %5 : tensor + func.return %6: tensor +} +// CHECK-NOT: @composite_conv_with_bias_dynamic_fn_1 + +// ----- + // CHECK-LABEL: @dot_general_with_bias_dynamic_fn( // CHECK-SAME: %[[ARG_0:.*]]: tensor func.func @dot_general_with_bias_dynamic_fn(%arg0: tensor) -> tensor { @@ -238,6 +257,25 @@ func.func @conv_with_relu_dynamic_fn(%arg0: tensor) -> tensor +func.func @conv_with_relu_dynamic_shape_not_same_op_fn(%arg0: tensor) -> tensor { + %0 = stablehlo.constant dense<2.000000e+00> : tensor<3x3x1x16xf32> + %1 = stablehlo.constant dense<0.000000e+00> : tensor + %2 = stablehlo.convolution(%arg0, %0) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {stride = [1, 1], pad = [[1, 1], [1, 1]], rhs_dilate = [1, 1]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64, precision_config = [#stablehlo, #stablehlo]} : (tensor, tensor<3x3x1x16xf32>) -> tensor + %3 = stablehlo.convolution(%arg0, %0) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {stride = [1, 1], pad = [[1, 1], [1, 1]], rhs_dilate = [1, 1]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64, precision_config = [#stablehlo, #stablehlo]} : (tensor, tensor<3x3x1x16xf32>) -> tensor + %4 = shape.shape_of %3 : tensor -> tensor<4xindex> + %5 = stablehlo.dynamic_broadcast_in_dim %1, %4, dims = [] : (tensor, tensor<4xindex>) -> tensor + %6 = stablehlo.maximum %2, %5 : tensor + func.return %6: tensor +} +// CHECK-NOT: private @composite_conv_with_relu_dynamic_fn_1 + +// ----- + // CHECK-LABEL: @dot_general_with_relu_dynamic_fn( // CHECK-SAME: %[[ARG_0:.*]]: tensor func.func @dot_general_with_relu_dynamic_fn(%arg0: tensor) -> tensor { @@ -508,6 +546,29 @@ func.func @conv_with_bias_and_relu_dynamic_fn(%arg0: tensor) -> t // ----- +// Because the operand of shape_of is other than the target conv, +// should not match conv bias relu dynamic pattern. + +// CHECK-LABEL: @conv_with_bias_and_relu_dynamic_shape_not_same_op_fn( +// CHECK-SAME: %[[ARG_0:.*]]: tensor +func.func @conv_with_bias_and_relu_dynamic_shape_not_same_op_fn(%arg0: tensor) -> tensor { + %0 = stablehlo.constant dense<2.000000e+00> : tensor<3x3x1x16xf32> + %1 = stablehlo.constant dense<2.000000e+00> : tensor<16xf32> + %2 = stablehlo.constant dense<0.000000e+00> : tensor + %3 = stablehlo.convolution(%arg0, %0) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {stride = [1, 1], pad = [[1, 1], [1, 1]], rhs_dilate = [1, 1]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64, precision_config = [#stablehlo, #stablehlo]} : (tensor, tensor<3x3x1x16xf32>) -> tensor + %4 = stablehlo.convolution(%arg0, %0) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {stride = [1, 1], pad = [[1, 1], [1, 1]], rhs_dilate = [1, 1]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64, precision_config = [#stablehlo, #stablehlo]} : (tensor, tensor<3x3x1x16xf32>) -> tensor + %5 = shape.shape_of %4 : tensor -> tensor<4xindex> + %6 = stablehlo.dynamic_broadcast_in_dim %1, %5, dims = [3] : (tensor<16xf32>, tensor<4xindex>) -> tensor + %7 = stablehlo.add %3, %6 : tensor + %8 = shape.shape_of %7 : tensor -> tensor<4xindex> + %9 = stablehlo.dynamic_broadcast_in_dim %2, %8, dims = [] : (tensor, tensor<4xindex>) -> tensor + %10 = stablehlo.maximum %7, %9 : tensor + func.return %10: tensor +} +// CHECK-NOT: private @composite_conv_with_bias_and_relu_dynamic_fn_1 + +// ----- + // CHECK-LABEL: @dot_general_with_bias_and_relu_dynamic_fn( // CHECK-SAME: %[[ARG_0:.*]]: tensor func.func @dot_general_with_bias_and_relu_dynamic_fn(%arg0: tensor) -> tensor { @@ -667,6 +728,28 @@ func.func @conv_with_bias_and_relu6_dynamic_fn(%arg0: tensor) -> // ----- +// Because the operand of shape_of is other than the target conv, +// should not match conv bias relu6 dynamic pattern. + +// CHECK-LABEL: @conv_with_bias_and_relu6_dynamic_shape_not_same_op_fn( +// CHECK-SAME: %[[ARG_0:.*]]: tensor +func.func @conv_with_bias_and_relu6_dynamic_shape_not_same_op_fn(%arg0: tensor) -> tensor { + %0 = stablehlo.constant dense<2.000000e+00> : tensor<3x3x1x16xf32> + %1 = stablehlo.constant dense<2.000000e+00> : tensor<16xf32> + %2 = stablehlo.constant dense<0.000000e+00> : tensor + %3 = stablehlo.constant dense<6.000000e+00> : tensor + %4 = stablehlo.convolution(%arg0, %0) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {pad = [[1, 1], [1, 1]]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor, tensor<3x3x1x16xf32>) -> tensor + %5 = stablehlo.convolution(%arg0, %0) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {pad = [[1, 1], [1, 1]]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor, tensor<3x3x1x16xf32>) -> tensor + %6 = shape.shape_of %5 : tensor -> tensor<4xindex> + %7 = stablehlo.dynamic_broadcast_in_dim %1, %6, dims = [3] : (tensor<16xf32>, tensor<4xindex>) -> tensor + %8 = stablehlo.add %4, %7 : tensor + %9 = stablehlo.clamp %2, %8, %3 : (tensor, tensor, tensor) -> tensor + func.return %9: tensor +} +// CHECK-NOT: private @composite_conv_with_bias_and_relu6_dynamic_fn_1 + +// ----- + // CHECK-LABEL: @dot_general_with_bias_and_relu6_dynamic_fn( // CHECK-SAME: %[[ARG_0:.*]]: tensor func.func @dot_general_with_bias_and_relu6_dynamic_fn(%arg0: tensor) -> tensor { diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/lift_quantizable_spots_as_functions_with_quantization_specs.mlir b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/lift_quantizable_spots_as_functions_with_quantization_specs.mlir new file mode 100644 index 00000000000000..00b3dd3b5e57a4 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/lift_quantizable_spots_as_functions_with_quantization_specs.mlir @@ -0,0 +1,25 @@ +// RUN: stablehlo-quant-opt %s -stablehlo-test-lift-quantizable-spots-as-functions-with-quantization-specs \ +// RUN: -split-input-file | FileCheck %s + +// CHECK: @main +func.func @main(%arg0: tensor<1x1x167xf32>) -> tensor<1x1x64xf32> { + %0 = stablehlo.constant dense<2.000000e+00> : tensor<167x64xf32> + %1 = stablehlo.dot_general %arg0, %0, contracting_dims = [2] x [0], precision = [DEFAULT, DEFAULT] : (tensor<1x1x167xf32>, tensor<167x64xf32>) -> tensor<1x1x64xf32> + return %1 : tensor<1x1x64xf32> +} +// Tests that `composite_dot_general_fn_1` and its corresponding XlaCallModuleOp +// is missing attributes required for quantization. + +// CHECK: %[[CONST:.*]] = stablehlo.constant dense<2.000000e+00> +// CHECK: %[[XLA_CALL_MODULE:.*]] = "tf.XlaCallModule"(%arg0, %[[CONST]]) +// CHECK-SAME: {_entry_function = @composite_dot_general_fn_1, {{.*}}} +// CHECK-NOT: _original_entry_function +// CHECK-NOT: _tfl_quant_trait +// CHECK: return %[[XLA_CALL_MODULE:.*]] : tensor<1x1x64xf32> +// CHECK: } + +// CHECK-LABEL: private @composite_dot_general_fn_1 +// CHECK-NOT: tf_quant.composite_function +// CHECK: %[[DOT_GENERAL:.*]] = stablehlo.dot_general %arg0, %arg1 +// CHECK: return %[[DOT_GENERAL:.*]] : tensor<1x1x64xf32> +// CHECK: } diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/optimize_graph.mlir b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/optimize_graph.mlir similarity index 100% rename from tensorflow/compiler/mlir/quantization/stablehlo/tests/optimize_graph.mlir rename to tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/optimize_graph.mlir diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/post_quantize.mlir b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/post_quantize.mlir similarity index 100% rename from tensorflow/compiler/mlir/quantization/stablehlo/tests/post_quantize.mlir rename to tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/post_quantize.mlir diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/prepare_quantize.mlir b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/prepare_quantize/prepare_quantize.mlir similarity index 84% rename from tensorflow/compiler/mlir/quantization/stablehlo/tests/prepare_quantize.mlir rename to tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/prepare_quantize/prepare_quantize.mlir index 6688ba63a68146..aec15c3d5a5ded 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/tests/prepare_quantize.mlir +++ b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/prepare_quantize/prepare_quantize.mlir @@ -1,4 +1,4 @@ -// RUN: stablehlo-quant-opt %s -split-input-file -stablehlo-prepare-quantize=enable-per-channel-quantization=false -verify-diagnostics | FileCheck %s +// RUN: stablehlo-quant-opt %s -split-input-file -stablehlo-prepare-quantize=enable-per-channel-quantized-weight=false -verify-diagnostics | FileCheck %s // ----- @@ -74,36 +74,36 @@ func.func @dot_redundant_stats(%arg0: tensor) -> tensor { // ----- -// CHECK-LABEL: func @convert_same_scale_propagate -func.func @convert_same_scale_propagate(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> { +// CHECK-LABEL: func @reshape_same_scale_propagate +func.func @reshape_same_scale_propagate(%arg0: tensor<2x3xf32>) -> tensor<6xf32> { // CHECK: %[[dq:.*]] = "quantfork.dcast" // CHECK-SAME: (tensor<2x3x!quant.uniform>) %0 = "quantfork.stats"(%arg0) {bitsNum = 8 : i64, layerStats = dense<[-0.999415695, 0.99998933]> : tensor<2xf32>, narrowRange = false} : (tensor<2x3xf32>) -> tensor<2x3xf32> - // CHECK: %[[convert:.*]] = stablehlo.convert %[[dq]] - %1 = stablehlo.convert %0 : (tensor<2x3xf32>) -> (tensor<2x3xf32>) - // CHECK: %[[q:.*]] = "quantfork.qcast"(%[[convert]]) - // CHECK-SAME: -> tensor<2x3x!quant.uniform> - %2 = "quantfork.stats"(%1) {bitsNum = 8 : i64, layerStats = dense<[-2.0, 2.0]> : tensor<2xf32>, narrowRange = false} : (tensor<2x3xf32>) -> tensor<2x3xf32> - func.return %2 : tensor<2x3xf32> + // CHECK: %[[reshape:.*]] = stablehlo.reshape %[[dq]] + %1 = stablehlo.reshape %0 : (tensor<2x3xf32>) -> (tensor<6xf32>) + // CHECK: %[[q:.*]] = "quantfork.qcast"(%[[reshape]]) + // CHECK-SAME: -> tensor<6x!quant.uniform> + %2 = "quantfork.stats"(%1) {bitsNum = 8 : i64, layerStats = dense<[-2.0, 2.0]> : tensor<2xf32>, narrowRange = false} : (tensor<6xf32>) -> tensor<6xf32> + func.return %2 : tensor<6xf32> } // ----- // CHECK-LABEL: func @merge_consecutive_qcast -// CHECK-SAME: (%[[ARG_0:.*]]: tensor<*xf32>, %[[ARG_1:.*]]: tensor<*xf32>, %[[ARG_2:.*]]: tensor<*xf32>) -> (tensor<*xf32>, tensor<*xf32>) -func.func @merge_consecutive_qcast(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>, %arg2: tensor<*xf32>) -> (tensor<*xf32>, tensor<*xf32>) { +// CHECK-SAME: (%[[ARG_0:.*]]: tensor, %[[ARG_1:.*]]: tensor, %[[ARG_2:.*]]: tensor) -> (tensor, tensor) +func.func @merge_consecutive_qcast(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> (tensor, tensor) { // CHECK: "quantfork.qcast"(%[[ARG_1]]) - // CHECK-SAME: -> tensor<*x!quant.uniform> + // CHECK-SAME: -> tensor> // CHECK: "quantfork.qcast"(%[[ARG_1]]) - // CHECK-SAME: -> tensor<*x!quant.uniform> - %0 = "quantfork.stats"(%arg0) {layerStats = dense<[-0.83811146, 2.4960899]> : tensor<2xf32>} : (tensor<*xf32>) -> tensor<*xf32> - %1 = "quantfork.stats"(%arg1) {layerStats = dense<[-0.835039615, 1.000000e+00]> : tensor<2xf32>} : (tensor<*xf32>) -> tensor<*xf32> - %2 = "stablehlo.concatenate"(%0, %1) {dimension = 0 : i64} : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> - %3 = "quantfork.stats"(%2) {layerStats = dense<[-0.83811146, 2.4960899]> : tensor<2xf32>} : (tensor<*xf32>) -> tensor<*xf32> - %4 = "quantfork.stats"(%arg2) {layerStats = dense<[-1.5726943, 1.07351148]> : tensor<2xf32>} : (tensor<*xf32>) -> tensor<*xf32> - %5 = "stablehlo.concatenate"(%4, %1) {dimension = 0 : i64} : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> - %6 = "quantfork.stats"(%5) {layerStats = dense<[-1.5726943, 4.6875381]> : tensor<2xf32>} : (tensor<*xf32>) -> tensor<*xf32> - func.return %3, %6 : tensor<*xf32>, tensor<*xf32> + // CHECK-SAME: -> tensor> + %0 = "quantfork.stats"(%arg0) {layerStats = dense<[-0.83811146, 2.4960899]> : tensor<2xf32>} : (tensor) -> tensor + %1 = "quantfork.stats"(%arg1) {layerStats = dense<[-0.835039615, 1.000000e+00]> : tensor<2xf32>} : (tensor) -> tensor + %2 = "stablehlo.concatenate"(%0, %1) {dimension = 0 : i64} : (tensor, tensor) -> tensor + %3 = "quantfork.stats"(%2) {layerStats = dense<[-0.83811146, 2.4960899]> : tensor<2xf32>} : (tensor) -> tensor + %4 = "quantfork.stats"(%arg2) {layerStats = dense<[-1.5726943, 1.07351148]> : tensor<2xf32>} : (tensor) -> tensor + %5 = "stablehlo.concatenate"(%4, %1) {dimension = 0 : i64} : (tensor, tensor) -> tensor + %6 = "quantfork.stats"(%5) {layerStats = dense<[-1.5726943, 4.6875381]> : tensor<2xf32>} : (tensor) -> tensor + func.return %3, %6 : tensor, tensor } // ----- diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/prepare_quantize_int4.mlir b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/prepare_quantize/prepare_quantize_int4.mlir similarity index 100% rename from tensorflow/compiler/mlir/quantization/stablehlo/tests/prepare_quantize_int4.mlir rename to tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/prepare_quantize/prepare_quantize_int4.mlir diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/prepare_quantize_per_channel.mlir b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/prepare_quantize/prepare_quantize_per_channel.mlir similarity index 98% rename from tensorflow/compiler/mlir/quantization/stablehlo/tests/prepare_quantize_per_channel.mlir rename to tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/prepare_quantize/prepare_quantize_per_channel.mlir index f509ffce05863d..a6159c1dd62b4b 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/tests/prepare_quantize_per_channel.mlir +++ b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/prepare_quantize/prepare_quantize_per_channel.mlir @@ -1,4 +1,4 @@ -// RUN: stablehlo-quant-opt %s -split-input-file -stablehlo-prepare-quantize=enable-per-channel-quantization=true -verify-diagnostics | FileCheck %s +// RUN: stablehlo-quant-opt %s -split-input-file -stablehlo-prepare-quantize=enable-per-channel-quantized-weight=true -verify-diagnostics | FileCheck %s // ----- diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/quantize.mlir b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/quantize/quantize.mlir similarity index 100% rename from tensorflow/compiler/mlir/quantization/stablehlo/tests/quantize.mlir rename to tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/quantize/quantize.mlir diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/quantize_op_with_region.mlir b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/quantize/quantize_op_with_region.mlir similarity index 100% rename from tensorflow/compiler/mlir/quantization/stablehlo/tests/quantize_op_with_region.mlir rename to tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/quantize/quantize_op_with_region.mlir diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/quantize_same_scale.mlir b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/quantize/quantize_same_scale.mlir similarity index 90% rename from tensorflow/compiler/mlir/quantization/stablehlo/tests/quantize_same_scale.mlir rename to tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/quantize/quantize_same_scale.mlir index 8831839bdeaebe..f437016ed2c6f2 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/tests/quantize_same_scale.mlir +++ b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/quantize/quantize_same_scale.mlir @@ -154,46 +154,6 @@ module attributes {tf_saved_model.semantics} { // ----- -module attributes {tf_saved_model.semantics} { - // CHECK-LABEL: composite_and_convert - // CHECK-SAME: %[[ARG0:.*]]: tensor<1x2xf32> - // CHECK-SAME: %[[ARG1:.*]]: tensor<2x3xf32> - func.func private @composite_and_convert(%arg0: tensor<1x2xf32>, %arg1: tensor<2x3xf32>) -> tensor<1x3xf32> { - // CHECK: %[[Q1:.*]] = "quantfork.qcast"(%[[ARG0]]) {volatile} : (tensor<1x2xf32>) -> tensor<1x2x!quant.uniform> - // CHECK: %[[Q2:.*]] = "quantfork.qcast"(%[[ARG1]]) {volatile} : (tensor<2x3xf32>) -> tensor<2x3x!quant.uniform:f32, 6.000000e-03:13>> - // CHECK: %[[CALL:.*]] = call @quantized_dot_general_fn_1(%[[Q1]], %[[Q2]]) - // CHECK-SAME: (tensor<1x2x!quant.uniform>, tensor<2x3x!quant.uniform:f32, 6.000000e-03:13>>) -> tensor<1x3x!quant.uniform> - // CHECK: %[[CONVERT:.*]] = stablehlo.convert %[[CALL]] : tensor<1x3x!quant.uniform> - // CHECK: %[[DQ:.*]] = "quantfork.dcast"(%[[CONVERT]]) : (tensor<1x3x!quant.uniform>) -> tensor<1x3xf32> - // CHECK: return %[[DQ]] - %0 = "quantfork.qcast"(%arg0) {volatile} : (tensor<1x2xf32>) -> tensor<1x2x!quant.uniform> - %1 = "quantfork.dcast"(%0) : (tensor<1x2x!quant.uniform>) -> tensor<1x2xf32> - %2 = "quantfork.qcast"(%arg1) {volatile} : (tensor<2x3xf32>) -> tensor<2x3x!quant.uniform:f32, 6.000000e-03:13>> - %3 = "quantfork.dcast"(%2) : (tensor<2x3x!quant.uniform:f32, 6.000000e-03:13>>) -> tensor<2x3xf32> - %4 = "tf.XlaCallModule"(%1, %3) {Sout = [#tf_type.shape<1x3>], _entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32> - %5 = "quantfork.qcast"(%4) {volatile} : (tensor<1x3xf32>) -> tensor<1x3x!quant.uniform> - %6 = "quantfork.dcast"(%5) : (tensor<1x3x!quant.uniform>) -> tensor<1x3xf32> - %7 = stablehlo.convert %6 : (tensor<1x3xf32>) -> tensor<1x3xf32> - %8 = "quantfork.qcast"(%7) {volatile} : (tensor<1x3xf32>) -> tensor<1x3x!quant.uniform> - %9 = "quantfork.dcast"(%8) : (tensor<1x3x!quant.uniform>) -> tensor<1x3xf32> - return %9 : tensor<1x3xf32> - } - - // CHECK: quantized_dot_general_fn_1 - // CHECK-SAME: %[[ARG2:.*]]: tensor<1x2x!quant.uniform> - // CHECK-SAME: %[[ARG3:.*]]: tensor<2x3x!quant.uniform:f32, 6.000000e-03:13>> - func.func private @composite_dot_general_fn_1(%arg0: tensor<1x2xf32>, %arg1: tensor<2x3xf32>) -> tensor<1x3xf32> attributes {_from_xla_call_module} { - // CHECK: %[[DOT:.*]] = stablehlo.dot_general %[[ARG2]], %[[ARG3]] - // CHECK-SAME: (tensor<1x2x!quant.uniform>, tensor<2x3x!quant.uniform:f32, 6.000000e-03:13>>) -> tensor<1x3x!quant.uniform> - // CHECK: %[[Q3:.*]] = stablehlo.uniform_quantize %0 : (tensor<1x3x!quant.uniform>) -> tensor<1x3x!quant.uniform> - // CHECK: return %[[Q3]] - %0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0] : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32> - return %0 : tensor<1x3xf32> - } -} - -// ----- - module attributes {tf_saved_model.semantics} { // CHECK-LABEL: composite_and_pad // CHECK-SAME: %[[ARG0:.*]]: tensor<1x2xf32> diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/quantize_composite_functions.mlir b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/quantize_composite_functions.mlir new file mode 100644 index 00000000000000..ff7bfafe654099 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/quantize_composite_functions.mlir @@ -0,0 +1,567 @@ +// RUN: stablehlo-quant-opt %s -split-input-file -verify-diagnostics \ +// RUN: -stablehlo-quantize-composite-functions | FileCheck %s +// RUN: stablehlo-quant-opt %s -split-input-file -verify-diagnostics \ +// RUN: -stablehlo-quantize-composite-functions=enable-per-channel-quantized-weight=false | FileCheck --check-prefix=CHECK-PER-TENSOR %s + +// Tests that basic dot_general is properly quantized. + +module attributes {tf_saved_model.semantics} { + func.func private @quantize_dot_general_fn(%arg0: tensor<1x2xf32>) -> tensor<1x3xf32> attributes {tf._original_func_name = "main_0"} { + %cst = "tf.Const"() {value = dense<3.00000000e-1> : tensor<2x3xf32>} : () -> tensor<2x3xf32> + %0 = "quantfork.stats"(%arg0) {layerStats = dense<[6.00000000e-6, 9.00000000e-1]> : tensor<2xf32>} : (tensor<1x2xf32>) -> tensor<1x2xf32> + %1 = "tf.XlaCallModule"(%0, %cst) {Sout = [#tf_type.shape<1x3>], _entry_function = @composite_dot_general_fn, _original_entry_function = "composite_dot_general_fn", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32> + %2 = "quantfork.stats"(%1) {layerStats = dense<[5.00000000e-6, 7.00000000e-1]> : tensor<2xf32>} : (tensor<1x3xf32>) -> tensor<1x3xf32> + return %2 : tensor<1x3xf32> + } +// Checks that the quantized XlaCallModule has been replaced by a CallOp, which +// calls the quantized entry function. + +// CHECK-PER-TENSOR: func.func private @quantize_dot_general_fn(%[[ARG_0:.+]]: tensor<1x2xf32>) -> tensor<1x3xf32> attributes {tf._original_func_name = "main_0"} +// CHECK-PER-TENSOR: %[[CONST_0:.+]] = stablehlo.constant() {value = dense<{{.*}}> : tensor<2x3xi8>} : () -> tensor<2x3x!quant.uniform:f32, {{.*}}> +// CHECK-PER-TENSOR: %[[UNIFORM_QUANTIZE_0:.+]] = stablehlo.uniform_quantize %[[ARG_0]] : (tensor<1x2xf32>) -> tensor<1x2x!quant.uniform> +// CHECK-PER-TENSOR: %[[CALL_0:.+]] = call @quantized_dot_general_fn(%[[UNIFORM_QUANTIZE_0]], %[[CONST_0]]) : (tensor<1x2x!quant.uniform>, tensor<2x3x!quant.uniform:f32, {{.*}}>) -> tensor<1x3x!quant.uniform> +// CHECK-PER-TENSOR: %[[UNIFORM_DEQUANTIZE_0:.+]] = stablehlo.uniform_dequantize %[[CALL_0]] : (tensor<1x3x!quant.uniform) -> tensor<1x3xf32> +// CHECK-PER-TENSOR: return %[[UNIFORM_DEQUANTIZE_0]] : tensor<1x3xf32> + + func.func private @composite_dot_general_fn(%arg0: tensor<1x2xf32>, %arg1: tensor<2x3xf32>) -> tensor<1x3xf32> attributes {_from_xla_call_module} { + %0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0] : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32> + return %0 : tensor<1x3xf32> + } +// Checks that the entry function is quantized for dot_general. Quantized +// dot_general outputs an i32 quantized tensor, followed by requantization to +// i8 quantized tensor. + +// CHECK-PER-TENSOR: func.func private @quantized_dot_general_fn(%[[ARG_1:.+]]: tensor<1x2x!quant.uniform>, %[[ARG_2:.+]]: tensor<2x3x!quant.uniform:f32, {{.*}}>>) -> tensor<1x3x!quant.uniform> attributes {_from_xla_call_module} +// CHECK-PER-TENSOR: %[[DOT_GENERAL_0:.+]] = stablehlo.dot_general %[[ARG_1]], %[[ARG_2]], contracting_dims = [1] x [0] : (tensor<1x2x!quant.uniform>, tensor<2x3x!quant.uniform:f32, {{.*}}>>) -> tensor<1x3x!quant.uniform> +// CHECK-PER-TENSOR: %[[UNIFORM_QUANTIZE_1:.+]] = stablehlo.uniform_quantize %[[DOT_GENERAL_0]] : (tensor<1x3x!quant.uniform>) -> tensor<1x3x!quant.uniform> +// CHECK-PER-TENSOR: return %[[UNIFORM_QUANTIZE_1]] : tensor<1x3x!quant.uniform> +} + +// ----- + +// Tests that fused pattern for dot_general + bias is properly quantized. + +module attributes {tf_saved_model.semantics} { + func.func private @quantize_dot_general_with_bias_same_shape_fn(%arg0: tensor<1x2xf32>) -> tensor<1x3xf32> attributes {tf._original_func_name = "main_0"} { + %cst = "tf.Const"() {value = dense<3.00000000e-1> : tensor<2x3xf32>} : () -> tensor<2x3xf32> + %cst_0 = "tf.Const"() {value = dense<4.00000000e-1> : tensor<1x3xf32>} : () -> tensor<1x3xf32> + %0 = "quantfork.stats"(%arg0) {layerStats = dense<[6.00000000e-6, 9.00000000e-1]> : tensor<2xf32>} : (tensor<1x2xf32>) -> tensor<1x2xf32> + %1 = "tf.XlaCallModule"(%0, %cst, %cst_0) {Sout = [#tf_type.shape<1x3>], _entry_function = @composite_dot_general_with_bias_same_shape_fn, _original_entry_function = "composite_dot_general_with_bias_same_shape_fn", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : (tensor<1x2xf32>, tensor<2x3xf32>, tensor<1x3xf32>) -> tensor<1x3xf32> + %2 = "quantfork.stats"(%1) {layerStats = dense<[5.00000000e-6, 7.00000000e-1]> : tensor<2xf32>} : (tensor<1x3xf32>) -> tensor<1x3xf32> + return %2 : tensor<1x3xf32> + } +// CHECK-PER-TENSOR: func.func private @quantize_dot_general_with_bias_same_shape_fn(%[[ARG_0:.+]]: tensor<1x2xf32>) -> tensor<1x3xf32> attributes {tf._original_func_name = "main_0"} +// CHECK-PER-TENSOR-DAG: %[[CONST_0:.+]] = stablehlo.constant() {value = dense<{{.*}}> : tensor<2x3xi8>} : () -> tensor<2x3x!quant.uniform:f32, {{.*}}> +// CHECK-PER-TENSOR-DAG: %[[CONST_1:.+]] = stablehlo.constant() {value = dense<{{.*}}> : tensor<1x3xi32>} : () -> tensor<1x3x!quant.uniform +// CHECK-PER-TENSOR: %[[UNIFORM_QUANTIZE_0:.+]] = stablehlo.uniform_quantize %[[ARG_0]] : (tensor<1x2xf32>) -> tensor<1x2x!quant.uniform> +// CHECK-PER-TENSOR: %[[CALL_0:.+]] = call @quantized_dot_general_with_bias_same_shape_fn(%[[UNIFORM_QUANTIZE_0]], %[[CONST_0]], %[[CONST_1]]) : (tensor<1x2x!quant.uniform>, tensor<2x3x!quant.uniform:f32, {{.*}}>, tensor<1x3x!quant.uniform) -> tensor<1x3x!quant.uniform +// CHECK-PER-TENSOR: %[[UNIFORM_DEQUANTIZE_0:.+]] = stablehlo.uniform_dequantize %[[CALL_0]] : (tensor<1x3x!quant.uniform) -> tensor<1x3xf32> +// CHECK-PER-TENSOR: return %[[UNIFORM_DEQUANTIZE_0]] : tensor<1x3xf32> + + func.func private @composite_dot_general_with_bias_same_shape_fn(%arg0: tensor<1x2xf32>, %arg1: tensor<2x3xf32>, %arg2: tensor<1x3xf32>) -> tensor<1x3xf32> attributes {_from_xla_call_module} { + %0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0] : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32> + %1 = stablehlo.add %0, %arg2 : tensor<1x3xf32> + return %1 : tensor<1x3xf32> + } +// CHECK-PER-TENSOR: func.func private @quantized_dot_general_with_bias_same_shape_fn(%[[ARG_1:.+]]: tensor<1x2x!quant.uniform>, %[[ARG_2:.+]]: tensor<2x3x!quant.uniform:f32, {{.*}}>>, %[[ARG_3:.+]]: tensor<1x3x!quant.uniform>) -> tensor<1x3x!quant.uniform> attributes {_from_xla_call_module} +// CHECK-PER-TENSOR: %[[DOT_GENERAL_0:.+]] = stablehlo.dot_general %[[ARG_1]], %[[ARG_2]], contracting_dims = [1] x [0] : (tensor<1x2x!quant.uniform>, tensor<2x3x!quant.uniform:f32, {{.*}}>) -> tensor<1x3x!quant.uniform> +// CHECK-PER-TENSOR: %[[ADD_0:.+]] = stablehlo.add %[[DOT_GENERAL_0]], %[[ARG_3]] : tensor<1x3x!quant.uniform> +// CHECK-PER-TENSOR: %[[UNIFORM_QUANTIZE_1:.+]] = stablehlo.uniform_quantize %[[ADD_0]] : (tensor<1x3x!quant.uniform>) -> tensor<1x3x!quant.uniform> +// CHECK-PER-TENSOR: return %[[UNIFORM_QUANTIZE_1]] : tensor<1x3x!quant.uniform> + +} + +// ----- + +// Tests that fused pattern for dot_general + bias with dynamic batch dimension +// is properly quantized. + +module attributes {tf_saved_model.semantics} { + func.func private @quantize_dot_general_with_bias_dynamic_fn(%arg0: tensor) -> tensor attributes {tf._original_func_name = "main_0"} { + %cst = "tf.Const"() {value = dense<3.00000000e-1> : tensor<2x3xf32>} : () -> tensor<2x3xf32> + %cst_0 = "tf.Const"() {value = dense<4.00000000e-1> : tensor<3xf32>} : () -> tensor<3xf32> + %0 = "quantfork.stats"(%arg0) {layerStats = dense<[6.00000000e-6, 9.00000000e-1]> : tensor<2xf32>} : (tensor) -> tensor + %1 = "tf.XlaCallModule"(%0, %cst, %cst_0) {Sout = [#tf_type.shape], _entry_function = @composite_dot_general_with_bias_dynamic_fn, _original_entry_function = "composite_dot_general_with_bias_dynamic_fn", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : (tensor, tensor<2x3xf32>, tensor<3xf32>) -> tensor + %2 = "quantfork.stats"(%1) {layerStats = dense<[5.00000000e-6, 7.00000000e-1]> : tensor<2xf32>} : (tensor) -> tensor + return %2 : tensor + } +// CHECK-PER-TENSOR: func.func private @quantize_dot_general_with_bias_dynamic_fn(%[[ARG_0:.+]]: tensor) -> tensor attributes {tf._original_func_name = "main_0"} +// CHECK-PER-TENSOR-DAG: %[[CONST_0:.+]] = stablehlo.constant() {value = dense<{{.*}}> : tensor<2x3xi8>} : () -> tensor<2x3x!quant.uniform:f32, {{.*}}> +// CHECK-PER-TENSOR-DAG: %[[CONST_1:.+]] = stablehlo.constant() {value = dense<{{.*}}> : tensor<3xi32>} : () -> tensor<3x!quant.uniform +// CHECK-PER-TENSOR: %[[UNIFORM_QUANTIZE_0:.+]] = stablehlo.uniform_quantize %[[ARG_0]] : (tensor) -> tensor> +// CHECK-PER-TENSOR: %[[CALL_0:.+]] = call @quantized_dot_general_with_bias_dynamic_fn(%[[UNIFORM_QUANTIZE_0]], %[[CONST_0]], %[[CONST_1]]) : (tensor>, tensor<2x3x!quant.uniform:f32, {{.*}}>, tensor<3x!quant.uniform) -> tensor +// CHECK-PER-TENSOR: %[[UNIFORM_DEQUANTIZE_0:.+]] = stablehlo.uniform_dequantize %[[CALL_0]] : (tensor) -> tensor +// CHECK-PER-TENSOR: return %[[UNIFORM_DEQUANTIZE_0]] : tensor + + func.func private @composite_dot_general_with_bias_dynamic_fn(%arg0: tensor, %arg1: tensor<2x3xf32>, %arg2: tensor<3xf32>) -> tensor attributes {_from_xla_call_module} { + %cst_0 = stablehlo.constant dense<2> : tensor<1xi32> + %0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0] : (tensor, tensor<2x3xf32>) -> tensor + %1 = stablehlo.get_dimension_size %0, dim = 0 : (tensor) -> tensor + %2 = stablehlo.reshape %1 : (tensor) -> tensor<1xi32> + %3 = stablehlo.concatenate %2, %cst_0, dim = 0 : (tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32> + %4 = stablehlo.dynamic_broadcast_in_dim %arg2, %3, dims = [1] : (tensor<3xf32>, tensor<2xi32>) -> tensor + %5 = stablehlo.add %0, %4 : tensor + return %5 : tensor + } +} +// CHECK-PER-TENSOR: func.func private @quantized_dot_general_with_bias_dynamic_fn(%[[ARG_1:.+]]: tensor>, %[[ARG_2:.+]]: tensor<2x3x!quant.uniform:f32, {{.*}}>>, %[[ARG_3:.+]]: tensor<3x!quant.uniform>) -> tensor> attributes {_from_xla_call_module} +// CHECK-PER-TENSOR: %[[CONST_2:.+]] = stablehlo.constant dense<2> : tensor<1xi32> +// CHECK-PER-TENSOR: %[[DOT_GENERAL_0:.+]] = stablehlo.dot_general %[[ARG_1]], %[[ARG_2]], contracting_dims = [1] x [0] : (tensor>, tensor<2x3x!quant.uniform:f32, {{.*}}>>) -> tensor> +// CHECK-PER-TENSOR: %[[GET_DIMENSION_SIZE_0:.+]] = stablehlo.get_dimension_size %[[DOT_GENERAL_0]], dim = 0 : (tensor) +// CHECK-PER-TENSOR: %[[RESHAPE_0:.+]] = stablehlo.reshape %[[GET_DIMENSION_SIZE_0]] : (tensor) -> tensor<1xi32> +// CHECK-PER-TENSOR: %[[CONCATENATE_0:.+]] = stablehlo.concatenate %[[RESHAPE_0]], %[[CONST_2]], dim = 0 : (tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32> +// CHECK-PER-TENSOR: %[[DYNAMIC_BROADCAST_IN_DIM_0:.+]] = stablehlo.dynamic_broadcast_in_dim %[[ARG_3]], %[[CONCATENATE_0]], dims = [1] : (tensor<3x!quant.uniform>, tensor<2xi32>) -> tensor> +// CHECK-PER-TENSOR: %[[ADD_0:.+]] = stablehlo.add %[[DOT_GENERAL_0]], %[[DYNAMIC_BROADCAST_IN_DIM_0]] : tensor> +// CHECK-PER-TENSOR: %[[UNIFORM_QUANTIZE_1:.+]] = stablehlo.uniform_quantize %[[ADD_0]] : (tensor>) +// CHECK-PER-TENSOR: return %[[UNIFORM_QUANTIZE_1]] : tensor> + + +// ----- + +// Tests that basic convolution is properly quantized. + +module attributes {tf_saved_model.semantics} { + func.func private @quantize_conv_fn(%arg0: tensor<1x3x4x3xf32>) -> tensor<1x3x4x2xf32> attributes {tf._original_func_name = "main_0"} { + %cst = "tf.Const"() {value = dense<3.00000000e-1> : tensor<2x3x3x2xf32>} : () -> tensor<2x3x3x2xf32> + %0 = "quantfork.stats"(%arg0) {layerStats = dense<[6.00000000e-6, 9.00000000e-1]> : tensor<2xf32>} : (tensor<1x3x4x3xf32>) -> tensor<1x3x4x3xf32> + %1 = "tf.XlaCallModule"(%0, %cst) {Sout = [#tf_type.shape<1x3x4x2>], dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64, _entry_function = @composite_conv_fn, _original_entry_function = "composite_conv_fn", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = ""} : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>) -> tensor<1x3x4x2xf32> + %2 = "quantfork.stats"(%1) {layerStats = dense<[5.00000000e-6, 7.00000000e-1]> : tensor<2xf32>} : (tensor<1x3x4x2xf32>) -> tensor<1x3x4x2xf32> + return %2 : tensor<1x3x4x2xf32> + } +// Check that the quantized XlaCallModule has been replaced by a CallOp, which +// calls the quantized entry function. + +// CHECK: func.func private @quantize_conv_fn(%[[ARG_0:.+]]: tensor<1x3x4x3xf32>) -> tensor<1x3x4x2xf32> attributes {tf._original_func_name = "main_0"} +// CHECK: %[[CONST_0:.+]] = stablehlo.constant() {value = dense<{{.*}}> : tensor<2x3x3x2xi8>} : () -> tensor<2x3x3x2x!quant.uniform:f32:3, {{.*}}> +// CHECK: %[[UNIFORM_QUANTIZE_0:.+]] = stablehlo.uniform_quantize %[[ARG_0]] : (tensor<1x3x4x3xf32>) -> tensor<1x3x4x3x!quant.uniform> +// CHECK: %[[CALL_0:.+]] = call @quantized_conv_fn(%[[UNIFORM_QUANTIZE_0]], %[[CONST_0]]) : (tensor<1x3x4x3x!quant.uniform>, tensor<2x3x3x2x!quant.uniform:f32:3, {{.*}}>) -> tensor<1x3x4x2x!quant.uniform> +// CHECK: %[[UNIFORM_DEQUANTIZE_0:.+]] = stablehlo.uniform_dequantize %[[CALL_0]] : (tensor<1x3x4x2x!quant.uniform) -> tensor<1x3x4x2xf32> +// CHECK: return %[[UNIFORM_DEQUANTIZE_0]] : tensor<1x3x4x2xf32> + +// CHECK-PER-TENSOR: func.func private @quantize_conv_fn(%[[ARG_0:.+]]: tensor<1x3x4x3xf32>) -> tensor<1x3x4x2xf32> attributes {tf._original_func_name = "main_0"} +// CHECK-PER-TENSOR: %[[CONST_0:.+]] = stablehlo.constant() {value = dense<{{.*}}> : tensor<2x3x3x2xi8>} : () -> tensor<2x3x3x2x!quant.uniform:f32, {{.*}}> +// CHECK-PER-TENSOR: %[[UNIFORM_QUANTIZE_0:.+]] = stablehlo.uniform_quantize %[[ARG_0]] : (tensor<1x3x4x3xf32>) -> tensor<1x3x4x3x!quant.uniform> +// CHECK-PER-TENSOR: %[[CALL_0:.+]] = call @quantized_conv_fn(%[[UNIFORM_QUANTIZE_0]], %[[CONST_0]]) : (tensor<1x3x4x3x!quant.uniform>, tensor<2x3x3x2x!quant.uniform:f32, {{.*}}>) -> tensor<1x3x4x2x!quant.uniform> +// CHECK-PER-TENSOR: %[[UNIFORM_DEQUANTIZE_0:.+]] = stablehlo.uniform_dequantize %[[CALL_0]] : (tensor<1x3x4x2x!quant.uniform) -> tensor<1x3x4x2xf32> +// CHECK-PER-TENSOR: return %[[UNIFORM_DEQUANTIZE_0]] : tensor<1x3x4x2xf32> + + func.func private @composite_conv_fn(%arg0: tensor<1x3x4x3xf32>, %arg1: tensor<2x3x3x2xf32>) -> tensor<1x3x4x2xf32> attributes {_from_xla_call_module} { + %0 = stablehlo.convolution(%arg0, %arg1) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {pad = [[0, 1], [1, 1]]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>) -> tensor<1x3x4x2xf32> + return %0 : tensor<1x3x4x2xf32> + } +// Checks that the entry function is quantized for convolution. Quantized +// convolution outputs an i32 quantized tensor, followed by requantization to +// i8 quantized tensor. + +// CHECK: func.func private @quantized_conv_fn(%[[ARG_1:.+]]: tensor<1x3x4x3x!quant.uniform>, %[[ARG_2:.+]]: tensor<2x3x3x2x!quant.uniform:f32:3, {{.*}}>>) -> tensor<1x3x4x2x!quant.uniform> attributes {_from_xla_call_module} +// CHECK: %[[CONVOLUTION_0:.+]] = stablehlo.convolution(%[[ARG_1]], %[[ARG_2]]) {{.*}} : (tensor<1x3x4x3x!quant.uniform>, tensor<2x3x3x2x!quant.uniform:f32:3, {{.*}}>>) -> tensor<1x3x4x2x!quant.uniform> +// CHECK: %[[UNIFORM_QUANTIZE_1:.+]] = stablehlo.uniform_quantize %[[CONVOLUTION_0]] : (tensor<1x3x4x2x!quant.uniform>) -> tensor<1x3x4x2x!quant.uniform> +// CHECK: return %[[UNIFORM_QUANTIZE_1]] : tensor<1x3x4x2x!quant.uniform> + +// CHECK-PER-TENSOR: func.func private @quantized_conv_fn(%[[ARG_1:.+]]: tensor<1x3x4x3x!quant.uniform>, %[[ARG_2:.+]]: tensor<2x3x3x2x!quant.uniform:f32, {{.*}}>>) -> tensor<1x3x4x2x!quant.uniform> attributes {_from_xla_call_module} +// CHECK-PER-TENSOR: %[[CONVOLUTION_0:.+]] = stablehlo.convolution(%[[ARG_1]], %[[ARG_2]]) {{.*}} : (tensor<1x3x4x3x!quant.uniform>, tensor<2x3x3x2x!quant.uniform:f32, {{.*}}>>) -> tensor<1x3x4x2x!quant.uniform> +// CHECK-PER-TENSOR: %[[UNIFORM_QUANTIZE_1:.+]] = stablehlo.uniform_quantize %[[CONVOLUTION_0]] : (tensor<1x3x4x2x!quant.uniform>) -> tensor<1x3x4x2x!quant.uniform> +// CHECK-PER-TENSOR: return %[[UNIFORM_QUANTIZE_1]] : tensor<1x3x4x2x!quant.uniform> +} + +// ----- + +// Tests that fused pattern for convolution + bias is properly quantized. + +// Checks that fused functions with 1D bias is properly quantized. +// The 1D bias should be broadcasted in dims [3], where it initially has +// `quantizedDimension=0`, but has `quantizedDimension=3` after broadcasting. + +module attributes {tf_saved_model.semantics} { + func.func private @quantize_conv_with_bias_1d_fn(%arg0: tensor<1x3x4x3xf32>) -> tensor<1x3x4x2xf32> attributes {tf._original_func_name = "main_0"} { + %cst = "tf.Const"() {value = dense<3.00000000e-1> : tensor<2x3x3x2xf32>} : () -> tensor<2x3x3x2xf32> + %cst_0 = "tf.Const"() {value = dense<4.00000000e-1> : tensor<2xf32>} : () -> tensor<2xf32> + %0 = "quantfork.stats"(%arg0) {layerStats = dense<[6.00000000e-6, 9.00000000e-1]> : tensor<2xf32>} : (tensor<1x3x4x3xf32>) -> tensor<1x3x4x3xf32> + %1 = "tf.XlaCallModule"(%0, %cst, %cst_0) {Sout = [#tf_type.shape<1x3x4x2>], _entry_function = @composite_conv_with_bias_1d_fn, _original_entry_function = "composite_conv_with_bias_1d_fn", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>, tensor<2xf32>) -> tensor<1x3x4x2xf32> + %2 = "quantfork.stats"(%1) {layerStats = dense<[5.00000000e-6, 7.00000000e-1]> : tensor<2xf32>} : (tensor<1x3x4x2xf32>) -> tensor<1x3x4x2xf32> + return %2 : tensor<1x3x4x2xf32> + } +// CHECK: func.func private @quantize_conv_with_bias_1d_fn(%[[ARG_0:.+]]: tensor<1x3x4x3xf32>) -> tensor<1x3x4x2xf32> attributes {tf._original_func_name = "main_0"} +// CHECK-DAG: %[[CONST_0:.+]] = stablehlo.constant() {value = dense<{{.*}}> : tensor<2x3x3x2xi8>} : () -> tensor<2x3x3x2x!quant.uniform:f32:3, {{.*}}> +// CHECK-DAG: %[[CONST_1:.+]] = stablehlo.constant() {value = dense<47978> : tensor<2xi32>} : () -> tensor<2x!quant.uniform> +// CHECK: %[[UNIFORM_QUANTIZE_0:.+]] = stablehlo.uniform_quantize %[[ARG_0]] : (tensor<1x3x4x3xf32>) -> tensor<1x3x4x3x!quant.uniform> +// CHECK: %[[CALL_0:.+]] = call @quantized_conv_with_bias_1d_fn(%[[UNIFORM_QUANTIZE_0]], %[[CONST_0]], %[[CONST_1]]) : (tensor<1x3x4x3x!quant.uniform>, tensor<2x3x3x2x!quant.uniform:f32:3, {0.0023622048182750312,0.0023622048182750312}>>, tensor<2x!quant.uniform>) -> tensor<1x3x4x2x!quant.uniform> +// CHECK: %[[UNIFORM_DEQUANTIZE_0:.+]] = stablehlo.uniform_dequantize %[[CALL_0]] : (tensor<1x3x4x2x!quant.uniform) -> tensor<1x3x4x2xf32> +// CHECK: return %[[UNIFORM_DEQUANTIZE_0]] : tensor<1x3x4x2xf32> + +// CHECK-PER-TENSOR: func.func private @quantize_conv_with_bias_1d_fn(%[[ARG_0:.+]]: tensor<1x3x4x3xf32>) -> tensor<1x3x4x2xf32> attributes {tf._original_func_name = "main_0"} +// CHECK-PER-TENSOR-DAG: %[[CONST_0:.+]] = stablehlo.constant() {value = dense<{{.*}}> : tensor<2x3x3x2xi8>} : () -> tensor<2x3x3x2x!quant.uniform:f32, {{.*}}> +// CHECK-PER-TENSOR-DAG: %[[CONST_1:.+]] = stablehlo.constant() {value = dense<{{.*}}> : tensor<2xi32>} : () -> tensor<2x!quant.uniform +// CHECK-PER-TENSOR: %[[UNIFORM_QUANTIZE_0:.+]] = stablehlo.uniform_quantize %[[ARG_0]] : (tensor<1x3x4x3xf32>) -> tensor<1x3x4x3x!quant.uniform> +// CHECK-PER-TENSOR: %[[CALL_0:.+]] = call @quantized_conv_with_bias_1d_fn(%[[UNIFORM_QUANTIZE_0]], %[[CONST_0]], %[[CONST_1]]) : (tensor<1x3x4x3x!quant.uniform>, tensor<2x3x3x2x!quant.uniform:f32, {{.*}}>, tensor<2x!quant.uniform) -> tensor<1x3x4x2x!quant.uniform +// CHECK-PER-TENSOR: %[[UNIFORM_DEQUANTIZE_0:.+]] = stablehlo.uniform_dequantize %[[CALL_0]] : (tensor<1x3x4x2x!quant.uniform) -> tensor<1x3x4x2xf32> +// CHECK-PER-TENSOR: return %[[UNIFORM_DEQUANTIZE_0]] : tensor<1x3x4x2xf32> + + func.func private @composite_conv_with_bias_1d_fn(%arg0: tensor<1x3x4x3xf32>, %arg1: tensor<2x3x3x2xf32>, %arg2: tensor<2xf32>) -> tensor<1x3x4x2xf32> attributes {_from_xla_call_module} { + %0 = stablehlo.broadcast_in_dim %arg2, dims = [3] : (tensor<2xf32>) -> tensor<1x3x4x2xf32> + %1 = stablehlo.convolution(%arg0, %arg1) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {pad = [[0, 1], [1, 1]]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>) -> tensor<1x3x4x2xf32> + %2 = stablehlo.add %1, %0 : tensor<1x3x4x2xf32> + return %2 : tensor<1x3x4x2xf32> + } +// CHECK: func.func private @quantized_conv_with_bias_1d_fn(%[[ARG_1:.+]]: tensor<1x3x4x3x!quant.uniform>, %[[ARG_2:.+]]: tensor<2x3x3x2x!quant.uniform:f32:3, {{.*}}>>, %[[ARG_3:.+]]: tensor<2x!quant.uniform>) -> tensor<1x3x4x2x!quant.uniform> attributes {_from_xla_call_module} +// CHECK: %[[BROADCAST_IN_DIM:.+]] = stablehlo.broadcast_in_dim %arg2, dims = [3] : (tensor<2x!quant.uniform>) -> tensor<1x3x4x2x!quant.uniform> +// CHECK: %[[CONVOLUTION_0:.+]] = stablehlo.convolution(%[[ARG_1]], %[[ARG_2]]) {{.*}} : (tensor<1x3x4x3x!quant.uniform>, tensor<2x3x3x2x!quant.uniform:f32:3, {{.*}}>) -> tensor<1x3x4x2x!quant.uniform> +// CHECK: %[[ADD_0:.+]] = stablehlo.add %[[CONVOLUTION_0]], %[[BROADCAST_IN_DIM]] : tensor<1x3x4x2x!quant.uniform> +// CHECK: %[[UNIFORM_QUANTIZE_1:.+]] = stablehlo.uniform_quantize %[[ADD_0]] : (tensor<1x3x4x2x!quant.uniform>) -> tensor<1x3x4x2x!quant.uniform> +// CHECK: return %[[UNIFORM_QUANTIZE_1]] : tensor<1x3x4x2x!quant.uniform> + +// CHECK-PER-TENSOR: func.func private @quantized_conv_with_bias_1d_fn(%[[ARG_1:.+]]: tensor<1x3x4x3x!quant.uniform>, %[[ARG_2:.+]]: tensor<2x3x3x2x!quant.uniform:f32, {{.*}}>>, %[[ARG_3:.+]]: tensor<2x!quant.uniform>) -> tensor<1x3x4x2x!quant.uniform> attributes {_from_xla_call_module} +// CHECK-PER-TENSOR: %[[BROADCAST_IN_DIM:.+]] = stablehlo.broadcast_in_dim %[[ARG_3]] +// CHECK-PER-TENSOR: %[[CONVOLUTION_0:.+]] = stablehlo.convolution(%[[ARG_1]], %[[ARG_2]]) {{.*}} : (tensor<1x3x4x3x!quant.uniform>, tensor<2x3x3x2x!quant.uniform:f32, {{.*}}>) -> tensor<1x3x4x2x!quant.uniform> +// CHECK-PER-TENSOR: %[[ADD_0:.+]] = stablehlo.add %[[CONVOLUTION_0]], %[[BROADCAST_IN_DIM]] : tensor<1x3x4x2x!quant.uniform> +// CHECK-PER-TENSOR: %[[UNIFORM_QUANTIZE_1:.+]] = stablehlo.uniform_quantize %[[ADD_0]] : (tensor<1x3x4x2x!quant.uniform>) -> tensor<1x3x4x2x!quant.uniform> +// CHECK-PER-TENSOR: return %[[UNIFORM_QUANTIZE_1]] : tensor<1x3x4x2x!quant.uniform> +} + +// ----- + +// Checks that fused functions with 4D bias is properly quantized. +// The 4D bias should be braoadcasted in dims [0, 1, 2, 3], where it +// already has `quantizedDimension=3`. + +module attributes {tf_saved_model.semantics} { + func.func private @quantize_conv_with_bias_fn(%arg0: tensor<1x3x4x3xf32>) -> tensor<1x3x4x2xf32> attributes {tf._original_func_name = "main_0"} { + %cst = "tf.Const"() {value = dense<3.00000000e-1> : tensor<2x3x3x2xf32>} : () -> tensor<2x3x3x2xf32> + %cst_0 = "tf.Const"() {value = dense<4.00000000e-1> : tensor<1x1x1x2xf32>} : () -> tensor<1x1x1x2xf32> + %0 = "quantfork.stats"(%arg0) {layerStats = dense<[6.00000000e-6, 9.00000000e-1]> : tensor<2xf32>} : (tensor<1x3x4x3xf32>) -> tensor<1x3x4x3xf32> + %1 = "tf.XlaCallModule"(%0, %cst, %cst_0) {Sout = [#tf_type.shape<1x3x4x2>], _entry_function = @composite_conv_with_bias_fn, _original_entry_function = "composite_conv_with_bias_fn", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>, tensor<1x1x1x2xf32>) -> tensor<1x3x4x2xf32> + %2 = "quantfork.stats"(%1) {layerStats = dense<[5.00000000e-6, 7.00000000e-1]> : tensor<2xf32>} : (tensor<1x3x4x2xf32>) -> tensor<1x3x4x2xf32> + return %2 : tensor<1x3x4x2xf32> + } +// CHECK: func.func private @quantize_conv_with_bias_fn(%[[ARG_0:.+]]: tensor<1x3x4x3xf32>) -> tensor<1x3x4x2xf32> attributes {tf._original_func_name = "main_0"} +// CHECK-DAG: %[[CONST_0:.+]] = stablehlo.constant() {value = dense<{{.*}}> : tensor<2x3x3x2xi8>} : () -> tensor<2x3x3x2x!quant.uniform:f32:3, {{.*}}> +// CHECK-DAG: %[[CONST_1:.+]] = stablehlo.constant() {value = dense<{{.*}}> : tensor<1x1x1x2xi32>} : () -> tensor<1x1x1x2x!quant.uniform +// CHECK: %[[UNIFORM_QUANTIZE_0:.+]] = stablehlo.uniform_quantize %[[ARG_0]] : (tensor<1x3x4x3xf32>) -> tensor<1x3x4x3x!quant.uniform> +// CHECK: %[[CALL_0:.+]] = call @quantized_conv_with_bias_fn(%[[UNIFORM_QUANTIZE_0]], %[[CONST_0]], %[[CONST_1]]) : (tensor<1x3x4x3x!quant.uniform>, tensor<2x3x3x2x!quant.uniform:f32:3, {{.*}}>, tensor<1x1x1x2x!quant.uniform) -> tensor<1x3x4x2x!quant.uniform +// CHECK: %[[UNIFORM_DEQUANTIZE_0:.+]] = stablehlo.uniform_dequantize %[[CALL_0]] : (tensor<1x3x4x2x!quant.uniform) -> tensor<1x3x4x2xf32> +// CHECK: return %[[UNIFORM_DEQUANTIZE_0]] : tensor<1x3x4x2xf32> + +// CHECK-PER-TENSOR: func.func private @quantize_conv_with_bias_fn(%[[ARG_0:.+]]: tensor<1x3x4x3xf32>) -> tensor<1x3x4x2xf32> attributes {tf._original_func_name = "main_0"} +// CHECK-PER-TENSOR-DAG: %[[CONST_0:.+]] = stablehlo.constant() {value = dense<{{.*}}> : tensor<2x3x3x2xi8>} : () -> tensor<2x3x3x2x!quant.uniform:f32, {{.*}}> +// CHECK-PER-TENSOR-DAG: %[[CONST_1:.+]] = stablehlo.constant() {value = dense<{{.*}}> : tensor<1x1x1x2xi32>} : () -> tensor<1x1x1x2x!quant.uniform +// CHECK-PER-TENSOR: %[[UNIFORM_QUANTIZE_0:.+]] = stablehlo.uniform_quantize %[[ARG_0]] : (tensor<1x3x4x3xf32>) -> tensor<1x3x4x3x!quant.uniform> +// CHECK-PER-TENSOR: %[[CALL_0:.+]] = call @quantized_conv_with_bias_fn(%[[UNIFORM_QUANTIZE_0]], %[[CONST_0]], %[[CONST_1]]) : (tensor<1x3x4x3x!quant.uniform>, tensor<2x3x3x2x!quant.uniform:f32, {{.*}}>, tensor<1x1x1x2x!quant.uniform) -> tensor<1x3x4x2x!quant.uniform +// CHECK-PER-TENSOR: %[[UNIFORM_DEQUANTIZE_0:.+]] = stablehlo.uniform_dequantize %[[CALL_0]] : (tensor<1x3x4x2x!quant.uniform) -> tensor<1x3x4x2xf32> +// CHECK-PER-TENSOR: return %[[UNIFORM_DEQUANTIZE_0]] : tensor<1x3x4x2xf32> + + func.func private @composite_conv_with_bias_fn(%arg0: tensor<1x3x4x3xf32>, %arg1: tensor<2x3x3x2xf32>, %arg2: tensor<1x1x1x2xf32>) -> tensor<1x3x4x2xf32> attributes {_from_xla_call_module} { + %0 = stablehlo.broadcast_in_dim %arg2, dims = [0, 1, 2, 3] : (tensor<1x1x1x2xf32>) -> tensor<1x3x4x2xf32> + %1 = stablehlo.convolution(%arg0, %arg1) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {pad = [[0, 1], [1, 1]]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>) -> tensor<1x3x4x2xf32> + %2 = stablehlo.add %1, %0 : tensor<1x3x4x2xf32> + return %2 : tensor<1x3x4x2xf32> + } +// CHECK: func.func private @quantized_conv_with_bias_fn(%[[ARG_1:.+]]: tensor<1x3x4x3x!quant.uniform>, %[[ARG_2:.+]]: tensor<2x3x3x2x!quant.uniform:f32:3, {{.*}}>>, %[[ARG_3:.+]]: tensor<1x1x1x2x!quant.uniform>) -> tensor<1x3x4x2x!quant.uniform> attributes {_from_xla_call_module} +// CHECK: %[[BROADCAST_IN_DIM:.+]] = stablehlo.broadcast_in_dim %arg2, dims = [0, 1, 2, 3] : (tensor<1x1x1x2x!quant.uniform>) -> tensor<1x3x4x2x!quant.uniform> +// CHECK: %[[CONVOLUTION_0:.+]] = stablehlo.convolution(%[[ARG_1]], %[[ARG_2]]) {{.*}} : (tensor<1x3x4x3x!quant.uniform>, tensor<2x3x3x2x!quant.uniform:f32:3, {{.*}}>) -> tensor<1x3x4x2x!quant.uniform> +// CHECK: %[[ADD_0:.+]] = stablehlo.add %[[CONVOLUTION_0]], %[[BROADCAST_IN_DIM]] : tensor<1x3x4x2x!quant.uniform> +// CHECK: %[[UNIFORM_QUANTIZE_1:.+]] = stablehlo.uniform_quantize %[[ADD_0]] : (tensor<1x3x4x2x!quant.uniform>) -> tensor<1x3x4x2x!quant.uniform> +// CHECK: return %[[UNIFORM_QUANTIZE_1]] : tensor<1x3x4x2x!quant.uniform> + +// CHECK-PER-TENSOR: func.func private @quantized_conv_with_bias_fn(%[[ARG_1:.+]]: tensor<1x3x4x3x!quant.uniform>, %[[ARG_2:.+]]: tensor<2x3x3x2x!quant.uniform:f32, {{.*}}>>, %[[ARG_3:.+]]: tensor<1x1x1x2x!quant.uniform>) -> tensor<1x3x4x2x!quant.uniform> attributes {_from_xla_call_module} +// CHECK-PER-TENSOR: %[[BROADCAST_IN_DIM:.+]] = stablehlo.broadcast_in_dim %arg2 +// CHECK-PER-TENSOR: %[[CONVOLUTION_0:.+]] = stablehlo.convolution(%[[ARG_1]], %[[ARG_2]]) {{.*}} : (tensor<1x3x4x3x!quant.uniform>, tensor<2x3x3x2x!quant.uniform:f32, {{.*}}>) -> tensor<1x3x4x2x!quant.uniform> +// CHECK-PER-TENSOR: %[[ADD_0:.+]] = stablehlo.add %[[CONVOLUTION_0]], %[[BROADCAST_IN_DIM]] : tensor<1x3x4x2x!quant.uniform> +// CHECK-PER-TENSOR: %[[UNIFORM_QUANTIZE_1:.+]] = stablehlo.uniform_quantize %[[ADD_0]] : (tensor<1x3x4x2x!quant.uniform>) -> tensor<1x3x4x2x!quant.uniform> +// CHECK-PER-TENSOR: return %[[UNIFORM_QUANTIZE_1]] : tensor<1x3x4x2x!quant.uniform> +} + +// ----- + +// Tests that fused pattern for convolution + bias with dynamic batch dimension +// is properly quantized. + +module attributes {tf_saved_model.semantics} { + func.func private @quantize_conv_with_bias_dynamic_fn(%arg0: tensor) -> tensor attributes {tf._original_func_name = "main_0"} { + %cst = "tf.Const"() {value = dense<3.00000000e-1> : tensor<2x3x3x2xf32>} : () -> tensor<2x3x3x2xf32> + %cst_0 = "tf.Const"() {value = dense<4.00000000e-1> : tensor<1x1x1x2xf32>} : () -> tensor<1x1x1x2xf32> + %0 = "quantfork.stats"(%arg0) {layerStats = dense<[6.00000000e-6, 9.00000000e-1]> : tensor<2xf32>} : (tensor) -> tensor + %1 = "tf.XlaCallModule"(%0, %cst, %cst_0) {Sout = [#tf_type.shape<1x3x4x2>], _entry_function = @composite_conv_with_bias_dynamic_fn, _original_entry_function = "composite_conv_with_bias_dynamic_fn", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : (tensor, tensor<2x3x3x2xf32>, tensor<1x1x1x2xf32>) -> tensor + %2 = "quantfork.stats"(%1) {layerStats = dense<[5.00000000e-6, 7.00000000e-1]> : tensor<2xf32>} : (tensor) -> tensor + return %2 : tensor + } +// CHECK: func.func private @quantize_conv_with_bias_dynamic_fn(%[[ARG_0:.+]]: tensor) -> tensor attributes {tf._original_func_name = "main_0"} +// CHECK-DAG: %[[CONST_0:.+]] = stablehlo.constant() {value = dense<{{.*}}> : tensor<2x3x3x2xi8>} : () -> tensor<2x3x3x2x!quant.uniform:f32:3, {{.*}}> +// CHECK-DAG: %[[CONST_1:.+]] = stablehlo.constant() {value = dense<{{.*}}> : tensor<1x1x1x2xi32>} : () -> tensor<1x1x1x2x!quant.uniform +// CHECK: %[[UNIFORM_QUANTIZE_0:.+]] = stablehlo.uniform_quantize %[[ARG_0]] : (tensor) -> tensor> +// CHECK: %[[CALL_0:.+]] = call @quantized_conv_with_bias_dynamic_fn(%[[UNIFORM_QUANTIZE_0]], %[[CONST_0]], %[[CONST_1]]) : (tensor>, tensor<2x3x3x2x!quant.uniform:f32:3, {{.*}}>, tensor<1x1x1x2x!quant.uniform) -> tensor +// CHECK: %[[UNIFORM_DEQUANTIZE_0:.+]] = stablehlo.uniform_dequantize %[[CALL_0]] : (tensor) -> tensor +// CHECK: return %[[UNIFORM_DEQUANTIZE_0]] : tensor + +// CHECK-PER-TENSOR: func.func private @quantize_conv_with_bias_dynamic_fn(%[[ARG_0:.+]]: tensor) -> tensor attributes {tf._original_func_name = "main_0"} +// CHECK-PER-TENSOR-DAG: %[[CONST_0:.+]] = stablehlo.constant() {value = dense<{{.*}}> : tensor<2x3x3x2xi8>} : () -> tensor<2x3x3x2x!quant.uniform:f32, {{.*}}> +// CHECK-PER-TENSOR-DAG: %[[CONST_1:.+]] = stablehlo.constant() {value = dense<{{.*}}> : tensor<1x1x1x2xi32>} : () -> tensor<1x1x1x2x!quant.uniform +// CHECK-PER-TENSOR: %[[UNIFORM_QUANTIZE_0:.+]] = stablehlo.uniform_quantize %[[ARG_0]] : (tensor) -> tensor> +// CHECK-PER-TENSOR: %[[CALL_0:.+]] = call @quantized_conv_with_bias_dynamic_fn(%[[UNIFORM_QUANTIZE_0]], %[[CONST_0]], %[[CONST_1]]) : (tensor>, tensor<2x3x3x2x!quant.uniform:f32, {{.*}}>, tensor<1x1x1x2x!quant.uniform) -> tensor +// CHECK-PER-TENSOR: %[[UNIFORM_DEQUANTIZE_0:.+]] = stablehlo.uniform_dequantize %[[CALL_0]] : (tensor) -> tensor +// CHECK-PER-TENSOR: return %[[UNIFORM_DEQUANTIZE_0]] : tensor + + func.func private @composite_conv_with_bias_dynamic_fn(%arg0: tensor, %arg1: tensor<2x3x3x2xf32>, %arg2: tensor<1x1x1x2xf32>) -> tensor attributes {_from_xla_call_module} { + %cst_0 = stablehlo.constant dense<3> : tensor<1xi32> + %cst_1 = stablehlo.constant dense<4> : tensor<1xi32> + %cst_2 = stablehlo.constant dense<2> : tensor<1xi32> + %0 = stablehlo.convolution(%arg0, %arg1) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {pad = [[0, 1], [1, 1]]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor, tensor<2x3x3x2xf32>) -> tensor + %1 = stablehlo.get_dimension_size %0, dim = 0 : (tensor) -> tensor + %2 = stablehlo.reshape %1 : (tensor) -> tensor<1xi32> + %3 = stablehlo.concatenate %2, %cst_0, %cst_1, %cst_2, dim = 0 : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<4xi32> + %4 = stablehlo.dynamic_broadcast_in_dim %arg2, %3, dims = [0, 1, 2, 3] : (tensor<1x1x1x2xf32>, tensor<4xi32>) -> tensor + %5 = stablehlo.add %0, %4 : tensor + return %5 : tensor + } +} +// CHECK: func.func private @quantized_conv_with_bias_dynamic_fn(%[[ARG_1:.+]]: tensor>, %[[ARG_2:.+]]: tensor<2x3x3x2x!quant.uniform:f32:3, {{.*}}>>, %[[ARG_3:.+]]: tensor<1x1x1x2x!quant.uniform>) -> tensor> attributes {_from_xla_call_module} +// CHECK-DAG: %[[CONST_2:.+]] = stablehlo.constant dense<3> : tensor<1xi32> +// CHECK-DAG: %[[CONST_3:.+]] = stablehlo.constant dense<4> : tensor<1xi32> +// CHECK-DAG: %[[CONST_4:.+]] = stablehlo.constant dense<2> : tensor<1xi32> +// CHECK: %[[CONVOLUTION_0:.+]] = stablehlo.convolution(%[[ARG_1]], %[[ARG_2]]) {{.*}} : (tensor>, tensor<2x3x3x2x!quant.uniform:f32:3, {{.*}}>) -> tensor> +// CHECK: %[[GET_DIMENSION_SIZE_0:.+]] = stablehlo.get_dimension_size %[[CONVOLUTION_0]], dim = 0 : (tensor) +// CHECK: %[[RESHAPE_0:.+]] = stablehlo.reshape %[[GET_DIMENSION_SIZE_0]] : (tensor) -> tensor<1xi32> +// CHECK: %[[CONCATENATE_0:.+]] = stablehlo.concatenate %[[RESHAPE_0]], %[[CONST_2]], %[[CONST_3]], %[[CONST_4]], dim = 0 : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<4xi32> +// CHECK: %[[DYNAMIC_BROADCAST_IN_DIM_0:.+]] = stablehlo.dynamic_broadcast_in_dim %[[ARG_3]], %[[CONCATENATE_0]], dims = [0, 1, 2, 3] : (tensor<1x1x1x2x!quant.uniform>, tensor<4xi32>) -> tensor> +// CHECK: %[[ADD_0:.+]] = stablehlo.add %[[CONVOLUTION_0]], %[[DYNAMIC_BROADCAST_IN_DIM_0]] : tensor> +// CHECK: %[[UNIFORM_QUANTIZE_0:.+]] = stablehlo.uniform_quantize %[[ADD_0]] : (tensor>) +// CHECK: return %[[UNIFORM_QUANTIZE_0]] : tensor> + +// CHECK-PER-TENSOR: func.func private @quantized_conv_with_bias_dynamic_fn(%[[ARG_1:.+]]: tensor>, %[[ARG_2:.+]]: tensor<2x3x3x2x!quant.uniform:f32, {{.*}}>>, %[[ARG_3:.+]]: tensor<1x1x1x2x!quant.uniform>) -> tensor> attributes {_from_xla_call_module} +// CHECK-PER-TENSOR-DAG: %[[CONST_2:.+]] = stablehlo.constant dense<3> : tensor<1xi32> +// CHECK-PER-TENSOR-DAG: %[[CONST_3:.+]] = stablehlo.constant dense<4> : tensor<1xi32> +// CHECK-PER-TENSOR-DAG: %[[CONST_4:.+]] = stablehlo.constant dense<2> : tensor<1xi32> +// CHECK-PER-TENSOR: %[[CONVOLUTION_0:.+]] = stablehlo.convolution(%[[ARG_1]], %[[ARG_2]]) {{.*}} : (tensor>, tensor<2x3x3x2x!quant.uniform:f32, {{.*}}>) -> tensor> +// CHECK-PER-TENSOR: %[[GET_DIMENSION_SIZE_0:.+]] = stablehlo.get_dimension_size %[[CONVOLUTION_0]], dim = 0 : (tensor) +// CHECK-PER-TENSOR: %[[RESHAPE_0:.+]] = stablehlo.reshape %[[GET_DIMENSION_SIZE_0]] : (tensor) -> tensor<1xi32> +// CHECK-PER-TENSOR: %[[CONCATENATE_0:.+]] = stablehlo.concatenate %[[RESHAPE_0]], %[[CONST_2]], %[[CONST_3]], %[[CONST_4]], dim = 0 : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<4xi32> +// CHECK-PER-TENSOR: %[[DYNAMIC_BROADCAST_IN_DIM_0:.+]] = stablehlo.dynamic_broadcast_in_dim %[[ARG_3]], %[[CONCATENATE_0]], dims = [0, 1, 2, 3] : (tensor<1x1x1x2x!quant.uniform>, tensor<4xi32>) -> tensor> +// CHECK-PER-TENSOR: %[[ADD_0:.+]] = stablehlo.add %[[CONVOLUTION_0]], %[[DYNAMIC_BROADCAST_IN_DIM_0]] : tensor> +// CHECK-PER-TENSOR: %[[UNIFORM_QUANTIZE_0:.+]] = stablehlo.uniform_quantize %[[ADD_0]] : (tensor>) +// CHECK-PER-TENSOR: return %[[UNIFORM_QUANTIZE_0]] : tensor> + +// ----- + +// Tests that fused pattern for convolution + bias + relu with +// dynamic batch dimension is properly quantized. + +// Note that this checks for identical condition as +// quantize_conv_with_bias_dynamic_fn, omitting stablehlo.maximum. +// This is because activation clipping which includes 0.0f can be simply +// omitted from the graph as the lifted function's out_scale and out_zp are +// already calculated based on the clipped distribution. +// Note that the resulting scale and zero point should be calculated based on +// clipped range [0, r_max]. + +module attributes {tf_saved_model.semantics} { + func.func private @quantize_conv_with_bias_and_relu_dynamic_fn(%arg0: tensor) -> tensor attributes {tf._original_func_name = "main_0"} { + %cst = "tf.Const"() {value = dense<3.00000000e-1> : tensor<2x3x3x2xf32>} : () -> tensor<2x3x3x2xf32> + %cst_0 = "tf.Const"() {value = dense<4.00000000e-1> : tensor<1x1x1x2xf32>} : () -> tensor<1x1x1x2xf32> + %0 = "quantfork.stats"(%arg0) {layerStats = dense<[6.00000000e-6, 9.00000000e-1]> : tensor<2xf32>} : (tensor) -> tensor + %1 = "tf.XlaCallModule"(%0, %cst, %cst_0) {Sout = [#tf_type.shape<1x3x4x2>], _entry_function = @composite_conv_with_bias_and_relu_dynamic_fn, _original_entry_function = "composite_conv_with_bias_and_relu_dynamic_fn", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : (tensor, tensor<2x3x3x2xf32>, tensor<1x1x1x2xf32>) -> tensor + %2 = "quantfork.stats"(%1) {layerStats = dense<[0.00000000e-6, 8.00000000e-1]> : tensor<2xf32>} : (tensor) -> tensor + return %2 : tensor + } +// CHECK: func.func private @quantize_conv_with_bias_and_relu_dynamic_fn(%[[ARG_0:.+]]: tensor) -> tensor attributes {tf._original_func_name = "main_0"} +// CHECK-DAG: %[[CONST_0:.+]] = stablehlo.constant() {value = dense<{{.*}}> : tensor<2x3x3x2xi8>} : () -> tensor<2x3x3x2x!quant.uniform:f32:3, {{.*}}> +// CHECK-DAG: %[[CONST_1:.+]] = stablehlo.constant() {value = dense<{{.*}}> : tensor<1x1x1x2xi32>} : () -> tensor<1x1x1x2x!quant.uniform +// CHECK: %[[UNIFORM_QUANTIZE_0:.+]] = stablehlo.uniform_quantize %[[ARG_0]] : (tensor) -> tensor> +// CHECK: %[[CALL_0:.+]] = call @quantized_conv_with_bias_and_relu_dynamic_fn(%[[UNIFORM_QUANTIZE_0]], %[[CONST_0]], %[[CONST_1]]) : (tensor>, tensor<2x3x3x2x!quant.uniform:f32:3, {0.0023622048182750312,0.0023622048182750312}>>, tensor<1x1x1x2x!quant.uniform>) -> tensor> +// CHECK: %[[UNIFORM_DEQUANTIZE_0:.+]] = stablehlo.uniform_dequantize %[[CALL_0]] : (tensor) -> tensor +// CHECK: return %[[UNIFORM_DEQUANTIZE_0]] : tensor + +// CHECK-PER-TENSOR: func.func private @quantize_conv_with_bias_and_relu_dynamic_fn(%[[ARG_0:.+]]: tensor) -> tensor attributes {tf._original_func_name = "main_0"} +// CHECK-PER-TENSOR-DAG: %[[CONST_0:.+]] = stablehlo.constant() {value = dense<{{.*}}> : tensor<2x3x3x2xi8>} : () -> tensor<2x3x3x2x!quant.uniform:f32, {{.*}}> +// CHECK-PER-TENSOR-DAG: %[[CONST_1:.+]] = stablehlo.constant() {value = dense<{{.*}}> : tensor<1x1x1x2xi32>} : () -> tensor<1x1x1x2x!quant.uniform +// CHECK-PER-TENSOR: %[[UNIFORM_QUANTIZE_0:.+]] = stablehlo.uniform_quantize %[[ARG_0]] : (tensor) -> tensor> +// CHECK-PER-TENSOR: %[[CALL_0:.+]] = call @quantized_conv_with_bias_and_relu_dynamic_fn(%[[UNIFORM_QUANTIZE_0]], %[[CONST_0]], %[[CONST_1]]) : (tensor>, tensor<2x3x3x2x!quant.uniform:f32, 0.0023622048182750312>>, tensor<1x1x1x2x!quant.uniform>) -> tensor> +// CHECK-PER-TENSOR: %[[UNIFORM_DEQUANTIZE_0:.+]] = stablehlo.uniform_dequantize %[[CALL_0]] : (tensor) -> tensor +// CHECK-PER-TENSOR: return %[[UNIFORM_DEQUANTIZE_0]] : tensor + + func.func private @composite_conv_with_bias_and_relu_dynamic_fn(%arg0: tensor, %arg1: tensor<2x3x3x2xf32>, %arg2: tensor<1x1x1x2xf32>) -> tensor attributes {_from_xla_call_module} { + %cst_0 = stablehlo.constant dense<3> : tensor<1xi32> + %cst_1 = stablehlo.constant dense<4> : tensor<1xi32> + %cst_2 = stablehlo.constant dense<2> : tensor<1xi32> + %cst_3 = stablehlo.constant dense<0.000000e+00> : tensor + %cst_4 = stablehlo.constant dense<6.000000e+00> : tensor + %0 = stablehlo.convolution(%arg0, %arg1) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {pad = [[0, 1], [1, 1]]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor, tensor<2x3x3x2xf32>) -> tensor + %1 = stablehlo.get_dimension_size %0, dim = 0 : (tensor) -> tensor + %2 = stablehlo.reshape %1 : (tensor) -> tensor<1xi32> + %3 = stablehlo.concatenate %2, %cst_0, %cst_1, %cst_2, dim = 0 : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<4xi32> + %4 = stablehlo.dynamic_broadcast_in_dim %arg2, %3, dims = [0, 1, 2, 3] : (tensor<1x1x1x2xf32>, tensor<4xi32>) -> tensor + %5 = stablehlo.add %0, %4 : tensor + %6 = stablehlo.clamp %cst_3, %5, %cst_4 : (tensor, tensor, tensor) -> tensor + return %6 : tensor + } +} +// CHECK: func.func private @quantized_conv_with_bias_and_relu_dynamic_fn(%[[ARG_1:.+]]: tensor>, %[[ARG_2:.+]]: tensor<2x3x3x2x!quant.uniform:f32:3, {{.*}}>>, %[[ARG_3:.+]]: tensor<1x1x1x2x!quant.uniform>) -> tensor> attributes {_from_xla_call_module} +// CHECK-DAG: %[[CONST_2:.+]] = stablehlo.constant dense<3> : tensor<1xi32> +// CHECK-DAG: %[[CONST_3:.+]] = stablehlo.constant dense<4> : tensor<1xi32> +// CHECK-DAG: %[[CONST_4:.+]] = stablehlo.constant dense<2> : tensor<1xi32> +// CHECK: %[[CONVOLUTION_0:.+]] = stablehlo.convolution(%[[ARG_1]], %[[ARG_2]]) {{.*}} : (tensor>, tensor<2x3x3x2x!quant.uniform:f32:3, {0.0023622048182750312,0.0023622048182750312}>>) -> tensor> +// CHECK: %[[GET_DIMENSION_SIZE_0:.+]] = stablehlo.get_dimension_size %[[CONVOLUTION_0]], dim = 0 : (tensor) +// CHECK: %[[RESHAPE_0:.+]] = stablehlo.reshape %[[GET_DIMENSION_SIZE_0]] : (tensor) -> tensor<1xi32> +// CHECK: %[[CONCATENATE_0:.+]] = stablehlo.concatenate %[[RESHAPE_0]], %[[CONST_2]], %[[CONST_3]], %[[CONST_4]], dim = 0 : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<4xi32> +// CHECK: %[[DYNAMIC_BROADCAST_IN_DIM_0:.+]] = stablehlo.dynamic_broadcast_in_dim %[[ARG_3]], %[[CONCATENATE_0]], dims = [0, 1, 2, 3] : (tensor<1x1x1x2x!quant.uniform>, tensor<4xi32>) -> tensor> +// CHECK: %[[ADD_0:.+]] = stablehlo.add %[[CONVOLUTION_0]], %[[DYNAMIC_BROADCAST_IN_DIM_0]] : tensor> +// CHECK: %[[UNIFORM_QUANTIZE_0:.+]] = stablehlo.uniform_quantize %[[ADD_0]] : (tensor>) -> tensor> +// CHECK: return %[[UNIFORM_QUANTIZE_0]] : tensor> + +// CHECK-PER-TENSOR: func.func private @quantized_conv_with_bias_and_relu_dynamic_fn(%[[ARG_1:.+]]: tensor>, %[[ARG_2:.+]]: tensor<2x3x3x2x!quant.uniform:f32, {{.*}}>>, %[[ARG_3:.+]]: tensor<1x1x1x2x!quant.uniform>) -> tensor> attributes {_from_xla_call_module} +// CHECK-PER-TENSOR-DAG: %[[CONST_2:.+]] = stablehlo.constant dense<3> : tensor<1xi32> +// CHECK-PER-TENSOR-DAG: %[[CONST_3:.+]] = stablehlo.constant dense<4> : tensor<1xi32> +// CHECK-PER-TENSOR-DAG: %[[CONST_4:.+]] = stablehlo.constant dense<2> : tensor<1xi32> +// CHECK-PER-TENSOR: %[[CONVOLUTION_0:.+]] = stablehlo.convolution(%[[ARG_1]], %[[ARG_2]]) {{.*}} : (tensor>, tensor<2x3x3x2x!quant.uniform:f32, 0.0023622048182750312>>) -> tensor> +// CHECK-PER-TENSOR: %[[GET_DIMENSION_SIZE_0:.+]] = stablehlo.get_dimension_size %[[CONVOLUTION_0]], dim = 0 : (tensor) +// CHECK-PER-TENSOR: %[[RESHAPE_0:.+]] = stablehlo.reshape %[[GET_DIMENSION_SIZE_0]] : (tensor) -> tensor<1xi32> +// CHECK-PER-TENSOR: %[[CONCATENATE_0:.+]] = stablehlo.concatenate %[[RESHAPE_0]], %[[CONST_2]], %[[CONST_3]], %[[CONST_4]], dim = 0 : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<4xi32> +// CHECK-PER-TENSOR: %[[DYNAMIC_BROADCAST_IN_DIM_0:.+]] = stablehlo.dynamic_broadcast_in_dim %[[ARG_3]], %[[CONCATENATE_0]], dims = [0, 1, 2, 3] : (tensor<1x1x1x2x!quant.uniform>, tensor<4xi32>) -> tensor> +// CHECK-PER-TENSOR: %[[ADD_0:.+]] = stablehlo.add %[[CONVOLUTION_0]], %[[DYNAMIC_BROADCAST_IN_DIM_0]] : tensor> +// CHECK-PER-TENSOR: %[[UNIFORM_QUANTIZE_0:.+]] = stablehlo.uniform_quantize %[[ADD_0]] : (tensor>) -> tensor> +// CHECK-PER-TENSOR: return %[[UNIFORM_QUANTIZE_0]] : tensor> + +// ----- + +// Tests that fused pattern for convolution + bias + relu6 with +// dynamic batch dimension is properly quantized. + +// Note that this checks for identical condition as +// quantize_conv_with_bias_dynamic_fn, omitting stablehlo.clamp. +// This is because activation clipping which includes 0.0f can be simply +// omitted from the graph as the lifted function's out_scale and out_zp are +// already calculated based on the clipped distribution. +// Note that the resulting scale and zero point should be calculated based on +// clipped range [0, r_max]. + +module attributes {tf_saved_model.semantics} { + func.func private @quantize_conv_with_bias_and_relu6_dynamic_fn(%arg0: tensor) -> tensor attributes {tf._original_func_name = "main_0"} { + %cst = "tf.Const"() {value = dense<3.00000000e-1> : tensor<2x3x3x2xf32>} : () -> tensor<2x3x3x2xf32> + %cst_0 = "tf.Const"() {value = dense<4.00000000e-1> : tensor<1x1x1x2xf32>} : () -> tensor<1x1x1x2xf32> + %0 = "quantfork.stats"(%arg0) {layerStats = dense<[6.00000000e-6, 9.00000000e-1]> : tensor<2xf32>} : (tensor) -> tensor + %1 = "tf.XlaCallModule"(%0, %cst, %cst_0) {Sout = [#tf_type.shape<1x3x4x2>], _entry_function = @composite_conv_with_bias_and_relu6_dynamic_fn, _original_entry_function = "composite_conv_with_bias_and_relu6_dynamic_fn", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : (tensor, tensor<2x3x3x2xf32>, tensor<1x1x1x2xf32>) -> tensor + %2 = "quantfork.stats"(%1) {layerStats = dense<[5.00000000e-6, 6.00000000e-1]> : tensor<2xf32>} : (tensor) -> tensor + return %2 : tensor + } +// CHECK: func.func private @quantize_conv_with_bias_and_relu6_dynamic_fn(%[[ARG_0:.+]]: tensor) -> tensor attributes {tf._original_func_name = "main_0"} +// CHECK-DAG: %[[CONST_0:.+]] = stablehlo.constant() {value = dense<{{.*}}> : tensor<2x3x3x2xi8>} : () -> tensor<2x3x3x2x!quant.uniform:f32:3, {{.*}}> +// CHECK-DAG: %[[CONST_1:.+]] = stablehlo.constant() {value = dense<{{.*}}> : tensor<1x1x1x2xi32>} : () -> tensor<1x1x1x2x!quant.uniform +// CHECK: %[[UNIFORM_QUANTIZE_0:.+]] = stablehlo.uniform_quantize %[[ARG_0]] : (tensor) -> tensor> +// CHECK: %[[CALL_0:.+]] = call @quantized_conv_with_bias_and_relu6_dynamic_fn(%[[UNIFORM_QUANTIZE_0]], %[[CONST_0]], %[[CONST_1]]) : (tensor>, tensor<2x3x3x2x!quant.uniform:f32:3, {0.0023622048182750312,0.0023622048182750312}>>, tensor<1x1x1x2x!quant.uniform>) -> tensor> +// CHECK: %[[UNIFORM_DEQUANTIZE_0:.+]] = stablehlo.uniform_dequantize %[[CALL_0]] : (tensor) -> tensor +// CHECK: return %[[UNIFORM_DEQUANTIZE_0]] : tensor + +// CHECK-PER-TENSOR: func.func private @quantize_conv_with_bias_and_relu6_dynamic_fn(%[[ARG_0:.+]]: tensor) -> tensor attributes {tf._original_func_name = "main_0"} +// CHECK-PER-TENSOR-DAG: %[[CONST_0:.+]] = stablehlo.constant() {value = dense<{{.*}}> : tensor<2x3x3x2xi8>} : () -> tensor<2x3x3x2x!quant.uniform:f32, {{.*}}> +// CHECK-PER-TENSOR-DAG: %[[CONST_1:.+]] = stablehlo.constant() {value = dense<{{.*}}> : tensor<1x1x1x2xi32>} : () -> tensor<1x1x1x2x!quant.uniform +// CHECK-PER-TENSOR: %[[UNIFORM_QUANTIZE_0:.+]] = stablehlo.uniform_quantize %[[ARG_0]] : (tensor) -> tensor> +// CHECK-PER-TENSOR: %[[CALL_0:.+]] = call @quantized_conv_with_bias_and_relu6_dynamic_fn(%[[UNIFORM_QUANTIZE_0]], %[[CONST_0]], %[[CONST_1]]) : (tensor>, tensor<2x3x3x2x!quant.uniform:f32, 0.0023622048182750312>>, tensor<1x1x1x2x!quant.uniform>) -> tensor> +// CHECK-PER-TENSOR: %[[UNIFORM_DEQUANTIZE_0:.+]] = stablehlo.uniform_dequantize %[[CALL_0]] : (tensor) -> tensor +// CHECK-PER-TENSOR: return %[[UNIFORM_DEQUANTIZE_0]] : tensor + + func.func private @composite_conv_with_bias_and_relu6_dynamic_fn(%arg0: tensor, %arg1: tensor<2x3x3x2xf32>, %arg2: tensor<1x1x1x2xf32>) -> tensor attributes {_from_xla_call_module} { + %cst_0 = stablehlo.constant dense<3> : tensor<1xi32> + %cst_1 = stablehlo.constant dense<4> : tensor<1xi32> + %cst_2 = stablehlo.constant dense<2> : tensor<1xi32> + %cst_3 = stablehlo.constant dense<0.000000e+00> : tensor + %cst_4 = stablehlo.constant dense<6.000000e+00> : tensor + %0 = stablehlo.convolution(%arg0, %arg1) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {pad = [[0, 1], [1, 1]]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor, tensor<2x3x3x2xf32>) -> tensor + %1 = stablehlo.get_dimension_size %0, dim = 0 : (tensor) -> tensor + %2 = stablehlo.reshape %1 : (tensor) -> tensor<1xi32> + %3 = stablehlo.concatenate %2, %cst_0, %cst_1, %cst_2, dim = 0 : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<4xi32> + %4 = stablehlo.dynamic_broadcast_in_dim %arg2, %3, dims = [0, 1, 2, 3] : (tensor<1x1x1x2xf32>, tensor<4xi32>) -> tensor + %5 = stablehlo.add %0, %4 : tensor + %6 = stablehlo.clamp %cst_3, %5, %cst_4 : (tensor, tensor, tensor) -> tensor + return %6 : tensor + } +} +// CHECK: func.func private @quantized_conv_with_bias_and_relu6_dynamic_fn(%[[ARG_1:.+]]: tensor>, %[[ARG_2:.+]]: tensor<2x3x3x2x!quant.uniform:f32:3, {{.*}}>>, %[[ARG_3:.+]]: tensor<1x1x1x2x!quant.uniform>) -> tensor> attributes {_from_xla_call_module} +// CHECK-DAG: %[[CONST_2:.+]] = stablehlo.constant dense<3> : tensor<1xi32> +// CHECK-DAG: %[[CONST_3:.+]] = stablehlo.constant dense<4> : tensor<1xi32> +// CHECK-DAG: %[[CONST_4:.+]] = stablehlo.constant dense<2> : tensor<1xi32> +// CHECK: %[[CONVOLUTION_0:.+]] = stablehlo.convolution(%[[ARG_1]], %[[ARG_2]]) {{.*}} : (tensor>, tensor<2x3x3x2x!quant.uniform:f32:3, {0.0023622048182750312,0.0023622048182750312}>>) -> tensor> +// CHECK: %[[GET_DIMENSION_SIZE_0:.+]] = stablehlo.get_dimension_size %[[CONVOLUTION_0]], dim = 0 : (tensor) +// CHECK: %[[RESHAPE_0:.+]] = stablehlo.reshape %[[GET_DIMENSION_SIZE_0]] : (tensor) -> tensor<1xi32> +// CHECK: %[[CONCATENATE_0:.+]] = stablehlo.concatenate %[[RESHAPE_0]], %[[CONST_2]], %[[CONST_3]], %[[CONST_4]], dim = 0 : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<4xi32> +// CHECK: %[[DYNAMIC_BROADCAST_IN_DIM_0:.+]] = stablehlo.dynamic_broadcast_in_dim %[[ARG_3]], %[[CONCATENATE_0]], dims = [0, 1, 2, 3] : (tensor<1x1x1x2x!quant.uniform>, tensor<4xi32>) -> tensor> +// CHECK: %[[ADD_0:.+]] = stablehlo.add %[[CONVOLUTION_0]], %[[DYNAMIC_BROADCAST_IN_DIM_0]] : tensor> +// CHECK: %[[UNIFORM_QUANTIZE_0:.+]] = stablehlo.uniform_quantize %[[ADD_0]] : (tensor>) -> tensor> +// CHECK: return %[[UNIFORM_QUANTIZE_0]] : tensor> + +// CHECK-PER-TENSOR: func.func private @quantized_conv_with_bias_and_relu6_dynamic_fn(%[[ARG_1:.+]]: tensor>, %[[ARG_2:.+]]: tensor<2x3x3x2x!quant.uniform:f32, {{.*}}>>, %[[ARG_3:.+]]: tensor<1x1x1x2x!quant.uniform>) -> tensor> attributes {_from_xla_call_module} +// CHECK-PER-TENSOR-DAG: %[[CONST_2:.+]] = stablehlo.constant dense<3> : tensor<1xi32> +// CHECK-PER-TENSOR-DAG: %[[CONST_3:.+]] = stablehlo.constant dense<4> : tensor<1xi32> +// CHECK-PER-TENSOR-DAG: %[[CONST_4:.+]] = stablehlo.constant dense<2> : tensor<1xi32> +// CHECK-PER-TENSOR: %[[CONVOLUTION_0:.+]] = stablehlo.convolution(%[[ARG_1]], %[[ARG_2]]) {{.*}} : (tensor>, tensor<2x3x3x2x!quant.uniform:f32, 0.0023622048182750312>>) -> tensor> +// CHECK-PER-TENSOR: %[[GET_DIMENSION_SIZE_0:.+]] = stablehlo.get_dimension_size %[[CONVOLUTION_0]], dim = 0 : (tensor) +// CHECK-PER-TENSOR: %[[RESHAPE_0:.+]] = stablehlo.reshape %[[GET_DIMENSION_SIZE_0]] : (tensor) -> tensor<1xi32> +// CHECK-PER-TENSOR: %[[CONCATENATE_0:.+]] = stablehlo.concatenate %[[RESHAPE_0]], %[[CONST_2]], %[[CONST_3]], %[[CONST_4]], dim = 0 : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<4xi32> +// CHECK-PER-TENSOR: %[[DYNAMIC_BROADCAST_IN_DIM_0:.+]] = stablehlo.dynamic_broadcast_in_dim %[[ARG_3]], %[[CONCATENATE_0]], dims = [0, 1, 2, 3] : (tensor<1x1x1x2x!quant.uniform>, tensor<4xi32>) -> tensor> +// CHECK-PER-TENSOR: %[[ADD_0:.+]] = stablehlo.add %[[CONVOLUTION_0]], %[[DYNAMIC_BROADCAST_IN_DIM_0]] : tensor> +// CHECK-PER-TENSOR: %[[UNIFORM_QUANTIZE_0:.+]] = stablehlo.uniform_quantize %[[ADD_0]] : (tensor>) -> tensor> +// CHECK-PER-TENSOR: return %[[UNIFORM_QUANTIZE_0]] : tensor> + +// ----- + +// Tests that XlaCallModule op is not quantized without the quantfork.stats ops. + +module attributes {tf_saved_model.semantics} { + func.func private @not_quantized_without_stats_fn(%arg0: tensor<1x2xf32>) -> tensor<1x3xf32> attributes {tf._original_func_name = "main_0"} { + %cst = "tf.Const"() {value = dense<3.00000000e-1> : tensor<2x3xf32>} : () -> tensor<2x3xf32> + %1 = "tf.XlaCallModule"(%arg0, %cst) {Sout = [#tf_type.shape<1x3>], _entry_function = @composite_dot_general_fn, _original_entry_function = "composite_dot_general_fn", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32> + return %1 : tensor<1x3xf32> + } +// Check that "tf.Const" is converted to stablehlo.constant. XlaCallModule is +// not quantized. + +// CHECK: func.func private @not_quantized_without_stats_fn(%[[ARG_0:.+]]: tensor<1x2xf32>) -> tensor<1x3xf32> attributes {tf._original_func_name = "main_0"} +// CHECK: %[[CONST_0:.+]] = stablehlo.constant dense<3.000000e-01> : tensor<2x3xf32> +// CHECK: %[[XLA_CALL_MODULE_0:.+]] = "tf.XlaCallModule"(%[[ARG_0]], %[[CONST_0]]) <{{{.*}}}> {{{.*_entry_function = @composite_dot_general_fn.*}}} : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32> +// CHECK: return %[[XLA_CALL_MODULE_0]] + + func.func private @composite_dot_general_fn(%arg0: tensor<1x2xf32>, %arg1: tensor<2x3xf32>) -> tensor<1x3xf32> attributes {_from_xla_call_module} { + %0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0] : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32> + return %0 : tensor<1x3xf32> + } +// CHECK: func.func private @composite_dot_general_fn(%[[ARG_1:.+]]: tensor<1x2xf32>, %[[ARG_2:.+]]: tensor<2x3xf32>) -> tensor<1x3xf32> attributes {_from_xla_call_module} +// Check that the composite_dot_general_fn is untouched. +// CHECK: %[[DOT_GENERAL_0:.+]] = stablehlo.dot_general %[[ARG_1]], %[[ARG_2]] +// CHECK: return %[[DOT_GENERAL_0]] +} + +// ----- + +// Tests that basic gather is properly quantized. + +module attributes {tf_saved_model.semantics} { +// CHECK: func.func private @quantize_gather_fn(%[[ARG:.+]]: tensor<3x4x2xf32>) -> tensor<2x3x2x2xf32> attributes {tf._original_func_name = "main_0"} + func.func private @quantize_gather_fn(%arg: tensor<3x4x2xf32>) -> tensor<2x3x2x2xf32> attributes {tf._original_func_name = "main_0"} { + %cst = "tf.Const"() {value = dense<1> : tensor<2x3x2xi32>} : () -> tensor<2x3x2xi32> + %0 = "quantfork.stats"(%arg) {layerStats = dense<[4.00000000e-6, 9.80000000e-1]> : tensor<2xf32>} : (tensor<3x4x2xf32>) -> tensor<3x4x2xf32> + %1 = "tf.XlaCallModule"(%0, %cst) {Sout = [#tf_type.shape<2x3x2x2>], _entry_function = @composite_gather_fn, _original_entry_function = "composite_gather_fn", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : (tensor<3x4x2xf32>, tensor<2x3x2xi32>) -> tensor<2x3x2x2xf32> + %2 = "quantfork.stats"(%1) {layerStats = dense<[4.00000000e-6, 9.80000000e-1]> : tensor<2xf32>} : (tensor<2x3x2x2xf32>) -> tensor<2x3x2x2xf32> + return %2 : tensor<2x3x2x2xf32> + } +// Checks that the quantized XlaCallModule has been replaced by a CallOp, which +// calls the quantized entry function. +// CHECK: %[[CONST:.+]] = stablehlo.constant dense<{{.*}}> : tensor<2x3x2xi32> +// CHECK: %[[UNIFORM_QUANTIZE:.+]] = stablehlo.uniform_quantize %[[ARG_0]] : (tensor<3x4x2xf32>) -> tensor<3x4x2x!quant.uniform> +// CHECK: %[[CALL:.+]] = call @quantized_gather_fn(%[[UNIFORM_QUANTIZE]], %[[CONST]]) : (tensor<3x4x2x!quant.uniform>, tensor<2x3x2xi32>) -> tensor<2x3x2x2x!quant.uniform> +// CHECK: %[[UNIFORM_DEQUANTIZE:.+]] = stablehlo.uniform_dequantize %[[CALL]] : (tensor<2x3x2x2x!quant.uniform) -> tensor<2x3x2x2xf32> +// CHECK: return %[[UNIFORM_DEQUANTIZE]] : tensor<2x3x2x2xf32> + +// CHECK: func.func private @quantized_gather_fn(%[[ARG_0:.+]]: tensor<3x4x2x!quant.uniform>, %[[ARG_1:.+]]: tensor<2x3x2xi32>) -> tensor<2x3x2x2x!quant.uniform> attributes {_from_xla_call_module} + func.func private @composite_gather_fn(%arg0: tensor<3x4x2xf32>, %arg1: tensor<2x3x2xi32>) -> tensor<2x3x2x2xf32> attributes {_from_xla_call_module} { + %0 = "stablehlo.gather"(%arg0, %arg1) { + dimension_numbers = #stablehlo.gather< + offset_dims = [2, 3], + collapsed_slice_dims = [0], + start_index_map = [1, 0], + index_vector_dim = 2>, + slice_sizes = array, + indices_are_sorted = false + } : (tensor<3x4x2xf32>, tensor<2x3x2xi32>) -> tensor<2x3x2x2xf32> + return %0 : tensor<2x3x2x2xf32> + } +// CHECK: %[[GATHER:.+]] = "stablehlo.gather"(%[[ARG_0]], %[[ARG_1]]) {{.*}} : (tensor<3x4x2x!quant.uniform>, tensor<2x3x2xi32>) -> tensor<2x3x2x2x!quant.uniform> +// CHECK: %[[UNIFORM_QUANTIZE:.+]] = stablehlo.uniform_quantize %[[GATHER]] : tensor<2x3x2x2x!quant.uniform> +// CHECK: return %[[UNIFORM_QUANTIZE]] : tensor<2x3x2x2x!quant.uniform> +} diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/replace_stablehlo_ops_in_main_function_with_xla_call_module_ops.mlir b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/replace_stablehlo_ops_in_main_function_with_xla_call_module_ops.mlir similarity index 100% rename from tensorflow/compiler/mlir/quantization/stablehlo/tests/replace_stablehlo_ops_in_main_function_with_xla_call_module_ops.mlir rename to tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/replace_stablehlo_ops_in_main_function_with_xla_call_module_ops.mlir diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/restore_function_name.mlir b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/restore_function_name.mlir similarity index 100% rename from tensorflow/compiler/mlir/quantization/stablehlo/tests/restore_function_name.mlir rename to tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/restore_function_name.mlir diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/unfuse_mhlo_batch_norm.mlir b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/unfuse_mhlo_batch_norm.mlir similarity index 100% rename from tensorflow/compiler/mlir/quantization/stablehlo/tests/unfuse_mhlo_batch_norm.mlir rename to tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/unfuse_mhlo_batch_norm.mlir diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/unwrap_xla_call_module_op.mlir b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/unwrap_xla_call_module_op.mlir similarity index 100% rename from tensorflow/compiler/mlir/quantization/stablehlo/tests/unwrap_xla_call_module_op.mlir rename to tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/unwrap_xla_call_module_op.mlir diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/populate_shape.mlir b/tensorflow/compiler/mlir/quantization/stablehlo/tests/populate_shape.mlir deleted file mode 100644 index 779ef786714fb7..00000000000000 --- a/tensorflow/compiler/mlir/quantization/stablehlo/tests/populate_shape.mlir +++ /dev/null @@ -1,44 +0,0 @@ -// RUN: stablehlo-quant-opt %s -split-input-file -populate-shape | FileCheck %s - -// CHECK-LABEL: @populate_shape_for_custom_aggregator -func.func @populate_shape_for_custom_aggregator(%input: tensor) { - // CHECK: %[[OUTPUT:.*]] = "tf.CustomAggregator"(%[[INPUT:.*]]) <{id = "49d53b0"}> {calibration_method = 1 : i64, device = "", initial_num_bins = 0 : i64, max = 6.000000e+00 : f32, max_percentile = 0.000000e+00 : f32, min = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} : (tensor) -> tensor - %0 = "tf.CustomAggregator"(%input) <{id = "49d53b0"}> {calibration_method = 1 : i64, device = "", initial_num_bins = 0 : i64, max = 6.000000e+00 : f32, max_percentile = 0.000000e+00 : f32, min = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} : (tensor) -> tensor<*xf32> - func.return -} - -// ---- - -// CHECK-LABEL: @populate_shape_for_xla_call_module -func.func @populate_shape_for_xla_call_module(%input: tensor) { - %cst = "tf.Const"() {value = dense<3.00000000e-1> : tensor<1x1x64x256xf32>} : () -> tensor<1x1x64x256xf32> - // CHECK: %[[OUTPUT:.*]] = "tf.XlaCallModule"(%[[INPUT:.*]], %[[CST:.*]]) <{Sout = [#tf_type.shape], dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64}> {_entry_function = @main_9, _original_entry_function = "composite_conv_fn", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = ""} : (tensor, tensor<1x1x64x256xf32>) -> tensor - %0 = "tf.XlaCallModule"(%input, %cst) <{Sout = [#tf_type.shape], dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64}> {_entry_function = @main_9, _original_entry_function = "composite_conv_fn", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = ""} : (tensor, tensor<1x1x64x256xf32>) -> tensor<*xf32> - func.return -} - -// ---- - -// CHECK-LABEL: @populate_shape_for_chain_of_ops -func.func @populate_shape_for_chain_of_ops(%input: tensor) { - %cst = "tf.Const"() {value = dense<3.00000000e-1> : tensor<1x1x64x256xf32>} : () -> tensor<1x1x64x256xf32> - // CHECK: %[[VAL_0:.*]] = "tf.CustomAggregator"(%[[INPUT:.*]]) <{id = "49d53b0"}> {calibration_method = 1 : i64, device = "", initial_num_bins = 0 : i64, max = 6.000000e+00 : f32, max_percentile = 0.000000e+00 : f32, min = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} : (tensor) -> tensor - // CHECK: %[[VAL_1:.*]] = "tf.XlaCallModule"(%[[VAL_0:.*]], %[[CST:.*]]) <{Sout = [#tf_type.shape], dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64}> {_entry_function = @main_9, _original_entry_function = "composite_conv_fn", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = ""} : (tensor, tensor<1x1x64x256xf32>) -> tensor - // CHECK: %[[VAL_2:.*]] = "tf.CustomAggregator"(%[[VAL_1:.*]]) <{id = "49d53b1"}> {calibration_method = 1 : i64, device = "", initial_num_bins = 0 : i64, max = 6.000000e+00 : f32, max_percentile = 0.000000e+00 : f32, min = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} : (tensor) -> tensor - %0 = "tf.CustomAggregator"(%input) <{id = "49d53b0"}> {calibration_method = 1 : i64, device = "", initial_num_bins = 0 : i64, max = 6.000000e+00 : f32, max_percentile = 0.000000e+00 : f32, min = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} : (tensor) -> tensor<*xf32> - %1 = "tf.XlaCallModule"(%0, %cst) <{Sout = [#tf_type.shape], dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64}> {_entry_function = @main_9, _original_entry_function = "composite_conv_fn", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = ""} : (tensor<*xf32>, tensor<1x1x64x256xf32>) -> tensor<*xf32> - %2 = "tf.CustomAggregator"(%1) <{id = "49d53b1"}> {calibration_method = 1 : i64, device = "", initial_num_bins = 0 : i64, max = 6.000000e+00 : f32, max_percentile = 0.000000e+00 : f32, min = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<*xf32>) -> tensor<*xf32> - func.return -} - -// ---- - -// CHECK-LABEL: @populate_shape_for_xla_call_module_failure_not_single_output -func.func @populate_shape_for_xla_call_module_failure_not_single_output(%input: tensor) { - %cst = "tf.Const"() {value = dense<3.00000000e-1> : tensor<1x1x64x256xf32>} : () -> tensor<1x1x64x256xf32> - // expected-error @+2 {{XlaCallModuleOp doesn't have 1 output.}} - %0, %1 = "tf.XlaCallModule"(%input, %cst) <{Sout = [#tf_type.shape, #tf_type.shape], dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64}> {_entry_function = @main_9, _original_entry_function = "composite_conv_fn", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = ""} : (tensor, tensor<1x1x64x256xf32>) -> (tensor<*xf32>, tensor<*xf32>) - // expected-error @+1 {{XlaCallModuleOp doesn't have 1 output.}} - "tf.XlaCallModule"(%input, %cst) <{Sout = [], dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64}> {_entry_function = @main_9, _original_entry_function = "composite_conv_fn", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = ""} : (tensor, tensor<1x1x64x256xf32>) -> () - func.return -} diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/quantize_composite_functions.mlir b/tensorflow/compiler/mlir/quantization/stablehlo/tests/quantize_composite_functions.mlir deleted file mode 100644 index 8c99fe0a345c38..00000000000000 --- a/tensorflow/compiler/mlir/quantization/stablehlo/tests/quantize_composite_functions.mlir +++ /dev/null @@ -1,403 +0,0 @@ -// RUN: stablehlo-quant-opt %s -split-input-file -verify-diagnostics \ -// RUN: -stablehlo-quantize-composite-functions | FileCheck %s - - -// Tests that basic dot_general is properly quantized. - -// The following pattern does not converge because of a bug in QuantizePass. -// TODO - b/305469508: Fix the QuantizePass to avoid this warning. -// expected-warning @+1 {{Failed to converge pattern at QuantizePass.}} -module attributes {tf_saved_model.semantics} { -// CHECK: func.func private @quantize_dot_general_fn(%[[ARG_0:.+]]: tensor<1x2xf32>) -> tensor<1x3xf32> attributes {tf._original_func_name = "main_0"} - func.func private @quantize_dot_general_fn(%arg0: tensor<1x2xf32>) -> tensor<1x3xf32> attributes {tf._original_func_name = "main_0"} { - %cst = "tf.Const"() {value = dense<3.00000000e-1> : tensor<2x3xf32>} : () -> tensor<2x3xf32> - %0 = "quantfork.stats"(%arg0) {layerStats = dense<[6.00000000e-6, 9.00000000e-1]> : tensor<2xf32>} : (tensor<1x2xf32>) -> tensor<1x2xf32> - %1 = "tf.XlaCallModule"(%0, %cst) {Sout = [#tf_type.shape<1x3>], _entry_function = @composite_dot_general_fn, _original_entry_function = "composite_dot_general_fn", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32> - %2 = "quantfork.stats"(%1) {layerStats = dense<[5.00000000e-6, 7.00000000e-1]> : tensor<2xf32>} : (tensor<1x3xf32>) -> tensor<1x3xf32> - return %2 : tensor<1x3xf32> - } -// Checks that the quantized XlaCallModule has been replaced by a CallOp, which -// calls the quantized entry function. -// CHECK: %[[CONST_0:.+]] = stablehlo.constant() {value = dense<{{.*}}> : tensor<2x3xi8>} : () -> tensor<2x3x!quant.uniform:f32, {{.*}}> -// CHECK: %[[UNIFORM_QUANTIZE_0:.+]] = stablehlo.uniform_quantize %[[ARG_0]] : (tensor<1x2xf32>) -> tensor<1x2x!quant.uniform> -// CHECK: %[[CALL_0:.+]] = call @quantized_dot_general_fn(%[[UNIFORM_QUANTIZE_0]], %[[CONST_0]]) : (tensor<1x2x!quant.uniform>, tensor<2x3x!quant.uniform:f32, {{.*}}>) -> tensor<1x3x!quant.uniform> -// CHECK: %[[UNIFORM_DEQUANTIZE_0:.+]] = stablehlo.uniform_dequantize %[[CALL_0]] : (tensor<1x3x!quant.uniform) -> tensor<1x3xf32> -// CHECK: return %[[UNIFORM_DEQUANTIZE_0]] : tensor<1x3xf32> - -// CHECK: func.func private @quantized_dot_general_fn(%[[ARG_1:.+]]: tensor<1x2x!quant.uniform>, %[[ARG_2:.+]]: tensor<2x3x!quant.uniform:f32, {{.*}}>>) -> tensor<1x3x!quant.uniform> attributes {_from_xla_call_module} - func.func private @composite_dot_general_fn(%arg0: tensor<1x2xf32>, %arg1: tensor<2x3xf32>) -> tensor<1x3xf32> attributes {_from_xla_call_module} { - %0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0] : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32> - return %0 : tensor<1x3xf32> - } -// Checks that the entry function is quantized for dot_general. Quantized -// dot_general outputs an i32 quantized tensor, followed by requantization to -// i8 quantized tensor. -// CHECK: %[[DOT_GENERAL_0:.+]] = stablehlo.dot_general %[[ARG_1]], %[[ARG_2]], contracting_dims = [1] x [0] : (tensor<1x2x!quant.uniform>, tensor<2x3x!quant.uniform:f32, {{.*}}>>) -> tensor<1x3x!quant.uniform> -// CHECK: %[[UNIFORM_QUANTIZE_1:.+]] = stablehlo.uniform_quantize %[[DOT_GENERAL_0]] : (tensor<1x3x!quant.uniform>) -> tensor<1x3x!quant.uniform> -// CHECK: return %[[UNIFORM_QUANTIZE_1]] : tensor<1x3x!quant.uniform> -} - -// ----- - -// Tests that fused pattern for dot_general + bias is properly quantized. - -// The following pattern does not converge because of a bug in QuantizePass. -// TODO - b/305469508: Fix the QuantizePass to avoid this warning. -// expected-warning @+1 {{Failed to converge pattern at QuantizePass.}} -module attributes {tf_saved_model.semantics} { -// CHECK: func.func private @quantize_dot_general_with_bias_same_shape_fn(%[[ARG_0:.+]]: tensor<1x2xf32>) -> tensor<1x3xf32> attributes {tf._original_func_name = "main_0"} - func.func private @quantize_dot_general_with_bias_same_shape_fn(%arg0: tensor<1x2xf32>) -> tensor<1x3xf32> attributes {tf._original_func_name = "main_0"} { - %cst = "tf.Const"() {value = dense<3.00000000e-1> : tensor<2x3xf32>} : () -> tensor<2x3xf32> - %cst_0 = "tf.Const"() {value = dense<4.00000000e-1> : tensor<1x3xf32>} : () -> tensor<1x3xf32> - %0 = "quantfork.stats"(%arg0) {layerStats = dense<[6.00000000e-6, 9.00000000e-1]> : tensor<2xf32>} : (tensor<1x2xf32>) -> tensor<1x2xf32> - %1 = "tf.XlaCallModule"(%0, %cst, %cst_0) {Sout = [#tf_type.shape<1x3>], _entry_function = @composite_dot_general_with_bias_same_shape_fn, _original_entry_function = "composite_dot_general_with_bias_same_shape_fn", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : (tensor<1x2xf32>, tensor<2x3xf32>, tensor<1x3xf32>) -> tensor<1x3xf32> - %2 = "quantfork.stats"(%1) {layerStats = dense<[5.00000000e-6, 7.00000000e-1]> : tensor<2xf32>} : (tensor<1x3xf32>) -> tensor<1x3xf32> - return %2 : tensor<1x3xf32> - } -// CHECK-DAG: %[[CONST_0:.+]] = stablehlo.constant() {value = dense<{{.*}}> : tensor<2x3xi8>} : () -> tensor<2x3x!quant.uniform:f32, {{.*}}> -// CHECK-DAG: %[[CONST_1:.+]] = stablehlo.constant() {value = dense<{{.*}}> : tensor<1x3xi32>} : () -> tensor<1x3x!quant.uniform -// CHECK: %[[UNIFORM_QUANTIZE_0:.+]] = stablehlo.uniform_quantize %[[ARG_0]] : (tensor<1x2xf32>) -> tensor<1x2x!quant.uniform> -// CHECK: %[[CALL_0:.+]] = call @quantized_dot_general_with_bias_same_shape_fn(%[[UNIFORM_QUANTIZE_0]], %[[CONST_0]], %[[CONST_1]]) : (tensor<1x2x!quant.uniform>, tensor<2x3x!quant.uniform:f32, {{.*}}>, tensor<1x3x!quant.uniform) -> tensor<1x3x!quant.uniform -// CHECK: %[[UNIFORM_DEQUANTIZE_0:.+]] = stablehlo.uniform_dequantize %[[CALL_0]] : (tensor<1x3x!quant.uniform) -> tensor<1x3xf32> -// CHECK: return %[[UNIFORM_DEQUANTIZE_0]] : tensor<1x3xf32> - -// CHECK: func.func private @quantized_dot_general_with_bias_same_shape_fn(%[[ARG_1:.+]]: tensor<1x2x!quant.uniform>, %[[ARG_2:.+]]: tensor<2x3x!quant.uniform:f32, {{.*}}>>, %[[ARG_3:.+]]: tensor<1x3x!quant.uniform>) -> tensor<1x3x!quant.uniform> attributes {_from_xla_call_module} - func.func private @composite_dot_general_with_bias_same_shape_fn(%arg0: tensor<1x2xf32>, %arg1: tensor<2x3xf32>, %arg2: tensor<1x3xf32>) -> tensor<1x3xf32> attributes {_from_xla_call_module} { - %0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0] : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32> - %1 = stablehlo.add %0, %arg2 : tensor<1x3xf32> - return %1 : tensor<1x3xf32> - } -// CHECK: %[[DOT_GENERAL_0:.+]] = stablehlo.dot_general %[[ARG_1]], %[[ARG_2]], contracting_dims = [1] x [0] : (tensor<1x2x!quant.uniform>, tensor<2x3x!quant.uniform:f32, {{.*}}>) -> tensor<1x3x!quant.uniform> -// CHECK: %[[ADD_0:.+]] = stablehlo.add %[[DOT_GENERAL_0]], %[[ARG_3]] : tensor<1x3x!quant.uniform> -// CHECK: %[[UNIFORM_QUANTIZE_1:.+]] = stablehlo.uniform_quantize %[[ADD_0]] : (tensor<1x3x!quant.uniform>) -> tensor<1x3x!quant.uniform> -// CHECK: return %[[UNIFORM_QUANTIZE_1]] : tensor<1x3x!quant.uniform> - -} - -// ----- - -// Tests that fused pattern for dot_general + bias with dynamic batch dimension -// is properly quantized. - -// The following pattern does not converge because of a bug in QuantizePass. -// TODO - b/305469508: Fix the QuantizePass to avoid this warning. -// expected-warning @+1 {{Failed to converge pattern at QuantizePass.}} -module attributes {tf_saved_model.semantics} { -// CHECK: func.func private @quantize_dot_general_with_bias_dynamic_fn(%[[ARG_0:.+]]: tensor) -> tensor attributes {tf._original_func_name = "main_0"} - func.func private @quantize_dot_general_with_bias_dynamic_fn(%arg0: tensor) -> tensor attributes {tf._original_func_name = "main_0"} { - %cst = "tf.Const"() {value = dense<3.00000000e-1> : tensor<2x3xf32>} : () -> tensor<2x3xf32> - %cst_0 = "tf.Const"() {value = dense<4.00000000e-1> : tensor<3xf32>} : () -> tensor<3xf32> - %0 = "quantfork.stats"(%arg0) {layerStats = dense<[6.00000000e-6, 9.00000000e-1]> : tensor<2xf32>} : (tensor) -> tensor - %1 = "tf.XlaCallModule"(%0, %cst, %cst_0) {Sout = [#tf_type.shape], _entry_function = @composite_dot_general_with_bias_dynamic_fn, _original_entry_function = "composite_dot_general_with_bias_dynamic_fn", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : (tensor, tensor<2x3xf32>, tensor<3xf32>) -> tensor - %2 = "quantfork.stats"(%1) {layerStats = dense<[5.00000000e-6, 7.00000000e-1]> : tensor<2xf32>} : (tensor) -> tensor - return %2 : tensor - } -// CHECK-DAG: %[[CONST_0:.+]] = stablehlo.constant() {value = dense<{{.*}}> : tensor<2x3xi8>} : () -> tensor<2x3x!quant.uniform:f32, {{.*}}> -// CHECK-DAG: %[[CONST_1:.+]] = stablehlo.constant() {value = dense<{{.*}}> : tensor<3xi32>} : () -> tensor<3x!quant.uniform -// CHECK: %[[UNIFORM_QUANTIZE_0:.+]] = stablehlo.uniform_quantize %[[ARG_0]] : (tensor) -> tensor> -// CHECK: %[[CALL_0:.+]] = call @quantized_dot_general_with_bias_dynamic_fn(%[[UNIFORM_QUANTIZE_0]], %[[CONST_0]], %[[CONST_1]]) : (tensor>, tensor<2x3x!quant.uniform:f32, {{.*}}>, tensor<3x!quant.uniform) -> tensor -// CHECK: %[[UNIFORM_DEQUANTIZE_0:.+]] = stablehlo.uniform_dequantize %[[CALL_0]] : (tensor) -> tensor -// CHECK: return %[[UNIFORM_DEQUANTIZE_0]] : tensor - -// CHECK: func.func private @quantized_dot_general_with_bias_dynamic_fn(%[[ARG_1:.+]]: tensor>, %[[ARG_2:.+]]: tensor<2x3x!quant.uniform:f32, {{.*}}>>, %[[ARG_3:.+]]: tensor<3x!quant.uniform>) -> tensor> attributes {_from_xla_call_module} - func.func private @composite_dot_general_with_bias_dynamic_fn(%arg0: tensor, %arg1: tensor<2x3xf32>, %arg2: tensor<3xf32>) -> tensor attributes {_from_xla_call_module} { - %cst_0 = stablehlo.constant dense<2> : tensor<1xi32> - %0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0] : (tensor, tensor<2x3xf32>) -> tensor - %1 = stablehlo.get_dimension_size %0, dim = 0 : (tensor) -> tensor - %2 = stablehlo.reshape %1 : (tensor) -> tensor<1xi32> - %3 = stablehlo.concatenate %2, %cst_0, dim = 0 : (tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32> - %4 = stablehlo.dynamic_broadcast_in_dim %arg2, %3, dims = [1] : (tensor<3xf32>, tensor<2xi32>) -> tensor - %5 = stablehlo.add %0, %4 : tensor - return %5 : tensor - } -} -// CHECK: %[[CONST_2:.+]] = stablehlo.constant dense<2> : tensor<1xi32> -// CHECK: %[[DOT_GENERAL_0:.+]] = stablehlo.dot_general %[[ARG_1]], %[[ARG_2]], contracting_dims = [1] x [0] : (tensor>, tensor<2x3x!quant.uniform:f32, {{.*}}>>) -> tensor> -// CHECK: %[[GET_DIMENSION_SIZE_0:.+]] = stablehlo.get_dimension_size %[[DOT_GENERAL_0]], dim = 0 : (tensor) -// CHECK: %[[RESHAPE_0:.+]] = stablehlo.reshape %[[GET_DIMENSION_SIZE_0]] : (tensor) -> tensor<1xi32> -// CHECK: %[[CONCATENATE_0:.+]] = stablehlo.concatenate %[[RESHAPE_0]], %[[CONST_2]], dim = 0 : (tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32> -// CHECK: %[[DYNAMIC_BROADCAST_IN_DIM_0:.+]] = stablehlo.dynamic_broadcast_in_dim %[[ARG_3]], %[[CONCATENATE_0]], dims = [1] : (tensor<3x!quant.uniform>, tensor<2xi32>) -> tensor> -// CHECK: %[[ADD_0:.+]] = stablehlo.add %[[DOT_GENERAL_0]], %[[DYNAMIC_BROADCAST_IN_DIM_0]] : tensor> -// CHECK: %[[UNIFORM_QUANTIZE_1:.+]] = stablehlo.uniform_quantize %[[ADD_0]] : (tensor>) -// CHECK: return %[[UNIFORM_QUANTIZE_1]] : tensor> - - -// ----- - -// Tests that basic convolution is properly quantized. - -// The following pattern does not converge because of a bug in QuantizePass. -// TODO - b/305469508: Fix the QuantizePass to avoid this warning. -// expected-warning @+1 {{Failed to converge pattern at QuantizePass.}} -module attributes {tf_saved_model.semantics} { -// CHECK: func.func private @quantize_conv_fn(%[[ARG_0:.+]]: tensor<1x3x4x3xf32>) -> tensor<1x3x4x2xf32> attributes {tf._original_func_name = "main_0"} - func.func private @quantize_conv_fn(%arg0: tensor<1x3x4x3xf32>) -> tensor<1x3x4x2xf32> attributes {tf._original_func_name = "main_0"} { - %cst = "tf.Const"() {value = dense<3.00000000e-1> : tensor<2x3x3x2xf32>} : () -> tensor<2x3x3x2xf32> - %0 = "quantfork.stats"(%arg0) {layerStats = dense<[6.00000000e-6, 9.00000000e-1]> : tensor<2xf32>} : (tensor<1x3x4x3xf32>) -> tensor<1x3x4x3xf32> - %1 = "tf.XlaCallModule"(%0, %cst) {Sout = [#tf_type.shape<1x3x4x2>], dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64, _entry_function = @composite_conv_fn, _original_entry_function = "composite_conv_fn", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = ""} : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>) -> tensor<1x3x4x2xf32> - %2 = "quantfork.stats"(%1) {layerStats = dense<[5.00000000e-6, 7.00000000e-1]> : tensor<2xf32>} : (tensor<1x3x4x2xf32>) -> tensor<1x3x4x2xf32> - return %2 : tensor<1x3x4x2xf32> - } -// Check that the quantized XlaCallModule has been replaced by a CallOp, which -// calls the quantized entry function. -// CHECK: %[[CONST_0:.+]] = stablehlo.constant() {value = dense<{{.*}}> : tensor<2x3x3x2xi8>} : () -> tensor<2x3x3x2x!quant.uniform:f32, {{.*}}> -// CHECK: %[[UNIFORM_QUANTIZE_0:.+]] = stablehlo.uniform_quantize %[[ARG_0]] : (tensor<1x3x4x3xf32>) -> tensor<1x3x4x3x!quant.uniform> -// CHECK: %[[CALL_0:.+]] = call @quantized_conv_fn(%[[UNIFORM_QUANTIZE_0]], %[[CONST_0]]) : (tensor<1x3x4x3x!quant.uniform>, tensor<2x3x3x2x!quant.uniform:f32, {{.*}}>) -> tensor<1x3x4x2x!quant.uniform> -// CHECK: %[[UNIFORM_DEQUANTIZE_0:.+]] = stablehlo.uniform_dequantize %[[CALL_0]] : (tensor<1x3x4x2x!quant.uniform) -> tensor<1x3x4x2xf32> -// CHECK: return %[[UNIFORM_DEQUANTIZE_0]] : tensor<1x3x4x2xf32> - -// CHECK: func.func private @quantized_conv_fn(%[[ARG_1:.+]]: tensor<1x3x4x3x!quant.uniform>, %[[ARG_2:.+]]: tensor<2x3x3x2x!quant.uniform:f32, {{.*}}>>) -> tensor<1x3x4x2x!quant.uniform> attributes {_from_xla_call_module} - func.func private @composite_conv_fn(%arg0: tensor<1x3x4x3xf32>, %arg1: tensor<2x3x3x2xf32>) -> tensor<1x3x4x2xf32> attributes {_from_xla_call_module} { - %0 = stablehlo.convolution(%arg0, %arg1) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {pad = [[0, 1], [1, 1]]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>) -> tensor<1x3x4x2xf32> - return %0 : tensor<1x3x4x2xf32> - } -// Checks that the entry function is quantized for convolution. Quantized -// convolution outputs an i32 quantized tensor, followed by requantization to -// i8 quantized tensor. -// CHECK: %[[CONVOLUTION_0:.+]] = stablehlo.convolution(%[[ARG_1]], %[[ARG_2]]) {{.*}} : (tensor<1x3x4x3x!quant.uniform>, tensor<2x3x3x2x!quant.uniform:f32, {{.*}}>>) -> tensor<1x3x4x2x!quant.uniform> -// CHECK: %[[UNIFORM_QUANTIZE_1:.+]] = stablehlo.uniform_quantize %[[CONVOLUTION_0]] : (tensor<1x3x4x2x!quant.uniform>) -> tensor<1x3x4x2x!quant.uniform> -// CHECK: return %[[UNIFORM_QUANTIZE_1]] : tensor<1x3x4x2x!quant.uniform> -} - -// ----- - -// Tests that fused pattern for convolution + bias is properly quantized. - -// The following pattern does not converge because of a bug in QuantizePass. -// TODO - b/305469508: Fix the QuantizePass to avoid this warning. -// expected-warning @+1 {{Failed to converge pattern at QuantizePass.}} -module attributes {tf_saved_model.semantics} { -// CHECK: func.func private @quantize_conv_with_bias_fn(%[[ARG_0:.+]]: tensor<1x3x4x3xf32>) -> tensor<1x3x4x2xf32> attributes {tf._original_func_name = "main_0"} - func.func private @quantize_conv_with_bias_fn(%arg0: tensor<1x3x4x3xf32>) -> tensor<1x3x4x2xf32> attributes {tf._original_func_name = "main_0"} { - %cst = "tf.Const"() {value = dense<3.00000000e-1> : tensor<2x3x3x2xf32>} : () -> tensor<2x3x3x2xf32> - %cst_0 = "tf.Const"() {value = dense<4.00000000e-1> : tensor<2xf32>} : () -> tensor<2xf32> - %0 = "quantfork.stats"(%arg0) {layerStats = dense<[6.00000000e-6, 9.00000000e-1]> : tensor<2xf32>} : (tensor<1x3x4x3xf32>) -> tensor<1x3x4x3xf32> - %1 = "tf.XlaCallModule"(%0, %cst, %cst_0) {Sout = [#tf_type.shape<1x3x4x2>], _entry_function = @composite_conv_with_bias_fn, _original_entry_function = "composite_conv_with_bias_fn", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>, tensor<2xf32>) -> tensor<1x3x4x2xf32> - %2 = "quantfork.stats"(%1) {layerStats = dense<[5.00000000e-6, 7.00000000e-1]> : tensor<2xf32>} : (tensor<1x3x4x2xf32>) -> tensor<1x3x4x2xf32> - return %2 : tensor<1x3x4x2xf32> - } -// CHECK-DAG: %[[CONST_0:.+]] = stablehlo.constant() {value = dense<{{.*}}> : tensor<2x3x3x2xi8>} : () -> tensor<2x3x3x2x!quant.uniform:f32, {{.*}}> -// CHECK-DAG: %[[CONST_1:.+]] = stablehlo.constant() {value = dense<{{.*}}> : tensor<2xi32>} : () -> tensor<2x!quant.uniform -// CHECK: %[[UNIFORM_QUANTIZE_0:.+]] = stablehlo.uniform_quantize %[[ARG_0]] : (tensor<1x3x4x3xf32>) -> tensor<1x3x4x3x!quant.uniform> -// CHECK: %[[CALL_0:.+]] = call @quantized_conv_with_bias_fn(%[[UNIFORM_QUANTIZE_0]], %[[CONST_0]], %[[CONST_1]]) : (tensor<1x3x4x3x!quant.uniform>, tensor<2x3x3x2x!quant.uniform:f32, {{.*}}>, tensor<2x!quant.uniform) -> tensor<1x3x4x2x!quant.uniform -// CHECK: %[[UNIFORM_DEQUANTIZE_0:.+]] = stablehlo.uniform_dequantize %[[CALL_0]] : (tensor<1x3x4x2x!quant.uniform) -> tensor<1x3x4x2xf32> -// CHECK: return %[[UNIFORM_DEQUANTIZE_0]] : tensor<1x3x4x2xf32> - -// CHECK: func.func private @quantized_conv_with_bias_fn(%[[ARG_1:.+]]: tensor<1x3x4x3x!quant.uniform>, %[[ARG_2:.+]]: tensor<2x3x3x2x!quant.uniform:f32, {{.*}}>>, %[[ARG_3:.+]]: tensor<2x!quant.uniform>) -> tensor<1x3x4x2x!quant.uniform> attributes {_from_xla_call_module} - func.func private @composite_conv_with_bias_fn(%arg0: tensor<1x3x4x3xf32>, %arg1: tensor<2x3x3x2xf32>, %arg2: tensor<2xf32>) -> tensor<1x3x4x2xf32> attributes {_from_xla_call_module} { - %0 = stablehlo.broadcast_in_dim %arg2, dims = [3] : (tensor<2xf32>) -> tensor<1x3x4x2xf32> - %1 = stablehlo.convolution(%arg0, %arg1) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {pad = [[0, 1], [1, 1]]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>) -> tensor<1x3x4x2xf32> - %2 = stablehlo.add %1, %0 : tensor<1x3x4x2xf32> - return %2 : tensor<1x3x4x2xf32> - } -// CHECK: %[[BROADCAST_IN_DIM:.+]] = stablehlo.broadcast_in_dim %arg2 -// CHECK: %[[CONVOLUTION_0:.+]] = stablehlo.convolution(%[[ARG_1]], %[[ARG_2]]) {{.*}} : (tensor<1x3x4x3x!quant.uniform>, tensor<2x3x3x2x!quant.uniform:f32, {{.*}}>) -> tensor<1x3x4x2x!quant.uniform> -// CHECK: %[[ADD_0:.+]] = stablehlo.add %[[CONVOLUTION_0]], %[[BROADCAST_IN_DIM]] : tensor<1x3x4x2x!quant.uniform> -// CHECK: %[[UNIFORM_QUANTIZE_1:.+]] = stablehlo.uniform_quantize %[[ADD_0]] : (tensor<1x3x4x2x!quant.uniform>) -> tensor<1x3x4x2x!quant.uniform> -// CHECK: return %[[UNIFORM_QUANTIZE_1]] : tensor<1x3x4x2x!quant.uniform> -} - -// ----- - -// Tests that fused pattern for convolution + bias with dynamic batch dimension -// is properly quantized. - -// The following pattern does not converge because of a bug in QuantizePass. -// TODO - b/305469508: Fix the QuantizePass to avoid this warning. -// expected-warning @+1 {{Failed to converge pattern at QuantizePass.}} -module attributes {tf_saved_model.semantics} { -// CHECK: func.func private @quantize_conv_with_bias_dynamic_fn(%[[ARG_0:.+]]: tensor) -> tensor attributes {tf._original_func_name = "main_0"} - func.func private @quantize_conv_with_bias_dynamic_fn(%arg0: tensor) -> tensor attributes {tf._original_func_name = "main_0"} { - %cst = "tf.Const"() {value = dense<3.00000000e-1> : tensor<2x3x3x2xf32>} : () -> tensor<2x3x3x2xf32> - %cst_0 = "tf.Const"() {value = dense<4.00000000e-1> : tensor<2xf32>} : () -> tensor<2xf32> - %0 = "quantfork.stats"(%arg0) {layerStats = dense<[6.00000000e-6, 9.00000000e-1]> : tensor<2xf32>} : (tensor) -> tensor - %1 = "tf.XlaCallModule"(%0, %cst, %cst_0) {Sout = [#tf_type.shape<1x3x4x2>], _entry_function = @composite_conv_with_bias_dynamic_fn, _original_entry_function = "composite_conv_with_bias_dynamic_fn", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : (tensor, tensor<2x3x3x2xf32>, tensor<2xf32>) -> tensor - %2 = "quantfork.stats"(%1) {layerStats = dense<[5.00000000e-6, 7.00000000e-1]> : tensor<2xf32>} : (tensor) -> tensor - return %2 : tensor - } - -// CHECK-DAG: %[[CONST_0:.+]] = stablehlo.constant() {value = dense<{{.*}}> : tensor<2x3x3x2xi8>} : () -> tensor<2x3x3x2x!quant.uniform:f32, {{.*}}> -// CHECK-DAG: %[[CONST_1:.+]] = stablehlo.constant() {value = dense<{{.*}}> : tensor<2xi32>} : () -> tensor<2x!quant.uniform -// CHECK: %[[UNIFORM_QUANTIZE_0:.+]] = stablehlo.uniform_quantize %[[ARG_0]] : (tensor) -> tensor> -// CHECK: %[[CALL_0:.+]] = call @quantized_conv_with_bias_dynamic_fn(%[[UNIFORM_QUANTIZE_0]], %[[CONST_0]], %[[CONST_1]]) : (tensor>, tensor<2x3x3x2x!quant.uniform:f32, {{.*}}>, tensor<2x!quant.uniform) -> tensor -// CHECK: %[[UNIFORM_DEQUANTIZE_0:.+]] = stablehlo.uniform_dequantize %[[CALL_0]] : (tensor) -> tensor -// CHECK: return %[[UNIFORM_DEQUANTIZE_0]] : tensor - -// CHECK: func.func private @quantized_conv_with_bias_dynamic_fn(%[[ARG_1:.+]]: tensor>, %[[ARG_2:.+]]: tensor<2x3x3x2x!quant.uniform:f32, {{.*}}>>, %[[ARG_3:.+]]: tensor<2x!quant.uniform>) -> tensor> attributes {_from_xla_call_module} - func.func private @composite_conv_with_bias_dynamic_fn(%arg0: tensor, %arg1: tensor<2x3x3x2xf32>, %arg2: tensor<2xf32>) -> tensor attributes {_from_xla_call_module} { - %cst_0 = stablehlo.constant dense<3> : tensor<1xi32> - %cst_1 = stablehlo.constant dense<4> : tensor<1xi32> - %cst_2 = stablehlo.constant dense<2> : tensor<1xi32> - %0 = stablehlo.convolution(%arg0, %arg1) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {pad = [[0, 1], [1, 1]]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor, tensor<2x3x3x2xf32>) -> tensor - %1 = stablehlo.get_dimension_size %0, dim = 0 : (tensor) -> tensor - %2 = stablehlo.reshape %1 : (tensor) -> tensor<1xi32> - %3 = stablehlo.concatenate %2, %cst_0, %cst_1, %cst_2, dim = 0 : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<4xi32> - %4 = stablehlo.dynamic_broadcast_in_dim %arg2, %3, dims = [3] : (tensor<2xf32>, tensor<4xi32>) -> tensor - %5 = stablehlo.add %0, %4 : tensor - return %5 : tensor - } -} -// CHECK-DAG: %[[CONST_2:.+]] = stablehlo.constant dense<3> : tensor<1xi32> -// CHECK-DAG: %[[CONST_3:.+]] = stablehlo.constant dense<4> : tensor<1xi32> -// CHECK-DAG: %[[CONST_4:.+]] = stablehlo.constant dense<2> : tensor<1xi32> -// CHECK: %[[CONVOLUTION_0:.+]] = stablehlo.convolution(%[[ARG_1]], %[[ARG_2]]) {{.*}} : (tensor>, tensor<2x3x3x2x!quant.uniform:f32, {{.*}}>) -> tensor> -// CHECK: %[[GET_DIMENSION_SIZE_0:.+]] = stablehlo.get_dimension_size %[[CONVOLUTION_0]], dim = 0 : (tensor) -// CHECK: %[[RESHAPE_0:.+]] = stablehlo.reshape %[[GET_DIMENSION_SIZE_0]] : (tensor) -> tensor<1xi32> -// CHECK: %[[CONCATENATE_0:.+]] = stablehlo.concatenate %[[RESHAPE_0]], %[[CONST_2]], %[[CONST_3]], %[[CONST_4]], dim = 0 : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<4xi32> -// CHECK: %[[DYNAMIC_BROADCAST_IN_DIM_0:.+]] = stablehlo.dynamic_broadcast_in_dim %[[ARG_3]], %[[CONCATENATE_0]], dims = [3] : (tensor<2x!quant.uniform>, tensor<4xi32>) -> tensor> -// CHECK: %[[ADD_0:.+]] = stablehlo.add %[[CONVOLUTION_0]], %[[DYNAMIC_BROADCAST_IN_DIM_0]] : tensor> -// CHECK: %[[UNIFORM_QUANTIZE_0:.+]] = stablehlo.uniform_quantize %[[ADD_0]] : (tensor>) -// CHECK: return %[[UNIFORM_QUANTIZE_0]] : tensor> - -// ----- - -// Tests that fused pattern for convolution + bias + relu with -// dynamic batch dimension is properly quantized. - -// Note that this checks for identical condition as -// quantize_conv_with_bias_dynamic_fn, omitting stablehlo.maximum. -// This is because activation clipping which includes 0.0f can be simply -// omitted from the graph as the lifted function's out_scale and out_zp are -// already calculated based on the clipped distribution. -// Note that the resulting scale and zero point should be calculated based on -// clipped range [0, r_max]. - -// The following pattern does not converge because of a bug in QuantizePass. -// TODO - b/305469508: Fix the QuantizePass to avoid this warning. -// expected-warning @+1 {{Failed to converge pattern at QuantizePass.}} -module attributes {tf_saved_model.semantics} { -// CHECK: func.func private @quantize_conv_with_bias_and_relu_dynamic_fn(%[[ARG_0:.+]]: tensor) -> tensor attributes {tf._original_func_name = "main_0"} - func.func private @quantize_conv_with_bias_and_relu_dynamic_fn(%arg0: tensor) -> tensor attributes {tf._original_func_name = "main_0"} { - %cst = "tf.Const"() {value = dense<3.00000000e-1> : tensor<2x3x3x2xf32>} : () -> tensor<2x3x3x2xf32> - %cst_0 = "tf.Const"() {value = dense<4.00000000e-1> : tensor<2xf32>} : () -> tensor<2xf32> - %0 = "quantfork.stats"(%arg0) {layerStats = dense<[6.00000000e-6, 9.00000000e-1]> : tensor<2xf32>} : (tensor) -> tensor - %1 = "tf.XlaCallModule"(%0, %cst, %cst_0) {Sout = [#tf_type.shape<1x3x4x2>], _entry_function = @composite_conv_with_bias_and_relu_dynamic_fn, _original_entry_function = "composite_conv_with_bias_and_relu_dynamic_fn", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : (tensor, tensor<2x3x3x2xf32>, tensor<2xf32>) -> tensor - %2 = "quantfork.stats"(%1) {layerStats = dense<[0.00000000e-6, 8.00000000e-1]> : tensor<2xf32>} : (tensor) -> tensor - return %2 : tensor - } -// CHECK-DAG: %[[CONST_0:.+]] = stablehlo.constant() {value = dense<{{.*}}> : tensor<2x3x3x2xi8>} : () -> tensor<2x3x3x2x!quant.uniform:f32, {{.*}}> -// CHECK-DAG: %[[CONST_1:.+]] = stablehlo.constant() {value = dense<{{.*}}> : tensor<2xi32>} : () -> tensor<2x!quant.uniform -// CHECK: %[[UNIFORM_QUANTIZE_0:.+]] = stablehlo.uniform_quantize %[[ARG_0]] : (tensor) -> tensor> -// CHECK: %[[CALL_0:.+]] = call @quantized_conv_with_bias_and_relu_dynamic_fn(%[[UNIFORM_QUANTIZE_0]], %[[CONST_0]], %[[CONST_1]]) : (tensor>, tensor<2x3x3x2x!quant.uniform:f32, 0.0023622048182750312>>, tensor<2x!quant.uniform>) -> tensor> -// CHECK: %[[UNIFORM_DEQUANTIZE_0:.+]] = stablehlo.uniform_dequantize %[[CALL_0]] : (tensor) -> tensor -// CHECK: return %[[UNIFORM_DEQUANTIZE_0]] : tensor - -// CHECK: func.func private @quantized_conv_with_bias_and_relu_dynamic_fn(%[[ARG_1:.+]]: tensor>, %[[ARG_2:.+]]: tensor<2x3x3x2x!quant.uniform:f32, {{.*}}>>, %[[ARG_3:.+]]: tensor<2x!quant.uniform>) -> tensor> attributes {_from_xla_call_module} - func.func private @composite_conv_with_bias_and_relu_dynamic_fn(%arg0: tensor, %arg1: tensor<2x3x3x2xf32>, %arg2: tensor<2xf32>) -> tensor attributes {_from_xla_call_module} { - %cst_0 = stablehlo.constant dense<3> : tensor<1xi32> - %cst_1 = stablehlo.constant dense<4> : tensor<1xi32> - %cst_2 = stablehlo.constant dense<2> : tensor<1xi32> - %cst_3 = stablehlo.constant dense<0.000000e+00> : tensor - %cst_4 = stablehlo.constant dense<6.000000e+00> : tensor - %0 = stablehlo.convolution(%arg0, %arg1) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {pad = [[0, 1], [1, 1]]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor, tensor<2x3x3x2xf32>) -> tensor - %1 = stablehlo.get_dimension_size %0, dim = 0 : (tensor) -> tensor - %2 = stablehlo.reshape %1 : (tensor) -> tensor<1xi32> - %3 = stablehlo.concatenate %2, %cst_0, %cst_1, %cst_2, dim = 0 : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<4xi32> - %4 = stablehlo.dynamic_broadcast_in_dim %arg2, %3, dims = [3] : (tensor<2xf32>, tensor<4xi32>) -> tensor - %5 = stablehlo.add %0, %4 : tensor - %6 = stablehlo.clamp %cst_3, %5, %cst_4 : (tensor, tensor, tensor) -> tensor - return %6 : tensor - } -} -// CHECK-DAG: %[[CONST_2:.+]] = stablehlo.constant dense<3> : tensor<1xi32> -// CHECK-DAG: %[[CONST_3:.+]] = stablehlo.constant dense<4> : tensor<1xi32> -// CHECK-DAG: %[[CONST_4:.+]] = stablehlo.constant dense<2> : tensor<1xi32> -// CHECK: %[[CONVOLUTION_0:.+]] = stablehlo.convolution(%[[ARG_1]], %[[ARG_2]]) {{.*}} : (tensor>, tensor<2x3x3x2x!quant.uniform:f32, 0.0023622048182750312>>) -> tensor> -// CHECK: %[[GET_DIMENSION_SIZE_0:.+]] = stablehlo.get_dimension_size %[[CONVOLUTION_0]], dim = 0 : (tensor) -// CHECK: %[[RESHAPE_0:.+]] = stablehlo.reshape %[[GET_DIMENSION_SIZE_0]] : (tensor) -> tensor<1xi32> -// CHECK: %[[CONCATENATE_0:.+]] = stablehlo.concatenate %[[RESHAPE_0]], %[[CONST_2]], %[[CONST_3]], %[[CONST_4]], dim = 0 : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<4xi32> -// CHECK: %[[DYNAMIC_BROADCAST_IN_DIM_0:.+]] = stablehlo.dynamic_broadcast_in_dim %[[ARG_3]], %[[CONCATENATE_0]], dims = [3] : (tensor<2x!quant.uniform>, tensor<4xi32>) -> tensor> -// CHECK: %[[ADD_0:.+]] = stablehlo.add %[[CONVOLUTION_0]], %[[DYNAMIC_BROADCAST_IN_DIM_0]] : tensor> -// CHECK: %[[UNIFORM_QUANTIZE_0:.+]] = stablehlo.uniform_quantize %[[ADD_0]] : (tensor>) -> tensor> -// CHECK: return %[[UNIFORM_QUANTIZE_0]] : tensor> - -// ----- - -// Tests that fused pattern for convolution + bias + relu6 with -// dynamic batch dimension is properly quantized. - -// Note that this checks for identical condition as -// quantize_conv_with_bias_dynamic_fn, omitting stablehlo.clamp. -// This is because activation clipping which includes 0.0f can be simply -// omitted from the graph as the lifted function's out_scale and out_zp are -// already calculated based on the clipped distribution. -// Note that the resulting scale and zero point should be calculated based on -// clipped range [0, r_max]. - -// The following pattern does not converge because of a bug in QuantizePass. -// TODO - b/305469508: Fix the QuantizePass to avoid this warning. -// expected-warning @+1 {{Failed to converge pattern at QuantizePass.}} -module attributes {tf_saved_model.semantics} { -// CHECK: func.func private @quantize_conv_with_bias_and_relu6_dynamic_fn(%[[ARG_0:.+]]: tensor) -> tensor attributes {tf._original_func_name = "main_0"} - func.func private @quantize_conv_with_bias_and_relu6_dynamic_fn(%arg0: tensor) -> tensor attributes {tf._original_func_name = "main_0"} { - %cst = "tf.Const"() {value = dense<3.00000000e-1> : tensor<2x3x3x2xf32>} : () -> tensor<2x3x3x2xf32> - %cst_0 = "tf.Const"() {value = dense<4.00000000e-1> : tensor<2xf32>} : () -> tensor<2xf32> - %0 = "quantfork.stats"(%arg0) {layerStats = dense<[6.00000000e-6, 9.00000000e-1]> : tensor<2xf32>} : (tensor) -> tensor - %1 = "tf.XlaCallModule"(%0, %cst, %cst_0) {Sout = [#tf_type.shape<1x3x4x2>], _entry_function = @composite_conv_with_bias_and_relu6_dynamic_fn, _original_entry_function = "composite_conv_with_bias_and_relu6_dynamic_fn", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : (tensor, tensor<2x3x3x2xf32>, tensor<2xf32>) -> tensor - %2 = "quantfork.stats"(%1) {layerStats = dense<[5.00000000e-6, 6.00000000e-1]> : tensor<2xf32>} : (tensor) -> tensor - return %2 : tensor - } -// CHECK-DAG: %[[CONST_0:.+]] = stablehlo.constant() {value = dense<{{.*}}> : tensor<2x3x3x2xi8>} : () -> tensor<2x3x3x2x!quant.uniform:f32, {{.*}}> -// CHECK-DAG: %[[CONST_1:.+]] = stablehlo.constant() {value = dense<{{.*}}> : tensor<2xi32>} : () -> tensor<2x!quant.uniform -// CHECK: %[[UNIFORM_QUANTIZE_0:.+]] = stablehlo.uniform_quantize %[[ARG_0]] : (tensor) -> tensor> -// CHECK: %[[CALL_0:.+]] = call @quantized_conv_with_bias_and_relu6_dynamic_fn(%[[UNIFORM_QUANTIZE_0]], %[[CONST_0]], %[[CONST_1]]) : (tensor>, tensor<2x3x3x2x!quant.uniform:f32, 0.0023622048182750312>>, tensor<2x!quant.uniform>) -> tensor> -// CHECK: %[[UNIFORM_DEQUANTIZE_0:.+]] = stablehlo.uniform_dequantize %[[CALL_0]] : (tensor) -> tensor -// CHECK: return %[[UNIFORM_DEQUANTIZE_0]] : tensor - -// CHECK: func.func private @quantized_conv_with_bias_and_relu6_dynamic_fn(%[[ARG_1:.+]]: tensor>, %[[ARG_2:.+]]: tensor<2x3x3x2x!quant.uniform:f32, {{.*}}>>, %[[ARG_3:.+]]: tensor<2x!quant.uniform>) -> tensor> attributes {_from_xla_call_module} - func.func private @composite_conv_with_bias_and_relu6_dynamic_fn(%arg0: tensor, %arg1: tensor<2x3x3x2xf32>, %arg2: tensor<2xf32>) -> tensor attributes {_from_xla_call_module} { - %cst_0 = stablehlo.constant dense<3> : tensor<1xi32> - %cst_1 = stablehlo.constant dense<4> : tensor<1xi32> - %cst_2 = stablehlo.constant dense<2> : tensor<1xi32> - %cst_3 = stablehlo.constant dense<0.000000e+00> : tensor - %cst_4 = stablehlo.constant dense<6.000000e+00> : tensor - %0 = stablehlo.convolution(%arg0, %arg1) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {pad = [[0, 1], [1, 1]]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor, tensor<2x3x3x2xf32>) -> tensor - %1 = stablehlo.get_dimension_size %0, dim = 0 : (tensor) -> tensor - %2 = stablehlo.reshape %1 : (tensor) -> tensor<1xi32> - %3 = stablehlo.concatenate %2, %cst_0, %cst_1, %cst_2, dim = 0 : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<4xi32> - %4 = stablehlo.dynamic_broadcast_in_dim %arg2, %3, dims = [3] : (tensor<2xf32>, tensor<4xi32>) -> tensor - %5 = stablehlo.add %0, %4 : tensor - %6 = stablehlo.clamp %cst_3, %5, %cst_4 : (tensor, tensor, tensor) -> tensor - return %6 : tensor - } -} -// CHECK-DAG: %[[CONST_2:.+]] = stablehlo.constant dense<3> : tensor<1xi32> -// CHECK-DAG: %[[CONST_3:.+]] = stablehlo.constant dense<4> : tensor<1xi32> -// CHECK-DAG: %[[CONST_4:.+]] = stablehlo.constant dense<2> : tensor<1xi32> -// CHECK: %[[CONVOLUTION_0:.+]] = stablehlo.convolution(%[[ARG_1]], %[[ARG_2]]) {{.*}} : (tensor>, tensor<2x3x3x2x!quant.uniform:f32, 0.0023622048182750312>>) -> tensor> -// CHECK: %[[GET_DIMENSION_SIZE_0:.+]] = stablehlo.get_dimension_size %[[CONVOLUTION_0]], dim = 0 : (tensor) -// CHECK: %[[RESHAPE_0:.+]] = stablehlo.reshape %[[GET_DIMENSION_SIZE_0]] : (tensor) -> tensor<1xi32> -// CHECK: %[[CONCATENATE_0:.+]] = stablehlo.concatenate %[[RESHAPE_0]], %[[CONST_2]], %[[CONST_3]], %[[CONST_4]], dim = 0 : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<4xi32> -// CHECK: %[[DYNAMIC_BROADCAST_IN_DIM_0:.+]] = stablehlo.dynamic_broadcast_in_dim %[[ARG_3]], %[[CONCATENATE_0]], dims = [3] : (tensor<2x!quant.uniform>, tensor<4xi32>) -> tensor> -// CHECK: %[[ADD_0:.+]] = stablehlo.add %[[CONVOLUTION_0]], %[[DYNAMIC_BROADCAST_IN_DIM_0]] : tensor> -// CHECK: %[[UNIFORM_QUANTIZE_0:.+]] = stablehlo.uniform_quantize %[[ADD_0]] : (tensor>) -> tensor> -// CHECK: return %[[UNIFORM_QUANTIZE_0]] : tensor> - -// ----- - -// Tests that XlaCallModule op is not quantized without the quantfork.stats ops. - -module attributes {tf_saved_model.semantics} { -// CHECK: func.func private @not_quantized_without_stats_fn(%[[ARG_0:.+]]: tensor<1x2xf32>) -> tensor<1x3xf32> attributes {tf._original_func_name = "main_0"} - func.func private @not_quantized_without_stats_fn(%arg0: tensor<1x2xf32>) -> tensor<1x3xf32> attributes {tf._original_func_name = "main_0"} { - %cst = "tf.Const"() {value = dense<3.00000000e-1> : tensor<2x3xf32>} : () -> tensor<2x3xf32> - %1 = "tf.XlaCallModule"(%arg0, %cst) {Sout = [#tf_type.shape<1x3>], _entry_function = @composite_dot_general_fn, _original_entry_function = "composite_dot_general_fn", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32> - return %1 : tensor<1x3xf32> - } - -// Check that "tf.Const" is converted to stablehlo.constant. XlaCallModule is -// not quantized. -// CHECK: %[[CONST_0:.+]] = stablehlo.constant dense<3.000000e-01> : tensor<2x3xf32> -// CHECK: %[[XLA_CALL_MODULE_0:.+]] = "tf.XlaCallModule"(%[[ARG_0]], %[[CONST_0]]) <{{{.*}}}> {{{.*_entry_function = @composite_dot_general_fn.*}}} : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32> -// CHECK: return %[[XLA_CALL_MODULE_0]] - -// CHECK: func.func private @composite_dot_general_fn(%[[ARG_1:.+]]: tensor<1x2xf32>, %[[ARG_2:.+]]: tensor<2x3xf32>) -> tensor<1x3xf32> attributes {_from_xla_call_module} - func.func private @composite_dot_general_fn(%arg0: tensor<1x2xf32>, %arg1: tensor<2x3xf32>) -> tensor<1x3xf32> attributes {_from_xla_call_module} { - %0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0] : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32> - return %0 : tensor<1x3xf32> - } - -// Check that the composite_dot_general_fn is untouched. -// CHECK: %[[DOT_GENERAL_0:.+]] = stablehlo.dot_general %[[ARG_1]], %[[ARG_2]] -// CHECK: return %[[DOT_GENERAL_0]] -} diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/BUILD b/tensorflow/compiler/mlir/quantization/tensorflow/BUILD index 56161914b45dd4..1d2608599bd93b 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/BUILD +++ b/tensorflow/compiler/mlir/quantization/tensorflow/BUILD @@ -406,6 +406,7 @@ cc_library( "//tensorflow/compiler/mlir/lite/quantization:quantization_lib", "//tensorflow/compiler/mlir/lite/quantization/ir:QuantOps", "//tensorflow/compiler/mlir/quantization/common:attrs_and_constraints", + "//tensorflow/compiler/mlir/quantization/common:func", "//tensorflow/compiler/mlir/quantization/common:lift_as_function_call", "//tensorflow/compiler/mlir/quantization/tensorflow/cc:const_op_size", "//tensorflow/compiler/mlir/quantization/tensorflow/cc:constant_fold", @@ -479,9 +480,11 @@ cc_library( compatible_with = get_compatible_with_portable(), deps = [ ":passes", + "//tensorflow/compiler/mlir/lite/stablehlo:fuse_convolution_pass", "//tensorflow/compiler/mlir/lite/stablehlo:legalize_tf_xla_call_module_to_stablehlo_pass", "//tensorflow/compiler/mlir/lite/stablehlo:rename_entrypoint_to_main", "//tensorflow/compiler/mlir/lite/stablehlo:tf_stablehlo", + "//tensorflow/compiler/mlir/lite/stablehlo:unfuse_batch_norm_pass", "//tensorflow/compiler/mlir/quantization/stablehlo:bridge_passes", "//tensorflow/compiler/mlir/quantization/stablehlo:passes", "//tensorflow/compiler/mlir/quantization/stablehlo/cc:pass_pipeline", diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/cc/BUILD b/tensorflow/compiler/mlir/quantization/tensorflow/cc/BUILD index 01d7a7b37e1907..62a6f27c8ad5f1 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/cc/BUILD +++ b/tensorflow/compiler/mlir/quantization/tensorflow/cc/BUILD @@ -99,6 +99,7 @@ cc_library( hdrs = ["convert_asset_args.h"], compatible_with = get_compatible_with_portable(), deps = [ + "//tensorflow/compiler/mlir/quantization/common:func", "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/mlir/tensorflow:import_model", "//tensorflow/core/protobuf:for_core_protos_cc", @@ -139,9 +140,9 @@ cc_library( "//tensorflow/compiler/mlir/tensorflow:error_util", "@com_google_absl//absl/status", "@com_google_absl//absl/strings:str_format", - "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", + "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:statusor", ], ) @@ -172,13 +173,16 @@ tf_cc_test( srcs = ["constant_fold_test.cc"], deps = [ ":constant_fold", + "//tensorflow/compiler/mlir/quantization/common:test_base", "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/core:tensorflow", "//tensorflow/core:test", "//tensorflow/core:test_main", + "@com_google_absl//absl/strings:string_view", + "@com_google_googletest//:gtest", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", - "@llvm-project//mlir:Parser", + "@llvm-project//mlir:Support", "@llvm-project//mlir:Transforms", ], ) diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/cc/constant_fold_test.cc b/tensorflow/compiler/mlir/quantization/tensorflow/cc/constant_fold_test.cc index f2a7942323f5f5..5122563c235193 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/cc/constant_fold_test.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/cc/constant_fold_test.cc @@ -14,12 +14,19 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/mlir/quantization/tensorflow/cc/constant_fold.h" +#include + +#include +#include "absl/strings/string_view.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/IR/OwningOpRef.h" // from @llvm-project -#include "mlir/Parser/Parser.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project -#include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h" +#include "tensorflow/compiler/mlir/quantization/common/test_base.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/core/platform/test.h" @@ -30,38 +37,7 @@ namespace { using ::testing::NotNull; using ::testing::SizeIs; -class ConstantFoldingTest : public ::testing::Test { - protected: - ConstantFoldingTest() { - ctx_.loadDialect(); - } - - // Parses `module_op_str` to create a `ModuleOp`. Checks whether the created - // module op is valid. - OwningOpRef ParseModuleOpString(absl::string_view module_op_str) { - auto module_op_ref = parseSourceString(module_op_str, &ctx_); - EXPECT_TRUE(module_op_ref); - return module_op_ref; - } - - // Gets the function with the given name from the module. - func::FuncOp GetFunctionFromModule(ModuleOp module, - absl::string_view function_name) { - SymbolTable symbol_table(module); - return symbol_table.lookup(function_name); - } - - // Returns the first operation with the given type in the function. - template - OpType FindOperationOfType(func::FuncOp function) { - for (auto op : function.getBody().getOps()) { - return op; - } - return nullptr; - } - - MLIRContext ctx_{}; -}; +class ConstantFoldingTest : public QuantizationTestBase {}; TEST_F(ConstantFoldingTest, FoldLargeConstant) { constexpr absl::string_view kModuleCode = R"mlir( @@ -80,8 +56,10 @@ TEST_F(ConstantFoldingTest, FoldLargeConstant) { )mlir"; OwningOpRef module_op_ref = ParseModuleOpString(kModuleCode); - func::FuncOp test_func = - GetFunctionFromModule(*module_op_ref, "test_fold_constant"); + const auto test_func = + module_op_ref->lookupSymbol("test_fold_constant"); + ASSERT_THAT(test_func, NotNull()); + Operation* mul_op = FindOperationOfType(test_func); SmallVector results = ConstantFoldOpIfPossible(mul_op); EXPECT_THAT(results, SizeIs(1)); @@ -106,8 +84,10 @@ TEST_F(ConstantFoldingTest, NotFoldingIdentity) { )mlir"; OwningOpRef module_op_ref = ParseModuleOpString(kModuleCode); - func::FuncOp test_func = - GetFunctionFromModule(*module_op_ref, "test_fold_constant"); + const auto test_func = + module_op_ref->lookupSymbol("test_fold_constant"); + ASSERT_THAT(test_func, NotNull()); + Operation* op_to_fold = FindOperationOfType(test_func); SmallVector results = ConstantFoldOpIfPossible(op_to_fold); EXPECT_THAT(results, SizeIs(1)); @@ -135,8 +115,10 @@ TEST_F(ConstantFoldingTest, NotFoldingArgument) { )mlir"; OwningOpRef module_op_ref = ParseModuleOpString(kModuleCode); - func::FuncOp test_func = - GetFunctionFromModule(*module_op_ref, "test_fold_constant"); + const auto test_func = + module_op_ref->lookupSymbol("test_fold_constant"); + ASSERT_THAT(test_func, NotNull()); + Operation* op_to_fold = FindOperationOfType(test_func); SmallVector results = ConstantFoldOpIfPossible(op_to_fold); EXPECT_THAT(results, SizeIs(1)); @@ -166,11 +148,12 @@ TEST_F(ConstantFoldingTest, FoldDepthwiseConvWeight) { )mlir"; OwningOpRef module_op_ref = ParseModuleOpString(kModuleCode); - func::FuncOp test_func = - GetFunctionFromModule(*module_op_ref, "test_fold_constant"); + const auto test_func = + module_op_ref->lookupSymbol("test_fold_constant"); + ASSERT_THAT(test_func, NotNull()); - RewritePatternSet patterns(&ctx_); - patterns.add(&ctx_); + RewritePatternSet patterns(ctx_.get()); + patterns.add(ctx_.get()); EXPECT_TRUE( succeeded(applyPatternsAndFoldGreedily(test_func, std::move(patterns)))); @@ -198,11 +181,12 @@ TEST_F(ConstantFoldingTest, DepthwiseConvWeightNotFoldable) { )mlir"; OwningOpRef module_op_ref = ParseModuleOpString(kModuleCode); - func::FuncOp test_func = - GetFunctionFromModule(*module_op_ref, "test_fold_constant"); + const auto test_func = + module_op_ref->lookupSymbol("test_fold_constant"); + ASSERT_THAT(test_func, NotNull()); - RewritePatternSet patterns(&ctx_); - patterns.add(&ctx_); + RewritePatternSet patterns(ctx_.get()); + patterns.add(ctx_.get()); EXPECT_TRUE( succeeded(applyPatternsAndFoldGreedily(test_func, std::move(patterns)))); diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/cc/convert_asset_args.cc b/tensorflow/compiler/mlir/quantization/tensorflow/cc/convert_asset_args.cc index d88a1fe42cc555..8e2b3537b34d85 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/cc/convert_asset_args.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/cc/convert_asset_args.cc @@ -25,6 +25,7 @@ limitations under the License. #include "mlir/IR/Value.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "tensorflow/compiler/mlir/quantization/common/func.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h" #include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h" #include "tensorflow/core/protobuf/meta_graph.pb.h" @@ -38,20 +39,6 @@ using ::mlir::tf_saved_model::LookupBoundInputOfType; using ::tensorflow::AssetFileDef; using ::tensorflow::kImportModelDefaultGraphFuncName; -// Gets the "main" function from the module. Returns an empty op iff it doesn't -// exist. -func::FuncOp GetMainFunction(ModuleOp module_op) { - const auto main_func_id = - StringAttr::get(module_op.getContext(), kImportModelDefaultGraphFuncName); - auto func_ops = module_op.getOps(); - auto main_func_itr = absl::c_find_if(func_ops, [&main_func_id](auto func_op) { - return func_op.getName() == main_func_id; - }); - - if (main_func_itr == func_ops.end()) return {}; - return *main_func_itr; -} - // Given argument attributes `arg_attrs`, returns a new set of argument // attributes where the "tf_saved_model.bound_input" attribute has been replaced // with the "tf_saved_model.index_path" attribute. `index_path` is the element @@ -130,7 +117,7 @@ void ConvertMainArgAttrs(func::FuncOp main_func_op, const int arg_idx, } // namespace FailureOr> ConvertAssetArgs(ModuleOp module_op) { - func::FuncOp main_func_op = GetMainFunction(module_op); + func::FuncOp main_func_op = FindMainFuncOp(module_op); if (!main_func_op) return failure(); SmallVector input_names = GetEntryFunctionInputs(main_func_op); diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/cc/run_passes.cc b/tensorflow/compiler/mlir/quantization/tensorflow/cc/run_passes.cc index ef63e75e52e8c9..ee7ae1a4c6d90a 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/cc/run_passes.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/cc/run_passes.cc @@ -21,6 +21,7 @@ limitations under the License. #include "mlir/Pass/PassManager.h" // from @llvm-project #include "tensorflow/compiler/mlir/quantization/tensorflow/debugging/mlir_dump.h" #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h" +#include "tsl/platform/errors.h" namespace tensorflow { namespace quantization { @@ -33,11 +34,8 @@ absl::Status RunPassesOnModuleOp( absl::StatusOr> dump_file; if (mlir_dump_file_name) { - dump_file = tensorflow::quantization::MaybeEnableIrPrinting( - pass_manager, mlir_dump_file_name.value()); - if (!dump_file.ok()) { - return dump_file.status(); - } + TF_RETURN_IF_ERROR(tensorflow::quantization::MaybeEnableIrPrinting( + pass_manager, mlir_dump_file_name.value())); } if (failed(pass_manager.run(module_op))) { diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/cc/run_passes.h b/tensorflow/compiler/mlir/quantization/tensorflow/cc/run_passes.h index 35f243a9f9626d..06db2acb7b057f 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/cc/run_passes.h +++ b/tensorflow/compiler/mlir/quantization/tensorflow/cc/run_passes.h @@ -23,6 +23,7 @@ limitations under the License. #include "mlir/Pass/PassManager.h" // from @llvm-project #include "tensorflow/compiler/mlir/quantization/tensorflow/debugging/mlir_dump.h" #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h" +#include "tsl/platform/errors.h" #include "tsl/platform/statusor.h" namespace tensorflow { @@ -47,8 +48,7 @@ absl::Status RunPasses(const absl::string_view name, FuncT add_passes_func, add_passes_func(pm); mlir::StatusScopedDiagnosticHandler diagnostic_handler{&ctx}; - TF_ASSIGN_OR_RETURN(const std::unique_ptr out_dump_file, - MaybeEnableIrPrinting(pm, name)); + TF_RETURN_IF_ERROR(MaybeEnableIrPrinting(pm, name)); if (failed(pm.run(module_op))) { return absl::InternalError( diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/debugging/BUILD b/tensorflow/compiler/mlir/quantization/tensorflow/debugging/BUILD index fa55ce0ba1e391..b465fe15e8d57c 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/debugging/BUILD +++ b/tensorflow/compiler/mlir/quantization/tensorflow/debugging/BUILD @@ -19,16 +19,21 @@ cc_library( hdrs = ["mlir_dump.h"], compatible_with = get_compatible_with_portable(), deps = [ + "@com_google_absl//absl/log", + "@com_google_absl//absl/memory", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:str_format", "@llvm-project//llvm:Support", + "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", "@local_tsl//tsl/platform:env", + "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:path", "@local_tsl//tsl/platform:status", + "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/platform:stringpiece", ], ) @@ -39,14 +44,19 @@ tf_cc_test( deps = [ ":mlir_dump", "@com_google_absl//absl/cleanup", - "@com_google_absl//absl/status:statusor", "@llvm-project//llvm:Support", + "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", + "@llvm-project//mlir:Parser", "@llvm-project//mlir:Pass", "@llvm-project//mlir:Support", + "@llvm-project//mlir:Transforms", + "@local_tsl//tsl/lib/core:status_test_util", + "@local_tsl//tsl/platform:env", "@local_tsl//tsl/platform:path", "@local_tsl//tsl/platform:test", "@local_tsl//tsl/platform:test_main", + "@stablehlo//:stablehlo_ops", ], ) @@ -57,7 +67,7 @@ tf_kernel_library( deps = [ "//tensorflow/compiler/mlir/quantization/tensorflow:quantization_options_proto_cc", "//tensorflow/core:framework", - "//tensorflow/core:portable_gif_internal", + "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", "//tensorflow/core/platform:path", "@com_google_absl//absl/log", diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/debugging/dump_tensor_op.cc b/tensorflow/compiler/mlir/quantization/tensorflow/debugging/dump_tensor_op.cc index ce52e488774f60..098adcb4586853 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/debugging/dump_tensor_op.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/debugging/dump_tensor_op.cc @@ -25,6 +25,8 @@ limitations under the License. #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor.pb.h" #include "tensorflow/core/framework/types.h" +#include "tensorflow/core/lib/io/compression.h" +#include "tensorflow/core/lib/io/record_writer.h" #include "tensorflow/core/platform/path.h" #include "tensorflow/core/platform/types.h" #include "tsl/platform/env.h" @@ -72,7 +74,18 @@ class DumpTensorOp : public OpKernel { OP_REQUIRES_OK(ctx, ctx->GetAttr("node_name", &node_name)); OP_REQUIRES_OK(ctx, ctx->env()->RecursivelyCreateDir(log_dir_path)); - tensor_data_path_ = io::JoinPath(log_dir_path, file_name); + std::string tensor_data_path = io::JoinPath(log_dir_path, file_name); + OP_REQUIRES_OK( + ctx, ctx->env()->NewWritableFile(tensor_data_path, &tensor_data_file_)); + + // Turn on Zlib compression. + io::RecordWriterOptions options = + io::RecordWriterOptions::CreateRecordWriterOptions( + io::compression::kZlib); + tensor_data_writer_ = + std::make_unique(tensor_data_file_.get(), options); + OP_REQUIRES(ctx, tensor_data_writer_ != nullptr, + absl::AbortedError("Could not create record writer")); // Fetch func_name and node_name from attributes and save as proto. quantization::UnitWiseQuantizationSpec::QuantizationUnit quant_unit_proto; @@ -80,28 +93,33 @@ class DumpTensorOp : public OpKernel { quant_unit_proto.set_node_name(node_name); string quant_unit_path = io::JoinPath(log_dir_path, "quant_unit.pb"); - OP_REQUIRES_OK( ctx, SaveSerializedProtoToFile(quant_unit_proto.SerializeAsString(), quant_unit_path, ctx->env())); } + ~DumpTensorOp() override { + (void)tensor_data_writer_->Flush(); + (void)tensor_data_writer_->Close(); + (void)tensor_data_file_->Close(); + } + void Compute(OpKernelContext* ctx) override { - if (enabled_) { - const Tensor& tensor_data = ctx->input(0); + if (!enabled_) return; + + const Tensor& tensor_data = ctx->input(0); - TensorProto tensor_proto; - tensor_data.AsProtoTensorContent(&tensor_proto); + TensorProto tensor_proto; + tensor_data.AsProtoTensorContent(&tensor_proto); - OP_REQUIRES_OK(ctx, - SaveSerializedProtoToFile(tensor_proto.SerializeAsString(), - tensor_data_path_, ctx->env())); - } + OP_REQUIRES_OK(ctx, tensor_data_writer_->WriteRecord( + tensor_proto.SerializeAsString())); } private: - std::string tensor_data_path_; bool enabled_; + std::unique_ptr tensor_data_file_; + std::unique_ptr tensor_data_writer_; }; REGISTER_KERNEL_BUILDER(Name("DumpTensor").Device(DEVICE_CPU), DumpTensorOp); diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/debugging/mlir_dump.cc b/tensorflow/compiler/mlir/quantization/tensorflow/debugging/mlir_dump.cc index 1d1a54e5a88db5..51984bcd000feb 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/debugging/mlir_dump.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/debugging/mlir_dump.cc @@ -14,22 +14,34 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/mlir/quantization/tensorflow/debugging/mlir_dump.h" +#include #include #include #include +#include +#include "absl/log/log.h" +#include "absl/memory/memory.h" #include "absl/status/status.h" #include "absl/status/statusor.h" -#include "absl/strings/str_format.h" +#include "absl/strings/match.h" #include "absl/strings/string_view.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/FormatVariadic.h" #include "llvm/Support/raw_ostream.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/IR/Operation.h" // from @llvm-project #include "mlir/IR/OperationSupport.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Pass/PassManager.h" // from @llvm-project #include "tsl/platform/env.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/file_system.h" #include "tsl/platform/path.h" #include "tsl/platform/status.h" +#include "tsl/platform/statusor.h" +#include "tsl/platform/stringpiece.h" namespace tensorflow { namespace quantization { @@ -63,39 +75,138 @@ absl::StatusOr GetMlirDumpDir() { return dump_dir; } +// A simple wrapper of tsl::WritableFile so that mlir Pass infra can use it. +class WritableFileWrapper : public llvm::raw_ostream { + public: + ~WritableFileWrapper() override { flush(); } + static absl::StatusOr> Create( + const std::string& filepath) { + std::unique_ptr file; + TF_RETURN_IF_ERROR(tsl::Env::Default()->NewWritableFile(filepath, &file)); + return absl::WrapUnique(new WritableFileWrapper(std::move(file))); + } + + private: + explicit WritableFileWrapper(std::unique_ptr file) + : file_(std::move(file)) { + SetBuffered(); + } + + uint64_t current_pos() const override { + int64_t position; + if (file_->Tell(&position).ok()) { + return position; + } else { + return -1; + } + } + + void write_impl(const char* ptr, size_t size) override { + if (file_ && !file_->Append(tsl::StringPiece(ptr, size)).ok()) { + file_ = nullptr; + } + } + + std::unique_ptr file_; +}; + // Creates a new file to dump the intermediate MLIRs by prefixing the // `dump_file_name` with the value of the TF_QUANT_MLIR_DUMP_PREFIX env // variable. Returns absl::FailedPreconditionError if the env variable is not // set or set to an empty string. -absl::StatusOr> CreateMlirDumpFile( +absl::StatusOr> CreateMlirDumpFile( const absl::string_view dump_file_name) { const absl::StatusOr dump_dir = GetMlirDumpDir(); if (!dump_dir.ok()) { return dump_dir.status(); } - auto *env = tsl::Env::Default(); - const tsl::Status status = env->RecursivelyCreateDir(*dump_dir); - if (!status.ok()) { - return status; - } + auto* env = tsl::Env::Default(); + TF_RETURN_IF_ERROR(env->RecursivelyCreateDir(*dump_dir)); - std::error_code ec{}; // NOLINT: Required to create llvm::raw_fd_ostream const std::string dump_file_path = tsl::io::JoinPath(*dump_dir, dump_file_name); - auto dump_file = std::make_unique(dump_file_path, ec); - if (ec) { - return absl::InternalError(absl::StrFormat( - "Unable to open file: %s, error: %s", dump_file_path, ec.message())); - } + TF_ASSIGN_OR_RETURN(std::unique_ptr file, + WritableFileWrapper::Create(dump_file_path)); LOG(INFO) << "IR dump file created: " << dump_file_path; - return dump_file; + return file; } +class PrinterConfig : public mlir::PassManager::IRPrinterConfig { + public: + explicit PrinterConfig( + absl::string_view dump_file_prefix, bool print_module_scope = false, + bool print_after_only_on_change = true, + mlir::OpPrintingFlags op_printing_flags = mlir::OpPrintingFlags()) + : mlir::PassManager::IRPrinterConfig( + print_module_scope, print_after_only_on_change, + /*printAfterOnlyOnFailure=*/false, op_printing_flags), + mlir_pass_count_(1), + dump_file_prefix_(dump_file_prefix) {} + + void printBeforeIfEnabled(mlir::Pass* pass, mlir::Operation* op, + PrintCallbackFn print_callback) override { + Dump(pass, print_callback, /*is_before=*/true); + } + + void printAfterIfEnabled(mlir::Pass* pass, mlir::Operation* op, + PrintCallbackFn print_callback) override { + Dump(pass, print_callback, /*is_before=*/false); + } + + private: + int64_t mlir_pass_count_; + absl::string_view dump_file_prefix_; + // Map from pass ptr to dump files and pass number. + // + // Each pass has unique and stable pointer, even for passes with the same + // name. E.g. a PassManager could have multiple Canonicalizer passes. + // We use this property to uniquely determine a Pass in a PassManager. + // + // If multiple consecutive func passes are applied to a Module. PassManager + // will iterate over the func in the outer loop and apply the passes in the + // inner loop. This may cause passes to run out-of-order. But the 1st runs of + // each pass are still in-order. So we use pass_to_number_map_ to keep track + // of the number for each pass. + llvm::DenseMap> + pass_to_dump_file_before_map_; + llvm::DenseMap> + pass_to_dump_file_after_map_; + llvm::DenseMap pass_to_number_map_; + + // Get the unique number for each pass. + int64_t GetPassNumber(mlir::Pass* pass) { + if (!pass_to_number_map_.contains(pass)) { + pass_to_number_map_[pass] = mlir_pass_count_++; + } + return pass_to_number_map_[pass]; + } + + void Dump(mlir::Pass* pass, PrintCallbackFn print_callback, bool is_before) { + auto& pass_to_dump_file_map = is_before ? pass_to_dump_file_before_map_ + : pass_to_dump_file_after_map_; + if (!pass_to_dump_file_map.contains(pass)) { + std::string filename = llvm::formatv( + "{0}_{1,0+4}_{2}_{3}.mlir", dump_file_prefix_, GetPassNumber(pass), + pass->getName().str(), is_before ? "before" : "after"); + absl::StatusOr> dump_file = + CreateMlirDumpFile(filename); + if (!dump_file.ok()) { + LOG(WARNING) << "Failed to dump MLIR module to " << filename; + return; + } + pass_to_dump_file_map[pass] = std::move(*dump_file); + } + + return print_callback(*(pass_to_dump_file_map[pass])); + } +}; + } // namespace -void EnableIrPrinting(llvm::raw_ostream &out_stream, mlir::PassManager &pm) { +void EnableIrPrinting(mlir::PassManager& pm, + absl::string_view file_name_prefix) { mlir::OpPrintingFlags flag{}; flag.useLocalScope().elideLargeElementsAttrs().enableDebugInfo(); @@ -112,39 +223,23 @@ void EnableIrPrinting(llvm::raw_ostream &out_stream, mlir::PassManager &pm) { // `PassManager::enableIRPrinting`, except for the `printModuleScope` // parameter, which is true by default. It is set to false to avoid the dump // file size becoming too large when the passes are running on a large model. - pm.enableIRPrinting( - /*shouldPrintBeforePass=*/[](mlir::Pass *, - mlir::Operation *) { return true; }, - /*shouldPrintAfterPass=*/ - [](mlir::Pass *, mlir::Operation *) { return true; }, - /*printModuleScope=*/false, /*printAfterOnlyOnChange=*/true, - /*printAfterOnlyOnFailure=*/false, out_stream, flag); - - LOG(INFO) << "IR dump for TensorFlow quantization pipeline enabled."; + pm.enableIRPrinting(std::make_unique( + file_name_prefix, /*print_module_scope=*/false, + /*print_after_only_on_change=*/true, flag)); } // TODO(b/259374854): Create tests for MaybeEnableIrPrinting. -absl::StatusOr> MaybeEnableIrPrinting( - mlir::PassManager &pm, const absl::string_view name) { +absl::Status MaybeEnableIrPrinting(mlir::PassManager& pm, + absl::string_view file_name_prefix) { if (!VLOG_IS_ON(1)) { LOG(INFO) << "Verbosity level too low to enable IR printing."; - return nullptr; + return absl::OkStatus(); } - absl::StatusOr> dump_file = - CreateMlirDumpFile(/*dump_file_name=*/absl::StrCat(name, ".mlir")); - if (absl::IsFailedPrecondition(dump_file.status())) { - // Requirements for enabling IR dump are not met. IR printing will not be - // enabled. - LOG(WARNING) << dump_file.status(); - return nullptr; - } else if (!dump_file.ok()) { - return dump_file.status(); - } - - EnableIrPrinting(**dump_file, pm); + EnableIrPrinting(pm, file_name_prefix); - return dump_file; + LOG(INFO) << "IR dump for TensorFlow quantization pipeline enabled."; + return absl::OkStatus(); } } // namespace quantization diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/debugging/mlir_dump.h b/tensorflow/compiler/mlir/quantization/tensorflow/debugging/mlir_dump.h index 803cd39a0a5bae..38a9c4fae4f912 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/debugging/mlir_dump.h +++ b/tensorflow/compiler/mlir/quantization/tensorflow/debugging/mlir_dump.h @@ -15,27 +15,29 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_DEBUGGING_MLIR_DUMP_H_ #define TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_DEBUGGING_MLIR_DUMP_H_ -#include - -#include "absl/status/statusor.h" +#include "absl/status/status.h" #include "absl/strings/string_view.h" -#include "llvm/Support/raw_ostream.h" #include "mlir/Pass/PassManager.h" // from @llvm-project namespace tensorflow { namespace quantization { -// Enables IR printing for `pm`. When the passes are run, the IRs will be dumped -// to `out_stream`. -void EnableIrPrinting(llvm::raw_ostream &out_stream, mlir::PassManager &pm); +// Enables IR printing for `pm`. When the passes are run, each pass will dump to +// its own file with prefix `file_name_prefix`. +void EnableIrPrinting(mlir::PassManager &pm, + absl::string_view file_name_prefix); // If verbosity level >= 1, this will dump intermediate IRs of passes to a file. -// The file path is given by prefixing `name`.mlir with the value of the -// TF_QUANT_MLIR_DUMP_PREFIX env variable. Returns `nullptr` iff the verbosity -// level < 1 or TF_QUANT_MLIR_DUMP_PREFIX is not set or set to an empty string. -// The returned ostream instance should live until the pass run is complete. -absl::StatusOr> MaybeEnableIrPrinting( - mlir::PassManager &pm, absl::string_view name); +// The dumped mlir files with be under a directory determined by +// the TF_QUANT_MLIR_DUMP_PREFIX env variable. The PassManager will dump to a +// new file for each pass. The file name will have the format +// {file_name_prefix}_{pass_number}_{pass_name}_{before|after}.mlir. +// * `file_name_prefix` is from input. +// * `pass_number` increments from 1 for each pass. +// * `pass_name` is the name of the pass. +// * `before|after` indicates whether the dump occurs before or after the pass. +absl::Status MaybeEnableIrPrinting(mlir::PassManager &pm, + absl::string_view file_name_prefix); } // namespace quantization } // namespace tensorflow diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/debugging/mlir_dump_test.cc b/tensorflow/compiler/mlir/quantization/tensorflow/debugging/mlir_dump_test.cc index a7162d9a05a4ff..c3034f4294b13d 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/debugging/mlir_dump_test.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/debugging/mlir_dump_test.cc @@ -16,23 +16,30 @@ limitations under the License. #include #include +#include #include "absl/cleanup/cleanup.h" -#include "absl/status/statusor.h" #include "llvm/ADT/StringRef.h" -#include "llvm/Support/raw_ostream.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinDialect.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/Parser/Parser.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Pass/PassManager.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/Support/TypeID.h" // from @llvm-project +#include "mlir/Transforms/Passes.h" // from @llvm-project +#include "stablehlo/dialect/StablehloOps.h" // from @stablehlo +#include "tsl/lib/core/status_test_util.h" +#include "tsl/platform/env.h" #include "tsl/platform/path.h" #include "tsl/platform/test.h" namespace tensorflow { namespace quantization { -namespace { +namespace mlir_dump_test { class NoOpPass : public mlir::PassWrapper> { @@ -69,12 +76,7 @@ class ParentPass pm.addPass(CreateNoOpPass()); - std::error_code ec{}; // NOLINT: Required to create llvm::raw_fd_ostream - const std::string tmp_dump_filename = - tsl::io::GetTempFilename(/*extension=*/".mlir"); - llvm::raw_fd_ostream dump_file{tmp_dump_filename, ec}; - - EnableIrPrinting(dump_file, pm); + EnableIrPrinting(pm, "dump2"); if (failed(pm.run(module_op))) { signalPassFailure(); @@ -86,41 +88,88 @@ std::unique_ptr> CreateParentPass() { return std::make_unique(); } -TEST(EnableIrPrintingTest, PassSuccessfullyRuns) { - mlir::MLIRContext ctx{}; +} // namespace mlir_dump_test - mlir::PassManager pm = {&ctx}; - pm.addPass(CreateNoOpPass()); +namespace { - std::error_code ec{}; // NOLINT: Required to create llvm::raw_fd_ostream - const std::string tmp_dump_filename = - tsl::io::GetTempFilename(/*extension=*/".mlir"); - llvm::raw_fd_ostream dump_file{tmp_dump_filename, ec}; +using namespace tensorflow::quantization::mlir_dump_test; - EnableIrPrinting(dump_file, pm); +class EnableIrPrintingTest : public ::testing::Test { + protected: + EnableIrPrintingTest() : env_(tsl::Env::Default()) { + if (!tsl::io::GetTestUndeclaredOutputsDir(&test_dir_)) { + test_dir_ = tsl::testing::TmpDir(); + } + } - mlir::OpBuilder builder(&ctx); - auto module_op = builder.create(builder.getUnknownLoc()); - // Destroy by calling destroy() to avoid memory leak since it is allocated - // with malloc(). - const absl::Cleanup module_op_cleanup = [module_op] { module_op->destroy(); }; + void SetUp() override { + tsl::setenv("TF_QUANT_MLIR_DUMP_PREFIX", test_dir_.c_str(), 1); - const mlir::LogicalResult result = pm.run(module_op); + mlir::DialectRegistry dialects; + dialects.insert(); + ctx_ = std::make_unique(dialects); + ctx_->loadAllAvailableDialects(); + } + + void TearDown() override { + // Delete files in the test directory. + std::vector files; + TF_ASSERT_OK( + env_->GetMatchingPaths(tsl::io::JoinPath(test_dir_, "*"), &files)); + for (const std::string& file : files) { + TF_ASSERT_OK(env_->DeleteFile(file)); + } + } + + tsl::Env* env_; + std::string test_dir_; + std::unique_ptr ctx_; +}; + +TEST_F(EnableIrPrintingTest, PassSuccessfullyRuns) { + mlir::PassManager pm = {ctx_.get()}; + pm.addPass(CreateNoOpPass()); + pm.addNestedPass(mlir::createCanonicalizerPass()); + pm.addNestedPass(mlir::createCanonicalizerPass()); + + EnableIrPrinting(pm, "dump"); + + constexpr absl::string_view program = R"mlir( +module{ + func.func @main(%arg0: tensor<10xf32>) -> tensor<10xf32> { + return %arg0 : tensor<10xf32> + } + func.func @func1(%arg0: tensor<10xf32>, %arg1: tensor<10xf32>) -> tensor<10xf32> { + %0 = stablehlo.add %arg0, %arg1 : tensor<10xf32> + %1 = stablehlo.add %arg0, %arg1 : tensor<10xf32> + return %0 : tensor<10xf32> + } +})mlir"; + auto module_op = mlir::parseSourceString(program, ctx_.get()); + + const mlir::LogicalResult result = pm.run(module_op.get()); EXPECT_FALSE(failed(result)); + + TF_EXPECT_OK(tsl::Env::Default()->FileExists( + tsl::io::JoinPath(test_dir_, + "dump_0001_tensorflow::quantization::mlir_dump_test" + "::NoOpPass_before.mlir"))); + TF_EXPECT_OK(tsl::Env::Default()->FileExists( + tsl::io::JoinPath(test_dir_, "dump_0002_Canonicalizer_before.mlir"))); + TF_EXPECT_OK(tsl::Env::Default()->FileExists( + tsl::io::JoinPath(test_dir_, "dump_0002_Canonicalizer_after.mlir"))); + TF_EXPECT_OK(tsl::Env::Default()->FileExists( + tsl::io::JoinPath(test_dir_, "dump_0003_Canonicalizer_before.mlir"))); } -TEST(EnableNestedIrPrintingTest, PassSuccessfullyRuns) { +TEST_F(EnableIrPrintingTest, NestedPassSuccessfullyRuns) { mlir::MLIRContext ctx{}; mlir::PassManager pm = {&ctx}; pm.addPass(CreateParentPass()); - std::error_code ec{}; // NOLINT: Required to create llvm::raw_fd_ostream - const std::string tmp_dump_filename = - tsl::io::GetTempFilename(/*extension=*/".mlir"); - llvm::raw_fd_ostream dump_file{tmp_dump_filename, ec}; - - EnableIrPrinting(dump_file, pm); + EnableIrPrinting(pm, "dump"); mlir::OpBuilder builder(&ctx); auto module_op = builder.create(builder.getUnknownLoc()); @@ -130,6 +179,15 @@ TEST(EnableNestedIrPrintingTest, PassSuccessfullyRuns) { const mlir::LogicalResult result = pm.run(module_op); EXPECT_FALSE(failed(result)); + + TF_EXPECT_OK(tsl::Env::Default()->FileExists( + tsl::io::JoinPath(test_dir_, + "dump_0001_tensorflow::quantization::mlir_dump_test" + "::ParentPass_before.mlir"))); + TF_EXPECT_OK(tsl::Env::Default()->FileExists( + tsl::io::JoinPath(test_dir_, + "dump2_0001_tensorflow::quantization::mlir_dump_test" + "::NoOpPass_before.mlir"))); } } // namespace } // namespace quantization diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/ops/BUILD b/tensorflow/compiler/mlir/quantization/tensorflow/ops/BUILD index fa201ff6a716bc..7042b6c5b17cdb 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/ops/BUILD +++ b/tensorflow/compiler/mlir/quantization/tensorflow/ops/BUILD @@ -20,8 +20,8 @@ cc_library( "//tensorflow/compiler/mlir/lite/quantization:quantization_lib", "//tensorflow/compiler/mlir/quantization/tensorflow:quantization_options_proto_cc", "//tensorflow/compiler/mlir/tensorflow", - "//tensorflow/compiler/mlir/tensorflow:tensorflow_types", "@com_google_absl//absl/container:flat_hash_set", + "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", "@llvm-project//mlir:Support", ], diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/ops/tf_op_quant_spec.cc b/tensorflow/compiler/mlir/quantization/tensorflow/ops/tf_op_quant_spec.cc index 3a81ff0dfd2c91..fb13da8489a81c 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/ops/tf_op_quant_spec.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/ops/tf_op_quant_spec.cc @@ -14,17 +14,17 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/mlir/quantization/tensorflow/ops/tf_op_quant_spec.h" -#include -#include #include #include #include #include "absl/container/flat_hash_set.h" +#include "llvm/Support/Casting.h" #include "mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project #include "mlir/IR/Operation.h" // from @llvm-project #include "mlir/IR/Value.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project +#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/quantization_options.pb.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/add_dump_tensor_op.cc b/tensorflow/compiler/mlir/quantization/tensorflow/passes/add_dump_tensor_op.cc index 595534433849ee..daf256734110ee 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/passes/add_dump_tensor_op.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/add_dump_tensor_op.cc @@ -39,6 +39,7 @@ limitations under the License. #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/quantization/ir/QuantOps.h" #include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h" +#include "tensorflow/compiler/mlir/quantization/common/attrs_and_constraints.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/cc/quantization_unit_loc.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/passes/passes.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_quant_ops.h" @@ -54,7 +55,56 @@ namespace { using DebuggerType = tensorflow::quantization::DebuggerOptions::DebuggerType; using DebuggerOptions = tensorflow::quantization::DebuggerOptions; +constexpr StringRef kEntryFuncAttrName = "_entry_function"; +constexpr StringRef kOriginalEntryFuncAttrName = "_original_entry_function"; constexpr StringRef kCompositeFuncPrefix = "composite_"; +constexpr StringRef kEmptyNodeName = "_empty_node"; + +// Returns a pair: `func_name` and `node_name` for the lifted function. In TF +// quantizer, both are filled. For StableHLO quantizer, the func_name is only +// filled and node_name is always set to "_empty_node". +std::pair GetFuncNameAndNodeName( + TF::PartitionedCallOp call_op, const FlatSymbolRefAttr &f_attr) { + std::optional quant_unit = + FindQuantizationUnitFromLoc(call_op->getLoc()); + return std::make_pair(quant_unit->func_name(), quant_unit->node_name()); +} + +std::pair GetFuncNameAndNodeName( + TF::XlaCallModuleOp call_op, const FlatSymbolRefAttr &f_attr) { + return std::make_pair(f_attr.getValue().str(), kEmptyNodeName.str()); +} + +Operation *DuplicateOp(TF::PartitionedCallOp call_op, PatternRewriter &rewriter, + const StringAttr &new_ref_func_name) { + // Create PartitionedCallOp to the copied composite function. This + // PartitionedCallOp does not have kQuantTraitAttrName, and therefore won't + // get quantized. + auto new_call_op = rewriter.create( + call_op.getLoc(), call_op.getResultTypes(), call_op.getOperands(), + FlatSymbolRefAttr::get(new_ref_func_name)); + return new_call_op; +} + +Operation *DuplicateOp(TF::XlaCallModuleOp call_op, PatternRewriter &rewriter, + const StringAttr &new_ref_func_name) { + // Create XlaCallModuleOp to the copied composite function. This + // XlaCallModuleOp does not have kQuantTraitAttrName, and therefore won't get + // quantized. + auto new_call_op = rewriter.create( + call_op.getLoc(), call_op.getResultTypes(), call_op.getOperands(), + call_op.getVersionAttr(), call_op.getModuleAttr(), call_op.getSoutAttr()); + new_call_op->setAttr(kEntryFuncAttrName, + rewriter.getStringAttr(new_ref_func_name.getValue())); + new_call_op->setAttrs(call_op->getAttrs()); + new_call_op->removeAttr(rewriter.getStringAttr(kQuantTraitAttrName)); + + FlatSymbolRefAttr new_func_name_attr = + FlatSymbolRefAttr::get(rewriter.getContext(), new_ref_func_name); + new_call_op->setAttr(kEntryFuncAttrName, new_func_name_attr); + new_call_op->setAttr(kOriginalEntryFuncAttrName, new_ref_func_name); + return new_call_op; +} // AddDumpTensorOp pass adds DumpTensorOp - which saves entire value of its // input into a file - to quantizable layer's output. @@ -110,49 +160,66 @@ class AddDumpTensorOpPass std::string log_dir_path_ = "/tmp/dumps"; }; -class AddDumpTensorOp : public OpRewritePattern { +template +class AddDumpTensorOp : public OpRewritePattern { public: // Does not take ownership of context, which must refer to a valid value that // outlives this object. explicit AddDumpTensorOp(MLIRContext *context, DebuggerType debugger_type, std::string log_dir_path) - : OpRewritePattern(context), + : OpRewritePattern(context), debugger_type_(debugger_type), log_dir_path_(std::move(log_dir_path)) {} private: - DebuggerType debugger_type_; - std::string log_dir_path_; + SmallVector CreateDumpAttributes( + PatternRewriter &rewriter, const StringRef folder_name, + const StringRef file_name, const bool enabled, const StringRef func_name, + const StringRef node_name) const { + SmallVector dump_attributes{ + rewriter.getNamedAttr("log_dir_path", + rewriter.getStringAttr(folder_name)), + rewriter.getNamedAttr("file_name", rewriter.getStringAttr(file_name)), + // The op is disabled by default. Otherwise, values will be saved + // during calibration. + rewriter.getNamedAttr("enabled", rewriter.getBoolAttr(false)), + rewriter.getNamedAttr("func_name", rewriter.getStringAttr(func_name)), + rewriter.getNamedAttr("node_name", rewriter.getStringAttr(node_name)), + }; + return dump_attributes; + } - LogicalResult matchAndRewrite(TF::PartitionedCallOp call_op, - PatternRewriter &rewriter) const override { - const auto f_attr = call_op.getFAttr().dyn_cast(); - if (!call_op->hasAttr(kQuantTraitAttrName)) { - return failure(); - } - if (!f_attr.getValue().starts_with(kCompositeFuncPrefix)) { - return failure(); - } + StringAttr DuplicateFunction(Operation *op, + const FlatSymbolRefAttr &f_attr) const { + ModuleOp module = op->getParentOfType(); + SymbolTable symbol_table(module); - // For now, only support ops with 1 results - if (call_op->getNumResults() != 1) return failure(); + const func::FuncOp ref_func = + dyn_cast_or_null(symbol_table.lookup(f_attr.getValue())); + func::FuncOp new_ref_func = dyn_cast(ref_func->clone()); + return symbol_table.insert(new_ref_func); + } - Value result = call_op->getResult(0); + LogicalResult match(LiftedOpT op) const override { + if (!op->hasAttr(kQuantTraitAttrName) || op->getNumResults() != 1) { + return failure(); + } - // If one of the user is DumpTensorOp, do nothing + Value result = op->getResult(0); for (auto user : result.getUsers()) { if (dyn_cast_or_null(user)) return failure(); } - rewriter.setInsertionPointAfterValue(result); - - std::optional quant_unit = - FindQuantizationUnitFromLoc(call_op->getLoc()); + const FlatSymbolRefAttr f_attr = GetFuncAttr(op); + if (!f_attr.getValue().starts_with(kCompositeFuncPrefix)) return failure(); + return success(); + } - if (!quant_unit.has_value()) return failure(); + void rewrite(LiftedOpT op, PatternRewriter &rewriter) const override { + // Only support ops with 1 results + Value result = op->getResult(0); + rewriter.setInsertionPointAfterValue(result); - auto folder_name = - tensorflow::io::JoinPath(log_dir_path_, f_attr.getValue()); // In Whole model, we first need to set file_name as // unquantized_tensor_data.pb as it is used by unquantized dump model. // After saving unquantized dump model, the file name will be changed to @@ -161,77 +228,56 @@ class AddDumpTensorOp : public OpRewritePattern { // as quantized_tensor_data.pb here. // TODO: b/296933893 - Refactor the debugger code when no quantize option // is added - auto file_name = + std::string file_name = debugger_type_ == DebuggerOptions::DEBUGGER_TYPE_WHOLE_MODEL ? "unquantized_tensor_data.pb" : "quantized_tensor_data.pb"; - SmallVector dump_attributes{ - rewriter.getNamedAttr("log_dir_path", - rewriter.getStringAttr(folder_name)), - rewriter.getNamedAttr("file_name", rewriter.getStringAttr(file_name)), - // The op is disabled by default. Otherwise, values will be saved - // during calibration. - rewriter.getNamedAttr("enabled", rewriter.getBoolAttr(false)), - rewriter.getNamedAttr("func_name", - rewriter.getStringAttr(quant_unit->func_name())), - rewriter.getNamedAttr("node_name", - rewriter.getStringAttr(quant_unit->node_name())), - }; + const FlatSymbolRefAttr f_attr = GetFuncAttr(op); - rewriter.create(call_op->getLoc(), TypeRange{}, result, + // In TF::PartitionedCallOp case, func_name and node_name are filled. + // But in TF::XlaCallModuleOp case, node_name is `kEmptyNodeName` since + // debugging and selective quantization of StableHLO Quantizer only uses + // func_name for op matching. + auto [func_name, node_name] = GetFuncNameAndNodeName(op, f_attr); + std::string folder_name = + tensorflow::io::JoinPath(log_dir_path_, f_attr.getValue()); + + // Attach DumpTensorOp to its output layer. + SmallVector dump_attributes = + CreateDumpAttributes(rewriter, folder_name, file_name, + /*enabled=*/false, func_name, node_name); + rewriter.create(op->getLoc(), TypeRange{}, result, dump_attributes); // Per-layer mode. if (debugger_type_ == DebuggerOptions::DEBUGGER_TYPE_INT_PER_LAYER || debugger_type_ == DebuggerOptions::DEBUGGER_TYPE_FLOAT_PER_LAYER) { - auto module = call_op->getParentOfType(); - SymbolTable symbol_table(module); - - // Copy composite function of quantizable layer. - const mlir::func::FuncOp ref_func = dyn_cast_or_null( - symbol_table.lookup(f_attr.getValue())); - mlir::func::FuncOp new_ref_func = - dyn_cast(ref_func->clone()); - const StringAttr new_ref_func_name = symbol_table.insert(new_ref_func); - - // Create PartitionedCallOp to the copied composite function. - // This PartitionedCallOp does not have kQuantTraitAttrName, and therefore - // won't get quantized. - auto ref_call_op = rewriter.create( - call_op.getLoc(), call_op.getResultTypes(), call_op.getOperands(), - FlatSymbolRefAttr::get(new_ref_func_name)); - - // Attach DumpTensorOp to its output unquantized layer. - SmallVector dump_attributes{ - rewriter.getNamedAttr("log_dir_path", - rewriter.getStringAttr(folder_name)), - rewriter.getNamedAttr("file_name", rewriter.getStringAttr( - "unquantized_tensor_data.pb")), - rewriter.getNamedAttr("enabled", rewriter.getBoolAttr(false)), - rewriter.getNamedAttr( - "func_name", rewriter.getStringAttr(quant_unit->func_name())), - rewriter.getNamedAttr( - "node_name", rewriter.getStringAttr(quant_unit->node_name())), - }; - - rewriter.create(call_op->getLoc(), TypeRange{}, - ref_call_op.getResult(0), - dump_attributes); + // Duplicate composite function and op of quantizable layer for creating + // unquantized layer. + StringAttr new_ref_func_name = DuplicateFunction(op, f_attr); + Operation *new_op = DuplicateOp(op, rewriter, new_ref_func_name); + + // Attach second DumpTensorOp to its output unquantized layer. + SmallVector dump_attributes = CreateDumpAttributes( + rewriter, folder_name, /*file_name=*/"unquantized_tensor_data.pb", + /*enabled=*/false, func_name, node_name); + rewriter.create(op.getLoc(), TypeRange{}, + new_op->getResult(0), dump_attributes); if (debugger_type_ == DebuggerOptions::DEBUGGER_TYPE_FLOAT_PER_LAYER) { // Swap all uses between call_op and ref_call_op, except for the // particular use that owns DumpTensor. rewriter.replaceUsesWithIf( - call_op.getResult(0), ref_call_op.getResult(0), - [](OpOperand &use) -> bool { + op.getResult(0), new_op->getResult(0), [](OpOperand &use) -> bool { return !isa(use.getOwner()); }); } } - - return success(); } + + DebuggerType debugger_type_; + std::string log_dir_path_; }; static PassRegistration pass; @@ -241,7 +287,10 @@ void AddDumpTensorOpPass::runOnOperation() { RewritePatternSet patterns(ctx); ModuleOp module = getOperation(); - patterns.add(ctx, debugger_type_, log_dir_path_); + patterns.add, + AddDumpTensorOp>(ctx, debugger_type_, + log_dir_path_); + if (failed(applyPatternsAndFoldGreedily(module, std::move(patterns)))) { module.emitError() << "quant-add-dump-tensor-op failed."; signalPassFailure(); diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/merge_initializer_function_ops_to_main.cc b/tensorflow/compiler/mlir/quantization/tensorflow/passes/merge_initializer_function_ops_to_main.cc index 6e94beb6b0a057..f1f65a1a183371 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/passes/merge_initializer_function_ops_to_main.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/merge_initializer_function_ops_to_main.cc @@ -35,6 +35,7 @@ limitations under the License. #include "mlir/IR/Value.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "tensorflow/compiler/mlir/quantization/common/func.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/passes/manipulate_model_attr.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/passes/passes.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" @@ -107,20 +108,6 @@ class MergeInitializerFunctionOpsToMainPass } }; -// Gets the "main" function from the module. Returns an empty op iff it doesn't -// exist. -func::FuncOp GetMainFunction(ModuleOp module_op) { - const auto main_func_id = - StringAttr::get(module_op.getContext(), kImportModelDefaultGraphFuncName); - auto func_ops = module_op.getOps(); - auto main_func_itr = absl::c_find_if(func_ops, [&main_func_id](auto func_op) { - return func_op.getName() == main_func_id; - }); - - if (main_func_itr == func_ops.end()) return {}; - return *main_func_itr; -} - // Returns true iff func_op has either no Region or the body has no Blocks. bool IsFuncOpEmpty(func::FuncOp func_op) { return func_op->getNumRegions() == 0 || func_op.getBody().empty(); @@ -336,7 +323,7 @@ void MergeInitializerFunctionOpsToMainPass::runOnOperation() { ModuleOp module_op = getOperation(); MLIRContext* ctx = module_op.getContext(); - func::FuncOp main_func_op = GetMainFunction(module_op); + func::FuncOp main_func_op = FindMainFuncOp(module_op); if (!main_func_op) { module_op.emitError("Main function op not found."); return signalPassFailure(); diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/prepare_lifting.cc b/tensorflow/compiler/mlir/quantization/tensorflow/passes/prepare_lifting.cc index 886a27011b1825..ebdd374288a065 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/passes/prepare_lifting.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/prepare_lifting.cc @@ -24,6 +24,7 @@ limitations under the License. #include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project #include "mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project #include "mlir/IR/DialectRegistry.h" // from @llvm-project @@ -120,6 +121,27 @@ bool HasEqualElementSize(Value val1, Value val2, ArrayRef val1_indices, return val1_result == val2_result; } +// Checks if a shape has dim sizes of all ones except the right most dim. +bool ReshapableTo1DTensor(ShapedType rhs_shape) { + for (auto rank = 0; rank < rhs_shape.getRank() - 1; rank++) { + if (rhs_shape.getDimSize(rank) != 1) { + return false; + } + } + return true; +} + +Value ReshapeTo1DTensor(OpBuilder& builder, Location loc, Value value) { + auto shape = value.getType().cast(); + if (shape.getRank() != 1) { + SmallVector new_shape; + new_shape.push_back(shape.getNumElements()); + value = builder.create( + loc, value, Create1DConstValue(builder, loc, new_shape)); + } + return ConstantFoldOpIfPossible(value.getDefiningOp()).front(); +} + // Matches convolution op with "NHWC" data format or matmul op with false adj_y. // The list of supported ops in this function is: // - Conv2DOp diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/prepare_lifting.td b/tensorflow/compiler/mlir/quantization/tensorflow/passes/prepare_lifting.td index 30e298dd6e7048..d75a01be7d2182 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/passes/prepare_lifting.td +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/prepare_lifting.td @@ -82,6 +82,13 @@ class HasEqualElementSize shape_1, list shape_2> : Constraint< "llvm::ArrayRef({" # !interleave(shape_2, ", ") # "}))">, "Checks if the given dimensions contain the same number of elements.">; +def ReshapableTo1DTensor : Constraint< + CPred<"quant::ReshapableTo1DTensor($0.getType().cast())">, + "Checks if the value dims are all ones except the right most dim">; + +def ReshapeTo1DTensor : NativeCodeCall< + "quant::ReshapeTo1DTensor($_builder, $_loc, $0)">; + def HasEqualShape : Constraint().hasRank() && " "$1.getType().cast().hasRank() && " @@ -112,7 +119,29 @@ def ConvertAddToBiasAdd : Pat< (TF_ConstOp:$add_rhs IsFloatElementsAttr:$add_rhs_value)), (TF_BiasAddOp $conv_out, $add_rhs, (CreateStringAttr<"NHWC">)), [(HasRankOf<1> $add_rhs_value), - (HasEqualElementSize<[-1], [0]> $conv_out, $add_rhs)]>; + (HasEqualElementSize<[-1], [0]> $conv_out, $add_rhs)], [], (addBenefit -1)>; + +// Convert conv+sub+mul pattern to conv+mul+add. +// (conv - sub) * mul -> conv * mul + (-sub) * mul +// +// This is needed to support Conv+BatchNorm pattern from Jax models converted +// using jax2tf w/o native serialization. Note that Jax2tf patterns always +// extend bias shapes to a rank of 4, e.g. 1x1x1x5. +def ConvertSubMulToMulAdd : Pat< + (TF_MulOp + (TF_SubOp + (SupportedAffineOpMatcher $conv_out, $input, $weight), + (TF_ConstOp:$sub_rhs IsFloatElementsAttr:$sub_rhs_value)), + (TF_ConstOp:$mul_rhs IsFloatElementsAttr:$mul_rhs_value)), + (TF_AddV2Op + (TF_MulOp $conv_out, (ReshapeTo1DTensor $mul_rhs)), + (TF_MulOp + (TF_NegOp (ReshapeTo1DTensor $sub_rhs)), + (ReshapeTo1DTensor $mul_rhs))), + [(ReshapableTo1DTensor $mul_rhs), + (ReshapableTo1DTensor $sub_rhs), + (HasEqualElementSize<[-1], [-1]> $conv_out, $mul_rhs), + (HasEqualElementSize<[-1], [-1]> $conv_out, $sub_rhs)]>; // TODO(b/278493977): Create generic implementation of lifting any fused op // with any reshaping op @@ -128,6 +157,7 @@ def ConvertAddWithReshapeToBiasAddWithReshape : Pat< (HasEqualElementSize<[-1], [0]> $reshape_out, $add_rhs)]>; // Fuse consecutive BiasAddOp and an AddV2Op. +// We also handle the case where add_rhs has rank 4. def FuseBiasAndAddV2 : Pat< (TF_AddV2Op (TF_BiasAddOp:$bias_add @@ -135,9 +165,10 @@ def FuseBiasAndAddV2 : Pat< (TF_ConstOp:$bias IsFloatElementsAttr:$bias_value), $data_format), (TF_ConstOp:$add_rhs IsFloatElementsAttr:$add_rhs_value)), (TF_BiasAddOp - $conv_out, (TF_AddV2Op $bias, $add_rhs), $data_format), + $conv_out, (TF_AddV2Op $bias, (ReshapeTo1DTensor $add_rhs)), $data_format), [(HasOneUse $bias_add), - (HasEqualShape $bias_value, $add_rhs_value)]>; + (ReshapableTo1DTensor $add_rhs), + (HasEqualElementSize<[-1], [-1]> $bias, $add_rhs)]>; // Fuse AffineOp followed by an MulOp patterns. def FuseAffineOpAndMul : Pat< diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/prepare_quantize.cc b/tensorflow/compiler/mlir/quantization/tensorflow/passes/prepare_quantize.cc index b5fb96396f7ef9..20ffa2adcfa969 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/passes/prepare_quantize.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/prepare_quantize.cc @@ -414,7 +414,7 @@ void PrepareQuantizePass::runOnOperation() { ApplyQuantizationParamsPropagation( func, is_signed, /*bit_width=*/8, !enable_per_channel_quantization_, GetTFOpQuantSpec, GetTfQuantScaleSpec, infer_tensor_range, - quant_specs_.legacy_float_scale); + quant_specs_.legacy_float_scale, /*is_qdq_conversion=*/false); RewritePatternSet patterns2(ctx); patterns2.add(ctx); diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/python/BUILD b/tensorflow/compiler/mlir/quantization/tensorflow/python/BUILD index 5248d95c9f9e10..490cc1fc889b91 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/python/BUILD +++ b/tensorflow/compiler/mlir/quantization/tensorflow/python/BUILD @@ -11,6 +11,7 @@ package( default_visibility = [ "//tensorflow/compiler/mlir/quantization/stablehlo:__subpackages__", "//tensorflow/compiler/mlir/quantization/tensorflow:internal_visibility_allowlist_package", + "//tensorflow/lite:__subpackages__", "//tensorflow/python:__subpackages__", "//tensorflow/tools/pip_package/v2:__subpackages__", ], @@ -345,6 +346,7 @@ tf_py_strict_test( "//tensorflow/python/framework:tensor_spec", "//tensorflow/python/framework:test_lib", "//tensorflow/python/lib/io:file_io", + "//tensorflow/python/lib/io:tf_record", "//tensorflow/python/module", "//tensorflow/python/ops:array_ops", "//tensorflow/python/ops:math_ops", diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/python/integration_test/quantize_model_test.py b/tensorflow/compiler/mlir/quantization/tensorflow/python/integration_test/quantize_model_test.py index a67a6451341e4e..ca7027dbe312bd 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/python/integration_test/quantize_model_test.py +++ b/tensorflow/compiler/mlir/quantization/tensorflow/python/integration_test/quantize_model_test.py @@ -42,6 +42,7 @@ from tensorflow.python.framework import tensor_spec from tensorflow.python.framework import test_util from tensorflow.python.lib.io import file_io +from tensorflow.python.lib.io import tf_record from tensorflow.python.module import module from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops @@ -1790,6 +1791,15 @@ def gen_data() -> repr_dataset.RepresentativeDataset: 'enable_per_channel_quantization': True, 'dilations': [1, 2, 2, 1], }, + { + 'testcase_name': 'with_bias_and_relu6_to_stablehlo_per_channel', + 'activation_fn': nn_ops.relu6, + 'has_bias': True, + 'has_batch_norm': False, + 'target_opset': quant_opts_pb2.STABLEHLO, + 'input_shape_dynamic': False, + 'enable_per_channel_quantization': True, + }, ) @test_util.run_in_graph_and_eager_modes def test_conv_ptq_model( @@ -1950,6 +1960,10 @@ def data_gen() -> repr_dataset.RepresentativeDataset: ), ) self.assertFalse(self._contains_op(output_graphdef, 'Conv2D')) + elif target_opset == quant_opts_pb2.STABLEHLO: + # This is to verify the invocation of StableHLO quantizer works. More + # thorough functional tests are in StableHLO quantizer directory. + self.assertTrue(self._contains_op(output_graphdef, 'XlaCallModule')) else: self.assertTrue(self._contains_quantized_function_call(output_graphdef)) self.assertFalse(self._contains_op(output_graphdef, 'FusedBatchNormV3')) @@ -5831,7 +5845,7 @@ def test_while_op_model( class DebuggerTest(quantize_model_test_base.QuantizedModelTest): - def _run_model_in_sess(self, model_dir, tags, signature_key, sample_input): + def _run_model_in_sess(self, model_dir, tags, signature_key, sample_inputs): with tensorflow.compat.v1.Session(graph=tensorflow.Graph()) as sess: meta_graph = saved_model_loader.load(sess, tags, export_dir=model_dir) signature_def = meta_graph.signature_def[signature_key] @@ -5843,13 +5857,26 @@ def _run_model_in_sess(self, model_dir, tags, signature_key, sample_input): for output_tensor_info in signature_def.outputs.values() ] - feed_dict = {} - for input_key, input_value in sample_input.items(): - input_tensor_name = signature_def.inputs[input_key].name - feed_dict[input_tensor_name] = input_value + output_values = [] + for sample_input in sample_inputs: + feed_dict = {} + for input_key, input_value in sample_input.items(): + input_tensor_name = signature_def.inputs[input_key].name + feed_dict[input_tensor_name] = input_value - # Obtain the output of the model. - return sess.run(output_tensor_names, feed_dict=feed_dict)[0] + # Obtain the output of the model. + output_values.append( + sess.run(output_tensor_names, feed_dict=feed_dict)[0] + ) + return output_values + + def _read_tensor_array_file(self, file_path): + tensor_protos = [] + for raw_record in tf_record.tf_record_iterator(file_path, options='ZLIB'): + tensor_protos.append( + tensorflow.make_ndarray(tensor_pb2.TensorProto.FromString(raw_record)) + ) + return np.array(tensor_protos) @parameterized.named_parameters( { @@ -5926,9 +5953,10 @@ def data_gen() -> repr_dataset.RepresentativeDataset: converted_model.signatures._signatures.keys(), {'serving_default'} ) - sample_input = { - 'input_tensor': np.random.uniform(low=0, high=1, size=(16, 3, 4, 3)) - } + sample_inputs = [ + {'input_tensor': np.random.uniform(low=0, high=1, size=(16, 3, 4, 3))}, + {'input_tensor': np.random.uniform(low=0, high=1, size=(16, 3, 4, 3))}, + ] # Check if output of the model and value saved by DumpTensorOp matches. # Verify for both unquantized model and quantized model. @@ -5936,24 +5964,19 @@ def data_gen() -> repr_dataset.RepresentativeDataset: [unquantized_dump_model_path, 'unquantized_tensor_data.pb'], [self._output_saved_model_path, 'quantized_tensor_data.pb'], ]: - output_value = self._run_model_in_sess( - model_path, tags, 'serving_default', sample_input + output_values = self._run_model_in_sess( + model_path, tags, 'serving_default', sample_inputs ) # Find the dump file and parse it. folder = os.path.join(log_dir_path, os.listdir(log_dir_path)[0]) dump_file_path = os.path.join(log_dir_path, folder, file_name) - - dump_file_proto = tensor_pb2.TensorProto.FromString( - open(dump_file_path, 'rb').read() - ) - - dump_file_numpy = tensorflow.make_ndarray(dump_file_proto) + dump_file_numpy = self._read_tensor_array_file(dump_file_path) # Since the model only has one conv2d and its output is directly used as # the output of the model, output of the model and conv2d's dump value # should be the same. - self.assertAllEqual(output_value, dump_file_numpy) + self.assertAllEqual(output_values, dump_file_numpy) # Verify if quant_unit.pb file was created correctly. quant_unit_file_path = os.path.join(log_dir_path, folder, 'quant_unit.pb') @@ -6070,15 +6093,16 @@ def data_gen() -> repr_dataset.RepresentativeDataset: converted_model.signatures._signatures.keys(), {'serving_default'} ) - sample_input = { - 'input_tensor': np.random.uniform(low=0, high=1, size=(16, 3, 4, 3)) - } + sample_inputs = [ + {'input_tensor': np.random.uniform(low=0, high=1, size=(16, 3, 4, 3))}, + {'input_tensor': np.random.uniform(low=0, high=1, size=(16, 3, 4, 3))}, + ] output_value_from_original_model = self._run_model_in_sess( - self._input_saved_model_path, tags, 'serving_default', sample_input + self._input_saved_model_path, tags, 'serving_default', sample_inputs ) output_value_from_debugging_model = self._run_model_in_sess( - self._output_saved_model_path, tags, 'serving_default', sample_input + self._output_saved_model_path, tags, 'serving_default', sample_inputs ) # Find the both quantized and unquantized dump file. @@ -6090,18 +6114,11 @@ def data_gen() -> repr_dataset.RepresentativeDataset: log_dir_path, folder, 'quantized_tensor_data.pb' ) - unquantized_dump_file_proto = tensor_pb2.TensorProto.FromString( - open(unquantized_dump_file_path, 'rb').read() - ) - quantized_dump_file_proto = tensor_pb2.TensorProto.FromString( - open(quantized_dump_file_path, 'rb').read() - ) - - unquantized_dump_file_numpy = tensorflow.make_ndarray( - unquantized_dump_file_proto + unquantized_dump_file_numpy = self._read_tensor_array_file( + unquantized_dump_file_path ) - quantized_dump_file_numpy = tensorflow.make_ndarray( - quantized_dump_file_proto + quantized_dump_file_numpy = self._read_tensor_array_file( + quantized_dump_file_path ) # Since the model only has one conv2d and its output is directly used as @@ -6143,169 +6160,46 @@ class CalibrationOptionsTest(quantize_model_test_base.QuantizedModelTest): (default in TF2) to ensure support for when TF2 is disabled. """ - @parameterized.named_parameters( - { - 'testcase_name': 'with_min_max', - 'target_opset': quant_opts_pb2.TF, - 'calibration_options': quant_opts_pb2.CalibrationOptions( - calibration_method=_CalibrationMethod.CALIBRATION_METHOD_MIN_MAX - ), - }, - { - 'testcase_name': 'with_min_max_to_xla', - 'target_opset': quant_opts_pb2.XLA, - 'calibration_options': quant_opts_pb2.CalibrationOptions( - calibration_method=_CalibrationMethod.CALIBRATION_METHOD_MIN_MAX - ), - }, - { - 'testcase_name': 'with_min_max_to_uq', - 'target_opset': quant_opts_pb2.UNIFORM_QUANTIZED, - 'calibration_options': quant_opts_pb2.CalibrationOptions( - calibration_method=_CalibrationMethod.CALIBRATION_METHOD_MIN_MAX - ), - }, - { - 'testcase_name': 'with_average_min_max', - 'target_opset': quant_opts_pb2.TF, - 'calibration_options': quant_opts_pb2.CalibrationOptions( - calibration_method=_CalibrationMethod.CALIBRATION_METHOD_AVERAGE_MIN_MAX - ), - }, - { - 'testcase_name': 'with_average_min_max_to_xla', - 'target_opset': quant_opts_pb2.XLA, - 'calibration_options': quant_opts_pb2.CalibrationOptions( - calibration_method=_CalibrationMethod.CALIBRATION_METHOD_AVERAGE_MIN_MAX - ), - }, - { - 'testcase_name': 'with_average_min_max_to_uq', - 'target_opset': quant_opts_pb2.UNIFORM_QUANTIZED, - 'calibration_options': quant_opts_pb2.CalibrationOptions( - calibration_method=_CalibrationMethod.CALIBRATION_METHOD_AVERAGE_MIN_MAX - ), - }, - { - 'testcase_name': 'with_histogram_percentile', - 'target_opset': quant_opts_pb2.TF, - 'calibration_options': quant_opts_pb2.CalibrationOptions( - calibration_method=_CalibrationMethod.CALIBRATION_METHOD_HISTOGRAM_PERCENTILE, - calibration_parameters=quant_opts_pb2.CalibrationOptions.CalibrationParameters( - initial_num_bins=10, - ), - ), - }, - { - 'testcase_name': 'with_histogram_percentile_to_xla', - 'target_opset': quant_opts_pb2.XLA, - 'calibration_options': quant_opts_pb2.CalibrationOptions( - calibration_method=_CalibrationMethod.CALIBRATION_METHOD_HISTOGRAM_PERCENTILE, - calibration_parameters=quant_opts_pb2.CalibrationOptions.CalibrationParameters( - initial_num_bins=10, - ), - ), - }, - { - 'testcase_name': 'with_histogram_percentile_to_uq', - 'target_opset': quant_opts_pb2.UNIFORM_QUANTIZED, - 'calibration_options': quant_opts_pb2.CalibrationOptions( - calibration_method=_CalibrationMethod.CALIBRATION_METHOD_HISTOGRAM_PERCENTILE, - calibration_parameters=quant_opts_pb2.CalibrationOptions.CalibrationParameters( - initial_num_bins=10, - ), - ), - }, - { - 'testcase_name': 'with_histogram_mse_bruteforce', - 'target_opset': quant_opts_pb2.TF, - 'calibration_options': quant_opts_pb2.CalibrationOptions( - calibration_method=_CalibrationMethod.CALIBRATION_METHOD_HISTOGRAM_MSE_BRUTEFORCE, - calibration_parameters=quant_opts_pb2.CalibrationOptions.CalibrationParameters( - initial_num_bins=10, - ), - ), - }, - { - 'testcase_name': 'with_histogram_mse_bruteforce_to_xla', - 'target_opset': quant_opts_pb2.XLA, - 'calibration_options': quant_opts_pb2.CalibrationOptions( - calibration_method=_CalibrationMethod.CALIBRATION_METHOD_HISTOGRAM_MSE_BRUTEFORCE, - calibration_parameters=quant_opts_pb2.CalibrationOptions.CalibrationParameters( - initial_num_bins=10, - ), - ), - }, - { - 'testcase_name': 'with_histogram_mse_bruteforce_to_uq', - 'target_opset': quant_opts_pb2.UNIFORM_QUANTIZED, - 'calibration_options': quant_opts_pb2.CalibrationOptions( - calibration_method=_CalibrationMethod.CALIBRATION_METHOD_HISTOGRAM_MSE_BRUTEFORCE, - calibration_parameters=quant_opts_pb2.CalibrationOptions.CalibrationParameters( - initial_num_bins=10, - ), - ), - }, - { - 'testcase_name': 'with_histogram_mse_max_frequency', - 'target_opset': quant_opts_pb2.TF, - 'calibration_options': quant_opts_pb2.CalibrationOptions( - calibration_method=_CalibrationMethod.CALIBRATION_METHOD_HISTOGRAM_MSE_MAX_FREQUENCY, - calibration_parameters=quant_opts_pb2.CalibrationOptions.CalibrationParameters( - initial_num_bins=10, + @parameterized.parameters( + parameter_combinations([{ + 'target_opset': [ + quant_opts_pb2.TF, + quant_opts_pb2.XLA, + quant_opts_pb2.UNIFORM_QUANTIZED, + ], + 'calibration_options': [ + quant_opts_pb2.CalibrationOptions( + calibration_method=_CalibrationMethod.CALIBRATION_METHOD_MIN_MAX ), - ), - }, - { - 'testcase_name': 'with_histogram_mse_max_frequency_to_xla', - 'target_opset': quant_opts_pb2.XLA, - 'calibration_options': quant_opts_pb2.CalibrationOptions( - calibration_method=_CalibrationMethod.CALIBRATION_METHOD_HISTOGRAM_MSE_MAX_FREQUENCY, - calibration_parameters=quant_opts_pb2.CalibrationOptions.CalibrationParameters( - initial_num_bins=10, + quant_opts_pb2.CalibrationOptions( + calibration_method=_CalibrationMethod.CALIBRATION_METHOD_AVERAGE_MIN_MAX ), - ), - }, - { - 'testcase_name': 'with_histogram_mse_max_frequency_to_uq', - 'target_opset': quant_opts_pb2.UNIFORM_QUANTIZED, - 'calibration_options': quant_opts_pb2.CalibrationOptions( - calibration_method=_CalibrationMethod.CALIBRATION_METHOD_HISTOGRAM_MSE_MAX_FREQUENCY, - calibration_parameters=quant_opts_pb2.CalibrationOptions.CalibrationParameters( - initial_num_bins=10, + quant_opts_pb2.CalibrationOptions( + calibration_method=_CalibrationMethod.CALIBRATION_METHOD_HISTOGRAM_PERCENTILE, + calibration_parameters=quant_opts_pb2.CalibrationOptions.CalibrationParameters( + initial_num_bins=10, + ), ), - ), - }, - { - 'testcase_name': 'with_histogram_mse_symmetric', - 'target_opset': quant_opts_pb2.TF, - 'calibration_options': quant_opts_pb2.CalibrationOptions( - calibration_method=_CalibrationMethod.CALIBRATION_METHOD_HISTOGRAM_MSE_SYMMETRIC, - calibration_parameters=quant_opts_pb2.CalibrationOptions.CalibrationParameters( - initial_num_bins=10, + quant_opts_pb2.CalibrationOptions( + calibration_method=_CalibrationMethod.CALIBRATION_METHOD_HISTOGRAM_MSE_BRUTEFORCE, + calibration_parameters=quant_opts_pb2.CalibrationOptions.CalibrationParameters( + initial_num_bins=10, + ), ), - ), - }, - { - 'testcase_name': 'with_histogram_mse_symmetric_to_xla', - 'target_opset': quant_opts_pb2.XLA, - 'calibration_options': quant_opts_pb2.CalibrationOptions( - calibration_method=_CalibrationMethod.CALIBRATION_METHOD_HISTOGRAM_MSE_SYMMETRIC, - calibration_parameters=quant_opts_pb2.CalibrationOptions.CalibrationParameters( - initial_num_bins=10, + quant_opts_pb2.CalibrationOptions( + calibration_method=_CalibrationMethod.CALIBRATION_METHOD_HISTOGRAM_MSE_MAX_FREQUENCY, + calibration_parameters=quant_opts_pb2.CalibrationOptions.CalibrationParameters( + initial_num_bins=10, + ), ), - ), - }, - { - 'testcase_name': 'with_histogram_mse_symmetric_to_uq', - 'target_opset': quant_opts_pb2.UNIFORM_QUANTIZED, - 'calibration_options': quant_opts_pb2.CalibrationOptions( - calibration_method=_CalibrationMethod.CALIBRATION_METHOD_HISTOGRAM_MSE_SYMMETRIC, - calibration_parameters=quant_opts_pb2.CalibrationOptions.CalibrationParameters( - initial_num_bins=10, + quant_opts_pb2.CalibrationOptions( + calibration_method=_CalibrationMethod.CALIBRATION_METHOD_HISTOGRAM_MSE_SYMMETRIC, + calibration_parameters=quant_opts_pb2.CalibrationOptions.CalibrationParameters( + initial_num_bins=10, + ), ), - ), - }, + ], + }]) ) @test_util.run_in_graph_and_eager_modes def test_conv_ptq_model_by_calibration_options( diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/python/py_function_lib.py b/tensorflow/compiler/mlir/quantization/tensorflow/python/py_function_lib.py index d9f8c9781fc4ca..902ee3e5c94e2e 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/python/py_function_lib.py +++ b/tensorflow/compiler/mlir/quantization/tensorflow/python/py_function_lib.py @@ -22,7 +22,6 @@ from tensorflow.compiler.mlir.quantization.tensorflow import quantization_options_pb2 from tensorflow.compiler.mlir.quantization.tensorflow.calibrator import calibration_algorithm from tensorflow.compiler.mlir.quantization.tensorflow.calibrator import calibration_statistics_pb2 -from tensorflow.compiler.mlir.quantization.tensorflow.calibrator import pywrap_calibration from tensorflow.compiler.mlir.quantization.tensorflow.python import pywrap_function_lib from tensorflow.compiler.mlir.quantization.tensorflow.python import representative_dataset as rd from tensorflow.compiler.mlir.quantization.tensorflow.python import save_model @@ -500,31 +499,6 @@ def _run_graph_for_calibration( logging.info('Calibration step complete.') -def _get_min_max_from_calibrator( - node_id: bytes, - calib_opts: quantization_options_pb2.CalibrationOptions, -) -> tuple[float, float]: - """Calculate min and max from statistics using calibration options. - - Args: - node_id: bytes of node id. - calib_opts: Calibration options used for calculating min and max. - - Returns: - (min_value, max_value): Min and max calculated using calib_opts. - - Raises: - ValueError: Unsupported calibration method is given. - """ - statistics: calibration_statistics_pb2.CalibrationStatistics = ( - pywrap_calibration.get_statistics_from_calibrator(node_id) - ) - min_value, max_value = calibration_algorithm.get_min_max_value( - statistics, calib_opts - ) - return min_value, max_value - - class PyFunctionLibrary(pywrap_function_lib.PyFunctionLibrary): """Wrapper class for overridden python method definitions. diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model.cc b/tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model.cc index db054766f1559c..512102d0a5f53c 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model.cc @@ -329,6 +329,9 @@ absl::StatusOr QuantizePtqModelPostCalibration( // Use StableHLO Quantizer option if opset is specified. if (is_stablehlo) { QuantizationConfig quantization_config{}; + quantization_config.mutable_static_range_ptq_preset() + ->set_enable_per_channel_quantized_weight( + quantization_options.enable_per_channel_quantization()); // When targeting server TPUs quantized types should be unpacked into // integer ops. quantization_config.mutable_pipeline_config()->set_unpack_quantized_types( diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model.py b/tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model.py index 13db5fff7a8cdc..556222ef4797d5 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model.py +++ b/tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model.py @@ -682,15 +682,16 @@ def _populate_quantization_options_default_values( == _PresetMethod.METHOD_STATIC_RANGE_WEIGHT_ONLY_INT8 ) or ( - quantization_options.op_set == quant_opts_pb2.OpSet.XLA + quantization_options.op_set + in (quant_opts_pb2.OpSet.XLA, quant_opts_pb2.OpSet.STABLEHLO) and quantization_options.quantization_method.preset_method == _PresetMethod.METHOD_STATIC_RANGE_INT8 ) ): raise ValueError( 'Currently, per-channel quantization is supported for Uniform Quantized' - ' opset, weight only quantization, or XLA opset with static range' - ' quantization.' + ' opset, weight only quantization, or XLA/StableHLO opset with static' + ' range quantization.' ) if ( diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/quantize_preprocess.cc b/tensorflow/compiler/mlir/quantization/tensorflow/quantize_preprocess.cc index 136aab9f583030..2ca81b72aa71aa 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/quantize_preprocess.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/quantize_preprocess.cc @@ -30,6 +30,7 @@ limitations under the License. #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Transforms/Passes.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_tf_xla_call_module_to_stablehlo_pass.h" +#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/passes.h" #include "tensorflow/compiler/mlir/lite/stablehlo/transforms/rename_entrypoint_to_main.h" #include "tensorflow/compiler/mlir/lite/stablehlo/transforms/tf_stablehlo_pass.h" #include "tensorflow/compiler/mlir/quantization/stablehlo/cc/pass_pipeline.h" @@ -56,8 +57,10 @@ void AddUnfuseMhloOpsPasses(mlir::PassManager& pm) { mlir::mhlo::createLegalizeEinsumToDotGeneralPass()); pm.addNestedPass( mlir::mhlo::createLegalizeDotToDotGeneralPass()); - pm.addNestedPass( - mlir::quant::stablehlo::createUnfuseMhloBatchNormPass()); + // Unfuse mhlo BatchNorm to primitive ops. + pm.addNestedPass(mlir::odml::createUnfuseBatchNormPass()); + // Fuse Conv + Mul to Conv. + pm.addNestedPass(mlir::odml::createFuseConvolutionPass()); pm.addNestedPass( mlir::mhlo::createLegalizeTorchIndexSelectToGatherPass()); } diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/quantize_preprocess.h b/tensorflow/compiler/mlir/quantization/tensorflow/quantize_preprocess.h index c7e191796031f4..740dca6c7b106b 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/quantize_preprocess.h +++ b/tensorflow/compiler/mlir/quantization/tensorflow/quantize_preprocess.h @@ -23,6 +23,7 @@ limitations under the License. #include "absl/strings/string_view.h" #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/Pass/PassManager.h" // from @llvm-project #include "tensorflow/core/public/session.h" namespace tensorflow { @@ -58,6 +59,8 @@ inline absl::Status PreprocessAndFreezeGraph(mlir::ModuleOp module_op, /*deserialize_xla_call_module=*/false); } +void AddTFToStablehloPasses(mlir::PassManager& pm); + } // namespace quantization } // namespace tensorflow diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/tests/add_dump_tensor_op_stablehlo.mlir b/tensorflow/compiler/mlir/quantization/tensorflow/tests/add_dump_tensor_op_stablehlo.mlir new file mode 100644 index 00000000000000..f4ef2e0f1d26f8 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/tensorflow/tests/add_dump_tensor_op_stablehlo.mlir @@ -0,0 +1,76 @@ +// RUN: tf-quant-opt %s -split-input-file -quant-add-dump-tensor-op='debugger_type=int_per_layer' | FileCheck --check-prefix=IntPerLayer %s + +module { + func.func @matmul2(%arg0: tensor {tf_saved_model.index_path = ["input_tensor"]}) -> (tensor) { + %0 = stablehlo.constant dense<[-0.211145893, -0.708605706]> : tensor<2xf32> + %1 = stablehlo.constant dense<[[-0.630731344, 0.54962182], [0.180364341, -0.764542698]]> : tensor<2x2xf32> + %2 = "tf.XlaCallModule"(%arg0, %1, %0) <{Sout = [#tf_type.shape], module = "", version = 9 : i64}> {_entry_function = @composite_dot_general_with_bias_and_relu6_dynamic_fn_2, _original_entry_function = "composite_dot_general_with_bias_and_relu6_dynamic_fn_2", _tfl_quant_trait = "fully_quantizable"} : (tensor, tensor<2x2xf32>, tensor<2xf32>) -> tensor + %3 = "tf.XlaCallModule"(%2, %1, %0) <{Sout = [#tf_type.shape], module = "", version = 9 : i64}> {_entry_function = @composite_dot_general_with_bias_and_relu6_dynamic_fn_1, _original_entry_function = "composite_dot_general_with_bias_and_relu6_dynamic_fn_1", _tfl_quant_trait = "fully_quantizable"} : (tensor, tensor<2x2xf32>, tensor<2xf32>) -> tensor + return %3 : tensor + } + func.func private @composite_dot_general_with_bias_and_relu6_dynamic_fn_2(%arg0: tensor, %arg1: tensor<2x2xf32>, %arg2: tensor<2xf32>) -> tensor attributes {_from_xla_call_module, tf_quant.composite_function} { + %0 = stablehlo.constant dense<0.000000e+00> : tensor + %1 = stablehlo.constant dense<6.000000e+00> : tensor + %2 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor, tensor<2x2xf32>) -> tensor + %3 = shape.shape_of %2 : tensor -> tensor<2xindex> + %4 = stablehlo.dynamic_broadcast_in_dim %arg2, %3, dims = [1] : (tensor<2xf32>, tensor<2xindex>) -> tensor + %5 = stablehlo.add %2, %4 : tensor + %6 = stablehlo.clamp %0, %5, %1 : (tensor, tensor, tensor) -> tensor + return %6 : tensor + } + func.func private @composite_dot_general_with_bias_and_relu6_dynamic_fn_1(%arg0: tensor, %arg1: tensor<2x2xf32>, %arg2: tensor<2xf32>) -> tensor attributes {_from_xla_call_module, tf_quant.composite_function} { + %0 = stablehlo.constant dense<0.000000e+00> : tensor + %1 = stablehlo.constant dense<6.000000e+00> : tensor + %2 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor, tensor<2x2xf32>) -> tensor + %3 = shape.shape_of %2 : tensor -> tensor<2xindex> + %4 = stablehlo.dynamic_broadcast_in_dim %arg2, %3, dims = [1] : (tensor<2xf32>, tensor<2xindex>) -> tensor + %5 = stablehlo.add %2, %4 : tensor + %6 = stablehlo.clamp %0, %5, %1 : (tensor, tensor, tensor) -> tensor + return %6 : tensor + } + +// IntPerLayer-LABEL: func @matmul2 +// IntPerLayer-DAG: %[[b0:.*]] = stablehlo.constant dense<[-0.211145893 +// IntPerLayer-DAG: %[[w0:.*]] = stablehlo.constant dense<{{\[\[}}-0.630731344 +// IntPerLayer-DAG: %[[matmul0_q:.*]] = "tf.XlaCallModule"(%arg0, %[[w0]], %[[b0]]) <{Sout = [#tf_type.shape], module = "", version = 9 : i64}> {_entry_function = @composite_dot_general_with_bias_and_relu6_dynamic_fn_2, _original_entry_function = "composite_dot_general_with_bias_and_relu6_dynamic_fn_2", _tfl_quant_trait = "fully_quantizable"} : (tensor, tensor<2x2xf32>, tensor<2xf32>) -> tensor +// IntPerLayer-DAG: "tf.DumpTensor"(%[[matmul0_q]]) <{enabled = false, file_name = "quantized_tensor_data.pb", func_name = "composite_dot_general_with_bias_and_relu6_dynamic_fn_2", log_dir_path = "/tmp/dumps/composite_dot_general_with_bias_and_relu6_dynamic_fn_2", node_name = "_empty_node"}> : (tensor) -> () +// IntPerLayer-DAG: %[[matmul0_uq:.*]] = "tf.XlaCallModule"(%arg0, %[[w0]], %[[b0]]) <{Sout = [#tf_type.shape], module = "", version = 9 : i64}> {_entry_function = @composite_dot_general_with_bias_and_relu6_dynamic_fn_2_0, _original_entry_function = "composite_dot_general_with_bias_and_relu6_dynamic_fn_2_0"} : (tensor, tensor<2x2xf32>, tensor<2xf32>) -> tensor +// IntPerLayer-DAG: "tf.DumpTensor"(%[[matmul0_uq]]) <{enabled = false, file_name = "unquantized_tensor_data.pb", func_name = "composite_dot_general_with_bias_and_relu6_dynamic_fn_2", log_dir_path = "/tmp/dumps/composite_dot_general_with_bias_and_relu6_dynamic_fn_2", node_name = "_empty_node"}> : (tensor) -> () +// IntPerLayer-DAG: %[[matmul1_q:.*]] = "tf.XlaCallModule"(%[[matmul0_q]], %[[w0]], %[[b0]]) <{Sout = [#tf_type.shape], module = "", version = 9 : i64}> {_entry_function = @composite_dot_general_with_bias_and_relu6_dynamic_fn_1, _original_entry_function = "composite_dot_general_with_bias_and_relu6_dynamic_fn_1", _tfl_quant_trait = "fully_quantizable"} : (tensor, tensor<2x2xf32>, tensor<2xf32>) -> tensor +// IntPerLayer-DAG: "tf.DumpTensor"(%[[matmul1_q]]) <{enabled = false, file_name = "quantized_tensor_data.pb", func_name = "composite_dot_general_with_bias_and_relu6_dynamic_fn_1", log_dir_path = "/tmp/dumps/composite_dot_general_with_bias_and_relu6_dynamic_fn_1", node_name = "_empty_node"}> : (tensor) -> () +// IntPerLayer-DAG: %[[matmul1_uq:.*]] = "tf.XlaCallModule"(%[[matmul0_q]], %[[w0]], %[[b0]]) <{Sout = [#tf_type.shape], module = "", version = 9 : i64}> {_entry_function = @composite_dot_general_with_bias_and_relu6_dynamic_fn_1_0, _original_entry_function = "composite_dot_general_with_bias_and_relu6_dynamic_fn_1_0"} : (tensor, tensor<2x2xf32>, tensor<2xf32>) -> tensor +// IntPerLayer-DAG: "tf.DumpTensor"(%[[matmul1_uq]]) <{enabled = false, file_name = "unquantized_tensor_data.pb", func_name = "composite_dot_general_with_bias_and_relu6_dynamic_fn_1", log_dir_path = "/tmp/dumps/composite_dot_general_with_bias_and_relu6_dynamic_fn_1", node_name = "_empty_node"}> : (tensor) -> () +// IntPerLayer-DAG: return %[[matmul1_q]] : tensor +// IntPerLayer-DAG: func.func private @composite_dot_general_with_bias_and_relu6_dynamic_fn_2 +// IntPerLayer-DAG: func.func private @composite_dot_general_with_bias_and_relu6_dynamic_fn_1 +// IntPerLayer-DAG: func.func private @composite_dot_general_with_bias_and_relu6_dynamic_fn_2_0 +// IntPerLayer-DAG: func.func private @composite_dot_general_with_bias_and_relu6_dynamic_fn_1_0 +} + +// ----- + +module { + func.func @matmul_concat(%arg0: tensor<1x2xf32> {tf_saved_model.index_path = ["input_tensor"]}) -> (tensor<2x3xf32>) { + %0 = stablehlo.constant dense<[[-0.630731344, 0.54962182, 0.180364341], [-0.764542698, -0.211145893, -0.708605706]]> : tensor<2x3xf32> + %1 = stablehlo.constant dense<1.000000e+00> : tensor<1x3xf32> + %2 = "tf.XlaCallModule"(%arg0, %0) <{Sout = [#tf_type.shape<1x3>], module = "", version = 9 : i64}> {_entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1", _stablehlo_module_attrs = {jax.uses_shape_polymorphism = true}, _tfl_quant_trait = "fully_quantizable"} : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32> + %3 = stablehlo.concatenate %2, %1, dim = 0 : (tensor<1x3xf32>, tensor<1x3xf32>) -> tensor<2x3xf32> + return %3 : tensor<2x3xf32> + } + func.func private @composite_dot_general_fn_1(%arg0: tensor<1x2xf32>, %arg1: tensor<2x3xf32>) -> tensor<1x3xf32> attributes {_from_xla_call_module, tf_quant.composite_function} { + %0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32> + return %0 : tensor<1x3xf32> + } + +// IntPerLayer-LABEL: func @matmul_concat +// IntPerLayer-DAG: %[[w0:.*]] = stablehlo.constant dense<{{\[\[}}-0.630731344 +// IntPerLayer-DAG: %[[c0:.*]] = stablehlo.constant dense<1.000000e+00 +// IntPerLayer-DAG: %[[matmul0_q:.*]] = "tf.XlaCallModule"(%arg0, %[[w0]]) <{Sout = [#tf_type.shape<1x3>], module = "", version = 9 : i64}> {_entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1", _stablehlo_module_attrs = {jax.uses_shape_polymorphism = true}, _tfl_quant_trait = "fully_quantizable"} : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32> +// IntPerLayer-DAG: "tf.DumpTensor"(%[[matmul0_q]]) <{enabled = false, file_name = "quantized_tensor_data.pb", func_name = "composite_dot_general_fn_1", log_dir_path = "/tmp/dumps/composite_dot_general_fn_1", node_name = "_empty_node"}> : (tensor<1x3xf32>) -> () +// IntPerLayer-DAG: %[[matmul0_uq:.*]] = "tf.XlaCallModule"(%arg0, %[[w0]]) <{Sout = [#tf_type.shape<1x3>], module = "", version = 9 : i64}> {_entry_function = @composite_dot_general_fn_1_0, _original_entry_function = "composite_dot_general_fn_1_0", _stablehlo_module_attrs = {jax.uses_shape_polymorphism = true}} : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32> +// IntPerLayer-DAG: "tf.DumpTensor"(%[[matmul0_uq]]) <{enabled = false, file_name = "unquantized_tensor_data.pb", func_name = "composite_dot_general_fn_1", log_dir_path = "/tmp/dumps/composite_dot_general_fn_1", node_name = "_empty_node"}> : (tensor<1x3xf32>) -> () +// IntPerLayer-DAG: %[[concat:.*]] = stablehlo.concatenate %[[matmul0_q]], %[[c0]], dim = 0 : (tensor<1x3xf32>, tensor<1x3xf32>) -> tensor<2x3xf32> +// IntPerLayer-DAG: return %[[concat]] : tensor<2x3xf32> +// IntPerLayer-DAG: func.func private @composite_dot_general_fn_1 +// IntPerLayer-DAG: func.func private @composite_dot_general_fn_1_0 +} diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/tests/prepare_lifting.mlir b/tensorflow/compiler/mlir/quantization/tensorflow/tests/prepare_lifting.mlir index 1e771e2586a61e..772b38f56e242b 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/tests/prepare_lifting.mlir +++ b/tensorflow/compiler/mlir/quantization/tensorflow/tests/prepare_lifting.mlir @@ -359,3 +359,43 @@ func.func @depthwise_conv2d_with_large_weight_and_add(%arg0: tensor<*xf32>) -> ( // CHECK-NEXT: %[[DEPTHWISE_CONV2D:.*]] = "tf.DepthwiseConv2dNative"(%arg0, %[[CONST]]) // CHECK-NEXT: %[[BIASADD:.*]] = "tf.BiasAdd"(%[[DEPTHWISE_CONV2D]], %[[CONST_0]]) // CHECK-NEXT: return %[[BIASADD]] + +// ---- + +func.func @fuse_conv2d_with_sub_and_mul(%arg0: tensor<1x3x4x3xf32>) -> (tensor<1x3x2x2xf32>) { + %cst = "tf.Const"() {value = dense<2.000000e+00> : tensor<2x3x3x2xf32>} : () -> tensor<2x3x3x2xf32> + %cst_0 = "tf.Const"() {value = dense<0.400000e+00> : tensor<1x1x1x2xf32>} : () -> tensor<1x1x1x2xf32> + %cst_1 = "tf.Const"() {value = dense<0.200000e+00> : tensor<1x1x1x2xf32>} : () -> tensor<1x1x1x2xf32> + %0 = "tf.Conv2D"(%arg0, %cst) {data_format = "NHWC", dilations = [1, 1, 2, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 2, 1], use_cudnn_on_gpu = true} : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>) -> tensor<1x3x2x2xf32> + %1 = "tf.Sub"(%0, %cst_0) {data_format = "NHWC"} : (tensor<1x3x2x2xf32>, tensor<1x1x1x2xf32>) -> tensor<1x3x2x2xf32> + %2 = "tf.Mul"(%1, %cst_1) : (tensor<1x3x2x2xf32>, tensor<1x1x1x2xf32>) -> tensor<1x3x2x2xf32> + func.return %2 : tensor<1x3x2x2xf32> +} + +// CHECK: func @fuse_conv2d_with_sub_and_mul +// CHECK-DAG: %[[CONST:.*]] = "tf.Const"() <{value = dense<-0.0800000056> : tensor<2xf32>}> : () -> tensor<2xf32> +// CHECK-DAG: %[[CONST_0:.*]] = "tf.Const"() <{value = dense<4.000000e-01> : tensor<2x3x3x2xf32>}> : () -> tensor<2x3x3x2xf32> +// CHECK-NEXT: %[[CONV:.*]] = "tf.Conv2D"(%arg0, %[[CONST_0]]) +// CHECK-NEXT: %[[BIAS_ADD:.*]] = "tf.BiasAdd"(%[[CONV]], %[[CONST]]) +// CHECK-NEXT: return %[[BIAS_ADD]] + +// ----- + +func.func @fuse_conv2d_with_sub_mul_addv2(%arg0: tensor<1x3x4x3xf32>) -> (tensor<1x3x2x2xf32>) { + %cst = "tf.Const"() {value = dense<2.000000e+00> : tensor<2x3x3x2xf32>} : () -> tensor<2x3x3x2xf32> + %cst_0 = "tf.Const"() {value = dense<0.400000e+00> : tensor<1x1x1x2xf32>} : () -> tensor<1x1x1x2xf32> + %cst_1 = "tf.Const"() {value = dense<0.200000e+00> : tensor<1x1x1x2xf32>} : () -> tensor<1x1x1x2xf32> + %cst_2 = "tf.Const"() {value = dense<0.300000e+00> : tensor<1x1x1x2xf32>} : () -> tensor<1x1x1x2xf32> + %0 = "tf.Conv2D"(%arg0, %cst) {data_format = "NHWC", dilations = [1, 1, 2, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 2, 1], use_cudnn_on_gpu = true} : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>) -> tensor<1x3x2x2xf32> + %1 = "tf.Sub"(%0, %cst_0) {data_format = "NHWC"} : (tensor<1x3x2x2xf32>, tensor<1x1x1x2xf32>) -> tensor<1x3x2x2xf32> + %2 = "tf.Mul"(%1, %cst_1) : (tensor<1x3x2x2xf32>, tensor<1x1x1x2xf32>) -> tensor<1x3x2x2xf32> + %3 = "tf.AddV2"(%2, %cst_2) : (tensor<1x3x2x2xf32>, tensor<1x1x1x2xf32>) -> tensor<1x3x2x2xf32> + func.return %3 : tensor<1x3x2x2xf32> +} + +// CHECK: func @fuse_conv2d_with_sub_mul_addv2 +// CHECK-DAG: %[[CONST:.*]] = "tf.Const"() <{value = dense<2.200000e-01> : tensor<2xf32>}> : () -> tensor<2xf32> +// CHECK-DAG: %[[CONST_0:.*]] = "tf.Const"() <{value = dense<4.000000e-01> : tensor<2x3x3x2xf32>}> : () -> tensor<2x3x3x2xf32> +// CHECK-NEXT: %[[CONV:.*]] = "tf.Conv2D"(%arg0, %[[CONST_0]]) +// CHECK-NEXT: %[[BIAS_ADD:.*]] = "tf.BiasAdd"(%[[CONV]], %[[CONST]]) +// CHECK-NEXT: return %[[BIAS_ADD]] diff --git a/tensorflow/compiler/mlir/register_common_dialects.cc b/tensorflow/compiler/mlir/register_common_dialects.cc index 4cda39bdbb6745..b089bd9a1eb787 100644 --- a/tensorflow/compiler/mlir/register_common_dialects.cc +++ b/tensorflow/compiler/mlir/register_common_dialects.cc @@ -29,7 +29,6 @@ limitations under the License. #include "xla/mlir/framework/ir/xla_framework.h" #include "xla/mlir_hlo/mhlo/IR/register.h" #include "xla/service/cpu/hlo_xla_runtime_pipeline.h" -#include "xla/translate/mhlo_to_lhlo_with_xla/mhlo_to_lhlo_with_xla.h" namespace mlir { diff --git a/tensorflow/compiler/mlir/tensorflow/analysis/resource_alias_analysis.h b/tensorflow/compiler/mlir/tensorflow/analysis/resource_alias_analysis.h index e8eb89fea13f32..7afec29bc5df75 100644 --- a/tensorflow/compiler/mlir/tensorflow/analysis/resource_alias_analysis.h +++ b/tensorflow/compiler/mlir/tensorflow/analysis/resource_alias_analysis.h @@ -59,6 +59,14 @@ class ResourceAliasAnalysisInfo { // `IsUnknownResource(resource) == false`. llvm::SmallSetVector GetResourceAliases(Value resource) const; + llvm::SmallSetVector GetValuesForResourceId(int64_t id) const { + auto it = id_to_resource_values_.find(id); + if (it == id_to_resource_values_.end()) { + return {}; // return empty set + } + return it->getSecond(); + } + // Returns true iff given resource is allocated by op with // `UniqueResourceAllocation` trait. This can be utilized for while-loop // parallelization. diff --git a/tensorflow/compiler/mlir/tensorflow/analysis/resource_dataflow.cc b/tensorflow/compiler/mlir/tensorflow/analysis/resource_dataflow.cc index 5ce3ccf0cc2c8d..5d9c1d32a92a10 100644 --- a/tensorflow/compiler/mlir/tensorflow/analysis/resource_dataflow.cc +++ b/tensorflow/compiler/mlir/tensorflow/analysis/resource_dataflow.cc @@ -31,6 +31,7 @@ limitations under the License. #include "mlir/IR/Value.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" @@ -40,6 +41,10 @@ limitations under the License. namespace mlir { namespace TF { +namespace { +constexpr char kCompositeDevice[] = "tf._composite_device"; +} // namespace + ResourceConstructingOps::ResourceConstructingOps(Operation *op) { if (op) ops.insert(op); } @@ -57,7 +62,11 @@ ResourceConstructingOps ResourceConstructingOps::getPessimisticValueState( auto global_tensor = tf_saved_model::LookupBoundInputOfType< tf_saved_model::GlobalTensorOp>(func, barg.getArgNumber(), symbol_table); - return ResourceConstructingOps(global_tensor); + ResourceConstructingOps result(global_tensor); + if (func.getArgAttr(barg.getArgNumber(), kCompositeDevice)) { + result.is_on_composite_device = true; + } + return result; } } else if (auto vh = dyn_cast(value.getDefiningOp())) { return ResourceConstructingOps(vh); @@ -74,17 +83,24 @@ ResourceConstructingOps ResourceConstructingOps::join( ResourceConstructingOps ret; ret.ops.insert(lhs.ops.begin(), lhs.ops.end()); ret.ops.insert(rhs.ops.begin(), rhs.ops.end()); + ret.is_on_composite_device = + lhs.is_on_composite_device || rhs.is_on_composite_device; return ret; } void ResourceConstructingOps::print(raw_ostream &os) const { - llvm::interleaveComma(ops, os << "["), os << "]"; + llvm::interleaveComma(ops, os << "["); + if (is_on_composite_device) { + os << " COMPOSITE"; + } + os << "]"; } void ResourceDataflowAnalysis::visitOperation(Operation *op, ArrayRef operands, ArrayRef results) { LLVM_DEBUG(llvm::dbgs() << "ResAn: Visiting operation: " << *op << "\n"); + if (auto cast = dyn_cast(op)) { join(results[0], *operands[0]); } else if (auto while_op = dyn_cast(op)) { @@ -94,6 +110,30 @@ void ResourceDataflowAnalysis::visitOperation(Operation *op, join(getLatticeElement(arg), *getLatticeElement(value)); } } + } else if (auto while_op = dyn_cast(op)) { + func::FuncOp cond = SymbolTable::lookupNearestSymbolFrom( + while_op, while_op.getCondAttr()); + func::FuncOp body = SymbolTable::lookupNearestSymbolFrom( + while_op, while_op.getBodyAttr()); + for (auto &arg : while_op->getOpOperands()) { + BlockArgument cond_arg = cond.getArgument(arg.getOperandNumber()); + join(getLatticeElement(cond_arg), *getLatticeElement(arg.get())); + BlockArgument body_arg = body.getArgument(arg.getOperandNumber()); + join(getLatticeElement(body_arg), *getLatticeElement(arg.get())); + } + } else if (auto graph = dyn_cast(op)) { + for (auto &arg : graph.GetFetch()->getOpOperands()) { + if (arg.getOperandNumber() < graph.getNumResults()) { + auto result = graph.getResult(arg.getOperandNumber()); + join(getLatticeElement(result), *getLatticeElement(arg.get())); + } + } + } else if (auto island = dyn_cast(op)) { + for (auto &arg : island.GetYield()->getOpOperands()) { + auto result = island.getResult(arg.getOperandNumber()); + join(getLatticeElement(result), *getLatticeElement(arg.get())); + // getLatticeElement(arg.get())->print(llvm::errs()); + } } else { setAllToEntryStates(results); } diff --git a/tensorflow/compiler/mlir/tensorflow/analysis/resource_dataflow.h b/tensorflow/compiler/mlir/tensorflow/analysis/resource_dataflow.h index 68f6fa2d44c763..61fdb0c39f0693 100644 --- a/tensorflow/compiler/mlir/tensorflow/analysis/resource_dataflow.h +++ b/tensorflow/compiler/mlir/tensorflow/analysis/resource_dataflow.h @@ -42,7 +42,8 @@ struct ResourceConstructingOps { static ResourceConstructingOps getPessimisticValueState(MLIRContext *context); static ResourceConstructingOps getPessimisticValueState(Value value); bool operator==(const ResourceConstructingOps &rhs) const { - return ops == rhs.ops; + return ops == rhs.ops && + is_on_composite_device == rhs.is_on_composite_device; } static ResourceConstructingOps join(const ResourceConstructingOps &lhs, @@ -52,6 +53,8 @@ struct ResourceConstructingOps { // The operation(s) which created the resource value. // IR constructs (i.e., GlobalTensorOp) are not const-correct. mutable DenseSet ops; + + bool is_on_composite_device = false; }; class ResourceDataflowAnalysis diff --git a/tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.cc b/tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.cc index b0d730898316d5..c95dd020497385 100644 --- a/tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.cc +++ b/tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.cc @@ -194,7 +194,7 @@ void CategorizeParallelIdsMap( groups_different_branch = 0; groups_from_only = 0; groups_to_only = 0; - for (auto [group, branch] : from) { + for (const auto& [group, branch] : from) { auto to_iter = to.find(group); if (to_iter == to.end()) { ++groups_from_only; @@ -207,7 +207,7 @@ void CategorizeParallelIdsMap( } } } - for (auto [group, _] : to) { + for (const auto& [group, _] : to) { auto from_iter = from.find(group); if (from_iter == from.end()) { ++groups_to_only; @@ -246,13 +246,13 @@ void SideEffectAnalysisInfo::SetLastWrites( void SideEffectAnalysisInfo::Enter() { per_resource_access_info_.clear(); - for (auto [resource, last_writes] : stack_down_.back()) { + for (const auto& [resource, last_writes] : stack_down_.back()) { SetLastWrites(resource, last_writes); } } void SideEffectAnalysisInfo::Exit() { - for (auto [resource, _] : per_resource_access_info_) { + for (const auto& [resource, _] : per_resource_access_info_) { absl::flat_hash_set last_writes = GetLastWrites(resource); auto& resource_to_operations = stack_up_.back(); resource_to_operations.try_emplace(resource); @@ -265,7 +265,7 @@ void SideEffectAnalysisInfo::Exit() { void SideEffectAnalysisInfo::Down() { stack_down_.emplace_back(); stack_up_.emplace_back(); - for (auto [resource, _] : per_resource_access_info_) { + for (const auto& [resource, _] : per_resource_access_info_) { absl::flat_hash_set last_writes = GetLastWrites(resource); stack_down_.back()[resource] = last_writes; } @@ -279,7 +279,7 @@ void SideEffectAnalysisInfo::Lateral() { void SideEffectAnalysisInfo::Up() { Exit(); - for (auto [resource, last_writes] : stack_up_.back()) { + for (const auto& [resource, last_writes] : stack_up_.back()) { SetLastWrites(resource, last_writes); } stack_down_.pop_back(); diff --git a/tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.h b/tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.h index faa6f7ef89eb59..97fcd30d36d02f 100644 --- a/tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.h +++ b/tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.h @@ -125,6 +125,10 @@ class SideEffectAnalysisInfo { return alias_analysis_.IsUniqueResourceAllocationId(resource_id); } + const TF::ResourceAliasAnalysis::Info& GetAliasAnalysis() const { + return alias_analysis_; + } + private: // Runs the analysis and populates `sorted_control_predecessors_` and // `sorted_control_successors_` for `func_op`. Clears `control_predecessors_`. diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td index 71cf2490e911f4..5e0e58c279e358 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td @@ -12463,6 +12463,18 @@ of the tensor. Rank is also known as "order", "degree", or "ndims." let hasFolder = 1; } +def TF_ReadFileOp : TF_Op<"ReadFile", [Pure, TF_NoConstantFold]> { + let summary = "Reads and outputs the entire contents of the input filename."; + + let arguments = (ins + TF_StrTensor:$filename + ); + + let results = (outs + TF_StrTensor:$contents + ); +} + def TF_ReadVariableOp : TF_Op<"ReadVariableOp", []> { let summary = "Reads the value of a variable."; diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tfrt_ops.cc b/tensorflow/compiler/mlir/tensorflow/ir/tfrt_ops.cc index 879aa62ab28f03..b2ae51a1189686 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tfrt_ops.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tfrt_ops.cc @@ -90,7 +90,7 @@ mlir::LogicalResult PwStreamResultsOp::verify() { } //===----------------------------------------------------------------------===// -// IfrtProgramCall +// IfrtCall //===----------------------------------------------------------------------===// mlir::LogicalResult IfrtCallOp::verify() { @@ -115,6 +115,26 @@ mlir::LogicalResult IfrtCallOp::verify() { } } + // Verify variable_arg_indices is sorted in ascending order. + int64_t prev_index = -1; + for (auto arg_index_attr : getVariableArgIndicesAttr()) { + if (!arg_index_attr.isa_and_nonnull()) { + return emitOpError() << "variable_arg_indices must be an integer"; + } + + int64_t index = + arg_index_attr.dyn_cast().getValue().getSExtValue(); + if (index < 0) { + return emitOpError() << "variable_arg_indices must be positive"; + } + + if (index <= prev_index) { + return emitOpError() + << "variable_arg_indices must be sorted in ascending order"; + } + prev_index = index; + } + return mlir::success(); } diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tfrt_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tfrt_ops.td index af4bdcea69182e..fe230904c241be 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tfrt_ops.td +++ b/tensorflow/compiler/mlir/tensorflow/ir/tfrt_ops.td @@ -63,7 +63,7 @@ def TF__TfrtGetResourceOp : TF_Op<"_TfrtGetResource", let hasVerifier = 1; } -def TF_IfrtLoadVariableOp : TF_Op<"IfrtLoadVariable", []> { +def TF_IfrtLoadVariableOp : TF_Op<"IfrtLoadVariable", [Pure]> { let summary = "Loads a variable tensor as an IFRT array"; let description = [{ @@ -77,6 +77,9 @@ def TF_IfrtLoadVariableOp : TF_Op<"IfrtLoadVariable", []> { `tf.IfrtLoadVariableOp` converts the tensor into an IFRT array based on device and sharding configuration specified in `VariableDeviceShardingConfigProto`. + + This op returns a scalar string tensor containing the loaded variable name, which can be + used as a key to look for the loaded IFRT array in runtime. }]; let arguments = (ins @@ -85,7 +88,12 @@ def TF_IfrtLoadVariableOp : TF_Op<"IfrtLoadVariable", []> { DefaultValuedAttr:$name ); + let results = (outs + TF_StrTensor:$array_key + ); + TF_DerivedOperandTypeListAttr Tin = TF_DerivedOperandTypeListAttr<0>; + TF_DerivedResultTypeListAttr Tout = TF_DerivedResultTypeListAttr<0>; } @@ -101,14 +109,13 @@ def TF_IfrtCallOp : TF_Op<"IfrtCall", []> { that the outlined function is compiled into an executable and is available for lookup from `IfrtCall` TF ops. - This op also takes `variable_names` attribute to bind the variables (weights) - by names. + `variable_arg_indices` is a sorted (ascending order) array and indicates which + element of `args` is a key to a loaded array corresponding to a variable. }]; let arguments = (ins Variadic : $args, I64Attr : $program_id, - StrArrayAttr : $variable_names, I32ArrayAttr : $variable_arg_indices ); diff --git a/tensorflow/compiler/mlir/tensorflow/tests/constant-fold.mlir b/tensorflow/compiler/mlir/tensorflow/tests/constant-fold.mlir index be0ab858484acd..05e5105883f94b 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/constant-fold.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/constant-fold.mlir @@ -752,3 +752,14 @@ func.func @testGlobalIterIdNotFolded() -> (tensor) { // CHECK: return %[[X]] func.return %0: tensor } + +// ----- + +// CHECK-LABEL: func @testReadFileOpNotFolded +func.func @testReadFileOpNotFolded() -> (tensor) { + %0 = "tf.Const"() { value = dense<"filepath"> : tensor } : () -> tensor + // CHECK: %[[X:.*]] = "tf.ReadFile" + %1 = "tf.ReadFile"(%0) : (tensor) -> tensor + // CHECK: return %[[X]] + func.return %1: tensor +} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/convert_control_to_data_outputs.mlir b/tensorflow/compiler/mlir/tensorflow/tests/convert_control_to_data_outputs.mlir index d473c1a7d7b67e..5f59e35498151e 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/convert_control_to_data_outputs.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/convert_control_to_data_outputs.mlir @@ -1,4 +1,4 @@ -// RUN: tf-opt -tf-executor-convert-control-to-data-outputs -split-input-file %s | FileCheck %s +// RUN: tf-opt %s -pass-pipeline='builtin.module(tf-executor-convert-control-to-data-outputs{composite-tpuexecute-side-effects})' -split-input-file -verify-diagnostics | FileCheck %s !tf_res = tensor>> @@ -574,3 +574,497 @@ func.func @unconnected(%arg0: !tf_res, %arg1: !tf_res, %arg2: tensor) { } func.return } + +// ----- + +!tf_res = tensor>> +!tf_str = tensor<3x!tf_type.string> + +// CHECK-LABEL: @tpu_execute_while_body +func.func @tpu_execute_while_body(%arg0: !tf_res, %arg1: !tf_res, + %arg2: tensor) + -> (!tf_res, !tf_res, tensor) { + %graph:3 = tf_executor.graph { + %key, %key_control = tf_executor.island wraps "tf.Const"() {value = dense<"">: !tf_str} : () -> !tf_str + // CHECK: [[exe:%.*]] = tf_executor.island({{.*}}) wraps "tf.TPUExecuteAndUpdateVariables" + %exe_control = tf_executor.island wraps "tf.TPUExecuteAndUpdateVariables"(%arg0, %arg0, %key) { + device_var_reads_indices = [0, 1], + device_var_updates_indices = [0, 1], + device = "task:0" + } : (!tf_res, !tf_res, !tf_str) -> () + + %assign_control_0 = tf_executor.island wraps "tf.AssignVariableOp"(%arg0, %arg2) : (tensor>>, tensor) -> () + %assign_control_1 = tf_executor.island wraps "tf.AssignVariableOp"(%arg1, %arg2) : (tensor>>, tensor) -> () + %add, %add_control = tf_executor.island wraps "tf.Add"(%arg2, %arg2) : (tensor, tensor) -> tensor + %mul, %mul_control = tf_executor.island wraps "tf.Mul"(%arg2, %arg2) : (tensor, tensor) -> tensor + %control_barrier = tf_executor.island(%assign_control_0, %assign_control_1, %add_control, %exe_control) wraps "tf.NoOp"() : () -> () + // CHECK-DAG: [[exe]]{{.*}}"tf.Identity"(%arg3) + // CHECK-DAG: "tf.Identity"(%arg4) + // CHECK: tf_executor.fetch + tf_executor.fetch %arg0, %arg1, %add, %control_barrier, %mul_control : tensor>>, tensor>>, tensor, !tf_executor.control, !tf_executor.control + } + func.return %graph#0, %graph#1, %graph#2 : !tf_res, !tf_res, tensor +} + +// CHECK-LABEL: @tpu_execute_while_cond +func.func @tpu_execute_while_cond(%arg0: !tf_res, %arg1: !tf_res, %arg2: tensor) -> (tensor) { + %graph = tf_executor.graph { + %island, %ctrl = tf_executor.island { + %pred = "tf.SomeOp"(%arg2) : (tensor) -> tensor + tf_executor.yield %pred : tensor + } + tf_executor.fetch %island : tensor + } + func.return %graph : tensor +} + +// CHECK-LABEL: @tpu_execute +func.func @tpu_execute(%arg0: !tf_res, %arg1: !tf_res, %arg2: tensor) { + tf_executor.graph { + // CHECK: "tf.Const"{{.*}}tensor + %while_out:3, %control_while = tf_executor.island wraps "tf.While"(%arg0, %arg1, %arg2) + {body = @tpu_execute_while_body, + cond = @tpu_execute_while_cond, is_stateless = false} + : (tensor>>, tensor>>, tensor) + -> (tensor>>, tensor>>, tensor) + tf_executor.fetch + } + func.return +} + +// ----- + +!tf_res = tensor>> +!tf_str = tensor<3x!tf_type.string> + +// CHECK-LABEL: @incomplete_composite_devices_while_body +func.func @incomplete_composite_devices_while_body(%arg0: !tf_res, %arg1: !tf_res, + %arg2: tensor) + -> (!tf_res, !tf_res, tensor) { + %graph:3 = tf_executor.graph { + %key, %key_control = tf_executor.island wraps "tf.Const"() {value = dense<"">: !tf_str} : () -> !tf_str + // CHECK: [[exe:%.*]] = tf_executor.island({{.*}}) wraps "tf.TPUExecuteAndUpdateVariables" + %exe_control = tf_executor.island wraps "tf.TPUExecuteAndUpdateVariables"(%arg0, %arg1, %key) { + device_var_reads_indices = [0, 1], + device_var_updates_indices = [0, 1], + device = "task:0" + } : (!tf_res, !tf_res, !tf_str) -> () + + %assign_control_0 = tf_executor.island wraps "tf.AssignVariableOp"(%arg0, %arg2) : (tensor>>, tensor) -> () + %assign_control_1 = tf_executor.island wraps "tf.AssignVariableOp"(%arg1, %arg2) : (tensor>>, tensor) -> () + %add, %add_control = tf_executor.island wraps "tf.Add"(%arg2, %arg2) : (tensor, tensor) -> tensor + %mul, %mul_control = tf_executor.island wraps "tf.Mul"(%arg2, %arg2) : (tensor, tensor) -> tensor + %control_barrier = tf_executor.island(%assign_control_0, %assign_control_1, %add_control, %exe_control) wraps "tf.NoOp"() : () -> () + // CHECK: [[exe]]{{.*}}"tf.Identity" + // CHECK-NOT: "tf.Identity" + // CHECK: tf_executor.fetch + tf_executor.fetch %arg0, %arg1, %add, %control_barrier, %mul_control : tensor>>, tensor>>, tensor, !tf_executor.control, !tf_executor.control + } + func.return %graph#0, %graph#1, %graph#2 : !tf_res, !tf_res, tensor +} + +// CHECK-LABEL: @incomplete_composite_devices_while_cond +func.func @incomplete_composite_devices_while_cond(%arg0: !tf_res, %arg1: !tf_res, %arg2: tensor) -> (tensor) { + %graph = tf_executor.graph { + %island, %ctrl = tf_executor.island { + %pred = "tf.SomeOp"(%arg2) : (tensor) -> tensor + tf_executor.yield %pred : tensor + } + tf_executor.fetch %island : tensor + } + func.return %graph : tensor +} + +// CHECK-LABEL: @incomplete_composite_devices +func.func @incomplete_composite_devices(%arg0: !tf_res {tf._composite_device = "/job:tpu_host_worker/replica:0/task:0/device:COMPOSITE:0"}, + %arg1: !tf_res, %arg2: tensor) { + tf_executor.graph { + // CHECK: "tf.Const"{{.*}}tensor + %while_out:3, %control_while = tf_executor.island wraps "tf.While"(%arg0, %arg1, %arg2) + {body = @incomplete_composite_devices_while_body, + cond = @incomplete_composite_devices_while_cond, is_stateless = false} + : (tensor>>, tensor>>, tensor) + -> (tensor>>, tensor>>, tensor) + tf_executor.fetch + } + func.return +} + +// ----- + +!tf_res = tensor>> +!tf_str = tensor<3x!tf_type.string> + +// CHECK-LABEL: @complete_composite_devices_while_body +func.func @complete_composite_devices_while_body(%arg0: !tf_res, %arg1: !tf_res, + %arg2: tensor) + -> (!tf_res, !tf_res, tensor) { + %graph:3 = tf_executor.graph { + %key, %key_control = tf_executor.island wraps "tf.Const"() {value = dense<"">: !tf_str} : () -> !tf_str + // CJHECK: [[exe:%.*]] = tf_executor.island wraps "tf.TPUExecuteAndUpdateVariables" + %exe_control = tf_executor.island wraps "tf.TPUExecuteAndUpdateVariables"(%arg0, %arg1, %key) { + device_var_reads_indices = [0, 1], + device_var_updates_indices = [0, 1], + device = "task:0" + } : (!tf_res, !tf_res, !tf_str) -> () + + %assign_control_0 = tf_executor.island wraps "tf.AssignVariableOp"(%arg0, %arg2) : (tensor>>, tensor) -> () + %assign_control_1 = tf_executor.island wraps "tf.AssignVariableOp"(%arg1, %arg2) : (tensor>>, tensor) -> () + %add, %add_control = tf_executor.island wraps "tf.Add"(%arg2, %arg2) : (tensor, tensor) -> tensor + %mul, %mul_control = tf_executor.island wraps "tf.Mul"(%arg2, %arg2) : (tensor, tensor) -> tensor + %control_barrier = tf_executor.island(%assign_control_0, %assign_control_1, %add_control, %exe_control) wraps "tf.NoOp"() : () -> () + // CHECK: "tf.Identity"(%arg3) + // CHECK: "tf.Identity"(%arg4) + // CHECK: tf_executor.fetch + tf_executor.fetch %arg0, %arg1, %add, %control_barrier, %mul_control : tensor>>, tensor>>, tensor, !tf_executor.control, !tf_executor.control + } + func.return %graph#0, %graph#1, %graph#2 : !tf_res, !tf_res, tensor +} + +// CHECK-LABEL: @complete_composite_devices_while_cond +func.func @complete_composite_devices_while_cond(%arg0: !tf_res, %arg1: !tf_res, %arg2: tensor) -> (tensor) { + %graph = tf_executor.graph { + %island, %ctrl = tf_executor.island { + %pred = "tf.SomeOp"(%arg2) : (tensor) -> tensor + tf_executor.yield %pred : tensor + } + tf_executor.fetch %island : tensor + } + func.return %graph : tensor +} + +// CHECK-LABEL: @complete_composite_devices +func.func @complete_composite_devices( + %arg0: !tf_res {tf._composite_device = "/job:tpu_host_worker/replica:0/task:0/device:COMPOSITE:0"}, + %arg1: !tf_res {tf._composite_device = "/job:tpu_host_worker/replica:0/task:0/device:COMPOSITE:0"}, %arg2: tensor) { + tf_executor.graph { + // CHECK: "tf.Const"{{.*}}tensor + %while_out:3, %control_while = tf_executor.island wraps "tf.While"(%arg0, %arg1, %arg2) + {body = @complete_composite_devices_while_body, + cond = @complete_composite_devices_while_cond, is_stateless = false} + : (tensor>>, tensor>>, tensor) + -> (tensor>>, tensor>>, tensor) + tf_executor.fetch + } + func.return +} + +// ----- + +!tf_res = tensor>> +!tf_str = tensor<3x!tf_type.string> + +// CHECK-LABEL: @tpu_execute_with_non_resource_operands_while_body +func.func @tpu_execute_with_non_resource_operands_while_body(%arg0: !tf_res, %arg1: !tf_res, + %arg2: tensor) + -> (!tf_res, !tf_res, tensor) { + %graph:3 = tf_executor.graph { + %key, %key_control = tf_executor.island wraps "tf.Const"() {value = dense<"">: !tf_str} : () -> !tf_str + // CHECK: [[exe:%.*]] = tf_executor.island({{[^)]*}}) wraps "tf.TPUExecuteAndUpdateVariables" + %exe_control = tf_executor.island wraps "tf.TPUExecuteAndUpdateVariables"(%arg2, %arg0, %arg1, %key) { + device_var_reads_indices = [1, 2], + device_var_updates_indices = [1, 2], + device = "task:0" + } : (tensor, !tf_res, !tf_res, !tf_str) -> () + + %assign_control_0 = tf_executor.island wraps "tf.AssignVariableOp"(%arg0, %arg2) : (tensor>>, tensor) -> () + %assign_control_1 = tf_executor.island wraps "tf.AssignVariableOp"(%arg1, %arg2) : (tensor>>, tensor) -> () + %add, %add_control = tf_executor.island wraps "tf.Add"(%arg2, %arg2) : (tensor, tensor) -> tensor + %mul, %mul_control = tf_executor.island wraps "tf.Mul"(%arg2, %arg2) : (tensor, tensor) -> tensor + %control_barrier = tf_executor.island(%assign_control_0, %assign_control_1, %add_control, %exe_control) wraps "tf.NoOp"() : () -> () + // CHECK: "tf.Identity"(%arg3) + // CHECK: "tf.Identity"(%arg4) + // CHECK: tf_executor.fetch + tf_executor.fetch %arg0, %arg1, %add, %control_barrier, %mul_control : tensor>>, tensor>>, tensor, !tf_executor.control, !tf_executor.control + } + func.return %graph#0, %graph#1, %graph#2 : !tf_res, !tf_res, tensor +} + +// CHECK-LABEL: @tpu_execute_with_non_resource_operands_while_cond +func.func @tpu_execute_with_non_resource_operands_while_cond(%arg0: !tf_res, %arg1: !tf_res, %arg2: tensor) -> (tensor) { + %graph = tf_executor.graph { + %island, %ctrl = tf_executor.island { + %pred = "tf.SomeOp"(%arg2) : (tensor) -> tensor + tf_executor.yield %pred : tensor + } + tf_executor.fetch %island : tensor + } + func.return %graph : tensor +} + +// CHECK-LABEL: @tpu_execute_with_non_resource_operands +func.func @tpu_execute_with_non_resource_operands(%arg0: !tf_res {tf._composite_device = "/job:tpu_host_worker/replica:0/task:0/device:COMPOSITE:0"}, + %arg1: !tf_res {tf._composite_device = "/job:tpu_host_worker/replica:0/task:0/device:COMPOSITE:0"}, %arg2: tensor) { + tf_executor.graph { + // CHECK: "tf.Const"{{.*}}tensor + %while_out:3, %control_while = tf_executor.island wraps "tf.While"(%arg0, %arg1, %arg2) + {body = @tpu_execute_with_non_resource_operands_while_body, + cond = @tpu_execute_with_non_resource_operands_while_cond, is_stateless = false} + : (tensor>>, tensor>>, tensor) + -> (tensor>>, tensor>>, tensor) + tf_executor.fetch + } + func.return +} + +// ----- + +!tf_res = tensor>> +!tf_str = tensor<3x!tf_type.string> + +// CHECK-LABEL: @double_tpu_execute_while_body +func.func @double_tpu_execute_while_body(%arg0: !tf_res, %arg1: !tf_res, + %arg2: tensor) + -> (!tf_res, !tf_res, tensor) { + // CHECK: "tf.Identity" + %graph:3 = tf_executor.graph { + // CHECK: {{.*}}, [[ctrl1:%.*]] = tf_executor.island wraps "tf.Identity" + // CHECK: {{.*}}, [[ctrl2:%.*]] = tf_executor.island wraps "tf.Identity" + // CHECK: "tf.Identity" + %key, %key_control = tf_executor.island wraps "tf.Const"() {value = dense<"">: !tf_str} : () -> !tf_str + // CHECK: [[exe_ctrl1:%.*]] = tf_executor.island([[ctrl1]]) wraps "tf.TPUExecuteAndUpdateVariables" + %exe_control1 = tf_executor.island wraps "tf.TPUExecuteAndUpdateVariables"(%arg2, %arg0, %arg1, %key) { + device_var_reads_indices = [1, 2], + device_var_updates_indices = [1, 2], + device = "task:0" + } : (tensor, !tf_res, !tf_res, !tf_str) -> () + + // CHECK: [[exe_ctrl2:%.*]] = tf_executor.island([[ctrl2]]) wraps "tf.TPUExecuteAndUpdateVariables" + %exe_control2 = tf_executor.island wraps "tf.TPUExecuteAndUpdateVariables"(%arg2, %arg0, %arg1, %key) { + device_var_reads_indices = [1, 2], + device_var_updates_indices = [1, 2], + device = "task:1" + } : (tensor, !tf_res, !tf_res, !tf_str) -> () + + %assign_control_0 = tf_executor.island wraps "tf.AssignVariableOp"(%arg0, %arg2) : (tensor>>, tensor) -> () + %assign_control_1 = tf_executor.island wraps "tf.AssignVariableOp"(%arg1, %arg2) : (tensor>>, tensor) -> () + %add, %add_control = tf_executor.island wraps "tf.Add"(%arg2, %arg2) : (tensor, tensor) -> tensor + %mul, %mul_control = tf_executor.island wraps "tf.Mul"(%arg2, %arg2) : (tensor, tensor) -> tensor + %control_barrier = tf_executor.island(%assign_control_0, %assign_control_1, %add_control, + %exe_control1, %exe_control2) wraps "tf.NoOp"() : () -> () + // CHECK: tf_executor.island([[exe_ctrl1]]) wraps "tf.Identity" + // CHECK: tf_executor.island([[exe_ctrl2]]) wraps "tf.Identity" + // CHECK: tf_executor.fetch + tf_executor.fetch %arg0, %arg1, %add, %control_barrier, %mul_control : tensor>>, tensor>>, tensor, !tf_executor.control, !tf_executor.control + } + func.return %graph#0, %graph#1, %graph#2 : !tf_res, !tf_res, tensor +} + +// CHECK-LABEL: @double_tpu_execute_while_cond +func.func @double_tpu_execute_while_cond(%arg0: !tf_res, %arg1: !tf_res, %arg2: tensor) -> (tensor) { + %graph = tf_executor.graph { + %island, %ctrl = tf_executor.island { + %pred = "tf.SomeOp"(%arg2) : (tensor) -> tensor + tf_executor.yield %pred : tensor + } + tf_executor.fetch %island : tensor + } + func.return %graph : tensor +} + +// CHECK-LABEL: @double_tpu_execute +func.func @double_tpu_execute(%arg0: !tf_res {tf._composite_device = "/job:tpu_host_worker/replica:0/task:0/device:COMPOSITE:0"}, + %arg1: !tf_res {tf._composite_device = "/job:tpu_host_worker/replica:0/task:0/device:COMPOSITE:0"}, %arg2: tensor) { + tf_executor.graph { + // CHECK: "tf.Const"{{.*}}tensor + %while_out:3, %control_while = tf_executor.island wraps "tf.While"(%arg0, %arg1, %arg2) + {body = @double_tpu_execute_while_body, + cond = @double_tpu_execute_while_cond, is_stateless = false} + : (tensor>>, tensor>>, tensor) + -> (tensor>>, tensor>>, tensor) + tf_executor.fetch + } + func.return +} + +// ----- + +!tf_res = tensor>> +!tf_str = tensor<3x!tf_type.string> + +// CHECK-LABEL: @tpu_executes_on_same_device_while_body +func.func @tpu_executes_on_same_device_while_body(%arg0: !tf_res, %arg1: !tf_res, + %arg2: tensor) + -> (!tf_res, !tf_res, tensor) { + %graph:3 = tf_executor.graph { + // CHECK: "tf.Identity" + // CHECK: {{.*}}, [[id_ctrl:%.*]] = tf_executor.island wraps "tf.Identity" + // CHECK: "tf.Identity" + %key, %key_control = tf_executor.island wraps "tf.Const"() {value = dense<"">: !tf_str} : () -> !tf_str + // CHECK: [[exe_ctrl1:%.*]] = tf_executor.island([[id_ctrl]]) wraps "tf.TPUExecuteAndUpdateVariables" + %exe_control1 = tf_executor.island wraps "tf.TPUExecuteAndUpdateVariables"(%arg2, %arg0, %arg1, %key) { + device_var_reads_indices = [1, 2], + device_var_updates_indices = [1, 2], + device = "task:0" + } : (tensor, !tf_res, !tf_res, !tf_str) -> () + + // CHECK: [[exe_ctrl2:%.*]] = tf_executor.island([[id_ctrl]]) wraps "tf.TPUExecuteAndUpdateVariables" + %exe_control2 = tf_executor.island wraps "tf.TPUExecuteAndUpdateVariables"(%arg2, %arg0, %arg1, %key) { + device_var_reads_indices = [1, 2], + device_var_updates_indices = [1, 2], + device = "task:0" + } : (tensor, !tf_res, !tf_res, !tf_str) -> () + + %assign_control_0 = tf_executor.island wraps "tf.AssignVariableOp"(%arg0, %arg2) : (tensor>>, tensor) -> () + %assign_control_1 = tf_executor.island wraps "tf.AssignVariableOp"(%arg1, %arg2) : (tensor>>, tensor) -> () + %add, %add_control = tf_executor.island wraps "tf.Add"(%arg2, %arg2) : (tensor, tensor) -> tensor + %mul, %mul_control = tf_executor.island wraps "tf.Mul"(%arg2, %arg2) : (tensor, tensor) -> tensor + %control_barrier = tf_executor.island(%assign_control_0, %assign_control_1, %add_control, + %exe_control1, %exe_control2) wraps "tf.NoOp"() : () -> () + // CHECK: "tf.Identity"(%arg3) + // CHECK: tf_executor.island([[exe_ctrl1]], [[exe_ctrl2]]) wraps "tf.Identity" + // CHECK: "tf.Identity"(%arg5) + // CHECK-NEXT: tf_executor.fetch + tf_executor.fetch %arg0, %arg1, %add, %control_barrier, %mul_control : tensor>>, tensor>>, tensor, !tf_executor.control, !tf_executor.control + } + func.return %graph#0, %graph#1, %graph#2 : !tf_res, !tf_res, tensor +} + +// CHECK-LABEL: @tpu_executes_on_same_device_while_cond +func.func @tpu_executes_on_same_device_while_cond(%arg0: !tf_res, %arg1: !tf_res, %arg2: tensor) -> (tensor) { + %graph = tf_executor.graph { + %island, %ctrl = tf_executor.island { + %pred = "tf.SomeOp"(%arg2) : (tensor) -> tensor + tf_executor.yield %pred : tensor + } + tf_executor.fetch %island : tensor + } + func.return %graph : tensor +} + +// CHECK-LABEL: @tpu_executes_on_same_device +func.func @tpu_executes_on_same_device(%arg0: !tf_res {tf._composite_device = "/job:tpu_host_worker/replica:0/task:0/device:COMPOSITE:0"}, + %arg1: !tf_res {tf._composite_device = "/job:tpu_host_worker/replica:0/task:0/device:COMPOSITE:0"}, %arg2: tensor) { + tf_executor.graph { + // CHECK: "tf.Const"{{.*}}tensor + %while_out:3, %control_while = tf_executor.island wraps "tf.While"(%arg0, %arg1, %arg2) + {body = @tpu_executes_on_same_device_while_body, + cond = @tpu_executes_on_same_device_while_cond, is_stateless = false} + : (tensor>>, tensor>>, tensor) + -> (tensor>>, tensor>>, tensor) + tf_executor.fetch + } + func.return +} + +// ----- + +!tf_res = tensor>> +!tf_str = tensor<3x!tf_type.string> + +// CHECK-LABEL: @tpu_execute_and_assign_variable_while_body +func.func @tpu_execute_and_assign_variable_while_body(%arg0: !tf_res, %arg1: !tf_res, + %arg2: tensor) + -> (!tf_res, !tf_res, tensor) { + %graph:3 = tf_executor.graph { + // CHECK: tf.Identity + // CHECK-NOT: tf.Identity + // CHECK: TPUExecuteAndUpdate + %key, %key_control = tf_executor.island wraps "tf.Const"() {value = dense<"">: !tf_str} : () -> !tf_str + %exe_control = tf_executor.island wraps "tf.TPUExecuteAndUpdateVariables"(%arg0, %arg1, %key) { + device_var_reads_indices = [1, 2], + device_var_updates_indices = [1, 2], + device = "task:0" + } : (!tf_res, !tf_res, !tf_str) -> () + + // CHECK: AssignVariableOp + %assign_control = tf_executor.island wraps "tf.AssignVariableOp"(%arg0, %arg2) { + device = "task:0" + } : (tensor>>, tensor) -> () + // CHECK: tf.Identity + // CHECK-NOT: tf.Identity + %control_barrier = tf_executor.island(%assign_control, %exe_control) wraps "tf.NoOp"() : () -> () + // CHECK: fetch + tf_executor.fetch %arg0, %arg1, %arg2, %control_barrier : !tf_res, !tf_res, tensor, !tf_executor.control + } + func.return %graph#0, %graph#1, %graph#2 : !tf_res, !tf_res, tensor +} + +// CHECK-LABEL: @tpu_execute_and_assign_variable_while_cond +func.func @tpu_execute_and_assign_variable_while_cond(%arg0: !tf_res, %arg1: !tf_res, %arg2: tensor) -> (tensor) { + %graph = tf_executor.graph { + %island, %ctrl = tf_executor.island { + %pred = "tf.SomeOp"(%arg2) : (tensor) -> tensor + tf_executor.yield %pred : tensor + } + tf_executor.fetch %island : tensor + } + func.return %graph : tensor +} + +// CHECK-LABEL: @tpu_execute_and_assign_variable +func.func @tpu_execute_and_assign_variable(%arg0: !tf_res {tf._composite_device = "/job:tpu_host_worker/replica:0/task:0/device:COMPOSITE:0"}, + %arg1: !tf_res {tf._composite_device = "/job:tpu_host_worker/replica:0/task:0/device:COMPOSITE:0"}, %arg2: tensor) { + tf_executor.graph { + // CHECK: "tf.Const"{{.*}}tensor + %while_out:3, %control_while = tf_executor.island wraps "tf.While"(%arg0, %arg1, %arg2) + {body = @tpu_execute_and_assign_variable_while_body, + cond = @tpu_execute_and_assign_variable_while_cond, is_stateless = false} + : (tensor>>, tensor>>, tensor) + -> (tensor>>, tensor>>, tensor) + tf_executor.fetch + } + func.return +} + +// ----- + +!tf_res = tensor>> +!tf_str = tensor<3x!tf_type.string> + +// CHECK-LABEL: @tpu_execute_and_assign_variable_on_different_devices_while_body +func.func @tpu_execute_and_assign_variable_on_different_devices_while_body(%arg0: !tf_res, %arg1: !tf_res, + %arg2: tensor) + -> (!tf_res, !tf_res, tensor) { + %graph:3 = tf_executor.graph { + // CHECK: {{.*}}, [[ctrl1:%.*]] = tf_executor.island wraps "tf.Identity" + // CHECK: {{.*}}, [[ctrl2:%.*]] = tf_executor.island wraps "tf.Identity" + // CHECK-NOT: tf.Identity + // CHECK: [[exe_ctrl:%.*]] = tf_executor.island([[ctrl1]]) wraps "tf.TPUExecuteAndUpdateVariables" + %key, %key_control = tf_executor.island wraps "tf.Const"() {value = dense<"">: !tf_str} : () -> !tf_str + %exe_control = tf_executor.island wraps "tf.TPUExecuteAndUpdateVariables"(%arg0, %arg1, %key) { + device_var_reads_indices = [1, 2], + device_var_updates_indices = [1, 2], + device = "task:0" + } : (!tf_res, !tf_res, !tf_str) -> () + + // CHECK: [[assign_ctrl:%.*]] = tf_executor.island([[ctrl2]]) wraps "tf.AssignVariableOp" + %assign_control = tf_executor.island wraps "tf.AssignVariableOp"(%arg0, %arg2) { + device = "task:1" + } : (tensor>>, tensor) -> () + // CHECK-DAG: tf_executor.island([[exe_ctrl]]) wraps "tf.Identity" + // CHECK-DAG: tf_executor.island([[assign_ctrl]]) wraps "tf.Identity" + // CHECK-NOT: tf.Identity + %control_barrier = tf_executor.island(%assign_control, %exe_control) wraps "tf.NoOp"() : () -> () + // CHECK: fetch + tf_executor.fetch %arg0, %arg1, %arg2, %control_barrier : !tf_res, !tf_res, tensor, !tf_executor.control + } + func.return %graph#0, %graph#1, %graph#2 : !tf_res, !tf_res, tensor +} + +// CHECK-LABEL: @tpu_execute_and_assign_variable_on_different_devices_while_cond +func.func @tpu_execute_and_assign_variable_on_different_devices_while_cond(%arg0: !tf_res, %arg1: !tf_res, %arg2: tensor) -> (tensor) { + %graph = tf_executor.graph { + %island, %ctrl = tf_executor.island { + %pred = "tf.SomeOp"(%arg2) : (tensor) -> tensor + tf_executor.yield %pred : tensor + } + tf_executor.fetch %island : tensor + } + func.return %graph : tensor +} + +// CHECK-LABEL: @tpu_execute_and_assign_variable_on_different_devices +func.func @tpu_execute_and_assign_variable_on_different_devices(%arg0: !tf_res {tf._composite_device = "/job:tpu_host_worker/replica:0/task:0/device:COMPOSITE:0"}, + %arg1: !tf_res {tf._composite_device = "/job:tpu_host_worker/replica:0/task:0/device:COMPOSITE:0"}, %arg2: tensor) { + tf_executor.graph { + // CHECK: "tf.Const"{{.*}}tensor + %while_out:3, %control_while = tf_executor.island wraps "tf.While"(%arg0, %arg1, %arg2) + {body = @tpu_execute_and_assign_variable_on_different_devices_while_body, + cond = @tpu_execute_and_assign_variable_on_different_devices_while_cond, is_stateless = false} + : (tensor>>, tensor>>, tensor) + -> (tensor>>, tensor>>, tensor) + tf_executor.fetch + } + func.return +} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/hoist_broadcast_read.mlir b/tensorflow/compiler/mlir/tensorflow/tests/hoist_broadcast_read.mlir new file mode 100644 index 00000000000000..c88310443ed0d1 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/hoist_broadcast_read.mlir @@ -0,0 +1,68 @@ +// RUN: tf-opt %s -split-input-file -tf-hoist-broadcast-read | FileCheck %s + +// The read should be hoisted. + +// CHECK-LABEL: func @hoist_cpu +func.func @hoist_cpu(%arg0: tensor<*x!tf_type.resource>> {tf.device = "/job:tpu_host_worker/replica:0/task:0/device:CPU:0"}) -> () { + // CHECK: %[[READ:.*]] = "tf.ReadVariableOp" + // CHECK-NEXT: tf_device.replicate + // CHECK-NEXT: "tf.OpA"(%[[READ]]) + tf_device.replicate {n = 2 : i32} { + %0 = "tf.ReadVariableOp"(%arg0) : (tensor<*x!tf_type.resource>>) -> tensor + "tf.OpA"(%0) : (tensor) -> () + } + func.return +} + +// ----- + +// The read should not be hoisted because the resource does not have device type CPU. + +// CHECK-LABEL: func @only_hoist_cpu +func.func @only_hoist_cpu(%arg0: tensor<*x!tf_type.resource>>) -> () { + // CHECK: tf_device.replicate + // CHECK-NEXT: "tf.ReadVariableOp" + tf_device.replicate {n = 2 : i32} { + %0 = "tf.ReadVariableOp"(%arg0) : (tensor<*x!tf_type.resource>>) -> tensor + "tf.OpA"(%0) : (tensor) -> () + } + func.return +} + +// ----- + +// The read should not be hoisted because it follows a write. + +// CHECK-LABEL: func @skip_read_after_write +func.func @skip_read_after_write(%arg0: tensor<*x!tf_type.resource>> {tf.device = "/job:tpu_host_worker/replica:0/task:0/device:CPU:0"}) -> () { + // CHECK: tf_device.replicate + // CHECK: "tf.AssignVariableOp" + // CHECK-NEXT: "tf.ReadVariableOp" + tf_device.replicate {n = 2 : i32} { + %0 = "tf.OpA"() : () -> tensor + "tf.AssignVariableOp"(%arg0, %0) : (tensor<*x!tf_type.resource>>, tensor) -> () + %1 = "tf.ReadVariableOp"(%arg0) : (tensor<*x!tf_type.resource>>) -> tensor + "tf.OpB"(%1) : (tensor) -> () + } + func.return +} + +// ----- + +// Check that hoisting preserves read order. + +// CHECK-LABEL: func @order_preserved +func.func @order_preserved(%arg0: tensor<*x!tf_type.resource>> {tf.device = "/job:tpu_host_worker/replica:0/task:0/device:CPU:0"}, %arg1: tensor<*x!tf_type.resource>>, %arg2: tensor<*x!tf_type.resource>> {tf.device = "/job:tpu_host_worker/replica:0/task:0/device:CPU:0"}) -> () { + // CHECK: %[[READ0:.*]] = "tf.ReadVariableOp"(%arg0) + // CHECK-NEXT: %[[READ2:.*]] = "tf.ReadVariableOp"(%arg2) + // CHECK-NEXT: tf_device.replicate + // CHECK-NEXT: %[[READ1:.*]] = "tf.ReadVariableOp"(%arg1) + // CHECK-NEXT: "tf.OpA"(%[[READ0]], %[[READ1]], %[[READ2]]) + tf_device.replicate {n = 2 : i32} { + %0 = "tf.ReadVariableOp"(%arg0) : (tensor<*x!tf_type.resource>>) -> tensor + %1 = "tf.ReadVariableOp"(%arg1) : (tensor<*x!tf_type.resource>>) -> tensor + %2 = "tf.ReadVariableOp"(%arg2) : (tensor<*x!tf_type.resource>>) -> tensor + "tf.OpA"(%0, %1, %2) : (tensor, tensor, tensor) -> () + } + func.return +} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/mlprogram.mlir b/tensorflow/compiler/mlir/tensorflow/tests/mlprogram.mlir index 6ded59b51ad8d4..72d7584b94e5d8 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/mlprogram.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/mlprogram.mlir @@ -2,7 +2,7 @@ module attributes {tf_saved_model.semantics} { // CHECK-LABEL: func @lowers_to_stablehlo - func.func @lowers_to_stablehlo(%arg0: tensor {tf_saved_model.index_path = []}) -> (tensor<*xi32> {tf_saved_model.index_path = []}) + func.func @lowers_to_stablehlo(%arg0: tensor {tf_saved_model.index_path = []}) -> (tensor {tf_saved_model.index_path = []}) attributes {tf_saved_model.exported_names = ["lowers_to_stablehlo"]} { // CHECK-DAG: [[one:%.*]] = stablehlo.constant dense<1> @@ -18,12 +18,12 @@ module attributes {tf_saved_model.semantics} { %outputs, %control = tf_executor.island wraps "tf.Const"() {device = "", value = dense<1> : tensor} : () -> tensor %outputs_0, %control_1 = tf_executor.island wraps "tf.Const"() {device = "", value = dense<20> : tensor} : () -> tensor %outputs_2, %control_3 = tf_executor.island wraps "tf.Const"() {device = "", value = dense<0> : tensor} : () -> tensor - %outputs_4, %control_5 = tf_executor.island wraps "tf.Range"(%outputs_2, %outputs_0, %outputs) {device = ""} : (tensor, tensor, tensor) -> tensor<*xi32> + %outputs_4, %control_5 = tf_executor.island wraps "tf.Range"(%outputs_2, %outputs_0, %outputs) {device = ""} : (tensor, tensor, tensor) -> tensor %outputs_6, %control_7 = tf_executor.island wraps "tf.Sub"(%outputs_0, %arg0) {device = ""} : (tensor, tensor) -> tensor - %outputs_8, %control_9 = tf_executor.island wraps "tf.FloorDiv"(%outputs_6, %outputs) {device = ""} : (tensor, tensor) -> tensor<*xi32> - tf_executor.fetch %outputs_8 : tensor<*xi32> + %outputs_8, %control_9 = tf_executor.island wraps "tf.FloorDiv"(%outputs_6, %outputs) {device = ""} : (tensor, tensor) -> tensor + tf_executor.fetch %outputs_8 : tensor } - func.return %0 : tensor<*xi32> + func.return %0 : tensor } } @@ -31,7 +31,7 @@ module attributes {tf_saved_model.semantics} { module attributes {tf_saved_model.semantics} { // CHECK-LABEL: func @removes_dead_code - func.func @removes_dead_code(%arg0: tensor<*x!tf_type.resource> {tf._user_specified_name = "iterator", tf.device = "/job:localhost/replica:0/task:0/device:CPU:0", tf_saved_model.index_path = []}) + func.func @removes_dead_code(%arg0: tensor {tf._user_specified_name = "iterator", tf.device = "/job:localhost/replica:0/task:0/device:CPU:0", tf_saved_model.index_path = []}) attributes {tf_saved_model.exported_names = ["removes_dead_code"]} { // CHECK-NEXT: return @@ -39,9 +39,9 @@ module attributes {tf_saved_model.semantics} { %outputs, %control = tf_executor.island wraps "tf.Const"() {device = "", value = dense<1> : tensor} : () -> tensor %outputs_0, %control_1 = tf_executor.island wraps "tf.Const"() {device = "", value = dense<20> : tensor} : () -> tensor %outputs_2, %control_3 = tf_executor.island wraps "tf.Const"() {device = "", value = dense<0> : tensor} : () -> tensor - %outputs_4, %control_5 = tf_executor.island wraps "tf.Range"(%outputs_2, %outputs_0, %outputs) {device = ""} : (tensor, tensor, tensor) -> tensor<*xi32> - %outputs_6, %control_7 = tf_executor.island wraps "tf.Sub"(%outputs_0, %outputs_2) {device = ""} : (tensor, tensor) -> tensor<*xi32> - %outputs_8, %control_9 = tf_executor.island wraps "tf.FloorDiv"(%outputs_6, %outputs) {device = ""} : (tensor<*xi32>, tensor) -> tensor<*xi32> + %outputs_4, %control_5 = tf_executor.island wraps "tf.Range"(%outputs_2, %outputs_0, %outputs) {device = ""} : (tensor, tensor, tensor) -> tensor + %outputs_6, %control_7 = tf_executor.island wraps "tf.Sub"(%outputs_0, %outputs_2) {device = ""} : (tensor, tensor) -> tensor + %outputs_8, %control_9 = tf_executor.island wraps "tf.FloorDiv"(%outputs_6, %outputs) {device = ""} : (tensor, tensor) -> tensor tf_executor.fetch %control_9 : !tf_executor.control } return diff --git a/tensorflow/compiler/mlir/tensorflow/tests/prepare_tpu_computation_for_tf_export.mlir b/tensorflow/compiler/mlir/tensorflow/tests/prepare_tpu_computation_for_tf_export.mlir index 021cad3b78be8f..389a682d3afe46 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/prepare_tpu_computation_for_tf_export.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/prepare_tpu_computation_for_tf_export.mlir @@ -167,9 +167,25 @@ func.func @UnsupportedOp(%arg0: tensor) -> tensor { // _XlaHostComputeMlir with manual_sharding should not fall back to // XlaHostCompute, because XlaHostCompute does not support manual_sharding. +// Instead, it is skipped and the MlirXlaOpKernel is expected to handle it. func.func @HostComputeManualNoFallback(%arg0: tensor) -> () { - // expected-error @+1 {{manual_sharding not supported with fallback}} + // CHECK: "tf._XlaHostComputeMlir" %1 = "tf._XlaHostComputeMlir"(%arg0) {recv_key = "host_compute_channel_recv1", send_key = "host_compute_channel_send1", host_mlir_module = "", manual_sharding = true} : (tensor) -> (tensor) func.return } + +// ----- + +// CHECK-LABEL: test_xla_call_module_with_host_communicative_subcomputation +func.func @test_xla_call_module_with_host_communicative_subcomputation() { + "tf.XlaCallModule"() {Sout = [], device = "", dim_args_spec = [], function_list = [@callee], module = "", platforms = [], version = 4 : i64} : () -> () + func.return +} + +// CHECK-LABEL: callee +func.func private @callee(%arg0: tensor) { + "tf.XlaHostCompute"(%arg0) <{ancestors = [], key = "@host_func", recv_key = "", send_key = "", shapes = []}> {_xla_original_oc_node_name = "hcb0", _xla_token_input_nodes = ["_xla_token_arg_node"]} : (tensor) -> () + return + } + diff --git a/tensorflow/compiler/mlir/tensorflow/tests/region-control-flow-to-functional.mlir b/tensorflow/compiler/mlir/tensorflow/tests/region-control-flow-to-functional.mlir index 322800014ffd16..479df14b6546f4 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/region-control-flow-to-functional.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/region-control-flow-to-functional.mlir @@ -990,3 +990,29 @@ func.func @testNameCollision(%arg0: tensor) { }) : (tensor, tensor) -> (tensor, tensor) return } + +// ----- + +func.func private @my_cond(%arg0: tensor, %arg1: tensor) -> tensor { + %0 = builtin.unrealized_conversion_cast to tensor + return %0 : tensor +} +func.func private @my_body(%arg0: tensor, %arg1: tensor) -> (tensor, tensor) { + return %arg0, %arg1 : tensor, tensor +} +// CHECK-LABEL: testConditionWithPassthroughArgs +func.func @testConditionWithPassthroughArgs(%arg1: tensor, %arg2: tensor) { + // CHECK: "tf.While" + // CHECK-SAME: body = @my_body + // CHECK-SAME: cond = @my_cond + %3:2 = "tf.WhileRegion"(%arg1, %arg2) <{is_stateless = false}> ({ + ^bb0(%barg1: tensor, %barg2: tensor): + %8 = func.call @my_cond(%barg1, %barg2) : (tensor, tensor) -> tensor + "tf.Yield"(%8, %barg1, %barg2) : (tensor, tensor, tensor) -> () + }, { + ^bb0(%barg1: tensor, %barg2: tensor): + %r1, %r2 = func.call @my_body(%barg1, %barg2) : (tensor, tensor) -> (tensor, tensor) + "tf.Yield"(%r1, %r2) : (tensor, tensor) -> () + }) : (tensor, tensor) -> (tensor, tensor) + return +} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/remove_unused_arguments.mlir b/tensorflow/compiler/mlir/tensorflow/tests/remove_unused_arguments.mlir index e4c1941a1c202f..c3608a2fb13145 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/remove_unused_arguments.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/remove_unused_arguments.mlir @@ -175,11 +175,13 @@ func.func @can_remove_all_results(%arg0: f32) -> (f32, f32) { // CHECK-LABEL: @has_inner_function func.func private @has_inner_function(%arg0: f32) -> (f32, f32) { - func.func private @inner() -> (tensor, tensor) { - %0, %1 = "some_constant"() : () -> (tensor, tensor) - // CHECK: return - // CHECK-SAME: tensor, tensor - return %0, %1 : tensor, tensor + builtin.module { + func.func private @inner() -> (tensor, tensor) { + %0, %1 = "some_constant"() : () -> (tensor, tensor) + // CHECK: return + // CHECK-SAME: tensor, tensor + return %0, %1 : tensor, tensor + } } // CHECK: return // CHECK-NOT: arg diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_freeze_global_tensors.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_freeze_global_tensors.mlir index 9c0f9b2eddb29b..3e7f029d8ff864 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_freeze_global_tensors.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_freeze_global_tensors.mlir @@ -161,3 +161,70 @@ module attributes {tf_saved_model.semantics} { // tf_saved_model.semantics. // CHECK-LABEL: module module {} + +// ----- + +!tf_res = tensor>> +!tf_str = tensor<3x!tf_type.string> + +// CHECK-LABEL: module attributes +module attributes {tf_saved_model.semantics} { + +"tf_saved_model.global_tensor"() {sym_name = "v1", type = tensor, value = dense<3.0> : tensor } : () -> () +"tf_saved_model.global_tensor"() {sym_name = "v2", type = tensor, value = dense<2.0> : tensor } : () -> () + +// CHECK-LABEL: @body +func.func private @body(%arg0: !tf_res, %arg1: !tf_res) -> (!tf_res, !tf_res) { + %graph:2 = tf_executor.graph { + %value, %value_control = tf_executor.island wraps "tf.GetKey"() : () -> tensor + %ret0, %ret0_control = tf_executor.island wraps "tf.SomeOp"() : () -> !tf_res + %ret1, %ret1_control = tf_executor.island wraps "tf.SomeOp"() : () -> !tf_res + %control_unknown = tf_executor.island wraps "tf.UnknownOp"() : () -> () + %key, %key_control = tf_executor.island wraps "tf.GetKey"() : () -> !tf_str + // CHECK: "tf.ReadVariableOp"(%arg0) + %read1, %read1_control = tf_executor.island wraps "tf.ReadVariableOp"(%arg0) : (!tf_res) -> tensor + // CHECK: "tf.ReadVariableOp"(%arg1) + %read2, %read2_control = tf_executor.island wraps "tf.ReadVariableOp"(%arg1) : (!tf_res) -> tensor + tf_executor.island wraps "tf.TPUExecuteAndUpdateVariables"(%ret0, %ret1, %key) { + device_var_reads_indices = [0, 1], + device_var_updates_indices = [0, 1]} : (!tf_res, !tf_res, !tf_str) -> () + tf_executor.fetch %ret0, %ret1: !tf_res, !tf_res + } + func.return %graph#0, %graph#1 : !tf_res, !tf_res +} + +// CHECK-LABEL: @cond +func.func private @cond(%arg0: !tf_res, %arg1: !tf_res) -> (tensor) { + %graph = tf_executor.graph { + %island, %ctrl = tf_executor.island { + %pred = "tf.SomeOp"() : () -> tensor + tf_executor.yield %pred : tensor + } + tf_executor.fetch %island : tensor + } + func.return %graph : tensor +} + +// CHECK-LABEL: @test_while_loop +func.func @test_while_loop(%arg0: !tf_res {tf._composite_device = "/job:tpu_host_worker/replica:0/task:0/device:COMPOSITE:0", tf_saved_model.bound_input = @v1}, + %arg1: !tf_res {tf_saved_model.bound_input = @v2}) + attributes {tf_saved_model.exported_names = ["test_while_loop"]} { + // CHECK-DAG: Const{{.*}}2.0 + // CHECK-DAG: Const{{.*}}3.0 + %read1 = "tf.ReadVariableOp"(%arg0) : (!tf_res) -> tensor + %read2 = "tf.ReadVariableOp"(%arg1) : (!tf_res) -> tensor + // CHECK: tf_executor.graph + tf_executor.graph { + %handle0, %handle0_control = tf_executor.island wraps "tf.SomeOp"() : () -> !tf_res + %handle1, %handle1_control = tf_executor.island wraps "tf.SomeOp"() : () -> !tf_res + %control_A = tf_executor.island wraps "tf.OpA"() : () -> () + %while_out:2, %while_control = tf_executor.island(%control_A) wraps "tf.While"( + %handle0, %handle1) { + body = @body, cond = @cond, is_stateless = false + } : (tensor>>, tensor>>) -> (tensor>>, tensor>>) + %control_B = tf_executor.island(%while_control) wraps "tf.OpB"() : () -> () + tf_executor.fetch + } + func.return +} +} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tfrt_ops.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tfrt_ops.mlir index 3fb11e56172276..3bd1677bc30baa 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tfrt_ops.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tfrt_ops.mlir @@ -13,3 +13,15 @@ func.func @testPwStreamResults(%arg0: tensor, %arg1: tensor) { } // ----- +// CHECK-LABEL: func @test_ifrt_call +func.func @test_ifrt_call(%arg0: tensor, %arg1: tensor) { + %result = "tf.IfrtCall"(%arg0, %arg1) <{program_id = 1234 : i64, variable_arg_indices = [0 : i32, 1 : i32], variable_names = ["a", "b"]}> : (tensor, tensor) -> (tensor<1x1xf32>) + func.return +} + +// ----- +func.func @test_ifrt_call_fail_unsorted_variable_arg_indices(%arg0: tensor, %arg1: tensor) { + // expected-error@below {{variable_arg_indices must be sorted in ascending order}} + %result = "tf.IfrtCall"(%arg0, %arg1) <{program_id = 1234 : i64, variable_arg_indices = [1 : i32, 0 : i32], variable_names = ["a", "b"]}> : (tensor, tensor) -> (tensor<1x1xf32>) + func.return +} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tpu_sharding_identification.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tpu_sharding_identification.mlir index c7443b4ec00f62..42ebe73490818f 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tpu_sharding_identification.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tpu_sharding_identification.mlir @@ -174,6 +174,104 @@ func.func @func_with_sharding_after_read_variable(%arg0: tensor<*x!tf_type.resou // ----- +// Tests sharding propagation in while region body. +// CHECK-LABEL: func @check_sharding_for_read_variable_inside_while_body +func.func @check_sharding_for_read_variable_inside_while_body(%arg0 : tensor, %arg1: tensor<*x!tf_type.resource>>) { + %0 = "tf.ReadVariableOp"(%arg1) : (tensor<*x!tf_type.resource>>) -> tensor<128x1024xf32> + %1:1 = "tf_device.cluster_func"(%arg0, %0) {func = @func_with_sharding_inside_while_body, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", num_cores_per_replica = 2 : i64, use_spmd_for_xla_partitioning = true, use_tpu = true} : (tensor, tensor<128x1024xf32>) -> (tensor<128x1024xf32>) + // CHECK: input_sharding_configuration + // CHECK-SAME: ["", "\0A\0B\0C"] + // CHECK: output_sharding_configuration + // CHECK-SAME: ["\0D\0E\0F"] + func.return +} + +// CHECK-LABEL: func @func_with_sharding_inside_while_body +// CHECK-SAME: (%{{[a-z0-9]+}}: tensor {mhlo.sharding = ""}, %{{[a-z0-9]+}}: tensor<128x1024xf32> {mhlo.sharding = "\0A\0B\0C"}) +// CHECK-SAME: -> (tensor<128x1024xf32> {mhlo.sharding = "\0D\0E\0F"}) +func.func @func_with_sharding_inside_while_body(%arg0: tensor, %arg1: tensor<128x1024xf32>) -> (tensor<128x1024xf32>) { + %cst = "tf.Const"() <{value = dense<0> : tensor}> {device = ""} : () -> tensor + %0:2 = "tf.WhileRegion"(%cst, %arg1) <{is_stateless = false, parallel_iterations = 1 : i64}> ({ + ^bb0(%arg2: tensor, %arg3: tensor<128x1024xf32>): + %1 = "tf.Less"(%arg2, %arg0) : (tensor, tensor) -> tensor + "tf.Yield"(%1) : (tensor) -> () + }, { + ^bb0(%arg2: tensor, %arg3: tensor<128x1024xf32>): + %1 = "tf.XlaSharding"(%arg3) <{_XlaSharding = "\0A\0B\0C", sharding = "\0A\0B\0C"}> {unspecified_dims = []} : (tensor<128x1024xf32>) -> tensor<128x1024xf32> + %2 = "tf.Square"(%1) : (tensor<128x1024xf32>) -> tensor<128x1024xf32> + "tf.Yield"(%arg2, %2) : (tensor, tensor<128x1024xf32>) -> () + }) {_num_original_outputs = 1 : i64, _read_only_resource_inputs = [1], _xla_propagate_compile_time_consts = true} : (tensor, tensor<128x1024xf32>) -> (tensor, tensor<128x1024xf32>) + %1 = "tf.XlaSharding"(%0#1) { _XlaSharding = "\0D\0E\0F", sharding = "\0D\0E\0F" } : (tensor<128x1024xf32>) -> tensor<128x1024xf32> + func.return %1 : tensor<128x1024xf32> +} + +// ----- + +// Tests sharding propagation in while region condition. +// CHECK-LABEL: func @check_sharding_for_read_variable_inside_while_cond +func.func @check_sharding_for_read_variable_inside_while_cond(%arg0 : tensor, %arg1: tensor<*x!tf_type.resource>>) { + %0 = "tf.ReadVariableOp"(%arg1) : (tensor<*x!tf_type.resource>>) -> tensor<128x1024xf32> + %1:1 = "tf_device.cluster_func"(%arg0, %0) {func = @func_with_sharding_inside_while_cond, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", num_cores_per_replica = 2 : i64, use_spmd_for_xla_partitioning = true, use_tpu = true} : (tensor, tensor<128x1024xf32>) -> (tensor<128x1024xf32>) + // CHECK: input_sharding_configuration + // CHECK-SAME: ["", "\0A\0B\0C"] + // CHECK: output_sharding_configuration + // CHECK-SAME: ["\0D\0E\0F"] + func.return +} + +// CHECK-LABEL: func @func_with_sharding_inside_while_cond +// CHECK-SAME: (%{{[a-z0-9]+}}: tensor {mhlo.sharding = ""}, %{{[a-z0-9]+}}: tensor<128x1024xf32> {mhlo.sharding = "\0A\0B\0C"}) +// CHECK-SAME: -> (tensor<128x1024xf32> {mhlo.sharding = "\0D\0E\0F"}) +func.func @func_with_sharding_inside_while_cond(%arg0: tensor, %arg1: tensor<128x1024xf32>) -> (tensor<128x1024xf32>) { + %cst = "tf.Const"() <{value = dense<0> : tensor}> {device = ""} : () -> tensor + %0:2 = "tf.WhileRegion"(%cst, %arg1) <{is_stateless = false, parallel_iterations = 1 : i64}> ({ + ^bb0(%arg2: tensor, %arg3: tensor<128x1024xf32>): + %1 = "tf.XlaSharding"(%arg3) <{_XlaSharding = "\0A\0B\0C", sharding = "\0A\0B\0C"}> {unspecified_dims = []} : (tensor<128x1024xf32>) -> tensor<128x1024xf32> + %2 = "tf.Less"(%arg2, %arg0) : (tensor, tensor) -> tensor + "tf.Yield"(%2) : (tensor) -> () + }, { + ^bb0(%arg2: tensor, %arg3: tensor<128x1024xf32>): + %1 = "tf.Square"(%arg3) : (tensor<128x1024xf32>) -> tensor<128x1024xf32> + "tf.Yield"(%arg2, %1) : (tensor, tensor<128x1024xf32>) -> () + }) {_num_original_outputs = 1 : i64, _read_only_resource_inputs = [1], _xla_propagate_compile_time_consts = true} : (tensor, tensor<128x1024xf32>) -> (tensor, tensor<128x1024xf32>) + %1 = "tf.XlaSharding"(%0#1) { _XlaSharding = "\0D\0E\0F", sharding = "\0D\0E\0F" } : (tensor<128x1024xf32>) -> tensor<128x1024xf32> + func.return %1 : tensor<128x1024xf32> +} + +// ----- + +// Tests output sharding propagation in while region body. +// CHECK-LABEL: func @check_output_sharding_for_while_region_op +func.func @check_output_sharding_for_while_region_op(%arg0 : tensor, %arg1: tensor<*x!tf_type.resource>>) { + %0 = "tf.ReadVariableOp"(%arg1) : (tensor<*x!tf_type.resource>>) -> tensor<128x1024xf32> + %1:1 = "tf_device.cluster_func"(%arg0, %0) {func = @func_with_sharded_while_region_op_output, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", num_cores_per_replica = 2 : i64, use_spmd_for_xla_partitioning = true, use_tpu = true} : (tensor, tensor<128x1024xf32>) -> (tensor<128x1024xf32>) + // CHECK: input_sharding_configuration + // CHECK-SAME: ["", ""] + // CHECK: output_sharding_configuration + // CHECK-SAME: ["\0D\0E\0F"] + func.return +} + +// CHECK-LABEL: func @func_with_sharded_while_region_op_output +// CHECK-SAME: (%{{[a-z0-9]+}}: tensor {mhlo.sharding = ""}, %{{[a-z0-9]+}}: tensor<128x1024xf32> {mhlo.sharding = ""}) +// CHECK-SAME: -> (tensor<128x1024xf32> {mhlo.sharding = "\0D\0E\0F"}) +func.func @func_with_sharded_while_region_op_output(%arg0: tensor, %arg1: tensor<128x1024xf32>) -> (tensor<128x1024xf32>) { + %cst = "tf.Const"() <{value = dense<0> : tensor}> {device = ""} : () -> tensor + %0:2 = "tf.WhileRegion"(%cst, %arg1) <{is_stateless = false, parallel_iterations = 1 : i64}> ({ + ^bb0(%arg2: tensor, %arg3: tensor<128x1024xf32>): + %1 = "tf.Less"(%arg2, %arg0) : (tensor, tensor) -> tensor + "tf.Yield"(%1) : (tensor) -> () + }, { + ^bb0(%arg2: tensor, %arg3: tensor<128x1024xf32>): + %1 = "tf.Square"(%arg3) : (tensor<128x1024xf32>) -> tensor<128x1024xf32> + %2 = "tf.XlaSharding"(%1) { _XlaSharding = "\0D\0E\0F", sharding = "\0D\0E\0F" } : (tensor<128x1024xf32>) -> tensor<128x1024xf32> + "tf.Yield"(%arg2, %2) : (tensor, tensor<128x1024xf32>) -> () + }) {_num_original_outputs = 1 : i64, _read_only_resource_inputs = [1], _xla_propagate_compile_time_consts = true} : (tensor, tensor<128x1024xf32>) -> (tensor, tensor<128x1024xf32>) + func.return %0#1 : tensor<128x1024xf32> +} + +// ----- + // Tests with input sharding following an identity op and cast op. // CHECK-LABEL: func @check_sharding_after_cast_op func.func @check_sharding_after_cast_op(%arg0: tensor<*xi32>, %arg1: tensor<*xi1>) { @@ -537,6 +635,33 @@ func.func @func(%arg0: tensor<*xi32> {tf.aliasing_output = 1 : i64}, // ----- +// CHECK-LABEL: func @check_symmetric_alias_propagation +func.func @check_symmetric_alias_propagation(%arg0: tensor<*xi32>, %arg1: tensor<*xi32>) { + // CHECK: tf_device.cluster_func + // CHECK-SAME: input_sharding_configuration = ["\01\02\03", "\04\05\06"] + // CHECK-SAME: output_sharding_configuration = ["\01\02\03", "\04\05\06"] + "tf_device.cluster_func"(%arg0, %arg1) { + func = @func, + use_spmd_for_xla_partitioning = false, num_cores_per_replica = 1 : i64 + } : (tensor<*xi32>, tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi32>) + func.return +} + +// CHECK-LABEL: func @func +// CHECK-SAME: %arg0: tensor<*xi32> {mhlo.sharding = "\01\02\03" +// CHECK-SAME: %arg1: tensor<*xi32> {mhlo.sharding = "\04\05\06" +// CHECK-SAME: ->{{.*}}mhlo.sharding = "\01\02\03"{{.*}}mhlo.sharding = "\04\05\06" +func.func @func(%arg0: tensor<*xi32> {tf.aliasing_output = 0 : i64}, + %arg1: tensor<*xi32> {tf.aliasing_output = 1 : i64}) -> (tensor<*xi32>, tensor<*xi32>) { + %0 = "tf.XlaSharding"(%arg0) { _XlaSharding = "\01\02\03"} : (tensor<*xi32>) -> tensor<*xi32> + %1 = "tf.A"(%0) : (tensor<*xi32>) -> (tensor<*xi32>) + %2 = "tf.B"(%arg1) : (tensor<*xi32>) -> (tensor<*xi32>) + %3 = "tf.XlaSharding"(%2) { _XlaSharding = "\04\05\06"} : (tensor<*xi32>) -> tensor<*xi32> + func.return %2, %3 : tensor<*xi32>, tensor<*xi32> +} + +// ----- + // Partial tiled inputs using XlaSharding ops identified as REPLICATED should keep the sharding configuration. // The following xla.OpSharding is used: // Proto debug string: diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tpu_validate_inputs.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tpu_validate_inputs.mlir index 502bb0d8fb0398..4af8cdf06f727f 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tpu_validate_inputs.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tpu_validate_inputs.mlir @@ -201,4 +201,30 @@ func.func @invalid_TUPLE_sharding_arity(%arg0: tensor) -> tensor { } return %0 : tensor } +// ----- + +// Serialized string: +// "\08\02\2a\08\08\01\1a\01\01\22\01\00\2a\08\08\01\1a\01\01\22\01\01" +// Proto debug string: +// type: TUPLE +// tuple_shardings { +// type: MAXIMAL +// tile_assignment_dimensions: 1 +// tile_assignment_devices: 0 +// } +// tuple_shardings { +// type: MAXIMAL +// tile_assignment_dimensions: 1 +// tile_assignment_devices: 1 +// } + +func.func @outfeed_enqueue_tuple_sharding_exception(%arg0: tensor, %arg1: tensor) -> tensor { + %0 = tf_executor.graph { + %control = tf_executor.island wraps "tf.TPUReplicateMetadata"() {_xla_compile_device_type = "TPU", _tpu_replicate = "cluster", device = "/device:TPU:0", num_cores_per_replica = 2 : i64, num_replicas = 1 : i64, topology = "topology"} : () -> () + %0, %c0 = tf_executor.island wraps "tf.AddV2"(%arg0, %arg1) {_tpu_replicate = "cluster"} : (tensor, tensor) -> tensor + %c1 = tf_executor.island wraps "tf.OutfeedEnqueueTuple"(%arg0, %arg1) {_tpu_replicate = "cluster", _XlaSharding = "\08\02\2a\08\08\01\1a\01\01\22\01\00\2a\08\08\01\1a\01\01\22\01\01"} : (tensor, tensor) -> () + tf_executor.fetch %0 : tensor + } + return %0 : tensor +} // ----- \ No newline at end of file diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/BUILD b/tensorflow/compiler/mlir/tensorflow/transforms/BUILD index 88bf3d1635e653..66b81b0af9bb9a 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/BUILD +++ b/tensorflow/compiler/mlir/tensorflow/transforms/BUILD @@ -595,6 +595,7 @@ cc_library( "//tensorflow/compiler/mlir/tensorflow:tensorflow_analysis", "//tensorflow/compiler/mlir/tensorflow:tensorflow_ops", "//tensorflow/compiler/mlir/tensorflow:tensorflow_side_effects", + "//tensorflow/compiler/mlir/tensorflow:tensorflow_traits", "//tensorflow/compiler/mlir/tensorflow:tensorflow_types", "//tensorflow/compiler/mlir/tensorflow:tf_ops_layout_helper", "//tensorflow/compiler/mlir/tensorflow:tf_saved_model_inc_gen", @@ -627,11 +628,13 @@ cc_library( "//tensorflow/core/protobuf/tpu:compile_metadata_proto_cc", "//tensorflow/core/protobuf/tpu:dynamic_padding_proto_cc", "//tensorflow/core/tpu:tpu_embedding_optimization_parameters_utils", + "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:btree", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/log", "@com_google_absl//absl/memory", + "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:variant", diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/convert_control_to_data_outputs.cc b/tensorflow/compiler/mlir/tensorflow/transforms/convert_control_to_data_outputs.cc index e16082fad89c4c..b528df7d9a5172 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/convert_control_to_data_outputs.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/convert_control_to_data_outputs.cc @@ -14,20 +14,43 @@ limitations under the License. ==============================================================================*/ #include +#include +#include +#include #include +#include +#include +#include +#include "absl/log/log.h" #include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/DenseSet.h" #include "llvm/ADT/EquivalenceClasses.h" #include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallSet.h" #include "llvm/ADT/SmallVector.h" #include "llvm/Support/Casting.h" -#include "llvm/Support/raw_ostream.h" +#include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h" // from @llvm-project +#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h" // from @llvm-project +#include "mlir/Analysis/DataFlow/SparseAnalysis.h" // from @llvm-project +#include "mlir/Analysis/DataFlowFramework.h" // from @llvm-project #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/Location.h" // from @llvm-project #include "mlir/IR/OperationSupport.h" // from @llvm-project #include "mlir/IR/SymbolTable.h" // from @llvm-project +#include "mlir/IR/TypeUtilities.h" // from @llvm-project #include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/IR/Visitors.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/analysis/resource_alias_analysis.h" +#include "tensorflow/compiler/mlir/tensorflow/analysis/resource_dataflow.h" #include "tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" @@ -40,20 +63,36 @@ namespace tf_executor { namespace { using TF::ResourceId; +using ResourceAndDevice = std::pair; static constexpr ResourceId kUnknownResourceId = TF::detail::ResourceAliasAnalysisInfo::kUnknownResourceId; static constexpr ResourceId kInvalidResourceId = TF::detail::ResourceAliasAnalysisInfo::kInvalidResourceId; using OperationSetTy = SmallPtrSet; -using ResourceToOpsMapTy = DenseMap; +using ResourceToOpsMapTy = DenseMap; +using DeviceMap = DenseMap; +constexpr int64_t kAnyDevice = 0; +constexpr ResourceAndDevice kInvalidResourceAndDevice{kInvalidResourceId, + kAnyDevice}; +constexpr ResourceAndDevice kUnknownResourceAndDevice{kUnknownResourceId, + kAnyDevice}; + +constexpr char kDeviceAttr[] = "device"; #define GEN_PASS_DEF_EXECUTORCONVERTCONTROLTODATAOUTPUTSPASS +#define GEN_PASS_DECL_EXECUTORCONVERTCONTROLTODATAOUTPUTSPASS #include "tensorflow/compiler/mlir/tensorflow/transforms/tf_passes.h.inc" -class ConvertControlToDataOutputsPass +struct ConvertControlToDataOutputsPass : public impl::ExecutorConvertControlToDataOutputsPassBase< ConvertControlToDataOutputsPass> { - public: + ConvertControlToDataOutputsPass() = default; + explicit ConvertControlToDataOutputsPass( + bool composite_tpuexecute_side_effects) + : ExecutorConvertControlToDataOutputsPassBase( + ExecutorConvertControlToDataOutputsPassOptions{ + composite_tpuexecute_side_effects}) {} + void runOnOperation() override; }; @@ -74,14 +113,61 @@ SmallVector GetWhileCallers(func::FuncOp func, return while_callers; } +bool IsResourceType(Type type) { + return getElementTypeOrSelf(type).isa(); +} + +bool OnlyOperatesOnCompositeDevices( + TF::TPUExecuteAndUpdateVariablesOp& op, + const TF::SideEffectAnalysis::Info& side_effect_analysis, + const DataFlowSolver& solver) { + auto& alias_analysis = side_effect_analysis.GetAliasAnalysis(); + llvm::SmallSet read_array; + for (const Attribute& attr : op.getDeviceVarReadsIndices()) { + read_array.insert(attr.cast().getInt()); + } + llvm::SmallSet update_array; + for (const Attribute& attr : op.getDeviceVarUpdatesIndices()) { + update_array.insert(attr.cast().getInt()); + } + + for (auto& arg : op->getOpOperands()) { + Value v = arg.get(); + if (!IsResourceType(arg.get().getType())) continue; + if (alias_analysis.IsUnknownResource(v)) continue; + for (auto id : alias_analysis.GetResourceUniqueIds(v)) { + (void)id; + } + } + + for (auto& arg : op->getOpOperands()) { + if (!IsResourceType(arg.get().getType())) { + continue; + } + auto lattice = + solver.lookupState(arg.get()) + ->getValue(); + bool is_read = read_array.contains(arg.getOperandNumber()); + bool is_update = update_array.contains(arg.getOperandNumber()); + // We want the resource operands that are on composite devices to be the + // exact same set as the resource operands that are read or updated. + if ((is_read || is_update) != lattice.is_on_composite_device) { + return false; + } + } + return true; +} + // Populates `chain_resource_to_ops_map`, the map from all resources that need // to be chained to the set of operations that access the resource, and // `resource_equivalence_classes`. Resources are equivalent if they are accessed // by a common op, and equivalent resources will be assigned to the same chain. void CollectChainResources( func::FuncOp func, ResourceToOpsMapTy& chain_resource_to_ops_map, - llvm::EquivalenceClasses& resource_equivalence_classes, - const TF::SideEffectAnalysis::Info& side_effect_analysis) { + llvm::EquivalenceClasses& resource_equivalence_classes, + DeviceMap& devices, + const TF::SideEffectAnalysis::Info& side_effect_analysis, + const DataFlowSolver& solver, bool composite_tpuexecute_side_effects) { auto graph_op = cast(func.front().front()); // For each op in the graph, get the resources it uses and update the access @@ -93,24 +179,78 @@ void CollectChainResources( assert(island.WrapsSingleOp()); Operation& op = island.GetBody().front(); - ResourceId prev_resource_id = kInvalidResourceId; + // If the op only operates on resources stored on devices that are + // "COMPOSITE", then this op is defined to work in parallel with other + // TPUExecute* ops. So we can make all ResourceIds device-specific below. + // (Even the per-op "resource ids", like ResourceEffects::TPUExecute.) + bool op_only_operates_on_composite_devices = false; + if (auto execute = llvm::dyn_cast(op)) { + if (OnlyOperatesOnCompositeDevices(execute, side_effect_analysis, + solver)) { + op_only_operates_on_composite_devices = true; + } + } + + auto device_attr = op.getAttrOfType(kDeviceAttr); + int64_t device_id; + if (!device_attr) { + device_id = kAnyDevice; + } else if (devices.find(device_attr) != devices.end()) { + device_id = devices[device_attr]; + } else { + device_id = 1 + devices.size(); + devices[device_attr] = device_id; + } + + auto& alias_analysis = side_effect_analysis.GetAliasAnalysis(); + + ResourceAndDevice prev_resource_and_device = kInvalidResourceAndDevice; for (auto resource_id_read_only_pair : side_effect_analysis.GetResourceIds(&op)) { - ResourceId resource_id = resource_id_read_only_pair.first; + auto resource_id = resource_id_read_only_pair.first; + // If alias analysis knows about a resource (as evidenced by the fact that + // GetValuesForResourceId isn't empty), and dataflow tells us that this + // stems from a function argument that was annotated as + // "tf._composite_device", then we can treat this resource as + // device-specific (see below). + bool resource_is_on_composite_device = false; + for (Value value : alias_analysis.GetValuesForResourceId(resource_id)) { + auto lattice = + solver.lookupState(value); + if (lattice) { + resource_is_on_composite_device |= + lattice->getValue().is_on_composite_device; + } + } + + // A device-specific resource identifier creates an edge only between ops + // on the same device, thus preventing ops on different devices from + // blocking each other, even if they access the same resource. + ResourceAndDevice resource_and_device; + if (composite_tpuexecute_side_effects && + (op_only_operates_on_composite_devices || + resource_is_on_composite_device)) { + resource_and_device = std::make_pair(resource_id, device_id); + } else { + resource_and_device = std::make_pair(resource_id, kAnyDevice); + } + // If the resource was allocated by an op with `UniqueResourceAllocation` // trait, then we don't need to chain resource ops accessing this resource // between iterations: Every iteration will create a new independent // resource. This enables more parallelism across iterations. - if (!side_effect_analysis.IsUniqueResourceAllocationId(resource_id)) { - chain_resource_to_ops_map[resource_id].insert(&op); - if (prev_resource_id != kInvalidResourceId) { + if (!side_effect_analysis.IsUniqueResourceAllocationId( + resource_and_device.first)) { + chain_resource_to_ops_map[resource_and_device].insert(&op); + if (prev_resource_and_device != kInvalidResourceAndDevice) { // Merge class of current ID with class of previous ID since both // resources are accessed by `op`. - resource_equivalence_classes.unionSets(prev_resource_id, resource_id); + resource_equivalence_classes.unionSets(prev_resource_and_device, + resource_and_device); } else { - resource_equivalence_classes.insert(resource_id); + resource_equivalence_classes.insert(resource_and_device); } - prev_resource_id = resource_id; + prev_resource_and_device = resource_and_device; } } }); @@ -269,7 +409,7 @@ IslandOp CreateIsland(Operation* sub_op, ValueRange control_inputs, // read/write to a resource of the class to the chain_sink operation. void ChainResourceOps( func::FuncOp func, ResourceToOpsMapTy& chain_resource_to_ops_map, - llvm::EquivalenceClasses& resource_equivalence_classes, + llvm::EquivalenceClasses& resource_equivalence_classes, SmallPtrSet ops_connected_to_fetch, int num_old_outputs) { assert(num_old_outputs + resource_equivalence_classes.getNumClasses() == func.getNumArguments()); @@ -310,8 +450,8 @@ void ChainResourceOps( resource_equivalence_classes.member_begin(class_iter); member_iter != resource_equivalence_classes.member_end(); ++member_iter) { - ResourceId resource_id = *member_iter; - auto map_iter = chain_resource_to_ops_map.find(resource_id); + ResourceAndDevice resource_and_device = *member_iter; + auto map_iter = chain_resource_to_ops_map.find(resource_and_device); if (map_iter == chain_resource_to_ops_map.end()) continue; OperationSetTy& resource_ops = map_iter->getSecond(); @@ -386,24 +526,27 @@ TF::WhileOp RewriteWhileOp(TF::WhileOp while_op, int num_resource_inputs, void ConvertControlToDataOutputs( func::FuncOp while_body, SmallVectorImpl& while_callers, OperationSetTy& recompute_analysis_for_funcs, - const TF::SideEffectAnalysis::Info& side_effect_analysis) { + const TF::SideEffectAnalysis::Info& side_effect_analysis, + const DataFlowSolver& solver, bool composite_tpuexecute_side_effects) { if (while_callers.empty()) return; // Collect access information for each resource in the while body that needs // to be chained, along with equivalence classes (resources in one class will // use the same chain). ResourceToOpsMapTy chain_resource_to_ops_map; - llvm::EquivalenceClasses resource_equivalence_classes; - CollectChainResources(while_body, chain_resource_to_ops_map, - resource_equivalence_classes, side_effect_analysis); + llvm::EquivalenceClasses resource_equivalence_classes; + DeviceMap devices; + CollectChainResources( + while_body, chain_resource_to_ops_map, resource_equivalence_classes, + devices, side_effect_analysis, solver, composite_tpuexecute_side_effects); // Check for presence of unknown side-effecting ops within the while loop // body. These ops act as barriers and the optimization would not yield much // inter iteration parallelism for this while loop body. So return with // warning. - if (chain_resource_to_ops_map.count(kUnknownResourceId) > 0) { + if (chain_resource_to_ops_map.count(kUnknownResourceAndDevice) > 0) { std::set blocking_ops; - for (Operation* op : chain_resource_to_ops_map[kUnknownResourceId]) { + for (Operation* op : chain_resource_to_ops_map[kUnknownResourceAndDevice]) { std::string op_name = op->getName().getStringRef().str(); if (blocking_ops.insert(op_name).second) { LOG(WARNING) @@ -459,6 +602,13 @@ void ConvertControlToDataOutputs( void ConvertControlToDataOutputsPass::runOnOperation() { ModuleOp module = getOperation(); + + DataFlowSolver solver; + solver.load(); + solver.load(); + solver.load(); + if (failed(solver.initializeAndRun(module))) return signalPassFailure(); + // This pass assumes that all functions are suitable for export i.e., each // function has a single tf_executor.graph op and all islands wrap the // internal op perfectly. Verify that in the beginning once. @@ -500,7 +650,8 @@ void ConvertControlToDataOutputsPass::runOnOperation() { } ConvertControlToDataOutputs( while_body, while_callers, recompute_analysis_for_funcs, - side_effect_analysis.GetAnalysisForFunc(while_body)); + side_effect_analysis.GetAnalysisForFunc(while_body), solver, + composite_tpuexecute_side_effects_); } } @@ -511,5 +662,12 @@ CreateTFExecutorConvertControlToDataOutputsPass() { return std::make_unique(); } +std::unique_ptr> +CreateTFExecutorConvertControlToDataOutputsPass( + bool composite_tpuexecute_side_effects) { + return std::make_unique( + composite_tpuexecute_side_effects); +} + } // namespace tf_executor } // namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/passes.h b/tensorflow/compiler/mlir/tensorflow/transforms/passes.h index 9261d66377edde..9c475f1f9f5281 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/passes.h +++ b/tensorflow/compiler/mlir/tensorflow/transforms/passes.h @@ -341,6 +341,9 @@ namespace tf_executor { // Creates a pass to chain control outputs of while loop body. std::unique_ptr> CreateTFExecutorConvertControlToDataOutputsPass(); +std::unique_ptr> +CreateTFExecutorConvertControlToDataOutputsPass( + bool composite_tpuexecute_side_effects); std::unique_ptr> CreateTFExecutorCheckControlDependenciesPass(); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/prepare_tpu_computation_for_tf_export.cc b/tensorflow/compiler/mlir/tensorflow/transforms/prepare_tpu_computation_for_tf_export.cc index 56fcdd761999b1..661dafe2a2f327 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/prepare_tpu_computation_for_tf_export.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/prepare_tpu_computation_for_tf_export.cc @@ -47,8 +47,8 @@ bool IsCommunicationOp(Operation* op) { // subcomputation in the TF/XLA bridge. bool SupportsCommunicationComputation(Operation* op) { return isa(op); + TF::XlaCallModuleOp, TF::StatefulPartitionedCallOp, + TF::PartitionedCallOp, TF::LegacyCallOp>(op); } #define GEN_PASS_DEF_PREPARETPUCOMPUTATIONFORTFEXPORTPASS @@ -65,14 +65,17 @@ class RewriteXlaHostComputeMlir public: using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(TF::_XlaHostComputeMlirOp op, - PatternRewriter& rewriter) const override { + LogicalResult match(TF::_XlaHostComputeMlirOp op) const override { if (op.getManualSharding()) { - op.emitOpError() << "manual_sharding not supported with fallback of " - "phase 2 legalize TF/XLA bridge. manual_sharding is " - "used by map_outside_compilation"; + // This rewrite does not support manual_sharding. It is expected that the + // _XlaHostComputeMlirOp registered as an MlirXlaOpKernel will handle this + // case later once the XlaBuilder graph reaches it. return failure(); } + return success(); + } + void rewrite(TF::_XlaHostComputeMlirOp op, + PatternRewriter& rewriter) const override { llvm::SmallVector shape_attrs; shape_attrs.reserve(op.getNumResults()); for (Type ty : op.getResultTypes()) { @@ -132,7 +135,6 @@ class RewriteXlaHostComputeMlir op.getRecvKeyAttr(), /*cost_estimate_ns=*/rewriter.getI64IntegerAttr(kDefaultCostEstimate), /*tpu_core=*/rewriter.getI64IntegerAttr(0)); - return success(); } }; diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/region_control_flow_to_functional.cc b/tensorflow/compiler/mlir/tensorflow/transforms/region_control_flow_to_functional.cc index 3924cd799b4cc6..a669276e35a175 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/region_control_flow_to_functional.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/region_control_flow_to_functional.cc @@ -247,9 +247,10 @@ StringRef ExtractSingleBlockRegion( } // Returns call for region with single call whose result feeds into the -// terminator of the region. if `allow_to_bool` is true, also allows a single -// ToBoolOp between the region yield and the call. Returns none if the region -// does not conform to this pattern. +// terminator of the region. If `allow_to_bool` is true, it allows patterns used +// in the condition of While ops, i.e. it allows a single bool (possibly passed +// through a ToBoolOp) between the region yield and the call. Returns none if +// the region does not conform to this pattern. std::optional IsSingleCallRegion(Region& region, bool allow_to_bool = false) { if (!llvm::hasSingleElement(region)) return std::nullopt; @@ -276,10 +277,23 @@ std::optional IsSingleCallRegion(Region& region, func::CallOp call = dyn_cast(*it++); if (!call) return std::nullopt; - // All call results should feed into expected consumer - // All results of the call should feed into the yield. - if (call.getNumResults() != call_consumer->getNumOperands()) - return std::nullopt; + if (allow_to_bool && call.getNumResults() == 1 && + yield->getNumOperands() != 1) { + // Allow patterns of the form + // %cond = call(...) + // yield %cond, [...passthrough args...] + if (yield->getNumOperands() != block.getNumArguments() + 1) + return std::nullopt; + for (auto [yield_operand, block_arg] : + llvm::zip(yield->getOperands().drop_front(1), block.getArguments())) { + if (yield_operand != block_arg) return std::nullopt; + } + } else { + // All call results should feed into expected consumer + // All results of the call should feed into the yield. + if (call.getNumResults() != call_consumer->getNumOperands()) + return std::nullopt; + } for (auto res_it : llvm::zip(call.getResults(), call_consumer->getOperands())) if (std::get<0>(res_it) != std::get<1>(res_it)) return std::nullopt; diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tf_passes.td b/tensorflow/compiler/mlir/tensorflow/transforms/tf_passes.td index f50286123f2478..b00e70eb73c4cc 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tf_passes.td +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tf_passes.td @@ -516,6 +516,13 @@ def ExecutorConvertControlToDataOutputsPass : Pass<"tf-executor-convert-control- }]; let constructor = "tf_executor::CreateTFExecutorConvertControlToDataOutputsPass()"; + + let options = [ + Option<"composite_tpuexecute_side_effects_", "composite-tpuexecute-side-effects", "bool", + /*default=*/"false", + "Enables certain TPUExecute ops to run in parallel if they only " + "operate on resources that live on composite devices."> + ]; } def ExecutorUpdateControlDependenciesPass : Pass<"tf-executor-update-control-dependencies", "ModuleOp"> { @@ -844,48 +851,48 @@ def DecomposeReduceDatasetPass : Pass<"tf-decompose-reduce-dataset", "mlir::func f = @__reduce_func_1, f._tf_data_function = true, output_shapes = [#tf_type.shape<>], output_types = [i64], use_inter_op_parallelism = true, _xla_compile_device_type="TPU"} : - (tensor, tensor) -> (tensor) + (tensor, tensor) -> (tensor) func.return - } - ``` + } + ``` - with the following reduction function: + with the following reduction function: - ```mlir - func.func private @__reduce_func_1(%arg0: tensor {tf._user_specified_name = "args_0"}, - %arg1: tensor<32xf32> {tf._user_specified_name = "args_1"}) -> (tensor) - attributes {tf._tf_data_function = true, tf.signature.is_stateful} { - %0 = "tf.JustPretend"(%arg1) : (tensor<32xf32>) -> (tensor) - func.return %0 : tensor - } - ``` + ```mlir + func.func private @__reduce_func_1(%arg0: tensor {tf._user_specified_name = "args_0"}, + %arg1: tensor<32xf32> {tf._user_specified_name = "args_1"}) -> (tensor) + attributes {tf._tf_data_function = true, tf.signature.is_stateful} { + %0 = "tf.JustPretend"(%arg1) : (tensor<32xf32>) -> (tensor) + func.return %0 : tensor + } + ``` - will be transformed into: + will be transformed into: - ```mlir - func.func @single_state_single_dataset_type_no_arguments(%arg0: tensor, %arg1: tensor) { - %0 = "tf.AnonymousIteratorV3"() {output_shapes = [#tf_type.shape<32>], output_types = [f32]} : () -> tensor - "tf.MakeIterator"(%arg0, %0) : (tensor, tensor) -> () - %cst = "tf.Const"() {value = dense : tensor} : () -> tensor - %1:2 = "tf.WhileRegion"(%cst, %arg1) ({ - ^bb0(%arg2: tensor, %arg3: tensor): - "tf.Yield"(%arg2) : (tensor) -> () - }, { - ^bb0(%arg2: tensor, %arg3: tensor): - %2 = "tf.IteratorGetNextAsOptional"(%0) {output_shapes = [#tf_type.shape<32>], output_types = [f32]} : (tensor) -> tensor - %3 = "tf.OptionalHasValue"(%2) : (tensor) -> tensor - %4 = "tf.IfRegion"(%3) ({ - %5 = "tf.OptionalGetValue"(%2) : (tensor) -> tensor<32xf32> - %6 = func.call @__reduce_func_1(%arg3, %5) {_xla_compile_device_type = "TPU"} : (tensor, tensor<32xf32>) -> tensor - "tf.Yield"(%6) : (tensor) -> () + ```mlir + func.func @single_state_single_dataset_type_no_arguments(%arg0: tensor, %arg1: tensor) { + %0 = "tf.AnonymousIteratorV3"() {output_shapes = [#tf_type.shape<32>], output_types = [f32]} : () -> tensor + "tf.MakeIterator"(%arg0, %0) : (tensor, tensor) -> () + %cst = "tf.Const"() {value = dense : tensor} : () -> tensor + %1:2 = "tf.WhileRegion"(%cst, %arg1) ({ + ^bb0(%arg2: tensor, %arg3: tensor): + "tf.Yield"(%arg2) : (tensor) -> () }, { - "tf.Yield"(%arg3) : (tensor) -> () - }) {_lower_using_switch_merge = true, is_stateless = false} : (tensor) -> tensor - "tf.Yield"(%3, %4) : (tensor, tensor) -> () - }) {_lower_using_switch_merge = true, is_stateless = false, parallel_iterations = 10 : i64} : (tensor, tensor) -> (tensor, tensor) - return - } - ``` + ^bb0(%arg2: tensor, %arg3: tensor): + %2 = "tf.IteratorGetNextAsOptional"(%0) {output_shapes = [#tf_type.shape<32>], output_types = [f32]} : (tensor) -> tensor + %3 = "tf.OptionalHasValue"(%2) : (tensor) -> tensor + %4 = "tf.IfRegion"(%3) ({ + %5 = "tf.OptionalGetValue"(%2) : (tensor) -> tensor<32xf32> + %6 = func.call @__reduce_func_1(%arg3, %5) {_xla_compile_device_type = "TPU"} : (tensor, tensor<32xf32>) -> tensor + "tf.Yield"(%6) : (tensor) -> () + }, { + "tf.Yield"(%arg3) : (tensor) -> () + }) {_lower_using_switch_merge = true, is_stateless = false} : (tensor) -> tensor + "tf.Yield"(%3, %4) : (tensor, tensor) -> () + }) {_lower_using_switch_merge = true, is_stateless = false, parallel_iterations = 10 : i64} : (tensor, tensor) -> (tensor, tensor) + return + } + ``` }]; let constructor = "TF::CreateDecomposeReduceDatasetPass()"; @@ -2199,14 +2206,14 @@ def HoistLoopInvariantPass : Pass<"tf-hoist-loop-invariant", "mlir::func::FuncOp brevity) ```mlir func.func @hoist_loop_invariant(%arg0, %arg1) { - %var = "tf.VarHandleOp"() {container="", shared_name="var_name", device = "/device:CPU:0"} + %var = "tf.VarHandleOp"() {container="", shared_name="var_name", device = "/device:CPU:0"} %results:2 = "tf.WhileRegion"(%arg0, %arg1) ({ ^bb0(%arg2, %arg3): %0 = "tf.OpA"() {is_stateless = true} "tf.Yield"(%0) }, { ^bb0(%arg2, %arg3): - %1 = "tf.ReadVariableOp"(%var) + %1 = "tf.ReadVariableOp"(%var) %2 = "tf.OpB"(%1) {is_stateless = true} %3 = "tf.OpC"(%arg2, %2) {is_stateless = true} %4 = "tf.OpD"(%arg3, %2) {is_stateless = true} @@ -2218,8 +2225,8 @@ def HoistLoopInvariantPass : Pass<"tf-hoist-loop-invariant", "mlir::func::FuncOp would be transformed to ```mlir func.func @hoist_loop_invariant(%arg0, %arg1) { - %var = "tf.VarHandleOp"() {container="", shared_name="var_name", device = "/device:CPU:0"} - %1 = "tf.ReadVariableOp"(%var) + %var = "tf.VarHandleOp"() {container="", shared_name="var_name", device = "/device:CPU:0"} + %1 = "tf.ReadVariableOp"(%var) %2 = "tf.OpB"(%1) {is_stateless = true} %results:2 = "tf.WhileRegion"(%arg0, %arg1) ({ ^bb0(%arg2, %arg3): diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_sharding_identification_pass.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_sharding_identification_pass.cc index 621e2a156492f7..fb2588f50631e8 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_sharding_identification_pass.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_sharding_identification_pass.cc @@ -17,8 +17,13 @@ limitations under the License. #include #include #include +#include #include +#include "absl/algorithm/container.h" +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" @@ -31,6 +36,7 @@ limitations under the License. #include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/OpDefinition.h" // from @llvm-project #include "mlir/IR/Operation.h" // from @llvm-project #include "mlir/IR/Value.h" // from @llvm-project #include "mlir/IR/Visitors.h" // from @llvm-project @@ -41,9 +47,11 @@ limitations under the License. #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_traits.h" #include "tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.h" #include "xla/client/sharding_builder.h" #include "xla/xla_data.pb.h" +#include "tsl/platform/errors.h" namespace mlir { namespace TFTPU { @@ -51,6 +59,8 @@ namespace { using OpShardingVariant = std::variant; using OpShardingVector = llvm::SmallVector; +using OptionalOpShardingVector = + llvm::SmallVector, 8>; constexpr char kReplicateSharding[] = ""; constexpr char kShardingAttr[] = "mhlo.sharding"; @@ -238,7 +248,7 @@ std::optional AssignLogicalDeviceFromTPUReplicatedCoreAttr( // Cast op may be added right after the input. // // TODO(hongjunchoi): Add logic to parse XlaSharding op inside control flow (If, -// Case, While) ops and Caller return values. +// Case) ops and Caller return values. // TODO(hongjunchoi): Consider explicitly checking op patterns to detect sharded // inputs. std::optional GetXlaShardingFromArg( @@ -260,6 +270,15 @@ std::optional GetXlaShardingFromArg( return logical_device; } + if (auto while_op = llvm::dyn_cast(owner)) { + const int operand_number = use.getOperandNumber(); + next_values_to_visit.push_back( + while_op.getCond().front().getArgument(operand_number)); + next_values_to_visit.push_back( + while_op.getBody().front().getArgument(operand_number)); + continue; + } + if (llvm::isa(owner)) { next_values_to_visit.push_back(use.getOwner()->getResult(0)); continue; @@ -281,15 +300,16 @@ std::optional GetXlaShardingFromArg( return std::nullopt; } -// Extracts sharding configurations for all inputs by parsing XlaSharding/ -// TPUPartitionedInput op connected to the operands/arguments. If argument to -// the `cluster_func` directly feeds into another function call op, then -// recursively walk the function definition to find the connected XlaSharding -// op. +// Tries to extract sharding configurations for all inputs by parsing +// XlaSharding/ TPUPartitionedInput op connected to the operands/arguments. If +// argument to the `cluster_func` directly feeds into another function call op, +// then recursively walk the function definition to find the connected +// XlaSharding op. void IdentifyXlaShardingForComputationInputs( - const llvm::SmallVector& logical_device_vec, bool use_spmd, + const llvm::SmallVector& logical_device_vec, bool infer_from_computation, tf_device::ClusterFuncOp cluster_func, - func::FuncOp func, Builder* builder, OpShardingVector& sharding_for_args) { + func::FuncOp func, Builder* builder, + OptionalOpShardingVector& sharding_for_args) { // Look up function definition from module. Block& function_block = func.front(); @@ -300,8 +320,6 @@ void IdentifyXlaShardingForComputationInputs( // 1) a TPUPartitionedInput Op if the input has a non-resource type; // 2) a ReadVariableOp else. // - // Replicate sharding is used if `use_spmd` is set. - // // Iterate through input arguments to the entry block of // tf_device.ClusterFunc. For input ops, look for XlaSharding ops. // XlaSharding ops can: @@ -330,17 +348,7 @@ void IdentifyXlaShardingForComputationInputs( } } - if (use_spmd) { - // If XLA SPMD is enabled, host variables or non-variable per-replica - // inputs should take on replicate sharding, so that every device gets the - // whole tensor(s) (and can slice them up later). Exclude device - // variables, which always should take maximal sharding. - sharding_for_args.push_back(kReplicateSharding); - continue; - } - - // Otherwise, default to maximal sharding core 0. - sharding_for_args.push_back(logical_device_vec[0]); + sharding_for_args.push_back(std::nullopt); } } @@ -364,31 +372,32 @@ mlir::Operation* GetXlaShardingFromResult(Value value) { return nullptr; } -// Looks up arg->retval aliases for every argument, and builds a reverse map. -void ExtractAliases(func::FuncOp func, llvm::SmallVectorImpl& aliases) { - aliases.resize(func.getNumResults(), -1); - for (int i = 0; i < func.getNumArguments(); i++) { - if (auto v = func.getArgAttrOfType(i, kAliasingAttr)) { - int retval_index = v.getInt(); - if (retval_index >= 0 && retval_index < aliases.size()) { - aliases[retval_index] = i; +absl::Status DetermineShardingFromAlias( + func::FuncOp func, OptionalOpShardingVector& input_shardings, + OptionalOpShardingVector& output_shardings) { + for (int arg_idx = 0; arg_idx < func.getNumArguments(); ++arg_idx) { + if (auto v = + func.getArgAttrOfType(arg_idx, kAliasingAttr)) { + if (int retval_idx = v.getInt(); + retval_idx >= 0 && retval_idx < func.getNumResults()) { + auto& input_sharding = input_shardings[arg_idx]; + auto& output_sharding = output_shardings[retval_idx]; + + if (input_sharding.has_value() && output_sharding.has_value() && + input_sharding.value() != output_sharding.value()) { + return absl::InvalidArgumentError(absl::StrCat( + "arg#", arg_idx, " is aliased to retval#", retval_idx, + " but their sharding configurations don't match.")); + } else if (input_sharding.has_value() && !output_sharding.has_value()) { + output_sharding = input_sharding; + } else if (!input_sharding.has_value() && output_sharding.has_value()) { + input_sharding = output_sharding; + } } } } -} -// Returns XLA sharding from argument connected via tf.aliasing_output. -std::optional GetXlaShardingFromAlias( - Value value, llvm::SmallVectorImpl& aliases, - const OpShardingVector& sharding_for_args) { - int retval_index = value.cast().getResultNumber(); - if (retval_index >= 0 && retval_index < aliases.size()) { - int arg_index = aliases[retval_index]; - if (arg_index >= 0 && arg_index < sharding_for_args.size()) { - return GetShardingStringFromVariant(sharding_for_args[arg_index]); - } - } - return std::nullopt; + return absl::OkStatus(); } // Returns XLA sharding from XlaSharding op connected to a result value. @@ -397,7 +406,7 @@ std::optional GetXlaShardingFromAlias( // used, we might see a Cast op. // // TODO(hongjunchoi): Add logic to parse XlaSharding op inside control flow (If, -// Case, While) ops and Caller argument values. +// Case) ops and Caller argument values. // TODO(hongjunchoi): Consider explicitly checking op patterns to detect sharded // inputs. std::optional GetXlaShardingFromRetval( @@ -456,31 +465,36 @@ std::optional GetXlaShardingFromRetval( values_to_visit.push_back(value_to_visit); continue; } + + if (auto while_op = llvm::dyn_cast(def)) { + if (auto op_result = value_to_visit.cast()) { + int result_idx = op_result.getResultNumber(); + if (auto yield_op = llvm::dyn_cast( + while_op.getBody().front().getTerminator())) { + values_to_visit.push_back(yield_op.getOperand(result_idx)); + } + } + continue; + } } return std::nullopt; } -// Extracts sharding configurations for all outputs by parsing XlaSharding/ -// TPUPartitionedOutput op connected to the retvals/results. +// Tries to extract sharding configurations for all outputs by parsing +// XlaSharding/ TPUPartitionedOutput op connected to the retvals/results. void IdentifyXlaShardingForComputationOutputs( - const llvm::SmallVector& logical_device_vec, bool use_spmd, + const llvm::SmallVector& logical_device_vec, bool infer_from_computation, tf_device::ClusterFuncOp cluster_func, func::FuncOp func, Builder* builder, - const OpShardingVector& sharding_for_args, - OpShardingVector& sharding_for_rets) { + OptionalOpShardingVector& sharding_for_rets) { Block& function_block = func.front(); Operation* terminator = function_block.getTerminator(); sharding_for_rets.reserve(terminator->getNumOperands()); - llvm::SmallVector aliases; // maps return value index to arg index - ExtractAliases(func, aliases); - // Iterate through results of `cluster_func`. For output ops, look for // TPUPartitionedOutput ops. // - // Replicate sharding is used if `use_spmd` is set. - // // Iterate through operands of the terminator. If the preceding op is // XlaShardingOp, then the provided sharding configuration is added to the // tf_device.ClusterFunc as an attribute and the function as a result @@ -495,12 +509,6 @@ void IdentifyXlaShardingForComputationOutputs( continue; } - if (auto from_alias = - GetXlaShardingFromAlias(result, aliases, sharding_for_args)) { - sharding_for_rets.push_back(from_alias.value()); - continue; - } - if (infer_from_computation) { if (auto retval_sharding = GetXlaShardingFromRetval(retval.get(), logical_device_vec)) { @@ -509,18 +517,76 @@ void IdentifyXlaShardingForComputationOutputs( } } - if (use_spmd) { - // If XLA SPMD is enabled, we default to replicate sharding. This way, - // all devices get the whole tensor(s), but if there's an XlaSharding op - // deeper in the function, they can use dynamic-slice to slice off their - // part of the computation. - sharding_for_rets.push_back(kReplicateSharding); - continue; + sharding_for_rets.push_back(std::nullopt); + } +} + +void SetReplicatedOrMaximalShardingIfNoShardingFound( + const llvm::SmallVector& logical_device_vec, bool use_spmd, + OptionalOpShardingVector& shardings) { + for (auto& sharding : shardings) { + if (sharding == std::nullopt) { + // If we haven't found sharding, default to either replicated or maximal + // sharding depending on whether XLA SPMD is enabled. + if (use_spmd) { + // If XLA SPMD is enabled, host variables or non-variable per-replica + // inputs, and outputs should take on replicate sharding, so that every + // device gets the whole tensor(s) (and can slice them up later eg. + // using dynamic-slice). + sharding = kReplicateSharding; + } else { + // Otherwise, default to maximal sharding core 0. + sharding = logical_device_vec[0]; + } + } + } +} + +// Moves shardings from `optional_shardings` to `shardings`. +absl::Status MoveSharding(OptionalOpShardingVector& optional_shardings, + OpShardingVector& shardings) { + shardings.clear(); + for (auto& sharding : optional_shardings) { + if (!sharding) { + return absl::InternalError( + "Couldn't find/assign sharding for an input/output. All shardings " + "should have been identified by this point."); } - // Otherwise, default to maximal sharding core 0. - sharding_for_rets.push_back(logical_device_vec[0]); + shardings.push_back(std::move(sharding.value())); } + + return absl::OkStatus(); +} + +// Determines XlaSharding for inputs and outputs. If there are aliased +// inputs/outputs for which no sharding was found directly, the corresponding +// output/input sharding is used (if it exists). If we still don't find sharding +// for some inputs/outputs, we default to replicated or maximal sharding +// depending on `use_spmd`. +absl::Status IdentifyXlaShardingForInputsAndOutputs( + const llvm::SmallVector& logical_device_vec, bool use_spmd, + bool infer_from_computation, tf_device::ClusterFuncOp cluster_func, + func::FuncOp func, Builder* builder, OpShardingVector& input_sharding, + OpShardingVector& output_sharding) { + OptionalOpShardingVector optional_input_sharding; + OptionalOpShardingVector optional_output_sharding; + IdentifyXlaShardingForComputationInputs( + logical_device_vec, infer_from_computation, cluster_func, func, builder, + optional_input_sharding); + IdentifyXlaShardingForComputationOutputs( + logical_device_vec, infer_from_computation, cluster_func, func, builder, + optional_output_sharding); + TF_RETURN_IF_ERROR(DetermineShardingFromAlias(func, optional_input_sharding, + optional_output_sharding)); + SetReplicatedOrMaximalShardingIfNoShardingFound(logical_device_vec, use_spmd, + optional_input_sharding); + SetReplicatedOrMaximalShardingIfNoShardingFound(logical_device_vec, use_spmd, + optional_output_sharding); + TF_RETURN_IF_ERROR(MoveSharding(optional_input_sharding, input_sharding)); + TF_RETURN_IF_ERROR(MoveSharding(optional_output_sharding, output_sharding)); + + return absl::OkStatus(); } // Extracts input/output sharding configuration of `cluster_func` by parsing @@ -551,15 +617,15 @@ LogicalResult IdentifyXlaShardingForTPUComputation( } OpShardingVector sharding_for_args; - IdentifyXlaShardingForComputationInputs(logical_device_vec, use_spmd, - /*infer_from_computation=*/true, - cluster_func, func, builder, - sharding_for_args); - OpShardingVector sharding_for_rets; - IdentifyXlaShardingForComputationOutputs( - logical_device_vec, use_spmd, /*infer_from_computation=*/true, - cluster_func, func, builder, sharding_for_args, sharding_for_rets); + if (auto status = IdentifyXlaShardingForInputsAndOutputs( + logical_device_vec, use_spmd, + /*infer_from_computation=*/true, cluster_func, func, builder, + sharding_for_args, sharding_for_rets); + !status.ok()) { + LOG(ERROR) << status; + return failure(); + }; auto has_maximal_sharding = [](const OpShardingVariant& sharding_or_op) -> bool { @@ -582,14 +648,14 @@ LogicalResult IdentifyXlaShardingForTPUComputation( sharding_for_rets.clear(); cluster_func->setAttr(kUseSpmdAttr, builder->getBoolAttr(false)); - IdentifyXlaShardingForComputationInputs( - logical_device_vec, - /*use_spmd=*/false, /*infer_from_computation=*/false, cluster_func, - func, builder, sharding_for_args); - IdentifyXlaShardingForComputationOutputs( - logical_device_vec, - /*use_spmd=*/false, /*infer_from_computation=*/false, cluster_func, - func, builder, sharding_for_args, sharding_for_rets); + if (auto status = IdentifyXlaShardingForInputsAndOutputs( + logical_device_vec, /*use_spmd=*/false, + /*infer_from_computation=*/false, cluster_func, func, builder, + sharding_for_args, sharding_for_rets); + !status.ok()) { + LOG(ERROR) << status; + return failure(); + } } // Update sharding on function arguments and returns. @@ -635,7 +701,7 @@ void TPUShardingIdentificationPass::runOnOperation() { if (result.wasInterrupted()) return signalPassFailure(); } -} // anonymous namespace +} // namespace std::unique_ptr> CreateTPUShardingIdentificationPass() { return std::make_unique(); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_validate_inputs.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_validate_inputs.cc index dea95b6e9db021..4dc9daa6c705ee 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_validate_inputs.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_validate_inputs.cc @@ -74,6 +74,7 @@ bool IsTpuRegularOp(Operation* op) { TypeID::get(), TypeID::get(), TypeID::get(), + TypeID::get(), }; return ops_set; }(); diff --git a/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc b/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc index 4160c31515cf62..95c67e6084d90a 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc @@ -182,6 +182,17 @@ void LoadImporterDialects(mlir::MLIRContext& context) { context.getOrLoadDialect(name); } +absl::StatusOr GetDenseTensorNameFromTensorInfo( + const TensorInfo& tensor_info) { + // TODO(b/184675681): Support other encoding cases. + // + // TODO(b/184679394): Add unit test for this check. + TF_RET_CHECK(tensor_info.encoding_case() == tensorflow::TensorInfo::kName) + << "Only dense tensor is supported, but got encoding case " + << tensor_info.encoding_case(); + return tensor_info.name(); +} + // This class is used to generate new MLIR function name strings that are both // unique in the TF function library `flib_` and unique among the name strings // generated by the class object during its lifetime. @@ -3945,7 +3956,11 @@ SavedModelSignatureDefImporterLite::ConvertGraph( specs.graph_func_name = name; specs.prune_unused_nodes = true; TF_ASSIGN_OR_RETURN(specs.inputs, ParseInputArrays(inputs)); - for (auto& output : outputs) specs.outputs.push_back(output.second.name()); + for (auto& output : outputs) { + TF_ASSIGN_OR_RETURN(std::string name, + GetDenseTensorNameFromTensorInfo(output.second)); + specs.outputs.push_back(std::move(name)); + } specs.control_outputs = control_outputs; specs.enable_shape_inference = false; specs.unconditionally_use_set_output_shapes = @@ -4031,12 +4046,8 @@ SavedModelSignatureDefImporterLite::ParseInputArrays( for (const auto& iter : inputs) { const auto& tensor_info = iter.second; - // TODO(b/184675681): Support other encoding cases. - // - // TODO(b/184679394): Add unit test for this check. - TF_RET_CHECK(tensor_info.encoding_case() == tensorflow::TensorInfo::kName) - << "Only dense tensor is supported, but got encoding case " - << tensor_info.encoding_case(); + TF_ASSIGN_OR_RETURN(std::string name, + GetDenseTensorNameFromTensorInfo(tensor_info)); VLOG(1) << "Importing Signature Input: input_name = " << iter.first << ", tensor_info = " << tensor_info.DebugString(); @@ -4052,7 +4063,7 @@ SavedModelSignatureDefImporterLite::ParseInputArrays( array_info.shape.set_unknown_rank(true); } - results.insert(std::pair(tensor_info.name(), + results.insert(std::pair(std::move(name), std::move(array_info))); } return results; diff --git a/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate_registration.cc b/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate_registration.cc index f1617dc7d83787..57b0d0e2ff2389 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate_registration.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate_registration.cc @@ -31,7 +31,7 @@ limitations under the License. #include "xla/client/client_library.h" #include "xla/client/compile_only_client.h" #include "xla/stream_executor/host/host_platform_id.h" -#include "xla/stream_executor/multi_platform_manager.h" +#include "xla/stream_executor/platform_manager.h" #include "tensorflow/core/common_runtime/graph_constructor.h" #include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/graph.pb.h" @@ -136,7 +136,7 @@ static LogicalResult MlirToGraphTranslateFunction(ModuleOp module, } // Use Host platform, which should always exist, to make sure graphs compile. - auto platform = stream_executor::MultiPlatformManager::PlatformWithId( + auto platform = stream_executor::PlatformManager::PlatformWithId( stream_executor::host::kHostPlatformId); auto client = xla::ClientLibrary::GetOrCreateCompileOnlyClient(platform.value()); diff --git a/tensorflow/compiler/mlir/tf2xla/api/v1/BUILD b/tensorflow/compiler/mlir/tf2xla/api/v1/BUILD index cc593a0d49db09..4971d62ea6f388 100644 --- a/tensorflow/compiler/mlir/tf2xla/api/v1/BUILD +++ b/tensorflow/compiler/mlir/tf2xla/api/v1/BUILD @@ -144,15 +144,16 @@ tf_cc_test( name = "compile_tf_graph_test", testonly = 1, srcs = ["compile_tf_graph_test.cc"], + data = [ + "testdata/prepare_to_library.mlir", + ], linkstatic = 1, deps = [ ":compile_tf_graph", "//tensorflow/compiler/jit", "//tensorflow/compiler/jit:xla_tpu_device", - "//tensorflow/compiler/mlir/tensorflow", - "//tensorflow/compiler/mlir/tensorflow:convert_type", - "//tensorflow/compiler/mlir/tensorflow:serialize_mlir_module_utils", "//tensorflow/compiler/mlir/tf2xla/internal:test_matchers", + "//tensorflow/compiler/mlir/tf2xla/internal/utils:test_metadata_config", "//tensorflow/compiler/tf2xla:layout_util", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/tf2xla:xla_helpers", @@ -163,17 +164,15 @@ tf_cc_test( "//tensorflow/core/lib/monitoring:cell_reader", "//tensorflow/core/protobuf/tpu:compile_metadata_proto_cc", "//tensorflow/core/tpu/kernels:tpu_compile_op_support", - "@com_google_absl//absl/status", + "//tensorflow/core/tpu/kernels/xla:host_compute_ops", "@com_google_googletest//:gtest", - "@llvm-project//mlir:FuncDialect", - "@llvm-project//mlir:IR", "@local_tsl//tsl/lib/core:status_test_util", "@local_tsl//tsl/lib/monitoring:test_utils", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:statusor", "@local_xla//xla:shape_util", "@local_xla//xla/client:client_library", - "@local_xla//xla/stream_executor", + "@local_xla//xla/stream_executor:platform_manager", "@local_xla//xla/translate/mhlo_to_hlo:type_to_shape", ], ) diff --git a/tensorflow/compiler/mlir/tf2xla/api/v1/compile_tf_graph.cc b/tensorflow/compiler/mlir/tf2xla/api/v1/compile_tf_graph.cc index b93431fb8e607a..0355204506068c 100644 --- a/tensorflow/compiler/mlir/tf2xla/api/v1/compile_tf_graph.cc +++ b/tensorflow/compiler/mlir/tf2xla/api/v1/compile_tf_graph.cc @@ -173,6 +173,7 @@ Status PrepareAndExportToLibrary(mlir::ModuleOp module, applyTensorflowAndCLOptions(manager); manager.addPass(mlir::TF::CreatePrepareTpuComputationForTfExportPass()); manager.addPass(mlir::TF::CreateTFRegionControlFlowToFunctional()); + manager.addPass(mlir::TF::CreateTFShapeInferencePass()); manager.addNestedPass( mlir::CreateFunctionalToExecutorDialectConversionPass()); manager.addPass(mlir::CreateBreakUpIslandsPass()); diff --git a/tensorflow/compiler/mlir/tf2xla/api/v1/compile_tf_graph_test.cc b/tensorflow/compiler/mlir/tf2xla/api/v1/compile_tf_graph_test.cc index e3b2339eedbba5..fdff5122c3516e 100644 --- a/tensorflow/compiler/mlir/tf2xla/api/v1/compile_tf_graph_test.cc +++ b/tensorflow/compiler/mlir/tf2xla/api/v1/compile_tf_graph_test.cc @@ -15,26 +15,19 @@ limitations under the License. #include "tensorflow/compiler/mlir/tf2xla/api/v1/compile_tf_graph.h" +#include #include #include #include #include -#include "absl/status/status.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project -#include "mlir/IR/BuiltinOps.h" // from @llvm-project -#include "mlir/IR/DialectRegistry.h" // from @llvm-project -#include "mlir/IR/MLIRContext.h" // from @llvm-project -#include "mlir/IR/OwningOpRef.h" // from @llvm-project -#include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h" -#include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h" -#include "tensorflow/compiler/mlir/tensorflow/utils/serialize_mlir_module_utils.h" #include "tensorflow/compiler/mlir/tf2xla/internal/test_matchers.h" +#include "tensorflow/compiler/mlir/tf2xla/internal/utils/test_metadata_config.h" #include "tensorflow/compiler/tf2xla/layout_util.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" -#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "xla/client/client_library.h" #include "xla/shape.h" +#include "xla/stream_executor/platform_manager.h" #include "xla/translate/mhlo_to_hlo/type_to_shape.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/lib/monitoring/cell_reader.h" @@ -82,64 +75,14 @@ MlirToHloArgs CreateTestMlirToHloArgs(const char* module_str = kMlirModuleStr) { } class CompileTFGraphTest : public ::testing::Test { - private: - absl::Status SetupArguments(mlir::ModuleOp module, - std::vector& arg_shapes, - tpu::TPUCompileMetadataProto& metadata_proto) { - auto main_fn = module.lookupSymbol(kEntryFuncName); - if (!main_fn) { - return absl::InternalError( - "Could not find main function in MLIR Module."); - } - - mlir::FunctionType func_type = main_fn.getFunctionType(); - for (auto input_type : func_type.getInputs()) { - tensorflow::TensorShape tensor_shape; - xla::Shape xla_shape = xla::TypeToShape(input_type); - TF_RETURN_IF_ERROR(tensorflow::TensorShape::BuildTensorShape( - xla_shape.dimensions(), &tensor_shape)); - arg_shapes.emplace_back(tensor_shape); - - DataType dtype; - TF_RETURN_IF_ERROR(ConvertToDataType(input_type, &dtype)); - - auto metadata_arg = metadata_proto.add_args(); - metadata_arg->set_kind(tpu::TPUCompileMetadataProto::Arg::PARAMETER); - metadata_arg->set_dtype(dtype); - } - - return absl::OkStatus(); - } - - absl::Status SetupReturnValues(mlir::ModuleOp module, - tpu::TPUCompileMetadataProto& metadata_proto) { - auto main_fn = module.lookupSymbol(kEntryFuncName); - if (!main_fn) { - return absl::InternalError( - "Could not find main function in MLIR Module."); - } - - int func_results = main_fn.getFunctionType().getNumResults(); - for (int i = 0; i < func_results; i++) { - metadata_proto.add_retvals(); - } - - return absl::OkStatus(); - } - public: tsl::StatusOr CompileWithComputation( const std::variant computation) { - mlir::DialectRegistry registry; - mlir::RegisterAllTensorFlowDialects(registry); - mlir::MLIRContext context(registry); - mlir::OwningOpRef mlir_module; - XlaCompilationResult compilation_result; se::Platform* platform = - se::MultiPlatformManager::PlatformWithName(kPlatformName).value(); + se::PlatformManager::PlatformWithName(kPlatformName).value(); auto client = xla::ClientLibrary::GetOrCreateCompileOnlyClient(platform).value(); @@ -150,11 +93,8 @@ class CompileTFGraphTest : public ::testing::Test { tpu::TPUCompileMetadataProto metadata_proto; std::vector arg_shapes; if (computation.index() == 0) { - TF_RETURN_IF_ERROR(DeserializeMlirModule( - std::get<0>(computation).mlir_module, &context, &mlir_module)); - TF_RETURN_IF_ERROR(SetupReturnValues(*mlir_module, metadata_proto)); - TF_RETURN_IF_ERROR( - SetupArguments(*mlir_module, arg_shapes, metadata_proto)); + TF_RETURN_IF_ERROR(tensorflow::tf2xla::internal::ConfigureMetadata( + std::get<0>(computation).mlir_module, arg_shapes, metadata_proto)); } XlaShapeLayoutHelpers::ShapeDeterminationFns shape_determination_fns; @@ -213,9 +153,9 @@ TEST_F(CompileTFGraphTest, RecordsStreamzForFunctionToHlo) { EXPECT_EQ(compilation_status.Delta("kOldBridgeNoMlirSuccess"), 1); } -TEST_F(CompileTFGraphTest, CatchesErrorMissedByPassManagerRun) { +TEST_F(CompileTFGraphTest, SuccessfullyCompilesWithManualSharding) { // MLIR module from failing test. - constexpr char kUnsupportedManualSharding[] = R"( + constexpr char kSupportedManualSharding[] = R"( module @module___inference_tpu_function_41 attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, producer = 1617 : i32}} { func.func @main(%arg0: tensor<2x2xf32>) -> (tensor<2x2xf32> {mhlo.sharding = "\08\03\1A\02\02\01\22\02\00\01"}) { %0 = tf_executor.graph { @@ -223,7 +163,7 @@ TEST_F(CompileTFGraphTest, CatchesErrorMissedByPassManagerRun) { %outputs_0, %control_1 = tf_executor.island wraps "tf.XlaSharding"(%outputs) {_XlaSharding = "\08\03\1A\02\02\01\22\02\00\01", sharding = "\08\03\1A\02\02\01\22\02\00\01", unspecified_dims = []} : (tensor<2x2xf32>) -> tensor<2x2xf32> %outputs_2, %control_3 = tf_executor.island wraps "tf.XlaSpmdFullToShardShape"(%outputs_0) {dim = -1 : i64, manual_sharding = "\08\03\1A\02\02\01\22\02\00\01", unspecified_dims = []} : (tensor<2x2xf32>) -> tensor<1x2xf32> %control_4 = tf_executor.island wraps "tf._XlaHostComputeMlir"(%outputs_2) {host_mlir_module = "", manual_sharding = true, recv_key = "host_compute_channel_0_retvals", send_key = "host_compute_channel_0_args"} : (tensor<1x2xf32>) -> () - %outputs_5, %control_6 = tf_executor.island(%control_4) wraps "tf._XlaHostComputeMlir"() {host_mlir_module = "", manual_sharding = true, recv_key = "host_compute_channel_1_retvals", send_key = "host_compute_channel_1_args"} : () -> tensor<1x2xf32> + %outputs_5, %control_6 = tf_executor.island(%control_4) wraps "tf._XlaHostComputeMlir"() {host_mlir_module = "module {\0A func.func @host_func() -> tensor<1x2xf32> {\0A %0 = \22tf.Const\22() {value = dense<0.1> : tensor<1x2xf32>} : () -> tensor<1x2xf32> \0A return %0 : tensor<1x2xf32>}}", manual_sharding = true, recv_key = "host_compute_channel_1_retvals", send_key = "host_compute_channel_1_args"} : () -> tensor<1x2xf32> %outputs_7, %control_8 = tf_executor.island wraps "tf.XlaSpmdShardToFullShape"(%outputs_5) {dim = -1 : i64, full_shape = #tf_type.shape<2x2>, manual_sharding = "\08\03\1A\02\02\01\22\02\00\01", unspecified_dims = []} : (tensor<1x2xf32>) -> tensor<2x2xf32> %outputs_9, %control_10 = tf_executor.island wraps "tf.XlaSharding"(%outputs_7) {_XlaSharding = "\08\03\1A\02\02\01\22\02\00\01", sharding = "\08\03\1A\02\02\01\22\02\00\01", unspecified_dims = []} : (tensor<2x2xf32>) -> tensor<2x2xf32> tf_executor.fetch %outputs_9 : tensor<2x2xf32> @@ -232,13 +172,11 @@ TEST_F(CompileTFGraphTest, CatchesErrorMissedByPassManagerRun) { } } )"; - auto mlir_to_hlo_args = CreateTestMlirToHloArgs(kUnsupportedManualSharding); + auto mlir_to_hlo_args = CreateTestMlirToHloArgs(kSupportedManualSharding); auto result = CompileWithComputation(mlir_to_hlo_args); - ASSERT_THAT(result.ok(), false); - EXPECT_THAT(result.status().message(), - testing::ContainsRegex("op manual_sharding")); + EXPECT_TRUE(result.ok()); } TEST_F(CompileTFGraphTest, DoesNotInlineStatelessRandomOps) { @@ -261,6 +199,24 @@ TEST_F(CompileTFGraphTest, DoesNotInlineStatelessRandomOps) { ComputationProtoContains("tf.StatelessRandomNormal")); } +TEST_F(CompileTFGraphTest, TestRunsShapeInference) { + static constexpr char kShapeInferenceModule[] = R"( + module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, producer = 268 : i32}} { + func.func @main() -> () { + %0 = "tf.Const"() <{value = dense<-1> : tensor<3360x8xi32>}> : () -> tensor<3360x8xi32> + %cst_33 = "tf.Const"() <{value = dense<[1120, -1]> : tensor<2xi32>}> : () -> tensor<2xi32> + %cst_34 = "tf.Const"() <{value = dense<[3, 1120, -1]> : tensor<3xi32>}> : () -> tensor<3xi32> + %cst_63 = "tf.Const"() <{value = dense<0> : tensor}> : () -> tensor + %1965:4 = "tf._XlaHostComputeMlir"(%0, %cst_34, %cst_63, %cst_33) <{host_mlir_module = "#loc1 = loc(\22Reshape:\22)\0A#loc2 = loc(\22Reshape_4\22)\0A#loc3 = loc(\22Reshape\22)\0A#loc9 = loc(fused[#loc1, #loc2, #loc3])\0Amodule {\0A func.func @host_func(%arg0: tensor<3360x?xi32> loc(fused[#loc1, #loc2, #loc3]), %arg1: tensor<3xi32> loc(fused[#loc1, #loc2, #loc3]), %arg2: tensor loc(fused[#loc1, #loc2, #loc3]), %arg3: tensor<2xi32> loc(fused[#loc1, #loc2, #loc3])) -> (tensor<1x1120x?xi32>, tensor<1x1120x?xi32>, tensor<1120x?xi32>, tensor<2xi32>) {\0A %0 = \22tf.Reshape\22(%arg0, %arg1) {_xla_outside_compilation = \220\22} : (tensor<3360x?xi32>, tensor<3xi32>) -> tensor<3x1120x?xi32> loc(#loc9)\0A %1:3 = \22tf.Split\22(%arg2, %0) {_xla_outside_compilation = \220\22} : (tensor, tensor<3x1120x?xi32>) -> (tensor<1x1120x?xi32>, tensor<1x1120x?xi32>, tensor<1x1120x?xi32>) loc(#loc10)\0A %2 = \22tf.Reshape\22(%1#0, %arg3) {_xla_outside_compilation = \220\22} : (tensor<1x1120x?xi32>, tensor<2xi32>) -> tensor<1120x?xi32> loc(#loc11)\0A %3 = \22tf.Shape\22(%2) {_xla_outside_compilation = \220\22} : (tensor<1120x?xi32>) -> tensor<2xi32> loc(#loc12)\0A return %1#1, %1#2, %2, %3 : tensor<1x1120x?xi32>, tensor<1x1120x?xi32>, tensor<1120x?xi32>, tensor<2xi32> loc(#loc9)\0A } loc(#loc9)\0A} loc(#loc)\0A#loc = loc(unknown)\0A#loc4 = loc(\22Split:\22)\0A#loc5 = loc(\22split\22)\0A#loc6 = loc(\22Reshape_5\22)\0A#loc7 = loc(\22Shape:\22)\0A#loc8 = loc(\22Shape_4\22)\0A#loc10 = loc(fused[#loc4, #loc5])\0A#loc11 = loc(fused[#loc1, #loc6])\0A#loc12 = loc(fused[#loc7, #loc8])\0A", recv_key = "host_compute_channel_0_retvals", send_key = "host_compute_channel_0_args"}> : (tensor<3360x8xi32>, tensor<3xi32>, tensor, tensor<2xi32>) -> (tensor<1x1120x?xi32>, tensor<1x1120x?xi32>, tensor<1120x?xi32>, tensor<2xi32>) + return + } + } + )"; + + auto compilation_result = + CompileWithComputation(CreateTestMlirToHloArgs(kShapeInferenceModule)); + EXPECT_TRUE(compilation_result.ok()); +} } // namespace } // namespace v1 } // namespace tf2xla diff --git a/tensorflow/compiler/mlir/tf2xla/api/v1/testdata/prepare_to_library.mlir b/tensorflow/compiler/mlir/tf2xla/api/v1/testdata/prepare_to_library.mlir new file mode 100644 index 00000000000000..42e145effa742f --- /dev/null +++ b/tensorflow/compiler/mlir/tf2xla/api/v1/testdata/prepare_to_library.mlir @@ -0,0 +1,10 @@ +module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, producer = 268 : i32}} { + func.func @main() -> () { + %0 = "tf.Const"() <{value = dense<-1> : tensor<3360x8xi32>}> : () -> tensor<3360x8xi32> + %cst_33 = "tf.Const"() <{value = dense<[1120, -1]> : tensor<2xi32>}> : () -> tensor<2xi32> + %cst_34 = "tf.Const"() <{value = dense<[3, 1120, -1]> : tensor<3xi32>}> : () -> tensor<3xi32> + %cst_63 = "tf.Const"() <{value = dense<0> : tensor}> : () -> tensor + %1965:4 = "tf._XlaHostComputeMlir"(%0, %cst_34, %cst_63, %cst_33) <{host_mlir_module = "#loc1 = loc(\22Reshape:\22)\0A#loc2 = loc(\22Reshape_4\22)\0A#loc3 = loc(\22Reshape\22)\0A#loc9 = loc(fused[#loc1, #loc2, #loc3])\0Amodule {\0A func.func @host_func(%arg0: tensor<3360x?xi32> loc(fused[#loc1, #loc2, #loc3]), %arg1: tensor<3xi32> loc(fused[#loc1, #loc2, #loc3]), %arg2: tensor loc(fused[#loc1, #loc2, #loc3]), %arg3: tensor<2xi32> loc(fused[#loc1, #loc2, #loc3])) -> (tensor<1x1120x?xi32>, tensor<1x1120x?xi32>, tensor<1120x?xi32>, tensor<2xi32>) {\0A %0 = \22tf.Reshape\22(%arg0, %arg1) {_xla_outside_compilation = \220\22} : (tensor<3360x?xi32>, tensor<3xi32>) -> tensor<3x1120x?xi32> loc(#loc9)\0A %1:3 = \22tf.Split\22(%arg2, %0) {_xla_outside_compilation = \220\22} : (tensor, tensor<3x1120x?xi32>) -> (tensor<1x1120x?xi32>, tensor<1x1120x?xi32>, tensor<1x1120x?xi32>) loc(#loc10)\0A %2 = \22tf.Reshape\22(%1#0, %arg3) {_xla_outside_compilation = \220\22} : (tensor<1x1120x?xi32>, tensor<2xi32>) -> tensor<1120x?xi32> loc(#loc11)\0A %3 = \22tf.Shape\22(%2) {_xla_outside_compilation = \220\22} : (tensor<1120x?xi32>) -> tensor<2xi32> loc(#loc12)\0A return %1#1, %1#2, %2, %3 : tensor<1x1120x?xi32>, tensor<1x1120x?xi32>, tensor<1120x?xi32>, tensor<2xi32> loc(#loc9)\0A } loc(#loc9)\0A} loc(#loc)\0A#loc = loc(unknown)\0A#loc4 = loc(\22Split:\22)\0A#loc5 = loc(\22split\22)\0A#loc6 = loc(\22Reshape_5\22)\0A#loc7 = loc(\22Shape:\22)\0A#loc8 = loc(\22Shape_4\22)\0A#loc10 = loc(fused[#loc4, #loc5])\0A#loc11 = loc(fused[#loc1, #loc6])\0A#loc12 = loc(fused[#loc7, #loc8])\0A", recv_key = "host_compute_channel_0_retvals", send_key = "host_compute_channel_0_args"}> : (tensor<3360x8xi32>, tensor<3xi32>, tensor, tensor<2xi32>) -> (tensor<1x1120x?xi32>, tensor<1x1120x?xi32>, tensor<1120x?xi32>, tensor<2xi32>) + return + } +} \ No newline at end of file diff --git a/tensorflow/compiler/mlir/tf2xla/api/v1/tf_dialect_to_executor.cc b/tensorflow/compiler/mlir/tf2xla/api/v1/tf_dialect_to_executor.cc index 9d0b884ebbe85d..aa9a73215b7284 100644 --- a/tensorflow/compiler/mlir/tf2xla/api/v1/tf_dialect_to_executor.cc +++ b/tensorflow/compiler/mlir/tf2xla/api/v1/tf_dialect_to_executor.cc @@ -91,8 +91,12 @@ void AddTfDialectToExecutorPasses(OpPassManager &pm) { pm.addPass(mlir::createSymbolDCEPass()); if (tensorflow::GetMlirCommonFlags() ->tf_mlir_enable_convert_control_to_data_outputs_pass) { + bool composite_tpuexecute_side_effects = + tensorflow::GetMlirCommonFlags() + ->tf_mlir_enable_composite_tpuexecute_side_effects; pm.addPass( - mlir::tf_executor::CreateTFExecutorConvertControlToDataOutputsPass()); + mlir::tf_executor::CreateTFExecutorConvertControlToDataOutputsPass( + composite_tpuexecute_side_effects)); } pm.addPass(mlir::TF::CreateVerifySuitableForExportPass()); } diff --git a/tensorflow/compiler/mlir/tf2xla/api/v2/BUILD b/tensorflow/compiler/mlir/tf2xla/api/v2/BUILD index b89695f1fa37ca..a92239e8dbba69 100644 --- a/tensorflow/compiler/mlir/tf2xla/api/v2/BUILD +++ b/tensorflow/compiler/mlir/tf2xla/api/v2/BUILD @@ -35,6 +35,7 @@ cc_library( "//tensorflow/compiler/mlir/tensorflow/transforms:tensorflow_passes", "//tensorflow/compiler/mlir/tf2xla:compile_mlir_util", "//tensorflow/compiler/mlir/tf2xla/api/v1:compile_tf_graph", + "//tensorflow/compiler/mlir/tf2xla/internal:compilation_timer", "//tensorflow/compiler/mlir/tf2xla/internal:legalize_tf_mlir", "//tensorflow/compiler/mlir/tf2xla/internal:legalize_tf_to_hlo", "//tensorflow/compiler/tf2xla:layout_util", @@ -52,6 +53,7 @@ cc_library( "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:variant", + "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", "@local_tsl//tsl/platform:error_logging", @@ -74,6 +76,8 @@ tf_cc_test( ":legalize_tf", "//tensorflow/compiler/jit", "//tensorflow/compiler/mlir/tensorflow", + "//tensorflow/compiler/mlir/tf2xla/internal:test_matchers", + "//tensorflow/compiler/mlir/tf2xla/internal/utils:test_metadata_config", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/tf2xla:xla_helpers", "//tensorflow/core:framework", @@ -89,6 +93,7 @@ tf_cc_test( "@local_tsl//tsl/lib/monitoring:test_utils", "@local_tsl//tsl/platform:statusor", "@local_xla//xla/client:client_library", + "@local_xla//xla/stream_executor:platform_manager", ], ) diff --git a/tensorflow/compiler/mlir/tf2xla/api/v2/cluster_tf.cc b/tensorflow/compiler/mlir/tf2xla/api/v2/cluster_tf.cc index 4fa9b9bb98da3f..9729f755c611ce 100644 --- a/tensorflow/compiler/mlir/tf2xla/api/v2/cluster_tf.cc +++ b/tensorflow/compiler/mlir/tf2xla/api/v2/cluster_tf.cc @@ -160,19 +160,18 @@ void CreateTPUClusteringPipelineV2(OpPassManager &pm) { } tensorflow::Status RunFunctionTf2xlaClusteringBridge( - ModuleOp module, DeviceType device_type, bool is_in_fallback_enabled_mode, - llvm::StringRef module_name) { - bool is_replicated = device_type == DeviceType::XLA_TPU_JIT; + ModuleOp module, bool is_supported_by_replicated_brige, + bool is_in_fallback_enabled_mode, llvm::StringRef module_name) { std::string device_type_filter = - device_type == DeviceType::XLA_TPU_JIT ? "tpu" : "cpu/gpu"; + is_supported_by_replicated_brige ? "tpu" : "cpu/gpu"; VLOG(2) - << (is_replicated ? "Replicated" : "NonReplicated") + << (is_supported_by_replicated_brige ? "Replicated" : "NonReplicated") << " Bridge called stack trace is " << "(NOTE: this is not an error; rather the stack trace for debugging) : " << tensorflow::CurrentStackTrace(); Status clustering_status = - is_replicated + is_supported_by_replicated_brige ? RunTFXLABridge( module, [module_name](OpPassManager &pm) { @@ -187,12 +186,12 @@ tensorflow::Status RunFunctionTf2xlaClusteringBridge( }, module_name, /*dump_prefix=*/"tf_xla_bridge_v2_nonreplicated"); - // TODO(b/317798386): add is_replicated as a filter. + // TODO(b/317798386): add is_supported_by_replicated_brige as a filter. TF_RETURN_IF_ERROR(RecordIfErrorStatus( /*error_prefix=*/"clustering_v2", is_in_fallback_enabled_mode, device_type_filter, clustering_status)); - // TODO(b/317798386): add is_replicated as a filter. + // TODO(b/317798386): add is_supported_by_replicated_brige as a filter. tensorflow::metrics::UpdateTfMlirBridgeFirstPhaseCounter( device_type_filter, /*bridge_version=*/"v2", /*fallback_enabled=*/is_in_fallback_enabled_mode, diff --git a/tensorflow/compiler/mlir/tf2xla/api/v2/cluster_tf.h b/tensorflow/compiler/mlir/tf2xla/api/v2/cluster_tf.h index e1298ac53560d3..8963fe7b126b9e 100644 --- a/tensorflow/compiler/mlir/tf2xla/api/v2/cluster_tf.h +++ b/tensorflow/compiler/mlir/tf2xla/api/v2/cluster_tf.h @@ -39,17 +39,19 @@ namespace v2 { // Inputs: // module - The MLIR Module that will be clustered. Expected to be in TF // Executor Dialect or TF Functional Dialect. Will convert to TF Functional. -// . device_type - The device type to cluster for. -// is_in_fallback_enabled_mode - Whether this was called with fallback to the -// non-MLIR Bridge. This is just for logging purposes and doesn't affect -// logic. -// module_name - What the input module name is for debugging help. +// is_supported_by_replicated_brige - If the graph targets the replicated +// bridge. Set it to true for replicated/partitioned graphs. e.g. replicated +// and single-core TPU graphs. Set this to false if the graph is not +// replicated, e.g. CPU/GPU graphs. is_in_fallback_enabled_mode - Whether this +// was called with fallback to the non-MLIR Bridge. This is just for logging +// purposes and doesn't affect logic. module_name - What the input module name +// is for debugging help. // // Output: Modifies the input module in place with clustered operations. // status - Whether the transformation to cluster the input MLIR module was // successful. tensorflow::Status RunFunctionTf2xlaClusteringBridge( - mlir::ModuleOp module, DeviceType device_type, + mlir::ModuleOp module, bool is_supported_by_replicated_brige, bool is_in_fallback_enabled_mode, llvm::StringRef module_name = llvm::StringRef()); } // namespace v2 diff --git a/tensorflow/compiler/mlir/tf2xla/api/v2/cluster_tf_test.cc b/tensorflow/compiler/mlir/tf2xla/api/v2/cluster_tf_test.cc index d00d8b43d9e790..c4a96702533c49 100644 --- a/tensorflow/compiler/mlir/tf2xla/api/v2/cluster_tf_test.cc +++ b/tensorflow/compiler/mlir/tf2xla/api/v2/cluster_tf_test.cc @@ -82,14 +82,14 @@ class FunctionClusterTensorflowDialectTest : public ::testing::Test { OwningOpRef mlir_module_; }; -TEST_F(FunctionClusterTensorflowDialectTest, ClustersTfTPU) { +TEST_F(FunctionClusterTensorflowDialectTest, ClustersTfReplicatedBridge) { CellReader compilation_status(kCompilationStreamz); TF_ASSERT_OK(CreateMlirModule("empty_func.mlir")); - TF_EXPECT_OK( - RunFunctionTf2xlaClusteringBridge(*mlir_module_, DeviceType::XLA_TPU_JIT, - /*is_in_fallback_enabled_mode=*/false)); + TF_EXPECT_OK(RunFunctionTf2xlaClusteringBridge( + *mlir_module_, /*is_supported_by_replicated_brige*/ true, + /*is_in_fallback_enabled_mode=*/false)); FuncOp main = mlir_module_->lookupSymbol("main"); ASSERT_TRUE(main); @@ -98,14 +98,15 @@ TEST_F(FunctionClusterTensorflowDialectTest, ClustersTfTPU) { compilation_status.Delta("tpu", "v2", "fallback_disabled", "success"), 1); } -TEST_F(FunctionClusterTensorflowDialectTest, RunsOutsideCompilationTPU) { +TEST_F(FunctionClusterTensorflowDialectTest, + RunsOutsideCompilationReplicatedBridge) { CellReader compilation_status(kCompilationStreamz); TF_ASSERT_OK(CreateMlirModule("outside_compilation.mlir")); - TF_EXPECT_OK( - RunFunctionTf2xlaClusteringBridge(*mlir_module_, DeviceType::XLA_TPU_JIT, - /*is_in_fallback_enabled_mode=*/false)); + TF_EXPECT_OK(RunFunctionTf2xlaClusteringBridge( + *mlir_module_, /*is_supported_by_replicated_brige*/ true, + /*is_in_fallback_enabled_mode=*/false)); FuncOp main = mlir_module_->lookupSymbol("main"); ASSERT_TRUE(main); @@ -121,31 +122,14 @@ TEST_F(FunctionClusterTensorflowDialectTest, RunsOutsideCompilationTPU) { compilation_status.Delta("tpu", "v2", "fallback_disabled", "success"), 1); } -TEST_F(FunctionClusterTensorflowDialectTest, ClustersTFCPU) { +TEST_F(FunctionClusterTensorflowDialectTest, ClustersTFNonReplicatedBridge) { CellReader compilation_status(kCompilationStreamz); TF_ASSERT_OK(CreateMlirModule("empty_func.mlir")); - TF_EXPECT_OK( - RunFunctionTf2xlaClusteringBridge(*mlir_module_, DeviceType::XLA_CPU_JIT, - /*is_in_fallback_enabled_mode=*/false)); - - FuncOp main = mlir_module_->lookupSymbol("main"); - ASSERT_TRUE(main); - - EXPECT_EQ( - compilation_status.Delta("cpu/gpu", "v2", "fallback_disabled", "success"), - 1); -} - -TEST_F(FunctionClusterTensorflowDialectTest, ClustersTFGPU) { - CellReader compilation_status(kCompilationStreamz); - - TF_ASSERT_OK(CreateMlirModule("empty_func.mlir")); - - TF_EXPECT_OK( - RunFunctionTf2xlaClusteringBridge(*mlir_module_, DeviceType::XLA_GPU_JIT, - /*is_in_fallback_enabled_mode=*/false)); + TF_EXPECT_OK(RunFunctionTf2xlaClusteringBridge( + *mlir_module_, /*is_supported_by_replicated_brige*/ false, + /*is_in_fallback_enabled_mode=*/false)); FuncOp main = mlir_module_->lookupSymbol("main"); ASSERT_TRUE(main); @@ -160,9 +144,9 @@ TEST_F(FunctionClusterTensorflowDialectTest, LogsFallbackMode) { TF_ASSERT_OK(CreateMlirModule("empty_func.mlir")); - TF_EXPECT_OK( - RunFunctionTf2xlaClusteringBridge(*mlir_module_, DeviceType::XLA_TPU_JIT, - /*is_in_fallback_enabled_mode=*/true)); + TF_EXPECT_OK(RunFunctionTf2xlaClusteringBridge( + *mlir_module_, /*is_supported_by_replicated_brige*/ true, + /*is_in_fallback_enabled_mode=*/true)); EXPECT_EQ( compilation_status.Delta("tpu", "v2", "fallback_enabled", "success"), 1); diff --git a/tensorflow/compiler/mlir/tf2xla/api/v2/legalize_tf.cc b/tensorflow/compiler/mlir/tf2xla/api/v2/legalize_tf.cc index 6fad56a2e7999b..d297e45b70e0bb 100644 --- a/tensorflow/compiler/mlir/tf2xla/api/v2/legalize_tf.cc +++ b/tensorflow/compiler/mlir/tf2xla/api/v2/legalize_tf.cc @@ -25,9 +25,11 @@ limitations under the License. #include "absl/status/status.h" #include "absl/strings/str_cat.h" #include "absl/types/variant.h" +#include "llvm/ADT/ScopeExit.h" #include "tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h" #include "tensorflow/compiler/mlir/tf2xla/api/v1/compile_mlir_util.h" #include "tensorflow/compiler/mlir/tf2xla/api/v1/compile_tf_graph.h" +#include "tensorflow/compiler/mlir/tf2xla/internal/compilation_timer.h" #include "tensorflow/compiler/mlir/tf2xla/internal/legalize_tf_mlir.h" #include "tensorflow/compiler/mlir/tf2xla/internal/legalize_tf_to_hlo.h" #include "tensorflow/compiler/tf2xla/layout_util.h" @@ -40,6 +42,7 @@ limitations under the License. #include "tensorflow/core/tpu/kernels/tpu_compile_op_support.h" #include "tensorflow/core/util/debug_data_dumper.h" #include "tensorflow/core/util/dump_graph.h" +#include "tsl/lib/monitoring/sampler.h" #include "tsl/platform/error_logging.h" #include "tsl/platform/errors.h" #include "tsl/platform/statusor.h" @@ -54,9 +57,17 @@ using tpu::FunctionToHloArgs; using tpu::MlirToHloArgs; using tpu::ShardingAndIndex; +auto* phase2_bridge_compilation_time = tsl::monitoring::Sampler<1>::New( + {"/tensorflow/core/tf2xla/api/v2/phase2_compilation_time", + "The wall-clock time spent on executing graphs in milliseconds.", + "configuration"}, + // Power of 1.5 with bucket count 45 (> 23 hours) + {tsl::monitoring::Buckets::Exponential(1, 1.5, 45)}); + // Name of component for error logging. This name is fixed and required to // enable logging. constexpr char kBridgeComponent[] = "TFXLABridge"; +constexpr char kFullBridge[] = "full_bridge"; namespace { @@ -134,6 +145,12 @@ tsl::StatusOr LegalizeMlirToHlo( std::vector* arg_core_mapping, std::vector>* per_core_arg_shapes, xla::CompileOnlyClient* client) { + CompilationTimer timer; + auto record_time = llvm::make_scope_exit([&timer] { + phase2_bridge_compilation_time->GetCell(kFullBridge) + ->Add(timer.ElapsedCyclesInMilliseconds()); + }); + auto compilation_result = std::make_unique(); DumpComputationInput(computation); diff --git a/tensorflow/compiler/mlir/tf2xla/api/v2/legalize_tf_test.cc b/tensorflow/compiler/mlir/tf2xla/api/v2/legalize_tf_test.cc index e1f119d91213ab..81b3b5a180eb93 100644 --- a/tensorflow/compiler/mlir/tf2xla/api/v2/legalize_tf_test.cc +++ b/tensorflow/compiler/mlir/tf2xla/api/v2/legalize_tf_test.cc @@ -23,12 +23,14 @@ limitations under the License. #include #include #include "absl/strings/str_format.h" -#include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h" -#include "tensorflow/compiler/mlir/tf2xla/api/v2/device_type.pb.h" +#include "tensorflow/compiler/mlir/tf2xla/internal/test_matchers.h" +#include "tensorflow/compiler/mlir/tf2xla/internal/utils/test_metadata_config.h" #include "tensorflow/compiler/tf2xla/xla_compiler.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "xla/client/client_library.h" +#include "xla/stream_executor/platform_manager.h" #include "tensorflow/core/lib/monitoring/cell_reader.h" +#include "tensorflow/core/lib/monitoring/test_utils.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/protobuf/config.pb.h" @@ -49,10 +51,10 @@ using tpu::FunctionToHloArgs; using tpu::MlirToHloArgs; using tpu::ShardingAndIndex; using tpu::TPUCompileMetadataProto; -using ::tsl::monitoring::testing::Histogram; static constexpr char kCompilationTimeStreamzName[] = "/tensorflow/core/tf2xla/api/v2/phase2_compilation_time"; +static constexpr char kFullBridge[] = "full_bridge"; static constexpr char kCompilationStatusStreamzName[] = "/tensorflow/core/tf2xla/api/v2/phase2_compilation_status"; static const char kMlirWithFallbackModeSuccess[] = @@ -79,7 +81,7 @@ static constexpr char kMlirModuleStr[] = R"( } })"; -// MLIR which should legalize at all +// MLIR which should not legalize at all static constexpr char kBadMlirModuleStr[] = R"( module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, producer = 268 : i32}} { func.func @main() -> () { @@ -107,13 +109,18 @@ tsl::StatusOr CompileMlirModule( mlir_to_hlo_args.mlir_module = mlir_module_str; se::Platform* platform = - se::MultiPlatformManager::PlatformWithName("Host").value(); + se::PlatformManager::PlatformWithName("Host").value(); auto client = xla::ClientLibrary::GetOrCreateCompileOnlyClient(platform).value(); std::vector arg_shapes; TPUCompileMetadataProto metadata_proto; - metadata_proto.add_retvals(); + // Configure metadata requires parsing the module and if we are testing a + // failure, we ignore this particular set up error assuming we'll not get + // far enough to need valid metadata. + tensorflow::tf2xla::internal::ConfigureMetadata(mlir_module_str, arg_shapes, + metadata_proto) + .IgnoreError(); bool use_tuple_args = true; std::vector arg_core_mapping; std::vector> per_core_arg_shapes; @@ -272,7 +279,7 @@ TEST(LegalizeTFTest, RecordsStreamzForNoMlirFallback) { {&guaranteed_constants}}; se::Platform* cpu_platform = - se::MultiPlatformManager::PlatformWithName("Host").value(); + se::PlatformManager::PlatformWithName("Host").value(); auto client = xla::ClientLibrary::GetOrCreateCompileOnlyClient(cpu_platform).value(); @@ -294,6 +301,39 @@ TEST(LegalizeTFTest, RecordsStreamzForNoMlirFallback) { EXPECT_FALSE(compile_result.ok()); } +TEST(LegalizeTFTest, RecordsCompilationTimeForSuccessfulCompilation) { + CellReader compilation_time( + kCompilationTimeStreamzName); + + TF_ASSERT_OK_AND_ASSIGN( + XlaCompiler::CompilationResult result, + CompileMlirModule( + kMlirModuleStr, + ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_ENABLED)); + + // Compilation time should have been updated. + EXPECT_GT(compilation_time.Delta(kFullBridge).num(), 0); +} + +TEST(LegalizeTFTest, SuccessfullyCompilesModulesWithReturnValues) { + static constexpr char kHasReturnValuesAndNoMetadataRetvals[] = R"( + module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, producer = 268 : i32}} { + func.func @main() -> (tensor<2xi32>) { + %cst = "tf.Const"() {value = dense<[524170, 523952]> : tensor<2xi32>} : () -> tensor<2xi32> + return %cst : tensor<2xi32> + } + })"; + + auto compilation_result = CompileMlirModule( + kHasReturnValuesAndNoMetadataRetvals, + ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_UNSPECIFIED); + EXPECT_TRUE(compilation_result.ok()); + + // Ensure that the compilation result contains a constant. + EXPECT_THAT(compilation_result, + ComputationProtoContains("opcode:.*constant")); +} + } // namespace v2 } // namespace tf2xla } // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tf2xla/api/v2/tf_dialect_to_executor.cc b/tensorflow/compiler/mlir/tf2xla/api/v2/tf_dialect_to_executor.cc index 455a59d6607c49..9befa9b7714a27 100644 --- a/tensorflow/compiler/mlir/tf2xla/api/v2/tf_dialect_to_executor.cc +++ b/tensorflow/compiler/mlir/tf2xla/api/v2/tf_dialect_to_executor.cc @@ -90,8 +90,12 @@ void AddTfDialectToExecutorPasses(OpPassManager &pm) { pm.addPass(mlir::createSymbolDCEPass()); if (tensorflow::GetMlirCommonFlags() ->tf_mlir_enable_convert_control_to_data_outputs_pass) { + bool composite_tpuexecute_side_effects = + tensorflow::GetMlirCommonFlags() + ->tf_mlir_enable_composite_tpuexecute_side_effects; pm.addPass( - mlir::tf_executor::CreateTFExecutorConvertControlToDataOutputsPass()); + mlir::tf_executor::CreateTFExecutorConvertControlToDataOutputsPass( + composite_tpuexecute_side_effects)); } pm.addPass(mlir::TF::CreateVerifySuitableForExportPass()); } diff --git a/tensorflow/compiler/mlir/tf2xla/internal/BUILD b/tensorflow/compiler/mlir/tf2xla/internal/BUILD index 2625b4b83b57f1..1f96ecb120795d 100644 --- a/tensorflow/compiler/mlir/tf2xla/internal/BUILD +++ b/tensorflow/compiler/mlir/tf2xla/internal/BUILD @@ -256,6 +256,8 @@ cc_library( "//tensorflow/core/common_runtime:function_body", "@com_google_absl//absl/log", "@com_google_absl//absl/status", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:IR", "@llvm-project//mlir:Support", "@local_tsl//tsl/platform:status", ], @@ -271,6 +273,7 @@ tf_cc_test( "//tensorflow/cc:functional_ops", "//tensorflow/cc:ops", "//tensorflow/cc:scope", + "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/tf2xla:tf2xla_defs", "//tensorflow/core:core_cpu_base", "//tensorflow/core:framework", @@ -278,6 +281,9 @@ tf_cc_test( "//tensorflow/core:testlib", "//tensorflow/core/platform:enable_tf2_utils", "@com_google_googletest//:gtest_main", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Parser", "@local_tsl//tsl/lib/core:status_test_util", ], ) diff --git a/tensorflow/compiler/mlir/tf2xla/internal/clustering_bridge_passes.cc b/tensorflow/compiler/mlir/tf2xla/internal/clustering_bridge_passes.cc index ed04adc6a394ff..f00df1513215cc 100644 --- a/tensorflow/compiler/mlir/tf2xla/internal/clustering_bridge_passes.cc +++ b/tensorflow/compiler/mlir/tf2xla/internal/clustering_bridge_passes.cc @@ -143,6 +143,8 @@ void AddReplicatedBridgeClusteringPipelinePasses(OpPassManager& pm, pm.addNestedPass(mlir::TFDevice::CreateClusterConstantSinkingPass()); pm.addPass(mlir::TF::CreateResourceDeviceInferencePass()); + pm.addNestedPass( + tensorflow::tf2xla::internal::CreateHoistBroadcastReadPass()); pm.addPass(mlir::TFDevice::CreateClusterOutliningPass()); pm.addPass(mlir::TFTPU::CreateTPUResourceReadForWritePass()); pm.addPass(mlir::TFDevice::CreateMarkInputOutputAliasesPass()); diff --git a/tensorflow/compiler/mlir/tf2xla/internal/clustering_bridge_passes_test.cc b/tensorflow/compiler/mlir/tf2xla/internal/clustering_bridge_passes_test.cc index c9cc5a4d1df16c..756ec42f31268c 100644 --- a/tensorflow/compiler/mlir/tf2xla/internal/clustering_bridge_passes_test.cc +++ b/tensorflow/compiler/mlir/tf2xla/internal/clustering_bridge_passes_test.cc @@ -28,7 +28,7 @@ TEST(ClusteringBridgePassesTest, AddsBridgePasses) { OpPassManager pass_manager; AddReplicatedBridgeClusteringPipelinePasses(pass_manager); - EXPECT_EQ(pass_manager.size(), 43); + EXPECT_EQ(pass_manager.size(), 44); } TEST(ClusteringBridgePassesTest, AddsNonTPUBridgePasses) { diff --git a/tensorflow/compiler/mlir/tf2xla/internal/legalize_tf_mlir.cc b/tensorflow/compiler/mlir/tf2xla/internal/legalize_tf_mlir.cc index 3b11eaaf2287d6..4fd0c21d68331f 100644 --- a/tensorflow/compiler/mlir/tf2xla/internal/legalize_tf_mlir.cc +++ b/tensorflow/compiler/mlir/tf2xla/internal/legalize_tf_mlir.cc @@ -44,7 +44,6 @@ limitations under the License. #include "tensorflow/core/protobuf/tpu/compile_metadata.pb.h" #include "tensorflow/core/tpu/kernels/tpu_compile_op_support.h" #include "tensorflow/core/tpu/tpu_compile.h" -#include "tsl/lib/monitoring/sampler.h" #include "tsl/platform/error_logging.h" #include "tsl/platform/errors.h" #include "tsl/platform/statusor.h" @@ -52,12 +51,6 @@ limitations under the License. namespace tensorflow { namespace tf2xla { namespace internal { -auto* phase2_bridge_compilation_time = tsl::monitoring::Sampler<1>::New( - {"/tensorflow/core/tf2xla/api/v2/phase2_compilation_time", - "The wall-clock time spent on executing graphs in milliseconds.", - "configuration"}, - // Power of 1.5 with bucket count 45 (> 23 hours) - {tsl::monitoring::Buckets::Exponential(1, 1.5, 45)}); // Name of component for error logging. This name is fixed and required to // enable logging. @@ -126,20 +119,11 @@ tsl::StatusOr LegalizeWithMlirBridge( // Enabling op fallback also enables whole graph fallback if op by op // fallback failed. - tsl::StatusOr mlir_bridge_status; - { - CompilationTimer timer; - const std::string kMlirBridgeFallback = "mlir_bridge_op_fallback_enabled"; - - mlir_bridge_status = CompileFromMlirToXlaHlo( - /*lower_to_xla_hlo=*/true, computation, metadata, device_type, - shape_determination_fns, use_tuple_args, compilation_result, - custom_legalization_passes, arg_shapes, arg_core_mapping, - per_core_arg_shapes); - - phase2_bridge_compilation_time->GetCell(kMlirBridgeFallback) - ->Add(timer.ElapsedCyclesInMilliseconds()); - } + tsl::StatusOr mlir_bridge_status = CompileFromMlirToXlaHlo( + /*lower_to_xla_hlo=*/true, computation, metadata, device_type, + shape_determination_fns, use_tuple_args, compilation_result, + custom_legalization_passes, arg_shapes, arg_core_mapping, + per_core_arg_shapes); if (mlir_bridge_status.ok()) { VLOG(1) << "Successfully compiled MLIR computation to XLA HLO using MLIR " diff --git a/tensorflow/compiler/mlir/tf2xla/internal/legalize_tf_to_hlo.cc b/tensorflow/compiler/mlir/tf2xla/internal/legalize_tf_to_hlo.cc index 8d7531f87eed3e..2d321246463494 100644 --- a/tensorflow/compiler/mlir/tf2xla/internal/legalize_tf_to_hlo.cc +++ b/tensorflow/compiler/mlir/tf2xla/internal/legalize_tf_to_hlo.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/tf2xla/internal/legalize_tf_to_hlo.h" #include +#include #include #include "absl/log/log.h" @@ -33,19 +34,11 @@ limitations under the License. #include "tensorflow/core/platform/status.h" #include "tensorflow/core/protobuf/tpu/compile_metadata.pb.h" #include "tensorflow/core/tpu/kernels/tpu_compile_op_support.h" -#include "tsl/lib/monitoring/sampler.h" #include "tsl/platform/statusor.h" namespace tensorflow { namespace tf2xla { namespace internal { -auto* phase2_combined_bridge_compilation_time = - tsl::monitoring::Sampler<1>::New( - {"/tensorflow/core/tf2xla/api/v2/phase2_combined_compilation_time", - "The wall-clock time spent on combined graphs in milliseconds.", - "configuration"}, - // Power of 1.5 with bucket count 45 (> 23 hours) - {tsl::monitoring::Buckets::Exponential(1, 1.5, 45)}); using metrics::IncrementTfMlirBridgeSecondPhaseCounter; using metrics::MlirBridgeSecondPhaseMetric; @@ -63,14 +56,14 @@ tsl::StatusOr LegalizeTfToHlo( xla::CompileOnlyClient* client, XlaCompilationResult* compilation_result) { LOG_FIRST_N(INFO, 1) << "Compiling MLIR computation to XLA HLO using the " "Combined MLIR Tf2Xla Bridge."; - CompilationTimer timer; - constexpr char kCombinedBridgeTimer[] = "combined_bridge"; - auto mlir_compilation = internal::CompileFromMlirToXlaHlo( - /*lower_to_xla_hlo=*/false, computation, metadata, device_type, - shape_determination_fns, use_tuple_args, compilation_result, - custom_legalization_passes, arg_shapes, arg_core_mapping, - per_core_arg_shapes); + tsl::StatusOr mlir_compilation + + = internal::CompileFromMlirToXlaHlo( + /*lower_to_xla_hlo=*/false, computation, metadata, device_type, + shape_determination_fns, use_tuple_args, compilation_result, + custom_legalization_passes, arg_shapes, arg_core_mapping, + per_core_arg_shapes); if (!mlir_compilation.ok()) { IncrementTfMlirBridgeSecondPhaseCounter( @@ -94,8 +87,6 @@ tsl::StatusOr LegalizeTfToHlo( IncrementTfMlirBridgeSecondPhaseCounter( MlirBridgeSecondPhaseMetric::kMlirCombinedOldSuccess); - phase2_combined_bridge_compilation_time->GetCell(kCombinedBridgeTimer) - ->Add(timer.ElapsedCyclesInMilliseconds()); return *compilation_result; } diff --git a/tensorflow/compiler/mlir/tf2xla/internal/legalize_tf_to_hlo_test.cc b/tensorflow/compiler/mlir/tf2xla/internal/legalize_tf_to_hlo_test.cc index 67de8464a9c587..fef88d10777cad 100644 --- a/tensorflow/compiler/mlir/tf2xla/internal/legalize_tf_to_hlo_test.cc +++ b/tensorflow/compiler/mlir/tf2xla/internal/legalize_tf_to_hlo_test.cc @@ -27,8 +27,8 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "xla/client/client_library.h" #include "xla/shape.h" -#include "xla/stream_executor/multi_platform_manager.h" #include "xla/stream_executor/platform.h" +#include "xla/stream_executor/platform_manager.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/lib/monitoring/cell_reader.h" @@ -80,7 +80,7 @@ tsl::StatusOr CompileMlirModule( mlir_to_hlo_args.mlir_module = module_str; se::Platform* platform = - se::MultiPlatformManager::PlatformWithName("Host").value(); + se::PlatformManager::PlatformWithName("Host").value(); auto client = xla::ClientLibrary::GetOrCreateCompileOnlyClient(platform).value(); diff --git a/tensorflow/compiler/mlir/tf2xla/internal/mlir_bridge_pass_util.cc b/tensorflow/compiler/mlir/tf2xla/internal/mlir_bridge_pass_util.cc index f49e4469f550d7..3672b0bbe0f47a 100644 --- a/tensorflow/compiler/mlir/tf2xla/internal/mlir_bridge_pass_util.cc +++ b/tensorflow/compiler/mlir/tf2xla/internal/mlir_bridge_pass_util.cc @@ -21,6 +21,11 @@ limitations under the License. #include "absl/log/log.h" #include "absl/status/status.h" +#include "llvm/ADT/StringRef.h" +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/Visitors.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "tensorflow/compiler/tf2xla/tf2xla_defs.h" #include "tensorflow/core/common_runtime/function_body.h" @@ -36,7 +41,7 @@ using ::mlir::success; namespace { LogicalResult HasAttr( - const Graph& graph, const FunctionLibraryDefinition& function_library, + const Graph& graph, const FunctionLibraryDefinition* function_library, const std::function& predicate) { if (predicate(graph)) { return success(); @@ -45,12 +50,13 @@ LogicalResult HasAttr( // Check if any reachable functions from the graph has the target attribute. GraphDef graph_def; graph.ToGraphDef(&graph_def); + if (!function_library) return failure(); for (const std::string& func_name : - function_library.ReachableDefinitions(graph_def).ListFunctionNames()) { - const FunctionDef* func_def = function_library.Find(func_name); + function_library->ReachableDefinitions(graph_def).ListFunctionNames()) { + const FunctionDef* func_def = function_library->Find(func_name); std::unique_ptr func_body; absl::Status status = FunctionDefToBodyHelper( - *func_def, AttrSlice(&func_def->attr()), &function_library, &func_body); + *func_def, AttrSlice(&func_def->attr()), function_library, &func_body); // This is not expected to happen in practice if (!status.ok()) { LOG(ERROR) << "Failed to parse " << func_name << ": " @@ -63,41 +69,9 @@ LogicalResult HasAttr( } return failure(); } -} // namespace - -bool HasTpuReplicateAttr(const Graph& graph, - const FunctionLibraryDefinition& function_library) { - auto predicate = [](const Graph& graph) { - for (const Node* node : graph.nodes()) { - // _tpu_replicate is used in replicated TPU graphs. It will be converted - // to_replication_info and _xla_compile_device_type in phase 1 pipelines. - if (node->attrs().FindByString(std::string(kTpuReplicateAttr))) { - return true; - } - } - return false; - }; - return HasAttr(graph, function_library, predicate).succeeded(); -} - -bool HasCompileDeviceTypeAttr( - const Graph& graph, const FunctionLibraryDefinition& function_library) { - auto predicate = [](const Graph& graph) { - for (const Node* node : graph.nodes()) { - // _xla_compile_device_type is found in CPU/GPU graphs with top-level - // compilation markers or single-core TPU graphs. - if (auto attr = - node->attrs().FindByString(std::string(kCompileDeviceTypeAttr))) { - return true; - } - } - return false; - }; - return HasAttr(graph, function_library, predicate).succeeded(); -} bool IsNonReplicatedGraph(const Graph& graph, - const FunctionLibraryDefinition& function_library) { + const FunctionLibraryDefinition* function_library) { auto predicate = [](const Graph& graph) { const std::string kStatefulPartitionedCallOp = "StatefulPartitionedCall"; for (const Node* node : graph.nodes()) { @@ -116,8 +90,23 @@ bool IsNonReplicatedGraph(const Graph& graph, return HasAttr(graph, function_library, predicate).succeeded(); } +bool IsReplicatedGraph(const Graph& graph, + const FunctionLibraryDefinition* function_library) { + auto predicate = [](const Graph& graph) { + for (const Node* node : graph.nodes()) { + // _tpu_replicate is used in replicated TPU graphs. It will be converted + // to_replication_info and _xla_compile_device_type in phase 1 pipelines. + if (node->attrs().FindByString(std::string(kTpuReplicateAttr))) { + return true; + } + } + return false; + }; + return HasAttr(graph, function_library, predicate).succeeded(); +} + bool IsSingleCoreTpuGraph(const Graph& graph, - const FunctionLibraryDefinition& function_library) { + const FunctionLibraryDefinition* function_library) { auto predicate = [](const Graph& graph) { for (const Node* node : graph.nodes()) { // _xla_compile_device_type=TPU is found in single-core TPU graphs. @@ -132,4 +121,57 @@ bool IsSingleCoreTpuGraph(const Graph& graph, return HasAttr(graph, function_library, predicate).succeeded(); } +bool IsReplicatedGraph(mlir::ModuleOp module) { + auto walk_result = module.walk([&](mlir::Operation* op) { + // TODO(b/223677572): Once the scope for new compilation and replication + // markers is expanded beyond bridge we can remove this check for + // `kTPUReplicateAttr`, we will then always have a `kCompileDeviceTypeAttr` + // in such cases (see above). + // TODO(b/229028654): Remove string conversion once we have C++17. + const llvm::StringRef tpu_replicate_attr_name(kTpuReplicateAttr.data(), + kTpuReplicateAttr.size()); + auto replicate_attr = + op->getAttrOfType(tpu_replicate_attr_name); + if (replicate_attr) return mlir::WalkResult::interrupt(); + return mlir::WalkResult::advance(); + }); + return walk_result.wasInterrupted(); +} + +bool IsSingleCoreTPUGraph(mlir::ModuleOp module) { + auto walk_result = module.walk([&](mlir::Operation* op) { + // Check for ops with compile device type "TPU". This allows us to support + // TPU compilation without replication. Note that currently the compile + // device type is not set by default before bridge, only if eager context + // attribute `jit_compile_rewrite` is true. + // TODO(b/229028654): Remove string conversion once we have C++17. + const llvm::StringRef compile_device_type_attr_name( + kCompileDeviceTypeAttr.data(), kCompileDeviceTypeAttr.size()); + auto compilation_attr = + op->getAttrOfType(compile_device_type_attr_name); + if (compilation_attr && compilation_attr.getValue().str() == kTpuDevice) { + return mlir::WalkResult::interrupt(); + } + return mlir::WalkResult::advance(); + }); + return walk_result.wasInterrupted(); +} + +} // namespace + +bool IsSupportedByNonReplicatedBridge( + const Graph& graph, const FunctionLibraryDefinition* function_library) { + return IsNonReplicatedGraph(graph, function_library); +} + +bool IsSupportedByReplicatedBridge( + const Graph& graph, const FunctionLibraryDefinition* function_library) { + return IsReplicatedGraph(graph, function_library) || + IsSingleCoreTpuGraph(graph, function_library); +} + +bool IsSupportedByReplicatedBridge(mlir::ModuleOp module) { + return IsReplicatedGraph(module) || IsSingleCoreTPUGraph(module); +} + } // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tf2xla/internal/mlir_bridge_pass_util.h b/tensorflow/compiler/mlir/tf2xla/internal/mlir_bridge_pass_util.h index 421f6c20bec9b8..5ea0bdc71ea0d0 100644 --- a/tensorflow/compiler/mlir/tf2xla/internal/mlir_bridge_pass_util.h +++ b/tensorflow/compiler/mlir/tf2xla/internal/mlir_bridge_pass_util.h @@ -16,37 +16,27 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_TF2XLA_INTERNAL_MLIR_BRIDGE_PASS_UTIL_H_ #define TENSORFLOW_COMPILER_MLIR_TF2XLA_INTERNAL_MLIR_BRIDGE_PASS_UTIL_H_ -#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "tensorflow/core/framework/function.h" namespace tensorflow { -using ::mlir::LogicalResult; - -// Checks if a graph or reachable functions in the library have any ops with -// _xla_compile_device_type attribute. -bool HasCompileDeviceTypeAttr( - const Graph& graph, const FunctionLibraryDefinition& function_library); - -// Checks if a graph or reachable functions in the library have any ops with -// _XlaMustCompile attribute. -bool HasMustCompileAttr(const Graph& graph, - const FunctionLibraryDefinition& function_library); +// Checks if a graph or reachable functions in the library have any +// StatefulPartitionedOps with _XlaMustCompile=true. The function library will +// be skipped if nullptr is provided. +bool IsSupportedByNonReplicatedBridge( + const Graph& graph, const FunctionLibraryDefinition* function_library); // Checks if a graph or reachable functions in the library have any ops with -// _tpu_replicate attribute. -bool HasTpuReplicateAttr(const Graph& graph, - const FunctionLibraryDefinition& function_library); +// _tpu_replicate or _xla_compile_device_type=TPU. The function library will be +// skipped if nullptr is provided. -// Checks if a graph or reachable functions in the library have any -// StatefulPartitionedCall ops with _XlaMustCompile attribute. -bool IsNonReplicatedGraph(const Graph& graph, - const FunctionLibraryDefinition& function_library); +bool IsSupportedByReplicatedBridge( + const Graph& graph, const FunctionLibraryDefinition* function_library); -// Checks if a graph or reachable functions in the library have any ops with +// Check if an MLIR module has any ops with _tpu_replicate or // _xla_compile_device_type=TPU. -bool IsSingleCoreTpuGraph(const Graph& graph, - const FunctionLibraryDefinition& function_library); +bool IsSupportedByReplicatedBridge(mlir::ModuleOp module); } // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tf2xla/internal/mlir_bridge_pass_util_test.cc b/tensorflow/compiler/mlir/tf2xla/internal/mlir_bridge_pass_util_test.cc index 8132eb66932f33..8ce19560f1f349 100644 --- a/tensorflow/compiler/mlir/tf2xla/internal/mlir_bridge_pass_util_test.cc +++ b/tensorflow/compiler/mlir/tf2xla/internal/mlir_bridge_pass_util_test.cc @@ -18,9 +18,15 @@ limitations under the License. #include #include +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/OwningOpRef.h" // from @llvm-project +#include "mlir/Parser/Parser.h" // from @llvm-project #include "tensorflow/cc/framework/ops.h" #include "tensorflow/cc/framework/scope.h" #include "tensorflow/cc/ops/function_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h" #include "tensorflow/compiler/tf2xla/tf2xla_defs.h" #include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/function_testlib.h" @@ -54,7 +60,7 @@ FunctionDef OuterXTimesTwo() { {std::string(kMustCompileAttr), true}}}}); } -TEST(HasCompileDeviceTypeAttr, GraphWithXlaClusters) { +TEST(IsSupportedByNonReplicatedBridge, NonReplicatedGraph) { const FunctionDef& fd = test::function::XTimesTwo(); FunctionDefLibrary flib; *flib.add_function() = fd; @@ -79,21 +85,21 @@ TEST(HasCompileDeviceTypeAttr, GraphWithXlaClusters) { .Attr("Tout", {DT_FLOAT}) .Attr("f", f_name_attr) .Finalize(root.graph(), &call)); - call->AddAttr(std::string(kCompileDeviceTypeAttr), kGpuDevice); + call->AddAttr(std::string(kMustCompileAttr), true); TF_ASSERT_OK(root.ToGraph(&graph)); - FunctionLibraryDefinition empty_flib_def(OpRegistry::Global()); EXPECT_TRUE( - HasCompileDeviceTypeAttr(graph, /*function_library=*/empty_flib_def)); + IsSupportedByNonReplicatedBridge(graph, /*function_library=*/nullptr)); } -TEST(HasTpuReplicateAttr, GraphWithXlaClusters) { - const FunctionDef& fd = test::function::XTimesTwo(); +// Checks that HasAttr actually goes through function library. +TEST(IsSupportedByNonReplicatedBridge, NonReplicatedFunctionLibrary) { + const FunctionDef& fd = OuterXTimesTwo(); FunctionDefLibrary flib; *flib.add_function() = fd; FunctionLibraryDefinition flib_def(OpRegistry::Global(), flib); - Graph graph(flib_def); + Graph graph(OpRegistry::Global()); graph.SetConstructionContext(ConstructionContext::kEagerRuntime); tensorflow::set_tf2_execution(true); @@ -103,6 +109,8 @@ TEST(HasTpuReplicateAttr, GraphWithXlaClusters) { Output a = ops::_Arg(root.WithOpName("A"), DT_FLOAT, 0); std::vector inputs({NodeBuilder::NodeOut(a.node())}); + // Builds a call without compilation markers that calls a function with Xla + // clusters. Node* call; NameAttrList f_name_attr; f_name_attr.set_name(fd.signature().name()); @@ -113,15 +121,13 @@ TEST(HasTpuReplicateAttr, GraphWithXlaClusters) { .Attr("Tout", {DT_FLOAT}) .Attr("f", f_name_attr) .Finalize(root.graph(), &call)); - call->AddAttr(std::string(kTpuReplicateAttr), "cluster"); TF_ASSERT_OK(root.ToGraph(&graph)); - - FunctionLibraryDefinition empty_flib_def(OpRegistry::Global()); - EXPECT_TRUE(HasTpuReplicateAttr(graph, /*function_library=*/empty_flib_def)); + EXPECT_TRUE( + IsSupportedByNonReplicatedBridge(graph, /*function_library=*/&flib_def)); } -TEST(IsNonReplicatedGraph, GraphWithXlaClusters) { +TEST(IsSupportedByReplicatedBridge, ReplicatedGraph) { const FunctionDef& fd = test::function::XTimesTwo(); FunctionDefLibrary flib; *flib.add_function() = fd; @@ -146,48 +152,15 @@ TEST(IsNonReplicatedGraph, GraphWithXlaClusters) { .Attr("Tout", {DT_FLOAT}) .Attr("f", f_name_attr) .Finalize(root.graph(), &call)); - call->AddAttr(std::string(kMustCompileAttr), true); + call->AddAttr(std::string(kTpuReplicateAttr), "cluster"); TF_ASSERT_OK(root.ToGraph(&graph)); - FunctionLibraryDefinition empty_flib_def(OpRegistry::Global()); - EXPECT_TRUE(IsNonReplicatedGraph(graph, /*function_library=*/empty_flib_def)); -} - -// Checks that HasAttr actually goes through function library. -TEST(IsNonReplicatedGraph, FunctionLibraryWithXlaClusters) { - const FunctionDef& fd = OuterXTimesTwo(); - FunctionDefLibrary flib; - *flib.add_function() = fd; - FunctionLibraryDefinition flib_def(OpRegistry::Global(), flib); - Graph graph(OpRegistry::Global()); - graph.SetConstructionContext(ConstructionContext::kEagerRuntime); - tensorflow::set_tf2_execution(true); - - ConfigProto config = ConfigProto(); - Scope root = Scope::NewRootScope().ExitOnError(); - - Output a = ops::_Arg(root.WithOpName("A"), DT_FLOAT, 0); - std::vector inputs({NodeBuilder::NodeOut(a.node())}); - - // Builds a call without compilation markers that calls a function with Xla - // clusters. - Node* call; - NameAttrList f_name_attr; - f_name_attr.set_name(fd.signature().name()); - TF_ASSERT_OK( - NodeBuilder("B", "StatefulPartitionedCall", &root.graph()->flib_def()) - .Input(inputs) - .Attr("Tin", {DT_FLOAT}) - .Attr("Tout", {DT_FLOAT}) - .Attr("f", f_name_attr) - .Finalize(root.graph(), &call)); - - TF_ASSERT_OK(root.ToGraph(&graph)); - EXPECT_TRUE(IsNonReplicatedGraph(graph, /*function_library=*/flib_def)); + EXPECT_TRUE( + IsSupportedByReplicatedBridge(graph, /*function_library=*/nullptr)); } -TEST(IsSingleCoreTpuGraph, GraphWithXlaClusters) { +TEST(IsSupportedByReplicatedBridge, SingleCoreTpuGraph) { const FunctionDef& fd = test::function::XTimesTwo(); FunctionDefLibrary flib; *flib.add_function() = fd; @@ -216,8 +189,38 @@ TEST(IsSingleCoreTpuGraph, GraphWithXlaClusters) { TF_ASSERT_OK(root.ToGraph(&graph)); - FunctionLibraryDefinition empty_flib_def(OpRegistry::Global()); - EXPECT_TRUE(IsSingleCoreTpuGraph(graph, /*function_library=*/empty_flib_def)); + EXPECT_TRUE( + IsSupportedByReplicatedBridge(graph, /*function_library=*/nullptr)); +} + +TEST(IsSupportedByReplicatedBridge, ReplicatedModule) { + const char* const code = R"mlir( +func.func @entry_func_1(%arg0: tensor) -> tensor attributes {tf.entry_function = {}} { + %0 = "tf.Identity"(%arg0) {_tpu_replicate = "cluster"} : (tensor) -> (tensor) + func.return %0 : tensor +} +)mlir"; + mlir::MLIRContext context; + context.loadDialect(); + mlir::OwningOpRef module = + mlir::parseSourceString(code, &context); + ASSERT_TRUE(module); + EXPECT_TRUE(IsSupportedByReplicatedBridge(*module)); +} + +TEST(IsSupportedByReplicatedBridge, SingleCoreTpuModule) { + const char* const code = R"mlir( +func.func @entry_func_1(%arg0: tensor) -> tensor attributes {tf.entry_function = {}} { + %0 = "tf.Identity"(%arg0) {_xla_compile_device_type = "TPU"} : (tensor) -> (tensor) + func.return %0 : tensor +} +)mlir"; + mlir::MLIRContext context; + context.loadDialect(); + mlir::OwningOpRef module = + mlir::parseSourceString(code, &context); + ASSERT_TRUE(module); + EXPECT_TRUE(IsSupportedByReplicatedBridge(*module)); } } // namespace diff --git a/tensorflow/compiler/mlir/tf2xla/internal/passes/BUILD b/tensorflow/compiler/mlir/tf2xla/internal/passes/BUILD index 6250f2cf0ca7c5..2b215023c13f15 100644 --- a/tensorflow/compiler/mlir/tf2xla/internal/passes/BUILD +++ b/tensorflow/compiler/mlir/tf2xla/internal/passes/BUILD @@ -26,6 +26,7 @@ cc_library( deps = [ ":extract_head_tail_outside_compilation", ":extract_outside_compilation", + ":hoist_broadcast_read", ":mark_ops_for_outside_compilation", ":tpu_cluster_formation", ":verify_clustering_pass", @@ -351,6 +352,41 @@ cc_library( ], ) +cc_library( + name = "hoist_broadcast_read", + srcs = ["hoist_broadcast_read.cc"], + textual_hdrs = [ + "clustering_passes.h.inc", + ], + deps = [ + ":clustering_passes_inc_gen", + "//tensorflow/compiler/mlir/tensorflow", + "//tensorflow/compiler/mlir/tensorflow:attribute_utils", + "//tensorflow/compiler/mlir/tensorflow:string_util", + "//tensorflow/compiler/mlir/tensorflow:tensorflow_analysis", + "//tensorflow/compiler/mlir/tensorflow:tensorflow_ops", + "//tensorflow/compiler/mlir/tensorflow:tensorflow_types", + "//tensorflow/compiler/mlir/tensorflow:tpu_rewrite_device_util", + "//tensorflow/compiler/mlir/tensorflow/transforms:lower_tf_lib", + "//tensorflow/compiler/mlir/tensorflow/transforms:tf_pass_inc_gen", + "//tensorflow/compiler/mlir/tensorflow/transforms:verify_no_outside_compilation_markers_pass", + "//tensorflow/compiler/mlir/tf2xla/transforms:legalization_op_config", + "//tensorflow/compiler/mlir/tf2xla/transforms:legalize_tf", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/log", + "@com_google_absl//absl/strings", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Rewrite", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TransformUtils", + ], +) + tf_cc_test( name = "tpu_cluster_formation_test", srcs = ["tpu_cluster_formation_test.cc"], @@ -366,3 +402,66 @@ tf_cc_test( "@local_tsl//tsl/platform:statusor", ], ) + +cc_library( + name = "lowering_passes", + hdrs = [ + "lowering_passes.h", + ], + textual_hdrs = [ + "lowering_passes.h.inc", + ], + deps = [ + ":input_metrics_lowering_pass", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + ], +) + +gentbl_cc_library( + name = "lowering_passes_inc_gen", + compatible_with = get_compatible_with_portable(), + tbl_outs = [ + ( + [ + "-gen-pass-decls", + "-name=TFXLABridgeLowering", + ], + "lowering_passes.h.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "lowering_passes.td", + deps = [ + "@llvm-project//mlir:PassBaseTdFiles", + ], +) + +cc_library( + name = "input_metrics_lowering_pass", + srcs = [ + "input_lowering_metrics_pass.cc", + ], + deps = [ + ":lowering_passes_inc_gen", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + ], +) + +tf_cc_test( + name = "input_metrics_lowering_pass_test", + srcs = ["input_lowering_metrics_pass_test.cc"], + deps = [ + ":lowering_passes", + "//tensorflow/compiler/mlir/tf2xla/transforms:test_utils", + "@com_google_googletest//:gtest_main", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", + "@local_tsl//tsl/platform:statusor", + ], +) diff --git a/tensorflow/compiler/mlir/tf2xla/internal/passes/clustering_passes.h b/tensorflow/compiler/mlir/tf2xla/internal/passes/clustering_passes.h index ea6187a2309205..3ccf990a4b5272 100644 --- a/tensorflow/compiler/mlir/tf2xla/internal/passes/clustering_passes.h +++ b/tensorflow/compiler/mlir/tf2xla/internal/passes/clustering_passes.h @@ -56,6 +56,11 @@ CreateXlaOutlineEntryFunctionsPass(); std::unique_ptr> CreateMarkOpsForOutsideCompilationPass(); +// Creates a pass that hoists reads out of a replicate that are on a variable +// whose value is broacast to all replicas. +std::unique_ptr> +CreateHoistBroadcastReadPass(); + #define GEN_PASS_REGISTRATION #define GEN_PASS_DECL_MARKOPSFOROUTSIDECOMPILATIONPASS #define GEN_PASS_DECL_TPUCLUSTERFORMATIONPASS diff --git a/tensorflow/compiler/mlir/tf2xla/internal/passes/clustering_passes.td b/tensorflow/compiler/mlir/tf2xla/internal/passes/clustering_passes.td index c219c35842c401..90d2e962bc9b1c 100644 --- a/tensorflow/compiler/mlir/tf2xla/internal/passes/clustering_passes.td +++ b/tensorflow/compiler/mlir/tf2xla/internal/passes/clustering_passes.td @@ -349,3 +349,42 @@ def MarkOpsForOutsideCompilationPass : Pass<"tf-mark-ops-for-outside-compilation let constructor = "tensorflow::tf2xla::internal::CreateMarkOpsForOutsideCompilationPass()"; } + +def HoistBroadcastReadPass : Pass<"tf-hoist-broadcast-read", "mlir::func::FuncOp"> { + let summary = "Hoist reads out of a replicate that are on a resource that is broacast to all replicas."; + + let description = [{ + Some `ReadVariableOp`s that are within a `tf_device.replicate` read the same + value across all replicas. These reads can be hoisted out of the + `tf_device.replicate` so there's one read for all replicas, and each replica + depends on the result of the read. This transform enables the + xla-broadcast-pass to optimize the broadcast value. + + For example, the following: + + ```mlir + tf_device.replicate {n = 2 : i32} { + %0 = "tf.ReadVariableOp"(%arg0) : (tensor<*x!tf_type.resource>>) -> tensor + "tf.OpA"(%0) : (tensor) -> () + } + ``` + + will be transformed into: + + ``mlir + %0 = "tf.ReadVariableOp"(%arg0) : (tensor<*x!tf_type.resource>>) -> tensor + tf_device.replicate {n = 2 : i32} { + "tf.OpA"(%0) : (tensor) -> () + } + ``` + + We must ensure that there is a single underlying resource that not + distributed across replicas. There is a single underlying resource when the + resource device type is CPU, so we cautiously only apply in this case. + + To be cautious we never hoist a read that comes after a write to the same + resource. + }]; + + let constructor = "tensorflow::tf2xla::internal::CreateHoistBroadcastReadPass()"; +} diff --git a/tensorflow/compiler/mlir/tf2xla/internal/passes/hoist_broadcast_read.cc b/tensorflow/compiler/mlir/tf2xla/internal/passes/hoist_broadcast_read.cc new file mode 100644 index 00000000000000..732bae8c67b018 --- /dev/null +++ b/tensorflow/compiler/mlir/tf2xla/internal/passes/hoist_broadcast_read.cc @@ -0,0 +1,154 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include "llvm/ADT/SmallVector.h" +#include "llvm/Support/Casting.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/IR/Visitors.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/core/util/device_name_utils.h" + +namespace tensorflow { +namespace tf2xla { +namespace internal { + +namespace { + +using mlir::BlockArgument; +using mlir::failure; +using mlir::LogicalResult; +using mlir::Operation; +using mlir::OperationPass; +using mlir::OpOperand; +using mlir::StringAttr; +using mlir::success; +using mlir::Value; +using mlir::WalkResult; +using mlir::func::FuncOp; +using mlir::TF::ReadVariableOp; +using mlir::tf_device::ReplicateOp; + +#define GEN_PASS_DEF_HOISTBROADCASTREADPASS +#include "tensorflow/compiler/mlir/tf2xla/internal/passes/clustering_passes.h.inc" + +constexpr char kFuncDeviceAttr[] = "tf.device"; +constexpr char kCpuDeviceType[] = "CPU"; + +struct HoistBroadcastRead + : public impl::HoistBroadcastReadPassBase { + void runOnOperation() override; +}; + +// Get the ancestor of `descendant` that is a direct child of `ancestor`. +Operation* GetAncestorBelow(Operation* descendant, Operation* ancestor) { + Operation* parent = descendant->getParentOp(); + if (!parent) return nullptr; + if (parent == ancestor) return descendant; + return GetAncestorBelow(parent, ancestor); +} + +// `is_cpu_read` is set to `true` iff `read` is on a resource with device type +// CPU. +LogicalResult IsCpuRead(FuncOp func, ReadVariableOp read, bool& is_cpu_read) { + if (auto arg = read->getOperand(0).dyn_cast()) { + if (arg.getOwner() != &(func.front())) { + is_cpu_read = false; + return success(); + } + if (auto attr = func.getArgAttrOfType(arg.getArgNumber(), + kFuncDeviceAttr)) { + std::string device = attr.getValue().str(); + tensorflow::DeviceNameUtils::ParsedName parsed_name; + if (!tensorflow::DeviceNameUtils::ParseFullName(device, &parsed_name)) { + return read->emitOpError() << "invalid device '" << device << "'"; + } + is_cpu_read = parsed_name.type == kCpuDeviceType; + return success(); + } + } + is_cpu_read = false; + return success(); +} + +// Get the reads to hoist in the `replicate`. +LogicalResult GetReads(FuncOp func, ReplicateOp replicate, + llvm::SmallVector& reads) { + for (Operation& op : replicate.getBody().front()) { + if (auto read = llvm::dyn_cast(&op)) { + bool is_cpu_read; + if (failed(IsCpuRead(func, read, is_cpu_read))) return failure(); + if (is_cpu_read) reads.push_back(read); + } + } + return success(); +} + +// Move reads above the `replicate`. Skip reads that come after a write to the +// same resource. +void MoveReads(ReplicateOp replicate, + llvm::SmallVector& reads) { + for (ReadVariableOp read : reads) { + Value res = read.getResource(); + Operation* scope = res.getParentBlock()->getParentOp(); + if (!scope->isProperAncestor(replicate)) continue; + bool has_conflicting_write = false; + for (OpOperand& use : res.getUses()) { + Operation* using_op = use.getOwner(); + if (using_op == read) continue; + if (!replicate->isProperAncestor(using_op)) continue; + Operation* peer = GetAncestorBelow(using_op, replicate); + if (read->isBeforeInBlock(peer)) continue; + if (llvm::isa(peer)) continue; + has_conflicting_write = true; + } + if (has_conflicting_write) continue; + read->moveBefore(replicate); + } +} + +// Hoist `ReadVariableOp`s above the `tf_device.replicate`s. +void HoistBroadcastRead::runOnOperation() { + FuncOp func = getOperation(); + + auto result = func.walk([&](ReplicateOp replicate) { + llvm::SmallVector reads; + if (failed(GetReads(func, replicate, reads))) + return WalkResult::interrupt(); + MoveReads(replicate, reads); + return WalkResult::advance(); + }); + + if (result.wasInterrupted()) return signalPassFailure(); +} + +} // namespace + +std::unique_ptr> CreateHoistBroadcastReadPass() { + return std::make_unique(); +} + +} // namespace internal +} // namespace tf2xla +} // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tf2xla/internal/passes/input_lowering_metrics_pass.cc b/tensorflow/compiler/mlir/tf2xla/internal/passes/input_lowering_metrics_pass.cc new file mode 100644 index 00000000000000..8dcfde07c013be --- /dev/null +++ b/tensorflow/compiler/mlir/tf2xla/internal/passes/input_lowering_metrics_pass.cc @@ -0,0 +1,46 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project + +namespace tensorflow { +namespace tf2xla { +namespace internal { + +namespace { + +#define GEN_PASS_DEF_INPUTLOWERINGMETRICSPASS +#include "tensorflow/compiler/mlir/tf2xla/internal/passes/lowering_passes.h.inc" + +class InputMetricsLoweringPass + : public impl::InputLoweringMetricsPassBase { + public: + void runOnOperation() override; +}; + +void InputMetricsLoweringPass::runOnOperation() {} +} // namespace + +std::unique_ptr> +CreateInputLoweringMetricsPass() { + return std::make_unique(); +} + +} // namespace internal +} // namespace tf2xla +} // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tf2xla/internal/passes/input_lowering_metrics_pass_test.cc b/tensorflow/compiler/mlir/tf2xla/internal/passes/input_lowering_metrics_pass_test.cc new file mode 100644 index 00000000000000..5fe9e58f9f23da --- /dev/null +++ b/tensorflow/compiler/mlir/tf2xla/internal/passes/input_lowering_metrics_pass_test.cc @@ -0,0 +1,87 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/OperationSupport.h" // from @llvm-project +#include "mlir/IR/OwningOpRef.h" // from @llvm-project +#include "mlir/Pass/PassManager.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tf2xla/internal/passes/lowering_passes.h" +#include "tensorflow/compiler/mlir/tf2xla/transforms/test_utils.h" +#include "tsl/platform/statusor.h" + +namespace tensorflow { +namespace tf2xla { +namespace internal { + +namespace { + +using mlir::LogicalResult; +using mlir::ModuleOp; +using mlir::mhlo::test::GetMlirModuleFromString; + +class InputLoweringMetricsPassTest : public testing::Test { + protected: + void CreateModule(const char* module_string) { + TF_ASSERT_OK_AND_ASSIGN(module_, + GetMlirModuleFromString(module_string, &context_)); + pm_ = std::make_unique(&context_); + pm_->addNestedPass(CreateInputLoweringMetricsPass()); + } + + bool ModulesEqual(const ModuleOp& module_before, + const ModuleOp& module_after) { + return mlir::OperationEquivalence::isEquivalentTo( + module_before, module_after, mlir::OperationEquivalence::None); + } + + mlir::LogicalResult Run() { + mlir::OwningOpRef module_before = module_->clone(); + LogicalResult run_result = pm_->run(module_.get()); + + EXPECT_TRUE(ModulesEqual(*module_before, *module_)); + return run_result; + } + + private: + mlir::MLIRContext context_; + mlir::OwningOpRef module_; + std::unique_ptr pm_; +}; + +TEST_F(InputLoweringMetricsPassTest, RunsSuccessfully) { + static constexpr char kMlirModuleStr[] = R"( + module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, producer = 268 : i32}} { + func.func @main() -> tensor<1xi32> { + %0 = "tf.Const"() {value = dense<1000> : tensor<1xi32>} : () -> tensor<1xi32> + return %0 : tensor<1xi32> + } + })"; + CreateModule(kMlirModuleStr); + + auto result = Run(); + + EXPECT_TRUE(result.succeeded()); +} + +} // namespace +} // namespace internal +} // namespace tf2xla +} // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tf2xla/internal/passes/lowering_passes.h b/tensorflow/compiler/mlir/tf2xla/internal/passes/lowering_passes.h new file mode 100644 index 00000000000000..0be689c6637ba6 --- /dev/null +++ b/tensorflow/compiler/mlir/tf2xla/internal/passes/lowering_passes.h @@ -0,0 +1,38 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TF2XLA_INTERNAL_PASSES_LOWERING_PASSES_H_ +#define TENSORFLOW_COMPILER_MLIR_TF2XLA_INTERNAL_PASSES_LOWERING_PASSES_H_ + +#include + +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project + +namespace tensorflow { +namespace tf2xla { +namespace internal { + +// Create a pass that just collects metrics about the input MLIR. Does not +// logically transform the program. +std::unique_ptr> +CreateInputLoweringMetricsPass(); + +#define GEN_PASS_REGISTRATION +#define GEN_PASS_DECL_INPUTLOWERINGMETRICSPASS +#include "tensorflow/compiler/mlir/tf2xla/internal/passes/lowering_passes.h.inc" + +} // namespace internal +} // namespace tf2xla +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_MLIR_TF2XLA_INTERNAL_PASSES_LOWERING_PASSES_H_ diff --git a/tensorflow/compiler/mlir/tf2xla/internal/passes/lowering_passes.td b/tensorflow/compiler/mlir/tf2xla/internal/passes/lowering_passes.td new file mode 100644 index 00000000000000..3247145bce0200 --- /dev/null +++ b/tensorflow/compiler/mlir/tf2xla/internal/passes/lowering_passes.td @@ -0,0 +1,25 @@ +/* Copyright 2024 TensorFlow Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +include "mlir/Pass/PassBase.td" + +def InputLoweringMetricsPass : Pass<"input-lowering-metrics-pass", "mlir::func::FuncOp"> { + + let summary = "Collects various metrics about the input to the lowering " + "portion of the bridge. This is a logical no-op."; + + let description = [{ + Gathers metrics about the input MLIR to Phase 2 of the TFXLA Bridge, which + does a strict semantic lowering from Tensorflow ops to XLA HLO. + }]; + + let constructor = "tensorflow::tf2xla::internal::CreateInputLoweringMetricsPass()"; +} \ No newline at end of file diff --git a/tensorflow/compiler/mlir/tf2xla/internal/passes/verify_clustering_pass_test.mlir b/tensorflow/compiler/mlir/tf2xla/internal/passes/verify_clustering_pass_test.mlir index 7ba98798c126df..c098bf494272fd 100644 --- a/tensorflow/compiler/mlir/tf2xla/internal/passes/verify_clustering_pass_test.mlir +++ b/tensorflow/compiler/mlir/tf2xla/internal/passes/verify_clustering_pass_test.mlir @@ -3,7 +3,7 @@ func.func @testNotTfDialect(%arg0: tensor<1x32x10x32xi32>, %arg1: tensor<32xi32>) -> tensor<1x32x10x32xi32> { // expected-error@below {{op is in dialect chlo not in tf functional dialect}} - %0 = "chlo.broadcast_add"(%arg0, %arg1) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<1x32x10x32xi32>, tensor<32xi32>) -> tensor<1x32x10x32xi32> + %0 = "chlo.broadcast_add"(%arg0, %arg1) {broadcast_dimensions = array} : (tensor<1x32x10x32xi32>, tensor<32xi32>) -> tensor<1x32x10x32xi32> func.return %0 : tensor<1x32x10x32xi32> } diff --git a/tensorflow/compiler/mlir/tf2xla/internal/passes/verify_input_dialect_to_executor_pass_test.mlir b/tensorflow/compiler/mlir/tf2xla/internal/passes/verify_input_dialect_to_executor_pass_test.mlir index 9d8b3e2e690f52..7d88228447e4af 100644 --- a/tensorflow/compiler/mlir/tf2xla/internal/passes/verify_input_dialect_to_executor_pass_test.mlir +++ b/tensorflow/compiler/mlir/tf2xla/internal/passes/verify_input_dialect_to_executor_pass_test.mlir @@ -33,6 +33,6 @@ func.func @testTFDialect(%arg0: tensor<4x?x!tf_type.stringref>) -> tensor<4x2x!t func.func @testNotTfDialect(%arg0: tensor<1x32x10x32xi32>, %arg1: tensor<32xi32>) -> tensor<1x32x10x32xi32> { // expected-error@below {{op is in dialect chlo which is not an accepted dialect}} - %0 = "chlo.broadcast_add"(%arg0, %arg1) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<1x32x10x32xi32>, tensor<32xi32>) -> tensor<1x32x10x32xi32> + %0 = "chlo.broadcast_add"(%arg0, %arg1) {broadcast_dimensions = array} : (tensor<1x32x10x32xi32>, tensor<32xi32>) -> tensor<1x32x10x32xi32> func.return %0 : tensor<1x32x10x32xi32> } diff --git a/tensorflow/compiler/mlir/tf2xla/internal/utils/BUILD b/tensorflow/compiler/mlir/tf2xla/internal/utils/BUILD index a67178be9d770a..890b2277397557 100644 --- a/tensorflow/compiler/mlir/tf2xla/internal/utils/BUILD +++ b/tensorflow/compiler/mlir/tf2xla/internal/utils/BUILD @@ -43,3 +43,36 @@ tf_cc_test( "@stablehlo//:chlo_ops", ], ) + +cc_library( + name = "test_metadata_config", + testonly = True, + srcs = ["test_metadata_config.cc"], + hdrs = ["test_metadata_config.h"], + visibility = [ + "//tensorflow/compiler/mlir/tf2xla/api:__subpackages__", + ], + deps = [ + "//tensorflow/compiler/jit", + "//tensorflow/compiler/jit:xla_tpu_device", + "//tensorflow/compiler/mlir/tensorflow", + "//tensorflow/compiler/mlir/tensorflow:convert_type", + "//tensorflow/compiler/mlir/tensorflow:serialize_mlir_module_utils", + "//tensorflow/compiler/tf2xla:layout_util", + "//tensorflow/compiler/tf2xla:xla_compiler", + "//tensorflow/compiler/tf2xla:xla_helpers", + "//tensorflow/compiler/tf2xla:xla_tpu_backend_registration", + "//tensorflow/compiler/tf2xla/kernels:xla_ops", + "//tensorflow/core:framework", + "//tensorflow/core:test_main", + "//tensorflow/core/protobuf/tpu:compile_metadata_proto_cc", + "//tensorflow/core/tpu/kernels/xla:host_compute_ops", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings:string_view", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@local_tsl//tsl/platform:errors", + "@local_xla//xla:shape_util", + "@local_xla//xla/translate/mhlo_to_hlo:type_to_shape", + ], +) diff --git a/tensorflow/compiler/mlir/tf2xla/internal/utils/test_metadata_config.cc b/tensorflow/compiler/mlir/tf2xla/internal/utils/test_metadata_config.cc new file mode 100644 index 00000000000000..a74af203e52fc8 --- /dev/null +++ b/tensorflow/compiler/mlir/tf2xla/internal/utils/test_metadata_config.cc @@ -0,0 +1,105 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/mlir/tf2xla/internal/utils/test_metadata_config.h" + +#include + +#include "absl/status/status.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/DialectRegistry.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/OwningOpRef.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/serialize_mlir_module_utils.h" +#include "xla/shape.h" +#include "xla/translate/mhlo_to_hlo/type_to_shape.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/protobuf/tpu/compile_metadata.pb.h" +#include "tsl/platform/errors.h" + +namespace tensorflow { +namespace tf2xla { +namespace internal { +namespace { + +constexpr char kEntryFuncName[] = "main"; + +absl::Status SetupArguments(mlir::ModuleOp module, + std::vector& arg_shapes, + tpu::TPUCompileMetadataProto& metadata_proto) { + auto main_fn = module.lookupSymbol(kEntryFuncName); + if (!main_fn) { + return absl::InternalError("Could not find main function in MLIR Module."); + } + + mlir::FunctionType func_type = main_fn.getFunctionType(); + for (auto input_type : func_type.getInputs()) { + tensorflow::TensorShape tensor_shape; + xla::Shape xla_shape = xla::TypeToShape(input_type); + TF_RETURN_IF_ERROR(tensorflow::TensorShape::BuildTensorShape( + xla_shape.dimensions(), &tensor_shape)); + arg_shapes.emplace_back(tensor_shape); + + DataType dtype; + TF_RETURN_IF_ERROR(ConvertToDataType(input_type, &dtype)); + + auto metadata_arg = metadata_proto.add_args(); + metadata_arg->set_kind(tpu::TPUCompileMetadataProto::Arg::PARAMETER); + metadata_arg->set_dtype(dtype); + } + + return absl::OkStatus(); +} + +absl::Status SetupReturnValues(mlir::ModuleOp module, + tpu::TPUCompileMetadataProto& metadata_proto) { + auto main_fn = module.lookupSymbol(kEntryFuncName); + if (!main_fn) { + return absl::InternalError("Could not find main function in MLIR Module."); + } + + int func_results = main_fn.getFunctionType().getNumResults(); + for (int i = 0; i < func_results; i++) { + metadata_proto.add_retvals(); + } + + return absl::OkStatus(); +} + +} // namespace + +absl::Status ConfigureMetadata(absl::string_view mlir_module_str, + std::vector& arg_shapes, + tpu::TPUCompileMetadataProto& metadata_proto) { + mlir::DialectRegistry registry; + mlir::RegisterAllTensorFlowDialects(registry); + mlir::MLIRContext context(registry); + mlir::OwningOpRef mlir_module; + + TF_RETURN_IF_ERROR( + DeserializeMlirModule(mlir_module_str, &context, &mlir_module)); + TF_RETURN_IF_ERROR(SetupReturnValues(*mlir_module, metadata_proto)); + TF_RETURN_IF_ERROR(SetupArguments(*mlir_module, arg_shapes, metadata_proto)); + + return absl::OkStatus(); +} + +} // namespace internal +} // namespace tf2xla +} // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tf2xla/internal/utils/test_metadata_config.h b/tensorflow/compiler/mlir/tf2xla/internal/utils/test_metadata_config.h new file mode 100644 index 00000000000000..6b9ec5244c0bee --- /dev/null +++ b/tensorflow/compiler/mlir/tf2xla/internal/utils/test_metadata_config.h @@ -0,0 +1,40 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TF2XLA_INTERNAL_UTILS_TEST_METADATA_CONFIG_H_ +#define TENSORFLOW_COMPILER_MLIR_TF2XLA_INTERNAL_UTILS_TEST_METADATA_CONFIG_H_ + +#include +#include + +#include "absl/strings/string_view.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/protobuf/tpu/compile_metadata.pb.h" + +namespace tensorflow { +namespace tf2xla { +namespace internal { + +// Fills in arg_shapes and metadata_proto with appropriate values based on the +// input mlir module. +absl::Status ConfigureMetadata(absl::string_view mlir_module_str, + std::vector& arg_shapes, + tpu::TPUCompileMetadataProto& metadata_proto); + +} // namespace internal +} // namespace tf2xla +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_MLIR_TF2XLA_INTERNAL_UTILS_TEST_METADATA_CONFIG_H_ diff --git a/tensorflow/compiler/mlir/tf2xla/mlir_bridge_rollout_policy.cc b/tensorflow/compiler/mlir/tf2xla/mlir_bridge_rollout_policy.cc index 3f35813744c60d..f00e12690d380b 100644 --- a/tensorflow/compiler/mlir/tf2xla/mlir_bridge_rollout_policy.cc +++ b/tensorflow/compiler/mlir/tf2xla/mlir_bridge_rollout_policy.cc @@ -24,7 +24,8 @@ namespace tensorflow { MlirBridgeRolloutPolicy GetMlirBridgeRolloutPolicy( const tensorflow::Graph& graph, const FunctionLibraryDefinition* function_library, - std::optional config_proto, bool run_tpu_bridge, + std::optional config_proto, + bool is_supported_by_replicated_brige, bool uses_uninitialized_resource_args, bool is_v1_compat, bool record_stats) { switch (GetMlirBridgeRolloutState(config_proto)) { diff --git a/tensorflow/compiler/mlir/tf2xla/mlir_bridge_rollout_policy.h b/tensorflow/compiler/mlir/tf2xla/mlir_bridge_rollout_policy.h index 5c7f47a219e10e..66a68ae53f535f 100644 --- a/tensorflow/compiler/mlir/tf2xla/mlir_bridge_rollout_policy.h +++ b/tensorflow/compiler/mlir/tf2xla/mlir_bridge_rollout_policy.h @@ -53,7 +53,8 @@ enum class MlirBridgeRolloutPolicy { MlirBridgeRolloutPolicy GetMlirBridgeRolloutPolicy( const tensorflow::Graph& graph, const FunctionLibraryDefinition* function_library, - std::optional config_proto, bool run_tpu_bridge, + std::optional config_proto, + bool is_supported_by_replicated_brige, bool uses_uninitialized_resource_args, bool is_v1_compat, bool record_stats); diff --git a/tensorflow/compiler/mlir/tf2xla/tests/BUILD b/tensorflow/compiler/mlir/tf2xla/tests/BUILD index c68c485954de1b..97bb01c30d1855 100644 --- a/tensorflow/compiler/mlir/tf2xla/tests/BUILD +++ b/tensorflow/compiler/mlir/tf2xla/tests/BUILD @@ -1,5 +1,5 @@ -load("//tensorflow:tensorflow.default.bzl", "filegroup") load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests") +load("//tensorflow:tensorflow.default.bzl", "filegroup") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], diff --git a/tensorflow/compiler/mlir/tf2xla/tests/legalize-tf-communication.mlir b/tensorflow/compiler/mlir/tf2xla/tests/legalize-tf-communication.mlir index 11cfffc24eaa33..d79804a6d38cc6 100644 --- a/tensorflow/compiler/mlir/tf2xla/tests/legalize-tf-communication.mlir +++ b/tensorflow/compiler/mlir/tf2xla/tests/legalize-tf-communication.mlir @@ -119,7 +119,9 @@ func.func @send_to_host(%arg0: tensor) { // CHECK: [[INIT_TOKEN:%.*]] = mhlo.create_token // CHECK: "mhlo.send"([[ARG0]], [[INIT_TOKEN]]) - // CHECK-SAME: channel_handle = #mhlo.channel_handle + // CHECK-SAME: channel_handle = #mhlo.channel_handle + // CHECK-SAME: handle = + // CHECK-SAME: type = 2 // CHECK-SAME: is_host_transfer = true // CHECK-SAME: mhlo.frontend_attributes = {_xla_host_transfer_handler_name = "tf_rendezvous", _xla_host_transfer_rendezvous = "send_key_dtoh_0"} // CHECK-SAME: mhlo.sharding = "\08\01\1A\01\01\22\01\00" @@ -137,7 +139,9 @@ func.func @recv_from_host() -> tensor { // CHECK: [[INIT_TOKEN:%.*]] = mhlo.create_token // CHECK: [[RECV_TUPLE:%.*]]:2 = "mhlo.recv"([[INIT_TOKEN]]) - // CHECK-SAME: channel_handle = #mhlo.channel_handle + // CHECK-SAME: channel_handle = #mhlo.channel_handle + // CHECK-SAME: handle = + // CHECK-SAME: type = 3 // CHECK-SAME: is_host_transfer = true // CHECK-SAME: mhlo.frontend_attributes = {_xla_host_transfer_handler_name = "tf_rendezvous", _xla_host_transfer_rendezvous = "recv_key_htod_0"} // CHECK-SAME: mhlo.sharding = "\08\01\1A\01\01\22\01\00" @@ -158,21 +162,29 @@ func.func @multiple_consecutive_ops(%arg0: tensor) -> tensor { // CHECK: [[INIT_TOKEN:%.*]] = mhlo.create_token // CHECK: [[SEND0_ARG0_TOKEN:%.*]] = "mhlo.send"([[ARG0]], [[INIT_TOKEN]]) - // CHECK-SAME: channel_handle = #mhlo.channel_handle + // CHECK-SAME: channel_handle = #mhlo.channel_handle + // CHECK-SAME: handle = + // CHECK-SAME: type = 2 // CHECK-SAME: mhlo.frontend_attributes = {_xla_host_transfer_handler_name = "tf_rendezvous", _xla_host_transfer_rendezvous = "send0_dtoh_0"} // CHECK: [[RECV0_RETVAL0_TUPLE:%.*]]:2 = "mhlo.recv"([[SEND0_ARG0_TOKEN]]) - // CHECK-SAME: channel_handle = #mhlo.channel_handle + // CHECK-SAME: channel_handle = #mhlo.channel_handle + // CHECK-SAME: handle = + // CHECK-SAME: type = 3 // CHECK-SAME: mhlo.frontend_attributes = {_xla_host_transfer_handler_name = "tf_rendezvous", _xla_host_transfer_rendezvous = "recv0_htod_0"} %0 = "tf._XlaHostComputeMlir"(%arg0) {recv_key = "recv0", send_key = "send0", host_mlir_module = ""} : (tensor) -> tensor // CHECK: [[SEND1_ARG0_TOKEN:%.*]] = "mhlo.send"([[RECV0_RETVAL0_TUPLE]]#0, [[RECV0_RETVAL0_TUPLE]]#1) - // CHECK-SAME: channel_handle = #mhlo.channel_handle + // CHECK-SAME: channel_handle = #mhlo.channel_handle + // CHECK-SAME: handle = + // CHECK-SAME: type = 2 // CHECK-SAME: mhlo.frontend_attributes = {_xla_host_transfer_handler_name = "tf_rendezvous", _xla_host_transfer_rendezvous = "send1_dtoh_0"} // CHECK: [[RECV1_RETVAL0_TUPLE:%.*]]:2 = "mhlo.recv"([[SEND1_ARG0_TOKEN]]) - // CHECK-SAME: channel_handle = #mhlo.channel_handle + // CHECK-SAME: channel_handle = #mhlo.channel_handle + // CHECK-SAME: handle = + // CHECK-SAME: type = 3 // CHECK-SAME: mhlo.frontend_attributes = {_xla_host_transfer_handler_name = "tf_rendezvous", _xla_host_transfer_rendezvous = "recv1_htod_0"} %1 = "tf._XlaHostComputeMlir"(%0) {recv_key = "recv1", send_key = "send1", host_mlir_module = ""} : (tensor) -> tensor @@ -376,11 +388,15 @@ func.func @if_both_branches(%arg0: tensor, %arg1: tensor, %arg2: tensor // CHECK: [[IF:%.*]]:2 = "mhlo.if"([[ARG0]]) %0 = "mhlo.if"(%arg0) ({ // CHECK: [[TRUE_SEND_TOKEN:%.*]] = "mhlo.send"([[ARG1]], [[INIT_TOKEN]]) - // CHECK-SAME: channel_handle = #mhlo.channel_handle + // CHECK-SAME: channel_handle = #mhlo.channel_handle + // CHECK-SAME: handle = + // CHECK-SAME: type = 2 // CHECK-SAME: mhlo.frontend_attributes = {_xla_host_transfer_handler_name = "tf_rendezvous", _xla_host_transfer_rendezvous = "send_if_true_dtoh_0"} // CHECK: [[TRUE_RECV_TUPLE:%.*]]:2 = "mhlo.recv"([[TRUE_SEND_TOKEN]]) - // CHECK-SAME: channel_handle = #mhlo.channel_handle + // CHECK-SAME: channel_handle = #mhlo.channel_handle + // CHECK-SAME: handle = + // CHECK-SAME: type = 3 // CHECK-SAME: mhlo.frontend_attributes = {_xla_host_transfer_handler_name = "tf_rendezvous", _xla_host_transfer_rendezvous = "recv_if_true_htod_0"} %1 = "tf._XlaHostComputeMlir"(%arg1) {recv_key = "recv_if_true", send_key = "send_if_true", host_mlir_module = ""} : (tensor) -> tensor @@ -388,11 +404,15 @@ func.func @if_both_branches(%arg0: tensor, %arg1: tensor, %arg2: tensor "mhlo.return"(%1) : (tensor) -> () }, { // CHECK: [[FALSE_SEND_TOKEN:%.*]] = "mhlo.send"([[ARG2]], [[INIT_TOKEN]]) - // CHECK-SAME: channel_handle = #mhlo.channel_handle + // CHECK-SAME: channel_handle = #mhlo.channel_handle + // CHECK-SAME: handle = + // CHECK-SAME: type = 2 // CHECK-SAME: mhlo.frontend_attributes = {_xla_host_transfer_handler_name = "tf_rendezvous", _xla_host_transfer_rendezvous = "send_if_false_dtoh_0"} // CHECK: [[FALSE_RECV_TUPLE:%.*]]:2 = "mhlo.recv"([[FALSE_SEND_TOKEN]]) - // CHECK-SAME: channel_handle = #mhlo.channel_handle + // CHECK-SAME: channel_handle = #mhlo.channel_handle + // CHECK-SAME: handle = + // CHECK-SAME: type = 3 // CHECK-SAME: mhlo.frontend_attributes = {_xla_host_transfer_handler_name = "tf_rendezvous", _xla_host_transfer_rendezvous = "recv_if_false_htod_0"} %1 = "tf._XlaHostComputeMlir"(%arg2) {recv_key = "recv_if_false", send_key = "send_if_false", host_mlir_module = ""} : (tensor) -> tensor @@ -419,11 +439,15 @@ func.func @if_true_branch(%arg0: tensor, %arg1: tensor, %arg2: tensor + // CHECK-SAME: channel_handle = #mhlo.channel_handle + // CHECK-SAME: handle = + // CHECK-SAME: type = 2 // CHECK-SAME: mhlo.frontend_attributes = {_xla_host_transfer_handler_name = "tf_rendezvous", _xla_host_transfer_rendezvous = "send_if_true_dtoh_0"} // CHECK: [[TRUE_RECV_TUPLE:%.*]]:2 = "mhlo.recv"([[TRUE_SEND_TOKEN]]) - // CHECK-SAME: channel_handle = #mhlo.channel_handle + // CHECK-SAME: channel_handle = #mhlo.channel_handle + // CHECK-SAME: handle = + // CHECK-SAME: type = 3 // CHECK-SAME: mhlo.frontend_attributes = {_xla_host_transfer_handler_name = "tf_rendezvous", _xla_host_transfer_rendezvous = "recv_if_true_htod_0"} %1 = "tf._XlaHostComputeMlir"(%arg1) {recv_key = "recv_if_true", send_key = "send_if_true", host_mlir_module = ""} : (tensor) -> tensor @@ -456,11 +480,15 @@ func.func @if_false_branch(%arg0: tensor, %arg1: tensor, %arg2: tensor< "mhlo.return"(%arg1) : (tensor) -> () }, { // CHECK: [[FALSE_SEND_TOKEN:%.*]] = "mhlo.send"([[ARG2]], [[INIT_TOKEN]]) - // CHECK-SAME: channel_handle = #mhlo.channel_handle + // CHECK-SAME: channel_handle = #mhlo.channel_handle + // CHECK-SAME: handle = + // CHECK-SAME: type = 2 // CHECK-SAME: mhlo.frontend_attributes = {_xla_host_transfer_handler_name = "tf_rendezvous", _xla_host_transfer_rendezvous = "send_if_false_dtoh_0"} // CHECK: [[FALSE_RECV_TUPLE:%.*]]:2 = "mhlo.recv"([[FALSE_SEND_TOKEN]]) - // CHECK-SAME: channel_handle = #mhlo.channel_handle + // CHECK-SAME: channel_handle = #mhlo.channel_handle + // CHECK-SAME: handle = + // CHECK-SAME: type = 3 // CHECK-SAME: mhlo.frontend_attributes = {_xla_host_transfer_handler_name = "tf_rendezvous", _xla_host_transfer_rendezvous = "recv_if_false_htod_0"} %1 = "tf._XlaHostComputeMlir"(%arg2) {recv_key = "recv_if_false", send_key = "send_if_false", host_mlir_module = ""} : (tensor) -> tensor @@ -681,11 +709,15 @@ func.func @while_cond_body(%arg0: tensor) -> tensor { %0 = "mhlo.while"(%arg0) ({ ^bb0(%arg1: tensor): // CHECK: [[COND_SEND_TOKEN:%.*]] = "mhlo.send"([[ITER_ARG_VALUE]], [[ITER_ARG_TOKEN]]) - // CHECK-SAME: channel_handle = #mhlo.channel_handle + // CHECK-SAME: channel_handle = #mhlo.channel_handle + // CHECK-SAME: handle = + // CHECK-SAME: type = 2 // CHECK-SAME: mhlo.frontend_attributes = {_xla_host_transfer_handler_name = "tf_rendezvous", _xla_host_transfer_rendezvous = "send_while_cond_dtoh_0"} // CHECK: [[COND_RECV_TUPLE:%.*]]:2 = "mhlo.recv"([[COND_SEND_TOKEN]]) - // CHECK-SAME: channel_handle = #mhlo.channel_handle + // CHECK-SAME: channel_handle = #mhlo.channel_handle + // CHECK-SAME: handle = + // CHECK-SAME: type = 3 // CHECK-SAME: mhlo.frontend_attributes = {_xla_host_transfer_handler_name = "tf_rendezvous", _xla_host_transfer_rendezvous = "recv_while_cond_htod_0"} %1 = "tf._XlaHostComputeMlir"(%arg1) {recv_key = "recv_while_cond", send_key = "send_while_cond", host_mlir_module = ""} : (tensor) -> tensor @@ -697,11 +729,15 @@ func.func @while_cond_body(%arg0: tensor) -> tensor { }, { ^bb0(%arg1: tensor): // CHECK: [[BODY_SEND_TOKEN:%.*]] = "mhlo.send"([[ITER_ARG_VALUE]], [[ITER_ARG_TOKEN]]) - // CHECK-SAME: channel_handle = #mhlo.channel_handle + // CHECK-SAME: channel_handle = #mhlo.channel_handle + // CHECK-SAME: handle = + // CHECK-SAME: type = 2 // CHECK-SAME: mhlo.frontend_attributes = {_xla_host_transfer_handler_name = "tf_rendezvous", _xla_host_transfer_rendezvous = "send_while_body_dtoh_0"} // CHECK: [[BODY_RECV_TUPLE:%.*]]:2 = "mhlo.recv"([[BODY_SEND_TOKEN]]) - // CHECK-SAME: channel_handle = #mhlo.channel_handle + // CHECK-SAME: channel_handle = #mhlo.channel_handle + // CHECK-SAME: handle = + // CHECK-SAME: type = 3 // CHECK-SAME: mhlo.frontend_attributes = {_xla_host_transfer_handler_name = "tf_rendezvous", _xla_host_transfer_rendezvous = "recv_while_body_htod_0"} %1 = "tf._XlaHostComputeMlir"(%arg1) {recv_key = "recv_while_body", send_key = "send_while_body", host_mlir_module = ""} : (tensor) -> tensor @@ -727,11 +763,15 @@ func.func @while_cond(%arg0: tensor) -> tensor { ^bb0(%arg1: tensor): // CHECK: [[COND_SEND_TOKEN:%.*]] = "mhlo.send"([[ITER_ARG_VALUE]], [[ITER_ARG_TOKEN]]) - // CHECK-SAME: channel_handle = #mhlo.channel_handle + // CHECK-SAME: channel_handle = #mhlo.channel_handle + // CHECK-SAME: handle = + // CHECK-SAME: type = 2 // CHECK-SAME: mhlo.frontend_attributes = {_xla_host_transfer_handler_name = "tf_rendezvous", _xla_host_transfer_rendezvous = "send_while_cond_dtoh_0"} // CHECK: [[COND_RECV_TUPLE:%.*]]:2 = "mhlo.recv"([[COND_SEND_TOKEN]]) - // CHECK-SAME: channel_handle = #mhlo.channel_handle + // CHECK-SAME: channel_handle = #mhlo.channel_handle + // CHECK-SAME: handle = + // CHECK-SAME: type = 3 // CHECK-SAME: mhlo.frontend_attributes = {_xla_host_transfer_handler_name = "tf_rendezvous", _xla_host_transfer_rendezvous = "recv_while_cond_htod_0"} %1 = "tf._XlaHostComputeMlir"(%arg1) {recv_key = "recv_while_cond", send_key = "send_while_cond", host_mlir_module = ""} : (tensor) -> tensor @@ -772,11 +812,15 @@ func.func @while_body(%arg0: tensor) -> tensor { ^bb0(%arg1: tensor): // CHECK: [[BODY_SEND_TOKEN:%.*]] = "mhlo.send"([[ITER_ARG_VALUE]], [[ITER_ARG_TOKEN]]) - // CHECK-SAME: channel_handle = #mhlo.channel_handle + // CHECK-SAME: channel_handle = #mhlo.channel_handle + // CHECK-SAME: handle = + // CHECK-SAME: type = 2 // CHECK-SAME: mhlo.frontend_attributes = {_xla_host_transfer_handler_name = "tf_rendezvous", _xla_host_transfer_rendezvous = "send_while_body_dtoh_0"} // CHECK: [[BODY_RECV_TUPLE:%.*]]:2 = "mhlo.recv"([[BODY_SEND_TOKEN]]) - // CHECK-SAME: channel_handle = #mhlo.channel_handle + // CHECK-SAME: channel_handle = #mhlo.channel_handle + // CHECK-SAME: handle = + // CHECK-SAME: type = 3 // CHECK-SAME: mhlo.frontend_attributes = {_xla_host_transfer_handler_name = "tf_rendezvous", _xla_host_transfer_rendezvous = "recv_while_body_htod_0"} %1 = "tf._XlaHostComputeMlir"(%arg1) {recv_key = "recv_while_body", send_key = "send_while_body", host_mlir_module = ""} : (tensor) -> tensor diff --git a/tensorflow/compiler/mlir/tf2xla/tests/legalize-tf-quant.mlir b/tensorflow/compiler/mlir/tf2xla/tests/legalize-tf-quant.mlir index 44997584147a0f..8e2876f5707deb 100644 --- a/tensorflow/compiler/mlir/tf2xla/tests/legalize-tf-quant.mlir +++ b/tensorflow/compiler/mlir/tf2xla/tests/legalize-tf-quant.mlir @@ -368,7 +368,7 @@ func.func @uniform_quantized_add(%arg0: tensor<3x2x!tf_type.qint32>) -> tensor<3 // CHECK-DAG: %[[LHS:.*]] = mhlo.convert %arg0 : (tensor<3x2xi32>) -> tensor<3x2x!quant.uniform> // CHECK-DAG: %[[RHS:.*]] = mhlo.constant() {value = dense<127> : tensor<2xi32>} : () -> tensor<2x!quant.uniform> - // CHECK: %[[RES:.*]] = chlo.broadcast_add %[[LHS]], %[[RHS]] {broadcast_dimensions = dense<1> : tensor<1xi64>} : + // CHECK: %[[RES:.*]] = chlo.broadcast_add %[[LHS]], %[[RHS]] {broadcast_dimensions = array} : // CHECK-SAME: (tensor<3x2x!quant.uniform>, tensor<2x!quant.uniform>) // CHECK-SAME: -> tensor<3x2x!quant.uniform> // CHECK: %[[RES_INT:.*]] = mhlo.convert %[[RES]] : (tensor<3x2x!quant.uniform>) -> tensor<3x2xi32> @@ -418,10 +418,10 @@ func.func @uniform_quantized_clip_by_value(%input: tensor<3x2xf32>) -> tensor<3x // CHECK-DAG: %[[CONVERT_1:.*]] = mhlo.convert %[[OPERAND]] : (tensor<3x2x!quant.uniform>) -> tensor<3x2xi32> // CHECK-DAG: %[[CONVERT_2:.*]] = mhlo.convert %[[CONVERT_1]] : (tensor<3x2xi32>) -> tensor<3x2x!quant.uniform> - // CHECK: %[[MIN_CLIPPED:.*]] = chlo.broadcast_maximum %[[CONVERT_2]], %[[MIN_MAX]] {broadcast_dimensions = dense<1> : tensor<1xi64>} : + // CHECK: %[[MIN_CLIPPED:.*]] = chlo.broadcast_maximum %[[CONVERT_2]], %[[MIN_MAX]] {broadcast_dimensions = array} : // CHECK-SAME: (tensor<3x2x!quant.uniform>, tensor<2x!quant.uniform>) // CHECK-SAME: -> tensor<3x2x!quant.uniform> - // CHECK: chlo.broadcast_minimum %[[MIN_CLIPPED]], %[[MIN_MAX]] {broadcast_dimensions = dense<1> : tensor<1xi64>} : + // CHECK: chlo.broadcast_minimum %[[MIN_CLIPPED]], %[[MIN_MAX]] {broadcast_dimensions = array} : // CHECK-SAME: (tensor<3x2x!quant.uniform>, tensor<2x!quant.uniform>) // CHECK-SAME: -> tensor<3x2x!quant.uniform> %1 = "tf.UniformQuantizedClipByValue"(%0, %min, %max, %scales, %zps) { @@ -557,4 +557,4 @@ func.func @while_region_with_quant_two_args(%arg0: tensor<2x2xf32>, %arg1: tenso // return %[[RESULT0]], %[[RESULT1]] func.return %3, %4 : tensor<2x?xf32>, tensor -} \ No newline at end of file +} diff --git a/tensorflow/compiler/mlir/tf2xla/tests/legalize-tf-with-tf2xla-hlo-importer.mlir b/tensorflow/compiler/mlir/tf2xla/tests/legalize-tf-with-tf2xla-hlo-importer.mlir index b8552d1b6bdd10..aabc9d471f8385 100644 --- a/tensorflow/compiler/mlir/tf2xla/tests/legalize-tf-with-tf2xla-hlo-importer.mlir +++ b/tensorflow/compiler/mlir/tf2xla/tests/legalize-tf-with-tf2xla-hlo-importer.mlir @@ -562,7 +562,7 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, pr // CHECK-NEXT: %[[cmul:.*]] = mhlo.convert %[[mul]] : tensor<8x8x8x8xf32> // CHECK-NEXT: %[[init:.*]] = mhlo.constant dense<0.000000e+00> : tensor // CHECK-NEXT: %[[convert_init:.*]] = mhlo.convert %[[init]] : tensor - // CHECK: %[[red1:.*]] = mhlo.reduce(%[[cmul]] init: %[[convert_init]]) across dimensions = [0, 1, 2] : (tensor<8x8x8x8xf32>, tensor) -> tensor<8xf32> + // CHECK: %[[red1:.*]] = mhlo.reduce(%[[cmul]] init: %[[convert_init]]) applies mhlo.add across dimensions = [0, 1, 2] : (tensor<8x8x8x8xf32>, tensor) -> tensor<8xf32> // CHECK: %[[scr2:.*]] = mhlo.convert %[[red1]] : tensor<8xf32> // CHECK: %[[mul2:.*]] = mhlo.multiply %arg2, %[[scr1]] : tensor<8xf32> @@ -575,7 +575,7 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, pr // CHECK: %[[cgrad:.*]] = mhlo.convert %[[grad]] : tensor<8x8x8x8xf32> // CHECK: %[[init2:.*]] = mhlo.constant dense<0.000000e+00> : tensor // CHECK-NEXT: %[[convert_init2:.*]] = mhlo.convert %[[init2]] : tensor - // CHECK: %[[red2:.*]] = mhlo.reduce(%[[cgrad]] init: %[[convert_init2]]) across dimensions = [0, 1, 2] : (tensor<8x8x8x8xf32>, tensor) -> tensor<8xf32> + // CHECK: %[[red2:.*]] = mhlo.reduce(%[[cgrad]] init: %[[convert_init2]]) applies mhlo.add across dimensions = [0, 1, 2] : (tensor<8x8x8x8xf32>, tensor) -> tensor<8xf32> // CHECK: %[[offset_backprop:.*]] = mhlo.convert %[[red2]] : tensor<8xf32> // CHECK: %[[x_backprop:.*]] = mhlo.convert %[[mul3]] : tensor<8x8x8x8xf32> diff --git a/tensorflow/compiler/mlir/tf2xla/tests/legalize-tf.mlir b/tensorflow/compiler/mlir/tf2xla/tests/legalize-tf.mlir index ba0f70e369f3bf..018046d66d57a2 100644 --- a/tensorflow/compiler/mlir/tf2xla/tests/legalize-tf.mlir +++ b/tensorflow/compiler/mlir/tf2xla/tests/legalize-tf.mlir @@ -166,7 +166,7 @@ func.func @fusedBatchNormGrad_noTraining(%arg0: tensor<8x8x8x8xf32>, %arg1: tens // CHECK-NEXT: %[[act:.*]] = mhlo.convert %arg1 : tensor<8x8x8x8xf32> // CHECK-NEXT: %[[eps:.*]] = mhlo.constant dense<1.000000e-03> : tensor - // CHECK-NEXT: %[[add:.*]] = chlo.broadcast_add %arg4, %[[eps]] {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<8xf32>, tensor) -> tensor<8xf32> + // CHECK-NEXT: %[[add:.*]] = chlo.broadcast_add %arg4, %[[eps]] {broadcast_dimensions = array} : (tensor<8xf32>, tensor) -> tensor<8xf32> // CHECK-NEXT: %[[scr1:.*]] = mhlo.rsqrt %[[add]] : tensor<8xf32> // CHECK: %[[bcast_arg3:.+]] = "mhlo.dynamic_broadcast_in_dim"(%arg3, {{.*}}) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<8xf32>, tensor<4xindex>) -> tensor<8x8x8x8xf32> @@ -218,7 +218,7 @@ func.func @fusedBatchNormGradV2_noTraining(%arg0: tensor<8x8x8x8xf32>, %arg1: te // CHECK-NEXT: %[[act:.*]] = mhlo.convert %arg1 : tensor<8x8x8x8xf32> // CHECK-NEXT: %[[eps:.*]] = mhlo.constant dense<1.000000e-03> : tensor - // CHECK-NEXT: %[[add:.*]] = chlo.broadcast_add %arg4, %[[eps]] {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<8xf32>, tensor) -> tensor<8xf32> + // CHECK-NEXT: %[[add:.*]] = chlo.broadcast_add %arg4, %[[eps]] {broadcast_dimensions = array} : (tensor<8xf32>, tensor) -> tensor<8xf32> // CHECK-NEXT: %[[scr1:.*]] = mhlo.rsqrt %[[add]] : tensor<8xf32> // CHECK: %[[bcast_arg3:.+]] = "mhlo.dynamic_broadcast_in_dim"(%arg3, {{.*}}) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<8xf32>, tensor<4xindex>) -> tensor<8x8x8x8xf32> @@ -299,7 +299,7 @@ func.func @fusedBatchNormGradV3_noTraining(%arg0: tensor<8x8x8x8xf32>, %arg1: te // CHECK-NEXT: %[[act:.*]] = mhlo.convert %arg1 : tensor<8x8x8x8xf32> // CHECK-NEXT: %[[eps:.*]] = mhlo.constant dense<1.000000e-03> : tensor - // CHECK-NEXT: %[[add:.*]] = chlo.broadcast_add %arg4, %[[eps]] {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<8xf32>, tensor) -> tensor<8xf32> + // CHECK-NEXT: %[[add:.*]] = chlo.broadcast_add %arg4, %[[eps]] {broadcast_dimensions = array} : (tensor<8xf32>, tensor) -> tensor<8xf32> // CHECK-NEXT: %[[scr1:.*]] = mhlo.rsqrt %[[add]] : tensor<8xf32> // CHECK: %[[bcast_arg3:.+]] = "mhlo.dynamic_broadcast_in_dim"(%arg3, {{.*}}) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<8xf32>, tensor<4xindex>) -> tensor<8x8x8x8xf32> @@ -381,7 +381,7 @@ func.func @fusedBatchNormGradV3_noTraining_NCHW(%arg0: tensor<8x8x8x8xf32>, %arg // CHECK-NEXT: %[[act:.*]] = mhlo.convert %arg1 : tensor<8x8x8x8xf32> // CHECK-NEXT: %[[eps:.*]] = mhlo.constant dense<1.000000e-03> : tensor - // CHECK-NEXT: %[[add:.*]] = chlo.broadcast_add %arg4, %[[eps]] {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<8xf32>, tensor) -> tensor<8xf32> + // CHECK-NEXT: %[[add:.*]] = chlo.broadcast_add %arg4, %[[eps]] {broadcast_dimensions = array} : (tensor<8xf32>, tensor) -> tensor<8xf32> // CHECK-NEXT: %[[scr1:.*]] = mhlo.rsqrt %[[add]] : tensor<8xf32> // CHECK: %[[bcast_arg3:.+]] = "mhlo.dynamic_broadcast_in_dim"(%arg3, {{.*}}) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<8xf32>, tensor<4xindex>) -> tensor<8x8x8x8xf32> @@ -739,7 +739,7 @@ func.func @matrix_diag_part_align_7d(%arg0: tensor<3x5x7x9x11x13x17xf32>) -> ten // CHECK-LABEL: func @erf func.func @erf(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> { - // CHECK: chlo.erf %arg0 : tensor<2x3xf32> + // CHECK: mhlo.erf %arg0 : tensor<2x3xf32> %0 = "tf.Erf"(%arg0) : (tensor<2x3xf32>) -> tensor<2x3xf32> func.return %0 : tensor<2x3xf32> } @@ -787,14 +787,14 @@ func.func @unary_einsum(%arg0: tensor<2x3xf32>) -> tensor<2x2xf32> { // CHECK-LABEL: func @floordiv_broadcast_i32 func.func @floordiv_broadcast_i32(%arg0: tensor<2x3xi32>, %arg1: tensor<3xi32>) -> tensor<2x3xi32> { - // CHECK-DAG: [[DIV:%.+]] = chlo.broadcast_divide %arg0, %arg1 {broadcast_dimensions = dense<1> : tensor<1xi64>} - // CHECK-DAG: [[MUL:%.+]] = chlo.broadcast_multiply [[DIV]], %arg1 {broadcast_dimensions = dense<1> : tensor<1xi64>} + // CHECK-DAG: [[DIV:%.+]] = chlo.broadcast_divide %arg0, %arg1 {broadcast_dimensions = array} + // CHECK-DAG: [[MUL:%.+]] = chlo.broadcast_multiply [[DIV]], %arg1 {broadcast_dimensions = array} // CHECK-DAG: [[CMP1:%.+]] = chlo.broadcast_compare [[MUL]], %arg0 {comparison_direction = #chlo} // CHECK-DAG: [[ZEROS1:%.+]] = mhlo.constant dense<0> // CHECK-DAG: [[CMP2:%.+]] = chlo.broadcast_compare %arg0, [[ZEROS1]] {comparison_direction = #chlo} // CHECK-DAG: [[ZEROS2:%.+]] = mhlo.constant dense<0> // CHECK-DAG: [[CMP3:%.+]] = chlo.broadcast_compare %arg1, [[ZEROS2]] {comparison_direction = #chlo} - // CHECK-DAG: [[CMP4:%.+]] = chlo.broadcast_compare [[CMP2]], [[CMP3]] {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = #chlo} + // CHECK-DAG: [[CMP4:%.+]] = chlo.broadcast_compare [[CMP2]], [[CMP3]] {broadcast_dimensions = array, comparison_direction = #chlo} // CHECK-DAG: [[AND:%.+]] = chlo.broadcast_and [[CMP1]], [[CMP4]] // CHECK-DAG: [[ONES:%.+]] = mhlo.constant dense<1> // CHECK-DAG: [[SUB:%.+]] = chlo.broadcast_subtract [[DIV]], [[ONES]] @@ -808,14 +808,14 @@ func.func @floordiv_broadcast_i32(%arg0: tensor<2x3xi32>, %arg1: tensor<3xi32>) // CHECK-LABEL: func @floordiv_reverse_broadcast_i32 func.func @floordiv_reverse_broadcast_i32(%arg0: tensor<3xi32>, %arg1: tensor<2x3xi32>) -> tensor<2x3xi32> { - // CHECK-DAG: [[DIV:%.+]] = chlo.broadcast_divide %arg0, %arg1 {broadcast_dimensions = dense<1> : tensor<1xi64>} + // CHECK-DAG: [[DIV:%.+]] = chlo.broadcast_divide %arg0, %arg1 {broadcast_dimensions = array} // CHECK-DAG: [[MUL:%.+]] = chlo.broadcast_multiply [[DIV]] - // CHECK-DAG: [[CMP1:%.+]] = chlo.broadcast_compare [[MUL]], %arg0 {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = #chlo} + // CHECK-DAG: [[CMP1:%.+]] = chlo.broadcast_compare [[MUL]], %arg0 {broadcast_dimensions = array, comparison_direction = #chlo} // CHECK-DAG: [[ZEROS1:%.+]] = mhlo.constant dense<0> // CHECK-DAG: [[CMP2:%.+]] = chlo.broadcast_compare %arg0, [[ZEROS1]] {comparison_direction = #chlo} // CHECK-DAG: [[ZEROS2:%.+]] = mhlo.constant dense<0> // CHECK-DAG: [[CMP3:%.+]] = chlo.broadcast_compare %arg1, [[ZEROS2]] {comparison_direction = #chlo} - // CHECK-DAG: [[CMP4:%.+]] = chlo.broadcast_compare [[CMP2]], [[CMP3]] {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = #chlo} + // CHECK-DAG: [[CMP4:%.+]] = chlo.broadcast_compare [[CMP2]], [[CMP3]] {broadcast_dimensions = array, comparison_direction = #chlo} // CHECK-DAG: [[AND:%.+]] = chlo.broadcast_and [[CMP1]], [[CMP4]] // CHECK-DAG: [[ONES:%.+]] = mhlo.constant dense<1> // CHECK-DAG: [[SUB:%.+]] = chlo.broadcast_subtract [[DIV]], [[ONES]] @@ -865,14 +865,14 @@ func.func @floordiv_f16_broadcast(%arg0: tensor<2x3xf16>, %arg1: tensor<3xf16>) // CHECK-LABEL: func @floordiv_dynamic func.func @floordiv_dynamic(%arg0: tensor, %arg1: tensor) -> tensor { - // CHECK-DAG: [[DIV:%.+]] = chlo.broadcast_divide %arg0, %arg1 {broadcast_dimensions = dense<1> : tensor<1xi64>} - // CHECK-DAG: [[MUL:%.+]] = chlo.broadcast_multiply [[DIV]], %arg1 {broadcast_dimensions = dense<1> : tensor<1xi64>} + // CHECK-DAG: [[DIV:%.+]] = chlo.broadcast_divide %arg0, %arg1 {broadcast_dimensions = array} + // CHECK-DAG: [[MUL:%.+]] = chlo.broadcast_multiply [[DIV]], %arg1 {broadcast_dimensions = array} // CHECK-DAG: [[CMP1:%.+]] = chlo.broadcast_compare [[MUL]], %arg0 {comparison_direction = #chlo} // CHECK-DAG: [[ZEROS1:%.+]] = mhlo.constant dense<0> // CHECK-DAG: [[CMP2:%.+]] = chlo.broadcast_compare %arg0, [[ZEROS1]] {comparison_direction = #chlo} // CHECK-DAG: [[ZEROS2:%.+]] = mhlo.constant dense<0> // CHECK-DAG: [[CMP3:%.+]] = chlo.broadcast_compare %arg1, [[ZEROS2]] {comparison_direction = #chlo} - // CHECK-DAG: [[CMP4:%.+]] = chlo.broadcast_compare [[CMP2]], [[CMP3]] {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = #chlo} + // CHECK-DAG: [[CMP4:%.+]] = chlo.broadcast_compare [[CMP2]], [[CMP3]] {broadcast_dimensions = array, comparison_direction = #chlo} // CHECK-DAG: [[AND:%.+]] = chlo.broadcast_and [[CMP1]], [[CMP4]] // CHECK-DAG: [[ONES:%.+]] = mhlo.constant dense<1> // CHECK-DAG: [[SUB:%.+]] = chlo.broadcast_subtract [[DIV]], [[ONES]] @@ -886,7 +886,7 @@ func.func @floordiv_dynamic(%arg0: tensor, %arg1: tensor) -> ten // CHECK-LABEL: func @floordiv_unsigned func.func @floordiv_unsigned(%arg0: tensor, %arg1: tensor) -> tensor { - // CHECK-DAG: [[DIV:%.+]] = chlo.broadcast_divide %arg0, %arg1 {broadcast_dimensions = dense<1> : tensor<1xi64>} + // CHECK-DAG: [[DIV:%.+]] = chlo.broadcast_divide %arg0, %arg1 {broadcast_dimensions = array} // CHECK: return [[DIV]] %0 = "tf.FloorDiv"(%arg0, %arg1) : (tensor, tensor) -> tensor func.return %0: tensor @@ -926,12 +926,12 @@ func.func @floordiv_int(%arg0: tensor<*xi32>, %arg1: tensor<*xi32>) -> tensor<*x // CHECK-LABEL: func @floormod_broadcast_numerator func.func @floormod_broadcast_numerator(%arg0: tensor<3xi32>, %arg1: tensor<2x3xi32>) -> tensor<2x3xi32> { - // CHECK-DAG: [[REM:%.+]] = chlo.broadcast_remainder %arg0, %arg1 {broadcast_dimensions = dense<1> : tensor<1xi64>} + // CHECK-DAG: [[REM:%.+]] = chlo.broadcast_remainder %arg0, %arg1 {broadcast_dimensions = array} // CHECK-DAG: [[ZL:%.+]] = mhlo.constant dense<0> // CHECK-DAG: [[CMP1:%.+]] = chlo.broadcast_compare [[REM]], [[ZL]] {comparison_direction = #chlo} // CHECK-DAG: [[ZR:%.+]] = mhlo.constant dense<0> // CHECK-DAG: [[CMP2:%.+]] = chlo.broadcast_compare %arg1, [[ZR]] {comparison_direction = #chlo} - // CHECK-DAG: [[CMP3:%.+]] = chlo.broadcast_compare [[REM]], [[ZR]] {broadcast_dimensions = dense<> : tensor<0xi64>, comparison_direction = #chlo} + // CHECK-DAG: [[CMP3:%.+]] = chlo.broadcast_compare [[REM]], [[ZR]] {broadcast_dimensions = array, comparison_direction = #chlo} // CHECK-DAG: [[CMP4:%.+]] = chlo.broadcast_compare [[CMP2]], [[CMP3]] {comparison_direction = #chlo} // CHECK-DAG: [[AND:%.+]] = chlo.broadcast_and [[CMP1]], [[CMP4]] // CHECK-DAG: [[ADD:%.+]] = chlo.broadcast_add %arg1, [[REM]] @@ -945,15 +945,15 @@ func.func @floormod_broadcast_numerator(%arg0: tensor<3xi32>, %arg1: tensor<2x3x // CHECK-LABEL: func @floormod_broadcast_denominator func.func @floormod_broadcast_denominator(%arg0: tensor<2x3xi32>, %arg1: tensor<3xi32>) -> tensor<2x3xi32> { - // CHECK-DAG: [[REM:%.+]] = chlo.broadcast_remainder %arg0, %arg1 {broadcast_dimensions = dense<1> : tensor<1xi64>} + // CHECK-DAG: [[REM:%.+]] = chlo.broadcast_remainder %arg0, %arg1 {broadcast_dimensions = array} // CHECK-DAG: [[ZL:%.+]] = mhlo.constant dense<0> // CHECK-DAG: [[CMP1:%.+]] = chlo.broadcast_compare [[REM]], [[ZL]] {comparison_direction = #chlo} // CHECK-DAG: [[ZR:%.+]] = mhlo.constant dense<0> // CHECK-DAG: [[CMP2:%.+]] = chlo.broadcast_compare %arg1, [[ZR]] {comparison_direction = #chlo} - // CHECK-DAG: [[CMP3:%.+]] = chlo.broadcast_compare [[REM]], [[ZR]] {broadcast_dimensions = dense<> : tensor<0xi64>, comparison_direction = #chlo} - // CHECK-DAG: [[CMP4:%.+]] = chlo.broadcast_compare [[CMP2]], [[CMP3]] {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = #chlo} + // CHECK-DAG: [[CMP3:%.+]] = chlo.broadcast_compare [[REM]], [[ZR]] {broadcast_dimensions = array, comparison_direction = #chlo} + // CHECK-DAG: [[CMP4:%.+]] = chlo.broadcast_compare [[CMP2]], [[CMP3]] {broadcast_dimensions = array, comparison_direction = #chlo} // CHECK-DAG: [[AND:%.+]] = chlo.broadcast_and [[CMP1]], [[CMP4]] - // CHECK-DAG: [[ADD:%.+]] = chlo.broadcast_add %arg1, [[REM]] {broadcast_dimensions = dense<1> : tensor<1xi64>} + // CHECK-DAG: [[ADD:%.+]] = chlo.broadcast_add %arg1, [[REM]] {broadcast_dimensions = array} // CHECK-DAG: [[SELECT:%.+]] = mhlo.select [[AND]], [[ADD]], [[REM]] // CHECK-NEXT: return [[SELECT]] %0 = "tf.FloorMod"(%arg0, %arg1) : (tensor<2x3xi32>, tensor<3xi32>) -> tensor<2x3xi32> @@ -964,7 +964,7 @@ func.func @floormod_broadcast_denominator(%arg0: tensor<2x3xi32>, %arg1: tensor< // CHECK-LABEL: func @floormod_unsigned_broadcast_denominator func.func @floormod_unsigned_broadcast_denominator(%arg0: tensor<2x3xui32>, %arg1: tensor<3xui32>) -> tensor<2x3xui32> { - // CHECK-DAG: [[REM:%.+]] = chlo.broadcast_remainder %arg0, %arg1 {broadcast_dimensions = dense<1> : tensor<1xi64>} + // CHECK-DAG: [[REM:%.+]] = chlo.broadcast_remainder %arg0, %arg1 {broadcast_dimensions = array} // CHECK-NEXT: return [[REM]] %0 = "tf.FloorMod"(%arg0, %arg1) : (tensor<2x3xui32>, tensor<3xui32>) -> tensor<2x3xui32> func.return %0: tensor<2x3xui32> @@ -974,15 +974,15 @@ func.func @floormod_unsigned_broadcast_denominator(%arg0: tensor<2x3xui32>, %arg // CHECK-LABEL: func @floormod_dynamic_broadcast_numerator func.func @floormod_dynamic_broadcast_numerator_(%arg0: tensor, %arg1: tensor) -> tensor { - // CHECK-DAG: [[REM:%.+]] = chlo.broadcast_remainder %arg0, %arg1 {broadcast_dimensions = dense<1> : tensor<1xi64>} + // CHECK-DAG: [[REM:%.+]] = chlo.broadcast_remainder %arg0, %arg1 {broadcast_dimensions = array} // CHECK-DAG: [[ZL:%.+]] = mhlo.constant dense<0> // CHECK-DAG: [[CMP1:%.+]] = chlo.broadcast_compare [[REM]], [[ZL]] {comparison_direction = #chlo} // CHECK-DAG: [[ZR:%.+]] = mhlo.constant dense<0> // CHECK-DAG: [[CMP2:%.+]] = chlo.broadcast_compare %arg1, [[ZR]] {comparison_direction = #chlo} - // CHECK-DAG: [[CMP3:%.+]] = chlo.broadcast_compare [[REM]], [[ZR]] {broadcast_dimensions = dense<> : tensor<0xi64>, comparison_direction = #chlo} - // CHECK-DAG: [[CMP4:%.+]] = chlo.broadcast_compare [[CMP2]], [[CMP3]] {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = #chlo} + // CHECK-DAG: [[CMP3:%.+]] = chlo.broadcast_compare [[REM]], [[ZR]] {broadcast_dimensions = array, comparison_direction = #chlo} + // CHECK-DAG: [[CMP4:%.+]] = chlo.broadcast_compare [[CMP2]], [[CMP3]] {broadcast_dimensions = array, comparison_direction = #chlo} // CHECK-DAG: [[AND:%.+]] = chlo.broadcast_and [[CMP1]], [[CMP4]] - // CHECK-DAG: [[ADD:%.+]] = chlo.broadcast_add %arg1, [[REM]] {broadcast_dimensions = dense<1> : tensor<1xi64>} + // CHECK-DAG: [[ADD:%.+]] = chlo.broadcast_add %arg1, [[REM]] {broadcast_dimensions = array} // CHECK-DAG: [[SELECT:%.+]] = mhlo.select [[AND]], [[ADD]], [[REM]] // CHECK-NEXT: return [[SELECT]] %0 = "tf.FloorMod"(%arg0, %arg1) : (tensor, tensor) -> tensor @@ -994,12 +994,12 @@ func.func @floormod_dynamic_broadcast_numerator_(%arg0: tensor, %arg1: // CHECK-LABEL: func @floormod_dynamic_broadcast_denominator func.func @floormod_dynamic_broadcast_denominator_(%arg0: tensor, %arg1: tensor) -> tensor { // CHECK-NOT: tf.FloorMod - // CHECK-DAG: [[REM:%.+]] = chlo.broadcast_remainder %arg0, %arg1 {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (tensor, tensor) -> tensor + // CHECK-DAG: [[REM:%.+]] = chlo.broadcast_remainder %arg0, %arg1 {broadcast_dimensions = array} : (tensor, tensor) -> tensor // CHECK-DAG: [[ZL:%.+]] = mhlo.constant dense<0.000000e+00> : tensor // CHECK-DAG: [[CMP1:%.+]] = chlo.broadcast_compare [[REM]], [[ZL]] {comparison_direction = #chlo} : (tensor, tensor) -> tensor // CHECK-DAG: [[ZR:%.+]] = mhlo.constant dense<0.000000e+00> : tensor // CHECK-DAG: [[CMP2:%.+]] = chlo.broadcast_compare %arg1, [[ZR]] {comparison_direction = #chlo} : (tensor, tensor) -> tensor - // CHECK-DAG: [[CMP3:%.+]] = chlo.broadcast_compare [[REM]], [[ZR]] {broadcast_dimensions = dense<> : tensor<0xi64>, comparison_direction = #chlo} : (tensor, tensor) -> tensor + // CHECK-DAG: [[CMP3:%.+]] = chlo.broadcast_compare [[REM]], [[ZR]] {broadcast_dimensions = array, comparison_direction = #chlo} : (tensor, tensor) -> tensor // CHECK-DAG: [[CMP4:%.+]] = chlo.broadcast_compare [[CMP2]], [[CMP3]] {comparison_direction = #chlo} : (tensor, tensor) -> tensor // CHECK-DAG: [[AND:%.+]] = chlo.broadcast_and [[CMP1]], [[CMP4]] : (tensor, tensor) -> tensor // CHECK-DAG: [[ADD:%.+]] = chlo.broadcast_add %arg1, [[REM]] : (tensor, tensor) -> tensor @@ -1839,8 +1839,8 @@ func.func @elu_unranked(%arg0: tensor) -> tensor { func.func @elu_grad(%gradients: tensor<4x8xf32>, %features: tensor) -> tensor<4x8xf32> { // CHECK-DAG: %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor // CHECK-DAG: %[[ONE:.*]] = mhlo.constant dense<1.000000e+00> : tensor - // CHECK-DAG: %[[PRED:.*]] = chlo.broadcast_compare %[[FEATURES]], %[[ZERO]] {broadcast_dimensions = dense<> : tensor<0xi64>, comparison_direction = #chlo} - // CHECK-DAG: %[[ADD1:.*]] = chlo.broadcast_add %[[FEATURES]], %[[ONE]] {broadcast_dimensions = dense<> : tensor<0xi64>} + // CHECK-DAG: %[[PRED:.*]] = chlo.broadcast_compare %[[FEATURES]], %[[ZERO]] {broadcast_dimensions = array, comparison_direction = #chlo} + // CHECK-DAG: %[[ADD1:.*]] = chlo.broadcast_add %[[FEATURES]], %[[ONE]] {broadcast_dimensions = array} // CHECK-DAG: %[[MULGRAD:.*]] = mhlo.multiply %[[GRADIENTS]], %[[ADD1]] : (tensor<4x8xf32>, tensor) -> tensor<4x8xf32> // CHECK: %[[RESULT:.*]] = mhlo.select %[[PRED]], %[[GRADIENTS]], %[[MULGRAD]] // CHECK: return %[[RESULT]] @@ -1857,7 +1857,7 @@ func.func @elu_grad(%gradients: tensor<4x8xf32>, %features: tensor) -> // CHECK-LABEL: func @relu func.func @relu(%arg0: tensor<1xi32>) -> tensor<1xi32> { // CHECK: %[[ZERO:.*]] = mhlo.constant dense<0> : tensor - // CHECK: chlo.broadcast_maximum %[[ZERO]], %arg0 {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor, tensor<1xi32>) -> tensor<1xi32> + // CHECK: chlo.broadcast_maximum %[[ZERO]], %arg0 {broadcast_dimensions = array} : (tensor, tensor<1xi32>) -> tensor<1xi32> %0 = "tf.Relu"(%arg0) : (tensor<1xi32>) -> tensor<1xi32> func.return %0: tensor<1xi32> } @@ -1867,7 +1867,7 @@ func.func @relu(%arg0: tensor<1xi32>) -> tensor<1xi32> { // CHECK-LABEL: func @relu_unranked func.func @relu_unranked(%arg0: tensor) -> tensor { // CHECK: %[[ZERO:.*]] = mhlo.constant dense<0> : tensor - // CHECK: chlo.broadcast_maximum %[[ZERO]], %arg0 {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor, tensor) -> tensor + // CHECK: chlo.broadcast_maximum %[[ZERO]], %arg0 {broadcast_dimensions = array} : (tensor, tensor) -> tensor %0 = "tf.Relu"(%arg0) : (tensor) -> tensor func.return %0: tensor } @@ -1877,7 +1877,7 @@ func.func @relu_unranked(%arg0: tensor) -> tensor { // CHECK-LABEL: func @relu_unsigned func.func @relu_unsigned(%arg0: tensor) -> tensor { // CHECK: %[[ZERO:.*]] = mhlo.constant dense<0> : tensor - // CHECK: chlo.broadcast_maximum %[[ZERO]], %arg0 {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor, tensor) -> tensor + // CHECK: chlo.broadcast_maximum %[[ZERO]], %arg0 {broadcast_dimensions = array} : (tensor, tensor) -> tensor %0 = "tf.Relu"(%arg0) : (tensor) -> tensor func.return %0: tensor } @@ -2017,7 +2017,7 @@ func.func @softsign_grad(%arg0: tensor<4x10xf32>, %arg1: tensor<4x10xf32>) -> te // CHECK-NEXT: %[[ONE:.*]] = mhlo.constant dense<1.000000e+00> : tensor // CHECK-NEXT: %[[ABS:.*]] = mhlo.abs %{{.*}} : tensor<4x10xf32> - // CHECK-NEXT: %[[BROADCAST_ADD:.*]] = chlo.broadcast_add %[[ONE]], %[[ABS]] {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor, tensor<4x10xf32>) -> tensor<4x10xf32> + // CHECK-NEXT: %[[BROADCAST_ADD:.*]] = chlo.broadcast_add %[[ONE]], %[[ABS]] {broadcast_dimensions = array} : (tensor, tensor<4x10xf32>) -> tensor<4x10xf32> // CHECK-NEXT: %[[MUL:.*]] = mhlo.multiply %[[BROADCAST_ADD]], %[[BROADCAST_ADD]] : tensor<4x10xf32> // CHECK-NEXT: %[[BROADCAST_DIV:.*]] = chlo.broadcast_divide %{{.*}}, %[[MUL]] : (tensor<4x10xf32>, tensor<4x10xf32>) -> tensor<4x10xf32> // CHECK-NEXT: return %[[BROADCAST_DIV]] : tensor<4x10xf32> @@ -2775,7 +2775,7 @@ func.func @sigmoid_grad_complex(%arg0: tensor<2xcomplex>, %arg1: tensor<2xc // CHECK-LABEL: @sigmoid_grad_dynamic func.func @sigmoid_grad_dynamic(%arg0: tensor, %arg1: tensor) -> tensor { // CHECK: chlo.broadcast_multiply {{.*}} : (tensor, tensor) -> tensor - // CHECK: chlo.broadcast_subtract {{.*}} {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor, tensor) -> tensor + // CHECK: chlo.broadcast_subtract {{.*}} {broadcast_dimensions = array} : (tensor, tensor) -> tensor // CHECK: chlo.broadcast_multiply {{.*}} : (tensor, tensor) -> tensor %0 = "tf.SigmoidGrad"(%arg0, %arg1) : (tensor, tensor) -> tensor func.return %0 : tensor @@ -3662,7 +3662,7 @@ func.func @mean(%arg0: tensor<4x8xf16>) -> tensor<4x1xf16> { // CHECK: %[[CAST:.*]] = mhlo.convert %arg0 : (tensor<4x8xf16>) -> tensor<4x8xf32> // CHECK: %[[INITIAL:.*]] = mhlo.constant dense<-0.000000e+00> : tensor // CHECK: %[[REDUCED:.*]] = mhlo.reduce(%[[CAST]] init: %[[INITIAL]]) applies mhlo.add across dimensions = [1] : (tensor<4x8xf32>, tensor) -> tensor<4xf32> - // CHECK: %[[MEAN:.*]] = chlo.broadcast_divide %[[REDUCED]], %{{.*}} {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<4xf32>, tensor) -> tensor<4xf32> + // CHECK: %[[MEAN:.*]] = chlo.broadcast_divide %[[REDUCED]], %{{.*}} {broadcast_dimensions = array} : (tensor<4xf32>, tensor) -> tensor<4xf32> // CHECK: %[[CAST_BACK:.*]] = mhlo.convert %[[MEAN]] : (tensor<4xf32>) -> tensor<4xf16> // CHECK: %[[RESULT:.*]] = mhlo.reshape %[[CAST_BACK]] : (tensor<4xf16>) -> tensor<4x1xf16> // CHECK: return %[[RESULT]] : tensor<4x1xf16> @@ -3698,7 +3698,7 @@ func.func @mean_dynamic(%arg0: tensor) -> tensor { // CHECK: %[[INDEX_CAST:.*]] = arith.index_cast %[[MUL]] : index to i64 // CHECK: %[[TENSOR:.*]] = tensor.from_elements %[[INDEX_CAST]] : tensor // CHECK: %[[CONVERT:.*]] = mhlo.convert %[[TENSOR]] : (tensor) -> tensor - // CHECK: %[[MEAN:.*]] = chlo.broadcast_divide %[[REDUCED]], %[[CONVERT]] {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor, tensor) -> tensor + // CHECK: %[[MEAN:.*]] = chlo.broadcast_divide %[[REDUCED]], %[[CONVERT]] {broadcast_dimensions = array} : (tensor, tensor) -> tensor // CHECK: %[[MEAN_CONVERTED:.*]] = mhlo.convert %[[MEAN]] : (tensor) -> tensor // CHECK: %[[SHAPE1:.*]] = shape.shape_of %[[MEAN_CONVERTED]] : tensor -> tensor<1xindex> // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index @@ -4120,8 +4120,8 @@ func.func @rng_std_normal(%arg0: tensor<3xi32>) -> tensor<12x?x64xf32> { func.func @range(%arg0: tensor, %arg1: tensor) -> tensor<5xf32> { %1 = "tf.Const"() {device = "", dtype = "tfdtype$DT_FLOAT", name = "range/limit", value = dense<5.000000e+00> : tensor} : () -> tensor // CHECK-DAG: [[IOTA:%.*]] = "mhlo.iota" - // CHECK-DAG: [[MUL:%.*]] = chlo.broadcast_multiply [[IOTA]], [[DELTA]] {broadcast_dimensions = dense<> : tensor<0xi64>} - // CHECK: chlo.broadcast_add [[MUL]], [[START]] {broadcast_dimensions = dense<> : tensor<0xi64>} + // CHECK-DAG: [[MUL:%.*]] = chlo.broadcast_multiply [[IOTA]], [[DELTA]] {broadcast_dimensions = array} + // CHECK: chlo.broadcast_add [[MUL]], [[START]] {broadcast_dimensions = array} %3 = "tf.Range"(%arg0, %1, %arg1) {Tidx = "tfdtype$DT_FLOAT", device = "", name = "range"} : (tensor, tensor, tensor) -> tensor<5xf32> func.return %3 : tensor<5xf32> } @@ -4142,8 +4142,8 @@ func.func @range_dynamic(%arg0: tensor, %arg1: tensor, %arg2: tensor : tensor<0xi64>} - // CHECK-DAG: [[ADD:%.+]] = chlo.broadcast_add [[MUL]], [[CONVERT_3]] {broadcast_dimensions = dense<> : tensor<0xi64>} + // CHECK-DAG: [[MUL:%.+]] = chlo.broadcast_multiply [[IOTA]], [[CONVERT_4]] {broadcast_dimensions = array} + // CHECK-DAG: [[ADD:%.+]] = chlo.broadcast_add [[MUL]], [[CONVERT_3]] {broadcast_dimensions = array} %2 = "tf.Range"(%arg0, %arg1, %arg2) {Tidx = "tfdtype$DT_FLOAT", device = "", name = "range"} : (tensor, tensor, tensor) -> tensor // CHECK: return [[ADD]] @@ -4166,8 +4166,8 @@ func.func @range_int_dynamic(%arg0: tensor, %arg1: tensor, %arg2: tens // CHECK-DAG: [[IOTA:%.+]] = "mhlo.dynamic_iota"([[RESHAPE]]) {iota_dimension = 0 : i64} // CHECK-DAG: [[CONVERT_3:%.+]] = mhlo.convert %arg0 // CHECK-DAG: [[CONVERT_4:%.+]] = mhlo.convert %arg2 - // CHECK-DAG: [[MUL:%.+]] = chlo.broadcast_multiply [[IOTA]], [[CONVERT_4]] {broadcast_dimensions = dense<> : tensor<0xi64>} - // CHECK-DAG: [[ADD:%.+]] = chlo.broadcast_add [[MUL]], [[CONVERT_3]] {broadcast_dimensions = dense<> : tensor<0xi64>} + // CHECK-DAG: [[MUL:%.+]] = chlo.broadcast_multiply [[IOTA]], [[CONVERT_4]] {broadcast_dimensions = array} + // CHECK-DAG: [[ADD:%.+]] = chlo.broadcast_add [[MUL]], [[CONVERT_3]] {broadcast_dimensions = array} %2 = "tf.Range"(%arg0, %arg1, %arg2) {Tidx = "tfdtype$DT_FLOAT", device = "", name = "range"} : (tensor, tensor, tensor) -> tensor // CHECK: return [[ADD]] @@ -4186,8 +4186,8 @@ func.func @linspace_static(%arg0: tensor, %arg1: tensor) -> tensor<4xf // CHECK-DAG: [[STEP_NUMERATOR:%.*]] = chlo.broadcast_subtract [[STOP]], [[START]] // CHECK-DAG: [[STEP:%.*]] = chlo.broadcast_divide [[STEP_NUMERATOR]], [[STEP_DENOMINATOR]] // CHECK-DAG: [[IOTA:%.*]] = "mhlo.iota"() {iota_dimension = 0 : i64} - // CHECK-DAG: [[MUL:%.*]] = chlo.broadcast_multiply [[IOTA]], [[STEP]] {broadcast_dimensions = dense<> : tensor<0xi64>} - // CHECK-DAG: [[LINSPACE:%.*]] = chlo.broadcast_add [[MUL]], [[START]] {broadcast_dimensions = dense<> : tensor<0xi64>} + // CHECK-DAG: [[MUL:%.*]] = chlo.broadcast_multiply [[IOTA]], [[STEP]] {broadcast_dimensions = array} + // CHECK-DAG: [[LINSPACE:%.*]] = chlo.broadcast_add [[MUL]], [[START]] {broadcast_dimensions = array} // CHECK: return [[LINSPACE]] %0 = "tf.Const"() {_output_shapes = ["tfshape$"], device = "", dtype = i32, value = dense<4> : tensor} : () -> tensor %1 = "tf.LinSpace"(%arg0, %arg1, %0) : (tensor, tensor, tensor) -> tensor<4xf32> @@ -5334,7 +5334,7 @@ func.func @random_shuffle_3D(%input: tensor<4x?x16xf32>) -> tensor<4x?x16xf32> { // CHECK-SAME: -> tensor<2x3x5x7xf32> // CHECK: [[COUNT:%.+]] = mhlo.constant dense<4.000000e+00> : tensor // CHECK: [[DIV_RESULT:%.+]] = chlo.broadcast_divide [[DIVIDEND]], [[COUNT]] -// CHECK-SAME: broadcast_dimensions = dense<> +// CHECK-SAME: broadcast_dimensions = array // CHECK-SAME: -> tensor<2x3x5x7xf32> // CHECK: [[CONV16:%.+]] = mhlo.convert [[DIV_RESULT]] // CHECK-SAME: -> tensor<2x3x5x7xf16> @@ -5360,7 +5360,7 @@ func.func @avgpool_valid_padding(%arg0: tensor<2x12x21x7xf16>) -> tensor<2x3x5x7 // CHECK-SAME: -> tensor<2x4x3x5x7xf32> // CHECK: [[COUNT:%.+]] = mhlo.constant dense<4.000000e+00> : tensor // CHECK: [[DIV_RESULT:%.+]] = chlo.broadcast_divide [[DIVIDEND]], [[COUNT]] -// CHECK-SAME: broadcast_dimensions = dense<> +// CHECK-SAME: broadcast_dimensions = array // CHECK-SAME: -> tensor<2x4x3x5x7xf32> // CHECK: [[CONV16:%.+]] = mhlo.convert [[DIV_RESULT]] // CHECK-SAME: -> tensor<2x4x3x5x7xf16> @@ -5386,7 +5386,7 @@ func.func @avgpool_3d_valid_padding(%arg0: tensor<2x4x12x21x7xf16>) -> tensor<2x // CHECK-SAME: -> tensor<2x7x3x5xf32> // CHECK: [[COUNT:%.+]] = mhlo.constant dense<4.000000e+00> : tensor // CHECK: [[DIV_RESULT:%.+]] = chlo.broadcast_divide [[DIVIDEND]], [[COUNT]] -// CHECK-SAME: broadcast_dimensions = dense<> +// CHECK-SAME: broadcast_dimensions = array // CHECK-SAME: -> tensor<2x7x3x5xf32> // CHECK: [[CONV16:%.+]] = mhlo.convert [[DIV_RESULT]] // CHECK-SAME: -> tensor<2x7x3x5xf16> @@ -5412,7 +5412,7 @@ func.func @avgpool_nchw_format(%arg0: tensor<2x7x12x21xf16>) -> tensor<2x7x3x5xf // CHECK-SAME: -> tensor<2x7x4x3x5xf32> // CHECK: [[COUNT:%.+]] = mhlo.constant dense<4.000000e+00> : tensor // CHECK: [[DIV_RESULT:%.+]] = chlo.broadcast_divide [[DIVIDEND]], [[COUNT]] -// CHECK-SAME: broadcast_dimensions = dense<> +// CHECK-SAME: broadcast_dimensions = array // CHECK-SAME: -> tensor<2x7x4x3x5xf32> // CHECK: [[CONV16:%.+]] = mhlo.convert [[DIV_RESULT]] // CHECK-SAME: -> tensor<2x7x4x3x5xf16> @@ -5497,7 +5497,7 @@ func.func @avgpool_3d_same_padding(%arg0: tensor<2x4x12x21x7xf32>) -> tensor<2x4 // CHECK-DAG: %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor // CHECK-DAG: %[[DIVISOR:.*]] = mhlo.constant dense<4.000000e+00> : tensor // CHECK: %[[OUT_GRAD_DIVIDED:.*]] = chlo.broadcast_divide %[[OUT_GRAD]], %[[DIVISOR]] -// CHECK-SAME: broadcast_dimensions = dense<> +// CHECK-SAME: broadcast_dimensions = array // CHECK-SAME: -> tensor<10x12x16x64xf32> // CHECK: %[[REDUCE_WINDOW_INPUT:.*]] = "mhlo.pad"(%[[OUT_GRAD_DIVIDED]], %[[ZERO]]) // CHECK-SAME: edge_padding_high = dense<[0, 1, 1, 0]> @@ -5530,7 +5530,7 @@ func.func @avgpool_grad_valid_padding(%grad: tensor<10x12x16x64xf32>) -> tensor< // CHECK-SAME: %[[OUT_GRAD:.*]]: tensor<10x8x12x16x64xf32>) -> tensor<10x8x24x32x64xf32> { // CHECK-DAG: %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor // CHECK-DAG: %[[DIVISOR:.*]] = mhlo.constant dense<4.000000e+00> : tensor -// CHECK: %[[OUT_GRAD_DIVIDED:.*]] = chlo.broadcast_divide %[[OUT_GRAD]], %[[DIVISOR]] {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<10x8x12x16x64xf32>, tensor) -> tensor<10x8x12x16x64xf32> +// CHECK: %[[OUT_GRAD_DIVIDED:.*]] = chlo.broadcast_divide %[[OUT_GRAD]], %[[DIVISOR]] {broadcast_dimensions = array} : (tensor<10x8x12x16x64xf32>, tensor) -> tensor<10x8x12x16x64xf32> // CHECK: %[[REDUCE_WINDOW_INPUT:.*]] = "mhlo.pad"(%[[OUT_GRAD_DIVIDED]], %[[ZERO]]) // CHECK-SAME: edge_padding_high = dense<[0, 0, 1, 1, 0]> // CHECK-SAME: edge_padding_low = dense<[0, 0, 1, 1, 0]> @@ -5724,7 +5724,7 @@ func.func @avgpool_3d_grad_ncdwh_format(%grad: tensor<2x9x8x4x7xf32>) -> tensor< // CHECK-DAG: %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor // CHECK-DAG: %[[DIVISOR:.*]] = mhlo.constant dense<4.000000e+00> : tensor // CHECK: %[[OUT_GRAD_DIVIDED:.*]] = chlo.broadcast_divide %[[OUT_GRAD]], %[[DIVISOR]] -// CHECK-SAME: broadcast_dimensions = dense<> +// CHECK-SAME: broadcast_dimensions = array // CHECK-SAME: -> tensor<10x12x16x64xbf16> // CHECK: %[[REDUCE_WINDOW_INPUT:.*]] = "mhlo.pad"(%[[OUT_GRAD_DIVIDED]], %[[ZERO]]) // CHECK-SAME: edge_padding_high = dense<[0, 1, 1, 0]> diff --git a/tensorflow/compiler/mlir/tf2xla/tests/verify-tfxla-legalization-no-chlo.mlir b/tensorflow/compiler/mlir/tf2xla/tests/verify-tfxla-legalization-no-chlo.mlir index 09374ca8006a2f..673e6c9ffd329a 100644 --- a/tensorflow/compiler/mlir/tf2xla/tests/verify-tfxla-legalization-no-chlo.mlir +++ b/tensorflow/compiler/mlir/tf2xla/tests/verify-tfxla-legalization-no-chlo.mlir @@ -6,6 +6,6 @@ // CHECK-LABEL: allows_chlo func.func @allows_chlo(%arg0: tensor<1x32x10x32xi32>, %arg1: tensor<32xi32>) -> tensor<1x32x10x32xi32> { - %0 = "chlo.broadcast_add"(%arg0, %arg1) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<1x32x10x32xi32>, tensor<32xi32>) -> tensor<1x32x10x32xi32> + %0 = "chlo.broadcast_add"(%arg0, %arg1) {broadcast_dimensions = array} : (tensor<1x32x10x32xi32>, tensor<32xi32>) -> tensor<1x32x10x32xi32> func.return %0 : tensor<1x32x10x32xi32> } diff --git a/tensorflow/compiler/mlir/tf2xla/tests/verify-tfxla-legalization.mlir b/tensorflow/compiler/mlir/tf2xla/tests/verify-tfxla-legalization.mlir index c91c7c9da8ac77..e6623350380fcb 100644 --- a/tensorflow/compiler/mlir/tf2xla/tests/verify-tfxla-legalization.mlir +++ b/tensorflow/compiler/mlir/tf2xla/tests/verify-tfxla-legalization.mlir @@ -31,7 +31,7 @@ func.func @invalid_mixed_mhlo() -> (tensor<8x64x128xcomplex> {mhlo.sharding func.func @fails_chlo(%arg0: tensor<1x32x10x32xi32>, %arg1: tensor<32xi32>) -> tensor<1x32x10x32xi32> { // expected-error @+1 {{Could not legalize op: chlo.broadcast_add}} - %0 = "chlo.broadcast_add"(%arg0, %arg1) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<1x32x10x32xi32>, tensor<32xi32>) -> tensor<1x32x10x32xi32> + %0 = "chlo.broadcast_add"(%arg0, %arg1) {broadcast_dimensions = array} : (tensor<1x32x10x32xi32>, tensor<32xi32>) -> tensor<1x32x10x32xi32> func.return %0 : tensor<1x32x10x32xi32> } diff --git a/tensorflow/compiler/mlir/tf2xla/transforms/BUILD b/tensorflow/compiler/mlir/tf2xla/transforms/BUILD index 0f0b1182e50bb7..28a459ccff2eac 100644 --- a/tensorflow/compiler/mlir/tf2xla/transforms/BUILD +++ b/tensorflow/compiler/mlir/tf2xla/transforms/BUILD @@ -270,7 +270,9 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", "//tensorflow/core/util/quantization:uniform_quant_ops_params", + "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/strings", + "@com_google_absl//absl/synchronization", "@com_google_absl//absl/types:span", "@llvm-project//llvm:Support", "@llvm-project//mlir:ArithDialect", diff --git a/tensorflow/compiler/mlir/tf2xla/transforms/legalization_op_config_test.cc b/tensorflow/compiler/mlir/tf2xla/transforms/legalization_op_config_test.cc index b8818bb9d3824b..4d96bb24acf152 100644 --- a/tensorflow/compiler/mlir/tf2xla/transforms/legalization_op_config_test.cc +++ b/tensorflow/compiler/mlir/tf2xla/transforms/legalization_op_config_test.cc @@ -132,7 +132,7 @@ TEST_F(LegalizationOpConfigTest, CountLoweringsSet) { // a new op, we should expect these to change too. EXPECT_EQ(mlir_lowering_count, 67); EXPECT_EQ(tf2xla_fallback_count, 316); - EXPECT_EQ(non_categorized_count, 422); + EXPECT_EQ(non_categorized_count, 423); } // Just a counter test to see which ops have duplicate lowerings. This isn't a @@ -224,7 +224,7 @@ TEST_F(LegalizationOpConfigTest, MlirLoweringWithoutXlaKernel) { } } - EXPECT_EQ(mlir_without_xla_count, 14); + EXPECT_EQ(mlir_without_xla_count, 13); } } // namespace mhlo diff --git a/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf.cc b/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf.cc index 70370bffc41f20..76056458079964 100644 --- a/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf.cc +++ b/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf.cc @@ -39,6 +39,7 @@ limitations under the License. #include "mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project #include "mlir/Dialect/Traits.h" // from @llvm-project #include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/IR/Diagnostics.h" // from @llvm-project @@ -497,7 +498,7 @@ static void CreateWhile32(Location loc, int num_iterations, // Increment the loop induction variable by one. auto one = builder->create(loc, builder->getI32IntegerAttr(1)); - auto scalar_broadcast_dims = GetI64ElementsAttr({}, builder); + auto scalar_broadcast_dims = builder->getDenseI64ArrayAttr({}); auto plus_one = builder->create( loc, block->getArgument(0), one, scalar_broadcast_dims); // Prepend with the updated loop induction variable. @@ -2172,7 +2173,7 @@ class ConvertFusedBatchNormGradBase non_feature_dims.push_back(i); } auto reduce_dims = GetI64ElementsAttr(non_feature_dims, &rewriter); - auto scalar_broadcast_dims = GetI64ElementsAttr({}, &rewriter); + auto scalar_broadcast_dims = rewriter.getDenseI64ArrayAttr({}); // scratch1 = rsqrt(var + epsilon) RankedTensorType scalar_float = @@ -2315,7 +2316,7 @@ class ConvertFusedBatchNormBase : public OpRewritePattern { Value corrected_variance = rewriter.create( op.getLoc(), batch_variance.getType(), batch_variance, - factor_const_op, /*broadcast_dimensions=*/DenseIntElementsAttr()); + factor_const_op, /*broadcast_dimensions=*/DenseI64ArrayAttr()); // Convert back to input type to stay aligned with expected output type // for TF op. @@ -2335,24 +2336,24 @@ class ConvertFusedBatchNormBase : public OpRewritePattern { // new_running_mean = alpha * old_mean + beta * batch_mean. auto alpha_mul_old_mean = rewriter.create( op.getLoc(), op.getMean().getType(), alpha, op.getMean(), - /*broadcast_dimensions=*/DenseIntElementsAttr()); + /*broadcast_dimensions=*/DenseI64ArrayAttr()); auto beta_mul_batch_mean = rewriter.create( op.getLoc(), batch_mean.getType(), beta, batch_mean, - /*broadcast_dimensions=*/DenseIntElementsAttr()); + /*broadcast_dimensions=*/DenseI64ArrayAttr()); batch_mean = rewriter.create( op.getLoc(), alpha_mul_old_mean, beta_mul_batch_mean, - /*broadcast_dimensions=*/DenseIntElementsAttr()); + /*broadcast_dimensions=*/DenseI64ArrayAttr()); // new_running_variance = alpha * old_variance + beta * batch_variance. auto alpha_mul_old_variance = rewriter.create( op.getLoc(), op.getVariance().getType(), alpha, op.getVariance(), - /*broadcast_dimensions=*/DenseIntElementsAttr()); + /*broadcast_dimensions=*/DenseI64ArrayAttr()); auto beta_mul_batch_variance = rewriter.create( op.getLoc(), corrected_variance.getType(), beta, corrected_variance, - /*broadcast_dimensions=*/DenseIntElementsAttr()); + /*broadcast_dimensions=*/DenseI64ArrayAttr()); corrected_variance = rewriter.create( op.getLoc(), alpha_mul_old_variance, beta_mul_batch_variance, - /*broadcast_dimensions=*/DenseIntElementsAttr()); + /*broadcast_dimensions=*/DenseI64ArrayAttr()); } if (std::is_same::value) { @@ -2522,7 +2523,7 @@ Operation *AvgPoolDivideByCount( // Divide `pooled` by window counts. Value divisor = GetScalarConstOfType(element_type, loc, window_count, &rewriter); - auto scalar_broadcast_dims = GetI64ElementsAttr({}, &rewriter); + auto scalar_broadcast_dims = rewriter.getDenseI64ArrayAttr({}); result = rewriter.create( loc, pooled_type, pooled, divisor, scalar_broadcast_dims); } else { @@ -4091,7 +4092,7 @@ class GenericConvertReductionOp : public OpRewritePattern { Value divisor = rewriter.create( loc, tensorflow::GetTypeFromTFTensorShape({}, reduce_element_type), divisor_tensor); - auto broadcast_dims = GetI64ElementsAttr({}, &rewriter); + auto broadcast_dims = rewriter.getDenseI64ArrayAttr({}); result = rewriter.create(loc, result, divisor, broadcast_dims); } @@ -6103,7 +6104,7 @@ class ConvertXlaReduceScatterOp if (replica_group_size == 0) return failure(); auto divisor = GetScalarConstOfType(element_type, loc, replica_group_size, &rewriter); - auto broadcast_dims = GetI64ElementsAttr({}, &rewriter); + auto broadcast_dims = rewriter.getDenseI64ArrayAttr({}); result = rewriter.create( loc, result, divisor.getResult(), broadcast_dims); } diff --git a/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf_collective.cc b/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf_collective.cc index d7937cce42ff24..54bd5812644488 100644 --- a/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf_collective.cc +++ b/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf_collective.cc @@ -176,7 +176,7 @@ LogicalResult ConvertAllReduce(OpBuilder& builder, int64_t channel_id, } auto divisor = GetScalarConstOfType(element_type, loc, replica_group_size, &builder); - auto broadcast_dims = GetI64ElementsAttr({}, &builder); + auto broadcast_dims = builder.getDenseI64ArrayAttr({}); result = builder.create( loc, all_reduce.getResult(0), divisor.getResult(), broadcast_dims); } else if (final_op != "Id") { diff --git a/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf_communication.cc b/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf_communication.cc index 763e94734f6d01..3e8dd5b58ed2f1 100644 --- a/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf_communication.cc +++ b/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf_communication.cc @@ -16,6 +16,9 @@ limitations under the License. // This file implements logic for lowering TensorFlow dialect's communication // ops (TF/XLA) to the HLO dialect. +#include +#include +#include #include #include #include @@ -31,6 +34,7 @@ limitations under the License. #include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/Location.h" // from @llvm-project #include "mlir/IR/TypeUtilities.h" // from @llvm-project #include "mlir/IR/Value.h" // from @llvm-project #include "mlir/IR/Visitors.h" // from @llvm-project @@ -69,6 +73,28 @@ class LegalizeTFCommunication void runOnOperation() override; }; +// A generator to serve out unique channel ids. +class ChannelIdGenerator { + public: + ChannelIdGenerator() = default; + ChannelIdGenerator(const ChannelIdGenerator&) = delete; + ChannelIdGenerator& operator=(const ChannelIdGenerator&) = delete; + ChannelIdGenerator(ChannelIdGenerator&&) = delete; + ChannelIdGenerator& operator=(ChannelIdGenerator&&) = delete; + int64_t operator++(int) { return next(); } + int64_t next() { return channel_id_.fetch_add(1, std::memory_order_relaxed); } + + private: + // All usage code expects positive int64_t values so we can't use uint64_t + // and will just have to limit ourselves to half the number space. + std::atomic channel_id_ = 1; +}; + +int64_t GetNextChannelId() { + static ChannelIdGenerator* channel_id = new ChannelIdGenerator(); + return channel_id->next(); +} + // Checks if an op is a TF/XLA communication op. bool IsCommunicationOp(Operation* op) { return isa( loc, token.getType(), operand, token, channel_handle, @@ -273,12 +299,12 @@ Value CreateSendOp(OpBuilder& builder, int64_t& channel_id, Location loc, } // Creates a `mhlo.recv` op for receiving a value. -Value CreateRecvOp(OpBuilder& builder, int64_t& channel_id, Location loc, - Value result, StringRef key, size_t index, Value token, +Value CreateRecvOp(OpBuilder& builder, Location loc, Value result, + StringRef key, size_t index, Value token, StringRef host_handler_name, bool manual_sharding) { // type 3 == HOST_TO_DEVICE auto channel_handle = ChannelHandleAttr::get(builder.getContext(), - /*handle=*/channel_id++, + /*handle=*/GetNextChannelId(), /*type=*/3); auto result_type = result.getType(); SmallVector recv_result_type = {result_type, token.getType()}; @@ -315,7 +341,7 @@ Value CreateSinkToken(OpBuilder& builder, Location loc, ArrayRef tokens, // ops per operand and result. Unique Channel IDs are assigned per transfer. // Sink tokens are created across all `mhlo.send` ops first and then by // all `mhlo.recv` ops. -Value RewriteHostComputeOp(OpBuilder& builder, int64_t& channel_id, +Value RewriteHostComputeOp(OpBuilder& builder, TF::_XlaHostComputeMlirOp host_compute, Value token) { builder.setInsertionPoint(host_compute); @@ -325,7 +351,7 @@ Value RewriteHostComputeOp(OpBuilder& builder, int64_t& channel_id, SmallVector send_tokens; for (auto operand : llvm::enumerate(host_compute.getInputs())) { auto send_token = CreateSendOp( - builder, channel_id, loc, operand.value(), host_compute.getSendKey(), + builder, loc, operand.value(), host_compute.getSendKey(), operand.index(), token, xla::kXlaHostTransferTfRendezvousHandlerName, manual_sharding); send_tokens.push_back(send_token); @@ -335,9 +361,8 @@ Value RewriteHostComputeOp(OpBuilder& builder, int64_t& channel_id, SmallVector recv_tokens; for (auto result : llvm::enumerate(host_compute.getOutputs())) { auto recv_token = CreateRecvOp( - builder, channel_id, loc, result.value(), host_compute.getRecvKey(), - result.index(), token, xla::kXlaHostTransferTfRendezvousHandlerName, - manual_sharding); + builder, loc, result.value(), host_compute.getRecvKey(), result.index(), + token, xla::kXlaHostTransferTfRendezvousHandlerName, manual_sharding); recv_tokens.push_back(recv_token); } token = CreateSinkToken(builder, loc, recv_tokens, token); @@ -347,11 +372,11 @@ Value RewriteHostComputeOp(OpBuilder& builder, int64_t& channel_id, } // Replaces `tf.XlaSendToHost` with a `mhlo.send`. -Value RewriteSendToHostOp(OpBuilder& builder, int64_t& channel_id, - TF::XlaSendToHostOp send_to_host, Value token) { +Value RewriteSendToHostOp(OpBuilder& builder, TF::XlaSendToHostOp send_to_host, + Value token) { builder.setInsertionPoint(send_to_host); - token = CreateSendOp(builder, channel_id, send_to_host.getLoc(), - send_to_host.getInput(), send_to_host.getKey(), + token = CreateSendOp(builder, send_to_host.getLoc(), send_to_host.getInput(), + send_to_host.getKey(), /*index=*/0, token, xla::kXlaHostTransferTfRendezvousHandlerName, /*manual_sharding=*/false); @@ -361,10 +386,10 @@ Value RewriteSendToHostOp(OpBuilder& builder, int64_t& channel_id, } // Replaces `tf.XlaRecvFromHost` with a `mhlo.recv`. -Value RewriteRecvFromHostOp(OpBuilder& builder, int64_t& channel_id, +Value RewriteRecvFromHostOp(OpBuilder& builder, TF::XlaRecvFromHostOp recv_from_host, Value token) { builder.setInsertionPoint(recv_from_host); - token = CreateRecvOp(builder, channel_id, recv_from_host.getLoc(), + token = CreateRecvOp(builder, recv_from_host.getLoc(), recv_from_host.getOutput(), recv_from_host.getKey(), /*index=*/0, token, xla::kXlaHostTransferTfRendezvousHandlerName, @@ -795,7 +820,7 @@ void RewriteFunctionTerminator(OpBuilder& builder, // rewritten to create a token or take in and return a token, depending on its // visibility and if there are any callers. LogicalResult RewriteFunction( - OpBuilder& builder, int64_t& channel_id, ModuleOp module, FuncOp func, + OpBuilder& builder, ModuleOp module, FuncOp func, const llvm::SmallDenseMap& funcs, const llvm::SmallPtrSetImpl& control_flow_ops, const llvm::SmallPtrSetImpl& control_flow_blocks, bool is_clone) { @@ -832,11 +857,11 @@ LogicalResult RewriteFunction( Operation* next_op = curr_op->getNextNode(); if (auto host_compute = dyn_cast(curr_op)) { - token = RewriteHostComputeOp(builder, channel_id, host_compute, token); + token = RewriteHostComputeOp(builder, host_compute, token); } else if (auto send_to_host = dyn_cast(curr_op)) { - token = RewriteSendToHostOp(builder, channel_id, send_to_host, token); + token = RewriteSendToHostOp(builder, send_to_host, token); } else if (auto recv_from_host = dyn_cast(curr_op)) { - token = RewriteRecvFromHostOp(builder, channel_id, recv_from_host, token); + token = RewriteRecvFromHostOp(builder, recv_from_host, token); } else if (auto call = dyn_cast(curr_op)) { // Only `mlir::func::CallOp` is supported as this requires knowing how to // rewrite arguments and results to a function. @@ -929,14 +954,11 @@ void LegalizeTFCommunication::runOnOperation() { if (failed(GetFunctionsToRewrite(module, funcs_to_rewrite))) return signalPassFailure(); - // Module level counter to make sure Channel IDs are unique. - int64_t channel_id = 1; OpBuilder builder(&getContext()); for (const auto& func_and_name : funcs_to_rewrite) { const auto& func_to_rewrite = func_and_name.getSecond(); func::FuncOp func = func_to_rewrite.original; - if (failed(RewriteFunction(builder, channel_id, module, func, - funcs_to_rewrite, + if (failed(RewriteFunction(builder, module, func, funcs_to_rewrite, func_to_rewrite.control_flow_ops, func_to_rewrite.control_flow_blocks, /*is_clone=*/false))) @@ -949,8 +971,8 @@ void LegalizeTFCommunication::runOnOperation() { GetCommunicationControlFlowOps(clone, funcs_to_rewrite, clone_control_flow_ops, clone_control_flow_blocks); - if (failed(RewriteFunction(builder, channel_id, module, clone, - funcs_to_rewrite, clone_control_flow_ops, + if (failed(RewriteFunction(builder, module, clone, funcs_to_rewrite, + clone_control_flow_ops, clone_control_flow_blocks, /*is_clone=*/true))) llvm_unreachable( diff --git a/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf_patterns.td b/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf_patterns.td index 0ee5d1dee5925d..108b1bf6e6bc86 100644 --- a/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf_patterns.td +++ b/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf_patterns.td @@ -75,7 +75,7 @@ def : Pat<(TF_ApproximateEqualOp:$result $x, $y, $tolerance), (CHLO_BroadcastCompareOp (MHLO_AbsOp:$abs (MHLO_SubtractOp $x, $y)), (CastValueToElementType $result, (MHLO_ConstantOp $tolerance), $abs), - (NullDenseIntElementsAttr), + (NullDenseI64ArrayAttr), CHLO_ComparisonDirectionValue<"LT">, (CHLO_DEFAULT_COMPARISON_TYPE))>; @@ -158,18 +158,18 @@ def : Pat<(TF_FloorDivOp AnyTensor:$l, AnyTensor:$r), (CHLO_BroadcastCompareOp (CHLO_BroadcastCompareOp:$l_cmp $l, (MHLO_ConstantOp:$l_zeros (GetScalarOfType<0> $l)), - (NullDenseIntElementsAttr), CHLO_ComparisonDirectionValue<"LT">, + (NullDenseI64ArrayAttr), CHLO_ComparisonDirectionValue<"LT">, (CHLO_DEFAULT_COMPARISON_TYPE)), (CHLO_BroadcastCompareOp:$r_cmp $r, (MHLO_ConstantOp:$r_zeros (GetScalarOfType<0> $r)), - (NullDenseIntElementsAttr), CHLO_ComparisonDirectionValue<"LT">, + (NullDenseI64ArrayAttr), CHLO_ComparisonDirectionValue<"LT">, (CHLO_DEFAULT_COMPARISON_TYPE)), (BinBroadcastDimensions $l_cmp, $r_cmp), CHLO_ComparisonDirectionValue<"NE">, (CHLO_DEFAULT_COMPARISON_TYPE)), - (NullDenseIntElementsAttr)), + (NullDenseI64ArrayAttr)), (CHLO_BroadcastSubOp $div, (MHLO_ConstantOp:$ones (GetScalarOfType<1> $div)), - (NullDenseIntElementsAttr)), $div), + (NullDenseI64ArrayAttr)), $div), [(SignedIntTensor $l)]>; // FloorDiv of unsigned is just div. @@ -189,19 +189,19 @@ def : Pat<(TF_FloorModOp AnyTensor:$l, AnyTensor:$r), (CHLO_BroadcastCompareOp (CHLO_BroadcastRemOp:$rem $l, $r, (BinBroadcastDimensions $l, $r)), (MHLO_ConstantOp:$l_zeros (GetScalarOfType<0> $l)), - (NullDenseIntElementsAttr), CHLO_ComparisonDirectionValue<"NE">, + (NullDenseI64ArrayAttr), CHLO_ComparisonDirectionValue<"NE">, (CHLO_DEFAULT_COMPARISON_TYPE)), (CHLO_BroadcastCompareOp (CHLO_BroadcastCompareOp:$r_cmp $r, (MHLO_ConstantOp:$r_zeros (GetScalarOfType<0> $r)), - (NullDenseIntElementsAttr), CHLO_ComparisonDirectionValue<"LT">, + (NullDenseI64ArrayAttr), CHLO_ComparisonDirectionValue<"LT">, (CHLO_DEFAULT_COMPARISON_TYPE)), (CHLO_BroadcastCompareOp:$rem_cmp $rem, $r_zeros, (BinBroadcastDimensions $rem, $r_zeros), CHLO_ComparisonDirectionValue<"LT">, (CHLO_DEFAULT_COMPARISON_TYPE)), (BinBroadcastDimensions $r_cmp, $rem_cmp), CHLO_ComparisonDirectionValue<"NE">, (CHLO_DEFAULT_COMPARISON_TYPE)), - (NullDenseIntElementsAttr)), + (NullDenseI64ArrayAttr)), (CHLO_BroadcastAddOp $r, $rem, (BinBroadcastDimensions $r, $rem)), $rem), [(TensorOf<[I8, I16, I32, I64, F16, F32, F64]> $l)]>; @@ -580,7 +580,7 @@ foreach Mapping = [ [TF_DigammaOp, CHLO_DigammaOp], [TF_ExpOp, MHLO_ExpOp], [TF_Expm1Op, MHLO_Expm1Op], - [TF_ErfOp, CHLO_ErfOp], + [TF_ErfOp, MHLO_ErfOp], [TF_ErfcOp, CHLO_ErfcOp], [TF_FloorOp, MHLO_FloorOp], [TF_ImagOp, MHLO_ImagOp], @@ -694,13 +694,13 @@ def : Pattern<(TF_SoftplusOp AnyTensor:$features), (CHLO_BroadcastAddOp:$threshold (MHLO_LogOp (MHLO_ConstantOp (EpsilonValue $features))), (MHLO_ConstantOp (GetScalarOfType<2> $features)), - (NullDenseIntElementsAttr) + (NullDenseI64ArrayAttr) ), (MHLO_SelectOp:$output (CHLO_BroadcastCompareOp $features, (MHLO_NegOp $threshold), - (NullDenseIntElementsAttr), + (NullDenseI64ArrayAttr), CHLO_ComparisonDirectionValue<"GT">, (CHLO_DEFAULT_COMPARISON_TYPE) ), @@ -709,7 +709,7 @@ def : Pattern<(TF_SoftplusOp AnyTensor:$features), (CHLO_BroadcastCompareOp $features, $threshold, - (NullDenseIntElementsAttr), + (NullDenseI64ArrayAttr), CHLO_ComparisonDirectionValue<"LT">, (CHLO_DEFAULT_COMPARISON_TYPE) ), diff --git a/tensorflow/compiler/mlir/tf_mlir_opt_main.cc b/tensorflow/compiler/mlir/tf_mlir_opt_main.cc index 122a9084771d88..2c49198be7bad8 100644 --- a/tensorflow/compiler/mlir/tf_mlir_opt_main.cc +++ b/tensorflow/compiler/mlir/tf_mlir_opt_main.cc @@ -40,7 +40,6 @@ limitations under the License. #include "xla/mlir_hlo/lhlo/transforms/passes.h" #include "xla/mlir_hlo/mhlo/transforms/passes.h" #include "xla/service/cpu/hlo_xla_runtime_pipeline.h" -#include "xla/translate/mhlo_to_lhlo_with_xla/mhlo_to_lhlo_with_xla.h" int main(int argc, char **argv) { tensorflow::InitMlir y(&argc, &argv); diff --git a/tensorflow/compiler/mlir/tfr/BUILD b/tensorflow/compiler/mlir/tfr/BUILD index 10862b22c4f8d6..04cd4282e5c451 100644 --- a/tensorflow/compiler/mlir/tfr/BUILD +++ b/tensorflow/compiler/mlir/tfr/BUILD @@ -344,7 +344,6 @@ tf_python_pybind_extension( ":tfr", "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/python/lib/core:pybind11_lib", - "//tensorflow/python/lib/core:pybind11_status", "@llvm-project//llvm:Support", "@llvm-project//mlir:ArithDialect", "@llvm-project//mlir:FuncDialect", @@ -352,6 +351,7 @@ tf_python_pybind_extension( "@llvm-project//mlir:Parser", "@llvm-project//mlir:SCFDialect", "@llvm-project//mlir:ShapeDialect", + "@llvm-project//mlir:Support", "@pybind11", ], ) diff --git a/tensorflow/compiler/mlir/tfr/python/tfr_wrapper.cc b/tensorflow/compiler/mlir/tfr/python/tfr_wrapper.cc index 760ddab974c7fd..5d572f8278684b 100644 --- a/tensorflow/compiler/mlir/tfr/python/tfr_wrapper.cc +++ b/tensorflow/compiler/mlir/tfr/python/tfr_wrapper.cc @@ -15,24 +15,25 @@ limitations under the License. #include +#include "llvm/Support/MemoryBuffer.h" +#include "llvm/Support/SMLoc.h" #include "llvm/Support/SourceMgr.h" -#include "llvm/Support/raw_ostream.h" #include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/Dialect/SCF/IR/SCF.h" // from @llvm-project #include "mlir/Dialect/Shape/IR/Shape.h" // from @llvm-project #include "mlir/IR/AsmState.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/Diagnostics.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project #include "mlir/IR/Verifier.h" // from @llvm-project #include "mlir/Parser/Parser.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project #include "pybind11/pybind11.h" // from @pybind11 #include "pybind11/stl.h" // from @pybind11 -#include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h" -#include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" -#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h" #include "tensorflow/compiler/mlir/tfr/ir/tfr_ops.h" #include "tensorflow/python/lib/core/pybind11_lib.h" -#include "tensorflow/python/lib/core/pybind11_status.h" PYBIND11_MODULE(tfr_wrapper, m) { m.def("verify", [](std::string input) { diff --git a/tensorflow/compiler/mlir/tfrt/BUILD b/tensorflow/compiler/mlir/tfrt/BUILD index 745d355867d1c4..315d9fca646d1d 100644 --- a/tensorflow/compiler/mlir/tfrt/BUILD +++ b/tensorflow/compiler/mlir/tfrt/BUILD @@ -703,6 +703,7 @@ cc_library( cc_library( name = "backend_compiler", + srcs = ["backend_compiler.cc"], hdrs = ["backend_compiler.h"], deps = [ "//tensorflow/core/tfrt/runtime", diff --git a/third_party/xla/third_party/tsl/tsl/platform/gif.h b/tensorflow/compiler/mlir/tfrt/backend_compiler.cc similarity index 73% rename from third_party/xla/third_party/tsl/tsl/platform/gif.h rename to tensorflow/compiler/mlir/tfrt/backend_compiler.cc index 865b6f201e66fe..7c04c778fda8da 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/gif.h +++ b/tensorflow/compiler/mlir/tfrt/backend_compiler.cc @@ -1,4 +1,4 @@ -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -13,9 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_TSL_PLATFORM_GIF_H_ -#define TENSORFLOW_TSL_PLATFORM_GIF_H_ +#include "tensorflow/compiler/mlir/tfrt/backend_compiler.h" -#include "gif_lib.h" // from @gif +namespace tensorflow { -#endif // TENSORFLOW_TSL_PLATFORM_GIF_H_ +BackendCompiler::~BackendCompiler() = default; + +} diff --git a/tensorflow/compiler/mlir/tfrt/backend_compiler.h b/tensorflow/compiler/mlir/tfrt/backend_compiler.h index 827dc92bd72f2e..0e959f04f43554 100644 --- a/tensorflow/compiler/mlir/tfrt/backend_compiler.h +++ b/tensorflow/compiler/mlir/tfrt/backend_compiler.h @@ -17,13 +17,16 @@ limitations under the License. #define TENSORFLOW_COMPILER_MLIR_TFRT_BACKEND_COMPILER_H_ #include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/DialectRegistry.h" // from @llvm-project #include "tensorflow/core/tfrt/runtime/runtime.h" namespace tensorflow { class BackendCompiler { public: - virtual ~BackendCompiler() = default; + virtual ~BackendCompiler(); + + virtual void GetDependentDialects(mlir::DialectRegistry& registry) const {} // Compile the `module` in TF dialect. The result module should be also in TF // dialect. diff --git a/tensorflow/compiler/mlir/tfrt/ir/mlrt/tf_mlrt_ops.td b/tensorflow/compiler/mlir/tfrt/ir/mlrt/tf_mlrt_ops.td index 87ee9fc91ca57e..7fbc42ad3db93f 100644 --- a/tensorflow/compiler/mlir/tfrt/ir/mlrt/tf_mlrt_ops.td +++ b/tensorflow/compiler/mlir/tfrt/ir/mlrt/tf_mlrt_ops.td @@ -443,6 +443,8 @@ def IfrtLoadVariableOp: TensorflowMlrt_Op<"ifrt_load_variable", []> { `tf.IfrtLoadVariableOp` converts the tensor into an IFRT array based on device and sharding configuration specified in `VariableDeviceShardingConfigProto`. + + This op returns a scalar string tensor as a key for user to look for the loaded array. }]; let arguments = (ins @@ -450,6 +452,10 @@ def IfrtLoadVariableOp: TensorflowMlrt_Op<"ifrt_load_variable", []> { StrAttr:$device_sharding_config_proto_text, StrAttr:$name ); + + let results = (outs + TFTensorType:$array_key + ); } diff --git a/tensorflow/compiler/mlir/tfrt/tests/ifrt/rewrite_cluster_to_ifrt_call.mlir b/tensorflow/compiler/mlir/tfrt/tests/ifrt/rewrite_cluster_to_ifrt_call.mlir index a23fafec92c028..dbb77732a3d6f6 100644 --- a/tensorflow/compiler/mlir/tfrt/tests/ifrt/rewrite_cluster_to_ifrt_call.mlir +++ b/tensorflow/compiler/mlir/tfrt/tests/ifrt/rewrite_cluster_to_ifrt_call.mlir @@ -1,17 +1,17 @@ -// RUN: tf-tfrt-opt -split-input-file -rewrite-cluster-to-ifrt-call=tpu-compile-metadata-debug %s | FileCheck %s +// RUN: tf-tfrt-opt -split-input-file -rewrite-cluster-to-ifrt-call %s | FileCheck %s // TODO(b/316226111): the printer may not guarantee the same order of fields. Rewrite the checks to be less sensitive to proto serialization formats. // ----- // Non-SPMD: one input and one output // // CHECK-LABEL: func.func @serving_default(%arg0: tensor<1x3xf32>) -> tensor<1x3xf32> { // CHECK-NEXT: "tf.IfrtCall"(%arg0) -// CHECK-SAME: {program_id = [[PROGRAM_ID:.*]] : i64, variable_arg_indices = [], variable_names = []} +// CHECK-SAME: {program_id = [[PROGRAM_ID:.*]] : i64, variable_arg_indices = []} // CHECK-SAME: (tensor<1x3xf32>) -> tensor<1x3xf32> // CHECK: return // -// CHECK: func.func private @_ifrt_program__func(%arg0: tensor<1x3xf32>) +// CHECK: func.func @_ifrt_program__func(%arg0: tensor<1x3xf32>) // CHECK-SAME: __tpu_compile_metadata_text = "args { dtype: DT_FLOAT shape { dim { size: 1 } dim { size: 3 } } kind: PARAMETER sharding { } is_bounded_dynamic_dim: false } retvals { sharding { } } num_replicas: 1 num_cores_per_replica: 1 " -// CHECK-SAME: tfrt_ifrt_serving.program_id = [[PROGRAM_ID]] : i64, tpu_compile_metadata = +// CHECK-SAME: tfrt_ifrt_serving.program_id = [[PROGRAM_ID]] : i64 // CHECK: return module attributes {tf.devices = ["/job:localhost/replica:0/task:0/device:CPU:0", "/job:localhost/replica:0/task:0/device:TPU_SYSTEM:0", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU:1"], tf.versions = {bad_consumers = [], min_consumer = 12 : i32, producer = 1704 : i32}} { @@ -33,13 +33,13 @@ func.func private @_func(%arg0: tensor<1x3xf32>) -> (tensor<1x3xf32>) { // // CHECK-LABEL: func.func @serving_default(%arg0: tensor<1x3xf32>) { // CHECK-NEXT: "tf.IfrtCall"(%arg0) -// CHECK-SAME: {program_id = [[PROGRAM_ID:.*]] : i64, variable_arg_indices = [], variable_names = []} +// CHECK-SAME: {program_id = [[PROGRAM_ID:.*]] : i64, variable_arg_indices = []} // CHECK-SAME: (tensor<1x3xf32>) -> () // CHECK: return // -// CHECK: func.func private @_ifrt_program__func(%arg0: tensor<1x3xf32>) +// CHECK: func.func @_ifrt_program__func(%arg0: tensor<1x3xf32>) // CHECK-SAME: __tpu_compile_metadata_text = "args { dtype: DT_FLOAT shape { dim { size: 1 } dim { size: 3 } } kind: PARAMETER sharding { type: OTHER tile_assignment_dimensions: 2 tile_assignment_dimensions: 1 tile_assignment_devices: 0 tile_assignment_devices: 1 } is_bounded_dynamic_dim: false } num_replicas: 1 num_cores_per_replica: 2 device_assignment { replica_count: 1 computation_count: 2 computation_devices { replica_device_ids: 0 } computation_devices { replica_device_ids: 1 } } use_spmd_for_xla_partitioning: true " -// CHECK-SAME: tfrt_ifrt_serving.program_id = [[PROGRAM_ID]] : i64, tpu_compile_metadata = +// CHECK-SAME: tfrt_ifrt_serving.program_id = [[PROGRAM_ID]] : i64 // CHECK: return module attributes {tf.devices = ["/job:localhost/replica:0/task:0/device:CPU:0", "/job:localhost/replica:0/task:0/device:TPU_SYSTEM:0", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU:1"], tf.versions = {bad_consumers = [], min_consumer = 12 : i32, producer = 1704 : i32}} { @@ -60,17 +60,17 @@ func.func private @_func(%arg0: tensor<1x3xf32>) -> () { // CHECK-LABEL: func.func @serving_default(%arg0: tensor<3x1xf32>, %arg1: tensor<1x3xf32>) -> tensor<1x1xf32> { // CHECK-NEXT: %0 = "tf.IfrtCall"(%arg1, %arg0) -// CHECK-SAME: {program_id = [[PROGRAM_ID:.*]] : i64, variable_arg_indices = [], variable_names = []} +// CHECK-SAME: {program_id = [[PROGRAM_ID:.*]] : i64, variable_arg_indices = [] // CHECK-SAME: (tensor<1x3xf32>, tensor<3x1xf32>) -> tensor<1x1xf32> // CHECK-NEXT: %1 = "tf.Identity"(%arg1) {device = ""} : (tensor<1x3xf32>) -> tensor<1x3xf32> // CHECK-NEXT: %2 = "tf.IfrtCall"(%1, %arg0) -// CHECK-SAME: {program_id = [[PROGRAM_ID]] : i64, variable_arg_indices = [], variable_names = [] +// CHECK-SAME: {program_id = [[PROGRAM_ID]] : i64, variable_arg_indices = [] // CHECK-SAME: (tensor<1x3xf32>, tensor<3x1xf32>) -> tensor<1x1xf32> // CHECK-NEXT: %3 = "tf.add"(%0, %2) : (tensor<1x1xf32>, tensor<1x1xf32>) -> tensor<1x1xf32> // CHECK: return // -// CHECK: func.func private @_ifrt_program__func(%arg0: tensor<1x3xf32>, %arg1: tensor<3x1xf32>) -> tensor<1x1xf32> -// CHECK-SAME: tfrt_ifrt_serving.program_id = [[PROGRAM_ID]] : i64, tpu_compile_metadata = +// CHECK: func.func @_ifrt_program__func(%arg0: tensor<1x3xf32>, %arg1: tensor<3x1xf32>) -> tensor<1x1xf32> +// CHECK-SAME: tfrt_ifrt_serving.program_id = [[PROGRAM_ID]] : i64 // CHECK-NEXT: %0 = "tf.MatMul"(%arg0, %arg1) // CHECK: return @@ -97,12 +97,12 @@ func.func private @_func(%arg0: tensor<1x3xf32>, %arg1: tensor<3x1xf32>) -> (ten // CHECK-LABEL: func.func @serving_default(%arg0: tensor<3x1xf32>, %arg1: tensor<1x3xf32>) -> tensor<1x1xf32> { // CHECK-NEXT: %0 = "tf.IfrtCall"(%arg1, %arg0) -// CHECK-SAME: {program_id = [[PROGRAM_ID:.*]] : i64, variable_arg_indices = [], variable_names = []} +// CHECK-SAME: {program_id = [[PROGRAM_ID:.*]] : i64, variable_arg_indices = [] // CHECK-SAME: (tensor<1x3xf32>, tensor<3x1xf32>) -> tensor<1x1xf32> // CHECK: return // -// CHECK: func.func private @_ifrt_program__func(%arg0: tensor<1x3xf32>, %arg1: tensor<3x1xf32>) -> tensor<1x1xf32> -// CHECK-SAME: tfrt_ifrt_serving.program_id = [[PROGRAM_ID]] : i64, tpu_compile_metadata = +// CHECK: func.func @_ifrt_program__func(%arg0: tensor<1x3xf32>, %arg1: tensor<3x1xf32>) -> tensor<1x1xf32> +// CHECK-SAME: tfrt_ifrt_serving.program_id = [[PROGRAM_ID]] : i64 // CHECK-NEXT: %0 = "tf.MatMul"(%arg0, %arg1) // CHECK: return diff --git a/tensorflow/compiler/mlir/tfrt/tests/ifrt/sink_variable_as_named_array.mlir b/tensorflow/compiler/mlir/tfrt/tests/ifrt/sink_variable_as_named_array.mlir new file mode 100644 index 00000000000000..ba644948c6b06d --- /dev/null +++ b/tensorflow/compiler/mlir/tfrt/tests/ifrt/sink_variable_as_named_array.mlir @@ -0,0 +1,44 @@ +// RUN: tf-tfrt-opt -split-input-file -sink-variable-as-named-array %s | FileCheck %s + +// ----- +// Basic test: all variables tensors are for devices and sinked as named ifrt arrays +// +// +// CHECK-LABEL: func.func @serving_default(%arg0: tensor<1x3xf32>) -> tensor<1x1xf32> { +// CHECK-NEXT: [[HANDLE2:%.*]] = "tf.VarHandleOp" +// CHECK-NEXT: [[KEY:%.*]] = "tf.IfrtLoadVariable"([[HANDLE2]]) +// CHECK-SAME: device_sharding_config_proto_text = "sharding { type: OTHER tile_assignment_dimensions: 2 tile_assignment_dimensions: 1 tile_assignment_devices: 0 tile_assignment_devices: 1 } device_ids: 0 device_ids: 1 " +// CHECK-SAME: name = "__y" +// CHECK-NEXT: [[RES:%.*]] = "tf.IfrtCall"([[KEY]], %arg0) <{program_id = 6515870160938153680 : i64, variable_arg_indices = [0 : i32]}> +// CHECK-SAME: : (tensor, tensor<1x3xf32>) -> tensor<1x1xf32> +// CHECK-NEXT: return [[RES]] : tensor<1x1xf32> +// +module { + func.func @serving_default(%arg0: tensor<1x3xf32>) -> tensor<1x1xf32> { + %0 = "tf.VarHandleOp"() <{container = "", shared_name = "y"}> : () -> tensor>> + %2 = "tf.ReadVariableOp"(%0) : (tensor>>) -> tensor<3x1xf32> + %result = "tf.IfrtCall"(%2, %arg0) <{program_id = 6515870160938153680 : i64, variable_arg_indices = []}> { __tpu_compile_metadata_text = "args { dtype: DT_FLOAT shape { dim { size: 3 } dim { size: 1 } } kind: PARAMETER sharding { type: OTHER tile_assignment_dimensions: 2 tile_assignment_dimensions: 1 tile_assignment_devices: 0 tile_assignment_devices: 1 } is_bounded_dynamic_dim: false } args { dtype: DT_FLOAT shape { dim { size: 3 } dim { size: 1 } } kind: PARAMETER sharding { } is_bounded_dynamic_dim: false } retvals { sharding { } } num_replicas: 1 num_cores_per_replica: 2 device_assignment { replica_count: 1 computation_count: 2 computation_devices { replica_device_ids: 0 } computation_devices { replica_device_ids: 1 } } use_spmd_for_xla_partitioning: true "} : (tensor<3x1xf32>, tensor<1x3xf32>) -> (tensor<1x1xf32>) + return %result : tensor<1x1xf32> + } +} + +// ----- +// Variable tensor for host can still be used. +// +// CHECK-LABEL: func.func @serving_default(%arg0: tensor<1x3xf32>) -> tensor<1x1xf32> { +// CHECK: "tf.VarHandleOp" +// CHECK-NEXT: [[VARIABLE:%.*]] = "tf.ReadVariableOp" +// CHECK-NEXT: [[KEY:%.*]] = "tf.IfrtLoadVariable" +// CHECK-NEXT: "tf.MatMul"(%arg0, [[VARIABLE]]) +// CHECK-NEXT: [[RES:%.*]] = "tf.IfrtCall"(%arg0, [[KEY]]) <{program_id = 6515870160938153680 : i64, variable_arg_indices = [1 : i32]}> +// CHECK-NEXT: return [[RES]] : tensor<1x1xf32> +// +module { + func.func @serving_default(%arg0: tensor<1x3xf32>) -> tensor<1x1xf32> { + %0 = "tf.VarHandleOp"() <{container = "", shared_name = "y"}> : () -> tensor>> + %2 = "tf.ReadVariableOp"(%0) : (tensor>>) -> tensor<3x1xf32> + %3 = "tf.MatMul"(%arg0, %2) : (tensor<1x3xf32>, tensor<3x1xf32>) -> tensor<1x1xf32> + %result = "tf.IfrtCall"(%arg0, %2) <{program_id = 6515870160938153680 : i64, variable_arg_indices = []}> { __tpu_compile_metadata_text = "args { dtype: DT_FLOAT shape { dim { size: 1 } dim { size: 3 } } kind: PARAMETER sharding { type: OTHER tile_assignment_dimensions: 2 tile_assignment_dimensions: 1 tile_assignment_devices: 0 tile_assignment_devices: 1 } is_bounded_dynamic_dim: false } args { dtype: DT_FLOAT shape { dim { size: 3 } dim { size: 1 } } kind: PARAMETER sharding { } is_bounded_dynamic_dim: false } retvals { sharding { } } num_replicas: 1 num_cores_per_replica: 2 device_assignment { replica_count: 1 computation_count: 2 computation_devices { replica_device_ids: 0 } computation_devices { replica_device_ids: 1 } } use_spmd_for_xla_partitioning: true "} : (tensor<1x3xf32>, tensor<3x1xf32>) -> (tensor<1x1xf32>) + return %result : tensor<1x1xf32> + } +} diff --git a/tensorflow/compiler/mlir/tfrt/tests/ifrt/tf_restore_merging.mlir b/tensorflow/compiler/mlir/tfrt/tests/ifrt/tf_restore_merging.mlir new file mode 100644 index 00000000000000..cd2ee741ba86bf --- /dev/null +++ b/tensorflow/compiler/mlir/tfrt/tests/ifrt/tf_restore_merging.mlir @@ -0,0 +1,21 @@ +// RUN: tf-tfrt-opt %s -tf-restore-merging | FileCheck %s + +// CHECK-LABEL: func @single_restore_group +// CHECK-SAME: (%[[ARG0:.*]]: {{.*}}) +func.func @single_restore_group(%arg0: tensor) -> (tensor<*xf32>, tensor<*xi32>) { + %0 = "tf.Const"() {value = dense<"foo"> : tensor<1x!tf_type.string>} : () -> tensor<1x!tf_type.string> + %1 = "tf.Const"() {value = dense<""> : tensor<1x!tf_type.string>} : () -> tensor<1x!tf_type.string> + %2 = "tf.RestoreV2"(%arg0, %0, %1) : (tensor, tensor<1x!tf_type.string>, tensor<1x!tf_type.string>) -> tensor<*xf32> + + %3 = "tf.Const"() {value = dense<"bar"> : tensor<1x!tf_type.string>} : () -> tensor<1x!tf_type.string> + %4 = "tf.Const"() {value = dense<""> : tensor<1x!tf_type.string>} : () -> tensor<1x!tf_type.string> + %5 = "tf.RestoreV2"(%arg0, %3, %4) : (tensor, tensor<1x!tf_type.string>, tensor<1x!tf_type.string>) -> tensor<*xi32> + + // CHECK: %[[NAMES:.*]] = "tf.Const"() <{value = dense<["foo", "bar"]> : tensor<2x!tf_type.string>}> + // CHECK-NEXT: %[[SHAPES:.*]] = "tf.Const"() <{value = dense<""> : tensor<2x!tf_type.string>}> + // CHECK-NEXT: %[[TENSORS:.*]]:2 = "tf.RestoreV2"(%[[ARG0]], %[[NAMES]], %[[SHAPES]]) + // CHECK-SAME: -> (tensor<*xf32>, tensor<*xi32>) + + // CHECK: return %[[TENSORS]]#0, %[[TENSORS]]#1 + func.return %2, %5 : tensor<*xf32>, tensor<*xi32> +} diff --git a/tensorflow/compiler/mlir/tfrt/tests/ifrt/tf_restore_splitting.mlir b/tensorflow/compiler/mlir/tfrt/tests/ifrt/tf_restore_splitting.mlir new file mode 100644 index 00000000000000..1aafed888aa9b9 --- /dev/null +++ b/tensorflow/compiler/mlir/tfrt/tests/ifrt/tf_restore_splitting.mlir @@ -0,0 +1,18 @@ +// RUN: tf-tfrt-opt %s -tf-restore-splitting | FileCheck %s + +// CHECK-LABEL: func @single_restore +// CHECK-SAME: (%[[ARG0:.*]]: {{.*}}) +func.func @single_restore(%arg0: tensor) -> (tensor<*xf32>, tensor<*xi32>) { + %0 = "tf.Const"() {value = dense<["foo", "bar"]> : tensor<2x!tf_type.string>} : () -> tensor<2x!tf_type.string> + %1 = "tf.Const"() {value = dense<""> : tensor<2x!tf_type.string>} : () -> tensor<2x!tf_type.string> + %2:2 = "tf.RestoreV2"(%arg0, %0, %1) : (tensor, tensor<2x!tf_type.string>, tensor<2x!tf_type.string>) -> (tensor<*xf32>, tensor<*xi32>) + + // CHECK: %[[FOO_NAME:.*]] = "tf.Const"() <{value = dense<"foo"> : tensor<1x!tf_type.string>}> + // CHECK: %[[FOO:.*]] = "tf.RestoreV2"(%[[ARG0]], %[[FOO_NAME]], {{.*}}) + + // CHECK: %[[BAR_NAME:.*]] = "tf.Const"() <{value = dense<"bar"> : tensor<1x!tf_type.string>}> + // CHECK: %[[BAR:.*]] = "tf.RestoreV2"(%[[ARG0]], %[[BAR_NAME]], {{.*}}) + + // CHECK: return %[[FOO]], %[[BAR]] + func.return %2#0, %2#1 : tensor<*xf32>, tensor<*xi32> +} diff --git a/tensorflow/compiler/mlir/tfrt/tests/mlrt/tf_to_mlrt.mlir b/tensorflow/compiler/mlir/tfrt/tests/mlrt/tf_to_mlrt.mlir index 94a28091c7235f..eb2e0587364d6e 100644 --- a/tensorflow/compiler/mlir/tfrt/tests/mlrt/tf_to_mlrt.mlir +++ b/tensorflow/compiler/mlir/tfrt/tests/mlrt/tf_to_mlrt.mlir @@ -458,3 +458,21 @@ func.func @xla_func(%arg0: tensor<1x3xf32>) -> tensor<*xf32> attributes {tf.entr %2 = "tf.XlaLaunch"(%arg0, %1) {__op_key = 3: i32, _noinline = true, _xla_compile_device_type = "GPU", device = "/device:GPU:0", function = @xla_func_0, operandSegmentSizes = array} : (tensor<1x3xf32>, tensor<1x3xf32>) -> tensor<*xf32> func.return %2 : tensor<*xf32> } + +// ----- + +// Test lowering of IfrtLoadVariableOp + +// CHECK-LABEL: func @ifrt_load_variable_test +func.func @ifrt_load_variable_test() -> () { + // CHECK: [[HANDLE:%.*]] = tf_mlrt.executeop() + // CHECK-SAME: VarHandleOp + %0 = "tf.VarHandleOp"() {__op_key = 1: i32, device = "/device:CPU:0", container = "", shared_name = "variable"} : () -> tensor>> + // CHECK-NEXT: "tf_mlrt.ifrt_load_variable"([[HANDLE]]) + // CHECK-SAME: device_sharding_config_proto_text + // CHECK-SAME: name = "__variable" + %1 = "tf.IfrtLoadVariable"(%0) <{device_sharding_config_proto_text = "sharding { } device_ids: 0 device_ids: 1 ", name = "__variable"}> {__op_key = 2: i32, device = "/device:CPU:0"} : (tensor>>) -> (tensor) + // CHECK-NEXT: return + func.return +} + diff --git a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/BUILD b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/BUILD index c3a2a757f3776b..d120d1fe4f5005 100644 --- a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/BUILD +++ b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/BUILD @@ -15,6 +15,8 @@ package_group( "//tensorflow/core/tfrt/saved_model/tests/...", ] + if_google([ "//learning/brain/tfrt/cpp_tests/...", + "//learning/pathways/serving/runtime/...", + "//learning/pathways/serving/tests/...", # Allow visibility from the mlir language server. "//learning/brain/mlir/mlir_lsp_server/...", ]), @@ -51,7 +53,10 @@ cc_library( name = "tf_ifrt_passes", srcs = [ "rewrite_cluster_to_ifrt_call.cc", + "sink_variable_as_named_array.cc", "tf_ifrt_passes.cc", + "tf_restore_merging.cc", + "tf_restore_splitting.cc", ], hdrs = [ "tf_ifrt_passes.h", @@ -66,22 +71,31 @@ cc_library( "//tensorflow/compiler/mlir/tensorflow:dump_mlir_util", "//tensorflow/compiler/mlir/tensorflow:error_util", "//tensorflow/compiler/mlir/tensorflow:tensorflow_ops", + "//tensorflow/compiler/mlir/tensorflow:tensorflow_types", "//tensorflow/compiler/mlir/tensorflow:tpu_rewrite_device_util", "//tensorflow/compiler/mlir/tensorflow/transforms:tensorflow_passes", "//tensorflow/compiler/mlir/tensorflow/transforms/host_runtime:tpu_metadata_utils", + "//tensorflow/compiler/mlir/tfrt:tf_to_tfrt", "//tensorflow/core:framework", + "//tensorflow/core/platform:protobuf", "//tensorflow/core/platform:random", "//tensorflow/core/protobuf/tpu:compile_metadata_proto_cc", + "//tensorflow/core/tfrt/ifrt:ifrt_config_proto_cc", "@com_google_absl//absl/base", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/log", "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@llvm-project//llvm:Support", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", "@llvm-project//mlir:Support", + "@llvm-project//mlir:Transforms", "@local_tsl//tsl/platform:protobuf", + "@local_xla//xla:xla_data_proto_cc", + "@local_xla//xla/service:computation_placer_hdr", ], ) @@ -104,6 +118,7 @@ cc_library( "//tensorflow/core/protobuf/tpu:compile_metadata_proto_cc", "//tensorflow/core/protobuf/tpu:topology_proto_cc", "//tensorflow/core/tpu/kernels:tpu_compile_op_support", + "//tensorflow/core/tpu/kernels/xla:host_compute_ops", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", @@ -115,7 +130,6 @@ cc_library( "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:protobuf", "@local_tsl//tsl/platform:statusor", "@local_xla//xla:shape_util", "@local_xla//xla:xla_data_proto_cc", @@ -141,7 +155,6 @@ tf_cc_test( ":tf2hlo", "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/tf2xla:xla_helpers", - "//tensorflow/core:framework", "//tensorflow/core:test", "//tensorflow/core/platform:resource_loader", "@com_google_absl//absl/log", diff --git a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/ifrt_backend_compiler.cc b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/ifrt_backend_compiler.cc index 412e86ef4e39f3..b0815b8f5c0272 100644 --- a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/ifrt_backend_compiler.cc +++ b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/ifrt_backend_compiler.cc @@ -92,6 +92,7 @@ CompileAndRegisterIfrtPrograms(absl::string_view model_name, model_name, entry_function_name.str(), *std::move(submodule), ifrt_model_context.GetClient(), &ifrt_model_context.GetThreadPoolDevice(), + &ifrt_model_context.GetLoadedVariableRegistry(), ifrt_model_context.GetShapeRepresentationFn()); // Register the Ifrt program to `ServingExecutableRegistry` so that @@ -159,7 +160,7 @@ absl::Status IfrtBackendCompiler::CompileTensorflow( // Use bridge for cluster formation. TF_RETURN_IF_ERROR(tensorflow::tf2xla::v2::RunFunctionTf2xlaClusteringBridge( - module, tensorflow::tf2xla::v2::DeviceType::XLA_TPU_JIT, + module, /*is_supported_by_replicated_brige*/ true, /*is_in_fallback_enabled_mode=*/false)); if (VLOG_IS_ON(1)) { diff --git a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/ifrt_constants.h b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/ifrt_constants.h index 4cd5bf2cbfc3bb..3e4971826c67b6 100644 --- a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/ifrt_constants.h +++ b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/ifrt_constants.h @@ -21,14 +21,19 @@ limitations under the License. namespace tensorflow { namespace ifrt_serving { -// Attribute name of a serialized TpuCompileMetadataProto. This is backward -// compatible. -inline constexpr absl::string_view kMetadataAttrName = "tpu_compile_metadata"; // Attribute name of a text TpuCompileMetadataProto. Note that the text proto is -// not backward compatible and only used for debug. +// not backward compatible and shall not be serialized. inline constexpr absl::string_view kMetadataTextAttrName = "__tpu_compile_metadata_text"; +// Name of a variable as loaded IFRT array . +inline constexpr absl::string_view kVariableArrayNameAttr = + "__variable_array_name"; + +// Attribute of a text `VariableDeviceShardingConfigProto`. +inline constexpr absl::string_view kVariableShardingConfigTextAttr = + "__variable_sharding_config_text"; + } // namespace ifrt_serving } // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/passes.td b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/passes.td index a79e91f0422983..c725aa85b0157b 100644 --- a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/passes.td +++ b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/passes.td @@ -26,11 +26,52 @@ def RewriteClusterToIfrtCallPass: Pass<"rewrite-cluster-to-ifrt-call", "mlir::Mo let dependentDialects = ["mlir::tf_device::TensorFlowDeviceDialect"]; let constructor = "CreateRewriteClusterToIfrtCallPass()"; + } - let options = [ - Option<"tpu_compile_metadata_debug_", "tpu-compile-metadata-debug", "bool", "false", - "if enabled, output compile metadata as readable string in " - "an extra __tpu_compile_metadata_debug attribute for debug">, - ]; +def SinkVariableAsNamedArrayPass: Pass<"sink-variable-as-named-array", "mlir::ModuleOp"> { + let summary = "Sink variable tensor for tpu device as named IFRT array for tf.IfrtCall"; + let description = [{ + This pass sinks variable tensor argument to `tf.IfrtCall` as variable_arg_indices + and variable_names attributes and also lowers `tf.ReadVariableOp` to + `tf.IfrtLoadVariableOp`. + + The runtime ensures that `tf.IfrtCall` kernel can bind the IFRT array by + its name as input to the TPU program. + + }]; + + let constructor = "CreateSinkVariableAsNamedArrayPass()"; } + +def TfRestoreSplittingPass + : Pass<"tf-restore-splitting", "mlir::func::FuncOp"> { + let summary = "Splits `tf.RestoreV2` ops"; + + let description = [{ + This pass splits each `tf.RestoreV2` op so that one restore op handles one + variable only. This pass can split restore ops only if the tensor names and + the shape/slices arguments are constants, which is usually the case. + + Splitting monolithic restore ops into per-tensor restore ops makes it easier + to shard SavedModel initialization across multiple clusters. + }]; + + let constructor = "CreateTfRestoreSplittingPass()"; +} + +def TfRestoreMergingPass : Pass<"tf-restore-merging", "mlir::func::FuncOp"> { + let summary = "Merges `tf.RestoreV2` ops"; + + let description = [{ + This pass merges multiple `tf.RestoreV2` ops into one `tf.RestoreV2` op + using variadic results. The current implementation merges restore ops only + if they have the same `prefix` and have constant tensor names and + shape/slice arguments. + + This pass is run in order to undo `tf-restore-splitting` after cluster + formation and reduce the op dispatch overhead. + }]; + + let constructor = "CreateTfRestoreMergingPass()"; +} \ No newline at end of file diff --git a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/rewrite_cluster_to_ifrt_call.cc b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/rewrite_cluster_to_ifrt_call.cc index 89428db2dd148c..b8bd4685919d4c 100644 --- a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/rewrite_cluster_to_ifrt_call.cc +++ b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/rewrite_cluster_to_ifrt_call.cc @@ -190,9 +190,15 @@ class RewriteClusterToIfrtCallPass return signalPassFailure(); } + auto metadata_attr = + ifrt_program->getAttrOfType(kMetadataTextAttrName); + if (!metadata_attr) { + return signalPassFailure(); + } + ifrt_call_op->setAttr(kMetadataTextAttrName, metadata_attr); + // TODO(b/304839793): populate variable names after adding a variable // hoisting pass. - ifrt_call_op.setVariableNamesAttr(builder.getArrayAttr({})); ifrt_call_op.setVariableArgIndicesAttr(builder.getI32ArrayAttr({})); ifrt_call_op.setProgramId(program_id); @@ -214,25 +220,24 @@ class RewriteClusterToIfrtCallPass if (mlir::failed(GetTpuCompileMetadata(cluster_func, devices, &metadata))) { return signalPassFailure(); } + std::string serialized_metadata; + tsl::protobuf::TextFormat::Printer printer; + printer.SetSingleLineMode(true); + printer.PrintToString(metadata, &serialized_metadata); - cloned_ifrt_program->setAttr( - kMetadataAttrName, builder.getStringAttr(metadata.SerializeAsString())); + cloned_ifrt_program->setAttr(kMetadataTextAttrName, + builder.getStringAttr(serialized_metadata)); - if (tpu_compile_metadata_debug_) { - std::string serialized_metadata; - tsl::protobuf::TextFormat::Printer printer; - printer.SetSingleLineMode(true); - printer.PrintToString(metadata, &serialized_metadata); - - cloned_ifrt_program->setAttr(kMetadataTextAttrName, - builder.getStringAttr(serialized_metadata)); - } cloned_ifrt_program.setName(ifrt_program_name); int64_t program_id = NewProgramId(); cloned_ifrt_program->setAttr("tfrt_ifrt_serving.program_id", builder.getI64IntegerAttr(program_id)); + // Make clonet ifrt program public so that it does not get dropped by + // inliner. + cloned_ifrt_program.setPublic(); + builder.setInsertionPoint(cluster_func); mlir::TF::IfrtCallOp ifrt_call_op = builder.create( @@ -241,9 +246,12 @@ class RewriteClusterToIfrtCallPass // TODO(b/304839793): populate variable names after adding a variable // hoisting pass. - ifrt_call_op.setVariableNamesAttr(builder.getArrayAttr({})); ifrt_call_op.setVariableArgIndicesAttr(builder.getI32ArrayAttr({})); ifrt_call_op.setProgramId(program_id); + // Additionally attach tpu_compile_metadata to IfrtCallOp. Some subsequent + // pass such as SinkVariableAsNamedArrayPass relies on this attribute. + ifrt_call_op->setAttr(kMetadataTextAttrName, + builder.getStringAttr(serialized_metadata)); cluster_func->replaceAllUsesWith(ifrt_call_op.getResults()); cluster_func->erase(); diff --git a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/sink_variable_as_named_array.cc b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/sink_variable_as_named_array.cc new file mode 100644 index 00000000000000..450c441ef4e3f7 --- /dev/null +++ b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/sink_variable_as_named_array.cc @@ -0,0 +1,359 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/Casting.h" +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/Diagnostics.h" // from @llvm-project +#include "mlir/IR/OpDefinition.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/IR/Visitors.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tfrt_ops.h" +#include "tensorflow/compiler/mlir/tfrt/transforms/ifrt/ifrt_constants.h" +#include "xla/service/computation_placer.h" +#include "xla/xla_data.pb.h" +#include "tensorflow/core/platform/protobuf.h" // IWYU pragma: keep +#include "tensorflow/core/protobuf/tpu/compile_metadata.pb.h" +#include "tensorflow/core/tfrt/ifrt/ifrt_config.pb.h" + +namespace tensorflow { +namespace ifrt_serving { +namespace { + +#define GEN_PASS_DEF_SINKVARIABLEASNAMEDARRAYPASS +#define GEN_PASS_DECL_SINKVARIABLEASNAMEDARRAYPASS +#include "tensorflow/compiler/mlir/tfrt/transforms/ifrt/passes.h.inc" // IWYU pragma: keep + +class SinkVariableAsNamedArrayPass + : public impl::SinkVariableAsNamedArrayPassBase< + SinkVariableAsNamedArrayPass> { + public: + void runOnOperation() override { + mlir::ModuleOp module = getOperation(); + mlir::OpBuilder builder(&getContext()); + + absl::flat_hash_map variable_config_by_name; + llvm::SmallDenseMap + ifrt_call_argument_configs; + + // First, we backtrack from IFRT call to collect variable tensors that needs + // to converted to loaded ifrt arrays and their associated information such + // as their name and defining ops. + std::vector ifrt_call_ops; + module.walk([&ifrt_call_ops](mlir::TF::IfrtCallOp call) { + ifrt_call_ops.push_back(call); + }); + for (const auto& call : ifrt_call_ops) { + if (mlir::failed(CollectVariablesUsedByDevice( + call, variable_config_by_name, ifrt_call_argument_configs))) { + return signalPassFailure(); + } + } + + // Rewrite ReadVariableOp with IfrtLoadVariableOp + llvm::SmallDenseMap + read_to_load; + for (auto& [name, variable_config] : variable_config_by_name) { + for (auto& read_variable_op : variable_config.read_variable_op) { + builder.setInsertionPointAfter(read_variable_op); + // TODO(b/319045348): consider use resource alias analysis for this. + auto var_handle = GetDefiningOp( + read_variable_op.getResource()); + + if (!var_handle) { + read_variable_op->emitError( + "ReadVariableOp has no defining VarHandleOp."); + return signalPassFailure(); + } + + auto load_variable_op = builder.create( + read_variable_op->getLoc(), + mlir::RankedTensorType::get( + {}, builder.getType()), + var_handle.getResult(), + builder.getStringAttr(variable_config.device_sharding_config), + builder.getStringAttr(name)); + read_to_load[read_variable_op] = load_variable_op; + } + } + + // Rewrite ifrt call: variable tensors are sunk as attribute. + // The runtime guarantees the binding of corresponding loaded ifrt array + // based on attributes. + for (auto& call : ifrt_call_ops) { + if (!call.getVariableArgIndicesAttr().empty()) { + call->emitError() << "Expect empty " + << call.getVariableArgIndicesAttrName().str() + << " attributes, but got " + << call.getVariableArgIndicesAttr().size() + << " elements"; + return signalPassFailure(); + } + if (call->getOpOperands().size() != + ifrt_call_argument_configs[call].size()) { + call->emitError() << "IfrtCallOp got " << call->getOpOperands().size() + << " operands, but expects " + << ifrt_call_argument_configs[call].size(); + return signalPassFailure(); + } + llvm::SmallVector variable_arg_indices; + llvm::SmallVector variable_arg_names; + llvm::SmallVector updated_args; + + for (const auto& [arg_idx, arg] : + llvm::enumerate(ifrt_call_argument_configs[call])) { + if (arg.is_variable) { + variable_arg_names.push_back( + builder.getStringAttr(arg.variable_name)); + variable_arg_indices.push_back(arg_idx); + // Variable use the key from IfrtLoadVariable. + updated_args.push_back( + read_to_load[arg.read_variable_op].getResult()); + } else { + // non variable + updated_args.push_back(call->getOperand(arg_idx)); + } + } + + builder.setInsertionPointAfter(call); + auto updated_ifrt_call = builder.create( + call->getLoc(), call.getResultTypes(), updated_args); + + updated_ifrt_call->setAttrs(call->getAttrs()); + // Update variable_arg_indices attribute. + updated_ifrt_call.setVariableArgIndicesAttr( + builder.getI32ArrayAttr(variable_arg_indices)); + + call.replaceAllUsesWith(updated_ifrt_call); + call.erase(); + } + + // Delete all ReadVariableOps that are not used. + for (auto& [name, variable_config] : variable_config_by_name) { + for (auto& read_variable_op : variable_config.read_variable_op) { + if (read_variable_op.use_empty()) { + read_variable_op.erase(); + } + } + } + } + + private: + struct VariableConfig { + // VariableDeviceShardingConfig text proto. + std::string device_sharding_config; + // All ReadVariableOps that returns this named variable. + std::vector read_variable_op; + }; + struct IfrtArgConfig { + bool is_variable; + std::string variable_name; + mlir::TF::ReadVariableOp read_variable_op; + }; + using IfrtArgConfigList = llvm::SmallVector; + + // Find defining ReadVariableOps and also build argument configuration map of + // a IfrtCallOp. + mlir::LogicalResult CollectVariablesUsedByDevice( + mlir::TF::IfrtCallOp call, + absl::flat_hash_map& variable_config_by_name, + llvm::SmallDenseMap& + ifrt_call_argument_configs) { + IfrtArgConfigList& args = ifrt_call_argument_configs[call]; + + tensorflow::tpu::TPUCompileMetadataProto metadata; + + // TODO(b/319045348): remove the usage kMetadataAttrName. + auto metadata_attr = + call->getAttrOfType(kMetadataTextAttrName); + if (metadata_attr && !metadata_attr.empty()) { + if (!tensorflow::protobuf::TextFormat::ParseFromString( + metadata_attr.getValue().str(), &metadata)) { + return call.emitError() + << "Failed to parse TPUCompileMetadataProto from attr :" + << metadata_attr.getValue().str(); + } + } else { + return call.emitError() + << "Failed to Get TPUCompileMetadataProto from attr"; + } + + for (const auto& [arg_idx, input] : llvm::enumerate(call->getOperands())) { + // Assuming the nested function calls are inlined. + if (auto read_variable_op = + GetDefiningOp(input)) { + mlir::FailureOr variable_tensor_name = + GetVariableTensorName(read_variable_op); + + if (mlir::failed(variable_tensor_name)) { + return mlir::failure(); + } + + absl::StatusOr device_sharding_config = + GetVariableShardingConfig(metadata, arg_idx); + if (!device_sharding_config.ok()) { + return call->emitError() + << "Fail to get device sharding config for argument index " + << arg_idx; + } + VariableConfig& variable_config = + variable_config_by_name[*variable_tensor_name]; + if (!variable_config.read_variable_op.empty()) { + if (variable_config.device_sharding_config != + *device_sharding_config) { + return call->emitError() + << "A variable tensor has different sharding config: " + << variable_config.device_sharding_config << " vs " + << *device_sharding_config; + } + } else { + variable_config.device_sharding_config = *device_sharding_config; + } + + variable_config.read_variable_op.push_back(read_variable_op); + args.push_back({.is_variable = true, + .variable_name = *variable_tensor_name, + .read_variable_op = read_variable_op}); + } else { + args.push_back({.is_variable = false}); + } + } + + return mlir::success(); + } + + // The returned variable tensor name is used both as an internal hash key, + // and as the binding name between the tensor and the array in the + // runtime. + std::string GetVariableTensorName(mlir::TF::VarHandleOp var_handle) { + return absl::StrCat(absl::string_view(var_handle.getContainer()), "__", + absl::string_view(var_handle.getSharedName())); + } + + mlir::FailureOr GetVariableTensorName( + mlir::TF::ReadVariableOp read_variable_op) { + mlir::Value variable_definition = read_variable_op.getResource(); + auto var_handle = GetDefiningOp(variable_definition); + + if (!var_handle) { + return read_variable_op->emitError("ReadVariableOp has no defining op."); + } + + return GetVariableTensorName(var_handle); + } + + absl::StatusOr GetVariableShardingConfig( + const tensorflow::tpu::TPUCompileMetadataProto& metadata, int arg_idx) { + tensorflow::ifrt_serving::VariableDeviceShardingConfigProto + device_sharding_config; + std::vector device_ids; + + if (metadata.has_device_assignment()) { + absl::StatusOr> da = + xla::DeviceAssignment::Deserialize(metadata.device_assignment()); + + if (!da.ok()) { + return da.status(); + } + if (metadata.num_replicas() != (*da)->replica_count() || + metadata.num_cores_per_replica() != (*da)->computation_count()) { + return absl::FailedPreconditionError(absl::StrCat( + "Device assignment has different replica count: ", + metadata.num_replicas(), " vs ", (*da)->replica_count(), + " or computation count: ", metadata.num_cores_per_replica(), " vs ", + (*da)->computation_count(), ".")); + } + + device_ids.reserve(metadata.num_replicas() * + metadata.num_cores_per_replica()); + for (int i = 0; i < (*da)->replica_count(); ++i) { + for (int j = 0; j < (*da)->computation_count(); ++j) { + device_ids.push_back((**da)(i, j)); + } + } + } else { + // Default use first N devices. + device_ids.resize(metadata.num_replicas() * + metadata.num_cores_per_replica()); + std::iota(device_ids.begin(), device_ids.end(), 0); + } + + device_sharding_config.mutable_device_ids()->Assign(device_ids.begin(), + device_ids.end()); + + if (metadata.args_size() > 0) { + *device_sharding_config.mutable_sharding() = + metadata.args(arg_idx).sharding(); + } + + std::string proto_text; + tsl::protobuf::TextFormat::Printer printer; + printer.SetSingleLineMode(true); + printer.PrintToString(device_sharding_config, &proto_text); + + return proto_text; + } + + template + OpT GetDefiningOp(const mlir::Value& value) { + mlir::Operation* op = value.getDefiningOp(); + + while (op && !llvm::isa(op)) { + if (llvm::isa(op)) { + op = op->getOperand(0).getDefiningOp(); + } else { + return nullptr; + } + } + + if (op != nullptr) { + return llvm::dyn_cast(op); + } else { + return nullptr; + } + } +}; + +} // namespace + +std::unique_ptr> +CreateSinkVariableAsNamedArrayPass() { + return std::make_unique(); +} + +} // namespace ifrt_serving +} // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/testdata/xla_call_host_callback.mlir b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/testdata/xla_call_host_callback.mlir new file mode 100644 index 00000000000000..4eff0866ba7a66 --- /dev/null +++ b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/testdata/xla_call_host_callback.mlir @@ -0,0 +1,23 @@ +module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, producer = 1758 : i32}} { + + func.func private @callee(%arg0: tensor, %arg1: tensor<*xi32>) { + "tf.XlaHostCompute"(%arg0, %arg1) <{ancestors = [], key = "@test_callee", recv_key = "", send_key = "", shapes = []}> {_xla_original_oc_node_name = "hcb0", _xla_token_input_nodes = ["_xla_token_arg_node"]} : (tensor, tensor<*xi32>) -> () + return + } + + // The mlir module in XlaCallModule is serialized from: + // + // func.func private @_stablehlo_main_0(%arg0: tensor, %arg1: tensor<*xi32>) -> () attributes {_from_xla_call_module} { + // stablehlo.custom_call @tf.call_tf_function(%arg0, %arg1) {api_version = 2 : i32, has_side_effect = true, tf.backend_config = {called_func = @callee}} : (tensor, tensor<*xi32>) -> () + // return + // } + // + // func.func @main(%arg0: tensor, %arg1: tensor<*xi32>) -> () { + // "tf.XlaCallModule"(%arg0, %arg1) {Sout = [#tf_type.shape], dim_args_spec = [], _entry_function = @_stablehlo_main_0, _stablehlo_module_attrs = { mhlo.num_partitions = 1 }, module = "", platforms = [], version = 5 : i64} : (tensor, tensor<*xi32>) -> () + // func.return + // } + func.func @main(%arg0: tensor, %arg1: tensor<*xi32>) attributes {tfrt_ifrt_serving.program_id = -2372940092539171444 : i64, __tpu_compile_metadata_text = "args { dtype: DT_INT32 kind: PARAMETER sharding { } } args { dtype: DT_INT32 kind: PARAMETER sharding { } } num_replicas: 1 num_cores_per_replica: 1 use_spmd_for_xla_partitioning: true compile_options { }"} { + "tf.XlaCallModule"(%arg0, %arg1) <{Sout = [], dim_args_spec = [], function_list = [@callee], module = "ML\EFR\0DStableHLO_v0.17.6\00\01\19\05\01\05\09\01\03\0B\03\07\0F\13\17\03M-\0D\01\19\0B\13\0B\0F\13\13\13\13\13\0B\13\13\03\15\0B\0B\0B\0B\13\0B\0F\0B\0B\0B\01\03\0F\03\0B3\07\0B\17\07\02\B1\05\0D\03\03\05\07\05\0F\11\01\05\17\01A\0B\17\01!\07\17\01!Q\17\01!}\03\03\13!\05\11\17\01#\0B\17\01%\0B\03\01\1D\13#\09\1D\15\0D\03#%\1D\17\13\0B\01\0B\05\1D\19\05\03\01\02\04)\03\00\FF\FF\FF\FF\FF\FF\FF\FF\05\1B3\05\11\05\03\07\01\1D\04O\05\01Q\09\03\01\07\04=\03\01\05\03P\0B\03\07\04)\03\05\0B\05\07\0D\0F\0F\00\05E\15\11\05\05\01\03\07\00\17\06\03\01\05\01\00j\03\1B)\1B\0B\03%)\95\15\1F\11\0F\0B\11builtin\00vhlo\00module\00func_v1\00custom_call_v1\00return_v1\00experimental/users/deqiangc/mira/testdata/xla_call_module_serialized.mlir\00mhlo.num_partitions\00tf.backend_config\00\00main\00called_index\00tf.call_tf_function\00\08'\07\05\01\01\0B\19\1D\19\1F\1B\11'\1B)\19+\19\19\19", platforms = [], version = 5 : i64}> : (tensor, tensor<*xi32>) -> () + return + } +} diff --git a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf2hlo.cc b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf2hlo.cc index 2b555fe42281e3..312761a3ba06d7 100644 --- a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf2hlo.cc +++ b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf2hlo.cc @@ -48,7 +48,7 @@ limitations under the License. #include "xla/service/computation_placer.h" #include "xla/service/llvm_ir/llvm_util.h" #include "xla/shape.h" -#include "xla/stream_executor/multi_platform_manager.h" +#include "xla/stream_executor/platform_manager.h" #include "xla/translate/hlo_to_mhlo/hlo_to_mlir_hlo.h" #include "xla/xla_data.pb.h" #include "tensorflow/core/framework/function.pb.h" @@ -59,7 +59,6 @@ limitations under the License. #include "tensorflow/core/protobuf/tpu/topology.pb.h" #include "tensorflow/core/tpu/kernels/tpu_compile_op_support.h" #include "tsl/platform/errors.h" -#include "tsl/platform/protobuf.h" #include "tsl/platform/statusor.h" namespace tensorflow { @@ -68,24 +67,14 @@ namespace { static constexpr absl::string_view kEntryFuncName = "main"; absl::StatusOr GetCompileMetadata( - mlir::func::FuncOp op, absl::Span inputs, + mlir::func::FuncOp op, absl::Span inputs, const xla::ifrt::Client& ifrt_client) { tensorflow::tpu::TPUCompileMetadataProto metadata; - auto metadata_attr = op->getAttrOfType(kMetadataAttrName); auto metadata_text_attr = op->getAttrOfType(kMetadataTextAttrName); - if (metadata_attr && !metadata_attr.getValue().empty()) { - // tpu_compile_metadata takes priority if exists. - VLOG(1) << "Parsing from attribute " << kMetadataAttrName << " : " - << metadata_attr.getValue().str(); - if (!metadata.ParseFromString(metadata_attr.getValue().str())) { - return absl::InternalError( - absl::StrCat("Failed to parse tpu_compile_metadata attribute:", - metadata_attr.getValue().str())); - } - } else if (metadata_text_attr && !metadata_text_attr.getValue().empty()) { + if (metadata_text_attr && !metadata_text_attr.getValue().empty()) { // Try __tpu_compile_metadata_text attribute. This only for debugging // purpose. VLOG(1) << "Parsing from attribute " << kMetadataTextAttrName @@ -97,12 +86,11 @@ absl::StatusOr GetCompileMetadata( metadata_text_attr.getValue().str(), " cannot be parsed")); } } else { - return absl::InvalidArgumentError(absl::StrCat( - "Missing ", kMetadataAttrName, " and ", kMetadataTextAttrName)); + return absl::InvalidArgumentError( + absl::StrCat("Missing ", kMetadataTextAttrName)); } - VLOG(3) << "TpuCompileMetadata before shape is populated " - << metadata.DebugString(); + VLOG(3) << "TpuCompileMetadata before shape is populated " << metadata; if (metadata.num_replicas() < 1 || metadata.num_cores_per_replica() < 1) { return absl::InternalError( absl::StrCat("Number of replicas ", metadata.num_replicas(), @@ -127,14 +115,14 @@ absl::StatusOr GetCompileMetadata( "Only support PARAMETER, but got ", metadata.args(i).kind())); } - if (metadata.args(i).dtype() != inputs[i].dtype()) { + if (metadata.args(i).dtype() != inputs[i].dtype) { return absl::InternalError(absl::StrCat("Dtype mismatched! Expected ", metadata.args(i).dtype(), " got ", - inputs[i].dtype())); + inputs[i].dtype)); } // Update shape. - *metadata.mutable_args(i)->mutable_shape() = inputs[i].shape().AsProto(); + *metadata.mutable_args(i)->mutable_shape() = inputs[i].shape.AsProto(); } // Create a default device assignment if one is not given by the model. @@ -155,7 +143,7 @@ absl::StatusOr GetCompileMetadata( } // namespace absl::StatusOr CompileTfToHlo( - mlir::ModuleOp module, absl::Span inputs, + mlir::ModuleOp module, absl::Span inputs, absl::string_view entry_function_name, const xla::ifrt::Client& ifrt_client, tensorflow::XlaHelpers::ShapeRepresentationFn shape_representation_fn) { if (VLOG_IS_ON(1)) { @@ -171,7 +159,7 @@ absl::StatusOr CompileTfToHlo( TF_ASSIGN_OR_RETURN( auto* platform, - stream_executor::MultiPlatformManager::PlatformWithName("Host")); + stream_executor::PlatformManager::PlatformWithName("Host")); TF_ASSIGN_OR_RETURN( auto* client, xla::ClientLibrary::GetOrCreateCompileOnlyClient(platform)); @@ -189,11 +177,11 @@ absl::StatusOr CompileTfToHlo( TF_ASSIGN_OR_RETURN(tensorflow::tpu::TPUCompileMetadataProto compile_metadata, GetCompileMetadata(entry_fn, inputs, ifrt_client)); - VLOG(1) << "Compilation metadata: " << compile_metadata.DebugString(); + VLOG(1) << "Compilation metadata: " << compile_metadata; std::vector arg_shapes; for (const auto& input : inputs) { - arg_shapes.push_back(input.shape()); + arg_shapes.push_back(input.shape); } bool use_tuple_args = false; @@ -223,6 +211,7 @@ absl::StatusOr CompileTfToHlo( Tf2HloResult result; result.mlir_hlo_module = xla::llvm_ir::CreateMlirModuleOp(module->getLoc()); result.compile_metadata = std::move(compile_metadata); + result.host_compute_metadata = compilation_result.host_compute_metadata; TF_RETURN_IF_ERROR(xla::ConvertHloToMlirHlo( *result.mlir_hlo_module, &compilation_result.computation->proto())); diff --git a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf2hlo.h b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf2hlo.h index f170a5fc5d9265..74fa271401f547 100644 --- a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf2hlo.h +++ b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf2hlo.h @@ -24,20 +24,28 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "xla/python/ifrt/client.h" #include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/protobuf/tpu/compile_metadata.pb.h" namespace tensorflow { namespace ifrt_serving { +struct DtypeAndShape { + tensorflow::DataType dtype; + tensorflow::TensorShape shape; +}; + struct Tf2HloResult { mlir::OwningOpRef mlir_hlo_module; tensorflow::tpu::TPUCompileMetadataProto compile_metadata; + tf2xla::HostComputeMetadata host_compute_metadata; }; // A class that convert tf module to hlo // TODO(b/304839793): provide wrap persistent compilation cache. absl::StatusOr CompileTfToHlo( - mlir::ModuleOp module, absl::Span inputs, + mlir::ModuleOp module, absl::Span inputs, absl::string_view entry_function_name, const xla::ifrt::Client& ifrt_client, tensorflow::XlaHelpers::ShapeRepresentationFn shape_representation_fn); diff --git a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf2hlo_test.cc b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf2hlo_test.cc index 89f353c68eafdd..7ee1c450426b20 100644 --- a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf2hlo_test.cc +++ b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf2hlo_test.cc @@ -35,8 +35,6 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "xla/python/ifrt/client.h" #include "xla/python/ifrt/test_util.h" -#include "tensorflow/core/framework/tensor.h" -#include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/platform/resource_loader.h" #include "tensorflow/core/platform/test.h" #include "tsl/lib/core/status_test_util.h" @@ -124,13 +122,12 @@ TEST(Tf2HloTest, Tuple) { TF_ASSERT_OK_AND_ASSIGN(std::shared_ptr client, xla::ifrt::test_util::GetClient()); - std::vector tensors; - tensorflow::Tensor x(DT_FLOAT, tensorflow::TensorShape({1, 3})); - tensorflow::Tensor y(DT_FLOAT, tensorflow::TensorShape({3, 1})); - tensors.push_back(x); - tensors.push_back(y); - auto result = CompileTfToHlo(mlir_module.get(), tensors, "main", *client, - tensorflow::IdentityShapeRepresentationFn()); + std::vector dtype_and_shapes; + dtype_and_shapes.push_back(DtypeAndShape{DT_FLOAT, {1, 3}}); + dtype_and_shapes.push_back(DtypeAndShape{DT_FLOAT, {3, 1}}); + auto result = + CompileTfToHlo(mlir_module.get(), dtype_and_shapes, "main", *client, + tensorflow::IdentityShapeRepresentationFn()); TF_ASSERT_OK(result.status()); } @@ -158,12 +155,11 @@ TEST(Tf2HloTest, Spmd) { TF_ASSERT_OK_AND_ASSIGN(std::shared_ptr client, xla::ifrt::test_util::GetClient()); - std::vector tensors; - tensorflow::Tensor x(DT_FLOAT, tensorflow::TensorShape({4, 64})); - tensors.push_back(x); - - auto result = CompileTfToHlo(mlir_module.get(), tensors, "main", *client, - tensorflow::IdentityShapeRepresentationFn()); + std::vector dtype_and_shapes; + dtype_and_shapes.push_back(DtypeAndShape{DT_FLOAT, {4, 64}}); + auto result = + CompileTfToHlo(mlir_module.get(), dtype_and_shapes, "main", *client, + tensorflow::IdentityShapeRepresentationFn()); LOG(INFO) << result->compile_metadata; TF_ASSERT_OK(result.status()); @@ -227,16 +223,13 @@ TEST(Tf2HloTest, UsingDefaultDeviceAssignment) { TF_ASSERT_OK_AND_ASSIGN(std::shared_ptr client, xla::ifrt::test_util::GetClient()); - std::vector tensors; - tensorflow::Tensor x(DT_FLOAT, tensorflow::TensorShape({4, 64})); - tensorflow::Tensor y(DT_FLOAT, tensorflow::TensorShape({64, 10})); - tensorflow::Tensor z(DT_FLOAT, tensorflow::TensorShape({1, 4})); - tensors.push_back(x); - tensors.push_back(y); - tensors.push_back(z); - - auto result = CompileTfToHlo(mlir_module.get(), tensors, "main", *client, - tensorflow::IdentityShapeRepresentationFn()); + std::vector dtype_and_shapes; + dtype_and_shapes.push_back(DtypeAndShape{DT_FLOAT, {4, 64}}); + dtype_and_shapes.push_back(DtypeAndShape{DT_FLOAT, {64, 10}}); + dtype_and_shapes.push_back(DtypeAndShape{DT_FLOAT, {1, 4}}); + auto result = + CompileTfToHlo(mlir_module.get(), dtype_and_shapes, "main", *client, + tensorflow::IdentityShapeRepresentationFn()); LOG(INFO) << result->compile_metadata; TF_ASSERT_OK(result.status()); @@ -302,6 +295,47 @@ TEST(Tf2HloTest, UsingDefaultDeviceAssignment) { EXPECT_THAT(result->compile_metadata, EqualsProto(expected_compile_metadata)); } +// Multiple input and multiple out. +TEST(Tf2HloTest, XlaCallHostCallback) { + // Create test input module + constexpr absl::string_view kDataDirectory = + "tensorflow/compiler/mlir/tfrt/transforms/ifrt/testdata"; + std::string mlir_module_path = tensorflow::GetDataDependencyFilepath( + absl::StrCat(kDataDirectory, "/xla_call_host_callback.mlir")); + + mlir::DialectRegistry registry; + mlir::registerAllDialects(registry); + mlir::RegisterAllTensorFlowDialects(registry); + + mlir::MLIRContext context(registry); + + mlir::OwningOpRef mlir_module = + mlir::parseSourceFile(mlir_module_path, + mlir::ParserConfig(&context)); + + ASSERT_TRUE(mlir_module); + ASSERT_TRUE(mlir_module.get() != nullptr); + + TF_ASSERT_OK_AND_ASSIGN(std::shared_ptr client, + xla::ifrt::test_util::GetClient()); + + std::vector dtype_and_shapes; + dtype_and_shapes.push_back(DtypeAndShape{DT_INT32, {1}}); + dtype_and_shapes.push_back(DtypeAndShape{DT_INT32, {1}}); + + auto result = + CompileTfToHlo(mlir_module.get(), dtype_and_shapes, "main", *client, + tensorflow::IdentityShapeRepresentationFn()); + + TF_ASSERT_OK(result.status()); + + ASSERT_EQ((*result).host_compute_metadata.device_to_host().size(), 1); + ASSERT_EQ( + (*result).host_compute_metadata.device_to_host().begin()->metadata_size(), + 2); + ASSERT_EQ((*result).host_compute_metadata.host_to_device().size(), 0); +} + } // namespace } // namespace ifrt_serving } // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf_ifrt_passes.cc b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf_ifrt_passes.cc index 47fe04925d1881..53bd55cc0d2799 100644 --- a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf_ifrt_passes.cc +++ b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf_ifrt_passes.cc @@ -27,10 +27,12 @@ limitations under the License. #include "mlir/Pass/PassManager.h" // from @llvm-project #include "mlir/Pass/PassRegistry.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/Transforms/Passes.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" #include "tensorflow/compiler/mlir/tensorflow/utils/data_dumper_logger_config.h" #include "tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h" #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h" +#include "tensorflow/compiler/mlir/tfrt/transforms/passes.h" #include "tensorflow/core/util/debug_data_dumper.h" namespace tensorflow { @@ -69,6 +71,20 @@ void AddClusterToIfrtRuntimeOpsPassPipeline(OpPassManager& pm, mlir::TF::CreateCanonicalizeCompileAndReplicateAttributesPass()); pm.addPass(CreateRewriteClusterToIfrtCallPass()); + + // Sink VarHandle with ReadVariableOp: subsequent SinkVariableAsNamedArrayPass + // rely on the co-existence of VarHandle and ReadVariable in the same + // function. + // First, we inline all the function calls. This will sink VarHandle + // with ReadVariable in most cases. Then SinkInvariantOpsPass will sink + // VarHandle to a few special Ops that inliner does not handle. + // TODO(b/319045348): the bridge before this pipeline already does some + // inlining. Consider removing this inliner. + pm.addPass(mlir::createInlinerPass()); + pm.addPass(::tensorflow::CreateSinkInInvariantOpsPass()); + + // Sink variable tensor as named array in IFRT. + pm.addPass(CreateSinkVariableAsNamedArrayPass()); } } // namespace @@ -111,9 +127,7 @@ absl::Status RunClusterToIfrtRuntimeOpsPassPipeline( } // Register all IfrtPass -void RegisterTfIfrtPasses() { - mlir::registerPass([]() { return CreateRewriteClusterToIfrtCallPass(); }); -} +void RegisterTfIfrtPasses() { registerTfrtIfrtServingPasses(); } } // namespace ifrt_serving } // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf_ifrt_passes.h b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf_ifrt_passes.h index b04c61eb5c76d8..084b170fa5b9a4 100644 --- a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf_ifrt_passes.h +++ b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf_ifrt_passes.h @@ -20,6 +20,7 @@ limitations under the License. #include "absl/status/status.h" #include "llvm/ADT/StringRef.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project @@ -30,6 +31,19 @@ namespace ifrt_serving { std::unique_ptr> CreateRewriteClusterToIfrtCallPass(); +// Creates a pass that sinks variable tensor argument to `tf.IfrtCall` as named +// arrays and lowers `tf.ReadVariableOp` to `tf.IfrtLoadVariableOp`. +std::unique_ptr> +CreateSinkVariableAsNamedArrayPass(); + +// Creates a pass that splits `tf.RestoreV2` ops. +std::unique_ptr> +CreateTfRestoreSplittingPass(); + +// Creates a pass that merges `tf.RestoreV2` ops. +std::unique_ptr> +CreateTfRestoreMergingPass(); + #define GEN_PASS_REGISTRATION #include "tensorflow/compiler/mlir/tfrt/transforms/ifrt/passes.h.inc" // IWYU pragma: keep diff --git a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf_restore_merging.cc b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf_restore_merging.cc new file mode 100644 index 00000000000000..5220824d3f716a --- /dev/null +++ b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf_restore_merging.cc @@ -0,0 +1,164 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include +#include + +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/StringRef.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/Diagnostics.h" // from @llvm-project +#include "mlir/IR/Location.h" // from @llvm-project +#include "mlir/IR/Matchers.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/OperationSupport.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" + +namespace tensorflow { +namespace ifrt_serving { +namespace { + +#define GEN_PASS_DEF_TFRESTOREMERGINGPASS +#define GEN_PASS_DECL_TFRESTOREMERGINGPASS +#include "tensorflow/compiler/mlir/tfrt/transforms/ifrt/passes.h.inc" // IWYU pragma: keep + +class TfRestoreMergingPass + : public impl::TfRestoreMergingPassBase { + public: + void runOnOperation() override { + mlir::func::FuncOp func = getOperation(); + for (mlir::Block& block : func) { + // Group `tf.RestoreV2` ops by prefixes and merge each group. + llvm::SmallDenseMap> + restore_groups; + for (auto restore : block.getOps()) { + restore_groups[restore.getPrefix()].push_back(restore); + } + for (const auto& restores : llvm::make_second_range(restore_groups)) { + if (mlir::failed(MergeRestores(restores))) { + return signalPassFailure(); + } + } + } + } + + private: + mlir::DenseStringElementsAttr GetStringTensorAttr( + llvm::ArrayRef values) { + const int size = values.size(); + const auto type = mlir::RankedTensorType::get( + {size}, mlir::TF::StringType::get(&getContext())); + return mlir::DenseStringElementsAttr::get(type, values); + } + + // Merges `tf.RestoreV2` ops with the same prefix. Ignores restore ops with + // non-constant `tensor_names` and/or `shape_and_slices`. + mlir::LogicalResult MergeRestores( + llvm::ArrayRef restores) { + if (restores.size() <= 1) { + return mlir::success(); + } + + // All restore ops must have the same prefix. + const mlir::Value prefix = + mlir::TF::RestoreV2Op(restores.front()).getPrefix(); + + std::vector restores_to_merge; + std::vector values_to_replace; + std::vector merged_tensor_names; + std::vector merged_shape_and_slices; + + std::vector restore_locs; + std::vector tensor_names_locs; + std::vector shape_and_slices_locs; + + for (mlir::TF::RestoreV2Op restore : restores) { + mlir::DenseStringElementsAttr tensor_names; + mlir::DenseStringElementsAttr shape_and_slices; + if (!mlir::matchPattern(restore, + mlir::m_Op( + mlir::matchers::m_Val(prefix), + mlir::m_Constant(&tensor_names), + mlir::m_Constant(&shape_and_slices)))) { + continue; + } + if (tensor_names.size() != restore.getNumResults() || + shape_and_slices.size() != restore.getNumResults()) { + return restore.emitOpError() + << "returns an inconsistent number of results"; + } + + restores_to_merge.push_back(restore); + llvm::append_range(values_to_replace, restore.getTensors()); + llvm::append_range(merged_tensor_names, + tensor_names.getValues()); + llvm::append_range(merged_shape_and_slices, + shape_and_slices.getValues()); + + restore_locs.push_back(restore.getLoc()); + tensor_names_locs.push_back(restore.getTensorNames().getLoc()); + shape_and_slices_locs.push_back(restore.getShapeAndSlices().getLoc()); + } + if (restores_to_merge.size() <= 1) { + return mlir::success(); + } + + // Insert the merged restore op right before the first restore op to be + // merged in order to keep the dominance property. + mlir::OpBuilder builder(restores_to_merge.front()); + + auto new_tensor_names = builder.create( + builder.getFusedLoc(tensor_names_locs), + GetStringTensorAttr(merged_tensor_names)); + auto new_shape_and_slices = builder.create( + builder.getFusedLoc(shape_and_slices_locs), + GetStringTensorAttr(merged_shape_and_slices)); + + auto new_restore = builder.create( + builder.getFusedLoc(restore_locs), + mlir::TypeRange(mlir::ValueRange(values_to_replace)), prefix, + new_tensor_names, new_shape_and_slices); + for (auto [old_value, new_value] : + llvm::zip(values_to_replace, new_restore.getTensors())) { + old_value.replaceAllUsesWith(new_value); + } + + for (mlir::TF::RestoreV2Op restore : restores_to_merge) { + restore.erase(); + } + return mlir::success(); + } +}; + +} // namespace + +std::unique_ptr> +CreateTfRestoreMergingPass() { + return std::make_unique(); +} + +} // namespace ifrt_serving +} // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf_restore_splitting.cc b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf_restore_splitting.cc new file mode 100644 index 00000000000000..130ca0a2e90b74 --- /dev/null +++ b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf_restore_splitting.cc @@ -0,0 +1,122 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include + +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/StringRef.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/Diagnostics.h" // from @llvm-project +#include "mlir/IR/Matchers.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/OperationSupport.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/IR/Visitors.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" + +namespace tensorflow { +namespace ifrt_serving { +namespace { + +#define GEN_PASS_DEF_TFRESTORESPLITTINGPASS +#define GEN_PASS_DECL_TFRESTORESPLITTINGPASS +#include "tensorflow/compiler/mlir/tfrt/transforms/ifrt/passes.h.inc" // IWYU pragma: keep + +class TfRestoreSplittingPass + : public impl::TfRestoreSplittingPassBase { + public: + void runOnOperation() override { + mlir::func::FuncOp func = getOperation(); + const mlir::WalkResult result = + func.walk([&](mlir::TF::RestoreV2Op restore) { + if (mlir::failed(SplitRestore(restore))) { + return mlir::WalkResult::interrupt(); + } + return mlir::WalkResult::advance(); + }); + if (result.wasInterrupted()) { + return signalPassFailure(); + } + } + + private: + mlir::DenseStringElementsAttr GetStringTensorAttr( + llvm::ArrayRef values) { + const int size = values.size(); + const auto type = mlir::RankedTensorType::get( + {size}, mlir::TF::StringType::get(&getContext())); + return mlir::DenseStringElementsAttr::get(type, values); + } + + // Splits the `tf.RestoreV2` op into per-variable restore ops if its + // `tensor_name` and `shape_and_slices` are constant. + mlir::LogicalResult SplitRestore(mlir::TF::RestoreV2Op restore) { + mlir::DenseStringElementsAttr tensor_names; + mlir::DenseStringElementsAttr shape_and_slices; + if (!mlir::matchPattern(restore, + mlir::m_Op( + /*prefix=*/mlir::matchers::m_Any(), + mlir::m_Constant(&tensor_names), + mlir::m_Constant(&shape_and_slices)))) { + return mlir::success(); + } + if (tensor_names.size() != restore.getNumResults() || + shape_and_slices.size() != restore.getNumResults()) { + return restore.emitOpError() + << "returns an inconsistent number of results"; + } + + mlir::OpBuilder builder(restore); + for (auto [tensor_name, shape_and_slice, result] : + llvm::zip(tensor_names.getValues(), + shape_and_slices.getValues(), + restore.getTensors())) { + auto new_tensor_names = + builder.create(restore.getTensorNames().getLoc(), + GetStringTensorAttr({tensor_name})); + + auto new_shape_and_slices = builder.create( + restore.getShapeAndSlices().getLoc(), + GetStringTensorAttr({shape_and_slice})); + + auto new_restore = builder.create( + restore.getLoc(), mlir::TypeRange({result.getType()}), + restore.getPrefix(), new_tensor_names, new_shape_and_slices); + result.replaceAllUsesWith(new_restore.getTensors()[0]); + } + + restore.erase(); + return mlir::success(); + } +}; + +} // namespace + +std::unique_ptr> +CreateTfRestoreSplittingPass() { + return std::make_unique(); +} + +} // namespace ifrt_serving +} // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tfrt/transforms/mlrt/BUILD b/tensorflow/compiler/mlir/tfrt/transforms/mlrt/BUILD index eb61561911f2b4..bec0d45d6b1525 100644 --- a/tensorflow/compiler/mlir/tfrt/transforms/mlrt/BUILD +++ b/tensorflow/compiler/mlir/tfrt/transforms/mlrt/BUILD @@ -51,6 +51,7 @@ cc_library( "//tensorflow/compiler/mlir/tensorflow:tensorflow_ops_a_m_inc_gen", "//tensorflow/compiler/mlir/tensorflow:tensorflow_ops_n_z_inc_gen", "//tensorflow/compiler/mlir/tensorflow:tensorflow_tfrt_ops_inc_gen", + "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", ], ) @@ -65,6 +66,7 @@ cc_library( ":util", "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/mlir/tensorflow:export_tf_dialect_op", + "//tensorflow/compiler/mlir/tensorflow:tensorflow_tfrt_ops_inc_gen", "//tensorflow/compiler/mlir/tfrt:constants", "//tensorflow/compiler/mlir/tfrt:tfrt_pipeline_options", "//tensorflow/compiler/mlir/tfrt:transform_utils", diff --git a/tensorflow/compiler/mlir/tfrt/transforms/mlrt/tf_to_mlrt.cc b/tensorflow/compiler/mlir/tfrt/transforms/mlrt/tf_to_mlrt.cc index be81e54ece871f..a9dca8adca598c 100644 --- a/tensorflow/compiler/mlir/tfrt/transforms/mlrt/tf_to_mlrt.cc +++ b/tensorflow/compiler/mlir/tfrt/transforms/mlrt/tf_to_mlrt.cc @@ -37,6 +37,7 @@ limitations under the License. #include "mlir/Transforms/DialectConversion.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tfrt_ops.h.inc" #include "tensorflow/compiler/mlir/tensorflow/translate/export_tf_dialect_op.h" #include "tensorflow/compiler/mlir/tfrt/constants.h" #include "tensorflow/compiler/mlir/tfrt/ir/mlrt/mlrt_dialect.h" @@ -322,6 +323,26 @@ class GetResourceOpConversion final } }; +// Convert tf.IfrtLoadVariableOp to tf_mlrt.IfrtLoadVariableOp +class IfrtLoadVariableOpConversion + : public mlir::OpConversionPattern { + public: + using OpConversionPattern::OpConversionPattern; + + mlir::LogicalResult matchAndRewrite( + mlir::TF::IfrtLoadVariableOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const override { + llvm::SmallVector result_types( + op->getNumResults(), rewriter.getType()); + auto new_op = rewriter.create( + op.getLoc(), result_types, adaptor.getOperands()[0], + op.getDeviceShardingConfigProtoTextAttr(), op.getNameAttr()); + rewriter.replaceOp(op, new_op); + + return mlir::success(); + } +}; + std::optional DecodeLongName(mlir::Location loc) { if (auto name_loc = loc.dyn_cast()) { return name_loc.getName().str(); @@ -1167,8 +1188,8 @@ class TfToMlrtConversionPass // Order the list of added ops alphabetically. patterns.add(&context, &type_converter_, &symbol_table); patterns.add(&context); + SetResourceOpConversion, IfrtLoadVariableOpConversion, + TFAwaitOpConversion, TFPromiseOpConversion>(&context); patterns.add(type_converter_, &context); diff --git a/tensorflow/compiler/mlir/tfrt/transforms/mlrt/util.cc b/tensorflow/compiler/mlir/tfrt/transforms/mlrt/util.cc index fb110fb01f2ef1..69c8b08dcbc0b1 100644 --- a/tensorflow/compiler/mlir/tfrt/transforms/mlrt/util.cc +++ b/tensorflow/compiler/mlir/tfrt/transforms/mlrt/util.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/mlir/tfrt/transforms/mlrt/util.h" +#include "llvm/Support/Casting.h" #include "mlir/IR/Operation.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" @@ -29,13 +30,15 @@ bool UseFallback(mlir::Operation *op) { // TODO(b/173017701): have a centralized place to hold the information // whether a TF op should be lowered to FallbackExecute op. + // TODO(b/319045348): Define trait to reflect that IfrtLoadVariableOp has no + // TF kernels so that we don't need to check every op here. // LINT.IfChange(fallback_allow_list) - return !llvm::isa(op); + return !llvm::isa< + mlir::TF::_TfrtSetResourceOp, mlir::TF::_TfrtGetResourceOp, + mlir::TF::BatchFunctionOp, mlir::TF::CaseOp, mlir::TF::IfrtLoadVariableOp, + mlir::TF::StatefulPartitionedCallOp, mlir::TF::PartitionedCallOp, + mlir::TF::LegacyCallOp, mlir::TF::IfOp, mlir::TF::WhileOp, + mlir::TF::TPUCompileMlirAndExecuteOp>(op); // LINT.ThenChange(tf_to_mlrt.cc:fallback_allow_list) } diff --git a/tensorflow/compiler/mlir/tfrt/translate/import_model.cc b/tensorflow/compiler/mlir/tfrt/translate/import_model.cc index 974bc1a56c938f..66aee10db7e050 100644 --- a/tensorflow/compiler/mlir/tfrt/translate/import_model.cc +++ b/tensorflow/compiler/mlir/tfrt/translate/import_model.cc @@ -238,7 +238,7 @@ Status ConvertTfMlirToRuntimeExecutable( TF_RETURN_IF_ERROR( tensorflow::tf2xla::v2::RunFunctionTf2xlaClusteringBridge( - module, tf2xla::v2::DeviceType::XLA_TPU_JIT, + module, /*is_supported_by_replicated_brige*/ true, /*is_in_fallback_enabled_mode=*/VLOG_IS_ON(1))); TF_RETURN_IF_ERROR( @@ -257,7 +257,7 @@ Status ConvertTfMlirToRuntimeExecutable( } else if (options.device_target == TfrtDeviceInfraTarget::kGpu) { TF_RETURN_IF_ERROR( tensorflow::tf2xla::v2::RunFunctionTf2xlaClusteringBridge( - module, tf2xla::v2::DeviceType::XLA_GPU_JIT, + module, /*is_supported_by_replicated_brige*/ false, /*is_in_fallback_enabled_mode=*/false)); TF_RETURN_IF_ERROR( diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/BUILD b/tensorflow/compiler/mlir/tools/kernel_gen/BUILD index 44300ff2209f0d..6cfe2be883e6ca 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/BUILD +++ b/tensorflow/compiler/mlir/tools/kernel_gen/BUILD @@ -48,6 +48,7 @@ cc_library( "//tensorflow/compiler/mlir/tools/kernel_gen/transforms:gpu_passes", # fixdeps: keep "//tensorflow/compiler/mlir/tools/kernel_gen/transforms:passes", "//tensorflow/core:lib", + "@com_google_absl//absl/status", "@llvm-project//llvm:Support", "@llvm-project//mlir:AffineToStandard", "@llvm-project//mlir:ArithDialect", @@ -56,8 +57,8 @@ cc_library( "@llvm-project//mlir:BuiltinToLLVMIRTranslation", "@llvm-project//mlir:ComplexDialect", "@llvm-project//mlir:ComplexToStandard", + "@llvm-project//mlir:ControlFlowDialect", "@llvm-project//mlir:FuncDialect", - "@llvm-project//mlir:FuncToLLVM", "@llvm-project//mlir:GPUDialect", "@llvm-project//mlir:GPUToGPURuntimeTransforms", "@llvm-project//mlir:GPUToLLVMIRTranslation", @@ -67,6 +68,7 @@ cc_library( "@llvm-project//mlir:LLVMIRTransforms", "@llvm-project//mlir:LLVMToLLVMIRTranslation", "@llvm-project//mlir:LinalgTransforms", + "@llvm-project//mlir:MemRefDialect", "@llvm-project//mlir:MemRefTransforms", "@llvm-project//mlir:NVVMToLLVMIRTranslation", "@llvm-project//mlir:Parser", @@ -76,9 +78,10 @@ cc_library( "@llvm-project//mlir:SCFToControlFlow", "@llvm-project//mlir:SCFToGPU", "@llvm-project//mlir:SCFTransforms", + "@llvm-project//mlir:ShapeDialect", "@llvm-project//mlir:ShapeToStandard", + "@llvm-project//mlir:Support", "@llvm-project//mlir:Transforms", - "@llvm-project//mlir:VectorToLLVM", "@local_xla//xla/mlir_hlo", "@local_xla//xla/mlir_hlo:all_passes", # fixdeps: keep "@local_xla//xla/mlir_hlo:mhlo_passes", @@ -162,11 +165,21 @@ cc_library( "//tensorflow/compiler/mlir/tools/kernel_gen/ir:tf_framework_ops", "//tensorflow/core:framework", "//tensorflow/core:lib", + "//tensorflow/core/platform:refcount", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:string_view", + "@llvm-project//llvm:JITLink", + "@llvm-project//llvm:OrcShared", "@llvm-project//llvm:Support", "@llvm-project//mlir:ExecutionEngine", "@llvm-project//mlir:ExecutionEngineUtils", + "@llvm-project//mlir:IR", "@llvm-project//mlir:MemRefTransforms", "@llvm-project//mlir:Parser", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:mlir_c_runner_utils", "@llvm-project//mlir:mlir_runner_utils", "@local_xla//xla/stream_executor", ], @@ -178,8 +191,13 @@ cc_library( hdrs = ["tf_jit_cache.h"], deps = [ "//tensorflow/core:framework", + "//tensorflow/core:framework_lite", + "//tensorflow/core/framework:resource_base", + "//tensorflow/core/platform:status", "@com_google_absl//absl/container:flat_hash_map", + "@llvm-project//llvm:Support", "@llvm-project//mlir:ExecutionEngine", + "@local_tsl//tsl/platform:thread_annotations", ], ) @@ -197,6 +215,7 @@ cc_library( ]), deps = [ "//tensorflow/core:framework", + "//tensorflow/core/framework:resource_base", "//tensorflow/core/platform:logging", "//tensorflow/core/platform:mutex", "//tensorflow/core/platform:refcount", @@ -206,6 +225,7 @@ cc_library( "@com_google_absl//absl/strings", "@llvm-project//mlir:mlir_runner_utils", "@local_tsl//tsl/platform:hash", + "@local_tsl//tsl/platform:thread_annotations", "@local_xla//xla/stream_executor", ] + if_cuda_is_configured([ "@local_config_cuda//cuda:cuda_headers", diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/hlo_to_kernel.cc b/tensorflow/compiler/mlir/tools/kernel_gen/hlo_to_kernel.cc index b5537741529d06..4c05366e42af26 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/hlo_to_kernel.cc +++ b/tensorflow/compiler/mlir/tools/kernel_gen/hlo_to_kernel.cc @@ -119,8 +119,8 @@ Status Run(llvm::StringRef input_file, llvm::StringRef output_file, llvm::StringRef host_triple, llvm::ArrayRef architectures, llvm::ArrayRef tile_sizes, - llvm::ArrayRef unroll_factors, int64_t max_supported_rank, - bool print_ptx, bool print_llvmir, bool enable_ftz, bool index_64bit, + llvm::ArrayRef unroll_factors, bool print_ptx, + bool print_llvmir, bool enable_ftz, bool index_64bit, bool jit_compile, bool jit_i64_indexed_for_large_tensors) { // Read TF code. std::string hlo_code; @@ -138,9 +138,9 @@ Status Run(llvm::StringRef input_file, llvm::StringRef output_file, TF_ASSIGN_OR_RETURN( mlir::OwningOpRef module, GenerateKernelForHloCode(context, hlo_code, architectures, tile_sizes, - unroll_factors, max_supported_rank, print_ptx, - print_llvmir, enable_ftz, index_64bit, - jit_compile, jit_i64_indexed_for_large_tensors, + unroll_factors, print_ptx, print_llvmir, + enable_ftz, index_64bit, jit_compile, + jit_i64_indexed_for_large_tensors, /*apply_cl_options=*/true)); // Get binary. @@ -186,11 +186,6 @@ int main(int argc, char** argv) { llvm::cl::list architectures( "arch", llvm::cl::desc("target architectures (e.g. sm_70 or compute_75)"), llvm::cl::ZeroOrMore, llvm::cl::CommaSeparated); - llvm::cl::opt max_supported_rank( - "max-supported-rank", - llvm::cl::desc("maximum supported rank to be guaranteed by rank " - "specialization lowering"), - llvm::cl::init(5)); llvm::cl::list tile_sizes( "tile_sizes", llvm::cl::desc("tile sizes to use"), llvm::cl::ZeroOrMore, llvm::cl::CommaSeparated); @@ -222,8 +217,8 @@ int main(int argc, char** argv) { auto status = tensorflow::kernel_gen::Run( input_file, output_file, host_triple, architectures, tile_sizes, - unroll_factors, max_supported_rank, print_ptx, print_llvmir, enable_ftz, - index_64bit, jit_compile, jit_i64_indexed_for_large_tensors); + unroll_factors, print_ptx, print_llvmir, enable_ftz, index_64bit, + jit_compile, jit_i64_indexed_for_large_tensors); if (!status.ok()) { LOG(ERROR) << status; return 1; diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.td b/tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.td index 8856ae73b0b4bf..d8e7617cc352ba 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.td +++ b/tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.td @@ -284,7 +284,6 @@ def TFFramework_JITCompileFromStrOp : TFFramework_Op<"jit_compile_from_str", StrAttr:$code, I64ArrayAttr:$tileSizes, I64ArrayAttr:$unrollFactors, - I64Attr:$maxSupportedRank, BoolAttr:$enableFtz, BoolAttr:$index64Bit, BoolAttr:$cpuCodegen diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/kernel_creator.cc b/tensorflow/compiler/mlir/tools/kernel_gen/kernel_creator.cc index 8bf068241f83fa..c8969f429e1805 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/kernel_creator.cc +++ b/tensorflow/compiler/mlir/tools/kernel_gen/kernel_creator.cc @@ -23,31 +23,36 @@ limitations under the License. #include +#include "absl/status/status.h" +#include "llvm/ADT/ArrayRef.h" #include "mlir/Conversion/AffineToStandard/AffineToStandard.h" // from @llvm-project #include "mlir/Conversion/ComplexToStandard/ComplexToStandard.h" // from @llvm-project -#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVMPass.h" // from @llvm-project #include "mlir/Conversion/GPUCommon/GPUCommonPass.h" // from @llvm-project #include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h" // from @llvm-project #include "mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h" // from @llvm-project #include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h" // from @llvm-project #include "mlir/Conversion/SCFToGPU/SCFToGPUPass.h" // from @llvm-project #include "mlir/Conversion/ShapeToStandard/ShapeToStandard.h" // from @llvm-project -#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" // from @llvm-project #include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project #include "mlir/Dialect/Arith/Transforms/Passes.h" // from @llvm-project #include "mlir/Dialect/Bufferization/Transforms/Passes.h" // from @llvm-project #include "mlir/Dialect/Complex/IR/Complex.h" // from @llvm-project +#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h" // from @llvm-project #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/Dialect/GPU/IR/GPUDialect.h" // from @llvm-project #include "mlir/Dialect/GPU/Transforms/Passes.h" // from @llvm-project #include "mlir/Dialect/LLVMIR/Transforms/OptimizeForNVVM.h" // from @llvm-project #include "mlir/Dialect/Linalg/Passes.h" // from @llvm-project -#include "mlir/Dialect/Linalg/Transforms/Transforms.h" // from @llvm-project +#include "mlir/Dialect/MemRef/IR/MemRef.h" // from @llvm-project #include "mlir/Dialect/MemRef/Transforms/Passes.h" // from @llvm-project #include "mlir/Dialect/SCF/Transforms/Passes.h" // from @llvm-project +#include "mlir/Dialect/Shape/IR/Shape.h" // from @llvm-project +#include "mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project +#include "mlir/IR/OwningOpRef.h" // from @llvm-project +#include "mlir/Interfaces/DataLayoutInterfaces.h" // from @llvm-project #include "mlir/Parser/Parser.h" // from @llvm-project -#include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Pass/PassManager.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h" // from @llvm-project #include "mlir/Target/LLVMIR/Dialect/GPU/GPUToLLVMIRTranslation.h" // from @llvm-project #include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h" // from @llvm-project @@ -64,7 +69,10 @@ limitations under the License. #include "xla/mlir_hlo/transforms/gpu_passes.h" #include "xla/mlir_hlo/transforms/passes.h" #include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/status.h" #include "tensorflow/core/platform/statusor.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/statusor.h" namespace tensorflow { namespace kernel_gen { @@ -113,8 +121,7 @@ bool IsSmallAlloc(Value alloc) { Status LowerHloToJITInvocation(mlir::ModuleOp module, llvm::ArrayRef tile_sizes, llvm::ArrayRef unroll_factors, - int64_t max_supported_rank, bool enable_ftz, - bool index_64bit, + bool enable_ftz, bool index_64bit, bool jit_i64_indexed_for_large_tensors, bool apply_cl_options) { mlir::PassManager pm(module.getContext()); @@ -122,8 +129,7 @@ Status LowerHloToJITInvocation(mlir::ModuleOp module, pm.addNestedPass( mlir::kernel_gen::transforms::CreateFuncToJITInvocationPass( - tile_sizes, unroll_factors, max_supported_rank, enable_ftz, - index_64bit, + tile_sizes, unroll_factors, enable_ftz, index_64bit, /*cpu_codegen=*/false, jit_i64_indexed_for_large_tensors)); pm.addPass(mlir::kernel_gen::tf_framework::CreateEmbedTFFrameworkPass()); pm.addNestedPass( @@ -143,8 +149,7 @@ Status LowerHloToJITInvocation(mlir::ModuleOp module, Status LowerHlotoLoops(mlir::ModuleOp module, llvm::ArrayRef tile_sizes, - llvm::ArrayRef unroll_factors, - int64_t max_supported_rank, bool enable_ftz, + llvm::ArrayRef unroll_factors, bool enable_ftz, bool index_64bit, bool jit_i64_indexed_for_large_tensors, bool apply_cl_options) { mlir::PassManager pm(module.getContext()); @@ -152,14 +157,11 @@ Status LowerHlotoLoops(mlir::ModuleOp module, if (jit_i64_indexed_for_large_tensors) { pm.addNestedPass( mlir::kernel_gen::transforms::CreateFuncToJITInvocationPass( - tile_sizes, unroll_factors, max_supported_rank, enable_ftz, - index_64bit, + tile_sizes, unroll_factors, enable_ftz, index_64bit, /*cpu_codegen=*/false, /*jit_i64_indexed_for_large_tensors=*/true)); } - pm.addNestedPass(mlir::mhlo::createRankSpecializationClusterPass()); - pm.addNestedPass( - mlir::mhlo::createRankSpecializationToSCFPass(max_supported_rank)); + pm.addNestedPass(mlir::mhlo::createChloLegalizeToHloPass()); pm.addNestedPass(mlir::createCanonicalizerPass()); @@ -409,6 +411,8 @@ StatusOr> SetupContextAndParseModule( mlir::MLIRContext& context, llvm::StringRef tf_code) { mlir::DialectRegistry registry; registry.insert(); + registry.insert(); registry.insert(); mlir::registerBuiltinDialectTranslation(registry); mlir::registerGPUDialectTranslation(registry); @@ -418,9 +422,10 @@ StatusOr> SetupContextAndParseModule( context.appendDialectRegistry(registry); mlir::OwningOpRef module = mlir::parseSourceString(tf_code, &context); - if (!module) + if (!module) { return tensorflow::Status(absl::StatusCode::kInvalidArgument, "invalid kernel IR"); + } return module; } @@ -428,9 +433,9 @@ StatusOr> GenerateKernelForHloCode( mlir::MLIRContext& context, llvm::StringRef tf_code, llvm::ArrayRef architectures, llvm::ArrayRef tile_sizes, llvm::ArrayRef unroll_factors, - int64_t max_supported_rank, bool print_ptx, bool print_llvmir, - bool enable_ftz, bool index_64bit, bool jit_compile, - bool jit_i64_indexed_for_large_tensors, bool apply_cl_options) { + bool print_ptx, bool print_llvmir, bool enable_ftz, bool index_64bit, + bool jit_compile, bool jit_i64_indexed_for_large_tensors, + bool apply_cl_options) { if (jit_compile && jit_i64_indexed_for_large_tensors) { return tensorflow::Status( absl::StatusCode::kInvalidArgument, @@ -446,14 +451,12 @@ StatusOr> GenerateKernelForHloCode( assert(!jit_i64_indexed_for_large_tensors && "expect to have reported an error earlier"); TF_RETURN_IF_ERROR(LowerHloToJITInvocation( - module.get(), tile_sizes, unroll_factors, max_supported_rank, - enable_ftz, index_64bit, + module.get(), tile_sizes, unroll_factors, enable_ftz, index_64bit, /*jit_i64_indexed_for_large_tensors=*/false, apply_cl_options)); } else { - TF_RETURN_IF_ERROR( - LowerHlotoLoops(module.get(), tile_sizes, unroll_factors, - max_supported_rank, enable_ftz, index_64bit, - jit_i64_indexed_for_large_tensors, apply_cl_options)); + TF_RETURN_IF_ERROR(LowerHlotoLoops( + module.get(), tile_sizes, unroll_factors, enable_ftz, index_64bit, + jit_i64_indexed_for_large_tensors, apply_cl_options)); TF_RETURN_IF_ERROR( LowerLoopsToGPU(module.get(), index_64bit, apply_cl_options)); TF_RETURN_IF_ERROR( diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/kernel_creator.h b/tensorflow/compiler/mlir/tools/kernel_gen/kernel_creator.h index ac8666874224aa..f92ff42405db38 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/kernel_creator.h +++ b/tensorflow/compiler/mlir/tools/kernel_gen/kernel_creator.h @@ -29,6 +29,7 @@ limitations under the License. #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/OwningOpRef.h" // from @llvm-project #include "tensorflow/core/platform/statusor.h" namespace tensorflow { @@ -44,9 +45,9 @@ StatusOr> GenerateKernelForHloCode( mlir::MLIRContext& context, llvm::StringRef tf_code, llvm::ArrayRef architectures, llvm::ArrayRef tile_sizes, llvm::ArrayRef unroll_factors, - int64_t max_supported_rank, bool print_ptx, bool print_llvmir, - bool enable_ftz, bool index_64bit, bool jit_compile, - bool jit_i64_indexed_for_large_tensors, bool apply_cl_options); + bool print_ptx, bool print_llvmir, bool enable_ftz, bool index_64bit, + bool jit_compile, bool jit_i64_indexed_for_large_tensors, + bool apply_cl_options); } // namespace kernel_gen } // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/tests/embed_tf_framework.mlir b/tensorflow/compiler/mlir/tools/kernel_gen/tests/embed_tf_framework.mlir index d6a7c2698f15ea..e6404a84319c0d 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/tests/embed_tf_framework.mlir +++ b/tensorflow/compiler/mlir/tools/kernel_gen/tests/embed_tf_framework.mlir @@ -80,7 +80,7 @@ func.func @jit_compile_from_str(%ctx : !tf_framework.op_kernel_context) // CHECK: return %[[RES]] %0 = tf_framework.jit_compile_from_str "placeholder" { architectures = ["sm_123", "sm_456"], tileSizes = [1, 2, 3], - unrollFactors = [4], maxSupportedRank = 3 : i64, enableFtz = false, + unrollFactors = [4], enableFtz = false, index64Bit = false, cpuCodegen = false } func.return %0 : !tf_framework.jit_callable } diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/tests/func_to_jit_invocations.mlir b/tensorflow/compiler/mlir/tools/kernel_gen/tests/func_to_jit_invocations.mlir index 7da7a482f63774..e89d5e196bde94 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/tests/func_to_jit_invocations.mlir +++ b/tensorflow/compiler/mlir/tools/kernel_gen/tests/func_to_jit_invocations.mlir @@ -1,25 +1,25 @@ // RUN: kernel-gen-opt %s --split-input-file \ // RUN: --func-to-jit-invocation="tile-sizes=1,2,3 unroll-factors=3,2,1 \ -// RUN: max-supported-rank=32 enable-ftz=false cpu-codegen=false" | \ +// RUN: enable-ftz=false cpu-codegen=false" | \ // RUN: FileCheck %s // RUN: kernel-gen-opt %s --split-input-file \ // RUN: --func-to-jit-invocation="tile-sizes=1,2,3 unroll-factors=3,2,1 \ -// RUN: max-supported-rank=32 enable-ftz=false cpu-codegen=false \ +// RUN: enable-ftz=false cpu-codegen=false \ // RUN: jit_i64_indexed_for_large_tensors=true" | \ // RUN: FileCheck %s --check-prefix=CHECK-JFLT -func.func @unary_tanh(%arg : tensor<*xf32>) -> tensor<*xf32> { - %0 = mhlo.tanh %arg : tensor<*xf32> - func.return %0 : tensor<*xf32> +func.func @unary_tanh(%arg : tensor) -> tensor { + %0 = mhlo.tanh %arg : tensor + func.return %0 : tensor } // CHECK-LABEL: @unary_tanh -// CHECK-SAME: %[[ARG:.*]]: tensor<*xf32> +// CHECK-SAME: %[[ARG:.*]]: tensor // CHECK: %[[CALLABLE:.*]] = tf_framework.jit_compile_from_str // CHECK-SAME: " // CHECK-SAME: module { -// CHECK-SAME: func @main(%[[ARG_JIT:.*]]: tensor<*xf32>) -> tensor<*xf32> +// CHECK-SAME: func @main(%[[ARG_JIT:.*]]: tensor) -> tensor // CHECK-SAME: attributes {tf_entry} // CHECK-SAME: { // CHECK-SAME: %[[RES_JIT:.*]] = mhlo.tanh %[[ARG_JIT]] @@ -30,7 +30,6 @@ func.func @unary_tanh(%arg : tensor<*xf32>) -> tensor<*xf32> { // CHECK-SAME: { // CHECK-SAME: cpuCodegen = false // CHECK-SAME: enableFtz = false -// CHECK-SAME: maxSupportedRank = 32 : i64 // CHECK-SAME: tileSizes = [1, 2, 3] // CHECK-SAME: unrollFactors = [3, 2, 1] // CHECK-SAME: } @@ -38,7 +37,7 @@ func.func @unary_tanh(%arg : tensor<*xf32>) -> tensor<*xf32> { // CHECK: return %[[RES]] // CHECK-JFLT-LABEL: @unary_tanh -// CHECK-JFLT-SAME: %[[ARG0:.*]]: tensor<*xf32> +// CHECK-JFLT-SAME: %[[ARG0:.*]]: tensor // CHECK-JFLT: %[[SHAPE:.*]] = shape.shape_of %[[ARG0]] // CHECK-JFLT: %[[NUM:.*]] = shape.num_elements %[[SHAPE]] // CHECK-JFLT: %[[LIMIT:.*]] = arith.constant 2147483647 @@ -49,7 +48,6 @@ func.func @unary_tanh(%arg : tensor<*xf32>) -> tensor<*xf32> { // CHECK-JFLT-SAME: cpuCodegen = false // CHECK-JFLT-SAME: enableFtz = false // CHECK-JFLT-SAME: index64Bit = true -// CHECK-JFLT-SAME: maxSupportedRank = 32 // CHECK-JFLT-SAME: tileSizes = [1, 2, 3] // CHECK-JFLT-SAME: unrollFactors = [3, 2, 1] // CHECK-JFLT: %[[JIT_0:.*]] = tf_framework.jit_execute %[[JIT]](%[[ARG0]]) @@ -82,7 +80,6 @@ func.func @binary_sub(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>) -> tensor<*x // CHECK-SAME: { // CHECK-SAME: cpuCodegen = false // CHECK-SAME: enableFtz = false -// CHECK-SAME: maxSupportedRank = 32 : i64 // CHECK-SAME: tileSizes = [1, 2, 3] // CHECK-SAME: unrollFactors = [3, 2, 1] // CHECK-SAME: } @@ -114,7 +111,6 @@ func.func @binary_sub(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>) -> tensor<*x // CHECK-JFLT-SAME: { // CHECK-JFLT-SAME: cpuCodegen = false // CHECK-JFLT-SAME: enableFtz = false -// CHECK-JFLT-SAME: maxSupportedRank = 32 : i64 // CHECK-JFLT-SAME: tileSizes = [1, 2, 3] // CHECK-JFLT-SAME: unrollFactors = [3, 2, 1] // CHECK-JFLT-SAME: } @@ -149,7 +145,6 @@ func.func @reciprocal(%arg0: tensor<*xf32>) // CHECK-SAME: cpuCodegen = false, // CHECK-SAME: enableFtz = false, // CHECK-SAME: index64Bit = false, -// CHECK-SAME: maxSupportedRank = 32 : i64, // CHECK-SAME: tileSizes = [1, 2, 3], // CHECK-SAME: unrollFactors = [3, 2, 1] // CHECK-SAME: } @@ -168,7 +163,6 @@ func.func @reciprocal(%arg0: tensor<*xf32>) // CHECK-JFLT-SAME: cpuCodegen = false // CHECK-JFLT-SAME: enableFtz = false // CHECK-JFLT-SAME: index64Bit = true -// CHECK-JFLT-SAME: maxSupportedRank = 32 // CHECK-JFLT-SAME: tileSizes = [1, 2, 3] // CHECK-JFLT-SAME: unrollFactors = [3, 2, 1] // CHECK-JFLT: %[[JIT_0:.*]] = tf_framework.jit_execute %[[JIT]](%[[ARG0]]) diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/tests/hlo_to_kernel/add_v2.mlir b/tensorflow/compiler/mlir/tools/kernel_gen/tests/hlo_to_kernel/add_v2.mlir index a1b35ccecc993e..686b34e0d138db 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/tests/hlo_to_kernel/add_v2.mlir +++ b/tensorflow/compiler/mlir/tools/kernel_gen/tests/hlo_to_kernel/add_v2.mlir @@ -1,7 +1,131 @@ // RUN: hlo_to_kernel --input=%s --output=%t --unroll_factors=4 --tile_sizes=256 --arch=sm_70 -func.func @AddV2(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) - -> tensor<*xf32> attributes {tf_entry, llvm.emit_c_interface} { - %0 = chlo.broadcast_add %arg0, %arg1 : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> - return %0 : tensor<*xf32> +func.func @AddV2(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> attributes {llvm.emit_c_interface, tf_entry} { + %0 = shape.const_shape [1, 1, 1, 1, 1] : tensor<5xindex> + %c5 = arith.constant 5 : index + %1 = shape.const_shape [1, 1, 1, 1] : tensor<4xindex> + %c4 = arith.constant 4 : index + %2 = shape.const_shape [1, 1, 1] : tensor<3xindex> + %c3 = arith.constant 3 : index + %3 = shape.const_shape [1, 1] : tensor<2xindex> + %c2 = arith.constant 2 : index + %4 = shape.const_shape [1] : tensor<1xindex> + %c1 = arith.constant 1 : index + %5 = shape.shape_of %arg0 : tensor<*xf32> -> tensor + %6 = shape.shape_of %arg1 : tensor<*xf32> -> tensor + %7 = shape.num_elements %5 : tensor -> index + %8 = arith.cmpi eq, %7, %c1 : index + %9 = scf.if %8 -> (tensor<*xf32>) { + %14 = shape.num_elements %6 : tensor -> index + %from_elements = tensor.from_elements %14 : tensor<1xindex> + %15 = mhlo.reshape %arg0 : (tensor<*xf32>) -> tensor + %16 = mhlo.dynamic_reshape %arg1, %from_elements : (tensor<*xf32>, tensor<1xindex>) -> tensor + %17 = chlo.broadcast_add %15, %16 : (tensor, tensor) -> tensor + %cast = tensor.cast %17 : tensor to tensor<*xf32> + scf.yield %cast : tensor<*xf32> + } else { + %14 = shape.num_elements %6 : tensor -> index + %15 = arith.cmpi eq, %14, %c1 : index + %16 = scf.if %15 -> (tensor<*xf32>) { + %17 = shape.num_elements %5 : tensor -> index + %from_elements = tensor.from_elements %17 : tensor<1xindex> + %18 = mhlo.dynamic_reshape %arg0, %from_elements : (tensor<*xf32>, tensor<1xindex>) -> tensor + %19 = mhlo.reshape %arg1 : (tensor<*xf32>) -> tensor + %20 = chlo.broadcast_add %18, %19 : (tensor, tensor) -> tensor + %cast = tensor.cast %20 : tensor to tensor<*xf32> + scf.yield %cast : tensor<*xf32> + } else { + %17 = shape.shape_eq %5, %6 : tensor, tensor + %18 = scf.if %17 -> (tensor<*xf32>) { + %19 = shape.any %5, %6 : tensor, tensor -> tensor + %20 = shape.num_elements %19 : tensor -> index + %from_elements = tensor.from_elements %20 : tensor<1xindex> + %21 = mhlo.dynamic_reshape %arg0, %from_elements : (tensor<*xf32>, tensor<1xindex>) -> tensor + %22 = mhlo.dynamic_reshape %arg1, %from_elements : (tensor<*xf32>, tensor<1xindex>) -> tensor + %23 = chlo.broadcast_add %21, %22 : (tensor, tensor) -> tensor + %cast = tensor.cast %23 : tensor to tensor<*xf32> + scf.yield %cast : tensor<*xf32> + } else { + %19:2 = chlo.minimum_broadcast_shapes %5, %6 : tensor, tensor -> tensor, tensor + %20 = shape.rank %19#0 : tensor -> index + %21 = shape.rank %19#1 : tensor -> index + %22 = arith.cmpi sgt, %20, %21 : index + %23 = arith.select %22, %20, %21 : index + %24 = arith.cmpi ule, %23, %c1 : index + %25 = scf.if %24 -> (tensor<*xf32>) { + %26 = shape.broadcast %19#0, %4 : tensor, tensor<1xindex> -> tensor + %cast = tensor.cast %26 : tensor to tensor<1xindex> + %27 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xf32>, tensor<1xindex>) -> tensor + %28 = shape.broadcast %19#1, %4 : tensor, tensor<1xindex> -> tensor + %cast_0 = tensor.cast %28 : tensor to tensor<1xindex> + %29 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xf32>, tensor<1xindex>) -> tensor + %30 = chlo.broadcast_add %27, %29 : (tensor, tensor) -> tensor + %cast_1 = tensor.cast %30 : tensor to tensor<*xf32> + scf.yield %cast_1 : tensor<*xf32> + } else { + %26 = arith.cmpi ule, %23, %c2 : index + %27 = scf.if %26 -> (tensor<*xf32>) { + %28 = shape.broadcast %19#0, %3 : tensor, tensor<2xindex> -> tensor + %cast = tensor.cast %28 : tensor to tensor<2xindex> + %29 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xf32>, tensor<2xindex>) -> tensor + %30 = shape.broadcast %19#1, %3 : tensor, tensor<2xindex> -> tensor + %cast_0 = tensor.cast %30 : tensor to tensor<2xindex> + %31 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xf32>, tensor<2xindex>) -> tensor + %32 = chlo.broadcast_add %29, %31 : (tensor, tensor) -> tensor + %cast_1 = tensor.cast %32 : tensor to tensor<*xf32> + scf.yield %cast_1 : tensor<*xf32> + } else { + %28 = arith.cmpi ule, %23, %c3 : index + %29 = scf.if %28 -> (tensor<*xf32>) { + %30 = shape.broadcast %19#0, %2 : tensor, tensor<3xindex> -> tensor + %cast = tensor.cast %30 : tensor to tensor<3xindex> + %31 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xf32>, tensor<3xindex>) -> tensor + %32 = shape.broadcast %19#1, %2 : tensor, tensor<3xindex> -> tensor + %cast_0 = tensor.cast %32 : tensor to tensor<3xindex> + %33 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xf32>, tensor<3xindex>) -> tensor + %34 = chlo.broadcast_add %31, %33 : (tensor, tensor) -> tensor + %cast_1 = tensor.cast %34 : tensor to tensor<*xf32> + scf.yield %cast_1 : tensor<*xf32> + } else { + %30 = arith.cmpi ule, %23, %c4 : index + %31 = scf.if %30 -> (tensor<*xf32>) { + %32 = shape.broadcast %19#0, %1 : tensor, tensor<4xindex> -> tensor + %cast = tensor.cast %32 : tensor to tensor<4xindex> + %33 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xf32>, tensor<4xindex>) -> tensor + %34 = shape.broadcast %19#1, %1 : tensor, tensor<4xindex> -> tensor + %cast_0 = tensor.cast %34 : tensor to tensor<4xindex> + %35 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xf32>, tensor<4xindex>) -> tensor + %36 = chlo.broadcast_add %33, %35 : (tensor, tensor) -> tensor + %cast_1 = tensor.cast %36 : tensor to tensor<*xf32> + scf.yield %cast_1 : tensor<*xf32> + } else { + %32 = arith.cmpi ule, %23, %c5 : index + cf.assert %32, "Input for dynamic binary or n-ary op lowering was of a rank greater than 5" + %33 = shape.broadcast %19#0, %0 : tensor, tensor<5xindex> -> tensor + %cast = tensor.cast %33 : tensor to tensor<5xindex> + %34 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xf32>, tensor<5xindex>) -> tensor + %35 = shape.broadcast %19#1, %0 : tensor, tensor<5xindex> -> tensor + %cast_0 = tensor.cast %35 : tensor to tensor<5xindex> + %36 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xf32>, tensor<5xindex>) -> tensor + %37 = chlo.broadcast_add %34, %36 : (tensor, tensor) -> tensor + %cast_1 = tensor.cast %37 : tensor to tensor<*xf32> + scf.yield %cast_1 : tensor<*xf32> + } + scf.yield %31 : tensor<*xf32> + } + scf.yield %29 : tensor<*xf32> + } + scf.yield %27 : tensor<*xf32> + } + scf.yield %25 : tensor<*xf32> + } + scf.yield %18 : tensor<*xf32> + } + scf.yield %16 : tensor<*xf32> + } + %10 = shape.shape_of %arg0 : tensor<*xf32> -> tensor + %11 = shape.shape_of %arg1 : tensor<*xf32> -> tensor + %12 = shape.broadcast %10, %11 : tensor, tensor -> tensor + %13 = mhlo.dynamic_reshape %9, %12 : (tensor<*xf32>, tensor) -> tensor<*xf32> + return %13 : tensor<*xf32> } diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/tests/hlo_to_kernel/add_v2_unsigned.mlir b/tensorflow/compiler/mlir/tools/kernel_gen/tests/hlo_to_kernel/add_v2_unsigned.mlir index ee01fe543eaac4..f38a2dca1bc8cd 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/tests/hlo_to_kernel/add_v2_unsigned.mlir +++ b/tensorflow/compiler/mlir/tools/kernel_gen/tests/hlo_to_kernel/add_v2_unsigned.mlir @@ -1,7 +1,131 @@ // RUN: hlo_to_kernel --input=%s --output=%t --unroll_factors=4 --tile_sizes=256 --arch=sm_70 -func.func @AddV2(%arg0: tensor<*xui32>, %arg1: tensor<*xui32>) - -> tensor<*xui32> attributes {tf_entry, llvm.emit_c_interface} { - %0 = chlo.broadcast_add %arg0, %arg1 : (tensor<*xui32>, tensor<*xui32>) -> tensor<*xui32> - return %0 : tensor<*xui32> +func.func @AddV2(%arg0: tensor<*xui32>, %arg1: tensor<*xui32>) -> tensor<*xui32> attributes {llvm.emit_c_interface, tf_entry} { + %0 = shape.const_shape [1, 1, 1, 1, 1] : tensor<5xindex> + %c5 = arith.constant 5 : index + %1 = shape.const_shape [1, 1, 1, 1] : tensor<4xindex> + %c4 = arith.constant 4 : index + %2 = shape.const_shape [1, 1, 1] : tensor<3xindex> + %c3 = arith.constant 3 : index + %3 = shape.const_shape [1, 1] : tensor<2xindex> + %c2 = arith.constant 2 : index + %4 = shape.const_shape [1] : tensor<1xindex> + %c1 = arith.constant 1 : index + %5 = shape.shape_of %arg0 : tensor<*xui32> -> tensor + %6 = shape.shape_of %arg1 : tensor<*xui32> -> tensor + %7 = shape.num_elements %5 : tensor -> index + %8 = arith.cmpi eq, %7, %c1 : index + %9 = scf.if %8 -> (tensor<*xui32>) { + %14 = shape.num_elements %6 : tensor -> index + %from_elements = tensor.from_elements %14 : tensor<1xindex> + %15 = mhlo.reshape %arg0 : (tensor<*xui32>) -> tensor + %16 = mhlo.dynamic_reshape %arg1, %from_elements : (tensor<*xui32>, tensor<1xindex>) -> tensor + %17 = chlo.broadcast_add %15, %16 : (tensor, tensor) -> tensor + %cast = tensor.cast %17 : tensor to tensor<*xui32> + scf.yield %cast : tensor<*xui32> + } else { + %14 = shape.num_elements %6 : tensor -> index + %15 = arith.cmpi eq, %14, %c1 : index + %16 = scf.if %15 -> (tensor<*xui32>) { + %17 = shape.num_elements %5 : tensor -> index + %from_elements = tensor.from_elements %17 : tensor<1xindex> + %18 = mhlo.dynamic_reshape %arg0, %from_elements : (tensor<*xui32>, tensor<1xindex>) -> tensor + %19 = mhlo.reshape %arg1 : (tensor<*xui32>) -> tensor + %20 = chlo.broadcast_add %18, %19 : (tensor, tensor) -> tensor + %cast = tensor.cast %20 : tensor to tensor<*xui32> + scf.yield %cast : tensor<*xui32> + } else { + %17 = shape.shape_eq %5, %6 : tensor, tensor + %18 = scf.if %17 -> (tensor<*xui32>) { + %19 = shape.any %5, %6 : tensor, tensor -> tensor + %20 = shape.num_elements %19 : tensor -> index + %from_elements = tensor.from_elements %20 : tensor<1xindex> + %21 = mhlo.dynamic_reshape %arg0, %from_elements : (tensor<*xui32>, tensor<1xindex>) -> tensor + %22 = mhlo.dynamic_reshape %arg1, %from_elements : (tensor<*xui32>, tensor<1xindex>) -> tensor + %23 = chlo.broadcast_add %21, %22 : (tensor, tensor) -> tensor + %cast = tensor.cast %23 : tensor to tensor<*xui32> + scf.yield %cast : tensor<*xui32> + } else { + %19:2 = chlo.minimum_broadcast_shapes %5, %6 : tensor, tensor -> tensor, tensor + %20 = shape.rank %19#0 : tensor -> index + %21 = shape.rank %19#1 : tensor -> index + %22 = arith.cmpi sgt, %20, %21 : index + %23 = arith.select %22, %20, %21 : index + %24 = arith.cmpi ule, %23, %c1 : index + %25 = scf.if %24 -> (tensor<*xui32>) { + %26 = shape.broadcast %19#0, %4 : tensor, tensor<1xindex> -> tensor + %cast = tensor.cast %26 : tensor to tensor<1xindex> + %27 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xui32>, tensor<1xindex>) -> tensor + %28 = shape.broadcast %19#1, %4 : tensor, tensor<1xindex> -> tensor + %cast_0 = tensor.cast %28 : tensor to tensor<1xindex> + %29 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xui32>, tensor<1xindex>) -> tensor + %30 = chlo.broadcast_add %27, %29 : (tensor, tensor) -> tensor + %cast_1 = tensor.cast %30 : tensor to tensor<*xui32> + scf.yield %cast_1 : tensor<*xui32> + } else { + %26 = arith.cmpi ule, %23, %c2 : index + %27 = scf.if %26 -> (tensor<*xui32>) { + %28 = shape.broadcast %19#0, %3 : tensor, tensor<2xindex> -> tensor + %cast = tensor.cast %28 : tensor to tensor<2xindex> + %29 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xui32>, tensor<2xindex>) -> tensor + %30 = shape.broadcast %19#1, %3 : tensor, tensor<2xindex> -> tensor + %cast_0 = tensor.cast %30 : tensor to tensor<2xindex> + %31 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xui32>, tensor<2xindex>) -> tensor + %32 = chlo.broadcast_add %29, %31 : (tensor, tensor) -> tensor + %cast_1 = tensor.cast %32 : tensor to tensor<*xui32> + scf.yield %cast_1 : tensor<*xui32> + } else { + %28 = arith.cmpi ule, %23, %c3 : index + %29 = scf.if %28 -> (tensor<*xui32>) { + %30 = shape.broadcast %19#0, %2 : tensor, tensor<3xindex> -> tensor + %cast = tensor.cast %30 : tensor to tensor<3xindex> + %31 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xui32>, tensor<3xindex>) -> tensor + %32 = shape.broadcast %19#1, %2 : tensor, tensor<3xindex> -> tensor + %cast_0 = tensor.cast %32 : tensor to tensor<3xindex> + %33 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xui32>, tensor<3xindex>) -> tensor + %34 = chlo.broadcast_add %31, %33 : (tensor, tensor) -> tensor + %cast_1 = tensor.cast %34 : tensor to tensor<*xui32> + scf.yield %cast_1 : tensor<*xui32> + } else { + %30 = arith.cmpi ule, %23, %c4 : index + %31 = scf.if %30 -> (tensor<*xui32>) { + %32 = shape.broadcast %19#0, %1 : tensor, tensor<4xindex> -> tensor + %cast = tensor.cast %32 : tensor to tensor<4xindex> + %33 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xui32>, tensor<4xindex>) -> tensor + %34 = shape.broadcast %19#1, %1 : tensor, tensor<4xindex> -> tensor + %cast_0 = tensor.cast %34 : tensor to tensor<4xindex> + %35 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xui32>, tensor<4xindex>) -> tensor + %36 = chlo.broadcast_add %33, %35 : (tensor, tensor) -> tensor + %cast_1 = tensor.cast %36 : tensor to tensor<*xui32> + scf.yield %cast_1 : tensor<*xui32> + } else { + %32 = arith.cmpi ule, %23, %c5 : index + cf.assert %32, "Input for dynamic binary or n-ary op lowering was of a rank greater than 5" + %33 = shape.broadcast %19#0, %0 : tensor, tensor<5xindex> -> tensor + %cast = tensor.cast %33 : tensor to tensor<5xindex> + %34 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xui32>, tensor<5xindex>) -> tensor + %35 = shape.broadcast %19#1, %0 : tensor, tensor<5xindex> -> tensor + %cast_0 = tensor.cast %35 : tensor to tensor<5xindex> + %36 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xui32>, tensor<5xindex>) -> tensor + %37 = chlo.broadcast_add %34, %36 : (tensor, tensor) -> tensor + %cast_1 = tensor.cast %37 : tensor to tensor<*xui32> + scf.yield %cast_1 : tensor<*xui32> + } + scf.yield %31 : tensor<*xui32> + } + scf.yield %29 : tensor<*xui32> + } + scf.yield %27 : tensor<*xui32> + } + scf.yield %25 : tensor<*xui32> + } + scf.yield %18 : tensor<*xui32> + } + scf.yield %16 : tensor<*xui32> + } + %10 = shape.shape_of %arg0 : tensor<*xui32> -> tensor + %11 = shape.shape_of %arg1 : tensor<*xui32> -> tensor + %12 = shape.broadcast %10, %11 : tensor, tensor -> tensor + %13 = mhlo.dynamic_reshape %9, %12 : (tensor<*xui32>, tensor) -> tensor<*xui32> + return %13 : tensor<*xui32> } diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/tests/hlo_to_kernel/minimum.mlir b/tensorflow/compiler/mlir/tools/kernel_gen/tests/hlo_to_kernel/minimum.mlir new file mode 100644 index 00000000000000..1facc06ee500e9 --- /dev/null +++ b/tensorflow/compiler/mlir/tools/kernel_gen/tests/hlo_to_kernel/minimum.mlir @@ -0,0 +1,131 @@ +// RUN: hlo_to_kernel --input=%s --output=%t --unroll_factors=4 --tile_sizes=256 --arch=sm_70 + +func.func @Minimum_GPU_DT_UINT32_DT_UINT32(%arg0: tensor<*xui32>, %arg1: tensor<*xui32>) -> tensor<*xui32> attributes {llvm.emit_c_interface, tf_entry} { + %0 = shape.const_shape [1, 1, 1, 1, 1] : tensor<5xindex> + %c5 = arith.constant 5 : index + %1 = shape.const_shape [1, 1, 1, 1] : tensor<4xindex> + %c4 = arith.constant 4 : index + %2 = shape.const_shape [1, 1, 1] : tensor<3xindex> + %c3 = arith.constant 3 : index + %3 = shape.const_shape [1, 1] : tensor<2xindex> + %c2 = arith.constant 2 : index + %4 = shape.const_shape [1] : tensor<1xindex> + %c1 = arith.constant 1 : index + %5 = shape.shape_of %arg0 : tensor<*xui32> -> tensor + %6 = shape.shape_of %arg1 : tensor<*xui32> -> tensor + %7 = shape.num_elements %5 : tensor -> index + %8 = arith.cmpi eq, %7, %c1 : index + %9 = scf.if %8 -> (tensor<*xui32>) { + %14 = shape.num_elements %6 : tensor -> index + %from_elements = tensor.from_elements %14 : tensor<1xindex> + %15 = mhlo.reshape %arg0 : (tensor<*xui32>) -> tensor + %16 = mhlo.dynamic_reshape %arg1, %from_elements : (tensor<*xui32>, tensor<1xindex>) -> tensor + %17 = chlo.broadcast_minimum %15, %16 : (tensor, tensor) -> tensor + %cast = tensor.cast %17 : tensor to tensor<*xui32> + scf.yield %cast : tensor<*xui32> + } else { + %14 = shape.num_elements %6 : tensor -> index + %15 = arith.cmpi eq, %14, %c1 : index + %16 = scf.if %15 -> (tensor<*xui32>) { + %17 = shape.num_elements %5 : tensor -> index + %from_elements = tensor.from_elements %17 : tensor<1xindex> + %18 = mhlo.dynamic_reshape %arg0, %from_elements : (tensor<*xui32>, tensor<1xindex>) -> tensor + %19 = mhlo.reshape %arg1 : (tensor<*xui32>) -> tensor + %20 = chlo.broadcast_minimum %18, %19 : (tensor, tensor) -> tensor + %cast = tensor.cast %20 : tensor to tensor<*xui32> + scf.yield %cast : tensor<*xui32> + } else { + %17 = shape.shape_eq %5, %6 : tensor, tensor + %18 = scf.if %17 -> (tensor<*xui32>) { + %19 = shape.any %5, %6 : tensor, tensor -> tensor + %20 = shape.num_elements %19 : tensor -> index + %from_elements = tensor.from_elements %20 : tensor<1xindex> + %21 = mhlo.dynamic_reshape %arg0, %from_elements : (tensor<*xui32>, tensor<1xindex>) -> tensor + %22 = mhlo.dynamic_reshape %arg1, %from_elements : (tensor<*xui32>, tensor<1xindex>) -> tensor + %23 = chlo.broadcast_minimum %21, %22 : (tensor, tensor) -> tensor + %cast = tensor.cast %23 : tensor to tensor<*xui32> + scf.yield %cast : tensor<*xui32> + } else { + %19:2 = chlo.minimum_broadcast_shapes %5, %6 : tensor, tensor -> tensor, tensor + %20 = shape.rank %19#0 : tensor -> index + %21 = shape.rank %19#1 : tensor -> index + %22 = arith.cmpi sgt, %20, %21 : index + %23 = arith.select %22, %20, %21 : index + %24 = arith.cmpi ule, %23, %c1 : index + %25 = scf.if %24 -> (tensor<*xui32>) { + %26 = shape.broadcast %19#0, %4 : tensor, tensor<1xindex> -> tensor + %cast = tensor.cast %26 : tensor to tensor<1xindex> + %27 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xui32>, tensor<1xindex>) -> tensor + %28 = shape.broadcast %19#1, %4 : tensor, tensor<1xindex> -> tensor + %cast_0 = tensor.cast %28 : tensor to tensor<1xindex> + %29 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xui32>, tensor<1xindex>) -> tensor + %30 = chlo.broadcast_minimum %27, %29 : (tensor, tensor) -> tensor + %cast_1 = tensor.cast %30 : tensor to tensor<*xui32> + scf.yield %cast_1 : tensor<*xui32> + } else { + %26 = arith.cmpi ule, %23, %c2 : index + %27 = scf.if %26 -> (tensor<*xui32>) { + %28 = shape.broadcast %19#0, %3 : tensor, tensor<2xindex> -> tensor + %cast = tensor.cast %28 : tensor to tensor<2xindex> + %29 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xui32>, tensor<2xindex>) -> tensor + %30 = shape.broadcast %19#1, %3 : tensor, tensor<2xindex> -> tensor + %cast_0 = tensor.cast %30 : tensor to tensor<2xindex> + %31 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xui32>, tensor<2xindex>) -> tensor + %32 = chlo.broadcast_minimum %29, %31 : (tensor, tensor) -> tensor + %cast_1 = tensor.cast %32 : tensor to tensor<*xui32> + scf.yield %cast_1 : tensor<*xui32> + } else { + %28 = arith.cmpi ule, %23, %c3 : index + %29 = scf.if %28 -> (tensor<*xui32>) { + %30 = shape.broadcast %19#0, %2 : tensor, tensor<3xindex> -> tensor + %cast = tensor.cast %30 : tensor to tensor<3xindex> + %31 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xui32>, tensor<3xindex>) -> tensor + %32 = shape.broadcast %19#1, %2 : tensor, tensor<3xindex> -> tensor + %cast_0 = tensor.cast %32 : tensor to tensor<3xindex> + %33 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xui32>, tensor<3xindex>) -> tensor + %34 = chlo.broadcast_minimum %31, %33 : (tensor, tensor) -> tensor + %cast_1 = tensor.cast %34 : tensor to tensor<*xui32> + scf.yield %cast_1 : tensor<*xui32> + } else { + %30 = arith.cmpi ule, %23, %c4 : index + %31 = scf.if %30 -> (tensor<*xui32>) { + %32 = shape.broadcast %19#0, %1 : tensor, tensor<4xindex> -> tensor + %cast = tensor.cast %32 : tensor to tensor<4xindex> + %33 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xui32>, tensor<4xindex>) -> tensor + %34 = shape.broadcast %19#1, %1 : tensor, tensor<4xindex> -> tensor + %cast_0 = tensor.cast %34 : tensor to tensor<4xindex> + %35 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xui32>, tensor<4xindex>) -> tensor + %36 = chlo.broadcast_minimum %33, %35 : (tensor, tensor) -> tensor + %cast_1 = tensor.cast %36 : tensor to tensor<*xui32> + scf.yield %cast_1 : tensor<*xui32> + } else { + %32 = arith.cmpi ule, %23, %c5 : index + cf.assert %32, "Input for dynamic binary or n-ary op lowering was of a rank greater than 5" + %33 = shape.broadcast %19#0, %0 : tensor, tensor<5xindex> -> tensor + %cast = tensor.cast %33 : tensor to tensor<5xindex> + %34 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xui32>, tensor<5xindex>) -> tensor + %35 = shape.broadcast %19#1, %0 : tensor, tensor<5xindex> -> tensor + %cast_0 = tensor.cast %35 : tensor to tensor<5xindex> + %36 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xui32>, tensor<5xindex>) -> tensor + %37 = chlo.broadcast_minimum %34, %36 : (tensor, tensor) -> tensor + %cast_1 = tensor.cast %37 : tensor to tensor<*xui32> + scf.yield %cast_1 : tensor<*xui32> + } + scf.yield %31 : tensor<*xui32> + } + scf.yield %29 : tensor<*xui32> + } + scf.yield %27 : tensor<*xui32> + } + scf.yield %25 : tensor<*xui32> + } + scf.yield %18 : tensor<*xui32> + } + scf.yield %16 : tensor<*xui32> + } + %10 = shape.shape_of %arg0 : tensor<*xui32> -> tensor + %11 = shape.shape_of %arg1 : tensor<*xui32> -> tensor + %12 = shape.broadcast %10, %11 : tensor, tensor -> tensor + %13 = mhlo.dynamic_reshape %9, %12 : (tensor<*xui32>, tensor) -> tensor<*xui32> + return %13 : tensor<*xui32> +} diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/tests/hlo_to_kernel/tanh.mlir b/tensorflow/compiler/mlir/tools/kernel_gen/tests/hlo_to_kernel/tanh.mlir index 8685ab17faee7a..2d3c8e6f5b9ef7 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/tests/hlo_to_kernel/tanh.mlir +++ b/tensorflow/compiler/mlir/tools/kernel_gen/tests/hlo_to_kernel/tanh.mlir @@ -1,6 +1,11 @@ // RUN: hlo_to_kernel --input=%s --output=%t --unroll_factors=4 --tile_sizes=256 --arch=sm_70 -func.func @tanh(%arg: tensor<*xf32>) -> tensor<*xf32> attributes {tf_entry} { - %0 = mhlo.tanh %arg : tensor<*xf32> - return %0 : tensor<*xf32> +func.func @tanh(%arg0: tensor<*xf32>) -> tensor<*xf32> attributes {tf_entry} { + %0 = shape.shape_of %arg0 : tensor<*xf32> -> tensor + %1 = shape.num_elements %0 : tensor -> index + %from_elements = tensor.from_elements %1 : tensor<1xindex> + %2 = mhlo.dynamic_reshape %arg0, %from_elements : (tensor<*xf32>, tensor<1xindex>) -> tensor + %3 = mhlo.tanh %2 : tensor + %4 = mhlo.dynamic_reshape %3, %0 : (tensor, tensor) -> tensor<*xf32> + return %4 : tensor<*xf32> } diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/tests/isinf.mlir b/tensorflow/compiler/mlir/tools/kernel_gen/tests/isinf.mlir index 22c9572dfd810c..da3ca471a857e4 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/tests/isinf.mlir +++ b/tensorflow/compiler/mlir/tools/kernel_gen/tests/isinf.mlir @@ -1,6 +1,6 @@ // RUN: tf-opt %s --test-tf-lower-tf --xla-legalize-tf | \ -// RUN: mlir-hlo-opt --mhlo-rank-specialization-cluster \ -// RUN: --mhlo-rank-specialization-to-scf --hlo-legalize-to-linalg \ +// RUN: mlir-hlo-opt \ +// RUN: --hlo-legalize-to-linalg \ // RUN: --empty-tensor-to-alloc-tensor \ // RUN: --computeop-and-func-bufferize --canonicalize | \ // RUN: kernel-gen-opt -allow-unregistered-dialect \ diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/tests/ops.mlir b/tensorflow/compiler/mlir/tools/kernel_gen/tests/ops.mlir index ccd7a901fd0513..b5055e9ba4c8ab 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/tests/ops.mlir +++ b/tensorflow/compiler/mlir/tools/kernel_gen/tests/ops.mlir @@ -82,7 +82,7 @@ func.func @jit_compile(%ctx : !tf_framework.op_kernel_context) func.func @jit_compile_from_str_wo_ctx() -> !tf_framework.jit_callable { %callable = tf_framework.jit_compile_from_str "placeholder" { architectures = ["sm_123", "sm_456"], tileSizes = [1, 2, 3], - unrollFactors = [4], maxSupportedRank = 3 : i64, enableFtz = false, + unrollFactors = [4], enableFtz = false, index64Bit = false, cpuCodegen = false } func.return %callable : !tf_framework.jit_callable } @@ -92,7 +92,7 @@ func.func @jit_compile_from_str(%ctx : !tf_framework.op_kernel_context) -> !tf_framework.jit_callable { %callable = tf_framework.jit_compile_from_str %ctx , "placeholder" { architectures = ["sm_123", "sm_456"], tileSizes = [1, 2, 3], - unrollFactors = [4], maxSupportedRank = 3 : i64, enableFtz = false, + unrollFactors = [4], enableFtz = false, index64Bit = false, cpuCodegen = false } func.return %callable : !tf_framework.jit_callable } diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/tests/tanh.mlir b/tensorflow/compiler/mlir/tools/kernel_gen/tests/tanh.mlir deleted file mode 100644 index b8a227688c9dcb..00000000000000 --- a/tensorflow/compiler/mlir/tools/kernel_gen/tests/tanh.mlir +++ /dev/null @@ -1,26 +0,0 @@ -// RUN: tf-opt %s --test-tf-lower-tf --xla-legalize-tf | \ -// RUN: mlir-hlo-opt --mhlo-rank-specialization-cluster \ -// RUN: --mhlo-rank-specialization-to-scf --hlo-legalize-to-linalg \ -// RUN: --empty-tensor-to-alloc-tensor \ -// RUN: --computeop-and-func-bufferize --canonicalize | \ -// RUN: kernel-gen-opt -allow-unregistered-dialect \ -// RUN: --shape-to-descriptors \ -// RUN: --canonicalize --kernelgen-final-bufferize | \ -// RUN: FileCheck %s - -// Test whether all shape computations required for tanh can be lowered to -// the standard dialect, scf and descriptors. We check for a sparse pattern here, -// as each lowering pattern is already tested and we just care for the -// integration. -// TODO: Expand this pattern once things have stabilized. -// CHECK-LABEL: @tanh -func.func @tanh(%arg0: tensor<*xf32>) -> tensor<*xf32> { - // CHECK: alloc - // CHECK: scf.for - // CHECK: memref.reshape - // CHECK: alloc - // CHECK: linalg.generic - // CHECK: memref.reshape - %0 = "tf.Tanh"(%arg0) { } : (tensor<*xf32>) -> tensor<*xf32> - func.return %0 : tensor<*xf32> -} diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/tests/tf-legalize-to-lmhlo.mlir b/tensorflow/compiler/mlir/tools/kernel_gen/tests/tf-legalize-to-lmhlo.mlir deleted file mode 100644 index c30b21d8cd3162..00000000000000 --- a/tensorflow/compiler/mlir/tools/kernel_gen/tests/tf-legalize-to-lmhlo.mlir +++ /dev/null @@ -1,38 +0,0 @@ -// RUN: tf-opt %s --xla-legalize-tf='legalize-chlo=false' | \ -// RUN: mlir-hlo-opt --mhlo-rank-specialization-cluster \ -// RUN: --mhlo-rank-specialization-to-scf --chlo-legalize-to-hlo \ -// RUN: --hlo-legalize-to-linalg --empty-tensor-to-alloc-tensor \ -// RUN: --computeop-and-func-bufferize | \ -// RUN: kernel-gen-opt --shape-to-descriptors \ -// RUN: --canonicalize --kernelgen-final-bufferize - -func.func @acos(%arg0: tensor<*xf32>) -> tensor<*xf32> { - %0 = "tf.Acos"(%arg0) { } : (tensor<*xf32>) -> tensor<*xf32> - func.return %0 : tensor<*xf32> -} - -func.func @tan(%arg0: tensor<*xf32>) -> tensor<*xf32> { - %0 = "tf.Tan"(%arg0) { } : (tensor<*xf32>) -> tensor<*xf32> - func.return %0 : tensor<*xf32> -} - -func.func @tanh(%arg0: tensor<*xf32>) -> tensor<*xf32> { - %0 = "tf.Tanh"(%arg0) { } : (tensor<*xf32>) -> tensor<*xf32> - func.return %0 : tensor<*xf32> -} - -func.func @sin(%arg0: tensor<*xf32>) -> tensor<*xf32> { - %0 = "tf.Sin"(%arg0) { } : (tensor<*xf32>) -> tensor<*xf32> - func.return %0 : tensor<*xf32> -} - -func.func @sinh(%arg0: tensor<*xf32>) -> tensor<*xf32> { - %0 = "tf.Sinh"(%arg0) { } : (tensor<*xf32>) -> tensor<*xf32> - func.return %0 : tensor<*xf32> -} - -func.func @erf(%arg0: tensor<*xf32>) -> tensor<*xf32> { - %0 = "tf.Erf"(%arg0) { } : (tensor<*xf32>) -> tensor<*xf32> - func.return %0 : tensor<*xf32> -} - diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/tests/tf_framework_legalize_to_llvm.mlir b/tensorflow/compiler/mlir/tools/kernel_gen/tests/tf_framework_legalize_to_llvm.mlir index ed65470eb72728..4da5051a7bd2ac 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/tests/tf_framework_legalize_to_llvm.mlir +++ b/tensorflow/compiler/mlir/tools/kernel_gen/tests/tf_framework_legalize_to_llvm.mlir @@ -151,7 +151,7 @@ func.func @is_valid_memref(%buf: memref) -> i1 { // ----- -// CHECK-LABEL: llvm.func @_mlir_ciface_tf_jit_compile(!llvm.ptr, !llvm.ptr, i64, !llvm.ptr, i64, !llvm.ptr, i64, i1, i1, i1) -> !llvm.ptr +// CHECK-LABEL: llvm.func @_mlir_ciface_tf_jit_compile(!llvm.ptr, !llvm.ptr, i64, !llvm.ptr, i64, !llvm.ptr, i1, i1, i1) -> !llvm.ptr // CHECK: llvm.mlir.global internal constant @[[CODE:jit_module_code_[0-9]+]]("placeholder\00") // CHECK: @jit_compile_from_str(%[[CTX:.*]]: !llvm.ptr) @@ -184,17 +184,16 @@ func.func @jit_compile_from_str(%ctx: !tf_framework.op_kernel_context) // CHECK: %[[C4:.*]] = llvm.mlir.constant(4 : i64) // CHECK: llvm.store %[[C4]], %[[PTR]] - // CHECK-DAG: %[[MAX_RANK:.*]] = llvm.mlir.constant(3 : i64) // CHECK-DAG: %[[ENABLE_FTZ:.*]] = llvm.mlir.constant(false) // CHECK-DAG: %[[CPU_CODEGEN:.*]] = llvm.mlir.constant(false) // CHECK: %[[RES:.*]] = llvm.call @_mlir_ciface_tf_jit_compile // CHECK-SAME: %[[CTX]], %[[CODE_PTR]], // CHECK-SAME: %[[NUM_TILE_SIZES]], %[[TILE_SIZES]], // CHECK-SAME: %[[NUM_UNROLL_FACTORS]], %[[UNROLL_FACTORS]], - // CHECK-SAME: %[[MAX_RANK]], %[[ENABLE_FTZ]], %[[CPU_CODEGEN]] + // CHECK-SAME: %[[ENABLE_FTZ]], %[[CPU_CODEGEN]] // CHECK: llvm.return %[[RES]] %0 = tf_framework.jit_compile_from_str %ctx, "placeholder" { - tileSizes = [1, 2, 3], unrollFactors = [4], maxSupportedRank = 3 : i64, + tileSizes = [1, 2, 3], unrollFactors = [4], enableFtz = false, index64Bit = false, cpuCodegen = false } func.return %0 : !tf_framework.jit_callable } diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/tf_framework_c_interface.cc b/tensorflow/compiler/mlir/tools/kernel_gen/tf_framework_c_interface.cc index 34cbb4069ffeac..b2f717046ce1d1 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/tf_framework_c_interface.cc +++ b/tensorflow/compiler/mlir/tools/kernel_gen/tf_framework_c_interface.cc @@ -19,22 +19,44 @@ limitations under the License. #include #include +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "llvm/ADT/Hashing.h" #include "llvm/ADT/SmallVector.h" +#include "llvm/ExecutionEngine/JITSymbol.h" +#include "llvm/ExecutionEngine/Orc/Core.h" +#include "llvm/ExecutionEngine/Orc/Mangling.h" +#include "llvm/ExecutionEngine/Orc/Shared/ExecutorAddress.h" +#include "llvm/Support/Error.h" #include "llvm/Support/TargetSelect.h" +#include "llvm/Support/raw_ostream.h" #include "mlir/Dialect/MemRef/Transforms/AllocationOpInterfaceImpl.h" // from @llvm-project +#include "mlir/ExecutionEngine/CRunnerUtils.h" // from @llvm-project #include "mlir/ExecutionEngine/ExecutionEngine.h" // from @llvm-project #include "mlir/ExecutionEngine/OptUtils.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/OwningOpRef.h" // from @llvm-project #include "mlir/Parser/Parser.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/compiler/mlir/tools/kernel_gen/compile_cache_item.pb.h" #include "tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.h" #include "tensorflow/compiler/mlir/tools/kernel_gen/kernel_creator.h" #include "tensorflow/compiler/mlir/tools/kernel_gen/tf_jit_cache.h" +#include "xla/stream_executor/device_description.h" #include "xla/stream_executor/stream.h" #include "tensorflow/core/framework/allocator.h" +#include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/resource_mgr.h" -#include "tensorflow/core/lib/io/path.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/path.h" +#include "tensorflow/core/platform/refcount.h" #include "tensorflow/core/platform/status.h" #include "tensorflow/core/platform/statusor.h" +#include "tsl/framework/allocator.h" #if defined(GOOGLE_CUDA) || defined(TENSORFLOW_USE_ROCM) #include @@ -162,8 +184,8 @@ void InitializeLlvmCompiler() { llvm::Expected> Compile( const std::string code, llvm::SmallVectorImpl& architectures, llvm::SmallVectorImpl& tile_sizes, - llvm::SmallVectorImpl& unroll_factors, int64_t max_supported_rank, - bool enable_ftz, bool index_64bit) { + llvm::SmallVectorImpl& unroll_factors, bool enable_ftz, + bool index_64bit) { std::string cache_dir; if (const char* dir = getenv(kTFJitCacheDirEnvVar.data())) { cache_dir = dir; @@ -197,7 +219,6 @@ llvm::Expected> Compile( tensorflow::StatusOr> status_or_module = tensorflow::kernel_gen::GenerateKernelForHloCode( context, code, architectures, tile_sizes, unroll_factors, - max_supported_rank, /*print_ptx=*/false, /*print_llvmir=*/false, enable_ftz, index_64bit, /*jit_compile=*/false, @@ -261,8 +282,7 @@ llvm::SmallVector SmallVectorFromCArray(int64_t num_elements, extern "C" void* _mlir_ciface_tf_jit_compile( void* op_kernel_ctx, char* code, int64_t num_tile_sizes, int64_t* tile_sizes_ptr, int64_t num_unroll_factors, - int64_t* unroll_factors_ptr, int64_t max_supported_rank, bool enable_ftz, - bool index_64bit) { + int64_t* unroll_factors_ptr, bool enable_ftz, bool index_64bit) { // Get the resource manager. auto* ctx = static_cast(op_kernel_ctx); tensorflow::ResourceMgr* rm = ctx->resource_manager(); @@ -303,8 +323,8 @@ extern "C" void* _mlir_ciface_tf_jit_compile( // Lookup or compile the execution module. ExecutionEngine* engine = jit_cache->LookupOrCompile(code, [&]() { - return Compile(code, architectures, tile_sizes, unroll_factors, - max_supported_rank, enable_ftz, index_64bit); + return Compile(code, architectures, tile_sizes, unroll_factors, enable_ftz, + index_64bit); }); if (engine == nullptr) { ReportError(op_kernel_ctx, ErrorCode::UNKNOWN, "JIT compilation failed."); diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/tf_framework_c_interface.h b/tensorflow/compiler/mlir/tools/kernel_gen/tf_framework_c_interface.h index 3a7c879d2c64c9..a62dc2c7020ab7 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/tf_framework_c_interface.h +++ b/tensorflow/compiler/mlir/tools/kernel_gen/tf_framework_c_interface.h @@ -36,8 +36,7 @@ extern "C" MLIR_RUNNERUTILS_EXPORT void _mlir_ciface_tf_report_error( extern "C" MLIR_RUNNERUTILS_EXPORT void* _mlir_ciface_tf_jit_compile( void* op_kernel_ctx, char* code, int64_t num_tile_sizes, int64_t* tile_sizes_ptr, int64_t num_unroll_factors, - int64_t* unroll_factors_ptr, int64_t max_supported_rank, bool enable_ftz, - bool index_64bit); + int64_t* unroll_factors_ptr, bool enable_ftz, bool index_64bit); extern "C" MLIR_RUNNERUTILS_EXPORT void _mlir_ciface_tf_jit_execute( void* op_kernel_ctx, void* callable, void* result, int64_t num_args, diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/tf_gpu_runtime_wrappers.h b/tensorflow/compiler/mlir/tools/kernel_gen/tf_gpu_runtime_wrappers.h index 79e2c0cc530552..be1325c5a4dc16 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/tf_gpu_runtime_wrappers.h +++ b/tensorflow/compiler/mlir/tools/kernel_gen/tf_gpu_runtime_wrappers.h @@ -18,9 +18,12 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "mlir/ExecutionEngine/RunnerUtils.h" // from @llvm-project +#include "tensorflow/core/framework/resource_base.h" #include "tensorflow/core/framework/resource_op_kernel.h" #include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/status.h" #include "tsl/platform/hash.h" +#include "tsl/platform/thread_annotations.h" #if GOOGLE_CUDA #include "third_party/gpus/cuda/include/cuda.h" diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/tf_jit_cache.cc b/tensorflow/compiler/mlir/tools/kernel_gen/tf_jit_cache.cc index 4d69d754a0a8b8..b2e3cc19b581c5 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/tf_jit_cache.cc +++ b/tensorflow/compiler/mlir/tools/kernel_gen/tf_jit_cache.cc @@ -19,6 +19,11 @@ limitations under the License. #include #include +#include "llvm/Support/Error.h" +#include "mlir/ExecutionEngine/ExecutionEngine.h" // from @llvm-project +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/status.h" + namespace mlir { namespace kernel_gen { namespace tf_framework { diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/tf_jit_cache.h b/tensorflow/compiler/mlir/tools/kernel_gen/tf_jit_cache.h index 414e7954be271e..f03e5778f6c8c5 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/tf_jit_cache.h +++ b/tensorflow/compiler/mlir/tools/kernel_gen/tf_jit_cache.h @@ -21,7 +21,11 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "mlir/ExecutionEngine/ExecutionEngine.h" // from @llvm-project +#include "tensorflow/core/framework/resource_base.h" #include "tensorflow/core/framework/resource_op_kernel.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/status.h" +#include "tsl/platform/thread_annotations.h" namespace mlir { namespace kernel_gen { diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/copy_cleanup_pass.cc b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/copy_cleanup_pass.cc index af0943ded1f8d8..32faed506e52b4 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/copy_cleanup_pass.cc +++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/copy_cleanup_pass.cc @@ -60,7 +60,7 @@ void RemoveCopyIfTargetOnlyRead(func::FuncOp func) { } continue; } - if (auto effect_interface = cast(user)) { + if (auto effect_interface = dyn_cast(user)) { if (reader) { at_most_one_read = false; } else { diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/func_to_jit_invocations.cc b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/func_to_jit_invocations.cc index d87271e161529b..60f8876f109553 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/func_to_jit_invocations.cc +++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/func_to_jit_invocations.cc @@ -176,8 +176,7 @@ LogicalResult RewriteToLargeSizeJit(FuncOp op) { void PackJITCompileOp(tf_framework::JITCompileOp op, llvm::ArrayRef tile_sizes, - llvm::ArrayRef unroll_factors, - int64_t max_supported_rank, bool enable_ftz, + llvm::ArrayRef unroll_factors, bool enable_ftz, bool index_64bit, bool cpu_codegen) { IRRewriter rewriter(op.getContext()); Block *body = op.SingleBlock::getBody(); @@ -219,7 +218,6 @@ void PackJITCompileOp(tf_framework::JITCompileOp op, op, op->getResultTypes(), op.getCtx(), rewriter.getStringAttr(code), rewriter.getI64ArrayAttr(tile_sizes), rewriter.getI64ArrayAttr(unroll_factors), - rewriter.getI64IntegerAttr(max_supported_rank), rewriter.getBoolAttr(enable_ftz), rewriter.getBoolAttr(index_64bit), rewriter.getBoolAttr(cpu_codegen)); } @@ -231,12 +229,11 @@ struct FuncToJITInvocationPass : public impl::FuncToJITInvocationPassBase { explicit FuncToJITInvocationPass(llvm::ArrayRef tile_sizes, llvm::ArrayRef unroll_factors, - int64_t max_supported_rank, bool enable_ftz, - bool index_64bit, bool cpu_codegen, + bool enable_ftz, bool index_64bit, + bool cpu_codegen, bool jit_i64_indexed_for_large_tensors) { tile_sizes_ = tile_sizes; unroll_factors_ = unroll_factors; - max_supported_rank_ = max_supported_rank; enable_ftz_ = enable_ftz; index_64bit_ = index_64bit; cpu_codegen_ = cpu_codegen; @@ -255,9 +252,9 @@ struct FuncToJITInvocationPass } getOperation().walk([&](tf_framework::JITCompileOp op) { - PackJITCompileOp( - op, tile_sizes_, unroll_factors_, max_supported_rank_, enable_ftz_, - index_64bit_ || jit_i64_indexed_for_large_tensors_, cpu_codegen_); + PackJITCompileOp(op, tile_sizes_, unroll_factors_, enable_ftz_, + index_64bit_ || jit_i64_indexed_for_large_tensors_, + cpu_codegen_); }); } }; @@ -266,11 +263,11 @@ struct FuncToJITInvocationPass std::unique_ptr> CreateFuncToJITInvocationPass( llvm::ArrayRef tile_sizes, llvm::ArrayRef unroll_factors, - int64_t max_supported_rank, bool enable_ftz, bool index_64bit, - bool cpu_codegen, bool jit_i64_indexed_for_large_tensors) { + bool enable_ftz, bool index_64bit, bool cpu_codegen, + bool jit_i64_indexed_for_large_tensors) { return std::make_unique( - tile_sizes, unroll_factors, max_supported_rank, enable_ftz, index_64bit, - cpu_codegen, jit_i64_indexed_for_large_tensors); + tile_sizes, unroll_factors, enable_ftz, index_64bit, cpu_codegen, + jit_i64_indexed_for_large_tensors); } } // namespace transforms diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.h b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.h index 355030d6009f18..45e248ceb904ff 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.h +++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.h @@ -66,8 +66,8 @@ std::unique_ptr> CreateBufferReusePass(); // framework. std::unique_ptr> CreateFuncToJITInvocationPass( llvm::ArrayRef tile_sizes = {}, - llvm::ArrayRef unroll_factors = {}, int64_t max_supported_rank = 5, - bool enable_ftz = false, bool index_64bit = false, bool cpu_codegen = false, + llvm::ArrayRef unroll_factors = {}, bool enable_ftz = false, + bool index_64bit = false, bool cpu_codegen = false, bool jit_i64_indexed_for_large_tensors = false); // Pass for applying LLVM legalization patterns. diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.td b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.td index 0aa80747ef5ebc..4f92be70d25397 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.td +++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.td @@ -52,8 +52,6 @@ def FuncToJITInvocationPass : Pass<"func-to-jit-invocation", "mlir::func::FuncOp "llvm::cl::ZeroOrMore">, ListOption<"unroll_factors_", "unroll-factors", "int64_t", "Unrolling in each tile dimension", "llvm::cl::ZeroOrMore">, - Option<"max_supported_rank_", "max-supported-rank", "int64_t", - /*default=*/"", "Max rank that this kernel supports">, Option<"enable_ftz_", "enable-ftz", "bool", /*default=*/"", "Enable the denormal flush to zero mode when generating code">, Option<"index_64bit_", "index_64bit", "bool", /*default=*/"", diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/tf_framework_legalize_to_llvm.cc b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/tf_framework_legalize_to_llvm.cc index fe649e1edeb723..cffa5e7b44691e 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/tf_framework_legalize_to_llvm.cc +++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/tf_framework_legalize_to_llvm.cc @@ -271,8 +271,6 @@ class JITCompileFromStrOpConverter ConvertIntegerArrayAttrToStackAllocatedArray( loc, rewriter.getI64Type(), rewriter.getI64Type(), op.getUnrollFactors(), &rewriter); - Value max_supported_rank = rewriter.create( - loc, rewriter.getI64Type(), op.getMaxSupportedRankAttr()); Value enable_ftz = rewriter.create( loc, rewriter.getI1Type(), op.getEnableFtzAttr()); Value index_64bit = rewriter.create( @@ -285,8 +283,8 @@ class JITCompileFromStrOpConverter op, getVoidPtrType(), tf_func_ref, llvm::ArrayRef({adaptor.getCtx(), jit_module_code, tile_sizes.first, tile_sizes.second, unroll_factors.first, - unroll_factors.second, max_supported_rank, enable_ftz, - index_64bit, cpu_codegen})); + unroll_factors.second, enable_ftz, index_64bit, + cpu_codegen})); return success(); } @@ -304,7 +302,6 @@ class JITCompileFromStrOpConverter /*int64_t* tile_sizes_ptr*/ ptr_ty, /*int64_t num_unroll_factors*/ i64_ty, /*int64_t* unroll_factors_ptr*/ ptr_ty, - /*int64_t max_supported_rank*/ i64_ty, /*bool enable_ftz*/ i1_ty, /*bool index_64bit*/ i1_ty, /*bool cpu_codegen*/ i1_ty}); diff --git a/tensorflow/compiler/tests/xla_call_module_test.py b/tensorflow/compiler/tests/xla_call_module_test.py index 76a76a4ab82747..4a1c8edddc1e54 100644 --- a/tensorflow/compiler/tests/xla_call_module_test.py +++ b/tensorflow/compiler/tests/xla_call_module_test.py @@ -247,27 +247,6 @@ def f(x): # x: f32[2, b] self._assertOpOutputMatchesExpected(f, (x,), (np.sin(x), x.shape[1])) - def test_poly_unranked(self): - x = np.arange(6, dtype=np.float32).reshape((2, 3)) - - def f(x): # x: f32[2, b] - # sin(x) - module, version = serialize(""" -module @jit_f.0 attributes {jax.uses_shape_polymorphism = true} { - func.func public @main(%arg1: tensor<*xf32>) -> tensor<*xf32> { - %0 = stablehlo.sine %arg1 : tensor<*xf32> - return %0 : tensor<*xf32> - } -} -""") - return xla.call_module([x], - module=module, version=version, - Tout=[x.dtype], - Sout=[(None, None),], - platforms=[self.testing_platform()],) - - self._assertOpOutputMatchesExpected(f, (x,), (np.sin(x),)) - def test_wrong_actual_args_errors(self): x = np.arange(6, dtype=np.float32).reshape((3, 2)) y = np.arange(6, dtype=np.int32).reshape((2, 3)) diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD index c565cef1489532..4f6706c11b3234 100644 --- a/tensorflow/compiler/tf2xla/BUILD +++ b/tensorflow/compiler/tf2xla/BUILD @@ -944,8 +944,8 @@ tf_cc_test( "@local_xla//xla/client:local_client", "@local_xla//xla/service:compiler", "@local_xla//xla/service:platform_util", - "@local_xla//xla/stream_executor:multi_platform_manager", "@local_xla//xla/stream_executor:platform", + "@local_xla//xla/stream_executor:platform_manager", ], ) @@ -1171,14 +1171,15 @@ cc_library( "//tensorflow/compiler/mlir:mlir_graph_optimization_pass", "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/mlir/tensorflow:device_util", - "//tensorflow/compiler/mlir/tensorflow/transforms:bridge", "//tensorflow/compiler/mlir/tensorflow/transforms/host_runtime:lower_cluster_to_runtime_ops", "//tensorflow/compiler/mlir/tf2xla:mlir_bridge_rollout_policy", "//tensorflow/compiler/mlir/tf2xla/api/v1:cluster_tf", "//tensorflow/compiler/mlir/tf2xla/api/v1:tf_dialect_to_executor", "//tensorflow/compiler/mlir/tf2xla/api/v2:cluster_tf", "//tensorflow/compiler/mlir/tf2xla/api/v2:tf_dialect_to_executor", + "//tensorflow/compiler/mlir/tf2xla/internal:mlir_bridge_pass_util", "//tensorflow/core:core_cpu", + "//tensorflow/core:core_cpu_base", "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core/common_runtime:device_set", diff --git a/tensorflow/compiler/tf2xla/const_analysis.cc b/tensorflow/compiler/tf2xla/const_analysis.cc index b39d6f8f5d2bed..b8d91294ca7b18 100644 --- a/tensorflow/compiler/tf2xla/const_analysis.cc +++ b/tensorflow/compiler/tf2xla/const_analysis.cc @@ -42,7 +42,7 @@ Status GetFunctionBody(FunctionLibraryRuntime* flib_runtime, TF_RETURN_IF_ERROR(flib_runtime->Instantiate( name_attr_list.name(), AttrSlice(&name_attr_list.attr()), &func_handle)); *fbody = flib_runtime->GetFunctionBody(func_handle); - return OkStatus(); + return absl::OkStatus(); } Status GetFunctionBodies(FunctionLibraryRuntime* flib_runtime, @@ -57,7 +57,7 @@ Status GetFunctionBodies(FunctionLibraryRuntime* flib_runtime, &func_handle)); fbodies->push_back(flib_runtime->GetFunctionBody(func_handle)); } - return OkStatus(); + return absl::OkStatus(); } Status CondConstInputIndices( @@ -84,7 +84,7 @@ Status CondConstInputIndices( const_input_idxs->push_back(i + 1); } } - return OkStatus(); + return absl::OkStatus(); } Status GetCompileTimeConstInputs(const NodeDef& node, const OpKernel* op_kernel, @@ -133,7 +133,7 @@ Status GetCompileTimeConstInputs(const NodeDef& node, const OpKernel* op_kernel, } } } - return OkStatus(); + return absl::OkStatus(); } else if (node.op() == "If" || node.op() == "StatelessIf") { const FunctionBody* fthen = nullptr; const FunctionBody* felse = nullptr; @@ -162,7 +162,7 @@ Status GetCompileTimeConstInputs(const NodeDef& node, const OpKernel* op_kernel, const_input_idxs->push_back(i); } } - return OkStatus(); + return absl::OkStatus(); } else if (op_def != nullptr) { return XlaOpRegistry::CompileTimeConstantInputs(node, *op_def, const_input_idxs); @@ -193,7 +193,7 @@ Status BackwardsConstAnalysis( !edge_filter_input) { VLOG(5) << "Using cached argument indices on graph " << &g; *compile_time_const_arg_indices = g.GetConstArgIndicesCache().value(); - return OkStatus(); + return absl::OkStatus(); } auto edge_filter = [&](const Edge& e) { return edge_filter_input ? edge_filter_input(e) : true; diff --git a/tensorflow/compiler/tf2xla/frontend_attributes_util.cc b/tensorflow/compiler/tf2xla/frontend_attributes_util.cc index bd33022634c61c..b17f636eb94b0b 100644 --- a/tensorflow/compiler/tf2xla/frontend_attributes_util.cc +++ b/tensorflow/compiler/tf2xla/frontend_attributes_util.cc @@ -21,11 +21,11 @@ limitations under the License. namespace tensorflow { -StatusOr> +absl::StatusOr> GetFrontendAttributesFromAttrSlice(const AttrSlice& attrs) { const AttrValue* attr = attrs.Find(kXlaFrontendAttributesAttrName); if (attr == nullptr) { - return StatusOr>(std::nullopt); + return absl::StatusOr>(std::nullopt); } xla::FrontendAttributes attributes; if (!attributes.ParseFromString(attr->s())) { diff --git a/tensorflow/compiler/tf2xla/frontend_attributes_util.h b/tensorflow/compiler/tf2xla/frontend_attributes_util.h index 6113a9dd172615..2f8436fad402bd 100644 --- a/tensorflow/compiler/tf2xla/frontend_attributes_util.h +++ b/tensorflow/compiler/tf2xla/frontend_attributes_util.h @@ -28,7 +28,7 @@ namespace tensorflow { // // Return an InvalidArgument error if some attributes are present but // cannot be parsed. -StatusOr> +absl::StatusOr> GetFrontendAttributesFromAttrSlice(const AttrSlice& attrs); } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/functionalize_cond.cc b/tensorflow/compiler/tf2xla/functionalize_cond.cc index 1d3b6d3fe873b0..577cbd9126e62b 100644 --- a/tensorflow/compiler/tf2xla/functionalize_cond.cc +++ b/tensorflow/compiler/tf2xla/functionalize_cond.cc @@ -128,14 +128,14 @@ Status GetSwitchPredicate(const Node& switch_node, OutputTensor* pred) { TF_RETURN_IF_ERROR(pred_edge->src()->input_edge(0, &pred_edge)); } *pred = OutputTensor(pred_edge->src(), pred_edge->src_output()); - return OkStatus(); + return absl::OkStatus(); } Status GetSwitchValue(const Node& switch_node, OutputTensor* val) { const Edge* val_edge; TF_RETURN_IF_ERROR(switch_node.input_edge(0, &val_edge)); *val = OutputTensor(val_edge->src(), val_edge->src_output()); - return OkStatus(); + return absl::OkStatus(); } bool StateMap::OutputTensorLess::operator()(const OutputTensor& lhs, @@ -394,7 +394,7 @@ Conditional::Conditional(OutputTensor predicate, FunctionalizeCond* parent, Status Conditional::AddMerge(Node* m) { merges_.insert(m); - return OkStatus(); + return absl::OkStatus(); } Status Conditional::AddSwitch(Node* s) { @@ -410,7 +410,7 @@ Status Conditional::AddSwitch(Node* s) { } switches_.insert(s); parent_->AddSwitchId(s->id()); - return OkStatus(); + return absl::OkStatus(); } Status Conditional::BuildArgumentNodes() { @@ -492,7 +492,7 @@ Status Conditional::BuildArgumentNodes() { } } - return OkStatus(); + return absl::OkStatus(); } Status Conditional::AddSwitchNodeAlongEdge(const Edge* edge, BranchType branch, @@ -741,7 +741,7 @@ Status Conditional::ExtractBodies(Graph* graph) { } } } - return OkStatus(); + return absl::OkStatus(); } Status Conditional::BuildIfNode(Graph* graph, @@ -834,7 +834,7 @@ Status Conditional::BuildIfNode(Graph* graph, TF_ASSIGN_OR_RETURN(if_node_, parent_->AddIfNode(if_def, *merges_.begin(), predicate_)); - return OkStatus(); + return absl::OkStatus(); } Status Conditional::AddInputEdges( @@ -871,7 +871,7 @@ Status Conditional::AddInputEdges( for (Node* n : external_control_inputs_) { graph->AddControlEdge(n, if_node_); } - return OkStatus(); + return absl::OkStatus(); } Status Conditional::AddOutputEdges( @@ -910,7 +910,7 @@ Status Conditional::AddOutputEdges( graph->AddControlEdge(if_node_, n); } - return OkStatus(); + return absl::OkStatus(); } Status Conditional::BuildAndReplace( @@ -918,7 +918,7 @@ Status Conditional::BuildAndReplace( std::unordered_map* merge_to_replacement) { VLOG(1) << "Build If and replace merge nodes " << NodesToString(this->merges_); - if (replaced_) return OkStatus(); + if (replaced_) return absl::OkStatus(); TF_RETURN_IF_ERROR(ExtractBodies(graph)); TF_RETURN_IF_ERROR(BuildArgumentNodes()); @@ -944,7 +944,7 @@ Status Conditional::BuildAndReplace( "Converting to If failed."); replaced_ = true; - return OkStatus(); + return absl::OkStatus(); } string Conditional::name() const { @@ -966,12 +966,11 @@ Status FunctionalizeCond::AddIdentityNode(const Node* replacee, Node* if_node, TF_RETURN_IF_ERROR(id_builder.Finalize(graph_, &id)); state_map_.ResetCondId(id, state_map_.LookupCondId(if_node)); state_map_.ResetAncestorId(id, state_map_.LookupAncestorId(if_node)); - return OkStatus(); + return absl::OkStatus(); } -StatusOr FunctionalizeCond::AddIfNode(const NodeDef& def, - const Node* replacee, - const OutputTensor& predicate) { +absl::StatusOr FunctionalizeCond::AddIfNode( + const NodeDef& def, const Node* replacee, const OutputTensor& predicate) { TF_ASSIGN_OR_RETURN(Node * ret, graph_->AddNode(def)); VLOG(1) << "Adding If for " << replacee->name(); StateMap::CondId id = state_map_.LookupCondId(replacee); @@ -1018,7 +1017,7 @@ Status FunctionalizeCond::PropagateUpdatedState(const Node* replacee) { changed.erase(n); } } - return OkStatus(); + return absl::OkStatus(); } // Returns the most restrictive branch of two branches or neither. This is the @@ -1040,7 +1039,7 @@ BranchType StateMap::FindBranchOf(CondId id, OutputTensor predicate) const { return it->second; } -StatusOr FunctionalizeCond::JoinCondStatesNonMerge( +absl::StatusOr FunctionalizeCond::JoinCondStatesNonMerge( StateMap::CondId src, StateMap::CondId dst) { VLOG(5) << "Joining src=" << DebugString(src) << " [" << src << "] and dst=" << DebugString(dst) << " [" << dst << "]"; @@ -1076,7 +1075,7 @@ StatusOr FunctionalizeCond::JoinCondStatesNonMerge( return state_map_.GetCondId(both); } -StatusOr FunctionalizeCond::JoinCondStatesMerge( +absl::StatusOr FunctionalizeCond::JoinCondStatesMerge( Node* merge, StateMap::CondId src, StateMap::CondId dst) { // Determine the flow state when joining two states for a merge // node. Combining the two states for a merge node is effectively performing a @@ -1160,7 +1159,7 @@ Status FunctionalizeCond::DetermineCondStateMerge(Node* dst) { // Only Merge nodes with two inputs are supported, but if this is a redundant // merge, then the dead edge may already have been removed (if due to a // switch) and so the input count would be incorrect. - if (state_map_.IsDead(state_map_.LookupCondId(dst))) return OkStatus(); + if (state_map_.IsDead(state_map_.LookupCondId(dst))) return absl::OkStatus(); int data_inputs = 0; for (auto e : dst->in_edges()) { @@ -1183,7 +1182,7 @@ Status FunctionalizeCond::DetermineCondStateMerge(Node* dst) { dst->name(), " only has ", data_inputs, " inputs, while only merge nodes with two inputs supported."); } - return OkStatus(); + return absl::OkStatus(); } Status FunctionalizeCond::DetermineCondStateNonMerge(Node* dst) { @@ -1201,13 +1200,14 @@ Status FunctionalizeCond::DetermineCondStateNonMerge(Node* dst) { FormatNodeForError(*dst)); state_map_.ResetCondId(dst, id_or.value()); } - return OkStatus(); + return absl::OkStatus(); } Status FunctionalizeCond::RemoveRedundantMerge(Node* node) { // Handle redundant merge nodes. A merge node is considered redundant if // one input edge is dead while the other has a value. - if (!state_map_.IsDead(state_map_.LookupCondId(node))) return OkStatus(); + if (!state_map_.IsDead(state_map_.LookupCondId(node))) + return absl::OkStatus(); const Edge* non_dead_edge = nullptr; for (auto e : node->in_edges()) { @@ -1239,7 +1239,7 @@ Status FunctionalizeCond::RemoveRedundantMerge(Node* node) { : non_dead_edge->src_output(), dst_node, dst_port); } - return OkStatus(); + return absl::OkStatus(); } Status FunctionalizeCond::RemoveRedundantSwitch(Node* node) { @@ -1251,7 +1251,7 @@ Status FunctionalizeCond::RemoveRedundantSwitch(Node* node) { // (rather than boolean equivalence) and aimed at redundant switches as // currently generated by gradient code. StateMap::CondId dst_id = state_map_.LookupCondId(node); - if (state_map_.IsDead(dst_id)) return OkStatus(); + if (state_map_.IsDead(dst_id)) return absl::OkStatus(); BranchType b; OutputTensor pred; @@ -1272,7 +1272,7 @@ Status FunctionalizeCond::RemoveRedundantSwitch(Node* node) { } b = state_map_.FindBranchOf(dst_id, val); if (b != BranchType::kThenBranch && b != BranchType::kElseBranch) - return OkStatus(); + return absl::OkStatus(); } VLOG(5) << "Redundant switch " << node->name() << " " << Branch_Name(b) << " " @@ -1309,7 +1309,7 @@ Status FunctionalizeCond::RemoveRedundantSwitch(Node* node) { switch_branch == Graph::kControlSlot ? Graph::kControlSlot : val_port, dst_node, dst_input); } - return OkStatus(); + return absl::OkStatus(); } Status FunctionalizeCond::DetermineStates(std::vector rev_topo_order) { @@ -1325,7 +1325,7 @@ Status FunctionalizeCond::DetermineStates(std::vector rev_topo_order) { << " @ " << state_map_.AncestorStateToString(dst); if (VLOG_IS_ON(10)) DumpGraphWithCondState("it"); } - return OkStatus(); + return absl::OkStatus(); } Status FunctionalizeCond::DetermineAncestorState(Node* dst) { @@ -1359,7 +1359,7 @@ Status FunctionalizeCond::DetermineAncestorState(Node* dst) { id = insert(id, src); } state_map_.ResetAncestorId(dst, id); - return OkStatus(); + return absl::OkStatus(); } void FunctionalizeCond::DeleteReachableAndDeadNodes( @@ -1504,7 +1504,7 @@ Status FunctionalizeCond::FunctionalizeInternal() { // No merges mean no switch values consumed (as only considering values // fetchable as output of merge); DeleteReachableAndDeadNodes(merge_order); - return OkStatus(); + return absl::OkStatus(); } TF_RETURN_IF_ERROR(DetermineStates(std::move(rev_topo_order))); @@ -1574,7 +1574,7 @@ Status FunctionalizeCond::FunctionalizeInternal() { DeleteReachableAndDeadNodes(merge_order); - return OkStatus(); + return absl::OkStatus(); } void FunctionalizeCond::DumpGraphWithCondState(const string& name) { diff --git a/tensorflow/compiler/tf2xla/functionalize_cond.h b/tensorflow/compiler/tf2xla/functionalize_cond.h index 1841a495bac601..23b2acb56978d0 100644 --- a/tensorflow/compiler/tf2xla/functionalize_cond.h +++ b/tensorflow/compiler/tf2xla/functionalize_cond.h @@ -193,8 +193,8 @@ class FunctionalizeCond { // Add a If node to the graph defined by def that will, amongst other, replace // replacee in the graph. - StatusOr AddIfNode(const NodeDef& def, const Node* replacee, - const OutputTensor& predicate); + absl::StatusOr AddIfNode(const NodeDef& def, const Node* replacee, + const OutputTensor& predicate); // Propagates the state of a newly inserted node. Status PropagateUpdatedState(const Node* replacee); @@ -238,11 +238,11 @@ class FunctionalizeCond { // Determines the dst node's CondState by joining the src and dst's CondState // where either the dst node is a merge or not. // These may modify state_map_. - StatusOr JoinCondStatesMerge(Node* merge, - StateMap::CondId src, - StateMap::CondId dst); - StatusOr JoinCondStatesNonMerge(StateMap::CondId src, - StateMap::CondId dst); + absl::StatusOr JoinCondStatesMerge(Node* merge, + StateMap::CondId src, + StateMap::CondId dst); + absl::StatusOr JoinCondStatesNonMerge(StateMap::CondId src, + StateMap::CondId dst); // Determines which switch/merge nodes are ancestors of this node. Status DetermineAncestorState(Node* dst); diff --git a/tensorflow/compiler/tf2xla/functionalize_cond_test.cc b/tensorflow/compiler/tf2xla/functionalize_cond_test.cc index 7b37017d165658..d8015ce6835d09 100644 --- a/tensorflow/compiler/tf2xla/functionalize_cond_test.cc +++ b/tensorflow/compiler/tf2xla/functionalize_cond_test.cc @@ -52,13 +52,14 @@ class FunctionalizeCondTest : public ::testing::Test { return fc_->state_map_.CondStateToString(id); } - StatusOr JoinCondStatesNonMerge(StateMap::CondId src, - StateMap::CondId dst) { + absl::StatusOr JoinCondStatesNonMerge( + StateMap::CondId src, StateMap::CondId dst) { return fc_->JoinCondStatesNonMerge(src, dst); } - StatusOr JoinCondStatesMerge(Node* n, StateMap::CondId src, - StateMap::CondId dst) { + absl::StatusOr JoinCondStatesMerge(Node* n, + StateMap::CondId src, + StateMap::CondId dst) { return fc_->JoinCondStatesMerge(n, src, dst); } diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc index 1d13170daa43e6..2bad3b58d34761 100644 --- a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc +++ b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc @@ -117,7 +117,8 @@ Status AddFunctionDefToGraphLibrary( // `graph->flib_def().default_registry()` which is done in the following line // (we have to use `LookUp` instead of `Contains` or `Find` because the latter // both don't check the default registry). - if (graph->flib_def().LookUp(func_name, &op_reg_data).ok()) return OkStatus(); + if (graph->flib_def().LookUp(func_name, &op_reg_data).ok()) + return absl::OkStatus(); const FunctionDef* new_fdef = fld->Find(func_name); DCHECK(new_fdef != nullptr); @@ -197,7 +198,7 @@ Status FunctionalizeControlFlowForNodeAssociatedFunctions( } } } - return OkStatus(); + return absl::OkStatus(); } Status FunctionalizeControlFlowForFunction( @@ -210,7 +211,7 @@ Status FunctionalizeControlFlowForFunction( // Convert the function to a graph. FunctionLibraryRuntime::Handle handle; TF_RETURN_IF_ERROR(flr->Instantiate(func_name, AttrSlice(&attrs), &handle)); - Status ret_status = OkStatus(); + Status ret_status = absl::OkStatus(); auto cleanup_handle = gtl::MakeCleanup([&]() { auto s = flr->ReleaseHandle(handle); if (!s.ok()) { @@ -304,7 +305,7 @@ Status FunctionalizeControlFlow(Graph* graph, VLOG(2) << "FunctionalizeControlFlow (final): " << DumpGraphToFile("functionalize_final", *graph, library); - return OkStatus(); + return absl::OkStatus(); } Status FunctionalizeControlFlowForGraphDef(GraphDef* graph_def, @@ -319,7 +320,7 @@ Status FunctionalizeControlFlowForGraphDef(GraphDef* graph_def, include_functions)); graph.ToGraphDef(graph_def); std::swap(*graph_def->mutable_library(), function_lib); - return OkStatus(); + return absl::OkStatus(); } Status FunctionalizeControlFlowForXlaPass::Run( @@ -388,7 +389,7 @@ Status FunctionalizeControlFlowForXlaPass::Run( DumpGraphToFile("functionalize_control_flow_after", *graph, options.flib_def); } - return OkStatus(); + return absl::OkStatus(); } } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc b/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc index a116f905097989..25a08224c8b946 100644 --- a/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc +++ b/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc @@ -56,7 +56,7 @@ Status FindIfThenAndElse(const GraphDef& graph, string* op_name, *then_fn = *result; TF_RETURN_IF_ERROR(GetNodeAttr(node, "else_branch", &result)); *else_fn = *result; - return OkStatus(); + return absl::OkStatus(); } } return errors::NotFound("No If node found in graph"); @@ -317,7 +317,7 @@ Status FindWhileCondAndBody(const GraphDef& graph, NameAttrList* cond, *cond = *result; TF_RETURN_IF_ERROR(GetNodeAttr(node, "body", &result)); *body = *result; - return OkStatus(); + return absl::OkStatus(); } } return errors::NotFound("No While node found in graph"); diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow_util.cc b/tensorflow/compiler/tf2xla/functionalize_control_flow_util.cc index 71e3f9f69e4445..c2bc42b5c24e14 100644 --- a/tensorflow/compiler/tf2xla/functionalize_control_flow_util.cc +++ b/tensorflow/compiler/tf2xla/functionalize_control_flow_util.cc @@ -30,7 +30,7 @@ bool NodeCmpByNameResourcesLast::operator()(const Node* lhs, std::tie(rhs_is_resource, rhs->name()); } -StatusOr BuildRetvalNode(Graph* graph, DataType type, int index) { +absl::StatusOr BuildRetvalNode(Graph* graph, DataType type, int index) { const char* const kRetValOp = "_Retval"; NodeDef ret_def; ret_def.set_op(kRetValOp); @@ -78,7 +78,7 @@ Status ExtractWhileLoopFrames( } } - return OkStatus(); + return absl::OkStatus(); } // Check that the graph has no cycle containing the given node. @@ -99,7 +99,7 @@ Status CheckNodeNotInCycle(const Node* node, const int num_nodes) { } } } - return OkStatus(); + return absl::OkStatus(); } } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow_util.h b/tensorflow/compiler/tf2xla/functionalize_control_flow_util.h index 888839455db8b2..5d7ce5618fe252 100644 --- a/tensorflow/compiler/tf2xla/functionalize_control_flow_util.h +++ b/tensorflow/compiler/tf2xla/functionalize_control_flow_util.h @@ -90,10 +90,10 @@ struct NodeCmpByNameResourcesLast { }; // Returns the Node* created from the NodeDef in the Graph. -StatusOr AddNodeDefToGraph(const NodeDef& node_def, Graph* graph); +absl::StatusOr AddNodeDefToGraph(const NodeDef& node_def, Graph* graph); // Build a retval node of given type and index. -StatusOr BuildRetvalNode(Graph* graph, DataType type, int index); +absl::StatusOr BuildRetvalNode(Graph* graph, DataType type, int index); // Returns a textual representation of the names of the nodes in the input. template diff --git a/tensorflow/compiler/tf2xla/functionalize_while.cc b/tensorflow/compiler/tf2xla/functionalize_while.cc index 0294c018e512db..70f98b3e88daec 100644 --- a/tensorflow/compiler/tf2xla/functionalize_while.cc +++ b/tensorflow/compiler/tf2xla/functionalize_while.cc @@ -103,10 +103,10 @@ Status CopySubgraph(const Graph& graph, const WhileLoopFrame* frame, output->AddEdge(src_copy, src_output, dst_copy, e->dst_input()); } } - return OkStatus(); + return absl::OkStatus(); } -StatusOr BuildArgNode(Graph* graph, DataType type, int index) { +absl::StatusOr BuildArgNode(Graph* graph, DataType type, int index) { const char* const kArgOp = "_Arg"; NodeDef arg_def; NodeDefBuilder builder(absl::StrCat(kArgOp, index), kArgOp); @@ -206,7 +206,7 @@ Status BuildLoopBody(const Graph& graph, WhileLoopFrame* frame, TF_RETURN_IF_ERROR(CopySubgraph(graph, frame, std::move(next_iterations), squash_src_outputs, &node_map, output)); - return OkStatus(); + return absl::OkStatus(); } Status FunctionalizeLoop(Graph* graph, WhileLoopFrame* frame, @@ -216,7 +216,7 @@ Status FunctionalizeLoop(Graph* graph, WhileLoopFrame* frame, VLOG(2) << "Skipping functionalization for frame " << frame->name << " because it has control flow nodes that are filtered out by " "the specified node filter."; - return OkStatus(); + return absl::OkStatus(); } VLOG(2) << "Frame " << frame->name << " before: " << DumpGraphToFile("functionalize_before", *graph, library); @@ -501,7 +501,7 @@ Status FunctionalizeLoop(Graph* graph, WhileLoopFrame* frame, VLOG(2) << "Frame " << frame->name << " after: " << DumpGraphToFile("functionalize_after", *graph, library); - return OkStatus(); + return absl::OkStatus(); } } // namespace @@ -565,7 +565,7 @@ Status FunctionalizeWhileLoop(Graph* graph, FunctionLibraryDefinition* library, } } - return OkStatus(); + return absl::OkStatus(); } } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/fused_batchnorm_reserve_space_test.cc b/tensorflow/compiler/tf2xla/fused_batchnorm_reserve_space_test.cc index 0179168be93bcd..70c09bc84ac275 100644 --- a/tensorflow/compiler/tf2xla/fused_batchnorm_reserve_space_test.cc +++ b/tensorflow/compiler/tf2xla/fused_batchnorm_reserve_space_test.cc @@ -60,7 +60,7 @@ Status GetTestDevice(Session* session, string* test_device) { *test_device = found_gpu ? "GPU" : "CPU"; VLOG(2) << "Using test device " << *test_device; - return OkStatus(); + return absl::OkStatus(); } void FillZeros(Tensor* tensor) { diff --git a/tensorflow/compiler/tf2xla/graph_compiler.cc b/tensorflow/compiler/tf2xla/graph_compiler.cc index 4914efdba55a07..23eb33224dc24b 100644 --- a/tensorflow/compiler/tf2xla/graph_compiler.cc +++ b/tensorflow/compiler/tf2xla/graph_compiler.cc @@ -114,7 +114,7 @@ Status PrepareArguments(XlaOpKernelContext* ctx, Graph* graph, return errors::InvalidArgument("Invalid function argument"); } } - return OkStatus(); + return absl::OkStatus(); } } // namespace Status GraphCompiler::Compile() { @@ -204,7 +204,7 @@ Status GraphCompiler::Compile() { } } } - return OkStatus(); + return absl::OkStatus(); } namespace { @@ -221,7 +221,7 @@ Status GetFunctionNameAndAttr(const FunctionLibraryRuntime& flib, " does not have 'func' field set"); } *func = attr_value->func(); - return OkStatus(); + return absl::OkStatus(); } if (flib.GetFunctionLibraryDefinition()->Find(node.def().op())) { @@ -230,7 +230,7 @@ Status GetFunctionNameAndAttr(const FunctionLibraryRuntime& flib, func->set_name(FunctionLibraryDefinition::kGradientOp); } *func->mutable_attr() = node.def().attr(); - return OkStatus(); + return absl::OkStatus(); } } // namespace diff --git a/tensorflow/compiler/tf2xla/graph_compiler_util.cc b/tensorflow/compiler/tf2xla/graph_compiler_util.cc index 68c576a52cba73..ac064805f1a470 100644 --- a/tensorflow/compiler/tf2xla/graph_compiler_util.cc +++ b/tensorflow/compiler/tf2xla/graph_compiler_util.cc @@ -106,7 +106,7 @@ Status AddArgNodes(Graph* graph, const NodeMap& node_map, graph->RemoveEdge(edge); } } - return OkStatus(); + return absl::OkStatus(); } // Each fetch id identifies the positional output of some node. For each fetch @@ -138,7 +138,7 @@ Status AddRetvalNodes(Graph* graph, const NodeMap& node_map, .Finalize(graph, &retval_node)); retval_nodes->insert(retval_node); } - return OkStatus(); + return absl::OkStatus(); } // RewriteAndPruneGraph identifies input and output edges (named by the feed and @@ -192,7 +192,7 @@ Status RewriteAndPruneGraph( ", missing feeds: ", absl::StrJoin(missing_feeds, ", "), ", missing fetches: ", absl::StrJoin(missing_fetches, ", ")); } - return OkStatus(); + return absl::OkStatus(); } // CollectArgNodes collects _Arg nodes from the graph, and performs basic @@ -224,7 +224,7 @@ Status CollectArgNodes(const Graph& graph, std::vector* arg_nodes) { } arg_nodes->push_back(index_node.second); } - return OkStatus(); + return absl::OkStatus(); } } // namespace @@ -243,7 +243,7 @@ Status CreateXlaArgs(const Graph& graph, TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), kDebugNameAttr, &arg.name)); xla_args->push_back(arg); } - return OkStatus(); + return absl::OkStatus(); } void PopulateXlaArgs(const tf2xla::Config& config, @@ -306,7 +306,7 @@ Status InitGraph(const GraphDef& graph_def, const tf2xla::Config& config, TF_RETURN_IF_ERROR(g->AddFunctionLibrary(flib_def.ToProto())); *graph = std::move(g); - return OkStatus(); + return absl::OkStatus(); } } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/host_compute_metadata.proto b/tensorflow/compiler/tf2xla/host_compute_metadata.proto index 43ab371a217e6c..9e6eec2cddc99e 100644 --- a/tensorflow/compiler/tf2xla/host_compute_metadata.proto +++ b/tensorflow/compiler/tf2xla/host_compute_metadata.proto @@ -1,19 +1,21 @@ syntax = "proto3"; package tensorflow.tf2xla; + +import "tensorflow/core/framework/tensor_shape.proto"; +import "tensorflow/core/framework/types.proto"; + option cc_enable_arenas = true; option java_outer_classname = "Tf2XlaProtos"; option java_multiple_files = true; option java_package = "org.tensorflow.tf2xla"; -import "tensorflow/core/framework/tensor_shape.proto"; -import "tensorflow/core/framework/types.proto"; - // TensorMetadata indicates the type and shape of a Tensor that is // part of a host compute transfer. message TensorMetadata { DataType type = 1; TensorShapeProto shape = 2; + int64 channel_id = 3; } // HostTransferMetadata describes a transfer either from host to device diff --git a/tensorflow/compiler/tf2xla/kernels/all_reduce_op.cc b/tensorflow/compiler/tf2xla/kernels/all_reduce_op.cc index 80a43e2026d875..ba99f9b1297542 100644 --- a/tensorflow/compiler/tf2xla/kernels/all_reduce_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/all_reduce_op.cc @@ -96,7 +96,6 @@ class CollectiveReduceV2Op : public XlaOpKernel { void operator=(const CollectiveReduceV2Op&) = delete; }; - REGISTER_XLA_OP(Name("CollectiveReduceV2") .CompileTimeConstantInput("group_key") .CompileTimeConstantInput("group_size"), @@ -106,4 +105,13 @@ REGISTER_XLA_OP(Name("CollectiveAssignGroupV2") .CompileTimeConstantInput("group_assignment"), MlirXlaOpKernel); +REGISTER_XLA_OP(Name("XlaReduceScatter") + .CompileTimeConstantInput("group_assignment") + .CompileTimeConstantInput("scatter_dimension"), + MlirXlaOpKernel); + +REGISTER_XLA_OP( + Name("XlaAllReduce").CompileTimeConstantInput("group_assignment"), + MlirXlaOpKernel); + } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.cc b/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.cc index c2bf7f7f606432..f44d800481fe6a 100644 --- a/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.cc +++ b/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.cc @@ -148,7 +148,7 @@ Status CheckConvAttrs(const ConvOpAttrs& attrs) { attrs.dilations[input_dim]); } } - return OkStatus(); + return absl::OkStatus(); } // Wrapper around ConvBackpropComputeDimensions that converts from XLA shapes diff --git a/tensorflow/compiler/tf2xla/kernels/fused_conv_ops.cc b/tensorflow/compiler/tf2xla/kernels/fused_conv_ops.cc index 2146284bd69864..92138c9663c556 100644 --- a/tensorflow/compiler/tf2xla/kernels/fused_conv_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/fused_conv_ops.cc @@ -277,7 +277,7 @@ class FusedConv2DInt8Op : public XlaOpKernel { } ctx->SetOutput(0, result); - return OkStatus(); + return absl::OkStatus(); } void Compile(XlaOpKernelContext* ctx) override { diff --git a/tensorflow/compiler/tf2xla/kernels/gather_op.cc b/tensorflow/compiler/tf2xla/kernels/gather_op.cc index 59d196c351c99b..5877aea0269643 100644 --- a/tensorflow/compiler/tf2xla/kernels/gather_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/gather_op.cc @@ -86,7 +86,7 @@ Status XlaGather(const xla::XlaOp& input, const TensorShape& input_shape, *gather_output = xla::Broadcast(XlaHelpers::Zero(builder, dtype), out_shape.dim_sizes()); - return OkStatus(); + return absl::OkStatus(); } for (int64_t i = 0; i < num_index_dims; ++i) { @@ -152,7 +152,7 @@ Status XlaGather(const xla::XlaOp& input, const TensorShape& input_shape, } *gather_output = xla::Gather(input, indices, dim_numbers, slice_sizes); - return OkStatus(); + return absl::OkStatus(); } Status XlaGatherWithBatchDimsOpImpl(XlaOpKernelContext* context, @@ -236,7 +236,7 @@ Status XlaGatherWithBatchDimsOpImpl(XlaOpKernelContext* context, /*indices_are_nd=*/false, context->expected_output_dtype(0), index_type, context->builder(), gather_output)); } - return OkStatus(); + return absl::OkStatus(); } class GatherOp : public XlaOpKernel { public: diff --git a/tensorflow/compiler/tf2xla/kernels/if_op.cc b/tensorflow/compiler/tf2xla/kernels/if_op.cc index 945ae96b46c327..b2d0b2e1d418f2 100644 --- a/tensorflow/compiler/tf2xla/kernels/if_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/if_op.cc @@ -181,7 +181,7 @@ static Status ValidateShapes(XlaOpKernelContext* ctx, "Mismatch in resource of then and else branch for resource ", i); } } - return OkStatus(); + return absl::OkStatus(); } // TODO(b/35949885): There is duplication here with the handling of the diff --git a/tensorflow/compiler/tf2xla/kernels/light_outside_compilation.cc b/tensorflow/compiler/tf2xla/kernels/light_outside_compilation.cc index fbdcecdcf95dfd..60ef289567fe8f 100644 --- a/tensorflow/compiler/tf2xla/kernels/light_outside_compilation.cc +++ b/tensorflow/compiler/tf2xla/kernels/light_outside_compilation.cc @@ -53,8 +53,8 @@ limitations under the License. #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/status_macros.h" -#include "xla/stream_executor/multi_platform_manager.h" #include "xla/stream_executor/platform.h" +#include "xla/stream_executor/platform_manager.h" #include "xla/stream_executor/stream.h" #include "xla/util.h" #include "tensorflow/core/common_runtime/gpu/gpu_process_state.h" @@ -374,7 +374,7 @@ Status LightOutsideCompilationOp::CompileToCustomCallCallingTfKernel( output_shape.IsTuple() ? xla::GetTupleElement(out, i) : out); } - return OkStatus(); + return absl::OkStatus(); } namespace { @@ -552,7 +552,7 @@ Status PopulateMetadataBufferIfNeeded(OpKernelContext& ctx, num_dimensions * sizeof(int32_t)); } } - return OkStatus(); + return absl::OkStatus(); } class FakeDeviceContext : public DeviceContext { @@ -569,8 +569,7 @@ Status CallTfKernel(void* stream_handle, void** buffers, const char* opaque, // Look up the platform only once, for a small performance gain. static Status* platform_status = nullptr; static se::Platform* platform = [&]() -> se::Platform* { - StatusOr p = - se::MultiPlatformManager::PlatformWithName("CUDA"); + StatusOr p = se::PlatformManager::PlatformWithName("CUDA"); if (!p.ok()) { platform_status = new Status(p.status()); return nullptr; @@ -708,7 +707,7 @@ Status CallTfKernel(void* stream_handle, void** buffers, const char* opaque, } TF_RETURN_IF_ERROR(ctx.status()); - return OkStatus(); + return absl::OkStatus(); } void GenericTfCallback(void* stream_handle, void** buffers, const char* opaque, diff --git a/tensorflow/compiler/tf2xla/kernels/listdiff_op.cc b/tensorflow/compiler/tf2xla/kernels/listdiff_op.cc index 9a86b26fb8a623..258e0b0d47f6a0 100644 --- a/tensorflow/compiler/tf2xla/kernels/listdiff_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/listdiff_op.cc @@ -98,7 +98,7 @@ class ListDiffOp : public XlaOpKernel { xla::ConstantR1(context->builder(), val_output)); context->SetOutput(1, xla::ConstantR1(context->builder(), idx_output)); - return OkStatus(); + return absl::OkStatus(); } template diff --git a/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc b/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc index fdd38e2f6beb32..a82cd2e3b85db1 100644 --- a/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc @@ -55,7 +55,7 @@ static Status ValidateKernelSizes(const T& ksizes) { " must be positive but is ", ksizes[i]); } } - return OkStatus(); + return absl::OkStatus(); } template @@ -67,7 +67,7 @@ static Status ValidateStrides(const T& strides) { " must be positive but is ", strides[i]); } } - return OkStatus(); + return absl::OkStatus(); } // Superclass of pooling ops. diff --git a/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc b/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc index ce20763c6146eb..81282578bb2ee6 100644 --- a/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc @@ -101,7 +101,7 @@ class MaxOp : public XlaReductionOp { "Unsupported PrimitiveType in MaxOp: '", xla::PrimitiveType_Name(xla_reduction_type), "'"); } else { - return OkStatus(); + return absl::OkStatus(); } } diff --git a/tensorflow/compiler/tf2xla/kernels/scatter_nd_op.cc b/tensorflow/compiler/tf2xla/kernels/scatter_nd_op.cc index 77062741c0f91f..55eaf3db8a7570 100644 --- a/tensorflow/compiler/tf2xla/kernels/scatter_nd_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/scatter_nd_op.cc @@ -58,7 +58,7 @@ Status ValidateUpdateShape(const TensorShape& buffer_shape, }; if (updates_shape.dims() == 0 && broadcast_scalar_update) { - return OkStatus(); + return absl::OkStatus(); } if (updates_shape.dims() < batch_dim) return shape_err(); @@ -81,7 +81,7 @@ Status ValidateUpdateShape(const TensorShape& buffer_shape, return shape_err(); } } - return OkStatus(); + return absl::OkStatus(); } class ScatterNdOp : public XlaOpKernel { diff --git a/tensorflow/compiler/tf2xla/kernels/shape_util.cc b/tensorflow/compiler/tf2xla/kernels/shape_util.cc index 06f7160e392456..1cd296b349a9dc 100644 --- a/tensorflow/compiler/tf2xla/kernels/shape_util.cc +++ b/tensorflow/compiler/tf2xla/kernels/shape_util.cc @@ -42,7 +42,7 @@ Status TensorShapeToConstant(const TensorShape& input_shape, vec(i) = dim_size; } } - return OkStatus(); + return absl::OkStatus(); } } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/sharding_util_ops.cc b/tensorflow/compiler/tf2xla/kernels/sharding_util_ops.cc index c175a5584e8f2d..3c1ca648769d95 100644 --- a/tensorflow/compiler/tf2xla/kernels/sharding_util_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/sharding_util_ops.cc @@ -91,7 +91,7 @@ Status GetAndValidateAttributes(OpKernelConstruction* ctx, paddings.assign(expected_rank, 0); } - return OkStatus(); + return absl::OkStatus(); } std::vector GetSliceIndices(absl::Span num_partitions, @@ -174,10 +174,10 @@ class XlaSplitNDBaseOp : public XlaOpKernel { xla::Pad(input, xla::ConstantR0WithType(ctx->builder(), type, /*value=*/0), padding_config)); - return OkStatus(); + return absl::OkStatus(); } else if (num_slices_ == 1) { ctx->SetOutput(/*index=*/0, input); - return OkStatus(); + return absl::OkStatus(); } // Slice shape with optional padding. @@ -242,7 +242,7 @@ class XlaSplitNDBaseOp : public XlaOpKernel { slice_limit_indices, slice_strides)); } } - return OkStatus(); + return absl::OkStatus(); } private: @@ -426,7 +426,7 @@ class XlaConcatNDBaseOp : public XlaOpKernel { output_shape.push_back(max_dim_size - paddings_[dim]); } - return OkStatus(); + return absl::OkStatus(); } std::vector num_concats_; diff --git a/tensorflow/compiler/tf2xla/kernels/stack_ops.cc b/tensorflow/compiler/tf2xla/kernels/stack_ops.cc index e74381aa6f24b4..8131769503086c 100644 --- a/tensorflow/compiler/tf2xla/kernels/stack_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/stack_ops.cc @@ -85,7 +85,7 @@ Status MaybeInitializeStack(xla::XlaBuilder* builder, XlaResource* resource, actual_shape.DebugString()); } } - return OkStatus(); + return absl::OkStatus(); } class StackOp : public XlaOpKernel { diff --git a/tensorflow/compiler/tf2xla/kernels/stateful_random_ops.cc b/tensorflow/compiler/tf2xla/kernels/stateful_random_ops.cc index 6fbb413b46bc3b..d5746a5bfd729b 100644 --- a/tensorflow/compiler/tf2xla/kernels/stateful_random_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/stateful_random_ops.cc @@ -142,7 +142,7 @@ Status CheckStateShape(xla::RandomAlgorithm alg, const TensorShape& shape) { return errors::InvalidArgument("The size of the state must be at least ", min_state_size, "; got ", state_size); } - return OkStatus(); + return absl::OkStatus(); } StatusOr ResolveAlg(int alg_id) { @@ -227,7 +227,7 @@ Status CompileImpl( var = BitcastConvertType(var, state_element_type); TF_RETURN_IF_ERROR( ctx->AssignVariable(state_input_idx, STATE_ELEMENT_DTYPE, var)); - return OkStatus(); + return absl::OkStatus(); } class StatefulUniformOp : public XlaOpKernel { diff --git a/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc b/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc index 22614039459a90..aca9973e118c14 100644 --- a/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc @@ -84,7 +84,7 @@ Status MaybeInitializeTensorArray(xla::XlaBuilder* builder, shape.DebugString()); } } - return OkStatus(); + return absl::OkStatus(); } // Checks that the TensorArray 'resource' has been initialized, and has type @@ -106,14 +106,14 @@ Status CheckTensorArrayIsInitialized(const string& op_name, " but op has dtype ", DataTypeString(dtype), "."); } - return OkStatus(); + return absl::OkStatus(); } Status GetTensorArrayShape(const XlaResource* resource, xla::XlaBuilder* builder, TensorShape* shape) { *shape = resource->shape(); shape->InsertDim(0, resource->max_array_size()); - return OkStatus(); + return absl::OkStatus(); } // Like XlaBuilder::DynamicUpdateSlice, but adds 'update' to the diff --git a/tensorflow/compiler/tf2xla/kernels/tensor_list_ops.cc b/tensorflow/compiler/tf2xla/kernels/tensor_list_ops.cc index 7e4c16ca189de3..8f8e6bd90ca448 100644 --- a/tensorflow/compiler/tf2xla/kernels/tensor_list_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/tensor_list_ops.cc @@ -119,19 +119,19 @@ Status TryGetElementShapeFromInput(XlaOpKernelContext* ctx, xla::XlaOp input, bool is_compile_time_constant = is_compile_time_constant_or.value(); if (!is_compile_time_constant) { *got_shape = false; - return OkStatus(); + return absl::OkStatus(); } PartialTensorShape partial_shape; TF_RETURN_IF_ERROR(ctx->ConstantInputAsPartialShape(0, &partial_shape)); if (!partial_shape.IsFullyDefined()) { *got_shape = false; - return OkStatus(); + return absl::OkStatus(); } *shape = xla::ShapeUtil::MakeShape(dtype, partial_shape.dim_sizes()); *got_shape = true; - return OkStatus(); + return absl::OkStatus(); } class TensorListReserveOp : public XlaOpKernel { diff --git a/tensorflow/compiler/tf2xla/kernels/tensor_list_utils.cc b/tensorflow/compiler/tf2xla/kernels/tensor_list_utils.cc index 575a3f400d899b..d7a1b5f970561a 100644 --- a/tensorflow/compiler/tf2xla/kernels/tensor_list_utils.cc +++ b/tensorflow/compiler/tf2xla/kernels/tensor_list_utils.cc @@ -117,7 +117,7 @@ bool IsTensorListInput(XlaOpKernelContext* ctx, int index) { Status IsTensorListInitialized(xla::XlaOp list, bool* is_initialized) { TF_ASSIGN_OR_RETURN(xla::Shape list_shape, list.builder()->GetShape(list)); *is_initialized = list_shape.IsTuple(); - return OkStatus(); + return absl::OkStatus(); } Status IsNestedTensorList(xla::XlaOp list, bool* is_nested_list) { @@ -128,14 +128,14 @@ Status IsNestedTensorList(xla::XlaOp list, bool* is_nested_list) { } TF_ASSIGN_OR_RETURN(xla::Shape list_shape, list.builder()->GetShape(list)); *is_nested_list = (xla::ShapeUtil::TupleElementCount(list_shape) > 2); - return OkStatus(); + return absl::OkStatus(); } Status BuildNonNestedTensorList(xla::XlaOp buffer, xla::XlaOp push_index, xla::XlaOp* output_list) { TF_RET_CHECK(buffer.builder()); *output_list = xla::Tuple(buffer.builder(), {buffer, push_index}); - return OkStatus(); + return absl::OkStatus(); } Status GetTensorListBufferShape(xla::XlaOp list, xla::Shape* buffer_shape) { @@ -146,7 +146,7 @@ Status GetTensorListBufferShape(xla::XlaOp list, xla::Shape* buffer_shape) { } TF_ASSIGN_OR_RETURN(xla::Shape list_shape, list.builder()->GetShape(list)); *buffer_shape = xla::ShapeUtil::GetTupleElementShape(list_shape, 0); - return OkStatus(); + return absl::OkStatus(); } Status GetTensorListBuffer(xla::XlaOp list, xla::XlaOp* buffer) { @@ -156,7 +156,7 @@ Status GetTensorListBuffer(xla::XlaOp list, xla::XlaOp* buffer) { return errors::InvalidArgument("TensorList is not initialized"); } *buffer = xla::GetTupleElement(list, 0); - return OkStatus(); + return absl::OkStatus(); } Status GetTensorListPushIndex(xla::XlaOp list, xla::XlaOp* push_index) { @@ -168,7 +168,7 @@ Status GetTensorListPushIndex(xla::XlaOp list, xla::XlaOp* push_index) { TF_ASSIGN_OR_RETURN(xla::Shape list_shape, list.builder()->GetShape(list)); int tuple_size = xla::ShapeUtil::TupleElementCount(list_shape); *push_index = xla::GetTupleElement(list, tuple_size - 1); - return OkStatus(); + return absl::OkStatus(); } Status SetTensorListPushIndex(xla::XlaOp list, xla::XlaOp push_index, @@ -187,7 +187,7 @@ Status SetTensorListPushIndex(xla::XlaOp list, xla::XlaOp push_index, } result_parts.push_back(push_index); *result = xla::Tuple(list.builder(), result_parts); - return OkStatus(); + return absl::OkStatus(); } xla::XlaOp BuildUninitializedTensorList(xla::XlaBuilder* b, @@ -222,7 +222,7 @@ Status GetLeadingDimForTensorList(xla::XlaOp list, int64_t* leading_dim, *leading_dim = list_shape.dimensions(0); *leading_dim_dynamic_size = xla::GetDimensionSize(list, 0); } - return OkStatus(); + return absl::OkStatus(); } Status GetTensorListShapeFromElementTensorListShape( @@ -244,7 +244,7 @@ Status GetTensorListShapeFromElementTensorListShape( shapes.push_back(xla::ShapeUtil::MakeShape(xla::PrimitiveType::S32, std::vector{})); *tensor_list_shape = xla::ShapeUtil::MakeTupleShape(shapes); - return OkStatus(); + return absl::OkStatus(); } Status GetTensorListShapeFromElementShape(const xla::Shape& element_shape, @@ -267,7 +267,7 @@ Status GetTensorListShapeFromElementShape(const xla::Shape& element_shape, shapes.push_back(xla::ShapeUtil::MakeShape(xla::PrimitiveType::S32, std::vector{})); *tensor_list_shape = xla::ShapeUtil::MakeTupleShape(shapes); - return OkStatus(); + return absl::OkStatus(); } Status CreateZerosTensorListWithShape( @@ -296,7 +296,7 @@ Status CreateZerosTensorListWithShape( .element_type() == xla::S32); elements.push_back(xla::ConstantLiteral(b, xla::LiteralUtil::Zero(xla::S32))); *list = xla::Tuple(b, elements); - return OkStatus(); + return absl::OkStatus(); } Status GetInitializedTensorListForElement(xla::XlaOp list, xla::XlaOp element, @@ -330,7 +330,7 @@ Status GetInitializedTensorListForElement(xla::XlaOp list, xla::XlaOp element, ", expected: ", list_shape.DebugString()); } *initialized_list = list; - return OkStatus(); + return absl::OkStatus(); } else { // Prepare dynamic dimension dimensions for zero tensor list. The dynamic // sizes are created by reading the dynamic dimension size of sub-elements. @@ -414,7 +414,7 @@ Status ExecuteTensorListPushBack(xla::XlaOp list, xla::XlaOp element, result_parts.push_back(updated_push_index); *result = xla::Tuple(b, result_parts); - return OkStatus(); + return absl::OkStatus(); } Status ExecuteTensorListPopBack(xla::XlaOp list, xla::XlaOp* list_result, @@ -463,7 +463,7 @@ Status ExecuteTensorListPopBack(xla::XlaOp list, xla::XlaOp* list_result, *element_result = element_result_parts[0]; } - return OkStatus(); + return absl::OkStatus(); } Status ExecuteTensorListSetItem(xla::XlaOp list, xla::XlaOp index, @@ -499,7 +499,7 @@ Status ExecuteTensorListSetItem(xla::XlaOp list, xla::XlaOp index, result_parts.push_back(updated_list_part); result_parts.push_back(xla::GetTupleElement(list, 1)); *result = xla::Tuple(b, result_parts); - return OkStatus(); + return absl::OkStatus(); } Status ExecuteTensorListGetItem(xla::XlaOp list, xla::XlaOp index, @@ -541,7 +541,7 @@ Status ExecuteTensorListGetItem(xla::XlaOp list, xla::XlaOp index, } slice_shape.erase(slice_shape.begin()); *result = xla::Reshape(read, slice_shape); - return OkStatus(); + return absl::OkStatus(); } Status ExecuteTensorListFromTensor(int push_index, xla::XlaOp tensor, @@ -558,7 +558,7 @@ Status ExecuteTensorListFromTensor(int push_index, xla::XlaOp tensor, std::vector result_parts{tensor, xla::ConstantR0(b, push_index)}; *result = xla::Tuple(b, result_parts); - return OkStatus(); + return absl::OkStatus(); } } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/to_bool_op.cc b/tensorflow/compiler/tf2xla/kernels/to_bool_op.cc index 015a0ce40e80d3..e7ba8d13082849 100644 --- a/tensorflow/compiler/tf2xla/kernels/to_bool_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/to_bool_op.cc @@ -42,7 +42,7 @@ class ToBoolOp : public XlaOpKernel { if (shape.rank() == 0) { auto result = xla::Ne(ctx->Input(0), xla::ZerosLike(input)); ctx->SetOutput(0, result); - return OkStatus(); + return absl::OkStatus(); } // Otherwise, any input tensor with elements returns True. Input tensor @@ -54,7 +54,7 @@ class ToBoolOp : public XlaOpKernel { auto result = xla::Ne(num_elements, xla::ZerosLike(num_elements)); ctx->SetOutput(0, result); - return OkStatus(); + return absl::OkStatus(); } }; diff --git a/tensorflow/compiler/tf2xla/kernels/variable_ops.cc b/tensorflow/compiler/tf2xla/kernels/variable_ops.cc index 7cdf30594581c6..27cfc014bacb04 100644 --- a/tensorflow/compiler/tf2xla/kernels/variable_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/variable_ops.cc @@ -41,7 +41,7 @@ Status ValidateAssignUpdateVariableOpShapes(XlaOpKernelContext* ctx) { ctx->GetVariableTypeAndShape(0, &variable_dtype, &variable_shape)); TF_RETURN_IF_ERROR( ValidateAssignUpdateVariableOpShapes(variable_shape, value_shape)); - return OkStatus(); + return absl::OkStatus(); } class VarIsInitializedOp : public XlaOpKernel { diff --git a/tensorflow/compiler/tf2xla/kernels/while_op.cc b/tensorflow/compiler/tf2xla/kernels/while_op.cc index c3370dcde64b70..d6685cc1e1d965 100644 --- a/tensorflow/compiler/tf2xla/kernels/while_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/while_op.cc @@ -77,7 +77,7 @@ Status VerifyResourceArgsGroupedAtEnd(XlaOpKernelContext* ctx, } } } - return OkStatus(); + return absl::OkStatus(); } // Builds XlaCompiler argument descriptions `args` from `ctx`. @@ -128,7 +128,7 @@ Status MakeXlaCompilerArgumentsFromInputs( } } } - return OkStatus(); + return absl::OkStatus(); } // Populates loop invariant indices to true in `loop_invariants`. @@ -186,7 +186,7 @@ Status ConvertLoopInvariantsToConst( compile_time_const_arg_indices->at(arg_idx) = true; (*num_compile_time_const_args)++; } - return OkStatus(); + return absl::OkStatus(); } Status VerifyBodyInputAndOutputShapeMatch( @@ -213,7 +213,7 @@ Status VerifyBodyInputAndOutputShapeMatch( xla::ShapeUtil::HumanString(body_input_shape), " vs. ", xla::ShapeUtil::HumanString(body_output_shape)); } - return OkStatus(); + return absl::OkStatus(); } StatusOr BuildWrappedCond( diff --git a/tensorflow/compiler/tf2xla/kernels/xla_call_module_loader.cc b/tensorflow/compiler/tf2xla/kernels/xla_call_module_loader.cc index 589b9daec8772e..68737e43c7d8f6 100644 --- a/tensorflow/compiler/tf2xla/kernels/xla_call_module_loader.cc +++ b/tensorflow/compiler/tf2xla/kernels/xla_call_module_loader.cc @@ -28,7 +28,9 @@ limitations under the License. #include "absl/strings/string_view.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" +#include "llvm/Support/raw_ostream.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/Block.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project @@ -37,10 +39,12 @@ limitations under the License. #include "mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project #include "mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/IR/Location.h" // from @llvm-project +#include "mlir/IR/OperationSupport.h" // from @llvm-project #include "mlir/IR/SymbolTable.h" // from @llvm-project #include "mlir/IR/TypeUtilities.h" // from @llvm-project #include "mlir/IR/Types.h" // from @llvm-project #include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/IR/ValueRange.h" // from @llvm-project #include "mlir/IR/Verifier.h" // from @llvm-project #include "mlir/Parser/Parser.h" // from @llvm-project #include "mlir/Pass/PassManager.h" // from @llvm-project @@ -105,6 +109,9 @@ bool IsShapeAssertionsCheckDisabled( constexpr llvm::StringRef kUsesShapePolymorphismAttr = "jax.uses_shape_polymorphism"; +constexpr llvm::StringRef kCustomCallShimTarget = + "stablehlo.shape_refinement_operand_wrapper"; + } // namespace bool IsTokenType(mlir::Type type) { @@ -124,7 +131,7 @@ tsl::StatusOr> XlaCallModuleLoader::Create( return loader; } -tsl::Status XlaCallModuleLoader::SetPlatformIndex( +absl::Status XlaCallModuleLoader::SetPlatformIndex( absl::string_view compilation_platform) { int platform_index = -1; if (!platforms_.empty()) { @@ -186,7 +193,25 @@ tsl::Status XlaCallModuleLoader::SetPlatformIndex( return tsl::OkStatus(); } -tsl::Status XlaCallModuleLoader::RefineDynamicShapes( +static mlir::stablehlo::CustomCallOp MakeShapeRefinementOperandWrapper( + mlir::OpBuilder op_builder, mlir::Value operand, + llvm::ArrayRef shape) { + auto constant = op_builder.create( + operand.getLoc(), op_builder.getI64TensorAttr(shape)); + return op_builder.create( + operand.getLoc(), operand.getType(), mlir::ValueRange{operand, constant}, + llvm::SmallVector{ + op_builder.getNamedAttr( + "call_target_name", + op_builder.getStringAttr(kCustomCallShimTarget)), + op_builder.getNamedAttr("indices_of_shape_operands", + op_builder.getI64TensorAttr({1})), + op_builder.getNamedAttr("has_side_effect", + op_builder.getBoolAttr(false)), + }); +} + +absl::Status XlaCallModuleLoader::RefineDynamicShapes( llvm::ArrayRef input_shapes) { // Skip shape refinement for new versions if USES_SHAPE_POLYMORPHISM_ATTR=1 if (version_ >= kVersionStartSupportUsesShapePolymorphismAttr) { @@ -240,6 +265,7 @@ tsl::Status XlaCallModuleLoader::RefineDynamicShapes( ")")); } + // Derive static input types to use for main. mlir::Block &main_body = main_.front(); mlir::Builder builder(module_->getContext()); std::vector static_array_input_types(nr_inputs); @@ -264,7 +290,7 @@ tsl::Status XlaCallModuleLoader::RefineDynamicShapes( ConvertPrimitiveTypeToMLIRType(xla_shape.element_type(), builder)); mlir::RankedTensorType type = mlir::RankedTensorType::get(xla_dimensions, element_type); - // TODO(burmako): This fails with an obscure compilation error. + // TODO(burmako): This fails with an obscure compilation error on Windows. // TF_ASSIGN_OR_RETURN( // mlir::Type type, // ConvertShapeToType(xla_shape, builder)); @@ -301,39 +327,73 @@ tsl::Status XlaCallModuleLoader::RefineDynamicShapes( } } - // Refine 'main' argument types to use static input types instead. The main - // arguments may occur as return values, or as inputs to called functions, - // and changing their types may invalidate the module. To prevent this - // we insert dummy conversion ops as the sole uses of the main arguments, for - // the arguments that are not tokens and have dynamic shape. - // If we use stablehlo.convert, we end up with "convert 3xf32 -> *xf32" - // after we set the static shapes for the main arguments. The "convert" - // op does not support unranked result for ranked inputs. So, we use - // "bitcast_convert", which is more flexible in the relationship between - // the input and the result. + // Insert custom_call ops as shims to maintain the validity of the module when + // main's input types are changed later. This is a workaround to allow shape + // refinement to be applied; the custom_calls are removed before returning. + // Arguments to main may occur as return values, or as inputs to called + // functions, and changing their types may invalidate the module due to type + // mismatches. To prevent this, for each argument that is a dynamically-shaped + // tensor, we insert a custom_call op that takes the argument as an input and + // replace uses of the argument with the custom_call's result. custom_call + // is used as it allows its inputs and outputs to be unranked. + // + // Example: + // + // The below main function returns its argument directly: + // + // func.func @main(%arg0: tensor<*xf32>) -> tensor<*xf32> { + // return %arg0 : tensor<*xf32> + // } + // + // Changing the argument's type invalidates the IR (type mismatch): + // + // func.func @main(%arg0: tensor<2x3xf32>) -> tensor<*xf32> { + // return %arg0 : tensor<*xf32> + // } + // + // Inserting a custom_call allows the IR to remain valid: + // + // func.func @main(%arg0: tensor<2x3xf32>) -> tensor<*xf32> { + // %0 = stablehlo.constant dense<[2, 3]> : tensor<2xi64> + // %1 = stablehlo.custom_call + // @stablehlo.shape_refinement_operand_wrapper(%arg0, %0) + // {indices_of_shape_operands = dense<1> : tensor<1xi64>} : + // (tensor<2x3xf32>, tensor<2xi64>) -> tensor<*xf32> + // return %1 : tensor<*xf32> + // } + // + // After shapes are refined and the custom_calls are removed, we get: + // + // func.func @main(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> { + // return %arg0 : tensor<2x3xf32> + // } + // mlir::OpBuilder op_builder(module_->getBodyRegion()); op_builder.setInsertionPointToStart(&main_body); for (auto i = 0; i < main_body.getNumArguments(); ++i) { mlir::BlockArgument arg = main_body.getArgument(i); mlir::Type arg_type = arg.getType(); - if (IsTokenType(arg_type)) { + bool is_input_refined = arg_type == static_array_input_types[i]; + if (IsTokenType(arg_type) || is_input_refined) { continue; } auto ranked_arg_type = arg_type.dyn_cast(); if (!ranked_arg_type || !ranked_arg_type.hasStaticShape()) { - auto convert_op = op_builder.create( - arg.getLoc(), arg_type, arg); - arg.replaceAllUsesExcept(convert_op, convert_op); + auto type = static_array_input_types[i].cast(); + auto custom_call = + MakeShapeRefinementOperandWrapper(op_builder, arg, type.getShape()); + auto call_result = custom_call.getResult(0); + arg.replaceAllUsesExcept(call_result, custom_call); } } - auto static_array_output_types = llvm::to_vector(main_.getResultTypes()); + // Actually update main's input types. for (auto i = 0; i < main_body.getNumArguments(); ++i) { auto arg = main_body.getArgument(i); arg.setType(static_array_input_types[i]); } main_.setType(builder.getFunctionType(static_array_input_types, - static_array_output_types)); + main_.getResultTypes())); if (VLOG_IS_ON(5)) { DumpMlirOpToFile("xla_call_module.after_refined_input_types", *module_); } @@ -343,13 +403,34 @@ tsl::Status XlaCallModuleLoader::RefineDynamicShapes( TF_RETURN_IF_ERROR( xla::RefinePolymorphicShapes(*module_, enable_shape_assertions)); + // Clean up custom_call shims. + for (auto call : llvm::make_early_inc_range( + main_body.getOps())) { + if (call->getAttr("call_target_name").cast().strref() == + kCustomCallShimTarget) { + auto operand = call->getOperand(0); + auto result = call->getResult(0); + if (operand.getType() != result.getType()) { + std::string s; + llvm::raw_string_ostream os(s); + os << "custom_call shim shape refinement failed, input type does not " + "match output type: " + << operand.getType() << " != " << result.getType(); + return absl::InvalidArgumentError(os.str()); + } + call->getResult(0).replaceAllUsesExcept(call->getOperand(0), call); + call.erase(); + } + } + if (VLOG_IS_ON(3)) { DumpMlirOpToFile("xla_call_module.after_shape_refinement", *module_); } + return tsl::OkStatus(); } -tsl::Status XlaCallModuleLoader::LoadModule( +absl::Status XlaCallModuleLoader::LoadModule( mlir::MLIRContext *context, int version, std::string module_str, std::vector disabled_checks, std::vector platforms, int num_invocation_args, @@ -446,7 +527,7 @@ tsl::Status XlaCallModuleLoader::LoadModule( return tsl::OkStatus(); } -tsl::Status XlaCallModuleLoader::ValidateDialect() { +absl::Status XlaCallModuleLoader::ValidateDialect() { mlir::StatusScopedDiagnosticHandler diag_handler(module_->getContext()); bool moduleHasUnsupportedDialects = false; diff --git a/tensorflow/compiler/tf2xla/kernels/xla_call_module_loader.h b/tensorflow/compiler/tf2xla/kernels/xla_call_module_loader.h index e77ce0effcf92c..3e9627ebcc29a2 100644 --- a/tensorflow/compiler/tf2xla/kernels/xla_call_module_loader.h +++ b/tensorflow/compiler/tf2xla/kernels/xla_call_module_loader.h @@ -52,7 +52,7 @@ class XlaCallModuleLoader { // Sets the platform index argument, if the module is compiled for multiple // platforms, and then erases the argument. - tsl::Status SetPlatformIndex(absl::string_view compilation_platform); + absl::Status SetPlatformIndex(absl::string_view compilation_platform); // Refines the dynamic module arguments based on the static argument shapes. // This assumes that the module has a "main" function without dimension args, @@ -71,10 +71,10 @@ class XlaCallModuleLoader { // cause lifetime issues. // The input_shapes includes only the non-token and the non-platform-index // arguments. - tsl::Status RefineDynamicShapes(llvm::ArrayRef input_shapes); + absl::Status RefineDynamicShapes(llvm::ArrayRef input_shapes); // Validates that the module only contains ops from valid dialects. - tsl::Status ValidateDialect(); + absl::Status ValidateDialect(); // Validates that the module represents a statically-shaped StableHLO program, // otherwise all sorts of weirdness might happen in the HLO exporter which is @@ -97,16 +97,16 @@ class XlaCallModuleLoader { XlaCallModuleLoader() = default; // Initializes the loader with the given serialized module string. - tsl::Status LoadModule(mlir::MLIRContext* context, int version, - std::string module_str, - std::vector disabled_checks, - std::vector platforms, - int num_invocation_args, - bool main_has_token_input_output); + absl::Status LoadModule(mlir::MLIRContext* context, int version, + std::string module_str, + std::vector disabled_checks, + std::vector platforms, + int num_invocation_args, + bool main_has_token_input_output); // Adds a wrapper for the "main" function to compute the platform index and // the dimension arguments. - tsl::Status AddMainWrapper(); + absl::Status AddMainWrapper(); mlir::MLIRContext* context_; int version_; diff --git a/tensorflow/compiler/tf2xla/kernels/xla_custom_call_v2_op.cc b/tensorflow/compiler/tf2xla/kernels/xla_custom_call_v2_op.cc index cb36059f62051a..2f1883a289cd3c 100644 --- a/tensorflow/compiler/tf2xla/kernels/xla_custom_call_v2_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/xla_custom_call_v2_op.cc @@ -93,7 +93,7 @@ class XlaCustomCallV2Op : public XlaOpKernel { } } - return OkStatus(); + return absl::OkStatus(); } std::string call_target_name_; diff --git a/tensorflow/compiler/tf2xla/layout_util.cc b/tensorflow/compiler/tf2xla/layout_util.cc index eb4526a7ea3cc4..2a3379e98b55b7 100644 --- a/tensorflow/compiler/tf2xla/layout_util.cc +++ b/tensorflow/compiler/tf2xla/layout_util.cc @@ -74,12 +74,12 @@ Status RewriteLayoutWithShardedShape( layout_preference)); *xla_shape->mutable_layout() = per_device_xla_shape.layout(); } - return OkStatus(); + return absl::OkStatus(); } // There is a shape_representation_fn or sharding for an output, this function // uses a reshape to fix the layout. -StatusOr ReshapeWithCorrectRepresentationAndSharding( +absl::StatusOr ReshapeWithCorrectRepresentationAndSharding( xla::XlaBuilder* builder, xla::XlaOp original, xla::Shape original_shape, XlaShapeLayoutHelpers::ShapeDeterminationFns shape_determination_fns, std::optional sharding, bool fast_mem) { diff --git a/tensorflow/compiler/tf2xla/layout_util.h b/tensorflow/compiler/tf2xla/layout_util.h index 4a4a652c1e83ae..835728ce0c8027 100644 --- a/tensorflow/compiler/tf2xla/layout_util.h +++ b/tensorflow/compiler/tf2xla/layout_util.h @@ -67,7 +67,7 @@ Status RewriteLayoutWithShardedShape( // Adds reshapes to fix the layout of an output, if a shape_representation_fn or // sharding is present. -StatusOr ReshapeWithCorrectRepresentationAndSharding( +absl::StatusOr ReshapeWithCorrectRepresentationAndSharding( xla::XlaBuilder* builder, xla::XlaOp original, xla::Shape original_shape, XlaShapeLayoutHelpers::ShapeDeterminationFns shape_determination_fns, std::optional sharding, bool fast_mem); diff --git a/tensorflow/compiler/tf2xla/lib/broadcast.cc b/tensorflow/compiler/tf2xla/lib/broadcast.cc index ef70161fdcdcb2..da1b4182004e8a 100644 --- a/tensorflow/compiler/tf2xla/lib/broadcast.cc +++ b/tensorflow/compiler/tf2xla/lib/broadcast.cc @@ -51,7 +51,7 @@ Status BroadcastOpsToSame(xla::XlaOp* lhs, xla::XlaOp* rhs) { TF_ASSIGN_OR_RETURN(*lhs, xla::BroadcastTo(*lhs, bcast.output_shape())); TF_ASSIGN_OR_RETURN(*rhs, xla::BroadcastTo(*rhs, bcast.output_shape())); } - return OkStatus(); + return absl::OkStatus(); } } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/light_outside_compilation_kernels_for_test.cc b/tensorflow/compiler/tf2xla/light_outside_compilation_kernels_for_test.cc index e13333ca28b53b..788ef0fc95bf9d 100644 --- a/tensorflow/compiler/tf2xla/light_outside_compilation_kernels_for_test.cc +++ b/tensorflow/compiler/tf2xla/light_outside_compilation_kernels_for_test.cc @@ -49,10 +49,11 @@ class TestStaticTfOp : public OpKernel { se::DeviceMemoryBase gpu_dst{out_tensor->data(), size}; se::Stream* stream = ctx->op_device_context()->stream(); - stream->ThenMemcpyD2D( - /*gpu_dst=*/&gpu_dst, - /*gpu_src=*/se::DeviceMemoryBase{input.data(), size}, - /*size=*/input.AllocatedBytes()); + OP_REQUIRES_OK(ctx, + stream->MemcpyD2D( + /*gpu_dst=*/&gpu_dst, + /*gpu_src=*/se::DeviceMemoryBase{input.data(), size}, + /*size=*/input.AllocatedBytes())); } }; @@ -91,14 +92,16 @@ class TestStaticMultipleOutputTfOp : public OpKernel { se::Stream* stream = ctx->device()->tensorflow_accelerator_device_info()->stream; - stream->ThenMemcpyD2D( - /*gpu_dst=*/&gpu_dst1, - /*gpu_src=*/se::DeviceMemoryBase{input.data(), size}, - /*size=*/input.AllocatedBytes()); - stream->ThenMemcpyD2D( - /*gpu_dst=*/&gpu_dst2, - /*gpu_src=*/se::DeviceMemoryBase{input.data(), size}, - /*size=*/input.AllocatedBytes()); + OP_REQUIRES_OK(ctx, + stream->MemcpyD2D( + /*gpu_dst=*/&gpu_dst1, + /*gpu_src=*/se::DeviceMemoryBase{input.data(), size}, + /*size=*/input.AllocatedBytes())); + OP_REQUIRES_OK(ctx, + stream->MemcpyD2D( + /*gpu_dst=*/&gpu_dst2, + /*gpu_src=*/se::DeviceMemoryBase{input.data(), size}, + /*size=*/input.AllocatedBytes())); } }; @@ -145,11 +148,12 @@ class TestDynamicTfOp : public OpKernel { ctx->device()->tensorflow_accelerator_device_info()->stream; se::DeviceMemoryBase gpu_dst{out_tensor->data(), size_to_cpy}; - stream->ThenMemcpyD2D( - /*gpu_dst=*/&gpu_dst, - /*gpu_src=*/ - se::DeviceMemoryBase{input.data(), static_cast(size)}, - /*size=*/size_to_cpy); + OP_REQUIRES_OK(ctx, stream->MemcpyD2D( + /*gpu_dst=*/&gpu_dst, + /*gpu_src=*/ + se::DeviceMemoryBase{input.data(), + static_cast(size)}, + /*size=*/size_to_cpy)); } private: @@ -208,9 +212,9 @@ class DynamicMultidimOp : public OpKernel { se::Stream* stream = ctx->device()->tensorflow_accelerator_device_info()->stream; - stream->ThenMemcpy( - /*gpu_dst=*/&gpu_dst, /*host_src=*/host_data.data(), - /*size=*/num_elements * sizeof(float)); + OP_REQUIRES_OK(ctx, stream->Memcpy( + /*gpu_dst=*/&gpu_dst, /*host_src=*/host_data.data(), + /*size=*/num_elements * sizeof(float))); } }; @@ -280,9 +284,10 @@ class TestTfMustBeConstantOp : public OpKernel { TF_CHECK_OK(ctx->allocate_temp(input.dtype(), input.shape(), &tmp, pinned_alloc_attrs)); - stream->ThenMemcpy(tmp.data(), - se::DeviceMemoryBase{input.data(), allocated_size}, - allocated_size); + OP_REQUIRES_OK( + ctx, stream->Memcpy(tmp.data(), + se::DeviceMemoryBase{input.data(), allocated_size}, + allocated_size)); OP_REQUIRES_OK(ctx, stream->BlockHostUntilDone()); @@ -295,7 +300,7 @@ class TestTfMustBeConstantOp : public OpKernel { &out_tensor)); se::DeviceMemoryBase gpu_dst{out_tensor->data(), static_cast(allocated_size)}; - stream->ThenMemcpy(&gpu_dst, tmp.data(), allocated_size); + OP_REQUIRES_OK(ctx, stream->Memcpy(&gpu_dst, tmp.data(), allocated_size)); } }; @@ -339,10 +344,11 @@ class TestDynamicTfWithBoundOp : public OpKernel { se::Stream* stream = ctx->device()->tensorflow_accelerator_device_info()->stream; se::DeviceMemoryBase gpu_dst{out_tensor->data(), size_to_cpy}; - stream->ThenMemcpyD2D( - /*gpu_dst=*/&gpu_dst, - /*gpu_src=*/se::DeviceMemoryBase{input.data(), size_to_cpy}, - /*size=*/size_to_cpy); + OP_REQUIRES_OK( + ctx, stream->MemcpyD2D( + /*gpu_dst=*/&gpu_dst, + /*gpu_src=*/se::DeviceMemoryBase{input.data(), size_to_cpy}, + /*size=*/size_to_cpy)); } private: diff --git a/tensorflow/compiler/tf2xla/mlir_bridge_pass.cc b/tensorflow/compiler/tf2xla/mlir_bridge_pass.cc index 1ea2e98903d69e..bd8287dd3f7a95 100644 --- a/tensorflow/compiler/tf2xla/mlir_bridge_pass.cc +++ b/tensorflow/compiler/tf2xla/mlir_bridge_pass.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/mlir_bridge_pass.h" +#include #include #include "tensorflow/compiler/mlir/tf2xla/mlir_bridge_rollout_policy.h" @@ -33,6 +34,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/tf2xla/api/v1/tf_dialect_to_executor.h" #include "tensorflow/compiler/mlir/tf2xla/api/v2/cluster_tf.h" #include "tensorflow/compiler/mlir/tf2xla/api/v2/tf_dialect_to_executor.h" +#include "tensorflow/compiler/mlir/tf2xla/internal/mlir_bridge_pass_util.h" #include "tensorflow/compiler/tf2xla/tf2xla_defs.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/core/common_runtime/device_set.h" @@ -58,14 +60,6 @@ auto* mlir_bridge_gauge_v2 = monitoring::Gauge::New( "/tensorflow/config/experimental/enable_mlir_bridge_gauge_v2", "Tracks usage of the MLIR-based TF2XLA bridge among TF2 models"); -auto* replicated_graphs_without_device_type_counter = - tensorflow::monitoring::Counter<1>::New( - /* metric name */ - "/tensorflow/core/tf2xla/replicated_graphs_without_device_type_count", - /* metric description */ - "Tracks if any replicated graphs are without device type", - /* metric field */ "version"); - namespace { using ::mlir::ModuleOp; @@ -79,22 +73,6 @@ bool HasTPUDevice(const DeviceSet& device_set) { return false; } -// Check that graph has tf.StatefulPartitionedCall op with _XlaMustCompile. -bool RunNonReplicatedBridge(const Graph& graph) { - const std::string kStatefulPartitionedCallOp = "StatefulPartitionedCall"; - const std::string kXlaMustCompile = "_XlaMustCompile"; - for (const Node* node : graph.nodes()) { - auto node_op = node->type_string(); - if (node_op == kStatefulPartitionedCallOp) { - auto attr = node->attrs().FindByString(kXlaMustCompile); - if (attr != nullptr && attr->b() == true) { - return true; - } - } - } - return false; -} - bool HasTPUDevice(mlir::ModuleOp module) { mlir::TF::RuntimeDevices devices; if (failed(GetDevicesFromOp(module.getOperation(), &devices))) return false; @@ -105,49 +83,10 @@ bool HasTPUDevice(mlir::ModuleOp module) { }); } -bool IsReplicatedGraph(mlir::ModuleOp module) { - auto walk_result = module.walk([&](mlir::Operation* op) { - // TODO(b/223677572): Once the scope for new compilation and replication - // markers is expanded beyond bridge we can remove this check for - // `kTPUReplicateAttr`, we will then always have a `kCompileDeviceTypeAttr` - // in such cases (see above). - // TODO(b/229028654): Remove string conversion once we have C++17. - const llvm::StringRef tpu_replicate_attr_name(kTpuReplicateAttr.data(), - kTpuReplicateAttr.size()); - auto replicate_attr = - op->getAttrOfType(tpu_replicate_attr_name); - if (replicate_attr) return mlir::WalkResult::interrupt(); - return mlir::WalkResult::advance(); - }); - return walk_result.wasInterrupted(); -} - -bool IsReplicatedGraphWithoutDeviceType(mlir::ModuleOp module) { - return !HasTPUDevice(module) && IsReplicatedGraph(module); -} - -bool IsSingleCoreTPUGraph(mlir::ModuleOp module) { - auto walk_result = module.walk([&](mlir::Operation* op) { - // Check for ops with compile device type "TPU". This allows us to support - // TPU compilation without replication. Note that currently the compile - // device type is not set by default before bridge, only if eager context - // attribute `jit_compile_rewrite` is true. - // TODO(b/229028654): Remove string conversion once we have C++17. - const llvm::StringRef compile_device_type_attr_name( - kCompileDeviceTypeAttr.data(), kCompileDeviceTypeAttr.size()); - auto compilation_attr = - op->getAttrOfType(compile_device_type_attr_name); - if (compilation_attr && compilation_attr.getValue().str() == kTpuDevice) { - return mlir::WalkResult::interrupt(); - } - return mlir::WalkResult::advance(); - }); - return walk_result.wasInterrupted(); -} - -bool RunReplicatedBridge(mlir::ModuleOp module) { - if (HasTPUDevice(module) && IsReplicatedGraph(module)) return true; - return IsSingleCoreTPUGraph(module); +bool HasDevice(mlir::ModuleOp module) { + mlir::TF::RuntimeDevices devices; + if (failed(GetDevicesFromOp(module.getOperation(), &devices))) return false; + return !devices.device_names().empty(); } bool HasTPUPartitionedCallOpInModule(mlir::ModuleOp module) { @@ -201,10 +140,11 @@ absl::Status RunLowerToRuntimeOpsOnSubmodule(ModuleOp parent_module, // The config_proto param is a required input for all TF1 graphs but it is // redundant for TF2 graphs. MlirOptimizationPassState GetPassStateImpl( - bool run_replicated_bridge, const ConfigProto& config_proto, + bool is_supported_by_replicated_brige, const ConfigProto& config_proto, const Graph& graph, const FunctionLibraryDefinition& function_library) { // Skip MLIR TF/XLA Bridge if no XLA-compilable ops are found. - if (!run_replicated_bridge && !RunNonReplicatedBridge(graph)) { + if (!is_supported_by_replicated_brige && + !IsSupportedByNonReplicatedBridge(graph, &function_library)) { VLOG(3) << "Skipping MLIR Bridge, graph is not qualified to run the bridge"; return MlirOptimizationPassState::Disabled; } @@ -214,58 +154,43 @@ MlirOptimizationPassState GetPassStateImpl( // GetMlirBridgeRolloutPolicy will analyze a TPU graph if users have not // explicltly requested a policy. MlirBridgeRolloutPolicy policy = GetMlirBridgeRolloutPolicy( - graph, &function_library, config_proto, - /*run_replicated_bridge*/ run_replicated_bridge, + graph, &function_library, config_proto, is_supported_by_replicated_brige, /*uses_uninitialized_resource_args=*/false, /*is_v1_compat=*/false, /*record_stats=*/false); // GetPassState is called once before MlirBridgePass starts, and the pass // gets skipped if it is disabled. Log such cases in this function. The cases // where the pass is enabled will only be logged during their execution to // prevent them from being counted twice. - if (run_replicated_bridge) { - switch (policy) { - case MlirBridgeRolloutPolicy::kEnabledByUser: - return MlirOptimizationPassState::Enabled; - case MlirBridgeRolloutPolicy::kEnabledAfterGraphAnalysis: - return MlirOptimizationPassState::FallbackEnabled; - case MlirBridgeRolloutPolicy::kDisabledByUser: - VLOG(1) << "Skipping MLIR TPU Bridge, disabled by user. " - "Old bridge will evaluate."; - metrics::UpdateTfMlirBridgeFirstPhaseCounter("tpu", "v2", true, - "disabled_by_user"); - return MlirOptimizationPassState::Disabled; - case MlirBridgeRolloutPolicy::kDisabledAfterGraphAnalysis: - VLOG(1) << "Skipping MLIR TPU Bridge, disabled because " - "graph has unsupported features. Old bridge will evaluate."; - metrics::UpdateTfMlirBridgeFirstPhaseCounter("tpu", "v2", true, - "invalid_graph"); - // We set `uses_uninitialized_resource_args` to false here because the - // first phase of the bridge is not affected by uninitialized resource - // args. - // For Invalid Graph Analysis we need to log here because Run will not - // be called. - LogGraphFeatures(graph, &function_library, config_proto, - /*uses_uninitialized_resource_args=*/false, - /*is_v1_compat=*/false); - return MlirOptimizationPassState::Disabled; - } - } - // TODO(b/277112519): Have uniform behavior for GPU/CPU and TPU switch (policy) { case MlirBridgeRolloutPolicy::kEnabledByUser: return MlirOptimizationPassState::Enabled; case MlirBridgeRolloutPolicy::kEnabledAfterGraphAnalysis: return MlirOptimizationPassState::FallbackEnabled; - case MlirBridgeRolloutPolicy::kDisabledByUser: - VLOG(1) << "Skipping MLIR CPU/GPU Bridge, disabled by user."; - metrics::UpdateTfMlirBridgeFirstPhaseCounter("cpu/gpu", "v2", false, - "disabled_by_user"); + case MlirBridgeRolloutPolicy::kDisabledByUser: { + VLOG(1) << "Skipping MLIR " + << (is_supported_by_replicated_brige ? "Replicated" + : "Non-Replicated") + << " Bridge, disabled by user. " + "The fallback will evaluate."; + metrics::UpdateTfMlirBridgeFirstPhaseCounter( + is_supported_by_replicated_brige ? "tpu" : "cpu/gpu", "v2", true, + "disabled_by_user"); return MlirOptimizationPassState::Disabled; - default: - // This case should never be hit. Added here to be consistent with OSS - // implementation. - metrics::UpdateTfMlirBridgeFirstPhaseCounter("cpu/gpu", "v2", false, + } + case MlirBridgeRolloutPolicy::kDisabledAfterGraphAnalysis: + // Graph analysis only runs on TPU graph. + VLOG(1) << "Skipping MLIR TPU Bridge, disabled because the " + "graph has unsupported features. The fallback will evaluate."; + metrics::UpdateTfMlirBridgeFirstPhaseCounter("tpu", "v2", true, "invalid_graph"); + // We set `uses_uninitialized_resource_args` to false here because the + // first phase of the bridge is not affected by uninitialized resource + // args. + // For Invalid Graph Analysis we need to log here because Run will not + // be called. + LogGraphFeatures(graph, &function_library, config_proto, + /*uses_uninitialized_resource_args=*/false, + /*is_v1_compat=*/false); return MlirOptimizationPassState::Disabled; } } @@ -274,14 +199,18 @@ MlirOptimizationPassState MlirBridgePass::GetPassState( const DeviceSet* device_set, const ConfigProto& config_proto, const Graph& graph, const FunctionLibraryDefinition& function_library) const { + // While we do not use device type information to choose which pass pipeline + // to execute, it's needed for successful execution. if (!device_set) { // This is not expected in practice. VLOG(1) << "Device set is empty!"; return MlirOptimizationPassState::Disabled; } - return GetPassStateImpl(/*run_replicated_bridge*/ HasTPUDevice(*device_set), - config_proto, graph, function_library); + return GetPassStateImpl( + /*is_supported_by_replicated_brige*/ IsSupportedByReplicatedBridge( + graph, &function_library), + config_proto, graph, function_library); } // This runs the first phase of the "bridge", transforming the graph in a form @@ -297,18 +226,9 @@ Status MlirBridgePass::Run(const std::string& function_name, static absl::once_flag flag; absl::call_once(flag, UpdateLogVerbosityIfDefined, "TF_DEBUG_LOG_VERBOSITY"); - // Check if it's possible for a replicated graph to not have a device type. - if (IsReplicatedGraphWithoutDeviceType(module)) { - replicated_graphs_without_device_type_counter->GetCell("v2")->IncrementBy( - 1); - } - - // Check if the graph has any XLA-compilable ops. - // This check needs to precede GetPassState for instrumentation purposes. - bool run_replicated_bridge = RunReplicatedBridge(module); - if (!run_replicated_bridge && !RunNonReplicatedBridge(graph)) { - VLOG(1) << "Skipping MLIR TF2XLA Bridge, no XLA-compilable ops found."; - return OkStatus(); + if (!HasDevice(module)) { + LOG(INFO) << "No devices in " << function_name << "\n"; + return absl::OkStatus(); } if (HasTPUPartitionedCallOpInModule(module)) { @@ -320,8 +240,9 @@ Status MlirBridgePass::Run(const std::string& function_name, // TODO(b/241853328): Add caching of pass state and call logging/metrics // related to graph analysis from here. - auto pass_state = GetPassStateImpl(run_replicated_bridge, config_proto, graph, - function_library); + bool is_supported_by_replicated_brige = IsSupportedByReplicatedBridge(module); + auto pass_state = GetPassStateImpl(is_supported_by_replicated_brige, + config_proto, graph, function_library); if (pass_state == MlirOptimizationPassState::Disabled) { // GetPassState is called before run() and run() will only be called if the @@ -333,7 +254,7 @@ Status MlirBridgePass::Run(const std::string& function_name, } bool fallback_enabled = false; - if (run_replicated_bridge) { + if (is_supported_by_replicated_brige) { if (pass_state == MlirOptimizationPassState::FallbackEnabled) { // We set `uses_uninitialized_resource_args` to false here because the // first phase of the bridge is not affected by uninitialized resource @@ -350,7 +271,7 @@ Status MlirBridgePass::Run(const std::string& function_name, TF_RETURN_IF_ERROR( tensorflow::tf2xla::v2::RunFunctionTf2xlaClusteringBridge( - module, tf2xla::v2::DeviceType::XLA_TPU_JIT, fallback_enabled, + module, /*is_supported_by_replicated_brige*/ true, fallback_enabled, function_name)); TF_RETURN_IF_ERROR( @@ -360,8 +281,8 @@ Status MlirBridgePass::Run(const std::string& function_name, VLOG(1) << "Running GPU/CPU Bridge"; TF_RETURN_IF_ERROR( tensorflow::tf2xla::v2::RunFunctionTf2xlaClusteringBridge( - module, tf2xla::v2::DeviceType::XLA_GPU_JIT, fallback_enabled, - function_name)); + module, /*is_supported_by_replicated_brige*/ false, + fallback_enabled, function_name)); TF_RETURN_IF_ERROR( tensorflow::tfrt_compiler::RunLowerClusterToRuntimeOpsPassPipeline( @@ -376,14 +297,14 @@ MlirOptimizationPassState MlirBridgeV1CompatPass::GetPassState( const DeviceSet* device_set, const ConfigProto& config_proto, const Graph& graph, const FunctionLibraryDefinition& function_library) const { - // Skip MLIR TPU Bridge if no TPU devices found. - if (device_set && !HasTPUDevice(*device_set)) + // Skip MLIR Bridge if no potential XLA clusters are found. + if (!IsSupportedByReplicatedBridge(graph, &function_library)) return MlirOptimizationPassState::Disabled; // We set `uses_uninitialized_resource_args` to false here because the first // phase of the bridge is not affected by uninitialized resource args. MlirBridgeRolloutPolicy policy = GetMlirBridgeRolloutPolicy( graph, /*function_library=*/&function_library, config_proto, - /*run_replicated_bridge*/ true, + /*is_supported_by_replicated_brige*/ true, /*uses_uninitialized_resource_args=*/false, /*is_v1_compat=*/true, /*record_stats=*/false); switch (policy) { @@ -423,14 +344,8 @@ Status MlirBridgeV1CompatPass::Run(const GraphOptimizationPassOptions& options, // Skip function graphs as MlirBridgePass will be used instead. if (options.is_function_graph) return OkStatus(); - // Check if it's possible for a replicated graph to not have a device type. - if (IsReplicatedGraphWithoutDeviceType(module)) { - replicated_graphs_without_device_type_counter->GetCell("v1")->IncrementBy( - 1); - } - // Skip MLIR TPU Bridge if no TPU devices or TPU ops found. - if (!RunReplicatedBridge(module)) { + if (!IsSupportedByReplicatedBridge(module)) { VLOG(1) << "Skipping MLIR TPU Bridge V1 Compat, no TPU devices or TPU ops " "found"; return OkStatus(); diff --git a/tensorflow/compiler/tf2xla/ops/xla_ops.cc b/tensorflow/compiler/tf2xla/ops/xla_ops.cc index edb2a40f4d332b..27a534296921cd 100644 --- a/tensorflow/compiler/tf2xla/ops/xla_ops.cc +++ b/tensorflow/compiler/tf2xla/ops/xla_ops.cc @@ -42,7 +42,7 @@ Status UnchangedRank(shape_inference::InferenceContext* c) { } else { c->set_output(0, c->input(0)); } - return OkStatus(); + return absl::OkStatus(); } REGISTER_OP("XlaBroadcastHelper") @@ -294,7 +294,7 @@ static Status XlaDotShapeFunction(shape_inference::InferenceContext* c) { } c->set_output(0, c->MakeShape(output_dims)); - return OkStatus(); + return absl::OkStatus(); } REGISTER_OP("XlaDot") @@ -398,7 +398,7 @@ REGISTER_OP("XlaDynamicSlice") return UnchangedRank(c); } c->set_output(0, size_indices_value); - return OkStatus(); + return absl::OkStatus(); }) .Doc(R"doc( Wraps the XLA DynamicSlice operator, documented at @@ -556,7 +556,7 @@ REGISTER_OP("XlaPad") } c->set_output(0, c->MakeShape(output_dims)); - return OkStatus(); + return absl::OkStatus(); }) .Doc(R"doc( Wraps the XLA Pad operator, documented at @@ -587,7 +587,7 @@ REGISTER_OP("XlaRecv") shape_inference::ShapeHandle s; TF_RETURN_IF_ERROR(c->MakeShapeFromTensorShape(shape_attr, &s)); c->set_output(0, s); - return OkStatus(); + return absl::OkStatus(); }) .Doc(R"doc( Receives the named tensor from another XLA computation. Wraps the XLA Recv @@ -630,7 +630,7 @@ REGISTER_OP("XlaReduce") } else { c->set_output(0, c->input(0)); } - return OkStatus(); + return absl::OkStatus(); }) .Doc(R"doc( Wraps the XLA Reduce operator, documented at @@ -684,7 +684,7 @@ REGISTER_OP("XlaVariadicReduce") c->set_output(i, c->input(i)); } } - return OkStatus(); + return absl::OkStatus(); }) .Doc(R"doc( Wraps the variadic XLA Reduce operator. @@ -768,7 +768,7 @@ REGISTER_OP("XlaVariadicReduceV2") for (int i = 0; i < nr_inputs; ++i) { c->set_output(i, output_shape); } - return OkStatus(); + return absl::OkStatus(); }) .Doc(R"doc( Wraps the variadic XLA Reduce operator. @@ -828,7 +828,7 @@ REGISTER_OP("XlaRngBitGenerator") shape_inference::ShapeHandle output; TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(2, &output)); c->set_output(1, output); - return OkStatus(); + return absl::OkStatus(); }) .Doc(R"doc( Stateless PRNG bit generator. @@ -912,7 +912,7 @@ REGISTER_OP("XlaKeyValueSort") .SetShapeFn([](shape_inference::InferenceContext* c) { c->set_output(0, c->input(0)); c->set_output(1, c->input(1)); - return OkStatus(); + return absl::OkStatus(); }) .Doc(R"doc( Wraps the XLA Sort operator, documented at @@ -938,7 +938,7 @@ REGISTER_OP("XlaVariadicSort") std::vector input_shapes; TF_RETURN_IF_ERROR(c->input("inputs", &input_shapes)); TF_RETURN_IF_ERROR(c->set_output("outputs", input_shapes)); - return OkStatus(); + return absl::OkStatus(); }) .Doc(R"doc( Wraps the XLA Sort operator, documented at @@ -1066,7 +1066,7 @@ REGISTER_OP("XlaSpmdFullToShardShape") dims.push_back(c->MakeDim(dim)); } c->set_output(0, c->MakeShape(dims)); - return OkStatus(); + return absl::OkStatus(); }) .Doc(R"doc( An op used by XLA SPMD partitioner to switch from automatic partitioning to @@ -1092,7 +1092,7 @@ REGISTER_OP("XlaSpmdShardToFullShape") shape_inference::ShapeHandle s; TF_RETURN_IF_ERROR(c->MakeShapeFromTensorShape(shape_attr, &s)); c->set_output(0, s); - return OkStatus(); + return absl::OkStatus(); }) .Doc(R"doc( An op used by XLA SPMD partitioner to switch from manual partitioning to @@ -1119,7 +1119,7 @@ REGISTER_OP("XlaReplicaId") .Output("id: int32") .SetShapeFn([](shape_inference::InferenceContext* context) { context->set_output(0, context->MakeShape({})); - return OkStatus(); + return absl::OkStatus(); }) .Doc("Replica ID."); @@ -1212,7 +1212,7 @@ Status OptimizationBarrierShape(shape_inference::InferenceContext* c) { for (int i = 0; i < c->num_inputs(); ++i) { c->set_output(i, c->input(i)); } - return OkStatus(); + return absl::OkStatus(); } REGISTER_OP("XlaOptimizationBarrier") @@ -1258,7 +1258,7 @@ REGISTER_OP("XlaCustomCall") shape_inference::ShapeHandle s; TF_RETURN_IF_ERROR(c->MakeShapeFromTensorShape(shape_attr, &s)); c->set_output(0, s); - return OkStatus(); + return absl::OkStatus(); }) .Doc(R"doc( Wraps the XLA CustomCall operator @@ -1293,7 +1293,7 @@ REGISTER_OP("XlaCustomCallV2") TF_RETURN_IF_ERROR(c->MakeShapeFromTensorShape(shapes[i], &shape)); c->set_output(i, shape); } - return OkStatus(); + return absl::OkStatus(); }) .Doc(R"doc( Emits an HLO `CustomCall` operation with multiple outputs. @@ -1346,7 +1346,7 @@ REGISTER_OP("XlaCallModule") << "] : " << c->DebugString(s); c->set_output(i, s); } - return OkStatus(); + return absl::OkStatus(); }) .Doc(R"doc( Invokes a StableHLO module. @@ -1365,7 +1365,7 @@ version: Tracks changes the semantics of the op, to support backwards the op carries a StableHLO module with compatibility guarantees. From version 5, XLACallModule can include `stablehlo.custom_call` op to execute tf functions. From version 6 the op supports the `disabled_checks` attribute. - See more versioning details at https://github.com/search?q=repo%3Atensorflow%2Ftensorflow+path%3Axla_call_module+%22int+VERSION_MAXIMUM_SUPPORTED%22&type=code. + See more versioning details at https://github.com/search?q=repo%3Atensorflow%2Ftensorflow+path%3Axla_call_module+%22int+kVersionMaximumSupported%22&type=code. module: A serialized computation, a text or bytecode representation of an mlir.Module. The return type must be a tuple if and only if the `Sout` is a list with 0 or more than 1 elements. The length of `Tout` and diff --git a/tensorflow/compiler/tf2xla/python/xla.py b/tensorflow/compiler/tf2xla/python/xla.py index 27940b7fb92c17..5846013128611c 100644 --- a/tensorflow/compiler/tf2xla/python/xla.py +++ b/tensorflow/compiler/tf2xla/python/xla.py @@ -667,7 +667,7 @@ def call_module_maximum_supported_version(): """Maximum version of XlaCallModule op supported. See versioning details documentation for the XlaCallModule op at: - https://github.com/search?q=repo%3Atensorflow%2Ftensorflow+path%3Axla_call_module+%22int+VERSION_MAXIMUM_SUPPORTED%22&type=code + https://github.com/search?q=repo%3Atensorflow%2Ftensorflow+path%3Axla_call_module+%22int+kVersionMaximumSupported%22&type=code """ return 9 diff --git a/tensorflow/compiler/tf2xla/rearrange_function_argument.cc b/tensorflow/compiler/tf2xla/rearrange_function_argument.cc index b29d73adffb0a4..a081fa18891ba2 100644 --- a/tensorflow/compiler/tf2xla/rearrange_function_argument.cc +++ b/tensorflow/compiler/tf2xla/rearrange_function_argument.cc @@ -66,7 +66,7 @@ Status InputTypesNeedsRearrange(const std::vector& in_types, if (first_resource_index == -1) { // No resource input. No need to rewrite. *need_rewrite = false; - return OkStatus(); + return absl::OkStatus(); } *need_rewrite = false; @@ -77,7 +77,7 @@ Status InputTypesNeedsRearrange(const std::vector& in_types, } } if (!*need_rewrite) { - return OkStatus(); + return absl::OkStatus(); } *resource_input_count = 0; @@ -100,7 +100,7 @@ Status InputTypesNeedsRearrange(const std::vector& in_types, } } - return OkStatus(); + return absl::OkStatus(); } // Given mapping between original input index and rearranged input index, @@ -122,7 +122,7 @@ Status ReorderInputEdges(Graph* g, Node* n, g->RemoveEdge(e); g->AddEdge(src, src_output, n, new_dst_input)->DebugString(); } - return OkStatus(); + return absl::OkStatus(); } // For While node, given mapping between original input index and rearranged @@ -154,7 +154,7 @@ Status ReorderOutputEdges(Graph* g, Node* n, int input_count, g->AddEdge(input_edge->src(), input_edge->src_output(), dst, dst_input); } } - return OkStatus(); + return absl::OkStatus(); } // Given mapping between original input index and rearranged input index, change @@ -203,7 +203,7 @@ Status CalculateRetvalRearrange( TF_RETURN_IF_ERROR(GetNodeAttr(arg->def(), "index", &src_index)); resource_retval_to_arg->insert(std::make_pair(i, src_index)); } - return OkStatus(); + return absl::OkStatus(); } // Given original output types and return value index mapping, return the new @@ -252,7 +252,7 @@ Status RearrangeOutputEdges(Node* n, Graph* g, g->AddEdge(n, iter->second, dst, dst_input); } } - return OkStatus(); + return absl::OkStatus(); } // Given mapping between original output index and rearranged output index, @@ -287,7 +287,7 @@ Status MaybeRewriteWhileNode( types, &input_need_rearrange, &resource_input_count, &index_mapping)); if (!input_need_rearrange) { *node_rewritten = false; - return OkStatus(); + return absl::OkStatus(); } *node_rewritten = true; @@ -379,7 +379,7 @@ Status MaybeRewriteWhileNode( n->ClearAttr(attr_name); n->AddAttr(attr_name, attr_value); } - return OkStatus(); + return absl::OkStatus(); } Status MaybeRewriteIfNode( @@ -403,7 +403,7 @@ Status MaybeRewriteIfNode( DT_RESOURCE) != out_types.end(); if (!input_need_rearrange && !has_resource_output) { *node_rewritten = false; - return OkStatus(); + return absl::OkStatus(); } *node_rewritten = true; @@ -514,7 +514,7 @@ Status MaybeRewriteIfNode( n->ClearAttr("Tout"); n->AddAttr("Tout", new_out_types); } - return OkStatus(); + return absl::OkStatus(); } } // namespace @@ -557,7 +557,7 @@ Status RearrangeFunctionArguments( } } - return OkStatus(); + return absl::OkStatus(); } } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/resource_util.cc b/tensorflow/compiler/tf2xla/resource_util.cc index e91ce07e6d6983..4180b8f1330bcd 100644 --- a/tensorflow/compiler/tf2xla/resource_util.cc +++ b/tensorflow/compiler/tf2xla/resource_util.cc @@ -104,7 +104,7 @@ Status PropagateFromArgOp( int index; TF_RETURN_IF_ERROR(GetNodeAttr(n.attrs(), "index", &index)); - if (!resource_arg_indices.contains(index)) return OkStatus(); + if (!resource_arg_indices.contains(index)) return absl::OkStatus(); TF_RET_CHECK(function_name.has_value()) << "ResourceUsageAnalysis does not support analyzing _Arg nodes " @@ -122,7 +122,7 @@ Status PropagateFromArgOp( (*user_to_source)[o] = src_node_info; } - return OkStatus(); + return absl::OkStatus(); } Status UpdateResourceUsageFromFunctionBodyAnalysis( @@ -176,7 +176,7 @@ Status UpdateResourceUsageFromFunctionBodyAnalysis( } } - return OkStatus(); + return absl::OkStatus(); } Status PropagateThroughCallOp( @@ -219,7 +219,7 @@ Status PropagateThroughCallOp( TF_RETURN_IF_ERROR(UpdateResourceUsageFromFunctionBodyAnalysis( n, function_name, *fbody, called_function_source_to_path, user_to_source, source_to_path)); - return OkStatus(); + return absl::OkStatus(); } // Analyzes pass through values for Identity and IdentityN ops. @@ -246,7 +246,7 @@ Status PropagateThroughIdentityOp( } } - return OkStatus(); + return absl::OkStatus(); } Status AnalyzeResourceUsage( @@ -313,7 +313,7 @@ Status AnalyzeResourceUsage( it.first->dst()->type_string()); } - return OkStatus(); + return absl::OkStatus(); } } // anonymous namespace diff --git a/tensorflow/compiler/tf2xla/shape_util.cc b/tensorflow/compiler/tf2xla/shape_util.cc index b9093c2105cd0a..e01d1b919f9699 100644 --- a/tensorflow/compiler/tf2xla/shape_util.cc +++ b/tensorflow/compiler/tf2xla/shape_util.cc @@ -45,7 +45,7 @@ Status PopulateInfeedLayoutVector(const xla::Shape& shape, } else { layouts->insert(layouts->end(), shape.rank(), -1); } - return OkStatus(); + return absl::OkStatus(); } // Populate the output layout unless the minor_to_major array contains all -1 @@ -83,7 +83,7 @@ Status AssignLayout( layout = layout_func(*shape); } *shape->mutable_layout() = layout; - return OkStatus(); + return absl::OkStatus(); } } // namespace @@ -100,7 +100,7 @@ Status XLAShapeToTensorShape(const xla::Shape& shape, for (int i = 0; i < shape.rank(); ++i) { TF_RETURN_IF_ERROR(tensor_shape->AddDimWithStatus(shape.dimensions(i))); } - return OkStatus(); + return absl::OkStatus(); } // Convert a TensorShape into the equivalent XLA Shape proto. @@ -110,7 +110,7 @@ Status TensorShapeToXLAShape(DataType dtype, xla::PrimitiveType type; TF_RETURN_IF_ERROR(DataTypeToPrimitiveType(dtype, &type)); *shape = TensorShapeToXLAShape(type, tensor_shape); - return OkStatus(); + return absl::OkStatus(); } Status TensorShapeToBoundedXLAShape(DataType dtype, @@ -122,7 +122,7 @@ Status TensorShapeToBoundedXLAShape(DataType dtype, if (tensor_shape.unknown_rank()) { // For unknown shape, create a rank 1 size 0 tensor. *shape = xla::ShapeUtil::MakeShapeWithDenseLayout(type, {0}, {0}); - return OkStatus(); + return absl::OkStatus(); } if (tensor_shape.dims() != bound.dims()) { @@ -157,7 +157,7 @@ Status TensorShapeToBoundedXLAShape(DataType dtype, } } *shape = result; - return OkStatus(); + return absl::OkStatus(); } xla::Shape TensorShapeToXLAShape(xla::PrimitiveType type, @@ -190,7 +190,7 @@ Status TensorShapeToXLAShape(DataType dtype, const TensorShape& tensor_shape, xla::PrimitiveType type; TF_RETURN_IF_ERROR(DataTypeToPrimitiveType(dtype, &type)); *shape = TensorShapeToXLAShape(type, tensor_shape); - return OkStatus(); + return absl::OkStatus(); } StatusOr TensorShapeToXLAShape(DataType dtype, @@ -272,7 +272,7 @@ Status GetShapeWithLayout( VLOG(4) << "Shape[] = " << xla::ShapeUtil::HumanStringWithLayout(*output_shape); } - return OkStatus(); + return absl::OkStatus(); } } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/side_effect_util.cc b/tensorflow/compiler/tf2xla/side_effect_util.cc index 1a62857c537cc1..9446e4b4adadb9 100644 --- a/tensorflow/compiler/tf2xla/side_effect_util.cc +++ b/tensorflow/compiler/tf2xla/side_effect_util.cc @@ -74,7 +74,7 @@ Status SetDeviceOrdinalAttributeForNode(Node* node, int device_ordinal) { return errors::Internal("Unknown node type to set 'device_ordinal': ", node->DebugString()); } - return OkStatus(); + return absl::OkStatus(); } std::set CalculateTokenInputsForOutputToken(const Graph& g) { @@ -143,7 +143,7 @@ Status ParseHostComputeCoreList(absl::Span list_from_attr, } (*host_compute_core)[parts[0]] = core; } - return OkStatus(); + return absl::OkStatus(); } } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/test_util.cc b/tensorflow/compiler/tf2xla/test_util.cc index 0fbf983058ca32..3fb8523ce71e0c 100644 --- a/tensorflow/compiler/tf2xla/test_util.cc +++ b/tensorflow/compiler/tf2xla/test_util.cc @@ -37,7 +37,7 @@ Status InstantiateFunctionForTest(const string& name, for (NodeDef& n : inst.nodes) { *result->gdef.add_node() = std::move(n); } - return OkStatus(); + return absl::OkStatus(); } } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/tf2xla.cc b/tensorflow/compiler/tf2xla/tf2xla.cc index 61768b0ff5557e..ef87b320cdcd0c 100644 --- a/tensorflow/compiler/tf2xla/tf2xla.cc +++ b/tensorflow/compiler/tf2xla/tf2xla.cc @@ -125,7 +125,7 @@ Status ConvertGraphToXla(std::unique_ptr graph, ++input_index; } } - return OkStatus(); + return absl::OkStatus(); } Status ConvertVarHandlesToAotVarHandles(GraphDef* graph_def) { @@ -141,7 +141,7 @@ Status ConvertVarHandlesToAotVarHandles(GraphDef* graph_def) { node.mutable_attr()->erase("allowed_devices"); } } - return OkStatus(); + return absl::OkStatus(); }; for (auto& node : *graph_def->mutable_node()) { TF_RETURN_IF_ERROR(update_var_handle_op_node(node)); @@ -151,7 +151,7 @@ Status ConvertVarHandlesToAotVarHandles(GraphDef* graph_def) { TF_RETURN_IF_ERROR(update_var_handle_op_node(node)); } } - return OkStatus(); + return absl::OkStatus(); } } // namespace @@ -164,7 +164,7 @@ Status ConvertGraphDefToXla(GraphDef graph_def, const tf2xla::Config& config, TF_RETURN_IF_ERROR(InitGraph(graph_def, config, &graph)); TF_RETURN_IF_ERROR( ConvertGraphToXla(std::move(graph), config, client, computation)); - return OkStatus(); + return absl::OkStatus(); } } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/tf2xla_util.cc b/tensorflow/compiler/tf2xla/tf2xla_util.cc index a43ad91c20d828..ad5de83c814d4c 100644 --- a/tensorflow/compiler/tf2xla/tf2xla_util.cc +++ b/tensorflow/compiler/tf2xla/tf2xla_util.cc @@ -55,7 +55,7 @@ Status ValidateTensorId(const tf2xla::TensorId& id) { if (id.output_index() < 0) { return errors::InvalidArgument("TensorId output_index must be positive"); } - return OkStatus(); + return absl::OkStatus(); } Status CheckNameDuplicates(const string& kind, const string& name, @@ -65,7 +65,7 @@ Status CheckNameDuplicates(const string& kind, const string& name, return errors::InvalidArgument("duplicate ", kind, " name: ", name); } } - return OkStatus(); + return absl::OkStatus(); } Status CheckFeedFetchNameConflicts(const string& kind, @@ -79,7 +79,7 @@ Status CheckFeedFetchNameConflicts(const string& kind, " and ", name_data); } } - return OkStatus(); + return absl::OkStatus(); } // For graph `g`, copy all function call nodes' FunctionDef from `lookup_fld` to @@ -108,7 +108,7 @@ Status CopyAssociatedFunctions(Graph* g, } } } - return OkStatus(); + return absl::OkStatus(); } // Replaces the single edge feeding into {dst,dst_input} with a new @@ -162,7 +162,7 @@ Status ReplaceSrcOutputUsageWithNode(Graph* g, Node* src, int src_output, } } } - return OkStatus(); + return absl::OkStatus(); } // For graph `g`, replaces _Arg nodes whose "index" attribute is in @@ -190,7 +190,7 @@ Status ReplaceArgUsageWithConstNode( TF_RETURN_IF_ERROR( ReplaceSrcOutputUsageWithNode(g, arg_node, 0, const_node)); } - return OkStatus(); + return absl::OkStatus(); } // Replaces the single input to _Retval nodes with an index in the keys of @@ -220,7 +220,7 @@ Status ReplaceRetvalInputWithArg( ReplaceEdge(g, ret_nodes[arg_index], 0, arg_nodes[arg_index], 0) .status()); } - return OkStatus(); + return absl::OkStatus(); } // For a node's function attr (e.g. then/else branch for "If" nodes), rewrites @@ -276,7 +276,7 @@ Status PropagateConstIntoFuncAttr( // Copy associated functions. TF_RETURN_IF_ERROR(CopyAssociatedFunctions(func_graph, lookup_fld, fld)); - return OkStatus(); + return absl::OkStatus(); } // For an "If" node in graph `g`, if it has Const node inputs, rewrite its @@ -295,7 +295,7 @@ Status PropagateConstIntoIfNode(Graph* g, Node* if_node, } } if (const_input_index_to_node.empty()) { - return OkStatus(); + return absl::OkStatus(); } // Rewrite "then_branch" and "else_branch" function, replace usage of those @@ -306,7 +306,7 @@ Status PropagateConstIntoIfNode(Graph* g, Node* if_node, if_node, attr_name, const_input_index_to_node, lookup_fld, fld)); } - return OkStatus(); + return absl::OkStatus(); } using GraphCache = absl::flat_hash_map>; @@ -456,7 +456,7 @@ Status PropagateConstIntoAndAroundWhileNode( const_input_index_to_node[i] = input_edge->src(); } if (const_input_index_to_node.empty()) { - return OkStatus(); + return absl::OkStatus(); } // Rewrite "cond" and "body" function, replace usage of those _Arg nodes with @@ -473,7 +473,7 @@ Status PropagateConstIntoAndAroundWhileNode( TF_RETURN_IF_ERROR( ReplaceSrcOutputUsageWithNode(g, while_node, it.first, it.second)); } - return OkStatus(); + return absl::OkStatus(); } } // namespace @@ -502,7 +502,7 @@ Status ValidateConfig(const tf2xla::Config& config) { if (config.fetch().empty()) { return errors::InvalidArgument("fetches must be specified"); } - return OkStatus(); + return absl::OkStatus(); } Status AddPlaceholdersForFeeds( @@ -599,7 +599,7 @@ Status AddPlaceholdersForFeeds( } } - return OkStatus(); + return absl::OkStatus(); } Status PruneGraphDefInto(const tf2xla::Config& config, const GraphDef& in, @@ -664,7 +664,7 @@ Status PruneGraphDefInto(const tf2xla::Config& config, const GraphDef& in, *out->add_node() = node; } } - return OkStatus(); + return absl::OkStatus(); } string TensorIdToString(const tf2xla::TensorId& id) { @@ -695,7 +695,7 @@ Status SetNodeShardingFromNeighbors(Node* n, bool out_edges) { n->set_assigned_device_name(matching_node->assigned_device_name()); n->set_requested_device(matching_node->requested_device()); } - return OkStatus(); + return absl::OkStatus(); } void AddDtypeToKernelDefConstraint(absl::string_view name, DataType dtype, @@ -858,7 +858,7 @@ Status RewriteAssociatedFunction( } } - return OkStatus(); + return absl::OkStatus(); } Status CachedFunctionHandles::GetOrInstantiate( @@ -868,12 +868,12 @@ Status CachedFunctionHandles::GetOrInstantiate( auto iter = handles_.find(canonicalized_name); if (iter != handles_.end()) { *handle = iter->second; - return OkStatus(); + return absl::OkStatus(); } TF_RETURN_IF_ERROR(flr_->Instantiate(func_name, attrs, handle)); handles_[canonicalized_name] = *handle; - return OkStatus(); + return absl::OkStatus(); } Status CachedFunctionHandles::ReleaseAllHandles() { @@ -965,7 +965,7 @@ Status PropagateConstIntoFunctionalNodes( } } } - return OkStatus(); + return absl::OkStatus(); } Status PruneUnreachableFunctionsFromGraph(const Graph& g, @@ -979,7 +979,7 @@ Status PruneUnreachableFunctionsFromGraph(const Graph& g, TF_RETURN_IF_ERROR(fld->RemoveFunction(func_name)); } } - return OkStatus(); + return absl::OkStatus(); } Status RewriteTensorListWithConstElement(Graph* g, @@ -1116,7 +1116,7 @@ Status RewriteTensorListWithConstElement(Graph* g, bwd_while->ClearAttr("body"); bwd_while->AddAttr("body", bwd_body_attr); } - return OkStatus(); + return absl::OkStatus(); } } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/tf2xla_util_test.cc b/tensorflow/compiler/tf2xla/tf2xla_util_test.cc index d04dbd314a4931..d1ea22324c7e8c 100644 --- a/tensorflow/compiler/tf2xla/tf2xla_util_test.cc +++ b/tensorflow/compiler/tf2xla/tf2xla_util_test.cc @@ -41,7 +41,7 @@ namespace tensorflow { namespace { void ExpectErrorContains(const Status& status, absl::string_view str) { - EXPECT_NE(OkStatus(), status); + EXPECT_NE(absl::OkStatus(), status); EXPECT_TRUE(absl::StrContains(status.message(), str)) << "expected error: " << status.message() << " to contain: " << str; } diff --git a/tensorflow/compiler/tf2xla/type_util.cc b/tensorflow/compiler/tf2xla/type_util.cc index 335cfdf37b1605..8221aa28b27b0b 100644 --- a/tensorflow/compiler/tf2xla/type_util.cc +++ b/tensorflow/compiler/tf2xla/type_util.cc @@ -28,66 +28,66 @@ Status DataTypeToPrimitiveType(DataType data_type, xla::PrimitiveType* type) { switch (data_type) { case tensorflow::DT_BOOL: *type = xla::PRED; - return OkStatus(); + return absl::OkStatus(); case tensorflow::DT_INT4: *type = xla::S4; - return OkStatus(); + return absl::OkStatus(); case tensorflow::DT_INT8: case tensorflow::DT_QINT8: *type = xla::S8; - return OkStatus(); + return absl::OkStatus(); case tensorflow::DT_INT16: case tensorflow::DT_QINT16: *type = xla::S16; - return OkStatus(); + return absl::OkStatus(); case tensorflow::DT_INT32: case tensorflow::DT_QINT32: *type = xla::S32; - return OkStatus(); + return absl::OkStatus(); case tensorflow::DT_INT64: *type = xla::S64; - return OkStatus(); + return absl::OkStatus(); case tensorflow::DT_UINT4: *type = xla::U4; - return OkStatus(); + return absl::OkStatus(); case tensorflow::DT_UINT8: case tensorflow::DT_QUINT8: *type = xla::U8; - return OkStatus(); + return absl::OkStatus(); case tensorflow::DT_UINT16: case tensorflow::DT_QUINT16: *type = xla::U16; - return OkStatus(); + return absl::OkStatus(); case tensorflow::DT_UINT32: *type = xla::U32; - return OkStatus(); + return absl::OkStatus(); case tensorflow::DT_UINT64: *type = xla::U64; - return OkStatus(); + return absl::OkStatus(); case tensorflow::DT_FLOAT8_E5M2: *type = xla::F8E5M2; - return OkStatus(); + return absl::OkStatus(); case tensorflow::DT_FLOAT8_E4M3FN: *type = xla::F8E4M3FN; - return OkStatus(); + return absl::OkStatus(); case tensorflow::DT_BFLOAT16: *type = xla::BF16; - return OkStatus(); + return absl::OkStatus(); case tensorflow::DT_HALF: *type = xla::F16; - return OkStatus(); + return absl::OkStatus(); case tensorflow::DT_FLOAT: *type = xla::F32; - return OkStatus(); + return absl::OkStatus(); case tensorflow::DT_DOUBLE: *type = xla::F64; - return OkStatus(); + return absl::OkStatus(); case tensorflow::DT_COMPLEX64: *type = xla::C64; - return OkStatus(); + return absl::OkStatus(); case tensorflow::DT_COMPLEX128: *type = xla::C128; - return OkStatus(); + return absl::OkStatus(); default: return errors::InvalidArgument( "Unsupported type in DataTypeToPrimitiveType: '", diff --git a/tensorflow/compiler/tf2xla/xla_compilation_device.cc b/tensorflow/compiler/tf2xla/xla_compilation_device.cc index 240e4bb1a78ceb..bbdf5c7d2c74fa 100644 --- a/tensorflow/compiler/tf2xla/xla_compilation_device.cc +++ b/tensorflow/compiler/tf2xla/xla_compilation_device.cc @@ -140,7 +140,7 @@ void XlaCompilationDevice::Compute(OpKernel* op_kernel, VLOG(4) << "Done"; } -Status XlaCompilationDevice::Sync() { return OkStatus(); } +Status XlaCompilationDevice::Sync() { return absl::OkStatus(); } Status XlaCompilationDevice::MakeTensorFromProto( const TensorProto& tensor_proto, const AllocatorAttributes alloc_attrs, diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc index 5b40e7c5b968a4..1ab4f2fd3a9a81 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler.cc @@ -100,7 +100,7 @@ Status CheckSignature(const DataTypeVector& types, " but function parameter has type ", DataTypeString(types[i])); } } - return OkStatus(); + return absl::OkStatus(); } // Uses the _Arg and _Retval nodes in the graph to determine an OpSharding for @@ -445,7 +445,7 @@ Status BuildComputation( TF_ASSIGN_OR_RETURN(auto program_shape, computation->GetProgramShape()); *output_shape = program_shape.result(); - return OkStatus(); + return absl::OkStatus(); } } // namespace @@ -519,7 +519,7 @@ string XlaCompiler::Argument::ShapeHumanString() const { XlaCompiler::XlaCompiler(XlaCompiler::Options options) : options_(options), - initialization_status_(OkStatus()), + initialization_status_(absl::OkStatus()), next_step_id_(1), device_(new XlaCompilationDevice(SessionOptions(), options_.device_type)), device_mgr_(absl::WrapUnique(device_)) { @@ -572,7 +572,7 @@ static Status GetFunctionBody(const NameAttrList& function, *fbody = flib_runtime->GetFunctionBody(handle); TF_RET_CHECK(*fbody); - return OkStatus(); + return absl::OkStatus(); } Status XlaCompiler::FindFunctionBody(const NameAttrList& function, @@ -599,7 +599,7 @@ Status XlaCompiler::FindFunctionBody(const NameAttrList& function, } VLOG(4) << "Function " << function.name() << " in local_flib_runtime_"; } - return OkStatus(); + return absl::OkStatus(); } std::unique_ptr XlaCompiler::GetGraph(const FunctionBody* fbody) { @@ -816,7 +816,7 @@ Status XlaCompiler::CompileFunction( auto it = cache_.find({function_id, arg_vector}); if (it != cache_.end()) { *result = it->second; - return OkStatus(); + return absl::OkStatus(); } const FunctionBody* fbody; @@ -928,7 +928,7 @@ Status XlaCompiler::CompileFunction( VLOG(1) << "===================================================="; cache_[{function_id, arg_vector}] = *result; - return OkStatus(); + return absl::OkStatus(); } // Computes the XLA shape for argument 'arg'. @@ -976,12 +976,12 @@ Status XlaCompiler::XLAShapeForArgument( arg.type, std::get(arg.shape), xla_shape)); } } - return OkStatus(); + return absl::OkStatus(); } case XlaCompiler::Argument::kTensorList: { TF_RET_CHECK(absl::holds_alternative(arg.shape)); *xla_shape = std::get(arg.shape); - return OkStatus(); + return absl::OkStatus(); } case XlaCompiler::Argument::kConstantResource: case XlaCompiler::Argument::kResource: { @@ -1001,7 +1001,7 @@ Status XlaCompiler::XLAShapeForArgument( TF_RETURN_IF_ERROR(RewriteLayoutWithShardedShape( arg_sharding, arg.fast_mem, options_.shape_determination_fns, xla_shape)); - return OkStatus(); + return absl::OkStatus(); } case XlaResource::kTensorArray: { if (arg.max_array_size < 0) { @@ -1019,7 +1019,7 @@ Status XlaCompiler::XLAShapeForArgument( arg.tensor_array_gradients.size() + 1, *xla_shape); *xla_shape = xla::ShapeUtil::MakeTupleShape(tuple_shape); } - return OkStatus(); + return absl::OkStatus(); } case XlaResource::kStack: { if (arg.max_array_size < 0) { @@ -1035,7 +1035,7 @@ Status XlaCompiler::XLAShapeForArgument( TensorShapeToXLAShape(arg.type, shape, &buffer_shape)); *xla_shape = xla::ShapeUtil::MakeTupleShape( {buffer_shape, xla::ShapeUtil::MakeShape(xla::S32, {})}); - return OkStatus(); + return absl::OkStatus(); } case XlaResource::kInvalid: @@ -1045,7 +1045,7 @@ Status XlaCompiler::XLAShapeForArgument( } case XlaCompiler::Argument::kToken: { *xla_shape = xla::ShapeUtil::MakeTokenShape(); - return OkStatus(); + return absl::OkStatus(); } case XlaCompiler::Argument::kInvalid: return errors::Internal("Invalid argument type in XLAShapeForArgument()"); @@ -1144,7 +1144,7 @@ Status XlaCompiler::BuildArguments( } if (input_to_args->empty() && !use_tuple_arg) { - return OkStatus(); + return absl::OkStatus(); } // `arg_to_inputs[c] = d` means that the c'th original arg index corresponds @@ -1305,7 +1305,7 @@ Status XlaCompiler::BuildArguments( } } - return OkStatus(); + return absl::OkStatus(); } namespace { @@ -1321,7 +1321,7 @@ Status ValidateFunctionDef(const FunctionDef* fdef, const OpDef* op_def; TF_RETURN_IF_ERROR(OpRegistry::Global()->LookUpOpDef(op, &op_def)); } - return OkStatus(); + return absl::OkStatus(); } // If node is PartitionedCall or StatefulPartitionedCall, returns the @@ -1340,10 +1340,10 @@ Status GetPotentialFunctionName(const Node& node, const string** name) { " does not have 'func' field set"); } *name = &attr_value->func().name(); - return OkStatus(); + return absl::OkStatus(); } *name = &node.type_string(); - return OkStatus(); + return absl::OkStatus(); } // Check that the graph doesn't have any invalid nodes (e.g. incompatible with @@ -1379,7 +1379,7 @@ Status ValidateGraph(const Graph* graph, return errors::InvalidArgument(errmsg); } - return OkStatus(); + return absl::OkStatus(); }; for (const Node* node : graph->nodes()) { @@ -1402,7 +1402,7 @@ Status ValidateGraph(const Graph* graph, s = FindKernelDef(device_type, node->def(), nullptr, nullptr); TF_RETURN_IF_ERROR(maybe_error(node, s)); } - return OkStatus(); + return absl::OkStatus(); } void ConvertConstantsToExpressions(xla::XlaBuilder* builder, @@ -1603,10 +1603,28 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options, result->input_mapping)); for (const auto& [key, send] : host_compute_sends_) { - *result->host_compute_metadata.add_device_to_host() = send; + auto* d2h = result->host_compute_metadata.add_device_to_host(); + *d2h = send; + + for (int i = 0; i < d2h->metadata_size(); ++i) { + const std::string channel_name = + GetDeviceToHostChannelName(d2h->key(), i); + xla::ChannelHandle handle; + TF_RETURN_IF_ERROR(GetDeviceToHostChannelHandle(channel_name, &handle)); + d2h->mutable_metadata(i)->set_channel_id(handle.handle()); + } } for (const auto& [key, recv] : host_compute_recvs_) { - *result->host_compute_metadata.add_host_to_device() = recv; + auto* h2d = result->host_compute_metadata.add_host_to_device(); + *h2d = recv; + + for (int i = 0; i < h2d->metadata_size(); ++i) { + const std::string channel_name = + GetHostToDeviceChannelName(h2d->key(), i); + xla::ChannelHandle handle; + TF_RETURN_IF_ERROR(GetHostToDeviceChannelHandle(channel_name, &handle)); + h2d->mutable_metadata(i)->set_channel_id(handle.handle()); + } } if (!tsl::tensor_float_32_execution_enabled()) { @@ -1618,7 +1636,7 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options, VLOG(2) << "XLA output shape: " << xla::ShapeUtil::HumanStringWithLayout(result->xla_output_shape); result->collective_info = context->GetCollectiveInfo(); - return OkStatus(); + return absl::OkStatus(); } Status XlaCompiler::GetChannelHandle(const string& key, @@ -1629,7 +1647,7 @@ Status XlaCompiler::GetChannelHandle(const string& key, } *channel = result.first->second; VLOG(1) << "Channel: " << key << " " << channel->DebugString(); - return OkStatus(); + return absl::OkStatus(); } Status XlaCompiler::GetHostToDeviceChannelHandle(const string& key, @@ -1641,7 +1659,7 @@ Status XlaCompiler::GetHostToDeviceChannelHandle(const string& key, } *channel = result.first->second; VLOG(1) << "Host to device channel: " << key << " " << channel->DebugString(); - return OkStatus(); + return absl::OkStatus(); } Status XlaCompiler::GetDeviceToHostChannelHandle(const string& key, @@ -1653,7 +1671,7 @@ Status XlaCompiler::GetDeviceToHostChannelHandle(const string& key, } *channel = result.first->second; VLOG(1) << "Device to host channel: " << key << " " << channel->DebugString(); - return OkStatus(); + return absl::OkStatus(); } namespace { @@ -1680,7 +1698,7 @@ Status XlaCompiler::SetDeviceToHostMetadata( tf2xla::HostTransferMetadata new_transfer; SetTransfer(key, types, shapes, &new_transfer); if (xla::protobuf_util::ProtobufEquals(existing_transfer, new_transfer)) { - return OkStatus(); + return absl::OkStatus(); } else { return errors::InvalidArgument( "Duplicate calls to SetDeviceToHostMetadata with key ", key); @@ -1688,7 +1706,7 @@ Status XlaCompiler::SetDeviceToHostMetadata( } tf2xla::HostTransferMetadata& transfer = host_compute_sends_[key]; SetTransfer(key, types, shapes, &transfer); - return OkStatus(); + return absl::OkStatus(); } Status XlaCompiler::GetDeviceToHostShapes( @@ -1703,7 +1721,7 @@ Status XlaCompiler::GetDeviceToHostShapes( TensorShape shape(iter->second.metadata(i).shape()); shapes->push_back(shape); } - return OkStatus(); + return absl::OkStatus(); } Status XlaCompiler::SetHostToDeviceMetadata( @@ -1714,7 +1732,7 @@ Status XlaCompiler::SetHostToDeviceMetadata( tf2xla::HostTransferMetadata new_transfer; SetTransfer(key, types, shapes, &new_transfer); if (xla::protobuf_util::ProtobufEquals(existing_transfer, new_transfer)) { - return OkStatus(); + return absl::OkStatus(); } else { return errors::InvalidArgument( "Duplicate calls to SetHostToDeviceMetadata with key ", key); @@ -1722,7 +1740,7 @@ Status XlaCompiler::SetHostToDeviceMetadata( } tf2xla::HostTransferMetadata& transfer = host_compute_recvs_[key]; SetTransfer(key, types, shapes, &transfer); - return OkStatus(); + return absl::OkStatus(); } Status XlaCompiler::GetHostComputeControlDependency( @@ -1735,7 +1753,7 @@ Status XlaCompiler::GetHostComputeControlDependency( } else { *handle = iter->second; } - return OkStatus(); + return absl::OkStatus(); } Status XlaCompiler::SetHostComputeControlDependency( @@ -1747,7 +1765,7 @@ Status XlaCompiler::SetHostComputeControlDependency( host_compute_name); } host_compute_control_output_[host_compute_name] = handle; - return OkStatus(); + return absl::OkStatus(); } void XlaCompiler::PushNodeTokenMapping() { @@ -1761,7 +1779,7 @@ Status XlaCompiler::PopNodeTokenMapping() { "empty."); } node_token_mapping_stack_.pop(); - return OkStatus(); + return absl::OkStatus(); } Status XlaCompiler::SetNodeToken(const string& node_name, const xla::XlaOp op) { @@ -1775,7 +1793,7 @@ Status XlaCompiler::SetNodeToken(const string& node_name, const xla::XlaOp op) { return errors::FailedPrecondition("Token mapping already exists for node ", node_name); } - return OkStatus(); + return absl::OkStatus(); } StatusOr XlaCompiler::GetNodeToken(const string& node_name) { diff --git a/tensorflow/compiler/tf2xla/xla_context.cc b/tensorflow/compiler/tf2xla/xla_context.cc index 04f12c7ca575d4..ff444efe752640 100644 --- a/tensorflow/compiler/tf2xla/xla_context.cc +++ b/tensorflow/compiler/tf2xla/xla_context.cc @@ -195,7 +195,7 @@ Status XlaContext::RecordCollectiveInfoFromNestedCompilationResult( result.collective_info->group_size) .status(); } - return OkStatus(); + return absl::OkStatus(); } StatusOr XlaContext::RecordCollectiveInfo(int group_key, diff --git a/tensorflow/compiler/tf2xla/xla_helpers.cc b/tensorflow/compiler/tf2xla/xla_helpers.cc index dc40bb47e8f8fe..5f99e7f284e26c 100644 --- a/tensorflow/compiler/tf2xla/xla_helpers.cc +++ b/tensorflow/compiler/tf2xla/xla_helpers.cc @@ -86,7 +86,7 @@ xla::XlaOp XlaHelpers::FloatLiteral(xla::XlaBuilder* b, DataType data_type, *output = input.Clone(); output->mutable_shape_do_not_use()->Swap(&shape); - return OkStatus(); + return absl::OkStatus(); } Status XlaHelpers::OneHot(xla::XlaBuilder* builder, int64_t depth, int axis, @@ -110,7 +110,7 @@ Status XlaHelpers::OneHot(xla::XlaBuilder* builder, int64_t depth, int axis, xla::Eq(indices, xla::Iota(builder, iota_shape, axis), broadcast_dims), xla::Broadcast(on_value, output_shape.dim_sizes()), xla::Broadcast(off_value, output_shape.dim_sizes())); - return OkStatus(); + return absl::OkStatus(); } DataType XlaHelpers::SumAccumulationType(const DataType& dtype) { @@ -253,7 +253,7 @@ Status ResolveDeviceAssignment( }); run_options.set_device_assignment(&device_assignment); run_options.set_gpu_executable_run_options(&gpu_options); - return OkStatus(); + return absl::OkStatus(); } } // end namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/xla_helpers.h b/tensorflow/compiler/tf2xla/xla_helpers.h index 0bebf471ecfbe9..ee483af794e26d 100644 --- a/tensorflow/compiler/tf2xla/xla_helpers.h +++ b/tensorflow/compiler/tf2xla/xla_helpers.h @@ -18,6 +18,8 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_TF2XLA_XLA_HELPERS_H_ #define TENSORFLOW_COMPILER_TF2XLA_XLA_HELPERS_H_ +#include + #include "absl/types/optional.h" #include "absl/types/span.h" #include "tensorflow/compiler/tf2xla/host_compute_metadata.pb.h" @@ -33,6 +35,15 @@ namespace tensorflow { using XlaLayoutPreference = mlir::XlaLayoutPreference; +inline std::string GetDeviceToHostChannelName(absl::string_view channel_key, + int index) { + return absl::StrCat(channel_key, "_dtoh_", index); +} +inline std::string GetHostToDeviceChannelName(absl::string_view channel_key, + int index) { + return absl::StrCat(channel_key, "_htod_", index); +} + // Helper methods for building XLA computations. class XlaHelpers { public: diff --git a/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function_test.cc b/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function_test.cc index ad65c1708794fd..8e8d3f28d8a47a 100644 --- a/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function_test.cc +++ b/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function_test.cc @@ -25,8 +25,8 @@ limitations under the License. #include "xla/service/platform_util.h" #include "xla/shape_util.h" #include "xla/status_macros.h" -#include "xla/stream_executor/multi_platform_manager.h" #include "xla/stream_executor/platform.h" +#include "xla/stream_executor/platform_manager.h" #include "xla/test.h" #include "xla/xla_data.pb.h" #include "tensorflow/core/framework/attr_value.pb.h" @@ -340,8 +340,8 @@ TEST(XlaJitCompiledCpuFunction, CanCompileWithAdditionalPlatform) { string name_; }; - TF_EXPECT_OK(se::MultiPlatformManager::RegisterPlatform( - std::make_unique())); + TF_EXPECT_OK( + se::PlatformManager::RegisterPlatform(std::make_unique())); xla::Compiler::RegisterCompilerFactory(kFakePlatformId, []() { return std::unique_ptr(nullptr); }); diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.cc b/tensorflow/compiler/tf2xla/xla_op_kernel.cc index 237d115aa298e5..f70a7df612d149 100644 --- a/tensorflow/compiler/tf2xla/xla_op_kernel.cc +++ b/tensorflow/compiler/tf2xla/xla_op_kernel.cc @@ -196,7 +196,7 @@ Status XlaOpKernelContext::ConstantInputReshaped( } TF_ASSIGN_OR_RETURN(*constant_literal, HostTensorToLiteral(temp)); - return OkStatus(); + return absl::OkStatus(); } // Converts an int16, int32 or int64 scalar literal to an int64. @@ -214,7 +214,7 @@ static Status LiteralToInt64Scalar(const xla::LiteralSlice& literal, } else { return errors::InvalidArgument("value must be int16, int32, or int64"); } - return OkStatus(); + return absl::OkStatus(); } // Converts an float32 or float64 scalar literal to a float64. @@ -230,7 +230,7 @@ static Status LiteralToFloat64Scalar(const xla::LiteralSlice& literal, } else { return errors::InvalidArgument("value must be either float32 or float64"); } - return OkStatus(); + return absl::OkStatus(); } Status XlaOpKernelContext::ConstantInputAsIntScalar( @@ -273,7 +273,7 @@ static Status LiteralToPredVector(const xla::LiteralSlice& literal, for (int64_t i = 0; i < size; ++i) { out->push_back(literal.Get({i})); } - return OkStatus(); + return absl::OkStatus(); } Status XlaOpKernelContext::ResolveInputDynamismIntoPred(int index, bool* out) { @@ -288,7 +288,7 @@ Status XlaOpKernelContext::ResolveInputDynamismIntoPred(int index, bool* out) { // TODO(b/176993339): Support resolving dynamism across computations so // resolving dynamism will not fail in those cases. *out = true; - return OkStatus(); + return absl::OkStatus(); } Tensor dynamism = dynamism_or_status.value(); @@ -302,7 +302,7 @@ Status XlaOpKernelContext::ResolveInputDynamismIntoPred(int index, bool* out) { TF_ASSIGN_OR_RETURN(literal, HostTensorToLiteral(temp)); *out = literal.Get({}); - return OkStatus(); + return absl::OkStatus(); } Status XlaOpKernelContext::ResolveInputDynamismIntoPredVector( @@ -332,7 +332,7 @@ Status XlaOpKernelContext::ResolveInputDynamismReshaped( .Broadcast(xla::ShapeUtil::MakeShape(xla::PRED, new_dims), {}) .value(); - return OkStatus(); + return absl::OkStatus(); } Tensor dynamism = dynamism_or_status.value(); @@ -346,7 +346,7 @@ Status XlaOpKernelContext::ResolveInputDynamismReshaped( } TF_ASSIGN_OR_RETURN(*dynamism_literal, HostTensorToLiteral(temp)); - return OkStatus(); + return absl::OkStatus(); } Status XlaOpKernelContext::ResolveInputDynamismIntoPredVector( @@ -377,7 +377,7 @@ static Status LiteralToInt64Vector(const xla::LiteralSlice& literal, } else { return errors::InvalidArgument("value must be either int32 or int64"); } - return OkStatus(); + return absl::OkStatus(); } Status XlaOpKernelContext::ConstantInputAsIntVector( @@ -424,11 +424,11 @@ Status XlaOpKernelContext::ConstantInputAsInt64Literal( for (int64_t i = 0; i < src_data.size(); ++i) { out->data()[i] = src_data[i]; } - return OkStatus(); + return absl::OkStatus(); } case xla::S64: *out = std::move(literal); - return OkStatus(); + return absl::OkStatus(); default: return errors::InvalidArgument( @@ -462,7 +462,7 @@ Status XlaOpKernelContext::ConstantInputAsShape(int index, TensorShape* shape, ", result: ", num_elements); } *shape = TensorShape(dims); - return OkStatus(); + return absl::OkStatus(); } Status XlaOpKernelContext::ConstantInputAsPartialShape( @@ -478,12 +478,12 @@ Status XlaOpKernelContext::ConstantInputAsPartialShape( "Cannot convert value to PartialTensorShape: ", shape_val); } *shape = PartialTensorShape(); // Shape with unknown rank. - return OkStatus(); + return absl::OkStatus(); } std::vector dims; TF_RETURN_IF_ERROR(LiteralToInt64Vector(literal, &dims)); *shape = PartialTensorShape(dims); - return OkStatus(); + return absl::OkStatus(); } Status XlaOpKernelContext::InputList(absl::string_view name, @@ -498,7 +498,7 @@ Status XlaOpKernelContext::InputList(absl::string_view name, XlaExpression::CastExpressionFromTensor(input)->AsXlaOp(builder())); shapes->push_back(input.shape()); } - return OkStatus(); + return absl::OkStatus(); } Status XlaOpKernelContext::ConstantInputList(absl::string_view name, @@ -510,7 +510,7 @@ Status XlaOpKernelContext::ConstantInputList(absl::string_view name, for (int i = start; i < stop; ++i) { TF_RETURN_IF_ERROR(ConstantInput(i, &(*outputs)[i], mode)); } - return OkStatus(); + return absl::OkStatus(); } StatusOr XlaOpKernelContext::ConstantInputTensor( @@ -571,7 +571,7 @@ Status ReadVariableInputTensor(const Tensor& tensor, DataType type, TF_ASSIGN_OR_RETURN(xla::Literal literal, HostTensorToLiteral(*expression->constant_value())); *value = xla::ConstantLiteral(ctx->builder(), literal); - return OkStatus(); + return absl::OkStatus(); } auto shape_determination_fns = ctx->compiler()->options().shape_determination_fns; @@ -590,7 +590,7 @@ Status ReadVariableInputTensor(const Tensor& tensor, DataType type, } else { *value = xla::Reshape(variable->value(), variable->shape().dim_sizes()); } - return OkStatus(); + return absl::OkStatus(); } } // namespace @@ -625,7 +625,7 @@ Status XlaOpKernelContext::GetVariableTypeAndShape(int index, DataType* type, } *type = variable->type(); *shape = variable->shape(); - return OkStatus(); + return absl::OkStatus(); } void XlaOpKernelContext::SetOutputExpression(int index, @@ -656,7 +656,7 @@ void XlaOpKernelContext::SetOutputExpression(int index, } XlaExpression::AssignExpressionToTensor(expression, context_->mutable_output(index)); - return OkStatus(); + return absl::OkStatus(); }(); if (!status.ok()) { SetStatus(status); @@ -697,7 +697,7 @@ Status XlaOpKernelContext::GetResourceInput(int index, XlaResource** resource) { XlaExpression::CastExpressionFromTensor(context_->input(index)); TF_RET_CHECK(expression->resource() != nullptr); *resource = expression->resource(); - return OkStatus(); + return absl::OkStatus(); } namespace { diff --git a/tensorflow/compiler/tf2xla/xla_op_registry.cc b/tensorflow/compiler/tf2xla/xla_op_registry.cc index 986388ab32b9f2..0109f6a3f07ef3 100644 --- a/tensorflow/compiler/tf2xla/xla_op_registry.cc +++ b/tensorflow/compiler/tf2xla/xla_op_registry.cc @@ -53,7 +53,7 @@ static Status LaunchOpHasKernelForDevice(const DeviceType& device_type) { &kernel_class_name)); VLOG(1) << "LaunchOpHasKernelForDevice" << " kernel_class_name: " << kernel_class_name; - return OkStatus(); + return absl::OkStatus(); } XlaOpRegistry::XlaOpRegistry() = default; @@ -437,7 +437,7 @@ XlaOpRegistry::CompileTimeConstantInputArgNames(const string& op) { compile_time_constant_inputs = CompileTimeConstantInputArgNames(node_def.op()); if (compile_time_constant_inputs->empty()) { - return OkStatus(); + return absl::OkStatus(); } } @@ -470,7 +470,7 @@ XlaOpRegistry::CompileTimeConstantInputArgNames(const string& op) { } absl::c_sort(*result); - return OkStatus(); + return absl::OkStatus(); } /*static*/ bool XlaOpRegistry::IsMetadataOp(const string& op) { diff --git a/tensorflow/compiler/tf2xla/xla_resource.cc b/tensorflow/compiler/tf2xla/xla_resource.cc index d48c97d35c31f5..0e1d33a0c1c718 100644 --- a/tensorflow/compiler/tf2xla/xla_resource.cc +++ b/tensorflow/compiler/tf2xla/xla_resource.cc @@ -109,7 +109,7 @@ Status XlaResource::SetTypeAndShape(DataType type, const TensorShape& shape) { } type_ = type; shape_ = shape; - return OkStatus(); + return absl::OkStatus(); } Status XlaResource::SetValue(const xla::XlaOp& value) { @@ -120,7 +120,7 @@ Status XlaResource::SetValue(const xla::XlaOp& value) { } value_ = value; is_overwritten_ = true; - return OkStatus(); + return absl::OkStatus(); } Status XlaResource::SetZeroValue(xla::XlaBuilder* builder) { @@ -159,7 +159,7 @@ Status XlaResource::SetZeroValue(xla::XlaBuilder* builder) { default: LOG(FATAL) << "Invalid resource type"; } - return OkStatus(); + return absl::OkStatus(); } Status XlaResource::GetOrCreateTensorArrayGradient(const string& source, @@ -183,7 +183,7 @@ Status XlaResource::GetOrCreateTensorArrayGradient(const string& source, /*tensor_array_multiple_writes_aggregate=*/true)); } *gradient_out = gradient.get(); - return OkStatus(); + return absl::OkStatus(); } Status XlaResource::Pack(xla::XlaOp* pack, xla::XlaBuilder* builder) const { @@ -198,7 +198,7 @@ Status XlaResource::Pack(xla::XlaOp* pack, xla::XlaBuilder* builder) const { } *pack = xla::Tuple(builder, elems); } - return OkStatus(); + return absl::OkStatus(); } Status XlaResource::SetFromPack(const std::set& gradient_sources, @@ -229,7 +229,7 @@ Status XlaResource::SetFromPack(const std::set& gradient_sources, gradient->value_ = v; } } - return OkStatus(); + return absl::OkStatus(); } } // namespace tensorflow diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index 7d51991254c9e6..dde60c6079d704 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -336,10 +336,8 @@ cc_library( ":lib_internal", ":protos_all_cc", "//tensorflow/core/kernels:required", - "@com_google_googletest//:gtest", "@local_tsl//tsl/platform:test", "@local_tsl//tsl/platform:test_benchmark", - "@local_tsl//tsl/platform/default/build_config:gtest", ] + tf_additional_test_deps(), ) @@ -499,7 +497,6 @@ cc_library( "//tensorflow/core/platform:platform_port", "//tensorflow/core/platform:thread_annotations", "//tensorflow/core/platform:mutex", - "@local_tsl//tsl/platform/default/build_config:minimal", "@local_tsl//tsl/framework:fixedpoint_types", "//tensorflow/core/platform:types", ], @@ -647,7 +644,6 @@ cc_library( "//tensorflow/core/kernels/mkl:mkl_batch_matmul_op", "//tensorflow/core/kernels/mkl:mkl_einsum_op", "//tensorflow/core/kernels/mkl:mkl_matmul_op", - "//tensorflow/core/kernels/mkl:mkl_sparse_matrix_matmul_op", "//tensorflow/core/kernels/mkl:mkl_tmp_ops", "//tensorflow/core/kernels/mkl:mkl_deprecated_ops", ]) + if_cuda_or_rocm([ @@ -712,7 +708,7 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":tensorflow_opensource", - "@local_tsl//tsl/platform/default/build_config:tensorflow_platform_specific", + "//tensorflow/core/platform/default/build_config:tensorflow_platform_specific", ], ) @@ -1143,8 +1139,8 @@ cc_library( deps = [ ":portable_tensorflow_lib", "//tensorflow/core/kernels:portable_tensorflow_kernels", + "@com_google_googletest//:gtest", "@eigen_archive//:eigen3", - "@local_tsl//tsl/platform/default/build_config:gtest", ], ) diff --git a/tensorflow/core/api_def/base_api/api_def_GlobalShuffleDataset.pbtxt b/tensorflow/core/api_def/base_api/api_def_GlobalShuffleDataset.pbtxt new file mode 100644 index 00000000000000..8e4a1bdf8ea190 --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_GlobalShuffleDataset.pbtxt @@ -0,0 +1,4 @@ +op { + graph_op_name: "GlobalShuffleDataset" + visibility: HIDDEN +} diff --git a/tensorflow/core/common_runtime/BUILD b/tensorflow/core/common_runtime/BUILD index 550f20269532b2..15acd9db5df20f 100644 --- a/tensorflow/core/common_runtime/BUILD +++ b/tensorflow/core/common_runtime/BUILD @@ -254,6 +254,7 @@ filegroup( "collective_param_resolver_local.h", "collective_rma_local.h", "collective_util.h", + "colocate_predecessor_trees_pass.h", "colocation_graph.h", "constant_folding.h", "copy_tensor.h", @@ -1180,6 +1181,28 @@ cc_library( alwayslink = 1, ) +cc_library( + name = "colocate_predecessor_trees_pass", + srcs = ["colocate_predecessor_trees_pass.cc"], + hdrs = ["colocate_predecessor_trees_pass.h"], + copts = tf_copts(), + deps = [ + ":optimization_registry", + "//tensorflow/core:core_cpu_base", + "//tensorflow/core:framework", + "//tensorflow/core:portable_gif_internal", + "//tensorflow/core/framework:node_def_util", + "//tensorflow/core/framework:tensor_proto_cc", + "//tensorflow/core/framework:tensor_shape_proto_cc", + "//tensorflow/core/platform:status", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + ], + alwayslink = 1, +) + cc_library( name = "local_device", srcs = ["local_device.cc"], @@ -1945,6 +1968,7 @@ tf_cuda_library( ":collective_param_resolver_local", ":collective_rma_local", ":collective_util", + ":colocate_predecessor_trees_pass", ":composite_device", ":copy_tensor", ":costmodel_manager", @@ -2340,6 +2364,7 @@ tf_cc_tests( "buf_rendezvous_test.cc", "collective_executor_mgr_test.cc", "collective_rma_local_test.cc", + "colocate_predecessor_trees_pass_test.cc", "device_mgr_test.cc", "device_resolver_local_test.cc", "device_set_test.cc", @@ -3476,28 +3501,6 @@ tf_cc_fuzz_test( ], ) -cc_library( - name = "serving_device_selector", - srcs = ["serving_device_selector.cc"], - hdrs = ["serving_device_selector.h"], - copts = tf_copts(), - deps = [ - "@com_google_absl//absl/strings", - "@com_google_absl//absl/types:span", - ], -) - -cc_library( - name = "serving_device_selector_policies", - srcs = ["serving_device_selector_policies.cc"], - hdrs = ["serving_device_selector_policies.h"], - copts = tf_copts(), - features = ["-layering_check"], - deps = [ - ":serving_device_selector", - ], -) - cc_library( name = "device_id_utils", hdrs = ["device_id_utils.h"], diff --git a/tensorflow/core/common_runtime/accumulate_n_optimizer.cc b/tensorflow/core/common_runtime/accumulate_n_optimizer.cc index e399b90428a81d..9b7930ad1ef590 100644 --- a/tensorflow/core/common_runtime/accumulate_n_optimizer.cc +++ b/tensorflow/core/common_runtime/accumulate_n_optimizer.cc @@ -53,7 +53,7 @@ class AccumulateNV2RemovePass : public GraphOptimizationPass { if (options.graph == nullptr) { // TODO(apassos) returning OK feels weird here as we can't do anything // without a graph, but some tests require this. - return OkStatus(); + return absl::OkStatus(); } Graph* g = options.graph->get(); @@ -70,7 +70,7 @@ class AccumulateNV2RemovePass : public GraphOptimizationPass { matches.push_back(n); } } - if (matches.empty()) return OkStatus(); + if (matches.empty()) return absl::OkStatus(); std::vector control_flow_info; TF_RETURN_IF_ERROR(BuildControlFlowInfo(g, &control_flow_info)); @@ -98,7 +98,7 @@ class AccumulateNV2RemovePass : public GraphOptimizationPass { TF_RETURN_IF_ERROR(RewriteIntoTempVariable(n, g)); } } - return OkStatus(); + return absl::OkStatus(); } Status RewriteIntoTempVariable(Node* n, Graph* g) { @@ -226,7 +226,7 @@ class AccumulateNV2RemovePass : public GraphOptimizationPass { // using its incoming/outgoing edge sets. g->RemoveNode(n); - return OkStatus(); + return absl::OkStatus(); } Status RewriteIntoAddN(Node* n, Graph* g) { @@ -281,7 +281,7 @@ class AccumulateNV2RemovePass : public GraphOptimizationPass { // using its incoming/outgoing edge sets. g->RemoveNode(n); - return OkStatus(); + return absl::OkStatus(); } }; REGISTER_OPTIMIZATION(OptimizationPassRegistry::PRE_PLACEMENT, 10, diff --git a/tensorflow/core/common_runtime/all_to_all.h b/tensorflow/core/common_runtime/all_to_all.h index 0c9716da6a108b..38bfd3ddc2058a 100644 --- a/tensorflow/core/common_runtime/all_to_all.h +++ b/tensorflow/core/common_runtime/all_to_all.h @@ -34,7 +34,7 @@ class AllToAll : public CollectiveImplementationInterface { void Run(StatusCallback done) override; Status InitializeCollectiveParams(CollectiveParams* col_params) override { - return OkStatus(); + return absl::OkStatus(); } // Initializes members of CollectiveContext not yet initialized, i.e. device diff --git a/tensorflow/core/common_runtime/arg_ret_placement.cc b/tensorflow/core/common_runtime/arg_ret_placement.cc index 0467a36456fd82..a995564c8c2964 100644 --- a/tensorflow/core/common_runtime/arg_ret_placement.cc +++ b/tensorflow/core/common_runtime/arg_ret_placement.cc @@ -73,7 +73,7 @@ Status CheckMemoryType(bool use_host_memory, const FullTypeDef& ft) { " but full type information is\n", ft.DebugString()); } - return OkStatus(); + return absl::OkStatus(); } // Note that ints_on_device is only true for single device functions @@ -152,7 +152,7 @@ static Status SetMemoryTypeForNode( aa.set_on_host(mt_from_dtype == HOST_MEMORY); alloc_attrs->push_back(aa); } - return OkStatus(); + return absl::OkStatus(); } // This helper function takes a list of nodes. @@ -169,7 +169,7 @@ static Status SetMemoryTypeHelper( weak_flag, /*ints_on_device=*/false, memory_types, alloc_attrs)); } - return OkStatus(); + return absl::OkStatus(); } // This helper function takes a list of pairs that contain an arg node. @@ -192,7 +192,7 @@ static Status SetMemoryTypeHelper( arg.first, dtype, /*is_arg=*/true, weak_flag, ints_on_device, /*memory_types=*/nullptr, alloc_attrs)); } - return OkStatus(); + return absl::OkStatus(); } // This helper function takes a list of pairs that contain a ret node. @@ -214,7 +214,7 @@ static Status SetMemoryTypeHelper( ret.first, dtype, /*is_arg=*/false, weak_flag, ints_on_device, /*memory_types=*/nullptr, alloc_attrs)); } - return OkStatus(); + return absl::OkStatus(); } Status SetMemoryTypeForArgs(const gtl::InlinedVector& nodes, diff --git a/tensorflow/core/common_runtime/buf_rendezvous.cc b/tensorflow/core/common_runtime/buf_rendezvous.cc index bd0d21dd2cf95d..8bd8a2c1a10ec6 100644 --- a/tensorflow/core/common_runtime/buf_rendezvous.cc +++ b/tensorflow/core/common_runtime/buf_rendezvous.cc @@ -148,7 +148,7 @@ void BufRendezvous::ProvideBuf(const string& key, Device* dev, DVLOG(4) << "ProvideBuf: key = " << key << ": calling cons_cb" << h->DebugString(); DeregisterCancellation(h); - h->cons_cb(OkStatus(), h); + h->cons_cb(absl::OkStatus(), h); } if (!providebuf_status.ok()) { done(providebuf_status); @@ -224,7 +224,7 @@ void BufRendezvous::ConsumeBuf(const string& key, const string& device_name, DVLOG(4) << "ConsumeBuf: key = " << key << ": calling cons_cb" << existing_hook->DebugString(); DeregisterCancellation(existing_hook); - existing_hook->cons_cb(OkStatus(), existing_hook); + existing_hook->cons_cb(absl::OkStatus(), existing_hook); return; } if (!consumebuf_status.ok()) { @@ -257,7 +257,7 @@ void BufRendezvous::CancelHook(const string& key) { /*static*/ void BufRendezvous::DoneWithHook(Hook* h) { - h->prod_cb(OkStatus()); + h->prod_cb(absl::OkStatus()); delete h; } diff --git a/tensorflow/core/common_runtime/buf_rendezvous_test.cc b/tensorflow/core/common_runtime/buf_rendezvous_test.cc index 5e2c6172aa4408..1bbb828f32c4e6 100644 --- a/tensorflow/core/common_runtime/buf_rendezvous_test.cc +++ b/tensorflow/core/common_runtime/buf_rendezvous_test.cc @@ -35,7 +35,7 @@ class BufRendezvousTest : public ::testing::Test { public: explicit FakeDevice(const DeviceAttributes& attrs) : Device(nullptr, attrs) {} - Status Sync() override { return OkStatus(); } + Status Sync() override { return absl::OkStatus(); } Allocator* GetAllocator(AllocatorAttributes) override { return nullptr; } }; DeviceAttributes attrs; diff --git a/tensorflow/core/common_runtime/collective_executor_mgr.cc b/tensorflow/core/common_runtime/collective_executor_mgr.cc index f776fa5e8960ec..7a10db2649d4d8 100644 --- a/tensorflow/core/common_runtime/collective_executor_mgr.cc +++ b/tensorflow/core/common_runtime/collective_executor_mgr.cc @@ -81,6 +81,17 @@ void CollectiveExecutorMgr::Cleanup(int64_t step_id) { if (ce) ce->Unref(); } +void CollectiveExecutorMgr::CleanupAll() { + gtl::FlatMap executor_table; + { + mutex_lock l(exec_mu_); + std::swap(executor_table, executor_table_); + } + for (auto iter : executor_table) { + iter.second->Unref(); + } +} + void CollectiveExecutorMgr::GetStepSequenceAsync( const GetStepSequenceRequest* request, GetStepSequenceResponse* response, const StatusCallback& done) { diff --git a/tensorflow/core/common_runtime/collective_executor_mgr.h b/tensorflow/core/common_runtime/collective_executor_mgr.h index d480db03a46609..dddaa7ae9c942a 100644 --- a/tensorflow/core/common_runtime/collective_executor_mgr.h +++ b/tensorflow/core/common_runtime/collective_executor_mgr.h @@ -37,6 +37,8 @@ class CollectiveExecutorMgr : public CollectiveExecutorMgrInterface { void Cleanup(int64_t step_id) override; + void CleanupAll() override; + ParamResolverInterface* GetParamResolver() const override { return param_resolver_.get(); } diff --git a/tensorflow/core/common_runtime/collective_param_resolver_local.cc b/tensorflow/core/common_runtime/collective_param_resolver_local.cc index cebcee53ba35a7..46de7b68645064 100644 --- a/tensorflow/core/common_runtime/collective_param_resolver_local.cc +++ b/tensorflow/core/common_runtime/collective_param_resolver_local.cc @@ -127,7 +127,7 @@ Status CheckUserSpecifiedRanks(const std::vector members) { "Duplicate ranks specified for group members. Received ranks: ", received_ranks); } - return OkStatus(); + return absl::OkStatus(); } } // namespace @@ -644,7 +644,7 @@ Status CollectiveParamResolverLocal::LookupGroup(int32_t group_key, group_rec->second->status.ToString()); } *group = group_rec->second->group; - return OkStatus(); + return absl::OkStatus(); } void CollectiveParamResolverLocal::CompleteParamsAsync( diff --git a/tensorflow/core/common_runtime/collective_rma_local.cc b/tensorflow/core/common_runtime/collective_rma_local.cc index d0d6c4f35324a2..4c968f703af615 100644 --- a/tensorflow/core/common_runtime/collective_rma_local.cc +++ b/tensorflow/core/common_runtime/collective_rma_local.cc @@ -157,7 +157,7 @@ void CollectiveRemoteAccessLocal::MemCpyAsync( int64_t bytes = src->TotalBytes(); DCHECK_EQ(dst->TotalBytes(), bytes); memcpy(DMAHelper::base(dst), DMAHelper::base(src), bytes); - done(OkStatus()); + done(absl::OkStatus()); } } diff --git a/tensorflow/core/common_runtime/colocate_predecessor_trees_pass.cc b/tensorflow/core/common_runtime/colocate_predecessor_trees_pass.cc new file mode 100644 index 00000000000000..54382f3d384738 --- /dev/null +++ b/tensorflow/core/common_runtime/colocate_predecessor_trees_pass.cc @@ -0,0 +1,198 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/common_runtime/colocate_predecessor_trees_pass.h" + +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "tensorflow/core/common_runtime/optimization_registry.h" +#include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor.pb.h" +#include "tensorflow/core/framework/tensor_shape.pb.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/util/device_name_utils.h" +#include "tensorflow/core/util/dump_graph.h" +#include "tsl/util/device_name_utils.h" + +namespace tensorflow { +namespace { + +constexpr absl::string_view kClassAttr = "_class"; + +// Check if the node is a valid tree node. Noticed this node must not be the +// root of the tree. We find root of the tree in other place. +// For a valid tree node, it must +// 1. not a arg node +// 2. not have device attr +// 3. not have colocation attr +// 4. must register for CPU +// 5. only have one output node +bool IsValidTreeNode(const Node& node, bool in_node_mode) { + if (node.IsArg()) { + return false; + } + if (node.has_assigned_device_name()) { + return false; + } + if (HasNodeAttr(node.def(), kClassAttr)) { + return false; + } + if (!KernelDefAvailable(DeviceType(DEVICE_CPU), node.def())) { + return false; + } + + int num_parents_to_tree_nodes = 0; + auto parent_nodes = in_node_mode ? node.out_nodes() : node.in_nodes(); + for (auto parent_node : parent_nodes) { + if (in_node_mode && (parent_node->IsExit() || parent_node->IsSink())) + continue; + if (!in_node_mode && parent_node->IsSource()) continue; + num_parents_to_tree_nodes++; + } + if (num_parents_to_tree_nodes != 1) return false; + return true; +} + +// Check if the node is potential root node. For a valid root node, it must +// 1. have device attr +// 2. not a arg node +// 3. must register for CPU has device type must be CPU +// 4. the output node can only be exit or sink node +bool IsPotentialRootNode(const Node& node) { + if (!node.has_assigned_device_name()) { + return false; + } + auto device_name = node.assigned_device_name(); + DeviceNameUtils::ParsedName parsed_device_name; + DeviceNameUtils::ParseFullName(device_name, &parsed_device_name); + if (parsed_device_name.type != DEVICE_CPU) { + return false; + } + if (node.IsArg()) { + return false; + } + if (!KernelDefAvailable(DeviceType(DEVICE_CPU), node.def())) { + return false; + } + return true; +} + +// Find all tree nodes for the root node. Otherwise, return false. +std::optional> FindTreeNodes(Node* potential_root) { + absl::flat_hash_set tree_nodes; + tree_nodes.insert(potential_root); + + auto seek_tree_nodes = [&](bool in_node_mode) { + std::queue pending_nodes; + auto nodes_to_potential_nodes = + in_node_mode ? potential_root->in_nodes() : potential_root->out_nodes(); + for (Node* node : nodes_to_potential_nodes) { + if (in_node_mode && node->IsSource()) continue; + if (!in_node_mode && (node->IsSink() || node->IsExit())) continue; + pending_nodes.push(node); + } + while (!pending_nodes.empty()) { + Node* node = pending_nodes.front(); + pending_nodes.pop(); + if (tree_nodes.find(node) != tree_nodes.end()) { + return false; + } + if (!IsValidTreeNode(*node, in_node_mode)) { + return false; + } + tree_nodes.insert(node); + auto nodes_to_potential_node = + in_node_mode ? node->in_nodes() : node->out_nodes(); + for (Node* node : nodes_to_potential_node) { + if (in_node_mode && node->IsSource()) continue; + if (!in_node_mode && (node->IsSink() || node->IsExit())) continue; + pending_nodes.push(node); + } + } + return true; + }; + + if (!seek_tree_nodes(/*in_node_mode=*/true) || + !seek_tree_nodes(/*in_node_mode=*/false)) { + return std::nullopt; + } + return tree_nodes; +} + +// Propagate colocation info from root node to each tree nodes. +void PropagateColocationInfo(Node* root_node, + absl::flat_hash_set& tree_nodes) { + std::string colocation_prefix = "loc:@"; + std::string node_name = root_node->name(); + for (auto node : tree_nodes) { + node->AddAttr(std::string(kClassAttr), + absl::StrCat(colocation_prefix, node_name)); + } +} + +} // namespace + +Status ColocatePredecessorTreesPass::Run( + const GraphOptimizationPassOptions& options) { + // find all potential node. + if (options.graph == nullptr) { + VLOG(1) << "No graph in colocate_predecessor_trees_pass.\n"; + return absl::OkStatus(); + } + Graph* graph = options.graph->get(); + if (VLOG_IS_ON(1)) { + VLOG(1) << DumpGraphToFile("before_colocate_predecessor_trees", *graph, + options.flib_def); + } + + absl::flat_hash_map> tree_nodes_map; + for (Node* node : graph->nodes()) { + if (IsPotentialRootNode(*node)) { + std::optional> nodes = FindTreeNodes(node); + if (nodes.has_value()) { + tree_nodes_map[node] = *std::move(nodes); + } + } + } + + for (auto& [root_node, tree_nodes] : tree_nodes_map) { + PropagateColocationInfo(root_node, tree_nodes); + } + + if (VLOG_IS_ON(1)) { + VLOG(1) << DumpGraphToFile("after_colocate_predecessor_trees", *graph, + options.flib_def); + } + + return absl::OkStatus(); +} + +// TODO(b/325245805): Fix the bug then register the pass again. +// REGISTER_OPTIMIZATION(OptimizationPassRegistry::PRE_PLACEMENT, 50, +// ColocatePredecessorTreesPass); + +} // namespace tensorflow diff --git a/tensorflow/core/common_runtime/colocate_predecessor_trees_pass.h b/tensorflow/core/common_runtime/colocate_predecessor_trees_pass.h new file mode 100644 index 00000000000000..2f1e21aea50b42 --- /dev/null +++ b/tensorflow/core/common_runtime/colocate_predecessor_trees_pass.h @@ -0,0 +1,116 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_COLOCATE_PREDECESSOR_TREES_PASS_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_COLOCATE_PREDECESSOR_TREES_PASS_H_ + +#include "tensorflow/core/common_runtime/optimization_registry.h" + +// Colocate a tree of unplaced constants with its placed root. Identify a +// dangling tree of ops whose root op is assigned but rest of ops are not +// assigned. Then it should colocate the rest of the ops with the root op. +// +// For example, the graph before pass is: +// +// node { +// name: "const0" +// op: "Const" +// } +// node { +// name: "const1" +// op: "Const" +// } +// node { +// name: "fill0" +// op: "Fill" +// input: "const1" +// input: "const0" +// } +// node { +// name: "id0" +// op: "Identity" +// input: "fill0" +// device: "/job:worker/replica:0/task:2/device:CPU:0" +// } +// +// The graph after pass is: +// +// node { +// name: "const0" +// op: "Const" +// attr { +// key: "_class" +// value { +// list { +// s: "loc:@id0" +// } +// } +// } +// } +// node { +// name: "const1" +// op: "Const" +// attr { +// key: "_class" +// value { +// list { +// s: "loc:@id0" +// } +// } +// } +// } +// node { +// name: "fill0" +// op: "Fill" +// input: "const1" +// input: "const0" +// attr { +// key: "_class" +// value { +// list { +// s: "loc:@id0" +// } +// } +// } +// } +// node { +// name: "id0" +// op: "Identity" +// input: "fill0" +// device: "/job:worker/replica:0/task:2/device:CPU:0" +// attr { +// key: "_class" +// value { +// list { +// s: "loc:@id0" +// } +// } +// } +// } + +namespace tensorflow { + +// This pass can place each tree of unassigned nodes with its root, when the +// root is already assigned to a device. Placement is instructed here with the +// colocation class attribute _class. This is a good heuristic because it +// reduces number of cut edges and tends to load balance. +class ColocatePredecessorTreesPass : public GraphOptimizationPass { + public: + Status Run(const GraphOptimizationPassOptions& options) override; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_COLOCATE_PREDECESSOR_TREES_PASS_H_ diff --git a/tensorflow/core/common_runtime/colocate_predecessor_trees_pass_test.cc b/tensorflow/core/common_runtime/colocate_predecessor_trees_pass_test.cc new file mode 100644 index 00000000000000..5871b095143db4 --- /dev/null +++ b/tensorflow/core/common_runtime/colocate_predecessor_trees_pass_test.cc @@ -0,0 +1,331 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/common_runtime/colocate_predecessor_trees_pass.h" + +#include +#include + +#include "tensorflow/cc/framework/scope.h" +#include "tensorflow/core/common_runtime/graph_def_builder_util.h" +#include "tensorflow/core/common_runtime/optimization_registry.h" +#include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/graph/graph_def_builder.h" +#include "tensorflow/core/platform/test.h" +#include "tsl/lib/core/status_test_util.h" +#include "tsl/platform/test.h" + +namespace tensorflow { + +const char kCpu0[] = "/job:tpu_host_worker/replica:0/task:0/device:CPU:0"; +const char kCpu1[] = "/job:tpu_host_worker/replica:0/task:0/device:CPU:1"; +const char kClassAttr[] = "_class"; + +// Return the node with name `name`. +Node* GetNode(const Graph& graph, const std::string& name) { + for (Node* node : graph.nodes()) { + if (node->name() == name) return node; + } + return nullptr; +} + +// Test a simple colocate predecessor tree example. +TEST(ColocatePredecessorTreesPassTest, SimpleExample) { + auto graph = std::make_unique(OpRegistry::Global()); + GraphDefBuilder builder(GraphDefBuilder::kFailImmediately); + Node* const_0 = ops::SourceOp("Const", builder.opts() + .WithName("const_0") + .WithAttr("dtype", DT_INT32) + .WithAttr("value", Tensor(1.0))); + Node* const_1 = ops::SourceOp("Const", builder.opts() + .WithName("const_1") + .WithAttr("dtype", DT_INT32) + .WithAttr("value", Tensor(2.0))); + Node* fill = + ops::BinaryOp("Fill", const_0, const_1, builder.opts().WithName("fill")); + ops::UnaryOp("Identity", fill, builder.opts().WithName("identity")); + + TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get())); + GetNode(*graph, "identity")->set_assigned_device_name(kCpu0); + + GraphDef before; + graph->ToGraphDef(&before); + GraphOptimizationPassOptions options; + options.graph = &graph; + ColocatePredecessorTreesPass pass; + TF_ASSERT_OK(pass.Run(options)); + + EXPECT_TRUE(HasNodeAttr(GetNode(*graph, "const_0")->def(), kClassAttr)); + EXPECT_TRUE(HasNodeAttr(GetNode(*graph, "const_1")->def(), kClassAttr)); + EXPECT_TRUE(HasNodeAttr(GetNode(*graph, "fill")->def(), kClassAttr)); + EXPECT_TRUE(HasNodeAttr(GetNode(*graph, "identity")->def(), kClassAttr)); + + std::string expected_colocation_info = "loc:@identity"; + const AttrValue* input_value; + TF_EXPECT_OK( + GetNode(*graph, "const_0")->attrs().Find(kClassAttr, &input_value)); + EXPECT_EQ(input_value->s(), expected_colocation_info); + TF_EXPECT_OK( + GetNode(*graph, "const_1")->attrs().Find(kClassAttr, &input_value)); + EXPECT_EQ(input_value->s(), expected_colocation_info); + TF_EXPECT_OK(GetNode(*graph, "fill")->attrs().Find(kClassAttr, &input_value)); + EXPECT_EQ(input_value->s(), expected_colocation_info); + TF_EXPECT_OK( + GetNode(*graph, "identity")->attrs().Find(kClassAttr, &input_value)); + EXPECT_EQ(input_value->s(), expected_colocation_info); +} + +// Test colocate two predecessor trees case. +TEST(ColocatePredecessorTreesPassTest, PropagateTwoTrees) { + auto graph = std::make_unique(OpRegistry::Global()); + GraphDefBuilder builder(GraphDefBuilder::kFailImmediately); + Node* const_0 = ops::SourceOp("Const", builder.opts() + .WithName("const_0") + .WithAttr("dtype", DT_INT32) + .WithAttr("value", Tensor(1.0))); + Node* const_1 = ops::SourceOp("Const", builder.opts() + .WithName("const_1") + .WithAttr("dtype", DT_INT32) + .WithAttr("value", Tensor(2.0))); + Node* fill = + ops::BinaryOp("Fill", const_0, const_1, builder.opts().WithName("fill")); + ops::UnaryOp("Identity", fill, builder.opts().WithName("identity")); + + Node* const_2 = ops::SourceOp("Const", builder.opts() + .WithName("const_2") + .WithAttr("dtype", DT_INT32) + .WithAttr("value", Tensor(1.0))); + Node* const_3 = ops::SourceOp("Const", builder.opts() + .WithName("const_3") + .WithAttr("dtype", DT_INT32) + .WithAttr("value", Tensor(2.0))); + Node* fill_1 = ops::BinaryOp("Fill", const_2, const_3, + builder.opts().WithName("fill_1")); + ops::UnaryOp("Identity", fill_1, builder.opts().WithName("identity_1")); + + TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get())); + GetNode(*graph, "identity")->set_assigned_device_name(kCpu0); + GetNode(*graph, "identity_1")->set_assigned_device_name(kCpu0); + + GraphDef before; + graph->ToGraphDef(&before); + GraphOptimizationPassOptions options; + options.graph = &graph; + ColocatePredecessorTreesPass pass; + TF_ASSERT_OK(pass.Run(options)); + + EXPECT_TRUE(HasNodeAttr(GetNode(*graph, "const_0")->def(), kClassAttr)); + EXPECT_TRUE(HasNodeAttr(GetNode(*graph, "const_1")->def(), kClassAttr)); + EXPECT_TRUE(HasNodeAttr(GetNode(*graph, "fill")->def(), kClassAttr)); + EXPECT_TRUE(HasNodeAttr(GetNode(*graph, "identity")->def(), kClassAttr)); + + std::string expected_colocation_info = "loc:@identity"; + const AttrValue* input_value; + TF_EXPECT_OK( + GetNode(*graph, "const_0")->attrs().Find(kClassAttr, &input_value)); + EXPECT_EQ(input_value->s(), expected_colocation_info); + TF_EXPECT_OK( + GetNode(*graph, "const_1")->attrs().Find(kClassAttr, &input_value)); + EXPECT_EQ(input_value->s(), expected_colocation_info); + TF_EXPECT_OK(GetNode(*graph, "fill")->attrs().Find(kClassAttr, &input_value)); + EXPECT_EQ(input_value->s(), expected_colocation_info); + TF_EXPECT_OK( + GetNode(*graph, "identity")->attrs().Find(kClassAttr, &input_value)); + EXPECT_EQ(input_value->s(), expected_colocation_info); + + EXPECT_TRUE(HasNodeAttr(GetNode(*graph, "const_2")->def(), kClassAttr)); + EXPECT_TRUE(HasNodeAttr(GetNode(*graph, "const_3")->def(), kClassAttr)); + EXPECT_TRUE(HasNodeAttr(GetNode(*graph, "fill_1")->def(), kClassAttr)); + EXPECT_TRUE(HasNodeAttr(GetNode(*graph, "identity_1")->def(), kClassAttr)); + + std::string expected_colocation_info_1 = "loc:@identity_1"; + TF_EXPECT_OK( + GetNode(*graph, "const_2")->attrs().Find(kClassAttr, &input_value)); + EXPECT_EQ(input_value->s(), expected_colocation_info_1); + TF_EXPECT_OK( + GetNode(*graph, "const_3")->attrs().Find(kClassAttr, &input_value)); + EXPECT_EQ(input_value->s(), expected_colocation_info_1); + TF_EXPECT_OK( + GetNode(*graph, "fill_1")->attrs().Find(kClassAttr, &input_value)); + EXPECT_EQ(input_value->s(), expected_colocation_info_1); + TF_EXPECT_OK( + GetNode(*graph, "identity_1")->attrs().Find(kClassAttr, &input_value)); + EXPECT_EQ(input_value->s(), expected_colocation_info_1); +} + +// Test a simple colocate predecessor tree example. +TEST(ColocatePredecessorTreesPassTest, RootHasMultipleOutputs) { + auto graph = std::make_unique(OpRegistry::Global()); + GraphDefBuilder builder(GraphDefBuilder::kFailImmediately); + Node* const_0 = ops::SourceOp("Const", builder.opts() + .WithName("const_0") + .WithAttr("dtype", DT_INT32) + .WithAttr("value", Tensor(1.0))); + Node* const_1 = ops::SourceOp("Const", builder.opts() + .WithName("const_1") + .WithAttr("dtype", DT_INT32) + .WithAttr("value", Tensor(2.0))); + Node* fill = + ops::BinaryOp("Fill", const_0, const_1, builder.opts().WithName("fill")); + Node* identity = + ops::UnaryOp("Identity", fill, builder.opts().WithName("identity")); + ops::UnaryOp("Identity", identity, builder.opts().WithName("identity_1")); + ops::UnaryOp("Identity", identity, builder.opts().WithName("identity_2")); + + TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get())); + GetNode(*graph, "identity")->set_assigned_device_name(kCpu0); + + GraphDef before; + graph->ToGraphDef(&before); + GraphOptimizationPassOptions options; + options.graph = &graph; + ColocatePredecessorTreesPass pass; + TF_ASSERT_OK(pass.Run(options)); + + EXPECT_TRUE(HasNodeAttr(GetNode(*graph, "const_0")->def(), kClassAttr)); + EXPECT_TRUE(HasNodeAttr(GetNode(*graph, "const_1")->def(), kClassAttr)); + EXPECT_TRUE(HasNodeAttr(GetNode(*graph, "fill")->def(), kClassAttr)); + EXPECT_TRUE(HasNodeAttr(GetNode(*graph, "identity")->def(), kClassAttr)); + EXPECT_TRUE(HasNodeAttr(GetNode(*graph, "identity_1")->def(), kClassAttr)); + EXPECT_TRUE(HasNodeAttr(GetNode(*graph, "identity_1")->def(), kClassAttr)); + + std::string expected_colocation_info = "loc:@identity"; + const AttrValue* input_value; + TF_EXPECT_OK( + GetNode(*graph, "const_0")->attrs().Find(kClassAttr, &input_value)); + EXPECT_EQ(input_value->s(), expected_colocation_info); + TF_EXPECT_OK( + GetNode(*graph, "const_1")->attrs().Find(kClassAttr, &input_value)); + EXPECT_EQ(input_value->s(), expected_colocation_info); + TF_EXPECT_OK(GetNode(*graph, "fill")->attrs().Find(kClassAttr, &input_value)); + EXPECT_EQ(input_value->s(), expected_colocation_info); + TF_EXPECT_OK( + GetNode(*graph, "identity")->attrs().Find(kClassAttr, &input_value)); + EXPECT_EQ(input_value->s(), expected_colocation_info); + TF_EXPECT_OK( + GetNode(*graph, "identity_1")->attrs().Find(kClassAttr, &input_value)); + EXPECT_EQ(input_value->s(), expected_colocation_info); + TF_EXPECT_OK( + GetNode(*graph, "identity_2")->attrs().Find(kClassAttr, &input_value)); + EXPECT_EQ(input_value->s(), expected_colocation_info); +} + +// Test that a const op has device attr, no colocation info is propagated. +TEST(ColocatePredecessorTreesPassTest, ConstHasDeviceAttr) { + auto graph = std::make_unique(OpRegistry::Global()); + GraphDefBuilder builder(GraphDefBuilder::kFailImmediately); + Node* const_0 = ops::SourceOp("Const", builder.opts() + .WithName("const_0") + .WithAttr("dtype", DT_INT32) + .WithAttr("value", Tensor(1.0))); + Node* const_1 = ops::SourceOp("Const", builder.opts() + .WithName("const_1") + .WithAttr("dtype", DT_INT32) + .WithAttr("value", Tensor(2.0))); + Node* fill = + ops::BinaryOp("Fill", const_0, const_1, builder.opts().WithName("fill")); + + ops::UnaryOp("Identity", fill, builder.opts().WithName("identity")); + + TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get())); + GetNode(*graph, "identity")->set_assigned_device_name(kCpu0); + GetNode(*graph, "const_0")->set_assigned_device_name(kCpu1); + + GraphDef before; + graph->ToGraphDef(&before); + GraphOptimizationPassOptions options; + options.graph = &graph; + ColocatePredecessorTreesPass pass; + TF_ASSERT_OK(pass.Run(options)); + + EXPECT_FALSE(HasNodeAttr(GetNode(*graph, "const_0")->def(), kClassAttr)); + EXPECT_FALSE(HasNodeAttr(GetNode(*graph, "const_1")->def(), kClassAttr)); + EXPECT_FALSE(HasNodeAttr(GetNode(*graph, "fill")->def(), kClassAttr)); + EXPECT_FALSE(HasNodeAttr(GetNode(*graph, "identity")->def(), kClassAttr)); +} + +// Test that a const op has colocation info, no colocation info is propagated. +TEST(ColocatePredecessorTreesPassTest, ConstHasColocationInfo) { + auto graph = std::make_unique(OpRegistry::Global()); + GraphDefBuilder builder(GraphDefBuilder::kFailImmediately); + Node* const_0 = + ops::SourceOp("Const", builder.opts() + .WithName("const_0") + .WithAttr("dtype", DT_INT32) + .WithAttr("value", Tensor(1.0)) + .WithAttr("_class", {"loc:@fill"})); + Node* const_1 = ops::SourceOp("Const", builder.opts() + .WithName("const_1") + .WithAttr("dtype", DT_INT32) + .WithAttr("value", Tensor(2.0))); + Node* fill = + ops::BinaryOp("Fill", const_0, const_1, builder.opts().WithName("fill")); + Node* identity = + ops::UnaryOp("Identity", fill, builder.opts().WithName("identity")); + + TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get())); + GetNode(*graph, "identity")->set_assigned_device_name(kCpu0); + + GraphDef before; + graph->ToGraphDef(&before); + + GraphOptimizationPassOptions options; + options.graph = &graph; + ColocatePredecessorTreesPass pass; + TF_ASSERT_OK(pass.Run(options)); + + EXPECT_TRUE(HasNodeAttr(const_0->def(), kClassAttr)); + EXPECT_FALSE(HasNodeAttr(const_1->def(), kClassAttr)); + EXPECT_FALSE(HasNodeAttr(fill->def(), kClassAttr)); + EXPECT_FALSE(HasNodeAttr(identity->def(), kClassAttr)); +} + +// Test that one input is Arg, no colocation info is propagated. +TEST(ColocatePredecessorTreesPassTest, InputArg) { + auto graph = std::make_unique(OpRegistry::Global()); + GraphDefBuilder builder(GraphDefBuilder::kFailImmediately); + Node* arg_0 = ops::SourceOp("_Arg", builder.opts() + .WithName("arg_0") + .WithAttr("T", DT_INT32) + .WithAttr("index", 0)); + Node* const_0 = ops::SourceOp("Const", builder.opts() + .WithName("const_0") + .WithAttr("dtype", DT_INT32) + .WithAttr("value", Tensor(2.0))); + Node* fill = + ops::BinaryOp("Fill", arg_0, const_0, builder.opts().WithName("fill")); + + ops::UnaryOp("Identity", fill, builder.opts().WithName("identity")); + + TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get())); + GetNode(*graph, "identity")->set_assigned_device_name(kCpu0); + + GraphDef before; + graph->ToGraphDef(&before); + GraphOptimizationPassOptions options; + options.graph = &graph; + ColocatePredecessorTreesPass pass; + TF_ASSERT_OK(pass.Run(options)); + + EXPECT_FALSE(HasNodeAttr(GetNode(*graph, "arg_0")->def(), kClassAttr)); + EXPECT_FALSE(HasNodeAttr(GetNode(*graph, "const_0")->def(), kClassAttr)); + EXPECT_FALSE(HasNodeAttr(GetNode(*graph, "fill")->def(), kClassAttr)); + EXPECT_FALSE(HasNodeAttr(GetNode(*graph, "identity")->def(), kClassAttr)); +} + +} // namespace tensorflow diff --git a/tensorflow/core/common_runtime/colocation_graph.cc b/tensorflow/core/common_runtime/colocation_graph.cc index 795f8fe4678f05..2edb3d7fd7fc7c 100644 --- a/tensorflow/core/common_runtime/colocation_graph.cc +++ b/tensorflow/core/common_runtime/colocation_graph.cc @@ -188,7 +188,7 @@ Status Member::SetAssignedDeviceName(const string& device_name) { // Set requested device to assigned_device to maintain the invariant that // requested is a specialization of assigned. requested_device_name_ = assigned_device_name_; - return OkStatus(); + return absl::OkStatus(); } Status Member::SetResourceDeviceName(const Node& node) { @@ -208,7 +208,7 @@ Status Member::SetResourceDeviceName(const Node& node) { // Set requested device to resource device to maintain the invariant that // requested is a specialization of resource. requested_device_name_ = resource_device_name_; - return OkStatus(); + return absl::OkStatus(); } Status Member::SetRequestedDeviceName(const Node& node) { @@ -228,7 +228,7 @@ Status Member::SetRequestedDeviceName(const Node& node) { node.requested_device(), "' in node: ", node.DebugString()); } - return OkStatus(); + return absl::OkStatus(); } Status Member::FillPossibleDevices(PossibleDevices* possible_device) const { @@ -242,7 +242,7 @@ Status Member::FillPossibleDevices(PossibleDevices* possible_device) const { possible_device->requested_device_name = requested_device_name_; possible_device->resource_device_name = resource_device_name_; possible_device->device_types = supported_device_types_; - return OkStatus(); + return absl::OkStatus(); } bool Member::IsEdgeFromCompositeDeviceToPhysicalDevice( @@ -295,7 +295,7 @@ Status Member::EnsureCompatibilityAcrossResourceEdge( if (DeviceNameUtils::AreCompatibleDevNames(src_root.requested_device_name_, requested_device_name_)) { - return OkStatus(); + return absl::OkStatus(); } // If we are here, assigned and resource devices are compatible but requested @@ -317,7 +317,7 @@ Status Member::EnsureCompatibilityAcrossResourceEdge( assigned_device_name_); DeviceNameUtils::EnsureSpecification(&requested_device_name_, resource_device_name_); - return OkStatus(); + return absl::OkStatus(); } void Member::Merge(std::vector* tree, int x_root, int y_root, @@ -416,7 +416,7 @@ Status Member::MergeDeviceNames(const Member& other, assigned_device_name_ = std::move(assigned_device_name_copy); resource_device_name_ = std::move(resource_device_name_copy); requested_device_name_ = std::move(requested_device_name_copy); - return OkStatus(); + return absl::OkStatus(); } // Updates this to contain the intersection of the device types in @@ -489,7 +489,7 @@ bool Member::MergeSupportedDevices( Status Member::AssignDevice(const Node& node) { if (node.assigned_device_name_index() == assigned_device_name_index_) { - return OkStatus(); + return absl::OkStatus(); } DeviceNameUtils::ParsedName parsed; @@ -525,7 +525,7 @@ Status Member::AssignDevice(const Node& node) { assigned_device_name_index_ = node.assigned_device_name_index(); // Clear cached possible_devices, if any. possible_devices_.clear(); - return OkStatus(); + return absl::OkStatus(); } void Member::MaybeExcludeXlaDevices() { @@ -561,7 +561,7 @@ Status Member::LimitToPossibleDevices(const PossibleDevices& devices, TF_RETURN_IF_ERROR(DeviceNameUtils::MergeDevNames( &resource_device_name_, devices.resource_device_name)); MergeSupportedDevices(devices.device_types); - return OkStatus(); + return absl::OkStatus(); } string Member::DebugString() const { @@ -701,7 +701,7 @@ Status ColocationGraph::ColocateAllNodes() { ColocateNodeToGroup(&colocation_group_root, node, node->name())); } - return OkStatus(); + return absl::OkStatus(); } Status ColocationGraph::ColocateResourceOrRefEdge(const Node* src, @@ -717,7 +717,7 @@ Status ColocationGraph::ColocateResourceOrRefEdge(const Node* src, // If the src root is assigned to a composite device and the dst root is // assigned to a physical device, don't colocate the dst root with the src // root. - return OkStatus(); + return absl::OkStatus(); } TF_RETURN_IF_ERROR(dst_root.EnsureCompatibilityAcrossResourceEdge( *src, src_root, *dst, log_device_placement_)); @@ -731,7 +731,7 @@ Status ColocationGraph::ColocateResourceOrRefEdge(const Node* src, status.message()), *dst); } - return OkStatus(); + return absl::OkStatus(); } Status ColocationGraph::ColocateResourceAndRefEdges( @@ -781,7 +781,7 @@ Status ColocationGraph::ColocateResourceAndRefEdges( } } - return OkStatus(); + return absl::OkStatus(); } namespace { @@ -910,7 +910,7 @@ Status ColocationGraph::AddHostOnlyDataTypesConstraints() { } } - return OkStatus(); + return absl::OkStatus(); } Status ColocationGraph::AddInspectionConstraints( @@ -923,7 +923,7 @@ Status ColocationGraph::AddInspectionConstraints( << ":\n\t" << groups.DebugString(); TF_RETURN_IF_ERROR(ApplyIOColocationGroups(groups, *node)); } - return OkStatus(); + return absl::OkStatus(); } Status ColocationGraph::Initialize() { @@ -940,7 +940,7 @@ Status ColocationGraph::Initialize() { members_[root_id].MaybeExcludeXlaDevices(); } - return OkStatus(); + return absl::OkStatus(); } // pair containing a node and whether this node has a resource input @@ -995,7 +995,7 @@ Status GetGroupNodes(const IOColocationGroups& groups, const Node& node, << "]"; } } - return OkStatus(); + return absl::OkStatus(); } // Returns whether the device_type in `device_attributes` is supported. @@ -1065,7 +1065,7 @@ Status ColocationGraph::ApplyIOColocationGroups( TF_RETURN_IF_ERROR(LimitToPossibleDevices(*group_node, possible_devices)); } - return OkStatus(); + return absl::OkStatus(); } Status ColocationGraph::ColocateNodeToGroup( @@ -1094,7 +1094,7 @@ Status ColocationGraph::ColocateNodeToGroup( } } } - return OkStatus(); + return absl::OkStatus(); } // Merge the (possibly disjoint) sets containing nodes "x" and @@ -1115,7 +1115,7 @@ Status ColocationGraph::ColocateNodes(const Node& x, const Node& y) { Status ColocationGraph::ColocateNodes(const Node& x, int x_root, const Node& y, int y_root) { if (x_root == y_root) { - return OkStatus(); + return absl::OkStatus(); } Member* new_root_member; @@ -1155,7 +1155,7 @@ Status ColocationGraph::ColocateNodes(const Node& x, int x_root, const Node& y, // All error checks are done, merge the colocation graphs. Member::Merge(&members_, x_root, y_root, &new_root_member, &old_root_member, /*dry_run=*/false); - return OkStatus(); + return absl::OkStatus(); } Status ColocationGraph::LimitToAssignedDevice(const Node& node) { @@ -1234,7 +1234,7 @@ Status ColocationGraph::GetDevicesForNode( const int node_root = FindAndUpdateRoot(node->id()); if (!members_[node_root].possible_devices().empty()) { *possible_devices = &members_[node_root].possible_devices(); - return OkStatus(); + return absl::OkStatus(); } Member& root_member = members_[node_root]; @@ -1361,7 +1361,7 @@ Status ColocationGraph::GetDevicesForNode( // Cache the result of the possible devices for this node group. root_member.set_possible_devices(std::move(devices)); *possible_devices = &root_member.possible_devices(); - return OkStatus(); + return absl::OkStatus(); } Status ColocationGraph::InitializeMembers() { @@ -1371,7 +1371,7 @@ Status ColocationGraph::InitializeMembers() { return AttachDef(status, *node); } } - return OkStatus(); + return absl::OkStatus(); } string ColocationGraph::DebugString() const { @@ -1481,7 +1481,7 @@ Status ColocationGraph::InitializeMemberWithAssignedDevice( for (const auto& d : member->supported_device_types()) { if (IsSupportedDeviceType(assigned_device->attributes(), d.first)) { - return OkStatus(); + return absl::OkStatus(); } } @@ -1537,7 +1537,7 @@ Status ColocationGraph::InitializeMember(const Node& node, Member* member) { } } } - return OkStatus(); + return absl::OkStatus(); } // Returns a list of devices having type in supported_device_types. The diff --git a/tensorflow/core/common_runtime/constant_folding.cc b/tensorflow/core/common_runtime/constant_folding.cc index e469527fdefd49..cf79d58555568a 100644 --- a/tensorflow/core/common_runtime/constant_folding.cc +++ b/tensorflow/core/common_runtime/constant_folding.cc @@ -644,7 +644,7 @@ Status ConstantFold(const ConstantFoldingOptions& opts, VLOG(1) << "No constant foldable nodes found"; *was_mutated = false; // This is not an error, so return the status as OK. - return OkStatus(); + return absl::OkStatus(); } std::map tensors_to_fetch; @@ -657,7 +657,7 @@ Status ConstantFold(const ConstantFoldingOptions& opts, VLOG(1) << "No constant nodes found that feed into the original graph."; *was_mutated = false; // This is not an error, so return the status as OK. - return OkStatus(); + return absl::OkStatus(); } VLOG(1) << "Constant foldable " << constant_graph->num_node_ids() << " : " << graph->num_node_ids(); @@ -714,7 +714,7 @@ Status ConstantFold(const ConstantFoldingOptions& opts, DumpGraph("After", graph); *was_mutated = (num_nodes_replaced > 0); - return OkStatus(); + return absl::OkStatus(); } } // namespace tensorflow diff --git a/tensorflow/core/common_runtime/constant_folding_test.cc b/tensorflow/core/common_runtime/constant_folding_test.cc index f65926fc3921a4..bf518b59ac7234 100644 --- a/tensorflow/core/common_runtime/constant_folding_test.cc +++ b/tensorflow/core/common_runtime/constant_folding_test.cc @@ -169,7 +169,7 @@ TEST_F(ConstantFoldingTest, DeterministicFolding) { opt.generate_new_name = generate_new_name; TF_CHECK_OK( ConstantFold(opt, nullptr, Env::Default(), nullptr, &g, &was_mutated)); - return OkStatus(); + return absl::OkStatus(); }; Graph g1(OpRegistry::Global()); @@ -691,7 +691,7 @@ class TestTFFileSystem : public ::tensorflow::NullFileSystem { const ::tensorflow::StringPiece sp = data_tensor_.tensor_data(); *result = std::unique_ptr<::tensorflow::ReadOnlyMemoryRegion>( new TestReadOnlyMemoryRegion(sp.data(), sp.size())); - return OkStatus(); + return absl::OkStatus(); } protected: @@ -708,7 +708,7 @@ class TestTFEnvironment : public ::tensorflow::EnvWrapper { was_used_ = true; if (fname == "test://test") { *result = &test_filesystem_; - return OkStatus(); + return absl::OkStatus(); } return tf_base::GetFileSystemForFile(fname, result); } diff --git a/tensorflow/core/common_runtime/copy_tensor.cc b/tensorflow/core/common_runtime/copy_tensor.cc index 29ad1493c3fbcc..a4712a5c83a742 100644 --- a/tensorflow/core/common_runtime/copy_tensor.cc +++ b/tensorflow/core/common_runtime/copy_tensor.cc @@ -74,7 +74,7 @@ void CopyHostToDevice(const Tensor* input, Allocator* cpu_allocator, status_cb->Ref(); CopyHostToDevice(&from, cpu_allocator, out_allocator, edge_name, dst, to, recv_dev_context, wrapped_done, sync_dst_compute); - return OkStatus(); + return absl::OkStatus(); } else { if (!DMAHelper::CanUseDMA(&from)) { Status err = errors::InvalidArgument( @@ -89,7 +89,7 @@ void CopyHostToDevice(const Tensor* input, Allocator* cpu_allocator, *to = Tensor(out_allocator, from.dtype(), from.shape()); recv_dev_context->CopyCPUTensorToDevice(&from, dst, to, wrapped_done, sync_dst_compute); - return OkStatus(); + return absl::OkStatus(); } else { return status_cb->status(); } @@ -112,7 +112,7 @@ void CopyHostToDevice(const Tensor* input, Allocator* cpu_allocator, } } else if (input->dtype() == DT_RESOURCE) { *output = *input; - done(OkStatus()); + done(absl::OkStatus()); } else { recv_dev_context->CopyCPUTensorToDevice(input, dst, output, std::move(done), sync_dst_compute); @@ -148,7 +148,7 @@ void CopyDeviceToDevice(CopyTensor::CopyFunction copy_function, send_dev_context, recv_dev_context, src, dst, src_alloc_attr, dst_alloc_attr, &from, to, dev_to_dev_stream_index, wrapped_done); - return OkStatus(); + return absl::OkStatus(); } else { if (!DMAHelper::CanUseDMA(&from)) { Status err = errors::InvalidArgument( @@ -164,7 +164,7 @@ void CopyDeviceToDevice(CopyTensor::CopyFunction copy_function, copy_function(send_dev_context, recv_dev_context, src, dst, src_alloc_attr, dst_alloc_attr, &from, to, dev_to_dev_stream_index, wrapped_done); - return OkStatus(); + return absl::OkStatus(); } else { return status_cb->status(); } @@ -188,7 +188,7 @@ void CopyDeviceToDevice(CopyTensor::CopyFunction copy_function, } } else if (input->dtype() == DT_RESOURCE) { *output = *input; - done(OkStatus()); + done(absl::OkStatus()); } else { copy_function(send_dev_context, recv_dev_context, src, dst, src_alloc_attr, dst_alloc_attr, input, output, dev_to_dev_stream_index, @@ -297,7 +297,7 @@ void CopyTensor::ViaDMA(StringPiece edge_name, DeviceContext* send_dev_context, // cpu -> cpu CHECK(!non_cpu_src && !non_cpu_dst); *output = *input; - done(OkStatus()); + done(absl::OkStatus()); } // static @@ -308,7 +308,7 @@ Status CopyTensor::Register(DeviceType sender_device_type, std::vector* registry = MutableRegistry(); registry->emplace_back(sender_device_type, receiver_device_type, copy_function, is_pluggable_device); - return OkStatus(); + return absl::OkStatus(); } namespace { @@ -324,7 +324,7 @@ static Status WrappedTensorDeviceCopy( *to = from; } - return OkStatus(); + return absl::OkStatus(); } #define REGISTER_WRAPPED_TENSOR_COPY(DIRECTION) \ @@ -357,7 +357,7 @@ void CopyDeviceToHost(const Tensor* input, Allocator* cpu_allocator, status_cb->Ref(); CopyDeviceToHost(&from, cpu_allocator, out_allocator, edge_name, src, to, send_dev_context, wrapped_done); - return OkStatus(); + return absl::OkStatus(); } else { if (!DMAHelper::CanUseDMA(&from)) { Status err = errors::InvalidArgument( @@ -372,7 +372,7 @@ void CopyDeviceToHost(const Tensor* input, Allocator* cpu_allocator, *to = Tensor(out_allocator, from.dtype(), from.shape()); send_dev_context->CopyDeviceTensorToCPU(&from, edge_name, src, to, wrapped_done); - return OkStatus(); + return absl::OkStatus(); } else { return status_cb->status(); } @@ -395,7 +395,7 @@ void CopyDeviceToHost(const Tensor* input, Allocator* cpu_allocator, } } else if (input->dtype() == DT_RESOURCE) { *output = *input; - done(OkStatus()); + done(absl::OkStatus()); } else { send_dev_context->CopyDeviceTensorToCPU(input, edge_name, src, output, std::move(done)); diff --git a/tensorflow/core/common_runtime/costmodel_manager.cc b/tensorflow/core/common_runtime/costmodel_manager.cc index d029b6c89c4c13..36ef7d08933b84 100644 --- a/tensorflow/core/common_runtime/costmodel_manager.cc +++ b/tensorflow/core/common_runtime/costmodel_manager.cc @@ -63,7 +63,7 @@ Status CostModelManager::AddToCostGraphDef(const Graph* graph, } CostModel* cost_model = it->second; cost_model->AddToCostGraphDef(graph, cost_graph); - return OkStatus(); + return absl::OkStatus(); } } // namespace tensorflow diff --git a/tensorflow/core/common_runtime/debugger_state_interface.cc b/tensorflow/core/common_runtime/debugger_state_interface.cc index a6581a2166cb0f..a9626069a79926 100644 --- a/tensorflow/core/common_runtime/debugger_state_interface.cc +++ b/tensorflow/core/common_runtime/debugger_state_interface.cc @@ -69,7 +69,7 @@ Status DebuggerStateRegistry::CreateState( "It appears that TFDBG is not linked in this TensorFlow build."); } else { *state = (*factory_)(debug_options); - return OkStatus(); + return absl::OkStatus(); } } @@ -90,7 +90,7 @@ Status DebugGraphDecoratorRegistry::CreateDecorator( "It appears that TFDBG is not linked in this TensorFlow build."); } else { *decorator = (*factory_)(options); - return OkStatus(); + return absl::OkStatus(); } } diff --git a/tensorflow/core/common_runtime/device/device_utils.cc b/tensorflow/core/common_runtime/device/device_utils.cc index 5203b4c3181597..e95f95cb8dfa8e 100644 --- a/tensorflow/core/common_runtime/device/device_utils.cc +++ b/tensorflow/core/common_runtime/device/device_utils.cc @@ -30,7 +30,7 @@ Status ValidateDeviceType(StringPiece type) { strings::StrCat("Device name/type '", type, "' must match ", kTfDeviceTypeRegEx->pattern(), ".")); } - return OkStatus(); + return absl::OkStatus(); } } // namespace device_utils diff --git a/tensorflow/core/common_runtime/device_mgr_test.cc b/tensorflow/core/common_runtime/device_mgr_test.cc index 5d4777018ba529..6cf9be8959599d 100644 --- a/tensorflow/core/common_runtime/device_mgr_test.cc +++ b/tensorflow/core/common_runtime/device_mgr_test.cc @@ -30,7 +30,7 @@ static Device* CreateDevice(const char* type, const char* name) { class FakeDevice : public Device { public: explicit FakeDevice(const DeviceAttributes& attr) : Device(nullptr, attr) {} - Status Sync() override { return OkStatus(); } + Status Sync() override { return absl::OkStatus(); } Allocator* GetAllocator(AllocatorAttributes) override { return nullptr; } }; DeviceAttributes attr; diff --git a/tensorflow/core/common_runtime/device_resolver_local.cc b/tensorflow/core/common_runtime/device_resolver_local.cc index 4c2394d80da484..1fca35b15ef012 100644 --- a/tensorflow/core/common_runtime/device_resolver_local.cc +++ b/tensorflow/core/common_runtime/device_resolver_local.cc @@ -31,7 +31,7 @@ Status DeviceResolverLocal::GetDeviceAttributes(const string& device, return s; } *attributes = dev->attributes(); - return OkStatus(); + return absl::OkStatus(); } Status DeviceResolverLocal::GetAllDeviceAttributes( diff --git a/tensorflow/core/common_runtime/device_set_test.cc b/tensorflow/core/common_runtime/device_set_test.cc index 630754643ce9cd..20141e80f2a458 100644 --- a/tensorflow/core/common_runtime/device_set_test.cc +++ b/tensorflow/core/common_runtime/device_set_test.cc @@ -29,7 +29,7 @@ static Device* Dev(const char* type, const char* name) { class FakeDevice : public Device { public: explicit FakeDevice(const DeviceAttributes& attr) : Device(nullptr, attr) {} - Status Sync() override { return OkStatus(); } + Status Sync() override { return absl::OkStatus(); } Allocator* GetAllocator(AllocatorAttributes) override { return nullptr; } }; DeviceAttributes attr; @@ -61,11 +61,11 @@ class DeviceSetTest : public ::testing::Test { class DummyFactory : public DeviceFactory { public: Status ListPhysicalDevices(std::vector* devices) override { - return OkStatus(); + return absl::OkStatus(); } Status CreateDevices(const SessionOptions& options, const string& name_prefix, std::vector>* devices) override { - return OkStatus(); + return absl::OkStatus(); } }; diff --git a/tensorflow/core/common_runtime/direct_session.cc b/tensorflow/core/common_runtime/direct_session.cc index 587137c308d77a..36faa488aa34ab 100644 --- a/tensorflow/core/common_runtime/direct_session.cc +++ b/tensorflow/core/common_runtime/direct_session.cc @@ -108,7 +108,7 @@ Status NewThreadPoolFromThreadPoolOptions( num_threads, !options.config.experimental().disable_thread_spinning(), /*allocator=*/nullptr); *owned = true; - return OkStatus(); + return absl::OkStatus(); } // Global, named threadpool. @@ -135,7 +135,7 @@ Status NewThreadPoolFromThreadPoolOptions( } *owned = false; *pool = mvalue->second; - return OkStatus(); + return absl::OkStatus(); } // Function to create a global thread pool for sessions. The thread number is @@ -206,7 +206,7 @@ class DirectSessionFactory : public SessionFactory { sessions_.push_back(session); } *out_session = session; - return OkStatus(); + return absl::OkStatus(); } Status Reset(const SessionOptions& options, @@ -430,7 +430,7 @@ Status DirectSession::Create(GraphDef&& graph) { } return ExtendLocked(std::move(graph)); } - return OkStatus(); + return absl::OkStatus(); } Status DirectSession::Extend(const GraphDef& graph) { @@ -471,7 +471,7 @@ Status DirectSession::ExtendLocked(GraphDef&& graph) { execution_state_.swap(state); TF_RETURN_IF_ERROR(flib_def_->AddLibrary(graph.library())); } - return OkStatus(); + return absl::OkStatus(); } Status DirectSession::Run(const NamedTensorList& inputs, @@ -499,7 +499,7 @@ Status DirectSession::CreateDebuggerState( TF_RETURN_IF_ERROR(debugger_state->get()->PublishDebugMetadata( global_step, session_run_index, executor_step_index, input_names, output_names, target_names)); - return OkStatus(); + return absl::OkStatus(); } Status DirectSession::DecorateAndPublishGraphForDebug( @@ -510,7 +510,7 @@ Status DirectSession::DecorateAndPublishGraphForDebug( TF_RETURN_IF_ERROR(decorator->DecorateGraph(graph, device)); TF_RETURN_IF_ERROR(decorator->PublishGraph(*graph, device->name())); - return OkStatus(); + return absl::OkStatus(); } Status DirectSession::RunInternal( @@ -843,7 +843,7 @@ Status DirectSession::RunInternal( } metrics::UpdateGraphExecTime(options_.env->NowMicros() - start_time_usecs); - return OkStatus(); + return absl::OkStatus(); } Status DirectSession::Run(const RunOptions& run_options, @@ -964,7 +964,7 @@ Status DirectSession::Run(const RunOptions& run_options, metrics::RecordGraphOutputTensors(output_size); } - return OkStatus(); + return absl::OkStatus(); } Status DirectSession::PRunSetup(const std::vector& input_names, @@ -1043,7 +1043,7 @@ Status DirectSession::PRunSetup(const std::vector& input_names, } *handle = run_state_args.handle; - return OkStatus(); + return absl::OkStatus(); } Status DirectSession::PRun(const string& handle, const NamedTensorList& inputs, @@ -1209,7 +1209,7 @@ Status DirectSession::SendPRunInputs(const NamedTensorList& inputs, return s; } } - return OkStatus(); + return absl::OkStatus(); } Status DirectSession::RecvPRunOutputs( @@ -1254,7 +1254,7 @@ Status DirectSession::RecvPRunOutputs( (*outputs)[output_offset] = output_tensor; } - return OkStatus(); + return absl::OkStatus(); } Status DirectSession::CheckFetch(const NamedTensorList& feeds, @@ -1315,7 +1315,7 @@ Status DirectSession::CheckFetch(const NamedTensorList& feeds, } } } - return OkStatus(); + return absl::OkStatus(); } Status DirectSession::CreateExecutors( @@ -1380,7 +1380,7 @@ Status DirectSession::CreateExecutors( tsl::core::RefCountPtr* r) { *r = tsl::core::RefCountPtr( new IntraProcessRendezvous(device_mgr)); - return OkStatus(); + return absl::OkStatus(); }})); GraphOptimizer optimizer(optimizer_opts); @@ -1489,7 +1489,7 @@ Status DirectSession::CreateExecutors( *out_executors_and_keys = std::move(ek); *out_func_info = std::move(func_info); - return OkStatus(); + return absl::OkStatus(); } Status DirectSession::GetOrCreateExecutors( @@ -1524,7 +1524,7 @@ Status DirectSession::GetOrCreateExecutors( auto it = executors_.find(key); if (it != executors_.end()) { *executors_and_keys = it->second.get(); - return OkStatus(); + return absl::OkStatus(); } } @@ -1557,7 +1557,7 @@ Status DirectSession::GetOrCreateExecutors( auto it = executors_.find(sorted_key); if (it != executors_.end()) { *executors_and_keys = it->second.get(); - return OkStatus(); + return absl::OkStatus(); } } @@ -1603,7 +1603,7 @@ Status DirectSession::GetOrCreateExecutors( executors_.emplace(key, insert_result.first->second); *executors_and_keys = insert_result.first->second.get(); - return OkStatus(); + return absl::OkStatus(); } Status DirectSession::CreateGraphs( @@ -1775,24 +1775,24 @@ ::tensorflow::Status DirectSession::ListDevices( const DeviceAttributes& attrs = d->attributes(); response->emplace_back(attrs); } - return OkStatus(); + return absl::OkStatus(); } ::tensorflow::Status DirectSession::Reset( const std::vector& containers) { device_mgr_->ClearContainers(containers); - return OkStatus(); + return absl::OkStatus(); } ::tensorflow::Status DirectSession::Close() { cancellation_manager_->StartCancel(); { mutex_lock l(closed_lock_); - if (closed_) return OkStatus(); + if (closed_) return absl::OkStatus(); closed_ = true; } if (factory_ != nullptr) factory_->Deregister(this); - return OkStatus(); + return absl::OkStatus(); } DirectSession::RunState::RunState(int64_t step_id, @@ -1868,7 +1868,7 @@ ::tensorflow::Status DirectSession::WaitForNotification( } else { notification->WaitForNotification(); } - return OkStatus(); + return absl::OkStatus(); } Status DirectSession::MakeCallable(const CallableOptions& callable_options, @@ -1886,7 +1886,7 @@ Status DirectSession::MakeCallable(const CallableOptions& callable_options, *out_handle = next_callable_handle_++; callables_[*out_handle] = {std::move(ek), std::move(func_info)}; } - return OkStatus(); + return absl::OkStatus(); } class DirectSession::RunCallableCallFrame : public CallFrameInterface { @@ -1913,7 +1913,7 @@ class DirectSession::RunCallableCallFrame : public CallFrameInterface { } else { *val = &(*feed_tensors_)[index]; } - return OkStatus(); + return absl::OkStatus(); } Status SetRetval(int index, const Tensor& val) override { @@ -1921,7 +1921,7 @@ class DirectSession::RunCallableCallFrame : public CallFrameInterface { return errors::Internal("RetVal index out of bounds: ", index); } (*fetch_tensors_)[index] = val; - return OkStatus(); + return absl::OkStatus(); } private: @@ -2032,7 +2032,7 @@ ::tensorflow::Status DirectSession::RunCallable( metrics::RecordGraphOutputTensors(output_size); } - return OkStatus(); + return absl::OkStatus(); } ::tensorflow::Status DirectSession::ReleaseCallable(CallableHandle handle) { @@ -2041,7 +2041,7 @@ ::tensorflow::Status DirectSession::ReleaseCallable(CallableHandle handle) { return errors::InvalidArgument("No such callable handle: ", handle); } callables_.erase(handle); - return OkStatus(); + return absl::OkStatus(); } Status DirectSession::Finalize() { @@ -2055,7 +2055,7 @@ Status DirectSession::Finalize() { execution_state_.reset(); flib_def_.reset(); finalized_ = true; - return OkStatus(); + return absl::OkStatus(); } DirectSession::Callable::~Callable() { diff --git a/tensorflow/core/common_runtime/direct_session.h b/tensorflow/core/common_runtime/direct_session.h index a81a3079345117..6421862af65af0 100644 --- a/tensorflow/core/common_runtime/direct_session.h +++ b/tensorflow/core/common_runtime/direct_session.h @@ -111,7 +111,7 @@ class DirectSession : public Session { ::tensorflow::Status Close() override; ::tensorflow::Status LocalDeviceManager(const DeviceMgr** output) override { *output = device_mgr_.get(); - return OkStatus(); + return absl::OkStatus(); } void ExportCostModels(CostModelManager::CostModelMap* cost_models) { @@ -313,7 +313,7 @@ class DirectSession : public Session { ::tensorflow::Status CheckNotClosed() { mutex_lock l(closed_lock_); if (closed_) return errors::Cancelled("Session has been closed."); - return OkStatus(); + return absl::OkStatus(); } ::tensorflow::Status CheckGraphCreated(const char* method) { @@ -322,7 +322,7 @@ class DirectSession : public Session { return errors::InvalidArgument( "Session was not created with a graph before ", method, "!"); } - return OkStatus(); + return absl::OkStatus(); } ::tensorflow::Status CreateDebuggerState( diff --git a/tensorflow/core/common_runtime/direct_session_test.cc b/tensorflow/core/common_runtime/direct_session_test.cc index a4777ea364b307..ab40c0f99eef0f 100644 --- a/tensorflow/core/common_runtime/direct_session_test.cc +++ b/tensorflow/core/common_runtime/direct_session_test.cc @@ -2757,7 +2757,7 @@ class DirectSessionCollectiveTest : public ::testing::Test { mutex_lock l(direct_session->collective_graph_key_lock_); *collective_graph_key = direct_session->collective_graph_key_; } - return OkStatus(); + return absl::OkStatus(); } private: diff --git a/tensorflow/core/common_runtime/dynamic_device_mgr.cc b/tensorflow/core/common_runtime/dynamic_device_mgr.cc index 854329f19ea90c..325bbfd97b9849 100644 --- a/tensorflow/core/common_runtime/dynamic_device_mgr.cc +++ b/tensorflow/core/common_runtime/dynamic_device_mgr.cc @@ -118,7 +118,7 @@ Status DynamicDeviceMgr::LookupDevice(StringPiece name, Device** device) const { return errors::InvalidArgument(name, " unknown device."); } *device = iter->second; - return OkStatus(); + return absl::OkStatus(); } bool DynamicDeviceMgr::ContainsDevice(int64_t device_incarnation) const { @@ -181,7 +181,7 @@ Status DynamicDeviceMgr::AddDevices( device_incarnation_set_.insert(d->attributes().incarnation()); dynamic_devices_.emplace(d.get(), std::move(d)); } - return OkStatus(); + return absl::OkStatus(); } Status DynamicDeviceMgr::RemoveDevices(const std::vector& devices) { @@ -221,7 +221,7 @@ Status DynamicDeviceMgr::RemoveDevices(const std::vector& devices) { stale_devices_.add(std::move(it->second)); dynamic_devices_.erase(it); } - return OkStatus(); + return absl::OkStatus(); } Status DynamicDeviceMgr::RemoveDevicesByName( diff --git a/tensorflow/core/common_runtime/dynamic_device_mgr_test.cc b/tensorflow/core/common_runtime/dynamic_device_mgr_test.cc index 4749d254c3a1e8..90c9782ab1a3e0 100644 --- a/tensorflow/core/common_runtime/dynamic_device_mgr_test.cc +++ b/tensorflow/core/common_runtime/dynamic_device_mgr_test.cc @@ -34,7 +34,7 @@ static Device* CreateDevice(const char* type, const char* name, class FakeDevice : public Device { public: explicit FakeDevice(const DeviceAttributes& attr) : Device(nullptr, attr) {} - Status Sync() override { return OkStatus(); } + Status Sync() override { return absl::OkStatus(); } Allocator* GetAllocator(AllocatorAttributes) override { return nullptr; } }; diff --git a/tensorflow/core/common_runtime/eager/context.cc b/tensorflow/core/common_runtime/eager/context.cc index e788afc3eee409..5296563337aa63 100644 --- a/tensorflow/core/common_runtime/eager/context.cc +++ b/tensorflow/core/common_runtime/eager/context.cc @@ -147,7 +147,9 @@ EagerContext::EagerContext( pin_small_ops_to_cpu_(ReadBoolFromEnvVar( "TF_EAGER_ENABLE_SMALL_TENSOR_CPU_PINNING", false)), run_eager_op_as_function_(run_eager_op_as_function), - jit_compile_rewrite_(jit_compile_rewrite) { + jit_compile_rewrite_(jit_compile_rewrite), + register_abstract_functions_local_only_(ReadBoolFromEnvVar( + "TF_EAGER_REGISTER_ABSTRACT_FUNCTIONS_LOCAL_ONLY", false)) { ResetPFLR(device_mgr, opts.env, &opts.config, TF_GRAPH_DEF_VERSION, &func_lib_def_, opts.config.graph_options().optimizer_options(), thread_pool_.get(), cluster_flr); @@ -708,7 +710,8 @@ Status EagerContext::RegisterFunction(AbstractFunction* f) { if (!fdef) { return errors::InvalidArgument("GetFunctionDef returned nullptr."); } - return AddFunctionDef(*fdef); + return AddFunctionDef(*fdef, FunctionDefLibrary(), + register_abstract_functions_local_only_); } bool EagerContext::UsesTFRT() { return false; } diff --git a/tensorflow/core/common_runtime/eager/context.h b/tensorflow/core/common_runtime/eager/context.h index 9accb3749e252a..78edf11aa98346 100644 --- a/tensorflow/core/common_runtime/eager/context.h +++ b/tensorflow/core/common_runtime/eager/context.h @@ -896,6 +896,32 @@ class EagerContext : public ImmediateExecutionContext, public core::RefCounted { std::function resource_deallocator_ = nullptr; bool run_eager_op_as_function_; bool jit_compile_rewrite_; + + // Controls the behavior of + // `EagerContext::RegisterFunction(AbstractFunction*)` in distributed + // settings. + // + // By default, each abstract function will be registered on all workers in + // a cluster. If the environment variable + // `TF_EAGER_REGISTER_ABSTRACT_FUNCTIONS_LOCAL_ONLY=1` is set, each abstract + // function will be registered on the local worker only. + // + // In the common case that all functions are initially dispatched to + // a local device, the `ProcessFunctionLibraryRuntime` + // will ensure that the precise dependencies of that function are shipped to + // the remote device. Since PFLR instantiation often involves optimization, + // passes such as lowering control flow and inlining function calls, this will + // result in (1) sending a substantially smaller set of functions to each + // worker, and (2) the unoptimized functions never being called. + // + // Therefore setting `TF_EAGER_REGISTER_ABSTRACT_FUNCTIONS_LOCAL_ONLY=1` can + // significantly reduce both the startup time and the memory footprint on + // remote workers by avoiding the shipping of unneeded functions. + // + // TODO(b/326251557): Infer automatically when it is necessary to register a + // function or its dependencies on remote hosts; then remove the environment + // variable. + bool register_abstract_functions_local_only_; }; inline EagerContext* ContextFromInterface(ImmediateExecutionContext* context) { diff --git a/tensorflow/core/common_runtime/eager/context_distributed_manager.cc b/tensorflow/core/common_runtime/eager/context_distributed_manager.cc index 502e164c2c205b..d4760ccbfb2129 100644 --- a/tensorflow/core/common_runtime/eager/context_distributed_manager.cc +++ b/tensorflow/core/common_runtime/eager/context_distributed_manager.cc @@ -230,14 +230,12 @@ Status GetReplacedFromExistingWorkers( return OkStatus(); } -Status CreateRemoteContexts(EagerContext* context, - const std::vector& remote_workers, - uint64 context_id, uint64 context_view_id, - int keep_alive_secs, const ServerDef& server_def, - eager::EagerClientCache* remote_eager_workers, - bool async, - const eager::CreateContextRequest& base_request, - int64_t init_timeout_in_ms, int retries) { +Status CreateRemoteContexts( + EagerContext* context, const std::vector& remote_workers, + uint64 context_id, uint64 context_view_id, int keep_alive_secs, + const ServerDef& server_def, eager::EagerClientCache* remote_eager_workers, + bool async, const eager::CreateContextRequest& base_request, + int64_t init_timeout_in_ms, int retries, bool clear_existing_contexts) { int num_remote_workers = remote_workers.size(); BlockingCounter counter(num_remote_workers); std::vector statuses(num_remote_workers); @@ -271,6 +269,7 @@ Status CreateRemoteContexts(EagerContext* context, request.mutable_server_def()->set_task_index(parsed_name.task); request.mutable_server_def()->mutable_default_session_config()->MergeFrom( server_def.default_session_config()); + request.set_clear_existing_contexts(clear_existing_contexts); std::vector filtered_device_mask; context->FilterDevicesForRemoteWorkers( @@ -415,7 +414,8 @@ Status UpdateRemoteContexts(EagerContext* context, Status UpdateContextWithServerDef(EagerContext* context, const ServerDef& server_def, bool reset_context, int keep_alive_secs, - int64_t init_timeout_in_ms, int retries) { + int64_t init_timeout_in_ms, int retries, + bool clear_existing_contexts = false) { // We don't use the TF_RETURN_IF_ERROR macro directly since that destroys the // server object (which currently CHECK-fails) and we miss the error, instead, // we log the error, and then return to allow the user to see the error @@ -578,7 +578,7 @@ Status UpdateContextWithServerDef(EagerContext* context, reset_context_status = CreateRemoteContexts( context, remote_workers, context_id, context_view_id, keep_alive_secs, server_def, remote_eager_workers.get(), context->Executor().Async(), - base_request, init_timeout_in_ms, retries); + base_request, init_timeout_in_ms, retries, clear_existing_contexts); // NOTE: the remote tasks could fail after `GetAllRemoteDevices` and cause // the CreateRemoteContexts to fail. We currently only log instead of // directly returning the error, since returning here will cause the server @@ -604,7 +604,7 @@ Status UpdateContextWithServerDef(EagerContext* context, context, added_workers, context_id, context_view_id + 1, keep_alive_secs, server_def, remote_eager_workers.get(), context->Executor().Async(), base_request, init_timeout_in_ms, - /*retries=*/0)); + /*retries=*/0, /*clear_existing_contexts=*/false)); } if (!existing_workers.empty()) { if (VLOG_IS_ON(1)) { @@ -672,7 +672,7 @@ Status UpdateContextWithServerDef(EagerContext* context, Status EagerContextDistributedManager::SetOrUpdateServerDef( const ServerDef& server_def, bool reset_context, int keep_alive_secs, - int64_t init_timeout_in_ms, int retries) { + int64_t init_timeout_in_ms, int retries, bool clear_existing_contexts) { if (server_def.has_cluster_device_filters()) { if (reset_context) { const auto& cdf = server_def.cluster_device_filters(); @@ -696,9 +696,9 @@ Status EagerContextDistributedManager::SetOrUpdateServerDef( "when updating the server def."; } } - Status s = - UpdateContextWithServerDef(context_, server_def, reset_context, - keep_alive_secs, init_timeout_in_ms, retries); + Status s = UpdateContextWithServerDef(context_, server_def, reset_context, + keep_alive_secs, init_timeout_in_ms, + retries, clear_existing_contexts); // If context is reset, make sure pointer is set to the new agent. coordination_service_agent_ = context_->GetServer() diff --git a/tensorflow/core/common_runtime/eager/context_distributed_manager.h b/tensorflow/core/common_runtime/eager/context_distributed_manager.h index b6344827f44ba4..c5ba7ba4c5197a 100644 --- a/tensorflow/core/common_runtime/eager/context_distributed_manager.h +++ b/tensorflow/core/common_runtime/eager/context_distributed_manager.h @@ -46,7 +46,8 @@ class EagerContextDistributedManager Status SetOrUpdateServerDef(const ServerDef& server_def, bool reset_context, int keep_alive_secs, int64_t init_timeout_in_ms, - int retries) override; + int retries, + bool clear_existing_contexts = false) override; Status InitializeLocalOnlyContext(const ServerDef& server_def, int keep_alive_secs) override; diff --git a/tensorflow/core/common_runtime/eager/custom_device.h b/tensorflow/core/common_runtime/eager/custom_device.h index ed8716fb4b298a..6ab6cbe8283793 100644 --- a/tensorflow/core/common_runtime/eager/custom_device.h +++ b/tensorflow/core/common_runtime/eager/custom_device.h @@ -57,7 +57,7 @@ class CustomDevice { // Returns true signifying to pin to the current custom device. // Returns false to pin to the physical device. - virtual StatusOr ShallPinToThisDevice( + virtual absl::StatusOr ShallPinToThisDevice( const ImmediateExecutionOperation* op) = 0; }; diff --git a/tensorflow/core/common_runtime/eager/execute_test.cc b/tensorflow/core/common_runtime/eager/execute_test.cc index 7f53263deced35..74e5b88a64578f 100644 --- a/tensorflow/core/common_runtime/eager/execute_test.cc +++ b/tensorflow/core/common_runtime/eager/execute_test.cc @@ -347,7 +347,7 @@ TEST(ExecuteTest, MultipleNestedCompiledFunction) { TF_ASSERT_OK(ctx->AddFunctionDef(x_times_two)); const string call_function_name = "FunctionCall"; - const FunctionDef function_call = FunctionDefHelper::Define( + FunctionDef function_call = FunctionDefHelper::Define( // Name call_function_name, // Args @@ -362,12 +362,21 @@ TEST(ExecuteTest, MultipleNestedCompiledFunction) { "StatefulPartitionedCall", {"x"}, {{"_XlaMustCompile", true}, - {"device", "/job:localhost/replica:0/task:0/device:CPU:0"}, + {"_device", "/job:localhost/replica:0/task:0/device:CPU:0"}, {"Tin", DataTypeSlice({DT_INT64})}, {"Tout", DataTypeSlice({DT_INT64})}, {"f", tensorflow::FunctionDefHelper::FunctionRef( "XTimesTwo", {{"T", DT_INT64}})}}}, }); + + // Set user requested device for the StatefulPartitionedCall node, as + // FunctionDefHelper::Define cannot do that. + for (auto& node_def : *function_call.mutable_node_def()) { + if (node_def.op() == "StatefulPartitionedCall") { + node_def.set_device("/job:localhost/replica:0/task:0/device:CPU:0"); + } + } + TF_ASSERT_OK(ctx->AddFunctionDef(function_call)); const string call_function_name2 = "FunctionCall2"; diff --git a/tensorflow/core/common_runtime/eager/mkl_eager_op_rewrite.cc b/tensorflow/core/common_runtime/eager/mkl_eager_op_rewrite.cc index 83d0b66b72e3cd..104ac476d9e75e 100644 --- a/tensorflow/core/common_runtime/eager/mkl_eager_op_rewrite.cc +++ b/tensorflow/core/common_runtime/eager/mkl_eager_op_rewrite.cc @@ -1,4 +1,4 @@ -/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -53,9 +53,6 @@ class MklEagerOpRewrite : public EagerOpRewrite { // Rewrite rule for Conv2D, Conv2DBackpropInput and Conv2DBackpropFilter. static bool RewriteConv2D(EagerOperation* op); - // Rewrite rule for MklSparseMatrixMatMul. - static bool RewriteSparseMatrixMatMul(EagerOperation* op); - // Rewrite rule for FusedBatchNormV3 and FusedBatchNormGradV3 static bool RewriteFusedBatchNormV3(EagerOperation* op); @@ -115,10 +112,6 @@ MklEagerOpRewrite::MklEagerOpRewrite(string name, string file, string line) InsertMKLEagerOps( {"FusedBatchNormV3", RewriteFusedBatchNormV3, CreateGenericMklOp}); InsertMKLEagerOps({"MatMul", AlwaysRewrite, CreateGenericMklOp}); -#ifdef ENABLE_ONEDNN_V3 - InsertMKLEagerOps( - {"SparseMatrixMatMul", RewriteSparseMatrixMatMul, CreateGenericMklOp}); -#endif // ENABLE_ONEDNN_V3 // TODO(Intel-tf): Support MaxPool, MaxPool3D rewrite, handle workspace. // Note: MaxPoolGrad, MaxPool3DGrad rewrite cannot be supported in eager // mode due to workspace restriction @@ -239,42 +232,6 @@ bool MklEagerOpRewrite::RewriteConv2D(EagerOperation* op) { return (padding != "EXPLICIT"); } -bool MklEagerOpRewrite::RewriteSparseMatrixMatMul(EagerOperation* op) { - const NodeDef& ndef = op->MutableAttrs()->BuildNodeDef(); - DataType T; - const TensorProto* proto = nullptr; - Tensor tensor; - bool adjoint_a, adjoint_b, transpose_a, transpose_b, transpose_out; - - // Check the datatype. - TF_CHECK_OK(GetNodeAttr(ndef, "T", &T)); - if (T != DT_FLOAT) { - VLOG(1) << "_MklSparseMatrixMatMul only supports DT_FLOAT"; - return false; - } - - // Check for adjointing. - TF_CHECK_OK(GetNodeAttr(ndef, "adjoint_a", &adjoint_a)); - TF_CHECK_OK(GetNodeAttr(ndef, "adjoint_b", &adjoint_b)); - if (adjoint_a || adjoint_b) { - VLOG(1) - << "_MklNativeSparseMatrixMatMul doesn't support adjointing matrices"; - return false; - } - - // Check for transposing. - TF_CHECK_OK(GetNodeAttr(ndef, "transpose_a", &transpose_a)); - TF_CHECK_OK(GetNodeAttr(ndef, "transpose_b", &transpose_b)); - TF_CHECK_OK(GetNodeAttr(ndef, "transpose_output", &transpose_out)); - if (transpose_a || transpose_b || transpose_out) { - VLOG(1) - << "_MklNativeSparseMatrixMatMul doesn't support transposing matrices"; - return false; - } - - return true; -} - bool MklEagerOpRewrite::RewriteFusedBatchNormV3(EagerOperation* op) { const NodeDef& ndef = op->MutableAttrs()->BuildNodeDef(); if (Check5DFormat(ndef)) { diff --git a/tensorflow/core/common_runtime/eval_const_tensor.cc b/tensorflow/core/common_runtime/eval_const_tensor.cc index a204ccd5d801e9..b6131c0df03945 100644 --- a/tensorflow/core/common_runtime/eval_const_tensor.cc +++ b/tensorflow/core/common_runtime/eval_const_tensor.cc @@ -112,7 +112,7 @@ std::optional GetSliceIndex(const Node& node, const int node_output) { // `tf.unstack(tf.shape(tensor))[ix]`, // and the result can be inferred from shape metadata, returns the result. // Otherwise, returns null. -StatusOr> TryInferFromShapes( +absl::StatusOr> TryInferFromShapes( const Node& node, const int node_output, const ShapeRefiner& refiner) { std::optional result; if (node.num_inputs() == 0 || node_output >= node.num_outputs()) { @@ -252,7 +252,7 @@ std::string OutputName(const NodeOutput& output) { // Assuming that the subgraph ending at `target_node` is constant-foldable, // returns it along with all constant inputs necessary for evaluation. // Otherwise, returns null. -StatusOr> ExtractConstantSubgraph( +absl::StatusOr> ExtractConstantSubgraph( const Node& target_node, const ShapeRefiner& refiner, const absl::FunctionRef(const Node&, int)> lookup, const OpRegistryInterface* op_registry, const int32_t graph_def_version) { @@ -356,7 +356,7 @@ StatusOr> ExtractConstantSubgraph( } // namespace -StatusOr> EvaluateConstantTensor( +absl::StatusOr> EvaluateConstantTensor( const Node& node, const int node_output, const ShapeRefiner& refiner, const absl::FunctionRef(const Node&, int)> lookup, const std::optional runner) { diff --git a/tensorflow/core/common_runtime/eval_const_tensor.h b/tensorflow/core/common_runtime/eval_const_tensor.h index e46c6ffc76aa51..049a3e9fb857b8 100644 --- a/tensorflow/core/common_runtime/eval_const_tensor.h +++ b/tensorflow/core/common_runtime/eval_const_tensor.h @@ -47,7 +47,7 @@ struct EvaluateConstantTensorRunner { // // When the evaluation is successful, the function returns a tensor, otherwise // it returns std::nullopt. -StatusOr> EvaluateConstantTensor( +absl::StatusOr> EvaluateConstantTensor( // The tensor to be evaluated. const Node& node, int node_output, // Used to fetch inference contexts for nodes in the graph. diff --git a/tensorflow/core/common_runtime/eval_const_tensor_test.cc b/tensorflow/core/common_runtime/eval_const_tensor_test.cc index f48e1c3d176395..ecd65071ce7ec9 100644 --- a/tensorflow/core/common_runtime/eval_const_tensor_test.cc +++ b/tensorflow/core/common_runtime/eval_const_tensor_test.cc @@ -57,7 +57,7 @@ class EvaluateConstantTensorTest : public ::testing::Test { return *this; } - StatusOr> Run(const Output& output) { + absl::StatusOr> Run(const Output& output) { TF_RETURN_IF_ERROR(scope_.status()); const auto& graph = *scope_.graph(); ShapeRefiner refiner(graph.versions(), graph.op_registry()); diff --git a/tensorflow/core/common_runtime/executor.cc b/tensorflow/core/common_runtime/executor.cc index a539a6eaa9ede2..b241a7e60c14e8 100644 --- a/tensorflow/core/common_runtime/executor.cc +++ b/tensorflow/core/common_runtime/executor.cc @@ -152,7 +152,7 @@ class ExecutorImpl : public Executor { Status Initialize(const Graph& graph) { TF_RETURN_IF_ERROR(immutable_state_.Initialize(graph)); kernel_stats_.Initialize(immutable_state_.graph_view()); - return OkStatus(); + return absl::OkStatus(); } private: @@ -507,7 +507,7 @@ void ExecutorState::RunAsync(Executor::DoneCallback done) { num_outstanding_ops_ = ready.size(); if (ready.empty()) { delete this; - done(OkStatus()); + done(absl::OkStatus()); } else { done_cb_ = std::move(done); // Schedule to run all the ready ops in thread pool. @@ -1077,7 +1077,7 @@ Status ExecutorState::PrepareInputs( } } } - return OkStatus(); + return absl::OkStatus(); } template @@ -1565,7 +1565,7 @@ class DefaultExecutorRegistrar { Executor* ret = nullptr; TF_RETURN_IF_ERROR(NewLocalExecutor(params, std::move(graph), &ret)); out_executor->reset(ret); - return OkStatus(); + return absl::OkStatus(); } }; }; diff --git a/tensorflow/core/common_runtime/executor_factory.cc b/tensorflow/core/common_runtime/executor_factory.cc index 1bf582808024bc..470bfb7548fbd9 100644 --- a/tensorflow/core/common_runtime/executor_factory.cc +++ b/tensorflow/core/common_runtime/executor_factory.cc @@ -70,7 +70,7 @@ Status ExecutorFactory::GetFactory(const string& executor_type, } *out_factory = iter->second; - return OkStatus(); + return absl::OkStatus(); } Status NewExecutor(const string& executor_type, diff --git a/tensorflow/core/common_runtime/executor_test.cc b/tensorflow/core/common_runtime/executor_test.cc index 648a2e80acf297..6a847b3f1b2f9f 100644 --- a/tensorflow/core/common_runtime/executor_test.cc +++ b/tensorflow/core/common_runtime/executor_test.cc @@ -584,7 +584,7 @@ Status ReplaceEdgeWithSendRecv(Graph* g, const Edge* edge, const string& tensor, g->AddControlEdge(edge->src(), recv); g->RemoveEdge(edge); - return OkStatus(); + return absl::OkStatus(); } // Defines a graph to perform the following computation: diff --git a/tensorflow/core/common_runtime/function.cc b/tensorflow/core/common_runtime/function.cc index 227821a71e783f..844d6ebd682107 100644 --- a/tensorflow/core/common_runtime/function.cc +++ b/tensorflow/core/common_runtime/function.cc @@ -605,7 +605,7 @@ Status FunctionLibraryRuntimeImpl::GetRetTypes(Handle h, } const FunctionBody* fbody = GetFunctionBody(h); *ret_types = fbody->ret_types; - return OkStatus(); + return absl::OkStatus(); } Status FunctionLibraryRuntimeImpl::CreateKernel( @@ -732,7 +732,7 @@ Status FunctionLibraryRuntimeImpl::InstantiateSymbolicGradient( CHECK_NOTNULL(f_body); *g_body = SymbolicGradient(*f_body); } - return OkStatus(); + return absl::OkStatus(); } bool FunctionLibraryRuntimeImpl::IsLocalTarget( @@ -792,7 +792,7 @@ Status FunctionLibraryRuntimeImpl::Instantiate( " not found in items."); } ++item_handle->second->instantiation_counter; - return OkStatus(); + return absl::OkStatus(); } } @@ -859,7 +859,7 @@ Status FunctionLibraryRuntimeImpl::Instantiate( TF_RETURN_IF_ERROR(GetOrCreateItem(local_handle, &item)); } - return OkStatus(); + return absl::OkStatus(); } Status FunctionLibraryRuntimeImpl::ReleaseHandle(Handle handle) { @@ -872,7 +872,7 @@ Status FunctionLibraryRuntimeImpl::ReleaseHandle(Handle handle) { { mutex_lock l(mu_); // Return directly if all items has already been released. - if (items_ == nullptr) return OkStatus(); + if (items_ == nullptr) return absl::OkStatus(); auto it = items_->find(h); if (it == items_->end()) { @@ -1005,7 +1005,7 @@ Status FunctionLibraryRuntimeImpl::CreateItem(Item** item) { (*item)->exec = exec.release(); } } - return OkStatus(); + return absl::OkStatus(); } Status FunctionLibraryRuntimeImpl::GetOrCreateItem(LocalHandle local_handle, @@ -1019,7 +1019,7 @@ Status FunctionLibraryRuntimeImpl::GetOrCreateItem(LocalHandle local_handle, } *item = iter->second.get(); if ((*item)->exec != nullptr) { - return OkStatus(); + return absl::OkStatus(); } } // NOTE: We need to call CreateItem out of mu_ because creating an @@ -1307,7 +1307,7 @@ Status FunctionLibraryRuntimeImpl::PrepareRunSync( device_name_, handle, /*include_multi_device=*/true); if (local_handle == kInvalidLocalHandle) { *out_item = nullptr; - return OkStatus(); + return absl::OkStatus(); } TF_RETURN_IF_ERROR(GetOrCreateItem(local_handle, out_item)); @@ -1317,7 +1317,7 @@ Status FunctionLibraryRuntimeImpl::PrepareRunSync( } DCHECK(run_opts->runner != nullptr); - return OkStatus(); + return absl::OkStatus(); } Status FunctionLibraryRuntimeImpl::RunSync(Options opts, Handle handle, @@ -1384,7 +1384,7 @@ Status FunctionLibraryRuntimeImpl::Clone( skip_flib_def)); *out_flr = (*out_pflr)->GetFLR(device_->name()); if (*out_flr != nullptr) { - return OkStatus(); + return absl::OkStatus(); } else { return errors::Internal("Cloning FunctionLibraryRuntime failed."); } diff --git a/tensorflow/core/common_runtime/function_def_utils.cc b/tensorflow/core/common_runtime/function_def_utils.cc index 1d7b5209ec116c..c4d87114291ac5 100644 --- a/tensorflow/core/common_runtime/function_def_utils.cc +++ b/tensorflow/core/common_runtime/function_def_utils.cc @@ -80,7 +80,7 @@ Status FunctionDefToBodyHelper( *fbody = std::make_unique(std::move(record), result.arg_types, result.ret_types, graph.release()); - return OkStatus(); + return absl::OkStatus(); } Status FunctionDefToBodyHelper(core::RefCountPtr&& record, diff --git a/tensorflow/core/common_runtime/function_optimization_registration_test.cc b/tensorflow/core/common_runtime/function_optimization_registration_test.cc index 1a8dddc21f960b..7963119e082786 100644 --- a/tensorflow/core/common_runtime/function_optimization_registration_test.cc +++ b/tensorflow/core/common_runtime/function_optimization_registration_test.cc @@ -35,7 +35,7 @@ class TestFunctionPass : public FunctionOptimizationPass { std::vector* control_ret_node_names, bool* control_rets_updated) override { ran_ = true; - return OkStatus(); + return absl::OkStatus(); } }; @@ -56,7 +56,7 @@ TEST(FunctionOptimizationPassRegistry, RegisteredPass) { /*flib_def=*/nullptr, /*control_ret_node_names=*/nullptr, /*control_rets_updated=*/nullptr); - EXPECT_EQ(status, OkStatus()); + EXPECT_EQ(status, absl::OkStatus()); EXPECT_TRUE(TestFunctionPass::ran_); } diff --git a/tensorflow/core/common_runtime/function_optimization_registry.cc b/tensorflow/core/common_runtime/function_optimization_registry.cc index 9a6ff8040d9982..3d6427ce372b05 100644 --- a/tensorflow/core/common_runtime/function_optimization_registry.cc +++ b/tensorflow/core/common_runtime/function_optimization_registry.cc @@ -34,7 +34,7 @@ Status FunctionOptimizationPassRegistry::Run( std::unique_ptr* graph, FunctionLibraryDefinition* flib_def, std::vector* control_ret_node_names, bool* control_rets_updated) { - if (!pass_) return OkStatus(); + if (!pass_) return absl::OkStatus(); tensorflow::metrics::ScopedCounter<2> timings( tensorflow::metrics::GetGraphOptimizationCounter(), diff --git a/tensorflow/core/common_runtime/function_optimization_registry_no_pass_test.cc b/tensorflow/core/common_runtime/function_optimization_registry_no_pass_test.cc index 49bb5c8f06d683..8a1b5716473ca7 100644 --- a/tensorflow/core/common_runtime/function_optimization_registry_no_pass_test.cc +++ b/tensorflow/core/common_runtime/function_optimization_registry_no_pass_test.cc @@ -36,7 +36,7 @@ TEST(FunctionOptimizationPassRegistry, NoPassSet) { /*flib_def=*/nullptr, /*control_ret_node_names=*/nullptr, /*control_rets_updated=*/nullptr); - EXPECT_EQ(status, OkStatus()); + EXPECT_EQ(status, absl::OkStatus()); } } // namespace tensorflow diff --git a/tensorflow/core/common_runtime/function_optimization_registry_test.cc b/tensorflow/core/common_runtime/function_optimization_registry_test.cc index d23aa4e8b1b77f..af28a7c24261a5 100644 --- a/tensorflow/core/common_runtime/function_optimization_registry_test.cc +++ b/tensorflow/core/common_runtime/function_optimization_registry_test.cc @@ -37,7 +37,7 @@ class PassingFunctionPass : public FunctionOptimizationPass { std::vector* control_ret_node_names, bool* control_rets_updated) override { ran_ = true; - return OkStatus(); + return absl::OkStatus(); } }; @@ -57,7 +57,7 @@ TEST(FunctionOptimizationPassRegistry, PassNoError) { /*flib_def=*/nullptr, /*control_ret_node_names=*/nullptr, /*control_rets_updated=*/nullptr); - EXPECT_EQ(status, OkStatus()); + EXPECT_EQ(status, absl::OkStatus()); EXPECT_TRUE(PassingFunctionPass::ran_); } diff --git a/tensorflow/core/common_runtime/function_test.cc b/tensorflow/core/common_runtime/function_test.cc index 2369a7e0358e03..aba35de1602ddd 100644 --- a/tensorflow/core/common_runtime/function_test.cc +++ b/tensorflow/core/common_runtime/function_test.cc @@ -179,7 +179,7 @@ class FunctionLibraryRuntimeTest : public ::testing::Test { tsl::core::RefCountPtr* r) { *r = tsl::core::RefCountPtr( new IntraProcessRendezvous(device_mgr)); - return OkStatus(); + return absl::OkStatus(); }})); flr0_ = pflr_->GetFLR("/job:localhost/replica:0/task:0/cpu:0"); flr1_ = pflr_->GetFLR("/job:localhost/replica:0/task:0/cpu:1"); @@ -210,7 +210,7 @@ class FunctionLibraryRuntimeTest : public ::testing::Test { for (size_t i = 0; i < rets.size(); ++i) { *rets[i] = out[i]; } - return OkStatus(); + return absl::OkStatus(); } Status Instantiate(FunctionLibraryRuntime* flr, const string& name, @@ -280,7 +280,7 @@ class FunctionLibraryRuntimeTest : public ::testing::Test { return status; } - return OkStatus(); + return absl::OkStatus(); } Status InstantiateAndRunViaCallFrameInterface(FunctionLibraryRuntime* flr, @@ -449,7 +449,7 @@ class ConsumeArgumentCallFrame : public CallFrameInterface { Status SetRetval(int index, const Tensor& val) override { CHECK_EQ(index, 0); *retval_ = val; - return OkStatus(); + return absl::OkStatus(); } private: @@ -1128,7 +1128,7 @@ TEST_F(FunctionLibraryRuntimeTest, ExpandInlineFunctionsAndKeepCallerNode) { auto a = ops::_Arg(s.WithOpName("a"), DT_FLOAT, 0); auto b = test::function::Call(&s, "b", "AddAndMul", {a}); TF_RETURN_IF_ERROR(s.ToGraph(g->get())); - return OkStatus(); + return absl::OkStatus(); }; const string input_node = "Func/b/input/_0"; @@ -1216,7 +1216,7 @@ TEST_F(FunctionLibraryRuntimeTest, ExpandInlineFunctionsAndPlaceInlinedNodes) { for (Node* node : (*g)->op_nodes()) { if (node->name() == "b") node->set_requested_device(call_device); } - return OkStatus(); + return absl::OkStatus(); }; const string input_node = "Func/b/input/_0"; diff --git a/tensorflow/core/common_runtime/function_threadpool_test.cc b/tensorflow/core/common_runtime/function_threadpool_test.cc index 77890a829c29c5..9454b25f7899ed 100644 --- a/tensorflow/core/common_runtime/function_threadpool_test.cc +++ b/tensorflow/core/common_runtime/function_threadpool_test.cc @@ -70,7 +70,7 @@ class FunctionLibraryRuntimeTest : public ::testing::Test { tsl::core::RefCountPtr* r) { *r = tsl::core::RefCountPtr( new IntraProcessRendezvous(device_mgr)); - return OkStatus(); + return absl::OkStatus(); }})); flr0_ = pflr_->GetFLR("/job:localhost/replica:0/task:0/cpu:0"); } @@ -110,7 +110,7 @@ class FunctionLibraryRuntimeTest : public ::testing::Test { EXPECT_GE(call_count, 1); // Test runner is used. } - return OkStatus(); + return absl::OkStatus(); } Status Instantiate(FunctionLibraryRuntime* flr, const string& name, @@ -191,7 +191,7 @@ class FunctionLibraryRuntimeTest : public ::testing::Test { EXPECT_GE(call_count, 1); // Test runner is used. } - return OkStatus(); + return absl::OkStatus(); } FunctionLibraryRuntime* flr0_; diff --git a/tensorflow/core/common_runtime/function_utils.cc b/tensorflow/core/common_runtime/function_utils.cc index bb6b9008c357eb..facd31481c05ed 100644 --- a/tensorflow/core/common_runtime/function_utils.cc +++ b/tensorflow/core/common_runtime/function_utils.cc @@ -268,7 +268,7 @@ Status NameAndAttrsFromFunctionCall(const NodeDef& call_def, function->set_name(call_def.op()); *function->mutable_attr() = call_def.attr(); } - return OkStatus(); + return absl::OkStatus(); } Status InstantiateFunctionCall(const NodeDef& call_def, diff --git a/tensorflow/core/common_runtime/gpu/BUILD b/tensorflow/core/common_runtime/gpu/BUILD index 1920ee32f98303..f35c59767e0e3e 100644 --- a/tensorflow/core/common_runtime/gpu/BUILD +++ b/tensorflow/core/common_runtime/gpu/BUILD @@ -264,34 +264,17 @@ tf_cuda_library( tf_cuda_library( name = "gpu_virtual_mem_allocator", - srcs = [ - "gpu_virtual_mem_allocator.cc", - ], hdrs = [ "gpu_virtual_mem_allocator.h", ], copts = tf_copts(), - cuda_deps = [ - "@local_xla//xla/stream_executor/gpu:gpu_driver_header", - "@local_xla//xla/stream_executor/gpu:gpu_types_header", - ], features = [ "-layering_check", "parse_headers", ], visibility = ["//visibility:public"], deps = [ - ":gpu_id", - "//tensorflow/core:framework", - "//tensorflow/core:framework_internal", - "//tensorflow/core:lib", - "//tensorflow/core:lib_internal", - "//tensorflow/core/framework:allocator", - "//tensorflow/core/platform:stream_executor", - "//tensorflow/core/profiler/lib:traceme", - "@com_google_absl//absl/strings:str_format", - "@local_xla//xla/stream_executor", - "@local_xla//xla/stream_executor:platform", + "@local_xla//xla/stream_executor/integrations:gpu_virtual_mem_allocator", ], ) @@ -335,6 +318,7 @@ tf_cuda_cc_test( "//tensorflow/core/common_runtime:direct_session_internal", "//tensorflow/core/kernels:ops_util", "@local_xla//xla/stream_executor/integrations:device_mem_allocator", + "@local_xla//xla/stream_executor/integrations:gpu_virtual_mem_allocator", ], ) @@ -390,6 +374,7 @@ tf_cuda_cc_test( "//tensorflow/core/common_runtime:core_cpu_internal", "//tensorflow/core/common_runtime:direct_session_internal", "//tensorflow/core/kernels:ops_util", + "@local_xla//xla/stream_executor:platform_manager", "@local_xla//xla/stream_executor/gpu:gpu_init", ], ) @@ -474,34 +459,21 @@ tf_cuda_cc_test( ], ) -tf_cc_test( - name = "gpu_virtual_mem_allocator_test", - size = "small", - srcs = ["gpu_virtual_mem_allocator_test.cc"], - features = ["-layering_check"], - tags = tf_cuda_tests_tags(), - deps = [ - ":gpu_virtual_mem_allocator", - "//tensorflow/core:test", - "//tensorflow/core:test_main", - "//tensorflow/core:testlib", - "//tensorflow/core/framework:allocator", - "//tensorflow/core/platform:stream_executor", - "@local_xla//xla/stream_executor/gpu:gpu_init", - "@local_xla//xla/stream_executor/integrations:device_mem_allocator", - ], -) - cc_library( name = "gpu_serving_device_selector", srcs = ["gpu_serving_device_selector.cc"], hdrs = ["gpu_serving_device_selector.h"], features = ["-layering_check"], deps = [ - "//tensorflow/core/common_runtime:serving_device_selector", + "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:fixed_array", + "@com_google_absl//absl/container:node_hash_map", + "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", + "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/time", + "@local_tsl//tsl/framework:serving_device_selector", ], ) @@ -511,8 +483,9 @@ tf_cc_test( srcs = ["gpu_serving_device_selector_test.cc"], deps = [ ":gpu_serving_device_selector", - "//tensorflow/core/common_runtime:serving_device_selector", - "//tensorflow/core/common_runtime:serving_device_selector_policies", + "@com_google_absl//absl/time", "@com_google_googletest//:gtest_main", + "@local_tsl//tsl/framework:serving_device_selector", + "@local_tsl//tsl/framework:serving_device_selector_policies", ], ) diff --git a/tensorflow/core/common_runtime/gpu/gpu_debug_allocator.cc b/tensorflow/core/common_runtime/gpu/gpu_debug_allocator.cc index b1570a7746d6b8..4756b43132b8cf 100644 --- a/tensorflow/core/common_runtime/gpu/gpu_debug_allocator.cc +++ b/tensorflow/core/common_runtime/gpu/gpu_debug_allocator.cc @@ -47,7 +47,7 @@ bool CheckMask(se::StreamExecutor* exec, void* ptr, int64_t* mask) { se::DeviceMemory gpu_ptr{se::DeviceMemoryBase{ptr, MASK_BYTES}}; int64_t tmp[MASK_WORDS]; - tsl::Status result = exec->SynchronousMemcpyD2H(gpu_ptr, MASK_BYTES, tmp); + absl::Status result = exec->SynchronousMemcpyD2H(gpu_ptr, MASK_BYTES, tmp); if (!result.ok()) { LOG(FATAL) << "Could not copy debug mask, " << result; } @@ -67,7 +67,7 @@ bool CheckMask(se::StreamExecutor* exec, void* ptr, int64_t* mask) { void InitMask(se::StreamExecutor* exec, void* ptr, int64_t* mask) { se::DeviceMemory gpu_ptr{se::DeviceMemoryBase{ptr, MASK_BYTES}}; - tsl::Status result = exec->SynchronousMemcpyH2D(mask, MASK_BYTES, &gpu_ptr); + absl::Status result = exec->SynchronousMemcpyH2D(mask, MASK_BYTES, &gpu_ptr); if (!result.ok()) { LOG(FATAL) << "Could not copy debug mask, " << result; } @@ -178,7 +178,7 @@ void* GPUNanResetAllocator::AllocateRaw(size_t alignment, size_t num_bytes) { se::DeviceMemory nan_ptr{ se::DeviceMemoryBase{static_cast(allocated_ptr), req_size}}; - tsl::Status result = + absl::Status result = stream_exec_->SynchronousMemcpyH2D(&nans[0], req_size, &nan_ptr); if (!result.ok()) { LOG(ERROR) << "Could not initialize to NaNs, " << result; @@ -194,7 +194,7 @@ void GPUNanResetAllocator::DeallocateRaw(void* ptr) { std::nanf("")); se::DeviceMemory nan_ptr{ se::DeviceMemoryBase{static_cast(ptr), req_size}}; - tsl::Status result = + absl::Status result = stream_exec_->SynchronousMemcpyH2D(&nans[0], req_size, &nan_ptr); if (!result.ok()) { LOG(ERROR) << "Could not initialize to NaNs, " << result; diff --git a/tensorflow/core/common_runtime/gpu/gpu_serving_device_selector.cc b/tensorflow/core/common_runtime/gpu/gpu_serving_device_selector.cc index e1122b1f9a2514..42d786c499c742 100644 --- a/tensorflow/core/common_runtime/gpu/gpu_serving_device_selector.cc +++ b/tensorflow/core/common_runtime/gpu/gpu_serving_device_selector.cc @@ -14,17 +14,32 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/common_runtime/gpu/gpu_serving_device_selector.h" +#include +#include #include #include +#include "absl/base/attributes.h" #include "absl/container/fixed_array.h" #include "absl/log/check.h" +#include "absl/log/log.h" #include "absl/strings/string_view.h" #include "absl/synchronization/mutex.h" -#include "tensorflow/core/common_runtime/serving_device_selector.h" +#include "absl/time/clock.h" +#include "tsl/framework/serving_device_selector.h" namespace tensorflow { namespace gpu { +// A default estimate of execution time for an enqueued program that this host +// has never finished executing. We currently set it to 1 ns (so that for all +// empty queues it still affects the decision) until we have better way to +// estimate this, as this penalty is chip-dependent and program-dependent. +constexpr int64_t kDefaultEstimateNs = 1; +ABSL_CONST_INIT int64_t (*NowNs)() = +[]() -> int64_t { + return absl::GetCurrentTimeNanos(); +}; + +using DeviceStates = GpuServingDeviceSelector::DeviceStates; GpuServingDeviceSelector::GpuServingDeviceSelector( const int num_devices, @@ -33,29 +48,73 @@ GpuServingDeviceSelector::GpuServingDeviceSelector( device_selector_policy_(std::move(device_selector_policy)), req_id_counter_(0) {} -DeviceReservation GpuServingDeviceSelector::ReserveDevice( +tsl::DeviceReservation GpuServingDeviceSelector::ReserveDevice( absl::string_view program_fingerprint) { absl::MutexLock lock(&mu_); DeviceStates device_states; device_states.states = absl::Span(device_states_); + auto [it, emplaced] = + execution_info_.try_emplace(program_fingerprint, ExecutionInfo()); const int device_index = device_selector_policy_->SelectDevice(program_fingerprint, device_states); - DeviceState::ProgramInfo program_info; - program_info.fingerprint = program_fingerprint; - program_info.req_id = ++req_id_counter_; - device_states_[device_index].scheduled_programs.push_back(program_info); + ServingDeviceSelector::EnqueueHelper( + device_states_.at(device_index), device_index, it->second, + program_fingerprint, /*priority=*/0, req_id_counter_++, + /*priority_queue_count=*/1, /*prefetch_results=*/0, NowNs()); - return DeviceReservation(device_index, this); + return tsl::DeviceReservation(device_index, this); } void GpuServingDeviceSelector::FreeDeviceReservation( - const DeviceReservation& reservation) { + const tsl::DeviceReservation& reservation) { + Completed(reservation.device_index()); +} + +void GpuServingDeviceSelector::Enqueue(int32_t index_on_host, + absl::string_view fingerprint) { + if (fingerprint.empty()) { + LOG(ERROR) << "Empty fingerprint."; + return; + } + absl::MutexLock lock(&mu_); - auto& scheduled_programs = - device_states_.at(reservation.device_index()).scheduled_programs; - DCHECK(!scheduled_programs.empty()); - scheduled_programs.pop_front(); + auto [it, emplaced] = + execution_info_.try_emplace(fingerprint, ExecutionInfo()); + + DeviceState& device_state = device_states_.at(index_on_host); + ServingDeviceSelector::EnqueueHelper(device_state, index_on_host, it->second, + fingerprint, + /*priority=*/0, /*req_id=*/-1, + /*priority_queue_count=*/1, + /*prefetch_results=*/0, NowNs()); + + // TODO(xiangll): Metric estimated execution time. +} + +void GpuServingDeviceSelector::Completed(int32_t index_on_host, + bool had_error) { + absl::MutexLock lock(&mu_); + DeviceState& device_state = device_states_.at(index_on_host); + ServingDeviceSelector::CompletedHelper(device_state, index_on_host, 0, + min_exec_time_, had_error, NowNs()); + + // TODO(xiangll): Metric estimated execution time. +} + +int64_t GpuServingDeviceSelector::TotalGpuLoadNsForTest() { + absl::MutexLock lock(&mu_); + int64_t total_gpu_load_ns = 0; + for (const auto& device_state : device_states_) { + total_gpu_load_ns += ServingDeviceSelector::EstimateTimeTillIdleNs( + device_state, 0, min_exec_time_.value_or(kDefaultEstimateNs), NowNs()); + } + return total_gpu_load_ns; +} + +/*static*/ void GpuServingDeviceSelector::OverwriteNowNsFunctionForTest( + int64_t (*now_ns)()) { + NowNs = now_ns; } } // namespace gpu diff --git a/tensorflow/core/common_runtime/gpu/gpu_serving_device_selector.h b/tensorflow/core/common_runtime/gpu/gpu_serving_device_selector.h index 55651e60d790f7..3055605780fe99 100644 --- a/tensorflow/core/common_runtime/gpu/gpu_serving_device_selector.h +++ b/tensorflow/core/common_runtime/gpu/gpu_serving_device_selector.h @@ -15,32 +15,54 @@ limitations under the License. #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_GPU_GPU_SERVING_DEVICE_SELECTOR_H_ #define TENSORFLOW_CORE_COMMON_RUNTIME_GPU_GPU_SERVING_DEVICE_SELECTOR_H_ +#include #include +#include +#include "absl/base/thread_annotations.h" #include "absl/container/fixed_array.h" +#include "absl/container/node_hash_map.h" #include "absl/strings/string_view.h" #include "absl/synchronization/mutex.h" -#include "tensorflow/core/common_runtime/serving_device_selector.h" +#include "tsl/framework/serving_device_selector.h" namespace tensorflow { namespace gpu { -class GpuServingDeviceSelector : public ServingDeviceSelector { +class GpuServingDeviceSelector : public tsl::ServingDeviceSelector { public: GpuServingDeviceSelector( int num_devices, std::unique_ptr device_selector_policy); - DeviceReservation ReserveDevice( + tsl::DeviceReservation ReserveDevice( absl::string_view program_fingerprint) override; + // Enqueues the program on the stream of index `index_on_host`. + void Enqueue(int32_t index_on_host, absl::string_view fingerprint); + + // Marks the completion of a program on the given stream. + // If `had_error` is true, this function doesn't update program's execution + // time stats to avoid incorrect estimates. + void Completed(int32_t index_on_host, bool had_error = false); + + int64_t TotalGpuLoadNsForTest(); + private: - void FreeDeviceReservation(const DeviceReservation& reservation) override; + friend class ServingDeviceSelectorTestHelper; + static void OverwriteNowNsFunctionForTest(int64_t (*now_ns)()); + + void FreeDeviceReservation( + const tsl::DeviceReservation& reservation) override; absl::Mutex mu_; absl::FixedArray device_states_ ABSL_GUARDED_BY(mu_); std::unique_ptr device_selector_policy_; int64_t req_id_counter_ ABSL_GUARDED_BY(mu_); + // Map from program fingerprint to execution info. + absl::node_hash_map execution_info_ + ABSL_GUARDED_BY(mu_); + std::optional min_exec_time_ ABSL_GUARDED_BY(mu_); }; } // namespace gpu diff --git a/tensorflow/core/common_runtime/gpu/gpu_serving_device_selector_test.cc b/tensorflow/core/common_runtime/gpu/gpu_serving_device_selector_test.cc index 42f9361fc2278b..b486706dcec50d 100644 --- a/tensorflow/core/common_runtime/gpu/gpu_serving_device_selector_test.cc +++ b/tensorflow/core/common_runtime/gpu/gpu_serving_device_selector_test.cc @@ -14,24 +14,49 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/common_runtime/gpu/gpu_serving_device_selector.h" +#include #include #include +#include #include -#include "tensorflow/core/common_runtime/serving_device_selector.h" -#include "tensorflow/core/common_runtime/serving_device_selector_policies.h" +#include "absl/time/clock.h" +#include "tsl/framework/serving_device_selector.h" +#include "tsl/framework/serving_device_selector_policies.h" namespace tensorflow { namespace gpu { +class ServingDeviceSelectorTestHelper { + public: + ServingDeviceSelectorTestHelper() { + GpuServingDeviceSelector::OverwriteNowNsFunctionForTest(NowNs); + now_ns_ = 0; + } + + ~ServingDeviceSelectorTestHelper() { + GpuServingDeviceSelector::OverwriteNowNsFunctionForTest( + absl::GetCurrentTimeNanos); + } + + static void ElapseNs(int64_t ns) { now_ns_ += ns; } + + static int64_t NowNs() { return now_ns_; } + + private: + static int64_t now_ns_; +}; + +int64_t ServingDeviceSelectorTestHelper::now_ns_ = 0; namespace { TEST(GpuServingDeviceSelector, Basic) { // Create a selector with two devices and round-robin policy. GpuServingDeviceSelector selector(/*num_devices=*/2, - std::make_unique()); + std::make_unique()); const std::string program_fingerprint = "TensorFlow"; - DeviceReservation reservation = selector.ReserveDevice(program_fingerprint); + tsl::DeviceReservation reservation = + selector.ReserveDevice(program_fingerprint); EXPECT_EQ(reservation.device_index(), 0); reservation = selector.ReserveDevice(program_fingerprint); @@ -41,6 +66,58 @@ TEST(GpuServingDeviceSelector, Basic) { EXPECT_EQ(reservation.device_index(), 0); } +TEST(GpuServingDeviceSelector, DefaultPolicyOnlyEnqueueCall) { + ServingDeviceSelectorTestHelper helper; + auto policy = std::make_unique(); + auto serving_device_selector = + std::make_unique( + 4, std::move(policy)); + serving_device_selector->Enqueue(3, "16ms"); + serving_device_selector->Enqueue(2, "8ms"); + serving_device_selector->Enqueue(1, "4ms"); + serving_device_selector->Enqueue(0, "2ms"); + // Nothing is completed yet, we don't have any estimated execution time, and + // we don't know what programs we are enqueueing. + serving_device_selector->Enqueue(3, "16ms"); + serving_device_selector->Enqueue(2, "8ms"); + serving_device_selector->Enqueue(1, "4ms"); + serving_device_selector->Enqueue(0, "2ms"); + helper.ElapseNs(2e6); + serving_device_selector->Completed(0); + helper.ElapseNs(2e6); + serving_device_selector->Completed(0); + serving_device_selector->Completed(1); + helper.ElapseNs(4e6); + serving_device_selector->Completed(1); + serving_device_selector->Completed(2); + helper.ElapseNs(8e6); + serving_device_selector->Completed(2); + serving_device_selector->Completed(3); + helper.ElapseNs(16e6); + serving_device_selector->Completed(3); + + serving_device_selector->Enqueue(3, "16ms"); + EXPECT_EQ(serving_device_selector->TotalGpuLoadNsForTest(), 16e6); + serving_device_selector->Enqueue(2, "8ms"); + EXPECT_EQ(serving_device_selector->TotalGpuLoadNsForTest(), 24e6); + serving_device_selector->Enqueue(1, "4ms"); + EXPECT_EQ(serving_device_selector->TotalGpuLoadNsForTest(), 28e6); + serving_device_selector->Enqueue(0, "2ms"); + EXPECT_EQ(serving_device_selector->TotalGpuLoadNsForTest(), 30e6); + helper.ElapseNs(2e6); + serving_device_selector->Completed(0); + EXPECT_EQ(serving_device_selector->TotalGpuLoadNsForTest(), 22e6); + helper.ElapseNs(2e6); + serving_device_selector->Completed(1); + EXPECT_EQ(serving_device_selector->TotalGpuLoadNsForTest(), 16e6); + helper.ElapseNs(4e6); + serving_device_selector->Completed(2); + EXPECT_EQ(serving_device_selector->TotalGpuLoadNsForTest(), 8e6); + helper.ElapseNs(8e6); + serving_device_selector->Completed(3); + EXPECT_EQ(serving_device_selector->TotalGpuLoadNsForTest(), 0e6); +} + } // namespace } // namespace gpu } // namespace tensorflow diff --git a/tensorflow/core/common_runtime/gpu/gpu_util.cc b/tensorflow/core/common_runtime/gpu/gpu_util.cc index b1284203215750..3ae56b821df2c0 100644 --- a/tensorflow/core/common_runtime/gpu/gpu_util.cc +++ b/tensorflow/core/common_runtime/gpu/gpu_util.cc @@ -121,7 +121,7 @@ Status PrepareCopy(Device* device, const DeviceContext* ctx, const Tensor& src, return errors::Internal("GPU copy from non-DMA ", DataTypeString(src.dtype()), " tensor"); } - return OkStatus(); + return absl::OkStatus(); } void* GetBase(const Tensor* src) { @@ -201,7 +201,7 @@ void GPUUtil::SetProtoFromGPU(const Tensor& tensor, Device* dev, } alloc->DeallocateRaw(buf); } - done(OkStatus()); + done(absl::OkStatus()); }); } @@ -268,7 +268,7 @@ void GPUUtil::DeviceToDeviceCopy( if (!send_device_to_device_stream->ok()) { LOG(FATAL) << "GPU->GPU Memcpy failed"; } - done(OkStatus()); + done(absl::OkStatus()); }); send_dev_context->MaintainLifetimeOnStream(input, send_device_to_device_stream); @@ -352,7 +352,7 @@ void GPUUtil::CopyGPUTensorToCPU(Device* gpu_device, LOG(FATAL) << "GPU->CPU Memcpy failed"; // Crash OK } input_ref.Unref(); - done(OkStatus()); + done(absl::OkStatus()); }); } @@ -433,7 +433,7 @@ void GPUUtil::CopyCPUTensorToGPU(const Tensor* cpu_tensor, if (!recv_host_to_device_stream->ok()) { LOG(FATAL) << "CPU->GPU Memcpy failed"; } - done(OkStatus()); + done(absl::OkStatus()); }); } @@ -456,7 +456,7 @@ Status GPUUtil::SyncAll(Device* gpu_device) { !dev_info->stream->ok()) { return errors::Internal("GPU sync failed"); } - return OkStatus(); + return absl::OkStatus(); } string GPUUtil::MemoryDebugString(const Device* device, Tensor* tensor) { @@ -537,7 +537,7 @@ void GPUUtil::CopyGPUTensorToSameGPU(Device* gpu_device, send_stream->ThenMemcpy(&gpu_dst_ptr, gpu_src_ptr, total_bytes); } - done(OkStatus()); + done(absl::OkStatus()); } } // namespace tensorflow diff --git a/tensorflow/core/common_runtime/gpu/gpu_util_platform_specific.cc b/tensorflow/core/common_runtime/gpu/gpu_util_platform_specific.cc index fffb504f2f7afa..783597123a3100 100644 --- a/tensorflow/core/common_runtime/gpu/gpu_util_platform_specific.cc +++ b/tensorflow/core/common_runtime/gpu/gpu_util_platform_specific.cc @@ -52,7 +52,7 @@ Status GPUDeviceContext::ThenExecute(Device* device, se::Stream* stream, const DeviceBase::AcceleratorDeviceInfo* gpu_info = device->tensorflow_accelerator_device_info(); gpu_info->event_mgr->ThenExecute(stream, func); - return OkStatus(); + return absl::OkStatus(); } } // namespace tensorflow diff --git a/tensorflow/core/common_runtime/gpu/gpu_virtual_mem_allocator.h b/tensorflow/core/common_runtime/gpu/gpu_virtual_mem_allocator.h index 20a592bc23ad92..64b5e7b567e3c1 100644 --- a/tensorflow/core/common_runtime/gpu/gpu_virtual_mem_allocator.h +++ b/tensorflow/core/common_runtime/gpu/gpu_virtual_mem_allocator.h @@ -13,106 +13,16 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -// CUDA virtual memory API is only available in CUDA versions greater than 10.2. - #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_GPU_GPU_VIRTUAL_MEM_ALLOCATOR_H_ #define TENSORFLOW_CORE_COMMON_RUNTIME_GPU_GPU_VIRTUAL_MEM_ALLOCATOR_H_ -#include -#include - -#include "xla/stream_executor/stream_executor.h" -#include "tsl/framework/allocator.h" -#include "tsl/framework/device_id.h" -#include "tsl/platform/statusor.h" +#include "xla/stream_executor/integrations/gpu_virtual_mem_allocator.h" // IWYU pragma: keep #if GOOGLE_CUDA -#include "xla/stream_executor/gpu/gpu_driver.h" -#include "xla/stream_executor/gpu/gpu_types.h" -#endif - -#if CUDA_VERSION >= 10020 - namespace tensorflow { - -// GpuVirtualMemAllocator is a SubAllocator for use with BFCAllocator which -// provides contiguous allocations with each call to Alloc. This is done by -// reserving a large chunk of virtual addresses at construction and then mapping -// physical memory pages to this virtual address range as requested. -// -// This class is not thread-safe. -class GpuVirtualMemAllocator : public tsl::SubAllocator { - public: - static tsl::StatusOr> Create( - const std::vector& alloc_visitors, - const std::vector& free_visitors, - stream_executor::gpu::GpuContext& gpu_context, - tsl::PlatformDeviceId gpu_id, size_t virtual_address_space_size, - const std::vector& peer_gpu_ids); - ~GpuVirtualMemAllocator() override; - - // Allocates memory at least as large as requested by num_bytes. Will be - // aligned to the min allocation granularity (typically 2MiB). - // alignment is ignored by this allocator. - void* Alloc(size_t alignment, size_t num_bytes, - size_t* bytes_received) override; - - // Frees should only happen at the end of the contiguous memory allocations or - // else we introduce pointless fragmentation...But, this is supported. If the - // allocation happens at the end, then the next_alloc_offset_ is moved back, - // otherwise a hole is created. - // - // Holes are not re-used, all allocations continue to come at the end of the - // next_alloc_offset_. To accommodate this, the virtual_address_space_size - // should be much larger than the max physical size of the allocator. - // - // In practice, since the BFC allocator coalesces adjacent AllocationRegions, - // this free function should never be invoked. - void Free(void* ptr, size_t num_bytes) override; - - bool SupportsCoalescing() const override { return true; } - - private: - GpuVirtualMemAllocator( - const std::vector& alloc_visitors, - const std::vector& free_visitors, - stream_executor::gpu::GpuContext& gpu_context, - tsl::PlatformDeviceId gpu_id, - std::vector access_device_handles, - stream_executor::gpu::GpuDriver::VmemSpan vmem, size_t granularity); - - stream_executor::gpu::GpuContext& gpu_context_; - tsl::PlatformDeviceId gpu_id_; - - // Peer access is configured at mmap time so the allocator must be aware of - // all gpus that may want to read the memory. This list also includes the - // above gpu_id_ to facilitate the invocation of the GpuDriver::MapMemory - // function. - const std::vector access_gpu_handles_; - - // The virtual memory span held by this allocator. - stream_executor::gpu::GpuDriver::VmemSpan vmem_; - // The next offset from the vmem base address that will be allocated. This - // corresponds to the size of physically pinned memory if holes haven't been - // created with "free". - size_t next_alloc_offset_ = 0; - - // Smallest allocation as determined by CUDA. - const size_t granularity_; - - struct Mapping { - stream_executor::gpu::GpuDevicePtr va; - stream_executor::gpu::GpuDriver::GenericMemoryHandle physical; - }; - // List of mappings, sorted by va. - std::vector mappings_; - - GpuVirtualMemAllocator(const GpuVirtualMemAllocator&) = delete; - void operator=(const GpuVirtualMemAllocator&) = delete; -}; - +using stream_executor::GpuVirtualMemAllocator; } // namespace tensorflow -#endif // CUDA_VERSION >= 10200 +#endif // GOOGLE_CUDA #endif // TENSORFLOW_CORE_COMMON_RUNTIME_GPU_GPU_VIRTUAL_MEM_ALLOCATOR_H_ diff --git a/tensorflow/core/common_runtime/gpu/pool_allocator_test.cc b/tensorflow/core/common_runtime/gpu/pool_allocator_test.cc index 0c194b1addeacc..74ff893a5bef3d 100644 --- a/tensorflow/core/common_runtime/gpu/pool_allocator_test.cc +++ b/tensorflow/core/common_runtime/gpu/pool_allocator_test.cc @@ -18,6 +18,7 @@ limitations under the License. #include "tensorflow/core/common_runtime/pool_allocator.h" #include "xla/stream_executor/gpu/gpu_init.h" +#include "xla/stream_executor/platform_manager.h" #include "tensorflow/core/common_runtime/device/device_host_allocator.h" #include "tensorflow/core/platform/stream_executor.h" #include "tensorflow/core/platform/test.h" @@ -26,7 +27,7 @@ namespace { TEST(PoolAllocatorTest, ZeroSizeBuffers) { se::Platform* platform = - se::MultiPlatformManager::PlatformWithName(se::GpuPlatformName()).value(); + se::PlatformManager::PlatformWithName(se::GpuPlatformName()).value(); PoolAllocator pool( 2 /*pool_size_limit*/, false /*auto_resize*/, new DeviceHostAllocator( @@ -45,7 +46,7 @@ TEST(PoolAllocatorTest, ZeroSizeBuffers) { TEST(PoolAllocatorTest, ZeroSizePool) { se::Platform* platform = - se::MultiPlatformManager::PlatformWithName(se::GpuPlatformName()).value(); + se::PlatformManager::PlatformWithName(se::GpuPlatformName()).value(); PoolAllocator pool( 0 /*pool_size_limit*/, false /*auto_resize*/, new DeviceHostAllocator( @@ -79,7 +80,7 @@ TEST(PoolAllocatorTest, ZeroSizePool) { TEST(PoolAllocatorTest, Alignment) { se::Platform* platform = - se::MultiPlatformManager::PlatformWithName(se::GpuPlatformName()).value(); + se::PlatformManager::PlatformWithName(se::GpuPlatformName()).value(); PoolAllocator pool( 0 /*pool_size_limit*/, false /*auto_resize*/, new DeviceHostAllocator( @@ -141,7 +142,7 @@ TEST(PoolAllocatorTest, CudaHostAllocator) { free_size += size; }; se::Platform* platform = - se::MultiPlatformManager::PlatformWithName(se::GpuPlatformName()).value(); + se::PlatformManager::PlatformWithName(se::GpuPlatformName()).value(); DeviceHostAllocator* sub_allocator = new DeviceHostAllocator( platform->GetExecutor(se::StreamExecutorConfig(/*ordinal=*/0)).value(), 0 /*numa_node*/, {alloc_visitor}, {free_visitor}); @@ -243,7 +244,7 @@ TEST(PoolAllocatorTest, Pow2Rounder) { TEST(PoolAllocatorTest, Name) { se::Platform* platform = - se::MultiPlatformManager::PlatformWithName(se::GpuPlatformName()).value(); + se::PlatformManager::PlatformWithName(se::GpuPlatformName()).value(); PoolAllocator pool( 2 /*pool_size_limit*/, false /*auto_resize*/, new DeviceHostAllocator( diff --git a/tensorflow/core/common_runtime/gradients.cc b/tensorflow/core/common_runtime/gradients.cc index 1e8b2318f3402a..9f65429171a25c 100644 --- a/tensorflow/core/common_runtime/gradients.cc +++ b/tensorflow/core/common_runtime/gradients.cc @@ -402,7 +402,7 @@ Status SymbolicGradientBuilder::Compute() { (*x_grad_node_outputs_)[i] = SumGradients(x_node_outputs_[i]); } - return OkStatus(); + return absl::OkStatus(); } Status AddSymbolicGradients(gtl::ArraySlice y_node_outputs, diff --git a/tensorflow/core/common_runtime/graph_constructor.cc b/tensorflow/core/common_runtime/graph_constructor.cc index 576b143c9b3f9c..0e1f20909eec17 100644 --- a/tensorflow/core/common_runtime/graph_constructor.cc +++ b/tensorflow/core/common_runtime/graph_constructor.cc @@ -214,7 +214,7 @@ class GraphConstructor { TF_RETURN_IF_ERROR(PopulateMissingUnusedInputMapKeys()); UpdateUniquifiedColocationNames(); FixupSourceAndSinkEdges(g_); - return OkStatus(); + return absl::OkStatus(); } private: @@ -643,7 +643,7 @@ Status GraphConstructor::EnsureNoNameCollisions() { prefix_ = strings::StrCat(FindUniqueName(prefix_no_slash), "/"); } } - return OkStatus(); + return absl::OkStatus(); } Status GraphConstructor::ValidateInputMapAndControlDependencies() { @@ -670,7 +670,7 @@ Status GraphConstructor::ValidateInputMapAndControlDependencies() { "graph"); } } - return OkStatus(); + return absl::OkStatus(); } Status GraphConstructor::BuildNodeIndex() { @@ -714,7 +714,7 @@ Status GraphConstructor::BuildNodeIndex() { // Update gdef_prefixes_. AddPrefixes(node_def.name(), &gdef_prefixes_); } - return OkStatus(); + return absl::OkStatus(); } Status GraphConstructor::InitFromEdges() { @@ -781,15 +781,15 @@ Status GraphConstructor::InitFromEdges() { } pending_count_.push_back(pending_count); } - return OkStatus(); + return absl::OkStatus(); } Status GraphConstructor::ValidateColocationConstraints( const NodeDef& node_def) { if (!opts_.validate_colocation_constraints || !opts_.importing) - return OkStatus(); + return absl::OkStatus(); const auto iter = node_def.attr().find(kColocationAttrName); - if (iter == node_def.attr().end()) return OkStatus(); + if (iter == node_def.attr().end()) return absl::OkStatus(); for (const string& c : iter->second.list().s()) { StringPiece s(c); if (absl::ConsumePrefix(&s, kColocationGroupPrefix) && @@ -799,7 +799,7 @@ Status GraphConstructor::ValidateColocationConstraints( "' expects to be colocated with unknown node '", s, "'"); } } - return OkStatus(); + return absl::OkStatus(); } Status GraphConstructor::MakeNode(NodeDef&& node_def, Node** node) { @@ -811,18 +811,18 @@ Status GraphConstructor::MakeNode(NodeDef&& node_def, Node** node) { (opts_.propagate_device_spec && !(*node)->def().device().empty())) { (*node)->set_assigned_device_name((*node)->def().device()); } - return OkStatus(); + return absl::OkStatus(); } Status GraphConstructor::ValidateShape(Node* node) { - if (!opts_.importing || !opts_.validate_shape) return OkStatus(); + if (!opts_.importing || !opts_.validate_shape) return absl::OkStatus(); TF_RETURN_IF_ERROR(refiner_->AddNode(node)); // For nodes with the _output_shapes attribute, override the shape. std::vector shape_attrs; const char* kAttrName = "_output_shapes"; if (!TryGetNodeAttr(node->attrs(), kAttrName, &shape_attrs)) { // No _output_shapes attribute, the AddNode call above was sufficient. - return OkStatus(); + return absl::OkStatus(); } auto* ic = refiner_->GetContext(node); DCHECK(ic != nullptr) @@ -860,7 +860,7 @@ Status GraphConstructor::ValidateShape(Node* node) { } } node->ClearAttr(kAttrName); - return OkStatus(); + return absl::OkStatus(); } Status GraphConstructor::ModifyNodeDefForImport(NodeDef* node_def) { @@ -871,7 +871,7 @@ Status GraphConstructor::ModifyNodeDefForImport(NodeDef* node_def) { if (versions()) { TF_RETURN_IF_ERROR(CheckOpDeprecation(*op_def, versions()->producer())); } - return OkStatus(); + return absl::OkStatus(); } void RemoveInputs(const std::vector& inputs_to_remove, NodeDef* node_def, @@ -1083,11 +1083,11 @@ Status GraphConstructor::IsNodeFullyMapped(const NodeDef& node_def, for (int i = 0; i < op_def->output_arg_size(); ++i) { if (opts_.input_map.find({node_def.name(), i}) == opts_.input_map.end()) { *is_node_mapped = false; - return OkStatus(); + return absl::OkStatus(); } } *is_node_mapped = true; - return OkStatus(); + return absl::OkStatus(); } void GraphConstructor::DFS(int cur_node, std::vector* cur_branch, @@ -1361,7 +1361,7 @@ Status GraphConstructor::Convert() { " nodes in a cycle"); } - return OkStatus(); + return absl::OkStatus(); } Status GraphConstructor::AddBackEdges() { @@ -1378,15 +1378,15 @@ Status GraphConstructor::AddBackEdges() { VLOG(2) << "Add back edge: " << src_node->name() << " -> " << e.dst_node->name(); } - return OkStatus(); + return absl::OkStatus(); } Status GraphConstructor::UpdateVersionDef() { - if (versions() == nullptr) return OkStatus(); + if (versions() == nullptr) return absl::OkStatus(); if (!opts_.importing) { g_->set_versions(*versions()); - return OkStatus(); + return absl::OkStatus(); } VersionDef g_versions = g_->versions(); g_versions.set_producer( @@ -1404,11 +1404,11 @@ Status GraphConstructor::UpdateVersionDef() { } } g_->set_versions(g_versions); - return OkStatus(); + return absl::OkStatus(); } Status GraphConstructor::PopulateReturnTensors() { - if (opts_.return_tensors.empty()) return OkStatus(); + if (opts_.return_tensors.empty()) return absl::OkStatus(); for (const TensorId& id : opts_.return_tensors) { auto iter = opts_.input_map.find(id); if (iter == opts_.input_map.end()) { @@ -1435,11 +1435,11 @@ Status GraphConstructor::PopulateReturnTensors() { return_tensors_->push_back({node, remapped_id.second}); } } - return OkStatus(); + return absl::OkStatus(); } Status GraphConstructor::PopulateReturnNodes() { - if (opts_.return_nodes.empty()) return OkStatus(); + if (opts_.return_nodes.empty()) return absl::OkStatus(); for (StringPiece name : opts_.return_nodes) { auto iter = gdef_nodes_.find(name); if (iter == gdef_nodes_.end()) { @@ -1448,11 +1448,11 @@ Status GraphConstructor::PopulateReturnNodes() { } return_nodes_->push_back(iter->second.node); } - return OkStatus(); + return absl::OkStatus(); } Status GraphConstructor::PopulateMissingUnusedInputMapKeys() { - if (missing_unused_input_map_keys_ == nullptr) return OkStatus(); + if (missing_unused_input_map_keys_ == nullptr) return absl::OkStatus(); for (const auto& input_map_pair : opts_.input_map) { TensorId key = input_map_pair.first; if (used_input_map_keys_.count(key) > 0) continue; @@ -1477,7 +1477,7 @@ Status GraphConstructor::PopulateMissingUnusedInputMapKeys() { missing_unused_input_map_keys_->push_back(key); } } - return OkStatus(); + return absl::OkStatus(); } void GraphConstructor::Undo() { @@ -1511,7 +1511,7 @@ Status GraphConstructor::MakeEdge(Node* src, int output_index, Node* dst, " incompatible with expected ", DataTypeString(dst_in), "."); } g_->AddEdge(src, output_index, dst, input_index); - return OkStatus(); + return absl::OkStatus(); } } // namespace diff --git a/tensorflow/core/common_runtime/graph_constructor_test.cc b/tensorflow/core/common_runtime/graph_constructor_test.cc index 2af347b29b3a62..2cae8ca92c3c81 100644 --- a/tensorflow/core/common_runtime/graph_constructor_test.cc +++ b/tensorflow/core/common_runtime/graph_constructor_test.cc @@ -109,7 +109,7 @@ class GraphConstructorTest : public ::testing::Test { ImportGraphDefResults* results = nullptr) { Convert(gdef_ascii); Status s = ImportGraphDef(opts, gdef_, &graph_, refiner, results); - EXPECT_EQ(OkStatus(), s) << s; + EXPECT_EQ(absl::OkStatus(), s) << s; } void ExpectVersions(int min_consumer, int producer) { @@ -184,7 +184,7 @@ Status Scalars(shape_inference::InferenceContext* c) { for (int i = 0; i < c->num_outputs(); ++i) { c->set_output(i, c->Scalar()); } - return OkStatus(); + return absl::OkStatus(); } REGISTER_OP("ABC"); @@ -812,7 +812,7 @@ versions { ImportGraphDefOptions opts; auto s = ImportGraphDef(opts, def, &graph_, nullptr); - ASSERT_EQ(OkStatus(), s) << s; + ASSERT_EQ(absl::OkStatus(), s) << s; } TEST_F(GraphConstructorTest, TypeMismatch) { @@ -933,7 +933,7 @@ TEST_F(GraphConstructorTest, ImportGraphDef) { // Importing an empty graph is fine. Status s = ImportGraphDef(opts, def, &graph_, nullptr); - ASSERT_EQ(OkStatus(), s) << s; + ASSERT_EQ(absl::OkStatus(), s) << s; EXPECT_EQ(2, graph_.num_nodes()); EXPECT_TRUE(HasControlEdge(source, sink)); EXPECT_EQ(1, graph_.num_edges()); @@ -968,7 +968,7 @@ TEST_F(GraphConstructorTest, ImportGraphDef) { // First import should work out fine. s = ImportGraphDef(opts, def, &graph_, nullptr); - ASSERT_EQ(OkStatus(), s) << s; + ASSERT_EQ(absl::OkStatus(), s) << s; EXPECT_EQ(5 + 2, graph_.num_nodes()); // Added nodes + source and sink EXPECT_EQ("A", ColocationGroup("B")); EXPECT_TRUE(HasEdge("A", 0, "B", 0)); @@ -989,7 +989,7 @@ TEST_F(GraphConstructorTest, ImportGraphDef) { // But succeed if a unique prefix is provided. opts.prefix = "import"; s = ImportGraphDef(opts, def, &graph_, nullptr); - ASSERT_EQ(OkStatus(), s) << s; + ASSERT_EQ(absl::OkStatus(), s) << s; EXPECT_EQ( 10 + 2, graph_.num_nodes()); // Added nodes + original nodes + source and sink @@ -1020,7 +1020,7 @@ TEST_F(GraphConstructorTest, ImportGraphDef_DefaultAttrs) { ASSERT_TRUE(protobuf::TextFormat::ParseFromString( "node{ name:'A' op:'TestDefaultAttr'}", &def)); Status s = ImportGraphDef(ImportGraphDefOptions(), def, &graph_, nullptr); - ASSERT_EQ(OkStatus(), s) << s; + ASSERT_EQ(absl::OkStatus(), s) << s; Node* a = nullptr; for (Node* n : graph_.nodes()) { if (n->name() == "A") { @@ -1031,7 +1031,7 @@ TEST_F(GraphConstructorTest, ImportGraphDef_DefaultAttrs) { ASSERT_TRUE(a != nullptr); int value = 0; s = GetNodeAttr(a->attrs(), "default_int", &value); - ASSERT_EQ(OkStatus(), s) << s << " -- " << a->def().DebugString(); + ASSERT_EQ(absl::OkStatus(), s) << s << " -- " << a->def().DebugString(); EXPECT_EQ(31415, value); } @@ -1056,14 +1056,14 @@ TEST_F(GraphConstructorTest, ImportGraphDef_Versioning) { def.mutable_versions()->Clear(); graph_.ToGraphDef(&def); s = ImportGraphDef(opts, def, &graph_, nullptr); - EXPECT_EQ(OkStatus(), s) << s; + EXPECT_EQ(absl::OkStatus(), s) << s; def.Clear(); const int original_min_consumer = graph_.versions().min_consumer(); def.mutable_versions()->set_min_consumer(original_min_consumer + 2); def.mutable_versions()->add_bad_consumers(TF_GRAPH_DEF_VERSION - 1); s = ImportGraphDef(opts, def, &graph_, nullptr); - EXPECT_EQ(OkStatus(), s) << s; + EXPECT_EQ(absl::OkStatus(), s) << s; EXPECT_EQ(original_min_consumer + 2, graph_.versions().min_consumer()); ASSERT_EQ(1, graph_.versions().bad_consumers_size()); EXPECT_EQ(TF_GRAPH_DEF_VERSION - 1, graph_.versions().bad_consumers(0)); @@ -1162,7 +1162,7 @@ node { &def); ASSERT_TRUE(parsed); Status s = ImportGraphDef(ImportGraphDefOptions(), def, &graph_, nullptr); - EXPECT_EQ(OkStatus(), s) << s; + EXPECT_EQ(absl::OkStatus(), s) << s; Graph g2(OpRegistry::Global()); def.mutable_versions()->set_producer(10); @@ -2256,7 +2256,7 @@ versions { &def); ASSERT_TRUE(parsed); Status s = ImportGraphDef(ImportGraphDefOptions(), def, &graph_, nullptr); - EXPECT_EQ(OkStatus(), s) << s; + EXPECT_EQ(absl::OkStatus(), s) << s; } TEST_F(GraphConstructorTest, ImportGraphDef_ControlDeps) { @@ -2444,7 +2444,7 @@ TEST_F(GraphConstructorTest, ImportGraphDef_ErrorsDoNoChangeTheGraph) { const string& sink = graph_.FindNodeId(Graph::kSinkId)->name(); Status s = ImportGraphDef(opts, def, &graph_, nullptr); - ASSERT_EQ(OkStatus(), s) << s; + ASSERT_EQ(absl::OkStatus(), s) << s; EXPECT_EQ(3, graph_.num_nodes()); // 'scope/A', source and sink EXPECT_TRUE(HasControlEdge(source, sink)); EXPECT_TRUE(HasControlEdge(source, "scope/A")); diff --git a/tensorflow/core/common_runtime/graph_execution_state.cc b/tensorflow/core/common_runtime/graph_execution_state.cc index 69c25f0623ea28..c2e115b80c3cb9 100644 --- a/tensorflow/core/common_runtime/graph_execution_state.cc +++ b/tensorflow/core/common_runtime/graph_execution_state.cc @@ -127,7 +127,7 @@ GraphExecutionState::~GraphExecutionState() { TF_RETURN_IF_ERROR(ret->InitBaseGraph(std::move(base_graph))); *out_state = std::move(ret); } - return OkStatus(); + return absl::OkStatus(); } /* static */ Status GraphExecutionState::MakeForPrunedGraph( @@ -177,7 +177,7 @@ GraphExecutionState::~GraphExecutionState() { TF_RETURN_IF_ERROR(ret->InitBaseGraph(std::move(base_graph))); TF_RETURN_IF_ERROR(ret->BuildGraph(subgraph_options, out_client_graph)); *out_state = std::move(ret); - return OkStatus(); + return absl::OkStatus(); } Status GraphExecutionState::Extend( @@ -285,7 +285,7 @@ Status GraphExecutionState::Extend( // NOTE(mrry): Extend() is likely to be used for non-throughput-sensitive // interactive workloads, but in future we may want to transfer other // parts of the placement and/or cost model. - return OkStatus(); + return absl::OkStatus(); } void GraphExecutionState::SaveStatefulNodes(Graph* graph) { @@ -344,7 +344,7 @@ class TensorConnectionPruneRewrite : public subgraph::PruneRewrite { (*out_node)->set_assigned_device_name( feed_tensor.node->assigned_device_name()); - return OkStatus(); + return absl::OkStatus(); } private: @@ -358,12 +358,12 @@ Status LookupDevice(const DeviceSet& device_set, const string& tensor_name, *out_device_attrs = nullptr; if (tensor2device.empty()) { *out_device_attrs = &device_set.client_device()->attributes(); - return OkStatus(); + return absl::OkStatus(); } const auto it = tensor2device.find(tensor_name); if (it == tensor2device.end()) { *out_device_attrs = &device_set.client_device()->attributes(); - return OkStatus(); + return absl::OkStatus(); } DeviceNameUtils::ParsedName parsed_name; if (!DeviceNameUtils::ParseFullName(it->second, &parsed_name)) { @@ -379,7 +379,7 @@ Status LookupDevice(const DeviceSet& device_set, const string& tensor_name, "' in CallableOptions does not exist"); } *out_device_attrs = &device->attributes(); - return OkStatus(); + return absl::OkStatus(); } struct TensorAndDevice { @@ -433,7 +433,7 @@ bool IsFeedAndFetchSupported(DataType dtype, const string& device_type) { Status ValidateFeedAndFetchDevices( const Graph& graph, const std::vector& tensors_and_devices) { - if (tensors_and_devices.empty()) return OkStatus(); + if (tensors_and_devices.empty()) return absl::OkStatus(); std::vector found(tensors_and_devices.size(), false); for (const Node* node : graph.nodes()) { // Linearly looping through all nodes and then all feed+fetch tensors isn't @@ -465,7 +465,7 @@ Status ValidateFeedAndFetchDevices( "in the Graph"); } } - return OkStatus(); + return absl::OkStatus(); } Status GetFeedShapeAndTypeFromAttribute(const NodeDef& node, @@ -500,7 +500,7 @@ Status GetFeedShapeAndTypeFromAttribute(const NodeDef& node, return errors::InvalidArgument("Could not determine shape for feed node: ", node.name(), " of type ", node.op()); } - return OkStatus(); + return absl::OkStatus(); } } // namespace @@ -608,7 +608,7 @@ Status GraphExecutionState::PruneGraph( for (int i = 0; i < options.callable_options.tensor_connection_size(); ++i) { out_rewrite_metadata->feed_types.pop_back(); } - return OkStatus(); + return absl::OkStatus(); } Status GraphExecutionState::InitBaseGraph(std::unique_ptr&& new_graph) { @@ -646,7 +646,7 @@ Status GraphExecutionState::InitBaseGraph(std::unique_ptr&& new_graph) { SaveStatefulNodes(new_graph.get()); graph_ = new_graph.release(); - return OkStatus(); + return absl::OkStatus(); } Status GraphExecutionState::OptimizeGraph( @@ -849,7 +849,7 @@ Status GraphExecutionState::OptimizeGraph( for (Node* node : optimized_graph->get()->nodes()) { node->set_assigned_device_name(node->requested_device()); } - return OkStatus(); + return absl::OkStatus(); } else { return errors::InvalidArgument("Meta Optimizer disabled"); } @@ -974,7 +974,7 @@ Status GraphExecutionState::BuildGraph(const BuildGraphOptions& options, // TODO(vrv): We should check invariants of the graph here. metrics::UpdateGraphBuildTime(Env::Default()->NowMicros() - start_time_usecs); *out = std::move(dense_copy); - return OkStatus(); + return absl::OkStatus(); } } // namespace tensorflow diff --git a/tensorflow/core/common_runtime/graph_runner.cc b/tensorflow/core/common_runtime/graph_runner.cc index 87de162a87455b..f834ec226df365 100644 --- a/tensorflow/core/common_runtime/graph_runner.cc +++ b/tensorflow/core/common_runtime/graph_runner.cc @@ -63,13 +63,13 @@ class SimpleRendezvous : public RendezvousInterface { return errors::Internal("Send of an already sent tensor"); } table_[edge_name] = val; - return OkStatus(); + return absl::OkStatus(); } void RecvAsync(const ParsedKey& parsed, const Args& recv_args, DoneCallback done) override { Tensor tensor; - Status status = OkStatus(); + Status status = absl::OkStatus(); { string key(parsed.edge_name); mutex_lock l(mu_); @@ -205,7 +205,7 @@ Status GraphRunner::Run(Graph* graph, FunctionLibraryRuntime* function_library, (*outputs)[i] = tensor::DeepCopy(output_tensor); } - return OkStatus(); + return absl::OkStatus(); } } // namespace tensorflow diff --git a/tensorflow/core/common_runtime/graph_view.cc b/tensorflow/core/common_runtime/graph_view.cc index 48aa6d2c80bed9..29458524bd5051 100644 --- a/tensorflow/core/common_runtime/graph_view.cc +++ b/tensorflow/core/common_runtime/graph_view.cc @@ -285,7 +285,7 @@ Status GraphView::Initialize(const Graph* g) { } } CHECK_EQ(ptr, space_ + total_bytes); - return OkStatus(); + return absl::OkStatus(); } namespace { diff --git a/tensorflow/core/common_runtime/hierarchical_tree_broadcaster.cc b/tensorflow/core/common_runtime/hierarchical_tree_broadcaster.cc index ed3ae517a979cd..1556290485c17e 100644 --- a/tensorflow/core/common_runtime/hierarchical_tree_broadcaster.cc +++ b/tensorflow/core/common_runtime/hierarchical_tree_broadcaster.cc @@ -182,7 +182,7 @@ Status HierarchicalTreeBroadcaster::InitializeCollectiveParams( } VLOG(2) << collective_util::SubdivPermDebugString(*col_params); - return OkStatus(); + return absl::OkStatus(); } Status HierarchicalTreeBroadcaster::InitializeCollectiveContext( diff --git a/tensorflow/core/common_runtime/immutable_executor_state.cc b/tensorflow/core/common_runtime/immutable_executor_state.cc index ea9b5cc8d9f97d..e3a2435505e041 100644 --- a/tensorflow/core/common_runtime/immutable_executor_state.cc +++ b/tensorflow/core/common_runtime/immutable_executor_state.cc @@ -347,7 +347,7 @@ Status ImmutableExecutorState::BuildControlFlowInfo(const Graph* g, } } - return OkStatus(); + return absl::OkStatus(); } void ImmutableExecutorState::InitializePending(const Graph* graph, diff --git a/tensorflow/core/common_runtime/inline_function_utils.cc b/tensorflow/core/common_runtime/inline_function_utils.cc index 625940689c6eb7..67250c6abbfc13 100644 --- a/tensorflow/core/common_runtime/inline_function_utils.cc +++ b/tensorflow/core/common_runtime/inline_function_utils.cc @@ -278,7 +278,7 @@ Status ValidateNoInline(const FunctionBody* fbody) { return errors::InvalidArgument( "Can't inline function marked with '_noinline'"); } - return OkStatus(); + return absl::OkStatus(); } using OutputControlSrc = InlineFunctionBodyOptions::OutputControlSource; @@ -394,7 +394,7 @@ Status ValidateInlining(const Node* node, const FunctionBody* fbody, TF_RETURN_IF_ERROR(ValidateNoInline(fbody)); } - return OkStatus(); + return absl::OkStatus(); } // Function inlining must preserve function execution semantics with regards to @@ -857,7 +857,7 @@ Status InlineFunctionBody(const FunctionLibraryDefinition& flib_def, Graph* g, VLOG(4) << "Final graph: " << g->ToGraphDefDebug().DebugString(); - return OkStatus(); + return absl::OkStatus(); } bool ExpandInlineFunctions(FunctionLibraryRuntime* lib, Graph* graph, diff --git a/tensorflow/core/common_runtime/inspecting_placer.cc b/tensorflow/core/common_runtime/inspecting_placer.cc index ea6a1d46b3ecc5..a84cd700874d8c 100644 --- a/tensorflow/core/common_runtime/inspecting_placer.cc +++ b/tensorflow/core/common_runtime/inspecting_placer.cc @@ -102,7 +102,7 @@ class ColocationGraphToIOColocationGroups { const Member& member = colocation_graph_->members()[it.first]; TF_RETURN_IF_ERROR(member.FillPossibleDevices(&possible_devices)); } - return OkStatus(); + return absl::OkStatus(); } private: @@ -156,7 +156,7 @@ Status InspectingPlacer::ComputeIOColocationGroups(const Node& node, converter.AssignGroups(fbody->arg_nodes, &groups->input_groups); converter.AssignGroups(fbody->ret_nodes, &groups->output_groups); TF_RETURN_IF_ERROR(converter.FillGroups(&groups->group_devices)); - return OkStatus(); + return absl::OkStatus(); } } // namespace tensorflow diff --git a/tensorflow/core/common_runtime/int32_fulltype.cc b/tensorflow/core/common_runtime/int32_fulltype.cc index 9971a6cf3b9fbb..ab2ef6867d122b 100644 --- a/tensorflow/core/common_runtime/int32_fulltype.cc +++ b/tensorflow/core/common_runtime/int32_fulltype.cc @@ -58,7 +58,7 @@ Status Int32FulltypePass::Int32FullTypeForTensor(DataType dtype, tensor_t->set_type_id(TFT_SHAPE_TENSOR); (*tensor_t->add_args()) = data_t; } - return OkStatus(); + return absl::OkStatus(); } static bool is_host_memory_int32(MemoryType mtype, DataType dtype) { @@ -135,7 +135,7 @@ Status Int32FulltypePass::ProcessGraph(Graph* graph, bool ints_on_device) { << t.DebugString(); } } - return OkStatus(); + return absl::OkStatus(); } } // namespace tensorflow diff --git a/tensorflow/core/common_runtime/int32_fulltype_test.cc b/tensorflow/core/common_runtime/int32_fulltype_test.cc index 6e816cf88339fc..e6ead597aea23e 100644 --- a/tensorflow/core/common_runtime/int32_fulltype_test.cc +++ b/tensorflow/core/common_runtime/int32_fulltype_test.cc @@ -69,7 +69,7 @@ class Int32FulltypeTest : public ::testing::Test { Status BuildGraph(const GraphDefBuilder& builder, Graph* out_graph) { TF_RETURN_IF_ERROR(GraphDefBuilderToGraph(builder, out_graph)); RebuildNodeNameMap(*out_graph); - return OkStatus(); + return absl::OkStatus(); } void AddTensorFT(FullTypeDef& t, tensorflow::FullTypeId out_t_id, diff --git a/tensorflow/core/common_runtime/isolate_placer_inspection_required_ops_pass.cc b/tensorflow/core/common_runtime/isolate_placer_inspection_required_ops_pass.cc index 698481c8ca946b..7945b5cc3cf1c6 100644 --- a/tensorflow/core/common_runtime/isolate_placer_inspection_required_ops_pass.cc +++ b/tensorflow/core/common_runtime/isolate_placer_inspection_required_ops_pass.cc @@ -28,7 +28,7 @@ Status IsolatePlacerInspectionRequiredOpsPass::Run( if (options.graph == nullptr) { VLOG(1) << "Not running IsolatePlacerInspectionRequiredOpsPass because no " "graph is provided"; - return OkStatus(); + return absl::OkStatus(); } VLOG(1) << "IsolatePlacerInspectionRequiredOpsPass::Run"; diff --git a/tensorflow/core/common_runtime/lower_case_op.cc b/tensorflow/core/common_runtime/lower_case_op.cc index e0e10fb7e43448..4727c941ec6efc 100644 --- a/tensorflow/core/common_runtime/lower_case_op.cc +++ b/tensorflow/core/common_runtime/lower_case_op.cc @@ -136,7 +136,7 @@ Status CaseBuilder::CreatePivotNodes() { .Device(case_op_->requested_device()) .Finalize(graph_, &pivots_[b])); } - return OkStatus(); + return absl::OkStatus(); } string CaseBuilder::NewName(const string& infix) { @@ -166,7 +166,7 @@ Status CaseBuilder::AddInput(Node* src, int src_output) { for (int b = 0; b < num_branches_; b++) { branch_call_builders_[b].Input(input, b); } - return OkStatus(); + return absl::OkStatus(); } Status CaseBuilder::AddInputs() { @@ -184,7 +184,7 @@ Status CaseBuilder::AddInputs() { graph_->AddControlEdge(e->src(), control_predecessor_); } } - return OkStatus(); + return absl::OkStatus(); } Status CaseBuilder::AddOutputs() { @@ -247,7 +247,7 @@ Status CaseBuilder::AddOutputs() { graph_->AddEdge(merges[e->src_output()], 0, e->dst(), e->dst_input()); } } - return OkStatus(); + return absl::OkStatus(); } Status CaseBuilder::BuildLoweredCaseOutput() { @@ -287,7 +287,7 @@ Status RewriteCaseNode(Node* n, Graph* g, bool keep_node_fetchable) { TF_RETURN_IF_ERROR(cb.AddOutputs()); g->RemoveNode(n); - return OkStatus(); + return absl::OkStatus(); } } // namespace tensorflow diff --git a/tensorflow/core/common_runtime/lower_function_call_op.cc b/tensorflow/core/common_runtime/lower_function_call_op.cc index b5077da20d2398..b05509e5246e9e 100644 --- a/tensorflow/core/common_runtime/lower_function_call_op.cc +++ b/tensorflow/core/common_runtime/lower_function_call_op.cc @@ -74,7 +74,7 @@ Status RewriteFunctionCallNode(Node* n, Graph* g, fdef = flib_def.FindRecord(func.name()); } else if (n->type_string() == FunctionLibraryDefinition::kGradientOp) { VLOG(2) << "Skip SymbolicGradient lowering"; - return OkStatus(); + return absl::OkStatus(); } else { fdef = flib_def.FindRecord(n->type_string()); } @@ -97,7 +97,7 @@ Status RewriteFunctionCallNode(Node* n, Graph* g, << can_inline_function_call.message(); } - return OkStatus(); + return absl::OkStatus(); } } // namespace tensorflow diff --git a/tensorflow/core/common_runtime/lower_functional_ops.cc b/tensorflow/core/common_runtime/lower_functional_ops.cc index 6894e677fdd112..62c712a0d15360 100644 --- a/tensorflow/core/common_runtime/lower_functional_ops.cc +++ b/tensorflow/core/common_runtime/lower_functional_ops.cc @@ -105,7 +105,7 @@ Status LowerFunctionalOpsPass::Run( "Lowering If/While ops should happen before partitioning."); } if (options.graph == nullptr) { - return OkStatus(); + return absl::OkStatus(); } Graph* g = options.graph->get(); @@ -209,7 +209,7 @@ Status LowerFunctionalOpsPass::Run( }, IsPropagatableDevice, g); - return OkStatus(); + return absl::OkStatus(); } REGISTER_OPTIMIZATION(OptimizationPassRegistry::PRE_PLACEMENT, 10, diff --git a/tensorflow/core/common_runtime/lower_if_op.cc b/tensorflow/core/common_runtime/lower_if_op.cc index 17799a50f90fa0..a2875c7c823b52 100644 --- a/tensorflow/core/common_runtime/lower_if_op.cc +++ b/tensorflow/core/common_runtime/lower_if_op.cc @@ -169,7 +169,7 @@ Status CondBuilder::CreatePivotNodes() { .Input(switch_pred, kThenBranch) .Device(if_op_->requested_device()), graph_, &pivot_t_)); - return OkStatus(); + return absl::OkStatus(); } string CondBuilder::NewName(const string& infix) { @@ -202,7 +202,7 @@ Status CondBuilder::AddInput(Node* src, int src_output) { .Finalize(graph_, &input)); then_call_builder_.Input(input, kThenBranch); else_call_builder_.Input(input, kElseBranch); - return OkStatus(); + return absl::OkStatus(); } Status CondBuilder::AddInputs() { @@ -220,7 +220,7 @@ Status CondBuilder::AddInputs() { graph_->AddControlEdge(e->src(), control_predecessor_); } } - return OkStatus(); + return absl::OkStatus(); } Status CondBuilder::AddOutputs() { @@ -276,7 +276,7 @@ Status CondBuilder::AddOutputs() { } } - return OkStatus(); + return absl::OkStatus(); } Status CondBuilder::BuildLoweredIfOutput() { @@ -317,7 +317,7 @@ Status RewriteIfNode(Node* n, Graph* g, bool keep_node_fetchable) { TF_RETURN_IF_ERROR(cb.AddOutputs()); g->RemoveNode(n); - return OkStatus(); + return absl::OkStatus(); } } // namespace tensorflow diff --git a/tensorflow/core/common_runtime/lower_while_op.cc b/tensorflow/core/common_runtime/lower_while_op.cc index 16e4ac7b0d0a95..67cab4576b44ab 100644 --- a/tensorflow/core/common_runtime/lower_while_op.cc +++ b/tensorflow/core/common_runtime/lower_while_op.cc @@ -236,7 +236,7 @@ Status LowerWhileHelper::RunInternal() { TF_RETURN_IF_ERROR(CreateNextIterationNodes()); TF_RETURN_IF_ERROR(UpdateMergeNodes()); TF_RETURN_IF_ERROR(UpdateConsumers()); - return OkStatus(); + return absl::OkStatus(); } void LowerWhileHelper::InitializeInputOutputToLoweredNodeMap() { @@ -301,7 +301,7 @@ Status LowerWhileHelper::CreateEnterNodes() { graph_->AddControlEdge(incoming_control_node, n); } } - return OkStatus(); + return absl::OkStatus(); } Status LowerWhileHelper::CreateMergeNodes() { @@ -325,7 +325,7 @@ Status LowerWhileHelper::CreateMergeNodes() { TF_RETURN_IF_ERROR(builder.Finalize(graph_, &merge_node)); merge_nodes_.emplace_back(merge_node); } - return OkStatus(); + return absl::OkStatus(); } Status LowerWhileHelper::CreateCondFuncCallNode() { @@ -354,7 +354,7 @@ Status LowerWhileHelper::CreateCondFuncCallNode() { } } TF_RETURN_IF_ERROR(builder.Finalize(graph_, &loop_cond_node_)); - return OkStatus(); + return absl::OkStatus(); } Status LowerWhileHelper::CreateSwitchNodes() { @@ -390,7 +390,7 @@ Status LowerWhileHelper::CreateSwitchNodes() { TF_RETURN_IF_ERROR(builder.Finalize(graph_, &switch_node)); switch_nodes_.emplace_back(switch_node); } - return OkStatus(); + return absl::OkStatus(); } Status LowerWhileHelper::CreateBodyFuncCallNode() { @@ -428,7 +428,7 @@ Status LowerWhileHelper::CreateBodyFuncCallNode() { } TF_RETURN_IF_ERROR(builder.Finalize(graph_, &body_control_node_)); graph_->AddControlEdge(body_control_node_, body_call_node_); - return OkStatus(); + return absl::OkStatus(); } Status LowerWhileHelper::CreateExitNodes() { @@ -504,7 +504,7 @@ Status LowerWhileHelper::CreateExitNodes() { .Finalize(graph_, &lowered_while_output_)); } - return OkStatus(); + return absl::OkStatus(); } Status LowerWhileHelper::CreateNextIterationNodes() { @@ -530,7 +530,7 @@ Status LowerWhileHelper::CreateNextIterationNodes() { TF_RETURN_IF_ERROR(builder.Finalize(graph_, &next_iteration)); next_iterations_nodes_.emplace_back(next_iteration); } - return OkStatus(); + return absl::OkStatus(); } Status LowerWhileHelper::UpdateMergeNodes() { @@ -538,7 +538,7 @@ Status LowerWhileHelper::UpdateMergeNodes() { TF_RETURN_IF_ERROR( graph_->UpdateEdge(next_iterations_nodes_[i], 0, merge_nodes_[i], 1)); } - return OkStatus(); + return absl::OkStatus(); } Status LowerWhileHelper::UpdateConsumers() { @@ -565,7 +565,7 @@ Status LowerWhileHelper::UpdateConsumers() { } } } - return OkStatus(); + return absl::OkStatus(); } string LowerWhileHelper::NewName(const string& infix) { @@ -615,7 +615,7 @@ Status RewriteWhileNode(Node* n, Graph* g, flib_def, keep_node_fetchable)); g->RemoveNode(n); - return OkStatus(); + return absl::OkStatus(); } } // namespace tensorflow diff --git a/tensorflow/core/common_runtime/memory_types.cc b/tensorflow/core/common_runtime/memory_types.cc index 7f555718947be5..d90981a6e883fc 100644 --- a/tensorflow/core/common_runtime/memory_types.cc +++ b/tensorflow/core/common_runtime/memory_types.cc @@ -52,7 +52,7 @@ static Status ProcessMemoryTypes( if (device_type != DEVICE_GPU && !DeviceFactory::IsPluggableDevice(device_type.type_string())) { // On non-GPU devices, HOST_MEMORY and DEVICE_MEMORY are always compatible. - return OkStatus(); + return absl::OkStatus(); } // For GPU, HOST_MEMORY and DEVICE_MEMORY is not compatible. I.e., a // conversion/transfer must be done. @@ -89,14 +89,14 @@ static Status ProcessMemoryTypes( << dm; TF_RETURN_IF_ERROR(fn(e, sm, dm)); } - return OkStatus(); + return absl::OkStatus(); } Status ValidateMemoryTypes(const DeviceType& device_type, const Graph* g) { return ProcessMemoryTypes( device_type, g, [](const Edge* e, MemoryType sm, MemoryType dm) { if (sm == dm) { - return OkStatus(); + return absl::OkStatus(); } return errors::Internal("Memory type mismatch (", sm, " ", dm, ") between :", e->src()->id(), ":", @@ -164,12 +164,12 @@ Status EnsureMemoryTypes(const DeviceType& device_type, TF_RETURN_IF_ERROR(ProcessMemoryTypes( device_type, g, [&edges](const Edge* e, MemoryType sm, MemoryType dm) { if (sm == dm) { - return OkStatus(); + return absl::OkStatus(); } if (((sm == HOST_MEMORY) && (dm == DEVICE_MEMORY)) || ((sm == DEVICE_MEMORY) && (dm == HOST_MEMORY))) { edges.push_back({e, sm, dm}); - return OkStatus(); + return absl::OkStatus(); } return errors::Internal("Unexpected memory type pair on an edge: ", sm, " vs. ", dm); @@ -226,7 +226,7 @@ Status MemoryTypeForOutput(const DeviceType& device_type, const Graph* g, " that has only ", out_mvec.size(), " outputs"); } *memory_type = out_mvec[index]; - return OkStatus(); + return absl::OkStatus(); } } // end namespace tensorflow diff --git a/tensorflow/core/common_runtime/mkl_layout_pass.cc b/tensorflow/core/common_runtime/mkl_layout_pass.cc index da1ad8dc2c089a..02c62ddd32245e 100644 --- a/tensorflow/core/common_runtime/mkl_layout_pass.cc +++ b/tensorflow/core/common_runtime/mkl_layout_pass.cc @@ -1,4 +1,4 @@ -/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -46,7 +46,6 @@ limitations under the License. #include "tensorflow/core/lib/hash/hash.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/util/mkl_heuristics.h" -#include "tensorflow/core/util/onednn_env_vars.h" #include "tensorflow/core/util/tensor_format.h" #include "tensorflow/core/util/util.h" @@ -371,7 +370,6 @@ class MklLayoutRewritePass : public GraphOptimizationPass { csinfo_.tanh = "Tanh"; csinfo_.tanh_grad = "TanhGrad"; csinfo_.reshape = "Reshape"; - csinfo_.sparse_matrix_matmul = "SparseMatrixMatMul"; csinfo_.slice = "Slice"; csinfo_.softmax = "Softmax"; csinfo_.split = "Split"; @@ -542,12 +540,6 @@ class MklLayoutRewritePass : public GraphOptimizationPass { rinfo_.push_back({csinfo_.matmul, mkl_op_registry::GetMklOpName(csinfo_.matmul), CopyAttrsAll, MatMulRewrite, kRewriteForOpNameChange}); -#ifdef ENABLE_ONEDNN_V3 - rinfo_.push_back( - {csinfo_.sparse_matrix_matmul, - mkl_op_registry::GetMklOpName(csinfo_.sparse_matrix_matmul), - CopyAttrsAll, SparseMatrixMatMulRewrite, kRewriteForOpNameChange}); -#endif rinfo_.push_back({csinfo_.leakyrelu, mkl_op_registry::GetMklOpName(csinfo_.leakyrelu), CopyAttrsAll, LeakyReluRewrite, GetRewriteCause()}); @@ -993,7 +985,6 @@ class MklLayoutRewritePass : public GraphOptimizationPass { string tanh_grad; string transpose; string reshape; - string sparse_matrix_matmul; string slice; string softmax; string split; @@ -1550,50 +1541,6 @@ class MklLayoutRewritePass : public GraphOptimizationPass { } return false; } - - static bool SparseMatrixMatMulRewrite(const Node* n) { - DataType T; - const TensorProto* proto = nullptr; - Tensor tensor; - bool adjoint_a, adjoint_b, transpose_a, transpose_b, transpose_out; - - // Check the environment variable. - if (!UseOnednnSpmm()) { - VLOG(2) << "TF_ENABLE_ONEDNN_SPMM is disabled"; - return false; - } else { - VLOG(2) << "TF_ENABLE_ONEDNN_SPMM is enabled"; - } - - // Check the datatype. - TF_CHECK_OK(GetNodeAttr(n->def(), "T", &T)); - if (T != DT_FLOAT) { - VLOG(2) << "_MklSparseMatrixMatMul only supports DT_FLOAT"; - return false; - } - - // Check for adjointing. - TF_CHECK_OK(GetNodeAttr(n->def(), "adjoint_a", &adjoint_a)); - TF_CHECK_OK(GetNodeAttr(n->def(), "adjoint_b", &adjoint_b)); - if (adjoint_a || adjoint_b) { - VLOG(2) - << "_MklNativeSparseMatrixMatMul doesn't support adjointing matrices"; - return false; - } - - // Check for transposing. - TF_CHECK_OK(GetNodeAttr(n->def(), "transpose_a", &transpose_a)); - TF_CHECK_OK(GetNodeAttr(n->def(), "transpose_b", &transpose_b)); - TF_CHECK_OK(GetNodeAttr(n->def(), "transpose_output", &transpose_out)); - if (transpose_a || transpose_b || transpose_out) { - VLOG(2) << "_MklNativeSparseMatrixMatMul doesn't support transposing " - "matrices"; - return false; - } - - return true; - } - // For oneDNN, only int32 is supported for axis data type static bool ConcatV2Rewrite(const Node* n) { DataType T; diff --git a/tensorflow/core/common_runtime/next_pluggable_device/c/tf_rendezvous_c_api_helper.cc b/tensorflow/core/common_runtime/next_pluggable_device/c/tf_rendezvous_c_api_helper.cc index f1d715ccf5c079..0750d3b6205946 100644 --- a/tensorflow/core/common_runtime/next_pluggable_device/c/tf_rendezvous_c_api_helper.cc +++ b/tensorflow/core/common_runtime/next_pluggable_device/c/tf_rendezvous_c_api_helper.cc @@ -101,9 +101,9 @@ using DoneCallbackParamPtr = std::unique_ptr; SendParamDeleter MakeSendParamDeleter(); -StatusOr SendParamsToC(const RendezvousInterface::ParsedKey& key, - const RendezvousInterface::Args& args, - const Tensor& tensor, bool is_dead); +absl::StatusOr SendParamsToC( + const RendezvousInterface::ParsedKey& key, + const RendezvousInterface::Args& args, const Tensor& tensor, bool is_dead); void RendezvousCallbackThunk(void* context, TF_RendezvousDoneCallback_Params* params) { @@ -168,9 +168,10 @@ SendParamDeleter MakeSendParamDeleter() { }; } -StatusOr SendParamsToC(const RendezvousInterface::ParsedKey& key, - const RendezvousInterface::Args& args, - const Tensor& tensor, const bool is_dead) { +absl::StatusOr SendParamsToC( + const RendezvousInterface::ParsedKey& key, + const RendezvousInterface::Args& args, const Tensor& tensor, + const bool is_dead) { TF_RendezvousSend_Params* params = new TF_RendezvousSend_Params(); params->key = new TF_RendezvousParsedKey(ToC(key)); params->args = new TF_RendezvousArgsStruct(ToC(args)); diff --git a/tensorflow/core/common_runtime/next_pluggable_device/c/tf_rendezvous_c_api_internal.cc b/tensorflow/core/common_runtime/next_pluggable_device/c/tf_rendezvous_c_api_internal.cc index b63803e3012c15..87b6a7dc8c652e 100644 --- a/tensorflow/core/common_runtime/next_pluggable_device/c/tf_rendezvous_c_api_internal.cc +++ b/tensorflow/core/common_runtime/next_pluggable_device/c/tf_rendezvous_c_api_internal.cc @@ -113,7 +113,7 @@ DoneCallbackParamDeleter MakeDoneCallbackParamDeleter() { }; } -StatusOr DoneCallbackParamsToC( +absl::StatusOr DoneCallbackParamsToC( const Status& status, const RendezvousInterface::Args& sender_args, const RendezvousInterface::Args& recver_args, const Tensor& tensor, const bool is_dead) { diff --git a/tensorflow/core/common_runtime/next_pluggable_device/c/tf_rendezvous_c_api_test.cc b/tensorflow/core/common_runtime/next_pluggable_device/c/tf_rendezvous_c_api_test.cc index 50ee7d5d641a9d..3b391d00bceab9 100644 --- a/tensorflow/core/common_runtime/next_pluggable_device/c/tf_rendezvous_c_api_test.cc +++ b/tensorflow/core/common_runtime/next_pluggable_device/c/tf_rendezvous_c_api_test.cc @@ -74,7 +74,7 @@ class FakeAllocator : public Allocator { class FakeDevice : public Device { public: explicit FakeDevice(const DeviceAttributes& attr) : Device(nullptr, attr) {} - Status Sync() override { return OkStatus(); } + Status Sync() override { return absl::OkStatus(); } Allocator* GetAllocator(AllocatorAttributes) override { return allocator_.get(); } diff --git a/tensorflow/core/common_runtime/next_pluggable_device/c_plugin_coordination_service_agent_test.cc b/tensorflow/core/common_runtime/next_pluggable_device/c_plugin_coordination_service_agent_test.cc index c6f351d59ba608..9070a7e2dcf146 100644 --- a/tensorflow/core/common_runtime/next_pluggable_device/c_plugin_coordination_service_agent_test.cc +++ b/tensorflow/core/common_runtime/next_pluggable_device/c_plugin_coordination_service_agent_test.cc @@ -208,7 +208,7 @@ TEST_F(CPluginCoordinationServiceAgentTest, GetKeyValue_Simple_Success) { kv->set_value(test_value); ON_CALL(*GetClient(), GetKeyValueAsync(_, _, _, _)) .WillByDefault(DoAll(SetArgPointee<2>(mocked_response), - InvokeArgument<3>(OkStatus()))); + InvokeArgument<3>(absl::OkStatus()))); // Initialize coordination agent. InitializeAgent(); @@ -228,7 +228,7 @@ TEST_F(CPluginCoordinationServiceAgentTest, GetKeyValue_WithTimeout_Success) { kv->set_value(test_value); ON_CALL(*GetClient(), GetKeyValueAsync(_, _, _, _)) .WillByDefault(DoAll(SetArgPointee<2>(mocked_response), - InvokeArgument<3>(OkStatus()))); + InvokeArgument<3>(absl::OkStatus()))); // Initialize coordination agent. InitializeAgent(); @@ -282,7 +282,7 @@ TEST_F(CPluginCoordinationServiceAgentTest, InsertKeyValue_Success) { EXPECT_CALL(*GetClient(), InsertKeyValueAsync(Pointee(EqualsProto(expected_input)), _, _)) - .WillOnce(InvokeArgument<2>(OkStatus())); + .WillOnce(InvokeArgument<2>(absl::OkStatus())); InitializeAgent(); TF_ASSERT_OK(agent_->InsertKeyValue(test_key, test_value)); @@ -296,7 +296,7 @@ TEST_F(CPluginCoordinationServiceAgentTest, DeleteKeyValue_Success) { EXPECT_CALL(*GetClient(), DeleteKeyValueAsync(Pointee(EqualsProto(expected_input)), _, _)) - .WillOnce(InvokeArgument<2>(OkStatus())); + .WillOnce(InvokeArgument<2>(absl::OkStatus())); InitializeAgent(); TF_ASSERT_OK(agent_->DeleteKeyValue(test_key)); @@ -312,7 +312,7 @@ TEST_F(CPluginCoordinationServiceAgentTest, TryGetKeyValue_Simple_Success) { kv->set_value(test_value); ON_CALL(*GetClient(), TryGetKeyValueAsync(_, _, _)) .WillByDefault(DoAll(SetArgPointee<1>(mocked_response), - InvokeArgument<2>(OkStatus()))); + InvokeArgument<2>(absl::OkStatus()))); // Initialize coordination agent. InitializeAgent(); diff --git a/tensorflow/core/common_runtime/next_pluggable_device/c_plugin_variable.cc b/tensorflow/core/common_runtime/next_pluggable_device/c_plugin_variable.cc index 63a713f86ecefa..9ffe65ba8f7243 100644 --- a/tensorflow/core/common_runtime/next_pluggable_device/c_plugin_variable.cc +++ b/tensorflow/core/common_runtime/next_pluggable_device/c_plugin_variable.cc @@ -27,11 +27,11 @@ namespace tensorflow { CPluginVariable::~CPluginVariable() { TF_DeleteVariableInfo(var_info_); } -tsl::Status CPluginVariable::GetTensorInternal() { +absl::Status CPluginVariable::GetTensorInternal() { // Note: we assume once a variable is initialized, it's underlying tensor // won't change during it's lifecycle. if (tensor_obtained_) { - return tsl::OkStatus(); + return absl::OkStatus(); } TF_StatusPtr c_status_ptr(TF_NewStatus()); TF_Tensor* c_tensor = @@ -42,21 +42,21 @@ tsl::Status CPluginVariable::GetTensorInternal() { } TF_RETURN_IF_ERROR(TF_TensorToTensor(c_tensor_ptr.get(), &tensor_)); tensor_obtained_ = true; - return tsl::OkStatus(); + return absl::OkStatus(); } -tsl::Status CPluginVariable::GetTensor(const Tensor** result_tensor) { +absl::Status CPluginVariable::GetTensor(const Tensor** result_tensor) { TF_RETURN_IF_ERROR(GetTensorInternal()); *result_tensor = &tensor_; - return tsl::OkStatus(); + return absl::OkStatus(); } -tsl::Status CPluginVariable::GetMutableTensor(Tensor** result_tensor) { +absl::Status CPluginVariable::GetMutableTensor(Tensor** result_tensor) { // Note: we assume once a variable is initialized, it's underlying tensor // won't change during it's lifecycle. TF_RETURN_IF_ERROR(GetTensorInternal()); *result_tensor = &tensor_; - return tsl::OkStatus(); + return absl::OkStatus(); } } // namespace tensorflow diff --git a/tensorflow/core/common_runtime/next_pluggable_device/c_plugin_variable.h b/tensorflow/core/common_runtime/next_pluggable_device/c_plugin_variable.h index 7cca7481e4d7ff..55b6de0adc6585 100644 --- a/tensorflow/core/common_runtime/next_pluggable_device/c_plugin_variable.h +++ b/tensorflow/core/common_runtime/next_pluggable_device/c_plugin_variable.h @@ -29,16 +29,16 @@ class CPluginVariable : public PluginVariable { ~CPluginVariable() override; explicit CPluginVariable(TF_VariableInfo* var_info) : var_info_(var_info) {} - tsl::Status GetTensor(const Tensor** result_tensor) override; + absl::Status GetTensor(const Tensor** result_tensor) override; - tsl::Status GetMutableTensor(Tensor** result_tensor) override; + absl::Status GetMutableTensor(Tensor** result_tensor) override; TF_VariableInfo* GetVariableInfo() { return var_info_; } friend class CPluginOpKernelContext; private: - tsl::Status GetTensorInternal(); + absl::Status GetTensorInternal(); TF_VariableInfo* var_info_; // Owned. Cleared by destructor. bool tensor_obtained_ = false; diff --git a/tensorflow/core/common_runtime/next_pluggable_device/direct_plugin_op_kernel.cc b/tensorflow/core/common_runtime/next_pluggable_device/direct_plugin_op_kernel.cc index feedccee815bcf..63a147facf9928 100644 --- a/tensorflow/core/common_runtime/next_pluggable_device/direct_plugin_op_kernel.cc +++ b/tensorflow/core/common_runtime/next_pluggable_device/direct_plugin_op_kernel.cc @@ -74,7 +74,7 @@ Status DirectPluginOpKernelContext::CreatePluginVariable( TF_RETURN_IF_ERROR(LookupResource(ctx_, handle, &var)); *variable = new DirectPluginVariable(index, handle.name(), var); - return tsl::OkStatus(); + return absl::OkStatus(); } Status DirectPluginOpKernelContext::AllocateTempForPluginVariable( @@ -112,16 +112,16 @@ Status DirectPluginOpKernelContext::LookupOrCreateResource( void* opaque_plugin_resource = create_func(create_func_args); *new_resource = new tensorflow::PluginResource( opaque_plugin_resource, plugin_resource_name, delete_func); - return tensorflow::OkStatus(); + return absl::OkStatus(); })); tf_plugin_resource_ptr.reset(tf_plugin_resource); *result_plugin_resource = tf_plugin_resource_ptr->GetOpaquePluginResource(); - return OkStatus(); + return absl::OkStatus(); } Status DirectPluginOpKernelContext::GetInput(int index, Tensor* tensor) const { *tensor = ctx_->input(index); - return OkStatus(); + return absl::OkStatus(); } Status DirectPluginOpKernelContext::GetInput(const char* name, diff --git a/tensorflow/core/common_runtime/next_pluggable_device/direct_plugin_op_kernel.h b/tensorflow/core/common_runtime/next_pluggable_device/direct_plugin_op_kernel.h index 7fd23cdf87dc00..4abb6e627fb6d6 100644 --- a/tensorflow/core/common_runtime/next_pluggable_device/direct_plugin_op_kernel.h +++ b/tensorflow/core/common_runtime/next_pluggable_device/direct_plugin_op_kernel.h @@ -120,7 +120,7 @@ class DirectPluginOpKernelContext : public PluginOpKernelContext { Status GetConfigProto(const ConfigProto** config_proto) const override { *config_proto = ctx_->function_library()->config_proto(); - return OkStatus(); + return absl::OkStatus(); } void MaybeDeleteConfigProto(const ConfigProto* config_proto) const override { @@ -131,7 +131,7 @@ class DirectPluginOpKernelContext : public PluginOpKernelContext { Status GetFunctionLibraryDefinition( const FunctionLibraryDefinition** flib_def) const override { *flib_def = ctx_->function_library()->GetFunctionLibraryDefinition(); - return OkStatus(); + return absl::OkStatus(); } void MaybeDeleteFunctionLibraryDefinition( @@ -143,7 +143,7 @@ class DirectPluginOpKernelContext : public PluginOpKernelContext { Status GetResourceHandle(int index, const ResourceHandle** handle) const override { *handle = &HandleFromInput(ctx_, index); - return OkStatus(); + return absl::OkStatus(); } void MaybeDeleteResourceHandle(const ResourceHandle* handle) const override { @@ -162,7 +162,7 @@ class DirectPluginOpKernelContext : public PluginOpKernelContext { Status SetOutput(int index, const Tensor& tensor) override { ctx_->set_output(index, tensor); - return OkStatus(); + return absl::OkStatus(); } void CtxFailure(const Status& status) override { ctx_->CtxFailure(status); } diff --git a/tensorflow/core/common_runtime/next_pluggable_device/direct_plugin_variable.h b/tensorflow/core/common_runtime/next_pluggable_device/direct_plugin_variable.h index 525cea97b5f234..7f1c071bdc4abb 100644 --- a/tensorflow/core/common_runtime/next_pluggable_device/direct_plugin_variable.h +++ b/tensorflow/core/common_runtime/next_pluggable_device/direct_plugin_variable.h @@ -29,14 +29,14 @@ class DirectPluginOpKernelContext; class DirectPluginVariable : public PluginVariable { public: DirectPluginVariable(int index, const std::string& name, Var* var); - tsl::Status GetTensor(const Tensor** result_tensor) override { + absl::Status GetTensor(const Tensor** result_tensor) override { *result_tensor = var_info_.var()->tensor(); - return tsl::OkStatus(); + return absl::OkStatus(); } - tsl::Status GetMutableTensor(Tensor** result_tensor) override { + absl::Status GetMutableTensor(Tensor** result_tensor) override { *result_tensor = var_info_.var()->tensor(); - return tsl::OkStatus(); + return absl::OkStatus(); } VariableInfo* GetVariableInfo() { return &var_info_; } diff --git a/tensorflow/core/common_runtime/next_pluggable_device/next_pluggable_device.cc b/tensorflow/core/common_runtime/next_pluggable_device/next_pluggable_device.cc index ab0baa29025dc5..09a2fff9f37177 100644 --- a/tensorflow/core/common_runtime/next_pluggable_device/next_pluggable_device.cc +++ b/tensorflow/core/common_runtime/next_pluggable_device/next_pluggable_device.cc @@ -105,7 +105,7 @@ void NextPluggableDevice::ComputeAsync(AsyncOpKernel* op_kernel, } // TODO(chuanhao): implement NextPluggableDevice::Sync(). -Status NextPluggableDevice::Sync() { return OkStatus(); } +Status NextPluggableDevice::Sync() { return absl::OkStatus(); } // TODO(chuanhao): implement NextPluggableDevice::Sync(). void NextPluggableDevice::Sync(const DoneCallback& done) { done(Sync()); } @@ -113,7 +113,7 @@ void NextPluggableDevice::Sync(const DoneCallback& done) { done(Sync()); } Status NextPluggableDevice::TryGetDeviceContext(DeviceContext** out_context) { *out_context = device_context_.get(); (*out_context)->Ref(); - return OkStatus(); + return absl::OkStatus(); } Status NextPluggableDevice::MakeTensorFromProto( @@ -195,7 +195,7 @@ Status NextPluggableDevice::MakeTensorFromProto( device_context_->CopyCPUTensorToDevice(&from, this, copy_dst, std::move(wrapped_done), true /*sync_dst_compute*/); - return OkStatus(); + return absl::OkStatus(); }; Status s; diff --git a/tensorflow/core/common_runtime/next_pluggable_device/next_pluggable_device_factory.cc b/tensorflow/core/common_runtime/next_pluggable_device/next_pluggable_device_factory.cc index 4c2c8633b6ae46..ac7bc33925df8c 100644 --- a/tensorflow/core/common_runtime/next_pluggable_device/next_pluggable_device_factory.cc +++ b/tensorflow/core/common_runtime/next_pluggable_device/next_pluggable_device_factory.cc @@ -67,7 +67,7 @@ Status NextPluggableDeviceFactory::ListPhysicalDevices( devices->push_back(device_name); } - return OkStatus(); + return absl::OkStatus(); } Status NextPluggableDeviceFactory::CreateDevices( @@ -84,7 +84,7 @@ Status NextPluggableDeviceFactory::CreateDevices( TF_DeleteStatus(c_status); if (visible_device_count <= 0) { - return OkStatus(); + return absl::OkStatus(); } const absl::flat_hash_map device_count_map( session_options.config.device_count().begin(), @@ -118,7 +118,7 @@ Status NextPluggableDeviceFactory::CreateDevices( LOG(INFO) << "Created " << num_tf_devices << " TensorFlow NextPluggableDevices. " << "Physical device type: " << device_type_; - return OkStatus(); + return absl::OkStatus(); } } // namespace tensorflow diff --git a/tensorflow/core/common_runtime/next_pluggable_device/plugin_variable.h b/tensorflow/core/common_runtime/next_pluggable_device/plugin_variable.h index 21692601cb1db7..ab2ec9a29fd353 100644 --- a/tensorflow/core/common_runtime/next_pluggable_device/plugin_variable.h +++ b/tensorflow/core/common_runtime/next_pluggable_device/plugin_variable.h @@ -36,9 +36,9 @@ class PluginVariable { // `result_tensor` will point to the tensor possessed by the variable if // status is ok. - virtual tsl::Status GetTensor(const Tensor** result_tensor) = 0; + virtual absl::Status GetTensor(const Tensor** result_tensor) = 0; - virtual tsl::Status GetMutableTensor(Tensor** result_tensor) = 0; + virtual absl::Status GetMutableTensor(Tensor** result_tensor) = 0; }; } // namespace tensorflow diff --git a/tensorflow/core/common_runtime/node_file_writer.cc b/tensorflow/core/common_runtime/node_file_writer.cc index c6ad3b7f94ae01..10fb334b38f8d3 100644 --- a/tensorflow/core/common_runtime/node_file_writer.cc +++ b/tensorflow/core/common_runtime/node_file_writer.cc @@ -127,7 +127,7 @@ tensorflow::NodeFileWriter::GetNodeFileWriterIfEnabled( Status NodeFileWriter::RecordNodeExecution(OpKernel* op_kernel, OpKernelContext* context) { if (kOpsToSkipWriting->count(op_kernel->type_string())) { - return OkStatus(); + return absl::OkStatus(); } NodeDef def; def.set_name("NodeFileWriter"); @@ -140,7 +140,7 @@ Status NodeFileWriter::RecordNodeExecution(OpKernel* op_kernel, if (!context->has_input(i) || context->input_is_ref(i)) { // Calling context->input(i) requires the input to exist and not be a ref, // so return immediately if that is not the case. - return OkStatus(); + return absl::OkStatus(); } TensorShapeProto* shape_proto = input_shapes.mutable_list()->add_shape(); const Tensor& input = context->input(i); @@ -155,7 +155,7 @@ Status NodeFileWriter::RecordNodeExecution(OpKernel* op_kernel, } else if (!DataTypeIsFloating(input.dtype())) { // Skip ops with non-floating-point inputs, since these are not useful // when testing determinism. - return OkStatus(); + return absl::OkStatus(); } } return MaybeWriteNodeDefToFile(def); @@ -185,7 +185,7 @@ Status NodeFileWriter::MaybeWriteNodeDefToFile(const NodeDef& def) { // file is never closed. TF_RETURN_IF_ERROR(node_def_file_->Flush()); } - return OkStatus(); + return absl::OkStatus(); } } // namespace tensorflow diff --git a/tensorflow/core/common_runtime/optimization_registry.cc b/tensorflow/core/common_runtime/optimization_registry.cc index b5f25ff9ea22cd..f2999ff521a17a 100644 --- a/tensorflow/core/common_runtime/optimization_registry.cc +++ b/tensorflow/core/common_runtime/optimization_registry.cc @@ -103,7 +103,7 @@ Status OptimizationPassRegistry::RunGrouping( VLOG_IS_ON(3) || (VLOG_IS_ON(2) && grouping == Grouping::POST_REWRITE_FOR_EXEC)); - return OkStatus(); + return absl::OkStatus(); } void OptimizationPassRegistry::LogGrouping(Grouping grouping, int vlog_level) { diff --git a/tensorflow/core/common_runtime/optimization_registry_test.cc b/tensorflow/core/common_runtime/optimization_registry_test.cc index 117105e2ace61b..1b9943d49a2a60 100644 --- a/tensorflow/core/common_runtime/optimization_registry_test.cc +++ b/tensorflow/core/common_runtime/optimization_registry_test.cc @@ -25,7 +25,7 @@ class TestOptimization : public GraphOptimizationPass { static int count_; Status Run(const GraphOptimizationPassOptions& options) override { ++count_; - return OkStatus(); + return absl::OkStatus(); } }; @@ -43,8 +43,9 @@ TEST(OptimizationRegistry, OptimizationPass) { new FunctionLibraryDefinition(OpRegistry::Global(), FunctionDefLibrary())); options.flib_def = flib_def.get(); - EXPECT_EQ(OkStatus(), OptimizationPassRegistry::Global()->RunGrouping( - OptimizationPassRegistry::PRE_PLACEMENT, options)); + EXPECT_EQ(absl::OkStatus(), + OptimizationPassRegistry::Global()->RunGrouping( + OptimizationPassRegistry::PRE_PLACEMENT, options)); EXPECT_EQ(1, TestOptimization::count_); } @@ -72,7 +73,7 @@ class OptimizationPassTest : public ::testing::Test { options.flib_def = flib_def_.get(); // Note that options.graph is not set so this test checks that passes // properly handle this being nullptr (esp. Segfault is avoided). - EXPECT_EQ(OkStatus(), + EXPECT_EQ(absl::OkStatus(), OptimizationPassRegistry::Global()->RunGrouping( OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, options)); } diff --git a/tensorflow/core/common_runtime/optimize_cross_host_control_deps.cc b/tensorflow/core/common_runtime/optimize_cross_host_control_deps.cc index cdd4a4dd017f57..6deb020f816b31 100644 --- a/tensorflow/core/common_runtime/optimize_cross_host_control_deps.cc +++ b/tensorflow/core/common_runtime/optimize_cross_host_control_deps.cc @@ -42,7 +42,7 @@ Status BuildNoopNode(const Node& source, StringPiece name, const string& device, if (!device.empty()) { (*node)->set_assigned_device_name(device); } - return OkStatus(); + return absl::OkStatus(); } Status BuildIdentityNNode(const Node& source, StringPiece name, @@ -62,7 +62,7 @@ Status BuildIdentityNNode(const Node& source, StringPiece name, if (!device.empty()) { (*node)->set_assigned_device_name(device); } - return OkStatus(); + return absl::OkStatus(); } const string& RequestedOrAssignedDevice(const Node* n) { @@ -176,7 +176,7 @@ Status OptimizeCrossHostControlOutputEdges(Graph* graph, } } } - return OkStatus(); + return absl::OkStatus(); } Status OptimizeCrossHostDataOutputEdges(Graph* graph, @@ -249,7 +249,7 @@ Status OptimizeCrossHostDataOutputEdges(Graph* graph, } } } - return OkStatus(); + return absl::OkStatus(); } Status OptimizeCrossHostControlInputEdges(Graph* graph, @@ -322,7 +322,7 @@ Status OptimizeCrossHostControlInputEdges(Graph* graph, } } } - return OkStatus(); + return absl::OkStatus(); } } // namespace tensorflow diff --git a/tensorflow/core/common_runtime/optimize_function_graph_utils.cc b/tensorflow/core/common_runtime/optimize_function_graph_utils.cc index c2f330ce6db6fb..e041fbbf360a6d 100644 --- a/tensorflow/core/common_runtime/optimize_function_graph_utils.cc +++ b/tensorflow/core/common_runtime/optimize_function_graph_utils.cc @@ -72,7 +72,7 @@ Status ValidateNoListArguments( " and outputs"); } } - return OkStatus(); + return absl::OkStatus(); } Status ValidateMultiDeviceOptions( @@ -107,7 +107,7 @@ Status ValidateMultiDeviceOptions( options.output_devices.size(), " number of arguments = ", signature.output_arg_size()); } - return OkStatus(); + return absl::OkStatus(); } Status SetArgShape(const std::unordered_map& @@ -133,7 +133,7 @@ Status SetArgShape(const std::unordered_map& } } } - return OkStatus(); + return absl::OkStatus(); } const string* AssignedOrRequestedDeviceName(const Node& node) { @@ -202,7 +202,7 @@ Status WriteToCache(const std::string& dir_name, const std::string& file_name, << absl::ToInt64Milliseconds(cache_writing_duration) << " msecs, file name: " << file_name; - return OkStatus(); + return absl::OkStatus(); } // Retrieves the OptimizedFunctionGraphInfo from a cache file. @@ -294,7 +294,7 @@ Status GetGraphAndArgRets(const string& function_name, AttrSlice attrs, for (const Node* node : fbody->control_ret_nodes) { control_ret_node_names->push_back(node->name()); } - return OkStatus(); + return absl::OkStatus(); } } // namespace @@ -460,7 +460,7 @@ Status PinArgsAndRets(const std::vector& input_devices, node->set_assigned_device_name(output_devices[index]); } } - return OkStatus(); + return absl::OkStatus(); } StatusOr OptimizeFunctionGraph( diff --git a/tensorflow/core/common_runtime/parallel_concat_optimizer.cc b/tensorflow/core/common_runtime/parallel_concat_optimizer.cc index ed02bc162175f5..bb6cc17b5b1665 100644 --- a/tensorflow/core/common_runtime/parallel_concat_optimizer.cc +++ b/tensorflow/core/common_runtime/parallel_concat_optimizer.cc @@ -30,7 +30,7 @@ class ParallelConcatRemovePass : public GraphOptimizationPass { if (options.graph == nullptr) { // TODO(apassos) returning OK feels weird here as we can't do anything // without a graph, but some tests require this. - return OkStatus(); + return absl::OkStatus(); } Graph* g = options.graph->get(); if (g == nullptr) { @@ -110,7 +110,7 @@ class ParallelConcatRemovePass : public GraphOptimizationPass { } g->RemoveNode(n); } - return OkStatus(); + return absl::OkStatus(); } }; REGISTER_OPTIMIZATION(OptimizationPassRegistry::PRE_PLACEMENT, 10, diff --git a/tensorflow/core/common_runtime/partitioning_utils.cc b/tensorflow/core/common_runtime/partitioning_utils.cc index 9756f5584ae02a..b7f84da5441c7a 100644 --- a/tensorflow/core/common_runtime/partitioning_utils.cc +++ b/tensorflow/core/common_runtime/partitioning_utils.cc @@ -108,7 +108,7 @@ Status MakeSendRecvDependencyExplicit(Graph* graph) { } graph->AddControlEdge(send_recv_pair.send_node, send_recv_pair.recv_node); } - return OkStatus(); + return absl::OkStatus(); } } // namespace @@ -136,7 +136,7 @@ Status PartitionFunctionGraph( subgraphs->emplace(device, std::move(subgraph)); } - return OkStatus(); + return absl::OkStatus(); } StatusOr> InsertTransferOps( @@ -259,7 +259,7 @@ Status UpdateArgAndRetvalMetadata( ret_nodes, ints_on_device, *ret_alloc_attrs)); } - return OkStatus(); + return absl::OkStatus(); } string FunctionNameGenerator::GetName() { diff --git a/tensorflow/core/common_runtime/permuter.h b/tensorflow/core/common_runtime/permuter.h index 654ab0b27291d8..a48195d8b10229 100644 --- a/tensorflow/core/common_runtime/permuter.h +++ b/tensorflow/core/common_runtime/permuter.h @@ -50,7 +50,7 @@ class Permuter : public CollectiveImplementationInterface { void Run(StatusCallback done) override; Status InitializeCollectiveParams(CollectiveParams* col_params) override { - return OkStatus(); + return absl::OkStatus(); } // Initializes members of CollectiveContext not yet initialized, i.e. device diff --git a/tensorflow/core/common_runtime/placer.cc b/tensorflow/core/common_runtime/placer.cc index 5f89a5937c8484..05dc029cc74756 100644 --- a/tensorflow/core/common_runtime/placer.cc +++ b/tensorflow/core/common_runtime/placer.cc @@ -86,7 +86,7 @@ Status GetFileName(string base_name, string* fname) { base_name = MakeUniqueFilename(base_name); *fname = absl::StrCat(result, "/", base_name); - return OkStatus(); + return absl::OkStatus(); } void DumpColocationGraph(const string& base_name, @@ -147,7 +147,7 @@ Status AssignAndLog(int assigned_device, Node* node, TF_RETURN_IF_ERROR(colocation_graph->LimitToAssignedDevice(*node)); LogDeviceAssignment(node, log_device_placement); - return OkStatus(); + return absl::OkStatus(); } } // namespace @@ -327,7 +327,7 @@ Status Placer::Run(const GraphOptimizationPassOptions& options) { strings::StrCat(options.debug_filename_prefix, "colocation_graph"), colocation_graph); } - return OkStatus(); + return absl::OkStatus(); } bool Placer::CanAssignToDevice(const string& candidate_device_name, diff --git a/tensorflow/core/common_runtime/placer_inspection_required_ops_utils.cc b/tensorflow/core/common_runtime/placer_inspection_required_ops_utils.cc index c6c90d4e772769..52591ec1882c00 100644 --- a/tensorflow/core/common_runtime/placer_inspection_required_ops_utils.cc +++ b/tensorflow/core/common_runtime/placer_inspection_required_ops_utils.cc @@ -44,7 +44,7 @@ Status Set(const Node& node, bool value, bool* is_deep, std::vector>* cache) { *is_deep = value; (*cache)[node.id()] = value; - return OkStatus(); + return absl::OkStatus(); } } // namespace @@ -59,7 +59,7 @@ Status PlacerInspectionRequiredOpChecker::IsPlacerInspectionRequired( const Node& node, bool* is_deep) { if (cache_[node.id()].has_value()) { *is_deep = cache_[node.id()].value(); - return OkStatus(); + return absl::OkStatus(); } if (!IsFunctionCall(node)) { @@ -91,7 +91,7 @@ Status GetFunctionDefAndAttrs(const FunctionLibraryDefinition& flib_def, "Failed to find function \"", function_name, "\" in function library: ", flib_def.ToProto().DebugString()); } - return OkStatus(); + return absl::OkStatus(); } FunctionStack::FunctionStack(const string& function_name) @@ -194,7 +194,7 @@ Status AddInputIdentity(Node* node, int input_idx, Graph* graph, VLOG(6) << "Successfully inserted identity. Modified node: \n" << node->DebugString(); - return OkStatus(); + return absl::OkStatus(); } struct EdgePtrCompare { @@ -218,7 +218,7 @@ Status AddOutputIdentities(Node* node, Graph* graph, TF_ASSIGN_OR_RETURN(*identity_node, graph->AddNode(identity_def)); graph->AddEdge(node, src_output, *identity_node, 0); - return OkStatus(); + return absl::OkStatus(); }; // output_used[i] == true iff `node`'s i'th output is used @@ -264,7 +264,7 @@ Status AddOutputIdentities(Node* node, Graph* graph, << " -> : \n" << identity_node->DebugString(); } - return OkStatus(); + return absl::OkStatus(); } Status IsolateNode(Node* node, Graph* graph) { @@ -281,7 +281,7 @@ Status IsolateNode(Node* node, Graph* graph) { TF_RETURN_IF_ERROR(AddInputIdentity(node, i, graph, &node_names)); } TF_RETURN_IF_ERROR(AddOutputIdentities(node, graph, &node_names)); - return OkStatus(); + return absl::OkStatus(); } } // namespace @@ -305,7 +305,7 @@ Status IsolatePlacerInspectionRequiredOps( TF_RETURN_IF_ERROR(IsolateNode(node, graph)); } - return OkStatus(); + return absl::OkStatus(); } } // namespace tensorflow diff --git a/tensorflow/core/common_runtime/placer_test.cc b/tensorflow/core/common_runtime/placer_test.cc index 0a2d1a278b2463..784436579e0913 100644 --- a/tensorflow/core/common_runtime/placer_test.cc +++ b/tensorflow/core/common_runtime/placer_test.cc @@ -112,11 +112,11 @@ class FakeDevice : public Device { class DummyFactory : public DeviceFactory { public: Status ListPhysicalDevices(std::vector* devices) override { - return OkStatus(); + return absl::OkStatus(); } Status CreateDevices(const SessionOptions& options, const string& name_prefix, std::vector>* devices) override { - return OkStatus(); + return absl::OkStatus(); } }; @@ -255,14 +255,14 @@ class PlacerTest : public ::testing::Test { Status BuildGraph(const GraphDefBuilder& builder, Graph* out_graph) { TF_RETURN_IF_ERROR(GraphDefBuilderToGraph(builder, out_graph)); RebuildNodeNameMap(*out_graph); - return OkStatus(); + return absl::OkStatus(); } Status BuildGraph(const GraphDef& graph_def, Graph* out_graph) { GraphConstructorOptions opts; TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(opts, graph_def, out_graph)); RebuildNodeNameMap(*out_graph); - return OkStatus(); + return absl::OkStatus(); } // Invokes the Placer on "graph". If no DeviceSet is specified, the @@ -960,7 +960,7 @@ Status PlacerTest::ReferenceTestHelper(const string& variable_op_type, EXPECT_DEVICE_TYPE(g, strings::StrCat("assign_", i), expected_device_type); } - return OkStatus(); + return absl::OkStatus(); } // Test all 2^3 combinations of Variable and Assignment op types @@ -1031,7 +1031,7 @@ TEST_F(PlacerTest, TestResourceHandle) { EXPECT_COLOCATED(g, "var", "assign"); EXPECT_DEVICE_TYPE(g, "var", device); EXPECT_DEVICE_TYPE(g, "assign", device); - return OkStatus(); + return absl::OkStatus(); }; TF_EXPECT_OK( handle_test("TestHandleVariable", "TestHandleAssign", "FakeGPU")); @@ -1097,7 +1097,7 @@ TEST_F(PlacerTest, TestResourceHandlesOnDifferentDevicesFails) { << s.ToString(); } - return OkStatus(); + return absl::OkStatus(); }; TF_EXPECT_OK(handle_test(false, false)); @@ -1202,7 +1202,7 @@ TEST_F(PlacerTest, TestResourceHandleOnCompositeDevice) { // `var` is assigned to COMPOSITE. GetNodeByName(*g, "var")->set_assigned_device_name( "/job:a/replica:0/task:0/device:COMPOSITE:0"); - return OkStatus(); + return absl::OkStatus(); }; { diff --git a/tensorflow/core/common_runtime/pluggable_device/BUILD b/tensorflow/core/common_runtime/pluggable_device/BUILD index fcd7f9150e750b..e77b75da78c342 100644 --- a/tensorflow/core/common_runtime/pluggable_device/BUILD +++ b/tensorflow/core/common_runtime/pluggable_device/BUILD @@ -203,6 +203,7 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core/platform:stream_executor", + "@local_xla//xla/stream_executor:platform_manager", ], alwayslink = 1, ) diff --git a/tensorflow/core/common_runtime/pluggable_device/pluggable_device.cc b/tensorflow/core/common_runtime/pluggable_device/pluggable_device.cc index 69749d99792502..3b08d428ff985b 100644 --- a/tensorflow/core/common_runtime/pluggable_device/pluggable_device.cc +++ b/tensorflow/core/common_runtime/pluggable_device/pluggable_device.cc @@ -241,7 +241,7 @@ Status PluggableDevice::Init(const SessionOptions& options) { } } - return OkStatus(); + return absl::OkStatus(); } Allocator* PluggableDevice::GetAllocator(AllocatorAttributes attr) { @@ -326,8 +326,8 @@ Status PluggableDevice::MaybeCopyTensorToPluggableDevice( StatusCallback done) { if (alloc_attrs.on_host()) { *to = from; - done(OkStatus()); - return OkStatus(); + done(absl::OkStatus()); + return absl::OkStatus(); } else { if (!DMAHelper::CanUseDMA(&from)) { Status err = errors::Internal("PluggableDevice copy from non-DMA ", @@ -359,7 +359,7 @@ Status PluggableDevice::MaybeCopyTensorToPluggableDevice( device_context_->CopyCPUTensorToDevice( &from, this, copy, std::move(wrapped_done), false /*sync_dst_compute*/); - return OkStatus(); + return absl::OkStatus(); } } diff --git a/tensorflow/core/common_runtime/pluggable_device/pluggable_device_context.cc b/tensorflow/core/common_runtime/pluggable_device/pluggable_device_context.cc index 5c74e602603469..ec3faf2d6329ca 100644 --- a/tensorflow/core/common_runtime/pluggable_device/pluggable_device_context.cc +++ b/tensorflow/core/common_runtime/pluggable_device/pluggable_device_context.cc @@ -55,7 +55,7 @@ Status PluggableDeviceContext::ThenExecute(Device* device, se::Stream* stream, const DeviceBase::AcceleratorDeviceInfo* device_info = device->tensorflow_accelerator_device_info(); device_info->event_mgr->ThenExecute(stream, func); - return OkStatus(); + return absl::OkStatus(); } bool PluggableDeviceContext::IsPluggableDevice() { return true; } diff --git a/tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc b/tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc index f6705b0f15228d..e9b6bee90f733f 100644 --- a/tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc +++ b/tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc @@ -115,7 +115,7 @@ Status SingleVirtualDeviceMemoryLimit(const string& platform_name, allocated_memory = total_memory * per_process_device_memory_fraction; } *memory_limit = allocated_memory; - return OkStatus(); + return absl::OkStatus(); } } // namespace @@ -135,7 +135,7 @@ Status PluggableDeviceFactory::ListPhysicalDevices( devices->push_back(device_name); } - return OkStatus(); + return absl::OkStatus(); } Status PluggableDeviceFactory::GetDeviceDetails( @@ -143,7 +143,7 @@ Status PluggableDeviceFactory::GetDeviceDetails( TF_RETURN_IF_ERROR(ValidatePluggableDeviceMachineManager(platform_name_)); se::Platform* platform = PluggableDeviceMachineManager(platform_name_); if (platform == nullptr) { - return OkStatus(); + return absl::OkStatus(); } int device_count = platform->VisibleDeviceCount(); @@ -159,7 +159,7 @@ Status PluggableDeviceFactory::GetDeviceDetails( auto desc = std::move(desc_status).value(); (*details)["device_name"] = desc->name(); - return OkStatus(); + return absl::OkStatus(); } Status PluggableDeviceFactory::CreateDevices( @@ -168,11 +168,11 @@ Status PluggableDeviceFactory::CreateDevices( TF_RETURN_IF_ERROR(ValidatePluggableDeviceMachineManager(platform_name_)); se::Platform* platform = PluggableDeviceMachineManager(platform_name_); if (platform == nullptr) { - return OkStatus(); + return absl::OkStatus(); } const int visible_device_count = platform->VisibleDeviceCount(); if (visible_device_count <= 0) { - return OkStatus(); + return absl::OkStatus(); } const absl::flat_hash_map device_count_map( options.config.device_count().begin(), @@ -211,7 +211,7 @@ Status PluggableDeviceFactory::CreateDevices( bytes, device_localities[di], devices)); } - return OkStatus(); + return absl::OkStatus(); } static string GetShortDeviceDescription(PlatformDeviceId platform_device_id, @@ -274,7 +274,7 @@ Status PluggableDeviceFactory::CreatePluggableDevice( << GetShortDeviceDescription(platform_device_id, *desc) << ")"; TF_RETURN_IF_ERROR(pluggable_device->Init(options)); devices->push_back(std::move(pluggable_device)); - return OkStatus(); + return absl::OkStatus(); } Status PluggableDeviceFactory::GetDeviceLocalities( @@ -317,7 +317,7 @@ Status PluggableDeviceFactory::GetDeviceLocalities( << dev_locality.bus_id() << " numa: " << numa_node << "DeviceLocality: " << dev_locality.DebugString(); } - return OkStatus(); + return absl::OkStatus(); } } // namespace tensorflow diff --git a/tensorflow/core/common_runtime/pluggable_device/pluggable_device_init.cc b/tensorflow/core/common_runtime/pluggable_device/pluggable_device_init.cc index f44daa362cf1b4..0b7279d0098ac1 100644 --- a/tensorflow/core/common_runtime/pluggable_device/pluggable_device_init.cc +++ b/tensorflow/core/common_runtime/pluggable_device/pluggable_device_init.cc @@ -17,6 +17,7 @@ limitations under the License. #include +#include "xla/stream_executor/platform_manager.h" #include "tensorflow/core/common_runtime/device_factory.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/strings/numbers.h" @@ -30,11 +31,11 @@ limitations under the License. namespace tensorflow { Status ValidatePluggableDeviceMachineManager(const string& platform_name) { - return se::MultiPlatformManager::PlatformWithName(platform_name).status(); + return se::PlatformManager::PlatformWithName(platform_name).status(); } se::Platform* PluggableDeviceMachineManager(const string& platform_name) { - auto result = se::MultiPlatformManager::PlatformWithName(platform_name); + auto result = se::PlatformManager::PlatformWithName(platform_name); if (!result.ok()) { LOG(FATAL) << "Could not find platform with name " // Crash OK << platform_name; diff --git a/tensorflow/core/common_runtime/pluggable_device/pluggable_device_plugin_init.cc b/tensorflow/core/common_runtime/pluggable_device/pluggable_device_plugin_init.cc index 02e33fe46f6adc..a3879acd5daa08 100644 --- a/tensorflow/core/common_runtime/pluggable_device/pluggable_device_plugin_init.cc +++ b/tensorflow/core/common_runtime/pluggable_device/pluggable_device_plugin_init.cc @@ -47,8 +47,8 @@ static Status InitDeviceModule(void* dso_handle) { if (absl::IsNotFound(status)) { VLOG(1) << "Device module not found."; - return OkStatus(); - } else if (status != OkStatus()) { + return absl::OkStatus(); + } else if (status != absl::OkStatus()) { return status; } auto init_fn = reinterpret_cast(dso_symbol); @@ -68,7 +68,7 @@ static Status InitDeviceModule(void* dso_handle) { /*is_pluggable_device=*/true)); // Register the Copy tensor. VLOG(1) << "Successfully initialized Device module."; - return OkStatus(); + return absl::OkStatus(); } typedef const PJRT_Api* (*PjrtApiInitFn)(); @@ -81,8 +81,8 @@ static Status InitNextPluggableDeviceModule(void* dso_handle) { env->GetSymbolFromLibrary(dso_handle, "TFNPD_InitPlugin", &dso_symbol); if (absl::IsNotFound(status)) { VLOG(1) << "Next pluggable device module not found."; - return OkStatus(); - } else if (status != OkStatus()) { + return absl::OkStatus(); + } else if (status != absl::OkStatus()) { return status; } auto init_fn = reinterpret_cast(dso_symbol); @@ -98,7 +98,7 @@ static Status InitNextPluggableDeviceModule(void* dso_handle) { if (absl::IsNotFound(status)) { VLOG(1) << "Loading PJRT plugin failed for " << device_type << ": " << status.message(); - return OkStatus(); + return absl::OkStatus(); } else if (!status.ok()) { return status; } @@ -140,7 +140,7 @@ static Status InitNextPluggableDeviceModule(void* dso_handle) { /*is_pluggable_device=*/true)); // Register the Copy tensor. VLOG(1) << "Successfully initialized NextPluggableDevice module."; - return OkStatus(); + return absl::OkStatus(); } static Status InitGraphModule(void* dso_handle) { @@ -151,15 +151,15 @@ static Status InitGraphModule(void* dso_handle) { if (absl::IsNotFound(status)) { VLOG(1) << "Graph module not found."; - return OkStatus(); - } else if (status != OkStatus()) { + return absl::OkStatus(); + } else if (status != absl::OkStatus()) { return status; } auto init_fn = reinterpret_cast(dso_symbol); TF_RETURN_IF_ERROR(grappler::InitGraphPlugin(init_fn)); VLOG(1) << "Successfully initialized Graph module."; - return OkStatus(); + return absl::OkStatus(); } typedef void (*TFKernelInitFn)(); @@ -171,8 +171,8 @@ static Status InitKernelModule(void* dso_handle) { if (absl::IsNotFound(status)) { VLOG(1) << "Kernel module not found."; - return OkStatus(); - } else if (status != OkStatus()) { + return absl::OkStatus(); + } else if (status != absl::OkStatus()) { return status; } @@ -180,7 +180,7 @@ static Status InitKernelModule(void* dso_handle) { init_fn(); VLOG(1) << "Successfully initialized Kernel module."; - return OkStatus(); + return absl::OkStatus(); } static Status InitProfilerModule(void* dso_handle) { @@ -192,8 +192,8 @@ static Status InitProfilerModule(void* dso_handle) { if (absl::IsNotFound(status)) { VLOG(1) << "Profiler module not found."; - return OkStatus(); - } else if (status != OkStatus()) { + return absl::OkStatus(); + } else if (status != absl::OkStatus()) { return status; } @@ -201,7 +201,7 @@ static Status InitProfilerModule(void* dso_handle) { TF_RETURN_IF_ERROR(profiler::InitPluginProfiler(init_fn)); VLOG(1) << "Successfully initialized Profiler module"; - return OkStatus(); + return absl::OkStatus(); } Status RegisterPluggableDevicePlugin(void* dso_handle) { @@ -220,7 +220,7 @@ Status RegisterPluggableDevicePlugin(void* dso_handle) { // Step 4 Init Profiler Module. TF_RETURN_IF_ERROR(InitProfilerModule(dso_handle)); - return OkStatus(); + return absl::OkStatus(); } } // namespace tensorflow diff --git a/tensorflow/core/common_runtime/pluggable_device/pluggable_device_util.cc b/tensorflow/core/common_runtime/pluggable_device/pluggable_device_util.cc index 1289675cab1d55..b03132e68d3d4e 100644 --- a/tensorflow/core/common_runtime/pluggable_device/pluggable_device_util.cc +++ b/tensorflow/core/common_runtime/pluggable_device/pluggable_device_util.cc @@ -96,7 +96,7 @@ static Status PrepareCopy(Device* device, const DeviceContext* ctx, return errors::Internal("PluggableDevice copy from non-DMA", DataTypeString(src.dtype()), " tensor."); } - return OkStatus(); + return absl::OkStatus(); } static void* GetBase(const Tensor* src) { @@ -163,7 +163,7 @@ void PluggableDeviceUtil::DeviceToDeviceCopy( LOG(FATAL) << "PluggableDevice->PluggableDevice Memcpy " // Crash OK << "failed."; } - done(OkStatus()); + done(absl::OkStatus()); }); send_dev_context->MaintainLifetimeOnStream(input, send_device_to_device_stream); @@ -212,7 +212,7 @@ void PluggableDeviceUtil::CopyPluggableDeviceTensorToCPU( LOG(FATAL) << "PluggableDevice->CPU Memcpy failed."; // Crash OK } input_ref.Unref(); - done(OkStatus()); + done(absl::OkStatus()); }); } @@ -261,7 +261,7 @@ void PluggableDeviceUtil::CopyCPUTensorToPluggableDevice( if (!recv_host_to_device_stream->ok()) { LOG(FATAL) << "CPU->PluggableDevice Memcpy failed."; // Crash OK } - done(OkStatus()); + done(absl::OkStatus()); }); } @@ -284,7 +284,7 @@ Status PluggableDeviceUtil::SyncAll(Device* device) { !dev_info->stream->ok()) { return errors::Internal("PluggableDevice SyncAll failed."); } - return OkStatus(); + return absl::OkStatus(); } // static @@ -311,7 +311,7 @@ void PluggableDeviceUtil::CopyPluggableDeviceTensorToSameDevice( send_stream->ThenMemcpy(&device_dst_ptr, device_src_ptr, total_bytes); } - done(OkStatus()); + done(absl::OkStatus()); } } // namespace tensorflow diff --git a/tensorflow/core/common_runtime/process_function_library_runtime.cc b/tensorflow/core/common_runtime/process_function_library_runtime.cc index 801addf64a7e84..ba15b2c3d82368 100644 --- a/tensorflow/core/common_runtime/process_function_library_runtime.cc +++ b/tensorflow/core/common_runtime/process_function_library_runtime.cc @@ -144,7 +144,7 @@ Status ProcessFunctionLibraryRuntime::SendTensors( } TF_RETURN_IF_ERROR(SendTensorsToRendezvous( rendezvous, device_context, alloc_attrs, keys, tensors_to_send)); - return OkStatus(); + return absl::OkStatus(); } /* static */ @@ -174,7 +174,7 @@ Status ProcessFunctionLibraryRuntime::GetRetTypes( auto miter = mdevice_data_.find(h); if (miter != mdevice_data_.end()) { *ret_types = miter->second->ret_types_; - return OkStatus(); + return absl::OkStatus(); } auto fiter = function_data_.find(h); if (fiter != function_data_.end()) { @@ -194,7 +194,7 @@ Status ProcessFunctionLibraryRuntime::GetDeviceIncarnation( return errors::InvalidArgument("Device name: ", device_name, " not found."); } *incarnation = flr->device()->attributes().incarnation(); - return OkStatus(); + return absl::OkStatus(); } Status ProcessFunctionLibraryRuntime::GetDeviceContext( @@ -208,14 +208,14 @@ Status ProcessFunctionLibraryRuntime::GetDeviceContext( string device_type = device->parsed_name().type; if (device_type == "CPU" || device_type == "TPU_SYSTEM") { // "TPU_SYSTEM" indicates that `device` is a CPU. - return OkStatus(); + return absl::OkStatus(); } if (device->IsRemoteCallAllowed()) { auto* dev_info = flr->device()->tensorflow_accelerator_device_info(); if (dev_info) { *device_context = dev_info->default_context; - return OkStatus(); + return absl::OkStatus(); } } @@ -236,7 +236,7 @@ void ProcessFunctionLibraryRuntime::InitializeDeviceAndFlr() { if (parent_ != nullptr && parent_->remote_device_mgr() != nullptr) { for (auto d : parent_->remote_device_mgr()->ListDevices()) { Device* device = nullptr; - if (device_mgr_->LookupDevice(d->name(), &device) == OkStatus()) { + if (device_mgr_->LookupDevice(d->name(), &device) == absl::OkStatus()) { // If this device exists in device_mgr, i.e., a local device, // add this device from the instance included in device_mgr_ device_set_->AddDevice(device); @@ -417,7 +417,7 @@ Status FunctionRetsToTensors(const std::vector* function_rets, // NOLINTNEXTLINE tensors->push_back(absl::get(ret)); } - return OkStatus(); + return absl::OkStatus(); } } // namespace @@ -488,7 +488,7 @@ Status ProcessFunctionLibraryRuntime::InstantiateMultiDevice( if (it != table_.end()) { *handle = it->second; ++mdevice_data_[*handle]->instantiation_counter_; - return OkStatus(); + return absl::OkStatus(); } } @@ -765,7 +765,7 @@ Status ProcessFunctionLibraryRuntime::InstantiateMultiDevice( if (should_publish_function_graphs) { PublishSubgraphs(function_name, std::move(function_records)); } - return OkStatus(); + return absl::OkStatus(); } Status ProcessFunctionLibraryRuntime::GetOutputDevices( @@ -812,7 +812,7 @@ Status ProcessFunctionLibraryRuntime::GetOutputDevices( } } - return OkStatus(); + return absl::OkStatus(); } Status ProcessFunctionLibraryRuntime::PrepareRunMultiDevice( @@ -844,7 +844,7 @@ Status ProcessFunctionLibraryRuntime::PrepareRunMultiDevice( " without an appropriate cross process Rendezvous."); } - return OkStatus(); + return absl::OkStatus(); } std::vector ProcessFunctionLibraryRuntime::GetOrderedSubgraphs( @@ -961,7 +961,7 @@ Status ProcessFunctionLibraryRuntime::RunMultiDeviceSync( return s; } } - return OkStatus(); + return absl::OkStatus(); } void ProcessFunctionLibraryRuntime::RunMultiDeviceAsync( @@ -1105,12 +1105,12 @@ Status ProcessFunctionLibraryRuntime::IsCrossProcess( const auto& mdevice_it = mdevice_data_.find(handle); if (mdevice_it != mdevice_data_.end()) { *is_cross_process = mdevice_it->second->is_cross_process_; - return OkStatus(); + return absl::OkStatus(); } const auto& it = function_data_.find(handle); if (it != function_data_.end()) { *is_cross_process = it->second->is_cross_process(); - return OkStatus(); + return absl::OkStatus(); } return errors::InvalidArgument("Handle ", handle, " not found."); } @@ -1156,7 +1156,7 @@ Status ProcessFunctionLibraryRuntime::RemoveHandle( mutex_lock l(mu_); table_.erase(function_data_[handle]->function_key()); function_data_.erase(handle); - return OkStatus(); + return absl::OkStatus(); } Status ProcessFunctionLibraryRuntime::ReleaseMultiDeviceHandle( @@ -1167,7 +1167,7 @@ Status ProcessFunctionLibraryRuntime::ReleaseMultiDeviceHandle( auto it = mdevice_data_.find(handle); --it->second->instantiation_counter_; if (it->second->instantiation_counter_ != 0) { - return OkStatus(); + return absl::OkStatus(); } mdata = std::move(it->second); table_.erase(mdata->function_key_); @@ -1205,7 +1205,7 @@ Status ProcessFunctionLibraryRuntime::ReleaseMultiDeviceHandle( Status ProcessFunctionLibraryRuntime::ReleaseHandle( FunctionLibraryRuntime::Handle handle) { // Return directly if all function handles has already been released. - if (flr_map_ == nullptr) return OkStatus(); + if (flr_map_ == nullptr) return absl::OkStatus(); if (IsMultiDevice(handle)) { return ReleaseMultiDeviceHandle(handle); @@ -1293,7 +1293,7 @@ Status ProcessFunctionLibraryRuntime::GetComponentArgs( comp_args->args.push_back(args[it.index]); } } - return OkStatus(); + return absl::OkStatus(); } #if !defined(IS_MOBILE_PLATFORM) @@ -1315,7 +1315,7 @@ Status ProcessFunctionLibraryRuntime::GetComponentArgs( comp_args->args.push_back(comp_args->remote_args.back().get()); } } - return OkStatus(); + return absl::OkStatus(); } #endif // IS_MOBILE_PLATFORM @@ -1500,7 +1500,7 @@ void ProcessFunctionLibraryRuntime::Run( return; } } - done(OkStatus()); + done(absl::OkStatus()); }); } @@ -1650,7 +1650,7 @@ Status ProcessFunctionLibraryRuntime::Clone( tf_shared_lock l(mu_); for (auto* d : composite_devices_) (*out_pflr)->AddCompositeDevice(d); } - return OkStatus(); + return absl::OkStatus(); } } // namespace tensorflow diff --git a/tensorflow/core/common_runtime/process_function_library_runtime_test.cc b/tensorflow/core/common_runtime/process_function_library_runtime_test.cc index 01e243f6300f13..c95f118a1a1589 100644 --- a/tensorflow/core/common_runtime/process_function_library_runtime_test.cc +++ b/tensorflow/core/common_runtime/process_function_library_runtime_test.cc @@ -67,7 +67,7 @@ class TestClusterFLR : public DistributedFunctionLibraryRuntime { *handle = next_handle_; next_handle_++; } - done(OkStatus()); + done(absl::OkStatus()); } void Run(const FunctionLibraryRuntime::Options& opts, @@ -158,7 +158,7 @@ class ProcessFunctionLibraryRuntimeTest : public ::testing::Test { return tsl::core::RefCountPtr( new IntraProcessRendezvous(device_mgr)); }); - return OkStatus(); + return absl::OkStatus(); }})); } @@ -261,7 +261,7 @@ class ProcessFunctionLibraryRuntimeTest : public ::testing::Test { EXPECT_TRUE(errors::IsNotFound(status)) << "Actual status: " << status; EXPECT_TRUE(absl::StrContains(status.message(), "not found.")); - return OkStatus(); + return absl::OkStatus(); } Status Run(const string& name, FunctionLibraryRuntime::Options opts, @@ -308,7 +308,7 @@ class ProcessFunctionLibraryRuntimeTest : public ::testing::Test { for (size_t i = 0; i < rets.size(); ++i) { *rets[i] = out[i]; } - return OkStatus(); + return absl::OkStatus(); } std::unique_ptr device_mgr_; @@ -935,7 +935,7 @@ class TestFunctionPackedArgs : public FunctionArgsInterface { Status GetLocalArg(const FunctionArgIndex& index, Tensor* val) const override { *val = *packed_args_.at(index.index).at(index.sub_index).tensor; - return OkStatus(); + return absl::OkStatus(); }; std::vector GetLocalTensors() const override { return {}; } diff --git a/tensorflow/core/common_runtime/quantize_training.cc b/tensorflow/core/common_runtime/quantize_training.cc index 0269739828ca2d..8f225405cf41d3 100644 --- a/tensorflow/core/common_runtime/quantize_training.cc +++ b/tensorflow/core/common_runtime/quantize_training.cc @@ -148,7 +148,7 @@ Status FindSaveOp(const Graph* graph, Node** save_op, TF_RETURN_IF_ERROR(node->input_edges(in_edges)); } } - return OkStatus(); + return absl::OkStatus(); } Node* FindRestoreAllOp(const Graph* graph, StringPiece save_prefix) { @@ -237,7 +237,7 @@ Status ConnectVariablesToSaveOp(Graph* graph, Node* save_op, } graph->RemoveNode(save_op); - return OkStatus(); + return absl::OkStatus(); } // Add a restore subgraph for each variable and connect to the restore_all op. @@ -307,7 +307,7 @@ Status AddRestoreVariableSubgraphs(Graph* graph, Node* save_op, // Add a control edge from the assign op to restore_all op. graph->AddControlEdge(assign_op, restore_all); } - return OkStatus(); + return absl::OkStatus(); } // Adds new variables to save and restore ops matching the Save and Restore @@ -323,7 +323,7 @@ Status AddSaveAndRestore(Graph* graph, const std::vector& variables) { TF_RETURN_IF_ERROR( ConnectVariablesToSaveOp(graph, save_op, in_edges, variables)); } - return OkStatus(); + return absl::OkStatus(); } // Sets output to the Node that computes reduction axes corresponding to all @@ -358,7 +358,7 @@ Status MakeReductionAxes(Graph* graph, string name_prefix, Node* input, .Input(rank) .Input(delta) .Finalize(graph, output)); - return OkStatus(); + return absl::OkStatus(); } // Computes the exponential moving average of input, updated in update_variable. @@ -401,7 +401,7 @@ Status MakeExponentialMovingAverage(Graph* graph, string name_prefix, .Input(update_variable) .Input(update_value) .Finalize(graph, assign_value)); - return OkStatus(); + return absl::OkStatus(); } // Creates an automatically initialized exponential moving average variable. @@ -454,7 +454,7 @@ Status MakeInitializedEMAVariable(Graph* graph, const string& name, Node* decay, .Input(*var) .Input(assign_value) .Finalize(graph, var)); - return OkStatus(); + return absl::OkStatus(); } // Computes the min and max EMA of input and stores them in min_var and max_var. @@ -492,7 +492,7 @@ Status MakeEMAMinMaxVars(Graph* graph, const string& name_prefix, Node* input, added_variables, min_var)); TF_RETURN_IF_ERROR(MakeInitializedEMAVariable(graph, max_name, decay, max, added_variables, max_var)); - return OkStatus(); + return absl::OkStatus(); } // Makes an input min and max constant if the range is given. Otherwise, makes @@ -525,7 +525,7 @@ Status MakeInputMinMax(Graph* graph, const string& name_prefix, input_max)); } - return OkStatus(); + return absl::OkStatus(); } // Adds a QuantizeAndDequantizeV2 or FakeQuantizeWithMinMaxVars op @@ -559,7 +559,7 @@ Status MakeQuantizeOp(Graph* graph, const string& name_prefix, } else { return errors::InvalidArgument("Unknown quant op type: ", quant_op_type); } - return OkStatus(); + return absl::OkStatus(); } // Insert conversion op, connect it to the graph and remove the old edge. @@ -588,7 +588,7 @@ Status ProcessTargetEdges(Graph* graph, const string& quant_op_type, TF_RETURN_IF_ERROR(AddSaveAndRestore(graph, added_variables)); - return OkStatus(); + return absl::OkStatus(); } } // namespace @@ -655,7 +655,7 @@ Status DoQuantizeTraining(int32_t num_bits, const string& quant_op_type, TF_RETURN_IF_ERROR(ProcessTargetEdges(graph, quant_op_type, target_edges)); - return OkStatus(); + return absl::OkStatus(); } Status DoQuantizeTrainingOnGraphDef(const GraphDef& input_graphdef, @@ -671,7 +671,7 @@ Status DoQuantizeTrainingOnGraphDef(const GraphDef& input_graphdef, // Convert the result graph back to a GraphDef. graph.ToGraphDef(result_graphdef); - return OkStatus(); + return absl::OkStatus(); } Status DoQuantizeTrainingOnSerializedGraphDef(const string& input_graph_string, @@ -692,7 +692,7 @@ Status DoQuantizeTrainingOnSerializedGraphDef(const string& input_graph_string, return errors::Internal( "quantize training transformation resulted in invalid GraphDef"); } - return OkStatus(); + return absl::OkStatus(); } } // namespace tensorflow diff --git a/tensorflow/core/common_runtime/quantize_training_test.cc b/tensorflow/core/common_runtime/quantize_training_test.cc index db228c8e08be85..4031e7405049be 100644 --- a/tensorflow/core/common_runtime/quantize_training_test.cc +++ b/tensorflow/core/common_runtime/quantize_training_test.cc @@ -57,14 +57,14 @@ class QuantizeTrainingTest : public ::testing::Test { .Attr("dtype", DT_FLOAT) .Attr("shape", shape) .Finalize(g, out)); - return OkStatus(); + return absl::OkStatus(); } Status FindNode(Graph* g, const string& name, Node** out) { for (Node* node : g->nodes()) { if (node->name() == name) { *out = node; - return OkStatus(); + return absl::OkStatus(); } } return errors::Unimplemented("Node ", name, " not found."); diff --git a/tensorflow/core/common_runtime/rendezvous_mgr.cc b/tensorflow/core/common_runtime/rendezvous_mgr.cc index 80bc0ff53f2c08..75bfb99e5ef6e8 100644 --- a/tensorflow/core/common_runtime/rendezvous_mgr.cc +++ b/tensorflow/core/common_runtime/rendezvous_mgr.cc @@ -60,7 +60,7 @@ void SameWorkerRecvDone(const DeviceMgr* device_mgr, } } *out = in; - done(OkStatus()); + done(absl::OkStatus()); return; } diff --git a/tensorflow/core/common_runtime/rendezvous_util.cc b/tensorflow/core/common_runtime/rendezvous_util.cc index dc5ee4f9759f72..bad10d33dee5b1 100644 --- a/tensorflow/core/common_runtime/rendezvous_util.cc +++ b/tensorflow/core/common_runtime/rendezvous_util.cc @@ -50,7 +50,7 @@ Status SendTensorsToRendezvous( TF_RETURN_IF_ERROR( rendezvous->Send(parsed, rendez_args, tensors_to_send[i], false)); } - return OkStatus(); + return absl::OkStatus(); } void RecvOutputsFromRendezvousAsync( @@ -59,7 +59,7 @@ void RecvOutputsFromRendezvousAsync( const std::vector& keys, std::vector* received_tensors, StatusCallback done) { if (keys.empty()) { - done(OkStatus()); + done(absl::OkStatus()); return; } if (!alloc_attrs.empty() && (keys.size() != alloc_attrs.size())) { @@ -134,7 +134,7 @@ Status RecvOutputsFromRendezvous(RendezvousInterface* rendezvous, " was not valid."); } } - return OkStatus(); + return absl::OkStatus(); } } // namespace tensorflow diff --git a/tensorflow/core/common_runtime/replicate_constants_pass.cc b/tensorflow/core/common_runtime/replicate_constants_pass.cc index 376129bad99719..11d6f0d53864f4 100644 --- a/tensorflow/core/common_runtime/replicate_constants_pass.cc +++ b/tensorflow/core/common_runtime/replicate_constants_pass.cc @@ -90,7 +90,7 @@ Status DeviceNameToCpuDeviceNameWithDeviceId(const string& device_name, device.has_id = true; *host_device_name = DeviceNameUtils::ParsedNameToString(device); } - return OkStatus(); + return absl::OkStatus(); } // Get the CPU device on the same host as dst. @@ -114,7 +114,7 @@ Status GetSuccessorEdges( if (!device_to_edges.count(device)) device_to_edges.insert({device, {}}); device_to_edges[device].push_back(edge); } - return OkStatus(); + return absl::OkStatus(); } // Replicate the constant to each successor device. @@ -148,7 +148,7 @@ Status ReplicateConstantsPass::Run( if (options.graph == nullptr) { VLOG(1) << "No graph in replicate_constants_pass."; - return OkStatus(); + return absl::OkStatus(); } Graph* graph = options.graph->get(); if (VLOG_IS_ON(1)) { @@ -206,7 +206,7 @@ Status ReplicateConstantsPass::Run( VLOG(1) << DumpGraphToFile("after_replicate_constants_pass", *graph, options.flib_def); } - return OkStatus(); + return absl::OkStatus(); } REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 3, diff --git a/tensorflow/core/common_runtime/replicate_per_replica_nodes.cc b/tensorflow/core/common_runtime/replicate_per_replica_nodes.cc index 1272cbad8eafbc..58bab38fb1d093 100644 --- a/tensorflow/core/common_runtime/replicate_per_replica_nodes.cc +++ b/tensorflow/core/common_runtime/replicate_per_replica_nodes.cc @@ -40,7 +40,7 @@ class ReplicateHelper { } std::vector replicated_nodes(num_allowed_devices, nullptr); replicated_nodes_map_.emplace(node, std::move(replicated_nodes)); - return OkStatus(); + return absl::OkStatus(); } // Replicate the given node to an allowed device. @@ -49,7 +49,7 @@ class ReplicateHelper { int allowed_device_index, Graph* graph) { auto& replicated_nodes = replicated_nodes_map_.at(node); if (replicated_nodes[allowed_device_index] != nullptr) { - return OkStatus(); + return absl::OkStatus(); } const auto& device = allowed_devices.at(allowed_device_index); NodeDef node_def = node->def(); @@ -61,7 +61,7 @@ class ReplicateHelper { replicated_node->AddAttr("sub_index", allowed_device_index); } replicated_nodes[allowed_device_index] = replicated_node; - return OkStatus(); + return absl::OkStatus(); } // Replace an edge (a regular device -> composite device) with @@ -107,7 +107,7 @@ class ReplicateHelper { graph->AddEdge(src_replicated_nodes.at(i), edge->src_output(), dst, edge->dst_input()); } - return OkStatus(); + return absl::OkStatus(); } // Data edge: replace an edge (composite device -> a regular device) with @@ -145,7 +145,7 @@ class ReplicateHelper { graph->AddControlEdge(replicated_node, dst, /*allow_duplicates=*/true); } - return OkStatus(); + return absl::OkStatus(); } if (edge->src()->type_string() == "_Arg") { // This happens when the dst node runs on a host CPU and @@ -188,7 +188,7 @@ class ReplicateHelper { " assigned to ", dst_device); } } - return OkStatus(); + return absl::OkStatus(); } private: @@ -246,7 +246,7 @@ Status ReplicateNodesAndEdges(const std::vector& allowed_devices, cluster_nodes->erase(node); graph->RemoveNode(node); } - return OkStatus(); + return absl::OkStatus(); } } // namespace @@ -279,7 +279,7 @@ Status ReplicatePerReplicaNodesInFunctionGraph( if (composite_device_to_cluster_nodes.empty()) { VLOG(1) << "No nodes with composiste device found."; - return OkStatus(); + return absl::OkStatus(); } for (auto& it : composite_device_to_cluster_nodes) { @@ -331,7 +331,7 @@ Status ReplicatePerReplicaNodesInFunctionGraph( VLOG(1) << "Finished ReplicatePerReplicaNodesInFunctionGraph"; VLOG(1) << "Graph #nodes " << graph->num_nodes() << " #edges " << graph->num_edges(); - return OkStatus(); + return absl::OkStatus(); } } // namespace tensorflow diff --git a/tensorflow/core/common_runtime/ring_alg.cc b/tensorflow/core/common_runtime/ring_alg.cc index a88f203e61b751..617127cadcc3f1 100644 --- a/tensorflow/core/common_runtime/ring_alg.cc +++ b/tensorflow/core/common_runtime/ring_alg.cc @@ -115,7 +115,7 @@ Status GenerateSubdivsInCollectiveParams(CollectiveParams* col_params) { if (col_params->instance.impl_details.max_subdivs_per_device == -1) { col_params->instance.impl_details.subdiv_offsets = {0}; VLOG(2) << "Limiting to 1 subdivision as max_subdivs_per_device == -1"; - return OkStatus(); + return absl::OkStatus(); } if (col_params->instance.shape.num_elements() == 0) { @@ -173,7 +173,7 @@ Status GenerateSubdivsInCollectiveParams(CollectiveParams* col_params) { << tensor_size << " chunk_size " << chunk_size; } - return OkStatus(); + return absl::OkStatus(); } } // namespace @@ -252,7 +252,7 @@ Status RingAlg::InitializeCollectiveParams(CollectiveParams* col_params) { } VLOG(2) << collective_util::SubdivPermDebugString(*col_params); - return OkStatus(); + return absl::OkStatus(); } Status RingAlg::InitializeCollectiveContext( diff --git a/tensorflow/core/common_runtime/scoped_allocator_mgr.cc b/tensorflow/core/common_runtime/scoped_allocator_mgr.cc index dbb26ac5beabde..44e64690ccbb25 100644 --- a/tensorflow/core/common_runtime/scoped_allocator_mgr.cc +++ b/tensorflow/core/common_runtime/scoped_allocator_mgr.cc @@ -53,7 +53,7 @@ Status ScopedAllocatorContainer::AddScopedAllocator( allocators_[f.scope_id] = ScopedAllocatorContainer::SAField( i, new ScopedAllocatorInstance(sa, i)); } - return OkStatus(); + return absl::OkStatus(); } ScopedAllocator* ScopedAllocatorContainer::GetAllocator(int32_t scope_id) { diff --git a/tensorflow/core/common_runtime/serving_device_selector.cc b/tensorflow/core/common_runtime/serving_device_selector.cc deleted file mode 100644 index 30ca7d46a1a7fc..00000000000000 --- a/tensorflow/core/common_runtime/serving_device_selector.cc +++ /dev/null @@ -1,47 +0,0 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/core/common_runtime/serving_device_selector.h" - -namespace tensorflow { - -DeviceReservation::DeviceReservation(int device_index, - ServingDeviceSelector* device_selector) - : device_index_(device_index), device_selector_(device_selector) {} - -DeviceReservation::~DeviceReservation() { reset(); } - -void DeviceReservation::reset() { - if (device_selector_) device_selector_->FreeDeviceReservation(*this); - device_selector_ = nullptr; -} - -DeviceReservation::DeviceReservation(DeviceReservation&& r) - : device_index_{r.device_index_}, device_selector_{r.device_selector_} { - r.device_selector_ = nullptr; -} - -DeviceReservation& DeviceReservation::operator=(DeviceReservation&& r) { - if (this == &r) return *this; - - if (device_selector_) device_selector_->FreeDeviceReservation(*this); - - device_index_ = r.device_index_; - device_selector_ = r.device_selector_; - r.device_selector_ = nullptr; - return *this; -} - -} // namespace tensorflow diff --git a/tensorflow/core/common_runtime/serving_device_selector.h b/tensorflow/core/common_runtime/serving_device_selector.h deleted file mode 100644 index c776a1fae67d45..00000000000000 --- a/tensorflow/core/common_runtime/serving_device_selector.h +++ /dev/null @@ -1,96 +0,0 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_SERVING_DEVICE_SELECTOR_H_ -#define TENSORFLOW_CORE_COMMON_RUNTIME_SERVING_DEVICE_SELECTOR_H_ - -#include -#include -#include - -#include "absl/strings/string_view.h" -#include "absl/types/span.h" - -namespace tensorflow { - -class ServingDeviceSelector; - -// A RAII type for device reservation. -class DeviceReservation { - public: - DeviceReservation(int device_index, ServingDeviceSelector* selector); - ~DeviceReservation(); - - DeviceReservation(const DeviceReservation&) = delete; - DeviceReservation& operator=(const DeviceReservation&) = delete; - - DeviceReservation(DeviceReservation&& r); - DeviceReservation& operator=(DeviceReservation&& r); - - int device_index() const { return device_index_; } - - void reset(); - - private: - int device_index_; - ServingDeviceSelector* device_selector_; -}; - -// Interface for runtime device selection for serving. -// NOTE: This interface is experimental and subject to change. -class ServingDeviceSelector { - public: - // The state for a single device. - struct DeviceState { - // TODO(b/295352859): Add more stats to track that are useful for the Policy - // to use when selecting a device. - struct ProgramInfo { - absl::string_view fingerprint; - int64_t req_id = -1; - }; - std::deque scheduled_programs; - }; - - // Struct of all tracked device states, which will be passed to Policy. - struct DeviceStates { - absl::Span states; - }; - - // Policy used to select a device. - class Policy { - public: - virtual ~Policy() = default; - // Selects a device based on the tracked states of all devices. - virtual int SelectDevice(absl::string_view program_fingerprint, - const DeviceStates& device_states) = 0; - }; - - virtual ~ServingDeviceSelector() = default; - - // Reserves a device according to a given selection policy. The reserved - // device will be freed when the lifetime of the returned `DeviceReservation` - // object ends. - virtual DeviceReservation ReserveDevice( - absl::string_view program_fingerprint) = 0; - - private: - friend DeviceReservation; - - // Frees the given device reservation. - virtual void FreeDeviceReservation(const DeviceReservation& reservation) = 0; -}; - -} // namespace tensorflow - -#endif // TENSORFLOW_CORE_COMMON_RUNTIME_SERVING_DEVICE_SELECTOR_H_ diff --git a/tensorflow/core/common_runtime/session_factory.cc b/tensorflow/core/common_runtime/session_factory.cc index 10e269a70477e2..0ce81f9d9aed40 100644 --- a/tensorflow/core/common_runtime/session_factory.cc +++ b/tensorflow/core/common_runtime/session_factory.cc @@ -83,7 +83,7 @@ Status SessionFactory::GetFactory(const SessionOptions& options, if (candidate_factories.size() == 1) { *out_factory = candidate_factories[0].second; - return OkStatus(); + return absl::OkStatus(); } else if (candidate_factories.size() > 1) { // NOTE(mrry): This implementation assumes that the domains (in // terms of acceptable SessionOptions) of the registered diff --git a/tensorflow/core/common_runtime/session_state.cc b/tensorflow/core/common_runtime/session_state.cc index ac5baacd669ca9..7bf14d304c3740 100644 --- a/tensorflow/core/common_runtime/session_state.cc +++ b/tensorflow/core/common_runtime/session_state.cc @@ -31,7 +31,7 @@ Status SessionState::GetTensor(const string& handle, Tensor* tensor) { "' is not in the session store."); } *tensor = it->second; - return OkStatus(); + return absl::OkStatus(); } Status SessionState::AddTensor(const string& handle, const Tensor& tensor) { @@ -40,7 +40,7 @@ Status SessionState::AddTensor(const string& handle, const Tensor& tensor) { return errors::InvalidArgument("Failed to add a tensor with handle '", handle, "' to the session store."); } - return OkStatus(); + return absl::OkStatus(); } Status SessionState::DeleteTensor(const string& handle) { @@ -49,7 +49,7 @@ Status SessionState::DeleteTensor(const string& handle) { return errors::InvalidArgument("Failed to delete a tensor with handle '", handle, "' in the session store."); } - return OkStatus(); + return absl::OkStatus(); } int64_t SessionState::GetNewId() { @@ -64,7 +64,7 @@ Status TensorStore::AddTensor(const string& name, const TensorAndKey& tk) { "' to the tensor store."); } dirty_ = true; - return OkStatus(); + return absl::OkStatus(); } Status TensorStore::SaveTensors(const std::vector& output_names, @@ -83,7 +83,7 @@ Status TensorStore::SaveTensors(const std::vector& output_names, } } } - return OkStatus(); + return absl::OkStatus(); } } // namespace tensorflow diff --git a/tensorflow/core/common_runtime/session_test.cc b/tensorflow/core/common_runtime/session_test.cc index cbc5e91281060e..7c68f0f001e471 100644 --- a/tensorflow/core/common_runtime/session_test.cc +++ b/tensorflow/core/common_runtime/session_test.cc @@ -50,7 +50,7 @@ class FakeSessionFactory : public SessionFactory { Status NewSession(const SessionOptions& options, Session** out_session) override { *out_session = nullptr; - return OkStatus(); + return absl::OkStatus(); } }; class FakeSessionRegistrar { diff --git a/tensorflow/core/common_runtime/shape_refiner.cc b/tensorflow/core/common_runtime/shape_refiner.cc index 3a4d898a2e20d7..f021896691a078 100644 --- a/tensorflow/core/common_runtime/shape_refiner.cc +++ b/tensorflow/core/common_runtime/shape_refiner.cc @@ -147,7 +147,7 @@ Status ShapeRefiner::InferShapesForFunctionSubNode( } } - return OkStatus(); + return absl::OkStatus(); } // TODO(cwhipkey): When an inference context inside function has @@ -185,7 +185,7 @@ Status ShapeRefiner::InferShapesForFunction(const FunctionDef* function_def, } absl::flat_hash_set function_nodes; - Status inference_status = OkStatus(); + Status inference_status = absl::OkStatus(); { auto node_shape_inference_lambda = [this, &outer_context, &function_nodes, &inference_status](const Node* node) { @@ -272,7 +272,7 @@ Status ShapeRefiner::AddNodeInternal( // Store the resulting context object in the map. node_to_context_[node].swap(ec); - return OkStatus(); + return absl::OkStatus(); } Status ShapeRefiner::SetShape(const Node* node, int output_port, @@ -305,7 +305,7 @@ Status ShapeRefiner::SetShape(const Node* node, int output_port, // TODO(vrv): We might need to keep track of the fact that the // existing shape is invalidated, in case we need to propagate // this information to remote workers. - return OkStatus(); + return absl::OkStatus(); } Status ShapeRefiner::UpdateNode(const Node* node, bool relax, bool* refined) { @@ -390,7 +390,7 @@ Status ShapeRefiner::UpdateNode(const Node* node, bool relax, bool* refined) { if (!*refined) { // No input shape has changed, we're done - return OkStatus(); + return absl::OkStatus(); } // Get and run the shape function for this node to update the shapes of the @@ -406,7 +406,7 @@ Status ShapeRefiner::UpdateNode(const Node* node, bool relax, bool* refined) { if (!op_reg_data->shape_inference_fn) { // There is nothing more we can infer - return OkStatus(); + return absl::OkStatus(); } return RunShapeFn(node, op_reg_data, node_ext_context); @@ -463,7 +463,7 @@ Status ShapeRefiner::EvaluateConstantTensorForEdge( *result = *std::move(tensor); } - return OkStatus(); + return absl::OkStatus(); } Status ShapeRefiner::EvaluateConstantIntScalarEdge( @@ -489,7 +489,7 @@ Status ShapeRefiner::EvaluateConstantIntScalarEdge( *result = scalar.scalar()(); } } - return OkStatus(); + return absl::OkStatus(); } Status ShapeRefiner::ConstantPartialShape( @@ -520,10 +520,10 @@ Status ShapeRefiner::ConstantPartialShape( if (t.dims() == 0) { if (t.dtype() == DT_INT32 && t.scalar()() == -1) { *result = target_context->UnknownShape(); - return OkStatus(); + return absl::OkStatus(); } else if (t.dtype() == DT_INT64 && t.scalar()() == -1) { *result = target_context->UnknownShape(); - return OkStatus(); + return absl::OkStatus(); } } return errors::InvalidArgument( @@ -549,7 +549,7 @@ Status ShapeRefiner::ConstantPartialShape( .ok()) { if (evaluated && target_context->MakeShapeFromTensor(&t, src_shape, result).ok()) { - return OkStatus(); + return absl::OkStatus(); } } @@ -564,7 +564,7 @@ Status ShapeRefiner::ConstantPartialShape( if (!target_context->RankKnown(pre_cast_shape)) { // Failed to evaluate. Treat the output as completely unknown. *result = target_context->UnknownShape(); - return OkStatus(); + return absl::OkStatus(); } auto* dest_type = input_edge->src()->attrs().Find("DstT"); if (dest_type == nullptr || dest_type->value_case() != AttrValue::kType || @@ -572,7 +572,7 @@ Status ShapeRefiner::ConstantPartialShape( // Casting to a weird type. Do not attempt to infer across it. *result = target_context->MakeShape(std::vector( target_context->Rank(pre_cast_shape), target_context->UnknownDim())); - return OkStatus(); + return absl::OkStatus(); } *result = pre_cast_shape; } else if (src_op == "Shape") { @@ -612,7 +612,7 @@ Status ShapeRefiner::ConstantPartialShape( // TODO(cwhipkey): we could rely on all inputs being the same rank, so // figure that rank out and append the right number of unknown dims. *result = target_context->UnknownShape(); - return OkStatus(); + return absl::OkStatus(); } TF_RETURN_IF_ERROR( target_context->Concatenate(*result, sub_result, result)); @@ -635,7 +635,7 @@ Status ShapeRefiner::ConstantPartialShape( TF_RETURN_IF_ERROR(target_context->MakeShapeFromTensor( evaluated ? &t : nullptr, src_shape, result)); } - return OkStatus(); + return absl::OkStatus(); } Status ShapeRefiner::PartialStridedSliceShape( @@ -646,7 +646,7 @@ Status ShapeRefiner::PartialStridedSliceShape( ShapeHandle input_shape = ctx->input(i); if (ctx->Value(ctx->Dim(input_shape, 0)) != 1) { *result = ctx->UnknownShape(); - return OkStatus(); + return absl::OkStatus(); } } @@ -667,7 +667,7 @@ Status ShapeRefiner::PartialStridedSliceShape( !(end_mask == 0 || end_mask == 1) || ellipsis_mask != 0 || new_axis_mask != 0 || shrink_axis_mask != 0) { *result = ctx->UnknownShape(); - return OkStatus(); + return absl::OkStatus(); } bool evaluated; @@ -679,7 +679,7 @@ Status ShapeRefiner::PartialStridedSliceShape( &begin, outer_context)); if (!evaluated) { *result = ctx->UnknownShape(); - return OkStatus(); + return absl::OkStatus(); } } @@ -691,7 +691,7 @@ Status ShapeRefiner::PartialStridedSliceShape( &end, outer_context)); if (!evaluated) { *result = ctx->UnknownShape(); - return OkStatus(); + return absl::OkStatus(); } } @@ -700,7 +700,7 @@ Status ShapeRefiner::PartialStridedSliceShape( &stride, outer_context)); if (!evaluated) { *result = ctx->UnknownShape(); - return OkStatus(); + return absl::OkStatus(); } // Apply stride to input interpreted as a partial shape. @@ -708,7 +708,7 @@ Status ShapeRefiner::PartialStridedSliceShape( TF_RETURN_IF_ERROR( ConstantPartialShape(ctx, slice_node, 0, &input, outer_context)); TF_RETURN_IF_ERROR(ctx->Subshape(input, begin, end, stride, result)); - return OkStatus(); + return absl::OkStatus(); } Status ShapeRefiner::RunShapeFn(const Node* node, @@ -761,7 +761,7 @@ Status ShapeRefiner::RunShapeFn(const Node* node, } else { TF_RETURN_IF_ERROR(c->Run(shape_inference::UnknownShape)); } - return OkStatus(); + return absl::OkStatus(); }; TF_RETURN_IF_ERROR(run_inference_lambda()); @@ -828,7 +828,7 @@ Status ShapeRefiner::RunShapeFn(const Node* node, } } while (rerun_shape_fn); - return OkStatus(); + return absl::OkStatus(); } bool ShapeRefiner::SameDefinedShape(InferenceContext* c, ShapeHandle s0, diff --git a/tensorflow/core/common_runtime/shape_refiner_test.cc b/tensorflow/core/common_runtime/shape_refiner_test.cc index 7b1e8f8c31367c..b50973981fdb0e 100644 --- a/tensorflow/core/common_runtime/shape_refiner_test.cc +++ b/tensorflow/core/common_runtime/shape_refiner_test.cc @@ -330,7 +330,7 @@ REGISTER_OP("TestOp") if (c->input_tensor(0)) { if (c->input_tensor(1)) { c->set_output(0, c->Matrix(10, 10)); - return OkStatus(); + return absl::OkStatus(); } return shape_inference::ScalarShape(c); } @@ -384,7 +384,7 @@ REGISTER_OP("ShapeData") } c->set_output(0, c->MakeShape(dims)); - return OkStatus(); + return absl::OkStatus(); }); REGISTER_OP("ShapeDataInt64") @@ -403,7 +403,7 @@ REGISTER_OP("ShapeDataInt64") } c->set_output(0, c->MakeShape(dims)); - return OkStatus(); + return absl::OkStatus(); }); // An op with a shape function that looks at its input tensor @@ -422,7 +422,7 @@ REGISTER_OP("ShapeVectorForAllElements") } c->set_output(0, c->Vector(total)); - return OkStatus(); + return absl::OkStatus(); }); REGISTER_OP("MultiIdentity") @@ -433,7 +433,7 @@ REGISTER_OP("MultiIdentity") for (int i = 0; i < c->num_inputs(); ++i) { c->set_output(i, c->input(i)); } - return OkStatus(); + return absl::OkStatus(); }); class MultiIdentity : public OpKernel { @@ -834,7 +834,7 @@ Status TensorAsShapeShapeFn(shape_inference::InferenceContext* c) { shape_inference::ShapeHandle out; TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0 /* input_idx */, &out)); c->set_output(0, out); - return OkStatus(); + return absl::OkStatus(); } Status PartialTensorAsShapeShapeFn(shape_inference::InferenceContext* c) { @@ -842,12 +842,12 @@ Status PartialTensorAsShapeShapeFn(shape_inference::InferenceContext* c) { const Tensor* t = c->input_tensor(0); if (t == nullptr || t->NumElements() != 1) { c->set_output(0, c->UnknownShape()); - return OkStatus(); + return absl::OkStatus(); } TF_RETURN_IF_ERROR( c->MakeShapeFromTensorShape(TensorShape({t->flat()(0)}), &out)); c->set_output(0, out); - return OkStatus(); + return absl::OkStatus(); } // Register ops used by the ConstantValueAsShape* tests. @@ -881,7 +881,7 @@ REGISTER_OP("WithEmptyVectorShape") .SetDoNotOptimize() .SetShapeFn([](shape_inference::InferenceContext* c) { c->set_output(0, c->Vector(0)); - return OkStatus(); + return absl::OkStatus(); }); REGISTER_OP("WithPartialShape") @@ -891,7 +891,7 @@ REGISTER_OP("WithPartialShape") c->set_output( 0, c->MakeShape({1, shape_inference::InferenceContext::kUnknownDim, 3, shape_inference::InferenceContext::kUnknownDim, 5})); - return OkStatus(); + return absl::OkStatus(); }); REGISTER_OP("WithPartialShape2") @@ -901,7 +901,7 @@ REGISTER_OP("WithPartialShape2") c->set_output( 0, c->MakeShape({6, shape_inference::InferenceContext::kUnknownDim, 8})); - return OkStatus(); + return absl::OkStatus(); }); REGISTER_OP("WithUnknownShape") @@ -909,7 +909,7 @@ REGISTER_OP("WithUnknownShape") .SetDoNotOptimize() .SetShapeFn([](shape_inference::InferenceContext* c) { c->set_output(0, c->UnknownShape()); - return OkStatus(); + return absl::OkStatus(); }); } // namespace diff --git a/tensorflow/core/common_runtime/single_threaded_cpu_device.cc b/tensorflow/core/common_runtime/single_threaded_cpu_device.cc index 495f65cfe910c9..ababa7d14aec33 100644 --- a/tensorflow/core/common_runtime/single_threaded_cpu_device.cc +++ b/tensorflow/core/common_runtime/single_threaded_cpu_device.cc @@ -56,7 +56,7 @@ class SingleThreadedCpuDevice : public Device { ~SingleThreadedCpuDevice() override { eigen_device_.reset(); } - Status Sync() override { return OkStatus(); } + Status Sync() override { return absl::OkStatus(); } Status MakeTensorFromProto(const TensorProto& tensor_proto, const AllocatorAttributes alloc_attrs, @@ -66,7 +66,7 @@ class SingleThreadedCpuDevice : public Device { return errors::InvalidArgument("Cannot parse tensor from tensor_proto."); } *tensor = parsed; - return OkStatus(); + return absl::OkStatus(); } void CopyTensorInSameDevice(const Tensor* input_tensor, Tensor* output_tensor, @@ -79,7 +79,7 @@ class SingleThreadedCpuDevice : public Device { return; } tensor::DeepCopy(*input_tensor, output_tensor); - done(OkStatus()); + done(absl::OkStatus()); } Allocator* GetAllocator(AllocatorAttributes attr) override { diff --git a/tensorflow/core/common_runtime/single_threaded_executor.cc b/tensorflow/core/common_runtime/single_threaded_executor.cc index 598907b3cd2582..a098f437af17b6 100644 --- a/tensorflow/core/common_runtime/single_threaded_executor.cc +++ b/tensorflow/core/common_runtime/single_threaded_executor.cc @@ -57,7 +57,7 @@ Status ValidateOpIsSafeForSyncExecution( ". Perhaps your graph contains old-style control flow primitives? " "Try using tf.compat.v1.enable_control_flow_v2()."); } - return OkStatus(); + return absl::OkStatus(); } namespace { @@ -251,7 +251,7 @@ class SingleThreadedExecutorImpl : public Executor { } else { total_num_inputs_ = 0; } - return OkStatus(); + return absl::OkStatus(); } Status Run(const Args& args) override { @@ -481,7 +481,7 @@ class SingleThreadedExecutorImpl : public Executor { delete val.tensor; } } - return OkStatus(); + return absl::OkStatus(); } private: @@ -582,7 +582,7 @@ class SingleThreadedExecutorRegistrar { Executor* ret; TF_RETURN_IF_ERROR(NewSingleThreadedExecutor(params, graph, &ret)); out_executor->reset(ret); - return OkStatus(); + return absl::OkStatus(); } }; }; @@ -595,7 +595,7 @@ Status NewSingleThreadedExecutor(const LocalExecutorParams& params, auto impl = std::make_unique(params); TF_RETURN_IF_ERROR(impl->Initialize(graph)); *executor = impl.release(); - return OkStatus(); + return absl::OkStatus(); } } // namespace tensorflow diff --git a/tensorflow/core/common_runtime/single_threaded_executor_test.cc b/tensorflow/core/common_runtime/single_threaded_executor_test.cc index c1343aea326f6d..04fcd51647efc8 100644 --- a/tensorflow/core/common_runtime/single_threaded_executor_test.cc +++ b/tensorflow/core/common_runtime/single_threaded_executor_test.cc @@ -99,7 +99,7 @@ class ExecutorTest : public ::testing::Test { if ((*kernel)->type_string_view() == "Mock") { down_cast(*kernel)->SetCompute(mock_fn); } - return OkStatus(); + return absl::OkStatus(); }; params.delete_kernel = [](OpKernel* kernel) { DeleteNonCachedKernel(kernel); diff --git a/tensorflow/core/common_runtime/test_collective_executor_mgr.h b/tensorflow/core/common_runtime/test_collective_executor_mgr.h index c123791e0004e5..0a47150b6a0dc3 100644 --- a/tensorflow/core/common_runtime/test_collective_executor_mgr.h +++ b/tensorflow/core/common_runtime/test_collective_executor_mgr.h @@ -104,6 +104,14 @@ class TestCollectiveExecutorMgr : public CollectiveExecutorMgrInterface { } } + void CleanupAll() override { + mutex_lock l(mu_); + for (auto& iter : table_) { + iter.second->Unref(); + } + table_.clear(); + } + ParamResolverInterface* GetParamResolver() const override { return param_resolver_; } diff --git a/tensorflow/core/common_runtime/threadpool_device.cc b/tensorflow/core/common_runtime/threadpool_device.cc index 8365c229c4fd3c..cd77afcf53dbd3 100644 --- a/tensorflow/core/common_runtime/threadpool_device.cc +++ b/tensorflow/core/common_runtime/threadpool_device.cc @@ -126,7 +126,7 @@ Status ThreadPoolDevice::MakeTensorFromProto( Tensor parsed(tensor_proto.dtype()); if (parsed.FromProto(allocator_, tensor_proto)) { *tensor = std::move(parsed); - return OkStatus(); + return absl::OkStatus(); } } return errors::InvalidArgument("Cannot parse tensor from proto: ", @@ -143,7 +143,7 @@ void ThreadPoolDevice::CopyTensorInSameDevice( return; } tensor::DeepCopy(*input_tensor, output_tensor); - done(OkStatus()); + done(absl::OkStatus()); } namespace { diff --git a/tensorflow/core/common_runtime/threadpool_device.h b/tensorflow/core/common_runtime/threadpool_device.h index 5ca3deb7247602..a2b062e0e6f727 100644 --- a/tensorflow/core/common_runtime/threadpool_device.h +++ b/tensorflow/core/common_runtime/threadpool_device.h @@ -43,7 +43,7 @@ class ThreadPoolDevice : public LocalDevice { const DeviceContext* device_context, StatusCallback done) override; - Status Sync() override { return OkStatus(); } + Status Sync() override { return absl::OkStatus(); } void Compute(OpKernel* op_kernel, OpKernelContext* context) override; void ComputeAsync(AsyncOpKernel* op_kernel, OpKernelContext* context, diff --git a/tensorflow/core/common_runtime/threadpool_device_factory.cc b/tensorflow/core/common_runtime/threadpool_device_factory.cc index 5840a5d64005ce..4ebf3576167e5f 100644 --- a/tensorflow/core/common_runtime/threadpool_device_factory.cc +++ b/tensorflow/core/common_runtime/threadpool_device_factory.cc @@ -32,7 +32,7 @@ class ThreadPoolDeviceFactory : public DeviceFactory { Status ListPhysicalDevices(std::vector* devices) override { devices->push_back("/physical_device:CPU:0"); - return OkStatus(); + return absl::OkStatus(); } Status CreateDevices(const SessionOptions& options, const string& name_prefix, @@ -67,7 +67,7 @@ class ThreadPoolDeviceFactory : public DeviceFactory { devices->push_back(std::move(tpd)); } - return OkStatus(); + return absl::OkStatus(); } }; diff --git a/tensorflow/core/common_runtime/type_inference.cc b/tensorflow/core/common_runtime/type_inference.cc index 54ce1a68bc500c..3ad86d8792bc46 100644 --- a/tensorflow/core/common_runtime/type_inference.cc +++ b/tensorflow/core/common_runtime/type_inference.cc @@ -99,14 +99,14 @@ std::vector> input_types( Status update_inferred_type(Node* target, const FullTypeDef& t, bool& updated) { if (t.type_id() == TFT_UNSET) { VLOG(3) << " " << target->name() << " no inferred type"; - return OkStatus(); + return absl::OkStatus(); } if (target->def().has_experimental_type()) { const auto existing = target->def().experimental_type(); if (full_type::IsSubtype(existing, t)) { VLOG(3) << " " << target->name() << " no new type info"; - return OkStatus(); + return absl::OkStatus(); } else if (!full_type::IsSubtype(t, existing)) { // The only allowable type mismatches are those which would further // specialize the existing type. @@ -121,7 +121,7 @@ Status update_inferred_type(Node* target, const FullTypeDef& t, bool& updated) { *(target->mutable_def()->mutable_experimental_type()) = t; updated = true; VLOG(3) << " " << target->name() << " updated"; - return OkStatus(); + return absl::OkStatus(); } StatusOr run_inference(const string& fn_name, @@ -131,7 +131,7 @@ StatusOr run_inference(const string& fn_name, // * execute pass on its graph // * get retnode types // * return them here - return OkStatus(); + return absl::OkStatus(); } } // namespace @@ -174,7 +174,7 @@ Status TypeInferencePass::Run( auto infer_forward = [&forward](Node* n, bool& updated) { if (!forward.contains(n->id())) { - return OkStatus(); + return absl::OkStatus(); } VLOG(4) << " " << n->name() << " has forward function"; @@ -189,12 +189,12 @@ Status TypeInferencePass::Run( update_inferred_type(n, *infer_ret, updated), "while updating its output type."); - return OkStatus(); + return absl::OkStatus(); }; auto infer_reverse = [&reverse](Node* n, bool& updated) { if (!reverse.contains(n->id())) { - return OkStatus(); + return absl::OkStatus(); } VLOG(4) << " " << n->name() << " has reverse function"; @@ -218,7 +218,7 @@ Status TypeInferencePass::Run( absl::StrCat("while updating its output type inferred from '", n->name(), ",")); - return OkStatus(); + return absl::OkStatus(); }; std::list queue; @@ -328,7 +328,7 @@ Status TypeInferencePass::Run( DumpGraphToFile("forward_type_inference_after", *g, flib_def); } - return OkStatus(); + return absl::OkStatus(); } Status WeakTypeInferencePass::Run( @@ -341,7 +341,7 @@ Status WeakTypeInferencePass::Run( "invalid graph that escaped type checking. Error message: " << pass_status.ToString(); } - return OkStatus(); + return absl::OkStatus(); } // Note: This needs to run last because Placer needs it. diff --git a/tensorflow/core/data/BUILD b/tensorflow/core/data/BUILD index 29dbcb95ca4204..d0c2fd315ff4c4 100644 --- a/tensorflow/core/data/BUILD +++ b/tensorflow/core/data/BUILD @@ -227,9 +227,9 @@ cc_library( ":tfdataz_metrics", "//tensorflow/core/platform:env", "//tensorflow/core/platform:numbers", + "//tensorflow/core/platform:status", "@com_google_absl//absl/base", "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/strings", "@local_tsl//tsl/platform:logging", ], ) @@ -279,6 +279,7 @@ cc_library( visibility = ["//visibility:public"], deps = [ "//tensorflow/core:framework", + "//tensorflow/core:framework_internal", "//tensorflow/core:lib", "//tensorflow/core/platform:env", "//tensorflow/core/platform:mutex", @@ -531,6 +532,8 @@ cc_library( ":dataset_utils", ":root_dataset", ":serialization_utils", + ":tf_data_memory_logger", + ":tfdataz_metrics", ":unbounded_thread_pool", "//tensorflow/core:all_kernels", "//tensorflow/core:core_cpu_internal", diff --git a/tensorflow/core/data/captured_function.cc b/tensorflow/core/data/captured_function.cc index d05afc40135f1b..7707b220e91f25 100644 --- a/tensorflow/core/data/captured_function.cc +++ b/tensorflow/core/data/captured_function.cc @@ -30,6 +30,7 @@ limitations under the License. #include "tensorflow/core/data/stats_utils.h" #include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/cancellation.h" +#include "tensorflow/core/framework/dataset.h" #include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/function_handle_cache.h" #include "tensorflow/core/framework/op_kernel.h" @@ -126,7 +127,7 @@ Status GetCapturedInput(const CapturedFunction* const func, int index, ". Num captured inputs: ", func->captured_inputs().size()); } *out = &func->captured_inputs()[index]; - return OkStatus(); + return absl::OkStatus(); } Status RunShortCircuit(const ShortCircuitInfo& info, @@ -146,7 +147,7 @@ Status RunShortCircuit(const ShortCircuitInfo& info, rets->push_back(*captured_input); } } - return OkStatus(); + return absl::OkStatus(); } Status RunShortCircuit(const ShortCircuitInfo& info, std::vector&& args, @@ -169,7 +170,7 @@ Status RunShortCircuit(const ShortCircuitInfo& info, std::vector&& args, rets->push_back(*captured_input); } } - return OkStatus(); + return absl::OkStatus(); } Status CreateShortCircuitInfo(OpKernelConstruction* ctx, @@ -190,7 +191,7 @@ Status CreateShortCircuitInfo(OpKernelConstruction* ctx, // If the function contains any stateful operations, we conservatively execute // the entire function. if (ctx->function_library()->IsStateful(func.name())) { - return OkStatus(); + return absl::OkStatus(); } const FunctionBody* fn_body = @@ -228,7 +229,7 @@ Status CreateShortCircuitInfo(OpKernelConstruction* ctx, } } - return OkStatus(); + return absl::OkStatus(); } Status CreateFunctionLibraryDefinition( @@ -253,7 +254,7 @@ Status LookupFunction(const FunctionLibraryDefinition& lib_def, "Failed to find function ", name, " in function library: ", lib_def.ToProto().DebugString()); } - return OkStatus(); + return absl::OkStatus(); } class CallFrameBase : public CallFrameInterface { @@ -272,7 +273,7 @@ class CallFrameBase : public CallFrameInterface { retvals->emplace_back(std::move(val.value())); ++i; } - return OkStatus(); + return absl::OkStatus(); } size_t num_retvals() const override { return retvals_.size(); } @@ -283,7 +284,7 @@ class CallFrameBase : public CallFrameInterface { if (index < retvals_size && val.dtype() == ret_types_[index] && !retvals_[index]) { retvals_[index] = val; - return OkStatus(); + return absl::OkStatus(); } else if (index >= retvals_size) { return errors::InvalidArgument("Return value ", index, " is out of range."); @@ -324,10 +325,10 @@ class OwnedArgsCallFrame : public CallFrameBase { const int captured_inputs_size = captured_inputs_->size(); if (index < args_size) { *val = &args_[index]; - return OkStatus(); + return absl::OkStatus(); } else if (index < args_size + captured_inputs_size) { *val = &(*captured_inputs_)[index - args_.size()]; - return OkStatus(); + return absl::OkStatus(); } else { return errors::InvalidArgument("Argument ", index, " is out of range."); } @@ -368,10 +369,10 @@ class BorrowedArgsCallFrame : public CallFrameBase { const int captured_inputs_size = captured_inputs_->size(); if (index < args_size) { *val = &args_[index]; - return OkStatus(); + return absl::OkStatus(); } else if (index < args_size + captured_inputs_size) { *val = &(*captured_inputs_)[index - args_size]; - return OkStatus(); + return absl::OkStatus(); } else { return errors::InvalidArgument("Argument ", index, " is out of range."); } @@ -385,7 +386,7 @@ class BorrowedArgsCallFrame : public CallFrameBase { } // namespace Status MakeIteratorFromInputElement( - IteratorContext* ctx, const IteratorBase* parent, + IteratorContext* ctx, const DatasetBaseIterator* parent, const std::vector& input_element, int64_t thread_index, const InstantiatedCapturedFunction& inst_captured_func, StringPiece prefix, std::unique_ptr* out_iterator) { @@ -395,15 +396,18 @@ Status MakeIteratorFromInputElement( } Status MakeIteratorFromInputElement( - IteratorContext* ctx, const IteratorBase* parent, + IteratorContext* ctx, const DatasetBaseIterator* parent, const std::vector& input_element, int64_t thread_index, const InstantiatedCapturedFunction& inst_captured_func, StringPiece prefix, std::unique_ptr* out_iterator, const std::shared_ptr& node) { std::vector return_values; - TF_RETURN_IF_ERROR(inst_captured_func.RunWithBorrowedArgs( - ctx, input_element, &return_values, node)); + auto status = inst_captured_func.RunWithBorrowedArgs(ctx, input_element, + &return_values, node); + if (!status.ok()) { + return parent->AddErrorContext(status); + } if (!(return_values.size() == 1 && return_values[0].dtype() == DT_VARIANT && TensorShapeUtils::IsScalar(return_values[0].shape()))) { @@ -423,7 +427,7 @@ Status MakeIteratorFromInputElement( TF_RETURN_IF_ERROR(returned_dataset->MakeIterator( &nested_ctx, parent, iterator_prefix, out_iterator)); ctx->MergeCheckpoint(nested_ctx.checkpoint()); - return OkStatus(); + return absl::OkStatus(); } /* static */ @@ -453,7 +457,7 @@ Status FunctionMetadata::Create( VLOG(1) << "Disabling multi-device execution for a function that uses the " << FunctionLibraryDefinition::kIntsOnDeviceAttr << " attribute."; (*out_metadata)->use_multi_device_function_ = false; - return OkStatus(); + return absl::OkStatus(); } auto validate_arg = [](const OpDef::ArgDef& arg) { if (!arg.number_attr().empty() || !arg.type_list_attr().empty()) { @@ -466,16 +470,16 @@ Status FunctionMetadata::Create( for (const auto& arg : fdef->signature().input_arg()) { if (!validate_arg(arg)) { (*out_metadata)->use_multi_device_function_ = false; - return OkStatus(); + return absl::OkStatus(); } } for (const auto& arg : fdef->signature().output_arg()) { if (!validate_arg(arg)) { (*out_metadata)->use_multi_device_function_ = false; - return OkStatus(); + return absl::OkStatus(); } } - return OkStatus(); + return absl::OkStatus(); } /* static */ @@ -497,7 +501,7 @@ Status CapturedFunction::Create( std::unique_ptr* out_function) { *out_function = absl::WrapUnique( new CapturedFunction(std::move(metadata), std::move(captured_inputs))); - return OkStatus(); + return absl::OkStatus(); } Status CapturedFunction::AddToGraph( @@ -520,7 +524,7 @@ Status CapturedFunction::AddToGraph( } TF_RETURN_IF_ERROR( b->AddFunction(ctx, metadata_->func().name(), *metadata_->lib_def())); - return OkStatus(); + return absl::OkStatus(); } Status CapturedFunction::Instantiate( @@ -679,7 +683,7 @@ Status CapturedFunction::Instantiate( *instantiated_captured_function = absl::WrapUnique( new InstantiatedCapturedFunction(lib, f_handle, std::move(ret_types), *params.runner, this, is_multi_device)); - return OkStatus(); + return absl::OkStatus(); } Status CapturedFunction::CheckExternalState() const { @@ -687,7 +691,7 @@ Status CapturedFunction::CheckExternalState() const { TF_RETURN_IF_ERROR( IsFunctionStateful(*lib_def(), *(lib_def()->Find(name)))); } - return OkStatus(); + return absl::OkStatus(); } CapturedFunction::CapturedFunction( @@ -700,7 +704,7 @@ Status CapturedFunction::IsMultiDevice(FunctionLibraryRuntime* flr, bool* is_multi_device) const { if (!metadata_->use_multi_device_function()) { *is_multi_device = false; - return OkStatus(); + return absl::OkStatus(); } const FunctionDef* fdef; @@ -735,7 +739,7 @@ Status CapturedFunction::IsMultiDevice(FunctionLibraryRuntime* flr, if (!DeviceNameUtils::AreCompatibleDevNames(current_device_name, resource_device_name)) { *is_multi_device = true; - return OkStatus(); + return absl::OkStatus(); } } } @@ -748,7 +752,7 @@ Status CapturedFunction::IsMultiDevice(FunctionLibraryRuntime* flr, // Check if the op has a kernel available for the current device. if (!KernelDefAvailable(current_device_type, node)) { *is_multi_device = true; - return OkStatus(); + return absl::OkStatus(); } // If the op has a requested device, check if the requested device is // compatible with the current device. @@ -761,14 +765,14 @@ Status CapturedFunction::IsMultiDevice(FunctionLibraryRuntime* flr, if (!DeviceNameUtils::AreCompatibleDevNames(current_device_name, node_device_name)) { *is_multi_device = true; - return OkStatus(); + return absl::OkStatus(); } } } } *is_multi_device = false; - return OkStatus(); + return absl::OkStatus(); } InstantiatedCapturedFunction::InstantiatedCapturedFunction( diff --git a/tensorflow/core/data/captured_function.h b/tensorflow/core/data/captured_function.h index 5d9a573aad0d3f..e415c546f970ae 100644 --- a/tensorflow/core/data/captured_function.h +++ b/tensorflow/core/data/captured_function.h @@ -45,7 +45,7 @@ class InstantiatedCapturedFunction; // Creates an iterator for a dataset which is created by applying the given // function to the given input element. Status MakeIteratorFromInputElement( - IteratorContext* ctx, const IteratorBase* parent, + IteratorContext* ctx, const DatasetBaseIterator* parent, const std::vector& input_element, int64_t thread_index, const InstantiatedCapturedFunction& inst_captured_func, StringPiece prefix, std::unique_ptr* out_iterator); @@ -54,7 +54,7 @@ Status MakeIteratorFromInputElement( // function to the given input element. Pass non-null `node` to record // processing time for modeling Iterator's GetNext() resource usage. Status MakeIteratorFromInputElement( - IteratorContext* ctx, const IteratorBase* parent, + IteratorContext* ctx, const DatasetBaseIterator* parent, const std::vector& input_element, int64_t thread_index, const InstantiatedCapturedFunction& inst_captured_func, StringPiece prefix, std::unique_ptr* out_iterator, diff --git a/tensorflow/core/data/compression_utils.cc b/tensorflow/core/data/compression_utils.cc index fad3515e85e32f..68c8f27f3127ad 100644 --- a/tensorflow/core/data/compression_utils.cc +++ b/tensorflow/core/data/compression_utils.cc @@ -132,7 +132,7 @@ Status CompressElement(const std::vector& element, out->set_version(kCompressedElementVersion); VLOG(3) << "Compressed element from " << iov.NumBytes() << " bytes to " << out->data().size() << " bytes"; - return OkStatus(); + return absl::OkStatus(); } Status UncompressElement(const CompressedElement& compressed, @@ -228,7 +228,7 @@ Status UncompressElement(const CompressedElement& compressed, nonmemcpyable_pos += metadata.uncompressed_bytes(0); } } - return OkStatus(); + return absl::OkStatus(); } REGISTER_UNARY_VARIANT_DECODE_FUNCTION(CompressedElement, diff --git a/tensorflow/core/data/dataset_test_base.cc b/tensorflow/core/data/dataset_test_base.cc index ae6518bd1af289..acc099dcda529b 100644 --- a/tensorflow/core/data/dataset_test_base.cc +++ b/tensorflow/core/data/dataset_test_base.cc @@ -147,7 +147,7 @@ Status WriteDataToFile(const string& filename, const char* data, TF_RETURN_IF_ERROR(file_writer->Flush()); TF_RETURN_IF_ERROR(file_writer->Close()); - return OkStatus(); + return absl::OkStatus(); } Status WriteDataToTFRecordFile(const string& filename, @@ -167,7 +167,7 @@ Status WriteDataToTFRecordFile(const string& filename, TF_RETURN_IF_ERROR(record_writer.Close()); TF_RETURN_IF_ERROR(file_writer->Flush()); TF_RETURN_IF_ERROR(file_writer->Close()); - return OkStatus(); + return absl::OkStatus(); } template @@ -195,7 +195,7 @@ Status IsEqual(const Tensor& t1, const Tensor& t2) { i, "]: ", flat_t1(i), " vs. ", flat_t2(i)); } } - return OkStatus(); + return absl::OkStatus(); } DatasetOpsTestBase::DatasetOpsTestBase() @@ -247,7 +247,7 @@ Status DatasetOpsTestBase::ExpectEqual(const Tensor& a, const Tensor& b) { default: return errors::Internal("Unsupported dtype: ", a.dtype()); } - return OkStatus(); + return absl::OkStatus(); } template @@ -271,7 +271,7 @@ Status DatasetOpsTestBase::ExpectEqual(std::vector produced_tensors, " v.s. ", expected_tensors.size(), ")")); } - if (produced_tensors.empty()) return OkStatus(); + if (produced_tensors.empty()) return absl::OkStatus(); if (produced_tensors[0].dtype() != expected_tensors[0].dtype()) { return Status(tensorflow::errors::Internal( "The two tensor vectors have different dtypes (", @@ -318,7 +318,7 @@ Status DatasetOpsTestBase::ExpectEqual(std::vector produced_tensors, TF_RETURN_IF_ERROR(DatasetOpsTestBase::ExpectEqual(produced_tensors[i], expected_tensors[i])); } - return OkStatus(); + return absl::OkStatus(); } Status DatasetOpsTestBase::CreateOpKernel( @@ -343,7 +343,7 @@ Status DatasetOpsTestBase::CreateOpKernel( device_->resource_manager(), props_with_defaults, TF_GRAPH_DEF_VERSION, &kernel)); op_kernel->reset(kernel); - return OkStatus(); + return absl::OkStatus(); } Status DatasetOpsTestBase::CreateDatasetContext( @@ -357,7 +357,7 @@ Status DatasetOpsTestBase::CreateDatasetContext( } TF_RETURN_IF_ERROR(CreateOpKernelContext( dateset_kernel, inputs, dataset_context_params, dataset_context)); - return OkStatus(); + return absl::OkStatus(); } Status DatasetOpsTestBase::CreateDataset(OpKernel* kernel, @@ -367,7 +367,7 @@ Status DatasetOpsTestBase::CreateDataset(OpKernel* kernel, // Assume that DatasetOp has only one output. DCHECK_EQ(context->num_outputs(), 1); TF_RETURN_IF_ERROR(GetDatasetFromContext(context, 0, dataset)); - return OkStatus(); + return absl::OkStatus(); } Status DatasetOpsTestBase::RestoreIterator( @@ -389,7 +389,7 @@ Status DatasetOpsTestBase::CreateIteratorContext( params.ram_budget_manager = std::make_shared( /*budget*/ std::numeric_limits::max()); *iterator_context = std::make_unique(params); - return OkStatus(); + return absl::OkStatus(); } Status DatasetOpsTestBase::GetDatasetFromContext(OpKernelContext* context, @@ -408,7 +408,7 @@ Status DatasetOpsTestBase::InitThreadPool(int thread_num) { } thread_pool_ = std::make_unique( Env::Default(), ThreadOptions(), "test_thread_pool", thread_num); - return OkStatus(); + return absl::OkStatus(); } Status DatasetOpsTestBase::InitFunctionLibraryRuntime( @@ -441,7 +441,7 @@ Status DatasetOpsTestBase::InitFunctionLibraryRuntime( tsl::core::RefCountPtr* r) { *r = tsl::core::RefCountPtr( new IntraProcessRendezvous(device_mgr)); - return OkStatus(); + return absl::OkStatus(); }}); flr_ = pflr_->GetFLR("/job:localhost/replica:0/task:0/cpu:0"); if (thread_pool_ == nullptr) { @@ -451,7 +451,7 @@ Status DatasetOpsTestBase::InitFunctionLibraryRuntime( thread_pool_->Schedule(std::move(fn)); }; } - return OkStatus(); + return absl::OkStatus(); } Status DatasetOpsTestBase::RunOpKernel(OpKernel* op_kernel, @@ -511,7 +511,7 @@ Status DatasetOpsTestBase::RunFunction( for (int i = 0; i < rets.size(); ++i) { *(rets[i]) = computed[i]; } - return OkStatus(); + return absl::OkStatus(); } Status DatasetOpsTestBase::CreateOpKernelContext( @@ -554,14 +554,14 @@ Status DatasetOpsTestBase::CreateOpKernelContext( *context = std::make_unique(params.get()); *context_params = std::move(params); - return OkStatus(); + return absl::OkStatus(); } Status DatasetOpsTestBase::CreateSerializationContext( std::unique_ptr* context) { *context = std::make_unique(SerializationContext::Params{}); - return OkStatus(); + return absl::OkStatus(); } Status DatasetOpsTestBase::CheckOpKernelInput( @@ -571,7 +571,7 @@ Status DatasetOpsTestBase::CheckOpKernelInput( kernel.num_inputs(), ", but got: ", inputs.size()); } - return OkStatus(); + return absl::OkStatus(); } Status DatasetOpsTestBase::AddDatasetInput( @@ -604,7 +604,7 @@ Status DatasetOpsTestBase::AddDatasetInput( // collect the inputs. tensors_.push_back(std::move(input)); - return OkStatus(); + return absl::OkStatus(); } Status DatasetOpsTestBase::CheckIteratorGetNext( @@ -638,7 +638,7 @@ Status DatasetOpsTestBase::CheckIteratorGetNext( TF_EXPECT_OK(ExpectEqual(out_tensors, expected_outputs, /*compare_order=*/compare_order)); - return OkStatus(); + return absl::OkStatus(); } Status DatasetOpsTestBase::CheckIteratorSkip( @@ -659,7 +659,7 @@ Status DatasetOpsTestBase::CheckIteratorSkip( TF_EXPECT_OK(ExpectEqual(out_tensors, expected_outputs, /*compare_order=*/compare_order)); } - return OkStatus(); + return absl::OkStatus(); } Status DatasetOpsTestBase::CheckSplitProviderFullIteration( @@ -673,7 +673,7 @@ Status DatasetOpsTestBase::CheckSplitProviderFullIteration( MakeIterator(params, *dataset, std::move(split_providers), &iterator)); TF_RETURN_IF_ERROR(CheckIteratorGetNext(iterator.get(), expected_outputs, /*compare_order=*/true)); - return OkStatus(); + return absl::OkStatus(); } Status DatasetOpsTestBase::CheckSplitProviderShardedIteration( @@ -703,65 +703,65 @@ Status DatasetOpsTestBase::CheckSplitProviderShardedIteration( /*breakpoints=*/ {0, mid_breakpoint, near_end_breakpoint, end_breakpoint}, /*compare_order=*/true)); - return OkStatus(); + return absl::OkStatus(); } Status DatasetOpsTestBase::CheckDatasetNodeName( const string& expected_dataset_node_name) { EXPECT_EQ(dataset_->node_name(), expected_dataset_node_name); - return OkStatus(); + return absl::OkStatus(); } Status DatasetOpsTestBase::CheckDatasetTypeString( const string& expected_type_str) { EXPECT_EQ(dataset_->type_string(), expected_type_str); - return OkStatus(); + return absl::OkStatus(); } Status DatasetOpsTestBase::CheckDatasetOutputDtypes( const DataTypeVector& expected_output_dtypes) { TF_EXPECT_OK( VerifyTypesMatch(dataset_->output_dtypes(), expected_output_dtypes)); - return OkStatus(); + return absl::OkStatus(); } Status DatasetOpsTestBase::CheckDatasetOutputShapes( const std::vector& expected_output_shapes) { TF_EXPECT_OK(VerifyShapesCompatible(dataset_->output_shapes(), expected_output_shapes)); - return OkStatus(); + return absl::OkStatus(); } Status DatasetOpsTestBase::CheckDatasetCardinality(int expected_cardinality) { EXPECT_EQ(dataset_->Cardinality(), expected_cardinality); - return OkStatus(); + return absl::OkStatus(); } Status DatasetOpsTestBase::CheckDatasetOptions( const Options& expected_options) { EXPECT_EQ(dataset_->options().SerializeAsString(), expected_options.SerializeAsString()); - return OkStatus(); + return absl::OkStatus(); } Status DatasetOpsTestBase::CheckIteratorOutputDtypes( const DataTypeVector& expected_output_dtypes) { TF_EXPECT_OK( VerifyTypesMatch(iterator_->output_dtypes(), expected_output_dtypes)); - return OkStatus(); + return absl::OkStatus(); } Status DatasetOpsTestBase::CheckIteratorOutputShapes( const std::vector& expected_output_shapes) { TF_EXPECT_OK(VerifyShapesCompatible(iterator_->output_shapes(), expected_output_shapes)); - return OkStatus(); + return absl::OkStatus(); } Status DatasetOpsTestBase::CheckIteratorPrefix( const string& expected_iterator_prefix) { EXPECT_EQ(iterator_->prefix(), expected_iterator_prefix); - return OkStatus(); + return absl::OkStatus(); } Status DatasetOpsTestBase::CheckIteratorSaveAndRestore( @@ -796,7 +796,7 @@ Status DatasetOpsTestBase::CheckIteratorSaveAndRestore( } TF_EXPECT_OK(ExpectEqual(out_tensors, expected_outputs, /*compare_order=*/compare_order)); - return OkStatus(); + return absl::OkStatus(); } Status DatasetOpsTestBase::CheckIteratorSaveAndRestore( @@ -822,7 +822,7 @@ Status DatasetOpsTestBase::Initialize(const DatasetParams& dataset_params) { dataset_->MakeIterator(iterator_ctx_.get(), /*parent=*/nullptr, dataset_params.iterator_prefix(), &iterator_)); initialized_ = true; - return OkStatus(); + return absl::OkStatus(); } Status DatasetOpsTestBase::InitializeRuntime( @@ -830,7 +830,7 @@ Status DatasetOpsTestBase::InitializeRuntime( TF_RETURN_IF_ERROR(InitThreadPool(thread_num_)); TF_RETURN_IF_ERROR( InitFunctionLibraryRuntime(dataset_params.func_lib(), cpu_num_)); - return OkStatus(); + return absl::OkStatus(); } Status DatasetOpsTestBase::MakeDataset(const DatasetParams& dataset_params, @@ -846,7 +846,7 @@ Status DatasetOpsTestBase::MakeDataset(const DatasetParams& dataset_params, *dataset = std::make_unique( std::move(dataset_kernel), std::move(dataset_ctx_params), std::move(dataset_ctx), std::move(created_tensors), dataset_base); - return OkStatus(); + return absl::OkStatus(); } Status DatasetOpsTestBase::RunDatasetOp( @@ -880,7 +880,7 @@ Status DatasetOpsTestBase::RunDatasetOp( TF_RETURN_IF_ERROR(CreateDatasetContext(dataset_kernel->get(), &inputs, dataset_ctx_params, dataset_ctx)); TF_RETURN_IF_ERROR(RunOpKernel(dataset_kernel->get(), dataset_ctx->get())); - return OkStatus(); + return absl::OkStatus(); } Status DatasetOpsTestBase::MakeDataset( @@ -896,7 +896,7 @@ Status DatasetOpsTestBase::MakeDataset( // Assume that DatasetOp has only one output. DCHECK_EQ((*dataset_ctx)->num_outputs(), 1); TF_RETURN_IF_ERROR(GetDatasetFromContext(dataset_ctx->get(), 0, dataset)); - return OkStatus(); + return absl::OkStatus(); } Status DatasetOpsTestBase::MakeIterator( @@ -917,7 +917,7 @@ Status DatasetOpsTestBase::MakeIterator( &iterator_base)); *iterator = std::make_unique(std::move(iterator_ctx), std::move(iterator_base)); - return OkStatus(); + return absl::OkStatus(); } Status DatasetOpsTestBase::MakeIterator( @@ -934,7 +934,7 @@ Status DatasetOpsTestBase::RunDatasetOp(const DatasetParams& dataset_params, for (int i = 0; i < dataset_ctx_->num_outputs(); ++i) { outputs->emplace_back(*dataset_ctx_->mutable_output(i)); } - return OkStatus(); + return absl::OkStatus(); } Status DatasetOpsTestBase::MakeDatasetOpKernel( @@ -950,7 +950,7 @@ Status DatasetOpsTestBase::MakeDatasetOpKernel( test::function::NDef(dataset_params.node_name(), dataset_params.op_name(), input_names, attributes); TF_RETURN_IF_ERROR(CreateOpKernel(node_def, dataset_kernel)); - return OkStatus(); + return absl::OkStatus(); } Status DatasetOpsTestBase::MakeGetOptionsOpKernel( @@ -965,7 +965,7 @@ Status DatasetOpsTestBase::MakeGetOptionsOpKernel( dataset_params.dataset_type(), input_names, attributes); TF_RETURN_IF_ERROR(CreateOpKernel(node_def, op_kernel)); - return OkStatus(); + return absl::OkStatus(); } Status DatasetOpsTestBase::MakeDatasetTensor( @@ -1007,7 +1007,7 @@ Status DatasetOpsTestBase::MakeDatasetTensor( TF_RETURN_IF_ERROR( StoreDatasetInVariantTensor(dataset_base, &dataset_tensor)); *dataset = std::make_unique(dataset_tensor); - return OkStatus(); + return absl::OkStatus(); } DatasetParams::DatasetParams(DataTypeVector output_dtypes, @@ -1057,7 +1057,7 @@ std::vector RangeDatasetParams::GetInputTensors() const { Status RangeDatasetParams::GetInputNames( std::vector* input_names) const { *input_names = {"start", "stop", "step"}; - return OkStatus(); + return absl::OkStatus(); } Status RangeDatasetParams::GetAttributes(AttributeVector* attr_vector) const { @@ -1065,7 +1065,7 @@ Status RangeDatasetParams::GetAttributes(AttributeVector* attr_vector) const { {"output_shapes", output_shapes_}, {"replicate_on_split", false}, {"metadata", ""}}; - return OkStatus(); + return absl::OkStatus(); } string RangeDatasetParams::dataset_type() const { return "Range"; } @@ -1080,7 +1080,7 @@ std::vector BatchDatasetParams::GetInputTensors() const { Status BatchDatasetParams::GetInputNames( std::vector* input_names) const { *input_names = {"input_dataset", "batch_size", "drop_remainder"}; - return OkStatus(); + return absl::OkStatus(); } Status BatchDatasetParams::GetAttributes(AttributeVector* attr_vector) const { @@ -1088,7 +1088,7 @@ Status BatchDatasetParams::GetAttributes(AttributeVector* attr_vector) const { {"output_types", output_dtypes_}, {"output_shapes", output_shapes_}, {"metadata", ""}}; - return OkStatus(); + return absl::OkStatus(); } string BatchDatasetParams::dataset_type() const { return "Batch"; } @@ -1102,7 +1102,7 @@ Status MapDatasetParams::GetInputNames(std::vector* input_names) const { for (int i = 0; i < other_arguments_.size(); ++i) { input_names->emplace_back(absl::StrCat("other_arguments_", i)); } - return OkStatus(); + return absl::OkStatus(); } Status MapDatasetParams::GetAttributes(AttributeVector* attr_vector) const { @@ -1113,7 +1113,7 @@ Status MapDatasetParams::GetAttributes(AttributeVector* attr_vector) const { {"use_inter_op_parallelism", use_inter_op_parallelism_}, {"preserve_cardinality", preserve_cardinality_}, {"metadata", ""}}; - return OkStatus(); + return absl::OkStatus(); } string MapDatasetParams::dataset_type() const { return "Map"; } @@ -1139,7 +1139,7 @@ Status TensorSliceDatasetParams::GetInputNames( for (int i = 0; i < components_.size(); ++i) { input_names->emplace_back(absl::StrCat("components_", i)); } - return OkStatus(); + return absl::OkStatus(); } Status TensorSliceDatasetParams::GetAttributes( @@ -1149,7 +1149,7 @@ Status TensorSliceDatasetParams::GetAttributes( {"is_files", is_files_}, {"replicate_on_split", false}, {"metadata", ""}}; - return OkStatus(); + return absl::OkStatus(); } DataTypeVector TensorSliceDatasetParams::TensorSliceDtypes( @@ -1183,14 +1183,14 @@ std::vector TakeDatasetParams::GetInputTensors() const { Status TakeDatasetParams::GetInputNames( std::vector* input_names) const { *input_names = {"input_dataset", "count"}; - return OkStatus(); + return absl::OkStatus(); } Status TakeDatasetParams::GetAttributes(AttributeVector* attr_vector) const { *attr_vector = {{"output_shapes", output_shapes_}, {"output_types", output_dtypes_}, {"metadata", ""}}; - return OkStatus(); + return absl::OkStatus(); } string TakeDatasetParams::dataset_type() const { return "Take"; } @@ -1202,7 +1202,7 @@ std::vector ConcatenateDatasetParams::GetInputTensors() const { Status ConcatenateDatasetParams::GetInputNames( std::vector* input_names) const { *input_names = {"input_dataset", "another_dataset"}; - return OkStatus(); + return absl::OkStatus(); } Status ConcatenateDatasetParams::GetAttributes( @@ -1210,7 +1210,7 @@ Status ConcatenateDatasetParams::GetAttributes( *attr_vector = {{"output_types", output_dtypes_}, {"output_shapes", output_shapes_}, {"metadata", ""}}; - return OkStatus(); + return absl::OkStatus(); } string ConcatenateDatasetParams::dataset_type() const { return "Concatenate"; } @@ -1220,7 +1220,7 @@ std::vector OptionsDatasetParams::GetInputTensors() const { return {}; } Status OptionsDatasetParams::GetInputNames( std::vector* input_names) const { input_names->emplace_back("input_dataset"); - return OkStatus(); + return absl::OkStatus(); } Status OptionsDatasetParams::GetAttributes(AttributeVector* attr_vector) const { @@ -1228,7 +1228,7 @@ Status OptionsDatasetParams::GetAttributes(AttributeVector* attr_vector) const { {"output_shapes", output_shapes_}, {"output_types", output_dtypes_}, {"metadata", ""}}; - return OkStatus(); + return absl::OkStatus(); } string OptionsDatasetParams::dataset_type() const { return "Options"; } diff --git a/tensorflow/core/data/dataset_utils.cc b/tensorflow/core/data/dataset_utils.cc index e61f700b407026..fd8c3c126b571c 100644 --- a/tensorflow/core/data/dataset_utils.cc +++ b/tensorflow/core/data/dataset_utils.cc @@ -99,6 +99,7 @@ constexpr char kMakeSloppyOpt[] = "make_sloppy"; constexpr char kBatchParallelizationOpt[] = "batch_parallelization"; constexpr char kEnableGradientDescentOpt[] = "enable_gradient_descent"; constexpr char kInjectPrefetchOpt[] = "inject_prefetch"; +constexpr char kSeqInterleavePrefetchOpt[] = "seq_interleave_prefetch"; constexpr char kInjectIoPrefetchEligibleOpt[] = "inject_io_prefetch_eligible"; constexpr char kInjectIoPrefetchOpt[] = "inject_io_prefetch"; constexpr char kAutotuneOpt[] = "autotune"; @@ -224,6 +225,14 @@ void DefaultOptimizationGraphRewrites( optimization_disabled->insert(kInjectPrefetchOpt); } } + if (optimization_options.optional_seq_interleave_prefetch_case() == + OptimizationOptions::kSeqInterleavePrefetch) { + if (optimization_options.seq_interleave_prefetch()) { + optimization_enabled->insert(kSeqInterleavePrefetchOpt); + } else { + optimization_disabled->insert(kSeqInterleavePrefetchOpt); + } + } } // Returns whether an op has been allowlisted as stateless. Uses a heuristic to @@ -255,7 +264,7 @@ Status VerifyTypeMatch(const DataType& expected, const DataType& received, ": expected ", DataTypeString(expected), " but got ", DataTypeString(received), "."); } - return OkStatus(); + return absl::OkStatus(); } Status VerifyTypesMatch(const DataTypeVector& expected, @@ -268,7 +277,7 @@ Status VerifyTypesMatch(const DataTypeVector& expected, for (size_t i = 0; i < expected.size(); ++i) { TF_RETURN_IF_ERROR(VerifyTypeMatch(expected[i], received[i], i)); } - return OkStatus(); + return absl::OkStatus(); } Status VerifyTypesMatch(const DataTypeVector& expected, @@ -281,7 +290,7 @@ Status VerifyTypesMatch(const DataTypeVector& expected, for (size_t i = 0; i < expected.size(); ++i) { TF_RETURN_IF_ERROR(VerifyTypeMatch(expected[i], received[i].dtype(), i)); } - return OkStatus(); + return absl::OkStatus(); } Status VerifyShapeCompatible(const PartialTensorShape& expected, @@ -291,7 +300,7 @@ Status VerifyShapeCompatible(const PartialTensorShape& expected, ": expected ", expected.DebugString(), " but got ", received.DebugString(), "."); } - return OkStatus(); + return absl::OkStatus(); } Status VerifyShapesCompatible(const std::vector& expected, @@ -305,7 +314,7 @@ Status VerifyShapesCompatible(const std::vector& expected, TF_RETURN_IF_ERROR(VerifyShapeCompatible(expected[i], received[i], i)); } - return OkStatus(); + return absl::OkStatus(); } Status VerifyShapesCompatible(const std::vector& expected, @@ -320,7 +329,7 @@ Status VerifyShapesCompatible(const std::vector& expected, VerifyShapeCompatible(expected[i], received[i].shape(), i)); } - return OkStatus(); + return absl::OkStatus(); } Status AddToFunctionLibrary(FunctionLibraryDefinition* base, @@ -357,13 +366,13 @@ Status AddToFunctionLibrary(FunctionLibraryDefinition* base, Status IsFunctionStateful(const FunctionLibraryDefinition& library, const FunctionDef& function_def) { if (!function_def.signature().is_stateful()) { - return OkStatus(); + return absl::OkStatus(); } for (const NodeDef& node_def : function_def.node_def()) { TF_RETURN_IF_ERROR(IsNodeStateful(library, node_def)); } - return OkStatus(); + return absl::OkStatus(); } Status IsNodeStateful(const FunctionLibraryDefinition& library, @@ -375,7 +384,7 @@ Status IsNodeStateful(const FunctionLibraryDefinition& library, if (!OpRegistry::Global()->LookUpOpDef(node.op(), &op_def).ok() || IsOpAllowlisted(op_def) || !op_def->is_stateful() || op_def->name() == "Assert") { - return OkStatus(); + return absl::OkStatus(); } if (op_def->name() == "If") { @@ -389,7 +398,7 @@ Status IsNodeStateful(const FunctionLibraryDefinition& library, if (else_func != nullptr) { TF_RETURN_IF_ERROR(IsFunctionStateful(library, *else_func)); } - return OkStatus(); + return absl::OkStatus(); } if (op_def->name() == "While") { @@ -403,7 +412,7 @@ Status IsNodeStateful(const FunctionLibraryDefinition& library, if (body_func != nullptr) { TF_RETURN_IF_ERROR(IsFunctionStateful(library, *body_func)); } - return OkStatus(); + return absl::OkStatus(); } return errors::FailedPrecondition(op_def->name(), " is stateful."); @@ -440,7 +449,7 @@ Status DeterminismPolicy::FromString(const std::string& s, return errors::InvalidArgument("Unrecognized determinism policy: ", s); } *out = DeterminismPolicy(type); - return OkStatus(); + return absl::OkStatus(); } DeterminismPolicy::DeterminismPolicy(bool is_deterministic) { @@ -641,7 +650,7 @@ Status CopyPartialBatch(int64_t num_elements, const Tensor& value, return errors::InvalidArgument("Unsupported data type: ", DataTypeString(value.dtype())); } - return OkStatus(); + return absl::OkStatus(); } Status ReadBatch(IteratorContext* ctx, IteratorStateReader* reader, @@ -674,7 +683,7 @@ Status ReadBatch(IteratorContext* ctx, IteratorStateReader* reader, batch->emplace_back(std::move(t)); } } - return OkStatus(); + return absl::OkStatus(); } Status WriteBatch(int64_t batch_size, int64_t num_elements, @@ -699,7 +708,7 @@ Status WriteBatch(int64_t batch_size, int64_t num_elements, strings::StrCat(kOutput, "_", i), (*batch)[i])); } } - return OkStatus(); + return absl::OkStatus(); } Status ReadStatus(const string& iterator_prefix, const string& prefix, @@ -717,9 +726,9 @@ Status ReadStatus(const string& iterator_prefix, const string& prefix, &error_message)); *status = Status(code, error_message); } else { - *status = OkStatus(); + *status = absl::OkStatus(); } - return OkStatus(); + return absl::OkStatus(); } Status WriteStatus(const string& iterator_prefix, const string& prefix, @@ -732,7 +741,7 @@ Status WriteStatus(const string& iterator_prefix, const string& prefix, FullName(iterator_prefix, strings::StrCat(prefix, "_", kMessage)), std::string(status.message()))); } - return OkStatus(); + return absl::OkStatus(); } Status ProcessBatch(int64_t batch_size, int64_t num_elements, @@ -742,7 +751,7 @@ Status ProcessBatch(int64_t batch_size, int64_t num_elements, if (num_elements == 0) { if (status.ok() || absl::IsOutOfRange(status)) { *end_of_sequence = true; - return OkStatus(); + return absl::OkStatus(); } else { *end_of_sequence = false; return status; @@ -755,7 +764,7 @@ Status ProcessBatch(int64_t batch_size, int64_t num_elements, if (num_elements < batch_size) { if (drop_remainder) { *end_of_sequence = true; - return OkStatus(); + return absl::OkStatus(); } for (size_t i = 0; i < batch->size(); ++i) { TensorShape component_shape((*batch)[i].shape()); @@ -775,13 +784,12 @@ Status ProcessBatch(int64_t batch_size, int64_t num_elements, *output = std::move(*batch); } *end_of_sequence = false; - return OkStatus(); + return absl::OkStatus(); } Status CopyBatch(CopyBatchParams params, const std::vector>& batch_elements, bool parallel_copy, - std::function allocation_callback, std::vector* out_tensors) { const size_t num_tuple_components = batch_elements.at(0).size(); out_tensors->reserve(num_tuple_components); @@ -800,9 +808,6 @@ Status CopyBatch(CopyBatchParams params, component_index); } } - if (allocation_callback) { - TF_RETURN_IF_ERROR(allocation_callback()); - } for (size_t component_index = 0; component_index < num_tuple_components; ++component_index) { Tensor& batch_component = out_tensors->at(component_index); @@ -865,13 +870,13 @@ Status CopyBatch(CopyBatchParams params, } } } - return OkStatus(); + return absl::OkStatus(); } absl::flat_hash_set CreateGraphRewriteConfigs(const Options& options) { absl::flat_hash_set configs; const auto& autotune_options = options.autotune_options(); - std::array autotune_only_optimizations = { + std::array autotune_only_optimizations = { kAutotuneBufferSizesOpt, kBatchParallelizationOpt, kDisablePrefetchLegacyAutotuneOpt, @@ -879,6 +884,7 @@ absl::flat_hash_set CreateGraphRewriteConfigs(const Options& options) { kFilterParallelizationOpt, kMapParallelizationOpt, kMapFusionOpt, + kSeqInterleavePrefetchOpt, kInjectPrefetchOpt, kInjectIoPrefetchEligibleOpt, kInjectIoPrefetchOpt}; @@ -935,7 +941,17 @@ bool ShouldApplyOptimizations( } int64 GetAutotuneDefaultParallelism(IteratorContext* ctx) { - return std::min(kAutotuneDefaultParallelism, ctx->runner_threadpool_size()); + int64_t initial_parallelism = 16; + if (ctx->options()) { + int64_t initial_parallelism_option = + ctx->options()->autotune_options().initial_parallelism(); + if (initial_parallelism_option > 0) { + initial_parallelism = initial_parallelism_option; + } + } + int64_t runner_threadpool_size = ctx->runner_threadpool_size(); + int64_t value = std::min(initial_parallelism, runner_threadpool_size); + return value; } IteratorContext MakeNestedIteratorContext(IteratorContext* ctx) { @@ -999,7 +1015,7 @@ REGISTER_DATASET_EXPERIMENT("data_transfer", RandomJobSamplePercentage<0>, AllTasks); REGISTER_DATASET_EXPERIMENT("file_locality", RandomJobSamplePercentage<0>, AllTasks); -REGISTER_DATASET_EXPERIMENT("file_locality_v2", RandomJobSamplePercentage<0>, +REGISTER_DATASET_EXPERIMENT("file_locality_v2", RandomJobSamplePercentage<50>, AllTasks); REGISTER_DATASET_EXPERIMENT("no_compression", RandomJobSamplePercentage<0>, AllTasks); @@ -1009,10 +1025,8 @@ REGISTER_DATASET_EXPERIMENT("inject_io_prefetch", RandomJobSamplePercentage<0>, AllTasks); REGISTER_DATASET_EXPERIMENT("reduce_array_record_dataset_memory_usage", RandomJobSamplePercentage<0>, AllTasks); -REGISTER_DATASET_EXPERIMENT("map_fusion", RandomJobSamplePercentage<50>, - AllTasks); -REGISTER_DATASET_EXPERIMENT("log_filenames", RandomJobSamplePercentage<50>, - AllTasks); +REGISTER_DATASET_EXPERIMENT("map_fusion", RandomJobSamplePercentage<5>, + IndependentHostTasks); } // namespace } // namespace data } // namespace tensorflow diff --git a/tensorflow/core/data/dataset_utils.h b/tensorflow/core/data/dataset_utils.h index b975fc0aad567e..3bdc036b118344 100644 --- a/tensorflow/core/data/dataset_utils.h +++ b/tensorflow/core/data/dataset_utils.h @@ -37,9 +37,6 @@ namespace data { // should be supplied by the auto-sharding rewrite. constexpr int kShardHint = -1; -// The initial parallelism value before Autotune has a chance to optimize. -constexpr int kAutotuneDefaultParallelism = 16; - // Creates a resource handle with a unique name for the given resource where // the resource is managed by the Resource Manager. template @@ -53,7 +50,7 @@ Status CreateWeakHandle(OpKernelContext* ctx, T* resource, *handle = MakeResourceHandle(container_name, unique_name, *ctx->device(), TypeIndex::Make()); - return OkStatus(); + return absl::OkStatus(); } // Creates a ref-counting resource handle for the given resource, where the @@ -65,7 +62,7 @@ Status CreateHandle(OpKernelContext* ctx, T* resource, ResourceHandle* handle) { ResourceHandle::MakeRefCountingHandle(resource, ctx->device()->name()); TF_RETURN_IF_ERROR( mgr->CreateUnowned(handle->container(), handle->name(), resource)); - return OkStatus(); + return absl::OkStatus(); } // TODO(b/198162355): Merge this class with ResourceOpKernel. @@ -305,14 +302,12 @@ struct CopyBatchParams { // // The `batch_elements` argument contains the individual elements to copy into a // batch. The `parallel_copy` argument indicates whether to parallelize the -// copy. The `allocation_callback` argument can be used to pass a callback to -// invoke upon successful allocation of the memory for the batch. The -// `out_tensors` argument will be used to store the resulting batch (one for +// copy. +// The `out_tensors` argument will be used to store the resulting batch (one for // each component of the input). Status CopyBatch(CopyBatchParams params, const std::vector>& batch_elements, bool parallel_copy, - std::function allocation_callback, std::vector* out_tensors); // Computes the set of experiments to apply based on the job name, task id, diff --git a/tensorflow/core/data/dataset_utils_test.cc b/tensorflow/core/data/dataset_utils_test.cc index 4ade9741a8857b..853f5e6c5c0bfa 100644 --- a/tensorflow/core/data/dataset_utils_test.cc +++ b/tensorflow/core/data/dataset_utils_test.cc @@ -200,19 +200,19 @@ TEST(DatasetUtilsTest, BoolConstructor) { class TestSplitProvider : public SplitProvider { public: Status GetNext(Tensor* split, bool* end_of_splits) override { - return OkStatus(); + return absl::OkStatus(); } - Status Reset() override { return OkStatus(); } + Status Reset() override { return absl::OkStatus(); } Status Save(std::function key_name_fn, IteratorStateWriter* writer) override { - return OkStatus(); + return absl::OkStatus(); } Status Restore(std::function key_name_fn, IteratorStateReader* reader) override { - return OkStatus(); + return absl::OkStatus(); } }; @@ -672,13 +672,15 @@ GetOptimizationsTestCase GetOptimizationTestCase4() { options.mutable_optimization_options()->set_parallel_batch(true); options.mutable_optimization_options()->set_shuffle_and_repeat_fusion(true); options.mutable_optimization_options()->set_inject_prefetch(true); + options.mutable_optimization_options()->set_seq_interleave_prefetch(true); options.set_slack(true); return {options, /*expected_enabled=*/ {"filter_fusion", "filter_parallelization", "make_sloppy", "map_and_batch_fusion", "map_and_filter_fusion", "map_fusion", "map_parallelization", "noop_elimination", "parallel_batch", - "shuffle_and_repeat_fusion", "slack", "inject_prefetch"}, + "shuffle_and_repeat_fusion", "slack", "inject_prefetch", + "seq_interleave_prefetch"}, /*expected_disabled=*/{}, /*expected_default=*/{}}; } diff --git a/tensorflow/core/data/finalization_utils.cc b/tensorflow/core/data/finalization_utils.cc index 64296bf6e91da4..ed4925d86bbc06 100644 --- a/tensorflow/core/data/finalization_utils.cc +++ b/tensorflow/core/data/finalization_utils.cc @@ -22,10 +22,10 @@ limitations under the License. namespace tensorflow { namespace data { -StatusOr GetFinalizedDataset(OpKernelContext* ctx, - const DatasetBase* dataset) { +absl::StatusOr GetFinalizedDataset(OpKernelContext* ctx, + const DatasetBase* dataset) { return dataset->Finalize( - ctx, [ctx, dataset]() -> StatusOr> { + ctx, [ctx, dataset]() -> absl::StatusOr> { core::RefCountPtr dataset_ref_ptr; DatasetBase* raw_ptr; TF_RETURN_IF_ERROR(data::FinalizeDataset(ctx, dataset, &raw_ptr)); diff --git a/tensorflow/core/data/finalization_utils.h b/tensorflow/core/data/finalization_utils.h index c0ed2afb1b053e..95548d7199fd6d 100644 --- a/tensorflow/core/data/finalization_utils.h +++ b/tensorflow/core/data/finalization_utils.h @@ -26,8 +26,8 @@ namespace data { // Returns the finalized version of the dataset. The returned DatasetBase is // unowned and lives for as long as this dataset. -StatusOr GetFinalizedDataset(OpKernelContext* ctx, - const DatasetBase* dataset); +absl::StatusOr GetFinalizedDataset(OpKernelContext* ctx, + const DatasetBase* dataset); } // namespace data } // namespace tensorflow diff --git a/tensorflow/core/data/hash_utils.cc b/tensorflow/core/data/hash_utils.cc index 1928f14a488f9b..8dbab84a533656 100644 --- a/tensorflow/core/data/hash_utils.cc +++ b/tensorflow/core/data/hash_utils.cc @@ -85,7 +85,7 @@ Status GetSink(const GraphDef& graph_def, const NodeDef** sink) { if (sink == nullptr) { return errors::Internal("Cannot find sink node for dataset graph."); } - return OkStatus(); + return absl::OkStatus(); } Status ShouldIgnoreInput(const NodeDef& node, int i, bool* result) { @@ -103,7 +103,7 @@ Status ShouldIgnoreInput(const NodeDef& node, int i, bool* result) { VLOG(2) << "Ignoring arg: " << input_arg_name << " from node: " << node.name(); *result = true; - return OkStatus(); + return absl::OkStatus(); } } } else if (errors::IsNotFound(status)) { @@ -114,7 +114,7 @@ Status ShouldIgnoreInput(const NodeDef& node, int i, bool* result) { return status; } } - return OkStatus(); + return absl::OkStatus(); } Status ParseInputNodeName(absl::string_view input_name, @@ -123,14 +123,14 @@ Status ParseInputNodeName(absl::string_view input_name, if (input_name[0] == '^') { *node_name = input_name.substr(1); *is_control_input = true; - return OkStatus(); + return absl::OkStatus(); } std::pair node_spec = absl::StrSplit(input_name, absl::MaxSplits(':', 1)); *node_name = node_spec.first; *suffix = node_spec.second; *is_control_input = false; - return OkStatus(); + return absl::OkStatus(); } // Given a graph_def and a root_node, this class computes a fingerprint that @@ -237,7 +237,7 @@ class GraphHasher { } nodes_[node] = node_rep; } - return OkStatus(); + return absl::OkStatus(); } Status HashRoot(uint64* hash) { return HashNode(root_, hash); } @@ -251,7 +251,7 @@ class GraphHasher { auto it = node_cache_->find(node); if (it != node_cache_->end()) { *hash = it->second; - return OkStatus(); + return absl::OkStatus(); } NodeRep* node_rep = gtl::FindOrNull(nodes_, node); @@ -291,7 +291,7 @@ class GraphHasher { return errors::Internal(absl::StrCat("Computed the hash for node ", node->DebugString(), " twice!")); } - return OkStatus(); + return absl::OkStatus(); } Status CheckNodesEqual(const NodeDef* this_node, GraphHasher* that, @@ -339,7 +339,7 @@ class GraphHasher { that_input_suffix); } } - return OkStatus(); + return absl::OkStatus(); } Status HashNodeNonInput(const NodeDef* node, bool hash_functions, @@ -347,7 +347,7 @@ class GraphHasher { auto iter = attr_cache_->find(std::make_pair(node, hash_functions)); if (iter != attr_cache_->end()) { *hash = iter->second; - return OkStatus(); + return absl::OkStatus(); } // Hash Attrs. We get the list of attrs from the op registry and then look // up their values in the NodeDef attr map. This avoids looping over @@ -396,7 +396,7 @@ class GraphHasher { "Computed the hash for non-input node: ", node->DebugString(), " and hash function bool: ", hash_functions, "twice!")); } - return OkStatus(); + return absl::OkStatus(); } Status CheckNodesEqualNonInput(const NodeDef* this_node, GraphHasher* that, @@ -451,7 +451,7 @@ class GraphHasher { that_node->name(), ": ", this_node->device(), " vs ", that_node->device()); } - return OkStatus(); + return absl::OkStatus(); } Status HashAttr(const std::string& attr_name, const AttrValue& attr_value, @@ -473,7 +473,7 @@ class GraphHasher { value_hash = DeterministicProtoHash64(attr_value); } *hash = Hash64Combine(Hash64(attr_name), value_hash); - return OkStatus(); + return absl::OkStatus(); } Status CheckAttrsEqual(const std::string& attr_name, @@ -489,7 +489,7 @@ class GraphHasher { TF_RETURN_IF_ERROR( CheckFunctionsEqual(this_attr.func(), that, that_attr.func())); } - return OkStatus(); + return absl::OkStatus(); } if (this_attr.has_list() != that_attr.has_list()) { return errors::FailedPrecondition( @@ -508,7 +508,7 @@ class GraphHasher { that_attr.list().func(i))); } } - return OkStatus(); + return absl::OkStatus(); } uint64 this_hash, that_hash; TF_RETURN_IF_ERROR( @@ -520,7 +520,7 @@ class GraphHasher { "AttrValues are different: ", this_attr.DebugString(), " vs ", that_attr.DebugString()); } - return OkStatus(); + return absl::OkStatus(); } Status HashFunction(const NameAttrList& func, uint64* hash) { @@ -533,7 +533,7 @@ class GraphHasher { auto it = function_cache_->find(fdef); if (it != function_cache_->end()) { *hash = it->second; - return OkStatus(); + return absl::OkStatus(); } // Convert to a GraphDef. @@ -569,7 +569,7 @@ class GraphHasher { return errors::Internal( absl::StrCat("Computed the hash for function ", name, " twice!")); } - return OkStatus(); + return absl::OkStatus(); } Status CheckFunctionsEqual(const NameAttrList& this_func, GraphHasher* that, @@ -638,7 +638,7 @@ class GraphHasher { } TF_RETURN_IF_ERROR( CheckControlInputsEqual(this_control_rets, that, that_control_rets)); - return OkStatus(); + return absl::OkStatus(); } Status HashControlInputs(const std::vector& inputs, @@ -650,7 +650,7 @@ class GraphHasher { HashNodeNonInput(input, /*hash_functions=*/false, &node_hash)); *hash = Hash64CombineUnordered(*hash, node_hash); } - return OkStatus(); + return absl::OkStatus(); } Status CheckControlInputsEqual( @@ -686,7 +686,7 @@ class GraphHasher { "], which don't match any of the other node's dependencies [", absl::StrJoin(that_hashes, ", ", formatter), "]"); } - return OkStatus(); + return absl::OkStatus(); } private: @@ -748,7 +748,7 @@ Status HashTensor(const Tensor& tensor, uint64* hash) { default: *hash = Hash64(tensor.tensor_data().data(), tensor.tensor_data().size()); } - return OkStatus(); + return absl::OkStatus(); } Status HashNode(const GraphDef& graph, const NodeDef& node, uint64* hash) { diff --git a/tensorflow/core/data/rewrite_utils.cc b/tensorflow/core/data/rewrite_utils.cc index 198871d9632db8..5a9b85279c6098 100644 --- a/tensorflow/core/data/rewrite_utils.cc +++ b/tensorflow/core/data/rewrite_utils.cc @@ -131,7 +131,7 @@ Status ApplyRewrites(OpKernelContext* ctx, RemoveFakeSinks(&function_def); } - return OkStatus(); + return absl::OkStatus(); } } // anonymous namespace @@ -245,7 +245,7 @@ Status RewriteDataset(OpKernelContext* ctx, const DatasetBase* input, }); } - return OkStatus(); + return absl::OkStatus(); } std::unique_ptr GetGrapplerItem( diff --git a/tensorflow/core/data/root_dataset.cc b/tensorflow/core/data/root_dataset.cc index 29b32918fc9722..006174e7af28e8 100644 --- a/tensorflow/core/data/root_dataset.cc +++ b/tensorflow/core/data/root_dataset.cc @@ -28,6 +28,7 @@ limitations under the License. #include "tensorflow/core/data/rewrite_utils.h" #include "tensorflow/core/framework/dataset.h" #include "tensorflow/core/framework/dataset_options.pb.h" +#include "tensorflow/core/framework/metrics.h" #include "tensorflow/core/framework/model.h" #include "tensorflow/core/framework/model.pb.h" #include "tensorflow/core/platform/errors.h" @@ -151,16 +152,22 @@ Status RootDataset::FromOptions(const DatasetBase* input, SetRootDatasetParams(input->options(), ¶ms); *output = new RootDataset(input, params); (*output)->Initialize(/*metadata=*/{}); - return OkStatus(); + for (const auto& framework : input->options().framework_type()) { + metrics::RecordTFDataFrameworkType(framework); + } + return absl::OkStatus(); } Status RootDataset::FromOptions(core::RefCountPtr input, DatasetBase** output) { Params params; + for (const auto& framework : input->options().framework_type()) { + metrics::RecordTFDataFrameworkType(framework); + } SetRootDatasetParams(input->options(), ¶ms); *output = new RootDataset(std::move(input), params); (*output)->Initialize(/*metadata=*/{}); - return OkStatus(); + return absl::OkStatus(); } class RootDataset::Iterator : public DatasetIterator { @@ -195,8 +202,12 @@ class RootDataset::Iterator : public DatasetIterator { dataset()->params_.ComputeInitialAutotuneRamBudget()); if (dataset()->params_.autotune) { - model_ = ctx->model() != nullptr ? ctx->model() - : std::make_shared(); + if (ctx->model() != nullptr) { + model_ = ctx->model(); + } else { + model_ = std::make_shared(); + ctx->SetModel(model_); + } absl::flat_hash_set experiments = GetExperiments(); if (experiments.contains("stage_based_autotune_v2")) { @@ -216,7 +227,7 @@ class RootDataset::Iterator : public DatasetIterator { TF_RETURN_IF_ERROR(dataset()->input_->MakeIterator(&iter_ctx, this, prefix(), &input_impl_)); ctx->MergeCheckpoint(iter_ctx.checkpoint()); - return OkStatus(); + return absl::OkStatus(); } Status GetNextInternal(IteratorContext* ctx, std::vector* out_tensors, @@ -238,7 +249,7 @@ class RootDataset::Iterator : public DatasetIterator { mutex_lock l(mu_); end_time_usec_ = std::max(ctx->env()->NowMicros(), end_time_usec_); } - return OkStatus(); + return absl::OkStatus(); } protected: @@ -250,7 +261,7 @@ class RootDataset::Iterator : public DatasetIterator { Status SaveInternal(SerializationContext* ctx, IteratorStateWriter* writer) override { TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_)); - return OkStatus(); + return absl::OkStatus(); } Status RestoreInternal(IteratorContext* ctx, @@ -258,7 +269,7 @@ class RootDataset::Iterator : public DatasetIterator { IteratorContext iter_ctx(CreateParams(ctx)); TF_RETURN_IF_ERROR(RestoreInput(&iter_ctx, reader, input_impl_)); ctx->MergeCheckpoint(iter_ctx.checkpoint()); - return OkStatus(); + return absl::OkStatus(); } TraceMeMetadata GetTraceMeMetadata() const override { @@ -325,6 +336,7 @@ class RootDataset::Iterator : public DatasetIterator { params.runner = RunnerWithMaxParallelism(params.runner, max_intra_op_parallelism_); } + params.options = &dataset()->options(); return params; } @@ -351,7 +363,7 @@ class RootDataset::Iterator : public DatasetIterator { } }); } - return OkStatus(); + return absl::OkStatus(); } std::shared_ptr model_ = nullptr; @@ -389,6 +401,10 @@ RootDataset::RootDataset(core::RefCountPtr input, params_(std::move(params)) { owned_input_ = std::move(input); input_ = owned_input_.get(); + random_indexing_compatible_ = absl::OkStatus(); + if (input_ != nullptr) { + random_indexing_compatible_ = input_->RandomIndexingCompatible(); + } AddTraceMetadata(params_, input_->options(), &traceme_metadata_); } @@ -426,7 +442,7 @@ Status RootDataset::Get(OpKernelContext* ctx, int64 index, Status RootDataset::InputDatasets( std::vector* inputs) const { inputs->push_back(input_); - return OkStatus(); + return absl::OkStatus(); } Status RootDataset::CheckExternalState() const { @@ -482,7 +498,7 @@ Status FinalizeDataset(OpKernelContext* ctx, const DatasetBase* input, } else { return RootDataset::FromOptions(std::move(rewritten_output), output); } - return OkStatus(); + return absl::OkStatus(); } #else // !IS_MOBILE_PLATFORM diff --git a/tensorflow/core/data/root_dataset.h b/tensorflow/core/data/root_dataset.h index f679b25541b82a..870741ed9354b2 100644 --- a/tensorflow/core/data/root_dataset.h +++ b/tensorflow/core/data/root_dataset.h @@ -67,6 +67,9 @@ class RootDataset : public DatasetBase { Status InputDatasets(std::vector* inputs) const override; std::unique_ptr MakeIteratorInternal( const string& prefix) const override; + Status RandomIndexingCompatible() const override { + return random_indexing_compatible_; + } protected: Status AsGraphDefInternal(SerializationContext* ctx, @@ -84,6 +87,7 @@ class RootDataset : public DatasetBase { core::RefCountPtr owned_input_; const Params params_; TraceMeMetadata traceme_metadata_; + Status random_indexing_compatible_; }; // Finalizes the `input` dataset, which is expected to be called before the diff --git a/tensorflow/core/data/serialization_utils.cc b/tensorflow/core/data/serialization_utils.cc index e948426c02ed0c..204cebfab5ab15 100644 --- a/tensorflow/core/data/serialization_utils.cc +++ b/tensorflow/core/data/serialization_utils.cc @@ -29,6 +29,7 @@ limitations under the License. #include "tensorflow/core/framework/dataset.h" #include "tensorflow/core/framework/dataset.pb.h" #include "tensorflow/core/framework/function.h" +#include "tensorflow/core/framework/variant.h" #include "tensorflow/core/framework/variant_op_registry.h" #include "tensorflow/core/framework/variant_tensor_data.h" #include "tensorflow/core/graph/graph_def_builder.h" @@ -61,7 +62,7 @@ Status FromGraphDef(FunctionLibraryRuntime* flr, const GraphDef& graph_def, TF_RETURN_IF_ERROR(graph_runner.Run(&graph, cloned_flr, input_list, {output_node}, &outputs)); *result = outputs[0]; - return OkStatus(); + return absl::OkStatus(); } // FindStatefulOps searches `graph_def` for all of its stateful ops storing @@ -89,7 +90,7 @@ Status FindStatefulOps(const GraphDef& graph_def, } } } - return OkStatus(); + return absl::OkStatus(); } } // namespace @@ -118,7 +119,7 @@ Status ReadElementsFromCheckpoint(IteratorContext* ctx, &element.back())); } } - return OkStatus(); + return absl::OkStatus(); } Status WriteElement(IteratorStateWriter* writer, StringPiece key_prefix, @@ -132,7 +133,7 @@ Status WriteElement(IteratorStateWriter* writer, StringPiece key_prefix, TF_RETURN_IF_ERROR(writer->WriteTensor( element_prefix, absl::StrCat(kComponent, "[", j, "]"), element[j])); } - return OkStatus(); + return absl::OkStatus(); } Status WriteElementsToCheckpoint( @@ -143,7 +144,7 @@ Status WriteElementsToCheckpoint( for (int i = 0; i < elements.size(); ++i) { TF_RETURN_IF_ERROR(WriteElement(writer, key_prefix, elements, i)); } - return OkStatus(); + return absl::OkStatus(); } Status UpdateCheckpointElements( @@ -155,7 +156,7 @@ Status UpdateCheckpointElements( for (int64_t i : checkpoint_indices) { TF_RETURN_IF_ERROR(WriteElement(writer, key_prefix, elements, i)); } - return OkStatus(); + return absl::OkStatus(); } VariantTensorDataReader::VariantTensorDataReader( @@ -254,7 +255,7 @@ Status VariantTensorDataReader::ReadScalarInternal(StringPiece n, return errors::NotFound(key); } *val = data_.at(name)->tensors(key_it->second).scalar()(); - return OkStatus(); + return absl::OkStatus(); } Status VariantTensorDataReader::ReadTensorInternal(FunctionLibraryRuntime* flr, @@ -275,7 +276,7 @@ Status VariantTensorDataReader::ReadTensorInternal(FunctionLibraryRuntime* flr, return errors::NotFound(key); } *val = data_.at(name)->tensors(key_it->second); - return OkStatus(); + return absl::OkStatus(); } Status VariantTensorDataReader::ReadDatasetInternal(FunctionLibraryRuntime* flr, @@ -294,7 +295,7 @@ Status VariantTensorDataReader::ReadDatasetInternal(FunctionLibraryRuntime* flr, GraphDef graph_def; graph_def.ParseFromString(serialized_graph_def); TF_RETURN_IF_ERROR(FromGraphDef(flr, graph_def, {}, output_node, val)); - return OkStatus(); + return absl::OkStatus(); } std::map VariantTensorDataReader::ReadAllTensors() { @@ -418,7 +419,7 @@ Status VariantTensorDataWriter::WriteTensorInternal(StringPiece n, data_[name]->set_type_name("tensorflow::Iterator"); } *(data_[name]->add_tensors()) = val; - return OkStatus(); + return absl::OkStatus(); } Status VariantTensorDataWriter::WriteDatasetInternal( @@ -439,7 +440,7 @@ Status VariantTensorDataWriter::WriteDatasetInternal( TF_RETURN_IF_ERROR( WriteScalar(n, strings::StrCat(key, kOutputNode), output_node)); TF_RETURN_IF_ERROR(WriteScalar(n, key, result)); - return OkStatus(); + return absl::OkStatus(); } std::string IteratorStateVariant::TypeName() { @@ -455,7 +456,7 @@ IteratorStateVariant::IteratorStateVariant(const IteratorStateVariant& other) { Status IteratorStateVariant::InitializeFromVariantData( std::unique_ptr data) { data_ = std::move(data); - return OkStatus(); + return absl::OkStatus(); } void IteratorStateVariant::Encode(VariantTensorData* data) const { @@ -548,7 +549,7 @@ Status AsGraphDefForRewrite(OpKernelContext* ctx, const DatasetBase* input, *dataset_node = node.input(0); } } - return OkStatus(); + return absl::OkStatus(); } Status AsGraphDef(const DatasetBase* dataset, @@ -583,7 +584,7 @@ Status AsGraphDef(const DatasetBase* dataset, .WithAttr("T", DT_VARIANT) .WithAttr("index", 0)); TF_RETURN_IF_ERROR(b.ToGraphDef(graph_def)); - return OkStatus(); + return absl::OkStatus(); } tsl::StatusOr> CheckpointStats( @@ -599,19 +600,15 @@ tsl::StatusOr> CheckpointStats( "Failed to parse checkpoint tensor from proto."); } - int64_t num_tensors = t.dim_size(0); - auto serialized_vec = t.vec(); - std::vector data; - data.reserve(num_tensors); - for (int i = 0; i < num_tensors; ++i) { - auto* w = serialized_vec(i).get(); - if (!w) { - return absl::InvalidArgumentError( - "Failed to access IteratorStateVariant inside checkpoint tensor"); - } - data.push_back(w->GetData()); + auto variant = t.scalar()(); + auto* w = variant.get(); + if (!w) { + return absl::InvalidArgumentError( + "Failed to access IteratorStateVariant inside checkpoint tensor"); } - auto reader = std::make_unique(data); + const VariantTensorData* data = w->GetData(); + auto reader = std::make_unique( + std::vector{data}); absl::flat_hash_map stats; for (const auto& [key, tensor] : reader->ReadAllTensors()) { stats[key] = tensor.TotalBytes(); diff --git a/tensorflow/core/data/service/BUILD b/tensorflow/core/data/service/BUILD index cb7a59de37908d..53a5b42936e70d 100644 --- a/tensorflow/core/data/service/BUILD +++ b/tensorflow/core/data/service/BUILD @@ -35,10 +35,6 @@ tf_proto_library( create_java_proto = False, create_kotlin_proto = False, protodeps = tf_additional_all_protos(), - visibility = [ - ":data_transfer_visibility", - "//tensorflow:internal", - ], ) tf_proto_library( @@ -413,6 +409,7 @@ tf_cc_test( "//tensorflow/core/platform:status_matchers", "//tensorflow/core/platform:statusor", "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/status", "@local_tsl//tsl/platform:env", "@local_tsl//tsl/platform:path", "@local_tsl//tsl/protobuf:protos_all_cc", diff --git a/tensorflow/core/data/service/auto_scaler.cc b/tensorflow/core/data/service/auto_scaler.cc index 9afffdbac301dd..c340f3dd599be8 100644 --- a/tensorflow/core/data/service/auto_scaler.cc +++ b/tensorflow/core/data/service/auto_scaler.cc @@ -130,8 +130,8 @@ std::optional AutoScaler::GetOptimalNumberOfWorkers() const return std::max(int64_t{1}, optimal_number_of_workers); } -tsl::Status AutoScaler::ReportProcessingTime(const std::string& worker_address, - absl::Duration processing_time) +absl::Status AutoScaler::ReportProcessingTime(const std::string& worker_address, + absl::Duration processing_time) TF_LOCKS_EXCLUDED(mu_) { if (processing_time <= absl::ZeroDuration()) { return absl::InvalidArgumentError(absl::StrCat( @@ -146,7 +146,7 @@ tsl::Status AutoScaler::ReportProcessingTime(const std::string& worker_address, return tsl::OkStatus(); } -tsl::Status AutoScaler::ReportTargetProcessingTime( +absl::Status AutoScaler::ReportTargetProcessingTime( int64_t consumer_id, absl::Duration target_processing_time) TF_LOCKS_EXCLUDED(mu_) { if (target_processing_time <= absl::ZeroDuration()) { @@ -163,7 +163,7 @@ tsl::Status AutoScaler::ReportTargetProcessingTime( return tsl::OkStatus(); } -tsl::Status AutoScaler::RemoveWorker(const std::string& worker_address) +absl::Status AutoScaler::RemoveWorker(const std::string& worker_address) TF_LOCKS_EXCLUDED(mu_) { tsl::mutex_lock l(mu_); if (!worker_throughputs_.contains(worker_address)) @@ -175,7 +175,7 @@ tsl::Status AutoScaler::RemoveWorker(const std::string& worker_address) return tsl::OkStatus(); } -tsl::Status AutoScaler::RemoveConsumer(int64_t consumer_id) +absl::Status AutoScaler::RemoveConsumer(int64_t consumer_id) TF_LOCKS_EXCLUDED(mu_) { tsl::mutex_lock l(mu_); if (!consumption_rates_.contains(consumer_id)) @@ -194,7 +194,7 @@ void MultipleIterationsAutoScaler::EnsureIterationIsRegistered( } } -tsl::Status MultipleIterationsAutoScaler::UnregisterIteration( +absl::Status MultipleIterationsAutoScaler::UnregisterIteration( int64_t iteration_id) TF_LOCKS_EXCLUDED(mu_) { tsl::mutex_lock l(mu_); if (!auto_scalers_.contains(iteration_id)) @@ -204,7 +204,7 @@ tsl::Status MultipleIterationsAutoScaler::UnregisterIteration( return tsl::OkStatus(); } -tsl::Status MultipleIterationsAutoScaler::UpdateOptimalNumberOfWorkersMetric( +absl::Status MultipleIterationsAutoScaler::UpdateOptimalNumberOfWorkersMetric( int64_t current_number_of_workers) TF_LOCKS_EXCLUDED(mu_) { if (current_number_of_workers <= 0) return absl::InvalidArgumentError( @@ -263,29 +263,29 @@ std::optional MultipleIterationsAutoScaler::GetOptimalNumberOfWorkers() return optimal_number_of_workers; } -tsl::Status MultipleIterationsAutoScaler::ReportProcessingTime( +absl::Status MultipleIterationsAutoScaler::ReportProcessingTime( int64_t iteration_id, const std::string& worker_address, absl::Duration processing_time) TF_LOCKS_EXCLUDED(mu_) { tsl::mutex_lock l(mu_); EnsureIterationIsRegistered(iteration_id); - tsl::Status status = auto_scalers_[iteration_id]->ReportProcessingTime( + absl::Status status = auto_scalers_[iteration_id]->ReportProcessingTime( worker_address, processing_time); return status; } -tsl::Status MultipleIterationsAutoScaler::ReportTargetProcessingTime( +absl::Status MultipleIterationsAutoScaler::ReportTargetProcessingTime( int64_t iteration_id, int64_t consumer_id, absl::Duration target_processing_time) TF_LOCKS_EXCLUDED(mu_) { tsl::mutex_lock l(mu_); EnsureIterationIsRegistered(iteration_id); - tsl::Status status = auto_scalers_[iteration_id]->ReportTargetProcessingTime( + absl::Status status = auto_scalers_[iteration_id]->ReportTargetProcessingTime( consumer_id, target_processing_time); return status; } -tsl::Status MultipleIterationsAutoScaler::RemoveWorker( +absl::Status MultipleIterationsAutoScaler::RemoveWorker( int64_t iteration_id, const std::string& worker_address) TF_LOCKS_EXCLUDED(mu_) { tsl::tf_shared_lock l(mu_); @@ -293,20 +293,21 @@ tsl::Status MultipleIterationsAutoScaler::RemoveWorker( return absl::NotFoundError(absl::StrCat( "There are no reported times for iteration_id ", iteration_id)); - tsl::Status status = + absl::Status status = auto_scalers_[iteration_id]->RemoveWorker(worker_address); return status; } -tsl::Status MultipleIterationsAutoScaler::RemoveConsumer(int64_t iteration_id, - int64_t consumer_id) +absl::Status MultipleIterationsAutoScaler::RemoveConsumer(int64_t iteration_id, + int64_t consumer_id) TF_LOCKS_EXCLUDED(mu_) { tsl::tf_shared_lock l(mu_); if (!auto_scalers_.contains(iteration_id)) return absl::NotFoundError(absl::StrCat( "There are no reported times for iteration_id ", iteration_id)); - tsl::Status status = auto_scalers_[iteration_id]->RemoveConsumer(consumer_id); + absl::Status status = + auto_scalers_[iteration_id]->RemoveConsumer(consumer_id); return status; } diff --git a/tensorflow/core/data/service/auto_scaler.h b/tensorflow/core/data/service/auto_scaler.h index 0c41700e724633..860ccf77368347 100644 --- a/tensorflow/core/data/service/auto_scaler.h +++ b/tensorflow/core/data/service/auto_scaler.h @@ -76,24 +76,24 @@ class AutoScaler { // Reports the latest observed processing time from the worker with // `worker_address`. Returns an error if `processing_time` is ZeroDuration or // negative. - tsl::Status ReportProcessingTime(const std::string& worker_address, - absl::Duration processing_time) + absl::Status ReportProcessingTime(const std::string& worker_address, + absl::Duration processing_time) TF_LOCKS_EXCLUDED(mu_); // Reports the latest observed target processing time from the consumer // identified by `consumer_id`. Returns an error if `target_processing_time` // is ZeroDuration or negative. - tsl::Status ReportTargetProcessingTime(int64_t consumer_id, - absl::Duration target_processing_time) + absl::Status ReportTargetProcessingTime(int64_t consumer_id, + absl::Duration target_processing_time) TF_LOCKS_EXCLUDED(mu_); // Unregisters the worker with `worker_address`, removing its reported // processing time from consideration of the current workload estimation. // Returns an error if the specified worker does not exist. - tsl::Status RemoveWorker(const std::string& worker_address) + absl::Status RemoveWorker(const std::string& worker_address) TF_LOCKS_EXCLUDED(mu_); // Unregisters the consumer identified by `consumer_id`, removing its reported // target processing time from consideration of the current workload // estimation. Returns an error if the specified consumer does not exist. - tsl::Status RemoveConsumer(int64_t consumer_id) TF_LOCKS_EXCLUDED(mu_); + absl::Status RemoveConsumer(int64_t consumer_id) TF_LOCKS_EXCLUDED(mu_); private: mutable tsl::mutex mu_; @@ -118,13 +118,13 @@ class MultipleIterationsAutoScaler { // Unregisters iteration with `iteration_id`, removing its reported // times from consideration of the current workload estimation. // Returns an error if the specified iteration does not exist. - tsl::Status UnregisterIteration(int64_t iteration_id) TF_LOCKS_EXCLUDED(mu_); + absl::Status UnregisterIteration(int64_t iteration_id) TF_LOCKS_EXCLUDED(mu_); // Updates the metric value with the current estimated optimal number of // workers. The estimate is limited to min(4 * `current_number_of_workers`, // `current_number_of_workers` + 500). Returns an error if there are no // previously reported processing and target processing times for at least one // iteration, or `current_number_of_workers` is not positive. - tsl::Status UpdateOptimalNumberOfWorkersMetric( + absl::Status UpdateOptimalNumberOfWorkersMetric( int64_t current_number_of_workers) TF_LOCKS_EXCLUDED(mu_); // Returns the estimated optimal number of workers according to the current // observed workload. If there are no previously reported processing and @@ -134,31 +134,31 @@ class MultipleIterationsAutoScaler { // Reports the latest observed processing time from the worker with // `worker_address` for iteration with `iteration_id`. Returns an error if // `processing_time` is ZeroDuration or negative. - tsl::Status ReportProcessingTime(int64_t iteration_id, - const std::string& worker_address, - absl::Duration processing_time) + absl::Status ReportProcessingTime(int64_t iteration_id, + const std::string& worker_address, + absl::Duration processing_time) TF_LOCKS_EXCLUDED(mu_); // Reports the latest observed target processing time from the consumer // identified by `consumer_id` for iteration with `iteration_id`. Returns an // error if `target_processing_time` is ZeroDuration or negative. - tsl::Status ReportTargetProcessingTime(int64_t iteration_id, - int64_t consumer_id, - absl::Duration target_processing_time) + absl::Status ReportTargetProcessingTime(int64_t iteration_id, + int64_t consumer_id, + absl::Duration target_processing_time) TF_LOCKS_EXCLUDED(mu_); // Unregisters the worker with `worker_address` for iteration with // `iteration_id`, removing its reported processing time from consideration of // the current workload estimation. Returns an error if there are no // previously reported processing times for iteration with `iteration_id` and // the specified worker. - tsl::Status RemoveWorker(int64_t iteration_id, - const std::string& worker_address) + absl::Status RemoveWorker(int64_t iteration_id, + const std::string& worker_address) TF_LOCKS_EXCLUDED(mu_); // Unregisters the consumer identified by `consumer_id` for iteration with // `iteration_id`, removing its reported target processing time from // consideration of the current workload estimation. Returns an error if there // are no previously reported processing times for iteration with // `iteration_id` and the specified consumer. - tsl::Status RemoveConsumer(int64_t iteration_id, int64_t consumer_id) + absl::Status RemoveConsumer(int64_t iteration_id, int64_t consumer_id) TF_LOCKS_EXCLUDED(mu_); private: diff --git a/tensorflow/core/data/service/auto_scaler_test.cc b/tensorflow/core/data/service/auto_scaler_test.cc index 6007b7f758c33b..c04ea49d216bf6 100644 --- a/tensorflow/core/data/service/auto_scaler_test.cc +++ b/tensorflow/core/data/service/auto_scaler_test.cc @@ -202,14 +202,14 @@ TEST(AutoScalerTest, ReportProcessingTimeNewAndExisting) { TEST(AutoScalerTest, ReportProcessingTimeZeroDuration) { AutoScaler auto_scaler; - tsl::Status result = auto_scaler.ReportProcessingTime("/worker/task/0:20000", - absl::ZeroDuration()); + absl::Status result = auto_scaler.ReportProcessingTime("/worker/task/0:20000", + absl::ZeroDuration()); EXPECT_THAT(result, StatusIs(absl::StatusCode::kInvalidArgument)); } TEST(AutoScalerTest, ReportProcessingTimeNegativeDuration) { AutoScaler auto_scaler; - tsl::Status result = auto_scaler.ReportProcessingTime( + absl::Status result = auto_scaler.ReportProcessingTime( "/worker/task/0:20000", absl::Microseconds(-10)); EXPECT_THAT(result, StatusIs(absl::StatusCode::kInvalidArgument)); } @@ -246,14 +246,14 @@ TEST(AutoScalerTest, ReportTargetProcessingTimeNewAndExisting) { TEST(AutoScalerTest, ReportTargetProcessingTimeZeroDuration) { AutoScaler auto_scaler; - tsl::Status result = + absl::Status result = auto_scaler.ReportTargetProcessingTime(0, absl::ZeroDuration()); EXPECT_THAT(result, StatusIs(absl::StatusCode::kInvalidArgument)); } TEST(AutoScalerTest, ReportTargetProcessingTimeNegativeDuration) { AutoScaler auto_scaler; - tsl::Status result = + absl::Status result = auto_scaler.ReportTargetProcessingTime(0, absl::Microseconds(-10)); EXPECT_THAT(result, StatusIs(absl::StatusCode::kInvalidArgument)); } @@ -324,7 +324,7 @@ TEST(MultipleIterationsAutoScalerTest, UnregisterNonexistentIteration) { TEST(MultipleIterationsAutoScalerTest, UpdateOptimalNumberOfWorkersMetricInvalidCurrentWorkers) { MultipleIterationsAutoScaler auto_scaler; - tsl::Status status = auto_scaler.UpdateOptimalNumberOfWorkersMetric(0); + absl::Status status = auto_scaler.UpdateOptimalNumberOfWorkersMetric(0); EXPECT_THAT(status, StatusIs(absl::StatusCode::kInvalidArgument)); status = auto_scaler.UpdateOptimalNumberOfWorkersMetric(-1); EXPECT_THAT(status, StatusIs(absl::StatusCode::kInvalidArgument)); @@ -333,7 +333,7 @@ TEST(MultipleIterationsAutoScalerTest, TEST(MultipleIterationsAutoScalerTest, UpdateOptimalNumberOfWorkersMetricNoReportedTimes) { MultipleIterationsAutoScaler auto_scaler; - tsl::Status status = auto_scaler.UpdateOptimalNumberOfWorkersMetric(1); + absl::Status status = auto_scaler.UpdateOptimalNumberOfWorkersMetric(1); EXPECT_THAT(status, StatusIs(absl::StatusCode::kUnavailable)); } @@ -345,7 +345,7 @@ TEST(MultipleIterationsAutoScalerTest, auto_scaler.ReportTargetProcessingTime(0, 0, absl::Microseconds(5))); TF_ASSERT_OK( auto_scaler.ReportTargetProcessingTime(1, 0, absl::Microseconds(5))); - tsl::Status status = auto_scaler.UpdateOptimalNumberOfWorkersMetric(1); + absl::Status status = auto_scaler.UpdateOptimalNumberOfWorkersMetric(1); EXPECT_THAT(status, StatusIs(absl::StatusCode::kUnavailable)); } @@ -357,7 +357,7 @@ TEST(MultipleIterationsAutoScalerTest, absl::Microseconds(10))); TF_ASSERT_OK(auto_scaler.ReportProcessingTime(1, "/worker/task/0:20000", absl::Microseconds(10))); - tsl::Status status = auto_scaler.UpdateOptimalNumberOfWorkersMetric(1); + absl::Status status = auto_scaler.UpdateOptimalNumberOfWorkersMetric(1); EXPECT_THAT(status, StatusIs(absl::StatusCode::kUnavailable)); } @@ -587,7 +587,7 @@ TEST(MultipleIterationsAutoScalerTest, ReportProcessingTimeNewAndExisting) { TEST(MultipleIterationsAutoScalerTest, ReportProcessingTimeZeroDuration) { MultipleIterationsAutoScaler auto_scaler; - tsl::Status result = auto_scaler.ReportProcessingTime( + absl::Status result = auto_scaler.ReportProcessingTime( 0, "/worker/task/0:20000", absl::ZeroDuration()); EXPECT_THAT(result, StatusIs(absl::StatusCode::kInvalidArgument)); } @@ -595,7 +595,7 @@ TEST(MultipleIterationsAutoScalerTest, ReportProcessingTimeZeroDuration) { TEST(MultipleIterationsAutoScalerTest, ReportProcessingTimeNegativeDuration) { MultipleIterationsAutoScaler auto_scaler; - tsl::Status result = auto_scaler.ReportProcessingTime( + absl::Status result = auto_scaler.ReportProcessingTime( 0, "/worker/task/0:20000", absl::Microseconds(-10)); EXPECT_THAT(result, StatusIs(absl::StatusCode::kInvalidArgument)); } @@ -655,7 +655,7 @@ TEST(MultipleIterationsAutoScalerTest, TEST(MultipleIterationsAutoScalerTest, ReportTargetProcessingTimeZeroDuration) { MultipleIterationsAutoScaler auto_scaler; - tsl::Status result = + absl::Status result = auto_scaler.ReportTargetProcessingTime(0, 0, absl::ZeroDuration()); EXPECT_THAT(result, StatusIs(absl::StatusCode::kInvalidArgument)); } @@ -664,7 +664,7 @@ TEST(MultipleIterationsAutoScalerTest, ReportTargetProcessingTimeNegativeDuration) { MultipleIterationsAutoScaler auto_scaler; - tsl::Status result = + absl::Status result = auto_scaler.ReportTargetProcessingTime(0, 0, absl::Microseconds(-10)); EXPECT_THAT(result, StatusIs(absl::StatusCode::kInvalidArgument)); } diff --git a/tensorflow/core/data/service/client/data_service_client.cc b/tensorflow/core/data/service/client/data_service_client.cc index 4b4c16a9c7b009..2f01dc4bb21c0d 100644 --- a/tensorflow/core/data/service/client/data_service_client.cc +++ b/tensorflow/core/data/service/client/data_service_client.cc @@ -133,7 +133,7 @@ Status DataServiceClient::Initialize() { params_.address), deadline_micros)); initialized_ = true; - return OkStatus(); + return absl::OkStatus(); } StatusOr DataServiceClient::GetNext( @@ -425,7 +425,7 @@ Status DataServiceClient::AddTask(const TaskInfo& task_info) std::mt19937 rng; std::shuffle(tasks_.begin(), tasks_.end(), rng); } - return OkStatus(); + return absl::OkStatus(); } void DataServiceClient::Heartbeat() TF_LOCKS_EXCLUDED(mu_) { @@ -808,10 +808,10 @@ Status DataServiceClient::MaybeRemoveTask(Task& task, int64_t deadline_micros, result.ready = true; result.skip = true; get_next_cv_.notify_all(); - return OkStatus(); + return absl::OkStatus(); } VLOG(1) << "Failed to remove task for worker " << task.info.worker_address(); - return OkStatus(); + return absl::OkStatus(); } Status DataServiceClient::GetElement(Task* task, int64_t deadline_micros, @@ -857,7 +857,7 @@ Status DataServiceClient::GetElement(Task* task, int64_t deadline_micros, TF_RETURN_IF_ERROR(MaybeRemoveTask(*task, deadline_micros, *result)); mutex_lock l(mu_); if (result->skip) { - return OkStatus(); + return absl::OkStatus(); } } int64_t backoff_until = std::min( @@ -874,11 +874,11 @@ Status DataServiceClient::GetElement(Task* task, int64_t deadline_micros, // task before returning to this one. result->ready = true; result->skip = true; - return OkStatus(); + return absl::OkStatus(); } } ProcessGetElementResponse(enqueue_result, get_element_result, result, *task); - return OkStatus(); + return absl::OkStatus(); } bool DataServiceClient::ResultReady() const TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { diff --git a/tensorflow/core/data/service/client/data_service_client.h b/tensorflow/core/data/service/client/data_service_client.h index e96aa45125ab9f..e339a0d1f0d8d9 100644 --- a/tensorflow/core/data/service/client/data_service_client.h +++ b/tensorflow/core/data/service/client/data_service_client.h @@ -229,7 +229,7 @@ class DataServiceClient { // A status to be returned from the next call to `GetNext`. This is set by // asynchronous threads when they encounter errors. - Status status_ TF_GUARDED_BY(mu_) = OkStatus(); + Status status_ TF_GUARDED_BY(mu_) = absl::OkStatus(); // A queue of results for `GetElement` requests to read from. When doing // strict round robin reads, the queue will contain placeholder results with // their `Result::ready` field false until their data has been retrieved diff --git a/tensorflow/core/data/service/client/validate_utils.cc b/tensorflow/core/data/service/client/validate_utils.cc index c8972d34b60d24..b4769ce1282eee 100644 --- a/tensorflow/core/data/service/client/validate_utils.cc +++ b/tensorflow/core/data/service/client/validate_utils.cc @@ -30,7 +30,7 @@ namespace { // Validates local worker related parameters. Status ValidateLocalWorkers(const DataServiceParams& data_service_params) { if (data_service_params.target_workers != TARGET_WORKERS_LOCAL) { - return OkStatus(); + return absl::OkStatus(); } if (LocalWorkers::Empty()) { if (IsStaticShard(data_service_params.processing_mode)) { @@ -54,13 +54,13 @@ Status ValidateLocalWorkers(const DataServiceParams& data_service_params) { "Coordinated reads require non-local workers, but `target_workers` " "is \"LOCAL\"."); } - return OkStatus(); + return absl::OkStatus(); } // Validates cross-trainer cache related parameters. Status ValidateCrossTrainerCache(const DataServiceParams& data_service_params) { if (!data_service_params.cross_trainer_cache_options.has_value()) { - return OkStatus(); + return absl::OkStatus(); } if (data_service_params.job_name.empty()) { return errors::InvalidArgument( @@ -84,14 +84,14 @@ Status ValidateCrossTrainerCache(const DataServiceParams& data_service_params) { "Got number of coordinated consumers: ", data_service_params.num_consumers.value()); } - return OkStatus(); + return absl::OkStatus(); } } // namespace Status ValidateDataServiceParams(const DataServiceParams& data_service_params) { TF_RETURN_IF_ERROR(ValidateLocalWorkers(data_service_params)); TF_RETURN_IF_ERROR(ValidateCrossTrainerCache(data_service_params)); - return OkStatus(); + return absl::OkStatus(); } } // namespace data diff --git a/tensorflow/core/data/service/common.cc b/tensorflow/core/data/service/common.cc index c6ac0e2cdee63b..cf0d352baf5a81 100644 --- a/tensorflow/core/data/service/common.cc +++ b/tensorflow/core/data/service/common.cc @@ -61,10 +61,10 @@ Status ValidateProcessingMode(const ProcessingModeDef& processing_mode) { "specify a valid sharding policy. Please add the policy to either " "`IsDynamicShard` or `IsStaticShard` (i.e., auto-shard)."); } - return OkStatus(); + return absl::OkStatus(); } -StatusOr ToAutoShardPolicy( +absl::StatusOr ToAutoShardPolicy( const ProcessingModeDef::ShardingPolicy sharding_policy) { switch (sharding_policy) { case ProcessingModeDef::FILE: @@ -87,7 +87,7 @@ StatusOr ToAutoShardPolicy( } } -StatusOr ParseTargetWorkers(absl::string_view s) { +absl::StatusOr ParseTargetWorkers(absl::string_view s) { std::string str_upper = absl::AsciiStrToUpper(s); if (str_upper.empty() || str_upper == kAuto) { return TARGET_WORKERS_AUTO; @@ -115,7 +115,7 @@ std::string TargetWorkersToString(TargetWorkers target_workers) { } } -StatusOr ParseDeploymentMode(absl::string_view s) { +absl::StatusOr ParseDeploymentMode(absl::string_view s) { std::string str_upper = absl::AsciiStrToUpper(s); if (str_upper == kColocated) { return DEPLOYMENT_MODE_COLOCATED; diff --git a/tensorflow/core/data/service/common.h b/tensorflow/core/data/service/common.h index 873c8361325840..550cffeb7b9558 100644 --- a/tensorflow/core/data/service/common.h +++ b/tensorflow/core/data/service/common.h @@ -71,19 +71,19 @@ Status ValidateProcessingMode(const ProcessingModeDef& processing_mode); // Converts tf.data service `sharding_policy` to `AutoShardPolicy`. Returns an // internal error if `sharding_policy` is not supported. -StatusOr ToAutoShardPolicy( +absl::StatusOr ToAutoShardPolicy( ProcessingModeDef::ShardingPolicy sharding_policy); // Parses a string representing a `TargetWorkers` (case-insensitive). // Returns InvalidArgument if the string is not recognized. -StatusOr ParseTargetWorkers(absl::string_view s); +absl::StatusOr ParseTargetWorkers(absl::string_view s); // Converts a `TargetWorkers` enum to string. std::string TargetWorkersToString(TargetWorkers target_workers); // Parses a string representing a `DeploymentMode` (case-insensitive). // Returns InvalidArgument if the string is not recognized. -StatusOr ParseDeploymentMode(absl::string_view s); +absl::StatusOr ParseDeploymentMode(absl::string_view s); // Returns true if `status` is a retriable error that indicates preemption. bool IsPreemptedError(const Status& status); diff --git a/tensorflow/core/data/service/common_test.cc b/tensorflow/core/data/service/common_test.cc index b1214d54d9a136..72da0c2c96ae52 100644 --- a/tensorflow/core/data/service/common_test.cc +++ b/tensorflow/core/data/service/common_test.cc @@ -174,7 +174,7 @@ TEST(CommonTest, IsPreemptedError) { EXPECT_TRUE(IsPreemptedError(errors::Aborted("Aborted"))); EXPECT_TRUE(IsPreemptedError(errors::Cancelled("Cancelled"))); EXPECT_TRUE(IsPreemptedError(errors::Unavailable("Unavailable"))); - EXPECT_FALSE(IsPreemptedError(OkStatus())); + EXPECT_FALSE(IsPreemptedError(absl::OkStatus())); } TEST(CommonTest, IsPermanentError) { diff --git a/tensorflow/core/data/service/credentials_factory.cc b/tensorflow/core/data/service/credentials_factory.cc index 4322e3d43754e9..9367296e80bbdf 100644 --- a/tensorflow/core/data/service/credentials_factory.cc +++ b/tensorflow/core/data/service/credentials_factory.cc @@ -55,7 +55,7 @@ Status CredentialsFactory::Get(absl::string_view protocol, auto it = credentials_factories().find(std::string(protocol)); if (it != credentials_factories().end()) { *out = it->second; - return OkStatus(); + return absl::OkStatus(); } std::vector available_types; @@ -75,7 +75,7 @@ Status CredentialsFactory::CreateServerCredentials( CredentialsFactory* factory; TF_RETURN_IF_ERROR(CredentialsFactory::Get(protocol, &factory)); TF_RETURN_IF_ERROR(factory->CreateServerCredentials(out)); - return OkStatus(); + return absl::OkStatus(); } Status CredentialsFactory::CreateClientCredentials( @@ -84,7 +84,7 @@ Status CredentialsFactory::CreateClientCredentials( CredentialsFactory* factory; TF_RETURN_IF_ERROR(CredentialsFactory::Get(protocol, &factory)); TF_RETURN_IF_ERROR(factory->CreateClientCredentials(out)); - return OkStatus(); + return absl::OkStatus(); } bool CredentialsFactory::Exists(absl::string_view protocol) { @@ -100,13 +100,13 @@ class InsecureCredentialsFactory : public CredentialsFactory { Status CreateServerCredentials( std::shared_ptr<::grpc::ServerCredentials>* out) override { *out = ::grpc::InsecureServerCredentials(); - return OkStatus(); + return absl::OkStatus(); } Status CreateClientCredentials( std::shared_ptr<::grpc::ChannelCredentials>* out) override { *out = ::grpc::InsecureChannelCredentials(); - return OkStatus(); + return absl::OkStatus(); } }; diff --git a/tensorflow/core/data/service/cross_trainer_cache.h b/tensorflow/core/data/service/cross_trainer_cache.h index 98f7a52957af63..98c1725c2dce79 100644 --- a/tensorflow/core/data/service/cross_trainer_cache.h +++ b/tensorflow/core/data/service/cross_trainer_cache.h @@ -163,7 +163,7 @@ class CrossTrainerCache { // If `status_` is non-OK, the cache is cancelled, and all method calls will // return this status. - Status status_ TF_GUARDED_BY(mu_) = OkStatus(); + Status status_ TF_GUARDED_BY(mu_) = absl::OkStatus(); // `cache_` stores the cached elements. std::deque> cache_ TF_GUARDED_BY(mu_); @@ -294,7 +294,7 @@ Status CrossTrainerCache::ExtendCache() TF_LOCKS_EXCLUDED(mu_) { FreeSpace(new_element_size_bytes); cache_.push_back(std::make_shared(std::move(element))); cache_size_bytes_ += new_element_size_bytes; - return OkStatus(); + return absl::OkStatus(); } template diff --git a/tensorflow/core/data/service/cross_trainer_cache_test.cc b/tensorflow/core/data/service/cross_trainer_cache_test.cc index 9d4aff79f6f2e6..a0359b5135266e 100644 --- a/tensorflow/core/data/service/cross_trainer_cache_test.cc +++ b/tensorflow/core/data/service/cross_trainer_cache_test.cc @@ -51,7 +51,7 @@ using ::testing::UnorderedElementsAreArray; class InfiniteRange : public CachableSequence { public: - StatusOr GetNext() override { return next_++; } + absl::StatusOr GetNext() override { return next_++; } size_t GetElementSizeBytes(const int64_t& element) const override { return sizeof(element); } @@ -63,7 +63,7 @@ class InfiniteRange : public CachableSequence { class TensorDataset : public CachableSequence { public: - StatusOr GetNext() override { return Tensor("Test Tensor"); } + absl::StatusOr GetNext() override { return Tensor("Test Tensor"); } size_t GetElementSizeBytes(const Tensor& element) const override { return element.TotalBytes(); } @@ -73,7 +73,7 @@ class SlowDataset : public CachableSequence { public: explicit SlowDataset(absl::Duration delay) : delay_(delay) {} - StatusOr GetNext() override { + absl::StatusOr GetNext() override { Env::Default()->SleepForMicroseconds(absl::ToInt64Microseconds(delay_)); return Tensor("Test Tensor"); } @@ -369,7 +369,7 @@ TEST(CrossTrainerCacheTest, Cancel) { /*thread_options=*/{}, /*name=*/absl::StrCat("Trainer_", i), [&cache, &status, &mu]() { for (int j = 0; true; ++j) { - StatusOr> tensor = + absl::StatusOr> tensor = cache.Get(absl::StrCat("Trainer_", j % 1000)); { mutex_lock l(mu); diff --git a/tensorflow/core/data/service/data_service_test.cc b/tensorflow/core/data/service/data_service_test.cc index f506b3bcb13b54..52e07993195a61 100644 --- a/tensorflow/core/data/service/data_service_test.cc +++ b/tensorflow/core/data/service/data_service_test.cc @@ -235,7 +235,7 @@ TEST(DataServiceTest, GcMissingClientsWithSmallTimeout) { TF_ASSERT_OK(dataset_client.GetTasks(iteration_client_id).status()); // Iteration should be garbage collected within 10 seconds. absl::Time wait_start = absl::Now(); - TF_ASSERT_OK(WaitWhile([&]() -> StatusOr { + TF_ASSERT_OK(WaitWhile([&]() -> absl::StatusOr { TF_ASSIGN_OR_RETURN(size_t num_iterations, cluster.NumActiveIterations()); return num_iterations > 0; })); diff --git a/tensorflow/core/data/service/data_transfer.h b/tensorflow/core/data/service/data_transfer.h index 14f5ce64f4b43d..788dd241f185b7 100644 --- a/tensorflow/core/data/service/data_transfer.h +++ b/tensorflow/core/data/service/data_transfer.h @@ -91,7 +91,7 @@ class DataTransferClient { // Returns a string describing properties of the client relevant for checking // compatibility with a server for a given protocol. - virtual StatusOr GetCompatibilityInfo() const { + virtual absl::StatusOr GetCompatibilityInfo() const { return std::string(); } @@ -99,7 +99,7 @@ class DataTransferClient { // properties described in `server_compatibility_info`. virtual Status CheckCompatibility( const std::string& server_compatibility_info) const { - return OkStatus(); + return absl::OkStatus(); } protected: @@ -130,7 +130,7 @@ class DataTransferServer { // Returns a string describing properties of the server relevant for checking // compatibility with a client for a given protocol. - virtual StatusOr GetCompatibilityInfo() const { + virtual absl::StatusOr GetCompatibilityInfo() const { return std::string(); } }; diff --git a/tensorflow/core/data/service/data_transfer_test.cc b/tensorflow/core/data/service/data_transfer_test.cc index 7b057f28a1f06d..182f6198759702 100644 --- a/tensorflow/core/data/service/data_transfer_test.cc +++ b/tensorflow/core/data/service/data_transfer_test.cc @@ -37,7 +37,7 @@ class TestDataTransferServer : public DataTransferServer { explicit TestDataTransferServer(bool* called) : called_(called) {} Status Start() override { *called_ = true; - return OkStatus(); + return absl::OkStatus(); } int Port() const override { return 0; } @@ -58,7 +58,7 @@ TEST(DataTransferTest, RegisterDataTransferServerBuilder) { bool called = false; DataTransferServer::Register("test", [&called](auto ignore, auto* server) { *server = std::make_shared(&called); - return OkStatus(); + return absl::OkStatus(); }); std::shared_ptr server; diff --git a/tensorflow/core/data/service/dataset_store.cc b/tensorflow/core/data/service/dataset_store.cc index 105c5939378d61..e201ff9ef081ee 100644 --- a/tensorflow/core/data/service/dataset_store.cc +++ b/tensorflow/core/data/service/dataset_store.cc @@ -38,7 +38,7 @@ Status FileSystemDatasetStore::Put(const std::string& key, const DatasetDef& dataset) { std::string path_to_write = io::JoinPath(datasets_dir_, key); TF_RETURN_IF_ERROR(WriteDatasetDef(path_to_write, dataset)); - return OkStatus(); + return absl::OkStatus(); } Status FileSystemDatasetStore::Get( @@ -48,14 +48,14 @@ Status FileSystemDatasetStore::Get( DatasetDef def; TF_RETURN_IF_ERROR(ReadDatasetDef(path, def)); dataset_def = std::make_shared(def); - return OkStatus(); + return absl::OkStatus(); } Status MemoryDatasetStore::Put(const std::string& key, const DatasetDef& dataset) { auto& stored_dataset = datasets_[key]; stored_dataset = std::make_shared(dataset); - return OkStatus(); + return absl::OkStatus(); } Status MemoryDatasetStore::Get(const std::string& key, @@ -65,7 +65,7 @@ Status MemoryDatasetStore::Get(const std::string& key, return errors::NotFound("Dataset with key ", key, " not found"); } dataset_def = stored_dataset; - return OkStatus(); + return absl::OkStatus(); } } // namespace data diff --git a/tensorflow/core/data/service/dispatcher_client.cc b/tensorflow/core/data/service/dispatcher_client.cc index 6f59d3a60f055f..ff3b165899f562 100644 --- a/tensorflow/core/data/service/dispatcher_client.cc +++ b/tensorflow/core/data/service/dispatcher_client.cc @@ -49,7 +49,7 @@ namespace data { Status DataServiceDispatcherClient::Initialize() { mutex_lock l(mu_); if (stub_) { - return OkStatus(); + return absl::OkStatus(); } std::shared_ptr credentials; TF_RETURN_IF_ERROR( @@ -81,10 +81,11 @@ Status DataServiceDispatcherClient::Initialize() { "same version of TensorFlow. If you're running an MPM binary, make " "sure the server is running an up-to-date MPM."); } - return OkStatus(); + return absl::OkStatus(); } -StatusOr DataServiceDispatcherClient::WorkerHeartbeat( +absl::StatusOr +DataServiceDispatcherClient::WorkerHeartbeat( const WorkerHeartbeatRequest& request) { WorkerHeartbeatResponse response; grpc::ClientContext client_ctx; @@ -109,7 +110,7 @@ Status DataServiceDispatcherClient::WorkerUpdate( if (!status.ok()) { return grpc_util::WrapError("Failed to send worker update", status); } - return OkStatus(); + return absl::OkStatus(); } Status DataServiceDispatcherClient::GetDatasetDef(const std::string& dataset_id, @@ -123,7 +124,7 @@ Status DataServiceDispatcherClient::GetDatasetDef(const std::string& dataset_id, return grpc_util::WrapError("Failed to get dataset def", status); } dataset_def = resp.dataset_def(); - return OkStatus(); + return absl::OkStatus(); } Status DataServiceDispatcherClient::GetSplit(int64_t iteration_id, @@ -148,7 +149,7 @@ Status DataServiceDispatcherClient::GetSplit(int64_t iteration_id, return errors::Internal("Failed to parse split tensor proto"); } } - return OkStatus(); + return absl::OkStatus(); } Status DataServiceDispatcherClient::Snapshot( @@ -167,7 +168,7 @@ Status DataServiceDispatcherClient::Snapshot( if (!status.ok()) { return grpc_util::WrapError("Failed to snapshot", status); } - return OkStatus(); + return absl::OkStatus(); } Status DataServiceDispatcherClient::GetSnapshotSplit( @@ -190,13 +191,13 @@ Status DataServiceDispatcherClient::GetSnapshotSplit( local_split_index = resp.local_split_index(); end_of_splits = resp.end_of_splits(); if (end_of_splits) { - return OkStatus(); + return absl::OkStatus(); } if (!split.FromProto(resp.split())) { return errors::Internal("Failed to parse split tensor proto: ", resp.split().DebugString()); } - return OkStatus(); + return absl::OkStatus(); } Status DataServiceDispatcherClient::RegisterDataset( @@ -218,7 +219,7 @@ Status DataServiceDispatcherClient::RegisterDataset( return grpc_util::WrapError("Failed to register dataset", status); } dataset_id = resp.dataset_id(); - return OkStatus(); + return absl::OkStatus(); } Status DataServiceDispatcherClient::GetOrCreateJob( @@ -248,7 +249,7 @@ Status DataServiceDispatcherClient::GetOrCreateJob( status); } job_id = resp.job_id(); - return OkStatus(); + return absl::OkStatus(); } Status DataServiceDispatcherClient::GetOrCreateIteration( @@ -267,7 +268,7 @@ Status DataServiceDispatcherClient::GetOrCreateIteration( status); } iteration_client_id = resp.iteration_client_id(); - return OkStatus(); + return absl::OkStatus(); } Status DataServiceDispatcherClient::ReleaseIterationClient( @@ -284,7 +285,7 @@ Status DataServiceDispatcherClient::ReleaseIterationClient( iteration_client_id), status); } - return OkStatus(); + return absl::OkStatus(); } Status DataServiceDispatcherClient::MaybeRemoveTask(int64_t task_id, @@ -303,7 +304,7 @@ Status DataServiceDispatcherClient::MaybeRemoveTask(int64_t task_id, return grpc_util::WrapError("Failed to call MaybeRemoveTask", status); } removed = resp.removed(); - return OkStatus(); + return absl::OkStatus(); } Status DataServiceDispatcherClient::ClientHeartbeat( @@ -314,7 +315,7 @@ Status DataServiceDispatcherClient::ClientHeartbeat( if (!s.ok()) { return grpc_util::WrapError("Failed to get tasks", s); } - return OkStatus(); + return absl::OkStatus(); } Status DataServiceDispatcherClient::GetWorkers( @@ -331,7 +332,7 @@ Status DataServiceDispatcherClient::GetWorkers( for (auto& worker : resp.workers()) { workers.push_back(worker); } - return OkStatus(); + return absl::OkStatus(); } Status DataServiceDispatcherClient::GetDataServiceMetadata( @@ -346,7 +347,7 @@ Status DataServiceDispatcherClient::GetDataServiceMetadata( return grpc_util::WrapError("Failed to get data service metadata", s); } metadata = resp.metadata(); - return OkStatus(); + return absl::OkStatus(); } Status DataServiceDispatcherClient::GetDataServiceConfig( @@ -360,7 +361,7 @@ Status DataServiceDispatcherClient::GetDataServiceConfig( return grpc_util::WrapError("Failed to get data service config", s); } config = response.config(); - return OkStatus(); + return absl::OkStatus(); } Status DataServiceDispatcherClient::DisableCompressionAtRuntime( @@ -376,7 +377,7 @@ Status DataServiceDispatcherClient::DisableCompressionAtRuntime( return grpc_util::WrapError( "Failed to get runtime compression disabling decision", s); } - return OkStatus(); + return absl::OkStatus(); } Status DataServiceDispatcherClient::EnsureInitialized() { diff --git a/tensorflow/core/data/service/dispatcher_client.h b/tensorflow/core/data/service/dispatcher_client.h index 40385928482040..9f521bd210ac6a 100644 --- a/tensorflow/core/data/service/dispatcher_client.h +++ b/tensorflow/core/data/service/dispatcher_client.h @@ -50,7 +50,7 @@ class DataServiceDispatcherClient : public DataServiceClientBase { // registered with the dispatcher, this will register the worker. The // dispatcher will report which new tasks the worker should run, and which // tasks it should delete. - StatusOr WorkerHeartbeat( + absl::StatusOr WorkerHeartbeat( const WorkerHeartbeatRequest& request); // Updates the dispatcher with information about the worker's state. diff --git a/tensorflow/core/data/service/dispatcher_client_test.cc b/tensorflow/core/data/service/dispatcher_client_test.cc index 7a000509bc8d81..64cf5c3c76e360 100644 --- a/tensorflow/core/data/service/dispatcher_client_test.cc +++ b/tensorflow/core/data/service/dispatcher_client_test.cc @@ -21,6 +21,7 @@ limitations under the License. #include #include "absl/container/flat_hash_set.h" +#include "absl/status/status.h" #include "tensorflow/core/data/service/common.pb.h" #include "tensorflow/core/data/service/data_transfer.h" #include "tensorflow/core/data/service/dataset_store.h" @@ -74,19 +75,21 @@ DataServiceMetadata GetDefaultMetadata() { class DispatcherClientTest : public ::testing::Test { protected: - Status SetUpTfDataService(int64_t num_workers) { + absl::Status SetUpTfDataService(int64_t num_workers, + int64_t worker_max_concurrent_snapshots = 0) { TestCluster::Config config; config.num_workers = num_workers; config.work_dir = tsl::io::JoinPath(tsl::testing::TmpDir(), "work_dir"); + config.worker_max_concurrent_snapshots = worker_max_concurrent_snapshots; test_cluster_ = std::make_unique(config); TF_RETURN_IF_ERROR(test_cluster_->Initialize()); dispatcher_client_ = std::make_unique( test_cluster_->DispatcherAddress(), kProtocol); - return OkStatus(); + return absl::OkStatus(); } // Creates a dataset and returns the dataset ID. - StatusOr RegisterDataset( + absl::StatusOr RegisterDataset( const DatasetDef& dataset, const DataServiceMetadata& metadata, const std::optional& requested_dataset_id = std::nullopt) { std::string dataset_id; @@ -96,12 +99,15 @@ class DispatcherClientTest : public ::testing::Test { } // Starts snapshots and returns the directories. - StatusOr> StartDummySnapshots() { + absl::StatusOr> StartDummySnapshots( + int64_t num_snapshots) { DistributedSnapshotMetadata metadata = CreateDummyDistributedSnapshotMetadata(); // Create a set of local file paths to which snapshots will be materialized. - absl::flat_hash_set directories = {LocalTempFilename(), - LocalTempFilename()}; + absl::flat_hash_set directories; + for (int64_t i = 0; i < num_snapshots; ++i) { + directories.insert(LocalTempFilename()); + } for (const auto& directory : directories) { TF_RETURN_IF_ERROR( dispatcher_client_->Snapshot(RangeDataset(10), directory, metadata)); @@ -156,7 +162,7 @@ TEST_F(DispatcherClientTest, GetDataServiceConfig) { TEST_F(DispatcherClientTest, SnapshotSkeletonWritten) { TF_ASSERT_OK(SetUpTfDataService(/*num_workers=*/1)); TF_ASSERT_OK_AND_ASSIGN(absl::flat_hash_set paths, - StartDummySnapshots()); + StartDummySnapshots(/*num_snapshots=*/3)); for (const auto& path : paths) { TF_ASSERT_OK(Env::Default()->FileExists(CommittedChunksDirectory(path))); TF_ASSERT_OK(Env::Default()->FileExists(StreamsDirectory(path))); @@ -166,7 +172,7 @@ TEST_F(DispatcherClientTest, SnapshotSkeletonWritten) { TEST_F(DispatcherClientTest, SnapshotMetadataAndDatasetDefWritten) { TF_ASSERT_OK(SetUpTfDataService(/*num_workers=*/1)); TF_ASSERT_OK_AND_ASSIGN(absl::flat_hash_set paths, - StartDummySnapshots()); + StartDummySnapshots(/*num_snapshots=*/3)); for (const auto& path : paths) { TF_ASSERT_OK( Env::Default()->FileExists(io::JoinPath(path, "snapshot.metadata"))); @@ -176,25 +182,30 @@ TEST_F(DispatcherClientTest, SnapshotMetadataAndDatasetDefWritten) { } TEST_F(DispatcherClientTest, SnapshotsInHeartbeat) { - TF_ASSERT_OK(SetUpTfDataService(/*num_workers=*/1)); + TF_ASSERT_OK(SetUpTfDataService(/*num_workers=*/1, + /*worker_max_concurrent_snapshots=*/3)); TF_ASSERT_OK_AND_ASSIGN(absl::flat_hash_set paths, - StartDummySnapshots()); + StartDummySnapshots(/*num_snapshots=*/3)); WorkerHeartbeatRequest worker_heartbeat_request; worker_heartbeat_request.set_worker_address(test_cluster_->WorkerAddress(0)); - TF_ASSERT_OK_AND_ASSIGN( - WorkerHeartbeatResponse worker_heartbeat_response, - dispatcher_client_->WorkerHeartbeat(worker_heartbeat_request)); - ASSERT_EQ(worker_heartbeat_response.snapshot_tasks_size(), paths.size()); - for (const auto& snapshot_task : worker_heartbeat_response.snapshot_tasks()) { - ASSERT_TRUE(paths.count(snapshot_task.base_path())); - ASSERT_EQ(snapshot_task.stream_index(), 0); + + for (int64_t i = 1; i <= 3; ++i) { + TF_ASSERT_OK_AND_ASSIGN( + WorkerHeartbeatResponse worker_heartbeat_response, + dispatcher_client_->WorkerHeartbeat(worker_heartbeat_request)); + ASSERT_EQ(worker_heartbeat_response.snapshot_tasks_size(), i); + for (const auto& snapshot_task : + worker_heartbeat_response.snapshot_tasks()) { + ASSERT_TRUE(paths.count(snapshot_task.base_path())); + ASSERT_EQ(snapshot_task.stream_index(), 0); + } } } TEST_F(DispatcherClientTest, GetSnapshotSplit) { TF_ASSERT_OK(SetUpTfDataService(/*num_workers=*/1)); TF_ASSERT_OK_AND_ASSIGN(absl::flat_hash_set paths, - StartDummySnapshots()); + StartDummySnapshots(/*num_snapshots=*/3)); WorkerHeartbeatRequest worker_heartbeat_request; worker_heartbeat_request.set_worker_address(test_cluster_->WorkerAddress(0)); TF_ASSERT_OK_AND_ASSIGN( @@ -219,10 +230,12 @@ TEST_F(DispatcherClientTest, GetSnapshotSplit) { } TEST_F(DispatcherClientTest, GetSnapshotSplitMultipleStreams) { - TF_ASSERT_OK(SetUpTfDataService(/*num_workers=*/3)); + TF_ASSERT_OK(SetUpTfDataService(/*num_workers=*/3, + /*worker_max_concurrent_snapshots=*/1)); TF_ASSERT_OK_AND_ASSIGN(absl::flat_hash_set paths, - StartDummySnapshots()); + StartDummySnapshots(/*num_snapshots=*/3)); + absl::flat_hash_set snapshots_in_progress; for (int64_t i = 0; i < 3; ++i) { WorkerHeartbeatRequest worker_heartbeat_request; worker_heartbeat_request.set_worker_address( @@ -230,8 +243,10 @@ TEST_F(DispatcherClientTest, GetSnapshotSplitMultipleStreams) { TF_ASSERT_OK_AND_ASSIGN( WorkerHeartbeatResponse worker_heartbeat_response, dispatcher_client_->WorkerHeartbeat(worker_heartbeat_request)); + EXPECT_EQ(worker_heartbeat_response.snapshot_tasks().size(), 1); for (const auto& snapshot_task : worker_heartbeat_response.snapshot_tasks()) { + snapshots_in_progress.insert(snapshot_task.base_path()); GetSnapshotSplitRequest get_snapshot_split_request; Tensor split; int64_t local_split_index = 0; @@ -245,6 +260,9 @@ TEST_F(DispatcherClientTest, GetSnapshotSplitMultipleStreams) { EXPECT_FALSE(end_of_splits); } } + + // Each worker writes one snapshot; each snapshot has been assigned a worker. + EXPECT_EQ(snapshots_in_progress, paths); } TEST_F(DispatcherClientTest, RegisterDatasetWithExplicitId) { diff --git a/tensorflow/core/data/service/dispatcher_impl.cc b/tensorflow/core/data/service/dispatcher_impl.cc index 09d082b2fff7f9..73ee07a6c34e64 100644 --- a/tensorflow/core/data/service/dispatcher_impl.cc +++ b/tensorflow/core/data/service/dispatcher_impl.cc @@ -34,6 +34,7 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/numbers.h" +#include "absl/strings/str_cat.h" #include "absl/strings/str_split.h" #include "absl/time/time.h" #include "tensorflow/core/data/dataset_utils.h" @@ -139,7 +140,7 @@ Status CreateWorkerStub(const std::string& address, const std::string& protocol, CredentialsFactory::CreateClientCredentials(protocol, &credentials)); auto channel = ::grpc::CreateCustomChannel(address, credentials, args); stub = WorkerService::NewStub(channel); - return OkStatus(); + return absl::OkStatus(); } void PrepareGraph(GraphDef* graph) { @@ -227,7 +228,7 @@ Status DataServiceDispatcherImpl::Start() { LOG(INFO) << "Running with fault_tolerant_mode=False. The dispatcher will " "not be able to recover its state on restart."; started_ = true; - return OkStatus(); + return absl::OkStatus(); } journal_writer_ = std::make_unique(env_, JournalDir(config_.work_dir())); @@ -267,7 +268,9 @@ Status DataServiceDispatcherImpl::Start() { TF_RETURN_IF_ERROR(journal_writer_.value()->EnsureInitialized()); TF_RETURN_IF_ERROR(RestoreSnapshots()); started_ = true; - return OkStatus(); + LOG(INFO) << "Started tf.data service dispatcher with config " + << config_.DebugString(); + return absl::OkStatus(); } void DataServiceDispatcherImpl::Stop() TF_LOCKS_EXCLUDED(mu_) { @@ -334,7 +337,7 @@ Status DataServiceDispatcherImpl::RestoreSplitProviders( } } restored = std::move(split_providers); - return OkStatus(); + return absl::OkStatus(); } Status DataServiceDispatcherImpl::FindTasksToDelete( @@ -350,7 +353,7 @@ Status DataServiceDispatcherImpl::FindTasksToDelete( response->add_tasks_to_delete(current_task); } } - return OkStatus(); + return absl::OkStatus(); } Status DataServiceDispatcherImpl::FindNewTasks( @@ -381,7 +384,7 @@ Status DataServiceDispatcherImpl::FindNewTasks( TaskDef* task_def = response->add_new_tasks(); TF_RETURN_IF_ERROR(PopulateTaskDef(task, task_def)); } - return OkStatus(); + return absl::OkStatus(); } void DataServiceDispatcherImpl::ReportProcessingTimesFromActiveTasks( @@ -457,12 +460,21 @@ Status DataServiceDispatcherImpl::WorkerHeartbeat( TF_RETURN_IF_ERROR( FindNewTasks(worker_address, current_tasks, assigned_tasks, response)); } + + std::vector snapshot_paths = + snapshot_assignment_manager_.LoadBalanceSnapshots( + request->worker_address()); std::vector snapshots; + snapshots.reserve(snapshot_paths.size()); { tf_shared_lock l(mu_); - snapshots.reserve(snapshots_.size()); - for (const auto& [path, snapshot_manager] : snapshots_) { - snapshots.push_back(snapshot_manager.get()); + for (const std::string& snapshot_path : snapshot_paths) { + const auto it = snapshots_.find(snapshot_path); + if (it == snapshots_.end()) { + return absl::InternalError(absl::StrCat( + "Dataset snapshot at ", snapshot_path, " does not exist.")); + } + snapshots.push_back(it->second.get()); } } for (SnapshotManager* snapshot_manager : snapshots) { @@ -471,7 +483,7 @@ Status DataServiceDispatcherImpl::WorkerHeartbeat( VLOG(3) << "Finished worker heartbeat for worker at address " << request->worker_address(); - return OkStatus(); + return absl::OkStatus(); } Status DataServiceDispatcherImpl::WorkerUpdate( @@ -495,7 +507,7 @@ Status DataServiceDispatcherImpl::WorkerUpdate( << task->iteration->iteration_id << " completed"; } } - return OkStatus(); + return absl::OkStatus(); } Status DataServiceDispatcherImpl::GetDatasetDef( @@ -507,7 +519,7 @@ Status DataServiceDispatcherImpl::GetDatasetDef( std::shared_ptr dataset_def; TF_RETURN_IF_ERROR(GetDatasetDef(*dataset, dataset_def)); *response->mutable_dataset_def() = *dataset_def; - return OkStatus(); + return absl::OkStatus(); } Status DataServiceDispatcherImpl::GetSplit(const GetSplitRequest* request, @@ -538,7 +550,7 @@ Status DataServiceDispatcherImpl::GetSplit(const GetSplitRequest* request, VLOG(3) << "Returning end_of_splits since current repetition " << current_repetition << " is greater than the requested repetition " << repetition; - return OkStatus(); + return absl::OkStatus(); } split_provider = split_providers_[iteration_id][provider_index].get(); } @@ -563,7 +575,7 @@ Status DataServiceDispatcherImpl::GetSplit(const GetSplitRequest* request, } VLOG(3) << "Returning from GetSplit, split=" << split << ", end_of_splits=" << end_of_splits; - return OkStatus(); + return absl::OkStatus(); } Status DataServiceDispatcherImpl::MakeSplitProviders( @@ -575,13 +587,13 @@ Status DataServiceDispatcherImpl::MakeSplitProviders( std::shared_ptr dataset_def; TF_RETURN_IF_ERROR(GetDatasetDef(*dataset, dataset_def)); TF_RETURN_IF_ERROR(CreateSplitProviders(*dataset_def, split_providers)); - return OkStatus(); + return absl::OkStatus(); } Status DataServiceDispatcherImpl::GetVersion(const GetVersionRequest* request, GetVersionResponse* response) { response->set_version(kDataServiceVersion); - return OkStatus(); + return absl::OkStatus(); } Status DataServiceDispatcherImpl::GetOrRegisterDataset( @@ -599,7 +611,7 @@ Status DataServiceDispatcherImpl::GetOrRegisterDataset( VLOG(3) << "RegisterDataset returns an existing dataset with ID = " << *dataset_id; response->set_dataset_id(*dataset_id); - return OkStatus(); + return absl::OkStatus(); } std::string new_dataset_id; @@ -607,10 +619,11 @@ Status DataServiceDispatcherImpl::GetOrRegisterDataset( request->dataset_id(), new_dataset_id)); response->set_dataset_id(new_dataset_id); VLOG(3) << "Registered new dataset with id " << new_dataset_id; - return OkStatus(); + return absl::OkStatus(); } -StatusOr> DataServiceDispatcherImpl::FindDataset( +absl::StatusOr> +DataServiceDispatcherImpl::FindDataset( const GetOrRegisterDatasetRequest& request) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { std::shared_ptr existing_dataset; @@ -655,7 +668,7 @@ Status DataServiceDispatcherImpl::GetDataServiceMetadata( VLOG(3) << "Get the data service metadata for dataset id: " << dataset_id << "."; *response->mutable_metadata() = dataset->metadata; - return OkStatus(); + return absl::OkStatus(); } Status DataServiceDispatcherImpl::GetDataServiceConfig( @@ -663,7 +676,7 @@ Status DataServiceDispatcherImpl::GetDataServiceConfig( GetDataServiceConfigResponse* response) { TF_RETURN_IF_ERROR(CheckStarted()); response->mutable_config()->set_deployment_mode(config_.deployment_mode()); - return OkStatus(); + return absl::OkStatus(); } Status DataServiceDispatcherImpl::GetOrCreateJob( @@ -692,7 +705,7 @@ Status DataServiceDispatcherImpl::GetOrCreateJob( } VLOG(3) << "Received job id " << job->id << " for CreateJob(" << request->DebugString() << ")"; - return OkStatus(); + return absl::OkStatus(); } Status DataServiceDispatcherImpl::GetOrCreateIteration( @@ -723,7 +736,7 @@ Status DataServiceDispatcherImpl::GetOrCreateIteration( TF_RETURN_IF_ERROR(AssignTasks(tasks)); VLOG(3) << "Created iteration " << iteration->iteration_id << " for CreateIteration(" << request->DebugString() << ")"; - return OkStatus(); + return absl::OkStatus(); } Status DataServiceDispatcherImpl::MaybeRemoveTask( @@ -737,7 +750,7 @@ Status DataServiceDispatcherImpl::MaybeRemoveTask( if (errors::IsNotFound(s)) { // Task is already removed. response->set_removed(true); - return OkStatus(); + return absl::OkStatus(); } TF_RETURN_IF_ERROR(s); auto& remover_ref = remove_task_requests_[task->task_id]; @@ -756,7 +769,7 @@ Status DataServiceDispatcherImpl::MaybeRemoveTask( response->set_removed(removed); if (!removed) { VLOG(1) << "Failed to remove task " << task->task_id; - return OkStatus(); + return absl::OkStatus(); } mutex_lock l(mu_); if (!task->removed) { @@ -774,7 +787,7 @@ Status DataServiceDispatcherImpl::MaybeRemoveTask( << " from tf.data service AutoScaler: " << auto_scaler_status; } VLOG(1) << "Task " << task->task_id << " successfully removed"; - return OkStatus(); + return absl::OkStatus(); } Status DataServiceDispatcherImpl::ReleaseIterationClient( @@ -799,7 +812,7 @@ Status DataServiceDispatcherImpl::ReleaseIterationClient( release_iteration_client->set_iteration_client_id(iteration_client_id); release_iteration_client->set_time_micros(env_->NowMicros()); TF_RETURN_IF_ERROR(Apply(update)); - return OkStatus(); + return absl::OkStatus(); } // Validates that the job matches the requested processing mode. @@ -832,7 +845,7 @@ Status DataServiceDispatcherImpl::ValidateMatchingJob( "Tried to create job with name ", job->job_name, ", but found an existing job with different parameters: ", diff); } - return OkStatus(); + return absl::OkStatus(); } Status DataServiceDispatcherImpl::CreateJob( @@ -857,7 +870,7 @@ Status DataServiceDispatcherImpl::CreateJob( TF_RETURN_IF_ERROR(state_.JobFromId(job_id, job)); tensorflow::metrics::RecordTFDataServiceJobsCreated( request.processing_mode_def(), is_coordinated_read); - return OkStatus(); + return absl::OkStatus(); } Status DataServiceDispatcherImpl::CreateIteration( @@ -882,7 +895,7 @@ Status DataServiceDispatcherImpl::CreateIteration( TF_RETURN_IF_ERROR(Apply(update)); TF_RETURN_IF_ERROR(state_.IterationFromId(iteration_id, iteration)); - return OkStatus(); + return absl::OkStatus(); } Status DataServiceDispatcherImpl::CreateTasksForWorker( @@ -900,7 +913,7 @@ Status DataServiceDispatcherImpl::CreateTasksForWorker( std::shared_ptr task; TF_RETURN_IF_ERROR(CreateTask(iteration, worker_address, task)); } - return OkStatus(); + return absl::OkStatus(); } Status DataServiceDispatcherImpl::AcquireIterationClientId( @@ -915,7 +928,7 @@ Status DataServiceDispatcherImpl::AcquireIterationClientId( TF_RETURN_IF_ERROR(Apply(update)); // Does not release clients before they start to read from the dataset. latest_client_heartbeats_time_[iteration_client_id] = absl::InfiniteFuture(); - return OkStatus(); + return absl::OkStatus(); } Status DataServiceDispatcherImpl::CreateTasksForIteration( @@ -930,7 +943,7 @@ Status DataServiceDispatcherImpl::CreateTasksForIteration( TF_RETURN_IF_ERROR(CreateTask(iteration, worker->address, task)); tasks.push_back(task); } - return OkStatus(); + return absl::OkStatus(); } Status DataServiceDispatcherImpl::CreatePendingTask( @@ -952,7 +965,7 @@ Status DataServiceDispatcherImpl::CreatePendingTask( worker->tags.end()}; create_task->set_worker_uid(worker->uid); TF_RETURN_IF_ERROR(Apply(update)); - return OkStatus(); + return absl::OkStatus(); } Status DataServiceDispatcherImpl::CreateTask( @@ -974,7 +987,7 @@ Status DataServiceDispatcherImpl::CreateTask( create_task->set_worker_uid(worker->uid); TF_RETURN_IF_ERROR(Apply(update)); TF_RETURN_IF_ERROR(state_.TaskFromId(task_id, task)); - return OkStatus(); + return absl::OkStatus(); } Status DataServiceDispatcherImpl::AssignTasks( @@ -982,7 +995,7 @@ Status DataServiceDispatcherImpl::AssignTasks( for (const auto& task : tasks) { TF_RETURN_IF_ERROR(AssignTask(task)); } - return OkStatus(); + return absl::OkStatus(); } Status DataServiceDispatcherImpl::GetOrCreateWorkerStub( @@ -993,7 +1006,7 @@ Status DataServiceDispatcherImpl::GetOrCreateWorkerStub( auto it = worker_stubs_.find(worker_address); if (it != worker_stubs_.end()) { out_stub = it->second.get(); - return OkStatus(); + return absl::OkStatus(); } } std::unique_ptr stub; @@ -1008,7 +1021,7 @@ Status DataServiceDispatcherImpl::GetOrCreateWorkerStub( } out_stub = worker.get(); } - return OkStatus(); + return absl::OkStatus(); } Status DataServiceDispatcherImpl::AssignTask(std::shared_ptr task) @@ -1032,7 +1045,7 @@ Status DataServiceDispatcherImpl::AssignTask(std::shared_ptr task) s.error_code() == grpc::StatusCode::CANCELLED) { // Worker is presumably preempted. We will assign the task to the worker // when it reconnects. - return OkStatus(); + return absl::OkStatus(); } return grpc_util::WrapError( absl::StrCat("Failed to submit task to worker ", task->worker_address), @@ -1040,7 +1053,7 @@ Status DataServiceDispatcherImpl::AssignTask(std::shared_ptr task) } VLOG(2) << "Finished assigning task " << task->task_id << " to worker " << task->worker_address; - return OkStatus(); + return absl::OkStatus(); } Status DataServiceDispatcherImpl::ClientHeartbeat( @@ -1149,7 +1162,7 @@ Status DataServiceDispatcherImpl::ClientHeartbeat( VLOG(4) << "Found " << response->task_info_size() << " tasks for iteration client id " << request->iteration_client_id(); - return OkStatus(); + return absl::OkStatus(); } Status DataServiceDispatcherImpl::GetWorkers(const GetWorkersRequest* request, @@ -1164,7 +1177,7 @@ Status DataServiceDispatcherImpl::GetWorkers(const GetWorkersRequest* request, } VLOG(3) << "Returning list of " << response->workers_size() << " workers from GetWorkers"; - return OkStatus(); + return absl::OkStatus(); } Status DataServiceDispatcherImpl::Snapshot(const SnapshotRequest* request, @@ -1188,6 +1201,7 @@ Status DataServiceDispatcherImpl::Snapshot(const SnapshotRequest* request, std::unique_ptr snapshot_manager, SnapshotManager::Start(*request, snapshot_assignment_manager_, env_)); snapshots_.insert({request->path(), std::move(snapshot_manager)}); + snapshot_assignment_manager_.AddSnapshot(request->path()); Update update; SnapshotUpdate* snapshot = update.mutable_snapshot(); @@ -1253,6 +1267,7 @@ absl::Status DataServiceDispatcherImpl::RestoreSnapshots() return; } snapshots_.insert({path, std::move(snapshot_manager.value())}); + snapshot_assignment_manager_.AddSnapshot(path); }); } thread_pool.reset(); @@ -1269,14 +1284,14 @@ Status DataServiceDispatcherImpl::DisableCompressionAtRuntime( if (dataset->metadata.compression() != DataServiceMetadata::COMPRESSION_SNAPPY) { response->set_no_compression_to_disable(true); - return OkStatus(); + return absl::OkStatus(); } if (std::optional compression_disabled_at_runtime = state_.CompressionDisabledAtRuntime(request->dataset_id()); compression_disabled_at_runtime.has_value()) { response->set_compression_disabled_at_runtime( *compression_disabled_at_runtime); - return OkStatus(); + return absl::OkStatus(); } response->set_compression_disabled_at_runtime( request->disable_compression_at_runtime()); @@ -1287,7 +1302,7 @@ Status DataServiceDispatcherImpl::DisableCompressionAtRuntime( compression_disabled_at_runtime->set_compression_disabled( request->disable_compression_at_runtime()); TF_RETURN_IF_ERROR(Apply(update)); - return OkStatus(); + return absl::OkStatus(); } Status DataServiceDispatcherImpl::PopulateTaskDef( @@ -1326,7 +1341,7 @@ Status DataServiceDispatcherImpl::PopulateTaskDef( io::JoinPath(DatasetsDir(config_.work_dir()), dataset->dataset_id); task_def->set_path(path); } - return OkStatus(); + return absl::OkStatus(); } Status DataServiceDispatcherImpl::CheckStarted() TF_LOCKS_EXCLUDED(mu_) { @@ -1334,7 +1349,7 @@ Status DataServiceDispatcherImpl::CheckStarted() TF_LOCKS_EXCLUDED(mu_) { if (!started_) { return errors::Unavailable("Dispatcher has not started yet."); } - return OkStatus(); + return absl::OkStatus(); } Status DataServiceDispatcherImpl::RecordSplitProduced( @@ -1439,7 +1454,7 @@ Status DataServiceDispatcherImpl::ReleaseMissingClients() TF_RETURN_IF_ERROR(Apply(update)); } } - return OkStatus(); + return absl::OkStatus(); } void DataServiceDispatcherImpl::RemoveWorkerFromAutoScaler( @@ -1506,7 +1521,7 @@ Status DataServiceDispatcherImpl::GcOldIterations() } LOG(INFO) << "Garbage collected iteration " << iteration->DebugString(); } - return OkStatus(); + return absl::OkStatus(); } bool DataServiceDispatcherImpl::ShouldGcIteration(const Iteration& iteration, diff --git a/tensorflow/core/data/service/dispatcher_impl.h b/tensorflow/core/data/service/dispatcher_impl.h index 3ece2684c95dab..5f1f31315a49fd 100644 --- a/tensorflow/core/data/service/dispatcher_impl.h +++ b/tensorflow/core/data/service/dispatcher_impl.h @@ -217,7 +217,7 @@ class DataServiceDispatcherImpl { TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); // Finds the dataset ID with the requested dataset ID. // Returns nullptr if no such dataset exists. - StatusOr> FindDataset( + absl::StatusOr> FindDataset( const GetOrRegisterDatasetRequest& request); // Gets a worker's stub from `worker_stubs_`, or if none exists, creates a // stub and stores it in `worker_stubs_`. A borrowed pointer to the stub is diff --git a/tensorflow/core/data/service/dispatcher_state.cc b/tensorflow/core/data/service/dispatcher_state.cc index c7ba48534aa601..22ab9ff2aeb988 100644 --- a/tensorflow/core/data/service/dispatcher_state.cc +++ b/tensorflow/core/data/service/dispatcher_state.cc @@ -92,7 +92,7 @@ Status DispatcherState::Apply(const Update& update) { return errors::Internal("Update type not set."); } - return OkStatus(); + return absl::OkStatus(); } void DispatcherState::RegisterDataset( @@ -141,7 +141,7 @@ Status DispatcherState::JobFromId(int64_t job_id, return errors::NotFound("Job with id ", job_id, " not found"); } job = it->second; - return OkStatus(); + return absl::OkStatus(); } Status DispatcherState::JobByName(const std::string& job_name, @@ -151,7 +151,7 @@ Status DispatcherState::JobByName(const std::string& job_name, return errors::NotFound("Job with name ", job_name, " not found"); } job = it->second; - return OkStatus(); + return absl::OkStatus(); } void DispatcherState::CreateIteration( @@ -330,7 +330,7 @@ Status DispatcherState::DatasetFromId( return errors::NotFound("Dataset id ", id, " not found"); } dataset = it->second; - return OkStatus(); + return absl::OkStatus(); } Status DispatcherState::WorkerFromAddress( @@ -340,7 +340,7 @@ Status DispatcherState::WorkerFromAddress( return errors::NotFound("Worker with address ", address, " not found."); } worker = it->second; - return OkStatus(); + return absl::OkStatus(); } std::vector> @@ -370,7 +370,7 @@ Status DispatcherState::IterationFromId( return errors::NotFound("Iteration id ", id, " not found"); } iteration = it->second; - return OkStatus(); + return absl::OkStatus(); } Status DispatcherState::IterationByKey( @@ -382,7 +382,7 @@ Status DispatcherState::IterationByKey( " not found"); } iteration = it->second; - return OkStatus(); + return absl::OkStatus(); } int64_t DispatcherState::NextAvailableJobId() const { @@ -400,7 +400,7 @@ Status DispatcherState::IterationForIterationClientId( return errors::NotFound("Iteration client id not found: ", iteration_client_id); } - return OkStatus(); + return absl::OkStatus(); } std::vector DispatcherState::ListActiveClientIds() { @@ -424,7 +424,7 @@ Status DispatcherState::TaskFromId(int64_t id, return errors::NotFound("Task ", id, " not found"); } task = it->second; - return OkStatus(); + return absl::OkStatus(); } Status DispatcherState::TasksForIteration( @@ -439,7 +439,7 @@ Status DispatcherState::TasksForIteration( for (const auto& task : it->second) { tasks.push_back(task); } - return OkStatus(); + return absl::OkStatus(); } Status DispatcherState::TasksForWorker( @@ -456,7 +456,7 @@ Status DispatcherState::TasksForWorker( for (const auto& task : worker_tasks) { tasks.push_back(task.second); } - return OkStatus(); + return absl::OkStatus(); } int64_t DispatcherState::NextAvailableTaskId() const { @@ -467,7 +467,7 @@ Status DispatcherState::ValidateWorker(absl::string_view worker_address) const { return worker_index_resolver_.ValidateWorker(worker_address); } -StatusOr DispatcherState::GetWorkerIndex( +absl::StatusOr DispatcherState::GetWorkerIndex( absl::string_view worker_address) const { return worker_index_resolver_.GetWorkerIndex(worker_address); } diff --git a/tensorflow/core/data/service/dispatcher_state.h b/tensorflow/core/data/service/dispatcher_state.h index 39d10453763251..e64b48771400ad 100644 --- a/tensorflow/core/data/service/dispatcher_state.h +++ b/tensorflow/core/data/service/dispatcher_state.h @@ -288,7 +288,8 @@ class DispatcherState { // If the dispatcher config specifies worker addresses, `GetWorkerIndex` // returns the worker index according to the list. This is useful for // deterministically sharding a dataset among a fixed set of workers. - StatusOr GetWorkerIndex(absl::string_view worker_address) const; + absl::StatusOr GetWorkerIndex( + absl::string_view worker_address) const; // Returns the paths of all snapshots initiated during the lifetime of this // journal. diff --git a/tensorflow/core/data/service/graph_rewriters.cc b/tensorflow/core/data/service/graph_rewriters.cc index 012f36ff175ff1..af2059ae89c707 100644 --- a/tensorflow/core/data/service/graph_rewriters.cc +++ b/tensorflow/core/data/service/graph_rewriters.cc @@ -93,7 +93,7 @@ bool ShouldReplaceDynamicPort(absl::string_view config_address, } } // namespace -StatusOr +absl::StatusOr RemoveCompressionMapRewriter::ApplyRemoveCompressionMapRewrite( const GraphDef& graph_def) { grappler::RemoveCompressionMap remove_compression_map; @@ -122,7 +122,8 @@ RemoveCompressionMapRewriter::GetRewriteConfig() const { return config; } -StatusOr AutoShardRewriter::Create(const TaskDef& task_def) { +absl::StatusOr AutoShardRewriter::Create( + const TaskDef& task_def) { TF_ASSIGN_OR_RETURN( AutoShardPolicy auto_shard_policy, ToAutoShardPolicy(task_def.processing_mode_def().sharding_policy())); @@ -130,7 +131,7 @@ StatusOr AutoShardRewriter::Create(const TaskDef& task_def) { task_def.worker_index()); } -StatusOr AutoShardRewriter::ApplyAutoShardRewrite( +absl::StatusOr AutoShardRewriter::ApplyAutoShardRewrite( const GraphDef& graph_def) { if (auto_shard_policy_ == AutoShardPolicy::OFF) { return graph_def; @@ -184,13 +185,13 @@ AutoShardRewriter::GetRewriteConfig() const { Status WorkerIndexResolver::ValidateWorker( absl::string_view worker_address) const { if (worker_addresses_.empty()) { - return OkStatus(); + return absl::OkStatus(); } for (absl::string_view config_address : worker_addresses_) { if (config_address == worker_address || ShouldReplaceDynamicPort(config_address, worker_address)) { - return OkStatus(); + return absl::OkStatus(); } } @@ -214,7 +215,7 @@ void WorkerIndexResolver::AddWorker(absl::string_view worker_address) { } } -StatusOr WorkerIndexResolver::GetWorkerIndex( +absl::StatusOr WorkerIndexResolver::GetWorkerIndex( absl::string_view worker_address) const { const auto it = absl::c_find(worker_addresses_, worker_address); if (it == worker_addresses_.cend()) { diff --git a/tensorflow/core/data/service/graph_rewriters.h b/tensorflow/core/data/service/graph_rewriters.h index 7c0c347a836b1d..84c43a4f29d579 100644 --- a/tensorflow/core/data/service/graph_rewriters.h +++ b/tensorflow/core/data/service/graph_rewriters.h @@ -37,7 +37,7 @@ namespace data { class RemoveCompressionMapRewriter { public: // Returns `graph_def` with the compression map removed. - StatusOr ApplyRemoveCompressionMapRewrite( + absl::StatusOr ApplyRemoveCompressionMapRewrite( const GraphDef& graph_def); private: @@ -49,11 +49,11 @@ class AutoShardRewriter { public: // Creates an `AutoShardRewriter` according to `task_def`. Returns an error if // the sharding policy is not a valid auto-shard policy. - static StatusOr Create(const TaskDef& task_def); + static absl::StatusOr Create(const TaskDef& task_def); // Applies auto-sharding to `graph_def`. If auto-shard policy is OFF, returns // the same graph as `graph_def`. Otherwise, returns the re-written graph. - StatusOr ApplyAutoShardRewrite(const GraphDef& graph_def); + absl::StatusOr ApplyAutoShardRewrite(const GraphDef& graph_def); private: AutoShardRewriter(AutoShardPolicy auto_shard_policy, int64_t num_workers, @@ -97,7 +97,8 @@ class WorkerIndexResolver { // Returns the worker index for the worker at `worker_address`. Returns a // NotFound error if the worker is not registered. - StatusOr GetWorkerIndex(absl::string_view worker_address) const; + absl::StatusOr GetWorkerIndex( + absl::string_view worker_address) const; private: std::vector worker_addresses_; diff --git a/tensorflow/core/data/service/graph_rewriters_test.cc b/tensorflow/core/data/service/graph_rewriters_test.cc index fe52fe9e4f38cc..a549c548353276 100644 --- a/tensorflow/core/data/service/graph_rewriters_test.cc +++ b/tensorflow/core/data/service/graph_rewriters_test.cc @@ -49,7 +49,8 @@ using ::tensorflow::testing::StatusIs; using ::testing::HasSubstr; using ::testing::SizeIs; -StatusOr GetNode(const GraphDef& graph_def, absl::string_view name) { +absl::StatusOr GetNode(const GraphDef& graph_def, + absl::string_view name) { for (const NodeDef& node : graph_def.node()) { if (node.name() == name) { return node; @@ -59,7 +60,8 @@ StatusOr GetNode(const GraphDef& graph_def, absl::string_view name) { name, graph_def.ShortDebugString())); } -StatusOr GetValue(const GraphDef& graph_def, absl::string_view name) { +absl::StatusOr GetValue(const GraphDef& graph_def, + absl::string_view name) { for (const NodeDef& node : graph_def.node()) { if (node.name() == name) { return node.attr().at("value").tensor().int64_val()[0]; diff --git a/tensorflow/core/data/service/grpc_dispatcher_impl_test.cc b/tensorflow/core/data/service/grpc_dispatcher_impl_test.cc index bee4676be3e5ca..393e754cdf524a 100644 --- a/tensorflow/core/data/service/grpc_dispatcher_impl_test.cc +++ b/tensorflow/core/data/service/grpc_dispatcher_impl_test.cc @@ -77,7 +77,7 @@ class GrpcDispatcherImplTest : public ::testing::Test { std::shared_ptr channel = ::grpc::CreateCustomChannel(GetDispatcherAddress(), credentials, args); dispatcher_client_stub_ = DispatcherService::NewStub(channel); - return OkStatus(); + return absl::OkStatus(); } std::string GetDispatcherAddress() const { diff --git a/tensorflow/core/data/service/grpc_worker_impl.cc b/tensorflow/core/data/service/grpc_worker_impl.cc index ce6ab35729fda2..d83879bed599b9 100644 --- a/tensorflow/core/data/service/grpc_worker_impl.cc +++ b/tensorflow/core/data/service/grpc_worker_impl.cc @@ -47,7 +47,7 @@ Status GrpcWorkerImpl::Start( worker_address_ = worker_address; TF_RETURN_IF_ERROR(impl_->Start(worker_address, transfer_servers)); LocalWorkers::Add(worker_address, impl_); - return OkStatus(); + return absl::OkStatus(); } void GrpcWorkerImpl::Stop() { diff --git a/tensorflow/core/data/service/grpc_worker_impl_test.cc b/tensorflow/core/data/service/grpc_worker_impl_test.cc index c21c6b5056b575..062117c94999d6 100644 --- a/tensorflow/core/data/service/grpc_worker_impl_test.cc +++ b/tensorflow/core/data/service/grpc_worker_impl_test.cc @@ -88,7 +88,7 @@ class GrpcWorkerImplTest : public ::testing::Test { std::shared_ptr channel = ::grpc::CreateCustomChannel(GetWorkerAddress(), credentials, args); worker_client_stub_ = WorkerService::NewStub(channel); - return OkStatus(); + return absl::OkStatus(); } std::string GetDispatcherAddress() const { diff --git a/tensorflow/core/data/service/journal.cc b/tensorflow/core/data/service/journal.cc index 5890e46fbf020c..d8ecd7adafce52 100644 --- a/tensorflow/core/data/service/journal.cc +++ b/tensorflow/core/data/service/journal.cc @@ -42,7 +42,7 @@ Status ParseSequenceNumber(const std::string& journal_file, return errors::InvalidArgument("Failed to parse journal file name: ", journal_file); } - return OkStatus(); + return absl::OkStatus(); } } // namespace @@ -57,7 +57,7 @@ FileJournalWriter::FileJournalWriter(Env* env, const std::string& journal_dir) Status FileJournalWriter::EnsureInitialized() { if (writer_) { - return OkStatus(); + return absl::OkStatus(); } std::vector journal_files; TF_RETURN_IF_ERROR(env_->RecursivelyCreateDir(journal_dir_)); @@ -73,7 +73,7 @@ Status FileJournalWriter::EnsureInitialized() { TF_RETURN_IF_ERROR(env_->NewAppendableFile(journal_file, &file_)); writer_ = std::make_unique(file_.get()); VLOG(1) << "Created journal writer to write to " << journal_file; - return OkStatus(); + return absl::OkStatus(); } Status FileJournalWriter::Write(const Update& update) { @@ -89,7 +89,7 @@ Status FileJournalWriter::Write(const Update& update) { if (VLOG_IS_ON(4)) { VLOG(4) << "Wrote journal entry: " << update.DebugString(); } - return OkStatus(); + return absl::OkStatus(); } FileJournalReader::FileJournalReader(Env* env, StringPiece journal_dir) @@ -97,7 +97,7 @@ FileJournalReader::FileJournalReader(Env* env, StringPiece journal_dir) Status FileJournalReader::EnsureInitialized() { if (reader_) { - return OkStatus(); + return absl::OkStatus(); } return UpdateFile(DataServiceJournalFile(journal_dir_, 0)); } @@ -115,7 +115,7 @@ Status FileJournalReader::Read(Update& update, bool& end_of_journal) { VLOG(3) << "Next journal file " << next_journal_file << " does not exist. End of journal reached."; end_of_journal = true; - return OkStatus(); + return absl::OkStatus(); } TF_RETURN_IF_ERROR(UpdateFile(next_journal_file)); continue; @@ -128,7 +128,7 @@ Status FileJournalReader::Read(Update& update, bool& end_of_journal) { VLOG(4) << "Read journal entry: " << update.DebugString(); } end_of_journal = false; - return OkStatus(); + return absl::OkStatus(); } } @@ -138,7 +138,7 @@ Status FileJournalReader::UpdateFile(const std::string& filename) { io::RecordReaderOptions opts; opts.buffer_size = 2 << 20; // 2MB reader_ = std::make_unique(file_.get(), opts); - return OkStatus(); + return absl::OkStatus(); } } // namespace data diff --git a/tensorflow/core/data/service/journal_test.cc b/tensorflow/core/data/service/journal_test.cc index 81d993c209d214..6ee4dd4bd3f4af 100644 --- a/tensorflow/core/data/service/journal_test.cc +++ b/tensorflow/core/data/service/journal_test.cc @@ -83,7 +83,7 @@ Status CheckJournalContent(StringPiece journal_dir, bool end_of_journal = false; TF_RETURN_IF_ERROR(reader.Read(result, end_of_journal)); EXPECT_TRUE(end_of_journal); - return OkStatus(); + return absl::OkStatus(); } } // namespace diff --git a/tensorflow/core/data/service/server_lib.cc b/tensorflow/core/data/service/server_lib.cc index a329526c909b7b..bfb4b3474e00de 100644 --- a/tensorflow/core/data/service/server_lib.cc +++ b/tensorflow/core/data/service/server_lib.cc @@ -54,7 +54,7 @@ Status GrpcDataServerBase::Start() { "Server cannot be started after it has been stopped."); } if (started_) { - return OkStatus(); + return absl::OkStatus(); } ::grpc::ServerBuilder builder; for (std::unique_ptr<::grpc::ServerBuilderOption>& option : server_options_) { @@ -81,7 +81,7 @@ Status GrpcDataServerBase::Start() { started_ = true; LOG(INFO) << "Started tf.data " << server_type_ << " running at 0.0.0.0:" << BoundPort(); - return OkStatus(); + return absl::OkStatus(); } void GrpcDataServerBase::Stop() { @@ -136,7 +136,7 @@ Status DispatchGrpcDataServer::NumWorkers(int* num_workers) { return grpc_util::WrapError("Failed to get workers", s); } *num_workers = resp.workers_size(); - return OkStatus(); + return absl::OkStatus(); } Status DispatchGrpcDataServer::SnapshotStreams( @@ -152,7 +152,7 @@ Status DispatchGrpcDataServer::SnapshotStreams( for (const auto& stream : resp.streams()) { streams->push_back(SnapshotStreamInfoWrapper(stream)); } - return OkStatus(); + return absl::OkStatus(); } size_t DispatchGrpcDataServer::NumActiveIterations() { @@ -212,7 +212,7 @@ void WorkerGrpcDataServer::MaybeStartAlternativeDataTransferServer( str_util::StringReplace(config_.data_transfer_address(), kPortPlaceholder, absl::StrCat(transfer_server_->Port()), /*replace_all=*/false)); - StatusOr compatibility_info = + absl::StatusOr compatibility_info = transfer_server_->GetCompatibilityInfo(); if (!compatibility_info.ok()) { LOG(ERROR) @@ -240,7 +240,7 @@ Status WorkerGrpcDataServer::StartServiceInternal() { std::vector transfer_servers = {grpc_transfer_server}; MaybeStartAlternativeDataTransferServer(transfer_servers); TF_RETURN_IF_ERROR(service_->Start(worker_address, transfer_servers)); - return OkStatus(); + return absl::OkStatus(); } void WorkerGrpcDataServer::StopServiceInternal() { service_->Stop(); } @@ -254,7 +254,7 @@ Status WorkerGrpcDataServer::NumTasks(int* num_tasks) { return grpc_util::WrapError("Failed to get tasks", s); } *num_tasks = resp.tasks_size(); - return OkStatus(); + return absl::OkStatus(); } Status WorkerGrpcDataServer::SnapshotTaskProgresses( @@ -269,7 +269,7 @@ Status WorkerGrpcDataServer::SnapshotTaskProgresses( for (const auto& progress : resp.snapshot_task_progresses()) { snapshot_task_progresses->push_back(SnapshotTaskProgressWrapper(progress)); } - return OkStatus(); + return absl::OkStatus(); } ServerStateExport WorkerGrpcDataServer::ExportState() const { @@ -281,13 +281,13 @@ ServerStateExport WorkerGrpcDataServer::ExportState() const { Status NewDispatchServer(const experimental::DispatcherConfig& config, std::unique_ptr& out_server) { out_server = std::make_unique(config); - return OkStatus(); + return absl::OkStatus(); } Status NewWorkerServer(const experimental::WorkerConfig& config, std::unique_ptr& out_server) { out_server = std::make_unique(config); - return OkStatus(); + return absl::OkStatus(); } } // namespace data diff --git a/tensorflow/core/data/service/snapshot/file_utils.cc b/tensorflow/core/data/service/snapshot/file_utils.cc index d5045122d673cc..0440b00b34f7f0 100644 --- a/tensorflow/core/data/service/snapshot/file_utils.cc +++ b/tensorflow/core/data/service/snapshot/file_utils.cc @@ -44,7 +44,7 @@ constexpr const char kTempFileSuffix[] = ".tmp"; absl::Status AtomicallyWrite( absl::string_view filename, tsl::Env* env, - absl::FunctionRef nonatomically_write) { + absl::FunctionRef nonatomically_write) { std::string uncommitted_filename = absl::StrCat(filename, "__"); if (!env->CreateUniqueFileName(&uncommitted_filename, kTempFileSuffix)) { return tsl::errors::Internal("Failed to write file ", filename, diff --git a/tensorflow/core/data/service/snapshot/parallel_tfrecord_writer.cc b/tensorflow/core/data/service/snapshot/parallel_tfrecord_writer.cc index 370c84f0762b59..0146b33a6f080a 100644 --- a/tensorflow/core/data/service/snapshot/parallel_tfrecord_writer.cc +++ b/tensorflow/core/data/service/snapshot/parallel_tfrecord_writer.cc @@ -47,12 +47,12 @@ ParallelTFRecordWriter::ParallelTFRecordWriter(const std::string& file_prefix, tsl::Env* env, ByteSize max_file_size, int64_t num_write_threads, - int64_t buffer_size_per_thread) + int64_t buffer_size) : env_(env), file_prefix_(file_prefix), compression_(compression), max_file_size_(max_file_size), - buffer_size_(num_write_threads * buffer_size_per_thread) { + buffer_size_(buffer_size) { thread_pool_ = std::make_unique( env_, tsl::ThreadOptions{}, "write_tfrecord_thread", num_write_threads); for (int64_t i = 0; i < num_write_threads; ++i) { @@ -162,8 +162,11 @@ ParallelTFRecordWriter::GetNextRecord(const std::string& filename) } std::vector record = std::move(buffer_.front()); + ByteSize estimated_size = EstimatedSize(record); + LOG_EVERY_N_SEC(INFO, 1) << "Writing TFRecord of " << estimated_size + << " to file " << file_prefix_ << "*."; ++file_stats_[filename].num_records; - file_stats_[filename].estimated_size += EstimatedSize(record); + file_stats_[filename].estimated_size += estimated_size; buffer_.pop_front(); ready_to_push_.SignalAll(); return record; diff --git a/tensorflow/core/data/service/snapshot/parallel_tfrecord_writer.h b/tensorflow/core/data/service/snapshot/parallel_tfrecord_writer.h index 453c0d3c29a7a0..db6cd182213b5e 100644 --- a/tensorflow/core/data/service/snapshot/parallel_tfrecord_writer.h +++ b/tensorflow/core/data/service/snapshot/parallel_tfrecord_writer.h @@ -58,9 +58,9 @@ class ParallelTFRecordWriter { public: explicit ParallelTFRecordWriter(const std::string& file_prefix, const std::string& compression, tsl::Env* env, - ByteSize max_file_size = ByteSize::GB(2), - int64_t num_write_threads = 10, - int64_t buffer_size_per_thread = 1); + ByteSize max_file_size = ByteSize::GB(6), + int64_t num_write_threads = 2, + int64_t buffer_size = 1); virtual ~ParallelTFRecordWriter(); ParallelTFRecordWriter(const ParallelTFRecordWriter&) = delete; ParallelTFRecordWriter& operator=(const ParallelTFRecordWriter&) = delete; diff --git a/tensorflow/core/data/service/snapshot/parallel_tfrecord_writer_test.cc b/tensorflow/core/data/service/snapshot/parallel_tfrecord_writer_test.cc index a71629559cc6c8..d55e7d1f8b0b82 100644 --- a/tensorflow/core/data/service/snapshot/parallel_tfrecord_writer_test.cc +++ b/tensorflow/core/data/service/snapshot/parallel_tfrecord_writer_test.cc @@ -171,7 +171,7 @@ class ParallelTFRecordWriterParamTest int64_t NumClients() const { return std::get<1>(GetParam()); } ByteSize MaxFileSize() const { return std::get<2>(GetParam()); } int64_t NumWriteThreads() const { return std::get<3>(GetParam()); } - int64_t BufferSizePerThread() const { return std::get<4>(GetParam()); } + int64_t BufferSize() const { return std::get<4>(GetParam()); } std::string Compression() const { return std::get<5>(GetParam()); } void VerifyFileStats( @@ -214,7 +214,7 @@ TEST_P(ParallelTFRecordWriterParamTest, WriteRecords) { TF_ASSERT_OK_AND_ASSIGN(std::string test_dir, TestDir()); ParallelTFRecordWriter parallel_tfrecord_writer( test_dir, Compression(), tsl::Env::Default(), MaxFileSize(), - NumWriteThreads(), BufferSizePerThread()); + NumWriteThreads(), BufferSize()); RangeIterator range_iterator(NumElements()); TF_ASSERT_OK_AND_ASSIGN( @@ -231,7 +231,7 @@ TEST_P(ParallelTFRecordWriterParamTest, ConcurrentWrites) { TF_ASSERT_OK_AND_ASSIGN(std::string test_dir, TestDir()); ParallelTFRecordWriter parallel_tfrecord_writer( test_dir, Compression(), tsl::Env::Default(), MaxFileSize(), - NumWriteThreads(), BufferSizePerThread()); + NumWriteThreads(), BufferSize()); std::vector> client_threads; for (int i = 0; i < NumClients(); ++i) { @@ -255,20 +255,21 @@ TEST_P(ParallelTFRecordWriterParamTest, ConcurrentWrites) { VerifyFileStats(stats, NumElements() * NumClients()); } -INSTANTIATE_TEST_SUITE_P( - ParallelTFRecordWriterParams, ParallelTFRecordWriterParamTest, - ::testing::Combine( - /*NumElements*/ ::testing::Values(0, 1, 100), - /*NumClients*/ ::testing::Values(1, 5), - /*MaxFileSize*/ - ::testing::Values(ByteSize::Bytes(1), ByteSize::Bytes(100), - ByteSize::GB(1)), - /*NumWriteThreads*/ ::testing::Values(1, 5), - /*BufferSizePerThread*/ ::testing::Values(1, 10000), - /*Compression*/ - ::testing::Values(tsl::io::compression::kNone, - tsl::io::compression::kSnappy, - tsl::io::compression::kZlib))); +INSTANTIATE_TEST_SUITE_P(ParallelTFRecordWriterParams, + ParallelTFRecordWriterParamTest, + ::testing::Combine( + /*NumElements*/ ::testing::Values(0, 1, 100), + /*NumClients*/ ::testing::Values(1, 5), + /*MaxFileSize*/ + ::testing::Values(ByteSize::Bytes(1), + ByteSize::Bytes(100), + ByteSize::GB(1)), + /*NumWriteThreads*/ ::testing::Values(1, 5), + /*BufferSize*/ ::testing::Values(1, 10000), + /*Compression*/ + ::testing::Values(tsl::io::compression::kNone, + tsl::io::compression::kSnappy, + tsl::io::compression::kZlib))); TEST(ParallelTFRecordWriterTest, WriteNoRecord) { TF_ASSERT_OK_AND_ASSIGN(std::string test_dir, TestDir()); diff --git a/tensorflow/core/data/service/snapshot/prefetched_split_provider.cc b/tensorflow/core/data/service/snapshot/prefetched_split_provider.cc index 9815020a103252..d28285966ff251 100644 --- a/tensorflow/core/data/service/snapshot/prefetched_split_provider.cc +++ b/tensorflow/core/data/service/snapshot/prefetched_split_provider.cc @@ -52,6 +52,7 @@ PrefetchedSplitProvider::PrefetchedSplitProvider( UpdateStatus(std::move(status)); return; } + absl::MutexLock l(&mu_); thread_pool_ = RunPrefetchThreads(); } @@ -160,13 +161,15 @@ PrefetchedSplitProvider::GetSplitFromProvider() ABSL_LOCKS_EXCLUDED(mu_) { } absl::Status PrefetchedSplitProvider::Reset() ABSL_LOCKS_EXCLUDED(mu_) { + std::unique_ptr thread_pool; { absl::MutexLock l(&mu_); reset_ = true; ready_to_push_.SignalAll(); ready_to_pop_.SignalAll(); + thread_pool = std::move(thread_pool_); } - thread_pool_.reset(); + thread_pool.reset(); TF_RETURN_IF_ERROR(split_provider_->Reset()); absl::MutexLock l(&mu_); @@ -182,10 +185,14 @@ absl::Status PrefetchedSplitProvider::Reset() ABSL_LOCKS_EXCLUDED(mu_) { } void PrefetchedSplitProvider::Cancel() { - // Finishes the in-flight threads. UpdateStatus( absl::CancelledError("tf.data prefetched split provider is shut down.")); - thread_pool_.reset(); + // Finishes the in-flight threads. + std::unique_ptr thread_pool; + { + absl::MutexLock l(&mu_); + thread_pool = std::move(thread_pool_); + } } absl::Status PrefetchedSplitProvider::InitDirs() { diff --git a/tensorflow/core/data/service/snapshot/prefetched_split_provider.h b/tensorflow/core/data/service/snapshot/prefetched_split_provider.h index 741f45ee88771d..518f8a3712d099 100644 --- a/tensorflow/core/data/service/snapshot/prefetched_split_provider.h +++ b/tensorflow/core/data/service/snapshot/prefetched_split_provider.h @@ -148,7 +148,7 @@ class PrefetchedSplitProvider { // Buffer to hold the splits. The size should be bounded by `buffer_size_`. absl::btree_set buffer_ ABSL_GUARDED_BY(mu_); - std::unique_ptr thread_pool_; + std::unique_ptr thread_pool_ ABSL_GUARDED_BY(mu_); }; } // namespace data diff --git a/tensorflow/core/data/service/snapshot/snapshot_manager.cc b/tensorflow/core/data/service/snapshot/snapshot_manager.cc index ae46955d08e28f..d2860d38ff9e3d 100644 --- a/tensorflow/core/data/service/snapshot/snapshot_manager.cc +++ b/tensorflow/core/data/service/snapshot/snapshot_manager.cc @@ -25,6 +25,7 @@ limitations under the License. #include #include "absl/algorithm/container.h" +#include "absl/container/btree_map.h" #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" #include "absl/log/log.h" @@ -97,7 +98,7 @@ std::string PrefetchedSplitDir(const std::string& snapshot_path, absl::StatusOr SnapshotAssignmentManager::TryAddAssignment( absl::string_view snapshot_path, absl::string_view worker_address, - int64_t stream_index) { + int64_t stream_index) TF_LOCKS_EXCLUDED(mu_) { tsl::mutex_lock l(mu_); if (assignments_[worker_address].size() >= worker_max_concurrent_snapshots()) { @@ -110,15 +111,60 @@ absl::StatusOr SnapshotAssignmentManager::TryAddAssignment( " already had an assignment for ", assignment.DebugString())); } + ++snapshot_assignment_counts_[snapshot_path]; return true; } void SnapshotAssignmentManager::RemoveAssignment( absl::string_view snapshot_path, absl::string_view worker_address, - int64_t stream_index) { + int64_t stream_index) TF_LOCKS_EXCLUDED(mu_) { tsl::mutex_lock l(mu_); - assignments_[worker_address].erase( + auto num_erased = assignments_[worker_address].erase( {std::string(snapshot_path), stream_index}); + if ((snapshot_assignment_counts_[snapshot_path] -= num_erased) <= 0) { + snapshot_assignment_counts_.erase(snapshot_path); + } +} + +void SnapshotAssignmentManager::AddSnapshot(absl::string_view snapshot_path) + TF_LOCKS_EXCLUDED(mu_) { + tsl::mutex_lock l(mu_); + if (!snapshot_assignment_counts_.contains(snapshot_path)) { + snapshot_assignment_counts_[snapshot_path] = 0; + } +} + +std::vector SnapshotAssignmentManager::LoadBalanceSnapshots( + absl::string_view worker_address) TF_LOCKS_EXCLUDED(mu_) { + std::vector result; + + tsl::mutex_lock l(mu_); + result.reserve(snapshot_assignment_counts_.size()); + const auto it = assignments_.find(worker_address); + if (it != assignments_.end()) { + for (const Assignment& assignment : it->second) { + result.push_back(assignment.snapshot_path); + } + } + if (result.size() >= worker_max_concurrent_snapshots()) { + return result; + } + + absl::btree_multimap snapshots_by_count; + for (const auto& [snapshot, count] : snapshot_assignment_counts_) { + snapshots_by_count.emplace(count, snapshot); + } + + for (const auto& [_, snapshot] : snapshots_by_count) { + if (absl::c_find(result, snapshot) == result.end()) { + // Assigns the next least-assigned snapshot. Assigns one snapshot at a + // time in case workers reach the assignment limit before the user has + // submitted all requests. + result.push_back(snapshot); + return result; + } + } + return result; } absl::StatusOr> SnapshotManager::Start( diff --git a/tensorflow/core/data/service/snapshot/snapshot_manager.h b/tensorflow/core/data/service/snapshot/snapshot_manager.h index 52f99203c0ed00..8c53ae98650878 100644 --- a/tensorflow/core/data/service/snapshot/snapshot_manager.h +++ b/tensorflow/core/data/service/snapshot/snapshot_manager.h @@ -60,6 +60,16 @@ class SnapshotAssignmentManager { void RemoveAssignment(absl::string_view snapshot_path, absl::string_view worker_address, int64_t stream_index); + // Adds a new snapshot. + void AddSnapshot(absl::string_view snapshot_path); + + // Load balances snapshots by the number of assigned streams. Given a worker, + // returns snapshots in the following order: + // - Snapshots already assigned to this worker. + // - Snapshots with the fewest assignments. + std::vector LoadBalanceSnapshots( + absl::string_view worker_address); + // Returns the maximum concurrent snapshots processed by each worker. int64_t worker_max_concurrent_snapshots() const { return worker_max_concurrent_snapshots_; @@ -91,6 +101,10 @@ class SnapshotAssignmentManager { absl::flat_hash_map> assignments_ TF_GUARDED_BY(mu_); + // A mapping from snapshot to the number of assigned workers. + absl::flat_hash_map snapshot_assignment_counts_ + TF_GUARDED_BY(mu_); + // The maximum number of snapshots that a worker can concurrently process at a // given point in time. This is a tradeoff between worker resource usage and // snapshot wall time. A value of 0 indicates that the decision should be left diff --git a/tensorflow/core/data/service/snapshot/snapshot_manager_test.cc b/tensorflow/core/data/service/snapshot/snapshot_manager_test.cc index 98c19fc52f3850..9e116536ad07e5 100644 --- a/tensorflow/core/data/service/snapshot/snapshot_manager_test.cc +++ b/tensorflow/core/data/service/snapshot/snapshot_manager_test.cc @@ -37,8 +37,13 @@ namespace tensorflow { namespace data { namespace { +using ::testing::_; +using ::testing::ElementsAre; using ::testing::IsEmpty; +using ::testing::Not; using ::testing::SizeIs; +using ::testing::UnorderedElementsAre; +using ::tsl::testing::IsOkAndHolds; using ::tsl::testing::StatusIs; template @@ -288,6 +293,65 @@ TEST(SnapshotManagerTest, ResumeFromError) { EXPECT_THAT(heartbeat_response.snapshot_tasks(), IsEmpty()); } +TEST(SnapshotAssignmentManagerTest, LoadBalanceSnapshots) { + SnapshotAssignmentManager snapshot_assignment_manager( + /*worker_max_concurrent_snapshots=*/2); + snapshot_assignment_manager.AddSnapshot("snapshot_1"); + snapshot_assignment_manager.AddSnapshot("snapshot_2"); + snapshot_assignment_manager.AddSnapshot("snapshot_3"); + + // Worker 1: snapshot 3 + // Worker 2: N/A + EXPECT_THAT(snapshot_assignment_manager.TryAddAssignment( + "snapshot_3", "worker_1", /*stream_index=*/0), + IsOkAndHolds(true)); + EXPECT_THAT(snapshot_assignment_manager.LoadBalanceSnapshots("worker_1"), + ElementsAre("snapshot_3", _)); + ASSERT_THAT(snapshot_assignment_manager.LoadBalanceSnapshots("worker_2"), + ElementsAre(Not("snapshot_3"))); + + // Worker 1: snapshots 2, 3 + // Worker 2: N/A + EXPECT_THAT(snapshot_assignment_manager.TryAddAssignment( + "snapshot_2", "worker_1", /*stream_index=*/0), + IsOkAndHolds(true)); + ASSERT_THAT(snapshot_assignment_manager.LoadBalanceSnapshots("worker_1"), + UnorderedElementsAre("snapshot_2", "snapshot_3")); + EXPECT_THAT(snapshot_assignment_manager.LoadBalanceSnapshots("worker_2"), + ElementsAre("snapshot_1")); + + // Worker 1: snapshots 2, 3 + // Worker 2: snapshot 2 + EXPECT_THAT(snapshot_assignment_manager.TryAddAssignment( + "snapshot_1", "worker_1", /*stream_index=*/0), + IsOkAndHolds(false)); + EXPECT_THAT(snapshot_assignment_manager.TryAddAssignment( + "snapshot_2", "worker_2", /*stream_index=*/0), + IsOkAndHolds(true)); + ASSERT_THAT(snapshot_assignment_manager.LoadBalanceSnapshots("worker_1"), + UnorderedElementsAre("snapshot_2", "snapshot_3")); + EXPECT_THAT(snapshot_assignment_manager.LoadBalanceSnapshots("worker_2"), + ElementsAre("snapshot_2", "snapshot_1")); + + // Worker 1: snapshot 3 + // Worker 2: snapshot 2 + snapshot_assignment_manager.RemoveAssignment("snapshot_2", "worker_1", + /*stream_index=*/0); + EXPECT_THAT(snapshot_assignment_manager.LoadBalanceSnapshots("worker_1"), + ElementsAre("snapshot_3", "snapshot_1")); + ASSERT_THAT(snapshot_assignment_manager.LoadBalanceSnapshots("worker_2"), + ElementsAre("snapshot_2", "snapshot_1")); + + // Worker 1: N/A + // Worker 2: snapshot 2 + snapshot_assignment_manager.RemoveAssignment("snapshot_3", "worker_1", + /*stream_index=*/0); + ASSERT_THAT(snapshot_assignment_manager.LoadBalanceSnapshots("worker_1"), + ElementsAre("snapshot_1")); + ASSERT_THAT(snapshot_assignment_manager.LoadBalanceSnapshots("worker_2"), + ElementsAre("snapshot_2", "snapshot_1")); +} + } // namespace } // namespace data } // namespace tensorflow diff --git a/tensorflow/core/data/service/snapshot/snapshot_stream_writer.cc b/tensorflow/core/data/service/snapshot/snapshot_stream_writer.cc index 0a356e9dec7bf8..01412806950427 100644 --- a/tensorflow/core/data/service/snapshot/snapshot_stream_writer.cc +++ b/tensorflow/core/data/service/snapshot/snapshot_stream_writer.cc @@ -53,7 +53,7 @@ namespace tensorflow { namespace data { namespace { -constexpr int64_t kTFRecordReaderOutputBufferSize = 512 << 20; // 512MB +constexpr ByteSize kTFRecordReaderOutputBufferSize = ByteSize::GB(1); constexpr int64_t kUnknownNumElements = -1; constexpr const char kFileShardDelimiter[] = "_CHUNK_SHARDS_"; @@ -192,15 +192,16 @@ absl::Status SnapshotStreamWriter::WriteChunks() { } bool SnapshotStreamWriter::ShouldWriteRecord() const { - { - mutex_lock l(mu_); - if (!completed_.ok()) { - return false; - } + mutex_lock l(mu_); + if (!completed_.ok() || end_of_sequence_) { + return false; } const absl::Time now = absl::FromUnixMicros(params_.env->NowMicros()); - return !end_of_sequence_ && - now < last_commit_time_ + params_.checkpoint_interval; + // Adjusts the checkpoint interval to speed up initial commits during startup. + // It will grow gradually from 5 min to the configured checkpoint interval. + const absl::Duration adjusted_checkpoint_interval = std::min( + params_.checkpoint_interval, absl::Minutes(0.5 * chunk_index_ + 5)); + return now < last_commit_time_ + adjusted_checkpoint_interval; } absl::Status SnapshotStreamWriter::WriteRecord(ParallelTFRecordWriter& writer) { @@ -358,9 +359,9 @@ absl::Status SnapshotStreamWriter::Restore() { kUnknownNumElements); } TF_RETURN_IF_ERROR(checkpoint_name.status()); - snapshot_util::TFRecordReaderImpl reader(CheckpointPath(*checkpoint_name), - params_.compression, - kTFRecordReaderOutputBufferSize); + snapshot_util::TFRecordReaderImpl reader( + CheckpointPath(*checkpoint_name), params_.compression, + kTFRecordReaderOutputBufferSize.ToUnsignedBytes()); TF_RETURN_IF_ERROR(reader.Initialize(params_.env)); TF_ASSIGN_OR_RETURN(std::vector serialized_tensors, reader.GetTensors()); diff --git a/tensorflow/core/data/service/snapshot/snapshot_stream_writer.h b/tensorflow/core/data/service/snapshot/snapshot_stream_writer.h index bd8aa358383aa5..3179ab167a6620 100644 --- a/tensorflow/core/data/service/snapshot/snapshot_stream_writer.h +++ b/tensorflow/core/data/service/snapshot/snapshot_stream_writer.h @@ -43,8 +43,8 @@ limitations under the License. namespace tensorflow { namespace data { -constexpr ByteSize kDefaultMaxChunkSize = ByteSize::GB(2); -constexpr absl::Duration kDefaultCheckpointInterval = absl::Minutes(20); +constexpr ByteSize kDefaultMaxChunkSize = ByteSize::GB(6); +constexpr absl::Duration kDefaultCheckpointInterval = absl::Minutes(30); struct SnapshotWriterParams { // The directory path of the snapshot. See the comment on SnapshotStreamWriter @@ -64,7 +64,9 @@ struct SnapshotWriterParams { // The maximum number of bytes in each chunk. ByteSize max_chunk_size = kDefaultMaxChunkSize; - // How often should checkpoints be written. + // How often should checkpoints be written at the steady state. We write + // checkpoints (and committing chunks) more frequently at the startup time to + // avoid starving training jobs during startup. absl::Duration checkpoint_interval = kDefaultCheckpointInterval; // If true, keep temporary files (e.g., checkpoints) after completing the diff --git a/tensorflow/core/data/service/snapshot/utils.h b/tensorflow/core/data/service/snapshot/utils.h index b7613ed718ed4a..1ea4d80b649dcf 100644 --- a/tensorflow/core/data/service/snapshot/utils.h +++ b/tensorflow/core/data/service/snapshot/utils.h @@ -25,6 +25,7 @@ limitations under the License. namespace tensorflow { namespace data { +// Estimates the size of the Tensors when serialized as TensorProtos. ByteSize EstimatedSize(const std::vector& tensors); } // namespace data diff --git a/tensorflow/core/data/service/split_provider.cc b/tensorflow/core/data/service/split_provider.cc index c5aef37fb5b064..ddf9d7c7a67138 100644 --- a/tensorflow/core/data/service/split_provider.cc +++ b/tensorflow/core/data/service/split_provider.cc @@ -58,13 +58,13 @@ Status DataServiceSplitProvider::GetNext(Tensor* split, bool* end_of_splits) << "; with iteration_id=" << iteration_id_ << ", repetition=" << repetition_; } - return OkStatus(); + return absl::OkStatus(); } Status DataServiceSplitProvider::Reset() TF_LOCKS_EXCLUDED(mu_) { mutex_lock l(mu_); repetition_++; - return OkStatus(); + return absl::OkStatus(); } Status DataServiceSplitProvider::Save( @@ -89,7 +89,7 @@ Status CreateSplitProviders( TF_RETURN_IF_ERROR(standalone::Dataset::FromGraph(params, dataset_def.graph(), &standalone_dataset)); TF_RETURN_IF_ERROR(standalone_dataset->MakeSplitProviders(&split_providers)); - return OkStatus(); + return absl::OkStatus(); } } // namespace data diff --git a/tensorflow/core/data/service/task_runner.cc b/tensorflow/core/data/service/task_runner.cc index 3eb3af401647db..c56d0989b1ceaa 100644 --- a/tensorflow/core/data/service/task_runner.cc +++ b/tensorflow/core/data/service/task_runner.cc @@ -105,7 +105,7 @@ Status TaskRunner::Create(const experimental::WorkerConfig& worker_config, } else { out = std::make_unique(std::move(iterator)); } - return OkStatus(); + return absl::OkStatus(); } FirstComeFirstServedTaskRunner::FirstComeFirstServedTaskRunner( @@ -123,14 +123,14 @@ Status FirstComeFirstServedTaskRunner::GetNext(const GetElementRequest& req, Status FirstComeFirstServedTaskRunner::GetNext(GetElementResult& result) { TF_ASSIGN_OR_RETURN(result, buffer_.Pop()); - return OkStatus(); + return absl::OkStatus(); } Status FirstComeFirstServedTaskRunner::PrefetchFn() { while (true) { TF_RETURN_IF_ERROR(buffer_.Push(GetNextFromInputIterator())); } - return OkStatus(); + return absl::OkStatus(); } void FirstComeFirstServedTaskRunner::RunPrefetchThread() { @@ -189,7 +189,7 @@ Status CachingTaskRunner::GetNext(const GetElementRequest& req, TF_ASSIGN_OR_RETURN(std::shared_ptr element, cache_.Get(req.trainer_id())); result = element->Copy(); - return OkStatus(); + return absl::OkStatus(); } CachingTaskRunner::GetElementResultSequence::GetElementResultSequence( @@ -248,7 +248,7 @@ Status RoundRobinTaskRunner::ValidateRequest(const GetElementRequest& req) { "Requesting data for consumer index ", req.consumer_index(), ", but the task is configured for only ", num_consumers_, " consumers"); } - return OkStatus(); + return absl::OkStatus(); } Status RoundRobinTaskRunner::PrepareFullRound(int64_t wait_us) @@ -259,7 +259,7 @@ Status RoundRobinTaskRunner::PrepareFullRound(int64_t wait_us) TF_RETURN_IF_ERROR(prefetch_thread_.FillBuffer(wait_us, buffer_)); round_skipped_ = buffer_.empty(); new_round_cv_.notify_all(); - return OkStatus(); + return absl::OkStatus(); } Status RoundRobinTaskRunner::PreparePartialRound() @@ -273,11 +273,11 @@ Status RoundRobinTaskRunner::PreparePartialRound() if (next_round_request.skipped_previous_round()) { VLOG(1) << "Skipping partial round"; round_skipped_ = true; - return OkStatus(); + return absl::OkStatus(); } TF_RETURN_IF_ERROR(prefetch_thread_.FillBuffer(/*wait_us=*/-1, buffer_)); round_skipped_ = false; - return OkStatus(); + return absl::OkStatus(); } Status RoundRobinTaskRunner::PrepareRound(const GetElementRequest& req) { @@ -333,7 +333,7 @@ Status RoundRobinTaskRunner::GetNext(const GetElementRequest& req, if (round_skipped_) { VLOG(1) << worker_address_ << ": Buffer not ready, skipping round " << current_round_ << " for consumer " << req.consumer_index(); - return OkStatus(); + return absl::OkStatus(); } auto& buffer_result = buffer_[req.consumer_index()]; result.element_index = buffer_result->index; @@ -351,7 +351,7 @@ Status RoundRobinTaskRunner::GetNext(const GetElementRequest& req, << req.round_index() << ". element size " << size; } result.components = std::move(element); - return OkStatus(); + return absl::OkStatus(); } void RoundRobinTaskRunner::Cancel() { @@ -431,14 +431,14 @@ Status PrefetchThread::FillBuffer(int64_t wait_us, } if (buffer_.size() < round_size_) { DCHECK_GE(wait_us, 0); - return OkStatus(); + return absl::OkStatus(); } for (auto& elem : buffer_) { out.push_back(std::move(elem)); } buffer_.clear(); cv_.notify_all(); - return OkStatus(); + return absl::OkStatus(); } Status PrefetchThread::GetStatus() { diff --git a/tensorflow/core/data/service/task_runner.h b/tensorflow/core/data/service/task_runner.h index 565a7f4727a477..d1f95a97d9b2aa 100644 --- a/tensorflow/core/data/service/task_runner.h +++ b/tensorflow/core/data/service/task_runner.h @@ -229,7 +229,7 @@ class PrefetchThread { // Buffered results for the next round. std::vector> buffer_ TF_GUARDED_BY(mu_); // The status if the prefetch thread fails. - Status status_ TF_GUARDED_BY(mu_) = OkStatus(); + Status status_ TF_GUARDED_BY(mu_) = absl::OkStatus(); // Condition variable notified when elements are added to or removed from // `buffer_`, or when `status_` is changed. condition_variable cv_; diff --git a/tensorflow/core/data/service/task_runner_test.cc b/tensorflow/core/data/service/task_runner_test.cc index 0c1ef895742b0c..1ea0d1645adbd8 100644 --- a/tensorflow/core/data/service/task_runner_test.cc +++ b/tensorflow/core/data/service/task_runner_test.cc @@ -64,13 +64,13 @@ class RangeIterator : public TaskIterator { Status GetNext(std::vector& element, bool& end_of_sequence) override { end_of_sequence = (next_ >= range_); if (end_of_sequence) { - return OkStatus(); + return absl::OkStatus(); } element = {Tensor{next_++}}; if (repeat_) { next_ = next_ % range_; } - return OkStatus(); + return absl::OkStatus(); } int64_t Cardinality() const override { @@ -89,7 +89,7 @@ class InfiniteRangeIterator : public TaskIterator { Status GetNext(std::vector& element, bool& end_of_sequence) override { element = {Tensor{next_++}}; - return OkStatus(); + return absl::OkStatus(); } int64_t Cardinality() const override { return kInfiniteCardinality; } @@ -107,12 +107,12 @@ class ElementOrErrorIterator : public TaskIterator { Status GetNext(std::vector& element, bool& end_of_sequence) override { end_of_sequence = (next_ >= elements_.size()); if (end_of_sequence) { - return OkStatus(); + return absl::OkStatus(); } const StatusOr& next_element = elements_[next_++]; TF_RETURN_IF_ERROR(next_element.status()); element = {Tensor{*next_element}}; - return OkStatus(); + return absl::OkStatus(); } int64_t Cardinality() const override { return elements_.size(); } @@ -193,7 +193,7 @@ Status RunConsumer(int64_t consumer_index, int64_t start_index, } } while (result.skip); } - return OkStatus(); + return absl::OkStatus(); } } // namespace diff --git a/tensorflow/core/data/service/test_cluster.cc b/tensorflow/core/data/service/test_cluster.cc index 28fa31900a68ff..ec6976812e83ad 100644 --- a/tensorflow/core/data/service/test_cluster.cc +++ b/tensorflow/core/data/service/test_cluster.cc @@ -72,6 +72,8 @@ Status TestCluster::Initialize() { config_.job_gc_check_interval_ms); dispatcher_config.set_job_gc_timeout_ms(config_.job_gc_timeout_ms); dispatcher_config.set_client_timeout_ms(config_.client_timeout_ms); + dispatcher_config.set_worker_max_concurrent_snapshots( + config_.worker_max_concurrent_snapshots); TF_RETURN_IF_ERROR(NewDispatchServer(dispatcher_config, dispatcher_)); TF_RETURN_IF_ERROR(dispatcher_->Start()); dispatcher_address_ = absl::StrCat("localhost:", dispatcher_->BoundPort()); @@ -80,7 +82,7 @@ Status TestCluster::Initialize() { for (int i = 0; i < num_workers_; ++i) { TF_RETURN_IF_ERROR(AddWorker()); } - return OkStatus(); + return absl::OkStatus(); } Status TestCluster::AddWorker(std::optional port) { @@ -99,7 +101,7 @@ Status TestCluster::AddWorker(std::optional port) { TF_RETURN_IF_ERROR(worker->Start()); worker_addresses_.push_back(absl::StrCat("localhost:", worker->BoundPort())); workers_.push_back(std::move(worker)); - return OkStatus(); + return absl::OkStatus(); } std::string TestCluster::DispatcherAddress() const { diff --git a/tensorflow/core/data/service/test_cluster.h b/tensorflow/core/data/service/test_cluster.h index 8654ca0051ab71..3b5f5f02032d03 100644 --- a/tensorflow/core/data/service/test_cluster.h +++ b/tensorflow/core/data/service/test_cluster.h @@ -54,6 +54,7 @@ class TestCluster { int64_t worker_heartbeat_interval_ms = 0; int64_t job_gc_check_interval_ms = 0; int64_t job_gc_timeout_ms = 0; + int64_t worker_max_concurrent_snapshots = 0; std::string work_dir; }; diff --git a/tensorflow/core/data/service/thread_safe_buffer.h b/tensorflow/core/data/service/thread_safe_buffer.h index 49b9200256e086..3c18da024a52ac 100644 --- a/tensorflow/core/data/service/thread_safe_buffer.h +++ b/tensorflow/core/data/service/thread_safe_buffer.h @@ -54,7 +54,7 @@ class ThreadSafeBuffer final { condition_variable ready_to_pop_; condition_variable ready_to_push_; std::deque> results_ TF_GUARDED_BY(mu_); - Status status_ TF_GUARDED_BY(mu_) = OkStatus(); + Status status_ TF_GUARDED_BY(mu_) = absl::OkStatus(); ThreadSafeBuffer(const ThreadSafeBuffer&) = delete; void operator=(const ThreadSafeBuffer&) = delete; @@ -94,7 +94,7 @@ Status ThreadSafeBuffer::Push(StatusOr value) { } results_.push_back(std::move(value)); ready_to_pop_.notify_one(); - return OkStatus(); + return absl::OkStatus(); } template diff --git a/tensorflow/core/data/service/utils.cc b/tensorflow/core/data/service/utils.cc index 7a4a0c4824ad27..43ad439d21fc55 100644 --- a/tensorflow/core/data/service/utils.cc +++ b/tensorflow/core/data/service/utils.cc @@ -32,7 +32,7 @@ Status WriteDatasetDef(const std::string& path, const DatasetDef& dataset_def) { TF_RETURN_IF_ERROR(Env::Default()->NewWritableFile(path, &file)); io::RecordWriter writer(file.get()); TF_RETURN_IF_ERROR(writer.WriteRecord(dataset_def.SerializeAsString())); - return OkStatus(); + return absl::OkStatus(); } Status ReadDatasetDef(const std::string& path, DatasetDef& dataset_def) { @@ -49,7 +49,7 @@ Status ReadDatasetDef(const std::string& path, DatasetDef& dataset_def) { if (!dataset_def.ParseFromString(record)) { return errors::DataLoss("Failed to parse dataset definition"); } - return OkStatus(); + return absl::OkStatus(); } } // namespace data diff --git a/tensorflow/core/data/service/validate_utils.cc b/tensorflow/core/data/service/validate_utils.cc index 99735979d4b574..45d58584d72ae3 100644 --- a/tensorflow/core/data/service/validate_utils.cc +++ b/tensorflow/core/data/service/validate_utils.cc @@ -47,7 +47,7 @@ Status ValidateElementSpec(const std::string& dataset_id, const std::string& encoded_spec1, const std::string& encoded_spec2) { if (encoded_spec1.empty() && encoded_spec2.empty()) { - return OkStatus(); + return absl::OkStatus(); } TF_ASSIGN_OR_RETURN(StructuredValue element_spec1, DecodeElementSpec(dataset_id, encoded_spec1)); @@ -67,7 +67,7 @@ Status ValidateElementSpec(const std::string& dataset_id, ". To fix this error, make sure you're registering the same dataset ", "with the same ID."); } - return OkStatus(); + return absl::OkStatus(); } Status ValidateDatasetMetadata(const std::string& dataset_id, @@ -89,7 +89,7 @@ Status ValidateDatasetMetadata(const std::string& dataset_id, "for dataset ID ", dataset_id, ": ", diff, ". To fix this error, make ", "sure you're registering the same dataset with the same ID."); } - return OkStatus(); + return absl::OkStatus(); } } // namespace diff --git a/tensorflow/core/data/service/worker_client.cc b/tensorflow/core/data/service/worker_client.cc index 9e633849ae2eb9..344a5fa0cbd223 100644 --- a/tensorflow/core/data/service/worker_client.cc +++ b/tensorflow/core/data/service/worker_client.cc @@ -77,11 +77,11 @@ Status DataServiceWorkerClient::GetElement(const GetElementRequest& req, Status DataServiceWorkerClient::EnsureInitialized() { mutex_lock l(mu_); if (client_) { - return OkStatus(); + return absl::OkStatus(); } TF_RETURN_IF_ERROR(DataTransferClient::Build( GetDataTransferProtocol(), {protocol_, address_}, &client_)); - return OkStatus(); + return absl::OkStatus(); } std::string DataServiceWorkerClient::GetDataTransferProtocol() const { @@ -153,7 +153,7 @@ class GrpcDataTransferClient : public DataTransferClient { case GetElementResponse::ELEMENT_NOT_SET: break; } - return OkStatus(); + return absl::OkStatus(); } void TryCancel() override { @@ -188,7 +188,7 @@ class GrpcTransferClientRegistrar { config.protocol, &credentials)); *out = std::make_unique(credentials, config.address); - return OkStatus(); + return absl::OkStatus(); }); } }; @@ -234,7 +234,7 @@ class LocalDataTransferClient : public DataTransferClient { return errors::Cancelled(absl::Substitute( "Client for worker $0 has been cancelled.", worker_address_)); } - return OkStatus(); + return absl::OkStatus(); } StatusOr> GetWorker( @@ -263,7 +263,7 @@ class LocalTransferClientRegistrar { kLocalTransferProtocol, [](DataTransferClient::Config config, std::unique_ptr* out) { *out = std::make_unique(config.address); - return OkStatus(); + return absl::OkStatus(); }); } }; diff --git a/tensorflow/core/data/service/worker_impl.cc b/tensorflow/core/data/service/worker_impl.cc index 41e3b5775971c0..1f74795da60770 100644 --- a/tensorflow/core/data/service/worker_impl.cc +++ b/tensorflow/core/data/service/worker_impl.cc @@ -90,7 +90,7 @@ Status MoveElementToResponse(std::vector&& element, UncompressedElement* uncompressed = resp.mutable_uncompressed(); component.AsProtoTensorContent(uncompressed->add_components()); } - return OkStatus(); + return absl::OkStatus(); } Variant& variant = element[0].scalar()(); CompressedElement* compressed = variant.get(); @@ -101,7 +101,7 @@ Status MoveElementToResponse(std::vector&& element, variant.TypeName()); } *resp.mutable_compressed() = *compressed; - return OkStatus(); + return absl::OkStatus(); } WorkerConfig ApplyWorkerDefaults(const WorkerConfig& config) { @@ -187,7 +187,8 @@ Status DataServiceWorkerImpl::Start( should_retry, "Worker heartbeat.", /*deadline_micros=*/kint64max)); LOG(INFO) << "Worker registered with dispatcher running at " - << config_.dispatcher_address(); + << config_.dispatcher_address() + << ". Worker config: " << config_.DebugString(); task_completion_thread_ = absl::WrapUnique( Env::Default()->StartThread({}, "data-service-worker-task-completion", [this]() { TaskCompletionThread(); })); @@ -195,7 +196,7 @@ Status DataServiceWorkerImpl::Start( {}, "data-service-worker-heartbeat", [this]() { HeartbeatThread(); })); mutex_lock l(mu_); registered_ = true; - return OkStatus(); + return absl::OkStatus(); } void DataServiceWorkerImpl::Stop() { @@ -236,7 +237,7 @@ Status DataServiceWorkerImpl::ValidateWorkerConfig() const { config_.worker_tags().end(), ", "), "}"); } - return OkStatus(); + return absl::OkStatus(); } StatusOr> @@ -283,7 +284,7 @@ Status DataServiceWorkerImpl::GetElementResult( VLOG(3) << "Task is already finished"; result->end_of_sequence = true; result->skip = false; - return OkStatus(); + return absl::OkStatus(); } // Perhaps the worker hasn't gotten the task from the dispatcher yet. // Return Unavailable so that the client knows to continue retrying. @@ -306,7 +307,7 @@ Status DataServiceWorkerImpl::GetElementResult( pending_completed_tasks_.insert(request->task_id()); task_completion_cv_.notify_one(); } - return OkStatus(); + return absl::OkStatus(); } Status DataServiceWorkerImpl::ProcessTask(const ProcessTaskRequest* request, @@ -323,13 +324,13 @@ Status DataServiceWorkerImpl::ProcessTaskInternal(const TaskDef& task_def) if (task) { VLOG(1) << "Received request to process already-processed task " << task->task_def.task_id(); - return OkStatus(); + return absl::OkStatus(); } task = std::make_unique(task_def); VLOG(3) << "Began processing for task " << task_def.task_id() << " with processing mode " << task_def.processing_mode_def().DebugString(); - return OkStatus(); + return absl::OkStatus(); } Status DataServiceWorkerImpl::EnsureTaskInitialized( @@ -342,7 +343,7 @@ Status DataServiceWorkerImpl::EnsureTaskInitialized( mutex_lock l(task.mu); if (task.initialized) { - return OkStatus(); + return absl::OkStatus(); } TF_ASSIGN_OR_RETURN(DatasetDef dataset_def, GetDatasetDef(task.task_def)); TF_ASSIGN_OR_RETURN(std::unique_ptr dataset, @@ -356,7 +357,7 @@ Status DataServiceWorkerImpl::EnsureTaskInitialized( task.initialized = true; VLOG(3) << "Created iterator for task " << task.task_def.task_id(); - return OkStatus(); + return absl::OkStatus(); } StatusOr DataServiceWorkerImpl::GetDatasetDef( @@ -487,7 +488,7 @@ Status DataServiceWorkerImpl::GetElement(const GetElementRequest* request, MoveElementToResponse(std::move(result.components), *response)); VLOG(3) << "Producing an element for task " << request->task_id(); } - return OkStatus(); + return absl::OkStatus(); } Status DataServiceWorkerImpl::GetWorkerTasks( @@ -500,7 +501,7 @@ Status DataServiceWorkerImpl::GetWorkerTasks( task_info->set_task_id(task->task_def.task_id()); task_info->set_iteration_id(task->task_def.iteration_id()); } - return OkStatus(); + return absl::OkStatus(); } Status DataServiceWorkerImpl::GetSnapshotTaskProgresses( @@ -509,7 +510,7 @@ Status DataServiceWorkerImpl::GetSnapshotTaskProgresses( for (const auto& snapshot_task_progress : GetSnapshotTaskProgress()) { *response->add_snapshot_task_progresses() = snapshot_task_progress; } - return OkStatus(); + return absl::OkStatus(); } void DataServiceWorkerImpl::TaskCompletionThread() TF_LOCKS_EXCLUDED(mu_) { @@ -519,7 +520,7 @@ void DataServiceWorkerImpl::TaskCompletionThread() TF_LOCKS_EXCLUDED(mu_) { while (!cancelled_ && pending_completed_tasks_.empty()) { task_completion_cv_.wait(l); } - if (cancelled_) { + if (cancelled_ && pending_completed_tasks_.empty()) { VLOG(3) << "Task completion thread shutting down"; return; } @@ -528,7 +529,10 @@ void DataServiceWorkerImpl::TaskCompletionThread() TF_LOCKS_EXCLUDED(mu_) { if (!s.ok()) { LOG(WARNING) << "Failed to send task updates to dispatcher: " << s; mutex_lock l(mu_); - if (!cancelled_) { + if (cancelled_) { + VLOG(3) << "Task completion thread shutting down"; + return; + } else { task_completion_cv_.wait_for( l, absl::ToChronoMicroseconds(kRetryInterval)); } @@ -556,7 +560,7 @@ Status DataServiceWorkerImpl::SendTaskUpdates() TF_LOCKS_EXCLUDED(mu_) { pending_completed_tasks_.erase(update.task_id()); } VLOG(3) << "Sent " << task_progress.size() << " task updates "; - return OkStatus(); + return absl::OkStatus(); } void DataServiceWorkerImpl::HeartbeatThread() TF_LOCKS_EXCLUDED(mu_) { @@ -761,7 +765,7 @@ Status DataServiceWorkerImpl::UpdateSnapshotWriters( } } - return OkStatus(); + return absl::OkStatus(); } StatusOr> diff --git a/tensorflow/core/data/snapshot_utils.cc b/tensorflow/core/data/snapshot_utils.cc index d7b8a325f87014..a1dd6179cc8e50 100644 --- a/tensorflow/core/data/snapshot_utils.cc +++ b/tensorflow/core/data/snapshot_utils.cc @@ -146,7 +146,7 @@ Status TFRecordWriter::Initialize(tensorflow::Env* env) { record_writer_ = std::make_unique( dest_.get(), io::RecordWriterOptions::CreateRecordWriterOptions( /*compression_type=*/compression_type_)); - return OkStatus(); + return absl::OkStatus(); } Status TFRecordWriter::WriteTensors(const std::vector& tensors) { @@ -174,7 +174,7 @@ Status TFRecordWriter::WriteTensors(const std::vector& tensors) { TF_RETURN_IF_ERROR(record_writer_->WriteRecord(proto_serialized)); #endif // TF_CORD_SUPPORT } - return OkStatus(); + return absl::OkStatus(); } Status TFRecordWriter::Sync() { @@ -190,7 +190,7 @@ Status TFRecordWriter::Close() { record_writer_ = nullptr; dest_ = nullptr; } - return OkStatus(); + return absl::OkStatus(); } TFRecordWriter::~TFRecordWriter() { @@ -238,7 +238,7 @@ Status CustomWriter::Initialize(tensorflow::Env* env) { } } - return OkStatus(); + return absl::OkStatus(); } Status CustomWriter::WriteTensors(const std::vector& tensors) { @@ -321,7 +321,7 @@ Status CustomWriter::WriteTensors(const std::vector& tensors) { #endif // TF_CORD_SUPPORT TF_RETURN_IF_ERROR(WriteRecord(metadata_serialized)); TF_RETURN_IF_ERROR(WriteRecord(output)); - return OkStatus(); + return absl::OkStatus(); } Status CustomWriter::Sync() { return dest_->Sync(); } @@ -335,7 +335,7 @@ Status CustomWriter::Close() { TF_RETURN_IF_ERROR(zlib_underlying_dest_->Close()); zlib_underlying_dest_ = nullptr; } - return OkStatus(); + return absl::OkStatus(); } CustomWriter::~CustomWriter() { @@ -392,7 +392,7 @@ Status Reader::SkipRecords(int64_t num_records) { std::vector unused_tensors; TF_RETURN_IF_ERROR(ReadTensors(&unused_tensors)); } - return OkStatus(); + return absl::OkStatus(); } class Reader::Dataset : public DatasetBase { @@ -419,10 +419,10 @@ class Reader::Dataset : public DatasetBase { std::string DebugString() const override { return "SnapshotDatasetReader"; } Status InputDatasets(std::vector* inputs) const override { - return OkStatus(); + return absl::OkStatus(); } - Status CheckExternalState() const override { return OkStatus(); } + Status CheckExternalState() const override { return absl::OkStatus(); } protected: Status AsGraphDefInternal(SerializationContext* ctx, @@ -486,7 +486,7 @@ class Reader::Dataset : public DatasetBase { Status status = AdvanceToNextFile(ctx->env()); if (absl::IsNotFound(status)) { *end_of_sequence = true; - return OkStatus(); + return absl::OkStatus(); } return status; } @@ -497,7 +497,7 @@ class Reader::Dataset : public DatasetBase { current_checkpoint_id_)); TF_RETURN_IF_ERROR( writer->WriteScalar(full_name(kStartIndex), start_index_)); - return OkStatus(); + return absl::OkStatus(); } Status RestoreInternal(IteratorContext* ctx, @@ -533,7 +533,7 @@ class Reader::Dataset : public DatasetBase { std::vector unused; TF_RETURN_IF_ERROR(reader_->ReadTensors(&unused)); } - return OkStatus(); + return absl::OkStatus(); } std::unique_ptr reader_; @@ -596,10 +596,10 @@ class Reader::NestedDataset : public DatasetBase { Status InputDatasets(std::vector* inputs) const override { inputs->clear(); - return OkStatus(); + return absl::OkStatus(); } - Status CheckExternalState() const override { return OkStatus(); } + Status CheckExternalState() const override { return absl::OkStatus(); } protected: Status AsGraphDefInternal(SerializationContext* ctx, @@ -616,7 +616,7 @@ class Reader::NestedDataset : public DatasetBase { b->AddDataset(this, /*inputs=*/{}, /*list_inputs=*/{std::make_pair(0, input_graph_nodes)}, /*attrs=*/{}, node)); - return OkStatus(); + return absl::OkStatus(); } std::unique_ptr MakeIteratorInternal( @@ -651,19 +651,19 @@ class Reader::NestedDataset : public DatasetBase { index_++; } - return OkStatus(); + return absl::OkStatus(); } Status SaveInternal(SerializationContext* ctx, IteratorStateWriter* writer) override { TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kIndex), index_)); - return OkStatus(); + return absl::OkStatus(); } Status RestoreInternal(IteratorContext* ctx, IteratorStateReader* reader) override { TF_RETURN_IF_ERROR(reader->ReadScalar(full_name(kIndex), &index_)); - return OkStatus(); + return absl::OkStatus(); } private: @@ -725,7 +725,7 @@ Status Reader::MakeNestedDataset(Env* env, datasets.end()); } MakeNestedDataset(datasets, output); - return OkStatus(); + return absl::OkStatus(); } void Reader::MakeNestedDataset(const std::vector& datasets, @@ -758,20 +758,20 @@ Status TFRecordReaderImpl::Initialize(Env* env) { #endif // IS_SLIM_BUILD record_reader_ = std::make_unique(file_.get(), options); bytes_read_ = 0; - return OkStatus(); + return absl::OkStatus(); } -StatusOr TFRecordReaderImpl::GetNext() { +absl::StatusOr TFRecordReaderImpl::GetNext() { tstring record; TF_RETURN_IF_ERROR(record_reader_->ReadRecord(&offset_, &record)); bytes_read_ += record.size(); return Parse(record); } -StatusOr> TFRecordReaderImpl::GetTensors() { +absl::StatusOr> TFRecordReaderImpl::GetTensors() { std::vector tensors; while (true) { - StatusOr tensor = GetNext(); + absl::StatusOr tensor = GetNext(); if (absl::IsOutOfRange(tensor.status())) { return tensors; } @@ -781,7 +781,7 @@ StatusOr> TFRecordReaderImpl::GetTensors() { return tensors; } -StatusOr TFRecordReaderImpl::Parse(const tstring& record) { +absl::StatusOr TFRecordReaderImpl::Parse(const tstring& record) { TensorProto proto; if (!proto.ParseFromArray(record.data(), record.size())) { return errors::DataLoss( @@ -805,7 +805,7 @@ Status TFRecordReader::ReadTensors(std::vector* read_tensors) { TF_ASSIGN_OR_RETURN(Tensor tensor, reader_impl_.GetNext()); read_tensors->push_back(std::move(tensor)); } - return OkStatus(); + return absl::OkStatus(); } CustomReader::CustomReader(const std::string& filename, @@ -855,7 +855,7 @@ Status CustomReader::Initialize(Env* env) { } } - return OkStatus(); + return absl::OkStatus(); } Status CustomReader::ReadTensors(std::vector* read_tensors) { @@ -909,7 +909,7 @@ Status CustomReader::ReadTensors(std::vector* read_tensors) { complex_index++; } } - return OkStatus(); + return absl::OkStatus(); } Status CustomReader::ReadTensorsV0(std::vector* read_tensors) { @@ -930,7 +930,7 @@ Status CustomReader::ReadTensorsV0(std::vector* read_tensors) { return errors::DataLoss("Unable to parse tensor from proto."); } } - return OkStatus(); + return absl::OkStatus(); } Status CustomReader::SnappyUncompress( @@ -980,7 +980,7 @@ Status CustomReader::SnappyUncompress( iov.data(), num_tensors)) { return errors::Internal("Failed to perform snappy decompression."); } - return OkStatus(); + return absl::OkStatus(); } Status CustomReader::ReadRecord(tstring* record) { @@ -1003,7 +1003,7 @@ Status CustomReader::ReadRecord(absl::Cord* record) { absl::string_view tmp_str_view(*tmp_str); record->Append(absl::MakeCordFromExternal( tmp_str_view, [tmp_str](absl::string_view) { delete tmp_str; })); - return OkStatus(); + return absl::OkStatus(); } } #endif // TF_CORD_SUPPORT @@ -1039,7 +1039,7 @@ Status ReadMetadataFile(Env* env, const string& dir, if (*file_exists) { return ReadBinaryProto(env, metadata_filename, metadata); } else { - return OkStatus(); + return absl::OkStatus(); } } @@ -1053,7 +1053,7 @@ Status ReadMetadataFile(Env* env, const string& dir, if (*file_exists) { return ReadBinaryProto(env, metadata_filename, metadata); } else { - return OkStatus(); + return absl::OkStatus(); } } @@ -1080,30 +1080,30 @@ Status DetermineOpState(const std::string& mode_string, bool file_exists, } LOG(INFO) << "Overriding mode to reader."; *mode = READER; - return OkStatus(); + return absl::OkStatus(); } if (mode_string == kModeWrite) { LOG(INFO) << "Overriding mode to writer."; *mode = WRITER; - return OkStatus(); + return absl::OkStatus(); } if (mode_string == kModePassthrough) { LOG(INFO) << "Overriding mode to passthrough."; *mode = PASSTHROUGH; - return OkStatus(); + return absl::OkStatus(); } if (!file_exists) { *mode = WRITER; - return OkStatus(); + return absl::OkStatus(); } if (metadata->finalized()) { // File found, snapshot has been finalized. *mode = READER; - return OkStatus(); + return absl::OkStatus(); } int64_t expiration_timer = static_cast(EnvTime::NowMicros()) - @@ -1112,11 +1112,11 @@ Status DetermineOpState(const std::string& mode_string, bool file_exists, if (metadata->creation_timestamp() >= expiration_timer) { // Someone else is already writing and time has not expired. *mode = PASSTHROUGH; - return OkStatus(); + return absl::OkStatus(); } else { // Time has expired, we write regardless. *mode = WRITER; - return OkStatus(); + return absl::OkStatus(); } } @@ -1179,7 +1179,7 @@ Status AsyncWriter::WriterThread(Env* env, const std::string& shard_directory, TF_RETURN_IF_ERROR(writer->WriteTensors(be.value)); } - return OkStatus(); + return absl::OkStatus(); } namespace { diff --git a/tensorflow/core/data/snapshot_utils.h b/tensorflow/core/data/snapshot_utils.h index d6cec998495e4c..7e3b897bdddb0d 100644 --- a/tensorflow/core/data/snapshot_utils.h +++ b/tensorflow/core/data/snapshot_utils.h @@ -262,17 +262,17 @@ class TFRecordReaderImpl { Status Initialize(Env* env); // Reads the next Tensor in the input file. - StatusOr GetNext(); + absl::StatusOr GetNext(); // Reads all Tensors in the input file. - StatusOr> GetTensors(); + absl::StatusOr> GetTensors(); // Returns the number of bytes read. uint64_t BytesRead() const { return bytes_read_; } private: // Parses `record` into a Tensor. - StatusOr Parse(const tstring& record); + absl::StatusOr Parse(const tstring& record); std::string filename_; std::unique_ptr file_; diff --git a/tensorflow/core/data/split_utils_test.cc b/tensorflow/core/data/split_utils_test.cc index 15b61e125a0eb7..060f2b99b75018 100644 --- a/tensorflow/core/data/split_utils_test.cc +++ b/tensorflow/core/data/split_utils_test.cc @@ -39,7 +39,7 @@ Status SaveAndRestore(SplitProvider* split_provider) { writer.GetData(&variants); VariantTensorDataReader reader(variants); TF_RETURN_IF_ERROR(split_provider->Restore(full_name, &reader)); - return OkStatus(); + return absl::OkStatus(); } Status CheckOutput(SplitProvider* split_provider, @@ -54,7 +54,7 @@ Status CheckOutput(SplitProvider* split_provider, } } EXPECT_EQ(next, expected.size()); - return OkStatus(); + return absl::OkStatus(); } TEST(IndexSplitProviderTest, Empty) { diff --git a/tensorflow/core/data/standalone.cc b/tensorflow/core/data/standalone.cc index 04a425170b27be..1d80acceae95ac 100644 --- a/tensorflow/core/data/standalone.cc +++ b/tensorflow/core/data/standalone.cc @@ -34,6 +34,8 @@ limitations under the License. #include "tensorflow/core/data/dataset_utils.h" #include "tensorflow/core/data/root_dataset.h" #include "tensorflow/core/data/serialization_utils.h" +#include "tensorflow/core/data/tf_data_memory_logger.h" +#include "tensorflow/core/data/tfdataz_metrics.h" #include "tensorflow/core/framework/dataset.h" #include "tensorflow/core/framework/device.h" #include "tensorflow/core/framework/device_factory.h" @@ -76,13 +78,27 @@ OpKernelContext::Params CreateParams( Iterator::Iterator(IteratorBase* iterator, IteratorContext* ctx, SerializationContext* serialization_ctx) - : iterator_(iterator), ctx_(ctx), serialization_ctx_(serialization_ctx) {} + : iterator_(iterator), ctx_(ctx), serialization_ctx_(serialization_ctx) { + if (DatasetBaseIterator* dataset_iterator = + dynamic_cast(iterator_.get())) { + tf_dataz_metrics_collector_ = std::make_shared( + *Env::Default(), dataset_iterator, ctx_->model()); + TfDatazMetricsRegistry::Register(tf_dataz_metrics_collector_); + EnsureIteratorMemoryLoggerStarted(); + } +} + +Iterator::~Iterator() { + if (tf_dataz_metrics_collector_) { + TfDatazMetricsRegistry::Deregister(tf_dataz_metrics_collector_); + } +} Status Iterator::GetNext(std::vector* outputs, bool* end_of_input) { return iterator_->GetNext(ctx_.get(), outputs, end_of_input); } -StatusOr> Iterator::Save() { +absl::StatusOr> Iterator::Save() { VariantTensorDataWriter writer; TF_RETURN_IF_ERROR(iterator_->Save(serialization_ctx_.get(), &writer)); std::vector> data; @@ -141,7 +157,7 @@ Status Dataset::FromGraph(Params params, const GraphDef& graph_def, tsl::core::RefCountPtr* r) { *r = tsl::core::RefCountPtr( new IntraProcessRendezvous(device_mgr)); - return OkStatus(); + return absl::OkStatus(); }}); string fetch_node = ""; @@ -176,7 +192,7 @@ Status Dataset::FromGraph(Params params, const GraphDef& graph_def, *result = absl::WrapUnique(new Dataset( finalized_dataset, dataset, device_mgr.release(), pflr.release(), flib_def.release(), pool.release(), std::move(runner))); - return OkStatus(); + return absl::OkStatus(); } // static Status Dataset::MakeIterator( @@ -215,7 +231,7 @@ Status Dataset::MakeIterator( ctx.get(), /*parent=*/nullptr, "Iterator", &iterator)); *result = absl::WrapUnique(new Iterator(iterator.release(), ctx.release(), serialization_ctx.release())); - return OkStatus(); + return absl::OkStatus(); } Status Dataset::MakeIterator(std::unique_ptr* result) { diff --git a/tensorflow/core/data/standalone.h b/tensorflow/core/data/standalone.h index 5de0d81b274b30..c2e257953a1b29 100644 --- a/tensorflow/core/data/standalone.h +++ b/tensorflow/core/data/standalone.h @@ -21,6 +21,7 @@ limitations under the License. #include #include "tensorflow/core/common_runtime/device_mgr.h" +#include "tensorflow/core/data/tfdataz_metrics.h" #include "tensorflow/core/data/unbounded_thread_pool.h" #include "tensorflow/core/framework/cancellation.h" #include "tensorflow/core/framework/dataset.h" @@ -79,13 +80,15 @@ class Dataset; // its elements. class Iterator { public: + virtual ~Iterator(); + // Returns the next element of the input pipeline (if there is one) and an // indication of whether the end of the input pipeline has been reached. Status GetNext(std::vector* outputs, bool* end_of_input); // Saves a checkpoint of the iterator. Returns Tensors that can be called with // `Restore()`. - StatusOr> Save(); + absl::StatusOr> Save(); // Restores the iterator from a checkpoint. `saved_iterator` is the serialized // iterator saved by calling `Save()`. @@ -103,6 +106,7 @@ class Iterator { std::unique_ptr iterator_; std::unique_ptr ctx_; std::unique_ptr serialization_ctx_; + std::shared_ptr tf_dataz_metrics_collector_; }; // Represents an input pipeline as a collection of data sources and a logical diff --git a/tensorflow/core/data/standalone_save_restore_test.cc b/tensorflow/core/data/standalone_save_restore_test.cc index 5cf78f45dddc70..fd163691d71f88 100644 --- a/tensorflow/core/data/standalone_save_restore_test.cc +++ b/tensorflow/core/data/standalone_save_restore_test.cc @@ -44,7 +44,7 @@ class TestDataset { Dataset::FromGraph(Dataset::Params(), dataset_def.graph(), &dataset_)); } - StatusOr> MakeIterator() const { + absl::StatusOr> MakeIterator() const { std::unique_ptr iterator; TF_RETURN_IF_ERROR(dataset_->MakeIterator(&iterator)); return iterator; diff --git a/tensorflow/core/data/test_utils.cc b/tensorflow/core/data/test_utils.cc index 577d1401fb052f..b1c4cfc5b27ec4 100644 --- a/tensorflow/core/data/test_utils.cc +++ b/tensorflow/core/data/test_utils.cc @@ -35,7 +35,7 @@ limitations under the License. namespace tensorflow { namespace data { -StatusOr> TestContext::Create() { +absl::StatusOr> TestContext::Create() { auto ctx = std::unique_ptr(new TestContext()); SessionOptions options; auto* device_count = options.config.mutable_device_count(); diff --git a/tensorflow/core/data/test_utils.h b/tensorflow/core/data/test_utils.h index 509e36316954da..61da1807ee4e3a 100644 --- a/tensorflow/core/data/test_utils.h +++ b/tensorflow/core/data/test_utils.h @@ -30,7 +30,7 @@ namespace data { class TestContext { public: - static StatusOr> Create(); + static absl::StatusOr> Create(); virtual ~TestContext() = default; OpKernelContext* op_ctx() const { return op_ctx_.get(); } diff --git a/tensorflow/core/data/tf_data_memory_logger.cc b/tensorflow/core/data/tf_data_memory_logger.cc index a313e435a0bd76..f1edc3fe9d08a8 100644 --- a/tensorflow/core/data/tf_data_memory_logger.cc +++ b/tensorflow/core/data/tf_data_memory_logger.cc @@ -26,6 +26,7 @@ limitations under the License. #include "tensorflow/core/data/tfdataz_metrics.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/numbers.h" +#include "tensorflow/core/platform/status.h" #include "tsl/platform/logging.h" namespace tensorflow { @@ -37,6 +38,7 @@ const int64_t kLogFrequencyS = 30; // How often to log. struct IteratorMemoryUsage { std::optional dataset_name; int64_t memory_usage; + std::string model_proto; }; int64_t TotalMemoryUsage(const std::vector& usages) { @@ -52,9 +54,16 @@ void LogDatasetMemoryUsage() { metric_collectors = TfDatazMetricsRegistry::GetIteratorMetricCollectors(); std::vector usages; for (const auto& metric_collector : metric_collectors) { - usages.push_back( - IteratorMemoryUsage{metric_collector->DatasetName(), - metric_collector->GetIteratorTotalMemoryUsage()}); + int64_t total_buffered_bytes = + metric_collector->GetModel()->output()->TotalBufferedBytes(); + model::ModelProto model_proto; + Status s = metric_collector->GetModel()->ToProto(&model_proto); + if (!s.ok()) { + LOG(ERROR) << "Failed to convert model to proto: " << s; + } + usages.push_back(IteratorMemoryUsage{metric_collector->DatasetName(), + total_buffered_bytes, + model_proto.ShortDebugString()}); } std::sort(usages.begin(), usages.end(), [](const auto& a, const auto& b) { return a.memory_usage > b.memory_usage; @@ -75,6 +84,7 @@ void LogDatasetMemoryUsage() { } else { VLOG(4) << "Dataset " << i << " (no name set): " << usage_string; } + VLOG(5) << "Model proto: " << usages[i].model_proto; } } diff --git a/tensorflow/core/data/tfdataz_metrics.cc b/tensorflow/core/data/tfdataz_metrics.cc index 2ba689d5818fe1..0bb251310c8273 100644 --- a/tensorflow/core/data/tfdataz_metrics.cc +++ b/tensorflow/core/data/tfdataz_metrics.cc @@ -16,16 +16,18 @@ limitations under the License. #include #include -#include #include #include #include #include -#include +#include "absl/container/flat_hash_set.h" #include "absl/time/time.h" #include "tensorflow/core/framework/dataset.h" +#include "tensorflow/core/framework/model.h" #include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/mutex.h" +#include "tsl/platform/thread_annotations.h" namespace tensorflow { namespace data { @@ -93,9 +95,10 @@ absl::Duration ApproximateLatencyEstimator::GetAverageLatency(Duration duration) return absl::Duration(absl::Microseconds(interval_latency)) / interval_count; } -TfDatazMetricsCollector::TfDatazMetricsCollector(const Env& env, - DatasetBaseIterator* iterator) - : iterator_(iterator), latency_estimator_(env) {} +TfDatazMetricsCollector::TfDatazMetricsCollector( + const Env& env, DatasetBaseIterator* iterator, + std::shared_ptr model) + : iterator_(iterator), model_(std::move(model)), latency_estimator_(env) {} void TfDatazMetricsCollector::RecordGetNextLatency( int64_t get_next_latency_usec) { @@ -131,6 +134,10 @@ int64_t TfDatazMetricsCollector::GetIteratorTotalMemoryUsage() { return iterator_->TotalBufferedBytes(); } +std::shared_ptr TfDatazMetricsCollector::GetModel() { + return model_; +} + namespace { static mutex* get_tfdataz_metrics_registry_lock() { static mutex tfdataz_metrics_registry_lock(LINKER_INITIALIZED); diff --git a/tensorflow/core/data/tfdataz_metrics.h b/tensorflow/core/data/tfdataz_metrics.h index dc202f6ca784c4..e37daf89c6a47d 100644 --- a/tensorflow/core/data/tfdataz_metrics.h +++ b/tensorflow/core/data/tfdataz_metrics.h @@ -23,6 +23,7 @@ limitations under the License. #include "absl/container/flat_hash_set.h" #include "absl/time/time.h" #include "tensorflow/core/framework/dataset.h" +#include "tensorflow/core/framework/model.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/thread_annotations.h" @@ -94,7 +95,8 @@ class TfDatazMetricsCollector { // We only collect metrics for CPU devices. This is a heuristic to avoid // collecting metrics for device-side iterators created by the multi-device // iterator mechanism. - TfDatazMetricsCollector(const Env& env, DatasetBaseIterator* iterator); + TfDatazMetricsCollector(const Env& env, DatasetBaseIterator* iterator, + std::shared_ptr model); // Records `GetNext` call latency. void RecordGetNextLatency(int64_t get_next_latency_usec); @@ -116,8 +118,11 @@ class TfDatazMetricsCollector { // buffered in all nodes in the subtree. int64_t GetIteratorTotalMemoryUsage(); + std::shared_ptr GetModel(); + private: DatasetBaseIterator* iterator_; // not owned + std::shared_ptr model_; ApproximateLatencyEstimator latency_estimator_; }; diff --git a/tensorflow/core/data/tfdataz_metrics_test.cc b/tensorflow/core/data/tfdataz_metrics_test.cc index a1fcc540125936..0a75ed4fe6d877 100644 --- a/tensorflow/core/data/tfdataz_metrics_test.cc +++ b/tensorflow/core/data/tfdataz_metrics_test.cc @@ -41,8 +41,8 @@ class TfDatazMetricsTest : public ::testing::Test { protected: void SetUp() override { env_ = std::make_unique(Env::Default()); - tfdataz_metrics_ = - std::make_unique(*env_, iterator_.get()); + tfdataz_metrics_ = std::make_unique( + *env_, iterator_.get(), /*model=*/nullptr); } void TearDown() override { @@ -200,9 +200,9 @@ class ScopedTfDataMetricsRegistration { TEST(TfDatazMetricsRegistryTest, Register) { std::unique_ptr iterator; auto collector_one = std::make_shared( - *Env::Default(), iterator.get()); + *Env::Default(), iterator.get(), /*model=*/nullptr); auto collector_two = std::make_shared( - *Env::Default(), iterator.get()); + *Env::Default(), iterator.get(), /*model=*/nullptr); ScopedTfDataMetricsRegistration scoped_registration_one(collector_one); ScopedTfDataMetricsRegistration scoped_registration_two(collector_two); @@ -213,11 +213,11 @@ TEST(TfDatazMetricsRegistryTest, Register) { TEST(TfDatazMetricsRegistryTest, Deregister) { std::unique_ptr iterator; auto collector_one = std::make_shared( - *Env::Default(), iterator.get()); + *Env::Default(), iterator.get(), /*model=*/nullptr); auto collector_two = std::make_shared( - *Env::Default(), iterator.get()); + *Env::Default(), iterator.get(), /*model=*/nullptr); auto collector_three = std::make_shared( - *Env::Default(), iterator.get()); + *Env::Default(), iterator.get(), /*model=*/nullptr); ScopedTfDataMetricsRegistration scoped_registration_one(collector_one); ScopedTfDataMetricsRegistration scoped_registration_two(collector_two); ScopedTfDataMetricsRegistration scoped_registration_three(collector_three); diff --git a/tensorflow/core/distributed_runtime/base_rendezvous_mgr.cc b/tensorflow/core/distributed_runtime/base_rendezvous_mgr.cc index b31dfa4f93f61c..87addeeceddc0d 100644 --- a/tensorflow/core/distributed_runtime/base_rendezvous_mgr.cc +++ b/tensorflow/core/distributed_runtime/base_rendezvous_mgr.cc @@ -129,7 +129,7 @@ Status BaseRemoteRendezvous::Initialize(WorkerSession* session) { if (session_ != nullptr) { if (session_->worker_name() == session->worker_name()) { VLOG(1) << "Skipping rendezvous re-initialization."; - return OkStatus(); + return absl::OkStatus(); } Status s = errors::Internal( "Double init! Worker names would have changed from: ", @@ -143,7 +143,7 @@ Status BaseRemoteRendezvous::Initialize(WorkerSession* session) { for (auto& call : deferred_calls) { RecvLocalAsyncInternal(call.parsed, std::move(call.done)); } - return OkStatus(); + return absl::OkStatus(); } WorkerSession* BaseRemoteRendezvous::session() { @@ -203,7 +203,7 @@ Status BaseRemoteRendezvous::ValidateDevices(const ParsedKey& parsed, "Invalid rendezvous key (dst): ", parsed.FullKey(), " @ ", sess->worker_name()); } - return OkStatus(); + return absl::OkStatus(); } void BaseRemoteRendezvous::SameWorkerRecvDone( @@ -218,7 +218,7 @@ void BaseRemoteRendezvous::SameWorkerRecvDone( (recv_args.alloc_attrs.on_host() || parsed.dst.type == "CPU"); if (src_host && dst_host) { *out = in; - done(OkStatus()); + done(absl::OkStatus()); return; } diff --git a/tensorflow/core/distributed_runtime/cluster_function_library_runtime.cc b/tensorflow/core/distributed_runtime/cluster_function_library_runtime.cc index e4806099683b8f..47dbf7c35a9670 100644 --- a/tensorflow/core/distributed_runtime/cluster_function_library_runtime.cc +++ b/tensorflow/core/distributed_runtime/cluster_function_library_runtime.cc @@ -169,7 +169,7 @@ Status ClusterFunctionLibraryRuntime::ConstructFunctionGraph( // from the library. *(gdef->mutable_library()) = flib_def.ReachableDefinitions(*gdef).ToProto(); - return OkStatus(); + return absl::OkStatus(); } ClusterFunctionLibraryRuntime::~ClusterFunctionLibraryRuntime() { @@ -209,7 +209,7 @@ void ClusterFunctionLibraryRuntime::Instantiate( const OpDef& sig = fdef->signature(); TF_RETURN_IF_ERROR(ConstructFunctionGraph(sig, attrs, options, *lib_def, &gdef, send_keys, recv_keys)); - return OkStatus(); + return absl::OkStatus(); }; Status s; if (options.lib_def) { diff --git a/tensorflow/core/distributed_runtime/cluster_function_library_runtime_test.cc b/tensorflow/core/distributed_runtime/cluster_function_library_runtime_test.cc index 8a3df302cd9566..899cc193ee87d1 100644 --- a/tensorflow/core/distributed_runtime/cluster_function_library_runtime_test.cc +++ b/tensorflow/core/distributed_runtime/cluster_function_library_runtime_test.cc @@ -120,7 +120,7 @@ class ClusterFunctionLibraryRuntimeTest : public ::testing::Test { *rets[i] = out[i]; } - return OkStatus(); + return absl::OkStatus(); } protected: diff --git a/tensorflow/core/distributed_runtime/collective_param_resolver_distributed.cc b/tensorflow/core/distributed_runtime/collective_param_resolver_distributed.cc index caffcca302b800..92f7251d3ad2c5 100644 --- a/tensorflow/core/distributed_runtime/collective_param_resolver_distributed.cc +++ b/tensorflow/core/distributed_runtime/collective_param_resolver_distributed.cc @@ -269,7 +269,7 @@ Status CollectiveParamResolverDistributed::UpdateGroupCache( absl::CEscape(previous_gr->group.runtime_details.communicator_key)); } } - return OkStatus(); + return absl::OkStatus(); } void CollectiveParamResolverDistributed::CompleteGroupDistributed( diff --git a/tensorflow/core/distributed_runtime/collective_param_resolver_distributed_test.cc b/tensorflow/core/distributed_runtime/collective_param_resolver_distributed_test.cc index 6a6e6f455603fe..d699b1d2275a2c 100644 --- a/tensorflow/core/distributed_runtime/collective_param_resolver_distributed_test.cc +++ b/tensorflow/core/distributed_runtime/collective_param_resolver_distributed_test.cc @@ -39,7 +39,7 @@ static std::unique_ptr NewDevice(const string& type, class FakeDevice : public Device { public: explicit FakeDevice(const DeviceAttributes& attr) : Device(nullptr, attr) {} - Status Sync() override { return OkStatus(); } + Status Sync() override { return absl::OkStatus(); } Allocator* GetAllocator(AllocatorAttributes) override { return nullptr; } }; DeviceAttributes attr; @@ -83,7 +83,7 @@ class FakeCache : public TestWorkerCache { for (const auto& it : resp.device_attributes()) { if (it.name() == device) { *locality = it.locality(); - done(OkStatus()); + done(absl::OkStatus()); return; } } @@ -98,7 +98,7 @@ class FakeNcclCommunicator : public NcclCommunicatorInterface { void Enqueue(std::shared_ptr col_ctx, StatusCallback done) override { - done(OkStatus()); + done(absl::OkStatus()); } void StartAbort(const Status& s) override {} diff --git a/tensorflow/core/distributed_runtime/collective_rma_distributed.cc b/tensorflow/core/distributed_runtime/collective_rma_distributed.cc index b97aff64bbd2a5..984a5bda243c82 100644 --- a/tensorflow/core/distributed_runtime/collective_rma_distributed.cc +++ b/tensorflow/core/distributed_runtime/collective_rma_distributed.cc @@ -85,7 +85,7 @@ Status PopulateTensorFromResponse(const RecvBufResponse& response, // If there are no transport options, then the tensor has already been // copied into request.buf_ptr. - if (!has_transport_options) return OkStatus(); + if (!has_transport_options) return absl::OkStatus(); const int64_t total_bytes = cpu_tensor->TotalBytes(); int64_t num_bytes = 0; @@ -101,7 +101,7 @@ Status PopulateTensorFromResponse(const RecvBufResponse& response, " bytes, expected: ", cpu_tensor->TotalBytes()); } PopulateTensorFromExtra(extra, cpu_tensor); - return OkStatus(); + return absl::OkStatus(); } } // namespace @@ -236,7 +236,7 @@ void CollectiveRemoteAccessDistributed::CheckPeerHealth( const StatusCallback& done) { if (peer_task == task_name_) { // Fast path if the peer is the worker itself. - done(OkStatus()); + done(absl::OkStatus()); return; } // We send a GetStatus RPC to check the health of a peer task. If the RPC @@ -282,7 +282,7 @@ void CollectiveRemoteAccessDistributed::CheckPeerHealth( // Skip validating device incarnation if we don't know what the // incarnation should be. The device attribute is cached after the // first collective. - s = OkStatus(); + s = absl::OkStatus(); } delete opts; delete req; diff --git a/tensorflow/core/distributed_runtime/collective_rma_distributed_test.cc b/tensorflow/core/distributed_runtime/collective_rma_distributed_test.cc index bd9d5f100d8758..8ece4bf2f0fe70 100644 --- a/tensorflow/core/distributed_runtime/collective_rma_distributed_test.cc +++ b/tensorflow/core/distributed_runtime/collective_rma_distributed_test.cc @@ -60,7 +60,7 @@ static std::unique_ptr NewDevice(const string& type, const string& name, public: explicit FakeDevice(const DeviceAttributes& attr, Allocator* allocator) : Device(nullptr, attr), allocator_(allocator) {} - Status Sync() override { return OkStatus(); } + Status Sync() override { return absl::OkStatus(); } Allocator* GetAllocator(AllocatorAttributes) override { return allocator_; } private: @@ -104,7 +104,7 @@ class FakeWorker : public TestWorkerInterface { for (const auto& da : dev_attr) { *response->add_device_attributes() = da; } - done(OkStatus()); + done(absl::OkStatus()); } void RecvBufAsync(CallOptions* opts, const RecvBufRequest* request, @@ -202,7 +202,7 @@ class FakeCache : public TestWorkerCache { for (const auto& it : resp.device_attributes()) { if (it.name() == device) { *locality = it.locality(); - done(OkStatus()); + done(absl::OkStatus()); return; } } diff --git a/tensorflow/core/distributed_runtime/coordination/coordination_service_barrier_proxy.cc b/tensorflow/core/distributed_runtime/coordination/coordination_service_barrier_proxy.cc index a1c58b8b6ea756..952d38d852560f 100644 --- a/tensorflow/core/distributed_runtime/coordination/coordination_service_barrier_proxy.cc +++ b/tensorflow/core/distributed_runtime/coordination/coordination_service_barrier_proxy.cc @@ -60,7 +60,7 @@ std::pair BarrierProxy::Wait() { // We should have a mechanism to remove it after it has been passed. status_ = agent_->WaitAtBarrier(key_, timeout_, tasks_); } else { - status_ = OkStatus(); + status_ = absl::OkStatus(); } status_set_ = true; cv_.notify_all(); @@ -98,7 +98,7 @@ Status BarrierProxyManager::Wait(tsl::CoordinationServiceAgent* agent, int num_local_threads, absl::string_view key, absl::Duration timeout) { // Only one device, no need to wait. - if (tasks.size() == 1 && num_local_threads <= 1) return OkStatus(); + if (tasks.size() == 1 && num_local_threads <= 1) return absl::OkStatus(); profiler::TraceMe traceme([&] { return profiler::TraceMeEncode("BarrierProxyManager::Wait", diff --git a/tensorflow/core/distributed_runtime/coordination/coordination_service_barrier_proxy_test.cc b/tensorflow/core/distributed_runtime/coordination/coordination_service_barrier_proxy_test.cc index 9ff0473321a07c..464c82bfca3fd0 100644 --- a/tensorflow/core/distributed_runtime/coordination/coordination_service_barrier_proxy_test.cc +++ b/tensorflow/core/distributed_runtime/coordination/coordination_service_barrier_proxy_test.cc @@ -75,21 +75,22 @@ class MockCoordinationServiceAgent : public CoordinationServiceAgent { MOCK_METHOD(Status, WaitForAllTasks, (const DeviceInfo& local_devices), (override)); MOCK_METHOD(const DeviceInfo&, GetClusterDeviceInfo, (), (override)); - MOCK_METHOD(StatusOr, GetOwnTask, (), (override)); - MOCK_METHOD(StatusOr>, GetTaskState, - (const std::vector& task), (override)); + MOCK_METHOD(absl::StatusOr, GetOwnTask, (), (override)); + MOCK_METHOD(absl::StatusOr>, + GetTaskState, (const std::vector& task), + (override)); MOCK_METHOD(Status, ReportError, (const Status& error), (override)); MOCK_METHOD(Status, Shutdown, (), (override)); MOCK_METHOD(Status, Reset, (), (override)); - MOCK_METHOD(StatusOr, GetKeyValue, (std::string_view key), + MOCK_METHOD(absl::StatusOr, GetKeyValue, (std::string_view key), (override)); - MOCK_METHOD(StatusOr, GetKeyValue, + MOCK_METHOD(absl::StatusOr, GetKeyValue, (std::string_view key, absl::Duration timeout), (override)); MOCK_METHOD(std::shared_ptr, GetKeyValueAsync, (std::string_view key, StatusOrValueCallback done), (override)); - MOCK_METHOD(StatusOr, TryGetKeyValue, (std::string_view key), - (override)); - MOCK_METHOD(StatusOr>, GetKeyValueDir, + MOCK_METHOD(absl::StatusOr, TryGetKeyValue, + (std::string_view key), (override)); + MOCK_METHOD(absl::StatusOr>, GetKeyValueDir, (std::string_view key), (override)); MOCK_METHOD(void, GetKeyValueDirAsync, (std::string_view key, StatusOrValueDirCallback done), @@ -109,7 +110,7 @@ class MockCoordinationServiceAgent : public CoordinationServiceAgent { (override)); MOCK_METHOD(void, CancelBarrierAsync, (std::string_view barrier_id, StatusCallback done), (override)); - MOCK_METHOD(StatusOr, GetEnv, (), (override)); + MOCK_METHOD(absl::StatusOr, GetEnv, (), (override)); MOCK_METHOD(void, SetError, (const Status& error), (override)); MOCK_METHOD(Status, ActivateWatch, (std::string_view key, @@ -163,8 +164,8 @@ TEST(BarrierProxyTest, AllThreadsExitBarrier) { /*num_threads_planned=*/8, /*num_threads_entered=*/8, /*expected_ok_count=*/8, - /*agent_wait_status=*/OkStatus(), - /*expected_same_exit_status_for_all_threads=*/OkStatus()); + /*agent_wait_status=*/absl::OkStatus(), + /*expected_same_exit_status_for_all_threads=*/absl::OkStatus()); } TEST(BarrierProxyTest, AgentErrorBroadcastedToAllThreads) { @@ -184,7 +185,7 @@ TEST(BarrierProxyTest, AgentIsIgnoredIfThereIsOnlyOneTask) { /*num_threads_entered=*/8, /*expected_ok_count=*/8, /*agent_wait_status=*/{}, - /*expected_same_exit_status_for_all_threads=*/OkStatus()); + /*expected_same_exit_status_for_all_threads=*/absl::OkStatus()); } TEST(BarrierProxyTest, TimeoutIfNotEnoughThreadEntered) { @@ -204,7 +205,7 @@ TEST(BarrierProxyTest, ExtraThreadsEnteringTheBarrierGetErrors) { /*num_threads_planned=*/8, /*num_threads_entered=*/10, /*expected_ok_count=*/8, - /*agent_wait_status=*/OkStatus(), + /*agent_wait_status=*/absl::OkStatus(), /*expected_same_exit_status_for_all_threads=*/{}); } @@ -240,7 +241,7 @@ TEST(BarrierProxyManagerTest, AllThreadExited) { TestBarrierProxyManagerWaitSingleKey( /*num_threads_planned=*/8, /*num_threads_entered=*/8, - /*agent_wait_status=*/OkStatus(), + /*agent_wait_status=*/absl::OkStatus(), /*expected_ok_count=*/8); } @@ -264,7 +265,7 @@ TEST(BarrierProxyManagerTest, ExtraThreadsEnteringTheSameKeyGetErrors) { TestBarrierProxyManagerWaitSingleKey( /*num_threads_planned=*/8, /*num_threads_entered=*/10, - /*agent_wait_status=*/OkStatus(), + /*agent_wait_status=*/absl::OkStatus(), /*expected_ok_count=*/8); } @@ -275,16 +276,16 @@ TEST(BarrierProxyManagerTest, DifferentKeysDoNotInterfereWithEachOther) { BarrierProxyManager mgr; EXPECT_CALL(*agent, WaitAtBarrier("key0", kTestTimeout, _)) - .WillOnce(Return(OkStatus())); + .WillOnce(Return(absl::OkStatus())); EXPECT_CALL(*agent, WaitAtBarrier("key1", kTestTimeout, _)) - .WillOnce(Return(OkStatus())); + .WillOnce(Return(absl::OkStatus())); { thread::ThreadPool pool(Env::Default(), /*name=*/"TestPool", kThreadPoolSize); for (int i = 0; i < kNumThreads * 2; ++i) { pool.Schedule([&, key = absl::StrCat("key", i % 2)]() { ASSERT_EQ(mgr.Wait(agent.get(), tasks, kNumThreads, key, kTestTimeout), - OkStatus()); + absl::OkStatus()); }); } } diff --git a/tensorflow/core/distributed_runtime/device_resolver_distributed.cc b/tensorflow/core/distributed_runtime/device_resolver_distributed.cc index c560ef6cc1075a..948ddf2afc0419 100644 --- a/tensorflow/core/distributed_runtime/device_resolver_distributed.cc +++ b/tensorflow/core/distributed_runtime/device_resolver_distributed.cc @@ -35,7 +35,7 @@ Status DeviceResolverDistributed::GetDeviceAttributes( return errors::NotFound(device, " not found"); } *attributes = it->second; - return OkStatus(); + return absl::OkStatus(); } Status DeviceResolverDistributed::GetAllDeviceAttributes( @@ -51,7 +51,7 @@ Status DeviceResolverDistributed::GetAllDeviceAttributes( if (attributes->empty()) { return errors::NotFound(task, " not found in the cache"); } - return OkStatus(); + return absl::OkStatus(); } Status DeviceResolverDistributed::UpdateDeviceAttributes( @@ -70,7 +70,7 @@ Status DeviceResolverDistributed::UpdateDeviceAttributes( "This usually means the remote worker has restarted"); } } - return OkStatus(); + return absl::OkStatus(); } } // namespace tensorflow diff --git a/tensorflow/core/distributed_runtime/device_resolver_distributed_test.cc b/tensorflow/core/distributed_runtime/device_resolver_distributed_test.cc index 32bf0875c1248b..59fadf9e73971f 100644 --- a/tensorflow/core/distributed_runtime/device_resolver_distributed_test.cc +++ b/tensorflow/core/distributed_runtime/device_resolver_distributed_test.cc @@ -38,7 +38,7 @@ std::unique_ptr NewDevice(const string& type, const string& name) { class FakeDevice : public Device { public: explicit FakeDevice(const DeviceAttributes& attr) : Device(nullptr, attr) {} - Status Sync() override { return OkStatus(); } + Status Sync() override { return absl::OkStatus(); } Allocator* GetAllocator(AllocatorAttributes) override { return nullptr; } }; DeviceAttributes attr; diff --git a/tensorflow/core/distributed_runtime/eager/cluster_function_library_runtime.cc b/tensorflow/core/distributed_runtime/eager/cluster_function_library_runtime.cc index ac5d1b2bb09022..3c5d53cb0dc6b1 100644 --- a/tensorflow/core/distributed_runtime/eager/cluster_function_library_runtime.cc +++ b/tensorflow/core/distributed_runtime/eager/cluster_function_library_runtime.cc @@ -249,7 +249,7 @@ void EagerClusterFunctionLibraryRuntime::Run( return; } } - done(OkStatus()); + done(absl::OkStatus()); }); } diff --git a/tensorflow/core/distributed_runtime/eager/destroy_tensor_handle_node.h b/tensorflow/core/distributed_runtime/eager/destroy_tensor_handle_node.h index 45aab6e5f840ea..ca5eaa2526f6cb 100644 --- a/tensorflow/core/distributed_runtime/eager/destroy_tensor_handle_node.h +++ b/tensorflow/core/distributed_runtime/eager/destroy_tensor_handle_node.h @@ -61,7 +61,7 @@ class DestroyTensorHandleNode : public tensorflow::AsyncEagerNode { "remote tensors handles: " << s.ToString(); } - done(OkStatus()); + done(absl::OkStatus()); delete response; }); } diff --git a/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc b/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc index f6f3bf1ee1668c..f5062d6b75ef42 100644 --- a/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc +++ b/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include #include +#include #include #include @@ -92,7 +93,7 @@ Status GetNumRetvals(FunctionLibraryDefinition* func_lib_def, } } - return OkStatus(); + return absl::OkStatus(); } Status GetEagerOperationAndNumRetvals(const Operation& operation, @@ -167,7 +168,7 @@ Status TensorHandleProto(TensorHandle* handle, TensorProto* proto) { const tensorflow::Tensor* t = nullptr; TF_RETURN_IF_ERROR(handle->Tensor(&t)); t->AsProtoTensorContent(proto); - return OkStatus(); + return absl::OkStatus(); } Status TensorHandleShape(TensorHandle* handle, TensorShapeProto* proto) { @@ -184,7 +185,7 @@ Status TensorHandleShape(TensorHandle* handle, TensorShapeProto* proto) { shape.AsProto(proto); } - return OkStatus(); + return absl::OkStatus(); } Status AddOpRetvalsToResponse( @@ -245,7 +246,7 @@ Status ResetAgentAndConnectToCoordinationService( return s; } } - return OkStatus(); + return absl::OkStatus(); } } // namespace @@ -281,6 +282,27 @@ Status EagerServiceImpl::CreateContext(const CreateContextRequest* request, return tensorflow::errors::Internal( "invalid eager env_ or env_->rendezvous_mgr."); } + if (request->clear_existing_contexts()) { + // Cleanup state from WorkerEnv + for (auto* device : env_->device_mgr->ListDevices()) { + device->ClearResourceMgr(); + } + env_->rendezvous_mgr->CleanupAll(); + env_->collective_executor_mgr->CleanupAll(); + TF_RETURN_IF_ERROR(env_->session_mgr->DeleteAllSessions()); + + // Cleanup existing contexts if any. + std::unordered_map tmp_contexts; + { + mutex_lock l(contexts_mu_); + if (!contexts_.empty()) { + std::swap(tmp_contexts, contexts_); + } + } + for (auto& context : tmp_contexts) { + context.second->Unref(); + } + } tsl::core::RefCountPtr r = env_->rendezvous_mgr->Find(request->context_id()); @@ -395,7 +417,7 @@ Status EagerServiceImpl::CreateContext(const CreateContextRequest* request, tsl::PreemptionNotifier::CreatePreemptionNotifier("sigterm", Env::Default()); preemption_notifier->WillBePreemptedAtAsync( - [coord_agent](StatusOr time_or_status) { + [coord_agent](absl::StatusOr time_or_status) { if (time_or_status.ok()) { const auto coord_task = coord_agent->GetOwnTask().value(); Status s = coord_agent->InsertKeyValue( @@ -431,7 +453,7 @@ Status EagerServiceImpl::CreateContext(const CreateContextRequest* request, new ServerContext(ctx, request->keep_alive_secs(), env_)); } - return OkStatus(); + return absl::OkStatus(); } Status EagerServiceImpl::UpdateContext(const UpdateContextRequest* request, @@ -463,7 +485,7 @@ Status EagerServiceImpl::UpdateContext(const UpdateContextRequest* request, ctx->IncrementContextViewId(); VLOG(1) << "Processing simplified UpdateContextRequest on " << ctx->HostCPU()->name(); - return OkStatus(); + return absl::OkStatus(); } auto session_name = @@ -522,7 +544,7 @@ Status EagerServiceImpl::UpdateContext(const UpdateContextRequest* request, *response->add_device_attributes() = da; } - return OkStatus(); + return absl::OkStatus(); } Status EagerServiceImpl::CreateMasterContext( @@ -540,7 +562,7 @@ Status EagerServiceImpl::CreateMasterContext( ServerContext::CreateMasterContext(context, env_); mutex_lock l(contexts_mu_); contexts_.emplace(context_id, server_context); - return OkStatus(); + return absl::OkStatus(); } void EagerServiceImpl::RunComponentFunction( @@ -723,7 +745,7 @@ Status EagerServiceImpl::Enqueue(CallOptions* call_opts, } } - return OkStatus(); + return absl::OkStatus(); } Status EagerServiceImpl::WaitQueueDone(const WaitQueueDoneRequest* request, @@ -748,17 +770,15 @@ Status EagerServiceImpl::KeepAlive(const KeepAliveRequest* request, tensorflow::EagerContext* ctx = context->Context(); response->set_context_view_id(ctx->GetContextViewId()); - return OkStatus(); + return absl::OkStatus(); } Status EagerServiceImpl::CloseContext(const CloseContextRequest* request, CloseContextResponse* response) { - VLOG(1) << "Executing EagerService::CloseContext for context " - << request->context_id(); ServerContext* context = nullptr; if (!GetServerContext(request->context_id(), &context).ok()) { // Swallow the error here. - return OkStatus(); + return absl::OkStatus(); } core::ScopedUnref context_unref(context); @@ -768,7 +788,7 @@ Status EagerServiceImpl::CloseContext(const CloseContextRequest* request, << request->context_view_id() << " for context_id " << request->context_id() << ". The current context_view_id is " << context->Context()->GetContextViewId() << "."; - return OkStatus(); + return absl::OkStatus(); } mutex_lock l(contexts_mu_); @@ -779,7 +799,7 @@ Status EagerServiceImpl::CloseContext(const CloseContextRequest* request, // we are releasing it from the map. context->Unref(); - return OkStatus(); + return absl::OkStatus(); } Status EagerServiceImpl::RegisterFunction( @@ -804,7 +824,7 @@ Status EagerServiceImpl::RemoveFunction(const RemoveFunctionOp& remove_function, Status EagerServiceImpl::CleanupFunction( const CleanupFunctionOp& cleanup_function) { env_->rendezvous_mgr->Cleanup(cleanup_function.step_id()); - return OkStatus(); + return absl::OkStatus(); } Status EagerServiceImpl::SendTensor(const SendTensorOp& send_tensor, @@ -831,7 +851,7 @@ Status EagerServiceImpl::SendTensor(const SendTensorOp& send_tensor, eager_context->RemoteMgr()->AddOperationOutputs(tensors, send_tensor.op_id()); - return OkStatus(); + return absl::OkStatus(); } Status EagerServiceImpl::SendPackedHandle( @@ -877,7 +897,7 @@ Status EagerServiceImpl::SendPackedHandle( eager_context->RemoteMgr()->AddOperationOutputs({packed_handle}, send_packed_handle.op_id()); - return OkStatus(); + return absl::OkStatus(); } tensorflow::Status EagerServiceImpl::GetServerContext( @@ -897,7 +917,7 @@ tensorflow::Status EagerServiceImpl::GetServerContext( (*server_context)->RecordAccess(); - return OkStatus(); + return absl::OkStatus(); } } // namespace eager diff --git a/tensorflow/core/distributed_runtime/eager/eager_service_impl_test.cc b/tensorflow/core/distributed_runtime/eager/eager_service_impl_test.cc index 2ab6631de71d9b..84cfe637697c70 100644 --- a/tensorflow/core/distributed_runtime/eager/eager_service_impl_test.cc +++ b/tensorflow/core/distributed_runtime/eager/eager_service_impl_test.cc @@ -60,7 +60,7 @@ class TestEagerServiceImpl : public EagerServiceImpl { TF_RETURN_IF_ERROR(GetServerContext(context_id, &context)); core::ScopedUnref context_unref(context); *ctx = context->Context(); - return OkStatus(); + return absl::OkStatus(); } Status GetTensorHandle(const uint64 context_id, const RemoteTensorHandleInternal& remote_handle, @@ -139,7 +139,7 @@ class DummyEagerClientCache : public EagerClientCache { core::RefCountPtr* client) override { client->reset(client_.get()); client_->Ref(); - return OkStatus(); + return absl::OkStatus(); } private: @@ -150,7 +150,7 @@ class FakeCache : public TestWorkerCache { Status GetEagerClientCache( std::unique_ptr* eager_client_cache) override { *eager_client_cache = std::make_unique(); - return OkStatus(); + return absl::OkStatus(); } void ListWorkers(std::vector* workers) const override { @@ -168,7 +168,7 @@ class EagerServiceImplTest : public ::testing::Test { [](const ServerDef& server_def, WorkerCacheInterface** worker_cache) { *worker_cache = new FakeCache; - return OkStatus(); + return absl::OkStatus(); }, /*coordination_handler=*/nullptr)) { worker_env_.env = Env::Default(); @@ -1009,7 +1009,7 @@ class FunctionWithRemoteInputsTest : public EagerServiceImplTest { tsl::core::RefCountPtr* r) { *r = tsl::core::RefCountPtr( worker_env_.rendezvous_mgr->Find(step_id).release()); - return OkStatus(); + return absl::OkStatus(); }}); } @@ -1100,7 +1100,7 @@ TEST_F(FunctionWithRemoteInputsTest, EagerPFLRTest) { std::move(tensor_args), [&inputs](const int i, RemoteTensorHandle* handle) -> Status { *handle = inputs.at(i); - return OkStatus(); + return absl::OkStatus(); }); eager_pflr_->Run(opts, handle, args, &outputs, [&status, &done](const Status& s) { @@ -1204,7 +1204,7 @@ TEST_F(FunctionWithRemoteInputsTest, KernelAndDeviceFuncTest) { std::move(input_tensors), [&remote_handles](const int index, RemoteTensorHandle* handle) -> Status { *handle = remote_handles.at(index); - return OkStatus(); + return absl::OkStatus(); }); std::vector outputs; @@ -1259,7 +1259,7 @@ TEST_F(FunctionWithRemoteInputsTest, KernelAndDeviceFuncAsyncTest) { std::move(input_tensors), [&remote_handles](const int index, RemoteTensorHandle* handle) -> Status { *handle = remote_handles.at(index); - return OkStatus(); + return absl::OkStatus(); }); std::vector outputs; diff --git a/tensorflow/core/distributed_runtime/eager/remote_copy_node.cc b/tensorflow/core/distributed_runtime/eager/remote_copy_node.cc index bd5bc39622b9d6..afcbce0b0a1215 100644 --- a/tensorflow/core/distributed_runtime/eager/remote_copy_node.cc +++ b/tensorflow/core/distributed_runtime/eager/remote_copy_node.cc @@ -216,7 +216,7 @@ Status RemoteCopyNode::RunLocalRecv(EagerOperation* op, "Expect to receive a Tensor but got a TensorShape."); } } - return OkStatus(); + return absl::OkStatus(); } void RemoteCopyNode::RunRemoteRecv(EagerOperation* op, StatusCallback done) { @@ -358,7 +358,7 @@ Status SerializePackedHandle(const uint64 op_id, TensorHandle* packed_handle, return errors::InvalidArgument("Nested packed handles are not supported"); } } - return OkStatus(); + return absl::OkStatus(); } void RemoteCopyNode::StartSendPackedHandle(StatusCallback done) { @@ -479,7 +479,7 @@ void RemoteCopyNode::StartRemoteSendTensor(StatusCallback done) { Status RemoteCopyNode::Prepare() { TF_RETURN_IF_ERROR(captured_state_->dst()->CopyInferenceShape(src_)); - return OkStatus(); + return absl::OkStatus(); } void RemoteCopyNode::RunAsync(StatusCallback done) { diff --git a/tensorflow/core/distributed_runtime/eager/remote_mgr.cc b/tensorflow/core/distributed_runtime/eager/remote_mgr.cc index 2bec995a1f616e..313c60ccb53bea 100644 --- a/tensorflow/core/distributed_runtime/eager/remote_mgr.cc +++ b/tensorflow/core/distributed_runtime/eager/remote_mgr.cc @@ -76,7 +76,7 @@ Status RemoteMgr::GetTensorHandleImpl( *handle = iter->second; - return OkStatus(); + return absl::OkStatus(); } Status RemoteMgr::GetTensorHandle( @@ -105,7 +105,7 @@ Status RemoteMgr::GetMirroredResourceShape( *handle = iter->second; - return OkStatus(); + return absl::OkStatus(); } Status RemoteMgr::GetRemoteTensorHandle(const tensorflow::TensorHandle* handle, @@ -121,7 +121,7 @@ Status RemoteMgr::GetRemoteTensorHandle(const tensorflow::TensorHandle* handle, "Found two different tensor handles with the same op_id:", *op_id, " and output_num:", *output_num)); } - return OkStatus(); + return absl::OkStatus(); } Status RemoteMgr::DeleteTensorHandle( @@ -132,7 +132,7 @@ Status RemoteMgr::DeleteTensorHandle( if (iter != remote_tensor_handle_map_.end()) { iter->second->Unref(); remote_tensor_handle_map_.erase(iter); - return OkStatus(); + return absl::OkStatus(); } } { @@ -140,7 +140,7 @@ Status RemoteMgr::DeleteTensorHandle( auto iter = mirrored_resource_shape_map_.find(remote_handle); if (iter != mirrored_resource_shape_map_.end()) { mirrored_resource_shape_map_.erase(iter); - return OkStatus(); + return absl::OkStatus(); } } return WithErrorSourcePayload(errors::InvalidArgument( @@ -176,7 +176,7 @@ Status RemoteMgr::SerializeRemoteTensorHandle( dtype_and_shape.shape.AsProto(dtype_and_shape_proto->mutable_shape()); } } - return OkStatus(); + return absl::OkStatus(); } Status RemoteMgr::DeserializeRemoteTensorHandle(const RemoteTensorHandle& in, @@ -214,7 +214,7 @@ Status RemoteMgr::DeserializeRemoteTensorHandle(const RemoteTensorHandle& in, (*out)->SetResourceHandleDtypeAndShape(std::move(dtypes_and_shapes)); } - return OkStatus(); + return absl::OkStatus(); } EagerExecutor& RemoteMgr::GetOrCreateExecutorForStream(uint64 stream_id) { diff --git a/tensorflow/core/distributed_runtime/eager/remote_tensor_handle_data.cc b/tensorflow/core/distributed_runtime/eager/remote_tensor_handle_data.cc index 48952917f60f69..e89c29774e7ff8 100644 --- a/tensorflow/core/distributed_runtime/eager/remote_tensor_handle_data.cc +++ b/tensorflow/core/distributed_runtime/eager/remote_tensor_handle_data.cc @@ -131,7 +131,7 @@ Status RemoteTensorHandleData::Shape(TensorShape* shape) const { tf_shared_lock l(mu_); *shape = shape_; - return OkStatus(); + return absl::OkStatus(); } Status RemoteTensorHandleData::NumDims(int* num_dims) const { @@ -140,7 +140,7 @@ Status RemoteTensorHandleData::NumDims(int* num_dims) const { tf_shared_lock l(mu_); *num_dims = shape_.dims(); - return OkStatus(); + return absl::OkStatus(); } Status RemoteTensorHandleData::Dim(int dim_index, int64_t* dim) const { @@ -149,7 +149,7 @@ Status RemoteTensorHandleData::Dim(int dim_index, int64_t* dim) const { tf_shared_lock l(mu_); *dim = shape_.dim_size(dim_index); - return OkStatus(); + return absl::OkStatus(); } Status RemoteTensorHandleData::NumElements(int64_t* num_elements) const { @@ -158,7 +158,7 @@ Status RemoteTensorHandleData::NumElements(int64_t* num_elements) const { tf_shared_lock l(mu_); *num_elements = shape_.num_elements(); - return OkStatus(); + return absl::OkStatus(); } bool RemoteTensorHandleData::IsReady() const { @@ -203,17 +203,17 @@ Status RemoteTensorHandleData::SetShapeAndRemoteTask( " from existing shape of ", shape_.DebugString())); } LOG(WARNING) << "SetShape can only be called on non-ready handles."; - return OkStatus(); + return absl::OkStatus(); } shape_ = shape; if (!remote_task.empty()) { remote_task_ = remote_task; } - is_poisoned_ = OkStatus(); + is_poisoned_ = absl::OkStatus(); is_ready_ = true; - return OkStatus(); + return absl::OkStatus(); } string RemoteTensorHandleData::DebugString() const { @@ -229,7 +229,7 @@ Status RemoteTensorHandleData::OpIdAndOutputNum(const bool wait_until_ready, } *op_id = op_id_; *output_num = output_num_; - return OkStatus(); + return absl::OkStatus(); } Status RemoteTensorHandleData::WaitReady(const char* caller) const { diff --git a/tensorflow/core/distributed_runtime/eager/remote_tensor_handle_data.h b/tensorflow/core/distributed_runtime/eager/remote_tensor_handle_data.h index c5b348eca2fadc..92f0a66ebbbba7 100644 --- a/tensorflow/core/distributed_runtime/eager/remote_tensor_handle_data.h +++ b/tensorflow/core/distributed_runtime/eager/remote_tensor_handle_data.h @@ -45,7 +45,7 @@ class RemoteTensorHandleData { Status NumDims(int* num_dims) const; Status Dim(int dim_index, int64_t* dim) const; Status NumElements(int64_t* num_elements) const; - Status Unprotect() { return OkStatus(); } + Status Unprotect() { return absl::OkStatus(); } bool IsReady() const; Status WaitReady(const char* caller) const; diff --git a/tensorflow/core/distributed_runtime/graph_mgr.cc b/tensorflow/core/distributed_runtime/graph_mgr.cc index dd749b9e86810a..cb827d584e260b 100644 --- a/tensorflow/core/distributed_runtime/graph_mgr.cc +++ b/tensorflow/core/distributed_runtime/graph_mgr.cc @@ -98,7 +98,7 @@ static Status ValidateGraphDefForDevices(const GraphDef& gdef) { FormatNodeDefForError(ndef)); } } - return OkStatus(); + return absl::OkStatus(); } Status GraphMgr::DecorateAndPublishGraphForDebug( @@ -108,7 +108,7 @@ Status GraphMgr::DecorateAndPublishGraphForDebug( DebugGraphDecoratorRegistry::CreateDecorator(debug_options, &decorator)); TF_RETURN_IF_ERROR(decorator->DecorateGraph(graph, device)); TF_RETURN_IF_ERROR(decorator->PublishGraph(*graph, device->name())); - return OkStatus(); + return absl::OkStatus(); } // Creates executors given a graph definition "gdef" of a "session". @@ -150,7 +150,7 @@ Status GraphMgr::InitItem(const string& handle, const GraphDef& gdef, this->worker_env_->rendezvous_mgr->Find(step_id); TF_RETURN_IF_ERROR(remote_r->Initialize(session)); *r = std::move(remote_r); - return OkStatus(); + return absl::OkStatus(); }})); // Constructs the graph out of "gdef". @@ -288,7 +288,7 @@ Status GraphMgr::InitItem(const string& handle, const GraphDef& gdef, } TF_RETURN_IF_ERROR(NewLocalExecutor(params, *unit->graph, &unit->root)); } - return OkStatus(); + return absl::OkStatus(); } Status GraphMgr::Register(const string& handle, const GraphDef& gdef, @@ -314,7 +314,7 @@ Status GraphMgr::Register(const string& handle, const GraphDef& gdef, item->handle = *graph_handle; CHECK(table_.insert({*graph_handle, item}).second); } - return OkStatus(); + return absl::OkStatus(); } Status GraphMgr::Deregister(const string& handle) { @@ -331,7 +331,7 @@ Status GraphMgr::Deregister(const string& handle) { table_.erase(iter); } item->Unref(); - return OkStatus(); + return absl::OkStatus(); } Status GraphMgr::DeregisterAll() { @@ -347,7 +347,7 @@ Status GraphMgr::DeregisterAll() { for (auto item : items) { item->Unref(); } - return OkStatus(); + return absl::OkStatus(); } Status GraphMgr::SendInputs(const int64_t step_id, const NamedTensors& in) { diff --git a/tensorflow/core/distributed_runtime/local_master.cc b/tensorflow/core/distributed_runtime/local_master.cc index 76f8d4f37b20e6..a41f977b059f82 100644 --- a/tensorflow/core/distributed_runtime/local_master.cc +++ b/tensorflow/core/distributed_runtime/local_master.cc @@ -43,7 +43,7 @@ Status WaitForNotification(CallOptions* call_options, } else { n->WaitForNotification(); } - return OkStatus(); + return absl::OkStatus(); } } // namespace diff --git a/tensorflow/core/distributed_runtime/master.cc b/tensorflow/core/distributed_runtime/master.cc index 3992a211ef68b6..2602c471a73671 100644 --- a/tensorflow/core/distributed_runtime/master.cc +++ b/tensorflow/core/distributed_runtime/master.cc @@ -144,7 +144,7 @@ class DeviceFinder { finder.Start(); TF_RETURN_IF_ERROR(finder.Wait()); finder.GetRemoteDevices(env->local_devices, out_remote); - return OkStatus(); + return absl::OkStatus(); } static void GetRemoteWorkers( diff --git a/tensorflow/core/distributed_runtime/master_session.cc b/tensorflow/core/distributed_runtime/master_session.cc index e54c577bc012a5..5593963988d9e5 100644 --- a/tensorflow/core/distributed_runtime/master_session.cc +++ b/tensorflow/core/distributed_runtime/master_session.cc @@ -870,7 +870,7 @@ Status MasterSession::ReffedClientGraph::RunPartitions( } fetch_proto->Swap(&iter->second); } - return OkStatus(); + return absl::OkStatus(); } namespace { @@ -967,7 +967,7 @@ void MasterSession::ReffedClientGraph::ProcessStats(int64_t step_id, } ph->StepDone(pss->start_micros, pss->end_micros, Microseconds(0) /*cleanup_time*/, 0 /*total_runops*/, - OkStatus()); + absl::OkStatus()); } // Assemble all stats for this timeline into a merged StepStats. if (pss->collect_timeline) { @@ -1091,7 +1091,7 @@ Status MasterSession::ReffedClientGraph::CheckFetches( } } } - return OkStatus(); + return absl::OkStatus(); } // Asynchronously deregisters subgraphs on the workers, without waiting for the @@ -1304,7 +1304,7 @@ Status MasterSession::CreateWorkerSessions(const ClusterDef& cluster_def) { // Request and responses used for a given worker. CreateWorkerSessionRequest request; CreateWorkerSessionResponse response; - Status status = OkStatus(); + Status status = absl::OkStatus(); }; BlockingCounter done(worker_names.size()); std::vector workers(worker_names.size()); @@ -1325,7 +1325,7 @@ Status MasterSession::CreateWorkerSessions(const ClusterDef& cluster_def) { const int64_t client_device_incarnation = devices_->client_device()->attributes().incarnation(); - Status status = OkStatus(); + Status status = absl::OkStatus(); // Create all the workers & kick off the computations. for (size_t i = 0; i < worker_names.size(); ++i) { workers[i].name = &worker_names[i]; @@ -1439,7 +1439,7 @@ Status MasterSession::DeleteWorkerSessions() { // Request and responses used for a given worker. DeleteWorkerSessionRequest request; DeleteWorkerSessionResponse response; - Status status = OkStatus(); + Status status = absl::OkStatus(); }; BlockingCounter done(worker_names.size()); std::vector workers(worker_names.size()); @@ -1453,7 +1453,7 @@ Status MasterSession::DeleteWorkerSessions() { } }); - Status status = OkStatus(); + Status status = absl::OkStatus(); // Create all the workers & kick off the computations. for (size_t i = 0; i < worker_names.size(); ++i) { workers[i].name = &worker_names[i]; @@ -1501,7 +1501,7 @@ Status MasterSession::ListDevices(ListDevicesResponse* resp) const { *(resp->add_local_device()) = dev->attributes(); } } - return OkStatus(); + return absl::OkStatus(); } Status MasterSession::Extend(const ExtendSessionRequest* req, @@ -1530,7 +1530,7 @@ Status MasterSession::Extend(const ExtendSessionRequest* req, ++graph_version_; resp->set_new_graph_version(graph_version_); } - return OkStatus(); + return absl::OkStatus(); } WorkerCacheInterface* MasterSession::get_worker_cache() const { @@ -1571,7 +1571,7 @@ Status MasterSession::StartStep(const BuildGraphOptions& opts, bool is_partial, (*out_rcg)->Ref(); *out_count = (*out_rcg)->get_and_increment_execution_count(); } - return OkStatus(); + return absl::OkStatus(); } void MasterSession::ClearRunsTable(std::vector* to_unref, @@ -1654,7 +1654,7 @@ Status MasterSession::PartialRunSetup(const PartialRunSetupRequest* req, TF_RETURN_IF_ERROR(BuildAndRegisterPartitions(rcg)); resp->set_partial_run_handle(handle); - return OkStatus(); + return absl::OkStatus(); } Status MasterSession::Run(CallOptions* opts, const RunStepRequestWrapper& req, @@ -1727,7 +1727,7 @@ Status MasterSession::BuildAndRegisterPartitions(ReffedClientGraph* rcg) { TF_RETURN_IF_ERROR(rcg->RegisterPartitions(std::move(popts))); - return OkStatus(); + return absl::OkStatus(); } Status MasterSession::DoPartialRun(CallOptions* opts, @@ -1888,7 +1888,7 @@ Status MasterSession::CreateDebuggerState( debug_options.global_step(), rcg_execution_count, rcg_execution_count, input_names, output_names, target_names)); - return OkStatus(); + return absl::OkStatus(); } void MasterSession::FillPerStepState(MasterSession::ReffedClientGraph* rcg, @@ -2049,7 +2049,7 @@ Status MasterSession::MakeCallable(const MakeCallableRequest& req, } resp->set_handle(handle); - return OkStatus(); + return absl::OkStatus(); } Status MasterSession::DoRunCallable(CallOptions* opts, ReffedClientGraph* rcg, @@ -2123,7 +2123,7 @@ Status MasterSession::ReleaseCallable(const ReleaseCallableRequest& req, if (to_unref != nullptr) { to_unref->Unref(); } - return OkStatus(); + return absl::OkStatus(); } Status MasterSession::Close() { @@ -2149,7 +2149,7 @@ Status MasterSession::Close() { LOG(WARNING) << s; } } - return OkStatus(); + return absl::OkStatus(); } void MasterSession::GarbageCollect() { diff --git a/tensorflow/core/distributed_runtime/message_wrappers.cc b/tensorflow/core/distributed_runtime/message_wrappers.cc index 35bc935496ea91..d3ae6f252e8482 100644 --- a/tensorflow/core/distributed_runtime/message_wrappers.cc +++ b/tensorflow/core/distributed_runtime/message_wrappers.cc @@ -59,13 +59,13 @@ const string& InMemoryRunStepRequest::feed_name(size_t i) const { Status InMemoryRunStepRequest::FeedValue(size_t i, Tensor* out_tensor) const { *out_tensor = feeds_[i].second; - return OkStatus(); + return absl::OkStatus(); } Status InMemoryRunStepRequest::FeedValue(size_t i, TensorProto* out_tensor) const { feeds_[i].second.AsProtoTensorContent(out_tensor); - return OkStatus(); + return absl::OkStatus(); } void InMemoryRunStepRequest::add_feed(const string& name, const Tensor& value) { @@ -155,14 +155,14 @@ Status MutableProtoRunStepRequest::FeedValue(size_t i, if (!ParseTensorProtoToTensor(request_.feed(i).tensor(), out_tensor)) { return errors::InvalidArgument("Invalid TensorProto for feed value ", i); } else { - return OkStatus(); + return absl::OkStatus(); } } Status MutableProtoRunStepRequest::FeedValue(size_t i, TensorProto* out_tensor) const { *out_tensor = request_.feed(i).tensor(); - return OkStatus(); + return absl::OkStatus(); } void MutableProtoRunStepRequest::add_feed(const string& name, @@ -246,13 +246,13 @@ Status ProtoRunStepRequest::FeedValue(size_t i, Tensor* out_tensor) const { if (!ParseTensorProtoToTensor(request_->feed(i).tensor(), out_tensor)) { return errors::InvalidArgument("Invalid TensorProto for feed value ", i); } else { - return OkStatus(); + return absl::OkStatus(); } } Status ProtoRunStepRequest::FeedValue(size_t i, TensorProto* out_tensor) const { *out_tensor = request_->feed(i).tensor(); - return OkStatus(); + return absl::OkStatus(); } size_t ProtoRunStepRequest::num_fetches() const { @@ -335,7 +335,7 @@ const string& InMemoryRunGraphRequest::send_key(size_t i) const { Status InMemoryRunGraphRequest::SendValue(size_t i, Tensor* out_tensor) const { *out_tensor = sends_[i].second; - return OkStatus(); + return absl::OkStatus(); } Status InMemoryRunGraphRequest::AddSendFromRunStepRequest( @@ -344,7 +344,7 @@ Status InMemoryRunGraphRequest::AddSendFromRunStepRequest( Tensor tensor; TF_RETURN_IF_ERROR(run_step_request.FeedValue(i, &tensor)); sends_.emplace_back(send_key, std::move(tensor)); - return OkStatus(); + return absl::OkStatus(); } Status InMemoryRunGraphRequest::AddSendFromRunCallableRequest( @@ -355,7 +355,7 @@ Status InMemoryRunGraphRequest::AddSendFromRunCallableRequest( return errors::InvalidArgument("Invalid TensorProto for feed value ", i); } sends_.emplace_back(send_key, std::move(tensor)); - return OkStatus(); + return absl::OkStatus(); } size_t InMemoryRunGraphRequest::num_recvs() const { return recvs_.size(); } @@ -478,7 +478,7 @@ Status MutableProtoRunGraphRequest::SendValue(size_t i, if (!ParseTensorProtoToTensor(request_.send(i).tensor(), out_tensor)) { return errors::InvalidArgument("Invalid TensorProto for feed value ", i); } else { - return OkStatus(); + return absl::OkStatus(); } } @@ -488,7 +488,7 @@ Status MutableProtoRunGraphRequest::AddSendFromRunStepRequest( NamedTensorProto* send = request_.add_send(); send->set_name(send_key); TF_RETURN_IF_ERROR(run_step_request.FeedValue(i, send->mutable_tensor())); - return OkStatus(); + return absl::OkStatus(); } Status MutableProtoRunGraphRequest::AddSendFromRunCallableRequest( @@ -497,7 +497,7 @@ Status MutableProtoRunGraphRequest::AddSendFromRunCallableRequest( NamedTensorProto* send = request_.add_send(); send->set_name(send_key); *send->mutable_tensor() = run_callable_request.feed(i); - return OkStatus(); + return absl::OkStatus(); } size_t MutableProtoRunGraphRequest::num_recvs() const { @@ -581,7 +581,7 @@ Status ProtoRunGraphRequest::SendValue(size_t i, Tensor* out_tensor) const { if (!ParseTensorProtoToTensor(request_->send(i).tensor(), out_tensor)) { return errors::InvalidArgument("Invalid TensorProto for feed value ", i); } else { - return OkStatus(); + return absl::OkStatus(); } } @@ -619,12 +619,12 @@ const string& InMemoryRunGraphResponse::recv_key(size_t i) const { Status InMemoryRunGraphResponse::RecvValue(size_t i, TensorProto* out_tensor) { recvs_[i].second.AsProtoTensorContent(out_tensor); - return OkStatus(); + return absl::OkStatus(); } Status InMemoryRunGraphResponse::RecvValue(size_t i, Tensor* out_tensor) { *out_tensor = recvs_[i].second; - return OkStatus(); + return absl::OkStatus(); } void InMemoryRunGraphResponse::AddRecv(const string& key, const Tensor& value) { @@ -678,14 +678,14 @@ const string& OwnedProtoRunGraphResponse::recv_key(size_t i) const { Status OwnedProtoRunGraphResponse::RecvValue(size_t i, TensorProto* out_tensor) { out_tensor->Swap(response_.mutable_recv(i)->mutable_tensor()); - return OkStatus(); + return absl::OkStatus(); } Status OwnedProtoRunGraphResponse::RecvValue(size_t i, Tensor* out_tensor) { if (!ParseTensorProtoToTensor(response_.recv(i).tensor(), out_tensor)) { return errors::InvalidArgument("Invalid TensorProto for recv value ", i); } else { - return OkStatus(); + return absl::OkStatus(); } } @@ -750,14 +750,14 @@ const string& NonOwnedProtoRunGraphResponse::recv_key(size_t i) const { Status NonOwnedProtoRunGraphResponse::RecvValue(size_t i, TensorProto* out_tensor) { out_tensor->Swap(response_->mutable_recv(i)->mutable_tensor()); - return OkStatus(); + return absl::OkStatus(); } Status NonOwnedProtoRunGraphResponse::RecvValue(size_t i, Tensor* out_tensor) { if (!ParseTensorProtoToTensor(response_->recv(i).tensor(), out_tensor)) { return errors::InvalidArgument("Invalid TensorProto for recv value ", i); } else { - return OkStatus(); + return absl::OkStatus(); } } @@ -820,7 +820,7 @@ const string& InMemoryRunStepResponse::tensor_name(size_t i) const { Status InMemoryRunStepResponse::TensorValue(size_t i, Tensor* out_tensor) const { *out_tensor = tensors_[i].second; - return OkStatus(); + return absl::OkStatus(); } const RunMetadata& InMemoryRunStepResponse::metadata() const { @@ -832,7 +832,7 @@ Status InMemoryRunStepResponse::AddTensorFromRunGraphResponse( Tensor tensor; TF_RETURN_IF_ERROR(wrapper->RecvValue(i, &tensor)); tensors_.emplace_back(name, tensor); - return OkStatus(); + return absl::OkStatus(); } RunMetadata* InMemoryRunStepResponse::mutable_metadata() { return &metadata_; } @@ -865,7 +865,7 @@ Status OwnedProtoRunStepResponse::TensorValue(size_t i, if (!ParseTensorProtoToTensor(response_.tensor(i).tensor(), out_tensor)) { return errors::InvalidArgument("Invalid TensorProto for fetch value ", i); } else { - return OkStatus(); + return absl::OkStatus(); } } @@ -918,7 +918,7 @@ Status NonOwnedProtoRunStepResponse::TensorValue(size_t i, if (!ParseTensorProtoToTensor(response_->tensor(i).tensor(), out_tensor)) { return errors::InvalidArgument("Invalid TensorProto for fetch value ", i); } else { - return OkStatus(); + return absl::OkStatus(); } } diff --git a/tensorflow/core/distributed_runtime/partial_run_mgr_test.cc b/tensorflow/core/distributed_runtime/partial_run_mgr_test.cc index f0277c5aeb94ee..a2d56e5b2fd6ad 100644 --- a/tensorflow/core/distributed_runtime/partial_run_mgr_test.cc +++ b/tensorflow/core/distributed_runtime/partial_run_mgr_test.cc @@ -65,15 +65,15 @@ TEST(PartialRunMgr, PartialRunRemoved) { int called = 0; partial_run_mgr.PartialRunDone( - step_id, [&called](Status status) { called++; }, OkStatus()); - partial_run_mgr.ExecutorDone(step_id, OkStatus()); + step_id, [&called](Status status) { called++; }, absl::OkStatus()); + partial_run_mgr.ExecutorDone(step_id, absl::OkStatus()); // Calling ExecutorDone and PartialRunDone on the step_id should still only // result in the callback being called once. // This proves that the original PartialRun has been removed. partial_run_mgr.PartialRunDone( - step_id, [&called](Status status) { called++; }, OkStatus()); - partial_run_mgr.ExecutorDone(step_id, OkStatus()); + step_id, [&called](Status status) { called++; }, absl::OkStatus()); + partial_run_mgr.ExecutorDone(step_id, absl::OkStatus()); EXPECT_EQ(1, called); } @@ -142,9 +142,9 @@ Status PartialRunError() { return errors::Internal("partial run error"); } INSTANTIATE_TEST_SUITE_P( PartialRunMgr, StatusPropagationTest, ::testing::Values( - StatusTestParam{OkStatus(), OkStatus(), OkStatus()}, - StatusTestParam{ExecutorError(), OkStatus(), ExecutorError()}, - StatusTestParam{OkStatus(), PartialRunError(), PartialRunError()}, + StatusTestParam{absl::OkStatus(), absl::OkStatus(), absl::OkStatus()}, + StatusTestParam{ExecutorError(), absl::OkStatus(), ExecutorError()}, + StatusTestParam{absl::OkStatus(), PartialRunError(), PartialRunError()}, StatusTestParam{ExecutorError(), PartialRunError(), ExecutorError()})); } // namespace diff --git a/tensorflow/core/distributed_runtime/recent_request_ids.cc b/tensorflow/core/distributed_runtime/recent_request_ids.cc index 490755f1125fec..e7ea66286fb341 100644 --- a/tensorflow/core/distributed_runtime/recent_request_ids.cc +++ b/tensorflow/core/distributed_runtime/recent_request_ids.cc @@ -64,7 +64,7 @@ Status RecentRequestIds::TrackUnique(int64_t request_id, const string& method_name, const protobuf::Message& request) { if (Insert(request_id)) { - return OkStatus(); + return absl::OkStatus(); } else { return errors::Aborted("The same ", method_name, " request was received twice. ", diff --git a/tensorflow/core/distributed_runtime/recent_request_ids.h b/tensorflow/core/distributed_runtime/recent_request_ids.h index 5cb8b82634f82d..bc2b1a3d615035 100644 --- a/tensorflow/core/distributed_runtime/recent_request_ids.h +++ b/tensorflow/core/distributed_runtime/recent_request_ids.h @@ -91,7 +91,7 @@ Status RecentRequestIds::TrackUnique(int64_t request_id, const string& method_name, const RequestWrapper* wrapper) { if (Insert(request_id)) { - return OkStatus(); + return absl::OkStatus(); } else { return errors::Aborted("The same ", method_name, " request was received twice. ", diff --git a/tensorflow/core/distributed_runtime/remote_device.cc b/tensorflow/core/distributed_runtime/remote_device.cc index a7a56435d829b7..9e6479b0e77ba1 100644 --- a/tensorflow/core/distributed_runtime/remote_device.cc +++ b/tensorflow/core/distributed_runtime/remote_device.cc @@ -39,7 +39,7 @@ class RemoteDevice : public Device { : Device(env, da), local_dev_name_(DeviceNameUtils::LocalName(da.name())) {} - Status Sync() override { return OkStatus(); } + Status Sync() override { return absl::OkStatus(); } Allocator* GetAllocator(AllocatorAttributes attr) override { return nullptr; } ResourceMgr* resource_manager() override { diff --git a/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_client.cc b/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_client.cc index d3b0280dcd500a..8142ea84660a36 100644 --- a/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_client.cc +++ b/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_client.cc @@ -154,11 +154,18 @@ class GrpcEagerClient : public EagerClient { void method##Async(const method##Request* request, \ method##Response* response, StatusCallback done, \ int64_t init_timeout_in_ms, int retries) override { \ - StatusCallback done_wrapped = callback_wrapper(std::move(done)); \ CallOptions* call_ops = nullptr; \ + StatusCallback done_wrapped; \ if (init_timeout_in_ms > 0) { \ call_ops = new CallOptions; \ call_ops->SetTimeout(init_timeout_in_ms); \ + auto new_done = [call_ops, done = std::move(done)](const Status& s) { \ + done(s); \ + delete call_ops; \ + }; \ + done_wrapped = callback_wrapper(new_done); \ + } else { \ + done_wrapped = callback_wrapper(std::move(done)); \ } \ new RPCState( \ &stub_, cq_, "/tensorflow.eager.EagerService/" #method, *request, \ @@ -317,7 +324,7 @@ class GrpcEagerClientCache : public EagerClientCache { it->second->Ref(); client->reset(it->second.get()); - return OkStatus(); + return absl::OkStatus(); } private: diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_master_service.cc b/tensorflow/core/distributed_runtime/rpc/grpc_master_service.cc index 141dcff71b15e2..50dc3d76889a02 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_master_service.cc +++ b/tensorflow/core/distributed_runtime/rpc/grpc_master_service.cc @@ -211,7 +211,7 @@ class GrpcMasterService : public tsl::AsyncServiceInterface { static_cast(status.code())); call->response.set_status_error_message( std::string(status.message())); - call->SendResponse(ToGrpcStatus(OkStatus())); + call->SendResponse(ToGrpcStatus(absl::OkStatus())); } else { call->SendResponse(ToGrpcStatus(status)); } diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc b/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc index 475573f93096e7..d5eacc7e6a16cd 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc +++ b/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc @@ -179,7 +179,7 @@ Status GrpcServer::GetHostAndPort(const ServerDef& server_def, "\" was not defined in cluster"); } - return OkStatus(); + return absl::OkStatus(); } Status GrpcServer::Init(const GrpcServerOptions& opts) { @@ -356,7 +356,7 @@ Status GrpcServer::Init(const GrpcServerOptions& opts) { LocalMaster::Register(target(), master_impl_.get(), config.operation_timeout_in_ms()); - return OkStatus(); + return absl::OkStatus(); } Status GrpcServer::ParseChannelSpec(const WorkerCacheFactoryOptions& options, @@ -383,7 +383,7 @@ Status GrpcServer::ParseChannelSpec(const WorkerCacheFactoryOptions& options, } TF_RETURN_IF_ERROR(channel_spec->AddHostPortsJob(job.name(), host_ports)); } - return OkStatus(); + return absl::OkStatus(); } Status GrpcServer::WorkerCacheFactory(const WorkerCacheFactoryOptions& options, @@ -422,7 +422,7 @@ Status GrpcServer::WorkerCacheFactory(const WorkerCacheFactoryOptions& options, } *worker_cache = NewGrpcWorkerCacheWithLocalWorker( channel_cache, grpc_worker_env(), worker_impl(), name_prefix); - return OkStatus(); + return absl::OkStatus(); } Status GrpcServer::Start() { @@ -455,11 +455,11 @@ Status GrpcServer::Start() { state_ = STARTED; LOG(INFO) << "Started server with target: " << target(); - return OkStatus(); + return absl::OkStatus(); } case STARTED: LOG(INFO) << "Server already started (target: " << target() << ")"; - return OkStatus(); + return absl::OkStatus(); case STOPPED: return errors::FailedPrecondition("Server has stopped."); default: @@ -502,7 +502,7 @@ Status GrpcServer::UpdateServerDef(const ServerDef& server_def) { master_env_.worker_cache = worker_cache; master_env_.collective_executor_mgr = worker_env_.collective_executor_mgr.get(); - return OkStatus(); + return absl::OkStatus(); } // TODO(haoyuzhang): Remove this method once we have a mechanism to directly set @@ -512,7 +512,7 @@ Status GrpcServer::SetCoordinationServiceAgentInstance( auto* coord_service = static_cast(coordination_service_); coord_service->SetCoordinationServiceAgentInstance(agent); - return OkStatus(); + return absl::OkStatus(); } Status GrpcServer::SetCoordinationServiceInstance( @@ -520,7 +520,7 @@ Status GrpcServer::SetCoordinationServiceInstance( auto* coord_service = static_cast(coordination_service_); coord_service->SetCoordinationServiceInstance(service); - return OkStatus(); + return absl::OkStatus(); } Status GrpcServer::StopCoordinationService() { @@ -534,7 +534,7 @@ Status GrpcServer::StopCoordinationService() { TF_RETURN_IF_ERROR(SetCoordinationServiceInstance(nullptr)); coordination_service_->Shutdown(); worker_env()->session_mgr->TeardownCoordinationService(); - return OkStatus(); + return absl::OkStatus(); } Status GrpcServer::Stop() { @@ -542,13 +542,13 @@ Status GrpcServer::Stop() { switch (state_) { case NEW: state_ = STOPPED; - return OkStatus(); + return absl::OkStatus(); case STARTED: return errors::Unimplemented( "Clean shutdown is not currently implemented"); case STOPPED: LOG(INFO) << "Server already stopped (target: " << target() << ")"; - return OkStatus(); + return absl::OkStatus(); default: LOG(FATAL); } @@ -560,7 +560,7 @@ Status GrpcServer::Join() { case NEW: // Prevent the server from being started subsequently. state_ = STOPPED; - return OkStatus(); + return absl::OkStatus(); case STARTED: case STOPPED: master_thread_.reset(); @@ -569,7 +569,7 @@ Status GrpcServer::Join() { for (auto& thread : extra_service_threads_) { thread.reset(); } - return OkStatus(); + return absl::OkStatus(); default: LOG(FATAL); } @@ -609,7 +609,7 @@ Status GrpcServer::Create(const ServerDef& server_def, Env* env, return s; } *out_server = std::move(ret); - return OkStatus(); + return absl::OkStatus(); } /* static */ @@ -627,7 +627,7 @@ Status GrpcServer::Create(const ServerDef& server_def, Env* env, return s; } out_server->reset(dynamic_cast(server.release())); - return OkStatus(); + return absl::OkStatus(); } namespace { diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_session.cc b/tensorflow/core/distributed_runtime/rpc/grpc_session.cc index ea7c59c2dcc503..7911ea2e59dc03 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_session.cc +++ b/tensorflow/core/distributed_runtime/rpc/grpc_session.cc @@ -63,7 +63,7 @@ Status GrpcSession::Create(const SessionOptions& options, } session->SetRemoteMaster(std::move(master)); *out_session = std::move(session); - return OkStatus(); + return absl::OkStatus(); } namespace { @@ -109,7 +109,7 @@ Status GrpcSession::Handle(string* out_handle) { return errors::InvalidArgument("A session is not created yet...."); } *out_handle = handle_; - return OkStatus(); + return absl::OkStatus(); } Status GrpcSession::CreateImpl(CallOptions* call_options, GraphDef graph) { @@ -281,7 +281,7 @@ Status GrpcSession::RunHelper( run_metadata->Swap(resp->mutable_metadata()); } - return OkStatus(); + return absl::OkStatus(); } Status GrpcSession::Run(const RunOptions& run_options, @@ -335,7 +335,7 @@ Status GrpcSession::PRunSetup(const std::vector& input_names, call_options.SetTimeout(options_.config.operation_timeout_in_ms()); TF_RETURN_IF_ERROR(master_->PartialRunSetup(&call_options, &req, &resp)); *handle = resp.partial_run_handle(); - return OkStatus(); + return absl::OkStatus(); } Status GrpcSession::PRun(const string& handle, @@ -353,7 +353,7 @@ Status GrpcSession::Close() { { mutex_lock l(mu_); if (handle_.empty()) { - return OkStatus(); + return absl::OkStatus(); } req.set_session_handle(handle_); handle_.clear(); @@ -398,7 +398,7 @@ Status GrpcSession::ListDevices(std::vector* response) { for (const auto& device_attr : resp.remote_device()) { response->emplace_back(device_attr); } - return OkStatus(); + return absl::OkStatus(); } void GrpcSession::SetRemoteMaster(std::unique_ptr master) { @@ -435,7 +435,7 @@ Status GrpcSession::MakeCallable(const CallableOptions& callable_options, call_options.SetTimeout(options_.config.operation_timeout_in_ms()); TF_RETURN_IF_ERROR(master_->MakeCallable(&call_options, &req, &resp)); *out_handle = resp.handle(); - return OkStatus(); + return absl::OkStatus(); } Status GrpcSession::RunCallable(CallableHandle handle, @@ -462,7 +462,7 @@ Status GrpcSession::RunCallable(CallableHandle handle, } fetch_tensors->push_back(std::move(fetch_tensor)); } - return OkStatus(); + return absl::OkStatus(); } Status GrpcSession::ReleaseCallable(CallableHandle handle) { @@ -486,7 +486,7 @@ class GrpcSessionFactory : public SessionFactory { std::unique_ptr session; TF_RETURN_IF_ERROR(GrpcSession::Create(options, &session)); *out_session = session.release(); - return OkStatus(); + return absl::OkStatus(); } // Invokes the session specific static method to reset containers. diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_state.h b/tensorflow/core/distributed_runtime/rpc/grpc_state.h index bcf044ff0b87c4..8fdc1d9437cf2f 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_state.h +++ b/tensorflow/core/distributed_runtime/rpc/grpc_state.h @@ -332,7 +332,7 @@ class StreamingRPCState : public UntypedStreamingRPCState { e = &exchanges_.GetFront(); mu_.unlock(); - e->Complete(OkStatus()); + e->Complete(absl::OkStatus()); { mutex_lock l(mu_); diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_tensorflow_server.cc b/tensorflow/core/distributed_runtime/rpc/grpc_tensorflow_server.cc index f9314e92e6631b..8c7a686dd002e6 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_tensorflow_server.cc +++ b/tensorflow/core/distributed_runtime/rpc/grpc_tensorflow_server.cc @@ -78,7 +78,7 @@ Status FillServerDef(const string& cluster_spec, const string& job_name, " is invalid (job \"", options->job_name(), "\" contains ", my_num_tasks, " tasks"); } - return OkStatus(); + return absl::OkStatus(); } } // namespace diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_testlib.cc b/tensorflow/core/distributed_runtime/rpc/grpc_testlib.cc index bd821118629981..0e9769cefd866e 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_testlib.cc +++ b/tensorflow/core/distributed_runtime/rpc/grpc_testlib.cc @@ -102,7 +102,7 @@ Status TestCluster::MakeTestCluster(const TestClusterConfig& config, TF_RETURN_IF_ERROR(session->ListDevices(&ret->devices_)); *out_cluster = std::move(ret); - return OkStatus(); + return absl::OkStatus(); } TestCluster::~TestCluster() { diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_testlib_server.cc b/tensorflow/core/distributed_runtime/rpc/grpc_testlib_server.cc index 1f93608522a88e..2900259a83867d 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_testlib_server.cc +++ b/tensorflow/core/distributed_runtime/rpc/grpc_testlib_server.cc @@ -106,7 +106,7 @@ Status FillServerDef(const string& job_spec, const string& job_name, ConfigProto* config = options->mutable_default_session_config(); (*config->mutable_device_count())["CPU"] = num_cpus; (*config->mutable_device_count())["GPU"] = num_gpus; - return OkStatus(); + return absl::OkStatus(); } } // namespace diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.cc b/tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.cc index 45d2015155de9b..2a18d0d28fe885 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.cc +++ b/tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.cc @@ -78,14 +78,14 @@ class GrpcWorkerCache : public WorkerCachePartial { Status GetEagerClientCache( std::unique_ptr* eager_client_cache) override { eager_client_cache->reset(eager::NewGrpcEagerClientCache(channel_cache_)); - return OkStatus(); + return absl::OkStatus(); } Status GetCoordinationClientCache(std::unique_ptr* coordination_client_cache) override { coordination_client_cache->reset( NewGrpcCoordinationClientCache(channel_cache_)); - return OkStatus(); + return absl::OkStatus(); } void SetLogging(bool v) override { logger_.SetLogging(v); } diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.cc b/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.cc index 888d13a398a418..2eaaeda2d571dc 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.cc +++ b/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.cc @@ -725,7 +725,7 @@ void GrpcWorker::LoggingAsync(const LoggingRequest* request, } } } - done(OkStatus()); + done(absl::OkStatus()); } void GrpcWorker::CleanupGraphAsync(const CleanupGraphRequest* request, diff --git a/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.cc b/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.cc index 1bf7183245a8a9..5f74b27223af7a 100644 --- a/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.cc +++ b/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.cc @@ -85,7 +85,7 @@ class RpcRecvTensorCall : public BaseRecvTensorCall { resp_.Clear(); { mutex_lock l(mu_); - status_ = OkStatus(); + status_ = absl::OkStatus(); } done_ = nullptr; } diff --git a/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr_test.cc b/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr_test.cc index ea374f7346c41c..fa784d27e0be39 100644 --- a/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr_test.cc +++ b/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr_test.cc @@ -63,7 +63,7 @@ class DummyWorker : public TestWorkerInterface { // RPC call objects. const int64_t t_us = random::New64() % 100 * 1000; Env::Default()->SleepForMicroseconds(t_us); - done(OkStatus()); + done(absl::OkStatus()); }); } }; @@ -103,7 +103,7 @@ static Device* CreateDevice(const char* type, const char* name) { class FakeDevice : public Device { public: explicit FakeDevice(const DeviceAttributes& attr) : Device(nullptr, attr) {} - Status Sync() override { return OkStatus(); } + Status Sync() override { return absl::OkStatus(); } Allocator* GetAllocator(AllocatorAttributes) override { return nullptr; } }; DeviceAttributes attr; @@ -320,7 +320,7 @@ TEST_F(RpcRendezvousMgrTest, RemoteRecvAsyncMany) { int num_requests = 10000; Tensor val(DT_STRING); mutex mu_; - Status status = OkStatus(); + Status status = absl::OkStatus(); BlockingCounter counter(num_requests); for (int i = 0; i < num_requests; i++) { diff --git a/tensorflow/core/distributed_runtime/rpc_collective_executor_mgr.cc b/tensorflow/core/distributed_runtime/rpc_collective_executor_mgr.cc index d89ab15689fafe..fd649181a38b97 100644 --- a/tensorflow/core/distributed_runtime/rpc_collective_executor_mgr.cc +++ b/tensorflow/core/distributed_runtime/rpc_collective_executor_mgr.cc @@ -80,7 +80,7 @@ void RpcCollectiveExecutorMgr::RefreshStepIdSequenceAsync( gks = it->second; } gks->next_step_id_ = NewRandomStepId(); - done(OkStatus()); + done(absl::OkStatus()); } else { WorkerInterface* wi = worker_cache_->GetOrCreateWorker(group_leader_); GetStepSequenceRequest* req = new GetStepSequenceRequest; @@ -124,7 +124,7 @@ void RpcCollectiveExecutorMgr::GetStepSequenceAsync( ss->set_graph_key(graph_key); ss->set_next_step_id(gks->next_step_id_); } - done(OkStatus()); + done(absl::OkStatus()); } } @@ -142,7 +142,7 @@ Status RpcCollectiveExecutorMgr::UpdateStepSequences( } gks->next_step_id_ = ss.next_step_id(); } - return OkStatus(); + return absl::OkStatus(); } int64_t RpcCollectiveExecutorMgr::NextStepId(int64_t graph_key) { diff --git a/tensorflow/core/distributed_runtime/server_lib.cc b/tensorflow/core/distributed_runtime/server_lib.cc index e9f622bcf2e3fd..a653a7999fed41 100644 --- a/tensorflow/core/distributed_runtime/server_lib.cc +++ b/tensorflow/core/distributed_runtime/server_lib.cc @@ -52,7 +52,7 @@ Status ServerFactory::GetFactory(const ServerDef& server_def, for (const auto& server_factory : *server_factories()) { if (server_factory.second->AcceptsOptions(server_def)) { *out_factory = server_factory.second; - return OkStatus(); + return absl::OkStatus(); } } diff --git a/tensorflow/core/distributed_runtime/server_lib_test.cc b/tensorflow/core/distributed_runtime/server_lib_test.cc index 2ac89ae67d554a..49abd7e7a639e9 100644 --- a/tensorflow/core/distributed_runtime/server_lib_test.cc +++ b/tensorflow/core/distributed_runtime/server_lib_test.cc @@ -28,7 +28,7 @@ class TestServerFactory : public ServerFactory { Status NewServer(const ServerDef& server_def, const Options& options, std::unique_ptr* out_server) override { - return OkStatus(); + return absl::OkStatus(); } }; @@ -45,7 +45,7 @@ TEST(ServerLibTest, NewServerNoFactoriesAccept) { server_def.set_protocol("fake_protocol"); std::unique_ptr server; Status s = NewServer(server_def, &server); - ASSERT_NE(s, OkStatus()); + ASSERT_NE(s, absl::OkStatus()); EXPECT_TRUE(absl::StrContains( s.message(), "No server factory registered for the given ServerDef")); EXPECT_TRUE( diff --git a/tensorflow/core/distributed_runtime/session_mgr.cc b/tensorflow/core/distributed_runtime/session_mgr.cc index f2826a45aecc1a..5fcd27c3c17b89 100644 --- a/tensorflow/core/distributed_runtime/session_mgr.cc +++ b/tensorflow/core/distributed_runtime/session_mgr.cc @@ -307,7 +307,7 @@ Status SessionMgr::CreateSession( activity_watcher::MaybeEnableMultiWorkersWatching( coordination_service_agent_.get()); } - return OkStatus(); + return absl::OkStatus(); } void SessionMgr::ResetDefaultWorkerCache(WorkerCacheInterface* worker_cache) { @@ -374,7 +374,7 @@ Status SessionMgr::UpdateSession( TF_RETURN_IF_ERROR(worker_session->UpdateWorkerCacheAndDevices( std::unique_ptr(worker_cache), std::move(added_remote_devices), removed_remote_devices)); - return OkStatus(); + return absl::OkStatus(); } Status SessionMgr::DeleteSession(const std::string& session) { @@ -383,7 +383,20 @@ Status SessionMgr::DeleteSession(const std::string& session) { if (it != sessions_.end()) { sessions_.erase(it); } - return OkStatus(); + return absl::OkStatus(); +} + +Status SessionMgr::DeleteAllSessions() { + std::map> tmp_sessions; + { + mutex_lock l(mu_); + swap(sessions_, tmp_sessions); + } + for (auto& session : tmp_sessions) { + session.second.reset(); + } + + return absl::OkStatus(); } Status SessionMgr::WorkerSessionForSessionLocked( @@ -406,7 +419,7 @@ Status SessionMgr::WorkerSessionForSessionLocked( *out_session = it->second; } } - return OkStatus(); + return absl::OkStatus(); } Status SessionMgr::WorkerSessionForSession( diff --git a/tensorflow/core/distributed_runtime/session_mgr.h b/tensorflow/core/distributed_runtime/session_mgr.h index fc16faa9ccbc2b..f3339f568747be 100644 --- a/tensorflow/core/distributed_runtime/session_mgr.h +++ b/tensorflow/core/distributed_runtime/session_mgr.h @@ -94,6 +94,9 @@ class SessionMgr { Status DeleteSession(const std::string& session); + // Deletes all existing sessions. + Status DeleteAllSessions(); + // Provides access to the coordination service agent. This method should only // be called after the agent has been initialized during session creation, or // an invalid nullptr is returned. Note: the agent is thread-safe and mutable. diff --git a/tensorflow/core/distributed_runtime/session_mgr_test.cc b/tensorflow/core/distributed_runtime/session_mgr_test.cc index 21982e7d86a33f..9e19a878750e77 100644 --- a/tensorflow/core/distributed_runtime/session_mgr_test.cc +++ b/tensorflow/core/distributed_runtime/session_mgr_test.cc @@ -61,7 +61,7 @@ class SessionMgrTest : public ::testing::Test { SessionMgr::WorkerCacheFactory factory_ = [](const ServerDef& server_def, WorkerCacheInterface** worker_cache) { *worker_cache = nullptr; // Set to null to make debugging easier. - return OkStatus(); + return absl::OkStatus(); }; SessionMgr mgr_; }; @@ -194,11 +194,13 @@ TEST_F(SessionMgrTest, CreateSessionWithMasterName) { cluster_device_attributes, true, master_name, new_incarnation)); - EXPECT_NE(mgr_.WorkerSessionForSession(sess_handle1, &session), OkStatus()) + EXPECT_NE(mgr_.WorkerSessionForSession(sess_handle1, &session), + absl::OkStatus()) << "Session for " << sess_handle1 << " should have been garbage collected."; - EXPECT_NE(mgr_.WorkerSessionForSession(sess_handle2, &session), OkStatus()) + EXPECT_NE(mgr_.WorkerSessionForSession(sess_handle2, &session), + absl::OkStatus()) << "Session for " << sess_handle2 << " should have been garbage collected."; diff --git a/tensorflow/core/distributed_runtime/tensor_coding.cc b/tensorflow/core/distributed_runtime/tensor_coding.cc index 70e7fda6bb99c1..4779fb5777742b 100644 --- a/tensorflow/core/distributed_runtime/tensor_coding.cc +++ b/tensorflow/core/distributed_runtime/tensor_coding.cc @@ -101,9 +101,9 @@ Status TensorResponse::ParseFrom(Source* source) { ClearTensor(); } already_used_ = true; - if (ParseFast(source)) return OkStatus(); + if (ParseFast(source)) return absl::OkStatus(); meta_.Clear(); - if (ParseSlow(source)) return OkStatus(); + if (ParseSlow(source)) return absl::OkStatus(); return errors::InvalidArgument("Cannot parse tensor from response"); } diff --git a/tensorflow/core/distributed_runtime/test_utils.h b/tensorflow/core/distributed_runtime/test_utils.h index 2e787c17d79eb9..ec8ba7be22da8e 100644 --- a/tensorflow/core/distributed_runtime/test_utils.h +++ b/tensorflow/core/distributed_runtime/test_utils.h @@ -187,7 +187,7 @@ class TestWorkerCache : public WorkerCacheInterface { auto it = localities_.find(device); if (it != localities_.end()) { *locality = it->second; - done(OkStatus()); + done(absl::OkStatus()); return; } done(errors::Internal("Device not found: ", device)); diff --git a/tensorflow/core/distributed_runtime/worker.cc b/tensorflow/core/distributed_runtime/worker.cc index 9d3f8187f8bdfe..42708852b36f36 100644 --- a/tensorflow/core/distributed_runtime/worker.cc +++ b/tensorflow/core/distributed_runtime/worker.cc @@ -53,7 +53,7 @@ void Worker::GetStatusAsync(CallOptions* opts, const GetStatusRequest* request, for (auto& d : devices) { response->add_device_attributes()->Swap(&d); } - done(OkStatus()); + done(absl::OkStatus()); } void Worker::CreateWorkerSessionAsync(const CreateWorkerSessionRequest* request, @@ -142,7 +142,7 @@ Status Worker::PrepareRunGraph(RunGraphRequestWrapper* req, for (size_t i = 0; i < req->num_recvs(); ++i) { out->insert({req->recv_key(i), empty_tensor}); } - return OkStatus(); + return absl::OkStatus(); } void Worker::RunGraphAsync(CallOptions* opts, RunGraphRequestWrapper* request, @@ -151,7 +151,7 @@ void Worker::RunGraphAsync(CallOptions* opts, RunGraphRequestWrapper* request, if (request->store_errors_in_response_body()) { done = [response, done](const Status& status) { response->set_status(status); - done(OkStatus()); + done(absl::OkStatus()); }; } if (request->is_partial()) { @@ -371,7 +371,7 @@ void Worker::CleanupGraphAsync(const CleanupGraphRequest* request, sam->Cleanup(step_id); } } - done(OkStatus()); + done(absl::OkStatus()); } void Worker::CleanupAllAsync(const CleanupAllRequest* request, @@ -380,7 +380,7 @@ void Worker::CleanupAllAsync(const CleanupAllRequest* request, std::vector containers; for (const auto& c : request->container()) containers.push_back(c); env_->device_mgr->ClearContainers(containers); - done(OkStatus()); + done(absl::OkStatus()); } void Worker::LoggingAsync(const LoggingRequest* request, @@ -489,7 +489,7 @@ Status Worker::PrepareRecvTensor(const Rendezvous::ParsedKey& parsed, distributed_runtime::WorkerPossiblyRestarted().SerializeAsString()}}); } - return OkStatus(); + return absl::OkStatus(); } void Worker::RecvTensorAsync(CallOptions* opts, diff --git a/tensorflow/core/distributed_runtime/worker_cache_partial.cc b/tensorflow/core/distributed_runtime/worker_cache_partial.cc index 7c53d6f0fed938..f224094f5f7d5b 100644 --- a/tensorflow/core/distributed_runtime/worker_cache_partial.cc +++ b/tensorflow/core/distributed_runtime/worker_cache_partial.cc @@ -51,7 +51,7 @@ void WorkerCachePartial::GetDeviceLocalityAsync(const string& device_name, }); return; } - done(OkStatus()); + done(absl::OkStatus()); } Status WorkerCachePartial::RefreshDeviceStatus(const string& device_name) { diff --git a/tensorflow/core/distributed_runtime/worker_session.cc b/tensorflow/core/distributed_runtime/worker_session.cc index 9fbf08760ae9d8..1bbf1a7bb6c329 100644 --- a/tensorflow/core/distributed_runtime/worker_session.cc +++ b/tensorflow/core/distributed_runtime/worker_session.cc @@ -158,7 +158,7 @@ Status WorkerSession::UpdateWorkerCacheAndDevices( TF_RETURN_IF_ERROR(remote_device_mgr_->RemoveDevices(removed_remote_devices)); TF_RETURN_IF_ERROR( remote_device_mgr_->AddDevices(std::move(added_remote_devices))); - return OkStatus(); + return absl::OkStatus(); } /* static */ diff --git a/tensorflow/core/framework/BUILD b/tensorflow/core/framework/BUILD index d82c215546f5f7..b5f24419245868 100644 --- a/tensorflow/core/framework/BUILD +++ b/tensorflow/core/framework/BUILD @@ -715,6 +715,7 @@ cc_library( "//learning/brain/google/data/core/kernels:__pkg__", "//learning/deepmind/tensorflow/queues:__pkg__", "//learning/deepmind/tensorflow/sstable:__pkg__", + "//tensorflow/compiler/mlir/tools/kernel_gen:__pkg__", ], deps = [ "//tensorflow/core/lib/core:refcount", @@ -1384,6 +1385,7 @@ tf_cc_tests( "@com_google_absl//absl/strings", "@com_google_googletest//:gtest", "@eigen_archive//:eigen3", + "@local_tsl//tsl/profiler/utils:xplane_utils", ], ) diff --git a/tensorflow/core/framework/allocator_test.cc b/tensorflow/core/framework/allocator_test.cc index 0eaf97d4b054c3..7e85b25a9df6f7 100644 --- a/tensorflow/core/framework/allocator_test.cc +++ b/tensorflow/core/framework/allocator_test.cc @@ -28,6 +28,7 @@ limitations under the License. #include "tensorflow/core/profiler/protobuf/xplane.pb.h" #include "tensorflow/core/profiler/utils/xplane_schema.h" #include "tensorflow/core/profiler/utils/xplane_visitor.h" +#include "tsl/profiler/utils/xplane_utils.h" namespace tensorflow { @@ -238,16 +239,16 @@ TEST(CPUAllocatorTest, ProfilerReporting) { EXPECT_EQ(OkStatus(), profiler->CollectData(&xspace)); // Validate the output - ASSERT_EQ(xspace.planes_size(), 1) << "XSpace: " << xspace.DebugString(); - const auto& plane = xspace.planes(0); - ::tensorflow::profiler::XPlaneVisitor xplane(&plane); + const auto plane = ::tsl::profiler::FindPlaneWithName( + xspace, ::tensorflow::profiler::kHostThreadsPlaneName); + ::tensorflow::profiler::XPlaneVisitor xplane(plane); - ASSERT_EQ(plane.name(), ::tensorflow::profiler::kHostThreadsPlaneName) + ASSERT_EQ(plane->name(), ::tensorflow::profiler::kHostThreadsPlaneName) << "XSpace: " << xspace.DebugString(); - ASSERT_EQ(plane.event_metadata_size(), 2) + ASSERT_EQ(plane->event_metadata_size(), 2) << "XSpace: " << xspace.DebugString(); - const auto& line = plane.lines(0); + const auto& line = plane->lines(0); ASSERT_EQ(line.events_size(), 2) << "XSpace: " << xspace.DebugString(); const auto& events = line.events(); diff --git a/tensorflow/core/framework/collective.h b/tensorflow/core/framework/collective.h index 4b43fb226af248..ffbdfc0d038c8b 100644 --- a/tensorflow/core/framework/collective.h +++ b/tensorflow/core/framework/collective.h @@ -265,6 +265,9 @@ class CollectiveExecutorMgrInterface : public StepSequenceInterface { // table. virtual void Cleanup(int64_t step_id) = 0; + // Cleanup the entire table, removing all entries for step_ids. + virtual void CleanupAll() = 0; + virtual ParamResolverInterface* GetParamResolver() const = 0; virtual DeviceResolverInterface* GetDeviceResolver() const = 0; diff --git a/tensorflow/core/framework/dataset.cc b/tensorflow/core/framework/dataset.cc index 1623798a3263b8..5b86ffc51b5988 100644 --- a/tensorflow/core/framework/dataset.cc +++ b/tensorflow/core/framework/dataset.cc @@ -546,8 +546,7 @@ Status IteratorBase::InitializeBase(IteratorContext* ctx, auto factory = [ctx, this](model::Node::Args args) { return CreateNode(ctx, std::move(args)); }; - model->AddNode(std::move(factory), prefix(), parent->model_node(), - &node_); + model->AddNode(std::move(factory), name(), parent->model_node(), &node_); cleanup_fns_.push_back([this, model]() { model->RemoveNode(node_); }); } } @@ -709,6 +708,10 @@ void WarnProtoConflicts(const protobuf::Message& src, protobuf::Message* dst) { set_dst.end(), std::back_inserter(in_both)); for (auto field : in_both) { + // Used for Job Instrumentation, users should not be warned. + if (field->name() == "framework_type") { + continue; + } if (field->type() == protobuf::FieldDescriptor::TYPE_MESSAGE) { WarnProtoConflicts(reflection->GetMessage(src, field), reflection->MutableMessage(dst, field)); @@ -773,10 +776,17 @@ void DatasetBase::Initialize(const Metadata& metadata) { LOG_EVERY_N_SEC(ERROR, 10) << s; } metadata_ = metadata; + if (absl::StrContains(metadata_.name(), ":")) { + // Type string is already included in the name, no need to add it. + return; + } if (metadata_.name() == "") { static std::atomic id_counter(0); *metadata_.mutable_name() = strings::StrCat(type_string(), ":", id_counter.fetch_add(1)); + } else { + *metadata_.mutable_name() = + strings::StrCat(type_string(), ":", metadata_.name()); } } diff --git a/tensorflow/core/framework/dataset.h b/tensorflow/core/framework/dataset.h index b74cfa3e6b01f4..40b95f8e26b887 100644 --- a/tensorflow/core/framework/dataset.h +++ b/tensorflow/core/framework/dataset.h @@ -25,6 +25,7 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" #include "absl/memory/memory.h" +#include "absl/status/status.h" #include "absl/strings/str_cat.h" #include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/attr_value_util.h" @@ -666,6 +667,7 @@ class IteratorContext { interleave_depth(ctx->interleave_depth()), is_restoring(ctx->is_restoring()), model(ctx->model()), + options(ctx->options()), ram_budget_manager(ctx->ram_budget_manager()), resource_mgr(ctx->resource_mgr()), runner(*(ctx->runner())), @@ -676,7 +678,8 @@ class IteratorContext { thread_factory(ctx->thread_factory()), thread_pool(ctx->thread_pool()), id_registry(ctx->id_registry()), - warm_start(ctx->warm_start()) {} + warm_start(ctx->warm_start()), + index_mapper(ctx->index_mapper()) {} explicit Params(OpKernelContext* ctx) : collective_executor(ctx->collective_executor()), @@ -736,12 +739,12 @@ class IteratorContext { // If non-null, identifies the object used for performance modeling. std::shared_ptr model = nullptr; - // Manager for the ram budget when using autotune. - std::shared_ptr ram_budget_manager = nullptr; - // The input pipeline options. const Options* options = nullptr; + // Manager for the ram budget when using autotune. + std::shared_ptr ram_budget_manager = nullptr; + // A resource manager for storing dataset-related state, e.g. random // seeds or cached tensors. Not owned. ResourceMgr* resource_mgr = nullptr; @@ -781,6 +784,11 @@ class IteratorContext { // Specifies the tf.data pipeline run mode. RunMode run_mode = RunMode::DEFAULT; + + // Maps the index of dataset elements to a shuffled index. In other words, + // given an index i, returns the permuted index p(i) for the iterator. Used + // to support global shuffling of datasets that support random access. + std::function index_mapper = nullptr; }; explicit IteratorContext(IteratorContext* ctx) @@ -835,6 +843,8 @@ class IteratorContext { const std::shared_ptr& model() const { return params_.model; } + const Options* options() const { return params_.options; } + const std::shared_ptr& ram_budget_manager() { return params_.ram_budget_manager; } @@ -867,6 +877,12 @@ class IteratorContext { RunMode run_mode() { return params_.run_mode; } + std::function index_mapper() const { + return params_.index_mapper; + } + + void SetModel(std::shared_ptr model) { params_.model = model; } + std::unique_ptr CreateThreadPool(const string& name, int num_threads) { if (params_.thread_pool) { @@ -1025,6 +1041,10 @@ class IteratorBase : public Checkpointable { // this iterator. virtual const string& prefix() const = 0; + // Returns a string identifying the iterator, e.g. "ParallelMapDatasetV2:" + // or "ParallelMapDatasetV2:". + virtual const string& name() const = 0; + // Indicates whether the iterator is compatible with symbolic checkpointing. virtual bool SymbolicCheckpointCompatible() const { return false; } @@ -1310,6 +1330,12 @@ class DatasetBase : public core::RefCounted { virtual Status Get(OpKernelContext* ctx, int64 index, std::vector* out_tensors) const; + // Returns true if the dataset and its inputs support random access. + virtual absl::Status RandomIndexingCompatible() const { + return absl::FailedPreconditionError( + absl::StrCat(type_string(), " does not support random access.")); + } + // Return a finalized version of the dataset. The returned DatasetBase is // unowned and lives for as long as this dataset. virtual StatusOr Finalize( @@ -1415,6 +1441,8 @@ class DatasetBaseIterator : public IteratorBase { const string& prefix() const override { return params_.prefix; } + const string& name() const override { return dataset()->metadata().name(); } + // Returns a name to be used for the TraceMe event. // // NOTE: TraceMe supports passing key-value pairs of "arguments" using the @@ -1438,6 +1466,16 @@ class DatasetBaseIterator : public IteratorBase { return IteratorBase::Save(ctx, writer); } + // Returns a copy of the `status` where the error message is prepended with + // dataset name and the iterator prefix. + Status AddErrorContext(const Status& status) const { + return Status(status.code(), + strings::StrCat("Error in user-defined function passed to ", + dataset()->metadata().name(), + " transformation with iterator: ", prefix(), + ": ", status.message())); + } + protected: Status Restore(IteratorContext* ctx, IteratorStateReader* reader) final { VLOG(2) << "Attempting to restore checkpoints on iterator (prefix: " diff --git a/tensorflow/core/framework/dataset_options.proto b/tensorflow/core/framework/dataset_options.proto index 1fe7c0bd8a8895..26276adec73178 100644 --- a/tensorflow/core/framework/dataset_options.proto +++ b/tensorflow/core/framework/dataset_options.proto @@ -27,7 +27,7 @@ enum AutoShardPolicy { OFF = -1; } -// next: 5 +// next: 6 message AutotuneOptions { // Whether to automatically tune performance knobs. oneof optional_enabled { @@ -54,6 +54,15 @@ message AutotuneOptions { oneof optional_autotune_algorithm { model.AutotuneAlgorithm autotune_algorithm = 4; } + + // The initial parallelism to use for parallel transformations before autotune + // has a chance to run. A higher value can help with quick startup, but may + // cause the ram_budget to temporarily be exceeded. Memory-sensitive datasets + // should consider setting this to `1` to avoid running out of memory. + // Defaults to 16. + oneof optional_initial_parallelism { + int64 initial_parallelism = 5; + } } // next: 2 @@ -82,7 +91,7 @@ message DistributeOptions { } } -// next: 21 +// next: 22 message OptimizationOptions { // Whether to apply default graph optimizations. If False, only graph // optimizations that have been explicitly enabled will be applied. @@ -151,6 +160,12 @@ message OptimizationOptions { } // NOTE: field id 20 was removed in August 2023. reserved 20; + // Whether to replace parallel interleave with interleave and prefetch. Only + // takes effect if the parallel interleave is deterministic; otherwise does + // nothing. + oneof optional_seq_interleave_prefetch { + bool seq_interleave_prefetch = 21; + } } // next: 3 @@ -175,13 +190,14 @@ enum ExternalStatePolicy { // Message stored with Dataset objects to control how datasets are processed and // optimized. // -// next: 11 +// next: 12 message Options { // Optional name for the dataset. oneof optional_dataset_name { string dataset_name = 10; } - + // List of frameworks used to generate this dataset. + repeated string framework_type = 11; // Whether the outputs need to be produced in deterministic order. oneof optional_deterministic { bool deterministic = 1; diff --git a/tensorflow/core/framework/metrics.cc b/tensorflow/core/framework/metrics.cc index 6a1796d9e457c7..43e3dc5c11d9a7 100644 --- a/tensorflow/core/framework/metrics.cc +++ b/tensorflow/core/framework/metrics.cc @@ -35,6 +35,11 @@ auto* persistent_cache_load_count = tsl::monitoring::Counter<0>::New( "/tensorflow/core/persistent_cache_load_count", "The number of times a binary is loaded from the persistent cache."); +auto* aot_bef_mlir_load_count = tsl::monitoring::Counter<0>::New( + "/tensorflow/core/aot_bef_mlir_load_count", + "The number of times BEF and MLIR are deserialized instead of generated " + "and used."); + auto* graph_runs = tsl::monitoring::Counter<0>::New( "/tensorflow/core/graph_runs", "The number of graph executions used to collect " @@ -301,11 +306,19 @@ auto* tf_data_autotune_stopping_criteria_counter = "algorithm stopping criterion is met.", "name"); +auto* tf_data_debug = tsl::monitoring::Counter<1>::New( + "/tensorflow/data/debug", + "The number of times this event occured, for debugging.", "event"); + auto* tf_data_error = tsl::monitoring::Counter<2>::New( "/tensorflow/data/error", "The number of times an error of this type occurred with this status code.", "error_type", "status_code"); +auto* tf_data_framework_type = tsl::monitoring::Counter<1>::New( + "/tensorflow/data/framework_type", + "The framework type used to build the tf.data.Dataset.", "framework_type"); + auto* parse_dense_feature_counter = tsl::monitoring::Counter<0>::New( "/tensorflow/data/dense_feature", "The number of dense features parsed by ops for parsing tf.Example."); @@ -670,10 +683,18 @@ void RecordTFDataAutotuneStoppingCriteria(const string& name) { tf_data_autotune_stopping_criteria_counter->GetCell(name)->IncrementBy(1); } +void RecordTFDataDebug(const string& event) { + tf_data_debug->GetCell(event)->IncrementBy(1); +} + void RecordTFDataError(const string& error_type, const string& status_code) { tf_data_error->GetCell(error_type, status_code)->IncrementBy(1); } +void RecordTFDataFrameworkType(const std::string& framework_type) { + tf_data_framework_type->GetCell(framework_type)->IncrementBy(1); +} + void RecordParseDenseFeature(int64 num_features) { static auto* parse_dense_feature_counter_cell = parse_dense_feature_counter->GetCell(); @@ -715,6 +736,12 @@ void UpdatePersistentCacheLoadCount() { persistent_cache_load_count_cell->IncrementBy(1); } +void UpdateAotBefMlirLoadCount() { + static auto* aot_bef_mlir_load_count_cell = + aot_bef_mlir_load_count->GetCell(); + aot_bef_mlir_load_count_cell->IncrementBy(1); +} + void UpdateGraphExecTime(const uint64 running_time_usecs) { if (running_time_usecs > 0) { static auto* graph_runs_cell = graph_runs->GetCell(); diff --git a/tensorflow/core/framework/metrics.h b/tensorflow/core/framework/metrics.h index c75a5fe81d237e..2fb2c2355f6903 100644 --- a/tensorflow/core/framework/metrics.h +++ b/tensorflow/core/framework/metrics.h @@ -217,10 +217,16 @@ void RecordTFDataAutoShardRewriteBatchSize( // criterion is met. void RecordTFDataAutotuneStoppingCriteria(const string& name); +// Records the number of times this event occured, for debugging. +void RecordTFDataDebug(const string& event); + // Records the number of times an error of this type occurred with this status // code. void RecordTFDataError(const string& error_type, const string& error_code); +// Records the framework type used to build the tf.data.Dataset. +void RecordTFDataFrameworkType(const std::string& framework_type); + // Records parsing of dense tensor features. void RecordParseDenseFeature(int64_t num_features); @@ -250,6 +256,9 @@ void RecordPipelineProcessingTime(const string& id, // Increments the count of binaries loaded from the persistent cache. void UpdatePersistentCacheLoadCount(); +// Increments the count of BEF and MLIR deserialized. +void UpdateAotBefMlirLoadCount(); + // Updates the metrics stored about time spent building graphs. // // By "GraphBuild", we refer to building a client graph, which is a sub-graph of diff --git a/tensorflow/core/framework/model.cc b/tensorflow/core/framework/model.cc index 2e8188f9caa0b2..41594a5cf387ab 100644 --- a/tensorflow/core/framework/model.cc +++ b/tensorflow/core/framework/model.cc @@ -270,11 +270,6 @@ inline bool IsSyncNode(const std::shared_ptr node) { return !node->IsAsync(); } -// Helper function for node traversal that returns only `DataService` nodes. -inline bool IsDataServiceNode(const std::shared_ptr node) { - return absl::StartsWith(node->name(), kDataService); -} - // Helper function for node traversal that returns only asynchronous interleave // many nodes. inline bool IsAsyncInterleaveManyNode(const std::shared_ptr node) { @@ -888,9 +883,9 @@ class AsyncRatio : public Node { std::vector> parameters, bool is_legacy_prefetch_autotuned = false) : Node(args), + is_legacy_prefetch_autotuned_(is_legacy_prefetch_autotuned), ratio_(ratio), - memory_ratio_(memory_ratio), - is_legacy_prefetch_autotuned_(is_legacy_prefetch_autotuned) { + memory_ratio_(memory_ratio) { for (auto& parameter : parameters) { parameters_[parameter->name] = std::move(parameter); } @@ -1119,6 +1114,10 @@ class AsyncRatio : public Node { return result; } + // Whether this node represents a prefetch node tuned by the legacy prefetch + // autotuner, rather than the model. + const bool is_legacy_prefetch_autotuned_; + private: // Identifies how many input elements need to be created to construct an // element for the dataset. @@ -1131,9 +1130,6 @@ class AsyncRatio : public Node { // budget bound with given num_parallel_calls (or buffer_size) combined with // the estimated average size of buffered elements. const double memory_ratio_; - // Whether this node represents a prefetch node tuned by the legacy prefetch - // autotuner, rather than the model. - const bool is_legacy_prefetch_autotuned_; }; class UnknownRatio : public Node { @@ -1338,6 +1334,18 @@ class AsyncKnownRatio : public AsyncRatio { node_proto->set_node_class(NodeClass::ASYNC_KNOWN_RATIO); node_proto->set_ratio(Ratio()); node_proto->set_memory_ratio(MemoryRatio()); + if (is_legacy_prefetch_autotuned_) { + // Update buffer_size parameter to make sense from a user perspective. + if (node_proto->parameters_size() != 1) { + return absl::InternalError(absl::StrCat( + "Expected prefetch node to have one parameter, but it has ", + node_proto->parameters_size())); + } + auto* parameter = node_proto->mutable_parameters(0); + // Legacy autotuner only modifies the state_value, not the model value. + parameter->set_value(parameter->state_value()); + parameter->set_tunable(true); + } return OkStatus(); } }; @@ -2278,9 +2286,8 @@ void Model::AddNode(Node::Factory factory, const string& name, std::shared_ptr* out_node) { // The name captures the sequence of iterators joined by `::`. We only use the // last element of the sequence as the name node. - auto node_name = str_util::Split(name, ':', str_util::SkipEmpty()).back(); mutex_lock l(mu_); - std::shared_ptr node = factory({id_counter_++, node_name, parent}); + std::shared_ptr node = factory({id_counter_++, name, parent}); if (!output_) { output_ = node; } @@ -2445,9 +2452,11 @@ Model::ModelParameters Model::CollectTunableParameters( } void Model::MaybeSyncStateValuesToValues(std::shared_ptr snapshot) { - auto subtree_nodes = - snapshot->CollectNodes(TraversalOrder::BFS, IsDataServiceNode); + auto subtree_nodes = snapshot->CollectNodes(TraversalOrder::BFS, IsAnyNode); for (const auto& node : subtree_nodes) { + if (!absl::StartsWith(node->name(), kDataService)) { + continue; + } node->SyncStateValuesToParameterValues(kBufferSize); } } diff --git a/tensorflow/core/framework/variant_encode_decode.h b/tensorflow/core/framework/variant_encode_decode.h index f481714eee2c4b..05ecb775709d3d 100644 --- a/tensorflow/core/framework/variant_encode_decode.h +++ b/tensorflow/core/framework/variant_encode_decode.h @@ -26,6 +26,7 @@ limitations under the License. #include "tensorflow/core/framework/variant_tensor_data.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/abi.h" +#include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/protobuf.h" namespace tensorflow { @@ -68,7 +69,10 @@ void EncodeVariantImpl(const T& value, TypeResolver, VariantTensorData* data) { - value.SerializeToString(&data->metadata_); + if (!value.SerializeToString(&data->metadata_)) { + data->metadata_.clear(); + LOG(ERROR) << "Failed to encode variant " << value.DebugString(); + } } // Specialization for other types diff --git a/tensorflow/core/function/runtime_client/runtime_client.cc b/tensorflow/core/function/runtime_client/runtime_client.cc index 6438a1ca2b83c1..1458d1d94b4913 100644 --- a/tensorflow/core/function/runtime_client/runtime_client.cc +++ b/tensorflow/core/function/runtime_client/runtime_client.cc @@ -89,7 +89,7 @@ EagerContext& GlobalPythonEagerContext() { return *ctx; } -StatusOr Runtime::GetFunctionProto(StringPiece name) { +absl::StatusOr Runtime::GetFunctionProto(StringPiece name) { EagerContext& ctx = this->eager_ctx_; const FunctionDef* f = ctx.FindFunctionDef(std::string(name)); @@ -170,7 +170,7 @@ Status Runtime::TransformFunction(StringPiece name, StringPiece pipeline_name, CreateFunction(reinterpret_cast(&fn)), absl::StrCat("updating function ", fn.getName().str())); } - return OkStatus(); + return absl::OkStatus(); } if (dialect == Dialect::TF) { @@ -196,7 +196,7 @@ Status Runtime::TransformFunction(StringPiece name, StringPiece pipeline_name, CreateFunction(reinterpret_cast(&fn)), absl::StrCat("updating function ", fn.getName().str())); } - return OkStatus(); + return absl::OkStatus(); } return Status( @@ -205,7 +205,7 @@ Status Runtime::TransformFunction(StringPiece name, StringPiece pipeline_name, ". Supported dialects are Dialect::TFG and Dialect::TF.")); } -StatusOr Runtime::CallFunction( +absl::StatusOr Runtime::CallFunction( StringPiece name, absl::Span args) { EagerContext& ctx = this->eager_ctx_; diff --git a/tensorflow/core/function/runtime_client/runtime_client.h b/tensorflow/core/function/runtime_client/runtime_client.h index e542fa56e039ac..e2cffdf4d74796 100644 --- a/tensorflow/core/function/runtime_client/runtime_client.h +++ b/tensorflow/core/function/runtime_client/runtime_client.h @@ -70,7 +70,7 @@ class Runtime { TF, }; - StatusOr GetFunctionProto(StringPiece name); + absl::StatusOr GetFunctionProto(StringPiece name); // TODO(mdan): Enforce creation or rename to SetFunction. Status CreateFunction(const FunctionDef& fdef); @@ -85,7 +85,7 @@ class Runtime { Status TransformFunction(StringPiece name, StringPiece pipeline_name, Dialect dialect = Dialect::TFG); - StatusOr CallFunction( + absl::StatusOr CallFunction( StringPiece name, absl::Span args); private: diff --git a/tensorflow/core/function/runtime_client/runtime_client_test.cc b/tensorflow/core/function/runtime_client/runtime_client_test.cc index 9cfe63fa23738e..91effd88b7f701 100644 --- a/tensorflow/core/function/runtime_client/runtime_client_test.cc +++ b/tensorflow/core/function/runtime_client/runtime_client_test.cc @@ -208,7 +208,7 @@ TEST(GlobalContext, Basic) { Runtime rt(GlobalEagerContext()); TF_ASSERT_OK(rt.CreateFunction(MakeNullaryFunction())); - StatusOr rets = rt.CallFunction("NullaryFunction", {}); + absl::StatusOr rets = rt.CallFunction("NullaryFunction", {}); TF_ASSERT_OK(rets.status()); ASSERT_EQ(rets->size(), 1); ASSERT_EQ(rets->at(0)->DataType(), DT_INT32); @@ -220,7 +220,7 @@ TEST(CreateTest, Call) { Runtime rt(*ctx); TF_ASSERT_OK(rt.CreateFunction(MakeNullaryFunction())); - StatusOr rets = rt.CallFunction("NullaryFunction", {}); + absl::StatusOr rets = rt.CallFunction("NullaryFunction", {}); TF_ASSERT_OK(rets.status()); ASSERT_EQ(rets->size(), 1); ASSERT_EQ(rets->at(0)->DataType(), DT_INT32); @@ -232,7 +232,7 @@ TEST(CreateTest, GetRoundtrip) { Runtime rt(*ctx); TF_ASSERT_OK(rt.CreateFunction(MakeNullaryFunction())); - StatusOr fdef_ret = rt.GetFunctionProto("NullaryFunction"); + absl::StatusOr fdef_ret = rt.GetFunctionProto("NullaryFunction"); TF_ASSERT_OK(fdef_ret.status()); FunctionDef fdef = *fdef_ret; @@ -240,7 +240,7 @@ TEST(CreateTest, GetRoundtrip) { TF_ASSERT_OK(rt.CreateFunction(fdef)); - StatusOr rets = rt.CallFunction("SecondFunction", {}); + absl::StatusOr rets = rt.CallFunction("SecondFunction", {}); TF_ASSERT_OK(rets.status()); ASSERT_EQ(rets->size(), 1); ASSERT_EQ(rets->at(0)->DataType(), DT_INT32); @@ -276,7 +276,7 @@ TEST(CreateTest, MlirFromGraphDef) { reinterpret_cast(&fop); TF_ASSERT_OK(rt.CreateFunction(opaque_fop)); - StatusOr rets = rt.CallFunction("NullaryFunction", {}); + absl::StatusOr rets = rt.CallFunction("NullaryFunction", {}); TF_ASSERT_OK(rets.status()); ASSERT_EQ(rets->size(), 1); ASSERT_EQ(rets->at(0)->DataType(), DT_INT32); @@ -288,7 +288,7 @@ TEST(CallTest, Nullary) { Runtime rt(*ctx); TF_ASSERT_OK(rt.CreateFunction(MakeNullaryFunction())); - StatusOr rets = rt.CallFunction("NullaryFunction", {}); + absl::StatusOr rets = rt.CallFunction("NullaryFunction", {}); TF_ASSERT_OK(rets.status()); ASSERT_EQ(rets->size(), 1); ASSERT_EQ(rets->at(0)->DataType(), DT_INT32); @@ -301,7 +301,8 @@ TEST(CallTest, Unary) { TF_ASSERT_OK(rt.CreateFunction(MakeUnaryFunction())); auto x = IntScalarTensor(*ctx, 1); - StatusOr rets = rt.CallFunction("UnaryFunction", {x.get()}); + absl::StatusOr rets = + rt.CallFunction("UnaryFunction", {x.get()}); TF_ASSERT_OK(rets.status()); ASSERT_EQ(rets->size(), 1); ASSERT_EQ(rets->at(0)->DataType(), DT_INT32); @@ -315,7 +316,7 @@ TEST(CallTest, Binary) { auto x = IntScalarTensor(*ctx, 1); auto y = IntScalarTensor(*ctx, 1); - StatusOr rets = + absl::StatusOr rets = rt.CallFunction("BinaryFunction", {x.get(), y.get()}); TF_ASSERT_OK(rets.status()); ASSERT_EQ(rets->size(), 1); @@ -333,7 +334,7 @@ TEST(TransformTest, TestPassOnBinaryFunction) { auto x = IntScalarTensor(*ctx, 2); auto y = IntScalarTensor(*ctx, 3); - StatusOr rets = + absl::StatusOr rets = rt.CallFunction("BinaryFunction", {x.get(), y.get()}); TF_ASSERT_OK(rets.status()); ASSERT_EQ(rets->size(), 1); @@ -352,7 +353,7 @@ TEST(TransformTest, TestPassOnMultiplyFunction) { auto x = IntScalarTensor(*ctx, 2); auto y = IntScalarTensor(*ctx, 3); - StatusOr rets = + absl::StatusOr rets = rt.CallFunction("MultiplyFunction", {x.get(), y.get()}); TF_ASSERT_OK(rets.status()); ASSERT_EQ(rets->size(), 1); @@ -372,7 +373,7 @@ TEST(TransformTest, TestMixedPassesOnBinaryFunction) { auto x = IntScalarTensor(*ctx, 2); auto y = IntScalarTensor(*ctx, 3); - StatusOr rets = + absl::StatusOr rets = rt.CallFunction("BinaryFunction", {x.get(), y.get()}); TF_ASSERT_OK(rets.status()); ASSERT_EQ(rets->size(), 1); diff --git a/tensorflow/core/graph/graph.h b/tensorflow/core/graph/graph.h index 85d0be75346447..6cf79b2c1c3be4 100644 --- a/tensorflow/core/graph/graph.h +++ b/tensorflow/core/graph/graph.h @@ -970,10 +970,14 @@ inline bool IsDistributedCommunication(const Node* n) { // https://en.cppreference.com/w/cpp/iterator/iterator). // Iterator for stepping through the nodes of a graph. -class NodeIter - : public std::iterator { +class NodeIter { public: + using iterator_category = std::forward_iterator_tag; + using value_type = Node; + using difference_type = std::ptrdiff_t; + using pointer = Node*; + using reference = Node*; + NodeIter(const Graph* graph, int id); bool operator==(const NodeIter& rhs) const; bool operator!=(const NodeIter& rhs) const; @@ -988,10 +992,14 @@ class NodeIter }; // Iterator for stepping through the neighbors of a node. -class NeighborIter - : public std::iterator { +class NeighborIter { public: + using iterator_category = std::forward_iterator_tag; + using value_type = Node; + using difference_type = std::ptrdiff_t; + using pointer = Node*; + using reference = Node*; + NeighborIter(EdgeSet::const_iterator iter, bool incoming); bool operator==(const NeighborIter& rhs) const; bool operator!=(const NeighborIter& rhs) const; diff --git a/tensorflow/core/graph/mkl_graph_util.h b/tensorflow/core/graph/mkl_graph_util.h index 79a1867155ef33..9ee13ac9a5d998 100644 --- a/tensorflow/core/graph/mkl_graph_util.h +++ b/tensorflow/core/graph/mkl_graph_util.h @@ -1,4 +1,4 @@ -/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -143,7 +143,6 @@ inline string GetMklNativeOpName(const string& name) { // prefixed with _Mkl instead of _MklNative. bool result = (0 == name.compare("ConjugateTranspose") || - 0 == name.compare("SparseTensorDenseMatMul") || 0 == name.compare("BatchMatMul") || 0 == name.compare("BatchMatMulV2") || 0 == name.compare("Einsum") || 0 == name.compare("MatMul") || 0 == name.compare("Transpose") || 0 == name.compare("QuantizeV2") || diff --git a/tensorflow/core/graph/mkl_testlib.cc b/tensorflow/core/graph/mkl_testlib.cc index e8955da2f1748c..05d0b67d3e1a16 100644 --- a/tensorflow/core/graph/mkl_testlib.cc +++ b/tensorflow/core/graph/mkl_testlib.cc @@ -1,4 +1,4 @@ -/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -32,17 +32,6 @@ Node* oneDNNSoftmax(Graph* g, Node* input) { return ret; } -Node* oneDNNSparseCSRMatmul(Graph* g, Node* csr_matrix_t, Node* b) { - Node* ret = nullptr; - TF_CHECK_OK(NodeBuilder(g->NewName("n"), "_MklNativeSparseMatrixMatMul") - .Input(csr_matrix_t) - .Input(b) - .Attr("T", DT_FLOAT) - .Attr("_kernel", mkl_op_registry::kMklNameChangeOpLabel) - .Finalize(g, &ret)); - return ret; -} - } // namespace graph } // namespace test } // namespace tensorflow diff --git a/tensorflow/core/graph/mkl_testlib.h b/tensorflow/core/graph/mkl_testlib.h index 1b783923c1f03c..733f124168d949 100644 --- a/tensorflow/core/graph/mkl_testlib.h +++ b/tensorflow/core/graph/mkl_testlib.h @@ -1,4 +1,4 @@ -/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -26,8 +26,6 @@ namespace graph { Node* oneDNNSoftmax(Graph* g, Node* input); -Node* oneDNNSparseCSRMatmul(Graph* g, Node* csr_matrix_t, Node* b); - } // namespace graph } // namespace test } // namespace tensorflow diff --git a/tensorflow/core/grappler/clusters/cluster.h b/tensorflow/core/grappler/clusters/cluster.h index 11b62675b299c9..f1c5f210f13a00 100644 --- a/tensorflow/core/grappler/clusters/cluster.h +++ b/tensorflow/core/grappler/clusters/cluster.h @@ -58,7 +58,7 @@ class Cluster { // Returns OK iff there are no pending calls to the Run() method and all the // resources used by the cluster could be released. Returns an error // otherwise. - virtual Status Shutdown() { return OkStatus(); } + virtual Status Shutdown() { return absl::OkStatus(); } // Whether soft placement is allowed. If allow_soft_placement is true, // an op will be placed on CPU if there's no GPU implementation for the OP diff --git a/tensorflow/core/grappler/clusters/single_machine.cc b/tensorflow/core/grappler/clusters/single_machine.cc index 9bd403f2c14781..92f17cc30d1a42 100644 --- a/tensorflow/core/grappler/clusters/single_machine.cc +++ b/tensorflow/core/grappler/clusters/single_machine.cc @@ -120,7 +120,7 @@ Status SingleMachine::Provision() { if (cpu_allocator_stats_enabled_) { TF_RETURN_IF_ERROR(ClearAllocatorStats()); } - return OkStatus(); + return absl::OkStatus(); } Status SingleMachine::Initialize(const GrapplerItem& item) { @@ -132,7 +132,7 @@ Status SingleMachine::Initialize(const GrapplerItem& item) { queue_runner_defs_ = item.queue_runners; last_graph_id_ = item.id; } - return OkStatus(); + return absl::OkStatus(); } Status SingleMachine::Shutdown() { @@ -142,7 +142,7 @@ Status SingleMachine::Shutdown() { last_graph_ = nullptr; already_provisioned = false; - return OkStatus(); + return absl::OkStatus(); } Status SingleMachine::Run(const GraphDef& graph_def, @@ -203,14 +203,14 @@ Status SingleMachine::Run(const GraphDef& graph_def, last_graph_ = &graph_def; - return OkStatus(); + return absl::OkStatus(); } Status SingleMachine::EnablePeakMemoryStats() { EnableCPUAllocatorStats(); cpu_allocator_stats_enabled_ = true; // No need to enable GPU allocator stats since its stats are always collected. - return OkStatus(); + return absl::OkStatus(); } Status SingleMachine::GetPeakMemoryUsage( @@ -238,7 +238,7 @@ Status SingleMachine::GetPeakMemoryUsage( (stats ? stats->peak_bytes_in_use : 0); } - return OkStatus(); + return absl::OkStatus(); } Status SingleMachine::RunWithTimeout( @@ -276,7 +276,7 @@ Status SingleMachine::RunWithTimeout( Status SingleMachine::CloseSession(bool use_timeout) { if (!session_ || !thread_pool_) { - return OkStatus(); + return absl::OkStatus(); } { @@ -317,7 +317,7 @@ Status SingleMachine::CloseSession(bool use_timeout) { " seconds, aborting")); } - return OkStatus(); + return absl::OkStatus(); } Status SingleMachine::ShutdownSession() { @@ -343,7 +343,7 @@ Status SingleMachine::ShutdownSession() { "The session is still running graphs after ", timeout_s_, " seconds")); } - return OkStatus(); + return absl::OkStatus(); } Status SingleMachine::ResetSession() { @@ -379,7 +379,7 @@ Status SingleMachine::ResetSession() { // We currently don't care about the client device. } - return OkStatus(); + return absl::OkStatus(); } void SingleMachine::MergeCosts(CostGraphDef* graph_costs, @@ -469,7 +469,7 @@ Status SingleMachine::ClearAllocatorStats() const { device->name())); } } - return OkStatus(); + return absl::OkStatus(); } } // namespace grappler diff --git a/tensorflow/core/grappler/clusters/virtual_cluster.cc b/tensorflow/core/grappler/clusters/virtual_cluster.cc index 0a902e8ac1bd2e..03dea8dd13a378 100644 --- a/tensorflow/core/grappler/clusters/virtual_cluster.cc +++ b/tensorflow/core/grappler/clusters/virtual_cluster.cc @@ -57,10 +57,10 @@ VirtualCluster::VirtualCluster(const DeviceSet* device_set) VirtualCluster::~VirtualCluster() {} -Status VirtualCluster::Provision() { return OkStatus(); } +Status VirtualCluster::Provision() { return absl::OkStatus(); } Status VirtualCluster::Initialize(const GrapplerItem& item) { - return OkStatus(); + return absl::OkStatus(); } Status VirtualCluster::Run(const GraphDef& graph, @@ -114,7 +114,7 @@ Status VirtualCluster::Run(const GrapplerItem& item, RunMetadata* metadata) { } } - return OkStatus(); + return absl::OkStatus(); } } // namespace grappler diff --git a/tensorflow/core/grappler/costs/analytical_cost_estimator.cc b/tensorflow/core/grappler/costs/analytical_cost_estimator.cc index 2f3975d73e78e5..10bf5ea4fc1aae 100644 --- a/tensorflow/core/grappler/costs/analytical_cost_estimator.cc +++ b/tensorflow/core/grappler/costs/analytical_cost_estimator.cc @@ -108,7 +108,7 @@ Status AddCostNode(ReadyNodeManager* node_manager, const OpContext& op_context, } output_info->set_size(size); } - return OkStatus(); + return absl::OkStatus(); } } // namespace @@ -151,7 +151,7 @@ AnalyticalCostEstimator::AnalyticalCostEstimator( Status AnalyticalCostEstimator::Initialize(const GrapplerItem& item) { item_ = &item; - return OkStatus(); + return absl::OkStatus(); } Status AnalyticalCostEstimator::PredictCosts(const GraphDef& optimized_graph, @@ -244,7 +244,7 @@ Status AnalyticalCostEstimator::PredictCosts(const GraphDef& optimized_graph, } } - return OkStatus(); + return absl::OkStatus(); } } // end namespace grappler diff --git a/tensorflow/core/grappler/costs/graph_memory.cc b/tensorflow/core/grappler/costs/graph_memory.cc index e937dbac274b7d..4099c2495edc00 100644 --- a/tensorflow/core/grappler/costs/graph_memory.cc +++ b/tensorflow/core/grappler/costs/graph_memory.cc @@ -45,7 +45,7 @@ Status GraphMemory::InferStatically( return s; } InferFromTrace(metadata.step_stats()); - return OkStatus(); + return absl::OkStatus(); } Status GraphMemory::InferDynamically(Cluster* cluster) { @@ -57,7 +57,7 @@ Status GraphMemory::InferDynamically(Cluster* cluster) { RunMetadata metadata; TF_RETURN_IF_ERROR(cluster->Run(item_, &metadata)); InferFromTrace(metadata.step_stats()); - return OkStatus(); + return absl::OkStatus(); } int64_t GraphMemory::GetWorstCaseMemoryUsage() const { diff --git a/tensorflow/core/grappler/costs/graph_properties.cc b/tensorflow/core/grappler/costs/graph_properties.cc index 84b33460db1b03..94b86ba901c44d 100644 --- a/tensorflow/core/grappler/costs/graph_properties.cc +++ b/tensorflow/core/grappler/costs/graph_properties.cc @@ -231,7 +231,7 @@ Status DisjointSet::Merge(Handle x, Handle y) { // x and y are already in the same set if (x_root == y_root) { - return OkStatus(); + return absl::OkStatus(); } // x and y are not in same set, so we merge them // Use the occasion to strengthen what we know about the handle by merging the @@ -248,7 +248,7 @@ Status DisjointSet::Merge(Handle x, Handle y) { y_root->parent = x_root; x_root->rank = x_root->rank + 1; } - return OkStatus(); + return absl::OkStatus(); } template @@ -807,7 +807,7 @@ class SymbolicShapeRefiner { TF_RETURN_IF_ERROR(SetUnknownShape(function_node, i)); } - return OkStatus(); + return absl::OkStatus(); } // Copy (not reference) so that changes we make here (e.g., replacing @@ -983,7 +983,7 @@ class SymbolicShapeRefiner { output++; } - return OkStatus(); + return absl::OkStatus(); } // Prepares input shapes/values/handles, then runs shape inference, and @@ -1071,7 +1071,7 @@ class SymbolicShapeRefiner { if (!*refined) { // No input shape has changed, we're done. - return OkStatus(); + return absl::OkStatus(); } // Convert all kUnknownDimFromConst to -1 for shape inference. @@ -1097,14 +1097,14 @@ class SymbolicShapeRefiner { // get output values or output shapes as tensor from function node. auto s = UpdateOutputShapesUsingAnnotatedInformation(*node, ctx); if (s.ok() && AllOutputShapesKnown(ctx)) { - return OkStatus(); + return absl::OkStatus(); } // If shape annotation was not available, incomplete, or incompatible, // fall through to call UpdateFunction(). } auto s = UpdateFunction(node); if (s.ok()) { - return OkStatus(); + return absl::OkStatus(); } else { VLOG(1) << "UpdateFunction failed for " << node->op() << ". Defaulting to ShapeUnknown.\n" @@ -1144,7 +1144,7 @@ class SymbolicShapeRefiner { ") but was ", output_port); } ctx->set_output(output_port, shape); - return OkStatus(); + return absl::OkStatus(); } struct ShapeId { @@ -1309,7 +1309,7 @@ class SymbolicShapeRefiner { const std::string& function_name) { auto it = fun_to_grappler_function_item_.find(function_name); if (it != fun_to_grappler_function_item_.end()) { - return OkStatus(); + return absl::OkStatus(); } const FunctionDef* function_def = @@ -1325,7 +1325,7 @@ class SymbolicShapeRefiner { << function_instantiated.message(); fun_to_grappler_function_item_[function_def->signature().name()] = absl::nullopt; - return OkStatus(); + return absl::OkStatus(); } if (static_cast(grappler_function_item.inputs().size()) > @@ -1348,7 +1348,7 @@ class SymbolicShapeRefiner { fun_to_grappler_function_item_[function_def->signature().name()] = grappler_function_item; - return OkStatus(); + return absl::OkStatus(); } Status AddNode(const NodeDef* node) { @@ -1629,7 +1629,7 @@ class SymbolicShapeRefiner { const_tensors_to_propagate_.push_back(tensor_proto); c->output_tensor_protos[k] = &const_tensors_to_propagate_.back(); } - return OkStatus(); + return absl::OkStatus(); } // Update output shapes with annotated information. @@ -1641,7 +1641,7 @@ class SymbolicShapeRefiner { const auto& attr = node.attr(); if (attr.count(kOutputSame) == 0 || !attr.at(kOutputSame).b() || attr.count(kOutputShapes) == 0) - return OkStatus(); + return absl::OkStatus(); InferenceContext* ic = c->inference_context.get(); int output_size = attr.at(kOutputShapes).list().shape_size(); @@ -1694,7 +1694,7 @@ class SymbolicShapeRefiner { } } - return OkStatus(); + return absl::OkStatus(); } Status MaybeUpdateNodeContextOutput(const NodeDef& node, const bool is_fed, @@ -1911,11 +1911,11 @@ class SymbolicShapeRefiner { const int max_element_size = 17; // Max up to 4x4 matrix or similar. if (AllOutputValuesKnown(c) || !AllInputValuesKnown(c) || !ShouldUpdateOutputShapesAndValues(c, max_element_size)) { - return OkStatus(); + return absl::OkStatus(); } UpdateOutputShapesAndValues(node, c).IgnoreError(); // This is optional. } - return OkStatus(); + return absl::OkStatus(); } Status InferShapes(const NodeDef& node, NodeContext* c) { @@ -1929,7 +1929,7 @@ class SymbolicShapeRefiner { TF_RETURN_IF_ERROR( c->inference_context->Run(shape_inference::UnknownShape)); } - Status status = OkStatus(); + Status status = absl::OkStatus(); auto it = fed_ports_.find(node.name()); const bool is_fed = it != fed_ports_.end(); if (is_fed) { @@ -2073,7 +2073,7 @@ class SymbolicShapeManager { Status Merge(ShapeHandle s1, ShapeHandle s2) { if (!s1.IsSet() || !s2.IsSet()) { - return OkStatus(); + return absl::OkStatus(); } TF_RETURN_IF_ERROR(shapes_.Merge(s1, s2)); if (InferenceContext::Rank(s1) > 0 && InferenceContext::Rank(s2) > 0) { @@ -2083,11 +2083,11 @@ class SymbolicShapeManager { InferenceContext::DimKnownRank(s2, i))); } } - return OkStatus(); + return absl::OkStatus(); } Status Merge(DimensionHandle d1, DimensionHandle d2) { if (!d1.IsSet() || !d2.IsSet()) { - return OkStatus(); + return absl::OkStatus(); } return dims_.Merge(d1, d2); } @@ -2141,7 +2141,7 @@ Status ValidateSymbolicShapeManager(const GraphDef& graph_def, SymbolicShapeRefiner* refiner, SymbolicShapeManager* shape_manager) { if (!VLOG_IS_ON(1)) { - return OkStatus(); + return absl::OkStatus(); } VLOG(1) << "Checking any conflicts in shapes and dimensions ..."; @@ -2182,7 +2182,7 @@ Status ValidateSymbolicShapeManager(const GraphDef& graph_def, VLOG(1) << "**** No incompatible shape found from SymbolicShapeManager."; } - return OkStatus(); + return absl::OkStatus(); } // Log shape inference and its merged shapes. @@ -2194,7 +2194,7 @@ Status VerboseShapeInferenceLogging(const GraphDef& graph_def, // node_names_for_logging to enable detailed logging. absl::flat_hash_set node_names_for_logging = {}; if (!VLOG_IS_ON(3) || node_names_for_logging.empty()) { - return OkStatus(); + return absl::OkStatus(); } auto should_log = [&node_names_for_logging](std::string node_name) { @@ -2231,7 +2231,7 @@ Status VerboseShapeInferenceLogging(const GraphDef& graph_def, VLOG(3) << ""; } - return OkStatus(); + return absl::OkStatus(); } Status GraphProperties::RelaxEnqueueShapesAndMergeTypes( @@ -2254,7 +2254,7 @@ Status GraphProperties::RelaxEnqueueShapesAndMergeTypes( b.shape = shape_refiner->OutputAsUnion(qnode, i, a.shape, b.shape); } - return OkStatus(); + return absl::OkStatus(); } // Compute the output shape of the merge node as the union of the available @@ -2308,7 +2308,7 @@ Status GraphProperties::UpdateMerge(SymbolicShapeRefiner* shape_refiner, *new_shapes = true; } - return OkStatus(); + return absl::OkStatus(); } // Manually propagate the input shape for Enter nodes. @@ -2336,7 +2336,7 @@ Status GraphProperties::UpdateEnter(SymbolicShapeRefiner* shape_refiner, ic->set_output_handle_shapes_and_types(0, *outputs); *new_shapes = true; } - return OkStatus(); + return absl::OkStatus(); } Status GraphProperties::UpdateShapes( @@ -2364,7 +2364,7 @@ Status GraphProperties::UpdateShapes( TF_RETURN_IF_ERROR(shape_refiner->UpdateNode(n, new_shapes)); } - return OkStatus(); + return absl::OkStatus(); } // Propagates the shapes in the transitive fan-out of . @@ -2417,7 +2417,7 @@ Status GraphProperties::PropagateShapes( return errors::Internal("Shape inference failed to converge"); } - return OkStatus(); + return absl::OkStatus(); } Status GraphProperties::UpdateQueue(const NodeDef* queue_node, @@ -2480,12 +2480,12 @@ Status GraphProperties::UpdateEnqueue( auto it = resource_handles.find(enqueue_node); if (it == resource_handles.end()) { // The corresponding queue was not found, there isn't much we can do. - return OkStatus(); + return absl::OkStatus(); } const NodeDef* qnode = it->second; auto qctx = shape_refiner->GetContext(qnode); if (!qctx) { - return OkStatus(); + return absl::OkStatus(); } auto* queue_handle_data = qctx->output_handle_shapes_and_types(0); @@ -2511,7 +2511,7 @@ Status GraphProperties::UpdateEnqueue( qctx->set_output_handle_shapes_and_types(0, shapes_and_types); } - return OkStatus(); + return absl::OkStatus(); } Status GraphProperties::InferStatically(bool assume_valid_feeds, @@ -2756,7 +2756,7 @@ Status GraphProperties::InferStatically(bool assume_valid_feeds, TF_RETURN_IF_ERROR(VerboseShapeInferenceLogging(item_.graph, refiner.get(), shape_manager.get())); - return OkStatus(); + return absl::OkStatus(); } Status GraphProperties::InferDynamically(Cluster* cluster) { @@ -2783,7 +2783,7 @@ Status GraphProperties::AnnotateOutputShapes(GraphDef* output_graph_def) const { } (*node->mutable_attr())["_output_shapes"] = std::move(attr_output_shape); } - return OkStatus(); + return absl::OkStatus(); } Status GraphProperties::InferFromCostGraph(const CostGraphDef& cost_graph) { @@ -2819,7 +2819,7 @@ Status GraphProperties::InferFromCostGraph(const CostGraphDef& cost_graph) { input_properties_[node.name()] = inputs; } - return OkStatus(); + return absl::OkStatus(); } bool GraphProperties::HasInputProperties(const string& node_name) const { diff --git a/tensorflow/core/grappler/costs/graph_properties_test.cc b/tensorflow/core/grappler/costs/graph_properties_test.cc index 97296d6a0f3cda..0824a95eefaf19 100644 --- a/tensorflow/core/grappler/costs/graph_properties_test.cc +++ b/tensorflow/core/grappler/costs/graph_properties_test.cc @@ -300,7 +300,7 @@ REGISTER_OP("DetectInputValueInShapeInferenceOp") if (c->input_tensor(0)) { // 10x10 if input_tensor is given to the inference context. c->set_output(0, c->Matrix(10, 10)); - return OkStatus(); + return absl::OkStatus(); } // unknown rank if input_tensor is not provided. return shape_inference::UnknownShape(c); diff --git a/tensorflow/core/grappler/costs/measuring_cost_estimator.cc b/tensorflow/core/grappler/costs/measuring_cost_estimator.cc index f3385e67802a2c..e7c330cbe22e6c 100644 --- a/tensorflow/core/grappler/costs/measuring_cost_estimator.cc +++ b/tensorflow/core/grappler/costs/measuring_cost_estimator.cc @@ -144,7 +144,7 @@ Status MeasuringCostEstimator::PredictCosts(const GraphDef& optimized_graph, RobustStats stats(times); costs->execution_time = Costs::Duration(stats.mean()); - return OkStatus(); + return absl::OkStatus(); } } // end namespace grappler } // end namespace tensorflow diff --git a/tensorflow/core/grappler/costs/virtual_scheduler.cc b/tensorflow/core/grappler/costs/virtual_scheduler.cc index 3fcb21e5862068..43941e62226648 100644 --- a/tensorflow/core/grappler/costs/virtual_scheduler.cc +++ b/tensorflow/core/grappler/costs/virtual_scheduler.cc @@ -180,7 +180,7 @@ Status HeapReadyManager::Init( // Sets up the comparator for the heap. greater_ = Greater(); - return OkStatus(); + return absl::OkStatus(); } void HeapReadyManager::AddNode(const NodeDef* node) { @@ -269,7 +269,7 @@ void PriorityReadyManager::AddNode(const NodeDef* node) { Status PriorityReadyManager::SetPriority( const std::unordered_map& node_priority) { node_priority_ = node_priority; - return OkStatus(); + return absl::OkStatus(); } CompositeNodeManager::CompositeNodeManager() @@ -281,7 +281,7 @@ Status CompositeNodeManager::Init( TF_RETURN_IF_ERROR(send_manager_.Init(node_map)); TF_RETURN_IF_ERROR(recv_manager_.Init(node_map)); curr_node_ = nullptr; - return OkStatus(); + return absl::OkStatus(); } void CompositeNodeManager::AddNode(const NodeDef* node) { @@ -593,7 +593,7 @@ Status SchedulerState::Init(const GrapplerItem* item, } initialized_ = true; - return OkStatus(); + return absl::OkStatus(); } void SchedulerState::MaybeUpdateInputOutput(const NodeDef* node) { diff --git a/tensorflow/core/grappler/costs/virtual_scheduler.h b/tensorflow/core/grappler/costs/virtual_scheduler.h index fe33b8eeda04a7..12aaa1ea7da325 100644 --- a/tensorflow/core/grappler/costs/virtual_scheduler.h +++ b/tensorflow/core/grappler/costs/virtual_scheduler.h @@ -174,7 +174,7 @@ class ReadyNodeManager { virtual ~ReadyNodeManager() {} virtual Status Init( const std::unordered_map* node_map) { - return OkStatus(); + return absl::OkStatus(); } virtual void AddNode(const NodeDef* node) = 0; virtual const NodeDef* GetCurrNode() = 0; diff --git a/tensorflow/core/grappler/graph_analyzer/gen_node.cc b/tensorflow/core/grappler/graph_analyzer/gen_node.cc index f7a9f11bf3dcfc..175466f31e7f4a 100644 --- a/tensorflow/core/grappler/graph_analyzer/gen_node.cc +++ b/tensorflow/core/grappler/graph_analyzer/gen_node.cc @@ -44,7 +44,7 @@ Status GenNode::BuildGraphInMap(const GraphDef& source, GenNodeMap* map) { return st; } } - return OkStatus(); + return absl::OkStatus(); } Status GenNode::ParseInputs(const GenNodeMap* map) { @@ -119,7 +119,7 @@ Status GenNode::ParseInputs(const GenNodeMap* map) { links_[this_port].emplace_back(LinkTarget(other_node, other_port)); other_node->links_[other_port].emplace_back(LinkTarget(this, this_port)); } - return OkStatus(); + return absl::OkStatus(); } bool GenNode::IsMultiInput(Port port) const { diff --git a/tensorflow/core/grappler/graph_analyzer/gen_node_test.cc b/tensorflow/core/grappler/graph_analyzer/gen_node_test.cc index 79d32dc348f4d5..e41acc5337fa5f 100644 --- a/tensorflow/core/grappler/graph_analyzer/gen_node_test.cc +++ b/tensorflow/core/grappler/graph_analyzer/gen_node_test.cc @@ -79,7 +79,7 @@ TEST(GenNodeTest, ParseNodeNoInputs) { map["node1"] = std::make_unique(&node1); auto gn1 = map["node1"].get(); - ASSERT_THAT(gn1->ParseInputs(&map), Eq(OkStatus())); + ASSERT_THAT(gn1->ParseInputs(&map), Eq(absl::OkStatus())); EXPECT_THAT(DumpLinkMap(gn1->links()), ElementsAre()); } @@ -101,7 +101,7 @@ TEST(GenNodeTest, ParseNodeWithControl) { auto gn1 = map["node1"].get(); auto gn2 = map["node2"].get(); auto gn3 = map["node3"].get(); - ASSERT_THAT(gn3->ParseInputs(&map), Eq(OkStatus())); + ASSERT_THAT(gn3->ParseInputs(&map), Eq(absl::OkStatus())); // clang-format off EXPECT_THAT(DumpLinkMap(gn1->links()), ElementsAre( "o0: node3[i0]", @@ -147,7 +147,7 @@ TEST(GenNodeTest, ParseNodeCommutative) { auto gn1 = map["node1"].get(); auto gn2 = map["node2"].get(); auto gn3 = map["node3"].get(); - ASSERT_THAT(gn3->ParseInputs(&map), Eq(OkStatus())); + ASSERT_THAT(gn3->ParseInputs(&map), Eq(absl::OkStatus())); // clang-format off EXPECT_THAT(DumpLinkMap(gn1->links()), ElementsAre( "o0: node3[i0]" @@ -180,7 +180,7 @@ TEST(GenNodeTest, ParseNodeMultiInputCommutative) { auto gn1 = map["node1"].get(); auto gn2 = map["node2"].get(); auto gn3 = map["node3"].get(); - ASSERT_THAT(gn3->ParseInputs(&map), Eq(OkStatus())); + ASSERT_THAT(gn3->ParseInputs(&map), Eq(absl::OkStatus())); // clang-format off EXPECT_THAT(DumpLinkMap(gn1->links()), ElementsAre( "o0: node3[i0]" @@ -216,7 +216,7 @@ TEST(GenNodeTest, ParseNodeMultiInputNotCommutative) { auto gn1 = map["node1"].get(); auto gn2 = map["node2"].get(); auto gn3 = map["node3"].get(); - ASSERT_THAT(gn3->ParseInputs(&map), Eq(OkStatus())); + ASSERT_THAT(gn3->ParseInputs(&map), Eq(absl::OkStatus())); // clang-format off EXPECT_THAT(DumpLinkMap(gn1->links()), ElementsAre( "o0: node3[i0]" @@ -250,7 +250,7 @@ TEST(GenNodeTest, ParseNodeMultiInputList) { auto gn1 = map["node1"].get(); auto gn2 = map["node2"].get(); auto gn3 = map["node3"].get(); - ASSERT_THAT(gn3->ParseInputs(&map), Eq(OkStatus())); + ASSERT_THAT(gn3->ParseInputs(&map), Eq(absl::OkStatus())); // clang-format off EXPECT_THAT(DumpLinkMap(gn1->links()), ElementsAre( "o0: node3[i0]" @@ -293,7 +293,7 @@ TEST(GenNodeTest, ParseNodeMultiMultiInput) { auto gn3 = map["node3"].get(); auto gn4 = map["node4"].get(); auto gn5 = map["node5"].get(); - ASSERT_THAT(gn5->ParseInputs(&map), Eq(OkStatus())); + ASSERT_THAT(gn5->ParseInputs(&map), Eq(absl::OkStatus())); // clang-format off EXPECT_THAT(DumpLinkMap(gn1->links()), ElementsAre( "o0: node5[i0]" @@ -337,7 +337,7 @@ TEST(GenNodeTest, ParseNodeMultiOutput) { map["node4"] = std::make_unique(&node4); auto gn4 = map["node4"].get(); - ASSERT_THAT(gn4->ParseInputs(&map), Eq(OkStatus())); + ASSERT_THAT(gn4->ParseInputs(&map), Eq(absl::OkStatus())); // clang-format off EXPECT_THAT(DumpLinkMap(gn4->links()), ElementsAre( "i0: node3[o1]", @@ -403,7 +403,7 @@ TEST(GenNodeTest, ParseNodeControlInputsAlwaysOk) { map["node1"] = std::make_unique(&node1); node1.add_input("^node1"); auto gn1 = map["node1"].get(); - ASSERT_THAT(gn1->ParseInputs(&map), Eq(OkStatus())); + ASSERT_THAT(gn1->ParseInputs(&map), Eq(absl::OkStatus())); // clang-format off EXPECT_THAT(DumpLinkMap(gn1->links()), ElementsAre( "iC: node1[oC]", @@ -434,7 +434,7 @@ TEST(GenNodeTest, BuildGraphInMap) { MakeNodeBroadcastGradientArgs("node3", "node1", "node2"); GenNodeMap map; - ASSERT_THAT(GenNode::BuildGraphInMap(graph, &map), Eq(OkStatus())); + ASSERT_THAT(GenNode::BuildGraphInMap(graph, &map), Eq(absl::OkStatus())); ASSERT_THAT(map.find("node1"), Ne(map.end())); ASSERT_THAT(map.find("node2"), Ne(map.end())); ASSERT_THAT(map.find("node3"), Ne(map.end())); diff --git a/tensorflow/core/grappler/graph_analyzer/graph_analyzer.cc b/tensorflow/core/grappler/graph_analyzer/graph_analyzer.cc index cbea47e263c1b5..70b4637d7b9d21 100644 --- a/tensorflow/core/grappler/graph_analyzer/graph_analyzer.cc +++ b/tensorflow/core/grappler/graph_analyzer/graph_analyzer.cc @@ -53,7 +53,7 @@ Status GraphAnalyzer::Run() { return st; } - return OkStatus(); + return absl::OkStatus(); } Status GraphAnalyzer::BuildMap() { @@ -305,7 +305,7 @@ Status GraphAnalyzer::CollateResult() { result_.clear(); // Not needed after collation. - return OkStatus(); + return absl::OkStatus(); } std::vector GraphAnalyzer::DumpRawSubgraphs() { @@ -335,7 +335,7 @@ Status GraphAnalyzer::OutputSubgraphs() { if (std::cout.fail()) { return Status(absl::StatusCode::kDataLoss, "Failed to write to stdout"); } else { - return OkStatus(); + return absl::OkStatus(); } } diff --git a/tensorflow/core/grappler/graph_analyzer/graph_analyzer_test.cc b/tensorflow/core/grappler/graph_analyzer/graph_analyzer_test.cc index 011959e2665d93..65e57ef5443deb 100644 --- a/tensorflow/core/grappler/graph_analyzer/graph_analyzer_test.cc +++ b/tensorflow/core/grappler/graph_analyzer/graph_analyzer_test.cc @@ -79,7 +79,7 @@ class GraphAnalyzerTest : public ::testing::Test, protected TestGraphs { TEST_F(GraphAnalyzerTest, BuildMap) { gran_ = std::make_unique(graph_3n_self_control_, 1); Status st = BuildMap(); - EXPECT_THAT(st, Eq(OkStatus())); + EXPECT_THAT(st, Eq(absl::OkStatus())); auto& map = GetNodes(); EXPECT_THAT(map.find("node1"), Ne(map.end())); @@ -99,7 +99,7 @@ TEST_F(GraphAnalyzerTest, BuildMapError) { TEST_F(GraphAnalyzerTest, FindSubgraphs0) { gran_ = std::make_unique(graph_3n_self_control_, 0); Status st = BuildMap(); - ASSERT_THAT(st, Eq(OkStatus())); + ASSERT_THAT(st, Eq(absl::OkStatus())); FindSubgraphs(); auto& subgraphs = GetResult(); @@ -112,7 +112,7 @@ TEST_F(GraphAnalyzerTest, FindSubgraphs0) { TEST_F(GraphAnalyzerTest, FindSubgraphs1) { gran_ = std::make_unique(graph_3n_self_control_, 1); Status st = BuildMap(); - ASSERT_THAT(st, Eq(OkStatus())); + ASSERT_THAT(st, Eq(absl::OkStatus())); FindSubgraphs(); auto& subgraphs = GetResult(); @@ -133,7 +133,7 @@ TEST_F(GraphAnalyzerTest, FindSubgraphs1) { TEST_F(GraphAnalyzerTest, FindSubgraphsTooLarge) { gran_ = std::make_unique(graph_3n_self_control_, 4); Status st = BuildMap(); - ASSERT_THAT(st, Eq(OkStatus())); + ASSERT_THAT(st, Eq(absl::OkStatus())); FindSubgraphs(); EXPECT_THAT(DumpRawSubgraphs(), ElementsAre()); @@ -148,7 +148,7 @@ TEST_F(GraphAnalyzerTest, FindSubgraphsTooLarge) { TEST_F(GraphAnalyzerTest, MultiInputSuccessBackwardsBaseIn) { gran_ = std::make_unique(graph_multi_input_, 4); Status st = BuildMap(); - ASSERT_THAT(st, Eq(OkStatus())); + ASSERT_THAT(st, Eq(absl::OkStatus())); auto root = std::make_unique(Subgraph::Identity({GetNode("add2")})); @@ -170,7 +170,7 @@ TEST_F(GraphAnalyzerTest, MultiInputSuccessBackwardsBaseIn) { TEST_F(GraphAnalyzerTest, MultiInputSuccessBackwardsBaseOut) { gran_ = std::make_unique(graph_multi_input_, 4); Status st = BuildMap(); - ASSERT_THAT(st, Eq(OkStatus())); + ASSERT_THAT(st, Eq(absl::OkStatus())); auto parent = std::make_unique(Subgraph::Identity()); auto root = @@ -193,7 +193,7 @@ TEST_F(GraphAnalyzerTest, MultiInputSuccessBackwardsBaseOut) { TEST_F(GraphAnalyzerTest, MultiInputSuccessBackwardsIncomplete) { gran_ = std::make_unique(graph_multi_input_, 5); Status st = BuildMap(); - ASSERT_THAT(st, Eq(OkStatus())); + ASSERT_THAT(st, Eq(absl::OkStatus())); auto root = std::make_unique(Subgraph::Identity({GetNode("add2")})); @@ -215,7 +215,7 @@ TEST_F(GraphAnalyzerTest, MultiInputSuccessBackwardsIncomplete) { TEST_F(GraphAnalyzerTest, MultiInputTooLargeBackwards) { gran_ = std::make_unique(graph_multi_input_, 3); Status st = BuildMap(); - ASSERT_THAT(st, Eq(OkStatus())); + ASSERT_THAT(st, Eq(absl::OkStatus())); auto root = std::make_unique(Subgraph::Identity({GetNode("add2")})); @@ -233,7 +233,7 @@ TEST_F(GraphAnalyzerTest, MultiInputTooLargeBackwards) { TEST_F(GraphAnalyzerTest, MultiInputNothingAddedBackwards) { gran_ = std::make_unique(graph_multi_input_, 4); Status st = BuildMap(); - ASSERT_THAT(st, Eq(OkStatus())); + ASSERT_THAT(st, Eq(absl::OkStatus())); auto root = std::make_unique( Subgraph::Identity({GetNode("add2"), GetNode("const2_1"), @@ -252,7 +252,7 @@ TEST_F(GraphAnalyzerTest, MultiInputNothingAddedBackwards) { TEST_F(GraphAnalyzerTest, MultiInputSuccessForwardsBaseOut) { gran_ = std::make_unique(graph_multi_input_, 4); Status st = BuildMap(); - ASSERT_THAT(st, Eq(OkStatus())); + ASSERT_THAT(st, Eq(absl::OkStatus())); auto root = std::make_unique(Subgraph::Identity({GetNode("const2_1")})); @@ -273,7 +273,7 @@ TEST_F(GraphAnalyzerTest, MultiInputSuccessForwardsBaseOut) { TEST_F(GraphAnalyzerTest, MultiInputSuccessBackwardsFull) { gran_ = std::make_unique(graph_multi_input_, 4); Status st = BuildMap(); - ASSERT_THAT(st, Eq(OkStatus())); + ASSERT_THAT(st, Eq(absl::OkStatus())); auto root = std::make_unique(Subgraph::Identity({GetNode("add2")})); @@ -295,7 +295,7 @@ TEST_F(GraphAnalyzerTest, MultiInputSuccessBackwardsFull) { TEST_F(GraphAnalyzerTest, MultiInputSuccessForwardsFull) { gran_ = std::make_unique(graph_multi_input_, 4); Status st = BuildMap(); - ASSERT_THAT(st, Eq(OkStatus())); + ASSERT_THAT(st, Eq(absl::OkStatus())); auto root = std::make_unique(Subgraph::Identity({GetNode("const2_1")})); @@ -314,7 +314,7 @@ TEST_F(GraphAnalyzerTest, MultiInputSuccessForwardsFull) { TEST_F(GraphAnalyzerTest, DropInvalidSubgraphsMulti) { gran_ = std::make_unique(graph_multi_input_, 3); Status st = BuildMap(); - ASSERT_THAT(st, Eq(OkStatus())); + ASSERT_THAT(st, Eq(absl::OkStatus())); // A good one, multi-input is all-in. GetResult().insert(std::make_unique(Subgraph::Identity({ @@ -360,7 +360,7 @@ TEST_F(GraphAnalyzerTest, DropInvalidSubgraphsMulti) { TEST_F(GraphAnalyzerTest, AllOrNoneInputSuccessBackwards) { gran_ = std::make_unique(graph_all_or_none_, 4); Status st = BuildMap(); - ASSERT_THAT(st, Eq(OkStatus())); + ASSERT_THAT(st, Eq(absl::OkStatus())); auto root = std::make_unique(Subgraph::Identity({GetNode("pass2")})); @@ -382,7 +382,7 @@ TEST_F(GraphAnalyzerTest, AllOrNoneInputSuccessBackwards) { TEST_F(GraphAnalyzerTest, AllOrNoneInputSuccessBackwardsNoControl) { gran_ = std::make_unique(graph_all_or_none_, 5); Status st = BuildMap(); - ASSERT_THAT(st, Eq(OkStatus())); + ASSERT_THAT(st, Eq(absl::OkStatus())); auto root = std::make_unique(Subgraph::Identity({GetNode("pass1")})); @@ -403,7 +403,7 @@ TEST_F(GraphAnalyzerTest, AllOrNoneInputSuccessBackwardsNoControl) { TEST_F(GraphAnalyzerTest, AllOrNoneInputSeparateControl) { gran_ = std::make_unique(graph_all_or_none_, 5); Status st = BuildMap(); - ASSERT_THAT(st, Eq(OkStatus())); + ASSERT_THAT(st, Eq(absl::OkStatus())); auto root = std::make_unique(Subgraph::Identity({GetNode("pass1")})); @@ -425,7 +425,7 @@ TEST_F(GraphAnalyzerTest, AllOrNoneInputSeparateControl) { TEST_F(GraphAnalyzerTest, AllOrNoneInputTooLargeBackwards) { gran_ = std::make_unique(graph_all_or_none_, 3); Status st = BuildMap(); - ASSERT_THAT(st, Eq(OkStatus())); + ASSERT_THAT(st, Eq(absl::OkStatus())); auto root = std::make_unique(Subgraph::Identity({GetNode("pass2")})); @@ -442,7 +442,7 @@ TEST_F(GraphAnalyzerTest, AllOrNoneInputTooLargeBackwards) { TEST_F(GraphAnalyzerTest, AllOrNoneInputNothingAddedBackwards) { gran_ = std::make_unique(graph_all_or_none_, 4); Status st = BuildMap(); - ASSERT_THAT(st, Eq(OkStatus())); + ASSERT_THAT(st, Eq(absl::OkStatus())); auto root = std::make_unique( Subgraph::Identity({GetNode("pass2"), GetNode("const2_1"), @@ -460,7 +460,7 @@ TEST_F(GraphAnalyzerTest, AllOrNoneInputNothingAddedBackwards) { TEST_F(GraphAnalyzerTest, AllOrNoneInputSuccessForwardsBaseOut) { gran_ = std::make_unique(graph_all_or_none_, 4); Status st = BuildMap(); - ASSERT_THAT(st, Eq(OkStatus())); + ASSERT_THAT(st, Eq(absl::OkStatus())); auto root = std::make_unique(Subgraph::Identity({GetNode("const2_1")})); @@ -480,7 +480,7 @@ TEST_F(GraphAnalyzerTest, AllOrNoneInputSuccessForwardsBaseOut) { TEST_F(GraphAnalyzerTest, AllOrNoneInputSuccessBackwardsFull) { gran_ = std::make_unique(graph_all_or_none_, 4); Status st = BuildMap(); - ASSERT_THAT(st, Eq(OkStatus())); + ASSERT_THAT(st, Eq(absl::OkStatus())); auto root = std::make_unique(Subgraph::Identity({GetNode("pass2")})); @@ -504,7 +504,7 @@ TEST_F(GraphAnalyzerTest, AllOrNoneInputSuccessBackwardsFull) { TEST_F(GraphAnalyzerTest, AllOrNoneInputSuccessForwardsFull) { gran_ = std::make_unique(graph_all_or_none_, 4); Status st = BuildMap(); - ASSERT_THAT(st, Eq(OkStatus())); + ASSERT_THAT(st, Eq(absl::OkStatus())); auto root = std::make_unique(Subgraph::Identity({GetNode("const2_1")})); @@ -524,7 +524,7 @@ TEST_F(GraphAnalyzerTest, AllOrNoneInputSuccessForwardsFull) { TEST_F(GraphAnalyzerTest, DropInvalidSubgraphsAllOrNone) { gran_ = std::make_unique(graph_all_or_none_, 3); Status st = BuildMap(); - ASSERT_THAT(st, Eq(OkStatus())); + ASSERT_THAT(st, Eq(absl::OkStatus())); // A good one, all-or-none is all-in. GetResult().insert(std::make_unique(Subgraph::Identity({ diff --git a/tensorflow/core/grappler/graph_analyzer/sig_node.cc b/tensorflow/core/grappler/graph_analyzer/sig_node.cc index 8e53e13b88833a..fc680b04265558 100644 --- a/tensorflow/core/grappler/graph_analyzer/sig_node.cc +++ b/tensorflow/core/grappler/graph_analyzer/sig_node.cc @@ -248,7 +248,7 @@ Status Signature::Compute() { OrderLinks(); - return OkStatus(); + return absl::OkStatus(); } void Signature::PrepareNodes() { diff --git a/tensorflow/core/grappler/graph_analyzer/sig_node_test.cc b/tensorflow/core/grappler/graph_analyzer/sig_node_test.cc index 82081e898d9bb4..70f8e9a0fe3761 100644 --- a/tensorflow/core/grappler/graph_analyzer/sig_node_test.cc +++ b/tensorflow/core/grappler/graph_analyzer/sig_node_test.cc @@ -479,7 +479,8 @@ TEST_F(SigNodeTest, EqualsLinkSize) { (*graph1.add_node()) = MakeNodeMul("node2", "node1", "node1"); GenNodeMap gen_map1; - ASSERT_THAT(GenNode::BuildGraphInMap(graph1, &gen_map1), Eq(OkStatus())); + ASSERT_THAT(GenNode::BuildGraphInMap(graph1, &gen_map1), + Eq(absl::OkStatus())); Subgraph::Identity id1; id1.insert(gen_map1["node1"].get()); @@ -497,7 +498,8 @@ TEST_F(SigNodeTest, EqualsLinkSize) { node22->add_input("node2"); GenNodeMap gen_map2; - ASSERT_THAT(GenNode::BuildGraphInMap(graph2, &gen_map2), Eq(OkStatus())); + ASSERT_THAT(GenNode::BuildGraphInMap(graph2, &gen_map2), + Eq(absl::OkStatus())); Subgraph::Identity id2; id2.insert(gen_map2["node1"].get()); @@ -519,7 +521,8 @@ TEST_F(SigNodeTest, EqualsLinks) { (*graph1.add_node()) = MakeNodeMul("node2", "node1", "node1"); GenNodeMap gen_map1; - ASSERT_THAT(GenNode::BuildGraphInMap(graph1, &gen_map1), Eq(OkStatus())); + ASSERT_THAT(GenNode::BuildGraphInMap(graph1, &gen_map1), + Eq(absl::OkStatus())); Subgraph::Identity id1; id1.insert(gen_map1["node1"].get()); @@ -530,7 +533,8 @@ TEST_F(SigNodeTest, EqualsLinks) { sg1.ExtractForSignature(&sig_map1); GenNodeMap gen_map2; - ASSERT_THAT(GenNode::BuildGraphInMap(graph1, &gen_map2), Eq(OkStatus())); + ASSERT_THAT(GenNode::BuildGraphInMap(graph1, &gen_map2), + Eq(absl::OkStatus())); Subgraph::Identity id2; id2.insert(gen_map2["node1"].get()); @@ -610,7 +614,7 @@ class SignatureTest : public SigBaseTest { gen_map_.clear(); sig_.map.clear(); Status result = GenNode::BuildGraphInMap(graph, &gen_map_); - ASSERT_THAT(result, Eq(OkStatus())); + ASSERT_THAT(result, Eq(absl::OkStatus())); Subgraph::Identity id; for (const auto& entry : gen_map_) { id.insert(entry.second.get()); @@ -668,7 +672,7 @@ class SignatureTest : public SigBaseTest { OrderLinks(&sig_); // The test as such. - ASSERT_THAT(sig_.Compute(), Eq(OkStatus())); + ASSERT_THAT(sig_.Compute(), Eq(absl::OkStatus())); signatures.insert(sig_.ToString()); @@ -1035,7 +1039,7 @@ TEST_F(SignatureTest, OrderLinks) { gen_map_.clear(); sig_.map.clear(); Status result = GenNode::BuildGraphInMap(graph_for_link_order_, &gen_map_); - ASSERT_THAT(result, Eq(OkStatus())); + ASSERT_THAT(result, Eq(absl::OkStatus())); Subgraph::Identity id; for (const auto& entry : gen_map_) { id.insert(entry.second.get()); @@ -1084,7 +1088,7 @@ TEST_F(SignatureTest, GraphTooBig) { (*graph.add_node()) = MakeNodeConst(absl::StrFormat("node%d", i)); } - ASSERT_THAT(GenNode::BuildGraphInMap(graph, &gen_map_), Eq(OkStatus())); + ASSERT_THAT(GenNode::BuildGraphInMap(graph, &gen_map_), Eq(absl::OkStatus())); Subgraph::Identity id; for (const auto& entry : gen_map_) { @@ -1173,7 +1177,7 @@ TEST_F(SignatureTest, Equals) { // Start with 2 copies of the same graph. GenNodeMap gen_map1; ASSERT_THAT(GenNode::BuildGraphInMap(graph_circular_bidir_, &gen_map1), - Eq(OkStatus())); + Eq(absl::OkStatus())); Subgraph::Identity id1; id1.insert(gen_map1["node1"].get()); @@ -1182,11 +1186,11 @@ TEST_F(SignatureTest, Equals) { Signature sig1; sg1.ExtractForSignature(&sig1.map); - ASSERT_THAT(sig1.Compute(), Eq(OkStatus())); + ASSERT_THAT(sig1.Compute(), Eq(absl::OkStatus())); GenNodeMap gen_map2; ASSERT_THAT(GenNode::BuildGraphInMap(graph_circular_bidir_, &gen_map2), - Eq(OkStatus())); + Eq(absl::OkStatus())); Subgraph::Identity id2; id2.insert(gen_map2["node1"].get()); @@ -1195,7 +1199,7 @@ TEST_F(SignatureTest, Equals) { Signature sig2; sg2.ExtractForSignature(&sig2.map); - ASSERT_THAT(sig2.Compute(), Eq(OkStatus())); + ASSERT_THAT(sig2.Compute(), Eq(absl::OkStatus())); EXPECT_TRUE(sig1 == sig2); diff --git a/tensorflow/core/grappler/graph_analyzer/subgraph_test.cc b/tensorflow/core/grappler/graph_analyzer/subgraph_test.cc index c6ec39238727d7..bbe6856417adbe 100644 --- a/tensorflow/core/grappler/graph_analyzer/subgraph_test.cc +++ b/tensorflow/core/grappler/graph_analyzer/subgraph_test.cc @@ -41,7 +41,7 @@ TEST(SubgraphTest, Comparison) { (*graph.add_node()) = MakeNodeConst("node1"); (*graph.add_node()) = MakeNodeConst("node2"); GenNodeMap map; - ASSERT_THAT(GenNode::BuildGraphInMap(graph, &map), Eq(OkStatus())); + ASSERT_THAT(GenNode::BuildGraphInMap(graph, &map), Eq(absl::OkStatus())); auto gn1 = map["node1"].get(); auto gn2 = map["node2"].get(); ASSERT_THAT(gn1, Ne(nullptr)); @@ -88,7 +88,7 @@ TEST(SubgraphTest, Iteration) { node3->add_input("^node3"); // The control link goes back to self. GenNodeMap map; - ASSERT_THAT(GenNode::BuildGraphInMap(graph, &map), Eq(OkStatus())); + ASSERT_THAT(GenNode::BuildGraphInMap(graph, &map), Eq(absl::OkStatus())); ASSERT_THAT(map.find("node3"), Ne(map.end())); Subgraph::Identity id; @@ -151,7 +151,7 @@ TEST(SubgraphTest, IterationSamePort) { (*graph.add_node()) = MakeNodeAddN("node3", "node1", "node2"); GenNodeMap map; - ASSERT_THAT(GenNode::BuildGraphInMap(graph, &map), Eq(OkStatus())); + ASSERT_THAT(GenNode::BuildGraphInMap(graph, &map), Eq(absl::OkStatus())); ASSERT_THAT(map.find("node3"), Ne(map.end())); Subgraph::Identity id; @@ -201,7 +201,7 @@ TEST(SubgraphTest, IterationSameNode) { (*graph.add_node()) = MakeNodeAddN("node3", "node1", "node2"); GenNodeMap map; - ASSERT_THAT(GenNode::BuildGraphInMap(graph, &map), Eq(OkStatus())); + ASSERT_THAT(GenNode::BuildGraphInMap(graph, &map), Eq(absl::OkStatus())); ASSERT_THAT(map.find("node3"), Ne(map.end())); Subgraph::Identity id; @@ -252,7 +252,7 @@ TEST(SubgraphTest, ExtendSet) { node3->add_input("^node3"); // The control link goes back to self. GenNodeMap map; - ASSERT_THAT(GenNode::BuildGraphInMap(graph, &map), Eq(OkStatus())); + ASSERT_THAT(GenNode::BuildGraphInMap(graph, &map), Eq(absl::OkStatus())); ASSERT_THAT(map.find("node2"), Ne(map.end())); ASSERT_THAT(map.find("node3"), Ne(map.end())); @@ -301,7 +301,7 @@ TEST(SubgraphTest, ExtractForSignature) { node3->add_input("^node3"); // The control link goes back to self. GenNodeMap map; - ASSERT_THAT(GenNode::BuildGraphInMap(graph, &map), Eq(OkStatus())); + ASSERT_THAT(GenNode::BuildGraphInMap(graph, &map), Eq(absl::OkStatus())); ASSERT_THAT(map.find("node1"), Ne(map.end())); ASSERT_THAT(map.find("node2"), Ne(map.end())); ASSERT_THAT(map.find("node3"), Ne(map.end())); diff --git a/tensorflow/core/grappler/graph_topology_view.cc b/tensorflow/core/grappler/graph_topology_view.cc index 265b1f4e58fab6..7dfb15c2c07f6a 100644 --- a/tensorflow/core/grappler/graph_topology_view.cc +++ b/tensorflow/core/grappler/graph_topology_view.cc @@ -137,7 +137,7 @@ Status GraphTopologyView::InitializeFromGraph( SortAndRemoveDuplicates(&fanouts_[node_idx]); } - return OkStatus(); + return absl::OkStatus(); } Status GraphTopologyView::InitializeFromGraph( diff --git a/tensorflow/core/grappler/graph_view.h b/tensorflow/core/grappler/graph_view.h index a7775aaeb31c7d..03bffa03bc112a 100644 --- a/tensorflow/core/grappler/graph_view.h +++ b/tensorflow/core/grappler/graph_view.h @@ -323,7 +323,7 @@ class GraphViewInternal { Status AddUniqueNode(NodeDefT* node) { auto inserted = nodes_.emplace(node->name(), node); return inserted.second - ? OkStatus() + ? absl::OkStatus() : absl::InvalidArgumentError(absl::StrCat( "Non unique node name detected: ", node->name())); } diff --git a/tensorflow/core/grappler/grappler_item.cc b/tensorflow/core/grappler/grappler_item.cc index 18e711c5aeb547..17e01a67bf3793 100644 --- a/tensorflow/core/grappler/grappler_item.cc +++ b/tensorflow/core/grappler/grappler_item.cc @@ -184,7 +184,7 @@ Status GrapplerItem::AddDevice(const string& device) { } devices_.insert(DeviceNameUtils::ParsedNameToString(name)); - return OkStatus(); + return absl::OkStatus(); } Status GrapplerItem::AddDevices(const GrapplerItem& other) { @@ -194,7 +194,7 @@ Status GrapplerItem::AddDevices(const GrapplerItem& other) { if (!added.ok()) invalid_devices.emplace_back(device); } return invalid_devices.empty() - ? OkStatus() + ? absl::OkStatus() : errors::InvalidArgument("Skipped invalid devices: [", absl::StrJoin(invalid_devices, ", "), "]"); @@ -208,7 +208,7 @@ Status GrapplerItem::InferDevicesFromGraph() { } VLOG(2) << "Inferred device set: [" << absl::StrJoin(devices_, ", ") << "]"; return invalid_devices.empty() - ? OkStatus() + ? absl::OkStatus() : errors::InvalidArgument("Skipped invalid devices: [", absl::StrJoin(invalid_devices, ", "), "]"); diff --git a/tensorflow/core/grappler/grappler_item_builder.cc b/tensorflow/core/grappler/grappler_item_builder.cc index ce4f96ac204e08..35198461cc8033 100644 --- a/tensorflow/core/grappler/grappler_item_builder.cc +++ b/tensorflow/core/grappler/grappler_item_builder.cc @@ -86,7 +86,7 @@ Status PruneGraph(GrapplerItem* item) { Cluster* cluster = nullptr; // ModelPruner doesn't check cluster. TF_RETURN_IF_ERROR(pruner.Optimize(cluster, *item, &pruned_graph)); item->graph = std::move(pruned_graph); - return OkStatus(); + return absl::OkStatus(); } // Replace any unknown dimensions in a shape with @@ -203,7 +203,7 @@ Status UpdatePlaceholderShape( if (!shape_proto.dim().empty()) *(node->mutable_attr()->at("shape").mutable_shape()) = shape_proto; - return OkStatus(); + return absl::OkStatus(); } } // namespace @@ -223,7 +223,7 @@ Status RuntimeGraphOptimizer(const GraphDef& graph_def_arg, if (output_graph_def != &graph_def_arg) { *output_graph_def = graph_def_arg; } - return OkStatus(); + return absl::OkStatus(); } // Create a session option for a single GPU device. diff --git a/tensorflow/core/grappler/mutable_graph_view.cc b/tensorflow/core/grappler/mutable_graph_view.cc index 743a925cb74c39..638a6a33f9395f 100644 --- a/tensorflow/core/grappler/mutable_graph_view.cc +++ b/tensorflow/core/grappler/mutable_graph_view.cc @@ -272,7 +272,7 @@ Status CheckFaninIsRegular(const TensorId& fanin, ErrorHandler handler) { return handler(absl::Substitute("fanin '$0' must be a regular tensor id", fanin.ToString())); } - return OkStatus(); + return absl::OkStatus(); } Status CheckFaninIsValid(const TensorId& fanin, ErrorHandler handler) { @@ -280,7 +280,7 @@ Status CheckFaninIsValid(const TensorId& fanin, ErrorHandler handler) { return handler(absl::Substitute("fanin '$0' must be a valid tensor id", fanin.ToString())); } - return OkStatus(); + return absl::OkStatus(); } Status CheckAddingFaninToSelf(absl::string_view node_name, @@ -289,7 +289,7 @@ Status CheckAddingFaninToSelf(absl::string_view node_name, return handler( absl::Substitute("can't add fanin '$0' to self", fanin.ToString())); } - return OkStatus(); + return absl::OkStatus(); } Status CheckRemovingFaninFromSelf(absl::string_view node_name, @@ -298,7 +298,7 @@ Status CheckRemovingFaninFromSelf(absl::string_view node_name, return handler(absl::Substitute("can't remove fanin '$0' from self", fanin.ToString())); } - return OkStatus(); + return absl::OkStatus(); } string NodeMissingErrorMsg(absl::string_view node_name) { @@ -310,7 +310,7 @@ Status CheckNodeExists(absl::string_view node_name, NodeDef* node, if (node == nullptr) { return handler(NodeMissingErrorMsg(node_name)); } - return OkStatus(); + return absl::OkStatus(); } Status CheckPortRange(int port, int min, int max, ErrorHandler handler) { @@ -321,7 +321,7 @@ Status CheckPortRange(int port, int min, int max, ErrorHandler handler) { return handler( absl::Substitute("port must be in range [$0, $1]", min, max)); } - return OkStatus(); + return absl::OkStatus(); } string SwapNodeNamesSwitchControlErrorMsg(absl::string_view node_name) { @@ -508,7 +508,7 @@ Status MutableGraphView::AddSubgraph(GraphDef&& subgraph) { AddAndDedupFanouts(node); } - return OkStatus(); + return absl::OkStatus(); } Status MutableGraphView::UpdateNode( @@ -549,7 +549,7 @@ Status MutableGraphView::UpdateNode( } if (node->op() == op) { - return OkStatus(); + return absl::OkStatus(); } node->set_op(string(op)); @@ -562,7 +562,7 @@ Status MutableGraphView::UpdateNode( } } - return OkStatus(); + return absl::OkStatus(); } Status MutableGraphView::UpdateNodeName(absl::string_view from_node_name, @@ -580,7 +580,7 @@ Status MutableGraphView::UpdateNodeName(absl::string_view from_node_name, TF_RETURN_IF_ERROR(CheckNodeExists(from_node_name, node, error_status)); if (node->name() == to_node_name) { - return OkStatus(); + return absl::OkStatus(); } if (HasNode(to_node_name)) { return error_status( @@ -605,7 +605,7 @@ Status MutableGraphView::UpdateNodeName(absl::string_view from_node_name, nodes().erase(node->name()); node->set_name(string(to_node_name)); nodes().emplace(node->name(), node); - return OkStatus(); + return absl::OkStatus(); } Status MutableGraphView::SwapNodeNames(absl::string_view from_node_name, @@ -622,7 +622,7 @@ Status MutableGraphView::SwapNodeNames(absl::string_view from_node_name, NodeDef* from_node = GetNode(from_node_name); TF_RETURN_IF_ERROR(CheckNodeExists(from_node_name, from_node, error_status)); if (from_node_name == to_node_name) { - return OkStatus(); + return absl::OkStatus(); } NodeDef* to_node = GetNode(to_node_name); TF_RETURN_IF_ERROR(CheckNodeExists(to_node_name, to_node, error_status)); @@ -639,7 +639,7 @@ Status MutableGraphView::SwapNodeNames(absl::string_view from_node_name, SwapFanoutInputs(*this, &fanouts(), &max_regular_output_port(), from_node, to_node); swap_names(); - return OkStatus(); + return absl::OkStatus(); } bool from_is_switch = IsSwitch(*from_node); @@ -750,7 +750,7 @@ Status MutableGraphView::SwapNodeNames(absl::string_view from_node_name, } } - return OkStatus(); + return absl::OkStatus(); } Status MutableGraphView::UpdateFanouts(absl::string_view from_node_name, @@ -771,7 +771,7 @@ Status MutableGraphView::UpdateFanoutsInternal(NodeDef* from_node, VLOG(2) << absl::Substitute("Update fanouts from '$0' to '$1'.", from_node->name(), to_node->name()); if (from_node == to_node) { - return OkStatus(); + return absl::OkStatus(); } // Update internal state with the new output_port->input_port edge. @@ -860,7 +860,7 @@ Status MutableGraphView::UpdateFanoutsInternal(NodeDef* from_node, max_regular_output_port().erase(from_node); } - return OkStatus(); + return absl::OkStatus(); } bool MutableGraphView::AddFaninInternal(NodeDef* node, @@ -927,7 +927,7 @@ Status MutableGraphView::AddRegularFanin(absl::string_view node_name, TF_RETURN_IF_ERROR(CheckNodeExists(fanin.node(), fanin_node, error_status)); AddFaninInternal(node, {fanin_node, fanin.index()}); - return OkStatus(); + return absl::OkStatus(); } Status MutableGraphView::AddRegularFaninByPort(absl::string_view node_name, @@ -971,7 +971,7 @@ Status MutableGraphView::AddRegularFaninByPort(absl::string_view node_name, RemoveControllingFaninInternal(node, fanin_node); } - return OkStatus(); + return absl::OkStatus(); } NodeDef* MutableGraphView::GetControllingFaninToAdd(absl::string_view node_name, @@ -1061,7 +1061,7 @@ Status MutableGraphView::AddControllingFanin(absl::string_view node_name, } AddFaninInternal(node, {control_node, Graph::kControlSlot}); - return OkStatus(); + return absl::OkStatus(); } bool MutableGraphView::RemoveRegularFaninInternal(NodeDef* node, @@ -1137,7 +1137,7 @@ Status MutableGraphView::RemoveRegularFanin(absl::string_view node_name, TF_RETURN_IF_ERROR(CheckNodeExists(fanin.node(), fanin_node, error_status)); RemoveRegularFaninInternal(node, {fanin_node, fanin.index()}); - return OkStatus(); + return absl::OkStatus(); } Status MutableGraphView::RemoveRegularFaninByPort(absl::string_view node_name, @@ -1180,7 +1180,7 @@ Status MutableGraphView::RemoveRegularFaninByPort(absl::string_view node_name, max_regular_input_port()[node] = updated_last_regular_input_port; } - return OkStatus(); + return absl::OkStatus(); } bool MutableGraphView::RemoveControllingFaninInternal(NodeDef* node, @@ -1218,7 +1218,7 @@ Status MutableGraphView::RemoveControllingFanin( CheckNodeExists(fanin_node_name, fanin_node, error_status)); RemoveControllingFaninInternal(node, fanin_node); - return OkStatus(); + return absl::OkStatus(); } Status MutableGraphView::RemoveAllFanins(absl::string_view node_name, @@ -1233,7 +1233,7 @@ Status MutableGraphView::RemoveAllFanins(absl::string_view node_name, } if (node->input().empty()) { - return OkStatus(); + return absl::OkStatus(); } const int num_regular_fanins = @@ -1241,7 +1241,7 @@ Status MutableGraphView::RemoveAllFanins(absl::string_view node_name, RemoveFaninsInternal(node, keep_controlling_fanins); if (keep_controlling_fanins) { if (num_regular_fanins == 0) { - return OkStatus(); + return absl::OkStatus(); } else if (num_regular_fanins < node->input_size()) { node->mutable_input()->DeleteSubrange(0, num_regular_fanins); } else { @@ -1250,7 +1250,7 @@ Status MutableGraphView::RemoveAllFanins(absl::string_view node_name, } else { node->clear_input(); } - return OkStatus(); + return absl::OkStatus(); } Status MutableGraphView::UpdateFanin(absl::string_view node_name, @@ -1289,7 +1289,7 @@ Status MutableGraphView::UpdateFanin(absl::string_view node_name, } if (from_fanin == to_fanin) { - return OkStatus(); + return absl::OkStatus(); } bool from_fanin_is_control = IsTensorIdControlling(from_fanin); @@ -1304,7 +1304,7 @@ Status MutableGraphView::UpdateFanin(absl::string_view node_name, if (modified) { AddFaninInternal(node, {to_fanin_node, to_fanin.index()}); } - return OkStatus(); + return absl::OkStatus(); } // In place mutation of regular fanins, requires no shifting of ports. @@ -1340,7 +1340,7 @@ Status MutableGraphView::UpdateFanin(absl::string_view node_name, } } - return OkStatus(); + return absl::OkStatus(); } Status MutableGraphView::UpdateRegularFaninByPort(absl::string_view node_name, @@ -1365,7 +1365,7 @@ Status MutableGraphView::UpdateRegularFaninByPort(absl::string_view node_name, TensorId tensor_id = ParseTensorName(node->input(port)); if (tensor_id == fanin) { - return OkStatus(); + return absl::OkStatus(); } InputPort input(node, port); @@ -1384,7 +1384,7 @@ Status MutableGraphView::UpdateRegularFaninByPort(absl::string_view node_name, RemoveControllingFaninInternal(node, fanin_node); } - return OkStatus(); + return absl::OkStatus(); } Status MutableGraphView::SwapRegularFaninsByPorts(absl::string_view node_name, @@ -1405,12 +1405,12 @@ Status MutableGraphView::SwapRegularFaninsByPorts(absl::string_view node_name, error_status)); if (from_port == to_port) { - return OkStatus(); + return absl::OkStatus(); } TensorId from_fanin = ParseTensorName(node->input(from_port)); TensorId to_fanin = ParseTensorName(node->input(to_port)); if (from_fanin == to_fanin) { - return OkStatus(); + return absl::OkStatus(); } InputPort from_input(node, from_port); @@ -1428,7 +1428,7 @@ Status MutableGraphView::SwapRegularFaninsByPorts(absl::string_view node_name, node->mutable_input()->SwapElements(from_port, to_port); - return OkStatus(); + return absl::OkStatus(); } Status MutableGraphView::UpdateAllRegularFaninsToControlling( @@ -1499,7 +1499,7 @@ Status MutableGraphView::UpdateAllRegularFaninsToControlling( node->mutable_input()->DeleteSubrange(pos, node->input_size() - pos); max_regular_input_port().erase(node); - return OkStatus(); + return absl::OkStatus(); } Status MutableGraphView::CheckNodesCanBeDeleted( @@ -1562,7 +1562,7 @@ Status MutableGraphView::CheckNodesCanBeDeleted( return MutationError("DeleteNodes", params, error_msg); } - return OkStatus(); + return absl::OkStatus(); } Status MutableGraphView::DeleteNodes( @@ -1600,7 +1600,7 @@ Status MutableGraphView::DeleteNodes( graph()->mutable_node()->DeleteSubrange(last_pos + 1, last_idx - last_pos); } - return OkStatus(); + return absl::OkStatus(); } void MutableGraphView::RemoveFaninsInternal(NodeDef* deleted_node, diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc index 4c1ee2e302f19d..f3079b745d029c 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc @@ -255,7 +255,7 @@ class ArithmeticOptimizerStage : public GraphOptimizerStage { // Update consumers of node to take new_input as input instead. Status UpdateConsumers(NodeDef* node, const string& new_input) { const auto consumers = ctx().node_map->GetOutputs(node->name()); - if (consumers.empty()) return OkStatus(); + if (consumers.empty()) return absl::OkStatus(); const TensorId new_tensor = ParseTensorName(new_input); for (NodeDef* consumer : consumers) { if (consumer->name() == new_tensor.node()) continue; @@ -283,7 +283,7 @@ class ArithmeticOptimizerStage : public GraphOptimizerStage { AddToOptimizationQueue(consumer); } } - return OkStatus(); + return absl::OkStatus(); } // TODO(ezhulenev): remove this method from ArithmeticOptimizer when all @@ -397,7 +397,7 @@ class ArithmeticNodesGroupOptimizerStage : public ArithmeticOptimizerStage { *simplified_node_name = RewriteOptimizedNodesGroup(group); } - return OkStatus(); + return absl::OkStatus(); } protected: @@ -438,7 +438,7 @@ class ArithmeticNodesGroupOptimizerStage : public ArithmeticOptimizerStage { } } - return OkStatus(); + return absl::OkStatus(); } Status CreateOptimizedNodesGroup(NodeDef* root_node, @@ -458,7 +458,7 @@ class ArithmeticNodesGroupOptimizerStage : public ArithmeticOptimizerStage { TF_RETURN_IF_ERROR(AbsorbInputByOptimizedNodesGroup(input_i, group)); } - return OkStatus(); + return absl::OkStatus(); } // Check if all inputs can be broadcasted to the same shape @@ -833,7 +833,7 @@ class HoistCommonFactorOutOfAggregation : public ArithmeticOptimizerStage { *simplified_node_name = new_outer_node->name(); } } - return OkStatus(); + return absl::OkStatus(); } private: @@ -917,7 +917,7 @@ class HoistCommonFactorOutOfAggregation : public ArithmeticOptimizerStage { } *common_factor_is_denominator = has_div; - return OkStatus(); + return absl::OkStatus(); } // Gather up the non-shared factors (the y's in the example). @@ -951,7 +951,7 @@ class HoistCommonFactorOutOfAggregation : public ArithmeticOptimizerStage { *shapes_match = ShapesSymbolicallyEqual(*lhs, *rhs); } } - return OkStatus(); + return absl::OkStatus(); } bool IsRewritten(const NodeDef* node) const { @@ -1193,7 +1193,7 @@ class RemoveIdentityTranspose : public ArithmeticOptimizerStage { NodeDef* node_perm; TF_RETURN_IF_ERROR(GetInputNode(node->input(1), &node_perm)); if (!IsConstant(*node_perm)) { - return OkStatus(); + return absl::OkStatus(); } std::vector node_perm_values; TF_RETURN_IF_ERROR(GetPermutation(*node_perm, &node_perm_values)); @@ -1203,7 +1203,7 @@ class RemoveIdentityTranspose : public ArithmeticOptimizerStage { TF_RETURN_IF_ERROR( GetInputNode(first_transpose->input(1), &first_transpose_perm)); if (!IsConstant(*first_transpose_perm)) { - return OkStatus(); + return absl::OkStatus(); } std::vector first_transpose_perm_values; TF_RETURN_IF_ERROR( @@ -1240,7 +1240,7 @@ class RemoveIdentityTranspose : public ArithmeticOptimizerStage { } } } - return OkStatus(); + return absl::OkStatus(); } private: @@ -1252,10 +1252,10 @@ class RemoveIdentityTranspose : public ArithmeticOptimizerStage { for (int val : perm32) { perm64->push_back(static_cast(val)); } - return OkStatus(); + return absl::OkStatus(); } if (ValuesFromConstNode(node_perm, perm64)) { - return OkStatus(); + return absl::OkStatus(); } return errors::InvalidArgument("Couldn't extract permutation from ", node_perm.name()); @@ -1321,7 +1321,7 @@ class RemoveInvolution : public ArithmeticOptimizerStage { } } - return OkStatus(); + return absl::OkStatus(); } }; @@ -1351,7 +1351,7 @@ class RemoveRedundantBitcastStage : public ArithmeticOptimizerStage { TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "type", &output_type)); if ((input_type == output_type) && !IsInPreserveSet(*node)) { *simplified_node_name = node->input(0); - return OkStatus(); + return absl::OkStatus(); } NodeDef* bitcast; @@ -1372,7 +1372,7 @@ class RemoveRedundantBitcastStage : public ArithmeticOptimizerStage { *simplified_node_name = bitcast->name(); } - return OkStatus(); + return absl::OkStatus(); } }; @@ -1400,7 +1400,7 @@ class RemoveRedundantCastStage : public ArithmeticOptimizerStage { if (input_type == output_type) { *simplified_node_name = node->input(0); } - return OkStatus(); + return absl::OkStatus(); } }; @@ -1440,7 +1440,7 @@ class RemoveNegationStage : public ArithmeticOptimizerStage { if (updated) { AddToOptimizationQueue(node); } - return OkStatus(); + return absl::OkStatus(); } }; @@ -1461,7 +1461,7 @@ class RemoveLogicalNotStage : public ArithmeticOptimizerStage { TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &input)); if (IsInPreserveSet(*input) || NumNonControlOutputs(*input, *ctx().node_map) > 1) { - return OkStatus(); + return absl::OkStatus(); } string new_op; if (IsEqual(*input)) { @@ -1481,7 +1481,7 @@ class RemoveLogicalNotStage : public ArithmeticOptimizerStage { input->set_op(new_op); *simplified_node_name = input->name(); } - return OkStatus(); + return absl::OkStatus(); } }; @@ -1564,7 +1564,7 @@ class HoistCWiseUnaryChainsStage : public ArithmeticOptimizerStage { TF_RETURN_IF_ERROR( HoistUnaryOpChain(prefix_length, tails, &ctrl_inputs, node)); } - return OkStatus(); + return absl::OkStatus(); } private: @@ -1594,7 +1594,7 @@ class HoistCWiseUnaryChainsStage : public ArithmeticOptimizerStage { ChainLinkSet cur_tails; TF_RETURN_IF_ERROR(InitializeChains(root_node, &cur_tails)); if (cur_tails.size() < 2) { - return OkStatus(); + return absl::OkStatus(); } ctrl_inputs->clear(); bool stop = false; @@ -1608,7 +1608,7 @@ class HoistCWiseUnaryChainsStage : public ArithmeticOptimizerStage { // Advance tail pointers to the next level. TF_RETURN_IF_ERROR(AdvanceTails(*tails, &cur_tails, &stop)); } - return OkStatus(); + return absl::OkStatus(); } // Hoists the chains to the other side of concat or split and attaches the @@ -1621,7 +1621,7 @@ class HoistCWiseUnaryChainsStage : public ArithmeticOptimizerStage { << absl::StrJoin(*ctrl_inputs, ", ") << "]"; if (tails.empty()) { - return OkStatus(); + return absl::OkStatus(); } AddToOptimizationQueue(root_node); optimized_nodes_.insert(root_node->name()); @@ -1680,7 +1680,7 @@ class HoistCWiseUnaryChainsStage : public ArithmeticOptimizerStage { TF_RETURN_IF_ERROR(GetInputNode(node.input(input_port), &tail)); tails->insert(ChainLink(tail, input_port)); } - return OkStatus(); + return absl::OkStatus(); } else { // Handle split nodes by looking forwards in the graph. const auto& outputs = ctx().node_map->GetOutputs(node.name()); @@ -1695,11 +1695,11 @@ class HoistCWiseUnaryChainsStage : public ArithmeticOptimizerStage { // This output node has a non-control input other than the split node, // abort. tails->clear(); - return OkStatus(); + return absl::OkStatus(); } } } - return OkStatus(); + return absl::OkStatus(); } bool OpsAreSafeToHoist(const NodeDef& root_node, @@ -1740,7 +1740,7 @@ class HoistCWiseUnaryChainsStage : public ArithmeticOptimizerStage { const NodeDef* tail = link.node; if (node_is_concat_) { if (tail->input_size() == 0 || IsControlInput(tail->input(0))) { - return OkStatus(); + return absl::OkStatus(); } NodeDef* new_tail; TF_RETURN_IF_ERROR(GetInputNode(tail->input(0), &new_tail)); @@ -1750,7 +1750,7 @@ class HoistCWiseUnaryChainsStage : public ArithmeticOptimizerStage { for (NodeDef* new_tail : ctx().node_map->GetOutputs(tail->name())) { const TensorId tensor = ParseTensorName(new_tail->input(0)); if (tensor.node() != tail->name()) { - return OkStatus(); + return absl::OkStatus(); } // Skip control outputs. if (tensor.index() >= 0) { @@ -1761,7 +1761,7 @@ class HoistCWiseUnaryChainsStage : public ArithmeticOptimizerStage { } } *stop = false; - return OkStatus(); + return absl::OkStatus(); } Status HoistChainForConcat(const int prefix_length, const ChainLinkSet& tails, @@ -1788,7 +1788,7 @@ class HoistCWiseUnaryChainsStage : public ArithmeticOptimizerStage { ctx().node_map->UpdateInput(tail->name(), tail_input, concat_name); } } - return OkStatus(); + return absl::OkStatus(); } Status HoistChainForSplit(const int prefix_length, const ChainLinkSet& tails, @@ -1839,7 +1839,7 @@ class HoistCWiseUnaryChainsStage : public ArithmeticOptimizerStage { ? split_name : strings::StrCat(split_name, ":", link.port_origin))); } - return OkStatus(); + return absl::OkStatus(); } bool IsAlreadyOptimized(const NodeDef& node) const { @@ -1869,7 +1869,7 @@ class RemoveIdempotentStage : public ArithmeticOptimizerStage { if (input->op() == node->op() && input->device() == node->device()) { *simplified_node_name = node->input(0); } - return OkStatus(); + return absl::OkStatus(); } }; @@ -1908,7 +1908,7 @@ class SqrtDivToRsqrtMulStage : public ArithmeticOptimizerStage { AddToOptimizationQueue(node); AddToOptimizationQueue(y); } - return OkStatus(); + return absl::OkStatus(); } }; @@ -1935,13 +1935,14 @@ class FuseSquaredDiffStage : public ArithmeticOptimizerStage { // For complex, SquaredDiff computes conj(x-y)*(x-y), so this rewrite is // invalid. const DataType type = GetDataTypeFromAttr(*b, "T"); - if ((type == DT_COMPLEX64) || (type == DT_COMPLEX128)) return OkStatus(); + if ((type == DT_COMPLEX64) || (type == DT_COMPLEX128)) + return absl::OkStatus(); node->set_op("Identity"); b->set_op("SquaredDifference"); AddToOptimizationQueue(node); AddToOptimizationQueue(b); } - return OkStatus(); + return absl::OkStatus(); } }; @@ -1969,7 +1970,7 @@ class LogSoftmaxStage : public ArithmeticOptimizerStage { AddToOptimizationQueue(node); AddToOptimizationQueue(x); } - return OkStatus(); + return absl::OkStatus(); } }; @@ -2007,7 +2008,7 @@ class RemoveRedundantReshapeOrBroadcastTo : public ArithmeticOptimizerStage { if (!IsInPreserveSet(*node) && InputMatchesTargetShape(*node) && !HasControlInputs(*node)) { *simplified_node_name = node->input(0); - return OkStatus(); + return absl::OkStatus(); } // 2. Bypass reshape followed by reshape, possibly separated by a simple @@ -2040,7 +2041,7 @@ class RemoveRedundantReshapeOrBroadcastTo : public ArithmeticOptimizerStage { (!IsReshape(*reshape_to_bypass) || NumNonControlOutputs(*reshape_to_bypass, *ctx().node_map) > 1 || IsInPreserveSet(*reshape_to_bypass))) { - return OkStatus(); + return absl::OkStatus(); } // Clearing invalid shape inference results of nodes in chain. for (const NodeDef* node_in_chain : nodes_in_chain) { @@ -2059,11 +2060,11 @@ class RemoveRedundantReshapeOrBroadcastTo : public ArithmeticOptimizerStage { // Change the bypassed reshape to NoOp. ReplaceWithNoOp(reshape_to_bypass, ctx()); *simplified_node_name = node->name(); - return OkStatus(); + return absl::OkStatus(); } } - return OkStatus(); + return absl::OkStatus(); } private: @@ -2130,7 +2131,7 @@ class ReorderCastLikeAndValuePreserving : public ArithmeticOptimizerStage { if (!can_optimize || IsControlFlow(*producer) || IsInPreserveSet(*producer) || producer->device() != consumer->device()) { - return OkStatus(); + return absl::OkStatus(); } const NodeDef* cast_like_node = producer_is_cast ? producer : consumer; @@ -2144,13 +2145,13 @@ class ReorderCastLikeAndValuePreserving : public ArithmeticOptimizerStage { TF_RETURN_IF_ERROR(OutputTypeForNode(*cast_like_node, *cast_like_op_def, 0, &cast_dst_type)); if (!IsFixedSizeType(cast_src_type) || !IsFixedSizeType(cast_dst_type)) { - return OkStatus(); + return absl::OkStatus(); } else if (producer_is_cast && DataTypeSize(cast_dst_type) <= DataTypeSize(cast_src_type)) { - return OkStatus(); + return absl::OkStatus(); } else if (!producer_is_cast && DataTypeSize(cast_dst_type) >= DataTypeSize(cast_src_type)) { - return OkStatus(); + return absl::OkStatus(); } // Check that nodes were not already optimized. @@ -2162,7 +2163,7 @@ class ReorderCastLikeAndValuePreserving : public ArithmeticOptimizerStage { ctx().node_map->NodeExists(optimized_consumer_name) || ctx().node_map->NodeExists(optimized_producer_name); if (is_already_optimized) { - return OkStatus(); + return absl::OkStatus(); } // Add copies of consumer and producer in reverse order. @@ -2192,7 +2193,7 @@ class ReorderCastLikeAndValuePreserving : public ArithmeticOptimizerStage { AddToOptimizationQueue(new_producer); *simplified_node_name = new_consumer->name(); - return OkStatus(); + return absl::OkStatus(); } private: @@ -2209,11 +2210,11 @@ class ReorderCastLikeAndValuePreserving : public ArithmeticOptimizerStage { DataTypeString(dtype)); } else { // Op has fixed input type that already matches dtype. - return OkStatus(); + return absl::OkStatus(); } } SetDataTypeToAttr(dtype, type_attr_name, node); - return OkStatus(); + return absl::OkStatus(); } // This optimization can be dangerous on devices other than CPU and // GPU. The transpose might not be implemented for image.type, or @@ -2360,7 +2361,7 @@ class FoldMultiplyIntoConv : public ArithmeticOptimizerStage { AddToOptimizationQueue(tail); *simplified_node_name = conv->name(); - return OkStatus(); + return absl::OkStatus(); #undef TF_RETURN_IF_TRUE } }; @@ -2380,7 +2381,8 @@ class FoldTransposeIntoMatMul : public ArithmeticOptimizerStage { Status TrySimplify(NodeDef* node, string* simplified_node_name) override { const NodeScopeAndName matmul = ParseNodeScopeAndName(node->name()); const string optimized_node_name = OptimizedNodeName(matmul); - if (ctx().node_map->NodeExists(optimized_node_name)) return OkStatus(); + if (ctx().node_map->NodeExists(optimized_node_name)) + return absl::OkStatus(); NodeDef* a; NodeDef* b; @@ -2403,7 +2405,7 @@ class FoldTransposeIntoMatMul : public ArithmeticOptimizerStage { IsInnerMatrixTransposeNode(*a, ctx().node_map); const bool b_is_foldable = foldable_transpose_ops.count(b->op()) > 0 && IsInnerMatrixTransposeNode(*b, ctx().node_map); - if (!a_is_foldable && !b_is_foldable) return OkStatus(); + if (!a_is_foldable && !b_is_foldable) return absl::OkStatus(); NodeDef* new_op = AddCopyNode(optimized_node_name, node); @@ -2431,7 +2433,7 @@ class FoldTransposeIntoMatMul : public ArithmeticOptimizerStage { ForwardControlDependencies(new_op, deps_to_forward); *simplified_node_name = new_op->name(); - return OkStatus(); + return absl::OkStatus(); } private: @@ -2488,7 +2490,8 @@ class FoldConjugateIntoTranspose : public ArithmeticOptimizerStage { Status TrySimplify(NodeDef* node, string* simplified_node_name) override { const NodeScopeAndName matmul = ParseNodeScopeAndName(node->name()); const string optimized_node_name = OptimizedNodeName(matmul); - if (ctx().node_map->NodeExists(optimized_node_name)) return OkStatus(); + if (ctx().node_map->NodeExists(optimized_node_name)) + return absl::OkStatus(); NodeDef* input; TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &input)); @@ -2510,7 +2513,7 @@ class FoldConjugateIntoTranspose : public ArithmeticOptimizerStage { *simplified_node_name = new_op->name(); } - return OkStatus(); + return absl::OkStatus(); } }; @@ -2534,7 +2537,8 @@ class ReplaceMulWithSquare : public ArithmeticOptimizerStage { Status TrySimplify(NodeDef* node, string* simplified_node_name) override { const NodeScopeAndName mul = ParseNodeScopeAndName(node->name()); const string optimized_node_name = OptimizedNodeName(mul); - if (ctx().node_map->NodeExists(optimized_node_name)) return OkStatus(); + if (ctx().node_map->NodeExists(optimized_node_name)) + return absl::OkStatus(); const DataType type = GetDataTypeFromAttr(*node, "T"); bool is_complex = (type == DT_COMPLEX64) || (type == DT_COMPLEX128); @@ -2552,7 +2556,7 @@ class ReplaceMulWithSquare : public ArithmeticOptimizerStage { *simplified_node_name = new_square_node->name(); } - return OkStatus(); + return absl::OkStatus(); } }; @@ -2583,11 +2587,11 @@ class ReplaceMulWithBroadcastByTile : public ArithmeticOptimizerStage { TF_RETURN_IF_ERROR(GetInputNode(node->input(1), &ones)); if (IsInPreserveSet(*node) || IsInPreserveSet(*input) || IsInPreserveSet(*ones)) { - return OkStatus(); + return absl::OkStatus(); } // TODO(kkiningh): Generalize using IsOnes from constant_folding.cc - if (IsConstant(*input) || !IsOnes(*ones)) return OkStatus(); + if (IsConstant(*input) || !IsOnes(*ones)) return absl::OkStatus(); // Avoid optimizing the same node twice const NodeScopeAndName scope_and_name = ParseNodeScopeAndName(node->name()); @@ -2595,28 +2599,28 @@ class ReplaceMulWithBroadcastByTile : public ArithmeticOptimizerStage { const string const_node_name = OptimizedNodeName(scope_and_name, "Const"); if (ctx().node_map->NodeExists(tile_node_name) || ctx().node_map->NodeExists(const_node_name)) { - return OkStatus(); + return absl::OkStatus(); } const std::vector& props = ctx().graph_properties->GetInputProperties(node->name()); - if (props.size() != 2) return OkStatus(); + if (props.size() != 2) return absl::OkStatus(); // Ignore ops where the shape doesn't change const TensorShapeProto& input_shape = props[0].shape(); const TensorShapeProto& ones_shape = props[1].shape(); TensorShapeProto output_shape; if (!ShapeAfterBroadcast(input_shape, ones_shape, &output_shape)) { - return OkStatus(); + return absl::OkStatus(); } if (ShapesSymbolicallyEqual(input_shape, output_shape)) { - return OkStatus(); + return absl::OkStatus(); } // All inputs must have same input/output dimensions if (input_shape.dim_size() != output_shape.dim_size() || ones_shape.dim_size() != output_shape.dim_size()) - return OkStatus(); + return absl::OkStatus(); // At this point all preconditions are met. Can proceed with rewrite. VLOG(3) << "Simplify multiply with all ones input: node=" << node->name() @@ -2653,7 +2657,7 @@ class ReplaceMulWithBroadcastByTile : public ArithmeticOptimizerStage { ForwardControlDependencies(tile_node, {node}); *simplified_node_name = tile_node->name(); - return OkStatus(); + return absl::OkStatus(); } protected: @@ -2707,20 +2711,20 @@ class ReduceUpsamplingDims : public ArithmeticOptimizerStage { NodeDef* tile; TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &tile)); if (!IsTile(*tile) || IsInPreserveSet(*tile)) { - return OkStatus(); + return absl::OkStatus(); } if (NumNonControlOutputs(*tile, *ctx().node_map) != 1) { // Optimization is only worthwile when there is a single output from Tile. // Otherwise, we need to insert additional Reshape ops that can't be // easily removed. - return OkStatus(); + return absl::OkStatus(); } NodeDef* reshape; TF_RETURN_IF_ERROR(GetInputNode(tile->input(0), &reshape)); if (!IsReshape(*reshape) || IsInPreserveSet(*reshape)) { - return OkStatus(); + return absl::OkStatus(); } NodeDef* multiples; @@ -2741,18 +2745,18 @@ class ReduceUpsamplingDims : public ArithmeticOptimizerStage { ctx().node_map->NodeExists(new_tile_name) || ctx().node_map->NodeExists(new_shape_name) || ctx().node_map->NodeExists(new_multiples_name)) { - return OkStatus(); + return absl::OkStatus(); } // Compuate updated multiples/shape values. AttrValue new_multiples_attr; if (!CreateUpdatedMultiplesProto(multiples, new_multiples_attr.mutable_tensor())) { - return OkStatus(); + return absl::OkStatus(); } AttrValue new_shape_attr; if (!CreateUpdatedShapeProto(shape, new_shape_attr.mutable_tensor())) { - return OkStatus(); + return absl::OkStatus(); } // At this point the graph is validated and can be updated @@ -2790,7 +2794,7 @@ class ReduceUpsamplingDims : public ArithmeticOptimizerStage { ForwardControlDependencies(new_shape, {shape}); *simplified_node_name = node->name(); - return OkStatus(); + return absl::OkStatus(); } private: @@ -2922,7 +2926,7 @@ class ReplacePackWithTileReshape : public ArithmeticOptimizerStage { // Must be at least two Pack operations to consider for replacement if (chain.empty()) { - return OkStatus(); + return absl::OkStatus(); } // Avoid optimizing the same node twice @@ -2939,7 +2943,7 @@ class ReplacePackWithTileReshape : public ArithmeticOptimizerStage { ctx().node_map->NodeExists(new_tile_name) || ctx().node_map->NodeExists(new_shape_name) || ctx().node_map->NodeExists(new_reshape_name)) { - return OkStatus(); + return absl::OkStatus(); } // 2. Calculate the multiples and shape tensor using the chain @@ -2947,7 +2951,7 @@ class ReplacePackWithTileReshape : public ArithmeticOptimizerStage { TF_RETURN_IF_ERROR(GetTensorProperties(input->name(), &input_props)); const TensorShapeProto& input_shape = input_props->shape(); if (!PartialTensorShape(input_shape).IsFullyDefined()) { - return OkStatus(); + return absl::OkStatus(); } Tensor multiples(DT_INT32, TensorShape({input_shape.dim_size()})); TF_RETURN_IF_ERROR(CalculateMultiplesFromChain(chain, &multiples)); @@ -2956,7 +2960,7 @@ class ReplacePackWithTileReshape : public ArithmeticOptimizerStage { TF_RETURN_IF_ERROR(GetTensorProperties(node->name(), &output_props)); const TensorShapeProto& output_shape = output_props->shape(); if (!PartialTensorShape(output_shape).IsFullyDefined()) { - return OkStatus(); + return absl::OkStatus(); } Tensor output_shape_tensor(DT_INT32, TensorShape({output_shape.dim_size()})); @@ -3011,7 +3015,7 @@ class ReplacePackWithTileReshape : public ArithmeticOptimizerStage { *simplified_node_name = new_reshape_node->name(); - return OkStatus(); + return absl::OkStatus(); } protected: @@ -3052,7 +3056,7 @@ class ReplacePackWithTileReshape : public ArithmeticOptimizerStage { dims.insert(dims.begin() + axis, dims[axis]); } - return OkStatus(); + return absl::OkStatus(); } }; @@ -3086,7 +3090,7 @@ class SimplifyAggregation : public ArithmeticOptimizerStage { // 1. Discard aggregate nodes with a single input and no control deps. if (node->input_size() == 1) { *simplified_node_name = node->input(0); - return OkStatus(); + return absl::OkStatus(); } // 2. Rewrite aggregations of N >= 2 identical terms. @@ -3102,7 +3106,7 @@ class SimplifyAggregation : public ArithmeticOptimizerStage { break; } } - if (!all_equal) return OkStatus(); + if (!all_equal) return absl::OkStatus(); // And node should not be optimized earlier. const NodeScopeAndName node_scope_and_name = @@ -3116,7 +3120,7 @@ class SimplifyAggregation : public ArithmeticOptimizerStage { ctx().node_map->NodeExists(optimized_const_name) || ctx().node_map->NodeExists(optimized_mul_name); - if (is_already_optimized) return OkStatus(); + if (is_already_optimized) return absl::OkStatus(); // At this point all preconditions are met, and we safely do the rewrite. VLOG(3) << "Simplify aggregation with identical inputs: node=" @@ -3157,7 +3161,7 @@ class SimplifyAggregation : public ArithmeticOptimizerStage { ForwardControlDependencies(new_mul_node, {node}); *simplified_node_name = new_mul_node->name(); - return OkStatus(); + return absl::OkStatus(); } }; @@ -3175,16 +3179,16 @@ class ConvertPowStage : public ArithmeticOptimizerStage { Status TrySimplify(NodeDef* node, string* simplified_node_name) override { Tensor pow; - if (!GetTensorFromConstNode(node->input(1), &pow)) return OkStatus(); + if (!GetTensorFromConstNode(node->input(1), &pow)) return absl::OkStatus(); complex128 prev, curr; for (int i = 0; i < pow.NumElements(); ++i) { if (!GetElementUnexhaustive(pow, i, {pow.dtype()}, &curr)) { // input data type is not supported by Pow. Skip. - return OkStatus(); + return absl::OkStatus(); } if (i != 0 && curr != prev) { // pow has different values on different elements. Skip. - return OkStatus(); + return absl::OkStatus(); } prev = curr; } @@ -3266,7 +3270,7 @@ class ConvertPowStage : public ArithmeticOptimizerStage { AddToOptimizationQueue(node); AddToOptimizationQueue(y); } - return OkStatus(); + return absl::OkStatus(); } private: @@ -3274,22 +3278,22 @@ class ConvertPowStage : public ArithmeticOptimizerStage { switch (t->dtype()) { case DT_INT32: t->flat()(i) = 1; - return OkStatus(); + return absl::OkStatus(); case DT_INT64: t->flat()(i) = 1L; - return OkStatus(); + return absl::OkStatus(); case DT_FLOAT: t->flat()(i) = 1.0f; - return OkStatus(); + return absl::OkStatus(); case DT_DOUBLE: t->flat()(i) = 1.0; - return OkStatus(); + return absl::OkStatus(); case DT_COMPLEX64: t->flat()(i) = complex64(1); - return OkStatus(); + return absl::OkStatus(); case DT_COMPLEX128: t->flat()(i) = complex128(1); - return OkStatus(); + return absl::OkStatus(); default: return errors::InvalidArgument("Invalid data type: ", t->dtype()); } @@ -3309,11 +3313,11 @@ class ConvertLog1pStage : public ArithmeticOptimizerStage { NodeDef* input; TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &input)); if (!IsAdd(*input)) { - return OkStatus(); + return absl::OkStatus(); } if (ctx().graph_properties->GetInputProperties(input->name()).size() < 2) { - return OkStatus(); + return absl::OkStatus(); } bool modified = false; @@ -3324,7 +3328,7 @@ class ConvertLog1pStage : public ArithmeticOptimizerStage { if (modified) { *simplified_node_name = node->name(); } - return OkStatus(); + return absl::OkStatus(); } private: @@ -3337,17 +3341,17 @@ class ConvertLog1pStage : public ArithmeticOptimizerStage { for (int k = 0; k < c.shape().dim_size(); ++k) { // Skip if c shape is not fully determined. if (c.shape().dim(k).size() < 0) { - return OkStatus(); + return absl::OkStatus(); } } TensorShapeProto broadcast_shape; if (!ShapeAfterBroadcast(t.shape(), c.shape(), &broadcast_shape)) { - return OkStatus(); + return absl::OkStatus(); } if (!ShapesSymbolicallyEqual(t.shape(), broadcast_shape)) { // skip if the non-constant tensor doesn't have the same shape after // broadcast. - return OkStatus(); + return absl::OkStatus(); } Tensor constant; if (GetTensorFromConstNode(add_node->input(j), &constant)) { @@ -3362,11 +3366,11 @@ class ConvertLog1pStage : public ArithmeticOptimizerStage { DT_COMPLEX64, DT_COMPLEX128}, &element)) { // input data type is not supported by log1p. Skip. - return OkStatus(); + return absl::OkStatus(); } if (element != complex128(1)) { // current element is not 1. Skip. - return OkStatus(); + return absl::OkStatus(); } } NodeDef *x, *y; @@ -3383,7 +3387,7 @@ class ConvertLog1pStage : public ArithmeticOptimizerStage { AddToOptimizationQueue(y); *modified = true; } - return OkStatus(); + return absl::OkStatus(); } }; @@ -3405,21 +3409,22 @@ class ConvertExpm1Stage : public ArithmeticOptimizerStage { Status TrySimplify(NodeDef* node, string* simplified_node_name) override { if (ctx().graph_properties->GetInputProperties(node->name()).size() < 2) { - return OkStatus(); + return absl::OkStatus(); } const auto& t = ctx().graph_properties->GetInputProperties(node->name())[0]; const auto& c = ctx().graph_properties->GetInputProperties(node->name())[1]; TensorShapeProto broadcast_shape; if (!ShapeAfterBroadcast(t.shape(), c.shape(), &broadcast_shape)) { - return OkStatus(); + return absl::OkStatus(); } if (!ShapesSymbolicallyEqual(t.shape(), broadcast_shape)) { // skip if the non-constant tensor doesn't have the same shape after // broadcast. - return OkStatus(); + return absl::OkStatus(); } Tensor constant; - if (!GetTensorFromConstNode(node->input(1), &constant)) return OkStatus(); + if (!GetTensorFromConstNode(node->input(1), &constant)) + return absl::OkStatus(); // TODO(rmlarsen): Use the more general IsOnes helper here. complex128 element; for (int k = 0; k < constant.NumElements(); ++k) { @@ -3428,11 +3433,11 @@ class ConvertExpm1Stage : public ArithmeticOptimizerStage { DT_COMPLEX64, DT_COMPLEX128}, &element)) { // input data type is not supported by expm1. Skip. - return OkStatus(); + return absl::OkStatus(); } if (element != complex128(1)) { // current element is not 1. Skip. - return OkStatus(); + return absl::OkStatus(); } } NodeDef* exp; @@ -3450,7 +3455,7 @@ class ConvertExpm1Stage : public ArithmeticOptimizerStage { AddToOptimizationQueue(exp_input); AddToOptimizationQueue(ones); *simplified_node_name = node->name(); - return OkStatus(); + return absl::OkStatus(); } }; @@ -3476,7 +3481,7 @@ class OptimizeMaxOrMinOfMonotonicStage : public ArithmeticOptimizerStage { Status TrySimplify(NodeDef* reduction_node, string* simplified_node_name) override { if (IsInPreserveSet(*reduction_node)) { - return OkStatus(); + return absl::OkStatus(); } NodeDef* inner_function; @@ -3538,7 +3543,7 @@ class OptimizeMaxOrMinOfMonotonicStage : public ArithmeticOptimizerStage { AddToOptimizationQueue(inner_function); AddToOptimizationQueue(inner_input); } - return OkStatus(); + return absl::OkStatus(); } private: @@ -3638,7 +3643,7 @@ class UnaryOpsComposition : public ArithmeticOptimizerStage { *root, *ctx().node_map, /*follow_control_input*/ false, predicate_fn); // We were not able to find a chain that can be replaced. - if (op_names.size() == 1) return OkStatus(); + if (op_names.size() == 1) return absl::OkStatus(); // Do not add fused nodes to any other chain. std::for_each(op_nodes.begin(), op_nodes.end(), @@ -3666,7 +3671,7 @@ class UnaryOpsComposition : public ArithmeticOptimizerStage { *simplified_node_name = composition_node->name(); - return OkStatus(); + return absl::OkStatus(); } private: @@ -3745,14 +3750,14 @@ class RemoveStackSliceSameAxis : public ArithmeticOptimizerStage { // Get the input and see if it's a Pack op. TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &pack)); - if (!IsPack(*pack)) return OkStatus(); + if (!IsPack(*pack)) return absl::OkStatus(); bool return_early; PartialTensorShape pack_output_shape; int pack_axis; TF_RETURN_IF_ERROR( CheckInputs(node, pack, &pack_output_shape, &pack_axis, &return_early)); - if (return_early) return OkStatus(); + if (return_early) return absl::OkStatus(); int64_t slice_start_value; bool found; @@ -3760,7 +3765,7 @@ class RemoveStackSliceSameAxis : public ArithmeticOptimizerStage { TF_RETURN_IF_ERROR(GetSliceAxis(node, pack, pack_output_shape, pack_axis, &slice_start_value, &found, &must_expand_dims)); - if (!found) return OkStatus(); + if (!found) return absl::OkStatus(); return RewriteGraph(node, pack, slice_start_value, pack_axis, must_expand_dims, simplified_node_name); @@ -3778,7 +3783,7 @@ class RemoveStackSliceSameAxis : public ArithmeticOptimizerStage { ctx().graph_properties->GetInputProperties(node->name()); if (slice_properties.empty() || slice_properties[0].shape().unknown_rank()) { - return OkStatus(); + return absl::OkStatus(); } *pack_output_shape = slice_properties[0].shape(); const int pack_output_rank = pack_output_shape->dims(); @@ -3791,7 +3796,7 @@ class RemoveStackSliceSameAxis : public ArithmeticOptimizerStage { ") axis attribute is out of bounds: ", pack->attr().at("axis").i()); } *return_early = false; - return OkStatus(); + return absl::OkStatus(); } Status GetSliceAxis(const NodeDef* node, const NodeDef* pack, @@ -3818,18 +3823,18 @@ class RemoveStackSliceSameAxis : public ArithmeticOptimizerStage { TF_RETURN_IF_ERROR(GetInputNode(node->input(1), &slice_begin)); TF_RETURN_IF_ERROR(GetInputNode(node->input(2), &slice_size)); for (const auto* n : {slice_begin, slice_size}) { - if (!IsReallyConstant(*n)) return OkStatus(); + if (!IsReallyConstant(*n)) return absl::OkStatus(); } Tensor slice_begin_t; Tensor slice_size_t; TF_RETURN_IF_ERROR(CheckAttrExists(*slice_begin, "value")); if (!slice_begin_t.FromProto(slice_begin->attr().at("value").tensor())) { - return OkStatus(); + return absl::OkStatus(); } TF_RETURN_IF_ERROR(CheckAttrExists(*slice_size, "value")); if (!slice_size_t.FromProto(slice_size->attr().at("value").tensor())) { - return OkStatus(); + return absl::OkStatus(); } auto copy_tensor_values_to_vector = @@ -3845,7 +3850,7 @@ class RemoveStackSliceSameAxis : public ArithmeticOptimizerStage { " has invalid type for Index attr: ", DataTypeString(t.dtype())); } - return OkStatus(); + return absl::OkStatus(); }; gtl::InlinedVector slice_begin_vec; @@ -3864,7 +3869,7 @@ class RemoveStackSliceSameAxis : public ArithmeticOptimizerStage { int slice_begin_vec_size = slice_begin_vec.size(); if (!pack_output_shape.unknown_rank() && slice_begin_vec_size != pack_output_shape.dims()) { - return OkStatus(); + return absl::OkStatus(); } if (pack_axis >= slice_begin_vec_size) { return errors::InvalidArgument( @@ -3875,7 +3880,7 @@ class RemoveStackSliceSameAxis : public ArithmeticOptimizerStage { *slice_start_value = slice_begin_vec[pack_axis]; if (slice_size_vec[pack_axis] != 1) { // Not slicing a single value out. - return OkStatus(); + return absl::OkStatus(); } for (int i = 0; i < slice_begin_vec_size; ++i) { @@ -3884,7 +3889,7 @@ class RemoveStackSliceSameAxis : public ArithmeticOptimizerStage { !(slice_size_vec[i] == -1 || slice_size_vec[i] == pack_output_shape.dim_size(i))) { // Not slicing on the same axis as the Pack op. - return OkStatus(); + return absl::OkStatus(); } } } @@ -3897,7 +3902,7 @@ class RemoveStackSliceSameAxis : public ArithmeticOptimizerStage { } *found = true; // slice_start_value is valid. - return OkStatus(); + return absl::OkStatus(); } Status GetStridedSliceAxis(const NodeDef* node, const NodeDef* pack, @@ -3928,7 +3933,7 @@ class RemoveStackSliceSameAxis : public ArithmeticOptimizerStage { TF_RETURN_IF_ERROR(GetInputNode(node->input(3), &slice_strides)); for (const auto* n : {slice_begin, slice_end, slice_strides}) { - if (!IsReallyConstant(*n)) return OkStatus(); + if (!IsReallyConstant(*n)) return absl::OkStatus(); } Tensor slice_begin_t; @@ -3937,16 +3942,16 @@ class RemoveStackSliceSameAxis : public ArithmeticOptimizerStage { TF_RETURN_IF_ERROR(CheckAttrExists(*slice_begin, "value")); if (!slice_begin_t.FromProto(slice_begin->attr().at("value").tensor())) { - return OkStatus(); + return absl::OkStatus(); } TF_RETURN_IF_ERROR(CheckAttrExists(*slice_end, "value")); if (!slice_end_t.FromProto(slice_end->attr().at("value").tensor())) { - return OkStatus(); + return absl::OkStatus(); } TF_RETURN_IF_ERROR(CheckAttrExists(*slice_strides, "value")); if (!slice_strides_t.FromProto( slice_strides->attr().at("value").tensor())) { - return OkStatus(); + return absl::OkStatus(); } TensorShape processing_shape; TensorShape final_shape; @@ -3962,7 +3967,7 @@ class RemoveStackSliceSameAxis : public ArithmeticOptimizerStage { &processing_shape, &final_shape, &is_identity, &is_simple_slice, &slice_dim0, &slice_begin_vec, &slice_end_vec, &slice_strides_vec)); - if (!is_simple_slice) return OkStatus(); + if (!is_simple_slice) return absl::OkStatus(); int begin_index = -1; int64_t begin_value = 0; @@ -3971,7 +3976,7 @@ class RemoveStackSliceSameAxis : public ArithmeticOptimizerStage { if (v != 0) { if (begin_index != -1) { // At least two start values that are nonzero. - return OkStatus(); + return absl::OkStatus(); } begin_index = i; begin_value = v; @@ -3985,29 +3990,29 @@ class RemoveStackSliceSameAxis : public ArithmeticOptimizerStage { if (v != pack_output_shape.dim_size(i)) { if (end_index != -1) { // At least two end values that are nonzero. - return OkStatus(); + return absl::OkStatus(); } end_index = i; end_value = v; } } - if (begin_index == -1 && end_index == -1) return OkStatus(); + if (begin_index == -1 && end_index == -1) return absl::OkStatus(); if (begin_index != -1 && end_index != -1 && begin_index != end_index) { // Somehow received different axes for begin/end slicing - return OkStatus(); + return absl::OkStatus(); } const int slice_axis = (begin_index == -1) ? end_index : begin_index; if (slice_axis != pack_axis) { // Not slicing on the same axis as the Pack op. - return OkStatus(); + return absl::OkStatus(); } *slice_start_value = (begin_index == -1) ? 0 : begin_value; const int64_t slice_end_value = (end_index == -1) ? pack_output_shape.dim_size(slice_axis) : end_value; if (slice_end_value != *slice_start_value + 1) { // Not slicing a single value out. - return OkStatus(); + return absl::OkStatus(); } if (*slice_start_value < 0 || *slice_start_value >= pack->input_size()) { @@ -4023,11 +4028,11 @@ class RemoveStackSliceSameAxis : public ArithmeticOptimizerStage { *must_expand_dims = false; } else { // Shrinking on a different axis from the one that we are slicing on. - return OkStatus(); + return absl::OkStatus(); } *found = true; // slice_start_value is valid. - return OkStatus(); + return absl::OkStatus(); } Status RewriteGraph(const NodeDef* node, const NodeDef* pack, @@ -4079,7 +4084,7 @@ class RemoveStackSliceSameAxis : public ArithmeticOptimizerStage { AddToOptimizationQueue(output); *simplified_node_name = output->name(); - return OkStatus(); + return absl::OkStatus(); } }; @@ -4120,7 +4125,7 @@ class SimplifyEmbeddingLookupStage : public ArithmeticOptimizerStage { Status TrySimplify(NodeDef* reduction_node, string* simplified_node_name) override { - if (IsInPreserveSet(*reduction_node)) return OkStatus(); + if (IsInPreserveSet(*reduction_node)) return absl::OkStatus(); // Input 0 (data) of the reduction node must be a tf.gather() on the 0th // axis. @@ -4128,9 +4133,9 @@ class SimplifyEmbeddingLookupStage : public ArithmeticOptimizerStage { TF_RETURN_IF_ERROR(GetInputNode(reduction_node->input(0), &gather_node)); if (!IsGather(*gather_node) || IsInPreserveSet(*gather_node) || gather_node->device() != reduction_node->device()) - return OkStatus(); + return absl::OkStatus(); if (gather_node->op() == "GatherV2" && !IsAxis0(*gather_node, 2)) - return OkStatus(); + return absl::OkStatus(); // Input 1 (indices) of the gather node must be a tf.unique() on the 0th // axis. @@ -4138,9 +4143,9 @@ class SimplifyEmbeddingLookupStage : public ArithmeticOptimizerStage { TF_RETURN_IF_ERROR(GetInputNode(gather_node->input(1), &unique_node)); if (!IsUnique(*unique_node) || IsInPreserveSet(*unique_node) || unique_node->device() != gather_node->device()) - return OkStatus(); + return absl::OkStatus(); if (unique_node->op() == "UniqueV2" && !IsAxis0(*unique_node, 1)) - return OkStatus(); + return absl::OkStatus(); DataType unique_element_type; TF_RETURN_IF_ERROR(GetNodeAttr(*unique_node, "T", &unique_element_type)); @@ -4148,7 +4153,7 @@ class SimplifyEmbeddingLookupStage : public ArithmeticOptimizerStage { // Input 1 (indices) of the reduction node must be output 1 of the unique // node. const TensorId idx_tensor = ParseTensorName(reduction_node->input(1)); - if (idx_tensor != TensorId(unique_node->name(), 1)) return OkStatus(); + if (idx_tensor != TensorId(unique_node->name(), 1)) return absl::OkStatus(); // Input 1 (indices) of the reduction node becomes input 0 (x) of the unique // node. @@ -4207,7 +4212,7 @@ class SimplifyEmbeddingLookupStage : public ArithmeticOptimizerStage { gather_node->input(0)); } *simplified_node_name = reduction_node->name(); - return OkStatus(); + return absl::OkStatus(); } private: @@ -4247,7 +4252,7 @@ class RemoveCastIntoSegmentReductionStage : public ArithmeticOptimizerStage { Status TrySimplify(NodeDef* reduction_node, string* simplified_node_name) override { - if (IsInPreserveSet(*reduction_node)) return OkStatus(); + if (IsInPreserveSet(*reduction_node)) return absl::OkStatus(); bool optimized = false; @@ -4274,7 +4279,7 @@ class RemoveCastIntoSegmentReductionStage : public ArithmeticOptimizerStage { } if (optimized) *simplified_node_name = reduction_node->name(); - return OkStatus(); + return absl::OkStatus(); } private: @@ -4415,7 +4420,7 @@ Status ArithmeticOptimizer::SimplifyArithmeticOps(bool can_use_shapes) { } } } - return OkStatus(); + return absl::OkStatus(); } Status ArithmeticOptimizer::Optimize(Cluster* /*cluster*/, @@ -4456,7 +4461,7 @@ Status ArithmeticOptimizer::Optimize(Cluster* /*cluster*/, // Perform the optimizations. TF_RETURN_IF_ERROR(SimplifyArithmeticOps(can_use_shapes)); *optimized_graph = std::move(*optimized_graph_); - return OkStatus(); + return absl::OkStatus(); } } // namespace grappler diff --git a/tensorflow/core/grappler/optimizers/auto_mixed_precision.cc b/tensorflow/core/grappler/optimizers/auto_mixed_precision.cc index 7cab9376515a87..134b9017045ea6 100644 --- a/tensorflow/core/grappler/optimizers/auto_mixed_precision.cc +++ b/tensorflow/core/grappler/optimizers/auto_mixed_precision.cc @@ -298,7 +298,7 @@ class NodeTypeAttrMap { for (const NodeDef& node : graph.node()) { TF_RETURN_IF_ERROR(AddNode(node)); } - return OkStatus(); + return absl::OkStatus(); } bool is_initialized() const { return graph_ != nullptr; } @@ -412,7 +412,7 @@ class NodeTypeAttrMap { } } } - return OkStatus(); + return absl::OkStatus(); } // WARN: `graph_` must outlive this object (node pointers must remain valid). @@ -628,7 +628,7 @@ Status GraphTypeTopologyView::InitializeFromGraph( SortAndRemoveDuplicates(&fanouts_[node_type_idx]); } - return OkStatus(); + return absl::OkStatus(); } Status GraphTypeTopologyView::AddEphemeralEdges( @@ -678,7 +678,7 @@ Status GraphTypeTopologyView::AddEphemeralEdges( SortAndRemoveDuplicates(&fanouts_[node_type_idx]); } - return OkStatus(); + return absl::OkStatus(); } bool GraphTypeTopologyView::HasNode(absl::string_view node_name, @@ -931,7 +931,7 @@ Status ValidateLists(const gtl::FlatSet& allow_list, if (duplicates) { return errors::InvalidArgument("Op lists have conflicting entries"); } else { - return OkStatus(); + return absl::OkStatus(); } } @@ -1195,7 +1195,7 @@ Status AutoMixedPrecisionImpl::PrintDebugLogs(bool preop, size_t timestamp) { string prepend_path; TF_RETURN_IF_ERROR(ReadStringFromEnvVar( "TF_AUTO_MIXED_PRECISION_GRAPH_REWRITE_LOG_PATH", "", &prepend_path)); - if (prepend_path.empty()) return OkStatus(); + if (prepend_path.empty()) return absl::OkStatus(); string suffix = strings::StrCat("_", preop ? "preop" : kSuffix, "_", id_, "_", timestamp); @@ -1242,7 +1242,7 @@ Status AutoMixedPrecisionImpl::PrintDebugLogs(bool preop, size_t timestamp) { f.close(); LOG(INFO) << "Saved paint bucket info to " << fname; } - return OkStatus(); + return absl::OkStatus(); } void AutoMixedPrecisionImpl::LogSkippedNode(const NodeDef& node, @@ -1500,7 +1500,7 @@ Status AutoMixedPrecisionImpl::Optimize() { if (allow_set.empty()) { LOG(INFO) << "No allowlist ops found, nothing to do"; - return OkStatus(); + return absl::OkStatus(); } VLOG(2) << "Beginning pass 2 to propagate deny forwards from denylist ops " @@ -1552,7 +1552,7 @@ Status AutoMixedPrecisionImpl::Optimize() { TF_RETURN_IF_ERROR(PrintDebugLogs(/* preop = */ false, timestamp)); - return OkStatus(); + return absl::OkStatus(); } // If node is a Tensor List op with a float32 data type attribute then this @@ -1968,7 +1968,7 @@ Status AutoMixedPrecisionImpl::ForceColorMatchOnRecurrentEdges( } } } - return OkStatus(); + return absl::OkStatus(); } // Forces all of the given Tensor List nodes into the same color set. @@ -2255,7 +2255,7 @@ Status AutoMixedPrecisionImpl::ChangeTypeAttrsAndAddCasts( << " nodes to " << type_str << " precision using " << num_nonvar_casts_to_f16_ << " cast(s) to " << type_str << " (excluding Const and Variable casts)"; - return OkStatus(); + return absl::OkStatus(); } int GetNumGPUs(const Cluster& cluster) { @@ -2302,7 +2302,7 @@ Status AutoMixedPrecision::Optimize(Cluster* cluster, const GrapplerItem& item, // AutoMixedPrecision is currently only tuned for GPU. LOG(WARNING) << "No (suitable) GPUs detected, skipping " << name() << " graph optimizer"; - return OkStatus(); + return absl::OkStatus(); } if (num_gpus >= 1 && mode_ == AutoMixedPrecisionMode::BF16) { diff --git a/tensorflow/core/grappler/optimizers/auto_parallel.cc b/tensorflow/core/grappler/optimizers/auto_parallel.cc index f4b6da85ed2372..8753db8cb46fb1 100644 --- a/tensorflow/core/grappler/optimizers/auto_parallel.cc +++ b/tensorflow/core/grappler/optimizers/auto_parallel.cc @@ -198,7 +198,7 @@ Status AutoParallel::Initialize(const GrapplerItem& item) { } } LOG(INFO) << "Number of shared nodes: " << shared_nodes_.size(); - return OkStatus(); + return absl::OkStatus(); } bool AutoParallel::NotSharedNode(const string& name) { @@ -268,7 +268,7 @@ Status AutoParallel::Optimize(Cluster* cluster, const GrapplerItem& item, GraphDef* output) { TF_RETURN_IF_ERROR(Initialize(item)); BuildGraph(output); - return OkStatus(); + return absl::OkStatus(); } } // end namespace grappler diff --git a/tensorflow/core/grappler/optimizers/common_subgraph_elimination.cc b/tensorflow/core/grappler/optimizers/common_subgraph_elimination.cc index e755cabfbf84b5..587031b416a609 100644 --- a/tensorflow/core/grappler/optimizers/common_subgraph_elimination.cc +++ b/tensorflow/core/grappler/optimizers/common_subgraph_elimination.cc @@ -177,7 +177,7 @@ Status CommonSubgraphElimination::DedupComputations(GraphDef* optimized_graph) { GraphTopologyView graph_view; if (!graph_view.InitializeFromGraph(*optimized_graph).ok()) { LOG(WARNING) << "Failed to initialize GraphTopologyView."; - return OkStatus(); + return absl::OkStatus(); } // If either node or rep feeds an inplace op, deduping them may cause data @@ -270,7 +270,7 @@ Status CommonSubgraphElimination::DedupComputations(GraphDef* optimized_graph) { EraseNodesFromGraph(duplicates, optimized_graph); } - return OkStatus(); + return absl::OkStatus(); } Status CommonSubgraphElimination::Optimize(Cluster* /*cluster*/, diff --git a/tensorflow/core/grappler/optimizers/constant_folding.cc b/tensorflow/core/grappler/optimizers/constant_folding.cc index c018067c9c626a..ce0dab802427f7 100644 --- a/tensorflow/core/grappler/optimizers/constant_folding.cc +++ b/tensorflow/core/grappler/optimizers/constant_folding.cc @@ -336,7 +336,7 @@ static Status PutValueIntoTensor(const int64_t value, const DataType& type, } else { tensor->flat()(index) = value; } - return OkStatus(); + return absl::OkStatus(); } // Writes the given tensor shape into the given tensor. @@ -361,7 +361,7 @@ static Status ConvertShapeToConstant(const string& op, const DataType& type, *tensor = Tensor(type, TensorShape({})); TF_RETURN_IF_ERROR(PutValueIntoTensor(shp.dims(), type, 0, tensor)); } - return OkStatus(); + return absl::OkStatus(); } // TODO(rmlarsen): Perhaps we should move this to the GraphOptimizer base class. @@ -563,7 +563,7 @@ Status ConstantFolding::MaterializeShapes(const GraphProperties& properties) { } } - return OkStatus(); + return absl::OkStatus(); } namespace { @@ -615,22 +615,22 @@ Status ConstantFolding::MaterializeBroadcastGradientArgs( (shape_node1->op() != "Shape" && !IsReallyConstant(*shape_node1)) || shape_node2 == nullptr || (shape_node2->op() != "Shape" && !IsReallyConstant(*shape_node2))) { - return OkStatus(); + return absl::OkStatus(); } // Don't optimize this again if it was already optimized and folded. if (OptimizedNodeExists(node, "-folded-1") || OptimizedNodeExists(node, "-folded-2")) { - return OkStatus(); + return absl::OkStatus(); } int64_t min_id = 0; BCast::Vec shape1; if (!ExtractShape(*shape_node1, properties, &shape1, &min_id)) { - return OkStatus(); + return absl::OkStatus(); } BCast::Vec shape2; if (!ExtractShape(*shape_node2, properties, &shape2, &min_id)) { - return OkStatus(); + return absl::OkStatus(); } // A value of -1 means we don't known anything about the dimension. Replace // the -1 values with unique dimension ids since we don't want two '-1' @@ -659,7 +659,7 @@ Status ConstantFolding::MaterializeBroadcastGradientArgs( // We're either dealing with 2 different symbolic dimensions or a symbolic // and a know dimensions. We can't be sure whether both are equal or not, // so we can't be sure whether we'll be broadcasting or not. - return OkStatus(); + return absl::OkStatus(); } } // These extra dims could be equal to 1, in which case there is no @@ -667,18 +667,18 @@ Status ConstantFolding::MaterializeBroadcastGradientArgs( // be broadcasting. Since we don't know, we'll just punt. for (int i = common_dims, end = shape1.size(); i < end; ++i) { if (shape1[i] < 0) { - return OkStatus(); + return absl::OkStatus(); } } for (int i = common_dims, end = shape2.size(); i < end; ++i) { if (shape2[i] < 0) { - return OkStatus(); + return absl::OkStatus(); } } BCast bcast(shape1, shape2); if (!bcast.IsValid()) { - return OkStatus(); + return absl::OkStatus(); } BCast::Vec reduce_dims[2]; @@ -727,39 +727,39 @@ Status ConstantFolding::MaterializeBroadcastGradientArgs( } } - return OkStatus(); + return absl::OkStatus(); } Status ConstantFolding::MaterializeReductionIndices( NodeDef* node, const GraphProperties& properties) { if (node->input_size() < 2) { - return OkStatus(); + return absl::OkStatus(); } const NodeDef* indices = node_map_->GetNode(node->input(1)); if (!indices || IsReallyConstant(*indices)) { // The reduction indices are already constant, there's nothing to do. - return OkStatus(); + return absl::OkStatus(); } const std::vector& input_props = properties.GetInputProperties(node->name()); if (input_props.size() != 2) { - return OkStatus(); + return absl::OkStatus(); } const OpInfo::TensorProperties& input_prop = input_props[0]; if (input_prop.shape().unknown_rank()) { // We can't do anything if we don't know the rank of the input. - return OkStatus(); + return absl::OkStatus(); } const int input_rank = input_prop.shape().dim_size(); if (input_rank < 1) { // Unexpected graph, don't try to change it. - return OkStatus(); + return absl::OkStatus(); } const OpInfo::TensorProperties& reduction_indices_prop = input_props[1]; DataType dtype = reduction_indices_prop.dtype(); if (dtype != DT_INT32 && dtype != DT_INT64) { - return OkStatus(); + return absl::OkStatus(); } PartialTensorShape reduction_indices_shape(reduction_indices_prop.shape()); const int num_reduction_indices = reduction_indices_shape.num_elements(); @@ -767,7 +767,7 @@ Status ConstantFolding::MaterializeReductionIndices( const std::vector& output_props = properties.GetOutputProperties(node->name()); if (output_props.size() != 1) { - return OkStatus(); + return absl::OkStatus(); } const OpInfo::TensorProperties& output_prop = output_props[0]; const int output_rank = @@ -782,23 +782,23 @@ Status ConstantFolding::MaterializeReductionIndices( for (const NodeDef* fanout : node_map_->GetOutputs(node->name())) { full_reduction = false; if (!IsReshape(*fanout)) { - return OkStatus(); + return absl::OkStatus(); } const std::vector& reshape_props = properties.GetOutputProperties(fanout->name()); if (reshape_props.size() != 1) { - return OkStatus(); + return absl::OkStatus(); } const OpInfo::TensorProperties& reshape_prop = reshape_props[0]; PartialTensorShape shape(reshape_prop.shape()); if (shape.num_elements() != 1) { - return OkStatus(); + return absl::OkStatus(); } else { full_reduction = true; } } if (!full_reduction) { - return OkStatus(); + return absl::OkStatus(); } } @@ -806,7 +806,7 @@ Status ConstantFolding::MaterializeReductionIndices( // reduce as a constant node. string const_name = OptimizedNodeName(*node, "-reduction_indices"); if (node_map_->GetNode(const_name)) { - return OkStatus(); + return absl::OkStatus(); } NodeDef* reduction_indices = graph_->add_node(); Tensor value(dtype, TensorShape({input_rank})); @@ -831,22 +831,22 @@ Status ConstantFolding::MaterializeReductionIndices( node_map_->UpdateInput(node->name(), indices->name(), reduction_indices->name()); - return OkStatus(); + return absl::OkStatus(); } Status ConstantFolding::MaterializeConstantValuedNode( NodeDef* node, const GraphProperties& properties) { if (disable_compressed_tensor_optimization_) { - return OkStatus(); + return absl::OkStatus(); } // Nodes that generate constant-valued outputs can be represented compactly in // compressed format, regardless of their shape. const std::vector& output_props = properties.GetOutputProperties(node->name()); - if (output_props.size() != 1) return OkStatus(); + if (output_props.size() != 1) return absl::OkStatus(); const auto& output_shape = output_props[0].shape(); if (!PartialTensorShape(output_shape).IsFullyDefined()) { - return OkStatus(); + return absl::OkStatus(); } if (IsFill(*node)) { const auto output_dtype = output_props[0].dtype(); @@ -854,7 +854,7 @@ Status ConstantFolding::MaterializeConstantValuedNode( for (int i = 0; i < 2; ++i) { input_node = node_map_->GetNode(NodeName(node->input(i))); if (input_node == nullptr || !IsReallyConstant(*input_node)) { - return OkStatus(); + return absl::OkStatus(); } } TF_RETURN_IF_ERROR(CheckAttrExists(*input_node, "value")); @@ -898,7 +898,7 @@ Status ConstantFolding::MaterializeConstantValuedNode( value, properties, output_shape, node, graph_)); } } - return OkStatus(); + return absl::OkStatus(); } // Materialize output values inferred by the shape inference. @@ -908,7 +908,7 @@ Status ConstantFolding::MaterializeOutputValues( properties.GetOutputProperties(node->name()); if (output.size() != 1 || !output[0].has_value() || !IsFoldable(*node, &properties)) { - return OkStatus(); + return absl::OkStatus(); } // If this is a trivial Identity node with a constant input, just route the @@ -919,7 +919,7 @@ Status ConstantFolding::MaterializeOutputValues( std::vector inputs_to_forward; std::iota(inputs_to_forward.begin(), inputs_to_forward.end(), 0); graph_modified_ = ForwardInputs(node, inputs_to_forward); - return OkStatus(); + return absl::OkStatus(); } } // Repurpose the existing node to be the constant. @@ -945,7 +945,7 @@ Status ConstantFolding::MaterializeConstants( TF_RETURN_IF_ERROR(MaterializeOutputValues(&node, properties)); } } - return OkStatus(); + return absl::OkStatus(); } bool ConstantFolding::IsFoldable(const NodeDef& node, @@ -1182,7 +1182,7 @@ Status CreateConstantTensorAttrValue(DataType type, double value, absl::StrCat("Unsupported type in CreateConstantTensorAttrValue: ", DataTypeString(type))); } - return OkStatus(); + return absl::OkStatus(); } #undef SET_TENSOR_CAL_CASE @@ -1331,7 +1331,7 @@ Status ConstantFolding::CreateNodeDef(const string& name, absl::StrCat("Can't fold ", name, ", its size would be too large (", encoded_size, " >= ", kMaxConstantSize, " bytes)")); } - return OkStatus(); + return absl::OkStatus(); } Status ConstantFolding::EvaluateNode(const NodeDef& node, @@ -1420,7 +1420,7 @@ Status ConstantFolding::EvaluateOneFoldable(const NodeDef& node, outputs->at(i) = NodeDef(); } } - return OkStatus(); + return absl::OkStatus(); } Status ConstantFolding::FoldMergeNode(NodeDef* node, GraphDef* output_graph) { @@ -1510,9 +1510,9 @@ Status ConstantFolding::FoldMergeNode(NodeDef* node, GraphDef* output_graph) { } } } - return OkStatus(); + return absl::OkStatus(); } - return OkStatus(); + return absl::OkStatus(); } Status ConstantFolding::FoldNode(NodeDef* node, GraphDef* output_graph, @@ -1638,7 +1638,7 @@ Status ConstantFolding::FoldNode(NodeDef* node, GraphDef* output_graph, node->clear_input(); } } - return OkStatus(); + return absl::OkStatus(); } Status ConstantFolding::FoldGraph( @@ -1703,7 +1703,7 @@ Status ConstantFolding::FoldGraph( *(optimized_graph->add_node()) = std::move(*node); } } - return OkStatus(); + return absl::OkStatus(); } Status ConstantFolding::IsSimplifiableReshape( @@ -1785,7 +1785,7 @@ Status ConstantFolding::IsSimplifiableReshape( "to be compatible with ", new_dims.DebugString())); } - return OkStatus(); + return absl::OkStatus(); } #define IS_VALUE_CASE(DTYPE, VALUE) \ @@ -2057,7 +2057,7 @@ Status ConstantFolding::ReplaceOperationWithConstantTensor(DataType dtype, TensorProto* value, NodeDef* node, GraphDef* graph) { - if (dtype == DT_VARIANT) return OkStatus(); + if (dtype == DT_VARIANT) return absl::OkStatus(); node->set_op("Const"); EraseRegularNodeAttributes(node); (*node->mutable_attr())["dtype"].set_type(dtype); @@ -2074,14 +2074,14 @@ Status ConstantFolding::ReplaceOperationWithConstantTensor(DataType dtype, } DedupControlInputs(node); graph_modified_ = true; - return OkStatus(); + return absl::OkStatus(); } Status ConstantFolding::ReplaceOperationWithConstant( double value, const GraphProperties& properties, const TensorShapeProto& shape, NodeDef* node, GraphDef* graph) { const DataType dtype = GetDataTypeFromNodeOrProps(*node, properties); - if (dtype == DT_VARIANT) return OkStatus(); + if (dtype == DT_VARIANT) return absl::OkStatus(); AttrValue tensor_attr; Status s = CreateConstantTensorAttrValue(dtype, value, shape, &tensor_attr); if (!s.ok()) { @@ -2089,7 +2089,7 @@ Status ConstantFolding::ReplaceOperationWithConstant( VLOG(1) << "Failed to replace node " << node->name() << " of type " << DataTypeString(dtype) << " with constant tensor of value " << value; - return OkStatus(); + return absl::OkStatus(); } return ReplaceOperationWithConstantTensor(dtype, tensor_attr.mutable_tensor(), node, graph); @@ -2112,7 +2112,7 @@ Status ConstantFolding::SimplifyGraph( TF_RETURN_IF_ERROR(SimplifyNode(node, optimized_graph, properties)); } } - return OkStatus(); + return absl::OkStatus(); } #define RETURN_IF_ERROR_OR_MODIFIED(EXPR) \ @@ -2180,7 +2180,7 @@ Status ConstantFolding::SimplifyNode(NodeDef* node, GraphDef* optimized_graph, RemoveRedundantVariableUpdates(properties, optimized_graph, node)); graph_modified_ = graph_modified_cached; - return OkStatus(); + return absl::OkStatus(); } void ConstantFolding::RemoveSplitOrSplitV(const GraphProperties& properties, @@ -2199,7 +2199,7 @@ Status ConstantFolding::RemoveShuffleOrTranspose( const GraphProperties& properties, bool use_shape_info, GraphDef* optimized_graph, NodeDef* node) { if (!use_shape_info || !(IsShuffle(*node) || IsTranspose(*node))) - return OkStatus(); + return absl::OkStatus(); Tensor permutation_tensor; if (GetTensorFromConstNode(node->input(1), &permutation_tensor) && properties.HasInputProperties(node->name())) { @@ -2215,7 +2215,7 @@ Status ConstantFolding::RemoveShuffleOrTranspose( int permutation_size = permutation.size(); if (permutation_size != shape.dim_size()) { // Number of elements in perm should be same as dim_size. Skip if not. - return OkStatus(); + return absl::OkStatus(); } // The node is replaceable iff // dim_size == 0 || all dims have size 1 || @@ -2228,7 +2228,7 @@ Status ConstantFolding::RemoveShuffleOrTranspose( ReplaceOperationWithIdentity(0, properties, node, optimized_graph); } } - return OkStatus(); + return absl::OkStatus(); } void ConstantFolding::RemoveRandomShuffle(const GraphProperties& properties, @@ -2251,12 +2251,12 @@ Status ConstantFolding::RemoveReverse(const GraphProperties& properties, bool use_shape_info, GraphDef* optimized_graph, NodeDef* node) { - if (!use_shape_info || node->op() != "ReverseV2") return OkStatus(); + if (!use_shape_info || node->op() != "ReverseV2") return absl::OkStatus(); Tensor axis; if (properties.HasInputProperties(node->name()) && GetTensorFromConstNode(node->input(1), &axis)) { const auto& shape = properties.GetInputProperties(node->name())[0].shape(); - if (shape.unknown_rank()) return OkStatus(); + if (shape.unknown_rank()) return absl::OkStatus(); std::set target_axes; for (int j = 0; j < axis.NumElements(); ++j) { // value of axis can be negative. @@ -2282,14 +2282,14 @@ Status ConstantFolding::RemoveReverse(const GraphProperties& properties, ReplaceOperationWithIdentity(0, properties, node, optimized_graph); } } - return OkStatus(); + return absl::OkStatus(); } Status ConstantFolding::SimplifySlice(const GraphProperties& properties, bool use_shape_info, GraphDef* optimized_graph, NodeDef* node) { - if (!use_shape_info || !IsSlice(*node)) return OkStatus(); + if (!use_shape_info || !IsSlice(*node)) return absl::OkStatus(); Tensor begin; Tensor size; if (properties.HasInputProperties(node->name()) && @@ -2317,7 +2317,7 @@ Status ConstantFolding::SimplifySlice(const GraphProperties& properties, ReplaceOperationWithIdentity(0, properties, node, optimized_graph); } } - return OkStatus(); + return absl::OkStatus(); } Status ConstantFolding::SimplifyStridedSlice(const GraphProperties& properties, @@ -2332,20 +2332,20 @@ Status ConstantFolding::SimplifyStridedSlice(const GraphProperties& properties, node->attr().at("shrink_axis_mask").i() != 0) { // Skip nodes with new/shrink axis mask, since they involve dimension // changes. - return OkStatus(); + return absl::OkStatus(); } const auto& input = properties.GetInputProperties(node->name())[0]; for (int j = 0; j < input.shape().dim_size(); ++j) { // Skip if input shape is not fully determined. if (input.shape().dim(j).size() < 0) { - return OkStatus(); + return absl::OkStatus(); } } std::vector input_tensors(3); for (int i = 1; i < 4; ++i) { if (!GetTensorFromConstNode(node->input(i), &input_tensors[i - 1])) { - return OkStatus(); + return absl::OkStatus(); } } @@ -2409,7 +2409,7 @@ Status ConstantFolding::SimplifyStridedSlice(const GraphProperties& properties, ReplaceOperationWithIdentity(0, properties, node, optimized_graph); } } - return OkStatus(); + return absl::OkStatus(); } Status ConstantFolding::SimplifyTile(const GraphProperties& properties, @@ -2434,13 +2434,13 @@ Status ConstantFolding::SimplifyTile(const GraphProperties& properties, ReplaceOperationWithIdentity(0, properties, node, optimized_graph); } } - return OkStatus(); + return absl::OkStatus(); } Status ConstantFolding::SimplifyPad(const GraphProperties& properties, bool use_shape_info, GraphDef* optimized_graph, NodeDef* node) { - if (!use_shape_info || !IsPad(*node)) return OkStatus(); + if (!use_shape_info || !IsPad(*node)) return absl::OkStatus(); Tensor paddings; if (GetTensorFromConstNode(node->input(1), &paddings)) { @@ -2461,7 +2461,7 @@ Status ConstantFolding::SimplifyPad(const GraphProperties& properties, ReplaceOperationWithIdentity(0, properties, node, optimized_graph); } } - return OkStatus(); + return absl::OkStatus(); } void ConstantFolding::SimplifySqueeze(const GraphProperties& properties, @@ -3026,13 +3026,13 @@ Status ConstantFolding::SimplifyArithmeticOperations( ReplaceBinaryOperationWithBroadcastTo(1, properties, node, optimized_graph); } - return OkStatus(); + return absl::OkStatus(); } if (y_matches_output_shape && (is_sub && x_is_zero)) { // Replace 0 - y with Neg(y). ReplaceSubtractionFromZeroByNegation(node, optimized_graph); - return OkStatus(); + return absl::OkStatus(); } // Replace 1 / y with Reciprocal op. @@ -3041,7 +3041,7 @@ Status ConstantFolding::SimplifyArithmeticOperations( DataType type = node->attr().at("T").type(); if (DataTypeIsFloating(type) || DataTypeIsComplex(type)) { ReplaceDivisionOfOnesByReciprocal(node, optimized_graph); - return OkStatus(); + return absl::OkStatus(); } } @@ -3056,7 +3056,7 @@ Status ConstantFolding::SimplifyArithmeticOperations( ReplaceBinaryOperationWithBroadcastTo(0, properties, node, optimized_graph); } - return OkStatus(); + return absl::OkStatus(); } // x OR true = true OR y = true. @@ -3064,7 +3064,7 @@ Status ConstantFolding::SimplifyArithmeticOperations( if (shp.IsFullyDefined() && IsLogicalOr(*node) && (y_is_one || x_is_one)) { TF_RETURN_IF_ERROR(ReplaceOperationWithConstant( 1, properties, output_shape, node, optimized_graph)); - return OkStatus(); + return absl::OkStatus(); } // Simplify multiplication and matmul by zeros. @@ -3082,7 +3082,7 @@ Status ConstantFolding::SimplifyArithmeticOperations( TF_RETURN_IF_ERROR( AddQuantizedMatMulMinMaxOutConstNodes(node, optimized_graph)); } - return OkStatus(); + return absl::OkStatus(); } // Even if an input shape is only partially known, we may known that it // matches the output shape and thus forward or broadcast the @@ -3094,7 +3094,7 @@ Status ConstantFolding::SimplifyArithmeticOperations( ReplaceBinaryOperationWithBroadcastTo(0, properties, node, optimized_graph); } - return OkStatus(); + return absl::OkStatus(); } else if (is_mul && y_is_zero) { if (y_matches_output_shape) { ReplaceOperationWithIdentity(1, properties, node, optimized_graph); @@ -3102,11 +3102,11 @@ Status ConstantFolding::SimplifyArithmeticOperations( ReplaceBinaryOperationWithBroadcastTo(1, properties, node, optimized_graph); } - return OkStatus(); + return absl::OkStatus(); } } } - return OkStatus(); + return absl::OkStatus(); } bool ConstantFolding::ReduceDivToReciprocalMul(GraphDef* optimized_graph, @@ -3980,7 +3980,7 @@ Status ConstantFolding::AddQuantizedMatMulMinMaxOutConstNodes( } } - return OkStatus(); + return absl::OkStatus(); }; const string min_out_const_name = OptimizedNodeName(*node, "-quantized_matmul_min_out"); @@ -3996,7 +3996,7 @@ Status ConstantFolding::AddQuantizedMatMulMinMaxOutConstNodes( "node '$0' because of node name conflict", node->name())); } - return OkStatus(); + return absl::OkStatus(); } Status ConstantFolding::RunOptimizationPass(Cluster* cluster, @@ -4035,7 +4035,7 @@ Status ConstantFolding::RunOptimizationPass(Cluster* cluster, TF_RETURN_IF_ERROR( SimplifyGraph(optimized_graph, properties, &nodes_to_not_simplify)); - return OkStatus(); + return absl::OkStatus(); } Status ConstantFolding::Optimize(Cluster* cluster, const GrapplerItem& item, @@ -4093,7 +4093,7 @@ Status ConstantFolding::Optimize(Cluster* cluster, const GrapplerItem& item, *optimized_graph->mutable_library() = item.graph.library(); *optimized_graph->mutable_versions() = item.graph.versions(); - return OkStatus(); + return absl::OkStatus(); } } // namespace grappler diff --git a/tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry_test.cc b/tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry_test.cc index ce7f574c8f4d99..4719e066a78276 100644 --- a/tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry_test.cc +++ b/tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry_test.cc @@ -35,13 +35,13 @@ class TestGraphOptimizer : public CustomGraphOptimizer { public: Status Init( const tensorflow::RewriterConfig_CustomGraphOptimizer* config) override { - return OkStatus(); + return absl::OkStatus(); } string name() const override { return kTestOptimizerName; } bool UsesFunctionLibrary() const override { return false; } Status Optimize(Cluster* cluster, const GrapplerItem& item, GraphDef* optimized_graph) override { - return OkStatus(); + return absl::OkStatus(); } }; @@ -89,13 +89,13 @@ class TestPluginGraphOptimizer : public CustomGraphOptimizer { public: Status Init( const tensorflow::RewriterConfig_CustomGraphOptimizer* config) override { - return OkStatus(); + return absl::OkStatus(); } string name() const override { return kTestPluginOptimizerName; } bool UsesFunctionLibrary() const override { return false; } Status Optimize(Cluster* cluster, const GrapplerItem& item, GraphDef* optimized_graph) override { - return OkStatus(); + return absl::OkStatus(); } }; diff --git a/tensorflow/core/grappler/optimizers/data/BUILD b/tensorflow/core/grappler/optimizers/data/BUILD index 0418e287ac5f02..bfed9693e0dcbe 100644 --- a/tensorflow/core/grappler/optimizers/data/BUILD +++ b/tensorflow/core/grappler/optimizers/data/BUILD @@ -38,6 +38,7 @@ cc_library( ":parallel_batch", ":remove_compression_map", ":replicate_on_split", + ":seq_interleave_prefetch", ":shuffle_and_repeat_fusion", ":slack", ":use_private_thread_pool", @@ -969,6 +970,62 @@ tf_cc_test( ], ) +cc_library( + name = "seq_interleave_prefetch", + srcs = ["seq_interleave_prefetch.cc"], + hdrs = ["seq_interleave_prefetch.h"], + deps = [ + ":function_utils", + ":fusion_utils", + ":graph_utils", + ":optimizer_base", + "//tensorflow/core:framework", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core/data:dataset_utils", + "//tensorflow/core/grappler:grappler_item", + "//tensorflow/core/grappler:mutable_graph_view", + "//tensorflow/core/grappler:op_types", + "//tensorflow/core/grappler:utils", + "//tensorflow/core/grappler/clusters:cluster", + "//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry", + "//tensorflow/core/grappler/utils:topological_sort", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/strings", + "@local_tsl//tsl/protobuf:error_codes_proto_impl_cc", + ] + tf_protos_all(), + alwayslink = 1, +) + +tf_cc_test( + name = "seq_interleave_prefetch_test", + size = "small", + srcs = ["seq_interleave_prefetch_test.cc"], + deps = [ + ":function_utils", + ":graph_test_utils", + ":graph_utils", + ":inject_io_prefetch", + ":seq_interleave_prefetch", + "//tensorflow/core:core_cpu_base", + "//tensorflow/core:framework", + "//tensorflow/core:test_main", + "//tensorflow/core/data:dataset_utils", + "//tensorflow/core/framework:function_proto_cc", + "//tensorflow/core/framework:function_testlib", + "//tensorflow/core/framework:graph_proto_cc", + "//tensorflow/core/framework:node_def_proto_cc", + "//tensorflow/core/framework:types_proto_cc", + "//tensorflow/core/grappler:grappler_item", + "//tensorflow/core/platform:status", + "@com_google_absl//absl/strings", + "@com_google_googletest//:gtest_main", + "@local_tsl//tsl/lib/core:status_test_util", + ], +) + cc_library( name = "inject_io_prefetch", srcs = ["inject_io_prefetch.cc"], diff --git a/tensorflow/core/grappler/optimizers/data/auto_shard.cc b/tensorflow/core/grappler/optimizers/data/auto_shard.cc index ecb6e140c2056c..1ba30b81062cf8 100644 --- a/tensorflow/core/grappler/optimizers/data/auto_shard.cc +++ b/tensorflow/core/grappler/optimizers/data/auto_shard.cc @@ -263,7 +263,7 @@ Status AddShardNode(MutableGraphView* graph, const NodeDef& add_before, TF_RETURN_IF_ERROR( graph->UpdateFanouts(add_after->name(), new_node_graph->name())); - return OkStatus(); + return absl::OkStatus(); } Status AddShuffleDataset(MutableGraphView* graph, const NodeDef& add_before, @@ -292,7 +292,7 @@ Status AddShuffleDataset(MutableGraphView* graph, const NodeDef& add_before, TF_RETURN_IF_ERROR( graph->UpdateFanouts(add_after->name(), new_node_graph->name())); - return OkStatus(); + return absl::OkStatus(); } Status AddShuffleDatasetV2(MutableGraphView* graph, const NodeDef& add_before, @@ -315,7 +315,7 @@ Status AddShuffleDatasetV2(MutableGraphView* graph, const NodeDef& add_before, TF_RETURN_IF_ERROR( graph->UpdateFanouts(add_after->name(), new_node_graph->name())); - return OkStatus(); + return absl::OkStatus(); } Status AddShuffleDatasetV3(MutableGraphView* graph, const NodeDef& add_before, @@ -346,7 +346,7 @@ Status AddShuffleDatasetV3(MutableGraphView* graph, const NodeDef& add_before, TF_RETURN_IF_ERROR( graph->UpdateFanouts(add_after->name(), new_node_graph->name())); - return OkStatus(); + return absl::OkStatus(); } bool ReaderOpInFunction(const NodeDef& node, @@ -390,7 +390,7 @@ Status RemoveShuffleDataset(MutableGraphView* graph, const NodeDef& node, } // TODO(frankchn): Traverse functions too. - return OkStatus(); + return absl::OkStatus(); } Status RemoveShuffleDatasetV2(MutableGraphView* graph, const NodeDef& node, @@ -412,7 +412,7 @@ Status RemoveShuffleDatasetV2(MutableGraphView* graph, const NodeDef& node, } // TODO(frankchn): Traverse functions too. - return OkStatus(); + return absl::OkStatus(); } Status RemoveShuffleDatasetV3(MutableGraphView* graph, const NodeDef& node, @@ -439,7 +439,7 @@ Status RemoveShuffleDatasetV3(MutableGraphView* graph, const NodeDef& node, } // TODO(frankchn): Traverse functions too. - return OkStatus(); + return absl::OkStatus(); } Status ProcessDatasetSourceNode(MutableGraphView* graph, const NodeDef& node, @@ -481,7 +481,7 @@ Status ProcessDatasetSourceNode(MutableGraphView* graph, const NodeDef& node, seed_generator_node, reshuffle_each_iteration)); } - return OkStatus(); + return absl::OkStatus(); } const NodeDef* FindFuncAndTensorSliceDataset( @@ -570,7 +570,7 @@ Status RecursivelyHandleOp(const NodeDef& node, int64_t num_workers, TF_RETURN_IF_ERROR(RecursivelyHandleOp(*input_node, num_workers, index, flib, graph, nodes_to_delete)); } - return OkStatus(); + return absl::OkStatus(); } // This handles the case for the following subgraph: @@ -673,7 +673,7 @@ Status RewriteRebatchV2ToV1(const NodeDef& sink_node, int64_t num_replicas, // sink_node to get the RebatchDataset. NodeDef* input_node = graph_utils::GetInputNode(sink_node, *graph); if (input_node->op() != kRebatchDatasetV2OpName) { - return OkStatus(); + return absl::OkStatus(); } NodeDef* rebatch_node = input_node; @@ -710,7 +710,7 @@ Status RewriteRebatchV2ToV1(const NodeDef& sink_node, int64_t num_replicas, shape->mutable_dim(0)->set_size(-1); } - return OkStatus(); + return absl::OkStatus(); } Status ShardByData(const NodeDef& sink_node, int64_t num_workers, int64_t index, @@ -762,7 +762,7 @@ Status ShardByHint(const NodeDef& sink_node, int64_t num_workers, int64_t index, (*(mutable_node->mutable_attr()))[data::ShardDatasetOp::kRequireNonEmpty] .set_b(true); } - return OkStatus(); + return absl::OkStatus(); } Status ApplyAutoShard(const NodeDef& sink_node, int64_t num_workers, @@ -774,7 +774,7 @@ Status ApplyAutoShard(const NodeDef& sink_node, int64_t num_workers, graph->graph()->library()); switch (policy) { case AutoShardPolicy::OFF: - return OkStatus(); + return absl::OkStatus(); case AutoShardPolicy::FILE: return ShardByFile(sink_node, num_workers, index, &flib, graph); case AutoShardPolicy::DATA: @@ -830,7 +830,7 @@ Status OptimizeGraph(const GrapplerItem& item, int64_t num_workers, metrics::RecordTFDataAutoShard(id, policy_applied, num_workers, num_replicas); } - return OkStatus(); + return absl::OkStatus(); } } // anonymous namespace @@ -949,7 +949,7 @@ Status AutoShard::Init( return errors::InvalidArgument(kNumReplicasAttrName, " should be >= 0"); } - return OkStatus(); + return absl::OkStatus(); } Status AutoShard::OptimizeAndCollectStats(Cluster* cluster, @@ -960,7 +960,7 @@ Status AutoShard::OptimizeAndCollectStats(Cluster* cluster, TF_RETURN_IF_ERROR(OptimizeGraph(item, num_workers_, index_, auto_shard_policy_, num_replicas_, output)); stats->num_changes++; - return OkStatus(); + return absl::OkStatus(); } REGISTER_GRAPH_OPTIMIZER_AS(AutoShard, "tf_auto_shard"); diff --git a/tensorflow/core/grappler/optimizers/data/autotune_buffer_sizes.cc b/tensorflow/core/grappler/optimizers/data/autotune_buffer_sizes.cc index cc9b45086986cf..adcc2ec3549dff 100644 --- a/tensorflow/core/grappler/optimizers/data/autotune_buffer_sizes.cc +++ b/tensorflow/core/grappler/optimizers/data/autotune_buffer_sizes.cc @@ -56,7 +56,7 @@ Status AutotuneBufferSizes::OptimizeAndCollectStats(Cluster* cluster, if (!autotune_) { VLOG(1) << "The optimization autotune_buffer_sizes is not applied if " "autotune is off."; - return OkStatus(); + return absl::OkStatus(); } MutableGraphView graph(output); @@ -110,7 +110,7 @@ Status AutotuneBufferSizes::OptimizeAndCollectStats(Cluster* cluster, } } - if (async_datasets.empty()) return OkStatus(); + if (async_datasets.empty()) return absl::OkStatus(); for (const NodeDef* async_dataset_node : async_datasets) { NodeDef prefetch_node; @@ -130,7 +130,7 @@ Status AutotuneBufferSizes::OptimizeAndCollectStats(Cluster* cluster, graph.UpdateFanouts(async_dataset_node->name(), added_node->name())); } - return OkStatus(); + return absl::OkStatus(); } REGISTER_GRAPH_OPTIMIZER_AS(AutotuneBufferSizes, "autotune_buffer_sizes"); diff --git a/tensorflow/core/grappler/optimizers/data/autotune_buffer_sizes.h b/tensorflow/core/grappler/optimizers/data/autotune_buffer_sizes.h index 5594e77dc85f69..6ef62b74dd9f10 100644 --- a/tensorflow/core/grappler/optimizers/data/autotune_buffer_sizes.h +++ b/tensorflow/core/grappler/optimizers/data/autotune_buffer_sizes.h @@ -45,7 +45,7 @@ class AutotuneBufferSizes : public TFDataOptimizerBase { Status Init( const tensorflow::RewriterConfig_CustomGraphOptimizer* config) override { - if (!config) return OkStatus(); + if (!config) return absl::OkStatus(); const string& autotune = config->parameter_map().at(kAutotune).s(); if (autotune == "true") { @@ -57,7 +57,7 @@ class AutotuneBufferSizes : public TFDataOptimizerBase { absl::StrCat("Received an invalid value for parameter ", kAutotune, ": ", autotune)); } - return OkStatus(); + return absl::OkStatus(); } Status OptimizeAndCollectStats(Cluster* cluster, const GrapplerItem& item, diff --git a/tensorflow/core/grappler/optimizers/data/batch_parallelization.cc b/tensorflow/core/grappler/optimizers/data/batch_parallelization.cc index 8a8af73b0cd987..cdb82ba93cf82f 100644 --- a/tensorflow/core/grappler/optimizers/data/batch_parallelization.cc +++ b/tensorflow/core/grappler/optimizers/data/batch_parallelization.cc @@ -63,14 +63,15 @@ Status BatchParallelization::OptimizeAndCollectStats(Cluster* cluster, if (!autotune_) { VLOG(1) << "The optimization batch_parallelization is not applied if " "autotune is off."; - return OkStatus(); + return absl::OkStatus(); } MutableGraphView graph(output); // If the GrapplerItem is derived from a FunctionDef, we don't optimize it, // because we only want to enable extra batch parallelism on the main dataset // pipeline. - if (graph_utils::IsItemDerivedFromFunctionDef(item, graph)) return OkStatus(); + if (graph_utils::IsItemDerivedFromFunctionDef(item, graph)) + return absl::OkStatus(); absl::flat_hash_set nodes_to_delete; FunctionLibraryDefinition function_library(OpRegistry::Global(), @@ -93,7 +94,7 @@ Status BatchParallelization::OptimizeAndCollectStats(Cluster* cluster, } TF_RETURN_IF_ERROR(graph.DeleteNodes(nodes_to_delete)); - return OkStatus(); + return absl::OkStatus(); } REGISTER_GRAPH_OPTIMIZER_AS(BatchParallelization, "batch_parallelization"); diff --git a/tensorflow/core/grappler/optimizers/data/batch_parallelization.h b/tensorflow/core/grappler/optimizers/data/batch_parallelization.h index 430b57a4cbd0bd..ac872763fff0ed 100644 --- a/tensorflow/core/grappler/optimizers/data/batch_parallelization.h +++ b/tensorflow/core/grappler/optimizers/data/batch_parallelization.h @@ -36,7 +36,7 @@ class BatchParallelization : public TFDataOptimizerBase { Status Init( const tensorflow::RewriterConfig_CustomGraphOptimizer* config) override { - if (!config) return OkStatus(); + if (!config) return absl::OkStatus(); const string& autotune = config->parameter_map().at(kAutotune).s(); if (autotune == "true") { @@ -47,7 +47,7 @@ class BatchParallelization : public TFDataOptimizerBase { return errors::InvalidArgument("Received an invalid value for parameter ", kAutotune, ": ", autotune); } - return OkStatus(); + return absl::OkStatus(); } Status OptimizeAndCollectStats(Cluster* cluster, const GrapplerItem& item, diff --git a/tensorflow/core/grappler/optimizers/data/disable_intra_op_parallelism.cc b/tensorflow/core/grappler/optimizers/data/disable_intra_op_parallelism.cc index 67c9918d4d6556..1812504dcae3e9 100644 --- a/tensorflow/core/grappler/optimizers/data/disable_intra_op_parallelism.cc +++ b/tensorflow/core/grappler/optimizers/data/disable_intra_op_parallelism.cc @@ -48,7 +48,8 @@ Status DisableIntraOpParallelism::OptimizeAndCollectStats( // If the GrapplerItem is derived from a FunctionDef, we don't optimize it, // because we only want to disable intra op parallelism on the main dataset // pipeline. - if (graph_utils::IsItemDerivedFromFunctionDef(item, graph)) return OkStatus(); + if (graph_utils::IsItemDerivedFromFunctionDef(item, graph)) + return absl::OkStatus(); if (item.fetch.size() != 1) { return errors::InvalidArgument( @@ -61,7 +62,7 @@ Status DisableIntraOpParallelism::OptimizeAndCollectStats( if (node.op() == target_dataset_op) { // If parallelism is set by the user, we keep the user setting instead // of disabling it. - return OkStatus(); + return absl::OkStatus(); } } } @@ -97,14 +98,14 @@ Status DisableIntraOpParallelism::OptimizeAndCollectStats( // attrs from the input node. If we fail to set the attributes, we abort the // rewrite. if (!graph_utils::CopyShapesAndTypesAttrs(*last_node, &insert_node)) - return OkStatus(); + return absl::OkStatus(); auto* added_node = graph.AddNode(std::move(insert_node)); TF_RETURN_IF_ERROR( graph.UpdateFanouts(last_node->name(), added_node->name())); stats->num_changes++; - return OkStatus(); + return absl::OkStatus(); } REGISTER_GRAPH_OPTIMIZER_AS(DisableIntraOpParallelism, diff --git a/tensorflow/core/grappler/optimizers/data/disable_intra_op_parallelism.h b/tensorflow/core/grappler/optimizers/data/disable_intra_op_parallelism.h index c8f5df89c32263..5ea8c79ec02855 100644 --- a/tensorflow/core/grappler/optimizers/data/disable_intra_op_parallelism.h +++ b/tensorflow/core/grappler/optimizers/data/disable_intra_op_parallelism.h @@ -33,7 +33,7 @@ class DisableIntraOpParallelism : public TFDataOptimizerBase { Status Init( const tensorflow::RewriterConfig_CustomGraphOptimizer* config) override { - return OkStatus(); + return absl::OkStatus(); } Status OptimizeAndCollectStats(Cluster* cluster, const GrapplerItem& item, diff --git a/tensorflow/core/grappler/optimizers/data/disable_prefetch_legacy_autotune.cc b/tensorflow/core/grappler/optimizers/data/disable_prefetch_legacy_autotune.cc index 4c2e0ffe0312e1..0218d590ec303e 100644 --- a/tensorflow/core/grappler/optimizers/data/disable_prefetch_legacy_autotune.cc +++ b/tensorflow/core/grappler/optimizers/data/disable_prefetch_legacy_autotune.cc @@ -41,7 +41,7 @@ Status DisablePrefetchLegacyAutotune::OptimizeAndCollectStats( if (!autotune_) { VLOG(1) << "The optimization disable_prefetch_legacy_autotune is not " "applied if autotune is off."; - return OkStatus(); + return absl::OkStatus(); } MutableGraphView graph(output); @@ -58,7 +58,7 @@ Status DisablePrefetchLegacyAutotune::OptimizeAndCollectStats( } } - return OkStatus(); + return absl::OkStatus(); } REGISTER_GRAPH_OPTIMIZER_AS(DisablePrefetchLegacyAutotune, diff --git a/tensorflow/core/grappler/optimizers/data/disable_prefetch_legacy_autotune.h b/tensorflow/core/grappler/optimizers/data/disable_prefetch_legacy_autotune.h index e3de9bf8ca15cc..a4f480f4f9c6bf 100644 --- a/tensorflow/core/grappler/optimizers/data/disable_prefetch_legacy_autotune.h +++ b/tensorflow/core/grappler/optimizers/data/disable_prefetch_legacy_autotune.h @@ -36,7 +36,7 @@ class DisablePrefetchLegacyAutotune : public TFDataOptimizerBase { Status Init( const tensorflow::RewriterConfig_CustomGraphOptimizer* config) override { - if (!config) return OkStatus(); + if (!config) return absl::OkStatus(); const string& autotune = config->parameter_map().at(kAutotune).s(); if (autotune == "true") { @@ -47,7 +47,7 @@ class DisablePrefetchLegacyAutotune : public TFDataOptimizerBase { return errors::InvalidArgument("Received an invalid value for parameter ", kAutotune, ": ", autotune); } - return OkStatus(); + return absl::OkStatus(); } Status OptimizeAndCollectStats(Cluster* cluster, const GrapplerItem& item, diff --git a/tensorflow/core/grappler/optimizers/data/enable_gradient_descent.cc b/tensorflow/core/grappler/optimizers/data/enable_gradient_descent.cc index 88d587de0dddee..e13ba28148cde9 100644 --- a/tensorflow/core/grappler/optimizers/data/enable_gradient_descent.cc +++ b/tensorflow/core/grappler/optimizers/data/enable_gradient_descent.cc @@ -44,14 +44,15 @@ Status EnableGradientDescent::OptimizeAndCollectStats( if (!autotune_) { VLOG(1) << "The optimization enable_gradient_descent is not applied if " "autotune is off."; - return OkStatus(); + return absl::OkStatus(); } MutableGraphView graph(output); // If the GrapplerItem is derived from a FunctionDef, we don't optimize it, // because we only want to enable gradient descent on the main dataset // pipeline. - if (graph_utils::IsItemDerivedFromFunctionDef(item, graph)) return OkStatus(); + if (graph_utils::IsItemDerivedFromFunctionDef(item, graph)) + return absl::OkStatus(); int index = graph_utils::FindGraphNodeWithOp(kModelDataset, *output); NodeDef& model_node = *(output->mutable_node(index)); @@ -61,7 +62,7 @@ Status EnableGradientDescent::OptimizeAndCollectStats( stats->num_changes++; } - return OkStatus(); + return absl::OkStatus(); } REGISTER_GRAPH_OPTIMIZER_AS(EnableGradientDescent, "enable_gradient_descent"); diff --git a/tensorflow/core/grappler/optimizers/data/enable_gradient_descent.h b/tensorflow/core/grappler/optimizers/data/enable_gradient_descent.h index 21f742a957de88..faf97e7aafa44c 100644 --- a/tensorflow/core/grappler/optimizers/data/enable_gradient_descent.h +++ b/tensorflow/core/grappler/optimizers/data/enable_gradient_descent.h @@ -36,7 +36,7 @@ class EnableGradientDescent : public TFDataOptimizerBase { Status Init( const tensorflow::RewriterConfig_CustomGraphOptimizer* config) override { - if (!config) return OkStatus(); + if (!config) return absl::OkStatus(); const string& autotune = config->parameter_map().at(kAutotune).s(); if (autotune == "true") { @@ -47,7 +47,7 @@ class EnableGradientDescent : public TFDataOptimizerBase { return errors::InvalidArgument("Received an invalid value for parameter ", kAutotune, ": ", autotune); } - return OkStatus(); + return absl::OkStatus(); } Status OptimizeAndCollectStats(Cluster* cluster, const GrapplerItem& item, diff --git a/tensorflow/core/grappler/optimizers/data/filter_fusion.cc b/tensorflow/core/grappler/optimizers/data/filter_fusion.cc index 7e214b1ea459d0..a3e929fda7adbb 100644 --- a/tensorflow/core/grappler/optimizers/data/filter_fusion.cc +++ b/tensorflow/core/grappler/optimizers/data/filter_fusion.cc @@ -125,7 +125,7 @@ Status FilterFusion::OptimizeAndCollectStats(Cluster* cluster, } TF_RETURN_IF_ERROR(graph.DeleteNodes(nodes_to_delete)); - return OkStatus(); + return absl::OkStatus(); } REGISTER_GRAPH_OPTIMIZER_AS(FilterFusion, "filter_fusion"); diff --git a/tensorflow/core/grappler/optimizers/data/filter_fusion.h b/tensorflow/core/grappler/optimizers/data/filter_fusion.h index 30058b4975aa92..f7d8849d7d0dab 100644 --- a/tensorflow/core/grappler/optimizers/data/filter_fusion.h +++ b/tensorflow/core/grappler/optimizers/data/filter_fusion.h @@ -33,7 +33,7 @@ class FilterFusion : public TFDataOptimizerBase { Status Init( const tensorflow::RewriterConfig_CustomGraphOptimizer* config) override { - return OkStatus(); + return absl::OkStatus(); } Status OptimizeAndCollectStats(Cluster* cluster, const GrapplerItem& item, diff --git a/tensorflow/core/grappler/optimizers/data/filter_parallelization.cc b/tensorflow/core/grappler/optimizers/data/filter_parallelization.cc index ee7b9a07f470b7..36b476acf7a9fa 100644 --- a/tensorflow/core/grappler/optimizers/data/filter_parallelization.cc +++ b/tensorflow/core/grappler/optimizers/data/filter_parallelization.cc @@ -61,14 +61,15 @@ Status FilterParallelization::OptimizeAndCollectStats( if (!autotune_) { VLOG(1) << "The optimization filter_parallelization is not applied if " "autotune is off."; - return OkStatus(); + return absl::OkStatus(); } MutableGraphView graph(output); // If the GrapplerItem is derived from a FunctionDef, we don't optimize it, // because we only want to enable extra filter parallelism on the main dataset // pipeline. - if (graph_utils::IsItemDerivedFromFunctionDef(item, graph)) return OkStatus(); + if (graph_utils::IsItemDerivedFromFunctionDef(item, graph)) + return absl::OkStatus(); absl::flat_hash_set nodes_to_delete; FunctionLibraryDefinition function_library(OpRegistry::Global(), @@ -96,7 +97,7 @@ Status FilterParallelization::OptimizeAndCollectStats( } TF_RETURN_IF_ERROR(graph.DeleteNodes(nodes_to_delete)); - return OkStatus(); + return absl::OkStatus(); } REGISTER_GRAPH_OPTIMIZER_AS(FilterParallelization, "filter_parallelization"); diff --git a/tensorflow/core/grappler/optimizers/data/filter_parallelization.h b/tensorflow/core/grappler/optimizers/data/filter_parallelization.h index fd9f7827b52e99..ab501a02242cce 100644 --- a/tensorflow/core/grappler/optimizers/data/filter_parallelization.h +++ b/tensorflow/core/grappler/optimizers/data/filter_parallelization.h @@ -36,7 +36,7 @@ class FilterParallelization : public TFDataOptimizerBase { Status Init( const tensorflow::RewriterConfig_CustomGraphOptimizer* config) override { - if (!config) return OkStatus(); + if (!config) return absl::OkStatus(); const string& autotune = config->parameter_map().at(kAutotune).s(); if (autotune == "true") { @@ -47,7 +47,7 @@ class FilterParallelization : public TFDataOptimizerBase { return errors::InvalidArgument("Received an invalid value for parameter ", kAutotune, ": ", autotune); } - return OkStatus(); + return absl::OkStatus(); } Status OptimizeAndCollectStats(Cluster* cluster, const GrapplerItem& item, diff --git a/tensorflow/core/grappler/optimizers/data/fusion_utils.cc b/tensorflow/core/grappler/optimizers/data/fusion_utils.cc index 07955f04f89e46..20b5940f98102c 100644 --- a/tensorflow/core/grappler/optimizers/data/fusion_utils.cc +++ b/tensorflow/core/grappler/optimizers/data/fusion_utils.cc @@ -19,6 +19,7 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "absl/strings/strip.h" #include "tensorflow/core/framework/dataset.h" +#include "tensorflow/core/framework/metrics.h" #include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/node_def_builder.h" #include "tensorflow/core/framework/op_def.pb.h" @@ -391,6 +392,9 @@ void ComposeSignature(const OpDef& first_signature, *fused_signature->mutable_output_arg() = second_signature.output_arg(); if (first_signature.is_stateful() || second_signature.is_stateful()) { + if (!(first_signature.is_stateful() && second_signature.is_stateful())) { + metrics::RecordTFDataDebug("fused_with_mixed_statefulness"); + } fused_signature->set_is_stateful(true); } diff --git a/tensorflow/core/grappler/optimizers/data/graph_test_utils.cc b/tensorflow/core/grappler/optimizers/data/graph_test_utils.cc index 578575b23e71e1..d01d5c8f4621c4 100644 --- a/tensorflow/core/grappler/optimizers/data/graph_test_utils.cc +++ b/tensorflow/core/grappler/optimizers/data/graph_test_utils.cc @@ -142,6 +142,24 @@ NodeDef MakeParallelInterleaveV4Node(StringPiece name, }); } +NodeDef MakeInterleaveNode(StringPiece name, StringPiece input_node_name, + StringPiece cycle_length_node_name, + StringPiece block_length_node_name, + StringPiece function_name, + StringPiece deterministic) { + return test::function::NDef( + name, "InterleaveDataset", + {string(input_node_name), string(cycle_length_node_name), + string(block_length_node_name)}, + { + {"f", FunctionDefHelper::FunctionRef(string(function_name))}, + {"Targuments", {}}, + {"output_shapes", gtl::ArraySlice{}}, + {"output_types", gtl::ArraySlice{}}, + {"deterministic", string(deterministic)}, + }); +} + NodeDef MakeParallelMapNode(StringPiece name, StringPiece input_node_name, StringPiece num_parallel_calls_node_name, StringPiece function_name, bool sloppy) { diff --git a/tensorflow/core/grappler/optimizers/data/graph_test_utils.h b/tensorflow/core/grappler/optimizers/data/graph_test_utils.h index 5865f03a81a747..7341329ac36030 100644 --- a/tensorflow/core/grappler/optimizers/data/graph_test_utils.h +++ b/tensorflow/core/grappler/optimizers/data/graph_test_utils.h @@ -73,6 +73,13 @@ NodeDef MakeParallelInterleaveV4Node(StringPiece name, StringPiece function_name, StringPiece deterministic); +// Creates a test NodeDef for InterleaveDataset. +NodeDef MakeInterleaveNode(StringPiece name, StringPiece input_node_name, + StringPiece cycle_length_node_name, + StringPiece block_length_node_name, + StringPiece function_name, + StringPiece deterministic); + // Creates a test NodeDef for ParallelMapDataset. NodeDef MakeParallelMapNode(StringPiece name, StringPiece input_node_name, StringPiece num_parallel_calls_node_name, diff --git a/tensorflow/core/grappler/optimizers/data/graph_utils.cc b/tensorflow/core/grappler/optimizers/data/graph_utils.cc index 41558b47747375..4494f576bc5078 100644 --- a/tensorflow/core/grappler/optimizers/data/graph_utils.cc +++ b/tensorflow/core/grappler/optimizers/data/graph_utils.cc @@ -189,7 +189,7 @@ Status GetScalarConstNodeValueHelper( get_value(tensor); - return OkStatus(); + return absl::OkStatus(); } template <> @@ -360,7 +360,7 @@ Status EnsureNodeNamesUnique(Graph* g) { } } - return OkStatus(); + return absl::OkStatus(); } Status GetFetchNode(const MutableGraphView& graph, const GrapplerItem& item, @@ -373,7 +373,7 @@ Status GetFetchNode(const MutableGraphView& graph, const GrapplerItem& item, *fetch_node = graph.GetNode(item.fetch.at(0)); - return OkStatus(); + return absl::OkStatus(); } bool IsItemDerivedFromFunctionDef(const GrapplerItem& item, @@ -466,7 +466,7 @@ Status SetMetadataName(const std::string& name, NodeDef* node) { } *metadata.mutable_name() = name; metadata.SerializeToString((*node->mutable_attr())["metadata"].mutable_s()); - return OkStatus(); + return absl::OkStatus(); } } // namespace graph_utils diff --git a/tensorflow/core/grappler/optimizers/data/inject_prefetch.cc b/tensorflow/core/grappler/optimizers/data/inject_prefetch.cc index 5051f23b1ec4f5..76eea7b955d063 100644 --- a/tensorflow/core/grappler/optimizers/data/inject_prefetch.cc +++ b/tensorflow/core/grappler/optimizers/data/inject_prefetch.cc @@ -86,13 +86,13 @@ Status InjectPrefetch::OptimizeAndCollectStats(Cluster* cluster, if (!autotune_) { VLOG(1) << "The optimization inject_prefetch is not applied if autotune is " "off."; - return OkStatus(); + return absl::OkStatus(); } MutableGraphView graph(output); // If the GrapplerItem is derived from a FunctionDef, we don't optimize it. if (graph_utils::IsItemDerivedFromFunctionDef(item, graph)) { - return OkStatus(); + return absl::OkStatus(); } if (item.fetch.size() != 1) { @@ -104,7 +104,7 @@ Status InjectPrefetch::OptimizeAndCollectStats(Cluster* cluster, NodeDef* sink_node = graph.GetNode(item.fetch.at(0)); NodeDef* last_node = graph_utils::GetInputNode(*sink_node, graph); if (!ShouldInjectPrefetch(last_node, graph)) { - return OkStatus(); + return absl::OkStatus(); } // Insert `prefetch(AUTOTUNE)` after the last node. @@ -124,7 +124,7 @@ Status InjectPrefetch::OptimizeAndCollectStats(Cluster* cluster, // attrs from the input node. If we fail to set the attributes, we abort the // rewrite. if (!graph_utils::CopyShapesAndTypesAttrs(*last_node, &prefetch_node)) - return OkStatus(); + return absl::OkStatus(); TF_RETURN_IF_ERROR( graph_utils::SetMetadataName(prefetch_node.name(), &prefetch_node)); @@ -134,7 +134,7 @@ Status InjectPrefetch::OptimizeAndCollectStats(Cluster* cluster, graph.UpdateFanouts(last_node->name(), added_node->name())); stats->num_changes++; - return OkStatus(); + return absl::OkStatus(); } REGISTER_GRAPH_OPTIMIZER_AS(InjectPrefetch, "inject_prefetch"); diff --git a/tensorflow/core/grappler/optimizers/data/inject_prefetch.h b/tensorflow/core/grappler/optimizers/data/inject_prefetch.h index 5ce6b3dadec085..433841fdbef251 100644 --- a/tensorflow/core/grappler/optimizers/data/inject_prefetch.h +++ b/tensorflow/core/grappler/optimizers/data/inject_prefetch.h @@ -37,7 +37,7 @@ class InjectPrefetch : public TFDataOptimizerBase { Status Init( const tensorflow::RewriterConfig_CustomGraphOptimizer* config) override { - if (!config) return OkStatus(); + if (!config) return absl::OkStatus(); const std::string& autotune = config->parameter_map().at(kAutotune).s(); if (autotune == "true") { @@ -48,7 +48,7 @@ class InjectPrefetch : public TFDataOptimizerBase { return errors::InvalidArgument("Received an invalid value for parameter ", kAutotune, ": ", autotune); } - return OkStatus(); + return absl::OkStatus(); } Status OptimizeAndCollectStats(Cluster* cluster, const GrapplerItem& item, diff --git a/tensorflow/core/grappler/optimizers/data/make_deterministic.cc b/tensorflow/core/grappler/optimizers/data/make_deterministic.cc index 7e402d05e7a953..3defd73e0b48a6 100644 --- a/tensorflow/core/grappler/optimizers/data/make_deterministic.cc +++ b/tensorflow/core/grappler/optimizers/data/make_deterministic.cc @@ -210,7 +210,7 @@ Status ConvertMapOrInterleave(const string& node_name, // Remove extra attributes not in Interleave or Map. node->mutable_attr()->erase("deterministic"); node->mutable_attr()->erase("sloppy"); - return OkStatus(); + return absl::OkStatus(); } // Returns all transitive dependencies of a set of nodes, including the nodes @@ -379,7 +379,7 @@ Status SplitMap( split_results.first_function; *graph->graph()->mutable_library()->mutable_function()->Add() = split_results.second_function; - return OkStatus(); + return absl::OkStatus(); } // Converts a ParallalBatch dataset to a Batch dataset, to make it @@ -391,7 +391,7 @@ Status ConvertBatch(const string& node_name, MutableGraphView* graph) { node->set_input(2, node->input(3)); node->set_input(3, absl::StrCat("^", num_parallel_calls_input)); node->mutable_attr()->erase("deterministic"); - return OkStatus(); + return absl::OkStatus(); } // Convert a MapAndBatch node to a separate Map node and Batch node, to make it @@ -473,7 +473,7 @@ Status ConvertMapAndBatch(const string& node_name, MutableGraphView* graph) { NodeDef* graph_batch_node = graph->AddNode(std::move(new_batch_node)); TF_RETURN_IF_ERROR( graph->UpdateFanouts(orig_node.name(), graph_batch_node->name())); - return OkStatus(); + return absl::OkStatus(); } // Change the buffer_size of a Prefetch node to zero, effectively disabling it, @@ -484,7 +484,7 @@ Status ConvertPrefetch(const string& node_name, MutableGraphView* graph) { node->add_input(absl::StrCat("^", node->input(buffer_size_index))); NodeDef* tmp = graph_utils::AddScalarConstNode(0, graph); node->set_input(buffer_size_index, tmp->name()); - return OkStatus(); + return absl::OkStatus(); } // The two ways nondeterminism can occur in an input pipeline when there are @@ -746,7 +746,7 @@ Status MakeDeterministic::OptimizeAndCollectStats(Cluster* cluster, } TF_RETURN_IF_ERROR(graph.DeleteNodes(nodes_to_delete)); - return OkStatus(); + return absl::OkStatus(); } REGISTER_GRAPH_OPTIMIZER_AS(MakeDeterministic, "make_deterministic"); diff --git a/tensorflow/core/grappler/optimizers/data/make_deterministic.h b/tensorflow/core/grappler/optimizers/data/make_deterministic.h index 1df6e10def699e..3f331c398d282b 100644 --- a/tensorflow/core/grappler/optimizers/data/make_deterministic.h +++ b/tensorflow/core/grappler/optimizers/data/make_deterministic.h @@ -62,7 +62,7 @@ class MakeDeterministic : public TFDataOptimizerBase { Status Init( const tensorflow::RewriterConfig_CustomGraphOptimizer* config) override { - return OkStatus(); + return absl::OkStatus(); } Status OptimizeAndCollectStats(Cluster* cluster, const GrapplerItem& item, diff --git a/tensorflow/core/grappler/optimizers/data/make_sloppy.cc b/tensorflow/core/grappler/optimizers/data/make_sloppy.cc index c76a4b6d80ffcf..6cc34fc7dd7f43 100644 --- a/tensorflow/core/grappler/optimizers/data/make_sloppy.cc +++ b/tensorflow/core/grappler/optimizers/data/make_sloppy.cc @@ -43,7 +43,7 @@ Status MakeSloppy::OptimizeAndCollectStats(Cluster* cluster, stats->num_changes++; } } - return OkStatus(); + return absl::OkStatus(); } REGISTER_GRAPH_OPTIMIZER_AS(MakeSloppy, "make_sloppy"); diff --git a/tensorflow/core/grappler/optimizers/data/make_sloppy.h b/tensorflow/core/grappler/optimizers/data/make_sloppy.h index 72caf5138eb7bb..bedf80dd9bccc7 100644 --- a/tensorflow/core/grappler/optimizers/data/make_sloppy.h +++ b/tensorflow/core/grappler/optimizers/data/make_sloppy.h @@ -32,7 +32,7 @@ class MakeSloppy : public TFDataOptimizerBase { Status Init( const tensorflow::RewriterConfig_CustomGraphOptimizer* config) override { - return OkStatus(); + return absl::OkStatus(); } Status OptimizeAndCollectStats(Cluster* cluster, const GrapplerItem& item, diff --git a/tensorflow/core/grappler/optimizers/data/map_and_batch_fusion.cc b/tensorflow/core/grappler/optimizers/data/map_and_batch_fusion.cc index e3e712deb2daea..69943e81044728 100644 --- a/tensorflow/core/grappler/optimizers/data/map_and_batch_fusion.cc +++ b/tensorflow/core/grappler/optimizers/data/map_and_batch_fusion.cc @@ -139,7 +139,7 @@ Status MapAndBatchFusion::OptimizeAndCollectStats(Cluster* cluster, } TF_RETURN_IF_ERROR(graph.DeleteNodes(nodes_to_delete)); - return OkStatus(); + return absl::OkStatus(); } REGISTER_GRAPH_OPTIMIZER_AS(MapAndBatchFusion, "map_and_batch_fusion"); diff --git a/tensorflow/core/grappler/optimizers/data/map_and_batch_fusion.h b/tensorflow/core/grappler/optimizers/data/map_and_batch_fusion.h index dfc809a343e6ac..1eff9f2b4e2bfe 100644 --- a/tensorflow/core/grappler/optimizers/data/map_and_batch_fusion.h +++ b/tensorflow/core/grappler/optimizers/data/map_and_batch_fusion.h @@ -32,7 +32,7 @@ class MapAndBatchFusion : public TFDataOptimizerBase { Status Init( const tensorflow::RewriterConfig_CustomGraphOptimizer* config) override { - return OkStatus(); + return absl::OkStatus(); } Status OptimizeAndCollectStats(Cluster* cluster, const GrapplerItem& item, diff --git a/tensorflow/core/grappler/optimizers/data/map_and_filter_fusion.cc b/tensorflow/core/grappler/optimizers/data/map_and_filter_fusion.cc index 0db7273578ba7f..89ae48482c3a2d 100644 --- a/tensorflow/core/grappler/optimizers/data/map_and_filter_fusion.cc +++ b/tensorflow/core/grappler/optimizers/data/map_and_filter_fusion.cc @@ -240,7 +240,7 @@ Status MapAndFilterFusion::OptimizeAndCollectStats(Cluster* cluster, } TF_RETURN_IF_ERROR(graph.DeleteNodes(nodes_to_delete)); - return OkStatus(); + return absl::OkStatus(); } REGISTER_GRAPH_OPTIMIZER_AS(MapAndFilterFusion, "map_and_filter_fusion"); diff --git a/tensorflow/core/grappler/optimizers/data/map_and_filter_fusion.h b/tensorflow/core/grappler/optimizers/data/map_and_filter_fusion.h index 11eca17cfe8ba8..9645ca8532b26b 100644 --- a/tensorflow/core/grappler/optimizers/data/map_and_filter_fusion.h +++ b/tensorflow/core/grappler/optimizers/data/map_and_filter_fusion.h @@ -41,7 +41,7 @@ class MapAndFilterFusion : public TFDataOptimizerBase { Status Init( const tensorflow::RewriterConfig_CustomGraphOptimizer* config) override { - return OkStatus(); + return absl::OkStatus(); } Status OptimizeAndCollectStats(Cluster* cluster, const GrapplerItem& item, diff --git a/tensorflow/core/grappler/optimizers/data/map_fusion.cc b/tensorflow/core/grappler/optimizers/data/map_fusion.cc index 67f015ff2efd9b..dfaa2f31d6ce35 100644 --- a/tensorflow/core/grappler/optimizers/data/map_fusion.cc +++ b/tensorflow/core/grappler/optimizers/data/map_fusion.cc @@ -164,7 +164,7 @@ Status MapFusion::OptimizeAndCollectStats(Cluster* cluster, if (!autotune_) { VLOG(1) << "The optimization map_fusion is not applied if " "autotune is off."; - return OkStatus(); + return absl::OkStatus(); } MutableGraphView graph(output); @@ -244,7 +244,7 @@ Status MapFusion::OptimizeAndCollectStats(Cluster* cluster, } TF_RETURN_IF_ERROR(graph.DeleteNodes(nodes_to_delete)); - return OkStatus(); + return absl::OkStatus(); } REGISTER_GRAPH_OPTIMIZER_AS(MapFusion, "map_fusion"); diff --git a/tensorflow/core/grappler/optimizers/data/map_fusion.h b/tensorflow/core/grappler/optimizers/data/map_fusion.h index 7ce44d51cf7fde..b6db4f74c08a74 100644 --- a/tensorflow/core/grappler/optimizers/data/map_fusion.h +++ b/tensorflow/core/grappler/optimizers/data/map_fusion.h @@ -36,7 +36,7 @@ class MapFusion : public TFDataOptimizerBase { Status Init( const tensorflow::RewriterConfig_CustomGraphOptimizer* config) override { - if (!config) return OkStatus(); + if (!config) return absl::OkStatus(); const string& autotune = config->parameter_map().at(kAutotune).s(); if (autotune == "true") { @@ -47,7 +47,7 @@ class MapFusion : public TFDataOptimizerBase { return errors::InvalidArgument("Received an invalid value for parameter ", kAutotune, ": ", autotune); } - return OkStatus(); + return absl::OkStatus(); } Status OptimizeAndCollectStats(Cluster* cluster, const GrapplerItem& item, diff --git a/tensorflow/core/grappler/optimizers/data/map_parallelization.cc b/tensorflow/core/grappler/optimizers/data/map_parallelization.cc index a7b7884620a3b1..44b33c4f83471d 100644 --- a/tensorflow/core/grappler/optimizers/data/map_parallelization.cc +++ b/tensorflow/core/grappler/optimizers/data/map_parallelization.cc @@ -62,14 +62,15 @@ Status MapParallelization::OptimizeAndCollectStats(Cluster* cluster, if (!autotune_) { VLOG(1) << "The optimization map_parallelization is not applied if " "autotune is off."; - return OkStatus(); + return absl::OkStatus(); } MutableGraphView graph(output); // If the GrapplerItem is derived from a FunctionDef, we don't optimize it, // because we only want to enable extra map parallelism on the main dataset // pipeline. - if (graph_utils::IsItemDerivedFromFunctionDef(item, graph)) return OkStatus(); + if (graph_utils::IsItemDerivedFromFunctionDef(item, graph)) + return absl::OkStatus(); absl::flat_hash_set nodes_to_delete; FunctionLibraryDefinition function_library(OpRegistry::Global(), @@ -97,7 +98,7 @@ Status MapParallelization::OptimizeAndCollectStats(Cluster* cluster, } TF_RETURN_IF_ERROR(graph.DeleteNodes(nodes_to_delete)); - return OkStatus(); + return absl::OkStatus(); } REGISTER_GRAPH_OPTIMIZER_AS(MapParallelization, "map_parallelization"); diff --git a/tensorflow/core/grappler/optimizers/data/map_parallelization.h b/tensorflow/core/grappler/optimizers/data/map_parallelization.h index 3b6de4da545c9b..8f6c3a795fba58 100644 --- a/tensorflow/core/grappler/optimizers/data/map_parallelization.h +++ b/tensorflow/core/grappler/optimizers/data/map_parallelization.h @@ -36,7 +36,7 @@ class MapParallelization : public TFDataOptimizerBase { Status Init( const tensorflow::RewriterConfig_CustomGraphOptimizer* config) override { - if (!config) return OkStatus(); + if (!config) return absl::OkStatus(); const string& autotune = config->parameter_map().at(kAutotune).s(); if (autotune == "true") { @@ -47,7 +47,7 @@ class MapParallelization : public TFDataOptimizerBase { return errors::InvalidArgument("Received an invalid value for parameter ", kAutotune, ": ", autotune); } - return OkStatus(); + return absl::OkStatus(); } Status OptimizeAndCollectStats(Cluster* cluster, const GrapplerItem& item, diff --git a/tensorflow/core/grappler/optimizers/data/meta_optimizer.cc b/tensorflow/core/grappler/optimizers/data/meta_optimizer.cc index 87bb504f146c12..3627ab9b626cd2 100644 --- a/tensorflow/core/grappler/optimizers/data/meta_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/data/meta_optimizer.cc @@ -38,7 +38,8 @@ using ConfigMap = std::map; // tf.data optimizations, in the order we want to perform them. -constexpr std::array kTFDataOptimizations = { +// clang-format off +constexpr std::array kTFDataOptimizations = { "noop_elimination", "disable_intra_op_parallelism", "use_private_thread_pool", @@ -54,12 +55,14 @@ constexpr std::array kTFDataOptimizations = { "parallel_batch", "slack", "autotune_buffer_sizes", + "seq_interleave_prefetch", "inject_prefetch", "inject_io_prefetch_eligible", "inject_io_prefetch", "disable_prefetch_legacy_autotune", "enable_gradient_descent", "make_deterministic"}; +// clang-format on // Parses a list of string optimizer configurations into a map from // optimizer name -> rewriter config for that optimizer. @@ -67,7 +70,7 @@ Status ToConfigMap( const tensorflow::RewriterConfig_CustomGraphOptimizer* config, ConfigMap* result) { auto found = gtl::FindOrNull(config->parameter_map(), "optimizer_configs"); - if (!found) return OkStatus(); + if (!found) return absl::OkStatus(); auto& options = found->list().s(); for (const auto& option_string : options) { @@ -95,7 +98,7 @@ Status ToConfigMap( config_value); } - return OkStatus(); + return absl::OkStatus(); } } // namespace @@ -160,7 +163,7 @@ Status TFDataMetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item, if (optimized_functions) { *output->mutable_library() = flib.ToProto(); } - return OkStatus(); + return absl::OkStatus(); } Status TFDataMetaOptimizer::ApplyOptimization(const string& name, @@ -170,7 +173,7 @@ Status TFDataMetaOptimizer::ApplyOptimization(const string& name, const auto* optimizer = gtl::FindOrNull(enabled_optimizers_, name); if (!optimizer) { - return OkStatus(); + return absl::OkStatus(); } GraphDef result; @@ -183,7 +186,7 @@ Status TFDataMetaOptimizer::ApplyOptimization(const string& name, // A status of errors::Aborted just means that the optimizer was a no-op and // did not populate result. Swallow the error status and leave the original // graph in item. - status = OkStatus(); + status = absl::OkStatus(); } return status; @@ -191,7 +194,7 @@ Status TFDataMetaOptimizer::ApplyOptimization(const string& name, Status TFDataMetaOptimizer::Init( const tensorflow::RewriterConfig_CustomGraphOptimizer* config) { - if (!config) return OkStatus(); + if (!config) return absl::OkStatus(); // Initialize custom tf.data optimizers based on config. auto& optimizers = config->parameter_map().at("optimizers").list().s(); @@ -213,7 +216,7 @@ Status TFDataMetaOptimizer::Init( } } - return OkStatus(); + return absl::OkStatus(); } REGISTER_GRAPH_OPTIMIZER_AS(TFDataMetaOptimizer, "tf_data_meta_optimizer"); diff --git a/tensorflow/core/grappler/optimizers/data/noop_elimination.cc b/tensorflow/core/grappler/optimizers/data/noop_elimination.cc index 7faa95a5cb10ef..8e32fba097a97c 100644 --- a/tensorflow/core/grappler/optimizers/data/noop_elimination.cc +++ b/tensorflow/core/grappler/optimizers/data/noop_elimination.cc @@ -169,7 +169,7 @@ Status NoOpElimination::OptimizeAndCollectStats(Cluster* cluster, } TF_RETURN_IF_ERROR(graph.DeleteNodes(nodes_to_delete)); - return OkStatus(); + return absl::OkStatus(); } REGISTER_GRAPH_OPTIMIZER_AS(NoOpElimination, "noop_elimination"); diff --git a/tensorflow/core/grappler/optimizers/data/noop_elimination.h b/tensorflow/core/grappler/optimizers/data/noop_elimination.h index 45fd59326c289a..a35f7a1b632772 100644 --- a/tensorflow/core/grappler/optimizers/data/noop_elimination.h +++ b/tensorflow/core/grappler/optimizers/data/noop_elimination.h @@ -34,7 +34,7 @@ class NoOpElimination : public TFDataOptimizerBase { Status Init( const tensorflow::RewriterConfig_CustomGraphOptimizer* config) override { - return OkStatus(); + return absl::OkStatus(); } Status OptimizeAndCollectStats(Cluster* cluster, const GrapplerItem& item, diff --git a/tensorflow/core/grappler/optimizers/data/parallel_batch.cc b/tensorflow/core/grappler/optimizers/data/parallel_batch.cc index 25e858f44713a9..512db36330b926 100644 --- a/tensorflow/core/grappler/optimizers/data/parallel_batch.cc +++ b/tensorflow/core/grappler/optimizers/data/parallel_batch.cc @@ -38,7 +38,7 @@ Status ParallelBatch::OptimizeAndCollectStats(Cluster* cluster, stats->num_changes++; } } - return OkStatus(); + return absl::OkStatus(); } REGISTER_GRAPH_OPTIMIZER_AS(ParallelBatch, "parallel_batch"); diff --git a/tensorflow/core/grappler/optimizers/data/parallel_batch.h b/tensorflow/core/grappler/optimizers/data/parallel_batch.h index b4ac857a9eaa6c..c6b0a7b64036c3 100644 --- a/tensorflow/core/grappler/optimizers/data/parallel_batch.h +++ b/tensorflow/core/grappler/optimizers/data/parallel_batch.h @@ -32,7 +32,7 @@ class ParallelBatch : public TFDataOptimizerBase { Status Init( const tensorflow::RewriterConfig_CustomGraphOptimizer* config) override { - return OkStatus(); + return absl::OkStatus(); } Status OptimizeAndCollectStats(Cluster* cluster, const GrapplerItem& item, diff --git a/tensorflow/core/grappler/optimizers/data/remove_compression_map.cc b/tensorflow/core/grappler/optimizers/data/remove_compression_map.cc index 8c674a87359347..4753892e7f28b4 100644 --- a/tensorflow/core/grappler/optimizers/data/remove_compression_map.cc +++ b/tensorflow/core/grappler/optimizers/data/remove_compression_map.cc @@ -33,7 +33,7 @@ namespace grappler { namespace { -StatusOr GetCompressionFunctionName(const GraphDef& graph) { +absl::StatusOr GetCompressionFunctionName(const GraphDef& graph) { for (const auto& function : graph.library().function()) { for (const auto& node : function.node_def()) { if (node.op() == "CompressElement") { @@ -44,7 +44,7 @@ StatusOr GetCompressionFunctionName(const GraphDef& graph) { return errors::Internal("Compression function not found."); } -StatusOr GetCompressionMapNode(const GraphDef& graph) { +absl::StatusOr GetCompressionMapNode(const GraphDef& graph) { TF_ASSIGN_OR_RETURN(std::string compression_function_name, GetCompressionFunctionName(graph)); for (const auto& node : graph.node()) { @@ -76,7 +76,7 @@ Status RemoveCompressionMap::OptimizeAndCollectStats(Cluster* cluster, compression_map_output.node->add_input(compression_map_node.input().Get(0)); ++stats->num_changes; } - return OkStatus(); + return absl::OkStatus(); } REGISTER_GRAPH_OPTIMIZER_AS(RemoveCompressionMap, "remove_compression_map"); diff --git a/tensorflow/core/grappler/optimizers/data/remove_compression_map.h b/tensorflow/core/grappler/optimizers/data/remove_compression_map.h index d3dc881d119a42..6306cca768894f 100644 --- a/tensorflow/core/grappler/optimizers/data/remove_compression_map.h +++ b/tensorflow/core/grappler/optimizers/data/remove_compression_map.h @@ -32,7 +32,7 @@ class RemoveCompressionMap : public TFDataOptimizerBase { Status Init( const tensorflow::RewriterConfig_CustomGraphOptimizer* config) override { - return OkStatus(); + return absl::OkStatus(); } Status OptimizeAndCollectStats(Cluster* cluster, const GrapplerItem& item, diff --git a/tensorflow/core/grappler/optimizers/data/replicate_on_split.cc b/tensorflow/core/grappler/optimizers/data/replicate_on_split.cc index 00d501eed99874..43bcca49e2a7a5 100644 --- a/tensorflow/core/grappler/optimizers/data/replicate_on_split.cc +++ b/tensorflow/core/grappler/optimizers/data/replicate_on_split.cc @@ -40,7 +40,7 @@ Status ReplicateOnSplit::OptimizeAndCollectStats(Cluster* cluster, stats->num_changes++; } } - return OkStatus(); + return absl::OkStatus(); } REGISTER_GRAPH_OPTIMIZER_AS(ReplicateOnSplit, "replicate_on_split"); diff --git a/tensorflow/core/grappler/optimizers/data/replicate_on_split.h b/tensorflow/core/grappler/optimizers/data/replicate_on_split.h index 42732fbb1a1bea..338ef29b3fdcc3 100644 --- a/tensorflow/core/grappler/optimizers/data/replicate_on_split.h +++ b/tensorflow/core/grappler/optimizers/data/replicate_on_split.h @@ -32,7 +32,7 @@ class ReplicateOnSplit : public TFDataOptimizerBase { Status Init( const tensorflow::RewriterConfig_CustomGraphOptimizer* config) override { - return OkStatus(); + return absl::OkStatus(); } Status OptimizeAndCollectStats(Cluster* cluster, const GrapplerItem& item, diff --git a/tensorflow/core/grappler/optimizers/data/seq_interleave_prefetch.cc b/tensorflow/core/grappler/optimizers/data/seq_interleave_prefetch.cc new file mode 100644 index 00000000000000..11e4916e5a18e5 --- /dev/null +++ b/tensorflow/core/grappler/optimizers/data/seq_interleave_prefetch.cc @@ -0,0 +1,362 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/grappler/optimizers/data/seq_interleave_prefetch.h" + +#include +#include +#include +#include + +#include "absl/container/flat_hash_set.h" +#include "absl/strings/match.h" +#include "absl/strings/str_cat.h" +#include "tensorflow/core/data/dataset_utils.h" +#include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/framework/function.pb.h" +#include "tensorflow/core/framework/model.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/grappler/clusters/cluster.h" +#include "tensorflow/core/grappler/grappler_item.h" +#include "tensorflow/core/grappler/mutable_graph_view.h" +#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h" +#include "tensorflow/core/grappler/optimizers/data/function_utils.h" +#include "tensorflow/core/grappler/optimizers/data/graph_utils.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/platform/stringpiece.h" +#include "tensorflow/core/platform/types.h" +#include "tsl/platform/errors.h" + +namespace tensorflow { +namespace grappler { +namespace { + +constexpr char kInterleaveDatasetOpName[] = "InterleaveDataset"; +constexpr char kParallelInterleaveDatasetV2OpName[] = + "ParallelInterleaveDatasetV2"; +constexpr char kParallelInterleaveDatasetV3OpName[] = + "ParallelInterleaveDatasetV3"; +constexpr char kParallelInterleaveDatasetV4OpName[] = + "ParallelInterleaveDatasetV4"; +constexpr char kParallelInterleaveDatasetOpName[] = "ParallelInterleaveDataset"; +constexpr char kPrefetchDatasetOpName[] = "PrefetchDataset"; +constexpr char kDatasetStr[] = "Dataset"; +constexpr char kConstOpName[] = "Const"; +constexpr char kOutputShapes[] = "output_shapes"; +constexpr char kOutputTypes[] = "output_types"; +constexpr char kConstNodeOutputSuffix[] = ":output:0"; +constexpr char kDatasetNodeOutputSuffix[] = ":handle:0"; +constexpr char kDeterministicAttr[] = "deterministic"; +constexpr char kFunctionAttr[] = "f"; +constexpr char kDTypeAttr[] = "dtype"; +constexpr char kValueAttr[] = "value"; +constexpr char kTArgumentsAttr[] = "Targuments"; +constexpr char kOutputTypesAttr[] = "output_types"; +constexpr char kMetadataAttr[] = "metadata"; +constexpr char kOutputShapesAttr[] = "output_shapes"; +constexpr char kTOutputTypesAttr[] = "Toutput_types"; +constexpr char kSeqInterleavePrefetchRewritePrefix[] = + "inject/seq_interleave_prefetch_rewrite_"; + +// +// Steps involved in rewrite: +// +// For every deterministic parallel interleave node, +// 1. Create interleave node and set the `interleave_fn` function same as the +// `interleave_fn` in parallel interleave node. +// - Update fan outs in the top level graph. +// - Delete parallel interleave nodes and its unused input nodes. +// 2. Create a prefetch node with 'input set to (input of Identity node in +// FLD of the graph) +// - From the signature of 'f', find the output node (Identity node). +// - Find the input of this output node and set it as input of Prefetch +// node. +//. - Add prefetch and its input nodes to the FunctionDef. +// - Update fan outs of prefetch node. +// + +bool IsParallelInterleave(const std::string& op) { + return data::MatchesAnyVersion(kParallelInterleaveDatasetOpName, op); +} + +// Returns the number of inputs accepted by the parallel interleave op as per +// the version excluding the `other_arguments` input. +int GetNumInputsForParallelInterleaveOp(const std::string& op) { + if (op == kParallelInterleaveDatasetV2OpName) { + return 4; + } else if (op == kParallelInterleaveDatasetV3OpName) { + return 4; + } else if (op == kParallelInterleaveDatasetV4OpName) { + return 6; + } + return 0; +} + +// Check if op type of `node` has "Dataset" suffix. +bool NodeOpHasDatasetSuffix(const NodeDef& node) { + return absl::EndsWith(node.op(), kDatasetStr); +} + +// Returns true if there is at least one function node with dataset op. +bool DatasetOpInFunction(const NodeDef& node, const FunctionDef* fn) { + for (const auto& node : fn->node_def()) { + if (NodeOpHasDatasetSuffix(node)) { + return true; + } + } + return false; +} + +// A node is eligible for rewrite if it is a deterministic parallel interleave +// node and has a function node creating `Dataset`. +bool RewritePossibleForNode(const NodeDef& node, + const FunctionLibraryDefinition& fld) { + auto is_deterministic_parallel_interleave_node = [&]() -> bool { + if (!IsParallelInterleave(node.op())) return false; + auto determinism_value = node.attr().find(kDeterministicAttr); + return (determinism_value != node.attr().end()) && + (determinism_value->second.s() == "true"); + }; + + if (node.attr().count(kFunctionAttr) == 0) return false; + const FunctionDef* fn = fld.Find(node.attr().at(kFunctionAttr).func().name()); + + if (fn == nullptr) return false; + if (fn->signature().output_arg_size() != 1) return false; + if (is_deterministic_parallel_interleave_node()) { + return DatasetOpInFunction(node, fn); + } + + return false; +} + +NodeDef CreateBufferSizeNode(DataType dtype, + const std::function& add_value, + MutableGraphView* graph, FunctionDef& fdef) { + NodeDef node; + node.set_op(kConstOpName); + function_utils::SetUniqueFunctionNodeName( + absl::StrCat(kSeqInterleavePrefetchRewritePrefix, "buffer_size"), &fdef, + &node); + + (*node.mutable_attr())[kDTypeAttr].set_type(dtype); + auto tensor = std::make_unique(); + auto tensor_shape = std::make_unique(); + tensor->set_allocated_tensor_shape(tensor_shape.release()); + tensor->set_dtype(dtype); + add_value(tensor.get()); + (*node.mutable_attr())[kValueAttr].set_allocated_tensor(tensor.release()); + + return node; +} + +Status CreateAndAppendPrefetchNode(MutableGraphView* graph, FunctionDef& fdef) { + auto get_last_dataset_op_node = [&]() -> const NodeDef* { + // Find the input node of fdef's ret value. + const auto& output_arg = fdef.signature().output_arg(0).name(); + const auto& ret_val = fdef.ret().at(output_arg); + auto input = function_utils::FunctionDefTensorDesc(ret_val); + // Walk from output to input and find the first eligible node. + const NodeDef* dataset_op_node = nullptr; + while ( + function_utils::ContainsFunctionNodeWithName(input.node_name, fdef)) { + int idx = function_utils::FindFunctionNodeWithName(input.node_name, fdef); + const NodeDef& node = fdef.node_def(idx); + if (NodeOpHasDatasetSuffix(node)) { + dataset_op_node = &node; + break; + } + input = function_utils::FunctionDefTensorDesc(node.input(0)); + } + return dataset_op_node; + }; + + // 1. Find the position for the `prefetch` node. + const NodeDef* add_after = get_last_dataset_op_node(); + if (add_after == nullptr) { + return errors::NotFound( + "Could not find any dataset node to append `Prefetch` at its output in " + "`seq_interleave_prefetch` rewrite"); + } + + // 2. Create prefetch node. + NodeDef prefetch_node; + prefetch_node.set_op(kPrefetchDatasetOpName); + function_utils::SetUniqueFunctionNodeName( + absl::StrCat(kSeqInterleavePrefetchRewritePrefix, + fdef.signature().name()), + &fdef, &prefetch_node); + + // 3. Construct argument nodes. + const auto input_dataset = + absl::StrCat(add_after->name(), kDatasetNodeOutputSuffix); + NodeDef buffer_size_node = CreateBufferSizeNode( + DT_INT64, + [](TensorProto* proto) { proto->add_int64_val(data::model::kAutotune); }, + graph, fdef); + + // 4. Add inputs to prefetch nodes. + prefetch_node.add_input(input_dataset); + prefetch_node.add_input( + absl::StrCat(buffer_size_node.name(), kConstNodeOutputSuffix)); + + // 5. Set other attributes of prefetch node. + if (add_after->attr().count(kOutputShapes) > 0) { + graph_utils::CopyAttribute(kOutputShapes, *add_after, &prefetch_node); + } else { + tensorflow::TensorShapeProto* shape = + (*(prefetch_node.mutable_attr()))[kOutputShapes] + .mutable_list() + ->add_shape(); + shape->set_unknown_rank(true); + } + + if (add_after->attr().count(kOutputTypes) > 0) { + graph_utils::CopyAttribute(kOutputTypes, *add_after, &prefetch_node); + } else if (add_after->attr().count(kTOutputTypesAttr) > 0) { + (*(prefetch_node.mutable_attr()))[kOutputTypes] = + add_after->attr().at(kTOutputTypesAttr); + } else { + (*(prefetch_node.mutable_attr()))[kOutputTypes].mutable_list()->add_type( + tensorflow::DataType::DT_STRING); + } + + // 6. Update fanouts. + std::string old_input = input_dataset; + std::string new_input = + absl::StrCat(prefetch_node.name(), kDatasetNodeOutputSuffix); + function_utils::ReplaceReferences(old_input, new_input, &fdef); + + // 7. Add `prefetch` and its argument nodes to `fdef`. + *fdef.add_node_def() = std::move(prefetch_node); + *fdef.add_node_def() = std::move(buffer_size_node); + + return absl::OkStatus(); +} + +Status AddInterleaveNode(MutableGraphView* graph, + const NodeDef& parallel_interleave_node, + const std::string& interleave_map_func_name, + absl::flat_hash_set& nodes_to_delete) { + NodeDef interleave_node; + interleave_node.set_op(kInterleaveDatasetOpName); + graph_utils::SetUniqueGraphNodeName( + absl::StrCat(kSeqInterleavePrefetchRewritePrefix, + parallel_interleave_node.name()), + graph->graph(), &interleave_node); + + // Inputs to interleave node passed from parallel interleave node would + // comprise of `input_dataset`, `other_arguments`, `cycle_length`, and + // `block_length`. + int num_other_args = + parallel_interleave_node.input_size() - + GetNumInputsForParallelInterleaveOp(parallel_interleave_node.op()); + int inputs_from_parallel_interleave = 1 /* input_dataset */ + num_other_args + + 1 /* cycle_length */ + + 1 /* block_length */; + for (int i = 0; i < inputs_from_parallel_interleave; ++i) { + interleave_node.add_input(parallel_interleave_node.input(i)); + } + + // Copy attributes. + if (parallel_interleave_node.attr().contains(kTArgumentsAttr)) { + graph_utils::CopyAttribute(kTArgumentsAttr, parallel_interleave_node, + &interleave_node); + } + if (parallel_interleave_node.attr().contains(kOutputTypesAttr)) { + graph_utils::CopyAttribute(kOutputTypesAttr, parallel_interleave_node, + &interleave_node); + } + if (parallel_interleave_node.attr().contains(kOutputShapesAttr)) { + graph_utils::CopyAttribute(kOutputShapesAttr, parallel_interleave_node, + &interleave_node); + } + if (parallel_interleave_node.attr().contains(kMetadataAttr)) { + graph_utils::CopyAttribute(kMetadataAttr, parallel_interleave_node, + &interleave_node); + } + + // Set the interleave function attr to the same function as in parallel + // interleave. + const auto& parallel_interleave_fn_attr = + parallel_interleave_node.attr().at(kFunctionAttr); + (*interleave_node.mutable_attr())[kFunctionAttr] = + parallel_interleave_fn_attr; + (*interleave_node.mutable_attr())[kFunctionAttr].mutable_func()->set_name( + interleave_map_func_name); + + // Copy shapes and types attributes. + graph_utils::CopyShapesAndTypesAttrs(parallel_interleave_node, + &interleave_node); + + // Copy experimental types. + *interleave_node.mutable_experimental_type() = + parallel_interleave_node.experimental_type(); + + // Add new node into graph and update edges + NodeDef* new_node_graph = graph->AddNode(std::move(interleave_node)); + TF_RETURN_IF_ERROR(graph->UpdateFanouts(parallel_interleave_node.name(), + new_node_graph->name())); + + // Delete the parallel interleave node. + nodes_to_delete.insert(parallel_interleave_node.name()); + return absl::OkStatus(); +} +} // namespace + +Status SeqInterleavePrefetch::OptimizeAndCollectStats( + Cluster* cluster, const GrapplerItem& item, GraphDef* output, + OptimizationStats* stats) { + *output = item.graph; + MutableGraphView graph(output); + absl::flat_hash_set nodes_to_delete; + FunctionLibraryDefinition fld(OpRegistry::Global(), item.graph.library()); + + for (const NodeDef& node : item.graph.node()) { + if (!RewritePossibleForNode(node, fld)) continue; + // Find the parallel_interleave_node's `map_func`. + const FunctionDef* parallel_interleave_fn = + fld.Find(node.attr().at("f").func().name()); + FunctionDef interleave_fn(*parallel_interleave_fn); + interleave_fn.mutable_signature()->set_name( + absl::StrCat(kSeqInterleavePrefetchRewritePrefix, + parallel_interleave_fn->signature().name())); + // Replace the parallel interleave node with interleave. + TF_RETURN_IF_ERROR(AddInterleaveNode( + &graph, node, interleave_fn.signature().name(), nodes_to_delete)); + // Create and append the prefetch node to the interleave_fn. + TF_RETURN_IF_ERROR(CreateAndAppendPrefetchNode(&graph, interleave_fn)); + // Replace the `parallel_interleave_fn` with `interleave_fn`. + TF_RETURN_IF_ERROR(fld.ReplaceFunction( + parallel_interleave_fn->signature().name(), interleave_fn)); + stats->num_changes++; + } + + // Update the `FunctionDefLibrary` of the optimized graph. + *output->mutable_library() = fld.ToProto(); + TF_RETURN_IF_ERROR(graph.DeleteNodes(nodes_to_delete)); + + return absl::OkStatus(); +} + +REGISTER_GRAPH_OPTIMIZER_AS(SeqInterleavePrefetch, "seq_interleave_prefetch"); + +} // namespace grappler +} // namespace tensorflow diff --git a/tensorflow/core/grappler/optimizers/data/seq_interleave_prefetch.h b/tensorflow/core/grappler/optimizers/data/seq_interleave_prefetch.h new file mode 100644 index 00000000000000..00cfed1ed78abd --- /dev/null +++ b/tensorflow/core/grappler/optimizers/data/seq_interleave_prefetch.h @@ -0,0 +1,54 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_SEQ_INTERLEAVE_PREFETCH_H_ +#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_SEQ_INTERLEAVE_PREFETCH_H_ + +#include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/grappler/optimizers/data/optimizer_base.h" + +namespace tensorflow { +namespace grappler { + +// This optimization replaces parallel interleave with sequential interleave and +// adds `prefetch(AUTOTUNE)` after the user defined map function in interleave. +class SeqInterleavePrefetch : public TFDataOptimizerBase { + public: + SeqInterleavePrefetch() = default; + ~SeqInterleavePrefetch() override = default; + + std::string name() const override { return "seq_interleave_prefetch"; }; + + // The SeqInterleavePrefetch optimizer requires access to the function + // library. + bool UsesFunctionLibrary() const override { return true; } + + Status Init( + const tensorflow::RewriterConfig_CustomGraphOptimizer* config) override { + return absl::OkStatus(); + } + + Status OptimizeAndCollectStats(Cluster* cluster, const GrapplerItem& item, + GraphDef* output, + OptimizationStats* stats) override; + + protected: + bool autotune_ = true; +}; + +} // namespace grappler +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_SEQ_INTERLEAVE_PREFETCH_H_ diff --git a/tensorflow/core/grappler/optimizers/data/seq_interleave_prefetch_test.cc b/tensorflow/core/grappler/optimizers/data/seq_interleave_prefetch_test.cc new file mode 100644 index 00000000000000..076357accef923 --- /dev/null +++ b/tensorflow/core/grappler/optimizers/data/seq_interleave_prefetch_test.cc @@ -0,0 +1,381 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/grappler/optimizers/data/seq_interleave_prefetch.h" + +#include + +#include +#include "absl/strings/str_cat.h" +#include "tensorflow/core/data/dataset_utils.h" +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/framework/function.pb.h" +#include "tensorflow/core/framework/function_testlib.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/grappler/grappler_item.h" +#include "tensorflow/core/grappler/optimizers/data/function_utils.h" +#include "tensorflow/core/grappler/optimizers/data/graph_test_utils.h" +#include "tensorflow/core/grappler/optimizers/data/graph_utils.h" +#include "tensorflow/core/platform/status.h" +#include "tsl/lib/core/status_test_util.h" + +namespace tensorflow { +namespace grappler { +namespace { + +using test::function::GDef; +using test::function::NDef; + +constexpr char kPrefetchDatasetOpName[] = "PrefetchDataset"; +constexpr char kInterleaveDatasetOpName[] = "InterleaveDataset"; +constexpr char kParallelInterleaveDatasetOpName[] = + "ParallelInterleaveDatasetV4"; +constexpr char kSeqInterleavePrefetchRewritePrefix[] = + "inject/seq_interleave_prefetch_rewrite_"; +constexpr char kFdefProtoStr[] = + R"pb(signature { + name: "parallel_interleave_fdef" + input_arg { name: "args_0" type: DT_STRING } + output_arg { name: "identity" type: DT_VARIANT } + is_stateful: true + control_output: "SSTableDataset" + } + node_def { + name: "key_prefix" + op: "Const" + attr { + key: "dtype" + value { type: DT_STRING } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape {} + string_val: "" + } + } + } + } + node_def { + name: "start_key" + op: "Const" + attr { + key: "dtype" + value { type: DT_STRING } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape {} + string_val: "" + } + } + } + } + node_def { + name: "stop_key" + op: "Const" + attr { + key: "dtype" + value { type: DT_STRING } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape {} + string_val: "" + } + } + } + } + node_def { + name: "SSTableDataset" + op: "SSTableDataset" + input: "args_0" + input: "key_prefix:output:0" + input: "start_key:output:0" + input: "stop_key:output:0" + attr { + key: "metadata" + value { s: "" } + } + attr { + key: "split_size" + value { i: 0 } + } + experimental_type { + type_id: TFT_PRODUCT + args { + type_id: TFT_DATASET + args { + type_id: TFT_TENSOR + args { type_id: TFT_STRING } + } + } + } + } + node_def { + name: "Identity" + op: "Identity" + input: "SSTableDataset:handle:0" + input: "^NoOp" + attr { + key: "T" + value { type: DT_VARIANT } + } + } + node_def { name: "NoOp" op: "NoOp" input: "^SSTableDataset" } + ret { key: "identity" value: "Identity:output:0" } + attr { + key: "_construction_context" + value { s: "kEagerRuntime" } + } + attr { + key: "_tf_data_function" + value { b: true } + } + control_ret { key: "SSTableDataset" value: "SSTableDataset" } + arg_attr { + key: 0 + value { + attr { + key: "_output_shapes" + value { list { shape {} } } + } + attr { + key: "_user_specified_name" + value { s: "args_0" } + } + } + })pb"; + +GraphDef ParallelInterleaveCase(bool deterministic) { + FunctionDef fdef; + protobuf::TextFormat::ParseFromString(kFdefProtoStr, &fdef); + return GDef( + {NDef("stop", "Const", {}, {{"value", 1}, {"dtype", DT_INT32}}), + NDef("range", "RangeDataset", {"stop"}, {}), + NDef("cycle_length", "Const", {}, {{"value", 1}, {"dtype", DT_INT32}}), + NDef("block_length", "Const", {}, {{"value", 1}, {"dtype", DT_INT32}}), + NDef("num_parallel_calls", "Const", {}, + {{"value", 1}, {"dtype", DT_INT32}}), + graph_tests_utils::MakeParallelInterleaveV4Node( + "parallel_interleave", "range", "cycle_length", "block_length", + "num_parallel_calls", "parallel_interleave_fdef", + deterministic ? "true" : "false")}, + // FunctionLib + { + fdef, + }); +} + +GraphDef MultipleParallelInterleaveCase(bool deterministic) { + FunctionDef fdef_1, fdef_2, fdef_3; + protobuf::TextFormat::ParseFromString(kFdefProtoStr, &fdef_1); + fdef_1.mutable_signature()->set_name("parallel_interleave_fdef_1"); + protobuf::TextFormat::ParseFromString(kFdefProtoStr, &fdef_2); + fdef_2.mutable_signature()->set_name("parallel_interleave_fdef_2"); + protobuf::TextFormat::ParseFromString(kFdefProtoStr, &fdef_3); + fdef_3.mutable_signature()->set_name("parallel_interleave_fdef_3"); + + auto make_parallel_interleave_node = + [&deterministic](const int node_num, const FunctionDef &fdef) { + return graph_tests_utils::MakeParallelInterleaveV4Node( + absl::StrCat("parallel_interleave_", node_num), "range", + "cycle_length", "block_length", "num_parallel_calls", + fdef.signature().name(), deterministic ? "true" : "false"); + }; + + return GDef( + {NDef("stop", "Const", {}, {{"value", 1}, {"dtype", DT_INT32}}), + NDef("range", "RangeDataset", {"stop"}, {}), + NDef("cycle_length", "Const", {}, {{"value", 1}, {"dtype", DT_INT32}}), + NDef("block_length", "Const", {}, {{"value", 1}, {"dtype", DT_INT32}}), + NDef("num_parallel_calls", "Const", {}, + {{"value", 1}, {"dtype", DT_INT32}}), + make_parallel_interleave_node(1, fdef_1), + make_parallel_interleave_node(2, fdef_2), + make_parallel_interleave_node(3, fdef_3)}, + // FunctionLib + { + fdef_1, + fdef_2, + fdef_3, + }); +} + +GraphDef InterleaveCase(bool deterministic) { + FunctionDef fdef; + protobuf::TextFormat::ParseFromString(kFdefProtoStr, &fdef); + return GDef( + {NDef("stop", "Const", {}, {{"value", 1}, {"dtype", DT_INT32}}), + NDef("range", "RangeDataset", {"stop"}, {}), + NDef("cycle_length", "Const", {}, {{"value", 1}, {"dtype", DT_INT32}}), + NDef("block_length", "Const", {}, {{"value", 1}, {"dtype", DT_INT32}}), + graph_tests_utils::MakeInterleaveNode( + "sequential_interleave", "range", "cycle_length", "block_length", + "parallel_interleave_fdef", deterministic ? "true" : "false")}, + // FunctionLib + { + fdef, + }); +} + +bool PrefetchInFunction(const NodeDef &node, + const FunctionLibraryDefinition &flib) { + auto f_attr_it = node.attr().find("f"); + if (f_attr_it == node.attr().end()) return false; + const FunctionDef *func = flib.Find(f_attr_it->second.func().name()); + if (func == nullptr) { + return false; + } + for (int i = 0; i < func->node_def_size(); i++) { + NodeDef node_in_func = func->node_def(i); + if (tensorflow::data::MatchesAnyVersion( + /*op_prefix=*/kPrefetchDatasetOpName, + /*op_to_match=*/node_in_func.op())) { + return true; + } + } + return false; +} + +bool IsInterleaveNode(const NodeDef &node) { + return (node.op() == kInterleaveDatasetOpName); +} + +} // namespace + +Status OptimizeWithInjectInterleavePrefetch(const GrapplerItem &item, + GraphDef *output) { + SeqInterleavePrefetch optimizer; + return optimizer.Optimize(nullptr, item, output); +} + +class SeqInterleavePrefetchParameterizedTest + : public ::testing::TestWithParam {}; + +TEST_P(SeqInterleavePrefetchParameterizedTest, + ParallelInterleaveHasConditionalInjection) { + GrapplerItem item; + bool deterministic = GetParam(); + item.graph = ParallelInterleaveCase(deterministic); + item.fetch.push_back("Sink"); + + GraphDef output; + TF_ASSERT_OK(OptimizeWithInjectInterleavePrefetch(item, &output)); + FunctionLibraryDefinition lib_def(OpRegistry::Global(), output.library()); + const std::string ¶llel_interleave_fdef_name = "parallel_interleave_fdef"; + const std::string &interleave_fdef_name = absl::StrCat( + kSeqInterleavePrefetchRewritePrefix, parallel_interleave_fdef_name); + if (deterministic) { + EXPECT_TRUE( + !graph_utils::ContainsGraphNodeWithName("parallel_interleave", output)); + EXPECT_TRUE(!graph_utils::ContainsNodeWithOp( + kParallelInterleaveDatasetOpName, output)); + EXPECT_TRUE( + graph_utils::ContainsNodeWithOp(kInterleaveDatasetOpName, output)); + for (auto node : output.node()) { + if (!IsInterleaveNode(node)) continue; + EXPECT_TRUE(PrefetchInFunction(node, lib_def)); + } + const FunctionDef *parallel_interleave_fdef = + lib_def.Find(parallel_interleave_fdef_name); + const FunctionDef *interleave_fdef = lib_def.Find(interleave_fdef_name); + EXPECT_EQ(parallel_interleave_fdef, nullptr); + EXPECT_NE(interleave_fdef, nullptr); + EXPECT_EQ(lib_def.ListFunctionNames().at(0), interleave_fdef_name); + EXPECT_TRUE(function_utils::FindFunctionNodeWithOp(kPrefetchDatasetOpName, + *interleave_fdef)); + } else { + EXPECT_TRUE(graph_utils::ContainsNodeWithOp( + kParallelInterleaveDatasetOpName, output)); + EXPECT_TRUE( + !graph_utils::ContainsNodeWithOp(kInterleaveDatasetOpName, output)); + EXPECT_TRUE( + graph_utils::ContainsGraphNodeWithName("parallel_interleave", output)); + EXPECT_NE(lib_def.Find(parallel_interleave_fdef_name), nullptr); + } + EXPECT_EQ(lib_def.num_functions(), 1); +} + +TEST_P(SeqInterleavePrefetchParameterizedTest, + MultipleParallelInterleavesHaveConditionalInjection) { + GrapplerItem item; + bool deterministic = GetParam(); + item.graph = MultipleParallelInterleaveCase(deterministic); + item.fetch.push_back("Sink"); + + GraphDef output; + TF_ASSERT_OK(OptimizeWithInjectInterleavePrefetch(item, &output)); + FunctionLibraryDefinition lib_def(OpRegistry::Global(), output.library()); + if (deterministic) { + EXPECT_TRUE(!graph_utils::ContainsNodeWithOp( + kParallelInterleaveDatasetOpName, output)); + EXPECT_TRUE( + graph_utils::ContainsNodeWithOp(kInterleaveDatasetOpName, output)); + for (int i = 1; i <= 3; ++i) { + EXPECT_TRUE(!graph_utils::ContainsGraphNodeWithName( + absl::StrCat("parallel_interleave_", std::to_string(i)), output)); + } + for (auto node : output.node()) { + if (!IsInterleaveNode(node)) continue; + EXPECT_TRUE(PrefetchInFunction(node, lib_def)); + } + } else { + EXPECT_TRUE(graph_utils::ContainsNodeWithOp( + kParallelInterleaveDatasetOpName, output)); + EXPECT_TRUE( + !graph_utils::ContainsNodeWithOp(kInterleaveDatasetOpName, output)); + for (int i = 1; i <= 3; ++i) { + EXPECT_TRUE(graph_utils::ContainsGraphNodeWithName( + absl::StrCat("parallel_interleave_", std::to_string(i)), output)); + } + } +} + +TEST_P(SeqInterleavePrefetchParameterizedTest, + SequentialInterleaveHasNoInjection) { + GrapplerItem item; + item.graph = InterleaveCase(/*deterministic=*/GetParam()); + item.fetch.push_back("Sink"); + + GraphDef output; + TF_ASSERT_OK(OptimizeWithInjectInterleavePrefetch(item, &output)); + EXPECT_TRUE( + graph_utils::ContainsNodeWithOp(kInterleaveDatasetOpName, output)); + EXPECT_TRUE( + graph_utils::ContainsGraphNodeWithName("sequential_interleave", output)); + FunctionLibraryDefinition lib_def(OpRegistry::Global(), output.library()); + for (auto node : output.node()) { + if (!IsInterleaveNode(node)) continue; + EXPECT_FALSE(PrefetchInFunction(node, lib_def)); + } +} + +INSTANTIATE_TEST_SUITE_P(Determinism, SeqInterleavePrefetchParameterizedTest, + ::testing::Values(false, true)); + +} // namespace grappler +} // namespace tensorflow diff --git a/tensorflow/core/grappler/optimizers/data/shuffle_and_repeat_fusion.cc b/tensorflow/core/grappler/optimizers/data/shuffle_and_repeat_fusion.cc index f0391112b32d0e..5fabf42bf03872 100644 --- a/tensorflow/core/grappler/optimizers/data/shuffle_and_repeat_fusion.cc +++ b/tensorflow/core/grappler/optimizers/data/shuffle_and_repeat_fusion.cc @@ -72,7 +72,7 @@ Status FuseShuffleV1AndRepeat(const NodeDef& shuffle_node, // Optionally set the `metadata` attribute. graph_utils::MaybeSetFusedMetadata(shuffle_node, repeat_node, fused_node); - return OkStatus(); + return absl::OkStatus(); } Status FuseShuffleV2AndRepeat(const NodeDef& shuffle_node, @@ -112,7 +112,7 @@ Status FuseShuffleV2AndRepeat(const NodeDef& shuffle_node, // Optionally set the `metadata` attribute. graph_utils::MaybeSetFusedMetadata(shuffle_node, repeat_node, fused_node); - return OkStatus(); + return absl::OkStatus(); } Status FuseShuffleV3AndRepeat(const NodeDef& shuffle_node, @@ -149,7 +149,7 @@ Status FuseShuffleV3AndRepeat(const NodeDef& shuffle_node, // Optionally set the `metadata` attribute. graph_utils::MaybeSetFusedMetadata(shuffle_node, repeat_node, fused_node); - return OkStatus(); + return absl::OkStatus(); } } // namespace @@ -205,7 +205,7 @@ Status ShuffleAndRepeatFusion::OptimizeAndCollectStats( } TF_RETURN_IF_ERROR(graph.DeleteNodes(nodes_to_delete)); - return OkStatus(); + return absl::OkStatus(); } REGISTER_GRAPH_OPTIMIZER_AS(ShuffleAndRepeatFusion, diff --git a/tensorflow/core/grappler/optimizers/data/shuffle_and_repeat_fusion.h b/tensorflow/core/grappler/optimizers/data/shuffle_and_repeat_fusion.h index d21619b92b34aa..5ce38242bbe3b0 100644 --- a/tensorflow/core/grappler/optimizers/data/shuffle_and_repeat_fusion.h +++ b/tensorflow/core/grappler/optimizers/data/shuffle_and_repeat_fusion.h @@ -32,7 +32,7 @@ class ShuffleAndRepeatFusion : public TFDataOptimizerBase { Status Init( const tensorflow::RewriterConfig_CustomGraphOptimizer* config) override { - return OkStatus(); + return absl::OkStatus(); } Status OptimizeAndCollectStats(Cluster* cluster, const GrapplerItem& item, diff --git a/tensorflow/core/grappler/optimizers/data/slack.cc b/tensorflow/core/grappler/optimizers/data/slack.cc index 61b305fabb1c26..c83a371973609c 100644 --- a/tensorflow/core/grappler/optimizers/data/slack.cc +++ b/tensorflow/core/grappler/optimizers/data/slack.cc @@ -85,7 +85,7 @@ Status Slack::RecursivelyHandleOp(const MutableGraphView& graph, } else { AddNodeAttr("slack_period", slack_period_, dataset_node); } - return OkStatus(); + return absl::OkStatus(); } if (IsDatasetNodeOfType(*dataset_node, kPassThroughOps)) { NodeDef* input_node = graph_utils::GetInputNode(*dataset_node, graph, 0); @@ -97,12 +97,12 @@ Status Slack::RecursivelyHandleOp(const MutableGraphView& graph, NodeDef* input_node = graph_utils::GetInputNode(*dataset_node, graph, i); TF_RETURN_IF_ERROR(RecursivelyHandleOp(graph, input_node)); } - return OkStatus(); + return absl::OkStatus(); } LOG(WARNING) << "Could not find a final `prefetch` in the input pipeline to " "which to introduce slack."; - return OkStatus(); + return absl::OkStatus(); } Status Slack::OptimizeAndCollectStats(Cluster* cluster, @@ -119,7 +119,8 @@ Status Slack::OptimizeAndCollectStats(Cluster* cluster, // If the GrapplerItem is derived from a FunctionDef, we don't optimize it, // because we only want to add slack to the prefetch on the main dataset // pipeline. - if (graph_utils::IsItemDerivedFromFunctionDef(item, graph)) return OkStatus(); + if (graph_utils::IsItemDerivedFromFunctionDef(item, graph)) + return absl::OkStatus(); if (item.fetch.size() != 1) { return errors::InvalidArgument( diff --git a/tensorflow/core/grappler/optimizers/data/slack.h b/tensorflow/core/grappler/optimizers/data/slack.h index 8ce90841ebf49c..b39cfc65094567 100644 --- a/tensorflow/core/grappler/optimizers/data/slack.h +++ b/tensorflow/core/grappler/optimizers/data/slack.h @@ -45,7 +45,7 @@ class Slack : public TFDataOptimizerBase { return errors::InvalidArgument("Invalid `slack_period` parameter: ", slack_period_param); } - return OkStatus(); + return absl::OkStatus(); } Status OptimizeAndCollectStats(Cluster* cluster, const GrapplerItem& item, diff --git a/tensorflow/core/grappler/optimizers/data/split_utils.cc b/tensorflow/core/grappler/optimizers/data/split_utils.cc index d023b9fbd86ded..1798f1de44f054 100644 --- a/tensorflow/core/grappler/optimizers/data/split_utils.cc +++ b/tensorflow/core/grappler/optimizers/data/split_utils.cc @@ -178,7 +178,7 @@ Status InputRewriter::RewriteInput(absl::string_view input_str, auto iter = input_map_.find(input_str); if (iter != input_map_.end()) { *new_input_str = iter->second; - return OkStatus(); + return absl::OkStatus(); } if (IsControlInput(input_str)) { @@ -189,7 +189,7 @@ Status InputRewriter::RewriteInput(absl::string_view input_str, TF_RETURN_IF_ERROR(RewriteNodeInput(input_str, new_input_str)); } input_map_.insert({input_str, *new_input_str}); - return OkStatus(); + return absl::OkStatus(); } Status InputRewriter::RewriteControlInput(absl::string_view input_str, @@ -201,7 +201,7 @@ Status InputRewriter::RewriteControlInput(absl::string_view input_str, } else { *new_input_str = string{input_str}; } - return OkStatus(); + return absl::OkStatus(); } Status InputRewriter::RewriteArgumentInput(absl::string_view input_str, @@ -233,7 +233,7 @@ Status InputRewriter::RewriteArgumentInput(absl::string_view input_str, original_function_.signature().input_arg_size() - num_captured_inputs_) { // Argument is a captured input. No need to modify argument string. *new_input_str = string{input_str}; - return OkStatus(); + return absl::OkStatus(); } const OpDef::ArgDef* found_arg_def = &original_function_.signature().input_arg(i); @@ -266,7 +266,7 @@ Status InputRewriter::RewriteNodeInput(absl::string_view input_str, components.size() == 3 ? components[2] : "0"; if (!IsInFirstFunction(node_name)) { *new_input_str = string{input_str}; - return OkStatus(); + return absl::OkStatus(); } auto index_iter = name_to_node_.find(node_name); @@ -352,7 +352,7 @@ Status InputRewriter::RewriteCrossFunctionInput( added_input_arg->set_description(absl::StrCat("Input ", input_index)); *new_input_str = added_input_arg->name(); - return OkStatus(); + return absl::OkStatus(); } void InitializeSignatures( @@ -400,7 +400,7 @@ void InitializeSignatures( } // namespace -StatusOr SplitFunction( +absl::StatusOr SplitFunction( const FunctionDef& function, const absl::flat_hash_set& nodes_in_first_function, int64_t num_captured_inputs, const FunctionLibraryDefinition& library) { diff --git a/tensorflow/core/grappler/optimizers/data/split_utils.h b/tensorflow/core/grappler/optimizers/data/split_utils.h index ac9951952fccac..df4c52b2176ad1 100644 --- a/tensorflow/core/grappler/optimizers/data/split_utils.h +++ b/tensorflow/core/grappler/optimizers/data/split_utils.h @@ -64,7 +64,7 @@ struct SplitResults { // Splitting functions in certain cases is unimplemented, in which case an // Unimplemented status will be returned. Grappler passes must gracefully handle // Unimplemented statuses without returning the error to its caller. -StatusOr SplitFunction( +absl::StatusOr SplitFunction( const FunctionDef& function, const absl::flat_hash_set& nodes_in_first_function, int64_t num_captured_inputs, const FunctionLibraryDefinition& library); diff --git a/tensorflow/core/grappler/optimizers/data/use_private_thread_pool.cc b/tensorflow/core/grappler/optimizers/data/use_private_thread_pool.cc index 4bf2e02635bd88..917c66939d8208 100644 --- a/tensorflow/core/grappler/optimizers/data/use_private_thread_pool.cc +++ b/tensorflow/core/grappler/optimizers/data/use_private_thread_pool.cc @@ -42,7 +42,8 @@ Status UsePrivateThreadPool::OptimizeAndCollectStats(Cluster* cluster, MutableGraphView graph(output); // If the GrapplerItem is derived from a FunctionDef, we don't optimize it. - if (graph_utils::IsItemDerivedFromFunctionDef(item, graph)) return OkStatus(); + if (graph_utils::IsItemDerivedFromFunctionDef(item, graph)) + return absl::OkStatus(); if (item.fetch.size() != 1) { return errors::InvalidArgument( @@ -54,7 +55,7 @@ Status UsePrivateThreadPool::OptimizeAndCollectStats(Cluster* cluster, if (node.op() == kPrivateThreadPoolDataset) { // If private thread pool is set by the user, we keep the user setting // instead of rewriting it. - return OkStatus(); + return absl::OkStatus(); } } @@ -89,14 +90,14 @@ Status UsePrivateThreadPool::OptimizeAndCollectStats(Cluster* cluster, // attrs from the input node. If we fail to set the attributes, we abort the // rewrite. if (!graph_utils::CopyShapesAndTypesAttrs(*last_node, &insert_node)) - return OkStatus(); + return absl::OkStatus(); auto* added_node = graph.AddNode(std::move(insert_node)); TF_RETURN_IF_ERROR( graph.UpdateFanouts(last_node->name(), added_node->name())); stats->num_changes++; - return OkStatus(); + return absl::OkStatus(); } REGISTER_GRAPH_OPTIMIZER_AS(UsePrivateThreadPool, "use_private_thread_pool"); diff --git a/tensorflow/core/grappler/optimizers/data/use_private_thread_pool.h b/tensorflow/core/grappler/optimizers/data/use_private_thread_pool.h index 8f2868af3852af..f515b3afb41371 100644 --- a/tensorflow/core/grappler/optimizers/data/use_private_thread_pool.h +++ b/tensorflow/core/grappler/optimizers/data/use_private_thread_pool.h @@ -33,7 +33,7 @@ class UsePrivateThreadPool : public TFDataOptimizerBase { Status Init( const tensorflow::RewriterConfig_CustomGraphOptimizer* config) override { - return OkStatus(); + return absl::OkStatus(); } Status OptimizeAndCollectStats(Cluster* cluster, const GrapplerItem& item, diff --git a/tensorflow/core/grappler/optimizers/inference/batch_op_rewriter.cc b/tensorflow/core/grappler/optimizers/inference/batch_op_rewriter.cc index 822dbf5c2b4552..1a0cdbc1d2833c 100644 --- a/tensorflow/core/grappler/optimizers/inference/batch_op_rewriter.cc +++ b/tensorflow/core/grappler/optimizers/inference/batch_op_rewriter.cc @@ -152,7 +152,7 @@ Status BatchOpRewriter::Init( // (e.g., enable_adaptive_shared_batching_thread_pool is false), proto // is considered as empty. VLOG(2) << "Empty batch-op rewrite config"; - return OkStatus(); + return absl::OkStatus(); } if (!absl::Base64Unescape(params.s(), &unencoded)) { return absl::InternalError( @@ -163,7 +163,7 @@ Status BatchOpRewriter::Init( "Failed to parse batch_op_rewrite_config from params."); } VLOG(2) << "BatchOp Rewrite config is " << config_.DebugString(); - return OkStatus(); + return absl::OkStatus(); } Status BatchOpRewriter::Optimize(Cluster* cluster, const GrapplerItem& item, @@ -270,7 +270,7 @@ Status BatchOpRewriter::Optimize(Cluster* cluster, const GrapplerItem& item, } if (asbs_overridden) { - return OkStatus(); + return absl::OkStatus(); } if (config_.enable_adaptive_shared_batching_thread_pool()) { @@ -280,7 +280,7 @@ Status BatchOpRewriter::Optimize(Cluster* cluster, const GrapplerItem& item, batch_op); }); } - return OkStatus(); + return absl::OkStatus(); } REGISTER_GRAPH_OPTIMIZER_AS(BatchOpRewriter, "batch_op_rewrite"); diff --git a/tensorflow/core/grappler/optimizers/loop_optimizer.cc b/tensorflow/core/grappler/optimizers/loop_optimizer.cc index 148047a62553ed..94c2c22f472f19 100644 --- a/tensorflow/core/grappler/optimizers/loop_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/loop_optimizer.cc @@ -111,7 +111,7 @@ Status LoopInvariantNodeMotionOptimizer::HandleInvariantEnter( } } } - return OkStatus(); + return absl::OkStatus(); } Status LoopInvariantNodeMotionOptimizer::HandleConst(NodeDef* node, @@ -183,7 +183,7 @@ Status LoopInvariantNodeMotionOptimizer::HandleConst(NodeDef* node, const_node->add_input(ctrl_dep); node_map_->AddOutput(NodeName(ctrl_dep), const_node->name()); } - return OkStatus(); + return absl::OkStatus(); } Status LoopInvariantNodeMotionOptimizer::HandleInvariantNode( @@ -197,7 +197,7 @@ Status LoopInvariantNodeMotionOptimizer::HandleInvariantNode( } } if (num_outputs == 0) { - return OkStatus(); + return absl::OkStatus(); } DataTypeVector input_types; @@ -252,7 +252,7 @@ Status LoopInvariantNodeMotionOptimizer::HandleInvariantNode( } } } - return OkStatus(); + return absl::OkStatus(); } Status LoopInvariantNodeMotionOptimizer::MoveInvariantNodes( @@ -270,7 +270,7 @@ Status LoopInvariantNodeMotionOptimizer::MoveInvariantNodes( HandleInvariantNode(invariant_node, num_outputs, frame_id)); } } - return OkStatus(); + return absl::OkStatus(); } Status LoopInvariantNodeMotionOptimizer::RevertInvariantNodes() { @@ -327,7 +327,7 @@ Status LoopInvariantNodeMotionOptimizer::RevertInvariantNodes() { } } } - return OkStatus(); + return absl::OkStatus(); } Status LoopInvariantNodeMotionOptimizer::FindInvariantNodes( @@ -376,7 +376,7 @@ Status LoopInvariantNodeMotionOptimizer::FindInvariantNodes( } } } - return OkStatus(); + return absl::OkStatus(); } Status LoopInvariantNodeMotionOptimizer::Optimize() { @@ -450,7 +450,7 @@ Status LoopInvariantNodeMotionOptimizer::Optimize() { TF_RETURN_IF_ERROR(MoveInvariantNodes(frame_id)); } - return OkStatus(); + return absl::OkStatus(); } std::vector GetStackPushNodesToConvert( @@ -544,7 +544,7 @@ Status RemoveStackOps(const std::unordered_set& nodes_to_preserve, } } } - return OkStatus(); + return absl::OkStatus(); } bool IsSimpleBinaryOperator(const NodeDef& node) { @@ -582,7 +582,7 @@ Status EvaluateBoolOpForConstantOperands(const NodeDef& op_node, *value = outputs[0].tensor->scalar()(); delete outputs[0].tensor; - return OkStatus(); + return absl::OkStatus(); } // TODO(lyandy): Consolidate with ConstantFolding implementation. @@ -614,7 +614,7 @@ Status CheckForDeadFanout(const MutableGraphView& view, CHECK(selector.FromProto(switch_predicate->attr().at("value").tensor())); *has_dead_fanout = true; *dead_fanout = selector.scalar()() ? 0 : 1; - return OkStatus(); + return absl::OkStatus(); } GraphView::InputPort switch_input_port(&switch_node, 0); @@ -625,7 +625,7 @@ Status CheckForDeadFanout(const MutableGraphView& view, // operator which returns false for the initialization value. // TODO(srjoglekar): Improve to work with arbitrary predicate subgraphs. if (!IsMerge(*switch_input) || !IsLoopCond(*switch_predicate)) { - return OkStatus(); + return absl::OkStatus(); } VLOG(4) << "Try to find a zero iteration while loop:" @@ -634,7 +634,7 @@ Status CheckForDeadFanout(const MutableGraphView& view, // Find the boolean predicate from a LoopCond node (e.g. Greater). NodeDef* switch_ctrl_node = view.GetRegularFanin({switch_predicate, 0}).node; if (!switch_ctrl_node || !IsSimpleBinaryOperator(*switch_ctrl_node)) { - return OkStatus(); + return absl::OkStatus(); } // Find the Merge node & the Constant Operand to the condition node, if @@ -656,7 +656,7 @@ Status CheckForDeadFanout(const MutableGraphView& view, } } if (merge_node == nullptr || constant_ctrl_input == nullptr) { - return OkStatus(); + return absl::OkStatus(); } // Find the initialization constant (via Enter, if one exists). @@ -672,7 +672,7 @@ Status CheckForDeadFanout(const MutableGraphView& view, } } if (enter_node != nullptr) { - if (constant_init_node != nullptr) return OkStatus(); + if (constant_init_node != nullptr) return absl::OkStatus(); for (const auto& input : enter_node->input()) { NodeDef* node = node_map.GetNode(input); if (IsReallyConstant(*node, feed_nodes)) { @@ -681,7 +681,7 @@ Status CheckForDeadFanout(const MutableGraphView& view, } } if (constant_init_node == nullptr) { - return OkStatus(); + return absl::OkStatus(); } VLOG(4) << "Check if loop will be 0 iterations:" @@ -712,7 +712,7 @@ Status CheckForDeadFanout(const MutableGraphView& view, } else { VLOG(4) << "Was not able to prove that loop has 0 iterations."; } - return OkStatus(); + return absl::OkStatus(); } } // namespace @@ -756,7 +756,7 @@ Status LoopOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item, feed_nodes, optimized_graph)); } - return OkStatus(); + return absl::OkStatus(); } // An Identity node has only 1 output, but Switch and Merge nodes have 2. @@ -781,7 +781,7 @@ static Status update_identity_node_type(NodeDef* sw_node) { *(new_t.add_args()) = old_t.args()[0]; *(sw_node->mutable_experimental_type()) = new_t; } - return OkStatus(); + return absl::OkStatus(); } Status LoopOptimizer::RemoveDeadBranches( @@ -933,14 +933,14 @@ Status LoopOptimizer::RemoveDeadBranches( LOG(WARNING) << "Skipping loop optimization for Merge node with control input: " << merge_node->name(); - return OkStatus(); + return absl::OkStatus(); } else if (dead_inputs.size() != 1 || num_data_inputs != 2) { LOG(WARNING) << "Skipping loop optimization for Merge node (" << merge_node->name() << ") with unexpected dead_inputs.size() (" << dead_inputs.size() << " or num_data_inputs" << num_data_inputs; - return OkStatus(); + return absl::OkStatus(); } } @@ -1013,7 +1013,7 @@ Status LoopOptimizer::RemoveDeadBranches( } EraseNodesFromGraph(std::move(nodes_idx_to_delete), optimized_graph); - return OkStatus(); + return absl::OkStatus(); } } // end namespace grappler diff --git a/tensorflow/core/grappler/optimizers/memory_optimizer.cc b/tensorflow/core/grappler/optimizers/memory_optimizer.cc index c9cf8dbc9fb03e..08e9eb9d23815c 100644 --- a/tensorflow/core/grappler/optimizers/memory_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/memory_optimizer.cc @@ -794,7 +794,7 @@ Status BuildSwapPair(NodeDef* node, int input_to_swap, (*swap_out_node->mutable_attr())["T"].set_type(input_type); *swap_pair = std::make_pair(swap_out_node, swap_in_node); - return OkStatus(); + return absl::OkStatus(); } struct SwapInfo { @@ -1313,7 +1313,7 @@ Status FindAssignNodesToRelax(const GraphDef& graph, } if (!found_send && devices.size() == 1) { nodes_to_relax->insert(assign_nodes.begin(), assign_nodes.end()); - return OkStatus(); + return absl::OkStatus(); } GraphTopologyView graph_view; @@ -1373,7 +1373,7 @@ Status FindAssignNodesToRelax(const GraphDef& graph, } } } - return OkStatus(); + return absl::OkStatus(); } } // namespace @@ -1439,7 +1439,7 @@ Status MemoryOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item, } optimized_graph->Swap(&optimized_item.graph); - return OkStatus(); + return absl::OkStatus(); } } // end namespace grappler diff --git a/tensorflow/core/grappler/optimizers/meta_optimizer.cc b/tensorflow/core/grappler/optimizers/meta_optimizer.cc index 999e7c0dc6d092..5fe1d762b2c517 100644 --- a/tensorflow/core/grappler/optimizers/meta_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/meta_optimizer.cc @@ -183,7 +183,7 @@ Status GetGraphDevice(const GraphDef& g_def, std::set* devices) { } devices->insert(parsed_name.type); } - return OkStatus(); + return absl::OkStatus(); } } // namespace @@ -279,7 +279,7 @@ Status MetaOptimizer::InitializeOptimizers( const std::set& device_types, std::vector>* optimizers) const { if (cfg_.disable_meta_optimizer()) { - return OkStatus(); + return absl::OkStatus(); } ConfigList plugin_configs = PluginGraphOptimizerRegistry::GetPluginConfigs( @@ -574,13 +574,14 @@ Status MetaOptimizer::InitializeCustomGraphOptimizers( Status MetaOptimizer::InitializePluginGraphOptimizers( const std::set& device_types, std::vector>* optimizers) const { - if (cfg_.use_plugin_optimizers() == RewriterConfig::OFF) return OkStatus(); + if (cfg_.use_plugin_optimizers() == RewriterConfig::OFF) + return absl::OkStatus(); auto plugin_optimizers = PluginGraphOptimizerRegistry::CreateOptimizers(device_types); for (auto& plugin_optimizer : plugin_optimizers) { optimizers->push_back(std::move(plugin_optimizer)); } - return OkStatus(); + return absl::OkStatus(); } const RewriterConfig::CustomGraphOptimizer* @@ -758,7 +759,7 @@ Status MetaOptimizer::OptimizeGraph( VLOG(3) << "Skipping optimization, graph has less than " << min_graph_nodes << " nodes."; *optimized_graph = item.graph; - return OkStatus(); + return absl::OkStatus(); } tensorflow::metrics::ScopedCounter<2> timings( @@ -789,7 +790,7 @@ Status MetaOptimizer::OptimizeGraph( if (optimizers.empty()) { VLOG(3) << "Skipping graph optimization, no optimizers registered"; *optimized_graph = item.graph; - return OkStatus(); + return absl::OkStatus(); } // Invariant: optimized_graph contains the most recently optimized version of @@ -892,7 +893,7 @@ Status MetaOptimizer::OptimizeGraph( DCHECK_EQ(optimized_graph->versions().producer(), original_producer); } - return OkStatus(); + return absl::OkStatus(); } Status MetaOptimizer::OptimizeGraph(Cluster* cluster, GrapplerItem&& item, @@ -950,7 +951,7 @@ Status MetaOptimizer::RunOptimizer( message = strings::StrCat(optimizer->name(), " did nothing. time = ", duration_ms, "ms."); // Swallow the non-critical error. - status = OkStatus(); + status = absl::OkStatus(); } else if (absl::IsDeadlineExceeded(status)) { message = strings::StrCat(status.ToString(), ", time = ", duration_ms, "ms."); @@ -983,7 +984,7 @@ Status MetaOptimizer::RunOptimizer( if (absl::StartsWith(optimizer->name(), "tfg_optimizer")) return status; } - return OkStatus(); + return absl::OkStatus(); } // Propagates `_tf_data_function` attributes from functions to their callees. @@ -1311,7 +1312,7 @@ Status MetaOptimizer::OptimizeConsumeItem(Cluster* cluster, GrapplerItem&& item, *optimized_graph); } - return OkStatus(); + return absl::OkStatus(); } string MetaOptimizer::GetResultString() const { @@ -1379,7 +1380,7 @@ Status OptimizeGraph( const GrapplerItem::OptimizationOptions& optimization_options, std::unique_ptr* g) { if (!tensorflow::grappler::MetaOptimizerEnabled(config_proto)) { - return OkStatus(); + return absl::OkStatus(); } tensorflow::grappler::GrapplerItem item; @@ -1459,7 +1460,7 @@ Status OptimizeGraph( } *g = std::move(optimized_graph); - return OkStatus(); + return absl::OkStatus(); } } // namespace grappler diff --git a/tensorflow/core/grappler/optimizers/meta_optimizer_test.cc b/tensorflow/core/grappler/optimizers/meta_optimizer_test.cc index a8c9aed523ad62..d1ad7e5a580c3c 100644 --- a/tensorflow/core/grappler/optimizers/meta_optimizer_test.cc +++ b/tensorflow/core/grappler/optimizers/meta_optimizer_test.cc @@ -53,14 +53,14 @@ class TestOptimizer : public CustomGraphOptimizer { Status Init(const tensorflow::RewriterConfig_CustomGraphOptimizer* config = nullptr) override { - return OkStatus(); + return absl::OkStatus(); } Status Optimize(Cluster* cluster, const GrapplerItem& item, GraphDef* optimized_graph) override { optimized_ = true; *optimized_graph = item.graph; - return OkStatus(); + return absl::OkStatus(); } private: @@ -83,7 +83,7 @@ class TestOptimizerWithParams : public TestOptimizer { Status Init( const tensorflow::RewriterConfig_CustomGraphOptimizer* config) override { CHECK(config != nullptr); - return OkStatus(); + return absl::OkStatus(); } }; @@ -107,7 +107,7 @@ class GrapplerItemPropertiesAccumulator : public CustomGraphOptimizer { Status Init( const tensorflow::RewriterConfig_CustomGraphOptimizer* config) override { - return OkStatus(); + return absl::OkStatus(); } Status Optimize(Cluster* cluster, const GrapplerItem& item, @@ -116,7 +116,7 @@ class GrapplerItemPropertiesAccumulator : public CustomGraphOptimizer { if (optimization_options_) { optimization_options_->insert({item.id, item.optimization_options()}); } - return OkStatus(); + return absl::OkStatus(); } private: @@ -711,7 +711,7 @@ class SleepingOptimizer : public CustomGraphOptimizer { Status Init( const tensorflow::RewriterConfig_CustomGraphOptimizer* config) override { - return OkStatus(); + return absl::OkStatus(); } Status Optimize(Cluster* cluster, const GrapplerItem& item, @@ -720,7 +720,7 @@ class SleepingOptimizer : public CustomGraphOptimizer { Env::Default()->SleepForMicroseconds(1000000); GRAPPLER_RETURN_IF_DEADLINE_EXCEEDED(); optimized_graph->add_node(); - return OkStatus(); + return absl::OkStatus(); } }; @@ -1201,14 +1201,14 @@ class TfDataTestOptimizer : public CustomGraphOptimizer { Status Init( const tensorflow::RewriterConfig_CustomGraphOptimizer* config) override { - return OkStatus(); + return absl::OkStatus(); } Status Optimize(Cluster* cluster, const GrapplerItem& item, GraphDef* optimized_graph) override { ++count_; *optimized_graph = item.graph; - return OkStatus(); + return absl::OkStatus(); } private: diff --git a/tensorflow/core/grappler/optimizers/mkl_remapper_test.cc b/tensorflow/core/grappler/optimizers/mkl_remapper_test.cc index e53086933deb29..56aa21a7ab09a3 100644 --- a/tensorflow/core/grappler/optimizers/mkl_remapper_test.cc +++ b/tensorflow/core/grappler/optimizers/mkl_remapper_test.cc @@ -39,6 +39,9 @@ class MklRemapperTest : public GrapplerTest { void FuseConv2DWithBiasAndAddNOrAdd(const string& data_format, const string& activation, string add_op, bool add_with_bcast) { +#ifdef DNNL_AARCH64_USE_ACL + GTEST_SKIP() << "Skipping test due to different behaviour on AARCH64"; +#endif if (!IsMKLEnabled()) GTEST_SKIP() << "Test only applicable to oneDNN."; using ::tensorflow::ops::Placeholder; tensorflow::Scope s = tensorflow::Scope::NewRootScope(); @@ -216,6 +219,7 @@ CREATE_CONV2DFUSION_ADD_BCAST_TEST(AddV2); #undef CREATE_CONV2DFUSION_ADD_ACTIVATION_TEST #undef CREATE_CONV2DFUSION_TEST +#ifndef DNNL_AARCH64_USE_ACL #define REGISTER_TEST(NAME, T, INPUT) \ TEST_F(MklRemapperTest, NAME##_##T) { \ if (!IsMKLEnabled()) GTEST_SKIP() << "Test only applicable to oneDNN."; \ @@ -310,6 +314,7 @@ CREATE_CONV2DFUSION_ADD_BCAST_TEST(AddV2); } REGISTER_TEST_ALL_TYPES(FuseDepthwiseConv2DWithBiasAndActivation); #undef REGISTER_TEST +#endif TEST_F(MklRemapperTest, FuseBatchNormWithRelu) { if (!IsMKLEnabled()) GTEST_SKIP() << "Test only applicable to oneDNN."; @@ -444,12 +449,15 @@ TEST_F(MklRemapperTest, FuseBatchNormWithRelu) { ASSERT_EQ(tensors_expected.size(), 1); auto tensors = EvaluateNodes(output, item.fetch, item.feed); ASSERT_EQ(tensors.size(), 1); - test::ExpectTensorNear(tensors[0], tensors_expected[0], 1e-6); + test::ExpectTensorNear(tensors[0], tensors_expected[0], 2e-5); } } } TEST_F(MklRemapperTest, FuseMatMulWithBiasAddAndAdd) { +#ifdef DNNL_AARCH64_USE_ACL + GTEST_SKIP() << "Skipping test due to different behaviour on AARCH64"; +#endif if (!IsMKLEnabled()) GTEST_SKIP() << "Test only applicable to oneDNN."; using ::tensorflow::ops::Placeholder; @@ -543,6 +551,9 @@ class RelpaceAddWithBiasAddTest : public GrapplerTest { protected: template void RelpaceAddWithBiasAddDepthwiseConv2D(const string& add_op) { +#ifdef DNNL_AARCH64_USE_ACL + GTEST_SKIP() << "Skipping test due to different behaviour on AARCH64"; +#endif if (!IsMKLEnabled()) GTEST_SKIP() << "Test only applicable to oneDNN."; using ::tensorflow::ops::Placeholder; @@ -1087,6 +1098,9 @@ class MklRemapperConv2dBiasAddSwishTest : public GrapplerTest { protected: template void RunTest() { +#ifdef DNNL_AARCH64_USE_ACL + GTEST_SKIP() << "Skipping test due to different behaviour on AARCH64"; +#endif if (!IsMKLEnabled()) GTEST_SKIP() << "Test only applicable to oneDNN."; if (!IsDataTypeSupportedByOneDNNOnThisCPU(DTYPE)) GTEST_SKIP() << "Intel oneDNN with " << DataType_Name(DTYPE) @@ -1172,6 +1186,9 @@ class MklRemapperConv2dFusedBatchNormSwishTest : public GrapplerTest { protected: template void RunTest() { +#ifdef DNNL_AARCH64_USE_ACL + GTEST_SKIP() << "Skipping test due to different behaviour on AARCH64"; +#endif if (!IsMKLEnabled()) GTEST_SKIP() << "Test only applicable to oneDNN."; using ::tensorflow::ops::Placeholder; @@ -1557,6 +1574,9 @@ class FusedConvBiasAddAndHardSwishTest : public GrapplerTest { template void RunTest(const string& add_op, const bool is_depthwise) { +#ifdef DNNL_AARCH64_USE_ACL + GTEST_SKIP() << "Skipping test due to different behaviour on AARCH64"; +#endif if (!IsMKLEnabled()) GTEST_SKIP() << "Test only applicable to oneDNN."; if (!IsDataTypeSupportedByOneDNNOnThisCPU(DType)) GTEST_SKIP() << "Intel oneDNN with " << DataType_Name(DType) diff --git a/tensorflow/core/grappler/optimizers/model_pruner.cc b/tensorflow/core/grappler/optimizers/model_pruner.cc index b8ed4c24c6c750..cbc630b4952b95 100644 --- a/tensorflow/core/grappler/optimizers/model_pruner.cc +++ b/tensorflow/core/grappler/optimizers/model_pruner.cc @@ -381,7 +381,7 @@ Status RewriteIdentityNAndInputsOutputs( } mutable_inputs->DeleteSubrange(curr_pos, num_inputs - curr_pos); - return OkStatus(); + return absl::OkStatus(); } Status SplitIdentityNInputs(GraphDef* graph, @@ -413,7 +413,7 @@ Status SplitIdentityNInputs(GraphDef* graph, *updated_graph = true; } - return OkStatus(); + return absl::OkStatus(); } } // namespace @@ -504,7 +504,7 @@ Status ModelPruner::Optimize(Cluster* cluster, const GrapplerItem& item, *optimized_graph->mutable_versions() = item.graph.versions(); if (nodes_to_delete.empty()) { optimized_graph->mutable_node()->Swap(pruned_graph->mutable_node()); - return OkStatus(); + return absl::OkStatus(); } const bool fetches_are_known = !item.fetch.empty(); @@ -526,7 +526,7 @@ Status ModelPruner::Optimize(Cluster* cluster, const GrapplerItem& item, if (optimized_graph->node_size() > item.graph.node_size()) { return errors::Internal("Pruning increased graph size."); } - return OkStatus(); + return absl::OkStatus(); } } // end namespace grappler diff --git a/tensorflow/core/grappler/optimizers/pin_to_host_optimizer.cc b/tensorflow/core/grappler/optimizers/pin_to_host_optimizer.cc index 7b41ff2c4dc75b..0792f6f0b2dd06 100644 --- a/tensorflow/core/grappler/optimizers/pin_to_host_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/pin_to_host_optimizer.cc @@ -78,7 +78,7 @@ Status TryFindKernelDef(const std::vector& devices, if (kdef) { *kdef = kernel; } - return OkStatus(); + return absl::OkStatus(); } } @@ -95,7 +95,7 @@ Status IsNodeOutputPortHostFriendly(const GraphView& graph, // Make sure we are not a denylisted op. if (IsDenylisted(node)) { - return OkStatus(); + return absl::OkStatus(); } // Check to make sure we have the right properties (i.e., statically shaped). @@ -112,10 +112,10 @@ Status IsNodeOutputPortHostFriendly(const GraphView& graph, << " but output_properties.size()=" << output_properties.size() << "\n" << node.DebugString(); - return OkStatus(); + return absl::OkStatus(); } if (!IsTensorSmall(output_properties[port_id])) { - return OkStatus(); + return absl::OkStatus(); } // These nodes may be optimized away downstream (even if pinned to Host), we @@ -126,17 +126,17 @@ Status IsNodeOutputPortHostFriendly(const GraphView& graph, TF_RETURN_IF_ERROR(IsNodeOutputPortHostFriendly( graph, properties, *fanin.node, fanin.port_id, &fanin_candidate)); if (!fanin_candidate) { - return OkStatus(); + return absl::OkStatus(); } } *is_candidate = true; - return OkStatus(); + return absl::OkStatus(); } // Check if op's device is on CPU. if (absl::StrContains(node.device(), DEVICE_CPU)) { *is_candidate = true; - return OkStatus(); + return absl::OkStatus(); } // Check if op's output port is pinned to HostMemory. @@ -144,7 +144,7 @@ Status IsNodeOutputPortHostFriendly(const GraphView& graph, Status s = OpRegistry::Global()->LookUpOpDef(node.op(), &op); if (!s.ok()) { LOG(WARNING) << "Could not find OpDef for : " << node.op(); - return OkStatus(); + return absl::OkStatus(); } // Map the port_id to output_arg_id. @@ -153,7 +153,7 @@ Status IsNodeOutputPortHostFriendly(const GraphView& graph, LOG(WARNING) << "Invalid port: " << port_id << "!\n" << node.DebugString() << "\n" << op->DebugString(); - return OkStatus(); + return absl::OkStatus(); } // Find the kernel. @@ -162,7 +162,7 @@ Status IsNodeOutputPortHostFriendly(const GraphView& graph, &kernel); if (!s.ok()) { LOG(INFO) << "Could not find KernelDef for: " << node.op(); - return OkStatus(); + return absl::OkStatus(); } // Check if the output_arg is pinned to Host. @@ -173,7 +173,7 @@ Status IsNodeOutputPortHostFriendly(const GraphView& graph, } } - return OkStatus(); + return absl::OkStatus(); } // Checks if a node's input port is Host friendly. @@ -225,18 +225,18 @@ Status IsNodeHostCandidate(const GraphView& graph, GraphProperties* properties, // Check if node already on CPU. if (absl::StrContains(node.device(), DEVICE_CPU)) { *is_candidate = true; - return OkStatus(); + return absl::OkStatus(); } // Skip these node types. if (IsDenylisted(node)) { - return OkStatus(); + return absl::OkStatus(); } // Check the node can be run on CPU. Status s = TryFindKernelDef({DEVICE_CPU}, node, nullptr); if (!s.ok()) { - return OkStatus(); + return absl::OkStatus(); } // Check all inputs are Host friendly. @@ -246,7 +246,7 @@ Status IsNodeHostCandidate(const GraphView& graph, GraphProperties* properties, TF_RETURN_IF_ERROR(IsNodeOutputPortHostFriendly( graph, properties, *fanin.node, fanin.port_id, &fanin_candidate)); if (!fanin_candidate) { - return OkStatus(); + return absl::OkStatus(); } } @@ -259,12 +259,12 @@ Status IsNodeHostCandidate(const GraphView& graph, GraphProperties* properties, } for (const auto& prop : properties->GetOutputProperties(node.name())) { if (!IsTensorSmall(prop)) { - return OkStatus(); + return absl::OkStatus(); } } *is_candidate = true; - return OkStatus(); + return absl::OkStatus(); } // Tries to find a Host device from `devices`. Returns empty string if no @@ -301,7 +301,7 @@ Status PinToHostOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item, // Skip Legacy TPU bridge graphs. if (IsLegacyTPUBridgeGraphDef(*optimized_graph)) { - return OkStatus(); + return absl::OkStatus(); } GraphProperties properties(item); @@ -364,7 +364,7 @@ Status PinToHostOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item, } } } - return OkStatus(); + return absl::OkStatus(); } } // end namespace grappler diff --git a/tensorflow/core/grappler/optimizers/remapper.cc b/tensorflow/core/grappler/optimizers/remapper.cc index b6d3d69af6cfc2..7fe058101d0a75 100644 --- a/tensorflow/core/grappler/optimizers/remapper.cc +++ b/tensorflow/core/grappler/optimizers/remapper.cc @@ -3275,7 +3275,7 @@ Status AddFusedContractionNode(RemapperContext* ctx, (*invalidated_nodes)[matched.bias_add] = true; (*nodes_to_delete)[matched.contraction] = true; - return OkStatus(); + return absl::OkStatus(); } Status AddFusedContractionNode(RemapperContext* ctx, @@ -3328,7 +3328,7 @@ Status AddFusedContractionNode(RemapperContext* ctx, (*nodes_to_delete)[matched.contraction] = true; (*invalidated_nodes)[matched.activation] = true; - return OkStatus(); + return absl::OkStatus(); } Status AddFusedContractionNode( @@ -3382,7 +3382,7 @@ Status AddFusedContractionNode( (*nodes_to_delete)[matched.bias_add] = true; (*invalidated_nodes)[matched.activation] = true; - return OkStatus(); + return absl::OkStatus(); } Status AddFusedConvNode(RemapperContext* ctx, @@ -3437,7 +3437,7 @@ Status AddFusedConvNode(RemapperContext* ctx, (*invalidated_nodes)[matched.bias_add] = true; (*nodes_to_delete)[matched.squeeze] = true; - return OkStatus(); + return absl::OkStatus(); } Status AddFusedConv2DNode(RemapperContext* ctx, @@ -3476,7 +3476,7 @@ Status AddFusedConv2DNode(RemapperContext* ctx, (*invalidated_nodes)[matched.fused_batch_norm] = true; (*nodes_to_delete)[matched.contraction] = true; - return OkStatus(); + return absl::OkStatus(); } Status AddFusedConv2DNode(RemapperContext* ctx, @@ -3521,7 +3521,7 @@ Status AddFusedConv2DNode(RemapperContext* ctx, (*nodes_to_delete)[matched.contraction] = true; (*nodes_to_delete)[matched.fused_batch_norm] = true; - return OkStatus(); + return absl::OkStatus(); } Status AddFusedContractionNode(RemapperContext* ctx, @@ -3575,7 +3575,7 @@ Status AddFusedContractionNode(RemapperContext* ctx, (*nodes_to_delete)[matched.contraction] = true; (*nodes_to_delete)[matched.bias_add] = true; - return OkStatus(); + return absl::OkStatus(); } Status AddFusedConv3DNode(RemapperContext* ctx, const PadWithConv3D& matched, @@ -3616,7 +3616,7 @@ Status AddFusedConv3DNode(RemapperContext* ctx, const PadWithConv3D& matched, } else { VLOG(2) << "Pad fusion with " << contraction.op() << " is invalidated, " << "it requires padding dim sizes to be constant."; - return OkStatus(); + return absl::OkStatus(); } utils::Mutation* mutation = ctx->graph_view.GetMutationBuilder(); @@ -3627,7 +3627,7 @@ Status AddFusedConv3DNode(RemapperContext* ctx, const PadWithConv3D& matched, (*invalidated_nodes)[matched.contraction_idx] = true; (*nodes_to_delete)[matched.pad_idx] = true; - return OkStatus(); + return absl::OkStatus(); } Status AddFusedContractionNode( @@ -3674,7 +3674,7 @@ Status AddFusedContractionNode( (*nodes_to_delete)[matched.bias_add] = true; (*nodes_to_delete)[matched.contraction] = true; - return OkStatus(); + return absl::OkStatus(); } Status FuseContractionWithBiasAddAndHardSwish( @@ -3715,7 +3715,7 @@ Status FuseContractionWithBiasAddAndHardSwish( for (const auto& node_idx : *remove_node_indices) { (*nodes_to_delete)[node_idx] = true; } - return OkStatus(); + return absl::OkStatus(); } Status FuseConv2DSwish(RemapperContext* ctx, @@ -3768,7 +3768,7 @@ Status FuseConv2DSwish(RemapperContext* ctx, (*nodes_to_delete)[node_index] = true; } - return OkStatus(); + return absl::OkStatus(); } Status AddFusedMatMulBiasAddAndGelu( @@ -3811,7 +3811,7 @@ Status AddFusedMatMulBiasAddAndGelu( for (const auto& node_idx : remove_node_indices) { (*nodes_to_delete)[node_idx] = true; } - return OkStatus(); + return absl::OkStatus(); } Status AddMklLayerNorm(RemapperContext* ctx, @@ -3844,7 +3844,7 @@ Status AddMklLayerNorm(RemapperContext* ctx, for (const auto& node_idx : remove_node_indices) { (*nodes_to_delete)[node_idx] = true; } - return OkStatus(); + return absl::OkStatus(); } Status ReplaceMulMaximumWithLeakyRelu( @@ -3880,7 +3880,7 @@ Status ReplaceMulMaximumWithLeakyRelu( (*nodes_to_delete)[node_index] = true; } - return OkStatus(); + return absl::OkStatus(); } Status ReplaceSigmoidMulWithSwish( @@ -3912,7 +3912,7 @@ Status ReplaceSigmoidMulWithSwish( for (const auto& node_index : remove_node_indices) { (*nodes_to_delete)[node_index] = true; } - return OkStatus(); + return absl::OkStatus(); } Status AddFusedBatchNormExNode(RemapperContext* ctx, @@ -3982,7 +3982,7 @@ Status AddFusedBatchNormExNode(RemapperContext* ctx, (*nodes_to_delete)[matched.invalidated] = true; } - return OkStatus(); + return absl::OkStatus(); } Status AddFusedBatchNormGradExNode(RemapperContext* ctx, @@ -4055,7 +4055,7 @@ Status AddFusedBatchNormGradExNode(RemapperContext* ctx, (*nodes_to_delete)[matched.activation_grad] = true; } - return OkStatus(); + return absl::OkStatus(); } Status AddBatchNormNodes(RemapperContext* ctx, const FusedBatchNorm& matched) { @@ -4288,7 +4288,7 @@ Status AddTensorToHashBucketNode(RemapperContext* ctx, (*invalidated_nodes)[matched.string_to_hash_bucket] = true; (*nodes_to_delete)[matched.as_string] = true; - return OkStatus(); + return absl::OkStatus(); } Status AddFusedBatchMatMul(RemapperContext* ctx, @@ -4321,7 +4321,7 @@ Status AddFusedBatchMatMul(RemapperContext* ctx, for (const auto& node_idx : remove_node_indices) { (*nodes_to_delete)[node_idx] = true; } - return OkStatus(); + return absl::OkStatus(); } // Helper function to get data of type T from a given tensor and @@ -4359,19 +4359,19 @@ Status AddMklFusedInstanceNorm(RemapperContext* ctx, if (!mean_axes_node || mean_axes_node->op() != "Const") { VLOG(2) << "Mean reduction axes node is not valid, abort fusion"; - return OkStatus(); + return absl::OkStatus(); } DataType dtype; Tensor mean_axes_tensor; if (!mean_axes_tensor.FromProto( mean_axes_node->attr().at("value").tensor())) { VLOG(2) << "Unable to get mean reduction axes, abort fusion"; - return OkStatus(); + return absl::OkStatus(); } dtype = mean_axes_tensor.dtype(); if (dtype != DT_INT32 && dtype != DT_INT64) { VLOG(2) << "Unexpected mean reduction axes data type, abort fusion"; - return OkStatus(); + return absl::OkStatus(); } std::vector reduction_axes = (dtype == DT_INT32) ? GetTensorValues(mean_axes_tensor) @@ -4383,11 +4383,11 @@ Status AddMklFusedInstanceNorm(RemapperContext* ctx, ctx->graph_view.GetNode(matched_nodes_map->at("activation"))->node(); if (!activation_node) { VLOG(2) << "Error to retrieve activation node, abort fusion"; - return OkStatus(); + return absl::OkStatus(); } if (!IsLeakyRelu(*activation_node) && !IsRelu(*activation_node)) { VLOG(2) << "Unsupported activation node, abort fusion"; - return OkStatus(); + return absl::OkStatus(); } } @@ -4445,7 +4445,7 @@ Status AddMklFusedInstanceNorm(RemapperContext* ctx, for (const auto& node_idx : *remove_node_indices) { (*nodes_to_delete)[node_idx] = true; } - return OkStatus(); + return absl::OkStatus(); } // This function supports below patterns that require inferred @@ -4617,7 +4617,7 @@ Status ReplaceSoftplusTanhAndMulWithMish( (*nodes_to_delete)[node_index] = true; } - return OkStatus(); + return absl::OkStatus(); } // Check if a node is a candidate to one of the patterns that require inferred @@ -5139,7 +5139,7 @@ Status Remapper::Optimize(Cluster* cluster, const GrapplerItem& item, *optimized_graph = std::move(mutable_item.graph); - return OkStatus(); + return absl::OkStatus(); } } // namespace grappler diff --git a/tensorflow/core/grappler/optimizers/remapper_test.cc b/tensorflow/core/grappler/optimizers/remapper_test.cc index cc315186474712..f5efc5b7d98bc7 100644 --- a/tensorflow/core/grappler/optimizers/remapper_test.cc +++ b/tensorflow/core/grappler/optimizers/remapper_test.cc @@ -1029,7 +1029,7 @@ class RemapperFuseConvWithBiasAndAddActivation : public RemapperTest { ASSERT_EQ(tensors_expected.size(), 1); auto tensors = EvaluateNodes(output, item.fetch, item.feed); ASSERT_EQ(tensors.size(), 1); - test::ExpectTensorNear(tensors[0], tensors_expected[0], 1e-6); + test::ExpectClose(tensors[0], tensors_expected[0], 0, 1e-6); } } }; @@ -2087,6 +2087,9 @@ TEST_F(RemapperFuseMatMulWithBiasAndActivationTest, Bf16) { } TEST_F(RemapperTest, FuseConv2DWithBatchNorm) { +#ifdef DNNL_AARCH64_USE_ACL + GTEST_SKIP() << "Skipping test due to different behaviour on AARCH64"; +#endif using ops::Placeholder; tensorflow::Scope s = tensorflow::Scope::NewRootScope(); @@ -2165,6 +2168,9 @@ TEST_F(RemapperTest, FuseConv2DWithBatchNorm) { } TEST_F(RemapperTest, FuseConv2DWithBatchNormAndActivation) { +#ifdef DNNL_AARCH64_USE_ACL + GTEST_SKIP() << "Skipping test due to different behaviour on AARCH64"; +#endif using ops::Placeholder; for (const string& activation : {"Relu", "Relu6", "Elu", "LeakyRelu"}) { @@ -2271,6 +2277,9 @@ TEST_F(RemapperTest, FuseConv2DWithBatchNormAndActivation) { #ifdef INTEL_MKL TEST_F(RemapperTest, FuseConv3DWithBiasAndAddN) { +#ifdef DNNL_AARCH64_USE_ACL + GTEST_SKIP() << "Skipping test due to different behaviour on AARCH64"; +#endif if (!IsMKLEnabled()) GTEST_SKIP() << "Test only applicable to oneDNN."; using ::tensorflow::ops::Placeholder; @@ -2339,10 +2348,13 @@ TEST_F(RemapperTest, FuseConv3DWithBiasAndAddN) { ASSERT_EQ(tensors_expected.size(), 1); auto tensors = EvaluateNodes(output, item.fetch, item.feed); ASSERT_EQ(tensors.size(), 1); - test::ExpectTensorNear(tensors[0], tensors_expected[0], 1e-6); + test::ExpectClose(tensors[0], tensors_expected[0], 0, 1e-6); } TEST_F(RemapperTest, FuseConv3DWithBiasAndAdd) { +#ifdef DNNL_AARCH64_USE_ACL + GTEST_SKIP() << "Skipping test due to different behaviour on AARCH64"; +#endif if (!IsMKLEnabled()) GTEST_SKIP() << "Test only applicable to oneDNN."; using ::tensorflow::ops::Placeholder; @@ -2410,7 +2422,7 @@ TEST_F(RemapperTest, FuseConv3DWithBiasAndAdd) { ASSERT_EQ(tensors_expected.size(), 1); auto tensors = EvaluateNodes(output, item.fetch, item.feed); ASSERT_EQ(tensors.size(), 1); - test::ExpectTensorNear(tensors[0], tensors_expected[0], 1e-6); + test::ExpectClose(tensors[0], tensors_expected[0], 0, 1e-6); } // Conv2D + Add {6,} + Conv2D + Biasadd fusion. diff --git a/tensorflow/core/grappler/optimizers/scoped_allocator_optimizer.cc b/tensorflow/core/grappler/optimizers/scoped_allocator_optimizer.cc index 6cc45a31b7ebd1..0caf9d89e925a9 100644 --- a/tensorflow/core/grappler/optimizers/scoped_allocator_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/scoped_allocator_optimizer.cc @@ -79,7 +79,7 @@ Status GetOutputDataType( " size of output_props ", output_props.size()); } *dtype = output_props[output_index].dtype(); - return OkStatus(); + return absl::OkStatus(); } // After shape inference has been done each op should be annotated @@ -133,7 +133,7 @@ Status CheckTypesAndGetShapes(const GraphProperties& graph_properties, VLOG(2) << "Adding shape " << props.shape().DebugString(); shapes->push_back(TensorShape(props.shape())); } - return OkStatus(); + return absl::OkStatus(); } // Describes an existing input edge in the graph. @@ -183,7 +183,7 @@ Status RemoveEdge(const string& input_edge_name, const string& from_node_name, node_map->RemoveOutput(from_node_name, to_node->name()); } inputs->DeleteSubrange(edge_index, 1); - return OkStatus(); + return absl::OkStatus(); } // In certain cases, we would like to insert an identity op between `input` and @@ -224,7 +224,7 @@ Status MaybeRewriteInput(ScopedAllocatorOptimizer* sa_opti, if (!(*rewrite)) { *new_input = input; *new_output_index = output_index; - return OkStatus(); + return absl::OkStatus(); } // Create new Identity op. @@ -249,7 +249,7 @@ Status MaybeRewriteInput(ScopedAllocatorOptimizer* sa_opti, VLOG(1) << "Rewrite input " << edge_name << " op " << op->name() << " old output index " << output_index << " with identity " << identity_name << " new output index 0"; - return OkStatus(); + return absl::OkStatus(); } // Populates *inputs with all of the non-control inputs of ops. @@ -310,7 +310,7 @@ Status GetInputs(ScopedAllocatorOptimizer* sa_opti, int64_t invocation_count, } inputs->emplace_back(inode, output_index, n); } - return OkStatus(); + return absl::OkStatus(); } // Return non-control inputs of `op` in `inputs`. @@ -333,7 +333,7 @@ Status GetDataInputs(GraphDef* graph, NodeMap* node_map, NodeDef* op, << output_index; inputs->emplace_back(inode, output_index, op); } - return OkStatus(); + return absl::OkStatus(); } void DumpGraphToVLOG(const GraphDef& graph, int log_level) { @@ -377,7 +377,7 @@ class UnaryElementwiseRewriter : public ScopedAllocatorOptimizer::Rewriter { " is a Const op which does not use AllocatorAttributes"); } } - return OkStatus(); + return absl::OkStatus(); } // Return non-OK if any input is already committed to a ScopedAllocator. @@ -403,7 +403,7 @@ class UnaryElementwiseRewriter : public ScopedAllocatorOptimizer::Rewriter { "assigned to scope_id ", scope_ids[1]); } } - return OkStatus(); + return absl::OkStatus(); } // Return non-OK if any input is a member of op_set. @@ -419,7 +419,7 @@ class UnaryElementwiseRewriter : public ScopedAllocatorOptimizer::Rewriter { } } } - return OkStatus(); + return absl::OkStatus(); } // Remove all control edges between members of ops. @@ -478,7 +478,7 @@ class UnaryElementwiseRewriter : public ScopedAllocatorOptimizer::Rewriter { int64_t num_elts = num_bytes / DataTypeSize(*dtype); VLOG(2) << "num_bytes " << num_bytes << " num_elts=" << num_elts; *sa_shape = TensorShape({num_elts}); - return OkStatus(); + return absl::OkStatus(); } // Returns the set of all nodes that are transitively reachable via data or @@ -505,7 +505,7 @@ class UnaryElementwiseRewriter : public ScopedAllocatorOptimizer::Rewriter { } } - return OkStatus(); + return absl::OkStatus(); } // Build the ScopedAllocator node that will be assigned to allocate @@ -594,7 +594,7 @@ class UnaryElementwiseRewriter : public ScopedAllocatorOptimizer::Rewriter { "ScopedAllocatorOptimizer and file a bug."; } - return OkStatus(); + return absl::OkStatus(); } Status BuildSAConcatNode(GraphDef* graph, NodeMap* node_map, @@ -657,7 +657,7 @@ class UnaryElementwiseRewriter : public ScopedAllocatorOptimizer::Rewriter { sac_node->add_input(ctl_edge); node_map->AddOutput(input_name, sac_node->name()); } - return OkStatus(); + return absl::OkStatus(); } Status BuildReplacementOp(GraphDef* graph, NodeMap* node_map, @@ -684,7 +684,7 @@ class UnaryElementwiseRewriter : public ScopedAllocatorOptimizer::Rewriter { LOG_WARNING_AND_RETURN_IF_ERROR(op_builder.Finalize(sa_op_node)); node_map->AddNode(sa_op_name, sa_op_node); node_map->AddOutput(sac_name, sa_op_name); - return OkStatus(); + return absl::OkStatus(); } Status BuildSplitNode(GraphDef* graph, NodeMap* node_map, @@ -713,7 +713,7 @@ class UnaryElementwiseRewriter : public ScopedAllocatorOptimizer::Rewriter { for (const auto& input : sas_inputs) { node_map->AddOutput(input.node, sas_name); } - return OkStatus(); + return absl::OkStatus(); } // After the new ScopedAllocator and its corresponding Concat and @@ -804,7 +804,7 @@ class UnaryElementwiseRewriter : public ScopedAllocatorOptimizer::Rewriter { // Remove. RemoveNode(old_op, graph, node_map); } - return OkStatus(); + return absl::OkStatus(); } // Given a collection of instances of op_name, presumed to be @@ -885,7 +885,7 @@ class UnaryElementwiseRewriter : public ScopedAllocatorOptimizer::Rewriter { op_name, sas_name)); *applied = true; - return OkStatus(); + return absl::OkStatus(); } }; @@ -933,7 +933,7 @@ Status ScopedAllocatorOptimizer::Optimize(Cluster* /*cluster*/, VLOG(1) << "ScopedAllocatorOptimizer::Optimize() done"; VLOG(3) << "Optimized graph:"; DumpGraphToVLOG(*optimized_graph, /*log_level=*/3); - return OkStatus(); + return absl::OkStatus(); } ScopedAllocatorOptimizer::Rewriter* ScopedAllocatorOptimizer::GetRewriter( @@ -958,7 +958,7 @@ Status ScopedAllocatorOptimizer::NewIdentityId(int* id) { if (next_identity_id_ < 0) { return errors::Aborted("NewIdentityId overflow"); } - return OkStatus(); + return absl::OkStatus(); } ScopedAllocatorOptimizer::~ScopedAllocatorOptimizer() { @@ -1115,7 +1115,7 @@ Status ScopedAllocatorOptimizer::ProcessGraphDef( absl::flat_hash_set seen_outputs; status = ApplyToAll(root.get(), [this, &seen_outputs](Tree* t) { IdentifyRepeatedInputs(t->nodes_, &seen_outputs, &repeated_outputs_); - return OkStatus(); + return absl::OkStatus(); }); if (!status.ok()) { break; @@ -1142,7 +1142,7 @@ Status ScopedAllocatorOptimizer::ProcessGraphDef( } } } - return OkStatus(); + return absl::OkStatus(); }); if (!status.ok()) { break; @@ -1198,13 +1198,13 @@ Status ScopedAllocatorOptimizer::OrderNodeSet( // Nodes should be identical type. Default order is by name but for // collectives we order by increasing instance_key so each group gets // the same instance_key. - if (nodes->size() <= 1) return OkStatus(); + if (nodes->size() <= 1) return absl::OkStatus(); if (IsCollectiveNode(*nodes->at(0))) { std::sort(nodes->begin(), nodes->end(), InstanceKeyLess()); } else { std::sort(nodes->begin(), nodes->end(), NameLess()); } - return OkStatus(); + return absl::OkStatus(); } } // namespace grappler diff --git a/tensorflow/core/grappler/optimizers/shape_optimizer.cc b/tensorflow/core/grappler/optimizers/shape_optimizer.cc index c0053a49a5f6c8..e6f7d069860adb 100644 --- a/tensorflow/core/grappler/optimizers/shape_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/shape_optimizer.cc @@ -191,7 +191,7 @@ Status ShapeOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item, } } } - return OkStatus(); + return absl::OkStatus(); } } // end namespace grappler diff --git a/tensorflow/core/grappler/optimizers/static_schedule.cc b/tensorflow/core/grappler/optimizers/static_schedule.cc index 5eff0db3c9a663..0f7bc13140e416 100644 --- a/tensorflow/core/grappler/optimizers/static_schedule.cc +++ b/tensorflow/core/grappler/optimizers/static_schedule.cc @@ -127,7 +127,7 @@ Status EstimateEarliestExecutionTimes( } } - return OkStatus(); + return absl::OkStatus(); } Status EstimateRequiredTimes( @@ -196,7 +196,7 @@ Status EstimateRequiredTimes( } } - return OkStatus(); + return absl::OkStatus(); } } // end namespace grappler diff --git a/tensorflow/core/grappler/optimizers/tfg_optimizer_hook.cc b/tensorflow/core/grappler/optimizers/tfg_optimizer_hook.cc index df6fcc04c133bd..12c494c6d0456c 100644 --- a/tensorflow/core/grappler/optimizers/tfg_optimizer_hook.cc +++ b/tensorflow/core/grappler/optimizers/tfg_optimizer_hook.cc @@ -172,7 +172,7 @@ Status TFGGrapplerOptimizer::Optimize( module.dump(); } - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } } // end namespace tfg diff --git a/tensorflow/core/grappler/utils.cc b/tensorflow/core/grappler/utils.cc index 7c51706dcaaed3..4610011c3e6a0d 100644 --- a/tensorflow/core/grappler/utils.cc +++ b/tensorflow/core/grappler/utils.cc @@ -454,7 +454,7 @@ Status SetTensorValue(DataType dtype, int value, Tensor* tensor) { return errors::InvalidArgument("Unsupported type ", DataTypeString(dtype)); } - return OkStatus(); + return absl::OkStatus(); } #undef HANDLE_CASE @@ -464,14 +464,14 @@ Status CheckAttrExists(const NodeDef& node, const string& key) { return errors::InvalidArgument("Node '", node.name(), "' lacks '", key, "' attr: ", node.ShortDebugString()); } - return OkStatus(); + return absl::OkStatus(); } Status CheckAttrsExist(const NodeDef& node, absl::Span keys) { for (const string& key : keys) { TF_RETURN_IF_ERROR(CheckAttrExists(node, key)); } - return OkStatus(); + return absl::OkStatus(); } Status IsKernelRegisteredForNode( diff --git a/tensorflow/core/grappler/utils/frame.cc b/tensorflow/core/grappler/utils/frame.cc index f7e114ff0ff705..280b04a71889d5 100644 --- a/tensorflow/core/grappler/utils/frame.cc +++ b/tensorflow/core/grappler/utils/frame.cc @@ -111,7 +111,7 @@ inline Status FrameView::InferFromGraphViewT(const GraphViewT& graph_view) { " does not match frame ids for it's fanout ", fanout_node->name()); } } - return OkStatus(); + return absl::OkStatus(); }; while (!ready_node_indices.empty()) { @@ -138,7 +138,7 @@ inline Status FrameView::InferFromGraphViewT(const GraphViewT& graph_view) { } num_frames_ = static_cast(frame_name_to_id.size()); - return OkStatus(); + return absl::OkStatus(); } Status FrameView::InferFromGraphView(const utils::GraphView& graph_view) { diff --git a/tensorflow/core/grappler/utils/functions.cc b/tensorflow/core/grappler/utils/functions.cc index eb1493f968173d..c2d848aaa67ae6 100644 --- a/tensorflow/core/grappler/utils/functions.cc +++ b/tensorflow/core/grappler/utils/functions.cc @@ -180,7 +180,7 @@ Status InstantiationTypeParameters( ++index; } } - return OkStatus(); + return absl::OkStatus(); }; for (const auto& input : func.signature().input_arg()) @@ -188,7 +188,7 @@ Status InstantiationTypeParameters( for (const auto& output : func.signature().output_arg()) TF_RETURN_IF_ERROR(resolve_type_attr(output)); - return OkStatus(); + return absl::OkStatus(); } Status InstantiationBodyParameters( @@ -218,7 +218,7 @@ Status InstantiationBodyParameters( } } - return OkStatus(); + return absl::OkStatus(); } Status MakeGrapplerFunctionItem(const FunctionDef& func, @@ -308,7 +308,7 @@ Status MakeGrapplerFunctionItem(const FunctionDef& func, /*func_attr=*/AttrSlice(&func.attr()), std::move(arg_attr), std::move(inputs), std::move(outputs), std::move(control_outputs), graph_def_version, signature.is_stateful(), std::move(function_body)); - return OkStatus(); + return absl::OkStatus(); } Status MakeGrapplerFunctionItem(const FunctionDef& func, @@ -357,7 +357,7 @@ Status ReplaceInputWithConst(const NodeDef& input_const, int input_index, item->input_args_.erase(item->input_args_.begin() + input_index); item->arg_attr_.erase(item->arg_attr_.begin() + input_index); - return OkStatus(); + return absl::OkStatus(); } Status RemoveFunctionOutputs(const absl::flat_hash_set& remove_outputs, @@ -412,7 +412,7 @@ Status RemoveFunctionOutputs(const absl::flat_hash_set& remove_outputs, auto& o = item->output_args_; o.erase(std::remove_if(o.begin(), o.end(), is_remove_output_arg), o.end()); - return OkStatus(); + return absl::OkStatus(); } namespace { @@ -472,14 +472,14 @@ Status MakeFunctionDefHelper::Initialize( function_body_outputs_.emplace(node.name(), std::move(outputs_range_map)); } - return OkStatus(); + return absl::OkStatus(); } Status MakeFunctionDefHelper::AsFunctionDefInput(const string& graph_def_input, string* func_def_input) const { if (IsControlInput(graph_def_input)) { *func_def_input = graph_def_input; - return OkStatus(); + return absl::OkStatus(); } const SafeTensorId tensor = ParseTensorName(graph_def_input); @@ -490,7 +490,7 @@ Status MakeFunctionDefHelper::AsFunctionDefInput(const string& graph_def_input, if (is_input != input_nodes_.end()) { DCHECK_EQ(tensor.index(), 0); *func_def_input = tensor.node(); - return OkStatus(); + return absl::OkStatus(); } // Or it must be output from one of the function body nodes @@ -505,7 +505,7 @@ Status MakeFunctionDefHelper::AsFunctionDefInput(const string& graph_def_input, tensor.index() < output_range.second) { *func_def_input = absl::StrCat(tensor.node(), ":", output_name, ":", tensor.index() - output_range.first); - return OkStatus(); + return absl::OkStatus(); } } } @@ -524,7 +524,7 @@ Status MakeFunctionDefHelper::AsFunctionDefNode( function_body_node->set_input(i, func_def_input); } - return OkStatus(); + return absl::OkStatus(); } } // namespace @@ -617,7 +617,7 @@ Status MakeFunctionDef(const GrapplerFunctionItem& item, TF_RETURN_IF_ERROR(helper.AsFunctionDefNode(func_def_node)); } - return OkStatus(); + return absl::OkStatus(); } } // end namespace grappler diff --git a/tensorflow/core/grappler/utils/graph_view.cc b/tensorflow/core/grappler/utils/graph_view.cc index 1d26a4a6bfe38c..aeb947808488e8 100644 --- a/tensorflow/core/grappler/utils/graph_view.cc +++ b/tensorflow/core/grappler/utils/graph_view.cc @@ -104,7 +104,7 @@ GraphView::GraphView(const GraphDef* graph, Status* status) return; } } - *status = OkStatus(); + *status = absl::OkStatus(); } bool GraphView::AddUniqueNodeInternal(const NodeDef* node) { @@ -165,7 +165,7 @@ Status GraphView::CheckAndAddFaninsInternal(NodeView* node_view) { node_view->fanins_set_.emplace(fanin_node_view.node(), fanin_id.index()); } } - return OkStatus(); + return absl::OkStatus(); } MutableFaninView::MutableFaninView(MutableNodeView* node_view, int index) @@ -257,7 +257,7 @@ MutationNewNode Mutation::AddNode(NodeDef&& node, Status* status) { mutation_node.regular_fanins = std::move(regular_fanins); mutation_node.num_regular_fanins = mutation_node.regular_fanins.size(); mutation_node.controlling_fanins = std::move(controlling_fanins); - *status = OkStatus(); + *status = absl::OkStatus(); return MutationNewNode(this, mutation_counter_, new_nodes_.size() - 1); } @@ -484,7 +484,7 @@ MutableGraphView::MutableGraphView(GraphDef* graph, Status* status) } AddFaninsInternal(&fanins); mutation_.ResetInternal(); - *status = OkStatus(); + *status = absl::OkStatus(); } Mutation* MutableGraphView::GetMutationBuilder() { return &mutation_; } @@ -534,7 +534,7 @@ Status MutableGraphView::CheckFaninsInternal( } fanins->push_back(std::move(node_fanins)); } - return OkStatus(); + return absl::OkStatus(); } void MutableGraphView::AddFaninsInternal( @@ -686,7 +686,7 @@ Status MutableGraphView::GetNodeNamesAndPartitionUpdatedNodes( } } - return OkStatus(); + return absl::OkStatus(); } Status MutableGraphView::RemovedOrMissingNodeFanoutsWellFormed( @@ -764,7 +764,7 @@ Status MutableGraphView::RemovedOrMissingNodeFanoutsWellFormed( } } - return OkStatus(); + return absl::OkStatus(); } Status MutableGraphView::CheckNodeNamesAndFanins( @@ -801,7 +801,7 @@ Status MutableGraphView::CheckNodeNamesAndFanins( } } - return OkStatus(); + return absl::OkStatus(); } Status MutableGraphView::CheckKernelRegisteredForNodes() { @@ -849,7 +849,7 @@ Status MutableGraphView::CheckKernelRegisteredForNodes() { LOG(WARNING) << s.message(); } } - return OkStatus(); + return absl::OkStatus(); } template @@ -1617,7 +1617,7 @@ Status MutableGraphView::SortTopologically( // Permute graph NodeDefs. PermuteNodesInPlace(graph_, &order, /*invert_permutation=*/false); - return OkStatus(); + return absl::OkStatus(); } inline Status MutableGraphView::ValidateInternal( @@ -1637,7 +1637,7 @@ inline Status MutableGraphView::ValidateInternal( // Check if nodes after mutation have kernels registered. TF_RETURN_IF_ERROR(CheckKernelRegisteredForNodes()); - return OkStatus(); + return absl::OkStatus(); } Status MutableGraphView::ApplyMutationInternal() { @@ -1701,7 +1701,7 @@ Status MutableGraphView::ApplyMutationInternal() { mutation_.mutation_counter_++; - return OkStatus(); + return absl::OkStatus(); } } // namespace utils diff --git a/tensorflow/core/grappler/utils/topological_sort.cc b/tensorflow/core/grappler/utils/topological_sort.cc index 4e10be8d16d03f..29e00240028715 100644 --- a/tensorflow/core/grappler/utils/topological_sort.cc +++ b/tensorflow/core/grappler/utils/topological_sort.cc @@ -105,7 +105,7 @@ Status ComputeTopologicalOrder( return errors::InvalidArgument( "The graph couldn't be sorted in topological order."); } - return OkStatus(); + return absl::OkStatus(); } } // namespace @@ -123,7 +123,7 @@ Status ComputeTopologicalOrder( topo_order->emplace_back(&graph.node(ready_node_idx)); } - return OkStatus(); + return absl::OkStatus(); } Status ComputeTopologicalOrder(const GraphDef& graph, @@ -136,14 +136,14 @@ Status ReversedTopologicalSort(GraphDef* graph) { TF_RETURN_IF_ERROR(ComputeTopologicalOrder(*graph, {}, &ready_nodes)); std::reverse(ready_nodes.begin(), ready_nodes.end()); PermuteNodesInPlace(graph, &ready_nodes, /*invert_permutation=*/true); - return OkStatus(); + return absl::OkStatus(); } Status TopologicalSort(GraphDef* graph) { std::vector ready_nodes; TF_RETURN_IF_ERROR(ComputeTopologicalOrder(*graph, {}, &ready_nodes)); PermuteNodesInPlace(graph, &ready_nodes, /*invert_permutation=*/true); - return OkStatus(); + return absl::OkStatus(); } } // namespace grappler diff --git a/tensorflow/core/grappler/utils/transitive_fanin.cc b/tensorflow/core/grappler/utils/transitive_fanin.cc index bc122674664190..2b2c0d8e6662d8 100644 --- a/tensorflow/core/grappler/utils/transitive_fanin.cc +++ b/tensorflow/core/grappler/utils/transitive_fanin.cc @@ -82,7 +82,7 @@ Status ComputeTransitiveFanin( // So, we do not set ill_formed for missing _Send. } } - return OkStatus(); + return absl::OkStatus(); } Status ComputeTransitiveFanin(const GraphDef& graph, @@ -106,7 +106,7 @@ Status SetTransitiveFaninGraph(const GraphDef& input_graph, *output_graph->add_node() = *keep[i]; } - return OkStatus(); + return absl::OkStatus(); } } // namespace grappler diff --git a/tensorflow/core/grappler/verifiers/structure_verifier_test.cc b/tensorflow/core/grappler/verifiers/structure_verifier_test.cc index bea5de0c6c6fdb..eac350e2882554 100644 --- a/tensorflow/core/grappler/verifiers/structure_verifier_test.cc +++ b/tensorflow/core/grappler/verifiers/structure_verifier_test.cc @@ -46,7 +46,7 @@ Status Scalars(shape_inference::InferenceContext* c) { for (int i = 0; i < c->num_outputs(); ++i) { c->set_output(i, c->Scalar()); } - return OkStatus(); + return absl::OkStatus(); } REGISTER_OP("TestParams").Output("o: float").SetShapeFn(Scalars); diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD index 5ce78756bf3cd0..3f3687c50224f4 100644 --- a/tensorflow/core/kernels/BUILD +++ b/tensorflow/core/kernels/BUILD @@ -1601,9 +1601,9 @@ tf_kernel_library( "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", - "//tensorflow/core/framework:bounds_check", "//tensorflow/core/platform:stream_executor", "//tensorflow/core/profiler/lib:scoped_annotation", + "@com_google_absl//absl/strings", "@eigen_archive//:eigen3", ], ) @@ -3877,7 +3877,7 @@ tf_cuda_cc_test( name = "matmul_op_test", srcs = ["matmul_op_test.cc"], tags = [ - "no_arm64", # b/282068262 + "no_aarch64", # b/282068262 ], deps = [ ":matmul_op", diff --git a/tensorflow/core/kernels/barrier_ops.cc b/tensorflow/core/kernels/barrier_ops.cc index 9c144faf997cb3..91b1de594b121a 100644 --- a/tensorflow/core/kernels/barrier_ops.cc +++ b/tensorflow/core/kernels/barrier_ops.cc @@ -396,7 +396,7 @@ class Barrier : public ResourceBase { TF_RETURN_IF_ERROR(ready_queue_->ValidateTuple(ready_tuple)); ready_tuples->push_back(ready_tuple); } - return OkStatus(); + return absl::OkStatus(); } void CloseQueueLocked(OpKernelContext* ctx, bool cancel_pending_enqueues, @@ -485,7 +485,7 @@ class BarrierOp : public ResourceOpKernel { " but requested component shapes were ", TensorShapeUtils::ShapeListString(value_component_shapes_)); } - return OkStatus(); + return absl::OkStatus(); } DataTypeVector value_component_types_; diff --git a/tensorflow/core/kernels/batch_kernels.cc b/tensorflow/core/kernels/batch_kernels.cc index 8862e5b0c98f58..bd2d7880d8520e 100644 --- a/tensorflow/core/kernels/batch_kernels.cc +++ b/tensorflow/core/kernels/batch_kernels.cc @@ -200,7 +200,7 @@ class BatchResource : public serving::BatchResourceBase { low_priority_max_enqueued_batches, low_priority_allowed_batch_sizes), allowed_batch_sizes)); - return OkStatus(); + return absl::OkStatus(); } static Status Create( @@ -221,7 +221,7 @@ class BatchResource : public serving::BatchResourceBase { /*enable_large_batch_splitting=*/true, allowed_batch_sizes, /*disable_padding=*/false), allowed_batch_sizes)); - return OkStatus(); + return absl::OkStatus(); } string DebugString() const final { return "BatchResource"; } @@ -411,7 +411,7 @@ void BatchFunctionKernel::ComputeAsync(OpKernelContext* c, DoneCallback done) { new_resource->set_session_metadata(*session_metadata); } *r = new_resource.release(); - return OkStatus(); + return absl::OkStatus(); }; } else { creator = [this, @@ -428,7 +428,7 @@ void BatchFunctionKernel::ComputeAsync(OpKernelContext* c, DoneCallback done) { new_resource->set_session_metadata(*session_metadata); } *r = new_resource.release(); - return OkStatus(); + return absl::OkStatus(); }; } @@ -438,8 +438,9 @@ void BatchFunctionKernel::ComputeAsync(OpKernelContext* c, DoneCallback done) { container_, shared_name_, &br, creator), done); const uint64_t guid = random::New64(); - auto create_batch_task_fn = [handle]() - -> StatusOr> { + auto create_batch_task_fn = + [handle]() -> absl::StatusOr< + std::unique_ptr> { return {std::make_unique(handle)}; }; Status status; @@ -525,7 +526,7 @@ Status BatchFunctionKernel::GetOrCreateFunctionHandle( } else { *handle = fhandle_.value(); } - return OkStatus(); + return absl::OkStatus(); } // Validates 'allowed_batch_sizes_'. The entries must increase monotonically. @@ -534,7 +535,7 @@ Status BatchFunctionKernel::GetOrCreateFunctionHandle( // to `max_batch_size_`. Status BatchFunctionKernel::ValidateAllowedBatchSizes() const { if (allowed_batch_sizes_.empty()) { - return OkStatus(); + return absl::OkStatus(); } int32_t last_size = 0; for (size_t i = 0; i < allowed_batch_sizes_.size(); ++i) { @@ -553,7 +554,7 @@ Status BatchFunctionKernel::ValidateAllowedBatchSizes() const { last_size = size; } - return OkStatus(); + return absl::OkStatus(); } // Initialize vars by reading from op-kernel-construction. @@ -667,7 +668,7 @@ class BatchKernel : public AsyncOpKernel { max_batch_size_, batch_timeout_micros_, max_enqueued_batches_, allowed_batch_sizes_, false, &new_resource)); *r = new_resource.release(); - return OkStatus(); + return absl::OkStatus(); }; OP_REQUIRES_OK_ASYNC(c, c->resource_manager()->LookupOrCreate( @@ -675,7 +676,7 @@ class BatchKernel : public AsyncOpKernel { done); const Status status = br->RegisterInput( random::New64(), c, batcher_queue_, - []() -> StatusOr< + []() -> absl::StatusOr< std::unique_ptr> { return {std::make_unique(kInvalidHandle)}; }, @@ -689,7 +690,7 @@ class BatchKernel : public AsyncOpKernel { // monotonically, and the last one must equal 'max_batch_size_'. Status ValidateAllowedBatchSizes() const { if (allowed_batch_sizes_.empty()) { - return OkStatus(); + return absl::OkStatus(); } int32_t last_size = 0; for (size_t i = 0; i < allowed_batch_sizes_.size(); ++i) { @@ -704,7 +705,7 @@ class BatchKernel : public AsyncOpKernel { } last_size = size; } - return OkStatus(); + return absl::OkStatus(); } private: @@ -798,7 +799,7 @@ class UnbatchResource : public ResourceBase { context->set_output(0, tensor_it->second.tensor); waiting_tensors_.erase(tensor_it); done_callbacks_to_call.push_back(done); - return OkStatus(); + return absl::OkStatus(); } const uint64 deadline_micros = @@ -837,7 +838,7 @@ class UnbatchResource : public ResourceBase { } } - return OkStatus(); + return absl::OkStatus(); }(); for (const AsyncOpKernel::DoneCallback& done_callback : @@ -930,7 +931,7 @@ class UnbatchKernel : public AsyncOpKernel { std::function creator = [this](UnbatchResource** r) { *r = new UnbatchResource(timeout_micros_); - return OkStatus(); + return absl::OkStatus(); }; OP_REQUIRES_OK_ASYNC(c, c->resource_manager()->LookupOrCreate( @@ -989,7 +990,7 @@ class UnbatchGradResource : public ResourceBase { return errors::InvalidArgument("Unsupported data type: ", type); } done(); - return OkStatus(); + return absl::OkStatus(); } // Ingests data from one invocation of the op. @@ -1076,7 +1077,7 @@ class UnbatchGradResource : public ResourceBase { available_batches_.erase(batch_it); } } - return OkStatus(); + return absl::OkStatus(); } private: @@ -1126,7 +1127,7 @@ class UnbatchGradKernel : public AsyncOpKernel { std::function creator = [](UnbatchGradResource** r) { *r = new UnbatchGradResource(); - return OkStatus(); + return absl::OkStatus(); }; OP_REQUIRES_OK_ASYNC(c, c->resource_manager()->LookupOrCreate( diff --git a/tensorflow/core/kernels/batch_kernels_test.cc b/tensorflow/core/kernels/batch_kernels_test.cc index af7546a062169d..62d5ef22ec2484 100644 --- a/tensorflow/core/kernels/batch_kernels_test.cc +++ b/tensorflow/core/kernels/batch_kernels_test.cc @@ -23,19 +23,24 @@ limitations under the License. #include #include "absl/strings/match.h" #include "tensorflow/core/common_runtime/rendezvous_mgr.h" +#include "tensorflow/core/framework/device_factory.h" #include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/node_def_builder.h" #include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/tensor_testutil.h" #include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/kernels/batch_kernel_test_util.h" #include "tensorflow/core/kernels/batching_util/warmup.h" #include "tensorflow/core/kernels/ops_testutil.h" +#include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/status.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/protobuf/config.pb.h" +#include "tensorflow/core/public/version.h" #include "tsl/lib/core/status_test_util.h" #include "tsl/platform/blocking_counter.h" #include "tsl/platform/errors.h" +#include "tsl/platform/refcount.h" #include "tsl/platform/status.h" namespace tensorflow { @@ -120,7 +125,7 @@ class BatchFunctionKernelParallelWarmupTestState : public OpsTestBase { tsl::core::RefCountPtr *r) { *r = tsl::core::RefCountPtr( new IntraProcessRendezvous(device_mgr)); - return OkStatus(); + return absl::OkStatus(); }}); std::vector inputs( @@ -152,6 +157,77 @@ class BatchFunctionKernelParallelWarmupTestState : public OpsTestBase { class BatchFunctionKernelParallelWarmupTest : public ::testing::TestWithParam {}; +TEST_P(BatchFunctionKernelParallelWarmupTest, HandlesLargeBatchSplitting) { + // This test fails if it does not come before the others in the suite, + // because `SharedBatchScheduler::QueueOptions::input_batch_size_limit` + // does not get reset. + SessionMetadata session_metadata; + session_metadata.set_name("test_model"); + session_metadata.set_version(123); + serving::WarmupStateRegistry::Key key(session_metadata.name(), + session_metadata.version()); + + int num_requests = 16; + + { + auto per_model_data = std::make_unique(); + per_model_data->warmup_all_batch_sizes = true; + auto handle = serving::GetGlobalWarmupStateRegistry().Register( + key, std::move(per_model_data)); + + tsl::BlockingCounter blocking_counter(num_requests); + for (int i = 0; i < num_requests; ++i) { + Env::Default()->SchedClosure([&]() { + BatchFunctionKernelParallelWarmupTestState test; + test.set_session_metadata(session_metadata); + TF_CHECK_OK(test.Init(/*enable_splitting=*/true, + /*check_output_shape=*/true)); + test.AddInputFromList( + TensorShape({16}), + {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}); + auto status = test.RunOpKernel(); + ASSERT_FALSE(status.ok()); + // This proves the kernel is executed with batch sizes other than 2. + EXPECT_TRUE(absl::StrContains(status.message(), + "is not compatible with expected shape")); + blocking_counter.DecrementCount(); + }); + } + blocking_counter.Wait(); + } + + { + EXPECT_FALSE(serving::GetGlobalWarmupStateRegistry().Lookup(key)); + auto per_model_data = std::make_unique(); + per_model_data->warmup_all_batch_sizes = true; + auto handle = serving::GetGlobalWarmupStateRegistry().Register( + key, std::move(per_model_data)); + + tsl::BlockingCounter blocking_counter(num_requests); + for (int i = 0; i < num_requests; ++i) { + Env::Default()->SchedClosure([&]() { + BatchFunctionKernelParallelWarmupTestState test; + test.set_session_metadata(session_metadata); + // Error free when the EnsureShapeOp is replaced with an Identity op. + TF_CHECK_OK( + test.Init(/*enable_splitting=*/true, /*check_output_shape=*/false)); + test.AddInputFromList( + TensorShape({16}), + {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}); + TF_CHECK_OK(test.RunOpKernel()); + + test::ExpectTensorEqual( + *test.GetOutput(0), + test::AsTensor( + {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15})); + + blocking_counter.DecrementCount(); + }); + } + blocking_counter.Wait(); + } +} + TEST_P(BatchFunctionKernelParallelWarmupTest, ParallelWarmup) { SessionMetadata session_metadata; session_metadata.set_name("test_model"); diff --git a/tensorflow/core/kernels/batching_util/BUILD b/tensorflow/core/kernels/batching_util/BUILD index 8b8bb863e78aad..26b33b2a71e781 100644 --- a/tensorflow/core/kernels/batching_util/BUILD +++ b/tensorflow/core/kernels/batching_util/BUILD @@ -1,7 +1,7 @@ # Description: Utilities. -load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") load("//tensorflow:tensorflow.bzl", "tf_cc_test") +load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], @@ -347,6 +347,7 @@ cc_library( ":adaptive_shared_batch_scheduler", ":batch_scheduler", ":concat_split_util", + ":input_split_metadata", ":shared_batch_scheduler", ":threadsafe_status", ":warmup", @@ -360,18 +361,22 @@ cc_library( "//tensorflow/core/common_runtime:request_cost", "//tensorflow/core/common_runtime:request_cost_accessor", "//tensorflow/core/common_runtime:request_cost_accessor_registry", + "//tensorflow/core/platform:errors", "//tensorflow/core/platform:status", "//tensorflow/core/platform:thread_annotations", "//tensorflow/core/profiler/lib:traceme", "//tensorflow/core/profiler/lib:traceme_encode", "//tensorflow/core/protobuf:for_core_protos_cc", "//tensorflow/core/util:incremental_barrier", + "@com_google_absl//absl/container:fixed_array", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@com_google_absl//absl/synchronization", "@com_google_absl//absl/time", "@com_google_absl//absl/types:optional", "@local_tsl//tsl/platform:criticality", + "@local_tsl//tsl/platform:logging", ], ) diff --git a/tensorflow/core/kernels/batching_util/adaptive_shared_batch_scheduler.h b/tensorflow/core/kernels/batching_util/adaptive_shared_batch_scheduler.h index b9ca84b3bef2b6..28e8f4396fd668 100644 --- a/tensorflow/core/kernels/batching_util/adaptive_shared_batch_scheduler.h +++ b/tensorflow/core/kernels/batching_util/adaptive_shared_batch_scheduler.h @@ -424,7 +424,7 @@ Status AdaptiveSharedBatchScheduler::Create( options.batches_to_average_over); } scheduler->reset(new AdaptiveSharedBatchScheduler(options)); - return OkStatus(); + return absl::OkStatus(); } template @@ -472,7 +472,7 @@ Status AdaptiveSharedBatchScheduler::AddQueue( this->shared_from_this(), options)); mutex_lock l(mu_); queues_and_callbacks_[asbs_queue_raw] = process_batch_callback; - return OkStatus(); + return absl::OkStatus(); } template @@ -822,7 +822,7 @@ Status ASBSQueue::Schedule(std::unique_ptr* task) { if (closed_batch) { scheduler_->MaybeScheduleClosedBatches(); } - return OkStatus(); + return absl::OkStatus(); } template diff --git a/tensorflow/core/kernels/batching_util/adaptive_shared_batch_scheduler_test.cc b/tensorflow/core/kernels/batching_util/adaptive_shared_batch_scheduler_test.cc index 1cfeb68b028de2..8858390da400fc 100644 --- a/tensorflow/core/kernels/batching_util/adaptive_shared_batch_scheduler_test.cc +++ b/tensorflow/core/kernels/batching_util/adaptive_shared_batch_scheduler_test.cc @@ -470,7 +470,7 @@ TEST(AdaptiveSharedBatchSchedulerTest, TruncateBatches) { output_tasks->emplace_back(new FakeTask(task_size)); remaining_size -= task_size; } - return OkStatus(); + return absl::OkStatus(); }; TF_ASSERT_OK(scheduler->AddQueue(queue_options, queue_callback, &queue)); TF_ASSERT_OK(ScheduleTask(30, queue.get())); diff --git a/tensorflow/core/kernels/batching_util/basic_batch_scheduler.h b/tensorflow/core/kernels/batching_util/basic_batch_scheduler.h index 607a8212175662..047058c0c50f8e 100644 --- a/tensorflow/core/kernels/batching_util/basic_batch_scheduler.h +++ b/tensorflow/core/kernels/batching_util/basic_batch_scheduler.h @@ -342,7 +342,7 @@ Status BasicBatchScheduler::Create( scheduler->reset( new BasicBatchScheduler(std::move(shared_scheduler_queue))); - return OkStatus(); + return absl::OkStatus(); } template diff --git a/tensorflow/core/kernels/batching_util/batch_resource_base.cc b/tensorflow/core/kernels/batching_util/batch_resource_base.cc index fa900a9c87789c..08ad9866808498 100644 --- a/tensorflow/core/kernels/batching_util/batch_resource_base.cc +++ b/tensorflow/core/kernels/batching_util/batch_resource_base.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/core/kernels/batching_util/batch_resource_base.h" #include +#include #include #include #include @@ -26,10 +27,13 @@ limitations under the License. #include #include +#include "absl/container/fixed_array.h" #include "absl/container/flat_hash_map.h" +#include "absl/strings/match.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "absl/synchronization/blocking_counter.h" +#include "absl/synchronization/mutex.h" #include "absl/time/time.h" #include "absl/types/optional.h" #include "tensorflow/core/common_runtime/cost_constants.h" @@ -39,10 +43,16 @@ limitations under the License. #include "tensorflow/core/common_runtime/request_cost.h" #include "tensorflow/core/common_runtime/request_cost_accessor.h" #include "tensorflow/core/common_runtime/request_cost_accessor_registry.h" +#include "tensorflow/core/framework/allocator.h" #include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/op_requires.h" #include "tensorflow/core/framework/ops_util.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/tensor_util.h" #include "tensorflow/core/kernels/batching_util/concat_split_util.h" +#include "tensorflow/core/kernels/batching_util/input_split_metadata.h" +#include "tensorflow/core/kernels/batching_util/threadsafe_status.h" #include "tensorflow/core/kernels/batching_util/warmup.h" #include "tensorflow/core/lib/gtl/cleanup.h" #include "tensorflow/core/lib/monitoring/counter.h" @@ -50,6 +60,10 @@ limitations under the License. #include "tensorflow/core/lib/monitoring/percentile_sampler.h" #include "tensorflow/core/lib/monitoring/sampler.h" #include "tensorflow/core/lib/monitoring/types.h" +#include "tensorflow/core/platform/context.h" +#include "tensorflow/core/platform/env_time.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/status.h" #include "tensorflow/core/platform/types.h" #include "tensorflow/core/profiler/lib/traceme.h" @@ -277,6 +291,7 @@ BatchResourceBase::BatchTask::CreateSplitTask( task->is_partial = true; task->start_time = this->start_time; task->request_cost = this->request_cost; + task->forced_warmup_batch_size = this->forced_warmup_batch_size; return task; } @@ -300,12 +315,22 @@ Status BatchResourceBase::RegisterWarmupInputs( int64_t guid, OpKernelContext* context, const string& batcher_queue_name, const CreateBatchTaskFn& create_batch_task_fn, AsyncOpKernel::DoneCallback done) { + auto shared_status = std::make_shared(); + auto create_batch_task_fn_share_status = [&create_batch_task_fn, + &shared_status]() { + auto batch_task = create_batch_task_fn(); + if (!batch_task.ok()) { + return batch_task; + } + (*batch_task)->status = shared_status; + return batch_task; + }; auto warmup_counter = std::make_shared(allowed_batch_sizes_.size()); // Enqueue warmup batches. for (int i = 0; i < allowed_batch_sizes_.size(); ++i) { Status status = RegisterInput( - guid, context, batcher_queue_name, create_batch_task_fn, + guid, context, batcher_queue_name, create_batch_task_fn_share_status, [warmup_counter = warmup_counter.get()]() { warmup_counter->DecrementCount(); }, @@ -313,11 +338,13 @@ Status BatchResourceBase::RegisterWarmupInputs( if (!status.ok()) return status; } // Enqueue real batch if the other batches were enqueued successfully. - return RegisterInput(guid, context, batcher_queue_name, create_batch_task_fn, - [warmup_counter, done = done]() { - warmup_counter->Wait(); - done(); - }); + return RegisterInput( + guid, context, batcher_queue_name, create_batch_task_fn_share_status, + [warmup_counter, context, shared_status, done = std::move(done)]() { + warmup_counter->Wait(); + context->SetStatus(shared_status->status()); + done(); + }); } Status BatchResourceBase::RegisterInput( @@ -394,7 +421,7 @@ Status BatchResourceBase::RegisterInput( &empty_output, cpu_alloc)); } done_callback(); - return OkStatus(); + return absl::OkStatus(); } OpInputList captured_tensors; const auto captured_status = @@ -406,10 +433,24 @@ Status BatchResourceBase::RegisterInput( } } batch_components->context = context; - batch_components->done_callback = std::move(done_callback); batch_components->split_index = 0; batch_components->output = std::make_shared(); - batch_components->status = std::make_shared(); + if (!batch_components->status) { + // A shared status has already been injected if `RegisterWarmupInputs` + // was called. If not, create the `ThreadSafeStatus` and tie the setting + // of the kernel context's status to this shared status. + batch_components->status = std::make_shared(); + batch_components->done_callback = [done_callback = std::move(done_callback), + shared_status = batch_components->status, + context = context]() { + context->SetStatus(shared_status->status()); + done_callback(); + }; + } else { + // Otherwise `RegisterWarmupInputs` was called and already setup the + // `done_callback` and `status` correctly for this `BatchTask`. + batch_components->done_callback = std::move(done_callback); + } batch_components->forced_warmup_batch_size = forced_warmup_batch_size; std::unique_ptr request_cost_accessor = @@ -558,7 +599,7 @@ BatchResourceBase::GetAdaptiveBatcherQueueOptions( } } - return OkStatus(); + return absl::OkStatus(); } // Returns the smallest entry in 'allowed_batch_sizes_' that is greater than @@ -656,7 +697,7 @@ Status BatchResourceBase::ConcatInputTensors( TF_RETURN_IF_ERROR(concat_status); concatenated_tensors->push_back(concatenated_tensor); } - return OkStatus(); + return absl::OkStatus(); } /*static*/ Status BatchResourceBase::SplitInputTask( @@ -673,7 +714,9 @@ Status BatchResourceBase::ConcatInputTensors( // complete. std::function split_task_done_callback = [done_callback = input_task.done_callback, output = input_task.output, - op_kernel_context = input_task.context, status = shared_status]() { + forced_warmup_batch_size = input_task.forced_warmup_batch_size, + op_kernel_context = input_task.context, + status = shared_status]() mutable { const int num_output = op_kernel_context->num_outputs(); for (int i = 0; i < num_output; ++i) { Tensor output_tensor; @@ -692,10 +735,10 @@ Status BatchResourceBase::ConcatInputTensors( if (!concat_status.ok()) { status->Update(concat_status); } - - op_kernel_context->set_output(i, std::move(output_tensor)); + if (forced_warmup_batch_size == 0) { + op_kernel_context->set_output(i, std::move(output_tensor)); + } } - op_kernel_context->SetStatus(status->status()); done_callback(); }; IncrementalBarrier barrier(split_task_done_callback); @@ -751,7 +794,7 @@ Status BatchResourceBase::ConcatInputTensors( std::back_inserter(output_task.inputs)); } } - return OkStatus(); + return absl::OkStatus(); } Status BatchResourceBase::SplitOutputTensors( @@ -824,7 +867,7 @@ Status BatchResourceBase::SplitOutputTensors( } } - return OkStatus(); + return absl::OkStatus(); } void BatchResourceBase::ProcessFuncBatch(std::unique_ptr batch) const { @@ -864,10 +907,15 @@ void BatchResourceBase::ProcessFuncBatch(std::unique_ptr batch) const { batch_cost_measurements.clear(); for (int i = 0; i < batch->num_tasks(); ++i) { WithContext wc(batch->task(i).propagated_context); - if (batch->task(i).is_partial) { - batch->mutable_task(i)->status->Update(status); - } else { - batch->mutable_task(i)->context->SetStatus(status); + if (!status.ok()) { + if (!absl::StrContains( + status.message(), + "Function was cancelled before it was started")) { + batch->mutable_task(i)->status->Update(status); + } else { + // Do not propagate this error; Prefer a more helpful error message. + LOG(ERROR) << "ERROR!!!! " << status.message(); + } } batch->mutable_task(i)->done_callback(); } @@ -1034,7 +1082,7 @@ void BatchResourceBase::ProcessBatch(std::unique_ptr batch) const { index_flat(task_idx, 2) = offset + task.size(); offset += task.size(); } - return OkStatus(); + return absl::OkStatus(); } // Looks up the batcher queue for 'queue_name'. If it didn't previously exist, @@ -1046,7 +1094,7 @@ Status BatchResourceBase::LookupOrCreateBatcherQueue(const string& queue_name, auto it = batcher_queues_.find(queue_name); if (it != batcher_queues_.end()) { *queue = it->second.get(); - return OkStatus(); + return absl::OkStatus(); } std::unique_ptr new_queue; @@ -1072,7 +1120,7 @@ Status BatchResourceBase::LookupOrCreateBatcherQueue(const string& queue_name, } *queue = new_queue.get(); batcher_queues_[queue_name] = std::move(new_queue); - return OkStatus(); + return absl::OkStatus(); } void BatchResourceBase::SplitBatchCostsAndRecordMetrics( diff --git a/tensorflow/core/kernels/batching_util/batch_resource_base.h b/tensorflow/core/kernels/batching_util/batch_resource_base.h index b86d25c097da39..3983a28e1b70f2 100644 --- a/tensorflow/core/kernels/batching_util/batch_resource_base.h +++ b/tensorflow/core/kernels/batching_util/batch_resource_base.h @@ -171,6 +171,8 @@ class BatchResourceBase : public ResourceBase { AsyncOpKernel::DoneCallback done); // Ingests data from one invocation of the batch op. The data is enqueued to // be combined with others into a batch, asynchronously. + // `CreateBatchTaskFn` should be used to instantiate fields added to a + // child class of `BatchTask` by the caller. Status RegisterInput(int64_t guid, OpKernelContext* context, const string& batcher_queue_name, const CreateBatchTaskFn& create_batch_task_fn, diff --git a/tensorflow/core/kernels/batching_util/concat_split_util.h b/tensorflow/core/kernels/batching_util/concat_split_util.h index 77a2b16340dd08..d3b2839ae9752e 100644 --- a/tensorflow/core/kernels/batching_util/concat_split_util.h +++ b/tensorflow/core/kernels/batching_util/concat_split_util.h @@ -87,7 +87,7 @@ Status Concat(OpKernelContext* context, const gtl::ArraySlice inputs, ConcatCPU(context->device(), inputs_flat, &output_flat); } - return OkStatus(); + return absl::OkStatus(); } // Same as 'Concat' above, but handles Tensor dtype deduction automatically. @@ -135,7 +135,7 @@ Status SplitEasyCases(OpKernelContext* context, const Tensor& input, if (sizes.size() == 1 && sizes.at(0) == input.shape().dim_size(0)) { outputs->push_back(input); *done = true; - return OkStatus(); + return absl::OkStatus(); } // Special case 1: input is aligned. @@ -146,10 +146,10 @@ Status SplitEasyCases(OpKernelContext* context, const Tensor& input, position += size; } *done = true; - return OkStatus(); + return absl::OkStatus(); } - return OkStatus(); + return absl::OkStatus(); } // Handles the general case, on CPU. @@ -189,7 +189,7 @@ Status SplitCPU(OpKernelContext* context, const Tensor& input, position += size; } - return OkStatus(); + return absl::OkStatus(); } #if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) || \ @@ -215,7 +215,7 @@ Status Split(OpKernelContext* context, const Tensor& input, TF_RETURN_IF_ERROR( SplitEasyCases(context, input, sizes, outputs, &easy_cases_done)); if (easy_cases_done) { - return OkStatus(); + return absl::OkStatus(); } #if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) || \ diff --git a/tensorflow/core/kernels/batching_util/serial_device_batch_scheduler.h b/tensorflow/core/kernels/batching_util/serial_device_batch_scheduler.h index 80d24faa794838..7340ace6317603 100644 --- a/tensorflow/core/kernels/batching_util/serial_device_batch_scheduler.h +++ b/tensorflow/core/kernels/batching_util/serial_device_batch_scheduler.h @@ -303,7 +303,7 @@ Status SerialDeviceBatchScheduler::Create( "specified"); } scheduler->reset(new SerialDeviceBatchScheduler(options)); - return OkStatus(); + return absl::OkStatus(); } template @@ -349,7 +349,7 @@ Status SerialDeviceBatchScheduler::AddQueue( this->shared_from_this(), options)); mutex_lock l(mu_); queues_and_callbacks_[SDBS_queue_raw] = process_batch_callback; - return OkStatus(); + return absl::OkStatus(); } template @@ -516,7 +516,7 @@ Status SDBSQueue::Schedule(std::unique_ptr* task) { } // AddBatch must be called outside of lock, since it may call ReleaseBatch. if (new_batch != nullptr) scheduler_->AddBatch(new_batch); - return OkStatus(); + return absl::OkStatus(); } template diff --git a/tensorflow/core/kernels/batching_util/shared_batch_scheduler.h b/tensorflow/core/kernels/batching_util/shared_batch_scheduler.h index 4e218e83dd51c2..f018076e96f633 100644 --- a/tensorflow/core/kernels/batching_util/shared_batch_scheduler.h +++ b/tensorflow/core/kernels/batching_util/shared_batch_scheduler.h @@ -570,7 +570,7 @@ Status SharedBatchScheduler::Create( options.num_batch_threads); } scheduler->reset(new SharedBatchScheduler(options)); - return OkStatus(); + return absl::OkStatus(); } template @@ -678,7 +678,7 @@ Status SharedBatchScheduler::AddQueueAfterRewritingOptions( } } *queue = std::move(handle); - return OkStatus(); + return absl::OkStatus(); } template @@ -909,7 +909,7 @@ Status Queue::ScheduleWithLazySplit(std::unique_ptr* task) { schedulable_batch_callback_(); } - return OkStatus(); + return absl::OkStatus(); } // TODO(b/194294263): @@ -984,7 +984,7 @@ Status Queue::ScheduleWithoutOrEagerSplit( schedulable_batch_callback_(); } - return OkStatus(); + return absl::OkStatus(); } template @@ -1041,7 +1041,7 @@ Status Queue::ValidateBatchTaskQueueCapacity(TaskType* task) const { ", open_batch_size=", tail_batch_task_size(), ", max_execution_batch_size=", max_execution_batch_size(), ")"); } - return OkStatus(); + return absl::OkStatus(); } // NOTE, the capacity checking below is loose and is retained @@ -1062,7 +1062,7 @@ Status Queue::ValidateBatchTaskQueueCapacity(TaskType* task) const { options_.max_enqueued_batches); } } - return OkStatus(); + return absl::OkStatus(); } template diff --git a/tensorflow/core/kernels/batching_util/shared_batch_scheduler_test.cc b/tensorflow/core/kernels/batching_util/shared_batch_scheduler_test.cc index 04d0b78df1d3ab..093999b319d7f0 100644 --- a/tensorflow/core/kernels/batching_util/shared_batch_scheduler_test.cc +++ b/tensorflow/core/kernels/batching_util/shared_batch_scheduler_test.cc @@ -172,7 +172,7 @@ class SharedBatchSchedulerTest (*output_tasks)[i] = std::make_unique(task_sizes[i]); } - return OkStatus(); + return absl::OkStatus(); }; } return nullptr; @@ -840,7 +840,7 @@ void CreateQueues() { }); busy_waiter.join(); notifier.join(); - return OkStatus(); + return absl::OkStatus(); }; internal::Queue::ProcessBatchCallback process_batch_callback = diff --git a/tensorflow/core/kernels/batching_util/threadsafe_status_test.cc b/tensorflow/core/kernels/batching_util/threadsafe_status_test.cc index 94ab08c8a1743d..3dc3185ce43b4b 100644 --- a/tensorflow/core/kernels/batching_util/threadsafe_status_test.cc +++ b/tensorflow/core/kernels/batching_util/threadsafe_status_test.cc @@ -35,7 +35,7 @@ TEST(ThreadSafeStatus, Update) { status.Update(errors::FailedPrecondition("original error")); EXPECT_EQ(status.status().code(), error::FAILED_PRECONDITION); - status.Update(OkStatus()); + status.Update(absl::OkStatus()); EXPECT_EQ(status.status().code(), error::FAILED_PRECONDITION); status.Update(errors::Internal("new error")); diff --git a/tensorflow/core/kernels/cast_op.cc b/tensorflow/core/kernels/cast_op.cc index c892af27b1eed4..79d4462489c583 100644 --- a/tensorflow/core/kernels/cast_op.cc +++ b/tensorflow/core/kernels/cast_op.cc @@ -125,7 +125,7 @@ CpuCastOp::CpuCastOp(OpKernelConstruction* ctx) : CastOpBase(ctx) { Status CpuCastOp::Prepare() { if (external_src_dtype_ == external_dst_dtype_) { work_ = nullptr; // Identity - return OkStatus(); + return absl::OkStatus(); } if (src_dtype_ == DT_BOOL) { work_ = GetCpuCastFromBool(dst_dtype_); @@ -172,7 +172,7 @@ Status CpuCastOp::Prepare() { // vectorized versions (not the least based on F16C for Haswell // or newer). - return work_ == nullptr ? Unimplemented() : OkStatus(); + return work_ == nullptr ? Unimplemented() : absl::OkStatus(); } #if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) || \ diff --git a/tensorflow/core/kernels/check_numerics_op.cc b/tensorflow/core/kernels/check_numerics_op.cc index 238239fdf917c9..539f109bf72bad 100644 --- a/tensorflow/core/kernels/check_numerics_op.cc +++ b/tensorflow/core/kernels/check_numerics_op.cc @@ -226,8 +226,10 @@ class CheckNumericsOp : public AsyncOpKernel { se::DeviceMemoryBase abnormal_detected_ptr( abnormal_detected.flat().data(), abnormal_detected.flat().size()); - stream->ThenMemset32(&abnormal_detected_ptr, 0, - abnormal_detected.flat().size() * sizeof(int)); + OP_REQUIRES_OK( + context, + stream->Memset32(&abnormal_detected_ptr, 0, + abnormal_detected.flat().size() * sizeof(int))); // Call the GPU kernels for the numerical checks const Device& d = context->eigen_device(); @@ -244,14 +246,14 @@ class CheckNumericsOp : public AsyncOpKernel { context->allocate_temp(DT_INT32, TensorShape({abnormal_detected_size}), &abnormal_detected_host, attr), done); - OP_REQUIRES_ASYNC( - context, - stream - ->ThenMemcpy(abnormal_detected_host.flat().data(), - abnormal_detected_ptr, - abnormal_detected_size * sizeof(int)) - .ok(), - errors::Internal("GPU memcpy from device to host failed"), done); + OP_REQUIRES_ASYNC(context, + stream + ->Memcpy(abnormal_detected_host.flat().data(), + abnormal_detected_ptr, + abnormal_detected_size * sizeof(int)) + .ok(), + errors::Internal("GPU memcpy from device to host failed"), + done); // We have observed crashes on some network stacks when not holding // this tensor reference. diff --git a/tensorflow/core/kernels/checkpoint_callback_manager.cc b/tensorflow/core/kernels/checkpoint_callback_manager.cc index 85d9150e822dcf..30ac196d2a52be 100644 --- a/tensorflow/core/kernels/checkpoint_callback_manager.cc +++ b/tensorflow/core/kernels/checkpoint_callback_manager.cc @@ -165,7 +165,7 @@ Status CheckpointCallbackManager::RegisterSaveCallback( TriggerSaveCallbackIfFileNotExist(checkpoint_id, checkpoint_dir, file_extension, lazy_callback); } - return OkStatus(); + return absl::OkStatus(); } bool CheckpointCallbackManager::DoesSaveCallbackExist( @@ -199,7 +199,7 @@ Status CheckpointCallbackManager::RegisterRestoreCallback( TriggerRestoreCallbackIfFileExists(checkpoint_id, checkpoint_dir, file_extension, lazy_callback); } - return OkStatus(); + return absl::OkStatus(); } bool CheckpointCallbackManager::DoesRestoreCallbackExist( diff --git a/tensorflow/core/kernels/checkpoint_callback_manager_test.cc b/tensorflow/core/kernels/checkpoint_callback_manager_test.cc index 115eb50b68c5c2..57f3c19fc68e2e 100644 --- a/tensorflow/core/kernels/checkpoint_callback_manager_test.cc +++ b/tensorflow/core/kernels/checkpoint_callback_manager_test.cc @@ -111,11 +111,11 @@ TEST_F(CheckpointCallbackManagerTest, RegisterSaveCallbackTwice) { TEST_F(CheckpointCallbackManagerTest, RegisterRestoreCallbackTwice) { RestoreCallback first_callback = [](absl::string_view checkpoint_id, absl::string_view str) { - return OkStatus(); + return absl::OkStatus(); }; RestoreCallback second_callback = [](absl::string_view checkpoint_id, absl::string_view str) { - return OkStatus(); + return absl::OkStatus(); }; TF_ASSERT_OK(checkpoint_callback_manager_->RegisterRestoreCallback( @@ -150,11 +150,11 @@ TEST_F(CheckpointCallbackManagerTest, DoesSaveCallbackExist) { TEST_F(CheckpointCallbackManagerTest, DoesRestoreCallbackExist) { RestoreCallback first_callback = [](absl::string_view checkpoint_id, absl::string_view str) { - return OkStatus(); + return absl::OkStatus(); }; RestoreCallback second_callback = [](absl::string_view checkpoint_id, absl::string_view str) { - return OkStatus(); + return absl::OkStatus(); }; TF_ASSERT_OK(checkpoint_callback_manager_->RegisterRestoreCallback( @@ -235,7 +235,7 @@ TEST_F(CheckpointCallbackManagerTest, Restore) { EXPECT_EQ(checkpoint_id, "model.ckpt-100"); EXPECT_EQ(str, "Apple"); ++callback_call_count; - return OkStatus(); + return absl::OkStatus(); }; TF_ASSERT_OK(checkpoint_callback_manager_->RegisterRestoreCallback( @@ -271,7 +271,7 @@ TEST_F(CheckpointCallbackManagerTest, SaveAndRestore) { EXPECT_EQ(checkpoint_id, "model.ckpt-500"); EXPECT_EQ(str, "Apple"); ++restore_callback_count; - return OkStatus(); + return absl::OkStatus(); }; TF_ASSERT_OK(checkpoint_callback_manager_->RegisterRestoreCallback( @@ -312,7 +312,7 @@ TEST_F(CheckpointCallbackManagerTest, RestoreLazyCallback) { EXPECT_EQ(checkpoint_id, "model.ckpt-100"); EXPECT_EQ(str, "Apple"); ++callback_call_count; - return OkStatus(); + return absl::OkStatus(); }; TF_EXPECT_OK(WriteStringToFile( diff --git a/tensorflow/core/kernels/collective_ops.cc b/tensorflow/core/kernels/collective_ops.cc index 57815190f4073a..dea6e0ec5f0953 100644 --- a/tensorflow/core/kernels/collective_ops.cc +++ b/tensorflow/core/kernels/collective_ops.cc @@ -560,7 +560,7 @@ class CollectiveAssignGroupV2OpKernel : public OpKernel { << " device_index = " << index << " group_key = " << group_key->DebugString() << " group_size = " << group_size->DebugString(); - return OkStatus(); + return absl::OkStatus(); } } } @@ -638,7 +638,7 @@ class CollectiveOpV2Kernel : public AsyncOpKernel { col_params->instance.instance_key = instance_key.unaligned_flat()(0); col_params->instance.impl_details.communication_hint = communication_hint_; col_params->instance.impl_details.timeout_seconds = timeout_seconds_; - return OkStatus(); + return absl::OkStatus(); } // Runs a collective. The output tensor must be allocated before calling this @@ -1072,7 +1072,7 @@ class CollectiveInitializeCommunicatorOpKernel : public AsyncOpKernel { "rank must be less than group size but got ", rank, " >= ", group_size); } - return OkStatus(); + return absl::OkStatus(); } void ComputeAsync(OpKernelContext* c, DoneCallback done) override { @@ -1196,7 +1196,7 @@ class CollectiveOpV3Kernel : public AsyncOpKernel { col_params->instance.impl_details.timeout_seconds = timeout_seconds_ > 0 ? resource->timeout_seconds() : timeout_seconds_; col_params->run_group_initialization = false; - return OkStatus(); + return absl::OkStatus(); } // Runs a collective. The output tensor must be allocated before calling this diff --git a/tensorflow/core/kernels/conditional_accumulator.h b/tensorflow/core/kernels/conditional_accumulator.h index 5f901b12d43d58..13343a99b64acd 100644 --- a/tensorflow/core/kernels/conditional_accumulator.h +++ b/tensorflow/core/kernels/conditional_accumulator.h @@ -81,7 +81,7 @@ class ConditionalAccumulator shape_.DebugString(), ", got ", tensor->shape().DebugString()); } - return OkStatus(); + return absl::OkStatus(); } void AllocateAndAssignToAccumGradFunction(OpKernelContext* ctx, diff --git a/tensorflow/core/kernels/conditional_accumulator_base.cc b/tensorflow/core/kernels/conditional_accumulator_base.cc index c830ac160324fa..8a0c73d0bdbca2 100644 --- a/tensorflow/core/kernels/conditional_accumulator_base.cc +++ b/tensorflow/core/kernels/conditional_accumulator_base.cc @@ -31,7 +31,7 @@ ConditionalAccumulatorBase::ConditionalAccumulatorBase( Status ConditionalAccumulatorBase::MatchesNodeDef(const NodeDef& node_def) { // TODO(xinghao@): implement the checks for the node definition - return OkStatus(); + return absl::OkStatus(); } /** @@ -47,7 +47,7 @@ Status ConditionalAccumulatorBase::SetGlobalStep(int64_t new_global_step) { << " >= " << new_global_step << " = new_global_step."; } current_global_step_ = new_global_step; - return OkStatus(); + return absl::OkStatus(); } /** diff --git a/tensorflow/core/kernels/conditional_accumulator_base_op.h b/tensorflow/core/kernels/conditional_accumulator_base_op.h index 9acf1ca1e21817..73504fbd495327 100644 --- a/tensorflow/core/kernels/conditional_accumulator_base_op.h +++ b/tensorflow/core/kernels/conditional_accumulator_base_op.h @@ -116,7 +116,7 @@ class ConditionalAccumulatorBaseOp : public OpKernel { h(0) = cinfo_.container(); h(1) = cinfo_.name(); accumulator_set_ = true; - return OkStatus(); + return absl::OkStatus(); } }; diff --git a/tensorflow/core/kernels/conditional_accumulator_op.cc b/tensorflow/core/kernels/conditional_accumulator_op.cc index 6df0b7f701e2fd..732847ff200db7 100644 --- a/tensorflow/core/kernels/conditional_accumulator_op.cc +++ b/tensorflow/core/kernels/conditional_accumulator_op.cc @@ -37,13 +37,13 @@ class ConditionalAccumulatorOp : public ConditionalAccumulatorBaseOp { new ConditionalAccumulator(dtype_, shape_, cinfo_.name(), reduction_type_); *ret = accumulator; - return OkStatus(); + return absl::OkStatus(); }; } Status CheckSignature(OpKernelContext* ctx) override { TF_RETURN_IF_ERROR(ctx->MatchSignature({}, {DT_STRING_REF})); - return OkStatus(); + return absl::OkStatus(); } void SetHandleToOutput(OpKernelContext* ctx) @@ -75,13 +75,13 @@ class ResourceConditionalAccumulatorOp : public ConditionalAccumulatorBaseOp { new ConditionalAccumulator(dtype_, shape_, cinfo_.name(), reduction_type_); *ret = accumulator; - return OkStatus(); + return absl::OkStatus(); }; } Status CheckSignature(OpKernelContext* ctx) override { TF_RETURN_IF_ERROR(ctx->MatchSignature({}, {DT_RESOURCE})); - return OkStatus(); + return absl::OkStatus(); } void SetHandleToOutput(OpKernelContext* ctx) diff --git a/tensorflow/core/kernels/conv_grad_filter_ops_3d.cc b/tensorflow/core/kernels/conv_grad_filter_ops_3d.cc index c21a0cc907bce7..4d43ee0bdf8592 100644 --- a/tensorflow/core/kernels/conv_grad_filter_ops_3d.cc +++ b/tensorflow/core/kernels/conv_grad_filter_ops_3d.cc @@ -707,6 +707,9 @@ void LaunchConvBackpropFilterOpImpl( "with cuDNN on Ampere GPUs or later.")); return; } + auto* blas = stream->parent()->AsBlas(); + OP_REQUIRES(context, blas != nullptr, + absl::InternalError("No BLAS for stream.")); bool is_grouped_convolution = filter_shape.dim_size(3) != dims.in_depth; if (!is_grouped_convolution && dims.filter_size(1) == 1 && @@ -737,8 +740,8 @@ void LaunchConvBackpropFilterOpImpl( auto c_ptr = AsDeviceMemory(filter_backprop->template flat().data(), filter_backprop->template flat().size()); - OP_REQUIRES_OK(context, stream->ThenBlasGemm( - se::blas::Transpose::kNoTranspose, + OP_REQUIRES_OK( + context, blas->BlasGemm(stream, se::blas::Transpose::kNoTranspose, se::blas::Transpose::kTranspose, n, m, k, a_ptr, n, b_ptr, m, &c_ptr, n, GetNumericOptions(), se::blas::CallContext::kNone)); @@ -760,8 +763,8 @@ void LaunchConvBackpropFilterOpImpl( auto c_ptr = AsDeviceMemory(filter_backprop->template flat().data(), filter_backprop->template flat().size()); - OP_REQUIRES_OK(context, stream->ThenBlasGemm( - se::blas::Transpose::kNoTranspose, + OP_REQUIRES_OK( + context, blas->BlasGemm(stream, se::blas::Transpose::kNoTranspose, se::blas::Transpose::kTranspose, n, m, k, b_ptr, n, a_ptr, m, &c_ptr, n, GetNumericOptions(), se::blas::CallContext::kNone)); diff --git a/tensorflow/core/kernels/conv_grad_filter_ops_launcher.cc b/tensorflow/core/kernels/conv_grad_filter_ops_launcher.cc index e65e5995e92045..f1c8958b127e88 100644 --- a/tensorflow/core/kernels/conv_grad_filter_ops_launcher.cc +++ b/tensorflow/core/kernels/conv_grad_filter_ops_launcher.cc @@ -226,6 +226,8 @@ void LaunchConv2DBackpropFilterOpImpl( "without cudnn")); return; } + auto* blas = stream->parent()->AsBlas(); + OP_REQUIRES(ctx, blas != nullptr, absl::InternalError("No BLAS for stream.")); // If the filter in-depth (filter_shape.dim_size(2)) is 1 and smaller than the // input depth, it's a depthwise convolution. More generally, if the filter @@ -261,11 +263,11 @@ void LaunchConv2DBackpropFilterOpImpl( auto c_ptr = AsDeviceMemory(filter_backprop->template flat().data(), filter_backprop->template flat().size()); - OP_REQUIRES_OK(ctx, stream->ThenBlasGemm(se::blas::Transpose::kNoTranspose, - se::blas::Transpose::kTranspose, n, - m, k, a_ptr, n, b_ptr, m, &c_ptr, - n, GetNumericOptions(), - se::blas::CallContext::kNone)); + OP_REQUIRES_OK( + ctx, blas->BlasGemm(stream, se::blas::Transpose::kNoTranspose, + se::blas::Transpose::kTranspose, n, m, k, a_ptr, n, + b_ptr, m, &c_ptr, n, GetNumericOptions(), + se::blas::CallContext::kNone)); return; } else if (dims.spatial_dims[0].filter_size == dims.spatial_dims[0].input_size && @@ -287,11 +289,11 @@ void LaunchConv2DBackpropFilterOpImpl( auto c_ptr = AsDeviceMemory(filter_backprop->template flat().data(), filter_backprop->template flat().size()); - OP_REQUIRES_OK(ctx, stream->ThenBlasGemm(se::blas::Transpose::kNoTranspose, - se::blas::Transpose::kTranspose, n, - m, k, b_ptr, n, a_ptr, m, &c_ptr, - n, GetNumericOptions(), - se::blas::CallContext::kNone)); + OP_REQUIRES_OK( + ctx, blas->BlasGemm(stream, se::blas::Transpose::kNoTranspose, + se::blas::Transpose::kTranspose, n, m, k, b_ptr, n, + a_ptr, m, &c_ptr, n, GetNumericOptions(), + se::blas::CallContext::kNone)); return; } diff --git a/tensorflow/core/kernels/conv_grad_input_ops.cc b/tensorflow/core/kernels/conv_grad_input_ops.cc index cf805027cb5835..ee41b2ddce4eb8 100644 --- a/tensorflow/core/kernels/conv_grad_input_ops.cc +++ b/tensorflow/core/kernels/conv_grad_input_ops.cc @@ -131,6 +131,8 @@ void LaunchConv2DBackpropInputOpGpuImpl( "without cudnn")); return; } + auto* blas = stream->parent()->AsBlas(); + OP_REQUIRES(ctx, blas != nullptr, absl::InternalError("No BLAS for stream.")); // If the filter in-depth (filter_shape.dim_size(2)) is 1 and smaller than the // input depth, it's a depthwise convolution. More generally, if the filter @@ -158,9 +160,9 @@ void LaunchConv2DBackpropInputOpGpuImpl( auto no_transpose = se::blas::Transpose::kNoTranspose; OP_REQUIRES_OK( - ctx, stream->ThenBlasGemm(transpose, no_transpose, n, m, k, b_ptr, k, - a_ptr, k, &c_ptr, n, GetNumericOptions(), - se::blas::CallContext::kNone)); + ctx, blas->BlasGemm(stream, transpose, no_transpose, n, m, k, b_ptr, k, + a_ptr, k, &c_ptr, n, GetNumericOptions(), + se::blas::CallContext::kNone)); return; } else if (dims.spatial_dims[0].filter_size == dims.spatial_dims[0].input_size && @@ -186,9 +188,9 @@ void LaunchConv2DBackpropInputOpGpuImpl( auto no_transpose = se::blas::Transpose::kNoTranspose; OP_REQUIRES_OK( - ctx, stream->ThenBlasGemm(transpose, no_transpose, n, m, k, b_ptr, k, - a_ptr, k, &c_ptr, n, GetNumericOptions(), - se::blas::CallContext::kNone)); + ctx, blas->BlasGemm(stream, transpose, no_transpose, n, m, k, b_ptr, k, + a_ptr, k, &c_ptr, n, GetNumericOptions(), + se::blas::CallContext::kNone)); return; } diff --git a/tensorflow/core/kernels/conv_grad_input_ops_3d.cc b/tensorflow/core/kernels/conv_grad_input_ops_3d.cc index 70311cbbd7a3d7..2fc59b60e3034b 100644 --- a/tensorflow/core/kernels/conv_grad_input_ops_3d.cc +++ b/tensorflow/core/kernels/conv_grad_input_ops_3d.cc @@ -718,6 +718,9 @@ void LaunchConvBackpropInputOpImpl( auto* stream = context->op_device_context()->stream(); OP_REQUIRES(context, stream, errors::Internal("No GPU stream available.")); + auto* blas = stream->parent()->AsBlas(); + OP_REQUIRES(context, blas != nullptr, + absl::InternalError("No BLAS for stream.")); bool is_grouped_convolution = filter_shape.dim_size(3) != dims.in_depth; if (!is_grouped_convolution && dims.filter_size(0) == 1 && @@ -740,10 +743,10 @@ void LaunchConvBackpropInputOpImpl( auto transpose = se::blas::Transpose::kTranspose; auto no_transpose = se::blas::Transpose::kNoTranspose; - OP_REQUIRES_OK(context, stream->ThenBlasGemm(transpose, no_transpose, n, m, - k, b_ptr, k, a_ptr, k, &c_ptr, - n, GetNumericOptions(), - se::blas::CallContext::kNone)); + OP_REQUIRES_OK( + context, blas->BlasGemm(stream, transpose, no_transpose, n, m, k, b_ptr, + k, a_ptr, k, &c_ptr, n, GetNumericOptions(), + se::blas::CallContext::kNone)); return; } else if (!is_grouped_convolution && dims.filter_size(0) == dims.input_size(0) && @@ -765,10 +768,10 @@ void LaunchConvBackpropInputOpImpl( auto transpose = se::blas::Transpose::kTranspose; auto no_transpose = se::blas::Transpose::kNoTranspose; - OP_REQUIRES_OK(context, stream->ThenBlasGemm(transpose, no_transpose, n, m, - k, b_ptr, k, a_ptr, k, &c_ptr, - n, GetNumericOptions(), - se::blas::CallContext::kNone)); + OP_REQUIRES_OK( + context, blas->BlasGemm(stream, transpose, no_transpose, n, m, k, b_ptr, + k, a_ptr, k, &c_ptr, n, GetNumericOptions(), + se::blas::CallContext::kNone)); return; } diff --git a/tensorflow/core/kernels/conv_grad_shape_utils.cc b/tensorflow/core/kernels/conv_grad_shape_utils.cc index 9560a37fd6eea6..4bce9b873473f2 100644 --- a/tensorflow/core/kernels/conv_grad_shape_utils.cc +++ b/tensorflow/core/kernels/conv_grad_shape_utils.cc @@ -87,7 +87,7 @@ Status ConvBackpropExtractAndVerifyDimension( << ", pad_before = " << dim->pad_before << ", pad_after = " << dim->pad_after << ", dilation = " << dim->dilation << ", strides = " << dim->stride; - return OkStatus(); + return absl::OkStatus(); } } // namespace @@ -154,7 +154,7 @@ Status ConvBackpropComputeDimensionsV2( strides, padding, padding_before, padding_after, image_dim, i, &dims->spatial_dims[i])); } - return OkStatus(); + return absl::OkStatus(); } Status ConvBackpropComputeDimensions(StringPiece label, int num_spatial_dims, diff --git a/tensorflow/core/kernels/conv_ops.cc b/tensorflow/core/kernels/conv_ops.cc index 063899c6b4a3b7..751a7c0b1957f4 100644 --- a/tensorflow/core/kernels/conv_ops.cc +++ b/tensorflow/core/kernels/conv_ops.cc @@ -91,7 +91,7 @@ Status InitConv2DParameters(const OpKernelConstruction* context, TF_RETURN_IF_ERROR(CheckValidPadding( params->padding, params->explicit_paddings, num_dims, data_format)); - return OkStatus(); + return absl::OkStatus(); } Status ComputeConv2DDimension(const Conv2DParameters& params, @@ -210,7 +210,7 @@ Status ComputeConv2DDimension(const Conv2DParameters& params, dimensions->pad_cols_before = pad_cols_before; dimensions->pad_cols_after = pad_cols_after; - return OkStatus(); + return absl::OkStatus(); } #undef TF_REQUIRES diff --git a/tensorflow/core/kernels/conv_ops_fused_image_transform.cc b/tensorflow/core/kernels/conv_ops_fused_image_transform.cc index 042ee1dfc3cd93..d45ff0171dfe59 100644 --- a/tensorflow/core/kernels/conv_ops_fused_image_transform.cc +++ b/tensorflow/core/kernels/conv_ops_fused_image_transform.cc @@ -344,7 +344,7 @@ class FusedResizeAndPadConvFunctor { std::function**)> creator = [](Im2ColBufferResource** resource) { *resource = new Im2ColBufferResource(); - return OkStatus(); + return absl::OkStatus(); }; OP_REQUIRES_OK(context, context->resource_manager()->LookupOrCreate( "Conv2d", "im2col_buffer", @@ -382,7 +382,7 @@ class FusedResizeAndPadConvFunctor { resize_creator = [](Im2ColBufferResource** resource) { *resource = new Im2ColBufferResource(); - return OkStatus(); + return absl::OkStatus(); }; OP_REQUIRES_OK(context, context->resource_manager()->LookupOrCreate( "Conv2d", "resize_cache", diff --git a/tensorflow/core/kernels/conv_ops_fused_impl.h b/tensorflow/core/kernels/conv_ops_fused_impl.h index a70fff1ed4aef6..5e35562b6f1c3a 100644 --- a/tensorflow/core/kernels/conv_ops_fused_impl.h +++ b/tensorflow/core/kernels/conv_ops_fused_impl.h @@ -665,8 +665,11 @@ struct LaunchFusedConv2DOp { stream, nullptr, std::get(runner_and_scratch), input_ptr, filter_ptr, side_input_ptr, bias_ptr, output_ptr); } else { - cudnn_launch_status = stream->FusedConvolveWithAlgorithm( - input_desc, input_ptr, // input + auto dnn = stream->parent()->AsDnn(); + OP_REQUIRES(context, dnn != nullptr, + absl::InternalError("No DNN for stream.")); + cudnn_launch_status = dnn->FusedConvolveWithAlgorithm( + stream, input_desc, input_ptr, // input kConvScale, // input_scale filter_desc, filter_ptr, // filter conv_desc, // conv diff --git a/tensorflow/core/kernels/conv_ops_fused_int8.cc b/tensorflow/core/kernels/conv_ops_fused_int8.cc index d8a80ec3523576..6ce8a906409e46 100644 --- a/tensorflow/core/kernels/conv_ops_fused_int8.cc +++ b/tensorflow/core/kernels/conv_ops_fused_int8.cc @@ -705,11 +705,13 @@ void operator()( std::get(runner_and_scratch), conv_input_ptr, filter_ptr, side_input_ptr, bias_ptr, output_ptr); } else { - cudnn_launch_status = stream->FusedConvolveWithAlgorithm( - conv_input_desc, conv_input_ptr, conv_scale, filter_desc, filter_ptr, - conv_desc, side_input_ptr, side_input_scale, bias_desc, bias_ptr, - dnn_activation_mode, output_desc, &output_ptr, &scratch_allocator, - autotune_entry.GetAlgorithmConfig(), + auto dnn = stream->parent()->AsDnn(); + OP_REQUIRES(ctx, dnn != nullptr, absl::InternalError("No DNN for stream.")); + cudnn_launch_status = dnn->FusedConvolveWithAlgorithm( + stream, conv_input_desc, conv_input_ptr, conv_scale, filter_desc, + filter_ptr, conv_desc, side_input_ptr, side_input_scale, bias_desc, + bias_ptr, dnn_activation_mode, output_desc, &output_ptr, + &scratch_allocator, autotune_entry.GetAlgorithmConfig(), /*output_profile_result=*/nullptr); } diff --git a/tensorflow/core/kernels/conv_ops_gpu.cc b/tensorflow/core/kernels/conv_ops_gpu.cc index 135c85d271246b..ecc689a383d4d8 100644 --- a/tensorflow/core/kernels/conv_ops_gpu.cc +++ b/tensorflow/core/kernels/conv_ops_gpu.cc @@ -357,7 +357,11 @@ StatusOr> AutotuneUnfusedConv( DnnScratchAllocator scratch_allocator(scratch_size_limit, ctx); std::vector algorithms; - if (!stream->parent()->GetMIOpenConvolveAlgorithms( + auto dnn = stream->parent()->AsDnn(); + if (dnn == nullptr) { + return absl::InvalidArgumentError("No DNN in stream executor."); + } + if (!dnn->GetMIOpenConvolveAlgorithms( kind, se::dnn::ToDataType::value, stream, input_desc, input_ptr, filter_desc, filter_ptr, output_desc, output_ptr, conv_desc, &scratch_allocator, &algorithms)) { @@ -381,9 +385,9 @@ StatusOr> AutotuneUnfusedConv( for (auto miopen_algorithm : algorithms) { auto profile_algorithm = miopen_algorithm.algorithm(); se::dnn::ProfileResult profile_result; - auto miopen_launch_status = stream->ConvolveWithAlgorithm( - kind, input_desc, input_ptr, filter_desc, filter_ptr, output_desc, - output_ptr, conv_desc, &scratch_allocator, + auto miopen_launch_status = dnn->ConvolveWithAlgorithm( + stream, kind, input_desc, input_ptr, filter_desc, filter_ptr, + output_desc, output_ptr, conv_desc, &scratch_allocator, se::dnn::AlgorithmConfig(profile_algorithm, miopen_algorithm.scratch_size()), &profile_result); diff --git a/tensorflow/core/kernels/conv_ops_gpu.h b/tensorflow/core/kernels/conv_ops_gpu.h index 80646badbade7d..627450ef2d6321 100644 --- a/tensorflow/core/kernels/conv_ops_gpu.h +++ b/tensorflow/core/kernels/conv_ops_gpu.h @@ -195,10 +195,14 @@ Status LaunchAutotunedConv(const AutotuneEntry& autotune_entry, std::get(runner_and_scratch), in_ptr, filter_ptr, out_ptr); } else { - return stream->ConvolveWithAlgorithm( - kind, input_desc, in_ptr, filter_desc, filter_ptr, output_desc, out_ptr, - conv_desc, scratch_allocator, autotune_entry.GetAlgorithmConfig(), - nullptr); + auto dnn = stream->parent()->AsDnn(); + if (dnn == nullptr) { + return absl::InternalError("No DNN for stream."); + } + return dnn->ConvolveWithAlgorithm( + stream, kind, input_desc, in_ptr, filter_desc, filter_ptr, output_desc, + out_ptr, conv_desc, scratch_allocator, + autotune_entry.GetAlgorithmConfig(), nullptr); } } diff --git a/tensorflow/core/kernels/conv_ops_impl.h b/tensorflow/core/kernels/conv_ops_impl.h index 432093b48e3d5d..0d3fc798bbe3c2 100644 --- a/tensorflow/core/kernels/conv_ops_impl.h +++ b/tensorflow/core/kernels/conv_ops_impl.h @@ -789,6 +789,9 @@ void LaunchConvOpImpl(OpKernelContext* context, bool cudnn_use_autotune, if (filter_dims[i] != in_dims[i]) filter_same_dims = false; } + auto* blas = stream->parent()->AsBlas(); + OP_REQUIRES(context, blas != nullptr, + absl::InternalError("No BLAS for stream.")); if (!is_grouped_convolution && one_filter && one_dilations && one_stride && data_format == FORMAT_NHWC && (padding == VALID || padding == SAME)) { // 1x1 filter, so call cublas directly. @@ -805,10 +808,10 @@ void LaunchConvOpImpl(OpKernelContext* context, bool cudnn_use_autotune, output->template flat().size()); auto no_transpose = se::blas::Transpose::kNoTranspose; - OP_REQUIRES_OK(context, stream->ThenBlasGemm(no_transpose, no_transpose, n, - m, k, b_ptr, n, a_ptr, k, - &c_ptr, n, GetNumericOptions(), - se::blas::CallContext::kNone)); + OP_REQUIRES_OK(context, blas->BlasGemm(stream, no_transpose, no_transpose, + n, m, k, b_ptr, n, a_ptr, k, &c_ptr, + n, GetNumericOptions(), + se::blas::CallContext::kNone)); return; } else if (!is_grouped_convolution && filter_same_dims && padding == VALID && data_format == FORMAT_NHWC) { @@ -827,10 +830,10 @@ void LaunchConvOpImpl(OpKernelContext* context, bool cudnn_use_autotune, output->template flat().size()); auto no_transpose = se::blas::Transpose::kNoTranspose; - OP_REQUIRES_OK(context, stream->ThenBlasGemm(no_transpose, no_transpose, n, - m, k, b_ptr, n, a_ptr, k, - &c_ptr, n, GetNumericOptions(), - se::blas::CallContext::kNone)); + OP_REQUIRES_OK(context, blas->BlasGemm(stream, no_transpose, no_transpose, + n, m, k, b_ptr, n, a_ptr, k, &c_ptr, + n, GetNumericOptions(), + se::blas::CallContext::kNone)); return; } diff --git a/tensorflow/core/kernels/conv_ops_using_gemm.cc b/tensorflow/core/kernels/conv_ops_using_gemm.cc index 8374935243fec7..07efd8069b019f 100644 --- a/tensorflow/core/kernels/conv_ops_using_gemm.cc +++ b/tensorflow/core/kernels/conv_ops_using_gemm.cc @@ -311,7 +311,7 @@ class Im2ColConvFunctor { std::function**)> creator = [](Im2ColBufferResource** resource) { *resource = new Im2ColBufferResource(); - return OkStatus(); + return absl::OkStatus(); }; OP_REQUIRES_OK(context, context->resource_manager()->LookupOrCreate( "Conv2d", "im2col_buffer", diff --git a/tensorflow/core/kernels/count_ops.cc b/tensorflow/core/kernels/count_ops.cc index 84741c3ed49bac..dd1d3db048046d 100644 --- a/tensorflow/core/kernels/count_ops.cc +++ b/tensorflow/core/kernels/count_ops.cc @@ -86,7 +86,7 @@ Status OutputSparse(const BatchedMap& per_batch_counts, int64_t num_values, dense_shape->flat().data()[1] = num_values; } - return OkStatus(); + return absl::OkStatus(); } int64_t GetOutputSize(int64_t max_seen, int64_t max_length, diff --git a/tensorflow/core/kernels/ctc_decoder_ops.cc b/tensorflow/core/kernels/ctc_decoder_ops.cc index e89c9c8c7c4220..2480bc435bd01c 100644 --- a/tensorflow/core/kernels/ctc_decoder_ops.cc +++ b/tensorflow/core/kernels/ctc_decoder_ops.cc @@ -111,7 +111,7 @@ class CTCDecodeHelper { s = ctx->output_list("decoded_shape", decoded_shape); if (!s.ok()) return s; - return OkStatus(); + return absl::OkStatus(); } // sequences[b][p][ix] stores decoded value "ix" of path "p" for batch "b". @@ -174,7 +174,7 @@ class CTCDecodeHelper { shape_t(0) = batch_size; shape_t(1) = max_decoded; } - return OkStatus(); + return absl::OkStatus(); } private: diff --git a/tensorflow/core/kernels/ctc_loss_op.cc b/tensorflow/core/kernels/ctc_loss_op.cc index 7e38ac5fcb0669..dfe9de54c91535 100644 --- a/tensorflow/core/kernels/ctc_loss_op.cc +++ b/tensorflow/core/kernels/ctc_loss_op.cc @@ -331,13 +331,13 @@ class CTCLossOpGPU : public OpKernel { StreamExecutor* executor = ctx->op_device_context()->stream()->parent(); se::dnn::DataType data_type = ToDataType::value; - auto probs_desc_s = executor->createRnnStateTensorDescriptor( + auto probs_desc_s = executor->AsDnn()->CreateRnnStateTensorDescriptor( max_time, batch_size, num_classes, data_type); OP_REQUIRES_OK(ctx, probs_desc_s.status()); std::unique_ptr probs_desc = std::move(probs_desc_s).value(); - auto grads_desc_s = executor->createRnnStateTensorDescriptor( + auto grads_desc_s = executor->AsDnn()->CreateRnnStateTensorDescriptor( max_time, batch_size, num_classes, data_type); OP_REQUIRES_OK(ctx, grads_desc_s.status()); std::unique_ptr grads_desc = @@ -358,14 +358,23 @@ class CTCLossOpGPU : public OpKernel { DnnScratchAllocator workspace_allocator(1LL << 32, ctx); Stream* stream = ctx->op_device_context()->stream(); + auto dnn = stream->parent()->AsDnn(); + OP_REQUIRES(ctx, dnn != nullptr, + absl::InternalError("stream->parent() has no DNN support")); + stream_executor::DeviceMemory scratch_memory; + int ctc_loss_algo_id; bool cudnn_launch_status = - stream - ->ThenCtcLoss(*probs_desc, probs_data, labels_data, - labels_lengths_data, input_lengths_data, - GetNumericOptions(), &costs_data, *grads_desc, - &grads_data, &workspace_allocator) + dnn->PrepareForCtcLoss( + stream, *probs_desc, probs_data, *grads_desc, labels_data, + labels_lengths_data, input_lengths_data, GetNumericOptions(), + &workspace_allocator, &scratch_memory, &ctc_loss_algo_id) .ok(); - + if (cudnn_launch_status) { + cudnn_launch_status = dnn->DoCtcLoss( + stream, *probs_desc, probs_data, labels_data, labels_lengths_data, + input_lengths_data, &costs_data, *grads_desc, &grads_data, + &scratch_memory, ctc_loss_algo_id); + } if (!cudnn_launch_status) { ctx->SetStatus(errors::Internal("cuDNN CTCLoss launch failure")); } diff --git a/tensorflow/core/kernels/cudnn_pooling_gpu.cc b/tensorflow/core/kernels/cudnn_pooling_gpu.cc index bd6e9ed054762a..e9917cce87a3e3 100644 --- a/tensorflow/core/kernels/cudnn_pooling_gpu.cc +++ b/tensorflow/core/kernels/cudnn_pooling_gpu.cc @@ -104,7 +104,9 @@ void DnnPooling3dImpl(OpKernelContext* context, auto* stream = context->op_device_context()->stream(); OP_REQUIRES(context, stream, errors::Internal("No GPU stream available.")); - + auto* dnn = stream->parent()->AsDnn(); + OP_REQUIRES(context, dnn != nullptr, + errors::Internal("No DNN support for stream.")); #if TENSORFLOW_USE_ROCM static int64 PoolingScratchSize = GetDnnWorkspaceLimit( // default value is in bytes despite the name of the environment variable @@ -113,13 +115,14 @@ void DnnPooling3dImpl(OpKernelContext* context, DnnScratchAllocator scratch_allocator(PoolingScratchSize, context); OP_REQUIRES_OK(context, - stream->ThenPoolForward(pooling_desc, GetNumericOptions(), - input_desc, input_data, output_desc, - &output_data, &scratch_allocator)); + dnn->PoolForward(stream, pooling_desc, GetNumericOptions(), + input_desc, input_data, output_desc, + &output_data, &scratch_allocator)); #else - OP_REQUIRES_OK(context, stream->ThenPoolForward( - pooling_desc, GetNumericOptions(), input_desc, - input_data, output_desc, &output_data)); + OP_REQUIRES_OK( + context, + dnn->PoolForward(stream, pooling_desc, GetNumericOptions(), input_desc, + input_data, output_desc, &output_data)); #endif if (data_format == FORMAT_NHWC) { @@ -294,6 +297,9 @@ void DnnPooling3dGradImpl( auto* stream = context->op_device_context()->stream(); OP_REQUIRES(context, stream, errors::Internal("No GPU stream available.")); + auto* dnn = stream->parent()->AsDnn(); + OP_REQUIRES(context, dnn != nullptr, + errors::Internal("No DNN support for stream.")); #if TENSORFLOW_USE_ROCM static int64 PoolingScratchSize = GetDnnWorkspaceLimit( @@ -304,16 +310,16 @@ void DnnPooling3dGradImpl( DnnScratchAllocator scratch_allocator(PoolingScratchSize, context); OP_REQUIRES_OK( context, - stream->ThenPoolBackward( - pooling_desc, GetNumericOptions(), orig_input_desc, orig_input_data, - orig_output_desc, orig_output_data, output_backprop_data, - &input_backprop_data, &scratch_allocator)); + dnn->PoolBackward(stream, pooling_desc, GetNumericOptions(), + orig_input_desc, orig_input_data, orig_output_desc, + orig_output_data, output_backprop_data, + &input_backprop_data, &scratch_allocator)); #else OP_REQUIRES_OK(context, - stream->ThenPoolBackward( - pooling_desc, GetNumericOptions(), orig_input_desc, - orig_input_data, orig_output_desc, orig_output_data, - output_backprop_data, &input_backprop_data)); + dnn->PoolBackward(stream, pooling_desc, GetNumericOptions(), + orig_input_desc, orig_input_data, + orig_output_desc, orig_output_data, + output_backprop_data, &input_backprop_data)); #endif if (data_format == FORMAT_NHWC) { diff --git a/tensorflow/core/kernels/cudnn_rnn_ops.cc b/tensorflow/core/kernels/cudnn_rnn_ops.cc index 1e6cf1846adbc1..a96cf937e277a5 100644 --- a/tensorflow/core/kernels/cudnn_rnn_ops.cc +++ b/tensorflow/core/kernels/cudnn_rnn_ops.cc @@ -24,6 +24,7 @@ limitations under the License. #include #include +#include "absl/strings/str_cat.h" #include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive #include "tensorflow/core/framework/device_base.h" #include "tensorflow/core/framework/kernel_def_builder.h" @@ -137,6 +138,13 @@ using se::dnn::RnnStateTensorDescriptor; using se::dnn::ToDataType; using tsl::StatusOr; +absl::StatusOr GetDnn(Stream* stream) { + auto dnn = stream->parent()->AsDnn(); + if (dnn == nullptr) { + return absl::InternalError("No DNN for stream"); + } + return dnn; +} uint64 HashList(const std::vector& list) { if (list.empty()) { return 0; @@ -708,22 +716,26 @@ Status CreateForwardAndBackwardIODescriptors( const TensorShape& output_shape = model_shapes.output_shape; DCHECK_EQ(input_shape.dims(), 3); + auto dnn = executor->AsDnn(); + if (dnn == nullptr) { + return absl::InvalidArgumentError("No dnn in the executor."); + } if (seq_lengths.data() != nullptr) { if (time_major) { - auto input_desc_s = executor->createRnnSequenceTensorDescriptor( + auto input_desc_s = dnn->CreateRnnSequenceTensorDescriptor( input_shape.dim_size(0), input_shape.dim_size(1), input_shape.dim_size(2), seq_lengths, time_major, data_type); TF_RETURN_IF_ERROR(input_desc_s.status()); *input_desc = std::move(input_desc_s).value(); } else { - auto input_desc_s = executor->createRnnSequenceTensorDescriptor( + auto input_desc_s = dnn->CreateRnnSequenceTensorDescriptor( input_shape.dim_size(1), input_shape.dim_size(0), input_shape.dim_size(2), seq_lengths, time_major, data_type); TF_RETURN_IF_ERROR(input_desc_s.status()); *input_desc = std::move(input_desc_s).value(); } } else { - auto input_desc_s = executor->createRnnSequenceTensorDescriptor( + auto input_desc_s = dnn->CreateRnnSequenceTensorDescriptor( input_shape.dim_size(0), input_shape.dim_size(1), input_shape.dim_size(2), data_type); TF_RETURN_IF_ERROR(input_desc_s.status()); @@ -732,13 +744,13 @@ Status CreateForwardAndBackwardIODescriptors( DCHECK_EQ(hidden_state_shape.dims(), 3); if (time_major) { - auto hidden_state_desc_s = executor->createRnnStateTensorDescriptor( + auto hidden_state_desc_s = dnn->CreateRnnStateTensorDescriptor( hidden_state_shape.dim_size(0), hidden_state_shape.dim_size(1), hidden_state_shape.dim_size(2), data_type); TF_RETURN_IF_ERROR(hidden_state_desc_s.status()); *h_state_desc = std::move(hidden_state_desc_s).value(); } else { - auto hidden_state_desc_s = executor->createRnnStateTensorDescriptor( + auto hidden_state_desc_s = dnn->CreateRnnStateTensorDescriptor( hidden_state_shape.dim_size(1), hidden_state_shape.dim_size(0), hidden_state_shape.dim_size(2), data_type); TF_RETURN_IF_ERROR(hidden_state_desc_s.status()); @@ -747,13 +759,13 @@ Status CreateForwardAndBackwardIODescriptors( DCHECK_EQ(cell_state_shape.dims(), 3); if (time_major) { - auto cell_state_desc_s = executor->createRnnStateTensorDescriptor( + auto cell_state_desc_s = dnn->CreateRnnStateTensorDescriptor( cell_state_shape.dim_size(0), cell_state_shape.dim_size(1), cell_state_shape.dim_size(2), data_type); TF_RETURN_IF_ERROR(cell_state_desc_s.status()); *c_state_desc = std::move(cell_state_desc_s).value(); } else { - auto cell_state_desc_s = executor->createRnnStateTensorDescriptor( + auto cell_state_desc_s = dnn->CreateRnnStateTensorDescriptor( cell_state_shape.dim_size(1), cell_state_shape.dim_size(0), cell_state_shape.dim_size(2), data_type); TF_RETURN_IF_ERROR(cell_state_desc_s.status()); @@ -763,20 +775,20 @@ Status CreateForwardAndBackwardIODescriptors( DCHECK_EQ(output_shape.dims(), 3); if (seq_lengths.data() != nullptr) { if (time_major) { - auto output_desc_s = executor->createRnnSequenceTensorDescriptor( + auto output_desc_s = dnn->CreateRnnSequenceTensorDescriptor( output_shape.dim_size(0), output_shape.dim_size(1), output_shape.dim_size(2), seq_lengths, time_major, data_type); TF_RETURN_IF_ERROR(output_desc_s.status()); *output_desc = std::move(output_desc_s).value(); } else { - auto output_desc_s = executor->createRnnSequenceTensorDescriptor( + auto output_desc_s = dnn->CreateRnnSequenceTensorDescriptor( output_shape.dim_size(1), output_shape.dim_size(0), output_shape.dim_size(2), seq_lengths, time_major, data_type); TF_RETURN_IF_ERROR(output_desc_s.status()); *output_desc = std::move(output_desc_s).value(); } } else { - auto output_desc_s = executor->createRnnSequenceTensorDescriptor( + auto output_desc_s = dnn->CreateRnnSequenceTensorDescriptor( output_shape.dim_size(0), output_shape.dim_size(1), output_shape.dim_size(2), data_type); TF_RETURN_IF_ERROR(output_desc_s.status()); @@ -849,21 +861,18 @@ Status DoForwardImpl(OpKernelContext* context, const RnnDescriptor& rnn_desc, } } - bool launch_success = - stream - ->ThenRnnForward(rnn_desc, *input_desc, input_data, seq_lengths_ptr, - *h_state_desc, input_h_data, *c_state_desc, - input_c_data, params_data, *output_desc, - &output_data, *h_state_desc, &output_h_data, - *c_state_desc, &output_c_data, is_training, - reserve_space_allocator, workspace_allocator, - output_profile_result) - .ok(); - return launch_success - ? OkStatus() - : errors::Internal( - "Failed to call ThenRnnForward with model config: ", - model_types.DebugString(), ", ", model_shapes.DebugString()); + TF_ASSIGN_OR_RETURN(auto dnn, GetDnn(stream)); + bool launch_success = dnn->DoRnnForward( + stream, rnn_desc, *input_desc, input_data, seq_lengths_ptr, *h_state_desc, + input_h_data, *c_state_desc, input_c_data, params_data, *output_desc, + &output_data, *h_state_desc, &output_h_data, *c_state_desc, + &output_c_data, is_training, reserve_space_allocator, workspace_allocator, + output_profile_result); + return launch_success ? OkStatus() + : absl::InternalError(absl::StrCat( + "Failed to call DoRnnForward with model config: ", + model_types.DebugString(), ", ", + model_shapes.DebugString())); } template @@ -1035,23 +1044,21 @@ Status DoBackwardImpl( } } - bool launch_success = - stream - ->ThenRnnBackward( - rnn_desc, *input_desc, input_data, seq_lengths_ptr, *h_state_desc, - input_h_data, *c_state_desc, input_c_data, params_data, - *output_desc, output_data, *h_state_desc, output_h_data, - *c_state_desc, output_c_data, output_backprop_data, - output_h_backprop_data, output_c_backprop_data, - &input_backprop_data, &input_h_backprop_data, - &input_c_backprop_data, ¶ms_backprop_data, - &reserve_space_uint8, workspace_allocator, output_profile_result) - .ok(); + TF_ASSIGN_OR_RETURN(auto dnn, GetDnn(stream)); + bool launch_success = dnn->DoRnnBackward( + stream, rnn_desc, *input_desc, input_data, seq_lengths_ptr, *h_state_desc, + input_h_data, *c_state_desc, input_c_data, params_data, *output_desc, + output_data, *h_state_desc, output_h_data, *c_state_desc, output_c_data, + output_backprop_data, output_h_backprop_data, output_c_backprop_data, + &input_backprop_data, &input_h_backprop_data, &input_c_backprop_data, + ¶ms_backprop_data, &reserve_space_uint8, workspace_allocator, + output_profile_result); return launch_success ? OkStatus() - : errors::Internal( - "Failed to call ThenRnnBackward with model config: ", - model_types.DebugString(), ", ", model_shapes.DebugString()); + : absl::InternalError(absl::StrCat( + "Failed to call DoRnnBackward with model config: ", + model_types.DebugString(), ", ", + model_shapes.DebugString())); } template @@ -1294,7 +1301,11 @@ class CudnnRNNKernelCommon : public OpKernel { // ExtracCudnnRNNParamsInfo is only called by op_kernels that do not require // random number generator, therefore set state_allocator to nullptr. const AlgorithmConfig algo_config; - auto rnn_desc_s = stream->parent()->createRnnDescriptor( + auto dnn = stream->parent()->AsDnn(); + if (dnn == nullptr) { + return absl::InvalidArgumentError("Dnn is not supported"); + } + auto rnn_desc_s = dnn->CreateRnnDescriptor( num_layers, h_num_units, input_size, /*cell_size=*/c_num_units, /*batch_size=*/0, input_mode, rnn_direction_mode(), rnn_mode(), ToDataType::value, algo_config, GetNumericOptions(), dropout(), @@ -1319,7 +1330,11 @@ class CudnnRNNKernelCommon : public OpKernel { se::dnn::DataType data_type = std::is_same_v ? se::dnn::DataType::kFloat : ToDataType::value; - auto rnn_desc_s = executor->createRnnDescriptor( + auto dnn = executor->AsDnn(); + if (dnn == nullptr) { + return absl::InvalidArgumentError("Dnn is not supported"); + } + auto rnn_desc_s = dnn->CreateRnnDescriptor( model_shapes.num_layers, model_shapes.num_units, model_shapes.input_size, model_shapes.cell_num_units, model_shapes.batch_size, input_mode, rnn_direction_mode(), rnn_mode(), @@ -1892,7 +1907,11 @@ class CudnnRNNForwardOpV2 std::vector algorithms; auto* stream = context->op_device_context()->stream(); - CHECK(stream->parent()->GetRnnAlgorithms(&algorithms)); + auto dnn = stream->parent()->AsDnn(); + if (dnn == nullptr) { + return absl::InvalidArgumentError("No DNN found"); + } + CHECK(dnn->GetRnnAlgorithms(&algorithms)); if (algorithms.empty()) { LOG(WARNING) << "No Rnn algorithm found"; return OkStatus(); diff --git a/tensorflow/core/kernels/data/BUILD b/tensorflow/core/kernels/data/BUILD index dda6eb01038eb2..a0d17a35b104b0 100644 --- a/tensorflow/core/kernels/data/BUILD +++ b/tensorflow/core/kernels/data/BUILD @@ -1003,6 +1003,8 @@ tf_kernel_library( "//tensorflow/core/data:name_utils", "//tensorflow/core/data:split_utils", "@com_google_absl//absl/memory", + "@com_google_absl//absl/status", + "@local_tsl//tsl/platform:mutex", "@local_tsl//tsl/platform:types", ], ) diff --git a/tensorflow/core/kernels/data/batch_dataset_op.cc b/tensorflow/core/kernels/data/batch_dataset_op.cc index 2b4610ea751a6a..2d7c09361fdd63 100644 --- a/tensorflow/core/kernels/data/batch_dataset_op.cc +++ b/tensorflow/core/kernels/data/batch_dataset_op.cc @@ -119,7 +119,7 @@ class BatchDatasetOp::Dataset : public DatasetBase { Status InputDatasets(std::vector* inputs) const override { inputs->push_back(input_); - return OkStatus(); + return absl::OkStatus(); } Status CheckExternalState() const override { @@ -143,9 +143,8 @@ class BatchDatasetOp::Dataset : public DatasetBase { batch_elements.emplace_back(std::move(batch_element_tuple)); } TF_RETURN_IF_ERROR(CopyBatch(CopyBatchParams(ctx), batch_elements, - parallel_copy_, - /*allocation_callback=*/nullptr, out_tensors)); - return OkStatus(); + parallel_copy_, out_tensors)); + return absl::OkStatus(); } protected: @@ -163,7 +162,7 @@ class BatchDatasetOp::Dataset : public DatasetBase { TF_RETURN_IF_ERROR( b->AddDataset(this, {input_graph_node, batch_size, drop_remainder}, {{kParallelCopy, parallel_copy}}, output)); - return OkStatus(); + return absl::OkStatus(); } private: @@ -188,7 +187,7 @@ class BatchDatasetOp::Dataset : public DatasetBase { mutex_lock l(mu_); if (!input_impl_) { *end_of_sequence = true; - return OkStatus(); + return absl::OkStatus(); } batch_elements.reserve(dataset()->reserve_size_); *end_of_sequence = false; @@ -206,13 +205,13 @@ class BatchDatasetOp::Dataset : public DatasetBase { if (batch_elements.empty()) { DCHECK(*end_of_sequence); - return OkStatus(); + return absl::OkStatus(); } if (dataset()->drop_remainder_ && batch_elements.size() < dataset()->batch_size_) { *end_of_sequence = true; - return OkStatus(); + return absl::OkStatus(); } // Copy the retrieved batch elements into one output tensor per tuple @@ -223,12 +222,11 @@ class BatchDatasetOp::Dataset : public DatasetBase { // respective slice locations. This would require a different GetNext() // overload that supports zero-copy, and might make sense in an // optimization pass. - TF_RETURN_IF_ERROR(CopyBatch( - CopyBatchParams(ctx), batch_elements, dataset()->parallel_copy_, - /*allocation_callback=*/nullptr, out_tensors)); + TF_RETURN_IF_ERROR(CopyBatch(CopyBatchParams(ctx), batch_elements, + dataset()->parallel_copy_, out_tensors)); *end_of_sequence = false; - return OkStatus(); + return absl::OkStatus(); } protected: @@ -245,7 +243,7 @@ class BatchDatasetOp::Dataset : public DatasetBase { if (input_impl_) { TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_)); } - return OkStatus(); + return absl::OkStatus(); } Status RestoreInternal(IteratorContext* ctx, @@ -259,7 +257,7 @@ class BatchDatasetOp::Dataset : public DatasetBase { } else { input_impl_.reset(); } - return OkStatus(); + return absl::OkStatus(); } TraceMeMetadata GetTraceMeMetadata() const override { diff --git a/tensorflow/core/kernels/data/cache_dataset_ops.cc b/tensorflow/core/kernels/data/cache_dataset_ops.cc index 4344e57af08fb1..98d75f2cfe2641 100644 --- a/tensorflow/core/kernels/data/cache_dataset_ops.cc +++ b/tensorflow/core/kernels/data/cache_dataset_ops.cc @@ -93,7 +93,7 @@ class PartialCache { TF_RETURN_IF_ERROR(ExtendTempCacheToIndex(index, ctx)); } *out_tensors = cache_.at(index); - return OkStatus(); + return absl::OkStatus(); } // Returns the data which has been cached up to this point. @@ -112,11 +112,12 @@ class PartialCache { } cache_.push_back(out_tensors); } - return OkStatus(); + return absl::OkStatus(); } - StatusOr> GetIteratorResourceFromDataset( - OpKernelContext* ctx, const DatasetBase* dataset) { + absl::StatusOr> + GetIteratorResourceFromDataset(OpKernelContext* ctx, + const DatasetBase* dataset) { FunctionLibraryRuntime* flr; std::unique_ptr device_mgr(nullptr); std::unique_ptr flib_def(nullptr); @@ -183,7 +184,7 @@ class CacheDatasetOp::FileDatasetBase : public DatasetBase { Status InputDatasets(std::vector* inputs) const override { inputs->push_back(input_); - return OkStatus(); + return absl::OkStatus(); } Status CheckExternalState() const override { @@ -329,7 +330,7 @@ class CacheDatasetOp::FileDatasetBase : public DatasetBase { *end_of_sequence = false; TF_RETURN_IF_ERROR(EnsureLockFileExists(end_of_sequence)); if (*end_of_sequence) { - return OkStatus(); + return absl::OkStatus(); } TF_RETURN_IF_ERROR(writer_->status()); if (cur_index_ >= kMaxItems) { @@ -348,7 +349,7 @@ class CacheDatasetOp::FileDatasetBase : public DatasetBase { if (*end_of_sequence && out_tensors->empty()) { TF_RETURN_IF_ERROR(Finish()); cur_index_++; - return OkStatus(); + return absl::OkStatus(); } if (out_tensors->size() != dataset()->num_tensors_) { return errors::Internal( @@ -366,7 +367,7 @@ class CacheDatasetOp::FileDatasetBase : public DatasetBase { TF_RETURN_IF_ERROR(Finish()); } cur_index_++; - return OkStatus(); + return absl::OkStatus(); } protected: @@ -385,7 +386,7 @@ class CacheDatasetOp::FileDatasetBase : public DatasetBase { if (iteration_completed_) { TF_RETURN_IF_ERROR( writer->WriteScalar(prefix(), kIterationCompleted, "")); - return OkStatus(); + return absl::OkStatus(); } // lockfile is created on the first call to GetNextInternal. The @@ -409,7 +410,7 @@ class CacheDatasetOp::FileDatasetBase : public DatasetBase { } TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_)); TF_RETURN_IF_ERROR(writer->WriteScalar(prefix(), kShardId, shard_id_)); - return OkStatus(); + return absl::OkStatus(); } Status RestoreInternal(IteratorContext* ctx, @@ -428,7 +429,7 @@ class CacheDatasetOp::FileDatasetBase : public DatasetBase { if (reader->Contains(prefix(), kIterationCompleted)) { iteration_completed_ = true; - return OkStatus(); + return absl::OkStatus(); } TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_)); @@ -445,7 +446,7 @@ class CacheDatasetOp::FileDatasetBase : public DatasetBase { filename_ = strings::StrCat(dataset()->filename_, "_", shard_id_); lockfile_ = strings::StrCat(filename_, kLockFileSuffix); writer_ = std::make_unique(dataset()->env_, filename_); - return OkStatus(); + return absl::OkStatus(); } private: @@ -453,10 +454,10 @@ class CacheDatasetOp::FileDatasetBase : public DatasetBase { TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { if (iteration_completed_) { *end_of_sequence = true; - return OkStatus(); + return absl::OkStatus(); } if (lockfile_created_) { - return OkStatus(); + return absl::OkStatus(); } // Perform rudimentary locking to help catch concurrent writes to the @@ -508,7 +509,7 @@ class CacheDatasetOp::FileDatasetBase : public DatasetBase { // BundleWriter in another Session. writer_ = std::make_unique(dataset()->env_, filename_); lockfile_created_ = true; - return OkStatus(); + return absl::OkStatus(); } Status Finish() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { @@ -537,7 +538,7 @@ class CacheDatasetOp::FileDatasetBase : public DatasetBase { TF_RETURN_IF_ERROR(dataset()->env_->DeleteFile( strings::StrCat(dataset()->filename_, "_", i, kLockFileSuffix))); } - return OkStatus(); + return absl::OkStatus(); } mutex mu_; @@ -571,7 +572,7 @@ class CacheDatasetOp::FileDatasetBase : public DatasetBase { TF_RETURN_IF_ERROR(reader_.status()); if (!reader_.Valid()) { *end_of_sequence = true; - return OkStatus(); + return absl::OkStatus(); } out_tensors->clear(); out_tensors->resize(dataset()->num_tensors_); @@ -588,7 +589,7 @@ class CacheDatasetOp::FileDatasetBase : public DatasetBase { if (!reader_.Valid()) { out_tensors->clear(); *end_of_sequence = true; - return OkStatus(); + return absl::OkStatus(); } StringPiece key = reader_.key(); DCHECK_EQ(key, dataset()->FormatName(cur_index_, i)); @@ -596,7 +597,7 @@ class CacheDatasetOp::FileDatasetBase : public DatasetBase { TF_RETURN_IF_ERROR(reader_.status()); } cur_index_++; - return OkStatus(); + return absl::OkStatus(); } protected: @@ -611,7 +612,7 @@ class CacheDatasetOp::FileDatasetBase : public DatasetBase { mutex_lock l(mu_); TF_RETURN_IF_ERROR( writer->WriteScalar(prefix(), kCurIndex, cur_index_)); - return OkStatus(); + return absl::OkStatus(); } Status RestoreInternal( @@ -634,7 +635,7 @@ class CacheDatasetOp::FileDatasetBase : public DatasetBase { } reader_.Seek(dataset()->FormatName(cur_index_, 0)); iterator_restored_ = true; - return OkStatus(); + return absl::OkStatus(); } private: @@ -696,7 +697,7 @@ class CacheDatasetOp::FileDataset : public CacheDatasetOp::FileDatasetBase { Node* filename = nullptr; TF_RETURN_IF_ERROR(b->AddScalar(filename_, &filename)); TF_RETURN_IF_ERROR(b->AddDataset(this, {input_graph, filename}, output)); - return OkStatus(); + return absl::OkStatus(); } }; @@ -720,7 +721,7 @@ class CacheDatasetOp::FileDatasetV2 : public CacheDatasetOp::FileDatasetBase { TF_RETURN_IF_ERROR(b->AddTensor(resource_handle_, &resource_handle_node)); TF_RETURN_IF_ERROR(b->AddDataset( this, {input_node, filename_node, resource_handle_node}, output)); - return OkStatus(); + return absl::OkStatus(); } private: @@ -788,7 +789,7 @@ class CacheDatasetOp::MemoryDatasetBase : public DatasetBase { Status InputDatasets(std::vector* inputs) const override { inputs->push_back(input_); - return OkStatus(); + return absl::OkStatus(); } Status CheckExternalState() const override { @@ -876,7 +877,7 @@ class CacheDatasetOp::MemoryDatasetBase : public DatasetBase { VLOG(2) << "Finalizing the cache because EOF has been reached."; cache_->Complete(std::move(temp_cache_)); } - return OkStatus(); + return absl::OkStatus(); } RecordBufferEnqueue(ctx, *out_tensors); temp_cache_.emplace_back(*out_tensors); @@ -885,7 +886,7 @@ class CacheDatasetOp::MemoryDatasetBase : public DatasetBase { "expected input cardinality."; cache_->Complete(std::move(temp_cache_)); } - return OkStatus(); + return absl::OkStatus(); } protected: @@ -939,7 +940,7 @@ class CacheDatasetOp::MemoryDatasetBase : public DatasetBase { for (size_t i = 0; i < cache_->size(); ++i) { RecordBufferEnqueue(ctx, cache_->at(i)); } - return OkStatus(); + return absl::OkStatus(); } Status GetNextInternal(IteratorContext* ctx, @@ -952,10 +953,10 @@ class CacheDatasetOp::MemoryDatasetBase : public DatasetBase { cache_tensors.end()); index_++; *end_of_sequence = false; - return OkStatus(); + return absl::OkStatus(); } else { *end_of_sequence = true; - return OkStatus(); + return absl::OkStatus(); } } @@ -970,7 +971,7 @@ class CacheDatasetOp::MemoryDatasetBase : public DatasetBase { IteratorStateWriter* writer) override { mutex_lock l(mu_); TF_RETURN_IF_ERROR(writer->WriteScalar(prefix(), kIndex, index_)); - return OkStatus(); + return absl::OkStatus(); } Status RestoreInternal(IteratorContext* ctx, @@ -985,7 +986,7 @@ class CacheDatasetOp::MemoryDatasetBase : public DatasetBase { } index_ = static_cast(temp); } - return OkStatus(); + return absl::OkStatus(); } private: @@ -1053,7 +1054,7 @@ class CacheDatasetOp::MemoryDataset : public CacheDatasetOp::MemoryDatasetBase { TF_RETURN_IF_ERROR(b->AddScalar(tstring(""), &filename_node)); TF_RETURN_IF_ERROR( b->AddDataset(this, {input_node, filename_node}, output)); - return OkStatus(); + return absl::OkStatus(); } private: @@ -1102,7 +1103,7 @@ class CacheDatasetOp::MemoryDatasetV2 TF_RETURN_IF_ERROR(b->AddTensor(handle, &resource_handle_node)); TF_RETURN_IF_ERROR(b->AddDataset( this, {input_node, filename_node, resource_handle_node}, output)); - return OkStatus(); + return absl::OkStatus(); } private: @@ -1139,7 +1140,7 @@ void CacheDatasetOp::MakeDataset(OpKernelContext* ctx, DatasetBase* input, ctx->resource_manager()->LookupOrCreate( container, name, &manager, [](MemoryCacheManager** manager) { *manager = new MemoryCacheManager(); - return OkStatus(); + return absl::OkStatus(); })); handle = MakeResourceHandle(ctx, container, name); } else { @@ -1154,7 +1155,7 @@ void CacheDatasetOp::MakeDataset(OpKernelContext* ctx, DatasetBase* input, ctx, ctx->resource_manager()->LookupOrCreate( container, name, &manager, [](MemoryCacheManager** manager) { *manager = new MemoryCacheManager(); - return OkStatus(); + return absl::OkStatus(); })); auto handle = MakeResourceHandle(ctx, container, name); diff --git a/tensorflow/core/kernels/data/cache_dataset_ops_test.cc b/tensorflow/core/kernels/data/cache_dataset_ops_test.cc index 578633988a3ce1..0f8f7f8824b2b5 100644 --- a/tensorflow/core/kernels/data/cache_dataset_ops_test.cc +++ b/tensorflow/core/kernels/data/cache_dataset_ops_test.cc @@ -51,14 +51,14 @@ class CacheDatasetParams : public DatasetParams { Status GetInputNames(std::vector* input_names) const override { *input_names = {CacheDatasetOp::kInputDataset, CacheDatasetOp::kFileName}; - return OkStatus(); + return absl::OkStatus(); } Status GetAttributes(AttributeVector* attr_vector) const override { *attr_vector = {{"output_types", output_dtypes_}, {"output_shapes", output_shapes_}, {"metadata", ""}}; - return OkStatus(); + return absl::OkStatus(); } string dataset_type() const override { return CacheDatasetOp::kDatasetType; } @@ -75,7 +75,7 @@ class CacheDatasetOpTest : public DatasetOpsTestBase { TF_RETURN_IF_ERROR(DatasetOpsTestBase::Initialize(dataset_params)); auto params = static_cast(dataset_params); cache_filename_ = params.filename(); - return OkStatus(); + return absl::OkStatus(); } ~CacheDatasetOpTest() override { diff --git a/tensorflow/core/kernels/data/cache_ops.cc b/tensorflow/core/kernels/data/cache_ops.cc index 6c6c7a6980aeee..002d3876e61ef0 100644 --- a/tensorflow/core/kernels/data/cache_ops.cc +++ b/tensorflow/core/kernels/data/cache_ops.cc @@ -81,7 +81,7 @@ Status AnonymousMemoryCacheHandleOp::CreateResource( std::unique_ptr pflr, FunctionLibraryRuntime* lib, MemoryCacheManager** manager) { *manager = new MemoryCacheManager(); - return OkStatus(); + return absl::OkStatus(); } void DeleteMemoryCacheOp::Compute(OpKernelContext* ctx) { diff --git a/tensorflow/core/kernels/data/concatenate_dataset_op.cc b/tensorflow/core/kernels/data/concatenate_dataset_op.cc index fe8990b333c822..8f380a29b87193 100644 --- a/tensorflow/core/kernels/data/concatenate_dataset_op.cc +++ b/tensorflow/core/kernels/data/concatenate_dataset_op.cc @@ -73,7 +73,7 @@ class ConcatenateDatasetOp::Dataset : public DatasetBase { Status MakeSplitProviders(std::vector>* split_providers) const override { TF_ASSIGN_OR_RETURN(*split_providers, GetSplitProviders(this)); - return OkStatus(); + return absl::OkStatus(); } const DataTypeVector& output_dtypes() const override { @@ -106,7 +106,7 @@ class ConcatenateDatasetOp::Dataset : public DatasetBase { Status InputDatasets(std::vector* inputs) const override { inputs->push_back(input_); inputs->push_back(to_concatenate_); - return OkStatus(); + return absl::OkStatus(); } Status CheckExternalState() const override { @@ -123,7 +123,7 @@ class ConcatenateDatasetOp::Dataset : public DatasetBase { TF_RETURN_IF_ERROR( to_concatenate_->Get(ctx, index - input_cardinality_, out_tensors)); } - return OkStatus(); + return absl::OkStatus(); } protected: @@ -137,7 +137,7 @@ class ConcatenateDatasetOp::Dataset : public DatasetBase { b->AddInputDataset(ctx, to_concatenate_, &to_concatenate_graph)); TF_RETURN_IF_ERROR( b->AddDataset(this, {input_graph, to_concatenate_graph}, output)); - return OkStatus(); + return absl::OkStatus(); } private: @@ -155,7 +155,7 @@ class ConcatenateDatasetOp::Dataset : public DatasetBase { &input_contexts_[0], this, strings::StrCat(prefix(), "[0]"), &input_impl_)); ctx->MergeCheckpoint(input_contexts_[0].checkpoint()); - return OkStatus(); + return absl::OkStatus(); } Status GetNextInternal(IteratorContext* ctx, @@ -164,14 +164,14 @@ class ConcatenateDatasetOp::Dataset : public DatasetBase { mutex_lock l(mu_); if (!input_impl_) { *end_of_sequence = true; - return OkStatus(); + return absl::OkStatus(); } while (i_ < 2) { TF_RETURN_IF_ERROR(input_impl_->GetNext(&input_contexts_[i_], out_tensors, end_of_sequence)); ctx->MergeCheckpoint(input_contexts_[i_].checkpoint()); if (!*end_of_sequence) { - return OkStatus(); + return absl::OkStatus(); } if (++i_ < 2) { TF_RETURN_IF_ERROR(dataset()->to_concatenate_->MakeIterator( @@ -181,7 +181,7 @@ class ConcatenateDatasetOp::Dataset : public DatasetBase { } *end_of_sequence = true; input_impl_.reset(); - return OkStatus(); + return absl::OkStatus(); } protected: @@ -201,7 +201,7 @@ class ConcatenateDatasetOp::Dataset : public DatasetBase { if (input_impl_) { TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_)); } - return OkStatus(); + return absl::OkStatus(); } Status RestoreInternal(IteratorContext* ctx, @@ -213,7 +213,7 @@ class ConcatenateDatasetOp::Dataset : public DatasetBase { &input_uninitialized)); if (static_cast(input_uninitialized)) { input_impl_.reset(); - return OkStatus(); + return absl::OkStatus(); } if (!TF_PREDICT_TRUE(i_ >= 0 && i_ <= 2)) return errors::InvalidArgument("i_ must be in range [0, 2]."); @@ -226,7 +226,7 @@ class ConcatenateDatasetOp::Dataset : public DatasetBase { if (input_impl_) { TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_)); } - return OkStatus(); + return absl::OkStatus(); } private: @@ -240,7 +240,7 @@ class ConcatenateDatasetOp::Dataset : public DatasetBase { const PartialTensorShape& ts2, PartialTensorShape* output_tensorshape) { if (ts1.dims() != ts2.dims() || ts1.unknown_rank() || ts2.unknown_rank()) - return OkStatus(); + return absl::OkStatus(); auto dims1 = ts1.dim_sizes(); auto dims2 = ts2.dim_sizes(); for (int d = 0; d < ts1.dims(); d++) { @@ -249,7 +249,7 @@ class ConcatenateDatasetOp::Dataset : public DatasetBase { else TF_RETURN_IF_ERROR(output_tensorshape->AddDimWithStatus(-1)); } - return OkStatus(); + return absl::OkStatus(); } const DatasetBase* input_; diff --git a/tensorflow/core/kernels/data/experimental/BUILD b/tensorflow/core/kernels/data/experimental/BUILD index a88b206f7d752c..b8c54db8486fb0 100644 --- a/tensorflow/core/kernels/data/experimental/BUILD +++ b/tensorflow/core/kernels/data/experimental/BUILD @@ -329,6 +329,26 @@ tf_cc_test( ], ) +tf_kernel_library( + name = "global_shuffle_dataset_op", + srcs = ["global_shuffle_dataset_op.cc"], + deps = [ + "//tensorflow/core:dataset_ops_op_lib", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core/data:name_utils", + "//tensorflow/core/framework:types_proto_cc", + "//tensorflow/core/kernels:random_index_shuffle", + "//tensorflow/core/kernels/data:random_seed_ops", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/log", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/synchronization", + ], +) + tf_kernel_library( name = "group_by_reducer_dataset_op", srcs = ["group_by_reducer_dataset_op.cc"], @@ -935,6 +955,7 @@ tf_kernel_library( ":csv_dataset_op", ":dense_to_sparse_batch_dataset_op", ":directed_interleave_dataset_op", + ":global_shuffle_dataset_op", ":group_by_reducer_dataset_op", ":group_by_window_dataset_op", ":ignore_errors_dataset_op", diff --git a/tensorflow/core/kernels/data/experimental/assert_cardinality_dataset_op.cc b/tensorflow/core/kernels/data/experimental/assert_cardinality_dataset_op.cc index c9221f8ae502f1..99c8045f7c0fe7 100644 --- a/tensorflow/core/kernels/data/experimental/assert_cardinality_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/assert_cardinality_dataset_op.cc @@ -72,7 +72,7 @@ class AssertCardinalityDatasetOp::Dataset : public DatasetBase { Status InputDatasets(std::vector* inputs) const override { inputs->push_back(input_); - return OkStatus(); + return absl::OkStatus(); } Status CheckExternalState() const override { @@ -89,7 +89,7 @@ class AssertCardinalityDatasetOp::Dataset : public DatasetBase { TF_RETURN_IF_ERROR(b->AddScalar(cardinality_, &cardinality_node)); TF_RETURN_IF_ERROR( b->AddDataset(this, {input_graph_node, cardinality_node}, output)); - return OkStatus(); + return absl::OkStatus(); } private: @@ -125,7 +125,7 @@ class AssertCardinalityDatasetOp::Dataset : public DatasetBase { ElementString(dataset()->cardinality_), " but contained at least ", ElementString(num_elements_), "."); } - return OkStatus(); + return absl::OkStatus(); } protected: @@ -140,7 +140,7 @@ class AssertCardinalityDatasetOp::Dataset : public DatasetBase { TF_RETURN_IF_ERROR( writer->WriteScalar(full_name("num_elements"), num_elements_)); TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_)); - return OkStatus(); + return absl::OkStatus(); } Status RestoreInternal(IteratorContext* ctx, @@ -148,7 +148,7 @@ class AssertCardinalityDatasetOp::Dataset : public DatasetBase { TF_RETURN_IF_ERROR( reader->ReadScalar(full_name("num_elements"), &num_elements_)); TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_)); - return OkStatus(); + return absl::OkStatus(); } private: diff --git a/tensorflow/core/kernels/data/experimental/assert_next_dataset_op.cc b/tensorflow/core/kernels/data/experimental/assert_next_dataset_op.cc index d65faf6f7d3e1d..6dac458703c872 100644 --- a/tensorflow/core/kernels/data/experimental/assert_next_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/assert_next_dataset_op.cc @@ -68,7 +68,7 @@ class AssertNextDatasetOp::Dataset : public DatasetBase { Status InputDatasets(std::vector* inputs) const override { inputs->push_back(input_); - return OkStatus(); + return absl::OkStatus(); } Status CheckExternalState() const override { @@ -85,7 +85,7 @@ class AssertNextDatasetOp::Dataset : public DatasetBase { TF_RETURN_IF_ERROR(b->AddVector(transformations_, &transformations_node)); TF_RETURN_IF_ERROR( b->AddDataset(this, {input_graph_node, transformations_node}, output)); - return OkStatus(); + return absl::OkStatus(); } private: @@ -132,13 +132,13 @@ class AssertNextDatasetOp::Dataset : public DatasetBase { Status SaveInternal(SerializationContext* ctx, IteratorStateWriter* writer) override { TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_)); - return OkStatus(); + return absl::OkStatus(); } Status RestoreInternal(IteratorContext* ctx, IteratorStateReader* reader) override { TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_)); - return OkStatus(); + return absl::OkStatus(); } private: diff --git a/tensorflow/core/kernels/data/experimental/assert_next_dataset_op_test.cc b/tensorflow/core/kernels/data/experimental/assert_next_dataset_op_test.cc index 18da142aad6a42..96a08f9acf2a2b 100644 --- a/tensorflow/core/kernels/data/experimental/assert_next_dataset_op_test.cc +++ b/tensorflow/core/kernels/data/experimental/assert_next_dataset_op_test.cc @@ -49,13 +49,13 @@ class AssertNextDatasetParams : public DatasetParams { input_names->reserve(input_dataset_params_.size() + 1); input_names->emplace_back(AssertNextDatasetOp::kInputDataset); input_names->emplace_back(AssertNextDatasetOp::kTransformations); - return OkStatus(); + return absl::OkStatus(); } Status GetAttributes(AttributeVector* attr_vector) const override { *attr_vector = {{AssertNextDatasetOp::kOutputShapes, output_shapes_}, {AssertNextDatasetOp::kOutputTypes, output_dtypes_}}; - return OkStatus(); + return absl::OkStatus(); } string dataset_type() const override { diff --git a/tensorflow/core/kernels/data/experimental/assert_prev_dataset_op.cc b/tensorflow/core/kernels/data/experimental/assert_prev_dataset_op.cc index bcee58bdd317ef..a7589ecc12bdd4 100644 --- a/tensorflow/core/kernels/data/experimental/assert_prev_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/assert_prev_dataset_op.cc @@ -74,7 +74,7 @@ Status CheckOpName(const DatasetBase& dataset, const NameAttrList& assertions) { assertions.name(), "', but found '", dataset.type_string(), "'."); } - return OkStatus(); + return absl::OkStatus(); } // Returns a NodeDef representation of `dataset`. @@ -92,7 +92,7 @@ StatusOr GetDatasetNode(const DatasetBase& dataset, // Checks `dataset`'s attrs against those in `assertions`. Status CheckAttributes(const DatasetBase& dataset, const NameAttrList& assertions) { - if (assertions.attr().empty()) return OkStatus(); + if (assertions.attr().empty()) return absl::OkStatus(); TF_ASSIGN_OR_RETURN(NodeDef node, GetDatasetNode(dataset, assertions.name())); std::vector attrs_not_found; for (const auto& attr : assertions.attr()) { @@ -116,7 +116,7 @@ Status CheckAttributes(const DatasetBase& dataset, attr.second.DebugString(), "', but found no such attribute defined."); } } - return OkStatus(); + return absl::OkStatus(); } // Checks `dataset`'s op name and attrs against those in `transformation`. @@ -125,7 +125,7 @@ Status CheckTransformation(const DatasetBase& dataset, TF_ASSIGN_OR_RETURN(NameAttrList assertions, GetAssertions(transformation)); TF_RETURN_IF_ERROR(CheckOpName(dataset, assertions)); TF_RETURN_IF_ERROR(CheckAttributes(dataset, assertions)); - return OkStatus(); + return absl::OkStatus(); } } // namespace @@ -167,7 +167,7 @@ class AssertPrevDatasetOp::Dataset : public DatasetBase { Status InputDatasets(std::vector* inputs) const override { inputs->push_back(input_); - return OkStatus(); + return absl::OkStatus(); } Status CheckExternalState() const override { @@ -184,7 +184,7 @@ class AssertPrevDatasetOp::Dataset : public DatasetBase { TF_RETURN_IF_ERROR(b->AddVector(transformations_, &transformations_node)); TF_RETURN_IF_ERROR( b->AddDataset(this, {input_graph_node, transformations_node}, output)); - return OkStatus(); + return absl::OkStatus(); } private: @@ -233,13 +233,13 @@ class AssertPrevDatasetOp::Dataset : public DatasetBase { Status SaveInternal(SerializationContext* ctx, IteratorStateWriter* writer) override { TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_)); - return OkStatus(); + return absl::OkStatus(); } Status RestoreInternal(IteratorContext* ctx, IteratorStateReader* reader) override { TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_)); - return OkStatus(); + return absl::OkStatus(); } private: diff --git a/tensorflow/core/kernels/data/experimental/assert_prev_dataset_op_test.cc b/tensorflow/core/kernels/data/experimental/assert_prev_dataset_op_test.cc index 0bff98260f0055..cb7bd224bb23f2 100644 --- a/tensorflow/core/kernels/data/experimental/assert_prev_dataset_op_test.cc +++ b/tensorflow/core/kernels/data/experimental/assert_prev_dataset_op_test.cc @@ -71,13 +71,13 @@ class AssertPrevDatasetParams : public DatasetParams { input_names->reserve(input_dataset_params_.size() + 1); input_names->emplace_back(AssertPrevDatasetOp::kInputDataset); input_names->emplace_back(AssertPrevDatasetOp::kTransformations); - return OkStatus(); + return absl::OkStatus(); } Status GetAttributes(AttributeVector* attr_vector) const override { *attr_vector = {{AssertPrevDatasetOp::kOutputShapes, output_shapes_}, {AssertPrevDatasetOp::kOutputTypes, output_dtypes_}}; - return OkStatus(); + return absl::OkStatus(); } string dataset_type() const override { diff --git a/tensorflow/core/kernels/data/experimental/auto_shard_dataset_op_test.cc b/tensorflow/core/kernels/data/experimental/auto_shard_dataset_op_test.cc index 935c2f9387d23f..e7b2925b0dca25 100644 --- a/tensorflow/core/kernels/data/experimental/auto_shard_dataset_op_test.cc +++ b/tensorflow/core/kernels/data/experimental/auto_shard_dataset_op_test.cc @@ -55,7 +55,7 @@ class AutoShardDatasetParams : public DatasetParams { input_names->emplace_back(AutoShardDatasetOp::kInputDataset); input_names->emplace_back(AutoShardDatasetOp::kNumWorkers); input_names->emplace_back(AutoShardDatasetOp::kIndex); - return OkStatus(); + return absl::OkStatus(); } Status GetAttributes(AttributeVector* attr_vector) const override { @@ -66,7 +66,7 @@ class AutoShardDatasetParams : public DatasetParams { attr_vector->emplace_back(AutoShardDatasetOp::kOutputTypes, output_dtypes_); attr_vector->emplace_back(AutoShardDatasetOp::kOutputShapes, output_shapes_); - return OkStatus(); + return absl::OkStatus(); } string dataset_type() const override { diff --git a/tensorflow/core/kernels/data/experimental/choose_fastest_branch_dataset_op.cc b/tensorflow/core/kernels/data/experimental/choose_fastest_branch_dataset_op.cc index efd9102f0240ae..2d2b239205302b 100644 --- a/tensorflow/core/kernels/data/experimental/choose_fastest_branch_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/choose_fastest_branch_dataset_op.cc @@ -56,10 +56,10 @@ class WrapperDataset : public DatasetBase { string DebugString() const override { return "WrapperDataset"; } Status InputDatasets(std::vector* inputs) const override { - return OkStatus(); + return absl::OkStatus(); } - Status CheckExternalState() const override { return OkStatus(); } + Status CheckExternalState() const override { return absl::OkStatus(); } protected: Status AsGraphDefInternal(SerializationContext* ctx, @@ -93,7 +93,7 @@ class WrapperDataset : public DatasetBase { "Make sure the branches to ChooseFastestDataset do not expect the " "input to repeat."); } - return OkStatus(); + return absl::OkStatus(); } Status GetNextInternal(IteratorContext* ctx, @@ -111,12 +111,12 @@ class WrapperDataset : public DatasetBase { Status SaveInternal(SerializationContext* ctx, IteratorStateWriter* writer) override { - return OkStatus(); + return absl::OkStatus(); } Status RestoreInternal(IteratorContext* ctx, IteratorStateReader* reader) override { - return OkStatus(); + return absl::OkStatus(); } private: @@ -252,7 +252,7 @@ class ChooseFastestBranchDatasetOp : public UnaryDatasetOpKernel { Status InputDatasets( std::vector* inputs) const override { inputs->push_back(input_); - return OkStatus(); + return absl::OkStatus(); } Status CheckExternalState() const override { @@ -351,7 +351,7 @@ class ChooseFastestBranchDatasetOp : public UnaryDatasetOpKernel { ctx, &instantiated_captured_funcs_[i])); } - return OkStatus(); + return absl::OkStatus(); } // The first num_elements_per_branch * num_branches iterations, we run @@ -422,7 +422,7 @@ class ChooseFastestBranchDatasetOp : public UnaryDatasetOpKernel { TF_RETURN_IF_ERROR( writer->WriteScalar(full_name("input_impl_empty"), "")); } - return OkStatus(); + return absl::OkStatus(); } Status RestoreInternal(IteratorContext* ctx, @@ -449,7 +449,7 @@ class ChooseFastestBranchDatasetOp : public UnaryDatasetOpKernel { } TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, current_iterator_)); } - return OkStatus(); + return absl::OkStatus(); } private: @@ -545,7 +545,7 @@ class ChooseFastestBranchDatasetOp : public UnaryDatasetOpKernel { ¤t_iterator_, /*node=*/nullptr)); } - return OkStatus(); + return absl::OkStatus(); } mutex mu_; diff --git a/tensorflow/core/kernels/data/experimental/choose_fastest_dataset_op.cc b/tensorflow/core/kernels/data/experimental/choose_fastest_dataset_op.cc index 606dd8b1f216b4..41352a2cd40f5a 100644 --- a/tensorflow/core/kernels/data/experimental/choose_fastest_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/choose_fastest_dataset_op.cc @@ -165,14 +165,14 @@ class ChooseFastestDatasetOp : public DatasetOpKernel { for (const auto& input : inputs_) { inputs->push_back(input); } - return OkStatus(); + return absl::OkStatus(); } Status CheckExternalState() const override { for (const auto& input : inputs_) { TF_RETURN_IF_ERROR(input->CheckExternalState()); } - return OkStatus(); + return absl::OkStatus(); } protected: @@ -210,7 +210,7 @@ class ChooseFastestDatasetOp : public DatasetOpKernel { ctx, this, strings::StrCat(prefix(), "[", i, "]"), &input_impls_[i])); } - return OkStatus(); + return absl::OkStatus(); } Status GetNextInternal(IteratorContext* ctx, @@ -266,7 +266,7 @@ class ChooseFastestDatasetOp : public DatasetOpKernel { TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl)); } } - return OkStatus(); + return absl::OkStatus(); } Status RestoreInternal(IteratorContext* ctx, @@ -289,7 +289,7 @@ class ChooseFastestDatasetOp : public DatasetOpKernel { TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl)); } } - return OkStatus(); + return absl::OkStatus(); } private: diff --git a/tensorflow/core/kernels/data/experimental/csv_dataset_op.cc b/tensorflow/core/kernels/data/experimental/csv_dataset_op.cc index 935ea38b0797dc..80c717806f7be7 100644 --- a/tensorflow/core/kernels/data/experimental/csv_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/csv_dataset_op.cc @@ -200,12 +200,12 @@ class CSVDatasetOp : public DatasetOpKernel { string DebugString() const override { return "CSVDatasetOp::Dataset"; } - Status CheckExternalState() const override { return OkStatus(); } + Status CheckExternalState() const override { return absl::OkStatus(); } Status InputDatasets( std::vector* inputs) const override { inputs->clear(); - return OkStatus(); + return absl::OkStatus(); } protected: @@ -269,7 +269,7 @@ class CSVDatasetOp : public DatasetOpKernel { {std::make_pair(8, record_defaults)}, // Tensor list inputs {}, output)); } - return OkStatus(); + return absl::OkStatus(); } private: @@ -314,7 +314,7 @@ class CSVDatasetOp : public DatasetOpKernel { // Iteration ends when there are no more files to process. if (current_file_index_ == dataset()->filenames_.size()) { *end_of_sequence = true; - return OkStatus(); + return absl::OkStatus(); } TF_RETURN_IF_ERROR(SetupStreamsLocked(ctx->env())); } while (true); @@ -340,7 +340,7 @@ class CSVDatasetOp : public DatasetOpKernel { TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("num_buffer_reads"), num_buffer_reads_)); } - return OkStatus(); + return absl::OkStatus(); } Status RestoreInternal(IteratorContext* ctx, @@ -380,7 +380,7 @@ class CSVDatasetOp : public DatasetOpKernel { } pos_ = size_t(pos); } - return OkStatus(); + return absl::OkStatus(); } private: @@ -451,7 +451,7 @@ class CSVDatasetOp : public DatasetOpKernel { if (include) { return FieldToOutput(ctx, StringPiece(), out_tensors); } else { - return OkStatus(); + return absl::OkStatus(); } } else if (!s.ok()) { return s; // Surface other errors back to caller @@ -575,7 +575,7 @@ class CSVDatasetOp : public DatasetOpKernel { const std::vector& earlier_pieces, bool include) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { - if (!include) return OkStatus(); + if (!include) return absl::OkStatus(); if (earlier_pieces.empty()) { if (field.find('\"', 1) == field.size() - 1) { @@ -695,7 +695,7 @@ class CSVDatasetOp : public DatasetOpKernel { if (errors::IsOutOfRange(s) && !result->empty()) { // Ignore OutOfRange error when ReadNBytes read < N bytes. - return OkStatus(); + return absl::OkStatus(); } return s; } @@ -798,7 +798,7 @@ class CSVDatasetOp : public DatasetOpKernel { " not supported in field ", output_idx); } - return OkStatus(); + return absl::OkStatus(); } // Records can be delimited by "\r\n" line breaks. When we encounter a @@ -825,7 +825,7 @@ class CSVDatasetOp : public DatasetOpKernel { const std::vector& earlier_pieces, bool include) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { - if (!include) return OkStatus(); + if (!include) return absl::OkStatus(); if (earlier_pieces.empty()) { return FieldToOutput(ctx, field, out_tensors); @@ -883,7 +883,7 @@ class CSVDatasetOp : public DatasetOpKernel { return errors::InvalidArgument("Can't read header of file"); } } - return OkStatus(); + return absl::OkStatus(); } // Resets all reader streams. diff --git a/tensorflow/core/kernels/data/experimental/data_service_dataset_op.cc b/tensorflow/core/kernels/data/experimental/data_service_dataset_op.cc index d3affc6c824070..a75018294d7874 100644 --- a/tensorflow/core/kernels/data/experimental/data_service_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/data_service_dataset_op.cc @@ -196,7 +196,7 @@ class DataServiceDatasetOp::Dataset : public DatasetBase { Status InputDatasets(std::vector* inputs) const override { inputs->clear(); - return OkStatus(); + return absl::OkStatus(); } protected: @@ -353,7 +353,7 @@ class DataServiceDatasetOp::Dataset : public DatasetBase { data_service_client_.GetNext(ctx_factory)); *out_tensors = std::move(result.tensors); *end_of_sequence = result.end_of_sequence; - return OkStatus(); + return absl::OkStatus(); } protected: @@ -668,7 +668,7 @@ void DataServiceDatasetOp::MakeDataset(OpKernelContext* ctx, container, name, &iteration_counter, [](IterationCounter** counter) { *counter = new IterationCounter(); - return OkStatus(); + return absl::OkStatus(); })); iteration_counter_handle = MakeResourceHandle(ctx, container, name); diff --git a/tensorflow/core/kernels/data/experimental/dense_to_sparse_batch_dataset_op.cc b/tensorflow/core/kernels/data/experimental/dense_to_sparse_batch_dataset_op.cc index cbdf0ae1b88cf2..8bfc9ade778f1f 100644 --- a/tensorflow/core/kernels/data/experimental/dense_to_sparse_batch_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/dense_to_sparse_batch_dataset_op.cc @@ -123,7 +123,7 @@ class DenseToSparseBatchDatasetOp : public UnaryDatasetOpKernel { Status InputDatasets( std::vector* inputs) const override { inputs->push_back(input_); - return OkStatus(); + return absl::OkStatus(); } Status CheckExternalState() const override { @@ -147,7 +147,7 @@ class DenseToSparseBatchDatasetOp : public UnaryDatasetOpKernel { TF_RETURN_IF_ERROR(b->AddVector(row_shape, &row_shape_node)); TF_RETURN_IF_ERROR(b->AddDataset( this, {input_node, batch_size_node, row_shape_node}, output)); - return OkStatus(); + return absl::OkStatus(); } private: @@ -231,7 +231,7 @@ class DenseToSparseBatchDatasetOp : public UnaryDatasetOpKernel { if (batch_elements.empty()) { DCHECK(*end_of_sequence); - return OkStatus(); + return absl::OkStatus(); } // * indices will be [`total_elements`, `row_shape + 1`]. @@ -284,7 +284,7 @@ class DenseToSparseBatchDatasetOp : public UnaryDatasetOpKernel { out_tensors->push_back(std::move(serialized_sparse)); *end_of_sequence = false; - return OkStatus(); + return absl::OkStatus(); } protected: @@ -299,14 +299,14 @@ class DenseToSparseBatchDatasetOp : public UnaryDatasetOpKernel { IteratorStateWriter* writer) override { mutex_lock l(mu_); TF_RETURN_IF_ERROR(Iterator::SaveInput(ctx, writer, input_impl_)); - return OkStatus(); + return absl::OkStatus(); } Status RestoreInternal(IteratorContext* ctx, IteratorStateReader* reader) override { mutex_lock l(mu_); TF_RETURN_IF_ERROR(Iterator::RestoreInput(ctx, reader, input_impl_)); - return OkStatus(); + return absl::OkStatus(); } private: diff --git a/tensorflow/core/kernels/data/experimental/directed_interleave_dataset_op.cc b/tensorflow/core/kernels/data/experimental/directed_interleave_dataset_op.cc index da86d78c9329d2..059105c0aaa39b 100644 --- a/tensorflow/core/kernels/data/experimental/directed_interleave_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/directed_interleave_dataset_op.cc @@ -85,7 +85,7 @@ class DirectedInterleaveDatasetOp::Dataset : public DatasetBase { Status MakeSplitProviders(std::vector>* split_providers) const override { TF_ASSIGN_OR_RETURN(*split_providers, GetSplitProviders(this)); - return OkStatus(); + return absl::OkStatus(); } const DataTypeVector& output_dtypes() const override { @@ -117,7 +117,7 @@ class DirectedInterleaveDatasetOp::Dataset : public DatasetBase { for (const auto& data_input : data_inputs_) { inputs->push_back(data_input); } - return OkStatus(); + return absl::OkStatus(); } Status CheckExternalState() const override { @@ -151,7 +151,7 @@ class DirectedInterleaveDatasetOp::Dataset : public DatasetBase { /*attrs=*/ {std::make_pair(kStopOnEmptyDataset, stop_on_empty_dataset_attr)}, output)); - return OkStatus(); + return absl::OkStatus(); } private: @@ -178,7 +178,7 @@ class DirectedInterleaveDatasetOp::Dataset : public DatasetBase { strings::StrCat(prefix(), "[", i, "]"), &data_input_impls_[i])); ctx->MergeCheckpoint(input_contexts_[i + 1].checkpoint()); } - return OkStatus(); + return absl::OkStatus(); } Status GetNextInternal(IteratorContext* ctx, @@ -187,7 +187,7 @@ class DirectedInterleaveDatasetOp::Dataset : public DatasetBase { mutex_lock l(mu_); if (!selector_input_impl_) { *end_of_sequence = true; - return OkStatus(); + return absl::OkStatus(); } while (true) { @@ -198,7 +198,7 @@ class DirectedInterleaveDatasetOp::Dataset : public DatasetBase { ctx->MergeCheckpoint(input_contexts_[0].checkpoint()); if (*end_of_sequence) { ResetInputs(); - return OkStatus(); + return absl::OkStatus(); } int64_t selected_input = selector_result[0].scalar()(); @@ -216,7 +216,7 @@ class DirectedInterleaveDatasetOp::Dataset : public DatasetBase { ctx->MergeCheckpoint( input_contexts_[selected_input + 1].checkpoint()); if (!end_of_selected_input) { - return OkStatus(); + return absl::OkStatus(); } // End of selected input here. Do cleanup on checkpoints. @@ -225,7 +225,7 @@ class DirectedInterleaveDatasetOp::Dataset : public DatasetBase { if (dataset()->stop_on_empty_dataset_) { *end_of_sequence = true; ResetInputs(); - return OkStatus(); + return absl::OkStatus(); } data_input_impls_[selected_input].reset(); @@ -234,7 +234,7 @@ class DirectedInterleaveDatasetOp::Dataset : public DatasetBase { if (num_active_inputs_ == 0) { selector_input_impl_.reset(); *end_of_sequence = true; - return OkStatus(); + return absl::OkStatus(); } } @@ -269,7 +269,7 @@ class DirectedInterleaveDatasetOp::Dataset : public DatasetBase { TF_RETURN_IF_ERROR(SaveInput(ctx, writer, data_input_impl)); } } - return OkStatus(); + return absl::OkStatus(); } Status RestoreInternal(IteratorContext* ctx, @@ -293,7 +293,7 @@ class DirectedInterleaveDatasetOp::Dataset : public DatasetBase { data_input_impls_[i].reset(); } } - return OkStatus(); + return absl::OkStatus(); } private: diff --git a/tensorflow/core/kernels/data/experimental/directed_interleave_dataset_op_test.cc b/tensorflow/core/kernels/data/experimental/directed_interleave_dataset_op_test.cc index 4bf22292f2ac0d..e79c64d5750e72 100644 --- a/tensorflow/core/kernels/data/experimental/directed_interleave_dataset_op_test.cc +++ b/tensorflow/core/kernels/data/experimental/directed_interleave_dataset_op_test.cc @@ -57,7 +57,7 @@ class DirectedInterleaveDatasetParams : public DatasetParams { input_names->emplace_back(absl::StrCat( DirectedInterleaveDatasetOp::kDataInputDatasets, "_", i)); } - return OkStatus(); + return absl::OkStatus(); } Status GetAttributes(AttributeVector* attr_vector) const override { @@ -70,7 +70,7 @@ class DirectedInterleaveDatasetParams : public DatasetParams { num_input_datasets_); attr_vector->emplace_back(DirectedInterleaveDatasetOp::kStopOnEmptyDataset, stop_on_empty_dataset_); - return OkStatus(); + return absl::OkStatus(); } string dataset_type() const override { diff --git a/tensorflow/core/kernels/data/experimental/global_shuffle_dataset_op.cc b/tensorflow/core/kernels/data/experimental/global_shuffle_dataset_op.cc new file mode 100644 index 00000000000000..dcf76df269df4f --- /dev/null +++ b/tensorflow/core/kernels/data/experimental/global_shuffle_dataset_op.cc @@ -0,0 +1,322 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include +#include +#include +#include + +#include "absl/base/thread_annotations.h" +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" +#include "absl/synchronization/mutex.h" +#include "tensorflow/core/data/name_utils.h" +#include "tensorflow/core/framework/dataset.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/op_requires.h" +#include "tensorflow/core/framework/resource_handle.h" +#include "tensorflow/core/framework/resource_mgr.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/kernels/data/random_seed_ops.h" +#include "tensorflow/core/kernels/random_index_shuffle.h" +#include "tsl/platform/errors.h" + +namespace tensorflow { +namespace data { +namespace { + +constexpr int32_t kIndexShuffleRounds = 8; + +constexpr const char kGlobalShuffleDataset[] = "GlobalShuffleDataset"; +constexpr const char kReshuffleEachIteration[] = "reshuffle_each_iteration"; +constexpr const char kSeed[] = "seed"; +constexpr const char kSeed2[] = "seed2"; +constexpr const char kSeedGenerator[] = "SeedGenerator"; + +class GlobalShuffleDatasetOp : public UnaryDatasetOpKernel { + public: + explicit GlobalShuffleDatasetOp(OpKernelConstruction* ctx); + + protected: + void MakeDataset(OpKernelContext* ctx, DatasetBase* input, + DatasetBase** output) override; + + private: + class Dataset; + + bool reshuffle_each_iteration_ = true; +}; + +class GlobalShuffleDatasetOp::Dataset : public DatasetBase { + public: + Dataset(OpKernelContext* ctx, const DatasetBase* input, + SeedGeneratorManager* seed_generator, RandomSeeds&& input_seeds, + bool owns_resource, ResourceHandle&& resource_handle) + : DatasetBase(DatasetContext(ctx)), + input_(input), + seed_generator_(seed_generator), + input_seeds_(std::move(input_seeds)), + owns_resource_(owns_resource), + resource_handle_(std::move(resource_handle)), + resource_mgr_(ctx->resource_manager()) { + input_->Ref(); + } + + ~Dataset() override { + seed_generator_->Unref(); + if (owns_resource_) { + absl::Status s = resource_mgr_->Delete( + resource_handle_.container(), resource_handle_.name()); + if (!s.ok()) { + LOG(WARNING) << "Failed to delete random seed generator resource for " + << "tf.data global shuffle dataset: " << s; + } + } + input_->Unref(); + } + + const DataTypeVector& output_dtypes() const override { + return input_->output_dtypes(); + } + + const std::vector& output_shapes() const override { + return input_->output_shapes(); + } + + std::string DebugString() const override { + return name_utils::DatasetDebugString(kGlobalShuffleDataset); + } + + int64_t CardinalityInternal(CardinalityOptions options) const override { + return input_->Cardinality(options); + } + + absl::Status InputDatasets( + std::vector* inputs) const override { + inputs->push_back(input_); + return absl::OkStatus(); + } + + absl::Status CheckExternalState() const override { + return input_->CheckExternalState(); + } + + protected: + std::unique_ptr MakeIteratorInternal( + const std::string& prefix) const override; + + absl::Status AsGraphDefInternal(SerializationContext* ctx, + DatasetGraphDefBuilder* b, + Node** output) const override { + // Inputs + Node* input_graph_node = nullptr; + TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node)); + Node* seed_node = nullptr; + Node* seed2_node = nullptr; + TF_RETURN_IF_ERROR(b->AddScalar(input_seeds_.input_seed(), &seed_node)); + TF_RETURN_IF_ERROR(b->AddScalar(input_seeds_.input_seed2(), &seed2_node)); + + Node* resource_handle_node = nullptr; + Tensor handle(DT_RESOURCE, TensorShape({})); + handle.scalar()() = resource_handle_; + TF_RETURN_IF_ERROR(b->AddTensor(handle, &resource_handle_node)); + + // Attrs + AttrValue reshuffle_each_iteration; + b->BuildAttrValue(seed_generator_->get()->reshuffle_each_iteration(), + &reshuffle_each_iteration); + return b->AddDataset( + this, /*inputs=*/ + {input_graph_node, seed_node, seed2_node, resource_handle_node}, + /*attrs=*/ + {std::make_pair(kReshuffleEachIteration, reshuffle_each_iteration)}, + output); + } + + private: + class Iterator; + + const DatasetBase* const input_; + SeedGeneratorManager* const seed_generator_; // Owned + const RandomSeeds input_seeds_; + const bool owns_resource_; + const ResourceHandle resource_handle_; + ResourceMgr* const resource_mgr_; // Not owned. +}; + +class GlobalShuffleDatasetOp::Dataset::Iterator + : public DatasetIterator { + public: + explicit Iterator(const Params& params, + std::shared_ptr seed_generator) + : DatasetIterator(params), + cardinality_(dataset()->Cardinality()), + seed_generator_(seed_generator) {} + + bool SymbolicCheckpointCompatible() const override { return true; } + + absl::Status Initialize(IteratorContext* ctx) override + ABSL_LOCKS_EXCLUDED(mu_) { + absl::MutexLock l(&mu_); + int64_t seed4; + seed_generator_->GenerateSeeds(&seed_, &seed2_); + seed_generator_->GenerateSeeds(&seed3_, &seed4); + TF_RETURN_IF_ERROR( + dataset()->input_->MakeIterator(ctx, this, prefix(), &input_impl_)); + return absl::OkStatus(); + } + + absl::Status GetNextInternal(IteratorContext* ctx, + std::vector* out_tensors, + bool* end_of_sequence) override + ABSL_LOCKS_EXCLUDED(mu_) { + absl::MutexLock l(&mu_); + IteratorContext::Params params(ctx); + params.index_mapper = GetIndexMapper(ctx->index_mapper()); + IteratorContext global_shuffle_ctx(params); + TF_RETURN_IF_ERROR(input_impl_->GetNext(&global_shuffle_ctx, out_tensors, + end_of_sequence)); + ctx->MergeCheckpoint(global_shuffle_ctx.checkpoint()); + return absl::OkStatus(); + } + + absl::Status SaveInternal(SerializationContext* ctx, + IteratorStateWriter* writer) override { + return absl::UnimplementedError( + "TODO(b/325112575): Support checkpoints for random access iterators."); + } + + absl::Status RestoreInternal(IteratorContext* ctx, + IteratorStateReader* reader) override { + return absl::UnimplementedError( + "TODO(b/325112575): Support checkpoints for random access iterators."); + } + + private: + std::function GetIndexMapper( + std::function parent_index_mapper) const + ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_) { + uint32_t seed = static_cast(seed_); + uint32_t seed2 = static_cast(seed2_); + uint32_t seed3 = static_cast(seed3_); + uint64_t max_index = + cardinality_ > 0 ? static_cast(cardinality_ - 1) : 0; + return [parent_index_mapper, seed, seed2, seed3, + max_index](int64_t element_position) { + if (parent_index_mapper != nullptr) { + element_position = parent_index_mapper(element_position); + } + return static_cast(tensorflow::random::index_shuffle( + static_cast(element_position), {seed, seed2, seed3}, + max_index, kIndexShuffleRounds)); + }; + } + + const int64_t cardinality_; + + mutable absl::Mutex mu_; + std::shared_ptr seed_generator_ ABSL_GUARDED_BY(mu_); + int64_t seed_ ABSL_GUARDED_BY(mu_) = 0; + int64_t seed2_ ABSL_GUARDED_BY(mu_) = 0; + int64_t seed3_ ABSL_GUARDED_BY(mu_) = 0; + std::unique_ptr input_impl_ ABSL_GUARDED_BY(mu_); +}; + +GlobalShuffleDatasetOp::GlobalShuffleDatasetOp(OpKernelConstruction* ctx) + : UnaryDatasetOpKernel(ctx) { + if (ctx->HasAttr(kReshuffleEachIteration)) { + OP_REQUIRES_OK( + ctx, ctx->GetAttr(kReshuffleEachIteration, &reshuffle_each_iteration_)); + } +} + +void GlobalShuffleDatasetOp::MakeDataset(OpKernelContext* ctx, + DatasetBase* input, + DatasetBase** output) { + OP_REQUIRES(ctx, input->RandomIndexingCompatible().ok(), + absl::FailedPreconditionError(absl::StrCat( + "`global_shuffle` requires all upstream transformations be " + "compatible with random access. Got: ", + input->RandomIndexingCompatible().ToString()))); + + int64_t cardinality = input->Cardinality(); + OP_REQUIRES(ctx, cardinality > 0, + absl::InvalidArgumentError(absl::StrCat( + "`global_shuffle` requires the input dataset to have a " + "non-empty finite cardinality. Got cardinality ", + cardinality, " for dataset ", input->DebugString()))); + + int64_t seed, seed2; + OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, kSeed, &seed)); + OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, kSeed2, &seed2)); + RandomSeeds input_seeds(seed, seed2); + + static std::atomic resource_id_counter(0); + const std::string& container = ctx->resource_manager()->default_container(); + std::string name = absl::StrCat(ctx->op_kernel().name(), "/", kSeedGenerator, + "_", resource_id_counter.fetch_add(1)); + + auto handle = HandleFromInput(ctx, 3); + SeedGeneratorManager* seed_generator = nullptr; + absl::Status s = ctx->resource_manager()->Lookup( + handle.container(), handle.name(), &seed_generator); + + bool owns_resource = false; + if (absl::IsNotFound(s)) { + owns_resource = true; + OP_REQUIRES_OK( + ctx, ctx->resource_manager()->LookupOrCreate( + container, name, &seed_generator, + [reshuffle = reshuffle_each_iteration_, + &input_seeds](SeedGeneratorManager** seed_generator) { + if (reshuffle) { + *seed_generator = new SeedGeneratorManager( + new RandomSeedGenerator(input_seeds)); + } else { + *seed_generator = new SeedGeneratorManager( + new FixedSeedGenerator(input_seeds)); + } + return absl::OkStatus(); + })); + handle = MakeResourceHandle(ctx, container, name); + } else { + OP_REQUIRES_OK(ctx, s); + } + + *output = new Dataset(ctx, input, seed_generator, std::move(input_seeds), + owns_resource, std::move(handle)); +} + +std::unique_ptr +GlobalShuffleDatasetOp::Dataset::MakeIteratorInternal( + const std::string& prefix) const { + return std::make_unique( + Iterator::Params{ + this, name_utils::IteratorPrefix(kGlobalShuffleDataset, prefix)}, + seed_generator_->get()); +} + +REGISTER_KERNEL_BUILDER(Name(kGlobalShuffleDataset).Device(DEVICE_CPU), + GlobalShuffleDatasetOp); + +} // namespace +} // namespace data +} // namespace tensorflow diff --git a/tensorflow/core/kernels/data/experimental/group_by_reducer_dataset_op.cc b/tensorflow/core/kernels/data/experimental/group_by_reducer_dataset_op.cc index 2fbdc124c5f907..059719f214f3f0 100644 --- a/tensorflow/core/kernels/data/experimental/group_by_reducer_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/group_by_reducer_dataset_op.cc @@ -191,7 +191,7 @@ class GroupByReducerDatasetOp : public UnaryDatasetOpKernel { {"Tfinalize_func_other_arguments", finalize_func_other_arguments_types_attr}}, output)); - return OkStatus(); + return absl::OkStatus(); } private: @@ -211,7 +211,7 @@ class GroupByReducerDatasetOp : public UnaryDatasetOpKernel { ctx, &instantiated_reduce_func_)); TF_RETURN_IF_ERROR(dataset()->captured_finalize_func_->Instantiate( ctx, &instantiated_finalize_func_)); - return OkStatus(); + return absl::OkStatus(); } Status GetNextInternal(IteratorContext* ctx, @@ -272,12 +272,12 @@ class GroupByReducerDatasetOp : public UnaryDatasetOpKernel { if (keys_index_ == keys_.size()) { *end_of_sequence = true; - return OkStatus(); + return absl::OkStatus(); } TF_RETURN_IF_ERROR(instantiated_finalize_func_->RunWithBorrowedArgs( ctx, states_[keys_[keys_index_++]], out_tensors, model_node())); *end_of_sequence = false; - return OkStatus(); + return absl::OkStatus(); } protected: @@ -341,7 +341,7 @@ class GroupByReducerDatasetOp : public UnaryDatasetOpKernel { } } - return OkStatus(); + return absl::OkStatus(); } Status RestoreInternal(IteratorContext* ctx, @@ -398,7 +398,7 @@ class GroupByReducerDatasetOp : public UnaryDatasetOpKernel { } } - return OkStatus(); + return absl::OkStatus(); } private: diff --git a/tensorflow/core/kernels/data/experimental/group_by_window_dataset_op.cc b/tensorflow/core/kernels/data/experimental/group_by_window_dataset_op.cc index 05035da404cc99..508a281f43f49d 100644 --- a/tensorflow/core/kernels/data/experimental/group_by_window_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/group_by_window_dataset_op.cc @@ -118,7 +118,7 @@ class GroupByWindowDatasetOp : public UnaryDatasetOpKernel { Status InputDatasets( std::vector* inputs) const override { inputs->push_back(input_); - return OkStatus(); + return absl::OkStatus(); } Status CheckExternalState() const override { @@ -184,7 +184,7 @@ class GroupByWindowDatasetOp : public UnaryDatasetOpKernel { {"Twindow_size_func_other_arguments", window_size_func_other_arguments_types_attr}}, output)); - return OkStatus(); + return absl::OkStatus(); } private: @@ -202,7 +202,7 @@ class GroupByWindowDatasetOp : public UnaryDatasetOpKernel { ctx, &instantiated_reduce_func_)); TF_RETURN_IF_ERROR(dataset()->captured_window_size_func_->Instantiate( ctx, &instantiated_window_size_func_)); - return OkStatus(); + return absl::OkStatus(); } Status GetNextInternal(IteratorContext* ctx, @@ -219,7 +219,7 @@ class GroupByWindowDatasetOp : public UnaryDatasetOpKernel { if (!end_of_group) { // Produce the subelement as output. *end_of_sequence = false; - return OkStatus(); + return absl::OkStatus(); } // We have reached the end of the current group, so maybe move on // to the next group. @@ -301,7 +301,7 @@ class GroupByWindowDatasetOp : public UnaryDatasetOpKernel { } while (current_group_iterator_ || !end_of_input_); *end_of_sequence = true; - return OkStatus(); + return absl::OkStatus(); } protected: @@ -371,7 +371,7 @@ class GroupByWindowDatasetOp : public UnaryDatasetOpKernel { } TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("group_counter"), group_counter_ - 1)); - return OkStatus(); + return absl::OkStatus(); } Status RestoreInternal(IteratorContext* ctx, @@ -431,7 +431,7 @@ class GroupByWindowDatasetOp : public UnaryDatasetOpKernel { TF_RETURN_IF_ERROR( RestoreInput(ctx, reader, current_group_iterator_)); } - return OkStatus(); + return absl::OkStatus(); } private: @@ -448,7 +448,7 @@ class GroupByWindowDatasetOp : public UnaryDatasetOpKernel { strings::StrCat(name, "[", i, "][", j, "]"), group[i][j])); } } - return OkStatus(); + return absl::OkStatus(); } Status RestoreGroup(IteratorContext* ctx, IteratorStateReader* reader, @@ -470,7 +470,7 @@ class GroupByWindowDatasetOp : public UnaryDatasetOpKernel { &group->at(i)[j])); } } - return OkStatus(); + return absl::OkStatus(); } Status StartFlushingGroup(IteratorContext* ctx, int64_t key) diff --git a/tensorflow/core/kernels/data/experimental/ignore_errors_dataset_op.cc b/tensorflow/core/kernels/data/experimental/ignore_errors_dataset_op.cc index ae420ee85ba8cd..5d0c9ff554b531 100644 --- a/tensorflow/core/kernels/data/experimental/ignore_errors_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/ignore_errors_dataset_op.cc @@ -71,7 +71,7 @@ class IgnoreErrorsDatasetOp : public UnaryDatasetOpKernel { Status InputDatasets( std::vector* inputs) const override { inputs->push_back(input_); - return OkStatus(); + return absl::OkStatus(); } Status CheckExternalState() const override { @@ -89,7 +89,7 @@ class IgnoreErrorsDatasetOp : public UnaryDatasetOpKernel { TF_RETURN_IF_ERROR( b->AddDataset(this, {std::make_pair(0, input_graph_node)}, {}, {{"log_warning", log_warning_attr}}, output)); - return OkStatus(); + return absl::OkStatus(); } private: @@ -111,7 +111,7 @@ class IgnoreErrorsDatasetOp : public UnaryDatasetOpKernel { tf_shared_lock l(mu_); if (!input_impl_) { *end_of_sequence = true; - return OkStatus(); + return absl::OkStatus(); } s = input_impl_->GetNext(ctx, out_tensors, end_of_sequence); while (!s.ok() && !errors::IsCancelled(s)) { @@ -144,7 +144,7 @@ class IgnoreErrorsDatasetOp : public UnaryDatasetOpKernel { else TF_RETURN_IF_ERROR( writer->WriteScalar(full_name("input_impls_empty"), "")); - return OkStatus(); + return absl::OkStatus(); } Status RestoreInternal(IteratorContext* ctx, @@ -154,7 +154,7 @@ class IgnoreErrorsDatasetOp : public UnaryDatasetOpKernel { input_impl_.reset(); else TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_)); - return OkStatus(); + return absl::OkStatus(); } private: diff --git a/tensorflow/core/kernels/data/experimental/list_dataset_op.cc b/tensorflow/core/kernels/data/experimental/list_dataset_op.cc index 94f788a2812816..8678a687fa3637 100644 --- a/tensorflow/core/kernels/data/experimental/list_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/list_dataset_op.cc @@ -64,7 +64,7 @@ class ListDatasetOp::Dataset : public DatasetBase { split_providers) const override { split_providers->push_back( std::make_unique(num_elements_)); - return OkStatus(); + return absl::OkStatus(); } const DataTypeVector& output_dtypes() const override { return output_types_; } @@ -82,10 +82,10 @@ class ListDatasetOp::Dataset : public DatasetBase { } Status InputDatasets(std::vector* inputs) const override { - return OkStatus(); + return absl::OkStatus(); } - Status CheckExternalState() const override { return OkStatus(); } + Status CheckExternalState() const override { return absl::OkStatus(); } Status Get(OpKernelContext* ctx, int64 index, std::vector* out_tensors) const override { @@ -95,7 +95,7 @@ class ListDatasetOp::Dataset : public DatasetBase { for (int i = 0; i < num_components_; ++i) { out_tensors->push_back(tensors_[i + num_components_ * index]); } - return OkStatus(); + return absl::OkStatus(); } protected: @@ -119,7 +119,7 @@ class ListDatasetOp::Dataset : public DatasetBase { b->BuildAttrValue(input_types_, &input_types); TF_RETURN_IF_ERROR(b->AddDataset(this, {}, {{0, tensors}}, {{kTinputTypes, input_types}}, output)); - return OkStatus(); + return absl::OkStatus(); } private: @@ -138,7 +138,7 @@ class ListDatasetOp::Dataset : public DatasetBase { TF_ASSIGN_OR_RETURN(split_provider_, GetSingleSplitProvider(ctx, dataset())); } - return OkStatus(); + return absl::OkStatus(); } Status GetNextInternal(IteratorContext* ctx, @@ -147,7 +147,7 @@ class ListDatasetOp::Dataset : public DatasetBase { Tensor split; TF_RETURN_IF_ERROR(split_provider_->GetNext(&split, end_of_sequence)); if (*end_of_sequence) { - return OkStatus(); + return absl::OkStatus(); } int64_t index = split.scalar()(); out_tensors->reserve(dataset()->num_components_); @@ -156,7 +156,7 @@ class ListDatasetOp::Dataset : public DatasetBase { dataset()->tensors_[i + dataset()->num_components_ * index]); } *end_of_sequence = false; - return OkStatus(); + return absl::OkStatus(); } protected: diff --git a/tensorflow/core/kernels/data/experimental/list_dataset_op_test.cc b/tensorflow/core/kernels/data/experimental/list_dataset_op_test.cc index 524302e0a332d9..86f4b00386bab8 100644 --- a/tensorflow/core/kernels/data/experimental/list_dataset_op_test.cc +++ b/tensorflow/core/kernels/data/experimental/list_dataset_op_test.cc @@ -52,7 +52,7 @@ class ListDatasetParams : public DatasetParams { for (int i = 0; i < tensors_.size(); ++i) { input_names->emplace_back(absl::StrCat("tensors_", i)); } - return OkStatus(); + return absl::OkStatus(); } Status GetAttributes(AttributeVector* attr_vector) const override { @@ -60,7 +60,7 @@ class ListDatasetParams : public DatasetParams { {"output_types", output_dtypes_}, {"output_shapes", output_shapes_}, {"metadata", ""}}; - return OkStatus(); + return absl::OkStatus(); } string dataset_type() const override { return "List"; } diff --git a/tensorflow/core/kernels/data/experimental/load_dataset_op.cc b/tensorflow/core/kernels/data/experimental/load_dataset_op.cc index c5dc060f44e9ee..a28743164d43a5 100644 --- a/tensorflow/core/kernels/data/experimental/load_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/load_dataset_op.cc @@ -87,7 +87,7 @@ class LoadDatasetOp::Dataset : public DatasetBase { Status InputDatasets(std::vector* inputs) const override { inputs->clear(); - return OkStatus(); + return absl::OkStatus(); } protected: @@ -122,7 +122,7 @@ class LoadDatasetOp::Dataset : public DatasetBase { std::make_pair(kReaderFuncTarguments, reader_func_arguments_types_attr)}, // Attrs output)); - return OkStatus(); + return absl::OkStatus(); } private: @@ -210,7 +210,7 @@ class LoadDatasetOp::Dataset : public DatasetBase { // We need to take a reference here as we will use the input_ and // its iterator. input_->Ref(); - return OkStatus(); + return absl::OkStatus(); } mutex mu_; diff --git a/tensorflow/core/kernels/data/experimental/lookup_ops.cc b/tensorflow/core/kernels/data/experimental/lookup_ops.cc index 91d6b1dfe03266..7fff4076d71bd8 100644 --- a/tensorflow/core/kernels/data/experimental/lookup_ops.cc +++ b/tensorflow/core/kernels/data/experimental/lookup_ops.cc @@ -70,7 +70,7 @@ class DatasetIterator iterator_ctx_.get(), nullptr, "LookupTable", &iterator_)); core::ScopedUnref unref(finalized_dataset); Next(); - return OkStatus(); + return absl::OkStatus(); } void Next() override { @@ -132,7 +132,7 @@ std::unique_ptr MakeDatasetInitializerSerializer( "Failed to create InitializeTableFromDataset op: ", builder->opts().StatusToString()); } - return OkStatus(); + return absl::OkStatus(); }, /*cleanup=*/std::move(unref_dataset)); } diff --git a/tensorflow/core/kernels/data/experimental/map_and_batch_dataset_op.cc b/tensorflow/core/kernels/data/experimental/map_and_batch_dataset_op.cc index 88558fab26cc60..80b208a8f3e2d6 100644 --- a/tensorflow/core/kernels/data/experimental/map_and_batch_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/map_and_batch_dataset_op.cc @@ -138,7 +138,7 @@ class MapAndBatchDatasetOp::Dataset : public DatasetBase { Status InputDatasets(std::vector* inputs) const override { inputs->push_back(input_); - return OkStatus(); + return absl::OkStatus(); } Status CheckExternalState() const override { @@ -182,7 +182,7 @@ class MapAndBatchDatasetOp::Dataset : public DatasetBase { std::make_pair(kPreserveCardinality, preserve_cardinality_attr)}, // Attrs output)); - return OkStatus(); + return absl::OkStatus(); } private: @@ -239,7 +239,7 @@ class MapAndBatchDatasetOp::Dataset : public DatasetBase { if (ctx->warm_start() && !ctx->is_restoring()) { EnsureThreadsStarted(ctx); } - return OkStatus(); + return absl::OkStatus(); } Status GetNextInternal(IteratorContext* ctx, @@ -279,7 +279,7 @@ class MapAndBatchDatasetOp::Dataset : public DatasetBase { ProcessBatch(dataset()->batch_size_, result->num_elements, dataset()->drop_remainder_, result->status, ctx, out_tensors, end_of_sequence, &result->output)); - return OkStatus(); + return absl::OkStatus(); } protected: @@ -298,7 +298,7 @@ class MapAndBatchDatasetOp::Dataset : public DatasetBase { if (ctx->symbolic_checkpoint()) { TF_RETURN_IF_ERROR(writer->WriteScalar(prefix(), kCallCounter, 0)); TF_RETURN_IF_ERROR(writer->WriteScalar(prefix(), kBatchResultsSize, 0)); - return OkStatus(); + return absl::OkStatus(); } mutex_lock l(*mu_); // Wait for all in-flight calls to complete. @@ -314,7 +314,7 @@ class MapAndBatchDatasetOp::Dataset : public DatasetBase { for (size_t i = 0; i < batch_results_.size(); ++i) { TF_RETURN_IF_ERROR(WriteBatchResult(writer, i)); } - return OkStatus(); + return absl::OkStatus(); } Status RestoreInternal(IteratorContext* ctx, @@ -334,7 +334,7 @@ class MapAndBatchDatasetOp::Dataset : public DatasetBase { if (ctx->warm_start()) { EnsureThreadsStarted(ctx); } - return OkStatus(); + return absl::OkStatus(); } TraceMeMetadata GetTraceMeMetadata() const override { @@ -370,7 +370,7 @@ class MapAndBatchDatasetOp::Dataset : public DatasetBase { : end_of_input(false), num_elements(0), output_allocated(false), - status(OkStatus()), + status(absl::OkStatus()), status_offset(-1), num_calls(batch_size), checkpoint(MemoryCheckpoint{ctx->id_registry()}), @@ -534,7 +534,7 @@ class MapAndBatchDatasetOp::Dataset : public DatasetBase { const std::shared_ptr>& return_values) { mutex_lock l(result->mu); if (result->output_allocated) { - return OkStatus(); + return absl::OkStatus(); } const size_t num_components = return_values->size(); result->output.reserve(num_components); @@ -553,7 +553,7 @@ class MapAndBatchDatasetOp::Dataset : public DatasetBase { } RecordBufferEnqueue(ctx.get(), result->output); result->output_allocated = true; - return OkStatus(); + return absl::OkStatus(); } void RunnerThread(const std::shared_ptr& ctx) @@ -647,7 +647,7 @@ class MapAndBatchDatasetOp::Dataset : public DatasetBase { if (result->output_allocated) { RecordBufferEnqueue(ctx, result->output); } - return OkStatus(); + return absl::OkStatus(); } Status WriteBatchResult(IteratorStateWriter* writer, size_t index) @@ -677,7 +677,7 @@ class MapAndBatchDatasetOp::Dataset : public DatasetBase { TF_RETURN_IF_ERROR( WriteStatus(prefix(), strings::StrCat(batch_prefix, "_", kStatus), result->status, writer)); - return OkStatus(); + return absl::OkStatus(); } // Used for coordination between the main thread, the runner thread, and diff --git a/tensorflow/core/kernels/data/experimental/map_and_batch_dataset_op_test.cc b/tensorflow/core/kernels/data/experimental/map_and_batch_dataset_op_test.cc index 96507577f2abd9..a340bd20758a48 100644 --- a/tensorflow/core/kernels/data/experimental/map_and_batch_dataset_op_test.cc +++ b/tensorflow/core/kernels/data/experimental/map_and_batch_dataset_op_test.cc @@ -67,7 +67,7 @@ class MapAndBatchDatasetParams : public DatasetParams { input_names->emplace_back(MapAndBatchDatasetOp::kNumParallelCalls); input_names->emplace_back(MapAndBatchDatasetOp::kDropRemainder); - return OkStatus(); + return absl::OkStatus(); } Status GetAttributes(AttributeVector* attr_vector) const override { @@ -77,7 +77,7 @@ class MapAndBatchDatasetParams : public DatasetParams { {"output_types", output_dtypes_}, {"preserve_cardinality", preserve_cardinality_}, {"metadata", ""}}; - return OkStatus(); + return absl::OkStatus(); } std::vector func_lib() const override { return func_lib_; } diff --git a/tensorflow/core/kernels/data/experimental/matching_files_dataset_op.cc b/tensorflow/core/kernels/data/experimental/matching_files_dataset_op.cc index 90b1a041fb9a21..403439b1604277 100644 --- a/tensorflow/core/kernels/data/experimental/matching_files_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/matching_files_dataset_op.cc @@ -84,10 +84,10 @@ class MatchingFilesDatasetOp : public DatasetOpKernel { Status InputDatasets( std::vector* inputs) const override { - return OkStatus(); + return absl::OkStatus(); } - Status CheckExternalState() const override { return OkStatus(); } + Status CheckExternalState() const override { return absl::OkStatus(); } protected: Status AsGraphDefInternal(SerializationContext* ctx, @@ -96,7 +96,7 @@ class MatchingFilesDatasetOp : public DatasetOpKernel { Node* patterns_node = nullptr; TF_RETURN_IF_ERROR(b->AddVector(patterns_, &patterns_node)); TF_RETURN_IF_ERROR(b->AddDataset(this, {patterns_node}, output)); - return OkStatus(); + return absl::OkStatus(); } private: @@ -139,7 +139,7 @@ class MatchingFilesDatasetOp : public DatasetOpKernel { out_tensors->emplace_back(std::move(filepath_tensor)); *end_of_sequence = false; hasMatch_ = true; - return OkStatus(); + return absl::OkStatus(); } // In this case, current_path is a directory. Then continue the @@ -185,7 +185,7 @@ class MatchingFilesDatasetOp : public DatasetOpKernel { *end_of_sequence = true; if (hasMatch_) { - return OkStatus(); + return absl::OkStatus(); } else { return errors::NotFound("Don't find any matched files"); } @@ -226,7 +226,7 @@ class MatchingFilesDatasetOp : public DatasetOpKernel { } } - return OkStatus(); + return absl::OkStatus(); } Status RestoreInternal(IteratorContext* ctx, @@ -268,7 +268,7 @@ class MatchingFilesDatasetOp : public DatasetOpKernel { } } - return OkStatus(); + return absl::OkStatus(); } private: @@ -289,7 +289,7 @@ class MatchingFilesDatasetOp : public DatasetOpKernel { // All the files in the heap are matched with the pattern, so finish // the search if current_path is a file. if (!current_path.second) { - return OkStatus(); + return absl::OkStatus(); } filepath_queue_.pop(); diff --git a/tensorflow/core/kernels/data/experimental/non_serializable_dataset_op.cc b/tensorflow/core/kernels/data/experimental/non_serializable_dataset_op.cc index 237dd5ca4cae29..1649bb7d54a93b 100644 --- a/tensorflow/core/kernels/data/experimental/non_serializable_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/non_serializable_dataset_op.cc @@ -72,7 +72,7 @@ class NonSerializableDatasetOp : public UnaryDatasetOpKernel { Status InputDatasets( std::vector* inputs) const override { inputs->push_back(input_); - return OkStatus(); + return absl::OkStatus(); } Status CheckExternalState() const override { @@ -117,13 +117,13 @@ class NonSerializableDatasetOp : public UnaryDatasetOpKernel { Status SaveInternal(SerializationContext* ctx, IteratorStateWriter* writer) override { TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_)); - return OkStatus(); + return absl::OkStatus(); } Status RestoreInternal(IteratorContext* ctx, IteratorStateReader* reader) override { TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_)); - return OkStatus(); + return absl::OkStatus(); } private: diff --git a/tensorflow/core/kernels/data/experimental/parallel_interleave_dataset_op.cc b/tensorflow/core/kernels/data/experimental/parallel_interleave_dataset_op.cc index 162076e4ae07ac..9facb7f4fb52ea 100644 --- a/tensorflow/core/kernels/data/experimental/parallel_interleave_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/parallel_interleave_dataset_op.cc @@ -151,7 +151,7 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase { Status InputDatasets(std::vector* inputs) const override { inputs->push_back(input_); - return OkStatus(); + return absl::OkStatus(); } Status CheckExternalState() const override { @@ -219,7 +219,7 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase { attrs.emplace_back(kTarguments, other_arguments_types_attr); TF_RETURN_IF_ERROR(b->AddDataset(this, inputs, list_inputs, attrs, output)); - return OkStatus(); + return absl::OkStatus(); } private: @@ -392,7 +392,7 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase { if (!can_produce_elements && !input_impl_) { // No potential for future values. *end_of_sequence = true; - return OkStatus(); + return absl::OkStatus(); } if (must_wait_for_input) { @@ -462,7 +462,7 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase { TF_RETURN_IF_ERROR( writer->WriteScalar(prefix(), kWorkerThreadsRunning, "")); } - return OkStatus(); + return absl::OkStatus(); } Status RestoreInternal(IteratorContext* ctx, @@ -494,7 +494,7 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase { } std::unique_ptr threadpool = ctx->CreateThreadPool( "read_worker_thread_state", dataset()->num_threads()); - Status s = OkStatus(); + Status s = absl::OkStatus(); BlockingCounter counter(dataset()->num_threads()); for (size_t i = 0; i < dataset()->num_threads(); ++i) { threadpool->Schedule([this, i, ctx, reader, &s, &counter] { @@ -572,7 +572,7 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase { [this, new_ctx, i]() { WorkerThread(new_ctx, i); })); } } - return OkStatus(); + return absl::OkStatus(); } TraceMeMetadata GetTraceMeMetadata() const override { @@ -655,7 +655,7 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase { std::unique_ptr iterator; - WorkerThreadState() : output_elem(OkStatus()) {} + WorkerThreadState() : output_elem(absl::OkStatus()) {} }; void CancelThreads() TF_LOCKS_EXCLUDED(mu_) { @@ -677,7 +677,7 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase { Status s = input_impl_->GetNext(ctx, &args, &end_of_input); if (end_of_input) { input_impl_.reset(); - return OkStatus(); + return absl::OkStatus(); } if (i < dataset()->cycle_length_) { interleave_indices_.push_back(i); @@ -693,7 +693,7 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase { DCHECK(interleave_indices_.size() == dataset()->cycle_length_); DCHECK(staging_indices_.size() == dataset()->prefetch_input_elements_); } - return OkStatus(); + return absl::OkStatus(); } // Produces elements into the worker's output buffers. @@ -837,7 +837,7 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase { workers_[thread_index].outputs.emplace_back(iterator_creation_status); workers_[thread_index].is_producing = false; worker_thread_states_[thread_index].iterator_creation_status = - OkStatus(); + absl::OkStatus(); // CHECKPOINT_MARKER_C // Non-OK iterator creation status has been notified to the // client. @@ -929,7 +929,7 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase { worker_thread_states_[thread_index].output_elem.output); } worker_thread_states_[thread_index].output_elem.status = - OkStatus(); + absl::OkStatus(); if (deterministic_) { workers_[thread_index].cond_var.notify_one(); } else { @@ -966,7 +966,7 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase { TF_RETURN_IF_ERROR( writer->WriteScalar(iterator_name, kIsProducing, "")); } - return OkStatus(); + return absl::OkStatus(); } Status ReadWorkerStateLocked(IteratorContext* ctx, @@ -989,7 +989,7 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase { TF_RETURN_IF_ERROR( reader->ReadScalar(worker_prefix, kOutputsSize, &outputs_size)); for (int i = 0; i < outputs_size; ++i) { - workers_[index].outputs.emplace_back(OkStatus()); + workers_[index].outputs.emplace_back(absl::OkStatus()); TF_RETURN_IF_ERROR(ReadOutputElemLocked( ctx, reader, &workers_[index].outputs.back(), worker_prefix, strings::StrCat(kOutputs, "_", i))); @@ -999,7 +999,7 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase { } else { workers_[index].is_producing = false; } - return OkStatus(); + return absl::OkStatus(); } Status WriteWorkerThreadStateLocked(SerializationContext* ctx, @@ -1032,7 +1032,7 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase { TF_RETURN_IF_ERROR( writer->WriteScalar(iterator_name, kEndOfSequence, "")); } - return OkStatus(); + return absl::OkStatus(); } Status ReadWorkerThreadStateLocked(IteratorContext* ctx, @@ -1073,7 +1073,7 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase { } else { state->end_of_sequence = false; } - return OkStatus(); + return absl::OkStatus(); } Status WriteOutputElemLocked(IteratorStateWriter* writer, @@ -1092,7 +1092,7 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase { iterator_name, strings::StrCat(prefix, "_", kOutput, "_", i), output_elem.output[i])); } - return OkStatus(); + return absl::OkStatus(); } Status ReadOutputElemLocked(IteratorContext* ctx, @@ -1115,7 +1115,7 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase { strings::StrCat(prefix, "_", kOutput, "_", i), &output_elem->output.back())); } - return OkStatus(); + return absl::OkStatus(); } Status WriteStatusLocked(IteratorStateWriter* writer, @@ -1130,7 +1130,7 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase { iterator_name, strings::StrCat(prefix, "_", KMessage), std::string(status.message()))); } - return OkStatus(); + return absl::OkStatus(); } Status ReadStatusLocked(IteratorStateReader* reader, @@ -1148,9 +1148,9 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase { &error_message)); *status = Status(code, error_message); } else { - *status = OkStatus(); + *status = absl::OkStatus(); } - return OkStatus(); + return absl::OkStatus(); } // Mutex & condition variable to guard mutable iterator internals and diff --git a/tensorflow/core/kernels/data/experimental/parallel_interleave_dataset_op_test.cc b/tensorflow/core/kernels/data/experimental/parallel_interleave_dataset_op_test.cc index 8cd0d9a5019ea1..1ba4c37e0513a8 100644 --- a/tensorflow/core/kernels/data/experimental/parallel_interleave_dataset_op_test.cc +++ b/tensorflow/core/kernels/data/experimental/parallel_interleave_dataset_op_test.cc @@ -78,7 +78,7 @@ class ParallelInterleaveDatasetParams : public DatasetParams { ParallelInterleaveDatasetOp::kBufferOutputElements); input_names->emplace_back( ParallelInterleaveDatasetOp::kPrefetchInputElements); - return OkStatus(); + return absl::OkStatus(); } Status GetAttributes(AttributeVector* attr_vector) const override { @@ -88,7 +88,7 @@ class ParallelInterleaveDatasetParams : public DatasetParams { {"output_shapes", output_shapes_}, {"output_types", output_dtypes_}, {"metadata", ""}}; - return OkStatus(); + return absl::OkStatus(); } string dataset_type() const override { diff --git a/tensorflow/core/kernels/data/experimental/parse_example_dataset_op.cc b/tensorflow/core/kernels/data/experimental/parse_example_dataset_op.cc index ec90ead5920270..9ca3a2bf9aa059 100644 --- a/tensorflow/core/kernels/data/experimental/parse_example_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/parse_example_dataset_op.cc @@ -276,7 +276,7 @@ class ParseExampleDatasetOp : public UnaryDatasetOpKernel { Status InputDatasets( std::vector* inputs) const override { inputs->push_back(input_); - return OkStatus(); + return absl::OkStatus(); } Status CheckExternalState() const override { @@ -357,7 +357,7 @@ class ParseExampleDatasetOp : public UnaryDatasetOpKernel { }, {{2, dense_defaults_nodes}}, attrs, output)); - return OkStatus(); + return absl::OkStatus(); } private: @@ -462,7 +462,7 @@ class ParseExampleDatasetOp : public UnaryDatasetOpKernel { "")); } } - return OkStatus(); + return absl::OkStatus(); } Status RestoreInternal(IteratorContext* ctx, @@ -507,7 +507,7 @@ class ParseExampleDatasetOp : public UnaryDatasetOpKernel { RecordBufferEnqueue(ctx, result.return_values); result.notification.Notify(); } - return OkStatus(); + return absl::OkStatus(); } TraceMeMetadata GetTraceMeMetadata() const override { @@ -643,7 +643,7 @@ class ParseExampleDatasetOp : public UnaryDatasetOpKernel { dataset()->output_shapes()[output_index].DebugString(), ", got ", tensor.shape().DebugString(), ")."); } - return OkStatus(); + return absl::OkStatus(); } Status ParseExample(IteratorContext* ctx, std::vector input, @@ -722,7 +722,7 @@ class ParseExampleDatasetOp : public UnaryDatasetOpKernel { steps); } } - return OkStatus(); + return absl::OkStatus(); } Status ProcessResult(IteratorContext* ctx, @@ -733,7 +733,7 @@ class ParseExampleDatasetOp : public UnaryDatasetOpKernel { *out_tensors = std::move(result->return_values); RecordBufferDequeue(ctx, *out_tensors); *end_of_sequence = false; - return OkStatus(); + return absl::OkStatus(); } if (errors::IsOutOfRange(result->status)) { // To guarantee that the transformation preserves the cardinality of @@ -858,7 +858,7 @@ class ParseExampleDatasetOp : public UnaryDatasetOpKernel { TF_RETURN_IF_ERROR(writer->WriteScalar( ErrorMessageKey(index), std::string(status.message()))); } - return OkStatus(); + return absl::OkStatus(); } Status ReadStatusLocked(IteratorStateReader* reader, size_t index, @@ -874,9 +874,9 @@ class ParseExampleDatasetOp : public UnaryDatasetOpKernel { reader->ReadScalar(ErrorMessageKey(index), &error_message)); *status = Status(code, error_message); } else { - *status = OkStatus(); + *status = absl::OkStatus(); } - return OkStatus(); + return absl::OkStatus(); } string CodeKey(size_t index) { diff --git a/tensorflow/core/kernels/data/experimental/random_access_ops.cc b/tensorflow/core/kernels/data/experimental/random_access_ops.cc index a868e975732b38..b4c26b8136ed2c 100644 --- a/tensorflow/core/kernels/data/experimental/random_access_ops.cc +++ b/tensorflow/core/kernels/data/experimental/random_access_ops.cc @@ -46,7 +46,7 @@ Status GetElementAtIndexOp::DoCompute(OpKernelContext* ctx) { for (int i = 0; i < components.size(); ++i) { ctx->set_output(i, components[i]); } - return OkStatus(); + return absl::OkStatus(); } namespace { diff --git a/tensorflow/core/kernels/data/experimental/random_dataset_op.cc b/tensorflow/core/kernels/data/experimental/random_dataset_op.cc index 1f1163839ff968..c83febdeac422b 100644 --- a/tensorflow/core/kernels/data/experimental/random_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/random_dataset_op.cc @@ -85,7 +85,7 @@ class RandomDatasetOp::Dataset : public DatasetBase { // TODO(aaudibert): Avoid sending dummy splits over RPC when using tf.data // service with RandomDataset. split_providers->push_back(std::make_unique(kint64max)); - return OkStatus(); + return absl::OkStatus(); } std::unique_ptr MakeIteratorInternal( @@ -119,10 +119,10 @@ class RandomDatasetOp::Dataset : public DatasetBase { } Status InputDatasets(std::vector* inputs) const override { - return OkStatus(); + return absl::OkStatus(); } - Status CheckExternalState() const override { return OkStatus(); } + Status CheckExternalState() const override { return absl::OkStatus(); } protected: Status AsGraphDefInternal(SerializationContext* ctx, @@ -163,7 +163,7 @@ class RandomDatasetOp::Dataset : public DatasetBase { mutex_lock l(mu_); seed_generator_->GenerateSeeds(&seed_, &seed2_); ResetRngs(); - return OkStatus(); + return absl::OkStatus(); } Status GetNextInternal(IteratorContext* ctx, @@ -174,7 +174,7 @@ class RandomDatasetOp::Dataset : public DatasetBase { out_tensors->emplace_back(ctx->allocator({}), DT_INT64, TensorShape({})); out_tensors->back().scalar()() = Random(); *end_of_sequence = false; - return OkStatus(); + return absl::OkStatus(); } std::shared_ptr CreateNode( @@ -193,7 +193,7 @@ class RandomDatasetOp::Dataset : public DatasetBase { num_random_samples_)); TF_RETURN_IF_ERROR(writer->WriteScalar(this->full_name(kSeed), seed_)); TF_RETURN_IF_ERROR(writer->WriteScalar(this->full_name(kSeed2), seed2_)); - return OkStatus(); + return absl::OkStatus(); } Status RestoreInternal(IteratorContext* ctx, @@ -210,7 +210,7 @@ class RandomDatasetOp::Dataset : public DatasetBase { TF_RETURN_IF_ERROR(reader->ReadScalar(this->full_name(kSeed), &seed_)); TF_RETURN_IF_ERROR(reader->ReadScalar(this->full_name(kSeed2), &seed2_)); ResetRngs(); - return OkStatus(); + return absl::OkStatus(); } protected: @@ -304,7 +304,7 @@ void RandomDatasetOp::MakeDataset(OpKernelContext* ctx, DatasetBase** output) { *manager = new SeedGeneratorManager(new FixedSeedGenerator(seeds)); } - return OkStatus(); + return absl::OkStatus(); })); handle = MakeResourceHandle(ctx, container, name); } diff --git a/tensorflow/core/kernels/data/experimental/random_dataset_op_test.cc b/tensorflow/core/kernels/data/experimental/random_dataset_op_test.cc index f93cf82dbeefc7..35bd077fb8bd8f 100644 --- a/tensorflow/core/kernels/data/experimental/random_dataset_op_test.cc +++ b/tensorflow/core/kernels/data/experimental/random_dataset_op_test.cc @@ -104,7 +104,7 @@ class RandomDatasetParams : public DatasetParams { if (op_version_ == 2) { input_names->emplace_back("seed_generator"); } - return OkStatus(); + return absl::OkStatus(); } virtual Status GetAttributes(AttributeVector* attributes) const override { @@ -115,7 +115,7 @@ class RandomDatasetParams : public DatasetParams { attributes->emplace_back("rerandomize_each_iteration", rerandomize_each_iteration_); } - return OkStatus(); + return absl::OkStatus(); } virtual string dataset_type() const override { diff --git a/tensorflow/core/kernels/data/experimental/rebatch_dataset_op.cc b/tensorflow/core/kernels/data/experimental/rebatch_dataset_op.cc index 220e78e6572da9..e6c02d226756ed 100644 --- a/tensorflow/core/kernels/data/experimental/rebatch_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/rebatch_dataset_op.cc @@ -97,7 +97,7 @@ class RebatchDatasetOp : public UnaryDatasetOpKernel { Status InputDatasets( std::vector* inputs) const override { inputs->push_back(input_); - return OkStatus(); + return absl::OkStatus(); } Status CheckExternalState() const override { @@ -114,7 +114,7 @@ class RebatchDatasetOp : public UnaryDatasetOpKernel { TF_RETURN_IF_ERROR(b->AddScalar(num_replicas_, &num_replicas)); TF_RETURN_IF_ERROR( b->AddDataset(this, {input_graph_node, num_replicas}, output)); - return OkStatus(); + return absl::OkStatus(); } private: @@ -141,7 +141,7 @@ class RebatchDatasetOp : public UnaryDatasetOpKernel { TF_RETURN_IF_ERROR( input_impl_->GetNext(ctx, &input_tensors, end_of_sequence)); if (*end_of_sequence) { - return OkStatus(); + return absl::OkStatus(); } input_descriptors_.reserve(input_tensors.size()); @@ -185,7 +185,7 @@ class RebatchDatasetOp : public UnaryDatasetOpKernel { } } slice_number_ = (slice_number_ + 1) % dataset()->num_replicas_; - return OkStatus(); + return absl::OkStatus(); } protected: @@ -209,7 +209,7 @@ class RebatchDatasetOp : public UnaryDatasetOpKernel { input_descriptors_[i].whole_tensor)); } } - return OkStatus(); + return absl::OkStatus(); } Status RestoreInternal(IteratorContext* ctx, @@ -237,7 +237,7 @@ class RebatchDatasetOp : public UnaryDatasetOpKernel { dataset()->num_replicas_); } } - return OkStatus(); + return absl::OkStatus(); } TraceMeMetadata GetTraceMeMetadata() const override { @@ -358,7 +358,7 @@ class RebatchDatasetV2Op : public UnaryDatasetOpKernel { Status InputDatasets( std::vector* inputs) const override { inputs->push_back(input_); - return OkStatus(); + return absl::OkStatus(); } Status CheckExternalState() const override { @@ -377,7 +377,7 @@ class RebatchDatasetV2Op : public UnaryDatasetOpKernel { TF_RETURN_IF_ERROR(b->AddScalar(drop_remainder_, &drop_remainder)); TF_RETURN_IF_ERROR(b->AddDataset( this, {input_graph_node, batch_sizes, drop_remainder}, output)); - return OkStatus(); + return absl::OkStatus(); } private: @@ -399,7 +399,7 @@ class RebatchDatasetV2Op : public UnaryDatasetOpKernel { mutex_lock l(mu_); if (end_of_sequence_) { *end_of_sequence = true; - return OkStatus(); + return absl::OkStatus(); } *end_of_sequence = false; @@ -455,7 +455,7 @@ class RebatchDatasetV2Op : public UnaryDatasetOpKernel { (dataset()->drop_remainder_ && batch_size < desired_batch_size)) { DCHECK(end_of_sequence_); *end_of_sequence = true; - return OkStatus(); + return absl::OkStatus(); } const size_t num_components = dataset()->output_dtypes().size(); @@ -490,7 +490,7 @@ class RebatchDatasetV2Op : public UnaryDatasetOpKernel { Tensor(dataset()->output_dtypes()[i], tensor_shape)); } } - return OkStatus(); + return absl::OkStatus(); } // Special case: when there's only one slice, we return the slice @@ -507,7 +507,7 @@ class RebatchDatasetV2Op : public UnaryDatasetOpKernel { tensors.push_back(std::move(tensor)); } *out_tensors = std::move(tensors); - return OkStatus(); + return absl::OkStatus(); } // For each component, concatenate slices into one tensor. @@ -567,7 +567,7 @@ class RebatchDatasetV2Op : public UnaryDatasetOpKernel { } } - return OkStatus(); + return absl::OkStatus(); } protected: @@ -589,7 +589,7 @@ class RebatchDatasetV2Op : public UnaryDatasetOpKernel { full_name(strings::StrCat("tensors[", i, "]")), tensors_[i])); } } - return OkStatus(); + return absl::OkStatus(); } Status RestoreInternal(IteratorContext* ctx, @@ -613,7 +613,7 @@ class RebatchDatasetV2Op : public UnaryDatasetOpKernel { &tensors_[i])); } } - return OkStatus(); + return absl::OkStatus(); } TraceMeMetadata GetTraceMeMetadata() const override { @@ -636,7 +636,7 @@ class RebatchDatasetV2Op : public UnaryDatasetOpKernel { tensors_[i].dim_size(0), "."); } } - return OkStatus(); + return absl::OkStatus(); } class TensorSlice { diff --git a/tensorflow/core/kernels/data/experimental/sampling_dataset_op.cc b/tensorflow/core/kernels/data/experimental/sampling_dataset_op.cc index fd01c81061f227..ca84cf49ee33f6 100644 --- a/tensorflow/core/kernels/data/experimental/sampling_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/sampling_dataset_op.cc @@ -72,7 +72,7 @@ class SamplingDatasetOp::Dataset : public DatasetBase { Status InputDatasets(std::vector* inputs) const override { inputs->push_back(input_); - return OkStatus(); + return absl::OkStatus(); } Status CheckExternalState() const override { @@ -93,7 +93,7 @@ class SamplingDatasetOp::Dataset : public DatasetBase { TF_RETURN_IF_ERROR(b->AddScalar(seeds_.second, &seed2)); TF_RETURN_IF_ERROR( b->AddDataset(this, {input_graph_node, rate, seed, seed2}, output)); - return OkStatus(); + return absl::OkStatus(); } private: @@ -118,7 +118,7 @@ class SamplingDatasetOp::Dataset : public DatasetBase { tf_shared_lock l(mu_); if (!input_impl_) { *end_of_sequence = true; - return OkStatus(); + return absl::OkStatus(); } TF_RETURN_IF_ERROR( input_impl_->GetNext(ctx, out_tensors, end_of_sequence)); @@ -126,7 +126,7 @@ class SamplingDatasetOp::Dataset : public DatasetBase { if (*end_of_sequence) { mutex_lock l(mu_); input_impl_.reset(); - return OkStatus(); + return absl::OkStatus(); } // generate a number from random uniform [0, 1) @@ -138,7 +138,7 @@ class SamplingDatasetOp::Dataset : public DatasetBase { } } while (!rand_val_hit); *end_of_sequence = false; - return OkStatus(); + return absl::OkStatus(); } protected: @@ -167,7 +167,7 @@ class SamplingDatasetOp::Dataset : public DatasetBase { TF_RETURN_IF_ERROR( writer->WriteScalar(full_name("input_impl_empty"), "")); } - return OkStatus(); + return absl::OkStatus(); } Status RestoreInternal(IteratorContext* ctx, @@ -188,7 +188,7 @@ class SamplingDatasetOp::Dataset : public DatasetBase { } else { input_impl_.reset(); } - return OkStatus(); + return absl::OkStatus(); } mutex mu_; diff --git a/tensorflow/core/kernels/data/experimental/sampling_dataset_op_test.cc b/tensorflow/core/kernels/data/experimental/sampling_dataset_op_test.cc index 87bcf12b1925d4..3c0f3dac7b9bc9 100644 --- a/tensorflow/core/kernels/data/experimental/sampling_dataset_op_test.cc +++ b/tensorflow/core/kernels/data/experimental/sampling_dataset_op_test.cc @@ -50,13 +50,13 @@ class SamplingDatasetParams : public DatasetParams { *input_names = {SamplingDatasetOp::kInputDataset, SamplingDatasetOp::kRate, SamplingDatasetOp::kSeed, SamplingDatasetOp::kSeed2}; - return OkStatus(); + return absl::OkStatus(); } Status GetAttributes(AttributeVector* attr_vector) const override { *attr_vector = {{SamplingDatasetOp::kOutputTypes, output_dtypes_}, {SamplingDatasetOp::kOutputShapes, output_shapes_}}; - return OkStatus(); + return absl::OkStatus(); } string dataset_type() const override { diff --git a/tensorflow/core/kernels/data/experimental/save_dataset_op.cc b/tensorflow/core/kernels/data/experimental/save_dataset_op.cc index 989bfceb6f63d9..0110618143c85f 100644 --- a/tensorflow/core/kernels/data/experimental/save_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/save_dataset_op.cc @@ -92,7 +92,7 @@ Status SaveDatasetOp::DoCompute(OpKernelContext* ctx) { TF_RETURN_IF_ERROR(WriteMetadataFile(ctx->env(), path, run_id, dataset->output_dtypes(), num_elements, /*finalized=*/true)); - return OkStatus(); + return absl::OkStatus(); } Status SaveDatasetOp::WriteData(OpKernelContext* ctx, DatasetBase* dataset, @@ -173,7 +173,7 @@ Status SaveDatasetOp::GetShardIndex(IteratorContext* ctx, int64_t* shard_index) { if (!use_shard_func_) { *shard_index = (*shard_index + 1) % GetCpuBudget(); - return OkStatus(); + return absl::OkStatus(); } std::vector output_tensors; TF_RETURN_IF_ERROR(function->RunWithBorrowedArgs( @@ -184,7 +184,7 @@ Status SaveDatasetOp::GetShardIndex(IteratorContext* ctx, return errors::InvalidArgument("`shard_func` must return a scalar int64."); } *shard_index = output_tensors[0].flat()(0); - return OkStatus(); + return absl::OkStatus(); } Status SaveDatasetOp::WriteMetadataFile(Env* env, const std::string& path, @@ -244,7 +244,7 @@ class SaveDatasetV2Op::Dataset : public DatasetBase { Status InputDatasets(std::vector* inputs) const override { inputs->push_back(input_); - return OkStatus(); + return absl::OkStatus(); } Status CheckExternalState() const override { @@ -296,7 +296,7 @@ class SaveDatasetV2Op::Dataset : public DatasetBase { std::make_pair(kShardFuncTarguments, shard_func_arguments_types_attr)}, output)); - return OkStatus(); + return absl::OkStatus(); } private: @@ -396,7 +396,7 @@ class SaveDatasetV2Op::Dataset : public DatasetBase { } current_writer->Write(*out_tensors); - return OkStatus(); + return absl::OkStatus(); } protected: @@ -447,7 +447,7 @@ class SaveDatasetV2Op::Dataset : public DatasetBase { TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { if (!use_shard_func) { *shard_index = (*shard_index + 1) % GetCpuBudget(); - return OkStatus(); + return absl::OkStatus(); } std::vector output_tensors; TF_RETURN_IF_ERROR(function->RunWithBorrowedArgs( @@ -459,7 +459,7 @@ class SaveDatasetV2Op::Dataset : public DatasetBase { "`shard_func` must return a scalar int64."); } *shard_index = output_tensors[0].flat()(0); - return OkStatus(); + return absl::OkStatus(); } Status WriteMetadataFile(Env* env, const std::string& path, uint64 run_id, diff --git a/tensorflow/core/kernels/data/experimental/save_dataset_op_test.cc b/tensorflow/core/kernels/data/experimental/save_dataset_op_test.cc index 5ae5757eb799ee..f0b7745a46e93d 100644 --- a/tensorflow/core/kernels/data/experimental/save_dataset_op_test.cc +++ b/tensorflow/core/kernels/data/experimental/save_dataset_op_test.cc @@ -63,7 +63,7 @@ class SaveDatasetV2Params : public DatasetParams { input_names->clear(); input_names->emplace_back(SaveDatasetV2Op::kInputDataset); input_names->emplace_back(SaveDatasetV2Op::kPath); - return OkStatus(); + return absl::OkStatus(); } Status GetAttributes(AttributeVector* attr_vector) const override { @@ -75,7 +75,7 @@ class SaveDatasetV2Params : public DatasetParams { type_arguments_); attr_vector->emplace_back(SaveDatasetV2Op::kOutputTypes, output_dtypes_); attr_vector->emplace_back(SaveDatasetV2Op::kOutputShapes, output_shapes_); - return OkStatus(); + return absl::OkStatus(); } string path() const { return path_; } @@ -101,7 +101,7 @@ class SaveDatasetV2OpTest : public DatasetOpsTestBase { TF_RETURN_IF_ERROR(DatasetOpsTestBase::Initialize(dataset_params)); auto params = static_cast(dataset_params); save_filename_ = params.path(); - return OkStatus(); + return absl::OkStatus(); } protected: diff --git a/tensorflow/core/kernels/data/experimental/scan_dataset_op.cc b/tensorflow/core/kernels/data/experimental/scan_dataset_op.cc index 7e937c682343a9..a4aa38277870ee 100644 --- a/tensorflow/core/kernels/data/experimental/scan_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/scan_dataset_op.cc @@ -117,7 +117,7 @@ class ScanDatasetOp : public UnaryDatasetOpKernel { Status InputDatasets( std::vector* inputs) const override { inputs->push_back(input_); - return OkStatus(); + return absl::OkStatus(); } Status CheckExternalState() const override { @@ -161,7 +161,7 @@ class ScanDatasetOp : public UnaryDatasetOpKernel { {"preserve_cardinality", preserve_cardinality_attr}, {"use_default_device", use_default_device_attr}}, output)); - return OkStatus(); + return absl::OkStatus(); } private: @@ -189,7 +189,7 @@ class ScanDatasetOp : public UnaryDatasetOpKernel { TF_RETURN_IF_ERROR( input_impl_->GetNext(ctx, &next_element, end_of_sequence)); if (*end_of_sequence) { - return OkStatus(); + return absl::OkStatus(); } std::vector args; @@ -248,7 +248,7 @@ class ScanDatasetOp : public UnaryDatasetOpKernel { // `f` may deliberately raise `errors::OutOfRange` to indicate // that we should terminate the iteration early. *end_of_sequence = true; - return OkStatus(); + return absl::OkStatus(); } } return s; @@ -273,7 +273,7 @@ class ScanDatasetOp : public UnaryDatasetOpKernel { TF_RETURN_IF_ERROR(writer->WriteTensor( full_name(strings::StrCat("state[", idx, "]")), state_[idx])); } - return OkStatus(); + return absl::OkStatus(); } Status RestoreInternal(IteratorContext* ctx, @@ -288,7 +288,7 @@ class ScanDatasetOp : public UnaryDatasetOpKernel { ctx->flr(), full_name(strings::StrCat("state[", idx, "]")), &state_[idx])); } - return OkStatus(); + return absl::OkStatus(); } private: diff --git a/tensorflow/core/kernels/data/experimental/set_stats_aggregator_dataset_op.cc b/tensorflow/core/kernels/data/experimental/set_stats_aggregator_dataset_op.cc index d1f05d8a3bcc2f..3d2ae0e7f7434a 100644 --- a/tensorflow/core/kernels/data/experimental/set_stats_aggregator_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/set_stats_aggregator_dataset_op.cc @@ -146,7 +146,7 @@ class SetStatsAggregatorDatasetOp : public UnaryDatasetOpKernel { Status InputDatasets( std::vector* inputs) const override { inputs->push_back(input_); - return OkStatus(); + return absl::OkStatus(); } Status CheckExternalState() const override { @@ -168,7 +168,7 @@ class SetStatsAggregatorDatasetOp : public UnaryDatasetOpKernel { TF_RETURN_IF_ERROR(b->AddDataset( this, {input_graph_node, resource_handle_node, tag_node, prefix_node}, output)); - return OkStatus(); + return absl::OkStatus(); } private: diff --git a/tensorflow/core/kernels/data/experimental/sleep_dataset_op.cc b/tensorflow/core/kernels/data/experimental/sleep_dataset_op.cc index 7fbd26a5a81b67..a8d8a7fa44228f 100644 --- a/tensorflow/core/kernels/data/experimental/sleep_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/sleep_dataset_op.cc @@ -73,7 +73,7 @@ class SleepDatasetOp : public UnaryDatasetOpKernel { Status InputDatasets( std::vector* inputs) const override { inputs->push_back(input_); - return OkStatus(); + return absl::OkStatus(); } Status CheckExternalState() const override { diff --git a/tensorflow/core/kernels/data/experimental/sliding_window_dataset_op.cc b/tensorflow/core/kernels/data/experimental/sliding_window_dataset_op.cc index 8e95a7128266db..fa522ea8e74bdf 100644 --- a/tensorflow/core/kernels/data/experimental/sliding_window_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/sliding_window_dataset_op.cc @@ -121,7 +121,7 @@ class SlidingWindowDatasetOp : public UnaryDatasetOpKernel { Status InputDatasets( std::vector* inputs) const override { inputs->push_back(input_); - return OkStatus(); + return absl::OkStatus(); } Status CheckExternalState() const override { @@ -148,7 +148,7 @@ class SlidingWindowDatasetOp : public UnaryDatasetOpKernel { TF_RETURN_IF_ERROR(b->AddDataset( this, {input_graph_node, window_size, window_shift, window_stride}, {std::make_pair(kDropRemainder, drop_remainder_attr)}, output)); - return OkStatus(); + return absl::OkStatus(); } private: @@ -193,7 +193,7 @@ class SlidingWindowDatasetOp : public UnaryDatasetOpKernel { (buffer_.size() < target_size && drop_remainder)) { DCHECK(input_impl_ == nullptr); *end_of_sequence = true; - return OkStatus(); + return absl::OkStatus(); } for (size_t i = 0; i < buffer_.size(); i += window_stride) { @@ -248,7 +248,7 @@ class SlidingWindowDatasetOp : public UnaryDatasetOpKernel { } } *end_of_sequence = false; - return OkStatus(); + return absl::OkStatus(); } protected: @@ -278,7 +278,7 @@ class SlidingWindowDatasetOp : public UnaryDatasetOpKernel { strings::StrCat("buffer[", i, "][", j, "]"), buffer_[i][j])); } } - return OkStatus(); + return absl::OkStatus(); } Status RestoreInternal(IteratorContext* ctx, @@ -305,7 +305,7 @@ class SlidingWindowDatasetOp : public UnaryDatasetOpKernel { &buffer_[i][j])); } } - return OkStatus(); + return absl::OkStatus(); } private: diff --git a/tensorflow/core/kernels/data/experimental/snapshot_dataset_op.cc b/tensorflow/core/kernels/data/experimental/snapshot_dataset_op.cc index b62e0241570ed9..9dfdbafe6e8e3c 100644 --- a/tensorflow/core/kernels/data/experimental/snapshot_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/snapshot_dataset_op.cc @@ -148,7 +148,7 @@ class SnapshotDatasetV2Op::Dataset : public DatasetBase { Status InputDatasets(std::vector* inputs) const override { inputs->push_back(input_); - return OkStatus(); + return absl::OkStatus(); } Status CheckExternalState() const override { @@ -306,12 +306,12 @@ class SnapshotDatasetV2Op::Dataset : public DatasetBase { // We do not need to checkpoint the reader as we are rebuilding the // reader datasets from information that is already saved by the main // iterator. - return OkStatus(); + return absl::OkStatus(); } Status RestoreInternal(IteratorContext* ctx, IteratorStateReader* reader) override { - return OkStatus(); + return absl::OkStatus(); } private: @@ -429,7 +429,7 @@ class SnapshotDatasetV2Op::Dataset : public DatasetBase { } current_writer->Write(*out_tensors); - return OkStatus(); + return absl::OkStatus(); } protected: @@ -488,7 +488,7 @@ class SnapshotDatasetV2Op::Dataset : public DatasetBase { // Create writable files if we see an index bigger than our current // files. *shard_index = output_tensors[0].flat()(0); - return OkStatus(); + return absl::OkStatus(); } Status WriteMetadataFile(Env* env, bool finalized) @@ -621,7 +621,7 @@ class SnapshotDatasetV2Op::Dataset : public DatasetBase { TF_RETURN_IF_ERROR( writer->WriteScalar(full_name(kGraphHashDirectory), hash_dir_)); } - return OkStatus(); + return absl::OkStatus(); } Status RestoreInternal(IteratorContext* ctx, @@ -631,7 +631,7 @@ class SnapshotDatasetV2Op::Dataset : public DatasetBase { TF_RETURN_IF_ERROR(InitializeIterator(ctx, reader)); return RestoreInput(ctx, reader, iterator_); } - return OkStatus(); + return absl::OkStatus(); } private: @@ -998,7 +998,7 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel { Status InputDatasets( std::vector* inputs) const override { inputs->push_back(input_); - return OkStatus(); + return absl::OkStatus(); } Status CheckExternalState() const override { @@ -1081,7 +1081,7 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel { {"mode", mode_attr}, {"snapshot_name", snapshot_name_attr}}, output)); - return OkStatus(); + return absl::OkStatus(); } private: @@ -1106,7 +1106,9 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel { // Initialize at first and at that point we don't know which iterator // (Reader / Writer / Passthrough) we need to restore as this info is part // of the checkpoint. - Status Initialize(IteratorContext* ctx) override { return OkStatus(); } + Status Initialize(IteratorContext* ctx) override { + return absl::OkStatus(); + } Status GetNextInternal(IteratorContext* ctx, std::vector* out_tensors, @@ -1137,7 +1139,7 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel { static_cast(state_))); TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kHashDir), hash_dir_)); VLOG(2) << "Saving Snapshot iterator: " << state_; - return OkStatus(); + return absl::OkStatus(); } Status RestoreInternal(IteratorContext* ctx, @@ -1149,7 +1151,7 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel { LOG(ERROR) << "Dataset has changed while restoring from the " "checkpoint. Old hash: " << hash_dir << "; new hash: " << hash_dir_; - return OkStatus(); + return absl::OkStatus(); } int64_t temp; TF_RETURN_IF_ERROR(reader->ReadScalar(full_name(kState), &temp)); @@ -1275,7 +1277,7 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel { for (auto i = 0; i < dataset()->num_reader_threads_; ++i) { curr_filenames_.push_back(GetNextFilename()); } - return OkStatus(); + return absl::OkStatus(); } Status GetNextInternal(IteratorContext* ctx, @@ -1358,7 +1360,7 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel { if (background_threads_finished_) { *end_of_sequence = true; - return OkStatus(); + return absl::OkStatus(); } return errors::Internal("Unreachable point in SnapshotReader"); @@ -1397,7 +1399,7 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel { num_elements_read_)); VLOG(2) << "Saving SnapshotReaderIterator: " << num_elements_read_ << "; elements_produced: " << elements_produced_; - return OkStatus(); + return absl::OkStatus(); } Status RestoreInternal(IteratorContext* ctx, @@ -1413,7 +1415,7 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel { << "run_dir: " << run_dir << " but new run_dir is: " << run_dir_ << ". We'll now restart snapshot creation."; - return OkStatus(); + return absl::OkStatus(); } TF_RETURN_IF_ERROR(reader->ReadScalar(full_name(kRunId), &run_id_)); TF_RETURN_IF_ERROR(reader->ReadScalar(full_name(kRunDir), &run_dir_)); @@ -1464,7 +1466,7 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel { &num_elements_read_)); VLOG(2) << "Restoring SnapshotReaderIterator: " << num_elements_read_ << "; elements_produced: " << elements_produced_; - return OkStatus(); + return absl::OkStatus(); } private: @@ -1497,18 +1499,18 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel { profiler::TraceMeLevel::kInfo); BufferElement elem; elem.value = std::move(read_tensors); - elem.status = OkStatus(); + elem.status = absl::OkStatus(); mutex_lock l(mu_); buffer_.push_back(std::move(elem)); num_elements_read_++; cond_var_.notify_all(); } else if (errors::IsOutOfRange(s)) { - return OkStatus(); + return absl::OkStatus(); } else { return s; } } - return OkStatus(); + return absl::OkStatus(); } string GetNextFilename() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { @@ -1575,7 +1577,7 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel { TF_RETURN_IF_ERROR(writer->WriteScalar( ErrorMessageKey(index), std::string(status.message()))); } - return OkStatus(); + return absl::OkStatus(); } Status ReadStatus(IteratorStateReader* reader, size_t index, @@ -1590,9 +1592,9 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel { reader->ReadScalar(ErrorMessageKey(index), &error_message)); *status = Status(code, error_message); } else { - *status = OkStatus(); + *status = absl::OkStatus(); } - return OkStatus(); + return absl::OkStatus(); } string CodeKey(size_t index) { @@ -1774,7 +1776,7 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel { << write_throughput; } } - return OkStatus(); + return absl::OkStatus(); } protected: @@ -1829,7 +1831,7 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel { } VLOG(2) << "Saving SnapshotWriterIterator: " << num_elements_written_ << "; elements_produced: " << elements_produced_; - return OkStatus(); + return absl::OkStatus(); } Status RestoreInternal(IteratorContext* ctx, @@ -1844,7 +1846,7 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel { if (hash_dir != hash_dir_) { LOG(INFO) << "Old hash dir from ckpt: " << hash_dir << " is not the same as the new one: " << hash_dir_; - return OkStatus(); + return absl::OkStatus(); } is_restored_ = true; if (reader->Contains(full_name(kEndOfSequence))) { @@ -1937,7 +1939,7 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel { VLOG(2) << "Restoring SnapshotWriterIterator: " << num_elements_written_ << "; elements_produced: " << elements_produced_; - return OkStatus(); + return absl::OkStatus(); } private: @@ -1966,7 +1968,7 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel { while (num_active_threads_ > 0) { cond_var_.wait(l); } - return OkStatus(); + return absl::OkStatus(); } // Wait for a space in the buffer_. @@ -1989,7 +1991,7 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel { snapshot_util::ElementOrEOF elem_copy = next_elem_; buffer_.push_back(elem_copy); cond_var_.notify_all(); - return OkStatus(); + return absl::OkStatus(); } Status ProcessOneElement(Env* env, int64_t* bytes_written, @@ -2055,7 +2057,7 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel { *bytes_written = 0; } TF_RETURN_IF_ERROR((*writer)->WriteTensors(elem.value)); - return OkStatus(); + return absl::OkStatus(); } if (*end_of_processing) { @@ -2078,7 +2080,7 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel { cond_var_.notify_all(); } } - return OkStatus(); + return absl::OkStatus(); } // Just pulls off elements from the buffer and writes them. @@ -2135,14 +2137,14 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel { if (compression_ratio_ > 0.0) { *should_close = bytes_written > (compression_ratio_ * dataset()->shard_size_bytes_); - return OkStatus(); + return absl::OkStatus(); } } // If the number of bytes written aren't shard_size_bytes_ yet, we // keep on going. if (bytes_written <= dataset()->shard_size_bytes_) { *should_close = false; - return OkStatus(); + return absl::OkStatus(); } // Use the actual file size to determine compression ratio. // Make sure that all bytes are written out. @@ -2154,7 +2156,7 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel { static_cast(file_size); LOG(INFO) << "Writing compression achieved: " << compression_ratio_; *should_close = file_size >= dataset()->shard_size_bytes_; - return OkStatus(); + return absl::OkStatus(); } mutex mu_; @@ -2265,7 +2267,7 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel { *hash = Hash64Combine(*hash, Hash64(reader_path_prefix_)); *hash = Hash64Combine(*hash, Hash64(writer_path_prefix_)); *hash = Hash64Combine(*hash, shard_size_bytes_); - return OkStatus(); + return absl::OkStatus(); } const int graph_def_version_; diff --git a/tensorflow/core/kernels/data/experimental/sql/sqlite_query_connection.cc b/tensorflow/core/kernels/data/experimental/sql/sqlite_query_connection.cc index 82fcb2fa4d7299..8b63d10a6e86e5 100644 --- a/tensorflow/core/kernels/data/experimental/sql/sqlite_query_connection.cc +++ b/tensorflow/core/kernels/data/experimental/sql/sqlite_query_connection.cc @@ -40,14 +40,14 @@ Status SqliteQueryConnection::Open(const string& data_source_name, data_source_name, SQLITE_OPEN_READWRITE | SQLITE_OPEN_CREATE, &db_)); query_ = query; output_types_ = output_types; - return OkStatus(); + return absl::OkStatus(); } Status SqliteQueryConnection::Close() { stmt_ = SqliteStatement(); db_->Unref(); db_ = nullptr; - return OkStatus(); + return absl::OkStatus(); } Status SqliteQueryConnection::GetNext(IteratorContext* ctx, @@ -63,7 +63,7 @@ Status SqliteQueryConnection::GetNext(IteratorContext* ctx, FillTensorWithResultSetEntry(dt, i, &out_tensors->back()); } } - return OkStatus(); + return absl::OkStatus(); } Status SqliteQueryConnection::PrepareQuery() { @@ -77,7 +77,7 @@ Status SqliteQueryConnection::PrepareQuery() { column_count, output_types_.size())); } column_count_ = column_count; - return OkStatus(); + return absl::OkStatus(); } void SqliteQueryConnection::FillTensorWithResultSetEntry( diff --git a/tensorflow/core/kernels/data/experimental/sql_dataset_op.cc b/tensorflow/core/kernels/data/experimental/sql_dataset_op.cc index 5d190bd8598958..daf55e01f85d29 100644 --- a/tensorflow/core/kernels/data/experimental/sql_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/sql_dataset_op.cc @@ -105,10 +105,10 @@ class SqlDatasetOp : public DatasetOpKernel { Status InputDatasets( std::vector* inputs) const override { - return OkStatus(); + return absl::OkStatus(); } - Status CheckExternalState() const override { return OkStatus(); } + Status CheckExternalState() const override { return absl::OkStatus(); } protected: Status AsGraphDefInternal(SerializationContext* ctx, @@ -123,7 +123,7 @@ class SqlDatasetOp : public DatasetOpKernel { TF_RETURN_IF_ERROR(b->AddScalar(query_, &query_node)); TF_RETURN_IF_ERROR(b->AddDataset( this, {driver_name_node, data_source_name_node, query_node}, output)); - return OkStatus(); + return absl::OkStatus(); } private: @@ -147,7 +147,7 @@ class SqlDatasetOp : public DatasetOpKernel { if (!query_connection_initialized_) { TF_RETURN_IF_ERROR(InitializeQueryConnection()); } - Status status = OkStatus(); + Status status = absl::OkStatus(); if (!end_of_sequence_) { next_calls_++; status = @@ -170,7 +170,7 @@ class SqlDatasetOp : public DatasetOpKernel { TF_RETURN_IF_ERROR( writer->WriteScalar(full_name("next_calls"), next_calls_)); } - return OkStatus(); + return absl::OkStatus(); } Status RestoreInternal(IteratorContext* ctx, @@ -192,7 +192,7 @@ class SqlDatasetOp : public DatasetOpKernel { query_connection_initialized_ = false; end_of_sequence_ = false; } - return OkStatus(); + return absl::OkStatus(); } private: @@ -209,7 +209,7 @@ class SqlDatasetOp : public DatasetOpKernel { LOG(WARNING) << "Failed to connect to database: " << s; return s; } - return OkStatus(); + return absl::OkStatus(); } mutex mu_; diff --git a/tensorflow/core/kernels/data/experimental/stats_aggregator_ops.cc b/tensorflow/core/kernels/data/experimental/stats_aggregator_ops.cc index 6404c0fc52a200..4bb95b087fcb5b 100644 --- a/tensorflow/core/kernels/data/experimental/stats_aggregator_ops.cc +++ b/tensorflow/core/kernels/data/experimental/stats_aggregator_ops.cc @@ -86,7 +86,7 @@ class StatsAggregatorImpl : public StatsAggregator { // in V1. Status SetSummaryWriter( SummaryWriterInterface* summary_writer_interface) override { - return OkStatus(); + return absl::OkStatus(); } void IncrementCounter(const string& name, const string& label, @@ -124,7 +124,7 @@ class StatsAggregatorHandleOp Status CreateResource(StatsAggregatorResource** ret) override TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { *ret = new StatsAggregatorResource(std::make_unique()); - return OkStatus(); + return absl::OkStatus(); } }; @@ -159,7 +159,7 @@ class StatsAggregatorImplV2 : public StatsAggregator { mutex_lock l(mu_); if (summary_writer_interface_) TF_RETURN_IF_ERROR(summary_writer_interface_->Flush()); - return OkStatus(); + return absl::OkStatus(); } void IncrementCounter(const string& name, const string& label, @@ -194,7 +194,7 @@ class StatsAggregatorImplV2 : public StatsAggregator { } summary_writer_interface_ = summary_writer_interface; summary_writer_interface_->Ref(); - return OkStatus(); + return absl::OkStatus(); } private: @@ -250,7 +250,7 @@ class StatsAggregatorHandleOpV2 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { *ret = new StatsAggregatorResource(std::make_unique()); - return OkStatus(); + return absl::OkStatus(); } }; diff --git a/tensorflow/core/kernels/data/experimental/stats_dataset_ops.cc b/tensorflow/core/kernels/data/experimental/stats_dataset_ops.cc index 146a74ec85d0f0..0ff1595afae7d3 100644 --- a/tensorflow/core/kernels/data/experimental/stats_dataset_ops.cc +++ b/tensorflow/core/kernels/data/experimental/stats_dataset_ops.cc @@ -83,7 +83,7 @@ class LatencyStatsDatasetOp : public UnaryDatasetOpKernel { Status InputDatasets( std::vector* inputs) const override { inputs->push_back(input_); - return OkStatus(); + return absl::OkStatus(); } Status CheckExternalState() const override { @@ -99,7 +99,7 @@ class LatencyStatsDatasetOp : public UnaryDatasetOpKernel { Node* tag_node; TF_RETURN_IF_ERROR(b->AddScalar(tag_, &tag_node)); TF_RETURN_IF_ERROR(b->AddDataset(this, {input_node, tag_node}, output)); - return OkStatus(); + return absl::OkStatus(); } private: @@ -140,14 +140,14 @@ class LatencyStatsDatasetOp : public UnaryDatasetOpKernel { IteratorStateWriter* writer) override { mutex_lock l(mu_); TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_)); - return OkStatus(); + return absl::OkStatus(); } Status RestoreInternal(IteratorContext* ctx, IteratorStateReader* reader) override { mutex_lock l(mu_); TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_)); - return OkStatus(); + return absl::OkStatus(); } private: @@ -218,7 +218,7 @@ class BytesProducedStatsDatasetOp : public UnaryDatasetOpKernel { Node* tag_node; TF_RETURN_IF_ERROR(b->AddScalar(tag_, &tag_node)); TF_RETURN_IF_ERROR(b->AddDataset(this, {input_node, tag_node}, output)); - return OkStatus(); + return absl::OkStatus(); } private: @@ -261,14 +261,14 @@ class BytesProducedStatsDatasetOp : public UnaryDatasetOpKernel { IteratorStateWriter* writer) override { mutex_lock l(mu_); TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_)); - return OkStatus(); + return absl::OkStatus(); } Status RestoreInternal(IteratorContext* ctx, IteratorStateReader* reader) override { mutex_lock l(mu_); TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_)); - return OkStatus(); + return absl::OkStatus(); } private: diff --git a/tensorflow/core/kernels/data/experimental/take_while_dataset_op.cc b/tensorflow/core/kernels/data/experimental/take_while_dataset_op.cc index 395c390f403527..50fe016ee7e822 100644 --- a/tensorflow/core/kernels/data/experimental/take_while_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/take_while_dataset_op.cc @@ -88,7 +88,7 @@ class TakeWhileDatasetOp : public UnaryDatasetOpKernel { Status InputDatasets( std::vector* inputs) const override { inputs->push_back(input_); - return OkStatus(); + return absl::OkStatus(); } Status CheckExternalState() const override { @@ -119,7 +119,7 @@ class TakeWhileDatasetOp : public UnaryDatasetOpKernel { {std::make_pair("predicate", f_attr), std::make_pair("Targuments", other_arguments_types_attr)}, output)); - return OkStatus(); + return absl::OkStatus(); } private: @@ -144,7 +144,7 @@ class TakeWhileDatasetOp : public UnaryDatasetOpKernel { tf_shared_lock l(mu_); if (!input_impl_) { *end_of_sequence = true; - return OkStatus(); + return absl::OkStatus(); } TF_RETURN_IF_ERROR( input_impl_->GetNext(ctx, out_tensors, end_of_sequence)); @@ -152,7 +152,7 @@ class TakeWhileDatasetOp : public UnaryDatasetOpKernel { if (*end_of_sequence) { mutex_lock l(mu_); input_impl_.reset(); - return OkStatus(); + return absl::OkStatus(); } std::vector result; TF_RETURN_IF_ERROR(instantiated_captured_func_->RunWithBorrowedArgs( @@ -169,7 +169,7 @@ class TakeWhileDatasetOp : public UnaryDatasetOpKernel { input_impl_.reset(); out_tensors->clear(); } - return OkStatus(); + return absl::OkStatus(); } protected: @@ -189,7 +189,7 @@ class TakeWhileDatasetOp : public UnaryDatasetOpKernel { if (input_impl_) { TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_)); } - return OkStatus(); + return absl::OkStatus(); } Status RestoreInternal(IteratorContext* ctx, @@ -203,7 +203,7 @@ class TakeWhileDatasetOp : public UnaryDatasetOpKernel { } else { TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_)); } - return OkStatus(); + return absl::OkStatus(); } private: diff --git a/tensorflow/core/kernels/data/experimental/threadpool_dataset_op.cc b/tensorflow/core/kernels/data/experimental/threadpool_dataset_op.cc index 62a317acec4fec..79f7689c2a52bc 100644 --- a/tensorflow/core/kernels/data/experimental/threadpool_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/threadpool_dataset_op.cc @@ -51,7 +51,7 @@ Status ValidateNumThreads(int32_t num_threads) { if (num_threads >= kThreadLimit) { return errors::InvalidArgument("`num_threads` must be < ", kThreadLimit); } - return OkStatus(); + return absl::OkStatus(); } } // namespace @@ -136,7 +136,7 @@ class ThreadPoolHandleOp : public OpKernel { num_threads_, /*low_latency_hint=*/false, max_intra_op_parallelism_); - return OkStatus(); + return absl::OkStatus(); })); initialized_ = true; } @@ -210,7 +210,7 @@ class ThreadPoolDatasetOp : public UnaryDatasetOpKernel { Status InputDatasets( std::vector* inputs) const override { inputs->push_back(input_); - return OkStatus(); + return absl::OkStatus(); } Status CheckExternalState() const override { @@ -227,7 +227,7 @@ class ThreadPoolDatasetOp : public UnaryDatasetOpKernel { TF_RETURN_IF_ERROR(b->AddTensor(resource_handle_, &resource_handle_node)); TF_RETURN_IF_ERROR(b->AddDataset( this, {input_graph_node, resource_handle_node}, output)); - return OkStatus(); + return absl::OkStatus(); } private: @@ -259,13 +259,13 @@ class ThreadPoolDatasetOp : public UnaryDatasetOpKernel { IteratorStateWriter* writer) override { DCHECK(input_impl_ != nullptr); TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_)); - return OkStatus(); + return absl::OkStatus(); } Status RestoreInternal(IteratorContext* ctx, IteratorStateReader* reader) override { TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_)); - return OkStatus(); + return absl::OkStatus(); } private: @@ -332,7 +332,7 @@ class MaxIntraOpParallelismDatasetOp::Dataset : public DatasetBase { Status InputDatasets(std::vector* inputs) const override { inputs->clear(); inputs->push_back(input_); - return OkStatus(); + return absl::OkStatus(); } Status CheckExternalState() const override { @@ -350,7 +350,7 @@ class MaxIntraOpParallelismDatasetOp::Dataset : public DatasetBase { &max_intra_op_parallelism_node)); TF_RETURN_IF_ERROR(b->AddDataset( this, {input_graph_node, max_intra_op_parallelism_node}, output)); - return OkStatus(); + return absl::OkStatus(); } private: @@ -383,13 +383,13 @@ class MaxIntraOpParallelismDatasetOp::Dataset : public DatasetBase { IteratorStateWriter* writer) override { DCHECK(input_impl_ != nullptr); TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_)); - return OkStatus(); + return absl::OkStatus(); } Status RestoreInternal(IteratorContext* ctx, IteratorStateReader* reader) override { TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_)); - return OkStatus(); + return absl::OkStatus(); } TraceMeMetadata GetTraceMeMetadata() const override { @@ -475,7 +475,7 @@ class PrivateThreadPoolDatasetOp::Dataset : public DatasetBase { Status InputDatasets(std::vector* inputs) const override { inputs->clear(); inputs->push_back(input_); - return OkStatus(); + return absl::OkStatus(); } Status CheckExternalState() const override { @@ -492,7 +492,7 @@ class PrivateThreadPoolDatasetOp::Dataset : public DatasetBase { TF_RETURN_IF_ERROR(b->AddScalar(num_threads_, &num_threads_node)); TF_RETURN_IF_ERROR( b->AddDataset(this, {input_graph_node, num_threads_node}, output)); - return OkStatus(); + return absl::OkStatus(); } private: @@ -528,13 +528,13 @@ class PrivateThreadPoolDatasetOp::Dataset : public DatasetBase { IteratorStateWriter* writer) override { DCHECK(input_impl_ != nullptr); TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_)); - return OkStatus(); + return absl::OkStatus(); } Status RestoreInternal(IteratorContext* ctx, IteratorStateReader* reader) override { TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_)); - return OkStatus(); + return absl::OkStatus(); } TraceMeMetadata GetTraceMeMetadata() const override { diff --git a/tensorflow/core/kernels/data/experimental/to_tf_record_op.cc b/tensorflow/core/kernels/data/experimental/to_tf_record_op.cc index ca3a985881a8db..c8d53de774ebe6 100644 --- a/tensorflow/core/kernels/data/experimental/to_tf_record_op.cc +++ b/tensorflow/core/kernels/data/experimental/to_tf_record_op.cc @@ -45,7 +45,7 @@ class ToTFRecordOp : public AsyncOpKernel { return errors::InvalidArgument(argument_name, " must be a scalar"); } *output = argument_t->scalar()(); - return OkStatus(); + return absl::OkStatus(); } void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override { @@ -119,7 +119,7 @@ class ToTFRecordOp : public AsyncOpKernel { } components.clear(); } while (!end_of_sequence); - return OkStatus(); + return absl::OkStatus(); } BackgroundWorker background_worker_; diff --git a/tensorflow/core/kernels/data/experimental/unbatch_dataset_op.cc b/tensorflow/core/kernels/data/experimental/unbatch_dataset_op.cc index 88bcf453693ed9..0d153486f10eee 100644 --- a/tensorflow/core/kernels/data/experimental/unbatch_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/unbatch_dataset_op.cc @@ -116,7 +116,7 @@ class UnbatchDatasetOp : public UnaryDatasetOpKernel { Status InputDatasets( std::vector* inputs) const override { inputs->push_back(input_); - return OkStatus(); + return absl::OkStatus(); } Status CheckExternalState() const override { @@ -130,7 +130,7 @@ class UnbatchDatasetOp : public UnaryDatasetOpKernel { Node* input_graph_node = nullptr; TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node)); TF_RETURN_IF_ERROR(b->AddDataset(this, {input_graph_node}, output)); - return OkStatus(); + return absl::OkStatus(); } private: @@ -157,7 +157,7 @@ class UnbatchDatasetOp : public UnaryDatasetOpKernel { mutex_lock l(mu_); if (!input_impl_) { *end_of_sequence = true; - return OkStatus(); + return absl::OkStatus(); } *end_of_sequence = false; while (!*end_of_sequence) { @@ -177,7 +177,7 @@ class UnbatchDatasetOp : public UnaryDatasetOpKernel { if (current_index_ >= current_batch_size_) { ctx->MergeCheckpoint(input_ckpt_.get()); } - return OkStatus(); + return absl::OkStatus(); } current_index_ = 0; current_batch_size_ = 0; @@ -207,7 +207,7 @@ class UnbatchDatasetOp : public UnaryDatasetOpKernel { } } input_impl_.reset(); - return OkStatus(); + return absl::OkStatus(); } protected: @@ -244,7 +244,7 @@ class UnbatchDatasetOp : public UnaryDatasetOpKernel { full_name(StrCat("tensors[", i, "]")), tensors_[i])); } } - return OkStatus(); + return absl::OkStatus(); } Status RestoreInternal(IteratorContext* ctx, @@ -267,7 +267,7 @@ class UnbatchDatasetOp : public UnaryDatasetOpKernel { if (current_index_ < current_batch_size_) { TF_RETURN_IF_ERROR(RestoreTensors(ctx, reader)); } - return OkStatus(); + return absl::OkStatus(); } private: @@ -298,7 +298,7 @@ class UnbatchDatasetOp : public UnaryDatasetOpKernel { shapes_[i] = tensors_[i].shape(); shapes_[i].RemoveDim(0); } - return OkStatus(); + return absl::OkStatus(); } mutex mu_; diff --git a/tensorflow/core/kernels/data/experimental/unique_dataset_op.cc b/tensorflow/core/kernels/data/experimental/unique_dataset_op.cc index d2f4ac53e5abcd..466a5f13bb47b8 100644 --- a/tensorflow/core/kernels/data/experimental/unique_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/unique_dataset_op.cc @@ -56,7 +56,7 @@ class UniqueDatasetOp::Dataset : public DatasetBase { Status InputDatasets(std::vector* inputs) const override { inputs->push_back(input_); - return OkStatus(); + return absl::OkStatus(); } Status CheckExternalState() const override { @@ -70,7 +70,7 @@ class UniqueDatasetOp::Dataset : public DatasetBase { Node* input_graph_node = nullptr; TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node)); TF_RETURN_IF_ERROR(b->AddDataset(this, {input_graph_node}, output)); - return OkStatus(); + return absl::OkStatus(); } private: @@ -99,7 +99,7 @@ class UniqueDatasetOp::Dataset : public DatasetBase { DCHECK_EQ(1, out_tensors->size()); saw_new_value = unique_elements_.insert((*out_tensors)[0]).second; } while (!saw_new_value); - return OkStatus(); + return absl::OkStatus(); } protected: @@ -124,7 +124,7 @@ class UniqueDatasetOp::Dataset : public DatasetBase { TF_RETURN_IF_ERROR(writer->WriteTensor( full_name(strings::StrCat("unique_elements[", i++, "]")), t)); } - return OkStatus(); + return absl::OkStatus(); } Status RestoreInternal(IteratorContext* ctx, @@ -151,7 +151,7 @@ class UniqueDatasetOp::Dataset : public DatasetBase { "value."); } } - return OkStatus(); + return absl::OkStatus(); } private: diff --git a/tensorflow/core/kernels/data/experimental/unique_dataset_op_test.cc b/tensorflow/core/kernels/data/experimental/unique_dataset_op_test.cc index eeabf2efd6ea49..e918812bb04bc5 100644 --- a/tensorflow/core/kernels/data/experimental/unique_dataset_op_test.cc +++ b/tensorflow/core/kernels/data/experimental/unique_dataset_op_test.cc @@ -40,14 +40,14 @@ class UniqueDatasetParams : public DatasetParams { Status GetInputNames(std::vector* input_names) const override { input_names->clear(); input_names->emplace_back(UniqueDatasetOp::kInputDataset); - return OkStatus(); + return absl::OkStatus(); } Status GetAttributes(AttributeVector* attributes) const override { *attributes = {{"output_types", output_dtypes_}, {"output_shapes", output_shapes_}, {"metadata", ""}}; - return OkStatus(); + return absl::OkStatus(); } string dataset_type() const override { return UniqueDatasetOp::kDatasetType; } diff --git a/tensorflow/core/kernels/data/filter_dataset_op.cc b/tensorflow/core/kernels/data/filter_dataset_op.cc index a14e8f32b9ac23..425b38b458149e 100644 --- a/tensorflow/core/kernels/data/filter_dataset_op.cc +++ b/tensorflow/core/kernels/data/filter_dataset_op.cc @@ -81,7 +81,7 @@ class FilterDatasetOp::Dataset : public DatasetBase { Status InputDatasets(std::vector* inputs) const override { inputs->push_back(input_); - return OkStatus(); + return absl::OkStatus(); } Status CheckExternalState() const override { @@ -107,7 +107,7 @@ class FilterDatasetOp::Dataset : public DatasetBase { TF_RETURN_IF_ERROR(b->AddDataset( this, {{0, input_graph_node}}, {{1, other_arguments}}, {{kPredicate, f}, {kTarguments, other_arguments_types_attr}}, output)); - return OkStatus(); + return absl::OkStatus(); } private: @@ -140,7 +140,7 @@ class FilterDatasetOp::Dataset : public DatasetBase { tf_shared_lock l(mu_); if (!input_impl_) { *end_of_sequence = true; - return OkStatus(); + return absl::OkStatus(); } TF_RETURN_IF_ERROR( input_impl_->GetNext(ctx, out_tensors, end_of_sequence)); @@ -148,12 +148,15 @@ class FilterDatasetOp::Dataset : public DatasetBase { if (*end_of_sequence) { mutex_lock l(mu_); input_impl_.reset(); - return OkStatus(); + return absl::OkStatus(); } std::vector result; - TF_RETURN_IF_ERROR(instantiated_captured_func_->RunWithBorrowedArgs( - ctx, *out_tensors, &result, model_node())); + auto status = instantiated_captured_func_->RunWithBorrowedArgs( + ctx, *out_tensors, &result, model_node()); + if (!status.ok()) { + return AddErrorContext(status); + } if (result.size() != 1 || result[0].dtype() != DT_BOOL || result[0].NumElements() != 1) { @@ -201,7 +204,7 @@ class FilterDatasetOp::Dataset : public DatasetBase { static_cast(1)); } *end_of_sequence = false; - return OkStatus(); + return absl::OkStatus(); } protected: @@ -224,7 +227,7 @@ class FilterDatasetOp::Dataset : public DatasetBase { writer->WriteScalar(prefix(), kFilteredElements, filtered_elements_)); TF_RETURN_IF_ERROR( writer->WriteScalar(prefix(), kDroppedElements, dropped_elements_)); - return OkStatus(); + return absl::OkStatus(); } Status RestoreInternal(IteratorContext* ctx, @@ -242,7 +245,7 @@ class FilterDatasetOp::Dataset : public DatasetBase { reader->ReadScalar(prefix(), kFilteredElements, &filtered_elements_)); TF_RETURN_IF_ERROR( reader->ReadScalar(prefix(), kDroppedElements, &dropped_elements_)); - return OkStatus(); + return absl::OkStatus(); } data::TraceMeMetadata GetTraceMeMetadata() const override { diff --git a/tensorflow/core/kernels/data/filter_dataset_op_test.cc b/tensorflow/core/kernels/data/filter_dataset_op_test.cc index 1bb0e374ebb632..e325b604c60dda 100644 --- a/tensorflow/core/kernels/data/filter_dataset_op_test.cc +++ b/tensorflow/core/kernels/data/filter_dataset_op_test.cc @@ -56,7 +56,7 @@ class FilterDatasetParams : public DatasetParams { absl::StrCat(FilterDatasetOp::kOtherArguments, "_", i)); } - return OkStatus(); + return absl::OkStatus(); } Status GetAttributes(AttributeVector* attr_vector) const override { @@ -65,7 +65,7 @@ class FilterDatasetParams : public DatasetParams { {"output_shapes", output_shapes_}, {"output_types", output_dtypes_}, {"metadata", ""}}; - return OkStatus(); + return absl::OkStatus(); } std::vector func_lib() const override { return func_lib_; } diff --git a/tensorflow/core/kernels/data/finalize_dataset_op_test.cc b/tensorflow/core/kernels/data/finalize_dataset_op_test.cc index 0cd087bc0bfa80..efd135c0e24839 100644 --- a/tensorflow/core/kernels/data/finalize_dataset_op_test.cc +++ b/tensorflow/core/kernels/data/finalize_dataset_op_test.cc @@ -42,14 +42,14 @@ class FinalizeDatasetParams : public DatasetParams { Status GetInputNames(std::vector* input_names) const override { input_names->emplace_back(FinalizeDatasetOp::kInputDataset); - return OkStatus(); + return absl::OkStatus(); } Status GetAttributes(AttributeVector* attr_vector) const override { *attr_vector = {{FinalizeDatasetOp::kHasCapturedRef, has_captured_ref_}, {FinalizeDatasetOp::kOutputTypes, output_dtypes_}, {FinalizeDatasetOp::kOutputShapes, output_shapes_}}; - return OkStatus(); + return absl::OkStatus(); } string dataset_type() const override { return "Finalize"; } diff --git a/tensorflow/core/kernels/data/fixed_length_record_dataset_op.cc b/tensorflow/core/kernels/data/fixed_length_record_dataset_op.cc index 9440721bc858a6..fab7523bd4587c 100644 --- a/tensorflow/core/kernels/data/fixed_length_record_dataset_op.cc +++ b/tensorflow/core/kernels/data/fixed_length_record_dataset_op.cc @@ -95,10 +95,10 @@ class FixedLengthRecordDatasetOp::Dataset : public DatasetBase { } Status InputDatasets(std::vector* inputs) const override { - return OkStatus(); + return absl::OkStatus(); } - Status CheckExternalState() const override { return OkStatus(); } + Status CheckExternalState() const override { return absl::OkStatus(); } protected: Status AsGraphDefInternal(SerializationContext* ctx, @@ -121,7 +121,7 @@ class FixedLengthRecordDatasetOp::Dataset : public DatasetBase { {filenames, header_bytes, record_bytes, footer_bytes, buffer_size, compression_type}, output)); - return OkStatus(); + return absl::OkStatus(); } private: @@ -152,7 +152,7 @@ class FixedLengthRecordDatasetOp::Dataset : public DatasetBase { record_tensor.scalar()() = record; out_tensors->emplace_back(std::move(record_tensor)); *end_of_sequence = false; - return OkStatus(); + return absl::OkStatus(); } // We have reached the end of the current file, so maybe move on to @@ -165,7 +165,7 @@ class FixedLengthRecordDatasetOp::Dataset : public DatasetBase { // Iteration ends when there are no more files to process. if (current_file_index_ == dataset()->filenames_.size()) { *end_of_sequence = true; - return OkStatus(); + return absl::OkStatus(); } // Actually move on to next file. @@ -208,7 +208,7 @@ class FixedLengthRecordDatasetOp::Dataset : public DatasetBase { int64_t current_pos = input_buffer_ ? input_buffer_->Tell() : -1; TF_RETURN_IF_ERROR( writer->WriteScalar(prefix(), kCurrentPos, current_pos)); - return OkStatus(); + return absl::OkStatus(); } Status RestoreInternal(IteratorContext* ctx, @@ -239,7 +239,7 @@ class FixedLengthRecordDatasetOp::Dataset : public DatasetBase { TF_RETURN_IF_ERROR(input_buffer_->Seek(current_pos)); } - return OkStatus(); + return absl::OkStatus(); } private: @@ -279,7 +279,7 @@ class FixedLengthRecordDatasetOp::Dataset : public DatasetBase { record_tensor.scalar()() = std::move(record); out_tensors->emplace_back(std::move(record_tensor)); *end_of_sequence = false; - return OkStatus(); + return absl::OkStatus(); } } else { tstring record; @@ -298,7 +298,7 @@ class FixedLengthRecordDatasetOp::Dataset : public DatasetBase { record_tensor.scalar()() = std::move(record); out_tensors->emplace_back(std::move(record_tensor)); *end_of_sequence = false; - return OkStatus(); + return absl::OkStatus(); } if (errors::IsOutOfRange(s) && !record.empty()) { uint64 body_size = @@ -326,7 +326,7 @@ class FixedLengthRecordDatasetOp::Dataset : public DatasetBase { // Iteration ends when there are no more files to process. if (current_file_index_ == dataset()->filenames_.size()) { *end_of_sequence = true; - return OkStatus(); + return absl::OkStatus(); } // Actually move on to next file. @@ -397,7 +397,7 @@ class FixedLengthRecordDatasetOp::Dataset : public DatasetBase { buffered_input_stream_ ? buffered_input_stream_->Tell() : -1; TF_RETURN_IF_ERROR( writer->WriteScalar(prefix(), kCurrentPos, current_pos)); - return OkStatus(); + return absl::OkStatus(); } Status RestoreInternal(IteratorContext* ctx, @@ -434,7 +434,7 @@ class FixedLengthRecordDatasetOp::Dataset : public DatasetBase { dataset()->footer_bytes_, &lookahead_cache_)); } - return OkStatus(); + return absl::OkStatus(); } private: diff --git a/tensorflow/core/kernels/data/fixed_length_record_dataset_op_test.cc b/tensorflow/core/kernels/data/fixed_length_record_dataset_op_test.cc index 366ca2078c2c89..125d6de637abc1 100644 --- a/tensorflow/core/kernels/data/fixed_length_record_dataset_op_test.cc +++ b/tensorflow/core/kernels/data/fixed_length_record_dataset_op_test.cc @@ -63,13 +63,13 @@ class FixedLengthRecordDatasetParams : public DatasetParams { FixedLengthRecordDatasetOp::kFooterBytes, FixedLengthRecordDatasetOp::kBufferSize, FixedLengthRecordDatasetOp::kCompressionType}; - return OkStatus(); + return absl::OkStatus(); } Status GetAttributes(AttributeVector* attr_vector) const override { attr_vector->clear(); attr_vector->emplace_back("metadata", ""); - return OkStatus(); + return absl::OkStatus(); } string dataset_type() const override { @@ -107,7 +107,7 @@ Status CreateTestFiles(const std::vector& filenames, WriteDataToFile(filenames[i], contents[i].data(), params)); } } - return OkStatus(); + return absl::OkStatus(); } // Test case 1: multiple fixed-length record files with ZLIB compression. diff --git a/tensorflow/core/kernels/data/flat_map_dataset_op.cc b/tensorflow/core/kernels/data/flat_map_dataset_op.cc index 9abee3f1112296..6a9afa758ed3ab 100644 --- a/tensorflow/core/kernels/data/flat_map_dataset_op.cc +++ b/tensorflow/core/kernels/data/flat_map_dataset_op.cc @@ -87,7 +87,7 @@ class FlatMapDatasetOp::Dataset : public DatasetBase { Status InputDatasets(std::vector* inputs) const override { inputs->push_back(input_); - return OkStatus(); + return absl::OkStatus(); } Status CheckExternalState() const override { @@ -116,7 +116,7 @@ class FlatMapDatasetOp::Dataset : public DatasetBase { {std::make_pair(kFunc, f), std::make_pair(kTarguments, other_arguments_types_attr)}, // Attrs output)); - return OkStatus(); + return absl::OkStatus(); } private: @@ -144,7 +144,7 @@ class FlatMapDatasetOp::Dataset : public DatasetBase { do { if (!input_impl_) { *end_of_sequence = true; - return OkStatus(); + return absl::OkStatus(); } if (current_element_iterator_) { // We are currently processing a mapped element, so try to get the @@ -162,7 +162,7 @@ class FlatMapDatasetOp::Dataset : public DatasetBase { if (!end_of_element) { // Produce the subelement as output. *end_of_sequence = false; - return OkStatus(); + return absl::OkStatus(); } // Since this sub-iterator is done, // we can commit `input_ckpt_` to `ctx->checkpoint()` @@ -186,7 +186,7 @@ class FlatMapDatasetOp::Dataset : public DatasetBase { input_ckpt_->Merge(input_ctx->checkpoint()); if (*end_of_sequence) { input_impl_.reset(); - return OkStatus(); + return absl::OkStatus(); } TF_RETURN_IF_ERROR( @@ -203,7 +203,7 @@ class FlatMapDatasetOp::Dataset : public DatasetBase { while (*num_skipped < num_to_skip) { if (!input_impl_) { *end_of_sequence = true; - return OkStatus(); + return absl::OkStatus(); } if (current_element_iterator_) { // We are currently processing a mapped element, so try to get the @@ -256,13 +256,13 @@ class FlatMapDatasetOp::Dataset : public DatasetBase { if (*end_of_sequence) { input_impl_.reset(); *end_of_sequence = true; - return OkStatus(); + return absl::OkStatus(); } TF_RETURN_IF_ERROR( BuildCurrentElementIteratorLocked(ctx, /*is_get_next=*/false)); } *end_of_sequence = false; - return OkStatus(); + return absl::OkStatus(); // LINT.ThenChange(:GetNextInternal) } @@ -298,7 +298,7 @@ class FlatMapDatasetOp::Dataset : public DatasetBase { TF_RETURN_IF_ERROR(SaveInput(ctx, writer, current_element_iterator_)); } } - return OkStatus(); + return absl::OkStatus(); } Status RestoreInternal(IteratorContext* ctx, @@ -329,7 +329,7 @@ class FlatMapDatasetOp::Dataset : public DatasetBase { TF_RETURN_IF_ERROR(RestoreCurrentElementIterator(ctx, reader)); } } - return OkStatus(); + return absl::OkStatus(); } private: @@ -367,7 +367,7 @@ class FlatMapDatasetOp::Dataset : public DatasetBase { TF_RETURN_IF_ERROR( BuildCurrentElementIteratorLocked(ctx, /*is_get_next=*/false)); TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, current_element_iterator_)); - return OkStatus(); + return absl::OkStatus(); } Status RestoreCurrentElementIteratorSymbolic(IteratorContext* ctx, @@ -388,7 +388,7 @@ class FlatMapDatasetOp::Dataset : public DatasetBase { TF_RETURN_IF_ERROR( BuildCurrentElementIteratorLocked(ctx, /*is_get_next=*/false)); TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, current_element_iterator_)); - return OkStatus(); + return absl::OkStatus(); } mutex mu_; diff --git a/tensorflow/core/kernels/data/flat_map_dataset_op_test.cc b/tensorflow/core/kernels/data/flat_map_dataset_op_test.cc index b2bf4052dcaa04..b6a68065c93845 100644 --- a/tensorflow/core/kernels/data/flat_map_dataset_op_test.cc +++ b/tensorflow/core/kernels/data/flat_map_dataset_op_test.cc @@ -52,7 +52,7 @@ class FlatMapDatasetParams : public DatasetParams { input_names->emplace_back( absl::StrCat(FlatMapDatasetOp::kOtherArguments, "_", i)); } - return OkStatus(); + return absl::OkStatus(); } Status GetAttributes(AttributeVector* attr_vector) const override { @@ -61,7 +61,7 @@ class FlatMapDatasetParams : public DatasetParams { {"output_shapes", output_shapes_}, {"output_types", output_dtypes_}, {"metadata", ""}}; - return OkStatus(); + return absl::OkStatus(); } string dataset_type() const override { diff --git a/tensorflow/core/kernels/data/generator_dataset_op.cc b/tensorflow/core/kernels/data/generator_dataset_op.cc index 99149b606ecdf0..401a8e50284859 100644 --- a/tensorflow/core/kernels/data/generator_dataset_op.cc +++ b/tensorflow/core/kernels/data/generator_dataset_op.cc @@ -76,7 +76,7 @@ class GeneratorDatasetOp::Dataset : public DatasetBase { } Status InputDatasets(std::vector* inputs) const override { - return OkStatus(); + return absl::OkStatus(); } Status CheckExternalState() const override { @@ -119,7 +119,7 @@ class GeneratorDatasetOp::Dataset : public DatasetBase { dataset()->next_func_->Instantiate(ctx, &instantiated_next_func_)); TF_RETURN_IF_ERROR(dataset()->finalize_func_->Instantiate( ctx, &instantiated_finalize_func_)); - return OkStatus(); + return absl::OkStatus(); } Status GetNextInternal(IteratorContext* ctx, @@ -135,7 +135,7 @@ class GeneratorDatasetOp::Dataset : public DatasetBase { if (finalized_) { *end_of_sequence = true; - return OkStatus(); + return absl::OkStatus(); } Status s = instantiated_next_func_->RunWithBorrowedArgs( @@ -145,7 +145,7 @@ class GeneratorDatasetOp::Dataset : public DatasetBase { } else if (errors::IsOutOfRange(s)) { // `next_func` may deliberately raise `errors::OutOfRange` // to indicate that we should terminate the iteration. - s = OkStatus(); + s = absl::OkStatus(); *end_of_sequence = true; // NOTE(mrry): We ignore any tensors returned by the finalize function. diff --git a/tensorflow/core/kernels/data/get_options_op_test.cc b/tensorflow/core/kernels/data/get_options_op_test.cc index c7b8face3323ce..8f5ae9d7ea7d8b 100644 --- a/tensorflow/core/kernels/data/get_options_op_test.cc +++ b/tensorflow/core/kernels/data/get_options_op_test.cc @@ -44,11 +44,11 @@ class GetOptionsParams : public DatasetParams { Status GetInputNames(std::vector* input_names) const override { input_names->emplace_back(OptionsDatasetOp::kInputDataset); - return OkStatus(); + return absl::OkStatus(); } Status GetAttributes(AttributeVector* attr_vector) const override { - return OkStatus(); + return absl::OkStatus(); } string dataset_type() const override { return "GetOptions"; } diff --git a/tensorflow/core/kernels/data/interleave_dataset_op.cc b/tensorflow/core/kernels/data/interleave_dataset_op.cc index bbc0c1d76947f7..772f1fdb416d50 100644 --- a/tensorflow/core/kernels/data/interleave_dataset_op.cc +++ b/tensorflow/core/kernels/data/interleave_dataset_op.cc @@ -106,7 +106,7 @@ class InterleaveDatasetOp::Dataset : public DatasetBase { Status InputDatasets(std::vector* inputs) const override { inputs->push_back(input_); - return OkStatus(); + return absl::OkStatus(); } Status CheckExternalState() const override { @@ -137,7 +137,7 @@ class InterleaveDatasetOp::Dataset : public DatasetBase { this, {{0, input_node}, {2, cycle_length_node}, {3, block_length_node}}, {{1, other_arguments}}, {{kFunc, f}, {kTarguments, other_arguments_types_attr}}, output)); - return OkStatus(); + return absl::OkStatus(); } private: @@ -202,7 +202,7 @@ class InterleaveDatasetOp::Dataset : public DatasetBase { // Produce the subelement as output. AdvancePosition(); *end_of_sequence = false; - return OkStatus(); + return absl::OkStatus(); } else { // We have reached the end of the current element, so move // on to the next element in the cycle. @@ -228,7 +228,7 @@ class InterleaveDatasetOp::Dataset : public DatasetBase { ctx->MergeCheckpoint(input_ckpt_.get()); *end_of_sequence = true; - return OkStatus(); + return absl::OkStatus(); } Status SkipInternal(IteratorContext* ctx, int num_to_skip, @@ -267,7 +267,7 @@ class InterleaveDatasetOp::Dataset : public DatasetBase { } if (num_to_skip == *num_skipped) { *end_of_sequence = false; - return OkStatus(); + return absl::OkStatus(); } } else { TF_RETURN_IF_ERROR(MoveToNextElement(ctx)); @@ -277,7 +277,7 @@ class InterleaveDatasetOp::Dataset : public DatasetBase { ctx->MergeCheckpoint(input_ckpt_.get()); *end_of_sequence = true; - return OkStatus(); + return absl::OkStatus(); } protected: @@ -309,7 +309,7 @@ class InterleaveDatasetOp::Dataset : public DatasetBase { last_checkpointed_input_element_index_)); TF_RETURN_IF_ERROR(SaveCurrentElements(ctx, writer)); - return OkStatus(); + return absl::OkStatus(); } Status RestoreInternal(IteratorContext* ctx, @@ -357,7 +357,7 @@ class InterleaveDatasetOp::Dataset : public DatasetBase { TF_RETURN_IF_ERROR( RestoreCurrentElements(ctx, reader, input_element_indices, std::move(checkpoints), std::move(args))); - return OkStatus(); + return absl::OkStatus(); } TraceMeMetadata GetTraceMeMetadata() const override { @@ -415,7 +415,7 @@ class InterleaveDatasetOp::Dataset : public DatasetBase { current_elements_[idx]->input_element_index)); } } - return OkStatus(); + return absl::OkStatus(); } absl::StatusOr> RestoreInputOffsets( @@ -569,7 +569,7 @@ class InterleaveDatasetOp::Dataset : public DatasetBase { } input_ckpt_->Merge(input_ctx->checkpoint()); - return OkStatus(); + return absl::OkStatus(); } Status RestoreCurrentElements( @@ -630,7 +630,7 @@ class InterleaveDatasetOp::Dataset : public DatasetBase { } } - return OkStatus(); + return absl::OkStatus(); } Status MoveToNextElement(IteratorContext* ctx) @@ -669,7 +669,7 @@ class InterleaveDatasetOp::Dataset : public DatasetBase { } else { AdvanceToNextInCycle(); } - return OkStatus(); + return absl::OkStatus(); } // Check if the given `input_element_index` is the earliest(oldest) current diff --git a/tensorflow/core/kernels/data/interleave_dataset_op_test.cc b/tensorflow/core/kernels/data/interleave_dataset_op_test.cc index ceda37554e0195..e27e147ae7c6da 100644 --- a/tensorflow/core/kernels/data/interleave_dataset_op_test.cc +++ b/tensorflow/core/kernels/data/interleave_dataset_op_test.cc @@ -65,7 +65,7 @@ class InterleaveDatasetParams : public DatasetParams { } input_names->emplace_back(InterleaveDatasetOp::kCycleLength); input_names->emplace_back(InterleaveDatasetOp::kBlockLength); - return OkStatus(); + return absl::OkStatus(); } Status GetAttributes(AttributeVector* attr_vector) const override { @@ -74,7 +74,7 @@ class InterleaveDatasetParams : public DatasetParams { {"output_shapes", output_shapes_}, {"output_types", output_dtypes_}, {"metadata", ""}}; - return OkStatus(); + return absl::OkStatus(); } std::vector func_lib() const override { return func_lib_; } diff --git a/tensorflow/core/kernels/data/iterator_ops.cc b/tensorflow/core/kernels/data/iterator_ops.cc index 7ec2442016c936..3a00a0a39b35dc 100644 --- a/tensorflow/core/kernels/data/iterator_ops.cc +++ b/tensorflow/core/kernels/data/iterator_ops.cc @@ -171,7 +171,7 @@ Status IteratorResource::Save(OpKernelContext* ctx, } LOG(INFO) << "Saving symbolic checkpoint"; TF_RETURN_IF_ERROR(checkpoint.Save(writer)); - return OkStatus(); + return absl::OkStatus(); } SerializationContext::Params params(ctx); params.external_state_policy = external_state_policy; @@ -231,7 +231,7 @@ Status IteratorResource::Restore(OpKernelContext* ctx, new_state->MergeCheckpoint(iter_ctx.checkpoint()); mutex_lock l(mu_); std::swap(iterator_state_, new_state); - return OkStatus(); + return absl::OkStatus(); } Status IteratorResource::SetIteratorFromDataset(OpKernelContext* ctx, @@ -280,14 +280,15 @@ Status IteratorResource::SetIteratorFromDataset(OpKernelContext* ctx, TF_RETURN_IF_ERROR( VerifyShapesCompatible(output_shapes_, iterator->output_shapes())); new_state->DowncastAndSetIteratorAndDataset(std::move(iterator), dataset); + new_state->SetModel(iter_ctx.model()); new_state->MergeCheckpoint(iter_ctx.checkpoint()); mutex_lock l(mu_); std::swap(iterator_state_, new_state); tf_dataz_metrics_collector_ = std::make_shared( - env_, iterator_state_->iterator()); + env_, iterator_state_->iterator(), iterator_state_->model()); EnsureIteratorMemoryLoggerStarted(); TfDatazMetricsRegistry::Register(tf_dataz_metrics_collector_); - return OkStatus(); + return absl::OkStatus(); } void IteratorResource::State::DowncastAndSetIteratorAndDataset( @@ -305,6 +306,10 @@ void IteratorResource::State::MergeCheckpoint(MemoryCheckpoint* other) { } } +void IteratorResource::State::SetModel(std::shared_ptr model) { + model_ = model; +} + namespace { // A helper class that uses a list of IteratorStateVariant objects to represent @@ -344,7 +349,7 @@ class IteratorVariantSerializer { } num_tensors_ = variants_.size(); can_serialize_ = true; - return OkStatus(); + return absl::OkStatus(); } // Initializes `this` from `serialized_t` while restoring the iterator state. @@ -365,7 +370,7 @@ class IteratorVariantSerializer { } reader_ = std::make_unique(data); num_tensors_ = data.size(); - return OkStatus(); + return absl::OkStatus(); } int64_t NumTensors() { return num_tensors_; } @@ -385,7 +390,7 @@ class IteratorVariantSerializer { } serialized->vec()(i) = variants_[i]; } - return OkStatus(); + return absl::OkStatus(); } // Returns an IteratorStateReader to restore iterator state. Expects that @@ -462,7 +467,7 @@ void IteratorHandleOp::Compute(OpKernelContext* context) context->env(), output_dtypes_, output_shapes_, std::move(device_mgr), std::move(flib_def), std::move(pflr), flr); - return OkStatus(); + return absl::OkStatus(); })); Status s = VerifyResource(resource); @@ -485,7 +490,7 @@ Status IteratorHandleOp::VerifyResource(IteratorResource* resource) { VerifyTypesMatch(output_dtypes_, resource->output_dtypes())); TF_RETURN_IF_ERROR( VerifyShapesCompatible(output_shapes_, resource->output_shapes())); - return OkStatus(); + return absl::OkStatus(); } FunctionLibraryRuntime* IteratorHandleOp::CreatePrivateFLR( @@ -544,7 +549,7 @@ Status AnonymousIteratorHandleOp::CreateResource( *resource = new IteratorResource(ctx->env(), output_dtypes_, output_shapes_, std::move(device_mgr), std::move(flib_def), std::move(pflr), lib); - return OkStatus(); + return absl::OkStatus(); } HybridAsyncOpKernel::HybridAsyncOpKernel(OpKernelConstruction* ctx, @@ -659,7 +664,7 @@ class ToSingleElementOp : public AsyncOpKernel { if (!end_of_sequence) { return errors::InvalidArgument("Dataset had more than one element."); } - return OkStatus(); + return absl::OkStatus(); } IteratorMetricsCollector metrics_collector_; @@ -770,7 +775,7 @@ class OneShotIteratorOp : public AsyncOpKernel { ctx->env(), output_dtypes_, output_shapes_, /*device_mgr=*/nullptr, std::move(flib_def), std::move(pflr), flr); - return OkStatus(); + return absl::OkStatus(); })); core::ScopedUnref unref_iterator(*iterator); @@ -810,7 +815,7 @@ class OneShotIteratorOp : public AsyncOpKernel { TF_RETURN_IF_ERROR(GetDatasetFromVariantTensor(return_values[0], &dataset)); TF_RETURN_IF_ERROR((*iterator)->SetIteratorFromDataset(ctx, dataset)); (*iterator)->Ref(); - return OkStatus(); + return absl::OkStatus(); } void ProduceOutput(OpKernelContext* ctx, const DoneCallback& done) { @@ -901,7 +906,7 @@ Status IteratorGetNextOp::DoCompute(OpKernelContext* ctx) { for (int i = 0; i < components.size(); ++i) { ctx->set_output(i, components[i]); } - return OkStatus(); + return absl::OkStatus(); } Status IteratorGetNextAsOptionalOp::DoCompute(OpKernelContext* ctx) { diff --git a/tensorflow/core/kernels/data/iterator_ops.h b/tensorflow/core/kernels/data/iterator_ops.h index 1b5d75622c8872..841df1a3b19bce 100644 --- a/tensorflow/core/kernels/data/iterator_ops.h +++ b/tensorflow/core/kernels/data/iterator_ops.h @@ -26,6 +26,7 @@ limitations under the License. #include "tensorflow/core/data/unbounded_thread_pool.h" #include "tensorflow/core/framework/dataset.h" #include "tensorflow/core/framework/function_handle_cache.h" +#include "tensorflow/core/framework/model.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/types.h" @@ -114,6 +115,8 @@ class IteratorResource : public ResourceBase { DatasetBaseIterator* iterator() { return iterator_.get(); } + std::shared_ptr model() { return model_; } + const MemoryCheckpoint& checkpoint() const { return checkpoint_; } DatasetBase* dataset() { return dataset_.get(); } @@ -126,6 +129,8 @@ class IteratorResource : public ResourceBase { // Merges the given checkpoint with the checkpoint of this state. void MergeCheckpoint(MemoryCheckpoint* other); + void SetModel(std::shared_ptr model); + std::shared_ptr id_registry() { return id_registry_; } @@ -141,6 +146,7 @@ class IteratorResource : public ResourceBase { core::RefCountPtr dataset_; std::shared_ptr id_registry_; MemoryCheckpoint checkpoint_; + std::shared_ptr model_; }; IteratorMetricsCollector metrics_collector_; diff --git a/tensorflow/core/kernels/data/map_dataset_op.cc b/tensorflow/core/kernels/data/map_dataset_op.cc index 6661b103b69915..8db529beca95d5 100644 --- a/tensorflow/core/kernels/data/map_dataset_op.cc +++ b/tensorflow/core/kernels/data/map_dataset_op.cc @@ -83,7 +83,7 @@ class MapDatasetOp::Dataset : public DatasetBase { Status InputDatasets(std::vector* inputs) const override { inputs->push_back(input_); - return OkStatus(); + return absl::OkStatus(); } Status CheckExternalState() const override { @@ -142,7 +142,7 @@ class MapDatasetOp::Dataset : public DatasetBase { std::make_pair(kPreserveCardinality, preserve_cardinality_attr)}, // Attrs output)); - return OkStatus(); + return absl::OkStatus(); } private: @@ -169,7 +169,7 @@ class MapDatasetOp::Dataset : public DatasetBase { std::vector args; TF_RETURN_IF_ERROR(input_impl_->GetNext(ctx, &args, end_of_sequence)); if (*end_of_sequence) { - return OkStatus(); + return absl::OkStatus(); } Status s = instantiated_captured_func_->Run(ctx, std::move(args), @@ -185,11 +185,13 @@ class MapDatasetOp::Dataset : public DatasetBase { // `f` may deliberately raise `errors::OutOfRange` to indicate // that we should terminate the iteration early. *end_of_sequence = true; - return OkStatus(); + return absl::OkStatus(); } - } else { - return s; } + if (!s.ok()) { + return AddErrorContext(s); + } + return s; } protected: @@ -203,13 +205,13 @@ class MapDatasetOp::Dataset : public DatasetBase { TF_RETURN_IF_ERROR(ctx->HandleCheckExternalStateStatus( dataset()->captured_func_->CheckExternalState())); TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_)); - return OkStatus(); + return absl::OkStatus(); } Status RestoreInternal(IteratorContext* ctx, IteratorStateReader* reader) override { TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_)); - return OkStatus(); + return absl::OkStatus(); } private: diff --git a/tensorflow/core/kernels/data/map_defun_op.cc b/tensorflow/core/kernels/data/map_defun_op.cc index cf09562bba52e8..3171116e7ae404 100644 --- a/tensorflow/core/kernels/data/map_defun_op.cc +++ b/tensorflow/core/kernels/data/map_defun_op.cc @@ -102,7 +102,7 @@ class MapDefunOp::MapFunctionCallFrame : public CallFrameInterface { // The function is calling for a captured input *val = &compute_opts_->captured_inputs[index - compute_opts_->args.size()]; - return OkStatus(); + return absl::OkStatus(); } // NOTE: If contention on mu_ becomes problematic, we could create a vector @@ -118,7 +118,7 @@ class MapDefunOp::MapFunctionCallFrame : public CallFrameInterface { sliced_args_[index] = tensor::DeepCopy(sliced_args_[index]); } *val = &sliced_args_[index]; - return OkStatus(); + return absl::OkStatus(); } Status SetRetval(int index, const Tensor& val) override { @@ -287,7 +287,7 @@ Status MapDefunOp::SetupArgs(OpKernelContext* ctx, *compute_opts = new ComputeOptions(ctx, arguments, captured_inputs, std::move(arg_shapes), batch_size, output_shapes_, max_intra_op_parallelism_); - return OkStatus(); + return absl::OkStatus(); } Status MapDefunOp::SetupOutputs(OpKernelContext* ctx, ComputeOptions* opts) { @@ -303,7 +303,7 @@ Status MapDefunOp::SetupOutputs(OpKernelContext* ctx, ComputeOptions* opts) { TF_RETURN_IF_ERROR(opts->output.allocate(i, output_shape, &out)); } } - return OkStatus(); + return absl::OkStatus(); } namespace { diff --git a/tensorflow/core/kernels/data/map_defun_op_test.cc b/tensorflow/core/kernels/data/map_defun_op_test.cc index 4e390d44ead5c1..d48650cada3f82 100644 --- a/tensorflow/core/kernels/data/map_defun_op_test.cc +++ b/tensorflow/core/kernels/data/map_defun_op_test.cc @@ -59,7 +59,7 @@ class MapDefunOpParams : public DatasetParams { input_names->emplace_back( strings::StrCat(MapDefunOp::kCapturedInputs, "_", i)); } - return OkStatus(); + return absl::OkStatus(); } Status GetAttributes(AttributeVector* attr_vector) const override { @@ -70,7 +70,7 @@ class MapDefunOpParams : public DatasetParams { {MapDefunOp::kOutputTypes, output_dtypes_}, {MapDefunOp::kFunc, func_}, {MapDefunOp::kMaxIntraOpParallelism, max_intra_op_parallelism_}}; - return OkStatus(); + return absl::OkStatus(); } std::vector func_lib() const override { return func_lib_; } @@ -100,7 +100,7 @@ class MapDefunOpTest : public DatasetOpsTestBase { NodeDef node_def = test::function::NDef(kNodeName, kOpName, input_namess, attributes); TF_RETURN_IF_ERROR(CreateOpKernel(node_def, map_defun_kernel)); - return OkStatus(); + return absl::OkStatus(); } // Creates a new `MapDefun` op kernel context. @@ -109,7 +109,7 @@ class MapDefunOpTest : public DatasetOpsTestBase { std::unique_ptr* context) { TF_RETURN_IF_ERROR(CheckOpKernelInput(*op_kernel, *inputs)); TF_RETURN_IF_ERROR(CreateOpKernelContext(op_kernel, inputs, context)); - return OkStatus(); + return absl::OkStatus(); } }; diff --git a/tensorflow/core/kernels/data/model_dataset_op.cc b/tensorflow/core/kernels/data/model_dataset_op.cc index 7bd4b738520741..15a19b992d66ee 100644 --- a/tensorflow/core/kernels/data/model_dataset_op.cc +++ b/tensorflow/core/kernels/data/model_dataset_op.cc @@ -95,7 +95,7 @@ class ModelDatasetOp::Dataset : public DatasetBase { Status InputDatasets(std::vector* inputs) const override { inputs->push_back(input_); - return OkStatus(); + return absl::OkStatus(); } Status CheckExternalState() const override { @@ -121,7 +121,7 @@ class ModelDatasetOp::Dataset : public DatasetBase { std::make_pair(kCpuBudget, cpu_budget_attr), std::make_pair(kRamBudget, ram_budget_attr)}, output)); - return OkStatus(); + return absl::OkStatus(); } private: @@ -206,7 +206,7 @@ class ModelDatasetOp::Dataset : public DatasetBase { } }); } - return OkStatus(); + return absl::OkStatus(); } mutex mu_; diff --git a/tensorflow/core/kernels/data/multi_device_iterator_ops.cc b/tensorflow/core/kernels/data/multi_device_iterator_ops.cc index 390663822aa986..a6de1b57025f84 100644 --- a/tensorflow/core/kernels/data/multi_device_iterator_ops.cc +++ b/tensorflow/core/kernels/data/multi_device_iterator_ops.cc @@ -127,7 +127,7 @@ class MultiDeviceIterator : public ResourceBase { multi_device_buffer_ = std::make_unique( devices_.size(), max_buffer_size, incarnation_id_, std::move(iterator), this); - return OkStatus(); + return absl::OkStatus(); } Status GetNextFromShard(OpKernelContext* ctx, int shard_num, @@ -144,7 +144,7 @@ class MultiDeviceIterator : public ResourceBase { IteratorContext iter_ctx(std::move(params)); multi_device_buffer_->GetNextFromShard(&iter_ctx, shard_num, incarnation_id, std::move(callback)); - return OkStatus(); + return absl::OkStatus(); } const DataTypeVector& output_types() const { return output_types_; } @@ -547,7 +547,7 @@ class MultiDeviceIteratorHandleOp : public OpKernel { output_shapes_, devices_, std::move(flib_def), std::move(pflr), flr, std::move(function_handle_cache)); - return OkStatus(); + return absl::OkStatus(); })); Status s = VerifyResource(resource); if (TF_PREDICT_FALSE(!s.ok())) { @@ -575,7 +575,7 @@ class MultiDeviceIteratorHandleOp : public OpKernel { VerifyTypesMatch(output_types_, resource->output_types())); TF_RETURN_IF_ERROR( VerifyShapesCompatible(output_shapes_, resource->output_shapes())); - return OkStatus(); + return absl::OkStatus(); } mutex mu_; @@ -620,7 +620,7 @@ class AnonymousMultiDeviceIteratorOp new MultiDeviceIterator(ctx->env(), output_dtypes_, output_shapes_, devices_, std::move(flib_def), std::move(pflr), lib, std::move(function_handle_cache)); - return OkStatus(); + return absl::OkStatus(); } std::vector devices_; diff --git a/tensorflow/core/kernels/data/optimize_dataset_op_test.cc b/tensorflow/core/kernels/data/optimize_dataset_op_test.cc index f90d3d47d0704a..6e03748b970582 100644 --- a/tensorflow/core/kernels/data/optimize_dataset_op_test.cc +++ b/tensorflow/core/kernels/data/optimize_dataset_op_test.cc @@ -47,7 +47,7 @@ class OptimizeDatasetParams : public DatasetParams { Status GetInputNames(std::vector* input_names) const override { *input_names = {OptimizeDatasetOp::kInputDataset, OptimizeDatasetOp::kOptimizations}; - return OkStatus(); + return absl::OkStatus(); } Status GetAttributes(AttributeVector* attr_vector) const override { @@ -55,7 +55,7 @@ class OptimizeDatasetParams : public DatasetParams { {OptimizeDatasetOp::kOutputShapes, output_shapes_}, {OptimizeDatasetOp::kOutputTypes, output_dtypes_}, {OptimizeDatasetOp::kOptimizationConfigs, optimization_configs_}}; - return OkStatus(); + return absl::OkStatus(); } string dataset_type() const override { diff --git a/tensorflow/core/kernels/data/optional_ops.cc b/tensorflow/core/kernels/data/optional_ops.cc index 78e66510fdb437..07967a672135e3 100644 --- a/tensorflow/core/kernels/data/optional_ops.cc +++ b/tensorflow/core/kernels/data/optional_ops.cc @@ -50,7 +50,7 @@ static Status OptionalDeviceCopy( } else { *to = from; } - return OkStatus(); + return absl::OkStatus(); } #define REGISTER_OPTIONAL_COPY(DIRECTION) \ @@ -144,7 +144,7 @@ Status WriteOptionalWithValueToOutput(OpKernelContext* ctx, int output_index, TF_RETURN_IF_ERROR(ctx->allocate_output(output_index, TensorShape({}), &variant_t, cpu_alloc)); variant_t->scalar()() = v; - return OkStatus(); + return absl::OkStatus(); } Status WriteOptionalNoneToOutput(OpKernelContext* ctx, int output_index) { @@ -155,7 +155,7 @@ Status WriteOptionalNoneToOutput(OpKernelContext* ctx, int output_index) { TF_RETURN_IF_ERROR(ctx->allocate_output(output_index, TensorShape({}), &variant_t, cpu_alloc)); variant_t->scalar()() = v; - return OkStatus(); + return absl::OkStatus(); } namespace { diff --git a/tensorflow/core/kernels/data/optional_ops_util.cc b/tensorflow/core/kernels/data/optional_ops_util.cc index 4392dda39a2e5e..c504c99c7a528d 100644 --- a/tensorflow/core/kernels/data/optional_ops_util.cc +++ b/tensorflow/core/kernels/data/optional_ops_util.cc @@ -32,7 +32,7 @@ Status OptionalZerosLike(OpKernelContext* ctx, const OptionalVariant& x, const Tensor& input, Tensor* out)> zeros_like_func) { if (!x.has_value()) { - return OkStatus(); + return absl::OkStatus(); } std::vector zero_tensors; for (const Tensor& tensor : x.get_values()) { @@ -41,7 +41,7 @@ Status OptionalZerosLike(OpKernelContext* ctx, const OptionalVariant& x, zero_tensors.push_back(std::move(zero_t)); } *y = OptionalVariant(zero_tensors); - return OkStatus(); + return absl::OkStatus(); } Status OptionalBinaryAdd( @@ -56,7 +56,7 @@ Status OptionalBinaryAdd( "Cannot add optionals because one has a value and the other doesn't."); } if (!a.has_value()) { - return OkStatus(); + return absl::OkStatus(); } if (a.get_values().size() != b.get_values().size()) { return errors::InvalidArgument( @@ -73,7 +73,7 @@ Status OptionalBinaryAdd( out_tensors.push_back(std::move(out_tensor)); } *out = OptionalVariant(out_tensors); - return OkStatus(); + return absl::OkStatus(); } } // namespace data diff --git a/tensorflow/core/kernels/data/options_dataset_op.cc b/tensorflow/core/kernels/data/options_dataset_op.cc index 350fa76d16af65..311957f5e28178 100644 --- a/tensorflow/core/kernels/data/options_dataset_op.cc +++ b/tensorflow/core/kernels/data/options_dataset_op.cc @@ -82,7 +82,7 @@ class OptionsDatasetOp::Dataset : public DatasetBase { Status InputDatasets(std::vector* inputs) const override { inputs->push_back(input_); - return OkStatus(); + return absl::OkStatus(); } Status CheckExternalState() const override { @@ -100,7 +100,7 @@ class OptionsDatasetOp::Dataset : public DatasetBase { TF_RETURN_IF_ERROR(b->AddDataset( this, {input_graph_node}, {std::make_pair(kSerializedOptions, serialized_options_attr)}, output)); - return OkStatus(); + return absl::OkStatus(); } private: diff --git a/tensorflow/core/kernels/data/padded_batch_dataset_op.cc b/tensorflow/core/kernels/data/padded_batch_dataset_op.cc index 9cfde812ed5a3f..1c8ef0caef8b04 100644 --- a/tensorflow/core/kernels/data/padded_batch_dataset_op.cc +++ b/tensorflow/core/kernels/data/padded_batch_dataset_op.cc @@ -123,7 +123,7 @@ class PaddedBatchDatasetOp::Dataset : public DatasetBase { Status InputDatasets(std::vector* inputs) const override { inputs->push_back(input_); - return OkStatus(); + return absl::OkStatus(); } Status CheckExternalState() const override { @@ -178,7 +178,7 @@ class PaddedBatchDatasetOp::Dataset : public DatasetBase { {kToutputTypes, output_types}, {kNumPaddedShapes, N}}, output)); - return OkStatus(); + return absl::OkStatus(); } private: @@ -203,7 +203,7 @@ class PaddedBatchDatasetOp::Dataset : public DatasetBase { mutex_lock l(mu_); if (!input_impl_) { *end_of_sequence = true; - return OkStatus(); + return absl::OkStatus(); } else { *end_of_sequence = false; batch_elements.reserve(dataset()->batch_size_); @@ -224,18 +224,18 @@ class PaddedBatchDatasetOp::Dataset : public DatasetBase { if (batch_elements.empty()) { DCHECK(*end_of_sequence); - return OkStatus(); + return absl::OkStatus(); } if (dataset()->drop_remainder_ && batch_elements.size() < dataset()->batch_size_) { *end_of_sequence = true; - return OkStatus(); + return absl::OkStatus(); } TF_RETURN_IF_ERROR(CopyBatch(ctx, batch_elements, out_tensors)); *end_of_sequence = false; - return OkStatus(); + return absl::OkStatus(); } protected: @@ -252,7 +252,7 @@ class PaddedBatchDatasetOp::Dataset : public DatasetBase { if (input_impl_) { TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_)); } - return OkStatus(); + return absl::OkStatus(); } Status RestoreInternal(IteratorContext* ctx, @@ -268,7 +268,7 @@ class PaddedBatchDatasetOp::Dataset : public DatasetBase { dataset()->input_->MakeIterator(ctx, this, prefix(), &input_impl_)); TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_)); } - return OkStatus(); + return absl::OkStatus(); } TraceMeMetadata GetTraceMeMetadata() const override { @@ -365,7 +365,7 @@ class PaddedBatchDatasetOp::Dataset : public DatasetBase { batch_elements[index][component_index], &batch_component, index)); } - return OkStatus(); + return absl::OkStatus(); }; if (dataset()->parallel_copy_ && (batch_component.AllocatedBytes() / @@ -403,7 +403,7 @@ class PaddedBatchDatasetOp::Dataset : public DatasetBase { } } } - return OkStatus(); + return absl::OkStatus(); } mutex mu_; diff --git a/tensorflow/core/kernels/data/padded_batch_dataset_op_test.cc b/tensorflow/core/kernels/data/padded_batch_dataset_op_test.cc index 9f0b8e2cb5f3a2..6edacb8b7e3444 100644 --- a/tensorflow/core/kernels/data/padded_batch_dataset_op_test.cc +++ b/tensorflow/core/kernels/data/padded_batch_dataset_op_test.cc @@ -76,7 +76,7 @@ class PaddedBatchDatasetParams : public DatasetParams { strings::StrCat(PaddedBatchDatasetOp::kPaddingValues, "_", j)); } input_names->push_back(PaddedBatchDatasetOp::kDropRemainder); - return OkStatus(); + return absl::OkStatus(); } Status GetAttributes(AttributeVector* attr_vector) const override { @@ -85,7 +85,7 @@ class PaddedBatchDatasetParams : public DatasetParams { {"output_shapes", output_shapes_}, {"N", num_padded_shapes_}, {"metadata", ""}}; - return OkStatus(); + return absl::OkStatus(); } string dataset_type() const override { diff --git a/tensorflow/core/kernels/data/parallel_batch_dataset_op.cc b/tensorflow/core/kernels/data/parallel_batch_dataset_op.cc index f8963ba3d5c428..d60f83da21ddd5 100644 --- a/tensorflow/core/kernels/data/parallel_batch_dataset_op.cc +++ b/tensorflow/core/kernels/data/parallel_batch_dataset_op.cc @@ -139,7 +139,7 @@ class ParallelBatchDatasetOp::Dataset : public DatasetBase { Status InputDatasets(std::vector* inputs) const override { inputs->push_back(input_); - return OkStatus(); + return absl::OkStatus(); } Status CheckExternalState() const override { @@ -181,7 +181,7 @@ class ParallelBatchDatasetOp::Dataset : public DatasetBase { this, {input_graph_node, batch_size, num_parallel_calls, drop_remainder}, attrs, output)); - return OkStatus(); + return absl::OkStatus(); } private: @@ -229,7 +229,7 @@ class ParallelBatchDatasetOp::Dataset : public DatasetBase { if (ctx->warm_start() && !ctx->is_restoring()) { EnsureThreadsStarted(ctx); } - return OkStatus(); + return absl::OkStatus(); } Status GetNextInternal(IteratorContext* ctx, @@ -266,7 +266,7 @@ class ParallelBatchDatasetOp::Dataset : public DatasetBase { ProcessBatch(dataset()->batch_size_, result->num_elements, dataset()->drop_remainder_, result->status, ctx, out_tensors, end_of_sequence, &result->output)); - return OkStatus(); + return absl::OkStatus(); } protected: @@ -296,7 +296,7 @@ class ParallelBatchDatasetOp::Dataset : public DatasetBase { for (size_t i = 0; i < batch_results_.size(); ++i) { TF_RETURN_IF_ERROR(WriteBatchResult(writer, i)); } - return OkStatus(); + return absl::OkStatus(); } Status RestoreInternal(IteratorContext* ctx, @@ -314,7 +314,7 @@ class ParallelBatchDatasetOp::Dataset : public DatasetBase { if (ctx->warm_start()) { EnsureThreadsStarted(ctx); } - return OkStatus(); + return absl::OkStatus(); } TraceMeMetadata GetTraceMeMetadata() const override { @@ -344,7 +344,7 @@ class ParallelBatchDatasetOp::Dataset : public DatasetBase { explicit BatchResult(IteratorContext* ctx) : end_of_input(false), num_elements(0), - status(OkStatus()), + status(absl::OkStatus()), call_finished(false), output_allocated(false), uid(tensorflow::EnvTime::NowNanos()), @@ -422,17 +422,17 @@ class ParallelBatchDatasetOp::Dataset : public DatasetBase { Status status; { mutex_lock l(result->mu); - auto allocation_callback = - [this, ctx, result]() - TF_EXCLUSIVE_LOCKS_REQUIRED(&BatchResult::mu) { - result->output_allocated = true; - RecordBufferEnqueue(ctx.get(), result->output); - return OkStatus(); - }; status = CopyBatch(CopyBatchParams(ctx.get()), *batch_elements, - dataset()->parallel_copy_, - std::move(allocation_callback), &result->output); + dataset()->parallel_copy_, &result->output); result->status.Update(status); + + if (result->status.ok()) { + result->output_allocated = true; + RecordBufferEnqueue(ctx.get(), result->output); + } else { + result->output.clear(); + result->output_allocated = false; + } } CallCompleted(ctx, result); return status; @@ -566,7 +566,7 @@ class ParallelBatchDatasetOp::Dataset : public DatasetBase { if (result->output_allocated) { RecordBufferEnqueue(ctx, result->output); } - return OkStatus(); + return absl::OkStatus(); } Status WriteBatchResult(IteratorStateWriter* writer, size_t index) @@ -597,7 +597,7 @@ class ParallelBatchDatasetOp::Dataset : public DatasetBase { TF_RETURN_IF_ERROR( WriteStatus(prefix(), strings::StrCat(batch_prefix, "_", kStatus), result->status, writer)); - return OkStatus(); + return absl::OkStatus(); } // Used for coordination between the main thread and the runner thread. diff --git a/tensorflow/core/kernels/data/parallel_batch_dataset_op_test.cc b/tensorflow/core/kernels/data/parallel_batch_dataset_op_test.cc index 6f07d87289441d..f4fc0c519e79e6 100644 --- a/tensorflow/core/kernels/data/parallel_batch_dataset_op_test.cc +++ b/tensorflow/core/kernels/data/parallel_batch_dataset_op_test.cc @@ -60,7 +60,7 @@ class ParallelBatchDatasetParams : public DatasetParams { ParallelBatchDatasetOp::kBatchSize, ParallelBatchDatasetOp::kNumParallelCalls, ParallelBatchDatasetOp::kDropRemainder}; - return OkStatus(); + return absl::OkStatus(); } Status GetAttributes(AttributeVector* attr_vector) const override { @@ -71,7 +71,7 @@ class ParallelBatchDatasetParams : public DatasetParams { {"deterministic", deterministic_}, {"metadata", ""}, }; - return OkStatus(); + return absl::OkStatus(); }; string dataset_type() const override { diff --git a/tensorflow/core/kernels/data/parallel_filter_dataset_op.cc b/tensorflow/core/kernels/data/parallel_filter_dataset_op.cc index 15fbfd59bc4c5f..2a7786d91624e6 100644 --- a/tensorflow/core/kernels/data/parallel_filter_dataset_op.cc +++ b/tensorflow/core/kernels/data/parallel_filter_dataset_op.cc @@ -20,6 +20,7 @@ limitations under the License. #include "absl/status/status.h" #include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/common_runtime/input_colocation_exemption_registry.h" +#include "tensorflow/core/data/dataset_utils.h" #include "tensorflow/core/data/name_utils.h" #include "tensorflow/core/framework/metrics.h" #include "tensorflow/core/framework/model.h" @@ -94,7 +95,7 @@ class ParallelFilterDatasetOp::Dataset : public DatasetBase { Status InputDatasets(std::vector* inputs) const override { inputs->push_back(input_); - return OkStatus(); + return absl::OkStatus(); } Status CheckExternalState() const override { @@ -128,7 +129,7 @@ class ParallelFilterDatasetOp::Dataset : public DatasetBase { {kPredicate, predicate_attr}, {kTarguments, other_arguments_types_attr}}, output)); - return OkStatus(); + return absl::OkStatus(); } private: @@ -234,7 +235,7 @@ class ParallelFilterDatasetOp::Dataset : public DatasetBase { writer->WriteScalar(element_prefix, kEndOfInput, "")); } } - return OkStatus(); + return absl::OkStatus(); } Status RestoreInternal(IteratorContext* ctx, @@ -263,7 +264,7 @@ class ParallelFilterDatasetOp::Dataset : public DatasetBase { RecordBufferEnqueue(ctx, result.return_values); result.notification.Notify(); } - return OkStatus(); + return absl::OkStatus(); } TraceMeMetadata GetTraceMeMetadata() const override { @@ -403,7 +404,7 @@ class ParallelFilterDatasetOp::Dataset : public DatasetBase { if (!result->end_of_input && result->status.ok()) { *out_tensors = std::move(result->return_values); *end_of_sequence = false; - return OkStatus(); + return absl::OkStatus(); } if (errors::IsOutOfRange(result->status)) { // `predicate` may deliberately raise `errors::OutOfRange` to indicate @@ -530,7 +531,7 @@ class ParallelFilterDatasetOp::Dataset : public DatasetBase { TF_RETURN_IF_ERROR(writer->WriteTensor( prefix, absl::StrCat(kComponent, "[", j, "]"), values[j])); } - return OkStatus(); + return absl::OkStatus(); } Status ReadComponentsLocked(IteratorContext* ctx, @@ -552,7 +553,7 @@ class ParallelFilterDatasetOp::Dataset : public DatasetBase { ctx->flr(), prefix, absl::StrCat(kComponent, "[", j, "]"), &values->back())); } - return OkStatus(); + return absl::OkStatus(); } Status WriteStatusLocked(IteratorStateWriter* writer, @@ -564,7 +565,7 @@ class ParallelFilterDatasetOp::Dataset : public DatasetBase { TF_RETURN_IF_ERROR(writer->WriteScalar(key, kErrorMessage, std::string(status.message()))); } - return OkStatus(); + return absl::OkStatus(); } Status ReadStatusLocked(IteratorStateReader* reader, const std::string& key, @@ -579,9 +580,9 @@ class ParallelFilterDatasetOp::Dataset : public DatasetBase { reader->ReadScalar(key, kErrorMessage, &error_message)); *status = Status(code, error_message); } else { - *status = OkStatus(); + *status = absl::OkStatus(); } - return OkStatus(); + return absl::OkStatus(); } // Used for coordination between the main thread and the runner thread. diff --git a/tensorflow/core/kernels/data/parallel_filter_dataset_op_test.cc b/tensorflow/core/kernels/data/parallel_filter_dataset_op_test.cc index 0ebadffa1c9f02..9537ea1c1aef2c 100644 --- a/tensorflow/core/kernels/data/parallel_filter_dataset_op_test.cc +++ b/tensorflow/core/kernels/data/parallel_filter_dataset_op_test.cc @@ -66,7 +66,7 @@ class ParallelFilterDatasetParams : public DatasetParams { absl::StrCat(ParallelFilterDatasetOp::kOtherArguments, "_", i)); } input_names->emplace_back(ParallelFilterDatasetOp::kNumParallelCalls); - return OkStatus(); + return absl::OkStatus(); } Status GetAttributes(AttributeVector* attr_vector) const override { @@ -74,7 +74,7 @@ class ParallelFilterDatasetParams : public DatasetParams { {"predicate", pred_func_}, {"Targuments", type_arguments_}, {"output_shapes", output_shapes_}, {"output_types", output_dtypes_}, {"deterministic", deterministic_}, {"metadata", ""}}; - return OkStatus(); + return absl::OkStatus(); } string dataset_type() const override { diff --git a/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc b/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc index 43699475d26c3b..e134259ee4abfa 100644 --- a/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc +++ b/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc @@ -275,7 +275,7 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase { Status InputDatasets(std::vector* inputs) const override { inputs->push_back(input_); - return OkStatus(); + return absl::OkStatus(); } Status CheckExternalState() const override { @@ -351,7 +351,7 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase { } TF_RETURN_IF_ERROR(b->AddDataset(this, inputs, list_inputs, attrs, output)); - return OkStatus(); + return absl::OkStatus(); } private: @@ -415,7 +415,7 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase { EnsureInitialElementsCreated(ctx); EnsureThreadsStarted(ctx); } - return OkStatus(); + return absl::OkStatus(); } Status GetNextInternal(IteratorContext* ctx, @@ -446,7 +446,7 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase { } if (!result) { *end_of_sequence = true; - return OkStatus(); + return absl::OkStatus(); } profiler::TraceMe traceme([&] { return profiler::TraceMeEncode("ParallelInterleaveConsume", @@ -530,7 +530,7 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase { // Wake workers back up. current_workers_cond_var_.notify_all(); future_workers_cond_var_.notify_all(); - return OkStatus(); + return absl::OkStatus(); } Status RestoreInternal(IteratorContext* ctx, @@ -579,7 +579,7 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase { } VLOG(2) << "Parallel interleave iterator restored"; VLOG(4) << "State after restore:\n" << DebugString(); - return OkStatus(); + return absl::OkStatus(); } TraceMeMetadata GetTraceMeMetadata() const override { @@ -1286,7 +1286,7 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase { ErrorMessageKey(idx), std::string(status.message()))); } - return OkStatus(); + return absl::OkStatus(); } Status ReadStatusLocked(IteratorStateReader* reader, @@ -1303,9 +1303,9 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase { iterator_name, ErrorMessageKey(idx), &error_message)); *status = Status(code, error_message); } else { - *status = OkStatus(); + *status = absl::OkStatus(); } - return OkStatus(); + return absl::OkStatus(); } string CodeKey(size_t idx) { @@ -1362,7 +1362,7 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase { key_prefix, absl::StrCat(kResultsSuffix, "[", i, "]", kIsReadySuffix), "")); } - return OkStatus(); + return absl::OkStatus(); } Status WriteCurrentElements(SerializationContext* ctx, @@ -1381,7 +1381,7 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase { WriteElement(ctx, current_elements_[idx], key_prefix, writer)); } } - return OkStatus(); + return absl::OkStatus(); } Status WriteFutureElements(SerializationContext* ctx, @@ -1400,7 +1400,7 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase { WriteElement(ctx, future_elements_[idx], key_prefix, writer)); } } - return OkStatus(); + return absl::OkStatus(); } Status ReadElement(IteratorContext* ctx, IteratorStateReader* reader, @@ -1410,7 +1410,7 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase { TF_RETURN_IF_ERROR(reader->ReadScalar(key_prefix, kElementUninitialized, &element_uninitialized)); if (static_cast(element_uninitialized)) { - return OkStatus(); + return absl::OkStatus(); } std::unique_ptr iterator; auto element = std::make_shared(); @@ -1447,7 +1447,7 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase { if (static_cast(!restore_iterator)) { element->iterator.reset(); *out = std::move(element); - return OkStatus(); + return absl::OkStatus(); } int64_t inputs_size; TF_RETURN_IF_ERROR(reader->ReadScalar( @@ -1475,7 +1475,7 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase { mutex_lock l(*mu_); element->iterator = std::move(iterator); *out = std::move(element); - return OkStatus(); + return absl::OkStatus(); } Status ReadCurrentElements(IteratorContext* ctx, @@ -1501,7 +1501,7 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase { } } if (size == 0) { - return OkStatus(); + return absl::OkStatus(); } std::vector> elements; TF_RETURN_IF_ERROR( @@ -1513,7 +1513,7 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase { for (int idx = 0; idx < size; ++idx) { current_elements_[idx] = std::move(elements[idx]); } - return OkStatus(); + return absl::OkStatus(); } Status ReadFutureElements(IteratorContext* ctx, @@ -1526,7 +1526,7 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase { future_elements_.resize(size); } if (size == 0) { - return OkStatus(); + return absl::OkStatus(); } std::vector> elements; TF_RETURN_IF_ERROR( @@ -1538,14 +1538,14 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase { for (int idx = 0; idx < size; ++idx) { future_elements_[idx] = std::move(elements[idx]); } - return OkStatus(); + return absl::OkStatus(); } Status ReadElementsParallel( IteratorContext* ctx, IteratorStateReader* reader, int64_t size, const string& name, std::vector>* elements) { elements->resize(size); - Status s = OkStatus(); + Status s = absl::OkStatus(); BlockingCounter counter(size); for (int idx = 0; idx < size; ++idx) { thread_pool_->Schedule([this, ctx, reader, idx, name, &s, &counter, diff --git a/tensorflow/core/kernels/data/parallel_interleave_dataset_op_test.cc b/tensorflow/core/kernels/data/parallel_interleave_dataset_op_test.cc index e8117a8a07210f..7649e7aa996996 100644 --- a/tensorflow/core/kernels/data/parallel_interleave_dataset_op_test.cc +++ b/tensorflow/core/kernels/data/parallel_interleave_dataset_op_test.cc @@ -86,7 +86,7 @@ class ParallelInterleaveDatasetParams : public DatasetParams { input_names->emplace_back( ParallelInterleaveDatasetOp::kPrefetchInputElements); input_names->emplace_back(ParallelInterleaveDatasetOp::kNumParallelCalls); - return OkStatus(); + return absl::OkStatus(); } Status GetAttributes(AttributeVector* attr_vector) const override { @@ -96,7 +96,7 @@ class ParallelInterleaveDatasetParams : public DatasetParams { {"output_shapes", output_shapes_}, {"output_types", output_dtypes_}, {"metadata", ""}}; - return OkStatus(); + return absl::OkStatus(); } string dataset_type() const override { diff --git a/tensorflow/core/kernels/data/parallel_map_dataset_op.cc b/tensorflow/core/kernels/data/parallel_map_dataset_op.cc index dd859aaf333c12..390860b4ae7550 100644 --- a/tensorflow/core/kernels/data/parallel_map_dataset_op.cc +++ b/tensorflow/core/kernels/data/parallel_map_dataset_op.cc @@ -153,7 +153,7 @@ class ParallelMapDatasetOp::Dataset : public DatasetBase { Status InputDatasets(std::vector* inputs) const override { inputs->push_back(input_); - return OkStatus(); + return absl::OkStatus(); } Status CheckExternalState() const override { @@ -225,7 +225,7 @@ class ParallelMapDatasetOp::Dataset : public DatasetBase { std::make_pair(2, num_parallel_calls)}, // Single tensor inputs. {std::make_pair(1, other_arguments)}, // Tensor list inputs. attrs, output)); - return OkStatus(); + return absl::OkStatus(); } private: @@ -274,7 +274,7 @@ class ParallelMapDatasetOp::Dataset : public DatasetBase { if (ctx->warm_start() && !ctx->is_restoring()) { EnsureThreadsStarted(ctx); } - return OkStatus(); + return absl::OkStatus(); } Status GetNextInternal(IteratorContext* ctx, @@ -318,7 +318,7 @@ class ParallelMapDatasetOp::Dataset : public DatasetBase { TF_RETURN_IF_ERROR(ctx->HandleCheckExternalStateStatus( dataset()->captured_func_->CheckExternalState())); if (ctx->symbolic_checkpoint()) { - return OkStatus(); + return absl::OkStatus(); } mutex_lock l(*mu_); // Wait for all in-flight calls to complete. @@ -350,7 +350,7 @@ class ParallelMapDatasetOp::Dataset : public DatasetBase { writer->WriteScalar(element_prefix, kEndOfInput, static_cast(result.end_of_input))); } - return OkStatus(); + return absl::OkStatus(); } Status RestoreInternal(IteratorContext* ctx, @@ -359,7 +359,7 @@ class ParallelMapDatasetOp::Dataset : public DatasetBase { TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_)); DCHECK(invocation_results_.empty()); if (ctx->symbolic_checkpoint()) { - return OkStatus(); + return absl::OkStatus(); } int64_t invocation_results_size; TF_RETURN_IF_ERROR(reader->ReadScalar( @@ -397,7 +397,7 @@ class ParallelMapDatasetOp::Dataset : public DatasetBase { RecordBufferEnqueue(ctx, result.return_values); result.notification.Notify(); } - return OkStatus(); + return absl::OkStatus(); } TraceMeMetadata GetTraceMeMetadata() const override { @@ -491,7 +491,9 @@ class ParallelMapDatasetOp::Dataset : public DatasetBase { } auto done = [this, ctx, result](Status status) { - result->status.Update(status); + if (!status.ok()) { + result->status = AddErrorContext(status); + } RecordBufferEnqueue(ctx.get(), result->return_values); CallCompleted(ctx, result); }; @@ -539,7 +541,7 @@ class ParallelMapDatasetOp::Dataset : public DatasetBase { *out_tensors = std::move(result->return_values); RecordBufferDequeue(ctx, *out_tensors); *end_of_sequence = false; - return OkStatus(); + return absl::OkStatus(); } if (errors::IsOutOfRange(result->status)) { if (preserve_cardinality_) { @@ -553,7 +555,7 @@ class ParallelMapDatasetOp::Dataset : public DatasetBase { // `f` may deliberately raise `errors::OutOfRange` to indicate // that we should terminate the iteration early. *end_of_sequence = true; - return OkStatus(); + return absl::OkStatus(); } } *end_of_sequence = result->end_of_input; @@ -672,7 +674,7 @@ class ParallelMapDatasetOp::Dataset : public DatasetBase { absl::StrCat("_", kErrorMessage), std::string(status.message()))); } - return OkStatus(); + return absl::OkStatus(); } Status ReadStatusLocked(IteratorStateReader* reader, @@ -689,9 +691,9 @@ class ParallelMapDatasetOp::Dataset : public DatasetBase { prefix, absl::StrCat("_", kErrorMessage), &error_message)); *status = Status(code, error_message); } else { - *status = OkStatus(); + *status = absl::OkStatus(); } - return OkStatus(); + return absl::OkStatus(); } // Used for coordination between the main thread and the runner thread. diff --git a/tensorflow/core/kernels/data/parallel_map_dataset_op_test.cc b/tensorflow/core/kernels/data/parallel_map_dataset_op_test.cc index 6072b37cbd70b6..cfd139cc6bbfe7 100644 --- a/tensorflow/core/kernels/data/parallel_map_dataset_op_test.cc +++ b/tensorflow/core/kernels/data/parallel_map_dataset_op_test.cc @@ -65,7 +65,7 @@ class ParallelMapDatasetParams : public DatasetParams { absl::StrCat(ParallelMapDatasetOp::kOtherArguments, "_", i)); } input_names->emplace_back(ParallelMapDatasetOp::kNumParallelCalls); - return OkStatus(); + return absl::OkStatus(); } Status GetAttributes(AttributeVector* attr_vector) const override { @@ -77,7 +77,7 @@ class ParallelMapDatasetParams : public DatasetParams { {"deterministic", deterministic_}, {"preserve_cardinality", preserve_cardinality_}, {"metadata", ""}}; - return OkStatus(); + return absl::OkStatus(); } string dataset_type() const override { diff --git a/tensorflow/core/kernels/data/prefetch_dataset_op.cc b/tensorflow/core/kernels/data/prefetch_dataset_op.cc index cc23163c17ff94..2bcda3283a76bb 100644 --- a/tensorflow/core/kernels/data/prefetch_dataset_op.cc +++ b/tensorflow/core/kernels/data/prefetch_dataset_op.cc @@ -104,7 +104,7 @@ class PrefetchDatasetOp::Dataset : public DatasetBase { Status InputDatasets(std::vector* inputs) const override { inputs->push_back(input_); - return OkStatus(); + return absl::OkStatus(); } Status CheckExternalState() const override { @@ -137,7 +137,7 @@ class PrefetchDatasetOp::Dataset : public DatasetBase { std::make_pair(kLegacyAutotune, legacy_autotune_attr), std::make_pair(kBufferSizeMin, buffer_size_min_attr)}, output)); - return OkStatus(); + return absl::OkStatus(); } private: @@ -188,7 +188,7 @@ class PrefetchDatasetOp::Dataset : public DatasetBase { TF_RETURN_IF_ERROR(EnsureThreadsStarted(ctx)); } ctx->MergeCheckpoint(iter_ctx.checkpoint()); - return OkStatus(); + return absl::OkStatus(); } Status GetNextInternal(IteratorContext* ctx, @@ -217,7 +217,7 @@ class PrefetchDatasetOp::Dataset : public DatasetBase { if (prefetch_thread_finished_) { *end_of_sequence = true; - return OkStatus(); + return absl::OkStatus(); } DCHECK_EQ(buffer_limit(), 0); @@ -259,7 +259,7 @@ class PrefetchDatasetOp::Dataset : public DatasetBase { Status SaveInternal(SerializationContext* ctx, IteratorStateWriter* writer) override { if (ctx->symbolic_checkpoint()) { - return OkStatus(); + return absl::OkStatus(); } // Acquire both locks to ensure that the prefetch thread and // all GetNext threads are blocked. @@ -282,7 +282,7 @@ class PrefetchDatasetOp::Dataset : public DatasetBase { } } } - return OkStatus(); + return absl::OkStatus(); } Status RestoreInternal(IteratorContext* ctx, @@ -301,7 +301,7 @@ class PrefetchDatasetOp::Dataset : public DatasetBase { TF_RETURN_IF_ERROR(EnsureThreadsStarted(ctx)); } cond_var_->notify_all(); - return OkStatus(); + return absl::OkStatus(); } data::TraceMeMetadata GetTraceMeMetadata() const override { @@ -393,7 +393,7 @@ class PrefetchDatasetOp::Dataset : public DatasetBase { } RecordBufferEnqueue(ctx, buffer_element.value); } - return OkStatus(); + return absl::OkStatus(); } int64_t buffer_limit() const TF_EXCLUSIVE_LOCKS_REQUIRED(*mu_) { @@ -492,7 +492,7 @@ class PrefetchDatasetOp::Dataset : public DatasetBase { prefetch_thread_ = ctx->StartThread( "tf_data_prefetch", [this, new_ctx]() { PrefetchThread(new_ctx); }); } - return OkStatus(); + return absl::OkStatus(); } // Prefetches elements of the input, storing results in an internal buffer. @@ -576,7 +576,7 @@ class PrefetchDatasetOp::Dataset : public DatasetBase { absl::StrCat(prefix(), "::", index), ErrorMessageKey(), std::string(status.message()))); } - return OkStatus(); + return absl::OkStatus(); } Status ReadStatus(IteratorStateReader* reader, size_t index, Status* status) @@ -593,9 +593,9 @@ class PrefetchDatasetOp::Dataset : public DatasetBase { ErrorMessageKey(), &error_message)); *status = Status(code, error_message); } else { - *status = OkStatus(); + *status = absl::OkStatus(); } - return OkStatus(); + return absl::OkStatus(); } string CodeKey() { return absl::StrCat(kStatus, kCodeSuffix); } diff --git a/tensorflow/core/kernels/data/prefetch_dataset_op_test.cc b/tensorflow/core/kernels/data/prefetch_dataset_op_test.cc index 36a35f97bae619..7fe3674db2aca4 100644 --- a/tensorflow/core/kernels/data/prefetch_dataset_op_test.cc +++ b/tensorflow/core/kernels/data/prefetch_dataset_op_test.cc @@ -50,7 +50,7 @@ class PrefetchDatasetParams : public DatasetParams { input_names->clear(); input_names->emplace_back(PrefetchDatasetOp::kInputDataset); input_names->emplace_back(PrefetchDatasetOp::kBufferSize); - return OkStatus(); + return absl::OkStatus(); } Status GetAttributes(AttributeVector* attr_vector) const override { @@ -61,7 +61,7 @@ class PrefetchDatasetParams : public DatasetParams { attr_vector->emplace_back("legacy_autotune", legacy_autotune_); attr_vector->emplace_back("buffer_size_min", buffer_size_min_); attr_vector->emplace_back("metadata", ""); - return OkStatus(); + return absl::OkStatus(); } string dataset_type() const override { diff --git a/tensorflow/core/kernels/data/random_seed_ops.cc b/tensorflow/core/kernels/data/random_seed_ops.cc index 5cf316fded10d5..61566cab3fabd3 100644 --- a/tensorflow/core/kernels/data/random_seed_ops.cc +++ b/tensorflow/core/kernels/data/random_seed_ops.cc @@ -91,7 +91,7 @@ Status AnonymousSeedGeneratorHandleOp::CreateResource( *manager = new SeedGeneratorManager(new FixedSeedGenerator(*seeds_)); } seeds_ = nullptr; - return OkStatus(); + return absl::OkStatus(); } void DeleteSeedGeneratorOp::Compute(OpKernelContext* ctx) { diff --git a/tensorflow/core/kernels/data/range_dataset_op.cc b/tensorflow/core/kernels/data/range_dataset_op.cc index b6aa84aa7089f5..7e0baae0632d4e 100644 --- a/tensorflow/core/kernels/data/range_dataset_op.cc +++ b/tensorflow/core/kernels/data/range_dataset_op.cc @@ -19,12 +19,14 @@ limitations under the License. #include #include "absl/memory/memory.h" +#include "absl/status/status.h" #include "tensorflow/core/data/name_utils.h" #include "tensorflow/core/data/split_utils.h" #include "tensorflow/core/framework/dataset.h" #include "tensorflow/core/framework/partial_tensor_shape.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/platform/errors.h" +#include "tsl/platform/mutex.h" #include "tsl/platform/types.h" namespace tensorflow { @@ -61,7 +63,7 @@ Status ConvertOutputTypes(const tensorflow::DataTypeVector& output_dtypes, return errors::InvalidArgument("Unsupported data type: ", DataTypeString(output_dtypes[0])); } - return OkStatus(); + return absl::OkStatus(); } int64_t sgn(int64_t val) { return (0 < val) - (val < 0); } @@ -142,23 +144,23 @@ class RangeDatasetOp::RangeSplitProvider : public SplitProvider { Status GetNext(Tensor* split, bool* end_of_splits) override { int64_t next = counter_.GetNext(end_of_splits); if (*end_of_splits) { - return OkStatus(); + return absl::OkStatus(); } *split = Tensor(DT_INT64, TensorShape{}); split->scalar()() = next; - return OkStatus(); + return absl::OkStatus(); } Status Reset() override { counter_.Reset(); - return OkStatus(); + return absl::OkStatus(); } Status Save(std::function key_name_fn, IteratorStateWriter* writer) override { TF_RETURN_IF_ERROR( writer->WriteScalar(key_name_fn(kNext), counter_.Peek())); - return OkStatus(); + return absl::OkStatus(); } Status Restore(std::function key_name_fn, @@ -166,7 +168,7 @@ class RangeDatasetOp::RangeSplitProvider : public SplitProvider { int64_t next; TF_RETURN_IF_ERROR(reader->ReadScalar(key_name_fn(kNext), &next)); counter_.SetNext(next); - return OkStatus(); + return absl::OkStatus(); } int64_t Cardinality() const override { return counter_.Cardinality(); } @@ -186,6 +188,10 @@ class RangeDatasetOp::Dataset : public DatasetBase { output_dtypes_(output_dtypes), replicate_on_split_(replicate_on_split) {} + absl::Status RandomIndexingCompatible() const override { + return absl::OkStatus(); + } + std::unique_ptr MakeIteratorInternal( const string& prefix) const override { return std::make_unique(Iterator::Params{ @@ -216,15 +222,15 @@ class RangeDatasetOp::Dataset : public DatasetBase { split_providers) const override { split_providers->push_back( std::make_unique(start_, stop_, step_)); - return OkStatus(); + return absl::OkStatus(); } Status InputDatasets(std::vector* inputs) const override { inputs->clear(); - return OkStatus(); + return absl::OkStatus(); } - Status CheckExternalState() const override { return OkStatus(); } + Status CheckExternalState() const override { return absl::OkStatus(); } Status Get(OpKernelContext* ctx, int64 index, std::vector* out_tensors) const override { @@ -250,7 +256,7 @@ class RangeDatasetOp::Dataset : public DatasetBase { this, {start, stop, step}, // Inputs {std::make_pair(kReplicateOnSplit, replicate_on_split)}, // Attrs output)); - return OkStatus(); + return absl::OkStatus(); } private: @@ -269,30 +275,47 @@ class RangeDatasetOp::Dataset : public DatasetBase { TF_ASSIGN_OR_RETURN(split_provider_, GetSingleSplitProvider(ctx, dataset())); } - return OkStatus(); + return absl::OkStatus(); } Status GetNextInternal(IteratorContext* ctx, std::vector* out_tensors, bool* end_of_sequence) override { + if (ctx->index_mapper() != nullptr) { + return Get(ctx, out_tensors, end_of_sequence); + } int64_t value; if (split_provider_ != nullptr) { Tensor split; TF_RETURN_IF_ERROR(split_provider_->GetNext(&split, end_of_sequence)); if (*end_of_sequence) { - return OkStatus(); + return absl::OkStatus(); } value = split.scalar()(); } else { value = counter_->GetNext(end_of_sequence); if (*end_of_sequence) { - return OkStatus(); + return absl::OkStatus(); } } out_tensors->reserve(1); return ConvertOutputTypes(output_dtypes(), out_tensors, value); } + absl::Status Get(IteratorContext* ctx, std::vector* out_tensors, + bool* end_of_sequence) { + tsl::mutex_lock l(mu_); + if (element_count_ >= + (dataset()->stop_ - dataset()->start_) / dataset()->step_) { + *end_of_sequence = true; + return absl::OkStatus(); + } + int64_t output_index = ctx->index_mapper()(element_count_++); + int64_t value = dataset()->start_ + output_index * dataset()->step_; + *end_of_sequence = false; + return ConvertOutputTypes(output_dtypes(), out_tensors, value); + } + protected: std::shared_ptr CreateNode( IteratorContext* ctx, model::Node::Args args) const override { @@ -313,7 +336,7 @@ class RangeDatasetOp::Dataset : public DatasetBase { TF_RETURN_IF_ERROR( writer->WriteScalar(prefix(), kNext, counter_->Peek())); } - return OkStatus(); + return absl::OkStatus(); } Status RestoreInternal(IteratorContext* ctx, @@ -329,7 +352,7 @@ class RangeDatasetOp::Dataset : public DatasetBase { TF_RETURN_IF_ERROR(reader->ReadScalar(prefix(), kNext, &next)); counter_->SetNext(next); } - return OkStatus(); + return absl::OkStatus(); } std::string SplitProviderKeyNameFn(const std::string& key) { @@ -339,6 +362,11 @@ class RangeDatasetOp::Dataset : public DatasetBase { private: std::unique_ptr counter_; std::shared_ptr split_provider_; + + mutable tsl::mutex mu_; + // Count of elements produced by this iterator when it runs in the random + // access mode. + int64_t element_count_ TF_GUARDED_BY(mu_) = 0; }; const int64_t start_; diff --git a/tensorflow/core/kernels/data/reduce_dataset_op.cc b/tensorflow/core/kernels/data/reduce_dataset_op.cc index 5b29a2958d0d8b..226ed3d2e66587 100644 --- a/tensorflow/core/kernels/data/reduce_dataset_op.cc +++ b/tensorflow/core/kernels/data/reduce_dataset_op.cc @@ -125,7 +125,7 @@ Status ReduceDatasetOp::DoCompute(OpKernelContext* ctx) { for (size_t i = 0; i < state.size(); ++i) { ctx->set_output(i, state[i]); } - return OkStatus(); + return absl::OkStatus(); } namespace { diff --git a/tensorflow/core/kernels/data/reduce_dataset_op_test.cc b/tensorflow/core/kernels/data/reduce_dataset_op_test.cc index 9bb932a87faae5..65e9e0a4ba5e6c 100644 --- a/tensorflow/core/kernels/data/reduce_dataset_op_test.cc +++ b/tensorflow/core/kernels/data/reduce_dataset_op_test.cc @@ -63,7 +63,7 @@ class ReduceDatasetParams : public DatasetParams { for (int i = 0; i < other_arguments_.size(); ++i) { input_names->emplace_back(strings::StrCat("other_arguments_", i)); } - return OkStatus(); + return absl::OkStatus(); } Status GetAttributes(AttributeVector* attr_vector) const override { @@ -75,7 +75,7 @@ class ReduceDatasetParams : public DatasetParams { {"output_shapes", output_shapes_}, {"use_inter_op_parallelism", use_inter_op_parallelism_}, {"metadata", ""}}; - return OkStatus(); + return absl::OkStatus(); } string dataset_type() const override { return "Reduce"; } diff --git a/tensorflow/core/kernels/data/repeat_dataset_op.cc b/tensorflow/core/kernels/data/repeat_dataset_op.cc index 904d8ac23baad4..156c47caf957b1 100644 --- a/tensorflow/core/kernels/data/repeat_dataset_op.cc +++ b/tensorflow/core/kernels/data/repeat_dataset_op.cc @@ -188,7 +188,7 @@ class RepeatDatasetOp::Dataset : public DatasetBase { Status InputDatasets(std::vector* inputs) const override { inputs->push_back(input_); - return OkStatus(); + return absl::OkStatus(); } Status CheckExternalState() const override { @@ -210,7 +210,7 @@ class RepeatDatasetOp::Dataset : public DatasetBase { Node* count = nullptr; TF_RETURN_IF_ERROR(b->AddScalar(count_, &count)); TF_RETURN_IF_ERROR(b->AddDataset(this, {input_graph_node, count}, output)); - return OkStatus(); + return absl::OkStatus(); } private: @@ -225,7 +225,7 @@ class RepeatDatasetOp::Dataset : public DatasetBase { std::vector* out_tensors, bool* end_of_sequence) override { *end_of_sequence = true; - return OkStatus(); + return absl::OkStatus(); } protected: @@ -237,11 +237,11 @@ class RepeatDatasetOp::Dataset : public DatasetBase { Status SaveInternal(SerializationContext* ctx, IteratorStateWriter* writer) override { - return OkStatus(); + return absl::OkStatus(); } Status RestoreInternal(IteratorContext* ctx, IteratorStateReader* reader) override { - return OkStatus(); + return absl::OkStatus(); } }; @@ -264,13 +264,13 @@ class RepeatDatasetOp::Dataset : public DatasetBase { mutex_lock l(mu_); // TODO(mrry): Make locking less conservative. if (!input_impl_) { *end_of_sequence = true; - return OkStatus(); + return absl::OkStatus(); } while (i_ < dataset()->count_) { TF_RETURN_IF_ERROR( input_impl_->GetNext(ctx, out_tensors, end_of_sequence)); if (!*end_of_sequence) { - return OkStatus(); + return absl::OkStatus(); } ctx->PurgeCheckpoint(nested_prefix(prefix(), i_)); ++i_; @@ -283,7 +283,7 @@ class RepeatDatasetOp::Dataset : public DatasetBase { } *end_of_sequence = true; input_impl_.reset(); - return OkStatus(); + return absl::OkStatus(); } protected: @@ -302,7 +302,7 @@ class RepeatDatasetOp::Dataset : public DatasetBase { if (input_impl_) { TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_)); } - return OkStatus(); + return absl::OkStatus(); } Status RestoreInternal(IteratorContext* ctx, @@ -319,7 +319,7 @@ class RepeatDatasetOp::Dataset : public DatasetBase { } else { input_impl_.reset(); } - return OkStatus(); + return absl::OkStatus(); } private: @@ -364,12 +364,12 @@ class RepeatDatasetOp::Dataset : public DatasetBase { // Otherwise, this iterator could loop infinitely. if (!has_data_service_input_) { input_impl_.reset(); - return OkStatus(); + return absl::OkStatus(); } } first_call_ = false; if (!*end_of_sequence) { - return OkStatus(); + return absl::OkStatus(); } ctx->PurgeCheckpoint(nested_prefix(prefix(), i_)); ++i_; @@ -397,7 +397,7 @@ class RepeatDatasetOp::Dataset : public DatasetBase { if (input_impl_) { TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_)); } - return OkStatus(); + return absl::OkStatus(); } Status RestoreInternal(IteratorContext* ctx, @@ -416,7 +416,7 @@ class RepeatDatasetOp::Dataset : public DatasetBase { TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_)); first_call_ = false; } - return OkStatus(); + return absl::OkStatus(); } private: diff --git a/tensorflow/core/kernels/data/repeat_dataset_op_test.cc b/tensorflow/core/kernels/data/repeat_dataset_op_test.cc index 812563f689d7b7..451d7c107de04c 100644 --- a/tensorflow/core/kernels/data/repeat_dataset_op_test.cc +++ b/tensorflow/core/kernels/data/repeat_dataset_op_test.cc @@ -51,7 +51,7 @@ class RepeatDatasetParams : public DatasetParams { input_names->clear(); input_names->emplace_back(RepeatDatasetOp::kInputDataset); input_names->emplace_back(RepeatDatasetOp::kCount); - return OkStatus(); + return absl::OkStatus(); } Status GetAttributes(AttributeVector* attr_vector) const override { @@ -59,7 +59,7 @@ class RepeatDatasetParams : public DatasetParams { attr_vector->emplace_back("output_types", output_dtypes_); attr_vector->emplace_back("output_shapes", output_shapes_); attr_vector->emplace_back("metadata", ""); - return OkStatus(); + return absl::OkStatus(); } string dataset_type() const override { return RepeatDatasetOp::kDatasetType; } diff --git a/tensorflow/core/kernels/data/rewrite_dataset_op_test.cc b/tensorflow/core/kernels/data/rewrite_dataset_op_test.cc index 05b5f049f64dc8..bde5797ffd2d69 100644 --- a/tensorflow/core/kernels/data/rewrite_dataset_op_test.cc +++ b/tensorflow/core/kernels/data/rewrite_dataset_op_test.cc @@ -45,13 +45,13 @@ class RewriteDatasetParams : public DatasetParams { Status GetInputNames(std::vector* input_names) const override { *input_names = {RewriteDatasetOp::kInputDataset, RewriteDatasetOp::kRewriteName}; - return OkStatus(); + return absl::OkStatus(); } Status GetAttributes(AttributeVector* attr_vector) const override { attr_vector->emplace_back("output_types", output_dtypes_); attr_vector->emplace_back("output_shapes", output_shapes_); - return OkStatus(); + return absl::OkStatus(); } string dataset_type() const override { diff --git a/tensorflow/core/kernels/data/shard_dataset_op.cc b/tensorflow/core/kernels/data/shard_dataset_op.cc index 8542b9d4c33ece..dd102c3a4d8e09 100644 --- a/tensorflow/core/kernels/data/shard_dataset_op.cc +++ b/tensorflow/core/kernels/data/shard_dataset_op.cc @@ -92,7 +92,7 @@ class ShardDatasetOp::Dataset : public DatasetBase { Status InputDatasets(std::vector* inputs) const override { inputs->push_back(input_); - return OkStatus(); + return absl::OkStatus(); } Status CheckExternalState() const override { @@ -122,7 +122,7 @@ class ShardDatasetOp::Dataset : public DatasetBase { TF_RETURN_IF_ERROR( b->AddDataset(this, {input_graph_node, num_shards, index}, {{kRequireNonEmpty, require_non_empty_attr}}, output)); - return OkStatus(); + return absl::OkStatus(); } private: @@ -154,7 +154,7 @@ class ShardDatasetOp::Dataset : public DatasetBase { *end_of_sequence = false; if (!input_impl_) { *end_of_sequence = true; - return OkStatus(); + return absl::OkStatus(); } int num_to_skip = @@ -168,14 +168,14 @@ class ShardDatasetOp::Dataset : public DatasetBase { next_index_ += num_skipped; if (*end_of_sequence) { input_impl_.reset(); - return OkStatus(); + return absl::OkStatus(); } std::vector result; TF_RETURN_IF_ERROR(input_impl_->GetNext(ctx, &result, end_of_sequence)); if (*end_of_sequence) { input_impl_.reset(); - return OkStatus(); + return absl::OkStatus(); } next_index_++; @@ -207,7 +207,7 @@ class ShardDatasetOp::Dataset : public DatasetBase { } *out_tensors = std::move(result); - return OkStatus(); + return absl::OkStatus(); } protected: @@ -227,7 +227,7 @@ class ShardDatasetOp::Dataset : public DatasetBase { TF_RETURN_IF_ERROR( writer->WriteScalar(prefix(), kNextIndex, next_index_)); } - return OkStatus(); + return absl::OkStatus(); } Status RestoreInternal(IteratorContext* ctx, @@ -243,7 +243,7 @@ class ShardDatasetOp::Dataset : public DatasetBase { } else { input_impl_.reset(); } - return OkStatus(); + return absl::OkStatus(); } TraceMeMetadata GetTraceMeMetadata() const override { diff --git a/tensorflow/core/kernels/data/shard_dataset_op_test.cc b/tensorflow/core/kernels/data/shard_dataset_op_test.cc index de9ad00821cf8e..d593bc6be3accb 100644 --- a/tensorflow/core/kernels/data/shard_dataset_op_test.cc +++ b/tensorflow/core/kernels/data/shard_dataset_op_test.cc @@ -46,7 +46,7 @@ class ShardDatasetParams : public DatasetParams { input_names->emplace_back(ShardDatasetOp::kInputDataset); input_names->emplace_back(ShardDatasetOp::kNumShards); input_names->emplace_back(ShardDatasetOp::kIndex); - return OkStatus(); + return absl::OkStatus(); } Status GetAttributes(AttributeVector* attr_vector) const override { @@ -55,7 +55,7 @@ class ShardDatasetParams : public DatasetParams { attr_vector->emplace_back("output_types", output_dtypes_); attr_vector->emplace_back("output_shapes", output_shapes_); attr_vector->emplace_back("metadata", ""); - return OkStatus(); + return absl::OkStatus(); } string dataset_type() const override { return ShardDatasetOp::kDatasetType; } diff --git a/tensorflow/core/kernels/data/shuffle_dataset_op.cc b/tensorflow/core/kernels/data/shuffle_dataset_op.cc index cb2c28dbf1ea58..c1002e7c7c64e5 100644 --- a/tensorflow/core/kernels/data/shuffle_dataset_op.cc +++ b/tensorflow/core/kernels/data/shuffle_dataset_op.cc @@ -126,7 +126,7 @@ class ShuffleDatasetOpBase::ShuffleDatasetBase : public DatasetBase { Status InputDatasets(std::vector* inputs) const override { inputs->push_back(input_); - return OkStatus(); + return absl::OkStatus(); } Status CheckExternalState() const override { @@ -148,7 +148,7 @@ class ShuffleDatasetOpBase::ShuffleDatasetBase : public DatasetBase { shuffled_index = shuffled_indices_[index]; } TF_RETURN_IF_ERROR(input_->Get(ctx, shuffled_index, out_tensors)); - return OkStatus(); + return absl::OkStatus(); } string DebugString() const override { @@ -211,7 +211,7 @@ class ShuffleDatasetOpBase::ShuffleDatasetBase : public DatasetBase { checkpoint_indices_.insert(i); } } - return OkStatus(); + return absl::OkStatus(); } Status GetNextInternal(IteratorContext* ctx, @@ -222,7 +222,7 @@ class ShuffleDatasetOpBase::ShuffleDatasetBase : public DatasetBase { if (num_elements_ == 0) { DCHECK(input_impl_ == nullptr); *end_of_sequence = true; - return OkStatus(); + return absl::OkStatus(); } *end_of_sequence = false; @@ -241,7 +241,7 @@ class ShuffleDatasetOpBase::ShuffleDatasetBase : public DatasetBase { checkpoint_indices_.insert(slices_.front()->start % buffer_->size()); slices_.front()->start++; num_elements_--; - return OkStatus(); + return absl::OkStatus(); } protected: @@ -316,7 +316,7 @@ class ShuffleDatasetOpBase::ShuffleDatasetBase : public DatasetBase { writer->WriteScalar(this->prefix(), kDataProduced, "")); } - return OkStatus(); + return absl::OkStatus(); } Status RestoreInternal(IteratorContext* ctx, @@ -393,7 +393,7 @@ class ShuffleDatasetOpBase::ShuffleDatasetBase : public DatasetBase { } data_produced_ = reader->Contains(this->prefix(), kDataProduced); - return OkStatus(); + return absl::OkStatus(); } TraceMeMetadata GetTraceMeMetadata() const override { @@ -473,13 +473,13 @@ class ShuffleDatasetOpBase::ShuffleDatasetBase : public DatasetBase { // If we encounter the end of sequence without producing data, we // terminate the iteration immediately. (Otherwise, this iterator // would loop infinitely and never produce a value.) - return OkStatus(); + return absl::OkStatus(); } } if (num_log_entries > 0) { LOG(INFO) << "Shuffle buffer filled."; } - return OkStatus(); + return absl::OkStatus(); } bool ShouldFillBuffer() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { @@ -517,7 +517,7 @@ class ShuffleDatasetOpBase::ShuffleDatasetBase : public DatasetBase { TF_RETURN_IF_ERROR(this->dataset()->input_->MakeIterator( ctx, this, this->prefix(), &input_impl_)); epoch_++; - return OkStatus(); + return absl::OkStatus(); } void AddToShuffleBuffer(IteratorContext* ctx, std::vector&& element) @@ -641,7 +641,7 @@ class ShuffleDatasetOp::Dataset : public ShuffleDatasetBase { {std::make_pair(kReshuffleEachIteration, reshuffle_each_iteration)}, // Attrs output)); - return OkStatus(); + return absl::OkStatus(); } private: @@ -696,7 +696,7 @@ class ShuffleDatasetOp::DatasetV2 : public ShuffleDatasetBase { {input_graph_node, buffer_size_node, resource_handle_node}, // Inputs {}, // Attrs output)); - return OkStatus(); + return absl::OkStatus(); } private: @@ -760,7 +760,7 @@ class ShuffleDatasetOp::DatasetV3 : public ShuffleDatasetBase { {std::make_pair(kReshuffleEachIteration, reshuffle_each_iteration)}, // Attrs output)); - return OkStatus(); + return absl::OkStatus(); } private: @@ -828,7 +828,7 @@ void ShuffleDatasetOp::MakeDataset(OpKernelContext* ctx, DatasetBase* input, *manager = new SeedGeneratorManager(new FixedSeedGenerator(seeds)); } - return OkStatus(); + return absl::OkStatus(); })); handle = MakeResourceHandle(ctx, container, name); } else { @@ -857,7 +857,7 @@ void ShuffleDatasetOp::MakeDataset(OpKernelContext* ctx, DatasetBase* input, [&seeds](SeedGeneratorManager** manager) { *manager = new SeedGeneratorManager( new RandomSeedGenerator(seeds)); - return OkStatus(); + return absl::OkStatus(); })); handle = MakeResourceHandle(ctx, container, name); } else { @@ -892,7 +892,7 @@ void ShuffleDatasetOp::MakeDataset(OpKernelContext* ctx, DatasetBase* input, *manager = new SeedGeneratorManager(new FixedSeedGenerator(seeds)); } - return OkStatus(); + return absl::OkStatus(); })); auto handle = MakeResourceHandle(ctx, container, name); @@ -949,7 +949,7 @@ class ShuffleAndRepeatDatasetOp::Dataset : public ShuffleDatasetBase { {std::make_pair(kReshuffleEachIteration, reshuffle_each_iteration)}, // Attrs output)); - return OkStatus(); + return absl::OkStatus(); } private: @@ -1012,7 +1012,7 @@ class ShuffleAndRepeatDatasetOp::DatasetV2 : public ShuffleDatasetBase { {std::make_pair(kReshuffleEachIteration, reshuffle_each_iteration)}, // Attrs output)); - return OkStatus(); + return absl::OkStatus(); } private: @@ -1088,7 +1088,7 @@ void ShuffleAndRepeatDatasetOp::MakeDataset(OpKernelContext* ctx, *manager = new SeedGeneratorManager(new FixedSeedGenerator(seeds)); } - return OkStatus(); + return absl::OkStatus(); })); handle = MakeResourceHandle(ctx, container, name); } else { @@ -1118,7 +1118,7 @@ void ShuffleAndRepeatDatasetOp::MakeDataset(OpKernelContext* ctx, *manager = new SeedGeneratorManager(new FixedSeedGenerator(seeds)); } - return OkStatus(); + return absl::OkStatus(); })); auto handle = MakeResourceHandle(ctx, container, name); diff --git a/tensorflow/core/kernels/data/shuffle_dataset_op_test.cc b/tensorflow/core/kernels/data/shuffle_dataset_op_test.cc index 5845894baf592a..99f907de2c4dc7 100644 --- a/tensorflow/core/kernels/data/shuffle_dataset_op_test.cc +++ b/tensorflow/core/kernels/data/shuffle_dataset_op_test.cc @@ -68,7 +68,7 @@ class ShuffleDatasetParams : public DatasetParams { if (count_ != 1) { input_names->emplace_back(ShuffleAndRepeatDatasetOp::kCount); } - return OkStatus(); + return absl::OkStatus(); } Status GetAttributes(AttributeVector* attr_vector) const override { @@ -78,7 +78,7 @@ class ShuffleDatasetParams : public DatasetParams { attr_vector->emplace_back("reshuffle_each_iteration", reshuffle_each_iteration_); attr_vector->emplace_back("metadata", ""); - return OkStatus(); + return absl::OkStatus(); } string dataset_type() const override { diff --git a/tensorflow/core/kernels/data/skip_dataset_op.cc b/tensorflow/core/kernels/data/skip_dataset_op.cc index 1b5c027af0f7b2..2a0c75f4c54b93 100644 --- a/tensorflow/core/kernels/data/skip_dataset_op.cc +++ b/tensorflow/core/kernels/data/skip_dataset_op.cc @@ -77,7 +77,7 @@ class SkipDatasetOp::Dataset : public DatasetBase { Status InputDatasets(std::vector* inputs) const override { inputs->push_back(input_); - return OkStatus(); + return absl::OkStatus(); } Status CheckExternalState() const override { @@ -99,7 +99,7 @@ class SkipDatasetOp::Dataset : public DatasetBase { Node* count = nullptr; TF_RETURN_IF_ERROR(b->AddScalar(count_, &count)); TF_RETURN_IF_ERROR(b->AddDataset(this, {input_graph_node, count}, output)); - return OkStatus(); + return absl::OkStatus(); } private: @@ -114,7 +114,7 @@ class SkipDatasetOp::Dataset : public DatasetBase { std::vector* out_tensors, bool* end_of_sequence) override { *end_of_sequence = true; - return OkStatus(); + return absl::OkStatus(); } protected: @@ -126,12 +126,12 @@ class SkipDatasetOp::Dataset : public DatasetBase { Status SaveInternal(SerializationContext* ctx, IteratorStateWriter* writer) override { - return OkStatus(); + return absl::OkStatus(); } Status RestoreInternal(IteratorContext* ctx, IteratorStateReader* reader) override { - return OkStatus(); + return absl::OkStatus(); } }; @@ -153,7 +153,7 @@ class SkipDatasetOp::Dataset : public DatasetBase { if (!input_impl_) { *end_of_sequence = true; - return OkStatus(); + return absl::OkStatus(); } if (i_ < dataset()->count_) { @@ -164,7 +164,7 @@ class SkipDatasetOp::Dataset : public DatasetBase { if (*end_of_sequence) { // We reached the end before the count was reached. input_impl_.reset(); - return OkStatus(); + return absl::OkStatus(); } } @@ -174,7 +174,7 @@ class SkipDatasetOp::Dataset : public DatasetBase { if (*end_of_sequence) { input_impl_.reset(); } - return OkStatus(); + return absl::OkStatus(); } protected: @@ -193,7 +193,7 @@ class SkipDatasetOp::Dataset : public DatasetBase { if (input_impl_) { TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_)); } - return OkStatus(); + return absl::OkStatus(); } Status RestoreInternal(IteratorContext* ctx, @@ -208,7 +208,7 @@ class SkipDatasetOp::Dataset : public DatasetBase { } else { input_impl_.reset(); } - return OkStatus(); + return absl::OkStatus(); } private: diff --git a/tensorflow/core/kernels/data/skip_dataset_op_test.cc b/tensorflow/core/kernels/data/skip_dataset_op_test.cc index 70af1581d40c51..5a0fdf0060a660 100644 --- a/tensorflow/core/kernels/data/skip_dataset_op_test.cc +++ b/tensorflow/core/kernels/data/skip_dataset_op_test.cc @@ -43,7 +43,7 @@ class SkipDatasetParams : public DatasetParams { input_names->clear(); input_names->emplace_back(SkipDatasetOp::kInputDataset); input_names->emplace_back(SkipDatasetOp::kCount); - return OkStatus(); + return absl::OkStatus(); } Status GetAttributes(AttributeVector* attr_vector) const override { @@ -51,7 +51,7 @@ class SkipDatasetParams : public DatasetParams { attr_vector->emplace_back("output_types", output_dtypes_); attr_vector->emplace_back("output_shapes", output_shapes_); attr_vector->emplace_back("metadata", ""); - return OkStatus(); + return absl::OkStatus(); } string dataset_type() const override { return SkipDatasetOp::kDatasetType; } diff --git a/tensorflow/core/kernels/data/sparse_tensor_slice_dataset_op.cc b/tensorflow/core/kernels/data/sparse_tensor_slice_dataset_op.cc index 846d7be6d45af6..695263eee45c38 100644 --- a/tensorflow/core/kernels/data/sparse_tensor_slice_dataset_op.cc +++ b/tensorflow/core/kernels/data/sparse_tensor_slice_dataset_op.cc @@ -59,10 +59,10 @@ class Dataset : public DatasetBase { } Status InputDatasets(std::vector* inputs) const override { - return OkStatus(); + return absl::OkStatus(); } - Status CheckExternalState() const override { return OkStatus(); } + Status CheckExternalState() const override { return absl::OkStatus(); } protected: Status AsGraphDefInternal(SerializationContext* ctx, @@ -83,7 +83,7 @@ class Dataset : public DatasetBase { TF_RETURN_IF_ERROR( b->AddDataset(this, {indices_node, value_node, dense_shape_node}, {{"Tvalues", val_dtype}}, output)); - return OkStatus(); + return absl::OkStatus(); } private: @@ -107,7 +107,7 @@ class Dataset : public DatasetBase { mutex_lock l(mu_); if (i_ == num_elements_) { *end_of_sequence = true; - return OkStatus(); + return absl::OkStatus(); } out_tensors->clear(); @@ -158,7 +158,7 @@ class Dataset : public DatasetBase { ++i_; *end_of_sequence = false; - return OkStatus(); + return absl::OkStatus(); } protected: @@ -181,7 +181,7 @@ class Dataset : public DatasetBase { TF_RETURN_IF_ERROR(writer->WriteTensor(Iterator::prefix(), "next_values_", next_values_)); } - return OkStatus(); + return absl::OkStatus(); } Status RestoreInternal(IteratorContext* ctx, @@ -200,7 +200,7 @@ class Dataset : public DatasetBase { TF_RETURN_IF_ERROR(reader->ReadTensor(Iterator::prefix(), "next_values_", &next_values_)); } - return OkStatus(); + return absl::OkStatus(); } private: diff --git a/tensorflow/core/kernels/data/sparse_tensor_slice_dataset_op_test.cc b/tensorflow/core/kernels/data/sparse_tensor_slice_dataset_op_test.cc index dbc76441aeb26b..680b6c704a7722 100644 --- a/tensorflow/core/kernels/data/sparse_tensor_slice_dataset_op_test.cc +++ b/tensorflow/core/kernels/data/sparse_tensor_slice_dataset_op_test.cc @@ -50,13 +50,13 @@ class SparseTensorSliceDatasetParams : public DatasetParams { input_names->emplace_back("indices"); input_names->emplace_back("values"); input_names->emplace_back("dense_shape"); - return OkStatus(); + return absl::OkStatus(); } Status GetAttributes(AttributeVector* attr_vector) const override { attr_vector->clear(); attr_vector->emplace_back("Tvalues", tvalues_); - return OkStatus(); + return absl::OkStatus(); } string dataset_type() const override { return kDatasetType; } diff --git a/tensorflow/core/kernels/data/take_dataset_op.cc b/tensorflow/core/kernels/data/take_dataset_op.cc index 931b7dcdbfefd0..adbf81c06888c1 100644 --- a/tensorflow/core/kernels/data/take_dataset_op.cc +++ b/tensorflow/core/kernels/data/take_dataset_op.cc @@ -77,7 +77,7 @@ int64_t TakeDataset::CardinalityInternal(CardinalityOptions options) const { Status TakeDataset::InputDatasets( std::vector* inputs) const { inputs->push_back(input_); - return OkStatus(); + return absl::OkStatus(); } Status TakeDataset::CheckExternalState() const { @@ -100,7 +100,7 @@ class TakeDataset::EmptyIterator : public DatasetIterator { Status GetNextInternal(IteratorContext* ctx, std::vector* out_tensors, bool* end_of_sequence) override { *end_of_sequence = true; - return OkStatus(); + return absl::OkStatus(); } protected: @@ -112,12 +112,12 @@ class TakeDataset::EmptyIterator : public DatasetIterator { Status SaveInternal(SerializationContext* ctx, IteratorStateWriter* writer) override { - return OkStatus(); + return absl::OkStatus(); } Status RestoreInternal(IteratorContext* ctx, IteratorStateReader* reader) override { - return OkStatus(); + return absl::OkStatus(); } }; @@ -137,20 +137,20 @@ class TakeDataset::FiniteIterator : public DatasetIterator { mutex_lock l(mu_); // TODO(mrry): Make locking less conservative. if (!input_impl_) { *end_of_sequence = true; - return OkStatus(); + return absl::OkStatus(); } while (dataset()->count_ < 0 || i_ < dataset()->count_) { TF_RETURN_IF_ERROR( input_impl_->GetNext(ctx, out_tensors, end_of_sequence)); if (!*end_of_sequence) { ++i_; - return OkStatus(); + return absl::OkStatus(); } break; } *end_of_sequence = true; input_impl_.reset(); - return OkStatus(); + return absl::OkStatus(); } protected: @@ -169,7 +169,7 @@ class TakeDataset::FiniteIterator : public DatasetIterator { if (input_impl_) { TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_)); } - return OkStatus(); + return absl::OkStatus(); } Status RestoreInternal(IteratorContext* ctx, @@ -184,7 +184,7 @@ class TakeDataset::FiniteIterator : public DatasetIterator { } else { input_impl_.reset(); } - return OkStatus(); + return absl::OkStatus(); } private: @@ -214,7 +214,7 @@ Status TakeDataset::AsGraphDefInternal(SerializationContext* ctx, Node* count = nullptr; TF_RETURN_IF_ERROR(b->AddScalar(count_, &count)); TF_RETURN_IF_ERROR(b->AddDataset(this, {input_graph_node, count}, output)); - return OkStatus(); + return absl::OkStatus(); } TakeDatasetOp::TakeDatasetOp(OpKernelConstruction* ctx) diff --git a/tensorflow/core/kernels/data/tensor_dataset_op.cc b/tensorflow/core/kernels/data/tensor_dataset_op.cc index 81822eb3db007d..7d0fbc40c454bf 100644 --- a/tensorflow/core/kernels/data/tensor_dataset_op.cc +++ b/tensorflow/core/kernels/data/tensor_dataset_op.cc @@ -59,7 +59,7 @@ class TensorDatasetOp::Dataset : public DatasetBase { Status MakeSplitProviders(std::vector>* split_providers) const override { split_providers->push_back(std::make_unique(1)); - return OkStatus(); + return absl::OkStatus(); } const DataTypeVector& output_dtypes() const override { return dtypes_; } @@ -77,16 +77,16 @@ class TensorDatasetOp::Dataset : public DatasetBase { } Status InputDatasets(std::vector* inputs) const override { - return OkStatus(); + return absl::OkStatus(); } - Status CheckExternalState() const override { return OkStatus(); } + Status CheckExternalState() const override { return absl::OkStatus(); } Status Get(OpKernelContext* ctx, int64 index, std::vector* out_tensors) const override { TF_RETURN_IF_ERROR(CheckRandomAccessCompatible(index)); *out_tensors = tensors_; - return OkStatus(); + return absl::OkStatus(); } protected: @@ -110,7 +110,7 @@ class TensorDatasetOp::Dataset : public DatasetBase { b->BuildAttrValue(dtypes_, &dtypes); TF_RETURN_IF_ERROR(b->AddDataset(this, {}, {{0, components}}, {{kToutput_types, dtypes}}, output)); - return OkStatus(); + return absl::OkStatus(); } private: @@ -126,7 +126,7 @@ class TensorDatasetOp::Dataset : public DatasetBase { TF_ASSIGN_OR_RETURN(split_provider_, GetSingleSplitProvider(ctx, dataset())); } - return OkStatus(); + return absl::OkStatus(); } Status GetNextInternal(IteratorContext* ctx, @@ -145,10 +145,10 @@ class TensorDatasetOp::Dataset : public DatasetBase { *out_tensors = dataset()->tensors_; produced_ = true; *end_of_sequence = false; - return OkStatus(); + return absl::OkStatus(); } else { *end_of_sequence = true; - return OkStatus(); + return absl::OkStatus(); } } @@ -163,7 +163,7 @@ class TensorDatasetOp::Dataset : public DatasetBase { mutex_lock l(mu_); TF_RETURN_IF_ERROR(writer->WriteScalar(prefix(), kProduced, static_cast(produced_))); - return OkStatus(); + return absl::OkStatus(); } Status RestoreInternal(IteratorContext* ctx, @@ -172,7 +172,7 @@ class TensorDatasetOp::Dataset : public DatasetBase { int64_t produced; TF_RETURN_IF_ERROR(reader->ReadScalar(prefix(), kProduced, &produced)); produced_ = static_cast(produced); - return OkStatus(); + return absl::OkStatus(); } private: diff --git a/tensorflow/core/kernels/data/tensor_dataset_op_test.cc b/tensorflow/core/kernels/data/tensor_dataset_op_test.cc index bee3ccee541c0a..d3e62d757f9988 100644 --- a/tensorflow/core/kernels/data/tensor_dataset_op_test.cc +++ b/tensorflow/core/kernels/data/tensor_dataset_op_test.cc @@ -42,14 +42,14 @@ class TensorDatasetParams : public DatasetParams { input_names->emplace_back( absl::StrCat(TensorDatasetOp::kComponents, "_", i)); } - return OkStatus(); + return absl::OkStatus(); } Status GetAttributes(AttributeVector* attr_vector) const override { *attr_vector = {{"Toutput_types", output_dtypes_}, {"output_shapes", output_shapes_}, {"metadata", ""}}; - return OkStatus(); + return absl::OkStatus(); } string dataset_type() const override { return TensorDatasetOp::kDatasetType; } diff --git a/tensorflow/core/kernels/data/tensor_slice_dataset_op.cc b/tensorflow/core/kernels/data/tensor_slice_dataset_op.cc index 2e9e1726deb6a2..6de213fe921992 100644 --- a/tensorflow/core/kernels/data/tensor_slice_dataset_op.cc +++ b/tensorflow/core/kernels/data/tensor_slice_dataset_op.cc @@ -71,7 +71,7 @@ class TensorSliceDatasetOp::Dataset : public DatasetBase { split_providers) const override { split_providers->push_back( std::make_unique(tensors_[0].dim_size(0))); - return OkStatus(); + return absl::OkStatus(); } const DataTypeVector& output_dtypes() const override { return dtypes_; } @@ -89,10 +89,10 @@ class TensorSliceDatasetOp::Dataset : public DatasetBase { } Status InputDatasets(std::vector* inputs) const override { - return OkStatus(); + return absl::OkStatus(); } - Status CheckExternalState() const override { return OkStatus(); } + Status CheckExternalState() const override { return absl::OkStatus(); } Status Get(OpKernelContext* ctx, int64 index, std::vector* out_tensors) const override { @@ -102,7 +102,7 @@ class TensorSliceDatasetOp::Dataset : public DatasetBase { for (int i = 0; i < tensors_.size(); ++i) { out_tensors->push_back(MaybeCopySubSlice(tensors_[i], index)); } - return OkStatus(); + return absl::OkStatus(); } protected: @@ -138,7 +138,7 @@ class TensorSliceDatasetOp::Dataset : public DatasetBase { {kIsFiles, is_files}, {kReplicateOnSplit, replicate_on_split}}, output)); - return OkStatus(); + return absl::OkStatus(); } private: @@ -157,7 +157,7 @@ class TensorSliceDatasetOp::Dataset : public DatasetBase { TF_ASSIGN_OR_RETURN(split_provider_, GetSingleSplitProvider(ctx, dataset())); } - return OkStatus(); + return absl::OkStatus(); } Status GetNextInternal(IteratorContext* ctx, @@ -166,7 +166,7 @@ class TensorSliceDatasetOp::Dataset : public DatasetBase { Tensor split; TF_RETURN_IF_ERROR(split_provider_->GetNext(&split, end_of_sequence)); if (*end_of_sequence) { - return OkStatus(); + return absl::OkStatus(); } int64_t index = split.scalar()(); out_tensors->reserve(dataset()->tensors_.size()); @@ -175,7 +175,7 @@ class TensorSliceDatasetOp::Dataset : public DatasetBase { MaybeCopySubSlice(dataset()->tensors_[i], index)); } *end_of_sequence = false; - return OkStatus(); + return absl::OkStatus(); } protected: diff --git a/tensorflow/core/kernels/data/text_line_dataset_op.cc b/tensorflow/core/kernels/data/text_line_dataset_op.cc index e9b7852d5c0a95..317165d3f36ecd 100644 --- a/tensorflow/core/kernels/data/text_line_dataset_op.cc +++ b/tensorflow/core/kernels/data/text_line_dataset_op.cc @@ -72,10 +72,10 @@ class TextLineDatasetOp::Dataset : public DatasetBase { } Status InputDatasets(std::vector* inputs) const override { - return OkStatus(); + return absl::OkStatus(); } - Status CheckExternalState() const override { return OkStatus(); } + Status CheckExternalState() const override { return absl::OkStatus(); } protected: Status AsGraphDefInternal(SerializationContext* ctx, @@ -89,7 +89,7 @@ class TextLineDatasetOp::Dataset : public DatasetBase { TF_RETURN_IF_ERROR(b->AddScalar(options_.input_buffer_size, &buffer_size)); TF_RETURN_IF_ERROR(b->AddDataset( this, {filenames, compression_type, buffer_size}, output)); - return OkStatus(); + return absl::OkStatus(); } private: @@ -119,7 +119,7 @@ class TextLineDatasetOp::Dataset : public DatasetBase { bytes_counter->IncrementBy(line_contents_str.size()); out_tensors->push_back(std::move(line_contents)); *end_of_sequence = false; - return OkStatus(); + return absl::OkStatus(); } else if (!errors::IsOutOfRange(s)) { // Report non-EOF errors to the caller. return s; @@ -133,7 +133,7 @@ class TextLineDatasetOp::Dataset : public DatasetBase { // Iteration ends when there are no more files to process. if (current_file_index_ == dataset()->filenames_.size()) { *end_of_sequence = true; - return OkStatus(); + return absl::OkStatus(); } TF_RETURN_IF_ERROR(SetupStreamsLocked(ctx->env())); @@ -158,7 +158,7 @@ class TextLineDatasetOp::Dataset : public DatasetBase { TF_RETURN_IF_ERROR(writer->WriteScalar(prefix(), kCurrentPos, buffered_input_stream_->Tell())); } - return OkStatus(); + return absl::OkStatus(); } Status RestoreInternal(IteratorContext* ctx, @@ -179,7 +179,7 @@ class TextLineDatasetOp::Dataset : public DatasetBase { TF_RETURN_IF_ERROR(SetupStreamsLocked(ctx->env())); TF_RETURN_IF_ERROR(buffered_input_stream_->Seek(current_pos)); } - return OkStatus(); + return absl::OkStatus(); } private: @@ -209,7 +209,7 @@ class TextLineDatasetOp::Dataset : public DatasetBase { buffered_input_stream_ = std::make_unique( input_stream_.get(), dataset()->options_.input_buffer_size, false); } - return OkStatus(); + return absl::OkStatus(); } // Resets all reader streams. diff --git a/tensorflow/core/kernels/data/text_line_dataset_op_test.cc b/tensorflow/core/kernels/data/text_line_dataset_op_test.cc index a65874ebd2756a..02495e859a58b7 100644 --- a/tensorflow/core/kernels/data/text_line_dataset_op_test.cc +++ b/tensorflow/core/kernels/data/text_line_dataset_op_test.cc @@ -51,13 +51,13 @@ class TextLineDatasetParams : public DatasetParams { TextLineDatasetOp::kCompressionType, TextLineDatasetOp::kBufferSize, }; - return OkStatus(); + return absl::OkStatus(); } Status GetAttributes(AttributeVector* attr_vector) const override { attr_vector->clear(); attr_vector->emplace_back("metadata", ""); - return OkStatus(); + return absl::OkStatus(); } string dataset_type() const override { @@ -86,7 +86,7 @@ Status CreateTestFiles(const std::vector& filenames, TF_RETURN_IF_ERROR( WriteDataToFile(filenames[i], contents[i].data(), params)); } - return OkStatus(); + return absl::OkStatus(); } // Test case 1: multiple text files with ZLIB compression. diff --git a/tensorflow/core/kernels/data/tf_record_dataset_op.cc b/tensorflow/core/kernels/data/tf_record_dataset_op.cc index 9e6cb5185511c7..8d95feab7bdce7 100644 --- a/tensorflow/core/kernels/data/tf_record_dataset_op.cc +++ b/tensorflow/core/kernels/data/tf_record_dataset_op.cc @@ -97,10 +97,10 @@ class TFRecordDatasetOp::Dataset : public DatasetBase { } Status InputDatasets(std::vector* inputs) const override { - return OkStatus(); + return absl::OkStatus(); } - Status CheckExternalState() const override { return OkStatus(); } + Status CheckExternalState() const override { return absl::OkStatus(); } protected: Status AsGraphDefInternal(SerializationContext* ctx, @@ -116,7 +116,7 @@ class TFRecordDatasetOp::Dataset : public DatasetBase { this, {filenames, compression_type, buffer_size}, output)); Node* byte_offsets = nullptr; TF_RETURN_IF_ERROR(b->AddVector(byte_offsets_, &byte_offsets)); - return OkStatus(); + return absl::OkStatus(); } private: @@ -145,7 +145,7 @@ class TFRecordDatasetOp::Dataset : public DatasetBase { bytes_counter->IncrementBy( out_tensors->back().scalar()().size()); *end_of_sequence = false; - return OkStatus(); + return absl::OkStatus(); } out_tensors->pop_back(); if (!errors::IsOutOfRange(s)) { @@ -166,7 +166,7 @@ class TFRecordDatasetOp::Dataset : public DatasetBase { // Iteration ends when there are no more files to process. if (current_file_index_ == dataset()->filenames_.size()) { *end_of_sequence = true; - return OkStatus(); + return absl::OkStatus(); } TF_RETURN_IF_ERROR(SetupStreamsLocked(ctx->env())); @@ -187,7 +187,7 @@ class TFRecordDatasetOp::Dataset : public DatasetBase { *num_skipped += last_num_skipped; if (s.ok()) { *end_of_sequence = false; - return OkStatus(); + return absl::OkStatus(); } if (!errors::IsOutOfRange(s)) { // In case of other errors e.g., DataLoss, we still move forward @@ -207,7 +207,7 @@ class TFRecordDatasetOp::Dataset : public DatasetBase { // Iteration ends when there are no more files to process. if (current_file_index_ == dataset()->filenames_.size()) { *end_of_sequence = true; - return OkStatus(); + return absl::OkStatus(); } TF_RETURN_IF_ERROR(SetupStreamsLocked(ctx->env())); @@ -230,7 +230,7 @@ class TFRecordDatasetOp::Dataset : public DatasetBase { TF_RETURN_IF_ERROR( writer->WriteScalar(prefix(), kOffset, reader_->TellOffset())); } - return OkStatus(); + return absl::OkStatus(); } Status RestoreInternal(IteratorContext* ctx, @@ -247,7 +247,7 @@ class TFRecordDatasetOp::Dataset : public DatasetBase { TF_RETURN_IF_ERROR(SetupStreamsLocked(ctx->env())); TF_RETURN_IF_ERROR(reader_->SeekOffset(offset)); } - return OkStatus(); + return absl::OkStatus(); } private: @@ -269,7 +269,7 @@ class TFRecordDatasetOp::Dataset : public DatasetBase { TF_RETURN_IF_ERROR( reader_->SeekOffset(dataset()->byte_offsets_[current_file_index_])); } - return OkStatus(); + return absl::OkStatus(); } // Resets all reader streams. diff --git a/tensorflow/core/kernels/data/tf_record_dataset_op_test.cc b/tensorflow/core/kernels/data/tf_record_dataset_op_test.cc index c154dda6169e99..7b918b6e784ccc 100644 --- a/tensorflow/core/kernels/data/tf_record_dataset_op_test.cc +++ b/tensorflow/core/kernels/data/tf_record_dataset_op_test.cc @@ -74,13 +74,13 @@ class TFRecordDatasetParams : public DatasetParams { TFRecordDatasetOp::kBufferSize, TFRecordDatasetOp::kByteOffsets, }; - return OkStatus(); + return absl::OkStatus(); } Status GetAttributes(AttributeVector* attr_vector) const override { attr_vector->clear(); attr_vector->emplace_back("metadata", ""); - return OkStatus(); + return absl::OkStatus(); } string dataset_type() const override { @@ -111,7 +111,7 @@ Status CreateTestFiles(const std::vector& filenames, contents[i].end()); TF_RETURN_IF_ERROR(WriteDataToTFRecordFile(filenames[i], records, params)); } - return OkStatus(); + return absl::OkStatus(); } // Test case 1: multiple text files with ZLIB compression. diff --git a/tensorflow/core/kernels/data/window_dataset.cc b/tensorflow/core/kernels/data/window_dataset.cc index 4a945f8b6644ac..4ca4c2a8ac1f93 100644 --- a/tensorflow/core/kernels/data/window_dataset.cc +++ b/tensorflow/core/kernels/data/window_dataset.cc @@ -77,10 +77,10 @@ class Window : public DatasetBase { string DebugString() const override { return kWindow; } Status InputDatasets(std::vector* inputs) const override { - return OkStatus(); + return absl::OkStatus(); } - Status CheckExternalState() const override { return OkStatus(); } + Status CheckExternalState() const override { return absl::OkStatus(); } protected: Status AsGraphDefInternal(SerializationContext* ctx, @@ -103,7 +103,7 @@ class Window : public DatasetBase { } TF_RETURN_IF_ERROR( b->AddDataset(this, {}, {std::make_pair(0, input_nodes)}, {}, output)); - return OkStatus(); + return absl::OkStatus(); } private: @@ -121,14 +121,14 @@ class Window : public DatasetBase { *end_of_sequence = false; *out_tensors = dataset()->elements_[i_++]; } - return OkStatus(); + return absl::OkStatus(); } Status SaveInternal(SerializationContext* ctx, IteratorStateWriter* writer) override { mutex_lock l(mu_); TF_RETURN_IF_ERROR(writer->WriteScalar(prefix(), kCurIndex, i_)); - return OkStatus(); + return absl::OkStatus(); } Status RestoreInternal(IteratorContext* ctx, @@ -137,7 +137,7 @@ class Window : public DatasetBase { int64_t i; TF_RETURN_IF_ERROR(reader->ReadScalar(prefix(), kCurIndex, &i)); i_ = size_t(i); - return OkStatus(); + return absl::OkStatus(); } mutex mu_; @@ -191,7 +191,7 @@ Status NewWindow(std::vector> elements, *out_dataset = new Window(std::move(elements), std::move(output_types), std::move(output_shapes)); (*out_dataset)->Initialize(/*metadata=*/{}); - return OkStatus(); + return absl::OkStatus(); } } // namespace data diff --git a/tensorflow/core/kernels/data/window_dataset_op.cc b/tensorflow/core/kernels/data/window_dataset_op.cc index 711f60dd8f8ace..2aa058078e1c39 100644 --- a/tensorflow/core/kernels/data/window_dataset_op.cc +++ b/tensorflow/core/kernels/data/window_dataset_op.cc @@ -107,7 +107,7 @@ class WindowDatasetOp::Dataset : public DatasetBase { Status InputDatasets(std::vector* inputs) const override { inputs->push_back(input_); - return OkStatus(); + return absl::OkStatus(); } Status CheckExternalState() const override { @@ -133,7 +133,7 @@ class WindowDatasetOp::Dataset : public DatasetBase { {input_graph_node, window_size_node, window_shift_node, window_stride_node, drop_remainder_node}, output)); - return OkStatus(); + return absl::OkStatus(); } private: @@ -153,7 +153,7 @@ class WindowDatasetOp::Dataset : public DatasetBase { const int64_t window_shift = dataset()->window_shift_; const int64_t window_stride = dataset()->window_stride_; std::vector> window_elements; - Status status = OkStatus(); + Status status = absl::OkStatus(); { const size_t target_size = TargetBufferSize(window_size, window_stride); @@ -162,7 +162,7 @@ class WindowDatasetOp::Dataset : public DatasetBase { (buffer_.empty() || (dataset()->drop_remainder_ && buffer_.size() < target_size))) { *end_of_sequence = true; - return OkStatus(); + return absl::OkStatus(); } // Add elements to the buffer. @@ -187,7 +187,7 @@ class WindowDatasetOp::Dataset : public DatasetBase { if (buffer_.empty() || (dataset()->drop_remainder_ && buffer_.size() < target_size)) { DCHECK(*end_of_sequence); - return OkStatus(); + return absl::OkStatus(); } int num_elements = 1 + (buffer_.size() - 1) / window_stride; @@ -252,7 +252,7 @@ class WindowDatasetOp::Dataset : public DatasetBase { TF_RETURN_IF_ERROR( StoreDatasetInVariantTensor(window_dataset, &out_tensors->back())); } - return OkStatus(); + return absl::OkStatus(); } protected: @@ -284,7 +284,7 @@ class WindowDatasetOp::Dataset : public DatasetBase { buffer_[i].result[j])); } } - return OkStatus(); + return absl::OkStatus(); } Status RestoreInternal(IteratorContext* ctx, @@ -314,7 +314,7 @@ class WindowDatasetOp::Dataset : public DatasetBase { &buffer_[i].result[j])); } } - return OkStatus(); + return absl::OkStatus(); } TraceMeMetadata GetTraceMeMetadata() const override { @@ -340,7 +340,7 @@ class WindowDatasetOp::Dataset : public DatasetBase { TF_RETURN_IF_ERROR(writer->WriteScalar(prefix(), ErrorMessageKey(index), std::string(status.message()))); } - return OkStatus(); + return absl::OkStatus(); } Status ReadStatusLocked(IteratorStateReader* reader, size_t index, @@ -356,9 +356,9 @@ class WindowDatasetOp::Dataset : public DatasetBase { &error_message)); *status = Status(code, error_message); } else { - *status = OkStatus(); + *status = absl::OkStatus(); } - return OkStatus(); + return absl::OkStatus(); } string CodeKey(size_t index) { diff --git a/tensorflow/core/kernels/data/window_dataset_op_test.cc b/tensorflow/core/kernels/data/window_dataset_op_test.cc index fda3afbc8a8025..7d8d5b6bc0192b 100644 --- a/tensorflow/core/kernels/data/window_dataset_op_test.cc +++ b/tensorflow/core/kernels/data/window_dataset_op_test.cc @@ -59,7 +59,7 @@ class WindowDatasetParams : public DatasetParams { input_names->emplace_back(WindowDatasetOp::kShift); input_names->emplace_back(WindowDatasetOp::kStride); input_names->emplace_back(WindowDatasetOp::kDropRemainder); - return OkStatus(); + return absl::OkStatus(); } Status GetAttributes(AttributeVector* attr_vector) const override { @@ -67,7 +67,7 @@ class WindowDatasetParams : public DatasetParams { attr_vector->emplace_back("output_types", output_dtypes_); attr_vector->emplace_back("output_shapes", output_shapes_); attr_vector->emplace_back("metadata", ""); - return OkStatus(); + return absl::OkStatus(); } string dataset_type() const override { return WindowDatasetOp::kDatasetType; } diff --git a/tensorflow/core/kernels/data/zip_dataset_op.cc b/tensorflow/core/kernels/data/zip_dataset_op.cc index 41582cf0a1e5e9..3ceadfb5db2def 100644 --- a/tensorflow/core/kernels/data/zip_dataset_op.cc +++ b/tensorflow/core/kernels/data/zip_dataset_op.cc @@ -70,7 +70,7 @@ class ZipDatasetOp::Dataset : public DatasetBase { Status MakeSplitProviders(std::vector>* split_providers) const override { TF_ASSIGN_OR_RETURN(*split_providers, GetSplitProviders(this)); - return OkStatus(); + return absl::OkStatus(); } const DataTypeVector& output_dtypes() const override { @@ -104,14 +104,14 @@ class ZipDatasetOp::Dataset : public DatasetBase { for (const auto& input : inputs_) { inputs->push_back(input); } - return OkStatus(); + return absl::OkStatus(); } Status CheckExternalState() const override { for (const auto& input : inputs_) { TF_RETURN_IF_ERROR(input->CheckExternalState()); } - return OkStatus(); + return absl::OkStatus(); } Status Get(OpKernelContext* ctx, int64 index, @@ -124,7 +124,7 @@ class ZipDatasetOp::Dataset : public DatasetBase { out_tensors->insert(out_tensors->end(), input_tensors.begin(), input_tensors.end()); } - return OkStatus(); + return absl::OkStatus(); } protected: @@ -140,7 +140,7 @@ class ZipDatasetOp::Dataset : public DatasetBase { } TF_RETURN_IF_ERROR(b->AddDataset( this, {}, {std::make_pair(0, input_graph_nodes)}, {}, output)); - return OkStatus(); + return absl::OkStatus(); } private: @@ -162,7 +162,7 @@ class ZipDatasetOp::Dataset : public DatasetBase { &input_impls_[i])); ctx->MergeCheckpoint(input_contexts_[i].checkpoint()); } - return OkStatus(); + return absl::OkStatus(); } Status GetNextInternal(IteratorContext* ctx, @@ -171,11 +171,11 @@ class ZipDatasetOp::Dataset : public DatasetBase { mutex_lock l(mu_); if (input_impls_.empty()) { *end_of_sequence = true; - return OkStatus(); + return absl::OkStatus(); } out_tensors->clear(); out_tensors->reserve(dataset()->output_dtypes().size()); - Status status = OkStatus(); + Status status = absl::OkStatus(); *end_of_sequence = false; for (int i = 0; i < input_impls_.size(); ++i) { const auto& input_impl = input_impls_[i]; @@ -232,7 +232,7 @@ class ZipDatasetOp::Dataset : public DatasetBase { for (auto& input_impl : input_impls_) { TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl)); } - return OkStatus(); + return absl::OkStatus(); } Status RestoreInternal(IteratorContext* ctx, @@ -248,7 +248,7 @@ class ZipDatasetOp::Dataset : public DatasetBase { for (auto& input_impl : input_impls_) TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl)); } - return OkStatus(); + return absl::OkStatus(); } private: diff --git a/tensorflow/core/kernels/data/zip_dataset_op_test.cc b/tensorflow/core/kernels/data/zip_dataset_op_test.cc index 460c22f205747d..fcbd97f725491b 100644 --- a/tensorflow/core/kernels/data/zip_dataset_op_test.cc +++ b/tensorflow/core/kernels/data/zip_dataset_op_test.cc @@ -49,7 +49,7 @@ class ZipDatasetParams : public DatasetParams { input_names->emplace_back( absl::StrCat(ZipDatasetOp::kDatasetType, "_", i)); } - return OkStatus(); + return absl::OkStatus(); } Status GetAttributes(AttributeVector* attr_vector) const override { @@ -58,7 +58,7 @@ class ZipDatasetParams : public DatasetParams { attr_vector->emplace_back("output_shapes", output_shapes_); attr_vector->emplace_back("N", num_input_datasets_); attr_vector->emplace_back("metadata", ""); - return OkStatus(); + return absl::OkStatus(); } string dataset_type() const override { return ZipDatasetOp::kDatasetType; } diff --git a/tensorflow/core/kernels/debug_ops.h b/tensorflow/core/kernels/debug_ops.h index eb1d2db77a7f2e..d7c0c762fa7648 100644 --- a/tensorflow/core/kernels/debug_ops.h +++ b/tensorflow/core/kernels/debug_ops.h @@ -181,7 +181,7 @@ class BaseDebugOp : public OpKernel { // Log an error if the publishing failed. Status PublishTensor(const Tensor& tensor, int64_t step_id = -1) { if (debug_urls_.empty()) { - return OkStatus(); + return absl::OkStatus(); } else { Status status = DebugIO::PublishDebugTensor( *debug_watch_key_, tensor, Env::Default()->NowMicros(), debug_urls_, @@ -778,9 +778,11 @@ class DebugNumericSummaryV2Op : public AsyncOpKernel { se::DeviceMemoryBase output_tensor_ptr( output_tensor->flat().data(), output_tensor->flat().size()); - stream->ThenMemZero(&output_tensor_ptr, 2 * sizeof(Tout)); + OP_REQUIRES_OK(context, + stream->MemZero(&output_tensor_ptr, 2 * sizeof(Tout))); // Copy tensor_id to slot zero - stream->ThenMemcpy(&output_tensor_ptr, &tensor_id, sizeof(Tout)); + OP_REQUIRES_OK(context, stream->Memcpy(&output_tensor_ptr, &tensor_id, + sizeof(Tout))); if (num_elem == 0) { done(); return; @@ -812,9 +814,11 @@ class DebugNumericSummaryV2Op : public AsyncOpKernel { se::DeviceMemoryBase output_tensor_ptr( output_tensor->flat().data(), output_tensor->flat().size()); - stream->ThenMemset32(&output_tensor_ptr, 0, 5 * sizeof(Tout)); + OP_REQUIRES_OK(context, + stream->Memset32(&output_tensor_ptr, 0, 5 * sizeof(Tout))); const Tout static_output[] = {tensor_id, num_elem}; - stream->ThenMemcpy(&output_tensor_ptr, &static_output, 2 * sizeof(Tout)); + OP_REQUIRES_OK(context, stream->Memcpy(&output_tensor_ptr, &static_output, + 2 * sizeof(Tout))); if (num_elem == 0) { done(); return; @@ -846,14 +850,16 @@ class DebugNumericSummaryV2Op : public AsyncOpKernel { se::DeviceMemoryBase output_tensor_ptr( output_tensor->flat().data(), output_tensor->flat().size()); - stream->ThenMemset32(&output_tensor_ptr, 0, 11 * sizeof(Tout)); + OP_REQUIRES_OK( + context, stream->Memset32(&output_tensor_ptr, 0, 11 * sizeof(Tout))); int num_dims = tensor.dims(); const Tout static_output[] = {tensor_id, -1.0, // TODO(144919262): Device ID static_cast(tensor.dtype()), static_cast(num_dims), num_elem}; - stream->ThenMemcpy(&output_tensor_ptr, &static_output, 5 * sizeof(Tout)); + OP_REQUIRES_OK(context, stream->Memcpy(&output_tensor_ptr, &static_output, + 5 * sizeof(Tout))); if (num_elem == 0) { done(); return; @@ -897,7 +903,8 @@ class DebugNumericSummaryV2Op : public AsyncOpKernel { static_output[dim_idx++] = static_cast(tensor.dim_size(i)); } // Write to device stream - stream->ThenMemcpy(&output_tensor_ptr, &static_output, sizeof(Tout) * 10); + OP_REQUIRES_OK(context, stream->Memcpy(&output_tensor_ptr, &static_output, + sizeof(Tout) * 10)); context->device() ->tensorflow_accelerator_device_info() ->event_mgr->ThenExecute(stream, std::move(check_cb)); @@ -913,8 +920,10 @@ class DebugNumericSummaryV2Op : public AsyncOpKernel { se::DeviceMemoryBase output_tensor_ptr( output_tensor->flat().data(), output_tensor->flat().size()); - stream->ThenMemset32(&output_tensor_ptr, 0, - output_tensor->flat().size() * sizeof(Tout)); + OP_REQUIRES_OK( + context, + stream->Memset32(&output_tensor_ptr, 0, + output_tensor->flat().size() * sizeof(Tout))); if (num_elem == 0) { done(); return; diff --git a/tensorflow/core/kernels/decode_compressed_op.cc b/tensorflow/core/kernels/decode_compressed_op.cc index 971faa01ac4fa3..407746a9e20b02 100644 --- a/tensorflow/core/kernels/decode_compressed_op.cc +++ b/tensorflow/core/kernels/decode_compressed_op.cc @@ -41,7 +41,7 @@ class MemoryInputStream : public io::InputStreamInterface { bytes_to_read); } int64_t bytes = bytes_to_read; - Status s = OkStatus(); + Status s = absl::OkStatus(); if (pos_ + bytes_to_read > len_) { bytes = len_ - pos_; s = errors::OutOfRange("reached end of file"); @@ -58,7 +58,7 @@ class MemoryInputStream : public io::InputStreamInterface { Status Reset() override { pos_ = 0; - return OkStatus(); + return absl::OkStatus(); } private: diff --git a/tensorflow/core/kernels/decode_proto_op.cc b/tensorflow/core/kernels/decode_proto_op.cc index bc2622bd1a8372..eb55543c528355 100644 --- a/tensorflow/core/kernels/decode_proto_op.cc +++ b/tensorflow/core/kernels/decode_proto_op.cc @@ -123,7 +123,7 @@ Status InitDefaultValue(DataType dtype, const T value, DefaultValue* result) { "Cannot initialize default value for unsupported type: ", DataTypeString(dtype)); } - return OkStatus(); + return absl::OkStatus(); } template <> @@ -190,7 +190,7 @@ Status InitDefaultValueFromFieldDescriptor(DataType dtype, return InitDefaultValue(dtype, "", result); // default: intentionally omitted in order to enable static checking. } - return OkStatus(); + return absl::OkStatus(); } // A FieldInfo holds a handful of information from the FieldDescriptor @@ -263,14 +263,14 @@ class CountCollector { if (!SkipValue(input, field)) { return errors::DataLoss("ReadValue: Failed skipping field when counting"); } - return OkStatus(); + return absl::OkStatus(); } // Reads (in this case counts) a length-delimited list of values. Status ReadPackedValues(CodedInputStream* input, const FieldInfo& field, size_t buf_size) { if (buf_size == 0) { - return OkStatus(); + return absl::OkStatus(); } const void* tmpbuf; @@ -356,7 +356,7 @@ class CountCollector { if (!field.is_repeated && *count_ptr_ > 1) { *count_ptr_ = 1; } - return OkStatus(); + return absl::OkStatus(); } private: @@ -395,7 +395,7 @@ class CountCollector { } *count_ptr_ += count; - return OkStatus(); + return absl::OkStatus(); } // Counts the number of fixed-size values in a packed field. This can be done @@ -408,7 +408,7 @@ class CountCollector { "Illegal data length for packed fixed-size type: ", len); } *count_ptr_ += len / sizeof(T); - return OkStatus(); + return absl::OkStatus(); } // Skips a single value in the input stream. Dispatches to the appropriately @@ -578,7 +578,7 @@ class DenseCollector { for (int i = next_repeat_index_; i < max_repeat_count_; i++) { reinterpret_cast(datap_)[i] = default_value; } - return OkStatus(); + return absl::OkStatus(); } int32 next_repeat_index_ = 0; @@ -1033,7 +1033,7 @@ class DecodeProtoOp : public OpKernel { *field_info, WireFormatLite::GetTagWireType(tag), input, &collectors[expected_field_info_iter - fields_.begin()])); } - return OkStatus(); + return absl::OkStatus(); } // Collects values for a single field. @@ -1073,7 +1073,7 @@ class DecodeProtoOp : public OpKernel { return errors::DataLoss( "CollectField: Failed skipping malformed field"); } - return OkStatus(); + return absl::OkStatus(); } return collector->ReadValue(input, field); } diff --git a/tensorflow/core/kernels/deserialize_sparse_string_op.cc b/tensorflow/core/kernels/deserialize_sparse_string_op.cc index b02308b827b8f1..c7275f139421a8 100644 --- a/tensorflow/core/kernels/deserialize_sparse_string_op.cc +++ b/tensorflow/core/kernels/deserialize_sparse_string_op.cc @@ -223,7 +223,7 @@ class DeserializeSparseOp : public OpKernel { return errors::InvalidArgument("Could not construct tensor from proto"); } *result = tensor; - return OkStatus(); + return absl::OkStatus(); } Status GetAndValidateSparseTensor( @@ -278,7 +278,7 @@ class DeserializeSparseOp : public OpKernel { index, "].shape but they do not: ", rank, " vs. ", output_shape->dim_size(0)); } - return OkStatus(); + return absl::OkStatus(); } DataType dtype_; diff --git a/tensorflow/core/kernels/deserialize_sparse_variant_op.cc b/tensorflow/core/kernels/deserialize_sparse_variant_op.cc index 9be55421aefa5b..57d246ff5912e3 100644 --- a/tensorflow/core/kernels/deserialize_sparse_variant_op.cc +++ b/tensorflow/core/kernels/deserialize_sparse_variant_op.cc @@ -296,7 +296,7 @@ class DeserializeSparseOp : public OpKernel { (*output_shape)->shape().DebugString()); } *output_num_non_zeros = serialized_values.get()->NumElements(); - return OkStatus(); + return absl::OkStatus(); } Status GetAndValidateSparseTensorIndicesAndValues( @@ -356,7 +356,7 @@ class DeserializeSparseOp : public OpKernel { (*output_values)->dim_size(0)); } - return OkStatus(); + return absl::OkStatus(); } DataType dtype_; diff --git a/tensorflow/core/kernels/edit_distance_op.cc b/tensorflow/core/kernels/edit_distance_op.cc index 9e2b430dbc03d6..b495f2f80524e4 100644 --- a/tensorflow/core/kernels/edit_distance_op.cc +++ b/tensorflow/core/kernels/edit_distance_op.cc @@ -100,7 +100,7 @@ Status ValidateShapes(OpKernelContext* ctx, const Tensor& hypothesis_indices, truth_shape.shape().DebugString(), " and ", hypothesis_shape.shape().DebugString()); - return OkStatus(); + return absl::OkStatus(); } } // namespace diff --git a/tensorflow/core/kernels/encode_proto_op.cc b/tensorflow/core/kernels/encode_proto_op.cc index d585dff1f9e90f..8a79d0817d4126 100644 --- a/tensorflow/core/kernels/encode_proto_op.cc +++ b/tensorflow/core/kernels/encode_proto_op.cc @@ -277,7 +277,7 @@ Status WriteField(const FieldDescriptor& field_desc, const Tensor& input, Writer(value, output); } } - return OkStatus(); + return absl::OkStatus(); } // Writes a possibly repeated string, bytes, or message field. @@ -293,7 +293,7 @@ Status WriteVarLenField(const FieldDescriptor& field_desc, const Tensor& input, // small speedup. Writer(field_desc.number(), value, output); } - return OkStatus(); + return absl::OkStatus(); } static void WriteStringAdapter(int field_number, const tstring& value, @@ -331,7 +331,7 @@ Status WriteGroup(const FieldDescriptor& field_desc, const Tensor& input, WireFormatLite::WriteTag(field_desc.number(), WireFormatLite::WIRETYPE_END_GROUP, output); } - return OkStatus(); + return absl::OkStatus(); } // Writes a (possibly repeated) field into an output stream. It is the caller's diff --git a/tensorflow/core/kernels/example_parsing_ops.cc b/tensorflow/core/kernels/example_parsing_ops.cc index e241665f297110..b425b0a845612b 100644 --- a/tensorflow/core/kernels/example_parsing_ops.cc +++ b/tensorflow/core/kernels/example_parsing_ops.cc @@ -111,7 +111,7 @@ class ParseExampleOp : public OpKernel { for (int i = 0; i < keys_flat.size(); ++i) { keys->push_back(keys_flat(i)); } - return OkStatus(); + return absl::OkStatus(); } // Copies keys from OpInputList of scalar to std::vector. @@ -123,7 +123,7 @@ class ParseExampleOp : public OpKernel { for (const auto& key : key_list) { keys->push_back(key.scalar()()); } - return OkStatus(); + return absl::OkStatus(); } // Validates the shapes of input tensors. @@ -205,7 +205,7 @@ class ParseExampleOp : public OpKernel { "] == ", DataTypeString(attrs_.dense_types[d])); } } - return OkStatus(); + return absl::OkStatus(); } // Populates the FastParseExampleConfig from keys & defaults. @@ -284,7 +284,7 @@ class ParseExampleOp : public OpKernel { ragged_splits.set(d, result.ragged_splits[d]); } } - return OkStatus(); + return absl::OkStatus(); } ParseExampleAttrs attrs_; @@ -566,7 +566,7 @@ class ParseSequenceExampleOp : public OpKernel { } } } - return OkStatus(); + return absl::OkStatus(); } example::FastParseExampleConfig MakeContextConfig( @@ -765,7 +765,7 @@ class ParseSequenceExampleOp : public OpKernel { d, feature_list_result.ragged_splits[d]); } } - return OkStatus(); + return absl::OkStatus(); } ParseSequenceExampleAttrs attrs_; diff --git a/tensorflow/core/kernels/fft_ops.cc b/tensorflow/core/kernels/fft_ops.cc index 099a57025b3657..bcc81f903b84f6 100644 --- a/tensorflow/core/kernels/fft_ops.cc +++ b/tensorflow/core/kernels/fft_ops.cc @@ -727,14 +727,16 @@ class FFTGPUBase : public FFTBase { // Create a new plan if one doesn't exist. Otherwise, we need only set // the scratch allocator. + auto fft = stream->parent()->AsFft(); + OP_REQUIRES(ctx, fft != nullptr, absl::InternalError("No FFT for stream.")); if (plan == nullptr) { - plan = stream->parent()->AsFft()->CreateBatchedPlanWithScratchAllocator( + plan = fft->CreateBatchedPlanWithScratchAllocator( stream, fft_rank, fft_shape, input_embed, input_stride, input_distance, output_embed, output_stride, output_distance, kFftType, kInPlaceFft, batch_size, &scratch_allocator); } else { - stream->parent()->AsFft()->UpdatePlanWithScratchAllocator( - stream, plan.get(), &scratch_allocator); + fft->UpdatePlanWithScratchAllocator(stream, plan.get(), + &scratch_allocator); } OP_REQUIRES( @@ -807,17 +809,21 @@ class FFTGPUBase : public FFTBase { AsDeviceMemory(in.flat().data(), input_shape.num_elements()); auto dst = AsDeviceMemory(out->flat().data(), output_shape.num_elements()); + auto fft = stream->parent()->AsFft(); + OP_REQUIRES(ctx, fft != nullptr, absl::InternalError("No FFT for stream.")); OP_REQUIRES( - ctx, stream->ThenFft(plan, src, &dst).ok(), + ctx, fft->DoFft(stream, plan, src, &dst), errors::Internal("fft failed : type=", static_cast(fft_type), " in.shape=", input_shape.DebugString())); if (!IsForward()) { typedef typename RealTypeFromComplexType::RealT RealT; RealT alpha = 1.0 / output_distance; + auto blas = stream->parent()->AsBlas(); + OP_REQUIRES(ctx, blas != nullptr, + absl::InternalError("No Blas for stream.")); OP_REQUIRES( ctx, - stream->ThenBlasScal(output_shape.num_elements(), alpha, &dst, 1) - .ok(), + blas->DoBlasScal(stream, output_shape.num_elements(), alpha, &dst, 1), errors::Internal("BlasScal failed : in.shape=", input_shape.DebugString())); } @@ -909,16 +915,18 @@ class FFTNGPUBase : public FFTNBase { plan = std::move(*plan_or); } } + auto fft = stream->parent()->AsFft(); + OP_REQUIRES(ctx, fft != nullptr, absl::InternalError("No FFT for stream.")); // Create a new plan if one doesn't exist. Otherwise, we need only set // the scratch allocator. if (plan == nullptr) { - plan = stream->parent()->AsFft()->CreateBatchedPlanWithScratchAllocator( + plan = fft->CreateBatchedPlanWithScratchAllocator( stream, fft_rank, fft_shape, input_embed, input_stride, input_distance, output_embed, output_stride, output_distance, kFftType, kInPlaceFft, batch_size, &scratch_allocator); } else { - stream->parent()->AsFft()->UpdatePlanWithScratchAllocator( - stream, plan.get(), &scratch_allocator); + fft->UpdatePlanWithScratchAllocator(stream, plan.get(), + &scratch_allocator); } OP_REQUIRES( @@ -989,17 +997,21 @@ class FFTNGPUBase : public FFTNBase { AsDeviceMemory(in.flat().data(), input_shape.num_elements()); auto dst = AsDeviceMemory(out->flat().data(), output_shape.num_elements()); - OP_REQUIRES(ctx, stream->ThenFft(plan, src, &dst).ok(), + auto fft = stream->parent()->AsFft(); + OP_REQUIRES(ctx, fft != nullptr, absl::InternalError("No FFT for stream.")); + OP_REQUIRES(ctx, fft->DoFft(stream, plan, src, &dst), absl::InternalError(absl::StrCat( "fft failed : type=", static_cast(fft_type), " in.shape=", input_shape.DebugString()))); if (!IsForward()) { typedef typename RealTypeFromComplexType::RealT RealT; RealT alpha = 1.0 / output_distance; + auto blas = stream->parent()->AsBlas(); + OP_REQUIRES(ctx, blas != nullptr, + absl::InternalError("No blas for stream.")); OP_REQUIRES( ctx, - stream->ThenBlasScal(output_shape.num_elements(), alpha, &dst, 1) - .ok(), + blas->DoBlasScal(stream, output_shape.num_elements(), alpha, &dst, 1), absl::InternalError(absl::StrCat("BlasScal failed : in.shape=", input_shape.DebugString()))); } diff --git a/tensorflow/core/kernels/fifo_queue.cc b/tensorflow/core/kernels/fifo_queue.cc index e58ee0e111c54e..27adb244c4d5e9 100644 --- a/tensorflow/core/kernels/fifo_queue.cc +++ b/tensorflow/core/kernels/fifo_queue.cc @@ -96,7 +96,7 @@ Status FIFOQueue::GetElementComponentFromBatch(const FIFOQueue::Tuple& tuple, ctx->allocate_temp(tuple[component].dtype(), element_shape, out_tensor)); TF_RETURN_IF_ERROR( batch_util::CopySliceToElement(tuple[component], out_tensor, index)); - return OkStatus(); + return absl::OkStatus(); } void FIFOQueue::TryEnqueueMany(const Tuple& tuple, OpKernelContext* ctx, @@ -363,7 +363,7 @@ Status FIFOQueue::MatchesNodeDef(const NodeDef& node_def) { TF_RETURN_IF_ERROR(MatchesNodeDefCapacity(node_def, capacity_)); TF_RETURN_IF_ERROR(MatchesNodeDefTypes(node_def)); TF_RETURN_IF_ERROR(MatchesNodeDefShapes(node_def)); - return OkStatus(); + return absl::OkStatus(); } // Defines a FIFOQueueOp, which produces a Queue (specifically, one diff --git a/tensorflow/core/kernels/fill_empty_rows_functor_gpu.cu.cc b/tensorflow/core/kernels/fill_empty_rows_functor_gpu.cu.cc index 8c1b354997c6e5..331b3a77130a6a 100644 --- a/tensorflow/core/kernels/fill_empty_rows_functor_gpu.cu.cc +++ b/tensorflow/core/kernels/fill_empty_rows_functor_gpu.cu.cc @@ -237,22 +237,16 @@ struct FillEmptyRows { auto elements_per_row = elements_per_row_t.flat(); se::DeviceMemoryBase elements_per_row_gpu_memory( elements_per_row.data(), dense_rows * sizeof(Tindex)); - if (!stream - ->ThenMemZero(&elements_per_row_gpu_memory, - dense_rows * sizeof(Tindex)) - .ok()) { - return errors::Internal("Failed to zero elements_per_row"); - } + TF_RETURN_IF_ERROR(stream->MemZero(&elements_per_row_gpu_memory, + dense_rows * sizeof(Tindex))); Tensor rows_are_not_ordered_t; TF_RETURN_IF_ERROR(context->allocate_temp(DT_INT32, TensorShape({1}), &rows_are_not_ordered_t)); auto rows_are_not_ordered_gpu = rows_are_not_ordered_t.flat(); se::DeviceMemoryBase rows_are_not_ordered_gpu_memory( rows_are_not_ordered_gpu.data(), sizeof(int)); - if (!stream->ThenMemZero(&rows_are_not_ordered_gpu_memory, sizeof(int)) - .ok()) { - return errors::Internal("Failed to zero rows_are_not_ordered"); - } + TF_RETURN_IF_ERROR( + stream->MemZero(&rows_are_not_ordered_gpu_memory, sizeof(int))); Tensor first_invalid_index_t; TF_RETURN_IF_ERROR(context->allocate_temp(DT_INT32, TensorShape({1}), &first_invalid_index_t)); @@ -260,12 +254,8 @@ struct FillEmptyRows { constexpr const int kAllIndicesValid = std::numeric_limits::max(); se::DeviceMemoryBase first_invalid_index_gpu_memory( first_invalid_index_gpu.data(), sizeof(int)); - if (!stream - ->ThenMemset32(&first_invalid_index_gpu_memory, kAllIndicesValid, - sizeof(int)) - .ok()) { - return errors::Internal("Failed to initialize first_invalid_index"); - } + TF_RETURN_IF_ERROR(stream->Memset32(&first_invalid_index_gpu_memory, + kAllIndicesValid, sizeof(int))); if (N > 0) { TF_RETURN_IF_ERROR(wrap_kernel_call( @@ -321,33 +311,22 @@ struct FillEmptyRows { /*output=*/num_empty_rows_through.data())); ScratchSpace num_empty_rows_host(context, 1, /*on_host=*/true); - if (!stream - ->ThenMemcpy(num_empty_rows_host.mutable_data(), - se::DeviceMemoryBase( - num_empty_rows_through.data() + (dense_rows - 1), - sizeof(*num_empty_rows_host.data())), - sizeof(*num_empty_rows_host.data())) - .ok()) { - return errors::Internal("Failed to copy num_empty_rows to host"); - } + TF_RETURN_IF_ERROR(stream->Memcpy( + num_empty_rows_host.mutable_data(), + se::DeviceMemoryBase(num_empty_rows_through.data() + (dense_rows - 1), + sizeof(*num_empty_rows_host.data())), + sizeof(*num_empty_rows_host.data()))); ScratchSpace rows_are_not_ordered_host(context, 1, /*on_host=*/true); - if (!stream - ->ThenMemcpy(rows_are_not_ordered_host.mutable_data(), - rows_are_not_ordered_gpu_memory, - sizeof(*rows_are_not_ordered_host.data())) - .ok()) { - return errors::Internal("Failed to copy rows_are_not_ordered to host"); - } + TF_RETURN_IF_ERROR( + stream->Memcpy(rows_are_not_ordered_host.mutable_data(), + rows_are_not_ordered_gpu_memory, + sizeof(*rows_are_not_ordered_host.data()))); ScratchSpace first_invalid_index_host(context, 1, /*on_host=*/true); - if (!stream - ->ThenMemcpy(first_invalid_index_host.mutable_data(), - first_invalid_index_gpu_memory, - sizeof(*first_invalid_index_host.data())) - .ok()) { - return errors::Internal("Failed to copy first_invalid_index to host"); - } + TF_RETURN_IF_ERROR(stream->Memcpy( + first_invalid_index_host.mutable_data(), first_invalid_index_gpu_memory, + sizeof(*first_invalid_index_host.data()))); // We must wait for num_empty_rows and rows_are_not_ordered to be copied to // the host, so we enqueue the remainder of the computation onto the stream diff --git a/tensorflow/core/kernels/fingerprint_op_test.cc b/tensorflow/core/kernels/fingerprint_op_test.cc index 34f6fa354f007d..2ba9640d09e963 100644 --- a/tensorflow/core/kernels/fingerprint_op_test.cc +++ b/tensorflow/core/kernels/fingerprint_op_test.cc @@ -54,7 +54,7 @@ class FingerprintOpTest : public OpsTestBase { method_ = Tensor(DT_STRING, TensorShape{}); method_.scalar()() = method; inputs_.push_back(TensorValue(&method_)); - return OkStatus(); + return absl::OkStatus(); } Tensor batch_dims_; diff --git a/tensorflow/core/kernels/fixed_length_record_reader_op.cc b/tensorflow/core/kernels/fixed_length_record_reader_op.cc index 7787b0ad5f1c02..8db930d7795480 100644 --- a/tensorflow/core/kernels/fixed_length_record_reader_op.cc +++ b/tensorflow/core/kernels/fixed_length_record_reader_op.cc @@ -69,12 +69,12 @@ class FixedLengthRecordReader : public ReaderBase { // header_bytes_ is always skipped. TF_RETURN_IF_ERROR(buffered_inputstream_->SkipNBytes(header_bytes_)); - return OkStatus(); + return absl::OkStatus(); } Status OnWorkFinishedLocked() override { buffered_inputstream_.reset(nullptr); - return OkStatus(); + return absl::OkStatus(); } Status ReadLocked(tstring* key, tstring* value, bool* produced, @@ -98,7 +98,7 @@ class FixedLengthRecordReader : public ReaderBase { return s; } *at_end = true; - return OkStatus(); + return absl::OkStatus(); } } } @@ -112,7 +112,7 @@ class FixedLengthRecordReader : public ReaderBase { return s; } *at_end = true; - return OkStatus(); + return absl::OkStatus(); } lookahead_cache_.append(*value, 0, bytes_to_read); value->clear(); @@ -124,7 +124,7 @@ class FixedLengthRecordReader : public ReaderBase { *produced = true; ++record_number_; - return OkStatus(); + return absl::OkStatus(); } Status ResetLocked() override { diff --git a/tensorflow/core/kernels/function_ops.cc b/tensorflow/core/kernels/function_ops.cc index 3ce1145493731b..3f4bb1bb96ed4d 100644 --- a/tensorflow/core/kernels/function_ops.cc +++ b/tensorflow/core/kernels/function_ops.cc @@ -52,7 +52,7 @@ void ArgOp::Compute(OpKernelContext* ctx) { auto validate_type = [this](const Tensor& val) { if (val.dtype() == dtype_) { - return OkStatus(); + return absl::OkStatus(); } else { return errors::InvalidArgument("Type mismatch: actual ", DataTypeString(val.dtype()), diff --git a/tensorflow/core/kernels/functional_ops.cc b/tensorflow/core/kernels/functional_ops.cc index dd098b01bc65d2..79c393facbb6f1 100644 --- a/tensorflow/core/kernels/functional_ops.cc +++ b/tensorflow/core/kernels/functional_ops.cc @@ -90,7 +90,7 @@ Status ToBool(gtl::ArraySlice t, bool* v) { } else { *v = t[0].NumElements() > 0; } - return OkStatus(); + return absl::OkStatus(); } // Sets "rets" to be the output of "ctx". Validates rets' types based @@ -109,7 +109,7 @@ Status SetOutputs(const OpKernel* kernel, OpKernelContext* ctx, } ctx->set_output(i, rets[i]); } - return OkStatus(); + return absl::OkStatus(); } void SetRunOptions(OpKernelContext* ctx, FunctionLibraryRuntime::Options* opts, @@ -263,7 +263,7 @@ class IfOp : public AsyncOpKernel { tsl::core::WeakPtr(lib)); } } - return OkStatus(); + return absl::OkStatus(); } }; @@ -349,7 +349,7 @@ class CaseOp : public AsyncOpKernel { } } branch_handles.assign(handles.begin(), handles.end()); - return OkStatus(); + return absl::OkStatus(); } class State { @@ -556,7 +556,7 @@ class WhileOp : public AsyncOpKernel { Status GetArg(int index, const Tensor** val) override { if (index < args_->size()) { *val = &(*args_)[index]; - return OkStatus(); + return absl::OkStatus(); } else { return errors::InvalidArgument("Argument ", index, " is out of range."); } @@ -586,7 +586,7 @@ class WhileOp : public AsyncOpKernel { DataTypeString(val.dtype()), "."); } (*retvals_)[index] = val; - return OkStatus(); + return absl::OkStatus(); } private: @@ -665,7 +665,7 @@ class WhileOp : public AsyncOpKernel { } if (!cond) { - return Finish(OkStatus()); + return Finish(absl::OkStatus()); } rets_.clear(); rets_.resize(args_.size()); @@ -790,7 +790,7 @@ class WhileOp : public AsyncOpKernel { tsl::core::WeakPtr(lib)); } } - return OkStatus(); + return absl::OkStatus(); } }; // TODO(drpng): remove these. @@ -825,7 +825,7 @@ Status GetScalar(OpKernelContext* ctx, int index, int32* value, t.shape().DebugString()); } *value = t.scalar()(); - return OkStatus(); + return absl::OkStatus(); } class ForOp : public AsyncOpKernel { @@ -895,7 +895,7 @@ class ForOp : public AsyncOpKernel { *body_handle, tsl::core::WeakPtr(lib)); } } - return OkStatus(); + return absl::OkStatus(); } class State { @@ -953,7 +953,7 @@ class ForOp : public AsyncOpKernel { (delta_ < 0 && *iter_ >= limit_) || (delta_ == 0 && *iter_ == limit_)) { RunNext(); - return OkStatus(); + return absl::OkStatus(); } else { return errors::InvalidArgument("Invalid start/limit/delta: ", *iter_, " ", limit_, " ", delta_); @@ -968,7 +968,7 @@ class ForOp : public AsyncOpKernel { done_loop = *iter_ <= limit_; } if (done_loop) { - Finish(OkStatus()); + Finish(absl::OkStatus()); return; } diff --git a/tensorflow/core/kernels/fused_batch_norm_op.cc b/tensorflow/core/kernels/fused_batch_norm_op.cc index 009710f113fd7a..6510cfc8fc4c7b 100644 --- a/tensorflow/core/kernels/fused_batch_norm_op.cc +++ b/tensorflow/core/kernels/fused_batch_norm_op.cc @@ -73,11 +73,11 @@ Status ParseActivationMode(OpKernelConstruction* context, if (activation_mode_str == "Identity") { *activation_mode = FusedBatchNormActivationMode::kIdentity; - return OkStatus(); + return absl::OkStatus(); } if (activation_mode_str == "Relu") { *activation_mode = FusedBatchNormActivationMode::kRelu; - return OkStatus(); + return absl::OkStatus(); } return errors::InvalidArgument("Unsupported activation mode: ", activation_mode_str); @@ -967,38 +967,30 @@ struct FusedBatchNormImplGPU { } if (!batch_mean->SharesBufferWith(estimated_mean) && exponential_avg_factor != 1.0f) { - OP_REQUIRES( - context, - stream - ->ThenMemcpyD2D(&batch_mean_ptr, estimated_mean_ptr, - estimated_mean.NumElements() * sizeof(U)) - .ok(), - errors::Internal("MatrixTriangularSolveOp: failed to copy rhs " - "from device")); + OP_REQUIRES_OK( + context, stream->MemcpyD2D(&batch_mean_ptr, estimated_mean_ptr, + estimated_mean.NumElements() * sizeof(U))); } if (!batch_var->SharesBufferWith(estimated_variance) && exponential_avg_factor != 1.0f) { - OP_REQUIRES( + OP_REQUIRES_OK( context, - stream - ->ThenMemcpyD2D(&batch_var_ptr, estimated_variance_ptr, - estimated_variance.NumElements() * sizeof(U)) - .ok(), - errors::Internal("MatrixTriangularSolveOp: failed to copy rhs " - "from device")); + stream->MemcpyD2D(&batch_var_ptr, estimated_variance_ptr, + estimated_variance.NumElements() * sizeof(U))); } - bool cudnn_launch_status = - stream - ->ThenBatchNormalizationForward( - x_ptr, scale_ptr, offset_ptr, estimated_mean_ptr, - estimated_variance_ptr, side_input_ptr, x_desc, - scale_offset_desc, static_cast(epsilon), - static_cast(exponential_avg_factor), - AsDnnActivationMode(activation_mode), &y_ptr, &batch_mean_ptr, - &batch_var_ptr, &saved_mean_ptr, &saved_inv_var_ptr, - is_training, reserve_space_allocator.get(), - workspace_allocator.get()) - .ok(); + auto dnn = stream->parent()->AsDnn(); + if (dnn == nullptr) { + context->SetStatus(absl::InternalError("No DNN support for stream")); + return; + } + bool cudnn_launch_status = dnn->DoBatchNormalizationForward( + stream, x_ptr, scale_ptr, offset_ptr, estimated_mean_ptr, + estimated_variance_ptr, side_input_ptr, x_desc, scale_offset_desc, + static_cast(epsilon), + static_cast(exponential_avg_factor), + AsDnnActivationMode(activation_mode), &y_ptr, &batch_mean_ptr, + &batch_var_ptr, &saved_mean_ptr, &saved_inv_var_ptr, is_training, + reserve_space_allocator.get(), workspace_allocator.get()); if (!cudnn_launch_status) { context->SetStatus( @@ -1256,18 +1248,19 @@ struct FusedBatchNormGradImplGPU { } } #endif // CUDNN_VERSION >= 7402 + auto dnn = stream->parent()->AsDnn(); + if (dnn == nullptr) { + context->SetStatus(absl::InternalError("No DNN support for stream")); + return; + } - bool cudnn_launch_status = - stream - ->ThenBatchNormalizationBackward( - y_backprop_ptr, x_ptr, scale_ptr, offset_ptr, mean_ptr, - inv_variance_ptr, y_ptr, x_desc, scale_offset_desc, - static_cast(epsilon), - AsDnnActivationMode(activation_mode), &x_backprop_ptr, - &scale_backprop_ptr, &offset_backprop_ptr, - &side_input_backprop_ptr, reserve_space_data_ptr, - workspace_allocator.get()) - .ok(); + bool cudnn_launch_status = dnn->DoBatchNormalizationBackward( + stream, y_backprop_ptr, x_ptr, scale_ptr, offset_ptr, mean_ptr, + inv_variance_ptr, y_ptr, x_desc, scale_offset_desc, + static_cast(epsilon), AsDnnActivationMode(activation_mode), + &x_backprop_ptr, &scale_backprop_ptr, &offset_backprop_ptr, + &side_input_backprop_ptr, reserve_space_data_ptr, + workspace_allocator.get()); if (!cudnn_launch_status) { context->SetStatus( diff --git a/tensorflow/core/kernels/fused_eigen_output_kernels.cc b/tensorflow/core/kernels/fused_eigen_output_kernels.cc index 9784441d57c0d6..39054713db5ac2 100644 --- a/tensorflow/core/kernels/fused_eigen_output_kernels.cc +++ b/tensorflow/core/kernels/fused_eigen_output_kernels.cc @@ -144,7 +144,7 @@ Status InitializeFusedComputation( } } - return OkStatus(); + return absl::OkStatus(); } } // namespace tensorflow diff --git a/tensorflow/core/kernels/fused_eigen_output_kernels.h b/tensorflow/core/kernels/fused_eigen_output_kernels.h index 9a50882a1016c7..21bcf17df3e9d6 100644 --- a/tensorflow/core/kernels/fused_eigen_output_kernels.h +++ b/tensorflow/core/kernels/fused_eigen_output_kernels.h @@ -428,7 +428,7 @@ Status InitBiasAddArgs(OpKernelContext* context, BiasAddArgs* args, args->leakyrelu_alpha = *leakyrelu_alpha; } - return OkStatus(); + return absl::OkStatus(); } template @@ -471,7 +471,7 @@ Status InitFusedBatchNormArgs(OpKernelContext* context, float epsilon, args->leakyrelu_alpha = *leakyrelu_alpha; } - return OkStatus(); + return absl::OkStatus(); } } // namespace tensorflow diff --git a/tensorflow/core/kernels/fuzzing/fuzz_session.h b/tensorflow/core/kernels/fuzzing/fuzz_session.h index 6aeadddb49b40b..09c7563d2efd17 100644 --- a/tensorflow/core/kernels/fuzzing/fuzz_session.h +++ b/tensorflow/core/kernels/fuzzing/fuzz_session.h @@ -83,7 +83,7 @@ class FuzzSession { // can't be put into the constructor. Status InitIfNeeded() { if (initialized_) { - return OkStatus(); + return absl::OkStatus(); } initialized_ = true; diff --git a/tensorflow/core/kernels/image/crop_and_resize_op.cc b/tensorflow/core/kernels/image/crop_and_resize_op.cc index 42d098612e23c0..8f8e5741349bfc 100644 --- a/tensorflow/core/kernels/image/crop_and_resize_op.cc +++ b/tensorflow/core/kernels/image/crop_and_resize_op.cc @@ -61,7 +61,7 @@ static inline Status ParseAndCheckBoxSizes(const Tensor& boxes, int* num_boxes) { if (boxes.NumElements() == 0 && box_index.NumElements() == 0) { *num_boxes = 0; - return OkStatus(); + return absl::OkStatus(); } // The shape of 'boxes' is [num_boxes, 4]. if (boxes.dims() != 2) { @@ -80,7 +80,7 @@ static inline Status ParseAndCheckBoxSizes(const Tensor& boxes, if (box_index.dim_size(0) != *num_boxes) { return errors::InvalidArgument("box_index has incompatible shape"); } - return OkStatus(); + return absl::OkStatus(); } // Conditionally calls the compute callback if all values in box_index are in diff --git a/tensorflow/core/kernels/image/scale_and_translate_op.cc b/tensorflow/core/kernels/image/scale_and_translate_op.cc index 7b9ecaeadc8df3..a2bfb80438c007 100644 --- a/tensorflow/core/kernels/image/scale_and_translate_op.cc +++ b/tensorflow/core/kernels/image/scale_and_translate_op.cc @@ -120,7 +120,7 @@ Status ComputeSpansCore(OpKernelContext* context, const Kernel& kernel, } starts_vec(x) = span_start; } - return OkStatus(); + return absl::OkStatus(); } Status ComputeGradSpansCore(OpKernelContext* context, const Spans& spans, @@ -180,7 +180,7 @@ Status ComputeGradSpansCore(OpKernelContext* context, const Spans& spans, grad_starts_vec(input_index) = 0; } } - return OkStatus(); + return absl::OkStatus(); } // Computes the spans for the passed kernel, for a input dimension of length @@ -229,7 +229,7 @@ Status ComputeSpans(OpKernelContext* context, return errors::InvalidArgument(Printf("Unrecognized kernel type: %d", static_cast(kernel_type))); } - return OkStatus(); + return absl::OkStatus(); } // Computes the grad spans for the passed kernel. diff --git a/tensorflow/core/kernels/immutable_constant_op.cc b/tensorflow/core/kernels/immutable_constant_op.cc index c496f8036de8da..2b9f8f34f5d4c1 100644 --- a/tensorflow/core/kernels/immutable_constant_op.cc +++ b/tensorflow/core/kernels/immutable_constant_op.cc @@ -32,7 +32,7 @@ class MemmappedTensorAllocator : public Allocator { if (!status.ok()) { return status; } - return OkStatus(); + return absl::OkStatus(); } string Name() override { return "MemmappedTensorAllocator"; } diff --git a/tensorflow/core/kernels/immutable_constant_op_test.cc b/tensorflow/core/kernels/immutable_constant_op_test.cc index d32e01ed526621..a91932a240b292 100644 --- a/tensorflow/core/kernels/immutable_constant_op_test.cc +++ b/tensorflow/core/kernels/immutable_constant_op_test.cc @@ -83,7 +83,7 @@ class TestFileSystem : public NullFileSystem { auto region = new TestReadOnlyMemoryRegion(kTestTensorSizeBytes); std::fill_n(region->GetWritableDataStart(), kTestTensorSize, val); result->reset(region); - return OkStatus(); + return absl::OkStatus(); } }; @@ -158,7 +158,7 @@ Status CreateTempFileFloat(Env* env, float value, uint64 size, TF_RETURN_IF_ERROR(file->Append(sp)); } TF_RETURN_IF_ERROR(file->Close()); - return OkStatus(); + return absl::OkStatus(); } TEST(ImmutableConstantOpTest, FromFile) { @@ -199,7 +199,7 @@ Status CreateTempFileBadString(Env* env, char value, uint64 size, TF_RETURN_IF_ERROR(env->NewWritableFile(*filename, &file)); TF_RETURN_IF_ERROR(file->Append(std::string(size, value))); TF_RETURN_IF_ERROR(file->Close()); - return OkStatus(); + return absl::OkStatus(); } TEST(ImmutableConstantOpTest, FromFileStringUnimplmented) { diff --git a/tensorflow/core/kernels/initializable_lookup_table.cc b/tensorflow/core/kernels/initializable_lookup_table.cc index 612c8c70fd3298..7c295b970ab538 100644 --- a/tensorflow/core/kernels/initializable_lookup_table.cc +++ b/tensorflow/core/kernels/initializable_lookup_table.cc @@ -54,7 +54,7 @@ Status InitializableLookupTable::ImportValues(OpKernelContext* ctx, .WithAttr("Tout", values.dtype())); *out = ops::UnaryOp("Identity", table, builder->opts().WithControlInput(import_table)); - return OkStatus(); + return absl::OkStatus(); }); return Initialize(iter, std::move(serializer)); @@ -84,7 +84,7 @@ Status InitializableLookupTable::Initialize( "Table was already initialized with " "different data."); } else { - return OkStatus(); + return absl::OkStatus(); } } TF_RETURN_IF_ERROR(DoLazyPrepare([&iter]() { return iter.total_size(); })); @@ -98,13 +98,13 @@ Status InitializableLookupTable::Initialize( initializer_serializer_ = std::move(serializer); is_initialized_.store(true, std::memory_order_release); - return OkStatus(); + return absl::OkStatus(); } Status InitializableLookupTable::AreEntriesSame(const InitTableIterator& iter, bool* result) { *result = static_cast(iter.total_size()) == size(); - return OkStatus(); + return absl::OkStatus(); } } // namespace lookup diff --git a/tensorflow/core/kernels/initializable_lookup_table.h b/tensorflow/core/kernels/initializable_lookup_table.h index 674352278ef568..010febb73e8cca 100644 --- a/tensorflow/core/kernels/initializable_lookup_table.h +++ b/tensorflow/core/kernels/initializable_lookup_table.h @@ -222,7 +222,7 @@ class KeyValueTensorIterator public: // keys and values are not owned by the iterator. explicit KeyValueTensorIterator(const Tensor* keys, const Tensor* values) - : keys_(keys), values_(values), valid_(true), status_(OkStatus()) { + : keys_(keys), values_(values), valid_(true), status_(absl::OkStatus()) { TensorShape key_shape = keys_->shape(); if (!key_shape.IsSameSize(values_->shape())) { valid_ = false; diff --git a/tensorflow/core/kernels/inplace_ops.cc b/tensorflow/core/kernels/inplace_ops.cc index 1ce48822ea2c20..9c1a502214f590 100644 --- a/tensorflow/core/kernels/inplace_ops.cc +++ b/tensorflow/core/kernels/inplace_ops.cc @@ -37,7 +37,7 @@ Status DoParallelConcatUpdate(const Device& d, const Tensor& value, int32_t loc, auto nrows = Toutput.dimension(0); auto r = (loc % nrows + nrows) % nrows; // Guard index range. Toutput.template chip<0>(r).device(d) = Tvalue.template chip<0>(0); - return OkStatus(); + return absl::OkStatus(); } template <> diff --git a/tensorflow/core/kernels/linalg/einsum_op_impl.h b/tensorflow/core/kernels/linalg/einsum_op_impl.h index da5c6718f4a271..6dc4b07070e81b 100644 --- a/tensorflow/core/kernels/linalg/einsum_op_impl.h +++ b/tensorflow/core/kernels/linalg/einsum_op_impl.h @@ -90,7 +90,7 @@ struct EinsumHelper { " but got dimension ", input_dim); } (*label_to_dim_sizes)[label] = input_dim; - return OkStatus(); + return absl::OkStatus(); } // Validate input dimensions and populate unnamed labels and their label @@ -160,7 +160,7 @@ struct EinsumHelper { } if (!absl::c_linear_search(input_has_ellipsis, true) && !output_has_ellipsis) { - return OkStatus(); + return absl::OkStatus(); } // Insert broadcasting dimensions in the output labels. auto it = @@ -178,7 +178,7 @@ struct EinsumHelper { // Populate EinsumDimensionType for the new broadcasting labels. label_types->resize(num_named_labels + max_bcast_dims, EinsumDimensionType::kBroadcasting); - return OkStatus(); + return absl::OkStatus(); } // Permutes the labels according to the given permutation. @@ -194,7 +194,7 @@ struct EinsumHelper { // Returns a reshaped input Tensor. The underlying buffer is not copied. static Status CopyFrom(const Tensor& input, const TensorShape& shape, Tensor* output) { - if (output->CopyFrom(input, shape)) return OkStatus(); + if (output->CopyFrom(input, shape)) return absl::OkStatus(); return errors::Internal( "Encountered error while reshaping a Tensor of shape ", input.shape().DebugString(), " to shape ", shape.DebugString()); @@ -234,7 +234,7 @@ struct EinsumHelper { ctx->allocate_temp(DataTypeToEnum::value, transposed_shape, output)); const Device& device = ctx->eigen_device(); TF_RETURN_IF_ERROR(DoTranspose(device, input, permutation, output)); - return OkStatus(); + return absl::OkStatus(); } // If there are repeated labels in either the input or output, then this @@ -310,7 +310,7 @@ struct EinsumHelper { " while handling repeated indices. Up to rank 6 is supported."); #undef NDIMS_CASE } - return OkStatus(); + return absl::OkStatus(); } // Returns true if the input dimensions are already sorted in the order @@ -413,7 +413,7 @@ struct EinsumHelper { const_cast(input_deduped) .shaped({output_size, reshape[kReduce]}), Eigen::array({1}), Reducer()); - return OkStatus(); + return absl::OkStatus(); } // Reshapes a Tensor of shape [b0,b1...bk,N,M] to [prod(b0,b1...bk),N,M]. @@ -464,7 +464,7 @@ struct EinsumHelper { if (lhs.NumElements() == 0 || rhs.NumElements() == 0) { functor::SetZeroFunctor set_zero; set_zero(ctx->eigen_device(), output->flat()); - return OkStatus(); + return absl::OkStatus(); } Tensor output_reshaped; TF_RETURN_IF_ERROR( @@ -473,7 +473,7 @@ struct EinsumHelper { /*adj_y=*/false, trans_x, trans_y, /*grad_x=*/false, /*grad_y=*/false, bcast, &output_reshaped); - return OkStatus(); + return absl::OkStatus(); } }; diff --git a/tensorflow/core/kernels/linalg/matrix_triangular_solve_op_impl.h b/tensorflow/core/kernels/linalg/matrix_triangular_solve_op_impl.h index 8ed5315e2babcf..8e5243476e4069 100644 --- a/tensorflow/core/kernels/linalg/matrix_triangular_solve_op_impl.h +++ b/tensorflow/core/kernels/linalg/matrix_triangular_solve_op_impl.h @@ -267,14 +267,9 @@ struct LaunchBatchMatrixTriangularSolve { if (!bcast.IsBroadcastingRequired() || out->shape() == in_y.shape()) { auto src_device_mem = AsDeviceMemory(in_y.template flat().data()); auto dst_device_mem = AsDeviceMemory(out->template flat().data()); - OP_REQUIRES( - context, - stream - ->ThenMemcpyD2D(&dst_device_mem, src_device_mem, - bcast.y_batch_size() * m * n * sizeof(Scalar)) - .ok(), - errors::Internal("MatrixTriangularSolveOp: failed to copy rhs " - "from device")); + OP_REQUIRES_OK(context, stream->MemcpyD2D(&dst_device_mem, src_device_mem, + bcast.y_batch_size() * m * n * + sizeof(Scalar))); } else { std::vector out_ptrs; std::vector b_tmp_ptrs; @@ -287,14 +282,9 @@ struct LaunchBatchMatrixTriangularSolve { auto src_device_mem = AsDeviceMemory(b_tmp_ptrs[b_batch_indices[i]]); auto dst_device_mem = AsDeviceMemory(out->template flat().data() + i * m * n); - OP_REQUIRES( - context, - stream - ->ThenMemcpyD2D(&dst_device_mem, src_device_mem, - m * n * sizeof(Scalar)) - .ok(), - errors::Internal("MatrixTriangularSolveOp: failed to copy rhs " - "from device")); + OP_REQUIRES_OK(context, + stream->MemcpyD2D(&dst_device_mem, src_device_mem, + m * n * sizeof(Scalar))); } } diff --git a/tensorflow/core/kernels/linalg/tridiagonal_solve_op_gpu.cu.cc b/tensorflow/core/kernels/linalg/tridiagonal_solve_op_gpu.cu.cc index ae2fbcb4c1cf44..1b9eb709a42595 100644 --- a/tensorflow/core/kernels/linalg/tridiagonal_solve_op_gpu.cu.cc +++ b/tensorflow/core/kernels/linalg/tridiagonal_solve_op_gpu.cu.cc @@ -76,8 +76,8 @@ void CopyDeviceToDevice(OpKernelContext* context, const Scalar* src, auto dst_device_mem = AsDeviceMemory(dst); auto* stream = context->op_device_context()->stream(); bool copy_status = stream - ->ThenMemcpyD2D(&dst_device_mem, src_device_mem, - sizeof(Scalar) * num_elements) + ->MemcpyD2D(&dst_device_mem, src_device_mem, + sizeof(Scalar) * num_elements) .ok(); if (!copy_status) { diff --git a/tensorflow/core/kernels/list_kernels.cc b/tensorflow/core/kernels/list_kernels.cc index a1dcb5058ef989..fa4b082c9779c2 100644 --- a/tensorflow/core/kernels/list_kernels.cc +++ b/tensorflow/core/kernels/list_kernels.cc @@ -51,7 +51,7 @@ Status TensorShapeFromTensor(const Tensor& t, PartialTensorShape* out) { if ((t.dtype() == DT_INT32 && t.scalar()() == -1) || (t.dtype() == DT_INT64 && t.scalar()() == -1)) { *out = PartialTensorShape(); - return OkStatus(); + return absl::OkStatus(); } return errors::InvalidArgument( "The only valid scalar shape tensor is the fully unknown shape " @@ -80,7 +80,7 @@ Status GetElementShapeFromInput(OpKernelContext* c, // compatible and store the merged shape in `element_shape`. PartialTensorShape tmp = *element_shape; TF_RETURN_IF_ERROR(tmp.MergeWith(tensor_list.element_shape, element_shape)); - return OkStatus(); + return absl::OkStatus(); } Status GetInputList(OpKernelContext* c, int index, const TensorList** list) { @@ -95,7 +95,7 @@ Status GetInputList(OpKernelContext* c, int index, const TensorList** list) { c->input(index).scalar()().DebugString(), "'"); } *list = l; - return OkStatus(); + return absl::OkStatus(); } Status ForwardInputOrCreateNewList(OpKernelContext* c, int32_t input_index, @@ -120,7 +120,7 @@ Status ForwardInputOrCreateNewList(OpKernelContext* c, int32_t input_index, // Woohoo, forwarding succeeded! c->set_output(output_index, *output_tensor); *output_list = tmp_out; - return OkStatus(); + return absl::OkStatus(); } } @@ -133,7 +133,7 @@ Status ForwardInputOrCreateNewList(OpKernelContext* c, int32_t input_index, output_tensor->scalar()() = input_list.Copy(); *output_list = output_tensor->scalar()().get(); - return OkStatus(); + return absl::OkStatus(); } class EmptyTensorList : public OpKernel { @@ -710,7 +710,7 @@ static Status TensorListDeviceCopy( TF_RETURN_IF_ERROR(copy(t, &to->tensors().back())); } } - return OkStatus(); + return absl::OkStatus(); } #define REGISTER_LIST_COPY(DIRECTION) \ diff --git a/tensorflow/core/kernels/list_kernels.h b/tensorflow/core/kernels/list_kernels.h index 41f475799fc9b9..a43fc195c702ab 100644 --- a/tensorflow/core/kernels/list_kernels.h +++ b/tensorflow/core/kernels/list_kernels.h @@ -845,7 +845,7 @@ Status Scatter(OpKernelContext* c, const Tensor& value, const Tensor& indices, copy_tensor(c, tmp, aligned); std::swap(list->tensors()[i], aligned); } - return OkStatus(); + return absl::OkStatus(); } template diff --git a/tensorflow/core/kernels/load_and_remap_matrix_op.cc b/tensorflow/core/kernels/load_and_remap_matrix_op.cc index 3d9d048bd7167e..fb2f9d40495c94 100644 --- a/tensorflow/core/kernels/load_and_remap_matrix_op.cc +++ b/tensorflow/core/kernels/load_and_remap_matrix_op.cc @@ -50,7 +50,7 @@ Status RemapVectorToMap( ", which is not supported.")); } } - return OkStatus(); + return absl::OkStatus(); } } // anonymous namespace diff --git a/tensorflow/core/kernels/logging_ops_test.cc b/tensorflow/core/kernels/logging_ops_test.cc index 6eb1f226400da8..8ba44782a194c0 100644 --- a/tensorflow/core/kernels/logging_ops_test.cc +++ b/tensorflow/core/kernels/logging_ops_test.cc @@ -50,13 +50,13 @@ TEST_F(PrintingV2GraphTest, StringSuccess) { } TEST_F(PrintingV2GraphTest, InvalidOutputStream) { - ASSERT_NE(OkStatus(), (Init("invalid_output_stream"))); + ASSERT_NE(absl::OkStatus(), (Init("invalid_output_stream"))); } TEST_F(PrintingV2GraphTest, InvalidInputRank) { TF_ASSERT_OK(Init()); AddInputFromArray(TensorShape({2}), {"bar", "foo"}); - ASSERT_NE(OkStatus(), RunOpKernel()); + ASSERT_NE(absl::OkStatus(), RunOpKernel()); } class PrintingGraphTest : public OpsTestBase { diff --git a/tensorflow/core/kernels/logistic-loss.h b/tensorflow/core/kernels/logistic-loss.h index 6a7016a9625345..d06b1228f62a59 100644 --- a/tensorflow/core/kernels/logistic-loss.h +++ b/tensorflow/core/kernels/logistic-loss.h @@ -99,10 +99,10 @@ class LogisticLossUpdater : public DualLossUpdater { Status ConvertLabel(float* const example_label) const final { if (*example_label == 0.0) { *example_label = -1; - return OkStatus(); + return absl::OkStatus(); } if (*example_label == 1.0) { - return OkStatus(); + return absl::OkStatus(); } return errors::InvalidArgument( "Only labels of 0.0 or 1.0 are supported right now. " diff --git a/tensorflow/core/kernels/lookup_table_init_op.cc b/tensorflow/core/kernels/lookup_table_init_op.cc index 1a80ba41958eef..27cc76ee11b945 100644 --- a/tensorflow/core/kernels/lookup_table_init_op.cc +++ b/tensorflow/core/kernels/lookup_table_init_op.cc @@ -178,7 +178,7 @@ class InitializeTableFromTextFileOp : public OpKernel { .WithAttr("delimiter", delimiter_string)); *out = ops::UnaryOp("Identity", table, builder->opts().WithControlInput(import_table)); - return OkStatus(); + return absl::OkStatus(); }); } diff --git a/tensorflow/core/kernels/lookup_table_op.cc b/tensorflow/core/kernels/lookup_table_op.cc index 50dd6bd1d8ed2a..78a2716f0b95d2 100644 --- a/tensorflow/core/kernels/lookup_table_op.cc +++ b/tensorflow/core/kernels/lookup_table_op.cc @@ -83,7 +83,7 @@ class MutableHashTableOfScalars final : public LookupInterface { is_full_size_default ? default_flat(i) : default_flat(0)); } - return OkStatus(); + return absl::OkStatus(); } Status DoInsert(bool clear, const Tensor& keys, const Tensor& values) { @@ -98,7 +98,7 @@ class MutableHashTableOfScalars final : public LookupInterface { gtl::InsertOrUpdate(&table_, SubtleMustCopyIfIntegral(key_values(i)), SubtleMustCopyIfIntegral(value_values(i))); } - return OkStatus(); + return absl::OkStatus(); } Status Insert(OpKernelContext* ctx, const Tensor& keys, @@ -113,7 +113,7 @@ class MutableHashTableOfScalars final : public LookupInterface { for (int64_t i = 0; i < key_values.size(); ++i) { table_.erase(SubtleMustCopyIfIntegral(key_values(i))); } - return OkStatus(); + return absl::OkStatus(); } Status ImportValues(OpKernelContext* ctx, const Tensor& keys, @@ -132,7 +132,7 @@ class MutableHashTableOfScalars final : public LookupInterface { TF_RETURN_IF_ERROR( ctx->allocate_output("values", TensorShape({size}), &values)); ExportKeysAndValues(keys, values); - return OkStatus(); + return absl::OkStatus(); } DataType key_dtype() const override { return DataTypeToEnum::v(); } @@ -191,7 +191,7 @@ class MutableHashTableOfScalars final : public LookupInterface { .WithAttr("Tout", value_dtype())); *out = ops::UnaryOp("Identity", table, builder->opts().WithControlInput(import_table)); - return OkStatus(); + return absl::OkStatus(); } private: @@ -264,7 +264,7 @@ class MutableHashTableOfTensors final : public LookupInterface { } } - return OkStatus(); + return absl::OkStatus(); } Status DoInsert(bool clear, const Tensor& keys, const Tensor& values) { @@ -285,7 +285,7 @@ class MutableHashTableOfTensors final : public LookupInterface { gtl::InsertOrUpdate(&table_, SubtleMustCopyIfIntegral(key_values(i)), value_vec); } - return OkStatus(); + return absl::OkStatus(); } Status Insert(OpKernelContext* ctx, const Tensor& keys, @@ -300,7 +300,7 @@ class MutableHashTableOfTensors final : public LookupInterface { for (int64_t i = 0; i < key_values.size(); ++i) { table_.erase(SubtleMustCopyIfIntegral(key_values(i))); } - return OkStatus(); + return absl::OkStatus(); } Status ImportValues(OpKernelContext* ctx, const Tensor& keys, @@ -320,7 +320,7 @@ class MutableHashTableOfTensors final : public LookupInterface { TF_RETURN_IF_ERROR(ctx->allocate_output( "values", TensorShape({size, value_dim}), &values)); ExportKeysAndValues(keys, values); - return OkStatus(); + return absl::OkStatus(); } DataType key_dtype() const override { return DataTypeToEnum::v(); } @@ -380,7 +380,7 @@ class MutableHashTableOfTensors final : public LookupInterface { .WithAttr("Tout", value_dtype())); *out = ops::UnaryOp("Identity", table, builder->opts().WithControlInput(import_table)); - return OkStatus(); + return absl::OkStatus(); } private: @@ -560,7 +560,7 @@ class MutableDenseHashTable final : public LookupInterface { } } } - return OkStatus(); + return absl::OkStatus(); } Status Insert(OpKernelContext* ctx, const Tensor& key, @@ -623,14 +623,14 @@ class MutableDenseHashTable final : public LookupInterface { ++num_entries_; } } - return OkStatus(); + return absl::OkStatus(); } Status ExportValues(OpKernelContext* ctx) override TF_LOCKS_EXCLUDED(mu_) { tf_shared_lock l(mu_); TF_RETURN_IF_ERROR(ctx->set_output("keys", key_buckets_)); TF_RETURN_IF_ERROR(ctx->set_output("values", value_buckets_)); - return OkStatus(); + return absl::OkStatus(); } Status CheckKeyAndValueTensorsForImport(const Tensor& keys, @@ -655,7 +655,7 @@ class MutableDenseHashTable final : public LookupInterface { "Expected shape ", expected_value_shape.DebugString(), " for value, got ", values.shape().DebugString()); } - return OkStatus(); + return absl::OkStatus(); } DataType key_dtype() const override { return DataTypeToEnum::v(); } @@ -740,7 +740,7 @@ class MutableDenseHashTable final : public LookupInterface { } } } - return OkStatus(); + return absl::OkStatus(); } Status DoRemove(OpKernelContext* ctx, const Tensor& key) @@ -791,7 +791,7 @@ class MutableDenseHashTable final : public LookupInterface { } } } - return OkStatus(); + return absl::OkStatus(); } Status AllocateBuckets(OpKernelContext* ctx, int64_t new_num_buckets) @@ -829,7 +829,7 @@ class MutableDenseHashTable final : public LookupInterface { value_buckets_matrix(i, j) = V(); } } - return OkStatus(); + return absl::OkStatus(); } Status Rebucket(OpKernelContext* ctx, int64_t num_new_buckets) diff --git a/tensorflow/core/kernels/lookup_table_op.h b/tensorflow/core/kernels/lookup_table_op.h index 97d5f404d7c1df..f4855a22d73665 100644 --- a/tensorflow/core/kernels/lookup_table_op.h +++ b/tensorflow/core/kernels/lookup_table_op.h @@ -79,7 +79,7 @@ class LookupTableOp : public OpKernel { container->MemoryUsed() + table_.AllocatedBytes()); } *ret = container; - return OkStatus(); + return absl::OkStatus(); }; lookup::LookupInterface* table = nullptr; @@ -236,7 +236,7 @@ class HashTable : public InitializableLookupTable { .WithAttr("use_node_name_sharing", true)); if (table_.empty()) { *out = hash_table_node; - return OkStatus(); + return absl::OkStatus(); } if (initializer_serializer_ == nullptr) { @@ -251,7 +251,7 @@ class HashTable : public InitializableLookupTable { builder, hash_table_node, &initializer)); *out = ops::UnaryOp("Identity", hash_table_node, builder->opts().WithControlInput(initializer)); - return OkStatus(); + return absl::OkStatus(); } size_t size() const override { @@ -282,7 +282,7 @@ class HashTable : public InitializableLookupTable { keys_data(i) = it->first; values_data(i) = it->second; } - return OkStatus(); + return absl::OkStatus(); } DataType key_dtype() const override { return DataTypeToEnum::v(); } @@ -297,7 +297,7 @@ class HashTable : public InitializableLookupTable { if (size > 0) { table_.reserve(size); } - return OkStatus(); + return absl::OkStatus(); }; Status DoLazyPrepare(std::function size_fn) override { @@ -317,7 +317,7 @@ class HashTable : public InitializableLookupTable { result.first->second, " and trying to add value ", value); } } - return OkStatus(); + return absl::OkStatus(); } Status DoFind(const Tensor& key, Tensor* value, @@ -330,7 +330,7 @@ class HashTable : public InitializableLookupTable { value_values(i) = gtl::FindWithDefault( table_, SubtleMustCopyIfIntegral(key_values(i)), default_val); } - return OkStatus(); + return absl::OkStatus(); } int64_t MemoryUsed() const override { diff --git a/tensorflow/core/kernels/lookup_util.cc b/tensorflow/core/kernels/lookup_util.cc index 2ea77327c7df3b..44b93c75e5c988 100644 --- a/tensorflow/core/kernels/lookup_util.cc +++ b/tensorflow/core/kernels/lookup_util.cc @@ -54,7 +54,7 @@ Status GetNumLinesInTextFile(Env* env, const string& vocab_file, return s; } *num_lines = next_id; - return OkStatus(); + return absl::OkStatus(); } // Iterator that reads a text file. Each iteration process one line, it parses @@ -210,7 +210,7 @@ class TextFileLineIterator int64_t index, Tensor* tensor) { if (index == kLineNumber) { tensor->flat()(0) = next_id_ + offset_; - return OkStatus(); + return absl::OkStatus(); } const string& token = (index == kWholeLine) ? line : tokens[index]; const DataType& dtype = tensor->dtype(); @@ -259,7 +259,7 @@ class TextFileLineIterator return errors::InvalidArgument("Data type ", DataTypeString(dtype), " not supported."); } - return OkStatus(); + return absl::OkStatus(); } TextFileLineIterator(const TextFileLineIterator&) = delete; @@ -283,7 +283,7 @@ Status GetTableHandle(StringPiece input_name, OpKernelContext* ctx, *container = h(0); *table_handle = h(1); } - return OkStatus(); + return absl::OkStatus(); } } // namespace @@ -345,7 +345,7 @@ Status GetInitializableLookupTable(StringPiece input_name, OpKernelContext* ctx, " is not initializable"); } } - return OkStatus(); + return absl::OkStatus(); } Status CheckTableDataTypes(const LookupInterface& table, DataType key_dtype, @@ -357,7 +357,7 @@ Status CheckTableDataTypes(const LookupInterface& table, DataType key_dtype, DataTypeString(table.key_dtype()), "-", DataTypeString(table.value_dtype()), " for table ", table_name); } - return OkStatus(); + return absl::OkStatus(); } // Helper function to initialize an InitializableLookupTable from a text file. @@ -413,7 +413,7 @@ Status InitializeTableFromTextFile( if (absl::IsFailedPrecondition(s) && table->is_initialized()) { LOG(INFO) << "Table trying to initialize from file " << filename << " is already initialized."; - return OkStatus(); + return absl::OkStatus(); } return s; } diff --git a/tensorflow/core/kernels/lrn_op.cc b/tensorflow/core/kernels/lrn_op.cc index 22a25eeb6842a4..aa4eafabd0ab71 100644 --- a/tensorflow/core/kernels/lrn_op.cc +++ b/tensorflow/core/kernels/lrn_op.cc @@ -219,11 +219,11 @@ struct LaunchLRN { auto* stream = context->op_device_context()->stream(); OP_REQUIRES(context, stream, errors::Internal("No GPU stream available.")); - bool status = - stream - ->ThenNormalizeWithDimensions(normalize_desc, dimensions_desc, - input_data, &output_data) - .ok(); + auto dnn = stream->parent()->AsDnn(); + OP_REQUIRES(context, dnn != nullptr, + absl::InternalError("No DNN support for stream.")); + bool status = dnn->DoNormalizeWithDimensions( + stream, normalize_desc, dimensions_desc, input_data, &output_data); OP_REQUIRES(context, status, errors::Internal("NormalizeWithDimensions launch failed")); #elif TENSORFLOW_USE_ROCM @@ -279,12 +279,12 @@ struct LaunchLRN { auto* stream = context->op_device_context()->stream(); OP_REQUIRES(context, stream, errors::Internal("No GPU stream available.")); + auto dnn = stream->parent()->AsDnn(); + OP_REQUIRES(context, dnn != nullptr, + absl::InternalError("No DNN support for stream.")); - bool status = - stream - ->ThenNormalizeWithDimensions(normalize_desc, dimensions_desc, - input_data, &output_data) - .ok(); + bool status = dnn->DoNormalizeWithDimensions( + stream, normalize_desc, dimensions_desc, input_data, &output_data); OP_REQUIRES(context, status, errors::Internal("NormalizeWithDimensions launch failed")); @@ -517,12 +517,13 @@ struct LaunchLRNGrad { auto* stream = context->op_device_context()->stream(); OP_REQUIRES(context, stream, errors::Internal("No GPU stream available.")); - bool status = - stream - ->ThenNormalizeBackwardWithDimensions( - normalize_desc, dimensions_desc, input_image_data, - output_image_data, input_grads_data, &output_grads_data) - .ok(); + auto dnn = stream->parent()->AsDnn(); + OP_REQUIRES(context, dnn != nullptr, + absl::InternalError("No DNN support for stream.")); + bool status = dnn->DoNormalizeBackwardWithDimensions( + stream, normalize_desc, dimensions_desc, input_image_data, + output_image_data, input_grads_data, &output_grads_data, + /*workspace_allocator=*/nullptr); OP_REQUIRES( context, status, errors::Internal("NormalizeBackwardWithDimensions launch failed")); @@ -616,12 +617,13 @@ struct LaunchLRNGrad { DnnScratchAllocator scratch_allocator(NormalizeBackwardScratchSize, context); - bool status = stream - ->ThenNormalizeBackwardWithDimensions( - normalize_desc, dimensions_desc, input_image_data, - output_image_data, input_grads_data, - &output_grads_data, &scratch_allocator) - .ok(); + auto dnn = stream->parent()->AsDnn(); + OP_REQUIRES(context, dnn != nullptr, + absl::InternalError("No DNN support for stream.")); + bool status = dnn->DoNormalizeBackwardWithDimensions( + stream, normalize_desc, dimensions_desc, input_image_data, + output_image_data, input_grads_data, &output_grads_data, + /*workspace_allocator=*/nullptr, &scratch_allocator); OP_REQUIRES( context, status, errors::Internal("NormalizeBackwardWithDimensions launch failed")); diff --git a/tensorflow/core/kernels/map_kernels.h b/tensorflow/core/kernels/map_kernels.h index 6a05762983d627..ad01ef15932661 100644 --- a/tensorflow/core/kernels/map_kernels.h +++ b/tensorflow/core/kernels/map_kernels.h @@ -35,7 +35,7 @@ inline Status GetInputMap(OpKernelContext* ctx, int index, ctx->input(index).scalar()().DebugString(), "'"); } *ret_map = map; - return OkStatus(); + return absl::OkStatus(); } // TODO(kattian): change into templated function @@ -62,7 +62,7 @@ inline Status ForwardInputOrCreateNewMap(OpKernelContext* ctx, // Woohoo, forwarding succeeded! ctx->set_output(output_index, *output_tensor); *output_map = tmp_out; - return OkStatus(); + return absl::OkStatus(); } } @@ -75,7 +75,7 @@ inline Status ForwardInputOrCreateNewMap(OpKernelContext* ctx, output_tensor->scalar()() = input_map.Copy(); *output_map = output_tensor->scalar()().get(); - return OkStatus(); + return absl::OkStatus(); } class EmptyTensorMap : public OpKernel { @@ -240,14 +240,14 @@ Status TensorMapBinaryAdd(OpKernelContext* ctx, const TensorMap& a, out->tensors().emplace(p.first, p.second); } } - return OkStatus(); + return absl::OkStatus(); } template Status TensorMapZerosLike(OpKernelContext* ctx, const TensorMap& x, TensorMap* y) { // Zeros like returns an empty map. - return OkStatus(); + return absl::OkStatus(); } } // namespace tensorflow diff --git a/tensorflow/core/kernels/map_stage_op.cc b/tensorflow/core/kernels/map_stage_op.cc index d60d3435ff5294..bf16340081466a 100644 --- a/tensorflow/core/kernels/map_stage_op.cc +++ b/tensorflow/core/kernels/map_stage_op.cc @@ -171,7 +171,7 @@ class StagingMap : public ResourceBase { "' was out of bounds '", dtypes_.size(), "'.")); } - return OkStatus(); + return absl::OkStatus(); } Status copy_or_move_tensors(OptionalTuple* map_tuple, const Tensor& key, @@ -203,7 +203,7 @@ class StagingMap : public ResourceBase { } } - return OkStatus(); + return absl::OkStatus(); } // Check that the optional value at the specified index @@ -218,7 +218,7 @@ class StagingMap : public ResourceBase { dtypes_.size(), "'."); } - return OkStatus(); + return absl::OkStatus(); } // Check that the indices are strictly ordered @@ -237,7 +237,7 @@ class StagingMap : public ResourceBase { return errors::InvalidArgument("Indices are not strictly ordered"); } - return OkStatus(); + return absl::OkStatus(); } // Check bytes are within memory limits memory limits @@ -250,7 +250,7 @@ class StagingMap : public ResourceBase { "'."); } - return OkStatus(); + return absl::OkStatus(); } // Insert incomplete data into the Barrier @@ -327,7 +327,7 @@ class StagingMap : public ResourceBase { } } - return OkStatus(); + return absl::OkStatus(); } // Does the insertion into the actual staging area @@ -338,7 +338,7 @@ class StagingMap : public ResourceBase { notify_removers(); - return OkStatus(); + return absl::OkStatus(); } public: @@ -376,7 +376,7 @@ class StagingMap : public ResourceBase { // Update the current size current_bytes_ += tuple_bytes; - return OkStatus(); + return absl::OkStatus(); } Status get(const KeyType* key, const Tensor* indices, Tuple* tuple) { @@ -398,7 +398,7 @@ class StagingMap : public ResourceBase { // Update bytes in the Staging Area current_bytes_ -= get_tuple_bytes(*tuple); - return OkStatus(); + return absl::OkStatus(); } Status pop(const KeyType* key, const Tensor* indices, Tuple* tuple) { @@ -429,7 +429,7 @@ class StagingMap : public ResourceBase { notify_inserters_if_bounded(); - return OkStatus(); + return absl::OkStatus(); } Status popitem(KeyType* key, const Tensor* indices, Tuple* tuple) { @@ -464,7 +464,7 @@ class StagingMap : public ResourceBase { notify_inserters_if_bounded(); - return OkStatus(); + return absl::OkStatus(); } Status clear() { @@ -475,7 +475,7 @@ class StagingMap : public ResourceBase { notify_inserters_if_bounded(); - return OkStatus(); + return absl::OkStatus(); } std::size_t incomplete_size() { @@ -506,13 +506,13 @@ Status GetStagingMap(OpKernelContext* ctx, const NodeDef& ndef, TF_RETURN_IF_ERROR(GetNodeAttr(ndef, "capacity", &capacity)); TF_RETURN_IF_ERROR(GetNodeAttr(ndef, "memory_limit", &memory_limit)); *ret = new StagingMap(dtypes, capacity, memory_limit); - return OkStatus(); + return absl::OkStatus(); }; TF_RETURN_IF_ERROR(cinfo.Init(rm, ndef, true /* use name() */)); TF_RETURN_IF_ERROR(rm->LookupOrCreate>( cinfo.container(), cinfo.name(), map, create_fn)); - return OkStatus(); + return absl::OkStatus(); } template diff --git a/tensorflow/core/kernels/matmul_op_fused.cc b/tensorflow/core/kernels/matmul_op_fused.cc index f937a06016dc8c..e38b854db4ec19 100644 --- a/tensorflow/core/kernels/matmul_op_fused.cc +++ b/tensorflow/core/kernels/matmul_op_fused.cc @@ -361,7 +361,11 @@ StatusOr> AutotuneFusedMatmul( std::vector> runners; auto element_type = se::dnn::ToDataType::value; - TF_RETURN_IF_ERROR(stream->parent()->GetFusedMatmulRunners( + auto dnn = stream->parent()->AsDnn(); + if (dnn == nullptr) { + return errors::Internal("No DNN in stream executor."); + } + TF_RETURN_IF_ERROR(dnn->GetFusedMatmulRunners( CudnnUseFrontend(), element_type, element_type, element_type, stream, trans_a, trans_b, m, n, k, lda, ldb, ldc, activation_mode, /*use_fallback=*/false, GetNumericOptions(), &runners)); @@ -405,7 +409,7 @@ StatusOr> AutotuneFusedMatmul( << params.ToString(); std::vector> fallback_runners; - TF_RETURN_IF_ERROR(stream->parent()->GetFusedMatmulRunners( + TF_RETURN_IF_ERROR(dnn->GetFusedMatmulRunners( CudnnUseFrontend(), element_type, element_type, element_type, stream, trans_a, trans_b, m, n, k, lda, ldb, ldc, activation_mode, /*use_fallback=*/true, GetNumericOptions(), &fallback_runners)); diff --git a/tensorflow/core/kernels/matmul_op_impl.h b/tensorflow/core/kernels/matmul_op_impl.h index 7180fb1d4e35f9..9c5fe075d97ff5 100644 --- a/tensorflow/core/kernels/matmul_op_impl.h +++ b/tensorflow/core/kernels/matmul_op_impl.h @@ -720,15 +720,15 @@ struct LaunchBatchMatMul { } BlasScratchAllocator scratch_allocator(context, max_scratch_size); - bool blas_launch_status = - stream - ->ThenBlasGemmBatchedWithScratch( - blas_transpose_b, blas_transpose_a, n, m, k, - static_cast(1.0), b_ptrs, - adj_y || trans_y ? k : n, a_ptrs, adj_x || trans_x ? m : k, - static_cast(0.0), c_ptrs, n, batch_size, - GetNumericOptions(), &scratch_allocator, call_context) - .ok(); + auto blas = stream->parent()->AsBlas(); + OP_REQUIRES(context, blas != nullptr, + absl::InternalError("No blas support for stream")); + bool blas_launch_status = blas->DoBlasGemmBatched( + stream, blas_transpose_b, blas_transpose_a, n, m, k, + static_cast(1.0), b_ptrs, adj_y || trans_y ? k : n, + a_ptrs, adj_x || trans_x ? m : k, static_cast(0.0), + c_ptrs, n, batch_size, GetNumericOptions(), &scratch_allocator, + call_context); if (!blas_launch_status) { context->SetStatus(errors::Internal( "Blas xGEMMBatched launch failed: a.shape=", @@ -785,6 +785,9 @@ struct LaunchBatchMatMul { // C' = B' x A', where ' stands for transpose (not adjoint). // TODO(yangzihao): Choose the best of the three strategies using // autotune. + auto blas = stream->parent()->AsBlas(); + OP_REQUIRES(context, blas != nullptr, + absl::InternalError("No blas support for stream")); if (batch_size == 1) { // This is a regular matrix*matrix or matrix*vector multiply. Avoid the // overhead of the scratch allocator and the batch interface. @@ -803,14 +806,11 @@ struct LaunchBatchMatMul { blas_transpose_a == se::blas::Transpose::kTranspose ? se::blas::Transpose::kNoTranspose : se::blas::Transpose::kTranspose; - bool blas_launch_status = - stream - ->ThenBlasGemv(gemv_trans_a, adj_x || trans_x ? m : k, - adj_x || trans_x ? k : m, - static_cast(1.0), *(a_ptrs[0]), - adj_x || trans_x ? m : k, *(b_ptrs[0]), 1, - static_cast(0.0), c_ptrs[0], 1) - .ok(); + bool blas_launch_status = blas->DoBlasGemv( + stream, gemv_trans_a, adj_x || trans_x ? m : k, + adj_x || trans_x ? k : m, static_cast(1.0), + *(a_ptrs[0]), adj_x || trans_x ? m : k, *(b_ptrs[0]), 1, + static_cast(0.0), c_ptrs[0], 1); if (!blas_launch_status) { context->SetStatus(errors::Internal( "Blas xGEMV launch failed : a.shape=", @@ -821,16 +821,16 @@ struct LaunchBatchMatMul { } } - OP_REQUIRES_OK(context, - stream->ThenBlasGemm( - blas_transpose_b, blas_transpose_a, n, m, k, + OP_REQUIRES_OK( + context, + blas->BlasGemm(stream, blas_transpose_b, blas_transpose_a, n, m, k, *(b_ptrs[0]), adj_y || trans_y ? k : n, *(a_ptrs[0]), adj_x || trans_x ? m : k, c_ptrs[0], n, GetNumericOptions(), call_context)); } else if (use_strided_batched) { OP_REQUIRES_OK( - context, stream->ThenBlasGemmStridedBatched( - blas_transpose_b, blas_transpose_a, n, m, k, + context, blas->BlasGemmStridedBatched( + stream, blas_transpose_b, blas_transpose_a, n, m, k, static_cast(1.0), *b_ptrs[0], adj_y || trans_y ? k : n, b_stride, *a_ptrs[0], adj_x || trans_x ? m : k, a_stride, @@ -838,15 +838,12 @@ struct LaunchBatchMatMul { batch_size, GetNumericOptions(), call_context)); } else { BlasScratchAllocator scratch_allocator(context); - bool blas_launch_status = - stream - ->ThenBlasGemmBatchedWithScratch( - blas_transpose_b, blas_transpose_a, n, m, k, - static_cast(1.0), b_ptrs, - adj_y || trans_y ? k : n, a_ptrs, adj_x || trans_x ? m : k, - static_cast(0.0), c_ptrs, n, batch_size, - GetNumericOptions(), &scratch_allocator, call_context) - .ok(); + bool blas_launch_status = blas->DoBlasGemmBatched( + stream, blas_transpose_b, blas_transpose_a, n, m, k, + static_cast(1.0), b_ptrs, adj_y || trans_y ? k : n, + a_ptrs, adj_x || trans_x ? m : k, static_cast(0.0), + c_ptrs, n, batch_size, GetNumericOptions(), &scratch_allocator, + call_context); if (!blas_launch_status) { context->SetStatus(errors::Internal( "Blas xGEMMBatched launch failed : a.shape=", @@ -1084,7 +1081,7 @@ class BatchMatMulOp : public BaseBatchMatMulOp { } } } - return OkStatus(); + return absl::OkStatus(); } }; @@ -1110,7 +1107,7 @@ class BatchMatMulV2Op : public BaseBatchMatMulOp { if (in1.dims() < 2) { return errors::InvalidArgument("In[1] ndims must be >= 2: ", in1.dims()); } - return OkStatus(); + return absl::OkStatus(); } }; diff --git a/tensorflow/core/kernels/matmul_op_test.cc b/tensorflow/core/kernels/matmul_op_test.cc index 96c37ac97817b8..897d8fd1772b07 100644 --- a/tensorflow/core/kernels/matmul_op_test.cc +++ b/tensorflow/core/kernels/matmul_op_test.cc @@ -55,7 +55,7 @@ class FusedMatMulOpTest : public OpsTestBase { void RunAndFetch(const tensorflow::Scope& root, const string& fetch, Tensor* output, bool allow_gpu_device, const NodeDef* fetch_node = nullptr, - tsl::Status* last_status = nullptr) { + absl::Status* last_status = nullptr) { tensorflow::GraphDef graph; TF_ASSERT_OK(root.ToGraphDef(&graph)); @@ -208,7 +208,7 @@ class FusedMatMulOpTest : public OpsTestBase { .Attr("transpose_b", transpose_b) .Finalize(&fused_matmul)); - tsl::Status last_status; + absl::Status last_status; RunAndFetch(root, fused_matmul.name(), output, allow_gpu_device, &fused_matmul, &last_status); diff --git a/tensorflow/core/kernels/mkl/BUILD b/tensorflow/core/kernels/mkl/BUILD index 72d1971c501d15..6e83cb8afef9da 100644 --- a/tensorflow/core/kernels/mkl/BUILD +++ b/tensorflow/core/kernels/mkl/BUILD @@ -101,32 +101,6 @@ tf_mkl_kernel_library( ] + MKL_DEPS, ) -tf_mkl_kernel_library( - name = "mkl_sparse_matrix_matmul_op", - srcs = [ - "mkl_sparse_matrix_matmul_op.cc", - ], - hdrs = [ - "mkl_kernel_util.h", - "mkl_matmul_ops_common.h", - ], - deps = [ - "//tensorflow/core/kernels:cwise_op", - "//tensorflow/core/kernels:dense_update_functor", - "//tensorflow/core/kernels/sparse:kernels", - ] + MKL_DEPS, -) - -tf_cc_test_mkl( - name = "mkl_sparse_matrix_matmul_op_benchmark", - size = "small", - srcs = ["mkl_sparse_matrix_matmul_op_benchmark.cc"], - linkstatic = 1, - deps = [ - "//tensorflow/core/kernels/mkl:mkl_sparse_matrix_matmul_op", - ] + MKL_TEST_DEPS, -) - tf_cc_test_mkl( name = "mkl_quantized_conv_ops_perchannel_test", size = "small", diff --git a/tensorflow/core/kernels/mkl/mkl_matmul_ops_common.h b/tensorflow/core/kernels/mkl/mkl_matmul_ops_common.h index 7342e348a12ed4..6c5364979fa6df 100644 --- a/tensorflow/core/kernels/mkl/mkl_matmul_ops_common.h +++ b/tensorflow/core/kernels/mkl/mkl_matmul_ops_common.h @@ -1,4 +1,4 @@ -/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -21,6 +21,7 @@ limitations under the License. #include #include +#include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive #include "dnnl.hpp" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op_kernel.h" @@ -28,7 +29,6 @@ limitations under the License. #include "tensorflow/core/kernels/mkl/mkl_kernel_util.h" #include "tensorflow/core/util/mkl_util.h" #include "tensorflow/core/util/onednn_env_vars.h" -#include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive #if defined(DNNL_AARCH64_USE_ACL) && defined(ENABLE_ONEDNN_OPENMP) #include "tensorflow/core/platform/mutex.h" #endif @@ -776,7 +776,6 @@ struct MklMatMulParams { memory::dims a_strides; memory::dims b_strides; memory::dims c_strides; - memory::dim a_nnz; struct PostOpParam { string name; std::vector param; @@ -788,19 +787,17 @@ struct MklMatMulParams { MklMatMulParams(string prefix, memory::dims a_dims, memory::dims b_dims, memory::dims c_dims, memory::dims a_strides, - memory::dims b_strides, memory::dims c_strides, - memory::dim a_nnz = 0) + memory::dims b_strides, memory::dims c_strides) : prefix(prefix), a_dims(a_dims), b_dims(b_dims), c_dims(c_dims), a_strides(a_strides), b_strides(b_strides), - c_strides(c_strides), - a_nnz(a_nnz) {} + c_strides(c_strides) {} }; -template +template class MklMatMulPrimitive : public MklPrimitive { public: explicit MklMatMulPrimitive(const MklMatMulParams& params) @@ -814,12 +811,9 @@ class MklMatMulPrimitive : public MklPrimitive { dnnl::memory::desc GetScratchPadDesc() { return context_.prim_desc->scratchpad_desc(); } - void Execute(const std::shared_ptr& stream, const Tlhs* a_data, const Trhs* b_data, const Toutput* c_data, void* sp_data, - void* mul_data = nullptr, void* add_data = nullptr, - const int32_t* a_col_indices = nullptr, - const int32_t* a_row_pointers = nullptr) { + void* mul_data = nullptr, void* add_data = nullptr) { #if defined(DNNL_AARCH64_USE_ACL) && defined(ENABLE_ONEDNN_OPENMP) mutex_lock lock(primitive_execution_mu_); #endif @@ -830,29 +824,20 @@ class MklMatMulPrimitive : public MklPrimitive { static_cast(const_cast(b_data)), *stream); context_.c_mem->set_data_handle( static_cast(const_cast(c_data)), *stream); + context_.sp_mem->set_data_handle(sp_data, *stream); - if (sp_data != nullptr) context_.sp_mem->set_data_handle(sp_data, *stream); if (mul_data != nullptr) context_.mul_mem->set_data_handle(mul_data, *stream); if (add_data != nullptr) context_.add_mem->set_data_handle(add_data, *stream); #else - if constexpr (CSR) { - context_.a_mem->set_data_handle( - static_cast(const_cast(a_data)), 0); - context_.a_mem->set_data_handle( - static_cast(const_cast(a_col_indices)), 1); - context_.a_mem->set_data_handle( - static_cast(const_cast(a_row_pointers)), 2); - } else { - context_.a_mem->set_data_handle( - static_cast(const_cast(a_data))); - } + context_.a_mem->set_data_handle( + static_cast(const_cast(a_data))); context_.b_mem->set_data_handle( static_cast(const_cast(b_data))); context_.c_mem->set_data_handle( static_cast(const_cast(c_data))); - if (sp_data != nullptr) context_.sp_mem->set_data_handle(sp_data); + context_.sp_mem->set_data_handle(sp_data); if (mul_data != nullptr) context_.mul_mem->set_data_handle(mul_data); if (add_data != nullptr) context_.add_mem->set_data_handle(add_data); #endif // !ENABLE_ONEDNN_OPENMP && !ENABLE_ONEDNN_V3 @@ -862,7 +847,7 @@ class MklMatMulPrimitive : public MklPrimitive { context_.a_mem->set_data_handle(DummyData); context_.b_mem->set_data_handle(DummyData); context_.c_mem->set_data_handle(DummyData); - if (sp_data != nullptr) context_.sp_mem->set_data_handle(DummyData); + context_.sp_mem->set_data_handle(DummyData); if (mul_data != nullptr) context_.mul_mem->set_data_handle(DummyData); if (add_data != nullptr) context_.add_mem->set_data_handle(DummyData); } @@ -922,16 +907,8 @@ class MklMatMulPrimitive : public MklPrimitive { std::shared_ptr matmul_primitive = nullptr; // Create MatMul descriptor and primitive descriptor. - if constexpr (CSR) { - // If it's a CSR matrix. - const auto tmp = memory::desc::csr( - params.a_dims, MklDnnType(), params.a_nnz, - dnnl::memory::data_type::s32, dnnl::memory::data_type::s32); - context_.a_md.reset(new memory::desc(tmp)); - } else { - context_.a_md.reset(new memory::desc({params.a_dims}, MklDnnType(), - params.a_strides)); - } + context_.a_md.reset(new memory::desc({params.a_dims}, MklDnnType(), + params.a_strides)); context_.b_md.reset(new memory::desc({params.b_dims}, MklDnnType(), #ifdef DNNL_AARCH64_USE_ACL @@ -990,13 +967,8 @@ class MklMatMulPrimitive : public MklPrimitive { #endif // !ENABLE_ONEDNN_V3 // Create memory primitive based on dummy data. - if constexpr (CSR) { - context_.a_mem.reset(new dnnl::memory(*context_.a_md, cpu_engine_, - std::vector(3, DummyData))); - } else { - context_.a_mem.reset( - new dnnl::memory(*context_.a_md, cpu_engine_, DummyData)); - } + context_.a_mem.reset( + new dnnl::memory(*context_.a_md, cpu_engine_, DummyData)); #ifdef DNNL_AARCH64_USE_ACL context_.b_mem.reset(new dnnl::memory( context_.prim_desc.get()->weights_desc(), cpu_engine_, DummyData)); @@ -1047,25 +1019,24 @@ class MklMatMulPrimitive : public MklPrimitive { #endif }; -template +template class MklMatMulPrimitiveFactory : public MklPrimitiveFactory { public: - static MklMatMulPrimitive* Get( + static MklMatMulPrimitive* Get( const MklMatMulParams& params, bool do_not_cache) { - MklMatMulPrimitive* matmul_prim = nullptr; + MklMatMulPrimitive* matmul_prim = nullptr; if (do_not_cache) { // Always create new primitive - matmul_prim = new MklMatMulPrimitive(params); + matmul_prim = new MklMatMulPrimitive(params); } else { // Try to find a suitable one in pool - matmul_prim = dynamic_cast*>( - MklMatMulPrimitiveFactory::GetInstance() + matmul_prim = dynamic_cast*>( + MklMatMulPrimitiveFactory::GetInstance() .GetMklMatMul(params)); if (matmul_prim == nullptr) { - matmul_prim = new MklMatMulPrimitive(params); - MklMatMulPrimitiveFactory::GetInstance() + matmul_prim = new MklMatMulPrimitive(params); + MklMatMulPrimitiveFactory::GetInstance() .SetMklMatMul(params, matmul_prim); } } diff --git a/tensorflow/core/kernels/mkl/mkl_sparse_matrix_matmul_op.cc b/tensorflow/core/kernels/mkl/mkl_sparse_matrix_matmul_op.cc deleted file mode 100644 index 029deec8f576a9..00000000000000 --- a/tensorflow/core/kernels/mkl/mkl_sparse_matrix_matmul_op.cc +++ /dev/null @@ -1,226 +0,0 @@ -/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#if defined(INTEL_MKL) && defined(ENABLE_ONEDNN_V3) - -#define EIGEN_USE_THREADS - -#include "Eigen/Core" -#include "Eigen/SparseCore" -#include "dnnl.hpp" -#include "tensorflow/core/framework/op.h" -#include "tensorflow/core/framework/op_kernel.h" -#include "tensorflow/core/framework/tensor_types.h" -#include "tensorflow/core/framework/type_traits.h" -#include "tensorflow/core/framework/variant_op_registry.h" -#include "tensorflow/core/kernels/cwise_ops_common.h" -#include "tensorflow/core/kernels/dense_update_functor.h" -#include "tensorflow/core/kernels/fill_functor.h" -#include "tensorflow/core/kernels/mkl/mkl_matmul_ops_common.h" -#include "tensorflow/core/kernels/sparse/kernels.h" -#include "tensorflow/core/kernels/sparse/mat_mul_op.h" -#include "tensorflow/core/kernels/sparse/sparse_matrix.h" -#include "tensorflow/core/kernels/sparse/transpose_op.h" -#include "tensorflow/core/kernels/transpose_functor.h" -#include "tensorflow/core/lib/gtl/inlined_vector.h" -#include "tensorflow/core/platform/threadpool.h" -#include "tensorflow/core/util/mkl_util.h" -#include "unsupported/Eigen/CXX11/Tensor" - -using dnnl::stream; - -namespace tensorflow { - -typedef Eigen::ThreadPoolDevice CPUDevice; - -// Implements a kernel which, given a SparseMatrix `a` and dense Tensor `b`, -// computes a dense Tensor `c` satisfying `c = a * b` where * denotes matrix -// multiplication. -// -// The rank of both `a` and `b` must be equal and their shapes must be -// compatible for matrix multiplication. Otherwise, InvalidArgument runtime -// errors will be thrown. Only inputs of rank 2 are supported. -// -template -class MklSparseMatrixMatMulOp : public MklDnnMatMulOpBase { - private: - tensorflow::CSRMatMulCPUOp eigen_sparse_matmul_op_; - - public: - explicit MklSparseMatrixMatMulOp(OpKernelConstruction* ctx) - : MklDnnMatMulOpBase(ctx), eigen_sparse_matmul_op_(ctx) {} - - // Throws errors if there are issues with the input. - Status ValidateInputs(const CSRSparseMatrix& sparse_matrix_a, - const Tensor& dense_tensor_b, int* rank) { - // Validate datatypes. - if (sparse_matrix_a.dtype() != dense_tensor_b.dtype()) { - return absl::InvalidArgumentError(absl::StrCat( - "Input types don't match. a.dtype == ", - DataTypeString(sparse_matrix_a.dtype()), - " vs. b.dtype == ", DataTypeString(dense_tensor_b.dtype()))); - } - - // Validate the ranks. - *rank = sparse_matrix_a.dims(); - if (*rank != dense_tensor_b.dims()) { - return absl::InvalidArgumentError( - absl::StrCat("Ranks of a and b must match, saw: ", *rank, " vs. ", - dense_tensor_b.dims(), ".")); - } - - // Validate shapes. - const auto& a_dense_shape = sparse_matrix_a.dense_shape().vec(); - const int64_t a_inner_dim = a_dense_shape(*rank - 1); - const int64_t b_inner_dim = dense_tensor_b.dim_size(*rank - 2); - if (a_inner_dim != b_inner_dim) { - return absl::InvalidArgumentError( - absl::StrCat("Inner product dimensions of A and B do not agree. ", - "Shapes are: ", TensorShape(a_dense_shape).DebugString(), - " vs. ", dense_tensor_b.shape().DebugString())); - } - - Status s = sparse_matrix_a.Validate(); - return s; - } - - // Determine if we should call the Eigen kernel as a fallback. - bool ShouldCallEigenFallback(const CSRSparseMatrix& sparse_matrix_a, - const Tensor& dense_tensor_b, int rank) { - if (sparse_matrix_a.dtype() != DT_FLOAT) { - VLOG(1) << "sparse_matrix_a.dtype() is not DT_FLOAT"; - return true; - } - if (rank != 2) { - VLOG(1) << "rank is not 2, but " << rank << " instead."; - return true; - } - - return false; - } - - void Compute(OpKernelContext* ctx) override { - // Try to catch any exceptions during the matmul itself. - try { - // Handle the input. - const CSRSparseMatrix* sparse_matrix_a; - OP_REQUIRES_OK(ctx, ExtractVariantFromInput(ctx, 0, &sparse_matrix_a)); - const Tensor& rhs_tensor = ctx->input(1); - Tensor* output_tensor = nullptr; - - int rank; - OP_REQUIRES_OK(ctx, - this->ValidateInputs(*sparse_matrix_a, rhs_tensor, &rank)); - - const auto dense_shape = sparse_matrix_a->dense_shape().vec(); - if (ShouldCallEigenFallback(*sparse_matrix_a, rhs_tensor, rank)) { - return eigen_sparse_matmul_op_.Compute(ctx); - return; - } - - // Dimensions of the matrices. - int64_t num_lhs_rows = dense_shape(rank - 2); - int64_t num_lhs_cols = dense_shape(rank - 1); - int64_t num_rhs_rows = rhs_tensor.dim_size(rank - 2); - int64_t num_rhs_cols = rhs_tensor.dim_size(rank - 1); - memory::dims lhs_dims = memory::dims({num_lhs_rows, num_lhs_cols}); - memory::dims rhs_dims = memory::dims({num_rhs_rows, num_rhs_cols}); - memory::dims output_dims = memory::dims({num_lhs_rows, num_rhs_cols}); - - // Choose the datatype. - const float* lhs_data; - dnnl::memory::data_type lhs_datatype; - switch (sparse_matrix_a->dtype()) { - case DT_FLOAT: - lhs_data = sparse_matrix_a->values().flat().data(); - lhs_datatype = dnnl::memory::data_type::f32; - break; - default: - OP_REQUIRES(ctx, sparse_matrix_a->dtype() == DT_FLOAT, - absl::InvalidArgumentError(absl::StrCat( - "MklSparseMatrixMatMulOp got an unexpected data ", - "type for sparse-matrix input."))); - } - - // Get the oneDNN primitive. - string prefix = "sparsecsrmatmul"; - MklMatMulParams matmul_params(prefix, lhs_dims, rhs_dims, output_dims, - dnnl::memory::dims(), dnnl::memory::dims(), - dnnl::memory::dims(), - sparse_matrix_a->total_nnz()); - MklMatMulPrimitive* matmul_prim = - MklMatMulPrimitiveFactory::Get(matmul_params, 0); - - // Threading. - auto st = ExecuteSingleThreadedGemm(num_lhs_rows, num_rhs_rows, - num_rhs_cols, sizeof(T)); - Eigen::ThreadPoolInterface* eigen_interface = - EigenThreadPoolFromTfContext(ctx); - tsl::OneDnnThreadPool eigen_tp(eigen_interface, - ThreadPoolUseCallerThread(), st ? 1 : -1); - - // Get the cached primitive. - std::shared_ptr matmul_pd = - matmul_prim->GetPrimitiveDesc(); - - // Allocate room for the result. - TensorShape output_tf_shape({num_lhs_rows, num_rhs_cols}); - OP_REQUIRES_OK(ctx, - ctx->allocate_output(0, output_tf_shape, &output_tensor)); - - T* rhs_data = const_cast(rhs_tensor.flat().data()); - T* output_data = const_cast(output_tensor->flat().data()); - MklDnnData lhs_mkl(&(this->cpu_engine_)); - MklDnnData rhs_mkl(&(this->cpu_engine_)); - - // CPU stream. - std::shared_ptr cpu_stream; - cpu_stream.reset(CreateStream(&eigen_tp, matmul_prim->GetEngine())); - - // Allocate a scratchpad. - UserScratchPad scratch_pad; - scratch_pad.AllocateSPTensor(matmul_prim, ctx); - - // Execute the actual matmul. - matmul_prim->Execute( - cpu_stream, lhs_data, rhs_data, output_data, scratch_pad.Get(), - nullptr, nullptr, - sparse_matrix_a->col_indices().flat().data(), - sparse_matrix_a->row_pointers().flat().data()); - } catch (dnnl::error& e) { - OP_REQUIRES_OK( - ctx, - absl::AbortedError(absl::StrCat( - "Operation received an exception:", "Status: ", - std::to_string(e.status), ", message: ", string(e.message), - ", in file ", string(__FILE__), ":", std::to_string(__LINE__)))); - } - } -}; - -#define REGISTER_CPU(T) \ - REGISTER_KERNEL_BUILDER(Name("_MklNativeSparseMatrixMatMul") \ - .Device(DEVICE_CPU) \ - .Label(mkl_op_registry::kMklNameChangeOpLabel) \ - .TypeConstraint("T"), \ - MklSparseMatrixMatMulOp); - -REGISTER_CPU(float) - -#undef REGISTER_CPU - -} // namespace tensorflow - -#endif // INTEL_MKL && ENABLE_ONEDNN_V3 diff --git a/tensorflow/core/kernels/mkl/mkl_sparse_matrix_matmul_op_benchmark.cc b/tensorflow/core/kernels/mkl/mkl_sparse_matrix_matmul_op_benchmark.cc deleted file mode 100644 index 20ef908e9b9521..00000000000000 --- a/tensorflow/core/kernels/mkl/mkl_sparse_matrix_matmul_op_benchmark.cc +++ /dev/null @@ -1,199 +0,0 @@ -/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. - - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifdef INTEL_MKL - -#include - -#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h" -#include "tensorflow/core/framework/tensor.h" -#include "tensorflow/core/framework/tensor_shape.h" -#include "tensorflow/core/framework/tensor_testutil.h" -#include "tensorflow/core/framework/tensor_types.h" -#include "tensorflow/core/graph/mkl_testlib.h" -#include "tensorflow/core/graph/node_builder.h" -#include "tensorflow/core/kernels/sparse/kernels.h" -#include "tensorflow/core/kernels/sparse/sparse_matrix.h" -#include "tensorflow/core/platform/test.h" -#include "tensorflow/core/platform/test_benchmark.h" - -namespace tensorflow { -namespace { - -static Graph *SparseMatrixMatmulGenerate(int nnz, int m, int k, int n, - Tensor **csr_matrix_t, - Tensor **dense_matrix_t) { - Graph *g = new Graph(OpRegistry::Global()); - CSRSparseMatrix csr_matrix; - - // Generate the random COO matrix. - Tensor a_values_t(DT_FLOAT, TensorShape({nnz})); - Tensor a_indices_t(DT_INT64, TensorShape({nnz, 2})); - Tensor a_shape_t(DT_INT64, TensorShape({2})); - auto a_shape_vec = a_shape_t.vec(); - a_shape_vec(0) = m; - a_shape_vec(1) = k; - a_values_t.flat().setRandom(); - auto a_indices_mat = a_indices_t.matrix(); - std::random_device rd; - std::mt19937 gen(rd()); - std::uniform_int_distribution<> a_lhs_dist(0, a_shape_vec(0) - 1); - std::uniform_int_distribution<> a_rhs_dist(0, a_shape_vec(1) - 1); - for (int32_t i = 0; i < nnz; ++i) { - a_indices_mat(i, 0) = (const int64_t)a_lhs_dist(gen); - a_indices_mat(i, 1) = (const int64_t)a_rhs_dist(gen); - } - - // Calculate some constants for the conversion. - const int64_t batch_size = 1; - const int num_rows = a_shape_vec(0); - const int num_cols = a_shape_vec(1); - - // Allocate memory for the output CSR. - Tensor csr_batch_pointers(DT_INT32, TensorShape({batch_size + 1})); - Tensor csr_column_indices(DT_INT32, TensorShape({nnz})); - Tensor csr_row_pointers(DT_INT32, TensorShape({(num_rows + 1) * batch_size})); - - // Cast the indices matrix to const. - auto a_indices_mat_const = std::as_const(a_indices_t).matrix(); - - // Zero out the row pointers. - memset(csr_row_pointers.flat().data(), 0, - (num_rows + 1) * batch_size * sizeof(int32)); - - // Convert from COO to CSR. - functor::SparseTensorToCSRSparseMatrixCPUFunctor coo_to_csr; - TF_CHECK_OK(coo_to_csr(batch_size, num_rows, num_cols, a_indices_mat_const, - csr_batch_pointers.vec(), - csr_row_pointers.vec(), - csr_column_indices.vec())); - - // Construct a CSRSparseMatrix. - TF_CHECK_OK(CSRSparseMatrix::CreateCSRSparseMatrix( - DT_FLOAT, a_shape_t, csr_batch_pointers, csr_row_pointers, - csr_column_indices, a_values_t, &csr_matrix)); - *csr_matrix_t = new Tensor(cpu_allocator(), DT_VARIANT, TensorShape({})); - (*csr_matrix_t)->scalar()() = std::move(csr_matrix); - - // Generate the dense tensor to multiply against. - *dense_matrix_t = new Tensor(DT_FLOAT, TensorShape({k, n})); - (*dense_matrix_t)->flat().setRandom(); - - return g; -} - -static Graph *SparseMatrixMatmul(const string &kind, Graph *g, - Tensor *csr_matrix_t, Tensor *dense_matrix_t) { - const bool isDefault = (kind == "Default"); - Node *ret = nullptr; - - if (isDefault) { - TF_CHECK_OK(NodeBuilder(g->NewName("n1"), "SparseMatrixMatMul") - .Input(test::graph::Constant(g, *csr_matrix_t)) - .Input(test::graph::Constant(g, *dense_matrix_t)) - .Attr("T", DT_FLOAT) - .Finalize(g, &ret)); - } else { - test::graph::oneDNNSparseCSRMatmul( - g, test::graph::Constant(g, *csr_matrix_t), - test::graph::Constant(g, *dense_matrix_t)); - } - return g; -} - -// NOLINTBEGIN -#define BM_SparseMatrixMatmulDev(kind, NNZ, M, K, N, DEVICE) \ - static void BM_SparseMatrixMatmul_##kind##NNZ##_##M##_##K##_##N##_##DEVICE( \ - ::testing::benchmark::State &state) { \ - Tensor *csr_matrix_t, *dense_matrix_t; \ - Graph *g; \ - int64_t items_per_iter = (static_cast(NNZ) * N); \ - g = SparseMatrixMatmulGenerate(NNZ, M, K, N, &csr_matrix_t, \ - &dense_matrix_t); \ - test::Benchmark( \ - #DEVICE, SparseMatrixMatmul(#kind, g, csr_matrix_t, dense_matrix_t), \ - /*old_benchmark_api*/ false) \ - .Run(state); \ - state.SetItemsProcessed(state.iterations() * items_per_iter); \ - state.SetBytesProcessed(state.iterations() * items_per_iter * \ - sizeof(float)); \ - } \ - BENCHMARK(BM_SparseMatrixMatmul_##kind##NNZ##_##M##_##K##_##N##_##DEVICE) \ - ->Arg(/* unused arg */ 1); -// NOLINTEND - -#define BM_SparseMatrixMatmul(NNZ, M, K, N) \ - BM_SparseMatrixMatmulDev(Default, NNZ, M, K, N, cpu); \ - BM_SparseMatrixMatmulDev(Mkl, NNZ, M, K, N, cpu); - -BM_SparseMatrixMatmul(128, 8, 512, 1); -BM_SparseMatrixMatmul(128, 16, 512, 1); -BM_SparseMatrixMatmul(128, 128, 512, 1); - -BM_SparseMatrixMatmul(128, 4096, 4096, 1); -BM_SparseMatrixMatmul(1024, 4096, 4096, 1); -BM_SparseMatrixMatmul(16384, 4096, 4096, 1); - -BM_SparseMatrixMatmul(128, 8, 1024, 16); -BM_SparseMatrixMatmul(128, 16, 1024, 16); -BM_SparseMatrixMatmul(128, 128, 1024, 16); -BM_SparseMatrixMatmul(128, 4096, 4096, 128); -BM_SparseMatrixMatmul(128, 4096, 4096, 1024); - -BM_SparseMatrixMatmul(1024, 8, 1024, 16); -BM_SparseMatrixMatmul(1024, 16, 1024, 16); -BM_SparseMatrixMatmul(1024, 128, 1024, 16); -BM_SparseMatrixMatmul(1024, 4096, 4096, 128); -BM_SparseMatrixMatmul(1024, 4096, 4096, 1024); - -BM_SparseMatrixMatmul(16384, 8, 1024, 16); -BM_SparseMatrixMatmul(16384, 16, 1024, 16); -BM_SparseMatrixMatmul(16384, 128, 1024, 16); -BM_SparseMatrixMatmul(16384, 4096, 4096, 128); -BM_SparseMatrixMatmul(16384, 4096, 4096, 1024); - -BM_SparseMatrixMatmul(16384, 4096, 4096, 4096); - -// The big ones. -BM_SparseMatrixMatmul(100, 1, 1000000, 100); -BM_SparseMatrixMatmul(200, 1, 2000000, 100); -BM_SparseMatrixMatmul(400, 1, 4000000, 100); - -BM_SparseMatrixMatmul(400, 4, 1000000, 100); -BM_SparseMatrixMatmul(800, 4, 2000000, 100); -BM_SparseMatrixMatmul(1600, 4, 4000000, 100); - -BM_SparseMatrixMatmul(800, 8, 1000000, 100); -BM_SparseMatrixMatmul(1600, 8, 2000000, 100); -BM_SparseMatrixMatmul(3200, 8, 4000000, 100); - -// The bigger ones. -// BM_SparseMatrixMatmul(100, 1, 1000000, 1000); -// BM_SparseMatrixMatmul(200, 1, 2000000, 1000); -// BM_SparseMatrixMatmul(400, 1, 4000000, 1000); - -// BM_SparseMatrixMatmul(400, 4, 1000000, 1000); -// BM_SparseMatrixMatmul(800, 4, 2000000, 1000); -// BM_SparseMatrixMatmul(1600, 4, 4000000, 1000); - -// BM_SparseMatrixMatmul(800, 8, 1000000, 1000); -// BM_SparseMatrixMatmul(1600, 8, 2000000, 1000); -// BM_SparseMatrixMatmul(3200, 8, 4000000, 1000); - -} // namespace -} // end namespace tensorflow - -#endif diff --git a/tensorflow/core/kernels/mlir_generated/BUILD b/tensorflow/core/kernels/mlir_generated/BUILD index cb43658770bc9d..a59140af636cf6 100644 --- a/tensorflow/core/kernels/mlir_generated/BUILD +++ b/tensorflow/core/kernels/mlir_generated/BUILD @@ -1942,7 +1942,6 @@ gpu_kernel_library( "c64", "c128", ], - max_supported_rank = 8, op = "select_v2", tile_size = "256", types = [], diff --git a/tensorflow/core/kernels/mlir_generated/build_defs.bzl b/tensorflow/core/kernels/mlir_generated/build_defs.bzl index 3ed387826e1f0d..e6c1c6a93b2446 100644 --- a/tensorflow/core/kernels/mlir_generated/build_defs.bzl +++ b/tensorflow/core/kernels/mlir_generated/build_defs.bzl @@ -158,7 +158,6 @@ def _gen_kernel_bin_impl(ctx): executable = ctx.executable._tool, arguments = cmd_args + [ "--tile_sizes=%s" % ctx.attr.tile_size, - "--max-supported-rank=%s" % ctx.attr.max_supported_rank, "--host-triple=%s" % ctx.attr.host_triple, "--arch=%s" % ",".join(ctx.attr.gpu_archs), "--input=%s" % ctx.file.mlir_op.path, @@ -194,7 +193,6 @@ _gen_kernel_bin_rule = rule( "data_type": attr.string(mandatory = True), "tile_size": attr.string(mandatory = True), "unroll_factors": attr.string(), - "max_supported_rank": attr.int(), "host_triple": attr.string(mandatory = True), "gpu_archs": attr.string_list(), "jit": attr.bool(), @@ -253,7 +251,6 @@ def _gen_kernel_library( platform, tile_size, tile_size_override = {}, - max_supported_rank = 5, output_types = [], jit_types = [], output_jit_types = [], @@ -275,7 +272,6 @@ def _gen_kernel_library( types: The types ("f16", "f32", "f64") for which a kernel should be generated. tile_size: The tiling specification, e.g. "16x16" or "16Bx16". tile_size_override: dict of type-specific tile_size. - max_supported_rank: Maximum supported rank for rank specialization. jit_types: The types ("f16", "f32", "f64") for which a kernel should be generated. These kernels are different in that they are only partially compiled and will be JIT compiled at execution time. @@ -372,7 +368,6 @@ def _gen_kernel_library( host_triple = host_triple, gpu_archs = gpu_archs, jit = jit, - max_supported_rank = max_supported_rank, mlir_op = "{op}_{name}_{platform}_{type}_{output_type}.mlir".format( op = op, name = name, diff --git a/tensorflow/core/kernels/mlir_generated/gpu_binary_ops_test.cc b/tensorflow/core/kernels/mlir_generated/gpu_binary_ops_test.cc index 561ca57c67ca11..7bea10ceabd737 100644 --- a/tensorflow/core/kernels/mlir_generated/gpu_binary_ops_test.cc +++ b/tensorflow/core/kernels/mlir_generated/gpu_binary_ops_test.cc @@ -1503,7 +1503,7 @@ GENERATE_DEFAULT_TESTS(Xlogy, /*test_name=*/Complex64, std::complex, test::OpsTestConfig().ATol(2e-6).RTol(2e-6)) GENERATE_DEFAULT_TESTS(Xlogy, /*test_name=*/Complex128, std::complex, std::complex, baseline_xlogy, - test::OpsTestConfig()) + test::OpsTestConfig().ATol(1e-12).RTol(1e-12)) /// Test `tf.Xlog1py`. diff --git a/tensorflow/core/kernels/mlir_generated/op_definitions/abs.mlir.tmpl b/tensorflow/core/kernels/mlir_generated/op_definitions/abs.mlir.tmpl index 896c5850ff2bce..8e3acc17ae313c 100644 --- a/tensorflow/core/kernels/mlir_generated/op_definitions/abs.mlir.tmpl +++ b/tensorflow/core/kernels/mlir_generated/op_definitions/abs.mlir.tmpl @@ -1,5 +1,9 @@ -func.func @Abs_platform_elem_type_output_type(%arg0: tensor<*xelem_type>) - -> tensor<*xoutput_type> attributes {tf_entry, llvm.emit_c_interface} { - %0 = mhlo.abs %arg0 : tensor<*xelem_type> - func.return %0 : tensor<*xoutput_type> +func.func @Abs_platform_elem_type_output_type(%arg0: tensor<*xelem_type>) -> tensor<*xoutput_type> attributes {llvm.emit_c_interface, tf_entry} { + %0 = shape.shape_of %arg0 : tensor<*xelem_type> -> tensor + %1 = shape.num_elements %0 : tensor -> index + %from_elements = tensor.from_elements %1 : tensor<1xindex> + %2 = mhlo.dynamic_reshape %arg0, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %3 = mhlo.abs %2 : tensor + %4 = mhlo.dynamic_reshape %3, %0 : (tensor, tensor) -> tensor<*xelem_type> + return %4 : tensor<*xelem_type> } diff --git a/tensorflow/core/kernels/mlir_generated/op_definitions/acos.mlir.tmpl b/tensorflow/core/kernels/mlir_generated/op_definitions/acos.mlir.tmpl index 1508141049b649..5c4a2789351012 100644 --- a/tensorflow/core/kernels/mlir_generated/op_definitions/acos.mlir.tmpl +++ b/tensorflow/core/kernels/mlir_generated/op_definitions/acos.mlir.tmpl @@ -1,5 +1,9 @@ -func.func @Acos_platform_elem_type_output_type(%arg0: tensor<*xelem_type>) - -> tensor<*xoutput_type> attributes {tf_entry, llvm.emit_c_interface} { - %0 = chlo.acos %arg0 : tensor<*xelem_type> -> tensor<*xoutput_type> - func.return %0 : tensor<*xoutput_type> +func.func @Acos_platform_elem_type_output_type(%arg0: tensor<*xelem_type>) -> tensor<*xoutput_type> attributes {llvm.emit_c_interface, tf_entry} { + %0 = shape.shape_of %arg0 : tensor<*xelem_type> -> tensor + %1 = shape.num_elements %0 : tensor -> index + %from_elements = tensor.from_elements %1 : tensor<1xindex> + %2 = mhlo.dynamic_reshape %arg0, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %3 = chlo.acos %2 : tensor -> tensor + %4 = mhlo.dynamic_reshape %3, %0 : (tensor, tensor) -> tensor<*xoutput_type> + return %4 : tensor<*xoutput_type> } diff --git a/tensorflow/core/kernels/mlir_generated/op_definitions/acosh.mlir.tmpl b/tensorflow/core/kernels/mlir_generated/op_definitions/acosh.mlir.tmpl index 2d5a7cca69cc42..237f906d37282f 100644 --- a/tensorflow/core/kernels/mlir_generated/op_definitions/acosh.mlir.tmpl +++ b/tensorflow/core/kernels/mlir_generated/op_definitions/acosh.mlir.tmpl @@ -1,5 +1,9 @@ -func.func @Acosh_platform_elem_type_output_type(%arg0: tensor<*xelem_type>) - -> tensor<*xoutput_type> attributes {tf_entry, llvm.emit_c_interface} { - %0 = chlo.acosh %arg0 : tensor<*xelem_type> -> tensor<*xoutput_type> - func.return %0 : tensor<*xoutput_type> +func.func @Acosh_platform_elem_type_output_type(%arg0: tensor<*xelem_type>) -> tensor<*xoutput_type> attributes {llvm.emit_c_interface, tf_entry} { + %0 = shape.shape_of %arg0 : tensor<*xelem_type> -> tensor + %1 = shape.num_elements %0 : tensor -> index + %from_elements = tensor.from_elements %1 : tensor<1xindex> + %2 = mhlo.dynamic_reshape %arg0, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %3 = chlo.acosh %2 : tensor -> tensor + %4 = mhlo.dynamic_reshape %3, %0 : (tensor, tensor) -> tensor<*xoutput_type> + return %4 : tensor<*xoutput_type> } diff --git a/tensorflow/core/kernels/mlir_generated/op_definitions/add_v2.mlir.tmpl b/tensorflow/core/kernels/mlir_generated/op_definitions/add_v2.mlir.tmpl index b113c33f2ed2a0..8681673198aad1 100644 --- a/tensorflow/core/kernels/mlir_generated/op_definitions/add_v2.mlir.tmpl +++ b/tensorflow/core/kernels/mlir_generated/op_definitions/add_v2.mlir.tmpl @@ -1,6 +1,129 @@ -func.func @AddV2_platform_elem_type_output_type(%arg0: tensor<*xelem_type>, %arg1: tensor<*xelem_type>) - -> tensor<*xoutput_type> attributes {tf_entry, llvm.emit_c_interface} { - %0 = chlo.broadcast_add %arg0, %arg1 - : (tensor<*xelem_type>, tensor<*xelem_type>) -> tensor<*xoutput_type> - func.return %0 : tensor<*xoutput_type> +func.func @AddV2_platform_elem_type_output_type(%arg0: tensor<*xelem_type>, %arg1: tensor<*xelem_type>) -> tensor<*xoutput_type> attributes {llvm.emit_c_interface, tf_entry} { + %0 = shape.const_shape [1, 1, 1, 1, 1] : tensor<5xindex> + %c5 = arith.constant 5 : index + %1 = shape.const_shape [1, 1, 1, 1] : tensor<4xindex> + %c4 = arith.constant 4 : index + %2 = shape.const_shape [1, 1, 1] : tensor<3xindex> + %c3 = arith.constant 3 : index + %3 = shape.const_shape [1, 1] : tensor<2xindex> + %c2 = arith.constant 2 : index + %4 = shape.const_shape [1] : tensor<1xindex> + %c1 = arith.constant 1 : index + %5 = shape.shape_of %arg0 : tensor<*xelem_type> -> tensor + %6 = shape.shape_of %arg1 : tensor<*xelem_type> -> tensor + %7 = shape.num_elements %5 : tensor -> index + %8 = arith.cmpi eq, %7, %c1 : index + %9 = scf.if %8 -> (tensor<*xoutput_type>) { + %14 = shape.num_elements %6 : tensor -> index + %from_elements = tensor.from_elements %14 : tensor<1xindex> + %15 = mhlo.reshape %arg0 : (tensor<*xelem_type>) -> tensor + %16 = mhlo.dynamic_reshape %arg1, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %17 = chlo.broadcast_add %15, %16 : (tensor, tensor) -> tensor + %cast = tensor.cast %17 : tensor to tensor<*xoutput_type> + scf.yield %cast : tensor<*xoutput_type> + } else { + %14 = shape.num_elements %6 : tensor -> index + %15 = arith.cmpi eq, %14, %c1 : index + %16 = scf.if %15 -> (tensor<*xoutput_type>) { + %17 = shape.num_elements %5 : tensor -> index + %from_elements = tensor.from_elements %17 : tensor<1xindex> + %18 = mhlo.dynamic_reshape %arg0, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %19 = mhlo.reshape %arg1 : (tensor<*xelem_type>) -> tensor + %20 = chlo.broadcast_add %18, %19 : (tensor, tensor) -> tensor + %cast = tensor.cast %20 : tensor to tensor<*xoutput_type> + scf.yield %cast : tensor<*xoutput_type> + } else { + %17 = shape.shape_eq %5, %6 : tensor, tensor + %18 = scf.if %17 -> (tensor<*xoutput_type>) { + %19 = shape.any %5, %6 : tensor, tensor -> tensor + %20 = shape.num_elements %19 : tensor -> index + %from_elements = tensor.from_elements %20 : tensor<1xindex> + %21 = mhlo.dynamic_reshape %arg0, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %22 = mhlo.dynamic_reshape %arg1, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %23 = chlo.broadcast_add %21, %22 : (tensor, tensor) -> tensor + %cast = tensor.cast %23 : tensor to tensor<*xoutput_type> + scf.yield %cast : tensor<*xoutput_type> + } else { + %19:2 = chlo.minimum_broadcast_shapes %5, %6 : tensor, tensor -> tensor, tensor + %20 = shape.rank %19#0 : tensor -> index + %21 = shape.rank %19#1 : tensor -> index + %22 = arith.cmpi sgt, %20, %21 : index + %23 = arith.select %22, %20, %21 : index + %24 = arith.cmpi ule, %23, %c1 : index + %25 = scf.if %24 -> (tensor<*xoutput_type>) { + %26 = shape.broadcast %19#0, %4 : tensor, tensor<1xindex> -> tensor + %cast = tensor.cast %26 : tensor to tensor<1xindex> + %27 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %28 = shape.broadcast %19#1, %4 : tensor, tensor<1xindex> -> tensor + %cast_0 = tensor.cast %28 : tensor to tensor<1xindex> + %29 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %30 = chlo.broadcast_add %27, %29 : (tensor, tensor) -> tensor + %cast_1 = tensor.cast %30 : tensor to tensor<*xoutput_type> + scf.yield %cast_1 : tensor<*xoutput_type> + } else { + %26 = arith.cmpi ule, %23, %c2 : index + %27 = scf.if %26 -> (tensor<*xoutput_type>) { + %28 = shape.broadcast %19#0, %3 : tensor, tensor<2xindex> -> tensor + %cast = tensor.cast %28 : tensor to tensor<2xindex> + %29 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xelem_type>, tensor<2xindex>) -> tensor + %30 = shape.broadcast %19#1, %3 : tensor, tensor<2xindex> -> tensor + %cast_0 = tensor.cast %30 : tensor to tensor<2xindex> + %31 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xelem_type>, tensor<2xindex>) -> tensor + %32 = chlo.broadcast_add %29, %31 : (tensor, tensor) -> tensor + %cast_1 = tensor.cast %32 : tensor to tensor<*xoutput_type> + scf.yield %cast_1 : tensor<*xoutput_type> + } else { + %28 = arith.cmpi ule, %23, %c3 : index + %29 = scf.if %28 -> (tensor<*xoutput_type>) { + %30 = shape.broadcast %19#0, %2 : tensor, tensor<3xindex> -> tensor + %cast = tensor.cast %30 : tensor to tensor<3xindex> + %31 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xelem_type>, tensor<3xindex>) -> tensor + %32 = shape.broadcast %19#1, %2 : tensor, tensor<3xindex> -> tensor + %cast_0 = tensor.cast %32 : tensor to tensor<3xindex> + %33 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xelem_type>, tensor<3xindex>) -> tensor + %34 = chlo.broadcast_add %31, %33 : (tensor, tensor) -> tensor + %cast_1 = tensor.cast %34 : tensor to tensor<*xoutput_type> + scf.yield %cast_1 : tensor<*xoutput_type> + } else { + %30 = arith.cmpi ule, %23, %c4 : index + %31 = scf.if %30 -> (tensor<*xoutput_type>) { + %32 = shape.broadcast %19#0, %1 : tensor, tensor<4xindex> -> tensor + %cast = tensor.cast %32 : tensor to tensor<4xindex> + %33 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xelem_type>, tensor<4xindex>) -> tensor + %34 = shape.broadcast %19#1, %1 : tensor, tensor<4xindex> -> tensor + %cast_0 = tensor.cast %34 : tensor to tensor<4xindex> + %35 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xelem_type>, tensor<4xindex>) -> tensor + %36 = chlo.broadcast_add %33, %35 : (tensor, tensor) -> tensor + %cast_1 = tensor.cast %36 : tensor to tensor<*xoutput_type> + scf.yield %cast_1 : tensor<*xoutput_type> + } else { + %32 = arith.cmpi ule, %23, %c5 : index + cf.assert %32, "Input for dynamic binary or n-ary op lowering was of a rank greater than 5" + %33 = shape.broadcast %19#0, %0 : tensor, tensor<5xindex> -> tensor + %cast = tensor.cast %33 : tensor to tensor<5xindex> + %34 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xelem_type>, tensor<5xindex>) -> tensor + %35 = shape.broadcast %19#1, %0 : tensor, tensor<5xindex> -> tensor + %cast_0 = tensor.cast %35 : tensor to tensor<5xindex> + %36 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xelem_type>, tensor<5xindex>) -> tensor + %37 = chlo.broadcast_add %34, %36 : (tensor, tensor) -> tensor + %cast_1 = tensor.cast %37 : tensor to tensor<*xoutput_type> + scf.yield %cast_1 : tensor<*xoutput_type> + } + scf.yield %31 : tensor<*xoutput_type> + } + scf.yield %29 : tensor<*xoutput_type> + } + scf.yield %27 : tensor<*xoutput_type> + } + scf.yield %25 : tensor<*xoutput_type> + } + scf.yield %18 : tensor<*xoutput_type> + } + scf.yield %16 : tensor<*xoutput_type> + } + %10 = shape.shape_of %arg0 : tensor<*xelem_type> -> tensor + %11 = shape.shape_of %arg1 : tensor<*xelem_type> -> tensor + %12 = shape.broadcast %10, %11 : tensor, tensor -> tensor + %13 = mhlo.dynamic_reshape %9, %12 : (tensor<*xoutput_type>, tensor) -> tensor<*xoutput_type> + return %13 : tensor<*xoutput_type> } diff --git a/tensorflow/core/kernels/mlir_generated/op_definitions/angle.mlir.tmpl b/tensorflow/core/kernels/mlir_generated/op_definitions/angle.mlir.tmpl index d53cc3093fbd1d..009e352e52ae01 100644 --- a/tensorflow/core/kernels/mlir_generated/op_definitions/angle.mlir.tmpl +++ b/tensorflow/core/kernels/mlir_generated/op_definitions/angle.mlir.tmpl @@ -1,7 +1,11 @@ -func.func @Angle_platform_elem_type_output_type(%arg0: tensor<*xelem_type>) - -> tensor<*xoutput_type> attributes {tf_entry, llvm.emit_c_interface} { - %0 = mhlo.imag %arg0 : (tensor<*xelem_type>) -> tensor<*xoutput_type> - %1 = mhlo.real %arg0 : (tensor<*xelem_type>) -> tensor<*xoutput_type> - %2 = mhlo.atan2 %0, %1 : tensor<*xoutput_type> - func.return %2 : tensor<*xoutput_type> +func.func @Angle_platform_elem_type_output_type(%arg0: tensor<*xelem_type>) -> tensor<*xoutput_type> attributes {llvm.emit_c_interface, tf_entry} { + %0 = shape.shape_of %arg0 : tensor<*xelem_type> -> tensor + %1 = shape.num_elements %0 : tensor -> index + %from_elements = tensor.from_elements %1 : tensor<1xindex> + %2 = mhlo.dynamic_reshape %arg0, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %3 = mhlo.imag %2 : (tensor) -> tensor + %4 = mhlo.real %2 : (tensor) -> tensor + %5 = mhlo.atan2 %3, %4 : tensor + %6 = mhlo.dynamic_reshape %5, %0 : (tensor, tensor) -> tensor<*xoutput_type> + return %6 : tensor<*xoutput_type> } diff --git a/tensorflow/core/kernels/mlir_generated/op_definitions/asin.mlir.tmpl b/tensorflow/core/kernels/mlir_generated/op_definitions/asin.mlir.tmpl index 68a92623fe45b1..534932181bd4d6 100644 --- a/tensorflow/core/kernels/mlir_generated/op_definitions/asin.mlir.tmpl +++ b/tensorflow/core/kernels/mlir_generated/op_definitions/asin.mlir.tmpl @@ -1,5 +1,9 @@ -func.func @Asin_platform_elem_type_output_type(%arg0: tensor<*xelem_type>) - -> tensor<*xoutput_type> attributes {tf_entry, llvm.emit_c_interface} { - %0 = chlo.asin %arg0 : tensor<*xelem_type> -> tensor<*xoutput_type> - func.return %0 : tensor<*xoutput_type> +func.func @Asin_platform_elem_type_output_type(%arg0: tensor<*xelem_type>) -> tensor<*xoutput_type> attributes {llvm.emit_c_interface, tf_entry} { + %0 = shape.shape_of %arg0 : tensor<*xelem_type> -> tensor + %1 = shape.num_elements %0 : tensor -> index + %from_elements = tensor.from_elements %1 : tensor<1xindex> + %2 = mhlo.dynamic_reshape %arg0, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %3 = chlo.asin %2 : tensor -> tensor + %4 = mhlo.dynamic_reshape %3, %0 : (tensor, tensor) -> tensor<*xoutput_type> + return %4 : tensor<*xoutput_type> } diff --git a/tensorflow/core/kernels/mlir_generated/op_definitions/asinh.mlir.tmpl b/tensorflow/core/kernels/mlir_generated/op_definitions/asinh.mlir.tmpl index 9cc65cb8c76d95..e3e256c8425d5a 100644 --- a/tensorflow/core/kernels/mlir_generated/op_definitions/asinh.mlir.tmpl +++ b/tensorflow/core/kernels/mlir_generated/op_definitions/asinh.mlir.tmpl @@ -1,5 +1,9 @@ -func.func @Asinh_platform_elem_type_output_type(%arg0: tensor<*xelem_type>) - -> tensor<*xoutput_type> attributes {tf_entry, llvm.emit_c_interface} { - %0 = chlo.asinh %arg0 : tensor<*xelem_type> -> tensor<*xoutput_type> - func.return %0 : tensor<*xoutput_type> +func.func @Asinh_platform_elem_type_output_type(%arg0: tensor<*xelem_type>) -> tensor<*xoutput_type> attributes {llvm.emit_c_interface, tf_entry} { + %0 = shape.shape_of %arg0 : tensor<*xelem_type> -> tensor + %1 = shape.num_elements %0 : tensor -> index + %from_elements = tensor.from_elements %1 : tensor<1xindex> + %2 = mhlo.dynamic_reshape %arg0, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %3 = chlo.asinh %2 : tensor -> tensor + %4 = mhlo.dynamic_reshape %3, %0 : (tensor, tensor) -> tensor<*xoutput_type> + return %4 : tensor<*xoutput_type> } diff --git a/tensorflow/core/kernels/mlir_generated/op_definitions/atan.mlir.tmpl b/tensorflow/core/kernels/mlir_generated/op_definitions/atan.mlir.tmpl index 9dee2cd71aa9e0..679f4b2c0c48a4 100644 --- a/tensorflow/core/kernels/mlir_generated/op_definitions/atan.mlir.tmpl +++ b/tensorflow/core/kernels/mlir_generated/op_definitions/atan.mlir.tmpl @@ -1,5 +1,9 @@ -func.func @Atan_platform_elem_type_output_type(%arg0: tensor<*xelem_type>) - -> tensor<*xoutput_type> attributes {tf_entry, llvm.emit_c_interface} { - %0 = chlo.atan %arg0 : tensor<*xelem_type> -> tensor<*xoutput_type> - func.return %0 : tensor<*xoutput_type> +func.func @Atan_platform_elem_type_output_type(%arg0: tensor<*xelem_type>) -> tensor<*xoutput_type> attributes {llvm.emit_c_interface, tf_entry} { + %0 = shape.shape_of %arg0 : tensor<*xelem_type> -> tensor + %1 = shape.num_elements %0 : tensor -> index + %from_elements = tensor.from_elements %1 : tensor<1xindex> + %2 = mhlo.dynamic_reshape %arg0, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %3 = chlo.atan %2 : tensor -> tensor + %4 = mhlo.dynamic_reshape %3, %0 : (tensor, tensor) -> tensor<*xoutput_type> + return %4 : tensor<*xoutput_type> } diff --git a/tensorflow/core/kernels/mlir_generated/op_definitions/atan2.mlir.tmpl b/tensorflow/core/kernels/mlir_generated/op_definitions/atan2.mlir.tmpl index b0820d84a35a29..a61936c7ec2bfe 100644 --- a/tensorflow/core/kernels/mlir_generated/op_definitions/atan2.mlir.tmpl +++ b/tensorflow/core/kernels/mlir_generated/op_definitions/atan2.mlir.tmpl @@ -1,6 +1,129 @@ -func.func @Atan2_platform_elem_type_output_type(%arg0: tensor<*xelem_type>, %arg1: tensor<*xelem_type>) - -> tensor<*xoutput_type> attributes {tf_entry, llvm.emit_c_interface} { - %0 = chlo.broadcast_atan2 %arg0, %arg1 - : (tensor<*xelem_type>, tensor<*xelem_type>) -> tensor<*xoutput_type> - func.return %0 : tensor<*xoutput_type> +func.func @Atan2_platform_elem_type_output_type(%arg0: tensor<*xelem_type>, %arg1: tensor<*xelem_type>) -> tensor<*xoutput_type> attributes {llvm.emit_c_interface, tf_entry} { + %0 = shape.const_shape [1, 1, 1, 1, 1] : tensor<5xindex> + %c5 = arith.constant 5 : index + %1 = shape.const_shape [1, 1, 1, 1] : tensor<4xindex> + %c4 = arith.constant 4 : index + %2 = shape.const_shape [1, 1, 1] : tensor<3xindex> + %c3 = arith.constant 3 : index + %3 = shape.const_shape [1, 1] : tensor<2xindex> + %c2 = arith.constant 2 : index + %4 = shape.const_shape [1] : tensor<1xindex> + %c1 = arith.constant 1 : index + %5 = shape.shape_of %arg0 : tensor<*xelem_type> -> tensor + %6 = shape.shape_of %arg1 : tensor<*xelem_type> -> tensor + %7 = shape.num_elements %5 : tensor -> index + %8 = arith.cmpi eq, %7, %c1 : index + %9 = scf.if %8 -> (tensor<*xoutput_type>) { + %14 = shape.num_elements %6 : tensor -> index + %from_elements = tensor.from_elements %14 : tensor<1xindex> + %15 = mhlo.reshape %arg0 : (tensor<*xelem_type>) -> tensor + %16 = mhlo.dynamic_reshape %arg1, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %17 = chlo.broadcast_atan2 %15, %16 : (tensor, tensor) -> tensor + %cast = tensor.cast %17 : tensor to tensor<*xoutput_type> + scf.yield %cast : tensor<*xoutput_type> + } else { + %14 = shape.num_elements %6 : tensor -> index + %15 = arith.cmpi eq, %14, %c1 : index + %16 = scf.if %15 -> (tensor<*xoutput_type>) { + %17 = shape.num_elements %5 : tensor -> index + %from_elements = tensor.from_elements %17 : tensor<1xindex> + %18 = mhlo.dynamic_reshape %arg0, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %19 = mhlo.reshape %arg1 : (tensor<*xelem_type>) -> tensor + %20 = chlo.broadcast_atan2 %18, %19 : (tensor, tensor) -> tensor + %cast = tensor.cast %20 : tensor to tensor<*xoutput_type> + scf.yield %cast : tensor<*xoutput_type> + } else { + %17 = shape.shape_eq %5, %6 : tensor, tensor + %18 = scf.if %17 -> (tensor<*xoutput_type>) { + %19 = shape.any %5, %6 : tensor, tensor -> tensor + %20 = shape.num_elements %19 : tensor -> index + %from_elements = tensor.from_elements %20 : tensor<1xindex> + %21 = mhlo.dynamic_reshape %arg0, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %22 = mhlo.dynamic_reshape %arg1, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %23 = chlo.broadcast_atan2 %21, %22 : (tensor, tensor) -> tensor + %cast = tensor.cast %23 : tensor to tensor<*xoutput_type> + scf.yield %cast : tensor<*xoutput_type> + } else { + %19:2 = chlo.minimum_broadcast_shapes %5, %6 : tensor, tensor -> tensor, tensor + %20 = shape.rank %19#0 : tensor -> index + %21 = shape.rank %19#1 : tensor -> index + %22 = arith.cmpi sgt, %20, %21 : index + %23 = arith.select %22, %20, %21 : index + %24 = arith.cmpi ule, %23, %c1 : index + %25 = scf.if %24 -> (tensor<*xoutput_type>) { + %26 = shape.broadcast %19#0, %4 : tensor, tensor<1xindex> -> tensor + %cast = tensor.cast %26 : tensor to tensor<1xindex> + %27 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %28 = shape.broadcast %19#1, %4 : tensor, tensor<1xindex> -> tensor + %cast_0 = tensor.cast %28 : tensor to tensor<1xindex> + %29 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %30 = chlo.broadcast_atan2 %27, %29 : (tensor, tensor) -> tensor + %cast_1 = tensor.cast %30 : tensor to tensor<*xoutput_type> + scf.yield %cast_1 : tensor<*xoutput_type> + } else { + %26 = arith.cmpi ule, %23, %c2 : index + %27 = scf.if %26 -> (tensor<*xoutput_type>) { + %28 = shape.broadcast %19#0, %3 : tensor, tensor<2xindex> -> tensor + %cast = tensor.cast %28 : tensor to tensor<2xindex> + %29 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xelem_type>, tensor<2xindex>) -> tensor + %30 = shape.broadcast %19#1, %3 : tensor, tensor<2xindex> -> tensor + %cast_0 = tensor.cast %30 : tensor to tensor<2xindex> + %31 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xelem_type>, tensor<2xindex>) -> tensor + %32 = chlo.broadcast_atan2 %29, %31 : (tensor, tensor) -> tensor + %cast_1 = tensor.cast %32 : tensor to tensor<*xoutput_type> + scf.yield %cast_1 : tensor<*xoutput_type> + } else { + %28 = arith.cmpi ule, %23, %c3 : index + %29 = scf.if %28 -> (tensor<*xoutput_type>) { + %30 = shape.broadcast %19#0, %2 : tensor, tensor<3xindex> -> tensor + %cast = tensor.cast %30 : tensor to tensor<3xindex> + %31 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xelem_type>, tensor<3xindex>) -> tensor + %32 = shape.broadcast %19#1, %2 : tensor, tensor<3xindex> -> tensor + %cast_0 = tensor.cast %32 : tensor to tensor<3xindex> + %33 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xelem_type>, tensor<3xindex>) -> tensor + %34 = chlo.broadcast_atan2 %31, %33 : (tensor, tensor) -> tensor + %cast_1 = tensor.cast %34 : tensor to tensor<*xoutput_type> + scf.yield %cast_1 : tensor<*xoutput_type> + } else { + %30 = arith.cmpi ule, %23, %c4 : index + %31 = scf.if %30 -> (tensor<*xoutput_type>) { + %32 = shape.broadcast %19#0, %1 : tensor, tensor<4xindex> -> tensor + %cast = tensor.cast %32 : tensor to tensor<4xindex> + %33 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xelem_type>, tensor<4xindex>) -> tensor + %34 = shape.broadcast %19#1, %1 : tensor, tensor<4xindex> -> tensor + %cast_0 = tensor.cast %34 : tensor to tensor<4xindex> + %35 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xelem_type>, tensor<4xindex>) -> tensor + %36 = chlo.broadcast_atan2 %33, %35 : (tensor, tensor) -> tensor + %cast_1 = tensor.cast %36 : tensor to tensor<*xoutput_type> + scf.yield %cast_1 : tensor<*xoutput_type> + } else { + %32 = arith.cmpi ule, %23, %c5 : index + cf.assert %32, "Input for dynamic binary or n-ary op lowering was of a rank greater than 5" + %33 = shape.broadcast %19#0, %0 : tensor, tensor<5xindex> -> tensor + %cast = tensor.cast %33 : tensor to tensor<5xindex> + %34 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xelem_type>, tensor<5xindex>) -> tensor + %35 = shape.broadcast %19#1, %0 : tensor, tensor<5xindex> -> tensor + %cast_0 = tensor.cast %35 : tensor to tensor<5xindex> + %36 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xelem_type>, tensor<5xindex>) -> tensor + %37 = chlo.broadcast_atan2 %34, %36 : (tensor, tensor) -> tensor + %cast_1 = tensor.cast %37 : tensor to tensor<*xoutput_type> + scf.yield %cast_1 : tensor<*xoutput_type> + } + scf.yield %31 : tensor<*xoutput_type> + } + scf.yield %29 : tensor<*xoutput_type> + } + scf.yield %27 : tensor<*xoutput_type> + } + scf.yield %25 : tensor<*xoutput_type> + } + scf.yield %18 : tensor<*xoutput_type> + } + scf.yield %16 : tensor<*xoutput_type> + } + %10 = shape.shape_of %arg0 : tensor<*xelem_type> -> tensor + %11 = shape.shape_of %arg1 : tensor<*xelem_type> -> tensor + %12 = shape.broadcast %10, %11 : tensor, tensor -> tensor + %13 = mhlo.dynamic_reshape %9, %12 : (tensor<*xoutput_type>, tensor) -> tensor<*xoutput_type> + return %13 : tensor<*xoutput_type> } diff --git a/tensorflow/core/kernels/mlir_generated/op_definitions/atanh.mlir.tmpl b/tensorflow/core/kernels/mlir_generated/op_definitions/atanh.mlir.tmpl index f6c06c331dd14d..1f584b1d65ae03 100644 --- a/tensorflow/core/kernels/mlir_generated/op_definitions/atanh.mlir.tmpl +++ b/tensorflow/core/kernels/mlir_generated/op_definitions/atanh.mlir.tmpl @@ -1,5 +1,9 @@ -func.func @Atanh_platform_elem_type_output_type(%arg0: tensor<*xelem_type>) - -> tensor<*xoutput_type> attributes {tf_entry, llvm.emit_c_interface} { - %0 = chlo.atanh %arg0 : tensor<*xelem_type> -> tensor<*xoutput_type> - func.return %0 : tensor<*xoutput_type> +func.func @Atanh_platform_elem_type_output_type(%arg0: tensor<*xelem_type>) -> tensor<*xoutput_type> attributes {llvm.emit_c_interface, tf_entry} { + %0 = shape.shape_of %arg0 : tensor<*xelem_type> -> tensor + %1 = shape.num_elements %0 : tensor -> index + %from_elements = tensor.from_elements %1 : tensor<1xindex> + %2 = mhlo.dynamic_reshape %arg0, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %3 = chlo.atanh %2 : tensor -> tensor + %4 = mhlo.dynamic_reshape %3, %0 : (tensor, tensor) -> tensor<*xoutput_type> + return %4 : tensor<*xoutput_type> } diff --git a/tensorflow/core/kernels/mlir_generated/op_definitions/bitwise_and.mlir.tmpl b/tensorflow/core/kernels/mlir_generated/op_definitions/bitwise_and.mlir.tmpl index 495c1fa8961f38..2f4e9084ee8beb 100644 --- a/tensorflow/core/kernels/mlir_generated/op_definitions/bitwise_and.mlir.tmpl +++ b/tensorflow/core/kernels/mlir_generated/op_definitions/bitwise_and.mlir.tmpl @@ -1,5 +1,129 @@ -func.func @BitwiseAnd_platform_elem_type_output_type(%arg0: tensor<*xelem_type>, %arg1: tensor<*xelem_type>) - -> tensor<*xoutput_type> attributes {tf_entry, llvm.emit_c_interface} { - %0 = chlo.broadcast_and %arg0, %arg1 : (tensor<*xelem_type>, tensor<*xelem_type>) -> tensor<*xelem_type> - func.return %0 : tensor<*xoutput_type> +func.func @BitwiseAnd_platform_elem_type_output_type(%arg0: tensor<*xelem_type>, %arg1: tensor<*xelem_type>) -> tensor<*xoutput_type> attributes {llvm.emit_c_interface, tf_entry} { + %0 = shape.const_shape [1, 1, 1, 1, 1] : tensor<5xindex> + %c5 = arith.constant 5 : index + %1 = shape.const_shape [1, 1, 1, 1] : tensor<4xindex> + %c4 = arith.constant 4 : index + %2 = shape.const_shape [1, 1, 1] : tensor<3xindex> + %c3 = arith.constant 3 : index + %3 = shape.const_shape [1, 1] : tensor<2xindex> + %c2 = arith.constant 2 : index + %4 = shape.const_shape [1] : tensor<1xindex> + %c1 = arith.constant 1 : index + %5 = shape.shape_of %arg0 : tensor<*xelem_type> -> tensor + %6 = shape.shape_of %arg1 : tensor<*xelem_type> -> tensor + %7 = shape.num_elements %5 : tensor -> index + %8 = arith.cmpi eq, %7, %c1 : index + %9 = scf.if %8 -> (tensor<*xelem_type>) { + %14 = shape.num_elements %6 : tensor -> index + %from_elements = tensor.from_elements %14 : tensor<1xindex> + %15 = mhlo.reshape %arg0 : (tensor<*xelem_type>) -> tensor + %16 = mhlo.dynamic_reshape %arg1, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %17 = chlo.broadcast_and %15, %16 : (tensor, tensor) -> tensor + %cast = tensor.cast %17 : tensor to tensor<*xelem_type> + scf.yield %cast : tensor<*xelem_type> + } else { + %14 = shape.num_elements %6 : tensor -> index + %15 = arith.cmpi eq, %14, %c1 : index + %16 = scf.if %15 -> (tensor<*xelem_type>) { + %17 = shape.num_elements %5 : tensor -> index + %from_elements = tensor.from_elements %17 : tensor<1xindex> + %18 = mhlo.dynamic_reshape %arg0, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %19 = mhlo.reshape %arg1 : (tensor<*xelem_type>) -> tensor + %20 = chlo.broadcast_and %18, %19 : (tensor, tensor) -> tensor + %cast = tensor.cast %20 : tensor to tensor<*xelem_type> + scf.yield %cast : tensor<*xelem_type> + } else { + %17 = shape.shape_eq %5, %6 : tensor, tensor + %18 = scf.if %17 -> (tensor<*xelem_type>) { + %19 = shape.any %5, %6 : tensor, tensor -> tensor + %20 = shape.num_elements %19 : tensor -> index + %from_elements = tensor.from_elements %20 : tensor<1xindex> + %21 = mhlo.dynamic_reshape %arg0, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %22 = mhlo.dynamic_reshape %arg1, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %23 = chlo.broadcast_and %21, %22 : (tensor, tensor) -> tensor + %cast = tensor.cast %23 : tensor to tensor<*xelem_type> + scf.yield %cast : tensor<*xelem_type> + } else { + %19:2 = chlo.minimum_broadcast_shapes %5, %6 : tensor, tensor -> tensor, tensor + %20 = shape.rank %19#0 : tensor -> index + %21 = shape.rank %19#1 : tensor -> index + %22 = arith.cmpi sgt, %20, %21 : index + %23 = arith.select %22, %20, %21 : index + %24 = arith.cmpi ule, %23, %c1 : index + %25 = scf.if %24 -> (tensor<*xelem_type>) { + %26 = shape.broadcast %19#0, %4 : tensor, tensor<1xindex> -> tensor + %cast = tensor.cast %26 : tensor to tensor<1xindex> + %27 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %28 = shape.broadcast %19#1, %4 : tensor, tensor<1xindex> -> tensor + %cast_0 = tensor.cast %28 : tensor to tensor<1xindex> + %29 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %30 = chlo.broadcast_and %27, %29 : (tensor, tensor) -> tensor + %cast_1 = tensor.cast %30 : tensor to tensor<*xelem_type> + scf.yield %cast_1 : tensor<*xelem_type> + } else { + %26 = arith.cmpi ule, %23, %c2 : index + %27 = scf.if %26 -> (tensor<*xelem_type>) { + %28 = shape.broadcast %19#0, %3 : tensor, tensor<2xindex> -> tensor + %cast = tensor.cast %28 : tensor to tensor<2xindex> + %29 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xelem_type>, tensor<2xindex>) -> tensor + %30 = shape.broadcast %19#1, %3 : tensor, tensor<2xindex> -> tensor + %cast_0 = tensor.cast %30 : tensor to tensor<2xindex> + %31 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xelem_type>, tensor<2xindex>) -> tensor + %32 = chlo.broadcast_and %29, %31 : (tensor, tensor) -> tensor + %cast_1 = tensor.cast %32 : tensor to tensor<*xelem_type> + scf.yield %cast_1 : tensor<*xelem_type> + } else { + %28 = arith.cmpi ule, %23, %c3 : index + %29 = scf.if %28 -> (tensor<*xelem_type>) { + %30 = shape.broadcast %19#0, %2 : tensor, tensor<3xindex> -> tensor + %cast = tensor.cast %30 : tensor to tensor<3xindex> + %31 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xelem_type>, tensor<3xindex>) -> tensor + %32 = shape.broadcast %19#1, %2 : tensor, tensor<3xindex> -> tensor + %cast_0 = tensor.cast %32 : tensor to tensor<3xindex> + %33 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xelem_type>, tensor<3xindex>) -> tensor + %34 = chlo.broadcast_and %31, %33 : (tensor, tensor) -> tensor + %cast_1 = tensor.cast %34 : tensor to tensor<*xelem_type> + scf.yield %cast_1 : tensor<*xelem_type> + } else { + %30 = arith.cmpi ule, %23, %c4 : index + %31 = scf.if %30 -> (tensor<*xelem_type>) { + %32 = shape.broadcast %19#0, %1 : tensor, tensor<4xindex> -> tensor + %cast = tensor.cast %32 : tensor to tensor<4xindex> + %33 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xelem_type>, tensor<4xindex>) -> tensor + %34 = shape.broadcast %19#1, %1 : tensor, tensor<4xindex> -> tensor + %cast_0 = tensor.cast %34 : tensor to tensor<4xindex> + %35 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xelem_type>, tensor<4xindex>) -> tensor + %36 = chlo.broadcast_and %33, %35 : (tensor, tensor) -> tensor + %cast_1 = tensor.cast %36 : tensor to tensor<*xelem_type> + scf.yield %cast_1 : tensor<*xelem_type> + } else { + %32 = arith.cmpi ule, %23, %c5 : index + cf.assert %32, "Input for dynamic binary or n-ary op lowering was of a rank greater than 5" + %33 = shape.broadcast %19#0, %0 : tensor, tensor<5xindex> -> tensor + %cast = tensor.cast %33 : tensor to tensor<5xindex> + %34 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xelem_type>, tensor<5xindex>) -> tensor + %35 = shape.broadcast %19#1, %0 : tensor, tensor<5xindex> -> tensor + %cast_0 = tensor.cast %35 : tensor to tensor<5xindex> + %36 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xelem_type>, tensor<5xindex>) -> tensor + %37 = chlo.broadcast_and %34, %36 : (tensor, tensor) -> tensor + %cast_1 = tensor.cast %37 : tensor to tensor<*xelem_type> + scf.yield %cast_1 : tensor<*xelem_type> + } + scf.yield %31 : tensor<*xelem_type> + } + scf.yield %29 : tensor<*xelem_type> + } + scf.yield %27 : tensor<*xelem_type> + } + scf.yield %25 : tensor<*xelem_type> + } + scf.yield %18 : tensor<*xelem_type> + } + scf.yield %16 : tensor<*xelem_type> + } + %10 = shape.shape_of %arg0 : tensor<*xelem_type> -> tensor + %11 = shape.shape_of %arg1 : tensor<*xelem_type> -> tensor + %12 = shape.broadcast %10, %11 : tensor, tensor -> tensor + %13 = mhlo.dynamic_reshape %9, %12 : (tensor<*xelem_type>, tensor) -> tensor<*xelem_type> + return %13 : tensor<*xelem_type> } diff --git a/tensorflow/core/kernels/mlir_generated/op_definitions/bitwise_or.mlir.tmpl b/tensorflow/core/kernels/mlir_generated/op_definitions/bitwise_or.mlir.tmpl index ebf18924609f42..8f9cb17256e7f5 100644 --- a/tensorflow/core/kernels/mlir_generated/op_definitions/bitwise_or.mlir.tmpl +++ b/tensorflow/core/kernels/mlir_generated/op_definitions/bitwise_or.mlir.tmpl @@ -1,5 +1,129 @@ -func.func @BitwiseOr_platform_elem_type_output_type(%arg0: tensor<*xelem_type>, %arg1: tensor<*xelem_type>) - -> tensor<*xoutput_type> attributes {tf_entry, llvm.emit_c_interface} { - %0 = chlo.broadcast_or %arg0, %arg1 : (tensor<*xelem_type>, tensor<*xelem_type>) -> tensor<*xelem_type> - func.return %0 : tensor<*xoutput_type> +func.func @BitwiseOr_platform_elem_type_output_type(%arg0: tensor<*xelem_type>, %arg1: tensor<*xelem_type>) -> tensor<*xoutput_type> attributes {llvm.emit_c_interface, tf_entry} { + %0 = shape.const_shape [1, 1, 1, 1, 1] : tensor<5xindex> + %c5 = arith.constant 5 : index + %1 = shape.const_shape [1, 1, 1, 1] : tensor<4xindex> + %c4 = arith.constant 4 : index + %2 = shape.const_shape [1, 1, 1] : tensor<3xindex> + %c3 = arith.constant 3 : index + %3 = shape.const_shape [1, 1] : tensor<2xindex> + %c2 = arith.constant 2 : index + %4 = shape.const_shape [1] : tensor<1xindex> + %c1 = arith.constant 1 : index + %5 = shape.shape_of %arg0 : tensor<*xelem_type> -> tensor + %6 = shape.shape_of %arg1 : tensor<*xelem_type> -> tensor + %7 = shape.num_elements %5 : tensor -> index + %8 = arith.cmpi eq, %7, %c1 : index + %9 = scf.if %8 -> (tensor<*xelem_type>) { + %14 = shape.num_elements %6 : tensor -> index + %from_elements = tensor.from_elements %14 : tensor<1xindex> + %15 = mhlo.reshape %arg0 : (tensor<*xelem_type>) -> tensor + %16 = mhlo.dynamic_reshape %arg1, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %17 = chlo.broadcast_or %15, %16 : (tensor, tensor) -> tensor + %cast = tensor.cast %17 : tensor to tensor<*xelem_type> + scf.yield %cast : tensor<*xelem_type> + } else { + %14 = shape.num_elements %6 : tensor -> index + %15 = arith.cmpi eq, %14, %c1 : index + %16 = scf.if %15 -> (tensor<*xelem_type>) { + %17 = shape.num_elements %5 : tensor -> index + %from_elements = tensor.from_elements %17 : tensor<1xindex> + %18 = mhlo.dynamic_reshape %arg0, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %19 = mhlo.reshape %arg1 : (tensor<*xelem_type>) -> tensor + %20 = chlo.broadcast_or %18, %19 : (tensor, tensor) -> tensor + %cast = tensor.cast %20 : tensor to tensor<*xelem_type> + scf.yield %cast : tensor<*xelem_type> + } else { + %17 = shape.shape_eq %5, %6 : tensor, tensor + %18 = scf.if %17 -> (tensor<*xelem_type>) { + %19 = shape.any %5, %6 : tensor, tensor -> tensor + %20 = shape.num_elements %19 : tensor -> index + %from_elements = tensor.from_elements %20 : tensor<1xindex> + %21 = mhlo.dynamic_reshape %arg0, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %22 = mhlo.dynamic_reshape %arg1, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %23 = chlo.broadcast_or %21, %22 : (tensor, tensor) -> tensor + %cast = tensor.cast %23 : tensor to tensor<*xelem_type> + scf.yield %cast : tensor<*xelem_type> + } else { + %19:2 = chlo.minimum_broadcast_shapes %5, %6 : tensor, tensor -> tensor, tensor + %20 = shape.rank %19#0 : tensor -> index + %21 = shape.rank %19#1 : tensor -> index + %22 = arith.cmpi sgt, %20, %21 : index + %23 = arith.select %22, %20, %21 : index + %24 = arith.cmpi ule, %23, %c1 : index + %25 = scf.if %24 -> (tensor<*xelem_type>) { + %26 = shape.broadcast %19#0, %4 : tensor, tensor<1xindex> -> tensor + %cast = tensor.cast %26 : tensor to tensor<1xindex> + %27 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %28 = shape.broadcast %19#1, %4 : tensor, tensor<1xindex> -> tensor + %cast_0 = tensor.cast %28 : tensor to tensor<1xindex> + %29 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %30 = chlo.broadcast_or %27, %29 : (tensor, tensor) -> tensor + %cast_1 = tensor.cast %30 : tensor to tensor<*xelem_type> + scf.yield %cast_1 : tensor<*xelem_type> + } else { + %26 = arith.cmpi ule, %23, %c2 : index + %27 = scf.if %26 -> (tensor<*xelem_type>) { + %28 = shape.broadcast %19#0, %3 : tensor, tensor<2xindex> -> tensor + %cast = tensor.cast %28 : tensor to tensor<2xindex> + %29 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xelem_type>, tensor<2xindex>) -> tensor + %30 = shape.broadcast %19#1, %3 : tensor, tensor<2xindex> -> tensor + %cast_0 = tensor.cast %30 : tensor to tensor<2xindex> + %31 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xelem_type>, tensor<2xindex>) -> tensor + %32 = chlo.broadcast_or %29, %31 : (tensor, tensor) -> tensor + %cast_1 = tensor.cast %32 : tensor to tensor<*xelem_type> + scf.yield %cast_1 : tensor<*xelem_type> + } else { + %28 = arith.cmpi ule, %23, %c3 : index + %29 = scf.if %28 -> (tensor<*xelem_type>) { + %30 = shape.broadcast %19#0, %2 : tensor, tensor<3xindex> -> tensor + %cast = tensor.cast %30 : tensor to tensor<3xindex> + %31 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xelem_type>, tensor<3xindex>) -> tensor + %32 = shape.broadcast %19#1, %2 : tensor, tensor<3xindex> -> tensor + %cast_0 = tensor.cast %32 : tensor to tensor<3xindex> + %33 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xelem_type>, tensor<3xindex>) -> tensor + %34 = chlo.broadcast_or %31, %33 : (tensor, tensor) -> tensor + %cast_1 = tensor.cast %34 : tensor to tensor<*xelem_type> + scf.yield %cast_1 : tensor<*xelem_type> + } else { + %30 = arith.cmpi ule, %23, %c4 : index + %31 = scf.if %30 -> (tensor<*xelem_type>) { + %32 = shape.broadcast %19#0, %1 : tensor, tensor<4xindex> -> tensor + %cast = tensor.cast %32 : tensor to tensor<4xindex> + %33 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xelem_type>, tensor<4xindex>) -> tensor + %34 = shape.broadcast %19#1, %1 : tensor, tensor<4xindex> -> tensor + %cast_0 = tensor.cast %34 : tensor to tensor<4xindex> + %35 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xelem_type>, tensor<4xindex>) -> tensor + %36 = chlo.broadcast_or %33, %35 : (tensor, tensor) -> tensor + %cast_1 = tensor.cast %36 : tensor to tensor<*xelem_type> + scf.yield %cast_1 : tensor<*xelem_type> + } else { + %32 = arith.cmpi ule, %23, %c5 : index + cf.assert %32, "Input for dynamic binary or n-ary op lowering was of a rank greater than 5" + %33 = shape.broadcast %19#0, %0 : tensor, tensor<5xindex> -> tensor + %cast = tensor.cast %33 : tensor to tensor<5xindex> + %34 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xelem_type>, tensor<5xindex>) -> tensor + %35 = shape.broadcast %19#1, %0 : tensor, tensor<5xindex> -> tensor + %cast_0 = tensor.cast %35 : tensor to tensor<5xindex> + %36 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xelem_type>, tensor<5xindex>) -> tensor + %37 = chlo.broadcast_or %34, %36 : (tensor, tensor) -> tensor + %cast_1 = tensor.cast %37 : tensor to tensor<*xelem_type> + scf.yield %cast_1 : tensor<*xelem_type> + } + scf.yield %31 : tensor<*xelem_type> + } + scf.yield %29 : tensor<*xelem_type> + } + scf.yield %27 : tensor<*xelem_type> + } + scf.yield %25 : tensor<*xelem_type> + } + scf.yield %18 : tensor<*xelem_type> + } + scf.yield %16 : tensor<*xelem_type> + } + %10 = shape.shape_of %arg0 : tensor<*xelem_type> -> tensor + %11 = shape.shape_of %arg1 : tensor<*xelem_type> -> tensor + %12 = shape.broadcast %10, %11 : tensor, tensor -> tensor + %13 = mhlo.dynamic_reshape %9, %12 : (tensor<*xelem_type>, tensor) -> tensor<*xelem_type> + return %13 : tensor<*xelem_type> } diff --git a/tensorflow/core/kernels/mlir_generated/op_definitions/bitwise_xor.mlir.tmpl b/tensorflow/core/kernels/mlir_generated/op_definitions/bitwise_xor.mlir.tmpl index fb1773b24942bb..e4ba7c1486571a 100644 --- a/tensorflow/core/kernels/mlir_generated/op_definitions/bitwise_xor.mlir.tmpl +++ b/tensorflow/core/kernels/mlir_generated/op_definitions/bitwise_xor.mlir.tmpl @@ -1,5 +1,129 @@ -func.func @BitwiseXor_platform_elem_type_output_type(%arg0: tensor<*xelem_type>, %arg1: tensor<*xelem_type>) - -> tensor<*xoutput_type> attributes {tf_entry, llvm.emit_c_interface} { - %0 = chlo.broadcast_xor %arg0, %arg1 : (tensor<*xelem_type>, tensor<*xelem_type>) -> tensor<*xelem_type> - func.return %0 : tensor<*xoutput_type> +func.func @BitwiseXor_platform_elem_type_output_type(%arg0: tensor<*xelem_type>, %arg1: tensor<*xelem_type>) -> tensor<*xoutput_type> attributes {llvm.emit_c_interface, tf_entry} { + %0 = shape.const_shape [1, 1, 1, 1, 1] : tensor<5xindex> + %c5 = arith.constant 5 : index + %1 = shape.const_shape [1, 1, 1, 1] : tensor<4xindex> + %c4 = arith.constant 4 : index + %2 = shape.const_shape [1, 1, 1] : tensor<3xindex> + %c3 = arith.constant 3 : index + %3 = shape.const_shape [1, 1] : tensor<2xindex> + %c2 = arith.constant 2 : index + %4 = shape.const_shape [1] : tensor<1xindex> + %c1 = arith.constant 1 : index + %5 = shape.shape_of %arg0 : tensor<*xelem_type> -> tensor + %6 = shape.shape_of %arg1 : tensor<*xelem_type> -> tensor + %7 = shape.num_elements %5 : tensor -> index + %8 = arith.cmpi eq, %7, %c1 : index + %9 = scf.if %8 -> (tensor<*xelem_type>) { + %14 = shape.num_elements %6 : tensor -> index + %from_elements = tensor.from_elements %14 : tensor<1xindex> + %15 = mhlo.reshape %arg0 : (tensor<*xelem_type>) -> tensor + %16 = mhlo.dynamic_reshape %arg1, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %17 = chlo.broadcast_xor %15, %16 : (tensor, tensor) -> tensor + %cast = tensor.cast %17 : tensor to tensor<*xelem_type> + scf.yield %cast : tensor<*xelem_type> + } else { + %14 = shape.num_elements %6 : tensor -> index + %15 = arith.cmpi eq, %14, %c1 : index + %16 = scf.if %15 -> (tensor<*xelem_type>) { + %17 = shape.num_elements %5 : tensor -> index + %from_elements = tensor.from_elements %17 : tensor<1xindex> + %18 = mhlo.dynamic_reshape %arg0, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %19 = mhlo.reshape %arg1 : (tensor<*xelem_type>) -> tensor + %20 = chlo.broadcast_xor %18, %19 : (tensor, tensor) -> tensor + %cast = tensor.cast %20 : tensor to tensor<*xelem_type> + scf.yield %cast : tensor<*xelem_type> + } else { + %17 = shape.shape_eq %5, %6 : tensor, tensor + %18 = scf.if %17 -> (tensor<*xelem_type>) { + %19 = shape.any %5, %6 : tensor, tensor -> tensor + %20 = shape.num_elements %19 : tensor -> index + %from_elements = tensor.from_elements %20 : tensor<1xindex> + %21 = mhlo.dynamic_reshape %arg0, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %22 = mhlo.dynamic_reshape %arg1, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %23 = chlo.broadcast_xor %21, %22 : (tensor, tensor) -> tensor + %cast = tensor.cast %23 : tensor to tensor<*xelem_type> + scf.yield %cast : tensor<*xelem_type> + } else { + %19:2 = chlo.minimum_broadcast_shapes %5, %6 : tensor, tensor -> tensor, tensor + %20 = shape.rank %19#0 : tensor -> index + %21 = shape.rank %19#1 : tensor -> index + %22 = arith.cmpi sgt, %20, %21 : index + %23 = arith.select %22, %20, %21 : index + %24 = arith.cmpi ule, %23, %c1 : index + %25 = scf.if %24 -> (tensor<*xelem_type>) { + %26 = shape.broadcast %19#0, %4 : tensor, tensor<1xindex> -> tensor + %cast = tensor.cast %26 : tensor to tensor<1xindex> + %27 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %28 = shape.broadcast %19#1, %4 : tensor, tensor<1xindex> -> tensor + %cast_0 = tensor.cast %28 : tensor to tensor<1xindex> + %29 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %30 = chlo.broadcast_xor %27, %29 : (tensor, tensor) -> tensor + %cast_1 = tensor.cast %30 : tensor to tensor<*xelem_type> + scf.yield %cast_1 : tensor<*xelem_type> + } else { + %26 = arith.cmpi ule, %23, %c2 : index + %27 = scf.if %26 -> (tensor<*xelem_type>) { + %28 = shape.broadcast %19#0, %3 : tensor, tensor<2xindex> -> tensor + %cast = tensor.cast %28 : tensor to tensor<2xindex> + %29 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xelem_type>, tensor<2xindex>) -> tensor + %30 = shape.broadcast %19#1, %3 : tensor, tensor<2xindex> -> tensor + %cast_0 = tensor.cast %30 : tensor to tensor<2xindex> + %31 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xelem_type>, tensor<2xindex>) -> tensor + %32 = chlo.broadcast_xor %29, %31 : (tensor, tensor) -> tensor + %cast_1 = tensor.cast %32 : tensor to tensor<*xelem_type> + scf.yield %cast_1 : tensor<*xelem_type> + } else { + %28 = arith.cmpi ule, %23, %c3 : index + %29 = scf.if %28 -> (tensor<*xelem_type>) { + %30 = shape.broadcast %19#0, %2 : tensor, tensor<3xindex> -> tensor + %cast = tensor.cast %30 : tensor to tensor<3xindex> + %31 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xelem_type>, tensor<3xindex>) -> tensor + %32 = shape.broadcast %19#1, %2 : tensor, tensor<3xindex> -> tensor + %cast_0 = tensor.cast %32 : tensor to tensor<3xindex> + %33 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xelem_type>, tensor<3xindex>) -> tensor + %34 = chlo.broadcast_xor %31, %33 : (tensor, tensor) -> tensor + %cast_1 = tensor.cast %34 : tensor to tensor<*xelem_type> + scf.yield %cast_1 : tensor<*xelem_type> + } else { + %30 = arith.cmpi ule, %23, %c4 : index + %31 = scf.if %30 -> (tensor<*xelem_type>) { + %32 = shape.broadcast %19#0, %1 : tensor, tensor<4xindex> -> tensor + %cast = tensor.cast %32 : tensor to tensor<4xindex> + %33 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xelem_type>, tensor<4xindex>) -> tensor + %34 = shape.broadcast %19#1, %1 : tensor, tensor<4xindex> -> tensor + %cast_0 = tensor.cast %34 : tensor to tensor<4xindex> + %35 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xelem_type>, tensor<4xindex>) -> tensor + %36 = chlo.broadcast_xor %33, %35 : (tensor, tensor) -> tensor + %cast_1 = tensor.cast %36 : tensor to tensor<*xelem_type> + scf.yield %cast_1 : tensor<*xelem_type> + } else { + %32 = arith.cmpi ule, %23, %c5 : index + cf.assert %32, "Input for dynamic binary or n-ary op lowering was of a rank greater than 5" + %33 = shape.broadcast %19#0, %0 : tensor, tensor<5xindex> -> tensor + %cast = tensor.cast %33 : tensor to tensor<5xindex> + %34 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xelem_type>, tensor<5xindex>) -> tensor + %35 = shape.broadcast %19#1, %0 : tensor, tensor<5xindex> -> tensor + %cast_0 = tensor.cast %35 : tensor to tensor<5xindex> + %36 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xelem_type>, tensor<5xindex>) -> tensor + %37 = chlo.broadcast_xor %34, %36 : (tensor, tensor) -> tensor + %cast_1 = tensor.cast %37 : tensor to tensor<*xelem_type> + scf.yield %cast_1 : tensor<*xelem_type> + } + scf.yield %31 : tensor<*xelem_type> + } + scf.yield %29 : tensor<*xelem_type> + } + scf.yield %27 : tensor<*xelem_type> + } + scf.yield %25 : tensor<*xelem_type> + } + scf.yield %18 : tensor<*xelem_type> + } + scf.yield %16 : tensor<*xelem_type> + } + %10 = shape.shape_of %arg0 : tensor<*xelem_type> -> tensor + %11 = shape.shape_of %arg1 : tensor<*xelem_type> -> tensor + %12 = shape.broadcast %10, %11 : tensor, tensor -> tensor + %13 = mhlo.dynamic_reshape %9, %12 : (tensor<*xelem_type>, tensor) -> tensor<*xelem_type> + return %13 : tensor<*xelem_type> } diff --git a/tensorflow/core/kernels/mlir_generated/op_definitions/cast.mlir.tmpl b/tensorflow/core/kernels/mlir_generated/op_definitions/cast.mlir.tmpl index cf61977c9dad34..6e5a042e84e8bd 100644 --- a/tensorflow/core/kernels/mlir_generated/op_definitions/cast.mlir.tmpl +++ b/tensorflow/core/kernels/mlir_generated/op_definitions/cast.mlir.tmpl @@ -1,5 +1,9 @@ -func.func @Cast_platform_elem_type_output_type(%arg0: tensor<*xelem_type>) - -> tensor<*xoutput_type> attributes {tf_entry, llvm.emit_c_interface} { - %0 = mhlo.convert %arg0 : (tensor<*xelem_type>) -> tensor<*xoutput_type> - func.return %0 : tensor<*xoutput_type> +func.func @Cast_platform_elem_type_output_type(%arg0: tensor<*xelem_type>) -> tensor<*xoutput_type> attributes {llvm.emit_c_interface, tf_entry} { + %0 = shape.shape_of %arg0 : tensor<*xelem_type> -> tensor + %1 = shape.num_elements %0 : tensor -> index + %from_elements = tensor.from_elements %1 : tensor<1xindex> + %2 = mhlo.dynamic_reshape %arg0, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %3 = mhlo.convert %2 : (tensor) -> tensor + %4 = mhlo.dynamic_reshape %3, %0 : (tensor, tensor) -> tensor<*xoutput_type> + return %4 : tensor<*xoutput_type> } diff --git a/tensorflow/core/kernels/mlir_generated/op_definitions/ceil.mlir.tmpl b/tensorflow/core/kernels/mlir_generated/op_definitions/ceil.mlir.tmpl index cfc627c90316d9..7ab5d18badadc9 100644 --- a/tensorflow/core/kernels/mlir_generated/op_definitions/ceil.mlir.tmpl +++ b/tensorflow/core/kernels/mlir_generated/op_definitions/ceil.mlir.tmpl @@ -1,5 +1,9 @@ -func.func @Ceil_platform_elem_type_output_type(%arg0: tensor<*xelem_type>) - -> tensor<*xoutput_type> attributes {tf_entry, llvm.emit_c_interface} { - %0 = mhlo.ceil %arg0 : tensor<*xelem_type> - func.return %0 : tensor<*xoutput_type> +func.func @Ceil_platform_elem_type_output_type(%arg0: tensor<*xelem_type>) -> tensor<*xoutput_type> attributes {llvm.emit_c_interface, tf_entry} { + %0 = shape.shape_of %arg0 : tensor<*xelem_type> -> tensor + %1 = shape.num_elements %0 : tensor -> index + %from_elements = tensor.from_elements %1 : tensor<1xindex> + %2 = mhlo.dynamic_reshape %arg0, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %3 = mhlo.ceil %2 : tensor + %4 = mhlo.dynamic_reshape %3, %0 : (tensor, tensor) -> tensor<*xelem_type> + return %4 : tensor<*xelem_type> } diff --git a/tensorflow/core/kernels/mlir_generated/op_definitions/complex.mlir.tmpl b/tensorflow/core/kernels/mlir_generated/op_definitions/complex.mlir.tmpl index 1adb56de892015..c9c7330058e09e 100644 --- a/tensorflow/core/kernels/mlir_generated/op_definitions/complex.mlir.tmpl +++ b/tensorflow/core/kernels/mlir_generated/op_definitions/complex.mlir.tmpl @@ -1,6 +1,129 @@ -func.func @Complex_platform_elem_type_output_type(%arg0: tensor<*xelem_type>, %arg1: tensor<*xelem_type>) - -> tensor<*xoutput_type> attributes {tf_entry, llvm.emit_c_interface} { - %0 = chlo.broadcast_complex %arg0, %arg1 - : (tensor<*xelem_type>, tensor<*xelem_type>) -> tensor<*xoutput_type> - func.return %0 : tensor<*xoutput_type> +func.func @Complex_platform_elem_type_output_type(%arg0: tensor<*xelem_type>, %arg1: tensor<*xelem_type>) -> tensor<*xoutput_type> attributes {llvm.emit_c_interface, tf_entry} { + %0 = shape.const_shape [1, 1, 1, 1, 1] : tensor<5xindex> + %c5 = arith.constant 5 : index + %1 = shape.const_shape [1, 1, 1, 1] : tensor<4xindex> + %c4 = arith.constant 4 : index + %2 = shape.const_shape [1, 1, 1] : tensor<3xindex> + %c3 = arith.constant 3 : index + %3 = shape.const_shape [1, 1] : tensor<2xindex> + %c2 = arith.constant 2 : index + %4 = shape.const_shape [1] : tensor<1xindex> + %c1 = arith.constant 1 : index + %5 = shape.shape_of %arg0 : tensor<*xelem_type> -> tensor + %6 = shape.shape_of %arg1 : tensor<*xelem_type> -> tensor + %7 = shape.num_elements %5 : tensor -> index + %8 = arith.cmpi eq, %7, %c1 : index + %9 = scf.if %8 -> (tensor<*xoutput_type>) { + %14 = shape.num_elements %6 : tensor -> index + %from_elements = tensor.from_elements %14 : tensor<1xindex> + %15 = mhlo.reshape %arg0 : (tensor<*xelem_type>) -> tensor + %16 = mhlo.dynamic_reshape %arg1, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %17 = chlo.broadcast_complex %15, %16 : (tensor, tensor) -> tensor + %cast = tensor.cast %17 : tensor to tensor<*xoutput_type> + scf.yield %cast : tensor<*xoutput_type> + } else { + %14 = shape.num_elements %6 : tensor -> index + %15 = arith.cmpi eq, %14, %c1 : index + %16 = scf.if %15 -> (tensor<*xoutput_type>) { + %17 = shape.num_elements %5 : tensor -> index + %from_elements = tensor.from_elements %17 : tensor<1xindex> + %18 = mhlo.dynamic_reshape %arg0, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %19 = mhlo.reshape %arg1 : (tensor<*xelem_type>) -> tensor + %20 = chlo.broadcast_complex %18, %19 : (tensor, tensor) -> tensor + %cast = tensor.cast %20 : tensor to tensor<*xoutput_type> + scf.yield %cast : tensor<*xoutput_type> + } else { + %17 = shape.shape_eq %5, %6 : tensor, tensor + %18 = scf.if %17 -> (tensor<*xoutput_type>) { + %19 = shape.any %5, %6 : tensor, tensor -> tensor + %20 = shape.num_elements %19 : tensor -> index + %from_elements = tensor.from_elements %20 : tensor<1xindex> + %21 = mhlo.dynamic_reshape %arg0, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %22 = mhlo.dynamic_reshape %arg1, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %23 = chlo.broadcast_complex %21, %22 : (tensor, tensor) -> tensor + %cast = tensor.cast %23 : tensor to tensor<*xoutput_type> + scf.yield %cast : tensor<*xoutput_type> + } else { + %19:2 = chlo.minimum_broadcast_shapes %5, %6 : tensor, tensor -> tensor, tensor + %20 = shape.rank %19#0 : tensor -> index + %21 = shape.rank %19#1 : tensor -> index + %22 = arith.cmpi sgt, %20, %21 : index + %23 = arith.select %22, %20, %21 : index + %24 = arith.cmpi ule, %23, %c1 : index + %25 = scf.if %24 -> (tensor<*xoutput_type>) { + %26 = shape.broadcast %19#0, %4 : tensor, tensor<1xindex> -> tensor + %cast = tensor.cast %26 : tensor to tensor<1xindex> + %27 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %28 = shape.broadcast %19#1, %4 : tensor, tensor<1xindex> -> tensor + %cast_0 = tensor.cast %28 : tensor to tensor<1xindex> + %29 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %30 = chlo.broadcast_complex %27, %29 : (tensor, tensor) -> tensor + %cast_1 = tensor.cast %30 : tensor to tensor<*xoutput_type> + scf.yield %cast_1 : tensor<*xoutput_type> + } else { + %26 = arith.cmpi ule, %23, %c2 : index + %27 = scf.if %26 -> (tensor<*xoutput_type>) { + %28 = shape.broadcast %19#0, %3 : tensor, tensor<2xindex> -> tensor + %cast = tensor.cast %28 : tensor to tensor<2xindex> + %29 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xelem_type>, tensor<2xindex>) -> tensor + %30 = shape.broadcast %19#1, %3 : tensor, tensor<2xindex> -> tensor + %cast_0 = tensor.cast %30 : tensor to tensor<2xindex> + %31 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xelem_type>, tensor<2xindex>) -> tensor + %32 = chlo.broadcast_complex %29, %31 : (tensor, tensor) -> tensor + %cast_1 = tensor.cast %32 : tensor to tensor<*xoutput_type> + scf.yield %cast_1 : tensor<*xoutput_type> + } else { + %28 = arith.cmpi ule, %23, %c3 : index + %29 = scf.if %28 -> (tensor<*xoutput_type>) { + %30 = shape.broadcast %19#0, %2 : tensor, tensor<3xindex> -> tensor + %cast = tensor.cast %30 : tensor to tensor<3xindex> + %31 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xelem_type>, tensor<3xindex>) -> tensor + %32 = shape.broadcast %19#1, %2 : tensor, tensor<3xindex> -> tensor + %cast_0 = tensor.cast %32 : tensor to tensor<3xindex> + %33 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xelem_type>, tensor<3xindex>) -> tensor + %34 = chlo.broadcast_complex %31, %33 : (tensor, tensor) -> tensor + %cast_1 = tensor.cast %34 : tensor to tensor<*xoutput_type> + scf.yield %cast_1 : tensor<*xoutput_type> + } else { + %30 = arith.cmpi ule, %23, %c4 : index + %31 = scf.if %30 -> (tensor<*xoutput_type>) { + %32 = shape.broadcast %19#0, %1 : tensor, tensor<4xindex> -> tensor + %cast = tensor.cast %32 : tensor to tensor<4xindex> + %33 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xelem_type>, tensor<4xindex>) -> tensor + %34 = shape.broadcast %19#1, %1 : tensor, tensor<4xindex> -> tensor + %cast_0 = tensor.cast %34 : tensor to tensor<4xindex> + %35 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xelem_type>, tensor<4xindex>) -> tensor + %36 = chlo.broadcast_complex %33, %35 : (tensor, tensor) -> tensor + %cast_1 = tensor.cast %36 : tensor to tensor<*xoutput_type> + scf.yield %cast_1 : tensor<*xoutput_type> + } else { + %32 = arith.cmpi ule, %23, %c5 : index + cf.assert %32, "Input for dynamic binary or n-ary op lowering was of a rank greater than 5" + %33 = shape.broadcast %19#0, %0 : tensor, tensor<5xindex> -> tensor + %cast = tensor.cast %33 : tensor to tensor<5xindex> + %34 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xelem_type>, tensor<5xindex>) -> tensor + %35 = shape.broadcast %19#1, %0 : tensor, tensor<5xindex> -> tensor + %cast_0 = tensor.cast %35 : tensor to tensor<5xindex> + %36 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xelem_type>, tensor<5xindex>) -> tensor + %37 = chlo.broadcast_complex %34, %36 : (tensor, tensor) -> tensor + %cast_1 = tensor.cast %37 : tensor to tensor<*xoutput_type> + scf.yield %cast_1 : tensor<*xoutput_type> + } + scf.yield %31 : tensor<*xoutput_type> + } + scf.yield %29 : tensor<*xoutput_type> + } + scf.yield %27 : tensor<*xoutput_type> + } + scf.yield %25 : tensor<*xoutput_type> + } + scf.yield %18 : tensor<*xoutput_type> + } + scf.yield %16 : tensor<*xoutput_type> + } + %10 = shape.shape_of %arg0 : tensor<*xelem_type> -> tensor + %11 = shape.shape_of %arg1 : tensor<*xelem_type> -> tensor + %12 = shape.broadcast %10, %11 : tensor, tensor -> tensor + %13 = mhlo.dynamic_reshape %9, %12 : (tensor<*xoutput_type>, tensor) -> tensor<*xoutput_type> + return %13 : tensor<*xoutput_type> } diff --git a/tensorflow/core/kernels/mlir_generated/op_definitions/complex_abs.mlir.tmpl b/tensorflow/core/kernels/mlir_generated/op_definitions/complex_abs.mlir.tmpl index ebcc01234a3862..ed04436a48bb16 100644 --- a/tensorflow/core/kernels/mlir_generated/op_definitions/complex_abs.mlir.tmpl +++ b/tensorflow/core/kernels/mlir_generated/op_definitions/complex_abs.mlir.tmpl @@ -1,5 +1,9 @@ -func.func @ComplexAbs_platform_elem_type_output_type(%arg0: tensor<*xelem_type>) - -> tensor<*xoutput_type> attributes {tf_entry, llvm.emit_c_interface} { - %0 = mhlo.abs %arg0 : (tensor<*xelem_type>) -> tensor<*xoutput_type> - func.return %0 : tensor<*xoutput_type> +func.func @ComplexAbs_platform_elem_type_output_type(%arg0: tensor<*xelem_type>) -> tensor<*xoutput_type> attributes {llvm.emit_c_interface, tf_entry} { + %0 = shape.shape_of %arg0 : tensor<*xelem_type> -> tensor + %1 = shape.num_elements %0 : tensor -> index + %from_elements = tensor.from_elements %1 : tensor<1xindex> + %2 = mhlo.dynamic_reshape %arg0, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %3 = mhlo.abs %2 : (tensor) -> tensor + %4 = mhlo.dynamic_reshape %3, %0 : (tensor, tensor) -> tensor<*xoutput_type> + return %4 : tensor<*xoutput_type> } diff --git a/tensorflow/core/kernels/mlir_generated/op_definitions/conj.mlir.tmpl b/tensorflow/core/kernels/mlir_generated/op_definitions/conj.mlir.tmpl index 6bbd0f2417a0f9..cdd354aa43d0de 100644 --- a/tensorflow/core/kernels/mlir_generated/op_definitions/conj.mlir.tmpl +++ b/tensorflow/core/kernels/mlir_generated/op_definitions/conj.mlir.tmpl @@ -1,5 +1,9 @@ -func.func @Conj_platform_elem_type_output_type(%arg0: tensor<*xelem_type>) - -> tensor<*xoutput_type> attributes {tf_entry, llvm.emit_c_interface} { - %0 = chlo.conj %arg0 : tensor<*xelem_type> -> tensor<*xoutput_type> - func.return %0 : tensor<*xoutput_type> +func.func @Conj_platform_elem_type_output_type(%arg0: tensor<*xelem_type>) -> tensor<*xoutput_type> attributes {llvm.emit_c_interface, tf_entry} { + %0 = shape.shape_of %arg0 : tensor<*xelem_type> -> tensor + %1 = shape.num_elements %0 : tensor -> index + %from_elements = tensor.from_elements %1 : tensor<1xindex> + %2 = mhlo.dynamic_reshape %arg0, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %3 = chlo.conj %2 : tensor -> tensor + %4 = mhlo.dynamic_reshape %3, %0 : (tensor, tensor) -> tensor<*xoutput_type> + return %4 : tensor<*xoutput_type> } diff --git a/tensorflow/core/kernels/mlir_generated/op_definitions/cos.mlir.tmpl b/tensorflow/core/kernels/mlir_generated/op_definitions/cos.mlir.tmpl index e5394ba75ff576..39a6d7b1d82339 100644 --- a/tensorflow/core/kernels/mlir_generated/op_definitions/cos.mlir.tmpl +++ b/tensorflow/core/kernels/mlir_generated/op_definitions/cos.mlir.tmpl @@ -1,5 +1,9 @@ -func.func @Cos_platform_elem_type_output_type(%arg0: tensor<*xelem_type>) - -> tensor<*xoutput_type> attributes {tf_entry, llvm.emit_c_interface} { - %0 = mhlo.cosine %arg0 : tensor<*xelem_type> - func.return %0 : tensor<*xoutput_type> +func.func @Cos_platform_elem_type_output_type(%arg0: tensor<*xelem_type>) -> tensor<*xoutput_type> attributes {llvm.emit_c_interface, tf_entry} { + %0 = shape.shape_of %arg0 : tensor<*xelem_type> -> tensor + %1 = shape.num_elements %0 : tensor -> index + %from_elements = tensor.from_elements %1 : tensor<1xindex> + %2 = mhlo.dynamic_reshape %arg0, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %3 = mhlo.cosine %2 : tensor + %4 = mhlo.dynamic_reshape %3, %0 : (tensor, tensor) -> tensor<*xelem_type> + return %4 : tensor<*xelem_type> } diff --git a/tensorflow/core/kernels/mlir_generated/op_definitions/cosh.mlir.tmpl b/tensorflow/core/kernels/mlir_generated/op_definitions/cosh.mlir.tmpl index cb90a204235a27..1047460621e760 100644 --- a/tensorflow/core/kernels/mlir_generated/op_definitions/cosh.mlir.tmpl +++ b/tensorflow/core/kernels/mlir_generated/op_definitions/cosh.mlir.tmpl @@ -1,5 +1,9 @@ -func.func @Cosh_platform_elem_type_output_type(%arg0: tensor<*xelem_type>) - -> tensor<*xoutput_type> attributes {tf_entry, llvm.emit_c_interface} { - %0 = chlo.cosh %arg0 : tensor<*xelem_type> -> tensor<*xoutput_type> - func.return %0 : tensor<*xoutput_type> +func.func @Cosh_platform_elem_type_output_type(%arg0: tensor<*xelem_type>) -> tensor<*xoutput_type> attributes {llvm.emit_c_interface, tf_entry} { + %0 = shape.shape_of %arg0 : tensor<*xelem_type> -> tensor + %1 = shape.num_elements %0 : tensor -> index + %from_elements = tensor.from_elements %1 : tensor<1xindex> + %2 = mhlo.dynamic_reshape %arg0, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %3 = chlo.cosh %2 : tensor -> tensor + %4 = mhlo.dynamic_reshape %3, %0 : (tensor, tensor) -> tensor<*xoutput_type> + return %4 : tensor<*xoutput_type> } diff --git a/tensorflow/core/kernels/mlir_generated/op_definitions/digamma.mlir.tmpl b/tensorflow/core/kernels/mlir_generated/op_definitions/digamma.mlir.tmpl index 59a5f0e0cab368..7acbe863efd3d3 100644 --- a/tensorflow/core/kernels/mlir_generated/op_definitions/digamma.mlir.tmpl +++ b/tensorflow/core/kernels/mlir_generated/op_definitions/digamma.mlir.tmpl @@ -1,5 +1,9 @@ -func.func @Digamma_platform_elem_type_output_type(%arg0: tensor<*xelem_type>) - -> tensor<*xoutput_type> attributes {tf_entry, llvm.emit_c_interface} { - %0 = chlo.digamma %arg0 : tensor<*xelem_type> -> tensor<*xoutput_type> - func.return %0 : tensor<*xoutput_type> +func.func @Digamma_platform_elem_type_output_type(%arg0: tensor<*xelem_type>) -> tensor<*xoutput_type> attributes {llvm.emit_c_interface, tf_entry} { + %0 = shape.shape_of %arg0 : tensor<*xelem_type> -> tensor + %1 = shape.num_elements %0 : tensor -> index + %from_elements = tensor.from_elements %1 : tensor<1xindex> + %2 = mhlo.dynamic_reshape %arg0, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %3 = chlo.digamma %2 : tensor -> tensor + %4 = mhlo.dynamic_reshape %3, %0 : (tensor, tensor) -> tensor<*xoutput_type> + return %4 : tensor<*xoutput_type> } diff --git a/tensorflow/core/kernels/mlir_generated/op_definitions/div.mlir.tmpl b/tensorflow/core/kernels/mlir_generated/op_definitions/div.mlir.tmpl index ca6746a1d99730..fc4650e4ea4096 100644 --- a/tensorflow/core/kernels/mlir_generated/op_definitions/div.mlir.tmpl +++ b/tensorflow/core/kernels/mlir_generated/op_definitions/div.mlir.tmpl @@ -1,6 +1,129 @@ -func.func @Div_platform_elem_type_output_type(%arg0: tensor<*xelem_type>, %arg1: tensor<*xelem_type>) - -> tensor<*xoutput_type> attributes {tf_entry, llvm.emit_c_interface} { - %0 = chlo.broadcast_divide %arg0, %arg1 - : (tensor<*xelem_type>, tensor<*xelem_type>) -> tensor<*xoutput_type> - func.return %0 : tensor<*xoutput_type> +func.func @Div_platform_elem_type_output_type(%arg0: tensor<*xelem_type>, %arg1: tensor<*xelem_type>) -> tensor<*xoutput_type> attributes {llvm.emit_c_interface, tf_entry} { + %0 = shape.const_shape [1, 1, 1, 1, 1] : tensor<5xindex> + %c5 = arith.constant 5 : index + %1 = shape.const_shape [1, 1, 1, 1] : tensor<4xindex> + %c4 = arith.constant 4 : index + %2 = shape.const_shape [1, 1, 1] : tensor<3xindex> + %c3 = arith.constant 3 : index + %3 = shape.const_shape [1, 1] : tensor<2xindex> + %c2 = arith.constant 2 : index + %4 = shape.const_shape [1] : tensor<1xindex> + %c1 = arith.constant 1 : index + %5 = shape.shape_of %arg0 : tensor<*xelem_type> -> tensor + %6 = shape.shape_of %arg1 : tensor<*xelem_type> -> tensor + %7 = shape.num_elements %5 : tensor -> index + %8 = arith.cmpi eq, %7, %c1 : index + %9 = scf.if %8 -> (tensor<*xoutput_type>) { + %14 = shape.num_elements %6 : tensor -> index + %from_elements = tensor.from_elements %14 : tensor<1xindex> + %15 = mhlo.reshape %arg0 : (tensor<*xelem_type>) -> tensor + %16 = mhlo.dynamic_reshape %arg1, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %17 = chlo.broadcast_divide %15, %16 : (tensor, tensor) -> tensor + %cast = tensor.cast %17 : tensor to tensor<*xoutput_type> + scf.yield %cast : tensor<*xoutput_type> + } else { + %14 = shape.num_elements %6 : tensor -> index + %15 = arith.cmpi eq, %14, %c1 : index + %16 = scf.if %15 -> (tensor<*xoutput_type>) { + %17 = shape.num_elements %5 : tensor -> index + %from_elements = tensor.from_elements %17 : tensor<1xindex> + %18 = mhlo.dynamic_reshape %arg0, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %19 = mhlo.reshape %arg1 : (tensor<*xelem_type>) -> tensor + %20 = chlo.broadcast_divide %18, %19 : (tensor, tensor) -> tensor + %cast = tensor.cast %20 : tensor to tensor<*xoutput_type> + scf.yield %cast : tensor<*xoutput_type> + } else { + %17 = shape.shape_eq %5, %6 : tensor, tensor + %18 = scf.if %17 -> (tensor<*xoutput_type>) { + %19 = shape.any %5, %6 : tensor, tensor -> tensor + %20 = shape.num_elements %19 : tensor -> index + %from_elements = tensor.from_elements %20 : tensor<1xindex> + %21 = mhlo.dynamic_reshape %arg0, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %22 = mhlo.dynamic_reshape %arg1, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %23 = chlo.broadcast_divide %21, %22 : (tensor, tensor) -> tensor + %cast = tensor.cast %23 : tensor to tensor<*xoutput_type> + scf.yield %cast : tensor<*xoutput_type> + } else { + %19:2 = chlo.minimum_broadcast_shapes %5, %6 : tensor, tensor -> tensor, tensor + %20 = shape.rank %19#0 : tensor -> index + %21 = shape.rank %19#1 : tensor -> index + %22 = arith.cmpi sgt, %20, %21 : index + %23 = arith.select %22, %20, %21 : index + %24 = arith.cmpi ule, %23, %c1 : index + %25 = scf.if %24 -> (tensor<*xoutput_type>) { + %26 = shape.broadcast %19#0, %4 : tensor, tensor<1xindex> -> tensor + %cast = tensor.cast %26 : tensor to tensor<1xindex> + %27 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %28 = shape.broadcast %19#1, %4 : tensor, tensor<1xindex> -> tensor + %cast_0 = tensor.cast %28 : tensor to tensor<1xindex> + %29 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %30 = chlo.broadcast_divide %27, %29 : (tensor, tensor) -> tensor + %cast_1 = tensor.cast %30 : tensor to tensor<*xoutput_type> + scf.yield %cast_1 : tensor<*xoutput_type> + } else { + %26 = arith.cmpi ule, %23, %c2 : index + %27 = scf.if %26 -> (tensor<*xoutput_type>) { + %28 = shape.broadcast %19#0, %3 : tensor, tensor<2xindex> -> tensor + %cast = tensor.cast %28 : tensor to tensor<2xindex> + %29 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xelem_type>, tensor<2xindex>) -> tensor + %30 = shape.broadcast %19#1, %3 : tensor, tensor<2xindex> -> tensor + %cast_0 = tensor.cast %30 : tensor to tensor<2xindex> + %31 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xelem_type>, tensor<2xindex>) -> tensor + %32 = chlo.broadcast_divide %29, %31 : (tensor, tensor) -> tensor + %cast_1 = tensor.cast %32 : tensor to tensor<*xoutput_type> + scf.yield %cast_1 : tensor<*xoutput_type> + } else { + %28 = arith.cmpi ule, %23, %c3 : index + %29 = scf.if %28 -> (tensor<*xoutput_type>) { + %30 = shape.broadcast %19#0, %2 : tensor, tensor<3xindex> -> tensor + %cast = tensor.cast %30 : tensor to tensor<3xindex> + %31 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xelem_type>, tensor<3xindex>) -> tensor + %32 = shape.broadcast %19#1, %2 : tensor, tensor<3xindex> -> tensor + %cast_0 = tensor.cast %32 : tensor to tensor<3xindex> + %33 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xelem_type>, tensor<3xindex>) -> tensor + %34 = chlo.broadcast_divide %31, %33 : (tensor, tensor) -> tensor + %cast_1 = tensor.cast %34 : tensor to tensor<*xoutput_type> + scf.yield %cast_1 : tensor<*xoutput_type> + } else { + %30 = arith.cmpi ule, %23, %c4 : index + %31 = scf.if %30 -> (tensor<*xoutput_type>) { + %32 = shape.broadcast %19#0, %1 : tensor, tensor<4xindex> -> tensor + %cast = tensor.cast %32 : tensor to tensor<4xindex> + %33 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xelem_type>, tensor<4xindex>) -> tensor + %34 = shape.broadcast %19#1, %1 : tensor, tensor<4xindex> -> tensor + %cast_0 = tensor.cast %34 : tensor to tensor<4xindex> + %35 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xelem_type>, tensor<4xindex>) -> tensor + %36 = chlo.broadcast_divide %33, %35 : (tensor, tensor) -> tensor + %cast_1 = tensor.cast %36 : tensor to tensor<*xoutput_type> + scf.yield %cast_1 : tensor<*xoutput_type> + } else { + %32 = arith.cmpi ule, %23, %c5 : index + cf.assert %32, "Input for dynamic binary or n-ary op lowering was of a rank greater than 5" + %33 = shape.broadcast %19#0, %0 : tensor, tensor<5xindex> -> tensor + %cast = tensor.cast %33 : tensor to tensor<5xindex> + %34 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xelem_type>, tensor<5xindex>) -> tensor + %35 = shape.broadcast %19#1, %0 : tensor, tensor<5xindex> -> tensor + %cast_0 = tensor.cast %35 : tensor to tensor<5xindex> + %36 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xelem_type>, tensor<5xindex>) -> tensor + %37 = chlo.broadcast_divide %34, %36 : (tensor, tensor) -> tensor + %cast_1 = tensor.cast %37 : tensor to tensor<*xoutput_type> + scf.yield %cast_1 : tensor<*xoutput_type> + } + scf.yield %31 : tensor<*xoutput_type> + } + scf.yield %29 : tensor<*xoutput_type> + } + scf.yield %27 : tensor<*xoutput_type> + } + scf.yield %25 : tensor<*xoutput_type> + } + scf.yield %18 : tensor<*xoutput_type> + } + scf.yield %16 : tensor<*xoutput_type> + } + %10 = shape.shape_of %arg0 : tensor<*xelem_type> -> tensor + %11 = shape.shape_of %arg1 : tensor<*xelem_type> -> tensor + %12 = shape.broadcast %10, %11 : tensor, tensor -> tensor + %13 = mhlo.dynamic_reshape %9, %12 : (tensor<*xoutput_type>, tensor) -> tensor<*xoutput_type> + return %13 : tensor<*xoutput_type> } diff --git a/tensorflow/core/kernels/mlir_generated/op_definitions/div_no_nan.mlir.tmpl b/tensorflow/core/kernels/mlir_generated/op_definitions/div_no_nan.mlir.tmpl index 53b9f5f97b8863..0dd4fc448970e6 100644 --- a/tensorflow/core/kernels/mlir_generated/op_definitions/div_no_nan.mlir.tmpl +++ b/tensorflow/core/kernels/mlir_generated/op_definitions/div_no_nan.mlir.tmpl @@ -1,9 +1,148 @@ -func.func @DivNoNan_platform_elem_type_output_type(%arg0: tensor<*xelem_type>, - %arg1: tensor<*xelem_type>) -> tensor<*xoutput_type> attributes {tf_entry, - llvm.emit_c_interface} { - %0 = mhlo.constant dense<0.000000e+00> : tensor - %1 = chlo.broadcast_compare %arg1, %0 {comparison_direction = #chlo} : (tensor<*xelem_type>, tensor) -> tensor<*xi1> - %2 = chlo.broadcast_divide %arg0, %arg1 : (tensor<*xelem_type>, tensor<*xelem_type>) -> tensor<*xoutput_type> - %3 = chlo.broadcast_select %1, %0, %2 : (tensor<*xi1>, tensor, tensor<*xoutput_type>) -> tensor<*xoutput_type> - func.return %3 : tensor<*xoutput_type> +func.func @DivNoNan_platform_elem_type_output_type(%arg0: tensor<*xelem_type>, %arg1: tensor<*xelem_type>) -> tensor<*xoutput_type> attributes {llvm.emit_c_interface, tf_entry} { + %0 = shape.const_shape [1, 1, 1, 1, 1] : tensor<5xindex> + %c5 = arith.constant 5 : index + %1 = shape.const_shape [1, 1, 1, 1] : tensor<4xindex> + %c4 = arith.constant 4 : index + %2 = shape.const_shape [1, 1, 1] : tensor<3xindex> + %c3 = arith.constant 3 : index + %3 = shape.const_shape [1, 1] : tensor<2xindex> + %c2 = arith.constant 2 : index + %4 = shape.const_shape [1] : tensor<1xindex> + %c1 = arith.constant 1 : index + %5 = mhlo.constant dense<0.000000e+00> : tensor + %6 = shape.shape_of %arg0 : tensor<*xelem_type> -> tensor + %7 = shape.shape_of %arg1 : tensor<*xelem_type> -> tensor + %8 = shape.num_elements %6 : tensor -> index + %9 = arith.cmpi eq, %8, %c1 : index + %10 = scf.if %9 -> (tensor<*xoutput_type>) { + %17 = shape.num_elements %7 : tensor -> index + %from_elements = tensor.from_elements %17 : tensor<1xindex> + %18 = mhlo.reshape %arg0 : (tensor<*xelem_type>) -> tensor + %19 = mhlo.dynamic_reshape %arg1, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %20 = chlo.broadcast_compare %19, %5 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %21 = chlo.broadcast_divide %18, %19 : (tensor, tensor) -> tensor + %22 = chlo.broadcast_select %20, %5, %21 : (tensor, tensor, tensor) -> tensor + %cast = tensor.cast %22 : tensor to tensor<*xoutput_type> + scf.yield %cast : tensor<*xoutput_type> + } else { + %17 = shape.num_elements %7 : tensor -> index + %18 = arith.cmpi eq, %17, %c1 : index + %19 = scf.if %18 -> (tensor<*xoutput_type>) { + %20 = shape.num_elements %6 : tensor -> index + %from_elements = tensor.from_elements %20 : tensor<1xindex> + %21 = mhlo.dynamic_reshape %arg0, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %22 = mhlo.reshape %arg1 : (tensor<*xelem_type>) -> tensor + %23 = chlo.broadcast_compare %22, %5 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %24 = chlo.broadcast_divide %21, %22 : (tensor, tensor) -> tensor + %25 = chlo.broadcast_select %23, %5, %24 : (tensor, tensor, tensor) -> tensor + %cast = tensor.cast %25 : tensor to tensor<*xoutput_type> + scf.yield %cast : tensor<*xoutput_type> + } else { + %20 = shape.shape_eq %6, %7 : tensor, tensor + %21 = scf.if %20 -> (tensor<*xoutput_type>) { + %22 = shape.any %6, %7 : tensor, tensor -> tensor + %23 = shape.num_elements %22 : tensor -> index + %from_elements = tensor.from_elements %23 : tensor<1xindex> + %24 = mhlo.dynamic_reshape %arg0, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %25 = mhlo.dynamic_reshape %arg1, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %26 = chlo.broadcast_compare %25, %5 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %27 = chlo.broadcast_divide %24, %25 : (tensor, tensor) -> tensor + %28 = chlo.broadcast_select %26, %5, %27 : (tensor, tensor, tensor) -> tensor + %cast = tensor.cast %28 : tensor to tensor<*xoutput_type> + scf.yield %cast : tensor<*xoutput_type> + } else { + %22:2 = chlo.minimum_broadcast_shapes %6, %7 : tensor, tensor -> tensor, tensor + %23 = shape.rank %22#0 : tensor -> index + %24 = shape.rank %22#1 : tensor -> index + %25 = arith.cmpi sgt, %23, %24 : index + %26 = arith.select %25, %23, %24 : index + %27 = arith.cmpi ule, %26, %c1 : index + %28 = scf.if %27 -> (tensor<*xoutput_type>) { + %29 = shape.broadcast %22#0, %4 : tensor, tensor<1xindex> -> tensor + %cast = tensor.cast %29 : tensor to tensor<1xindex> + %30 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %31 = shape.broadcast %22#1, %4 : tensor, tensor<1xindex> -> tensor + %cast_0 = tensor.cast %31 : tensor to tensor<1xindex> + %32 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %33 = chlo.broadcast_compare %32, %5 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %34 = chlo.broadcast_divide %30, %32 : (tensor, tensor) -> tensor + %35 = chlo.broadcast_select %33, %5, %34 : (tensor, tensor, tensor) -> tensor + %cast_1 = tensor.cast %35 : tensor to tensor<*xoutput_type> + scf.yield %cast_1 : tensor<*xoutput_type> + } else { + %29 = arith.cmpi ule, %26, %c2 : index + %30 = scf.if %29 -> (tensor<*xoutput_type>) { + %31 = shape.broadcast %22#0, %3 : tensor, tensor<2xindex> -> tensor + %cast = tensor.cast %31 : tensor to tensor<2xindex> + %32 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xelem_type>, tensor<2xindex>) -> tensor + %33 = shape.broadcast %22#1, %3 : tensor, tensor<2xindex> -> tensor + %cast_0 = tensor.cast %33 : tensor to tensor<2xindex> + %34 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xelem_type>, tensor<2xindex>) -> tensor + %35 = chlo.broadcast_compare %34, %5 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %36 = chlo.broadcast_divide %32, %34 : (tensor, tensor) -> tensor + %37 = chlo.broadcast_select %35, %5, %36 : (tensor, tensor, tensor) -> tensor + %cast_1 = tensor.cast %37 : tensor to tensor<*xoutput_type> + scf.yield %cast_1 : tensor<*xoutput_type> + } else { + %31 = arith.cmpi ule, %26, %c3 : index + %32 = scf.if %31 -> (tensor<*xoutput_type>) { + %33 = shape.broadcast %22#0, %2 : tensor, tensor<3xindex> -> tensor + %cast = tensor.cast %33 : tensor to tensor<3xindex> + %34 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xelem_type>, tensor<3xindex>) -> tensor + %35 = shape.broadcast %22#1, %2 : tensor, tensor<3xindex> -> tensor + %cast_0 = tensor.cast %35 : tensor to tensor<3xindex> + %36 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xelem_type>, tensor<3xindex>) -> tensor + %37 = chlo.broadcast_compare %36, %5 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %38 = chlo.broadcast_divide %34, %36 : (tensor, tensor) -> tensor + %39 = chlo.broadcast_select %37, %5, %38 : (tensor, tensor, tensor) -> tensor + %cast_1 = tensor.cast %39 : tensor to tensor<*xoutput_type> + scf.yield %cast_1 : tensor<*xoutput_type> + } else { + %33 = arith.cmpi ule, %26, %c4 : index + %34 = scf.if %33 -> (tensor<*xoutput_type>) { + %35 = shape.broadcast %22#0, %1 : tensor, tensor<4xindex> -> tensor + %cast = tensor.cast %35 : tensor to tensor<4xindex> + %36 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xelem_type>, tensor<4xindex>) -> tensor + %37 = shape.broadcast %22#1, %1 : tensor, tensor<4xindex> -> tensor + %cast_0 = tensor.cast %37 : tensor to tensor<4xindex> + %38 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xelem_type>, tensor<4xindex>) -> tensor + %39 = chlo.broadcast_compare %38, %5 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %40 = chlo.broadcast_divide %36, %38 : (tensor, tensor) -> tensor + %41 = chlo.broadcast_select %39, %5, %40 : (tensor, tensor, tensor) -> tensor + %cast_1 = tensor.cast %41 : tensor to tensor<*xoutput_type> + scf.yield %cast_1 : tensor<*xoutput_type> + } else { + %35 = arith.cmpi ule, %26, %c5 : index + cf.assert %35, "Input for dynamic binary or n-ary op lowering was of a rank greater than 5" + %36 = shape.broadcast %22#0, %0 : tensor, tensor<5xindex> -> tensor + %cast = tensor.cast %36 : tensor to tensor<5xindex> + %37 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xelem_type>, tensor<5xindex>) -> tensor + %38 = shape.broadcast %22#1, %0 : tensor, tensor<5xindex> -> tensor + %cast_0 = tensor.cast %38 : tensor to tensor<5xindex> + %39 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xelem_type>, tensor<5xindex>) -> tensor + %40 = chlo.broadcast_compare %39, %5 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %41 = chlo.broadcast_divide %37, %39 : (tensor, tensor) -> tensor + %42 = chlo.broadcast_select %40, %5, %41 : (tensor, tensor, tensor) -> tensor + %cast_1 = tensor.cast %42 : tensor to tensor<*xoutput_type> + scf.yield %cast_1 : tensor<*xoutput_type> + } + scf.yield %34 : tensor<*xoutput_type> + } + scf.yield %32 : tensor<*xoutput_type> + } + scf.yield %30 : tensor<*xoutput_type> + } + scf.yield %28 : tensor<*xoutput_type> + } + scf.yield %21 : tensor<*xoutput_type> + } + scf.yield %19 : tensor<*xoutput_type> + } + %11 = shape.shape_of %arg1 : tensor<*xelem_type> -> tensor + %12 = shape.shape_of %arg0 : tensor<*xelem_type> -> tensor + %13 = shape.shape_of %arg1 : tensor<*xelem_type> -> tensor + %14 = shape.broadcast %12, %13 : tensor, tensor -> tensor + %15 = shape.broadcast %11, %14 : tensor, tensor -> tensor + %16 = mhlo.dynamic_reshape %10, %15 : (tensor<*xoutput_type>, tensor) -> tensor<*xoutput_type> + return %16 : tensor<*xoutput_type> } diff --git a/tensorflow/core/kernels/mlir_generated/op_definitions/div_no_nan_cmplx.mlir.tmpl b/tensorflow/core/kernels/mlir_generated/op_definitions/div_no_nan_cmplx.mlir.tmpl index 3b5509af9f03c7..f75314cf0590d2 100644 --- a/tensorflow/core/kernels/mlir_generated/op_definitions/div_no_nan_cmplx.mlir.tmpl +++ b/tensorflow/core/kernels/mlir_generated/op_definitions/div_no_nan_cmplx.mlir.tmpl @@ -1,9 +1,148 @@ -func.func @DivNoNan_platform_elem_type_output_type(%arg0: tensor<*xelem_type>, - %arg1: tensor<*xelem_type>) -> tensor<*xoutput_type> attributes {tf_entry, - llvm.emit_c_interface} { - %0 = mhlo.constant dense<(0.000000e+00,0.000000e+00)> : tensor - %1 = chlo.broadcast_compare %arg1, %0 {comparison_direction = #chlo} : (tensor<*xelem_type>, tensor) -> tensor<*xi1> - %2 = chlo.broadcast_divide %arg0, %arg1 : (tensor<*xelem_type>, tensor<*xelem_type>) -> tensor<*xoutput_type> - %3 = chlo.broadcast_select %1, %0, %2 : (tensor<*xi1>, tensor, tensor<*xoutput_type>) -> tensor<*xoutput_type> - func.return %3 : tensor<*xoutput_type> +func.func @DivNoNan_platform_elem_type_output_type(%arg0: tensor<*xelem_type>, %arg1: tensor<*xelem_type>) -> tensor<*xoutput_type> attributes {llvm.emit_c_interface, tf_entry} { + %0 = shape.const_shape [1, 1, 1, 1, 1] : tensor<5xindex> + %c5 = arith.constant 5 : index + %1 = shape.const_shape [1, 1, 1, 1] : tensor<4xindex> + %c4 = arith.constant 4 : index + %2 = shape.const_shape [1, 1, 1] : tensor<3xindex> + %c3 = arith.constant 3 : index + %3 = shape.const_shape [1, 1] : tensor<2xindex> + %c2 = arith.constant 2 : index + %4 = shape.const_shape [1] : tensor<1xindex> + %c1 = arith.constant 1 : index + %5 = mhlo.constant dense<(0.000000e+00,0.000000e+00)> : tensor + %6 = shape.shape_of %arg0 : tensor<*xelem_type> -> tensor + %7 = shape.shape_of %arg1 : tensor<*xelem_type> -> tensor + %8 = shape.num_elements %6 : tensor -> index + %9 = arith.cmpi eq, %8, %c1 : index + %10 = scf.if %9 -> (tensor<*xoutput_type>) { + %17 = shape.num_elements %7 : tensor -> index + %from_elements = tensor.from_elements %17 : tensor<1xindex> + %18 = mhlo.reshape %arg0 : (tensor<*xelem_type>) -> tensor + %19 = mhlo.dynamic_reshape %arg1, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %20 = chlo.broadcast_compare %19, %5 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %21 = chlo.broadcast_divide %18, %19 : (tensor, tensor) -> tensor + %22 = chlo.broadcast_select %20, %5, %21 : (tensor, tensor, tensor) -> tensor + %cast = tensor.cast %22 : tensor to tensor<*xoutput_type> + scf.yield %cast : tensor<*xoutput_type> + } else { + %17 = shape.num_elements %7 : tensor -> index + %18 = arith.cmpi eq, %17, %c1 : index + %19 = scf.if %18 -> (tensor<*xoutput_type>) { + %20 = shape.num_elements %6 : tensor -> index + %from_elements = tensor.from_elements %20 : tensor<1xindex> + %21 = mhlo.dynamic_reshape %arg0, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %22 = mhlo.reshape %arg1 : (tensor<*xelem_type>) -> tensor + %23 = chlo.broadcast_compare %22, %5 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %24 = chlo.broadcast_divide %21, %22 : (tensor, tensor) -> tensor + %25 = chlo.broadcast_select %23, %5, %24 : (tensor, tensor, tensor) -> tensor + %cast = tensor.cast %25 : tensor to tensor<*xoutput_type> + scf.yield %cast : tensor<*xoutput_type> + } else { + %20 = shape.shape_eq %6, %7 : tensor, tensor + %21 = scf.if %20 -> (tensor<*xoutput_type>) { + %22 = shape.any %6, %7 : tensor, tensor -> tensor + %23 = shape.num_elements %22 : tensor -> index + %from_elements = tensor.from_elements %23 : tensor<1xindex> + %24 = mhlo.dynamic_reshape %arg0, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %25 = mhlo.dynamic_reshape %arg1, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %26 = chlo.broadcast_compare %25, %5 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %27 = chlo.broadcast_divide %24, %25 : (tensor, tensor) -> tensor + %28 = chlo.broadcast_select %26, %5, %27 : (tensor, tensor, tensor) -> tensor + %cast = tensor.cast %28 : tensor to tensor<*xoutput_type> + scf.yield %cast : tensor<*xoutput_type> + } else { + %22:2 = chlo.minimum_broadcast_shapes %6, %7 : tensor, tensor -> tensor, tensor + %23 = shape.rank %22#0 : tensor -> index + %24 = shape.rank %22#1 : tensor -> index + %25 = arith.cmpi sgt, %23, %24 : index + %26 = arith.select %25, %23, %24 : index + %27 = arith.cmpi ule, %26, %c1 : index + %28 = scf.if %27 -> (tensor<*xoutput_type>) { + %29 = shape.broadcast %22#0, %4 : tensor, tensor<1xindex> -> tensor + %cast = tensor.cast %29 : tensor to tensor<1xindex> + %30 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %31 = shape.broadcast %22#1, %4 : tensor, tensor<1xindex> -> tensor + %cast_0 = tensor.cast %31 : tensor to tensor<1xindex> + %32 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %33 = chlo.broadcast_compare %32, %5 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %34 = chlo.broadcast_divide %30, %32 : (tensor, tensor) -> tensor + %35 = chlo.broadcast_select %33, %5, %34 : (tensor, tensor, tensor) -> tensor + %cast_1 = tensor.cast %35 : tensor to tensor<*xoutput_type> + scf.yield %cast_1 : tensor<*xoutput_type> + } else { + %29 = arith.cmpi ule, %26, %c2 : index + %30 = scf.if %29 -> (tensor<*xoutput_type>) { + %31 = shape.broadcast %22#0, %3 : tensor, tensor<2xindex> -> tensor + %cast = tensor.cast %31 : tensor to tensor<2xindex> + %32 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xelem_type>, tensor<2xindex>) -> tensor + %33 = shape.broadcast %22#1, %3 : tensor, tensor<2xindex> -> tensor + %cast_0 = tensor.cast %33 : tensor to tensor<2xindex> + %34 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xelem_type>, tensor<2xindex>) -> tensor + %35 = chlo.broadcast_compare %34, %5 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %36 = chlo.broadcast_divide %32, %34 : (tensor, tensor) -> tensor + %37 = chlo.broadcast_select %35, %5, %36 : (tensor, tensor, tensor) -> tensor + %cast_1 = tensor.cast %37 : tensor to tensor<*xoutput_type> + scf.yield %cast_1 : tensor<*xoutput_type> + } else { + %31 = arith.cmpi ule, %26, %c3 : index + %32 = scf.if %31 -> (tensor<*xoutput_type>) { + %33 = shape.broadcast %22#0, %2 : tensor, tensor<3xindex> -> tensor + %cast = tensor.cast %33 : tensor to tensor<3xindex> + %34 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xelem_type>, tensor<3xindex>) -> tensor + %35 = shape.broadcast %22#1, %2 : tensor, tensor<3xindex> -> tensor + %cast_0 = tensor.cast %35 : tensor to tensor<3xindex> + %36 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xelem_type>, tensor<3xindex>) -> tensor + %37 = chlo.broadcast_compare %36, %5 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %38 = chlo.broadcast_divide %34, %36 : (tensor, tensor) -> tensor + %39 = chlo.broadcast_select %37, %5, %38 : (tensor, tensor, tensor) -> tensor + %cast_1 = tensor.cast %39 : tensor to tensor<*xoutput_type> + scf.yield %cast_1 : tensor<*xoutput_type> + } else { + %33 = arith.cmpi ule, %26, %c4 : index + %34 = scf.if %33 -> (tensor<*xoutput_type>) { + %35 = shape.broadcast %22#0, %1 : tensor, tensor<4xindex> -> tensor + %cast = tensor.cast %35 : tensor to tensor<4xindex> + %36 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xelem_type>, tensor<4xindex>) -> tensor + %37 = shape.broadcast %22#1, %1 : tensor, tensor<4xindex> -> tensor + %cast_0 = tensor.cast %37 : tensor to tensor<4xindex> + %38 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xelem_type>, tensor<4xindex>) -> tensor + %39 = chlo.broadcast_compare %38, %5 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %40 = chlo.broadcast_divide %36, %38 : (tensor, tensor) -> tensor + %41 = chlo.broadcast_select %39, %5, %40 : (tensor, tensor, tensor) -> tensor + %cast_1 = tensor.cast %41 : tensor to tensor<*xoutput_type> + scf.yield %cast_1 : tensor<*xoutput_type> + } else { + %35 = arith.cmpi ule, %26, %c5 : index + cf.assert %35, "Input for dynamic binary or n-ary op lowering was of a rank greater than 5" + %36 = shape.broadcast %22#0, %0 : tensor, tensor<5xindex> -> tensor + %cast = tensor.cast %36 : tensor to tensor<5xindex> + %37 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xelem_type>, tensor<5xindex>) -> tensor + %38 = shape.broadcast %22#1, %0 : tensor, tensor<5xindex> -> tensor + %cast_0 = tensor.cast %38 : tensor to tensor<5xindex> + %39 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xelem_type>, tensor<5xindex>) -> tensor + %40 = chlo.broadcast_compare %39, %5 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %41 = chlo.broadcast_divide %37, %39 : (tensor, tensor) -> tensor + %42 = chlo.broadcast_select %40, %5, %41 : (tensor, tensor, tensor) -> tensor + %cast_1 = tensor.cast %42 : tensor to tensor<*xoutput_type> + scf.yield %cast_1 : tensor<*xoutput_type> + } + scf.yield %34 : tensor<*xoutput_type> + } + scf.yield %32 : tensor<*xoutput_type> + } + scf.yield %30 : tensor<*xoutput_type> + } + scf.yield %28 : tensor<*xoutput_type> + } + scf.yield %21 : tensor<*xoutput_type> + } + scf.yield %19 : tensor<*xoutput_type> + } + %11 = shape.shape_of %arg1 : tensor<*xelem_type> -> tensor + %12 = shape.shape_of %arg0 : tensor<*xelem_type> -> tensor + %13 = shape.shape_of %arg1 : tensor<*xelem_type> -> tensor + %14 = shape.broadcast %12, %13 : tensor, tensor -> tensor + %15 = shape.broadcast %11, %14 : tensor, tensor -> tensor + %16 = mhlo.dynamic_reshape %10, %15 : (tensor<*xoutput_type>, tensor) -> tensor<*xoutput_type> + return %16 : tensor<*xoutput_type> } diff --git a/tensorflow/core/kernels/mlir_generated/op_definitions/elu.mlir.tmpl b/tensorflow/core/kernels/mlir_generated/op_definitions/elu.mlir.tmpl index 4575994a5e5ac6..40378cce20097f 100644 --- a/tensorflow/core/kernels/mlir_generated/op_definitions/elu.mlir.tmpl +++ b/tensorflow/core/kernels/mlir_generated/op_definitions/elu.mlir.tmpl @@ -1,8 +1,12 @@ -func.func @Elu_platform_elem_type_output_type(%arg0: tensor<*xelem_type>) - -> tensor<*xoutput_type> attributes {tf_entry, llvm.emit_c_interface} { - %0 = "chlo.constant_like"(%arg0) {value = 0.000000e+00 : elem_type} : (tensor<*xelem_type>) -> tensor<*xelem_type> - %1 = mhlo.compare GT, %arg0, %0 : (tensor<*xelem_type>, tensor<*xelem_type>) -> tensor<*xi1> - %2 = mhlo.exponential_minus_one %arg0 : tensor<*xelem_type> - %3 = mhlo.select %1, %arg0, %2 : tensor<*xi1>, tensor<*xelem_type> - func.return %3 : tensor<*xoutput_type> +func.func @Elu_platform_elem_type_output_type(%arg0: tensor<*xelem_type>) -> tensor<*xoutput_type> attributes {llvm.emit_c_interface, tf_entry} { + %0 = shape.shape_of %arg0 : tensor<*xelem_type> -> tensor + %1 = shape.num_elements %0 : tensor -> index + %from_elements = tensor.from_elements %1 : tensor<1xindex> + %2 = mhlo.dynamic_reshape %arg0, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %3 = "chlo.constant_like"(%2) {value = 0.000000e+00 : elem_type} : (tensor) -> tensor + %4 = mhlo.compare GT, %2, %3 : (tensor, tensor) -> tensor + %5 = mhlo.exponential_minus_one %2 : tensor + %6 = mhlo.select %4, %2, %5 : tensor, tensor + %7 = mhlo.dynamic_reshape %6, %0 : (tensor, tensor) -> tensor<*xelem_type> + return %7 : tensor<*xelem_type> } diff --git a/tensorflow/core/kernels/mlir_generated/op_definitions/equal.mlir.tmpl b/tensorflow/core/kernels/mlir_generated/op_definitions/equal.mlir.tmpl index 96d97f0273536d..fbf599ae3b3375 100644 --- a/tensorflow/core/kernels/mlir_generated/op_definitions/equal.mlir.tmpl +++ b/tensorflow/core/kernels/mlir_generated/op_definitions/equal.mlir.tmpl @@ -1,5 +1,129 @@ -func.func @Equal_platform_elem_type_output_type(%arg0: tensor<*xelem_type>, %arg1: tensor<*xelem_type>) - -> tensor<*xoutput_type> attributes {tf_entry, llvm.emit_c_interface} { - %0 = chlo.broadcast_compare %arg0, %arg1 {comparison_direction = #chlo} : (tensor<*xelem_type>, tensor<*xelem_type>) -> tensor<*xoutput_type> - func.return %0 : tensor<*xoutput_type> +func.func @Equal_platform_elem_type_output_type(%arg0: tensor<*xelem_type>, %arg1: tensor<*xelem_type>) -> tensor<*xoutput_type> attributes {llvm.emit_c_interface, tf_entry} { + %0 = shape.const_shape [1, 1, 1, 1, 1] : tensor<5xindex> + %c5 = arith.constant 5 : index + %1 = shape.const_shape [1, 1, 1, 1] : tensor<4xindex> + %c4 = arith.constant 4 : index + %2 = shape.const_shape [1, 1, 1] : tensor<3xindex> + %c3 = arith.constant 3 : index + %3 = shape.const_shape [1, 1] : tensor<2xindex> + %c2 = arith.constant 2 : index + %4 = shape.const_shape [1] : tensor<1xindex> + %c1 = arith.constant 1 : index + %5 = shape.shape_of %arg0 : tensor<*xelem_type> -> tensor + %6 = shape.shape_of %arg1 : tensor<*xelem_type> -> tensor + %7 = shape.num_elements %5 : tensor -> index + %8 = arith.cmpi eq, %7, %c1 : index + %9 = scf.if %8 -> (tensor<*xoutput_type>) { + %14 = shape.num_elements %6 : tensor -> index + %from_elements = tensor.from_elements %14 : tensor<1xindex> + %15 = mhlo.reshape %arg0 : (tensor<*xelem_type>) -> tensor + %16 = mhlo.dynamic_reshape %arg1, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %17 = chlo.broadcast_compare %15, %16 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %cast = tensor.cast %17 : tensor to tensor<*xoutput_type> + scf.yield %cast : tensor<*xoutput_type> + } else { + %14 = shape.num_elements %6 : tensor -> index + %15 = arith.cmpi eq, %14, %c1 : index + %16 = scf.if %15 -> (tensor<*xoutput_type>) { + %17 = shape.num_elements %5 : tensor -> index + %from_elements = tensor.from_elements %17 : tensor<1xindex> + %18 = mhlo.dynamic_reshape %arg0, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %19 = mhlo.reshape %arg1 : (tensor<*xelem_type>) -> tensor + %20 = chlo.broadcast_compare %18, %19 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %cast = tensor.cast %20 : tensor to tensor<*xoutput_type> + scf.yield %cast : tensor<*xoutput_type> + } else { + %17 = shape.shape_eq %5, %6 : tensor, tensor + %18 = scf.if %17 -> (tensor<*xoutput_type>) { + %19 = shape.any %5, %6 : tensor, tensor -> tensor + %20 = shape.num_elements %19 : tensor -> index + %from_elements = tensor.from_elements %20 : tensor<1xindex> + %21 = mhlo.dynamic_reshape %arg0, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %22 = mhlo.dynamic_reshape %arg1, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %23 = chlo.broadcast_compare %21, %22 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %cast = tensor.cast %23 : tensor to tensor<*xoutput_type> + scf.yield %cast : tensor<*xoutput_type> + } else { + %19:2 = chlo.minimum_broadcast_shapes %5, %6 : tensor, tensor -> tensor, tensor + %20 = shape.rank %19#0 : tensor -> index + %21 = shape.rank %19#1 : tensor -> index + %22 = arith.cmpi sgt, %20, %21 : index + %23 = arith.select %22, %20, %21 : index + %24 = arith.cmpi ule, %23, %c1 : index + %25 = scf.if %24 -> (tensor<*xoutput_type>) { + %26 = shape.broadcast %19#0, %4 : tensor, tensor<1xindex> -> tensor + %cast = tensor.cast %26 : tensor to tensor<1xindex> + %27 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %28 = shape.broadcast %19#1, %4 : tensor, tensor<1xindex> -> tensor + %cast_0 = tensor.cast %28 : tensor to tensor<1xindex> + %29 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %30 = chlo.broadcast_compare %27, %29 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %cast_1 = tensor.cast %30 : tensor to tensor<*xoutput_type> + scf.yield %cast_1 : tensor<*xoutput_type> + } else { + %26 = arith.cmpi ule, %23, %c2 : index + %27 = scf.if %26 -> (tensor<*xoutput_type>) { + %28 = shape.broadcast %19#0, %3 : tensor, tensor<2xindex> -> tensor + %cast = tensor.cast %28 : tensor to tensor<2xindex> + %29 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xelem_type>, tensor<2xindex>) -> tensor + %30 = shape.broadcast %19#1, %3 : tensor, tensor<2xindex> -> tensor + %cast_0 = tensor.cast %30 : tensor to tensor<2xindex> + %31 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xelem_type>, tensor<2xindex>) -> tensor + %32 = chlo.broadcast_compare %29, %31 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %cast_1 = tensor.cast %32 : tensor to tensor<*xoutput_type> + scf.yield %cast_1 : tensor<*xoutput_type> + } else { + %28 = arith.cmpi ule, %23, %c3 : index + %29 = scf.if %28 -> (tensor<*xoutput_type>) { + %30 = shape.broadcast %19#0, %2 : tensor, tensor<3xindex> -> tensor + %cast = tensor.cast %30 : tensor to tensor<3xindex> + %31 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xelem_type>, tensor<3xindex>) -> tensor + %32 = shape.broadcast %19#1, %2 : tensor, tensor<3xindex> -> tensor + %cast_0 = tensor.cast %32 : tensor to tensor<3xindex> + %33 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xelem_type>, tensor<3xindex>) -> tensor + %34 = chlo.broadcast_compare %31, %33 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %cast_1 = tensor.cast %34 : tensor to tensor<*xoutput_type> + scf.yield %cast_1 : tensor<*xoutput_type> + } else { + %30 = arith.cmpi ule, %23, %c4 : index + %31 = scf.if %30 -> (tensor<*xoutput_type>) { + %32 = shape.broadcast %19#0, %1 : tensor, tensor<4xindex> -> tensor + %cast = tensor.cast %32 : tensor to tensor<4xindex> + %33 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xelem_type>, tensor<4xindex>) -> tensor + %34 = shape.broadcast %19#1, %1 : tensor, tensor<4xindex> -> tensor + %cast_0 = tensor.cast %34 : tensor to tensor<4xindex> + %35 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xelem_type>, tensor<4xindex>) -> tensor + %36 = chlo.broadcast_compare %33, %35 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %cast_1 = tensor.cast %36 : tensor to tensor<*xoutput_type> + scf.yield %cast_1 : tensor<*xoutput_type> + } else { + %32 = arith.cmpi ule, %23, %c5 : index + cf.assert %32, "Input for dynamic binary or n-ary op lowering was of a rank greater than 5" + %33 = shape.broadcast %19#0, %0 : tensor, tensor<5xindex> -> tensor + %cast = tensor.cast %33 : tensor to tensor<5xindex> + %34 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xelem_type>, tensor<5xindex>) -> tensor + %35 = shape.broadcast %19#1, %0 : tensor, tensor<5xindex> -> tensor + %cast_0 = tensor.cast %35 : tensor to tensor<5xindex> + %36 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xelem_type>, tensor<5xindex>) -> tensor + %37 = chlo.broadcast_compare %34, %36 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %cast_1 = tensor.cast %37 : tensor to tensor<*xoutput_type> + scf.yield %cast_1 : tensor<*xoutput_type> + } + scf.yield %31 : tensor<*xoutput_type> + } + scf.yield %29 : tensor<*xoutput_type> + } + scf.yield %27 : tensor<*xoutput_type> + } + scf.yield %25 : tensor<*xoutput_type> + } + scf.yield %18 : tensor<*xoutput_type> + } + scf.yield %16 : tensor<*xoutput_type> + } + %10 = shape.shape_of %arg0 : tensor<*xelem_type> -> tensor + %11 = shape.shape_of %arg1 : tensor<*xelem_type> -> tensor + %12 = shape.broadcast %10, %11 : tensor, tensor -> tensor + %13 = mhlo.dynamic_reshape %9, %12 : (tensor<*xoutput_type>, tensor) -> tensor<*xoutput_type> + return %13 : tensor<*xoutput_type> } diff --git a/tensorflow/core/kernels/mlir_generated/op_definitions/erf.mlir.tmpl b/tensorflow/core/kernels/mlir_generated/op_definitions/erf.mlir.tmpl index 83fb89eaf43fc5..51ab4c0c995caf 100644 --- a/tensorflow/core/kernels/mlir_generated/op_definitions/erf.mlir.tmpl +++ b/tensorflow/core/kernels/mlir_generated/op_definitions/erf.mlir.tmpl @@ -1,5 +1,9 @@ -func.func @Erf_platform_elem_type_output_type(%arg0: tensor<*xelem_type>) - -> tensor<*xoutput_type> attributes {tf_entry, llvm.emit_c_interface} { - %0 = chlo.erf %arg0 : tensor<*xelem_type> -> tensor<*xoutput_type> - func.return %0 : tensor<*xoutput_type> +func.func @Erf_platform_elem_type_output_type(%arg0: tensor<*xelem_type>) -> tensor<*xoutput_type> attributes {llvm.emit_c_interface, tf_entry} { + %0 = shape.shape_of %arg0 : tensor<*xelem_type> -> tensor + %1 = shape.num_elements %0 : tensor -> index + %from_elements = tensor.from_elements %1 : tensor<1xindex> + %2 = mhlo.dynamic_reshape %arg0, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %3 = chlo.erf %2 : tensor -> tensor + %4 = mhlo.dynamic_reshape %3, %0 : (tensor, tensor) -> tensor<*xoutput_type> + return %4 : tensor<*xoutput_type> } diff --git a/tensorflow/core/kernels/mlir_generated/op_definitions/erfc.mlir.tmpl b/tensorflow/core/kernels/mlir_generated/op_definitions/erfc.mlir.tmpl index f40bfee5e926f6..e0254a1d1a84e6 100644 --- a/tensorflow/core/kernels/mlir_generated/op_definitions/erfc.mlir.tmpl +++ b/tensorflow/core/kernels/mlir_generated/op_definitions/erfc.mlir.tmpl @@ -1,5 +1,9 @@ -func.func @Erfc_platform_elem_type_output_type(%arg0: tensor<*xelem_type>) - -> tensor<*xoutput_type> attributes {tf_entry, llvm.emit_c_interface} { - %0 = chlo.erfc %arg0 : tensor<*xelem_type> -> tensor<*xoutput_type> - func.return %0 : tensor<*xoutput_type> +func.func @Erfc_platform_elem_type_output_type(%arg0: tensor<*xelem_type>) -> tensor<*xoutput_type> attributes {llvm.emit_c_interface, tf_entry} { + %0 = shape.shape_of %arg0 : tensor<*xelem_type> -> tensor + %1 = shape.num_elements %0 : tensor -> index + %from_elements = tensor.from_elements %1 : tensor<1xindex> + %2 = mhlo.dynamic_reshape %arg0, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %3 = chlo.erfc %2 : tensor -> tensor + %4 = mhlo.dynamic_reshape %3, %0 : (tensor, tensor) -> tensor<*xoutput_type> + return %4 : tensor<*xoutput_type> } diff --git a/tensorflow/core/kernels/mlir_generated/op_definitions/exp.mlir.tmpl b/tensorflow/core/kernels/mlir_generated/op_definitions/exp.mlir.tmpl index c41a3fb83d7d5f..030ff18ac939c9 100644 --- a/tensorflow/core/kernels/mlir_generated/op_definitions/exp.mlir.tmpl +++ b/tensorflow/core/kernels/mlir_generated/op_definitions/exp.mlir.tmpl @@ -1,5 +1,9 @@ -func.func @Exp_platform_elem_type_output_type(%arg0: tensor<*xelem_type>) - -> tensor<*xoutput_type> attributes {tf_entry, llvm.emit_c_interface} { - %0 = mhlo.exponential %arg0 : tensor<*xelem_type> - func.return %0 : tensor<*xoutput_type> +func.func @Exp_platform_elem_type_output_type(%arg0: tensor<*xelem_type>) -> tensor<*xoutput_type> attributes {llvm.emit_c_interface, tf_entry} { + %0 = shape.shape_of %arg0 : tensor<*xelem_type> -> tensor + %1 = shape.num_elements %0 : tensor -> index + %from_elements = tensor.from_elements %1 : tensor<1xindex> + %2 = mhlo.dynamic_reshape %arg0, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %3 = mhlo.exponential %2 : tensor + %4 = mhlo.dynamic_reshape %3, %0 : (tensor, tensor) -> tensor<*xelem_type> + return %4 : tensor<*xelem_type> } diff --git a/tensorflow/core/kernels/mlir_generated/op_definitions/expm1.mlir.tmpl b/tensorflow/core/kernels/mlir_generated/op_definitions/expm1.mlir.tmpl index ee0f193dffe2dd..f5ed49876d04fb 100644 --- a/tensorflow/core/kernels/mlir_generated/op_definitions/expm1.mlir.tmpl +++ b/tensorflow/core/kernels/mlir_generated/op_definitions/expm1.mlir.tmpl @@ -1,5 +1,9 @@ -func.func @Expm1_platform_elem_type_output_type(%arg0: tensor<*xelem_type>) - -> tensor<*xoutput_type> attributes {tf_entry, llvm.emit_c_interface} { - %0 = mhlo.exponential_minus_one %arg0 : tensor<*xelem_type> - func.return %0 : tensor<*xoutput_type> +func.func @Expm1_platform_elem_type_output_type(%arg0: tensor<*xelem_type>) -> tensor<*xoutput_type> attributes {llvm.emit_c_interface, tf_entry} { + %0 = shape.shape_of %arg0 : tensor<*xelem_type> -> tensor + %1 = shape.num_elements %0 : tensor -> index + %from_elements = tensor.from_elements %1 : tensor<1xindex> + %2 = mhlo.dynamic_reshape %arg0, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %3 = mhlo.exponential_minus_one %2 : tensor + %4 = mhlo.dynamic_reshape %3, %0 : (tensor, tensor) -> tensor<*xelem_type> + return %4 : tensor<*xelem_type> } diff --git a/tensorflow/core/kernels/mlir_generated/op_definitions/floor.mlir.tmpl b/tensorflow/core/kernels/mlir_generated/op_definitions/floor.mlir.tmpl index 226a224eb5012b..5200ecc0f4f54f 100644 --- a/tensorflow/core/kernels/mlir_generated/op_definitions/floor.mlir.tmpl +++ b/tensorflow/core/kernels/mlir_generated/op_definitions/floor.mlir.tmpl @@ -1,5 +1,9 @@ -func.func @Floor_platform_elem_type_output_type(%arg0: tensor<*xelem_type>) - -> tensor<*xoutput_type> attributes {tf_entry, llvm.emit_c_interface} { - %0 = mhlo.floor %arg0 : tensor<*xelem_type> - func.return %0 : tensor<*xoutput_type> +func.func @Floor_platform_elem_type_output_type(%arg0: tensor<*xelem_type>) -> tensor<*xoutput_type> attributes {llvm.emit_c_interface, tf_entry} { + %0 = shape.shape_of %arg0 : tensor<*xelem_type> -> tensor + %1 = shape.num_elements %0 : tensor -> index + %from_elements = tensor.from_elements %1 : tensor<1xindex> + %2 = mhlo.dynamic_reshape %arg0, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %3 = mhlo.floor %2 : tensor + %4 = mhlo.dynamic_reshape %3, %0 : (tensor, tensor) -> tensor<*xelem_type> + return %4 : tensor<*xelem_type> } diff --git a/tensorflow/core/kernels/mlir_generated/op_definitions/floor_div.mlir.tmpl b/tensorflow/core/kernels/mlir_generated/op_definitions/floor_div.mlir.tmpl index 265074b3f33418..0b19056d6c641a 100644 --- a/tensorflow/core/kernels/mlir_generated/op_definitions/floor_div.mlir.tmpl +++ b/tensorflow/core/kernels/mlir_generated/op_definitions/floor_div.mlir.tmpl @@ -1,16 +1,195 @@ -func.func @FloorDiv_platform_elem_type_output_type(%arg0: tensor<*xelem_type>, %arg1: tensor<*xelem_type>) - -> tensor<*xoutput_type> attributes {tf_entry, llvm.emit_c_interface} { - %0 = chlo.broadcast_divide %arg0, %arg1 : (tensor<*xelem_type>, tensor<*xelem_type>) -> tensor<*xelem_type> - %1 = chlo.broadcast_multiply %0, %arg1 : (tensor<*xelem_type>, tensor<*xelem_type>) -> tensor<*xelem_type> - %2 = chlo.broadcast_compare %1, %arg0 {comparison_direction = #chlo} : (tensor<*xelem_type>, tensor<*xelem_type>) -> tensor<*xi1> - %3 = mhlo.constant dense<0> : tensor - %4 = chlo.broadcast_compare %arg0, %3 {comparison_direction = #chlo} : (tensor<*xelem_type>, tensor) -> tensor<*xi1> - %5 = mhlo.constant dense<0> : tensor - %6 = chlo.broadcast_compare %arg1, %5 {comparison_direction = #chlo} : (tensor<*xelem_type>, tensor) -> tensor<*xi1> - %7 = chlo.broadcast_compare %4, %6 {comparison_direction = #chlo} : (tensor<*xi1>, tensor<*xi1>) -> tensor<*xi1> - %8 = chlo.broadcast_and %2, %7 : (tensor<*xi1>, tensor<*xi1>) -> tensor<*xi1> - %9 = mhlo.constant dense<1> : tensor - %10 = chlo.broadcast_subtract %0, %9 : (tensor<*xelem_type>, tensor) -> tensor<*xelem_type> - %11 = mhlo.select %8, %10, %0 : tensor<*xi1>, tensor<*xelem_type> - return %11 : tensor<*xoutput_type> +func.func @FloorDiv_platform_elem_type_output_type(%arg0: tensor<*xelem_type>, %arg1: tensor<*xelem_type>) -> tensor<*xoutput_type> attributes {llvm.emit_c_interface, tf_entry} { + %0 = shape.const_shape [1, 1, 1, 1, 1] : tensor<5xindex> + %c5 = arith.constant 5 : index + %1 = shape.const_shape [1, 1, 1, 1] : tensor<4xindex> + %c4 = arith.constant 4 : index + %2 = shape.const_shape [1, 1, 1] : tensor<3xindex> + %c3 = arith.constant 3 : index + %3 = shape.const_shape [1, 1] : tensor<2xindex> + %c2 = arith.constant 2 : index + %4 = shape.const_shape [1] : tensor<1xindex> + %c1 = arith.constant 1 : index + %5 = mhlo.constant dense<1> : tensor + %6 = mhlo.constant dense<0> : tensor + %7 = shape.shape_of %arg1 : tensor<*xelem_type> -> tensor + %8 = shape.shape_of %arg0 : tensor<*xelem_type> -> tensor + %9 = shape.num_elements %7 : tensor -> index + %10 = arith.cmpi eq, %9, %c1 : index + %11 = scf.if %10 -> (tensor<*xelem_type>) { + %16 = shape.num_elements %8 : tensor -> index + %from_elements = tensor.from_elements %16 : tensor<1xindex> + %17 = mhlo.reshape %arg1 : (tensor<*xelem_type>) -> tensor + %18 = mhlo.dynamic_reshape %arg0, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %19 = chlo.broadcast_divide %18, %17 : (tensor, tensor) -> tensor + %20 = chlo.broadcast_multiply %19, %17 : (tensor, tensor) -> tensor + %21 = chlo.broadcast_compare %20, %18 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %22 = chlo.broadcast_compare %18, %6 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %23 = chlo.broadcast_compare %17, %6 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %24 = chlo.broadcast_compare %22, %23 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %25 = chlo.broadcast_and %21, %24 : (tensor, tensor) -> tensor + %26 = chlo.broadcast_subtract %19, %5 : (tensor, tensor) -> tensor + %27 = mhlo.select %25, %26, %19 : tensor, tensor + %cast = tensor.cast %27 : tensor to tensor<*xelem_type> + scf.yield %cast : tensor<*xelem_type> + } else { + %16 = shape.num_elements %8 : tensor -> index + %17 = arith.cmpi eq, %16, %c1 : index + %18 = scf.if %17 -> (tensor<*xelem_type>) { + %19 = shape.num_elements %7 : tensor -> index + %from_elements = tensor.from_elements %19 : tensor<1xindex> + %20 = mhlo.dynamic_reshape %arg1, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %21 = mhlo.reshape %arg0 : (tensor<*xelem_type>) -> tensor + %22 = chlo.broadcast_divide %21, %20 : (tensor, tensor) -> tensor + %23 = chlo.broadcast_multiply %22, %20 : (tensor, tensor) -> tensor + %24 = chlo.broadcast_compare %23, %21 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %25 = chlo.broadcast_compare %21, %6 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %26 = chlo.broadcast_compare %20, %6 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %27 = chlo.broadcast_compare %25, %26 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %28 = chlo.broadcast_and %24, %27 : (tensor, tensor) -> tensor + %29 = chlo.broadcast_subtract %22, %5 : (tensor, tensor) -> tensor + %30 = mhlo.select %28, %29, %22 : tensor, tensor + %cast = tensor.cast %30 : tensor to tensor<*xelem_type> + scf.yield %cast : tensor<*xelem_type> + } else { + %19 = shape.shape_eq %7, %8 : tensor, tensor + %20 = scf.if %19 -> (tensor<*xelem_type>) { + %21 = shape.any %7, %8 : tensor, tensor -> tensor + %22 = shape.num_elements %21 : tensor -> index + %from_elements = tensor.from_elements %22 : tensor<1xindex> + %23 = mhlo.dynamic_reshape %arg1, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %24 = mhlo.dynamic_reshape %arg0, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %25 = chlo.broadcast_divide %24, %23 : (tensor, tensor) -> tensor + %26 = chlo.broadcast_multiply %25, %23 : (tensor, tensor) -> tensor + %27 = chlo.broadcast_compare %26, %24 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %28 = chlo.broadcast_compare %24, %6 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %29 = chlo.broadcast_compare %23, %6 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %30 = chlo.broadcast_compare %28, %29 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %31 = chlo.broadcast_and %27, %30 : (tensor, tensor) -> tensor + %32 = chlo.broadcast_subtract %25, %5 : (tensor, tensor) -> tensor + %33 = mhlo.select %31, %32, %25 : tensor, tensor + %cast = tensor.cast %33 : tensor to tensor<*xelem_type> + scf.yield %cast : tensor<*xelem_type> + } else { + %21:2 = chlo.minimum_broadcast_shapes %7, %8 : tensor, tensor -> tensor, tensor + %22 = shape.rank %21#0 : tensor -> index + %23 = shape.rank %21#1 : tensor -> index + %24 = arith.cmpi sgt, %22, %23 : index + %25 = arith.select %24, %22, %23 : index + %26 = arith.cmpi ule, %25, %c1 : index + %27 = scf.if %26 -> (tensor<*xelem_type>) { + %28 = shape.broadcast %21#0, %4 : tensor, tensor<1xindex> -> tensor + %cast = tensor.cast %28 : tensor to tensor<1xindex> + %29 = mhlo.dynamic_reshape %arg1, %cast : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %30 = shape.broadcast %21#1, %4 : tensor, tensor<1xindex> -> tensor + %cast_0 = tensor.cast %30 : tensor to tensor<1xindex> + %31 = mhlo.dynamic_reshape %arg0, %cast_0 : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %32 = chlo.broadcast_divide %31, %29 : (tensor, tensor) -> tensor + %33 = chlo.broadcast_multiply %32, %29 : (tensor, tensor) -> tensor + %34 = chlo.broadcast_compare %33, %31 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %35 = chlo.broadcast_compare %31, %6 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %36 = chlo.broadcast_compare %29, %6 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %37 = chlo.broadcast_compare %35, %36 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %38 = chlo.broadcast_and %34, %37 : (tensor, tensor) -> tensor + %39 = chlo.broadcast_subtract %32, %5 : (tensor, tensor) -> tensor + %40 = mhlo.select %38, %39, %32 : tensor, tensor + %cast_1 = tensor.cast %40 : tensor to tensor<*xelem_type> + scf.yield %cast_1 : tensor<*xelem_type> + } else { + %28 = arith.cmpi ule, %25, %c2 : index + %29 = scf.if %28 -> (tensor<*xelem_type>) { + %30 = shape.broadcast %21#0, %3 : tensor, tensor<2xindex> -> tensor + %cast = tensor.cast %30 : tensor to tensor<2xindex> + %31 = mhlo.dynamic_reshape %arg1, %cast : (tensor<*xelem_type>, tensor<2xindex>) -> tensor + %32 = shape.broadcast %21#1, %3 : tensor, tensor<2xindex> -> tensor + %cast_0 = tensor.cast %32 : tensor to tensor<2xindex> + %33 = mhlo.dynamic_reshape %arg0, %cast_0 : (tensor<*xelem_type>, tensor<2xindex>) -> tensor + %34 = chlo.broadcast_divide %33, %31 : (tensor, tensor) -> tensor + %35 = chlo.broadcast_multiply %34, %31 : (tensor, tensor) -> tensor + %36 = chlo.broadcast_compare %35, %33 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %37 = chlo.broadcast_compare %33, %6 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %38 = chlo.broadcast_compare %31, %6 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %39 = chlo.broadcast_compare %37, %38 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %40 = chlo.broadcast_and %36, %39 : (tensor, tensor) -> tensor + %41 = chlo.broadcast_subtract %34, %5 : (tensor, tensor) -> tensor + %42 = mhlo.select %40, %41, %34 : tensor, tensor + %cast_1 = tensor.cast %42 : tensor to tensor<*xelem_type> + scf.yield %cast_1 : tensor<*xelem_type> + } else { + %30 = arith.cmpi ule, %25, %c3 : index + %31 = scf.if %30 -> (tensor<*xelem_type>) { + %32 = shape.broadcast %21#0, %2 : tensor, tensor<3xindex> -> tensor + %cast = tensor.cast %32 : tensor to tensor<3xindex> + %33 = mhlo.dynamic_reshape %arg1, %cast : (tensor<*xelem_type>, tensor<3xindex>) -> tensor + %34 = shape.broadcast %21#1, %2 : tensor, tensor<3xindex> -> tensor + %cast_0 = tensor.cast %34 : tensor to tensor<3xindex> + %35 = mhlo.dynamic_reshape %arg0, %cast_0 : (tensor<*xelem_type>, tensor<3xindex>) -> tensor + %36 = chlo.broadcast_divide %35, %33 : (tensor, tensor) -> tensor + %37 = chlo.broadcast_multiply %36, %33 : (tensor, tensor) -> tensor + %38 = chlo.broadcast_compare %37, %35 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %39 = chlo.broadcast_compare %35, %6 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %40 = chlo.broadcast_compare %33, %6 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %41 = chlo.broadcast_compare %39, %40 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %42 = chlo.broadcast_and %38, %41 : (tensor, tensor) -> tensor + %43 = chlo.broadcast_subtract %36, %5 : (tensor, tensor) -> tensor + %44 = mhlo.select %42, %43, %36 : tensor, tensor + %cast_1 = tensor.cast %44 : tensor to tensor<*xelem_type> + scf.yield %cast_1 : tensor<*xelem_type> + } else { + %32 = arith.cmpi ule, %25, %c4 : index + %33 = scf.if %32 -> (tensor<*xelem_type>) { + %34 = shape.broadcast %21#0, %1 : tensor, tensor<4xindex> -> tensor + %cast = tensor.cast %34 : tensor to tensor<4xindex> + %35 = mhlo.dynamic_reshape %arg1, %cast : (tensor<*xelem_type>, tensor<4xindex>) -> tensor + %36 = shape.broadcast %21#1, %1 : tensor, tensor<4xindex> -> tensor + %cast_0 = tensor.cast %36 : tensor to tensor<4xindex> + %37 = mhlo.dynamic_reshape %arg0, %cast_0 : (tensor<*xelem_type>, tensor<4xindex>) -> tensor + %38 = chlo.broadcast_divide %37, %35 : (tensor, tensor) -> tensor + %39 = chlo.broadcast_multiply %38, %35 : (tensor, tensor) -> tensor + %40 = chlo.broadcast_compare %39, %37 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %41 = chlo.broadcast_compare %37, %6 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %42 = chlo.broadcast_compare %35, %6 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %43 = chlo.broadcast_compare %41, %42 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %44 = chlo.broadcast_and %40, %43 : (tensor, tensor) -> tensor + %45 = chlo.broadcast_subtract %38, %5 : (tensor, tensor) -> tensor + %46 = mhlo.select %44, %45, %38 : tensor, tensor + %cast_1 = tensor.cast %46 : tensor to tensor<*xelem_type> + scf.yield %cast_1 : tensor<*xelem_type> + } else { + %34 = arith.cmpi ule, %25, %c5 : index + cf.assert %34, "Input for dynamic binary or n-ary op lowering was of a rank greater than 5" + %35 = shape.broadcast %21#0, %0 : tensor, tensor<5xindex> -> tensor + %cast = tensor.cast %35 : tensor to tensor<5xindex> + %36 = mhlo.dynamic_reshape %arg1, %cast : (tensor<*xelem_type>, tensor<5xindex>) -> tensor + %37 = shape.broadcast %21#1, %0 : tensor, tensor<5xindex> -> tensor + %cast_0 = tensor.cast %37 : tensor to tensor<5xindex> + %38 = mhlo.dynamic_reshape %arg0, %cast_0 : (tensor<*xelem_type>, tensor<5xindex>) -> tensor + %39 = chlo.broadcast_divide %38, %36 : (tensor, tensor) -> tensor + %40 = chlo.broadcast_multiply %39, %36 : (tensor, tensor) -> tensor + %41 = chlo.broadcast_compare %40, %38 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %42 = chlo.broadcast_compare %38, %6 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %43 = chlo.broadcast_compare %36, %6 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %44 = chlo.broadcast_compare %42, %43 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %45 = chlo.broadcast_and %41, %44 : (tensor, tensor) -> tensor + %46 = chlo.broadcast_subtract %39, %5 : (tensor, tensor) -> tensor + %47 = mhlo.select %45, %46, %39 : tensor, tensor + %cast_1 = tensor.cast %47 : tensor to tensor<*xelem_type> + scf.yield %cast_1 : tensor<*xelem_type> + } + scf.yield %33 : tensor<*xelem_type> + } + scf.yield %31 : tensor<*xelem_type> + } + scf.yield %29 : tensor<*xelem_type> + } + scf.yield %27 : tensor<*xelem_type> + } + scf.yield %20 : tensor<*xelem_type> + } + scf.yield %18 : tensor<*xelem_type> + } + %12 = shape.shape_of %arg0 : tensor<*xelem_type> -> tensor + %13 = shape.shape_of %arg1 : tensor<*xelem_type> -> tensor + %14 = shape.broadcast %12, %13 : tensor, tensor -> tensor + %15 = mhlo.dynamic_reshape %11, %14 : (tensor<*xelem_type>, tensor) -> tensor<*xelem_type> + return %15 : tensor<*xelem_type> } diff --git a/tensorflow/core/kernels/mlir_generated/op_definitions/floor_div_float.mlir.tmpl b/tensorflow/core/kernels/mlir_generated/op_definitions/floor_div_float.mlir.tmpl index 32d9f89431eb3e..280d957a6bf705 100644 --- a/tensorflow/core/kernels/mlir_generated/op_definitions/floor_div_float.mlir.tmpl +++ b/tensorflow/core/kernels/mlir_generated/op_definitions/floor_div_float.mlir.tmpl @@ -1,6 +1,137 @@ -func.func @FloorDiv_platform_elem_type_output_type(%arg0: tensor<*xelem_type>, %arg1: tensor<*xelem_type>) - -> tensor<*xoutput_type> attributes {tf_entry, llvm.emit_c_interface} { - %0 = chlo.broadcast_divide %arg0, %arg1 : (tensor<*xelem_type>, tensor<*xelem_type>) -> tensor<*xoutput_type> - %1 = mhlo.floor %0 : tensor<*xoutput_type> - func.return %1 : tensor<*xoutput_type> +func.func @FloorDiv_platform_elem_type_output_type(%arg0: tensor<*xelem_type>, %arg1: tensor<*xelem_type>) -> tensor<*xoutput_type> attributes {llvm.emit_c_interface, tf_entry} { + %0 = shape.const_shape [1, 1, 1, 1, 1] : tensor<5xindex> + %c5 = arith.constant 5 : index + %1 = shape.const_shape [1, 1, 1, 1] : tensor<4xindex> + %c4 = arith.constant 4 : index + %2 = shape.const_shape [1, 1, 1] : tensor<3xindex> + %c3 = arith.constant 3 : index + %3 = shape.const_shape [1, 1] : tensor<2xindex> + %c2 = arith.constant 2 : index + %4 = shape.const_shape [1] : tensor<1xindex> + %c1 = arith.constant 1 : index + %5 = shape.shape_of %arg0 : tensor<*xelem_type> -> tensor + %6 = shape.shape_of %arg1 : tensor<*xelem_type> -> tensor + %7 = shape.num_elements %5 : tensor -> index + %8 = arith.cmpi eq, %7, %c1 : index + %9 = scf.if %8 -> (tensor<*xoutput_type>) { + %14 = shape.num_elements %6 : tensor -> index + %from_elements = tensor.from_elements %14 : tensor<1xindex> + %15 = mhlo.reshape %arg0 : (tensor<*xelem_type>) -> tensor + %16 = mhlo.dynamic_reshape %arg1, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %17 = chlo.broadcast_divide %15, %16 : (tensor, tensor) -> tensor + %18 = mhlo.floor %17 : tensor + %cast = tensor.cast %18 : tensor to tensor<*xoutput_type> + scf.yield %cast : tensor<*xoutput_type> + } else { + %14 = shape.num_elements %6 : tensor -> index + %15 = arith.cmpi eq, %14, %c1 : index + %16 = scf.if %15 -> (tensor<*xoutput_type>) { + %17 = shape.num_elements %5 : tensor -> index + %from_elements = tensor.from_elements %17 : tensor<1xindex> + %18 = mhlo.dynamic_reshape %arg0, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %19 = mhlo.reshape %arg1 : (tensor<*xelem_type>) -> tensor + %20 = chlo.broadcast_divide %18, %19 : (tensor, tensor) -> tensor + %21 = mhlo.floor %20 : tensor + %cast = tensor.cast %21 : tensor to tensor<*xoutput_type> + scf.yield %cast : tensor<*xoutput_type> + } else { + %17 = shape.shape_eq %5, %6 : tensor, tensor + %18 = scf.if %17 -> (tensor<*xoutput_type>) { + %19 = shape.any %5, %6 : tensor, tensor -> tensor + %20 = shape.num_elements %19 : tensor -> index + %from_elements = tensor.from_elements %20 : tensor<1xindex> + %21 = mhlo.dynamic_reshape %arg0, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %22 = mhlo.dynamic_reshape %arg1, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %23 = chlo.broadcast_divide %21, %22 : (tensor, tensor) -> tensor + %24 = mhlo.floor %23 : tensor + %cast = tensor.cast %24 : tensor to tensor<*xoutput_type> + scf.yield %cast : tensor<*xoutput_type> + } else { + %19:2 = chlo.minimum_broadcast_shapes %5, %6 : tensor, tensor -> tensor, tensor + %20 = shape.rank %19#0 : tensor -> index + %21 = shape.rank %19#1 : tensor -> index + %22 = arith.cmpi sgt, %20, %21 : index + %23 = arith.select %22, %20, %21 : index + %24 = arith.cmpi ule, %23, %c1 : index + %25 = scf.if %24 -> (tensor<*xoutput_type>) { + %26 = shape.broadcast %19#0, %4 : tensor, tensor<1xindex> -> tensor + %cast = tensor.cast %26 : tensor to tensor<1xindex> + %27 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %28 = shape.broadcast %19#1, %4 : tensor, tensor<1xindex> -> tensor + %cast_0 = tensor.cast %28 : tensor to tensor<1xindex> + %29 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %30 = chlo.broadcast_divide %27, %29 : (tensor, tensor) -> tensor + %31 = mhlo.floor %30 : tensor + %cast_1 = tensor.cast %31 : tensor to tensor<*xoutput_type> + scf.yield %cast_1 : tensor<*xoutput_type> + } else { + %26 = arith.cmpi ule, %23, %c2 : index + %27 = scf.if %26 -> (tensor<*xoutput_type>) { + %28 = shape.broadcast %19#0, %3 : tensor, tensor<2xindex> -> tensor + %cast = tensor.cast %28 : tensor to tensor<2xindex> + %29 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xelem_type>, tensor<2xindex>) -> tensor + %30 = shape.broadcast %19#1, %3 : tensor, tensor<2xindex> -> tensor + %cast_0 = tensor.cast %30 : tensor to tensor<2xindex> + %31 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xelem_type>, tensor<2xindex>) -> tensor + %32 = chlo.broadcast_divide %29, %31 : (tensor, tensor) -> tensor + %33 = mhlo.floor %32 : tensor + %cast_1 = tensor.cast %33 : tensor to tensor<*xoutput_type> + scf.yield %cast_1 : tensor<*xoutput_type> + } else { + %28 = arith.cmpi ule, %23, %c3 : index + %29 = scf.if %28 -> (tensor<*xoutput_type>) { + %30 = shape.broadcast %19#0, %2 : tensor, tensor<3xindex> -> tensor + %cast = tensor.cast %30 : tensor to tensor<3xindex> + %31 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xelem_type>, tensor<3xindex>) -> tensor + %32 = shape.broadcast %19#1, %2 : tensor, tensor<3xindex> -> tensor + %cast_0 = tensor.cast %32 : tensor to tensor<3xindex> + %33 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xelem_type>, tensor<3xindex>) -> tensor + %34 = chlo.broadcast_divide %31, %33 : (tensor, tensor) -> tensor + %35 = mhlo.floor %34 : tensor + %cast_1 = tensor.cast %35 : tensor to tensor<*xoutput_type> + scf.yield %cast_1 : tensor<*xoutput_type> + } else { + %30 = arith.cmpi ule, %23, %c4 : index + %31 = scf.if %30 -> (tensor<*xoutput_type>) { + %32 = shape.broadcast %19#0, %1 : tensor, tensor<4xindex> -> tensor + %cast = tensor.cast %32 : tensor to tensor<4xindex> + %33 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xelem_type>, tensor<4xindex>) -> tensor + %34 = shape.broadcast %19#1, %1 : tensor, tensor<4xindex> -> tensor + %cast_0 = tensor.cast %34 : tensor to tensor<4xindex> + %35 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xelem_type>, tensor<4xindex>) -> tensor + %36 = chlo.broadcast_divide %33, %35 : (tensor, tensor) -> tensor + %37 = mhlo.floor %36 : tensor + %cast_1 = tensor.cast %37 : tensor to tensor<*xoutput_type> + scf.yield %cast_1 : tensor<*xoutput_type> + } else { + %32 = arith.cmpi ule, %23, %c5 : index + cf.assert %32, "Input for dynamic binary or n-ary op lowering was of a rank greater than 5" + %33 = shape.broadcast %19#0, %0 : tensor, tensor<5xindex> -> tensor + %cast = tensor.cast %33 : tensor to tensor<5xindex> + %34 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xelem_type>, tensor<5xindex>) -> tensor + %35 = shape.broadcast %19#1, %0 : tensor, tensor<5xindex> -> tensor + %cast_0 = tensor.cast %35 : tensor to tensor<5xindex> + %36 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xelem_type>, tensor<5xindex>) -> tensor + %37 = chlo.broadcast_divide %34, %36 : (tensor, tensor) -> tensor + %38 = mhlo.floor %37 : tensor + %cast_1 = tensor.cast %38 : tensor to tensor<*xoutput_type> + scf.yield %cast_1 : tensor<*xoutput_type> + } + scf.yield %31 : tensor<*xoutput_type> + } + scf.yield %29 : tensor<*xoutput_type> + } + scf.yield %27 : tensor<*xoutput_type> + } + scf.yield %25 : tensor<*xoutput_type> + } + scf.yield %18 : tensor<*xoutput_type> + } + scf.yield %16 : tensor<*xoutput_type> + } + %10 = shape.shape_of %arg0 : tensor<*xelem_type> -> tensor + %11 = shape.shape_of %arg1 : tensor<*xelem_type> -> tensor + %12 = shape.broadcast %10, %11 : tensor, tensor -> tensor + %13 = mhlo.dynamic_reshape %9, %12 : (tensor<*xoutput_type>, tensor) -> tensor<*xoutput_type> + return %13 : tensor<*xoutput_type> } diff --git a/tensorflow/core/kernels/mlir_generated/op_definitions/floor_mod.mlir.tmpl b/tensorflow/core/kernels/mlir_generated/op_definitions/floor_mod.mlir.tmpl index 92497856f22264..d7a703c12f3b37 100644 --- a/tensorflow/core/kernels/mlir_generated/op_definitions/floor_mod.mlir.tmpl +++ b/tensorflow/core/kernels/mlir_generated/op_definitions/floor_mod.mlir.tmpl @@ -1,14 +1,188 @@ -func.func @FloorMod_platform_elem_type_output_type(%arg0: tensor<*xelem_type>, %arg1: tensor<*xelem_type>) - -> tensor<*xoutput_type> attributes {tf_entry, llvm.emit_c_interface} { - %0 = chlo.broadcast_remainder %arg0, %arg1 : (tensor<*xelem_type>, tensor<*xelem_type>) -> tensor<*xelem_type> - %1 = mhlo.constant dense<0> : tensor - %2 = chlo.broadcast_compare %0, %1 {comparison_direction = #chlo} : (tensor<*xelem_type>, tensor) -> tensor<*xi1> - %3 = mhlo.constant dense<0> : tensor - %4 = chlo.broadcast_compare %arg1, %3 {comparison_direction = #chlo} : (tensor<*xelem_type>, tensor) -> tensor<*xi1> - %5 = chlo.broadcast_compare %0, %3 {comparison_direction = #chlo} : (tensor<*xelem_type>, tensor) -> tensor<*xi1> - %6 = chlo.broadcast_compare %4, %5 {comparison_direction = #chlo} : (tensor<*xi1>, tensor<*xi1>) -> tensor<*xi1> - %7 = chlo.broadcast_and %2, %6 : (tensor<*xi1>, tensor<*xi1>) -> tensor<*xi1> - %8 = chlo.broadcast_add %arg1, %0 : (tensor<*xelem_type>, tensor<*xelem_type>) -> tensor<*xelem_type> - %9 = mhlo.select %7, %8, %0 : tensor<*xi1>, tensor<*xelem_type> - return %9 : tensor<*xoutput_type> +func.func @FloorMod_platform_elem_type_output_type(%arg0: tensor<*xelem_type>, %arg1: tensor<*xelem_type>) -> tensor<*xoutput_type> attributes {llvm.emit_c_interface, tf_entry} { + %0 = shape.const_shape [1, 1, 1, 1, 1] : tensor<5xindex> + %c5 = arith.constant 5 : index + %1 = shape.const_shape [1, 1, 1, 1] : tensor<4xindex> + %c4 = arith.constant 4 : index + %2 = shape.const_shape [1, 1, 1] : tensor<3xindex> + %c3 = arith.constant 3 : index + %3 = shape.const_shape [1, 1] : tensor<2xindex> + %c2 = arith.constant 2 : index + %4 = shape.const_shape [1] : tensor<1xindex> + %c1 = arith.constant 1 : index + %5 = mhlo.constant dense<0> : tensor + %6 = shape.shape_of %arg1 : tensor<*xelem_type> -> tensor + %7 = shape.shape_of %arg0 : tensor<*xelem_type> -> tensor + %8 = shape.num_elements %6 : tensor -> index + %9 = arith.cmpi eq, %8, %c1 : index + %10 = scf.if %9 -> (tensor<*xelem_type>) { + %17 = shape.num_elements %7 : tensor -> index + %from_elements = tensor.from_elements %17 : tensor<1xindex> + %18 = mhlo.reshape %arg1 : (tensor<*xelem_type>) -> tensor + %19 = mhlo.dynamic_reshape %arg0, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %20 = chlo.broadcast_remainder %19, %18 : (tensor, tensor) -> tensor + %21 = chlo.broadcast_compare %20, %5 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %22 = chlo.broadcast_compare %18, %5 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %23 = chlo.broadcast_compare %20, %5 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %24 = chlo.broadcast_compare %22, %23 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %25 = chlo.broadcast_and %21, %24 : (tensor, tensor) -> tensor + %26 = chlo.broadcast_add %18, %20 : (tensor, tensor) -> tensor + %27 = mhlo.select %25, %26, %20 : tensor, tensor + %cast = tensor.cast %27 : tensor to tensor<*xelem_type> + scf.yield %cast : tensor<*xelem_type> + } else { + %17 = shape.num_elements %7 : tensor -> index + %18 = arith.cmpi eq, %17, %c1 : index + %19 = scf.if %18 -> (tensor<*xelem_type>) { + %20 = shape.num_elements %6 : tensor -> index + %from_elements = tensor.from_elements %20 : tensor<1xindex> + %21 = mhlo.dynamic_reshape %arg1, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %22 = mhlo.reshape %arg0 : (tensor<*xelem_type>) -> tensor + %23 = chlo.broadcast_remainder %22, %21 : (tensor, tensor) -> tensor + %24 = chlo.broadcast_compare %23, %5 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %25 = chlo.broadcast_compare %21, %5 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %26 = chlo.broadcast_compare %23, %5 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %27 = chlo.broadcast_compare %25, %26 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %28 = chlo.broadcast_and %24, %27 : (tensor, tensor) -> tensor + %29 = chlo.broadcast_add %21, %23 : (tensor, tensor) -> tensor + %30 = mhlo.select %28, %29, %23 : tensor, tensor + %cast = tensor.cast %30 : tensor to tensor<*xelem_type> + scf.yield %cast : tensor<*xelem_type> + } else { + %20 = shape.shape_eq %6, %7 : tensor, tensor + %21 = scf.if %20 -> (tensor<*xelem_type>) { + %22 = shape.any %6, %7 : tensor, tensor -> tensor + %23 = shape.num_elements %22 : tensor -> index + %from_elements = tensor.from_elements %23 : tensor<1xindex> + %24 = mhlo.dynamic_reshape %arg1, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %25 = mhlo.dynamic_reshape %arg0, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %26 = chlo.broadcast_remainder %25, %24 : (tensor, tensor) -> tensor + %27 = chlo.broadcast_compare %26, %5 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %28 = chlo.broadcast_compare %24, %5 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %29 = chlo.broadcast_compare %26, %5 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %30 = chlo.broadcast_compare %28, %29 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %31 = chlo.broadcast_and %27, %30 : (tensor, tensor) -> tensor + %32 = chlo.broadcast_add %24, %26 : (tensor, tensor) -> tensor + %33 = mhlo.select %31, %32, %26 : tensor, tensor + %cast = tensor.cast %33 : tensor to tensor<*xelem_type> + scf.yield %cast : tensor<*xelem_type> + } else { + %22:2 = chlo.minimum_broadcast_shapes %6, %7 : tensor, tensor -> tensor, tensor + %23 = shape.rank %22#0 : tensor -> index + %24 = shape.rank %22#1 : tensor -> index + %25 = arith.cmpi sgt, %23, %24 : index + %26 = arith.select %25, %23, %24 : index + %27 = arith.cmpi ule, %26, %c1 : index + %28 = scf.if %27 -> (tensor<*xelem_type>) { + %29 = shape.broadcast %22#0, %4 : tensor, tensor<1xindex> -> tensor + %cast = tensor.cast %29 : tensor to tensor<1xindex> + %30 = mhlo.dynamic_reshape %arg1, %cast : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %31 = shape.broadcast %22#1, %4 : tensor, tensor<1xindex> -> tensor + %cast_0 = tensor.cast %31 : tensor to tensor<1xindex> + %32 = mhlo.dynamic_reshape %arg0, %cast_0 : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %33 = chlo.broadcast_remainder %32, %30 : (tensor, tensor) -> tensor + %34 = chlo.broadcast_compare %33, %5 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %35 = chlo.broadcast_compare %30, %5 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %36 = chlo.broadcast_compare %33, %5 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %37 = chlo.broadcast_compare %35, %36 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %38 = chlo.broadcast_and %34, %37 : (tensor, tensor) -> tensor + %39 = chlo.broadcast_add %30, %33 : (tensor, tensor) -> tensor + %40 = mhlo.select %38, %39, %33 : tensor, tensor + %cast_1 = tensor.cast %40 : tensor to tensor<*xelem_type> + scf.yield %cast_1 : tensor<*xelem_type> + } else { + %29 = arith.cmpi ule, %26, %c2 : index + %30 = scf.if %29 -> (tensor<*xelem_type>) { + %31 = shape.broadcast %22#0, %3 : tensor, tensor<2xindex> -> tensor + %cast = tensor.cast %31 : tensor to tensor<2xindex> + %32 = mhlo.dynamic_reshape %arg1, %cast : (tensor<*xelem_type>, tensor<2xindex>) -> tensor + %33 = shape.broadcast %22#1, %3 : tensor, tensor<2xindex> -> tensor + %cast_0 = tensor.cast %33 : tensor to tensor<2xindex> + %34 = mhlo.dynamic_reshape %arg0, %cast_0 : (tensor<*xelem_type>, tensor<2xindex>) -> tensor + %35 = chlo.broadcast_remainder %34, %32 : (tensor, tensor) -> tensor + %36 = chlo.broadcast_compare %35, %5 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %37 = chlo.broadcast_compare %32, %5 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %38 = chlo.broadcast_compare %35, %5 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %39 = chlo.broadcast_compare %37, %38 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %40 = chlo.broadcast_and %36, %39 : (tensor, tensor) -> tensor + %41 = chlo.broadcast_add %32, %35 : (tensor, tensor) -> tensor + %42 = mhlo.select %40, %41, %35 : tensor, tensor + %cast_1 = tensor.cast %42 : tensor to tensor<*xelem_type> + scf.yield %cast_1 : tensor<*xelem_type> + } else { + %31 = arith.cmpi ule, %26, %c3 : index + %32 = scf.if %31 -> (tensor<*xelem_type>) { + %33 = shape.broadcast %22#0, %2 : tensor, tensor<3xindex> -> tensor + %cast = tensor.cast %33 : tensor to tensor<3xindex> + %34 = mhlo.dynamic_reshape %arg1, %cast : (tensor<*xelem_type>, tensor<3xindex>) -> tensor + %35 = shape.broadcast %22#1, %2 : tensor, tensor<3xindex> -> tensor + %cast_0 = tensor.cast %35 : tensor to tensor<3xindex> + %36 = mhlo.dynamic_reshape %arg0, %cast_0 : (tensor<*xelem_type>, tensor<3xindex>) -> tensor + %37 = chlo.broadcast_remainder %36, %34 : (tensor, tensor) -> tensor + %38 = chlo.broadcast_compare %37, %5 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %39 = chlo.broadcast_compare %34, %5 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %40 = chlo.broadcast_compare %37, %5 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %41 = chlo.broadcast_compare %39, %40 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %42 = chlo.broadcast_and %38, %41 : (tensor, tensor) -> tensor + %43 = chlo.broadcast_add %34, %37 : (tensor, tensor) -> tensor + %44 = mhlo.select %42, %43, %37 : tensor, tensor + %cast_1 = tensor.cast %44 : tensor to tensor<*xelem_type> + scf.yield %cast_1 : tensor<*xelem_type> + } else { + %33 = arith.cmpi ule, %26, %c4 : index + %34 = scf.if %33 -> (tensor<*xelem_type>) { + %35 = shape.broadcast %22#0, %1 : tensor, tensor<4xindex> -> tensor + %cast = tensor.cast %35 : tensor to tensor<4xindex> + %36 = mhlo.dynamic_reshape %arg1, %cast : (tensor<*xelem_type>, tensor<4xindex>) -> tensor + %37 = shape.broadcast %22#1, %1 : tensor, tensor<4xindex> -> tensor + %cast_0 = tensor.cast %37 : tensor to tensor<4xindex> + %38 = mhlo.dynamic_reshape %arg0, %cast_0 : (tensor<*xelem_type>, tensor<4xindex>) -> tensor + %39 = chlo.broadcast_remainder %38, %36 : (tensor, tensor) -> tensor + %40 = chlo.broadcast_compare %39, %5 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %41 = chlo.broadcast_compare %36, %5 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %42 = chlo.broadcast_compare %39, %5 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %43 = chlo.broadcast_compare %41, %42 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %44 = chlo.broadcast_and %40, %43 : (tensor, tensor) -> tensor + %45 = chlo.broadcast_add %36, %39 : (tensor, tensor) -> tensor + %46 = mhlo.select %44, %45, %39 : tensor, tensor + %cast_1 = tensor.cast %46 : tensor to tensor<*xelem_type> + scf.yield %cast_1 : tensor<*xelem_type> + } else { + %35 = arith.cmpi ule, %26, %c5 : index + cf.assert %35, "Input for dynamic binary or n-ary op lowering was of a rank greater than 5" + %36 = shape.broadcast %22#0, %0 : tensor, tensor<5xindex> -> tensor + %cast = tensor.cast %36 : tensor to tensor<5xindex> + %37 = mhlo.dynamic_reshape %arg1, %cast : (tensor<*xelem_type>, tensor<5xindex>) -> tensor + %38 = shape.broadcast %22#1, %0 : tensor, tensor<5xindex> -> tensor + %cast_0 = tensor.cast %38 : tensor to tensor<5xindex> + %39 = mhlo.dynamic_reshape %arg0, %cast_0 : (tensor<*xelem_type>, tensor<5xindex>) -> tensor + %40 = chlo.broadcast_remainder %39, %37 : (tensor, tensor) -> tensor + %41 = chlo.broadcast_compare %40, %5 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %42 = chlo.broadcast_compare %37, %5 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %43 = chlo.broadcast_compare %40, %5 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %44 = chlo.broadcast_compare %42, %43 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %45 = chlo.broadcast_and %41, %44 : (tensor, tensor) -> tensor + %46 = chlo.broadcast_add %37, %40 : (tensor, tensor) -> tensor + %47 = mhlo.select %45, %46, %40 : tensor, tensor + %cast_1 = tensor.cast %47 : tensor to tensor<*xelem_type> + scf.yield %cast_1 : tensor<*xelem_type> + } + scf.yield %34 : tensor<*xelem_type> + } + scf.yield %32 : tensor<*xelem_type> + } + scf.yield %30 : tensor<*xelem_type> + } + scf.yield %28 : tensor<*xelem_type> + } + scf.yield %21 : tensor<*xelem_type> + } + scf.yield %19 : tensor<*xelem_type> + } + %11 = shape.shape_of %arg1 : tensor<*xelem_type> -> tensor + %12 = shape.shape_of %arg0 : tensor<*xelem_type> -> tensor + %13 = shape.shape_of %arg1 : tensor<*xelem_type> -> tensor + %14 = shape.broadcast %12, %13 : tensor, tensor -> tensor + %15 = shape.broadcast %11, %14 : tensor, tensor -> tensor + %16 = mhlo.dynamic_reshape %10, %15 : (tensor<*xelem_type>, tensor) -> tensor<*xelem_type> + return %16 : tensor<*xelem_type> } diff --git a/tensorflow/core/kernels/mlir_generated/op_definitions/floor_mod_float.mlir.tmpl b/tensorflow/core/kernels/mlir_generated/op_definitions/floor_mod_float.mlir.tmpl index dc86b86ff5b1d5..c71f34da2cd8ac 100644 --- a/tensorflow/core/kernels/mlir_generated/op_definitions/floor_mod_float.mlir.tmpl +++ b/tensorflow/core/kernels/mlir_generated/op_definitions/floor_mod_float.mlir.tmpl @@ -1,14 +1,188 @@ -func.func @FloorMod_platform_elem_type_output_type(%arg0: tensor<*xelem_type>, %arg1: tensor<*xelem_type>) - -> tensor<*xoutput_type> attributes {tf_entry, llvm.emit_c_interface} { - %0 = chlo.broadcast_remainder %arg0, %arg1 : (tensor<*xelem_type>, tensor<*xelem_type>) -> tensor<*xelem_type> - %1 = mhlo.constant dense<0.000000e+00> : tensor - %2 = chlo.broadcast_compare %0, %1 {comparison_direction = #chlo} : (tensor<*xelem_type>, tensor) -> tensor<*xi1> - %3 = mhlo.constant dense<0.000000e+00> : tensor - %4 = chlo.broadcast_compare %arg1, %3 {comparison_direction = #chlo} : (tensor<*xelem_type>, tensor) -> tensor<*xi1> - %5 = chlo.broadcast_compare %0, %3 {comparison_direction = #chlo} : (tensor<*xelem_type>, tensor) -> tensor<*xi1> - %6 = chlo.broadcast_compare %4, %5 {comparison_direction = #chlo} : (tensor<*xi1>, tensor<*xi1>) -> tensor<*xi1> - %7 = chlo.broadcast_and %2, %6 : (tensor<*xi1>, tensor<*xi1>) -> tensor<*xi1> - %8 = chlo.broadcast_add %arg1, %0 : (tensor<*xelem_type>, tensor<*xelem_type>) -> tensor<*xelem_type> - %9 = mhlo.select %7, %8, %0 : tensor<*xi1>, tensor<*xelem_type> - func.return %9 : tensor<*xoutput_type> +func.func @FloorMod_platform_elem_type_output_type(%arg0: tensor<*xelem_type>, %arg1: tensor<*xelem_type>) -> tensor<*xoutput_type> attributes {llvm.emit_c_interface, tf_entry} { + %0 = shape.const_shape [1, 1, 1, 1, 1] : tensor<5xindex> + %c5 = arith.constant 5 : index + %1 = shape.const_shape [1, 1, 1, 1] : tensor<4xindex> + %c4 = arith.constant 4 : index + %2 = shape.const_shape [1, 1, 1] : tensor<3xindex> + %c3 = arith.constant 3 : index + %3 = shape.const_shape [1, 1] : tensor<2xindex> + %c2 = arith.constant 2 : index + %4 = shape.const_shape [1] : tensor<1xindex> + %c1 = arith.constant 1 : index + %5 = mhlo.constant dense<0.000000e+00> : tensor + %6 = shape.shape_of %arg1 : tensor<*xelem_type> -> tensor + %7 = shape.shape_of %arg0 : tensor<*xelem_type> -> tensor + %8 = shape.num_elements %6 : tensor -> index + %9 = arith.cmpi eq, %8, %c1 : index + %10 = scf.if %9 -> (tensor<*xelem_type>) { + %17 = shape.num_elements %7 : tensor -> index + %from_elements = tensor.from_elements %17 : tensor<1xindex> + %18 = mhlo.reshape %arg1 : (tensor<*xelem_type>) -> tensor + %19 = mhlo.dynamic_reshape %arg0, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %20 = chlo.broadcast_remainder %19, %18 : (tensor, tensor) -> tensor + %21 = chlo.broadcast_compare %20, %5 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %22 = chlo.broadcast_compare %18, %5 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %23 = chlo.broadcast_compare %20, %5 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %24 = chlo.broadcast_compare %22, %23 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %25 = chlo.broadcast_and %21, %24 : (tensor, tensor) -> tensor + %26 = chlo.broadcast_add %18, %20 : (tensor, tensor) -> tensor + %27 = mhlo.select %25, %26, %20 : tensor, tensor + %cast = tensor.cast %27 : tensor to tensor<*xelem_type> + scf.yield %cast : tensor<*xelem_type> + } else { + %17 = shape.num_elements %7 : tensor -> index + %18 = arith.cmpi eq, %17, %c1 : index + %19 = scf.if %18 -> (tensor<*xelem_type>) { + %20 = shape.num_elements %6 : tensor -> index + %from_elements = tensor.from_elements %20 : tensor<1xindex> + %21 = mhlo.dynamic_reshape %arg1, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %22 = mhlo.reshape %arg0 : (tensor<*xelem_type>) -> tensor + %23 = chlo.broadcast_remainder %22, %21 : (tensor, tensor) -> tensor + %24 = chlo.broadcast_compare %23, %5 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %25 = chlo.broadcast_compare %21, %5 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %26 = chlo.broadcast_compare %23, %5 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %27 = chlo.broadcast_compare %25, %26 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %28 = chlo.broadcast_and %24, %27 : (tensor, tensor) -> tensor + %29 = chlo.broadcast_add %21, %23 : (tensor, tensor) -> tensor + %30 = mhlo.select %28, %29, %23 : tensor, tensor + %cast = tensor.cast %30 : tensor to tensor<*xelem_type> + scf.yield %cast : tensor<*xelem_type> + } else { + %20 = shape.shape_eq %6, %7 : tensor, tensor + %21 = scf.if %20 -> (tensor<*xelem_type>) { + %22 = shape.any %6, %7 : tensor, tensor -> tensor + %23 = shape.num_elements %22 : tensor -> index + %from_elements = tensor.from_elements %23 : tensor<1xindex> + %24 = mhlo.dynamic_reshape %arg1, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %25 = mhlo.dynamic_reshape %arg0, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %26 = chlo.broadcast_remainder %25, %24 : (tensor, tensor) -> tensor + %27 = chlo.broadcast_compare %26, %5 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %28 = chlo.broadcast_compare %24, %5 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %29 = chlo.broadcast_compare %26, %5 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %30 = chlo.broadcast_compare %28, %29 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %31 = chlo.broadcast_and %27, %30 : (tensor, tensor) -> tensor + %32 = chlo.broadcast_add %24, %26 : (tensor, tensor) -> tensor + %33 = mhlo.select %31, %32, %26 : tensor, tensor + %cast = tensor.cast %33 : tensor to tensor<*xelem_type> + scf.yield %cast : tensor<*xelem_type> + } else { + %22:2 = chlo.minimum_broadcast_shapes %6, %7 : tensor, tensor -> tensor, tensor + %23 = shape.rank %22#0 : tensor -> index + %24 = shape.rank %22#1 : tensor -> index + %25 = arith.cmpi sgt, %23, %24 : index + %26 = arith.select %25, %23, %24 : index + %27 = arith.cmpi ule, %26, %c1 : index + %28 = scf.if %27 -> (tensor<*xelem_type>) { + %29 = shape.broadcast %22#0, %4 : tensor, tensor<1xindex> -> tensor + %cast = tensor.cast %29 : tensor to tensor<1xindex> + %30 = mhlo.dynamic_reshape %arg1, %cast : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %31 = shape.broadcast %22#1, %4 : tensor, tensor<1xindex> -> tensor + %cast_0 = tensor.cast %31 : tensor to tensor<1xindex> + %32 = mhlo.dynamic_reshape %arg0, %cast_0 : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %33 = chlo.broadcast_remainder %32, %30 : (tensor, tensor) -> tensor + %34 = chlo.broadcast_compare %33, %5 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %35 = chlo.broadcast_compare %30, %5 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %36 = chlo.broadcast_compare %33, %5 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %37 = chlo.broadcast_compare %35, %36 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %38 = chlo.broadcast_and %34, %37 : (tensor, tensor) -> tensor + %39 = chlo.broadcast_add %30, %33 : (tensor, tensor) -> tensor + %40 = mhlo.select %38, %39, %33 : tensor, tensor + %cast_1 = tensor.cast %40 : tensor to tensor<*xelem_type> + scf.yield %cast_1 : tensor<*xelem_type> + } else { + %29 = arith.cmpi ule, %26, %c2 : index + %30 = scf.if %29 -> (tensor<*xelem_type>) { + %31 = shape.broadcast %22#0, %3 : tensor, tensor<2xindex> -> tensor + %cast = tensor.cast %31 : tensor to tensor<2xindex> + %32 = mhlo.dynamic_reshape %arg1, %cast : (tensor<*xelem_type>, tensor<2xindex>) -> tensor + %33 = shape.broadcast %22#1, %3 : tensor, tensor<2xindex> -> tensor + %cast_0 = tensor.cast %33 : tensor to tensor<2xindex> + %34 = mhlo.dynamic_reshape %arg0, %cast_0 : (tensor<*xelem_type>, tensor<2xindex>) -> tensor + %35 = chlo.broadcast_remainder %34, %32 : (tensor, tensor) -> tensor + %36 = chlo.broadcast_compare %35, %5 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %37 = chlo.broadcast_compare %32, %5 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %38 = chlo.broadcast_compare %35, %5 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %39 = chlo.broadcast_compare %37, %38 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %40 = chlo.broadcast_and %36, %39 : (tensor, tensor) -> tensor + %41 = chlo.broadcast_add %32, %35 : (tensor, tensor) -> tensor + %42 = mhlo.select %40, %41, %35 : tensor, tensor + %cast_1 = tensor.cast %42 : tensor to tensor<*xelem_type> + scf.yield %cast_1 : tensor<*xelem_type> + } else { + %31 = arith.cmpi ule, %26, %c3 : index + %32 = scf.if %31 -> (tensor<*xelem_type>) { + %33 = shape.broadcast %22#0, %2 : tensor, tensor<3xindex> -> tensor + %cast = tensor.cast %33 : tensor to tensor<3xindex> + %34 = mhlo.dynamic_reshape %arg1, %cast : (tensor<*xelem_type>, tensor<3xindex>) -> tensor + %35 = shape.broadcast %22#1, %2 : tensor, tensor<3xindex> -> tensor + %cast_0 = tensor.cast %35 : tensor to tensor<3xindex> + %36 = mhlo.dynamic_reshape %arg0, %cast_0 : (tensor<*xelem_type>, tensor<3xindex>) -> tensor + %37 = chlo.broadcast_remainder %36, %34 : (tensor, tensor) -> tensor + %38 = chlo.broadcast_compare %37, %5 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %39 = chlo.broadcast_compare %34, %5 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %40 = chlo.broadcast_compare %37, %5 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %41 = chlo.broadcast_compare %39, %40 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %42 = chlo.broadcast_and %38, %41 : (tensor, tensor) -> tensor + %43 = chlo.broadcast_add %34, %37 : (tensor, tensor) -> tensor + %44 = mhlo.select %42, %43, %37 : tensor, tensor + %cast_1 = tensor.cast %44 : tensor to tensor<*xelem_type> + scf.yield %cast_1 : tensor<*xelem_type> + } else { + %33 = arith.cmpi ule, %26, %c4 : index + %34 = scf.if %33 -> (tensor<*xelem_type>) { + %35 = shape.broadcast %22#0, %1 : tensor, tensor<4xindex> -> tensor + %cast = tensor.cast %35 : tensor to tensor<4xindex> + %36 = mhlo.dynamic_reshape %arg1, %cast : (tensor<*xelem_type>, tensor<4xindex>) -> tensor + %37 = shape.broadcast %22#1, %1 : tensor, tensor<4xindex> -> tensor + %cast_0 = tensor.cast %37 : tensor to tensor<4xindex> + %38 = mhlo.dynamic_reshape %arg0, %cast_0 : (tensor<*xelem_type>, tensor<4xindex>) -> tensor + %39 = chlo.broadcast_remainder %38, %36 : (tensor, tensor) -> tensor + %40 = chlo.broadcast_compare %39, %5 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %41 = chlo.broadcast_compare %36, %5 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %42 = chlo.broadcast_compare %39, %5 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %43 = chlo.broadcast_compare %41, %42 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %44 = chlo.broadcast_and %40, %43 : (tensor, tensor) -> tensor + %45 = chlo.broadcast_add %36, %39 : (tensor, tensor) -> tensor + %46 = mhlo.select %44, %45, %39 : tensor, tensor + %cast_1 = tensor.cast %46 : tensor to tensor<*xelem_type> + scf.yield %cast_1 : tensor<*xelem_type> + } else { + %35 = arith.cmpi ule, %26, %c5 : index + cf.assert %35, "Input for dynamic binary or n-ary op lowering was of a rank greater than 5" + %36 = shape.broadcast %22#0, %0 : tensor, tensor<5xindex> -> tensor + %cast = tensor.cast %36 : tensor to tensor<5xindex> + %37 = mhlo.dynamic_reshape %arg1, %cast : (tensor<*xelem_type>, tensor<5xindex>) -> tensor + %38 = shape.broadcast %22#1, %0 : tensor, tensor<5xindex> -> tensor + %cast_0 = tensor.cast %38 : tensor to tensor<5xindex> + %39 = mhlo.dynamic_reshape %arg0, %cast_0 : (tensor<*xelem_type>, tensor<5xindex>) -> tensor + %40 = chlo.broadcast_remainder %39, %37 : (tensor, tensor) -> tensor + %41 = chlo.broadcast_compare %40, %5 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %42 = chlo.broadcast_compare %37, %5 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %43 = chlo.broadcast_compare %40, %5 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %44 = chlo.broadcast_compare %42, %43 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %45 = chlo.broadcast_and %41, %44 : (tensor, tensor) -> tensor + %46 = chlo.broadcast_add %37, %40 : (tensor, tensor) -> tensor + %47 = mhlo.select %45, %46, %40 : tensor, tensor + %cast_1 = tensor.cast %47 : tensor to tensor<*xelem_type> + scf.yield %cast_1 : tensor<*xelem_type> + } + scf.yield %34 : tensor<*xelem_type> + } + scf.yield %32 : tensor<*xelem_type> + } + scf.yield %30 : tensor<*xelem_type> + } + scf.yield %28 : tensor<*xelem_type> + } + scf.yield %21 : tensor<*xelem_type> + } + scf.yield %19 : tensor<*xelem_type> + } + %11 = shape.shape_of %arg1 : tensor<*xelem_type> -> tensor + %12 = shape.shape_of %arg0 : tensor<*xelem_type> -> tensor + %13 = shape.shape_of %arg1 : tensor<*xelem_type> -> tensor + %14 = shape.broadcast %12, %13 : tensor, tensor -> tensor + %15 = shape.broadcast %11, %14 : tensor, tensor -> tensor + %16 = mhlo.dynamic_reshape %10, %15 : (tensor<*xelem_type>, tensor) -> tensor<*xelem_type> + return %16 : tensor<*xelem_type> } diff --git a/tensorflow/core/kernels/mlir_generated/op_definitions/floor_mod_unsigned.mlir.tmpl b/tensorflow/core/kernels/mlir_generated/op_definitions/floor_mod_unsigned.mlir.tmpl index b4b7654ca700f7..e9babe7dcf6c6b 100644 --- a/tensorflow/core/kernels/mlir_generated/op_definitions/floor_mod_unsigned.mlir.tmpl +++ b/tensorflow/core/kernels/mlir_generated/op_definitions/floor_mod_unsigned.mlir.tmpl @@ -1,5 +1,129 @@ -func.func @FloorMod_platform_elem_type_output_type(%arg0: tensor<*xelem_type>, %arg1: tensor<*xelem_type>) - -> tensor<*xoutput_type> attributes {tf_entry, llvm.emit_c_interface} { - %0 = chlo.broadcast_remainder %arg0, %arg1 : (tensor<*xelem_type>, tensor<*xelem_type>) -> tensor<*xelem_type> - func.return %0 : tensor<*xoutput_type> +func.func @FloorMod_platform_elem_type_output_type(%arg0: tensor<*xelem_type>, %arg1: tensor<*xelem_type>) -> tensor<*xoutput_type> attributes {llvm.emit_c_interface, tf_entry} { + %0 = shape.const_shape [1, 1, 1, 1, 1] : tensor<5xindex> + %c5 = arith.constant 5 : index + %1 = shape.const_shape [1, 1, 1, 1] : tensor<4xindex> + %c4 = arith.constant 4 : index + %2 = shape.const_shape [1, 1, 1] : tensor<3xindex> + %c3 = arith.constant 3 : index + %3 = shape.const_shape [1, 1] : tensor<2xindex> + %c2 = arith.constant 2 : index + %4 = shape.const_shape [1] : tensor<1xindex> + %c1 = arith.constant 1 : index + %5 = shape.shape_of %arg0 : tensor<*xelem_type> -> tensor + %6 = shape.shape_of %arg1 : tensor<*xelem_type> -> tensor + %7 = shape.num_elements %5 : tensor -> index + %8 = arith.cmpi eq, %7, %c1 : index + %9 = scf.if %8 -> (tensor<*xelem_type>) { + %14 = shape.num_elements %6 : tensor -> index + %from_elements = tensor.from_elements %14 : tensor<1xindex> + %15 = mhlo.reshape %arg0 : (tensor<*xelem_type>) -> tensor + %16 = mhlo.dynamic_reshape %arg1, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %17 = chlo.broadcast_remainder %15, %16 : (tensor, tensor) -> tensor + %cast = tensor.cast %17 : tensor to tensor<*xelem_type> + scf.yield %cast : tensor<*xelem_type> + } else { + %14 = shape.num_elements %6 : tensor -> index + %15 = arith.cmpi eq, %14, %c1 : index + %16 = scf.if %15 -> (tensor<*xelem_type>) { + %17 = shape.num_elements %5 : tensor -> index + %from_elements = tensor.from_elements %17 : tensor<1xindex> + %18 = mhlo.dynamic_reshape %arg0, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %19 = mhlo.reshape %arg1 : (tensor<*xelem_type>) -> tensor + %20 = chlo.broadcast_remainder %18, %19 : (tensor, tensor) -> tensor + %cast = tensor.cast %20 : tensor to tensor<*xelem_type> + scf.yield %cast : tensor<*xelem_type> + } else { + %17 = shape.shape_eq %5, %6 : tensor, tensor + %18 = scf.if %17 -> (tensor<*xelem_type>) { + %19 = shape.any %5, %6 : tensor, tensor -> tensor + %20 = shape.num_elements %19 : tensor -> index + %from_elements = tensor.from_elements %20 : tensor<1xindex> + %21 = mhlo.dynamic_reshape %arg0, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %22 = mhlo.dynamic_reshape %arg1, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %23 = chlo.broadcast_remainder %21, %22 : (tensor, tensor) -> tensor + %cast = tensor.cast %23 : tensor to tensor<*xelem_type> + scf.yield %cast : tensor<*xelem_type> + } else { + %19:2 = chlo.minimum_broadcast_shapes %5, %6 : tensor, tensor -> tensor, tensor + %20 = shape.rank %19#0 : tensor -> index + %21 = shape.rank %19#1 : tensor -> index + %22 = arith.cmpi sgt, %20, %21 : index + %23 = arith.select %22, %20, %21 : index + %24 = arith.cmpi ule, %23, %c1 : index + %25 = scf.if %24 -> (tensor<*xelem_type>) { + %26 = shape.broadcast %19#0, %4 : tensor, tensor<1xindex> -> tensor + %cast = tensor.cast %26 : tensor to tensor<1xindex> + %27 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %28 = shape.broadcast %19#1, %4 : tensor, tensor<1xindex> -> tensor + %cast_0 = tensor.cast %28 : tensor to tensor<1xindex> + %29 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %30 = chlo.broadcast_remainder %27, %29 : (tensor, tensor) -> tensor + %cast_1 = tensor.cast %30 : tensor to tensor<*xelem_type> + scf.yield %cast_1 : tensor<*xelem_type> + } else { + %26 = arith.cmpi ule, %23, %c2 : index + %27 = scf.if %26 -> (tensor<*xelem_type>) { + %28 = shape.broadcast %19#0, %3 : tensor, tensor<2xindex> -> tensor + %cast = tensor.cast %28 : tensor to tensor<2xindex> + %29 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xelem_type>, tensor<2xindex>) -> tensor + %30 = shape.broadcast %19#1, %3 : tensor, tensor<2xindex> -> tensor + %cast_0 = tensor.cast %30 : tensor to tensor<2xindex> + %31 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xelem_type>, tensor<2xindex>) -> tensor + %32 = chlo.broadcast_remainder %29, %31 : (tensor, tensor) -> tensor + %cast_1 = tensor.cast %32 : tensor to tensor<*xelem_type> + scf.yield %cast_1 : tensor<*xelem_type> + } else { + %28 = arith.cmpi ule, %23, %c3 : index + %29 = scf.if %28 -> (tensor<*xelem_type>) { + %30 = shape.broadcast %19#0, %2 : tensor, tensor<3xindex> -> tensor + %cast = tensor.cast %30 : tensor to tensor<3xindex> + %31 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xelem_type>, tensor<3xindex>) -> tensor + %32 = shape.broadcast %19#1, %2 : tensor, tensor<3xindex> -> tensor + %cast_0 = tensor.cast %32 : tensor to tensor<3xindex> + %33 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xelem_type>, tensor<3xindex>) -> tensor + %34 = chlo.broadcast_remainder %31, %33 : (tensor, tensor) -> tensor + %cast_1 = tensor.cast %34 : tensor to tensor<*xelem_type> + scf.yield %cast_1 : tensor<*xelem_type> + } else { + %30 = arith.cmpi ule, %23, %c4 : index + %31 = scf.if %30 -> (tensor<*xelem_type>) { + %32 = shape.broadcast %19#0, %1 : tensor, tensor<4xindex> -> tensor + %cast = tensor.cast %32 : tensor to tensor<4xindex> + %33 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xelem_type>, tensor<4xindex>) -> tensor + %34 = shape.broadcast %19#1, %1 : tensor, tensor<4xindex> -> tensor + %cast_0 = tensor.cast %34 : tensor to tensor<4xindex> + %35 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xelem_type>, tensor<4xindex>) -> tensor + %36 = chlo.broadcast_remainder %33, %35 : (tensor, tensor) -> tensor + %cast_1 = tensor.cast %36 : tensor to tensor<*xelem_type> + scf.yield %cast_1 : tensor<*xelem_type> + } else { + %32 = arith.cmpi ule, %23, %c5 : index + cf.assert %32, "Input for dynamic binary or n-ary op lowering was of a rank greater than 5" + %33 = shape.broadcast %19#0, %0 : tensor, tensor<5xindex> -> tensor + %cast = tensor.cast %33 : tensor to tensor<5xindex> + %34 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xelem_type>, tensor<5xindex>) -> tensor + %35 = shape.broadcast %19#1, %0 : tensor, tensor<5xindex> -> tensor + %cast_0 = tensor.cast %35 : tensor to tensor<5xindex> + %36 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xelem_type>, tensor<5xindex>) -> tensor + %37 = chlo.broadcast_remainder %34, %36 : (tensor, tensor) -> tensor + %cast_1 = tensor.cast %37 : tensor to tensor<*xelem_type> + scf.yield %cast_1 : tensor<*xelem_type> + } + scf.yield %31 : tensor<*xelem_type> + } + scf.yield %29 : tensor<*xelem_type> + } + scf.yield %27 : tensor<*xelem_type> + } + scf.yield %25 : tensor<*xelem_type> + } + scf.yield %18 : tensor<*xelem_type> + } + scf.yield %16 : tensor<*xelem_type> + } + %10 = shape.shape_of %arg0 : tensor<*xelem_type> -> tensor + %11 = shape.shape_of %arg1 : tensor<*xelem_type> -> tensor + %12 = shape.broadcast %10, %11 : tensor, tensor -> tensor + %13 = mhlo.dynamic_reshape %9, %12 : (tensor<*xelem_type>, tensor) -> tensor<*xelem_type> + return %13 : tensor<*xelem_type> } diff --git a/tensorflow/core/kernels/mlir_generated/op_definitions/greater.mlir.tmpl b/tensorflow/core/kernels/mlir_generated/op_definitions/greater.mlir.tmpl index b3df2502ea01fa..f4a4472466165a 100644 --- a/tensorflow/core/kernels/mlir_generated/op_definitions/greater.mlir.tmpl +++ b/tensorflow/core/kernels/mlir_generated/op_definitions/greater.mlir.tmpl @@ -1,5 +1,129 @@ -func.func @Greater_platform_elem_type_output_type(%arg0: tensor<*xelem_type>, %arg1: tensor<*xelem_type>) - -> tensor<*xoutput_type> attributes {tf_entry, llvm.emit_c_interface} { - %0 = chlo.broadcast_compare %arg0, %arg1 {comparison_direction = #chlo} : (tensor<*xelem_type>, tensor<*xelem_type>) -> tensor<*xoutput_type> - func.return %0 : tensor<*xoutput_type> +func.func @Greater_platform_elem_type_output_type(%arg0: tensor<*xelem_type>, %arg1: tensor<*xelem_type>) -> tensor<*xoutput_type> attributes {llvm.emit_c_interface, tf_entry} { + %0 = shape.const_shape [1, 1, 1, 1, 1] : tensor<5xindex> + %c5 = arith.constant 5 : index + %1 = shape.const_shape [1, 1, 1, 1] : tensor<4xindex> + %c4 = arith.constant 4 : index + %2 = shape.const_shape [1, 1, 1] : tensor<3xindex> + %c3 = arith.constant 3 : index + %3 = shape.const_shape [1, 1] : tensor<2xindex> + %c2 = arith.constant 2 : index + %4 = shape.const_shape [1] : tensor<1xindex> + %c1 = arith.constant 1 : index + %5 = shape.shape_of %arg0 : tensor<*xelem_type> -> tensor + %6 = shape.shape_of %arg1 : tensor<*xelem_type> -> tensor + %7 = shape.num_elements %5 : tensor -> index + %8 = arith.cmpi eq, %7, %c1 : index + %9 = scf.if %8 -> (tensor<*xoutput_type>) { + %14 = shape.num_elements %6 : tensor -> index + %from_elements = tensor.from_elements %14 : tensor<1xindex> + %15 = mhlo.reshape %arg0 : (tensor<*xelem_type>) -> tensor + %16 = mhlo.dynamic_reshape %arg1, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %17 = chlo.broadcast_compare %15, %16 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %cast = tensor.cast %17 : tensor to tensor<*xoutput_type> + scf.yield %cast : tensor<*xoutput_type> + } else { + %14 = shape.num_elements %6 : tensor -> index + %15 = arith.cmpi eq, %14, %c1 : index + %16 = scf.if %15 -> (tensor<*xoutput_type>) { + %17 = shape.num_elements %5 : tensor -> index + %from_elements = tensor.from_elements %17 : tensor<1xindex> + %18 = mhlo.dynamic_reshape %arg0, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %19 = mhlo.reshape %arg1 : (tensor<*xelem_type>) -> tensor + %20 = chlo.broadcast_compare %18, %19 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %cast = tensor.cast %20 : tensor to tensor<*xoutput_type> + scf.yield %cast : tensor<*xoutput_type> + } else { + %17 = shape.shape_eq %5, %6 : tensor, tensor + %18 = scf.if %17 -> (tensor<*xoutput_type>) { + %19 = shape.any %5, %6 : tensor, tensor -> tensor + %20 = shape.num_elements %19 : tensor -> index + %from_elements = tensor.from_elements %20 : tensor<1xindex> + %21 = mhlo.dynamic_reshape %arg0, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %22 = mhlo.dynamic_reshape %arg1, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %23 = chlo.broadcast_compare %21, %22 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %cast = tensor.cast %23 : tensor to tensor<*xoutput_type> + scf.yield %cast : tensor<*xoutput_type> + } else { + %19:2 = chlo.minimum_broadcast_shapes %5, %6 : tensor, tensor -> tensor, tensor + %20 = shape.rank %19#0 : tensor -> index + %21 = shape.rank %19#1 : tensor -> index + %22 = arith.cmpi sgt, %20, %21 : index + %23 = arith.select %22, %20, %21 : index + %24 = arith.cmpi ule, %23, %c1 : index + %25 = scf.if %24 -> (tensor<*xoutput_type>) { + %26 = shape.broadcast %19#0, %4 : tensor, tensor<1xindex> -> tensor + %cast = tensor.cast %26 : tensor to tensor<1xindex> + %27 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %28 = shape.broadcast %19#1, %4 : tensor, tensor<1xindex> -> tensor + %cast_0 = tensor.cast %28 : tensor to tensor<1xindex> + %29 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %30 = chlo.broadcast_compare %27, %29 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %cast_1 = tensor.cast %30 : tensor to tensor<*xoutput_type> + scf.yield %cast_1 : tensor<*xoutput_type> + } else { + %26 = arith.cmpi ule, %23, %c2 : index + %27 = scf.if %26 -> (tensor<*xoutput_type>) { + %28 = shape.broadcast %19#0, %3 : tensor, tensor<2xindex> -> tensor + %cast = tensor.cast %28 : tensor to tensor<2xindex> + %29 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xelem_type>, tensor<2xindex>) -> tensor + %30 = shape.broadcast %19#1, %3 : tensor, tensor<2xindex> -> tensor + %cast_0 = tensor.cast %30 : tensor to tensor<2xindex> + %31 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xelem_type>, tensor<2xindex>) -> tensor + %32 = chlo.broadcast_compare %29, %31 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %cast_1 = tensor.cast %32 : tensor to tensor<*xoutput_type> + scf.yield %cast_1 : tensor<*xoutput_type> + } else { + %28 = arith.cmpi ule, %23, %c3 : index + %29 = scf.if %28 -> (tensor<*xoutput_type>) { + %30 = shape.broadcast %19#0, %2 : tensor, tensor<3xindex> -> tensor + %cast = tensor.cast %30 : tensor to tensor<3xindex> + %31 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xelem_type>, tensor<3xindex>) -> tensor + %32 = shape.broadcast %19#1, %2 : tensor, tensor<3xindex> -> tensor + %cast_0 = tensor.cast %32 : tensor to tensor<3xindex> + %33 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xelem_type>, tensor<3xindex>) -> tensor + %34 = chlo.broadcast_compare %31, %33 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %cast_1 = tensor.cast %34 : tensor to tensor<*xoutput_type> + scf.yield %cast_1 : tensor<*xoutput_type> + } else { + %30 = arith.cmpi ule, %23, %c4 : index + %31 = scf.if %30 -> (tensor<*xoutput_type>) { + %32 = shape.broadcast %19#0, %1 : tensor, tensor<4xindex> -> tensor + %cast = tensor.cast %32 : tensor to tensor<4xindex> + %33 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xelem_type>, tensor<4xindex>) -> tensor + %34 = shape.broadcast %19#1, %1 : tensor, tensor<4xindex> -> tensor + %cast_0 = tensor.cast %34 : tensor to tensor<4xindex> + %35 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xelem_type>, tensor<4xindex>) -> tensor + %36 = chlo.broadcast_compare %33, %35 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %cast_1 = tensor.cast %36 : tensor to tensor<*xoutput_type> + scf.yield %cast_1 : tensor<*xoutput_type> + } else { + %32 = arith.cmpi ule, %23, %c5 : index + cf.assert %32, "Input for dynamic binary or n-ary op lowering was of a rank greater than 5" + %33 = shape.broadcast %19#0, %0 : tensor, tensor<5xindex> -> tensor + %cast = tensor.cast %33 : tensor to tensor<5xindex> + %34 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xelem_type>, tensor<5xindex>) -> tensor + %35 = shape.broadcast %19#1, %0 : tensor, tensor<5xindex> -> tensor + %cast_0 = tensor.cast %35 : tensor to tensor<5xindex> + %36 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xelem_type>, tensor<5xindex>) -> tensor + %37 = chlo.broadcast_compare %34, %36 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %cast_1 = tensor.cast %37 : tensor to tensor<*xoutput_type> + scf.yield %cast_1 : tensor<*xoutput_type> + } + scf.yield %31 : tensor<*xoutput_type> + } + scf.yield %29 : tensor<*xoutput_type> + } + scf.yield %27 : tensor<*xoutput_type> + } + scf.yield %25 : tensor<*xoutput_type> + } + scf.yield %18 : tensor<*xoutput_type> + } + scf.yield %16 : tensor<*xoutput_type> + } + %10 = shape.shape_of %arg0 : tensor<*xelem_type> -> tensor + %11 = shape.shape_of %arg1 : tensor<*xelem_type> -> tensor + %12 = shape.broadcast %10, %11 : tensor, tensor -> tensor + %13 = mhlo.dynamic_reshape %9, %12 : (tensor<*xoutput_type>, tensor) -> tensor<*xoutput_type> + return %13 : tensor<*xoutput_type> } diff --git a/tensorflow/core/kernels/mlir_generated/op_definitions/greater_equal.mlir.tmpl b/tensorflow/core/kernels/mlir_generated/op_definitions/greater_equal.mlir.tmpl index 8160335dc8936f..573b389ebb5148 100644 --- a/tensorflow/core/kernels/mlir_generated/op_definitions/greater_equal.mlir.tmpl +++ b/tensorflow/core/kernels/mlir_generated/op_definitions/greater_equal.mlir.tmpl @@ -1,5 +1,129 @@ -func.func @GreaterEqual_platform_elem_type_output_type(%arg0: tensor<*xelem_type>, %arg1: tensor<*xelem_type>) - -> tensor<*xoutput_type> attributes {tf_entry, llvm.emit_c_interface} { - %0 = chlo.broadcast_compare %arg0, %arg1 {comparison_direction = #chlo} : (tensor<*xelem_type>, tensor<*xelem_type>) -> tensor<*xoutput_type> - func.return %0 : tensor<*xoutput_type> +func.func @GreaterEqual_platform_elem_type_output_type(%arg0: tensor<*xelem_type>, %arg1: tensor<*xelem_type>) -> tensor<*xoutput_type> attributes {llvm.emit_c_interface, tf_entry} { + %0 = shape.const_shape [1, 1, 1, 1, 1] : tensor<5xindex> + %c5 = arith.constant 5 : index + %1 = shape.const_shape [1, 1, 1, 1] : tensor<4xindex> + %c4 = arith.constant 4 : index + %2 = shape.const_shape [1, 1, 1] : tensor<3xindex> + %c3 = arith.constant 3 : index + %3 = shape.const_shape [1, 1] : tensor<2xindex> + %c2 = arith.constant 2 : index + %4 = shape.const_shape [1] : tensor<1xindex> + %c1 = arith.constant 1 : index + %5 = shape.shape_of %arg0 : tensor<*xelem_type> -> tensor + %6 = shape.shape_of %arg1 : tensor<*xelem_type> -> tensor + %7 = shape.num_elements %5 : tensor -> index + %8 = arith.cmpi eq, %7, %c1 : index + %9 = scf.if %8 -> (tensor<*xoutput_type>) { + %14 = shape.num_elements %6 : tensor -> index + %from_elements = tensor.from_elements %14 : tensor<1xindex> + %15 = mhlo.reshape %arg0 : (tensor<*xelem_type>) -> tensor + %16 = mhlo.dynamic_reshape %arg1, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %17 = chlo.broadcast_compare %15, %16 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %cast = tensor.cast %17 : tensor to tensor<*xoutput_type> + scf.yield %cast : tensor<*xoutput_type> + } else { + %14 = shape.num_elements %6 : tensor -> index + %15 = arith.cmpi eq, %14, %c1 : index + %16 = scf.if %15 -> (tensor<*xoutput_type>) { + %17 = shape.num_elements %5 : tensor -> index + %from_elements = tensor.from_elements %17 : tensor<1xindex> + %18 = mhlo.dynamic_reshape %arg0, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %19 = mhlo.reshape %arg1 : (tensor<*xelem_type>) -> tensor + %20 = chlo.broadcast_compare %18, %19 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %cast = tensor.cast %20 : tensor to tensor<*xoutput_type> + scf.yield %cast : tensor<*xoutput_type> + } else { + %17 = shape.shape_eq %5, %6 : tensor, tensor + %18 = scf.if %17 -> (tensor<*xoutput_type>) { + %19 = shape.any %5, %6 : tensor, tensor -> tensor + %20 = shape.num_elements %19 : tensor -> index + %from_elements = tensor.from_elements %20 : tensor<1xindex> + %21 = mhlo.dynamic_reshape %arg0, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %22 = mhlo.dynamic_reshape %arg1, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %23 = chlo.broadcast_compare %21, %22 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %cast = tensor.cast %23 : tensor to tensor<*xoutput_type> + scf.yield %cast : tensor<*xoutput_type> + } else { + %19:2 = chlo.minimum_broadcast_shapes %5, %6 : tensor, tensor -> tensor, tensor + %20 = shape.rank %19#0 : tensor -> index + %21 = shape.rank %19#1 : tensor -> index + %22 = arith.cmpi sgt, %20, %21 : index + %23 = arith.select %22, %20, %21 : index + %24 = arith.cmpi ule, %23, %c1 : index + %25 = scf.if %24 -> (tensor<*xoutput_type>) { + %26 = shape.broadcast %19#0, %4 : tensor, tensor<1xindex> -> tensor + %cast = tensor.cast %26 : tensor to tensor<1xindex> + %27 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %28 = shape.broadcast %19#1, %4 : tensor, tensor<1xindex> -> tensor + %cast_0 = tensor.cast %28 : tensor to tensor<1xindex> + %29 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %30 = chlo.broadcast_compare %27, %29 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %cast_1 = tensor.cast %30 : tensor to tensor<*xoutput_type> + scf.yield %cast_1 : tensor<*xoutput_type> + } else { + %26 = arith.cmpi ule, %23, %c2 : index + %27 = scf.if %26 -> (tensor<*xoutput_type>) { + %28 = shape.broadcast %19#0, %3 : tensor, tensor<2xindex> -> tensor + %cast = tensor.cast %28 : tensor to tensor<2xindex> + %29 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xelem_type>, tensor<2xindex>) -> tensor + %30 = shape.broadcast %19#1, %3 : tensor, tensor<2xindex> -> tensor + %cast_0 = tensor.cast %30 : tensor to tensor<2xindex> + %31 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xelem_type>, tensor<2xindex>) -> tensor + %32 = chlo.broadcast_compare %29, %31 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %cast_1 = tensor.cast %32 : tensor to tensor<*xoutput_type> + scf.yield %cast_1 : tensor<*xoutput_type> + } else { + %28 = arith.cmpi ule, %23, %c3 : index + %29 = scf.if %28 -> (tensor<*xoutput_type>) { + %30 = shape.broadcast %19#0, %2 : tensor, tensor<3xindex> -> tensor + %cast = tensor.cast %30 : tensor to tensor<3xindex> + %31 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xelem_type>, tensor<3xindex>) -> tensor + %32 = shape.broadcast %19#1, %2 : tensor, tensor<3xindex> -> tensor + %cast_0 = tensor.cast %32 : tensor to tensor<3xindex> + %33 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xelem_type>, tensor<3xindex>) -> tensor + %34 = chlo.broadcast_compare %31, %33 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %cast_1 = tensor.cast %34 : tensor to tensor<*xoutput_type> + scf.yield %cast_1 : tensor<*xoutput_type> + } else { + %30 = arith.cmpi ule, %23, %c4 : index + %31 = scf.if %30 -> (tensor<*xoutput_type>) { + %32 = shape.broadcast %19#0, %1 : tensor, tensor<4xindex> -> tensor + %cast = tensor.cast %32 : tensor to tensor<4xindex> + %33 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xelem_type>, tensor<4xindex>) -> tensor + %34 = shape.broadcast %19#1, %1 : tensor, tensor<4xindex> -> tensor + %cast_0 = tensor.cast %34 : tensor to tensor<4xindex> + %35 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xelem_type>, tensor<4xindex>) -> tensor + %36 = chlo.broadcast_compare %33, %35 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %cast_1 = tensor.cast %36 : tensor to tensor<*xoutput_type> + scf.yield %cast_1 : tensor<*xoutput_type> + } else { + %32 = arith.cmpi ule, %23, %c5 : index + cf.assert %32, "Input for dynamic binary or n-ary op lowering was of a rank greater than 5" + %33 = shape.broadcast %19#0, %0 : tensor, tensor<5xindex> -> tensor + %cast = tensor.cast %33 : tensor to tensor<5xindex> + %34 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xelem_type>, tensor<5xindex>) -> tensor + %35 = shape.broadcast %19#1, %0 : tensor, tensor<5xindex> -> tensor + %cast_0 = tensor.cast %35 : tensor to tensor<5xindex> + %36 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xelem_type>, tensor<5xindex>) -> tensor + %37 = chlo.broadcast_compare %34, %36 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %cast_1 = tensor.cast %37 : tensor to tensor<*xoutput_type> + scf.yield %cast_1 : tensor<*xoutput_type> + } + scf.yield %31 : tensor<*xoutput_type> + } + scf.yield %29 : tensor<*xoutput_type> + } + scf.yield %27 : tensor<*xoutput_type> + } + scf.yield %25 : tensor<*xoutput_type> + } + scf.yield %18 : tensor<*xoutput_type> + } + scf.yield %16 : tensor<*xoutput_type> + } + %10 = shape.shape_of %arg0 : tensor<*xelem_type> -> tensor + %11 = shape.shape_of %arg1 : tensor<*xelem_type> -> tensor + %12 = shape.broadcast %10, %11 : tensor, tensor -> tensor + %13 = mhlo.dynamic_reshape %9, %12 : (tensor<*xoutput_type>, tensor) -> tensor<*xoutput_type> + return %13 : tensor<*xoutput_type> } diff --git a/tensorflow/core/kernels/mlir_generated/op_definitions/imag.mlir.tmpl b/tensorflow/core/kernels/mlir_generated/op_definitions/imag.mlir.tmpl index 779bd89fb60407..ac5dd500138258 100644 --- a/tensorflow/core/kernels/mlir_generated/op_definitions/imag.mlir.tmpl +++ b/tensorflow/core/kernels/mlir_generated/op_definitions/imag.mlir.tmpl @@ -1,5 +1,9 @@ -func.func @Imag_platform_elem_type_output_type(%arg0: tensor<*xelem_type>) - -> tensor<*xoutput_type> attributes {tf_entry, llvm.emit_c_interface} { - %0 = mhlo.imag %arg0 : (tensor<*xelem_type>) -> tensor<*xoutput_type> - func.return %0 : tensor<*xoutput_type> +func.func @Imag_platform_elem_type_output_type(%arg0: tensor<*xelem_type>) -> tensor<*xoutput_type> attributes {llvm.emit_c_interface, tf_entry} { + %0 = shape.shape_of %arg0 : tensor<*xelem_type> -> tensor + %1 = shape.num_elements %0 : tensor -> index + %from_elements = tensor.from_elements %1 : tensor<1xindex> + %2 = mhlo.dynamic_reshape %arg0, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %3 = mhlo.imag %2 : (tensor) -> tensor + %4 = mhlo.dynamic_reshape %3, %0 : (tensor, tensor) -> tensor<*xoutput_type> + return %4 : tensor<*xoutput_type> } diff --git a/tensorflow/core/kernels/mlir_generated/op_definitions/invert.mlir.tmpl b/tensorflow/core/kernels/mlir_generated/op_definitions/invert.mlir.tmpl index d82438e4bff4b8..bba97765817ea9 100644 --- a/tensorflow/core/kernels/mlir_generated/op_definitions/invert.mlir.tmpl +++ b/tensorflow/core/kernels/mlir_generated/op_definitions/invert.mlir.tmpl @@ -1,5 +1,9 @@ -func.func @Invert_platform_elem_type_output_type(%arg0: tensor<*xelem_type>) - -> tensor<*xoutput_type> attributes {tf_entry, llvm.emit_c_interface} { - %0 = mhlo.not %arg0 : tensor<*xelem_type> - func.return %0 : tensor<*xoutput_type> +func.func @Invert_platform_elem_type_output_type(%arg0: tensor<*xelem_type>) -> tensor<*xoutput_type> attributes {llvm.emit_c_interface, tf_entry} { + %0 = shape.shape_of %arg0 : tensor<*xelem_type> -> tensor + %1 = shape.num_elements %0 : tensor -> index + %from_elements = tensor.from_elements %1 : tensor<1xindex> + %2 = mhlo.dynamic_reshape %arg0, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %3 = mhlo.not %2 : tensor + %4 = mhlo.dynamic_reshape %3, %0 : (tensor, tensor) -> tensor<*xelem_type> + return %4 : tensor<*xelem_type> } diff --git a/tensorflow/core/kernels/mlir_generated/op_definitions/is_finite.mlir.tmpl b/tensorflow/core/kernels/mlir_generated/op_definitions/is_finite.mlir.tmpl index c931c8b28e83f7..20ffb573f72eaf 100644 --- a/tensorflow/core/kernels/mlir_generated/op_definitions/is_finite.mlir.tmpl +++ b/tensorflow/core/kernels/mlir_generated/op_definitions/is_finite.mlir.tmpl @@ -1,5 +1,9 @@ -func.func @IsFinite_platform_elem_type_output_type(%arg0: tensor<*xelem_type>) - -> tensor<*xoutput_type> attributes {tf_entry, llvm.emit_c_interface} { - %0 = mhlo.is_finite %arg0 : (tensor<*xelem_type>) -> tensor<*xoutput_type> - func.return %0 : tensor<*xoutput_type> +func.func @IsFinite_platform_elem_type_output_type(%arg0: tensor<*xelem_type>) -> tensor<*xoutput_type> attributes {llvm.emit_c_interface, tf_entry} { + %0 = shape.shape_of %arg0 : tensor<*xelem_type> -> tensor + %1 = shape.num_elements %0 : tensor -> index + %from_elements = tensor.from_elements %1 : tensor<1xindex> + %2 = mhlo.dynamic_reshape %arg0, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %3 = mhlo.is_finite %2 : (tensor) -> tensor + %4 = mhlo.dynamic_reshape %3, %0 : (tensor, tensor) -> tensor<*xoutput_type> + return %4 : tensor<*xoutput_type> } diff --git a/tensorflow/core/kernels/mlir_generated/op_definitions/is_inf.mlir.tmpl b/tensorflow/core/kernels/mlir_generated/op_definitions/is_inf.mlir.tmpl index 74675ebf4341fe..368dd246600504 100644 --- a/tensorflow/core/kernels/mlir_generated/op_definitions/is_inf.mlir.tmpl +++ b/tensorflow/core/kernels/mlir_generated/op_definitions/is_inf.mlir.tmpl @@ -1,5 +1,9 @@ -func.func @IsInf_platform_elem_type_output_type(%arg0: tensor<*xelem_type>) - -> tensor<*xoutput_type> attributes {tf_entry, llvm.emit_c_interface} { - %0 = chlo.is_inf %arg0 : tensor<*xelem_type> -> tensor<*xoutput_type> - func.return %0 : tensor<*xoutput_type> +func.func @IsInf_platform_elem_type_output_type(%arg0: tensor<*xelem_type>) -> tensor<*xoutput_type> attributes {llvm.emit_c_interface, tf_entry} { + %0 = shape.shape_of %arg0 : tensor<*xelem_type> -> tensor + %1 = shape.num_elements %0 : tensor -> index + %from_elements = tensor.from_elements %1 : tensor<1xindex> + %2 = mhlo.dynamic_reshape %arg0, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %3 = chlo.is_inf %2 : tensor -> tensor + %4 = mhlo.dynamic_reshape %3, %0 : (tensor, tensor) -> tensor<*xoutput_type> + return %4 : tensor<*xoutput_type> } diff --git a/tensorflow/core/kernels/mlir_generated/op_definitions/is_nan.mlir.tmpl b/tensorflow/core/kernels/mlir_generated/op_definitions/is_nan.mlir.tmpl index 027b7527432601..739ceca3e1d4d0 100644 --- a/tensorflow/core/kernels/mlir_generated/op_definitions/is_nan.mlir.tmpl +++ b/tensorflow/core/kernels/mlir_generated/op_definitions/is_nan.mlir.tmpl @@ -1,5 +1,9 @@ -func.func @IsNan_platform_elem_type_output_type(%arg0: tensor<*xelem_type>) - -> tensor<*xoutput_type> attributes {tf_entry, llvm.emit_c_interface} { - %0 = chlo.broadcast_compare %arg0, %arg0 {comparison_direction = #chlo} : (tensor<*xelem_type>, tensor<*xelem_type>) -> tensor<*xoutput_type> - func.return %0 : tensor<*xoutput_type> +func.func @IsNan_platform_elem_type_output_type(%arg0: tensor<*xelem_type>) -> tensor<*xoutput_type> attributes {llvm.emit_c_interface, tf_entry} { + %0 = shape.shape_of %arg0 : tensor<*xelem_type> -> tensor + %1 = shape.num_elements %0 : tensor -> index + %from_elements = tensor.from_elements %1 : tensor<1xindex> + %2 = mhlo.dynamic_reshape %arg0, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %3 = chlo.broadcast_compare %2, %2 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %4 = mhlo.dynamic_reshape %3, %0 : (tensor, tensor) -> tensor<*xoutput_type> + return %4 : tensor<*xoutput_type> } diff --git a/tensorflow/core/kernels/mlir_generated/op_definitions/left_shift.mlir.tmpl b/tensorflow/core/kernels/mlir_generated/op_definitions/left_shift.mlir.tmpl index b91e79991378c1..a8638d08cc86b8 100644 --- a/tensorflow/core/kernels/mlir_generated/op_definitions/left_shift.mlir.tmpl +++ b/tensorflow/core/kernels/mlir_generated/op_definitions/left_shift.mlir.tmpl @@ -1,5 +1,129 @@ -func.func @LeftShift_platform_elem_type_output_type(%arg0: tensor<*xelem_type>, %arg1: tensor<*xelem_type>) - -> tensor<*xoutput_type> attributes {tf_entry, llvm.emit_c_interface} { - %0 = chlo.broadcast_shift_left %arg0, %arg1 : (tensor<*xelem_type>, tensor<*xelem_type>) -> tensor<*xelem_type> - func.return %0 : tensor<*xoutput_type> +func.func @LeftShift_platform_elem_type_output_type(%arg0: tensor<*xelem_type>, %arg1: tensor<*xelem_type>) -> tensor<*xoutput_type> attributes {llvm.emit_c_interface, tf_entry} { + %0 = shape.const_shape [1, 1, 1, 1, 1] : tensor<5xindex> + %c5 = arith.constant 5 : index + %1 = shape.const_shape [1, 1, 1, 1] : tensor<4xindex> + %c4 = arith.constant 4 : index + %2 = shape.const_shape [1, 1, 1] : tensor<3xindex> + %c3 = arith.constant 3 : index + %3 = shape.const_shape [1, 1] : tensor<2xindex> + %c2 = arith.constant 2 : index + %4 = shape.const_shape [1] : tensor<1xindex> + %c1 = arith.constant 1 : index + %5 = shape.shape_of %arg0 : tensor<*xelem_type> -> tensor + %6 = shape.shape_of %arg1 : tensor<*xelem_type> -> tensor + %7 = shape.num_elements %5 : tensor -> index + %8 = arith.cmpi eq, %7, %c1 : index + %9 = scf.if %8 -> (tensor<*xelem_type>) { + %14 = shape.num_elements %6 : tensor -> index + %from_elements = tensor.from_elements %14 : tensor<1xindex> + %15 = mhlo.reshape %arg0 : (tensor<*xelem_type>) -> tensor + %16 = mhlo.dynamic_reshape %arg1, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %17 = chlo.broadcast_shift_left %15, %16 : (tensor, tensor) -> tensor + %cast = tensor.cast %17 : tensor to tensor<*xelem_type> + scf.yield %cast : tensor<*xelem_type> + } else { + %14 = shape.num_elements %6 : tensor -> index + %15 = arith.cmpi eq, %14, %c1 : index + %16 = scf.if %15 -> (tensor<*xelem_type>) { + %17 = shape.num_elements %5 : tensor -> index + %from_elements = tensor.from_elements %17 : tensor<1xindex> + %18 = mhlo.dynamic_reshape %arg0, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %19 = mhlo.reshape %arg1 : (tensor<*xelem_type>) -> tensor + %20 = chlo.broadcast_shift_left %18, %19 : (tensor, tensor) -> tensor + %cast = tensor.cast %20 : tensor to tensor<*xelem_type> + scf.yield %cast : tensor<*xelem_type> + } else { + %17 = shape.shape_eq %5, %6 : tensor, tensor + %18 = scf.if %17 -> (tensor<*xelem_type>) { + %19 = shape.any %5, %6 : tensor, tensor -> tensor + %20 = shape.num_elements %19 : tensor -> index + %from_elements = tensor.from_elements %20 : tensor<1xindex> + %21 = mhlo.dynamic_reshape %arg0, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %22 = mhlo.dynamic_reshape %arg1, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %23 = chlo.broadcast_shift_left %21, %22 : (tensor, tensor) -> tensor + %cast = tensor.cast %23 : tensor to tensor<*xelem_type> + scf.yield %cast : tensor<*xelem_type> + } else { + %19:2 = chlo.minimum_broadcast_shapes %5, %6 : tensor, tensor -> tensor, tensor + %20 = shape.rank %19#0 : tensor -> index + %21 = shape.rank %19#1 : tensor -> index + %22 = arith.cmpi sgt, %20, %21 : index + %23 = arith.select %22, %20, %21 : index + %24 = arith.cmpi ule, %23, %c1 : index + %25 = scf.if %24 -> (tensor<*xelem_type>) { + %26 = shape.broadcast %19#0, %4 : tensor, tensor<1xindex> -> tensor + %cast = tensor.cast %26 : tensor to tensor<1xindex> + %27 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %28 = shape.broadcast %19#1, %4 : tensor, tensor<1xindex> -> tensor + %cast_0 = tensor.cast %28 : tensor to tensor<1xindex> + %29 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %30 = chlo.broadcast_shift_left %27, %29 : (tensor, tensor) -> tensor + %cast_1 = tensor.cast %30 : tensor to tensor<*xelem_type> + scf.yield %cast_1 : tensor<*xelem_type> + } else { + %26 = arith.cmpi ule, %23, %c2 : index + %27 = scf.if %26 -> (tensor<*xelem_type>) { + %28 = shape.broadcast %19#0, %3 : tensor, tensor<2xindex> -> tensor + %cast = tensor.cast %28 : tensor to tensor<2xindex> + %29 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xelem_type>, tensor<2xindex>) -> tensor + %30 = shape.broadcast %19#1, %3 : tensor, tensor<2xindex> -> tensor + %cast_0 = tensor.cast %30 : tensor to tensor<2xindex> + %31 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xelem_type>, tensor<2xindex>) -> tensor + %32 = chlo.broadcast_shift_left %29, %31 : (tensor, tensor) -> tensor + %cast_1 = tensor.cast %32 : tensor to tensor<*xelem_type> + scf.yield %cast_1 : tensor<*xelem_type> + } else { + %28 = arith.cmpi ule, %23, %c3 : index + %29 = scf.if %28 -> (tensor<*xelem_type>) { + %30 = shape.broadcast %19#0, %2 : tensor, tensor<3xindex> -> tensor + %cast = tensor.cast %30 : tensor to tensor<3xindex> + %31 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xelem_type>, tensor<3xindex>) -> tensor + %32 = shape.broadcast %19#1, %2 : tensor, tensor<3xindex> -> tensor + %cast_0 = tensor.cast %32 : tensor to tensor<3xindex> + %33 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xelem_type>, tensor<3xindex>) -> tensor + %34 = chlo.broadcast_shift_left %31, %33 : (tensor, tensor) -> tensor + %cast_1 = tensor.cast %34 : tensor to tensor<*xelem_type> + scf.yield %cast_1 : tensor<*xelem_type> + } else { + %30 = arith.cmpi ule, %23, %c4 : index + %31 = scf.if %30 -> (tensor<*xelem_type>) { + %32 = shape.broadcast %19#0, %1 : tensor, tensor<4xindex> -> tensor + %cast = tensor.cast %32 : tensor to tensor<4xindex> + %33 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xelem_type>, tensor<4xindex>) -> tensor + %34 = shape.broadcast %19#1, %1 : tensor, tensor<4xindex> -> tensor + %cast_0 = tensor.cast %34 : tensor to tensor<4xindex> + %35 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xelem_type>, tensor<4xindex>) -> tensor + %36 = chlo.broadcast_shift_left %33, %35 : (tensor, tensor) -> tensor + %cast_1 = tensor.cast %36 : tensor to tensor<*xelem_type> + scf.yield %cast_1 : tensor<*xelem_type> + } else { + %32 = arith.cmpi ule, %23, %c5 : index + cf.assert %32, "Input for dynamic binary or n-ary op lowering was of a rank greater than 5" + %33 = shape.broadcast %19#0, %0 : tensor, tensor<5xindex> -> tensor + %cast = tensor.cast %33 : tensor to tensor<5xindex> + %34 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xelem_type>, tensor<5xindex>) -> tensor + %35 = shape.broadcast %19#1, %0 : tensor, tensor<5xindex> -> tensor + %cast_0 = tensor.cast %35 : tensor to tensor<5xindex> + %36 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xelem_type>, tensor<5xindex>) -> tensor + %37 = chlo.broadcast_shift_left %34, %36 : (tensor, tensor) -> tensor + %cast_1 = tensor.cast %37 : tensor to tensor<*xelem_type> + scf.yield %cast_1 : tensor<*xelem_type> + } + scf.yield %31 : tensor<*xelem_type> + } + scf.yield %29 : tensor<*xelem_type> + } + scf.yield %27 : tensor<*xelem_type> + } + scf.yield %25 : tensor<*xelem_type> + } + scf.yield %18 : tensor<*xelem_type> + } + scf.yield %16 : tensor<*xelem_type> + } + %10 = shape.shape_of %arg0 : tensor<*xelem_type> -> tensor + %11 = shape.shape_of %arg1 : tensor<*xelem_type> -> tensor + %12 = shape.broadcast %10, %11 : tensor, tensor -> tensor + %13 = mhlo.dynamic_reshape %9, %12 : (tensor<*xelem_type>, tensor) -> tensor<*xelem_type> + return %13 : tensor<*xelem_type> } diff --git a/tensorflow/core/kernels/mlir_generated/op_definitions/less.mlir.tmpl b/tensorflow/core/kernels/mlir_generated/op_definitions/less.mlir.tmpl index cc68d6f7113892..0d3a3cd1844dc7 100644 --- a/tensorflow/core/kernels/mlir_generated/op_definitions/less.mlir.tmpl +++ b/tensorflow/core/kernels/mlir_generated/op_definitions/less.mlir.tmpl @@ -1,5 +1,129 @@ -func.func @Less_platform_elem_type_output_type(%arg0: tensor<*xelem_type>, %arg1: tensor<*xelem_type>) - -> tensor<*xoutput_type> attributes {tf_entry, llvm.emit_c_interface} { - %0 = chlo.broadcast_compare %arg0, %arg1 {comparison_direction = #chlo} : (tensor<*xelem_type>, tensor<*xelem_type>) -> tensor<*xoutput_type> - func.return %0 : tensor<*xoutput_type> +func.func @Less_platform_elem_type_output_type(%arg0: tensor<*xelem_type>, %arg1: tensor<*xelem_type>) -> tensor<*xoutput_type> attributes {llvm.emit_c_interface, tf_entry} { + %0 = shape.const_shape [1, 1, 1, 1, 1] : tensor<5xindex> + %c5 = arith.constant 5 : index + %1 = shape.const_shape [1, 1, 1, 1] : tensor<4xindex> + %c4 = arith.constant 4 : index + %2 = shape.const_shape [1, 1, 1] : tensor<3xindex> + %c3 = arith.constant 3 : index + %3 = shape.const_shape [1, 1] : tensor<2xindex> + %c2 = arith.constant 2 : index + %4 = shape.const_shape [1] : tensor<1xindex> + %c1 = arith.constant 1 : index + %5 = shape.shape_of %arg0 : tensor<*xelem_type> -> tensor + %6 = shape.shape_of %arg1 : tensor<*xelem_type> -> tensor + %7 = shape.num_elements %5 : tensor -> index + %8 = arith.cmpi eq, %7, %c1 : index + %9 = scf.if %8 -> (tensor<*xoutput_type>) { + %14 = shape.num_elements %6 : tensor -> index + %from_elements = tensor.from_elements %14 : tensor<1xindex> + %15 = mhlo.reshape %arg0 : (tensor<*xelem_type>) -> tensor + %16 = mhlo.dynamic_reshape %arg1, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %17 = chlo.broadcast_compare %15, %16 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %cast = tensor.cast %17 : tensor to tensor<*xoutput_type> + scf.yield %cast : tensor<*xoutput_type> + } else { + %14 = shape.num_elements %6 : tensor -> index + %15 = arith.cmpi eq, %14, %c1 : index + %16 = scf.if %15 -> (tensor<*xoutput_type>) { + %17 = shape.num_elements %5 : tensor -> index + %from_elements = tensor.from_elements %17 : tensor<1xindex> + %18 = mhlo.dynamic_reshape %arg0, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %19 = mhlo.reshape %arg1 : (tensor<*xelem_type>) -> tensor + %20 = chlo.broadcast_compare %18, %19 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %cast = tensor.cast %20 : tensor to tensor<*xoutput_type> + scf.yield %cast : tensor<*xoutput_type> + } else { + %17 = shape.shape_eq %5, %6 : tensor, tensor + %18 = scf.if %17 -> (tensor<*xoutput_type>) { + %19 = shape.any %5, %6 : tensor, tensor -> tensor + %20 = shape.num_elements %19 : tensor -> index + %from_elements = tensor.from_elements %20 : tensor<1xindex> + %21 = mhlo.dynamic_reshape %arg0, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %22 = mhlo.dynamic_reshape %arg1, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %23 = chlo.broadcast_compare %21, %22 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %cast = tensor.cast %23 : tensor to tensor<*xoutput_type> + scf.yield %cast : tensor<*xoutput_type> + } else { + %19:2 = chlo.minimum_broadcast_shapes %5, %6 : tensor, tensor -> tensor, tensor + %20 = shape.rank %19#0 : tensor -> index + %21 = shape.rank %19#1 : tensor -> index + %22 = arith.cmpi sgt, %20, %21 : index + %23 = arith.select %22, %20, %21 : index + %24 = arith.cmpi ule, %23, %c1 : index + %25 = scf.if %24 -> (tensor<*xoutput_type>) { + %26 = shape.broadcast %19#0, %4 : tensor, tensor<1xindex> -> tensor + %cast = tensor.cast %26 : tensor to tensor<1xindex> + %27 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %28 = shape.broadcast %19#1, %4 : tensor, tensor<1xindex> -> tensor + %cast_0 = tensor.cast %28 : tensor to tensor<1xindex> + %29 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %30 = chlo.broadcast_compare %27, %29 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %cast_1 = tensor.cast %30 : tensor to tensor<*xoutput_type> + scf.yield %cast_1 : tensor<*xoutput_type> + } else { + %26 = arith.cmpi ule, %23, %c2 : index + %27 = scf.if %26 -> (tensor<*xoutput_type>) { + %28 = shape.broadcast %19#0, %3 : tensor, tensor<2xindex> -> tensor + %cast = tensor.cast %28 : tensor to tensor<2xindex> + %29 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xelem_type>, tensor<2xindex>) -> tensor + %30 = shape.broadcast %19#1, %3 : tensor, tensor<2xindex> -> tensor + %cast_0 = tensor.cast %30 : tensor to tensor<2xindex> + %31 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xelem_type>, tensor<2xindex>) -> tensor + %32 = chlo.broadcast_compare %29, %31 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %cast_1 = tensor.cast %32 : tensor to tensor<*xoutput_type> + scf.yield %cast_1 : tensor<*xoutput_type> + } else { + %28 = arith.cmpi ule, %23, %c3 : index + %29 = scf.if %28 -> (tensor<*xoutput_type>) { + %30 = shape.broadcast %19#0, %2 : tensor, tensor<3xindex> -> tensor + %cast = tensor.cast %30 : tensor to tensor<3xindex> + %31 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xelem_type>, tensor<3xindex>) -> tensor + %32 = shape.broadcast %19#1, %2 : tensor, tensor<3xindex> -> tensor + %cast_0 = tensor.cast %32 : tensor to tensor<3xindex> + %33 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xelem_type>, tensor<3xindex>) -> tensor + %34 = chlo.broadcast_compare %31, %33 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %cast_1 = tensor.cast %34 : tensor to tensor<*xoutput_type> + scf.yield %cast_1 : tensor<*xoutput_type> + } else { + %30 = arith.cmpi ule, %23, %c4 : index + %31 = scf.if %30 -> (tensor<*xoutput_type>) { + %32 = shape.broadcast %19#0, %1 : tensor, tensor<4xindex> -> tensor + %cast = tensor.cast %32 : tensor to tensor<4xindex> + %33 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xelem_type>, tensor<4xindex>) -> tensor + %34 = shape.broadcast %19#1, %1 : tensor, tensor<4xindex> -> tensor + %cast_0 = tensor.cast %34 : tensor to tensor<4xindex> + %35 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xelem_type>, tensor<4xindex>) -> tensor + %36 = chlo.broadcast_compare %33, %35 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %cast_1 = tensor.cast %36 : tensor to tensor<*xoutput_type> + scf.yield %cast_1 : tensor<*xoutput_type> + } else { + %32 = arith.cmpi ule, %23, %c5 : index + cf.assert %32, "Input for dynamic binary or n-ary op lowering was of a rank greater than 5" + %33 = shape.broadcast %19#0, %0 : tensor, tensor<5xindex> -> tensor + %cast = tensor.cast %33 : tensor to tensor<5xindex> + %34 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xelem_type>, tensor<5xindex>) -> tensor + %35 = shape.broadcast %19#1, %0 : tensor, tensor<5xindex> -> tensor + %cast_0 = tensor.cast %35 : tensor to tensor<5xindex> + %36 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xelem_type>, tensor<5xindex>) -> tensor + %37 = chlo.broadcast_compare %34, %36 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %cast_1 = tensor.cast %37 : tensor to tensor<*xoutput_type> + scf.yield %cast_1 : tensor<*xoutput_type> + } + scf.yield %31 : tensor<*xoutput_type> + } + scf.yield %29 : tensor<*xoutput_type> + } + scf.yield %27 : tensor<*xoutput_type> + } + scf.yield %25 : tensor<*xoutput_type> + } + scf.yield %18 : tensor<*xoutput_type> + } + scf.yield %16 : tensor<*xoutput_type> + } + %10 = shape.shape_of %arg0 : tensor<*xelem_type> -> tensor + %11 = shape.shape_of %arg1 : tensor<*xelem_type> -> tensor + %12 = shape.broadcast %10, %11 : tensor, tensor -> tensor + %13 = mhlo.dynamic_reshape %9, %12 : (tensor<*xoutput_type>, tensor) -> tensor<*xoutput_type> + return %13 : tensor<*xoutput_type> } diff --git a/tensorflow/core/kernels/mlir_generated/op_definitions/less_equal.mlir.tmpl b/tensorflow/core/kernels/mlir_generated/op_definitions/less_equal.mlir.tmpl index d81d5d48a9e86d..31a1516a298815 100644 --- a/tensorflow/core/kernels/mlir_generated/op_definitions/less_equal.mlir.tmpl +++ b/tensorflow/core/kernels/mlir_generated/op_definitions/less_equal.mlir.tmpl @@ -1,5 +1,129 @@ -func.func @LessEqual_platform_elem_type_output_type(%arg0: tensor<*xelem_type>, %arg1: tensor<*xelem_type>) - -> tensor<*xoutput_type> attributes {tf_entry, llvm.emit_c_interface} { - %0 = chlo.broadcast_compare %arg0, %arg1 {comparison_direction = #chlo} : (tensor<*xelem_type>, tensor<*xelem_type>) -> tensor<*xoutput_type> - func.return %0 : tensor<*xoutput_type> +func.func @LessEqual_platform_elem_type_output_type(%arg0: tensor<*xelem_type>, %arg1: tensor<*xelem_type>) -> tensor<*xoutput_type> attributes {llvm.emit_c_interface, tf_entry} { + %0 = shape.const_shape [1, 1, 1, 1, 1] : tensor<5xindex> + %c5 = arith.constant 5 : index + %1 = shape.const_shape [1, 1, 1, 1] : tensor<4xindex> + %c4 = arith.constant 4 : index + %2 = shape.const_shape [1, 1, 1] : tensor<3xindex> + %c3 = arith.constant 3 : index + %3 = shape.const_shape [1, 1] : tensor<2xindex> + %c2 = arith.constant 2 : index + %4 = shape.const_shape [1] : tensor<1xindex> + %c1 = arith.constant 1 : index + %5 = shape.shape_of %arg0 : tensor<*xelem_type> -> tensor + %6 = shape.shape_of %arg1 : tensor<*xelem_type> -> tensor + %7 = shape.num_elements %5 : tensor -> index + %8 = arith.cmpi eq, %7, %c1 : index + %9 = scf.if %8 -> (tensor<*xoutput_type>) { + %14 = shape.num_elements %6 : tensor -> index + %from_elements = tensor.from_elements %14 : tensor<1xindex> + %15 = mhlo.reshape %arg0 : (tensor<*xelem_type>) -> tensor + %16 = mhlo.dynamic_reshape %arg1, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %17 = chlo.broadcast_compare %15, %16 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %cast = tensor.cast %17 : tensor to tensor<*xoutput_type> + scf.yield %cast : tensor<*xoutput_type> + } else { + %14 = shape.num_elements %6 : tensor -> index + %15 = arith.cmpi eq, %14, %c1 : index + %16 = scf.if %15 -> (tensor<*xoutput_type>) { + %17 = shape.num_elements %5 : tensor -> index + %from_elements = tensor.from_elements %17 : tensor<1xindex> + %18 = mhlo.dynamic_reshape %arg0, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %19 = mhlo.reshape %arg1 : (tensor<*xelem_type>) -> tensor + %20 = chlo.broadcast_compare %18, %19 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %cast = tensor.cast %20 : tensor to tensor<*xoutput_type> + scf.yield %cast : tensor<*xoutput_type> + } else { + %17 = shape.shape_eq %5, %6 : tensor, tensor + %18 = scf.if %17 -> (tensor<*xoutput_type>) { + %19 = shape.any %5, %6 : tensor, tensor -> tensor + %20 = shape.num_elements %19 : tensor -> index + %from_elements = tensor.from_elements %20 : tensor<1xindex> + %21 = mhlo.dynamic_reshape %arg0, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %22 = mhlo.dynamic_reshape %arg1, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %23 = chlo.broadcast_compare %21, %22 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %cast = tensor.cast %23 : tensor to tensor<*xoutput_type> + scf.yield %cast : tensor<*xoutput_type> + } else { + %19:2 = chlo.minimum_broadcast_shapes %5, %6 : tensor, tensor -> tensor, tensor + %20 = shape.rank %19#0 : tensor -> index + %21 = shape.rank %19#1 : tensor -> index + %22 = arith.cmpi sgt, %20, %21 : index + %23 = arith.select %22, %20, %21 : index + %24 = arith.cmpi ule, %23, %c1 : index + %25 = scf.if %24 -> (tensor<*xoutput_type>) { + %26 = shape.broadcast %19#0, %4 : tensor, tensor<1xindex> -> tensor + %cast = tensor.cast %26 : tensor to tensor<1xindex> + %27 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %28 = shape.broadcast %19#1, %4 : tensor, tensor<1xindex> -> tensor + %cast_0 = tensor.cast %28 : tensor to tensor<1xindex> + %29 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %30 = chlo.broadcast_compare %27, %29 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %cast_1 = tensor.cast %30 : tensor to tensor<*xoutput_type> + scf.yield %cast_1 : tensor<*xoutput_type> + } else { + %26 = arith.cmpi ule, %23, %c2 : index + %27 = scf.if %26 -> (tensor<*xoutput_type>) { + %28 = shape.broadcast %19#0, %3 : tensor, tensor<2xindex> -> tensor + %cast = tensor.cast %28 : tensor to tensor<2xindex> + %29 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xelem_type>, tensor<2xindex>) -> tensor + %30 = shape.broadcast %19#1, %3 : tensor, tensor<2xindex> -> tensor + %cast_0 = tensor.cast %30 : tensor to tensor<2xindex> + %31 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xelem_type>, tensor<2xindex>) -> tensor + %32 = chlo.broadcast_compare %29, %31 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %cast_1 = tensor.cast %32 : tensor to tensor<*xoutput_type> + scf.yield %cast_1 : tensor<*xoutput_type> + } else { + %28 = arith.cmpi ule, %23, %c3 : index + %29 = scf.if %28 -> (tensor<*xoutput_type>) { + %30 = shape.broadcast %19#0, %2 : tensor, tensor<3xindex> -> tensor + %cast = tensor.cast %30 : tensor to tensor<3xindex> + %31 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xelem_type>, tensor<3xindex>) -> tensor + %32 = shape.broadcast %19#1, %2 : tensor, tensor<3xindex> -> tensor + %cast_0 = tensor.cast %32 : tensor to tensor<3xindex> + %33 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xelem_type>, tensor<3xindex>) -> tensor + %34 = chlo.broadcast_compare %31, %33 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %cast_1 = tensor.cast %34 : tensor to tensor<*xoutput_type> + scf.yield %cast_1 : tensor<*xoutput_type> + } else { + %30 = arith.cmpi ule, %23, %c4 : index + %31 = scf.if %30 -> (tensor<*xoutput_type>) { + %32 = shape.broadcast %19#0, %1 : tensor, tensor<4xindex> -> tensor + %cast = tensor.cast %32 : tensor to tensor<4xindex> + %33 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xelem_type>, tensor<4xindex>) -> tensor + %34 = shape.broadcast %19#1, %1 : tensor, tensor<4xindex> -> tensor + %cast_0 = tensor.cast %34 : tensor to tensor<4xindex> + %35 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xelem_type>, tensor<4xindex>) -> tensor + %36 = chlo.broadcast_compare %33, %35 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %cast_1 = tensor.cast %36 : tensor to tensor<*xoutput_type> + scf.yield %cast_1 : tensor<*xoutput_type> + } else { + %32 = arith.cmpi ule, %23, %c5 : index + cf.assert %32, "Input for dynamic binary or n-ary op lowering was of a rank greater than 5" + %33 = shape.broadcast %19#0, %0 : tensor, tensor<5xindex> -> tensor + %cast = tensor.cast %33 : tensor to tensor<5xindex> + %34 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xelem_type>, tensor<5xindex>) -> tensor + %35 = shape.broadcast %19#1, %0 : tensor, tensor<5xindex> -> tensor + %cast_0 = tensor.cast %35 : tensor to tensor<5xindex> + %36 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xelem_type>, tensor<5xindex>) -> tensor + %37 = chlo.broadcast_compare %34, %36 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %cast_1 = tensor.cast %37 : tensor to tensor<*xoutput_type> + scf.yield %cast_1 : tensor<*xoutput_type> + } + scf.yield %31 : tensor<*xoutput_type> + } + scf.yield %29 : tensor<*xoutput_type> + } + scf.yield %27 : tensor<*xoutput_type> + } + scf.yield %25 : tensor<*xoutput_type> + } + scf.yield %18 : tensor<*xoutput_type> + } + scf.yield %16 : tensor<*xoutput_type> + } + %10 = shape.shape_of %arg0 : tensor<*xelem_type> -> tensor + %11 = shape.shape_of %arg1 : tensor<*xelem_type> -> tensor + %12 = shape.broadcast %10, %11 : tensor, tensor -> tensor + %13 = mhlo.dynamic_reshape %9, %12 : (tensor<*xoutput_type>, tensor) -> tensor<*xoutput_type> + return %13 : tensor<*xoutput_type> } diff --git a/tensorflow/core/kernels/mlir_generated/op_definitions/lgamma.mlir.tmpl b/tensorflow/core/kernels/mlir_generated/op_definitions/lgamma.mlir.tmpl index fcbc71e4dbd6c0..1af17e4d37e512 100644 --- a/tensorflow/core/kernels/mlir_generated/op_definitions/lgamma.mlir.tmpl +++ b/tensorflow/core/kernels/mlir_generated/op_definitions/lgamma.mlir.tmpl @@ -1,5 +1,9 @@ -func.func @Lgamma_platform_elem_type_output_type(%arg0: tensor<*xelem_type>) - -> tensor<*xoutput_type> attributes {tf_entry, llvm.emit_c_interface} { - %0 = chlo.lgamma %arg0 : tensor<*xelem_type> -> tensor<*xoutput_type> - func.return %0 : tensor<*xoutput_type> +func.func @Lgamma_platform_elem_type_output_type(%arg0: tensor<*xelem_type>) -> tensor<*xoutput_type> attributes {llvm.emit_c_interface, tf_entry} { + %0 = shape.shape_of %arg0 : tensor<*xelem_type> -> tensor + %1 = shape.num_elements %0 : tensor -> index + %from_elements = tensor.from_elements %1 : tensor<1xindex> + %2 = mhlo.dynamic_reshape %arg0, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %3 = chlo.lgamma %2 : tensor -> tensor + %4 = mhlo.dynamic_reshape %3, %0 : (tensor, tensor) -> tensor<*xoutput_type> + return %4 : tensor<*xoutput_type> } diff --git a/tensorflow/core/kernels/mlir_generated/op_definitions/log.mlir.tmpl b/tensorflow/core/kernels/mlir_generated/op_definitions/log.mlir.tmpl index 7cbd3f2d5173cd..a794234f629d58 100644 --- a/tensorflow/core/kernels/mlir_generated/op_definitions/log.mlir.tmpl +++ b/tensorflow/core/kernels/mlir_generated/op_definitions/log.mlir.tmpl @@ -1,5 +1,9 @@ -func.func @Log_platform_elem_type_output_type(%arg0: tensor<*xelem_type>) - -> tensor<*xoutput_type> attributes {tf_entry, llvm.emit_c_interface} { - %0 = mhlo.log %arg0 : tensor<*xelem_type> - func.return %0 : tensor<*xoutput_type> +func.func @Log_platform_elem_type_output_type(%arg0: tensor<*xelem_type>) -> tensor<*xoutput_type> attributes {llvm.emit_c_interface, tf_entry} { + %0 = shape.shape_of %arg0 : tensor<*xelem_type> -> tensor + %1 = shape.num_elements %0 : tensor -> index + %from_elements = tensor.from_elements %1 : tensor<1xindex> + %2 = mhlo.dynamic_reshape %arg0, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %3 = mhlo.log %2 : tensor + %4 = mhlo.dynamic_reshape %3, %0 : (tensor, tensor) -> tensor<*xelem_type> + return %4 : tensor<*xelem_type> } diff --git a/tensorflow/core/kernels/mlir_generated/op_definitions/log1p.mlir.tmpl b/tensorflow/core/kernels/mlir_generated/op_definitions/log1p.mlir.tmpl index 421aa3c61fcd4d..9aeb15d4c69d38 100644 --- a/tensorflow/core/kernels/mlir_generated/op_definitions/log1p.mlir.tmpl +++ b/tensorflow/core/kernels/mlir_generated/op_definitions/log1p.mlir.tmpl @@ -1,5 +1,9 @@ -func.func @Log1p_platform_elem_type_output_type(%arg0: tensor<*xelem_type>) - -> tensor<*xoutput_type> attributes {tf_entry, llvm.emit_c_interface} { - %0 = mhlo.log_plus_one %arg0 : tensor<*xelem_type> - func.return %0 : tensor<*xoutput_type> +func.func @Log1p_platform_elem_type_output_type(%arg0: tensor<*xelem_type>) -> tensor<*xoutput_type> attributes {llvm.emit_c_interface, tf_entry} { + %0 = shape.shape_of %arg0 : tensor<*xelem_type> -> tensor + %1 = shape.num_elements %0 : tensor -> index + %from_elements = tensor.from_elements %1 : tensor<1xindex> + %2 = mhlo.dynamic_reshape %arg0, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %3 = mhlo.log_plus_one %2 : tensor + %4 = mhlo.dynamic_reshape %3, %0 : (tensor, tensor) -> tensor<*xelem_type> + return %4 : tensor<*xelem_type> } diff --git a/tensorflow/core/kernels/mlir_generated/op_definitions/logical_and.mlir.tmpl b/tensorflow/core/kernels/mlir_generated/op_definitions/logical_and.mlir.tmpl index d4a23178c10ecc..62ea4c76bb670a 100644 --- a/tensorflow/core/kernels/mlir_generated/op_definitions/logical_and.mlir.tmpl +++ b/tensorflow/core/kernels/mlir_generated/op_definitions/logical_and.mlir.tmpl @@ -1,5 +1,129 @@ -func.func @LogicalAnd_platform_elem_type_output_type(%arg0: tensor<*xelem_type>, %arg1: tensor<*xelem_type>) - -> tensor<*xoutput_type> attributes {tf_entry, llvm.emit_c_interface} { - %0 = chlo.broadcast_and %arg0, %arg1 : (tensor<*xelem_type>, tensor<*xelem_type>) -> tensor<*xelem_type> - func.return %0 : tensor<*xoutput_type> +func.func @LogicalAnd_platform_elem_type_output_type(%arg0: tensor<*xelem_type>, %arg1: tensor<*xelem_type>) -> tensor<*xoutput_type> attributes {llvm.emit_c_interface, tf_entry} { + %0 = shape.const_shape [1, 1, 1, 1, 1] : tensor<5xindex> + %c5 = arith.constant 5 : index + %1 = shape.const_shape [1, 1, 1, 1] : tensor<4xindex> + %c4 = arith.constant 4 : index + %2 = shape.const_shape [1, 1, 1] : tensor<3xindex> + %c3 = arith.constant 3 : index + %3 = shape.const_shape [1, 1] : tensor<2xindex> + %c2 = arith.constant 2 : index + %4 = shape.const_shape [1] : tensor<1xindex> + %c1 = arith.constant 1 : index + %5 = shape.shape_of %arg0 : tensor<*xelem_type> -> tensor + %6 = shape.shape_of %arg1 : tensor<*xelem_type> -> tensor + %7 = shape.num_elements %5 : tensor -> index + %8 = arith.cmpi eq, %7, %c1 : index + %9 = scf.if %8 -> (tensor<*xelem_type>) { + %14 = shape.num_elements %6 : tensor -> index + %from_elements = tensor.from_elements %14 : tensor<1xindex> + %15 = mhlo.reshape %arg0 : (tensor<*xelem_type>) -> tensor + %16 = mhlo.dynamic_reshape %arg1, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %17 = chlo.broadcast_and %15, %16 : (tensor, tensor) -> tensor + %cast = tensor.cast %17 : tensor to tensor<*xelem_type> + scf.yield %cast : tensor<*xelem_type> + } else { + %14 = shape.num_elements %6 : tensor -> index + %15 = arith.cmpi eq, %14, %c1 : index + %16 = scf.if %15 -> (tensor<*xelem_type>) { + %17 = shape.num_elements %5 : tensor -> index + %from_elements = tensor.from_elements %17 : tensor<1xindex> + %18 = mhlo.dynamic_reshape %arg0, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %19 = mhlo.reshape %arg1 : (tensor<*xelem_type>) -> tensor + %20 = chlo.broadcast_and %18, %19 : (tensor, tensor) -> tensor + %cast = tensor.cast %20 : tensor to tensor<*xelem_type> + scf.yield %cast : tensor<*xelem_type> + } else { + %17 = shape.shape_eq %5, %6 : tensor, tensor + %18 = scf.if %17 -> (tensor<*xelem_type>) { + %19 = shape.any %5, %6 : tensor, tensor -> tensor + %20 = shape.num_elements %19 : tensor -> index + %from_elements = tensor.from_elements %20 : tensor<1xindex> + %21 = mhlo.dynamic_reshape %arg0, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %22 = mhlo.dynamic_reshape %arg1, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %23 = chlo.broadcast_and %21, %22 : (tensor, tensor) -> tensor + %cast = tensor.cast %23 : tensor to tensor<*xelem_type> + scf.yield %cast : tensor<*xelem_type> + } else { + %19:2 = chlo.minimum_broadcast_shapes %5, %6 : tensor, tensor -> tensor, tensor + %20 = shape.rank %19#0 : tensor -> index + %21 = shape.rank %19#1 : tensor -> index + %22 = arith.cmpi sgt, %20, %21 : index + %23 = arith.select %22, %20, %21 : index + %24 = arith.cmpi ule, %23, %c1 : index + %25 = scf.if %24 -> (tensor<*xelem_type>) { + %26 = shape.broadcast %19#0, %4 : tensor, tensor<1xindex> -> tensor + %cast = tensor.cast %26 : tensor to tensor<1xindex> + %27 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %28 = shape.broadcast %19#1, %4 : tensor, tensor<1xindex> -> tensor + %cast_0 = tensor.cast %28 : tensor to tensor<1xindex> + %29 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %30 = chlo.broadcast_and %27, %29 : (tensor, tensor) -> tensor + %cast_1 = tensor.cast %30 : tensor to tensor<*xelem_type> + scf.yield %cast_1 : tensor<*xelem_type> + } else { + %26 = arith.cmpi ule, %23, %c2 : index + %27 = scf.if %26 -> (tensor<*xelem_type>) { + %28 = shape.broadcast %19#0, %3 : tensor, tensor<2xindex> -> tensor + %cast = tensor.cast %28 : tensor to tensor<2xindex> + %29 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xelem_type>, tensor<2xindex>) -> tensor + %30 = shape.broadcast %19#1, %3 : tensor, tensor<2xindex> -> tensor + %cast_0 = tensor.cast %30 : tensor to tensor<2xindex> + %31 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xelem_type>, tensor<2xindex>) -> tensor + %32 = chlo.broadcast_and %29, %31 : (tensor, tensor) -> tensor + %cast_1 = tensor.cast %32 : tensor to tensor<*xelem_type> + scf.yield %cast_1 : tensor<*xelem_type> + } else { + %28 = arith.cmpi ule, %23, %c3 : index + %29 = scf.if %28 -> (tensor<*xelem_type>) { + %30 = shape.broadcast %19#0, %2 : tensor, tensor<3xindex> -> tensor + %cast = tensor.cast %30 : tensor to tensor<3xindex> + %31 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xelem_type>, tensor<3xindex>) -> tensor + %32 = shape.broadcast %19#1, %2 : tensor, tensor<3xindex> -> tensor + %cast_0 = tensor.cast %32 : tensor to tensor<3xindex> + %33 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xelem_type>, tensor<3xindex>) -> tensor + %34 = chlo.broadcast_and %31, %33 : (tensor, tensor) -> tensor + %cast_1 = tensor.cast %34 : tensor to tensor<*xelem_type> + scf.yield %cast_1 : tensor<*xelem_type> + } else { + %30 = arith.cmpi ule, %23, %c4 : index + %31 = scf.if %30 -> (tensor<*xelem_type>) { + %32 = shape.broadcast %19#0, %1 : tensor, tensor<4xindex> -> tensor + %cast = tensor.cast %32 : tensor to tensor<4xindex> + %33 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xelem_type>, tensor<4xindex>) -> tensor + %34 = shape.broadcast %19#1, %1 : tensor, tensor<4xindex> -> tensor + %cast_0 = tensor.cast %34 : tensor to tensor<4xindex> + %35 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xelem_type>, tensor<4xindex>) -> tensor + %36 = chlo.broadcast_and %33, %35 : (tensor, tensor) -> tensor + %cast_1 = tensor.cast %36 : tensor to tensor<*xelem_type> + scf.yield %cast_1 : tensor<*xelem_type> + } else { + %32 = arith.cmpi ule, %23, %c5 : index + cf.assert %32, "Input for dynamic binary or n-ary op lowering was of a rank greater than 5" + %33 = shape.broadcast %19#0, %0 : tensor, tensor<5xindex> -> tensor + %cast = tensor.cast %33 : tensor to tensor<5xindex> + %34 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xelem_type>, tensor<5xindex>) -> tensor + %35 = shape.broadcast %19#1, %0 : tensor, tensor<5xindex> -> tensor + %cast_0 = tensor.cast %35 : tensor to tensor<5xindex> + %36 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xelem_type>, tensor<5xindex>) -> tensor + %37 = chlo.broadcast_and %34, %36 : (tensor, tensor) -> tensor + %cast_1 = tensor.cast %37 : tensor to tensor<*xelem_type> + scf.yield %cast_1 : tensor<*xelem_type> + } + scf.yield %31 : tensor<*xelem_type> + } + scf.yield %29 : tensor<*xelem_type> + } + scf.yield %27 : tensor<*xelem_type> + } + scf.yield %25 : tensor<*xelem_type> + } + scf.yield %18 : tensor<*xelem_type> + } + scf.yield %16 : tensor<*xelem_type> + } + %10 = shape.shape_of %arg0 : tensor<*xelem_type> -> tensor + %11 = shape.shape_of %arg1 : tensor<*xelem_type> -> tensor + %12 = shape.broadcast %10, %11 : tensor, tensor -> tensor + %13 = mhlo.dynamic_reshape %9, %12 : (tensor<*xelem_type>, tensor) -> tensor<*xelem_type> + return %13 : tensor<*xelem_type> } diff --git a/tensorflow/core/kernels/mlir_generated/op_definitions/logical_not.mlir.tmpl b/tensorflow/core/kernels/mlir_generated/op_definitions/logical_not.mlir.tmpl index d2fc7ce3fdd955..0478252983a5b1 100644 --- a/tensorflow/core/kernels/mlir_generated/op_definitions/logical_not.mlir.tmpl +++ b/tensorflow/core/kernels/mlir_generated/op_definitions/logical_not.mlir.tmpl @@ -1,5 +1,9 @@ -func.func @LogicalNot_platform_elem_type_output_type(%arg0: tensor<*xelem_type>) - -> tensor<*xoutput_type> attributes {tf_entry, llvm.emit_c_interface} { - %0 = mhlo.not %arg0 : tensor<*xelem_type> - func.return %0 : tensor<*xoutput_type> +func.func @LogicalNot_platform_elem_type_output_type(%arg0: tensor<*xelem_type>) -> tensor<*xoutput_type> attributes {llvm.emit_c_interface, tf_entry} { + %0 = shape.shape_of %arg0 : tensor<*xelem_type> -> tensor + %1 = shape.num_elements %0 : tensor -> index + %from_elements = tensor.from_elements %1 : tensor<1xindex> + %2 = mhlo.dynamic_reshape %arg0, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %3 = mhlo.not %2 : tensor + %4 = mhlo.dynamic_reshape %3, %0 : (tensor, tensor) -> tensor<*xelem_type> + return %4 : tensor<*xelem_type> } diff --git a/tensorflow/core/kernels/mlir_generated/op_definitions/logical_or.mlir.tmpl b/tensorflow/core/kernels/mlir_generated/op_definitions/logical_or.mlir.tmpl index ffe5092bd61ced..d515ae404b1e68 100644 --- a/tensorflow/core/kernels/mlir_generated/op_definitions/logical_or.mlir.tmpl +++ b/tensorflow/core/kernels/mlir_generated/op_definitions/logical_or.mlir.tmpl @@ -1,5 +1,129 @@ -func.func @LogicalOr_platform_elem_type_output_type(%arg0: tensor<*xelem_type>, %arg1: tensor<*xelem_type>) - -> tensor<*xoutput_type> attributes {tf_entry, llvm.emit_c_interface} { - %0 = chlo.broadcast_or %arg0, %arg1 : (tensor<*xelem_type>, tensor<*xelem_type>) -> tensor<*xelem_type> - func.return %0 : tensor<*xoutput_type> +func.func @LogicalOr_platform_elem_type_output_type(%arg0: tensor<*xelem_type>, %arg1: tensor<*xelem_type>) -> tensor<*xoutput_type> attributes {llvm.emit_c_interface, tf_entry} { + %0 = shape.const_shape [1, 1, 1, 1, 1] : tensor<5xindex> + %c5 = arith.constant 5 : index + %1 = shape.const_shape [1, 1, 1, 1] : tensor<4xindex> + %c4 = arith.constant 4 : index + %2 = shape.const_shape [1, 1, 1] : tensor<3xindex> + %c3 = arith.constant 3 : index + %3 = shape.const_shape [1, 1] : tensor<2xindex> + %c2 = arith.constant 2 : index + %4 = shape.const_shape [1] : tensor<1xindex> + %c1 = arith.constant 1 : index + %5 = shape.shape_of %arg0 : tensor<*xelem_type> -> tensor + %6 = shape.shape_of %arg1 : tensor<*xelem_type> -> tensor + %7 = shape.num_elements %5 : tensor -> index + %8 = arith.cmpi eq, %7, %c1 : index + %9 = scf.if %8 -> (tensor<*xelem_type>) { + %14 = shape.num_elements %6 : tensor -> index + %from_elements = tensor.from_elements %14 : tensor<1xindex> + %15 = mhlo.reshape %arg0 : (tensor<*xelem_type>) -> tensor + %16 = mhlo.dynamic_reshape %arg1, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %17 = chlo.broadcast_or %15, %16 : (tensor, tensor) -> tensor + %cast = tensor.cast %17 : tensor to tensor<*xelem_type> + scf.yield %cast : tensor<*xelem_type> + } else { + %14 = shape.num_elements %6 : tensor -> index + %15 = arith.cmpi eq, %14, %c1 : index + %16 = scf.if %15 -> (tensor<*xelem_type>) { + %17 = shape.num_elements %5 : tensor -> index + %from_elements = tensor.from_elements %17 : tensor<1xindex> + %18 = mhlo.dynamic_reshape %arg0, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %19 = mhlo.reshape %arg1 : (tensor<*xelem_type>) -> tensor + %20 = chlo.broadcast_or %18, %19 : (tensor, tensor) -> tensor + %cast = tensor.cast %20 : tensor to tensor<*xelem_type> + scf.yield %cast : tensor<*xelem_type> + } else { + %17 = shape.shape_eq %5, %6 : tensor, tensor + %18 = scf.if %17 -> (tensor<*xelem_type>) { + %19 = shape.any %5, %6 : tensor, tensor -> tensor + %20 = shape.num_elements %19 : tensor -> index + %from_elements = tensor.from_elements %20 : tensor<1xindex> + %21 = mhlo.dynamic_reshape %arg0, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %22 = mhlo.dynamic_reshape %arg1, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %23 = chlo.broadcast_or %21, %22 : (tensor, tensor) -> tensor + %cast = tensor.cast %23 : tensor to tensor<*xelem_type> + scf.yield %cast : tensor<*xelem_type> + } else { + %19:2 = chlo.minimum_broadcast_shapes %5, %6 : tensor, tensor -> tensor, tensor + %20 = shape.rank %19#0 : tensor -> index + %21 = shape.rank %19#1 : tensor -> index + %22 = arith.cmpi sgt, %20, %21 : index + %23 = arith.select %22, %20, %21 : index + %24 = arith.cmpi ule, %23, %c1 : index + %25 = scf.if %24 -> (tensor<*xelem_type>) { + %26 = shape.broadcast %19#0, %4 : tensor, tensor<1xindex> -> tensor + %cast = tensor.cast %26 : tensor to tensor<1xindex> + %27 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %28 = shape.broadcast %19#1, %4 : tensor, tensor<1xindex> -> tensor + %cast_0 = tensor.cast %28 : tensor to tensor<1xindex> + %29 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %30 = chlo.broadcast_or %27, %29 : (tensor, tensor) -> tensor + %cast_1 = tensor.cast %30 : tensor to tensor<*xelem_type> + scf.yield %cast_1 : tensor<*xelem_type> + } else { + %26 = arith.cmpi ule, %23, %c2 : index + %27 = scf.if %26 -> (tensor<*xelem_type>) { + %28 = shape.broadcast %19#0, %3 : tensor, tensor<2xindex> -> tensor + %cast = tensor.cast %28 : tensor to tensor<2xindex> + %29 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xelem_type>, tensor<2xindex>) -> tensor + %30 = shape.broadcast %19#1, %3 : tensor, tensor<2xindex> -> tensor + %cast_0 = tensor.cast %30 : tensor to tensor<2xindex> + %31 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xelem_type>, tensor<2xindex>) -> tensor + %32 = chlo.broadcast_or %29, %31 : (tensor, tensor) -> tensor + %cast_1 = tensor.cast %32 : tensor to tensor<*xelem_type> + scf.yield %cast_1 : tensor<*xelem_type> + } else { + %28 = arith.cmpi ule, %23, %c3 : index + %29 = scf.if %28 -> (tensor<*xelem_type>) { + %30 = shape.broadcast %19#0, %2 : tensor, tensor<3xindex> -> tensor + %cast = tensor.cast %30 : tensor to tensor<3xindex> + %31 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xelem_type>, tensor<3xindex>) -> tensor + %32 = shape.broadcast %19#1, %2 : tensor, tensor<3xindex> -> tensor + %cast_0 = tensor.cast %32 : tensor to tensor<3xindex> + %33 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xelem_type>, tensor<3xindex>) -> tensor + %34 = chlo.broadcast_or %31, %33 : (tensor, tensor) -> tensor + %cast_1 = tensor.cast %34 : tensor to tensor<*xelem_type> + scf.yield %cast_1 : tensor<*xelem_type> + } else { + %30 = arith.cmpi ule, %23, %c4 : index + %31 = scf.if %30 -> (tensor<*xelem_type>) { + %32 = shape.broadcast %19#0, %1 : tensor, tensor<4xindex> -> tensor + %cast = tensor.cast %32 : tensor to tensor<4xindex> + %33 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xelem_type>, tensor<4xindex>) -> tensor + %34 = shape.broadcast %19#1, %1 : tensor, tensor<4xindex> -> tensor + %cast_0 = tensor.cast %34 : tensor to tensor<4xindex> + %35 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xelem_type>, tensor<4xindex>) -> tensor + %36 = chlo.broadcast_or %33, %35 : (tensor, tensor) -> tensor + %cast_1 = tensor.cast %36 : tensor to tensor<*xelem_type> + scf.yield %cast_1 : tensor<*xelem_type> + } else { + %32 = arith.cmpi ule, %23, %c5 : index + cf.assert %32, "Input for dynamic binary or n-ary op lowering was of a rank greater than 5" + %33 = shape.broadcast %19#0, %0 : tensor, tensor<5xindex> -> tensor + %cast = tensor.cast %33 : tensor to tensor<5xindex> + %34 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xelem_type>, tensor<5xindex>) -> tensor + %35 = shape.broadcast %19#1, %0 : tensor, tensor<5xindex> -> tensor + %cast_0 = tensor.cast %35 : tensor to tensor<5xindex> + %36 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xelem_type>, tensor<5xindex>) -> tensor + %37 = chlo.broadcast_or %34, %36 : (tensor, tensor) -> tensor + %cast_1 = tensor.cast %37 : tensor to tensor<*xelem_type> + scf.yield %cast_1 : tensor<*xelem_type> + } + scf.yield %31 : tensor<*xelem_type> + } + scf.yield %29 : tensor<*xelem_type> + } + scf.yield %27 : tensor<*xelem_type> + } + scf.yield %25 : tensor<*xelem_type> + } + scf.yield %18 : tensor<*xelem_type> + } + scf.yield %16 : tensor<*xelem_type> + } + %10 = shape.shape_of %arg0 : tensor<*xelem_type> -> tensor + %11 = shape.shape_of %arg1 : tensor<*xelem_type> -> tensor + %12 = shape.broadcast %10, %11 : tensor, tensor -> tensor + %13 = mhlo.dynamic_reshape %9, %12 : (tensor<*xelem_type>, tensor) -> tensor<*xelem_type> + return %13 : tensor<*xelem_type> } diff --git a/tensorflow/core/kernels/mlir_generated/op_definitions/maximum.mlir.tmpl b/tensorflow/core/kernels/mlir_generated/op_definitions/maximum.mlir.tmpl index de2fcd8171576b..8529768fb90285 100644 --- a/tensorflow/core/kernels/mlir_generated/op_definitions/maximum.mlir.tmpl +++ b/tensorflow/core/kernels/mlir_generated/op_definitions/maximum.mlir.tmpl @@ -1,6 +1,129 @@ -func.func @Maximum_platform_elem_type_output_type(%arg0: tensor<*xelem_type>, %arg1: tensor<*xelem_type>) - -> tensor<*xoutput_type> attributes {tf_entry, llvm.emit_c_interface} { - %0 = chlo.broadcast_maximum %arg0, %arg1 - : (tensor<*xelem_type>, tensor<*xelem_type>) -> tensor<*xoutput_type> - func.return %0 : tensor<*xoutput_type> +func.func @Maximum_platform_elem_type_output_type(%arg0: tensor<*xelem_type>, %arg1: tensor<*xelem_type>) -> tensor<*xoutput_type> attributes {llvm.emit_c_interface, tf_entry} { + %0 = shape.const_shape [1, 1, 1, 1, 1] : tensor<5xindex> + %c5 = arith.constant 5 : index + %1 = shape.const_shape [1, 1, 1, 1] : tensor<4xindex> + %c4 = arith.constant 4 : index + %2 = shape.const_shape [1, 1, 1] : tensor<3xindex> + %c3 = arith.constant 3 : index + %3 = shape.const_shape [1, 1] : tensor<2xindex> + %c2 = arith.constant 2 : index + %4 = shape.const_shape [1] : tensor<1xindex> + %c1 = arith.constant 1 : index + %5 = shape.shape_of %arg0 : tensor<*xelem_type> -> tensor + %6 = shape.shape_of %arg1 : tensor<*xelem_type> -> tensor + %7 = shape.num_elements %5 : tensor -> index + %8 = arith.cmpi eq, %7, %c1 : index + %9 = scf.if %8 -> (tensor<*xoutput_type>) { + %14 = shape.num_elements %6 : tensor -> index + %from_elements = tensor.from_elements %14 : tensor<1xindex> + %15 = mhlo.reshape %arg0 : (tensor<*xelem_type>) -> tensor + %16 = mhlo.dynamic_reshape %arg1, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %17 = chlo.broadcast_maximum %15, %16 : (tensor, tensor) -> tensor + %cast = tensor.cast %17 : tensor to tensor<*xoutput_type> + scf.yield %cast : tensor<*xoutput_type> + } else { + %14 = shape.num_elements %6 : tensor -> index + %15 = arith.cmpi eq, %14, %c1 : index + %16 = scf.if %15 -> (tensor<*xoutput_type>) { + %17 = shape.num_elements %5 : tensor -> index + %from_elements = tensor.from_elements %17 : tensor<1xindex> + %18 = mhlo.dynamic_reshape %arg0, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %19 = mhlo.reshape %arg1 : (tensor<*xelem_type>) -> tensor + %20 = chlo.broadcast_maximum %18, %19 : (tensor, tensor) -> tensor + %cast = tensor.cast %20 : tensor to tensor<*xoutput_type> + scf.yield %cast : tensor<*xoutput_type> + } else { + %17 = shape.shape_eq %5, %6 : tensor, tensor + %18 = scf.if %17 -> (tensor<*xoutput_type>) { + %19 = shape.any %5, %6 : tensor, tensor -> tensor + %20 = shape.num_elements %19 : tensor -> index + %from_elements = tensor.from_elements %20 : tensor<1xindex> + %21 = mhlo.dynamic_reshape %arg0, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %22 = mhlo.dynamic_reshape %arg1, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %23 = chlo.broadcast_maximum %21, %22 : (tensor, tensor) -> tensor + %cast = tensor.cast %23 : tensor to tensor<*xoutput_type> + scf.yield %cast : tensor<*xoutput_type> + } else { + %19:2 = chlo.minimum_broadcast_shapes %5, %6 : tensor, tensor -> tensor, tensor + %20 = shape.rank %19#0 : tensor -> index + %21 = shape.rank %19#1 : tensor -> index + %22 = arith.cmpi sgt, %20, %21 : index + %23 = arith.select %22, %20, %21 : index + %24 = arith.cmpi ule, %23, %c1 : index + %25 = scf.if %24 -> (tensor<*xoutput_type>) { + %26 = shape.broadcast %19#0, %4 : tensor, tensor<1xindex> -> tensor + %cast = tensor.cast %26 : tensor to tensor<1xindex> + %27 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %28 = shape.broadcast %19#1, %4 : tensor, tensor<1xindex> -> tensor + %cast_0 = tensor.cast %28 : tensor to tensor<1xindex> + %29 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %30 = chlo.broadcast_maximum %27, %29 : (tensor, tensor) -> tensor + %cast_1 = tensor.cast %30 : tensor to tensor<*xoutput_type> + scf.yield %cast_1 : tensor<*xoutput_type> + } else { + %26 = arith.cmpi ule, %23, %c2 : index + %27 = scf.if %26 -> (tensor<*xoutput_type>) { + %28 = shape.broadcast %19#0, %3 : tensor, tensor<2xindex> -> tensor + %cast = tensor.cast %28 : tensor to tensor<2xindex> + %29 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xelem_type>, tensor<2xindex>) -> tensor + %30 = shape.broadcast %19#1, %3 : tensor, tensor<2xindex> -> tensor + %cast_0 = tensor.cast %30 : tensor to tensor<2xindex> + %31 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xelem_type>, tensor<2xindex>) -> tensor + %32 = chlo.broadcast_maximum %29, %31 : (tensor, tensor) -> tensor + %cast_1 = tensor.cast %32 : tensor to tensor<*xoutput_type> + scf.yield %cast_1 : tensor<*xoutput_type> + } else { + %28 = arith.cmpi ule, %23, %c3 : index + %29 = scf.if %28 -> (tensor<*xoutput_type>) { + %30 = shape.broadcast %19#0, %2 : tensor, tensor<3xindex> -> tensor + %cast = tensor.cast %30 : tensor to tensor<3xindex> + %31 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xelem_type>, tensor<3xindex>) -> tensor + %32 = shape.broadcast %19#1, %2 : tensor, tensor<3xindex> -> tensor + %cast_0 = tensor.cast %32 : tensor to tensor<3xindex> + %33 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xelem_type>, tensor<3xindex>) -> tensor + %34 = chlo.broadcast_maximum %31, %33 : (tensor, tensor) -> tensor + %cast_1 = tensor.cast %34 : tensor to tensor<*xoutput_type> + scf.yield %cast_1 : tensor<*xoutput_type> + } else { + %30 = arith.cmpi ule, %23, %c4 : index + %31 = scf.if %30 -> (tensor<*xoutput_type>) { + %32 = shape.broadcast %19#0, %1 : tensor, tensor<4xindex> -> tensor + %cast = tensor.cast %32 : tensor to tensor<4xindex> + %33 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xelem_type>, tensor<4xindex>) -> tensor + %34 = shape.broadcast %19#1, %1 : tensor, tensor<4xindex> -> tensor + %cast_0 = tensor.cast %34 : tensor to tensor<4xindex> + %35 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xelem_type>, tensor<4xindex>) -> tensor + %36 = chlo.broadcast_maximum %33, %35 : (tensor, tensor) -> tensor + %cast_1 = tensor.cast %36 : tensor to tensor<*xoutput_type> + scf.yield %cast_1 : tensor<*xoutput_type> + } else { + %32 = arith.cmpi ule, %23, %c5 : index + cf.assert %32, "Input for dynamic binary or n-ary op lowering was of a rank greater than 5" + %33 = shape.broadcast %19#0, %0 : tensor, tensor<5xindex> -> tensor + %cast = tensor.cast %33 : tensor to tensor<5xindex> + %34 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xelem_type>, tensor<5xindex>) -> tensor + %35 = shape.broadcast %19#1, %0 : tensor, tensor<5xindex> -> tensor + %cast_0 = tensor.cast %35 : tensor to tensor<5xindex> + %36 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xelem_type>, tensor<5xindex>) -> tensor + %37 = chlo.broadcast_maximum %34, %36 : (tensor, tensor) -> tensor + %cast_1 = tensor.cast %37 : tensor to tensor<*xoutput_type> + scf.yield %cast_1 : tensor<*xoutput_type> + } + scf.yield %31 : tensor<*xoutput_type> + } + scf.yield %29 : tensor<*xoutput_type> + } + scf.yield %27 : tensor<*xoutput_type> + } + scf.yield %25 : tensor<*xoutput_type> + } + scf.yield %18 : tensor<*xoutput_type> + } + scf.yield %16 : tensor<*xoutput_type> + } + %10 = shape.shape_of %arg0 : tensor<*xelem_type> -> tensor + %11 = shape.shape_of %arg1 : tensor<*xelem_type> -> tensor + %12 = shape.broadcast %10, %11 : tensor, tensor -> tensor + %13 = mhlo.dynamic_reshape %9, %12 : (tensor<*xoutput_type>, tensor) -> tensor<*xoutput_type> + return %13 : tensor<*xoutput_type> } diff --git a/tensorflow/core/kernels/mlir_generated/op_definitions/minimum.mlir.tmpl b/tensorflow/core/kernels/mlir_generated/op_definitions/minimum.mlir.tmpl index d3e0e22b815a7f..42ceb3a3306f1f 100644 --- a/tensorflow/core/kernels/mlir_generated/op_definitions/minimum.mlir.tmpl +++ b/tensorflow/core/kernels/mlir_generated/op_definitions/minimum.mlir.tmpl @@ -1,6 +1,129 @@ -func.func @Minimum_platform_elem_type_output_type(%arg0: tensor<*xelem_type>, %arg1: tensor<*xelem_type>) - -> tensor<*xoutput_type> attributes {tf_entry, llvm.emit_c_interface} { - %0 = chlo.broadcast_minimum %arg0, %arg1 - : (tensor<*xelem_type>, tensor<*xelem_type>) -> tensor<*xoutput_type> - func.return %0 : tensor<*xoutput_type> +func.func @Minimum_platform_elem_type_output_type(%arg0: tensor<*xelem_type>, %arg1: tensor<*xelem_type>) -> tensor<*xoutput_type> attributes {llvm.emit_c_interface, tf_entry} { + %0 = shape.const_shape [1, 1, 1, 1, 1] : tensor<5xindex> + %c5 = arith.constant 5 : index + %1 = shape.const_shape [1, 1, 1, 1] : tensor<4xindex> + %c4 = arith.constant 4 : index + %2 = shape.const_shape [1, 1, 1] : tensor<3xindex> + %c3 = arith.constant 3 : index + %3 = shape.const_shape [1, 1] : tensor<2xindex> + %c2 = arith.constant 2 : index + %4 = shape.const_shape [1] : tensor<1xindex> + %c1 = arith.constant 1 : index + %5 = shape.shape_of %arg0 : tensor<*xelem_type> -> tensor + %6 = shape.shape_of %arg1 : tensor<*xelem_type> -> tensor + %7 = shape.num_elements %5 : tensor -> index + %8 = arith.cmpi eq, %7, %c1 : index + %9 = scf.if %8 -> (tensor<*xelem_type>) { + %14 = shape.num_elements %6 : tensor -> index + %from_elements = tensor.from_elements %14 : tensor<1xindex> + %15 = mhlo.reshape %arg0 : (tensor<*xelem_type>) -> tensor + %16 = mhlo.dynamic_reshape %arg1, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %17 = chlo.broadcast_minimum %15, %16 : (tensor, tensor) -> tensor + %cast = tensor.cast %17 : tensor to tensor<*xelem_type> + scf.yield %cast : tensor<*xelem_type> + } else { + %14 = shape.num_elements %6 : tensor -> index + %15 = arith.cmpi eq, %14, %c1 : index + %16 = scf.if %15 -> (tensor<*xelem_type>) { + %17 = shape.num_elements %5 : tensor -> index + %from_elements = tensor.from_elements %17 : tensor<1xindex> + %18 = mhlo.dynamic_reshape %arg0, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %19 = mhlo.reshape %arg1 : (tensor<*xelem_type>) -> tensor + %20 = chlo.broadcast_minimum %18, %19 : (tensor, tensor) -> tensor + %cast = tensor.cast %20 : tensor to tensor<*xelem_type> + scf.yield %cast : tensor<*xelem_type> + } else { + %17 = shape.shape_eq %5, %6 : tensor, tensor + %18 = scf.if %17 -> (tensor<*xelem_type>) { + %19 = shape.any %5, %6 : tensor, tensor -> tensor + %20 = shape.num_elements %19 : tensor -> index + %from_elements = tensor.from_elements %20 : tensor<1xindex> + %21 = mhlo.dynamic_reshape %arg0, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %22 = mhlo.dynamic_reshape %arg1, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %23 = chlo.broadcast_minimum %21, %22 : (tensor, tensor) -> tensor + %cast = tensor.cast %23 : tensor to tensor<*xelem_type> + scf.yield %cast : tensor<*xelem_type> + } else { + %19:2 = chlo.minimum_broadcast_shapes %5, %6 : tensor, tensor -> tensor, tensor + %20 = shape.rank %19#0 : tensor -> index + %21 = shape.rank %19#1 : tensor -> index + %22 = arith.cmpi sgt, %20, %21 : index + %23 = arith.select %22, %20, %21 : index + %24 = arith.cmpi ule, %23, %c1 : index + %25 = scf.if %24 -> (tensor<*xelem_type>) { + %26 = shape.broadcast %19#0, %4 : tensor, tensor<1xindex> -> tensor + %cast = tensor.cast %26 : tensor to tensor<1xindex> + %27 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %28 = shape.broadcast %19#1, %4 : tensor, tensor<1xindex> -> tensor + %cast_0 = tensor.cast %28 : tensor to tensor<1xindex> + %29 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %30 = chlo.broadcast_minimum %27, %29 : (tensor, tensor) -> tensor + %cast_1 = tensor.cast %30 : tensor to tensor<*xelem_type> + scf.yield %cast_1 : tensor<*xelem_type> + } else { + %26 = arith.cmpi ule, %23, %c2 : index + %27 = scf.if %26 -> (tensor<*xelem_type>) { + %28 = shape.broadcast %19#0, %3 : tensor, tensor<2xindex> -> tensor + %cast = tensor.cast %28 : tensor to tensor<2xindex> + %29 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xelem_type>, tensor<2xindex>) -> tensor + %30 = shape.broadcast %19#1, %3 : tensor, tensor<2xindex> -> tensor + %cast_0 = tensor.cast %30 : tensor to tensor<2xindex> + %31 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xelem_type>, tensor<2xindex>) -> tensor + %32 = chlo.broadcast_minimum %29, %31 : (tensor, tensor) -> tensor + %cast_1 = tensor.cast %32 : tensor to tensor<*xelem_type> + scf.yield %cast_1 : tensor<*xelem_type> + } else { + %28 = arith.cmpi ule, %23, %c3 : index + %29 = scf.if %28 -> (tensor<*xelem_type>) { + %30 = shape.broadcast %19#0, %2 : tensor, tensor<3xindex> -> tensor + %cast = tensor.cast %30 : tensor to tensor<3xindex> + %31 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xelem_type>, tensor<3xindex>) -> tensor + %32 = shape.broadcast %19#1, %2 : tensor, tensor<3xindex> -> tensor + %cast_0 = tensor.cast %32 : tensor to tensor<3xindex> + %33 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xelem_type>, tensor<3xindex>) -> tensor + %34 = chlo.broadcast_minimum %31, %33 : (tensor, tensor) -> tensor + %cast_1 = tensor.cast %34 : tensor to tensor<*xelem_type> + scf.yield %cast_1 : tensor<*xelem_type> + } else { + %30 = arith.cmpi ule, %23, %c4 : index + %31 = scf.if %30 -> (tensor<*xelem_type>) { + %32 = shape.broadcast %19#0, %1 : tensor, tensor<4xindex> -> tensor + %cast = tensor.cast %32 : tensor to tensor<4xindex> + %33 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xelem_type>, tensor<4xindex>) -> tensor + %34 = shape.broadcast %19#1, %1 : tensor, tensor<4xindex> -> tensor + %cast_0 = tensor.cast %34 : tensor to tensor<4xindex> + %35 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xelem_type>, tensor<4xindex>) -> tensor + %36 = chlo.broadcast_minimum %33, %35 : (tensor, tensor) -> tensor + %cast_1 = tensor.cast %36 : tensor to tensor<*xelem_type> + scf.yield %cast_1 : tensor<*xelem_type> + } else { + %32 = arith.cmpi ule, %23, %c5 : index + cf.assert %32, "Input for dynamic binary or n-ary op lowering was of a rank greater than 5" + %33 = shape.broadcast %19#0, %0 : tensor, tensor<5xindex> -> tensor + %cast = tensor.cast %33 : tensor to tensor<5xindex> + %34 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xelem_type>, tensor<5xindex>) -> tensor + %35 = shape.broadcast %19#1, %0 : tensor, tensor<5xindex> -> tensor + %cast_0 = tensor.cast %35 : tensor to tensor<5xindex> + %36 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xelem_type>, tensor<5xindex>) -> tensor + %37 = chlo.broadcast_minimum %34, %36 : (tensor, tensor) -> tensor + %cast_1 = tensor.cast %37 : tensor to tensor<*xelem_type> + scf.yield %cast_1 : tensor<*xelem_type> + } + scf.yield %31 : tensor<*xelem_type> + } + scf.yield %29 : tensor<*xelem_type> + } + scf.yield %27 : tensor<*xelem_type> + } + scf.yield %25 : tensor<*xelem_type> + } + scf.yield %18 : tensor<*xelem_type> + } + scf.yield %16 : tensor<*xelem_type> + } + %10 = shape.shape_of %arg0 : tensor<*xelem_type> -> tensor + %11 = shape.shape_of %arg1 : tensor<*xelem_type> -> tensor + %12 = shape.broadcast %10, %11 : tensor, tensor -> tensor + %13 = mhlo.dynamic_reshape %9, %12 : (tensor<*xelem_type>, tensor) -> tensor<*xelem_type> + return %13 : tensor<*xoutput_type> } diff --git a/tensorflow/core/kernels/mlir_generated/op_definitions/mul.mlir.tmpl b/tensorflow/core/kernels/mlir_generated/op_definitions/mul.mlir.tmpl index 3cedec1a308144..ace06691c5977c 100644 --- a/tensorflow/core/kernels/mlir_generated/op_definitions/mul.mlir.tmpl +++ b/tensorflow/core/kernels/mlir_generated/op_definitions/mul.mlir.tmpl @@ -1,6 +1,129 @@ -func.func @Mul_platform_elem_type_output_type(%arg0: tensor<*xelem_type>, %arg1: tensor<*xelem_type>) - -> tensor<*xoutput_type> attributes {tf_entry, llvm.emit_c_interface} { - %0 = chlo.broadcast_multiply %arg0, %arg1 - : (tensor<*xelem_type>, tensor<*xelem_type>) -> tensor<*xoutput_type> - func.return %0 : tensor<*xoutput_type> +func.func @Mul_platform_elem_type_output_type(%arg0: tensor<*xelem_type>, %arg1: tensor<*xelem_type>) -> tensor<*xoutput_type> attributes {llvm.emit_c_interface, tf_entry} { + %0 = shape.const_shape [1, 1, 1, 1, 1] : tensor<5xindex> + %c5 = arith.constant 5 : index + %1 = shape.const_shape [1, 1, 1, 1] : tensor<4xindex> + %c4 = arith.constant 4 : index + %2 = shape.const_shape [1, 1, 1] : tensor<3xindex> + %c3 = arith.constant 3 : index + %3 = shape.const_shape [1, 1] : tensor<2xindex> + %c2 = arith.constant 2 : index + %4 = shape.const_shape [1] : tensor<1xindex> + %c1 = arith.constant 1 : index + %5 = shape.shape_of %arg0 : tensor<*xelem_type> -> tensor + %6 = shape.shape_of %arg1 : tensor<*xelem_type> -> tensor + %7 = shape.num_elements %5 : tensor -> index + %8 = arith.cmpi eq, %7, %c1 : index + %9 = scf.if %8 -> (tensor<*xoutput_type>) { + %14 = shape.num_elements %6 : tensor -> index + %from_elements = tensor.from_elements %14 : tensor<1xindex> + %15 = mhlo.reshape %arg0 : (tensor<*xelem_type>) -> tensor + %16 = mhlo.dynamic_reshape %arg1, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %17 = chlo.broadcast_multiply %15, %16 : (tensor, tensor) -> tensor + %cast = tensor.cast %17 : tensor to tensor<*xoutput_type> + scf.yield %cast : tensor<*xoutput_type> + } else { + %14 = shape.num_elements %6 : tensor -> index + %15 = arith.cmpi eq, %14, %c1 : index + %16 = scf.if %15 -> (tensor<*xoutput_type>) { + %17 = shape.num_elements %5 : tensor -> index + %from_elements = tensor.from_elements %17 : tensor<1xindex> + %18 = mhlo.dynamic_reshape %arg0, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %19 = mhlo.reshape %arg1 : (tensor<*xelem_type>) -> tensor + %20 = chlo.broadcast_multiply %18, %19 : (tensor, tensor) -> tensor + %cast = tensor.cast %20 : tensor to tensor<*xoutput_type> + scf.yield %cast : tensor<*xoutput_type> + } else { + %17 = shape.shape_eq %5, %6 : tensor, tensor + %18 = scf.if %17 -> (tensor<*xoutput_type>) { + %19 = shape.any %5, %6 : tensor, tensor -> tensor + %20 = shape.num_elements %19 : tensor -> index + %from_elements = tensor.from_elements %20 : tensor<1xindex> + %21 = mhlo.dynamic_reshape %arg0, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %22 = mhlo.dynamic_reshape %arg1, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %23 = chlo.broadcast_multiply %21, %22 : (tensor, tensor) -> tensor + %cast = tensor.cast %23 : tensor to tensor<*xoutput_type> + scf.yield %cast : tensor<*xoutput_type> + } else { + %19:2 = chlo.minimum_broadcast_shapes %5, %6 : tensor, tensor -> tensor, tensor + %20 = shape.rank %19#0 : tensor -> index + %21 = shape.rank %19#1 : tensor -> index + %22 = arith.cmpi sgt, %20, %21 : index + %23 = arith.select %22, %20, %21 : index + %24 = arith.cmpi ule, %23, %c1 : index + %25 = scf.if %24 -> (tensor<*xoutput_type>) { + %26 = shape.broadcast %19#0, %4 : tensor, tensor<1xindex> -> tensor + %cast = tensor.cast %26 : tensor to tensor<1xindex> + %27 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %28 = shape.broadcast %19#1, %4 : tensor, tensor<1xindex> -> tensor + %cast_0 = tensor.cast %28 : tensor to tensor<1xindex> + %29 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %30 = chlo.broadcast_multiply %27, %29 : (tensor, tensor) -> tensor + %cast_1 = tensor.cast %30 : tensor to tensor<*xoutput_type> + scf.yield %cast_1 : tensor<*xoutput_type> + } else { + %26 = arith.cmpi ule, %23, %c2 : index + %27 = scf.if %26 -> (tensor<*xoutput_type>) { + %28 = shape.broadcast %19#0, %3 : tensor, tensor<2xindex> -> tensor + %cast = tensor.cast %28 : tensor to tensor<2xindex> + %29 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xelem_type>, tensor<2xindex>) -> tensor + %30 = shape.broadcast %19#1, %3 : tensor, tensor<2xindex> -> tensor + %cast_0 = tensor.cast %30 : tensor to tensor<2xindex> + %31 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xelem_type>, tensor<2xindex>) -> tensor + %32 = chlo.broadcast_multiply %29, %31 : (tensor, tensor) -> tensor + %cast_1 = tensor.cast %32 : tensor to tensor<*xoutput_type> + scf.yield %cast_1 : tensor<*xoutput_type> + } else { + %28 = arith.cmpi ule, %23, %c3 : index + %29 = scf.if %28 -> (tensor<*xoutput_type>) { + %30 = shape.broadcast %19#0, %2 : tensor, tensor<3xindex> -> tensor + %cast = tensor.cast %30 : tensor to tensor<3xindex> + %31 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xelem_type>, tensor<3xindex>) -> tensor + %32 = shape.broadcast %19#1, %2 : tensor, tensor<3xindex> -> tensor + %cast_0 = tensor.cast %32 : tensor to tensor<3xindex> + %33 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xelem_type>, tensor<3xindex>) -> tensor + %34 = chlo.broadcast_multiply %31, %33 : (tensor, tensor) -> tensor + %cast_1 = tensor.cast %34 : tensor to tensor<*xoutput_type> + scf.yield %cast_1 : tensor<*xoutput_type> + } else { + %30 = arith.cmpi ule, %23, %c4 : index + %31 = scf.if %30 -> (tensor<*xoutput_type>) { + %32 = shape.broadcast %19#0, %1 : tensor, tensor<4xindex> -> tensor + %cast = tensor.cast %32 : tensor to tensor<4xindex> + %33 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xelem_type>, tensor<4xindex>) -> tensor + %34 = shape.broadcast %19#1, %1 : tensor, tensor<4xindex> -> tensor + %cast_0 = tensor.cast %34 : tensor to tensor<4xindex> + %35 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xelem_type>, tensor<4xindex>) -> tensor + %36 = chlo.broadcast_multiply %33, %35 : (tensor, tensor) -> tensor + %cast_1 = tensor.cast %36 : tensor to tensor<*xoutput_type> + scf.yield %cast_1 : tensor<*xoutput_type> + } else { + %32 = arith.cmpi ule, %23, %c5 : index + cf.assert %32, "Input for dynamic binary or n-ary op lowering was of a rank greater than 5" + %33 = shape.broadcast %19#0, %0 : tensor, tensor<5xindex> -> tensor + %cast = tensor.cast %33 : tensor to tensor<5xindex> + %34 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xelem_type>, tensor<5xindex>) -> tensor + %35 = shape.broadcast %19#1, %0 : tensor, tensor<5xindex> -> tensor + %cast_0 = tensor.cast %35 : tensor to tensor<5xindex> + %36 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xelem_type>, tensor<5xindex>) -> tensor + %37 = chlo.broadcast_multiply %34, %36 : (tensor, tensor) -> tensor + %cast_1 = tensor.cast %37 : tensor to tensor<*xoutput_type> + scf.yield %cast_1 : tensor<*xoutput_type> + } + scf.yield %31 : tensor<*xoutput_type> + } + scf.yield %29 : tensor<*xoutput_type> + } + scf.yield %27 : tensor<*xoutput_type> + } + scf.yield %25 : tensor<*xoutput_type> + } + scf.yield %18 : tensor<*xoutput_type> + } + scf.yield %16 : tensor<*xoutput_type> + } + %10 = shape.shape_of %arg0 : tensor<*xelem_type> -> tensor + %11 = shape.shape_of %arg1 : tensor<*xelem_type> -> tensor + %12 = shape.broadcast %10, %11 : tensor, tensor -> tensor + %13 = mhlo.dynamic_reshape %9, %12 : (tensor<*xoutput_type>, tensor) -> tensor<*xoutput_type> + return %13 : tensor<*xoutput_type> } diff --git a/tensorflow/core/kernels/mlir_generated/op_definitions/mul_no_nan.mlir.tmpl b/tensorflow/core/kernels/mlir_generated/op_definitions/mul_no_nan.mlir.tmpl index 4974e11359410a..4699da4dbf95f6 100644 --- a/tensorflow/core/kernels/mlir_generated/op_definitions/mul_no_nan.mlir.tmpl +++ b/tensorflow/core/kernels/mlir_generated/op_definitions/mul_no_nan.mlir.tmpl @@ -1,8 +1,148 @@ -func.func @MulNoNan_platform_elem_type_output_type(%arg0: tensor<*xelem_type>, %arg1: tensor<*xelem_type>) - -> tensor<*xoutput_type> attributes {tf_entry, llvm.emit_c_interface} { - %0 = mhlo.constant dense<0.000000e+00> : tensor - %1 = chlo.broadcast_compare %arg1, %0 {comparison_direction = #chlo} : (tensor<*xelem_type>, tensor) -> tensor<*xi1> - %2 = chlo.broadcast_multiply %arg0, %arg1 : (tensor<*xelem_type>, tensor<*xelem_type>) -> tensor<*xelem_type> - %3 = chlo.broadcast_select %1, %0, %2 : (tensor<*xi1>, tensor, tensor<*xelem_type>) -> tensor<*xelem_type> - func.return %3 : tensor<*xoutput_type> +func.func @MulNoNan_platform_elem_type_output_type(%arg0: tensor<*xelem_type>, %arg1: tensor<*xelem_type>) -> tensor<*xoutput_type> attributes {llvm.emit_c_interface, tf_entry} { + %0 = shape.const_shape [1, 1, 1, 1, 1] : tensor<5xindex> + %c5 = arith.constant 5 : index + %1 = shape.const_shape [1, 1, 1, 1] : tensor<4xindex> + %c4 = arith.constant 4 : index + %2 = shape.const_shape [1, 1, 1] : tensor<3xindex> + %c3 = arith.constant 3 : index + %3 = shape.const_shape [1, 1] : tensor<2xindex> + %c2 = arith.constant 2 : index + %4 = shape.const_shape [1] : tensor<1xindex> + %c1 = arith.constant 1 : index + %5 = mhlo.constant dense<0.000000e+00> : tensor + %6 = shape.shape_of %arg0 : tensor<*xelem_type> -> tensor + %7 = shape.shape_of %arg1 : tensor<*xelem_type> -> tensor + %8 = shape.num_elements %6 : tensor -> index + %9 = arith.cmpi eq, %8, %c1 : index + %10 = scf.if %9 -> (tensor<*xelem_type>) { + %17 = shape.num_elements %7 : tensor -> index + %from_elements = tensor.from_elements %17 : tensor<1xindex> + %18 = mhlo.reshape %arg0 : (tensor<*xelem_type>) -> tensor + %19 = mhlo.dynamic_reshape %arg1, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %20 = chlo.broadcast_compare %19, %5 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %21 = chlo.broadcast_multiply %18, %19 : (tensor, tensor) -> tensor + %22 = chlo.broadcast_select %20, %5, %21 : (tensor, tensor, tensor) -> tensor + %cast = tensor.cast %22 : tensor to tensor<*xelem_type> + scf.yield %cast : tensor<*xelem_type> + } else { + %17 = shape.num_elements %7 : tensor -> index + %18 = arith.cmpi eq, %17, %c1 : index + %19 = scf.if %18 -> (tensor<*xelem_type>) { + %20 = shape.num_elements %6 : tensor -> index + %from_elements = tensor.from_elements %20 : tensor<1xindex> + %21 = mhlo.dynamic_reshape %arg0, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %22 = mhlo.reshape %arg1 : (tensor<*xelem_type>) -> tensor + %23 = chlo.broadcast_compare %22, %5 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %24 = chlo.broadcast_multiply %21, %22 : (tensor, tensor) -> tensor + %25 = chlo.broadcast_select %23, %5, %24 : (tensor, tensor, tensor) -> tensor + %cast = tensor.cast %25 : tensor to tensor<*xelem_type> + scf.yield %cast : tensor<*xelem_type> + } else { + %20 = shape.shape_eq %6, %7 : tensor, tensor + %21 = scf.if %20 -> (tensor<*xelem_type>) { + %22 = shape.any %6, %7 : tensor, tensor -> tensor + %23 = shape.num_elements %22 : tensor -> index + %from_elements = tensor.from_elements %23 : tensor<1xindex> + %24 = mhlo.dynamic_reshape %arg0, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %25 = mhlo.dynamic_reshape %arg1, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %26 = chlo.broadcast_compare %25, %5 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %27 = chlo.broadcast_multiply %24, %25 : (tensor, tensor) -> tensor + %28 = chlo.broadcast_select %26, %5, %27 : (tensor, tensor, tensor) -> tensor + %cast = tensor.cast %28 : tensor to tensor<*xelem_type> + scf.yield %cast : tensor<*xelem_type> + } else { + %22:2 = chlo.minimum_broadcast_shapes %6, %7 : tensor, tensor -> tensor, tensor + %23 = shape.rank %22#0 : tensor -> index + %24 = shape.rank %22#1 : tensor -> index + %25 = arith.cmpi sgt, %23, %24 : index + %26 = arith.select %25, %23, %24 : index + %27 = arith.cmpi ule, %26, %c1 : index + %28 = scf.if %27 -> (tensor<*xelem_type>) { + %29 = shape.broadcast %22#0, %4 : tensor, tensor<1xindex> -> tensor + %cast = tensor.cast %29 : tensor to tensor<1xindex> + %30 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %31 = shape.broadcast %22#1, %4 : tensor, tensor<1xindex> -> tensor + %cast_0 = tensor.cast %31 : tensor to tensor<1xindex> + %32 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %33 = chlo.broadcast_compare %32, %5 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %34 = chlo.broadcast_multiply %30, %32 : (tensor, tensor) -> tensor + %35 = chlo.broadcast_select %33, %5, %34 : (tensor, tensor, tensor) -> tensor + %cast_1 = tensor.cast %35 : tensor to tensor<*xelem_type> + scf.yield %cast_1 : tensor<*xelem_type> + } else { + %29 = arith.cmpi ule, %26, %c2 : index + %30 = scf.if %29 -> (tensor<*xelem_type>) { + %31 = shape.broadcast %22#0, %3 : tensor, tensor<2xindex> -> tensor + %cast = tensor.cast %31 : tensor to tensor<2xindex> + %32 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xelem_type>, tensor<2xindex>) -> tensor + %33 = shape.broadcast %22#1, %3 : tensor, tensor<2xindex> -> tensor + %cast_0 = tensor.cast %33 : tensor to tensor<2xindex> + %34 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xelem_type>, tensor<2xindex>) -> tensor + %35 = chlo.broadcast_compare %34, %5 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %36 = chlo.broadcast_multiply %32, %34 : (tensor, tensor) -> tensor + %37 = chlo.broadcast_select %35, %5, %36 : (tensor, tensor, tensor) -> tensor + %cast_1 = tensor.cast %37 : tensor to tensor<*xelem_type> + scf.yield %cast_1 : tensor<*xelem_type> + } else { + %31 = arith.cmpi ule, %26, %c3 : index + %32 = scf.if %31 -> (tensor<*xelem_type>) { + %33 = shape.broadcast %22#0, %2 : tensor, tensor<3xindex> -> tensor + %cast = tensor.cast %33 : tensor to tensor<3xindex> + %34 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xelem_type>, tensor<3xindex>) -> tensor + %35 = shape.broadcast %22#1, %2 : tensor, tensor<3xindex> -> tensor + %cast_0 = tensor.cast %35 : tensor to tensor<3xindex> + %36 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xelem_type>, tensor<3xindex>) -> tensor + %37 = chlo.broadcast_compare %36, %5 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %38 = chlo.broadcast_multiply %34, %36 : (tensor, tensor) -> tensor + %39 = chlo.broadcast_select %37, %5, %38 : (tensor, tensor, tensor) -> tensor + %cast_1 = tensor.cast %39 : tensor to tensor<*xelem_type> + scf.yield %cast_1 : tensor<*xelem_type> + } else { + %33 = arith.cmpi ule, %26, %c4 : index + %34 = scf.if %33 -> (tensor<*xelem_type>) { + %35 = shape.broadcast %22#0, %1 : tensor, tensor<4xindex> -> tensor + %cast = tensor.cast %35 : tensor to tensor<4xindex> + %36 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xelem_type>, tensor<4xindex>) -> tensor + %37 = shape.broadcast %22#1, %1 : tensor, tensor<4xindex> -> tensor + %cast_0 = tensor.cast %37 : tensor to tensor<4xindex> + %38 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xelem_type>, tensor<4xindex>) -> tensor + %39 = chlo.broadcast_compare %38, %5 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %40 = chlo.broadcast_multiply %36, %38 : (tensor, tensor) -> tensor + %41 = chlo.broadcast_select %39, %5, %40 : (tensor, tensor, tensor) -> tensor + %cast_1 = tensor.cast %41 : tensor to tensor<*xelem_type> + scf.yield %cast_1 : tensor<*xelem_type> + } else { + %35 = arith.cmpi ule, %26, %c5 : index + cf.assert %35, "Input for dynamic binary or n-ary op lowering was of a rank greater than 5" + %36 = shape.broadcast %22#0, %0 : tensor, tensor<5xindex> -> tensor + %cast = tensor.cast %36 : tensor to tensor<5xindex> + %37 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xelem_type>, tensor<5xindex>) -> tensor + %38 = shape.broadcast %22#1, %0 : tensor, tensor<5xindex> -> tensor + %cast_0 = tensor.cast %38 : tensor to tensor<5xindex> + %39 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xelem_type>, tensor<5xindex>) -> tensor + %40 = chlo.broadcast_compare %39, %5 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %41 = chlo.broadcast_multiply %37, %39 : (tensor, tensor) -> tensor + %42 = chlo.broadcast_select %40, %5, %41 : (tensor, tensor, tensor) -> tensor + %cast_1 = tensor.cast %42 : tensor to tensor<*xelem_type> + scf.yield %cast_1 : tensor<*xelem_type> + } + scf.yield %34 : tensor<*xelem_type> + } + scf.yield %32 : tensor<*xelem_type> + } + scf.yield %30 : tensor<*xelem_type> + } + scf.yield %28 : tensor<*xelem_type> + } + scf.yield %21 : tensor<*xelem_type> + } + scf.yield %19 : tensor<*xelem_type> + } + %11 = shape.shape_of %arg1 : tensor<*xelem_type> -> tensor + %12 = shape.shape_of %arg0 : tensor<*xelem_type> -> tensor + %13 = shape.shape_of %arg1 : tensor<*xelem_type> -> tensor + %14 = shape.broadcast %12, %13 : tensor, tensor -> tensor + %15 = shape.broadcast %11, %14 : tensor, tensor -> tensor + %16 = mhlo.dynamic_reshape %10, %15 : (tensor<*xelem_type>, tensor) -> tensor<*xelem_type> + return %16 : tensor<*xelem_type> } diff --git a/tensorflow/core/kernels/mlir_generated/op_definitions/mul_no_nan_cmplx.mlir.tmpl b/tensorflow/core/kernels/mlir_generated/op_definitions/mul_no_nan_cmplx.mlir.tmpl index f25272ef9812a7..80374748a840d9 100644 --- a/tensorflow/core/kernels/mlir_generated/op_definitions/mul_no_nan_cmplx.mlir.tmpl +++ b/tensorflow/core/kernels/mlir_generated/op_definitions/mul_no_nan_cmplx.mlir.tmpl @@ -1,8 +1,148 @@ -func.func @MulNoNan_platform_elem_type_output_type(%arg0: tensor<*xelem_type>, %arg1: tensor<*xelem_type>) - -> tensor<*xoutput_type> attributes {tf_entry, llvm.emit_c_interface} { - %0 = mhlo.constant dense<(0.000000e+00,0.000000e+00)> : tensor - %1 = chlo.broadcast_compare %arg1, %0 {comparison_direction = #chlo} : (tensor<*xelem_type>, tensor) -> tensor<*xi1> - %2 = chlo.broadcast_multiply %arg0, %arg1 : (tensor<*xelem_type>, tensor<*xelem_type>) -> tensor<*xelem_type> - %3 = chlo.broadcast_select %1, %0, %2 : (tensor<*xi1>, tensor, tensor<*xelem_type>) -> tensor<*xelem_type> - func.return %3 : tensor<*xoutput_type> +func.func @MulNoNan_platform_elem_type_output_type(%arg0: tensor<*xelem_type>, %arg1: tensor<*xelem_type>) -> tensor<*xoutput_type> attributes {llvm.emit_c_interface, tf_entry} { + %0 = shape.const_shape [1, 1, 1, 1, 1] : tensor<5xindex> + %c5 = arith.constant 5 : index + %1 = shape.const_shape [1, 1, 1, 1] : tensor<4xindex> + %c4 = arith.constant 4 : index + %2 = shape.const_shape [1, 1, 1] : tensor<3xindex> + %c3 = arith.constant 3 : index + %3 = shape.const_shape [1, 1] : tensor<2xindex> + %c2 = arith.constant 2 : index + %4 = shape.const_shape [1] : tensor<1xindex> + %c1 = arith.constant 1 : index + %5 = mhlo.constant dense<(0.000000e+00,0.000000e+00)> : tensor + %6 = shape.shape_of %arg0 : tensor<*xelem_type> -> tensor + %7 = shape.shape_of %arg1 : tensor<*xelem_type> -> tensor + %8 = shape.num_elements %6 : tensor -> index + %9 = arith.cmpi eq, %8, %c1 : index + %10 = scf.if %9 -> (tensor<*xelem_type>) { + %17 = shape.num_elements %7 : tensor -> index + %from_elements = tensor.from_elements %17 : tensor<1xindex> + %18 = mhlo.reshape %arg0 : (tensor<*xelem_type>) -> tensor + %19 = mhlo.dynamic_reshape %arg1, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %20 = chlo.broadcast_compare %19, %5 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %21 = chlo.broadcast_multiply %18, %19 : (tensor, tensor) -> tensor + %22 = chlo.broadcast_select %20, %5, %21 : (tensor, tensor, tensor) -> tensor + %cast = tensor.cast %22 : tensor to tensor<*xelem_type> + scf.yield %cast : tensor<*xelem_type> + } else { + %17 = shape.num_elements %7 : tensor -> index + %18 = arith.cmpi eq, %17, %c1 : index + %19 = scf.if %18 -> (tensor<*xelem_type>) { + %20 = shape.num_elements %6 : tensor -> index + %from_elements = tensor.from_elements %20 : tensor<1xindex> + %21 = mhlo.dynamic_reshape %arg0, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %22 = mhlo.reshape %arg1 : (tensor<*xelem_type>) -> tensor + %23 = chlo.broadcast_compare %22, %5 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %24 = chlo.broadcast_multiply %21, %22 : (tensor, tensor) -> tensor + %25 = chlo.broadcast_select %23, %5, %24 : (tensor, tensor, tensor) -> tensor + %cast = tensor.cast %25 : tensor to tensor<*xelem_type> + scf.yield %cast : tensor<*xelem_type> + } else { + %20 = shape.shape_eq %6, %7 : tensor, tensor + %21 = scf.if %20 -> (tensor<*xelem_type>) { + %22 = shape.any %6, %7 : tensor, tensor -> tensor + %23 = shape.num_elements %22 : tensor -> index + %from_elements = tensor.from_elements %23 : tensor<1xindex> + %24 = mhlo.dynamic_reshape %arg0, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %25 = mhlo.dynamic_reshape %arg1, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %26 = chlo.broadcast_compare %25, %5 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %27 = chlo.broadcast_multiply %24, %25 : (tensor, tensor) -> tensor + %28 = chlo.broadcast_select %26, %5, %27 : (tensor, tensor, tensor) -> tensor + %cast = tensor.cast %28 : tensor to tensor<*xelem_type> + scf.yield %cast : tensor<*xelem_type> + } else { + %22:2 = chlo.minimum_broadcast_shapes %6, %7 : tensor, tensor -> tensor, tensor + %23 = shape.rank %22#0 : tensor -> index + %24 = shape.rank %22#1 : tensor -> index + %25 = arith.cmpi sgt, %23, %24 : index + %26 = arith.select %25, %23, %24 : index + %27 = arith.cmpi ule, %26, %c1 : index + %28 = scf.if %27 -> (tensor<*xelem_type>) { + %29 = shape.broadcast %22#0, %4 : tensor, tensor<1xindex> -> tensor + %cast = tensor.cast %29 : tensor to tensor<1xindex> + %30 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %31 = shape.broadcast %22#1, %4 : tensor, tensor<1xindex> -> tensor + %cast_0 = tensor.cast %31 : tensor to tensor<1xindex> + %32 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %33 = chlo.broadcast_compare %32, %5 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %34 = chlo.broadcast_multiply %30, %32 : (tensor, tensor) -> tensor + %35 = chlo.broadcast_select %33, %5, %34 : (tensor, tensor, tensor) -> tensor + %cast_1 = tensor.cast %35 : tensor to tensor<*xelem_type> + scf.yield %cast_1 : tensor<*xelem_type> + } else { + %29 = arith.cmpi ule, %26, %c2 : index + %30 = scf.if %29 -> (tensor<*xelem_type>) { + %31 = shape.broadcast %22#0, %3 : tensor, tensor<2xindex> -> tensor + %cast = tensor.cast %31 : tensor to tensor<2xindex> + %32 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xelem_type>, tensor<2xindex>) -> tensor + %33 = shape.broadcast %22#1, %3 : tensor, tensor<2xindex> -> tensor + %cast_0 = tensor.cast %33 : tensor to tensor<2xindex> + %34 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xelem_type>, tensor<2xindex>) -> tensor + %35 = chlo.broadcast_compare %34, %5 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %36 = chlo.broadcast_multiply %32, %34 : (tensor, tensor) -> tensor + %37 = chlo.broadcast_select %35, %5, %36 : (tensor, tensor, tensor) -> tensor + %cast_1 = tensor.cast %37 : tensor to tensor<*xelem_type> + scf.yield %cast_1 : tensor<*xelem_type> + } else { + %31 = arith.cmpi ule, %26, %c3 : index + %32 = scf.if %31 -> (tensor<*xelem_type>) { + %33 = shape.broadcast %22#0, %2 : tensor, tensor<3xindex> -> tensor + %cast = tensor.cast %33 : tensor to tensor<3xindex> + %34 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xelem_type>, tensor<3xindex>) -> tensor + %35 = shape.broadcast %22#1, %2 : tensor, tensor<3xindex> -> tensor + %cast_0 = tensor.cast %35 : tensor to tensor<3xindex> + %36 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xelem_type>, tensor<3xindex>) -> tensor + %37 = chlo.broadcast_compare %36, %5 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %38 = chlo.broadcast_multiply %34, %36 : (tensor, tensor) -> tensor + %39 = chlo.broadcast_select %37, %5, %38 : (tensor, tensor, tensor) -> tensor + %cast_1 = tensor.cast %39 : tensor to tensor<*xelem_type> + scf.yield %cast_1 : tensor<*xelem_type> + } else { + %33 = arith.cmpi ule, %26, %c4 : index + %34 = scf.if %33 -> (tensor<*xelem_type>) { + %35 = shape.broadcast %22#0, %1 : tensor, tensor<4xindex> -> tensor + %cast = tensor.cast %35 : tensor to tensor<4xindex> + %36 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xelem_type>, tensor<4xindex>) -> tensor + %37 = shape.broadcast %22#1, %1 : tensor, tensor<4xindex> -> tensor + %cast_0 = tensor.cast %37 : tensor to tensor<4xindex> + %38 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xelem_type>, tensor<4xindex>) -> tensor + %39 = chlo.broadcast_compare %38, %5 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %40 = chlo.broadcast_multiply %36, %38 : (tensor, tensor) -> tensor + %41 = chlo.broadcast_select %39, %5, %40 : (tensor, tensor, tensor) -> tensor + %cast_1 = tensor.cast %41 : tensor to tensor<*xelem_type> + scf.yield %cast_1 : tensor<*xelem_type> + } else { + %35 = arith.cmpi ule, %26, %c5 : index + cf.assert %35, "Input for dynamic binary or n-ary op lowering was of a rank greater than 5" + %36 = shape.broadcast %22#0, %0 : tensor, tensor<5xindex> -> tensor + %cast = tensor.cast %36 : tensor to tensor<5xindex> + %37 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xelem_type>, tensor<5xindex>) -> tensor + %38 = shape.broadcast %22#1, %0 : tensor, tensor<5xindex> -> tensor + %cast_0 = tensor.cast %38 : tensor to tensor<5xindex> + %39 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xelem_type>, tensor<5xindex>) -> tensor + %40 = chlo.broadcast_compare %39, %5 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %41 = chlo.broadcast_multiply %37, %39 : (tensor, tensor) -> tensor + %42 = chlo.broadcast_select %40, %5, %41 : (tensor, tensor, tensor) -> tensor + %cast_1 = tensor.cast %42 : tensor to tensor<*xelem_type> + scf.yield %cast_1 : tensor<*xelem_type> + } + scf.yield %34 : tensor<*xelem_type> + } + scf.yield %32 : tensor<*xelem_type> + } + scf.yield %30 : tensor<*xelem_type> + } + scf.yield %28 : tensor<*xelem_type> + } + scf.yield %21 : tensor<*xelem_type> + } + scf.yield %19 : tensor<*xelem_type> + } + %11 = shape.shape_of %arg1 : tensor<*xelem_type> -> tensor + %12 = shape.shape_of %arg0 : tensor<*xelem_type> -> tensor + %13 = shape.shape_of %arg1 : tensor<*xelem_type> -> tensor + %14 = shape.broadcast %12, %13 : tensor, tensor -> tensor + %15 = shape.broadcast %11, %14 : tensor, tensor -> tensor + %16 = mhlo.dynamic_reshape %10, %15 : (tensor<*xelem_type>, tensor) -> tensor<*xelem_type> + return %16 : tensor<*xelem_type> } diff --git a/tensorflow/core/kernels/mlir_generated/op_definitions/neg.mlir.tmpl b/tensorflow/core/kernels/mlir_generated/op_definitions/neg.mlir.tmpl index cc63a97b842103..e4e303867a4b5f 100644 --- a/tensorflow/core/kernels/mlir_generated/op_definitions/neg.mlir.tmpl +++ b/tensorflow/core/kernels/mlir_generated/op_definitions/neg.mlir.tmpl @@ -1,5 +1,9 @@ -func.func @Neg_platform_elem_type_output_type(%arg0: tensor<*xelem_type>) - -> tensor<*xoutput_type> attributes {tf_entry, llvm.emit_c_interface} { - %0 = mhlo.negate %arg0 : tensor<*xelem_type> - func.return %0 : tensor<*xoutput_type> +func.func @Neg_platform_elem_type_output_type(%arg0: tensor<*xelem_type>) -> tensor<*xoutput_type> attributes {llvm.emit_c_interface, tf_entry} { + %0 = shape.shape_of %arg0 : tensor<*xelem_type> -> tensor + %1 = shape.num_elements %0 : tensor -> index + %from_elements = tensor.from_elements %1 : tensor<1xindex> + %2 = mhlo.dynamic_reshape %arg0, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %3 = mhlo.negate %2 : tensor + %4 = mhlo.dynamic_reshape %3, %0 : (tensor, tensor) -> tensor<*xelem_type> + return %4 : tensor<*xelem_type> } diff --git a/tensorflow/core/kernels/mlir_generated/op_definitions/next_after.mlir.tmpl b/tensorflow/core/kernels/mlir_generated/op_definitions/next_after.mlir.tmpl index 8444fc60ef966c..6c1ab509c9ff4a 100644 --- a/tensorflow/core/kernels/mlir_generated/op_definitions/next_after.mlir.tmpl +++ b/tensorflow/core/kernels/mlir_generated/op_definitions/next_after.mlir.tmpl @@ -1,6 +1,129 @@ -func.func @NextAfter_platform_elem_type_output_type(%arg0: tensor<*xelem_type>, %arg1: tensor<*xelem_type>) - -> tensor<*xoutput_type> attributes {tf_entry, llvm.emit_c_interface} { - %0 = chlo.broadcast_next_after %arg0, %arg1 - : (tensor<*xelem_type>, tensor<*xelem_type>) -> tensor<*xoutput_type> - func.return %0 : tensor<*xoutput_type> +func.func @NextAfter_platform_elem_type_output_type(%arg0: tensor<*xelem_type>, %arg1: tensor<*xelem_type>) -> tensor<*xoutput_type> attributes {llvm.emit_c_interface, tf_entry} { + %0 = shape.const_shape [1, 1, 1, 1, 1] : tensor<5xindex> + %c5 = arith.constant 5 : index + %1 = shape.const_shape [1, 1, 1, 1] : tensor<4xindex> + %c4 = arith.constant 4 : index + %2 = shape.const_shape [1, 1, 1] : tensor<3xindex> + %c3 = arith.constant 3 : index + %3 = shape.const_shape [1, 1] : tensor<2xindex> + %c2 = arith.constant 2 : index + %4 = shape.const_shape [1] : tensor<1xindex> + %c1 = arith.constant 1 : index + %5 = shape.shape_of %arg0 : tensor<*xelem_type> -> tensor + %6 = shape.shape_of %arg1 : tensor<*xelem_type> -> tensor + %7 = shape.num_elements %5 : tensor -> index + %8 = arith.cmpi eq, %7, %c1 : index + %9 = scf.if %8 -> (tensor<*xoutput_type>) { + %14 = shape.num_elements %6 : tensor -> index + %from_elements = tensor.from_elements %14 : tensor<1xindex> + %15 = mhlo.reshape %arg0 : (tensor<*xelem_type>) -> tensor + %16 = mhlo.dynamic_reshape %arg1, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %17 = chlo.broadcast_next_after %15, %16 : (tensor, tensor) -> tensor + %cast = tensor.cast %17 : tensor to tensor<*xoutput_type> + scf.yield %cast : tensor<*xoutput_type> + } else { + %14 = shape.num_elements %6 : tensor -> index + %15 = arith.cmpi eq, %14, %c1 : index + %16 = scf.if %15 -> (tensor<*xoutput_type>) { + %17 = shape.num_elements %5 : tensor -> index + %from_elements = tensor.from_elements %17 : tensor<1xindex> + %18 = mhlo.dynamic_reshape %arg0, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %19 = mhlo.reshape %arg1 : (tensor<*xelem_type>) -> tensor + %20 = chlo.broadcast_next_after %18, %19 : (tensor, tensor) -> tensor + %cast = tensor.cast %20 : tensor to tensor<*xoutput_type> + scf.yield %cast : tensor<*xoutput_type> + } else { + %17 = shape.shape_eq %5, %6 : tensor, tensor + %18 = scf.if %17 -> (tensor<*xoutput_type>) { + %19 = shape.any %5, %6 : tensor, tensor -> tensor + %20 = shape.num_elements %19 : tensor -> index + %from_elements = tensor.from_elements %20 : tensor<1xindex> + %21 = mhlo.dynamic_reshape %arg0, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %22 = mhlo.dynamic_reshape %arg1, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %23 = chlo.broadcast_next_after %21, %22 : (tensor, tensor) -> tensor + %cast = tensor.cast %23 : tensor to tensor<*xoutput_type> + scf.yield %cast : tensor<*xoutput_type> + } else { + %19:2 = chlo.minimum_broadcast_shapes %5, %6 : tensor, tensor -> tensor, tensor + %20 = shape.rank %19#0 : tensor -> index + %21 = shape.rank %19#1 : tensor -> index + %22 = arith.cmpi sgt, %20, %21 : index + %23 = arith.select %22, %20, %21 : index + %24 = arith.cmpi ule, %23, %c1 : index + %25 = scf.if %24 -> (tensor<*xoutput_type>) { + %26 = shape.broadcast %19#0, %4 : tensor, tensor<1xindex> -> tensor + %cast = tensor.cast %26 : tensor to tensor<1xindex> + %27 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %28 = shape.broadcast %19#1, %4 : tensor, tensor<1xindex> -> tensor + %cast_0 = tensor.cast %28 : tensor to tensor<1xindex> + %29 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %30 = chlo.broadcast_next_after %27, %29 : (tensor, tensor) -> tensor + %cast_1 = tensor.cast %30 : tensor to tensor<*xoutput_type> + scf.yield %cast_1 : tensor<*xoutput_type> + } else { + %26 = arith.cmpi ule, %23, %c2 : index + %27 = scf.if %26 -> (tensor<*xoutput_type>) { + %28 = shape.broadcast %19#0, %3 : tensor, tensor<2xindex> -> tensor + %cast = tensor.cast %28 : tensor to tensor<2xindex> + %29 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xelem_type>, tensor<2xindex>) -> tensor + %30 = shape.broadcast %19#1, %3 : tensor, tensor<2xindex> -> tensor + %cast_0 = tensor.cast %30 : tensor to tensor<2xindex> + %31 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xelem_type>, tensor<2xindex>) -> tensor + %32 = chlo.broadcast_next_after %29, %31 : (tensor, tensor) -> tensor + %cast_1 = tensor.cast %32 : tensor to tensor<*xoutput_type> + scf.yield %cast_1 : tensor<*xoutput_type> + } else { + %28 = arith.cmpi ule, %23, %c3 : index + %29 = scf.if %28 -> (tensor<*xoutput_type>) { + %30 = shape.broadcast %19#0, %2 : tensor, tensor<3xindex> -> tensor + %cast = tensor.cast %30 : tensor to tensor<3xindex> + %31 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xelem_type>, tensor<3xindex>) -> tensor + %32 = shape.broadcast %19#1, %2 : tensor, tensor<3xindex> -> tensor + %cast_0 = tensor.cast %32 : tensor to tensor<3xindex> + %33 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xelem_type>, tensor<3xindex>) -> tensor + %34 = chlo.broadcast_next_after %31, %33 : (tensor, tensor) -> tensor + %cast_1 = tensor.cast %34 : tensor to tensor<*xoutput_type> + scf.yield %cast_1 : tensor<*xoutput_type> + } else { + %30 = arith.cmpi ule, %23, %c4 : index + %31 = scf.if %30 -> (tensor<*xoutput_type>) { + %32 = shape.broadcast %19#0, %1 : tensor, tensor<4xindex> -> tensor + %cast = tensor.cast %32 : tensor to tensor<4xindex> + %33 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xelem_type>, tensor<4xindex>) -> tensor + %34 = shape.broadcast %19#1, %1 : tensor, tensor<4xindex> -> tensor + %cast_0 = tensor.cast %34 : tensor to tensor<4xindex> + %35 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xelem_type>, tensor<4xindex>) -> tensor + %36 = chlo.broadcast_next_after %33, %35 : (tensor, tensor) -> tensor + %cast_1 = tensor.cast %36 : tensor to tensor<*xoutput_type> + scf.yield %cast_1 : tensor<*xoutput_type> + } else { + %32 = arith.cmpi ule, %23, %c5 : index + cf.assert %32, "Input for dynamic binary or n-ary op lowering was of a rank greater than 5" + %33 = shape.broadcast %19#0, %0 : tensor, tensor<5xindex> -> tensor + %cast = tensor.cast %33 : tensor to tensor<5xindex> + %34 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xelem_type>, tensor<5xindex>) -> tensor + %35 = shape.broadcast %19#1, %0 : tensor, tensor<5xindex> -> tensor + %cast_0 = tensor.cast %35 : tensor to tensor<5xindex> + %36 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xelem_type>, tensor<5xindex>) -> tensor + %37 = chlo.broadcast_next_after %34, %36 : (tensor, tensor) -> tensor + %cast_1 = tensor.cast %37 : tensor to tensor<*xoutput_type> + scf.yield %cast_1 : tensor<*xoutput_type> + } + scf.yield %31 : tensor<*xoutput_type> + } + scf.yield %29 : tensor<*xoutput_type> + } + scf.yield %27 : tensor<*xoutput_type> + } + scf.yield %25 : tensor<*xoutput_type> + } + scf.yield %18 : tensor<*xoutput_type> + } + scf.yield %16 : tensor<*xoutput_type> + } + %10 = shape.shape_of %arg0 : tensor<*xelem_type> -> tensor + %11 = shape.shape_of %arg1 : tensor<*xelem_type> -> tensor + %12 = shape.broadcast %10, %11 : tensor, tensor -> tensor + %13 = mhlo.dynamic_reshape %9, %12 : (tensor<*xoutput_type>, tensor) -> tensor<*xoutput_type> + return %13 : tensor<*xoutput_type> } diff --git a/tensorflow/core/kernels/mlir_generated/op_definitions/not_equal.mlir.tmpl b/tensorflow/core/kernels/mlir_generated/op_definitions/not_equal.mlir.tmpl index 50a2a2a717cb06..21c1f7c0ca1933 100644 --- a/tensorflow/core/kernels/mlir_generated/op_definitions/not_equal.mlir.tmpl +++ b/tensorflow/core/kernels/mlir_generated/op_definitions/not_equal.mlir.tmpl @@ -1,5 +1,129 @@ -func.func @NotEqual_platform_elem_type_output_type(%arg0: tensor<*xelem_type>, %arg1: tensor<*xelem_type>) - -> tensor<*xoutput_type> attributes {tf_entry, llvm.emit_c_interface} { - %0 = chlo.broadcast_compare %arg0, %arg1 {comparison_direction = #chlo} : (tensor<*xelem_type>, tensor<*xelem_type>) -> tensor<*xoutput_type> - func.return %0 : tensor<*xoutput_type> +func.func @NotEqual_platform_elem_type_output_type(%arg0: tensor<*xelem_type>, %arg1: tensor<*xelem_type>) -> tensor<*xoutput_type> attributes {llvm.emit_c_interface, tf_entry} { + %0 = shape.const_shape [1, 1, 1, 1, 1] : tensor<5xindex> + %c5 = arith.constant 5 : index + %1 = shape.const_shape [1, 1, 1, 1] : tensor<4xindex> + %c4 = arith.constant 4 : index + %2 = shape.const_shape [1, 1, 1] : tensor<3xindex> + %c3 = arith.constant 3 : index + %3 = shape.const_shape [1, 1] : tensor<2xindex> + %c2 = arith.constant 2 : index + %4 = shape.const_shape [1] : tensor<1xindex> + %c1 = arith.constant 1 : index + %5 = shape.shape_of %arg0 : tensor<*xelem_type> -> tensor + %6 = shape.shape_of %arg1 : tensor<*xelem_type> -> tensor + %7 = shape.num_elements %5 : tensor -> index + %8 = arith.cmpi eq, %7, %c1 : index + %9 = scf.if %8 -> (tensor<*xoutput_type>) { + %14 = shape.num_elements %6 : tensor -> index + %from_elements = tensor.from_elements %14 : tensor<1xindex> + %15 = mhlo.reshape %arg0 : (tensor<*xelem_type>) -> tensor + %16 = mhlo.dynamic_reshape %arg1, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %17 = chlo.broadcast_compare %15, %16 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %cast = tensor.cast %17 : tensor to tensor<*xoutput_type> + scf.yield %cast : tensor<*xoutput_type> + } else { + %14 = shape.num_elements %6 : tensor -> index + %15 = arith.cmpi eq, %14, %c1 : index + %16 = scf.if %15 -> (tensor<*xoutput_type>) { + %17 = shape.num_elements %5 : tensor -> index + %from_elements = tensor.from_elements %17 : tensor<1xindex> + %18 = mhlo.dynamic_reshape %arg0, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %19 = mhlo.reshape %arg1 : (tensor<*xelem_type>) -> tensor + %20 = chlo.broadcast_compare %18, %19 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %cast = tensor.cast %20 : tensor to tensor<*xoutput_type> + scf.yield %cast : tensor<*xoutput_type> + } else { + %17 = shape.shape_eq %5, %6 : tensor, tensor + %18 = scf.if %17 -> (tensor<*xoutput_type>) { + %19 = shape.any %5, %6 : tensor, tensor -> tensor + %20 = shape.num_elements %19 : tensor -> index + %from_elements = tensor.from_elements %20 : tensor<1xindex> + %21 = mhlo.dynamic_reshape %arg0, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %22 = mhlo.dynamic_reshape %arg1, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %23 = chlo.broadcast_compare %21, %22 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %cast = tensor.cast %23 : tensor to tensor<*xoutput_type> + scf.yield %cast : tensor<*xoutput_type> + } else { + %19:2 = chlo.minimum_broadcast_shapes %5, %6 : tensor, tensor -> tensor, tensor + %20 = shape.rank %19#0 : tensor -> index + %21 = shape.rank %19#1 : tensor -> index + %22 = arith.cmpi sgt, %20, %21 : index + %23 = arith.select %22, %20, %21 : index + %24 = arith.cmpi ule, %23, %c1 : index + %25 = scf.if %24 -> (tensor<*xoutput_type>) { + %26 = shape.broadcast %19#0, %4 : tensor, tensor<1xindex> -> tensor + %cast = tensor.cast %26 : tensor to tensor<1xindex> + %27 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %28 = shape.broadcast %19#1, %4 : tensor, tensor<1xindex> -> tensor + %cast_0 = tensor.cast %28 : tensor to tensor<1xindex> + %29 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %30 = chlo.broadcast_compare %27, %29 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %cast_1 = tensor.cast %30 : tensor to tensor<*xoutput_type> + scf.yield %cast_1 : tensor<*xoutput_type> + } else { + %26 = arith.cmpi ule, %23, %c2 : index + %27 = scf.if %26 -> (tensor<*xoutput_type>) { + %28 = shape.broadcast %19#0, %3 : tensor, tensor<2xindex> -> tensor + %cast = tensor.cast %28 : tensor to tensor<2xindex> + %29 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xelem_type>, tensor<2xindex>) -> tensor + %30 = shape.broadcast %19#1, %3 : tensor, tensor<2xindex> -> tensor + %cast_0 = tensor.cast %30 : tensor to tensor<2xindex> + %31 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xelem_type>, tensor<2xindex>) -> tensor + %32 = chlo.broadcast_compare %29, %31 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %cast_1 = tensor.cast %32 : tensor to tensor<*xoutput_type> + scf.yield %cast_1 : tensor<*xoutput_type> + } else { + %28 = arith.cmpi ule, %23, %c3 : index + %29 = scf.if %28 -> (tensor<*xoutput_type>) { + %30 = shape.broadcast %19#0, %2 : tensor, tensor<3xindex> -> tensor + %cast = tensor.cast %30 : tensor to tensor<3xindex> + %31 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xelem_type>, tensor<3xindex>) -> tensor + %32 = shape.broadcast %19#1, %2 : tensor, tensor<3xindex> -> tensor + %cast_0 = tensor.cast %32 : tensor to tensor<3xindex> + %33 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xelem_type>, tensor<3xindex>) -> tensor + %34 = chlo.broadcast_compare %31, %33 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %cast_1 = tensor.cast %34 : tensor to tensor<*xoutput_type> + scf.yield %cast_1 : tensor<*xoutput_type> + } else { + %30 = arith.cmpi ule, %23, %c4 : index + %31 = scf.if %30 -> (tensor<*xoutput_type>) { + %32 = shape.broadcast %19#0, %1 : tensor, tensor<4xindex> -> tensor + %cast = tensor.cast %32 : tensor to tensor<4xindex> + %33 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xelem_type>, tensor<4xindex>) -> tensor + %34 = shape.broadcast %19#1, %1 : tensor, tensor<4xindex> -> tensor + %cast_0 = tensor.cast %34 : tensor to tensor<4xindex> + %35 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xelem_type>, tensor<4xindex>) -> tensor + %36 = chlo.broadcast_compare %33, %35 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %cast_1 = tensor.cast %36 : tensor to tensor<*xoutput_type> + scf.yield %cast_1 : tensor<*xoutput_type> + } else { + %32 = arith.cmpi ule, %23, %c5 : index + cf.assert %32, "Input for dynamic binary or n-ary op lowering was of a rank greater than 5" + %33 = shape.broadcast %19#0, %0 : tensor, tensor<5xindex> -> tensor + %cast = tensor.cast %33 : tensor to tensor<5xindex> + %34 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xelem_type>, tensor<5xindex>) -> tensor + %35 = shape.broadcast %19#1, %0 : tensor, tensor<5xindex> -> tensor + %cast_0 = tensor.cast %35 : tensor to tensor<5xindex> + %36 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xelem_type>, tensor<5xindex>) -> tensor + %37 = chlo.broadcast_compare %34, %36 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %cast_1 = tensor.cast %37 : tensor to tensor<*xoutput_type> + scf.yield %cast_1 : tensor<*xoutput_type> + } + scf.yield %31 : tensor<*xoutput_type> + } + scf.yield %29 : tensor<*xoutput_type> + } + scf.yield %27 : tensor<*xoutput_type> + } + scf.yield %25 : tensor<*xoutput_type> + } + scf.yield %18 : tensor<*xoutput_type> + } + scf.yield %16 : tensor<*xoutput_type> + } + %10 = shape.shape_of %arg0 : tensor<*xelem_type> -> tensor + %11 = shape.shape_of %arg1 : tensor<*xelem_type> -> tensor + %12 = shape.broadcast %10, %11 : tensor, tensor -> tensor + %13 = mhlo.dynamic_reshape %9, %12 : (tensor<*xoutput_type>, tensor) -> tensor<*xoutput_type> + return %13 : tensor<*xoutput_type> } diff --git a/tensorflow/core/kernels/mlir_generated/op_definitions/ones_like.mlir.tmpl b/tensorflow/core/kernels/mlir_generated/op_definitions/ones_like.mlir.tmpl index e02415e579fa26..e176f96dd6eabf 100644 --- a/tensorflow/core/kernels/mlir_generated/op_definitions/ones_like.mlir.tmpl +++ b/tensorflow/core/kernels/mlir_generated/op_definitions/ones_like.mlir.tmpl @@ -1,5 +1,9 @@ -func.func @OnesLike_platform_elem_type_output_type(%arg0: tensor<*xelem_type>) - -> tensor<*xoutput_type> attributes {tf_entry, llvm.emit_c_interface} { - %0 = "chlo.constant_like"(%arg0) {value = 1 : elem_type} : (tensor<*xelem_type>) -> tensor<*xelem_type> - func.return %0 : tensor<*xoutput_type> +func.func @OnesLike_platform_elem_type_output_type(%arg0: tensor<*xelem_type>) -> tensor<*xoutput_type> attributes {llvm.emit_c_interface, tf_entry} { + %0 = shape.shape_of %arg0 : tensor<*xelem_type> -> tensor + %1 = shape.num_elements %0 : tensor -> index + %from_elements = tensor.from_elements %1 : tensor<1xindex> + %2 = mhlo.dynamic_reshape %arg0, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %3 = "chlo.constant_like"(%2) {value = 1 : elem_type} : (tensor) -> tensor + %4 = mhlo.dynamic_reshape %3, %0 : (tensor, tensor) -> tensor<*xelem_type> + return %4 : tensor<*xelem_type> } diff --git a/tensorflow/core/kernels/mlir_generated/op_definitions/ones_like_cmplx.mlir.tmpl b/tensorflow/core/kernels/mlir_generated/op_definitions/ones_like_cmplx.mlir.tmpl index b59eb9df791c10..ebc73531160654 100644 --- a/tensorflow/core/kernels/mlir_generated/op_definitions/ones_like_cmplx.mlir.tmpl +++ b/tensorflow/core/kernels/mlir_generated/op_definitions/ones_like_cmplx.mlir.tmpl @@ -1,5 +1,9 @@ -func.func @OnesLike_platform_elem_type_output_type(%arg0: tensor<*xelem_type>) - -> tensor<*xoutput_type> attributes {tf_entry, llvm.emit_c_interface} { - %0 = "chlo.constant_like"(%arg0) {value = #complex.number<:scalar_type 1.000000e+00, 0.000000e+00> : elem_type} : (tensor<*xelem_type>) -> tensor<*xelem_type> - func.return %0 : tensor<*xoutput_type> +func.func @OnesLike_platform_elem_type_output_type(%arg0: tensor<*xelem_type>) -> tensor<*xoutput_type> attributes {llvm.emit_c_interface, tf_entry} { + %0 = shape.shape_of %arg0 : tensor<*xelem_type> -> tensor + %1 = shape.num_elements %0 : tensor -> index + %from_elements = tensor.from_elements %1 : tensor<1xindex> + %2 = mhlo.dynamic_reshape %arg0, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %3 = "chlo.constant_like"(%2) {value = #complex.number<:scalar_type 1.000000e+00, 0.000000e+00> : elem_type} : (tensor) -> tensor + %4 = mhlo.dynamic_reshape %3, %0 : (tensor, tensor) -> tensor<*xelem_type> + return %4 : tensor<*xelem_type> } diff --git a/tensorflow/core/kernels/mlir_generated/op_definitions/ones_like_float.mlir.tmpl b/tensorflow/core/kernels/mlir_generated/op_definitions/ones_like_float.mlir.tmpl index 2f14209ebfde93..e2e3b9049c8637 100644 --- a/tensorflow/core/kernels/mlir_generated/op_definitions/ones_like_float.mlir.tmpl +++ b/tensorflow/core/kernels/mlir_generated/op_definitions/ones_like_float.mlir.tmpl @@ -1,5 +1,9 @@ -func.func @OnesLike_platform_elem_type_output_type(%arg0: tensor<*xelem_type>) - -> tensor<*xoutput_type> attributes {tf_entry, llvm.emit_c_interface} { - %0 = "chlo.constant_like"(%arg0) {value = 1.000000e+00 : elem_type} : (tensor<*xelem_type>) -> tensor<*xelem_type> - func.return %0 : tensor<*xoutput_type> +func.func @OnesLike_platform_elem_type_output_type(%arg0: tensor<*xelem_type>) -> tensor<*xoutput_type> attributes {llvm.emit_c_interface, tf_entry} { + %0 = shape.shape_of %arg0 : tensor<*xelem_type> -> tensor + %1 = shape.num_elements %0 : tensor -> index + %from_elements = tensor.from_elements %1 : tensor<1xindex> + %2 = mhlo.dynamic_reshape %arg0, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %3 = "chlo.constant_like"(%2) {value = 1.000000e+00 : elem_type} : (tensor) -> tensor + %4 = mhlo.dynamic_reshape %3, %0 : (tensor, tensor) -> tensor<*xelem_type> + return %4 : tensor<*xelem_type> } diff --git a/tensorflow/core/kernels/mlir_generated/op_definitions/polygamma.mlir.tmpl b/tensorflow/core/kernels/mlir_generated/op_definitions/polygamma.mlir.tmpl index bf6f31603b0aca..1eb02924c460df 100644 --- a/tensorflow/core/kernels/mlir_generated/op_definitions/polygamma.mlir.tmpl +++ b/tensorflow/core/kernels/mlir_generated/op_definitions/polygamma.mlir.tmpl @@ -1,7 +1,129 @@ -func.func @Polygamma_platform_elem_type_output_type(%arg0: tensor<*xelem_type>, - %arg1: tensor<*xelem_type>) -> tensor<*xoutput_type> - attributes {tf_entry, llvm.emit_c_interface} { - %0 = chlo.broadcast_polygamma %arg0, %arg1 : (tensor<*xelem_type>, - tensor<*xelem_type>) -> tensor<*xoutput_type> - func.return %0 : tensor<*xoutput_type> +func.func @Polygamma_platform_elem_type_output_type(%arg0: tensor<*xelem_type>, %arg1: tensor<*xelem_type>) -> tensor<*xoutput_type> attributes {llvm.emit_c_interface, tf_entry} { + %0 = shape.const_shape [1, 1, 1, 1, 1] : tensor<5xindex> + %c5 = arith.constant 5 : index + %1 = shape.const_shape [1, 1, 1, 1] : tensor<4xindex> + %c4 = arith.constant 4 : index + %2 = shape.const_shape [1, 1, 1] : tensor<3xindex> + %c3 = arith.constant 3 : index + %3 = shape.const_shape [1, 1] : tensor<2xindex> + %c2 = arith.constant 2 : index + %4 = shape.const_shape [1] : tensor<1xindex> + %c1 = arith.constant 1 : index + %5 = shape.shape_of %arg0 : tensor<*xelem_type> -> tensor + %6 = shape.shape_of %arg1 : tensor<*xelem_type> -> tensor + %7 = shape.num_elements %5 : tensor -> index + %8 = arith.cmpi eq, %7, %c1 : index + %9 = scf.if %8 -> (tensor<*xoutput_type>) { + %14 = shape.num_elements %6 : tensor -> index + %from_elements = tensor.from_elements %14 : tensor<1xindex> + %15 = mhlo.reshape %arg0 : (tensor<*xelem_type>) -> tensor + %16 = mhlo.dynamic_reshape %arg1, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %17 = chlo.broadcast_polygamma %15, %16 : (tensor, tensor) -> tensor + %cast = tensor.cast %17 : tensor to tensor<*xoutput_type> + scf.yield %cast : tensor<*xoutput_type> + } else { + %14 = shape.num_elements %6 : tensor -> index + %15 = arith.cmpi eq, %14, %c1 : index + %16 = scf.if %15 -> (tensor<*xoutput_type>) { + %17 = shape.num_elements %5 : tensor -> index + %from_elements = tensor.from_elements %17 : tensor<1xindex> + %18 = mhlo.dynamic_reshape %arg0, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %19 = mhlo.reshape %arg1 : (tensor<*xelem_type>) -> tensor + %20 = chlo.broadcast_polygamma %18, %19 : (tensor, tensor) -> tensor + %cast = tensor.cast %20 : tensor to tensor<*xoutput_type> + scf.yield %cast : tensor<*xoutput_type> + } else { + %17 = shape.shape_eq %5, %6 : tensor, tensor + %18 = scf.if %17 -> (tensor<*xoutput_type>) { + %19 = shape.any %5, %6 : tensor, tensor -> tensor + %20 = shape.num_elements %19 : tensor -> index + %from_elements = tensor.from_elements %20 : tensor<1xindex> + %21 = mhlo.dynamic_reshape %arg0, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %22 = mhlo.dynamic_reshape %arg1, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %23 = chlo.broadcast_polygamma %21, %22 : (tensor, tensor) -> tensor + %cast = tensor.cast %23 : tensor to tensor<*xoutput_type> + scf.yield %cast : tensor<*xoutput_type> + } else { + %19:2 = chlo.minimum_broadcast_shapes %5, %6 : tensor, tensor -> tensor, tensor + %20 = shape.rank %19#0 : tensor -> index + %21 = shape.rank %19#1 : tensor -> index + %22 = arith.cmpi sgt, %20, %21 : index + %23 = arith.select %22, %20, %21 : index + %24 = arith.cmpi ule, %23, %c1 : index + %25 = scf.if %24 -> (tensor<*xoutput_type>) { + %26 = shape.broadcast %19#0, %4 : tensor, tensor<1xindex> -> tensor + %cast = tensor.cast %26 : tensor to tensor<1xindex> + %27 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %28 = shape.broadcast %19#1, %4 : tensor, tensor<1xindex> -> tensor + %cast_0 = tensor.cast %28 : tensor to tensor<1xindex> + %29 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %30 = chlo.broadcast_polygamma %27, %29 : (tensor, tensor) -> tensor + %cast_1 = tensor.cast %30 : tensor to tensor<*xoutput_type> + scf.yield %cast_1 : tensor<*xoutput_type> + } else { + %26 = arith.cmpi ule, %23, %c2 : index + %27 = scf.if %26 -> (tensor<*xoutput_type>) { + %28 = shape.broadcast %19#0, %3 : tensor, tensor<2xindex> -> tensor + %cast = tensor.cast %28 : tensor to tensor<2xindex> + %29 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xelem_type>, tensor<2xindex>) -> tensor + %30 = shape.broadcast %19#1, %3 : tensor, tensor<2xindex> -> tensor + %cast_0 = tensor.cast %30 : tensor to tensor<2xindex> + %31 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xelem_type>, tensor<2xindex>) -> tensor + %32 = chlo.broadcast_polygamma %29, %31 : (tensor, tensor) -> tensor + %cast_1 = tensor.cast %32 : tensor to tensor<*xoutput_type> + scf.yield %cast_1 : tensor<*xoutput_type> + } else { + %28 = arith.cmpi ule, %23, %c3 : index + %29 = scf.if %28 -> (tensor<*xoutput_type>) { + %30 = shape.broadcast %19#0, %2 : tensor, tensor<3xindex> -> tensor + %cast = tensor.cast %30 : tensor to tensor<3xindex> + %31 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xelem_type>, tensor<3xindex>) -> tensor + %32 = shape.broadcast %19#1, %2 : tensor, tensor<3xindex> -> tensor + %cast_0 = tensor.cast %32 : tensor to tensor<3xindex> + %33 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xelem_type>, tensor<3xindex>) -> tensor + %34 = chlo.broadcast_polygamma %31, %33 : (tensor, tensor) -> tensor + %cast_1 = tensor.cast %34 : tensor to tensor<*xoutput_type> + scf.yield %cast_1 : tensor<*xoutput_type> + } else { + %30 = arith.cmpi ule, %23, %c4 : index + %31 = scf.if %30 -> (tensor<*xoutput_type>) { + %32 = shape.broadcast %19#0, %1 : tensor, tensor<4xindex> -> tensor + %cast = tensor.cast %32 : tensor to tensor<4xindex> + %33 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xelem_type>, tensor<4xindex>) -> tensor + %34 = shape.broadcast %19#1, %1 : tensor, tensor<4xindex> -> tensor + %cast_0 = tensor.cast %34 : tensor to tensor<4xindex> + %35 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xelem_type>, tensor<4xindex>) -> tensor + %36 = chlo.broadcast_polygamma %33, %35 : (tensor, tensor) -> tensor + %cast_1 = tensor.cast %36 : tensor to tensor<*xoutput_type> + scf.yield %cast_1 : tensor<*xoutput_type> + } else { + %32 = arith.cmpi ule, %23, %c5 : index + cf.assert %32, "Input for dynamic binary or n-ary op lowering was of a rank greater than 5" + %33 = shape.broadcast %19#0, %0 : tensor, tensor<5xindex> -> tensor + %cast = tensor.cast %33 : tensor to tensor<5xindex> + %34 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xelem_type>, tensor<5xindex>) -> tensor + %35 = shape.broadcast %19#1, %0 : tensor, tensor<5xindex> -> tensor + %cast_0 = tensor.cast %35 : tensor to tensor<5xindex> + %36 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xelem_type>, tensor<5xindex>) -> tensor + %37 = chlo.broadcast_polygamma %34, %36 : (tensor, tensor) -> tensor + %cast_1 = tensor.cast %37 : tensor to tensor<*xoutput_type> + scf.yield %cast_1 : tensor<*xoutput_type> + } + scf.yield %31 : tensor<*xoutput_type> + } + scf.yield %29 : tensor<*xoutput_type> + } + scf.yield %27 : tensor<*xoutput_type> + } + scf.yield %25 : tensor<*xoutput_type> + } + scf.yield %18 : tensor<*xoutput_type> + } + scf.yield %16 : tensor<*xoutput_type> + } + %10 = shape.shape_of %arg0 : tensor<*xelem_type> -> tensor + %11 = shape.shape_of %arg1 : tensor<*xelem_type> -> tensor + %12 = shape.broadcast %10, %11 : tensor, tensor -> tensor + %13 = mhlo.dynamic_reshape %9, %12 : (tensor<*xoutput_type>, tensor) -> tensor<*xoutput_type> + return %13 : tensor<*xoutput_type> } diff --git a/tensorflow/core/kernels/mlir_generated/op_definitions/pow.mlir.tmpl b/tensorflow/core/kernels/mlir_generated/op_definitions/pow.mlir.tmpl index 0858ec866d5467..a64b681bb30d95 100644 --- a/tensorflow/core/kernels/mlir_generated/op_definitions/pow.mlir.tmpl +++ b/tensorflow/core/kernels/mlir_generated/op_definitions/pow.mlir.tmpl @@ -1,6 +1,129 @@ -func.func @Pow_platform_elem_type_output_type(%arg0: tensor<*xelem_type>, %arg1: tensor<*xelem_type>) - -> tensor<*xoutput_type> attributes {tf_entry, llvm.emit_c_interface} { - %0 = chlo.broadcast_power %arg0, %arg1 - : (tensor<*xelem_type>, tensor<*xelem_type>) -> tensor<*xoutput_type> - func.return %0 : tensor<*xoutput_type> +func.func @Pow_platform_elem_type_output_type(%arg0: tensor<*xelem_type>, %arg1: tensor<*xelem_type>) -> tensor<*xoutput_type> attributes {llvm.emit_c_interface, tf_entry} { + %0 = shape.const_shape [1, 1, 1, 1, 1] : tensor<5xindex> + %c5 = arith.constant 5 : index + %1 = shape.const_shape [1, 1, 1, 1] : tensor<4xindex> + %c4 = arith.constant 4 : index + %2 = shape.const_shape [1, 1, 1] : tensor<3xindex> + %c3 = arith.constant 3 : index + %3 = shape.const_shape [1, 1] : tensor<2xindex> + %c2 = arith.constant 2 : index + %4 = shape.const_shape [1] : tensor<1xindex> + %c1 = arith.constant 1 : index + %5 = shape.shape_of %arg0 : tensor<*xelem_type> -> tensor + %6 = shape.shape_of %arg1 : tensor<*xelem_type> -> tensor + %7 = shape.num_elements %5 : tensor -> index + %8 = arith.cmpi eq, %7, %c1 : index + %9 = scf.if %8 -> (tensor<*xoutput_type>) { + %14 = shape.num_elements %6 : tensor -> index + %from_elements = tensor.from_elements %14 : tensor<1xindex> + %15 = mhlo.reshape %arg0 : (tensor<*xelem_type>) -> tensor + %16 = mhlo.dynamic_reshape %arg1, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %17 = chlo.broadcast_power %15, %16 : (tensor, tensor) -> tensor + %cast = tensor.cast %17 : tensor to tensor<*xoutput_type> + scf.yield %cast : tensor<*xoutput_type> + } else { + %14 = shape.num_elements %6 : tensor -> index + %15 = arith.cmpi eq, %14, %c1 : index + %16 = scf.if %15 -> (tensor<*xoutput_type>) { + %17 = shape.num_elements %5 : tensor -> index + %from_elements = tensor.from_elements %17 : tensor<1xindex> + %18 = mhlo.dynamic_reshape %arg0, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %19 = mhlo.reshape %arg1 : (tensor<*xelem_type>) -> tensor + %20 = chlo.broadcast_power %18, %19 : (tensor, tensor) -> tensor + %cast = tensor.cast %20 : tensor to tensor<*xoutput_type> + scf.yield %cast : tensor<*xoutput_type> + } else { + %17 = shape.shape_eq %5, %6 : tensor, tensor + %18 = scf.if %17 -> (tensor<*xoutput_type>) { + %19 = shape.any %5, %6 : tensor, tensor -> tensor + %20 = shape.num_elements %19 : tensor -> index + %from_elements = tensor.from_elements %20 : tensor<1xindex> + %21 = mhlo.dynamic_reshape %arg0, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %22 = mhlo.dynamic_reshape %arg1, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %23 = chlo.broadcast_power %21, %22 : (tensor, tensor) -> tensor + %cast = tensor.cast %23 : tensor to tensor<*xoutput_type> + scf.yield %cast : tensor<*xoutput_type> + } else { + %19:2 = chlo.minimum_broadcast_shapes %5, %6 : tensor, tensor -> tensor, tensor + %20 = shape.rank %19#0 : tensor -> index + %21 = shape.rank %19#1 : tensor -> index + %22 = arith.cmpi sgt, %20, %21 : index + %23 = arith.select %22, %20, %21 : index + %24 = arith.cmpi ule, %23, %c1 : index + %25 = scf.if %24 -> (tensor<*xoutput_type>) { + %26 = shape.broadcast %19#0, %4 : tensor, tensor<1xindex> -> tensor + %cast = tensor.cast %26 : tensor to tensor<1xindex> + %27 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %28 = shape.broadcast %19#1, %4 : tensor, tensor<1xindex> -> tensor + %cast_0 = tensor.cast %28 : tensor to tensor<1xindex> + %29 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %30 = chlo.broadcast_power %27, %29 : (tensor, tensor) -> tensor + %cast_1 = tensor.cast %30 : tensor to tensor<*xoutput_type> + scf.yield %cast_1 : tensor<*xoutput_type> + } else { + %26 = arith.cmpi ule, %23, %c2 : index + %27 = scf.if %26 -> (tensor<*xoutput_type>) { + %28 = shape.broadcast %19#0, %3 : tensor, tensor<2xindex> -> tensor + %cast = tensor.cast %28 : tensor to tensor<2xindex> + %29 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xelem_type>, tensor<2xindex>) -> tensor + %30 = shape.broadcast %19#1, %3 : tensor, tensor<2xindex> -> tensor + %cast_0 = tensor.cast %30 : tensor to tensor<2xindex> + %31 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xelem_type>, tensor<2xindex>) -> tensor + %32 = chlo.broadcast_power %29, %31 : (tensor, tensor) -> tensor + %cast_1 = tensor.cast %32 : tensor to tensor<*xoutput_type> + scf.yield %cast_1 : tensor<*xoutput_type> + } else { + %28 = arith.cmpi ule, %23, %c3 : index + %29 = scf.if %28 -> (tensor<*xoutput_type>) { + %30 = shape.broadcast %19#0, %2 : tensor, tensor<3xindex> -> tensor + %cast = tensor.cast %30 : tensor to tensor<3xindex> + %31 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xelem_type>, tensor<3xindex>) -> tensor + %32 = shape.broadcast %19#1, %2 : tensor, tensor<3xindex> -> tensor + %cast_0 = tensor.cast %32 : tensor to tensor<3xindex> + %33 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xelem_type>, tensor<3xindex>) -> tensor + %34 = chlo.broadcast_power %31, %33 : (tensor, tensor) -> tensor + %cast_1 = tensor.cast %34 : tensor to tensor<*xoutput_type> + scf.yield %cast_1 : tensor<*xoutput_type> + } else { + %30 = arith.cmpi ule, %23, %c4 : index + %31 = scf.if %30 -> (tensor<*xoutput_type>) { + %32 = shape.broadcast %19#0, %1 : tensor, tensor<4xindex> -> tensor + %cast = tensor.cast %32 : tensor to tensor<4xindex> + %33 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xelem_type>, tensor<4xindex>) -> tensor + %34 = shape.broadcast %19#1, %1 : tensor, tensor<4xindex> -> tensor + %cast_0 = tensor.cast %34 : tensor to tensor<4xindex> + %35 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xelem_type>, tensor<4xindex>) -> tensor + %36 = chlo.broadcast_power %33, %35 : (tensor, tensor) -> tensor + %cast_1 = tensor.cast %36 : tensor to tensor<*xoutput_type> + scf.yield %cast_1 : tensor<*xoutput_type> + } else { + %32 = arith.cmpi ule, %23, %c5 : index + cf.assert %32, "Input for dynamic binary or n-ary op lowering was of a rank greater than 5" + %33 = shape.broadcast %19#0, %0 : tensor, tensor<5xindex> -> tensor + %cast = tensor.cast %33 : tensor to tensor<5xindex> + %34 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xelem_type>, tensor<5xindex>) -> tensor + %35 = shape.broadcast %19#1, %0 : tensor, tensor<5xindex> -> tensor + %cast_0 = tensor.cast %35 : tensor to tensor<5xindex> + %36 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xelem_type>, tensor<5xindex>) -> tensor + %37 = chlo.broadcast_power %34, %36 : (tensor, tensor) -> tensor + %cast_1 = tensor.cast %37 : tensor to tensor<*xoutput_type> + scf.yield %cast_1 : tensor<*xoutput_type> + } + scf.yield %31 : tensor<*xoutput_type> + } + scf.yield %29 : tensor<*xoutput_type> + } + scf.yield %27 : tensor<*xoutput_type> + } + scf.yield %25 : tensor<*xoutput_type> + } + scf.yield %18 : tensor<*xoutput_type> + } + scf.yield %16 : tensor<*xoutput_type> + } + %10 = shape.shape_of %arg0 : tensor<*xelem_type> -> tensor + %11 = shape.shape_of %arg1 : tensor<*xelem_type> -> tensor + %12 = shape.broadcast %10, %11 : tensor, tensor -> tensor + %13 = mhlo.dynamic_reshape %9, %12 : (tensor<*xoutput_type>, tensor) -> tensor<*xoutput_type> + return %13 : tensor<*xoutput_type> } diff --git a/tensorflow/core/kernels/mlir_generated/op_definitions/real.mlir.tmpl b/tensorflow/core/kernels/mlir_generated/op_definitions/real.mlir.tmpl index ec4206ed73be12..435295f8538eb5 100644 --- a/tensorflow/core/kernels/mlir_generated/op_definitions/real.mlir.tmpl +++ b/tensorflow/core/kernels/mlir_generated/op_definitions/real.mlir.tmpl @@ -1,5 +1,9 @@ -func.func @Real_platform_elem_type_output_type(%arg0: tensor<*xelem_type>) - -> tensor<*xoutput_type> attributes {tf_entry, llvm.emit_c_interface} { - %0 = mhlo.real %arg0 : (tensor<*xelem_type>) -> tensor<*xoutput_type> - func.return %0 : tensor<*xoutput_type> +func.func @Real_platform_elem_type_output_type(%arg0: tensor<*xelem_type>) -> tensor<*xoutput_type> attributes {llvm.emit_c_interface, tf_entry} { + %0 = shape.shape_of %arg0 : tensor<*xelem_type> -> tensor + %1 = shape.num_elements %0 : tensor -> index + %from_elements = tensor.from_elements %1 : tensor<1xindex> + %2 = mhlo.dynamic_reshape %arg0, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %3 = mhlo.real %2 : (tensor) -> tensor + %4 = mhlo.dynamic_reshape %3, %0 : (tensor, tensor) -> tensor<*xoutput_type> + return %4 : tensor<*xoutput_type> } diff --git a/tensorflow/core/kernels/mlir_generated/op_definitions/reciprocal.mlir.tmpl b/tensorflow/core/kernels/mlir_generated/op_definitions/reciprocal.mlir.tmpl index 691de823709f55..f9fb66a39658db 100644 --- a/tensorflow/core/kernels/mlir_generated/op_definitions/reciprocal.mlir.tmpl +++ b/tensorflow/core/kernels/mlir_generated/op_definitions/reciprocal.mlir.tmpl @@ -1,6 +1,10 @@ -func.func @Reciprocal_platform_elem_type_output_type(%arg0: tensor<*xelem_type>) - -> tensor<*xoutput_type> attributes {tf_entry, llvm.emit_c_interface} { +func.func @Reciprocal_platform_elem_type_output_type(%arg0: tensor<*xelem_type>) -> tensor<*xoutput_type> attributes {llvm.emit_c_interface, tf_entry} { %0 = mhlo.constant dense<1> : tensor - %1 = chlo.broadcast_divide %0, %arg0 : (tensor, tensor<*xelem_type>) -> tensor<*xelem_type> - func.return %1 : tensor<*xoutput_type> + %1 = shape.shape_of %arg0 : tensor<*xelem_type> -> tensor + %2 = shape.num_elements %1 : tensor -> index + %from_elements = tensor.from_elements %2 : tensor<1xindex> + %3 = mhlo.dynamic_reshape %arg0, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %4 = chlo.broadcast_divide %0, %3 : (tensor, tensor) -> tensor + %5 = mhlo.dynamic_reshape %4, %1 : (tensor, tensor) -> tensor<*xelem_type> + return %5 : tensor<*xelem_type> } diff --git a/tensorflow/core/kernels/mlir_generated/op_definitions/reciprocal_cmplx.mlir.tmpl b/tensorflow/core/kernels/mlir_generated/op_definitions/reciprocal_cmplx.mlir.tmpl index 4533174f0fb510..02889889064ea3 100644 --- a/tensorflow/core/kernels/mlir_generated/op_definitions/reciprocal_cmplx.mlir.tmpl +++ b/tensorflow/core/kernels/mlir_generated/op_definitions/reciprocal_cmplx.mlir.tmpl @@ -1,6 +1,10 @@ -func.func @Reciprocal_platform_elem_type_output_type(%arg0: tensor<*xelem_type>) - -> tensor<*xoutput_type> attributes {tf_entry, llvm.emit_c_interface} { +func.func @Reciprocal_platform_elem_type_output_type(%arg0: tensor<*xelem_type>) -> tensor<*xoutput_type> attributes {llvm.emit_c_interface, tf_entry} { %0 = mhlo.constant dense<(1.000000e+00,0.000000e+00)> : tensor - %1 = chlo.broadcast_divide %0, %arg0 : (tensor, tensor<*xelem_type>) -> tensor<*xelem_type> - func.return %1 : tensor<*xoutput_type> + %1 = shape.shape_of %arg0 : tensor<*xelem_type> -> tensor + %2 = shape.num_elements %1 : tensor -> index + %from_elements = tensor.from_elements %2 : tensor<1xindex> + %3 = mhlo.dynamic_reshape %arg0, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %4 = chlo.broadcast_divide %0, %3 : (tensor, tensor) -> tensor + %5 = mhlo.dynamic_reshape %4, %1 : (tensor, tensor) -> tensor<*xelem_type> + return %5 : tensor<*xelem_type> } diff --git a/tensorflow/core/kernels/mlir_generated/op_definitions/reciprocal_float.mlir.tmpl b/tensorflow/core/kernels/mlir_generated/op_definitions/reciprocal_float.mlir.tmpl index afa19f32504afb..75fe52db7a277b 100644 --- a/tensorflow/core/kernels/mlir_generated/op_definitions/reciprocal_float.mlir.tmpl +++ b/tensorflow/core/kernels/mlir_generated/op_definitions/reciprocal_float.mlir.tmpl @@ -1,6 +1,10 @@ -func.func @Reciprocal_platform_elem_type_output_type(%arg0: tensor<*xelem_type>) - -> tensor<*xoutput_type> attributes {tf_entry, llvm.emit_c_interface} { +func.func @Reciprocal_platform_elem_type_output_type(%arg0: tensor<*xelem_type>) -> tensor<*xoutput_type> attributes {llvm.emit_c_interface, tf_entry} { %0 = mhlo.constant dense<1.000000e+00> : tensor - %1 = chlo.broadcast_divide %0, %arg0 : (tensor, tensor<*xelem_type>) -> tensor<*xelem_type> - func.return %1 : tensor<*xoutput_type> + %1 = shape.shape_of %arg0 : tensor<*xelem_type> -> tensor + %2 = shape.num_elements %1 : tensor -> index + %from_elements = tensor.from_elements %2 : tensor<1xindex> + %3 = mhlo.dynamic_reshape %arg0, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %4 = chlo.broadcast_divide %0, %3 : (tensor, tensor) -> tensor + %5 = mhlo.dynamic_reshape %4, %1 : (tensor, tensor) -> tensor<*xelem_type> + return %5 : tensor<*xelem_type> } diff --git a/tensorflow/core/kernels/mlir_generated/op_definitions/relu.mlir.tmpl b/tensorflow/core/kernels/mlir_generated/op_definitions/relu.mlir.tmpl index bde2dd294361e6..83029a0a31aab0 100644 --- a/tensorflow/core/kernels/mlir_generated/op_definitions/relu.mlir.tmpl +++ b/tensorflow/core/kernels/mlir_generated/op_definitions/relu.mlir.tmpl @@ -1,6 +1,10 @@ -func.func @Relu_platform_elem_type_output_type(%arg0: tensor<*xelem_type>) - -> tensor<*xoutput_type> attributes {tf_entry, llvm.emit_c_interface} { +func.func @Relu_platform_elem_type_output_type(%arg0: tensor<*xelem_type>) -> tensor<*xoutput_type> attributes {llvm.emit_c_interface, tf_entry} { %0 = mhlo.constant dense<0> : tensor - %1 = chlo.broadcast_maximum %0, %arg0 : (tensor, tensor<*xelem_type>) -> tensor<*xelem_type> - func.return %1 : tensor<*xoutput_type> + %1 = shape.shape_of %arg0 : tensor<*xelem_type> -> tensor + %2 = shape.num_elements %1 : tensor -> index + %from_elements = tensor.from_elements %2 : tensor<1xindex> + %3 = mhlo.dynamic_reshape %arg0, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %4 = chlo.broadcast_maximum %3, %0 : (tensor, tensor) -> tensor + %5 = mhlo.dynamic_reshape %4, %1 : (tensor, tensor) -> tensor<*xelem_type> + return %5 : tensor<*xelem_type> } diff --git a/tensorflow/core/kernels/mlir_generated/op_definitions/relu_float.mlir.tmpl b/tensorflow/core/kernels/mlir_generated/op_definitions/relu_float.mlir.tmpl index c31cf52a3c1f49..77fa9802579cee 100644 --- a/tensorflow/core/kernels/mlir_generated/op_definitions/relu_float.mlir.tmpl +++ b/tensorflow/core/kernels/mlir_generated/op_definitions/relu_float.mlir.tmpl @@ -1,6 +1,10 @@ -func.func @Relu_platform_elem_type_output_type(%arg0: tensor<*xelem_type>) - -> tensor<*xoutput_type> attributes {tf_entry, llvm.emit_c_interface} { +func.func @Relu_platform_elem_type_output_type(%arg0: tensor<*xelem_type>) -> tensor<*xoutput_type> attributes {llvm.emit_c_interface, tf_entry} { %0 = mhlo.constant dense<0.000000e+00> : tensor - %1 = chlo.broadcast_maximum %0, %arg0 : (tensor, tensor<*xelem_type>) -> tensor<*xelem_type> - func.return %1 : tensor<*xoutput_type> + %1 = shape.shape_of %arg0 : tensor<*xelem_type> -> tensor + %2 = shape.num_elements %1 : tensor -> index + %from_elements = tensor.from_elements %2 : tensor<1xindex> + %3 = mhlo.dynamic_reshape %arg0, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %4 = chlo.broadcast_maximum %3, %0 : (tensor, tensor) -> tensor + %5 = mhlo.dynamic_reshape %4, %1 : (tensor, tensor) -> tensor<*xelem_type> + return %5 : tensor<*xelem_type> } diff --git a/tensorflow/core/kernels/mlir_generated/op_definitions/relu_grad.mlir.tmpl b/tensorflow/core/kernels/mlir_generated/op_definitions/relu_grad.mlir.tmpl index 1408372bff6415..2f5533b8384130 100644 --- a/tensorflow/core/kernels/mlir_generated/op_definitions/relu_grad.mlir.tmpl +++ b/tensorflow/core/kernels/mlir_generated/op_definitions/relu_grad.mlir.tmpl @@ -1,7 +1,14 @@ -func.func @ReluGrad_platform_elem_type_output_type(%arg0: tensor<*xelem_type>, %arg1 : tensor<*xelem_type>) - -> tensor<*xoutput_type> attributes {tf_entry, llvm.emit_c_interface} { - %0 = "chlo.constant_like"(%arg1) {value = 0.000000e+00 : elem_type} : (tensor<*xelem_type>) -> tensor<*xelem_type> - %1 = mhlo.compare GT, %arg1, %0 : (tensor<*xelem_type>, tensor<*xelem_type>) -> tensor<*xi1> - %2 = mhlo.select %1, %arg0, %0 : tensor<*xi1>, tensor<*xelem_type> - func.return %2 : tensor<*xoutput_type> +func.func @ReluGrad_platform_elem_type_output_type(%arg0: tensor<*xelem_type>, %arg1: tensor<*xelem_type>) -> tensor<*xoutput_type> attributes {llvm.emit_c_interface, tf_entry} { + %0 = shape.shape_of %arg0 : tensor<*xelem_type> -> tensor + %1 = shape.shape_of %arg1 : tensor<*xelem_type> -> tensor + %2 = shape.any %0, %1 : tensor, tensor -> tensor + %3 = shape.num_elements %2 : tensor -> index + %from_elements = tensor.from_elements %3 : tensor<1xindex> + %4 = mhlo.dynamic_reshape %arg0, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %5 = mhlo.dynamic_reshape %arg1, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %6 = "chlo.constant_like"(%5) {value = 0.000000e+00 : elem_type} : (tensor) -> tensor + %7 = mhlo.compare GT, %5, %6 : (tensor, tensor) -> tensor + %8 = mhlo.select %7, %4, %6 : tensor, tensor + %9 = mhlo.dynamic_reshape %8, %0 : (tensor, tensor) -> tensor<*xelem_type> + return %9 : tensor<*xelem_type> } diff --git a/tensorflow/core/kernels/mlir_generated/op_definitions/right_shift.mlir.tmpl b/tensorflow/core/kernels/mlir_generated/op_definitions/right_shift.mlir.tmpl index b6f6b9f98627fc..1398a68d4713a4 100644 --- a/tensorflow/core/kernels/mlir_generated/op_definitions/right_shift.mlir.tmpl +++ b/tensorflow/core/kernels/mlir_generated/op_definitions/right_shift.mlir.tmpl @@ -1,5 +1,129 @@ -func.func @RightShift_platform_elem_type_output_type(%arg0: tensor<*xelem_type>, %arg1: tensor<*xelem_type>) - -> tensor<*xoutput_type> attributes {tf_entry, llvm.emit_c_interface} { - %0 = chlo.broadcast_shift_right_arithmetic %arg0, %arg1 : (tensor<*xelem_type>, tensor<*xelem_type>) -> tensor<*xelem_type> - func.return %0 : tensor<*xoutput_type> +func.func @RightShift_platform_elem_type_output_type(%arg0: tensor<*xelem_type>, %arg1: tensor<*xelem_type>) -> tensor<*xoutput_type> attributes {llvm.emit_c_interface, tf_entry} { + %0 = shape.const_shape [1, 1, 1, 1, 1] : tensor<5xindex> + %c5 = arith.constant 5 : index + %1 = shape.const_shape [1, 1, 1, 1] : tensor<4xindex> + %c4 = arith.constant 4 : index + %2 = shape.const_shape [1, 1, 1] : tensor<3xindex> + %c3 = arith.constant 3 : index + %3 = shape.const_shape [1, 1] : tensor<2xindex> + %c2 = arith.constant 2 : index + %4 = shape.const_shape [1] : tensor<1xindex> + %c1 = arith.constant 1 : index + %5 = shape.shape_of %arg0 : tensor<*xelem_type> -> tensor + %6 = shape.shape_of %arg1 : tensor<*xelem_type> -> tensor + %7 = shape.num_elements %5 : tensor -> index + %8 = arith.cmpi eq, %7, %c1 : index + %9 = scf.if %8 -> (tensor<*xelem_type>) { + %14 = shape.num_elements %6 : tensor -> index + %from_elements = tensor.from_elements %14 : tensor<1xindex> + %15 = mhlo.reshape %arg0 : (tensor<*xelem_type>) -> tensor + %16 = mhlo.dynamic_reshape %arg1, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %17 = chlo.broadcast_shift_right_arithmetic %15, %16 : (tensor, tensor) -> tensor + %cast = tensor.cast %17 : tensor to tensor<*xelem_type> + scf.yield %cast : tensor<*xelem_type> + } else { + %14 = shape.num_elements %6 : tensor -> index + %15 = arith.cmpi eq, %14, %c1 : index + %16 = scf.if %15 -> (tensor<*xelem_type>) { + %17 = shape.num_elements %5 : tensor -> index + %from_elements = tensor.from_elements %17 : tensor<1xindex> + %18 = mhlo.dynamic_reshape %arg0, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %19 = mhlo.reshape %arg1 : (tensor<*xelem_type>) -> tensor + %20 = chlo.broadcast_shift_right_arithmetic %18, %19 : (tensor, tensor) -> tensor + %cast = tensor.cast %20 : tensor to tensor<*xelem_type> + scf.yield %cast : tensor<*xelem_type> + } else { + %17 = shape.shape_eq %5, %6 : tensor, tensor + %18 = scf.if %17 -> (tensor<*xelem_type>) { + %19 = shape.any %5, %6 : tensor, tensor -> tensor + %20 = shape.num_elements %19 : tensor -> index + %from_elements = tensor.from_elements %20 : tensor<1xindex> + %21 = mhlo.dynamic_reshape %arg0, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %22 = mhlo.dynamic_reshape %arg1, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %23 = chlo.broadcast_shift_right_arithmetic %21, %22 : (tensor, tensor) -> tensor + %cast = tensor.cast %23 : tensor to tensor<*xelem_type> + scf.yield %cast : tensor<*xelem_type> + } else { + %19:2 = chlo.minimum_broadcast_shapes %5, %6 : tensor, tensor -> tensor, tensor + %20 = shape.rank %19#0 : tensor -> index + %21 = shape.rank %19#1 : tensor -> index + %22 = arith.cmpi sgt, %20, %21 : index + %23 = arith.select %22, %20, %21 : index + %24 = arith.cmpi ule, %23, %c1 : index + %25 = scf.if %24 -> (tensor<*xelem_type>) { + %26 = shape.broadcast %19#0, %4 : tensor, tensor<1xindex> -> tensor + %cast = tensor.cast %26 : tensor to tensor<1xindex> + %27 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %28 = shape.broadcast %19#1, %4 : tensor, tensor<1xindex> -> tensor + %cast_0 = tensor.cast %28 : tensor to tensor<1xindex> + %29 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %30 = chlo.broadcast_shift_right_arithmetic %27, %29 : (tensor, tensor) -> tensor + %cast_1 = tensor.cast %30 : tensor to tensor<*xelem_type> + scf.yield %cast_1 : tensor<*xelem_type> + } else { + %26 = arith.cmpi ule, %23, %c2 : index + %27 = scf.if %26 -> (tensor<*xelem_type>) { + %28 = shape.broadcast %19#0, %3 : tensor, tensor<2xindex> -> tensor + %cast = tensor.cast %28 : tensor to tensor<2xindex> + %29 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xelem_type>, tensor<2xindex>) -> tensor + %30 = shape.broadcast %19#1, %3 : tensor, tensor<2xindex> -> tensor + %cast_0 = tensor.cast %30 : tensor to tensor<2xindex> + %31 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xelem_type>, tensor<2xindex>) -> tensor + %32 = chlo.broadcast_shift_right_arithmetic %29, %31 : (tensor, tensor) -> tensor + %cast_1 = tensor.cast %32 : tensor to tensor<*xelem_type> + scf.yield %cast_1 : tensor<*xelem_type> + } else { + %28 = arith.cmpi ule, %23, %c3 : index + %29 = scf.if %28 -> (tensor<*xelem_type>) { + %30 = shape.broadcast %19#0, %2 : tensor, tensor<3xindex> -> tensor + %cast = tensor.cast %30 : tensor to tensor<3xindex> + %31 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xelem_type>, tensor<3xindex>) -> tensor + %32 = shape.broadcast %19#1, %2 : tensor, tensor<3xindex> -> tensor + %cast_0 = tensor.cast %32 : tensor to tensor<3xindex> + %33 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xelem_type>, tensor<3xindex>) -> tensor + %34 = chlo.broadcast_shift_right_arithmetic %31, %33 : (tensor, tensor) -> tensor + %cast_1 = tensor.cast %34 : tensor to tensor<*xelem_type> + scf.yield %cast_1 : tensor<*xelem_type> + } else { + %30 = arith.cmpi ule, %23, %c4 : index + %31 = scf.if %30 -> (tensor<*xelem_type>) { + %32 = shape.broadcast %19#0, %1 : tensor, tensor<4xindex> -> tensor + %cast = tensor.cast %32 : tensor to tensor<4xindex> + %33 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xelem_type>, tensor<4xindex>) -> tensor + %34 = shape.broadcast %19#1, %1 : tensor, tensor<4xindex> -> tensor + %cast_0 = tensor.cast %34 : tensor to tensor<4xindex> + %35 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xelem_type>, tensor<4xindex>) -> tensor + %36 = chlo.broadcast_shift_right_arithmetic %33, %35 : (tensor, tensor) -> tensor + %cast_1 = tensor.cast %36 : tensor to tensor<*xelem_type> + scf.yield %cast_1 : tensor<*xelem_type> + } else { + %32 = arith.cmpi ule, %23, %c5 : index + cf.assert %32, "Input for dynamic binary or n-ary op lowering was of a rank greater than 5" + %33 = shape.broadcast %19#0, %0 : tensor, tensor<5xindex> -> tensor + %cast = tensor.cast %33 : tensor to tensor<5xindex> + %34 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xelem_type>, tensor<5xindex>) -> tensor + %35 = shape.broadcast %19#1, %0 : tensor, tensor<5xindex> -> tensor + %cast_0 = tensor.cast %35 : tensor to tensor<5xindex> + %36 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xelem_type>, tensor<5xindex>) -> tensor + %37 = chlo.broadcast_shift_right_arithmetic %34, %36 : (tensor, tensor) -> tensor + %cast_1 = tensor.cast %37 : tensor to tensor<*xelem_type> + scf.yield %cast_1 : tensor<*xelem_type> + } + scf.yield %31 : tensor<*xelem_type> + } + scf.yield %29 : tensor<*xelem_type> + } + scf.yield %27 : tensor<*xelem_type> + } + scf.yield %25 : tensor<*xelem_type> + } + scf.yield %18 : tensor<*xelem_type> + } + scf.yield %16 : tensor<*xelem_type> + } + %10 = shape.shape_of %arg0 : tensor<*xelem_type> -> tensor + %11 = shape.shape_of %arg1 : tensor<*xelem_type> -> tensor + %12 = shape.broadcast %10, %11 : tensor, tensor -> tensor + %13 = mhlo.dynamic_reshape %9, %12 : (tensor<*xelem_type>, tensor) -> tensor<*xelem_type> + return %13 : tensor<*xelem_type> } diff --git a/tensorflow/core/kernels/mlir_generated/op_definitions/right_shift_unsigned.mlir.tmpl b/tensorflow/core/kernels/mlir_generated/op_definitions/right_shift_unsigned.mlir.tmpl index bdd71bc93c58a5..647783751d5cd0 100644 --- a/tensorflow/core/kernels/mlir_generated/op_definitions/right_shift_unsigned.mlir.tmpl +++ b/tensorflow/core/kernels/mlir_generated/op_definitions/right_shift_unsigned.mlir.tmpl @@ -1,5 +1,129 @@ -func.func @RightShift_platform_elem_type_output_type(%arg0: tensor<*xelem_type>, %arg1: tensor<*xelem_type>) - -> tensor<*xoutput_type> attributes {tf_entry, llvm.emit_c_interface} { - %0 = chlo.broadcast_shift_right_logical %arg0, %arg1 : (tensor<*xelem_type>, tensor<*xelem_type>) -> tensor<*xelem_type> - func.return %0 : tensor<*xoutput_type> +func.func @RightShift_platform_elem_type_output_type(%arg0: tensor<*xelem_type>, %arg1: tensor<*xelem_type>) -> tensor<*xoutput_type> attributes {llvm.emit_c_interface, tf_entry} { + %0 = shape.const_shape [1, 1, 1, 1, 1] : tensor<5xindex> + %c5 = arith.constant 5 : index + %1 = shape.const_shape [1, 1, 1, 1] : tensor<4xindex> + %c4 = arith.constant 4 : index + %2 = shape.const_shape [1, 1, 1] : tensor<3xindex> + %c3 = arith.constant 3 : index + %3 = shape.const_shape [1, 1] : tensor<2xindex> + %c2 = arith.constant 2 : index + %4 = shape.const_shape [1] : tensor<1xindex> + %c1 = arith.constant 1 : index + %5 = shape.shape_of %arg0 : tensor<*xelem_type> -> tensor + %6 = shape.shape_of %arg1 : tensor<*xelem_type> -> tensor + %7 = shape.num_elements %5 : tensor -> index + %8 = arith.cmpi eq, %7, %c1 : index + %9 = scf.if %8 -> (tensor<*xelem_type>) { + %14 = shape.num_elements %6 : tensor -> index + %from_elements = tensor.from_elements %14 : tensor<1xindex> + %15 = mhlo.reshape %arg0 : (tensor<*xelem_type>) -> tensor + %16 = mhlo.dynamic_reshape %arg1, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %17 = chlo.broadcast_shift_right_logical %15, %16 : (tensor, tensor) -> tensor + %cast = tensor.cast %17 : tensor to tensor<*xelem_type> + scf.yield %cast : tensor<*xelem_type> + } else { + %14 = shape.num_elements %6 : tensor -> index + %15 = arith.cmpi eq, %14, %c1 : index + %16 = scf.if %15 -> (tensor<*xelem_type>) { + %17 = shape.num_elements %5 : tensor -> index + %from_elements = tensor.from_elements %17 : tensor<1xindex> + %18 = mhlo.dynamic_reshape %arg0, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %19 = mhlo.reshape %arg1 : (tensor<*xelem_type>) -> tensor + %20 = chlo.broadcast_shift_right_logical %18, %19 : (tensor, tensor) -> tensor + %cast = tensor.cast %20 : tensor to tensor<*xelem_type> + scf.yield %cast : tensor<*xelem_type> + } else { + %17 = shape.shape_eq %5, %6 : tensor, tensor + %18 = scf.if %17 -> (tensor<*xelem_type>) { + %19 = shape.any %5, %6 : tensor, tensor -> tensor + %20 = shape.num_elements %19 : tensor -> index + %from_elements = tensor.from_elements %20 : tensor<1xindex> + %21 = mhlo.dynamic_reshape %arg0, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %22 = mhlo.dynamic_reshape %arg1, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %23 = chlo.broadcast_shift_right_logical %21, %22 : (tensor, tensor) -> tensor + %cast = tensor.cast %23 : tensor to tensor<*xelem_type> + scf.yield %cast : tensor<*xelem_type> + } else { + %19:2 = chlo.minimum_broadcast_shapes %5, %6 : tensor, tensor -> tensor, tensor + %20 = shape.rank %19#0 : tensor -> index + %21 = shape.rank %19#1 : tensor -> index + %22 = arith.cmpi sgt, %20, %21 : index + %23 = arith.select %22, %20, %21 : index + %24 = arith.cmpi ule, %23, %c1 : index + %25 = scf.if %24 -> (tensor<*xelem_type>) { + %26 = shape.broadcast %19#0, %4 : tensor, tensor<1xindex> -> tensor + %cast = tensor.cast %26 : tensor to tensor<1xindex> + %27 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %28 = shape.broadcast %19#1, %4 : tensor, tensor<1xindex> -> tensor + %cast_0 = tensor.cast %28 : tensor to tensor<1xindex> + %29 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %30 = chlo.broadcast_shift_right_logical %27, %29 : (tensor, tensor) -> tensor + %cast_1 = tensor.cast %30 : tensor to tensor<*xelem_type> + scf.yield %cast_1 : tensor<*xelem_type> + } else { + %26 = arith.cmpi ule, %23, %c2 : index + %27 = scf.if %26 -> (tensor<*xelem_type>) { + %28 = shape.broadcast %19#0, %3 : tensor, tensor<2xindex> -> tensor + %cast = tensor.cast %28 : tensor to tensor<2xindex> + %29 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xelem_type>, tensor<2xindex>) -> tensor + %30 = shape.broadcast %19#1, %3 : tensor, tensor<2xindex> -> tensor + %cast_0 = tensor.cast %30 : tensor to tensor<2xindex> + %31 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xelem_type>, tensor<2xindex>) -> tensor + %32 = chlo.broadcast_shift_right_logical %29, %31 : (tensor, tensor) -> tensor + %cast_1 = tensor.cast %32 : tensor to tensor<*xelem_type> + scf.yield %cast_1 : tensor<*xelem_type> + } else { + %28 = arith.cmpi ule, %23, %c3 : index + %29 = scf.if %28 -> (tensor<*xelem_type>) { + %30 = shape.broadcast %19#0, %2 : tensor, tensor<3xindex> -> tensor + %cast = tensor.cast %30 : tensor to tensor<3xindex> + %31 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xelem_type>, tensor<3xindex>) -> tensor + %32 = shape.broadcast %19#1, %2 : tensor, tensor<3xindex> -> tensor + %cast_0 = tensor.cast %32 : tensor to tensor<3xindex> + %33 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xelem_type>, tensor<3xindex>) -> tensor + %34 = chlo.broadcast_shift_right_logical %31, %33 : (tensor, tensor) -> tensor + %cast_1 = tensor.cast %34 : tensor to tensor<*xelem_type> + scf.yield %cast_1 : tensor<*xelem_type> + } else { + %30 = arith.cmpi ule, %23, %c4 : index + %31 = scf.if %30 -> (tensor<*xelem_type>) { + %32 = shape.broadcast %19#0, %1 : tensor, tensor<4xindex> -> tensor + %cast = tensor.cast %32 : tensor to tensor<4xindex> + %33 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xelem_type>, tensor<4xindex>) -> tensor + %34 = shape.broadcast %19#1, %1 : tensor, tensor<4xindex> -> tensor + %cast_0 = tensor.cast %34 : tensor to tensor<4xindex> + %35 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xelem_type>, tensor<4xindex>) -> tensor + %36 = chlo.broadcast_shift_right_logical %33, %35 : (tensor, tensor) -> tensor + %cast_1 = tensor.cast %36 : tensor to tensor<*xelem_type> + scf.yield %cast_1 : tensor<*xelem_type> + } else { + %32 = arith.cmpi ule, %23, %c5 : index + cf.assert %32, "Input for dynamic binary or n-ary op lowering was of a rank greater than 5" + %33 = shape.broadcast %19#0, %0 : tensor, tensor<5xindex> -> tensor + %cast = tensor.cast %33 : tensor to tensor<5xindex> + %34 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xelem_type>, tensor<5xindex>) -> tensor + %35 = shape.broadcast %19#1, %0 : tensor, tensor<5xindex> -> tensor + %cast_0 = tensor.cast %35 : tensor to tensor<5xindex> + %36 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xelem_type>, tensor<5xindex>) -> tensor + %37 = chlo.broadcast_shift_right_logical %34, %36 : (tensor, tensor) -> tensor + %cast_1 = tensor.cast %37 : tensor to tensor<*xelem_type> + scf.yield %cast_1 : tensor<*xelem_type> + } + scf.yield %31 : tensor<*xelem_type> + } + scf.yield %29 : tensor<*xelem_type> + } + scf.yield %27 : tensor<*xelem_type> + } + scf.yield %25 : tensor<*xelem_type> + } + scf.yield %18 : tensor<*xelem_type> + } + scf.yield %16 : tensor<*xelem_type> + } + %10 = shape.shape_of %arg0 : tensor<*xelem_type> -> tensor + %11 = shape.shape_of %arg1 : tensor<*xelem_type> -> tensor + %12 = shape.broadcast %10, %11 : tensor, tensor -> tensor + %13 = mhlo.dynamic_reshape %9, %12 : (tensor<*xelem_type>, tensor) -> tensor<*xelem_type> + return %13 : tensor<*xelem_type> } diff --git a/tensorflow/core/kernels/mlir_generated/op_definitions/rint.mlir.tmpl b/tensorflow/core/kernels/mlir_generated/op_definitions/rint.mlir.tmpl index 347c322bc14da7..1ed002e10c9176 100644 --- a/tensorflow/core/kernels/mlir_generated/op_definitions/rint.mlir.tmpl +++ b/tensorflow/core/kernels/mlir_generated/op_definitions/rint.mlir.tmpl @@ -1,23 +1,27 @@ -func.func @Rint_platform_elem_type_output_type(%arg0: tensor<*xelem_type>) - -> tensor<*xoutput_type> attributes {tf_entry, llvm.emit_c_interface} { - %0 = mhlo.constant dense<0.000000e+00> : tensor - %1 = mhlo.floor %arg0 : tensor<*xelem_type> - %2 = chlo.broadcast_subtract %arg0, %1 : (tensor<*xelem_type>, tensor<*xelem_type>) -> tensor<*xelem_type> - %3 = mhlo.constant dense<5.000000e-01> : tensor - %4 = chlo.broadcast_compare %2, %3 {comparison_direction = #chlo} : (tensor<*xelem_type>, tensor) -> tensor<*xi1> - %5 = chlo.broadcast_compare %2, %3 {comparison_direction = #chlo} : (tensor<*xelem_type>, tensor) -> tensor<*xi1> - %6 = mhlo.constant dense<2.000000e+00> : tensor - %7 = chlo.broadcast_multiply %arg0, %3 : (tensor<*xelem_type>, tensor) -> tensor<*xelem_type> - %8 = mhlo.floor %7 : tensor<*xelem_type> - %9 = chlo.broadcast_multiply %8, %6 : (tensor<*xelem_type>, tensor) -> tensor<*xelem_type> - %10 = chlo.broadcast_subtract %1, %9 : (tensor<*xelem_type>, tensor<*xelem_type>) -> tensor<*xelem_type> - %11 = mhlo.constant dense<1.000000e+00> : tensor - %12 = chlo.broadcast_compare %10, %11 {comparison_direction = #chlo} : (tensor<*xelem_type>, tensor) -> tensor<*xi1> - %13 = chlo.broadcast_and %5, %12 : (tensor<*xi1>, tensor<*xi1>) -> tensor<*xi1> - %14 = chlo.broadcast_or %4, %13 : (tensor<*xi1>, tensor<*xi1>) -> tensor<*xi1> - %15 = chlo.broadcast_add %1, %11 : (tensor<*xelem_type>, tensor) -> tensor<*xelem_type> - %16 = chlo.broadcast_select %14, %15, %1 : (tensor<*xi1>, tensor<*xelem_type>, tensor<*xelem_type>) -> tensor<*xelem_type> - %17 = chlo.broadcast_compare %16, %0 {comparison_direction = #chlo} : (tensor<*xelem_type>, tensor) -> tensor<*xi1> - %18 = chlo.broadcast_select %17, %0, %16 : (tensor<*xi1>, tensor, tensor<*xelem_type>) -> tensor<*xelem_type> - func.return %18 : tensor<*xoutput_type> +func.func @Rint_platform_elem_type_output_type(%arg0: tensor<*xelem_type>) -> tensor<*xoutput_type> attributes {llvm.emit_c_interface, tf_entry} { + %0 = mhlo.constant dense<1.000000e+00> : tensor + %1 = mhlo.constant dense<2.000000e+00> : tensor + %2 = mhlo.constant dense<5.000000e-01> : tensor + %3 = mhlo.constant dense<0.000000e+00> : tensor + %4 = shape.shape_of %arg0 : tensor<*xelem_type> -> tensor + %5 = shape.num_elements %4 : tensor -> index + %from_elements = tensor.from_elements %5 : tensor<1xindex> + %6 = mhlo.dynamic_reshape %arg0, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %7 = mhlo.floor %6 : tensor + %8 = chlo.broadcast_subtract %6, %7 : (tensor, tensor) -> tensor + %9 = chlo.broadcast_compare %8, %2 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %10 = chlo.broadcast_compare %8, %2 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %11 = chlo.broadcast_multiply %6, %2 : (tensor, tensor) -> tensor + %12 = mhlo.floor %11 : tensor + %13 = chlo.broadcast_multiply %12, %1 : (tensor, tensor) -> tensor + %14 = chlo.broadcast_subtract %7, %13 : (tensor, tensor) -> tensor + %15 = chlo.broadcast_compare %14, %0 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %16 = chlo.broadcast_and %10, %15 : (tensor, tensor) -> tensor + %17 = chlo.broadcast_or %9, %16 : (tensor, tensor) -> tensor + %18 = chlo.broadcast_add %7, %0 : (tensor, tensor) -> tensor + %19 = chlo.broadcast_select %17, %18, %7 : (tensor, tensor, tensor) -> tensor + %20 = chlo.broadcast_compare %19, %3 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %21 = chlo.broadcast_select %20, %3, %19 : (tensor, tensor, tensor) -> tensor + %22 = mhlo.dynamic_reshape %21, %4 : (tensor, tensor) -> tensor<*xelem_type> + return %22 : tensor<*xelem_type> } diff --git a/tensorflow/core/kernels/mlir_generated/op_definitions/round.mlir.tmpl b/tensorflow/core/kernels/mlir_generated/op_definitions/round.mlir.tmpl index 8790dccd9431da..eabbd15e939f06 100644 --- a/tensorflow/core/kernels/mlir_generated/op_definitions/round.mlir.tmpl +++ b/tensorflow/core/kernels/mlir_generated/op_definitions/round.mlir.tmpl @@ -1,4 +1,3 @@ -func.func @Round_platform_elem_type_output_type(%arg0: tensor<*xelem_type>) - -> tensor<*xoutput_type> attributes {tf_entry, llvm.emit_c_interface} { - func.return %arg0 : tensor<*xoutput_type> +func.func @Round_platform_elem_type_output_type(%arg0: tensor<*xelem_type>) -> tensor<*xoutput_type> attributes {llvm.emit_c_interface, tf_entry} { + return %arg0 : tensor<*xelem_type> } diff --git a/tensorflow/core/kernels/mlir_generated/op_definitions/round_float.mlir.tmpl b/tensorflow/core/kernels/mlir_generated/op_definitions/round_float.mlir.tmpl index 4e6549245f24fa..63d31af013b6d1 100644 --- a/tensorflow/core/kernels/mlir_generated/op_definitions/round_float.mlir.tmpl +++ b/tensorflow/core/kernels/mlir_generated/op_definitions/round_float.mlir.tmpl @@ -1,23 +1,27 @@ -func.func @Round_platform_elem_type_output_type(%arg0: tensor<*xelem_type>) - -> tensor<*xoutput_type> attributes {tf_entry, llvm.emit_c_interface} { - %0 = mhlo.constant dense<0.000000e+00> : tensor - %1 = mhlo.floor %arg0 : tensor<*xelem_type> - %2 = chlo.broadcast_subtract %arg0, %1 : (tensor<*xelem_type>, tensor<*xelem_type>) -> tensor<*xelem_type> - %3 = mhlo.constant dense<5.000000e-01> : tensor - %4 = chlo.broadcast_compare %2, %3 {comparison_direction = #chlo} : (tensor<*xelem_type>, tensor) -> tensor<*xi1> - %5 = chlo.broadcast_compare %2, %3 {comparison_direction = #chlo} : (tensor<*xelem_type>, tensor) -> tensor<*xi1> - %6 = mhlo.constant dense<2.000000e+00> : tensor - %7 = chlo.broadcast_multiply %arg0, %3 : (tensor<*xelem_type>, tensor) -> tensor<*xelem_type> - %8 = mhlo.floor %7 : tensor<*xelem_type> - %9 = chlo.broadcast_multiply %8, %6 : (tensor<*xelem_type>, tensor) -> tensor<*xelem_type> - %10 = chlo.broadcast_subtract %1, %9 : (tensor<*xelem_type>, tensor<*xelem_type>) -> tensor<*xelem_type> - %11 = mhlo.constant dense<1.000000e+00> : tensor - %12 = chlo.broadcast_compare %10, %11 {comparison_direction = #chlo} : (tensor<*xelem_type>, tensor) -> tensor<*xi1> - %13 = chlo.broadcast_and %5, %12 : (tensor<*xi1>, tensor<*xi1>) -> tensor<*xi1> - %14 = chlo.broadcast_or %4, %13 : (tensor<*xi1>, tensor<*xi1>) -> tensor<*xi1> - %15 = chlo.broadcast_add %1, %11 : (tensor<*xelem_type>, tensor) -> tensor<*xelem_type> - %16 = chlo.broadcast_select %14, %15, %1 : (tensor<*xi1>, tensor<*xelem_type>, tensor<*xelem_type>) -> tensor<*xelem_type> - %17 = chlo.broadcast_compare %16, %0 {comparison_direction = #chlo} : (tensor<*xelem_type>, tensor) -> tensor<*xi1> - %18 = chlo.broadcast_select %17, %0, %16 : (tensor<*xi1>, tensor, tensor<*xelem_type>) -> tensor<*xelem_type> - func.return %18 : tensor<*xoutput_type> +func.func @Round_platform_elem_type_output_type(%arg0: tensor<*xelem_type>) -> tensor<*xoutput_type> attributes {llvm.emit_c_interface, tf_entry} { + %0 = mhlo.constant dense<1.000000e+00> : tensor + %1 = mhlo.constant dense<2.000000e+00> : tensor + %2 = mhlo.constant dense<5.000000e-01> : tensor + %3 = mhlo.constant dense<0.000000e+00> : tensor + %4 = shape.shape_of %arg0 : tensor<*xelem_type> -> tensor + %5 = shape.num_elements %4 : tensor -> index + %from_elements = tensor.from_elements %5 : tensor<1xindex> + %6 = mhlo.dynamic_reshape %arg0, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %7 = mhlo.floor %6 : tensor + %8 = chlo.broadcast_subtract %6, %7 : (tensor, tensor) -> tensor + %9 = chlo.broadcast_compare %8, %2 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %10 = chlo.broadcast_compare %8, %2 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %11 = chlo.broadcast_multiply %6, %2 : (tensor, tensor) -> tensor + %12 = mhlo.floor %11 : tensor + %13 = chlo.broadcast_multiply %12, %1 : (tensor, tensor) -> tensor + %14 = chlo.broadcast_subtract %7, %13 : (tensor, tensor) -> tensor + %15 = chlo.broadcast_compare %14, %0 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %16 = chlo.broadcast_and %10, %15 : (tensor, tensor) -> tensor + %17 = chlo.broadcast_or %9, %16 : (tensor, tensor) -> tensor + %18 = chlo.broadcast_add %7, %0 : (tensor, tensor) -> tensor + %19 = chlo.broadcast_select %17, %18, %7 : (tensor, tensor, tensor) -> tensor + %20 = chlo.broadcast_compare %19, %3 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %21 = chlo.broadcast_select %20, %3, %19 : (tensor, tensor, tensor) -> tensor + %22 = mhlo.dynamic_reshape %21, %4 : (tensor, tensor) -> tensor<*xelem_type> + return %22 : tensor<*xelem_type> } diff --git a/tensorflow/core/kernels/mlir_generated/op_definitions/rsqrt.mlir.tmpl b/tensorflow/core/kernels/mlir_generated/op_definitions/rsqrt.mlir.tmpl index 5489e909beed69..c3f74cbde99bbc 100644 --- a/tensorflow/core/kernels/mlir_generated/op_definitions/rsqrt.mlir.tmpl +++ b/tensorflow/core/kernels/mlir_generated/op_definitions/rsqrt.mlir.tmpl @@ -1,5 +1,9 @@ -func.func @Rsqrt_platform_elem_type_output_type(%arg0: tensor<*xelem_type>) - -> tensor<*xoutput_type> attributes {tf_entry, llvm.emit_c_interface} { - %0 = mhlo.rsqrt %arg0 : tensor<*xelem_type> - func.return %0 : tensor<*xoutput_type> +func.func @Rsqrt_platform_elem_type_output_type(%arg0: tensor<*xelem_type>) -> tensor<*xoutput_type> attributes {llvm.emit_c_interface, tf_entry} { + %0 = shape.shape_of %arg0 : tensor<*xelem_type> -> tensor + %1 = shape.num_elements %0 : tensor -> index + %from_elements = tensor.from_elements %1 : tensor<1xindex> + %2 = mhlo.dynamic_reshape %arg0, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %3 = mhlo.rsqrt %2 : tensor + %4 = mhlo.dynamic_reshape %3, %0 : (tensor, tensor) -> tensor<*xelem_type> + return %4 : tensor<*xelem_type> } diff --git a/tensorflow/core/kernels/mlir_generated/op_definitions/select_v2.mlir.tmpl b/tensorflow/core/kernels/mlir_generated/op_definitions/select_v2.mlir.tmpl index e37d98ae4a8be4..0c39b00d033bce 100644 --- a/tensorflow/core/kernels/mlir_generated/op_definitions/select_v2.mlir.tmpl +++ b/tensorflow/core/kernels/mlir_generated/op_definitions/select_v2.mlir.tmpl @@ -1,5 +1,183 @@ -func.func @SelectV2_platform_elem_type_output_type(%arg0: tensor<*xi1>, %arg1: tensor<*xelem_type>, %arg2: tensor<*xelem_type>) - -> tensor<*xoutput_type> attributes {tf_entry, llvm.emit_c_interface} { - %0 = chlo.broadcast_select %arg0, %arg1, %arg2 : (tensor<*xi1>, tensor<*xelem_type>, tensor<*xelem_type>) -> tensor<*xelem_type> - func.return %0 : tensor<*xoutput_type> +func.func @SelectV2_platform_elem_type_output_type(%arg0: tensor<*xi1>, %arg1: tensor<*xelem_type>, %arg2: tensor<*xelem_type>) -> tensor<*xoutput_type> attributes {llvm.emit_c_interface, tf_entry} { + %0 = shape.const_shape [1, 1, 1, 1, 1, 1, 1, 1] : tensor<8xindex> + %c8 = arith.constant 8 : index + %1 = shape.const_shape [1, 1, 1, 1, 1, 1, 1] : tensor<7xindex> + %c7 = arith.constant 7 : index + %2 = shape.const_shape [1, 1, 1, 1, 1, 1] : tensor<6xindex> + %c6 = arith.constant 6 : index + %3 = shape.const_shape [1, 1, 1, 1, 1] : tensor<5xindex> + %c5 = arith.constant 5 : index + %4 = shape.const_shape [1, 1, 1, 1] : tensor<4xindex> + %c4 = arith.constant 4 : index + %5 = shape.const_shape [1, 1, 1] : tensor<3xindex> + %c3 = arith.constant 3 : index + %6 = shape.const_shape [1, 1] : tensor<2xindex> + %c2 = arith.constant 2 : index + %7 = shape.const_shape [1] : tensor<1xindex> + %c1 = arith.constant 1 : index + %8 = shape.shape_of %arg0 : tensor<*xi1> -> tensor + %9 = shape.shape_of %arg1 : tensor<*xelem_type> -> tensor + %10 = shape.shape_of %arg2 : tensor<*xelem_type> -> tensor + %11 = shape.shape_eq %8, %9 : tensor, tensor + %12 = shape.shape_eq %8, %10 : tensor, tensor + %13 = arith.andi %11, %12 : i1 + %14 = scf.if %13 -> (tensor<*xelem_type>) { + %20 = shape.any %8, %9, %10 : tensor, tensor, tensor -> tensor + %21 = shape.num_elements %20 : tensor -> index + %from_elements = tensor.from_elements %21 : tensor<1xindex> + %22 = mhlo.dynamic_reshape %arg0, %from_elements : (tensor<*xi1>, tensor<1xindex>) -> tensor + %23 = mhlo.dynamic_reshape %arg1, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %24 = mhlo.dynamic_reshape %arg2, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %25 = chlo.broadcast_select %22, %23, %24 : (tensor, tensor, tensor) -> tensor + %cast = tensor.cast %25 : tensor to tensor<*xelem_type> + scf.yield %cast : tensor<*xelem_type> + } else { + %20:3 = chlo.minimum_broadcast_shapes %8, %9, %10 : tensor, tensor, tensor -> tensor, tensor, tensor + %21 = shape.rank %20#0 : tensor -> index + %22 = shape.rank %20#1 : tensor -> index + %23 = arith.cmpi sgt, %21, %22 : index + %24 = arith.select %23, %21, %22 : index + %25 = shape.rank %20#2 : tensor -> index + %26 = arith.cmpi sgt, %24, %25 : index + %27 = arith.select %26, %24, %25 : index + %28 = arith.cmpi ule, %27, %c1 : index + %29 = scf.if %28 -> (tensor<*xelem_type>) { + %30 = shape.broadcast %20#0, %7 : tensor, tensor<1xindex> -> tensor + %cast = tensor.cast %30 : tensor to tensor<1xindex> + %31 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xi1>, tensor<1xindex>) -> tensor + %32 = shape.broadcast %20#1, %7 : tensor, tensor<1xindex> -> tensor + %cast_0 = tensor.cast %32 : tensor to tensor<1xindex> + %33 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %34 = shape.broadcast %20#2, %7 : tensor, tensor<1xindex> -> tensor + %cast_1 = tensor.cast %34 : tensor to tensor<1xindex> + %35 = mhlo.dynamic_reshape %arg2, %cast_1 : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %36 = chlo.broadcast_select %31, %33, %35 : (tensor, tensor, tensor) -> tensor + %cast_2 = tensor.cast %36 : tensor to tensor<*xelem_type> + scf.yield %cast_2 : tensor<*xelem_type> + } else { + %30 = arith.cmpi ule, %27, %c2 : index + %31 = scf.if %30 -> (tensor<*xelem_type>) { + %32 = shape.broadcast %20#0, %6 : tensor, tensor<2xindex> -> tensor + %cast = tensor.cast %32 : tensor to tensor<2xindex> + %33 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xi1>, tensor<2xindex>) -> tensor + %34 = shape.broadcast %20#1, %6 : tensor, tensor<2xindex> -> tensor + %cast_0 = tensor.cast %34 : tensor to tensor<2xindex> + %35 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xelem_type>, tensor<2xindex>) -> tensor + %36 = shape.broadcast %20#2, %6 : tensor, tensor<2xindex> -> tensor + %cast_1 = tensor.cast %36 : tensor to tensor<2xindex> + %37 = mhlo.dynamic_reshape %arg2, %cast_1 : (tensor<*xelem_type>, tensor<2xindex>) -> tensor + %38 = chlo.broadcast_select %33, %35, %37 : (tensor, tensor, tensor) -> tensor + %cast_2 = tensor.cast %38 : tensor to tensor<*xelem_type> + scf.yield %cast_2 : tensor<*xelem_type> + } else { + %32 = arith.cmpi ule, %27, %c3 : index + %33 = scf.if %32 -> (tensor<*xelem_type>) { + %34 = shape.broadcast %20#0, %5 : tensor, tensor<3xindex> -> tensor + %cast = tensor.cast %34 : tensor to tensor<3xindex> + %35 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xi1>, tensor<3xindex>) -> tensor + %36 = shape.broadcast %20#1, %5 : tensor, tensor<3xindex> -> tensor + %cast_0 = tensor.cast %36 : tensor to tensor<3xindex> + %37 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xelem_type>, tensor<3xindex>) -> tensor + %38 = shape.broadcast %20#2, %5 : tensor, tensor<3xindex> -> tensor + %cast_1 = tensor.cast %38 : tensor to tensor<3xindex> + %39 = mhlo.dynamic_reshape %arg2, %cast_1 : (tensor<*xelem_type>, tensor<3xindex>) -> tensor + %40 = chlo.broadcast_select %35, %37, %39 : (tensor, tensor, tensor) -> tensor + %cast_2 = tensor.cast %40 : tensor to tensor<*xelem_type> + scf.yield %cast_2 : tensor<*xelem_type> + } else { + %34 = arith.cmpi ule, %27, %c4 : index + %35 = scf.if %34 -> (tensor<*xelem_type>) { + %36 = shape.broadcast %20#0, %4 : tensor, tensor<4xindex> -> tensor + %cast = tensor.cast %36 : tensor to tensor<4xindex> + %37 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xi1>, tensor<4xindex>) -> tensor + %38 = shape.broadcast %20#1, %4 : tensor, tensor<4xindex> -> tensor + %cast_0 = tensor.cast %38 : tensor to tensor<4xindex> + %39 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xelem_type>, tensor<4xindex>) -> tensor + %40 = shape.broadcast %20#2, %4 : tensor, tensor<4xindex> -> tensor + %cast_1 = tensor.cast %40 : tensor to tensor<4xindex> + %41 = mhlo.dynamic_reshape %arg2, %cast_1 : (tensor<*xelem_type>, tensor<4xindex>) -> tensor + %42 = chlo.broadcast_select %37, %39, %41 : (tensor, tensor, tensor) -> tensor + %cast_2 = tensor.cast %42 : tensor to tensor<*xelem_type> + scf.yield %cast_2 : tensor<*xelem_type> + } else { + %36 = arith.cmpi ule, %27, %c5 : index + %37 = scf.if %36 -> (tensor<*xelem_type>) { + %38 = shape.broadcast %20#0, %3 : tensor, tensor<5xindex> -> tensor + %cast = tensor.cast %38 : tensor to tensor<5xindex> + %39 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xi1>, tensor<5xindex>) -> tensor + %40 = shape.broadcast %20#1, %3 : tensor, tensor<5xindex> -> tensor + %cast_0 = tensor.cast %40 : tensor to tensor<5xindex> + %41 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xelem_type>, tensor<5xindex>) -> tensor + %42 = shape.broadcast %20#2, %3 : tensor, tensor<5xindex> -> tensor + %cast_1 = tensor.cast %42 : tensor to tensor<5xindex> + %43 = mhlo.dynamic_reshape %arg2, %cast_1 : (tensor<*xelem_type>, tensor<5xindex>) -> tensor + %44 = chlo.broadcast_select %39, %41, %43 : (tensor, tensor, tensor) -> tensor + %cast_2 = tensor.cast %44 : tensor to tensor<*xelem_type> + scf.yield %cast_2 : tensor<*xelem_type> + } else { + %38 = arith.cmpi ule, %27, %c6 : index + %39 = scf.if %38 -> (tensor<*xelem_type>) { + %40 = shape.broadcast %20#0, %2 : tensor, tensor<6xindex> -> tensor + %cast = tensor.cast %40 : tensor to tensor<6xindex> + %41 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xi1>, tensor<6xindex>) -> tensor + %42 = shape.broadcast %20#1, %2 : tensor, tensor<6xindex> -> tensor + %cast_0 = tensor.cast %42 : tensor to tensor<6xindex> + %43 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xelem_type>, tensor<6xindex>) -> tensor + %44 = shape.broadcast %20#2, %2 : tensor, tensor<6xindex> -> tensor + %cast_1 = tensor.cast %44 : tensor to tensor<6xindex> + %45 = mhlo.dynamic_reshape %arg2, %cast_1 : (tensor<*xelem_type>, tensor<6xindex>) -> tensor + %46 = chlo.broadcast_select %41, %43, %45 : (tensor, tensor, tensor) -> tensor + %cast_2 = tensor.cast %46 : tensor to tensor<*xelem_type> + scf.yield %cast_2 : tensor<*xelem_type> + } else { + %40 = arith.cmpi ule, %27, %c7 : index + %41 = scf.if %40 -> (tensor<*xelem_type>) { + %42 = shape.broadcast %20#0, %1 : tensor, tensor<7xindex> -> tensor + %cast = tensor.cast %42 : tensor to tensor<7xindex> + %43 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xi1>, tensor<7xindex>) -> tensor + %44 = shape.broadcast %20#1, %1 : tensor, tensor<7xindex> -> tensor + %cast_0 = tensor.cast %44 : tensor to tensor<7xindex> + %45 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xelem_type>, tensor<7xindex>) -> tensor + %46 = shape.broadcast %20#2, %1 : tensor, tensor<7xindex> -> tensor + %cast_1 = tensor.cast %46 : tensor to tensor<7xindex> + %47 = mhlo.dynamic_reshape %arg2, %cast_1 : (tensor<*xelem_type>, tensor<7xindex>) -> tensor + %48 = chlo.broadcast_select %43, %45, %47 : (tensor, tensor, tensor) -> tensor + %cast_2 = tensor.cast %48 : tensor to tensor<*xelem_type> + scf.yield %cast_2 : tensor<*xelem_type> + } else { + %42 = arith.cmpi ule, %27, %c8 : index + cf.assert %42, "Input for dynamic binary or n-ary op lowering was of a rank greater than 8" + %43 = shape.broadcast %20#0, %0 : tensor, tensor<8xindex> -> tensor + %cast = tensor.cast %43 : tensor to tensor<8xindex> + %44 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xi1>, tensor<8xindex>) -> tensor + %45 = shape.broadcast %20#1, %0 : tensor, tensor<8xindex> -> tensor + %cast_0 = tensor.cast %45 : tensor to tensor<8xindex> + %46 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xelem_type>, tensor<8xindex>) -> tensor + %47 = shape.broadcast %20#2, %0 : tensor, tensor<8xindex> -> tensor + %cast_1 = tensor.cast %47 : tensor to tensor<8xindex> + %48 = mhlo.dynamic_reshape %arg2, %cast_1 : (tensor<*xelem_type>, tensor<8xindex>) -> tensor + %49 = chlo.broadcast_select %44, %46, %48 : (tensor, tensor, tensor) -> tensor + %cast_2 = tensor.cast %49 : tensor to tensor<*xelem_type> + scf.yield %cast_2 : tensor<*xelem_type> + } + scf.yield %41 : tensor<*xelem_type> + } + scf.yield %39 : tensor<*xelem_type> + } + scf.yield %37 : tensor<*xelem_type> + } + scf.yield %35 : tensor<*xelem_type> + } + scf.yield %33 : tensor<*xelem_type> + } + scf.yield %31 : tensor<*xelem_type> + } + scf.yield %29 : tensor<*xelem_type> + } + %15 = shape.shape_of %arg0 : tensor<*xi1> -> tensor + %16 = shape.shape_of %arg1 : tensor<*xelem_type> -> tensor + %17 = shape.shape_of %arg2 : tensor<*xelem_type> -> tensor + %18 = shape.broadcast %15, %16, %17 : tensor, tensor, tensor -> tensor + %19 = mhlo.dynamic_reshape %14, %18 : (tensor<*xelem_type>, tensor) -> tensor<*xelem_type> + return %19 : tensor<*xelem_type> } diff --git a/tensorflow/core/kernels/mlir_generated/op_definitions/selu.mlir.tmpl b/tensorflow/core/kernels/mlir_generated/op_definitions/selu.mlir.tmpl index b769ef9f1e28a8..5fd921268e5dac 100644 --- a/tensorflow/core/kernels/mlir_generated/op_definitions/selu.mlir.tmpl +++ b/tensorflow/core/kernels/mlir_generated/op_definitions/selu.mlir.tmpl @@ -1,12 +1,16 @@ -func.func @Selu_platform_elem_type_output_type(%arg0: tensor<*xelem_type>) - -> tensor<*xoutput_type> attributes {tf_entry, llvm.emit_c_interface} { - %0 = mhlo.constant dense<0.000000e+00> : tensor - %1 = chlo.broadcast_compare %arg0, %0 {comparison_direction = #chlo} : (tensor<*xelem_type>, tensor) -> tensor<*xi1> - %2 = mhlo.constant dense<1.05070102> : tensor - %3 = chlo.broadcast_multiply %arg0, %2 : (tensor<*xelem_type>, tensor) -> tensor<*xelem_type> - %4 = mhlo.constant dense<1.75809932> : tensor - %5 = mhlo.exponential_minus_one %arg0 : tensor<*xelem_type> - %6 = chlo.broadcast_multiply %5, %4 : (tensor<*xelem_type>, tensor) -> tensor<*xelem_type> - %7 = chlo.broadcast_select %1, %3, %6 : (tensor<*xi1>, tensor<*xelem_type>, tensor<*xelem_type>) -> tensor<*xelem_type> - func.return %7 : tensor<*xoutput_type> +func.func @Selu_platform_elem_type_output_type(%arg0: tensor<*xelem_type>) -> tensor<*xoutput_type> attributes {llvm.emit_c_interface, tf_entry} { + %0 = mhlo.constant dense<1.75809932> : tensor + %1 = mhlo.constant dense<1.05070102> : tensor + %2 = mhlo.constant dense<0.000000e+00> : tensor + %3 = shape.shape_of %arg0 : tensor<*xelem_type> -> tensor + %4 = shape.num_elements %3 : tensor -> index + %from_elements = tensor.from_elements %4 : tensor<1xindex> + %5 = mhlo.dynamic_reshape %arg0, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %6 = chlo.broadcast_compare %5, %2 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %7 = chlo.broadcast_multiply %5, %1 : (tensor, tensor) -> tensor + %8 = mhlo.exponential_minus_one %5 : tensor + %9 = chlo.broadcast_multiply %8, %0 : (tensor, tensor) -> tensor + %10 = chlo.broadcast_select %6, %7, %9 : (tensor, tensor, tensor) -> tensor + %11 = mhlo.dynamic_reshape %10, %3 : (tensor, tensor) -> tensor<*xelem_type> + return %11 : tensor<*xelem_type> } diff --git a/tensorflow/core/kernels/mlir_generated/op_definitions/sigmoid.mlir.tmpl b/tensorflow/core/kernels/mlir_generated/op_definitions/sigmoid.mlir.tmpl index 61baae4b916c88..a4a8bf75c4cf82 100644 --- a/tensorflow/core/kernels/mlir_generated/op_definitions/sigmoid.mlir.tmpl +++ b/tensorflow/core/kernels/mlir_generated/op_definitions/sigmoid.mlir.tmpl @@ -1,5 +1,9 @@ -func.func @Sigmoid_platform_elem_type_output_type(%arg0: tensor<*xelem_type>) - -> tensor<*xoutput_type> attributes {tf_entry, llvm.emit_c_interface} { - %0 = mhlo.logistic %arg0 : tensor<*xelem_type> - func.return %0 : tensor<*xoutput_type> +func.func @Sigmoid_platform_elem_type_output_type(%arg0: tensor<*xelem_type>) -> tensor<*xoutput_type> attributes {llvm.emit_c_interface, tf_entry} { + %0 = shape.shape_of %arg0 : tensor<*xelem_type> -> tensor + %1 = shape.num_elements %0 : tensor -> index + %from_elements = tensor.from_elements %1 : tensor<1xindex> + %2 = mhlo.dynamic_reshape %arg0, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %3 = mhlo.logistic %2 : tensor + %4 = mhlo.dynamic_reshape %3, %0 : (tensor, tensor) -> tensor<*xelem_type> + return %4 : tensor<*xelem_type> } diff --git a/tensorflow/core/kernels/mlir_generated/op_definitions/sign.mlir.tmpl b/tensorflow/core/kernels/mlir_generated/op_definitions/sign.mlir.tmpl index 01065d26d9f3d1..c7ce54c81b06b3 100644 --- a/tensorflow/core/kernels/mlir_generated/op_definitions/sign.mlir.tmpl +++ b/tensorflow/core/kernels/mlir_generated/op_definitions/sign.mlir.tmpl @@ -1,5 +1,9 @@ -func.func @Sign_platform_elem_type_output_type(%arg0: tensor<*xelem_type>) - -> tensor<*xoutput_type> attributes {tf_entry, llvm.emit_c_interface} { - %0 = mhlo.sign %arg0 : tensor<*xelem_type> - func.return %0 : tensor<*xoutput_type> +func.func @Sign_platform_elem_type_output_type(%arg0: tensor<*xelem_type>) -> tensor<*xoutput_type> attributes {llvm.emit_c_interface, tf_entry} { + %0 = shape.shape_of %arg0 : tensor<*xelem_type> -> tensor + %1 = shape.num_elements %0 : tensor -> index + %from_elements = tensor.from_elements %1 : tensor<1xindex> + %2 = mhlo.dynamic_reshape %arg0, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %3 = mhlo.sign %2 : tensor + %4 = mhlo.dynamic_reshape %3, %0 : (tensor, tensor) -> tensor<*xelem_type> + return %4 : tensor<*xelem_type> } diff --git a/tensorflow/core/kernels/mlir_generated/op_definitions/sin.mlir.tmpl b/tensorflow/core/kernels/mlir_generated/op_definitions/sin.mlir.tmpl index 9d6d99bddb26d5..2ee09d30a8ceab 100644 --- a/tensorflow/core/kernels/mlir_generated/op_definitions/sin.mlir.tmpl +++ b/tensorflow/core/kernels/mlir_generated/op_definitions/sin.mlir.tmpl @@ -1,5 +1,9 @@ -func.func @Sin_platform_elem_type_output_type(%arg0: tensor<*xelem_type>) - -> tensor<*xoutput_type> attributes {tf_entry, llvm.emit_c_interface} { - %0 = mhlo.sine %arg0 : tensor<*xelem_type> - func.return %0 : tensor<*xoutput_type> +func.func @Sin_platform_elem_type_output_type(%arg0: tensor<*xelem_type>) -> tensor<*xoutput_type> attributes {llvm.emit_c_interface, tf_entry} { + %0 = shape.shape_of %arg0 : tensor<*xelem_type> -> tensor + %1 = shape.num_elements %0 : tensor -> index + %from_elements = tensor.from_elements %1 : tensor<1xindex> + %2 = mhlo.dynamic_reshape %arg0, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %3 = mhlo.sine %2 : tensor + %4 = mhlo.dynamic_reshape %3, %0 : (tensor, tensor) -> tensor<*xelem_type> + return %4 : tensor<*xelem_type> } diff --git a/tensorflow/core/kernels/mlir_generated/op_definitions/sinh.mlir.tmpl b/tensorflow/core/kernels/mlir_generated/op_definitions/sinh.mlir.tmpl index 04ead92ada5c9e..350c9c240057aa 100644 --- a/tensorflow/core/kernels/mlir_generated/op_definitions/sinh.mlir.tmpl +++ b/tensorflow/core/kernels/mlir_generated/op_definitions/sinh.mlir.tmpl @@ -1,5 +1,9 @@ -func.func @Sinh_platform_elem_type_output_type(%arg0: tensor<*xelem_type>) - -> tensor<*xoutput_type> attributes {tf_entry, llvm.emit_c_interface} { - %0 = chlo.sinh %arg0 : tensor<*xelem_type> -> tensor<*xoutput_type> - func.return %0 : tensor<*xoutput_type> +func.func @Sinh_platform_elem_type_output_type(%arg0: tensor<*xelem_type>) -> tensor<*xoutput_type> attributes {llvm.emit_c_interface, tf_entry} { + %0 = shape.shape_of %arg0 : tensor<*xelem_type> -> tensor + %1 = shape.num_elements %0 : tensor -> index + %from_elements = tensor.from_elements %1 : tensor<1xindex> + %2 = mhlo.dynamic_reshape %arg0, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %3 = chlo.sinh %2 : tensor -> tensor + %4 = mhlo.dynamic_reshape %3, %0 : (tensor, tensor) -> tensor<*xoutput_type> + return %4 : tensor<*xoutput_type> } diff --git a/tensorflow/core/kernels/mlir_generated/op_definitions/softplus_f16.mlir.tmpl b/tensorflow/core/kernels/mlir_generated/op_definitions/softplus_f16.mlir.tmpl index f20d3fc58d612f..1ccfd2cb2310cc 100644 --- a/tensorflow/core/kernels/mlir_generated/op_definitions/softplus_f16.mlir.tmpl +++ b/tensorflow/core/kernels/mlir_generated/op_definitions/softplus_f16.mlir.tmpl @@ -1,15 +1,18 @@ -func.func @Softplus_platform_elem_type_output_type(%arg0: tensor<*xelem_type>) - -> tensor<*xoutput_type> attributes {tf_entry, llvm.emit_c_interface} { - %0 = mhlo.exponential %arg0 : tensor<*xelem_type> - %1 = mhlo.constant dense<1.220700e-04> : tensor - %2 = mhlo.log %1 : tensor - %3 = mhlo.constant dense<2.000000e+00> : tensor - %4 = chlo.broadcast_add %2, %3 : (tensor, tensor) -> tensor - %5 = mhlo.negate %4 : tensor - %6 = chlo.broadcast_compare %arg0, %5 {comparison_direction = #chlo} : (tensor<*xelem_type>, tensor) -> tensor<*xi1> - %7 = chlo.broadcast_compare %arg0, %4 {comparison_direction = #chlo} : (tensor<*xelem_type>, tensor) -> tensor<*xi1> - %8 = mhlo.log_plus_one %0 : tensor<*xelem_type> - %9 = mhlo.select %7, %0, %8 : tensor<*xi1>, tensor<*xelem_type> - %10 = mhlo.select %6, %arg0, %9 : tensor<*xi1>, tensor<*xelem_type> - func.return %10 : tensor<*xoutput_type> +func.func @Softplus_platform_elem_type_output_type(%arg0: tensor<*xelem_type>) -> tensor<*xoutput_type> attributes {llvm.emit_c_interface, tf_entry} { + %0 = mhlo.constant dense<-9.01091575> : tensor + %1 = mhlo.constant dense<2.000000e+00> : tensor + %2 = shape.shape_of %arg0 : tensor<*xelem_type> -> tensor + %3 = shape.num_elements %2 : tensor -> index + %from_elements = tensor.from_elements %3 : tensor<1xindex> + %4 = mhlo.dynamic_reshape %arg0, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %5 = mhlo.exponential %4 : tensor + %6 = chlo.broadcast_add %0, %1 : (tensor, tensor) -> tensor + %7 = mhlo.negate %6 : tensor + %8 = chlo.broadcast_compare %4, %7 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %9 = chlo.broadcast_compare %4, %6 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %10 = mhlo.log_plus_one %5 : tensor + %11 = mhlo.select %9, %5, %10 : tensor, tensor + %12 = mhlo.select %8, %4, %11 : tensor, tensor + %13 = mhlo.dynamic_reshape %12, %2 : (tensor, tensor) -> tensor<*xelem_type> + return %13 : tensor<*xelem_type> } diff --git a/tensorflow/core/kernels/mlir_generated/op_definitions/softplus_f32.mlir.tmpl b/tensorflow/core/kernels/mlir_generated/op_definitions/softplus_f32.mlir.tmpl index fd7541de4f03d9..5ad40b6a769541 100644 --- a/tensorflow/core/kernels/mlir_generated/op_definitions/softplus_f32.mlir.tmpl +++ b/tensorflow/core/kernels/mlir_generated/op_definitions/softplus_f32.mlir.tmpl @@ -1,15 +1,18 @@ -func.func @Softplus_platform_elem_type_output_type(%arg0: tensor<*xelem_type>) - -> tensor<*xoutput_type> attributes {tf_entry, llvm.emit_c_interface} { - %0 = mhlo.exponential %arg0 : tensor<*xelem_type> - %1 = mhlo.constant dense<1.1920929E-7> : tensor - %2 = mhlo.log %1 : tensor - %3 = mhlo.constant dense<2.000000e+00> : tensor - %4 = chlo.broadcast_add %2, %3 : (tensor, tensor) -> tensor - %5 = mhlo.negate %4 : tensor - %6 = chlo.broadcast_compare %arg0, %5 {comparison_direction = #chlo} : (tensor<*xelem_type>, tensor) -> tensor<*xi1> - %7 = chlo.broadcast_compare %arg0, %4 {comparison_direction = #chlo} : (tensor<*xelem_type>, tensor) -> tensor<*xi1> - %8 = mhlo.log_plus_one %0 : tensor<*xelem_type> - %9 = mhlo.select %7, %0, %8 : tensor<*xi1>, tensor<*xelem_type> - %10 = mhlo.select %6, %arg0, %9 : tensor<*xi1>, tensor<*xelem_type> - func.return %10 : tensor<*xoutput_type> +func.func @Softplus_platform_elem_type_output_type(%arg0: tensor<*xelem_type>) -> tensor<*xoutput_type> attributes {llvm.emit_c_interface, tf_entry} { + %0 = mhlo.constant dense<-15.9423847> : tensor + %1 = mhlo.constant dense<2.000000e+00> : tensor + %2 = shape.shape_of %arg0 : tensor<*xelem_type> -> tensor + %3 = shape.num_elements %2 : tensor -> index + %from_elements = tensor.from_elements %3 : tensor<1xindex> + %4 = mhlo.dynamic_reshape %arg0, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %5 = mhlo.exponential %4 : tensor + %6 = chlo.broadcast_add %0, %1 : (tensor, tensor) -> tensor + %7 = mhlo.negate %6 : tensor + %8 = chlo.broadcast_compare %4, %7 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %9 = chlo.broadcast_compare %4, %6 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %10 = mhlo.log_plus_one %5 : tensor + %11 = mhlo.select %9, %5, %10 : tensor, tensor + %12 = mhlo.select %8, %4, %11 : tensor, tensor + %13 = mhlo.dynamic_reshape %12, %2 : (tensor, tensor) -> tensor<*xelem_type> + return %13 : tensor<*xelem_type> } diff --git a/tensorflow/core/kernels/mlir_generated/op_definitions/softplus_f64.mlir.tmpl b/tensorflow/core/kernels/mlir_generated/op_definitions/softplus_f64.mlir.tmpl index 2fd33d20f95b58..d8e986994d01f1 100644 --- a/tensorflow/core/kernels/mlir_generated/op_definitions/softplus_f64.mlir.tmpl +++ b/tensorflow/core/kernels/mlir_generated/op_definitions/softplus_f64.mlir.tmpl @@ -1,15 +1,18 @@ -func.func @Softplus_platform_elem_type_output_type(%arg0: tensor<*xelem_type>) - -> tensor<*xoutput_type> attributes {tf_entry, llvm.emit_c_interface} { - %0 = mhlo.exponential %arg0 : tensor<*xelem_type> - %1 = mhlo.constant dense<2.2204460492503131E-16> : tensor - %2 = mhlo.log %1 : tensor - %3 = mhlo.constant dense<2.000000e+00> : tensor - %4 = chlo.broadcast_add %2, %3 : (tensor, tensor) -> tensor - %5 = mhlo.negate %4 : tensor - %6 = chlo.broadcast_compare %arg0, %5 {comparison_direction = #chlo} : (tensor<*xelem_type>, tensor) -> tensor<*xi1> - %7 = chlo.broadcast_compare %arg0, %4 {comparison_direction = #chlo} : (tensor<*xelem_type>, tensor) -> tensor<*xi1> - %8 = mhlo.log_plus_one %0 : tensor<*xelem_type> - %9 = mhlo.select %7, %0, %8 : tensor<*xi1>, tensor<*xelem_type> - %10 = mhlo.select %6, %arg0, %9 : tensor<*xi1>, tensor<*xelem_type> - func.return %10 : tensor<*xoutput_type> +func.func @Softplus_platform_elem_type_output_type(%arg0: tensor<*xelem_type>) -> tensor<*xoutput_type> attributes {llvm.emit_c_interface, tf_entry} { + %0 = mhlo.constant dense<-36.0436516> : tensor + %1 = mhlo.constant dense<2.000000e+00> : tensor + %2 = shape.shape_of %arg0 : tensor<*xelem_type> -> tensor + %3 = shape.num_elements %2 : tensor -> index + %from_elements = tensor.from_elements %3 : tensor<1xindex> + %4 = mhlo.dynamic_reshape %arg0, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %5 = mhlo.exponential %4 : tensor + %6 = chlo.broadcast_add %0, %1 : (tensor, tensor) -> tensor + %7 = mhlo.negate %6 : tensor + %8 = chlo.broadcast_compare %4, %7 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %9 = chlo.broadcast_compare %4, %6 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %10 = mhlo.log_plus_one %5 : tensor + %11 = mhlo.select %9, %5, %10 : tensor, tensor + %12 = mhlo.select %8, %4, %11 : tensor, tensor + %13 = mhlo.dynamic_reshape %12, %2 : (tensor, tensor) -> tensor<*xelem_type> + return %13 : tensor<*xelem_type> } diff --git a/tensorflow/core/kernels/mlir_generated/op_definitions/softsign.mlir.tmpl b/tensorflow/core/kernels/mlir_generated/op_definitions/softsign.mlir.tmpl index 65cf477d78571d..a77e5b170d74f0 100644 --- a/tensorflow/core/kernels/mlir_generated/op_definitions/softsign.mlir.tmpl +++ b/tensorflow/core/kernels/mlir_generated/op_definitions/softsign.mlir.tmpl @@ -1,8 +1,12 @@ -func.func @Softsign_platform_elem_type_output_type(%arg0: tensor<*xelem_type>) - -> tensor<*xoutput_type> attributes {tf_entry, llvm.emit_c_interface} { - %0 = "chlo.constant_like"(%arg0) {value = 1.000000e+00 : elem_type} : (tensor<*xelem_type>) -> tensor<*xelem_type> - %1 = mhlo.abs %arg0 : tensor<*xelem_type> - %2 = mhlo.add %0, %1 : tensor<*xelem_type> - %3 = mhlo.divide %arg0, %2 : tensor<*xelem_type> - func.return %3 : tensor<*xoutput_type> +func.func @Softsign_platform_elem_type_output_type(%arg0: tensor<*xelem_type>) -> tensor<*xoutput_type> attributes {llvm.emit_c_interface, tf_entry} { + %0 = shape.shape_of %arg0 : tensor<*xelem_type> -> tensor + %1 = shape.num_elements %0 : tensor -> index + %from_elements = tensor.from_elements %1 : tensor<1xindex> + %2 = mhlo.dynamic_reshape %arg0, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %3 = "chlo.constant_like"(%2) {value = 1.000000e+00 : elem_type} : (tensor) -> tensor + %4 = mhlo.abs %2 : tensor + %5 = mhlo.add %3, %4 : tensor + %6 = mhlo.divide %2, %5 : tensor + %7 = mhlo.dynamic_reshape %6, %0 : (tensor, tensor) -> tensor<*xelem_type> + return %7 : tensor<*xelem_type> } diff --git a/tensorflow/core/kernels/mlir_generated/op_definitions/sqrt.mlir.tmpl b/tensorflow/core/kernels/mlir_generated/op_definitions/sqrt.mlir.tmpl index 4e60e21a6d366e..08707c451e0312 100644 --- a/tensorflow/core/kernels/mlir_generated/op_definitions/sqrt.mlir.tmpl +++ b/tensorflow/core/kernels/mlir_generated/op_definitions/sqrt.mlir.tmpl @@ -1,5 +1,9 @@ -func.func @Sqrt_platform_elem_type_output_type(%arg0: tensor<*xelem_type>) - -> tensor<*xoutput_type> attributes {tf_entry, llvm.emit_c_interface} { - %0 = mhlo.sqrt %arg0 : tensor<*xelem_type> - func.return %0 : tensor<*xoutput_type> +func.func @Sqrt_platform_elem_type_output_type(%arg0: tensor<*xelem_type>) -> tensor<*xoutput_type> attributes {llvm.emit_c_interface, tf_entry} { + %0 = shape.shape_of %arg0 : tensor<*xelem_type> -> tensor + %1 = shape.num_elements %0 : tensor -> index + %from_elements = tensor.from_elements %1 : tensor<1xindex> + %2 = mhlo.dynamic_reshape %arg0, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %3 = mhlo.sqrt %2 : tensor + %4 = mhlo.dynamic_reshape %3, %0 : (tensor, tensor) -> tensor<*xelem_type> + return %4 : tensor<*xelem_type> } diff --git a/tensorflow/core/kernels/mlir_generated/op_definitions/square.mlir.tmpl b/tensorflow/core/kernels/mlir_generated/op_definitions/square.mlir.tmpl index 5863dfde38e5b6..8df4f3d08ac3b4 100644 --- a/tensorflow/core/kernels/mlir_generated/op_definitions/square.mlir.tmpl +++ b/tensorflow/core/kernels/mlir_generated/op_definitions/square.mlir.tmpl @@ -1,5 +1,9 @@ -func.func @Square_platform_elem_type_output_type(%arg0: tensor<*xelem_type>) - -> tensor<*xoutput_type> attributes {tf_entry, llvm.emit_c_interface} { - %0 = chlo.broadcast_multiply %arg0, %arg0 : (tensor<*xelem_type>, tensor<*xelem_type>) -> tensor<*xelem_type> - func.return %0 : tensor<*xoutput_type> +func.func @Square_platform_elem_type_output_type(%arg0: tensor<*xelem_type>) -> tensor<*xoutput_type> attributes {llvm.emit_c_interface, tf_entry} { + %0 = shape.shape_of %arg0 : tensor<*xelem_type> -> tensor + %1 = shape.num_elements %0 : tensor -> index + %from_elements = tensor.from_elements %1 : tensor<1xindex> + %2 = mhlo.dynamic_reshape %arg0, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %3 = chlo.broadcast_multiply %2, %2 : (tensor, tensor) -> tensor + %4 = mhlo.dynamic_reshape %3, %0 : (tensor, tensor) -> tensor<*xelem_type> + return %4 : tensor<*xelem_type> } diff --git a/tensorflow/core/kernels/mlir_generated/op_definitions/squared_difference.mlir.tmpl b/tensorflow/core/kernels/mlir_generated/op_definitions/squared_difference.mlir.tmpl index cf2b58356c9153..0c4220928702b5 100644 --- a/tensorflow/core/kernels/mlir_generated/op_definitions/squared_difference.mlir.tmpl +++ b/tensorflow/core/kernels/mlir_generated/op_definitions/squared_difference.mlir.tmpl @@ -1,6 +1,141 @@ -func.func @SquaredDifference_platform_elem_type_output_type(%arg0: tensor<*xelem_type>, %arg1: tensor<*xelem_type>) - -> tensor<*xoutput_type> attributes {tf_entry, llvm.emit_c_interface} { - %0 = chlo.broadcast_subtract %arg0, %arg1 : (tensor<*xelem_type>, tensor<*xelem_type>) -> tensor<*xelem_type> - %1 = chlo.broadcast_multiply %0, %0 : (tensor<*xelem_type>, tensor<*xelem_type>) -> tensor<*xelem_type> - func.return %1 : tensor<*xoutput_type> +func.func @SquaredDifference_platform_elem_type_output_type(%arg0: tensor<*xelem_type>, %arg1: tensor<*xelem_type>) -> tensor<*xoutput_type> attributes {llvm.emit_c_interface, tf_entry} { + %0 = shape.const_shape [1, 1, 1, 1, 1] : tensor<5xindex> + %c5 = arith.constant 5 : index + %1 = shape.const_shape [1, 1, 1, 1] : tensor<4xindex> + %c4 = arith.constant 4 : index + %2 = shape.const_shape [1, 1, 1] : tensor<3xindex> + %c3 = arith.constant 3 : index + %3 = shape.const_shape [1, 1] : tensor<2xindex> + %c2 = arith.constant 2 : index + %4 = shape.const_shape [1] : tensor<1xindex> + %c1 = arith.constant 1 : index + %5 = shape.shape_of %arg0 : tensor<*xelem_type> -> tensor + %6 = shape.shape_of %arg1 : tensor<*xelem_type> -> tensor + %7 = shape.num_elements %5 : tensor -> index + %8 = arith.cmpi eq, %7, %c1 : index + %9 = scf.if %8 -> (tensor<*xelem_type>) { + %18 = shape.num_elements %6 : tensor -> index + %from_elements = tensor.from_elements %18 : tensor<1xindex> + %19 = mhlo.reshape %arg0 : (tensor<*xelem_type>) -> tensor + %20 = mhlo.dynamic_reshape %arg1, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %21 = chlo.broadcast_subtract %19, %20 : (tensor, tensor) -> tensor + %22 = chlo.broadcast_multiply %21, %21 : (tensor, tensor) -> tensor + %cast = tensor.cast %22 : tensor to tensor<*xelem_type> + scf.yield %cast : tensor<*xelem_type> + } else { + %18 = shape.num_elements %6 : tensor -> index + %19 = arith.cmpi eq, %18, %c1 : index + %20 = scf.if %19 -> (tensor<*xelem_type>) { + %21 = shape.num_elements %5 : tensor -> index + %from_elements = tensor.from_elements %21 : tensor<1xindex> + %22 = mhlo.dynamic_reshape %arg0, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %23 = mhlo.reshape %arg1 : (tensor<*xelem_type>) -> tensor + %24 = chlo.broadcast_subtract %22, %23 : (tensor, tensor) -> tensor + %25 = chlo.broadcast_multiply %24, %24 : (tensor, tensor) -> tensor + %cast = tensor.cast %25 : tensor to tensor<*xelem_type> + scf.yield %cast : tensor<*xelem_type> + } else { + %21 = shape.shape_eq %5, %6 : tensor, tensor + %22 = scf.if %21 -> (tensor<*xelem_type>) { + %23 = shape.any %5, %6 : tensor, tensor -> tensor + %24 = shape.num_elements %23 : tensor -> index + %from_elements = tensor.from_elements %24 : tensor<1xindex> + %25 = mhlo.dynamic_reshape %arg0, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %26 = mhlo.dynamic_reshape %arg1, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %27 = chlo.broadcast_subtract %25, %26 : (tensor, tensor) -> tensor + %28 = chlo.broadcast_multiply %27, %27 : (tensor, tensor) -> tensor + %cast = tensor.cast %28 : tensor to tensor<*xelem_type> + scf.yield %cast : tensor<*xelem_type> + } else { + %23:2 = chlo.minimum_broadcast_shapes %5, %6 : tensor, tensor -> tensor, tensor + %24 = shape.rank %23#0 : tensor -> index + %25 = shape.rank %23#1 : tensor -> index + %26 = arith.cmpi sgt, %24, %25 : index + %27 = arith.select %26, %24, %25 : index + %28 = arith.cmpi ule, %27, %c1 : index + %29 = scf.if %28 -> (tensor<*xelem_type>) { + %30 = shape.broadcast %23#0, %4 : tensor, tensor<1xindex> -> tensor + %cast = tensor.cast %30 : tensor to tensor<1xindex> + %31 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %32 = shape.broadcast %23#1, %4 : tensor, tensor<1xindex> -> tensor + %cast_0 = tensor.cast %32 : tensor to tensor<1xindex> + %33 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %34 = chlo.broadcast_subtract %31, %33 : (tensor, tensor) -> tensor + %35 = chlo.broadcast_multiply %34, %34 : (tensor, tensor) -> tensor + %cast_1 = tensor.cast %35 : tensor to tensor<*xelem_type> + scf.yield %cast_1 : tensor<*xelem_type> + } else { + %30 = arith.cmpi ule, %27, %c2 : index + %31 = scf.if %30 -> (tensor<*xelem_type>) { + %32 = shape.broadcast %23#0, %3 : tensor, tensor<2xindex> -> tensor + %cast = tensor.cast %32 : tensor to tensor<2xindex> + %33 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xelem_type>, tensor<2xindex>) -> tensor + %34 = shape.broadcast %23#1, %3 : tensor, tensor<2xindex> -> tensor + %cast_0 = tensor.cast %34 : tensor to tensor<2xindex> + %35 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xelem_type>, tensor<2xindex>) -> tensor + %36 = chlo.broadcast_subtract %33, %35 : (tensor, tensor) -> tensor + %37 = chlo.broadcast_multiply %36, %36 : (tensor, tensor) -> tensor + %cast_1 = tensor.cast %37 : tensor to tensor<*xelem_type> + scf.yield %cast_1 : tensor<*xelem_type> + } else { + %32 = arith.cmpi ule, %27, %c3 : index + %33 = scf.if %32 -> (tensor<*xelem_type>) { + %34 = shape.broadcast %23#0, %2 : tensor, tensor<3xindex> -> tensor + %cast = tensor.cast %34 : tensor to tensor<3xindex> + %35 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xelem_type>, tensor<3xindex>) -> tensor + %36 = shape.broadcast %23#1, %2 : tensor, tensor<3xindex> -> tensor + %cast_0 = tensor.cast %36 : tensor to tensor<3xindex> + %37 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xelem_type>, tensor<3xindex>) -> tensor + %38 = chlo.broadcast_subtract %35, %37 : (tensor, tensor) -> tensor + %39 = chlo.broadcast_multiply %38, %38 : (tensor, tensor) -> tensor + %cast_1 = tensor.cast %39 : tensor to tensor<*xelem_type> + scf.yield %cast_1 : tensor<*xelem_type> + } else { + %34 = arith.cmpi ule, %27, %c4 : index + %35 = scf.if %34 -> (tensor<*xelem_type>) { + %36 = shape.broadcast %23#0, %1 : tensor, tensor<4xindex> -> tensor + %cast = tensor.cast %36 : tensor to tensor<4xindex> + %37 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xelem_type>, tensor<4xindex>) -> tensor + %38 = shape.broadcast %23#1, %1 : tensor, tensor<4xindex> -> tensor + %cast_0 = tensor.cast %38 : tensor to tensor<4xindex> + %39 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xelem_type>, tensor<4xindex>) -> tensor + %40 = chlo.broadcast_subtract %37, %39 : (tensor, tensor) -> tensor + %41 = chlo.broadcast_multiply %40, %40 : (tensor, tensor) -> tensor + %cast_1 = tensor.cast %41 : tensor to tensor<*xelem_type> + scf.yield %cast_1 : tensor<*xelem_type> + } else { + %36 = arith.cmpi ule, %27, %c5 : index + cf.assert %36, "Input for dynamic binary or n-ary op lowering was of a rank greater than 5" + %37 = shape.broadcast %23#0, %0 : tensor, tensor<5xindex> -> tensor + %cast = tensor.cast %37 : tensor to tensor<5xindex> + %38 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xelem_type>, tensor<5xindex>) -> tensor + %39 = shape.broadcast %23#1, %0 : tensor, tensor<5xindex> -> tensor + %cast_0 = tensor.cast %39 : tensor to tensor<5xindex> + %40 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xelem_type>, tensor<5xindex>) -> tensor + %41 = chlo.broadcast_subtract %38, %40 : (tensor, tensor) -> tensor + %42 = chlo.broadcast_multiply %41, %41 : (tensor, tensor) -> tensor + %cast_1 = tensor.cast %42 : tensor to tensor<*xelem_type> + scf.yield %cast_1 : tensor<*xelem_type> + } + scf.yield %35 : tensor<*xelem_type> + } + scf.yield %33 : tensor<*xelem_type> + } + scf.yield %31 : tensor<*xelem_type> + } + scf.yield %29 : tensor<*xelem_type> + } + scf.yield %22 : tensor<*xelem_type> + } + scf.yield %20 : tensor<*xelem_type> + } + %10 = shape.shape_of %arg0 : tensor<*xelem_type> -> tensor + %11 = shape.shape_of %arg1 : tensor<*xelem_type> -> tensor + %12 = shape.broadcast %10, %11 : tensor, tensor -> tensor + %13 = shape.shape_of %arg0 : tensor<*xelem_type> -> tensor + %14 = shape.shape_of %arg1 : tensor<*xelem_type> -> tensor + %15 = shape.broadcast %13, %14 : tensor, tensor -> tensor + %16 = shape.broadcast %12, %15 : tensor, tensor -> tensor + %17 = mhlo.dynamic_reshape %9, %16 : (tensor<*xelem_type>, tensor) -> tensor<*xelem_type> + return %17 : tensor<*xelem_type> } diff --git a/tensorflow/core/kernels/mlir_generated/op_definitions/sub.mlir.tmpl b/tensorflow/core/kernels/mlir_generated/op_definitions/sub.mlir.tmpl index 9422b04a4c3d30..b4f74d07b0e79b 100644 --- a/tensorflow/core/kernels/mlir_generated/op_definitions/sub.mlir.tmpl +++ b/tensorflow/core/kernels/mlir_generated/op_definitions/sub.mlir.tmpl @@ -1,5 +1,129 @@ -func.func @Sub_platform_elem_type_output_type(%arg0: tensor<*xelem_type>, %arg1: tensor<*xelem_type>) - -> tensor<*xoutput_type> attributes {tf_entry, llvm.emit_c_interface} { - %0 = chlo.broadcast_subtract %arg0, %arg1 : (tensor<*xelem_type>, tensor<*xelem_type>) -> tensor<*xelem_type> - func.return %0 : tensor<*xoutput_type> +func.func @Sub_platform_elem_type_output_type(%arg0: tensor<*xelem_type>, %arg1: tensor<*xelem_type>) -> tensor<*xoutput_type> attributes {llvm.emit_c_interface, tf_entry} { + %0 = shape.const_shape [1, 1, 1, 1, 1] : tensor<5xindex> + %c5 = arith.constant 5 : index + %1 = shape.const_shape [1, 1, 1, 1] : tensor<4xindex> + %c4 = arith.constant 4 : index + %2 = shape.const_shape [1, 1, 1] : tensor<3xindex> + %c3 = arith.constant 3 : index + %3 = shape.const_shape [1, 1] : tensor<2xindex> + %c2 = arith.constant 2 : index + %4 = shape.const_shape [1] : tensor<1xindex> + %c1 = arith.constant 1 : index + %5 = shape.shape_of %arg0 : tensor<*xelem_type> -> tensor + %6 = shape.shape_of %arg1 : tensor<*xelem_type> -> tensor + %7 = shape.num_elements %5 : tensor -> index + %8 = arith.cmpi eq, %7, %c1 : index + %9 = scf.if %8 -> (tensor<*xelem_type>) { + %14 = shape.num_elements %6 : tensor -> index + %from_elements = tensor.from_elements %14 : tensor<1xindex> + %15 = mhlo.reshape %arg0 : (tensor<*xelem_type>) -> tensor + %16 = mhlo.dynamic_reshape %arg1, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %17 = chlo.broadcast_subtract %15, %16 : (tensor, tensor) -> tensor + %cast = tensor.cast %17 : tensor to tensor<*xelem_type> + scf.yield %cast : tensor<*xelem_type> + } else { + %14 = shape.num_elements %6 : tensor -> index + %15 = arith.cmpi eq, %14, %c1 : index + %16 = scf.if %15 -> (tensor<*xelem_type>) { + %17 = shape.num_elements %5 : tensor -> index + %from_elements = tensor.from_elements %17 : tensor<1xindex> + %18 = mhlo.dynamic_reshape %arg0, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %19 = mhlo.reshape %arg1 : (tensor<*xelem_type>) -> tensor + %20 = chlo.broadcast_subtract %18, %19 : (tensor, tensor) -> tensor + %cast = tensor.cast %20 : tensor to tensor<*xelem_type> + scf.yield %cast : tensor<*xelem_type> + } else { + %17 = shape.shape_eq %5, %6 : tensor, tensor + %18 = scf.if %17 -> (tensor<*xelem_type>) { + %19 = shape.any %5, %6 : tensor, tensor -> tensor + %20 = shape.num_elements %19 : tensor -> index + %from_elements = tensor.from_elements %20 : tensor<1xindex> + %21 = mhlo.dynamic_reshape %arg0, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %22 = mhlo.dynamic_reshape %arg1, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %23 = chlo.broadcast_subtract %21, %22 : (tensor, tensor) -> tensor + %cast = tensor.cast %23 : tensor to tensor<*xelem_type> + scf.yield %cast : tensor<*xelem_type> + } else { + %19:2 = chlo.minimum_broadcast_shapes %5, %6 : tensor, tensor -> tensor, tensor + %20 = shape.rank %19#0 : tensor -> index + %21 = shape.rank %19#1 : tensor -> index + %22 = arith.cmpi sgt, %20, %21 : index + %23 = arith.select %22, %20, %21 : index + %24 = arith.cmpi ule, %23, %c1 : index + %25 = scf.if %24 -> (tensor<*xelem_type>) { + %26 = shape.broadcast %19#0, %4 : tensor, tensor<1xindex> -> tensor + %cast = tensor.cast %26 : tensor to tensor<1xindex> + %27 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %28 = shape.broadcast %19#1, %4 : tensor, tensor<1xindex> -> tensor + %cast_0 = tensor.cast %28 : tensor to tensor<1xindex> + %29 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %30 = chlo.broadcast_subtract %27, %29 : (tensor, tensor) -> tensor + %cast_1 = tensor.cast %30 : tensor to tensor<*xelem_type> + scf.yield %cast_1 : tensor<*xelem_type> + } else { + %26 = arith.cmpi ule, %23, %c2 : index + %27 = scf.if %26 -> (tensor<*xelem_type>) { + %28 = shape.broadcast %19#0, %3 : tensor, tensor<2xindex> -> tensor + %cast = tensor.cast %28 : tensor to tensor<2xindex> + %29 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xelem_type>, tensor<2xindex>) -> tensor + %30 = shape.broadcast %19#1, %3 : tensor, tensor<2xindex> -> tensor + %cast_0 = tensor.cast %30 : tensor to tensor<2xindex> + %31 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xelem_type>, tensor<2xindex>) -> tensor + %32 = chlo.broadcast_subtract %29, %31 : (tensor, tensor) -> tensor + %cast_1 = tensor.cast %32 : tensor to tensor<*xelem_type> + scf.yield %cast_1 : tensor<*xelem_type> + } else { + %28 = arith.cmpi ule, %23, %c3 : index + %29 = scf.if %28 -> (tensor<*xelem_type>) { + %30 = shape.broadcast %19#0, %2 : tensor, tensor<3xindex> -> tensor + %cast = tensor.cast %30 : tensor to tensor<3xindex> + %31 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xelem_type>, tensor<3xindex>) -> tensor + %32 = shape.broadcast %19#1, %2 : tensor, tensor<3xindex> -> tensor + %cast_0 = tensor.cast %32 : tensor to tensor<3xindex> + %33 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xelem_type>, tensor<3xindex>) -> tensor + %34 = chlo.broadcast_subtract %31, %33 : (tensor, tensor) -> tensor + %cast_1 = tensor.cast %34 : tensor to tensor<*xelem_type> + scf.yield %cast_1 : tensor<*xelem_type> + } else { + %30 = arith.cmpi ule, %23, %c4 : index + %31 = scf.if %30 -> (tensor<*xelem_type>) { + %32 = shape.broadcast %19#0, %1 : tensor, tensor<4xindex> -> tensor + %cast = tensor.cast %32 : tensor to tensor<4xindex> + %33 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xelem_type>, tensor<4xindex>) -> tensor + %34 = shape.broadcast %19#1, %1 : tensor, tensor<4xindex> -> tensor + %cast_0 = tensor.cast %34 : tensor to tensor<4xindex> + %35 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xelem_type>, tensor<4xindex>) -> tensor + %36 = chlo.broadcast_subtract %33, %35 : (tensor, tensor) -> tensor + %cast_1 = tensor.cast %36 : tensor to tensor<*xelem_type> + scf.yield %cast_1 : tensor<*xelem_type> + } else { + %32 = arith.cmpi ule, %23, %c5 : index + cf.assert %32, "Input for dynamic binary or n-ary op lowering was of a rank greater than 5" + %33 = shape.broadcast %19#0, %0 : tensor, tensor<5xindex> -> tensor + %cast = tensor.cast %33 : tensor to tensor<5xindex> + %34 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xelem_type>, tensor<5xindex>) -> tensor + %35 = shape.broadcast %19#1, %0 : tensor, tensor<5xindex> -> tensor + %cast_0 = tensor.cast %35 : tensor to tensor<5xindex> + %36 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xelem_type>, tensor<5xindex>) -> tensor + %37 = chlo.broadcast_subtract %34, %36 : (tensor, tensor) -> tensor + %cast_1 = tensor.cast %37 : tensor to tensor<*xelem_type> + scf.yield %cast_1 : tensor<*xelem_type> + } + scf.yield %31 : tensor<*xelem_type> + } + scf.yield %29 : tensor<*xelem_type> + } + scf.yield %27 : tensor<*xelem_type> + } + scf.yield %25 : tensor<*xelem_type> + } + scf.yield %18 : tensor<*xelem_type> + } + scf.yield %16 : tensor<*xelem_type> + } + %10 = shape.shape_of %arg0 : tensor<*xelem_type> -> tensor + %11 = shape.shape_of %arg1 : tensor<*xelem_type> -> tensor + %12 = shape.broadcast %10, %11 : tensor, tensor -> tensor + %13 = mhlo.dynamic_reshape %9, %12 : (tensor<*xelem_type>, tensor) -> tensor<*xelem_type> + return %13 : tensor<*xelem_type> } diff --git a/tensorflow/core/kernels/mlir_generated/op_definitions/tan.mlir.tmpl b/tensorflow/core/kernels/mlir_generated/op_definitions/tan.mlir.tmpl index ce6df920308c9d..c9d47268d5bd44 100644 --- a/tensorflow/core/kernels/mlir_generated/op_definitions/tan.mlir.tmpl +++ b/tensorflow/core/kernels/mlir_generated/op_definitions/tan.mlir.tmpl @@ -1,5 +1,9 @@ -func.func @Tan_platform_elem_type_output_type(%arg0: tensor<*xelem_type>) - -> tensor<*xoutput_type> attributes {tf_entry, llvm.emit_c_interface} { - %0 = mhlo.tan %arg0 : tensor<*xelem_type> - func.return %0 : tensor<*xoutput_type> +func.func @Tan_platform_elem_type_output_type(%arg0: tensor<*xelem_type>) -> tensor<*xoutput_type> attributes {llvm.emit_c_interface, tf_entry} { + %0 = shape.shape_of %arg0 : tensor<*xelem_type> -> tensor + %1 = shape.num_elements %0 : tensor -> index + %from_elements = tensor.from_elements %1 : tensor<1xindex> + %2 = mhlo.dynamic_reshape %arg0, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %3 = mhlo.tan %2 : tensor + %4 = mhlo.dynamic_reshape %3, %0 : (tensor, tensor) -> tensor<*xelem_type> + return %4 : tensor<*xelem_type> } diff --git a/tensorflow/core/kernels/mlir_generated/op_definitions/tanh.mlir.tmpl b/tensorflow/core/kernels/mlir_generated/op_definitions/tanh.mlir.tmpl index cb47c6fbb43f19..1c2bf19d63c124 100644 --- a/tensorflow/core/kernels/mlir_generated/op_definitions/tanh.mlir.tmpl +++ b/tensorflow/core/kernels/mlir_generated/op_definitions/tanh.mlir.tmpl @@ -1,5 +1,9 @@ -func.func @Tanh_platform_elem_type_output_type(%arg0: tensor<*xelem_type>) - -> tensor<*xoutput_type> attributes {tf_entry, llvm.emit_c_interface} { - %0 = mhlo.tanh %arg0 : tensor<*xelem_type> - func.return %0 : tensor<*xoutput_type> +func.func @Tanh_platform_elem_type_output_type(%arg0: tensor<*xelem_type>) -> tensor<*xoutput_type> attributes {llvm.emit_c_interface, tf_entry} { + %0 = shape.shape_of %arg0 : tensor<*xelem_type> -> tensor + %1 = shape.num_elements %0 : tensor -> index + %from_elements = tensor.from_elements %1 : tensor<1xindex> + %2 = mhlo.dynamic_reshape %arg0, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %3 = mhlo.tanh %2 : tensor + %4 = mhlo.dynamic_reshape %3, %0 : (tensor, tensor) -> tensor<*xelem_type> + return %4 : tensor<*xelem_type> } diff --git a/tensorflow/core/kernels/mlir_generated/op_definitions/truncate_div.mlir.tmpl b/tensorflow/core/kernels/mlir_generated/op_definitions/truncate_div.mlir.tmpl index 83c9e31ed37385..18921f3a068eea 100644 --- a/tensorflow/core/kernels/mlir_generated/op_definitions/truncate_div.mlir.tmpl +++ b/tensorflow/core/kernels/mlir_generated/op_definitions/truncate_div.mlir.tmpl @@ -1,6 +1,129 @@ -func.func @TruncateDiv_platform_elem_type_output_type( - %arg0: tensor<*xelem_type>, %arg1: tensor<*xelem_type>) - -> tensor<*xoutput_type> attributes {tf_entry, llvm.emit_c_interface} { - %0 = chlo.broadcast_divide %arg0, %arg1 : (tensor<*xelem_type>, tensor<*xelem_type>) -> tensor<*xelem_type> - func.return %0 : tensor<*xoutput_type> +func.func @TruncateDiv_platform_elem_type_output_type(%arg0: tensor<*xelem_type>, %arg1: tensor<*xelem_type>) -> tensor<*xoutput_type> attributes {llvm.emit_c_interface, tf_entry} { + %0 = shape.const_shape [1, 1, 1, 1, 1] : tensor<5xindex> + %c5 = arith.constant 5 : index + %1 = shape.const_shape [1, 1, 1, 1] : tensor<4xindex> + %c4 = arith.constant 4 : index + %2 = shape.const_shape [1, 1, 1] : tensor<3xindex> + %c3 = arith.constant 3 : index + %3 = shape.const_shape [1, 1] : tensor<2xindex> + %c2 = arith.constant 2 : index + %4 = shape.const_shape [1] : tensor<1xindex> + %c1 = arith.constant 1 : index + %5 = shape.shape_of %arg0 : tensor<*xelem_type> -> tensor + %6 = shape.shape_of %arg1 : tensor<*xelem_type> -> tensor + %7 = shape.num_elements %5 : tensor -> index + %8 = arith.cmpi eq, %7, %c1 : index + %9 = scf.if %8 -> (tensor<*xelem_type>) { + %14 = shape.num_elements %6 : tensor -> index + %from_elements = tensor.from_elements %14 : tensor<1xindex> + %15 = mhlo.reshape %arg0 : (tensor<*xelem_type>) -> tensor + %16 = mhlo.dynamic_reshape %arg1, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %17 = chlo.broadcast_divide %15, %16 : (tensor, tensor) -> tensor + %cast = tensor.cast %17 : tensor to tensor<*xelem_type> + scf.yield %cast : tensor<*xelem_type> + } else { + %14 = shape.num_elements %6 : tensor -> index + %15 = arith.cmpi eq, %14, %c1 : index + %16 = scf.if %15 -> (tensor<*xelem_type>) { + %17 = shape.num_elements %5 : tensor -> index + %from_elements = tensor.from_elements %17 : tensor<1xindex> + %18 = mhlo.dynamic_reshape %arg0, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %19 = mhlo.reshape %arg1 : (tensor<*xelem_type>) -> tensor + %20 = chlo.broadcast_divide %18, %19 : (tensor, tensor) -> tensor + %cast = tensor.cast %20 : tensor to tensor<*xelem_type> + scf.yield %cast : tensor<*xelem_type> + } else { + %17 = shape.shape_eq %5, %6 : tensor, tensor + %18 = scf.if %17 -> (tensor<*xelem_type>) { + %19 = shape.any %5, %6 : tensor, tensor -> tensor + %20 = shape.num_elements %19 : tensor -> index + %from_elements = tensor.from_elements %20 : tensor<1xindex> + %21 = mhlo.dynamic_reshape %arg0, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %22 = mhlo.dynamic_reshape %arg1, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %23 = chlo.broadcast_divide %21, %22 : (tensor, tensor) -> tensor + %cast = tensor.cast %23 : tensor to tensor<*xelem_type> + scf.yield %cast : tensor<*xelem_type> + } else { + %19:2 = chlo.minimum_broadcast_shapes %5, %6 : tensor, tensor -> tensor, tensor + %20 = shape.rank %19#0 : tensor -> index + %21 = shape.rank %19#1 : tensor -> index + %22 = arith.cmpi sgt, %20, %21 : index + %23 = arith.select %22, %20, %21 : index + %24 = arith.cmpi ule, %23, %c1 : index + %25 = scf.if %24 -> (tensor<*xelem_type>) { + %26 = shape.broadcast %19#0, %4 : tensor, tensor<1xindex> -> tensor + %cast = tensor.cast %26 : tensor to tensor<1xindex> + %27 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %28 = shape.broadcast %19#1, %4 : tensor, tensor<1xindex> -> tensor + %cast_0 = tensor.cast %28 : tensor to tensor<1xindex> + %29 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %30 = chlo.broadcast_divide %27, %29 : (tensor, tensor) -> tensor + %cast_1 = tensor.cast %30 : tensor to tensor<*xelem_type> + scf.yield %cast_1 : tensor<*xelem_type> + } else { + %26 = arith.cmpi ule, %23, %c2 : index + %27 = scf.if %26 -> (tensor<*xelem_type>) { + %28 = shape.broadcast %19#0, %3 : tensor, tensor<2xindex> -> tensor + %cast = tensor.cast %28 : tensor to tensor<2xindex> + %29 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xelem_type>, tensor<2xindex>) -> tensor + %30 = shape.broadcast %19#1, %3 : tensor, tensor<2xindex> -> tensor + %cast_0 = tensor.cast %30 : tensor to tensor<2xindex> + %31 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xelem_type>, tensor<2xindex>) -> tensor + %32 = chlo.broadcast_divide %29, %31 : (tensor, tensor) -> tensor + %cast_1 = tensor.cast %32 : tensor to tensor<*xelem_type> + scf.yield %cast_1 : tensor<*xelem_type> + } else { + %28 = arith.cmpi ule, %23, %c3 : index + %29 = scf.if %28 -> (tensor<*xelem_type>) { + %30 = shape.broadcast %19#0, %2 : tensor, tensor<3xindex> -> tensor + %cast = tensor.cast %30 : tensor to tensor<3xindex> + %31 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xelem_type>, tensor<3xindex>) -> tensor + %32 = shape.broadcast %19#1, %2 : tensor, tensor<3xindex> -> tensor + %cast_0 = tensor.cast %32 : tensor to tensor<3xindex> + %33 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xelem_type>, tensor<3xindex>) -> tensor + %34 = chlo.broadcast_divide %31, %33 : (tensor, tensor) -> tensor + %cast_1 = tensor.cast %34 : tensor to tensor<*xelem_type> + scf.yield %cast_1 : tensor<*xelem_type> + } else { + %30 = arith.cmpi ule, %23, %c4 : index + %31 = scf.if %30 -> (tensor<*xelem_type>) { + %32 = shape.broadcast %19#0, %1 : tensor, tensor<4xindex> -> tensor + %cast = tensor.cast %32 : tensor to tensor<4xindex> + %33 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xelem_type>, tensor<4xindex>) -> tensor + %34 = shape.broadcast %19#1, %1 : tensor, tensor<4xindex> -> tensor + %cast_0 = tensor.cast %34 : tensor to tensor<4xindex> + %35 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xelem_type>, tensor<4xindex>) -> tensor + %36 = chlo.broadcast_divide %33, %35 : (tensor, tensor) -> tensor + %cast_1 = tensor.cast %36 : tensor to tensor<*xelem_type> + scf.yield %cast_1 : tensor<*xelem_type> + } else { + %32 = arith.cmpi ule, %23, %c5 : index + cf.assert %32, "Input for dynamic binary or n-ary op lowering was of a rank greater than 5" + %33 = shape.broadcast %19#0, %0 : tensor, tensor<5xindex> -> tensor + %cast = tensor.cast %33 : tensor to tensor<5xindex> + %34 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xelem_type>, tensor<5xindex>) -> tensor + %35 = shape.broadcast %19#1, %0 : tensor, tensor<5xindex> -> tensor + %cast_0 = tensor.cast %35 : tensor to tensor<5xindex> + %36 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xelem_type>, tensor<5xindex>) -> tensor + %37 = chlo.broadcast_divide %34, %36 : (tensor, tensor) -> tensor + %cast_1 = tensor.cast %37 : tensor to tensor<*xelem_type> + scf.yield %cast_1 : tensor<*xelem_type> + } + scf.yield %31 : tensor<*xelem_type> + } + scf.yield %29 : tensor<*xelem_type> + } + scf.yield %27 : tensor<*xelem_type> + } + scf.yield %25 : tensor<*xelem_type> + } + scf.yield %18 : tensor<*xelem_type> + } + scf.yield %16 : tensor<*xelem_type> + } + %10 = shape.shape_of %arg0 : tensor<*xelem_type> -> tensor + %11 = shape.shape_of %arg1 : tensor<*xelem_type> -> tensor + %12 = shape.broadcast %10, %11 : tensor, tensor -> tensor + %13 = mhlo.dynamic_reshape %9, %12 : (tensor<*xelem_type>, tensor) -> tensor<*xelem_type> + return %13 : tensor<*xelem_type> } diff --git a/tensorflow/core/kernels/mlir_generated/op_definitions/truncate_div_float.mlir.tmpl b/tensorflow/core/kernels/mlir_generated/op_definitions/truncate_div_float.mlir.tmpl index 6b74e21508bc3b..bff954b451b307 100644 --- a/tensorflow/core/kernels/mlir_generated/op_definitions/truncate_div_float.mlir.tmpl +++ b/tensorflow/core/kernels/mlir_generated/op_definitions/truncate_div_float.mlir.tmpl @@ -1,11 +1,169 @@ -func.func @TruncateDiv_platform_elem_type_output_type( - %arg0: tensor<*xelem_type>, %arg1: tensor<*xelem_type>) - -> tensor<*xoutput_type> attributes {tf_entry, llvm.emit_c_interface} { - %0 = chlo.broadcast_divide %arg0, %arg1 : (tensor<*xelem_type>, tensor<*xelem_type>) -> tensor<*xelem_type> - %1 = mhlo.constant dense<0.000000e+00> : tensor - %2 = chlo.broadcast_compare %0, %1 {comparison_direction = #chlo} : (tensor<*xelem_type>, tensor) -> tensor<*xi1> - %3 = mhlo.ceil %0 : tensor<*xelem_type> - %4 = mhlo.floor %0 : tensor<*xelem_type> - %5 = chlo.broadcast_select %2, %3, %4 : (tensor<*xi1>, tensor<*xelem_type>, tensor<*xelem_type>) -> tensor<*xelem_type> - func.return %5 : tensor<*xoutput_type> +func.func @TruncateDiv_platform_elem_type_output_type(%arg0: tensor<*xelem_type>, %arg1: tensor<*xelem_type>) -> tensor<*xoutput_type> attributes {llvm.emit_c_interface, tf_entry} { + %0 = shape.const_shape [1, 1, 1, 1, 1] : tensor<5xindex> + %c5 = arith.constant 5 : index + %1 = shape.const_shape [1, 1, 1, 1] : tensor<4xindex> + %c4 = arith.constant 4 : index + %2 = shape.const_shape [1, 1, 1] : tensor<3xindex> + %c3 = arith.constant 3 : index + %3 = shape.const_shape [1, 1] : tensor<2xindex> + %c2 = arith.constant 2 : index + %4 = shape.const_shape [1] : tensor<1xindex> + %c1 = arith.constant 1 : index + %5 = mhlo.constant dense<0.000000e+00> : tensor + %6 = shape.shape_of %arg0 : tensor<*xelem_type> -> tensor + %7 = shape.shape_of %arg1 : tensor<*xelem_type> -> tensor + %8 = shape.num_elements %6 : tensor -> index + %9 = arith.cmpi eq, %8, %c1 : index + %10 = scf.if %9 -> (tensor<*xelem_type>) { + %22 = shape.num_elements %7 : tensor -> index + %from_elements = tensor.from_elements %22 : tensor<1xindex> + %23 = mhlo.reshape %arg0 : (tensor<*xelem_type>) -> tensor + %24 = mhlo.dynamic_reshape %arg1, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %25 = chlo.broadcast_divide %23, %24 : (tensor, tensor) -> tensor + %26 = chlo.broadcast_compare %25, %5 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %27 = mhlo.ceil %25 : tensor + %28 = mhlo.floor %25 : tensor + %29 = chlo.broadcast_select %26, %27, %28 : (tensor, tensor, tensor) -> tensor + %cast = tensor.cast %29 : tensor to tensor<*xelem_type> + scf.yield %cast : tensor<*xelem_type> + } else { + %22 = shape.num_elements %7 : tensor -> index + %23 = arith.cmpi eq, %22, %c1 : index + %24 = scf.if %23 -> (tensor<*xelem_type>) { + %25 = shape.num_elements %6 : tensor -> index + %from_elements = tensor.from_elements %25 : tensor<1xindex> + %26 = mhlo.dynamic_reshape %arg0, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %27 = mhlo.reshape %arg1 : (tensor<*xelem_type>) -> tensor + %28 = chlo.broadcast_divide %26, %27 : (tensor, tensor) -> tensor + %29 = chlo.broadcast_compare %28, %5 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %30 = mhlo.ceil %28 : tensor + %31 = mhlo.floor %28 : tensor + %32 = chlo.broadcast_select %29, %30, %31 : (tensor, tensor, tensor) -> tensor + %cast = tensor.cast %32 : tensor to tensor<*xelem_type> + scf.yield %cast : tensor<*xelem_type> + } else { + %25 = shape.shape_eq %6, %7 : tensor, tensor + %26 = scf.if %25 -> (tensor<*xelem_type>) { + %27 = shape.any %6, %7 : tensor, tensor -> tensor + %28 = shape.num_elements %27 : tensor -> index + %from_elements = tensor.from_elements %28 : tensor<1xindex> + %29 = mhlo.dynamic_reshape %arg0, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %30 = mhlo.dynamic_reshape %arg1, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %31 = chlo.broadcast_divide %29, %30 : (tensor, tensor) -> tensor + %32 = chlo.broadcast_compare %31, %5 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %33 = mhlo.ceil %31 : tensor + %34 = mhlo.floor %31 : tensor + %35 = chlo.broadcast_select %32, %33, %34 : (tensor, tensor, tensor) -> tensor + %cast = tensor.cast %35 : tensor to tensor<*xelem_type> + scf.yield %cast : tensor<*xelem_type> + } else { + %27:2 = chlo.minimum_broadcast_shapes %6, %7 : tensor, tensor -> tensor, tensor + %28 = shape.rank %27#0 : tensor -> index + %29 = shape.rank %27#1 : tensor -> index + %30 = arith.cmpi sgt, %28, %29 : index + %31 = arith.select %30, %28, %29 : index + %32 = arith.cmpi ule, %31, %c1 : index + %33 = scf.if %32 -> (tensor<*xelem_type>) { + %34 = shape.broadcast %27#0, %4 : tensor, tensor<1xindex> -> tensor + %cast = tensor.cast %34 : tensor to tensor<1xindex> + %35 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %36 = shape.broadcast %27#1, %4 : tensor, tensor<1xindex> -> tensor + %cast_0 = tensor.cast %36 : tensor to tensor<1xindex> + %37 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %38 = chlo.broadcast_divide %35, %37 : (tensor, tensor) -> tensor + %39 = chlo.broadcast_compare %38, %5 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %40 = mhlo.ceil %38 : tensor + %41 = mhlo.floor %38 : tensor + %42 = chlo.broadcast_select %39, %40, %41 : (tensor, tensor, tensor) -> tensor + %cast_1 = tensor.cast %42 : tensor to tensor<*xelem_type> + scf.yield %cast_1 : tensor<*xelem_type> + } else { + %34 = arith.cmpi ule, %31, %c2 : index + %35 = scf.if %34 -> (tensor<*xelem_type>) { + %36 = shape.broadcast %27#0, %3 : tensor, tensor<2xindex> -> tensor + %cast = tensor.cast %36 : tensor to tensor<2xindex> + %37 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xelem_type>, tensor<2xindex>) -> tensor + %38 = shape.broadcast %27#1, %3 : tensor, tensor<2xindex> -> tensor + %cast_0 = tensor.cast %38 : tensor to tensor<2xindex> + %39 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xelem_type>, tensor<2xindex>) -> tensor + %40 = chlo.broadcast_divide %37, %39 : (tensor, tensor) -> tensor + %41 = chlo.broadcast_compare %40, %5 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %42 = mhlo.ceil %40 : tensor + %43 = mhlo.floor %40 : tensor + %44 = chlo.broadcast_select %41, %42, %43 : (tensor, tensor, tensor) -> tensor + %cast_1 = tensor.cast %44 : tensor to tensor<*xelem_type> + scf.yield %cast_1 : tensor<*xelem_type> + } else { + %36 = arith.cmpi ule, %31, %c3 : index + %37 = scf.if %36 -> (tensor<*xelem_type>) { + %38 = shape.broadcast %27#0, %2 : tensor, tensor<3xindex> -> tensor + %cast = tensor.cast %38 : tensor to tensor<3xindex> + %39 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xelem_type>, tensor<3xindex>) -> tensor + %40 = shape.broadcast %27#1, %2 : tensor, tensor<3xindex> -> tensor + %cast_0 = tensor.cast %40 : tensor to tensor<3xindex> + %41 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xelem_type>, tensor<3xindex>) -> tensor + %42 = chlo.broadcast_divide %39, %41 : (tensor, tensor) -> tensor + %43 = chlo.broadcast_compare %42, %5 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %44 = mhlo.ceil %42 : tensor + %45 = mhlo.floor %42 : tensor + %46 = chlo.broadcast_select %43, %44, %45 : (tensor, tensor, tensor) -> tensor + %cast_1 = tensor.cast %46 : tensor to tensor<*xelem_type> + scf.yield %cast_1 : tensor<*xelem_type> + } else { + %38 = arith.cmpi ule, %31, %c4 : index + %39 = scf.if %38 -> (tensor<*xelem_type>) { + %40 = shape.broadcast %27#0, %1 : tensor, tensor<4xindex> -> tensor + %cast = tensor.cast %40 : tensor to tensor<4xindex> + %41 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xelem_type>, tensor<4xindex>) -> tensor + %42 = shape.broadcast %27#1, %1 : tensor, tensor<4xindex> -> tensor + %cast_0 = tensor.cast %42 : tensor to tensor<4xindex> + %43 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xelem_type>, tensor<4xindex>) -> tensor + %44 = chlo.broadcast_divide %41, %43 : (tensor, tensor) -> tensor + %45 = chlo.broadcast_compare %44, %5 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %46 = mhlo.ceil %44 : tensor + %47 = mhlo.floor %44 : tensor + %48 = chlo.broadcast_select %45, %46, %47 : (tensor, tensor, tensor) -> tensor + %cast_1 = tensor.cast %48 : tensor to tensor<*xelem_type> + scf.yield %cast_1 : tensor<*xelem_type> + } else { + %40 = arith.cmpi ule, %31, %c5 : index + cf.assert %40, "Input for dynamic binary or n-ary op lowering was of a rank greater than 5" + %41 = shape.broadcast %27#0, %0 : tensor, tensor<5xindex> -> tensor + %cast = tensor.cast %41 : tensor to tensor<5xindex> + %42 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xelem_type>, tensor<5xindex>) -> tensor + %43 = shape.broadcast %27#1, %0 : tensor, tensor<5xindex> -> tensor + %cast_0 = tensor.cast %43 : tensor to tensor<5xindex> + %44 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xelem_type>, tensor<5xindex>) -> tensor + %45 = chlo.broadcast_divide %42, %44 : (tensor, tensor) -> tensor + %46 = chlo.broadcast_compare %45, %5 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %47 = mhlo.ceil %45 : tensor + %48 = mhlo.floor %45 : tensor + %49 = chlo.broadcast_select %46, %47, %48 : (tensor, tensor, tensor) -> tensor + %cast_1 = tensor.cast %49 : tensor to tensor<*xelem_type> + scf.yield %cast_1 : tensor<*xelem_type> + } + scf.yield %39 : tensor<*xelem_type> + } + scf.yield %37 : tensor<*xelem_type> + } + scf.yield %35 : tensor<*xelem_type> + } + scf.yield %33 : tensor<*xelem_type> + } + scf.yield %26 : tensor<*xelem_type> + } + scf.yield %24 : tensor<*xelem_type> + } + %11 = shape.shape_of %arg0 : tensor<*xelem_type> -> tensor + %12 = shape.shape_of %arg1 : tensor<*xelem_type> -> tensor + %13 = shape.broadcast %11, %12 : tensor, tensor -> tensor + %14 = shape.shape_of %arg0 : tensor<*xelem_type> -> tensor + %15 = shape.shape_of %arg1 : tensor<*xelem_type> -> tensor + %16 = shape.broadcast %14, %15 : tensor, tensor -> tensor + %17 = shape.shape_of %arg0 : tensor<*xelem_type> -> tensor + %18 = shape.shape_of %arg1 : tensor<*xelem_type> -> tensor + %19 = shape.broadcast %17, %18 : tensor, tensor -> tensor + %20 = shape.broadcast %13, %16, %19 : tensor, tensor, tensor -> tensor + %21 = mhlo.dynamic_reshape %10, %20 : (tensor<*xelem_type>, tensor) -> tensor<*xelem_type> + return %21 : tensor<*xelem_type> } diff --git a/tensorflow/core/kernels/mlir_generated/op_definitions/xdivy.mlir.tmpl b/tensorflow/core/kernels/mlir_generated/op_definitions/xdivy.mlir.tmpl index d53eae03ce1bc3..4abf5a03d7fb04 100644 --- a/tensorflow/core/kernels/mlir_generated/op_definitions/xdivy.mlir.tmpl +++ b/tensorflow/core/kernels/mlir_generated/op_definitions/xdivy.mlir.tmpl @@ -1,9 +1,149 @@ -func.func @Xdivy_platform_elem_type_output_type(%arg0: tensor<*xelem_type>, - %arg1: tensor<*xelem_type>) -> tensor<*xoutput_type> attributes - {tf_entry, llvm.emit_c_interface} { - %0 = mhlo.constant dense<0.000000e+00> : tensor - %1 = chlo.broadcast_compare %arg0, %0 {comparison_direction = #chlo} : (tensor<*xelem_type>, tensor) -> tensor<*xi1> - %2 = chlo.broadcast_divide %arg0, %arg1 : (tensor<*xelem_type>, tensor<*xelem_type>) -> tensor<*xelem_type> - %3 = chlo.broadcast_select %1, %arg0, %2 : (tensor<*xi1>, tensor<*xelem_type>, tensor<*xelem_type>) -> tensor<*xelem_type> - func.return %3 : tensor<*xoutput_type> +func.func @Xdivy_platform_elem_type_output_type(%arg0: tensor<*xelem_type>, %arg1: tensor<*xelem_type>) -> tensor<*xoutput_type> attributes {llvm.emit_c_interface, tf_entry} { + %0 = shape.const_shape [1, 1, 1, 1, 1] : tensor<5xindex> + %c5 = arith.constant 5 : index + %1 = shape.const_shape [1, 1, 1, 1] : tensor<4xindex> + %c4 = arith.constant 4 : index + %2 = shape.const_shape [1, 1, 1] : tensor<3xindex> + %c3 = arith.constant 3 : index + %3 = shape.const_shape [1, 1] : tensor<2xindex> + %c2 = arith.constant 2 : index + %4 = shape.const_shape [1] : tensor<1xindex> + %c1 = arith.constant 1 : index + %5 = mhlo.constant dense<0.000000e+00> : tensor + %6 = shape.shape_of %arg0 : tensor<*xelem_type> -> tensor + %7 = shape.shape_of %arg1 : tensor<*xelem_type> -> tensor + %8 = shape.num_elements %6 : tensor -> index + %9 = arith.cmpi eq, %8, %c1 : index + %10 = scf.if %9 -> (tensor<*xelem_type>) { + %18 = shape.num_elements %7 : tensor -> index + %from_elements = tensor.from_elements %18 : tensor<1xindex> + %19 = mhlo.reshape %arg0 : (tensor<*xelem_type>) -> tensor + %20 = mhlo.dynamic_reshape %arg1, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %21 = chlo.broadcast_compare %19, %5 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %22 = chlo.broadcast_divide %19, %20 : (tensor, tensor) -> tensor + %23 = chlo.broadcast_select %21, %19, %22 : (tensor, tensor, tensor) -> tensor + %cast = tensor.cast %23 : tensor to tensor<*xelem_type> + scf.yield %cast : tensor<*xelem_type> + } else { + %18 = shape.num_elements %7 : tensor -> index + %19 = arith.cmpi eq, %18, %c1 : index + %20 = scf.if %19 -> (tensor<*xelem_type>) { + %21 = shape.num_elements %6 : tensor -> index + %from_elements = tensor.from_elements %21 : tensor<1xindex> + %22 = mhlo.dynamic_reshape %arg0, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %23 = mhlo.reshape %arg1 : (tensor<*xelem_type>) -> tensor + %24 = chlo.broadcast_compare %22, %5 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %25 = chlo.broadcast_divide %22, %23 : (tensor, tensor) -> tensor + %26 = chlo.broadcast_select %24, %22, %25 : (tensor, tensor, tensor) -> tensor + %cast = tensor.cast %26 : tensor to tensor<*xelem_type> + scf.yield %cast : tensor<*xelem_type> + } else { + %21 = shape.shape_eq %6, %7 : tensor, tensor + %22 = scf.if %21 -> (tensor<*xelem_type>) { + %23 = shape.any %6, %7 : tensor, tensor -> tensor + %24 = shape.num_elements %23 : tensor -> index + %from_elements = tensor.from_elements %24 : tensor<1xindex> + %25 = mhlo.dynamic_reshape %arg0, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %26 = mhlo.dynamic_reshape %arg1, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %27 = chlo.broadcast_compare %25, %5 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %28 = chlo.broadcast_divide %25, %26 : (tensor, tensor) -> tensor + %29 = chlo.broadcast_select %27, %25, %28 : (tensor, tensor, tensor) -> tensor + %cast = tensor.cast %29 : tensor to tensor<*xelem_type> + scf.yield %cast : tensor<*xelem_type> + } else { + %23:2 = chlo.minimum_broadcast_shapes %6, %7 : tensor, tensor -> tensor, tensor + %24 = shape.rank %23#0 : tensor -> index + %25 = shape.rank %23#1 : tensor -> index + %26 = arith.cmpi sgt, %24, %25 : index + %27 = arith.select %26, %24, %25 : index + %28 = arith.cmpi ule, %27, %c1 : index + %29 = scf.if %28 -> (tensor<*xelem_type>) { + %30 = shape.broadcast %23#0, %4 : tensor, tensor<1xindex> -> tensor + %cast = tensor.cast %30 : tensor to tensor<1xindex> + %31 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %32 = shape.broadcast %23#1, %4 : tensor, tensor<1xindex> -> tensor + %cast_0 = tensor.cast %32 : tensor to tensor<1xindex> + %33 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %34 = chlo.broadcast_compare %31, %5 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %35 = chlo.broadcast_divide %31, %33 : (tensor, tensor) -> tensor + %36 = chlo.broadcast_select %34, %31, %35 : (tensor, tensor, tensor) -> tensor + %cast_1 = tensor.cast %36 : tensor to tensor<*xelem_type> + scf.yield %cast_1 : tensor<*xelem_type> + } else { + %30 = arith.cmpi ule, %27, %c2 : index + %31 = scf.if %30 -> (tensor<*xelem_type>) { + %32 = shape.broadcast %23#0, %3 : tensor, tensor<2xindex> -> tensor + %cast = tensor.cast %32 : tensor to tensor<2xindex> + %33 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xelem_type>, tensor<2xindex>) -> tensor + %34 = shape.broadcast %23#1, %3 : tensor, tensor<2xindex> -> tensor + %cast_0 = tensor.cast %34 : tensor to tensor<2xindex> + %35 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xelem_type>, tensor<2xindex>) -> tensor + %36 = chlo.broadcast_compare %33, %5 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %37 = chlo.broadcast_divide %33, %35 : (tensor, tensor) -> tensor + %38 = chlo.broadcast_select %36, %33, %37 : (tensor, tensor, tensor) -> tensor + %cast_1 = tensor.cast %38 : tensor to tensor<*xelem_type> + scf.yield %cast_1 : tensor<*xelem_type> + } else { + %32 = arith.cmpi ule, %27, %c3 : index + %33 = scf.if %32 -> (tensor<*xelem_type>) { + %34 = shape.broadcast %23#0, %2 : tensor, tensor<3xindex> -> tensor + %cast = tensor.cast %34 : tensor to tensor<3xindex> + %35 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xelem_type>, tensor<3xindex>) -> tensor + %36 = shape.broadcast %23#1, %2 : tensor, tensor<3xindex> -> tensor + %cast_0 = tensor.cast %36 : tensor to tensor<3xindex> + %37 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xelem_type>, tensor<3xindex>) -> tensor + %38 = chlo.broadcast_compare %35, %5 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %39 = chlo.broadcast_divide %35, %37 : (tensor, tensor) -> tensor + %40 = chlo.broadcast_select %38, %35, %39 : (tensor, tensor, tensor) -> tensor + %cast_1 = tensor.cast %40 : tensor to tensor<*xelem_type> + scf.yield %cast_1 : tensor<*xelem_type> + } else { + %34 = arith.cmpi ule, %27, %c4 : index + %35 = scf.if %34 -> (tensor<*xelem_type>) { + %36 = shape.broadcast %23#0, %1 : tensor, tensor<4xindex> -> tensor + %cast = tensor.cast %36 : tensor to tensor<4xindex> + %37 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xelem_type>, tensor<4xindex>) -> tensor + %38 = shape.broadcast %23#1, %1 : tensor, tensor<4xindex> -> tensor + %cast_0 = tensor.cast %38 : tensor to tensor<4xindex> + %39 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xelem_type>, tensor<4xindex>) -> tensor + %40 = chlo.broadcast_compare %37, %5 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %41 = chlo.broadcast_divide %37, %39 : (tensor, tensor) -> tensor + %42 = chlo.broadcast_select %40, %37, %41 : (tensor, tensor, tensor) -> tensor + %cast_1 = tensor.cast %42 : tensor to tensor<*xelem_type> + scf.yield %cast_1 : tensor<*xelem_type> + } else { + %36 = arith.cmpi ule, %27, %c5 : index + cf.assert %36, "Input for dynamic binary or n-ary op lowering was of a rank greater than 5" + %37 = shape.broadcast %23#0, %0 : tensor, tensor<5xindex> -> tensor + %cast = tensor.cast %37 : tensor to tensor<5xindex> + %38 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xelem_type>, tensor<5xindex>) -> tensor + %39 = shape.broadcast %23#1, %0 : tensor, tensor<5xindex> -> tensor + %cast_0 = tensor.cast %39 : tensor to tensor<5xindex> + %40 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xelem_type>, tensor<5xindex>) -> tensor + %41 = chlo.broadcast_compare %38, %5 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %42 = chlo.broadcast_divide %38, %40 : (tensor, tensor) -> tensor + %43 = chlo.broadcast_select %41, %38, %42 : (tensor, tensor, tensor) -> tensor + %cast_1 = tensor.cast %43 : tensor to tensor<*xelem_type> + scf.yield %cast_1 : tensor<*xelem_type> + } + scf.yield %35 : tensor<*xelem_type> + } + scf.yield %33 : tensor<*xelem_type> + } + scf.yield %31 : tensor<*xelem_type> + } + scf.yield %29 : tensor<*xelem_type> + } + scf.yield %22 : tensor<*xelem_type> + } + scf.yield %20 : tensor<*xelem_type> + } + %11 = shape.shape_of %arg0 : tensor<*xelem_type> -> tensor + %12 = shape.shape_of %arg0 : tensor<*xelem_type> -> tensor + %13 = shape.shape_of %arg0 : tensor<*xelem_type> -> tensor + %14 = shape.shape_of %arg1 : tensor<*xelem_type> -> tensor + %15 = shape.broadcast %13, %14 : tensor, tensor -> tensor + %16 = shape.broadcast %11, %12, %15 : tensor, tensor, tensor -> tensor + %17 = mhlo.dynamic_reshape %10, %16 : (tensor<*xelem_type>, tensor) -> tensor<*xelem_type> + return %17 : tensor<*xelem_type> } diff --git a/tensorflow/core/kernels/mlir_generated/op_definitions/xdivy_cmplx.mlir.tmpl b/tensorflow/core/kernels/mlir_generated/op_definitions/xdivy_cmplx.mlir.tmpl index bc4861cc614af2..ad1baeb9204232 100644 --- a/tensorflow/core/kernels/mlir_generated/op_definitions/xdivy_cmplx.mlir.tmpl +++ b/tensorflow/core/kernels/mlir_generated/op_definitions/xdivy_cmplx.mlir.tmpl @@ -1,9 +1,149 @@ -func.func @Xdivy_platform_elem_type_output_type(%arg0: tensor<*xelem_type>, - %arg1: tensor<*xelem_type>) -> tensor<*xoutput_type> attributes - {tf_entry, llvm.emit_c_interface} { - %0 = mhlo.constant dense<(0.000000e+00,0.000000e+00)> : tensor - %1 = chlo.broadcast_compare %arg0, %0 {comparison_direction = #chlo} : (tensor<*xelem_type>, tensor) -> tensor<*xi1> - %2 = chlo.broadcast_divide %arg0, %arg1 : (tensor<*xelem_type>, tensor<*xelem_type>) -> tensor<*xelem_type> - %3 = chlo.broadcast_select %1, %arg0, %2 : (tensor<*xi1>, tensor<*xelem_type>, tensor<*xelem_type>) -> tensor<*xelem_type> - func.return %3 : tensor<*xoutput_type> +func.func @Xdivy_platform_elem_type_output_type(%arg0: tensor<*xelem_type>, %arg1: tensor<*xelem_type>) -> tensor<*xoutput_type> attributes {llvm.emit_c_interface, tf_entry} { + %0 = shape.const_shape [1, 1, 1, 1, 1] : tensor<5xindex> + %c5 = arith.constant 5 : index + %1 = shape.const_shape [1, 1, 1, 1] : tensor<4xindex> + %c4 = arith.constant 4 : index + %2 = shape.const_shape [1, 1, 1] : tensor<3xindex> + %c3 = arith.constant 3 : index + %3 = shape.const_shape [1, 1] : tensor<2xindex> + %c2 = arith.constant 2 : index + %4 = shape.const_shape [1] : tensor<1xindex> + %c1 = arith.constant 1 : index + %5 = mhlo.constant dense<(0.000000e+00,0.000000e+00)> : tensor + %6 = shape.shape_of %arg0 : tensor<*xelem_type> -> tensor + %7 = shape.shape_of %arg1 : tensor<*xelem_type> -> tensor + %8 = shape.num_elements %6 : tensor -> index + %9 = arith.cmpi eq, %8, %c1 : index + %10 = scf.if %9 -> (tensor<*xelem_type>) { + %18 = shape.num_elements %7 : tensor -> index + %from_elements = tensor.from_elements %18 : tensor<1xindex> + %19 = mhlo.reshape %arg0 : (tensor<*xelem_type>) -> tensor + %20 = mhlo.dynamic_reshape %arg1, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %21 = chlo.broadcast_compare %19, %5 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %22 = chlo.broadcast_divide %19, %20 : (tensor, tensor) -> tensor + %23 = chlo.broadcast_select %21, %19, %22 : (tensor, tensor, tensor) -> tensor + %cast = tensor.cast %23 : tensor to tensor<*xelem_type> + scf.yield %cast : tensor<*xelem_type> + } else { + %18 = shape.num_elements %7 : tensor -> index + %19 = arith.cmpi eq, %18, %c1 : index + %20 = scf.if %19 -> (tensor<*xelem_type>) { + %21 = shape.num_elements %6 : tensor -> index + %from_elements = tensor.from_elements %21 : tensor<1xindex> + %22 = mhlo.dynamic_reshape %arg0, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %23 = mhlo.reshape %arg1 : (tensor<*xelem_type>) -> tensor + %24 = chlo.broadcast_compare %22, %5 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %25 = chlo.broadcast_divide %22, %23 : (tensor, tensor) -> tensor + %26 = chlo.broadcast_select %24, %22, %25 : (tensor, tensor, tensor) -> tensor + %cast = tensor.cast %26 : tensor to tensor<*xelem_type> + scf.yield %cast : tensor<*xelem_type> + } else { + %21 = shape.shape_eq %6, %7 : tensor, tensor + %22 = scf.if %21 -> (tensor<*xelem_type>) { + %23 = shape.any %6, %7 : tensor, tensor -> tensor + %24 = shape.num_elements %23 : tensor -> index + %from_elements = tensor.from_elements %24 : tensor<1xindex> + %25 = mhlo.dynamic_reshape %arg0, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %26 = mhlo.dynamic_reshape %arg1, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %27 = chlo.broadcast_compare %25, %5 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %28 = chlo.broadcast_divide %25, %26 : (tensor, tensor) -> tensor + %29 = chlo.broadcast_select %27, %25, %28 : (tensor, tensor, tensor) -> tensor + %cast = tensor.cast %29 : tensor to tensor<*xelem_type> + scf.yield %cast : tensor<*xelem_type> + } else { + %23:2 = chlo.minimum_broadcast_shapes %6, %7 : tensor, tensor -> tensor, tensor + %24 = shape.rank %23#0 : tensor -> index + %25 = shape.rank %23#1 : tensor -> index + %26 = arith.cmpi sgt, %24, %25 : index + %27 = arith.select %26, %24, %25 : index + %28 = arith.cmpi ule, %27, %c1 : index + %29 = scf.if %28 -> (tensor<*xelem_type>) { + %30 = shape.broadcast %23#0, %4 : tensor, tensor<1xindex> -> tensor + %cast = tensor.cast %30 : tensor to tensor<1xindex> + %31 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %32 = shape.broadcast %23#1, %4 : tensor, tensor<1xindex> -> tensor + %cast_0 = tensor.cast %32 : tensor to tensor<1xindex> + %33 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %34 = chlo.broadcast_compare %31, %5 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %35 = chlo.broadcast_divide %31, %33 : (tensor, tensor) -> tensor + %36 = chlo.broadcast_select %34, %31, %35 : (tensor, tensor, tensor) -> tensor + %cast_1 = tensor.cast %36 : tensor to tensor<*xelem_type> + scf.yield %cast_1 : tensor<*xelem_type> + } else { + %30 = arith.cmpi ule, %27, %c2 : index + %31 = scf.if %30 -> (tensor<*xelem_type>) { + %32 = shape.broadcast %23#0, %3 : tensor, tensor<2xindex> -> tensor + %cast = tensor.cast %32 : tensor to tensor<2xindex> + %33 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xelem_type>, tensor<2xindex>) -> tensor + %34 = shape.broadcast %23#1, %3 : tensor, tensor<2xindex> -> tensor + %cast_0 = tensor.cast %34 : tensor to tensor<2xindex> + %35 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xelem_type>, tensor<2xindex>) -> tensor + %36 = chlo.broadcast_compare %33, %5 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %37 = chlo.broadcast_divide %33, %35 : (tensor, tensor) -> tensor + %38 = chlo.broadcast_select %36, %33, %37 : (tensor, tensor, tensor) -> tensor + %cast_1 = tensor.cast %38 : tensor to tensor<*xelem_type> + scf.yield %cast_1 : tensor<*xelem_type> + } else { + %32 = arith.cmpi ule, %27, %c3 : index + %33 = scf.if %32 -> (tensor<*xelem_type>) { + %34 = shape.broadcast %23#0, %2 : tensor, tensor<3xindex> -> tensor + %cast = tensor.cast %34 : tensor to tensor<3xindex> + %35 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xelem_type>, tensor<3xindex>) -> tensor + %36 = shape.broadcast %23#1, %2 : tensor, tensor<3xindex> -> tensor + %cast_0 = tensor.cast %36 : tensor to tensor<3xindex> + %37 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xelem_type>, tensor<3xindex>) -> tensor + %38 = chlo.broadcast_compare %35, %5 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %39 = chlo.broadcast_divide %35, %37 : (tensor, tensor) -> tensor + %40 = chlo.broadcast_select %38, %35, %39 : (tensor, tensor, tensor) -> tensor + %cast_1 = tensor.cast %40 : tensor to tensor<*xelem_type> + scf.yield %cast_1 : tensor<*xelem_type> + } else { + %34 = arith.cmpi ule, %27, %c4 : index + %35 = scf.if %34 -> (tensor<*xelem_type>) { + %36 = shape.broadcast %23#0, %1 : tensor, tensor<4xindex> -> tensor + %cast = tensor.cast %36 : tensor to tensor<4xindex> + %37 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xelem_type>, tensor<4xindex>) -> tensor + %38 = shape.broadcast %23#1, %1 : tensor, tensor<4xindex> -> tensor + %cast_0 = tensor.cast %38 : tensor to tensor<4xindex> + %39 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xelem_type>, tensor<4xindex>) -> tensor + %40 = chlo.broadcast_compare %37, %5 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %41 = chlo.broadcast_divide %37, %39 : (tensor, tensor) -> tensor + %42 = chlo.broadcast_select %40, %37, %41 : (tensor, tensor, tensor) -> tensor + %cast_1 = tensor.cast %42 : tensor to tensor<*xelem_type> + scf.yield %cast_1 : tensor<*xelem_type> + } else { + %36 = arith.cmpi ule, %27, %c5 : index + cf.assert %36, "Input for dynamic binary or n-ary op lowering was of a rank greater than 5" + %37 = shape.broadcast %23#0, %0 : tensor, tensor<5xindex> -> tensor + %cast = tensor.cast %37 : tensor to tensor<5xindex> + %38 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xelem_type>, tensor<5xindex>) -> tensor + %39 = shape.broadcast %23#1, %0 : tensor, tensor<5xindex> -> tensor + %cast_0 = tensor.cast %39 : tensor to tensor<5xindex> + %40 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xelem_type>, tensor<5xindex>) -> tensor + %41 = chlo.broadcast_compare %38, %5 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %42 = chlo.broadcast_divide %38, %40 : (tensor, tensor) -> tensor + %43 = chlo.broadcast_select %41, %38, %42 : (tensor, tensor, tensor) -> tensor + %cast_1 = tensor.cast %43 : tensor to tensor<*xelem_type> + scf.yield %cast_1 : tensor<*xelem_type> + } + scf.yield %35 : tensor<*xelem_type> + } + scf.yield %33 : tensor<*xelem_type> + } + scf.yield %31 : tensor<*xelem_type> + } + scf.yield %29 : tensor<*xelem_type> + } + scf.yield %22 : tensor<*xelem_type> + } + scf.yield %20 : tensor<*xelem_type> + } + %11 = shape.shape_of %arg0 : tensor<*xelem_type> -> tensor + %12 = shape.shape_of %arg0 : tensor<*xelem_type> -> tensor + %13 = shape.shape_of %arg0 : tensor<*xelem_type> -> tensor + %14 = shape.shape_of %arg1 : tensor<*xelem_type> -> tensor + %15 = shape.broadcast %13, %14 : tensor, tensor -> tensor + %16 = shape.broadcast %11, %12, %15 : tensor, tensor, tensor -> tensor + %17 = mhlo.dynamic_reshape %10, %16 : (tensor<*xelem_type>, tensor) -> tensor<*xelem_type> + return %17 : tensor<*xelem_type> } diff --git a/tensorflow/core/kernels/mlir_generated/op_definitions/xlog1py.mlir.tmpl b/tensorflow/core/kernels/mlir_generated/op_definitions/xlog1py.mlir.tmpl index 5246d903414270..f53f620dea1f08 100644 --- a/tensorflow/core/kernels/mlir_generated/op_definitions/xlog1py.mlir.tmpl +++ b/tensorflow/core/kernels/mlir_generated/op_definitions/xlog1py.mlir.tmpl @@ -1,10 +1,157 @@ -func.func @Xlog1py_platform_elem_type_output_type(%arg0: tensor<*xelem_type>, - %arg1: tensor<*xelem_type>) -> tensor<*xoutput_type> attributes - {tf_entry, llvm.emit_c_interface} { - %0 = mhlo.constant dense<0.000000e+00> : tensor - %1 = chlo.broadcast_compare %arg0, %0 {comparison_direction = #chlo} : (tensor<*xelem_type>, tensor) -> tensor<*xi1> - %2 = mhlo.log_plus_one %arg1 : tensor<*xelem_type> - %3 = chlo.broadcast_multiply %arg0, %2 : (tensor<*xelem_type>, tensor<*xelem_type>) -> tensor<*xelem_type> - %4 = chlo.broadcast_select %1, %arg0, %3 : (tensor<*xi1>, tensor<*xelem_type>, tensor<*xelem_type>) -> tensor<*xelem_type> - func.return %4 : tensor<*xoutput_type> +func.func @Xlog1py_platform_elem_type_output_type(%arg0: tensor<*xelem_type>, %arg1: tensor<*xelem_type>) -> tensor<*xoutput_type> attributes {llvm.emit_c_interface, tf_entry} { + %0 = shape.const_shape [1, 1, 1, 1, 1] : tensor<5xindex> + %c5 = arith.constant 5 : index + %1 = shape.const_shape [1, 1, 1, 1] : tensor<4xindex> + %c4 = arith.constant 4 : index + %2 = shape.const_shape [1, 1, 1] : tensor<3xindex> + %c3 = arith.constant 3 : index + %3 = shape.const_shape [1, 1] : tensor<2xindex> + %c2 = arith.constant 2 : index + %4 = shape.const_shape [1] : tensor<1xindex> + %c1 = arith.constant 1 : index + %5 = mhlo.constant dense<0.000000e+00> : tensor + %6 = shape.shape_of %arg0 : tensor<*xelem_type> -> tensor + %7 = shape.shape_of %arg1 : tensor<*xelem_type> -> tensor + %8 = shape.num_elements %6 : tensor -> index + %9 = arith.cmpi eq, %8, %c1 : index + %10 = scf.if %9 -> (tensor<*xelem_type>) { + %18 = shape.num_elements %7 : tensor -> index + %from_elements = tensor.from_elements %18 : tensor<1xindex> + %19 = mhlo.reshape %arg0 : (tensor<*xelem_type>) -> tensor + %20 = mhlo.dynamic_reshape %arg1, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %21 = chlo.broadcast_compare %19, %5 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %22 = mhlo.log_plus_one %20 : tensor + %23 = chlo.broadcast_multiply %19, %22 : (tensor, tensor) -> tensor + %24 = chlo.broadcast_select %21, %19, %23 : (tensor, tensor, tensor) -> tensor + %cast = tensor.cast %24 : tensor to tensor<*xelem_type> + scf.yield %cast : tensor<*xelem_type> + } else { + %18 = shape.num_elements %7 : tensor -> index + %19 = arith.cmpi eq, %18, %c1 : index + %20 = scf.if %19 -> (tensor<*xelem_type>) { + %21 = shape.num_elements %6 : tensor -> index + %from_elements = tensor.from_elements %21 : tensor<1xindex> + %22 = mhlo.dynamic_reshape %arg0, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %23 = mhlo.reshape %arg1 : (tensor<*xelem_type>) -> tensor + %24 = chlo.broadcast_compare %22, %5 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %25 = mhlo.log_plus_one %23 : tensor + %26 = chlo.broadcast_multiply %22, %25 : (tensor, tensor) -> tensor + %27 = chlo.broadcast_select %24, %22, %26 : (tensor, tensor, tensor) -> tensor + %cast = tensor.cast %27 : tensor to tensor<*xelem_type> + scf.yield %cast : tensor<*xelem_type> + } else { + %21 = shape.shape_eq %6, %7 : tensor, tensor + %22 = scf.if %21 -> (tensor<*xelem_type>) { + %23 = shape.any %6, %7 : tensor, tensor -> tensor + %24 = shape.num_elements %23 : tensor -> index + %from_elements = tensor.from_elements %24 : tensor<1xindex> + %25 = mhlo.dynamic_reshape %arg0, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %26 = mhlo.dynamic_reshape %arg1, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %27 = chlo.broadcast_compare %25, %5 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %28 = mhlo.log_plus_one %26 : tensor + %29 = chlo.broadcast_multiply %25, %28 : (tensor, tensor) -> tensor + %30 = chlo.broadcast_select %27, %25, %29 : (tensor, tensor, tensor) -> tensor + %cast = tensor.cast %30 : tensor to tensor<*xelem_type> + scf.yield %cast : tensor<*xelem_type> + } else { + %23:2 = chlo.minimum_broadcast_shapes %6, %7 : tensor, tensor -> tensor, tensor + %24 = shape.rank %23#0 : tensor -> index + %25 = shape.rank %23#1 : tensor -> index + %26 = arith.cmpi sgt, %24, %25 : index + %27 = arith.select %26, %24, %25 : index + %28 = arith.cmpi ule, %27, %c1 : index + %29 = scf.if %28 -> (tensor<*xelem_type>) { + %30 = shape.broadcast %23#0, %4 : tensor, tensor<1xindex> -> tensor + %cast = tensor.cast %30 : tensor to tensor<1xindex> + %31 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %32 = shape.broadcast %23#1, %4 : tensor, tensor<1xindex> -> tensor + %cast_0 = tensor.cast %32 : tensor to tensor<1xindex> + %33 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %34 = chlo.broadcast_compare %31, %5 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %35 = mhlo.log_plus_one %33 : tensor + %36 = chlo.broadcast_multiply %31, %35 : (tensor, tensor) -> tensor + %37 = chlo.broadcast_select %34, %31, %36 : (tensor, tensor, tensor) -> tensor + %cast_1 = tensor.cast %37 : tensor to tensor<*xelem_type> + scf.yield %cast_1 : tensor<*xelem_type> + } else { + %30 = arith.cmpi ule, %27, %c2 : index + %31 = scf.if %30 -> (tensor<*xelem_type>) { + %32 = shape.broadcast %23#0, %3 : tensor, tensor<2xindex> -> tensor + %cast = tensor.cast %32 : tensor to tensor<2xindex> + %33 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xelem_type>, tensor<2xindex>) -> tensor + %34 = shape.broadcast %23#1, %3 : tensor, tensor<2xindex> -> tensor + %cast_0 = tensor.cast %34 : tensor to tensor<2xindex> + %35 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xelem_type>, tensor<2xindex>) -> tensor + %36 = chlo.broadcast_compare %33, %5 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %37 = mhlo.log_plus_one %35 : tensor + %38 = chlo.broadcast_multiply %33, %37 : (tensor, tensor) -> tensor + %39 = chlo.broadcast_select %36, %33, %38 : (tensor, tensor, tensor) -> tensor + %cast_1 = tensor.cast %39 : tensor to tensor<*xelem_type> + scf.yield %cast_1 : tensor<*xelem_type> + } else { + %32 = arith.cmpi ule, %27, %c3 : index + %33 = scf.if %32 -> (tensor<*xelem_type>) { + %34 = shape.broadcast %23#0, %2 : tensor, tensor<3xindex> -> tensor + %cast = tensor.cast %34 : tensor to tensor<3xindex> + %35 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xelem_type>, tensor<3xindex>) -> tensor + %36 = shape.broadcast %23#1, %2 : tensor, tensor<3xindex> -> tensor + %cast_0 = tensor.cast %36 : tensor to tensor<3xindex> + %37 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xelem_type>, tensor<3xindex>) -> tensor + %38 = chlo.broadcast_compare %35, %5 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %39 = mhlo.log_plus_one %37 : tensor + %40 = chlo.broadcast_multiply %35, %39 : (tensor, tensor) -> tensor + %41 = chlo.broadcast_select %38, %35, %40 : (tensor, tensor, tensor) -> tensor + %cast_1 = tensor.cast %41 : tensor to tensor<*xelem_type> + scf.yield %cast_1 : tensor<*xelem_type> + } else { + %34 = arith.cmpi ule, %27, %c4 : index + %35 = scf.if %34 -> (tensor<*xelem_type>) { + %36 = shape.broadcast %23#0, %1 : tensor, tensor<4xindex> -> tensor + %cast = tensor.cast %36 : tensor to tensor<4xindex> + %37 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xelem_type>, tensor<4xindex>) -> tensor + %38 = shape.broadcast %23#1, %1 : tensor, tensor<4xindex> -> tensor + %cast_0 = tensor.cast %38 : tensor to tensor<4xindex> + %39 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xelem_type>, tensor<4xindex>) -> tensor + %40 = chlo.broadcast_compare %37, %5 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %41 = mhlo.log_plus_one %39 : tensor + %42 = chlo.broadcast_multiply %37, %41 : (tensor, tensor) -> tensor + %43 = chlo.broadcast_select %40, %37, %42 : (tensor, tensor, tensor) -> tensor + %cast_1 = tensor.cast %43 : tensor to tensor<*xelem_type> + scf.yield %cast_1 : tensor<*xelem_type> + } else { + %36 = arith.cmpi ule, %27, %c5 : index + cf.assert %36, "Input for dynamic binary or n-ary op lowering was of a rank greater than 5" + %37 = shape.broadcast %23#0, %0 : tensor, tensor<5xindex> -> tensor + %cast = tensor.cast %37 : tensor to tensor<5xindex> + %38 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xelem_type>, tensor<5xindex>) -> tensor + %39 = shape.broadcast %23#1, %0 : tensor, tensor<5xindex> -> tensor + %cast_0 = tensor.cast %39 : tensor to tensor<5xindex> + %40 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xelem_type>, tensor<5xindex>) -> tensor + %41 = chlo.broadcast_compare %38, %5 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %42 = mhlo.log_plus_one %40 : tensor + %43 = chlo.broadcast_multiply %38, %42 : (tensor, tensor) -> tensor + %44 = chlo.broadcast_select %41, %38, %43 : (tensor, tensor, tensor) -> tensor + %cast_1 = tensor.cast %44 : tensor to tensor<*xelem_type> + scf.yield %cast_1 : tensor<*xelem_type> + } + scf.yield %35 : tensor<*xelem_type> + } + scf.yield %33 : tensor<*xelem_type> + } + scf.yield %31 : tensor<*xelem_type> + } + scf.yield %29 : tensor<*xelem_type> + } + scf.yield %22 : tensor<*xelem_type> + } + scf.yield %20 : tensor<*xelem_type> + } + %11 = shape.shape_of %arg0 : tensor<*xelem_type> -> tensor + %12 = shape.shape_of %arg0 : tensor<*xelem_type> -> tensor + %13 = shape.shape_of %arg0 : tensor<*xelem_type> -> tensor + %14 = shape.shape_of %arg1 : tensor<*xelem_type> -> tensor + %15 = shape.broadcast %13, %14 : tensor, tensor -> tensor + %16 = shape.broadcast %11, %12, %15 : tensor, tensor, tensor -> tensor + %17 = mhlo.dynamic_reshape %10, %16 : (tensor<*xelem_type>, tensor) -> tensor<*xelem_type> + return %17 : tensor<*xelem_type> } diff --git a/tensorflow/core/kernels/mlir_generated/op_definitions/xlog1py_cmplx.mlir.tmpl b/tensorflow/core/kernels/mlir_generated/op_definitions/xlog1py_cmplx.mlir.tmpl index 4b3a9b96e17b4d..c5282303d94f7a 100644 --- a/tensorflow/core/kernels/mlir_generated/op_definitions/xlog1py_cmplx.mlir.tmpl +++ b/tensorflow/core/kernels/mlir_generated/op_definitions/xlog1py_cmplx.mlir.tmpl @@ -1,10 +1,157 @@ -func.func @Xlog1py_platform_elem_type_output_type(%arg0: tensor<*xelem_type>, - %arg1: tensor<*xelem_type>) -> tensor<*xoutput_type> attributes - {tf_entry, llvm.emit_c_interface} { - %0 = mhlo.constant dense<(1.000000e+00,0.000000e+00)> : tensor - %1 = chlo.broadcast_compare %arg0, %0 {comparison_direction = #chlo} : (tensor<*xelem_type>, tensor) -> tensor<*xi1> - %2 = mhlo.log_plus_one %arg1 : tensor<*xelem_type> - %3 = chlo.broadcast_multiply %arg0, %2 : (tensor<*xelem_type>, tensor<*xelem_type>) -> tensor<*xelem_type> - %4 = chlo.broadcast_select %1, %arg0, %3 : (tensor<*xi1>, tensor<*xelem_type>, tensor<*xelem_type>) -> tensor<*xelem_type> - func.return %4 : tensor<*xoutput_type> +func.func @Xlog1py_platform_elem_type_output_type(%arg0: tensor<*xelem_type>, %arg1: tensor<*xelem_type>) -> tensor<*xoutput_type> attributes {llvm.emit_c_interface, tf_entry} { + %0 = shape.const_shape [1, 1, 1, 1, 1] : tensor<5xindex> + %c5 = arith.constant 5 : index + %1 = shape.const_shape [1, 1, 1, 1] : tensor<4xindex> + %c4 = arith.constant 4 : index + %2 = shape.const_shape [1, 1, 1] : tensor<3xindex> + %c3 = arith.constant 3 : index + %3 = shape.const_shape [1, 1] : tensor<2xindex> + %c2 = arith.constant 2 : index + %4 = shape.const_shape [1] : tensor<1xindex> + %c1 = arith.constant 1 : index + %5 = mhlo.constant dense<(1.000000e+00,0.000000e+00)> : tensor + %6 = shape.shape_of %arg0 : tensor<*xelem_type> -> tensor + %7 = shape.shape_of %arg1 : tensor<*xelem_type> -> tensor + %8 = shape.num_elements %6 : tensor -> index + %9 = arith.cmpi eq, %8, %c1 : index + %10 = scf.if %9 -> (tensor<*xelem_type>) { + %18 = shape.num_elements %7 : tensor -> index + %from_elements = tensor.from_elements %18 : tensor<1xindex> + %19 = mhlo.reshape %arg0 : (tensor<*xelem_type>) -> tensor + %20 = mhlo.dynamic_reshape %arg1, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %21 = chlo.broadcast_compare %19, %5 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %22 = mhlo.log_plus_one %20 : tensor + %23 = chlo.broadcast_multiply %19, %22 : (tensor, tensor) -> tensor + %24 = chlo.broadcast_select %21, %19, %23 : (tensor, tensor, tensor) -> tensor + %cast = tensor.cast %24 : tensor to tensor<*xelem_type> + scf.yield %cast : tensor<*xelem_type> + } else { + %18 = shape.num_elements %7 : tensor -> index + %19 = arith.cmpi eq, %18, %c1 : index + %20 = scf.if %19 -> (tensor<*xelem_type>) { + %21 = shape.num_elements %6 : tensor -> index + %from_elements = tensor.from_elements %21 : tensor<1xindex> + %22 = mhlo.dynamic_reshape %arg0, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %23 = mhlo.reshape %arg1 : (tensor<*xelem_type>) -> tensor + %24 = chlo.broadcast_compare %22, %5 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %25 = mhlo.log_plus_one %23 : tensor + %26 = chlo.broadcast_multiply %22, %25 : (tensor, tensor) -> tensor + %27 = chlo.broadcast_select %24, %22, %26 : (tensor, tensor, tensor) -> tensor + %cast = tensor.cast %27 : tensor to tensor<*xelem_type> + scf.yield %cast : tensor<*xelem_type> + } else { + %21 = shape.shape_eq %6, %7 : tensor, tensor + %22 = scf.if %21 -> (tensor<*xelem_type>) { + %23 = shape.any %6, %7 : tensor, tensor -> tensor + %24 = shape.num_elements %23 : tensor -> index + %from_elements = tensor.from_elements %24 : tensor<1xindex> + %25 = mhlo.dynamic_reshape %arg0, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %26 = mhlo.dynamic_reshape %arg1, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %27 = chlo.broadcast_compare %25, %5 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %28 = mhlo.log_plus_one %26 : tensor + %29 = chlo.broadcast_multiply %25, %28 : (tensor, tensor) -> tensor + %30 = chlo.broadcast_select %27, %25, %29 : (tensor, tensor, tensor) -> tensor + %cast = tensor.cast %30 : tensor to tensor<*xelem_type> + scf.yield %cast : tensor<*xelem_type> + } else { + %23:2 = chlo.minimum_broadcast_shapes %6, %7 : tensor, tensor -> tensor, tensor + %24 = shape.rank %23#0 : tensor -> index + %25 = shape.rank %23#1 : tensor -> index + %26 = arith.cmpi sgt, %24, %25 : index + %27 = arith.select %26, %24, %25 : index + %28 = arith.cmpi ule, %27, %c1 : index + %29 = scf.if %28 -> (tensor<*xelem_type>) { + %30 = shape.broadcast %23#0, %4 : tensor, tensor<1xindex> -> tensor + %cast = tensor.cast %30 : tensor to tensor<1xindex> + %31 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %32 = shape.broadcast %23#1, %4 : tensor, tensor<1xindex> -> tensor + %cast_0 = tensor.cast %32 : tensor to tensor<1xindex> + %33 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %34 = chlo.broadcast_compare %31, %5 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %35 = mhlo.log_plus_one %33 : tensor + %36 = chlo.broadcast_multiply %31, %35 : (tensor, tensor) -> tensor + %37 = chlo.broadcast_select %34, %31, %36 : (tensor, tensor, tensor) -> tensor + %cast_1 = tensor.cast %37 : tensor to tensor<*xelem_type> + scf.yield %cast_1 : tensor<*xelem_type> + } else { + %30 = arith.cmpi ule, %27, %c2 : index + %31 = scf.if %30 -> (tensor<*xelem_type>) { + %32 = shape.broadcast %23#0, %3 : tensor, tensor<2xindex> -> tensor + %cast = tensor.cast %32 : tensor to tensor<2xindex> + %33 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xelem_type>, tensor<2xindex>) -> tensor + %34 = shape.broadcast %23#1, %3 : tensor, tensor<2xindex> -> tensor + %cast_0 = tensor.cast %34 : tensor to tensor<2xindex> + %35 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xelem_type>, tensor<2xindex>) -> tensor + %36 = chlo.broadcast_compare %33, %5 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %37 = mhlo.log_plus_one %35 : tensor + %38 = chlo.broadcast_multiply %33, %37 : (tensor, tensor) -> tensor + %39 = chlo.broadcast_select %36, %33, %38 : (tensor, tensor, tensor) -> tensor + %cast_1 = tensor.cast %39 : tensor to tensor<*xelem_type> + scf.yield %cast_1 : tensor<*xelem_type> + } else { + %32 = arith.cmpi ule, %27, %c3 : index + %33 = scf.if %32 -> (tensor<*xelem_type>) { + %34 = shape.broadcast %23#0, %2 : tensor, tensor<3xindex> -> tensor + %cast = tensor.cast %34 : tensor to tensor<3xindex> + %35 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xelem_type>, tensor<3xindex>) -> tensor + %36 = shape.broadcast %23#1, %2 : tensor, tensor<3xindex> -> tensor + %cast_0 = tensor.cast %36 : tensor to tensor<3xindex> + %37 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xelem_type>, tensor<3xindex>) -> tensor + %38 = chlo.broadcast_compare %35, %5 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %39 = mhlo.log_plus_one %37 : tensor + %40 = chlo.broadcast_multiply %35, %39 : (tensor, tensor) -> tensor + %41 = chlo.broadcast_select %38, %35, %40 : (tensor, tensor, tensor) -> tensor + %cast_1 = tensor.cast %41 : tensor to tensor<*xelem_type> + scf.yield %cast_1 : tensor<*xelem_type> + } else { + %34 = arith.cmpi ule, %27, %c4 : index + %35 = scf.if %34 -> (tensor<*xelem_type>) { + %36 = shape.broadcast %23#0, %1 : tensor, tensor<4xindex> -> tensor + %cast = tensor.cast %36 : tensor to tensor<4xindex> + %37 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xelem_type>, tensor<4xindex>) -> tensor + %38 = shape.broadcast %23#1, %1 : tensor, tensor<4xindex> -> tensor + %cast_0 = tensor.cast %38 : tensor to tensor<4xindex> + %39 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xelem_type>, tensor<4xindex>) -> tensor + %40 = chlo.broadcast_compare %37, %5 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %41 = mhlo.log_plus_one %39 : tensor + %42 = chlo.broadcast_multiply %37, %41 : (tensor, tensor) -> tensor + %43 = chlo.broadcast_select %40, %37, %42 : (tensor, tensor, tensor) -> tensor + %cast_1 = tensor.cast %43 : tensor to tensor<*xelem_type> + scf.yield %cast_1 : tensor<*xelem_type> + } else { + %36 = arith.cmpi ule, %27, %c5 : index + cf.assert %36, "Input for dynamic binary or n-ary op lowering was of a rank greater than 5" + %37 = shape.broadcast %23#0, %0 : tensor, tensor<5xindex> -> tensor + %cast = tensor.cast %37 : tensor to tensor<5xindex> + %38 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xelem_type>, tensor<5xindex>) -> tensor + %39 = shape.broadcast %23#1, %0 : tensor, tensor<5xindex> -> tensor + %cast_0 = tensor.cast %39 : tensor to tensor<5xindex> + %40 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xelem_type>, tensor<5xindex>) -> tensor + %41 = chlo.broadcast_compare %38, %5 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %42 = mhlo.log_plus_one %40 : tensor + %43 = chlo.broadcast_multiply %38, %42 : (tensor, tensor) -> tensor + %44 = chlo.broadcast_select %41, %38, %43 : (tensor, tensor, tensor) -> tensor + %cast_1 = tensor.cast %44 : tensor to tensor<*xelem_type> + scf.yield %cast_1 : tensor<*xelem_type> + } + scf.yield %35 : tensor<*xelem_type> + } + scf.yield %33 : tensor<*xelem_type> + } + scf.yield %31 : tensor<*xelem_type> + } + scf.yield %29 : tensor<*xelem_type> + } + scf.yield %22 : tensor<*xelem_type> + } + scf.yield %20 : tensor<*xelem_type> + } + %11 = shape.shape_of %arg0 : tensor<*xelem_type> -> tensor + %12 = shape.shape_of %arg0 : tensor<*xelem_type> -> tensor + %13 = shape.shape_of %arg0 : tensor<*xelem_type> -> tensor + %14 = shape.shape_of %arg1 : tensor<*xelem_type> -> tensor + %15 = shape.broadcast %13, %14 : tensor, tensor -> tensor + %16 = shape.broadcast %11, %12, %15 : tensor, tensor, tensor -> tensor + %17 = mhlo.dynamic_reshape %10, %16 : (tensor<*xelem_type>, tensor) -> tensor<*xelem_type> + return %17 : tensor<*xelem_type> } diff --git a/tensorflow/core/kernels/mlir_generated/op_definitions/xlogy.mlir.tmpl b/tensorflow/core/kernels/mlir_generated/op_definitions/xlogy.mlir.tmpl index 59cf8a1b594114..224b409127f35c 100644 --- a/tensorflow/core/kernels/mlir_generated/op_definitions/xlogy.mlir.tmpl +++ b/tensorflow/core/kernels/mlir_generated/op_definitions/xlogy.mlir.tmpl @@ -1,10 +1,157 @@ -func.func @Xlogy_platform_elem_type_output_type(%arg0: tensor<*xelem_type>, - %arg1: tensor<*xelem_type>) -> tensor<*xoutput_type> attributes - {tf_entry, llvm.emit_c_interface} { - %0 = mhlo.constant dense<0.000000e+00> : tensor - %1 = chlo.broadcast_compare %arg0, %0 {comparison_direction = #chlo} : (tensor<*xelem_type>, tensor) -> tensor<*xi1> - %2 = mhlo.log %arg1 : tensor<*xelem_type> - %3 = chlo.broadcast_multiply %arg0, %2 : (tensor<*xelem_type>, tensor<*xelem_type>) -> tensor<*xelem_type> - %4 = chlo.broadcast_select %1, %arg0, %3 : (tensor<*xi1>, tensor<*xelem_type>, tensor<*xelem_type>) -> tensor<*xelem_type> - func.return %4 : tensor<*xoutput_type> +func.func @Xlogy_platform_elem_type_output_type(%arg0: tensor<*xelem_type>, %arg1: tensor<*xelem_type>) -> tensor<*xoutput_type> attributes {llvm.emit_c_interface, tf_entry} { + %0 = shape.const_shape [1, 1, 1, 1, 1] : tensor<5xindex> + %c5 = arith.constant 5 : index + %1 = shape.const_shape [1, 1, 1, 1] : tensor<4xindex> + %c4 = arith.constant 4 : index + %2 = shape.const_shape [1, 1, 1] : tensor<3xindex> + %c3 = arith.constant 3 : index + %3 = shape.const_shape [1, 1] : tensor<2xindex> + %c2 = arith.constant 2 : index + %4 = shape.const_shape [1] : tensor<1xindex> + %c1 = arith.constant 1 : index + %5 = mhlo.constant dense<0.000000e+00> : tensor + %6 = shape.shape_of %arg0 : tensor<*xelem_type> -> tensor + %7 = shape.shape_of %arg1 : tensor<*xelem_type> -> tensor + %8 = shape.num_elements %6 : tensor -> index + %9 = arith.cmpi eq, %8, %c1 : index + %10 = scf.if %9 -> (tensor<*xelem_type>) { + %18 = shape.num_elements %7 : tensor -> index + %from_elements = tensor.from_elements %18 : tensor<1xindex> + %19 = mhlo.reshape %arg0 : (tensor<*xelem_type>) -> tensor + %20 = mhlo.dynamic_reshape %arg1, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %21 = chlo.broadcast_compare %19, %5 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %22 = mhlo.log %20 : tensor + %23 = chlo.broadcast_multiply %19, %22 : (tensor, tensor) -> tensor + %24 = chlo.broadcast_select %21, %19, %23 : (tensor, tensor, tensor) -> tensor + %cast = tensor.cast %24 : tensor to tensor<*xelem_type> + scf.yield %cast : tensor<*xelem_type> + } else { + %18 = shape.num_elements %7 : tensor -> index + %19 = arith.cmpi eq, %18, %c1 : index + %20 = scf.if %19 -> (tensor<*xelem_type>) { + %21 = shape.num_elements %6 : tensor -> index + %from_elements = tensor.from_elements %21 : tensor<1xindex> + %22 = mhlo.dynamic_reshape %arg0, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %23 = mhlo.reshape %arg1 : (tensor<*xelem_type>) -> tensor + %24 = chlo.broadcast_compare %22, %5 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %25 = mhlo.log %23 : tensor + %26 = chlo.broadcast_multiply %22, %25 : (tensor, tensor) -> tensor + %27 = chlo.broadcast_select %24, %22, %26 : (tensor, tensor, tensor) -> tensor + %cast = tensor.cast %27 : tensor to tensor<*xelem_type> + scf.yield %cast : tensor<*xelem_type> + } else { + %21 = shape.shape_eq %6, %7 : tensor, tensor + %22 = scf.if %21 -> (tensor<*xelem_type>) { + %23 = shape.any %6, %7 : tensor, tensor -> tensor + %24 = shape.num_elements %23 : tensor -> index + %from_elements = tensor.from_elements %24 : tensor<1xindex> + %25 = mhlo.dynamic_reshape %arg0, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %26 = mhlo.dynamic_reshape %arg1, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %27 = chlo.broadcast_compare %25, %5 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %28 = mhlo.log %26 : tensor + %29 = chlo.broadcast_multiply %25, %28 : (tensor, tensor) -> tensor + %30 = chlo.broadcast_select %27, %25, %29 : (tensor, tensor, tensor) -> tensor + %cast = tensor.cast %30 : tensor to tensor<*xelem_type> + scf.yield %cast : tensor<*xelem_type> + } else { + %23:2 = chlo.minimum_broadcast_shapes %6, %7 : tensor, tensor -> tensor, tensor + %24 = shape.rank %23#0 : tensor -> index + %25 = shape.rank %23#1 : tensor -> index + %26 = arith.cmpi sgt, %24, %25 : index + %27 = arith.select %26, %24, %25 : index + %28 = arith.cmpi ule, %27, %c1 : index + %29 = scf.if %28 -> (tensor<*xelem_type>) { + %30 = shape.broadcast %23#0, %4 : tensor, tensor<1xindex> -> tensor + %cast = tensor.cast %30 : tensor to tensor<1xindex> + %31 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %32 = shape.broadcast %23#1, %4 : tensor, tensor<1xindex> -> tensor + %cast_0 = tensor.cast %32 : tensor to tensor<1xindex> + %33 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %34 = chlo.broadcast_compare %31, %5 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %35 = mhlo.log %33 : tensor + %36 = chlo.broadcast_multiply %31, %35 : (tensor, tensor) -> tensor + %37 = chlo.broadcast_select %34, %31, %36 : (tensor, tensor, tensor) -> tensor + %cast_1 = tensor.cast %37 : tensor to tensor<*xelem_type> + scf.yield %cast_1 : tensor<*xelem_type> + } else { + %30 = arith.cmpi ule, %27, %c2 : index + %31 = scf.if %30 -> (tensor<*xelem_type>) { + %32 = shape.broadcast %23#0, %3 : tensor, tensor<2xindex> -> tensor + %cast = tensor.cast %32 : tensor to tensor<2xindex> + %33 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xelem_type>, tensor<2xindex>) -> tensor + %34 = shape.broadcast %23#1, %3 : tensor, tensor<2xindex> -> tensor + %cast_0 = tensor.cast %34 : tensor to tensor<2xindex> + %35 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xelem_type>, tensor<2xindex>) -> tensor + %36 = chlo.broadcast_compare %33, %5 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %37 = mhlo.log %35 : tensor + %38 = chlo.broadcast_multiply %33, %37 : (tensor, tensor) -> tensor + %39 = chlo.broadcast_select %36, %33, %38 : (tensor, tensor, tensor) -> tensor + %cast_1 = tensor.cast %39 : tensor to tensor<*xelem_type> + scf.yield %cast_1 : tensor<*xelem_type> + } else { + %32 = arith.cmpi ule, %27, %c3 : index + %33 = scf.if %32 -> (tensor<*xelem_type>) { + %34 = shape.broadcast %23#0, %2 : tensor, tensor<3xindex> -> tensor + %cast = tensor.cast %34 : tensor to tensor<3xindex> + %35 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xelem_type>, tensor<3xindex>) -> tensor + %36 = shape.broadcast %23#1, %2 : tensor, tensor<3xindex> -> tensor + %cast_0 = tensor.cast %36 : tensor to tensor<3xindex> + %37 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xelem_type>, tensor<3xindex>) -> tensor + %38 = chlo.broadcast_compare %35, %5 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %39 = mhlo.log %37 : tensor + %40 = chlo.broadcast_multiply %35, %39 : (tensor, tensor) -> tensor + %41 = chlo.broadcast_select %38, %35, %40 : (tensor, tensor, tensor) -> tensor + %cast_1 = tensor.cast %41 : tensor to tensor<*xelem_type> + scf.yield %cast_1 : tensor<*xelem_type> + } else { + %34 = arith.cmpi ule, %27, %c4 : index + %35 = scf.if %34 -> (tensor<*xelem_type>) { + %36 = shape.broadcast %23#0, %1 : tensor, tensor<4xindex> -> tensor + %cast = tensor.cast %36 : tensor to tensor<4xindex> + %37 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xelem_type>, tensor<4xindex>) -> tensor + %38 = shape.broadcast %23#1, %1 : tensor, tensor<4xindex> -> tensor + %cast_0 = tensor.cast %38 : tensor to tensor<4xindex> + %39 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xelem_type>, tensor<4xindex>) -> tensor + %40 = chlo.broadcast_compare %37, %5 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %41 = mhlo.log %39 : tensor + %42 = chlo.broadcast_multiply %37, %41 : (tensor, tensor) -> tensor + %43 = chlo.broadcast_select %40, %37, %42 : (tensor, tensor, tensor) -> tensor + %cast_1 = tensor.cast %43 : tensor to tensor<*xelem_type> + scf.yield %cast_1 : tensor<*xelem_type> + } else { + %36 = arith.cmpi ule, %27, %c5 : index + cf.assert %36, "Input for dynamic binary or n-ary op lowering was of a rank greater than 5" + %37 = shape.broadcast %23#0, %0 : tensor, tensor<5xindex> -> tensor + %cast = tensor.cast %37 : tensor to tensor<5xindex> + %38 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xelem_type>, tensor<5xindex>) -> tensor + %39 = shape.broadcast %23#1, %0 : tensor, tensor<5xindex> -> tensor + %cast_0 = tensor.cast %39 : tensor to tensor<5xindex> + %40 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xelem_type>, tensor<5xindex>) -> tensor + %41 = chlo.broadcast_compare %38, %5 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %42 = mhlo.log %40 : tensor + %43 = chlo.broadcast_multiply %38, %42 : (tensor, tensor) -> tensor + %44 = chlo.broadcast_select %41, %38, %43 : (tensor, tensor, tensor) -> tensor + %cast_1 = tensor.cast %44 : tensor to tensor<*xelem_type> + scf.yield %cast_1 : tensor<*xelem_type> + } + scf.yield %35 : tensor<*xelem_type> + } + scf.yield %33 : tensor<*xelem_type> + } + scf.yield %31 : tensor<*xelem_type> + } + scf.yield %29 : tensor<*xelem_type> + } + scf.yield %22 : tensor<*xelem_type> + } + scf.yield %20 : tensor<*xelem_type> + } + %11 = shape.shape_of %arg0 : tensor<*xelem_type> -> tensor + %12 = shape.shape_of %arg0 : tensor<*xelem_type> -> tensor + %13 = shape.shape_of %arg0 : tensor<*xelem_type> -> tensor + %14 = shape.shape_of %arg1 : tensor<*xelem_type> -> tensor + %15 = shape.broadcast %13, %14 : tensor, tensor -> tensor + %16 = shape.broadcast %11, %12, %15 : tensor, tensor, tensor -> tensor + %17 = mhlo.dynamic_reshape %10, %16 : (tensor<*xelem_type>, tensor) -> tensor<*xelem_type> + return %17 : tensor<*xelem_type> } diff --git a/tensorflow/core/kernels/mlir_generated/op_definitions/xlogy_cmplx.mlir.tmpl b/tensorflow/core/kernels/mlir_generated/op_definitions/xlogy_cmplx.mlir.tmpl index 441095d77d4940..4c5fa22db919db 100644 --- a/tensorflow/core/kernels/mlir_generated/op_definitions/xlogy_cmplx.mlir.tmpl +++ b/tensorflow/core/kernels/mlir_generated/op_definitions/xlogy_cmplx.mlir.tmpl @@ -1,10 +1,157 @@ -func.func @Xlogy_platform_elem_type_output_type(%arg0: tensor<*xelem_type>, - %arg1: tensor<*xelem_type>) -> tensor<*xoutput_type> attributes - {tf_entry, llvm.emit_c_interface} { - %0 = mhlo.constant dense<(0.000000e+00,0.000000e+00)> : tensor - %1 = chlo.broadcast_compare %arg0, %0 {comparison_direction = #chlo} : (tensor<*xelem_type>, tensor) -> tensor<*xi1> - %2 = mhlo.log %arg1 : tensor<*xelem_type> - %3 = chlo.broadcast_multiply %arg0, %2 : (tensor<*xelem_type>, tensor<*xelem_type>) -> tensor<*xelem_type> - %4 = chlo.broadcast_select %1, %arg0, %3 : (tensor<*xi1>, tensor<*xelem_type>, tensor<*xelem_type>) -> tensor<*xelem_type> - func.return %4 : tensor<*xoutput_type> +func.func @Xlogy_platform_elem_type_output_type(%arg0: tensor<*xelem_type>, %arg1: tensor<*xelem_type>) -> tensor<*xoutput_type> attributes {llvm.emit_c_interface, tf_entry} { + %0 = shape.const_shape [1, 1, 1, 1, 1] : tensor<5xindex> + %c5 = arith.constant 5 : index + %1 = shape.const_shape [1, 1, 1, 1] : tensor<4xindex> + %c4 = arith.constant 4 : index + %2 = shape.const_shape [1, 1, 1] : tensor<3xindex> + %c3 = arith.constant 3 : index + %3 = shape.const_shape [1, 1] : tensor<2xindex> + %c2 = arith.constant 2 : index + %4 = shape.const_shape [1] : tensor<1xindex> + %c1 = arith.constant 1 : index + %5 = mhlo.constant dense<(0.000000e+00,0.000000e+00)> : tensor + %6 = shape.shape_of %arg0 : tensor<*xelem_type> -> tensor + %7 = shape.shape_of %arg1 : tensor<*xelem_type> -> tensor + %8 = shape.num_elements %6 : tensor -> index + %9 = arith.cmpi eq, %8, %c1 : index + %10 = scf.if %9 -> (tensor<*xelem_type>) { + %18 = shape.num_elements %7 : tensor -> index + %from_elements = tensor.from_elements %18 : tensor<1xindex> + %19 = mhlo.reshape %arg0 : (tensor<*xelem_type>) -> tensor + %20 = mhlo.dynamic_reshape %arg1, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %21 = chlo.broadcast_compare %19, %5 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %22 = mhlo.log %20 : tensor + %23 = chlo.broadcast_multiply %19, %22 : (tensor, tensor) -> tensor + %24 = chlo.broadcast_select %21, %19, %23 : (tensor, tensor, tensor) -> tensor + %cast = tensor.cast %24 : tensor to tensor<*xelem_type> + scf.yield %cast : tensor<*xelem_type> + } else { + %18 = shape.num_elements %7 : tensor -> index + %19 = arith.cmpi eq, %18, %c1 : index + %20 = scf.if %19 -> (tensor<*xelem_type>) { + %21 = shape.num_elements %6 : tensor -> index + %from_elements = tensor.from_elements %21 : tensor<1xindex> + %22 = mhlo.dynamic_reshape %arg0, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %23 = mhlo.reshape %arg1 : (tensor<*xelem_type>) -> tensor + %24 = chlo.broadcast_compare %22, %5 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %25 = mhlo.log %23 : tensor + %26 = chlo.broadcast_multiply %22, %25 : (tensor, tensor) -> tensor + %27 = chlo.broadcast_select %24, %22, %26 : (tensor, tensor, tensor) -> tensor + %cast = tensor.cast %27 : tensor to tensor<*xelem_type> + scf.yield %cast : tensor<*xelem_type> + } else { + %21 = shape.shape_eq %6, %7 : tensor, tensor + %22 = scf.if %21 -> (tensor<*xelem_type>) { + %23 = shape.any %6, %7 : tensor, tensor -> tensor + %24 = shape.num_elements %23 : tensor -> index + %from_elements = tensor.from_elements %24 : tensor<1xindex> + %25 = mhlo.dynamic_reshape %arg0, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %26 = mhlo.dynamic_reshape %arg1, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %27 = chlo.broadcast_compare %25, %5 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %28 = mhlo.log %26 : tensor + %29 = chlo.broadcast_multiply %25, %28 : (tensor, tensor) -> tensor + %30 = chlo.broadcast_select %27, %25, %29 : (tensor, tensor, tensor) -> tensor + %cast = tensor.cast %30 : tensor to tensor<*xelem_type> + scf.yield %cast : tensor<*xelem_type> + } else { + %23:2 = chlo.minimum_broadcast_shapes %6, %7 : tensor, tensor -> tensor, tensor + %24 = shape.rank %23#0 : tensor -> index + %25 = shape.rank %23#1 : tensor -> index + %26 = arith.cmpi sgt, %24, %25 : index + %27 = arith.select %26, %24, %25 : index + %28 = arith.cmpi ule, %27, %c1 : index + %29 = scf.if %28 -> (tensor<*xelem_type>) { + %30 = shape.broadcast %23#0, %4 : tensor, tensor<1xindex> -> tensor + %cast = tensor.cast %30 : tensor to tensor<1xindex> + %31 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %32 = shape.broadcast %23#1, %4 : tensor, tensor<1xindex> -> tensor + %cast_0 = tensor.cast %32 : tensor to tensor<1xindex> + %33 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %34 = chlo.broadcast_compare %31, %5 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %35 = mhlo.log %33 : tensor + %36 = chlo.broadcast_multiply %31, %35 : (tensor, tensor) -> tensor + %37 = chlo.broadcast_select %34, %31, %36 : (tensor, tensor, tensor) -> tensor + %cast_1 = tensor.cast %37 : tensor to tensor<*xelem_type> + scf.yield %cast_1 : tensor<*xelem_type> + } else { + %30 = arith.cmpi ule, %27, %c2 : index + %31 = scf.if %30 -> (tensor<*xelem_type>) { + %32 = shape.broadcast %23#0, %3 : tensor, tensor<2xindex> -> tensor + %cast = tensor.cast %32 : tensor to tensor<2xindex> + %33 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xelem_type>, tensor<2xindex>) -> tensor + %34 = shape.broadcast %23#1, %3 : tensor, tensor<2xindex> -> tensor + %cast_0 = tensor.cast %34 : tensor to tensor<2xindex> + %35 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xelem_type>, tensor<2xindex>) -> tensor + %36 = chlo.broadcast_compare %33, %5 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %37 = mhlo.log %35 : tensor + %38 = chlo.broadcast_multiply %33, %37 : (tensor, tensor) -> tensor + %39 = chlo.broadcast_select %36, %33, %38 : (tensor, tensor, tensor) -> tensor + %cast_1 = tensor.cast %39 : tensor to tensor<*xelem_type> + scf.yield %cast_1 : tensor<*xelem_type> + } else { + %32 = arith.cmpi ule, %27, %c3 : index + %33 = scf.if %32 -> (tensor<*xelem_type>) { + %34 = shape.broadcast %23#0, %2 : tensor, tensor<3xindex> -> tensor + %cast = tensor.cast %34 : tensor to tensor<3xindex> + %35 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xelem_type>, tensor<3xindex>) -> tensor + %36 = shape.broadcast %23#1, %2 : tensor, tensor<3xindex> -> tensor + %cast_0 = tensor.cast %36 : tensor to tensor<3xindex> + %37 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xelem_type>, tensor<3xindex>) -> tensor + %38 = chlo.broadcast_compare %35, %5 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %39 = mhlo.log %37 : tensor + %40 = chlo.broadcast_multiply %35, %39 : (tensor, tensor) -> tensor + %41 = chlo.broadcast_select %38, %35, %40 : (tensor, tensor, tensor) -> tensor + %cast_1 = tensor.cast %41 : tensor to tensor<*xelem_type> + scf.yield %cast_1 : tensor<*xelem_type> + } else { + %34 = arith.cmpi ule, %27, %c4 : index + %35 = scf.if %34 -> (tensor<*xelem_type>) { + %36 = shape.broadcast %23#0, %1 : tensor, tensor<4xindex> -> tensor + %cast = tensor.cast %36 : tensor to tensor<4xindex> + %37 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xelem_type>, tensor<4xindex>) -> tensor + %38 = shape.broadcast %23#1, %1 : tensor, tensor<4xindex> -> tensor + %cast_0 = tensor.cast %38 : tensor to tensor<4xindex> + %39 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xelem_type>, tensor<4xindex>) -> tensor + %40 = chlo.broadcast_compare %37, %5 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %41 = mhlo.log %39 : tensor + %42 = chlo.broadcast_multiply %37, %41 : (tensor, tensor) -> tensor + %43 = chlo.broadcast_select %40, %37, %42 : (tensor, tensor, tensor) -> tensor + %cast_1 = tensor.cast %43 : tensor to tensor<*xelem_type> + scf.yield %cast_1 : tensor<*xelem_type> + } else { + %36 = arith.cmpi ule, %27, %c5 : index + cf.assert %36, "Input for dynamic binary or n-ary op lowering was of a rank greater than 5" + %37 = shape.broadcast %23#0, %0 : tensor, tensor<5xindex> -> tensor + %cast = tensor.cast %37 : tensor to tensor<5xindex> + %38 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xelem_type>, tensor<5xindex>) -> tensor + %39 = shape.broadcast %23#1, %0 : tensor, tensor<5xindex> -> tensor + %cast_0 = tensor.cast %39 : tensor to tensor<5xindex> + %40 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xelem_type>, tensor<5xindex>) -> tensor + %41 = chlo.broadcast_compare %38, %5 {comparison_direction = #chlo} : (tensor, tensor) -> tensor + %42 = mhlo.log %40 : tensor + %43 = chlo.broadcast_multiply %38, %42 : (tensor, tensor) -> tensor + %44 = chlo.broadcast_select %41, %38, %43 : (tensor, tensor, tensor) -> tensor + %cast_1 = tensor.cast %44 : tensor to tensor<*xelem_type> + scf.yield %cast_1 : tensor<*xelem_type> + } + scf.yield %35 : tensor<*xelem_type> + } + scf.yield %33 : tensor<*xelem_type> + } + scf.yield %31 : tensor<*xelem_type> + } + scf.yield %29 : tensor<*xelem_type> + } + scf.yield %22 : tensor<*xelem_type> + } + scf.yield %20 : tensor<*xelem_type> + } + %11 = shape.shape_of %arg0 : tensor<*xelem_type> -> tensor + %12 = shape.shape_of %arg0 : tensor<*xelem_type> -> tensor + %13 = shape.shape_of %arg0 : tensor<*xelem_type> -> tensor + %14 = shape.shape_of %arg1 : tensor<*xelem_type> -> tensor + %15 = shape.broadcast %13, %14 : tensor, tensor -> tensor + %16 = shape.broadcast %11, %12, %15 : tensor, tensor, tensor -> tensor + %17 = mhlo.dynamic_reshape %10, %16 : (tensor<*xelem_type>, tensor) -> tensor<*xelem_type> + return %17 : tensor<*xelem_type> } diff --git a/tensorflow/core/kernels/mlir_generated/op_definitions/zeros_like.mlir.tmpl b/tensorflow/core/kernels/mlir_generated/op_definitions/zeros_like.mlir.tmpl index d7bd4c96416a62..995255ccba44c6 100644 --- a/tensorflow/core/kernels/mlir_generated/op_definitions/zeros_like.mlir.tmpl +++ b/tensorflow/core/kernels/mlir_generated/op_definitions/zeros_like.mlir.tmpl @@ -1,5 +1,9 @@ -func.func @ZerosLike_platform_elem_type_output_type(%arg0: tensor<*xelem_type>) - -> tensor<*xoutput_type> attributes {tf_entry, llvm.emit_c_interface} { - %0 = "chlo.constant_like"(%arg0) {value = 0 : elem_type} : (tensor<*xelem_type>) -> tensor<*xelem_type> - func.return %0 : tensor<*xoutput_type> +func.func @ZerosLike_platform_elem_type_output_type(%arg0: tensor<*xelem_type>) -> tensor<*xoutput_type> attributes {llvm.emit_c_interface, tf_entry} { + %0 = shape.shape_of %arg0 : tensor<*xelem_type> -> tensor + %1 = shape.num_elements %0 : tensor -> index + %from_elements = tensor.from_elements %1 : tensor<1xindex> + %2 = mhlo.dynamic_reshape %arg0, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %3 = "chlo.constant_like"(%2) {value = 0 : elem_type} : (tensor) -> tensor + %4 = mhlo.dynamic_reshape %3, %0 : (tensor, tensor) -> tensor<*xelem_type> + return %4 : tensor<*xelem_type> } diff --git a/tensorflow/core/kernels/mlir_generated/op_definitions/zeros_like_cmplx.mlir.tmpl b/tensorflow/core/kernels/mlir_generated/op_definitions/zeros_like_cmplx.mlir.tmpl index 13f4559d002ff8..804330e2d337ad 100644 --- a/tensorflow/core/kernels/mlir_generated/op_definitions/zeros_like_cmplx.mlir.tmpl +++ b/tensorflow/core/kernels/mlir_generated/op_definitions/zeros_like_cmplx.mlir.tmpl @@ -1,5 +1,9 @@ -func.func @ZerosLike_platform_elem_type_output_type(%arg0: tensor<*xelem_type>) - -> tensor<*xoutput_type> attributes {tf_entry, llvm.emit_c_interface} { - %0 = "chlo.constant_like"(%arg0) {value = #complex.number<:scalar_type 0.000000e+00, 0.000000e+00> : elem_type} : (tensor<*xelem_type>) -> tensor<*xelem_type> - func.return %0 : tensor<*xoutput_type> +func.func @ZerosLike_platform_elem_type_output_type(%arg0: tensor<*xelem_type>) -> tensor<*xoutput_type> attributes {llvm.emit_c_interface, tf_entry} { + %0 = shape.shape_of %arg0 : tensor<*xelem_type> -> tensor + %1 = shape.num_elements %0 : tensor -> index + %from_elements = tensor.from_elements %1 : tensor<1xindex> + %2 = mhlo.dynamic_reshape %arg0, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %3 = "chlo.constant_like"(%2) {value = #complex.number<:scalar_type 0.000000e+00, 0.000000e+00> : elem_type} : (tensor) -> tensor + %4 = mhlo.dynamic_reshape %3, %0 : (tensor, tensor) -> tensor<*xelem_type> + return %4 : tensor<*xelem_type> } diff --git a/tensorflow/core/kernels/mlir_generated/op_definitions/zeros_like_float.mlir.tmpl b/tensorflow/core/kernels/mlir_generated/op_definitions/zeros_like_float.mlir.tmpl index ec2facf391fef9..169b88dc35c313 100644 --- a/tensorflow/core/kernels/mlir_generated/op_definitions/zeros_like_float.mlir.tmpl +++ b/tensorflow/core/kernels/mlir_generated/op_definitions/zeros_like_float.mlir.tmpl @@ -1,5 +1,9 @@ -func.func @ZerosLike_platform_elem_type_output_type(%arg0: tensor<*xelem_type>) - -> tensor<*xoutput_type> attributes {tf_entry, llvm.emit_c_interface} { - %0 = "chlo.constant_like"(%arg0) {value = 0.000000e+00 : elem_type} : (tensor<*xelem_type>) -> tensor<*xelem_type> - func.return %0 : tensor<*xoutput_type> +func.func @ZerosLike_platform_elem_type_output_type(%arg0: tensor<*xelem_type>) -> tensor<*xoutput_type> attributes {llvm.emit_c_interface, tf_entry} { + %0 = shape.shape_of %arg0 : tensor<*xelem_type> -> tensor + %1 = shape.num_elements %0 : tensor -> index + %from_elements = tensor.from_elements %1 : tensor<1xindex> + %2 = mhlo.dynamic_reshape %arg0, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %3 = "chlo.constant_like"(%2) {value = 0.000000e+00 : elem_type} : (tensor) -> tensor + %4 = mhlo.dynamic_reshape %3, %0 : (tensor, tensor) -> tensor<*xelem_type> + return %4 : tensor<*xelem_type> } diff --git a/tensorflow/core/kernels/mlir_generated/op_definitions/zeta.mlir.tmpl b/tensorflow/core/kernels/mlir_generated/op_definitions/zeta.mlir.tmpl index d7c88c2119b539..c31280f1daa362 100644 --- a/tensorflow/core/kernels/mlir_generated/op_definitions/zeta.mlir.tmpl +++ b/tensorflow/core/kernels/mlir_generated/op_definitions/zeta.mlir.tmpl @@ -1,6 +1,129 @@ -func.func @Zeta_platform_elem_type_output_type(%arg0: tensor<*xelem_type>, %arg1: tensor<*xelem_type>) - -> tensor<*xoutput_type> attributes {tf_entry, llvm.emit_c_interface} { - %0 = chlo.broadcast_zeta %arg0, %arg1 - : (tensor<*xelem_type>, tensor<*xelem_type>) -> tensor<*xoutput_type> - func.return %0 : tensor<*xoutput_type> +func.func @Zeta_platform_elem_type_output_type(%arg0: tensor<*xelem_type>, %arg1: tensor<*xelem_type>) -> tensor<*xoutput_type> attributes {llvm.emit_c_interface, tf_entry} { + %0 = shape.const_shape [1, 1, 1, 1, 1] : tensor<5xindex> + %c5 = arith.constant 5 : index + %1 = shape.const_shape [1, 1, 1, 1] : tensor<4xindex> + %c4 = arith.constant 4 : index + %2 = shape.const_shape [1, 1, 1] : tensor<3xindex> + %c3 = arith.constant 3 : index + %3 = shape.const_shape [1, 1] : tensor<2xindex> + %c2 = arith.constant 2 : index + %4 = shape.const_shape [1] : tensor<1xindex> + %c1 = arith.constant 1 : index + %5 = shape.shape_of %arg0 : tensor<*xelem_type> -> tensor + %6 = shape.shape_of %arg1 : tensor<*xelem_type> -> tensor + %7 = shape.num_elements %5 : tensor -> index + %8 = arith.cmpi eq, %7, %c1 : index + %9 = scf.if %8 -> (tensor<*xoutput_type>) { + %14 = shape.num_elements %6 : tensor -> index + %from_elements = tensor.from_elements %14 : tensor<1xindex> + %15 = mhlo.reshape %arg0 : (tensor<*xelem_type>) -> tensor + %16 = mhlo.dynamic_reshape %arg1, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %17 = chlo.broadcast_zeta %15, %16 : (tensor, tensor) -> tensor + %cast = tensor.cast %17 : tensor to tensor<*xoutput_type> + scf.yield %cast : tensor<*xoutput_type> + } else { + %14 = shape.num_elements %6 : tensor -> index + %15 = arith.cmpi eq, %14, %c1 : index + %16 = scf.if %15 -> (tensor<*xoutput_type>) { + %17 = shape.num_elements %5 : tensor -> index + %from_elements = tensor.from_elements %17 : tensor<1xindex> + %18 = mhlo.dynamic_reshape %arg0, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %19 = mhlo.reshape %arg1 : (tensor<*xelem_type>) -> tensor + %20 = chlo.broadcast_zeta %18, %19 : (tensor, tensor) -> tensor + %cast = tensor.cast %20 : tensor to tensor<*xoutput_type> + scf.yield %cast : tensor<*xoutput_type> + } else { + %17 = shape.shape_eq %5, %6 : tensor, tensor + %18 = scf.if %17 -> (tensor<*xoutput_type>) { + %19 = shape.any %5, %6 : tensor, tensor -> tensor + %20 = shape.num_elements %19 : tensor -> index + %from_elements = tensor.from_elements %20 : tensor<1xindex> + %21 = mhlo.dynamic_reshape %arg0, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %22 = mhlo.dynamic_reshape %arg1, %from_elements : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %23 = chlo.broadcast_zeta %21, %22 : (tensor, tensor) -> tensor + %cast = tensor.cast %23 : tensor to tensor<*xoutput_type> + scf.yield %cast : tensor<*xoutput_type> + } else { + %19:2 = chlo.minimum_broadcast_shapes %5, %6 : tensor, tensor -> tensor, tensor + %20 = shape.rank %19#0 : tensor -> index + %21 = shape.rank %19#1 : tensor -> index + %22 = arith.cmpi sgt, %20, %21 : index + %23 = arith.select %22, %20, %21 : index + %24 = arith.cmpi ule, %23, %c1 : index + %25 = scf.if %24 -> (tensor<*xoutput_type>) { + %26 = shape.broadcast %19#0, %4 : tensor, tensor<1xindex> -> tensor + %cast = tensor.cast %26 : tensor to tensor<1xindex> + %27 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %28 = shape.broadcast %19#1, %4 : tensor, tensor<1xindex> -> tensor + %cast_0 = tensor.cast %28 : tensor to tensor<1xindex> + %29 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xelem_type>, tensor<1xindex>) -> tensor + %30 = chlo.broadcast_zeta %27, %29 : (tensor, tensor) -> tensor + %cast_1 = tensor.cast %30 : tensor to tensor<*xoutput_type> + scf.yield %cast_1 : tensor<*xoutput_type> + } else { + %26 = arith.cmpi ule, %23, %c2 : index + %27 = scf.if %26 -> (tensor<*xoutput_type>) { + %28 = shape.broadcast %19#0, %3 : tensor, tensor<2xindex> -> tensor + %cast = tensor.cast %28 : tensor to tensor<2xindex> + %29 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xelem_type>, tensor<2xindex>) -> tensor + %30 = shape.broadcast %19#1, %3 : tensor, tensor<2xindex> -> tensor + %cast_0 = tensor.cast %30 : tensor to tensor<2xindex> + %31 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xelem_type>, tensor<2xindex>) -> tensor + %32 = chlo.broadcast_zeta %29, %31 : (tensor, tensor) -> tensor + %cast_1 = tensor.cast %32 : tensor to tensor<*xoutput_type> + scf.yield %cast_1 : tensor<*xoutput_type> + } else { + %28 = arith.cmpi ule, %23, %c3 : index + %29 = scf.if %28 -> (tensor<*xoutput_type>) { + %30 = shape.broadcast %19#0, %2 : tensor, tensor<3xindex> -> tensor + %cast = tensor.cast %30 : tensor to tensor<3xindex> + %31 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xelem_type>, tensor<3xindex>) -> tensor + %32 = shape.broadcast %19#1, %2 : tensor, tensor<3xindex> -> tensor + %cast_0 = tensor.cast %32 : tensor to tensor<3xindex> + %33 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xelem_type>, tensor<3xindex>) -> tensor + %34 = chlo.broadcast_zeta %31, %33 : (tensor, tensor) -> tensor + %cast_1 = tensor.cast %34 : tensor to tensor<*xoutput_type> + scf.yield %cast_1 : tensor<*xoutput_type> + } else { + %30 = arith.cmpi ule, %23, %c4 : index + %31 = scf.if %30 -> (tensor<*xoutput_type>) { + %32 = shape.broadcast %19#0, %1 : tensor, tensor<4xindex> -> tensor + %cast = tensor.cast %32 : tensor to tensor<4xindex> + %33 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xelem_type>, tensor<4xindex>) -> tensor + %34 = shape.broadcast %19#1, %1 : tensor, tensor<4xindex> -> tensor + %cast_0 = tensor.cast %34 : tensor to tensor<4xindex> + %35 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xelem_type>, tensor<4xindex>) -> tensor + %36 = chlo.broadcast_zeta %33, %35 : (tensor, tensor) -> tensor + %cast_1 = tensor.cast %36 : tensor to tensor<*xoutput_type> + scf.yield %cast_1 : tensor<*xoutput_type> + } else { + %32 = arith.cmpi ule, %23, %c5 : index + cf.assert %32, "Input for dynamic binary or n-ary op lowering was of a rank greater than 5" + %33 = shape.broadcast %19#0, %0 : tensor, tensor<5xindex> -> tensor + %cast = tensor.cast %33 : tensor to tensor<5xindex> + %34 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xelem_type>, tensor<5xindex>) -> tensor + %35 = shape.broadcast %19#1, %0 : tensor, tensor<5xindex> -> tensor + %cast_0 = tensor.cast %35 : tensor to tensor<5xindex> + %36 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xelem_type>, tensor<5xindex>) -> tensor + %37 = chlo.broadcast_zeta %34, %36 : (tensor, tensor) -> tensor + %cast_1 = tensor.cast %37 : tensor to tensor<*xoutput_type> + scf.yield %cast_1 : tensor<*xoutput_type> + } + scf.yield %31 : tensor<*xoutput_type> + } + scf.yield %29 : tensor<*xoutput_type> + } + scf.yield %27 : tensor<*xoutput_type> + } + scf.yield %25 : tensor<*xoutput_type> + } + scf.yield %18 : tensor<*xoutput_type> + } + scf.yield %16 : tensor<*xoutput_type> + } + %10 = shape.shape_of %arg0 : tensor<*xelem_type> -> tensor + %11 = shape.shape_of %arg1 : tensor<*xelem_type> -> tensor + %12 = shape.broadcast %10, %11 : tensor, tensor -> tensor + %13 = mhlo.dynamic_reshape %9, %12 : (tensor<*xoutput_type>, tensor) -> tensor<*xoutput_type> + return %13 : tensor<*xoutput_type> } diff --git a/tensorflow/core/kernels/mutex_ops.cc b/tensorflow/core/kernels/mutex_ops.cc index dd1d02cfdfb0cb..6147afc73e58a0 100644 --- a/tensorflow/core/kernels/mutex_ops.cc +++ b/tensorflow/core/kernels/mutex_ops.cc @@ -113,7 +113,7 @@ class Mutex : public ResourceBase { delete cancelled; } if (local_locked) { // Not cancelled. - fn_(OkStatus(), + fn_(absl::OkStatus(), SharedLockReleaser{std::make_shared(this)}); } else { fn_(errors::Cancelled("Lock acquisition cancelled."), @@ -146,7 +146,7 @@ class MutexLockOp : public AsyncOpKernel { [c](Mutex** ptr) { *ptr = new Mutex( c, HandleFromInput(c, 0).name()); - return OkStatus(); + return absl::OkStatus(); }), done); diff --git a/tensorflow/core/kernels/ops_testutil.cc b/tensorflow/core/kernels/ops_testutil.cc index 98618167f358a8..06762dbf9652eb 100644 --- a/tensorflow/core/kernels/ops_testutil.cc +++ b/tensorflow/core/kernels/ops_testutil.cc @@ -149,7 +149,7 @@ Status OpsTestBase::InitOpWithGraphVersion(int graph_def_version) { device_->resource_manager(), props, graph_def_version, &kernel)); kernel_.reset(kernel); input_types_ = kernel_->input_types(); - return OkStatus(); + return absl::OkStatus(); } static std::function)>* GetDefaultRunner() { diff --git a/tensorflow/core/kernels/ops_util_test.cc b/tensorflow/core/kernels/ops_util_test.cc index 955bca25b16776..82275979f758e1 100644 --- a/tensorflow/core/kernels/ops_util_test.cc +++ b/tensorflow/core/kernels/ops_util_test.cc @@ -133,7 +133,7 @@ class OpsUtilTest : public ::testing::Test { static void VerifyBcastValues(bcast_struct bcast) { int new_index, new_size; - EXPECT_EQ(OkStatus(), + EXPECT_EQ(absl::OkStatus(), GetBroadcastSize(bcast.input.index, bcast.input.in_size, bcast.input.ksize, bcast.input.stride, bcast.input.pad_size, &new_index, &new_size)); diff --git a/tensorflow/core/kernels/padding_fifo_queue.cc b/tensorflow/core/kernels/padding_fifo_queue.cc index 95a5483437e307..57b46a0a06f48e 100644 --- a/tensorflow/core/kernels/padding_fifo_queue.cc +++ b/tensorflow/core/kernels/padding_fifo_queue.cc @@ -52,7 +52,7 @@ Status PaddingFIFOQueue::Initialize() { " shapes."); } - return OkStatus(); + return absl::OkStatus(); } /* static */ @@ -63,7 +63,7 @@ Status PaddingFIFOQueue::GetElementComponent( TF_RETURN_IF_ERROR( ctx->allocate_temp(tuple[component].dtype(), element_shape, out_tensor)); *out_tensor = tuple[component]; - return OkStatus(); + return absl::OkStatus(); } void PaddingFIFOQueue::TryDequeueMany(int num_elements, OpKernelContext* ctx, @@ -243,7 +243,7 @@ Status PaddingFIFOQueue::ValidateTuple(const Tuple& tuple) { tuple[i].shape().DebugString()); } } - return OkStatus(); + return absl::OkStatus(); } Status PaddingFIFOQueue::ValidateManyTuple(const Tuple& tuple) { @@ -260,7 +260,7 @@ Status PaddingFIFOQueue::ValidateManyTuple(const Tuple& tuple) { tuple[i].shape().DebugString()); } } - return OkStatus(); + return absl::OkStatus(); } Status PaddingFIFOQueue::CompatibleNodeDefShapes( @@ -275,7 +275,7 @@ Status PaddingFIFOQueue::CompatibleNodeDefShapes( " but requested component shapes were ", PartialTensorShapeUtils::PartialShapeListString(requested_shapes)); } else { - return OkStatus(); + return absl::OkStatus(); } } @@ -288,7 +288,7 @@ Status PaddingFIFOQueue::MatchesNodeDef(const NodeDef& node_def) { TF_RETURN_IF_ERROR(MatchesNodeDefCapacity(node_def, capacity_)); TF_RETURN_IF_ERROR(MatchesNodeDefTypes(node_def)); TF_RETURN_IF_ERROR(CompatibleNodeDefShapes(node_def)); - return OkStatus(); + return absl::OkStatus(); } static Status ValidateElementToLargerSlice(const Tensor& element, @@ -303,7 +303,7 @@ static Status ValidateElementToLargerSlice(const Tensor& element, "Shapes are: [element]: ", element.shape().DebugString(), ", [parent slice]: ", chip_shape.DebugString()); } - return OkStatus(); + return absl::OkStatus(); } template @@ -314,7 +314,7 @@ Status HandleElementToLargerSlice(const Tensor& element, Tensor* parent, return s; } if (element.NumElements() == 0) { - return OkStatus(); + return absl::OkStatus(); } auto element_t = element.tensor(); auto parent_t = parent->tensor(); @@ -326,7 +326,7 @@ Status HandleElementToLargerSlice(const Tensor& element, Tensor* parent, slice_size[i] = element_t.dimension(i - 1); } parent_t.slice(slice_indices, slice_size) = element_t.reshape(slice_size); - return OkStatus(); + return absl::OkStatus(); } namespace { diff --git a/tensorflow/core/kernels/partitioned_function_ops.cc b/tensorflow/core/kernels/partitioned_function_ops.cc index 724c7aae337800..761920189c3933 100644 --- a/tensorflow/core/kernels/partitioned_function_ops.cc +++ b/tensorflow/core/kernels/partitioned_function_ops.cc @@ -162,7 +162,7 @@ Status PartitionedCallOp::FillOutputDevices( } } } - return OkStatus(); + return absl::OkStatus(); } Status PartitionedCallOp::Instantiate(FunctionLibraryRuntime* lib, @@ -227,7 +227,7 @@ Status PartitionedCallOp::Instantiate(FunctionLibraryRuntime* lib, TF_RETURN_IF_ERROR( lib->Instantiate(func_->name(), AttrSlice(&func_->attr()), opts, handle)); - return OkStatus(); + return absl::OkStatus(); } void PartitionedCallOp::RunFunction(FunctionLibraryRuntime::Handle handle, diff --git a/tensorflow/core/kernels/poisson-loss.h b/tensorflow/core/kernels/poisson-loss.h index 12888bda3135aa..7cdee546dd2b17 100644 --- a/tensorflow/core/kernels/poisson-loss.h +++ b/tensorflow/core/kernels/poisson-loss.h @@ -84,7 +84,7 @@ class PoissonLossUpdater : public DualLossUpdater { "Only non-negative labels can be used with the Poisson log loss. " "Found example with label: ", *example_label); } - return OkStatus(); + return absl::OkStatus(); } private: diff --git a/tensorflow/core/kernels/pooling_ops_common.cc b/tensorflow/core/kernels/pooling_ops_common.cc index 8d1220de4d4e0e..55a5dde92f12ca 100644 --- a/tensorflow/core/kernels/pooling_ops_common.cc +++ b/tensorflow/core/kernels/pooling_ops_common.cc @@ -113,7 +113,7 @@ Status CheckPaddingSize(int64_t window_rows, int64_t window_cols, "window size ", window_cols); } - return OkStatus(); + return absl::OkStatus(); } PoolParameters::PoolParameters(OpKernelContext* context, @@ -220,7 +220,7 @@ Status PoolParameters::forward_output_shape(TensorShape* shape) { *shape = TensorShape( {tensor_in_batch, tensor_in_rows, tensor_in_cols, out_depth}); } - return OkStatus(); + return absl::OkStatus(); } #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM @@ -402,6 +402,9 @@ void DnnPoolingImpl(OpKernelContext* context, se::dnn::PoolingMode pooling_mode, auto* stream = context->op_device_context()->stream(); OP_REQUIRES(context, stream, errors::Internal("No GPU stream available.")); + auto* dnn = stream->parent()->AsDnn(); + OP_REQUIRES(context, dnn != nullptr, + errors::Internal("No DNN support for stream.")); #if TENSORFLOW_USE_ROCM static int64 PoolingScratchSize = GetDnnWorkspaceLimit( @@ -411,13 +414,14 @@ void DnnPoolingImpl(OpKernelContext* context, se::dnn::PoolingMode pooling_mode, DnnScratchAllocator scratch_allocator(PoolingScratchSize, context); OP_REQUIRES_OK(context, - stream->ThenPoolForward(pooling_desc, GetNumericOptions(), - input_desc, input_data, output_desc, - &output_data, &scratch_allocator)); + dnn->PoolForward(stream, pooling_desc, GetNumericOptions(), + input_desc, input_data, output_desc, + &output_data, &scratch_allocator)); #else - OP_REQUIRES_OK(context, stream->ThenPoolForward( - pooling_desc, GetNumericOptions(), input_desc, - input_data, output_desc, &output_data)); + OP_REQUIRES_OK( + context, + dnn->PoolForward(stream, pooling_desc, GetNumericOptions(), input_desc, + input_data, output_desc, &output_data)); #endif #if CUDNN_VERSION < 7300 @@ -798,6 +802,9 @@ void DnnPoolingGradImpl(OpKernelContext* context, auto* stream = context->op_device_context()->stream(); OP_REQUIRES(context, stream, errors::Internal("No GPU stream available.")); + auto* dnn = stream->parent()->AsDnn(); + OP_REQUIRES(context, dnn != nullptr, + errors::Internal("No DNN support for stream.")); #if TENSORFLOW_USE_ROCM static int64 PoolingScratchSize = GetDnnWorkspaceLimit( @@ -808,16 +815,16 @@ void DnnPoolingGradImpl(OpKernelContext* context, DnnScratchAllocator scratch_allocator(PoolingScratchSize, context); OP_REQUIRES_OK( context, - stream->ThenPoolBackward( - pooling_desc, GetNumericOptions(), orig_input_desc, orig_input_data, - orig_output_desc, orig_output_data, output_backprop_data, - &input_backprop_data, &scratch_allocator)); + dnn->PoolBackward(stream, pooling_desc, GetNumericOptions(), + orig_input_desc, orig_input_data, orig_output_desc, + orig_output_data, output_backprop_data, + &input_backprop_data, &scratch_allocator)); #else OP_REQUIRES_OK(context, - stream->ThenPoolBackward( - pooling_desc, GetNumericOptions(), orig_input_desc, - orig_input_data, orig_output_desc, orig_output_data, - output_backprop_data, &input_backprop_data)); + dnn->PoolBackward(stream, pooling_desc, GetNumericOptions(), + orig_input_desc, orig_input_data, + orig_output_desc, orig_output_data, + output_backprop_data, &input_backprop_data)); #endif if (padding == EXPLICIT && (params.pad_top != params.pad_bottom || diff --git a/tensorflow/core/kernels/priority_queue.cc b/tensorflow/core/kernels/priority_queue.cc index b5426a0f75e46c..dde3ad973610b4 100644 --- a/tensorflow/core/kernels/priority_queue.cc +++ b/tensorflow/core/kernels/priority_queue.cc @@ -57,7 +57,7 @@ Status PriorityQueue::Initialize() { "is: ", component_shapes_[0].DebugString()); } - return OkStatus(); + return absl::OkStatus(); } void PriorityQueue::DequeueLocked(OpKernelContext* ctx, Tuple* tuple) { @@ -124,7 +124,7 @@ Status PriorityQueue::GetElementComponentFromBatch( ctx->allocate_temp(tuple[component].dtype(), element_shape, out_element)); TF_RETURN_IF_ERROR( batch_util::CopySliceToElement(tuple[component], out_element, index)); - return OkStatus(); + return absl::OkStatus(); } void PriorityQueue::TryEnqueueMany(const Tuple& tuple, OpKernelContext* ctx, @@ -393,7 +393,7 @@ Status PriorityQueue::MatchesNodeDef(const NodeDef& node_def) { TF_RETURN_IF_ERROR(MatchesNodeDefCapacity(node_def, capacity_)); TF_RETURN_IF_ERROR(MatchesPriorityNodeDefTypes(node_def)); TF_RETURN_IF_ERROR(MatchesPriorityNodeDefShapes(node_def)); - return OkStatus(); + return absl::OkStatus(); } Status PriorityQueue::MatchesPriorityNodeDefTypes( @@ -409,7 +409,7 @@ Status PriorityQueue::MatchesPriorityNodeDefTypes( " but requested component types were ", DataTypeSliceString(requested_dtypes)); } - return OkStatus(); + return absl::OkStatus(); } Status PriorityQueue::MatchesPriorityNodeDefShapes( @@ -424,7 +424,7 @@ Status PriorityQueue::MatchesPriorityNodeDefShapes( " but requested component shapes were ", ShapeListString(requested_shapes)); } - return OkStatus(); + return absl::OkStatus(); } } // namespace tensorflow diff --git a/tensorflow/core/kernels/quantized_concat_op.cc b/tensorflow/core/kernels/quantized_concat_op.cc index 19ec612fdea654..997bc0293ba7a5 100644 --- a/tensorflow/core/kernels/quantized_concat_op.cc +++ b/tensorflow/core/kernels/quantized_concat_op.cc @@ -118,7 +118,7 @@ class QuantizedConcatOp : public OpKernel { *output_min = overall_min; *output_max = overall_max; } - return OkStatus(); + return absl::OkStatus(); } int64_t CalculateInputsDim(const TensorShape& input_shape, @@ -171,7 +171,7 @@ class QuantizedConcatOp : public OpKernel { } *output_concat_dim += in.dims() > 0 ? in.dim_size(concat_dim) : 1; } - return OkStatus(); + return absl::OkStatus(); } void Compute(OpKernelContext* context) override { diff --git a/tensorflow/core/kernels/quantized_conv_ops.cc b/tensorflow/core/kernels/quantized_conv_ops.cc index 495836794ff9e6..aa5135dc6bdfc9 100644 --- a/tensorflow/core/kernels/quantized_conv_ops.cc +++ b/tensorflow/core/kernels/quantized_conv_ops.cc @@ -284,7 +284,7 @@ class Im2ColConvFunctor { (kMaxChunkSize + (sizeof(T1) - 1)) / sizeof(T1); #endif *resource = new Im2ColBufferResource(); - return OkStatus(); + return absl::OkStatus(); }; OP_REQUIRES_OK(context, context->resource_manager()->LookupOrCreate( "Conv2d", "im2col_buffer", diff --git a/tensorflow/core/kernels/queue_base.cc b/tensorflow/core/kernels/queue_base.cc index 6ab7c97ba3f414..1b1d7f448ada46 100644 --- a/tensorflow/core/kernels/queue_base.cc +++ b/tensorflow/core/kernels/queue_base.cc @@ -44,7 +44,7 @@ Status HandleSliceToElement(const Tensor& parent, Tensor* element, } auto parent_as_matrix = parent.flat_outer_dims(); element->flat() = parent_as_matrix.chip(index, 0); - return OkStatus(); + return absl::OkStatus(); } } // namespace @@ -74,7 +74,7 @@ Status QueueBase::ValidateTupleCommon(const Tuple& tuple) const { DataTypeString(tuple[i].dtype())); } } - return OkStatus(); + return absl::OkStatus(); } // static @@ -96,7 +96,7 @@ Status QueueBase::MatchesNodeDefOp(const NodeDef& node_def, "' that does not match type of Node '", node_def.name(), "': ", node_def.op()); } - return OkStatus(); + return absl::OkStatus(); } Status QueueBase::MatchesNodeDefCapacity(const NodeDef& node_def, @@ -109,7 +109,7 @@ Status QueueBase::MatchesNodeDefCapacity(const NodeDef& node_def, capacity, " but requested capacity was ", requested_capacity); } - return OkStatus(); + return absl::OkStatus(); } Status QueueBase::MatchesNodeDefTypes(const NodeDef& node_def) const { @@ -123,7 +123,7 @@ Status QueueBase::MatchesNodeDefTypes(const NodeDef& node_def) const { " but requested component types were ", DataTypeSliceString(requested_dtypes)); } - return OkStatus(); + return absl::OkStatus(); } Status QueueBase::MatchesNodeDefShapes(const NodeDef& node_def) const { @@ -136,7 +136,7 @@ Status QueueBase::MatchesNodeDefShapes(const NodeDef& node_def) const { " but requested component shapes were ", ShapeListString(requested_shapes)); } - return OkStatus(); + return absl::OkStatus(); } // TODO(mrry): If these checks become a bottleneck, find a way to @@ -153,7 +153,7 @@ Status QueueBase::ValidateTuple(const Tuple& tuple) { } } } - return OkStatus(); + return absl::OkStatus(); } // TODO(mrry): If these checks become a bottleneck, find a way to @@ -182,7 +182,7 @@ Status QueueBase::ValidateManyTuple(const Tuple& tuple) { } } } - return OkStatus(); + return absl::OkStatus(); } void QueueBase::Cancel(Action action, CancellationManager* cancellation_manager, diff --git a/tensorflow/core/kernels/ragged_cross_op.cc b/tensorflow/core/kernels/ragged_cross_op.cc index c8f27051b449cc..c6afecc8b80f0d 100644 --- a/tensorflow/core/kernels/ragged_cross_op.cc +++ b/tensorflow/core/kernels/ragged_cross_op.cc @@ -439,7 +439,7 @@ class RaggedCrossOp : public OpKernel { } } - return OkStatus(); + return absl::OkStatus(); } // Calculate the batch size from any input tensor. (We check that all input @@ -518,7 +518,7 @@ class RaggedCrossOp : public OpKernel { } } - return OkStatus(); + return absl::OkStatus(); } // Builds a RaggedReatureReader @@ -552,7 +552,7 @@ class RaggedCrossOp : public OpKernel { new RaggedFeatureReader(values, splits)); } } - return OkStatus(); + return absl::OkStatus(); } // Builds a DenseFaggedReatureReader. @@ -567,7 +567,7 @@ class RaggedCrossOp : public OpKernel { (features->size() + 1), ": ", values.dtype()); } - return OkStatus(); + return absl::OkStatus(); } // Builds a SparseFaggedReatureReader. @@ -586,7 +586,7 @@ class RaggedCrossOp : public OpKernel { (features->size() + 1), ": ", values.dtype()); } - return OkStatus(); + return absl::OkStatus(); } // Allocates output tensors with proper size, and populates row_splits_out. @@ -612,7 +612,7 @@ class RaggedCrossOp : public OpKernel { TF_RETURN_IF_ERROR(context->allocate_output( 0, TensorShape({cross_count_total}), values_out)); - return OkStatus(); + return absl::OkStatus(); } // Returns number of crosses for a given batch_index diff --git a/tensorflow/core/kernels/ragged_gather_op.cc b/tensorflow/core/kernels/ragged_gather_op.cc index 7d5495e10f71c3..0252be8de06803 100644 --- a/tensorflow/core/kernels/ragged_gather_op.cc +++ b/tensorflow/core/kernels/ragged_gather_op.cc @@ -111,7 +111,7 @@ class RaggedGatherOpBase : public OpKernel { " is not in [0, ", num_params, ")"); } } - return OkStatus(); + return absl::OkStatus(); } // Construct the `splits` output tensors, encoded using a nested vector. @@ -188,7 +188,7 @@ class RaggedGatherOpBase : public OpKernel { *num_values += limit - start; } } - return OkStatus(); + return absl::OkStatus(); } ::tensorflow::Status ValidateSplits( @@ -216,7 +216,7 @@ class RaggedGatherOpBase : public OpKernel { } } } - return OkStatus(); + return absl::OkStatus(); } ::tensorflow::Status WriteSplits( @@ -234,7 +234,7 @@ class RaggedGatherOpBase : public OpKernel { std::copy_n(out_splits[i].data(), out_splits[i].size(), splits_flat.data()); } - return OkStatus(); + return absl::OkStatus(); } ::tensorflow::Status WriteValues( @@ -253,7 +253,7 @@ class RaggedGatherOpBase : public OpKernel { : (num_elements / params_dense_values_in.dim_size(0)); CallWriteValueSlices(params_dense_values_in, value_slices, value_size, values_out); - return OkStatus(); + return absl::OkStatus(); } protected: diff --git a/tensorflow/core/kernels/ragged_tensor_from_variant_op.cc b/tensorflow/core/kernels/ragged_tensor_from_variant_op.cc index ea137114b55248..2b3052394766a0 100644 --- a/tensorflow/core/kernels/ragged_tensor_from_variant_op.cc +++ b/tensorflow/core/kernels/ragged_tensor_from_variant_op.cc @@ -77,7 +77,7 @@ Status RaggedComponentsFromVariant( } } } - return OkStatus(); + return absl::OkStatus(); } /* Takes a set of RaggedTensorVariants for non-ragged tensors, stacks @@ -96,7 +96,7 @@ Status StackNonRaggedTensors( RaggedTensorVariant* output_ragged) { if (ragged_components.empty()) { output_ragged->set_values(Tensor(DataTypeToEnum::value, {0})); - return OkStatus(); + return absl::OkStatus(); } TensorShape component_values_shape = ragged_components[0].values().shape(); @@ -120,7 +120,7 @@ Status StackNonRaggedTensors( output_values_flat(values_index++) = component_values_flat(j); } } - return OkStatus(); + return absl::OkStatus(); } template @@ -261,7 +261,7 @@ Status NestedStackRaggedTensors( } } } - return OkStatus(); + return absl::OkStatus(); } } // namespace diff --git a/tensorflow/core/kernels/ragged_tensor_to_sparse_kernel.cc b/tensorflow/core/kernels/ragged_tensor_to_sparse_kernel.cc index c5e3e2e7896a14..41667581eab3d1 100644 --- a/tensorflow/core/kernels/ragged_tensor_to_sparse_kernel.cc +++ b/tensorflow/core/kernels/ragged_tensor_to_sparse_kernel.cc @@ -188,7 +188,7 @@ class RaggedTensorToSparseOp : public OpKernel { "Final value of ragged splits must match the length " "the corresponding ragged values."); } - return OkStatus(); + return absl::OkStatus(); } // Build a list of index suffixes that should be added for each ragged item, diff --git a/tensorflow/core/kernels/ragged_tensor_to_tensor_op.cc b/tensorflow/core/kernels/ragged_tensor_to_tensor_op.cc index 3985680337a6e3..a3dbeb9aac9f84 100644 --- a/tensorflow/core/kernels/ragged_tensor_to_tensor_op.cc +++ b/tensorflow/core/kernels/ragged_tensor_to_tensor_op.cc @@ -90,10 +90,10 @@ class RaggedTensorToTensorBaseOp : public OpKernel { switch (GetRowPartitionTypeByDimension(dimension - 1)) { case RowPartitionType::VALUE_ROWIDS: *result = GetMaxWidthValueRowID(row_partition_tensor); - return OkStatus(); + return absl::OkStatus(); case RowPartitionType::ROW_SPLITS: *result = GetMaxWidthRowSplit(row_partition_tensor); - return OkStatus(); + return absl::OkStatus(); default: return errors::InvalidArgument( "Cannot handle partition type ", @@ -176,7 +176,7 @@ class RaggedTensorToTensorBaseOp : public OpKernel { TF_RETURN_IF_ERROR(GetMaxWidth(c, i, &(*result)[i])); } } - return OkStatus(); + return absl::OkStatus(); } /** @@ -236,7 +236,7 @@ class RaggedTensorToTensorBaseOp : public OpKernel { return errors::InvalidArgument("Invalid row split size."); } - return OkStatus(); + return absl::OkStatus(); } // Calculate the output index of the first element of a list. @@ -268,7 +268,7 @@ class RaggedTensorToTensorBaseOp : public OpKernel { const INDEX_TYPE index_size = value_rowids.size(); result->reserve(index_size); if (index_size == 0) { - return OkStatus(); + return absl::OkStatus(); } INDEX_TYPE current_output_column = 0; @@ -312,7 +312,7 @@ class RaggedTensorToTensorBaseOp : public OpKernel { return errors::InvalidArgument("Invalid row ids."); } - return OkStatus(); + return absl::OkStatus(); } Status CalculateOutputIndex(OpKernelContext* context, int dimension, @@ -355,13 +355,13 @@ class RaggedTensorToTensorBaseOp : public OpKernel { switch (first_partition_type) { case RowPartitionType::FIRST_DIM_SIZE: *result = first_partition_tensor.scalar()(); - return OkStatus(); + return absl::OkStatus(); case RowPartitionType::VALUE_ROWIDS: return errors::InvalidArgument( "Cannot handle VALUE_ROWIDS in first dimension."); case RowPartitionType::ROW_SPLITS: *result = first_partition_tensor.shape().dim_size(0) - 1; - return OkStatus(); + return absl::OkStatus(); default: return errors::InvalidArgument( "Cannot handle type ", diff --git a/tensorflow/core/kernels/ragged_tensor_to_variant_op.cc b/tensorflow/core/kernels/ragged_tensor_to_variant_op.cc index 153fd5a98fea1e..04237d8ecb7f99 100644 --- a/tensorflow/core/kernels/ragged_tensor_to_variant_op.cc +++ b/tensorflow/core/kernels/ragged_tensor_to_variant_op.cc @@ -61,7 +61,7 @@ Status UnbatchDenseZerothDim( } } - return OkStatus(); + return absl::OkStatus(); } template @@ -110,7 +110,7 @@ Status UnbatchRaggedZerothDim( batched_flat(j + start * num_inner_elems); } } - return OkStatus(); + return absl::OkStatus(); } // Unbatch nested splits. @@ -166,7 +166,7 @@ Status UnbatchRaggedZerothDim( } } - return OkStatus(); + return absl::OkStatus(); } } // namespace diff --git a/tensorflow/core/kernels/ragged_tensor_variant.cc b/tensorflow/core/kernels/ragged_tensor_variant.cc index c418961f64acc6..3ef3294ee957ba 100644 --- a/tensorflow/core/kernels/ragged_tensor_variant.cc +++ b/tensorflow/core/kernels/ragged_tensor_variant.cc @@ -58,7 +58,7 @@ Status RaggedTensorVariantDeviceCopy( TF_RETURN_IF_ERROR(copy(from.values(), to->mutable_values())); // TODO(b/170415165) Should we use `copy` to move splits from device<->host? *to->mutable_nested_splits() = from.nested_splits(); - return OkStatus(); + return absl::OkStatus(); } } // namespace diff --git a/tensorflow/core/kernels/ragged_tensor_variant.h b/tensorflow/core/kernels/ragged_tensor_variant.h index 4bef66a9e745b3..db35b8bcf0d35b 100644 --- a/tensorflow/core/kernels/ragged_tensor_variant.h +++ b/tensorflow/core/kernels/ragged_tensor_variant.h @@ -74,7 +74,7 @@ Status RaggedTensorVariantZerosLike(OpKernelContext* c, y->set_nested_splits(x.nested_splits()); TF_RETURN_IF_ERROR( ZerosLikeTensor(c, x.values(), y->mutable_values())); - return OkStatus(); + return absl::OkStatus(); } template @@ -102,7 +102,7 @@ Status RaggedTensorVariantBinaryAdd(OpKernelContext* c, out->set_nested_splits(x.nested_splits()); TF_RETURN_IF_ERROR(BinaryAddTensors(c, x.values(), y.values(), out->mutable_values())); - return OkStatus(); + return absl::OkStatus(); } } // namespace tensorflow diff --git a/tensorflow/core/kernels/random_index_shuffle_ops.cc b/tensorflow/core/kernels/random_index_shuffle_ops.cc index 6233f982381a81..d1f275b5fe4413 100644 --- a/tensorflow/core/kernels/random_index_shuffle_ops.cc +++ b/tensorflow/core/kernels/random_index_shuffle_ops.cc @@ -60,7 +60,7 @@ Status GetSeed(const Tensor& seed_t, const int row, return errors::InvalidArgument("Invalid seed type: ", DataTypeString(seed_t.dtype())); } - return OkStatus(); + return absl::OkStatus(); } template diff --git a/tensorflow/core/kernels/random_shuffle_queue_op.cc b/tensorflow/core/kernels/random_shuffle_queue_op.cc index 1f0352a153d7d5..bed97066d14d03 100644 --- a/tensorflow/core/kernels/random_shuffle_queue_op.cc +++ b/tensorflow/core/kernels/random_shuffle_queue_op.cc @@ -113,7 +113,7 @@ Status RandomShuffleQueue::Initialize() { for (int i = 0; i < num_components(); ++i) { queues_[i].reserve(min_after_dequeue_); } - return OkStatus(); + return absl::OkStatus(); } void RandomShuffleQueue::DequeueLocked(OpKernelContext* ctx, Tuple* tuple) { @@ -176,7 +176,7 @@ Status RandomShuffleQueue::GetElementComponentFromBatch(const Tuple& tuple, ctx->allocate_temp(tuple[component].dtype(), element_shape, out_tensor)); TF_RETURN_IF_ERROR( batch_util::CopySliceToElement(tuple[component], out_tensor, index)); - return OkStatus(); + return absl::OkStatus(); } void RandomShuffleQueue::TryEnqueueMany(const Tuple& tuple, @@ -469,7 +469,7 @@ Status RandomShuffleQueue::MatchesNodeDef(const NodeDef& node_def) { TF_RETURN_IF_ERROR(MatchesNodeDefTypes(node_def)); TF_RETURN_IF_ERROR(MatchesNodeDefShapes(node_def)); - return OkStatus(); + return absl::OkStatus(); } // Defines a RandomShuffleQueueOp, which produces a Queue (specifically, one diff --git a/tensorflow/core/kernels/range_sampler.cc b/tensorflow/core/kernels/range_sampler.cc index 79e5889454be5f..eae756b89896e5 100644 --- a/tensorflow/core/kernels/range_sampler.cc +++ b/tensorflow/core/kernels/range_sampler.cc @@ -252,7 +252,7 @@ Status FixedUnigramSampler::SetDistributionSampler(Env* env, " must be equal to weights size ", weights_.size())); dist_sampler_.reset(new random::DistributionSampler(weights_)); - return OkStatus(); + return absl::OkStatus(); } Status FixedUnigramSampler::SetDistributionSampler( @@ -263,7 +263,7 @@ Status FixedUnigramSampler::SetDistributionSampler( " must be equal to weights size ", weights_.size())); dist_sampler_.reset(new random::DistributionSampler(weights_)); - return OkStatus(); + return absl::OkStatus(); } float FixedUnigramSampler::Probability(int64_t value) const { @@ -309,7 +309,7 @@ Status FixedUnigramSampler::LoadFromFile(Env* env, const string& vocab_file, } ++word_id; } - return OkStatus(); + return absl::OkStatus(); } void FixedUnigramSampler::LoadFromUnigrams(const std::vector& unigrams, diff --git a/tensorflow/core/kernels/record_yielder.cc b/tensorflow/core/kernels/record_yielder.cc index 593d11cb2efd91..6bba98ab2c1ae9 100644 --- a/tensorflow/core/kernels/record_yielder.cc +++ b/tensorflow/core/kernels/record_yielder.cc @@ -91,7 +91,7 @@ static Status MatchFiles(const string& patterns, std::make_move_iterator(tmp_filenames.begin()), std::make_move_iterator(tmp_filenames.end())); } - return OkStatus(); + return absl::OkStatus(); } void RecordYielder::MainLoop() { @@ -200,7 +200,7 @@ void RecordYielder::ShardLoop(Shard* shard) { const int64_t kRecords = 16; for (const string& filename : shard->filenames) { std::unique_ptr file; - if (ShouldFinish(OkStatus())) break; + if (ShouldFinish(absl::OkStatus())) break; Status s = Env::Default()->NewRandomAccessFile(filename, &file); if (!s.ok()) { shard->status = errors::InvalidArgument("Can't open ", filename); diff --git a/tensorflow/core/kernels/reduction_ops_common.cc b/tensorflow/core/kernels/reduction_ops_common.cc index 943e75e93ffef2..d66ad0954f710b 100644 --- a/tensorflow/core/kernels/reduction_ops_common.cc +++ b/tensorflow/core/kernels/reduction_ops_common.cc @@ -76,7 +76,7 @@ Status SimplifyHelper(const Tensor& data, const Tensor& axis, } bitmap[index] = true; } - return OkStatus(); + return absl::OkStatus(); } Status ReductionHelper::Simplify(const Tensor& data, const Tensor& axis, @@ -154,7 +154,7 @@ Status ReductionHelper::Simplify(const Tensor& data, const Tensor& axis, VLOG(1) << "data reshape: " << absl::StrJoin(data_reshape_, ","); VLOG(1) << "out reshape: " << absl::StrJoin(out_reshape_, ","); VLOG(1) << "out shape: " << absl::StrJoin(out_shape_, ","); - return OkStatus(); + return absl::OkStatus(); } } // namespace tensorflow diff --git a/tensorflow/core/kernels/regex_replace_op.cc b/tensorflow/core/kernels/regex_replace_op.cc index 7b615aaf401fef..a3e5041d189f83 100644 --- a/tensorflow/core/kernels/regex_replace_op.cc +++ b/tensorflow/core/kernels/regex_replace_op.cc @@ -59,7 +59,7 @@ Status InternalCompute(const RE2& regex, const string& rewrite, } output_flat(i) = std::move(buf); } - return OkStatus(); + return absl::OkStatus(); } } // namespace diff --git a/tensorflow/core/kernels/reshape_op.h b/tensorflow/core/kernels/reshape_op.h index 099678c94324b8..e7459d91aef01d 100644 --- a/tensorflow/core/kernels/reshape_op.h +++ b/tensorflow/core/kernels/reshape_op.h @@ -159,7 +159,7 @@ class ReshapeOp : public OpKernel { (*product) *= size; } } - return OkStatus(); + return absl::OkStatus(); } }; diff --git a/tensorflow/core/kernels/resource_variable_ops.cc b/tensorflow/core/kernels/resource_variable_ops.cc index df0b87f4349f8b..d6d6cf228a9175 100644 --- a/tensorflow/core/kernels/resource_variable_ops.cc +++ b/tensorflow/core/kernels/resource_variable_ops.cc @@ -138,7 +138,7 @@ Status CopyVariable(int output_idx, OpKernelContext* ctx, const Tensor* t) { return errors::Internal("Unsupported dtype", t->dtype()); } } - return OkStatus(); + return absl::OkStatus(); } } // namespace @@ -432,7 +432,7 @@ class AssignVariableOp : public OpKernel { *ptr = new Var(dtype_); *(*ptr)->tensor() = value; (*ptr)->is_initialized = true; - return OkStatus(); + return absl::OkStatus(); })); mutex_lock ml(*variable->mu()); // (variable->tensor()->dtype() == DT_INVALID && !variable->is_initialized) @@ -1096,7 +1096,7 @@ Status DoScatter(OpKernelContext* c, Tensor* params, const Tensor& indices, } } } - return OkStatus(); + return absl::OkStatus(); } } // namespace diff --git a/tensorflow/core/kernels/resource_variable_util.cc b/tensorflow/core/kernels/resource_variable_util.cc index f6c0bac09fff9e..d66ac4e3e1d47f 100644 --- a/tensorflow/core/kernels/resource_variable_util.cc +++ b/tensorflow/core/kernels/resource_variable_util.cc @@ -27,7 +27,7 @@ Status ValidateAssignUpdateVariableOpShapes(const TensorShape& variable_shape, " using a Tensor with shape ", value_shape.DebugString(), ", shapes must be equal."); } - return OkStatus(); + return absl::OkStatus(); } } // namespace tensorflow diff --git a/tensorflow/core/kernels/rnn/blas_gemm.cc b/tensorflow/core/kernels/rnn/blas_gemm.cc index 96c9cde2a74f1d..a79a5506e34e68 100644 --- a/tensorflow/core/kernels/rnn/blas_gemm.cc +++ b/tensorflow/core/kernels/rnn/blas_gemm.cc @@ -49,12 +49,15 @@ void TensorCuBlasGemm::operator()(OpKernelContext* ctx, bool transa, auto a_ptr = AsDeviceMemory(a); auto b_ptr = AsDeviceMemory(b); auto c_ptr = AsDeviceMemory(c); + auto* stream = ctx->op_device_context()->stream(); + auto* blas = stream->parent()->AsBlas(); + OP_REQUIRES(ctx, blas != nullptr, absl::InternalError("No BLAS for stream.")); OP_REQUIRES_OK( - ctx, ctx->op_device_context()->stream()->ThenBlasGemm( - trans[transa], trans[transb], m, n, k, static_cast(alpha), - a_ptr, lda, b_ptr, ldb, static_cast(beta), &c_ptr, ldc, - GetNumericOptions(), se::blas::CallContext::kNone)); + ctx, blas->BlasGemm(stream, trans[transa], trans[transb], m, n, k, + static_cast(alpha), a_ptr, lda, b_ptr, ldb, + static_cast(beta), &c_ptr, ldc, + GetNumericOptions(), se::blas::CallContext::kNone)); #else ctx->SetStatus(errors::InvalidArgument("CuBlasGemm needs CUDA.")); #endif diff --git a/tensorflow/core/kernels/save_restore_tensor.cc b/tensorflow/core/kernels/save_restore_tensor.cc index fa292292a34cef..7b433ab8b5b9a0 100644 --- a/tensorflow/core/kernels/save_restore_tensor.cc +++ b/tensorflow/core/kernels/save_restore_tensor.cc @@ -346,7 +346,7 @@ struct RestoreOp { } VLOG(1) << "Done restoring tensor " << idx << " : " << tensor_name << " : " << restored_full_shape.num_elements(); - return OkStatus(); + return absl::OkStatus(); } OpKernelContext* context; @@ -444,7 +444,7 @@ Status RestoreTensorsV2(OpKernelContext* context, const Tensor& prefix, } } - return OkStatus(); + return absl::OkStatus(); } } // namespace tensorflow diff --git a/tensorflow/core/kernels/save_restore_v2_ops.cc b/tensorflow/core/kernels/save_restore_v2_ops.cc index 617cf60d37b89d..eb329b9b90b616 100644 --- a/tensorflow/core/kernels/save_restore_v2_ops.cc +++ b/tensorflow/core/kernels/save_restore_v2_ops.cc @@ -176,7 +176,7 @@ class SaveV2 : public OpKernel { &checkpoint_callback_manager, [](checkpoint::CheckpointCallbackManager** out) { *out = new checkpoint::CheckpointCallbackManager(); - return OkStatus(); + return absl::OkStatus(); })); checkpoint_callback_manager->Save(prefix_string); checkpoint_callback_manager->Unref(); @@ -245,7 +245,7 @@ class RestoreV2 : public OpKernel { &checkpoint_callback_manager, [](checkpoint::CheckpointCallbackManager** out) { *out = new checkpoint::CheckpointCallbackManager(); - return OkStatus(); + return absl::OkStatus(); })); checkpoint_callback_manager->Restore(prefix_string); checkpoint_callback_manager->Unref(); diff --git a/tensorflow/core/kernels/scatter_nd_op.cc b/tensorflow/core/kernels/scatter_nd_op.cc index 9effcca70f0ef8..1cb4837133d95d 100644 --- a/tensorflow/core/kernels/scatter_nd_op.cc +++ b/tensorflow/core/kernels/scatter_nd_op.cc @@ -863,7 +863,7 @@ Status PrepareAndValidateInputs(const TensorShape& params_shape, const int64_t safe_slice_dim = (*slice_dim < 1) ? 1 : *slice_dim; *num_updates = indices_shape.num_elements() / safe_slice_dim; - return OkStatus(); + return absl::OkStatus(); } template @@ -904,7 +904,7 @@ Status DoScatterNdImpl(OpKernelContext* c, const Tensor& indices, } if (shape.num_elements() == 0) { - return OkStatus(); + return absl::OkStatus(); } if (allocate) { @@ -956,7 +956,7 @@ Status DoScatterNdImpl(OpKernelContext* c, const Tensor& indices, gtl::ArraySlice(&indices_flat(bad_i, 0), slice_dim), ", "), "] does not index into shape ", shape.DebugString()); } - return OkStatus(); + return absl::OkStatus(); } template diff --git a/tensorflow/core/kernels/scatter_nd_util.cc b/tensorflow/core/kernels/scatter_nd_util.cc index 403c5361225880..4793e4ce99761c 100644 --- a/tensorflow/core/kernels/scatter_nd_util.cc +++ b/tensorflow/core/kernels/scatter_nd_util.cc @@ -61,7 +61,7 @@ Status ValidateScatterNdUpdateShape(const TensorShape& params_shape, return shape_err_suffix(); } } - return OkStatus(); + return absl::OkStatus(); } } // namespace tensorflow diff --git a/tensorflow/core/kernels/sdca_internal.cc b/tensorflow/core/kernels/sdca_internal.cc index a3b3db6bdc9121..bf669e0a13e7c5 100644 --- a/tensorflow/core/kernels/sdca_internal.cc +++ b/tensorflow/core/kernels/sdca_internal.cc @@ -156,7 +156,7 @@ Status ModelWeights::Initialize(OpKernelContext* const context) { {1, weight_inputs[i].NumElements()}), deltas}); } - return OkStatus(); + return absl::OkStatus(); }; return initialize_weights(dense_weights_inputs, &dense_weights_outputs, @@ -319,7 +319,7 @@ Status Examples::SampleAdaptiveProbabilities( for (int i = id; i < num_examples(); ++i) { sampled_count_[i] = examples_not_seen[i - id].first; } - return OkStatus(); + return absl::OkStatus(); } void Examples::RandomShuffle() { @@ -421,7 +421,7 @@ Status Examples::Initialize(OpKernelContext* const context, TF_RETURN_IF_ERROR(ComputeSquaredNormPerExample( worker_threads, num_examples, num_sparse_features, num_dense_features, &examples_)); - return OkStatus(); + return absl::OkStatus(); } Status Examples::CreateSparseFeatureRepresentation( diff --git a/tensorflow/core/kernels/sdca_internal.h b/tensorflow/core/kernels/sdca_internal.h index 5f5f8bacd15c1e..5bf6325f4d2fae 100644 --- a/tensorflow/core/kernels/sdca_internal.h +++ b/tensorflow/core/kernels/sdca_internal.h @@ -80,7 +80,7 @@ class Regularizations { TF_RETURN_IF_ERROR(context->GetAttr("l1", &symmetric_l1_)); TF_RETURN_IF_ERROR(context->GetAttr("l2", &symmetric_l2_)); shrinkage_ = symmetric_l1_ / symmetric_l2_; - return OkStatus(); + return absl::OkStatus(); } // Proximal SDCA shrinking for L1 regularization. diff --git a/tensorflow/core/kernels/segment_reduction_ops_impl.h b/tensorflow/core/kernels/segment_reduction_ops_impl.h index f2fe9f528607db..b3364da3af53be 100644 --- a/tensorflow/core/kernels/segment_reduction_ops_impl.h +++ b/tensorflow/core/kernels/segment_reduction_ops_impl.h @@ -1355,7 +1355,7 @@ class SparseSegmentGradV2OpCommon { Tensor* sorted_unique_indices = nullptr; TF_RETURN_IF_ERROR(context->allocate_output(1, TensorShape({0}), &sorted_unique_indices)); - return OkStatus(); + return absl::OkStatus(); } auto input_flat = input.flat_outer_dims(); @@ -1366,7 +1366,7 @@ class SparseSegmentGradV2OpCommon { context, operation, input_flat, indices_vec, segment_vec, dense_output_shape, done); - return OkStatus(); + return absl::OkStatus(); } }; diff --git a/tensorflow/core/kernels/segment_reduction_ops_impl_1.cc b/tensorflow/core/kernels/segment_reduction_ops_impl_1.cc index 4b0e55de7bb5bb..24d5f94b60f987 100644 --- a/tensorflow/core/kernels/segment_reduction_ops_impl_1.cc +++ b/tensorflow/core/kernels/segment_reduction_ops_impl_1.cc @@ -34,7 +34,7 @@ Status ValidateSegmentReduction(OpKernelContext* context, const Tensor& input, " input."); } - return OkStatus(); + return absl::OkStatus(); } // check routines not in the templated class to reduce code size @@ -55,7 +55,7 @@ Status ValidateUnsortedSegmentReduction(OpKernel* op_kernel, segment_ids.shape().DebugString()); } - return OkStatus(); + return absl::OkStatus(); } Status ValidateSparseSegmentReduction(OpKernelContext* context, @@ -97,7 +97,7 @@ Status ValidateSparseSegmentReduction(OpKernelContext* context, return errors::InvalidArgument("Shape must be at least rank 1"); } - return OkStatus(); + return absl::OkStatus(); } } // namespace internal diff --git a/tensorflow/core/kernels/sendrecv_ops_test.cc b/tensorflow/core/kernels/sendrecv_ops_test.cc index 75c02e6308e7ea..10f23e418d9a57 100644 --- a/tensorflow/core/kernels/sendrecv_ops_test.cc +++ b/tensorflow/core/kernels/sendrecv_ops_test.cc @@ -29,12 +29,12 @@ namespace { class DummyRendezvous : public Rendezvous { Status Send(const ParsedKey& key, const Args& args, const Tensor& val, const bool is_dead) override { - return OkStatus(); + return absl::OkStatus(); } void RecvAsync(const ParsedKey& key, const Args& args, DoneCallback done) override { static Tensor* t = new Tensor(DT_FLOAT, TensorShape({0})); - done(OkStatus(), args, args, *t, false); + done(absl::OkStatus(), args, args, *t, false); } void StartAbort(const Status& status) override {} }; diff --git a/tensorflow/core/kernels/set_kernels.cc b/tensorflow/core/kernels/set_kernels.cc index 2ae7e276fd0eaf..8ba24346ab1cb5 100644 --- a/tensorflow/core/kernels/set_kernels.cc +++ b/tensorflow/core/kernels/set_kernels.cc @@ -60,7 +60,7 @@ Status GroupShape(const VarDimArray& input_shape, ShapeArray* grouped_shape) { } // grouped_shape is input_shape[:-1] *grouped_shape = ShapeArray(input_shape.begin(), input_shape.end() - 1); - return OkStatus(); + return absl::OkStatus(); } // Build `SparseTensor` from indices, values, and shape in inputs @@ -425,7 +425,7 @@ Status CheckShapesMatch(VarDimArray shape1, VarDimArray shape2) { absl::StrJoin(shape1, ","), "] vs [", absl::StrJoin(shape2, ","), "]"); } - return OkStatus(); + return absl::OkStatus(); } // Validate ranks are the same, and all but last dimension are the same. @@ -438,7 +438,7 @@ Status GroupShapeFromInputs(VarDimArray shape1, VarDimArray shape2, TF_RETURN_IF_ERROR(GroupShape(shape2, &group_shape_2)); TF_RETURN_IF_ERROR(CheckShapesMatch(group_shape_1, group_shape_2)); *group_shape = group_shape_1; - return OkStatus(); + return absl::OkStatus(); } // Split `flat_group_index` into separate dimensions based on `group_shape`. diff --git a/tensorflow/core/kernels/shape_ops.h b/tensorflow/core/kernels/shape_ops.h index 296031eceb53d6..dcddddc5e38686 100644 --- a/tensorflow/core/kernels/shape_ops.h +++ b/tensorflow/core/kernels/shape_ops.h @@ -34,7 +34,7 @@ namespace shape_op_helpers { inline Status GetShape(OpKernelContext* ctx, int input_index, TensorShape* shape) { *shape = ctx->input(input_index).shape(); - return OkStatus(); + return absl::OkStatus(); } } // namespace shape_op_helpers diff --git a/tensorflow/core/kernels/shuffle_common.h b/tensorflow/core/kernels/shuffle_common.h index feb62accc02e53..0fd93bdfca9573 100644 --- a/tensorflow/core/kernels/shuffle_common.h +++ b/tensorflow/core/kernels/shuffle_common.h @@ -94,7 +94,7 @@ Status RandomShuffle(OpKernelContext* context, const Tensor& input, } } } - return OkStatus(); + return absl::OkStatus(); } } // namespace tensorflow diff --git a/tensorflow/core/kernels/smooth-hinge-loss.h b/tensorflow/core/kernels/smooth-hinge-loss.h index beb56fbf01882b..f1019b7c53cb7c 100644 --- a/tensorflow/core/kernels/smooth-hinge-loss.h +++ b/tensorflow/core/kernels/smooth-hinge-loss.h @@ -78,10 +78,10 @@ class SmoothHingeLossUpdater : public DualLossUpdater { Status ConvertLabel(float* const example_label) const final { if (*example_label == 0.0) { *example_label = -1; - return OkStatus(); + return absl::OkStatus(); } if (*example_label == 1.0) { - return OkStatus(); + return absl::OkStatus(); } return errors::InvalidArgument( "Only labels of 0.0 or 1.0 are supported right now. " diff --git a/tensorflow/core/kernels/spacetobatch_op.cc b/tensorflow/core/kernels/spacetobatch_op.cc index 079c3c32e3d5b3..210ff0c93023a4 100644 --- a/tensorflow/core/kernels/spacetobatch_op.cc +++ b/tensorflow/core/kernels/spacetobatch_op.cc @@ -123,7 +123,7 @@ Status SpaceToBatchOpCompute(OpKernelContext* context, if (internal_block_dims == 0) { context->set_output(0, orig_input_tensor); - return OkStatus(); + return absl::OkStatus(); } // For the purpose of computing the result, the input will be treated as @@ -212,7 +212,7 @@ Status SpaceToBatchOpCompute(OpKernelContext* context, TF_SPACETOBATCH_FOR_EACH_NUM_BLOCK_DIMS(TF_SPACETOBATCH_BLOCK_DIMS_CASE) #undef TF_SPACETOBATCH_BLOCK_DIMS_CASE } - return OkStatus(); + return absl::OkStatus(); } } // namespace diff --git a/tensorflow/core/kernels/sparse/BUILD b/tensorflow/core/kernels/sparse/BUILD index d99e8b6f79b3b3..103401b95dc348 100644 --- a/tensorflow/core/kernels/sparse/BUILD +++ b/tensorflow/core/kernels/sparse/BUILD @@ -53,7 +53,6 @@ tf_kernel_library( ], hdrs = [ "kernels.h", - "mat_mul_op.h", "transpose_op.h", "zeros_op.h", ], diff --git a/tensorflow/core/kernels/sparse/add_op.cc b/tensorflow/core/kernels/sparse/add_op.cc index 6dd92e11320043..e27de2a1782b91 100644 --- a/tensorflow/core/kernels/sparse/add_op.cc +++ b/tensorflow/core/kernels/sparse/add_op.cc @@ -175,7 +175,7 @@ class CSRSparseMatrixAddFunctor { TF_RETURN_IF_ERROR(csr_geam.Compute(a_comp, b_comp, &c_comp, workspace)); } - return OkStatus(); + return absl::OkStatus(); } private: diff --git a/tensorflow/core/kernels/sparse/conj_op.cc b/tensorflow/core/kernels/sparse/conj_op.cc index 36ed86c587b1bc..0436ea0a85f889 100644 --- a/tensorflow/core/kernels/sparse/conj_op.cc +++ b/tensorflow/core/kernels/sparse/conj_op.cc @@ -60,7 +60,7 @@ class CSRSparseMatrixConjFunctor { functor::UnaryFunctor> func; func(d, b->values().flat() /*out*/, a.values().flat() /*in*/); - return OkStatus(); + return absl::OkStatus(); } private: diff --git a/tensorflow/core/kernels/sparse/csr_sparse_matrix_to_sparse_tensor_op.cc b/tensorflow/core/kernels/sparse/csr_sparse_matrix_to_sparse_tensor_op.cc index 32d452eadd1f8c..ff3543f4602c77 100644 --- a/tensorflow/core/kernels/sparse/csr_sparse_matrix_to_sparse_tensor_op.cc +++ b/tensorflow/core/kernels/sparse/csr_sparse_matrix_to_sparse_tensor_op.cc @@ -55,7 +55,7 @@ Status ValidateCSRSparseMatrix(const CSRSparseMatrix& csr_sparse_matrix, return errors::InvalidArgument("CSR SparseMatrix must have rank 2 or 3; ", "but dense_shape has size ", rank); } - return OkStatus(); + return absl::OkStatus(); } } // namespace diff --git a/tensorflow/core/kernels/sparse/kernels.cc b/tensorflow/core/kernels/sparse/kernels.cc index 63651473ab9c96..bbb96743d0be21 100644 --- a/tensorflow/core/kernels/sparse/kernels.cc +++ b/tensorflow/core/kernels/sparse/kernels.cc @@ -140,7 +140,7 @@ Status SparseTensorToCSRSparseMatrixCPUFunctor::operator()( std::partial_sum(row_ptr_batch, row_ptr_batch + num_rows + 1, row_ptr_batch); } - return OkStatus(); + return absl::OkStatus(); } } // namespace functor diff --git a/tensorflow/core/kernels/sparse/mat_mul_op.cc b/tensorflow/core/kernels/sparse/mat_mul_op.cc index 7422bdb5fcb40c..ed5fd044f825f8 100644 --- a/tensorflow/core/kernels/sparse/mat_mul_op.cc +++ b/tensorflow/core/kernels/sparse/mat_mul_op.cc @@ -19,8 +19,9 @@ limitations under the License. #define EIGEN_USE_GPU #endif -#include "Eigen/Core" // from @eigen_archive +#include "Eigen/Core" // from @eigen_archive #include "Eigen/SparseCore" // from @eigen_archive +#include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/tensor_types.h" @@ -35,12 +36,10 @@ limitations under the License. #include "tensorflow/core/kernels/transpose_functor.h" #include "tensorflow/core/lib/gtl/inlined_vector.h" #include "tensorflow/core/platform/threadpool.h" -#include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM #include "tensorflow/core/util/cuda_sparse.h" #include "tensorflow/core/util/gpu_solvers.h" -#include "third_party/gpus/cuda/include/cuda.h" #endif namespace tensorflow { @@ -86,12 +85,12 @@ class CSRMatMulOp : public OpKernel { bool adjoint_a; OP_REQUIRES_OK(c, c->GetAttr("adjoint_a", &adjoint_a)); OP_REQUIRES(c, !(adjoint_a && transpose_a_), - absl::InvalidArgumentError( + errors::InvalidArgument( "Only one of adjoint_a and transpose_a may be true.")); bool adjoint_b; OP_REQUIRES_OK(c, c->GetAttr("adjoint_b", &adjoint_b)); OP_REQUIRES(c, !(adjoint_b && transpose_b_), - absl::InvalidArgumentError( + errors::InvalidArgument( "Only one of adjoint_b and transpose_b may be true.")); OP_REQUIRES_OK(c, c->GetAttr("transpose_output", &transpose_output_)); OP_REQUIRES_OK(c, c->GetAttr("conjugate_output", &conjugate_output_)); @@ -112,24 +111,23 @@ class CSRMatMulOp : public OpKernel { const Tensor& dense_tensor_b, int* rank, int64_t* batch_size) { if (sparse_matrix_a.dtype() != dense_tensor_b.dtype()) { - return absl::InvalidArgumentError(absl::StrCat( + return errors::InvalidArgument( "Input types don't match. a.dtype == ", DataTypeString(sparse_matrix_a.dtype()), - " vs. b.dtype == ", DataTypeString(dense_tensor_b.dtype()))); + " vs. b.dtype == ", DataTypeString(dense_tensor_b.dtype())); } *rank = sparse_matrix_a.dims(); // TODO(ebrevdo): Add support for broadcasting matmul. if (*rank != dense_tensor_b.dims()) { - return absl::InvalidArgumentError( - absl::StrCat("Ranks of a and b must match, saw: ", *rank, " vs. ", - dense_tensor_b.dims(), ".")); + return errors::InvalidArgument("Ranks of a and b must match, saw: ", rank, + " vs. ", dense_tensor_b.dims(), "."); } // A valid CSR SparseMatrix has rank 2 or rank 3. *batch_size = (*rank == 2) ? 1 : dense_tensor_b.dim_size(0); if (sparse_matrix_a.batch_size() != *batch_size) { - return absl::InvalidArgumentError(absl::StrCat( - "Batch sizes of a and b must match, saw: ", - sparse_matrix_a.batch_size(), " vs. ", *batch_size, ".")); + return errors::InvalidArgument("Batch sizes of a and b must match, saw: ", + sparse_matrix_a.batch_size(), " vs. ", + batch_size, "."); } const auto& a_dense_shape = sparse_matrix_a.dense_shape().vec(); const int64_t a_inner_dim = @@ -137,12 +135,12 @@ class CSRMatMulOp : public OpKernel { const int64_t b_inner_dim = dense_tensor_b.dim_size(this->transpose_b_ ? *rank - 1 : *rank - 2); if (a_inner_dim != b_inner_dim) { - return absl::InvalidArgumentError( - absl::StrCat("Inner product dimensions of A and B do not agree. ", - "Shapes are: ", TensorShape(a_dense_shape).DebugString(), - " vs. ", dense_tensor_b.shape().DebugString())); + return errors::InvalidArgument( + "Inner product dimensions of A and B do not agree. Shapes are: ", + TensorShape(a_dense_shape), " vs. ", + dense_tensor_b.shape().DebugString()); } - return OkStatus(); + return absl::OkStatus(); } public: @@ -265,7 +263,7 @@ class CSRMatMulCPUOp : public CSRMatMulOp { TF_RETURN_IF_ERROR(ctx->allocate_output(0, output_shape, output)); *matmul_result = output_transposed; } - return OkStatus(); + return absl::OkStatus(); } // Returns an Eigen::Ref expression of a sparse sub-matrix from the given @@ -490,7 +488,7 @@ class CSRMatMulCPUOp : public CSRMatMulOp { TF_RETURN_IF_ERROR( DoMatrixTranspose(ctx->eigen_device(), input, output)); } - return OkStatus(); + return absl::OkStatus(); } }; diff --git a/tensorflow/core/kernels/sparse/mat_mul_op.h b/tensorflow/core/kernels/sparse/mat_mul_op.h deleted file mode 100644 index a59c136a44c079..00000000000000 --- a/tensorflow/core/kernels/sparse/mat_mul_op.h +++ /dev/null @@ -1,166 +0,0 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_CORE_KERNELS_SPARSE_MAT_MUL_OP_H_ -#define TENSORFLOW_CORE_KERNELS_SPARSE_MAT_MUL_OP_H_ - -#define EIGEN_USE_THREADS - -#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM -#define EIGEN_USE_GPU -#endif - -#include "Eigen/Core" // from @eigen_archive -#include "Eigen/SparseCore" // from @eigen_archive -#include "tensorflow/core/framework/op.h" -#include "tensorflow/core/framework/op_kernel.h" -#include "tensorflow/core/framework/tensor_types.h" -#include "tensorflow/core/framework/type_traits.h" -#include "tensorflow/core/framework/variant_op_registry.h" -#include "tensorflow/core/kernels/cwise_ops_common.h" -#include "tensorflow/core/kernels/dense_update_functor.h" -#include "tensorflow/core/kernels/fill_functor.h" -#include "tensorflow/core/kernels/sparse/kernels.h" -#include "tensorflow/core/kernels/sparse/sparse_matrix.h" -#include "tensorflow/core/kernels/sparse/transpose_op.h" -#include "tensorflow/core/kernels/transpose_functor.h" -#include "tensorflow/core/lib/gtl/inlined_vector.h" -#include "tensorflow/core/platform/threadpool.h" -#include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive - -#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM -#include "tensorflow/core/util/cuda_sparse.h" -#include "tensorflow/core/util/gpu_solvers.h" -#endif - -#include "tensorflow/core/kernels/sparse/mat_mul_op.h" - -namespace tensorflow { - -// TODO(anudhyan): These constants may be tuned based on the performance of -// 'benchmark_sparse_matrix_mat_vec_mul'. We would like to find constants -// which work across hardware platforms for typical matrix sizes. It should be -// possible to observe at least 30-50% improvement as we increase the number -// of threads by 1. If not, then it may we worth increasing kMaxShards and -// kNumShardsPerThread. However, once we have too many shards, latency may be -// dominated by per-shard overhead. -// -// Maximum number of shards into which to divide the computation for each CSR -// Sparse Matrix instance. -static constexpr int32_t kMaxShards = 20; -// Number of shards allocated to each thread. -static constexpr int32_t kNumShardsPerThread = 3; - -typedef Eigen::ThreadPoolDevice CPUDevice; -typedef Eigen::GpuDevice GPUDevice; - -// Abstract OpKernel to compute sparse-dense matrix multiplication. -// -// Implements a kernel which, given a SparseMatrix `a` and dense Tensor `b`, -// computes a dense Tensor `c` satisfying `c = a * b` where * denotes matrix -// multiplication. -// -// The boolean attributes `transpose_a` and `adjoint_a` will transpose or -// adjoint `a` before multiplication, respectively. At most one of these -// attributes must be set to True. Corresponding attributes will transpose or -// adjoint `b` or the output (after multiplication). -// -// The rank of both `a` and `b` must be equal and their shapes must be -// compatible for matrix multiplication. Otherwise, InvalidArgument runtime -// errors will be thrown. Only rank 2 or rank 3 inputs are supported. -// -template -class CSRMatMulOp : public OpKernel { - public: - explicit CSRMatMulOp(OpKernelConstruction* c); - - ~CSRMatMulOp() override {} - - Status ValidateInputs(const CSRSparseMatrix& sparse_matrix_a, - const Tensor& dense_tensor_b, int* rank, - int64_t* batch_size); - - public: - bool transpose_a_; - bool transpose_b_; - bool conjugate_a_; - bool conjugate_b_; - bool transpose_output_; - bool conjugate_output_; -}; - -// CPU Kernel to compute sparse-dense matrix multiplication. -// -// Uses Eigen SparseMatrix to compute the sparse-dense multiplication between -// a CSR SparseMatrix `a` and dense Tensor `b`. If intra-op parallelism is -// available, the implementation parallelizes the computation across each row -// of the sparse matrix. -template -class CSRMatMulCPUOp : public CSRMatMulOp { - using SparseMatrix = Eigen::SparseMatrix; - using Matrix = - Eigen::Matrix; - using ConstMatrixMap = Eigen::Map; - using MatrixMap = Eigen::Map; - - public: - explicit CSRMatMulCPUOp(OpKernelConstruction* c) - : CSRMatMulOp(c) {} - - ~CSRMatMulCPUOp() override{}; - - void Compute(OpKernelContext* ctx) final; - - private: - Status AllocateOutput(OpKernelContext* ctx, const int32_t rank, - const int64_t batch_size, const int64_t num_rows, - const int64_t num_cols, const bool transpose_output, - Tensor** output, Tensor* output_transposed, - Tensor** matmul_result); - - Eigen::Ref GetSparseMatrixRef( - const CSRSparseMatrix& csr_matrix, const int batch_index, - const int64_t row_begin, const int64_t num_shard_rows, - std::vector* row_ptrs); - - void SparseDenseMatMulWithoutTransposedLHS(OpKernelContext* ctx, - const int64_t batch_size, - const int64_t num_lhs_rows, - const CSRSparseMatrix& lhs, - const Tensor& rhs, Tensor* output); - - void SparseDenseMatMulWithTransposedLHS(OpKernelContext* ctx, - const int64_t batch_size, - const int64_t num_lhs_rows, - const int64_t num_lhs_cols, - const CSRSparseMatrix& lhs, - const Tensor& rhs, Tensor* output); - - void HandleBatchAndRowRange( - const int64_t num_rows, const int64_t batch_and_row_begin, - const int64_t batch_and_row_end, - const std::function& fn); - - Status TransposeAndConjugateTensor(OpKernelContext* ctx, const Tensor& input, - bool conjugate, Tensor* output); - - Status TransposeAndConjugateAllocatedTensor(OpKernelContext* ctx, - const Tensor& input, - bool conjugate, Tensor* output); -}; - -} // namespace tensorflow - -#endif // TENSORFLOW_CORE_KERNELS_SPARSE_MAT_MUL_OP_H_ diff --git a/tensorflow/core/kernels/sparse/sparse_cholesky_op.cc b/tensorflow/core/kernels/sparse/sparse_cholesky_op.cc index 3045f97ce0ac18..90f4fbde158748 100644 --- a/tensorflow/core/kernels/sparse/sparse_cholesky_op.cc +++ b/tensorflow/core/kernels/sparse/sparse_cholesky_op.cc @@ -269,7 +269,7 @@ class CSRSparseCholeskyCPUOp : public OpKernel { perm_shape.dim_size(0), " != ", *batch_size); } - return OkStatus(); + return absl::OkStatus(); } }; diff --git a/tensorflow/core/kernels/sparse/sparse_matrix.h b/tensorflow/core/kernels/sparse/sparse_matrix.h index 5601c6c7de2196..95b93443863ecc 100644 --- a/tensorflow/core/kernels/sparse/sparse_matrix.h +++ b/tensorflow/core/kernels/sparse/sparse_matrix.h @@ -525,7 +525,7 @@ class CSRSparseMatrix { "CSRSparseMatrix::Validate: size(col_indices) = ", col_indices.dim_size(0), " != size(values) = ", values.dim_size(0)); } - return OkStatus(); + return absl::OkStatus(); } struct Metadata { @@ -648,7 +648,7 @@ Status ExtractVariantFromInput(OpKernelContext* ctx, int index, if (!(*value)->valid()) { return errors::InvalidArgument("Variant input ", index, " is not valid."); } - return OkStatus(); + return absl::OkStatus(); } } // namespace tensorflow diff --git a/tensorflow/core/kernels/sparse/transpose_op.cc b/tensorflow/core/kernels/sparse/transpose_op.cc index 6f98e6704e471b..44e4783613ccae 100644 --- a/tensorflow/core/kernels/sparse/transpose_op.cc +++ b/tensorflow/core/kernels/sparse/transpose_op.cc @@ -87,7 +87,7 @@ Status ValidateTransposeInputs(const ConstCSRComponent& input, "Input nnz should equal the output values size. Got ", nnz, " vs. ", output.values.size()); } - return OkStatus(); + return absl::OkStatus(); } } // namespace @@ -205,7 +205,7 @@ Status CSRSparseMatrixTranspose::operator()( maybe_conj_inplace::run(d, &output_values_t); } - return OkStatus(); + return absl::OkStatus(); } // CPU kernel for transposing a single component of a CSR SparseMatrix. diff --git a/tensorflow/core/kernels/sparse/zeros_op.h b/tensorflow/core/kernels/sparse/zeros_op.h index a6d03974a78ae4..8df31337110275 100644 --- a/tensorflow/core/kernels/sparse/zeros_op.h +++ b/tensorflow/core/kernels/sparse/zeros_op.h @@ -74,7 +74,7 @@ struct CSRSparseMatrixZeros { dtype, dense_shape_t, batch_ptr_t, csr_row_ptr_t, coo_col_ind_t, csr_values_t, matrix)); - return OkStatus(); + return absl::OkStatus(); } }; diff --git a/tensorflow/core/kernels/sparse_conditional_accumulator.h b/tensorflow/core/kernels/sparse_conditional_accumulator.h index 3b809f03e95dc2..e41caf8e0e4a45 100644 --- a/tensorflow/core/kernels/sparse_conditional_accumulator.h +++ b/tensorflow/core/kernels/sparse_conditional_accumulator.h @@ -140,7 +140,7 @@ class SparseConditionalAccumulator } } - return OkStatus(); + return absl::OkStatus(); } void AllocateAndAssignToAccumGradFunction( diff --git a/tensorflow/core/kernels/sparse_conditional_accumulator_op.cc b/tensorflow/core/kernels/sparse_conditional_accumulator_op.cc index 46137639b1d1d7..7a915cded37527 100644 --- a/tensorflow/core/kernels/sparse_conditional_accumulator_op.cc +++ b/tensorflow/core/kernels/sparse_conditional_accumulator_op.cc @@ -37,7 +37,7 @@ class SparseConditionalAccumulatorOp : public ConditionalAccumulatorBaseOp { new SparseConditionalAccumulator( dtype_, shape_, cinfo_.name(), reduction_type_); *ret = accumulator; - return OkStatus(); + return absl::OkStatus(); }; } @@ -45,7 +45,7 @@ class SparseConditionalAccumulatorOp : public ConditionalAccumulatorBaseOp { // it with cond2 otherwise. Status CheckSignature(OpKernelContext* ctx) override { TF_RETURN_IF_ERROR(ctx->MatchSignature({}, {DT_STRING_REF})); - return OkStatus(); + return absl::OkStatus(); } void SetHandleToOutput(OpKernelContext* ctx) diff --git a/tensorflow/core/kernels/sparse_cross_op.cc b/tensorflow/core/kernels/sparse_cross_op.cc index 9235ebe9efa74b..bc8d3e3b329e9d 100644 --- a/tensorflow/core/kernels/sparse_cross_op.cc +++ b/tensorflow/core/kernels/sparse_cross_op.cc @@ -591,7 +591,7 @@ Status ValidateInput(const OpInputList& indices_list_in, } } - return OkStatus(); + return absl::OkStatus(); } // Extracts data about the features and populates feature data. @@ -733,7 +733,7 @@ Status CreateOutputTensors( shape_vec(0) = batch_size; shape_vec(1) = max_cross_count; - return OkStatus(); + return absl::OkStatus(); } template diff --git a/tensorflow/core/kernels/sparse_reduce_op.cc b/tensorflow/core/kernels/sparse_reduce_op.cc index 9e040fe7224420..d64d7829a65fc1 100644 --- a/tensorflow/core/kernels/sparse_reduce_op.cc +++ b/tensorflow/core/kernels/sparse_reduce_op.cc @@ -129,7 +129,7 @@ Status ValidateInputs(const Tensor *shape_t, const Tensor *reduction_axes_t) { } } - return OkStatus(); + return absl::OkStatus(); } struct SumOp { diff --git a/tensorflow/core/kernels/sparse_tensor_dense_add_op.cc b/tensorflow/core/kernels/sparse_tensor_dense_add_op.cc index 80bb5a42e3b0aa..1d8b3b0156c756 100644 --- a/tensorflow/core/kernels/sparse_tensor_dense_add_op.cc +++ b/tensorflow/core/kernels/sparse_tensor_dense_add_op.cc @@ -91,7 +91,7 @@ Status ValidateInputs(const Tensor *a_indices, const Tensor *a_values, } } - return OkStatus(); + return absl::OkStatus(); } } // namespace diff --git a/tensorflow/core/kernels/sparse_tensor_dense_matmul_op.cc b/tensorflow/core/kernels/sparse_tensor_dense_matmul_op.cc index 73870c5b37c2b1..04aff711362552 100644 --- a/tensorflow/core/kernels/sparse_tensor_dense_matmul_op.cc +++ b/tensorflow/core/kernels/sparse_tensor_dense_matmul_op.cc @@ -319,7 +319,7 @@ Status SparseTensorDenseMatMulImpl( } #undef LOOP_NNZ } - return OkStatus(); + return absl::OkStatus(); } } // namespace diff --git a/tensorflow/core/kernels/sparse_tensors_map_ops.cc b/tensorflow/core/kernels/sparse_tensors_map_ops.cc index 841ed2cc36eaac..f14af265464f48 100644 --- a/tensorflow/core/kernels/sparse_tensors_map_ops.cc +++ b/tensorflow/core/kernels/sparse_tensors_map_ops.cc @@ -68,7 +68,7 @@ class SparseTensorsMap : public ResourceBase { gtl::InlinedVector(sp.shape().begin(), sp.shape().end())}; *handle = unique_st_handle; } - return OkStatus(); + return absl::OkStatus(); } Status RetrieveAndClearSparseTensors( @@ -95,7 +95,7 @@ class SparseTensorsMap : public ResourceBase { } } - return OkStatus(); + return absl::OkStatus(); } protected: @@ -128,7 +128,7 @@ class SparseTensorAccessingOp : public OpKernel { if (sparse_tensors_map_) { *sparse_tensors_map = sparse_tensors_map_; - return OkStatus(); + return absl::OkStatus(); } TF_RETURN_IF_ERROR(cinfo_.Init(ctx->resource_manager(), def(), @@ -137,7 +137,7 @@ class SparseTensorAccessingOp : public OpKernel { CreatorCallback sparse_tensors_map_creator = [this](SparseTensorsMap** c) { SparseTensorsMap* map = new SparseTensorsMap(cinfo_.name()); *c = map; - return OkStatus(); + return absl::OkStatus(); }; TF_RETURN_IF_ERROR( @@ -146,7 +146,7 @@ class SparseTensorAccessingOp : public OpKernel { sparse_tensors_map_creator)); *sparse_tensors_map = sparse_tensors_map_; - return OkStatus(); + return absl::OkStatus(); } private: diff --git a/tensorflow/core/kernels/sparse_to_dense_op.cc b/tensorflow/core/kernels/sparse_to_dense_op.cc index f6fe495b637078..048461daede0db 100644 --- a/tensorflow/core/kernels/sparse_to_dense_op.cc +++ b/tensorflow/core/kernels/sparse_to_dense_op.cc @@ -85,7 +85,7 @@ Status CheckSparseToDenseShapes(const Tensor& indices, if (!TensorShapeUtils::IsScalar(default_value.shape())) { return errors::InvalidArgument("default_value should be a scalar."); } - return OkStatus(); + return absl::OkStatus(); } } // end namespace diff --git a/tensorflow/core/kernels/sparse_to_dense_op_gpu.cu.cc b/tensorflow/core/kernels/sparse_to_dense_op_gpu.cu.cc index 473e3d2b919c02..e68447990d60f4 100644 --- a/tensorflow/core/kernels/sparse_to_dense_op_gpu.cu.cc +++ b/tensorflow/core/kernels/sparse_to_dense_op_gpu.cu.cc @@ -173,15 +173,16 @@ void LaunchSparseToDense::operator()( se::DeviceMemoryBase valid_status_ptr(status_ptr, valid_status_bytes); GpuLaunchConfig config = GetGpuLaunchConfig(num_elems, d); - stream->ThenMemset32(&valid_status_ptr, INT_MAX, valid_status_bytes); + OP_REQUIRES_OK( + c, stream->Memset32(&valid_status_ptr, INT_MAX, valid_status_bytes)); OP_REQUIRES_OK_ASYNC( c, GpuLaunchKernel(CheckIndicesValid, config.block_count, config.thread_per_block, 0, d.stream(), indices_ptr, num_elems, shape_ptr, num_dims, status_ptr), done); - stream->ThenMemcpy(reinterpret_cast(&valid_status), valid_status_ptr, - valid_status_bytes); + OP_REQUIRES_OK(c, stream->Memcpy(reinterpret_cast(&valid_status), + valid_status_ptr, valid_status_bytes)); // We capture 'shape' instead of 'shape_ptr' since this lambda outlives // the 'shape' tensor. diff --git a/tensorflow/core/kernels/sparse_utils.cc b/tensorflow/core/kernels/sparse_utils.cc index cf39f8102cb0dd..d9a2850e596519 100644 --- a/tensorflow/core/kernels/sparse_utils.cc +++ b/tensorflow/core/kernels/sparse_utils.cc @@ -176,7 +176,7 @@ Status ValidateSparseTensorShape(const Tensor& indices, const Tensor& values, shape.NumElements(), ") do not match"); } - return OkStatus(); + return absl::OkStatus(); } // Creates a debug string for the index tuple in indices(row, :). @@ -215,7 +215,7 @@ Status ValidateSparseTensorIndicesUnordered(const Tensor& indices, } } - return OkStatus(); + return absl::OkStatus(); } // Ensures all sparse indices are within correct bounds and are @@ -229,7 +229,7 @@ Status ValidateSparseTensorIndicesOrdered(const Tensor& indices, int64_t ndims = indices.dim_size(1); if (nnz == 0) { - return OkStatus(); + return absl::OkStatus(); } // First set of indices must be within range. @@ -282,7 +282,7 @@ Status ValidateSparseTensorIndicesOrdered(const Tensor& indices, } } // for i in [1, nnz) - return OkStatus(); + return absl::OkStatus(); } } // namespace @@ -300,7 +300,7 @@ Status ValidateSparseTensor(const Tensor& indices, const Tensor& values, case IndexValidation::kNone: { } } - return OkStatus(); + return absl::OkStatus(); } #define REGISTER_SPARSE_UTIL_FUNCTIONS(TypeIndex) \ diff --git a/tensorflow/core/kernels/sparse_xent_op.cc b/tensorflow/core/kernels/sparse_xent_op.cc index 01f8ab42a2c506..4ece900f6c5b95 100644 --- a/tensorflow/core/kernels/sparse_xent_op.cc +++ b/tensorflow/core/kernels/sparse_xent_op.cc @@ -34,7 +34,7 @@ typedef Eigen::GpuDevice GPUDevice; template Status CheckInvalidLabelIndex(const Tensor& labels, int64_t max_index) { - if (labels.NumElements() == 0) return OkStatus(); + if (labels.NumElements() == 0) return absl::OkStatus(); const auto label_values = labels.vec(); int64_t bad_index; auto min_max_dim_value = std::minmax_element( @@ -47,7 +47,7 @@ Status CheckInvalidLabelIndex(const Tensor& labels, int64_t max_index) { " which is outside the valid range of [0, ", max_index, "). Label values: ", labels.SummarizeValue(labels.NumElements())); } - return OkStatus(); + return absl::OkStatus(); } template diff --git a/tensorflow/core/kernels/spectrogram_convert_test_data.cc b/tensorflow/core/kernels/spectrogram_convert_test_data.cc index 18ce56fb52bb39..1878eb5999b505 100644 --- a/tensorflow/core/kernels/spectrogram_convert_test_data.cc +++ b/tensorflow/core/kernels/spectrogram_convert_test_data.cc @@ -34,7 +34,7 @@ Status ConvertCsvToRaw(const string& input_filename) { input_filename); } LOG(INFO) << "Wrote raw file to " << output_filename; - return OkStatus(); + return absl::OkStatus(); } } // namespace wav diff --git a/tensorflow/core/kernels/squared-loss.h b/tensorflow/core/kernels/squared-loss.h index 7222813bbbf823..3a0f6d2abb2253 100644 --- a/tensorflow/core/kernels/squared-loss.h +++ b/tensorflow/core/kernels/squared-loss.h @@ -64,7 +64,7 @@ class SquaredLossUpdater : public DualLossUpdater { // Labels don't require conversion for linear regression. Status ConvertLabel(float* const example_label) const final { - return OkStatus(); + return absl::OkStatus(); } }; diff --git a/tensorflow/core/kernels/stack.cc b/tensorflow/core/kernels/stack.cc index 1ee4f4268d43ed..90eaf2efebe1bd 100644 --- a/tensorflow/core/kernels/stack.cc +++ b/tensorflow/core/kernels/stack.cc @@ -63,7 +63,7 @@ class Stack : public ResourceBase { "its max_size (", max_size_, ")"); } stack_.push_back(value); - return OkStatus(); + return absl::OkStatus(); } Status Pop(TensorAndAllocation* value) { @@ -75,7 +75,7 @@ class Stack : public ResourceBase { } *value = stack_.back(); stack_.pop_back(); - return OkStatus(); + return absl::OkStatus(); } // We don't swap the first tensor on the stack and any subsequent tensors @@ -121,7 +121,7 @@ class Stack : public ResourceBase { return errors::InvalidArgument("Stack[", stack_name_, "] has already been closed."); } - return OkStatus(); + return absl::OkStatus(); } }; @@ -147,7 +147,7 @@ Status GetStack(OpKernelContext* ctx, Stack** stack) { return errors::Internal("No step container."); } TF_RETURN_IF_ERROR(step_container->Lookup(rm, key, stack)); - return OkStatus(); + return absl::OkStatus(); } } diff --git a/tensorflow/core/kernels/stage_op.cc b/tensorflow/core/kernels/stage_op.cc index 5d20ea3536004b..63d84513b3f5e1 100644 --- a/tensorflow/core/kernels/stage_op.cc +++ b/tensorflow/core/kernels/stage_op.cc @@ -82,7 +82,7 @@ class Buffer : public ResourceBase { // we should wake them all. non_empty_cond_var_.notify_all(); - return OkStatus(); + return absl::OkStatus(); } // Get tuple at front of the buffer @@ -115,7 +115,7 @@ class Buffer : public ResourceBase { tuple->push_back(tensor); } - return OkStatus(); + return absl::OkStatus(); } // Buffer size @@ -187,13 +187,13 @@ Status GetBuffer(OpKernelContext* ctx, const NodeDef& ndef, Buffer** buf) { TF_RETURN_IF_ERROR(GetNodeAttr(ndef, "capacity", &capacity)); TF_RETURN_IF_ERROR(GetNodeAttr(ndef, "memory_limit", &memory_limit)); *ret = new Buffer(capacity, memory_limit); - return OkStatus(); + return absl::OkStatus(); }; TF_RETURN_IF_ERROR(cinfo.Init(rm, ndef, true /* use name() */)); TF_RETURN_IF_ERROR(rm->LookupOrCreate(cinfo.container(), cinfo.name(), buf, create_fn)); - return OkStatus(); + return absl::OkStatus(); } } // namespace diff --git a/tensorflow/core/kernels/stateful_random_ops.cc b/tensorflow/core/kernels/stateful_random_ops.cc index 80f2f9ae0805ba..ef54ced28e7e6e 100644 --- a/tensorflow/core/kernels/stateful_random_ops.cc +++ b/tensorflow/core/kernels/stateful_random_ops.cc @@ -65,7 +65,7 @@ Status CheckState(const Tensor& state) { return errors::InvalidArgument( "RNG state must have one and only one dimension, not ", state.dims()); } - return OkStatus(); + return absl::OkStatus(); } Status CheckPhiloxState(const Tensor& state, int64_t alg_tag_skip = 0) { @@ -80,7 +80,7 @@ Status CheckPhiloxState(const Tensor& state, int64_t alg_tag_skip = 0) { " must be at least ", min_size, "; got ", state.NumElements()); } - return OkStatus(); + return absl::OkStatus(); } template @@ -149,7 +149,7 @@ Status UpdateVariableAndFill( arg.state_tensor = var_tensor; functor::UpdateVariableAndFill_Philox()( ctx, ctx->eigen_device(), dist, &arg, output_data); - return OkStatus(); + return absl::OkStatus(); case ConcreteRngAlgorithm::RNG_ALG_THREEFRY: return errors::Unimplemented( "Non-XLA devices don't support the ThreeFry algorithm."); @@ -202,7 +202,7 @@ Status GetScalar(const Tensor& tensor, int input_idx, T* result) { ", not ", DataTypeString(tensor.dtype())); } *result = tensor.flat()(0); - return OkStatus(); + return absl::OkStatus(); } template diff --git a/tensorflow/core/kernels/stateless_random_ops.cc b/tensorflow/core/kernels/stateless_random_ops.cc index beb2391b03f7de..1ce076e4c1a0d1 100644 --- a/tensorflow/core/kernels/stateless_random_ops.cc +++ b/tensorflow/core/kernels/stateless_random_ops.cc @@ -62,7 +62,7 @@ Status GenerateKey(Tensor seed, random::PhiloxRandom::Key* out_key, (*out_counter)[0] = (*out_counter)[1] = 0; (*out_counter)[2] = mix[2]; (*out_counter)[3] = mix[3]; - return OkStatus(); + return absl::OkStatus(); } StatelessRandomOpBase::StatelessRandomOpBase(OpKernelConstruction* context) diff --git a/tensorflow/core/kernels/stateless_random_ops_v2.h b/tensorflow/core/kernels/stateless_random_ops_v2.h index f88e5330041f6b..b566f490fdd6fb 100644 --- a/tensorflow/core/kernels/stateless_random_ops_v2.h +++ b/tensorflow/core/kernels/stateless_random_ops_v2.h @@ -38,7 +38,7 @@ inline Status CheckKeyCounterShape(int minimum_counter_size, "; got shape: ", counter_shape.DebugString(), ". (Note that batched counters are not supported yet.)"); } - return OkStatus(); + return absl::OkStatus(); } // A base class for kernels of stateless RNG ops that take shape, key, counter diff --git a/tensorflow/core/kernels/stateless_random_ops_v2_util.h b/tensorflow/core/kernels/stateless_random_ops_v2_util.h index 8744d848e869ba..c606a90fec23e2 100644 --- a/tensorflow/core/kernels/stateless_random_ops_v2_util.h +++ b/tensorflow/core/kernels/stateless_random_ops_v2_util.h @@ -41,7 +41,7 @@ Status GetScalar(const Tensor& tensor, int input_idx, T* result) { ", not ", DataTypeString(tensor.dtype())); } *result = tensor.flat()(0); - return OkStatus(); + return absl::OkStatus(); } inline StatusOr > diff --git a/tensorflow/core/kernels/string_util.cc b/tensorflow/core/kernels/string_util.cc index a0486aa6c925c5..105a89f589a0fe 100644 --- a/tensorflow/core/kernels/string_util.cc +++ b/tensorflow/core/kernels/string_util.cc @@ -31,7 +31,7 @@ Status ParseUnicodeEncoding(const string& str, UnicodeEncoding* encoding) { strings::StrCat("Invalid encoding \"", str, "\": Should be one of: UTF-8, UTF-16-BE, UTF-32-BE")); } - return OkStatus(); + return absl::OkStatus(); } // Sets unit value based on str. @@ -44,7 +44,7 @@ Status ParseCharUnit(const string& str, CharUnit* unit) { return errors::InvalidArgument(strings::StrCat( "Invalid unit \"", str, "\": Should be one of: BYTE, UTF8_CHAR")); } - return OkStatus(); + return absl::OkStatus(); } // Return the number of Unicode characters in a UTF-8 string. diff --git a/tensorflow/core/kernels/summary_image_op.cc b/tensorflow/core/kernels/summary_image_op.cc index c2ab2ab585b6fc..a68bf724cf9efc 100644 --- a/tensorflow/core/kernels/summary_image_op.cc +++ b/tensorflow/core/kernels/summary_image_op.cc @@ -173,7 +173,7 @@ class SummaryImageOp : public OpKernel { return errors::Internal("PNG encoding failed"); } } - return OkStatus(); + return absl::OkStatus(); } template diff --git a/tensorflow/core/kernels/summary_kernels.cc b/tensorflow/core/kernels/summary_kernels.cc index e2348ca7a75953..81d30e8dbee42b 100644 --- a/tensorflow/core/kernels/summary_kernels.cc +++ b/tensorflow/core/kernels/summary_kernels.cc @@ -97,7 +97,7 @@ class CreateSummaryDbWriterOp : public OpKernel { TF_RETURN_IF_ERROR(SetupTensorboardSqliteDb(db)); TF_RETURN_IF_ERROR(CreateSummaryDbWriter( db, experiment_name, run_name, user_name, ctx->env(), s)); - return OkStatus(); + return absl::OkStatus(); })); } }; diff --git a/tensorflow/core/kernels/tensor_array.cc b/tensorflow/core/kernels/tensor_array.cc index 644e6c373aaf05..fa24b716a9c822 100644 --- a/tensorflow/core/kernels/tensor_array.cc +++ b/tensorflow/core/kernels/tensor_array.cc @@ -111,7 +111,7 @@ Status TensorArray::CopyShapesFrom(TensorArray* rhs, tensors_[i].written = true; } - return OkStatus(); + return absl::OkStatus(); } } // namespace tensorflow diff --git a/tensorflow/core/kernels/tensor_array.h b/tensorflow/core/kernels/tensor_array.h index 97e4cd45b085fe..1081c2be8a08a8 100644 --- a/tensorflow/core/kernels/tensor_array.h +++ b/tensorflow/core/kernels/tensor_array.h @@ -202,7 +202,7 @@ class TensorArray : public ResourceBase { ++i; TF_RETURN_IF_ERROR(s); } - return OkStatus(); + return absl::OkStatus(); } // Read from index 'index' into Tensor 'value'. @@ -238,7 +238,7 @@ class TensorArray : public ResourceBase { ++i; if (!s.ok()) return s; } - return OkStatus(); + return absl::OkStatus(); } DataType ElemType() const { return dtype_; } @@ -256,7 +256,7 @@ class TensorArray : public ResourceBase { return s; } element_shape_ = new_element_shape_; - return OkStatus(); + return absl::OkStatus(); } string DebugString() const override { @@ -275,7 +275,7 @@ class TensorArray : public ResourceBase { mutex_lock l(mu_); TF_RETURN_IF_ERROR(LockedReturnIfClosed()); *size = tensors_.size(); - return OkStatus(); + return absl::OkStatus(); } // Record the size of the TensorArray after an unpack or split. @@ -285,7 +285,7 @@ class TensorArray : public ResourceBase { if (!is_grad_) { marked_size_ = size; } - return OkStatus(); + return absl::OkStatus(); } // Return the marked size of the TensorArray. @@ -293,7 +293,7 @@ class TensorArray : public ResourceBase { mutex_lock l(mu_); TF_RETURN_IF_ERROR(LockedReturnIfClosed()); *size = marked_size_; - return OkStatus(); + return absl::OkStatus(); } // Return the size that should be used by pack or concat op. @@ -301,7 +301,7 @@ class TensorArray : public ResourceBase { mutex_lock l(mu_); TF_RETURN_IF_ERROR(LockedReturnIfClosed()); *size = is_grad_ ? marked_size_ : tensors_.size(); - return OkStatus(); + return absl::OkStatus(); } // Once a TensorArray is being used for gradient calculations, it @@ -367,7 +367,7 @@ class TensorArray : public ResourceBase { return errors::InvalidArgument("TensorArray ", handle_.vec()(1), " has already been closed."); } - return OkStatus(); + return absl::OkStatus(); } const string key_; @@ -508,7 +508,7 @@ Status TensorArray::LockedWriteOrAggregate(OpKernelContext* ctx, // was just a shape, which just means zeros. So all we must do in this // case is copy the reference over and return early. t.tensor = *value; - return OkStatus(); + return absl::OkStatus(); } Tensor* existing_t = &t.tensor; @@ -536,7 +536,7 @@ Status TensorArray::LockedWriteOrAggregate(OpKernelContext* ctx, t.shape = value->shape(); t.written = true; } - return OkStatus(); + return absl::OkStatus(); } template @@ -619,7 +619,7 @@ Status TensorArray::LockedRead(OpKernelContext* ctx, const int32_t index, t.cleared = true; } t.read = true; - return OkStatus(); + return absl::OkStatus(); } } // namespace tensorflow diff --git a/tensorflow/core/kernels/tensor_array_ops.cc b/tensorflow/core/kernels/tensor_array_ops.cc index dc4f1ca9400bf2..5d1322e6f4b7d6 100644 --- a/tensorflow/core/kernels/tensor_array_ops.cc +++ b/tensorflow/core/kernels/tensor_array_ops.cc @@ -75,7 +75,7 @@ Status GetHandle(OpKernelContext* ctx, string* container, string* ta_handle) { *container = h(0); *ta_handle = h(1); } - return OkStatus(); + return absl::OkStatus(); } Status GetTensorArray(OpKernelContext* ctx, TensorArray** tensor_array) { @@ -88,7 +88,7 @@ Status GetTensorArray(OpKernelContext* ctx, TensorArray** tensor_array) { ScopedStepContainer* sc = ctx->step_container(); if (sc == nullptr) return errors::Internal("No step container."); TF_RETURN_IF_ERROR(sc->Lookup(rm, container + ta_handle, tensor_array)); - return OkStatus(); + return absl::OkStatus(); } else { return LookupResource(ctx, HandleFromInput(ctx, 0), tensor_array); } @@ -100,7 +100,7 @@ Status SetupFlowControlInputs(OpKernelContext* ctx, bool set_output) { if (set_output) { TF_RETURN_IF_ERROR(ctx->set_output("flow_out", *flow_control)); } - return OkStatus(); + return absl::OkStatus(); } // CREATION ******************************************************************* @@ -220,7 +220,7 @@ class TensorArrayOp : public TensorArrayCreationOp { *output_tensor_array = tensor_array; - return OkStatus(); + return absl::OkStatus(); } private: diff --git a/tensorflow/core/kernels/tensor_flag_utils.cc b/tensorflow/core/kernels/tensor_flag_utils.cc index 2f0165d08a3911..974c4622a69a89 100644 --- a/tensorflow/core/kernels/tensor_flag_utils.cc +++ b/tensorflow/core/kernels/tensor_flag_utils.cc @@ -25,7 +25,7 @@ Status ValidateSparseMatrixShardingConfig(const Tensor& config) { if (TensorShapeUtils::IsScalar(config.shape())) { const float scalar_config = config.template scalar()(); if (0 < scalar_config && scalar_config <= 1.0) { - return OkStatus(); + return absl::OkStatus(); } return Status( absl::StatusCode::kInvalidArgument, @@ -69,7 +69,7 @@ Status ValidateSparseMatrixShardingConfig(const Tensor& config) { config_matrix(i, 2), " in row ", i); } } - return OkStatus(); + return absl::OkStatus(); } template @@ -89,7 +89,7 @@ Status ValidateScalarQuantityShardingConfig(const Tensor& config) { if (TensorShapeUtils::IsScalar(config.shape())) { const float scalar_config = config.template scalar()(); if (0 < scalar_config && scalar_config <= 1.0) { - return OkStatus(); + return absl::OkStatus(); } return Status( absl::StatusCode::kInvalidArgument, @@ -126,7 +126,7 @@ Status ValidateScalarQuantityShardingConfig(const Tensor& config) { config_matrix(i, 1), " in row ", i); } } - return OkStatus(); + return absl::OkStatus(); } template diff --git a/tensorflow/core/kernels/tensor_list_util.cc b/tensorflow/core/kernels/tensor_list_util.cc index 34aa1d35d9af6e..7dc0d01b56b61d 100644 --- a/tensorflow/core/kernels/tensor_list_util.cc +++ b/tensorflow/core/kernels/tensor_list_util.cc @@ -61,7 +61,7 @@ Status TensorListBinaryAdd( TF_RETURN_IF_ERROR(binary_add_func(c, a_tensor, b_tensor, &out_tensor)); out->tensors().push_back(out_tensor); } - return OkStatus(); + return absl::OkStatus(); } Status TensorListZerosLike( @@ -77,7 +77,7 @@ Status TensorListZerosLike( TF_RETURN_IF_ERROR(zeros_like_func(c, t, &out_tensor)); y->tensors().emplace_back(out_tensor); } - return OkStatus(); + return absl::OkStatus(); } } // namespace tensorflow diff --git a/tensorflow/core/kernels/tensor_map.cc b/tensorflow/core/kernels/tensor_map.cc index 94a22fbfcabe93..a95d256cff92f4 100644 --- a/tensorflow/core/kernels/tensor_map.cc +++ b/tensorflow/core/kernels/tensor_map.cc @@ -53,7 +53,7 @@ static Status TensorMapDeviceCopy( TF_RETURN_IF_ERROR(copy(p.second, &to_val)); to->tensors().emplace(to_key, to_val); } - return OkStatus(); + return absl::OkStatus(); } #define REGISTER_LIST_COPY(DIRECTION) \ diff --git a/tensorflow/core/kernels/text_line_reader_op.cc b/tensorflow/core/kernels/text_line_reader_op.cc index ae05e581ed09b9..89b56cb1853bd7 100644 --- a/tensorflow/core/kernels/text_line_reader_op.cc +++ b/tensorflow/core/kernels/text_line_reader_op.cc @@ -46,16 +46,16 @@ class TextLineReader : public ReaderBase { if (absl::IsOutOfRange(status)) { // We ignore an end of file error when skipping header lines. // We will end up skipping this file. - return OkStatus(); + return absl::OkStatus(); } TF_RETURN_IF_ERROR(status); } - return OkStatus(); + return absl::OkStatus(); } Status OnWorkFinishedLocked() override { input_buffer_.reset(nullptr); - return OkStatus(); + return absl::OkStatus(); } Status ReadLocked(tstring* key, tstring* value, bool* produced, @@ -69,7 +69,7 @@ class TextLineReader : public ReaderBase { } if (absl::IsOutOfRange(status)) { // End of file, advance to the next. *at_end = true; - return OkStatus(); + return absl::OkStatus(); } else { // Some other reading error return status; } diff --git a/tensorflow/core/kernels/tf_record_reader_op.cc b/tensorflow/core/kernels/tf_record_reader_op.cc index 416bc22b9413d2..9126139afc6b65 100644 --- a/tensorflow/core/kernels/tf_record_reader_op.cc +++ b/tensorflow/core/kernels/tf_record_reader_op.cc @@ -43,13 +43,13 @@ class TFRecordReader : public ReaderBase { io::RecordReaderOptions options = io::RecordReaderOptions::CreateRecordReaderOptions(compression_type_); reader_.reset(new io::RecordReader(file_.get(), options)); - return OkStatus(); + return absl::OkStatus(); } Status OnWorkFinishedLocked() override { reader_.reset(nullptr); file_.reset(nullptr); - return OkStatus(); + return absl::OkStatus(); } Status ReadLocked(tstring* key, tstring* value, bool* produced, @@ -58,11 +58,11 @@ class TFRecordReader : public ReaderBase { Status status = reader_->ReadRecord(&offset_, value); if (absl::IsOutOfRange(status)) { *at_end = true; - return OkStatus(); + return absl::OkStatus(); } if (!status.ok()) return status; *produced = true; - return OkStatus(); + return absl::OkStatus(); } Status ResetLocked() override { diff --git a/tensorflow/core/kernels/training_op_helpers.h b/tensorflow/core/kernels/training_op_helpers.h index fa26c829566a49..69447348f067a5 100644 --- a/tensorflow/core/kernels/training_op_helpers.h +++ b/tensorflow/core/kernels/training_op_helpers.h @@ -45,7 +45,7 @@ template Status EnsureSparseVariableAccess(OpKernelContext* ctx, Var* var, bool lock_held = false) { if (var->copy_on_read_mode.load()) { - return OkStatus(); + return absl::OkStatus(); } std::optional ml; @@ -58,7 +58,7 @@ Status EnsureSparseVariableAccess(OpKernelContext* ctx, Var* var, // copy-on-read mode is false. if (var->tensor()->RefCountIsOne()) { var->copy_on_read_mode.store(true); - return OkStatus(); + return absl::OkStatus(); } Tensor tmp; if (std::is_same::value) { @@ -84,7 +84,7 @@ Status EnsureSparseVariableAccess(OpKernelContext* ctx, Var* var, } *var->tensor() = tmp; var->copy_on_read_mode.store(true); - return OkStatus(); + return absl::OkStatus(); } // Utility structure that releases a sequence of borrowed mutexes when it is @@ -249,7 +249,7 @@ Status PrepareToUpdateVariable(OpKernelContext* ctx, Tensor* tensor, } *tensor = tmp; } - return OkStatus(); + return absl::OkStatus(); } // This gives you `*out`, a tensor you can update, corresponding to a variable @@ -269,15 +269,15 @@ Status GetInputTensorFromVariable(OpKernelContext* ctx, int input, if (sparse) { TF_RETURN_IF_ERROR(EnsureSparseVariableAccess(ctx, var.get())); *out = *var->tensor(); - return OkStatus(); + return absl::OkStatus(); } TF_RETURN_IF_ERROR(PrepareToUpdateVariable( ctx, var->tensor(), var->copy_on_read_mode.load())); *out = *var->tensor(); - return OkStatus(); + return absl::OkStatus(); } *out = ctx->mutable_input(input, lock_held); - return OkStatus(); + return absl::OkStatus(); } } // end namespace tensorflow diff --git a/tensorflow/core/kernels/transpose_functor.h b/tensorflow/core/kernels/transpose_functor.h index 2969918c33df1f..d640d051a40f4d 100644 --- a/tensorflow/core/kernels/transpose_functor.h +++ b/tensorflow/core/kernels/transpose_functor.h @@ -235,14 +235,14 @@ Status DoTransposeImpl(const Device& d, const Tensor& in, default: return errors::Unimplemented("Unsupported dtype on CPU: ", in.dtype()); } - return OkStatus(); + return absl::OkStatus(); } template inline Status DoMatrixTransposeImpl(const Device& device, const Tensor& in, bool conjugate, Tensor* out) { const int ndims = in.dims(); - if (ndims == 0) return OkStatus(); + if (ndims == 0) return absl::OkStatus(); TransposePermsVec perm(ndims); std::iota(perm.begin(), perm.end(), 0); std::swap(perm[ndims - 2], perm[ndims - 1]); diff --git a/tensorflow/core/kernels/transpose_op.cc b/tensorflow/core/kernels/transpose_op.cc index df129f78c889f3..e3719aab6c648e 100644 --- a/tensorflow/core/kernels/transpose_op.cc +++ b/tensorflow/core/kernels/transpose_op.cc @@ -107,7 +107,7 @@ Status PermutationHelper(const Tensor& perm, const int dims, reinterpret_cast(Vperm.data()); *permutation = std::vector(perm_begin, perm_begin + dims); - return OkStatus(); + return absl::OkStatus(); } } // namespace diff --git a/tensorflow/core/kernels/typed_queue.h b/tensorflow/core/kernels/typed_queue.h index f11029cddadba2..2e67261841859d 100644 --- a/tensorflow/core/kernels/typed_queue.h +++ b/tensorflow/core/kernels/typed_queue.h @@ -68,7 +68,7 @@ Status TypedQueue::Initialize() { for (int i = 0; i < num_components(); ++i) { queues_.push_back(SubQueue()); } - return OkStatus(); + return absl::OkStatus(); } template diff --git a/tensorflow/core/kernels/unary_ops_composition.cc b/tensorflow/core/kernels/unary_ops_composition.cc index 4805e6c2aef9dc..98684f382ecd21 100644 --- a/tensorflow/core/kernels/unary_ops_composition.cc +++ b/tensorflow/core/kernels/unary_ops_composition.cc @@ -69,7 +69,7 @@ struct UnaryOpsCompositionBase { *cost += reg.cost; } - return OkStatus(); + return absl::OkStatus(); } std::unordered_map compute_fns; diff --git a/tensorflow/core/kernels/unicode_ops.cc b/tensorflow/core/kernels/unicode_ops.cc index b884b5fd354b23..3d59cc034480b3 100644 --- a/tensorflow/core/kernels/unicode_ops.cc +++ b/tensorflow/core/kernels/unicode_ops.cc @@ -237,7 +237,7 @@ Status GetErrorOptions(OpKernelConstruction* ctx, ErrorOptions* out) { &(out->replace_control_chars))); } - return OkStatus(); + return absl::OkStatus(); } inline bool ShouldHandleFormatError(const ErrorOptions& error_options, diff --git a/tensorflow/core/kernels/uniform_quant_ops/math_utils.cc b/tensorflow/core/kernels/uniform_quant_ops/math_utils.cc index 8fc4510608f7f0..413882ca810835 100644 --- a/tensorflow/core/kernels/uniform_quant_ops/math_utils.cc +++ b/tensorflow/core/kernels/uniform_quant_ops/math_utils.cc @@ -48,7 +48,7 @@ Status QuantizeMultiplier(double double_multiplier, q_fixed = (1LL << 31) - 1; } quantized_multiplier = static_cast(q_fixed); - return OkStatus(); + return absl::OkStatus(); } } // namespace tensorflow diff --git a/tensorflow/core/kernels/uniform_quant_ops/math_utils.h b/tensorflow/core/kernels/uniform_quant_ops/math_utils.h index 8b342b4295121e..8d471f9d21139d 100644 --- a/tensorflow/core/kernels/uniform_quant_ops/math_utils.h +++ b/tensorflow/core/kernels/uniform_quant_ops/math_utils.h @@ -137,7 +137,7 @@ Status AsymmetricQuantize(const ConstTensorTin& input_tensor, quantized_tensor.setZero(); scale = 1.0; zero_point = 0; - return OkStatus(); + return absl::OkStatus(); } // Using the scale calculated from the quantization range and data range, @@ -166,7 +166,7 @@ Status AsymmetricQuantize(const ConstTensorTin& input_tensor, AffineQuantize(input_tensor, inv_scale, zero_point, quantization_min_val, quantization_max_val, quantized_tensor); - return OkStatus(); + return absl::OkStatus(); } // Given double_multiplier, quantize it where it is represented by two int32_t, @@ -227,7 +227,7 @@ Status PerTensorToPerTensorRequantize( input_zero_point, output_zero_point, quantization_min_val, quantization_max_val); }); - return OkStatus(); + return absl::OkStatus(); } // Requantize where the input or output contains any per-axis quantized cases. @@ -298,7 +298,7 @@ Status PerAxisRequantize(OpKernelContext* context, const Tensor& input, quantization_min_val, quantization_max_val); }); } - return OkStatus(); + return absl::OkStatus(); } } // namespace internal diff --git a/tensorflow/core/kernels/uniform_quant_ops/tensor_utils.cc b/tensorflow/core/kernels/uniform_quant_ops/tensor_utils.cc index 7843162ddf9e79..d75659ad5161f6 100644 --- a/tensorflow/core/kernels/uniform_quant_ops/tensor_utils.cc +++ b/tensorflow/core/kernels/uniform_quant_ops/tensor_utils.cc @@ -53,7 +53,7 @@ Status QuantizationAxisAndShapeValid(const TensorShape& data_shape, " and zero_points shape ", zero_points_shape.DebugString()); } } - return OkStatus(); + return absl::OkStatus(); } TensorShape TransposedShape(const TensorShape& in_shape, diff --git a/tensorflow/core/kernels/uniform_quant_ops/uniform_quantized_add_op.cc b/tensorflow/core/kernels/uniform_quant_ops/uniform_quantized_add_op.cc index 1a3291dfa54df3..821ccd6e6bd1c1 100644 --- a/tensorflow/core/kernels/uniform_quant_ops/uniform_quantized_add_op.cc +++ b/tensorflow/core/kernels/uniform_quant_ops/uniform_quantized_add_op.cc @@ -158,7 +158,7 @@ Status EvalQuantizedAdd(OpKernelContext* context, const Tensor& lhs, lhs_quantization_axis, rhs_quantization_axis, output_quantization_axis, output); - return OkStatus(); + return absl::OkStatus(); } } // namespace diff --git a/tensorflow/core/kernels/uniform_quant_ops/uniform_quantized_convolution_ops.cc b/tensorflow/core/kernels/uniform_quant_ops/uniform_quantized_convolution_ops.cc index 3381dbc3c99622..1fa972a8c9801e 100644 --- a/tensorflow/core/kernels/uniform_quant_ops/uniform_quantized_convolution_ops.cc +++ b/tensorflow/core/kernels/uniform_quant_ops/uniform_quantized_convolution_ops.cc @@ -302,7 +302,7 @@ Status EvalLhsPerTensorAndRhsPerTensorQuantizedConv( /*input_zero_point=*/0, output_zero_point, output_quantization_min_val, output_quantization_max_val); }); - return OkStatus(); + return absl::OkStatus(); } // Quantized Conv on per-tensor quantized padded and dilated transposed lhs and @@ -383,7 +383,7 @@ Status EvalLhsPerTensorAndRhsPerChannelQuantizedConv( : out_feature_idx], output_quantization_min_val, output_quantization_max_val); }); - return OkStatus(); + return absl::OkStatus(); } // Quantized Conv on per-batch quantized padded and dilated transposed lhs and @@ -509,7 +509,7 @@ Status EvalQuantizedConv( // Transpose transposed_out back to out. const auto& out_perm_back = OutBackTransposePerm(out_perm); Transpose(out_transposed, out_perm_back, out); - return OkStatus(); + return absl::OkStatus(); } // Given float `lhs` and quantized `rhs`, performs per-batch dynamic range @@ -593,7 +593,7 @@ Status EvalHybridConv( // Transpose transposed_out back to out. const auto& out_perm_back = OutBackTransposePerm(out_perm); Transpose(out_transposed, out_perm_back, out); - return OkStatus(); + return absl::OkStatus(); } } // namespace diff --git a/tensorflow/core/kernels/uniform_quant_ops/uniform_quantized_dot_ops.cc b/tensorflow/core/kernels/uniform_quant_ops/uniform_quantized_dot_ops.cc index bd37ad7fbcc6a2..f2d51987c6f40c 100644 --- a/tensorflow/core/kernels/uniform_quant_ops/uniform_quantized_dot_ops.cc +++ b/tensorflow/core/kernels/uniform_quant_ops/uniform_quantized_dot_ops.cc @@ -39,7 +39,7 @@ Status DotInputShapeValid(const TensorShape& lhs_shape, "shape ", lhs_shape.DebugString(), " and rhs shape ", rhs_shape.DebugString()); } - return OkStatus(); + return absl::OkStatus(); } // Performs dot(lhs, rhs) and writes output to output. Assumes that output is @@ -109,7 +109,7 @@ Status EvalLhsPerTensorAndRhsPerTensorQuantizedDot( /*input_zero_point=*/0, output_zero_point, output_quantization_min_val, output_quantization_max_val); }); - return OkStatus(); + return absl::OkStatus(); } // Performs dot on per-tensor quantized lhs and per-channel (dimension 1) @@ -178,7 +178,7 @@ Status EvalLhsPerTensorAndRhsPerChannelQuantizedDot( output_zero_points_data[is_output_scales_scalar ? 0 : out_c], output_quantization_min_val, output_quantization_max_val); }); - return OkStatus(); + return absl::OkStatus(); } // Performs dot on per-batch (dimension 0) quantized lhs and per-tensor @@ -300,7 +300,7 @@ Status EvalHybridDot(OpKernelContext* context, const Tensor& lhs, rhs_scales.scalar()(), rhs_zero_points.scalar()(), output); } - return OkStatus(); + return absl::OkStatus(); } } // namespace diff --git a/tensorflow/core/kernels/variable_ops.cc b/tensorflow/core/kernels/variable_ops.cc index 6b30bc9c0aebb2..870dfd01fabc07 100644 --- a/tensorflow/core/kernels/variable_ops.cc +++ b/tensorflow/core/kernels/variable_ops.cc @@ -76,7 +76,7 @@ void VariableOp::Compute(OpKernelContext* ctx) { auto creator = [this](LegacyVar** var) { *var = new LegacyVar(dtype_); (*var)->tensor()->set_shape(shape_); - return OkStatus(); + return absl::OkStatus(); }; LegacyVar* var; OP_REQUIRES_OK(ctx, cinfo_.resource_manager()->LookupOrCreate( diff --git a/tensorflow/core/kernels/variant_ops_util.cc b/tensorflow/core/kernels/variant_ops_util.cc index 1665fe23d74236..989d5c70d28629 100644 --- a/tensorflow/core/kernels/variant_ops_util.cc +++ b/tensorflow/core/kernels/variant_ops_util.cc @@ -49,7 +49,7 @@ static inline Status AddVariantTo( Variant* c = &temp->at(lhs_ix); TF_RETURN_IF_ERROR(binary_add_variant(ctx, a, b, c)); temp_filled->at(lhs_ix) = true; - return OkStatus(); + return absl::OkStatus(); } void AddNVariant(OpKernelContext* ctx, diff --git a/tensorflow/core/kernels/while_op_test.cc b/tensorflow/core/kernels/while_op_test.cc index 62b602e8106aba..b7f5af047b8186 100644 --- a/tensorflow/core/kernels/while_op_test.cc +++ b/tensorflow/core/kernels/while_op_test.cc @@ -19,7 +19,7 @@ limitations under the License. #include "tensorflow/cc/framework/scope.h" #include "tensorflow/cc/ops/standard_ops.h" #include "xla/stream_executor/event.h" -#include "xla/stream_executor/multi_platform_manager.h" +#include "xla/stream_executor/platform_manager.h" #include "tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.h" #include "tensorflow/core/framework/function_testlib.h" #include "tensorflow/core/graph/node_builder.h" @@ -166,8 +166,8 @@ TEST_F(WhileOpTest, WhileOpCPUBuildWithPluggableDevice) { std::move(platform_fns_), stream_executor::test_util::DestroyPlatformFns, std::move(device_fns_), std::move(se_), std::move(timer_fns_))); - TF_CHECK_OK(stream_executor::MultiPlatformManager::RegisterPlatform( - std::move(cplatform))); + TF_CHECK_OK( + stream_executor::PlatformManager::RegisterPlatform(std::move(cplatform))); DeviceFactory::Register( platform_type, new PluggableDeviceFactory(platform_type, platform_name), diff --git a/tensorflow/core/kernels/whole_file_read_ops.cc b/tensorflow/core/kernels/whole_file_read_ops.cc index 7ded247d961852..f57e1d7c602ade 100644 --- a/tensorflow/core/kernels/whole_file_read_ops.cc +++ b/tensorflow/core/kernels/whole_file_read_ops.cc @@ -41,7 +41,7 @@ static Status ReadEntireFile(Env* env, const string& filename, T* contents) { io::RandomAccessInputStream input_stream(file.get()); io::BufferedInputStream in(&input_stream, 1 << 20); TF_RETURN_IF_ERROR(in.ReadAll(contents)); - return OkStatus(); + return absl::OkStatus(); } class WholeFileReader : public ReaderBase { @@ -56,7 +56,7 @@ class WholeFileReader : public ReaderBase { TF_RETURN_IF_ERROR(ReadEntireFile(env_, *key, value)); *produced = true; *at_end = true; - return OkStatus(); + return absl::OkStatus(); } // Stores state in a ReaderBaseState proto, since WholeFileReader has @@ -65,7 +65,7 @@ class WholeFileReader : public ReaderBase { ReaderBaseState base_state; SaveBaseState(&base_state); SerializeToTString(base_state, state); - return OkStatus(); + return absl::OkStatus(); } Status RestoreStateLocked(const tstring& state) override { @@ -75,7 +75,7 @@ class WholeFileReader : public ReaderBase { absl::CEscape(state)); } TF_RETURN_IF_ERROR(RestoreBaseState(base_state)); - return OkStatus(); + return absl::OkStatus(); } private: diff --git a/tensorflow/core/kernels/word2vec_kernels.cc b/tensorflow/core/kernels/word2vec_kernels.cc index 1e249136744b3c..bc81562edf3155 100644 --- a/tensorflow/core/kernels/word2vec_kernels.cc +++ b/tensorflow/core/kernels/word2vec_kernels.cc @@ -233,7 +233,7 @@ class SkipgramOp : public OpKernel { } precalc_examples_.resize(kPrecalc); sentence_.resize(kSentenceSize); - return OkStatus(); + return absl::OkStatus(); } }; diff --git a/tensorflow/core/lib/core/status_test.cc b/tensorflow/core/lib/core/status_test.cc index 9927545f73ab69..af80f615baf65b 100644 --- a/tensorflow/core/lib/core/status_test.cc +++ b/tensorflow/core/lib/core/status_test.cc @@ -24,11 +24,11 @@ limitations under the License. namespace tensorflow { TEST(Status, OK) { - EXPECT_EQ(OkStatus().code(), error::OK); - EXPECT_EQ(OkStatus().message(), ""); - TF_EXPECT_OK(OkStatus()); - TF_ASSERT_OK(OkStatus()); - EXPECT_EQ(OkStatus(), Status()); + EXPECT_EQ(absl::OkStatus().code(), error::OK); + EXPECT_EQ(absl::OkStatus().message(), ""); + TF_EXPECT_OK(absl::OkStatus()); + TF_ASSERT_OK(absl::OkStatus()); + EXPECT_EQ(absl::OkStatus(), Status()); Status s; EXPECT_TRUE(s.ok()); } @@ -73,7 +73,7 @@ TEST(Status, MoveAssign) { TEST(Status, Update) { Status s; - s.Update(OkStatus()); + s.Update(absl::OkStatus()); ASSERT_TRUE(s.ok()); Status a(errors::InvalidArgument("Invalid")); s.Update(a); @@ -81,12 +81,12 @@ TEST(Status, Update) { Status b(errors::Internal("Internal")); s.Update(b); ASSERT_EQ(s.ToString(), a.ToString()); - s.Update(OkStatus()); + s.Update(absl::OkStatus()); ASSERT_EQ(s.ToString(), a.ToString()); ASSERT_FALSE(s.ok()); } -TEST(Status, EqualsOK) { ASSERT_EQ(OkStatus(), Status()); } +TEST(Status, EqualsOK) { ASSERT_EQ(absl::OkStatus(), Status()); } TEST(Status, EqualsSame) { Status a(errors::InvalidArgument("Invalid")); @@ -114,10 +114,10 @@ TEST(Status, EqualsDifferentMessage) { TEST(StatusGroup, OKStatusGroup) { StatusGroup c; - c.Update(OkStatus()); - c.Update(OkStatus()); - ASSERT_EQ(c.as_summary_status(), OkStatus()); - ASSERT_EQ(c.as_concatenated_status(), OkStatus()); + c.Update(absl::OkStatus()); + c.Update(absl::OkStatus()); + ASSERT_EQ(c.as_summary_status(), absl::OkStatus()); + ASSERT_EQ(c.as_concatenated_status(), absl::OkStatus()); } TEST(StatusGroup, AggregateWithSingleErrorStatus) { @@ -197,7 +197,7 @@ TEST(Status, ErasePayloadRemovesIt) { static void BM_TF_CHECK_OK(::testing::benchmark::State& state) { tensorflow::Status s = (state.max_iterations < 0) ? errors::InvalidArgument("Invalid") - : OkStatus(); + : absl::OkStatus(); for (auto i : state) { TF_CHECK_OK(s); } diff --git a/tensorflow/core/lib/core/threadpool_test.cc b/tensorflow/core/lib/core/threadpool_test.cc index e4af7e70812e9d..504e2c894d62af 100644 --- a/tensorflow/core/lib/core/threadpool_test.cc +++ b/tensorflow/core/lib/core/threadpool_test.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/core/lib/core/threadpool.h" #include +#include #include "absl/synchronization/barrier.h" #include "absl/synchronization/blocking_counter.h" @@ -76,7 +77,7 @@ void RunWithFixedBlockSize(int64_t block_size, int64_t total, total, ThreadPool::SchedulingParams( ThreadPool::SchedulingStrategy::kFixedBlockSize /* strategy */, - absl::nullopt /* cost_per_unit */, block_size /* block_size */), + std::nullopt /* cost_per_unit */, block_size /* block_size */), [=, &mu, &num_shards, &num_done_work, &work](int64_t start, int64_t end) { VLOG(1) << "Shard [" << start << "," << end << ")"; EXPECT_GE(start, 0); @@ -220,7 +221,7 @@ void RunFixedBlockSizeShardingWithWorkerId(int64_t block_size, int64_t total, total, ThreadPool::SchedulingParams( ThreadPool::SchedulingStrategy::kFixedBlockSize /* strategy */, - absl::nullopt /* cost_per_unit */, block_size /* block_size */), + std::nullopt /* cost_per_unit */, block_size /* block_size */), [=, &mu, &num_done_work, &work, &threads_running](int64_t start, int64_t end, int id) { VLOG(1) << "Shard [" << start << "," << end << ")"; @@ -301,7 +302,7 @@ TEST(ThreadPool, ParallelForWithAdaptiveSchedulingStrategy) { kWorkItems, ThreadPool::SchedulingParams( ThreadPool::SchedulingStrategy::kAdaptive /* strategy */, - kHugeCost /* cost_per_unit */, absl::nullopt /* block_size */), + kHugeCost /* cost_per_unit */, std::nullopt /* block_size */), [&outer_context, &work](int64_t begin, int64_t end) { Context inner_context(ContextKind::kThread); ASSERT_EQ(outer_context, inner_context); diff --git a/tensorflow/core/lib/db/sqlite.cc b/tensorflow/core/lib/db/sqlite.cc index 583aeee34f37f6..b208c3a7242c65 100644 --- a/tensorflow/core/lib/db/sqlite.cc +++ b/tensorflow/core/lib/db/sqlite.cc @@ -83,7 +83,7 @@ sqlite3_stmt* PrepareRawOrDie(sqlite3* db, const char* sql) { } Status SetPragma(Sqlite* db, const char* pragma, const StringPiece& value) { - if (value.empty()) return OkStatus(); + if (value.empty()) return absl::OkStatus(); for (auto p = value.begin(); p < value.end(); ++p) { if (!(('0' <= *p && *p <= '9') || ('A' <= *p && *p <= 'Z') || ('a' <= *p && *p <= 'z') || *p == '-')) { @@ -105,7 +105,7 @@ const StringPiece GetEnv(const char* var) { Status EnvPragma(Sqlite* db, const char* pragma, const char* var) { TF_RETURN_WITH_CONTEXT_IF_ERROR(SetPragma(db, pragma, GetEnv(var)), "getenv(", var, ")"); - return OkStatus(); + return absl::OkStatus(); } } // namespace @@ -130,7 +130,7 @@ Status Sqlite::Open(const string& path, int flags, Sqlite** db) { sqlite3_stmt* commit = PrepareRawOrDie(sqlite, "COMMIT"); sqlite3_stmt* rollback = PrepareRawOrDie(sqlite, "ROLLBACK"); *db = new Sqlite(sqlite, begin, commit, rollback); - Status s = OkStatus(); + Status s = absl::OkStatus(); // Up until 2016 the default SQLite page_size was 1024. This ensures // the new default regardless of linkage unless configured otherwise. s.Update(SetPragma(*db, "page_size", "4096")); @@ -172,7 +172,7 @@ Status Sqlite::Prepare(const StringPiece& sql, SqliteStatement* stmt) { sql.size(), sql.data()); } *stmt = SqliteStatement(this, ps); - return OkStatus(); + return absl::OkStatus(); } Status SqliteStatement::Step(bool* is_done) { @@ -188,10 +188,10 @@ Status SqliteStatement::Step(bool* is_done) { switch (rc) { case SQLITE_ROW: *is_done = false; - return OkStatus(); + return absl::OkStatus(); case SQLITE_DONE: *is_done = true; - return OkStatus(); + return absl::OkStatus(); default: *is_done = true; return PrintfStatus(rc, "Step() failed: [%d] %s: %s", rc, db_->errmsg(), @@ -211,7 +211,7 @@ Status SqliteStatement::StepOnce() { if (TF_PREDICT_FALSE(is_done)) { return errors::Internal("No rows returned: ", sql()); } - return OkStatus(); + return absl::OkStatus(); } const SqliteStatement& SqliteStatement::StepOnceOrDie() { @@ -277,7 +277,7 @@ Status SqliteTransaction::Commit() { sqlite3_reset(db_->commit_); sqlite3_reset(db_->begin_); Begin(); - return OkStatus(); + return absl::OkStatus(); } } // namespace tensorflow diff --git a/tensorflow/core/lib/gif/BUILD b/tensorflow/core/lib/gif/BUILD index 579b7e39024fb1..7816383907c309 100644 --- a/tensorflow/core/lib/gif/BUILD +++ b/tensorflow/core/lib/gif/BUILD @@ -44,7 +44,6 @@ cc_library( srcs = if_mobile([ "gif_io.cc", "//tensorflow/core/platform:gif_hdrs", - "@local_tsl//tsl/platform:gif_hdrs", ]), hdrs = [ "gif_io.h", diff --git a/tensorflow/core/lib/jpeg/BUILD b/tensorflow/core/lib/jpeg/BUILD index fec1f63cc6790b..26fbb6398bdfb8 100644 --- a/tensorflow/core/lib/jpeg/BUILD +++ b/tensorflow/core/lib/jpeg/BUILD @@ -26,7 +26,6 @@ cc_library( "jpeg_handle.cc", "jpeg_mem.cc", "//tensorflow/core/platform:jpeg_hdrs", - "@local_tsl//tsl/platform:jpeg_hdrs", ], hdrs = [ "jpeg_handle.h", @@ -52,7 +51,6 @@ cc_library( "jpeg_handle.cc", "jpeg_mem.cc", "//tensorflow/core/platform:jpeg_hdrs", - "@local_tsl//tsl/platform:jpeg_hdrs", ]), hdrs = [ "jpeg_handle.h", diff --git a/tensorflow/core/lib/wav/wav_io.cc b/tensorflow/core/lib/wav/wav_io.cc index c3588e8ae76130..670e7ba7fa4143 100644 --- a/tensorflow/core/lib/wav/wav_io.cc +++ b/tensorflow/core/lib/wav/wav_io.cc @@ -111,7 +111,7 @@ Status IncrementOffset(int old_offset, int64_t increment, size_t max_size, return errors::InvalidArgument("Offset too large, overflowed: ", sum); } *new_offset = sum; - return OkStatus(); + return absl::OkStatus(); } Status ExpectText(const std::string& data, const std::string& expected_text, @@ -126,7 +126,7 @@ Status ExpectText(const std::string& data, const std::string& expected_text, " but found ", found_text); } *offset = new_offset; - return OkStatus(); + return absl::OkStatus(); } Status ReadString(const std::string& data, int expected_length, @@ -136,7 +136,7 @@ Status ReadString(const std::string& data, int expected_length, IncrementOffset(*offset, expected_length, data.size(), &new_offset)); *value = std::string(data.begin() + *offset, data.begin() + new_offset); *offset = new_offset; - return OkStatus(); + return absl::OkStatus(); } template @@ -211,7 +211,7 @@ Status EncodeAudioAsS16LEWav(const float* audio, size_t sample_rate, core::EncodeFixed16(&data[i * kBytesPerSample], static_cast(sample)); } - return OkStatus(); + return absl::OkStatus(); } template Status EncodeAudioAsS16LEWav(const float* audio, @@ -348,7 +348,7 @@ Status DecodeLin16WaveAsFloatVector(const std::string& wav_string, if (!was_data_found) { return errors::InvalidArgument("No data chunk found in WAV"); } - return OkStatus(); + return absl::OkStatus(); } } // namespace wav diff --git a/tensorflow/core/lib/wav/wav_io.h b/tensorflow/core/lib/wav/wav_io.h index f404b13db63030..9918b72da5c983 100644 --- a/tensorflow/core/lib/wav/wav_io.h +++ b/tensorflow/core/lib/wav/wav_io.h @@ -95,7 +95,7 @@ Status ReadValue(const std::string& data, T* value, int* offset) { } } *offset = new_offset; - return OkStatus(); + return absl::OkStatus(); } } // namespace wav diff --git a/tensorflow/core/ops/compat/ops_history_v2/GlobalShuffleDataset.pbtxt b/tensorflow/core/ops/compat/ops_history_v2/GlobalShuffleDataset.pbtxt new file mode 100644 index 00000000000000..131281a80ec590 --- /dev/null +++ b/tensorflow/core/ops/compat/ops_history_v2/GlobalShuffleDataset.pbtxt @@ -0,0 +1,70 @@ +op { + name: "GlobalShuffleDataset" + input_arg { + name: "input_dataset" + type: DT_VARIANT + } + input_arg { + name: "seed" + type: DT_INT64 + } + input_arg { + name: "seed2" + type: DT_INT64 + } + input_arg { + name: "seed_generator" + type: DT_RESOURCE + } + output_arg { + name: "handle" + type: DT_VARIANT + experimental_full_type { + type_id: TFT_DATASET + args { + type_id: TFT_FOR_EACH + args { + type_id: TFT_PRODUCT + } + args { + type_id: TFT_TENSOR + args { + type_id: TFT_VAR + s: "output_types" + } + } + args { + type_id: TFT_VAR + s: "output_types" + } + } + } + } + attr { + name: "reshuffle_each_iteration" + type: "bool" + default_value { + b: true + } + } + attr { + name: "output_types" + type: "list(type)" + has_minimum: true + minimum: 1 + } + attr { + name: "output_shapes" + type: "list(shape)" + has_minimum: true + minimum: 1 + } + attr { + name: "metadata" + type: "string" + default_value { + s: "" + } + } + is_stateful: true +} diff --git a/tensorflow/core/ops/experimental_dataset_ops.cc b/tensorflow/core/ops/experimental_dataset_ops.cc index 396e1720aaf2fd..e3c8a17116efce 100644 --- a/tensorflow/core/ops/experimental_dataset_ops.cc +++ b/tensorflow/core/ops/experimental_dataset_ops.cc @@ -502,6 +502,27 @@ REGISTER_OP("ExperimentalGroupByWindowDataset") "output_types")) .SetShapeFn(shape_inference::ScalarShape); +REGISTER_OP("GlobalShuffleDataset") + .Input("input_dataset: variant") + .Input("seed: int64") + .Input("seed2: int64") + .Input("seed_generator: resource") + .Output("handle: variant") + .Attr("reshuffle_each_iteration: bool = true") + .Attr("output_types: list(type) >= 1") + .Attr("output_shapes: list(shape) >= 1") + .Attr("metadata: string = ''") + .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, + "output_types")) + .SetShapeFn([](shape_inference::InferenceContext* c) { + shape_inference::ShapeHandle unused; + // seed, seed2, and seed_generator should be scalars. + TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); + return shape_inference::ScalarShape(c); + }); + REGISTER_OP("IgnoreErrorsDataset") .Input("input_dataset: variant") .Output("handle: variant") diff --git a/tensorflow/core/ops/ops.pbtxt b/tensorflow/core/ops/ops.pbtxt index a2fca10c85576b..39e6cb6b3bd12d 100644 --- a/tensorflow/core/ops/ops.pbtxt +++ b/tensorflow/core/ops/ops.pbtxt @@ -22095,6 +22095,76 @@ op { } is_stateful: true } +op { + name: "GlobalShuffleDataset" + input_arg { + name: "input_dataset" + type: DT_VARIANT + } + input_arg { + name: "seed" + type: DT_INT64 + } + input_arg { + name: "seed2" + type: DT_INT64 + } + input_arg { + name: "seed_generator" + type: DT_RESOURCE + } + output_arg { + name: "handle" + type: DT_VARIANT + experimental_full_type { + type_id: TFT_DATASET + args { + type_id: TFT_FOR_EACH + args { + type_id: TFT_PRODUCT + } + args { + type_id: TFT_TENSOR + args { + type_id: TFT_VAR + s: "output_types" + } + } + args { + type_id: TFT_VAR + s: "output_types" + } + } + } + } + attr { + name: "reshuffle_each_iteration" + type: "bool" + default_value { + b: true + } + } + attr { + name: "output_types" + type: "list(type)" + has_minimum: true + minimum: 1 + } + attr { + name: "output_shapes" + type: "list(shape)" + has_minimum: true + minimum: 1 + } + attr { + name: "metadata" + type: "string" + default_value { + s: "" + } + } + is_stateful: true +} op { name: "Greater" input_arg { diff --git a/tensorflow/core/ops/sparse_csr_matrix_ops.cc b/tensorflow/core/ops/sparse_csr_matrix_ops.cc index b1b9cb5cbeb073..d872dacd7fba45 100644 --- a/tensorflow/core/ops/sparse_csr_matrix_ops.cc +++ b/tensorflow/core/ops/sparse_csr_matrix_ops.cc @@ -1,4 +1,4 @@ -/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -32,8 +32,8 @@ Status GetVariantInput(InferenceContext* c, int index, TF_RETURN_IF_ERROR(c->WithRank(c->input(index), 0, &variant)); auto* shapes_and_types = c->input_handle_shapes_and_types(index); if (shapes_and_types == nullptr || shapes_and_types->size() != 1) { - return absl::InvalidArgumentError(absl::StrCat( - "Unable to access shape and type info from variant input ", index)); + return errors::InvalidArgument( + "Unable to access shape and type info from variant input ", index); } *shape_and_type = shapes_and_types->at(0); return OkStatus(); @@ -48,7 +48,7 @@ Status ValidateSquareMatrixShape(InferenceContext* c, TF_RETURN_IF_ERROR(c->WithRankAtLeast(matrix_shape, 2, &out)); TF_RETURN_IF_ERROR(c->WithRankAtMost(matrix_shape, 3, &out)); if (!c->RankKnown(matrix_shape)) { - return absl::InvalidArgumentError("Sparse matrix has an unknown rank."); + return errors::Internal("Sparse matrix has an unknown rank."); } TF_RETURN_IF_ERROR(c->Merge(c->Dim(matrix_shape, -2), @@ -71,9 +71,9 @@ REGISTER_OP("SparseTensorToCSRSparseMatrix") TF_RETURN_IF_ERROR(c->WithRank(dense_shape, rank, &dense_shape)); if (!c->RankKnown(dense_shape) || c->Rank(dense_shape) < 2 || c->Rank(dense_shape) > 3) { - return absl::InvalidArgumentError( - absl::StrCat("Invalid rank: ", c->Rank(dense_shape), - ". Expected a known rank of either 2 or 3.")); + return errors::InvalidArgument( + "Invalid rank: ", c->Rank(dense_shape), + ". Expected a known rank of either 2 or 3."); } DataType dtype; @@ -96,7 +96,7 @@ REGISTER_OP("CSRSparseMatrixToSparseTensor") ShapeHandle sparse_matrix = sparse_matrix_shape_and_type.shape; TF_RETURN_IF_ERROR(c->WithRankAtMost(sparse_matrix, 3, &sparse_matrix)); if (!c->RankKnown(sparse_matrix)) { - return absl::InvalidArgumentError("sparse_matrix has an unknown rank."); + return errors::InvalidArgument("sparse_matrix has an unknown rank."); } int rank = c->Rank(sparse_matrix); ShapeHandle indices = c->Matrix(c->UnknownDim(), rank); @@ -117,23 +117,23 @@ REGISTER_OP("DenseToCSRSparseMatrix") ShapeHandle dense_shape = c->input(0); if (!c->RankKnown(dense_shape) || c->Rank(dense_shape) < 2 || c->Rank(dense_shape) > 3) { - return absl::InvalidArgumentError( - absl::StrCat("Invalid rank of dense: ", c->Rank(dense_shape), - ". Expected a known rank of either 2 or 3.")); + return errors::InvalidArgument( + "Invalid rank of dense: ", c->Rank(dense_shape), + ". Expected a known rank of either 2 or 3."); } auto rank = c->Rank(dense_shape); ShapeHandle indices = c->input(1); if (!c->RankKnown(indices) || c->Rank(indices) != 2) { - return absl::InvalidArgumentError( - absl::StrCat("indices must be a matrix; but its rank is not 2: ", - c->Rank(indices))); + return errors::InvalidArgument( + "indices must be a matrix; but its rank is not 2: ", + c->Rank(indices)); } auto indices_col = c->Dim(indices, 1); if (!c->ValueKnown(indices_col) || c->Value(indices_col) != rank) { - return absl::InvalidArgumentError( - absl::StrCat("indices.shape[1] must match rank of dense; saw: ", - c->Value(indices_col), " vs. ", rank)); + return errors::InvalidArgument( + "indices.shape[1] must match rank of dense; saw: ", + c->Value(indices_col), " vs. ", rank); } ShapeHandle fake_values_vec = c->Vector(c->Dim(indices, 0)); ShapeHandle fake_shape_shape = c->Vector(rank); @@ -158,7 +158,7 @@ REGISTER_OP("CSRSparseMatrixToDense") ShapeHandle sparse_matrix = sparse_matrix_shape_and_type.shape; TF_RETURN_IF_ERROR(c->WithRankAtMost(sparse_matrix, 3, &sparse_matrix)); if (!c->RankKnown(sparse_matrix)) { - return absl::InvalidArgumentError("sparse_matrix has an unknown rank."); + return errors::InvalidArgument("sparse_matrix has an unknown rank."); } c->set_output(0, sparse_matrix); return OkStatus(); @@ -181,10 +181,10 @@ REGISTER_OP("CSRSparseMatrixComponents") c->WithRankAtMost(csr_sparse_matrix, 3, &csr_sparse_matrix)); ShapeHandle index; if (c->Rank(c->input(1)) != 0) { - return absl::InvalidArgumentError("index must be a scalar."); + return errors::InvalidArgument("index must be a scalar."); } if (!c->RankKnown(csr_sparse_matrix)) { - return absl::InvalidArgumentError( + return errors::InvalidArgument( "csr_sparse_matrix has an unknown rank."); } auto row_ptrs_dh = c->Dim(csr_sparse_matrix, -2); @@ -206,7 +206,7 @@ REGISTER_OP("SparseMatrixNNZ") TF_RETURN_IF_ERROR(c->WithRankAtLeast(sparse_matrix, 2, &sparse_matrix)); TF_RETURN_IF_ERROR(c->WithRankAtMost(sparse_matrix, 3, &sparse_matrix)); if (!c->RankKnown(sparse_matrix)) { - return absl::InvalidArgumentError("sparse_matrix has an unknown rank."); + return errors::InvalidArgument("sparse_matrix has an unknown rank."); } ShapeHandle out; if (c->Rank(sparse_matrix) == 3) { @@ -236,7 +236,7 @@ REGISTER_OP("SparseMatrixMatMul") TF_RETURN_IF_ERROR(c->WithRankAtLeast(a_shape, 2, &a_shape)); TF_RETURN_IF_ERROR(c->WithRankAtMost(a_shape, 3, &a_shape)); if (!c->RankKnown(a_shape)) { - return absl::InvalidArgumentError("a has an unknown rank."); + return errors::Internal("a has an unknown rank."); } ShapeHandle b_shape; TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 2, &b_shape)); @@ -256,11 +256,11 @@ REGISTER_OP("SparseMatrixMatMul") TF_RETURN_IF_ERROR(c->GetAttr("adjoint_a", &adjoint_a)); TF_RETURN_IF_ERROR(c->GetAttr("adjoint_b", &adjoint_b)); if (adjoint_a && transpose_a) { - return absl::InvalidArgumentError( + return errors::InvalidArgument( "Only one of adjoint_a and transpose_a may be true."); } if (adjoint_b && transpose_b) { - return absl::InvalidArgumentError( + return errors::InvalidArgument( "Only one of adjoint_b and transpose_b may be true."); } transpose_a = transpose_a || adjoint_a; @@ -295,86 +295,6 @@ REGISTER_OP("SparseMatrixMatMul") return OkStatus(); }); -#ifdef INTEL_MKL - -REGISTER_OP("_MklNativeSparseMatrixMatMul") - .Input("a: variant") - .Input("b: T") - .Attr("T: type") - .Attr("transpose_a: bool = false") - .Attr("transpose_b: bool = false") - .Attr("adjoint_a: bool = false") - .Attr("adjoint_b: bool = false") - .Attr("transpose_output: bool = false") - .Attr("conjugate_output: bool = false") - .Output("output: T") - .SetShapeFn([](InferenceContext* c) { - VLOG(1) << "_MklNativeSparseMatrixMatMul shape function"; - ShapeAndType sparse_matrix_shape_and_type; - TF_RETURN_IF_ERROR(GetVariantInput(c, 0, &sparse_matrix_shape_and_type)); - ShapeHandle a_shape = sparse_matrix_shape_and_type.shape; - TF_RETURN_IF_ERROR(c->WithRank(a_shape, 2, &a_shape)); - if (!c->RankKnown(a_shape)) { - return absl::InvalidArgumentError("a has an unknown rank."); - } - ShapeHandle b_shape; - TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 2, &b_shape)); - VLOG(1) << "_MklNativeSparseMatrixMatMul shape function still"; - - bool transpose_a = false; - bool transpose_b = false; - bool transpose_output = false; - - // TODO(ebrevdo): Add transpose support. - TF_RETURN_IF_ERROR(c->GetAttr("transpose_a", &transpose_a)); - TF_RETURN_IF_ERROR(c->GetAttr("transpose_b", &transpose_b)); - TF_RETURN_IF_ERROR(c->GetAttr("transpose_output", &transpose_output)); - - bool adjoint_a = false; - bool adjoint_b = false; - TF_RETURN_IF_ERROR(c->GetAttr("adjoint_a", &adjoint_a)); - TF_RETURN_IF_ERROR(c->GetAttr("adjoint_b", &adjoint_b)); - if (adjoint_a && transpose_a) { - return absl::InvalidArgumentError( - "Only one of adjoint_a and transpose_a may be true."); - } - if (adjoint_b && transpose_b) { - return absl::InvalidArgumentError( - "Only one of adjoint_b and transpose_b may be true."); - } - transpose_a = transpose_a || adjoint_a; - transpose_b = transpose_b || adjoint_b; - - auto output_rows = c->Dim(a_shape, transpose_a ? -1 : -2); - auto output_cols = c->Dim(b_shape, transpose_b ? -2 : -1); - if (transpose_output) { - std::tie(output_rows, output_cols) = - std::make_tuple(output_cols, output_rows); - } - - // Batch dims match between inputs. - ShapeHandle a_batch_dims; - ShapeHandle b_batch_dims; - ShapeHandle batch_dims; - TF_RETURN_IF_ERROR(c->Subshape(a_shape, 0, -2, &a_batch_dims)); - TF_RETURN_IF_ERROR(c->Subshape(b_shape, 0, -2, &b_batch_dims)); - TF_RETURN_IF_ERROR(c->Merge(a_batch_dims, b_batch_dims, &batch_dims)); - - // Assert inner dims match. - shape_inference::DimensionHandle unused; - TF_RETURN_IF_ERROR(c->Merge(c->Dim(a_shape, transpose_a ? -2 : -1), - c->Dim(b_shape, transpose_b ? -1 : -2), - &unused)); - - ShapeHandle out; - TF_RETURN_IF_ERROR(c->Concatenate( - batch_dims, c->Matrix(output_rows, output_cols), &out)); - - c->set_output(0, out); - return OkStatus(); - }); -#endif - REGISTER_OP("SparseMatrixMul") .Input("a: variant") .Input("b: T") @@ -386,24 +306,23 @@ REGISTER_OP("SparseMatrixMul") ShapeHandle a_shape = sparse_matrix_shape_and_type.shape; TF_RETURN_IF_ERROR(c->WithRankAtMost(a_shape, 3, &a_shape)); if (!c->RankKnown(a_shape)) { - return absl::InvalidArgumentError("a has an unknown rank."); + return errors::Internal("a has an unknown rank."); } ShapeHandle b_shape; TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(1), 3, &b_shape)); if (!c->RankKnown(b_shape)) { - return absl::InvalidArgumentError("b has an unknown rank."); + return errors::Internal("b has an unknown rank."); } ShapeHandle out; if (c->Rank(b_shape) == 0) { out = a_shape; } else if (c->Rank(b_shape) == 3) { if (c->Rank(a_shape) != 3) { - return absl::UnimplementedError( - "rank of b is 3 but rank of a is not."); + return errors::Unimplemented("rank of b is 3 but rank of a is not."); } if (!(c->Value(c->Dim(b_shape, 1)) == 1 && c->Value(c->Dim(b_shape, 2)) == 1)) { - return absl::UnimplementedError( + return errors::Unimplemented( "b must be a scalar or shaped [batch_size, 1, 1]"); } DimensionHandle batch_size = c->Dim(a_shape, 0); @@ -413,7 +332,7 @@ REGISTER_OP("SparseMatrixMul") TF_RETURN_IF_ERROR(c->ReplaceDim(a_shape, 0, batch_size, &a_shape)); out = a_shape; } else { - return absl::UnimplementedError( + return errors::Unimplemented( "b must be a scalar or shaped [batch_size, 1, 1]"); } c->set_output_handle_shapes_and_types( @@ -441,7 +360,7 @@ REGISTER_OP("SparseMatrixAdd") TF_RETURN_IF_ERROR(c->WithRankAtLeast(a_shape, 2, &a_shape)); TF_RETURN_IF_ERROR(c->WithRankAtMost(a_shape, 3, &a_shape)); if (!c->RankKnown(a_shape)) { - return absl::InvalidArgumentError("a has an unknown rank."); + return errors::InvalidArgument("a has an unknown rank."); } TF_RETURN_IF_ERROR(GetVariantInput(c, 1, &sparse_matrix_shape_and_type)); @@ -449,7 +368,7 @@ REGISTER_OP("SparseMatrixAdd") TF_RETURN_IF_ERROR(c->WithRankAtLeast(b_shape, 2, &b_shape)); TF_RETURN_IF_ERROR(c->WithRankAtMost(b_shape, 3, &b_shape)); if (!c->RankKnown(b_shape)) { - return absl::InvalidArgumentError("b has an unknown rank."); + return errors::InvalidArgument("b has an unknown rank."); } ShapeHandle out; TF_RETURN_IF_ERROR(c->Merge(a_shape, b_shape, &out)); @@ -475,7 +394,7 @@ REGISTER_OP("SparseMatrixSparseMatMul") TF_RETURN_IF_ERROR(c->WithRankAtLeast(a_shape, 2, &a_shape)); TF_RETURN_IF_ERROR(c->WithRankAtMost(a_shape, 3, &a_shape)); if (!c->RankKnown(a_shape)) { - return absl::InvalidArgumentError("a has an unknown rank."); + return errors::Internal("a has an unknown rank."); } TF_RETURN_IF_ERROR(GetVariantInput(c, 1, &sparse_matrix_shape_and_type)); @@ -483,7 +402,7 @@ REGISTER_OP("SparseMatrixSparseMatMul") TF_RETURN_IF_ERROR(c->WithRankAtLeast(b_shape, 2, &b_shape)); TF_RETURN_IF_ERROR(c->WithRankAtMost(b_shape, 3, &b_shape)); if (!c->RankKnown(b_shape)) { - return absl::InvalidArgumentError("b has an unknown rank."); + return errors::Internal("b has an unknown rank."); } bool transpose_a = false; @@ -495,10 +414,10 @@ REGISTER_OP("SparseMatrixSparseMatMul") TF_RETURN_IF_ERROR(c->GetAttr("adjoint_a", &adjoint_a)); TF_RETURN_IF_ERROR(c->GetAttr("adjoint_b", &adjoint_b)); if (adjoint_a && transpose_a) { - return absl::InvalidArgumentError( + return errors::InvalidArgument( "Only one of adjoint_a and transpose_a may be true."); } else if (adjoint_b && transpose_b) { - return absl::InvalidArgumentError( + return errors::InvalidArgument( "Only one of adjoint_b and transpose_b may be true."); } transpose_a = transpose_a || adjoint_a; @@ -545,9 +464,9 @@ REGISTER_OP("SparseMatrixZeros") c->WithRank(dense_shape, c->Value(rank), &dense_shape)); if (!c->RankKnown(dense_shape) || c->Rank(dense_shape) < 2 || c->Rank(dense_shape) > 3) { - return absl::InvalidArgumentError( - absl::StrCat("Invalid rank: ", c->Rank(dense_shape), - ". Expected a known rank of either 2 or 3.")); + return errors::InvalidArgument( + "Invalid rank: ", c->Rank(dense_shape), + ". Expected a known rank of either 2 or 3."); } DataType dtype; TF_RETURN_IF_ERROR(c->GetAttr("type", &dtype)); @@ -569,7 +488,7 @@ REGISTER_OP("SparseMatrixTranspose") TF_RETURN_IF_ERROR(c->WithRankAtLeast(input, 2, &input)); TF_RETURN_IF_ERROR(c->WithRankAtMost(input, 3, &input)); if (!c->RankKnown(input)) { - return absl::InvalidArgumentError("input has an unknown rank."); + return errors::InvalidArgument("input has an unknown rank."); } ShapeHandle output; if (c->Rank(input) == 2) { @@ -596,7 +515,7 @@ REGISTER_OP("SparseMatrixSoftmax") TF_RETURN_IF_ERROR(c->WithRankAtLeast(logits, 2, &logits)); TF_RETURN_IF_ERROR(c->WithRankAtMost(logits, 3, &logits)); if (!c->RankKnown(logits)) { - return absl::InvalidArgumentError("logits has an unknown rank."); + return errors::InvalidArgument("logits has an unknown rank."); } c->set_output_handle_shapes_and_types( 0, {ShapeAndType{logits, sparse_matrix_shape_and_type.dtype}}); @@ -616,14 +535,14 @@ REGISTER_OP("SparseMatrixSoftmaxGrad") TF_RETURN_IF_ERROR(c->WithRankAtLeast(softmax, 2, &softmax)); TF_RETURN_IF_ERROR(c->WithRankAtMost(softmax, 3, &softmax)); if (!c->RankKnown(softmax)) { - return absl::InvalidArgumentError("softmax has an unknown rank."); + return errors::InvalidArgument("softmax has an unknown rank."); } TF_RETURN_IF_ERROR(GetVariantInput(c, 1, &sparse_matrix_shape_and_type)); ShapeHandle grad_softmax = sparse_matrix_shape_and_type.shape; TF_RETURN_IF_ERROR(c->WithRankAtLeast(grad_softmax, 2, &grad_softmax)); TF_RETURN_IF_ERROR(c->WithRankAtMost(grad_softmax, 3, &grad_softmax)); if (!c->RankKnown(grad_softmax)) { - return absl::InvalidArgumentError("grad_softmax has an unknown rank."); + return errors::InvalidArgument("grad_softmax has an unknown rank."); } TF_RETURN_IF_ERROR(c->Merge(softmax, grad_softmax, &softmax)); c->set_output_handle_shapes_and_types( @@ -668,7 +587,7 @@ REGISTER_OP("SparseMatrixSparseCholesky") TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 1, &perm_shape)); TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(1), 2, &perm_shape)); if (!c->RankKnown(perm_shape)) { - return absl::InvalidArgumentError("permutation has an unknown rank."); + return errors::Internal("permutation has an unknown rank."); } // Each batch component of permutation must have the same number of diff --git a/tensorflow/core/platform/BUILD b/tensorflow/core/platform/BUILD index 0019b2c868b613..9af3f143790b36 100644 --- a/tensorflow/core/platform/BUILD +++ b/tensorflow/core/platform/BUILD @@ -384,7 +384,7 @@ cc_library( name = "gif", hdrs = ["gif.h"], deps = [ - "@local_tsl//tsl/platform:gif", + "@gif", ], ) @@ -412,7 +412,7 @@ cc_library( name = "jpeg", hdrs = ["jpeg.h"], deps = [ - "@local_tsl//tsl/platform:jpeg", + "@libjpeg_turbo//:jpeg", ], ) @@ -580,7 +580,7 @@ cc_library( hdrs = ["png.h"], deps = [ ":platform", - "@local_tsl//tsl/platform:png", + "@png", ], ) @@ -1088,7 +1088,7 @@ tf_cuda_library( "@local_tsl//tsl/platform", "@local_xla//xla/stream_executor", "@local_xla//xla/stream_executor:dnn", - "@local_xla//xla/stream_executor:multi_platform_manager", + "@local_xla//xla/stream_executor:platform_manager", "@local_xla//xla/stream_executor/cuda:cuda_activation_header", "@local_xla//xla/stream_executor/cuda:cuda_platform_id", "@local_xla//xla/stream_executor/host:host_platform_id", @@ -1111,7 +1111,7 @@ cc_library( deps = [ "@local_xla//xla/stream_executor", "@local_xla//xla/stream_executor:dnn", - "@local_xla//xla/stream_executor:multi_platform_manager", + "@local_xla//xla/stream_executor:platform_manager", "@local_xla//xla/stream_executor/cuda:cuda_platform_id", "@local_xla//xla/stream_executor/host:host_platform", "@local_xla//xla/stream_executor/host:host_platform_id", diff --git a/tensorflow/core/platform/cpu_feature_guard.cc b/tensorflow/core/platform/cpu_feature_guard.cc index f70c874c1c3c77..748dc21090a557 100644 --- a/tensorflow/core/platform/cpu_feature_guard.cc +++ b/tensorflow/core/platform/cpu_feature_guard.cc @@ -111,18 +111,18 @@ class CPUFeatureGuard { #ifdef __AVXNECONVERT__ CheckFeatureOrDie(CPUFeature::AVX_NE_CONVERT, "AVX_NE_CONVERT"); #endif // __AVXNECONVERT__ -#ifdef __AMXTILE__ +#ifdef __AMX_TILE__ CheckFeatureOrDie(CPUFeature::AMX_TILE, "AMX_TILE"); -#endif // __AMXTILE__ -#ifdef __AMXINT8__ +#endif // __AMX_TILE__ +#ifdef __AMX_INT8__ CheckFeatureOrDie(CPUFeature::AMX_INT8, "AMX_INT8"); -#endif // __AMXINT8__ -#ifdef __AMXBF16__ +#endif // __AMX_INT8__ +#ifdef __AMX_BF16__ CheckFeatureOrDie(CPUFeature::AMX_BF16, "AMX_BF16"); -#endif // __AMXBF16__ -#ifdef __AMXFP16__ +#endif // __AMX_BF16__ +#ifdef __AMX_FP16__ CheckFeatureOrDie(CPUFeature::AMX_FP16, "AMX_FP16"); -#endif // __AMXFP16__ +#endif // __AMX_FP16__ #ifdef __FMA__ CheckFeatureOrDie(CPUFeature::FMA, "FMA"); #endif // __FMA__ @@ -187,22 +187,22 @@ void InfoAboutUnusedCPUFeatures() { CheckIfFeatureUnused(CPUFeature::AVX_NE_CONVERT, "AVX_NE_CONVERT", missing_instructions); #endif // __AVXNECONVERT__ -#ifndef __AMXTILE__ +#ifndef __AMX_TILE__ CheckIfFeatureUnused(CPUFeature::AMX_TILE, "AMX_TILE", missing_instructions); -#endif // __AMXTILE__ -#ifndef __AMXINT8__ +#endif // __AMX_TILE__ +#ifndef __AMX_INT8__ CheckIfFeatureUnused(CPUFeature::AMX_INT8, "AMX_INT8", missing_instructions); -#endif // __AMXINT8__ -#ifndef __AMXBF16__ +#endif // __AMX_INT8__ +#ifndef __AMX_BF16__ CheckIfFeatureUnused(CPUFeature::AMX_BF16, "AMX_BF16", missing_instructions); -#endif // __AMXBF16__ -#ifndef __AMXFP16__ +#endif // __AMX_BF16__ +#ifndef __AMX_FP16__ CheckIfFeatureUnused(CPUFeature::AMX_FP16, "AMX_FP16", missing_instructions); -#endif // __AMXFP16__ +#endif // __AMX_FP16__ #ifndef __FMA__ CheckIfFeatureUnused(CPUFeature::FMA, "FMA", missing_instructions); #endif // __FMA__ diff --git a/tensorflow/core/platform/default/build_config/BUILD b/tensorflow/core/platform/default/build_config/BUILD index 122ff0e928e3df..7745a940ec7727 100644 --- a/tensorflow/core/platform/default/build_config/BUILD +++ b/tensorflow/core/platform/default/build_config/BUILD @@ -1,8 +1,8 @@ # Description: # Platform-specific build configurations. -load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") load("//tensorflow:tensorflow.bzl", "tf_copts") +load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") package(default_visibility = ["//tensorflow:internal"]) @@ -19,8 +19,20 @@ cc_library( "@farmhash_archive//:farmhash", "@fft2d", "@highwayhash//:sip_hash", - "@local_tsl//tsl/platform/default/build_config:gif", - "@local_tsl//tsl/platform/default/build_config:jpeg", "@zlib", ], ) + +cc_library( + name = "tensorflow_platform_specific", + copts = tf_copts(), + linkstatic = 1, + deps = [], +) + +cc_library( + name = "test_main", + testonly = 1, + linkstatic = 1, + deps = [], +) diff --git a/tensorflow/core/platform/gif.h b/tensorflow/core/platform/gif.h index 0b9dc9289c6ba2..79af3822d29831 100644 --- a/tensorflow/core/platform/gif.h +++ b/tensorflow/core/platform/gif.h @@ -16,6 +16,6 @@ limitations under the License. #ifndef TENSORFLOW_CORE_PLATFORM_GIF_H_ #define TENSORFLOW_CORE_PLATFORM_GIF_H_ -#include "tsl/platform/gif.h" // IWYU pragma: export +#include "gif_lib.h" // from @gif #endif // TENSORFLOW_CORE_PLATFORM_GIF_H_ diff --git a/tensorflow/core/platform/jpeg.h b/tensorflow/core/platform/jpeg.h index 7f205d8cfd1503..68dadd18a03da6 100644 --- a/tensorflow/core/platform/jpeg.h +++ b/tensorflow/core/platform/jpeg.h @@ -16,6 +16,14 @@ limitations under the License. #ifndef TENSORFLOW_CORE_PLATFORM_JPEG_H_ #define TENSORFLOW_CORE_PLATFORM_JPEG_H_ -#include "tsl/platform/jpeg.h" +#include +#include +#include +#include + +extern "C" { +#include "jerror.h" // from @libjpeg_turbo // IWYU pragma: export +#include "jpeglib.h" // from @libjpeg_turbo // IWYU pragma: export +} #endif // TENSORFLOW_CORE_PLATFORM_JPEG_H_ diff --git a/tensorflow/core/platform/platform_strings_test.cc b/tensorflow/core/platform/platform_strings_test.cc index 807cbc00b457fc..3943ccdc018585 100644 --- a/tensorflow/core/platform/platform_strings_test.cc +++ b/tensorflow/core/platform/platform_strings_test.cc @@ -20,7 +20,10 @@ limitations under the License. #include #include #include + +#ifndef _WIN32 #include +#endif // _WIN32 #include #include diff --git a/tensorflow/core/platform/png.h b/tensorflow/core/platform/png.h index dde8ad007c6719..fc1a342165fac3 100644 --- a/tensorflow/core/platform/png.h +++ b/tensorflow/core/platform/png.h @@ -17,6 +17,14 @@ limitations under the License. #define TENSORFLOW_CORE_PLATFORM_PNG_H_ #include "tensorflow/core/platform/platform.h" -#include "tsl/platform/png.h" + +#if defined(PLATFORM_GOOGLE) && !defined(IS_MOBILE_PLATFORM) +#include "png.h" // from @png // IWYU pragma: export +#elif defined(PLATFORM_POSIX) || defined(PLATFORM_WINDOWS) || \ + defined(PLATFORM_POSIX_ANDROID) || defined(IS_MOBILE_PLATFORM) +#include // IWYU pragma: export +#else +#error Define the appropriate PLATFORM_ macro for this platform +#endif #endif // TENSORFLOW_CORE_PLATFORM_PNG_H_ diff --git a/tensorflow/core/platform/stream_executor.h b/tensorflow/core/platform/stream_executor.h index 5b836e2f194d4b..f72e3566645e59 100644 --- a/tensorflow/core/platform/stream_executor.h +++ b/tensorflow/core/platform/stream_executor.h @@ -21,9 +21,9 @@ limitations under the License. #include "xla/stream_executor/dnn.h" #include "xla/stream_executor/event.h" #include "xla/stream_executor/host/host_platform_id.h" -#include "xla/stream_executor/multi_platform_manager.h" #include "xla/stream_executor/platform.h" #include "xla/stream_executor/platform/dso_loader.h" +#include "xla/stream_executor/platform_manager.h" #include "xla/stream_executor/rocm/rocm_platform_id.h" #include "xla/stream_executor/scratch_allocator.h" #include "xla/stream_executor/stream.h" diff --git a/tensorflow/core/platform/stream_executor_no_cuda.h b/tensorflow/core/platform/stream_executor_no_cuda.h index d4305ec2af1e62..53f5ccefed2616 100644 --- a/tensorflow/core/platform/stream_executor_no_cuda.h +++ b/tensorflow/core/platform/stream_executor_no_cuda.h @@ -21,9 +21,9 @@ limitations under the License. #include "xla/stream_executor/dnn.h" #include "xla/stream_executor/event.h" #include "xla/stream_executor/host/host_platform_id.h" -#include "xla/stream_executor/multi_platform_manager.h" #include "xla/stream_executor/platform.h" #include "xla/stream_executor/platform/dso_loader.h" +#include "xla/stream_executor/platform_manager.h" #include "xla/stream_executor/rocm/rocm_platform_id.h" #include "xla/stream_executor/scratch_allocator.h" #include "xla/stream_executor/stream.h" diff --git a/tensorflow/core/platform/vmodule_test.cc b/tensorflow/core/platform/vmodule_test.cc index a9ae307815038e..72a95dba1f8b3a 100644 --- a/tensorflow/core/platform/vmodule_test.cc +++ b/tensorflow/core/platform/vmodule_test.cc @@ -45,7 +45,8 @@ int RealMain(const char* argv0, bool do_vlog) { tsl::internal::LogMessage::VmoduleActivated("vmodule_test.cc", 7) && tsl::internal::LogMessage::VmoduleActivated("shoobadooba.h", 3); if (!ok) { - fprintf(stderr, "vmodule activated levels not as expected.\n"); + fprintf(stderr, + "vmodule activated levels not as expected.\n[ FAILED ]\n"); return EXIT_FAILURE; } #endif @@ -82,7 +83,7 @@ int RealMain(const char* argv0, bool do_vlog) { } // Read data from the child's stdout. - constexpr int kBufferSizeBytes = 4096; + constexpr int kBufferSizeBytes = 8192; char buffer[kBufferSizeBytes]; size_t result = fread(buffer, sizeof(buffer[0]), kBufferSizeBytes - 1, f); if (result == 0) { @@ -103,6 +104,10 @@ int RealMain(const char* argv0, bool do_vlog) { if (strstr(buffer, kExpected) == nullptr) { fprintf(stderr, "error: unexpected output from child: \"%.*s\"\n", kBufferSizeBytes, buffer); + fprintf(stderr, + "\n\nCould not find string \"%s\" in the above log buffer.\n[ " + "FAILED ]\n", + kExpected); return EXIT_FAILURE; } bool ok = strstr(buffer, "VLOG(7)\n") != nullptr && @@ -111,10 +116,14 @@ int RealMain(const char* argv0, bool do_vlog) { if (!ok) { fprintf(stderr, "error: VLOG output not as expected: \"%.*s\"\n", kBufferSizeBytes, buffer); + fprintf(stderr, + "\n\nCould not find expected VLOG statements in the above log " + "buffer.\n[ FAILED ]\n"); return EXIT_FAILURE; } // Success! + fprintf(stderr, "\n[ PASSED ]\n"); return EXIT_SUCCESS; } diff --git a/tensorflow/core/profiler/backends/gpu/BUILD b/tensorflow/core/profiler/backends/gpu/BUILD index 9a5038d9712de6..1f40f0657d948c 100644 --- a/tensorflow/core/profiler/backends/gpu/BUILD +++ b/tensorflow/core/profiler/backends/gpu/BUILD @@ -1,16 +1,5 @@ -load("@local_config_cuda//cuda:build_defs.bzl", "cuda_library", "if_cuda") -load( - "//tensorflow:tensorflow.bzl", - "tf_copts", - "tf_cuda_library", -) load("//tensorflow:tensorflow.default.bzl", "tf_cuda_cc_test") -load( - "//tensorflow/core/platform:build_config_root.bzl", - "tf_cuda_tests_tags", -) -load("//tensorflow/core/profiler/builds:build_config.bzl", "tf_profiler_copts") -load("@local_config_rocm//rocm:build_defs.bzl", "if_rocm") +load("//tensorflow/core/platform:build_config_root.bzl", "tf_cuda_tests_tags") load( "@local_tsl//tsl/platform/default:cuda_build_defs.bzl", "if_cuda_is_configured", @@ -18,7 +7,6 @@ load( package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_visibility = ["//tensorflow:internal"], licenses = ["notice"], ) @@ -32,7 +20,6 @@ tf_cuda_cc_test( "nomac", ], deps = [ - ":cupti_collector", "//tensorflow/cc:cc_ops", "//tensorflow/core:all_kernels", "//tensorflow/core:core_cpu", @@ -56,194 +43,11 @@ tf_cuda_cc_test( "//tensorflow/core/profiler/utils:xplane_visitor", "@com_google_absl//absl/strings", "@local_tsl//tsl/profiler/utils:tf_xplane_visitor", + "@local_xla//xla/backends/profiler/gpu:cuda_test", + "@local_xla//xla/backends/profiler/gpu:cupti_collector", "@local_xla//xla/backends/profiler/gpu:device_tracer", ] + if_cuda_is_configured([ "@local_config_cuda//cuda:cuda_headers", "@local_config_cuda//cuda:cupti_headers", ]), ) - -tf_cuda_library( - name = "cupti_interface", - hdrs = if_cuda(["cupti_interface.h"]), - copts = tf_profiler_copts() + tf_copts(), - visibility = ["//visibility:public"], - deps = [ - "//tensorflow/core:lib", - "//tensorflow/core:platform_base", - "@local_xla//xla/backends/profiler/gpu:cupti_interface", - ] + if_cuda(["@local_tsl//tsl/cuda:cupti"]), -) - -tf_cuda_library( - name = "mock_cupti", - testonly = 1, - hdrs = if_cuda(["mock_cupti.h"]), - copts = tf_profiler_copts() + tf_copts(), - cuda_deps = [ - ":cupti_interface", - ], - deps = [ - "//tensorflow/core:test", - "@local_xla//xla/backends/profiler/gpu:mock_cupti", - ], -) - -tf_cuda_library( - name = "cupti_error_manager", - hdrs = if_cuda(["cupti_error_manager.h"]), - copts = tf_profiler_copts() + tf_copts(), - cuda_deps = [ - ":cupti_interface", - ":cupti_wrapper", - "@local_xla//xla/backends/profiler/gpu:cupti_error_manager", - "//tensorflow/core/platform:mutex", - "//tensorflow/core/platform:thread_annotations", - ], - visibility = ["//visibility:public"], - deps = [ - "@com_google_absl//absl/base:core_headers", - "@com_google_absl//absl/debugging:leak_check", - "@com_google_absl//absl/synchronization", - ], -) - -cuda_library( - name = "cuda_test", - testonly = 1, - hdrs = ["cuda_test.h"], - copts = select({ - "@local_config_cuda//cuda:using_nvcc": [ - "-nvcc_options", - "ptxas-options=-v", - ], - "//conditions:default": [], - }), - visibility = ["//visibility:public"], - deps = [ - "//tensorflow/core:test", - "@local_config_cuda//cuda:cuda_headers", - "@local_config_cuda//cuda:cudart", - "@local_xla//xla/backends/profiler/gpu:cuda_test", - ], -) - -# Rationale for linkstatic: The symbols in libcupti_static.a have hidden -# visibility. The wrapper will fail to find them if it's ever built as a -# shared library. This is the same issue as b/11094727. Always linking -# the wrapper statically works around the issue. An alternative would be -# to patch libcupti_static, but it's not worth the trouble considering -# that the wrapper is about the only direct user. -tf_cuda_library( - name = "cupti_wrapper", - hdrs = if_cuda(["cupti_wrapper.h"]), - copts = tf_profiler_copts() + tf_copts(), - linkstatic = 1, - visibility = ["//visibility:public"], - deps = [ - ":cupti_interface", - "@local_xla//xla/backends/profiler/gpu:cupti_wrapper", - ] + if_cuda(["@local_tsl//tsl/cuda:cupti"]), -) - -tf_cuda_library( - name = "cupti_tracer", - hdrs = if_cuda(["cupti_tracer.h"]), - copts = tf_profiler_copts() + tf_copts(), - visibility = ["//visibility:public"], - deps = [ - ":cupti_collector", - ":cupti_interface", - ":cupti_utils", - ":nvtx_utils", - "//tensorflow/core:lib", - "//tensorflow/core/profiler/backends/cpu:annotation_stack", - "//tensorflow/core/profiler/lib:scoped_annotation", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/container:node_hash_map", - "@com_google_absl//absl/container:node_hash_set", - "@com_google_absl//absl/types:optional", - "@local_tsl//tsl/profiler/utils:buffer_pool", - "@local_xla//xla/backends/profiler/gpu:cupti_tracer", - ], -) - -tf_cuda_library( - name = "rocm_tracer", - hdrs = if_rocm(["rocm_tracer.h"]), - copts = tf_profiler_copts() + tf_copts(), - visibility = ["//visibility:public"], - deps = [ - "//tensorflow/core:lib", - "//tensorflow/core/profiler/backends/cpu:annotation_stack", - "@com_google_absl//absl/container:fixed_array", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/container:node_hash_map", - "@com_google_absl//absl/container:node_hash_set", - "@com_google_absl//absl/types:optional", - "@local_tsl//tsl/profiler/utils:time_utils", - "@local_xla//xla/backends/profiler/gpu:rocm_tracer", - "@local_xla//xla/stream_executor/rocm:roctracer_wrapper", - ], -) - -tf_cuda_library( - name = "nvtx_utils", - hdrs = if_cuda(["nvtx_utils.h"]), - copts = tf_profiler_copts() + tf_copts(), - cuda_deps = ["@com_google_absl//absl/strings:string_view"], - deps = [ - "//tensorflow/core:lib", - "@local_xla//xla/backends/profiler/gpu:nvtx_utils", - ], -) - -tf_cuda_library( - name = "cupti_collector", - hdrs = if_cuda(["cupti_collector.h"]), - copts = tf_profiler_copts() + tf_copts(), - visibility = ["//visibility:public"], - deps = [ - "//tensorflow/core:lib", - "//tensorflow/core/profiler/protobuf:xplane_proto_cc", - "//tensorflow/core/profiler/utils:trace_utils", - "//tensorflow/core/profiler/utils:xplane_builder", - "//tensorflow/core/profiler/utils:xplane_schema", - "//tensorflow/core/profiler/utils:xplane_utils", - "@com_google_absl//absl/container:fixed_array", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/container:node_hash_set", - "@com_google_absl//absl/strings", - "@local_tsl//tsl/profiler/utils:parse_annotation", - "@local_xla//xla/backends/profiler/gpu:cupti_collector", - ] + if_cuda(["@local_tsl//tsl/cuda:cupti"]), -) - -cc_library( - name = "cupti_collector_header", - hdrs = ["cupti_collector.h"], - visibility = ["//visibility:public"], - deps = [ - "//tensorflow/core:lib", - "//tensorflow/core/profiler/protobuf:xplane_proto_cc", - "@com_google_absl//absl/container:fixed_array", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/container:node_hash_set", - "@com_google_absl//absl/strings", - "@local_xla//xla/backends/profiler/gpu:cupti_collector_header", - ], -) - -tf_cuda_library( - name = "cupti_utils", - copts = tf_profiler_copts() + tf_copts(), - cuda_deps = [ - ":cupti_error_manager", - ":cupti_interface", - ":cupti_wrapper", - "@local_xla//xla/backends/profiler/gpu:cupti_utils", - ], - visibility = ["//visibility:public"], - alwayslink = 1, -) diff --git a/tensorflow/core/profiler/backends/gpu/cuda_test.h b/tensorflow/core/profiler/backends/gpu/cuda_test.h deleted file mode 100644 index 65d8c395d7d74c..00000000000000 --- a/tensorflow/core/profiler/backends/gpu/cuda_test.h +++ /dev/null @@ -1,39 +0,0 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_CORE_PROFILER_BACKENDS_GPU_CUDA_TEST_H_ -#define TENSORFLOW_CORE_PROFILER_BACKENDS_GPU_CUDA_TEST_H_ - -#include "xla/backends/profiler/gpu/cuda_test.h" - -namespace tensorflow { -namespace profiler { -namespace test { - -using xla::profiler::test::EmptyKernel; // NOLINT -using xla::profiler::test::MemCopyD2H; // NOLINT -using xla::profiler::test::MemCopyH2D; // NOLINT -using xla::profiler::test::MemCopyH2D_Async; // NOLINT -using xla::profiler::test::MemCopyP2PAvailable; // NOLINT -using xla::profiler::test::MemCopyP2PExplicit; // NOLINT -using xla::profiler::test::MemCopyP2PImplicit; // NOLINT -using xla::profiler::test::PrintfKernel; // NOLINT -using xla::profiler::test::Synchronize; // NOLINT - -} // namespace test -} // namespace profiler -} // namespace tensorflow - -#endif // TENSORFLOW_CORE_PROFILER_BACKENDS_GPU_CUDA_TEST_H_ diff --git a/tensorflow/core/profiler/backends/gpu/cupti_collector.h b/tensorflow/core/profiler/backends/gpu/cupti_collector.h deleted file mode 100644 index 7673cbbcbf5239..00000000000000 --- a/tensorflow/core/profiler/backends/gpu/cupti_collector.h +++ /dev/null @@ -1,55 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_CORE_PROFILER_BACKENDS_GPU_CUPTI_COLLECTOR_H_ -#define TENSORFLOW_CORE_PROFILER_BACKENDS_GPU_CUPTI_COLLECTOR_H_ - -#include - -#include "absl/container/fixed_array.h" -#include "absl/container/flat_hash_map.h" -#include "absl/container/node_hash_set.h" -#include "absl/strings/string_view.h" -#include "xla/backends/profiler/gpu/cupti_collector.h" -#include "tensorflow/core/platform/macros.h" -#include "tensorflow/core/platform/status.h" -#include "tensorflow/core/platform/types.h" -#include "tensorflow/core/profiler/protobuf/xplane.pb.h" - -namespace tensorflow { -namespace profiler { - -using xla::profiler::AnnotationMap; // NOLINT -using xla::profiler::CreateCuptiCollector; // NOLINT -using xla::profiler::CuptiTraceCollector; // NOLINT -using xla::profiler::CuptiTracerCollectorOptions; // NOLINT -using xla::profiler::CuptiTracerEvent; // NOLINT -using xla::profiler::CuptiTracerEventSource; // NOLINT -using xla::profiler::CuptiTracerEventType; // NOLINT -using xla::profiler::GetMemoryKindName; // NOLINT -using xla::profiler::GetTraceEventTypeName; // NOLINT -using xla::profiler::KernelDetails; // NOLINT -using xla::profiler::MemAllocDetails; // NOLINT -using xla::profiler::MemcpyDetails; // NOLINT -using xla::profiler::MemsetDetails; // NOLINT -using xla::profiler::ToXStat; // NOLINT - -using MemFreeDetails = MemAllocDetails; -using MemoryResidencyDetails = MemAllocDetails; - -} // namespace profiler -} // namespace tensorflow - -#endif // TENSORFLOW_CORE_PROFILER_BACKENDS_GPU_CUPTI_COLLECTOR_H_ diff --git a/tensorflow/core/profiler/backends/gpu/cupti_error_manager.h b/tensorflow/core/profiler/backends/gpu/cupti_error_manager.h deleted file mode 100644 index 99c6ed2352b760..00000000000000 --- a/tensorflow/core/profiler/backends/gpu/cupti_error_manager.h +++ /dev/null @@ -1,41 +0,0 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_CORE_PROFILER_BACKENDS_GPU_CUPTI_ERROR_MANAGER_H_ -#define TENSORFLOW_CORE_PROFILER_BACKENDS_GPU_CUPTI_ERROR_MANAGER_H_ - -#include -#include - -#include -#include -#include -#include -#include - -#include "xla/backends/profiler/gpu/cupti_error_manager.h" -#include "tensorflow/core/platform/mutex.h" -#include "tensorflow/core/platform/thread_annotations.h" -#include "tensorflow/core/profiler/backends/gpu/cupti_interface.h" - -namespace tensorflow { -namespace profiler { - -using xla::profiler::CuptiErrorManager; // NOLINT - -} // namespace profiler -} // namespace tensorflow - -#endif // TENSORFLOW_CORE_PROFILER_BACKENDS_GPU_CUPTI_ERROR_MANAGER_H_ diff --git a/tensorflow/core/profiler/backends/gpu/cupti_interface.h b/tensorflow/core/profiler/backends/gpu/cupti_interface.h deleted file mode 100644 index c0f693f295cf75..00000000000000 --- a/tensorflow/core/profiler/backends/gpu/cupti_interface.h +++ /dev/null @@ -1,37 +0,0 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_CORE_PROFILER_BACKENDS_GPU_CUPTI_INTERFACE_H_ -#define TENSORFLOW_CORE_PROFILER_BACKENDS_GPU_CUPTI_INTERFACE_H_ - -#include -#include - -#include "third_party/gpus/cuda/extras/CUPTI/include/cupti.h" -#include "third_party/gpus/cuda/include/cuda.h" -#include "xla/backends/profiler/gpu/cupti_interface.h" -#include "tensorflow/core/platform/macros.h" -#include "tensorflow/core/platform/types.h" - -namespace tensorflow { -namespace profiler { - -using xla::profiler::CuptiInterface; // NOLINT -using xla::profiler::GetCuptiInterface; // NOLINT - -} // namespace profiler -} // namespace tensorflow - -#endif // TENSORFLOW_CORE_PROFILER_BACKENDS_GPU_CUPTI_INTERFACE_H_ diff --git a/tensorflow/core/profiler/backends/gpu/cupti_tracer.h b/tensorflow/core/profiler/backends/gpu/cupti_tracer.h deleted file mode 100644 index 6afa95a703cc4b..00000000000000 --- a/tensorflow/core/profiler/backends/gpu/cupti_tracer.h +++ /dev/null @@ -1,39 +0,0 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_CORE_PROFILER_BACKENDS_GPU_CUPTI_TRACER_H_ -#define TENSORFLOW_CORE_PROFILER_BACKENDS_GPU_CUPTI_TRACER_H_ - -#include "absl/types/optional.h" -#include "third_party/gpus/cuda/extras/CUPTI/include/cupti.h" -#include "third_party/gpus/cuda/include/nvtx3/nvToolsExt.h" -#include "xla/backends/profiler/gpu/cupti_tracer.h" -#include "tensorflow/core/platform/errors.h" -#include "tensorflow/core/platform/status.h" -#include "tensorflow/core/platform/types.h" -#include "tensorflow/core/profiler/backends/gpu/cupti_collector.h" -#include "tensorflow/core/profiler/backends/gpu/cupti_interface.h" -#include "tsl/profiler/utils/buffer_pool.h" - -namespace tensorflow { -namespace profiler { - -using xla::profiler::CuptiTracer; // NOLINT -using xla::profiler::CuptiTracerOptions; // NOLINT - -} // namespace profiler -} // namespace tensorflow - -#endif // TENSORFLOW_CORE_PROFILER_BACKENDS_GPU_CUPTI_TRACER_H_ diff --git a/tensorflow/core/profiler/backends/gpu/device_tracer_test.cc b/tensorflow/core/profiler/backends/gpu/device_tracer_test.cc index effc959443ac77..fe017117ae664c 100644 --- a/tensorflow/core/profiler/backends/gpu/device_tracer_test.cc +++ b/tensorflow/core/profiler/backends/gpu/device_tracer_test.cc @@ -26,7 +26,7 @@ limitations under the License. #include "third_party/gpus/cuda/extras/CUPTI/include/cupti_activity.h" #include "third_party/gpus/cuda/include/cuda.h" #include "third_party/gpus/cuda/include/cuda_runtime.h" -#include "tensorflow/core/profiler/backends/gpu/cupti_collector.h" +#include "xla/backends/profiler/gpu/cupti_collector.h" #endif // GOOGLE_CUDA #include "tensorflow/core/common_runtime/direct_session.h" #include "tensorflow/core/framework/allocator.h" @@ -439,13 +439,13 @@ TEST_F(DeviceTracerTest, CudaRuntimeResource) { if (addr == reinterpret_cast(devptr) && num_bytes == size_in_bytes) { found_activity_memory_device = true; - EXPECT_EQ(kind, - GetMemoryKindName(CUPTI_ACTIVITY_MEMORY_KIND_DEVICE)); + EXPECT_EQ(kind, xla::profiler::GetMemoryKindName( + CUPTI_ACTIVITY_MEMORY_KIND_DEVICE)); } else if (addr == reinterpret_cast(hostptr) && num_bytes == size_in_bytes) { found_activity_memory_host = true; - EXPECT_EQ(kind, - GetMemoryKindName(CUPTI_ACTIVITY_MEMORY_KIND_PINNED)); + EXPECT_EQ(kind, xla::profiler::GetMemoryKindName( + CUPTI_ACTIVITY_MEMORY_KIND_PINNED)); } } else if (stat.Type() == StatType::kMemsetDetails) { CHECK(!found_activity_memset); @@ -459,8 +459,8 @@ TEST_F(DeviceTracerTest, CudaRuntimeResource) { (void)absl::SimpleAtoi(name_value[1], &num_bytes); EXPECT_EQ(num_bytes, 8); } else if (absl::StartsWith(detail, "kind:")) { - EXPECT_EQ(name_value[1], - GetMemoryKindName(CUPTI_ACTIVITY_MEMORY_KIND_DEVICE)); + EXPECT_EQ(name_value[1], xla::profiler::GetMemoryKindName( + CUPTI_ACTIVITY_MEMORY_KIND_DEVICE)); } } } else if (stat.Type() == StatType::kMemcpyDetails) { @@ -475,11 +475,11 @@ TEST_F(DeviceTracerTest, CudaRuntimeResource) { (void)absl::SimpleAtoi(name_value[1], &num_bytes); EXPECT_EQ(num_bytes, 8); } else if (absl::StartsWith(detail, "kind_src:")) { - EXPECT_EQ(name_value[1], - GetMemoryKindName(CUPTI_ACTIVITY_MEMORY_KIND_DEVICE)); + EXPECT_EQ(name_value[1], xla::profiler::GetMemoryKindName( + CUPTI_ACTIVITY_MEMORY_KIND_DEVICE)); } else if (absl::StartsWith(detail, "kind_dst:")) { - EXPECT_EQ(name_value[1], - GetMemoryKindName(CUPTI_ACTIVITY_MEMORY_KIND_PINNED)); + EXPECT_EQ(name_value[1], xla::profiler::GetMemoryKindName( + CUPTI_ACTIVITY_MEMORY_KIND_PINNED)); } } } diff --git a/tensorflow/core/profiler/backends/gpu/rocm_tracer.h b/tensorflow/core/profiler/backends/gpu/rocm_tracer.h deleted file mode 100644 index 30d121aaac1579..00000000000000 --- a/tensorflow/core/profiler/backends/gpu/rocm_tracer.h +++ /dev/null @@ -1,58 +0,0 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_CORE_PROFILER_BACKENDS_GPU_ROCM_TRACER_H_ -#define TENSORFLOW_CORE_PROFILER_BACKENDS_GPU_ROCM_TRACER_H_ - -#include "absl/container/fixed_array.h" -#include "absl/container/flat_hash_map.h" -#include "absl/container/flat_hash_set.h" -#include "absl/container/node_hash_set.h" -#include "absl/types/optional.h" -#include "xla/backends/profiler/gpu/rocm_tracer.h" -#include "xla/stream_executor/rocm/roctracer_wrapper.h" -#include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/platform/macros.h" -#include "tensorflow/core/platform/types.h" - -namespace tensorflow { -namespace profiler { - -using xla::profiler::AnnotationMap; // NOLINT -using xla::profiler::DumpRocmTracerEvent; // NOLINT -using xla::profiler::GetRocmTracerEventDomainName; // NOLINT -using xla::profiler::GetRocmTracerEventSourceName; // NOLINT -using xla::profiler::GetRocmTracerEventTypeName; // NOLINT -using xla::profiler::KernelDetails; // NOLINT -using xla::profiler::MemAllocDetails; // NOLINT -using xla::profiler::MemcpyDetails; // NOLINT -using xla::profiler::MemsetDetails; // NOLINT -using xla::profiler::RocmActivityCallbackImpl; // NOLINT -using xla::profiler::RocmApiCallbackImpl; // NOLINT -using xla::profiler::RocmTraceCollector; // NOLINT -using xla::profiler::RocmTraceCollectorOptions; // NOLINT -using xla::profiler::RocmTracer; // NOLINT -using xla::profiler::RocmTracerEvent; // NOLINT -using xla::profiler::RocmTracerEventDomain; // NOLINT -using xla::profiler::RocmTracerEventSource; // NOLINT -using xla::profiler::RocmTracerEventType; // NOLINT -using xla::profiler::RocmTracerOptions; // NOLINT -using xla::profiler::RocmTracerSyncTypes; // NOLINT -using xla::profiler::SynchronizationDetails; // NOLINT - -} // namespace profiler -} // namespace tensorflow -#endif // TENSORFLOW_CORE_PROFILER_BACKENDS_GPU_ROCM_TRACER_H_ diff --git a/tensorflow/core/profiler/convert/BUILD b/tensorflow/core/profiler/convert/BUILD index b1b86e51aa8ce3..3fc84fa8818707 100644 --- a/tensorflow/core/profiler/convert/BUILD +++ b/tensorflow/core/profiler/convert/BUILD @@ -324,6 +324,7 @@ cc_library( "//tensorflow/core/profiler/utils:device_caps_utils", "//tensorflow/core/profiler/utils:event_span", "//tensorflow/core/profiler/utils:hardware_type_utils", + "//tensorflow/core/profiler/utils:hlo_proto_map", "//tensorflow/core/profiler/utils:kernel_stats_utils", "//tensorflow/core/profiler/utils:math_utils", "//tensorflow/core/profiler/utils:xplane_schema", diff --git a/tensorflow/core/profiler/convert/hlo_proto_to_graph_view.cc b/tensorflow/core/profiler/convert/hlo_proto_to_graph_view.cc index 3a372e263ab038..bee450104be0a1 100644 --- a/tensorflow/core/profiler/convert/hlo_proto_to_graph_view.cc +++ b/tensorflow/core/profiler/convert/hlo_proto_to_graph_view.cc @@ -266,6 +266,8 @@ std::string WrapDotInHtml(std::string dot) { } #graph-container {height:95vh;width:100%;padding:10px;display:block;} #graph-container svg { height: 100% !important; width: 100% !important;} + .node, .cluster {cursor:pointer;} + .cluster:hover, .node:hover {outline: solid 3px black;} diff --git a/tensorflow/core/profiler/convert/op_metrics_to_record.h b/tensorflow/core/profiler/convert/op_metrics_to_record.h index 514e248df86041..47b675554b691b 100644 --- a/tensorflow/core/profiler/convert/op_metrics_to_record.h +++ b/tensorflow/core/profiler/convert/op_metrics_to_record.h @@ -35,6 +35,13 @@ inline double GigaFlopsPerSecondPerCore(const OpMetrics& metrics) { return SafeDivide(metrics.flops(), PicoToNano(metrics.time_ps())); } +inline double GigaModelFlopsPerSecondPerCore(const OpMetrics& metrics) { + // flops and time_ps are accumulated across all occurrences on all cores. + // time_ps is used instead of self_time_ps because flops for an op includes + // the flops executed by children (nested) ops. + return SafeDivide(metrics.model_flops(), PicoToNano(metrics.time_ps())); +} + // Return ByteAccessed for memory_space and operation_type. inline double BytesAccessedPerCore( const OpMetrics& metrics, uint64_t memory_space, diff --git a/tensorflow/core/profiler/convert/op_profile_builder.cc b/tensorflow/core/profiler/convert/op_profile_builder.cc index ac999a4c24e244..c6d45f70df29c7 100644 --- a/tensorflow/core/profiler/convert/op_profile_builder.cc +++ b/tensorflow/core/profiler/convert/op_profile_builder.cc @@ -169,7 +169,7 @@ void PopulateOpMetricsNode( // and memory_bandwidth = raw_bytes_accessed / raw_time. See: // https://github.com/tensorflow/profiler/blob/master/frontend/app/common/utils/utils.ts metrics->set_raw_time(op_metrics.time_ps()); - metrics->set_raw_flops(op_metrics.flops()); + metrics->set_raw_flops(op_metrics.model_flops()); metrics->set_occurrences(op_metrics.occurrences()); metrics->set_avg_time_ps( SafeDivide(op_metrics.time_ps(), op_metrics.occurrences())); diff --git a/tensorflow/core/profiler/convert/op_stats_to_input_pipeline_analysis.cc b/tensorflow/core/profiler/convert/op_stats_to_input_pipeline_analysis.cc index 66a677441f7803..120e22bab5d64d 100644 --- a/tensorflow/core/profiler/convert/op_stats_to_input_pipeline_analysis.cc +++ b/tensorflow/core/profiler/convert/op_stats_to_input_pipeline_analysis.cc @@ -659,7 +659,7 @@ bool InputAnalysis(double input_percent, double all_other_percent, "could be due to I/O or Python execution or both)."); return true; } else { - // Defintely not input-bound. + // Definitely not input-bound. *input_classification = "device"; *input_statement = absl::StrCat("Your program is NOT input-bound because only ", diff --git a/tensorflow/core/profiler/convert/xplane_to_op_metrics_db.cc b/tensorflow/core/profiler/convert/xplane_to_op_metrics_db.cc index c5135e5eec4464..2954cc2206f19e 100644 --- a/tensorflow/core/profiler/convert/xplane_to_op_metrics_db.cc +++ b/tensorflow/core/profiler/convert/xplane_to_op_metrics_db.cc @@ -247,11 +247,14 @@ OpMetricsDb ConvertDeviceTraceXPlaneToOpMetricsDb(const XPlane& device_trace) { absl::string_view tf_op_full_name; bool is_eager = false; + int64_t program_id = 0; event.ForEachStat([&](const XStatVisitor& stat) { if (stat.Type() == StatType::kTfOp) { tf_op_full_name = stat.StrOrRefValue(); } else if (stat.Type() == StatType::kIsEager) { is_eager = stat.IntValue(); + } else if (stat.Type() == StatType::kProgramId) { + program_id = stat.IntOrUintValue(); } }); if (tf_op_full_name.empty()) return; @@ -262,8 +265,9 @@ OpMetricsDb ConvertDeviceTraceXPlaneToOpMetricsDb(const XPlane& device_trace) { costs = op_level_cost_estimator.Predict(event); } device_op_metrics_db_builder.EnterOp( - /*program_id=*/0, absl::StrCat(tf_op.name, "/", event.Name()), - tf_op.type, tf_op_full_name, is_eager, + /*program_id=*/program_id, + absl::StrCat(tf_op.name, "/", event.Name()), tf_op.type, + tf_op_full_name, is_eager, /*occurrences=*/1, event.DurationPs(), /*children_time_ps=*/0, costs.flops, costs.bytes_accessed); }); diff --git a/tensorflow/core/profiler/convert/xplane_to_op_metrics_db_test.cc b/tensorflow/core/profiler/convert/xplane_to_op_metrics_db_test.cc index 9d742ac95fad09..5226b07353f74c 100644 --- a/tensorflow/core/profiler/convert/xplane_to_op_metrics_db_test.cc +++ b/tensorflow/core/profiler/convert/xplane_to_op_metrics_db_test.cc @@ -226,6 +226,7 @@ TEST(ConvertXPlaneToOpMetricsDb, TpuDeviceOpMetricsDb) { #if defined(PLATFORM_GOOGLE) EXPECT_THAT(op_metrics, EqualsProto(R"pb(metrics_db { + hlo_module_id: 1 self_time_ps: 10000 flops: 68 occurrences: 2 diff --git a/tensorflow/core/profiler/convert/xplane_to_op_stats.cc b/tensorflow/core/profiler/convert/xplane_to_op_stats.cc index f8b1844dd4bb42..ab5906965f23fe 100644 --- a/tensorflow/core/profiler/convert/xplane_to_op_stats.cc +++ b/tensorflow/core/profiler/convert/xplane_to_op_stats.cc @@ -36,6 +36,7 @@ limitations under the License. #include "tensorflow/core/profiler/utils/device_caps_utils.h" #include "tensorflow/core/profiler/utils/event_span.h" #include "tensorflow/core/profiler/utils/hardware_type_utils.h" +#include "tensorflow/core/profiler/utils/hlo_proto_map.h" #include "tensorflow/core/profiler/utils/kernel_stats_utils.h" #include "tensorflow/core/profiler/utils/math_utils.h" #include "tensorflow/core/profiler/utils/xplane_schema.h" @@ -56,8 +57,6 @@ std::string Hostname(const XSpace& space) { if (space.hostnames().empty()) return "localhost"; DCHECK_EQ(space.hostnames_size(), 1); const std::string& hostname = space.hostnames(0); - // This shouldn't be a taskname in host:port format. - DCHECK(!absl::StrContains(hostname, ':')); return hostname; } @@ -166,6 +165,15 @@ void PropagateXSpaceDiagnosticsToOpStats(const XSpace& space, } } +// This function should be idempotent to be called +void SetProgramIdToNameMap(const HloProtoMap& hlo_proto_map, + tensorflow::profiler::OpStats& op_stats) { + auto& program_id_to_name_map = *op_stats.mutable_program_id_to_name_map(); + for (const auto& [program_id, hlo_proto] : hlo_proto_map) { + program_id_to_name_map[program_id] = hlo_proto->hlo_module().name(); + } +} + OpStats ConvertXSpaceToOpStats(const XSpace& space, const OpStatsOptions& options) { std::vector device_planes = FindTensorCorePlanes(space); @@ -259,6 +267,13 @@ OpStats ConvertXSpaceToOpStats(const XSpace& space, (*op_stats.mutable_core_id_to_details())[kDefaultGpuLocalCoreId]; details.set_hostname(Hostname(space)); } + + // Set program_id_to_name map in OpStats from Xspace + // Will be non-op if the space does not have materialized device traces + HloProtoMap hlo_proto_map; + hlo_proto_map.AddHloProtosFromXSpace(space); + SetProgramIdToNameMap(hlo_proto_map, op_stats); + return op_stats; } diff --git a/tensorflow/core/profiler/convert/xplane_to_op_stats.h b/tensorflow/core/profiler/convert/xplane_to_op_stats.h index 0a1cde34b13afb..994efb032920a6 100644 --- a/tensorflow/core/profiler/convert/xplane_to_op_stats.h +++ b/tensorflow/core/profiler/convert/xplane_to_op_stats.h @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/core/profiler/convert/repository.h" #include "tensorflow/core/profiler/protobuf/op_stats.pb.h" +#include "tensorflow/core/profiler/utils/hlo_proto_map.h" #include "tsl/profiler/protobuf/xplane.pb.h" namespace tensorflow { @@ -36,6 +37,10 @@ struct OpStatsOptions { OpStats ConvertXSpaceToOpStats(const XSpace& space, const OpStatsOptions& options); +// Populates the program_id_to_name map in OpStats. +void SetProgramIdToNameMap(const HloProtoMap& hlo_proto_map, + tensorflow::profiler::OpStats& op_stats); + // Populates the given RunEnvironment with data from XSpace. void SetRunEnvironment(const XSpace& space, RunEnvironment* env); diff --git a/tensorflow/core/profiler/lib/BUILD b/tensorflow/core/profiler/lib/BUILD index 96e06e52dbca13..a7049b302e4640 100644 --- a/tensorflow/core/profiler/lib/BUILD +++ b/tensorflow/core/profiler/lib/BUILD @@ -166,8 +166,6 @@ cc_library( ":scoped_annotation", ":traceme", "//tensorflow/core:lib", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/types:optional", ], ) @@ -209,6 +207,7 @@ cc_library( ], ) +# TODO(csigg): Remove this forwarding target. cc_library( name = "scoped_annotation", hdrs = ["scoped_annotation.h"], @@ -223,22 +222,6 @@ cc_library( ]), ) -cc_library( - name = "scoped_annotation_stack", - hdrs = ["scoped_annotation_stack.h"], - visibility = [ - "//visibility:private", # Only private by automation, not intent. Owner may accept CLs adding visibility. See go/scheuklappen#explicit-private. - ], - deps = [ - "//tensorflow/core:lib", - "//tensorflow/core/platform", - "@com_google_absl//absl/strings", - "@local_tsl//tsl/profiler/lib:scoped_annotation_stack", - ] + if_not_android([ - "//tensorflow/core/profiler/backends/cpu:annotation_stack", - ]), -) - cc_library( name = "profiler_lock", hdrs = ["profiler_lock.h"], diff --git a/tensorflow/core/profiler/lib/annotated_traceme.h b/tensorflow/core/profiler/lib/annotated_traceme.h index 24ab188674f101..313fca86d8d08c 100644 --- a/tensorflow/core/profiler/lib/annotated_traceme.h +++ b/tensorflow/core/profiler/lib/annotated_traceme.h @@ -15,13 +15,12 @@ limitations under the License. #ifndef TENSORFLOW_CORE_PROFILER_LIB_ANNOTATED_TRACEME_H_ #define TENSORFLOW_CORE_PROFILER_LIB_ANNOTATED_TRACEME_H_ +#include +#include #include -#include "absl/strings/string_view.h" -#include "absl/types/optional.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/macros.h" -#include "tensorflow/core/platform/types.h" #include "tensorflow/core/profiler/lib/scoped_annotation.h" #include "tensorflow/core/profiler/lib/traceme.h" @@ -37,20 +36,21 @@ class AnnotatedTraceMe { DCHECK_GE(level, 1); bool annotation_enabled = ScopedAnnotation::IsEnabled(); bool traceme_enabled = TraceMe::Active(level); - if (TF_PREDICT_FALSE(annotation_enabled || traceme_enabled)) { - string name = std::forward(name_generator)(); - if (annotation_enabled) { - scoped_annotation_.emplace(absl::string_view(name)); - } - if (TF_PREDICT_TRUE(traceme_enabled)) { - trace_me_.emplace([&name] { return std::move(name); }, level); - } + if (TF_PREDICT_TRUE(!annotation_enabled && !traceme_enabled)) { + return; + } + std::string name = name_generator(); + if (annotation_enabled) { + scoped_annotation_.emplace(name); + } + if (TF_PREDICT_TRUE(traceme_enabled)) { + trace_me_.emplace([&name] { return std::move(name); }, level); } } private: - absl::optional trace_me_; - absl::optional scoped_annotation_; + std::optional trace_me_; + std::optional scoped_annotation_; }; } // namespace profiler diff --git a/tensorflow/core/profiler/lib/scoped_annotation_stack.h b/tensorflow/core/profiler/lib/scoped_annotation_stack.h deleted file mode 100644 index 7c7d6688c68593..00000000000000 --- a/tensorflow/core/profiler/lib/scoped_annotation_stack.h +++ /dev/null @@ -1,40 +0,0 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#ifndef TENSORFLOW_CORE_PROFILER_LIB_SCOPED_ANNOTATION_STACK_H_ -#define TENSORFLOW_CORE_PROFILER_LIB_SCOPED_ANNOTATION_STACK_H_ - -#include - -#include -#include -#include -#include - -#include "absl/strings/string_view.h" -#include "tsl/profiler/lib/scoped_annotation_stack.h" - -#if !defined(IS_MOBILE_PLATFORM) -#include "tensorflow/core/profiler/backends/cpu/annotation_stack.h" -#endif - -namespace tensorflow { -namespace profiler { - -using tsl::profiler::ScopedAnnotationStack; // NOLINT - -} // namespace profiler -} // namespace tensorflow - -#endif // TENSORFLOW_CORE_PROFILER_LIB_SCOPED_ANNOTATION_STACK_H_ diff --git a/tensorflow/core/profiler/protobuf/overview_page.proto b/tensorflow/core/profiler/protobuf/overview_page.proto index 501d1d09170f55..4b91b96da14c60 100644 --- a/tensorflow/core/profiler/protobuf/overview_page.proto +++ b/tensorflow/core/profiler/protobuf/overview_page.proto @@ -69,6 +69,14 @@ message OverviewPageAnalysis { // BEGIN-INTERNAL // Program Goodput metric in percentage. double program_goodput_percent = 18; + // Sparse core step time in ms average. + double sc_step_time_ms_average = 19; + // Sparse core infeed time in ms average. + double sc_infeed_time_ms_avg = 20; + // Sparse core outfeed time in ms average. + double sc_outfeed_time_ms_avg = 21; + // Sparse core idle time in ms average. + double sc_idle_time_ms_avg = 22; // END-INTERNAL } diff --git a/tensorflow/core/profiler/utils/BUILD b/tensorflow/core/profiler/utils/BUILD index 064c4dc76fa45c..b67fb16cd8f7c4 100644 --- a/tensorflow/core/profiler/utils/BUILD +++ b/tensorflow/core/profiler/utils/BUILD @@ -281,8 +281,8 @@ tf_cc_test( ":kernel_stats_utils", "//tensorflow/core:test", "//tensorflow/core:test_main", - "//tensorflow/core/profiler/backends/gpu:cupti_collector_header", "//tensorflow/core/profiler/protobuf:kernel_stats_proto_cc", + "@local_xla//xla/backends/profiler/gpu:cupti_collector", ], ) diff --git a/tensorflow/core/profiler/utils/kernel_stats_utils_test.cc b/tensorflow/core/profiler/utils/kernel_stats_utils_test.cc index 45aea85f943adf..99096213f56704 100644 --- a/tensorflow/core/profiler/utils/kernel_stats_utils_test.cc +++ b/tensorflow/core/profiler/utils/kernel_stats_utils_test.cc @@ -15,8 +15,8 @@ limitations under the License. #include "tensorflow/core/profiler/utils/kernel_stats_utils.h" +#include "xla/backends/profiler/gpu/cupti_collector.h" #include "tensorflow/core/platform/test.h" -#include "tensorflow/core/profiler/backends/gpu/cupti_collector.h" #include "tensorflow/core/profiler/protobuf/kernel_stats.pb.h" namespace tensorflow { @@ -68,7 +68,7 @@ TEST(KernelStatsUtilsTest, TestGroupKernelReportsByOpName) { } TEST(KernelStatsUtilsTest, KernelDetailsXStatParser) { - KernelDetails kernel_info; + xla::profiler::KernelDetails kernel_info; kernel_info.registers_per_thread = 10; kernel_info.static_shared_memory_usage = 128; kernel_info.dynamic_shared_memory_usage = 256; diff --git a/tensorflow/core/profiler/utils/op_metrics_db_utils.cc b/tensorflow/core/profiler/utils/op_metrics_db_utils.cc index 68f4551df6223f..5ea6b463571a4e 100644 --- a/tensorflow/core/profiler/utils/op_metrics_db_utils.cc +++ b/tensorflow/core/profiler/utils/op_metrics_db_utils.cc @@ -111,6 +111,9 @@ void SetOpMetadataFromHloEventMetadata( hlo_event_metadata.ForEachStat([&](const XStatVisitor& stat) { if (stat.Type().has_value()) { switch (static_cast(*stat.Type())) { + case StatType::kProgramId: + op_metrics->set_hlo_module_id(stat.IntOrUintValue()); + break; case StatType::kHloCategory: op_metrics->set_category(std::string(stat.StrOrRefValue())); break; diff --git a/tensorflow/core/protobuf/BUILD b/tensorflow/core/protobuf/BUILD index 8c60bdab0656ee..76f6c6551f8d74 100644 --- a/tensorflow/core/protobuf/BUILD +++ b/tensorflow/core/protobuf/BUILD @@ -216,7 +216,6 @@ tf_proto_library( "@local_tsl//tsl/protobuf:status_proto", ], tags = ["alt_dep=//third_party/tensorflow/core:protos_all"], - visibility = ["//visibility:public"], exports = [ "@local_tsl//tsl/protobuf:bfc_memory_map_proto", "@local_tsl//tsl/protobuf:rpc_options_proto", diff --git a/tensorflow/core/protobuf/eager_service.proto b/tensorflow/core/protobuf/eager_service.proto index 501ec15d636988..0f3b89409d970f 100644 --- a/tensorflow/core/protobuf/eager_service.proto +++ b/tensorflow/core/protobuf/eager_service.proto @@ -121,6 +121,9 @@ message CreateContextRequest { // target devices after function instantiation to avoid redundant copies. bool lazy_copy_remote_function_inputs = 9; + // If true, clears resource managers created in the worker environment. + bool clear_existing_contexts = 10; + reserved 5; } diff --git a/tensorflow/core/public/version.h b/tensorflow/core/public/version.h index a22f1072887762..8e61961924d6c5 100644 --- a/tensorflow/core/public/version.h +++ b/tensorflow/core/public/version.h @@ -21,7 +21,7 @@ limitations under the License. // Also update tensorflow/tensorflow.bzl and // tensorflow/tools/pip_package/setup.py #define TF_MAJOR_VERSION 2 -#define TF_MINOR_VERSION 16 +#define TF_MINOR_VERSION 17 #define TF_PATCH_VERSION 0 // TF_VERSION_SUFFIX is non-empty for pre-releases (e.g. "-alpha", "-alpha.1", @@ -108,7 +108,7 @@ limitations under the License. #define TF_GRAPH_DEF_VERSION_MIN_PRODUCER 0 #define TF_GRAPH_DEF_VERSION_MIN_CONSUMER 0 -#define TF_GRAPH_DEF_VERSION 1757 // Updated: 2024/1/30 +#define TF_GRAPH_DEF_VERSION 1780 // Updated: 2024/2/22 // Checkpoint compatibility versions (the versions field in SavedSliceMeta). // diff --git a/tensorflow/core/runtime_fallback/kernel/BUILD b/tensorflow/core/runtime_fallback/kernel/BUILD index de61ddf6451b97..371808583c5345 100644 --- a/tensorflow/core/runtime_fallback/kernel/BUILD +++ b/tensorflow/core/runtime_fallback/kernel/BUILD @@ -69,7 +69,9 @@ cc_library( hdrs = ["attr_util.h"], visibility = ["//visibility:public"], deps = [ + "//tensorflow/core:portable_gif_internal", "//tensorflow/core/runtime_fallback/util:attr_util", + "@com_google_absl//absl/strings", "@llvm-project//llvm:Support", "@tf_runtime//:core_runtime", "@tf_runtime//:hostcontext", diff --git a/tensorflow/core/runtime_fallback/kernel/attr_util.cc b/tensorflow/core/runtime_fallback/kernel/attr_util.cc index 47cd6d91fb0af6..3efe36f3766a44 100644 --- a/tensorflow/core/runtime_fallback/kernel/attr_util.cc +++ b/tensorflow/core/runtime_fallback/kernel/attr_util.cc @@ -20,12 +20,18 @@ limitations under the License. #include #include +#include "absl/strings/numbers.h" #include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/status.h" #include "tensorflow/core/platform/str_util.h" #include "tensorflow/core/platform/stringpiece.h" +#include "tensorflow/core/platform/types.h" #include "tensorflow/core/runtime_fallback/util/attr_util.h" +#include "tensorflow/core/util/padding.h" +#include "tfrt/core_runtime/op_attr_type.h" // from @tf_runtime +#include "tfrt/core_runtime/op_attrs.h" // from @tf_runtime +#include "tfrt/host_context/kernel_utils.h" // from @tf_runtime namespace tensorflow { diff --git a/tensorflow/core/runtime_fallback/kernel/attr_util.h b/tensorflow/core/runtime_fallback/kernel/attr_util.h index 6e56c3c17ebba8..387f227f1c8cb4 100644 --- a/tensorflow/core/runtime_fallback/kernel/attr_util.h +++ b/tensorflow/core/runtime_fallback/kernel/attr_util.h @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/platform/status.h" #include "tensorflow/core/platform/stringpiece.h" +#include "tensorflow/core/platform/types.h" #include "tensorflow/core/runtime_fallback/util/attr_util.h" #include "tensorflow/core/util/padding.h" #include "tfrt/core_runtime/op_attrs.h" // from @tf_runtime diff --git a/tensorflow/core/runtime_fallback/kernel/attr_util_test.cc b/tensorflow/core/runtime_fallback/kernel/attr_util_test.cc index f95c150c84e85b..bdb6f9ee1e1cd2 100644 --- a/tensorflow/core/runtime_fallback/kernel/attr_util_test.cc +++ b/tensorflow/core/runtime_fallback/kernel/attr_util_test.cc @@ -16,13 +16,12 @@ limitations under the License. #include -#include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/status.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/types.h" +#include "tsl/lib/core/status_test_util.h" #include "tfrt/core_runtime/op_attr_type.h" // from @tf_runtime #include "tfrt/core_runtime/op_attrs.h" // from @tf_runtime -#include "tfrt/host_context/kernel_utils.h" // from @tf_runtime #include "tfrt/support/forward_decls.h" // from @tf_runtime using llvm::ArrayRef; diff --git a/tensorflow/core/runtime_fallback/test/BUILD b/tensorflow/core/runtime_fallback/test/BUILD index e210a8296f3b85..e339e59bcd6fa9 100644 --- a/tensorflow/core/runtime_fallback/test/BUILD +++ b/tensorflow/core/runtime_fallback/test/BUILD @@ -190,9 +190,9 @@ cc_library( # ":coreruntime_driver", # "@com_google_googletest//:gtest", # "//tensorflow/core/platform:test_benchmark", +# "//tensorflow/core/platform/default/build_config:test_main", # "//tensorflow/core/runtime_fallback/kernel:kernel_fallback_op_handler", # "//tensorflow/core/runtime_fallback/kernel:kernel_fallback_tensor", -# "@local_tsl//tsl/platform/default/build_config:test_main", # "@tf_runtime//:core_runtime_alwayslink", # "@tf_runtime//:hostcontext", # "@tf_runtime//:tensor", @@ -215,9 +215,9 @@ cc_library( # deps = [ # "@com_google_googletest//:gtest", # "//tensorflow/core/framework:tensor_testutil", +# "//tensorflow/core/platform/default/build_config:test_main", # "//tensorflow/core/runtime_fallback/kernel:kernel_fallback_compat_request_state", # "//tensorflow/core/runtime_fallback/kernel:kernel_fallback_op_handler", -# "@local_tsl//tsl/platform/default/build_config:test_main", # "@tf_runtime//:core_runtime_alwayslink", # ], # ) diff --git a/tensorflow/core/tfrt/common/BUILD b/tensorflow/core/tfrt/common/BUILD index 469bbff25d6f04..48ad12ffda86f0 100644 --- a/tensorflow/core/tfrt/common/BUILD +++ b/tensorflow/core/tfrt/common/BUILD @@ -17,6 +17,7 @@ package_group( packages = [ # copybara:uncomment "//learning/brain/experimental/dtensor/...", # copybara:uncomment "//learning/brain/experimental/tfrt/...", + # copybara:uncomment "//learning/brain/google/monitoring/...", # copybara:uncomment "//learning/brain/google/xla/...", # copybara:uncomment "//learning/brain/tfrc/...", # copybara:uncomment "//learning/brain/tfrt/...", @@ -249,6 +250,16 @@ cc_library( alwayslink = True, ) +cc_library( + name = "metrics", + srcs = ["metrics.cc"], + hdrs = ["metrics.h"], + deps = [ + "//tensorflow/core:lib", + "@com_google_absl//absl/strings", + ], +) + tf_cuda_cc_test( name = "pjrt_gpu_client_registration_test", size = "small", diff --git a/tensorflow/core/tfrt/common/metrics.cc b/tensorflow/core/tfrt/common/metrics.cc new file mode 100644 index 00000000000000..ad6f3fb218eb96 --- /dev/null +++ b/tensorflow/core/tfrt/common/metrics.cc @@ -0,0 +1,40 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/tfrt/common/metrics.h" + +#include +#include + +#include "absl/strings/str_cat.h" +#include "tensorflow/core/lib/monitoring/sampler.h" +#include "tsl/lib/monitoring/sampler.h" + +namespace tensorflow { +namespace tfrt_metrics { + +monitoring::SamplerCell* GetTfrtGraphExecutorLatencySampler( + const std::string& model_name, int64_t model_version, + const std::string& graph_name) { + static auto* cell = tensorflow::monitoring::Sampler<3>::New( + {"/tfrt/graph_executor/latency", + "Tracks the latency of GraphExecutor (in microseconds) of a graph.", + "model_name", "model_version", "graph_name"}, + monitoring::Buckets::Exponential(10, 1.5, 33)); + return cell->GetCell(model_name, absl::StrCat(model_version), graph_name); +} + +} // namespace tfrt_metrics +} // namespace tensorflow diff --git a/tensorflow/core/profiler/backends/gpu/mock_cupti.h b/tensorflow/core/tfrt/common/metrics.h similarity index 55% rename from tensorflow/core/profiler/backends/gpu/mock_cupti.h rename to tensorflow/core/tfrt/common/metrics.h index 2631def6819992..749dfdae4328b2 100644 --- a/tensorflow/core/profiler/backends/gpu/mock_cupti.h +++ b/tensorflow/core/tfrt/common/metrics.h @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -13,24 +13,23 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CORE_PROFILER_BACKENDS_GPU_MOCK_CUPTI_H_ -#define TENSORFLOW_CORE_PROFILER_BACKENDS_GPU_MOCK_CUPTI_H_ - -#include -#include +#ifndef TENSORFLOW_CORE_TFRT_COMMON_METRICS_H_ +#define TENSORFLOW_CORE_TFRT_COMMON_METRICS_H_ #include +#include -#include "xla/backends/profiler/gpu/mock_cupti.h" -#include "tensorflow/core/platform/test.h" -#include "tensorflow/core/profiler/backends/gpu/cupti_interface.h" +#include "tensorflow/core/lib/monitoring/sampler.h" +#include "tsl/lib/monitoring/sampler.h" namespace tensorflow { -namespace profiler { +namespace tfrt_metrics { -using xla::profiler::MockCupti; // NOLINT +monitoring::SamplerCell* GetTfrtGraphExecutorLatencySampler( + const std::string& model_name, int64_t model_version, + const std::string& graph_name); -} // namespace profiler +} // namespace tfrt_metrics } // namespace tensorflow -#endif // TENSORFLOW_CORE_PROFILER_BACKENDS_GPU_MOCK_CUPTI_H_ +#endif // TENSORFLOW_CORE_TFRT_COMMON_METRICS_H_ diff --git a/tensorflow/core/tfrt/gpu/kernel/BUILD b/tensorflow/core/tfrt/gpu/kernel/BUILD index 2d24322cf93f3e..139da4d3b16103 100644 --- a/tensorflow/core/tfrt/gpu/kernel/BUILD +++ b/tensorflow/core/tfrt/gpu/kernel/BUILD @@ -47,7 +47,6 @@ cc_library( "//tensorflow/compiler/jit:xla_launch_util", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/core:framework", - "//tensorflow/core/common_runtime:serving_device_selector", "//tensorflow/core/platform:notification", "//tensorflow/core/platform:status", "//tensorflow/core/platform:statusor", @@ -64,6 +63,7 @@ cc_library( "@com_google_absl//absl/types:span", "@llvm-project//llvm:Support", "@local_tsl//tsl/framework:device_id", + "@local_tsl//tsl/framework:serving_device_selector", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:fingerprint", "@local_tsl//tsl/platform:statusor", @@ -90,7 +90,6 @@ tf_cuda_cc_test( "//tensorflow/compiler/jit:xla_gpu_jit", "//tensorflow/core:framework", "//tensorflow/core:test", - "//tensorflow/core/common_runtime:serving_device_selector_policies", "//tensorflow/core/common_runtime/gpu:gpu_serving_device_selector", "//tensorflow/core/framework:tensor_testutil", "//tensorflow/core/kernels:ops_testutil", @@ -114,10 +113,10 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":gpu_runner", - "//tensorflow/core/common_runtime:serving_device_selector_policies", "//tensorflow/core/common_runtime/gpu:gpu_serving_device_selector", "//tensorflow/core/platform:status", "//tensorflow/core/tfrt/runtime", + "@local_tsl//tsl/framework:serving_device_selector_policies", "@tf_runtime//:hostcontext", ], ) diff --git a/tensorflow/core/tfrt/gpu/kernel/gpu_runner.cc b/tensorflow/core/tfrt/gpu/kernel/gpu_runner.cc index 7862c43f74d8cf..95b8fd7141054e 100644 --- a/tensorflow/core/tfrt/gpu/kernel/gpu_runner.cc +++ b/tensorflow/core/tfrt/gpu/kernel/gpu_runner.cc @@ -36,7 +36,6 @@ limitations under the License. #include "tensorflow/compiler/jit/xla_platform_info.h" #include "tensorflow/compiler/tf2xla/xla_compiler.h" #include "xla/pjrt/pjrt_client.h" -#include "tensorflow/core/common_runtime/serving_device_selector.h" #include "tensorflow/core/framework/allocator.h" #include "tensorflow/core/framework/device.h" #include "tensorflow/core/framework/device_base.h" @@ -53,6 +52,7 @@ limitations under the License. #include "tensorflow/core/tfrt/utils/gpu_variables_table.h" #include "tsl/framework/device_id.h" #include "tsl/framework/device_id_manager.h" +#include "tsl/framework/serving_device_selector.h" #include "tsl/platform/errors.h" #include "tsl/platform/fingerprint.h" #include "tsl/platform/statusor.h" @@ -347,7 +347,7 @@ GpuRunner::Run(const GpuRunInputs& run_inputs) { TF_ASSIGN_OR_RETURN(uint64_t fingerprint, GenerateFingerprint(run_inputs.func_name, run_inputs.fallback_request_state)); - DeviceReservation device_reservation = + tsl::DeviceReservation device_reservation = serving_device_selector_->ReserveDevice(absl::StrCat(fingerprint)); const int device_idx = device_reservation.device_index(); diff --git a/tensorflow/core/tfrt/gpu/kernel/gpu_runner.h b/tensorflow/core/tfrt/gpu/kernel/gpu_runner.h index 97e93bb19ed70f..3671d83993a873 100644 --- a/tensorflow/core/tfrt/gpu/kernel/gpu_runner.h +++ b/tensorflow/core/tfrt/gpu/kernel/gpu_runner.h @@ -18,13 +18,13 @@ limitations under the License. #include #include "absl/container/flat_hash_map.h" -#include "tensorflow/core/common_runtime/serving_device_selector.h" #include "tensorflow/core/framework/resource_mgr.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/platform/status.h" #include "tensorflow/core/runtime_fallback/kernel/kernel_fallback_compat_request_state.h" #include "tensorflow/core/tfrt/utils/fallback_tensor.h" #include "tensorflow/core/tfrt/utils/gpu_variables_table.h" +#include "tsl/framework/serving_device_selector.h" #include "tfrt/host_context/async_value_ref.h" // from @tf_runtime #include "tfrt/host_context/execution_context.h" // from @tf_runtime @@ -47,7 +47,7 @@ struct GpuRunInputs { class GpuRunner { public: - explicit GpuRunner(ServingDeviceSelector* serving_device_selector) + explicit GpuRunner(tsl::ServingDeviceSelector* serving_device_selector) : serving_device_selector_(serving_device_selector) {} // This compiles the given program and runs the given input tensors in @@ -56,7 +56,7 @@ class GpuRunner { Run(const GpuRunInputs& run_inputs); private: - ServingDeviceSelector* serving_device_selector_; + tsl::ServingDeviceSelector* serving_device_selector_; tfrt::gpu::GpuVariablesTable vars_table_; }; diff --git a/tensorflow/core/tfrt/gpu/kernel/gpu_runner_test.cc b/tensorflow/core/tfrt/gpu/kernel/gpu_runner_test.cc index cbde67fb55ecb3..e2b999cb23ec06 100644 --- a/tensorflow/core/tfrt/gpu/kernel/gpu_runner_test.cc +++ b/tensorflow/core/tfrt/gpu/kernel/gpu_runner_test.cc @@ -21,7 +21,6 @@ limitations under the License. #include "tensorflow/cc/ops/function_ops.h" #include "tensorflow/cc/ops/math_ops.h" #include "tensorflow/core/common_runtime/gpu/gpu_serving_device_selector.h" -#include "tensorflow/core/common_runtime/serving_device_selector_policies.h" #include "tensorflow/core/framework/device_base.h" #include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/graph_to_functiondef.h" @@ -34,6 +33,7 @@ limitations under the License. #include "tensorflow/core/platform/statusor.h" #include "tensorflow/core/runtime_fallback/kernel/kernel_fallback_compat_request_state.h" #include "tensorflow/core/tfrt/fallback/fallback_state.h" +#include "tsl/framework/serving_device_selector_policies.h" #include "tfrt/host_context/concurrent_work_queue.h" // from @tf_runtime #include "tfrt/host_context/diagnostic.h" // from @tf_runtime #include "tfrt/host_context/function.h" // from @tf_runtime diff --git a/tensorflow/core/tfrt/gpu/kernel/tfrt_gpu_init.cc b/tensorflow/core/tfrt/gpu/kernel/tfrt_gpu_init.cc index 9d8abca5a0207b..d3541e594c8a4e 100644 --- a/tensorflow/core/tfrt/gpu/kernel/tfrt_gpu_init.cc +++ b/tensorflow/core/tfrt/gpu/kernel/tfrt_gpu_init.cc @@ -18,10 +18,10 @@ limitations under the License. #include #include "tensorflow/core/common_runtime/gpu/gpu_serving_device_selector.h" -#include "tensorflow/core/common_runtime/serving_device_selector_policies.h" #include "tensorflow/core/platform/status.h" #include "tensorflow/core/tfrt/gpu/kernel/gpu_runner.h" #include "tensorflow/core/tfrt/runtime/runtime.h" +#include "tsl/framework/serving_device_selector_policies.h" #include "tfrt/host_context/resource_context.h" // from @tf_runtime namespace tensorflow { @@ -29,7 +29,7 @@ namespace gpu { Status InitTfrtGpu(const GpuRunnerOptions& options, tensorflow::tfrt_stub::Runtime& runtime) { - auto policy = std::make_unique(); + auto policy = std::make_unique(); auto serving_device_selector = std::make_unique( options.num_gpu_streams, std::move(policy)); diff --git a/tensorflow/core/tfrt/gpu/kernel/tfrt_gpu_init.h b/tensorflow/core/tfrt/gpu/kernel/tfrt_gpu_init.h index 7475a6ddf50423..dcb4e2787aaaaa 100644 --- a/tensorflow/core/tfrt/gpu/kernel/tfrt_gpu_init.h +++ b/tensorflow/core/tfrt/gpu/kernel/tfrt_gpu_init.h @@ -14,16 +14,16 @@ limitations under the License. ==============================================================================*/ #ifndef TENSORFLOW_CORE_TFRT_GPU_KERNEL_TFRT_GPU_INIT_H_ #define TENSORFLOW_CORE_TFRT_GPU_KERNEL_TFRT_GPU_INIT_H_ -#include "tensorflow/core/common_runtime/serving_device_selector_policies.h" #include "tensorflow/core/tfrt/runtime/runtime.h" +#include "tsl/framework/serving_device_selector_policies.h" namespace tensorflow { namespace gpu { struct GpuRunnerOptions { int num_gpu_streams = 1; - ServingDeviceSelectorPolicy serving_selector_policy = - ServingDeviceSelectorPolicy::kRoundRobin; + tsl::ServingDeviceSelectorPolicy serving_selector_policy = + tsl::ServingDeviceSelectorPolicy::kRoundRobin; }; Status InitTfrtGpu(const GpuRunnerOptions& options, diff --git a/tensorflow/core/tfrt/graph_executor/BUILD b/tensorflow/core/tfrt/graph_executor/BUILD index 9d63aee1133459..931e00ff95aef4 100644 --- a/tensorflow/core/tfrt/graph_executor/BUILD +++ b/tensorflow/core/tfrt/graph_executor/BUILD @@ -77,6 +77,7 @@ cc_library( "//tensorflow/compiler/mlir/tensorflow:error_util", "//tensorflow/compiler/mlir/tensorflow:import_model", "//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_flags", + "//tensorflow/compiler/mlir/tfrt:backend_compiler", "//tensorflow/compiler/mlir/tfrt:import_model", "//tensorflow/compiler/mlir/tfrt:tfrt_compile_options", "//tensorflow/compiler/mlir/tfrt:transforms/update_op_cost_in_tfrt_mlir", @@ -96,6 +97,7 @@ cc_library( "//tensorflow/core/runtime_fallback/kernel:kernel_fallback_compat_request_state", "//tensorflow/core/runtime_fallback/kernel:kernel_fallback_execute_compat", "//tensorflow/core/runtime_fallback/kernel:kernel_fallback_utils", + "//tensorflow/core/tfrt/common:metrics", "//tensorflow/core/tfrt/fallback:cost_recorder", "//tensorflow/core/tfrt/fallback:fallback_state", "//tensorflow/core/tfrt/fallback:op_kernel_runner", @@ -116,6 +118,7 @@ cc_library( "//tensorflow/core/tfrt/utils:tfrt_graph_execution_state", "@com_google_absl//absl/base", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", @@ -123,9 +126,11 @@ cc_library( "@com_google_absl//absl/time", "@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:span", + "@llvm-project//llvm:Support", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:FuncExtensions", "@llvm-project//mlir:IR", + "@local_tsl//tsl/concurrency:async_value", "@local_tsl//tsl/concurrency:ref_count", "@local_tsl//tsl/platform:refcount", "@local_tsl//tsl/platform:status", @@ -149,25 +154,39 @@ tf_cc_test( srcs = ["graph_executor_test.cc"], tags = ["no_oss"], deps = [ + ":graph_execution_options", ":graph_executor", "//tensorflow/cc:array_ops", "//tensorflow/cc:cc_ops", "//tensorflow/cc:const_op", + "//tensorflow/cc:ops", + "//tensorflow/cc:scope", "//tensorflow/core:core_cpu_base", + "//tensorflow/core:framework", + "//tensorflow/core:framework_lite", + "//tensorflow/core:framework_types_hdr", "//tensorflow/core:test", + "//tensorflow/core/framework:common_shape_fns", "//tensorflow/core/framework:graph_proto_cc", + "//tensorflow/core/framework:op", + "//tensorflow/core/framework:tensor", "//tensorflow/core/framework:types_proto_cc", "//tensorflow/core/grappler/utils:grappler_test", + "//tensorflow/core/platform:status", "//tensorflow/core/platform:statusor", "//tensorflow/core/protobuf:for_core_protos_cc", + "//tensorflow/core/tfrt/fallback:fallback_state", "//tensorflow/core/tfrt/mlrt/interpreter:context", "//tensorflow/core/tfrt/mlrt/interpreter:value", "//tensorflow/core/tfrt/mlrt/kernel", "//tensorflow/core/tfrt/saved_model:saved_model_testutil", + "@com_google_absl//absl/status", "@com_google_absl//absl/time", + "@com_google_absl//absl/types:span", "@com_google_googletest//:gtest_main", "@local_tsl//tsl/platform:status", "@local_tsl//tsl/platform:statusor", + "@tf_runtime//:hostcontext", "@tf_runtime//:tensor", "@tf_runtime//cpp_tests:common", ] + if_google( diff --git a/tensorflow/core/tfrt/graph_executor/graph_executor.cc b/tensorflow/core/tfrt/graph_executor/graph_executor.cc index 1dec5ab1f1e053..414f64dbda95b2 100644 --- a/tensorflow/core/tfrt/graph_executor/graph_executor.cc +++ b/tensorflow/core/tfrt/graph_executor/graph_executor.cc @@ -25,6 +25,7 @@ limitations under the License. #include #include +#include "absl/container/inlined_vector.h" #include "absl/log/check.h" #include "absl/log/log.h" #include "absl/status/status.h" @@ -35,6 +36,8 @@ limitations under the License. #include "absl/time/time.h" #include "absl/types/optional.h" #include "absl/types/span.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" #include "mlir/Dialect/Func/Extensions/AllExtensions.h" // from @llvm-project #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project @@ -46,6 +49,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h" #include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h" +#include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h" #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h" #include "tensorflow/compiler/mlir/tfrt/transforms/mlrt/import_model.h" #include "tensorflow/compiler/mlir/tfrt/transforms/update_op_cost_in_tfrt_mlir.h" @@ -57,7 +61,9 @@ limitations under the License. #include "tensorflow/core/framework/rendezvous.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/lib/gtl/cleanup.h" +#include "tensorflow/core/lib/monitoring/sampler.h" #include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/status.h" #include "tensorflow/core/platform/statusor.h" #include "tensorflow/core/platform/tstring.h" @@ -65,17 +71,22 @@ limitations under the License. #include "tensorflow/core/protobuf/config.pb.h" #include "tensorflow/core/public/session.h" #include "tensorflow/core/public/version.h" +#include "tensorflow/core/runtime_fallback/kernel/kernel_fallback_compat_request_state.h" #include "tensorflow/core/runtime_fallback/kernel/kernel_fallback_utils.h" +#include "tensorflow/core/tfrt/common/metrics.h" #include "tensorflow/core/tfrt/fallback/cost_recorder.h" #include "tensorflow/core/tfrt/fallback/fallback_state.h" +#include "tensorflow/core/tfrt/fallback/op_kernel_runner.h" #include "tensorflow/core/tfrt/graph_executor/executable_context.h" #include "tensorflow/core/tfrt/graph_executor/export_mlir.h" #include "tensorflow/core/tfrt/graph_executor/graph_execution_options.h" #include "tensorflow/core/tfrt/graph_executor/sync_resource_state.h" #include "tensorflow/core/tfrt/mlrt/bytecode/bytecode.h" #include "tensorflow/core/tfrt/mlrt/bytecode/executable.h" +#include "tensorflow/core/tfrt/mlrt/bytecode/function.h" #include "tensorflow/core/tfrt/mlrt/interpreter/context.h" #include "tensorflow/core/tfrt/mlrt/interpreter/execute.h" +#include "tensorflow/core/tfrt/mlrt/interpreter/value.h" #include "tensorflow/core/tfrt/mlrt/kernel/context.h" #include "tensorflow/core/tfrt/runtime/runtime.h" #include "tensorflow/core/tfrt/runtime/step_id.h" @@ -83,11 +94,14 @@ limitations under the License. #include "tensorflow/core/tfrt/runtime/work_queue_interface.h" #include "tensorflow/core/tfrt/stubs/tfrt_native_lowering_stub.h" #include "tensorflow/core/tfrt/utils/fallback_tensor.h" +#include "tensorflow/core/tfrt/utils/tfrt_graph_execution_state.h" #include "tensorflow/core/tfrt/utils/utils.h" +#include "tsl/concurrency/async_value_ref.h" #include "tsl/platform/errors.h" #include "tsl/platform/refcount.h" #include "tsl/platform/statusor.h" #include "tsl/profiler/lib/traceme.h" +#include "tfrt/bef/bef_buffer.h" // from @tf_runtime #include "tfrt/bef_converter/mlir_to_bef.h" // from @tf_runtime #include "tfrt/core_runtime/core_runtime.h" // from @tf_runtime #include "tfrt/host_context/async_dispatch.h" // from @tf_runtime @@ -599,7 +613,11 @@ tensorflow::Status GraphExecutor::Run( } // Possibly record costs, depending on the particular setting of - // `CostAnalysisOptions`. + // `CostAnalysisOptions`. As of this comment, for correctness of that feature, + // the time needs to be created after the client graph is created + // + // To reduce system calls, this value is also used for timing the duration of + // `::Run`. auto now = absl::Now() + simulated_duration_; bool do_recompilation; CostRecorder* cost_recorder = @@ -634,7 +652,10 @@ tensorflow::Status GraphExecutor::Run( (*outputs)[original_index] = std::move(*flat_output_iter); ++flat_output_iter; } - + absl::Time end = absl::Now() + simulated_duration_; + absl::Duration elapsed_duration = end - now; + loaded_client_graph.latency_sampler()->Add( + absl::ToDoubleMicroseconds(elapsed_duration)); return OkStatus(); } @@ -649,7 +670,7 @@ GraphExecutor::ImportAndCompileClientGraph( // Step 1 of loading: Import the client graph from proto to an MLIR module. auto import_start_time = absl::Now(); mlir::DialectRegistry registry; - RegisterMlirDialect(registry); + RegisterMlirDialect(registry, options_.compile_options.backend_compiler); // Disable multi-threading in lazy loading as the thread pool it uses is out // of our control and this affects serving performance. // @@ -749,12 +770,15 @@ GraphExecutor::ImportAndCompileClientGraph( LOG(INFO) << "TFRT finished compiling client graph (" << &client_graph << "). Took " << absl::ToInt64Milliseconds(compile_duration) << " ms. Client graph name: " << client_graph.name; - + auto* latency_sampler = + tensorflow::tfrt_metrics::GetTfrtGraphExecutorLatencySampler( + options_.model_metadata.name(), options_.model_metadata.version(), + client_graph.name); return std::make_unique( client_graph.name, std::move(symbol_uids), this, std::move(context), std::move(module_with_op_keys), std::move(module), std::move(executable_context), stream_callback_id, - !checkpoint_path.empty(), std::move(flib_def)); + !checkpoint_path.empty(), std::move(flib_def), latency_sampler); } StatusOr> @@ -1073,7 +1097,8 @@ GraphExecutor::LoadedClientGraph::LoadedClientGraph( mlir::OwningOpRef tfrt_mlir, std::shared_ptr executable_context, std::optional stream_callback_id, bool is_restore, - FunctionLibraryDefinition flib_def) + FunctionLibraryDefinition flib_def, + tensorflow::monitoring::SamplerCell* latency_sampler) : name_(std::move(name)), symbol_uids_(std::move(symbol_uids)), graph_executor_(graph_executor), @@ -1097,7 +1122,8 @@ GraphExecutor::LoadedClientGraph::LoadedClientGraph( *r = tsl::core::RefCountPtr( new IntraProcessRendezvous(device_mgr)); return OkStatus(); - }}) { + }}), + latency_sampler_(latency_sampler) { const auto& options = graph_executor_->options().cost_analysis_options; if (options.version != Options::CostAnalysisOptions::kDisabled) { // Initialize in a way that ensures recompilation on the first run. @@ -1151,9 +1177,13 @@ tensorflow::Status GraphExecutor::CompileGraph( .status(); } -void RegisterMlirDialect(mlir::DialectRegistry& registry) { +void RegisterMlirDialect(mlir::DialectRegistry& registry, + tensorflow::BackendCompiler* backend_compiler) { registry.insert(); mlir::RegisterAllTensorFlowDialects(registry); + if (backend_compiler) { + backend_compiler->GetDependentDialects(registry); + } } } // namespace tfrt_stub diff --git a/tensorflow/core/tfrt/graph_executor/graph_executor.h b/tensorflow/core/tfrt/graph_executor/graph_executor.h index 66c291dcf530d5..6deee210444099 100644 --- a/tensorflow/core/tfrt/graph_executor/graph_executor.h +++ b/tensorflow/core/tfrt/graph_executor/graph_executor.h @@ -33,10 +33,12 @@ limitations under the License. #include "mlir/IR/MLIRContext.h" // from @llvm-project #include "mlir/IR/OwningOpRef.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h" +#include "tensorflow/compiler/mlir/tfrt/backend_compiler.h" #include "tensorflow/core/common_runtime/process_function_library_runtime.h" #include "tensorflow/core/framework/cancellation.h" #include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/lib/monitoring/sampler.h" #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/status.h" #include "tensorflow/core/platform/statusor.h" @@ -156,7 +158,8 @@ class GraphExecutor { mlir::OwningOpRef tfrt_mlir, std::shared_ptr executable_context, std::optional stream_callback_id, - bool is_restore, FunctionLibraryDefinition flib_def); + bool is_restore, FunctionLibraryDefinition flib_def, + tensorflow::monitoring::SamplerCell* latency_sampler); // Returns this instance's CostRecorder if it is time to update costs, // else returns nullptr. Only allows one non-null return value at a time @@ -192,6 +195,9 @@ class GraphExecutor { const { return pflr_; } + tensorflow::monitoring::SamplerCell* latency_sampler() { + return latency_sampler_; + } private: std::string name_; @@ -230,6 +236,7 @@ class GraphExecutor { bool is_restore_; FunctionLibraryDefinition flib_def_; ProcessFunctionLibraryRuntime pflr_; + tensorflow::monitoring::SamplerCell* latency_sampler_; }; // A subgraph constructed by specifying input/output tensors. @@ -376,7 +383,8 @@ class GraphExecutor { int num_recompilations_ TF_GUARDED_BY(num_recompilations_mu_) = 0; }; -void RegisterMlirDialect(mlir::DialectRegistry& registry); +void RegisterMlirDialect(mlir::DialectRegistry& registry, + tensorflow::BackendCompiler* backend_compiler); } // namespace tfrt_stub } // namespace tensorflow diff --git a/tensorflow/core/tfrt/graph_executor/graph_executor_test.cc b/tensorflow/core/tfrt/graph_executor/graph_executor_test.cc index 7088231d3ff3e6..0a0f073fe589fa 100644 --- a/tensorflow/core/tfrt/graph_executor/graph_executor_test.cc +++ b/tensorflow/core/tfrt/graph_executor/graph_executor_test.cc @@ -14,8 +14,8 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/tfrt/graph_executor/graph_executor.h" +#include #include -#include #include #include #include @@ -24,23 +24,36 @@ limitations under the License. #include "learning/brain/experimental/tfrt/native_lowering/kernels/sync_fallback_kernels.h" #include #include +#include "absl/status/status.h" #include "absl/time/time.h" +#include "absl/types/span.h" +#include "tensorflow/cc/framework/ops.h" +#include "tensorflow/cc/framework/scope.h" #include "tensorflow/cc/ops/array_ops.h" #include "tensorflow/cc/ops/const_op.h" +#include "tensorflow/core/framework/common_shape_fns.h" #include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.h" #include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/graph/graph.h" #include "tensorflow/core/graph/graph_def_builder.h" -#include "tensorflow/core/grappler/utils/grappler_test.h" -#include "tensorflow/core/platform/statusor.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/status.h" #include "tensorflow/core/protobuf/rewriter_config.pb.h" +#include "tensorflow/core/tfrt/fallback/fallback_state.h" +#include "tensorflow/core/tfrt/graph_executor/graph_execution_options.h" #include "tensorflow/core/tfrt/mlrt/interpreter/context.h" #include "tensorflow/core/tfrt/mlrt/interpreter/value.h" #include "tensorflow/core/tfrt/mlrt/kernel/kernel.h" #include "tensorflow/core/tfrt/saved_model/saved_model_testutil.h" #include "tsl/lib/core/status_test_util.h" -#include "tsl/platform/status.h" #include "tsl/platform/statusor.h" #include "tfrt/cpp_tests/test_util.h" // from @tf_runtime +#include "tfrt/host_context/resource_context.h" // from @tf_runtime #include "tfrt/tensor/dense_host_tensor.h" // from @tf_runtime namespace tensorflow { diff --git a/tensorflow/core/tfrt/ifrt/BUILD b/tensorflow/core/tfrt/ifrt/BUILD index df5dac9b0d91c7..d58e0e14f7ec03 100644 --- a/tensorflow/core/tfrt/ifrt/BUILD +++ b/tensorflow/core/tfrt/ifrt/BUILD @@ -29,6 +29,7 @@ cc_library( srcs = ["ifrt_serving_executable.cc"], hdrs = ["ifrt_serving_executable.h"], deps = [ + ":ifrt_loaded_variable_registry", ":ifrt_tensor_utils", ":sharding_utils", "//tensorflow/compiler/mlir/tfrt/transforms/ifrt:tf2hlo", @@ -50,6 +51,7 @@ cc_library( "@local_tsl//tsl/platform:env", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/platform:tstring", "@local_xla//xla:xla_data_proto_cc", "@local_xla//xla/hlo/ir:hlo", "@local_xla//xla/pjrt:pjrt_executable", @@ -76,12 +78,32 @@ cc_library( ], ) +cc_library( + name = "ifrt_loaded_variable_registry", + srcs = ["ifrt_loaded_variable_registry.cc"], + hdrs = ["ifrt_loaded_variable_registry.h"], + deps = [ + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/functional:any_invocable", + "@com_google_absl//absl/log", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/synchronization", + "@local_tsl//tsl/concurrency:ref_count", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:statusor", + "@local_xla//xla/python/ifrt", + ], +) + cc_library( name = "ifrt_model_context", srcs = ["ifrt_model_context.cc"], hdrs = ["ifrt_model_context.h"], deps = [ ":ifrt_executable_registry", + ":ifrt_loaded_variable_registry", "//tensorflow/compiler/tf2xla:xla_helpers", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/status", @@ -89,6 +111,9 @@ cc_library( "@com_google_absl//absl/strings", "@com_google_absl//absl/synchronization", "@eigen_archive//:eigen3", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Parser", "@local_tsl//tsl/concurrency:ref_count", "@local_xla//xla/python/ifrt", ], @@ -181,30 +206,6 @@ tf_cc_test( ], ) -tf_cc_test( - name = "ifrt_model_context_test", - srcs = [ - "ifrt_model_context_test.cc", - ], - tags = ["no_oss"], - deps = [ - ":ifrt_model_context", - "//tensorflow/core:test", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@com_google_googletest//:gtest_main", - "@eigen_archive//:eigen3", - "@local_tsl//tsl/concurrency:ref_count", - "@local_tsl//tsl/platform:env", - "@local_tsl//tsl/platform:status_matchers", - "@local_tsl//tsl/platform:statusor", - "@local_xla//xla/python/ifrt", - "@local_xla//xla/python/ifrt:test_util", - "@local_xla//xla/python/pjrt_ifrt:tfrt_cpu_client_test_lib", - ], -) - tf_cc_test( name = "ifrt_serving_executable_test", srcs = [ @@ -215,7 +216,9 @@ tf_cc_test( ], tags = ["no_oss"], deps = [ + ":ifrt_loaded_variable_registry", ":ifrt_serving_executable", + ":sharding_utils", "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/tf2xla:xla_helpers", "//tensorflow/core:framework", @@ -225,6 +228,7 @@ tf_cc_test( "//tensorflow/core/framework:tensor_testutil", "//tensorflow/core/framework:types_proto_cc", "//tensorflow/core/platform:resource_loader", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", "@com_google_googletest//:gtest_main", @@ -232,8 +236,11 @@ tf_cc_test( "@llvm-project//mlir:AllPassesAndDialects", "@llvm-project//mlir:IR", "@llvm-project//mlir:Parser", + "@local_tsl//tsl/concurrency:ref_count", "@local_tsl//tsl/platform:env", "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/platform:tstring", + "@local_xla//xla/hlo/ir:hlo", "@local_xla//xla/python/ifrt", "@local_xla//xla/python/ifrt:test_util", "@local_xla//xla/python/pjrt_ifrt:tfrt_cpu_client_test_lib", diff --git a/tensorflow/core/tfrt/ifrt/ifrt_loaded_variable_registry.cc b/tensorflow/core/tfrt/ifrt/ifrt_loaded_variable_registry.cc new file mode 100644 index 00000000000000..213d75431df6c9 --- /dev/null +++ b/tensorflow/core/tfrt/ifrt/ifrt_loaded_variable_registry.cc @@ -0,0 +1,60 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/tfrt/ifrt/ifrt_loaded_variable_registry.h" + +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "absl/synchronization/mutex.h" +#include "xla/python/ifrt/array.h" +#include "tsl/concurrency/ref_count.h" +#include "tsl/platform/statusor.h" + +namespace tensorflow { +namespace ifrt_serving { + +absl::Status IfrtLoadedVariableRegistry::TryRegisterLoadedVariable( + absl::string_view name, + LoadedVariableConstructor&& loaded_variable_constructor) { + absl::MutexLock lock(&mutex_); + auto& variable = loaded_variable_map_[name]; + if (variable != nullptr) { + // Already registered. This is rare. + VLOG(1) << "Variable '" << name << "' already registered."; + return absl::OkStatus(); + } + TF_ASSIGN_OR_RETURN(variable, loaded_variable_constructor()); + return absl::OkStatus(); +} + +absl::StatusOr> +IfrtLoadedVariableRegistry::GetLoadedVariable(absl::string_view name) const { + absl::MutexLock lock(&mutex_); + auto it = loaded_variable_map_.find(name); + if (it == loaded_variable_map_.end()) { + return absl::NotFoundError( + absl::StrCat("Variable '", name, "' not found.")); + } + return it->second; +} + +} // namespace ifrt_serving +} // namespace tensorflow diff --git a/tensorflow/core/tfrt/ifrt/ifrt_loaded_variable_registry.h b/tensorflow/core/tfrt/ifrt/ifrt_loaded_variable_registry.h new file mode 100644 index 00000000000000..ccfc4aa3d46a1b --- /dev/null +++ b/tensorflow/core/tfrt/ifrt/ifrt_loaded_variable_registry.h @@ -0,0 +1,60 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_TFRT_IFRT_IFRT_LOADED_VARIABLE_REGISTRY_H_ +#define TENSORFLOW_CORE_TFRT_IFRT_IFRT_LOADED_VARIABLE_REGISTRY_H_ + +#include "absl/container/flat_hash_map.h" +#include "absl/functional/any_invocable.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/synchronization/mutex.h" +#include "xla/python/ifrt/array.h" +#include "tsl/concurrency/ref_count.h" + +namespace tensorflow { +namespace ifrt_serving { + +// This class is thread safe. +class IfrtLoadedVariableRegistry { + public: + using LoadedVariableConstructor = + absl::AnyInvocable>() + const>; + + // Tries to register a loaded variable with the given name. + // Returns an error if the named array does not already exists and + // loaded_variable_constructor failed to create an array. Note that it returns + // OK if the named array already exists. + // loaded_variable_constructor is invoked in the caller thread. + absl::Status TryRegisterLoadedVariable( + absl::string_view name, + LoadedVariableConstructor&& loaded_variable_constructor) + ABSL_LOCKS_EXCLUDED(mutex_); + + absl::StatusOr> GetLoadedVariable( + absl::string_view name) const ABSL_LOCKS_EXCLUDED(mutex_); + + private: + mutable absl::Mutex mutex_; + absl::flat_hash_map> + loaded_variable_map_ ABSL_GUARDED_BY(mutex_); +}; + +} // namespace ifrt_serving +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_TFRT_IFRT_IFRT_LOADED_VARIABLE_REGISTRY_H_ diff --git a/tensorflow/core/tfrt/ifrt/ifrt_model_context.cc b/tensorflow/core/tfrt/ifrt/ifrt_model_context.cc index 4942407cfdab7e..fa74302f625dfd 100644 --- a/tensorflow/core/tfrt/ifrt/ifrt_model_context.cc +++ b/tensorflow/core/tfrt/ifrt/ifrt_model_context.cc @@ -18,12 +18,8 @@ limitations under the License. #include -#include "absl/container/flat_hash_map.h" #include "absl/status/status.h" -#include "absl/status/statusor.h" -#include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" -#include "absl/synchronization/mutex.h" // Enable Eigen::ThreadPoolDevice structure definition, rather than just // declaration. @@ -38,29 +34,6 @@ namespace ifrt_serving { const Eigen::ThreadPoolDevice& IfrtModelContext::GetThreadPoolDevice() const { return thread_pool_device_; } -absl::Status IfrtModelContext::RegisterLoadedVariable( - absl::string_view name, - tsl::RCReference loaded_variable) { - absl::MutexLock lock(&mutex_); - auto& variable = loaded_variable_map_[name]; - if (variable != nullptr) { - return absl::AlreadyExistsError( - absl::StrCat("Variable '", name, "' already exists.")); - } - variable = std::move(loaded_variable); - return absl::OkStatus(); -} - -absl::StatusOr> -IfrtModelContext::GetLoadedVariable(absl::string_view name) const { - absl::MutexLock lock(&mutex_); - auto it = loaded_variable_map_.find(name); - if (it == loaded_variable_map_.end()) { - return absl::NotFoundError( - absl::StrCat("Variable '", name, "' not found.")); - } - return it->second; -} } // namespace ifrt_serving } // namespace tensorflow diff --git a/tensorflow/core/tfrt/ifrt/ifrt_model_context.h b/tensorflow/core/tfrt/ifrt/ifrt_model_context.h index 3eb8a282e02cb7..1c942ce0e045d7 100644 --- a/tensorflow/core/tfrt/ifrt/ifrt_model_context.h +++ b/tensorflow/core/tfrt/ifrt/ifrt_model_context.h @@ -31,6 +31,7 @@ limitations under the License. #include "xla/python/ifrt/array.h" #include "xla/python/ifrt/client.h" #include "tensorflow/core/tfrt/ifrt/ifrt_executable_registry.h" +#include "tensorflow/core/tfrt/ifrt/ifrt_loaded_variable_registry.h" #include "tsl/concurrency/ref_count.h" namespace tensorflow { @@ -74,13 +75,12 @@ class IfrtModelContext { const Eigen::ThreadPoolDevice& GetThreadPoolDevice() const; - absl::Status RegisterLoadedVariable( - absl::string_view name, - tsl::RCReference loaded_variable) - ABSL_LOCKS_EXCLUDED(mutex_); - - absl::StatusOr> GetLoadedVariable( - absl::string_view name) const ABSL_LOCKS_EXCLUDED(mutex_); + const IfrtLoadedVariableRegistry& GetLoadedVariableRegistry() const { + return loaded_variable_registry_; + } + IfrtLoadedVariableRegistry& GetLoadedVariableRegistry() { + return loaded_variable_registry_; + } private: std::shared_ptr client_; @@ -90,9 +90,7 @@ class IfrtModelContext { std::vector handles_; - mutable absl::Mutex mutex_; - absl::flat_hash_map> - loaded_variable_map_ ABSL_GUARDED_BY(mutex_); + IfrtLoadedVariableRegistry loaded_variable_registry_; }; } // namespace ifrt_serving diff --git a/tensorflow/core/tfrt/ifrt/ifrt_model_context_test.cc b/tensorflow/core/tfrt/ifrt/ifrt_model_context_test.cc deleted file mode 100644 index 201a9050801298..00000000000000 --- a/tensorflow/core/tfrt/ifrt/ifrt_model_context_test.cc +++ /dev/null @@ -1,115 +0,0 @@ -/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -// Enable definition of Eigen::ThreadPoolDevice instead of just declaration. -#define EIGEN_USE_THREADS - -#include "tensorflow/core/tfrt/ifrt/ifrt_model_context.h" - -#include -#include -#include -#include - -#include -#include -#include "absl/status/status.h" -#include "absl/status/statusor.h" -#include "absl/strings/string_view.h" -#include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive -#include "xla/python/ifrt/array.h" -#include "xla/python/ifrt/client.h" -#include "xla/python/ifrt/dtype.h" -#include "xla/python/ifrt/shape.h" -#include "xla/python/ifrt/sharding.h" -#include "xla/python/ifrt/test_util.h" -#include "tensorflow/core/platform/test.h" -#include "tsl/concurrency/ref_count.h" -#include "tsl/lib/core/status_test_util.h" -#include "tsl/platform/env.h" -#include "tsl/platform/status_matchers.h" -#include "tsl/platform/statusor.h" -#include "tsl/platform/threadpool.h" - -namespace tensorflow { -namespace ifrt_serving { -namespace { - -Eigen::ThreadPoolDevice GetThreadPoolDevice() { - constexpr int kMaxParallelism = 16; - static tsl::thread::ThreadPool* thread_pool = []() { - return new tsl::thread::ThreadPool(tsl::Env::Default(), - tsl::ThreadOptions(), "IfrtSharding", - kMaxParallelism); - }(); - return Eigen::ThreadPoolDevice(thread_pool->AsEigenThreadPool(), - kMaxParallelism); -} - -absl::StatusOr> CreateDummyArray( - xla::ifrt::Client& client) { - xla::ifrt::DType dtype(xla::ifrt::DType::kF32); - xla::ifrt::Shape shape({2, 3}); - std::vector data(6); - std::iota(data.begin(), data.end(), 0); - - return client.MakeArrayFromHostBuffer( - data.data(), dtype, shape, - /*byte_strides=*/std::nullopt, - xla::ifrt::SingleDeviceSharding::Create(client.devices()[0], - xla::ifrt::MemoryKind()), - xla::ifrt::Client::HostBufferSemantics::kImmutableOnlyDuringCall, - /*on_done_with_host_buffer=*/{}); -} - -TEST(IfrtModelContext, ReRegisterShallFail) { - // Create contexts required for the compiler execution. - TF_ASSERT_OK_AND_ASSIGN(std::shared_ptr client, - xla::ifrt::test_util::GetClient()); - Eigen::ThreadPoolDevice thread_pool_device = GetThreadPoolDevice(); - IfrtModelContext ifrt_model_context(client, &thread_pool_device); - - TF_ASSERT_OK_AND_ASSIGN(tsl::RCReference loaded_variable, - CreateDummyArray(*client)); - - absl::string_view variable_name = "variable"; - - TF_ASSERT_OK(ifrt_model_context.RegisterLoadedVariable(variable_name, - loaded_variable)); - - auto re_register_status = - ifrt_model_context.RegisterLoadedVariable(variable_name, loaded_variable); - - EXPECT_THAT(re_register_status, - tsl::testing::StatusIs(absl::StatusCode::kAlreadyExists)); -} - -TEST(IfrtModelContext, GetUnregisterVariableShallFail) { - // Create contexts required for the compiler execution. - TF_ASSERT_OK_AND_ASSIGN(std::shared_ptr client, - xla::ifrt::test_util::GetClient()); - Eigen::ThreadPoolDevice thread_pool_device = GetThreadPoolDevice(); - IfrtModelContext ifrt_model_context(client, &thread_pool_device); - - absl::string_view variable_name = "variable"; - - auto statusor = ifrt_model_context.GetLoadedVariable(variable_name); - - EXPECT_THAT(statusor.status(), - tsl::testing::StatusIs(absl::StatusCode::kNotFound)); -} - -} // namespace -} // namespace ifrt_serving -} // namespace tensorflow diff --git a/tensorflow/core/tfrt/ifrt/ifrt_serving_executable.cc b/tensorflow/core/tfrt/ifrt/ifrt_serving_executable.cc index 1e6daad32f0adc..fb77f5c1671019 100644 --- a/tensorflow/core/tfrt/ifrt/ifrt_serving_executable.cc +++ b/tensorflow/core/tfrt/ifrt/ifrt_serving_executable.cc @@ -17,6 +17,7 @@ limitations under the License. #include #include +#include #include #include @@ -50,15 +51,44 @@ limitations under the License. #include "tensorflow/core/framework/tensor_shape.pb.h" #include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/protobuf/tpu/compile_metadata.pb.h" +#include "tensorflow/core/tfrt/ifrt/ifrt_loaded_variable_registry.h" #include "tensorflow/core/tfrt/ifrt/ifrt_tensor_utils.h" #include "tensorflow/core/tfrt/ifrt/sharding_utils.h" #include "tsl/concurrency/ref_count.h" #include "tsl/platform/errors.h" #include "tsl/platform/statusor.h" +#include "tsl/platform/tstring.h" namespace tensorflow { namespace ifrt_serving { namespace { +absl::StatusOr> BuildDtypeAndShape( + absl::Span inputs, + absl::Span variable_arg_indices, + const IfrtLoadedVariableRegistry& ifrt_loaded_variable_registry) { + std::vector dtypes_and_shapes; + dtypes_and_shapes.reserve(inputs.size()); + + int variable_index = 0; + for (int i = 0; i < inputs.size(); i++) { + if (variable_index < variable_arg_indices.size() && + i == variable_arg_indices[variable_index]) { + // Get already loaded variable tensor. + TF_ASSIGN_OR_RETURN(auto single_array, + ifrt_loaded_variable_registry.GetLoadedVariable( + inputs[i].scalar()())); + TF_ASSIGN_OR_RETURN(auto dtype, ToTensorDataType(single_array->dtype())); + dtypes_and_shapes.push_back(DtypeAndShape{ + .dtype = dtype, .shape = ToTensorShape(single_array->shape())}); + + variable_index++; + } else { + dtypes_and_shapes.push_back(DtypeAndShape{.dtype = inputs[i].dtype(), + .shape = inputs[i].shape()}); + } + } + return dtypes_and_shapes; +} absl::StatusOr GetXlaDeviceAssignment( const xla::ifrt::Client& ifrt_client, @@ -111,28 +141,6 @@ absl::StatusOr> GetAssignedDevices( return devices; } -absl::StatusOr> -CreateArrayFromHostTensorForSingleDevice(xla::ifrt::Client& ifrt_client, - const tensorflow::Tensor& tensor, - xla::ifrt::Device* device) { - TF_ASSIGN_OR_RETURN(auto dtype, ToIfrtDType(tensor.dtype())); - - VLOG(2) << "Make single device array for buffer slice " - << " at " << tensor.data(); - auto single_device_sharding = - xla::ifrt::SingleDeviceSharding::Create(device, xla::ifrt::MemoryKind()); - - return ifrt_client.MakeArrayFromHostBuffer( - tensor.data(), dtype, ToIfrtShape(tensor.shape()), - /*byte_strides=*/{}, std::move(single_device_sharding), - xla::ifrt::Client::HostBufferSemantics::kImmutableUntilTransferCompletes, - [tensor]() { - // Keep tensor alive - VLOG(2) << "Done with single device host buffer for slice " - << " at " << tensor.data(); - }); -} - } // namespace absl::StatusOr> @@ -143,55 +151,18 @@ IfrtServingExecutable::ConvertTensorToArray( VLOG(2) << "Converting tensor of shape " << input_shape; TF_ASSIGN_OR_RETURN(auto hlo_sharding, xla::HloSharding::FromProto(sharding)); - VLOG(3) << "IsTiled: " << hlo_sharding.IsTiled(); - VLOG(3) << "IsReplicated: " << hlo_sharding.IsReplicated(); - VLOG(3) << "IsTileMaximal: " << hlo_sharding.IsTileMaximal(); - if (!hlo_sharding.IsTiled() && !hlo_sharding.IsReplicated() && - !hlo_sharding.IsTileMaximal()) { - return absl::UnimplementedError(absl::StrCat( - "Only support MAXIMAL, OTHER or REPLICATED, but got sharding : ", - hlo_sharding.ToString())); - } - - VLOG(1) << "Hlo sharding: " << hlo_sharding.ToString(); - VLOG(1) << "Device list size: " << device_list.size(); - if (device_list.size() == 1) { - if (hlo_sharding.IsTiled()) { - return absl::InvalidArgumentError( - absl::StrCat("Tiled sharding", hlo_sharding.ToString(), - " expect more than 1 device, but got 1 only")); - } - return CreateArrayFromHostTensorForSingleDevice(*ifrt_client_, tensor, - device_list[0]); - } - - // Replicate implies Maximal, but not vice versa. Only Maximal is single - // device. - if (!hlo_sharding.IsReplicated() && hlo_sharding.IsTileMaximal()) { - VLOG(1) << "Single device fast path for Maximal tiled tensor"; - xla::ifrt::Device* device; - if (hlo_sharding.HasUniqueDevice()) { - int unique_device_id = hlo_sharding.GetUniqueDevice(); - TF_ASSIGN_OR_RETURN(device, ifrt_client_->LookupDevice(unique_device_id)); - } else { - device = device_list[0]; - } - return CreateArrayFromHostTensorForSingleDevice(*ifrt_client_, tensor, - device); - } - - return MakeAssembledArrayFromHostBuffer(*ifrt_client_, tensor, - std::move(hlo_sharding), device_list, - thread_pool_device_); + return MakeArrayFromTensor(*ifrt_client_, tensor, device_list, + std::move(hlo_sharding), thread_pool_device_); } absl::StatusOr IfrtServingExecutable::CreateExecutableSynchronously( - absl::Span inputs) { - TF_ASSIGN_OR_RETURN(Tf2HloResult tf2hlo_result, - CompileTfToHlo(*module_, inputs, signature_name(), - *ifrt_client_, shape_representation_fn_)); + absl::Span dtypes_and_shapes) { + TF_ASSIGN_OR_RETURN( + Tf2HloResult tf2hlo_result, + CompileTfToHlo(*module_, dtypes_and_shapes, signature_name(), + *ifrt_client_, shape_representation_fn_)); const int num_replicas = tf2hlo_result.compile_metadata.num_replicas(); const int num_partitions = @@ -240,10 +211,10 @@ IfrtServingExecutable::CreateExecutableSynchronously( xla::ifrt::Future> IfrtServingExecutable::LookUpOrCreateExecutable( - absl::Span inputs) { + absl::Span dtypes_and_shapes) { std::vector input_shapes; - for (const auto& tensor : inputs) { - input_shapes.push_back(tensor.shape()); + for (const auto& dtype_and_shape : dtypes_and_shapes) { + input_shapes.push_back(dtype_and_shape.shape); } Key key = {input_shapes}; @@ -268,19 +239,50 @@ IfrtServingExecutable::LookUpOrCreateExecutable( LOG(INFO) << "Cache missed. Building executable"; absl::StatusOr executable_bundle = - CreateExecutableSynchronously(inputs); + CreateExecutableSynchronously(dtypes_and_shapes); promise.Set(std::move(executable_bundle)); return future; } absl::StatusOr> IfrtServingExecutable::Execute( - absl::Span inputs) { - TF_ASSIGN_OR_RETURN(CachedExecutableBundle executable_bundle, - LookUpOrCreateExecutable(inputs).Await()); + absl::Span inputs, + absl::Span variable_arg_indices) { + for (int i = 1; i < variable_arg_indices.size(); i++) { + if (variable_arg_indices[i] <= variable_arg_indices[i - 1]) { + return absl::FailedPreconditionError(absl::StrCat( + "Expected variable_arg_indices in ascending order. But subsequence " + "starting at ", + i - 1, ": (", variable_arg_indices[i - 1], ", ", + variable_arg_indices[i], ")", " is not in ascending order")); + } + } - std::vector> args; - args.reserve(inputs.size()); + if (!variable_arg_indices.empty() && + inputs.size() <= variable_arg_indices.back()) { + return absl::FailedPreconditionError(absl::StrCat( + "Expected at most ", inputs.size(), " inputs, but got up to ", + variable_arg_indices.back(), " variables.")); + } + + // Ensure the variable tensor holds a valid key: a scalar string tensor. + for (const int i : variable_arg_indices) { + if (inputs[i].dtype() != tensorflow::DT_STRING || + !tensorflow::TensorShapeUtils::IsScalar(inputs[i].shape())) { + return absl::FailedPreconditionError( + absl::StrCat("Expected a scalar tensor as loaded variable array key, " + "but got type ", + inputs[i].dtype(), " and shape ", + inputs[i].shape().DebugString(), " at index ", i)); + } + } + + TF_ASSIGN_OR_RETURN(std::vector dtypes_and_shapes, + BuildDtypeAndShape(inputs, variable_arg_indices, + ifrt_loaded_variable_registry_)); + TF_ASSIGN_OR_RETURN( + CachedExecutableBundle executable_bundle, + LookUpOrCreateExecutable(absl::MakeSpan(dtypes_and_shapes)).Await()); TF_ASSIGN_OR_RETURN( std::vector devices, @@ -288,21 +290,35 @@ absl::StatusOr> IfrtServingExecutable::Execute( xla::ifrt::DeviceList device_list( xla::ifrt::DeviceList::Devices(devices.begin(), devices.end())); - auto compile_metadata_arg_iter = - executable_bundle.compile_metadata.args().begin(); - if (executable_bundle.compile_metadata.args().size() != inputs.size()) { + if (executable_bundle.compile_metadata.args().size() != + dtypes_and_shapes.size()) { return absl::InternalError(absl::StrCat( - "Expect ", executable_bundle.compile_metadata.args().size(), - " but got ", inputs.size(), " arguments")); + "Expected ", executable_bundle.compile_metadata.args().size(), + " but got ", dtypes_and_shapes.size(), " arguments")); } - for (const auto& input_tensor : inputs) { - TF_ASSIGN_OR_RETURN( - auto single_array, - ConvertTensorToArray(input_tensor, device_list, - compile_metadata_arg_iter->sharding())); - args.push_back(single_array); - compile_metadata_arg_iter++; + + std::vector> args; + args.reserve(inputs.size()); + + int variable_index = 0; + for (int i = 0; i < inputs.size(); i++) { + if (variable_index < variable_arg_indices.size() && + i == variable_arg_indices[variable_index]) { + TF_ASSIGN_OR_RETURN(auto single_array, + ifrt_loaded_variable_registry_.GetLoadedVariable( + inputs[i].scalar()())); + args.push_back(single_array); + variable_index++; + } else { + TF_ASSIGN_OR_RETURN( + auto single_array, + ConvertTensorToArray( + inputs[i], device_list, + executable_bundle.compile_metadata.args()[i].sharding())); + args.push_back(single_array); + } } + DCHECK_EQ(args.size(), dtypes_and_shapes.size()); VLOG(2) << "Start Execution"; diff --git a/tensorflow/core/tfrt/ifrt/ifrt_serving_executable.h b/tensorflow/core/tfrt/ifrt/ifrt_serving_executable.h index cc7d53154d0c28..678eaeefdce7f6 100644 --- a/tensorflow/core/tfrt/ifrt/ifrt_serving_executable.h +++ b/tensorflow/core/tfrt/ifrt/ifrt_serving_executable.h @@ -31,6 +31,7 @@ limitations under the License. #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project #include "mlir/IR/OwningOpRef.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf2hlo.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "xla/python/ifrt/array.h" #include "xla/python/ifrt/client.h" @@ -42,6 +43,7 @@ limitations under the License. #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/protobuf/tpu/compile_metadata.pb.h" +#include "tensorflow/core/tfrt/ifrt/ifrt_loaded_variable_registry.h" #include "tsl/concurrency/ref_count.h" namespace tensorflow { @@ -54,12 +56,14 @@ class IfrtServingExecutable { mlir::OwningOpRef module, std::shared_ptr client, const Eigen::ThreadPoolDevice* thread_pool_device, + const IfrtLoadedVariableRegistry* ifrt_loaded_variable_registry, tensorflow::XlaHelpers::ShapeRepresentationFn shape_representation_fn) : model_name_(std::string(model_name)), signature_name_(std::string(signature_name)), module_(std::move(module)), ifrt_client_(std::move(client)), thread_pool_device_(*thread_pool_device), + ifrt_loaded_variable_registry_(*ifrt_loaded_variable_registry), shape_representation_fn_(std::move(shape_representation_fn)) {} // Movable but not copyable. @@ -72,8 +76,10 @@ class IfrtServingExecutable { absl::string_view signature_name() const { return signature_name_; } // Executes the computation. + // variable_arg_indices are in sorted order. absl::StatusOr> Execute( - absl::Span inputs); + absl::Span inputs, + absl::Span variable_arg_indices); int num_executables() const { absl::MutexLock lock(&mutex_); @@ -113,6 +119,7 @@ class IfrtServingExecutable { std::shared_ptr ifrt_client_; const Eigen::ThreadPoolDevice& thread_pool_device_; + const IfrtLoadedVariableRegistry& ifrt_loaded_variable_registry_; tensorflow::XlaHelpers::ShapeRepresentationFn shape_representation_fn_; mutable absl::Mutex mutex_; @@ -126,9 +133,10 @@ class IfrtServingExecutable { const xla::OpSharding& sharding); xla::ifrt::Future> - LookUpOrCreateExecutable(absl::Span inputs); + LookUpOrCreateExecutable(absl::Span dtypes_and_shapes); absl::StatusOr - CreateExecutableSynchronously(absl::Span inputs); + CreateExecutableSynchronously( + absl::Span dtypes_and_shapes); absl::StatusOr> CreateSharding( int num_devices, const xla::ifrt::Shape& arg_xla_shape, diff --git a/tensorflow/core/tfrt/ifrt/ifrt_serving_executable_test.cc b/tensorflow/core/tfrt/ifrt/ifrt_serving_executable_test.cc index 391ba1b06b351e..a738f9e399c562 100644 --- a/tensorflow/core/tfrt/ifrt/ifrt_serving_executable_test.cc +++ b/tensorflow/core/tfrt/ifrt/ifrt_serving_executable_test.cc @@ -23,6 +23,7 @@ limitations under the License. #include #include +#include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" @@ -38,6 +39,8 @@ limitations under the License. #include "mlir/Parser/Parser.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "xla/hlo/ir/hlo_sharding.h" +#include "xla/python/ifrt/array.h" #include "xla/python/ifrt/client.h" #include "xla/python/ifrt/test_util.h" #include "tensorflow/core/framework/tensor.h" @@ -47,13 +50,27 @@ limitations under the License. #include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/platform/resource_loader.h" #include "tensorflow/core/platform/test.h" +#include "tensorflow/core/tfrt/ifrt/ifrt_loaded_variable_registry.h" +#include "tensorflow/core/tfrt/ifrt/sharding_utils.h" +#include "tsl/concurrency/ref_count.h" #include "tsl/platform/env.h" #include "tsl/platform/statusor.h" #include "tsl/platform/threadpool.h" +#include "tsl/platform/tstring.h" namespace tensorflow { namespace ifrt_serving { namespace { +struct VariableInputTestParam { + std::vector in_tensors; + std::vector + is_variable; // if is_variable[i] = true, then in_tensor[i] is a variable + // and can be preloaded as an ifrt array. + std::vector expected_out_tensors; +}; +using VariableInputTest = ::testing::TestWithParam; + +using ::tensorflow::test::AsTensor; using ::tensorflow::test::TensorEq; using ::testing::ElementsAre; @@ -91,21 +108,21 @@ TEST(IfrtServingExecutableTest, Basic) { xla::ifrt::test_util::GetClient()); Eigen::ThreadPoolDevice thread_pool_device = GetThreadPoolDevice(); + IfrtLoadedVariableRegistry ifrt_loaded_variable_registry; IfrtServingExecutable executable("test", "main", std::move(mlir_module), client, &thread_pool_device, + &ifrt_loaded_variable_registry, tensorflow::IdentityShapeRepresentationFn()); - auto x = tensorflow::test::AsTensor({1, 2, 3}, - tensorflow::TensorShape({1, 3})); - auto y = tensorflow::test::AsTensor({1, 2, 3}, - tensorflow::TensorShape({3, 1})); + auto x = AsTensor({1, 2, 3}, tensorflow::TensorShape({1, 3})); + auto y = AsTensor({1, 2, 3}, tensorflow::TensorShape({3, 1})); std::vector inputs{x, y}; TF_ASSERT_OK_AND_ASSIGN(auto result, - executable.Execute(absl::MakeSpan(inputs))); + executable.Execute(absl::MakeSpan(inputs), {})); - const auto expected_out = tensorflow::test::AsTensor( - {14}, tensorflow::TensorShape({1, 1})); + const auto expected_out = + AsTensor({14}, tensorflow::TensorShape({1, 1})); EXPECT_THAT(result, ElementsAre(TensorEq(expected_out))); } @@ -133,33 +150,32 @@ TEST(IfrtServingExecutableTest, MultipleShapes) { xla::ifrt::test_util::GetClient()); Eigen::ThreadPoolDevice thread_pool_device = GetThreadPoolDevice(); + IfrtLoadedVariableRegistry ifrt_loaded_variable_registry; + IfrtServingExecutable executable("test", "main", std::move(mlir_module), client, &thread_pool_device, + &ifrt_loaded_variable_registry, tensorflow::IdentityShapeRepresentationFn()); - auto x1 = tensorflow::test::AsTensor( - {1, 2, 3}, tensorflow::TensorShape({1, 3})); - auto y1 = tensorflow::test::AsTensor( - {1, 2, 3}, tensorflow::TensorShape({3, 1})); - const auto expected_out1 = tensorflow::test::AsTensor( - {14}, tensorflow::TensorShape({1, 1})); + auto x1 = AsTensor({1, 2, 3}, tensorflow::TensorShape({1, 3})); + auto y1 = AsTensor({1, 2, 3}, tensorflow::TensorShape({3, 1})); + const auto expected_out1 = + AsTensor({14}, tensorflow::TensorShape({1, 1})); std::vector inputs1{x1, y1}; - auto x2 = tensorflow::test::AsTensor( - {1, 2, 3, 4}, tensorflow::TensorShape({1, 4})); - auto y2 = tensorflow::test::AsTensor( - {1, 2, 3, 4}, tensorflow::TensorShape({4, 1})); - const auto expected_out2 = tensorflow::test::AsTensor( - {30}, tensorflow::TensorShape({1, 1})); + auto x2 = AsTensor({1, 2, 3, 4}, tensorflow::TensorShape({1, 4})); + auto y2 = AsTensor({1, 2, 3, 4}, tensorflow::TensorShape({4, 1})); + const auto expected_out2 = + AsTensor({30}, tensorflow::TensorShape({1, 1})); std::vector inputs2{x2, y2}; std::vector outputs1, outputs2; for (int i = 0; i < 3; i++) { TF_ASSERT_OK_AND_ASSIGN(outputs1, - executable.Execute(absl::MakeSpan(inputs1))); + executable.Execute(absl::MakeSpan(inputs1), {})); TF_ASSERT_OK_AND_ASSIGN(outputs2, - executable.Execute(absl::MakeSpan(inputs2))); + executable.Execute(absl::MakeSpan(inputs2), {})); } ASSERT_EQ(executable.num_executables(), 2); @@ -192,24 +208,27 @@ TEST(IfrtServingExecutableTest, Spmd) { xla::ifrt::test_util::GetClient()); Eigen::ThreadPoolDevice thread_pool_device = GetThreadPoolDevice(); + IfrtLoadedVariableRegistry ifrt_loaded_variable_registry; + IfrtServingExecutable executable("test", "main", std::move(mlir_module), client, &thread_pool_device, + &ifrt_loaded_variable_registry, tensorflow::IdentityShapeRepresentationFn()); - auto x = tensorflow::test::AsTensor({1, 2, 3, 4, 5, 6, 7, 8}, - tensorflow::TensorShape({4, 2})); - auto y = tensorflow::test::AsTensor({11, 12, 13, 14, 15, 16, 17, 18}, - tensorflow::TensorShape({4, 2})); + auto x = AsTensor({1, 2, 3, 4, 5, 6, 7, 8}, + tensorflow::TensorShape({4, 2})); + auto y = AsTensor({11, 12, 13, 14, 15, 16, 17, 18}, + tensorflow::TensorShape({4, 2})); - auto z = tensorflow::test::AsTensor({21, 22, 23, 24, 25, 26, 27, 28}, - tensorflow::TensorShape({4, 2})); + auto z = AsTensor({21, 22, 23, 24, 25, 26, 27, 28}, + tensorflow::TensorShape({4, 2})); - const auto expected_out = tensorflow::test::AsTensor( - {33, 36, 39, 42, 45, 48, 51, 54}, tensorflow::TensorShape({4, 2})); + const auto expected_out = AsTensor({33, 36, 39, 42, 45, 48, 51, 54}, + tensorflow::TensorShape({4, 2})); std::vector inputs{x, y, z}; TF_ASSERT_OK_AND_ASSIGN(auto result, - executable.Execute(absl::MakeSpan(inputs))); + executable.Execute(absl::MakeSpan(inputs), {})); EXPECT_THAT(result, ElementsAre(TensorEq(expected_out))); } @@ -237,26 +256,30 @@ TEST(IfrtServingExecutableTest, SpmdTwoReturns) { xla::ifrt::test_util::GetClient()); Eigen::ThreadPoolDevice thread_pool_device = GetThreadPoolDevice(); + IfrtLoadedVariableRegistry ifrt_loaded_variable_registry; + IfrtServingExecutable executable("test", "main", std::move(mlir_module), client, &thread_pool_device, + &ifrt_loaded_variable_registry, tensorflow::IdentityShapeRepresentationFn()); - auto x = tensorflow::test::AsTensor({1, 2, 3, 4, 5, 6, 7, 8}, - tensorflow::TensorShape({4, 2})); - auto y = tensorflow::test::AsTensor({11, 12, 13, 14, 15, 16, 17, 18}, - tensorflow::TensorShape({4, 2})); + auto x = AsTensor({1, 2, 3, 4, 5, 6, 7, 8}, + tensorflow::TensorShape({4, 2})); + auto y = AsTensor({11, 12, 13, 14, 15, 16, 17, 18}, + tensorflow::TensorShape({4, 2})); - auto z = tensorflow::test::AsTensor({21, 22, 23, 24, 25, 26, 27, 28}, - tensorflow::TensorShape({4, 2})); + auto z = AsTensor({21, 22, 23, 24, 25, 26, 27, 28}, + tensorflow::TensorShape({4, 2})); - const auto expected_out0 = tensorflow::test::AsTensor( - {33, 36, 39, 42, 45, 48, 51, 54}, tensorflow::TensorShape({4, 2})); - const auto expected_out1 = tensorflow::test::AsTensor( - {20, 20, 20, 20, 20, 20, 20, 20}, tensorflow::TensorShape({4, 2})); + const auto expected_out0 = AsTensor({33, 36, 39, 42, 45, 48, 51, 54}, + tensorflow::TensorShape({4, 2})); + const auto expected_out1 = AsTensor({20, 20, 20, 20, 20, 20, 20, 20}, + tensorflow::TensorShape({4, 2})); std::vector inputs{x, y, z}; + TF_ASSERT_OK_AND_ASSIGN(auto result, - executable.Execute(absl::MakeSpan(inputs))); + executable.Execute(absl::MakeSpan(inputs), {})); EXPECT_THAT(result, ElementsAre(TensorEq(expected_out0), TensorEq(expected_out1))); @@ -285,22 +308,204 @@ TEST(IfrtServingExecutableTest, NoReturn) { xla::ifrt::test_util::GetClient()); Eigen::ThreadPoolDevice thread_pool_device = GetThreadPoolDevice(); + IfrtLoadedVariableRegistry ifrt_loaded_variable_registry; + IfrtServingExecutable executable("test", "main", std::move(mlir_module), client, &thread_pool_device, + &ifrt_loaded_variable_registry, tensorflow::IdentityShapeRepresentationFn()); - auto x = tensorflow::test::AsTensor({1, 2, 3}, - tensorflow::TensorShape({1, 3})); - auto y = tensorflow::test::AsTensor({1, 2, 3}, - tensorflow::TensorShape({3, 1})); + auto x = AsTensor({1, 2, 3}, tensorflow::TensorShape({1, 3})); + auto y = AsTensor({1, 2, 3}, tensorflow::TensorShape({3, 1})); std::vector inputs{x, y}; TF_ASSERT_OK_AND_ASSIGN(auto result, - executable.Execute(absl::MakeSpan(inputs))); + executable.Execute(absl::MakeSpan(inputs), {})); ASSERT_EQ(result.size(), 0); } +TEST_P(VariableInputTest, InterleaveVariable) { + // Create test input module + constexpr absl::string_view kDataDirectory = + "tensorflow/core/tfrt/ifrt/testdata"; + std::string mlir_module_path = tensorflow::GetDataDependencyFilepath( + absl::StrCat(kDataDirectory, "/executable_long_inputs.mlir")); + + mlir::DialectRegistry registry; + mlir::registerAllDialects(registry); + mlir::RegisterAllTensorFlowDialects(registry); + + mlir::MLIRContext context(registry); + + mlir::OwningOpRef mlir_module = + mlir::parseSourceFile(mlir_module_path, &context); + + ASSERT_TRUE(mlir_module); + + // Create contexts required for the compiler execution. + TF_ASSERT_OK_AND_ASSIGN(std::shared_ptr client, + xla::ifrt::test_util::GetClient()); + Eigen::ThreadPoolDevice thread_pool_device = GetThreadPoolDevice(); + + IfrtLoadedVariableRegistry ifrt_loaded_variable_registry; + IfrtServingExecutable executable("test", "main", std::move(mlir_module), + client, &thread_pool_device, + &ifrt_loaded_variable_registry, + tensorflow::IdentityShapeRepresentationFn()); + + std::vector inputs; + std::vector loaded_variable_indices; + for (int i = 0; i < GetParam().in_tensors.size(); i++) { + if (GetParam().is_variable[i]) { + std::string variable_name = absl::StrCat("variable_", i); + ASSERT_OK(ifrt_loaded_variable_registry.TryRegisterLoadedVariable( + variable_name, + [&]() -> absl::StatusOr> { + TF_ASSIGN_OR_RETURN( + tsl::RCReference array, + MakeArrayFromTensor(*client, GetParam().in_tensors[i], + /*device_ids=*/{0}, + xla::HloSharding::Replicate(), + thread_pool_device)); + + return array; + })); + loaded_variable_indices.push_back(i); + + // Use string tensor containing the key (name) in place of variable + // tensor. + tensorflow::Tensor key_tensor(tensorflow::DT_STRING, {}); + key_tensor.scalar()() = variable_name; + inputs.push_back(key_tensor); + } else { + inputs.push_back(GetParam().in_tensors[i]); + } + } + + ASSERT_EQ(inputs.size(), GetParam().is_variable.size()); + + TF_ASSERT_OK_AND_ASSIGN( + auto result, executable.Execute(absl::MakeSpan(inputs), + absl::MakeSpan(loaded_variable_indices))); + + EXPECT_THAT(result, + ElementsAre(TensorEq(GetParam().expected_out_tensors[0]), + TensorEq(GetParam().expected_out_tensors[1]), + TensorEq(GetParam().expected_out_tensors[2]))); +} + +INSTANTIATE_TEST_SUITE_P( + VariableInputTests, VariableInputTest, + ::testing::ValuesIn( + { + // Basic case: all variables or all non-variables. + { + .in_tensors = + { + AsTensor({2, 2}, TensorShape({1, 2})), + AsTensor({3, 3}, TensorShape({2, 1})), + AsTensor({4, 4}, TensorShape({1, 2})), + AsTensor({5, 5}, TensorShape({2, 1})), + AsTensor({10, 10}, TensorShape({1, 2})), + }, + .is_variable = {true, true, true, true, true}, + .expected_out_tensors = + { + AsTensor({12}, TensorShape({1, 1})), + AsTensor({40}, TensorShape({1, 1})), + AsTensor({100}, TensorShape({1, 1})), + }, + }, + { + .in_tensors = + { + AsTensor({2, 2}, TensorShape({1, 2})), + AsTensor({3, 3}, TensorShape({2, 1})), + AsTensor({4, 4}, TensorShape({1, 2})), + AsTensor({5, 5}, TensorShape({2, 1})), + AsTensor({10, 10}, TensorShape({1, 2})), + }, + .is_variable = {false, false, false, false, false}, + .expected_out_tensors = + { + AsTensor({12}, TensorShape({1, 1})), + AsTensor({40}, TensorShape({1, 1})), + AsTensor({100}, TensorShape({1, 1})), + }, + }, + // Variable and non-variables are non-interleaved + { + .in_tensors = + { + AsTensor({2, 2}, TensorShape({1, 2})), + AsTensor({3, 3}, TensorShape({2, 1})), + AsTensor({4, 4}, TensorShape({1, 2})), + AsTensor({5, 5}, TensorShape({2, 1})), + AsTensor({10, 10}, TensorShape({1, 2})), + }, + .is_variable = {false, false, false, true, true}, + .expected_out_tensors = + { + AsTensor({12}, TensorShape({1, 1})), + AsTensor({40}, TensorShape({1, 1})), + AsTensor({100}, TensorShape({1, 1})), + }, + }, + { + .in_tensors = + { + AsTensor({2, 2}, TensorShape({1, 2})), + AsTensor({3, 3}, TensorShape({2, 1})), + AsTensor({4, 4}, TensorShape({1, 2})), + AsTensor({5, 5}, TensorShape({2, 1})), + AsTensor({10, 10}, TensorShape({1, 2})), + }, + .is_variable = {true, true, false, false, false}, + .expected_out_tensors = + { + AsTensor({12}, TensorShape({1, 1})), + AsTensor({40}, TensorShape({1, 1})), + AsTensor({100}, TensorShape({1, 1})), + }, + }, + // Variable and non-variables are interleaved + { + .in_tensors = + { + AsTensor({2, 2}, TensorShape({1, 2})), + AsTensor({3, 3}, TensorShape({2, 1})), + AsTensor({4, 4}, TensorShape({1, 2})), + AsTensor({5, 5}, TensorShape({2, 1})), + AsTensor({10, 10}, TensorShape({1, 2})), + }, + .is_variable = {true, false, false, true, false}, + .expected_out_tensors = + { + AsTensor({12}, TensorShape({1, 1})), + AsTensor({40}, TensorShape({1, 1})), + AsTensor({100}, TensorShape({1, 1})), + }, + }, + { + .in_tensors = + { + AsTensor({2, 2}, TensorShape({1, 2})), + AsTensor({3, 3}, TensorShape({2, 1})), + AsTensor({4, 4}, TensorShape({1, 2})), + AsTensor({5, 5}, TensorShape({2, 1})), + AsTensor({10, 10}, TensorShape({1, 2})), + }, + .is_variable = {false, true, true, false, true}, + .expected_out_tensors = + { + AsTensor({12}, TensorShape({1, 1})), + AsTensor({40}, TensorShape({1, 1})), + AsTensor({100}, TensorShape({1, 1})), + }, + }, + })); + } // namespace } // namespace ifrt_serving } // namespace tensorflow diff --git a/tensorflow/core/tfrt/ifrt/sharding_utils.cc b/tensorflow/core/tfrt/ifrt/sharding_utils.cc index bdc4d331fc7997..dda5beff68c9c7 100644 --- a/tensorflow/core/tfrt/ifrt/sharding_utils.cc +++ b/tensorflow/core/tfrt/ifrt/sharding_utils.cc @@ -364,7 +364,6 @@ CreateArrayFromHostTensorForSingleDevice(xla::ifrt::Client& ifrt_client, }); } -} // namespace StatusOr> MakeAssembledArrayFromHostBuffer( xla::ifrt::Client& ifrt_client, const tensorflow::Tensor& input_tensor, @@ -474,6 +473,8 @@ StatusOr> MakeAssembledArrayFromHostBuffer( xla::ifrt::ArrayCopySemantics::kDonateInput); } +} // namespace + absl::StatusOr MakeTensorFromArray( xla::ifrt::Client& ifrt_client, xla::ifrt::Array& input_array, const xla::HloSharding& hlo_sharding, @@ -636,23 +637,11 @@ absl::StatusOr MakeTensorFromArray( tensor_shape, thread_pool_device); } -StatusOr> MakeArrayFromTensor( +absl::StatusOr> MakeArrayFromTensor( xla::ifrt::Client& ifrt_client, const tensorflow::Tensor& input_tensor, - absl::Span device_ids, const xla::HloSharding& hlo_sharding, + const xla::ifrt::DeviceList& device_list, + const xla::HloSharding& hlo_sharding, const Eigen::ThreadPoolDevice& thread_pool_device) { - if (device_ids.empty()) { - return absl::InvalidArgumentError("device_ids cannot be empty"); - } - std::vector devices; - devices.reserve(device_ids.size()); - for (auto device_id : device_ids) { - TF_ASSIGN_OR_RETURN(xla::ifrt::Device * device, - ifrt_client.LookupDevice(device_id)); - devices.push_back(device); - } - xla::ifrt::DeviceList device_list( - xla::ifrt::DeviceList::Devices(devices.begin(), devices.end())); - VLOG(3) << "IsTiled: " << hlo_sharding.IsTiled(); VLOG(3) << "IsReplicated: " << hlo_sharding.IsReplicated(); VLOG(3) << "IsTileMaximal: " << hlo_sharding.IsTileMaximal(); @@ -687,5 +676,26 @@ StatusOr> MakeArrayFromTensor( thread_pool_device); } +absl::StatusOr> MakeArrayFromTensor( + xla::ifrt::Client& ifrt_client, const tensorflow::Tensor& input_tensor, + absl::Span device_ids, const xla::HloSharding& hlo_sharding, + const Eigen::ThreadPoolDevice& thread_pool_device) { + if (device_ids.empty()) { + return absl::InvalidArgumentError("device_ids cannot be empty"); + } + std::vector devices; + devices.reserve(device_ids.size()); + for (auto device_id : device_ids) { + TF_ASSIGN_OR_RETURN(xla::ifrt::Device * device, + ifrt_client.LookupDevice(device_id)); + devices.push_back(device); + } + xla::ifrt::DeviceList device_list( + xla::ifrt::DeviceList::Devices(devices.begin(), devices.end())); + + return MakeArrayFromTensor(ifrt_client, input_tensor, device_list, + hlo_sharding, thread_pool_device); +} + } // namespace ifrt_serving } // namespace tensorflow diff --git a/tensorflow/core/tfrt/ifrt/sharding_utils.h b/tensorflow/core/tfrt/ifrt/sharding_utils.h index 246d49c6c636fb..13599d4e6c3032 100644 --- a/tensorflow/core/tfrt/ifrt/sharding_utils.h +++ b/tensorflow/core/tfrt/ifrt/sharding_utils.h @@ -32,20 +32,18 @@ namespace tensorflow { namespace ifrt_serving { // Create a tensor from the given host tensor based on given device ids and -// sharding information. This is different from -// `MakeAssembledArrayFromHostBuffer` in that this function is a generic version -// that supports single device. -StatusOr> MakeArrayFromTensor( +// sharding information. +absl::StatusOr> MakeArrayFromTensor( xla::ifrt::Client& ifrt_client, const tensorflow::Tensor& input_tensor, absl::Span device_ids, const xla::HloSharding& hlo_sharding, const Eigen::ThreadPoolDevice& thread_pool_device); -// Sharded the given `data` by the `sharding` specification. -// It currently supports even sharding, replication and partial replication. -StatusOr> MakeAssembledArrayFromHostBuffer( +// A variant of the above api. The difference is that the user passes in +// device_list directly instead of a list of device_ids. +absl::StatusOr> MakeArrayFromTensor( xla::ifrt::Client& ifrt_client, const tensorflow::Tensor& input_tensor, - const xla::HloSharding& hlo_sharding, const xla::ifrt::DeviceList& device_list, + const xla::HloSharding& hlo_sharding, const Eigen::ThreadPoolDevice& thread_pool_device); // Reshard an disassembled array list back to one single tensor diff --git a/tensorflow/core/tfrt/ifrt/sharding_utils_test.cc b/tensorflow/core/tfrt/ifrt/sharding_utils_test.cc index ba95d2aed528d2..a4604fef756ba7 100644 --- a/tensorflow/core/tfrt/ifrt/sharding_utils_test.cc +++ b/tensorflow/core/tfrt/ifrt/sharding_utils_test.cc @@ -64,13 +64,6 @@ struct ReshardToTensorTestParam { xla::HloSharding sharding; }; -struct ShardToArrayTestParam { - tensorflow::Tensor in_tensor; - std::vector expected_out_tensors; - std::vector device_indices; - xla::HloSharding sharding; -}; - struct TensorToArrayTestParam { tensorflow::Tensor in_tensor; std::vector expected_out_tensors; @@ -78,7 +71,6 @@ struct TensorToArrayTestParam { xla::HloSharding sharding; }; -using ShardToArrayTest = ::testing::TestWithParam; using ReshardToTensorTest = ::testing::TestWithParam; using TensorToArrayTest = ::testing::TestWithParam; @@ -313,7 +305,7 @@ INSTANTIATE_TEST_SUITE_P( }, })); -TEST_P(ShardToArrayTest, MakeAssembledArrayFromHostBuffer) { +TEST_P(TensorToArrayTest, MakeArrayFromTensor) { constexpr int kMaxParallelism = 16; auto thread_pool = std::make_unique( tsl::Env::Default(), tsl::ThreadOptions(), "Resharding", kMaxParallelism); @@ -326,14 +318,12 @@ TEST_P(ShardToArrayTest, MakeAssembledArrayFromHostBuffer) { // Create contexts required for the compiler execution. TF_ASSERT_OK_AND_ASSIGN(std::shared_ptr client, xla::ifrt::test_util::GetClient()); - TF_ASSERT_OK_AND_ASSIGN(auto device_list, - xla::ifrt::test_util::GetDevices( - client.get(), GetParam().device_indices)); TF_ASSERT_OK_AND_ASSIGN( auto assembled_array, - MakeAssembledArrayFromHostBuffer( - *client, input_tensor, GetParam().sharding, device_list, device)); + MakeArrayFromTensor(*client, input_tensor, + absl::MakeSpan(GetParam().device_ids), + GetParam().sharding, device)); TF_ASSERT_OK_AND_ASSIGN(auto disassembled_arrays, assembled_array->DisassembleIntoSingleDeviceArrays( @@ -341,9 +331,6 @@ TEST_P(ShardToArrayTest, MakeAssembledArrayFromHostBuffer) { ASSERT_EQ(disassembled_arrays.size(), GetParam().expected_out_tensors.size()); - tensorflow::Tensor host_tensor(tensorflow::DT_INT32, - tensorflow::TensorShape({1, 2})); - for (int i = 0; i < disassembled_arrays.size(); ++i) { SCOPED_TRACE(absl::StrCat("Array ", i, " of ", disassembled_arrays.size())); auto disassembled_array = disassembled_arrays[i]; @@ -362,226 +349,36 @@ TEST_P(ShardToArrayTest, MakeAssembledArrayFromHostBuffer) { } INSTANTIATE_TEST_SUITE_P( - HloShardingTests, ShardToArrayTest, - ::testing::ValuesIn( + TensorToArrayTests, TensorToArrayTest, + ::testing::ValuesIn( { - // Full replication. + // Single device { .in_tensor = test::AsTensor({1}, TensorShape({})), .expected_out_tensors = { test::AsTensor({1}, TensorShape({})), - test::AsTensor({1}, TensorShape({})), - }, - .device_indices = {0, 1}, - .sharding = Replicate(), - }, - { - .in_tensor = test::AsTensor({1, 2, 3}, - TensorShape({3, 1})), - .expected_out_tensors = - { - test::AsTensor({1, 2, 3}, TensorShape({3, 1})), - test::AsTensor({1, 2, 3}, TensorShape({3, 1})), }, - .device_indices = {0, 1}, + .device_ids = {0}, .sharding = Replicate(), }, - // 1-D sharding - { - .in_tensor = test::AsTensor({1, 2, 3, 4}, - TensorShape({4})), - .expected_out_tensors = - { - test::AsTensor({1, 2}, TensorShape({2})), - test::AsTensor({3, 4}, TensorShape({2})), - }, - .device_indices = {0, 1}, - .sharding = Tile({2}), - }, - { - .in_tensor = test::AsTensor({1, 2, 3, 4}, - TensorShape({2, 2})), - .expected_out_tensors = - { - test::AsTensor({1, 2}, TensorShape({1, 2})), - test::AsTensor({3, 4}, TensorShape({1, 2})), - }, - .device_indices = {0, 1}, - .sharding = Tile({2, 1}), - }, - { - .in_tensor = test::AsTensor({1, 2, 3, 4}, - TensorShape({1, 2, 2})), - .expected_out_tensors = - { - test::AsTensor({1, 3}, TensorShape({1, 2, 1})), - test::AsTensor({2, 4}, TensorShape({1, 2, 1})), - }, - .device_indices = {0, 1}, - .sharding = Tile({1, 1, 2}), - }, - { - .in_tensor = test::AsTensor({1, 2, 3, 4, 5, 6, 7, 8}, - TensorShape({4, 2})), - .expected_out_tensors = - { - test::AsTensor({1, 2}, TensorShape({1, 2})), - test::AsTensor({3, 4}, TensorShape({1, 2})), - test::AsTensor({5, 6}, TensorShape({1, 2})), - test::AsTensor({7, 8}, TensorShape({1, 2})), - }, - .device_indices = {0, 1, 2, 3}, - .sharding = Tile({4, 1}), - }, - { - .in_tensor = test::AsTensor({1, 2, 3, 4, 5, 6, 7, 8}, - TensorShape({4, 2})), - .expected_out_tensors = - { - test::AsTensor({1, 3, 5, 7}, - TensorShape({4, 1})), - test::AsTensor({2, 4, 6, 8}, - TensorShape({4, 1})), - }, - .device_indices = {0, 1}, - .sharding = Tile({1, 2}), - }, - // 2-D sharding { - .in_tensor = test::AsTensor( - {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, - TensorShape({4, 4})), + .in_tensor = test::AsTensor({2}, TensorShape({})), .expected_out_tensors = { - test::AsTensor({1, 2, 5, 6}, - TensorShape({2, 2})), - test::AsTensor({3, 4, 7, 8}, - TensorShape({2, 2})), - test::AsTensor({9, 10, 13, 14}, - TensorShape({2, 2})), - test::AsTensor({11, 12, 15, 16}, - TensorShape({2, 2})), + test::AsTensor({2}, TensorShape({})), }, - .device_indices = {0, 1, 2, 3}, - .sharding = Tile({2, 2}), - }, - { - .in_tensor = test::AsTensor( - {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, - TensorShape({4, 1, 4})), - .expected_out_tensors = - { - test::AsTensor({1, 2, 5, 6}, - TensorShape({2, 1, 2})), - test::AsTensor({3, 4, 7, 8}, - TensorShape({2, 1, 2})), - test::AsTensor({9, 10, 13, 14}, - TensorShape({2, 1, 2})), - test::AsTensor({11, 12, 15, 16}, - TensorShape({2, 1, 2})), - }, - .device_indices = {0, 1, 2, 3}, - .sharding = Tile({2, 1, 2}), - }, - // Partial replication - { - .in_tensor = test::AsTensor({1, 2, 3, 4}, - TensorShape({2, 2})), - .expected_out_tensors = - { - test::AsTensor({1, 3}, TensorShape({2, 1})), - test::AsTensor({1, 3}, TensorShape({2, 1})), - test::AsTensor({2, 4}, TensorShape({2, 1})), - test::AsTensor({2, 4}, TensorShape({2, 1})), - }, - .device_indices = {0, 1, 2, 3}, - .sharding = PartialTile({1, 2, 2}), - }, - { - .in_tensor = test::AsTensor({1, 2, 3, 4}, - TensorShape({2, 2})), - .expected_out_tensors = - { - test::AsTensor({1, 2}, TensorShape({1, 2})), - test::AsTensor({1, 2}, TensorShape({1, 2})), - test::AsTensor({3, 4}, TensorShape({1, 2})), - test::AsTensor({3, 4}, TensorShape({1, 2})), - }, - .device_indices = {0, 1, 2, 3}, - .sharding = PartialTile({2, 1, 2}), + .device_ids = {0}, + .sharding = Maximal(0), }, { - .in_tensor = test::AsTensor({1, 2, 3, 4}, - TensorShape({2, 2})), + .in_tensor = test::AsTensor({3}, TensorShape({})), .expected_out_tensors = { - test::AsTensor({1, 2}, TensorShape({1, 2})), - test::AsTensor({1, 2}, TensorShape({1, 2})), - test::AsTensor({3, 4}, TensorShape({1, 2})), - test::AsTensor({3, 4}, TensorShape({1, 2})), - }, - .device_indices = {3, 2, 1, 0}, - .sharding = PartialTile({2, 1, 2}), - }, - })); - -TEST_P(TensorToArrayTest, MakeArrayFromTensor) { - constexpr int kMaxParallelism = 16; - auto thread_pool = std::make_unique( - tsl::Env::Default(), tsl::ThreadOptions(), "Resharding", kMaxParallelism); - - Eigen::ThreadPoolDevice device(thread_pool->AsEigenThreadPool(), - kMaxParallelism); - - auto input_tensor = GetParam().in_tensor; - - // Create contexts required for the compiler execution. - TF_ASSERT_OK_AND_ASSIGN(std::shared_ptr client, - xla::ifrt::test_util::GetClient()); - - TF_ASSERT_OK_AND_ASSIGN( - auto assembled_array, - MakeArrayFromTensor(*client, input_tensor, - absl::MakeSpan(GetParam().device_ids), - GetParam().sharding, device)); - - TF_ASSERT_OK_AND_ASSIGN(auto disassembled_arrays, - assembled_array->DisassembleIntoSingleDeviceArrays( - xla::ifrt::ArrayCopySemantics::kAlwaysCopy)); - - ASSERT_EQ(disassembled_arrays.size(), GetParam().expected_out_tensors.size()); - - for (int i = 0; i < disassembled_arrays.size(); ++i) { - SCOPED_TRACE(absl::StrCat("Array ", i, " of ", disassembled_arrays.size())); - auto disassembled_array = disassembled_arrays[i]; - auto expected_out_tensor = GetParam().expected_out_tensors[i]; - ASSERT_EQ(disassembled_array->shape(), - xla::ifrt::Shape(expected_out_tensor.shape().dim_sizes())); - tensorflow::Tensor host_tensor(expected_out_tensor.dtype(), - expected_out_tensor.shape()); - TF_ASSERT_OK( - disassembled_array - ->CopyToHostBuffer(host_tensor.data(), /*byte_strides=*/{}, - xla::ifrt::ArrayCopySemantics::kAlwaysCopy) - .Await()); - EXPECT_THAT(expected_out_tensor, TensorEq(host_tensor)); - } -} - -INSTANTIATE_TEST_SUITE_P( - TensorToArrayTests, TensorToArrayTest, - ::testing::ValuesIn( - { - // Single device - { - .in_tensor = test::AsTensor({1}, TensorShape({})), - .expected_out_tensors = - { - test::AsTensor({1}, TensorShape({})), + test::AsTensor({3}, TensorShape({})), }, - .device_ids = {0}, - .sharding = Replicate(), + .device_ids = {0, 1}, + .sharding = Maximal(1), }, // Full replication. { @@ -763,12 +560,11 @@ TEST(ShardingUtilsTest, MismatchRank) { xla::HloSharding sharding = Tile({2, 1}); - EXPECT_THAT( - MakeAssembledArrayFromHostBuffer( - *client, input_tensor, std::move(sharding), device_list, device), - StatusIs(absl::StatusCode::kInvalidArgument, - "shape must have 2 dimensions, but has 3 dimensions: " - "shape=[2,1,2], sharding={devices=[2,1]<=[2]}")); + EXPECT_THAT(MakeArrayFromTensor(*client, input_tensor, device_list, + std::move(sharding), device), + StatusIs(absl::StatusCode::kInvalidArgument, + "shape must have 2 dimensions, but has 3 dimensions: " + "shape=[2,1,2], sharding={devices=[2,1]<=[2]}")); } } // namespace diff --git a/tensorflow/core/tfrt/ifrt/testdata/executable_long_inputs.mlir b/tensorflow/core/tfrt/ifrt/testdata/executable_long_inputs.mlir new file mode 100644 index 00000000000000..529aa607276072 --- /dev/null +++ b/tensorflow/core/tfrt/ifrt/testdata/executable_long_inputs.mlir @@ -0,0 +1,8 @@ +module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, producer = 268 : i32}} { + func.func @main(%arg0: tensor<*xi32>, %arg1: tensor<*xi32>, %arg2: tensor<*xi32>, %arg3: tensor<*xi32>, %arg4: tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi32>, tensor<*xi32>) attributes {__tpu_compile_metadata_text = "args { dtype: DT_INT32 kind: PARAMETER } args { dtype: DT_INT32 kind: PARAMETER } args { dtype: DT_INT32 kind: PARAMETER } args { dtype: DT_INT32 kind: PARAMETER } args { dtype: DT_INT32 kind: PARAMETER } retvals { } retvals { } retvals { } num_replicas: 1 num_cores_per_replica: 1"} { + %0 = "tf.MatMul"(%arg0, %arg1): (tensor<*xi32>, tensor<*xi32>) -> tensor<*xi32> + %1 = "tf.MatMul"(%arg2, %arg3): (tensor<*xi32>, tensor<*xi32>) -> tensor<*xi32> + %2 = "tf.MatMul"(%arg4, %arg3): (tensor<*xi32>, tensor<*xi32>) -> tensor<*xi32> + func.return %0, %1, %2 : tensor<*xi32>, tensor<*xi32>, tensor<*xi32> + } +} diff --git a/tensorflow/core/tfrt/kernels/BUILD b/tensorflow/core/tfrt/kernels/BUILD index 390bef2009b9b0..fa342c7a6bcb17 100644 --- a/tensorflow/core/tfrt/kernels/BUILD +++ b/tensorflow/core/tfrt/kernels/BUILD @@ -29,6 +29,7 @@ cc_library( "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", ], alwayslink = 1, ) diff --git a/tensorflow/core/tfrt/kernels/ifrt_program_ops.cc b/tensorflow/core/tfrt/kernels/ifrt_program_ops.cc index 92ce3cad2c1e04..9f0ecea2708a79 100644 --- a/tensorflow/core/tfrt/kernels/ifrt_program_ops.cc +++ b/tensorflow/core/tfrt/kernels/ifrt_program_ops.cc @@ -23,6 +23,7 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" +#include "absl/types/span.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/op_requires.h" #include "tensorflow/core/framework/tensor.h" @@ -34,6 +35,8 @@ namespace tfrt_stub { IfrtCallOp::IfrtCallOp(tensorflow::OpKernelConstruction* ctx) : OpKernel(ctx) { OP_REQUIRES_OK(ctx, ctx->GetAttr("program_id", &program_id_)); + OP_REQUIRES_OK(ctx, + ctx->GetAttr("variable_arg_indices", &variable_arg_indices_)); } void IfrtCallOp::Compute(tensorflow::OpKernelContext* ctx) { @@ -51,7 +54,8 @@ void IfrtCallOp::Compute(tensorflow::OpKernelContext* ctx) { inputs.push_back(ctx->input(i)); } - absl::StatusOr> results = executable_->Execute(inputs); + absl::StatusOr> results = + executable_->Execute(inputs, absl::MakeSpan(variable_arg_indices_)); OP_REQUIRES(ctx, results.ok(), results.status()); tensorflow::OpOutputList outputs(ctx, 0, results->size()); diff --git a/tensorflow/core/tfrt/kernels/ifrt_program_ops.h b/tensorflow/core/tfrt/kernels/ifrt_program_ops.h index 578ccae70b8e4b..31bb908519d405 100644 --- a/tensorflow/core/tfrt/kernels/ifrt_program_ops.h +++ b/tensorflow/core/tfrt/kernels/ifrt_program_ops.h @@ -41,6 +41,9 @@ class IfrtCallOp : public tensorflow::OpKernel { // Op attributes. int64_t program_id_; + std::vector variable_names_; + std::vector variable_arg_indices_; + // Ifrt program to be called. Cached after the first call. absl::once_flag init_once_; std::shared_ptr executable_; diff --git a/tensorflow/core/tfrt/mlrt/kernel/BUILD b/tensorflow/core/tfrt/mlrt/kernel/BUILD index 2e0d7cd7fe185f..b22a9fa16ffce0 100644 --- a/tensorflow/core/tfrt/mlrt/kernel/BUILD +++ b/tensorflow/core/tfrt/mlrt/kernel/BUILD @@ -51,7 +51,9 @@ cc_library( ":context", ":kernel", "//tensorflow/core:framework", + "//tensorflow/core:framework_internal", "//tensorflow/core/platform:protobuf", + "//tensorflow/core/platform:refcount", "//tensorflow/core/tfrt/ifrt:ifrt_config_proto_cc", "//tensorflow/core/tfrt/ifrt:ifrt_model_context", "//tensorflow/core/tfrt/ifrt:sharding_utils", @@ -63,9 +65,13 @@ cc_library( "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", + "@local_tsl//tsl/concurrency:ref_count", + "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/platform:tstring", "@local_xla//xla:xla_data_proto_cc", "@local_xla//xla/hlo/ir:hlo", + "@local_xla//xla/python/ifrt", ], ) @@ -163,6 +169,8 @@ tf_cc_shared_test( ":kernel", "//tensorflow/core:core_cpu", "//tensorflow/core/framework:tensor", + "//tensorflow/core/framework:tensor_matcher", + "//tensorflow/core/framework:tensor_testutil", "//tensorflow/core/kernels:math", "//tensorflow/core/ops:math_ops_op_lib", "//tensorflow/core/platform:protobuf", @@ -179,6 +187,7 @@ tf_cc_shared_test( "//tensorflow/core/tfrt/mlrt/interpreter:interpreter_testutil", "//tensorflow/core/tfrt/mlrt/interpreter:value", "//tensorflow/core/tfrt/utils:fallback_tensor", + "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@com_google_absl//absl/synchronization", @@ -190,6 +199,7 @@ tf_cc_shared_test( "@local_tsl//tsl/platform:status", "@local_tsl//tsl/platform:status_matchers", "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/platform:tstring", "@local_xla//xla/python/ifrt", "@local_xla//xla/python/ifrt:test_util", "@local_xla//xla/python/pjrt_ifrt:tfrt_cpu_client_test_lib", diff --git a/tensorflow/core/tfrt/mlrt/kernel/ifrt_ops_kernel.cc b/tensorflow/core/tfrt/mlrt/kernel/ifrt_ops_kernel.cc index 422f96447dab5b..9f58ca6fa6b8f8 100644 --- a/tensorflow/core/tfrt/mlrt/kernel/ifrt_ops_kernel.cc +++ b/tensorflow/core/tfrt/mlrt/kernel/ifrt_ops_kernel.cc @@ -15,6 +15,7 @@ limitations under the License. #include #include +#include #include #include "absl/log/check.h" @@ -24,9 +25,14 @@ limitations under the License. #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "xla/hlo/ir/hlo_sharding.h" +#include "xla/python/ifrt/array.h" #include "xla/xla_data.pb.h" +#include "tensorflow/core/framework/resource_handle.h" +#include "tensorflow/core/framework/resource_mgr.h" +#include "tensorflow/core/framework/resource_var.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/platform/protobuf.h" // IWYU pragma: keep +#include "tensorflow/core/platform/refcount.h" #include "tensorflow/core/tfrt/ifrt/ifrt_config.pb.h" #include "tensorflow/core/tfrt/ifrt/ifrt_model_context.h" #include "tensorflow/core/tfrt/ifrt/sharding_utils.h" @@ -35,21 +41,24 @@ limitations under the License. #include "tensorflow/core/tfrt/mlrt/kernel/context.h" #include "tensorflow/core/tfrt/mlrt/kernel/kernel.h" #include "tensorflow/core/tfrt/utils/fallback_tensor.h" +#include "tsl/concurrency/ref_count.h" +#include "tsl/platform/errors.h" #include "tsl/platform/statusor.h" +#include "tsl/platform/tstring.h" namespace tensorflow { namespace tf_mlrt { namespace { -absl::Status IfrtLoadVariable( +absl::StatusOr> LoadIfrtVariable( tensorflow::ifrt_serving::IfrtModelContext& ifrt_model_context, const tensorflow::Tensor& variable, absl::string_view sharding_config_proto_text, absl::string_view name) { tensorflow::ifrt_serving::VariableDeviceShardingConfigProto sharding_config; if (!tensorflow::protobuf::TextFormat::ParseFromString( - std::string(sharding_config_proto_text), &sharding_config)) { + sharding_config_proto_text, &sharding_config)) { return absl::InvalidArgumentError(absl::StrCat( "Attribute: ", sharding_config_proto_text, " cannot be parsed")); } @@ -64,7 +73,7 @@ absl::Status IfrtLoadVariable( *ifrt_model_context.GetClient(), variable, absl::MakeSpan(device_ids), hlo_sharding, ifrt_model_context.GetThreadPoolDevice())); - return ifrt_model_context.RegisterLoadedVariable(name, result_array); + return result_array; } struct MlrtIfrtLoadVariableKernel : mlrt::KernelFrame { @@ -72,11 +81,14 @@ struct MlrtIfrtLoadVariableKernel : mlrt::KernelFrame { static constexpr char kName[] = "tf_mlrt.ifrt_load_variable"; - const tensorflow::Tensor& variable() const { + const ResourceHandle& variable() const { DCHECK_GE(arguments().size(), 1); - return arguments()[0].Get().tensor(); - } + const auto& tensor = + arguments()[0].Get().tensor(); + DCHECK_EQ(tensor.NumElements(), 1); + return tensor.scalar()(); + } absl::string_view sharding_config_proto_text() const { DCHECK_EQ(attributes().size(), 2); return attributes().GetAs(0).Get(); @@ -91,6 +103,7 @@ struct MlrtIfrtLoadVariableKernel : mlrt::KernelFrame { }; void MlrtIfrtLoadVariableKernel::Invoke() { + DCHECK_EQ(1, results().size()); std::optional ifrt_model_context = context() @@ -103,12 +116,30 @@ void MlrtIfrtLoadVariableKernel::Invoke() { return; } - auto status = IfrtLoadVariable(**ifrt_model_context, variable(), - sharding_config_proto_text(), name()); + auto status = + (*ifrt_model_context) + ->GetLoadedVariableRegistry() + .TryRegisterLoadedVariable( + name(), + [&]() -> absl::StatusOr> { + core::RefCountPtr variable_resource; + TF_RETURN_IF_ERROR( + LookupResource(&context().op_kernel_context(), variable(), + &variable_resource)); + + return LoadIfrtVariable(**ifrt_model_context, + *(variable_resource->tensor()), + sharding_config_proto_text(), name()); + }); if (!status.ok()) { - execution_context().Fail(status); + execution_context().Fail(std::move(status)); return; } + + // Return the name as the key + tensorflow::Tensor key_tensor(tensorflow::DT_STRING, {}); + key_tensor.scalar()() = std::string(name()); + results()[0].Set(tensorflow::tfrt_stub::FallbackTensor(key_tensor)); } void RegisterTfMlrtIfrtKernels(mlrt::KernelRegistry& registry) { registry.Register(); diff --git a/tensorflow/core/tfrt/mlrt/kernel/ifrt_ops_kernel_test.cc b/tensorflow/core/tfrt/mlrt/kernel/ifrt_ops_kernel_test.cc index 5f8471159a333d..1333ab67e59f04 100644 --- a/tensorflow/core/tfrt/mlrt/kernel/ifrt_ops_kernel_test.cc +++ b/tensorflow/core/tfrt/mlrt/kernel/ifrt_ops_kernel_test.cc @@ -20,11 +20,13 @@ limitations under the License. #include #include #include + // Enable definition of Eigen::ThreadPoolDevice instead of just declaration. #define EIGEN_USE_THREADS #include #include +#include "absl/log/check.h" #include "absl/status/status.h" #include "absl/strings/string_view.h" #include "absl/synchronization/notification.h" @@ -33,6 +35,7 @@ limitations under the License. #include "xla/python/ifrt/client.h" #include "xla/python/ifrt/test_util.h" #include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_testutil.h" #include "tensorflow/core/platform/protobuf.h" // IWYU pragma: keep #include "tensorflow/core/public/session_options.h" #include "tensorflow/core/runtime_fallback/kernel/kernel_fallback_compat_request_state.h" @@ -56,12 +59,15 @@ limitations under the License. #include "tsl/platform/status_matchers.h" #include "tsl/platform/statusor.h" #include "tsl/platform/threadpool.h" +#include "tsl/platform/tstring.h" #include "tfrt/host_context/concurrent_work_queue.h" // from @tf_runtime #include "tfrt/host_context/resource_context.h" // from @tf_runtime namespace tensorflow { namespace tf_mlrt { namespace { +using tensorflow::test::AsScalar; +using tensorflow::test::ExpectEqual; static absl::string_view kVariableName = "test_variable"; @@ -76,22 +82,24 @@ Eigen::ThreadPoolDevice GetThreadPoolDevice() { kMaxParallelism); } -mlrt::bc::Buffer CreateExecutableForIfrtLoadVariableOp() { +mlrt::bc::Buffer CreateExecutableForIfrtLoadVariableOp( + bool redundant_ifrt_load_variable_op = false) { mlrt::bc::Buffer buffer; mlrt::bc::Allocator allocator(&buffer); auto executable_ctor = mlrt::bc::New(&allocator); mlrt::testing::SymbolTable kernels; - std::vector kernel_names = {"tf_mlrt.ifrt_load_variable", - "return"}; + std::vector kernel_names = { + "tf_mlrt.createop", "tf_mlrt.executeop", "tf_mlrt.ifrt_load_variable", + "return"}; executable_ctor.construct_kernel_names(kernel_names.size()) .Assign(kernel_names); kernels.Def(kernel_names); mlrt::testing::AttributeTable attributes( - executable_ctor.construct_attributes(2)); + executable_ctor.construct_attributes(6)); tensorflow::ifrt_serving::VariableDeviceShardingConfigProto sharding_config; sharding_config.add_device_ids(0); @@ -103,6 +111,44 @@ mlrt::bc::Buffer CreateExecutableForIfrtLoadVariableOp() { attributes.Add("sharding_config", serialized_sharding_config); attributes.Add("variable_name", kVariableName); + attributes.Add("var_handle_op_node_def", + R"pb(name: "VarHandleOp" + op: "VarHandleOp" + device: "/job:localhost/replica:0/task:0/device:CPU:0" + attr { + key: "container" + value { s: "test" } + } + attr { + key: "shared_name" + value { s: "y" } + } + attr { + key: "dtype" + value { type: DT_INT32 } + } + attr { + key: "shape" + value { shape { dim { size: 1 } } } + } + )pb"); + + attributes.Add("var_handle_op_key", 0); + + attributes.Add("assign_variable_op_node_def", + R"pb(name: "AssignVariableOp" + op: "AssignVariableOp" + input: "dummy_resource" + input: "dummy_tensor" + device: "/job:localhost/replica:0/task:0/device:CPU:0" + attr { + key: "dtype" + value { type: DT_INT32 } + } + )pb"); + + attributes.Add("assign_variable_op_key", 1); + auto functions_ctor = executable_ctor.construct_functions(1); { @@ -112,23 +158,90 @@ mlrt::bc::Buffer CreateExecutableForIfrtLoadVariableOp() { mlrt::testing::SymbolTable regs; function_ctor.construct_input_regs(1).Assign({regs.Def("input_tensor")}); + function_ctor.construct_output_regs(1).Assign({regs.Def("output_tensor")}); + + const int kNumKernels = 6 + (redundant_ifrt_load_variable_op ? 1 : 0); + auto kernels_ctor = function_ctor.construct_kernels(kNumKernels); + int kernel_index = 0; + + { + // Create VarHandleOp + auto createop_ctor = kernels_ctor.ConstructAt(kernel_index); + createop_ctor.set_code(kernels.Use("tf_mlrt.createop")); + createop_ctor.construct_arguments(0); + createop_ctor.construct_results(0); + createop_ctor.construct_attributes(2).Assign( + {attributes.GetHandle("var_handle_op_node_def"), + attributes.GetHandle("var_handle_op_key")}); + kernel_index++; + } + { + // Create AssignVariableOp + auto createop_ctor = kernels_ctor.ConstructAt(kernel_index); + createop_ctor.set_code(kernels.Use("tf_mlrt.createop")); + createop_ctor.construct_arguments(0); + createop_ctor.construct_results(0); + createop_ctor.construct_attributes(2).Assign( + {attributes.GetHandle("assign_variable_op_node_def"), + attributes.GetHandle("assign_variable_op_key")}); + kernel_index++; + } + { + // Execute VarHandleOp + auto executeop_ctor = kernels_ctor.ConstructAt(kernel_index); + executeop_ctor.set_code(kernels.Use("tf_mlrt.executeop")); + executeop_ctor.construct_arguments(0); + executeop_ctor.construct_results(1).Assign({regs.Def("variable_handle")}); + executeop_ctor.construct_attributes(2).Assign( + {attributes.GetHandle("var_handle_op_node_def"), + attributes.GetHandle("var_handle_op_key")}); + executeop_ctor.construct_last_uses(1).Assign({0}); + kernel_index++; + } - auto kernels_ctor = function_ctor.construct_kernels(2); + { + // Execute AssignVariableOp + auto executeop_ctor = kernels_ctor.ConstructAt(kernel_index); + executeop_ctor.set_code(kernels.Use("tf_mlrt.executeop")); + executeop_ctor.construct_arguments(2).Assign( + regs.Use({"variable_handle", "input_tensor"})); + executeop_ctor.construct_results(0); + executeop_ctor.construct_attributes(2).Assign( + {attributes.GetHandle("assign_variable_op_node_def"), + attributes.GetHandle("assign_variable_op_key")}); + executeop_ctor.construct_last_uses(2).Assign({0, 0}); + kernel_index++; + } { - auto kernel_ctor = kernels_ctor.ConstructAt(0); + auto kernel_ctor = kernels_ctor.ConstructAt(kernel_index); kernel_ctor.set_code(kernels.Use("tf_mlrt.ifrt_load_variable")); + kernel_ctor.construct_results(1).Assign({regs.Use("output_tensor")}); + kernel_ctor.construct_arguments(1).Assign({regs.Use("variable_handle")}); + kernel_ctor.construct_attributes(2).Assign( + {attributes.GetHandle("sharding_config"), + attributes.GetHandle("variable_name")}); + kernel_ctor.construct_last_uses(1).Assign({1}); + kernel_index++; + } + if (redundant_ifrt_load_variable_op) { + auto kernel_ctor = kernels_ctor.ConstructAt(kernel_index); + kernel_ctor.set_code(kernels.Use("tf_mlrt.ifrt_load_variable")); + kernel_ctor.construct_results(1).Assign({regs.Def("dummy")}); kernel_ctor.construct_attributes(2).Assign( {attributes.GetHandle("sharding_config"), attributes.GetHandle("variable_name")}); kernel_ctor.construct_arguments(1).Assign({regs.Use("input_tensor")}); kernel_ctor.construct_last_uses(1).Assign({1}); + kernel_index++; } - { - auto kernel_ctor = kernels_ctor.ConstructAt(1); + auto kernel_ctor = kernels_ctor.ConstructAt(kernel_index); kernel_ctor.set_code(kernels.Use("return")); + kernel_ctor.construct_arguments(1).Assign({regs.Use("output_tensor")}); + kernel_index++; } + DCHECK_EQ(kernel_index, kNumKernels); function_ctor.set_num_regs(regs.size()); } @@ -186,7 +299,99 @@ TEST(KernelTest, IfrtLoadVariableOp) { "IfrtModelContext"); ASSERT_TRUE(ifrt_model_context.has_value()); - EXPECT_THAT((*ifrt_model_context)->GetLoadedVariable(kVariableName).status(), + EXPECT_THAT((*ifrt_model_context) + ->GetLoadedVariableRegistry() + .GetLoadedVariable(kVariableName) + .status(), + ::tsl::testing::StatusIs(absl::StatusCode::kNotFound)); + + std::vector args; + args.resize(1); + tensorflow::Tensor input_tensor; + TF_CHECK_OK(tensorflow::Tensor::BuildTensor(DT_INT32, {}, &input_tensor)); + input_tensor.scalar()() = 1234; + args.at(0).Set(tfrt_stub::FallbackTensor(std::move(input_tensor))); + + std::vector last_uses = {true}; + std::vector results; + results.resize(1); + + absl::Notification notification; + execution_context.set_exit_handler( + [¬ification]() { notification.Notify(); }); + + execution_context.Call(executable.functions()[0], last_uses, + absl::MakeSpan(args), absl::MakeSpan(results)); + mlrt::Execute(execution_context); + + notification.WaitForNotification(); + + TF_ASSERT_OK(execution_context.status()); + + TF_ASSERT_OK((*ifrt_model_context) + ->GetLoadedVariableRegistry() + .GetLoadedVariable(kVariableName) + .status()); + + ExpectEqual(results[0].Get().tensor(), + AsScalar(tsl::tstring(kVariableName))); +} + +TEST(KernelTest, DuplicateIfrtLoadVariableOpShallSucceed) { + auto buffer = CreateExecutableForIfrtLoadVariableOp( + /*redundant_ifrt_load_variable_op=*/true); + + mlrt::bc::Executable executable(buffer.data()); + + mlrt::KernelRegistry registry; + mlrt::RegisterBuiltinKernels(registry); + RegisterTfMlrtKernels(registry); + + mlrt::LoadedExecutable loaded_executable(executable, registry); + + auto work_queue = tfrt::CreateMultiThreadedWorkQueue( + /*num_threads=*/4, /*num_blocking_threads=*/4); + mlrt::ExecutionContext execution_context(&loaded_executable); + execution_context.set_work_queue(work_queue.get()); + + tensorflow::SessionOptions session_options; + tensorflow::FunctionDefLibrary fdef_lib; + TF_ASSERT_OK_AND_ASSIGN(auto fallback_state, tfrt_stub::FallbackState::Create( + session_options, fdef_lib)); + + std::function)> runner = + [](const std::function& f) { f(); }; + tfrt_stub::OpKernelRunnerTable runner_table; + tfd::FallbackResourceArray resource_array; + tfd::KernelFallbackCompatRequestState fallback_request_state( + &runner, &fallback_state->device_manager(), /*step_id=*/0, &runner_table, + &resource_array, /*user_intra_op_threadpool=*/nullptr, + /*model_metadata=*/std::nullopt, + &fallback_state->process_function_library_runtime()); + + tfrt::ResourceContext resource_context; + + TF_ASSERT_OK_AND_ASSIGN(std::shared_ptr client, + xla::ifrt::test_util::GetClient()); + Eigen::ThreadPoolDevice thread_pool_device = GetThreadPoolDevice(); + resource_context.CreateResource( + "IfrtModelContext", client, &thread_pool_device); + + auto tf_context = + std::make_unique(&fallback_request_state, &resource_context); + execution_context.AddUserContext(std::move(tf_context)); + + std::optional + ifrt_model_context = + resource_context + .GetResource( + "IfrtModelContext"); + + ASSERT_TRUE(ifrt_model_context.has_value()); + EXPECT_THAT((*ifrt_model_context) + ->GetLoadedVariableRegistry() + .GetLoadedVariable(kVariableName) + .status(), ::tsl::testing::StatusIs(absl::StatusCode::kNotFound)); std::vector args; @@ -212,8 +417,13 @@ TEST(KernelTest, IfrtLoadVariableOp) { TF_ASSERT_OK(execution_context.status()); - TF_ASSERT_OK( - (*ifrt_model_context)->GetLoadedVariable(kVariableName).status()); + TF_ASSERT_OK((*ifrt_model_context) + ->GetLoadedVariableRegistry() + .GetLoadedVariable(kVariableName) + .status()); + + ExpectEqual(results[0].Get().tensor(), + AsScalar(tsl::tstring(kVariableName))); } } // namespace diff --git a/tensorflow/core/tfrt/ops/ifrt_program_ops.cc b/tensorflow/core/tfrt/ops/ifrt_program_ops.cc index ab8e14b2e41eac..ba60bba95ecd83 100644 --- a/tensorflow/core/tfrt/ops/ifrt_program_ops.cc +++ b/tensorflow/core/tfrt/ops/ifrt_program_ops.cc @@ -25,6 +25,7 @@ REGISTER_OP("IfrtCall") .Attr("Tin: list(type) >= 0") .Attr("Tout: list(type) >= 0") .Attr("program_id: int") + .Attr("variable_arg_indices: list(int)") .SetIsStateful() .SetShapeFn(tensorflow::shape_inference::UnknownShape) .Doc(R"( @@ -39,7 +40,34 @@ in their SavedModel and instead rely on Ifrt Serving's mechanism that automatically inserts this op with graph rewrite. program_id: int64 id that can be used to look up compiled programs from - `ServingExecutableRegistry`. +ServingExecutableRegistry`. + +variable_arg_indices: must be in sorted ascending order. The argument at position +variable_arg_indices[k] in tpu program is already loaded as an ifrt array and +the input `args[variable_arg_indices[k]]` is the key to look for this loaded array. +)"); + +REGISTER_OP("IfrtLoadVariable") + .Input("variable: Tin") + .Output("array_key: Tout") + .Attr("Tin: type") + .Attr("Tout: type") + .Attr("config: string") + .Attr("name: string") + .SetIsStateful() + .SetShapeFn(tensorflow::shape_inference::UnknownShape) + .Doc(R"( +Converts the given tensor to a named array. + +This op loads the `variable` tensor to an IFRT device array based the sharding +spec in a `config` and the array can be looked up by `name` by the runtime. +The `config` is a text proto of `IfrtVariableDeviceShardingConfigProto`. +The `name` is typically a concatenation of `container` and `shared_name` from `tf.VarHandle`. +The idea is to avoid transferring to device repeatedly. + +Note that this op is not part of a stable interface. Users must not use this op +in their SavedModel and instead rely on Ifrt Serving's mechanism that +automatically inserts this op with graph rewrite. )"); } // namespace tfrt_stub diff --git a/tensorflow/core/tfrt/run_handler_thread_pool/BUILD b/tensorflow/core/tfrt/run_handler_thread_pool/BUILD index ff4c3bf8994505..3acb0edd074e96 100644 --- a/tensorflow/core/tfrt/run_handler_thread_pool/BUILD +++ b/tensorflow/core/tfrt/run_handler_thread_pool/BUILD @@ -15,7 +15,7 @@ package_group( # copybara:uncomment "//learning/serving/...", # copybara:uncomment "//learning/brain/experimental/tfrt/...", # copybara:uncomment "//learning/brain/tfrt/cpp_tests/tpu_model/...", - # copybara:uncomment "//learning/brain/tfrt/tfrt_session/...", + # copybara:uncomment "//learning/brain/tfrt/...", "//tensorflow/core/tfrt/...", "//tensorflow/core/runtime_fallback/runtime/...", "//tensorflow_serving/...", @@ -48,7 +48,6 @@ cc_library( hdrs = ["run_handler.h"], deps = [ ":run_handler_util", - "//tensorflow/core:framework", "//tensorflow/core:framework_internal", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", diff --git a/tensorflow/core/tfrt/saved_model/saved_model.cc b/tensorflow/core/tfrt/saved_model/saved_model.cc index 1714c87b4ee100..d103e4160fc01c 100644 --- a/tensorflow/core/tfrt/saved_model/saved_model.cc +++ b/tensorflow/core/tfrt/saved_model/saved_model.cc @@ -50,6 +50,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/tfrt/translate/tfrt_compile_options.h" #include "xla/status_macros.h" #include "tensorflow/core/framework/function.pb.h" +#include "tensorflow/core/framework/metrics.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor.pb.h" #include "tensorflow/core/framework/types.h" @@ -105,11 +106,6 @@ namespace { constexpr absl::string_view kSignatureJoiningDelimiter = "+"; -auto* saved_model_mla_check_time_milli_seconds = - tensorflow::monitoring::Gauge::New( - "/tensorflow/tfrt/saved_model/mla_check_time", - "Record the MLA check time for the savedmodel.", "model_name"); - auto* saved_model_import_time_seconds = tensorflow::monitoring::Gauge::New( "/tensorflow/tfrt/saved_model/import_time", @@ -478,7 +474,9 @@ SavedModelImpl::LoadSavedModel(Options options, LOG(INFO) << "Found AOT package. Register required dialects."; RegisterTfrtDialectsForAot(registry); } - RegisterMlirDialect(registry); + RegisterMlirDialect( + registry, + options.graph_execution_options.compile_options.backend_compiler); mlir::MLIRContext context(registry); // Step 1: Import saved model from a proto to an MLIR module. @@ -592,6 +590,7 @@ SavedModelImpl::LoadSavedModel(Options options, bef, LoadBefAndMlir(options.graph_execution_options.compile_options, mlir_module.get(), saved_model_dir_string, fallback_state.get())); + metrics::UpdateAotBefMlirLoadCount(); } } else { @@ -1016,7 +1015,8 @@ StatusOr> SavedModelImpl::LoadJoinedSignature(const JoinedSignature& joined_signature) { // Step 1: Import the combined subgraph from proto to an MLIR module. mlir::DialectRegistry registry; - RegisterMlirDialect(registry); + RegisterMlirDialect( + registry, graph_executor_->options().compile_options.backend_compiler); mlir::MLIRContext context(registry); ASSIGN_OR_RETURN_IN_IMPORT(auto module, diff --git a/tensorflow/core/tfrt/saved_model/saved_model_aot_compile.cc b/tensorflow/core/tfrt/saved_model/saved_model_aot_compile.cc index 759ec2e5b2dccc..6547d2172d3cad 100644 --- a/tensorflow/core/tfrt/saved_model/saved_model_aot_compile.cc +++ b/tensorflow/core/tfrt/saved_model/saved_model_aot_compile.cc @@ -204,7 +204,9 @@ StatusOr AotCompileSavedModel(absl::string_view input_model_dir, meta_graph_def.graph_def()); UpdateCompileOptions(aot_options); mlir::DialectRegistry registry; - RegisterMlirDialect(registry); + RegisterMlirDialect( + registry, + aot_options.graph_execution_options->compile_options.backend_compiler); mlir::MLIRContext context(registry); tensorflow::SessionOptions session_options = diff --git a/tensorflow/core/tfrt/saved_model/tests/BUILD b/tensorflow/core/tfrt/saved_model/tests/BUILD index 9496180fe1aa03..09f1e7deaa8d10 100644 --- a/tensorflow/core/tfrt/saved_model/tests/BUILD +++ b/tensorflow/core/tfrt/saved_model/tests/BUILD @@ -28,6 +28,30 @@ alias( actual = if_google("//learning/brain/public:disable_tf2", ":empty"), ) +filegroup( + name = "testdata", + srcs = glob([ + "testdata/**", + ]) + [ + ":saved_model_gen_control_flow_v1", + ":saved_model_gen_data", + ":saved_model_gen_dtype_coverage_v1", + ":saved_model_gen_error_v1", + ":saved_model_gen_hash_table_asset_v1", + ":saved_model_gen_if_v1", + ":saved_model_gen_matmul_gpu", + ":saved_model_gen_pow", + ":saved_model_gen_pow_v2", + ":saved_model_gen_ref_type_tensor_input", + ":saved_model_gen_resource_gather_v1", + ":saved_model_gen_sparse_tensor_input", + ":saved_model_gen_toy_v1", + ":saved_model_gen_toy_v2", + ":saved_model_gen_variable_on_tpu", + ":saved_model_gen_while_v1", + ], +) + pytype_strict_binary( name = "gen_resource_gather_v1", srcs = ["gen_resource_gather_v1.py"], @@ -433,6 +457,7 @@ gen_saved_model( gen_saved_model( model_name = "toy_v1", script = "gen_saved_model_v1", + version = "1", ) gen_saved_model( @@ -569,36 +594,7 @@ cc_library( testonly = 1, srcs = ["saved_model_test.cc"], data = [ - "control_flow_v1/saved_model.pb", - "data/saved_model.pb", - "dtype_coverage_v1/saved_model.pb", - "dtype_coverage_v1/variables/variables.data-00000-of-00001", - "dtype_coverage_v1/variables/variables.index", - "error_v1/saved_model.pb", - "hash_table_asset_v1/assets/tokens.txt", - "hash_table_asset_v1/saved_model.pb", - "if_v1/saved_model.pb", - "if_v1/variables/variables.data-00000-of-00001", - "if_v1/variables/variables.index", - "pow/saved_model.pb", - "pow_v2/saved_model.pb", - "ref_type_tensor_input/saved_model.pb", - "ref_type_tensor_input/variables/variables.data-00000-of-00001", - "ref_type_tensor_input/variables/variables.index", - "resource_gather_v1/saved_model.pb", - "resource_gather_v1/variables/variables.data-00000-of-00001", - "resource_gather_v1/variables/variables.index", - "sparse_tensor_input/saved_model.pb", - "toy_v1/saved_model.pb", - "toy_v1/variables/variables.data-00000-of-00001", - "toy_v1/variables/variables.index", - "toy_v2/saved_model.pb", - "toy_v2/variables/variables.data-00000-of-00001", - "toy_v2/variables/variables.index", - "variable_on_tpu/saved_model.pb", - "variable_on_tpu/variables/variables.data-00000-of-00001", - "variable_on_tpu/variables/variables.index", - "while_v1/saved_model.pb", + ":testdata", ], tags = ["no_oss"], deps = [ @@ -624,9 +620,7 @@ cc_library( testonly = 1, srcs = ["saved_model_ifrt_test.cc"], data = [ - "toy_v2/saved_model.pb", - "toy_v2/variables/variables.data-00000-of-00001", - "toy_v2/variables/variables.index", + ":testdata", ], tags = ["no_oss"], deps = [ @@ -704,9 +698,7 @@ tf_cuda_cc_test( name = "saved_model_gpu_test", srcs = ["saved_model_gpu_test.cc"], data = [ - "matmul_gpu/saved_model.pb", - "matmul_gpu/variables/variables.data-00000-of-00001", - "matmul_gpu/variables/variables.index", + ":testdata", ], tags = ["no_oss"], deps = [ diff --git a/tensorflow/core/tfrt/saved_model/tests/gen_saved_model.bzl b/tensorflow/core/tfrt/saved_model/tests/gen_saved_model.bzl index f3ed254c39689e..227fdc6c5ebc6d 100644 --- a/tensorflow/core/tfrt/saved_model/tests/gen_saved_model.bzl +++ b/tensorflow/core/tfrt/saved_model/tests/gen_saved_model.bzl @@ -2,17 +2,21 @@ load("//tensorflow:tensorflow.bzl", "if_google") -def gen_saved_model(model_name = "", script = "", **kwargs): +def gen_saved_model(model_name = "", script = "", version = "", **kwargs): + model_path = model_name + if version != "": + model_path = model_name + "/" + version + native.genrule( name = "saved_model_gen_" + model_name, srcs = [], outs = [ - model_name + "/saved_model.pb", - model_name + "/variables/variables.data-00000-of-00001", - model_name + "/variables/variables.index", + model_path + "/saved_model.pb", + model_path + "/variables/variables.data-00000-of-00001", + model_path + "/variables/variables.index", ], cmd = if_google( - "$(location " + script + ") --saved_model_path=$(RULEDIR)/" + model_name, + "$(location " + script + ") --saved_model_path=$(RULEDIR)/" + model_path, "touch $(OUTS)", # TODO(b/188517768): fix model gen. ), tools = [script], diff --git a/tensorflow/core/tfrt/saved_model/tests/gen_saved_model_v1.py b/tensorflow/core/tfrt/saved_model/tests/gen_saved_model_v1.py index 714dc5168cbb9b..3a67303c3c7161 100644 --- a/tensorflow/core/tfrt/saved_model/tests/gen_saved_model_v1.py +++ b/tensorflow/core/tfrt/saved_model/tests/gen_saved_model_v1.py @@ -19,7 +19,6 @@ from absl import flags from tensorflow.python.client import session from tensorflow.python.framework import dtypes -from tensorflow.python.framework import test_ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import variable_scope @@ -58,9 +57,6 @@ def main(argv): r32 = math_ops.add(x3, r31, name='result32') r33 = math_ops.add(x3, r32, name='result33') - # Sleep for 1 second. - sleep_op = test_ops.sleep_identity_op(1, x1, name='sleep') - sess = session.Session() sess.run(variables.global_variables_initializer()) @@ -75,46 +71,41 @@ def main(argv): tensor_info_r31 = utils.build_tensor_info(r31) tensor_info_r32 = utils.build_tensor_info(r32) tensor_info_r33 = utils.build_tensor_info(r33) - tensor_info_sleep = utils.build_tensor_info(sleep_op) - - toy_signature = ( - signature_def_utils.build_signature_def( - inputs={'x1': tensor_info_x1}, - outputs={'r1': tensor_info_r1}, - method_name=signature_constants.PREDICT_METHOD_NAME)) - another_toy_signature = ( - signature_def_utils.build_signature_def( - inputs={'x2': tensor_info_x2}, - outputs={ - 'r21': tensor_info_r21, - 'r22': tensor_info_r22, - }, - method_name=signature_constants.PREDICT_METHOD_NAME)) - yet_another_toy_signature = ( - signature_def_utils.build_signature_def( - inputs={'x3': tensor_info_x3}, - outputs={ - 'r31': tensor_info_r31, - 'r32': tensor_info_r32, - 'r33': tensor_info_r33, - }, - method_name=signature_constants.PREDICT_METHOD_NAME)) - sleep_signature = ( - signature_def_utils.build_signature_def( - inputs={'x1': tensor_info_x1}, - outputs={'sleep': tensor_info_sleep}, - method_name=signature_constants.PREDICT_METHOD_NAME)) + + toy_signature = signature_def_utils.build_signature_def( + inputs={'x1': tensor_info_x1}, + outputs={'r1': tensor_info_r1}, + method_name=signature_constants.PREDICT_METHOD_NAME, + ) + another_toy_signature = signature_def_utils.build_signature_def( + inputs={'x2': tensor_info_x2}, + outputs={ + 'r21': tensor_info_r21, + 'r22': tensor_info_r22, + }, + method_name=signature_constants.PREDICT_METHOD_NAME, + ) + yet_another_toy_signature = signature_def_utils.build_signature_def( + inputs={'x3': tensor_info_x3}, + outputs={ + 'r31': tensor_info_r31, + 'r32': tensor_info_r32, + 'r33': tensor_info_r33, + }, + method_name=signature_constants.PREDICT_METHOD_NAME, + ) sm_builder.add_meta_graph_and_variables( - sess, [tag_constants.SERVING], + sess, + [tag_constants.SERVING], signature_def_map={ 'toy': toy_signature, 'another_toy': another_toy_signature, 'yet_another_toy': yet_another_toy_signature, - 'sleep': sleep_signature, signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: toy_signature, }, - strip_default_attrs=True) + strip_default_attrs=True, + ) sm_builder.save() diff --git a/tensorflow/core/tfrt/saved_model/tests/saved_model_test.cc b/tensorflow/core/tfrt/saved_model/tests/saved_model_test.cc index 70f83330f715ff..07ff5edbae1008 100644 --- a/tensorflow/core/tfrt/saved_model/tests/saved_model_test.cc +++ b/tensorflow/core/tfrt/saved_model/tests/saved_model_test.cc @@ -51,7 +51,7 @@ TEST_P(SavedModelTest, BasicV1) { // y = tf.compat.v1.get_variable(name='y', initializer=[1, 2, 3]) // r = tf.matmul(x, y) std::string saved_model_dir = tensorflow::GetDataDependencyFilepath( - "tensorflow/core/tfrt/saved_model/tests/toy_v1"); + "tensorflow/core/tfrt/saved_model/tests/toy_v1/1"); auto runtime = DefaultTfrtRuntime(/*num_threads=*/1); auto options = DefaultSavedModelOptions(runtime.get()); @@ -126,7 +126,7 @@ TEST_P(SavedModelTest, OnlineCostAnalysis) { // y = tf.compat.v1.get_variable(name='y', initializer=[1, 2, 3]) // r = tf.matmul(x, y) std::string saved_model_dir = tensorflow::GetDataDependencyFilepath( - "tensorflow/core/tfrt/saved_model/tests/toy_v1"); + "tensorflow/core/tfrt/saved_model/tests/toy_v1/1"); auto runtime = DefaultTfrtRuntime(/*num_threads=*/1); auto options = DefaultSavedModelOptions(runtime.get()); @@ -258,7 +258,7 @@ TEST(SavedModelTest, RunMultipleSignatures) { // y = tf.compat.v1.get_variable(name='y', initializer=[1, 2, 3]) // r = tf.matmul(x, y) std::string saved_model_dir = tensorflow::GetDataDependencyFilepath( - "tensorflow/core/tfrt/saved_model/tests/toy_v1"); + "tensorflow/core/tfrt/saved_model/tests/toy_v1/1"); auto runtime = DefaultTfrtRuntime(/*num_threads=*/1); auto options = DefaultSavedModelOptions(runtime.get()); @@ -370,7 +370,7 @@ TEST(SavedModelTest, RunMultipleSignatures_OverlappingNodes) { // y = tf.compat.v1.get_variable(name='y', initializer=[1, 2, 3]) // r = tf.matmul(x, y) std::string saved_model_dir = tensorflow::GetDataDependencyFilepath( - "tensorflow/core/tfrt/saved_model/tests/toy_v1"); + "tensorflow/core/tfrt/saved_model/tests/toy_v1/1"); auto runtime = DefaultTfrtRuntime(/*num_threads=*/1); auto options = DefaultSavedModelOptions(runtime.get()); @@ -435,7 +435,7 @@ class SavedModelRunByTensorNamesTest : public ::testing::Test { // y = tf.compat.v1.get_variable(name='y', initializer=[1, 2, 3]) // r = tf.matmul(x, y) auto saved_model_dir = tensorflow::GetDataDependencyFilepath( - "tensorflow/core/tfrt/saved_model/tests/toy_v1"); + "tensorflow/core/tfrt/saved_model/tests/toy_v1/1"); runtime_ = DefaultTfrtRuntime(/*num_threads=*/1); auto options = DefaultSavedModelOptions(runtime_.get()); @@ -537,7 +537,7 @@ TEST(SavedModelTest, CustomWorkQueue) { // y = tf.compat.v1.get_variable(name='y', initializer=[1, 2, 3]) // r = tf.matmul(x, y) std::string saved_model_dir = tensorflow::GetDataDependencyFilepath( - "tensorflow/core/tfrt/saved_model/tests/toy_v1"); + "tensorflow/core/tfrt/saved_model/tests/toy_v1/1"); tfrt::tf::RunHandlerThreadWorkQueue::Options queue_options; queue_options.num_complementary_threads = 1; @@ -583,7 +583,7 @@ TEST(SavedModelTest, RunOptionsWorkQueue) { // y = tf.compat.v1.get_variable(name='y', initializer=[1, 2, 3]) // r = tf.matmul(x, y) std::string saved_model_dir = tensorflow::GetDataDependencyFilepath( - "tensorflow/core/tfrt/saved_model/tests/toy_v1"); + "tensorflow/core/tfrt/saved_model/tests/toy_v1/1"); auto runtime = tensorflow::tfrt_stub::Runtime::Create(/*num_inter_op_threads=*/4); @@ -633,7 +633,7 @@ TEST(SavedModelTest, FunctionMetadata) { // y = tf.compat.v1.get_variable(name='y', initializer=[1, 2, 3]) // r = tf.matmul(x, y) std::string saved_model_dir = tensorflow::GetDataDependencyFilepath( - "tensorflow/core/tfrt/saved_model/tests/toy_v1"); + "tensorflow/core/tfrt/saved_model/tests/toy_v1/1"); TFRTSavedModelTest test(saved_model_dir); auto* saved_model = test.GetSavedModel(); @@ -989,7 +989,7 @@ TEST(SavedModelTest, DeadlineExceeded) { // y = tf.compat.v1.get_variable(name='y', initializer=[1, 2, 3]) // r = tf.matmul(x, y) std::string saved_model_dir = tensorflow::GetDataDependencyFilepath( - "tensorflow/core/tfrt/saved_model/tests/toy_v1"); + "tensorflow/core/tfrt/saved_model/tests/toy_v1/1"); auto runtime = DefaultTfrtRuntime(/*num_threads=*/1); auto options = DefaultSavedModelOptions(runtime.get()); @@ -1021,7 +1021,7 @@ TEST(SavedModelTest, DisableCompilation) { // y = tf.compat.v1.get_variable(name='y', initializer=[1, 2, 3]) // r = tf.matmul(x, y) std::string saved_model_dir = tensorflow::GetDataDependencyFilepath( - "tensorflow/core/tfrt/saved_model/tests/toy_v1"); + "tensorflow/core/tfrt/saved_model/tests/toy_v1/1"); auto runtime = DefaultTfrtRuntime(/*num_threads=*/1); auto options = DefaultSavedModelOptions(runtime.get()); @@ -1060,7 +1060,7 @@ TEST(SavedModelTest, CustomModelConfig) { // y = tf.compat.v1.get_variable(name='y', initializer=[1, 2, 3]) // r = tf.matmul(x, y) std::string saved_model_dir = tensorflow::GetDataDependencyFilepath( - "tensorflow/core/tfrt/saved_model/tests/toy_v1"); + "tensorflow/core/tfrt/saved_model/tests/toy_v1/1"); auto runtime = DefaultTfrtRuntime(/*num_threads=*/1); @@ -1124,7 +1124,7 @@ class TestCompiler : public BackendCompiler { TEST(SavedModelTest, CustomCompiler) { std::string saved_model_dir = tensorflow::GetDataDependencyFilepath( - "tensorflow/core/tfrt/saved_model/tests/toy_v1"); + "tensorflow/core/tfrt/saved_model/tests/toy_v1/1"); auto runtime = DefaultTfrtRuntime(/*num_threads=*/1); diff --git a/tensorflow/core/tfrt/tfrt_session/BUILD b/tensorflow/core/tfrt/tfrt_session/BUILD index e7a18cb4306744..b4157cacbd7785 100644 --- a/tensorflow/core/tfrt/tfrt_session/BUILD +++ b/tensorflow/core/tfrt/tfrt_session/BUILD @@ -78,9 +78,9 @@ tf_cc_shared_test( size = "small", srcs = ["tfrt_session_test.cc"], data = [ - "//tensorflow/core/tfrt/saved_model/tests:toy_v1/saved_model.pb", - "//tensorflow/core/tfrt/saved_model/tests:toy_v1/variables/variables.data-00000-of-00001", - "//tensorflow/core/tfrt/saved_model/tests:toy_v1/variables/variables.index", + "//tensorflow/core/tfrt/saved_model/tests:toy_v1/1/saved_model.pb", + "//tensorflow/core/tfrt/saved_model/tests:toy_v1/1/variables/variables.data-00000-of-00001", + "//tensorflow/core/tfrt/saved_model/tests:toy_v1/1/variables/variables.index", ], tags = ["no_oss"], deps = [ @@ -108,8 +108,10 @@ tf_cc_shared_test( "//tensorflow/core/tfrt/saved_model:saved_model_testutil", "//tensorflow/core/tfrt/utils:thread_pool", "//tensorflow/python/framework:test_ops_kernels", + "@com_google_absl//absl/memory", "@com_google_absl//absl/time", "@com_google_googletest//:gtest", + "@local_tsl//tsl/platform:protobuf", ], ) diff --git a/tensorflow/core/tfrt/tfrt_session/tfrt_session_test.cc b/tensorflow/core/tfrt/tfrt_session/tfrt_session_test.cc index a5898333c23a4f..490d132ecefd50 100644 --- a/tensorflow/core/tfrt/tfrt_session/tfrt_session_test.cc +++ b/tensorflow/core/tfrt/tfrt_session/tfrt_session_test.cc @@ -22,6 +22,7 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" #include "absl/time/time.h" #include "tensorflow/cc/framework/ops.h" #include "tensorflow/cc/framework/scope.h" @@ -44,6 +45,7 @@ limitations under the License. #include "tensorflow/core/tfrt/saved_model/saved_model_testutil.h" #include "tensorflow/core/tfrt/utils/thread_pool.h" #include "tsl/lib/core/status_test_util.h" +#include "tsl/platform/protobuf.h" namespace tensorflow { namespace { @@ -78,7 +80,7 @@ class TfrtSessionTest : public ::testing::Test { // Initialize the session with a GraphDef. std::string saved_model_dir = GetDataDependencyFilepath( - "tensorflow/core/tfrt/saved_model/tests/toy_v1"); + "tensorflow/core/tfrt/saved_model/tests/toy_v1/1"); MetaGraphDef meta_graph_def; TF_ASSERT_OK(ReadMetaGraphDefFromSavedModel(saved_model_dir, {"serve"}, @@ -131,27 +133,83 @@ TEST_F(TfrtSessionTest, NoTargetNodes) { } TEST_F(TfrtSessionTest, RunOptions) { + SessionOptions options; + options.config.mutable_experimental()->set_use_tfrt(true); + auto* model_metadata = + options.config.mutable_experimental()->mutable_session_metadata(); + model_metadata->set_name("toy_v1"); + model_metadata->set_version(0); + + auto session = absl::WrapUnique(NewSession(options)); + ASSERT_TRUE(session != nullptr); + + tensorflow::GraphDef graph_def; + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString( + R"pb( + node: { + name: "input" + op: "Placeholder" + attr: { + key: "dtype" + value: { type: DT_INT32 } + } + } + node: { + name: "sleep_seconds" + op: "Const" + attr: { + key: "dtype" + value: { type: DT_INT32 } + } + attr: { + key: "value" + value: { + tensor: { + tensor_shape: {} + dtype: DT_INT32 + int_val: 2 + } + } + } + } + node: { + name: "sleep" + op: "SleepIdentityOp" + input: "sleep_seconds:0" + input: "input:0" + attr: { + key: "T" + value: { type: DT_INT32 } + } + })pb" + + , + &graph_def)); + + TF_ASSERT_OK(session->Create(graph_def)); + std::vector outputs; // Test the Run() overload with RunOptions and RunMetadata RunMetadata run_metadata; - TF_ASSERT_OK(session_->Run(RunOptions{}, inputs_, output_tensor_names_, - /*target_tensor_names=*/{}, &outputs, - &run_metadata)); + TF_ASSERT_OK(session->Run( + RunOptions{}, + /*inputs=*/{{"input", test::AsTensor({1}, TensorShape{1})}}, + /*output_tensor_names=*/{"sleep"}, + /*target_tensor_names=*/{}, &outputs, &run_metadata)); - ASSERT_EQ(outputs.size(), 3); + ASSERT_EQ(outputs.size(), 1); // Check output "r1". - test::ExpectEqual(outputs[0], - test::AsTensor({6}, TensorShape{1, 1})); + test::ExpectEqual(outputs[0], test::AsTensor({1}, TensorShape{1})); // Test timeout. RunOptions run_options; run_options.set_timeout_in_ms(1); - // The "sleep" op will sleep for 1 second, so the Session::Run() call will - // time out. - auto status = - session_->Run(run_options, inputs_, output_tensor_names_, - /*target_tensor_names=*/{"sleep"}, &outputs, &run_metadata); + auto status = session->Run( + run_options, + /*inputs=*/{{"input", test::AsTensor({1}, TensorShape{1})}}, + /*output_tensor_names=*/{"sleep"}, + /*target_tensor_names=*/{}, &outputs, &run_metadata); ASSERT_FALSE(status.ok()); EXPECT_THAT(status.ToString(), ::testing::HasSubstr("Deadline exceeded")); @@ -234,7 +292,7 @@ TEST_F(TfrtSessionTest, RunInCallerThreadSessionOptions) { // Initialize the session with a GraphDef. std::string saved_model_dir = GetDataDependencyFilepath( - "tensorflow/core/tfrt/saved_model/tests/toy_v1"); + "tensorflow/core/tfrt/saved_model/tests/toy_v1/1"); MetaGraphDef meta_graph_def; TF_ASSERT_OK(ReadMetaGraphDefFromSavedModel(saved_model_dir, {"serve"}, @@ -340,7 +398,7 @@ TEST_F(TfrtSessionTest, CreateWithEmptyGraphIsNoop) { // Create agian with an unempty GraphDef. std::string saved_model_dir = GetDataDependencyFilepath( - "tensorflow/core/tfrt/saved_model/tests/toy_v1"); + "tensorflow/core/tfrt/saved_model/tests/toy_v1/1"); MetaGraphDef meta_graph_def; TF_ASSERT_OK(ReadMetaGraphDefFromSavedModel(saved_model_dir, {"serve"}, @@ -352,7 +410,7 @@ TEST_F(TfrtSessionTest, CreateWithEmptyGraphIsNoop) { TEST_F(TfrtSessionTest, CreateAgainError) { // On a created session, create agian with a GraphDef. std::string saved_model_dir = GetDataDependencyFilepath( - "tensorflow/core/tfrt/saved_model/tests/toy_v1"); + "tensorflow/core/tfrt/saved_model/tests/toy_v1/1"); MetaGraphDef meta_graph_def; TF_ASSERT_OK(ReadMetaGraphDefFromSavedModel(saved_model_dir, {"serve"}, @@ -378,7 +436,7 @@ TEST_F(TfrtSessionTest, CreateAfterCloseError) { // Create the session with a GraphDef. std::string saved_model_dir = GetDataDependencyFilepath( - "tensorflow/core/tfrt/saved_model/tests/toy_v1"); + "tensorflow/core/tfrt/saved_model/tests/toy_v1/1"); MetaGraphDef meta_graph_def; TF_ASSERT_OK(ReadMetaGraphDefFromSavedModel(saved_model_dir, {"serve"}, @@ -400,7 +458,7 @@ TEST_F(TfrtSessionTest, ExtendWhenNotCreated) { // Extend the session with a GraphDef. std::string saved_model_dir = GetDataDependencyFilepath( - "tensorflow/core/tfrt/saved_model/tests/toy_v1"); + "tensorflow/core/tfrt/saved_model/tests/toy_v1/1"); MetaGraphDef meta_graph_def; TF_ASSERT_OK(ReadMetaGraphDefFromSavedModel(saved_model_dir, {"serve"}, @@ -535,7 +593,7 @@ TEST_F(TfrtSessionTest, ExtendAfterCloseError) { // Extend the session with a GraphDef. std::string saved_model_dir = GetDataDependencyFilepath( - "tensorflow/core/tfrt/saved_model/tests/toy_v1"); + "tensorflow/core/tfrt/saved_model/tests/toy_v1/1"); MetaGraphDef meta_graph_def; TF_ASSERT_OK(ReadMetaGraphDefFromSavedModel(saved_model_dir, {"serve"}, diff --git a/tensorflow/core/tfrt/utils/debug/BUILD b/tensorflow/core/tfrt/utils/debug/BUILD index 2396428514fb6f..38965aed32f1dc 100644 --- a/tensorflow/core/tfrt/utils/debug/BUILD +++ b/tensorflow/core/tfrt/utils/debug/BUILD @@ -28,9 +28,9 @@ cc_library( # name = "node_io_dump_rewriter_test", # srcs = ["node_io_dump_rewriter_test.cc"], # data = [ -# "//tensorflow/core/tfrt/saved_model/tests:toy_v1/saved_model.pb", -# "//tensorflow/core/tfrt/saved_model/tests:toy_v1/variables/variables.data-00000-of-00001", -# "//tensorflow/core/tfrt/saved_model/tests:toy_v1/variables/variables.index", +# "//tensorflow/core/tfrt/saved_model/tests:toy_v1/1/saved_model.pb", +# "//tensorflow/core/tfrt/saved_model/tests:toy_v1/1/variables/variables.data-00000-of-00001", +# "//tensorflow/core/tfrt/saved_model/tests:toy_v1/1/variables/variables.index", # "//tensorflow/core/tfrt/saved_model/tests:toy_v2/saved_model.pb", # "//tensorflow/core/tfrt/saved_model/tests:toy_v2/variables/variables.data-00000-of-00001", # "//tensorflow/core/tfrt/saved_model/tests:toy_v2/variables/variables.index", diff --git a/tensorflow/core/tfrt/utils/debug/node_io_dump_rewriter_test.cc b/tensorflow/core/tfrt/utils/debug/node_io_dump_rewriter_test.cc index 273a0cab075e16..3ccb66f704df60 100644 --- a/tensorflow/core/tfrt/utils/debug/node_io_dump_rewriter_test.cc +++ b/tensorflow/core/tfrt/utils/debug/node_io_dump_rewriter_test.cc @@ -150,7 +150,7 @@ TEST(NodeIoDumpRewriterTest, OnGraph) { TEST(NodeIoDumpRewriterTest, OnSavedModelV1) { // Read meta_graph_def. std::string saved_model_dir = GetDataDependencyFilepath( - "tensorflow/core/tfrt/saved_model/tests/toy_v1"); + "tensorflow/core/tfrt/saved_model/tests/toy_v1/1"); MetaGraphDef meta_graph_def; TF_ASSERT_OK(ReadMetaGraphDefFromSavedModel(saved_model_dir, {"serve"}, &meta_graph_def)); diff --git a/tensorflow/core/tpu/kernels/tpu_compile_op.cc b/tensorflow/core/tpu/kernels/tpu_compile_op.cc index 1ada53b73f4b60..b4a462a1e20b72 100644 --- a/tensorflow/core/tpu/kernels/tpu_compile_op.cc +++ b/tensorflow/core/tpu/kernels/tpu_compile_op.cc @@ -88,7 +88,7 @@ void TpuCompileSucceededAssertOp::Compute(OpKernelContext* ctx) { } } -REGISTER_MODULE_INITIALIZER(register_tpu_compile_op_kernel, { +STREAM_EXECUTOR_REGISTER_MODULE_INITIALIZER(register_tpu_compile_op_kernel, { VLOG(1) << "Register TpuCompileOp kernel."; REGISTER_KERNEL_BUILDER(Name("TPUCompile").Device(DEVICE_CPU), TpuCompileOp); REGISTER_KERNEL_BUILDER(Name("_TPUCompileMlir").Device(DEVICE_CPU), diff --git a/tensorflow/core/tpu/kernels/tpu_compile_op_impl.cc b/tensorflow/core/tpu/kernels/tpu_compile_op_impl.cc index 1bf78a934c3646..1e45bb2079f316 100644 --- a/tensorflow/core/tpu/kernels/tpu_compile_op_impl.cc +++ b/tensorflow/core/tpu/kernels/tpu_compile_op_impl.cc @@ -90,7 +90,7 @@ class TpuCompileOpImplFactory : public CompileOpImplFactory { }; #if defined(LIBTPU_ON_GCE) -REGISTER_MODULE_INITIALIZER(tpu_compile_op_impl_factory, { +STREAM_EXECUTOR_REGISTER_MODULE_INITIALIZER(tpu_compile_op_impl_factory, { VLOG(1) << "register TpuCompileOpImplFactory()"; CompileOpImplFactory::Register(new TpuCompileOpImplFactory()); }); diff --git a/tensorflow/core/tpu/kernels/xla/BUILD b/tensorflow/core/tpu/kernels/xla/BUILD index c75d6ded6f2f45..d91c12d5abde2f 100644 --- a/tensorflow/core/tpu/kernels/xla/BUILD +++ b/tensorflow/core/tpu/kernels/xla/BUILD @@ -7,11 +7,45 @@ package( licenses = ["notice"], ) +cc_library( + name = "host_compute_ops", + srcs = [ + "host_compute_ops.cc", + ], + visibility = ["//visibility:public"], + deps = [ + "//tensorflow/compiler/tf2xla:common", + "//tensorflow/compiler/tf2xla:mlir_xla_op_kernel", + "//tensorflow/compiler/tf2xla:side_effect_util", + "//tensorflow/compiler/tf2xla:xla_compiler", + "//tensorflow/compiler/tf2xla:xla_context", + "//tensorflow/compiler/tf2xla:xla_op_registry", + "//tensorflow/compiler/tf2xla/kernels:if_op", + "//tensorflow/compiler/tf2xla/kernels:index_ops", + "//tensorflow/compiler/tf2xla/kernels:while_op", + "//tensorflow/compiler/tf2xla/kernels:xla_ops", + "//tensorflow/core:core_cpu_internal", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core/platform:errors", + "//tensorflow/core/platform:status", + "//tensorflow/core/tpu/kernels:cross_replica_ops", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:logging", + "@local_xla//xla:shape_util", + "@local_xla//xla:side_effect_util", + "@local_xla//xla:xla_data_proto_cc", + "@local_xla//xla/client:sharding_builder", + "@local_xla//xla/client:xla_builder", + ], + alwayslink = 1, +) + cc_library( name = "xla_ops", srcs = [ "get_item_op.cc", - "host_compute_ops.cc", "index_ops.cc", "infeed_op.cc", "inplace_ops.cc", @@ -19,9 +53,8 @@ cc_library( ], visibility = ["//visibility:public"], deps = [ + ":host_compute_ops", "//tensorflow/compiler/tf2xla:common", - "//tensorflow/compiler/tf2xla:sharding_util", - "//tensorflow/compiler/tf2xla:side_effect_util", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/tf2xla:xla_context", "//tensorflow/compiler/tf2xla:xla_op_registry", @@ -39,18 +72,15 @@ cc_library( "//tensorflow/core/tpu/kernels:cross_replica_ops", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", - "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:logging", - "@local_tsl//tsl/platform:macros", "@local_xla//xla:shape_util", - "@local_xla//xla:side_effect_util", "@local_xla//xla:util", "@local_xla//xla:xla_data_proto_cc", - "@local_xla//xla/client:sharding_builder", "@local_xla//xla/client:xla_builder", + "@local_xla//xla/hlo/ir:hlo", "@local_xla//xla/stream_executor/tpu:c_api_conversions", "@local_xla//xla/stream_executor/tpu:c_api_decl", - "@local_xla//xla/stream_executor/tpu:tpu_api", + "@local_xla//xla/stream_executor/tpu:tpu_executor_api", ], alwayslink = 1, ) diff --git a/tensorflow/core/tpu/kernels/xla/host_compute_ops.cc b/tensorflow/core/tpu/kernels/xla/host_compute_ops.cc index fe0fd5b309715b..2c2dfc8df4a825 100644 --- a/tensorflow/core/tpu/kernels/xla/host_compute_ops.cc +++ b/tensorflow/core/tpu/kernels/xla/host_compute_ops.cc @@ -18,15 +18,15 @@ limitations under the License. #include #include -#include "absl/strings/str_cat.h" +#include "tensorflow/compiler/tf2xla/mlir_xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/side_effect_util.h" #include "tensorflow/compiler/tf2xla/xla_context.h" +#include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "xla/client/sharding_builder.h" #include "xla/client/xla_builder.h" -#include "xla/primitive_util.h" #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/side_effect_util.h" @@ -56,7 +56,6 @@ limitations under the License. #include "tensorflow/core/platform/types.h" #include "tsl/platform/errors.h" #include "tsl/platform/logging.h" // IWYU pragma: keep -#include "tsl/platform/macros.h" namespace tensorflow { @@ -180,7 +179,7 @@ class HostComputeOp : public XlaOpKernel { // Send values to the host. std::vector send_to_host_tokens; for (int i = 0; i < input_handles.size(); ++i) { - const string channel_name = absl::StrCat(send_key_, "_dtoh_", i); + const string channel_name = GetDeviceToHostChannelName(send_key_, i); xla::Shape xla_shape; OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(input_dtypes_[i], input_shapes[i], &xla_shape)); @@ -242,7 +241,7 @@ class HostComputeOp : public XlaOpKernel { // Copy results to the device. std::vector recv_from_host_tokens; for (int i = 0; i < output_shapes->size(); ++i) { - const string channel_name = absl::StrCat(recv_key_, "_htod_", i); + const string channel_name = GetHostToDeviceChannelName(recv_key_, i); // Specify frontend attributes. xla::FrontendAttributes attrs; (*attrs.mutable_map())[xla::kXlaHostTransferRendezvousNameAttr] = @@ -538,6 +537,7 @@ class RecvFromHostOp : public XlaOpKernel { REGISTER_XLA_OP(Name("XlaHostCompute"), HostComputeOp); REGISTER_XLA_OP(Name("XlaSendToHost"), SendToHostOp); REGISTER_XLA_OP(Name("XlaRecvFromHost"), RecvFromHostOp); +REGISTER_XLA_OP(Name("_XlaHostComputeMlir"), MlirXlaOpKernel); } // anonymous namespace } // namespace tensorflow diff --git a/tensorflow/core/tpu/kernels/xla/infeed_op.cc b/tensorflow/core/tpu/kernels/xla/infeed_op.cc index 050829f1ecfd5c..94ae03c0158fd5 100644 --- a/tensorflow/core/tpu/kernels/xla/infeed_op.cc +++ b/tensorflow/core/tpu/kernels/xla/infeed_op.cc @@ -17,18 +17,25 @@ limitations under the License. #include #include "tensorflow/compiler/tf2xla/shape_util.h" -#include "tensorflow/compiler/tf2xla/sharding_util.h" -#include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "xla/client/xla_builder.h" +#include "xla/hlo/ir/hlo_sharding.h" +#include "xla/layout_util.h" +#include "xla/shape.h" +#include "xla/shape_util.h" #include "xla/stream_executor/tpu/c_api_conversions.h" #include "xla/stream_executor/tpu/c_api_decl.h" -#include "xla/stream_executor/tpu/tpu_api.h" +#include "xla/stream_executor/tpu/tpu_executor_api.h" #include "xla/util.h" -#include "tensorflow/core/framework/kernel_def_builder.h" +#include "xla/xla_data.pb.h" #include "tensorflow/core/framework/op_kernel.h" -#include "tensorflow/core/tpu/tpu_defs.h" +#include "tensorflow/core/framework/op_requires.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/platform/status.h" +#include "tsl/platform/statusor.h" namespace tensorflow { diff --git a/tensorflow/core/tpu/kernels/xla/outfeed_ops.cc b/tensorflow/core/tpu/kernels/xla/outfeed_ops.cc index aa2dacfb319b35..fac67ba2289c31 100644 --- a/tensorflow/core/tpu/kernels/xla/outfeed_ops.cc +++ b/tensorflow/core/tpu/kernels/xla/outfeed_ops.cc @@ -28,7 +28,6 @@ limitations under the License. #include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/platform/types.h" #include "tsl/platform/logging.h" // IWYU pragma: keep -#include "tsl/platform/macros.h" namespace tensorflow { diff --git a/tensorflow/core/tpu/tpu_execute.cc b/tensorflow/core/tpu/tpu_execute.cc index 22e299e9bbb26d..8a50c90c646510 100644 --- a/tensorflow/core/tpu/tpu_execute.cc +++ b/tensorflow/core/tpu/tpu_execute.cc @@ -199,38 +199,39 @@ xla::Status UpdateDynamicInputs( ShapeSizeCompact(compile_time_shape), -1); auto raw_input_runtime = std::make_shared>( ShapeSizeCompact(runtime_shape) / sizeof(uint32_t)); - stream->ThenMemcpyD2H( + TF_RETURN_IF_ERROR(stream->MemcpyD2H( se::DeviceMemory(mutable_input_mem->AsDeviceMemoryBase()), absl::MakeSpan(absl::bit_cast(raw_input_runtime->data()), - ShapeSizeCompactRaw(runtime_shape))); - stream->ThenDoHostCallbackWithStatus([raw_input_runtime, padded_data, - runtime_shape, - compile_time_shape]() { - // After getting the data onto the host, transpose the data to - // the correct layout by delinearizing it and linearizing it again. - XLA_Shape c_runtime_shape, c_compile_time_shape; - ApiConverter::ToC(runtime_shape, &c_runtime_shape); - ApiConverter::ToC(compile_time_shape, &c_compile_time_shape); - StatusHelper status; - - TpuExecute_RuntimeInputToPaddedData_Params params; - params.struct_size = - TpuExecute_RuntimeInputToPaddedData_Params_SIZE; - params.priv = nullptr; - params.runtime_input_ptr = raw_input_runtime->data(); - params.runtime_input_size = raw_input_runtime->size(); - params.padded_data_ptr = padded_data->data(); - params.padded_data_size = padded_data->size(); - params.runtime_shape = &c_runtime_shape; - params.compile_time_shape = &c_compile_time_shape; - params.status = status.c_status; - - stream_executor::tpu::OpsApiFn() - ->TpuExecute_RuntimeInputToPaddedDataFn(¶ms); - ApiConverter::Destroy(&c_runtime_shape); - ApiConverter::Destroy(&c_compile_time_shape); - return status.status(); - }); + ShapeSizeCompactRaw(runtime_shape)))); + TF_RETURN_IF_ERROR(stream->DoHostCallbackWithStatus( + [raw_input_runtime, padded_data, runtime_shape, + compile_time_shape]() { + // After getting the data onto the host, transpose the data to + // the correct layout by delinearizing it and linearizing it + // again. + XLA_Shape c_runtime_shape, c_compile_time_shape; + ApiConverter::ToC(runtime_shape, &c_runtime_shape); + ApiConverter::ToC(compile_time_shape, &c_compile_time_shape); + StatusHelper status; + + TpuExecute_RuntimeInputToPaddedData_Params params; + params.struct_size = + TpuExecute_RuntimeInputToPaddedData_Params_SIZE; + params.priv = nullptr; + params.runtime_input_ptr = raw_input_runtime->data(); + params.runtime_input_size = raw_input_runtime->size(); + params.padded_data_ptr = padded_data->data(); + params.padded_data_size = padded_data->size(); + params.runtime_shape = &c_runtime_shape; + params.compile_time_shape = &c_compile_time_shape; + params.status = status.c_status; + + stream_executor::tpu::OpsApiFn() + ->TpuExecute_RuntimeInputToPaddedDataFn(¶ms); + ApiConverter::Destroy(&c_runtime_shape); + ApiConverter::Destroy(&c_compile_time_shape); + return status.status(); + })); // Allocate new input and transfer the padded and transposed data to // the new input location. TF_ASSIGN_OR_RETURN( @@ -239,10 +240,11 @@ xla::Status UpdateDynamicInputs( ShapeSizeCompact(compile_time_shape))); auto typed_new_input_memory = se::DeviceMemory(new_input.cref()); - stream->ThenMemcpyH2D(*padded_data, &typed_new_input_memory); + TF_RETURN_IF_ERROR( + stream->MemcpyH2D(*padded_data, &typed_new_input_memory)); // Retain the memory until the end of the transfer. - stream->ThenDoHostCallback([padded_data] {}); + TF_RETURN_IF_ERROR(stream->DoHostCallback([padded_data] {})); // Modify the memory location in the input shape tree to point to the // new input. @@ -343,38 +345,44 @@ void UnregisterCancellation(OpKernelContext* ctx, // the frequency of back-to-back programs (which are most efficient because // they don't require host synchronization). Instead, borrow a substream and // have the substream wait on the compute stream. - se::Stream* deregister_stream = stream->GetOrCreateSubStream(); - deregister_stream->ThenWaitFor(stream); - deregister_stream->ThenDoHostCallback([=]() { - // We must deregister the callback in the success case, to avoid closing all - // devices. In the failure case we must NOT call DeregisterCallback as that - // waits for all previous cancellation callbacks to complete and any call - // to XlaDevice::Sync() will cause deadlock. Consider: - // 1) CancellationManager::StartCancel() is in progress (state is - // cancelling_). - // 2) The call below to DeregisterCallback will block until state is - // cancelled_ (all callbacks are completed). - // 3) A different cancellation callback has called XlaDevice::Sync(), - // which will block until (2) is done. - // 4) StartCancel() in (1) cannot complete until (3) is done. - // - // Instead, call TryDeregisterCallback. The functional difference is - // TryDeregisterCallback will not block if cancellation is in progress - // so makes no guarantees as to the state of any callbacks. - // This is not a problem, as our cancellation handler does not rely on - // any external state. - VLOG(1) << "cancellation_manager->TryDeregisterCallback on device " - << device_ordinal; - cancellation_manager->TryDeregisterCallback(token); - VLOG(1) << "cancellation_manager->TryDeregisterCallback done on device " - << device_ordinal; - - // ExecutorState is held alive until at least this point to ensure - // cancellation_manager is valid. After all outstanding - // dec_num_deferred_ops_function are called, ExecutorState::Finish will be - // allowed to proceed. - dec_num_deferred_ops_function(); - }); + se::Stream* deregister_stream = + stream->GetOrCreateSubStream().value_or(nullptr); + if (deregister_stream == nullptr) { + return; + } + deregister_stream->WaitFor(stream).IgnoreError(); + deregister_stream + ->DoHostCallback([=]() { + // We must deregister the callback in the success case, to avoid closing + // all devices. In the failure case we must NOT call DeregisterCallback + // as that waits for all previous cancellation callbacks to complete and + // any call to XlaDevice::Sync() will cause deadlock. Consider: + // 1) CancellationManager::StartCancel() is in progress (state is + // cancelling_). + // 2) The call below to DeregisterCallback will block until state is + // cancelled_ (all callbacks are completed). + // 3) A different cancellation callback has called XlaDevice::Sync(), + // which will block until (2) is done. + // 4) StartCancel() in (1) cannot complete until (3) is done. + // + // Instead, call TryDeregisterCallback. The functional difference is + // TryDeregisterCallback will not block if cancellation is in progress + // so makes no guarantees as to the state of any callbacks. + // This is not a problem, as our cancellation handler does not rely on + // any external state. + VLOG(1) << "cancellation_manager->TryDeregisterCallback on device " + << device_ordinal; + cancellation_manager->TryDeregisterCallback(token); + VLOG(1) << "cancellation_manager->TryDeregisterCallback done on device " + << device_ordinal; + + // ExecutorState is held alive until at least this point to ensure + // cancellation_manager is valid. After all outstanding + // dec_num_deferred_ops_function are called, ExecutorState::Finish will + // be allowed to proceed. + dec_num_deferred_ops_function(); + }) + .IgnoreError(); stream->ReturnSubStream(deregister_stream); } diff --git a/tensorflow/core/transforms/eliminate_passthrough_iter_args/BUILD b/tensorflow/core/transforms/eliminate_passthrough_iter_args/BUILD index b19c211a8abdcb..043cab4e10fb11 100644 --- a/tensorflow/core/transforms/eliminate_passthrough_iter_args/BUILD +++ b/tensorflow/core/transforms/eliminate_passthrough_iter_args/BUILD @@ -22,7 +22,6 @@ cc_library( "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", "@llvm-project//mlir:Support", - "@llvm-project//mlir:TransformUtils", ], ) diff --git a/tensorflow/core/transforms/eliminate_passthrough_iter_args/pass.cc b/tensorflow/core/transforms/eliminate_passthrough_iter_args/pass.cc index 7b40357f79998c..a770f8528b0f74 100644 --- a/tensorflow/core/transforms/eliminate_passthrough_iter_args/pass.cc +++ b/tensorflow/core/transforms/eliminate_passthrough_iter_args/pass.cc @@ -18,15 +18,16 @@ limitations under the License. #include #include +#include "llvm/ADT/ADL.h" #include "llvm/ADT/BitVector.h" -#include "llvm/ADT/EpochTracker.h" #include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" #include "llvm/Support/Debug.h" -#include "mlir/IR/MLIRContext.h" // from @llvm-project #include "mlir/IR/PatternMatch.h" // from @llvm-project #include "mlir/IR/Value.h" // from @llvm-project -#include "mlir/Support/LogicalResult.h" // from @llvm-project -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project +#include "mlir/IR/ValueRange.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/core/ir/ops.h" #include "tensorflow/core/ir/utility.h" #include "tensorflow/core/transforms/utils/utils.h" diff --git a/tensorflow/core/transforms/shape_inference/BUILD b/tensorflow/core/transforms/shape_inference/BUILD index d9eb50e9762b2a..bc8c4762d809e5 100644 --- a/tensorflow/core/transforms/shape_inference/BUILD +++ b/tensorflow/core/transforms/shape_inference/BUILD @@ -18,7 +18,6 @@ cc_library( "//tensorflow/core:framework", "//tensorflow/core/ir:Dialect", "//tensorflow/core/ir:shape_inference_utils", - "//tensorflow/core/ir/importexport:convert_tensor", "//tensorflow/core/ir/types:Dialect", "//tensorflow/core/transforms:PassIncGen", "@llvm-project//llvm:Support", diff --git a/tensorflow/core/transforms/shape_inference/pass.cc b/tensorflow/core/transforms/shape_inference/pass.cc index c6763b4b51e5e7..80a28e0b53dbfb 100644 --- a/tensorflow/core/transforms/shape_inference/pass.cc +++ b/tensorflow/core/transforms/shape_inference/pass.cc @@ -19,16 +19,22 @@ limitations under the License. #include #include "llvm/ADT/STLExtras.h" -#include "llvm/ADT/SetVector.h" -#include "llvm/ADT/SmallVector.h" +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project #include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/OpDefinition.h" // from @llvm-project #include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project #include "mlir/IR/Value.h" // from @llvm-project #include "mlir/IR/Visitors.h" // from @llvm-project #include "mlir/Interfaces/InferTypeOpInterface.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project #include "tensorflow/core/framework/shape_inference.h" -#include "tensorflow/core/ir/importexport/convert_tensor.h" +#include "tensorflow/core/ir/dialect.h" #include "tensorflow/core/ir/ops.h" #include "tensorflow/core/ir/tf_op_wrapper.h" #include "tensorflow/core/ir/types/dialect.h" diff --git a/tensorflow/core/util/autotune_maps/BUILD b/tensorflow/core/util/autotune_maps/BUILD index f4f13211ab2f8e..10078d6ce696a8 100644 --- a/tensorflow/core/util/autotune_maps/BUILD +++ b/tensorflow/core/util/autotune_maps/BUILD @@ -146,6 +146,7 @@ tf_cuda_library( "@local_xla//xla:status_macros", "@local_xla//xla/stream_executor", "@local_xla//xla/stream_executor:lazy_op_runner", + "@local_xla//xla/stream_executor:platform_manager", "@local_xla//xla/stream_executor/gpu:gpu_init", ], ) diff --git a/tensorflow/core/util/autotune_maps/autotune_serialize.cc b/tensorflow/core/util/autotune_maps/autotune_serialize.cc index f943aecdfeedd4..464364df6768ef 100644 --- a/tensorflow/core/util/autotune_maps/autotune_serialize.cc +++ b/tensorflow/core/util/autotune_maps/autotune_serialize.cc @@ -24,6 +24,7 @@ limitations under the License. #include "xla/status_macros.h" #include "xla/stream_executor/dnn.h" #include "xla/stream_executor/gpu/gpu_init.h" +#include "xla/stream_executor/platform_manager.h" #include "tensorflow/core/platform/str_util.h" #include "tensorflow/core/util/activation_mode.h" #include "tensorflow/core/util/autotune_maps/autotune_map.pb.h" @@ -100,7 +101,7 @@ Status PopulateConvMap( // Get the list of all GPU StreamExecutors. TF_ASSIGN_OR_RETURN( se::Platform * platform, - se::MultiPlatformManager::PlatformWithName(se::GpuPlatformName())); + se::PlatformManager::PlatformWithName(se::GpuPlatformName())); std::vector device_descs; for (int i = 0; i < platform->VisibleDeviceCount(); i++) { TF_ASSIGN_OR_RETURN(std::unique_ptr device_desc, diff --git a/tensorflow/core/util/autotune_maps/autotune_serialize_test.cc b/tensorflow/core/util/autotune_maps/autotune_serialize_test.cc index ef62dfaf414cf5..c135beb198756e 100644 --- a/tensorflow/core/util/autotune_maps/autotune_serialize_test.cc +++ b/tensorflow/core/util/autotune_maps/autotune_serialize_test.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ // For Google-internal use only. +#include "xla/stream_executor/platform_manager.h" #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM #include "tensorflow/core/util/autotune_maps/autotune_serialize.h" @@ -37,7 +38,7 @@ using ::testing::HasSubstr; // Gets a GPU StreamExecutor instance. Any one will do. se::StreamExecutor* GetStreamExec() { se::Platform* platform = - se::MultiPlatformManager::PlatformWithName(se::GpuPlatformName()).value(); + se::PlatformManager::PlatformWithName(se::GpuPlatformName()).value(); CHECK_GT(platform->VisibleDeviceCount(), 0); return platform->ExecutorForDevice(0).value(); } diff --git a/tensorflow/core/util/onednn_env_vars.cc b/tensorflow/core/util/onednn_env_vars.cc index 1b73ef8e862fd7..0bba22bc9b16b6 100644 --- a/tensorflow/core/util/onednn_env_vars.cc +++ b/tensorflow/core/util/onednn_env_vars.cc @@ -1,4 +1,4 @@ -/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -53,17 +53,6 @@ bool ThreadPoolUseCallerThread() { return threadpool_use_caller_thread; } -bool UseOnednnSpmm() { - static bool use_onednn_spmm = [] { - bool setting; - TF_CHECK_OK(ReadBoolFromEnvVar("TF_ENABLE_ONEDNN_SPMM", - /*default_value*/ false, &setting)); - return setting; - }(); - - return use_onednn_spmm; -} - std::string FPMathModeSetting() { static std::string math_mode_setting = [] { std::string setting = ""; diff --git a/tensorflow/core/util/onednn_env_vars.h b/tensorflow/core/util/onednn_env_vars.h index e2cb27ccfc8115..4cf97500379262 100644 --- a/tensorflow/core/util/onednn_env_vars.h +++ b/tensorflow/core/util/onednn_env_vars.h @@ -1,4 +1,4 @@ -/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -27,8 +27,6 @@ bool UseSystemAlloc(); bool ThreadPoolUseCallerThread(); -bool UseOnednnSpmm(); - std::string FPMathModeSetting(); } // namespace tensorflow #endif // INTEL_MKL diff --git a/tensorflow/core/util/tensor_bundle/tensor_bundle.cc b/tensorflow/core/util/tensor_bundle/tensor_bundle.cc index 83a37ad11c3b5b..a12a832222537d 100644 --- a/tensorflow/core/util/tensor_bundle/tensor_bundle.cc +++ b/tensorflow/core/util/tensor_bundle/tensor_bundle.cc @@ -916,7 +916,7 @@ Status BundleReader::GetValue(const BundleEntryProto& entry, Tensor* val) { if (DataTypeCanUseMemcpy(entry.dtype())) { char* backing_buffer = const_cast((ret->tensor_data().data())); size_t unused_bytes_read; - if (entry.size() > kBufferSize) { + if (entry.size() > kBufferSize || enable_multi_threading_for_testing_) { StringPiece sp; if (!enable_multi_threading_for_testing_ && entry.size() < kLargeTensorThreshold) { @@ -964,6 +964,8 @@ Status BundleReader::GetValue(const BundleEntryProto& entry, Tensor* val) { statuses[i] = std::move(status); }); } + reader_pool = nullptr; // Wait for reads to finish + for (const auto& status : statuses) { TF_RETURN_IF_ERROR(status); } diff --git a/tensorflow/dtensor/cc/dtensor_device.cc b/tensorflow/dtensor/cc/dtensor_device.cc index b9ddf25fde6caf..cc3a92ec8d1c97 100644 --- a/tensorflow/dtensor/cc/dtensor_device.cc +++ b/tensorflow/dtensor/cc/dtensor_device.cc @@ -249,7 +249,7 @@ class DTensorDevice { } Mesh::tpu_core_ids()[mesh_name].assign(tpu_core_ids.begin(), tpu_core_ids.end()); - return OkStatus(); + return absl::OkStatus(); } void ClearTPUCoreIDs() { Mesh::tpu_core_ids().clear(); } @@ -1582,7 +1582,7 @@ Status AddExecutionFunctionDefsToFunctionDefLibrary( to_run, stack_traces)); } - return OkStatus(); + return absl::OkStatus(); } StatusOr diff --git a/tensorflow/dtensor/cc/dtensor_device_util.cc b/tensorflow/dtensor/cc/dtensor_device_util.cc index d3a84bac0d2017..8bb49f31d64d64 100644 --- a/tensorflow/dtensor/cc/dtensor_device_util.cc +++ b/tensorflow/dtensor/cc/dtensor_device_util.cc @@ -244,7 +244,7 @@ Status ParseAttrMap(const Node& node, absl::string_view indices_attr, std::map* indices_layout_map) { std::vector layouts; if (!TryGetNodeAttr(node.attrs(), layout_attr, &layouts)) { - return OkStatus(); + return absl::OkStatus(); } const TensorProto* indices; if (!TryGetNodeAttr(node.attrs(), indices_attr, &indices)) { @@ -262,7 +262,7 @@ Status ParseAttrMap(const Node& node, absl::string_view indices_attr, indices_layout_map->emplace( arg_index, tensorflow::dtensor::Layout::FromString(arg_layout).value()); } - return OkStatus(); + return absl::OkStatus(); } Status ParseResourceArgumentLayouts( @@ -598,7 +598,7 @@ tensorflow::Fprint128 ResourceHandleWithLayout::CacheKey() const { return f; } -tsl::Status ResourceHandleWithLayout::UpdateLayout(const Layout& new_layout) { +absl::Status ResourceHandleWithLayout::UpdateLayout(const Layout& new_layout) { // Only set the value for deferenced layout if the incoming layout is not // empty. This is still hacky as we use empty layout as placeholder for // eagerly placed VarHandleOp. @@ -612,7 +612,7 @@ tsl::Status ResourceHandleWithLayout::UpdateLayout(const Layout& new_layout) { "Attempted to overwrite an existing Layout."); } dereferenced_layout_.emplace(new_layout); - return tsl::OkStatus(); + return absl::OkStatus(); } char SparseTensorWithLayout::ID = 0; @@ -753,7 +753,7 @@ Status InferOutputLayouts(const DTensorOperation& doperation, output_layouts->push_back(layout); } graph->RemoveNode(op_node); - return OkStatus(); + return absl::OkStatus(); } Status PrepareGraphForMlir( @@ -944,7 +944,7 @@ Status PrepareGraphForMlir( ret_node->AddAttr(kDefaultLayoutAttr, layout->ToString()); } } - return OkStatus(); + return absl::OkStatus(); } StatusOr> GetNumLocalOutputs(Node* node) { @@ -995,7 +995,7 @@ Status SetMultiDeviceFunctionOutputs( } } function.num_local_outputs = std::move(num_local_outputs); - return OkStatus(); + return absl::OkStatus(); } } // namespace @@ -1109,7 +1109,7 @@ StatusOr IdentifyAllFunctionsToExecute( // nodes. Status MaybeInsertIdentityNodes(const FunctionDef* function_def, Graph* graph) { if (function_def == nullptr || function_def->control_ret().empty()) { - return OkStatus(); + return absl::OkStatus(); } tensorflow::Status status; for (Node* n : graph->nodes()) { @@ -1138,7 +1138,7 @@ Status MaybeInsertIdentityNodes(const FunctionDef* function_def, Graph* graph) { // Add an edge between Identity and _Retval. graph->AddEdge(ret_identity_node, 0, n, 0); } - return OkStatus(); + return absl::OkStatus(); } void AddDTensorFunctionAttr(FunctionDef& function_def) { diff --git a/tensorflow/dtensor/cc/dtensor_device_util.h b/tensorflow/dtensor/cc/dtensor_device_util.h index ce01047cc7bce8..604d33f9691cd3 100644 --- a/tensorflow/dtensor/cc/dtensor_device_util.h +++ b/tensorflow/dtensor/cc/dtensor_device_util.h @@ -294,20 +294,20 @@ class ResourceHandleWithLayout tensorflow::Fprint128 CacheKey() const override; // Updates the layout for the tensors. - tsl::Status UpdateLayout(const Layout& new_layout); + absl::Status UpdateLayout(const Layout& new_layout); // Updates the element layouts for the tensors. - tsl::Status UpdateElementLayouts(const std::vector& layouts) { + absl::Status UpdateElementLayouts(const std::vector& layouts) { dereferenced_element_layouts_.emplace(layouts); - return tsl::OkStatus(); + return absl::OkStatus(); } // Updates the local shape and dtype of the tensors. - tsl::Status UpdateShapeAndDType(const TensorShapeProto& shape, - const DataType& dtype) { + absl::Status UpdateShapeAndDType(const TensorShapeProto& shape, + const DataType& dtype) { set_dereferenced_shape(shape); set_dereferenced_dtype(dtype); - return tsl::OkStatus(); + return absl::OkStatus(); } ConstValueNode* const_value_node() const override { return nullptr; } diff --git a/tensorflow/dtensor/cc/dtensor_graph_to_mlir_pass.cc b/tensorflow/dtensor/cc/dtensor_graph_to_mlir_pass.cc index 6049ae707a42e2..25767572bed386 100644 --- a/tensorflow/dtensor/cc/dtensor_graph_to_mlir_pass.cc +++ b/tensorflow/dtensor/cc/dtensor_graph_to_mlir_pass.cc @@ -152,7 +152,7 @@ Status DTensorMlirPassRunner::Run(mlir::ModuleOp module) { TF_RETURN_IF_ERROR(diag_handler.ConsumeStatus()); if (logging_enabled_) pass_manager_.getContext()->enableMultithreading(); - return OkStatus(); + return absl::OkStatus(); } } // namespace tensorflow diff --git a/tensorflow/dtensor/cc/dtensor_meta_ops.cc b/tensorflow/dtensor/cc/dtensor_meta_ops.cc index a38e8392171c26..b4a51359209044 100644 --- a/tensorflow/dtensor/cc/dtensor_meta_ops.cc +++ b/tensorflow/dtensor/cc/dtensor_meta_ops.cc @@ -57,7 +57,7 @@ REGISTER_OP("DTensorAllScatter") if (!c->RankKnown(in)) { // Input shape unknown, so set unknown output shape. c->set_output(0, in); - return OkStatus(); + return absl::OkStatus(); } std::string input_layout_string; @@ -100,7 +100,7 @@ REGISTER_OP("DTensorAllScatter") } } c->set_output(0, c->MakeShape(out_dims)); - return OkStatus(); + return absl::OkStatus(); }); REGISTER_OP("DTensorAllGather") @@ -117,7 +117,7 @@ REGISTER_OP("DTensorAllGather") if (!c->RankKnown(in)) { // Input shape unknown, so set unknown output shape. c->set_output(0, in); - return OkStatus(); + return absl::OkStatus(); } std::string input_layout_string; @@ -159,7 +159,7 @@ REGISTER_OP("DTensorAllGather") } } c->set_output(0, c->MakeShape(out_dims)); - return OkStatus(); + return absl::OkStatus(); }); REGISTER_OP("DTensorAllToAll") @@ -173,7 +173,7 @@ REGISTER_OP("DTensorAllToAll") if (!c->RankKnown(in)) { // Input shape unknown, so set unknown output shape. c->set_output(0, in); - return OkStatus(); + return absl::OkStatus(); } std::string input_layout_string; @@ -218,7 +218,7 @@ REGISTER_OP("DTensorAllToAll") } } c->set_output(0, c->MakeShape(out_dims)); - return OkStatus(); + return absl::OkStatus(); }); } // namespace dtensor diff --git a/tensorflow/dtensor/cc/dtensor_tpu_kernels.cc b/tensorflow/dtensor/cc/dtensor_tpu_kernels.cc index f23b7eb54d6cb7..c6d1b560d4f108 100644 --- a/tensorflow/dtensor/cc/dtensor_tpu_kernels.cc +++ b/tensorflow/dtensor/cc/dtensor_tpu_kernels.cc @@ -59,11 +59,11 @@ Status DeleteIfExists(ResourceMgr* resource_manager, resource_manager->default_container(), resource_name); if (status.ok()) { VLOG(1) << "Removed existing resource " << resource_name; - return OkStatus(); + return absl::OkStatus(); } if (status.code() == error::NOT_FOUND) { VLOG(1) << "No resource " << resource_name << " to remove"; - return OkStatus(); + return absl::OkStatus(); } VLOG(1) << "Error removing resource " << resource_name << " : " << status; return status; @@ -151,7 +151,7 @@ class ConfigureAndInitializeGlobalTPUOpKernel : public OpKernel { } auto start = absl::Now(); - auto init_status = OkStatus(); + auto init_status = absl::OkStatus(); // Keep trying to initialize underlying TPU system until either TPU system // is initialized or initialization times out. @@ -242,7 +242,7 @@ class ConfigureAndInitializeGlobalTPUOpKernel : public OpKernel { tpu::kTpuEmbeddingEngineStateInterfaceResourceName, tpu::TpuEmbeddingEngineStateInterface::Create())); - return OkStatus(); + return absl::OkStatus(); } }; diff --git a/tensorflow/dtensor/cc/dtensor_tpu_ops.cc b/tensorflow/dtensor/cc/dtensor_tpu_ops.cc index 3e096a4c7c32f6..8c2d5a4931cf25 100644 --- a/tensorflow/dtensor/cc/dtensor_tpu_ops.cc +++ b/tensorflow/dtensor/cc/dtensor_tpu_ops.cc @@ -44,7 +44,7 @@ REGISTER_OP("ConfigureAndInitializeGlobalTPU") TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 0, &input)); } c->set_output(0, c->Vector(c->UnknownDim())); - return OkStatus(); + return absl::OkStatus(); }); REGISTER_OP("ShutdownTPUSystem") @@ -58,7 +58,7 @@ REGISTER_OP("DTensorSetGlobalTPUArray") .SetShapeFn([](InferenceContext* c) { ShapeHandle input; TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &input)); - return OkStatus(); + return absl::OkStatus(); }); } // namespace dtensor diff --git a/tensorflow/dtensor/cc/slice_util.h b/tensorflow/dtensor/cc/slice_util.h index f8167f5dff2627..c16a18227a9495 100644 --- a/tensorflow/dtensor/cc/slice_util.h +++ b/tensorflow/dtensor/cc/slice_util.h @@ -191,7 +191,7 @@ class ForwardLayoutInference : public TokenProcessor { TF_ASSIGN_OR_RETURN( expander_value_layout_, Layout::GetLayout(expander_value_sharding_, input_layout_.mesh())); - return OkStatus(); + return absl::OkStatus(); } private: @@ -291,7 +291,7 @@ class BackwardLayoutInference : public TokenProcessor { TF_ASSIGN_OR_RETURN( expander_value_layout_, Layout::GetLayout(expander_value_sharding_, value_layout_.mesh())); - return OkStatus(); + return absl::OkStatus(); } private: diff --git a/tensorflow/dtensor/mlir/dtensor_multi_device_expansion.cc b/tensorflow/dtensor/mlir/dtensor_multi_device_expansion.cc index b50b2a7b7d12c5..1f47934230e3bc 100644 --- a/tensorflow/dtensor/mlir/dtensor_multi_device_expansion.cc +++ b/tensorflow/dtensor/mlir/dtensor_multi_device_expansion.cc @@ -747,7 +747,7 @@ mlir::LogicalResult BuildOuterMainFunc( Status ExtractResultLayouts(mlir::Operation* op, mlir::func::ReturnOp return_op, std::vector& expanded_results) { if (!return_op || (return_op.getNumOperands() == 0)) { - return OkStatus(); + return absl::OkStatus(); } TF_ASSIGN_OR_RETURN(std::vector> layouts, ExtractLayoutFromOp(op)); @@ -760,7 +760,7 @@ Status ExtractResultLayouts(mlir::Operation* op, mlir::func::ReturnOp return_op, size_t result_index = std::distance(operands.begin(), search); expanded_results[result_index].layout = layouts[layout_index]; } - return OkStatus(); + return absl::OkStatus(); } struct DTensorMultiDeviceExpansion diff --git a/tensorflow/dtensor/mlir/expansions/concat_spmd_expander.cc b/tensorflow/dtensor/mlir/expansions/concat_spmd_expander.cc index 1d3948fc7d8144..f71d8f72d7bd4e 100644 --- a/tensorflow/dtensor/mlir/expansions/concat_spmd_expander.cc +++ b/tensorflow/dtensor/mlir/expansions/concat_spmd_expander.cc @@ -47,7 +47,7 @@ Status VerifyConcatLayout(mlir::Value concat_dim_operand, } } - return OkStatus(); + return absl::OkStatus(); } StatusOr ReduceForConcatOutputLayout(mlir::Value concat_dim_operand, diff --git a/tensorflow/dtensor/mlir/expansions/conv_spmd_expander.cc b/tensorflow/dtensor/mlir/expansions/conv_spmd_expander.cc index e3c893e4c07a8a..2602a16eb9f824 100644 --- a/tensorflow/dtensor/mlir/expansions/conv_spmd_expander.cc +++ b/tensorflow/dtensor/mlir/expansions/conv_spmd_expander.cc @@ -59,7 +59,7 @@ Status VerifyConvLayout(const Layout& input_layout, const Layout& filter_layout, if (input_layout.IsBatchParallel() || input_layout.IsFullyReplicated()) // No further checks needed for replicated case. - return OkStatus(); + return absl::OkStatus(); if (conv_op.getPadding() == "EXPLICIT") return errors::InvalidArgument( @@ -106,7 +106,7 @@ Status VerifyConvLayout(const Layout& input_layout, const Layout& filter_layout, "spatial partitions."); } - return OkStatus(); + return absl::OkStatus(); } mlir::Value PadInputOnUnshardedDim(mlir::OpBuilder& builder, diff --git a/tensorflow/dtensor/mlir/expansions/dtensor_op_spmd_expander.cc b/tensorflow/dtensor/mlir/expansions/dtensor_op_spmd_expander.cc index 6e8d896545a02c..432bc75e2b800a 100644 --- a/tensorflow/dtensor/mlir/expansions/dtensor_op_spmd_expander.cc +++ b/tensorflow/dtensor/mlir/expansions/dtensor_op_spmd_expander.cc @@ -53,7 +53,7 @@ Status ValidateSendRecvLayoutConfiguration(mlir::TF::DTensorSend dtensor_send, mlir::TF::DTensorRecv dtensor_recv) { // If either one of the send/recv ops has already been lowered, then send/recv // configuration has already been verified. - if (!dtensor_send || !dtensor_recv) return OkStatus(); + if (!dtensor_send || !dtensor_recv) return absl::OkStatus(); TF_ASSIGN_OR_RETURN(const absl::optional send_layout_or_null, ExtractLayoutFromOperand(dtensor_send.getInput())); @@ -110,7 +110,7 @@ Status ValidateSendRecvLayoutConfiguration(mlir::TF::DTensorSend dtensor_send, return absl::InvalidArgumentError( "tf.CopyToMesh op must be used to send data from/to host mesh."); - return OkStatus(); + return absl::OkStatus(); } template diff --git a/tensorflow/dtensor/mlir/expansions/einsum_spmd_expander.cc b/tensorflow/dtensor/mlir/expansions/einsum_spmd_expander.cc index eff66fa0354784..e7427b924f41ce 100644 --- a/tensorflow/dtensor/mlir/expansions/einsum_spmd_expander.cc +++ b/tensorflow/dtensor/mlir/expansions/einsum_spmd_expander.cc @@ -134,7 +134,7 @@ Status ExtractEquationRelations( } } - return OkStatus(); + return absl::OkStatus(); } // For a set of layouts and mappings from labels to offsets in the layouts, @@ -545,7 +545,7 @@ Status EinsumSPMDExpander::MaybeRelayoutInputs( for (const auto& contracting : contracting_labels) reduce_dims.emplace(input_label_to_sharding_spec[contracting]); - return OkStatus(); + return absl::OkStatus(); } } // namespace dtensor diff --git a/tensorflow/dtensor/mlir/expansions/gather_spmd_expander.cc b/tensorflow/dtensor/mlir/expansions/gather_spmd_expander.cc index fd42942f7e4130..7c3706e0f2d858 100644 --- a/tensorflow/dtensor/mlir/expansions/gather_spmd_expander.cc +++ b/tensorflow/dtensor/mlir/expansions/gather_spmd_expander.cc @@ -262,7 +262,7 @@ Status GatherNdGetInputLayoutFromOutput(const Layout& output_layout, TF_ASSIGN_OR_RETURN(*params_layout, Layout::GetLayout(params_specs, mesh)); TF_ASSIGN_OR_RETURN(*indices_layout, Layout::GetLayout(indices_specs, mesh)); - return OkStatus(); + return absl::OkStatus(); } } // namespace diff --git a/tensorflow/dtensor/mlir/expansions/matmul_spmd_expander.cc b/tensorflow/dtensor/mlir/expansions/matmul_spmd_expander.cc index ad409760344d12..f2c6521e71c744 100644 --- a/tensorflow/dtensor/mlir/expansions/matmul_spmd_expander.cc +++ b/tensorflow/dtensor/mlir/expansions/matmul_spmd_expander.cc @@ -114,7 +114,7 @@ StatusOr MatMulSPMDExpander::OutputLayoutAndReducedDims( Layout batch_layout; if (!*left || !*right) { - if (allow_unknown_layouts) return OkStatus(); + if (allow_unknown_layouts) return absl::OkStatus(); return errors::Unimplemented("failed to do SPMD expansion for ", OpName(op), " operand layouts " "unknown"); @@ -360,7 +360,7 @@ Status MatMulSPMDExpander::MaybeRelayoutInputs( TF_ASSIGN_OR_RETURN( right, EmitRelayout(op->getOperand(1), right_layout, new_right_layout)); - return OkStatus(); + return absl::OkStatus(); } StatusOr> MatMulSPMDExpander::ComputeLayoutForward( diff --git a/tensorflow/dtensor/mlir/expansions/meta_spmd_expander.cc b/tensorflow/dtensor/mlir/expansions/meta_spmd_expander.cc index 5606187b1e17ee..cc82fa42901fa7 100644 --- a/tensorflow/dtensor/mlir/expansions/meta_spmd_expander.cc +++ b/tensorflow/dtensor/mlir/expansions/meta_spmd_expander.cc @@ -262,7 +262,7 @@ Status VerifyPaddedDimensionNotSharded(const Layout& layout, "Padding over sharded dimension is not allowed."); } } - return OkStatus(); + return absl::OkStatus(); } } // namespace @@ -353,7 +353,7 @@ Status VerifyTileOperandLayout(const Layout& operand_layout, "tile op with input sharded at dimension where `multiple` > 1 is not " "supported."); } - return OkStatus(); + return absl::OkStatus(); } } // namespace @@ -981,7 +981,7 @@ Status RelayoutOneHotInput(const absl::optional& input_layout, one_hot->setOperand(0, new_input); - return OkStatus(); + return absl::OkStatus(); } } // namespace diff --git a/tensorflow/dtensor/mlir/expansions/random_op_spmd_expander.cc b/tensorflow/dtensor/mlir/expansions/random_op_spmd_expander.cc index 32e6f4315849ba..5156e09e417c1f 100644 --- a/tensorflow/dtensor/mlir/expansions/random_op_spmd_expander.cc +++ b/tensorflow/dtensor/mlir/expansions/random_op_spmd_expander.cc @@ -40,7 +40,7 @@ Status CheckLayoutIsSupported(const Layout& layout) { return errors::InvalidArgument("Large mesh rank size is not supported", layout.ToString()); - return OkStatus(); + return absl::OkStatus(); } Status ValidateShapeAndGetNewShape( @@ -70,7 +70,7 @@ Status ValidateShapeAndGetNewShape( } new_random_shape.emplace_back(op_dimension_size / dimension_sharding); } - return OkStatus(); + return absl::OkStatus(); } // Get a device seed for this layout and device_id. diff --git a/tensorflow/dtensor/mlir/expansions/reduce_spmd_expander.cc b/tensorflow/dtensor/mlir/expansions/reduce_spmd_expander.cc index 217b41c986bbf0..9d40da3173cee4 100644 --- a/tensorflow/dtensor/mlir/expansions/reduce_spmd_expander.cc +++ b/tensorflow/dtensor/mlir/expansions/reduce_spmd_expander.cc @@ -69,14 +69,14 @@ absl::string_view DefiningOpName(mlir::Value operand) { Status AssertReplicated(mlir::Value operand) { TF_ASSIGN_OR_RETURN(auto layout, ExtractLayoutFromOperand(operand)); - if (!layout) return OkStatus(); + if (!layout) return absl::OkStatus(); if (!layout->IsFullyReplicated()) { return errors::InvalidArgument( "Expected layout for ", DefiningOpName(operand), " to be fully replicated, but found ", layout->ToString()); } - return OkStatus(); + return absl::OkStatus(); } absl::flat_hash_set ReducedMeshDimensions( @@ -95,7 +95,7 @@ template Status ExtractDims(mlir::Operation* op, llvm::SmallVector* reduced_dims, bool* keep_dims, bool* matched) { - if (!llvm::isa(op)) return OkStatus(); + if (!llvm::isa(op)) return absl::OkStatus(); auto reduce_op = llvm::cast(op); *keep_dims = reduce_op.getKeepDims(); TF_RETURN_IF_ERROR(ExtractConstVectorFromValue( @@ -103,7 +103,7 @@ Status ExtractDims(mlir::Operation* op, TF_RETURN_IF_ERROR(AssertReplicated(reduce_op.getReductionIndices())); *matched = true; - return OkStatus(); + return absl::OkStatus(); } template <> @@ -191,7 +191,7 @@ Status ExtractReductionParameters(mlir::Operation* op, " not yet implemented."); reduced_dims_set.insert(reduced_dims.begin(), reduced_dims.end()); - return OkStatus(); + return absl::OkStatus(); } StatusOr ComputeResultLayout(mlir::Operation* op, diff --git a/tensorflow/dtensor/mlir/expansions/resource_spmd_expander.cc b/tensorflow/dtensor/mlir/expansions/resource_spmd_expander.cc index a7db8e1879ad38..c6f5042dfb8034 100644 --- a/tensorflow/dtensor/mlir/expansions/resource_spmd_expander.cc +++ b/tensorflow/dtensor/mlir/expansions/resource_spmd_expander.cc @@ -153,7 +153,7 @@ Status ValidateAndAssignResourceInputLayout(mlir::tf_device::ClusterOp op, add_layout_as_attributes(mutable_input_layouts, mutable_input_indices, resource_arg_index, layout_string); } - return OkStatus(); + return absl::OkStatus(); } } // namespace diff --git a/tensorflow/dtensor/mlir/expansions/slice_spmd_expander.cc b/tensorflow/dtensor/mlir/expansions/slice_spmd_expander.cc index d3e7487006ce88..13768bc95419b5 100644 --- a/tensorflow/dtensor/mlir/expansions/slice_spmd_expander.cc +++ b/tensorflow/dtensor/mlir/expansions/slice_spmd_expander.cc @@ -50,7 +50,7 @@ Status GetSliceOpArguments(mlir::TF::SliceOp slice_op, ExtractConstVectorFromValue(slice_op.getSize(), &sizes), "expected constant argument for SliceOp::size()"); - return OkStatus(); + return absl::OkStatus(); } StatusOr VerifySliceLayout( diff --git a/tensorflow/dtensor/mlir/expansions/softmax_spmd_expander.cc b/tensorflow/dtensor/mlir/expansions/softmax_spmd_expander.cc index 41f99410e16a71..8ccdf0ee8168ae 100644 --- a/tensorflow/dtensor/mlir/expansions/softmax_spmd_expander.cc +++ b/tensorflow/dtensor/mlir/expansions/softmax_spmd_expander.cc @@ -141,7 +141,7 @@ Status ComputeExpAndSum(mlir::OpBuilder& builder, const mlir::Value& logits, ComputeGlobalReduce(builder, exp_of_shifted_logits, logits_layout, {class_dimension}, kReduceOpAdd, /*keep_dims=*/true)); - return OkStatus(); + return absl::OkStatus(); } // Computes softmax from its components. Assumes that builder's insertion point diff --git a/tensorflow/dtensor/mlir/expansions/strided_slice_spmd_expander.cc b/tensorflow/dtensor/mlir/expansions/strided_slice_spmd_expander.cc index 0a4104e2b736a0..506607fb35401b 100644 --- a/tensorflow/dtensor/mlir/expansions/strided_slice_spmd_expander.cc +++ b/tensorflow/dtensor/mlir/expansions/strided_slice_spmd_expander.cc @@ -136,7 +136,7 @@ Status UpdateOpFromTokens(T strided_slice, mlir::Value new_end = IntConstWithMatchingType( builder, strided_slice.getLoc(), end, strided_slice.getBegin().getType()); strided_slice.getEndMutable().assign(new_end); - return OkStatus(); + return absl::OkStatus(); } } // namespace diff --git a/tensorflow/dtensor/mlir/layout_propagation_v2.cc b/tensorflow/dtensor/mlir/layout_propagation_v2.cc index adf9219a4fd922..a7f089eb1726ec 100644 --- a/tensorflow/dtensor/mlir/layout_propagation_v2.cc +++ b/tensorflow/dtensor/mlir/layout_propagation_v2.cc @@ -1359,7 +1359,7 @@ Status RunOneIteration( llvm::DenseMap>& consumers, llvm::DenseMap& merged_layouts, mlir::ModuleOp& module, int stage, int* steps) { - if (is_updated.empty()) return OkStatus(); + if (is_updated.empty()) return absl::OkStatus(); // Merge any possibly updated layouts. if (mlir::failed( MergeAndGetUpdatedLayouts(is_locked, is_updated, producer_request, @@ -1384,7 +1384,7 @@ Status RunOneIteration( return errors::Internal("UpdateLayoutsForOp failed to update layouts."); } ++(*steps); - return OkStatus(); + return absl::OkStatus(); } // Compares every value's layouts in `merged_a` with the ones in `merged_b`, @@ -1406,7 +1406,7 @@ Status CompareMergedLayouts(const llvm::DenseMap& merged_a, changed.insert(value); } } - return OkStatus(); + return absl::OkStatus(); } // MLIR pass that propagates layout for all ops the module. diff --git a/tensorflow/dtensor/mlir/shape_utils.cc b/tensorflow/dtensor/mlir/shape_utils.cc index 57b07683f67940..d0eda18a6bc94b 100644 --- a/tensorflow/dtensor/mlir/shape_utils.cc +++ b/tensorflow/dtensor/mlir/shape_utils.cc @@ -239,7 +239,7 @@ Status InferSPMDExpandedLocalShapeForResourceOutput( mlir::ArrayRef{local_variable_subtype}, context)); op_result->setType(new_var_type); } - return OkStatus(); + return absl::OkStatus(); } mlir::Operation* InferSPMDExpandedLocalShape(mlir::Operation* op) { diff --git a/tensorflow/dtensor/mlir/sparse_expander.cc b/tensorflow/dtensor/mlir/sparse_expander.cc index c471fbc1ed4241..92c5d616436423 100644 --- a/tensorflow/dtensor/mlir/sparse_expander.cc +++ b/tensorflow/dtensor/mlir/sparse_expander.cc @@ -64,7 +64,7 @@ Status RunSparseExpansion(mlir::Operation* op, mlir::Operation** output) { } else { // If there is no SparseTensor inputs then just return the op. *output = op; } - return OkStatus(); + return absl::OkStatus(); } } // namespace dtensor diff --git a/tensorflow/dtensor/mlir/sparse_expander_common.h b/tensorflow/dtensor/mlir/sparse_expander_common.h index e168e4f507e34d..c217e59e02b5b4 100644 --- a/tensorflow/dtensor/mlir/sparse_expander_common.h +++ b/tensorflow/dtensor/mlir/sparse_expander_common.h @@ -32,7 +32,7 @@ namespace dtensor { // SparseToDenseOp. If this value is eventually an output of a SparseToDenseOp, // there should only be DTensor related ops between the actual SparseToDenseOp, // e.g. DTensorRelayout ops or DTensorLayout op. -StatusOr GetSparseToDenseOp(mlir::Value value); +absl::StatusOr GetSparseToDenseOp(mlir::Value value); // Checks whether `value is an output of a SparseToDenseOp value. bool IsSparseValue(mlir::Value value); @@ -45,15 +45,15 @@ bool AllSparseInput(mlir::Operation* op); // Returns the indices component dense tensor from `value`. `value` represents // a SparseTensor value. -StatusOr GetIndicesFromSparseTensor(mlir::Value value); +absl::StatusOr GetIndicesFromSparseTensor(mlir::Value value); // Returns the values component dense tensor from `value`.`value` represents // a SparseTensor value. -StatusOr GetValuesFromSparseTensor(mlir::Value value); +absl::StatusOr GetValuesFromSparseTensor(mlir::Value value); // Returns the dense shape component dense tensor from `value`. `value` // represents a SparseTensor value. -StatusOr GetDenseShapesFromSparseTensor(mlir::Value value); +absl::StatusOr GetDenseShapesFromSparseTensor(mlir::Value value); } // namespace dtensor } // namespace tensorflow diff --git a/tensorflow/dtensor/mlir/spmd_expander.cc b/tensorflow/dtensor/mlir/spmd_expander.cc index ce6b34c7a004b5..4e63f87970777f 100644 --- a/tensorflow/dtensor/mlir/spmd_expander.cc +++ b/tensorflow/dtensor/mlir/spmd_expander.cc @@ -73,7 +73,7 @@ Status AdjustPartedLayout(const llvm::DenseMap& input_layouts, computed_layout.getSecond() = parted; } } - return OkStatus(); + return absl::OkStatus(); } // Returns whether DTensor should skip SPMD expansion because `op` uses parted @@ -168,7 +168,7 @@ Status SPMDExpanderBase::ExpandOpAndSetLayout(mlir::Operation* op, } SetLayoutOnOp(*output, absl::Span>( computed_layout.data(), computed_layout.size())); - return OkStatus(); + return absl::OkStatus(); } // `op` may be removed/replaced from the graph during SPMD expansion, so @@ -239,7 +239,7 @@ Status SPMDExpanderBase::ExpandOpAndSetLayout(mlir::Operation* op, } } - return OkStatus(); + return absl::OkStatus(); } StatusOr> SPMDExpanderBase::ComputeLayoutForward( @@ -299,7 +299,7 @@ Status RunSPMDExpansion(mlir::Operation* op, mlir::Operation** output) { VLOG(1) << "No expansion found for " << OpName(op) << "\n"; *output = op; } - return OkStatus(); + return absl::OkStatus(); } } // namespace dtensor diff --git a/tensorflow/dtensor/mlir/spmd_expander_common.cc b/tensorflow/dtensor/mlir/spmd_expander_common.cc index 7f411c4da74c73..b3f823ae7e9fc4 100644 --- a/tensorflow/dtensor/mlir/spmd_expander_common.cc +++ b/tensorflow/dtensor/mlir/spmd_expander_common.cc @@ -141,7 +141,7 @@ Status CreateSplitOp(const int num_split, const int split_dimension, llvm::SmallVector output_types(num_split, output_type); *split_op = builder->create( location, output_types, split_dimension_op.getOutput(), src_input); - return OkStatus(); + return absl::OkStatus(); } // Given layouts + shapes, determines if the two are broadcasting compatible. @@ -682,7 +682,7 @@ Status SetBuilderInsertionAfterValue(mlir::Value value, mlir::OpBuilder& builder) { if (value.isa()) { builder.setInsertionPointAfterValue(value); - return OkStatus(); + return absl::OkStatus(); } mlir::tf_device::ClusterOp cluster; for (mlir::Operation* op : value.getUsers()) { @@ -696,7 +696,7 @@ Status SetBuilderInsertionAfterValue(mlir::Value value, if (!cluster) return errors::Internal("value not used in any cluster"); builder.setInsertionPointToStart(cluster.SingleBlock::getBody()); - return OkStatus(); + return absl::OkStatus(); } Status PrintTensor(mlir::Value value, const std::string& format_string = "%s") { @@ -713,7 +713,7 @@ Status PrintTensor(mlir::Value value, const std::string& format_string = "%s") { builder.create(value.getLoc(), format.getOutput(), /*output_stream=*/"log(info)", /*end=*/"\n"); - return OkStatus(); + return absl::OkStatus(); } Status ExtractConstStringVectorFromValue( @@ -731,7 +731,7 @@ Status ExtractConstStringVectorFromValue( for (const auto& str : attr.getRawStringData()) { out_vector.push_back(str.str()); } - return OkStatus(); + return absl::OkStatus(); } StatusOr ExtractConstScalarStringFromValue(mlir::Value value) { diff --git a/tensorflow/dtensor/mlir/utils/update_tpu_metadata.cc b/tensorflow/dtensor/mlir/utils/update_tpu_metadata.cc index b55d22d229ac4f..802d46fd27ecde 100644 --- a/tensorflow/dtensor/mlir/utils/update_tpu_metadata.cc +++ b/tensorflow/dtensor/mlir/utils/update_tpu_metadata.cc @@ -179,7 +179,7 @@ Status UpdateMetadataProtoXlaSpmd(const Mesh& mesh_config, } *proto.mutable_device_assignment() = device_assignment; } - return OkStatus(); + return absl::OkStatus(); } Status UpdateMetadataProtoDtensorSpmd(const Mesh& mesh_config, @@ -238,7 +238,7 @@ Status UpdateMetadataProtoDtensorSpmd(const Mesh& mesh_config, } *proto.mutable_device_assignment() = device_assignment; } - return OkStatus(); + return absl::OkStatus(); } mlir::LogicalResult UpdateTPUCompileMetadata(const Mesh& mesh_config, diff --git a/tensorflow/dtensor/mlir/value_utils.cc b/tensorflow/dtensor/mlir/value_utils.cc index 206bc8615d7d5a..cfe56ff0e5f7b7 100644 --- a/tensorflow/dtensor/mlir/value_utils.cc +++ b/tensorflow/dtensor/mlir/value_utils.cc @@ -185,7 +185,7 @@ Status ExtractConstVectorFromValue(mlir::Value value, } for (const mlir::APInt& index : attr) out_vector->emplace_back(index.getSExtValue()); - return OkStatus(); + return absl::OkStatus(); } mlir::Value CreateIntScalarConst(const int64_t value, mlir::OpBuilder builder, diff --git a/tensorflow/lite/BUILD b/tensorflow/lite/BUILD index be2c5f2e74a066..6bfe5fc092ce66 100644 --- a/tensorflow/lite/BUILD +++ b/tensorflow/lite/BUILD @@ -8,7 +8,7 @@ load("//tensorflow/lite/core/shims:cc_library_with_tflite.bzl", "alias_with_tfli load("@bazel_skylib//:bzl_library.bzl", "bzl_library") package( - # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + # copybara:uncomment default_applicable_licenses = ["//tensorflow:LICENSE"], default_visibility = ["//visibility:public"], licenses = ["notice"], ) @@ -645,6 +645,7 @@ cc_library_with_tflite( copts = tflite_copts() + tflite_copts_warnings(), generate_opaque_delegate_target = True, visibility = [ + "//research/drishti/benchmarking/async:__subpackages__", "//tensorflow/compiler/mlir/lite:__subpackages__", "//tensorflow/lite/core/kernels:__subpackages__", "//tensorflow/lite/core/shims:__subpackages__", diff --git a/tensorflow/lite/acceleration/configuration/BUILD b/tensorflow/lite/acceleration/configuration/BUILD index 6b6aab6638d033..375d73fa2c23a4 100644 --- a/tensorflow/lite/acceleration/configuration/BUILD +++ b/tensorflow/lite/acceleration/configuration/BUILD @@ -235,24 +235,28 @@ cc_library_with_tflite( deps = ["//tensorflow/lite/core/acceleration/configuration:nnapi_plugin"], ) -cc_library( - name = "hexagon_plugin", - srcs = ["hexagon_plugin.cc"], - deps = [ - ":configuration_fbs", - "//tensorflow/lite/core/acceleration/configuration:delegate_registry", - "@com_google_absl//absl/memory", - ] + select({ - "//third_party/bazel_platforms/cpu:aarch64": [ - "//tensorflow/lite/delegates/hexagon:hexagon_delegate", - ], - "//third_party/bazel_platforms/cpu:armv7": [ - "//tensorflow/lite/delegates/hexagon:hexagon_delegate", - ], - "//conditions:default": [], - }), - alwayslink = 1, # For registration to always run. -) +# Commented under b/279852433 because caused an error in the OSS +# TODO(zhurakovskyi): Uncomment when fixed. +# copybara:uncomment_begin +# cc_library( +# name = "hexagon_plugin", +# srcs = ["hexagon_plugin.cc"], +# deps = [ +# ":configuration_fbs", +# "@com_google_absl//absl/memory", +# "//tensorflow/lite/core/acceleration/configuration:delegate_registry", +# ] + select({ +# "//third_party/bazel_platforms/cpu:aarch64": [ +# "//tensorflow/lite/delegates/hexagon:hexagon_delegate", +# ], +# "//third_party/bazel_platforms/cpu:armv7": [ +# "//tensorflow/lite/delegates/hexagon:hexagon_delegate", +# ], +# "//conditions:default": [], +# }), +# alwayslink = 1, # For registration to always run. +# ) +# copybara:uncomment_end cc_library( name = "gpu_plugin", diff --git a/tensorflow/lite/arena_planner.cc b/tensorflow/lite/arena_planner.cc index 8fd1a794369b50..4f21713e84d3a6 100644 --- a/tensorflow/lite/arena_planner.cc +++ b/tensorflow/lite/arena_planner.cc @@ -141,6 +141,10 @@ bool ArenaPlanner::InputTensorCanBeShared(const TfLiteTensor& input_tensor, input_allocation_type != kTfLiteArenaRw) { return false; } + if (preserve_all_tensors_) { + return false; + } + return true; } diff --git a/tensorflow/lite/arena_planner_test.cc b/tensorflow/lite/arena_planner_test.cc index 2021ac0797654c..9d6c896cffa5a6 100644 --- a/tensorflow/lite/arena_planner_test.cc +++ b/tensorflow/lite/arena_planner_test.cc @@ -1074,6 +1074,33 @@ TEST_F(ArenaPlannerTest, DebugTensors) { EXPECT_EQ(tensorOffsets.size(), 8); } +TEST_F(ArenaPlannerTest, DebugTensorsInputReuse) { + TestGraph graph({0, 1}, + { + /* in, out, tmp */ + {{0, 1}, {2, 3}, {}}, + {{2, 3}, {4}, {}, kTfLiteBuiltinMul}, + {{4, 2}, {5}, {}, kTfLiteBuiltinSub}, + {{5}, {6}, {}}, + }, + {6}); + + (*graph.tensors())[4].bytes = 200; + (*graph.tensors())[5].bytes = 200; + + SetGraph(&graph, /*preserve_all_tensors=*/false); + Execute(0, graph.nodes().size() - 1); + + // Output of mul node should be reused for output of sub node. + EXPECT_EQ(GetOffset(4), GetOffset(5)); + + SetGraph(&graph, /*preserve_all_tensors=*/true); + Execute(0, graph.nodes().size() - 1); + + // Output of mul node should not be reused for output of sub node. + EXPECT_NE(GetOffset(4), GetOffset(5)); +} + TEST_F(ArenaPlannerTest, SimpleProfilerTest) { gNumAlloc = 0; gNumDealloc = 0; diff --git a/tensorflow/lite/core/BUILD b/tensorflow/lite/core/BUILD index f405e37a47bf30..e5ace050bb9a10 100644 --- a/tensorflow/lite/core/BUILD +++ b/tensorflow/lite/core/BUILD @@ -1,5 +1,5 @@ -load("//tensorflow/lite:build_def.bzl", "tflite_copts", "tflite_copts_warnings", "tflite_self_contained_libs_test_suite") load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_portable") +load("//tensorflow/lite:build_def.bzl", "tflite_copts", "tflite_copts_warnings", "tflite_self_contained_libs_test_suite") load("//tensorflow/lite:special_rules.bzl", "internal_visibility_allowlist", "tflite_portable_test_suite") load("//tensorflow/lite/core:special_rules.bzl", "macros_visibility_allowlist") load("@bazel_skylib//:bzl_library.bzl", "bzl_library") @@ -170,6 +170,7 @@ cc_library( ], compatible_with = get_compatible_with_portable(), visibility = [ + "//research/drishti/benchmarking/async:__subpackages__", "//tensorflow/lite:__subpackages__", ], deps = [ @@ -515,6 +516,7 @@ cc_test( "//tensorflow/lite:util", "//tensorflow/lite/c:c_api_types", "//tensorflow/lite/kernels:builtin_ops", # build_cleaner: keep + "@com_google_absl//absl/log:check", "@com_google_googletest//:gtest_main", ], ) diff --git a/tensorflow/lite/core/acceleration/configuration/c/BUILD b/tensorflow/lite/core/acceleration/configuration/c/BUILD index af5a8ad252af7f..6972ec631b3d44 100644 --- a/tensorflow/lite/core/acceleration/configuration/c/BUILD +++ b/tensorflow/lite/core/acceleration/configuration/c/BUILD @@ -194,43 +194,48 @@ tflite_cc_library_with_c_headers_test( ], ) -# This rule invokes the "flatcc" FlatBuffer C API compiler to generate the sources -# use by the ":configuration_c_fbs" C library rule below. -genrule( - name = "configuration_c_fbs_gen", - srcs = ["//tensorflow/lite/acceleration/configuration:configuration.fbs"], - outs = [ - "configuration_builder.h", - "configuration_reader.h", - ], - cmd = "$(location //third_party/flatcc:flatcc) -o$(RULEDIR) --builder --reader $(SRCS)", - tools = ["//third_party/flatcc"], - # Currently this only enables the API for _building_ configuration flatbuffer objects, - # not the APIs for reading them, verifying them, or converting them to/from JSON. - # [If you need to enable those, replace the two lines above with the following - # outs = ["configuration_builder.h", "configuration_reader.h", "configuration_verifier.h", - # "configuration_json_parser.h", "configuration_json_printer.h"], - # cmd = "$(location //third_party/flatcc:flatcc) -o$(RULEDIR) " + - # "--builder --reader --verifier --json $(SRCS)", - # and then in the rule below -- or preferably in a separate target -- - # add the additional header files in "hdrs" and fix the dependencies.] -) - -# This rule defines a C library containing the Flatbuffer-generated C API for constructing objects -# using the FlatBuffer schema generated from tensorflow/lite/acceleration/configuration/configuration.proto, -# which defines the 'TFLiteSettings' FlatBuffer table and related types. -tflite_cc_library_with_c_headers_test( - name = "configuration_c_fbs", - hdrs = [ - "configuration_builder.h", - "configuration_reader.h", - ], - deps = ["//third_party/flatcc:runtime"], -) - -build_test( - name = "configuration_c_fbs_build_test", - targets = [ - ":configuration_c_fbs", - ], -) +# Commented out under the (b/279852433) because caused an error in the OSS +# TODO(zhurakovskyi): Uncomment when fixed. +# +# copybara:uncomment_begin +# # This rule invokes the "flatcc" FlatBuffer C API compiler to generate the sources +# # use by the ":configuration_c_fbs" C library rule below. +# genrule( +# name = "configuration_c_fbs_gen", +# srcs = ["//tensorflow/lite/acceleration/configuration:configuration.fbs"], +# outs = [ +# "configuration_builder.h", +# "configuration_reader.h", +# ], +# cmd = "$(location //third_party/flatcc:flatcc) -o$(RULEDIR) --builder --reader $(SRCS)", +# tools = ["//third_party/flatcc"], +# # Currently this only enables the API for _building_ configuration flatbuffer objects, +# # not the APIs for reading them, verifying them, or converting them to/from JSON. +# # [If you need to enable those, replace the two lines above with the following +# # outs = ["configuration_builder.h", "configuration_reader.h", "configuration_verifier.h", +# # "configuration_json_parser.h", "configuration_json_printer.h"], +# # cmd = "$(location //third_party/flatcc:flatcc) -o$(RULEDIR) " + +# # "--builder --reader --verifier --json $(SRCS)", +# # and then in the rule below -- or preferably in a separate target -- +# # add the additional header files in "hdrs" and fix the dependencies.] +# ) +# +# # This rule defines a C library containing the Flatbuffer-generated C API for constructing objects +# # using the FlatBuffer schema generated from tensorflow/lite/acceleration/configuration/configuration.proto, +# # which defines the 'TFLiteSettings' FlatBuffer table and related types. +# tflite_cc_library_with_c_headers_test( +# name = "configuration_c_fbs", +# hdrs = [ +# "configuration_builder.h", +# "configuration_reader.h", +# ], +# deps = ["//third_party/flatcc:runtime"], +# ) +# +# build_test( +# name = "configuration_c_fbs_build_test", +# targets = [ +# ":configuration_c_fbs", +# ], +# ) +# copybara:uncomment_end diff --git a/tensorflow/lite/core/async/c/async_kernel.h b/tensorflow/lite/core/async/c/async_kernel.h index 57ed576a34c456..1b3c76acee6324 100644 --- a/tensorflow/lite/core/async/c/async_kernel.h +++ b/tensorflow/lite/core/async/c/async_kernel.h @@ -20,6 +20,7 @@ limitations under the License. // for documentation. #include +#include #include "tensorflow/lite/core/async/c/types.h" #include "tensorflow/lite/core/async/interop/c/attribute_map.h" diff --git a/tensorflow/lite/core/c/registration_external.h b/tensorflow/lite/core/c/registration_external.h index 897e321875d7ea..5561bc474d0079 100644 --- a/tensorflow/lite/core/c/registration_external.h +++ b/tensorflow/lite/core/c/registration_external.h @@ -19,6 +19,7 @@ limitations under the License. #ifndef TENSORFLOW_LITE_CORE_C_REGISTRATION_EXTERNAL_H_ #define TENSORFLOW_LITE_CORE_C_REGISTRATION_EXTERNAL_H_ +#include #include #include "tensorflow/lite/builtin_ops.h" diff --git a/tensorflow/lite/core/interpreter.cc b/tensorflow/lite/core/interpreter.cc index 5c2917e8be9f24..d794902490649c 100644 --- a/tensorflow/lite/core/interpreter.cc +++ b/tensorflow/lite/core/interpreter.cc @@ -492,14 +492,6 @@ TfLiteStatus Interpreter::ApplyOptionsImpl(InterpreterOptions* options) { for (auto& subgraph : subgraphs_) { subgraph->SetOptions(options_.get()); } - - // Handle `experimental_dynamic_allocation_for_large_tensors_`. - if (options->GetDynamicAllocationForLargeTensors() > 0) { - for (auto& subgraph : subgraphs_) { - subgraph->OptimizeMemoryForLargeTensors( - options->GetDynamicAllocationForLargeTensors()); - } - } return kTfLiteOk; } diff --git a/tensorflow/lite/core/interpreter_builder.cc b/tensorflow/lite/core/interpreter_builder.cc index fc60df69278986..6f225aecd08e38 100644 --- a/tensorflow/lite/core/interpreter_builder.cc +++ b/tensorflow/lite/core/interpreter_builder.cc @@ -365,14 +365,13 @@ TfLiteStatus InterpreterBuilder::ParseNodes( EnumNameBuiltinOperator(op_type)); } + void* builtin_data = nullptr; + const char* init_data = nullptr; + size_t init_data_size = 0; if (op_type == BuiltinOperator_CUSTOM) { if (op->custom_options()) { - subgraph->AddNodeWithParameters( - FlatBufferIntArrayToVector(op->inputs()), - FlatBufferIntArrayToVector(op->outputs()), - FlatBufferIntArrayToVector(op->intermediates()), - reinterpret_cast(op->custom_options()->data()), - op->custom_options()->size(), nullptr, registration); + init_data = reinterpret_cast(op->custom_options()->data()); + init_data_size = op->custom_options()->size(); } else if (op->large_custom_options_offset() > 1 && allocation_) { if (op->large_custom_options_offset() + op->large_custom_options_size() > @@ -384,31 +383,20 @@ TfLiteStatus InterpreterBuilder::ParseNodes( return kTfLiteError; } // If the custom op is storing payloads outside of flatbuffers - subgraph->AddNodeWithParameters( - FlatBufferIntArrayToVector(op->inputs()), - FlatBufferIntArrayToVector(op->outputs()), - FlatBufferIntArrayToVector(op->intermediates()), - reinterpret_cast(allocation_->base()) + - op->large_custom_options_offset(), - op->large_custom_options_size(), nullptr, registration); - } else { - subgraph->AddNodeWithParameters( - FlatBufferIntArrayToVector(op->inputs()), - FlatBufferIntArrayToVector(op->outputs()), - FlatBufferIntArrayToVector(op->intermediates()), nullptr, 0, - nullptr, registration); + init_data = reinterpret_cast(allocation_->base()) + + op->large_custom_options_offset(); + init_data_size = op->large_custom_options_size(); } } else { - void* builtin_data = nullptr; MallocDataAllocator malloc_allocator; TF_LITE_ENSURE_STATUS(ParseOpData(op, op_type, error_reporter_, &malloc_allocator, &builtin_data)); - subgraph->AddNodeWithParameters( - FlatBufferIntArrayToVector(op->inputs()), - FlatBufferIntArrayToVector(op->outputs()), - FlatBufferIntArrayToVector(op->intermediates()), nullptr, 0, - builtin_data, registration); } + subgraph->AddNodeWithParameters( + FlatBufferIntArrayToVector(op->inputs()), + FlatBufferIntArrayToVector(op->outputs()), + FlatBufferIntArrayToVector(op->intermediates()), init_data, + init_data_size, builtin_data, registration); } return status; diff --git a/tensorflow/lite/core/kernels/BUILD b/tensorflow/lite/core/kernels/BUILD index 81d643a6759004..f200394d6123de 100644 --- a/tensorflow/lite/core/kernels/BUILD +++ b/tensorflow/lite/core/kernels/BUILD @@ -12,7 +12,10 @@ exports_files( "builtin_op_kernels.h", "register.h", ], - visibility = ["//tensorflow/lite:__subpackages__"], + visibility = [ + "//research/drishti/benchmarking/async:__subpackages__", + "//tensorflow/lite:__subpackages__", + ], ) cc_test( diff --git a/tensorflow/lite/core/kernels/register.cc b/tensorflow/lite/core/kernels/register.cc index 0e3eacf4d65017..a20ef3540fb1e9 100644 --- a/tensorflow/lite/core/kernels/register.cc +++ b/tensorflow/lite/core/kernels/register.cc @@ -82,7 +82,7 @@ BuiltinOpResolver::BuiltinOpResolver() { Register_EMBEDDING_LOOKUP_SPARSE()); AddBuiltin(BuiltinOperator_FULLY_CONNECTED, Register_FULLY_CONNECTED(), /* min_version = */ 1, - /* max_version = */ 11); + /* max_version = */ 12); AddBuiltin(BuiltinOperator_LSH_PROJECTION, Register_LSH_PROJECTION()); AddBuiltin(BuiltinOperator_HASHTABLE_LOOKUP, Register_HASHTABLE_LOOKUP()); AddBuiltin(BuiltinOperator_SOFTMAX, Register_SOFTMAX(), diff --git a/tensorflow/lite/core/subgraph.h b/tensorflow/lite/core/subgraph.h index 5de491609ccdb6..d01cd1b899c000 100644 --- a/tensorflow/lite/core/subgraph.h +++ b/tensorflow/lite/core/subgraph.h @@ -430,7 +430,17 @@ class Subgraph { // WARNING: This is an experimental API and subject to change. // Set the given `InterpreterOptions` object. - void SetOptions(InterpreterOptions* options) { options_ = options; } + void SetOptions(InterpreterOptions* options) { + options_ = options; + if (options && options->GetDynamicAllocationForLargeTensors() > 0) { + // Note: this operation cannot be reversed. + OptimizeMemoryForLargeTensors( + options->GetDynamicAllocationForLargeTensors()); + } + } + + // WARNING: This is an experimental API and subject to change. + const InterpreterOptions* GetOptions() const { return options_; } // WARNING: This is an experimental API and subject to change. // True if all intermediates tensors should be preserved for debugging. diff --git a/tensorflow/lite/core/subgraph_test.cc b/tensorflow/lite/core/subgraph_test.cc index 7515fcc0ca0aed..9c169c6ea256e3 100644 --- a/tensorflow/lite/core/subgraph_test.cc +++ b/tensorflow/lite/core/subgraph_test.cc @@ -24,6 +24,7 @@ limitations under the License. #include #include +#include "absl/log/check.h" #include "tensorflow/lite/c/c_api_types.h" #include "tensorflow/lite/core/interpreter.h" #include "tensorflow/lite/stderr_reporter.h" diff --git a/tensorflow/lite/delegates/coreml/builders/BUILD b/tensorflow/lite/delegates/coreml/builders/BUILD index 730a47ef29354a..3b1991dc4aba30 100644 --- a/tensorflow/lite/delegates/coreml/builders/BUILD +++ b/tensorflow/lite/delegates/coreml/builders/BUILD @@ -76,7 +76,7 @@ cc_test( deps = [ ":util", "//tensorflow/lite/core/c:common", - "//third_party/eigen3", "@com_google_googletest//:gtest_main", + "@eigen_archive//:eigen3", ], ) diff --git a/tensorflow/lite/delegates/gpu/BUILD b/tensorflow/lite/delegates/gpu/BUILD index bd47abe905729c..d6eb080e6a74e9 100644 --- a/tensorflow/lite/delegates/gpu/BUILD +++ b/tensorflow/lite/delegates/gpu/BUILD @@ -264,6 +264,7 @@ cc_library( "//tensorflow/lite:kernel_api", "//tensorflow/lite:minimal_logging", "//tensorflow/lite/async:backend_async_kernel_interface", + "//tensorflow/lite/core/async/interop/c:types", "//tensorflow/lite/core/c:common", "//tensorflow/lite/delegates:serialization", "//tensorflow/lite/delegates/gpu/cl:api", @@ -278,6 +279,7 @@ cc_library( "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", ], ) diff --git a/tensorflow/lite/delegates/gpu/cl/cl_operation.cc b/tensorflow/lite/delegates/gpu/cl/cl_operation.cc index 5325d834d49dee..1cc1738d071d44 100644 --- a/tensorflow/lite/delegates/gpu/cl/cl_operation.cc +++ b/tensorflow/lite/delegates/gpu/cl/cl_operation.cc @@ -24,10 +24,6 @@ namespace { std::string GetCommonOpenCLDefines(CalculationsPrecision precision) { std::string result; - result += "#define FLT16_0123(V) V.s0123\n"; - result += "#define FLT16_4567(V) V.s4567\n"; - result += "#define FLT16_89ab(V) V.s89ab\n"; - result += "#define FLT16_cdef(V) V.scdef\n"; result += "#define GLOBAL_ID_0 get_global_id(0)\n"; result += "#define GLOBAL_ID_1 get_global_id(1)\n"; result += "#define GLOBAL_ID_2 get_global_id(2)\n"; @@ -165,9 +161,9 @@ absl::Status ClOperation::SetDstTensor(int index, Tensor* tensor) { absl::Status ClOperation::Compile(const CreationContext& creation_context) { operation_->code_ = GetCommonOpenCLDefines(operation_->GetPrecision()) + operation_->code_; - RETURN_IF_ERROR(cl_args_.Init( - creation_context.GetGpuInfo(), - creation_context.context, &operation_->args_, &operation_->code_)); + RETURN_IF_ERROR(cl_args_.Init(creation_context.GetGpuInfo(), + creation_context.context, &operation_->args_, + &operation_->code_)); operation_->args_.ReleaseCPURepresentation(); RETURN_IF_ERROR(creation_context.cache->GetOrCreateCLKernel( operation_->code_, "main_function", operation_->compiler_options_, diff --git a/tensorflow/lite/delegates/gpu/cl/cl_program.cc b/tensorflow/lite/delegates/gpu/cl/cl_program.cc index 856896b667ba71..63e1837d17070c 100644 --- a/tensorflow/lite/delegates/gpu/cl/cl_program.cc +++ b/tensorflow/lite/delegates/gpu/cl/cl_program.cc @@ -79,6 +79,9 @@ std::string CompilerOptionToString(const GpuInfo& gpu_info, CompilerOptions option) { switch (option) { case CompilerOptions::kAdrenoFullSimd: + if (gpu_info.opencl_info.IsCLVK()) { + return ""; + } if (gpu_info.IsAdreno()) { if (gpu_info.adreno_info.IsAdreno3xx() || gpu_info.adreno_info.IsAdreno4xx()) { @@ -90,6 +93,9 @@ std::string CompilerOptionToString(const GpuInfo& gpu_info, return "unsupported"; } case CompilerOptions::kAdrenoMoreWaves: + if (gpu_info.opencl_info.IsCLVK()) { + return ""; + } if (gpu_info.IsAdreno()) { if (!(gpu_info.adreno_info.IsAdreno3xx() || gpu_info.adreno_info.IsAdreno4xx())) { diff --git a/tensorflow/lite/delegates/gpu/common/gpu_info.cc b/tensorflow/lite/delegates/gpu/common/gpu_info.cc index 30489a1721f9d8..3b1c264e786f48 100644 --- a/tensorflow/lite/delegates/gpu/common/gpu_info.cc +++ b/tensorflow/lite/delegates/gpu/common/gpu_info.cc @@ -397,6 +397,7 @@ AppleInfo::AppleInfo(const std::string& gpu_description) { {"apple a14 gpu", AppleGpu::kA14}, {"apple a15 gpu", AppleGpu::kA15}, {"apple a16 gpu", AppleGpu::kA16}, + {"apple a17 pro gpu", AppleGpu::kA17Pro}, // on tablets we have metal device name "apple m1 gpu" // and on notebooks "apple m1" {"apple m1 gpu", AppleGpu::kM1}, @@ -412,15 +413,82 @@ AppleInfo::AppleInfo(const std::string& gpu_description) { } else { gpu_type = AppleGpu::kUnknown; } + gpu_family = GetGpuFamily(); } -bool AppleInfo::IsA7GenerationGpu() const { return gpu_type == AppleGpu::kA7; } -bool AppleInfo::IsA8GenerationGpu() const { - return gpu_type == AppleGpu::kA8 || gpu_type == AppleGpu::kA8X; +AppleInfo::Family AppleInfo::GetGpuFamily() const { + if (gpu_type == AppleGpu::kA7) { + return AppleInfo::Family::kApple1; + } else if (gpu_type == AppleGpu::kA8 || gpu_type == AppleGpu::kA8X) { + return AppleInfo::Family::kApple2; + } else if (gpu_type == AppleGpu::kA9 || gpu_type == AppleGpu::kA9X || + gpu_type == AppleGpu::kA10 || gpu_type == AppleGpu::kA10X) { + return AppleInfo::Family::kApple3; + } else if (gpu_type == AppleGpu::kA11) { + return AppleInfo::Family::kApple4; + } else if (gpu_type == AppleGpu::kA12 || gpu_type == AppleGpu::kA12X || + gpu_type == AppleGpu::kA12Z) { + return AppleInfo::Family::kApple5; + } else if (gpu_type == AppleGpu::kA13) { + return AppleInfo::Family::kApple6; + } else if (gpu_type == AppleGpu::kA14 || IsM1Series()) { + return AppleInfo::Family::kApple7; + } else if (gpu_type == AppleGpu::kA15 || gpu_type == AppleGpu::kA16 || + gpu_type == AppleGpu::kM2) { + return AppleInfo::Family::kApple8; + } else if (gpu_type == AppleGpu::kA17Pro) { + return AppleInfo::Family::kApple9; + } + return AppleInfo::Family::kApple1; +} + +bool AppleInfo::IsFamilyApple1() const { + return gpu_family == AppleInfo::Family::kApple1; +} + +bool AppleInfo::IsFamilyApple2() const { + return gpu_family == AppleInfo::Family::kApple2; +} + +bool AppleInfo::IsFamilyApple3() const { + return gpu_family == AppleInfo::Family::kApple3; +} + +bool AppleInfo::IsFamilyApple4() const { + return gpu_family == AppleInfo::Family::kApple4; +} + +bool AppleInfo::IsFamilyApple5() const { + return gpu_family == AppleInfo::Family::kApple5; +} + +bool AppleInfo::IsFamilyApple6() const { + return gpu_family == AppleInfo::Family::kApple6; +} + +bool AppleInfo::IsFamilyApple7() const { + return gpu_family == AppleInfo::Family::kApple7; +} + +bool AppleInfo::IsFamilyApple8() const { + return gpu_family == AppleInfo::Family::kApple8; +} + +bool AppleInfo::IsFamilyApple9() const { + return gpu_family == AppleInfo::Family::kApple9; +} + +bool AppleInfo::IsFamilyOrLower(AppleInfo::Family family) const { + return gpu_family <= family; } bool AppleInfo::IsLocalMemoryPreferredOverGlobal() const { - return IsA7GenerationGpu() || IsA8GenerationGpu(); + return IsFamilyOrLower(AppleInfo::Family::kApple2); +} + +bool AppleInfo::IsM1Series() const { + return gpu_type == AppleGpu::kM1 || gpu_type == AppleGpu::kM1Pro || + gpu_type == AppleGpu::kM1Max || gpu_type == AppleGpu::kM1Ultra; } bool AppleInfo::IsBionic() const { @@ -428,21 +496,22 @@ bool AppleInfo::IsBionic() const { gpu_type == AppleGpu::kA12X || gpu_type == AppleGpu::kA12Z || gpu_type == AppleGpu::kA13 || gpu_type == AppleGpu::kA14 || gpu_type == AppleGpu::kA15 || gpu_type == AppleGpu::kA16 || - gpu_type == AppleGpu::kM1 || gpu_type == AppleGpu::kM1Pro || - gpu_type == AppleGpu::kM1Max || gpu_type == AppleGpu::kM1Ultra || - gpu_type == AppleGpu::kM2; + gpu_type == AppleGpu::kA17Pro || gpu_type == AppleGpu::kM1 || + gpu_type == AppleGpu::kM1Pro || gpu_type == AppleGpu::kM1Max || + gpu_type == AppleGpu::kM1Ultra || gpu_type == AppleGpu::kM2; } bool AppleInfo::IsSIMDMatMulSupported() const { return gpu_type == AppleGpu::kA14 || gpu_type == AppleGpu::kA15 || - gpu_type == AppleGpu::kA16 || gpu_type == AppleGpu::kM1 || - gpu_type == AppleGpu::kM1Pro || gpu_type == AppleGpu::kM1Max || - gpu_type == AppleGpu::kM1Ultra || gpu_type == AppleGpu::kM2; + gpu_type == AppleGpu::kA16 || gpu_type == AppleGpu::kA17Pro || + gpu_type == AppleGpu::kM1 || gpu_type == AppleGpu::kM1Pro || + gpu_type == AppleGpu::kM1Max || gpu_type == AppleGpu::kM1Ultra || + gpu_type == AppleGpu::kM2; } bool AppleInfo::IsSIMDMatMulFp32Perf2x() const { return gpu_type == AppleGpu::kA15 || gpu_type == AppleGpu::kA16 || - gpu_type == AppleGpu::kM2; + gpu_type == AppleGpu::kA17Pro || gpu_type == AppleGpu::kM2; } bool AppleInfo::IsRoundToNearestSupported() const { return IsBionic(); } @@ -484,6 +553,8 @@ int AppleInfo::GetComputeUnitsCount() const { return 5; case AppleGpu::kA16: return 5; + case AppleGpu::kA17Pro: + return 6; case AppleGpu::kM1: // approximate, can be 7 or 8 return 8; diff --git a/tensorflow/lite/delegates/gpu/common/gpu_info.h b/tensorflow/lite/delegates/gpu/common/gpu_info.h index 4bf986b7a60e7e..eeb6c33c312de8 100644 --- a/tensorflow/lite/delegates/gpu/common/gpu_info.h +++ b/tensorflow/lite/delegates/gpu/common/gpu_info.h @@ -172,6 +172,7 @@ enum class AppleGpu { kA14, kA15, kA16, + kA17Pro, kM1, kM1Pro, kM1Max, @@ -180,15 +181,39 @@ enum class AppleGpu { }; struct AppleInfo { + // https://developer.apple.com/documentation/metal/mtlgpufamily + enum class Family { + kApple9 = 9, + kApple8 = 8, + kApple7 = 7, + kApple6 = 6, + kApple5 = 5, + kApple4 = 4, + kApple3 = 3, + kApple2 = 2, + kApple1 = 1, + }; AppleInfo() = default; explicit AppleInfo(const std::string& gpu_description); AppleGpu gpu_type; + Family gpu_family; + + bool IsFamilyApple1() const; + bool IsFamilyApple2() const; + bool IsFamilyApple3() const; + bool IsFamilyApple4() const; + bool IsFamilyApple5() const; + bool IsFamilyApple6() const; + bool IsFamilyApple7() const; + bool IsFamilyApple8() const; + bool IsFamilyApple9() const; + + bool IsFamilyOrLower(Family family) const; - bool IsA7GenerationGpu() const; - bool IsA8GenerationGpu() const; bool IsLocalMemoryPreferredOverGlobal() const; bool IsBionic() const; + bool IsM1Series() const; bool IsSIMDMatMulSupported() const; // Often, fp32 alu performance is 1/2 of fp16 alu performance @@ -206,6 +231,7 @@ struct AppleInfo { void SetComputeUnits(int compute_units_count); private: + Family GetGpuFamily() const; int compute_units = -1; }; diff --git a/tensorflow/lite/delegates/gpu/common/gpu_model.cc b/tensorflow/lite/delegates/gpu/common/gpu_model.cc index ac0a7b455f7052..33301f407f9430 100644 --- a/tensorflow/lite/delegates/gpu/common/gpu_model.cc +++ b/tensorflow/lite/delegates/gpu/common/gpu_model.cc @@ -248,9 +248,7 @@ absl::Status ReserveGraphTensors(const CreateGpuModelInfo& create_info, tensor_desc.UpdateToSupportedStorageType(gpu_info, shape)); if (gpu_info.IsApiMetal() && storage_type == TensorStorageType::TEXTURE_2D) { - const bool a7_gen_gpu = - gpu_info.IsApple() && gpu_info.apple_info.IsA7GenerationGpu(); - if (!a7_gen_gpu) { + if (!(gpu_info.IsApple() && gpu_info.apple_info.IsFamilyApple1())) { tensor_desc.SetUseBufferForWriteOnlyTexture2d(true); } } diff --git a/tensorflow/lite/delegates/gpu/common/task/buffer_desc.cc b/tensorflow/lite/delegates/gpu/common/task/buffer_desc.cc index 56e1ce74d62fab..43b3d0df9a57fb 100644 --- a/tensorflow/lite/delegates/gpu/common/task/buffer_desc.cc +++ b/tensorflow/lite/delegates/gpu/common/task/buffer_desc.cc @@ -20,7 +20,6 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" -#include "absl/strings/substitute.h" #include "tensorflow/lite/delegates/gpu/common/data_type.h" #include "tensorflow/lite/delegates/gpu/common/status.h" #include "tensorflow/lite/delegates/gpu/common/task/util.h" @@ -91,30 +90,9 @@ absl::Status BufferDescriptor::PerformReadSelector( " % 2 == 0 ? 0 : 2]), unpackHalf2x16(buffer[", arg0, " / 2][", arg0, " % 2 == 0 ? 1 : 3]))"); } else { - if (element_size == 4) { - *result = - absl::StrCat("vec4(unpackHalf2x16(buffer[", args[0], - "].x), unpackHalf2x16(buffer[", args[0], "].y))"); - } else if (element_size == 16) { - const std::string vec0 = absl::Substitute( - "vec4(unpackHalf2x16(buffer[$0].a.x), " - "unpackHalf2x16(buffer[$0].a.y))", - args[0]); - const std::string vec1 = absl::Substitute( - "vec4(unpackHalf2x16(buffer[$0].a.z), " - "unpackHalf2x16(buffer[$0].a.w))", - args[0]); - const std::string vec2 = absl::Substitute( - "vec4(unpackHalf2x16(buffer[$0].b.x), " - "unpackHalf2x16(buffer[$0].b.y))", - args[0]); - const std::string vec3 = absl::Substitute( - "vec4(unpackHalf2x16(buffer[$0].b.z), " - "unpackHalf2x16(buffer[$0].b.w))", - args[0]); - *result = absl::Substitute("mat4x4($0, $1, $2, $3)", vec0, vec1, vec2, - vec3); - } + *result = + absl::StrCat("vec4(unpackHalf2x16(buffer[", args[0], + "].x), unpackHalf2x16(buffer[", args[0], "].y))"); } } else { *result = absl::StrCat("buffer[", args[0], "]"); diff --git a/tensorflow/lite/delegates/gpu/common/task/tensor_desc.cc b/tensorflow/lite/delegates/gpu/common/task/tensor_desc.cc index b3d342d5f71474..66b2d7bfcb3480 100644 --- a/tensorflow/lite/delegates/gpu/common/task/tensor_desc.cc +++ b/tensorflow/lite/delegates/gpu/common/task/tensor_desc.cc @@ -1590,11 +1590,9 @@ TensorDescriptor CreateHwcTensorDescriptor(DataType data_type, TensorStorageType GetStorageTypeForLinearTensor(const GpuInfo& gpu_info, DataType data_type, const Linear& shape) { - if (gpu_info.IsApple()) { - if (gpu_info.apple_info.IsA7GenerationGpu() || - gpu_info.apple_info.IsA8GenerationGpu()) { - return TensorStorageType::TEXTURE_2D; - } + if (gpu_info.IsApple() && + gpu_info.apple_info.IsFamilyOrLower(AppleInfo::Family::kApple2)) { + return TensorStorageType::TEXTURE_2D; } if (!gpu_info.SupportsImages() || gpu_info.IsMali() || gpu_info.IsApple() || gpu_info.IsAMD()) { diff --git a/tensorflow/lite/delegates/gpu/common/tasks/conv_generic.cc b/tensorflow/lite/delegates/gpu/common/tasks/conv_generic.cc index e49f4a8cb8e5a5..c2a1bd8f539512 100644 --- a/tensorflow/lite/delegates/gpu/common/tasks/conv_generic.cc +++ b/tensorflow/lite/delegates/gpu/common/tasks/conv_generic.cc @@ -1293,7 +1293,7 @@ ConvGeneric::ConvParams GetConvParamsForA7A8(const AppleInfo& apple_info, options.push_back(CreateWorkGroupSizeOption( {8, 4, 1}, WorkGroupSizeOption::ThreadMapping::kDefault, 1.0f, dst_shape, params.block_size)); - if (!apple_info.IsA7GenerationGpu()) { + if (!apple_info.IsFamilyApple1()) { options.push_back(CreateWorkGroupSizeOption( {4, 4, 1}, WorkGroupSizeOption::ThreadMapping::kDefault, 1.01f, dst_shape, params.block_size)); @@ -1304,7 +1304,7 @@ ConvGeneric::ConvParams GetConvParamsForA7A8(const AppleInfo& apple_info, options.push_back(CreateWorkGroupSizeOption( {32, 1, 1}, WorkGroupSizeOption::ThreadMapping::kLinearSpatial, 1.0f, dst_shape, params.block_size)); - if (!apple_info.IsA7GenerationGpu()) { + if (!apple_info.IsFamilyApple1()) { options.push_back(CreateWorkGroupSizeOption( {16, 1, 1}, WorkGroupSizeOption::ThreadMapping::kLinearSpatial, 1.01f, dst_shape, params.block_size)); diff --git a/tensorflow/lite/delegates/gpu/common/tasks/convolution_transposed.cc b/tensorflow/lite/delegates/gpu/common/tasks/convolution_transposed.cc index 4b96d3bfbfc002..481121d441da44 100644 --- a/tensorflow/lite/delegates/gpu/common/tasks/convolution_transposed.cc +++ b/tensorflow/lite/delegates/gpu/common/tasks/convolution_transposed.cc @@ -140,7 +140,7 @@ std::string ConvolutionTransposed::GenerateConvolutionTransposedCode( weights_layout_ == WeightsLayout::kOSpatialIOGroupO4I4) { BufferDescriptor desc; desc.element_type = op_def.src_tensors[1].GetDataType(); - desc.element_size = 16; + desc.element_size = 4; desc.memory_type = MemoryType::GLOBAL; AddSrcBuffer("weights", desc); } else { @@ -160,15 +160,15 @@ std::string ConvolutionTransposed::GenerateConvolutionTransposedCode( std::string f0, f1, f2, f3; if (weights_are_buffer) { if (gpu_info.SupportsPointersInKernels()) { - f0 = "FLT16_0123(weights_cache[" + std::to_string(s) + "])"; - f1 = "FLT16_4567(weights_cache[" + std::to_string(s) + "])"; - f2 = "FLT16_89ab(weights_cache[" + std::to_string(s) + "])"; - f3 = "FLT16_cdef(weights_cache[" + std::to_string(s) + "])"; + f0 = "weights_cache[" + std::to_string(s * 4 + 0) + "]"; + f1 = "weights_cache[" + std::to_string(s * 4 + 1) + "]"; + f2 = "weights_cache[" + std::to_string(s * 4 + 2) + "]"; + f3 = "weights_cache[" + std::to_string(s * 4 + 3) + "]"; } else { - f0 = "FLT16_0123(flt16val)"; - f1 = "FLT16_4567(flt16val)"; - f2 = "FLT16_89ab(flt16val)"; - f3 = "FLT16_cdef(flt16val)"; + f0 = "f0"; + f1 = "f1"; + f2 = "f2"; + f3 = "f3"; } } else { f0 = "f" + std::to_string(s * 4 + 0); @@ -250,16 +250,6 @@ std::string ConvolutionTransposed::GenerateConvolutionTransposedCode( return check; }; - switch (op_def.precision) { - case CalculationsPrecision::F32: - c += "#define FLT16 float16\n"; - break; - case CalculationsPrecision::F32_F16: - case CalculationsPrecision::F16: - c += "#define FLT16 half16\n"; - break; - } - c += "MAIN_FUNCTION($0) {\n"; if (op_def.IsBatchSupported()) { c += " int linear_id = GLOBAL_ID_0;\n"; @@ -300,7 +290,7 @@ std::string ConvolutionTransposed::GenerateConvolutionTransposedCode( if (src_def.HasAxis(Axis::DEPTH)) { c += " * args.kernel_size_z"; } - c += ";\n"; + c += " * 4;\n"; } for (int s = 0; s < block_size.w; ++s) { const std::string sind = std::to_string(s); @@ -444,7 +434,7 @@ std::string ConvolutionTransposed::GenerateConvolutionTransposedCode( if (weights_are_buffer) { c += " int f_offset = f_base + kernel_index * " "args.src_tensor.Slices() * " + - std::to_string(block_size.w) + ";\n"; + std::to_string(block_size.w * 4) + ";\n"; } else { c += " int x_c = kernel_index * args.src_tensor.Slices();\n"; } @@ -494,7 +484,7 @@ std::string ConvolutionTransposed::GenerateConvolutionTransposedCode( } if (weights_are_buffer) { if (gpu_info.SupportsPointersInKernels()) { - c += " __global FLT16* weights_cache = " + c += " __global FLT4* weights_cache = " "args.weights.GetPtr(f_offset);\n"; } } else { @@ -510,12 +500,18 @@ std::string ConvolutionTransposed::GenerateConvolutionTransposedCode( c += " x_c++;\n"; } if (weights_are_buffer && !gpu_info.SupportsPointersInKernels()) { - c += " FLT16 flt16val;\n"; + c += " FLT4 f0, f1, f2, f3;\n"; } for (int s = 0; s < block_size.w; ++s) { if (weights_are_buffer && !gpu_info.SupportsPointersInKernels()) { - c += " flt16val = args.weights.Read(f_offset + " + - std::to_string(s) + ");\n"; + c += " f0 = args.weights.Read(f_offset + " + + std::to_string(s * 4 + 0) + ");\n"; + c += " f1 = args.weights.Read(f_offset + " + + std::to_string(s * 4 + 1) + ");\n"; + c += " f2 = args.weights.Read(f_offset + " + + std::to_string(s * 4 + 2) + ");\n"; + c += " f3 = args.weights.Read(f_offset + " + + std::to_string(s * 4 + 3) + ");\n"; } const std::string sind = std::to_string(s); for (int z = 0; z < block_size.z; ++z) { @@ -532,7 +528,7 @@ std::string ConvolutionTransposed::GenerateConvolutionTransposedCode( } } if (weights_are_buffer) { - c += " f_offset += " + std::to_string(block_size.w) + ";\n"; + c += " f_offset += " + std::to_string(block_size.w * 4) + ";\n"; } c += " }\n"; c += " }\n"; diff --git a/tensorflow/lite/delegates/gpu/common/tasks/convolution_transposed.h b/tensorflow/lite/delegates/gpu/common/tasks/convolution_transposed.h index f170afbb0cc098..5eb41cc83f637e 100644 --- a/tensorflow/lite/delegates/gpu/common/tasks/convolution_transposed.h +++ b/tensorflow/lite/delegates/gpu/common/tasks/convolution_transposed.h @@ -108,7 +108,7 @@ void ConvolutionTransposed::UploadWeights( if (weights_are_buffer) { BufferDescriptor desc; desc.element_type = weights_desc.type; - desc.element_size = 16; + desc.element_size = 4; desc.size = weights_data.size(); desc.data = std::move(weights_data); args_.AddObject("weights", @@ -139,7 +139,7 @@ void ConvolutionTransposed::UploadWeights( if (weights_are_buffer) { BufferDescriptor desc; desc.element_type = weights_desc.type; - desc.element_size = 16; + desc.element_size = 4; desc.size = weights_data.size(); desc.data = std::move(weights_data); args_.AddObject("weights", diff --git a/tensorflow/lite/delegates/gpu/common/tasks/depthwise_conv.cc b/tensorflow/lite/delegates/gpu/common/tasks/depthwise_conv.cc index 31eb5942fd1306..09fff84d5cdbf7 100644 --- a/tensorflow/lite/delegates/gpu/common/tasks/depthwise_conv.cc +++ b/tensorflow/lite/delegates/gpu/common/tasks/depthwise_conv.cc @@ -101,11 +101,9 @@ std::string GetSrcXYCheck(const GpuInfo& gpu_info, } bool UseBuffersForWeights(const GpuInfo& gpu_info) { - if (gpu_info.IsApple()) { - if (gpu_info.apple_info.IsA7GenerationGpu() || - gpu_info.apple_info.IsA8GenerationGpu()) { - return false; - } + if (gpu_info.IsApple() && + gpu_info.apple_info.IsFamilyOrLower(AppleInfo::Family::kApple2)) { + return false; } return !gpu_info.SupportsImages() || gpu_info.IsMali() || gpu_info.IsApple() || gpu_info.IsAMD(); diff --git a/tensorflow/lite/delegates/gpu/common/tasks/fully_connected.cc b/tensorflow/lite/delegates/gpu/common/tasks/fully_connected.cc index 3194924ad1e16c..96e53b2f5a7534 100644 --- a/tensorflow/lite/delegates/gpu/common/tasks/fully_connected.cc +++ b/tensorflow/lite/delegates/gpu/common/tasks/fully_connected.cc @@ -113,15 +113,6 @@ std::string FullyConnected::GetFullyConnectedKernelCode( AddDstTensor("dst_tensor", op_def.dst_tensors[0]); std::string c; - switch (op_def.precision) { - case CalculationsPrecision::F32: - c += "#define FLT16 float16\n"; - break; - case CalculationsPrecision::F32_F16: - case CalculationsPrecision::F16: - c += "#define FLT16 half16\n"; - break; - } c += "#define WG_X " + std::to_string(work_group_size_.x) + "\n"; c += "#define WG_Y " + std::to_string(work_group_size_.y) + "\n"; @@ -135,11 +126,11 @@ std::string FullyConnected::GetFullyConnectedKernelCode( FLT4 v = args.src_tensor.Read(0, 0, c); )"; if (weights_are_buffer) { - c += R"(FLT16 w = args.weights.Read(c * args.dst_tensor.Slices() + gid); - FLT4 partial = v.x * FLT16_0123(w); - partial += v.y * FLT16_4567(w); - partial += v.z * FLT16_89ab(w); - partial += v.w * FLT16_cdef(w); + c += R"(int weights_index = (c * args.dst_tensor.Slices() + gid) * 4; + FLT4 partial = v.x * args.weights.Read(weights_index + 0); + partial += v.y * args.weights.Read(weights_index + 1); + partial += v.z * args.weights.Read(weights_index + 2); + partial += v.w * args.weights.Read(weights_index + 3); s += TO_ACCUM_TYPE(partial); )"; } else { diff --git a/tensorflow/lite/delegates/gpu/common/tasks/fully_connected.h b/tensorflow/lite/delegates/gpu/common/tasks/fully_connected.h index 19ca95644354bb..f44383028840bd 100644 --- a/tensorflow/lite/delegates/gpu/common/tasks/fully_connected.h +++ b/tensorflow/lite/delegates/gpu/common/tasks/fully_connected.h @@ -166,7 +166,7 @@ void FullyConnected::UploadWeights(const tflite::gpu::Tensor& weights, if (weights_are_buffer) { BufferDescriptor desc; desc.element_type = f32_weights ? DataType::FLOAT32 : DataType::FLOAT16; - desc.element_size = 16; + desc.element_size = 4; desc.size = float4_size * elements_count; desc.data.resize(desc.size); diff --git a/tensorflow/lite/delegates/gpu/common/tasks/reduce.cc b/tensorflow/lite/delegates/gpu/common/tasks/reduce.cc index 6fd1d262c5e145..b16f7d9265b1a3 100644 --- a/tensorflow/lite/delegates/gpu/common/tasks/reduce.cc +++ b/tensorflow/lite/delegates/gpu/common/tasks/reduce.cc @@ -44,6 +44,13 @@ int GetMaximumWGTotalSize(const GpuInfo& gpu_info) { total_wg_size = 64; } } + if (gpu_info.IsPowerVR()) { + if (gpu_info.IsCL30OrHigher()) { + total_wg_size = gpu_info.opencl_info.preferred_work_group_size_multiple; + } else { + total_wg_size = 32; + } + } return total_wg_size; } diff --git a/tensorflow/lite/delegates/gpu/common/tasks/special/fc_fc_add.cc b/tensorflow/lite/delegates/gpu/common/tasks/special/fc_fc_add.cc index 28f1c3f5f48217..c985b550a87f2f 100644 --- a/tensorflow/lite/delegates/gpu/common/tasks/special/fc_fc_add.cc +++ b/tensorflow/lite/delegates/gpu/common/tasks/special/fc_fc_add.cc @@ -108,15 +108,6 @@ std::string FCFCAdd::GetFCFCAddKernelCode(const OperationDef& op_def, AddDstTensor("dst_tensor", op_def.dst_tensors[0]); std::string c; - switch (op_def.precision) { - case CalculationsPrecision::F32: - c += "#define FLT16 float16\n"; - break; - case CalculationsPrecision::F32_F16: - case CalculationsPrecision::F16: - c += "#define FLT16 half16\n"; - break; - } c += "#define WG_X " + std::to_string(work_group_size_.x) + "\n"; c += "#define WG_Y " + std::to_string(work_group_size_.y) + "\n"; @@ -132,11 +123,11 @@ std::string FCFCAdd::GetFCFCAddKernelCode(const OperationDef& op_def, FLT4 v = args.src_tensor_0.Read(0, 0, c); )"; if (weights_are_buffer) { - c += R"(FLT16 w = args.weights0.Read(c * args.dst_tensor.Slices() + gid); - FLT4 partial = v.x * FLT16_0123(w); - partial += v.y * FLT16_4567(w); - partial += v.z * FLT16_89ab(w); - partial += v.w * FLT16_cdef(w); + c += R"(int weights_index = (c * args.dst_tensor.Slices() + gid) * 4; + FLT4 partial = v.x * args.weights0.Read(weights_index + 0); + partial += v.y * args.weights0.Read(weights_index + 1); + partial += v.z * args.weights0.Read(weights_index + 2); + partial += v.w * args.weights0.Read(weights_index + 3); s += TO_ACCUM_TYPE(partial); )"; } else { @@ -169,11 +160,11 @@ std::string FCFCAdd::GetFCFCAddKernelCode(const OperationDef& op_def, FLT4 v = args.src_tensor_1.Read(0, 0, c); )"; if (weights_are_buffer) { - c += R"(FLT16 w = args.weights1.Read(c * args.dst_tensor.Slices() + gid); - FLT4 partial = v.x * FLT16_0123(w); - partial += v.y * FLT16_4567(w); - partial += v.z * FLT16_89ab(w); - partial += v.w * FLT16_cdef(w); + c += R"(int weights_index = (c * args.dst_tensor.Slices() + gid) * 4; + FLT4 partial = v.x * args.weights1.Read(weights_index + 0); + partial += v.y * args.weights1.Read(weights_index + 1); + partial += v.z * args.weights1.Read(weights_index + 2); + partial += v.w * args.weights1.Read(weights_index + 3); s += TO_ACCUM_TYPE(partial); )"; } else { diff --git a/tensorflow/lite/delegates/gpu/common/tasks/special/fc_fc_add.h b/tensorflow/lite/delegates/gpu/common/tasks/special/fc_fc_add.h index d632d372e931a3..7d4f2a99ce3898 100644 --- a/tensorflow/lite/delegates/gpu/common/tasks/special/fc_fc_add.h +++ b/tensorflow/lite/delegates/gpu/common/tasks/special/fc_fc_add.h @@ -148,7 +148,7 @@ void FCFCAdd::UploadWeights(const tflite::gpu::Tensor& weights, if (weights_are_buffer) { BufferDescriptor desc; desc.element_type = f32_weights ? DataType::FLOAT32 : DataType::FLOAT16; - desc.element_size = 16; + desc.element_size = 4; desc.size = float4_size * elements_count; desc.data.resize(desc.size); diff --git a/tensorflow/lite/delegates/gpu/common/tasks/special/thin_pointwise_fuser.cc b/tensorflow/lite/delegates/gpu/common/tasks/special/thin_pointwise_fuser.cc index 923aaf53efae77..902ef54a4b3372 100644 --- a/tensorflow/lite/delegates/gpu/common/tasks/special/thin_pointwise_fuser.cc +++ b/tensorflow/lite/delegates/gpu/common/tasks/special/thin_pointwise_fuser.cc @@ -930,7 +930,8 @@ absl::Status TryThinPointwiseFuser( gpu_info.IsApple() || gpu_info.IsAMD())) { return absl::NotFoundError("ThinPointwiseFuser not suitable."); } - if (gpu_info.IsMali() && gpu_info.mali_info.IsMidgard()) { + // TODO(b/322801363): Add more precise checks for Mali + if (gpu_info.IsMali()) { return absl::NotFoundError("ThinPointwiseFuser not suitable."); } auto* node = graph.GetNode(first_node_id); diff --git a/tensorflow/lite/delegates/gpu/delegate.cc b/tensorflow/lite/delegates/gpu/delegate.cc index 00c58d37e6f37b..aeccd7ad574920 100644 --- a/tensorflow/lite/delegates/gpu/delegate.cc +++ b/tensorflow/lite/delegates/gpu/delegate.cc @@ -20,12 +20,15 @@ limitations under the License. #include "tensorflow/lite/delegates/gpu/delegate.h" +#include "tensorflow/lite/logger.h" + #if defined(__ANDROID__) #include #endif #include #include +#include #include #include #include // NOLINT(build/c++11) @@ -34,6 +37,7 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" +#include "absl/strings/numbers.h" #include "absl/types/span.h" #include "tensorflow/lite/builtin_ops.h" @@ -116,6 +120,7 @@ namespace { using delegates::Serialization; using delegates::SerializationParams; +using tflite::TFLITE_LOG_WARNING; constexpr char kSerializedDataPrefix[] = "gpuv2_data_"; @@ -151,6 +156,69 @@ InferenceUsage ToUsage(int32_t usage) { return InferenceUsage::UNKNOWN; } +bool ParseOptions(const char* const* options_keys, + const char* const* options_values, size_t num_options, + TfLiteGpuDelegateOptionsV2* options) { + for (size_t i = 0; i < num_options; ++i) { + if (strcmp(options_keys[i], "is_precision_loss_allowed")) { + if (!absl::SimpleAtoi(options_values[i], + &options->is_precision_loss_allowed)) { + TFLITE_LOG(TFLITE_LOG_WARNING, "ParseOptions: malformed option %s.", + options_keys[i]); + return false; + } + } else if (strcmp(options_keys[i], "inference_preference")) { + if (!absl::SimpleAtoi(options_values[i], + &options->inference_preference)) { + TFLITE_LOG(TFLITE_LOG_WARNING, "ParseOptions: malformed option %s.", + options_keys[i]); + return false; + } + } else if (strcmp(options_keys[i], "inference_priority1")) { + if (!absl::SimpleAtoi(options_values[i], &options->inference_priority1)) { + TFLITE_LOG(TFLITE_LOG_WARNING, "ParseOptions: malformed option %s.", + options_keys[i]); + return false; + } + } else if (strcmp(options_keys[i], "inference_priority2")) { + if (!absl::SimpleAtoi(options_values[i], &options->inference_priority2)) { + TFLITE_LOG(TFLITE_LOG_WARNING, "ParseOptions: malformed option %s.", + options_keys[i]); + return false; + } + } else if (strcmp(options_keys[i], "inference_priority3")) { + if (!absl::SimpleAtoi(options_values[i], &options->inference_priority3)) { + TFLITE_LOG(TFLITE_LOG_WARNING, "ParseOptions: malformed option %s.", + options_keys[i]); + return false; + } + } else if (strcmp(options_keys[i], "experimental_flags")) { + if (!absl::SimpleAtoi(options_values[i], &options->experimental_flags)) { + TFLITE_LOG(TFLITE_LOG_WARNING, "ParseOptions: malformed option %s.", + options_keys[i]); + return false; + } + } else if (strcmp(options_keys[i], "max_delegated_partitions")) { + if (!absl::SimpleAtoi(options_values[i], + &options->max_delegated_partitions)) { + TFLITE_LOG(TFLITE_LOG_WARNING, "ParseOptions: malformed option %s.", + options_keys[i]); + return false; + } + } else if (strcmp(options_keys[i], "serialization_dir")) { + options->serialization_dir = options_values[i]; + } else if (strcmp(options_keys[i], "model_token")) { + options->model_token = options_values[i]; + } else { + TFLITE_LOG(TFLITE_LOG_WARNING, "ParseOptions: unknown option %s.", + options_keys[i]); + return false; + } + } + + return true; +} + // Forward declarations. TfLiteStatus DelegatePrepare(TfLiteContext* context, TfLiteDelegate* delegate); @@ -521,9 +589,19 @@ absl::Status DelegateKernelCore::InitializeOpenGlApi( auto delegate_options = delegate_->options(); gl::InferenceOptions options; options.usage = ToUsage(delegate_options.inference_preference); - options.priority1 = ToPriority(delegate_options.inference_priority1); - options.priority2 = ToPriority(delegate_options.inference_priority2); - options.priority3 = ToPriority(delegate_options.inference_priority3); + // If is_precision_loss_allowed == -1, then just use priorities instead + // of paying attention to is_precision_loss_allowed value. + if (delegate_options.is_precision_loss_allowed == -1) { + options.priority1 = ToPriority(delegate_options.inference_priority1); + options.priority2 = ToPriority(delegate_options.inference_priority2); + options.priority3 = ToPriority(delegate_options.inference_priority3); + } else { + if (delegate_options.is_precision_loss_allowed == 0) { + options.priority1 = InferencePriority::MAX_PRECISION; + } else { + options.priority1 = InferencePriority::MIN_LATENCY; + } + } RETURN_IF_ERROR(gl_environment_->NewInferenceBuilder(std::move(*graph), options, builder)); enforce_same_thread_ = true; @@ -1419,3 +1497,19 @@ TfLiteDelegate* TfLiteGpuDelegateV2CreateAsync( void TfLiteGpuDelegateV2Delete(TfLiteDelegate* delegate) { delete tflite::gpu::GetDelegate(delegate); } + +TfLiteDelegate* tflite_plugin_create_delegate( + const char* const* options_keys, const char* const* options_values, + size_t num_options, void (*report_error)(const char*)) { + TfLiteGpuDelegateOptionsV2 options = TfLiteGpuDelegateOptionsV2Default(); + if (!tflite::gpu::ParseOptions(options_keys, options_values, num_options, + &options)) { + return nullptr; + } + + return TfLiteGpuDelegateV2Create(&options); +} + +void tflite_plugin_destroy_delegate(TfLiteDelegate* delegate) { + TfLiteGpuDelegateV2Delete(delegate); +} diff --git a/tensorflow/lite/delegates/gpu/delegate.h b/tensorflow/lite/delegates/gpu/delegate.h index 8c1ada4d5c5d56..0b11f41bc687e5 100644 --- a/tensorflow/lite/delegates/gpu/delegate.h +++ b/tensorflow/lite/delegates/gpu/delegate.h @@ -18,6 +18,8 @@ limitations under the License. #include +#include + #include "tensorflow/lite/core/c/common.h" #include "tensorflow/lite/delegates/gpu/delegate_options.h" @@ -46,6 +48,12 @@ TFL_CAPI_EXPORT TfLiteDelegate* TfLiteGpuDelegateV2CreateAsync( // Destroys a delegate created with `TfLiteGpuDelegateV2Create` call. TFL_CAPI_EXPORT void TfLiteGpuDelegateV2Delete(TfLiteDelegate* delegate); +TFL_CAPI_EXPORT TfLiteDelegate* tflite_plugin_create_delegate( + const char* const* options_keys, const char* const* options_values, + size_t num_options, void (*report_error)(const char*)); + +TFL_CAPI_EXPORT void tflite_plugin_destroy_delegate(TfLiteDelegate* delegate); + #ifdef __cplusplus } #endif // __cplusplus diff --git a/tensorflow/lite/delegates/gpu/metal/compute_task.cc b/tensorflow/lite/delegates/gpu/metal/compute_task.cc index ab8aec211b7aef..68e3eeffcb7fd5 100644 --- a/tensorflow/lite/delegates/gpu/metal/compute_task.cc +++ b/tensorflow/lite/delegates/gpu/metal/compute_task.cc @@ -35,27 +35,6 @@ namespace tflite { namespace gpu { namespace metal { namespace { -bool IsWordSymbol(char symbol) { - return absl::ascii_isalnum(symbol) || symbol == '_'; -} - -void ReplaceAllWords(const std::string& old_word, const std::string& new_word, - std::string* str) { - size_t position = str->find(old_word); - while (position != std::string::npos) { - const char prev = position == 0 ? ' ' : (*str)[position - 1]; - const char next = position + old_word.size() < str->size() - ? (*str)[position + old_word.size()] - : ' '; - if (IsWordSymbol(prev) || IsWordSymbol(next)) { - position = str->find(old_word, position + 1); - continue; - } - str->replace(position, old_word.size(), new_word); - position = str->find(old_word, position + new_word.size()); - } -} - std::map GetMetalDefines( MetalDevice* device, CalculationsPrecision precision) { std::string simdgroup_barrier; @@ -82,10 +61,6 @@ std::map GetMetalDefines( } } return { - {"FLT16_0123(V)", "V[0]"}, - {"FLT16_4567(V)", "V[1]"}, - {"FLT16_89ab(V)", "V[2]"}, - {"FLT16_cdef(V)", "V[3]"}, {"FLT", storage_type}, {"FLT2", storage_type + "2"}, {"FLT3", storage_type + "3"}, @@ -190,12 +165,6 @@ absl::Status ComputeTask::Compile(MetalDevice* device) { &operation_->args_, &operation_->code_)); operation_->args_.ReleaseCPURepresentation(); - - // manually resolving this defines, so as Metal has reserved words for them - ReplaceAllWords("float16", "float4x4", &operation_->code_); - ReplaceAllWords("half16", "half4x4", &operation_->code_); - ReplaceAllWords("float8", "float2x4", &operation_->code_); - ReplaceAllWords("half8", "half2x4", &operation_->code_); defines_ = GetMetalDefines(device, operation_->GetPrecision()); return CompileProgram(device, operation_->code_, defines_); } diff --git a/tensorflow/lite/delegates/gpu/metal/inference_context.cc b/tensorflow/lite/delegates/gpu/metal/inference_context.cc index 84542dde27faef..0dcfb1440c1e23 100644 --- a/tensorflow/lite/delegates/gpu/metal/inference_context.cc +++ b/tensorflow/lite/delegates/gpu/metal/inference_context.cc @@ -48,10 +48,10 @@ namespace { // returns true if actual memory for this storage type is buffer bool IsBufferBased(const GpuInfo& gpu_info, const TensorStorageType& type) { - const bool a7_gen_gpu = - gpu_info.IsApple() && gpu_info.apple_info.IsA7GenerationGpu(); - if (!a7_gen_gpu && (type == TensorStorageType::TEXTURE_2D || - type == TensorStorageType::SINGLE_TEXTURE_2D)) { + const bool family_apple1 = + gpu_info.IsApple() && gpu_info.apple_info.IsFamilyApple1(); + if (!family_apple1 && (type == TensorStorageType::TEXTURE_2D || + type == TensorStorageType::SINGLE_TEXTURE_2D)) { return true; } return type == TensorStorageType::BUFFER || diff --git a/tensorflow/lite/delegates/gpu/metal/metal_device.cc b/tensorflow/lite/delegates/gpu/metal/metal_device.cc index 1e387723ff66f6..9f3893af6306a4 100644 --- a/tensorflow/lite/delegates/gpu/metal/metal_device.cc +++ b/tensorflow/lite/delegates/gpu/metal/metal_device.cc @@ -40,11 +40,13 @@ GpuInfo CreateGpuInfoFromMetalDevice(id device) { } } - const bool a7_or_a8 = - gpu_info.IsApple() && (gpu_info.apple_info.IsA7GenerationGpu() || - gpu_info.apple_info.IsA8GenerationGpu()); - gpu_info.metal_info.image2d_max_width = a7_or_a8 ? 1024 * 8 : 1024 * 16; - gpu_info.metal_info.image2d_max_height = a7_or_a8 ? 1024 * 8 : 1024 * 16; + const bool family_apple1_or_2 = + gpu_info.IsApple() && + gpu_info.apple_info.IsFamilyOrLower(AppleInfo::Family::kApple2); + gpu_info.metal_info.image2d_max_width = + family_apple1_or_2 ? 1024 * 8 : 1024 * 16; + gpu_info.metal_info.image2d_max_height = + family_apple1_or_2 ? 1024 * 8 : 1024 * 16; gpu_info.metal_info.image_array_max_layers = 2048; gpu_info.metal_info.image3d_max_width = 2048; gpu_info.metal_info.image3d_max_height = 2048; diff --git a/tensorflow/lite/delegates/gpu/metal/metal_spatial_tensor.cc b/tensorflow/lite/delegates/gpu/metal/metal_spatial_tensor.cc index 87256bd682cb3e..996c31c4b2923b 100644 --- a/tensorflow/lite/delegates/gpu/metal/metal_spatial_tensor.cc +++ b/tensorflow/lite/delegates/gpu/metal/metal_spatial_tensor.cc @@ -409,10 +409,8 @@ absl::Status CreateTensorSharedImage2DBuffer(id buffer, } TensorStorageType GetFastestStorageType(const GpuInfo& gpu_info) { - const bool a7_or_a8 = - gpu_info.IsApple() && (gpu_info.apple_info.IsA7GenerationGpu() || - gpu_info.apple_info.IsA8GenerationGpu()); - if (a7_or_a8) { + if (gpu_info.IsApple() && + gpu_info.apple_info.IsFamilyOrLower(AppleInfo::Family::kApple2)) { return TensorStorageType::TEXTURE_2D; } else { return TensorStorageType::BUFFER; diff --git a/tensorflow/lite/delegates/hexagon/java/src/main/java/org/tensorflow/lite/BUILD b/tensorflow/lite/delegates/hexagon/java/src/main/java/org/tensorflow/lite/BUILD index 49c5a70aff578d..dfe05cf2abae27 100644 --- a/tensorflow/lite/delegates/hexagon/java/src/main/java/org/tensorflow/lite/BUILD +++ b/tensorflow/lite/delegates/hexagon/java/src/main/java/org/tensorflow/lite/BUILD @@ -1,4 +1,4 @@ -# copybara:uncomment package(default_applicable_licenses = ["//tensorflow:license"]) +# copybara:uncomment package(default_applicable_licenses = ["//tensorflow:LICENSE"]) licenses(["notice"]) diff --git a/tensorflow/lite/delegates/xnnpack/BUILD b/tensorflow/lite/delegates/xnnpack/BUILD index 34e879ae9375e3..6b878eb9ee6827 100644 --- a/tensorflow/lite/delegates/xnnpack/BUILD +++ b/tensorflow/lite/delegates/xnnpack/BUILD @@ -33,6 +33,12 @@ config_setting( define_values = {"xnnpack_use_latest_ops": "true"}, ) +# Use transient indirection buffers. +config_setting( + name = "xnnpack_use_transient_indirection_buffers_explicit", + define_values = {"xnnpack_use_transient_indirection_buffers": "true"}, +) + # Enable offloading of quantized 8-bit signed operators to XNNPACK delegate config_setting( name = "tflite_with_xnnpack_qs8_explicit_true", @@ -223,6 +229,9 @@ cc_library( }) + select({ ":xnnpack_use_latest_ops_explicit": ["-DXNNPACK_DELEGATE_USE_LATEST_OPS=1"], "//conditions:default": [], + }) + select({ + ":xnnpack_use_transient_indirection_buffers_explicit": ["-DXNNPACK_DELEGATE_USE_TRANSIENT_INDIRECTION_BUFFERS=1"], + "//conditions:default": [], }), linkstatic = True, deps = [ diff --git a/tensorflow/lite/delegates/xnnpack/xnnpack_delegate.cc b/tensorflow/lite/delegates/xnnpack/xnnpack_delegate.cc index 8c036c0dd3f95a..391629c49ba4fb 100644 --- a/tensorflow/lite/delegates/xnnpack/xnnpack_delegate.cc +++ b/tensorflow/lite/delegates/xnnpack/xnnpack_delegate.cc @@ -213,7 +213,7 @@ xnn_datatype GetXNNPackDatatype(TfLiteContext* context, TF_LITE_KERNEL_LOG(context, "unsupported zero-point value %d in channel " "%d of INT8 tensor %d in XNNPACK delegate", - quantization_params->zero_point[c], c, t); + quantization_params->zero_point->data[c], c, t); return xnn_datatype_invalid; } } @@ -380,7 +380,7 @@ class VariableHolder { TF_LITE_MAYBE_KERNEL_LOG( logging_context, "global id mismatch for tensor " - "%d, expected %zu, found %zu at VAR_HANDLE node %d", + "%d, expected %u, found %u at VAR_HANDLE node %d", tensor_id, global_id, it.first->second, node_index); return kTfLiteError; } @@ -434,9 +434,10 @@ class VariableHolder { for (size_t i = 0; i < NumDimensions(found_tensor); i++) { if (found_tensor->dims->data[i] != tensor->dims->data[i]) { TF_LITE_MAYBE_KERNEL_LOG(logging_context, - "mismatch between dimension %d of " + "mismatch between dimension %zu of " "variable tensor id %d: expected %d, got %d", - i, local_id, dims[i], tensor->dims->data[i]); + i, local_id, dims->data[i], + tensor->dims->data[i]); return kTfLiteError; } } @@ -556,8 +557,12 @@ class Delegate { } bool transient_indirection_buffer() const { +#ifdef XNNPACK_DELEGATE_USE_TRANSIENT_INDIRECTION_BUFFERS + return true; +#else return (options_.flags & TFLITE_XNNPACK_DELEGATE_FLAG_TRANSIENT_INDIRECTION_BUFFER) != 0; +#endif } bool experimental_adaptive_avx_optimization() const { @@ -900,7 +905,7 @@ class Subgraph { const auto it = global_id_to_xnnpack_id.find(global_id); if (it == global_id_to_xnnpack_id.end()) { TF_LITE_KERNEL_LOG(context, - "could not find variable with global id %zu in " + "could not find variable with global id %u in " "context %p for local tensor %d", global_id, context, t); return nullptr; @@ -1574,7 +1579,7 @@ class Subgraph { BuiltinOperator op_type, int node_index) { if (node->inputs->size != expected_num_inputs) { TF_LITE_MAYBE_KERNEL_LOG( - context, "unexpected number of inputs (%d != %d) in node #%d", + context, "unexpected number of inputs (%d != %d) in node %s #%d", node->inputs->size, expected_num_inputs, EnumNameBuiltinOperator(op_type), node_index); return kTfLiteError; @@ -2060,8 +2065,8 @@ class Subgraph { "unexpected value %d of shape dimension #%d in " "tensor #%d in %s node #%d: " "expected 1 for non-channel dimensions", - tensor.dims[i], i, tensor_index, EnumNameBuiltinOperator(op_type), - node_index); + tensor.dims->data[i], i, tensor_index, + EnumNameBuiltinOperator(op_type), node_index); return kTfLiteError; } } @@ -2920,7 +2925,7 @@ class Subgraph { "failed to delegate %s node #%d input tensor #%d and input tensor " "#%d. Mismatch at dimensions %zu (%d != %d)", EnumNameBuiltinOperator(BuiltinOperator_BATCH_MATMUL), node_index, - node->inputs->data[0], node->inputs->data[1], + node->inputs->data[0], node->inputs->data[1], i, input1_tensor.dims->data[i], input2_tensor.dims->data[i]); return kTfLiteError; } @@ -2930,7 +2935,7 @@ class Subgraph { "failed to delegate %s node #%d input tensor #%d and output tensor " "#%d. Mismatch at dimensions %zu (%d != %d)", EnumNameBuiltinOperator(BuiltinOperator_BATCH_MATMUL), node_index, - node->inputs->data[0], node->outputs->data[0], + node->inputs->data[0], node->outputs->data[0], i, input1_tensor.dims->data[i], output_tensor.dims->data[i]); return kTfLiteError; } @@ -4270,8 +4275,11 @@ class Subgraph { CheckTensorFloat32OrQUInt8Type(delegate, logging_context, output_tensor, node->outputs->data[0], node_index)); int expected_output_dims = 4; + uint32_t flags = 0; if (!reducer_params->keep_dims) { expected_output_dims -= num_reduction_axes; + } else { + flags = XNN_FLAG_KEEP_DIMS; } TF_LITE_ENSURE_STATUS(CheckTensorShape( logging_context, output_tensor, expected_output_dims, @@ -4279,7 +4287,6 @@ class Subgraph { TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation( logging_context, output_tensor, node->outputs->data[0], node_index)); - uint32_t flags = 0; const float output_min = -std::numeric_limits::infinity(); const float output_max = +std::numeric_limits::infinity(); @@ -4436,6 +4443,7 @@ class Subgraph { logging_context, output_tensor, node->outputs->data[0], node_index)); if (subgraph != nullptr) { + uint32_t flags = reducer_params->keep_dims ? XNN_FLAG_KEEP_DIMS : 0; xnn_status status = xnn_status_success; switch (num_reduction_axes) { case 1: @@ -4445,7 +4453,7 @@ class Subgraph { /*output_max=*/+std::numeric_limits::infinity(), /*input_id=*/input_output_tensors.at(node->inputs->data[0]), /*output_id=*/input_output_tensors.at(node->outputs->data[0]), - /*flags=*/0); + flags); break; case 2: status = xnn_define_global_average_pooling_2d( @@ -4454,7 +4462,7 @@ class Subgraph { /*output_max=*/+std::numeric_limits::infinity(), /*input_id=*/input_output_tensors.at(node->inputs->data[0]), /*output_id=*/input_output_tensors.at(node->outputs->data[0]), - /*flags=*/0); + flags); break; default: break; @@ -5464,7 +5472,7 @@ class Subgraph { TF_LITE_MAYBE_KERNEL_LOG(logging_context, "size %" PRId64 " does not match output shape %d at " - "dimension %d in SLICE node #%d", + "dimension %zu in SLICE node #%d", size[i], output_shape->data[i], i, node_index); return kTfLiteError; } @@ -5965,7 +5973,7 @@ class Subgraph { for (size_t i = 0; i < num_dims; i++) { if (stride_data[i] != 1) { TF_LITE_MAYBE_KERNEL_LOG(logging_context, - "stride at dimension %d, %d, must be 1" + "stride at dimension %zu, %d, must be 1" "in STRIDED_SLICE node #%d", i, stride_data[i], node_index); return kTfLiteError; @@ -6278,7 +6286,7 @@ class Subgraph { logging_context, "transpose convolution kernel input channel dimension (%d) " "doesn't match filter input channel (%d) in node #%d", - input_channels, input_tensor_dims[3]); + input_channels, input_tensor_dims[3], node_index); return kTfLiteError; } @@ -6586,7 +6594,7 @@ TfLiteIntArray* Delegate::PrepareOpsToDelegate(TfLiteContext* context) { if (node->inputs->size != 1) { TF_LITE_KERNEL_LOG( - context, "unexpected number of inputs (%d) in %s node %d", + context, "unexpected number of inputs (%d) in %d node %d", node->inputs->size, static_cast(registration->builtin_code), producer_index); @@ -6596,7 +6604,7 @@ TfLiteIntArray* Delegate::PrepareOpsToDelegate(TfLiteContext* context) { if (node->outputs->size != 1) { TF_LITE_KERNEL_LOG( - context, "unexpected number of outputs (%d) in %s node %d", + context, "unexpected number of outputs (%d) in %d node %d", node->outputs->size, static_cast(registration->builtin_code), producer_index); diff --git a/tensorflow/lite/examples/label_image/BUILD b/tensorflow/lite/examples/label_image/BUILD index f9d368dfbceb99..b5942f010521ea 100644 --- a/tensorflow/lite/examples/label_image/BUILD +++ b/tensorflow/lite/examples/label_image/BUILD @@ -31,16 +31,18 @@ cc_binary( deps = [ ":bitmap_helpers", "//tensorflow/lite:framework", - "//tensorflow/lite:string_util", + "//tensorflow/lite:string", + "//tensorflow/lite/c:c_api_types", "//tensorflow/lite/c:common", + "//tensorflow/lite/core:cc_api_stable", "//tensorflow/lite/kernels:builtin_ops", + "//tensorflow/lite/profiling:profile_buffer", "//tensorflow/lite/profiling:profiler", + "//tensorflow/lite/schema:schema_fbs", "//tensorflow/lite/tools:command_line_flags", "//tensorflow/lite/tools:tool_params", "//tensorflow/lite/tools/delegates:delegate_provider_hdr", "//tensorflow/lite/tools/delegates:tflite_execution_providers", - "@com_google_absl//absl/memory", - "@com_google_absl//absl/strings", ], ) @@ -59,8 +61,10 @@ cc_library( "//tensorflow/lite:framework", "//tensorflow/lite:string", "//tensorflow/lite:string_util", + "//tensorflow/lite/c:c_api_types", "//tensorflow/lite/kernels:builtin_ops", "//tensorflow/lite/schema:schema_fbs", + "@local_tsl//tsl/platform:tstring", ] + select({ "//tensorflow:android": [ "//tensorflow/lite/delegates/gpu:delegate", @@ -86,6 +90,7 @@ cc_test( ], deps = [ ":bitmap_helpers", + "//tensorflow/lite/c:c_api_types", "//tensorflow/lite/c:common", "@com_google_googletest//:gtest_main", ], diff --git a/tensorflow/lite/examples/label_image/CMakeLists.txt b/tensorflow/lite/examples/label_image/CMakeLists.txt index ae4ab447064713..08044b1675beb3 100644 --- a/tensorflow/lite/examples/label_image/CMakeLists.txt +++ b/tensorflow/lite/examples/label_image/CMakeLists.txt @@ -56,6 +56,11 @@ if(TFLITE_ENABLE_GPU) ) endif() # TFLITE_ENABLE_GPU +if(TFLITE_ENABLE_EXTERNAL_DELEGATE) + list(APPEND TFLITE_LABEL_IMAGE_SRCS + ${TFLITE_SOURCE_DIR}/tools/delegates/external_delegate_provider.cc) +endif() + add_executable(label_image ${TFLITE_LABEL_IMAGE_SRCS} ) diff --git a/tensorflow/lite/examples/label_image/bitmap_helpers.cc b/tensorflow/lite/examples/label_image/bitmap_helpers.cc index f060287154022d..d3698f3b22218b 100644 --- a/tensorflow/lite/examples/label_image/bitmap_helpers.cc +++ b/tensorflow/lite/examples/label_image/bitmap_helpers.cc @@ -24,8 +24,9 @@ limitations under the License. #include #include -#include "tensorflow/core/platform/ctstring_internal.h" +#include "tensorflow/lite/examples/label_image/label_image.h" #include "tensorflow/lite/examples/label_image/log.h" +#include "tsl/platform/ctstring_internal.h" namespace tflite { namespace label_image { diff --git a/tensorflow/lite/examples/label_image/label_image.cc b/tensorflow/lite/examples/label_image/label_image.cc index 803a2a89c931ef..5811fc3006553d 100644 --- a/tensorflow/lite/examples/label_image/label_image.cc +++ b/tensorflow/lite/examples/label_image/label_image.cc @@ -36,16 +36,23 @@ limitations under the License. #include #include -#include "absl/memory/memory.h" +#include "tensorflow/lite/c/c_api_types.h" +#include "tensorflow/lite/c/common.h" +#include "tensorflow/lite/core/interpreter_builder.h" #include "tensorflow/lite/examples/label_image/bitmap_helpers.h" #include "tensorflow/lite/examples/label_image/get_top_n.h" #include "tensorflow/lite/examples/label_image/log.h" +#include "tensorflow/lite/interpreter.h" #include "tensorflow/lite/kernels/register.h" +#include "tensorflow/lite/model_builder.h" #include "tensorflow/lite/optional_debug_tools.h" +#include "tensorflow/lite/profiling/profile_buffer.h" #include "tensorflow/lite/profiling/profiler.h" -#include "tensorflow/lite/string_util.h" +#include "tensorflow/lite/schema/schema_generated.h" +#include "tensorflow/lite/string_type.h" #include "tensorflow/lite/tools/command_line_flags.h" #include "tensorflow/lite/tools/delegates/delegate_provider.h" +#include "tensorflow/lite/tools/tool_params.h" namespace tflite { namespace label_image { diff --git a/tensorflow/lite/examples/label_image/label_image.h b/tensorflow/lite/examples/label_image/label_image.h index 1c00edb6558f7a..db55265b7e08db 100644 --- a/tensorflow/lite/examples/label_image/label_image.h +++ b/tensorflow/lite/examples/label_image/label_image.h @@ -16,7 +16,9 @@ limitations under the License. #ifndef TENSORFLOW_LITE_EXAMPLES_LABEL_IMAGE_LABEL_IMAGE_H_ #define TENSORFLOW_LITE_EXAMPLES_LABEL_IMAGE_LABEL_IMAGE_H_ +#include "tensorflow/lite/c/c_api_types.h" #include "tensorflow/lite/model.h" +#include "tensorflow/lite/model_builder.h" #include "tensorflow/lite/string_type.h" namespace tflite { diff --git a/tensorflow/lite/examples/label_image/label_image_test.cc b/tensorflow/lite/examples/label_image/label_image_test.cc index 3e51ca5ea32dbf..d4e2e87270484b 100644 --- a/tensorflow/lite/examples/label_image/label_image_test.cc +++ b/tensorflow/lite/examples/label_image/label_image_test.cc @@ -17,8 +17,8 @@ limitations under the License. #include -#include #include +#include "tensorflow/lite/c/c_api_types.h" #include "tensorflow/lite/examples/label_image/bitmap_helpers.h" #include "tensorflow/lite/examples/label_image/get_top_n.h" diff --git a/tensorflow/lite/examples/python/BUILD b/tensorflow/lite/examples/python/BUILD index 5de058126f8054..1b853c6546c0a6 100644 --- a/tensorflow/lite/examples/python/BUILD +++ b/tensorflow/lite/examples/python/BUILD @@ -5,15 +5,19 @@ package( licenses = ["notice"], ) -py_strict_binary( - name = "label_image", - srcs = ["label_image.py"], - main = "label_image.py", - python_version = "PY3", - srcs_version = "PY3", - deps = [ - "//tensorflow:tensorflow_py", - "//third_party/py/PIL:pil", - "//third_party/py/numpy", - ], -) +# Commented out under the (b/279852433) because caused an error in the OSS +# TODO(zhurakovskyi): Uncomment when fixed. +# copybara:uncomment_begin +# py_strict_binary( +# name = "label_image", +# srcs = ["label_image.py"], +# main = "label_image.py", +# python_version = "PY3", +# srcs_version = "PY3", +# deps = [ +# "//third_party/py/PIL:pil", +# "//third_party/py/numpy", +# "//tensorflow:tensorflow_py", +# ], +# ) +# copybara:uncomment_end diff --git a/tensorflow/lite/experimental/acceleration/compatibility/gpu_compatibility.bin b/tensorflow/lite/experimental/acceleration/compatibility/gpu_compatibility.bin index 52ae5ffb37ce02..bfe2754f16ff3c 100644 Binary files a/tensorflow/lite/experimental/acceleration/compatibility/gpu_compatibility.bin and b/tensorflow/lite/experimental/acceleration/compatibility/gpu_compatibility.bin differ diff --git a/tensorflow/lite/experimental/acceleration/configuration/BUILD b/tensorflow/lite/experimental/acceleration/configuration/BUILD index 2e14980693eee9..a1329c13fb6bc5 100644 --- a/tensorflow/lite/experimental/acceleration/configuration/BUILD +++ b/tensorflow/lite/experimental/acceleration/configuration/BUILD @@ -155,10 +155,10 @@ cc_library( "//tensorflow/lite/core/experimental/acceleration/configuration:delegate_registry", "@com_google_absl//absl/memory", ] + select({ - "//third_party/bazel_platforms/cpu:aarch64": [ + "@platforms//cpu:aarch64": [ "//tensorflow/lite/delegates/hexagon:hexagon_delegate", ], - "//third_party/bazel_platforms/cpu:armv7": [ + "@platforms//cpu:armv7": [ "//tensorflow/lite/delegates/hexagon:hexagon_delegate", ], "//conditions:default": [], diff --git a/tensorflow/lite/experimental/microfrontend/lib/bits.h b/tensorflow/lite/experimental/microfrontend/lib/bits.h index 04b3ba6f055f95..5e79ccaf6a79ca 100644 --- a/tensorflow/lite/experimental/microfrontend/lib/bits.h +++ b/tensorflow/lite/experimental/microfrontend/lib/bits.h @@ -30,13 +30,13 @@ static inline int CountLeadingZeros32Slow(uint64_t n) { } static inline int CountLeadingZeros32(uint32_t n) { -#if defined(_MSC_VER) +#if !defined(__clang__) && defined(_MSC_VER) unsigned long result = 0; // NOLINT(runtime/int) if (_BitScanReverse(&result, n)) { return 31 - result; } return 32; -#elif defined(__GNUC__) +#elif defined(__clang__) && defined(__GNUC__) // Handle 0 as a special case because __builtin_clz(0) is undefined. if (n == 0) { @@ -62,14 +62,14 @@ static inline int CountLeadingZeros64Slow(uint64_t n) { } static inline int CountLeadingZeros64(uint64_t n) { -#if defined(_MSC_VER) && defined(_M_X64) +#if !defined(__clang__) && defined(_MSC_VER) && defined(_M_X64) // MSVC does not have __builtin_clzll. Use _BitScanReverse64. unsigned long result = 0; // NOLINT(runtime/int) if (_BitScanReverse64(&result, n)) { return 63 - result; } return 64; -#elif defined(_MSC_VER) +#elif !defined(__clang__) && defined(_MSC_VER) // MSVC does not have __builtin_clzll. Compose two calls to _BitScanReverse unsigned long result = 0; // NOLINT(runtime/int) if ((n >> 32) && _BitScanReverse(&result, n >> 32)) { @@ -79,7 +79,7 @@ static inline int CountLeadingZeros64(uint64_t n) { return 63 - result; } return 64; -#elif defined(__GNUC__) +#elif defined(__clang__) || defined(__GNUC__) // Handle 0 as a special case because __builtin_clzll(0) is undefined. if (n == 0) { diff --git a/tensorflow/lite/experimental/shlo/BUILD b/tensorflow/lite/experimental/shlo/BUILD index f45baf77128a2f..ab22092d029954 100644 --- a/tensorflow/lite/experimental/shlo/BUILD +++ b/tensorflow/lite/experimental/shlo/BUILD @@ -1,65 +1,99 @@ -# StableHLO Device Profile reference implementation +# StableHLO Reference Library package( - default_applicable_licenses = ["//tensorflow:license"], - default_visibility = [":__subpackages__"], + default_applicable_licenses = ["//tensorflow:LICENSE"], + default_visibility = ["//visibility:public"], ) cc_library( name = "shlo", - srcs = [ - "src/broadcast_in_dim.cc", - "src/clamp.cc", - "src/compare.cc", - "src/concatenate.cc", - "src/dispatch.h", - "src/elementwise_binary.cc", - "src/elementwise_unary.cc", - "src/iota.cc", - "src/is_finite.cc", - "src/select.cc", - "src/shlo.cc", - "src/uniform_dequantize_quantize.cc", - ], - hdrs = [ - "include/shlo.h", - "src/storage.h", - "src/util.h", - ], deps = [ - ":float", - "@com_google_absl//absl/log", - "@com_google_absl//absl/status", - "@com_google_absl//absl/types:span", + ":tensor", ], ) cc_library( - name = "debug", - srcs = [ - "src/debug.cc", - ], - hdrs = [ - "src/debug.h", + name = "tensor", + srcs = ["tensor.cc"], + hdrs = ["tensor.h"], + deps = [ + ":data_type", + ":quantized_tensor_element_type", + ":shape", ], +) + +cc_library( + name = "shape", + srcs = ["shape.cc"], + hdrs = ["shape.h"], deps = [ - ":float", - ":shlo", - "@com_google_absl//absl/log", - "@com_google_absl//absl/log:check", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/types:span", ], ) +cc_test( + name = "shape_test", + srcs = ["shape_test.cc"], + deps = [ + ":shape", + "@com_google_googletest//:gtest_main", + ], +) + cc_library( - name = "float", - srcs = [ + name = "quantized_tensor_element_type", + hdrs = ["quantized_tensor_element_type.h"], + deps = [ + ":data_type", + ":shape", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/types:span", ], - hdrs = [ - "src/bf16.h", - "src/f16.h", - "src/has_keyword.h", +) + +cc_test( + name = "quantized_tensor_element_type_test", + srcs = ["quantized_tensor_element_type_test.cc"], + deps = [ + ":data_type", + ":quantized_tensor_element_type", + "@com_google_googletest//:gtest_main", ], +) + +cc_library( + name = "bf16", + hdrs = ["bf16.h"], + deps = [":has_keyword"], +) + +cc_library( + name = "f16", + hdrs = ["f16.h"], + deps = [":has_keyword"], +) + +cc_library( + name = "has_keyword", + hdrs = ["has_keyword.h"], +) + +cc_library( + name = "data_type", + hdrs = ["data_type.h"], deps = [ + ":bf16", + ":f16", ], ) + +cc_library( + name = "dispatch", + hdrs = ["dispatch.h"], + visibility = ["//visibility:private"], + deps = [":data_type"], +) diff --git a/tensorflow/lite/experimental/shlo/bf16.h b/tensorflow/lite/experimental/shlo/bf16.h new file mode 100644 index 00000000000000..b89ccda148c9f7 --- /dev/null +++ b/tensorflow/lite/experimental/shlo/bf16.h @@ -0,0 +1,129 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_SHLO_BF16_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_SHLO_BF16_H_ + +#include "tensorflow/lite/experimental/shlo/has_keyword.h" + +#if defined(__STDCPP_BFLOAT16_T__) +#include +namespace shlo_ref { +using BF16 = bfloat16_t; +} // namespace shlo_ref + +#elif __has_keyword(__bf16) && __x86_64__ +namespace shlo_ref { +// On x86 the compiler is able to generate code for __bf16 operations. +using BF16 = __bf16; +} // namespace shlo_ref + +#elif __has_keyword(__bf16) && __aarch64__ +#include +#include + +namespace shlo_ref { + +// On arm64 the compiler is not yet able to generate code for __bf16 +// operations. Therefore, we resort to a software-based implementation of BF16 +// based on promoting ops to float. +class BF16 { + public: + BF16(float f = 0.0f) { + if (std::isnan(f)) { + // If the value is a NaN, squash it to a NaN with the msb of the + // mantissa. This avoids that after the truncation below we don't end up + // with an inf. + value_ = std::signbit(f) ? 0xFFC0 : 0x7FC0; + } else { + // Fast rounding algorithm that rounds a half value to nearest even. This + // reduces expected error when we convert a large number of floats. + uint32_t input = *reinterpret_cast(&f); + + // Least significant bit of resulting bfloat. + uint32_t lsb = (input >> 16) & 1; + uint32_t rounding_bias = 0x7fff + lsb; + input += rounding_bias; + + value_ = static_cast(input >> 16u); + } + } + + BF16& operator=(BF16 other) { + value_ = other.value_; + return *this; + } + + bool operator==(BF16 other) const { return value_ == other.value_; } + bool operator!=(BF16 other) const { return !(*this == other); } + + operator float() const { + uint32_t tmp = value_ << 16; + return *reinterpret_cast(&tmp); + } + + BF16 operator-() const { return BF16(-static_cast(*this)); } + + BF16& operator+=(BF16 other) { + value_ = BF16(static_cast(*this) + static_cast(other)).value_; + return *this; + } + + BF16& operator-=(BF16 other) { + value_ = BF16(static_cast(*this) - static_cast(other)).value_; + return *this; + } + + BF16& operator*=(BF16 other) { + value_ = BF16(static_cast(*this) * static_cast(other)).value_; + return *this; + } + + BF16& operator/=(BF16 other) { + value_ = BF16(static_cast(*this) / static_cast(other)).value_; + return *this; + } + + private: + uint16_t value_; +}; + +inline BF16 operator+(BF16 x, BF16 y) { + x += y; + return x; +} + +inline BF16 operator-(BF16 x, BF16 y) { + x -= y; + return x; +} + +inline BF16 operator*(BF16 x, BF16 y) { + x *= y; + return x; +} + +inline BF16 operator/(BF16 x, BF16 y) { + x /= y; + return x; +} + +} // namespace shlo_ref + +#else +#error Type BF16 is not available +#endif + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_SHLO_BF16_H_ diff --git a/tensorflow/lite/experimental/shlo/data_type.h b/tensorflow/lite/experimental/shlo/data_type.h new file mode 100644 index 00000000000000..dbd90ec428b067 --- /dev/null +++ b/tensorflow/lite/experimental/shlo/data_type.h @@ -0,0 +1,119 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_SHLO_DATA_TYPE_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_SHLO_DATA_TYPE_H_ + +#include + +#include "tensorflow/lite/experimental/shlo/bf16.h" +#include "tensorflow/lite/experimental/shlo/f16.h" + +namespace shlo_ref { + +// For more information on StableHLO types, see the spec., search for "Element +// types". The SHLO Device Profile does not include unsigned or 64 bit types. +enum class DataType { + kI1, + kSI4, + kSI8, + kSI16, + kSI32, + kBF16, + kF16, + kF32, +}; + +// Storage provides the corresponding C++ type for the given DataType. +template +struct Storage {}; + +template <> +struct Storage { + using Type = bool; +}; +template <> +struct Storage { + using Type = int8_t; +}; +template <> +struct Storage { + using Type = int8_t; +}; +template <> +struct Storage { + using Type = int16_t; +}; +template <> +struct Storage { + using Type = int32_t; +}; +template <> +struct Storage { + using Type = BF16; +}; +template <> +struct Storage { + using Type = F16; +}; +template <> +struct Storage { + using Type = float; +}; + +template +using StorageType = typename Storage::Type; + +constexpr bool IsBool(DataType data_type) { return data_type == DataType::kI1; } + +constexpr bool IsInteger(DataType data_type) { + return data_type == DataType::kSI4 || data_type == DataType::kSI8 || + data_type == DataType::kSI16 || data_type == DataType::kSI32; +} +constexpr bool IsFloat(DataType data_type) { + return data_type == DataType::kBF16 || data_type == DataType::kF16 || + data_type == DataType::kF32; +} + +template +constexpr int64_t SizeOf() { + return sizeof(StorageType); +} + +constexpr int64_t SizeOf(DataType data_type) { + switch (data_type) { + case DataType::kI1: + return SizeOf(); + case DataType::kSI4: + return SizeOf(); + case DataType::kSI8: + return SizeOf(); + case DataType::kSI16: + return SizeOf(); + case DataType::kSI32: + return SizeOf(); + case DataType::kBF16: + return SizeOf(); + case DataType::kF16: + return SizeOf(); + case DataType::kF32: + return SizeOf(); + break; + } +} + +} // namespace shlo_ref + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_SHLO_DATA_TYPE_H_ diff --git a/tensorflow/lite/experimental/shlo/dispatch.h b/tensorflow/lite/experimental/shlo/dispatch.h new file mode 100644 index 00000000000000..9e152d686fcf13 --- /dev/null +++ b/tensorflow/lite/experimental/shlo/dispatch.h @@ -0,0 +1,151 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_SHLO_DISPATCH_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_SHLO_DISPATCH_H_ + +#define DISPATCH_INT(name, element_type, ...) \ + { \ + switch (element_type) { \ + case DataType::kSI4: \ + return name(__VA_ARGS__); \ + case DataType::kSI8: \ + return name(__VA_ARGS__); \ + case DataType::kSI16: \ + return name(__VA_ARGS__); \ + case DataType::kSI32: \ + return name(__VA_ARGS__); \ + default: \ + return absl::InvalidArgumentError("Unsupported element type"); \ + } \ + } + +#define DISPATCH_FLOAT(name, element_type, ...) \ + { \ + switch (element_type) { \ + case DataType::kBF16: \ + return name(__VA_ARGS__); \ + case DataType::kF16: \ + return name(__VA_ARGS__); \ + case DataType::kF32: \ + return name(__VA_ARGS__); \ + default: \ + return absl::InvalidArgumentError("Unsupported element type"); \ + } \ + } + +#define DISPATCH_INT_FLOAT(name, element_type, ...) \ + { \ + switch (element_type) { \ + case DataType::kSI4: \ + return name(__VA_ARGS__); \ + case DataType::kSI8: \ + return name(__VA_ARGS__); \ + case DataType::kSI16: \ + return name(__VA_ARGS__); \ + case DataType::kSI32: \ + return name(__VA_ARGS__); \ + case DataType::kBF16: \ + return name(__VA_ARGS__); \ + case DataType::kF16: \ + return name(__VA_ARGS__); \ + case DataType::kF32: \ + return name(__VA_ARGS__); \ + default: \ + return absl::InvalidArgumentError("Unsupported element type"); \ + } \ + } + +#define DISPATCH_BOOL_INT_FLOAT(name, element_type, ...) \ + { \ + switch (element_type) { \ + case DataType::kI1: \ + return name(__VA_ARGS__); \ + case DataType::kSI4: \ + return name(__VA_ARGS__); \ + case DataType::kSI8: \ + return name(__VA_ARGS__); \ + case DataType::kSI16: \ + return name(__VA_ARGS__); \ + case DataType::kSI32: \ + return name(__VA_ARGS__); \ + case DataType::kBF16: \ + return name(__VA_ARGS__); \ + case DataType::kF16: \ + return name(__VA_ARGS__); \ + case DataType::kF32: \ + return name(__VA_ARGS__); \ + default: \ + return absl::InvalidArgumentError("Unsupported element type"); \ + } \ + } + +#define DISPATCH_QUANTIZED(name, storage_type, expressed_type, ...) \ + { \ + switch (storage_type) { \ + case DataType::kSI4: \ + switch (expressed_type) { \ + case DataType::kBF16: \ + return name(__VA_ARGS__); \ + case DataType::kF16: \ + return name(__VA_ARGS__); \ + case DataType::kF32: \ + return name(__VA_ARGS__); \ + default: \ + return absl::InvalidArgumentError("Unsupported expressed type"); \ + } \ + break; \ + case DataType::kSI8: \ + switch (expressed_type) { \ + case DataType::kBF16: \ + return name(__VA_ARGS__); \ + case DataType::kF16: \ + return name(__VA_ARGS__); \ + case DataType::kF32: \ + return name(__VA_ARGS__); \ + default: \ + return absl::InvalidArgumentError("Unsupported expressed type"); \ + } \ + break; \ + case DataType::kSI16: \ + switch (expressed_type) { \ + case DataType::kBF16: \ + return name(__VA_ARGS__); \ + case DataType::kF16: \ + return name(__VA_ARGS__); \ + case DataType::kF32: \ + return name(__VA_ARGS__); \ + default: \ + return absl::InvalidArgumentError("Unsupported expressed type"); \ + } \ + break; \ + case DataType::kSI32: \ + switch (expressed_type) { \ + case DataType::kBF16: \ + return name(__VA_ARGS__); \ + case DataType::kF16: \ + return name(__VA_ARGS__); \ + case DataType::kF32: \ + return name(__VA_ARGS__); \ + default: \ + return absl::InvalidArgumentError("Unsupported expressed type"); \ + } \ + break; \ + default: \ + return absl::InvalidArgumentError("Unsupported storage type"); \ + } \ + } + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_SHLO_DISPATCH_H_ diff --git a/tensorflow/lite/experimental/shlo/f16.h b/tensorflow/lite/experimental/shlo/f16.h new file mode 100644 index 00000000000000..2496b31b84dc9f --- /dev/null +++ b/tensorflow/lite/experimental/shlo/f16.h @@ -0,0 +1,36 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_SHLO_F16_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_SHLO_F16_H_ + +#include "tensorflow/lite/experimental/shlo/has_keyword.h" + +#if defined(__STDCPP_FLOAT16_T__) +#include +namespace shlo_ref { +using F16 = float16_t; +} // namespace shlo_ref + +#elif __has_keyword(_Float16) +namespace shlo_ref { +using F16 = _Float16; +} // namespace shlo_ref + +#else +#error Type F16 is not available +#endif + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_SHLO_F16_H_ diff --git a/tensorflow/lite/experimental/shlo/src/has_keyword.h b/tensorflow/lite/experimental/shlo/has_keyword.h similarity index 86% rename from tensorflow/lite/experimental/shlo/src/has_keyword.h rename to tensorflow/lite/experimental/shlo/has_keyword.h index eb6f73f899e56c..548c86eec4de36 100644 --- a/tensorflow/lite/experimental/shlo/src/has_keyword.h +++ b/tensorflow/lite/experimental/shlo/has_keyword.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_SHLO_SRC_HAS_KEYWORD_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_SHLO_SRC_HAS_KEYWORD_H_ +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_SHLO_HAS_KEYWORD_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_SHLO_HAS_KEYWORD_H_ // CAUTION: __is_identifier behaves opposite how you would expect! // '__is_identifier' returns '0' if '__x' is a reserved identifier provided by @@ -29,4 +29,4 @@ limitations under the License. // More sensible macro for keyword detection #define __has_keyword(__x) !(__is_identifier(__x)) -#endif // TENSORFLOW_LITE_EXPERIMENTAL_SHLO_SRC_HAS_KEYWORD_H_ +#endif // TENSORFLOW_LITE_EXPERIMENTAL_SHLO_HAS_KEYWORD_H_ diff --git a/tensorflow/lite/experimental/shlo/legacy/BUILD b/tensorflow/lite/experimental/shlo/legacy/BUILD new file mode 100644 index 00000000000000..2814b314f1d038 --- /dev/null +++ b/tensorflow/lite/experimental/shlo/legacy/BUILD @@ -0,0 +1,65 @@ +# StableHLO Device Profile reference implementation + +package( + default_applicable_licenses = ["//tensorflow:LICENSE"], + default_visibility = [":__subpackages__"], +) + +cc_library( + name = "shlo", + srcs = [ + "src/broadcast_in_dim.cc", + "src/clamp.cc", + "src/compare.cc", + "src/concatenate.cc", + "src/dispatch.h", + "src/elementwise_binary.cc", + "src/elementwise_unary.cc", + "src/iota.cc", + "src/is_finite.cc", + "src/select.cc", + "src/shlo.cc", + "src/uniform_dequantize_quantize.cc", + ], + hdrs = [ + "include/shlo.h", + "src/storage.h", + "src/util.h", + ], + deps = [ + ":float", + "@com_google_absl//absl/log", + "@com_google_absl//absl/status", + "@com_google_absl//absl/types:span", + ], +) + +cc_library( + name = "debug", + srcs = [ + "src/debug.cc", + ], + hdrs = [ + "src/debug.h", + ], + deps = [ + ":float", + ":shlo", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/types:span", + ], +) + +cc_library( + name = "float", + srcs = [ + ], + hdrs = [ + "src/bf16.h", + "src/f16.h", + "src/has_keyword.h", + ], + deps = [ + ], +) diff --git a/tensorflow/lite/experimental/shlo/bench/BUILD b/tensorflow/lite/experimental/shlo/legacy/bench/BUILD similarity index 68% rename from tensorflow/lite/experimental/shlo/bench/BUILD rename to tensorflow/lite/experimental/shlo/legacy/bench/BUILD index c021814deabe51..400f05e72dafa0 100644 --- a/tensorflow/lite/experimental/shlo/bench/BUILD +++ b/tensorflow/lite/experimental/shlo/legacy/bench/BUILD @@ -1,4 +1,4 @@ -package(default_applicable_licenses = ["//tensorflow:license"]) +package(default_applicable_licenses = ["//tensorflow:LICENSE"]) cc_library( name = "util", @@ -8,7 +8,7 @@ cc_library( "util.h", ], deps = [ - "//tensorflow/lite/experimental/shlo:float", + "//tensorflow/lite/experimental/shlo/legacy:float", ], ) @@ -19,8 +19,8 @@ cc_binary( ], deps = [ ":util", - "//tensorflow/lite/experimental/shlo", - "//tensorflow/lite/experimental/shlo/test:util", + "//tensorflow/lite/experimental/shlo/legacy:shlo", + "//tensorflow/lite/experimental/shlo/legacy/test:util", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", @@ -37,7 +37,7 @@ cc_binary( ], deps = [ ":util", - "//tensorflow/lite/experimental/shlo:float", + "//tensorflow/lite/experimental/shlo/legacy:float", "@XNNPACK", "@com_google_absl//absl/log", "@com_google_benchmark//:benchmark", diff --git a/tensorflow/lite/experimental/shlo/bench/shlo_benchmark.cc b/tensorflow/lite/experimental/shlo/legacy/bench/shlo_benchmark.cc similarity index 96% rename from tensorflow/lite/experimental/shlo/bench/shlo_benchmark.cc rename to tensorflow/lite/experimental/shlo/legacy/bench/shlo_benchmark.cc index bf33e710217ad0..4d37f2f521420b 100644 --- a/tensorflow/lite/experimental/shlo/bench/shlo_benchmark.cc +++ b/tensorflow/lite/experimental/shlo/legacy/bench/shlo_benchmark.cc @@ -19,10 +19,10 @@ limitations under the License. #include "absl/log/log.h" #include "absl/status/status.h" #include "benchmark/benchmark.h" // from @com_google_benchmark -#include "tensorflow/lite/experimental/shlo/bench/util.h" -#include "tensorflow/lite/experimental/shlo/include/shlo.h" -#include "tensorflow/lite/experimental/shlo/src/storage.h" -#include "tensorflow/lite/experimental/shlo/test/util.h" +#include "tensorflow/lite/experimental/shlo/legacy/bench/util.h" +#include "tensorflow/lite/experimental/shlo/legacy/include/shlo.h" +#include "tensorflow/lite/experimental/shlo/legacy/src/storage.h" +#include "tensorflow/lite/experimental/shlo/legacy/test/util.h" namespace stablehlo { namespace benchmark { diff --git a/tensorflow/lite/experimental/shlo/bench/util.h b/tensorflow/lite/experimental/shlo/legacy/bench/util.h similarity index 86% rename from tensorflow/lite/experimental/shlo/bench/util.h rename to tensorflow/lite/experimental/shlo/legacy/bench/util.h index d79e822bd84990..29c81541175974 100644 --- a/tensorflow/lite/experimental/shlo/bench/util.h +++ b/tensorflow/lite/experimental/shlo/legacy/bench/util.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_SHLO_BENCH_UTIL_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_SHLO_BENCH_UTIL_H_ +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_SHLO_LEGACY_BENCH_UTIL_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_SHLO_LEGACY_BENCH_UTIL_H_ #include #include @@ -23,8 +23,8 @@ limitations under the License. #include #include -#include "tensorflow/lite/experimental/shlo/src/bf16.h" -#include "tensorflow/lite/experimental/shlo/src/f16.h" +#include "tensorflow/lite/experimental/shlo/legacy/src/bf16.h" +#include "tensorflow/lite/experimental/shlo/legacy/src/f16.h" namespace stablehlo { @@ -61,4 +61,4 @@ std::vector GenerateRandomVector( } // namespace stablehlo -#endif // TENSORFLOW_LITE_EXPERIMENTAL_SHLO_BENCH_UTIL_H_ +#endif // TENSORFLOW_LITE_EXPERIMENTAL_SHLO_LEGACY_BENCH_UTIL_H_ diff --git a/tensorflow/lite/experimental/shlo/bench/xnn_benchmark.cc b/tensorflow/lite/experimental/shlo/legacy/bench/xnn_benchmark.cc similarity index 98% rename from tensorflow/lite/experimental/shlo/bench/xnn_benchmark.cc rename to tensorflow/lite/experimental/shlo/legacy/bench/xnn_benchmark.cc index bc01b2b250d43b..6e9d502692c3f6 100644 --- a/tensorflow/lite/experimental/shlo/bench/xnn_benchmark.cc +++ b/tensorflow/lite/experimental/shlo/legacy/bench/xnn_benchmark.cc @@ -22,8 +22,8 @@ limitations under the License. #include "xnnpack.h" // from @XNNPACK #include "absl/log/log.h" #include "benchmark/benchmark.h" // from @com_google_benchmark -#include "tensorflow/lite/experimental/shlo/bench/util.h" -#include "tensorflow/lite/experimental/shlo/src/f16.h" +#include "tensorflow/lite/experimental/shlo/legacy/bench/util.h" +#include "tensorflow/lite/experimental/shlo/legacy/src/f16.h" namespace stablehlo { namespace benchmark { diff --git a/tensorflow/lite/experimental/shlo/include/shlo.h b/tensorflow/lite/experimental/shlo/legacy/include/shlo.h similarity index 98% rename from tensorflow/lite/experimental/shlo/include/shlo.h rename to tensorflow/lite/experimental/shlo/legacy/include/shlo.h index 28df8d7ba3d53b..31dbbc90576dc8 100644 --- a/tensorflow/lite/experimental/shlo/include/shlo.h +++ b/tensorflow/lite/experimental/shlo/legacy/include/shlo.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_SHLO_INCLUDE_SHLO_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_SHLO_INCLUDE_SHLO_H_ +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_SHLO_LEGACY_INCLUDE_SHLO_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_SHLO_LEGACY_INCLUDE_SHLO_H_ #include #include @@ -419,4 +419,4 @@ absl::Status Xor(const Tensor& lhs, const Tensor& rhs, Tensor& result); } // namespace stablehlo -#endif // TENSORFLOW_LITE_EXPERIMENTAL_SHLO_INCLUDE_SHLO_H_ +#endif // TENSORFLOW_LITE_EXPERIMENTAL_SHLO_LEGACY_INCLUDE_SHLO_H_ diff --git a/tensorflow/lite/experimental/shlo/src/bf16.h b/tensorflow/lite/experimental/shlo/legacy/src/bf16.h similarity index 92% rename from tensorflow/lite/experimental/shlo/src/bf16.h rename to tensorflow/lite/experimental/shlo/legacy/src/bf16.h index fd2e95d1ceff37..fbb1d480546409 100644 --- a/tensorflow/lite/experimental/shlo/src/bf16.h +++ b/tensorflow/lite/experimental/shlo/legacy/src/bf16.h @@ -13,10 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_SHLO_SRC_BF16_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_SHLO_SRC_BF16_H_ +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_SHLO_LEGACY_SRC_BF16_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_SHLO_LEGACY_SRC_BF16_H_ -#include "tensorflow/lite/experimental/shlo/src/has_keyword.h" +#include "tensorflow/lite/experimental/shlo/legacy/src/has_keyword.h" #if defined(__STDCPP_BFLOAT16_T__) #include @@ -126,4 +126,4 @@ inline BF16 operator/(BF16 x, BF16 y) { #error Type BF16 is not available #endif -#endif // TENSORFLOW_LITE_EXPERIMENTAL_SHLO_SRC_BF16_H_ +#endif // TENSORFLOW_LITE_EXPERIMENTAL_SHLO_LEGACY_SRC_BF16_H_ diff --git a/tensorflow/lite/experimental/shlo/src/broadcast_in_dim.cc b/tensorflow/lite/experimental/shlo/legacy/src/broadcast_in_dim.cc similarity index 96% rename from tensorflow/lite/experimental/shlo/src/broadcast_in_dim.cc rename to tensorflow/lite/experimental/shlo/legacy/src/broadcast_in_dim.cc index f05cd7823bb063..b20fac03d97a9a 100644 --- a/tensorflow/lite/experimental/shlo/src/broadcast_in_dim.cc +++ b/tensorflow/lite/experimental/shlo/legacy/src/broadcast_in_dim.cc @@ -20,10 +20,10 @@ limitations under the License. #include "absl/status/status.h" #include "absl/types/span.h" -#include "tensorflow/lite/experimental/shlo/include/shlo.h" -#include "tensorflow/lite/experimental/shlo/src/dispatch.h" -#include "tensorflow/lite/experimental/shlo/src/storage.h" -#include "tensorflow/lite/experimental/shlo/src/util.h" +#include "tensorflow/lite/experimental/shlo/legacy/include/shlo.h" +#include "tensorflow/lite/experimental/shlo/legacy/src/dispatch.h" +#include "tensorflow/lite/experimental/shlo/legacy/src/storage.h" +#include "tensorflow/lite/experimental/shlo/legacy/src/util.h" namespace stablehlo { diff --git a/tensorflow/lite/experimental/shlo/src/clamp.cc b/tensorflow/lite/experimental/shlo/legacy/src/clamp.cc similarity index 95% rename from tensorflow/lite/experimental/shlo/src/clamp.cc rename to tensorflow/lite/experimental/shlo/legacy/src/clamp.cc index cd221bae43c55b..b7481345cd227f 100644 --- a/tensorflow/lite/experimental/shlo/src/clamp.cc +++ b/tensorflow/lite/experimental/shlo/legacy/src/clamp.cc @@ -18,10 +18,10 @@ limitations under the License. #include #include "absl/status/status.h" -#include "tensorflow/lite/experimental/shlo/include/shlo.h" -#include "tensorflow/lite/experimental/shlo/src/dispatch.h" -#include "tensorflow/lite/experimental/shlo/src/storage.h" -#include "tensorflow/lite/experimental/shlo/src/util.h" +#include "tensorflow/lite/experimental/shlo/legacy/include/shlo.h" +#include "tensorflow/lite/experimental/shlo/legacy/src/dispatch.h" +#include "tensorflow/lite/experimental/shlo/legacy/src/storage.h" +#include "tensorflow/lite/experimental/shlo/legacy/src/util.h" namespace stablehlo { diff --git a/tensorflow/lite/experimental/shlo/src/compare.cc b/tensorflow/lite/experimental/shlo/legacy/src/compare.cc similarity index 95% rename from tensorflow/lite/experimental/shlo/src/compare.cc rename to tensorflow/lite/experimental/shlo/legacy/src/compare.cc index 923dd7222caeaf..354ea14977e0e8 100644 --- a/tensorflow/lite/experimental/shlo/src/compare.cc +++ b/tensorflow/lite/experimental/shlo/legacy/src/compare.cc @@ -17,10 +17,10 @@ limitations under the License. #include #include "absl/status/status.h" -#include "tensorflow/lite/experimental/shlo/include/shlo.h" -#include "tensorflow/lite/experimental/shlo/src/dispatch.h" -#include "tensorflow/lite/experimental/shlo/src/storage.h" -#include "tensorflow/lite/experimental/shlo/src/util.h" +#include "tensorflow/lite/experimental/shlo/legacy/include/shlo.h" +#include "tensorflow/lite/experimental/shlo/legacy/src/dispatch.h" +#include "tensorflow/lite/experimental/shlo/legacy/src/storage.h" +#include "tensorflow/lite/experimental/shlo/legacy/src/util.h" namespace stablehlo { diff --git a/tensorflow/lite/experimental/shlo/src/concatenate.cc b/tensorflow/lite/experimental/shlo/legacy/src/concatenate.cc similarity index 96% rename from tensorflow/lite/experimental/shlo/src/concatenate.cc rename to tensorflow/lite/experimental/shlo/legacy/src/concatenate.cc index 507105cae6a83b..39998950e0358c 100644 --- a/tensorflow/lite/experimental/shlo/src/concatenate.cc +++ b/tensorflow/lite/experimental/shlo/legacy/src/concatenate.cc @@ -19,10 +19,10 @@ limitations under the License. #include "absl/status/status.h" #include "absl/types/span.h" -#include "tensorflow/lite/experimental/shlo/include/shlo.h" -#include "tensorflow/lite/experimental/shlo/src/dispatch.h" -#include "tensorflow/lite/experimental/shlo/src/storage.h" -#include "tensorflow/lite/experimental/shlo/src/util.h" +#include "tensorflow/lite/experimental/shlo/legacy/include/shlo.h" +#include "tensorflow/lite/experimental/shlo/legacy/src/dispatch.h" +#include "tensorflow/lite/experimental/shlo/legacy/src/storage.h" +#include "tensorflow/lite/experimental/shlo/legacy/src/util.h" namespace stablehlo { diff --git a/tensorflow/lite/experimental/shlo/src/debug.cc b/tensorflow/lite/experimental/shlo/legacy/src/debug.cc similarity index 97% rename from tensorflow/lite/experimental/shlo/src/debug.cc rename to tensorflow/lite/experimental/shlo/legacy/src/debug.cc index 590ba31fcca5bd..655285c9c60825 100644 --- a/tensorflow/lite/experimental/shlo/src/debug.cc +++ b/tensorflow/lite/experimental/shlo/legacy/src/debug.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/lite/experimental/shlo/src/debug.h" +#include "tensorflow/lite/experimental/shlo/legacy/src/debug.h" #include #include @@ -24,11 +24,11 @@ limitations under the License. #include "absl/log/check.h" #include "absl/log/log.h" -#include "tensorflow/lite/experimental/shlo/include/shlo.h" -#include "tensorflow/lite/experimental/shlo/src/bf16.h" -#include "tensorflow/lite/experimental/shlo/src/f16.h" -#include "tensorflow/lite/experimental/shlo/src/storage.h" -#include "tensorflow/lite/experimental/shlo/src/util.h" +#include "tensorflow/lite/experimental/shlo/legacy/include/shlo.h" +#include "tensorflow/lite/experimental/shlo/legacy/src/bf16.h" +#include "tensorflow/lite/experimental/shlo/legacy/src/f16.h" +#include "tensorflow/lite/experimental/shlo/legacy/src/storage.h" +#include "tensorflow/lite/experimental/shlo/legacy/src/util.h" namespace stablehlo { diff --git a/tensorflow/lite/experimental/shlo/src/debug.h b/tensorflow/lite/experimental/shlo/legacy/src/debug.h similarity index 88% rename from tensorflow/lite/experimental/shlo/src/debug.h rename to tensorflow/lite/experimental/shlo/legacy/src/debug.h index 4020ddc4b07de9..eb676526afa082 100644 --- a/tensorflow/lite/experimental/shlo/src/debug.h +++ b/tensorflow/lite/experimental/shlo/legacy/src/debug.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_SHLO_SRC_DEBUG_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_SHLO_SRC_DEBUG_H_ +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_SHLO_LEGACY_SRC_DEBUG_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_SHLO_LEGACY_SRC_DEBUG_H_ #include #include @@ -29,10 +29,10 @@ limitations under the License. #include #include "absl/types/span.h" -#include "tensorflow/lite/experimental/shlo/include/shlo.h" -#include "tensorflow/lite/experimental/shlo/src/bf16.h" -#include "tensorflow/lite/experimental/shlo/src/f16.h" -#include "tensorflow/lite/experimental/shlo/src/util.h" +#include "tensorflow/lite/experimental/shlo/legacy/include/shlo.h" +#include "tensorflow/lite/experimental/shlo/legacy/src/bf16.h" +#include "tensorflow/lite/experimental/shlo/legacy/src/f16.h" +#include "tensorflow/lite/experimental/shlo/legacy/src/util.h" namespace stablehlo { @@ -114,4 +114,4 @@ inline std::string ToString(absl::Span span) { } // namespace stablehlo -#endif // TENSORFLOW_LITE_EXPERIMENTAL_SHLO_SRC_DEBUG_H_ +#endif // TENSORFLOW_LITE_EXPERIMENTAL_SHLO_LEGACY_SRC_DEBUG_H_ diff --git a/tensorflow/lite/experimental/shlo/src/dispatch.h b/tensorflow/lite/experimental/shlo/legacy/src/dispatch.h similarity index 97% rename from tensorflow/lite/experimental/shlo/src/dispatch.h rename to tensorflow/lite/experimental/shlo/legacy/src/dispatch.h index ac0028fdb18f6d..649bc0baad2a94 100644 --- a/tensorflow/lite/experimental/shlo/src/dispatch.h +++ b/tensorflow/lite/experimental/shlo/legacy/src/dispatch.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_SHLO_SRC_DISPATCH_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_SHLO_SRC_DISPATCH_H_ +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_SHLO_LEGACY_SRC_DISPATCH_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_SHLO_LEGACY_SRC_DISPATCH_H_ namespace stablehlo { @@ -134,4 +134,4 @@ namespace stablehlo { } // namespace stablehlo -#endif // TENSORFLOW_LITE_EXPERIMENTAL_SHLO_SRC_DISPATCH_H_ +#endif // TENSORFLOW_LITE_EXPERIMENTAL_SHLO_LEGACY_SRC_DISPATCH_H_ diff --git a/tensorflow/lite/experimental/shlo/src/elementwise_binary.cc b/tensorflow/lite/experimental/shlo/legacy/src/elementwise_binary.cc similarity index 99% rename from tensorflow/lite/experimental/shlo/src/elementwise_binary.cc rename to tensorflow/lite/experimental/shlo/legacy/src/elementwise_binary.cc index e449f875d3d338..feed5f3d6322fa 100644 --- a/tensorflow/lite/experimental/shlo/src/elementwise_binary.cc +++ b/tensorflow/lite/experimental/shlo/legacy/src/elementwise_binary.cc @@ -18,9 +18,9 @@ limitations under the License. #include #include "absl/status/status.h" -#include "tensorflow/lite/experimental/shlo/include/shlo.h" -#include "tensorflow/lite/experimental/shlo/src/storage.h" -#include "tensorflow/lite/experimental/shlo/src/util.h" +#include "tensorflow/lite/experimental/shlo/legacy/include/shlo.h" +#include "tensorflow/lite/experimental/shlo/legacy/src/storage.h" +#include "tensorflow/lite/experimental/shlo/legacy/src/util.h" namespace stablehlo { diff --git a/tensorflow/lite/experimental/shlo/src/elementwise_unary.cc b/tensorflow/lite/experimental/shlo/legacy/src/elementwise_unary.cc similarity index 98% rename from tensorflow/lite/experimental/shlo/src/elementwise_unary.cc rename to tensorflow/lite/experimental/shlo/legacy/src/elementwise_unary.cc index 43510324917877..2c404cce6ac89a 100644 --- a/tensorflow/lite/experimental/shlo/src/elementwise_unary.cc +++ b/tensorflow/lite/experimental/shlo/legacy/src/elementwise_unary.cc @@ -21,11 +21,11 @@ limitations under the License. #include #include "absl/status/status.h" -#include "tensorflow/lite/experimental/shlo/include/shlo.h" -#include "tensorflow/lite/experimental/shlo/src/bf16.h" -#include "tensorflow/lite/experimental/shlo/src/f16.h" -#include "tensorflow/lite/experimental/shlo/src/storage.h" -#include "tensorflow/lite/experimental/shlo/src/util.h" +#include "tensorflow/lite/experimental/shlo/legacy/include/shlo.h" +#include "tensorflow/lite/experimental/shlo/legacy/src/bf16.h" +#include "tensorflow/lite/experimental/shlo/legacy/src/f16.h" +#include "tensorflow/lite/experimental/shlo/legacy/src/storage.h" +#include "tensorflow/lite/experimental/shlo/legacy/src/util.h" namespace stablehlo { diff --git a/tensorflow/lite/experimental/shlo/src/f16.h b/tensorflow/lite/experimental/shlo/legacy/src/f16.h similarity index 78% rename from tensorflow/lite/experimental/shlo/src/f16.h rename to tensorflow/lite/experimental/shlo/legacy/src/f16.h index 72dbef86235c32..a2679306e83129 100644 --- a/tensorflow/lite/experimental/shlo/src/f16.h +++ b/tensorflow/lite/experimental/shlo/legacy/src/f16.h @@ -13,10 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_SHLO_SRC_F16_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_SHLO_SRC_F16_H_ +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_SHLO_LEGACY_SRC_F16_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_SHLO_LEGACY_SRC_F16_H_ -#include "tensorflow/lite/experimental/shlo/src/has_keyword.h" +#include "tensorflow/lite/experimental/shlo/legacy/src/has_keyword.h" #if defined(__STDCPP_FLOAT16_T__) #include @@ -33,4 +33,4 @@ using F16 = _Float16; #error Type F16 is not available #endif -#endif // TENSORFLOW_LITE_EXPERIMENTAL_SHLO_SRC_F16_H_ +#endif // TENSORFLOW_LITE_EXPERIMENTAL_SHLO_LEGACY_SRC_F16_H_ diff --git a/tensorflow/lite/experimental/shlo/legacy/src/has_keyword.h b/tensorflow/lite/experimental/shlo/legacy/src/has_keyword.h new file mode 100644 index 00000000000000..7c8efdc044c317 --- /dev/null +++ b/tensorflow/lite/experimental/shlo/legacy/src/has_keyword.h @@ -0,0 +1,32 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_SHLO_LEGACY_SRC_HAS_KEYWORD_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_SHLO_LEGACY_SRC_HAS_KEYWORD_H_ + +// CAUTION: __is_identifier behaves opposite how you would expect! +// '__is_identifier' returns '0' if '__x' is a reserved identifier provided by +// the compiler and '1' otherwise. borrowed from LLVM __config header under +// Apache license 2. +// (https://www.mend.io/blog/top-10-apache-license-questions-answered/) + +#ifndef __is_identifier // Optional of course. +#define __is_identifier(x) 1 // Compatibility with non-clang compilers. +#endif + +// More sensible macro for keyword detection +#define __has_keyword(__x) !(__is_identifier(__x)) + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_SHLO_LEGACY_SRC_HAS_KEYWORD_H_ diff --git a/tensorflow/lite/experimental/shlo/src/iota.cc b/tensorflow/lite/experimental/shlo/legacy/src/iota.cc similarity index 92% rename from tensorflow/lite/experimental/shlo/src/iota.cc rename to tensorflow/lite/experimental/shlo/legacy/src/iota.cc index 6df453b2544ab7..a2bd73febcd2df 100644 --- a/tensorflow/lite/experimental/shlo/src/iota.cc +++ b/tensorflow/lite/experimental/shlo/legacy/src/iota.cc @@ -16,10 +16,10 @@ limitations under the License. #include #include "absl/status/status.h" -#include "tensorflow/lite/experimental/shlo/include/shlo.h" -#include "tensorflow/lite/experimental/shlo/src/dispatch.h" -#include "tensorflow/lite/experimental/shlo/src/storage.h" -#include "tensorflow/lite/experimental/shlo/src/util.h" +#include "tensorflow/lite/experimental/shlo/legacy/include/shlo.h" +#include "tensorflow/lite/experimental/shlo/legacy/src/dispatch.h" +#include "tensorflow/lite/experimental/shlo/legacy/src/storage.h" +#include "tensorflow/lite/experimental/shlo/legacy/src/util.h" namespace stablehlo { diff --git a/tensorflow/lite/experimental/shlo/src/is_finite.cc b/tensorflow/lite/experimental/shlo/legacy/src/is_finite.cc similarity index 92% rename from tensorflow/lite/experimental/shlo/src/is_finite.cc rename to tensorflow/lite/experimental/shlo/legacy/src/is_finite.cc index dce589e3f21bca..11be44cead6fb5 100644 --- a/tensorflow/lite/experimental/shlo/src/is_finite.cc +++ b/tensorflow/lite/experimental/shlo/legacy/src/is_finite.cc @@ -18,10 +18,10 @@ limitations under the License. #include #include "absl/status/status.h" -#include "tensorflow/lite/experimental/shlo/include/shlo.h" -#include "tensorflow/lite/experimental/shlo/src/dispatch.h" -#include "tensorflow/lite/experimental/shlo/src/storage.h" -#include "tensorflow/lite/experimental/shlo/src/util.h" +#include "tensorflow/lite/experimental/shlo/legacy/include/shlo.h" +#include "tensorflow/lite/experimental/shlo/legacy/src/dispatch.h" +#include "tensorflow/lite/experimental/shlo/legacy/src/storage.h" +#include "tensorflow/lite/experimental/shlo/legacy/src/util.h" namespace stablehlo { diff --git a/tensorflow/lite/experimental/shlo/src/select.cc b/tensorflow/lite/experimental/shlo/legacy/src/select.cc similarity index 95% rename from tensorflow/lite/experimental/shlo/src/select.cc rename to tensorflow/lite/experimental/shlo/legacy/src/select.cc index 43913fc03bc2ad..5bef465c9a0df7 100644 --- a/tensorflow/lite/experimental/shlo/src/select.cc +++ b/tensorflow/lite/experimental/shlo/legacy/src/select.cc @@ -17,10 +17,10 @@ limitations under the License. #include #include "absl/status/status.h" -#include "tensorflow/lite/experimental/shlo/include/shlo.h" -#include "tensorflow/lite/experimental/shlo/src/dispatch.h" -#include "tensorflow/lite/experimental/shlo/src/storage.h" -#include "tensorflow/lite/experimental/shlo/src/util.h" +#include "tensorflow/lite/experimental/shlo/legacy/include/shlo.h" +#include "tensorflow/lite/experimental/shlo/legacy/src/dispatch.h" +#include "tensorflow/lite/experimental/shlo/legacy/src/storage.h" +#include "tensorflow/lite/experimental/shlo/legacy/src/util.h" namespace stablehlo { diff --git a/tensorflow/lite/experimental/shlo/src/shlo.cc b/tensorflow/lite/experimental/shlo/legacy/src/shlo.cc similarity index 96% rename from tensorflow/lite/experimental/shlo/src/shlo.cc rename to tensorflow/lite/experimental/shlo/legacy/src/shlo.cc index c04020deaa1ae2..d5b47bc28371be 100644 --- a/tensorflow/lite/experimental/shlo/src/shlo.cc +++ b/tensorflow/lite/experimental/shlo/legacy/src/shlo.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/lite/experimental/shlo/include/shlo.h" +#include "tensorflow/lite/experimental/shlo/legacy/include/shlo.h" #include #include @@ -23,7 +23,7 @@ limitations under the License. #include #include "absl/log/log.h" -#include "tensorflow/lite/experimental/shlo/src/storage.h" +#include "tensorflow/lite/experimental/shlo/legacy/src/storage.h" namespace stablehlo { diff --git a/tensorflow/lite/experimental/shlo/src/storage.h b/tensorflow/lite/experimental/shlo/legacy/src/storage.h similarity index 88% rename from tensorflow/lite/experimental/shlo/src/storage.h rename to tensorflow/lite/experimental/shlo/legacy/src/storage.h index 5b09031a3f37d3..280f8a8946b86d 100644 --- a/tensorflow/lite/experimental/shlo/src/storage.h +++ b/tensorflow/lite/experimental/shlo/legacy/src/storage.h @@ -13,15 +13,15 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_SHLO_SRC_STORAGE_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_SHLO_SRC_STORAGE_H_ +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_SHLO_LEGACY_SRC_STORAGE_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_SHLO_LEGACY_SRC_STORAGE_H_ #include #include -#include "tensorflow/lite/experimental/shlo/include/shlo.h" -#include "tensorflow/lite/experimental/shlo/src/bf16.h" -#include "tensorflow/lite/experimental/shlo/src/f16.h" +#include "tensorflow/lite/experimental/shlo/legacy/include/shlo.h" +#include "tensorflow/lite/experimental/shlo/legacy/src/bf16.h" +#include "tensorflow/lite/experimental/shlo/legacy/src/f16.h" namespace stablehlo { @@ -121,4 +121,4 @@ struct Storage { } // namespace stablehlo -#endif // TENSORFLOW_LITE_EXPERIMENTAL_SHLO_SRC_STORAGE_H_ +#endif // TENSORFLOW_LITE_EXPERIMENTAL_SHLO_LEGACY_SRC_STORAGE_H_ diff --git a/tensorflow/lite/experimental/shlo/src/uniform_dequantize_quantize.cc b/tensorflow/lite/experimental/shlo/legacy/src/uniform_dequantize_quantize.cc similarity index 94% rename from tensorflow/lite/experimental/shlo/src/uniform_dequantize_quantize.cc rename to tensorflow/lite/experimental/shlo/legacy/src/uniform_dequantize_quantize.cc index a7090c525ed21c..5b09f14b7189f5 100644 --- a/tensorflow/lite/experimental/shlo/src/uniform_dequantize_quantize.cc +++ b/tensorflow/lite/experimental/shlo/legacy/src/uniform_dequantize_quantize.cc @@ -16,10 +16,10 @@ limitations under the License. #include #include "absl/status/status.h" -#include "tensorflow/lite/experimental/shlo/include/shlo.h" -#include "tensorflow/lite/experimental/shlo/src/dispatch.h" -#include "tensorflow/lite/experimental/shlo/src/storage.h" -#include "tensorflow/lite/experimental/shlo/src/util.h" +#include "tensorflow/lite/experimental/shlo/legacy/include/shlo.h" +#include "tensorflow/lite/experimental/shlo/legacy/src/dispatch.h" +#include "tensorflow/lite/experimental/shlo/legacy/src/storage.h" +#include "tensorflow/lite/experimental/shlo/legacy/src/util.h" namespace stablehlo { diff --git a/tensorflow/lite/experimental/shlo/src/util.h b/tensorflow/lite/experimental/shlo/legacy/src/util.h similarity index 95% rename from tensorflow/lite/experimental/shlo/src/util.h rename to tensorflow/lite/experimental/shlo/legacy/src/util.h index aff723b822d10a..8df271be84f5a7 100644 --- a/tensorflow/lite/experimental/shlo/src/util.h +++ b/tensorflow/lite/experimental/shlo/legacy/src/util.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_SHLO_SRC_UTIL_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_SHLO_SRC_UTIL_H_ +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_SHLO_LEGACY_SRC_UTIL_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_SHLO_LEGACY_SRC_UTIL_H_ #include #include @@ -24,8 +24,8 @@ limitations under the License. #include #include "absl/status/status.h" -#include "tensorflow/lite/experimental/shlo/include/shlo.h" -#include "tensorflow/lite/experimental/shlo/src/storage.h" +#include "tensorflow/lite/experimental/shlo/legacy/include/shlo.h" +#include "tensorflow/lite/experimental/shlo/legacy/src/storage.h" namespace stablehlo { @@ -218,4 +218,4 @@ class TensorIndexIterator { } // namespace stablehlo -#endif // TENSORFLOW_LITE_EXPERIMENTAL_SHLO_SRC_UTIL_H_ +#endif // TENSORFLOW_LITE_EXPERIMENTAL_SHLO_LEGACY_SRC_UTIL_H_ diff --git a/tensorflow/lite/experimental/shlo/test/BUILD b/tensorflow/lite/experimental/shlo/legacy/test/BUILD similarity index 56% rename from tensorflow/lite/experimental/shlo/test/BUILD rename to tensorflow/lite/experimental/shlo/legacy/test/BUILD index 5576856bc8eb75..1e0ed3830a60b3 100644 --- a/tensorflow/lite/experimental/shlo/test/BUILD +++ b/tensorflow/lite/experimental/shlo/legacy/test/BUILD @@ -1,4 +1,14 @@ -package(default_applicable_licenses = ["//tensorflow:license"]) +package(default_applicable_licenses = ["//tensorflow:LICENSE"]) + +cc_library( + name = "matchers", + testonly = True, + hdrs = ["matchers.h"], + deps = [ + "//tensorflow/lite/experimental/shlo/legacy:debug", + "@com_google_googletest//:gtest_main", + ], +) cc_library( name = "util", @@ -7,9 +17,10 @@ cc_library( hdrs = [ "util.h", ], - visibility = ["//tensorflow/lite/experimental/shlo/bench:__subpackages__"], + visibility = ["//tensorflow/lite/experimental/shlo/legacy/bench:__subpackages__"], deps = [ - "//tensorflow/lite/experimental/shlo", + "//tensorflow/lite/experimental/shlo/legacy:debug", + "//tensorflow/lite/experimental/shlo/legacy:shlo", "@com_google_absl//absl/log:check", ], ) @@ -23,9 +34,8 @@ cc_test( ], deps = [ ":util", - "//tensorflow/lite/experimental/shlo", - "//tensorflow/lite/experimental/shlo:debug", - "@com_google_absl//absl/log", + "//tensorflow/lite/experimental/shlo/legacy:debug", + "//tensorflow/lite/experimental/shlo/legacy:shlo", "@com_google_absl//absl/types:span", "@com_google_googletest//:gtest_main", ], @@ -40,9 +50,8 @@ cc_test( ], deps = [ ":util", - "//tensorflow/lite/experimental/shlo", - "//tensorflow/lite/experimental/shlo:debug", - "@com_google_absl//absl/log", + "//tensorflow/lite/experimental/shlo/legacy:debug", + "//tensorflow/lite/experimental/shlo/legacy:shlo", "@com_google_googletest//:gtest_main", ], ) @@ -56,9 +65,8 @@ cc_test( ], deps = [ ":util", - "//tensorflow/lite/experimental/shlo", - "//tensorflow/lite/experimental/shlo:debug", - "@com_google_absl//absl/log", + "//tensorflow/lite/experimental/shlo/legacy:debug", + "//tensorflow/lite/experimental/shlo/legacy:shlo", "@com_google_googletest//:gtest_main", ], ) @@ -72,9 +80,8 @@ cc_test( ], deps = [ ":util", - "//tensorflow/lite/experimental/shlo", - "//tensorflow/lite/experimental/shlo:debug", - "@com_google_absl//absl/log", + "//tensorflow/lite/experimental/shlo/legacy:debug", + "//tensorflow/lite/experimental/shlo/legacy:shlo", "@com_google_absl//absl/types:span", "@com_google_googletest//:gtest_main", ], @@ -88,10 +95,10 @@ cc_test( data = [ ], deps = [ + ":matchers", ":util", - "//tensorflow/lite/experimental/shlo", - "//tensorflow/lite/experimental/shlo:debug", - "@com_google_absl//absl/log", + "//tensorflow/lite/experimental/shlo/legacy:debug", + "//tensorflow/lite/experimental/shlo/legacy:shlo", "@com_google_absl//absl/status", "@com_google_googletest//:gtest_main", ], @@ -105,10 +112,10 @@ cc_test( data = [ ], deps = [ + ":matchers", ":util", - "//tensorflow/lite/experimental/shlo", - "//tensorflow/lite/experimental/shlo:debug", - "@com_google_absl//absl/log", + "//tensorflow/lite/experimental/shlo/legacy:debug", + "//tensorflow/lite/experimental/shlo/legacy:shlo", "@com_google_absl//absl/status", "@com_google_googletest//:gtest_main", ], @@ -123,9 +130,8 @@ cc_test( ], deps = [ ":util", - "//tensorflow/lite/experimental/shlo", - "//tensorflow/lite/experimental/shlo:debug", - "@com_google_absl//absl/log", + "//tensorflow/lite/experimental/shlo/legacy:debug", + "//tensorflow/lite/experimental/shlo/legacy:shlo", "@com_google_googletest//:gtest_main", ], ) @@ -138,9 +144,8 @@ cc_test( data = [ ], deps = [ - "//tensorflow/lite/experimental/shlo", - "//tensorflow/lite/experimental/shlo:debug", - "@com_google_absl//absl/log", + "//tensorflow/lite/experimental/shlo/legacy:debug", + "//tensorflow/lite/experimental/shlo/legacy:shlo", "@com_google_googletest//:gtest_main", ], ) @@ -154,9 +159,8 @@ cc_test( ], deps = [ ":util", - "//tensorflow/lite/experimental/shlo", - "//tensorflow/lite/experimental/shlo:debug", - "@com_google_absl//absl/log", + "//tensorflow/lite/experimental/shlo/legacy:debug", + "//tensorflow/lite/experimental/shlo/legacy:shlo", "@com_google_googletest//:gtest_main", ], ) @@ -169,10 +173,9 @@ cc_test( data = [ ], deps = [ - "//tensorflow/lite/experimental/shlo", - "//tensorflow/lite/experimental/shlo:debug", - "@com_google_absl//absl/log", - "@com_google_absl//absl/status", + ":matchers", + "//tensorflow/lite/experimental/shlo/legacy:debug", + "//tensorflow/lite/experimental/shlo/legacy:shlo", "@com_google_googletest//:gtest_main", ], ) diff --git a/tensorflow/lite/experimental/shlo/test/broadcast_in_dim_test.cc b/tensorflow/lite/experimental/shlo/legacy/test/broadcast_in_dim_test.cc similarity index 85% rename from tensorflow/lite/experimental/shlo/test/broadcast_in_dim_test.cc rename to tensorflow/lite/experimental/shlo/legacy/test/broadcast_in_dim_test.cc index f5fb64f95b496e..b8bc9a22bf1d36 100644 --- a/tensorflow/lite/experimental/shlo/test/broadcast_in_dim_test.cc +++ b/tensorflow/lite/experimental/shlo/legacy/test/broadcast_in_dim_test.cc @@ -17,13 +17,13 @@ limitations under the License. #include #include +#include #include -#include "absl/log/log.h" #include "absl/types/span.h" -#include "tensorflow/lite/experimental/shlo/include/shlo.h" -#include "tensorflow/lite/experimental/shlo/src/debug.h" -#include "tensorflow/lite/experimental/shlo/src/storage.h" -#include "tensorflow/lite/experimental/shlo/test/util.h" +#include "tensorflow/lite/experimental/shlo/legacy/include/shlo.h" +#include "tensorflow/lite/experimental/shlo/legacy/src/debug.h" // IWYU pragma: keep, b/321245930 +#include "tensorflow/lite/experimental/shlo/legacy/src/storage.h" +#include "tensorflow/lite/experimental/shlo/legacy/test/util.h" namespace stablehlo { namespace testing { @@ -47,20 +47,10 @@ void test(std::initializer_list&& operand_shape, absl::Span broadcast_dimensions( broadcast_dimensions_values); - auto res = BroadcastInDim(operand, broadcast_dimensions, result); - - if (!res.ok()) { - LOG(INFO) << "Failure: " << res; - } - ASSERT_EQ(res.ok(), true); - - if (result != expected) { - LOG(INFO) << "operand: " << operand; - LOG(INFO) << "broadcast_dimensions: " << ToString(broadcast_dimensions); - LOG(INFO) << "expected: " << expected; - LOG(INFO) << "result: " << result; - } - ASSERT_EQ(result, expected); + ASSERT_OK(BroadcastInDim(operand, broadcast_dimensions, result)); + EXPECT_EQ(result, expected) + << "operand: " << operand + << "\nbroadcast_dimensions: " << ToString(broadcast_dimensions); } template @@ -97,18 +87,10 @@ void test( broadcast_dimensions_values); auto res = BroadcastInDim(operand, broadcast_dimensions, result); - if (!res.ok()) { - LOG(INFO) << "Failure: " << res; - } - ASSERT_EQ(res.ok(), true); - - if (result != expected) { - LOG(INFO) << "operand: " << operand; - LOG(INFO) << "broadcast_dimensions: " << ToString(broadcast_dimensions); - LOG(INFO) << "expected: " << expected; - LOG(INFO) << "result: " << result; - } - ASSERT_EQ(result, expected); + ASSERT_OK(BroadcastInDim(operand, broadcast_dimensions, result)); + EXPECT_EQ(result, expected) + << "operand: " << operand + << "\nbroadcast_dimensions: " << ToString(broadcast_dimensions); } TEST(BroadcastInDim, Unquantized) { diff --git a/tensorflow/lite/experimental/shlo/test/clamp_test.cc b/tensorflow/lite/experimental/shlo/legacy/test/clamp_test.cc similarity index 88% rename from tensorflow/lite/experimental/shlo/test/clamp_test.cc rename to tensorflow/lite/experimental/shlo/legacy/test/clamp_test.cc index a6d78c6cc2ffc1..001fd6f761d999 100644 --- a/tensorflow/lite/experimental/shlo/test/clamp_test.cc +++ b/tensorflow/lite/experimental/shlo/legacy/test/clamp_test.cc @@ -17,12 +17,12 @@ limitations under the License. #include #include +#include #include -#include "absl/log/log.h" -#include "tensorflow/lite/experimental/shlo/include/shlo.h" -#include "tensorflow/lite/experimental/shlo/src/debug.h" -#include "tensorflow/lite/experimental/shlo/src/storage.h" -#include "tensorflow/lite/experimental/shlo/test/util.h" +#include "tensorflow/lite/experimental/shlo/legacy/include/shlo.h" +#include "tensorflow/lite/experimental/shlo/legacy/src/debug.h" // IWYU pragma: keep, b/321245930 +#include "tensorflow/lite/experimental/shlo/legacy/src/storage.h" +#include "tensorflow/lite/experimental/shlo/legacy/test/util.h" namespace stablehlo { namespace testing { @@ -45,21 +45,9 @@ void test(std::initializer_list&& shape, expected_values.size()); Tensor result(TensorType(Shape(shape), element_type), result_values.data()); - auto res = Clamp(min, operand, max, result); - - if (!res.ok()) { - LOG(INFO) << "Failure: " << res; - } - ASSERT_EQ(res.ok(), true); - - if (result != expected) { - LOG(INFO) << "min=" << min; - LOG(INFO) << "max=" << max; - LOG(INFO) << "operand=" << operand; - LOG(INFO) << "expected=" << expected; - LOG(INFO) << "result=" << result; - } - ASSERT_EQ(result, expected); + ASSERT_OK(Clamp(min, operand, max, result)); + EXPECT_EQ(result, expected) + << "min: " << min << "\nmax: " << max << "\noperand: " << operand; } template @@ -107,21 +95,9 @@ void test( QuantizedTensorElementType(element_type)), result_quant_values.data()); - auto res = Clamp(min, operand, max, result); - - if (!res.ok()) { - LOG(INFO) << "Failure: " << res; - } - ASSERT_EQ(res.ok(), true); - - if (result != expected) { - LOG(INFO) << "min=" << min; - LOG(INFO) << "max=" << max; - LOG(INFO) << "operand=" << operand; - LOG(INFO) << "expected=" << expected; - LOG(INFO) << "result=" << result; - } - ASSERT_EQ(result, expected); + ASSERT_OK(Clamp(min, operand, max, result)); + EXPECT_EQ(result, expected) + << "min: " << min << "\nmax: " << max << "\noperand: " << operand; } TEST(Clamp, Unquantized) { diff --git a/tensorflow/lite/experimental/shlo/test/compare_test.cc b/tensorflow/lite/experimental/shlo/legacy/test/compare_test.cc similarity index 94% rename from tensorflow/lite/experimental/shlo/test/compare_test.cc rename to tensorflow/lite/experimental/shlo/legacy/test/compare_test.cc index b14a2d1268acb7..217ea28424fdb0 100644 --- a/tensorflow/lite/experimental/shlo/test/compare_test.cc +++ b/tensorflow/lite/experimental/shlo/legacy/test/compare_test.cc @@ -17,12 +17,12 @@ limitations under the License. #include #include +#include #include -#include "absl/log/log.h" -#include "tensorflow/lite/experimental/shlo/include/shlo.h" -#include "tensorflow/lite/experimental/shlo/src/debug.h" // IWYU pragma: keep, b/321245930 -#include "tensorflow/lite/experimental/shlo/src/storage.h" -#include "tensorflow/lite/experimental/shlo/test/util.h" +#include "tensorflow/lite/experimental/shlo/legacy/include/shlo.h" +#include "tensorflow/lite/experimental/shlo/legacy/src/debug.h" // IWYU pragma: keep, b/321245930 +#include "tensorflow/lite/experimental/shlo/legacy/src/storage.h" +#include "tensorflow/lite/experimental/shlo/legacy/test/util.h" namespace stablehlo { namespace testing { @@ -44,22 +44,11 @@ void test( Tensor result(TensorType(Shape(shape), ElementType::kI1), result_values.data()); - auto res = Compare(lhs, rhs, comparison_direction, compare_type, result); - - if (!res.ok()) { - LOG(INFO) << "Failure: " << res; - } - ASSERT_EQ(res.ok(), true); - - if (result != expected) { - LOG(INFO) << "comparison_direction: " << comparison_direction; - LOG(INFO) << "compare_type: " << compare_type; - LOG(INFO) << "lhs: " << lhs; - LOG(INFO) << "rhs: " << rhs; - LOG(INFO) << "expected: " << expected; - LOG(INFO) << "result: " << result; - } - ASSERT_EQ(result, expected); + ASSERT_OK(Compare(lhs, rhs, comparison_direction, compare_type, result)); + EXPECT_EQ(result, expected) + << "comparison_direction: " << comparison_direction + << "\ncompare_type: " << compare_type << "\nlhs: " << lhs + << "\nrhs: " << rhs; } template @@ -94,22 +83,11 @@ void test( Tensor result(TensorType(Shape(shape), ElementType::kI1), result_values.data()); - auto res = Compare(lhs, rhs, comparison_direction, compare_type, result); - - if (!res.ok()) { - LOG(INFO) << "Failure: " << res; - } - ASSERT_EQ(res.ok(), true); - - if (result != expected) { - LOG(INFO) << "comparison_direction: " << comparison_direction; - LOG(INFO) << "compare_type: " << compare_type; - LOG(INFO) << "lhs: " << lhs; - LOG(INFO) << "rhs: " << rhs; - LOG(INFO) << "expected: " << expected; - LOG(INFO) << "result: " << result; - } - ASSERT_EQ(result, expected); + ASSERT_OK(Compare(lhs, rhs, comparison_direction, compare_type, result)); + EXPECT_EQ(result, expected) + << "comparison_direction: " << comparison_direction + << "\ncompare_type: " << compare_type << "\nlhs: " << lhs + << "\nrhs: " << rhs; } TEST(Compare, Unquantized) { diff --git a/tensorflow/lite/experimental/shlo/test/concatenate_test.cc b/tensorflow/lite/experimental/shlo/legacy/test/concatenate_test.cc similarity index 87% rename from tensorflow/lite/experimental/shlo/test/concatenate_test.cc rename to tensorflow/lite/experimental/shlo/legacy/test/concatenate_test.cc index 54dc17b6b899df..3494ad9940a58f 100644 --- a/tensorflow/lite/experimental/shlo/test/concatenate_test.cc +++ b/tensorflow/lite/experimental/shlo/legacy/test/concatenate_test.cc @@ -13,17 +13,21 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include #include +#include +#include +#include #include #include +#include #include -#include "absl/log/log.h" #include "absl/types/span.h" -#include "tensorflow/lite/experimental/shlo/include/shlo.h" -#include "tensorflow/lite/experimental/shlo/src/debug.h" // IWYU pragma: keep, b/321245930 -#include "tensorflow/lite/experimental/shlo/src/storage.h" -#include "tensorflow/lite/experimental/shlo/test/util.h" +#include "tensorflow/lite/experimental/shlo/legacy/include/shlo.h" +#include "tensorflow/lite/experimental/shlo/legacy/src/debug.h" // IWYU pragma: keep, b/321245930 +#include "tensorflow/lite/experimental/shlo/legacy/src/storage.h" +#include "tensorflow/lite/experimental/shlo/legacy/test/util.h" namespace stablehlo { namespace testing { @@ -34,6 +38,16 @@ struct TensorConst { std::vector::Type>&& values; }; +template +std::string ToString(std::string_view name, + const std::vector& tensors) { + std::ostringstream result; + for (size_t i = 0; i < tensors.size(); ++i) { + result << name << "[" << i << "]: " << *tensors[i] << "\n"; + } + return result.str(); +} + template void test(std::initializer_list>&& inputs_, DimensionSize dimension, TensorConst&& expected_) { @@ -54,22 +68,9 @@ void test(std::initializer_list>&& inputs_, expected.num_elements()); Tensor result(TensorType(expected.type()), result_values.data()); - auto res = Concatenate(absl::Span(inputs), dimension, result); - - if (!res.ok()) { - LOG(INFO) << "Failure: " << res; - } - ASSERT_EQ(res.ok(), true); - - if (result != expected) { - for (auto i = 0; i < inputs.size(); ++i) { - LOG(INFO) << "input[" << i << "]: " << *inputs[i]; - } - LOG(INFO) << "dimension: " << dimension; - LOG(INFO) << "expected: " << expected; - LOG(INFO) << "result: " << result; - } - ASSERT_EQ(result, expected); + ASSERT_OK(Concatenate(absl::Span(inputs), dimension, result)); + EXPECT_EQ(result, expected) + << ToString("inputs", inputs) << "dimension: " << dimension; } template @@ -106,23 +107,10 @@ void test(QuantizedParameter&& quantized_parameter, QuantizedTensor result(QuantizedTensorType(expected.type()), result_values.data()); - auto res = Concatenate(absl::Span(inputs), dimension, - result); - - if (!res.ok()) { - LOG(INFO) << "Failure: " << res; - } - ASSERT_EQ(res.ok(), true); - - if (result != expected) { - for (auto i = 0; i < inputs.size(); ++i) { - LOG(INFO) << "input[" << i << "]: " << *inputs[i]; - } - LOG(INFO) << "dimension: " << dimension; - LOG(INFO) << "expected: " << expected; - LOG(INFO) << "result: " << result; - } - ASSERT_EQ(result, expected); + ASSERT_OK(Concatenate(absl::Span(inputs), dimension, + result)); + EXPECT_EQ(result, expected) + << ToString("inputs", inputs) << "dimension: " << dimension; } TEST(Concatenate, Unquantized) { diff --git a/tensorflow/lite/experimental/shlo/test/elementwise_binary_test.cc b/tensorflow/lite/experimental/shlo/legacy/test/elementwise_binary_test.cc similarity index 96% rename from tensorflow/lite/experimental/shlo/test/elementwise_binary_test.cc rename to tensorflow/lite/experimental/shlo/legacy/test/elementwise_binary_test.cc index fff0ceae0862f0..18714f25ae490c 100644 --- a/tensorflow/lite/experimental/shlo/test/elementwise_binary_test.cc +++ b/tensorflow/lite/experimental/shlo/legacy/test/elementwise_binary_test.cc @@ -18,13 +18,14 @@ limitations under the License. #include #include +#include #include -#include "absl/log/log.h" #include "absl/status/status.h" -#include "tensorflow/lite/experimental/shlo/include/shlo.h" -#include "tensorflow/lite/experimental/shlo/src/debug.h" -#include "tensorflow/lite/experimental/shlo/src/storage.h" -#include "tensorflow/lite/experimental/shlo/test/util.h" +#include "tensorflow/lite/experimental/shlo/legacy/include/shlo.h" +#include "tensorflow/lite/experimental/shlo/legacy/src/debug.h" // IWYU pragma: keep, b/321245930 +#include "tensorflow/lite/experimental/shlo/legacy/src/storage.h" +#include "tensorflow/lite/experimental/shlo/legacy/test/matchers.h" +#include "tensorflow/lite/experimental/shlo/legacy/test/util.h" namespace stablehlo { namespace testing { @@ -46,20 +47,9 @@ void test(absl::Status (*op)(const Tensor&, const Tensor&, Tensor&), expected_values.size()); Tensor result(TensorType(Shape(shape), element_type), result_values.data()); - auto res = op(input1, input2, result); - - if (!res.ok()) { - LOG(INFO) << "Failure: " << res; - } - if (result != expected) { - LOG(INFO) << "input1=" << input1; - LOG(INFO) << "input2=" << input2; - LOG(INFO) << "expected=" << expected; - LOG(INFO) << "result=" << result; - } - - ASSERT_EQ(res.ok(), true); - ASSERT_EQ(AlmostSame(result, expected), true); + ASSERT_OK(op(input1, input2, result)); + EXPECT_THAT(result, IsAlmostSame(expected)) + << "input1: " << input1 << "\ninput2: " << input2; } template @@ -100,20 +90,9 @@ void test( QuantizedTensorElementType(element_type)), result_quant_values.data()); - auto res = op(input1, input2, result); - - if (!res.ok()) { - LOG(INFO) << "Failure: " << res; - } - if (result != expected) { - LOG(INFO) << "input1=" << input1; - LOG(INFO) << "input2=" << input2; - LOG(INFO) << "expected=" << expected; - LOG(INFO) << "result=" << result; - } - - ASSERT_EQ(res.ok(), true); - ASSERT_EQ(AlmostSame(result, expected), true); + ASSERT_OK(op(input1, input2, result)); + EXPECT_THAT(result, IsAlmostSame(expected)) + << "input1: " << input1 << "\ninput2: " << input2; } TEST(ElementwiseBinary, Add) { diff --git a/tensorflow/lite/experimental/shlo/test/elementwise_unary_test.cc b/tensorflow/lite/experimental/shlo/legacy/test/elementwise_unary_test.cc similarity index 97% rename from tensorflow/lite/experimental/shlo/test/elementwise_unary_test.cc rename to tensorflow/lite/experimental/shlo/legacy/test/elementwise_unary_test.cc index 995cf6deaf7bc6..11a2564cba9d11 100644 --- a/tensorflow/lite/experimental/shlo/test/elementwise_unary_test.cc +++ b/tensorflow/lite/experimental/shlo/legacy/test/elementwise_unary_test.cc @@ -19,13 +19,14 @@ limitations under the License. #include #include +#include #include -#include "absl/log/log.h" #include "absl/status/status.h" -#include "tensorflow/lite/experimental/shlo/include/shlo.h" -#include "tensorflow/lite/experimental/shlo/src/debug.h" -#include "tensorflow/lite/experimental/shlo/src/storage.h" -#include "tensorflow/lite/experimental/shlo/test/util.h" +#include "tensorflow/lite/experimental/shlo/legacy/include/shlo.h" +#include "tensorflow/lite/experimental/shlo/legacy/src/debug.h" // IWYU pragma: keep, b/321245930 +#include "tensorflow/lite/experimental/shlo/legacy/src/storage.h" +#include "tensorflow/lite/experimental/shlo/legacy/test/matchers.h" +#include "tensorflow/lite/experimental/shlo/legacy/test/util.h" namespace stablehlo { namespace testing { @@ -43,19 +44,8 @@ void test(absl::Status (*op)(const Tensor&, Tensor&), expected_values.size()); Tensor result(TensorType(Shape(shape), element_type), result_values.data()); - auto res = op(input, result); - - if (!res.ok()) { - LOG(INFO) << "Failure: " << res; - } - if (result != expected) { - LOG(INFO) << "input=" << input; - LOG(INFO) << "expected=" << expected; - LOG(INFO) << "result=" << result; - } - - ASSERT_EQ(res.ok(), true); - ASSERT_EQ(AlmostSame(result, expected), true); + ASSERT_OK(op(input, result)); + EXPECT_THAT(result, IsAlmostSame(expected)) << "input: " << input; } template @@ -88,19 +78,8 @@ void test( QuantizedTensorElementType(element_type)), result_quant_values.data()); - auto res = op(input, result); - - if (!res.ok()) { - LOG(INFO) << "Failure: " << res; - } - if (result != expected) { - LOG(INFO) << "input=" << input; - LOG(INFO) << "expected=" << expected; - LOG(INFO) << "result=" << result; - } - - ASSERT_EQ(res.ok(), true); - ASSERT_EQ(AlmostSame(result, expected), true); + ASSERT_OK(op(input, result)); + EXPECT_THAT(result, IsAlmostSame(expected)) << "input: " << input; } TEST(ElementwiseUnary, Abs) { diff --git a/tensorflow/lite/experimental/shlo/test/iota_test.cc b/tensorflow/lite/experimental/shlo/legacy/test/iota_test.cc similarity index 87% rename from tensorflow/lite/experimental/shlo/test/iota_test.cc rename to tensorflow/lite/experimental/shlo/legacy/test/iota_test.cc index 794e9518c5708d..7fe34cdc01d902 100644 --- a/tensorflow/lite/experimental/shlo/test/iota_test.cc +++ b/tensorflow/lite/experimental/shlo/legacy/test/iota_test.cc @@ -17,12 +17,12 @@ limitations under the License. #include #include +#include #include -#include "absl/log/log.h" -#include "tensorflow/lite/experimental/shlo/include/shlo.h" -#include "tensorflow/lite/experimental/shlo/src/debug.h" // IWYU pragma: keep, b/321245930 -#include "tensorflow/lite/experimental/shlo/src/storage.h" -#include "tensorflow/lite/experimental/shlo/test/util.h" +#include "tensorflow/lite/experimental/shlo/legacy/include/shlo.h" +#include "tensorflow/lite/experimental/shlo/legacy/src/debug.h" // IWYU pragma: keep, b/321245930 +#include "tensorflow/lite/experimental/shlo/legacy/src/storage.h" +#include "tensorflow/lite/experimental/shlo/legacy/test/util.h" namespace stablehlo { namespace testing { @@ -38,19 +38,8 @@ void test(std::initializer_list&& shape, expected_values.size()); Tensor result(TensorType(Shape(shape), element_type), result_values.data()); - auto res = Iota(iota_dimension, result); - if (!res.ok()) { - ABSL_LOG(INFO) << "Failure: " << res; - } - ASSERT_EQ(res.ok(), true); - - if (result != expected) { - ABSL_LOG(INFO) << "iota_dimension=" << iota_dimension; - LOG(INFO) << "expected=" << expected; - LOG(INFO) << "result=" << result; - } - - ASSERT_EQ(result, expected); + ASSERT_OK(Iota(iota_dimension, result)); + EXPECT_EQ(result, expected) << "\niota_dimension: " << iota_dimension; } template @@ -74,19 +63,8 @@ void test( QuantizedTensorElementType(element_type)), result_quant_values.data()); - auto res = Iota(iota_dimension, result); - if (!res.ok()) { - LOG(INFO) << "Failure: " << res; - } - ASSERT_EQ(res.ok(), true); - - if (result != expected) { - LOG(INFO) << "iota_dimension=" << iota_dimension; - LOG(INFO) << "expected=" << expected; - LOG(INFO) << "result=" << result; - } - - ASSERT_EQ(result, expected); + ASSERT_OK(Iota(iota_dimension, result)); + EXPECT_EQ(result, expected) << "\niota_dimension: " << iota_dimension; } TEST(Iota, Unquantized) { diff --git a/tensorflow/lite/experimental/shlo/test/is_finite_test.cc b/tensorflow/lite/experimental/shlo/legacy/test/is_finite_test.cc similarity index 84% rename from tensorflow/lite/experimental/shlo/test/is_finite_test.cc rename to tensorflow/lite/experimental/shlo/legacy/test/is_finite_test.cc index d52838b081a64a..5c097287233822 100644 --- a/tensorflow/lite/experimental/shlo/test/is_finite_test.cc +++ b/tensorflow/lite/experimental/shlo/legacy/test/is_finite_test.cc @@ -17,11 +17,11 @@ limitations under the License. #include #include +#include #include -#include "absl/log/log.h" -#include "tensorflow/lite/experimental/shlo/include/shlo.h" -#include "tensorflow/lite/experimental/shlo/src/debug.h" // IWYU pragma: keep, b/321245930 -#include "tensorflow/lite/experimental/shlo/src/storage.h" +#include "tensorflow/lite/experimental/shlo/legacy/include/shlo.h" +#include "tensorflow/lite/experimental/shlo/legacy/src/debug.h" // IWYU pragma: keep, b/321245930 +#include "tensorflow/lite/experimental/shlo/legacy/src/storage.h" namespace stablehlo { namespace testing { @@ -40,19 +40,8 @@ void test( Tensor result(TensorType(Shape(shape), ElementType::kI1), result_values.data()); - auto res = IsFinite(input, result); - if (!res.ok()) { - LOG(INFO) << "Failure: " << res; - } - ASSERT_EQ(res.ok(), true); - - if (result != expected) { - LOG(INFO) << "input=" << input; - LOG(INFO) << "expected=" << expected; - LOG(INFO) << "result=" << result; - } - - ASSERT_EQ(result, expected); + ASSERT_OK(IsFinite(input, result)); + EXPECT_EQ(result, expected) << "input: " << input; } template @@ -70,19 +59,8 @@ void test( Tensor result(TensorType(Shape(shape), ElementType::kI1), result_values.data()); - auto res = IsFinite(input, result); - if (!res.ok()) { - LOG(INFO) << "Failure: " << res; - } - ASSERT_EQ(res.ok(), true); - - if (result != expected) { - LOG(INFO) << "input=" << input; - LOG(INFO) << "expected=" << expected; - LOG(INFO) << "result=" << result; - } - - ASSERT_EQ(result, expected); + ASSERT_OK(IsFinite(input, result)); + EXPECT_EQ(result, expected) << "input: " << input; } TEST(IsFinite, Unquantized) { diff --git a/third_party/xla/xla/service/gpu/runtime/conv_reorder.h b/tensorflow/lite/experimental/shlo/legacy/test/matchers.h similarity index 55% rename from third_party/xla/xla/service/gpu/runtime/conv_reorder.h rename to tensorflow/lite/experimental/shlo/legacy/test/matchers.h index 1c43dad384e820..6f06428db3fd61 100644 --- a/third_party/xla/xla/service/gpu/runtime/conv_reorder.h +++ b/tensorflow/lite/experimental/shlo/legacy/test/matchers.h @@ -1,10 +1,10 @@ -/* Copyright 2022 The OpenXLA Authors. +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 +http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, @@ -13,19 +13,18 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_GPU_RUNTIME_CONV_REORDER_H_ -#define XLA_SERVICE_GPU_RUNTIME_CONV_REORDER_H_ +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_SHLO_LEGACY_TEST_MATCHERS_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_SHLO_LEGACY_TEST_MATCHERS_H_ -#include "xla/runtime/custom_call_registry.h" +#include +#include "tensorflow/lite/experimental/shlo/legacy/src/debug.h" -namespace xla { -namespace gpu { +namespace stablehlo { +namespace testing { -// Registers XLA Gpu runtime convolution reorder custom calls. -void RegisterConvReorderCustomCalls( - runtime::DirectCustomCallRegistry& registry); +MATCHER_P(IsAlmostSame, expected, "") { return AlmostSame(arg, expected); } -} // namespace gpu -} // namespace xla +} // namespace testing +} // namespace stablehlo -#endif // XLA_SERVICE_GPU_RUNTIME_CONV_REORDER_H_ +#endif // TENSORFLOW_LITE_EXPERIMENTAL_SHLO_LEGACY_TEST_MATCHERS_H_ diff --git a/tensorflow/lite/experimental/shlo/test/select_test.cc b/tensorflow/lite/experimental/shlo/legacy/test/select_test.cc similarity index 88% rename from tensorflow/lite/experimental/shlo/test/select_test.cc rename to tensorflow/lite/experimental/shlo/legacy/test/select_test.cc index 16bdcc6a7418f4..d82accf8abaeb2 100644 --- a/tensorflow/lite/experimental/shlo/test/select_test.cc +++ b/tensorflow/lite/experimental/shlo/legacy/test/select_test.cc @@ -17,12 +17,12 @@ limitations under the License. #include #include +#include #include -#include "absl/log/log.h" -#include "tensorflow/lite/experimental/shlo/include/shlo.h" -#include "tensorflow/lite/experimental/shlo/src/debug.h" // IWYU pragma: keep, b/321245930 -#include "tensorflow/lite/experimental/shlo/src/storage.h" -#include "tensorflow/lite/experimental/shlo/test/util.h" +#include "tensorflow/lite/experimental/shlo/legacy/include/shlo.h" +#include "tensorflow/lite/experimental/shlo/legacy/src/debug.h" // IWYU pragma: keep, b/321245930 +#include "tensorflow/lite/experimental/shlo/legacy/src/storage.h" +#include "tensorflow/lite/experimental/shlo/legacy/test/util.h" namespace stablehlo { namespace testing { @@ -46,21 +46,9 @@ void test(std::initializer_list&& shape, expected_values.size()); Tensor result(TensorType(Shape(shape), element_type), result_values.data()); - auto res = Select(pred, on_true, on_false, result); - - if (!res.ok()) { - LOG(INFO) << "Failure: " << res; - } - ASSERT_EQ(res.ok(), true); - - if (result != expected) { - LOG(INFO) << "pred=" << pred; - LOG(INFO) << "on_true=" << on_true; - LOG(INFO) << "on_false=" << on_false; - LOG(INFO) << "expected=" << expected; - LOG(INFO) << "result=" << result; - } - ASSERT_EQ(result, expected); + ASSERT_OK(Select(pred, on_true, on_false, result)); + EXPECT_EQ(result, expected) << "pred: " << pred << "\non_true: " << on_true + << "\nnon_false: " << on_false; } template @@ -103,21 +91,9 @@ void test( QuantizedTensorElementType(element_type)), result_quant_values.data()); - auto res = Select(pred, on_true, on_false, result); - - if (!res.ok()) { - LOG(INFO) << "Failure: " << res; - } - ASSERT_EQ(res.ok(), true); - - if (result != expected) { - LOG(INFO) << "pred=" << pred; - LOG(INFO) << "on_true=" << on_true; - LOG(INFO) << "on_false=" << on_false; - LOG(INFO) << "expected=" << expected; - LOG(INFO) << "result=" << result; - } - ASSERT_EQ(result, expected); + ASSERT_OK(Select(pred, on_true, on_false, result)); + EXPECT_EQ(result, expected) << "pred: " << pred << "\non_true: " << on_true + << "\nnon_false: " << on_false; } TEST(Select, Unquantized) { diff --git a/tensorflow/lite/experimental/shlo/test/uniform_dequantize_quantize_test.cc b/tensorflow/lite/experimental/shlo/legacy/test/uniform_dequantize_quantize_test.cc similarity index 87% rename from tensorflow/lite/experimental/shlo/test/uniform_dequantize_quantize_test.cc rename to tensorflow/lite/experimental/shlo/legacy/test/uniform_dequantize_quantize_test.cc index 51b823a1b9f8b2..9cb5d288cedc87 100644 --- a/tensorflow/lite/experimental/shlo/test/uniform_dequantize_quantize_test.cc +++ b/tensorflow/lite/experimental/shlo/legacy/test/uniform_dequantize_quantize_test.cc @@ -17,12 +17,12 @@ limitations under the License. #include #include +#include #include -#include "absl/log/log.h" -#include "absl/status/status.h" -#include "tensorflow/lite/experimental/shlo/include/shlo.h" -#include "tensorflow/lite/experimental/shlo/src/debug.h" -#include "tensorflow/lite/experimental/shlo/src/storage.h" +#include "tensorflow/lite/experimental/shlo/legacy/include/shlo.h" +#include "tensorflow/lite/experimental/shlo/legacy/src/debug.h" // IWYU pragma: keep, b/321245930 +#include "tensorflow/lite/experimental/shlo/legacy/src/storage.h" +#include "tensorflow/lite/experimental/shlo/legacy/test/matchers.h" namespace stablehlo { namespace testing { @@ -45,27 +45,14 @@ void test(std::initializer_list&& shape, input_values.size()); Tensor result(TensorType(Shape(shape), expressed_type), result_values.data()); - auto res = UniformQuantize(input, quant); - if (!res.ok()) { - LOG(INFO) << "Failure: " << res; - } - ASSERT_EQ(res.ok(), true); - - res = UniformDequantize(quant, result); - if (!res.ok()) { - LOG(INFO) << "Failure: " << res; - } - ASSERT_EQ(res.ok(), true); - - if (result != input) { - LOG(INFO) << "input=" << input; - LOG(INFO) << "result=" << result; - } - - ASSERT_EQ(AlmostSame(result, input), true); + ASSERT_OK(UniformQuantize(input, quant)); + ASSERT_OK(UniformDequantize(quant, result)); + EXPECT_THAT(result, IsAlmostSame(input)); } TEST(QuantizeDequantize, All) { + test( + {4}, {.scale = 1, .zero_point = 0}, {-2, -1, 0, 1, 2}); test( {4}, {.scale = 1, .zero_point = 0}, {-2, -1, 0, 1, 2}); test( diff --git a/tensorflow/lite/experimental/shlo/test/util.h b/tensorflow/lite/experimental/shlo/legacy/test/util.h similarity index 80% rename from tensorflow/lite/experimental/shlo/test/util.h rename to tensorflow/lite/experimental/shlo/legacy/test/util.h index e1bb4fb87c1d1d..2fb80c47d814c3 100644 --- a/tensorflow/lite/experimental/shlo/test/util.h +++ b/tensorflow/lite/experimental/shlo/legacy/test/util.h @@ -13,16 +13,16 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_SHLO_TEST_UTIL_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_SHLO_TEST_UTIL_H_ +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_SHLO_LEGACY_TEST_UTIL_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_SHLO_LEGACY_TEST_UTIL_H_ #include #include #include "absl/log/check.h" -#include "tensorflow/lite/experimental/shlo/include/shlo.h" -#include "tensorflow/lite/experimental/shlo/src/storage.h" -#include "tensorflow/lite/experimental/shlo/src/util.h" +#include "tensorflow/lite/experimental/shlo/legacy/include/shlo.h" +#include "tensorflow/lite/experimental/shlo/legacy/src/storage.h" +#include "tensorflow/lite/experimental/shlo/legacy/src/util.h" namespace stablehlo { @@ -51,4 +51,4 @@ std::vector::Type> QuantizeVector( } // namespace stablehlo -#endif // TENSORFLOW_LITE_EXPERIMENTAL_SHLO_TEST_UTIL_H_ +#endif // TENSORFLOW_LITE_EXPERIMENTAL_SHLO_LEGACY_TEST_UTIL_H_ diff --git a/tensorflow/lite/experimental/shlo/ops/BUILD b/tensorflow/lite/experimental/shlo/ops/BUILD new file mode 100644 index 00000000000000..33ffee9a9c8f69 --- /dev/null +++ b/tensorflow/lite/experimental/shlo/ops/BUILD @@ -0,0 +1,6 @@ +# Implementation of StableHLO operations. + +package( + default_applicable_licenses = ["//tensorflow:LICENSE"], + default_visibility = ["//visibility:public"], +) diff --git a/tensorflow/lite/experimental/shlo/quantized_tensor_element_type.h b/tensorflow/lite/experimental/shlo/quantized_tensor_element_type.h new file mode 100644 index 00000000000000..7e94966e893057 --- /dev/null +++ b/tensorflow/lite/experimental/shlo/quantized_tensor_element_type.h @@ -0,0 +1,130 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_SHLO_QUANTIZED_TENSOR_ELEMENT_TYPE_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_SHLO_QUANTIZED_TENSOR_ELEMENT_TYPE_H_ + +#include +#include +#include + +#include "absl/container/inlined_vector.h" +#include "absl/log/absl_check.h" +#include "absl/types/span.h" +#include "tensorflow/lite/experimental/shlo/data_type.h" +#include "tensorflow/lite/experimental/shlo/shape.h" + +namespace shlo_ref { + +class QuantizedTensorElementType { + public: + template + static QuantizedTensorElementType PerTensor( + StorageType scale, StorageType zero_point) { + static_assert(IsInteger(storage_type), + "Storage type must be an integer type"); + static_assert(IsFloat(expressed_type), + "Expressed type must be a floating point type"); + using StorageT = Storage::Type; + using ExpressedT = Storage::Type; + + return QuantizedTensorElementType( + storage_type, expressed_type, std::nullopt, + SmallInlinedVector({scale}), + SmallInlinedVector({zero_point})); + } + + template + static QuantizedTensorElementType PerAxis( + absl::Span> scales, + absl::Span> zero_points, + Axis quantized_dimension) { + static_assert(IsInteger(storage_type), + "Storage type must be an integer type"); + static_assert(IsFloat(expressed_type), + "Expressed type must be a floating point type"); + using StorageT = Storage::Type; + using ExpressedT = Storage::Type; + + ABSL_CHECK(scales.size() == zero_points.size()); + return QuantizedTensorElementType( + storage_type, expressed_type, quantized_dimension, + SmallInlinedVector(scales.begin(), scales.end()), + SmallInlinedVector(zero_points.begin(), zero_points.end())); + } + + DataType StorageType() const { return storage_type_; } + DataType ExpressedType() const { return expressed_type_; } + + bool IsPerTensorQuantized() const { return !quantized_dimension_; } + bool IsPerAxisQuantized() const { return !IsPerTensorQuantized(); } + + Axis QuantizedDimension() const { + ABSL_CHECK(IsPerAxisQuantized()); + return quantized_dimension_.value(); + } + + template ::Type> + absl::Span Scales() const { + ABSL_CHECK(expressed_type == expressed_type_); + ABSL_CHECK(std::holds_alternative>(scales_)); + return std::get>(scales_); + } + + template ::Type> + absl::Span ZeroPoints() const { + ABSL_CHECK(storage_type == storage_type_); + ABSL_CHECK(std::holds_alternative>(zero_points_)); + return std::get>(zero_points_); + } + + private: + // Most quantized tensors will likely be per tensor quantized, which will have + // a single element in the vector. Use an InlinedVector with a single element + // so we only allocate when using per axis quantization. + template + using SmallInlinedVector = absl::InlinedVector; + + template + QuantizedTensorElementType(DataType storage_type, DataType expressed_type, + std::optional quantized_dimension, + SmallInlinedVector scales, + SmallInlinedVector zero_points) + : storage_type_(storage_type), + expressed_type_(expressed_type), + quantized_dimension_(quantized_dimension), + scales_(std::move(scales)), + zero_points_(std::move(zero_points)) {} + + DataType storage_type_; + DataType expressed_type_; + std::optional quantized_dimension_; + + std::variant::Type>, + SmallInlinedVector::Type>, + SmallInlinedVector::Type>> + scales_; + + // There is no need for kSI4 because it currently uses the same underlying + // storage type as kSI8, which complicates accessing the variant. If they ever + // use different underlying types, please add an alternative for kSI4. + std::variant::Type>, + SmallInlinedVector::Type>, + SmallInlinedVector::Type>> + zero_points_; +}; + +} // namespace shlo_ref +#endif // TENSORFLOW_LITE_EXPERIMENTAL_SHLO_QUANTIZED_TENSOR_ELEMENT_TYPE_H_ diff --git a/tensorflow/lite/experimental/shlo/quantized_tensor_element_type_test.cc b/tensorflow/lite/experimental/shlo/quantized_tensor_element_type_test.cc new file mode 100644 index 00000000000000..f8e00423ea14f7 --- /dev/null +++ b/tensorflow/lite/experimental/shlo/quantized_tensor_element_type_test.cc @@ -0,0 +1,93 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/experimental/shlo/quantized_tensor_element_type.h" + +#include +#include +#include "tensorflow/lite/experimental/shlo/data_type.h" + +namespace shlo_ref { +namespace { + +using ::testing::ElementsAre; +using ::testing::Eq; + +template +struct TestPair { + using StorageT = StorageType; + using ExpressedT = StorageType; + + static constexpr DataType kStorageType = storage_type; + static constexpr DataType kExpressedType = expressed_type; +}; + +template +class QuantizedTensorElementTypeTest : public ::testing::Test {}; + +using TestTypes = ::testing::Types, + TestPair, + TestPair, + TestPair, + TestPair, + TestPair, + TestPair, + TestPair, + TestPair, + TestPair, + TestPair, + TestPair>; + +TYPED_TEST_SUITE(QuantizedTensorElementTypeTest, TestTypes); + +TYPED_TEST(QuantizedTensorElementTypeTest, PerTensor) { + typename TypeParam::ExpressedT scale = .5; + typename TypeParam::StorageT zero_point = 3; + + auto element = + QuantizedTensorElementType::PerTensor( + scale, zero_point); + + EXPECT_THAT(element.StorageType(), Eq(TypeParam::kStorageType)); + EXPECT_THAT(element.ExpressedType(), Eq(TypeParam::kExpressedType)); + EXPECT_THAT(element.IsPerTensorQuantized(), Eq(true)); + EXPECT_THAT(element.IsPerAxisQuantized(), Eq(false)); + EXPECT_THAT(element.template Scales(), + ElementsAre(.5f)); + EXPECT_THAT(element.template ZeroPoints(), + ElementsAre(3)); +} + +TYPED_TEST(QuantizedTensorElementTypeTest, PerAxis) { + typename TypeParam::ExpressedT scales[] = {.5, .6, .2}; + typename TypeParam::StorageT zero_points[] = {3, 1, 2}; + auto element = QuantizedTensorElementType::PerAxis( + absl::MakeConstSpan(scales), absl::MakeConstSpan(zero_points), 3u); + + EXPECT_THAT(element.StorageType(), Eq(TypeParam::kStorageType)); + EXPECT_THAT(element.ExpressedType(), Eq(TypeParam::kExpressedType)); + EXPECT_THAT(element.IsPerTensorQuantized(), Eq(false)); + EXPECT_THAT(element.IsPerAxisQuantized(), Eq(true)); + EXPECT_THAT(element.QuantizedDimension(), Eq(3)); + EXPECT_THAT(element.template Scales(), + ElementsAre(.5f, .6f, .2f)); + EXPECT_THAT(element.template ZeroPoints(), + ElementsAre(3, 1, 2)); +} + +} // namespace +} // namespace shlo_ref diff --git a/tensorflow/lite/experimental/shlo/shape.cc b/tensorflow/lite/experimental/shlo/shape.cc new file mode 100644 index 00000000000000..a7cc4ec9800d57 --- /dev/null +++ b/tensorflow/lite/experimental/shlo/shape.cc @@ -0,0 +1,71 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/experimental/shlo/shape.h" + +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/container/inlined_vector.h" +#include "absl/types/span.h" + +namespace shlo_ref { + +Shape::Shape(absl::Span dims) + : dims_(dims.begin(), dims.end()) {} + +absl::Span Shape::Dimensions() const { return dims_; } + +absl::Span Shape::MutableDimensions() { + return absl::MakeSpan(dims_); +} + +absl::InlinedVector Shape::Axes() const { + absl::InlinedVector axes(dims_.size()); + absl::c_iota(axes, 0); + return axes; +} + +DimensionSize Shape::Dim(Axis axis) const { return dims_[axis]; } + +absl::InlinedVector Shape::Dims( + absl::Span axes) const { + absl::InlinedVector dims; + for (const auto axis : axes) { + // Ignore invalid axis + if (axis < dims_.size()) { + dims.push_back(Dim(axis)); + } + } + return dims; +} + +size_t Shape::Rank() const { return dims_.size(); } + +DimensionSize Shape::NumElements() const { + if (dims_.empty()) { + return 0; + } + return absl::c_accumulate(dims_, 1, std::multiplies<>()); +} + +bool operator==(const Shape& lhs, const Shape& rhs) { + return lhs.Dimensions() == rhs.Dimensions(); +} + +bool operator!=(const Shape& lhs, const Shape& rhs) { return !(lhs == rhs); } + +} // namespace shlo_ref diff --git a/tensorflow/lite/experimental/shlo/shape.h b/tensorflow/lite/experimental/shlo/shape.h new file mode 100644 index 00000000000000..b3c04ef0322db2 --- /dev/null +++ b/tensorflow/lite/experimental/shlo/shape.h @@ -0,0 +1,76 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_SHLO_SHAPE_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_SHLO_SHAPE_H_ + +#include +#include + +#include "absl/container/inlined_vector.h" +#include "absl/types/span.h" + +namespace shlo_ref { + +// The SHLO Spec states that dimensions are non-negative. We diverge from the +// spec here to use negative values to represent dynamic dimensions. +using DimensionSize = int64_t; +using Axis = size_t; + +inline constexpr DimensionSize kDynamicDimension = -1; +inline constexpr Axis kMaxNumDimensions = 6; + +class Shape { + public: + Shape() = default; + ~Shape() = default; + Shape(const Shape&) = default; + Shape& operator=(const Shape&) = default; + Shape(Shape&&) = default; + Shape& operator=(Shape&&) = default; + + explicit Shape(absl::Span dims); + + absl::Span Dimensions() const; + absl::Span MutableDimensions(); + + // range(rank(x)) + absl::InlinedVector Axes() const; + + // shape(x)[axis] + DimensionSize Dim(Axis axis) const; + + // list(map(lambda axis: dim(x, axis), axes)) + absl::InlinedVector Dims( + absl::Span axes) const; + + // size(shape(x)) + size_t Rank() const; + + // reduce(lambda x, y: x * y, shape(x)) + // Note: in the SHLO spec, this is called size. We've diverged for readability + // and possible confusion with C++ container's usage of size(). + DimensionSize NumElements() const; + + private: + absl::InlinedVector dims_; +}; + +bool operator==(const Shape& lhs, const Shape& rhs); +bool operator!=(const Shape& lhs, const Shape& rhs); + +} // namespace shlo_ref + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_SHLO_SHAPE_H_ diff --git a/tensorflow/lite/experimental/shlo/shape_test.cc b/tensorflow/lite/experimental/shlo/shape_test.cc new file mode 100644 index 00000000000000..41495efb9cbd15 --- /dev/null +++ b/tensorflow/lite/experimental/shlo/shape_test.cc @@ -0,0 +1,92 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/experimental/shlo/shape.h" + +#include +#include + +namespace shlo_ref { +namespace { + +using ::testing::ElementsAre; +using ::testing::Eq; + +TEST(ShapeTest, DimensionsAccess) { + const Shape shape({1, 2, 4, 8}); + EXPECT_THAT(shape.Dimensions(), ElementsAre(1, 2, 4, 8)); +} + +TEST(ShapeTest, DimensionsMutableAccess) { + Shape shape({1, 2, 4, 8}); + + shape.MutableDimensions()[2] = 42; + EXPECT_THAT(shape.Dimensions(), ElementsAre(1, 2, 42, 8)); +} + +TEST(ShapeTest, Axes) { + Shape shape({1, 2, 4, 8}); + EXPECT_THAT(shape.Axes(), ElementsAre(0, 1, 2, 3)); +} + +TEST(ShapeTest, Dim) { + Shape shape({1, 2, 4, 8}); + EXPECT_THAT(shape.Dim(1), Eq(2)); + EXPECT_THAT(shape.Dim(3), Eq(8)); +} + +TEST(ShapeTest, Dims) { + Shape shape({1, 2, 4, 8}); + EXPECT_THAT(shape.Dims({1, 3}), ElementsAre(2, 8)); +} + +TEST(ShapeTest, DimsInvalidAxisIgnored) { + Shape shape({1, 2, 4, 8}); + EXPECT_THAT(shape.Dims({1, 8}), ElementsAre(2)); +} + +TEST(ShapeTest, Rank) { + Shape shape({1, 2, 4, 8}); + EXPECT_THAT(shape.Rank(), Eq(4)); +} + +TEST(ShapeTest, RankEmpty) { + Shape shape{}; + EXPECT_THAT(shape.Rank(), Eq(0)); +} + +TEST(ShapeTest, NumElementsEmpty) { + Shape shape{}; + EXPECT_THAT(shape.NumElements(), Eq(0)); +} +TEST(ShapeTest, NumElements) { + Shape shape({1, 2, 4, 8}); + EXPECT_THAT(shape.NumElements(), Eq(64)); +} + +TEST(ShapeTest, Equals) { + Shape s1({1, 2, 4, 8}); + Shape s2({1, 2, 4, 8}); + EXPECT_TRUE(s1 == s2); +} + +TEST(ShapeTest, NotEquals) { + Shape s1({1, 2, 4, 8}); + Shape s2({1, 4, 2, 8}); + EXPECT_TRUE(s1 != s2); +} + +} // namespace +} // namespace shlo_ref diff --git a/tensorflow/lite/experimental/shlo/tensor.cc b/tensorflow/lite/experimental/shlo/tensor.cc new file mode 100644 index 00000000000000..b65506a7a5bab2 --- /dev/null +++ b/tensorflow/lite/experimental/shlo/tensor.cc @@ -0,0 +1,100 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/experimental/shlo/tensor.h" + +#include +#include +#include + +#include "tensorflow/lite/experimental/shlo/data_type.h" +#include "tensorflow/lite/experimental/shlo/quantized_tensor_element_type.h" +#include "tensorflow/lite/experimental/shlo/shape.h" + +namespace shlo_ref { + +const Shape& Tensor::shape() const { + if (IsQuantized()) { + return quantized_tensor_type().shape; + } else { + return tensor_type().shape; + } +} + +Shape& Tensor::shape() { + if (IsQuantized()) { + return quantized_tensor_type().shape; + } else { + return tensor_type().shape; + } +} + +bool Tensor::IsQuantized() const { + return std::holds_alternative(type); +} + +bool Tensor::IsPerAxisQuantized() const { + return IsQuantized() && + std::get(type).element_type.IsPerAxisQuantized(); +} +bool Tensor::IsPerTensorQuantized() const { + return IsQuantized() && std::get(type) + .element_type.IsPerTensorQuantized(); +} + +size_t Tensor::Rank() const { + return IsQuantized() ? quantized_tensor_type().shape.Rank() + : tensor_type().shape.Rank(); +} + +DataType Tensor::StorageType() const { + return IsQuantized() ? quantized_tensor_type().element_type.StorageType() + : tensor_type().element_type; +} + +DimensionSize Tensor::NumElements() const { + return IsQuantized() ? quantized_tensor_type().shape.NumElements() + : tensor_type().shape.NumElements(); +} + +TensorType& Tensor::tensor_type() { + assert(std::holds_alternative(type)); + return std::get(type); +} + +const TensorType& Tensor::tensor_type() const { + assert(std::holds_alternative(type)); + return std::get(type); +} + +QuantizedTensorType& Tensor::quantized_tensor_type() { + assert(std::holds_alternative(type)); + return std::get(type); +} + +const QuantizedTensorType& Tensor::quantized_tensor_type() const { + assert(std::holds_alternative(type)); + return std::get(type); +} + +const TensorElementType& Tensor::tensor_element_type() const { + return tensor_type().element_type; +} +const QuantizedTensorElementType& Tensor::quantized_tensor_element_type() + const { + return quantized_tensor_type().element_type; +} + +} // namespace shlo_ref diff --git a/tensorflow/lite/experimental/shlo/tensor.h b/tensorflow/lite/experimental/shlo/tensor.h new file mode 100644 index 00000000000000..759a401c683a8a --- /dev/null +++ b/tensorflow/lite/experimental/shlo/tensor.h @@ -0,0 +1,85 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_SHLO_TENSOR_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_SHLO_TENSOR_H_ + +#include +#include +#include + +#include "tensorflow/lite/experimental/shlo/data_type.h" +#include "tensorflow/lite/experimental/shlo/quantized_tensor_element_type.h" +#include "tensorflow/lite/experimental/shlo/shape.h" + +namespace shlo_ref { + +using TensorElementType = DataType; + +struct TensorType { + Shape shape; + TensorElementType element_type; +}; + +struct QuantizedTensorType { + Shape shape; + QuantizedTensorElementType element_type; +}; + +struct Tensor { + const Shape& shape() const; + Shape& shape(); + + bool IsQuantized() const; + bool IsPerAxisQuantized() const; + bool IsPerTensorQuantized() const; + + size_t Rank() const; + DataType StorageType() const; + + DimensionSize NumElements() const; + + TensorType& tensor_type(); + const TensorType& tensor_type() const; + + QuantizedTensorType& quantized_tensor_type(); + const QuantizedTensorType& quantized_tensor_type() const; + + const TensorElementType& tensor_element_type() const; + const QuantizedTensorElementType& quantized_tensor_element_type() const; + + template ::Type> + T* GetDataAs() { + return reinterpret_cast(data); + } + + template ::Type> + const T* GetDataAs() const { + return reinterpret_cast(data); + } + + std::variant type; + + // If type is TensorType, the type should be Storage::Type. + // If type is QuantizedTensorType, the type should be + // Storage::Type. + // May be nullptr if buffers are not yet available. + // The size of the array must be equal to Size(shape). + void* data = nullptr; +}; + +} // namespace shlo_ref + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_SHLO_TENSOR_H_ diff --git a/tensorflow/lite/g3doc/_book.yaml b/tensorflow/lite/g3doc/_book.yaml index a3c832afadcbd8..a229d8baf5e0c3 100644 --- a/tensorflow/lite/g3doc/_book.yaml +++ b/tensorflow/lite/g3doc/_book.yaml @@ -91,6 +91,12 @@ upper_tabs: path: /lite/android/quickstart - title: "Google Play services runtime" path: /lite/android/play_services + section: + - title: "Java API" + path: /lite/android/java.md + - title: "C API" + path: /lite/android/native.md + status: experimental - title: "Development tools" path: /lite/android/development @@ -130,7 +136,7 @@ upper_tabs: path: /lite/android/delegates/gpu.md - title: "Task library API" path: /lite/android/delegates/gpu_task.md - - title: "Native API" + - title: "C/C++ API" path: /lite/android/delegates/gpu_native.md - title: "NNAPI delegate" path: /lite/android/delegates/nnapi @@ -276,6 +282,9 @@ upper_tabs: - title: "Convert JAX models" path: /lite/examples/jax_conversion/overview status: nightly + - title: "Colab for JAX to TFLite" + path: /lite/examples/jax_conversion/jax_to_tflite + status: nightly - title: "Model compatibility" section: - title: "Overview" diff --git a/tensorflow/lite/g3doc/android/delegates/gpu_native.md b/tensorflow/lite/g3doc/android/delegates/gpu_native.md index 996056e1789d61..2221c2066f9cb6 100644 --- a/tensorflow/lite/g3doc/android/delegates/gpu_native.md +++ b/tensorflow/lite/g3doc/android/delegates/gpu_native.md @@ -63,6 +63,87 @@ an `EGLContext` does not exist, the delegate creates one internally, but then you must ensure that `Interpreter::Invoke()` is always called from the same thread in which `Interpreter::ModifyGraphWithDelegate()` was called. +#### With TensorFlow Lite in Google Play Services: + +If you are using TensorFlow Lite in Google Play Services [C API](../native), +you’ll need to use the Java/Kotlin API to check if a GPU delegate is available +for your device before initializing the TensorFlow Lite runtime. + +Add the GPU delegate gradle dependencies to your application: + +``` +implementation 'com.google.android.gms:play-services-tflite-gpu:16.2.0' +``` + +Then, check the GPU availability and initialize TfLiteNative if the check is +successful: + +
+ +
+

Java

+
+Task tfLiteHandleTask =
+TfLiteGpu.isGpuDelegateAvailable(this)
+   .onSuccessTask(gpuAvailable -> {
+      TfLiteInitializationOptions options =
+        TfLiteInitializationOptions.builder()
+          .setEnableGpuDelegateSupport(gpuAvailable).build();
+        return TfLiteNative.initialize(this, options);
+      }
+    );
+      
+
+
+

Kotlin

+
+val tfLiteHandleTask = TfLiteGpu.isGpuDelegateAvailable(this)
+    .onSuccessTask { gpuAvailable ->
+        val options = TfLiteInitializationOptions.Builder()
+            .setEnableGpuDelegateSupport(gpuAvailable)
+            .build()
+        TfLiteNative.initialize(this, options)
+    }
+        
+
+
+
+ +You also need to update your CMake configuration to include the +`TFLITE_USE_OPAQUE_DELEGATE` compiler flag: + +``` +add_compile_definitions(TFLITE_USE_OPAQUE_DELEGATE) +``` + +The [FlatBuffers](https://flatbuffers.dev/) library is used to configure +delegate plugins, so you need to add it to the dependencies of your native code. +You can use the official `CMake` project configuration as follow: + +``` +target_include_directories(tflite-jni PUBLIC + third_party/headers # flatbuffers + ...) +``` + +You can also just bundle the headers to your app. + +Finally to use GPU inference in your C code, create the GPU delegate using +`TFLiteSettings`: + +``` +#include "flatbuffers/flatbuffers.h" +#include "tensorflow/lite/acceleration/configuration/configuration_generated.h" + +flatbuffers::FlatBufferBuilder fbb; +tflite::TFLiteSettingsBuilder builder(fbb); +const tflite::TFLiteSettings* tflite_settings = + flatbuffers::GetTemporaryPointer(fbb, builder.Finish()); + +const TfLiteOpaqueDelegatePlugin* pluginCApi = TfLiteGpuDelegatePluginCApi(); +TfLiteOpaqueDelegate* gpu_delegate = pluginCApi->create(tflite_settings); +``` + ## Quantized models {:#quantized-models} Android GPU delegate libraries support quantized models by default. You do not diff --git a/tensorflow/lite/g3doc/android/java.md b/tensorflow/lite/g3doc/android/java.md new file mode 100644 index 00000000000000..9278df326476e9 --- /dev/null +++ b/tensorflow/lite/g3doc/android/java.md @@ -0,0 +1,556 @@ +# TensorFlow Lite in Google Play services Java API + +TensorFlow Lite in Google Play services can also be accessed using Java APIs, in +addition to the Native API. In particular, TensorFlow Lite in Google Play +services is available through the +[TensorFlow Lite Task API](https://www.tensorflow.org/lite/api_docs/java/org/tensorflow/lite/task/core/package-summary) +and the +[TensorFlow Lite Interpreter API](https://www.tensorflow.org/lite/api_docs/java/org/tensorflow/lite/InterpreterApi). +The Task Library provides optimized out-of-the-box model interfaces for common +machine learning tasks using visual, audio, and text data. The TensorFlow Lite +Interpreter API, provided by the TensorFlow runtime, provides a more +general-purpose interface for building and running ML models. + +The following sections provide instructions on how to use the Interpreter and +Task Library APIs with TensorFlow Lite in Google Play services. While it is +possible for an app to use both the Interpreter APIs and Task Library APIs, most +apps should only use one set of APIs. + +### Using the Task Library APIs + +The TensorFlow Lite Task API wraps the Interpreter API and provides a high-level +programming interface for common machine learning tasks that use visual, audio, +and text data. You should use the Task API if your application requires one of +the +[supported tasks](../inference_with_metadata/task_library/overview#supported_tasks). + +#### 1. Add project dependencies + +Your project dependency depends on your machine learning use case. The Task APIs +contain the following libraries: + +* Vision library: `org.tensorflow:tensorflow-lite-task-vision-play-services` +* Audio library: `org.tensorflow:tensorflow-lite-task-audio-play-services` +* Text library: `org.tensorflow:tensorflow-lite-task-text-play-services` + +Add one of the dependencies to your app project code to access the Play services +API for TensorFlow Lite. For example, use the following to implement a vision +task: + +``` +dependencies { +... + implementation 'org.tensorflow:tensorflow-lite-task-vision-play-services:0.4.2' +... +} +``` + +Caution: The TensorFlow Lite Tasks Audio library version 0.4.2 maven repository +is incomplete. Use version 0.4.2.1 for this library instead: +`org.tensorflow:tensorflow-lite-task-audio-play-services:0.4.2.1`. + +#### 2. Add initialization of TensorFlow Lite + +Initialize the TensorFlow Lite component of the Google Play services API +*before* using the TensorFlow Lite APIs. The following example initializes the +vision library: + +
+ +
+

Kotlin

+
+init {
+  TfLiteVision.initialize(context)
+}
+
+
+
+
+ +Important: Make sure the `TfLite.initialize` task completes before executing +code that accesses TensorFlow Lite APIs. + +Tip: The TensorFlow Lite modules are installed at the same time your application +is installed or updated from the Play Store. You can check the availability of +the modules by using `ModuleInstallClient` from the Google Play services APIs. +For more information on checking module availability, see +[Ensuring API availability with ModuleInstallClient](https://developers.google.com/android/guides/module-install-apis). + +#### 3. Run inferences + +After initializing the TensorFlow Lite component, call the `detect()` method to +generate inferences. The exact code within the `detect()` method varies +depending on the library and use case. The following is for a simple object +detection use case with the `TfLiteVision` library: + +
+ +
+

Kotlin

+
+fun detect(...) {
+  if (!TfLiteVision.isInitialized()) {
+    Log.e(TAG, "detect: TfLiteVision is not initialized yet")
+    return
+  }
+
+  if (objectDetector == null) {
+    setupObjectDetector()
+  }
+
+  ...
+
+}
+
+
+
+
+ +Depending on the data format, you may also need to preprocess and convert your +data within the `detect()` method before generating inferences. For example, +image data for an object detector requires the following: + +```kotlin +val imageProcessor = ImageProcessor.Builder().add(Rot90Op(-imageRotation / 90)).build() +val tensorImage = imageProcessor.process(TensorImage.fromBitmap(image)) +val results = objectDetector?.detect(tensorImage) +``` + +### Using the Interpreter APIs + +The Interpreter APIs offer more control and flexibility than the Task Library +APIs. You should use the Interpreter APIs if your machine learning task is not +supported by the Task library, or if you require a more general-purpose +interface for building and running ML models. + +#### 1. Add project dependencies + +Add the following dependencies to your app project code to access the Play +services API for TensorFlow Lite: + +``` +dependencies { +... + // Tensorflow Lite dependencies for Google Play services + implementation 'com.google.android.gms:play-services-tflite-java:16.0.1' + // Optional: include Tensorflow Lite Support Library + implementation 'com.google.android.gms:play-services-tflite-support:16.0.1' +... +} +``` + +#### 2. Add initialization of TensorFlow Lite + +Initialize the TensorFlow Lite component of the Google Play services API +*before* using the TensorFlow Lite APIs: + +
+ +
+

Kotlin

+
+val initializeTask: Task<Void> by lazy { TfLite.initialize(this) }
+
+
+
+

Java

+
+Task<Void> initializeTask = TfLite.initialize(context);
+
+
+
+
+ +Note: Make sure the `TfLite.initialize` task completes before executing code +that accesses TensorFlow Lite APIs. Use the `addOnSuccessListener()` method, as +shown in the next section. + +#### 3. Create an Interpreter and set runtime option {:#step_3_interpreter} + +Create an interpreter using `InterpreterApi.create()` and configure it to use +Google Play services runtime, by calling `InterpreterApi.Options.setRuntime()`, +as shown in the following example code: + +
+ +
+

Kotlin

+
+import org.tensorflow.lite.InterpreterApi
+import org.tensorflow.lite.InterpreterApi.Options.TfLiteRuntime
+...
+private lateinit var interpreter: InterpreterApi
+...
+initializeTask.addOnSuccessListener {
+  val interpreterOption =
+    InterpreterApi.Options().setRuntime(TfLiteRuntime.FROM_SYSTEM_ONLY)
+  interpreter = InterpreterApi.create(
+    modelBuffer,
+    interpreterOption
+  )}
+  .addOnFailureListener { e ->
+    Log.e("Interpreter", "Cannot initialize interpreter", e)
+  }
+
+
+
+

Java

+
+import org.tensorflow.lite.InterpreterApi
+import org.tensorflow.lite.InterpreterApi.Options.TfLiteRuntime
+...
+private InterpreterApi interpreter;
+...
+initializeTask.addOnSuccessListener(a -> {
+    interpreter = InterpreterApi.create(modelBuffer,
+      new InterpreterApi.Options().setRuntime(TfLiteRuntime.FROM_SYSTEM_ONLY));
+  })
+  .addOnFailureListener(e -> {
+    Log.e("Interpreter", String.format("Cannot initialize interpreter: %s",
+          e.getMessage()));
+  });
+
+
+
+
+ +You should use the implementation above because it avoids blocking the Android +user interface thread. If you need to manage thread execution more closely, you +can add a `Tasks.await()` call to interpreter creation: + +
+ +
+

Kotlin

+
+import androidx.lifecycle.lifecycleScope
+...
+lifecycleScope.launchWhenStarted { // uses coroutine
+  initializeTask.await()
+}
+
+
+
+

Java

+
+@BackgroundThread
+InterpreterApi initializeInterpreter() {
+    Tasks.await(initializeTask);
+    return InterpreterApi.create(...);
+}
+
+
+
+
+ +Warning: Do not call `.await()` on the foreground user interface thread because +it interrupts display of user interface elements and creates a poor user +experience. + +#### 4. Run inferences + +Using the `interpreter` object you created, call the `run()` method to generate +an inference. + +
+ +
+

Kotlin

+
+interpreter.run(inputBuffer, outputBuffer)
+
+
+
+

Java

+
+interpreter.run(inputBuffer, outputBuffer);
+
+
+
+
+ +## Hardware acceleration {:#hardware-acceleration} + +TensorFlow Lite allows you to accelerate the performance of your model using +specialized hardware processors, such as graphics processing units (GPUs). You +can take advantage of these specialized processors using hardware drivers called +[*delegates*](https://www.tensorflow.org/lite/performance/delegates). You can +use the following hardware acceleration delegates with TensorFlow Lite in Google +Play services: + +- *[GPU delegate](https://www.tensorflow.org/lite/performance/gpu) + (recommended)* - This delegate is provided through Google Play services and + is dynamically loaded, just like the Play services versions of the Task API + and Interpreter API. + +- [*NNAPI delegate*](https://www.tensorflow.org/lite/android/delegates/nnapi) - + This delegate is available as an included library dependency in your Android + development project, and is bundled into your app. + +For more information about hardware acceleration with TensorFlow Lite, see the +[TensorFlow Lite Delegates](https://www.tensorflow.org/lite/performance/delegates) +page. + +### Checking device compatibility + +Not all devices support GPU hardware acceleration with TFLite. In order to +mitigate errors and potential crashes, use the +`TfLiteGpu.isGpuDelegateAvailable` method to check whether a device is +compatible with the GPU delegate. + +Use this method to confirm whether a device is compatible with GPU, and use CPU +or the NNAPI delegate as a fallback for when GPU is not supported. + +``` +useGpuTask = TfLiteGpu.isGpuDelegateAvailable(context) +``` + +Once you have a variable like `useGpuTask`, you can use it to determine whether +devices use the GPU delegate. The following examples show how this can be done +with both the Task Library and Interpreter APIs. + +**With the Task Api** + +
+ +
+

Kotlin

+
+lateinit val optionsTask = useGpuTask.continueWith { task ->
+  val baseOptionsBuilder = BaseOptions.builder()
+  if (task.result) {
+    baseOptionsBuilder.useGpu()
+  }
+ ObjectDetectorOptions.builder()
+          .setBaseOptions(baseOptionsBuilder.build())
+          .setMaxResults(1)
+          .build()
+}
+    
+
+
+

Java

+
+Task<ObjectDetectorOptions> optionsTask = useGpuTask.continueWith({ task ->
+  BaseOptions baseOptionsBuilder = BaseOptions.builder();
+  if (task.getResult()) {
+    baseOptionsBuilder.useGpu();
+  }
+  return ObjectDetectorOptions.builder()
+          .setBaseOptions(baseOptionsBuilder.build())
+          .setMaxResults(1)
+          .build()
+});
+    
+
+
+
+ +**With the Interpreter Api** + +
+ +
+

Kotlin

+
+val interpreterTask = useGpuTask.continueWith { task ->
+  val interpreterOptions = InterpreterApi.Options()
+      .setRuntime(TfLiteRuntime.FROM_SYSTEM_ONLY)
+  if (task.result) {
+      interpreterOptions.addDelegateFactory(GpuDelegateFactory())
+  }
+  InterpreterApi.create(FileUtil.loadMappedFile(context, MODEL_PATH), interpreterOptions)
+}
+    
+
+
+

Java

+
+Task<InterpreterApi.Options> interpreterOptionsTask = useGpuTask.continueWith({ task ->
+  InterpreterApi.Options options =
+      new InterpreterApi.Options().setRuntime(TfLiteRuntime.FROM_SYSTEM_ONLY);
+  if (task.getResult()) {
+     options.addDelegateFactory(new GpuDelegateFactory());
+  }
+  return options;
+});
+    
+
+
+
+ +### GPU with Task Library APIs + +To use the GPU delegate with the Task APIs: + +1. Update the project dependencies to use the GPU delegate from Play services: + + ``` + implementation 'com.google.android.gms:play-services-tflite-gpu:16.1.0' + ``` + +1. Initialize the GPU delegate with `setEnableGpuDelegateSupport`. For example, + you can initialize the GPU delegate for `TfLiteVision` with the following: + +
+ +
+

Kotlin

+
+        TfLiteVision.initialize(context, TfLiteInitializationOptions.builder().setEnableGpuDelegateSupport(true).build())
+        
+
+
+

Java

+
+        TfLiteVision.initialize(context, TfLiteInitializationOptions.builder().setEnableGpuDelegateSupport(true).build());
+        
+
+
+
+ +1. Enable the GPU delegate option with + [`BaseOptions`](https://www.tensorflow.org/lite/api_docs/java/org/tensorflow/lite/task/core/BaseOptions.Builder): + +
+ +
+

Kotlin

+
+        val baseOptions = BaseOptions.builder().useGpu().build()
+        
+
+
+

Java

+
+        BaseOptions baseOptions = BaseOptions.builder().useGpu().build();
+        
+
+
+
+ +1. Configure the options using `.setBaseOptions`. For example, you can set up + GPU in `ObjectDetector` with the following: + +
+ +
+

Kotlin

+
+        val options =
+            ObjectDetectorOptions.builder()
+                .setBaseOptions(baseOptions)
+                .setMaxResults(1)
+                .build()
+        
+
+
+

Java

+
+        ObjectDetectorOptions options =
+            ObjectDetectorOptions.builder()
+                .setBaseOptions(baseOptions)
+                .setMaxResults(1)
+                .build();
+        
+
+
+
+ +### GPU with Interpreter APIs + +To use the GPU delegate with the Interpreter APIs: + +1. Update the project dependencies to use the GPU delegate from Play services: + + ``` + implementation 'com.google.android.gms:play-services-tflite-gpu:16.1.0' + ``` + +1. Enable the GPU delegate option in the TFlite initialization: + +
+ +
+

Kotlin

+
+        TfLite.initialize(context,
+          TfLiteInitializationOptions.builder()
+           .setEnableGpuDelegateSupport(true)
+           .build())
+        
+
+
+

Java

+
+        TfLite.initialize(context,
+          TfLiteInitializationOptions.builder()
+           .setEnableGpuDelegateSupport(true)
+           .build());
+        
+
+
+
+ +1. Enable GPU delegate in the interpreter options: set the delegate factory to + GpuDelegateFactory by calling `addDelegateFactory() + within`InterpreterApi.Options()`: + +
+ +
+

Kotlin

+
+        val interpreterOption = InterpreterApi.Options()
+         .setRuntime(TfLiteRuntime.FROM_SYSTEM_ONLY)
+         .addDelegateFactory(GpuDelegateFactory())
+        
+
+
+

Java

+
+        Options interpreterOption = InterpreterApi.Options()
+          .setRuntime(TfLiteRuntime.FROM_SYSTEM_ONLY)
+          .addDelegateFactory(new GpuDelegateFactory());
+        
+
+
+
+ +## Migrating from stand-alone TensorFlow Lite {:#migrating} + +If you are planning to migrate your app from stand-alone TensorFlow Lite to the +Play services API, review the following additional guidance for updating your +app project code: + +1. Review the [Limitations](#limitations) section of this page to ensure your + use case is supported. +2. Prior to updating your code, do performance and accuracy checks for your + models, particularly if you are using versions of TensorFlow Lite earlier + than version 2.1, so you have a baseline to compare against the new + implementation. +3. If you have migrated all of your code to use the Play services API for + TensorFlow Lite, you should remove the existing TensorFlow Lite *runtime + library* dependencies (entries with + org.tensorflow:**tensorflow-lite**:*) from your build.gradle + file so that you can reduce your app size. +4. Identify all occurrences of `new Interpreter` object creation in your code, + and modify each one so that it uses the InterpreterApi.create() call. The + new TfLite.initialize is asynchronous, which means in most cases it's not a + drop-in replacement: you must register a listener for when the call + completes. Refer to the code snippet in [Step 3](#step_3_interpreter) code. +5. Add `import org.tensorflow.lite.InterpreterApi;` and `import + org.tensorflow.lite.InterpreterApi.Options.TfLiteRuntime;` to any source + files using the `org.tensorflow.lite.Interpreter` or + `org.tensorflow.lite.InterpreterApi` classes. +6. If any of the resulting calls to `InterpreterApi.create()` have only a + single argument, append `new InterpreterApi.Options()` to the argument list. +7. Append `.setRuntime(TfLiteRuntime.FROM_SYSTEM_ONLY)` to the last argument of + any calls to `InterpreterApi.create()`. +8. Replace all other occurrences of the `org.tensorflow.lite.Interpreter` class + with `org.tensorflow.lite.InterpreterApi`. + +If you want to use stand-alone TensorFlow Lite and the Play services API +side-by-side, you must use TensorFlow Lite 2.9 (or later). TensorFlow Lite 2.8 +and earlier versions are not compatible with the Play services API version. diff --git a/tensorflow/lite/g3doc/android/lite_build.md b/tensorflow/lite/g3doc/android/lite_build.md index 4324c0b9d5d4bf..b9e0ab86649537 100644 --- a/tensorflow/lite/g3doc/android/lite_build.md +++ b/tensorflow/lite/g3doc/android/lite_build.md @@ -23,6 +23,20 @@ allprojects { } ``` +add nightly snapshots to dependencies (or edit as needed) to your build.gradle + +```groovy +... +dependencies { + ... + implementation 'org.tensorflow:tensorflow-lite:0.0.0-nightly-SNAPSHOT' + implementation 'org.tensorflow:tensorflow-lite-gpu:0.0.0-nightly-SNAPSHOT' + implementation 'org.tensorflow:tensorflow-lite-support:0.0.0-nightly-SNAPSHOT' + ... +} +... +``` + ## Build TensorFlow Lite locally In some cases, you might wish to use a local build of TensorFlow Lite. For diff --git a/tensorflow/lite/g3doc/android/native.md b/tensorflow/lite/g3doc/android/native.md new file mode 100644 index 00000000000000..a7ef8f055b3974 --- /dev/null +++ b/tensorflow/lite/g3doc/android/native.md @@ -0,0 +1,207 @@ +# TensorFlow Lite in Google Play services C API (Beta) + +Beta: TensorFlow Lite in Google Play services C API is currently in Beta. + +TensorFlow Lite in Google Play services runtime allows you to run machine +learning (ML) models without statically bundling TensorFlow Lite libraries into +your app. This guide provide instructions on how to use the C APIs for Google +Play services. + +Before working with the TensorFlow Lite in Google Play services C API, make sure +you have the [CMake](https://cmake.org/) build tool installed. + +## Update your build configuration + +Add the following dependencies to your app project code to access the Play +services API for TensorFlow Lite: + +``` +implementation "com.google.android.gms:play-services-tflite-java:16.2.0-beta02" +``` + +Then, enable the +[Prefab](https://developer.android.com/build/dependencies#build-system-configuration) +feature to access the C API from your CMake script by updating the android block +of your module's build.gradle file: + +``` +buildFeatures { + prefab = true +} +``` + +You finally need to add the package `tensorflowlite_jni_gms_client` imported +from the AAR as a dependency in your CMake script: + +``` +find_package(tensorflowlite_jni_gms_client REQUIRED CONFIG) + +target_link_libraries(tflite-jni # your JNI lib target + tensorflowlite_jni_gms_client::tensorflowlite_jni_gms_client + android # other deps for your target + log) + +# Also add -DTFLITE_IN_GMSCORE -DTFLITE_WITH_STABLE_ABI +# to the C/C++ compiler flags. + +add_compile_definitions(TFLITE_IN_GMSCORE) +add_compile_definitions(TFLITE_WITH_STABLE_ABI) +``` + +## Initialize the TensorFlow Lite runtime + +Before calling the TensorFlow Lite Native API you must initialize the +`TfLiteNative` runtime in your Java/Kotlin code. + +
+ +
+

Java

+
+Task tfLiteInitializeTask = TfLiteNative.initialize(context);
+      
+
+
+

Kotlin

+
+val tfLiteInitializeTask: Task = TfLiteNative.initialize(context)
+        
+
+
+
+ +Using the Google Play services Task API, `TfLiteNative.initialize` +asynchronously loads the TFLite runtime from Google Play services into your +application's runtime process. Use `addOnSuccessListener()` to make sure the +`TfLite.initialize()` task completes before executing code that accesses +TensorFlow Lite APIs. Once the task has completed successfully, you can invoke +all the available TFLite Native APIs. + +## Native code implementation + +To use TensorFlow Lite in Google Play services with your native code, you can do +one of the following: + +- declare new JNI functions to call native functions from your Java code +- Call the TensorFlow Lite Native API from your existing native C code. + +JNI functions: + +You can declare a new JNI function to make the TensorFlow Lite runtime declared +in Java/Kotlin accessible to your native code as follow: + +
+ +
+

Java

+
+package com.google.samples.gms.tflite.c;
+
+public class TfLiteJni {
+  static {
+    System.loadLibrary("tflite-jni");
+  }
+  public TfLiteJni() { /**/ };
+  public native void loadModel(AssetManager assetManager, String assetName);
+  public native float[] runInference(float[] input);
+}
+      
+
+
+

Kotlin

+
+package com.google.samples.gms.tflite.c
+
+class TfLiteJni() {
+  companion object {
+    init {
+      System.loadLibrary("tflite-jni")
+    }
+  }
+  external fun loadModel(assetManager: AssetManager, assetName: String)
+  external fun runInference(input: FloatArray): FloatArray
+}
+        
+
+
+
+ +Matching the following `loadModel` and `runInference` native functions: + +``` +#ifdef __cplusplus +extern "C" { +#endif + +void Java_com_google_samples_gms_tflite_c_loadModel( + JNIEnv *env, jobject tflite_jni, jobject asset_manager, jstring asset_name){} + //... +} + +jfloatArray Java_com_google_samples_gms_tflite_c_TfLiteJni_runInference( + JNIEnv* env, jobject tfliteJni, jfloatArray input) { + //... +} + +#ifdef __cplusplus +} // extern "C". +#endif +``` + +You can then call your C functions from your Java/Kotlin code: + +
+ +
+

Java

+
+tfLiteHandleTask.onSuccessTask(unused -> {
+    TfLiteJni jni = new TfLiteJni();
+    jni.loadModel(getAssets(), "add.bin");
+    //...
+});
+    
+
+
+

Kotlin

+
+tfLiteHandleTask.onSuccessTask {
+    val jni = TfLiteJni()
+    jni.loadModel(assets, "add.bin")
+    // ...
+}
+      
+
+
+
+ +### TensorFlow Lite in C code + +Include the appropriate API header file to include the TfLite with Google Play +services API: + +``` +#include "tensorflow/lite/c/c_api.h" +``` + +You can then use the regular TensorFlow Lite C API: + +``` +auto model = TfLiteModelCreate(model_asset, model_asset_length); +// ... +auto options = TfLiteInterpreterOptionsCreate(); +// ... +auto interpreter = TfLiteInterpreterCreate(model, options); +``` + +The TensorFlow Lite with Google Play services Native API headers provide the +same API as the regular +[TensorFlow Lite C API](https://www.tensorflow.org/lite/api_docs/c), excluding +features that are deprecated or experimental. For now the functions and types +from the `c_api.h`, `c_api_types.h` and `common.h` headers are available. Please +note that functions from the `c_api_experimental.h` header are not supported. +The documentation can be found +[online](https://www.tensorflow.org/lite/api_docs/c). + +You can use functions specific to TensorFlow Lite with Google Play Services by +including `tflite.h`. diff --git a/tensorflow/lite/g3doc/android/play_services.md b/tensorflow/lite/g3doc/android/play_services.md index 76b3e136518ce9..95fb79e0c344a6 100644 --- a/tensorflow/lite/g3doc/android/play_services.md +++ b/tensorflow/lite/g3doc/android/play_services.md @@ -27,558 +27,11 @@ the APIs. ## Using the Play services runtime -TensorFlow Lite in Google Play services is available through the -[TensorFlow Lite Task API](../api_docs/java/org/tensorflow/lite/task/core/package-summary) -and -[TensorFlow Lite Interpreter API](../api_docs/java/org/tensorflow/lite/InterpreterApi). -The Task Library provides optimized out-of-box model interfaces for common -machine learning tasks using visual, audio, and text data. The TensorFlow Lite -Interpreter API, provided by the TensorFlow runtime and support libraries, -provides a more general-purpose interface for building and running ML models. +The TensorFlow Lite in Google Play services is available through the following +programming language apis: -The following sections provide instructions on how to implement the Interpreter -and Task Library APIs in Google Play services. While it is possible for an app -to use both the Interpreter APIs and Task Library APIs, most apps should only -use one set of APIs. - -### Using the Task Library APIs - -The TensorFlow Lite Task API wraps the Interpreter API and provides a high-level -programming interface for common machine learning tasks that use visual, audio, -and text data. You should use the Task API if your application requires one of -the -[supported tasks](../inference_with_metadata/task_library/overview#supported_tasks). - -#### 1. Add project dependencies - -Your project dependency depends on your machine learning use case. The Task APIs -contain the following libraries: - -* Vision library: `org.tensorflow:tensorflow-lite-task-vision-play-services` -* Audio library: `org.tensorflow:tensorflow-lite-task-audio-play-services` -* Text library: `org.tensorflow:tensorflow-lite-task-text-play-services` - -Add one of the dependencies to your app project code to access the Play services -API for TensorFlow Lite. For example, use the following to implement a vision -task: - -``` -dependencies { -... - implementation 'org.tensorflow:tensorflow-lite-task-vision-play-services:0.4.2' -... -} -``` - -Caution: The TensorFlow Lite Tasks Audio library version 0.4.2 -maven repository is incomplete. Use version 0.4.2.1 for this library instead: -`org.tensorflow:tensorflow-lite-task-audio-play-services:0.4.2.1`. - -#### 2. Add initialization of TensorFlow Lite - -Initialize the TensorFlow Lite component of the Google Play services API -*before* using the TensorFlow Lite APIs. The following example initializes the -vision library: - -
- -
-

Kotlin

-
-init {
-  TfLiteVision.initialize(context)
-    }
-  }
-
-
-
-
- -Important: Make sure the `TfLite.initialize` task completes before executing -code that accesses TensorFlow Lite APIs. - -Tip: The TensorFlow Lite modules are installed at the same time your application -is installed or updated from the Play Store. You can check the availability of -the modules by using `ModuleInstallClient` from the Google Play services APIs. -For more information on checking module availability, see -[Ensuring API availability with ModuleInstallClient](https://developers.google.com/android/guides/module-install-apis). - -#### 3. Run inferences - -After initializing the TensorFlow Lite component, call the `detect()` method to -generate inferences. The exact code within the `detect()` method varies -depending on the library and use case. The following is for a simple object -detection use case with the `TfLiteVision` library: - -
- -
-

Kotlin

-
-fun detect(...) {
-  if (!TfLiteVision.isInitialized()) {
-    Log.e(TAG, "detect: TfLiteVision is not initialized yet")
-    return
-  }
-
-  if (objectDetector == null) {
-    setupObjectDetector()
-  }
-
-  ...
-
-}
-
-
-
-
- -Depending on the data format, you may also need to preprocess and convert your -data within the `detect()` method before generating inferences. For example, -image data for an object detector requires the following: - -```kotlin -val imageProcessor = ImageProcessor.Builder().add(Rot90Op(-imageRotation / 90)).build() -val tensorImage = imageProcessor.process(TensorImage.fromBitmap(image)) -val results = objectDetector?.detect(tensorImage) -``` - -### Using the Interpreter APIs - -The Interpreter APIs offer more control and flexibility than the Task Library -APIs. You should use the Interpreter APIs if your machine learning task is not -supported by the Task library, or if you require a more general-purpose -interface for building and running ML models. - -#### 1. Add project dependencies - -Add the following dependencies to your app project code to access the Play -services API for TensorFlow Lite: - -``` -dependencies { -... - // Tensorflow Lite dependencies for Google Play services - implementation 'com.google.android.gms:play-services-tflite-java:16.0.1' - // Optional: include Tensorflow Lite Support Library - implementation 'com.google.android.gms:play-services-tflite-support:16.0.1' -... -} -``` - -#### 2. Add initialization of TensorFlow Lite - -Initialize the TensorFlow Lite component of the Google Play services API -*before* using the TensorFlow Lite APIs: - -
- -
-

Kotlin

-
-val initializeTask: Task<Void> by lazy { TfLite.initialize(this) }
-
-
-
-

Java

-
-Task<Void> initializeTask = TfLite.initialize(context);
-
-
-
-
- -Note: Make sure the `TfLite.initialize` task completes before executing code -that accesses TensorFlow Lite APIs. Use the `addOnSuccessListener()` method, as -shown in the next section. - -#### 3. Create an Interpreter and set runtime option {:#step_3_interpreter} - -Create an interpreter using `InterpreterApi.create()` and configure it to use -Google Play services runtime, by calling `InterpreterApi.Options.setRuntime()`, -as shown in the following example code: - -
- -
-

Kotlin

-
-import org.tensorflow.lite.InterpreterApi
-import org.tensorflow.lite.InterpreterApi.Options.TfLiteRuntime
-...
-private lateinit var interpreter: InterpreterApi
-...
-initializeTask.addOnSuccessListener {
-  val interpreterOption =
-    InterpreterApi.Options().setRuntime(TfLiteRuntime.FROM_SYSTEM_ONLY)
-  interpreter = InterpreterApi.create(
-    modelBuffer,
-    interpreterOption
-  )}
-  .addOnFailureListener { e ->
-    Log.e("Interpreter", "Cannot initialize interpreter", e)
-  }
-
-
-
-

Java

-
-import org.tensorflow.lite.InterpreterApi
-import org.tensorflow.lite.InterpreterApi.Options.TfLiteRuntime
-...
-private InterpreterApi interpreter;
-...
-initializeTask.addOnSuccessListener(a -> {
-    interpreter = InterpreterApi.create(modelBuffer,
-      new InterpreterApi.Options().setRuntime(TfLiteRuntime.FROM_SYSTEM_ONLY));
-  })
-  .addOnFailureListener(e -> {
-    Log.e("Interpreter", String.format("Cannot initialize interpreter: %s",
-          e.getMessage()));
-  });
-
-
-
-
- -You should use the implementation above because it avoids blocking the Android -user interface thread. If you need to manage thread execution more closely, you -can add a `Tasks.await()` call to interpreter creation: - -
- -
-

Kotlin

-
-import androidx.lifecycle.lifecycleScope
-...
-lifecycleScope.launchWhenStarted { // uses coroutine
-  initializeTask.await()
-}
-
-
-
-

Java

-
-@BackgroundThread
-InterpreterApi initializeInterpreter() {
-    Tasks.await(initializeTask);
-    return InterpreterApi.create(...);
-}
-
-
-
-
- -Warning: Do not call `.await()` on the foreground user interface thread because -it interrupts display of user interface elements and creates a poor user -experience. - -#### 4. Run inferences - -Using the `interpreter` object you created, call the `run()` method to generate -an inference. - -
- -
-

Kotlin

-
-interpreter.run(inputBuffer, outputBuffer)
-
-
-
-

Java

-
-interpreter.run(inputBuffer, outputBuffer);
-
-
-
-
- -## Hardware acceleration {:#hardware-acceleration} - -TensorFlow Lite allows you to accelerate the performance of your model using -specialized hardware processors, such as graphics processing units (GPUs). You -can take advantage of these specialized processors using hardware drivers called -[*delegates*](https://www.tensorflow.org/lite/performance/delegates). You can -use the following hardware acceleration delegates with TensorFlow Lite in Google -Play services: - -- *[GPU delegate](https://www.tensorflow.org/lite/performance/gpu) - (recommended)* - This delegate is provided through Google Play services and - is dynamically loaded, just like the Play services versions of the Task API - and Interpreter API. - -- [*NNAPI delegate*](https://www.tensorflow.org/lite/android/delegates/nnapi) - - This delegate is available as an included library dependency in your Android - development project, and is bundled into your app. - -For more information about hardware acceleration with TensorFlow Lite, see the -[TensorFlow Lite Delegates](https://www.tensorflow.org/lite/performance/delegates) -page. - -### Checking device compatibility - -Not all devices support GPU hardware acceleration with TFLite. In order to -mitigate errors and potential crashes, use the -`TfLiteGpu.isGpuDelegateAvailable` method to check whether a device is -compatible with the GPU delegate. - -Use this method to confirm whether a device is compatible with GPU, and use CPU -or the NNAPI delegate as a fallback for when GPU is not supported. - -``` -useGpuTask = TfLiteGpu.isGpuDelegateAvailable(context) -``` - -Once you have a variable like `useGpuTask`, you can use it to determine whether -devices use the GPU delegate. The following examples show how this can be done -with both the Task Library and Interpreter APIs. - -**With the Task Api** - -
- -
-

Kotlin

-
-lateinit val optionsTask = useGpuTask.continueWith { task ->
-  val baseOptionsBuilder = BaseOptions.builder()
-  if (task.result) {
-    baseOptionsBuilder.useGpu()
-  }
- ObjectDetectorOptions.builder()
-          .setBaseOptions(baseOptionsBuilder.build())
-          .setMaxResults(1)
-          .build()
-}
-    
-
-
-

Java

-
-Task<ObjectDetectorOptions> optionsTask = useGpuTask.continueWith({ task ->
-  BaseOptions baseOptionsBuilder = BaseOptions.builder();
-  if (task.getResult()) {
-    baseOptionsBuilder.useGpu();
-  }
-  return ObjectDetectorOptions.builder()
-          .setBaseOptions(baseOptionsBuilder.build())
-          .setMaxResults(1)
-          .build()
-});
-    
-
-
-
- -**With the Interpreter Api** - -
- -
-

Kotlin

-
-val interpreterTask = useGpuTask.continueWith { task ->
-  val interpreterOptions = InterpreterApi.Options()
-      .setRuntime(TfLiteRuntime.FROM_SYSTEM_ONLY)
-  if (task.result) {
-      interpreterOptions.addDelegateFactory(GpuDelegateFactory())
-  }
-  InterpreterApi.create(FileUtil.loadMappedFile(context, MODEL_PATH), interpreterOptions)
-}
-    
-
-
-

Java

-
-Task<InterpreterApi.Options> interpreterOptionsTask = useGpuTask.continueWith({ task ->
-  InterpreterApi.Options options =
-      new InterpreterApi.Options().setRuntime(TfLiteRuntime.FROM_SYSTEM_ONLY);
-  if (task.getResult()) {
-     options.addDelegateFactory(new GpuDelegateFactory());
-  }
-  return options;
-});
-    
-
-
-
- -### GPU with Task Library APIs - -To use the GPU delegate with the Task APIs: - -1. Update the project dependencies to use the GPU delegate from Play services: - - ``` - implementation 'com.google.android.gms:play-services-tflite-gpu:16.1.0' - ``` - -1. Initialize the GPU delegate with `setEnableGpuDelegateSupport`. For example, - you can initialize the GPU delegate for `TfLiteVision` with the following: - -
- -
-

Kotlin

-
-        TfLiteVision.initialize(context, TfLiteInitializationOptions.builder().setEnableGpuDelegateSupport(true).build())
-        
-
-
-

Java

-
-        TfLiteVision.initialize(context, TfLiteInitializationOptions.builder().setEnableGpuDelegateSupport(true).build());
-        
-
-
-
- -1. Enable the GPU delegate option with - [`BaseOptions`](https://www.tensorflow.org/lite/api_docs/java/org/tensorflow/lite/task/core/BaseOptions.Builder): - -
- -
-

Kotlin

-
-        val baseOptions = BaseOptions.builder().useGpu().build()
-        
-
-
-

Java

-
-        BaseOptions baseOptions = BaseOptions.builder().useGpu().build();
-        
-
-
-
- -1. Configure the options using `.setBaseOptions`. For example, you can set up - GPU in `ObjectDetector` with the following: - -
- -
-

Kotlin

-
-        val options =
-            ObjectDetectorOptions.builder()
-                .setBaseOptions(baseOptions)
-                .setMaxResults(1)
-                .build()
-        
-
-
-

Java

-
-        ObjectDetectorOptions options =
-            ObjectDetectorOptions.builder()
-                .setBaseOptions(baseOptions)
-                .setMaxResults(1)
-                .build();
-        
-
-
-
- -### GPU with Interpreter APIs - -To use the GPU delegate with the Interpreter APIs: - -1. Update the project dependencies to use the GPU delegate from Play services: - - ``` - implementation 'com.google.android.gms:play-services-tflite-gpu:16.1.0' - ``` - -1. Enable the GPU delegate option in the TFlite initialization: - -
- -
-

Kotlin

-
-        TfLite.initialize(context,
-          TfLiteInitializationOptions.builder()
-           .setEnableGpuDelegateSupport(true)
-           .build())
-        
-
-
-

Java

-
-        TfLite.initialize(context,
-          TfLiteInitializationOptions.builder()
-           .setEnableGpuDelegateSupport(true)
-           .build());
-        
-
-
-
- -1. Set GPU delegate in interpreter options to use `DelegateFactory` by calling - `addDelegateFactory()` within `InterpreterApi.Options()`: - -
- -
-

Kotlin

-
-        val interpreterOption = InterpreterApi.Options()
-         .setRuntime(TfLiteRuntime.FROM_SYSTEM_ONLY)
-         .addDelegateFactory(GpuDelegateFactory())
-        
-
-
-

Java

-
-        Options interpreterOption = InterpreterApi.Options()
-          .setRuntime(TfLiteRuntime.FROM_SYSTEM_ONLY)
-          .addDelegateFactory(new GpuDelegateFactory());
-        
-
-
-
- -## Migrating from stand-alone TensorFlow Lite {:#migrating} - -If you are planning to migrate your app from stand-alone TensorFlow Lite to the -Play services API, review the following additional guidance for updating your -app project code: - -1. Review the [Limitations](#limitations) section of this page to ensure your - use case is supported. -2. Prior to updating your code, do performance and accuracy checks for your - models, particularly if you are using versions of TensorFlow Lite earlier - than version 2.1, so you have a baseline to compare against the new - implementation. -3. If you have migrated all of your code to use the Play services API for - TensorFlow Lite, you should remove the existing TensorFlow Lite *runtime - library* dependencies (entries with - org.tensorflow:**tensorflow-lite**:*) from your build.gradle - file so that you can reduce your app size. -4. Identify all occurrences of `new Interpreter` object creation in your code, - and modify it so that it uses the InterpreterApi.create() call. This new API - is asynchronous, which means in most cases it's not a drop-in replacement, - and you must register a listener for when the call completes. Refer to the - code snippet in [Step 3](#step_3_interpreter) code. -5. Add `import org.tensorflow.lite.InterpreterApi;` and `import - org.tensorflow.lite.InterpreterApi.Options.TfLiteRuntime;` to any source - files using the `org.tensorflow.lite.Interpreter` or - `org.tensorflow.lite.InterpreterApi` classes. -6. If any of the resulting calls to `InterpreterApi.create()` have only a - single argument, append `new InterpreterApi.Options()` to the argument list. -7. Append `.setRuntime(TfLiteRuntime.FROM_SYSTEM_ONLY)` to the last argument of - any calls to `InterpreterApi.create()`. -8. Replace all other occurrences of the `org.tensorflow.lite.Interpreter` class - with `org.tensorflow.lite.InterpreterApi`. - -If you want to use stand-alone TensorFlow Lite and the Play services API -side-by-side, you must use TensorFlow Lite 2.9 (or later). TensorFlow Lite 2.8 -and earlier versions are not compatible with the Play services API version. +- Java API - [see guide](../android/java) +- C API - [see guide](../android/native) ## Limitations @@ -587,10 +40,6 @@ TensorFlow Lite in Google Play services has the following limitations: * Support for hardware acceleration delegates is limited to the delegates listed in the [Hardware acceleration](#hardware-acceleration) section. No other acceleration delegates are supported. -* Access to TensorFlow Lite via - [native APIs](https://www.tensorflow.org/lite/guide/inference#load_and_run_a_model_in_c) - is not supported. Only the TensorFlow Lite Java APIs are available through - Google Play services. * Experimental or deprecated TensorFlow Lite APIs, including custom ops, are not supported. diff --git a/tensorflow/lite/g3doc/examples/jax_conversion/overview.ipynb b/tensorflow/lite/g3doc/examples/jax_conversion/jax_to_tflite.ipynb similarity index 73% rename from tensorflow/lite/g3doc/examples/jax_conversion/overview.ipynb rename to tensorflow/lite/g3doc/examples/jax_conversion/jax_to_tflite.ipynb index 164c1f03cf42ff..5adcf13b27adcb 100644 --- a/tensorflow/lite/g3doc/examples/jax_conversion/overview.ipynb +++ b/tensorflow/lite/g3doc/examples/jax_conversion/jax_to_tflite.ipynb @@ -1,21 +1,4 @@ { - "nbformat": 4, - "nbformat_minor": 0, - "metadata": { - "colab": { - "name": "Jax to TFLite.ipynb", - "provenance": [], - "collapsed_sections": [], - "toc_visible": true - }, - "kernelspec": { - "name": "python3", - "display_name": "Python 3" - }, - "language_info": { - "name": "python" - } - }, "cells": [ { "cell_type": "markdown", @@ -28,9 +11,11 @@ }, { "cell_type": "code", + "execution_count": null, "metadata": { "id": "qLCxmWRyRMZE" }, + "outputs": [], "source": [ "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n", "# you may not use this file except in compliance with the License.\n", @@ -43,9 +28,14 @@ "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "# See the License for the specific language governing permissions and\n", "# limitations under the License." - ], - "execution_count": null, - "outputs": [] + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "8LYgHRFPRpS1" + }, + "source": [] }, { "cell_type": "markdown", @@ -55,7 +45,7 @@ "source": [ "# Jax Model Conversion For TFLite\n", "## Overview\n", - "Note: This API is new and only available via pip install tf-nightly. It will be available in TensorFlow version 2.7. Also, the API is still experimental and subject to changes.\n", + "Note: This API is new and we recommend using via pip install tf-nightly. Also, the API is still experimental and subject to changes.\n", "\n", "This CodeLab demonstrates how to build a model for MNIST recognition using Jax, and how to convert it to TensorFlow Lite. This codelab will also demonstrate how to optimize the Jax-converted TFLite model with post-training quantiztion." ] @@ -66,20 +56,20 @@ "id": "i8cfOBcjSByO" }, "source": [ - "\n", - " \n", - " \n", - " \n", - " \n", - "
\n", - " View on TensorFlow.org\n", - " \n", - " Run in Google Colab\n", - " \n", - " View source on GitHub\n", - " \n", - " Download notebook\n", - "
" + "\u003ctable class=\"tfo-notebook-buttons\" align=\"left\"\u003e\n", + " \u003ctd\u003e\n", + " \u003ca target=\"_blank\" href=\"https://www.tensorflow.org/lite/examples/jax_conversion/jax_to_tflite\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" /\u003eView on TensorFlow.org\u003c/a\u003e\n", + " \u003c/td\u003e\n", + " \u003ctd\u003e\n", + " \u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/lite/g3doc/examples/jax_conversion/jax_to_tflite.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003eRun in Google Colab\u003c/a\u003e\n", + " \u003c/td\u003e\n", + " \u003ctd\u003e\n", + " \u003ca target=\"_blank\" href=\"https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/g3doc/examples/jax_conversion/jax_to_tflite.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView source on GitHub\u003c/a\u003e\n", + " \u003c/td\u003e\n", + " \u003ctd\u003e\n", + " \u003ca href=\"https://storage.googleapis.com/tensorflow_docs/tensorflow/tensorflow/lite/g3doc/examples/jax_conversion/jax_to_tflite.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/download_logo_32px.png\" /\u003eDownload notebook\u003c/a\u003e\n", + " \u003c/td\u003e\n", + "\u003c/table\u003e" ] }, { @@ -94,16 +84,52 @@ }, { "cell_type": "code", + "execution_count": null, "metadata": { "id": "EV04hKdrnE4f" }, + "outputs": [], "source": [ "!pip install tf-nightly --upgrade\n", - "!pip install jax --upgrade\n", - "!pip install jaxlib --upgrade" - ], + "!pip install jax --upgrade" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "vsilblGuGQa2" + }, + "outputs": [], + "source": [ + "# Make sure your JAX version is at least 0.4.20 or above.\n", + "import jax\n", + "jax.__version__" + ] + }, + { + "cell_type": "code", "execution_count": null, - "outputs": [] + "metadata": { + "id": "PJeQhMUwH0oX" + }, + "outputs": [], + "source": [ + "!pip install orbax-export --upgrade" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "j9_CVA0THQNc" + }, + "outputs": [], + "source": [ + "from orbax.export import ExportManager\n", + "from orbax.export import JaxModule\n", + "from orbax.export import ServingConfig" + ] }, { "cell_type": "markdown", @@ -117,9 +143,11 @@ }, { "cell_type": "code", + "execution_count": null, "metadata": { "id": "qSOPSZJn1_Tj" }, + "outputs": [], "source": [ "import numpy as np\n", "import tensorflow as tf\n", @@ -134,15 +162,15 @@ "from jax import jit, grad, random\n", "from jax.example_libraries import optimizers\n", "from jax.example_libraries import stax\n" - ], - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "code", + "execution_count": null, "metadata": { "id": "hdJIt3Da2Qn1" }, + "outputs": [], "source": [ "def _one_hot(x, k, dtype=np.float32):\n", " \"\"\"Create a one-hot encoding of x of size k.\"\"\"\n", @@ -155,9 +183,7 @@ "\n", "train_labels = _one_hot(train_labels, 10)\n", "test_labels = _one_hot(test_labels, 10)" - ], - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "markdown", @@ -170,9 +196,11 @@ }, { "cell_type": "code", + "execution_count": null, "metadata": { "id": "mi3TKB9nnQdK" }, + "outputs": [], "source": [ "def loss(params, batch):\n", " inputs, targets = batch\n", @@ -192,9 +220,7 @@ " stax.Dense(10), stax.LogSoftmax)\n", "\n", "rng = random.PRNGKey(0)" - ], - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "markdown", @@ -202,14 +228,16 @@ "id": "bRtnOBdJLd63" }, "source": [ - "## Train & Evaluate the model" + "## Train \u0026 Evaluate the model" ] }, { "cell_type": "code", + "execution_count": null, "metadata": { "id": "SWbYRyj7LYZt" }, + "outputs": [], "source": [ "step_size = 0.001\n", "num_epochs = 10\n", @@ -254,9 +282,7 @@ " print(\"Epoch {} in {:0.2f} sec\".format(epoch, epoch_time))\n", " print(\"Training set accuracy {}\".format(train_acc))\n", " print(\"Test set accuracy {}\".format(test_acc))" - ], - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "markdown", @@ -266,11 +292,10 @@ "source": [ "## Convert to TFLite model.\n", "Note here, we\n", - "1. Inline the params to the Jax `predict` func with `functools.partial`.\n", - "2. Build a `jnp.zeros`, this is a \"placeholder\" tensor used for Jax to trace the model.\n", - "3. Call `experimental_from_jax`:\n", - "> * The `serving_func` is wrapped in a list.\n", - "> * The input is associated with a given name and passed in as an array wrapped in a list.\n", + "1. Export the `JAX` model to `TF SavedModel` using `orbax`.\n", + "2. Call TFLite converter API to convert the `TF SavedModel` to `.tflite` model:\n", + "\n", + "\n", "\n", "\n", "\n" @@ -278,20 +303,25 @@ }, { "cell_type": "code", + "execution_count": null, "metadata": { "id": "6pcqKZqdNTmn" }, + "outputs": [], "source": [ - "serving_func = functools.partial(predict, params)\n", - "x_input = jnp.zeros((1, 28, 28))\n", - "converter = tf.lite.TFLiteConverter.experimental_from_jax(\n", - " [serving_func], [[('input1', x_input)]])\n", + "jax_module = JaxModule(params, predict, input_polymorphic_shape='b, ...')\n", + "converter = tf.lite.TFLiteConverter.from_concrete_functions(\n", + " [\n", + " jax_module.methods[JaxModule.DEFAULT_METHOD_KEY].get_concrete_function(\n", + " tf.TensorSpec(shape=(1, 28, 28), dtype=tf.float32, name=\"input\")\n", + " )\n", + " ]\n", + ")\n", + "\n", "tflite_model = converter.convert()\n", "with open('jax_mnist.tflite', 'wb') as f:\n", " f.write(tflite_model)" - ], - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "markdown", @@ -305,10 +335,13 @@ }, { "cell_type": "code", + "execution_count": null, "metadata": { "id": "acj2AYzjSlaY" }, + "outputs": [], "source": [ + "serving_func = functools.partial(predict, params)\n", "expected = serving_func(train_images[0:1])\n", "\n", "# Run the model with TensorFlow Lite\n", @@ -322,9 +355,7 @@ "\n", "# Assert if the result of TFLite model is consistent with the JAX model.\n", "np.testing.assert_almost_equal(expected, result, 1e-5)" - ], - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "markdown", @@ -340,15 +371,17 @@ }, { "cell_type": "code", + "execution_count": null, "metadata": { "id": "KI0rLV-Meg-2" }, + "outputs": [], "source": [ "def representative_dataset():\n", " for i in range(1000):\n", " x = train_images[i:i+1]\n", " yield [x]\n", - "\n", + "x_input = jnp.zeros((1, 28, 28))\n", "converter = tf.lite.TFLiteConverter.experimental_from_jax(\n", " [serving_func], [[('x', x_input)]])\n", "tflite_model = converter.convert()\n", @@ -358,9 +391,7 @@ "tflite_quant_model = converter.convert()\n", "with open('jax_mnist_quant.tflite', 'wb') as f:\n", " f.write(tflite_quant_model)" - ], - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "markdown", @@ -373,9 +404,11 @@ }, { "cell_type": "code", + "execution_count": null, "metadata": { "id": "X3oOm0OaevD6" }, + "outputs": [], "source": [ "expected = serving_func(train_images[0:1])\n", "\n", @@ -390,9 +423,7 @@ "\n", "# Assert if the result of TFLite model is consistent with the Jax model.\n", "np.testing.assert_almost_equal(expected, result, 1e-5)" - ], - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "markdown", @@ -406,15 +437,49 @@ }, { "cell_type": "code", + "execution_count": null, "metadata": { "id": "imFPw007juVG" }, + "outputs": [], "source": [ "!du -h jax_mnist.tflite\n", "!du -h jax_mnist_quant.tflite" - ], + ] + }, + { + "cell_type": "code", "execution_count": null, - "outputs": [] + "metadata": { + "id": "N5WdO18wNfyn" + }, + "outputs": [], + "source": [] } - ] -} \ No newline at end of file + ], + "metadata": { + "colab": { + "private_outputs": true, + "provenance": [ + { + "file_id": "1UlzbQspn2an2kzlLWZBWhP_JEqvShxmi", + "timestamp": 1705015454450 + }, + { + "file_id": "https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/g3doc/examples/jax_conversion/overview.ipynb", + "timestamp": 1698963811786 + } + ], + "toc_visible": true + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/tensorflow/lite/g3doc/examples/jax_conversion/overview.md b/tensorflow/lite/g3doc/examples/jax_conversion/overview.md new file mode 100644 index 00000000000000..d3cf89e15a7a24 --- /dev/null +++ b/tensorflow/lite/g3doc/examples/jax_conversion/overview.md @@ -0,0 +1,97 @@ +# JAX models with TensorFlow Lite + +This page provides a path for users who want to train models in JAX and deploy +to mobile for inference ([example colab](https://colab.sandbox.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/lite/g3doc/examples/jax_conversion/jax_to_tflite.ipynb)). + +The methods in this guide produce a `tflite_model` which can be used directly +with the TFLite interpreter code example or saved to a TFLite FlatBuffer file. + +## Prerequisite + +It's recommended to try this feature with the newest TensorFlow nightly Python +package. + +``` +pip install tf-nightly --upgrade +``` + +We will use the [Orbax +Export](https://orbax.readthedocs.io/en/latest/orbax_export_101.html) library to +export JAX models. Make sure your JAX version is at least 0.4.20 or above. + +``` +pip install jax --upgrade +pip install orbax-export --upgrade +``` + +## Convert JAX models to TensorFlow Lite + +We use the TensorFlow +[SavedModel](https://www.tensorflow.org/guide/saved_model) as the intermediate +format between JAX and TensorFlow Lite. Once you have a SavedModel then +existing TensorFlow Lite APIs can be used to complete the conversion process. + +```py +# This code snippet converts a JAX model to TFLite through TF SavedModel. +from orbax.export import ExportManager +from orbax.export import JaxModule +from orbax.export import ServingConfig +import tensorflow as tf +import jax.numpy as jnp + +def model_fn(_, x): + return jnp.sin(jnp.cos(x)) + +jax_module = JaxModule({}, model_fn, input_polymorphic_shape='b, ...') + +# Option 1: Simply save the model via `tf.saved_model.save` if no need for pre/post +# processing. +tf.saved_model.save( + jax_module, + '/some/directory', + signatures=jax_module.methods[JaxModule.DEFAULT_METHOD_KEY].get_concrete_function( + tf.TensorSpec(shape=(None,), dtype=tf.float32, name="input") + ), + options=tf.saved_model.SaveOptions(experimental_custom_gradients=True), +) +converter = tf.lite.TFLiteConverter.from_saved_model('/some/directory') +tflite_model = converter.convert() + +# Option 2: Define pre/post processing TF functions (e.g. (de)?tokenize). +serving_config = ServingConfig( + 'Serving_default', + # Corresponds to the input signature of `tf_preprocessor` + input_signature=[tf.TensorSpec(shape=(None,), dtype=tf.float32, name='input')], + tf_preprocessor=lambda x: x, + tf_postprocessor=lambda out: {'output': out} +) +export_mgr = ExportManager(jax_module, [serving_config]) +export_mgr.save('/some/directory') +converter = tf.lite.TFLiteConverter.from_saved_model('/some/directory') +tflite_model = converter.convert() + +# Option 3: Convert from TF concrete function directly +converter = tf.lite.TFLiteConverter.from_concrete_functions( + [ + jax_module.methods[JaxModule.DEFAULT_METHOD_KEY].get_concrete_function( + tf.TensorSpec(shape=(None,), dtype=tf.float32, name="input") + ) + ] +) +tflite_model = converter.convert() +``` + +## Check the converted TFLite model + +After the model is converted to TFLite, you can run TFLite interpreter APIs to +check model outputs. + +```py +# Run the model with TensorFlow Lite +interpreter = tf.lite.Interpreter(model_content=tflite_model) +interpreter.allocate_tensors() input_details = interpreter.get_input_details() +output_details = interpreter.get_output_details() +interpreter.set_tensor(input_details[0]["index"], input_data) +interpreter.invoke() +result = interpreter.get_tensor(output_details[0]["index"]) +``` diff --git a/tensorflow/lite/g3doc/guide/ops_custom.md b/tensorflow/lite/g3doc/guide/ops_custom.md index f9aa9bd31f4ab6..296fd6216ac397 100644 --- a/tensorflow/lite/g3doc/guide/ops_custom.md +++ b/tensorflow/lite/g3doc/guide/ops_custom.md @@ -167,19 +167,15 @@ defining those four functions and a global registration function that usually looks like this: ```c++ -namespace tflite { -namespace ops { -namespace custom { - TfLiteRegistration* Register_MY_CUSTOM_OP() { - static TfLiteRegistration r = {my_custom_op::Init, - my_custom_op::Free, - my_custom_op::Prepare, - my_custom_op::Eval}; +namespace my_namespace { + const TfLiteRegistration* Register_MY_CUSTOM_OP() { + static const TfLiteRegistration r = {my_custom_op::Init, + my_custom_op::Free, + my_custom_op::Prepare, + my_custom_op::Eval}; return &r; } -} // namespace custom -} // namespace ops -} // namespace tflite +} // namespace my_namespace ``` Note that registration is not automatic and an explicit call to @@ -231,8 +227,8 @@ TfLiteStatus AtanEval(TfLiteContext* context, TfLiteNode* node) { return kTfLiteOk; } -TfLiteRegistration* Register_ATAN() { - static TfLiteRegistration r = {nullptr, nullptr, AtanPrepare, AtanEval}; +const TfLiteRegistration* Register_ATAN() { + static const TfLiteRegistration r = {nullptr, nullptr, AtanPrepare, AtanEval}; return &r; } ``` @@ -259,10 +255,29 @@ code, is defined like this: ```c++ class OpResolver { + public: virtual TfLiteRegistration* FindOp(tflite::BuiltinOperator op) const = 0; virtual TfLiteRegistration* FindOp(const char* op) const = 0; - virtual void AddBuiltin(tflite::BuiltinOperator op, TfLiteRegistration* registration) = 0; - virtual void AddCustom(const char* op, TfLiteRegistration* registration) = 0; + ... +}; +``` + +The `MutableOpResolver` and `BuiltinOpResolver` classes are derived from +`OpResolver`: + +```c++ +class MutableOpResolver : public OpResolver { + public: + MutableOpResolver(); // Constructs an initially empty op resolver. + void AddBuiltin(tflite::BuiltinOperator op, const TfLiteRegistration* registration) = 0; + void AddCustom(const char* op, const TfLiteRegistration* registration) = 0; + void AddAll(const MutableOpResolver& other); + ... +}; + +class BuiltinOpResolver : public MutableOpResolver { + public: + BuiltinOpResolver(); // Constructs an op resolver with all the builtin ops. }; ``` @@ -272,10 +287,13 @@ Regular usage requires that you use the `BuiltinOpResolver` and write: tflite::ops::builtin::BuiltinOpResolver resolver; ``` -To add the custom op created above, you call `AddOp` (before you pass the -resolver to the `InterpreterBuilder`): +To add the custom op created above, you can instead use a `MutableOpResolver`, +and call `AddCustom` (before you pass the resolver to the +`InterpreterBuilder`): ```c++ +tflite::ops::builtin::MutableOpResolver resolver; +resolver.AddAll(tflite::ops::builtin::BuiltinOpResolver()); resolver.AddCustom("Atan", Register_ATAN()); ``` @@ -293,8 +311,8 @@ place your registrations in the Note that a similar process as above can be followed for supporting a set of operations instead of a single operator. Just add as many `AddCustom` operators -as you need. In addition, `BuiltinOpResolver` also allows you to override -implementations of builtins by using the `AddBuiltin`. +as you need. In addition, `MutableOpResolver` also allows you to override +implementations of builtins by using `AddBuiltin`. ### Test and profile your operator diff --git a/tensorflow/lite/g3doc/performance/quantization_spec.md b/tensorflow/lite/g3doc/performance/quantization_spec.md index f98f2922fc5df2..11f808e3980026 100644 --- a/tensorflow/lite/g3doc/performance/quantization_spec.md +++ b/tensorflow/lite/g3doc/performance/quantization_spec.md @@ -191,7 +191,7 @@ FULLY_CONNECTED Input 1 (Weight): data_type : int8 range : [-127, 127] - granularity: per-tensor + granularity: per-axis (dim = 0) restriction: zero_point = 0 Input 2 (Bias): data_type : int32 diff --git a/tensorflow/lite/interpreter_options.h b/tensorflow/lite/interpreter_options.h index d20fd5cb087def..a7557e5412188e 100644 --- a/tensorflow/lite/interpreter_options.h +++ b/tensorflow/lite/interpreter_options.h @@ -25,12 +25,6 @@ namespace tflite { /// WARNING: This is an experimental API and subject to change. class InterpreterOptions { public: - InterpreterOptions() - : experimental_preserve_all_tensors_(false), - experimental_ensure_dynamic_tensors_are_released_(false), - experimental_optimize_memory_for_large_tensors_(0), - experimental_disable_delegate_clustering_(false) {} - /// Preserving all intermediates tensors for debugging. /// WARNING: This is an experimental API and subject to change. void SetPreserveAllTensors(bool value = true) { @@ -93,11 +87,28 @@ class InterpreterOptions { experimental_disable_delegate_clustering_ = value; } + // If set to `true`, the CAST op will cache its output when its input is a + // constant tensor. + // + // WARNING: This is an experimental API and subject to change. + void SetCacheConstantCastOp(bool value) { + experimental_cache_constant_cast_op_ = value; + } + + // If `true`, the CAST op will cache its output when its input is a constant + // tensor. + // + // WARNING: This is an experimental API and subject to change. + bool GetCacheConstantCastOp() const { + return experimental_cache_constant_cast_op_; + } + private: - bool experimental_preserve_all_tensors_; - bool experimental_ensure_dynamic_tensors_are_released_; - int experimental_optimize_memory_for_large_tensors_; - bool experimental_disable_delegate_clustering_; + bool experimental_preserve_all_tensors_ = false; + bool experimental_ensure_dynamic_tensors_are_released_ = false; + int experimental_optimize_memory_for_large_tensors_ = 0; + bool experimental_disable_delegate_clustering_ = false; + bool experimental_cache_constant_cast_op_ = false; }; } // namespace tflite diff --git a/tensorflow/lite/ios/allowlist_TensorFlowLiteCMetal.txt b/tensorflow/lite/ios/allowlist_TensorFlowLiteCMetal.txt index 1ae124bcbddab8..2a745622252dc3 100644 --- a/tensorflow/lite/ios/allowlist_TensorFlowLiteCMetal.txt +++ b/tensorflow/lite/ios/allowlist_TensorFlowLiteCMetal.txt @@ -1,3 +1,4 @@ +_TFLGpuDelegateOptionsDefault _TFLGpuDelegateCreate _TFLGpuDelegateDelete _TFLGpuDelegateBindMetalBufferToTensor \ No newline at end of file diff --git a/tensorflow/lite/java/BUILD b/tensorflow/lite/java/BUILD index 3f9fe7fea2a364..ae7a0fd7dad92f 100644 --- a/tensorflow/lite/java/BUILD +++ b/tensorflow/lite/java/BUILD @@ -671,39 +671,44 @@ java_test_with_tflite( ], ) -java_test_with_tflite( - name = "InterpreterApiTest", - size = "small", - srcs = [ - "src/test/java/org/tensorflow/lite/InterpreterApiTest.java", - "src/test/java/org/tensorflow/lite/SupportedFeatures.java", - "src/test/java/org/tensorflow/lite/TestUtils.java", - ], - data = [ - "src/testdata/add.bin", - "src/testdata/add_unknown_dimensions.bin", - "src/testdata/mul_add_signature_def.bin", - "src/testdata/tile_with_bool_input.bin", - "//tensorflow/lite:testdata/dynamic_shapes.bin", - "//tensorflow/lite:testdata/multi_add.bin", - "//tensorflow/lite:testdata/multi_add_flex.bin", - ], - javacopts = JAVACOPTS, - test_class = "org.tensorflow.lite.InterpreterApiTest", - tflite_deps = [ - ":test_init", - ], - tflite_jni_binaries = [ - "//tensorflow/lite/java/src/test/native:libtensorflowlite_stable_test_jni.so", - ], - visibility = ["//visibility:private"], - deps = [ - ":tensorflowlite_javalib_stable", - "//third_party/java/mockito", - "@com_google_truth", - "@junit", - ], -) +# Disabled under the (b/279852433) because caused an error in the OSS +# TODO(zhurakovskyi): Uncomment when fixed. +# +# copybara:uncomment_begin +# java_test_with_tflite( +# name = "InterpreterApiTest", +# size = "small", +# srcs = [ +# "src/test/java/org/tensorflow/lite/InterpreterApiTest.java", +# "src/test/java/org/tensorflow/lite/SupportedFeatures.java", +# "src/test/java/org/tensorflow/lite/TestUtils.java", +# ], +# data = [ +# "src/testdata/add.bin", +# "src/testdata/add_unknown_dimensions.bin", +# "src/testdata/mul_add_signature_def.bin", +# "src/testdata/tile_with_bool_input.bin", +# "//tensorflow/lite:testdata/dynamic_shapes.bin", +# "//tensorflow/lite:testdata/multi_add.bin", +# "//tensorflow/lite:testdata/multi_add_flex.bin", +# ], +# javacopts = JAVACOPTS, +# test_class = "org.tensorflow.lite.InterpreterApiTest", +# tflite_deps = [ +# ":test_init", +# ], +# tflite_jni_binaries = [ +# "//tensorflow/lite/java/src/test/native:libtensorflowlite_stable_test_jni.so", +# ], +# visibility = ["//visibility:private"], +# deps = [ +# ":tensorflowlite_javalib_stable", +# "//third_party/java/mockito", +# "@com_google_truth", +# "@junit", +# ], +# ) +# copybara:uncomment_end java_test_with_tflite( name = "InterpreterApiNoRuntimeTest", @@ -732,33 +737,37 @@ java_test_with_tflite( ], ) -java_test_with_tflite( - name = "NnApiDelegateNativeTest", - size = "small", - srcs = [ - "src/test/java/org/tensorflow/lite/NnApiDelegateNativeTest.java", - "src/test/java/org/tensorflow/lite/SupportedFeatures.java", - "src/test/java/org/tensorflow/lite/TestUtils.java", - ], - data = [ - "src/testdata/add.bin", - ], - tags = ["no_mac"], - test_class = "org.tensorflow.lite.NnApiDelegateNativeTest", - tflite_deps = [ - ":test_init", - ], - tflite_jni_binaries = [ - "//tensorflow/lite/java/src/test/native:libtensorflowlite_test_jni.so", - ], - visibility = ["//visibility:private"], - deps = [ - ":tensorflowlite_javalib", - "//third_party/java/mockito", - "@com_google_truth", - "@junit", - ], -) +# Commented out under the (b/279852433) because caused an error in the OSS +# TODO(zhurakovskyi): Uncomment when fixed. +# copybara:uncomment_begin +# java_test_with_tflite( +# name = "NnApiDelegateNativeTest", +# size = "small", +# srcs = [ +# "src/test/java/org/tensorflow/lite/NnApiDelegateNativeTest.java", +# "src/test/java/org/tensorflow/lite/SupportedFeatures.java", +# "src/test/java/org/tensorflow/lite/TestUtils.java", +# ], +# data = [ +# "src/testdata/add.bin", +# ], +# tags = ["no_mac"], +# test_class = "org.tensorflow.lite.NnApiDelegateNativeTest", +# tflite_deps = [ +# ":test_init", +# ], +# tflite_jni_binaries = [ +# "//tensorflow/lite/java/src/test/native:libtensorflowlite_test_jni.so", +# ], +# visibility = ["//visibility:private"], +# deps = [ +# ":tensorflowlite_javalib", +# "//third_party/java/mockito", +# "@com_google_truth", +# "@junit", +# ], +# ) +# copybara:uncomment_end java_test_with_tflite( name = "NnApiDelegateTest", diff --git a/tensorflow/lite/java/demo/app/src/main/BUILD b/tensorflow/lite/java/demo/app/src/main/BUILD index 0c2baa9d3101d7..6a8113274d4b2b 100644 --- a/tensorflow/lite/java/demo/app/src/main/BUILD +++ b/tensorflow/lite/java/demo/app/src/main/BUILD @@ -30,7 +30,5 @@ android_binary( "//tensorflow/lite/java:tensorflowlite", "//tensorflow/lite/java:tensorflowlite_gpu", "//tensorflow/lite/java/src/testhelper/java/org/tensorflow/lite:testhelper", - "@androidsdk//com.android.support:support-v13-25.2.0", - "@androidsdk//com.android.support:support-v4-25.2.0", ], ) diff --git a/tensorflow/lite/java/ovic/demo/app/BUILD b/tensorflow/lite/java/ovic/demo/app/BUILD index 126625be7f7e1a..4c3b55828723eb 100644 --- a/tensorflow/lite/java/ovic/demo/app/BUILD +++ b/tensorflow/lite/java/ovic/demo/app/BUILD @@ -28,11 +28,9 @@ android_binary( resource_files = glob(["res/**"]), tags = ["manual"], deps = [ + # copybara:uncomment "//third_party/java/android/android_sdk_linux/extras/android/compatibility/multidex", "//tensorflow/lite/java:tensorflowlite", "//tensorflow/lite/java/ovic:ovicbenchmarkerlib", "//tensorflow/lite/java/ovic:ovicdetectionbenchmarkerlib", - "//third_party/java/android/android_sdk_linux/extras/android/compatibility/multidex", - "@androidsdk//com.android.support:support-v13-25.2.0", - "@androidsdk//com.android.support:support-v4-25.2.0", ], ) diff --git a/tensorflow/lite/java/src/testhelper/java/org/tensorflow/lite/BUILD b/tensorflow/lite/java/src/testhelper/java/org/tensorflow/lite/BUILD index 2264288a3631d4..365942d6490601 100644 --- a/tensorflow/lite/java/src/testhelper/java/org/tensorflow/lite/BUILD +++ b/tensorflow/lite/java/src/testhelper/java/org/tensorflow/lite/BUILD @@ -4,7 +4,7 @@ load("@build_bazel_rules_android//android:rules.bzl", "android_library") package( - # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + # copybara:uncomment default_applicable_licenses = ["//tensorflow:LICENSE"], default_visibility = ["//visibility:public"], licenses = ["notice"], ) diff --git a/tensorflow/lite/kernels/BUILD b/tensorflow/lite/kernels/BUILD index a0eb5ce425fc2a..9445fee50db352 100644 --- a/tensorflow/lite/kernels/BUILD +++ b/tensorflow/lite/kernels/BUILD @@ -889,6 +889,7 @@ cc_library( "//tensorflow/lite/c:c_api_types", "//tensorflow/lite:array", "//tensorflow/lite:builtin_ops", + "//tensorflow/lite:cc_api_stable", "@local_tsl//tsl/lib/random:philox_random", "@local_tsl//tsl/lib/random:random_distributions_utils", "//tensorflow/lite/core/c:c_api_types", @@ -1493,10 +1494,9 @@ cc_library( ], deps = [ ":test_util", - "//tensorflow/lite:string", + "//tensorflow/lite:cc_api_stable", "//tensorflow/lite/schema:schema_fbs", "@com_google_absl//absl/types:span", - "@flatbuffers", ], ) @@ -1509,6 +1509,7 @@ cc_test( ":cast_test_common", ":test_main", ":test_util", + "//tensorflow/lite/c:common", "//tensorflow/lite/core/c:c_api_types", "//tensorflow/lite/schema:schema_fbs", "@com_google_absl//absl/types:span", diff --git a/tensorflow/lite/kernels/activations.cc b/tensorflow/lite/kernels/activations.cc index 4ed6784ad0d7c5..82fcf583fe7c38 100644 --- a/tensorflow/lite/kernels/activations.cc +++ b/tensorflow/lite/kernels/activations.cc @@ -1543,7 +1543,7 @@ TfLiteStatus PreluEval(TfLiteContext* context, TfLiteNode* node) { default: TF_LITE_KERNEL_LOG( context, - "Only float32 and uint8 and int8 are supported currently, got %d.", + "Only float32 and uint8 and int8 are supported currently, got %s.", TfLiteTypeGetName(input->type)); return kTfLiteError; } diff --git a/tensorflow/lite/kernels/cast.cc b/tensorflow/lite/kernels/cast.cc index e7d6cdf09b18b1..57c65247f8788d 100644 --- a/tensorflow/lite/kernels/cast.cc +++ b/tensorflow/lite/kernels/cast.cc @@ -14,11 +14,13 @@ limitations under the License. ==============================================================================*/ #include #include +#include #include +#include "Eigen/Core" // from @eigen_archive #include "tensorflow/lite/core/c/common.h" -#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h" -#include "tensorflow/lite/kernels/internal/tensor.h" +#include "tensorflow/lite/core/subgraph.h" +#include "tensorflow/lite/interpreter_options.h" #include "tensorflow/lite/kernels/internal/tensor_ctypes.h" #include "tensorflow/lite/kernels/kernel_util.h" #include "tensorflow/lite/kernels/op_macros.h" @@ -27,29 +29,11 @@ namespace tflite { namespace ops { namespace builtin { namespace cast { -constexpr int kInputTensor = 0; -constexpr int kOutputTensor = 0; -TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { - TF_LITE_ENSURE_EQ(context, NumInputs(node), 1); - TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); - const TfLiteTensor* input; - TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input)); - TfLiteTensor* output; - TF_LITE_ENSURE_OK(context, - GetOutputSafe(context, node, kOutputTensor, &output)); +namespace { - // TODO(ahentz): these two checks would make the new implementation - // incompatible with some existing models, where params is not specified. It - // is OK not to have them because toco would have set input and output types - // to match the parameters. - // auto* params = reinterpret_cast(node->builtin_data); - // TF_LITE_ENSURE_EQ(context, input->type, params->in_data_type); - // TF_LITE_ENSURE_EQ(context, output->type, params->out_data_type); - - return context->ResizeTensor(context, output, - TfLiteIntArrayCopy(input->dims)); -} +constexpr int kInputTensor = 0; +constexpr int kOutputTensor = 0; template void copyCast(const FromT* in, ToT* out, int num_elements) { @@ -213,14 +197,8 @@ TfLiteStatus copyToTensor(TfLiteContext* context, const FromT* in, return kTfLiteOk; } -TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { - const TfLiteTensor* input; - TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input)); - TfLiteTensor* output; - TF_LITE_ENSURE_OK(context, - GetOutputSafe(context, node, kOutputTensor, &output)); - const int num_elements = NumElements(input); - TF_LITE_ENSURE_EQ(context, num_elements, NumElements(output)); +TfLiteStatus EvalImpl(TfLiteContext* context, const TfLiteTensor* input, + TfLiteTensor* output, const int num_elements) { switch (input->type) { case kTfLiteInt64: return copyToTensor(context, input->data.i64, output, num_elements); @@ -260,11 +238,90 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { // Unsupported type. TF_LITE_UNSUPPORTED_TYPE(context, input->type, "Cast"); } + return kTfLiteError; } + +struct OpData { + bool cached_output = false; +}; + +void* Init(TfLiteContext* context, const char* /*buffer*/, size_t /*length*/) { + return new OpData(); +} + +void Free(TfLiteContext* context, void* op_data) { + delete reinterpret_cast(op_data); +} + +bool OutputCachingEnabled(const TfLiteContext* context) { + if (context && context->impl_) { + const InterpreterOptions* options = + reinterpret_cast(context->impl_)->GetOptions(); + if (options) { + return options->GetCacheConstantCastOp(); + } + } + return false; +} + +bool ShouldCacheOutput(const TfLiteContext* context, + const TfLiteTensor* input) { + return OutputCachingEnabled(context) && IsConstantTensor(input); +} + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + TF_LITE_ENSURE_EQ(context, NumInputs(node), 1); + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); + const TfLiteTensor* input; + TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input)); + TfLiteTensor* output; + TF_LITE_ENSURE_OK(context, + GetOutputSafe(context, node, kOutputTensor, &output)); + + // TODO(ahentz): these two checks would make the new implementation + // incompatible with some existing models, where params is not specified. It + // is OK not to have them because toco would have set input and output types + // to match the parameters. + // auto* params = reinterpret_cast(node->builtin_data); + // TF_LITE_ENSURE_EQ(context, input->type, params->in_data_type); + // TF_LITE_ENSURE_EQ(context, output->type, params->out_data_type); + + if (ShouldCacheOutput(context, input)) { + output->allocation_type = kTfLiteArenaRwPersistent; + } + + TF_LITE_ENSURE_OK( + context, + context->ResizeTensor(context, output, TfLiteIntArrayCopy(input->dims))); + + return kTfLiteOk; +} + +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + const TfLiteTensor* input; + TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input)); + TfLiteTensor* output; + TF_LITE_ENSURE_OK(context, + GetOutputSafe(context, node, kOutputTensor, &output)); + const int num_elements = NumElements(input); + TF_LITE_ENSURE_EQ(context, num_elements, NumElements(output)); + + OpData& op_data = *reinterpret_cast(node->user_data); + if (ShouldCacheOutput(context, input)) { + if (op_data.cached_output) { + return kTfLiteOk; + } + op_data.cached_output = true; + } + return EvalImpl(context, input, output, num_elements); +} + +} // namespace } // namespace cast TfLiteRegistration* Register_CAST() { - static TfLiteRegistration r = {nullptr, nullptr, cast::Prepare, cast::Eval}; + static TfLiteRegistration r = {cast::Init, cast::Free, cast::Prepare, + cast::Eval}; return &r; } diff --git a/tensorflow/lite/kernels/cast_test.cc b/tensorflow/lite/kernels/cast_test.cc index e2971016619532..c2eef57197119b 100644 --- a/tensorflow/lite/kernels/cast_test.cc +++ b/tensorflow/lite/kernels/cast_test.cc @@ -23,6 +23,7 @@ limitations under the License. #include #include "absl/types/span.h" #include "flatbuffers/flatbuffers.h" // from @flatbuffers +#include "tensorflow/lite/c/common.h" #include "tensorflow/lite/core/c/c_api_types.h" #include "tensorflow/lite/kernels/cast_test_common.h" #include "tensorflow/lite/kernels/test_util.h" @@ -291,5 +292,31 @@ TEST(CastOpModel, CastInt16ToUInt16) { ElementsAreArray({10, 20, 30, 40, 50, 60})); } +TEST(CastOpModel, CastConstInputCachingWorks) { + // This tests the implementation of a performance optimization. If that + // optimization is changed, this test will likely break/need to be updated. + // + // We are relying on the fact that casting a constant input can be cached and + // that the output tensor does not need to be updated on every call. + CastOpModel m({TensorType_INT8, {2, 3}}, + std::vector{10, 20, 30, 40, 50, 60}, + {TensorType_FLOAT32, {2, 3}}); + EXPECT_EQ(m.GetOutputTensor(0)->allocation_type, kTfLiteArenaRwPersistent); + ASSERT_EQ(m.Invoke(), kTfLiteOk); + EXPECT_THAT(m.ExtractVector(m.output()), + ElementsAreArray({10, 20, 30, 40, 50, 60})); + // We are cheating here. If the values of the output tensor are cached then if + // we modify the cache and call the op again the output tensor values should + // not change. + float* output_data = + reinterpret_cast(m.GetOutputTensor(0)->data.data); + for (int i = 0; i < 6; ++i) { + ++output_data[i]; + } + ASSERT_EQ(m.Invoke(), kTfLiteOk); + EXPECT_THAT(m.ExtractVector(m.output()), + ElementsAreArray({11, 21, 31, 41, 51, 61})); +} + } // namespace } // namespace tflite diff --git a/tensorflow/lite/kernels/cast_test_common.h b/tensorflow/lite/kernels/cast_test_common.h index 123cce213228f1..1cfa4bd740eecf 100644 --- a/tensorflow/lite/kernels/cast_test_common.h +++ b/tensorflow/lite/kernels/cast_test_common.h @@ -20,6 +20,7 @@ limitations under the License. #include #include "absl/types/span.h" +#include "tensorflow/lite/interpreter_options.h" #include "tensorflow/lite/kernels/test_util.h" #include "tensorflow/lite/schema/schema_generated.h" @@ -37,6 +38,23 @@ class CastOpModel : public SingleOpModel { BuildInterpreter({GetShape(input_)}); } + template + CastOpModel(const TensorData& input, ConstInputData&& data, + const TensorData& output) { + input_ = AddConstInput(input, static_cast(data)); + output_ = AddOutput(output); + SetBuiltinOp(BuiltinOperator_CAST, BuiltinOptions_CastOptions, + CreateCastOptions(builder_).Union()); + BuildInterpreter({GetShape(input_)}, /*num_threads=*/-1, + /*allow_fp32_relax_to_fp16=*/false, + /*apply_delegate=*/true, /*allocate_and_delegate=*/false, + /*use_simple_allocator=*/false); + InterpreterOptions options; + options.SetCacheConstantCastOp(true); + interpreter_->ApplyOptions(&options); + AllocateAndDelegate(/*apply_delegate=*/true); + } + void Set4BitInput(absl::Span f) { PopulateTensor4bit(input_, 0, f.data(), f.data() + f.size()); } diff --git a/tensorflow/lite/kernels/complex_support.cc b/tensorflow/lite/kernels/complex_support.cc index 8713e73ea5772c..6bc59e4a815e60 100644 --- a/tensorflow/lite/kernels/complex_support.cc +++ b/tensorflow/lite/kernels/complex_support.cc @@ -83,7 +83,7 @@ TfLiteStatus EvalReal(TfLiteContext* context, TfLiteNode* node) { default: { TF_LITE_KERNEL_LOG(context, "Unsupported input type, Real op only supports " - "complex input, but got: ", + "complex input, but got: %s", TfLiteTypeGetName(input->type)); return kTfLiteError; } @@ -115,7 +115,7 @@ TfLiteStatus EvalImag(TfLiteContext* context, TfLiteNode* node) { default: { TF_LITE_KERNEL_LOG(context, "Unsupported input type, Imag op only supports " - "complex input, but got: ", + "complex input, but got: %s", TfLiteTypeGetName(input->type)); return kTfLiteError; } @@ -146,7 +146,7 @@ TfLiteStatus EvalAbs(TfLiteContext* context, TfLiteNode* node) { default: { TF_LITE_KERNEL_LOG(context, "Unsupported input type, ComplexAbs op only supports " - "complex input, but got: ", + "complex input, but got: %s", TfLiteTypeGetName(input->type)); return kTfLiteError; } diff --git a/tensorflow/lite/kernels/cumsum.cc b/tensorflow/lite/kernels/cumsum.cc index 997589d2b3b278..7d4d04dad82b41 100644 --- a/tensorflow/lite/kernels/cumsum.cc +++ b/tensorflow/lite/kernels/cumsum.cc @@ -63,7 +63,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { if (axis < 0) axis += NumDimensions(input); if (axis < 0 || axis >= NumDimensions(input)) { - TF_LITE_KERNEL_LOG(context, "Invalid axis: ", axis); + TF_LITE_KERNEL_LOG(context, "Invalid axis: %d", axis); return kTfLiteError; } diff --git a/tensorflow/lite/kernels/fill.cc b/tensorflow/lite/kernels/fill.cc index 910b8082238147..6f704cb028495a 100644 --- a/tensorflow/lite/kernels/fill.cc +++ b/tensorflow/lite/kernels/fill.cc @@ -41,7 +41,8 @@ TfLiteStatus ResizeOutputImpl(TfLiteContext* context, const TfLiteTensor* dims, T data = GetTensorData(dims)[i]; if (data < 0) { TfLiteIntArrayFree(output_shape); - TF_LITE_KERNEL_LOG(context, "Fill dimensions must be >= 0", dims->type); + TF_LITE_KERNEL_LOG(context, "Fill dimensions must be >= 0 got %d", + dims->type); return kTfLiteError; } output_shape->data[i] = data; diff --git a/tensorflow/lite/kernels/fully_connected.cc b/tensorflow/lite/kernels/fully_connected.cc index 8ef28868dc47d6..0c92b0e2641b85 100644 --- a/tensorflow/lite/kernels/fully_connected.cc +++ b/tensorflow/lite/kernels/fully_connected.cc @@ -413,14 +413,6 @@ TfLiteStatus PrepareImpl(TfLiteContext* context, TfLiteNode* node, // parameters set. This is usually done during quantized training. if (input->type == kTfLiteUInt8 || input->type == kTfLiteInt8 || input->type == kTfLiteInt16) { - // Populate scalar quantization parameters. - double real_multiplier = 0.0; - TF_LITE_ENSURE_STATUS(GetQuantizedConvolutionMultipler( - context, input, filter, bias, output, &real_multiplier)); - int exponent; - QuantizeMultiplier(real_multiplier, &data->output_multiplier, &exponent); - data->output_shift = exponent; - // Populate per-channel quantization parameters, if per-channel // quantization. TF_LITE_ENSURE_EQ(context, input->quantization.type, @@ -466,6 +458,14 @@ TfLiteStatus PrepareImpl(TfLiteContext* context, TfLiteNode* node, per_channel_multiplier[i] = significand; per_channel_shift[i] = channel_shift; } + } else { + // Populate scalar quantization parameters otherwise. + double real_multiplier = 0.0; + TF_LITE_ENSURE_STATUS(GetQuantizedConvolutionMultipler( + context, input, filter, bias, output, &real_multiplier)); + int exponent; + QuantizeMultiplier(real_multiplier, &data->output_multiplier, &exponent); + data->output_shift = exponent; } TF_LITE_ENSURE_STATUS(CalculateActivationRangeQuantized( @@ -1354,7 +1354,9 @@ TfLiteStatus EvalQuantized(TfLiteContext* context, TfLiteNode* node, // Block sparse with block size of 1x16. optimized_ops::FullyConnectedSparseWeight1x16( sparsity, op_params, input_shape, GetTensorData(input), - filter_shape, GetTensorData(filter), bias_shape, + filter_shape, GetTensorData(filter), + data->per_channel_output_multiplier.data(), + data->per_channel_output_shift.data(), bias_shape, GetTensorData(bias), output_shape, GetTensorData(output), CpuBackendContext::GetFromContext(context)); diff --git a/tensorflow/lite/kernels/fully_connected_test.cc b/tensorflow/lite/kernels/fully_connected_test.cc index 16bb850221a250..86a3f1ef8619ed 100644 --- a/tensorflow/lite/kernels/fully_connected_test.cc +++ b/tensorflow/lite/kernels/fully_connected_test.cc @@ -1713,7 +1713,13 @@ class SparseFullyConnectedOpModel : public SingleOpModel { } else if (input.type == TensorType_INT8) { // This is a quantized version. The scale of 'bias' depends on the scales // of input and filter. - auto bias_scale = GetScale(input_) * GetScale(weights_); + float bias_scale = GetScale(input_); + if (weights.per_channel_quantization && + !weights.per_channel_quantization_scales.empty()) { + bias_scale *= weights.per_channel_quantization_scales[0]; + } else { + bias_scale *= GetScale(weights_); + } TensorData bias = {TensorType_INT32, {units_}, 0, 0, bias_scale}; bias_ = AddInput(bias); } else { @@ -2333,6 +2339,44 @@ TEST_P(SparseQuantizedFullyConnectedOpTest, Simple1x16TestScaledInputOutput) { EXPECT_THAT(m.GetOutput(), ElementsAre(-52, -50, -52)); } +TEST_P(SparseQuantizedFullyConnectedOpTest, + Simple1x16PerChannelQuantizationTest) { + std::vector weight_data = { + 1, 2, 3, 4, -1, -2, -3, -4, 1, 2, 3, 4, -4, -3, -2, -1, // u = 0 + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, // u = 1 + -1, -2, -3, -4, 4, 3, 2, 1, -1, -2, -3, 4, 1, 2, 3, 4, // u = 2 + }; + TensorData weight = {TensorType_INT8, + {3, 16}, + 0.0, + 0.0, + 0.0, + 0, + true, + {4.0 / 127.0, 1.0 / 127.0, 4.0 / 127.0}, + {0, 0, 0}}; + weight.traversal_order = {0, 1, 2}; + weight.format = {kTfLiteDimDense, kTfLiteDimSparseCSR}; + weight.block_map = {1}; + weight.block_size = {16}; + SparseQuantizedFullyConnectedOpModel m( + GetRegistration(), + /*units=*/3, /*batches=*/2, + /*input=*/{TensorType_INT8, {2, 16}, 0, 0, 1}, weight, weight_data, + /*output=*/{TensorType_INT8, {}, 0, 0, 1}); + + m.SetBias({1, 2, 3}); + m.SetInput({ + 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, // b = 0 + 4, 3, 2, 1, 4, 3, 2, 1, 4, 3, 2, 1, 4, 3, 2, 1, // b = 1 + }); + + ASSERT_EQ(m.Invoke(), kTfLiteOk); + + EXPECT_THAT(m.GetOutputShape(), ElementsAre(2, 3)); + EXPECT_THAT(m.GetOutput(), ElementsAre(11, 1, 25, 0, 1, 21)); +} + INSTANTIATE_TEST_SUITE_P( SparseQuantizedFullyConnectedOpTest, SparseQuantizedFullyConnectedOpTest, ::testing::ValuesIn(SingleOpTest::GetKernelTags(*kKernelMapNoPie))); diff --git a/tensorflow/lite/kernels/internal/common.cc b/tensorflow/lite/kernels/internal/common.cc index 1654ab84f0d7ce..fabb0208b7d21c 100644 --- a/tensorflow/lite/kernels/internal/common.cc +++ b/tensorflow/lite/kernels/internal/common.cc @@ -17,6 +17,53 @@ limitations under the License. namespace tflite { +// Single-rounding MultiplyByQuantizedMultiplier +#if TFLITE_SINGLE_ROUNDING +int32_t MultiplyByQuantizedMultiplier(int32_t x, int32_t quantized_multiplier, + int shift) { + TFLITE_DCHECK(quantized_multiplier >= 0); + TFLITE_DCHECK(shift >= -31 && shift <= 30); + + const int64_t total_shift = 31 - shift; + const int64_t round = static_cast(1) << (total_shift - 1); + int64_t result = x * static_cast(quantized_multiplier) + round; + result = result >> total_shift; + + TFLITE_DCHECK(result >= std::numeric_limits::min() && + result <= std::numeric_limits::max()); + return static_cast(result); +} + +int32_t MultiplyByQuantizedMultiplier(int64_t x, int32_t quantized_multiplier, + int shift) { + // Inputs: + // - quantized_multiplier has fixed point at bit 31 + // - shift is -31 to +7 (negative for right shift) + // + // Assumptions: The following input ranges are assumed + // - quantize_scale>=0 (the usual range is (1<<30) to (1>>31)-1) + // - scaling is chosen so final scaled result fits in int32_t + // - input x is in the range -(1<<47) <= x < (1<<47) + TFLITE_DCHECK(quantized_multiplier >= 0); + TFLITE_DCHECK(shift >= -31 && shift < 8); + TFLITE_DCHECK(x >= -(static_cast(1) << 47) && + x < (static_cast(1) << 47)); + + const int32_t reduced_multiplier = + (quantized_multiplier < 0x7FFF0000) + ? ((quantized_multiplier + (1 << 15)) >> 16) + : 0x7FFF; + const int64_t total_shift = 15 - shift; + const int64_t round = static_cast(1) << (total_shift - 1); + int64_t result = x * static_cast(reduced_multiplier) + round; + result = result >> total_shift; + + TFLITE_DCHECK(result >= std::numeric_limits::min() && + result <= std::numeric_limits::max()); + return static_cast(result); +} +// Double-rounding MultiplyByQuantizedMultiplier +#else int32_t MultiplyByQuantizedMultiplier(int32_t x, int32_t quantized_multiplier, int shift) { using gemmlowp::RoundingDivideByPOT; @@ -51,5 +98,6 @@ int32_t MultiplyByQuantizedMultiplier(int64_t x, int32_t quantized_multiplier, int32_t result = x >> total_shift; return result; } +#endif // TFLITE_SINGLE_ROUNDING } // namespace tflite diff --git a/tensorflow/lite/kernels/internal/common.h b/tensorflow/lite/kernels/internal/common.h index 14d859917522f7..9761a8cc07a8ec 100644 --- a/tensorflow/lite/kernels/internal/common.h +++ b/tensorflow/lite/kernels/internal/common.h @@ -257,24 +257,14 @@ inline void BiasAndClamp(float clamp_min, float clamp_max, int bias_size, #endif } -// Single-rounding MultiplyByQuantizedMultiplier -#if TFLITE_SINGLE_ROUNDING -inline int32_t MultiplyByQuantizedMultiplier(int32_t x, - int32_t quantized_multiplier, - int shift) { - TFLITE_DCHECK(quantized_multiplier >= 0); - TFLITE_DCHECK(shift >= -31 && shift <= 30); - - const int64_t total_shift = 31 - shift; - const int64_t round = static_cast(1) << (total_shift - 1); - int64_t result = x * static_cast(quantized_multiplier) + round; - result = result >> total_shift; +TFLITE_NOINLINE int32_t MultiplyByQuantizedMultiplier( + int32_t x, int32_t quantized_multiplier, int shift); - TFLITE_DCHECK(result >= std::numeric_limits::min() && - result <= std::numeric_limits::max()); - return static_cast(result); -} +TFLITE_NOINLINE int32_t MultiplyByQuantizedMultiplier( + int64_t x, int32_t quantized_multiplier, int shift); +// Single-rounding MultiplyByQuantizedMultiplier +#if TFLITE_SINGLE_ROUNDING inline int32_t MultiplyByQuantizedMultiplierSmallerThanOneExp( int32_t x, int32_t quantized_multiplier, int shift) { TFLITE_DCHECK_LE(shift, 0); @@ -287,36 +277,6 @@ inline int32_t MultiplyByQuantizedMultiplierGreaterThanOne( return MultiplyByQuantizedMultiplier(x, quantized_multiplier, shift); } -inline int32_t MultiplyByQuantizedMultiplier(int64_t x, - int32_t quantized_multiplier, - int shift) { - // Inputs: - // - quantized_multiplier has fixed point at bit 31 - // - shift is -31 to +7 (negative for right shift) - // - // Assumptions: The following input ranges are assumed - // - quantize_scale>=0 (the usual range is (1<<30) to (1>>31)-1) - // - scaling is chosen so final scaled result fits in int32_t - // - input x is in the range -(1<<47) <= x < (1<<47) - TFLITE_DCHECK(quantized_multiplier >= 0); - TFLITE_DCHECK(shift >= -31 && shift < 8); - TFLITE_DCHECK(x >= -(static_cast(1) << 47) && - x < (static_cast(1) << 47)); - - const int32_t reduced_multiplier = - (quantized_multiplier < 0x7FFF0000) - ? ((quantized_multiplier + (1 << 15)) >> 16) - : 0x7FFF; - const int64_t total_shift = 15 - shift; - const int64_t round = static_cast(1) << (total_shift - 1); - int64_t result = x * static_cast(reduced_multiplier) + round; - result = result >> total_shift; - - TFLITE_DCHECK(result >= std::numeric_limits::min() && - result <= std::numeric_limits::max()); - return static_cast(result); -} - #ifdef USE_NEON inline int32x4x4_t MultiplyByQuantizedMultiplier4Rows( int32x4x4_t input_val, int32_t quantized_multiplier, int shift) { @@ -366,12 +326,6 @@ inline int32_t MultiplyByQuantizedMultiplierGreaterThanOne( quantized_multiplier); } -TFLITE_NOINLINE int32_t MultiplyByQuantizedMultiplier( - int32_t x, int32_t quantized_multiplier, int shift); - -TFLITE_NOINLINE int32_t MultiplyByQuantizedMultiplier( - int64_t x, int32_t quantized_multiplier, int shift); - #ifdef USE_NEON // Round uses ARM's rounding shift right. inline int32x4x4_t MultiplyByQuantizedMultiplier4Rows( diff --git a/tensorflow/lite/kernels/internal/optimized/neon_tensor_utils.cc b/tensorflow/lite/kernels/internal/optimized/neon_tensor_utils.cc index ffdf4b6d4b1756..f8e7601f6b4f14 100644 --- a/tensorflow/lite/kernels/internal/optimized/neon_tensor_utils.cc +++ b/tensorflow/lite/kernels/internal/optimized/neon_tensor_utils.cc @@ -1966,7 +1966,8 @@ void NeonSparseMatrixBatchVectorMultiplyAccumulate1x16( const int32_t* __restrict__ indices, int m_rows, int m_cols, const int8_t* __restrict__ vector, const int32_t* __restrict__ bias_vector, int n_batch, const int32_t input_offset, const int32_t output_multiplier, - const int32_t output_shift, const int32_t output_offset, + const int32_t output_shift, const int32_t* per_channel_scale, + const int32_t* per_channel_shift, const int32_t output_offset, const int32_t output_activation_min, const int32_t output_activation_max, int8_t* __restrict__ result) { constexpr int kBlockSize = kInt8ValuesPerNeonVector; @@ -2028,7 +2029,9 @@ void NeonSparseMatrixBatchVectorMultiplyAccumulate1x16( #endif const int32_t bias_value = bias_vector != nullptr ? bias_vector[row] : 0; acc = acc + bias_value + input_offset * matrix_row_sum; - acc = MultiplyByQuantizedMultiplier(acc, output_multiplier, output_shift); + acc = MultiplyByQuantizedMultiplier( + acc, per_channel_scale ? per_channel_scale[row] : output_multiplier, + per_channel_shift ? per_channel_shift[row] : output_shift); acc += output_offset; result[batch * m_rows + row] = static_cast(ActivationFunctionWithMinMax( diff --git a/tensorflow/lite/kernels/internal/optimized/neon_tensor_utils.h b/tensorflow/lite/kernels/internal/optimized/neon_tensor_utils.h index 096a6943ca4559..ebb5a2abad425a 100644 --- a/tensorflow/lite/kernels/internal/optimized/neon_tensor_utils.h +++ b/tensorflow/lite/kernels/internal/optimized/neon_tensor_utils.h @@ -84,14 +84,15 @@ void SparseMatrixBatchVectorMultiplyAccumulate1x16( const int32_t* __restrict__ indices, int m_rows, int m_cols, const int8_t* __restrict__ vector, const int32_t* __restrict__ bias_vector, int n_batch, const int32_t input_offset, const int32_t output_multiplier, - const int32_t output_shift, const int32_t output_offset, + const int32_t output_shift, const int32_t* per_channel_scale, + const int32_t* per_channel_shift, const int32_t output_offset, const int32_t output_activation_min, const int32_t output_activation_max, int8_t* __restrict__ result) { NEON_OR_PORTABLE(SparseMatrixBatchVectorMultiplyAccumulate1x16, matrix, segments, indices, m_rows, m_cols, vector, bias_vector, n_batch, input_offset, output_multiplier, output_shift, - output_offset, output_activation_min, output_activation_max, - result); + per_channel_scale, per_channel_shift, output_offset, + output_activation_min, output_activation_max, result); } void SparseMatrixBatchVectorMultiplyAccumulate( diff --git a/tensorflow/lite/kernels/internal/optimized/neon_tensor_utils_impl.h b/tensorflow/lite/kernels/internal/optimized/neon_tensor_utils_impl.h index 2e65529aa6191f..dd8a05f4a37f8e 100644 --- a/tensorflow/lite/kernels/internal/optimized/neon_tensor_utils_impl.h +++ b/tensorflow/lite/kernels/internal/optimized/neon_tensor_utils_impl.h @@ -122,7 +122,8 @@ void NeonSparseMatrixBatchVectorMultiplyAccumulate1x16( const int32_t* __restrict__ indices, int m_rows, int m_cols, const int8_t* __restrict__ vector, const int32_t* __restrict__ bias_vector, int n_batch, const int32_t input_offset, const int32_t output_multiplier, - const int32_t output_shift, const int32_t output_offset, + int32_t output_shift, const int32_t* per_channel_scale, + const int32_t* per_channel_shift, int32_t output_offset, const int32_t output_activation_min, const int32_t output_activation_max, int8_t* __restrict__ result); diff --git a/tensorflow/lite/kernels/internal/optimized/optimized_4bit_test.cc b/tensorflow/lite/kernels/internal/optimized/optimized_4bit_test.cc index 313ef85e75b990..a4db8457594d85 100644 --- a/tensorflow/lite/kernels/internal/optimized/optimized_4bit_test.cc +++ b/tensorflow/lite/kernels/internal/optimized/optimized_4bit_test.cc @@ -14,6 +14,8 @@ limitations under the License. ==============================================================================*/ #include #include +#include +#include #include #include @@ -37,10 +39,19 @@ struct TestPack { depth(depth), rows((src_rows + (width - 1)) & ~(width - 1)), cols((src_cols + (depth - 1)) & ~(depth - 1)), - packed_data_buffer(rows * cols + padding) {} + // Must be vector-aligned. + packed_data_buffer( + [=]() -> uint8_t* { + void* ptr; + if (posix_memalign(&ptr, 64, rows * cols + padding)) { + abort(); + } + return static_cast(ptr); + }(), + [](uint8_t* ptr) { free(ptr); }) {} void Prepack() { - packed_data = packed_data_buffer.data(); + packed_data = packed_data_buffer.get(); optimized_4bit::Prepack(packed_data, src_data.data(), rows, cols, src_rows, src_cols, width, depth); } @@ -64,7 +75,7 @@ struct TestPack { int rows; int cols; int padding = optimized_4bit::kDefaultAlignmentPadding; - std::vector packed_data_buffer; + std::unique_ptr> packed_data_buffer; }; class RunPackTests diff --git a/tensorflow/lite/kernels/internal/optimized/sparse_ops/fully_connected.h b/tensorflow/lite/kernels/internal/optimized/sparse_ops/fully_connected.h index df556fea3719e9..1dd8c3e5d4bcc4 100644 --- a/tensorflow/lite/kernels/internal/optimized/sparse_ops/fully_connected.h +++ b/tensorflow/lite/kernels/internal/optimized/sparse_ops/fully_connected.h @@ -16,6 +16,7 @@ limitations under the License. #define TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_SPARSE_OPS_FULLY_CONNECTED_H_ #include +#include #include "ruy/profiler/instrumentation.h" // from @ruy #include "tensorflow/lite/core/c/common.h" @@ -80,6 +81,7 @@ inline void FullyConnectedSparseWeight1x16Impl( const TfLiteSparsity& sparsity, const FullyConnectedParams& params, const RuntimeShape& input_shape, const int8_t* input_data, const RuntimeShape& weights_shape, const int8_t* weights_data, + const int32_t* per_channel_scale, const int32_t* per_channel_shift, const RuntimeShape& bias_shape, const int32_t* bias_data, const RuntimeShape& output_shape, int8_t* output_data, int thread_start, int thread_end, const CpuBackendContext& cpu_backend_context) { @@ -107,9 +109,9 @@ inline void FullyConnectedSparseWeight1x16Impl( tensor_utils::SparseMatrixBatchVectorMultiplyAccumulate1x16( weights_data, w1_segments, w1_indices, weights_shape.Dims(0), weights_shape.Dims(1), input_data + thread_start * input_depth, bias_data, - batches, input_offset, output_multiplier, output_shift, output_offset, - output_activation_min, output_activation_max, - output_data + thread_start * output_depth); + batches, input_offset, output_multiplier, output_shift, per_channel_scale, + per_channel_shift, output_offset, output_activation_min, + output_activation_max, output_data + thread_start * output_depth); } inline void FullyConnectedSparseWeight1x4Impl( @@ -200,6 +202,7 @@ inline void FullyConnectedSparseWeight1x16( const TfLiteSparsity& sparsity, const FullyConnectedParams& params, const RuntimeShape& input_shape, const int8_t* input_data, const RuntimeShape& weights_shape, const int8_t* weights_data, + const int32_t* per_channel_scale, const int32_t* per_channel_shift, const RuntimeShape& bias_shape, const int32_t* bias_data, const RuntimeShape& output_shape, int8_t* output_data, CpuBackendContext* cpu_backend_context) { @@ -212,8 +215,8 @@ inline void FullyConnectedSparseWeight1x16( // TODO(b/220851507): Add multi-thread support for quantized sparse kernel. return FullyConnectedSparseWeight1x16Impl( sparsity, params, input_shape, input_data, weights_shape, weights_data, - bias_shape, bias_data, output_shape, output_data, 0, batches, - *cpu_backend_context); + per_channel_scale, per_channel_shift, bias_shape, bias_data, output_shape, + output_data, 0, batches, *cpu_backend_context); } // The multi-threaded kernel slices the workload along the batch dimension. If diff --git a/tensorflow/lite/kernels/internal/optimized/sse_tensor_utils.h b/tensorflow/lite/kernels/internal/optimized/sse_tensor_utils.h index 896159c77b5010..4f313bd37a7c5b 100644 --- a/tensorflow/lite/kernels/internal/optimized/sse_tensor_utils.h +++ b/tensorflow/lite/kernels/internal/optimized/sse_tensor_utils.h @@ -89,14 +89,15 @@ void SparseMatrixBatchVectorMultiplyAccumulate1x16( const int32_t* __restrict__ indices, int m_rows, int m_cols, const int8_t* __restrict__ vector, const int32_t* __restrict__ bias_vector, int n_batch, const int32_t input_offset, const int32_t output_multiplier, - const int32_t output_shift, const int32_t output_offset, + const int32_t output_shift, const int32_t* per_channel_scale, + const int32_t* per_channel_shift, const int32_t output_offset, const int32_t output_activation_min, const int32_t output_activation_max, int8_t* __restrict__ result) { NEON_OR_PORTABLE(SparseMatrixBatchVectorMultiplyAccumulate1x16, matrix, segments, indices, m_rows, m_cols, vector, bias_vector, n_batch, input_offset, output_multiplier, output_shift, - output_offset, output_activation_min, output_activation_max, - result); + per_channel_scale, per_channel_shift, output_offset, + output_activation_min, output_activation_max, result); } void SparseMatrixBatchVectorMultiplyAccumulate( diff --git a/tensorflow/lite/kernels/internal/portable_tensor_utils.h b/tensorflow/lite/kernels/internal/portable_tensor_utils.h index 03bfdc8f41026e..d37fe6e4c89836 100644 --- a/tensorflow/lite/kernels/internal/portable_tensor_utils.h +++ b/tensorflow/lite/kernels/internal/portable_tensor_utils.h @@ -241,7 +241,8 @@ void SparseMatrixBatchVectorMultiplyAccumulate1x16( const int32_t* __restrict__ indices, int m_rows, int m_cols, const int8_t* __restrict__ vector, const int32_t* __restrict__ bias_vector, int n_batch, const int32_t input_offset, const int32_t output_multiplier, - const int32_t output_shift, const int32_t output_offset, + int32_t output_shift, const int32_t* per_channel_scale, + const int32_t* per_channel_shift, int32_t output_offset, const int32_t output_activation_min, const int32_t output_activation_max, int8_t* __restrict__ result); diff --git a/tensorflow/lite/kernels/internal/reference/portable_tensor_utils.cc b/tensorflow/lite/kernels/internal/reference/portable_tensor_utils.cc index b519d8139f646b..7d40df42b33db2 100644 --- a/tensorflow/lite/kernels/internal/reference/portable_tensor_utils.cc +++ b/tensorflow/lite/kernels/internal/reference/portable_tensor_utils.cc @@ -157,7 +157,7 @@ void PortableMatrixBatchVectorMultiplyAccumulate( *result += dotprod * batch_scaling_factor; ++result; } // for row - } // for batch + } // for batch } void PortableMatrixBatchVectorMultiplyAccumulate( @@ -200,7 +200,7 @@ void PortableMatrixBatchVectorMultiplyAccumulate( *result += dotprod * scale; ++result; } // for row - } // for batch + } // for batch } void PortableSparseMatrixBatchVectorMultiplyAccumulate1x4( @@ -232,7 +232,8 @@ void PortableSparseMatrixBatchVectorMultiplyAccumulate1x16( const int32_t* __restrict__ indices, int m_rows, int m_cols, const int8_t* __restrict__ vector, const int32_t* __restrict__ bias_vector, int n_batch, const int32_t input_offset, const int32_t output_multiplier, - const int32_t output_shift, const int32_t output_offset, + const int32_t output_shift, const int32_t* per_channel_scale, + const int32_t* per_channel_shift, const int32_t output_offset, const int32_t output_activation_min, const int32_t output_activation_max, int8_t* __restrict__ result) { const int kBlockSize = 16; @@ -252,8 +253,10 @@ void PortableSparseMatrixBatchVectorMultiplyAccumulate1x16( } } const int32_t bias_value = bias_vector != nullptr ? bias_vector[row] : 0; - dot_prod = MultiplyByQuantizedMultiplier(dot_prod + bias_value, - output_multiplier, output_shift); + dot_prod = MultiplyByQuantizedMultiplier( + dot_prod + bias_value, + per_channel_scale ? per_channel_scale[row] : output_multiplier, + per_channel_shift ? per_channel_shift[row] : output_shift); dot_prod += output_offset; result[batch * m_rows + row] = static_cast(ActivationFunctionWithMinMax( @@ -319,14 +322,14 @@ void PortableSparseMatrixBatchVectorMultiplyAccumulate( for (int c = 0; c < kBlockSize; c++) { dotprod += (*row_ptr++) * (*vector_block_ptr++); } // for block - } // for num_nonzero_blocks + } // for num_nonzero_blocks float scaling_factor = batch_scaling_factor; if (per_channel_scale) { scaling_factor *= per_channel_scale[row]; } result[batch * m_rows + row] += dotprod * scaling_factor; } // for row - } // for batch + } // for batch } template diff --git a/tensorflow/lite/kernels/internal/reference/portable_tensor_utils.h b/tensorflow/lite/kernels/internal/reference/portable_tensor_utils.h index ab0185e5a850d8..7c623f71007166 100644 --- a/tensorflow/lite/kernels/internal/reference/portable_tensor_utils.h +++ b/tensorflow/lite/kernels/internal/reference/portable_tensor_utils.h @@ -116,14 +116,16 @@ void SparseMatrixBatchVectorMultiplyAccumulate1x16( const int32_t* __restrict__ indices, int m_rows, int m_cols, const int8_t* __restrict__ vector, const int32_t* __restrict__ bias_vector, int n_batch, const int32_t input_offset, const int32_t output_multiplier, - const int32_t output_shift, const int32_t output_offset, + const int32_t output_shift, const int32_t* per_channel_scale, + const int32_t* per_channel_shift, const int32_t output_offset, const int32_t output_activation_min, const int32_t output_activation_max, int8_t* __restrict__ result) { PortableSparseMatrixBatchVectorMultiplyAccumulate1x16( matrix, segments, indices, m_rows, m_cols, vector, bias_vector, n_batch, - input_offset, output_multiplier, output_shift, output_offset, - output_activation_min, output_activation_max, result); + input_offset, output_multiplier, output_shift, per_channel_scale, + per_channel_shift, output_offset, output_activation_min, + output_activation_max, result); } void SparseMatrixBatchVectorMultiplyAccumulate( diff --git a/tensorflow/lite/kernels/internal/reference/portable_tensor_utils_impl.h b/tensorflow/lite/kernels/internal/reference/portable_tensor_utils_impl.h index 6a1f18e7244c13..11765ec7379599 100644 --- a/tensorflow/lite/kernels/internal/reference/portable_tensor_utils_impl.h +++ b/tensorflow/lite/kernels/internal/reference/portable_tensor_utils_impl.h @@ -92,7 +92,8 @@ void PortableSparseMatrixBatchVectorMultiplyAccumulate1x16( const int32_t* __restrict__ indices, int m_rows, int m_cols, const int8_t* __restrict__ vector, const int32_t* __restrict__ bias_vector, int n_batch, const int32_t input_offset, const int32_t output_multiplier, - const int32_t output_shift, const int32_t output_offset, + int32_t output_shift, const int32_t* per_channel_scale, + const int32_t* per_channel_shift, int32_t output_offset, const int32_t output_activation_min, const int32_t output_activation_max, int8_t* __restrict__ result); diff --git a/tensorflow/lite/kernels/internal/reference/softmax.h b/tensorflow/lite/kernels/internal/reference/softmax.h index c09a7eae8131b4..2930217b61f91d 100644 --- a/tensorflow/lite/kernels/internal/reference/softmax.h +++ b/tensorflow/lite/kernels/internal/reference/softmax.h @@ -115,6 +115,9 @@ inline void Softmax(const SoftmaxParams& params, FixedPoint0 shifted_scale = FixedPoint0::FromRaw(GetReciprocal( sum_of_exps.raw(), kAccumulationIntegerBits, &num_bits_over_unit)); + const int exponent = num_bits_over_unit + 31 - (sizeof(OutputT) * 8); + TFLITE_CHECK(0 <= exponent && exponent <= 31); + for (int c = 0; c < depth; ++c) { int32_t input_diff = static_cast(input_data[i * depth + c]) - max_in_row; @@ -127,8 +130,7 @@ inline void Softmax(const SoftmaxParams& params, FixedPoint0 exp_in_0 = exp_on_negative_values(scaled_diff_f8); int32_t unsat_output = gemmlowp::RoundingDivideByPOT( - (shifted_scale * exp_in_0).raw(), - num_bits_over_unit + 31 - (sizeof(OutputT) * 8)); + (shifted_scale * exp_in_0).raw(), exponent); const int32_t shifted_output = unsat_output + diff --git a/tensorflow/lite/kernels/internal/utils/sparsity_format_converter.cc b/tensorflow/lite/kernels/internal/utils/sparsity_format_converter.cc index eb02a3286c7d25..6f3b90d1783726 100644 --- a/tensorflow/lite/kernels/internal/utils/sparsity_format_converter.cc +++ b/tensorflow/lite/kernels/internal/utils/sparsity_format_converter.cc @@ -364,7 +364,7 @@ TfLiteStatus FormatConverter::SparseToDense(const T* src_data, TfLiteContext* context) { if (dest_size != dense_size_) { TF_LITE_MAYBE_KERNEL_LOG( - context, "unexpected buffer size for densified data, expected %lld.\n", + context, "unexpected buffer size for densified data, expected %zu.\n", dense_size_); return kTfLiteError; } diff --git a/tensorflow/lite/kernels/mul.cc b/tensorflow/lite/kernels/mul.cc index 8a47d5dc60693a..67b0cfd112479b 100644 --- a/tensorflow/lite/kernels/mul.cc +++ b/tensorflow/lite/kernels/mul.cc @@ -398,9 +398,10 @@ TfLiteStatus EvalImpl(TfLiteContext* context, TfLiteNode* node, OpData* data, context, EvalQuantized(context, node, params, data, input1, input2, output)); } else { - TF_LITE_KERNEL_LOG( - context, "Mul only supports FLOAT32, COMPLEX32, INT8, INT16,", - " INT32, INT64 and quantized UINT8 now, got %d.", output->type); + TF_LITE_KERNEL_LOG(context, + "Mul only supports FLOAT32, COMPLEX32, INT8, INT16," + " INT32, INT64 and quantized UINT8 now, got %d.", + output->type); return kTfLiteError; } diff --git a/tensorflow/lite/kernels/parse_example/parse_example.cc b/tensorflow/lite/kernels/parse_example/parse_example.cc index a0006605623519..2fdc06fa5c3316 100644 --- a/tensorflow/lite/kernels/parse_example/parse_example.cc +++ b/tensorflow/lite/kernels/parse_example/parse_example.cc @@ -467,7 +467,7 @@ Status FastParseExampleLite( std::vector varlen_dense_buffers(config.dense.size()); Status status_of_minibatch; for (size_t e = 0; e < count; ++e) { - Status status_of_minibatch = FastParseSerializedExample( + status_of_minibatch = FastParseSerializedExample( GetString(serialized, e), (!example_names.empty() ? example_names[e] : ""), e, config, quick_filter, quick_filter_size, config_index, config_index_size, diff --git a/tensorflow/lite/kernels/stablehlo_reduce_window.cc b/tensorflow/lite/kernels/stablehlo_reduce_window.cc index 78385506c8d768..538b50fd8a7a6e 100644 --- a/tensorflow/lite/kernels/stablehlo_reduce_window.cc +++ b/tensorflow/lite/kernels/stablehlo_reduce_window.cc @@ -687,7 +687,7 @@ struct StablehloData : public OpData { if (execution_plan.size() != 1) { TF_LITE_KERNEL_LOG(context, "Only one kernel is allowed within " - "stablehlo.reduce_window body. (%d) kernels found.\n", + "stablehlo.reduce_window body. (%zu) kernels found.\n", execution_plan.size()); return TfLiteReduceWindowFunctionUnsupported; } diff --git a/tensorflow/lite/kernels/stablehlo_scatter.cc b/tensorflow/lite/kernels/stablehlo_scatter.cc index 9885e8c2d94858..be67dc39e911ac 100644 --- a/tensorflow/lite/kernels/stablehlo_scatter.cc +++ b/tensorflow/lite/kernels/stablehlo_scatter.cc @@ -135,7 +135,7 @@ static TfLiteStatus GetComputationType(const Subgraph* computation_subgraph, if (computation_subgraph->execution_plan().size() > 1) { TF_LITE_KERNEL_LOG(context, "Only one kernel allowed withing the stablehlo region. " - "(%i) kernels found.\n", + "(%zu) kernels found.\n", computation_subgraph->execution_plan().size()); return kTfLiteError; } diff --git a/tensorflow/lite/kernels/subgraph_test_util.cc b/tensorflow/lite/kernels/subgraph_test_util.cc index 2bd0992909a1b8..0521ae31889fa0 100644 --- a/tensorflow/lite/kernels/subgraph_test_util.cc +++ b/tensorflow/lite/kernels/subgraph_test_util.cc @@ -176,6 +176,13 @@ void AddDynamicUpdateSliceNode(Subgraph* subgraph, int input0, int input1, } } // namespace +void Setup1DTensor(Subgraph* subgraph, int tensor_index, TfLiteType type) { + int dim = 1; + ASSERT_EQ(subgraph->SetTensorParametersReadWrite(tensor_index, type, "", 1, + &dim, {}, false), + kTfLiteOk); +} + void SetupTensor(Subgraph* subgraph, int tensor_index, TfLiteType type) { ASSERT_EQ(subgraph->SetTensorParametersReadWrite(tensor_index, type, "", 0, nullptr, {}, false), @@ -275,7 +282,7 @@ void SubgraphBuilder::BuildOutputNotConsumedSubgraph(Subgraph& subgraph) { ASSERT_EQ(subgraph.SetInputs({kInput0, kInput1, kInput2}), kTfLiteOk); ASSERT_EQ(subgraph.SetOutputs({kOutput0, kOutput1, kConstRhs}), kTfLiteOk); for (int i = 0; i < kTensorCount; ++i) { - SetupTensor(&subgraph, i, kTfLiteInt32); + Setup1DTensor(&subgraph, i, kTfLiteInt32); } // kInput0 --> +---+ diff --git a/tensorflow/lite/kernels/transpose_conv.cc b/tensorflow/lite/kernels/transpose_conv.cc index ebcc3011ddcd01..93c6df28890c9c 100644 --- a/tensorflow/lite/kernels/transpose_conv.cc +++ b/tensorflow/lite/kernels/transpose_conv.cc @@ -655,6 +655,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { if (IsDynamicTensor(output)) { TF_LITE_ENSURE_OK(context, ResizeTensor(context, output_shape, output)); } + TF_LITE_ENSURE_EQ(context, SizeOfDimension(input, 0), + SizeOfDimension(output, 0)); if (data->has_col2im && IsDynamicTensor(col2im)) { TF_LITE_ENSURE_OK(context, ResizeCol2ImTensor(context, output_shape, weights, input, col2im)); diff --git a/tensorflow/lite/kernels/unidirectional_sequence_rnn.cc b/tensorflow/lite/kernels/unidirectional_sequence_rnn.cc index d6b7daab0cff7d..7c640f63444517 100644 --- a/tensorflow/lite/kernels/unidirectional_sequence_rnn.cc +++ b/tensorflow/lite/kernels/unidirectional_sequence_rnn.cc @@ -405,7 +405,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { accum_scratch, row_sums, &op_data->compute_row_sums); } default: - TF_LITE_KERNEL_LOG(context, "Type %d not currently supported.", + TF_LITE_KERNEL_LOG(context, "Type %s not currently supported.", TfLiteTypeGetName(input_weights->type)); return kTfLiteError; } diff --git a/tensorflow/lite/lib_package/BUILD b/tensorflow/lite/lib_package/BUILD index 3c1b8d3d45f2bb..791ee3cdb62f4a 100644 --- a/tensorflow/lite/lib_package/BUILD +++ b/tensorflow/lite/lib_package/BUILD @@ -5,7 +5,7 @@ package(default_visibility = ["//visibility:private"]) genrule( name = "clicenses_generate", srcs = [ - "//third_party/eigen3:LICENSE", + # copybara:uncomment "//third_party/eigen3:LICENSE", "@arm_neon_2_x86_sse//:LICENSE", "@farmhash_archive//:COPYING", "@gemmlowp//:LICENSE", diff --git a/tensorflow/lite/python/BUILD b/tensorflow/lite/python/BUILD index 4a26931da922ee..91f9e5ddb67dcf 100644 --- a/tensorflow/lite/python/BUILD +++ b/tensorflow/lite/python/BUILD @@ -9,6 +9,7 @@ package( default_visibility = [ "//tensorflow:__subpackages__", "//tensorflow:internal", + "//third_party/odml/model_customization/quantization:__subpackages__", "//third_party/py/tensorflow_federated:__subpackages__", "//third_party/tflite_micro:__subpackages__", ], @@ -453,6 +454,7 @@ py_strict_library( ], srcs_version = "PY3", deps = [ + "//tensorflow/compiler/mlir/quantization/tensorflow/python:py_function_lib_py", "//tensorflow/python:_pywrap_toco_api", "//tensorflow/python:pywrap_tensorflow", ], diff --git a/tensorflow/lite/python/convert.py b/tensorflow/lite/python/convert.py index cfaff27a849199..6ae3a1724f5202 100644 --- a/tensorflow/lite/python/convert.py +++ b/tensorflow/lite/python/convert.py @@ -230,6 +230,7 @@ def mlir_quantize( denylisted_ops=None, denylisted_nodes=None, enable_variable_quantization=False, + disable_per_channel_for_dense_layers=False, ): """Quantize `input_data_str` with calibration results. @@ -255,6 +256,9 @@ def mlir_quantize( enable_variable_quantization: Experimental. Subject to change. Bool indicating whether to enable quantization of the residual variables remaining after the variable freezing pass. + disable_per_channel_for_dense_layers: Bool indicating whether to do + per-channel or per-tensor quantization in Fully Connected layers. Default + value is False meaning per-channel quantization is enabled. Returns: Quantized model in serialized form (e.g. a TFLITE model) with floating-point @@ -272,6 +276,7 @@ def mlir_quantize( denylisted_ops, denylisted_nodes, enable_variable_quantization, + disable_per_channel_for_dense_layers, ) @@ -419,15 +424,12 @@ def _run_deprecated_conversion_binary( output_filename: str = None try: # Build all input files - with _tempfile.NamedTemporaryFile( - delete=False - ) as fp_conversion, _tempfile.NamedTemporaryFile( - delete=False - ) as fp_model, _tempfile.NamedTemporaryFile( - delete=False - ) as fp_input, _tempfile.NamedTemporaryFile( - delete=False - ) as fp_debug: + with ( + _tempfile.NamedTemporaryFile(delete=False) as fp_conversion, + _tempfile.NamedTemporaryFile(delete=False) as fp_model, + _tempfile.NamedTemporaryFile(delete=False) as fp_input, + _tempfile.NamedTemporaryFile(delete=False) as fp_debug, + ): conversion_filename = fp_conversion.name input_filename = fp_input.name model_filename = fp_model.name @@ -502,7 +504,7 @@ def build_model_flags( saved_model_version=0, saved_model_tags=None, saved_model_exported_names=None, - **_ + **_, ): """Builds the model flags object from params. @@ -592,7 +594,8 @@ def build_conversion_flags( use_buffer_offset=False, reduce_type_precision=False, qdq_conversion_mode=None, - **_ + disable_per_channel_quantization_for_dense_layers=False, + **_, ): """Builds protocol buffer describing a conversion of a model. @@ -718,6 +721,9 @@ def build_conversion_flags( This could have side effects e.g. reduced flatbuffer size. qdq_conversion_mode: If set, assume input model is a quantized model represented with QDQ ops and convert to quantized kernels. + disable_per_channel_quantization_for_dense_layers: If set, disables per + channel end enables per tensor integer quantization for weights in Dense + layers. The flag works only for integer quantized model. Returns: conversion_flags: protocol buffer describing the conversion process. @@ -835,6 +841,9 @@ def build_conversion_flags( conversion_flags.reduce_type_precision = reduce_type_precision if qdq_conversion_mode is not None: conversion_flags.qdq_conversion_mode = qdq_conversion_mode + conversion_flags.disable_per_channel_quantization_for_dense_layers = ( + disable_per_channel_quantization_for_dense_layers + ) return conversion_flags @@ -846,7 +855,7 @@ def convert_graphdef_with_arrays( input_arrays_with_shape, output_arrays, control_output_arrays, - **kwargs + **kwargs, ): """Convert a frozen GraphDef that can't be loaded in TF. diff --git a/tensorflow/lite/python/lite.py b/tensorflow/lite/python/lite.py index 8f793938b834c8..e211a0ce40f7d5 100644 --- a/tensorflow/lite/python/lite.py +++ b/tensorflow/lite/python/lite.py @@ -31,7 +31,7 @@ from tensorflow.compiler.mlir.quantization.tensorflow.python import representative_dataset as rd from tensorflow.core.framework import graph_pb2 as _graph_pb2 from tensorflow.lite.experimental.microfrontend.python.ops import audio_microfrontend_op # pylint: disable=unused-import -from tensorflow.lite.python import conversion_metadata_schema_py_generated as conversion_metdata_fb +from tensorflow.lite.python import conversion_metadata_schema_py_generated as conversion_metadata_fb from tensorflow.lite.python import lite_constants as constants from tensorflow.lite.python.convert import convert_graphdef as _convert_graphdef from tensorflow.lite.python.convert import convert_graphdef_with_arrays as _convert_graphdef_with_arrays @@ -151,7 +151,6 @@ class Optimize(enum.Enum): # The flag can be used alone to optimize float32 models with sparse weights. # It can also be used together with the DEFAULT optimization mode to optimize # quantized models with sparse weights. - # TODO(b/161560631): Add log message when this optimization is applied. EXPERIMENTAL_SPARSITY = "EXPERIMENTAL_SPARSITY" def __str__(self): @@ -228,8 +227,6 @@ def __init__( # Hint for the supported accumulation type used for inference. Typically # used for fp16 post-training quantization, where some models can use fp16 # accumulators instead of the typical fp32 type. - # TODO(b/188185962): Provide full API and authoring support for - # reduced precision accumulation types. self._experimental_supported_accumulation_type = None @@ -605,7 +602,7 @@ class TFLiteConverterBase: # Stores the original model type temporarily to transmit the information # from the factory class methods to TFLiteConverterBase init function. - _original_model_type = conversion_metdata_fb.ModelType.NONE + _original_model_type = conversion_metadata_fb.ModelType.NONE def __init__(self): self.optimizations = set() @@ -642,9 +639,9 @@ def __init__(self): self.experimental_use_stablehlo_quantizer = False # Initializes conversion metadata. self.exclude_conversion_metadata = False - self._metadata = conversion_metdata_fb.ConversionMetadataT() - self._metadata.environment = conversion_metdata_fb.EnvironmentT() - self._metadata.options = conversion_metdata_fb.ConversionOptionsT() + self._metadata = conversion_metadata_fb.ConversionMetadataT() + self._metadata.environment = conversion_metadata_fb.EnvironmentT() + self._metadata.options = conversion_metadata_fb.ConversionOptionsT() self._metadata.environment.tensorflowVersion = versions.__version__ self._metadata.environment.modelType = self._get_original_model_type() self._experimental_enable_dynamic_update_slice = False @@ -666,6 +663,7 @@ def __init__(self): self._experimental_use_buffer_offset = False self._experimental_reduce_type_precision = False self._experimental_qdq_conversion_mode = None + self._experimental_disable_per_channel_quantization_for_dense_layers = False # Debug parameters self.ir_dump_dir = None @@ -745,14 +743,13 @@ def _quantize( elif self.experimental_new_quantizer and ( activations_type != _dtypes.int16 ): - # TODO(b/175659372): remove the activations_type restriction and enable - # it for all the activation types. return _mlir_quantize( calibrated, self._experimental_disable_per_channel, input_data_type=input_type, output_data_type=output_type, enable_variable_quantization=enable_variable_quantization, + disable_per_channel_for_dense_layers=self._experimental_disable_per_channel_quantization_for_dense_layers, ) else: return calibrate_quantize.calibrate_and_quantize( @@ -818,6 +815,9 @@ def _get_base_converter_args(self): "reduce_type_precision": self._experimental_reduce_type_precision, "use_stablehlo_quantizer": self.experimental_use_stablehlo_quantizer, "qdq_conversion_mode": self._experimental_qdq_conversion_mode, + "disable_per_channel_quantization_for_dense_layers": ( + self._experimental_disable_per_channel_quantization_for_dense_layers + ), } if self.saved_model_dir: @@ -836,7 +836,6 @@ def _get_base_converter_args(self): " path will be a no-op." ) - # TODO: b/307626169 - Integrate StableHLO Quantizer. if self.experimental_use_stablehlo_quantizer: if Optimize.DEFAULT in self.optimizations and self.representative_dataset: if len(self._saved_model_exported_names) != 1: @@ -862,7 +861,8 @@ def _get_base_converter_args(self): qc.RepresentativeDatasetConfig( tf_record=qc.TfRecordFile(path=tfrecord_file_path) ) - ] + ], + enable_per_channel_quantized_weight=True, ) ) @@ -936,7 +936,7 @@ def _increase_conversion_success_metric(self): @classmethod def _set_original_model_type(cls, model_type): """Stores the original model type.""" - if model_type == conversion_metdata_fb.ModelType.NONE: + if model_type == conversion_metadata_fb.ModelType.NONE: raise ValueError("The original model type should be specified.") cls._original_model_type = model_type @@ -944,7 +944,7 @@ def _get_original_model_type(self): """One-time getter to return original model type and set it to NONE.""" model_type = TFLiteConverterBase._original_model_type TFLiteConverterBase._original_model_type = ( - conversion_metdata_fb.ModelType.NONE + conversion_metadata_fb.ModelType.NONE ) return model_type @@ -1030,27 +1030,27 @@ def format_param(param): if quant_mode.is_post_training_float16_quantization(): self._metadata.options.modelOptimizationModes.append( - conversion_metdata_fb.ModelOptimizationMode.PTQ_FLOAT16 + conversion_metadata_fb.ModelOptimizationMode.PTQ_FLOAT16 ) if quant_mode.is_post_training_dynamic_range_quantization(): self._metadata.options.modelOptimizationModes.append( - conversion_metdata_fb.ModelOptimizationMode.PTQ_DYNAMIC_RANGE + conversion_metadata_fb.ModelOptimizationMode.PTQ_DYNAMIC_RANGE ) if quant_mode.is_post_training_int8_quantization(): self._metadata.options.modelOptimizationModes.append( - conversion_metdata_fb.ModelOptimizationMode.PTQ_FULL_INTEGER + conversion_metadata_fb.ModelOptimizationMode.PTQ_FULL_INTEGER ) if quant_mode.is_post_training_int16x8_quantization(): self._metadata.options.modelOptimizationModes.append( - conversion_metdata_fb.ModelOptimizationMode.PTQ_INT16 + conversion_metadata_fb.ModelOptimizationMode.PTQ_INT16 ) if quant_mode.is_quantization_aware_training(): self._metadata.options.modelOptimizationModes.append( - conversion_metdata_fb.ModelOptimizationMode.QUANTIZATION_AWARE_TRAINING + conversion_metadata_fb.ModelOptimizationMode.QUANTIZATION_AWARE_TRAINING ) def _set_conversion_latency_metric(self, value): @@ -1103,7 +1103,6 @@ def _optimize_tflite_model(self, model, quant_mode, quant_io=True): model = _mlir_sparsify(model) if not self._experimental_use_buffer_offset: - # TODO(b/287476027): move this logic into c++ try: model_object = flatbuffer_utils.convert_bytearray_to_object(model) if _check_model_use_buffer_offset(model_object): @@ -1489,7 +1488,6 @@ def convert(self): ) # We make sure to clear the saved_model_dir as there is some # legacy code down in the caller that checks this. - # TODO(b/162537905): Clean these indirect dependencies. self.saved_model_dir = None return super(TFLiteSavedModelConverterV2, self).convert( graph_def, input_tensors, output_tensors @@ -1693,8 +1691,6 @@ def _freeze_concrete_function(self): Raises: ValueError: none or multiple ConcreteFunctions provided. """ - # TODO(b/130297984): Add support for converting multiple function. - if len(self._funcs) == 0: # pylint: disable=g-explicit-length-test raise ValueError("No ConcreteFunction is specified.") @@ -2058,7 +2054,7 @@ def from_concrete_functions(cls, funcs, trackable_obj=None): """ # pylint: disable=protected-access TFLiteConverterBase._set_original_model_type( - conversion_metdata_fb.ModelType.TF_CONCRETE_FUNCTIONS + conversion_metadata_fb.ModelType.TF_CONCRETE_FUNCTIONS ) # pylint: enable=protected-access if trackable_obj is None: @@ -2101,7 +2097,7 @@ def from_saved_model(cls, saved_model_dir, signature_keys=None, tags=None): """ # pylint: disable=protected-access TFLiteConverterBase._set_original_model_type( - conversion_metdata_fb.ModelType.TF_SAVED_MODEL + conversion_metadata_fb.ModelType.TF_SAVED_MODEL ) # pylint: enable=protected-access # When run without eager enabled, this will return the legacy @@ -2176,7 +2172,7 @@ def from_keras_model(cls, model): """ # pylint: disable=protected-access TFLiteConverterBase._set_original_model_type( - conversion_metdata_fb.ModelType.KERAS_MODEL + conversion_metadata_fb.ModelType.KERAS_MODEL ) # pylint: enable=protected-access return TFLiteKerasModelConverterV2(model) @@ -2204,7 +2200,7 @@ def experimental_from_jax(cls, serving_funcs, inputs): """ # pylint: disable=protected-access TFLiteConverterBase._set_original_model_type( - conversion_metdata_fb.ModelType.JAX + conversion_metadata_fb.ModelType.JAX ) # pylint: enable=protected-access return TFLiteJaxConverterV2(serving_funcs, inputs) @@ -3019,7 +3015,7 @@ def from_session(cls, sess, input_tensors, output_tensors): """ # pylint: disable=protected-access TFLiteConverterBase._set_original_model_type( - conversion_metdata_fb.ModelType.TF_SESSION + conversion_metadata_fb.ModelType.TF_SESSION ) # pylint: enable=protected-access graph_def = _freeze_graph(sess, input_tensors, output_tensors) @@ -3059,7 +3055,7 @@ def from_frozen_graph( """ # pylint: disable=protected-access TFLiteConverterBase._set_original_model_type( - conversion_metdata_fb.ModelType.TF_GRAPH_DEF + conversion_metadata_fb.ModelType.TF_GRAPH_DEF ) # pylint: enable=protected-access with _ops.Graph().as_default(): @@ -3165,7 +3161,7 @@ def from_saved_model( """ # pylint: disable=protected-access TFLiteConverterBase._set_original_model_type( - conversion_metdata_fb.ModelType.TF_SAVED_MODEL + conversion_metadata_fb.ModelType.TF_SAVED_MODEL ) # pylint: enable=protected-access if tag_set is None: @@ -3224,7 +3220,7 @@ def from_keras_model_file( """ # pylint: disable=protected-access TFLiteConverterBase._set_original_model_type( - conversion_metdata_fb.ModelType.KERAS_MODEL + conversion_metadata_fb.ModelType.KERAS_MODEL ) # pylint: enable=protected-access return TFLiteKerasModelConverter( diff --git a/tensorflow/lite/python/lite_v2_test.py b/tensorflow/lite/python/lite_v2_test.py index ee41ff0330b0da..74d8c2940d56ed 100644 --- a/tensorflow/lite/python/lite_v2_test.py +++ b/tensorflow/lite/python/lite_v2_test.py @@ -34,21 +34,15 @@ from tensorflow.compiler.mlir.quantization.stablehlo import quantization_options_pb2 as quant_opts_pb2 from tensorflow.lite.python import conversion_metadata_schema_py_generated as metadata_fb from tensorflow.lite.python import convert +from tensorflow.lite.python import interpreter from tensorflow.lite.python import lite from tensorflow.lite.python import lite_v2_test_util from tensorflow.lite.python import schema_py_generated as schema_fb from tensorflow.lite.python import test_util as tflite_test_util from tensorflow.lite.python import util -from tensorflow.lite.python.convert import mlir_quantize -from tensorflow.lite.python.interpreter import Interpreter -from tensorflow.lite.python.interpreter import InterpreterWithCustomOps -from tensorflow.lite.python.interpreter import OpResolverType from tensorflow.lite.python.testdata import _pywrap_test_registerer as test_registerer from tensorflow.lite.python.testdata import double_op -from tensorflow.lite.python.util import get_conversion_metadata -# TODO(b/175659372): We should support 16x8 mode in the mlir quantizer -# from tensorflow.lite.toco import types_pb2 as _types_pb2 -from tensorflow.lite.tools.flatbuffer_utils import convert_bytearray_to_object as _convert_bytearray_to_object +from tensorflow.lite.tools import flatbuffer_utils from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import test_util @@ -58,10 +52,10 @@ from tensorflow.python.ops import rnn from tensorflow.python.platform import resource_loader from tensorflow.python.platform import test +from tensorflow.python.saved_model import loader_impl +from tensorflow.python.saved_model import save from tensorflow.python.saved_model import save_options from tensorflow.python.saved_model import saved_model -from tensorflow.python.saved_model.loader_impl import parse_saved_model -from tensorflow.python.saved_model.save import save from tensorflow.python.trackable import autotrackable # Type alias for preset quantization method protobuf enums. @@ -166,9 +160,9 @@ def __call__(self, x): [str(x) for x in range(11)], shape=(11,), dtype=tf.dtypes.string ) # Check values from converted model. - interpreter = tf.lite.Interpreter(model_content=tflite_model) - interpreter.allocate_tensors() - my_signature = interpreter.get_signature_runner() + interp = interpreter.Interpreter(model_content=tflite_model) + interp.allocate_tensors() + my_signature = interp.get_signature_runner() with self.assertRaises(ValueError) as error: _ = my_signature(x=input_data) @@ -233,13 +227,13 @@ def testConvertMultipleFunctions(self): tflite_model = converter.convert() # Check signatures are valid from converted model. - interpreter = Interpreter(model_content=tflite_model) - signature_defs = interpreter.get_signature_list() + interp = interpreter.Interpreter(model_content=tflite_model) + signature_defs = interp.get_signature_list() # Verify the SignatureDef structure returned is as expected. - self.assertEqual(len(signature_defs), 2) + self.assertLen(signature_defs, 2) self.assertEqual(list(signature_defs.keys()), ['add', 'sub']) - self.assertEqual(len(signature_defs.values()), 2) + self.assertLen(signature_defs.values(), 2) self.assertEqual(list(signature_defs['add'].keys()), ['inputs', 'outputs']) self.assertCountEqual(signature_defs['add']['inputs'], ['x']) self.assertEqual(list(signature_defs['add']['outputs']), ['output_0']) @@ -248,21 +242,21 @@ def testConvertMultipleFunctions(self): self.assertEqual(list(signature_defs['sub']['outputs']), ['output_0']) # Verify the Signature runner executions. - add_signature_runner = interpreter.get_signature_runner('add') + add_signature_runner = interp.get_signature_runner('add') add_output = add_signature_runner(x=input_data) self.assertEqual(add_output['output_0'], 3) input_details = add_signature_runner.get_input_details() - self.assertEqual(1, len(input_details)) + self.assertLen(input_details, 1) self.assertEqual('add_x:0', input_details['x']['name']) self.assertEqual(np.float32, input_details['x']['dtype']) self.assertTrue(([1] == input_details['x']['shape']).all()) self.assertEqual((0.0, 0), input_details['x']['quantization']) - sub_signature_runner = interpreter.get_signature_runner('sub') + sub_signature_runner = interp.get_signature_runner('sub') sub_output = sub_signature_runner(x=input_data) self.assertEqual(sub_output['output_0'], -2) output_details = sub_signature_runner.get_output_details() - self.assertEqual(1, len(output_details)) + self.assertLen(output_details, 1) self.assertEqual( 'StatefulPartitionedCall_1:0', output_details['output_0']['name'] ) @@ -271,7 +265,7 @@ def testConvertMultipleFunctions(self): self.assertEqual((0.0, 0), output_details['output_0']['quantization']) # Check the conversion metadata. - metadata = get_conversion_metadata(tflite_model) + metadata = util.get_conversion_metadata(tflite_model) self.assertIsNotNone(metadata) self.assertEqual(metadata.environment.apiVersion, 2) self.assertEqual( @@ -330,7 +324,7 @@ def testPostTrainingCalibrateAndQuantize(self, mlir_quantizer): quantized_tflite_model = quantized_converter.convert() self.assertIsNotNone(quantized_tflite_model) # Check the conversion metadata. - metadata = get_conversion_metadata(quantized_tflite_model) + metadata = util.get_conversion_metadata(quantized_tflite_model) self.assertIsNotNone(metadata) self.assertEqual( metadata.environment.tensorflowVersion.decode('utf-8'), @@ -350,12 +344,12 @@ def testPostTrainingCalibrateAndQuantize(self, mlir_quantizer): ) # The default input and output types should be float. - interpreter = Interpreter(model_content=quantized_tflite_model) - interpreter.allocate_tensors() - input_details = interpreter.get_input_details() + interp = interpreter.Interpreter(model_content=quantized_tflite_model) + interp.allocate_tensors() + input_details = interp.get_input_details() self.assertLen(input_details, 1) self.assertEqual(np.float32, input_details[0]['dtype']) - output_details = interpreter.get_output_details() + output_details = interp.get_output_details() self.assertLen(output_details, 1) self.assertEqual(np.float32, output_details[0]['dtype']) @@ -415,13 +409,13 @@ def testQuantizationRemovesQDQsForFloatIOInQAT(self): # Because assertions on the model later, we opt out applying default TFLite # delegates (i.e. the XNNPACK delegate). - interpreter = Interpreter( + interp = interpreter.Interpreter( model_content=quantized_model, - experimental_op_resolver_type=OpResolverType.BUILTIN_WITHOUT_DEFAULT_DELEGATES, + experimental_op_resolver_type=interpreter.OpResolverType.BUILTIN_WITHOUT_DEFAULT_DELEGATES, ) - interpreter.allocate_tensors() + interp.allocate_tensors() # The model should have LOGISTIC op, instead of DEQUANTIZE op. - op_details = interpreter._get_ops_details() + op_details = interp._get_ops_details() self.assertEqual(op_details[len(op_details) - 1]['op_name'], 'LOGISTIC') @parameterized.named_parameters( @@ -440,13 +434,13 @@ def testQuantizationRemovesQDQsForFloatIO(self, mlir_quantizer): # Because assertions on the model later, we opt out applying default TFLite # delegates (i.e. the XNNPACK delegate). - interpreter = Interpreter( + interp = interpreter.Interpreter( model_content=quantized_model, - experimental_op_resolver_type=OpResolverType.BUILTIN_WITHOUT_DEFAULT_DELEGATES, + experimental_op_resolver_type=interpreter.OpResolverType.BUILTIN_WITHOUT_DEFAULT_DELEGATES, ) - interpreter.allocate_tensors() + interp.allocate_tensors() # The model should have only one sqrt op. - op_details = interpreter._get_ops_details() + op_details = interp._get_ops_details() self.assertLen(op_details, 1) self.assertEqual(op_details[0]['op_name'], 'SQRT') @@ -498,7 +492,7 @@ def testIntegerQuantization( quantized_tflite_model = quantized_converter.convert() self.assertIsNotNone(quantized_tflite_model) # Check the conversion metadata. - metadata = get_conversion_metadata(quantized_tflite_model) + metadata = util.get_conversion_metadata(quantized_tflite_model) self.assertIsNotNone(metadata) expected_opt_options = [metadata_fb.ModelOptimizationMode.PTQ_FULL_INTEGER] if is_int16_quantize: @@ -507,14 +501,14 @@ def testIntegerQuantization( expected_opt_options, metadata.options.modelOptimizationModes ) - interpreter = Interpreter(model_content=quantized_tflite_model) - interpreter.allocate_tensors() - input_details = interpreter.get_input_details() + interp = interpreter.Interpreter(model_content=quantized_tflite_model) + interp.allocate_tensors() + input_details = interp.get_input_details() self.assertLen(input_details, 1) self.assertEqual( inference_input_output_type.as_numpy_dtype, input_details[0]['dtype'] ) - output_details = interpreter.get_output_details() + output_details = interp.get_output_details() self.assertLen(output_details, 1) self.assertEqual( inference_input_output_type.as_numpy_dtype, output_details[0]['dtype'] @@ -523,12 +517,8 @@ def testIntegerQuantization( # Ensure that the quantized tflite model is smaller. self.assertLess(len(quantized_tflite_model), len(tflite_model)) - @parameterized.named_parameters( - ('_INT16Quantize_INT8InputOutput', True, dtypes.int8) - ) - def testInvalidIntegerQuantization( - self, is_int16_quantize, inference_input_output_type - ): + @parameterized.named_parameters(('_INT16Quantize_INT8InputOutput', True)) + def testInvalidIntegerQuantization(self, is_int16_quantize): root, func, calibration_gen = self._getIntegerQuantizeModel() # Convert quantized model. @@ -563,41 +553,29 @@ def testCalibrateAndQuantizeBuiltinInt16(self): self.assertIsNotNone(float_tflite_model) converter = lite.TFLiteConverterV2.from_concrete_functions([func], root) - # TODO(b/156309549): We should add INT16 to the builtin types. converter.optimizations = [lite.Optimize.DEFAULT] converter.target_spec.supported_ops = [ lite.OpsSet.EXPERIMENTAL_TFLITE_BUILTINS_ACTIVATIONS_INT16_WEIGHTS_INT8 ] converter.representative_dataset = calibration_gen - # TODO(b/175659372): We should support 16x8 mode in the mlir quantizer - # converter._experimental_calibrate_only = True - # calibrated_tflite = converter.convert() - # quantized_tflite_model = mlir_quantize( - # calibrated_tflite, inference_type=_types_pb2.QUANTIZED_INT16) quantized_tflite_model = converter.convert() self.assertIsNotNone(quantized_tflite_model) # The default input and output types should be float. - interpreter = Interpreter(model_content=quantized_tflite_model) - interpreter.allocate_tensors() - input_details = interpreter.get_input_details() + interp = interpreter.Interpreter(model_content=quantized_tflite_model) + interp.allocate_tensors() + input_details = interp.get_input_details() self.assertLen(input_details, 1) self.assertEqual(np.float32, input_details[0]['dtype']) - output_details = interpreter.get_output_details() + output_details = interp.get_output_details() self.assertLen(output_details, 1) self.assertEqual(np.float32, output_details[0]['dtype']) # The weights tensor should be quantized to 8 bits, # the bias tensor should be 32 bits to utilize optimized kernels, # and the activations should be 16 bits. - tensor_details = interpreter.get_tensor_details() - # TODO(b/175659372): The old quantizer yields a 64 bit bias and a - # slightly different tensor order than the new one. - # self.assertEqual(np.int8, tensor_details[1]['dtype']) - # self.assertEqual(np.int32, tensor_details[0]['dtype']) - # self.assertEqual(np.int16, tensor_details[2]['dtype']) - # self.assertEqual(np.int16, tensor_details[3]['dtype']) + tensor_details = interp.get_tensor_details() self.assertEqual(np.int8, tensor_details[2]['dtype']) self.assertEqual(np.int64, tensor_details[1]['dtype']) self.assertEqual(np.int16, tensor_details[0]['dtype']) @@ -618,8 +596,8 @@ def testSignatureDefs(self): # Check values from converted model. expected_value = add_func(input_data) - interpreter = Interpreter(model_content=tflite_model) - signature_defs = interpreter.get_signature_list() + interp = interpreter.Interpreter(model_content=tflite_model) + signature_defs = interp.get_signature_list() results = self._evaluateTFLiteModelUsingSignatureDef( tflite_model, 'serving_default', {'x': input_data} ) @@ -631,9 +609,9 @@ def testSignatureDefs(self): ) # Verify the SignatureDef structure returned is as expected. - self.assertEqual(len(signature_defs), 1) + self.assertLen(signature_defs, 1) self.assertEqual(list(signature_defs.keys()), ['serving_default']) - self.assertEqual(len(signature_defs.values()), 1) + self.assertLen(signature_defs.values(), 1) self.assertEqual( list(signature_defs['serving_default'].keys()), ['inputs', 'outputs'] ) @@ -656,10 +634,10 @@ def testNoSignatureDefsWhenTrackingObjIsNone(self): tflite_model = converter.convert() # Check values from converted model. - interpreter = Interpreter(model_content=tflite_model) - signature_defs = interpreter.get_signature_list() + interp = interpreter.Interpreter(model_content=tflite_model) + signature_defs = interp.get_signature_list() # Verify that there is no SignatureDef structure found. - self.assertEqual(len(signature_defs), 0) + self.assertEmpty(signature_defs) @test_util.run_v2_only def testNoSignatureDefsWhenInvalidTrackingObjIsGiven(self): @@ -674,10 +652,10 @@ def testNoSignatureDefsWhenInvalidTrackingObjIsGiven(self): tflite_model = converter.convert() # Check values from converted model. - interpreter = Interpreter(model_content=tflite_model) - signature_defs = interpreter.get_signature_list() + interp = interpreter.Interpreter(model_content=tflite_model) + signature_defs = interp.get_signature_list() # Verify that there is no SignatureDef structure found. - self.assertEqual(len(signature_defs), 0) + self.assertEmpty(signature_defs) @test_util.run_v2_only def testTrackbleObject(self): @@ -700,7 +678,7 @@ def _getTrainingTimeQuantizedModel(self): class QLinear(tf.keras.layers.Layer): def __init__(self, units=3, **kwargs): - super(QLinear, self).__init__(**kwargs) + super().__init__(**kwargs) self.units = units def build(self, input_shape): @@ -759,21 +737,21 @@ def testTrainingTimeQuantization(self, inference_input_output_type): quantized_tflite_model = quantized_converter.convert() self.assertIsNotNone(quantized_tflite_model) # Check the conversion metadata. - metadata = get_conversion_metadata(quantized_tflite_model) + metadata = util.get_conversion_metadata(quantized_tflite_model) self.assertIsNotNone(metadata) self.assertAllEqual( [metadata_fb.ModelOptimizationMode.QUANTIZATION_AWARE_TRAINING], metadata.options.modelOptimizationModes, ) - interpreter = Interpreter(model_content=quantized_tflite_model) - interpreter.allocate_tensors() - input_details = interpreter.get_input_details() + interp = interpreter.Interpreter(model_content=quantized_tflite_model) + interp.allocate_tensors() + input_details = interp.get_input_details() self.assertLen(input_details, 1) self.assertEqual( inference_input_output_type.as_numpy_dtype, input_details[0]['dtype'] ) - output_details = interpreter.get_output_details() + output_details = interp.get_output_details() self.assertLen(output_details, 1) self.assertEqual( inference_input_output_type.as_numpy_dtype, output_details[0]['dtype'] @@ -811,6 +789,38 @@ def testNewQuantizer(self): new_value = self._evaluateTFLiteModel(new_tflite, [input_data]) self.assertAllClose(old_value, new_value, atol=1e-01) + @test_util.run_v2_only + def testGatherNDQI8(self): + """Test gather_nd with quantized i8 parameters.""" + + class GatherNDQI8QDQ(tf.keras.Model): + @tf.function( + input_signature=[tf.TensorSpec(shape=(2, 2), dtype=tf.float32)] + ) + + def func(self, input_tensor): + x = tf.quantization.fake_quant_with_min_max_args( + input_tensor, -3.0, 3.0 + ) + x = tf.gather_nd(x, [[0, 0], [1, 1]]) + return tf.quantization.fake_quant_with_min_max_args(x, -3.0, 3.0) + + # Build a QDQ model so that tfl.gather_nd will be converted to a QI8 version + # with the `_experimental_qdq_conversion_mode`` flag + root = GatherNDQI8QDQ() + concrete_func = root.func.get_concrete_function() + converter = lite.TFLiteConverterV2.from_concrete_functions( + [concrete_func], root + ) + converter._experimental_qdq_conversion_mode = 'STATIC' + tflite_model = converter.convert() + + np_data = np.array([[1, 2], [3, 4]], dtype=np.float32) + input_tensor = tf.constant(np_data, dtype=tf.int8) + expected_value = [1, 4] + actual_value = self._evaluateTFLiteModel(tflite_model, [input_tensor]) + self.assertAllClose(expected_value, actual_value[0], atol=1e-05) + @test_util.run_v2_only def testEmbeddings(self): """Test model with embeddings.""" @@ -821,7 +831,7 @@ def testEmbeddings(self): class EmbeddingModel(tf.keras.Model): def __init__(self): - super(EmbeddingModel, self).__init__() + super().__init__() self.shared_weights = self.add_weight( 'weights', shape=(2000, 300), @@ -949,7 +959,7 @@ def testIntegerQuantizationWithFlexOp( quantized_tflite_model = quantized_converter.convert() self.assertIsNotNone(quantized_tflite_model) # Check the conversion metadata. - metadata = get_conversion_metadata(quantized_tflite_model) + metadata = util.get_conversion_metadata(quantized_tflite_model) self.assertIsNotNone(metadata) self.assertEqual(metadata.options.enableSelectTfOps, True) expected_opt_options = [metadata_fb.ModelOptimizationMode.PTQ_FULL_INTEGER] @@ -959,14 +969,14 @@ def testIntegerQuantizationWithFlexOp( expected_opt_options, metadata.options.modelOptimizationModes ) - interpreter = Interpreter(model_content=quantized_tflite_model) - interpreter.allocate_tensors() - input_details = interpreter.get_input_details() + interp = interpreter.Interpreter(model_content=quantized_tflite_model) + interp.allocate_tensors() + input_details = interp.get_input_details() self.assertLen(input_details, 1) self.assertEqual( inference_input_output_type.as_numpy_dtype, input_details[0]['dtype'] ) - output_details = interpreter.get_output_details() + output_details = interp.get_output_details() self.assertLen(output_details, 1) self.assertEqual( inference_input_output_type.as_numpy_dtype, output_details[0]['dtype'] @@ -1063,13 +1073,13 @@ def testIntegerQuantizationWithUnsupportedOps( expected_dtype if enable_mlir_quantizer else dtypes.float32 ) - interpreter = Interpreter(model_content=quantized_tflite_model) - interpreter.allocate_tensors() - input_details = interpreter.get_input_details() + interp = interpreter.Interpreter(model_content=quantized_tflite_model) + interp.allocate_tensors() + input_details = interp.get_input_details() self.assertLen(input_details, 2) self.assertEqual(input_details[0]['dtype'], expected_dtype) self.assertEqual(input_details[1]['dtype'], expected_ceil_dtype) - output_details = interpreter.get_output_details() + output_details = interp.get_output_details() self.assertLen(output_details, 2) self.assertEqual(output_details[0]['dtype'], expected_dtype) self.assertEqual(output_details[1]['dtype'], expected_ceil_dtype) @@ -1127,9 +1137,6 @@ def calibration_gen(): ('_IntOnly_INT8InputOutput', True, False, dtypes.int8), ('_IntOnly_UINT8InputOutput', True, False, dtypes.uint8), ('_IntOnly_INT16Quantize_INT16InputOutput', True, True, dtypes.int16), - # TODO(b/198231624): Support control flow ops in MLIR quantizer - # ('_IntOnly_INT8InputOutputMlirQuant', True, False, dtypes.int8, True), - # ('_IntOnly_UINT8InputOutputMlirQuant', True, False, dtypes.uint8, True), ) @test_util.run_v2_only def testIntegerQuantizationWithControlFlow( @@ -1177,13 +1184,13 @@ def testIntegerQuantizationWithControlFlow( expected_dtype = inference_input_output_type.as_numpy_dtype - interpreter = Interpreter(model_content=quantized_tflite_model) - interpreter.allocate_tensors() - input_details = interpreter.get_input_details() + interp = interpreter.Interpreter(model_content=quantized_tflite_model) + interp.allocate_tensors() + input_details = interp.get_input_details() self.assertLen(input_details, 2) self.assertEqual(input_details[0]['dtype'], expected_dtype) self.assertEqual(input_details[1]['dtype'], dtypes.bool) - output_details = interpreter.get_output_details() + output_details = interp.get_output_details() self.assertLen(output_details, 1) self.assertEqual(output_details[0]['dtype'], expected_dtype) @@ -1213,13 +1220,13 @@ def testNewQuantizerBlocklistingArgs( quantized_converter._experimental_calibrate_only = True quantized_converter.experimental_lower_to_saved_model = lower_to_saved_model calibrated = quantized_converter.convert() - quantized_tflite_model = mlir_quantize( + quantized_tflite_model = convert.mlir_quantize( calibrated, denylisted_ops=denylisted_ops, denylisted_nodes=denylisted_nodes, ) - interpreter = Interpreter(model_content=quantized_tflite_model) - details = interpreter.get_tensor_details() + interp = interpreter.Interpreter(model_content=quantized_tflite_model) + details = interp.get_tensor_details() num_quantized_tensors = sum([ 1 for detail in details @@ -1254,7 +1261,7 @@ def testNewQuantizerNumericVerificationDebugMode(self, whole_model_verify): # Create a TFLite model with new quantizer and numeric verify ops. quantized_converter._experimental_calibrate_only = True calibrated = quantized_converter.convert() - debug_mode_tflite = mlir_quantize( + debug_mode_tflite = convert.mlir_quantize( calibrated, enable_numeric_verify=True, enable_whole_model_verify=whole_model_verify, @@ -1269,18 +1276,18 @@ def testNewQuantizerNumericVerificationDebugMode(self, whole_model_verify): ) def examine_tflite_model(tflite_content, input_data): - interpreter = Interpreter( + interp = interpreter.Interpreter( model_content=tflite_content, - experimental_op_resolver_type=OpResolverType.BUILTIN_WITHOUT_DEFAULT_DELEGATES, + experimental_op_resolver_type=interpreter.OpResolverType.BUILTIN_WITHOUT_DEFAULT_DELEGATES, ) - interpreter.allocate_tensors() - input_details = interpreter.get_input_details() - interpreter.set_tensor(input_details[0]['index'], input_data.numpy()) - interpreter.invoke() - tensor_details = interpreter.get_tensor_details() + interp.allocate_tensors() + input_details = interp.get_input_details() + interp.set_tensor(input_details[0]['index'], input_data.numpy()) + interp.invoke() + tensor_details = interp.get_tensor_details() return { - details['name']: interpreter.get_tensor(details['index']) - for details in interpreter.get_tensor_details() + details['name']: interp.get_tensor(details['index']) + for details in interp.get_tensor_details() }, tensor_details tflite_result, _ = examine_tflite_model(production_tflite, input_data) @@ -1328,15 +1335,14 @@ def examine_tflite_model(tflite_content, input_data): ('_PerChannelMlirQuant', False, True), ('_PerTensorQuant', True, False), ('_PerTensorMlirQuant', True, True), - ('_PerChannelDynamicRange', False, False, False), - ('_PerTensorDynamicRange', True, False, False), + ('_PerChannelDynamicRange', False, False), + ('_PerTensorDynamicRange', True, False), ) @test_util.run_v2_only def testDisablePerChannelQuantization( self, disable_per_channel=False, enable_mlir_quantizer=False, - representative_dataset=True, ): k_conv_name = 'tfl.pseudo_qconst' if enable_mlir_quantizer else 'Conv2D' # Dynamic range quant requires total num elements of filters > 1024. @@ -1358,11 +1364,11 @@ def testDisablePerChannelQuantization( quantized_tflite_model = quantized_converter.convert() self.assertIsNotNone(quantized_tflite_model) - interpreter = Interpreter(model_content=quantized_tflite_model) - interpreter.allocate_tensors() + interp = interpreter.Interpreter(model_content=quantized_tflite_model) + interp.allocate_tensors() detail = next(( d - for d in interpreter.get_tensor_details() + for d in interp.get_tensor_details() if d['name'].startswith(k_conv_name) )) quant_params = detail['quantization_parameters'] @@ -1370,6 +1376,75 @@ def testDisablePerChannelQuantization( self.assertLen(quant_params['scales'], expected_num_params) self.assertLen(quant_params['zero_points'], expected_num_params) + def _getIntegerQuantizeDenseModel(self, num_filters=32): + np.random.seed(0) + + root = autotrackable.AutoTrackable() + + @tf.function( + input_signature=[tf.TensorSpec(shape=[1, 16], dtype=tf.float32)] + ) + def func(inp): + dense = tf.matmul(a=inp, b=tf.ones([16, num_filters])) + output = tf.nn.relu(dense, name='output') + return output + + def calibration_gen(): + for _ in range(5): + yield [np.random.uniform(-1, 1, size=(1, 16)).astype(np.float32)] + + root.f = func + to_save = root.f.get_concrete_function() + return (root, to_save, calibration_gen) + + @parameterized.named_parameters( + ('_PerChannelQuant', False, False), + ('_PerChannelMlirQuant', False, True), + ('_PerTensorQuant', True, False), + ('_PerTensorMlirQuant', True, True), + ('_PerChannelDynamicRange', False, True, True), + ('_PerTensorDynamicRange', True, True, True), + ) + @test_util.run_v2_only + def testDisablePerChannelQuantizationForDenseLayers( + self, + disable_per_channel_for_dense=False, + enable_mlir_quantizer=False, + representative_dataset=False, + ): + k_dense_name = 'tfl.pseudo_qconst' if representative_dataset else 'MatMul' + # Dynamic range quant requires total num elements of filters > 1024. + k_num_filters = 64 + root, func, calib_gen = self._getIntegerQuantizeDenseModel(k_num_filters) + quantized_converter = tf.lite.TFLiteConverter.from_concrete_functions( + [func], root + ) + quantized_converter.optimizations = [lite.Optimize.DEFAULT] + if representative_dataset: + quantized_converter.representative_dataset = calib_gen + quantized_converter.target_spec.supported_ops = [ + lite.OpsSet.TFLITE_BUILTINS + ] + quantized_converter.experimental_new_quantizer = enable_mlir_quantizer + if disable_per_channel_for_dense: + quantized_converter._experimental_disable_per_channel_quantization_for_dense_layers = ( + disable_per_channel_for_dense + ) + quantized_tflite_model = quantized_converter.convert() + self.assertIsNotNone(quantized_tflite_model) + + interp = interpreter.Interpreter(model_content=quantized_tflite_model) + interp.allocate_tensors() + detail = next(( + d + for d in interp.get_tensor_details() + if d['name'].startswith(k_dense_name) + )) + quant_params = detail['quantization_parameters'] + expected_num_params = 1 if disable_per_channel_for_dense else k_num_filters + self.assertLen(quant_params['scales'], expected_num_params) + self.assertLen(quant_params['zero_points'], expected_num_params) + @parameterized.named_parameters( ('MlirQuantize', True), ('TocoQuantize', False) ) @@ -1392,13 +1467,13 @@ def calibration_gen(): converter.experimental_new_quantizer = enable_mlir_quantizer quantized_model = converter.convert() - interpreter = Interpreter(model_content=quantized_model) - interpreter.allocate_tensors() - input_details = interpreter.get_input_details() - interpreter.set_tensor(input_details[0]['index'], input_data) - interpreter.invoke() - output_details = interpreter.get_output_details() - output = interpreter.get_tensor(output_details[0]['index']) + interp = interpreter.Interpreter(model_content=quantized_model) + interp.allocate_tensors() + input_details = interp.get_input_details() + interp.set_tensor(input_details[0]['index'], input_data) + interp.invoke() + output_details = interp.get_output_details() + output = interp.get_tensor(output_details[0]['index']) # the inputs and weights are far smaller than the biases, so the final # result should be equal to the biases. self.assertAllClose(root.bias, output.flatten()) @@ -1446,7 +1521,7 @@ def testForceSelectTFOps(self): converter.target_spec.supported_ops = [tf.lite.OpsSet.SELECT_TF_OPS] tflite_model = converter.convert() # Check the conversion metadata. - metadata = get_conversion_metadata(tflite_model) + metadata = util.get_conversion_metadata(tflite_model) self.assertIsNotNone(metadata) self.assertEqual(metadata.options.forceSelectTfOps, True) @@ -1467,7 +1542,7 @@ def testExcludeConversionMetadata(self): converter.exclude_conversion_metadata = True tflite_model = converter.convert() # Check the conversion metadata. - metadata = get_conversion_metadata(tflite_model) + metadata = util.get_conversion_metadata(tflite_model) self.assertIsNone(metadata) def testConversionMetadataForDynamicRange(self): @@ -1478,7 +1553,7 @@ def testConversionMetadataForDynamicRange(self): converter.optimizations = [lite.Optimize.DEFAULT] quantized_model = converter.convert() # Check the conversion metadata. - metadata = get_conversion_metadata(quantized_model) + metadata = util.get_conversion_metadata(quantized_model) self.assertIsNotNone(metadata) self.assertAllEqual( [metadata_fb.ModelOptimizationMode.PTQ_DYNAMIC_RANGE], @@ -1493,7 +1568,7 @@ def testConversionMetadataForFloat16(self): converter.target_spec.supported_types = [dtypes.float16] quantized_model = converter.convert() # Check the conversion metadata. - metadata = get_conversion_metadata(quantized_model) + metadata = util.get_conversion_metadata(quantized_model) self.assertIsNotNone(metadata) self.assertAllEqual( [metadata_fb.ModelOptimizationMode.PTQ_FLOAT16], @@ -1546,7 +1621,7 @@ def testStableHloQuantizerSupportsOnlyStaticRangePtq(self): to_save = root.f.get_concrete_function(input_data) save_dir = os.path.join(self.get_temp_dir(), 'saved_model') - save(root, save_dir, to_save) + save.save(root, save_dir, to_save) converter = lite.TFLiteConverterV2.from_saved_model(save_dir) converter.experimental_use_stablehlo_quantizer = True @@ -1554,22 +1629,21 @@ def testStableHloQuantizerSupportsOnlyStaticRangePtq(self): converter.convert() @test_util.run_v2_only - def testStableHloQuantizerNoOpForStaticRangePtq(self): - """Tests that StableHLO Quantizer performs a no-op for Static-Range PTQ.""" - # TODO: b/307626169 - Provide a full test after StableHLO Quantizer - # integration. + def testStableHloQuantizerNoOpForTfSavedModel(self): + """Tests that StableHLO Quantizer does not run for TF SavedModel.""" input_data = tf.constant(1.0, shape=[1]) root = autotrackable.AutoTrackable() root.f = tf.function(lambda x: 2.0 * x) to_save = root.f.get_concrete_function(input_data) save_dir = os.path.join(self.get_temp_dir(), 'saved_model') - save(root, save_dir, to_save) + save.save(root, save_dir, to_save) def _representative_data_gen(): return [{'x': np.ones(shape=(1,), dtype=np.float32)}] converter = lite.TFLiteConverterV2.from_saved_model(save_dir) + # Set the flags to enable StableHLO Quantizer. converter.experimental_use_stablehlo_quantizer = True converter.optimizations = [lite.Optimize.DEFAULT] converter.representative_dataset = _representative_data_gen @@ -1578,8 +1652,8 @@ def _representative_data_gen(): self.assertIsNotNone(tflite_model) # Test that no tensor is quantized. - interpreter = tf.lite.Interpreter(model_content=tflite_model) - all_tensor_details = interpreter.get_tensor_details() + interp = interpreter.Interpreter(model_content=tflite_model) + all_tensor_details = interp.get_tensor_details() for tensor_detail in all_tensor_details: self.assertIn('dtype', tensor_detail) self.assertEqual(tensor_detail['dtype'], np.float32) @@ -1595,10 +1669,10 @@ def testV1SimpleModel(self): tflite_model = converter.convert() self.assertTrue(tflite_model) - interpreter = Interpreter(model_content=tflite_model) - interpreter.allocate_tensors() + interp = interpreter.Interpreter(model_content=tflite_model) + interp.allocate_tensors() - input_details = interpreter.get_input_details() + input_details = interp.get_input_details() self.assertLen(input_details, 2) self.assertStartsWith(input_details[0]['name'], 'inputA') self.assertEqual(np.float32, input_details[0]['dtype']) @@ -1613,7 +1687,7 @@ def testV1SimpleModel(self): self.assertTrue([1, 16, 16, 3], input_details[1]['shape']) self.assertEqual((0.0, 0.0), input_details[1]['quantization']) - output_details = interpreter.get_output_details() + output_details = interp.get_output_details() self.assertLen(output_details, 1) self.assertStartsWith(output_details[0]['name'], 'add') self.assertEqual(np.float32, output_details[0]['dtype']) @@ -1659,29 +1733,29 @@ def testUnfoldLargeConstant(self, unfold_large_constant): ) # Check values from converted model. - interpreter = Interpreter(model_content=tflite_model) - interpreter.allocate_tensors() + interp = interpreter.Interpreter(model_content=tflite_model) + interp.allocate_tensors() - input_details = interpreter.get_input_details() + input_details = interp.get_input_details() self.assertLen(input_details, 1) self.assertEqual('input:0', input_details[0]['name']) self.assertEqual(np.float32, input_details[0]['dtype']) self.assertAllEqual([1000, 1000], input_details[0]['shape']) self.assertEqual((0.0, 0.0), input_details[0]['quantization']) - output_details = interpreter.get_output_details() + output_details = interp.get_output_details() self.assertEqual('add:0', output_details[0]['name']) self.assertEqual(np.float32, output_details[0]['dtype']) self.assertAllEqual([1000, 1000], output_details[0]['shape']) self.assertEqual((0.0, 0.0), output_details[0]['quantization']) - interpreter.set_tensor( + interp.set_tensor( input_details[0]['index'], np.ones(shape=[1000, 1000], dtype=np.float32) ) - interpreter.invoke() + interp.invoke() self.assertAllEqual( np.full(shape=[1000, 1000], fill_value=2.0, dtype=np.float32), - interpreter.get_tensor(output_details[0]['index']), + interp.get_tensor(output_details[0]['index']), ) @test_util.run_v2_only @@ -1726,7 +1800,7 @@ def testTF1HubFormattedModel(self): # TF1 hub model is based on V1 saved model and they omit the saved model # schema version setting. - saved_model_proto = parse_saved_model(saved_model_dir) + saved_model_proto = loader_impl.parse_saved_model(saved_model_dir) saved_model_proto.saved_model_schema_version = 0 saved_model_pb_file_path = os.path.join(saved_model_dir, 'saved_model.pb') @@ -1804,29 +1878,27 @@ def testModelWithHashTableInitializer(self): tflite_model = converter.convert() # Check values from converted model. - interpreter = Interpreter(model_content=tflite_model) - input_details = interpreter.get_input_details() - output_details = interpreter.get_output_details() + interp = interpreter.Interpreter(model_content=tflite_model) + input_details = interp.get_input_details() + output_details = interp.get_output_details() input_data = np.array(['a', 'b', 'c', 'z'], dtype=np.string_) - interpreter.resize_tensor_input( - input_details[0]['index'], [4], strict=False - ) - interpreter.allocate_tensors() + interp.resize_tensor_input(input_details[0]['index'], [4], strict=False) + interp.allocate_tensors() - interpreter.set_tensor(input_details[0]['index'], input_data) + interp.set_tensor(input_details[0]['index'], input_data) # Invoke multiple times to ensure the initializer graph runs only once. - interpreter.invoke() - actual_value = interpreter.get_tensor(output_details[0]['index']) + interp.invoke() + actual_value = interp.get_tensor(output_details[0]['index']) self.assertEqual([1, 2, 3, -1], list(actual_value)) - interpreter.invoke() - actual_value = interpreter.get_tensor(output_details[0]['index']) + interp.invoke() + actual_value = interp.get_tensor(output_details[0]['index']) self.assertEqual([1, 2, 3, -1], list(actual_value)) - interpreter.invoke() - actual_value = interpreter.get_tensor(output_details[0]['index']) + interp.invoke() + actual_value = interp.get_tensor(output_details[0]['index']) self.assertEqual([1, 2, 3, -1], list(actual_value)) def _createV1ModelWithMutableHashTable(self): @@ -1900,20 +1972,18 @@ def testModelWithMutableHashTable(self): tflite_model = converter.convert() # Check values from converted model. - interpreter = Interpreter(model_content=tflite_model) - input_details = interpreter.get_input_details() - output_details = interpreter.get_output_details() + interp = interpreter.Interpreter(model_content=tflite_model) + input_details = interp.get_input_details() + output_details = interp.get_output_details() input_data = np.array(['a', 'b', 'c'], dtype=np.string_) - interpreter.resize_tensor_input( - input_details[0]['index'], [3], strict=False - ) - interpreter.allocate_tensors() + interp.resize_tensor_input(input_details[0]['index'], [3], strict=False) + interp.allocate_tensors() - interpreter.set_tensor(input_details[0]['index'], input_data) + interp.set_tensor(input_details[0]['index'], input_data) - interpreter.invoke() - actual_value = interpreter.get_tensor(output_details[0]['index']) + interp.invoke() + actual_value = interp.get_tensor(output_details[0]['index']) self.assertEqual([1, 5, -1], list(actual_value)) @test_util.run_v2_only @@ -1936,8 +2006,8 @@ def testReduceSumWithInt16Quant(self): converter.representative_dataset = lambda: [inputs] content = converter.convert() - interpreter = tf.lite.Interpreter(model_content=content) - runner = interpreter.get_signature_runner('serving_default') + interp = interpreter.Interpreter(model_content=content) + runner = interp.get_signature_runner('serving_default') y = runner(x=np.array([[1, 1, 1], [2, 2, 2], [3, 3, 3]]).astype(np.int16)) self.assertEqual([3, 6, 9], list(list(y.values())[0])) @@ -1950,7 +2020,7 @@ def testConstModel(self): to_save = root.f.get_concrete_function(input_data) save_dir = os.path.join(self.get_temp_dir(), 'saved_model') - save(root, save_dir, to_save) + save.save(root, save_dir, to_save) # Convert model and ensure model is not None. converter = lite.TFLiteConverterV2.from_saved_model(save_dir) @@ -1969,13 +2039,13 @@ def testVariableModel(self): to_save = root.f.get_concrete_function(input_data) save_dir = os.path.join(self.get_temp_dir(), 'saved_model') - save(root, save_dir, to_save) + save.save(root, save_dir, to_save) # Convert model and ensure model is not None. converter = lite.TFLiteConverterV2.from_saved_model(save_dir) tflite_model = converter.convert() # Check the conversion metadata. - metadata = get_conversion_metadata(tflite_model) + metadata = util.get_conversion_metadata(tflite_model) self.assertIsNotNone(metadata) self.assertEqual( metadata.environment.modelType, metadata_fb.ModelType.TF_SAVED_MODEL @@ -1997,7 +2067,7 @@ def testNativeVariablesModel(self, enable_resource_variables): to_save = root.assign_add.get_concrete_function(input_data) save_dir = os.path.join(self.get_temp_dir(), 'saved_model') - save(root, save_dir, to_save) + save.save(root, save_dir, to_save) # Convert model and ensure model is not None. converter = lite.TFLiteConverterV2.from_saved_model(save_dir) @@ -2031,7 +2101,7 @@ def testSignatures(self): to_save = root.f.get_concrete_function(input_data) save_dir = os.path.join(self.get_temp_dir(), 'saved_model') - save(root, save_dir, to_save) + save.save(root, save_dir, to_save) # Convert model with invalid `signature_keys`. with self.assertRaises(ValueError) as error: @@ -2068,12 +2138,10 @@ def testSignatureDefsWithFullIntegerQuantization(self): converter.optimizations = [tf.lite.Optimize.DEFAULT] tflite_model = converter.convert() # 2. Initialize the Interpreter - interpreter = Interpreter(model_content=tflite_model) - input_details = interpreter.get_input_details()[0] - output_details = interpreter.get_output_details()[0] - interpreter.resize_tensor_input(input_details['index'], tflite_input_shape) - interpreter.allocate_tensors() - signature_list = interpreter._get_full_signature_list()['serving_default'] + interp = interpreter.Interpreter(model_content=tflite_model) + input_details = interp.get_input_details()[0] + interp.resize_tensor_input(input_details['index'], tflite_input_shape) + interp.allocate_tensors() # 3. (Skip) Verify that signature def input/output tensors are in the model. # 4. Evaluate the model input_data = np.random.random(tflite_input_shape).astype(np.float32) @@ -2089,14 +2157,14 @@ def testSignatureDefsWithFullIntegerQuantization(self): converter.inference_output_type = tf.int8 tflite_model_quant = converter.convert() # 2. Initialize the Interpreter - interpreter = Interpreter(model_content=tflite_model_quant) - input_details = interpreter.get_input_details()[0] - output_details = interpreter.get_output_details()[0] - interpreter.resize_tensor_input(input_details['index'], tflite_input_shape) - interpreter.allocate_tensors() + interp = interpreter.Interpreter(model_content=tflite_model_quant) + input_details = interp.get_input_details()[0] + output_details = interp.get_output_details()[0] + interp.resize_tensor_input(input_details['index'], tflite_input_shape) + interp.allocate_tensors() # 3. Verify that signature def input/output tensors are in the model. - all_indices = {item['index'] for item in interpreter.get_tensor_details()} - signature_list = interpreter._get_full_signature_list()['serving_default'] + all_indices = {item['index'] for item in interp.get_tensor_details()} + signature_list = interp._get_full_signature_list()['serving_default'] input_tensor_indices = set(signature_list['inputs'].values()) assert input_tensor_indices.issubset(all_indices) output_tensor_indices = set(signature_list['outputs'].values()) @@ -2131,7 +2199,7 @@ def testSignatureDefs(self): ) save_dir = os.path.join(self.get_temp_dir(), 'saved_model') - save(root, save_dir, {'mul_add': mul_add_func}) + save.save(root, save_dir, {'mul_add': mul_add_func}) converter = lite.TFLiteConverterV2.from_saved_model( save_dir, signature_keys=['mul_add'] @@ -2140,8 +2208,8 @@ def testSignatureDefs(self): # Check values from converted model. expected_value = root.mul_add(input_data_1, input_data_0) - interpreter = Interpreter(model_content=tflite_model) - signature_defs = interpreter.get_signature_list() + interp = interpreter.Interpreter(model_content=tflite_model) + signature_defs = interp.get_signature_list() results = self._evaluateTFLiteModelUsingSignatureDef( tflite_model, 'mul_add', {'y': input_data_0, 'x': input_data_1} ) @@ -2149,9 +2217,9 @@ def testSignatureDefs(self): self.assertEqual(expected_value.numpy(), results['output_0']) # Verify the SignatureDef structure returned is as expected. - self.assertEqual(len(signature_defs), 1) + self.assertLen(signature_defs, 1) self.assertEqual(list(signature_defs.keys()), ['mul_add']) - self.assertEqual(len(signature_defs.values()), 1) + self.assertLen(signature_defs.values(), 1) self.assertEqual( list(signature_defs['mul_add'].keys()), ['inputs', 'outputs'] ) @@ -2172,7 +2240,7 @@ def testSignatureDefsWithDefaultValue(self): ) save_dir = os.path.join(self.get_temp_dir(), 'saved_model') - save(root, save_dir, {'mul_add': mul_add_func}) + save.save(root, save_dir, {'mul_add': mul_add_func}) converter = lite.TFLiteConverterV2.from_saved_model( save_dir, signature_keys=['mul_add'] @@ -2181,8 +2249,8 @@ def testSignatureDefsWithDefaultValue(self): # Check values from converted model. expected_value = root.mul_add(input_data_1, input_data_0) - interpreter = Interpreter(model_content=tflite_model) - signature_defs = interpreter.get_signature_list() + interp = interpreter.Interpreter(model_content=tflite_model) + signature_defs = interp.get_signature_list() results = self._evaluateTFLiteModelUsingSignatureDef( tflite_model, None, {'y': input_data_0, 'x': input_data_1} ) @@ -2190,9 +2258,9 @@ def testSignatureDefsWithDefaultValue(self): self.assertEqual(expected_value.numpy(), results['output_0']) # Verify the SignatureDef structure returned is as expected. - self.assertEqual(len(signature_defs), 1) + self.assertLen(signature_defs, 1) self.assertEqual(list(signature_defs.keys()), ['mul_add']) - self.assertEqual(len(signature_defs.values()), 1) + self.assertLen(signature_defs.values(), 1) self.assertEqual( list(signature_defs['mul_add'].keys()), ['inputs', 'outputs'] ) @@ -2210,7 +2278,7 @@ def testSignatureDefsQuantizedModel(self): ) save_dir = os.path.join(self.get_temp_dir(), 'saved_model') - save(root, save_dir, {'mul_add': mul_add_func}) + save.save(root, save_dir, {'mul_add': mul_add_func}) converter = lite.TFLiteConverterV2.from_saved_model( save_dir, signature_keys=['mul_add'] @@ -2233,13 +2301,13 @@ def representative_dataset_gen(): tflite_model = converter.convert() # Check signatures are valid from converted model. - interpreter = Interpreter(model_content=tflite_model) - signature_defs = interpreter.get_signature_list() + interp = interpreter.Interpreter(model_content=tflite_model) + signature_defs = interp.get_signature_list() # Verify the SignatureDef structure returned is as expected. - self.assertEqual(len(signature_defs), 1) + self.assertLen(signature_defs, 1) self.assertEqual(list(signature_defs.keys()), ['mul_add']) - self.assertEqual(len(signature_defs.values()), 1) + self.assertLen(signature_defs.values(), 1) self.assertEqual( list(signature_defs['mul_add'].keys()), ['inputs', 'outputs'] ) @@ -2255,20 +2323,20 @@ def testMultipleFunctionModel(self): sub_func = root.sub.get_concrete_function(input_data) save_dir = os.path.join(self.get_temp_dir(), 'saved_model') - save(root, save_dir, {'add': add_func, 'sub': sub_func}) + save.save(root, save_dir, {'add': add_func, 'sub': sub_func}) # Try converting multiple functions. converter = lite.TFLiteConverterV2.from_saved_model(save_dir) tflite_model = converter.convert() self.assertIsNotNone(tflite_model) - interpreter = tf.lite.Interpreter(model_content=tflite_model) - signature_defs = interpreter.get_signature_list() + interp = interpreter.Interpreter(model_content=tflite_model) + signature_defs = interp.get_signature_list() # Verify the SignatureDef structure returned is as expected. - self.assertEqual(len(signature_defs), 2) + self.assertLen(signature_defs, 2) self.assertEqual(list(signature_defs.keys()), ['add', 'sub']) - self.assertEqual(len(signature_defs.values()), 2) + self.assertLen(signature_defs.values(), 2) self.assertEqual(list(signature_defs['add'].keys()), ['inputs', 'outputs']) self.assertCountEqual(signature_defs['add']['inputs'], ['x']) self.assertEqual(list(signature_defs['add']['outputs']), ['output_0']) @@ -2277,11 +2345,11 @@ def testMultipleFunctionModel(self): self.assertEqual(list(signature_defs['sub']['outputs']), ['output_0']) # Verify the Signature runner executions. - add_signature_runner = interpreter.get_signature_runner('add') + add_signature_runner = interp.get_signature_runner('add') add_output = add_signature_runner(x=input_data) self.assertEqual(add_output['output_0'], 3) - sub_signature_runner = interpreter.get_signature_runner('sub') + sub_signature_runner = interp.get_signature_runner('sub') sub_output = sub_signature_runner(x=input_data) self.assertEqual(sub_output['output_0'], -2) @@ -2312,7 +2380,7 @@ def testMultipleFunctionQuantizedModel( sub_func = root.sub.get_concrete_function(input_data) save_dir = os.path.join(self.get_temp_dir(), 'saved_model') - save(root, save_dir, {'add': add_func, 'sub': sub_func}) + save.save(root, save_dir, {'add': add_func, 'sub': sub_func}) # Try converting multiple functions. converter = lite.TFLiteConverterV2.from_saved_model(save_dir) @@ -2359,13 +2427,13 @@ def representative_dataset_gen(): tflite_model = converter.convert() self.assertIsNotNone(tflite_model) - interpreter = tf.lite.Interpreter(model_content=tflite_model) - signature_defs = interpreter.get_signature_list() + interp = interpreter.Interpreter(model_content=tflite_model) + signature_defs = interp.get_signature_list() # Verify the SignatureDef structure returned is as expected. - self.assertEqual(len(signature_defs), 2) + self.assertLen(signature_defs, 2) self.assertEqual(list(signature_defs.keys()), ['add', 'sub']) - self.assertEqual(len(signature_defs.values()), 2) + self.assertLen(signature_defs.values(), 2) self.assertEqual(list(signature_defs['add'].keys()), ['inputs', 'outputs']) self.assertCountEqual(signature_defs['add']['inputs'], ['x']) self.assertEqual(list(signature_defs['add']['outputs']), ['output_0']) @@ -2379,7 +2447,7 @@ def representative_dataset_gen(): inference_input_output_type.as_numpy_dtype ) ) - add_signature_runner = interpreter.get_signature_runner('add') + add_signature_runner = interp.get_signature_runner('add') add_output = add_signature_runner(x=input_data) self.assertIsNotNone(add_output['output_0']) input_details = add_signature_runner.get_input_details() @@ -2392,7 +2460,7 @@ def representative_dataset_gen(): if inference_input_output_type == dtypes.float32: self.assertEqual((0.0, 0), input_details['x']['quantization']) - sub_signature_runner = interpreter.get_signature_runner('sub') + sub_signature_runner = interp.get_signature_runner('sub') sub_output = sub_signature_runner(x=input_data) self.assertIsNotNone(sub_output['output_0']) output_details = sub_signature_runner.get_output_details() @@ -2418,7 +2486,9 @@ def testMultipleFunctionModelWithSharedWeight(self): mul_func = root.mul.get_concrete_function(input_data) save_dir = os.path.join(self.get_temp_dir(), 'saved_model') - save(root, save_dir, {'add': add_func, 'sub': sub_func, 'mul': mul_func}) + save.save( + root, save_dir, {'add': add_func, 'sub': sub_func, 'mul': mul_func} + ) # Try converting multiple functions. converter = lite.TFLiteConverterV2.from_saved_model(save_dir) @@ -2428,14 +2498,12 @@ def testMultipleFunctionModelWithSharedWeight(self): # Make sure that the weight tensors are shared. self.assertLess(len(tflite_model), 1100000) - # TODO(b/184696047): Write down the test codes for multiple signature - # runners once the Python API is ready to use. - interpreter = tf.lite.Interpreter(model_content=tflite_model) - signature_defs = interpreter.get_signature_list() + interp = tf.lite.Interpreter(model_content=tflite_model) + signature_defs = interp.get_signature_list() self.assertLen(signature_defs, 3) - add_signature_runner = interpreter.get_signature_runner('add') - sub_signature_runner = interpreter.get_signature_runner('sub') - mul_signature_runner = interpreter.get_signature_runner('mul') + add_signature_runner = interp.get_signature_runner('add') + sub_signature_runner = interp.get_signature_runner('sub') + mul_signature_runner = interp.get_signature_runner('mul') self.assertIsNotNone(add_signature_runner) self.assertIsNotNone(sub_signature_runner) self.assertIsNotNone(mul_signature_runner) @@ -2445,7 +2513,7 @@ def testNoConcreteFunctionModel(self): root = self._getMultiFunctionModel() save_dir = os.path.join(self.get_temp_dir(), 'saved_model') - save(root, save_dir) + save.save(root, save_dir) with self.assertRaises(ValueError) as error: _ = lite.TFLiteConverterV2.from_saved_model(save_dir) @@ -2469,7 +2537,7 @@ def testKerasSequentialModel(self): model.fit(x, y, epochs=1) save_dir = os.path.join(self.get_temp_dir(), 'saved_model') - save(model, save_dir) + save.save(model, save_dir) # Convert model and ensure model is not None. converter = lite.TFLiteConverterV2.from_saved_model(save_dir) @@ -2503,8 +2571,8 @@ def testKerasSequentialModelExport(self): tflite_model = converter.convert() # Validate endpoints following `.export` to TFLite conversion. - interpreter = Interpreter(model_content=tflite_model) - signature_defs = interpreter.get_signature_list() + interp = interpreter.Interpreter(model_content=tflite_model) + signature_defs = interp.get_signature_list() self.assertLen(signature_defs, 1) self.assertEqual(next(iter(signature_defs)), 'serving_default') @@ -2522,7 +2590,7 @@ def testGraphDebugInfo(self): to_save = root.f.get_concrete_function(input_data) options = save_options.SaveOptions(save_debug_info=True) save_dir = os.path.join(self.get_temp_dir(), 'saved_model') - save(root, save_dir, to_save, options) + save.save(root, save_dir, to_save, options) # Convert model and ensure model is not None. converter = lite.TFLiteConverterV2.from_saved_model(save_dir) @@ -2624,16 +2692,16 @@ def testKerasFullyConnectedOutputShape3D(self): tflite_model = converter.convert() self.assertTrue(tflite_model) - interpreter = Interpreter(model_content=tflite_model) - output_details = interpreter.get_output_details() - input_details = interpreter.get_input_details() - interpreter.allocate_tensors() + interp = interpreter.Interpreter(model_content=tflite_model) + output_details = interp.get_output_details() + input_details = interp.get_input_details() + interp.allocate_tensors() input_data = np.array([[[1, 2, 3], [4, 5, 6], [7, 8, 9]]], np.float32) - interpreter.set_tensor(input_details[0]['index'], input_data) - interpreter.invoke() + interp.set_tensor(input_details[0]['index'], input_data) + interp.invoke() - actual_value = interpreter.get_tensor(output_details[0]['index']) + actual_value = interp.get_tensor(output_details[0]['index']) expected_value = model.predict(input_data) self.assertLen(output_details[0]['shape_signature'], 3) @@ -2711,25 +2779,25 @@ def testUnknownInputShapeModel(self): self.assertTrue(tflite_model) # Validate that tensors with unknown shape have unknown rank. - tflite_model_obj = _convert_bytearray_to_object(tflite_model) + tflite_model_obj = flatbuffer_utils.convert_bytearray_to_object( + tflite_model + ) for tensor in tflite_model_obj.subgraphs[0].tensors: self.assertEqual(False, tensor.hasRank) self.assertEqual([], tensor.shape.tolist()) # Check values from converted model. - interpreter = Interpreter(model_content=tflite_model) - input_details = interpreter.get_input_details() - output_details = interpreter.get_output_details() + interp = interpreter.Interpreter(model_content=tflite_model) + input_details = interp.get_input_details() + output_details = interp.get_output_details() input_data = np.array([1.0, 2.0, 3.0], dtype=np.float32) - interpreter.resize_tensor_input( - input_details[0]['index'], [3], strict=False - ) - interpreter.allocate_tensors() + interp.resize_tensor_input(input_details[0]['index'], [3], strict=False) + interp.allocate_tensors() - interpreter.set_tensor(input_details[0]['index'], input_data) - interpreter.invoke() - actual_value = interpreter.get_tensor(output_details[0]['index']) + interp.set_tensor(input_details[0]['index'], input_data) + interp.invoke() + actual_value = interp.get_tensor(output_details[0]['index']) self.assertEqual([2.0, 4.0, 6.0], list(actual_value)) @test_util.run_v2_only @@ -2742,7 +2810,9 @@ def testScalarInputShapeModel(self): self.assertTrue(tflite_model) # Validate that scalar tensors have a rank = 0. - tflite_model_obj = _convert_bytearray_to_object(tflite_model) + tflite_model_obj = flatbuffer_utils.convert_bytearray_to_object( + tflite_model + ) for tensor in tflite_model_obj.subgraphs[0].tensors: self.assertEqual(True, tensor.hasRank) self.assertEqual([], tensor.shape.tolist()) @@ -2757,7 +2827,9 @@ def testMatrixInputShapeModel(self): self.assertTrue(tflite_model) # Validate that matrix tensors have a rank = 2. - tflite_model_obj = _convert_bytearray_to_object(tflite_model) + tflite_model_obj = flatbuffer_utils.convert_bytearray_to_object( + tflite_model + ) for tensor in tflite_model_obj.subgraphs[0].tensors: self.assertEqual(True, tensor.hasRank) self.assertEqual([2, 3], tensor.shape.tolist()) @@ -2784,7 +2856,7 @@ def testDisablePerChannelQuantization( ) model.build(input_shape=(1, 5, 5, 3)) saved_model_dir = os.path.join(self.get_temp_dir(), 'conv_saved_model') - save(model, saved_model_dir) + save.save(model, saved_model_dir) k_conv_name = ( 'tfl.pseudo_qconst' if enable_mlir_quantizer @@ -2812,11 +2884,11 @@ def calib_gen(): quantized_tflite_model = quantized_converter.convert() self.assertIsNotNone(quantized_tflite_model) - interpreter = Interpreter(model_content=quantized_tflite_model) - interpreter.allocate_tensors() + interp = interpreter.Interpreter(model_content=quantized_tflite_model) + interp.allocate_tensors() detail = next(( d - for d in interpreter.get_tensor_details() + for d in interp.get_tensor_details() if d['name'].startswith(k_conv_name) )) quant_params = detail['quantization_parameters'] @@ -2849,7 +2921,7 @@ def testBiasQuantization( ) ]) saved_model_dir = os.path.join(self.get_temp_dir(), 'dense_saved_model') - save(model, saved_model_dir) + save.save(model, saved_model_dir) k_dense_bias_name = ( 'sequential/dense/BiasAdd/ReadVariableOp' if is_int16_quantize @@ -2888,11 +2960,11 @@ def calibration_gen(): quantized_tflite_model = quantized_converter.convert() self.assertIsNotNone(quantized_tflite_model) - interpreter = Interpreter(model_content=quantized_tflite_model) - interpreter.allocate_tensors() + interp = interpreter.Interpreter(model_content=quantized_tflite_model) + interp.allocate_tensors() dense_bias = next(( d - for d in interpreter.get_tensor_details() + for d in interp.get_tensor_details() if d['name'].startswith(k_dense_bias_name) )) self.assertEqual(bias_type, dense_bias['dtype']) @@ -2918,7 +2990,7 @@ def testMlirDynamicRangeQuantization( ) model.build(input_shape=(1, 32, 32, 3)) saved_model_dir = self.create_tempdir() - save(model, saved_model_dir.full_path) + save.save(model, saved_model_dir.full_path) converter = tf.lite.TFLiteConverter.from_saved_model( saved_model_dir.full_path @@ -2933,17 +3005,17 @@ def testMlirDynamicRangeQuantization( quantized_tflite_model = converter.convert() self.assertIsNotNone(quantized_tflite_model) - interpreter = Interpreter(model_content=quantized_tflite_model) - interpreter.allocate_tensors() + interp = interpreter.Interpreter(model_content=quantized_tflite_model) + interp.allocate_tensors() quantized_weight = None quantized_weight_with_one_postfix = None quantized_weight_without_one_postfix = None - for d in interpreter.get_tensor_details(): + for d in interp.get_tensor_details(): if d['name'] == conv_name + '1': quantized_weight = d quantized_weight_with_one_postfix = d break - for d in interpreter.get_tensor_details(): + for d in interp.get_tensor_details(): if d['name'].startswith(conv_name): if quantized_weight is None: quantized_weight = d @@ -2960,8 +3032,8 @@ def testMlirDynamicRangeQuantization( self.assertLen(quant_params['scales'], expected_num_params) self.assertLen(quant_params['zero_points'], expected_num_params) - input_details = interpreter.get_input_details() - output_details = interpreter.get_output_details() + input_details = interp.get_input_details() + output_details = interp.get_output_details() self.assertEqual(np.float32, input_details[0]['dtype']) self.assertEqual(np.float32, output_details[0]['dtype']) if enable_float16_quant: @@ -2991,7 +3063,7 @@ def testQDQConversionMode(self, mode): ) model.build(input_shape=(1, 32, 32, 3)) saved_model_dir = self.create_tempdir() - save(model, saved_model_dir.full_path) + save.save(model, saved_model_dir.full_path) converter = tf.lite.TFLiteConverter.from_saved_model( saved_model_dir.full_path ) @@ -3005,85 +3077,6 @@ def testQDQConversionMode(self, mode): model = converter.convert() self.assertIsNotNone(model) - # pylint: disable=pointless-string-statement - """disable test for now """ - """@parameterized.named_parameters( - ( - '_Float16Quantization', - _PresetQuantizationMethod.FLOAT16, - ), - ) - @test_util.run_v2_only - def testMlirStableHLOPresetQuantizationMethod( - self, preset_quantization_method - ): - num_filters = 38 - model = tf.keras.models.Sequential( - [tf.keras.layers.Conv2D(num_filters, (3, 3), activation='relu')] - ) - model.build(input_shape=(1, 5, 5, 3)) - saved_model_dir = os.path.join(self.get_temp_dir(), 'conv_saved_model') - save(model, saved_model_dir) - - quantization_options = quant_opts_pb2.QuantizationOptions( - quantization_method=quant_opts_pb2.QuantizationMethod( - preset_quantization_method=quant_opts_pb2.PresetQuantizationMethod( - preset_method=preset_quantization_method - ) - ) - ) - - converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir) - converter._experimental_quantization_options = quantization_options - - converter.target_spec.supported_ops = [ - tf.lite.OpsSet.EXPERIMENTAL_STABLEHLO_OPS - ] - converter.exclude_conversion_metadata = True - converter.optimizations = [lite.Optimize.DEFAULT] - quantized_stablehlo_model = converter.convert() - self.assertIsNotNone(quantized_stablehlo_model) - - @test_util.run_v2_only - def testMlirStableHLOCustomQuantizationMethod(self): - num_filters = 38 - model = tf.keras.models.Sequential( - [tf.keras.layers.Conv2D(num_filters, (3, 3), activation='relu')] - ) - model.build(input_shape=(1, 5, 5, 3)) - saved_model_dir = os.path.join(self.get_temp_dir(), 'conv_saved_model') - save(model, saved_model_dir) - - quantization_options = quant_opts_pb2.QuantizationOptions( - quantization_method=quant_opts_pb2.QuantizationMethod( - custom_quantization_method=quant_opts_pb2.CustomQuantizationMethod( - quantization_component_spec=[ - quant_opts_pb2.QuantizationComponentSpec( - quantization_component=quant_opts_pb2.QuantizationComponentSpec.QuantizationComponent.COMPONENT_WEIGHT, - bit_width=quant_opts_pb2.QuantizationComponentSpec.BitWidth.BIT_WIDTH_16, - bit_type=quant_opts_pb2.QuantizationComponentSpec.BitType.BIT_TYPE_FLOAT, - ), - quant_opts_pb2.QuantizationComponentSpec( - quantization_component=quant_opts_pb2.QuantizationComponentSpec.QuantizationComponent.COMPONENT_BIAS, - bit_width=quant_opts_pb2.QuantizationComponentSpec.BitWidth.BIT_WIDTH_16, - ), - ] - ) - ) - ) - - converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir) - converter._experimental_quantization_options = quantization_options - - converter.target_spec.supported_ops = [ - tf.lite.OpsSet.EXPERIMENTAL_STABLEHLO_OPS - ] - converter.exclude_conversion_metadata = True - converter.optimizations = [lite.Optimize.DEFAULT] - quantized_stablehlo_model = converter.convert() - self.assertIsNotNone(quantized_stablehlo_model)""" - # pylint: enable=pointless-string-statement - class FromKerasModelTest(lite_v2_test_util.ModelTest): @@ -3108,12 +3101,12 @@ def testVariableQuantization(self, variable_quantization, number_of_states): quantized_tflite_model = converter.convert() - interpreter = Interpreter(model_content=quantized_tflite_model) - interpreter.allocate_tensors() + interp = interpreter.Interpreter(model_content=quantized_tflite_model) + interp.allocate_tensors() detail = next(( d - for d in interpreter.get_tensor_details() + for d in interp.get_tensor_details() if d['name'].startswith(k_readvariable_name) )) quant_params = detail['quantization_parameters'] @@ -3172,7 +3165,7 @@ def testSequentialModel(self): converter = lite.TFLiteConverterV2.from_keras_model(model) tflite_model = converter.convert() # Check the conversion metadata. - metadata = get_conversion_metadata(tflite_model) + metadata = util.get_conversion_metadata(tflite_model) self.assertIsNotNone(metadata) self.assertEqual( metadata.environment.modelType, metadata_fb.ModelType.KERAS_MODEL @@ -3252,7 +3245,7 @@ def testKerasFallbackPath(self): class Model(tf.keras.Model): def __init__(self): - super(Model, self).__init__() + super().__init__() # A None name will cause a failure in exporting to a saved model. self.shared_weights = self.add_weight( name=None, @@ -3299,8 +3292,8 @@ def testSignatureDefs(self): np.random.uniform(-1, 1, size=(1, 32, 32, 3)).astype(np.float32) ) expected_value = keras_model(input_data) - interpreter = Interpreter(model_content=tflite_model) - signature_defs = interpreter.get_signature_list() + interp = interpreter.Interpreter(model_content=tflite_model) + signature_defs = interp.get_signature_list() results = self._evaluateTFLiteModelUsingSignatureDef( tflite_model, 'serving_default', {'tensor_input': input_data} ) @@ -3308,9 +3301,9 @@ def testSignatureDefs(self): self.assertAllClose(expected_value.numpy(), results['output_tensor']) # Verify the SignatureDef structure returned is as expected. - self.assertEqual(len(signature_defs), 1) + self.assertLen(signature_defs, 1) self.assertEqual(list(signature_defs.keys()), ['serving_default']) - self.assertEqual(len(signature_defs.values()), 1) + self.assertLen(signature_defs.values(), 1) self.assertEqual( list(signature_defs['serving_default'].keys()), ['inputs', 'outputs'] ) @@ -3354,17 +3347,17 @@ def testMlirDynamicRangeQuantization( quantized_tflite_model = converter.convert() self.assertIsNotNone(quantized_tflite_model) - interpreter = Interpreter(model_content=quantized_tflite_model) - interpreter.allocate_tensors() + interp = interpreter.Interpreter(model_content=quantized_tflite_model) + interp.allocate_tensors() quantized_weight = None quantized_weight_with_one_postfix = None quantized_weight_without_one_postfix = None - for d in interpreter.get_tensor_details(): + for d in interp.get_tensor_details(): if d['name'] == conv_name + '1': quantized_weight = d quantized_weight_with_one_postfix = d break - for d in interpreter.get_tensor_details(): + for d in interp.get_tensor_details(): if d['name'].startswith(conv_name): if quantized_weight is None: quantized_weight = d @@ -3381,8 +3374,8 @@ def testMlirDynamicRangeQuantization( self.assertLen(quant_params['scales'], expected_num_params) self.assertLen(quant_params['zero_points'], expected_num_params) - input_details = interpreter.get_input_details() - output_details = interpreter.get_output_details() + input_details = interp.get_input_details() + output_details = interp.get_output_details() self.assertEqual(np.float32, input_details[0]['dtype']) self.assertEqual(np.float32, output_details[0]['dtype']) if enable_float16_quant: @@ -3439,8 +3432,8 @@ def testQATLowBitKerasModel(self, num_bits, weight_only, low_bit): self.assertAllClose( [np.linalg.norm(result - tf_result.numpy().astype(np.float32))], [0.0] ) - interpreter = tf.lite.Interpreter(model_content=tflite_model) - interpreter.allocate_tensors() + interp = interpreter.Interpreter(model_content=tflite_model) + interp.allocate_tensors() num_8bit_activations = 0 num_8bit_weights = 0 kernel_name = ( @@ -3448,14 +3441,14 @@ def testQATLowBitKerasModel(self, num_bits, weight_only, low_bit): 'FakeQuantWithMinMaxVarsPerChannel' ) - for detail in interpreter.get_tensor_details(): + for detail in interp.get_tensor_details(): if ( detail['dtype'] == np.int8 and detail['name'] and detail['name'] == kernel_name ): num_8bit_weights += 1 - weights = interpreter.get_tensor(detail['index']) + weights = interp.get_tensor(detail['index']) if low_bit: self.assertFalse( (bit_min > weights).any() or (weights > bit_max).any() @@ -3525,15 +3518,15 @@ def call(self, inputs): converted_model = converter.convert() tf.lite.experimental.Analyzer.analyze(model_content=converted_model) - interpreter = tf.lite.Interpreter(model_content=converted_model) - interpreter.allocate_tensors() + interp = interpreter.Interpreter(model_content=converted_model) + interp.allocate_tensors() - input_index = interpreter.get_input_details()[0]['index'] - output_index = interpreter.get_output_details()[0]['index'] + input_index = interp.get_input_details()[0]['index'] + output_index = interp.get_output_details()[0]['index'] - interpreter.set_tensor(input_index, input_data.astype(np.float32)) - interpreter.invoke() - tflite_result = interpreter.tensor(output_index)() + interp.set_tensor(input_index, input_data.astype(np.float32)) + interp.invoke() + tflite_result = interp.tensor(output_index)() self.assertAllClose( [np.linalg.norm(tflite_result - tf_result.numpy().astype(np.float32))], @@ -3541,13 +3534,67 @@ def call(self, inputs): ) num_float32_tensor = 0 - for detail in interpreter.get_tensor_details(): + for detail in interp.get_tensor_details(): if detail['dtype'] == np.float32: num_float32_tensor += 1 # There should be only 2 float tensors, input and output. self.assertEqual(num_float32_tensor, 2) + @parameterized.named_parameters( + ('_PerChannelQuant', False, False), + ('_PerChannelMlirQuant', False, True), + ('_PerTensorQuant', True, False), + ('_PerTensorMlirQuant', True, True), + ('_PerChannelDynamicRange', False, True, True), + ('_PerTensorDynamicRange', True, True, True), + ) + @test_util.run_v2_only + def testDisablePerChannelQuantizationForDenseLayers( + self, + disable_per_channel_for_dense=False, + enable_mlir_quantizer=False, + representative_dataset=False, + ): + k_dense_name = 'tfl.pseudo_qconst' if representative_dataset else 'MatMul' + # Dynamic range quant requires total num elements of filters > 1024. + k_num_filters = 64 + model = tf.keras.models.Sequential([ + tf.keras.Input(shape=(16,)), + tf.keras.layers.Dense(k_num_filters, activation='relu'), + ]) + model.build() + + quantized_converter = lite.TFLiteConverterV2.from_keras_model(model) + quantized_converter.optimizations = [lite.Optimize.DEFAULT] + if representative_dataset: + + def calibration_gen(): + for _ in range(5): + yield [np.random.uniform(-1, 1, size=(1, 16)).astype(np.float32)] + + quantized_converter.representative_dataset = calibration_gen + quantized_converter.target_spec.supported_ops = [ + lite.OpsSet.TFLITE_BUILTINS + ] + quantized_converter.experimental_new_quantizer = enable_mlir_quantizer + if disable_per_channel_for_dense: + quantized_converter._experimental_disable_per_channel_quantization_for_dense_layers = ( + disable_per_channel_for_dense + ) + quantized_tflite_model = quantized_converter.convert() + self.assertIsNotNone(quantized_tflite_model) + + interp = interpreter.Interpreter(model_content=quantized_tflite_model) + interp.allocate_tensors() + detail = next( + (d for d in interp.get_tensor_details() if k_dense_name in d['name']) + ) + quant_params = detail['quantization_parameters'] + expected_num_params = 1 if disable_per_channel_for_dense else k_num_filters + self.assertLen(quant_params['scales'], expected_num_params) + self.assertLen(quant_params['zero_points'], expected_num_params) + class FromJaxModelTest(lite_v2_test_util.ModelTest): @@ -3645,7 +3692,7 @@ def single_input(input_tensor): ) tflite_model = converter.convert() # Check the conversion metadata. - metadata = get_conversion_metadata(tflite_model) + metadata = util.get_conversion_metadata(tflite_model) self.assertIsNotNone(metadata) self.assertEqual(metadata.environment.modelType, metadata_fb.ModelType.JAX) @@ -4246,9 +4293,9 @@ def testMatMulQuantize(self): quantized_tflite_model = quantized_converter.convert() # The default input and output types should be float. - quantized_interpreter = Interpreter(model_content=quantized_tflite_model) - quantized_interpreter.allocate_tensors() - input_details = quantized_interpreter.get_input_details() + interp = interpreter.Interpreter(model_content=quantized_tflite_model) + interp.allocate_tensors() + input_details = interp.get_input_details() self.assertLen(input_details, 1) self.assertEqual(np.float32, input_details[0]['dtype']) self.assertAllEqual([-1, 33], input_details[0]['shape_signature']) @@ -4274,9 +4321,9 @@ def testMatMulCalibrateAndQuantize(self): quantized_tflite_model = quantized_converter.convert() # The default input and output types should be float. - quantized_interpreter = Interpreter(model_content=quantized_tflite_model) - quantized_interpreter.allocate_tensors() - input_details = quantized_interpreter.get_input_details() + interp = interpreter.Interpreter(model_content=quantized_tflite_model) + interp.allocate_tensors() + input_details = interp.get_input_details() self.assertLen(input_details, 1) self.assertEqual(np.float32, input_details[0]['dtype']) self.assertAllEqual([-1, 33], input_details[0]['shape_signature']) @@ -4457,25 +4504,25 @@ def model(v): self.assertIsNotNone(tflite_model) # Check values from converted model. - interpreter = Interpreter(model_content=tflite_model) - input_details = interpreter.get_input_details() - output_details = interpreter.get_output_details() + interp = interpreter.Interpreter(model_content=tflite_model) + input_details = interp.get_input_details() + output_details = interp.get_output_details() - interpreter.allocate_tensors() + interp.allocate_tensors() input_data = np.array([1.0], dtype=np.float32) - interpreter.set_tensor(input_details[0]['index'], input_data) + interp.set_tensor(input_details[0]['index'], input_data) - interpreter.invoke() - actual_value = interpreter.get_tensor(output_details[0]['index']) + interp.invoke() + actual_value = interp.get_tensor(output_details[0]['index']) self.assertEqual(1, actual_value) - interpreter.invoke() - actual_value = interpreter.get_tensor(output_details[0]['index']) + interp.invoke() + actual_value = interp.get_tensor(output_details[0]['index']) self.assertEqual(1, actual_value) - interpreter.invoke() - actual_value = interpreter.get_tensor(output_details[0]['index']) + interp.invoke() + actual_value = interp.get_tensor(output_details[0]['index']) self.assertEqual(1, actual_value) @test_util.run_v2_only @@ -4516,26 +4563,26 @@ def body(i, m): self.assertIsNotNone(tflite_model) # Check values from converted model. - interpreter = Interpreter(model_content=tflite_model) - input_details = interpreter.get_input_details() - output_details = interpreter.get_output_details() + interp = interpreter.Interpreter(model_content=tflite_model) + input_details = interp.get_input_details() + output_details = interp.get_output_details() - interpreter.allocate_tensors() + interp.allocate_tensors() input_data = np.array([0], dtype=np.int32) - interpreter.set_tensor(input_details[0]['index'], input_data) + interp.set_tensor(input_details[0]['index'], input_data) - interpreter.invoke() + interp.invoke() expected_value = np.array([1], dtype=np.int32) - actual_value = interpreter.get_tensor(output_details[0]['index']) + actual_value = interp.get_tensor(output_details[0]['index']) self.assertEqual(expected_value, actual_value) - interpreter.invoke() - actual_value = interpreter.get_tensor(output_details[0]['index']) + interp.invoke() + actual_value = interp.get_tensor(output_details[0]['index']) self.assertEqual(expected_value, actual_value) - interpreter.invoke() - actual_value = interpreter.get_tensor(output_details[0]['index']) + interp.invoke() + actual_value = interp.get_tensor(output_details[0]['index']) self.assertEqual(expected_value, actual_value) @test_util.run_v2_only @@ -4576,25 +4623,25 @@ def body(i, m): self.assertIsNotNone(tflite_model) # Check values from converted model. - interpreter = Interpreter(model_content=tflite_model) - input_details = interpreter.get_input_details() - output_details = interpreter.get_output_details() + interp = interpreter.Interpreter(model_content=tflite_model) + input_details = interp.get_input_details() + output_details = interp.get_output_details() - interpreter.allocate_tensors() + interp.allocate_tensors() input_data = np.array([0], dtype=np.int32) - interpreter.set_tensor(input_details[0]['index'], input_data) + interp.set_tensor(input_details[0]['index'], input_data) - interpreter.invoke() - actual_value = interpreter.get_tensor(output_details[0]['index']) + interp.invoke() + actual_value = interp.get_tensor(output_details[0]['index']) self.assertEqual(10, actual_value) - interpreter.invoke() - actual_value = interpreter.get_tensor(output_details[0]['index']) + interp.invoke() + actual_value = interp.get_tensor(output_details[0]['index']) self.assertEqual(10, actual_value) - interpreter.invoke() - actual_value = interpreter.get_tensor(output_details[0]['index']) + interp.invoke() + actual_value = interp.get_tensor(output_details[0]['index']) self.assertEqual(10, actual_value) @test_util.run_v2_only @@ -4632,25 +4679,25 @@ def create_v1_saved_model(): self.assertIsNotNone(tflite_model) # Check values from converted model. - interpreter = Interpreter(model_content=tflite_model) - input_details = interpreter.get_input_details() - output_details = interpreter.get_output_details() + interp = interpreter.Interpreter(model_content=tflite_model) + input_details = interp.get_input_details() + output_details = interp.get_output_details() - interpreter.allocate_tensors() + interp.allocate_tensors() input_data = np.array([1.0], dtype=np.float32) - interpreter.set_tensor(input_details[0]['index'], input_data) + interp.set_tensor(input_details[0]['index'], input_data) - interpreter.invoke() - actual_value = interpreter.get_tensor(output_details[0]['index']) + interp.invoke() + actual_value = interp.get_tensor(output_details[0]['index']) self.assertEqual(3.0, actual_value) - interpreter.invoke() - actual_value = interpreter.get_tensor(output_details[0]['index']) + interp.invoke() + actual_value = interp.get_tensor(output_details[0]['index']) self.assertEqual(3.0, actual_value) - interpreter.invoke() - actual_value = interpreter.get_tensor(output_details[0]['index']) + interp.invoke() + actual_value = interp.get_tensor(output_details[0]['index']) self.assertEqual(3.0, actual_value) @test_util.run_v2_only @@ -4695,17 +4742,17 @@ def body(i, arr): self.assertIsNotNone(tflite_model) # Check values from converted model. - interpreter = Interpreter(model_content=tflite_model) - input_details = interpreter.get_input_details() - output_details = interpreter.get_output_details() + interp = interpreter.Interpreter(model_content=tflite_model) + input_details = interp.get_input_details() + output_details = interp.get_output_details() - interpreter.allocate_tensors() + interp.allocate_tensors() input_data = np.array([1.0], dtype=np.float32) - interpreter.set_tensor(input_details[0]['index'], input_data) + interp.set_tensor(input_details[0]['index'], input_data) - interpreter.invoke() - actual_value = interpreter.get_tensor(output_details[0]['index']) + interp.invoke() + actual_value = interp.get_tensor(output_details[0]['index']) self.assertEqual(0.0, actual_value) @test_util.run_v2_only @@ -4756,17 +4803,17 @@ def body(i, arr, m): self.assertIsNotNone(tflite_model) # Check values from converted model. - interpreter = Interpreter(model_content=tflite_model) - input_details = interpreter.get_input_details() - output_details = interpreter.get_output_details() + interp = interpreter.Interpreter(model_content=tflite_model) + input_details = interp.get_input_details() + output_details = interp.get_output_details() - interpreter.allocate_tensors() + interp.allocate_tensors() input_data = np.array([1.0], dtype=np.float32) - interpreter.set_tensor(input_details[0]['index'], input_data) + interp.set_tensor(input_details[0]['index'], input_data) - interpreter.invoke() - actual_value = interpreter.get_tensor(output_details[0]['index']) + interp.invoke() + actual_value = interp.get_tensor(output_details[0]['index']) self.assertEqual(9.0, actual_value) @parameterized.named_parameters( @@ -4812,17 +4859,17 @@ def create_v1_saved_model(): self.assertIsNotNone(tflite_model) # Check values from converted model. - interpreter = Interpreter(model_content=tflite_model) - input_details = interpreter.get_input_details() - output_details = interpreter.get_output_details() + interp = interpreter.Interpreter(model_content=tflite_model) + input_details = interp.get_input_details() + output_details = interp.get_output_details() - interpreter.allocate_tensors() + interp.allocate_tensors() input_data = np.array([1.0], dtype=np.float32) - interpreter.set_tensor(input_details[0]['index'], input_data) + interp.set_tensor(input_details[0]['index'], input_data) - interpreter.invoke() - actual_value = interpreter.get_tensor(output_details[0]['index']) + interp.invoke() + actual_value = interp.get_tensor(output_details[0]['index']) self.assertEqual(40.0, actual_value) @parameterized.named_parameters( @@ -4876,17 +4923,17 @@ def create_v1_saved_model(): self.assertIsNotNone(tflite_model) # Check values from converted model. - interpreter = Interpreter(model_content=tflite_model) - input_details = interpreter.get_input_details() - output_details = interpreter.get_output_details() + interp = interpreter.Interpreter(model_content=tflite_model) + input_details = interp.get_input_details() + output_details = interp.get_output_details() - interpreter.allocate_tensors() + interp.allocate_tensors() input_data = np.array([1.0], dtype=np.float32) - interpreter.set_tensor(input_details[0]['index'], input_data) + interp.set_tensor(input_details[0]['index'], input_data) - interpreter.invoke() - actual_value = interpreter.get_tensor(output_details[0]['index']) + interp.invoke() + actual_value = interp.get_tensor(output_details[0]['index']) self.assertEqual(40.0, actual_value) @@ -4929,23 +4976,23 @@ def testCustomOpRegistererByName(self): self.assertGreater(test_registerer.get_num_test_registerer_calls(), 0) self.assertIn('Double', tflite_test_util.get_ops_list(tflite_model)) # Check the conversion metadata. - metadata = get_conversion_metadata(tflite_model) + metadata = util.get_conversion_metadata(tflite_model) self.assertIsNotNone(metadata) self.assertEqual(metadata.options.allowCustomOps, True) # Check the model works with custom ops. - interpreter = InterpreterWithCustomOps( + interp = interpreter.InterpreterWithCustomOps( model_content=tflite_model, custom_op_registerers=['TF_TestRegisterer'] ) - interpreter.allocate_tensors() - input_details = interpreter.get_input_details() + interp.allocate_tensors() + input_details = interp.get_input_details() test_input = np.array([[0.0, 0.1, 0.2, 0.3]], dtype=np.float32) - interpreter.set_tensor(input_details[0]['index'], test_input) - interpreter.invoke() + interp.set_tensor(input_details[0]['index'], test_input) + interp.invoke() - output_details = interpreter.get_output_details() + output_details = interp.get_output_details() expected_output = np.array([[0.0, 0.2, 0.4, 0.6]], dtype=np.float32) - output_data = interpreter.get_tensor(output_details[0]['index']) + output_data = interp.get_tensor(output_details[0]['index']) self.assertArrayNear(expected_output[0], output_data[0], err=1e-2) def testCustomOpRegistererByFunc(self): @@ -4965,19 +5012,19 @@ def testCustomOpRegistererByFunc(self): self.assertIn('Double', tflite_test_util.get_ops_list(tflite_model)) # Check the model works with custom ops. - interpreter = InterpreterWithCustomOps( + interp = interpreter.InterpreterWithCustomOps( model_content=tflite_model, custom_op_registerers=[test_registerer.TF_TestRegisterer], ) - interpreter.allocate_tensors() - input_details = interpreter.get_input_details() + interp.allocate_tensors() + input_details = interp.get_input_details() test_input = np.array([[0.0, 0.1, 0.2, 0.3]], dtype=np.float32) - interpreter.set_tensor(input_details[0]['index'], test_input) - interpreter.invoke() + interp.set_tensor(input_details[0]['index'], test_input) + interp.invoke() - output_details = interpreter.get_output_details() + output_details = interp.get_output_details() expected_output = np.array([[0.0, 0.2, 0.4, 0.6]], dtype=np.float32) - output_data = interpreter.get_tensor(output_details[0]['index']) + output_data = interp.get_tensor(output_details[0]['index']) self.assertArrayNear(expected_output[0], output_data[0], err=1e-2) def testCustomOpRegistererFailure(self): @@ -5008,7 +5055,7 @@ def f(x): w = tf.add(z, z, name='w') return w - # NOTE this is exactly representable as a float as are the intermeidates of + # NOTE this is exactly representable as a float as are the intermediates of # f. So direct comparison is ok below. input_data = np.array(2.0, np.float32) @@ -5017,25 +5064,23 @@ def f(x): [concrete_func], f ) tflite_model = converter.convert() - interpreter = Interpreter( + interp = interpreter.Interpreter( model_content=tflite_model, experimental_preserve_all_tensors=experimental_preserve_all_tensors, ) - interpreter.allocate_tensors() - interpreter.set_tensor( - interpreter.get_input_details()[0]['index'], input_data - ) - interpreter.invoke() - out = interpreter.get_tensor(interpreter.get_output_details()[0]['index']) + interp.allocate_tensors() + interp.set_tensor(interp.get_input_details()[0]['index'], input_data) + interp.invoke() + out = interp.get_tensor(interp.get_output_details()[0]['index']) tensors = {} - for t in interpreter.get_tensor_details(): + for t in interp.get_tensor_details(): # With Tensorflow Lite default delegate applied to the model graph, the # access to original tensors of a delegated op could cause a ValueError # (i.e. 'Tensor data is null. Run allocate_tensors() first') to be thrown # out because the tensor memory isn't allocated at all. val = None try: - val = interpreter.get_tensor(t['index']) + val = interp.get_tensor(t['index']) except ValueError: pass tensors.update({t['name']: val}) @@ -5080,13 +5125,13 @@ def model(): self.assertIsNotNone(tflite_model) # Check values from converted model. - interpreter = Interpreter(model_content=tflite_model) - output_details = interpreter.get_output_details() + interp = interpreter.Interpreter(model_content=tflite_model) + output_details = interp.get_output_details() - interpreter.allocate_tensors() + interp.allocate_tensors() - interpreter.invoke() - actual_value = interpreter.get_tensor(output_details[0]['index']) + interp.invoke() + actual_value = interp.get_tensor(output_details[0]['index']) self.assertEqual(10, actual_value) @@ -5125,7 +5170,7 @@ def testRandomSparsity(self): float_tflite_model = float_converter.convert() self.assertIsNotNone(float_tflite_model) # Check the conversion metadata. - metadata = get_conversion_metadata(float_tflite_model) + metadata = util.get_conversion_metadata(float_tflite_model) self.assertIsNotNone(metadata) self.assertAllEqual( [metadata_fb.ModelOptimizationMode.RANDOM_SPARSITY], @@ -5147,14 +5192,20 @@ def testBlockSparsity(self): float_tflite_model = float_converter.convert() self.assertIsNotNone(float_tflite_model) # Check the conversion metadata. - metadata = get_conversion_metadata(float_tflite_model) + metadata = util.get_conversion_metadata(float_tflite_model) self.assertIsNotNone(metadata) self.assertAllEqual( [metadata_fb.ModelOptimizationMode.BLOCK_SPARSITY], metadata.options.modelOptimizationModes, ) - def testQuantizedBlockSparsity(self): + @parameterized.named_parameters( + ('_PerChannelQuantForDense', False), + ('_PerTensorQuantForDense', True), + ) + def testQuantizedBlockSparsity( + self, disable_per_channel_quantization_for_dense_layers + ): weight_values = np.array([ [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 2, 0, 0, 0, 0, 5, 0, 0, 0, 3, 0, 0, 0, 1, 0], @@ -5183,11 +5234,14 @@ def calibration_gen(): lite.Optimize.DEFAULT, ] quantized_converter.representative_dataset = calibration_gen + quantized_converter._experimental_disable_per_channel_quantization_for_dense_layers = ( + disable_per_channel_quantization_for_dense_layers + ) quantized_tflite_model = quantized_converter.convert() self.assertIsNotNone(quantized_tflite_model) # Check the conversion metadata. - metadata = get_conversion_metadata(quantized_tflite_model) + metadata = util.get_conversion_metadata(quantized_tflite_model) self.assertIsNotNone(metadata) self.assertEqual( metadata.environment.tensorflowVersion.decode('utf-8'), @@ -5203,17 +5257,17 @@ def calibration_gen(): ) # Check values from converted model. - interpreter = Interpreter(model_content=quantized_tflite_model) - input_details = interpreter.get_input_details() - output_details = interpreter.get_output_details() - interpreter.allocate_tensors() + interp = interpreter.Interpreter(model_content=quantized_tflite_model) + input_details = interp.get_input_details() + output_details = interp.get_output_details() + interp.allocate_tensors() input_data = np.array( [[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]], dtype=np.float32, ) - interpreter.set_tensor(input_details[0]['index'], input_data) - interpreter.invoke() - actual_value = interpreter.get_tensor(output_details[0]['index']) + interp.set_tensor(input_details[0]['index'], input_data) + interp.invoke() + actual_value = interp.get_tensor(output_details[0]['index']) self.assertArrayNear( np.array([0, 87, 0, 0, 0, 0, 0, 34], dtype=np.float32), actual_value.flatten(), @@ -5250,7 +5304,7 @@ def calibration_gen(): self.assertIsNotNone(quantized_tflite_model) # Check the conversion metadata. - metadata = get_conversion_metadata(quantized_tflite_model) + metadata = util.get_conversion_metadata(quantized_tflite_model) self.assertIsNotNone(metadata) self.assertEqual( metadata.environment.tensorflowVersion.decode('utf-8'), @@ -5273,17 +5327,17 @@ def calibration_gen(): ) # Check values from converted model. - interpreter = Interpreter(model_content=quantized_tflite_model) - input_details = interpreter.get_input_details() - output_details = interpreter.get_output_details() - interpreter.allocate_tensors() + interp = interpreter.Interpreter(model_content=quantized_tflite_model) + input_details = interp.get_input_details() + output_details = interp.get_output_details() + interp.allocate_tensors() input_data = np.array( [[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]], dtype=np.float32, ) - interpreter.set_tensor(input_details[0]['index'], input_data) - interpreter.invoke() - actual_value = interpreter.get_tensor(output_details[0]['index']) + interp.set_tensor(input_details[0]['index'], input_data) + interp.invoke() + actual_value = interp.get_tensor(output_details[0]['index']) self.assertArrayNear( np.array([0, -3, 4, 35], dtype=np.float32), actual_value.flatten(), @@ -5333,9 +5387,9 @@ def __call__(self, x): [str(x) for x in range(11)], shape=(11,), dtype=tf.dtypes.string ) # Check values from converted model. - interpreter = tf.lite.Interpreter(model_content=tflite_model) - interpreter.allocate_tensors() - my_signature = interpreter.get_signature_runner() + interp = interpreter.Interpreter(model_content=tflite_model) + interp.allocate_tensors() + my_signature = interp.get_signature_runner() with self.assertRaises(ValueError) as error: _ = my_signature(x=input_data) @@ -5355,7 +5409,7 @@ def testSavedModelSignatureDefs(self): ) save_dir = os.path.join(self.get_temp_dir(), 'saved_model') - save(root, save_dir, {'mul_add': mul_add_func}) + save.save(root, save_dir, {'mul_add': mul_add_func}) converter = lite.TFLiteConverterV2.from_saved_model( save_dir, signature_keys=['mul_add'] @@ -5365,8 +5419,8 @@ def testSavedModelSignatureDefs(self): # Check values from converted model. expected_value = root.mul_add(input_data_1, input_data_0) - interpreter = Interpreter(model_content=tflite_model) - signature_defs = interpreter.get_signature_list() + interp = interpreter.Interpreter(model_content=tflite_model) + signature_defs = interp.get_signature_list() results = self._evaluateTFLiteModelUsingSignatureDef( tflite_model, 'mul_add', {'y': input_data_0, 'x': input_data_1} ) @@ -5374,9 +5428,9 @@ def testSavedModelSignatureDefs(self): self.assertEqual(expected_value.numpy(), results['output_0']) # Verify the SignatureDef structure returned is as expected. - self.assertEqual(len(signature_defs), 1) + self.assertLen(signature_defs, 1) self.assertEqual(list(signature_defs.keys()), ['mul_add']) - self.assertEqual(len(signature_defs.values()), 1) + self.assertLen(signature_defs.values(), 1) self.assertEqual( list(signature_defs['mul_add'].keys()), ['inputs', 'outputs'] ) diff --git a/tensorflow/lite/python/wrap_toco.py b/tensorflow/lite/python/wrap_toco.py index 049d2babf36ff8..9d8a8bc11a6456 100644 --- a/tensorflow/lite/python/wrap_toco.py +++ b/tensorflow/lite/python/wrap_toco.py @@ -17,12 +17,16 @@ # pylint: disable=invalid-import-order,g-bad-import-order from tensorflow.python import pywrap_tensorflow # pylint: disable=unused-import from tensorflow.python import _pywrap_toco_api +from tensorflow.compiler.mlir.quantization.tensorflow.python import py_function_lib -# TODO(b/137402359): Remove lazy loading wrapper - -def wrapped_toco_convert(model_flags_str, toco_flags_str, input_data_str, - debug_info_str, enable_mlir_converter): +def wrapped_toco_convert( + model_flags_str, + toco_flags_str, + input_data_str, + debug_info_str, + enable_mlir_converter, +): """Wraps TocoConvert with lazy loader.""" return _pywrap_toco_api.TocoConvert( model_flags_str, @@ -30,20 +34,40 @@ def wrapped_toco_convert(model_flags_str, toco_flags_str, input_data_str, input_data_str, False, # extended_return debug_info_str, - enable_mlir_converter) + enable_mlir_converter, + py_function_lib.PyFunctionLibrary(), + ) def wrapped_experimental_mlir_quantize( - input_data_str, disable_per_channel, fully_quantize, inference_type, - input_data_type, output_data_type, enable_numeric_verify, - enable_whole_model_verify, denylisted_ops, denylisted_nodes, - enable_variable_quantization): + input_data_str, + disable_per_channel, + fully_quantize, + inference_type, + input_data_type, + output_data_type, + enable_numeric_verify, + enable_whole_model_verify, + denylisted_ops, + denylisted_nodes, + enable_variable_quantization, + disable_per_channel_for_dense_layers, +): """Wraps experimental mlir quantize model.""" return _pywrap_toco_api.ExperimentalMlirQuantizeModel( - input_data_str, disable_per_channel, fully_quantize, inference_type, - input_data_type, output_data_type, enable_numeric_verify, - enable_whole_model_verify, denylisted_ops, denylisted_nodes, - enable_variable_quantization) + input_data_str, + disable_per_channel, + fully_quantize, + inference_type, + input_data_type, + output_data_type, + enable_numeric_verify, + enable_whole_model_verify, + denylisted_ops, + denylisted_nodes, + enable_variable_quantization, + disable_per_channel_for_dense_layers, + ) def wrapped_experimental_mlir_sparsify(input_data_str): diff --git a/tensorflow/lite/testing/join.h b/tensorflow/lite/testing/join.h index 3f17f7fad46a8b..f7337d701fec63 100644 --- a/tensorflow/lite/testing/join.h +++ b/tensorflow/lite/testing/join.h @@ -15,6 +15,7 @@ limitations under the License. #ifndef TENSORFLOW_LITE_TESTING_JOIN_H_ #define TENSORFLOW_LITE_TESTING_JOIN_H_ +#include #include #include #include diff --git a/tensorflow/lite/toco/python/BUILD b/tensorflow/lite/toco/python/BUILD index 1357f2e76988ef..145a9aa0e0a57a 100644 --- a/tensorflow/lite/toco/python/BUILD +++ b/tensorflow/lite/toco/python/BUILD @@ -1,5 +1,5 @@ -load("//tensorflow:tensorflow.default.bzl", "tf_py_strict_test") load("//tensorflow:strict.default.bzl", "py_strict_binary", "py_strict_library") +load("//tensorflow:tensorflow.default.bzl", "tf_py_strict_test") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], @@ -45,6 +45,7 @@ cc_library( "//tensorflow/compiler/mlir/lite/python:saved_model_to_tfl_flatbuffer", "//tensorflow/compiler/mlir/lite/quantization/lite:quantize_model", "//tensorflow/compiler/mlir/lite/sparsity:sparsify_model", + "//tensorflow/compiler/mlir/quantization/tensorflow/python:py_function_lib", "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", @@ -99,6 +100,7 @@ py_strict_binary( srcs_version = "PY3", deps = [ "@absl_py//absl:app", + "//tensorflow/compiler/mlir/quantization/tensorflow/python:py_function_lib_py", "//tensorflow/python:_pywrap_toco_api", "//tensorflow/python:pywrap_tensorflow", # Needed to provide PyArray_API diff --git a/tensorflow/lite/toco/python/toco_from_protos.py b/tensorflow/lite/toco/python/toco_from_protos.py index 617b5e9093d795..663bf6fd79bd52 100644 --- a/tensorflow/lite/toco/python/toco_from_protos.py +++ b/tensorflow/lite/toco/python/toco_from_protos.py @@ -20,6 +20,7 @@ # pylint: disable=invalid-import-order,g-bad-import-order from tensorflow.python import pywrap_tensorflow # pylint: disable=unused-import from tensorflow.python import _pywrap_toco_api +from tensorflow.compiler.mlir.quantization.tensorflow.python import py_function_lib # pylint: disable=unused-import; required for TocoConvert to understand the type: PyFunctionLibrary from absl import app FLAGS = None @@ -49,7 +50,9 @@ def execute(unused_args): input_str, False, # extended_return debug_info_str, - enable_mlir_converter) + enable_mlir_converter, + None, # quantization_py_function_library + ) open(FLAGS.model_output_file, "wb").write(output_str) sys.exit(0) diff --git a/tensorflow/lite/toco/python/toco_python_api.cc b/tensorflow/lite/toco/python/toco_python_api.cc index 48af2bdce7cb9a..345380916fbc55 100644 --- a/tensorflow/lite/toco/python/toco_python_api.cc +++ b/tensorflow/lite/toco/python/toco_python_api.cc @@ -29,6 +29,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.h" #include "tensorflow/compiler/mlir/lite/quantization/lite/quantize_model.h" #include "tensorflow/compiler/mlir/lite/sparsity/sparsify_model.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/python/py_function_lib.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op_def.pb.h" #include "tensorflow/lite/core/api/error_reporter.h" @@ -96,8 +97,9 @@ void PopulateConversionLogHelper(const toco::ModelFlags& model_flags, PyObject* TocoConvert(PyObject* model_flags_proto_txt_raw, PyObject* toco_flags_proto_txt_raw, PyObject* input_contents_txt_raw, bool extended_return, - PyObject* debug_info_txt_raw, - bool enable_mlir_converter) { + PyObject* debug_info_txt_raw, bool enable_mlir_converter, + const tensorflow::quantization::PyFunctionLibrary* + quantization_py_function_library) { // Use Python C API to validate and convert arguments. In py3 (bytes), // in py2 (str). auto ConvertArg = [&](PyObject* obj, bool* error) { @@ -196,7 +198,8 @@ PyObject* TocoConvert(PyObject* model_flags_proto_txt_raw, &output_file_contents_txt); } else if (!model_flags.saved_model_dir().empty()) { status = tensorflow::ConvertSavedModelToTFLiteFlatBuffer( - model_flags, toco_flags, &output_file_contents_txt); + model_flags, toco_flags, &output_file_contents_txt, + quantization_py_function_library); } else { tensorflow::GraphDef graph_def; if (!graph_def.ParseFromString(input_contents_txt)) { @@ -294,7 +297,8 @@ PyObject* MlirQuantizeModel(PyObject* data, bool disable_per_channel, bool enable_numeric_verify, bool enable_whole_model_verify, PyObject* op_denylist, PyObject* node_denylist, - bool enable_variable_quantization) { + bool enable_variable_quantization, + bool disable_per_channel_for_dense_layers) { using tflite::interpreter_wrapper::PythonErrorReporter; char* buf = nullptr; Py_ssize_t length; @@ -340,7 +344,7 @@ PyObject* MlirQuantizeModel(PyObject* data, bool disable_per_channel, /*operator_names=*/{}, disable_per_channel, fully_quantize, output_model, error_reporter.get(), enable_numeric_verify, enable_whole_model_verify, /*legacy_float_scale=*/true, denylisted_ops, denylisted_nodes, - enable_variable_quantization); + enable_variable_quantization, disable_per_channel_for_dense_layers); if (status != kTfLiteOk) { error_reporter->exception(); return nullptr; diff --git a/tensorflow/lite/toco/python/toco_python_api.h b/tensorflow/lite/toco/python/toco_python_api.h index 3cd289037def18..37d42bcf170316 100644 --- a/tensorflow/lite/toco/python/toco_python_api.h +++ b/tensorflow/lite/toco/python/toco_python_api.h @@ -20,6 +20,8 @@ limitations under the License. #include #include +#include "tensorflow/compiler/mlir/quantization/tensorflow/python/py_function_lib.h" + namespace toco { // Convert a model represented in `input_contents`. `model_flags_proto` @@ -36,7 +38,9 @@ PyObject* TocoConvert(PyObject* model_flags_proto_txt_raw, PyObject* input_contents_txt_raw, bool extended_return = false, PyObject* debug_info_txt_raw = nullptr, - bool enable_mlir_converter = false); + bool enable_mlir_converter = false, + const tensorflow::quantization::PyFunctionLibrary* + quantization_py_function_library = nullptr); // Quantize the model with calibration data. Throw errors if `fully_quantize` // is specified by the calibration data are not sufficient to quantize the @@ -48,7 +52,8 @@ PyObject* MlirQuantizeModel(PyObject* data, bool disable_per_channel, bool enable_whole_model_verify = false, PyObject* op_denylist = nullptr, PyObject* node_denylist = nullptr, - bool enable_variable_quantization = false); + bool enable_variable_quantization = false, + bool disable_per_channel_for_dense_layers = false); // Sparsifies model to encode sparse tensors with proper format. Throws error if // sparsification fails. diff --git a/tensorflow/lite/toco/toco_flags.proto b/tensorflow/lite/toco/toco_flags.proto index 68a8a9de5b9e9d..e7a31365a65330 100644 --- a/tensorflow/lite/toco/toco_flags.proto +++ b/tensorflow/lite/toco/toco_flags.proto @@ -41,7 +41,7 @@ enum FileFormat { // of as properties of models, instead describing how models are to be // processed in the context of the present tooling job. // -// Next ID to use: 62. +// Next ID to use: 63. message TocoFlags { // Input file format optional FileFormat input_format = 1; @@ -350,4 +350,10 @@ message TocoFlags { // Quantizer integrated in the converter. // WARNING: Experimental interface, subject to change. optional stablehlo.quantization.QuantizationConfig quantization_config = 61; + + // Disables per channel weights quantization for Dense layers and enables + // legacy per tensor quantization. The legacy quantization for Dense layers is + // inconsistent with Conv 1x1 which always performs per channel quantization. + optional bool disable_per_channel_quantization_for_dense_layers = 62 + [default = false]; } diff --git a/tensorflow/lite/toco/toco_port.cc b/tensorflow/lite/toco/toco_port.cc index 289decfe92683f..56bac97b2ed1ba 100644 --- a/tensorflow/lite/toco/toco_port.cc +++ b/tensorflow/lite/toco/toco_port.cc @@ -129,13 +129,15 @@ std::string JoinPath(const std::string& a, const std::string& b) { #else // !PLATFORM_GOOGLE || __APPLE__ || __ANDROID__ || _WIN32 #include -#if defined(_WIN32) -#include // for _close, _open, _read -#endif #include #include -#include + #include +#if defined(_WIN32) +#include // for _close, _open, _read +#else +#include +#endif #if defined(PLATFORM_GOOGLE) #include "base/commandlineflags.h" diff --git a/tensorflow/lite/tools/benchmark/README.md b/tensorflow/lite/tools/benchmark/README.md index 2d7398525b99dc..1eb038b45ca77d 100644 --- a/tensorflow/lite/tools/benchmark/README.md +++ b/tensorflow/lite/tools/benchmark/README.md @@ -106,6 +106,12 @@ and the following optional parameters: Whether to optimize memory usage for large tensors with sacrificing latency. When the feature is enabled, `release_dynamic_tensors` is also enabled. +* `enable_builtin_cast_constant_cache`: `bool` (default=false) \ + Configure the builtin TFLite CAST operation to cache its output if its input + is a constant tensor. + + WARNING: This is an experimental option that may be removed at any time. + This list of parameters is not exhaustive. See [here](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/tools/benchmark/benchmark_model.cc) and diff --git a/tensorflow/lite/tools/benchmark/benchmark_tflite_model.cc b/tensorflow/lite/tools/benchmark/benchmark_tflite_model.cc index a36eab0f29ab53..05ce93d2e6d588 100644 --- a/tensorflow/lite/tools/benchmark/benchmark_tflite_model.cc +++ b/tensorflow/lite/tools/benchmark/benchmark_tflite_model.cc @@ -265,14 +265,13 @@ TfLiteStatus PopulateInputLayerInfo( std::vector shapes = Split(shapes_string, ':'); if (names.size() != shapes.size()) { - TFLITE_LOG(ERROR) << "The number of items in" - << " --input_layer_shape (" << shapes_string << ", with " - << shapes.size() << " items)" - << " must match the number of items in" - << " --input_layer (" << names_string << ", with " - << names.size() << " items)." - << " For example --input_layer=input1,input2" - << " --input_layer_shape=1,224,224,4:1,20"; + TFLITE_LOG(ERROR) + << "The number of items in --input_layer_shape (" << shapes_string + << ", with " << shapes.size() + << " items) must match the number of items in --input_layer (" + << names_string << ", with " << names.size() + << " items). For example --input_layer=input1,input2 " + "--input_layer_shape=1,224,224,4:1,20"; return kTfLiteError; } @@ -381,6 +380,8 @@ BenchmarkParams BenchmarkTfLiteModel::DefaultParams() { BenchmarkParam::Create(0)); default_params.AddParam("disable_delegate_clustering", BenchmarkParam::Create(false)); + default_params.AddParam("enable_builtin_cast_constant_cache", + BenchmarkParam::Create(false)); default_params.AddParam("output_filepath", BenchmarkParam::Create("")); @@ -469,6 +470,10 @@ std::vector BenchmarkTfLiteModel::GetFlags() { "Optimize memory usage for large tensors with sacrificing latency."), CreateFlag("disable_delegate_clustering", ¶ms_, "Disable delegate clustering."), + CreateFlag( + "enable_builtin_cast_constant_cache", ¶ms_, + "Cache the output of the builtin cast operation when its input " + "is a constant tensor."), CreateFlag( "output_filepath", ¶ms_, "File path to export outputs layer as binary data."), @@ -528,6 +533,8 @@ void BenchmarkTfLiteModel::LogParams() { "Optimize memory usage for large tensors", verbose); LOG_BENCHMARK_PARAM(bool, "disable_delegate_clustering", "Disable delegate clustering", verbose); + LOG_BENCHMARK_PARAM(bool, "enable_builtin_cast_constant_cache", + "Constant CAST output cache", verbose); LOG_BENCHMARK_PARAM(std::string, "output_filepath", "File path to export outputs layer to", verbose); LOG_BENCHMARK_PARAM(int32_t, "tensor_name_display_length", @@ -726,6 +733,8 @@ TfLiteStatus BenchmarkTfLiteModel::InitInterpreter() { params_.Get("optimize_memory_for_large_tensors")); options.SetDisableDelegateClustering( params_.Get("disable_delegate_clustering")); + options.SetCacheConstantCastOp( + params_.Get("enable_builtin_cast_constant_cache")); tflite::InterpreterBuilder builder(*model_, *resolver, &options); if (builder.SetNumThreads(num_threads) != kTfLiteOk) { diff --git a/tensorflow/lite/tools/cmake/modules/xnnpack.cmake b/tensorflow/lite/tools/cmake/modules/xnnpack.cmake index dbc832812f4f17..485751f3373d71 100644 --- a/tensorflow/lite/tools/cmake/modules/xnnpack.cmake +++ b/tensorflow/lite/tools/cmake/modules/xnnpack.cmake @@ -23,7 +23,7 @@ OverridableFetchContent_Declare( xnnpack GIT_REPOSITORY https://github.com/google/XNNPACK # Sync with tensorflow/workspace2.bzl - GIT_TAG dcbfffb80fb4f6fcfcfb5b3723854ec8797fa546 + GIT_TAG 9325fcfe52092b2f8f816db218bca208db7b2750 GIT_PROGRESS TRUE PREFIX "${CMAKE_BINARY_DIR}" SOURCE_DIR "${CMAKE_BINARY_DIR}/xnnpack" diff --git a/tensorflow/lite/tools/cmake/native_tools/flatbuffers/CMakeLists.txt b/tensorflow/lite/tools/cmake/native_tools/flatbuffers/CMakeLists.txt index c13f36fb140b2e..583a7c8c6bacf7 100644 --- a/tensorflow/lite/tools/cmake/native_tools/flatbuffers/CMakeLists.txt +++ b/tensorflow/lite/tools/cmake/native_tools/flatbuffers/CMakeLists.txt @@ -40,4 +40,4 @@ else() set(FLATC_INSTALL_PREFIX ${CMAKE_INSTALL_PREFIX} CACHE PATH "Flatc installation directory") endif() -find_package(flatbuffers) +find_package(FlatBuffers) diff --git a/tensorflow/lite/tools/delegates/experimental/stable_delegate/BUILD b/tensorflow/lite/tools/delegates/experimental/stable_delegate/BUILD index 9601726f043849..d4eeefe01432eb 100644 --- a/tensorflow/lite/tools/delegates/experimental/stable_delegate/BUILD +++ b/tensorflow/lite/tools/delegates/experimental/stable_delegate/BUILD @@ -13,6 +13,8 @@ cc_library( deps = [ "//tensorflow/lite/c:common", "//tensorflow/lite/tools:command_line_flags", + "//tensorflow/lite/tools:logging", + "//tensorflow/lite/tools:tool_params", "//tensorflow/lite/tools/delegates:delegate_provider_hdr", ] + select({ # Stable delegate does not support Windows because the shared library loader hasn't been diff --git a/tensorflow/lite/tools/delegates/experimental/stable_delegate/stable_delegate_provider.cc b/tensorflow/lite/tools/delegates/experimental/stable_delegate/stable_delegate_provider.cc index d8715648758543..ae8ab083237cd9 100644 --- a/tensorflow/lite/tools/delegates/experimental/stable_delegate/stable_delegate_provider.cc +++ b/tensorflow/lite/tools/delegates/experimental/stable_delegate/stable_delegate_provider.cc @@ -19,9 +19,10 @@ limitations under the License. #include #include -#include "tensorflow/lite/c/common.h" #include "tensorflow/lite/tools/command_line_flags.h" #include "tensorflow/lite/tools/delegates/delegate_provider.h" +#include "tensorflow/lite/tools/logging.h" +#include "tensorflow/lite/tools/tool_params.h" #if !defined(_WIN32) #include "tensorflow/lite/acceleration/configuration/c/delegate_plugin.h" diff --git a/tensorflow/lite/tools/versioning/op_signature.cc b/tensorflow/lite/tools/versioning/op_signature.cc index 5b8523622effd5..da367699004f8a 100644 --- a/tensorflow/lite/tools/versioning/op_signature.cc +++ b/tensorflow/lite/tools/versioning/op_signature.cc @@ -163,6 +163,15 @@ OpSignature GetOpSignature(const OperatorCode* op_code, const Operator* op, subgraph->tensors()->Get(op->inputs()->Get(1)); op_sig.ext_options.fully_connected.sparse_weight = (weight_tensor->sparsity() != nullptr); + const QuantizationParameters* weight_quant = + weight_tensor->quantization(); + if (weight_quant && weight_quant->scale() && + weight_quant->scale()->size() && weight_tensor->shape() && + weight_tensor->shape()->size()) { + op_sig.ext_options.fully_connected.is_per_channel_quantized = + weight_quant->scale()->size() > 1 && + weight_quant->scale()->size() == weight_tensor->shape()->Get(0); + } } break; case BuiltinOperator_MUL: { diff --git a/tensorflow/lite/tools/versioning/op_signature.h b/tensorflow/lite/tools/versioning/op_signature.h index 6e09767273357b..aece1638eca1d6 100644 --- a/tensorflow/lite/tools/versioning/op_signature.h +++ b/tensorflow/lite/tools/versioning/op_signature.h @@ -52,6 +52,7 @@ typedef struct { // TODO(b/156530611): Make this global when more ops support sparse // computation. bool sparse_weight; + bool is_per_channel_quantized; } fully_connected; struct { float input1_scale; diff --git a/tensorflow/lite/tools/versioning/op_version.cc b/tensorflow/lite/tools/versioning/op_version.cc index bba285328d2527..b5d8bb151e7145 100644 --- a/tensorflow/lite/tools/versioning/op_version.cc +++ b/tensorflow/lite/tools/versioning/op_version.cc @@ -173,6 +173,13 @@ int GetBuiltinOperatorVersion(const OpSignature& op_sig) { reinterpret_cast(op_sig.builtin_data); TFLITE_DCHECK(fully_connected_params != nullptr); + if (op_sig.inputs.at(0).type == kTfLiteFloat32 && + op_sig.inputs.at(1).type == kTfLiteInt8 && + op_sig.outputs.at(0).type == kTfLiteFloat32 && + op_sig.ext_options.fully_connected.is_per_channel_quantized) { + return 12; + } + if (op_sig.inputs.at(0).type == kTfLiteInt16 && op_sig.inputs.at(1).type == kTfLiteInt8 && op_sig.outputs.at(0).type == kTfLiteInt16) { diff --git a/tensorflow/lite/tools/versioning/op_version_test.cc b/tensorflow/lite/tools/versioning/op_version_test.cc index 5cff633f0ee0d0..acddf0d0de9043 100644 --- a/tensorflow/lite/tools/versioning/op_version_test.cc +++ b/tensorflow/lite/tools/versioning/op_version_test.cc @@ -709,6 +709,16 @@ TEST(OpVersionTest, VersioningFullyConnectedTest) { }; fully_connected_params.quantized_bias_type = kTfLiteInt32; EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 11); + + fake_op_sig = { + .op = BuiltinOperator_FULLY_CONNECTED, + .inputs = CreateOpSignatureTensorSpecs( + std::vector{kTfLiteFloat32, kTfLiteInt8}), + .outputs = CreateOpSignatureTensorSpecs(kTfLiteFloat32), + .builtin_data = reinterpret_cast(&fully_connected_params), + }; + fake_op_sig.ext_options.fully_connected.is_per_channel_quantized = true; + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 12); } TEST(OpVersionTest, VersioningDequantizeTest) { diff --git a/tensorflow/lite/tools/versioning/runtime_version.cc b/tensorflow/lite/tools/versioning/runtime_version.cc index d011a5d5438e46..170927f81d7a55 100644 --- a/tensorflow/lite/tools/versioning/runtime_version.cc +++ b/tensorflow/lite/tools/versioning/runtime_version.cc @@ -132,6 +132,7 @@ std::string FindMinimumRuntimeVersionForOp(tflite::BuiltinOperator op_code, {{BuiltinOperator_FULLY_CONNECTED, 9}, "2.3.0"}, {{BuiltinOperator_FULLY_CONNECTED, 10}, "2.11.0"}, {{BuiltinOperator_FULLY_CONNECTED, 11}, "2.15.0"}, + {{BuiltinOperator_FULLY_CONNECTED, 12}, "2.17.0"}, {{BuiltinOperator_GATHER, 1}, "1.6.0"}, {{BuiltinOperator_GATHER, 2}, "1.14.0"}, {{BuiltinOperator_GATHER, 3}, "1.15.0"}, diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index 69030356583d98..675709d4314eb4 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -546,6 +546,7 @@ tf_python_pybind_extension( "_pywrap_toco_api.pyi", ], deps = [ + "//tensorflow/compiler/mlir/quantization/tensorflow/python:py_function_lib", "//tensorflow/python/lib/core:pybind11_lib", "//third_party/python_runtime:headers", "@pybind11", diff --git a/tensorflow/python/_pywrap_toco_api.pyi b/tensorflow/python/_pywrap_toco_api.pyi index 619c686c336c32..213c6f14872f7d 100644 --- a/tensorflow/python/_pywrap_toco_api.pyi +++ b/tensorflow/python/_pywrap_toco_api.pyi @@ -13,9 +13,9 @@ # limitations under the License. # ============================================================================== -def ExperimentalMlirQuantizeModel(input_contents_txt_raw: object, disable_per_channel: bool = ..., fully_quantize: bool = ..., inference_type: int = ..., input_data_type: int = ..., output_data_type: int = ..., enable_numeric_verify: bool = ..., enable_whole_model_verify: bool = ..., op_blocklist: object = ..., node_blocklist: object = ..., enable_variable_quantization: bool = ...) -> object: ... +def ExperimentalMlirQuantizeModel(input_contents_txt_raw: object, disable_per_channel: bool = ..., fully_quantize: bool = ..., inference_type: int = ..., input_data_type: int = ..., output_data_type: int = ..., enable_numeric_verify: bool = ..., enable_whole_model_verify: bool = ..., op_blocklist: object = ..., node_blocklist: object = ..., enable_variable_quantization: bool = ..., disable_per_channel_for_dense_layers: bool = ...) -> object: ... def ExperimentalMlirSparsifyModel(input_contents_txt_raw: object) -> object: ... def FlatBufferToMlir(arg0: str, arg1: bool) -> str: ... def RegisterCustomOpdefs(custom_opdefs_txt_raw: object) -> object: ... def RetrieveCollectedErrors() -> list: ... -def TocoConvert(model_flags_proto_txt_raw: object, toco_flags_proto_txt_raw: object, input_contents_txt_raw: object, extended_return: bool = ..., debug_info_txt_raw: object = ..., enable_mlir_converter: bool = ...) -> object: ... +def TocoConvert(model_flags_proto_txt_raw: object, toco_flags_proto_txt_raw: object, input_contents_txt_raw: object, extended_return: bool = ..., debug_info_txt_raw: object = ..., enable_mlir_converter: bool = ..., quantization_py_function_library = ...) -> object: ... diff --git a/tensorflow/python/checkpoint/functional_saver.py b/tensorflow/python/checkpoint/functional_saver.py index c1a359480c3adf..77a8e98382cb96 100644 --- a/tensorflow/python/checkpoint/functional_saver.py +++ b/tensorflow/python/checkpoint/functional_saver.py @@ -57,7 +57,7 @@ def _single_shard_save( file_prefix: tensor_lib.Tensor, - shard: sharding_util.TensorSliceDict, + shard: sharding_util.Shard, task: device_lib.DeviceSpec, options: "checkpoint_options.CheckpointOptions | None" = None, ) -> ops.Operation: @@ -106,7 +106,7 @@ def _single_shard_restore( file_prefix: tensor_lib.Tensor, shardable_tensors: Sequence[sharding_util.ShardableTensor], options: "checkpoint_options.CheckpointOptions | None" = None -) -> sharding_util.TensorSliceDict: +) -> sharding_util.Shard: """Restore the saveable objects from a checkpoint with `file_prefix`. Args: @@ -221,6 +221,9 @@ def restore_fn_with_replaced_captures( _restore_noop = lambda *args, **kwargs: None +TensorKeyAndSliceSpec = tuple[str, str] +RestoreFn = Callable[[Mapping[str, tensor_lib.Tensor]], ops.Operation] + class MultiDeviceSaver: """Saves checkpoints directly from multiple devices. @@ -233,7 +236,7 @@ class MultiDeviceSaver: def __init__( self, serialized_tensors: Mapping[ - base.Trackable, sharding_util.TensorSliceDict], + base.Trackable, sharding_util.Shard], registered_savers: "RegisteredSaversDict | None" = None, call_with_mapped_captures: "MappedCapturesCallable | None" = None): """Specify a list of `SaveableObject`s to save and restore. @@ -255,11 +258,9 @@ def __init__( # Keep these two data structures so that we can map restored tensors to # the Trackable restore functions. self._keys_to_restore_fn: MutableMapping[ - sharding_util.TensorSlice, - Callable[Mapping[str, tensor_lib.Tensor]]] = {} + TensorKeyAndSliceSpec, RestoreFn] = {} self._restore_fn_to_keys: MutableMapping[ - Callable[Mapping[str, tensor_lib.Tensor]], - MutableSequence[sharding_util.TensorSlice]] = {} + RestoreFn, MutableSequence[TensorKeyAndSliceSpec]] = {} unique_tasks = set() for obj, tensor_dict in serialized_tensors.items(): @@ -377,7 +378,7 @@ def _traced_restore( def _get_shards_by_task( self, sharding_callback: sharding_util.ShardingCallback - ) -> Sequence[sharding_util.TensorSliceDict]: + ) -> Sequence[sharding_util.Shard]: """Calls the sharding callback with shardable_tensors. Args: diff --git a/tensorflow/python/checkpoint/sharding/sharding_policies.py b/tensorflow/python/checkpoint/sharding/sharding_policies.py index 889f2fa09ce31a..269ed85f7f7c12 100644 --- a/tensorflow/python/checkpoint/sharding/sharding_policies.py +++ b/tensorflow/python/checkpoint/sharding/sharding_policies.py @@ -15,7 +15,7 @@ """Checkpoint policies that determine how tensors are split into shards.""" import math -from typing import MutableSequence, Sequence +from typing import Sequence from absl import logging @@ -43,7 +43,7 @@ def description(self) -> str: def __call__( self, shardable_tensors: Sequence[sharding_util.ShardableTensor] - ) -> Sequence[sharding_util.TensorSliceDict]: + ) -> Sequence[sharding_util.Shard]: """Callback to split tensors into shards based on their device spec task. Args: @@ -66,7 +66,6 @@ def __call__( return [tensors_by_task] -_PartitionAxisAndSize = tuple[int, int] _OffsetAndShape = tuple[Sequence[int], Sequence[int]] @@ -79,6 +78,227 @@ class MaxShardSizePolicy(sharding_util.ShardingCallback): checkpoint object graph, whose size cannot be calculated when saving. """ + class MaxShardSizePartitioner(): + """Partition tensors into shards with a max shard size.""" + + def _get_next_partition(self) -> tuple[int, float]: + """Gets tensor partition with size closest to shard_size_remaining. + + Returns: + A tuple containing the axis and size of the next partition. + """ + rank = self._working_tensor_shape.rank + if rank is None or rank == 0: + return 0, math.inf + + num_elems = self._working_tensor_shape.num_elements() + + def num_partitions(axis: int) -> float: + axis_len = self._working_tensor_shape.dims[axis].value + slice_elems = num_elems // axis_len + bytes_per_slice = slice_elems * self._dtype_size + slices_per_shard = self._shard_size_remaining // bytes_per_slice + if slices_per_shard == 0: + return math.inf + return math.ceil(axis_len / slices_per_shard) + + # Find axis with minimum partitions. (axis with maximum partition size) + # (max partition size is as close as possible to the shard_size_remaining) + min_parts = num_partitions(0) + min_axis = 0 + for axis in range(1, rank): + parts_along_axis = num_partitions(axis) + part_size = num_elems * self._dtype_size / parts_along_axis + if (parts_along_axis < min_parts and + part_size <= self._shard_size_remaining): + min_axis, min_parts = axis, int(parts_along_axis) + return (min_axis, + math.ceil(int(self._working_tensor_shape[min_axis]) / min_parts)) + + def _add_partition( + self, part_axis: int, part_size: float + ) -> tuple[tensor_lib.Tensor, _OffsetAndShape]: + """Adds the tensor partition to the shard, if possible. + + Args: + part_axis: The axis of the partition. + part_size: The size of the partition. + + Returns: + A tuple containing the size of the slice that was added to the shard and + the offset & shape of the remaining portion of the tensor. + """ + if self._root_shape.rank is None or self._root_shape.rank == 0: + return None, (None, None) + + # Add what we can to the current shard. + slice_offset = self._working_tensor_offset + slice_shape = [self._root_shape[i] - slice_offset[i] + for i in range(self._root_shape.rank)] + slice_shape[part_axis] = part_size + slice_size_in_bytes = int(math.prod(slice_shape)) * self._dtype_size + with ops.device(self._device): + tensor_slice = array_ops.slice( + self._root_tensor, begin=slice_offset, size=slice_shape) + slice_spec = variables.Variable.SaveSliceInfo( + full_name=self._checkpoint_key, + full_shape=self._root_shape, + var_offset=slice_offset, + var_shape=slice_shape).spec.strip() + remaining_size = self._shard_size_remaining + if slice_size_in_bytes > self.max_shard_size: + logging.warning("Slice %s of tensor %s is a scalar of size %s bytes " + "and cannot be partitioned into a shard of max shard " + "size %s bytes. It will be added as an individual " + "shard that exceeds the max shard size.", slice_spec, + self._checkpoint_key, slice_size_in_bytes, + self.max_shard_size) + self._large_scalars.append( + {self._checkpoint_key: {slice_spec: tensor_slice}}) + elif slice_size_in_bytes > self._shard_size_remaining: + # Smallest partition can't fit in the remaining shard space. Start fresh + # with a new shard. + return None, (None, None) + else: + if not self._tensors_by_shard or self._shard_size_remaining < 1: + self._tensors_by_shard.append({}) + remaining_size = self.max_shard_size + (self._tensors_by_shard[-1] + .setdefault(self._checkpoint_key, {})[slice_spec]) = tensor_slice + remaining_size -= slice_size_in_bytes + + # Get remaining portion of tensor to add to the next shard(s). + slice_offset[part_axis] += part_size + slice_shape = [self._root_shape[i] - slice_offset[i] + for i in range(self._root_shape.rank)] + + return (remaining_size, (slice_offset, slice_shape)) + + def get_shards( + self, + max_shard_size: int, + shardable_tensors: Sequence[sharding_util.ShardableTensor] + ) -> Sequence[sharding_util.Shard]: + """Callback to split tensors into shards with a max shard size. + + Args: + max_shard_size: The maximum size of a shard file in bytes. + shardable_tensors: A list of ShardableTensors. + + Returns: + List of shard dicts containing tensors. + [ {checkpoint key: {slice_spec: tensor} } ] + """ + self.max_shard_size = max_shard_size + self._tensors_by_shard = [] + self._large_scalars = [] + + self._shard_size_remaining = self.max_shard_size + for shardable_tensor in shardable_tensors: + self._root_tensor = shardable_tensor.tensor + self._root_shape = shardable_tensor.shape + self._dtype = shardable_tensor.dtype + self._device = shardable_tensor.device + self._checkpoint_key = shardable_tensor.checkpoint_key + self._slice_spec = shardable_tensor.slice_spec + + self._dtype_size = dtypes.as_dtype(self._dtype).size + total_size = self._root_shape.num_elements() * self._dtype_size # bytes + + # Calculate string tensor sizes. + if self._checkpoint_key == base.OBJECT_GRAPH_PROTO_KEY: + # In graph mode, the object graph is populated using feed_additions + # when the session is run. So, we can't calculate the size here. + # Fortunately, the serialized object graph string will never be that + # big, so we just place it in the current shard without worrying about + # its size. + total_size = self._dtype_size = 0 + elif self._dtype == dtypes.string: + if not context.executing_eagerly(): + with ops.device(self._device): + self._root_tensor = ops.get_default_session().run( + self._root_tensor) + + if self._root_shape.rank is None or self._root_shape.rank == 0: + sizes = [string_ops.string_length(self._root_tensor, unit="BYTE")] + else: + sizes = [string_ops.string_length(elem, unit="BYTE") + for elem in self._root_tensor] + + if context.executing_eagerly(): + sizes = [size.numpy() for size in sizes] + else: + with ops.device(self._device): + sizes = ops.get_default_session().run(sizes) + + total_size = sum(sizes) + self._dtype_size = max(sizes) + + if (total_size > self.max_shard_size and + (self._root_shape.rank is None or self._root_shape.rank == 0)): + logging.warning("Tensor %s is a scalar of size %s bytes and cannot " + "be partitioned into a shard of max shard size %s " + "bytes. It will be added as an individual shard that " + "exceeds the max shard size.", + self._checkpoint_key, total_size, self.max_shard_size) + self._large_scalars.append( + {self._checkpoint_key: {self._slice_spec: self._root_tensor}}) + continue + + # Partition tensor and add partitions to shards. + self._working_tensor_offset = [0] * self._root_shape.rank + self._working_tensor_shape = self._root_shape + working_tensor_size = total_size + while working_tensor_size > self._shard_size_remaining: + (part_axis, part_size) = self._get_next_partition() + + if part_size == 0: + # Tensor partition couldn't fit in remaining shard space. Try again + # with the next full shard. + self._tensors_by_shard.append({}) + self._shard_size_remaining = self.max_shard_size + continue + + (remaining_size, + (remaining_offset, remaining_shape)) = self._add_partition( + part_axis=part_axis, part_size=part_size) + + if remaining_size is None: + # Tensor partition couldn't fit in remaining shard space. Try again + # with the next full shard. + self._tensors_by_shard.append({}) + self._shard_size_remaining = self.max_shard_size + else: + self._working_tensor_offset = remaining_offset + self._working_tensor_shape = tensor_shape.TensorShape( + remaining_shape) + working_tensor_size = ( + int(math.prod(remaining_shape)) * self._dtype_size) + self._shard_size_remaining = remaining_size + + if self._working_tensor_shape.num_elements() > 0: + if self._working_tensor_offset and self._working_tensor_shape: + with ops.device(self._device): + working_tensor = array_ops.slice( + self._root_tensor, + begin=self._working_tensor_offset, + size=self._working_tensor_shape.as_list()) + else: + working_tensor = self._root_tensor + remaining_tensor_slice_spec = variables.Variable.SaveSliceInfo( + full_name=self._checkpoint_key, + full_shape=self._root_shape, + var_offset=self._working_tensor_offset, + var_shape=self._working_tensor_shape).spec.strip() + if not self._tensors_by_shard: + self._tensors_by_shard.append({}) + (self._tensors_by_shard[-1] + .setdefault(self._checkpoint_key, {}) + [remaining_tensor_slice_spec]) = working_tensor + self._shard_size_remaining -= working_tensor_size + + return self._tensors_by_shard + self._large_scalars + def __init__(self, max_shard_size: int): self.max_shard_size = max_shard_size @@ -86,237 +306,8 @@ def __init__(self, max_shard_size: int): def description(self) -> str: return "Split tensors into shards with a max shard size." - def _get_next_partition( - self, - shard_size_remaining: int, - shape: tensor_shape.TensorShape, - dtype_size: int, - num_elems: int - ) -> _PartitionAxisAndSize: - """Gets tensor partition with size closest to shard_size_remaining. - - Args: - shard_size_remaining: Size in bytes of the space remaining in the shard. - shape: Shape of the working tensor to partition in the remaining - shard space. - dtype_size: Size in bytes of the dtype of the working tensor. - num_elems: Number of elements in the working tensor. - - Returns: - A tuple containing the axis of the next partition and that partition size. - """ - if shape.rank is None or shape.rank == 0: - return 0, math.inf - - # Find axis with minimum partitions. (aka axis with maximum partition size) - # (max partition size is as close as possible to the shard_size_remaining) - bytes_per_slice = num_elems // shape.dims[0].value * dtype_size - slices_per_shard = max( - 1, math.floor(shard_size_remaining / bytes_per_slice)) - min_parts = math.ceil(shape.dims[0].value / slices_per_shard) - min_axis = 0 - for axis in range(1, shape.rank): - bytes_per_slice = num_elems // shape.dims[axis].value * dtype_size - slices_per_shard = max( - 1, math.floor(shard_size_remaining / bytes_per_slice)) - axis_parts = math.ceil(shape.dims[axis].value / slices_per_shard) - partition_size = num_elems * dtype_size / axis_parts - if (axis_parts < min_parts and - partition_size < shard_size_remaining): - min_axis, min_parts = axis, int(axis_parts) - return min_axis, math.ceil(int(shape[min_axis]) / min_parts) - - def _add_partition( - self, - root_shardable_tensor: sharding_util.ShardableTensor, - dtype_size: int, - working_tensor_offset: Sequence[int], - part_axis_and_size: _PartitionAxisAndSize, - shard_size_remaining: int, - max_shard_size: int, - tensors_by_shard: MutableSequence[sharding_util.TensorSliceDict], - large_scalars: MutableSequence[sharding_util.TensorSliceDict], - ) -> tuple[tensor_lib.Tensor, _OffsetAndShape]: - """Adds the tensor partition to the shard, if possible. - - Args: - root_shardable_tensor: The full tensor being partitioned. - dtype_size: Size in bytes of the dtype of the working tensor. - working_tensor_offset: The offset of the working tensor in the full - tensor. - part_axis_and_size: A tuple containing the axis of the partition and that - partition size. - shard_size_remaining: Size in bytes of the space remaining in the shard. - max_shard_size: Max size in bytes allowed for a checkpoint shard. - tensors_by_shard: List of shard dicts containing tensors. - [ {checkpoint key: {slice_spec: tensor} } ] - large_scalars: List of shard dicts containing scalars too large to fit in - the max_shard_size. [ {checkpoint key: {slice_spec: tensor} } ] - - Returns: - A tuple containing the size of the slice that was added to the shard and - the offset & shape of the remaining portion of the tensor. - """ - root_tensor = root_shardable_tensor.tensor - root_tensor_shape = root_shardable_tensor.shape - checkpoint_key = root_shardable_tensor.checkpoint_key - - if root_tensor_shape.rank is None or root_tensor_shape.rank == 0: - return None, (None, None) - - min_axis, part_size = part_axis_and_size - - # Add what we can to the current shard. - slice_offset = working_tensor_offset - slice_shape = [root_tensor_shape[i] - slice_offset[i] - for i in range(root_tensor_shape.rank)] - slice_shape[min_axis] = part_size - slice_size_in_bytes = int(math.prod(slice_shape)) * dtype_size - with ops.device(root_shardable_tensor.device): - tensor_slice = array_ops.slice( - root_tensor, begin=slice_offset, size=slice_shape) - slice_spec = variables.Variable.SaveSliceInfo( - full_name=checkpoint_key, - full_shape=root_tensor_shape, - var_offset=slice_offset, - var_shape=slice_shape).spec.strip() - remaining_size = shard_size_remaining - if slice_size_in_bytes > max_shard_size: - logging.warning("Slice %s of tensor %s is a scalar of size %s bytes and " - "cannot be partitioned into a shard of max shard size %s " - "bytes. It will be added as an individual shard that " - "exceeds the max shard size.", slice_spec, checkpoint_key, - slice_size_in_bytes, max_shard_size) - large_scalars.append({checkpoint_key: {slice_spec: tensor_slice}}) - elif slice_size_in_bytes > shard_size_remaining: - # Smallest partition can't fit in the remaining shard space. Start fresh - # with a new shard. - return None, (None, None) - else: - if not tensors_by_shard or shard_size_remaining < 1: - tensors_by_shard.append({}) - remaining_size = max_shard_size - (tensors_by_shard[-1] - .setdefault(checkpoint_key, {})[slice_spec]) = tensor_slice - remaining_size -= slice_size_in_bytes - - # Get remaining portion of tensor to add to the next shard(s). - slice_offset[min_axis] += part_size - slice_shape = [root_tensor_shape[i] - slice_offset[i] - for i in range(root_tensor_shape.rank)] - - return (remaining_size, (slice_offset, slice_shape)) - def __call__( self, shardable_tensors: Sequence[sharding_util.ShardableTensor] - ) -> Sequence[sharding_util.TensorSliceDict]: - """Callback to split tensors into shards with a max shard size. - - Args: - shardable_tensors: A list of ShardableTensors. - - Returns: - List of shard dicts containing tensors. - [ {checkpoint key: {slice_spec: tensor} } ] - """ - tensors_by_shard = [] - large_scalars = [] - - shard_size_remaining = self.max_shard_size - for shardable_tensor in shardable_tensors: - root_tensor = shardable_tensor.tensor - root_shape = shardable_tensor.shape - dtype = shardable_tensor.dtype - checkpoint_key = shardable_tensor.checkpoint_key - - dtype_size = dtypes.as_dtype(dtype).size - total_size = root_shape.num_elements() * dtype_size # in bytes - - # Calculate string tensor sizes. - if checkpoint_key == base.OBJECT_GRAPH_PROTO_KEY: - # In graph mode, the object graph is populated using feed_additions when - # the session is run. So, we can't calculate the size here. Fortunately, - # the serialized object graph string will never be that big, so we just - # place it in the current shard without worrying about its size. - total_size = dtype_size = 0 - elif dtype == dtypes.string: - if not context.executing_eagerly(): - with ops.device(shardable_tensor.device): - root_tensor = ops.get_default_session().run(root_tensor) - - if root_shape.rank is None or root_shape.rank == 0: - sizes = [string_ops.string_length(root_tensor, unit="BYTE")] - else: - sizes = [string_ops.string_length(elem, unit="BYTE") - for elem in root_tensor] - - if context.executing_eagerly(): - sizes = [size.numpy() for size in sizes] - else: - with ops.device(shardable_tensor.device): - sizes = ops.get_default_session().run(sizes) - - total_size = sum(sizes) - dtype_size = max(sizes) - - if (total_size > self.max_shard_size and - (root_shape.rank is None or root_shape.rank == 0)): - logging.warning("Tensor %s is a scalar of size %s bytes and cannot be " - "partitioned into a shard of max shard size %s bytes. " - "It will be added as an individual shard that exceeds " - "the max shard size.", - checkpoint_key, total_size, self.max_shard_size) - large_scalars.append( - {checkpoint_key: {shardable_tensor.slice_spec: root_tensor}}) - continue - - # Partition tensor and add partitions to shards. - working_tensor = root_tensor - working_tensor_var_offset = [0] * root_shape.rank - working_tensor_shape = root_shape - working_tensor_size = total_size - while working_tensor_size > shard_size_remaining: - part_axis_and_size = self._get_next_partition( - shard_size_remaining=shard_size_remaining, - shape=working_tensor_shape, - dtype_size=dtype_size, - num_elems=working_tensor_shape.num_elements()) - - (remaining_size, - (remaining_offset, remaining_shape)) = self._add_partition( - root_shardable_tensor=shardable_tensor, - dtype_size=dtype_size, - working_tensor_offset=working_tensor_var_offset, - part_axis_and_size=part_axis_and_size, - shard_size_remaining=shard_size_remaining, - max_shard_size=self.max_shard_size, - tensors_by_shard=tensors_by_shard, - large_scalars=large_scalars) - - if remaining_size is None: - # Tensor partition couldn't fit in remaining shard space. Try again - # with the next full shard. - tensors_by_shard.append({}) - shard_size_remaining = self.max_shard_size - else: - working_tensor = array_ops.slice( - root_tensor, begin=remaining_offset, size=remaining_shape) - working_tensor_var_offset = remaining_offset - working_tensor_shape = working_tensor.shape - working_tensor_size = int(math.prod(remaining_shape)) * dtype_size - shard_size_remaining = remaining_size - - if working_tensor_shape.num_elements() > 0: - remaining_tensor_slice_spec = variables.Variable.SaveSliceInfo( - full_name=checkpoint_key, - full_shape=root_shape, - var_offset=working_tensor_var_offset, - var_shape=working_tensor_shape).spec.strip() - if not tensors_by_shard: - tensors_by_shard.append({}) - (tensors_by_shard[-1] - .setdefault(checkpoint_key, {}) - [remaining_tensor_slice_spec]) = working_tensor - shard_size_remaining -= working_tensor_size - - return tensors_by_shard + large_scalars + ) -> Sequence[sharding_util.Shard]: + return self.MaxShardSizePartitioner().get_shards( + self.max_shard_size, shardable_tensors) diff --git a/tensorflow/python/checkpoint/sharding/sharding_policies_test.py b/tensorflow/python/checkpoint/sharding/sharding_policies_test.py index 133a0b923d6338..1c96f621b54a24 100644 --- a/tensorflow/python/checkpoint/sharding/sharding_policies_test.py +++ b/tensorflow/python/checkpoint/sharding/sharding_policies_test.py @@ -485,15 +485,8 @@ def __init__(self, var_offset, var_shape): self.evaluate(shards[5][v1_name][slice_spec]), [[[10.0], [11.0]]]) # max_shard_size: 12 bytes - # 12 bytes is enough to fit 3 elements per variable in each shard, BUT that - # would require concurrent multidimensional tensor partitioning, which is - # not currently implemented for MaxShardSizePolicy. (When partitioning a - # tensor into a shard, we choose an axis to partition along. This can - # happen multiple times for a given tensor (in the case that the tensor - # spans multiple shards). In that case, multiple dimensions can be - # partitioned along (each time the tensor is partitioned, a new axis can be - # chosen), but not within a single iteration of adding a tensor partition to - # the shard.) So, v0/v1 should be split into 3 shards each. + # 12 bytes is enough to fit 3 elements per variable in each shard. + # v0/v1 should be split into 2 shards each. callback = sharding_policies.MaxShardSizePolicy(max_shard_size=12) shards = [] for tensors in shardable_tensors: @@ -504,37 +497,29 @@ def __init__(self, var_offset, var_shape): [ {"v0/.ATTRIBUTES/VARIABLE_VALUE",}, {"v0/.ATTRIBUTES/VARIABLE_VALUE",}, - {"v0/.ATTRIBUTES/VARIABLE_VALUE",}, - {"v1/.ATTRIBUTES/VARIABLE_VALUE",}, {"v1/.ATTRIBUTES/VARIABLE_VALUE",}, {"v1/.ATTRIBUTES/VARIABLE_VALUE", "_CHECKPOINTABLE_OBJECT_GRAPH",} ]) # V0 - slice_spec = V0SaveSliceInfo(var_offset=[0, 0], var_shape=[1, 2]).spec - self.assertAllEqual( - self.evaluate(shards[0][v0_name][slice_spec]), [[0, 1]]) - - slice_spec = V0SaveSliceInfo(var_offset=[1, 0], var_shape=[1, 2]).spec + slice_spec = V0SaveSliceInfo(var_offset=[0, 0], var_shape=[3, 1]).spec self.assertAllEqual( - self.evaluate(shards[1][v0_name][slice_spec]), [[2, 3]]) + self.evaluate(shards[0][v0_name][slice_spec]), [[0], [2], [4]]) - slice_spec = V0SaveSliceInfo(var_offset=[2, 0], var_shape=[1, 2]).spec + slice_spec = V0SaveSliceInfo(var_offset=[0, 1], var_shape=[3, 1]).spec self.assertAllEqual( - self.evaluate(shards[2][v0_name][slice_spec]), [[4, 5]]) + self.evaluate(shards[1][v0_name][slice_spec]), [[1], [3], [5]]) # V1 - slice_spec = V1SaveSliceInfo(var_offset=[0, 0, 0], var_shape=[1, 2, 1]).spec - self.assertAllEqual( - self.evaluate(shards[3][v1_name][slice_spec]), [[[6.0], [7.0]]]) - - slice_spec = V1SaveSliceInfo(var_offset=[1, 0, 0], var_shape=[1, 2, 1]).spec + slice_spec = V1SaveSliceInfo(var_offset=[0, 0, 0], var_shape=[3, 1, 1]).spec self.assertAllEqual( - self.evaluate(shards[4][v1_name][slice_spec]), [[[8.0], [9.0]]]) + self.evaluate(shards[2][v1_name][slice_spec]), + [[[6.0]], [[8.0]], [[10.0]]]) - slice_spec = V1SaveSliceInfo(var_offset=[2, 0, 0], var_shape=[1, 2, 1]).spec + slice_spec = V1SaveSliceInfo(var_offset=[0, 1, 0], var_shape=[3, 1, 1]).spec self.assertAllEqual( - self.evaluate(shards[5][v1_name][slice_spec]), [[[10.0], [11.0]]]) + self.evaluate(shards[3][v1_name][slice_spec]), + [[[7.0]], [[9.0]], [[11.0]]]) # max_shard_size: 16 bytes # Each variable should be split into 1.5 shards. The middle shard will @@ -688,6 +673,7 @@ def test_CheckpointOption_MaxShardSizePolicy(self): tmp_dir, options=checkpoint_options.CheckpointOptions( experimental_sharding_callback=( sharding_policies.MaxShardSizePolicy(max_shard_size=10)))) + # 8 files = 3 shards for v0, 3 for v1, 1 for v2, and 1 for the object graph self.assertLen(gfile.Glob(save_path + ".data*"), 8) ckpt.restore(save_path) diff --git a/tensorflow/python/checkpoint/sharding/sharding_util.py b/tensorflow/python/checkpoint/sharding/sharding_util.py index 322bba18dcfa84..d97a4955433452 100644 --- a/tensorflow/python/checkpoint/sharding/sharding_util.py +++ b/tensorflow/python/checkpoint/sharding/sharding_util.py @@ -30,8 +30,11 @@ from tensorflow.python.util import tf_export -TensorSlice = MutableMapping[tensor_spec.TensorSpec, tensor_lib.Tensor] -TensorSliceDict = MutableMapping[str, TensorSlice] +TensorSlices = MutableMapping[tensor_spec.TensorSpec, tensor_lib.Tensor] +# A mapping from a checkpoint key (full tensor name) to the corresponding tensor +# slices of the full tensor. It represents the collection of tensors stored in a +# checkpoint shard data file. +Shard = MutableMapping[str, TensorSlices] @tf_export.tf_export("train.experimental.ShardableTensor") @@ -144,7 +147,7 @@ def description(self) -> str: @abc.abstractmethod def __call__( self, shardable_tensors: Sequence[ShardableTensor] - ) -> Sequence[TensorSliceDict]: + ) -> Sequence[Shard]: pass def __hash__(self) -> int: @@ -159,7 +162,7 @@ def __hash__(self) -> int: def validate_shards( - shards: Sequence[TensorSliceDict], + shards: Sequence[Shard], shardable_tensors: Sequence[ShardableTensor], callback_description: str ) -> None: @@ -221,12 +224,33 @@ def validate_shards( f" original tensor_dtype: {target_dtype}\n" f" new tensor_dtype: {shard_tensor.dtype}\n") - # Validate same task in shard. + # Validate no task change. + target_task = device_lib.DeviceSpec.from_string( + unseen_tensor_dict[checkpoint_key][slice_spec].device).task + shard_tensor_task = device_lib.DeviceSpec.from_string( + shard_tensor.device).task + if shard_tensor_task != target_task: + raise RuntimeError( + "After executing the checkpoint sharding callback, a tensor " + "was found with an altered task:\n" + f" callback_description: {callback_description}\n" + f" checkpoint_key: {checkpoint_key}\n" + f" slice_spec: {slice_spec}\n" + f" original tensor_task: {target_task}\n" + f" new tensor_task: {shard_tensor_task}\n") + + # Validate tensors in shard have the same task. if task_tensor is None: - task_tensor = ShardableTensor - task_tensor.device = shard_tensor.device - task_tensor.checkpoint_key = checkpoint_key - task_tensor.slice_spec = slice_spec + task_tensor = ShardableTensor( + _tensor_save_spec=None, + tensor=None, + dtype=None, + device=shard_tensor.device, + name=None, + shape=None, + slice_spec=slice_spec, + checkpoint_key=checkpoint_key, + trackable=None) else: task1 = device_lib.DeviceSpec.from_string(task_tensor.device).task task2 = device_lib.DeviceSpec.from_string(shard_tensor.device).task diff --git a/tensorflow/python/checkpoint/sharding/sharding_util_test.py b/tensorflow/python/checkpoint/sharding/sharding_util_test.py index 1c5acbea791b78..be1170520fb830 100644 --- a/tensorflow/python/checkpoint/sharding/sharding_util_test.py +++ b/tensorflow/python/checkpoint/sharding/sharding_util_test.py @@ -84,7 +84,7 @@ def description(self): def __call__( self, shardable_tensors: Sequence[sharding_util.ShardableTensor] - ) -> Sequence[sharding_util.TensorSliceDict]: + ) -> Sequence[sharding_util.Shard]: pass self.assertEqual(hash(BlankCallback()), hash(BlankCallback())) @@ -99,7 +99,7 @@ def description(self): def __call__( self, shardable_tensors: Sequence[sharding_util.ShardableTensor] - ) -> Sequence[sharding_util.TensorSliceDict]: + ) -> Sequence[sharding_util.Shard]: pass self.assertEqual(hash(ValueCallback(1)), hash(ValueCallback(1))) @@ -165,7 +165,7 @@ def description(self): def __call__( self, shardable_tensors: Sequence[sharding_util.ShardableTensor] - ) -> Sequence[sharding_util.TensorSliceDict]: + ) -> Sequence[sharding_util.Shard]: tensor = shardable_tensors[0].tensor checkpoint_key = shardable_tensors[0].checkpoint_key slice_spec = shardable_tensors[0].slice_spec @@ -204,7 +204,7 @@ def description(self): def __call__( self, shardable_tensors: Sequence[sharding_util.ShardableTensor] - ) -> Sequence[sharding_util.TensorSliceDict]: + ) -> Sequence[sharding_util.Shard]: checkpoint_key = "ADDED_TENSOR_ABC123" slice_spec = "" tensor = tensor_lib.Tensor() @@ -238,7 +238,7 @@ def description(self): def __call__( self, shardable_tensors: Sequence[sharding_util.ShardableTensor] - ) -> Sequence[sharding_util.TensorSliceDict]: + ) -> Sequence[sharding_util.Shard]: shards = [] for shardable_tensor in shardable_tensors: tensor = shardable_tensor.tensor @@ -277,7 +277,7 @@ def description(self): def __call__( self, shardable_tensors: Sequence[sharding_util.ShardableTensor] - ) -> Sequence[sharding_util.TensorSliceDict]: + ) -> Sequence[sharding_util.Shard]: shards = [] for shardable_tensor in shardable_tensors: tensor = shardable_tensor.tensor @@ -303,8 +303,56 @@ def __call__( sharding_util.validate_shards( shards, shardable_tensors_flat, sharding_callback.description) + def test_validate_shards_task_change(self): + servers = [server_lib.Server.create_local_server() for _ in range(2)] + cluster_spec = server_lib.ClusterSpec({ + "worker": [s.target[len("grpc://"):] for s in servers]}) + remote.connect_to_cluster(cluster_spec) + + root = module.Module() + with ops.device("/job:worker/task:0/cpu:0"): + v0 = resource_variable_ops.ResourceVariable(0.0, name="v0") + with ops.device("/job:worker/task:1/cpu:0"): + v1 = resource_variable_ops.ResourceVariable(0.0, name="v1") + root.v0 = v0 + root.v1 = v1 + + class TaskChangeCallback(sharding_util.ShardingCallback): + @property + def description(self): + return "task change callback" + + def __call__( + self, shardable_tensors: Sequence[sharding_util.ShardableTensor] + ) -> Sequence[sharding_util.Shard]: + shards = [] + for shardable_tensor in shardable_tensors: + tensor = shardable_tensor.tensor + checkpoint_key = shardable_tensor.checkpoint_key + slice_spec = shardable_tensor.slice_spec + if checkpoint_key == "v0/.ATTRIBUTES/VARIABLE_VALUE": + with ops.device("/job:worker/task:1/cpu:0"): + tensor = array_ops.identity(tensor) + shards.append({checkpoint_key: {slice_spec: tensor}}) + return shards + + shardable_tensors = self._get_shardable_tensors_by_task(root) + shardable_tensors_flat = [] + for tensors in shardable_tensors: + shardable_tensors_flat.extend(tensors) + + sharding_callback = TaskChangeCallback() + shards = [] + for tensors in shardable_tensors: + shards.extend(sharding_callback(tensors)) + + with self.assertRaisesRegex(RuntimeError, + "a tensor was found with an altered task"): + sharding_util.validate_shards( + shards, shardable_tensors_flat, sharding_callback.description) + def test_validate_shards_different_tasks(self): - servers = [server_lib.Server.create_local_server() for _ in range(3)] + servers = [server_lib.Server.create_local_server() for _ in range(2)] cluster_spec = server_lib.ClusterSpec({ "worker": [s.target[len("grpc://"):] for s in servers]}) remote.connect_to_cluster(cluster_spec) @@ -324,7 +372,7 @@ def description(self): def __call__( self, shardable_tensors: Sequence[sharding_util.ShardableTensor] - ) -> Sequence[sharding_util.TensorSliceDict]: + ) -> Sequence[sharding_util.Shard]: shard = {} for shardable_tensor in shardable_tensors: tensor = shardable_tensor.tensor @@ -359,7 +407,7 @@ def description(self): def __call__( self, shardable_tensors: Sequence[sharding_util.ShardableTensor] - ) -> Sequence[sharding_util.TensorSliceDict]: + ) -> Sequence[sharding_util.Shard]: return [] shardable_tensors = self._get_shardable_tensors_by_task(root) diff --git a/tensorflow/python/compat/compat.py b/tensorflow/python/compat/compat.py index 3965b8212fe851..56b5d1b3213c95 100644 --- a/tensorflow/python/compat/compat.py +++ b/tensorflow/python/compat/compat.py @@ -29,7 +29,7 @@ # This value changes every day with an automatic CL. It can be modified in code # via `forward_compatibility_horizon()` or with the environment variable # TF_FORWARD_COMPATIBILITY_DELTA_DAYS, which is added to the compatibility date. -_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2024, 1, 30) +_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2024, 2, 22) _FORWARD_COMPATIBILITY_DELTA_DAYS_VAR_NAME = "TF_FORWARD_COMPATIBILITY_DELTA_DAYS" _FORWARD_COMPATIBILITY_DATE_NUMBER = None diff --git a/tensorflow/python/data/experimental/kernel_tests/BUILD b/tensorflow/python/data/experimental/kernel_tests/BUILD index 47f44a8e31576b..721383b1e51ab4 100644 --- a/tensorflow/python/data/experimental/kernel_tests/BUILD +++ b/tensorflow/python/data/experimental/kernel_tests/BUILD @@ -191,6 +191,23 @@ tf_py_strict_test( ], ) +tf_py_strict_test( + name = "global_shuffle_test", + srcs = ["global_shuffle_test.py"], + deps = [ + "//tensorflow/python/data/experimental/ops:global_shuffle_op", + "//tensorflow/python/data/kernel_tests:test_base", + "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/python/framework:combinations", + "//tensorflow/python/framework:constant_op", + "//tensorflow/python/framework:dtypes", + "//tensorflow/python/framework:errors", + "//tensorflow/python/framework:random_seed", + "//tensorflow/python/platform:client_testlib", + "@absl_py//absl/testing:parameterized", + ], +) + tf_py_strict_test( name = "group_by_reducer_test", size = "small", diff --git a/tensorflow/python/data/experimental/kernel_tests/global_shuffle_test.py b/tensorflow/python/data/experimental/kernel_tests/global_shuffle_test.py new file mode 100644 index 00000000000000..c3772209495e3f --- /dev/null +++ b/tensorflow/python/data/experimental/kernel_tests/global_shuffle_test.py @@ -0,0 +1,149 @@ +# Copyright 2024 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for global shuffling of tf.data datasets.""" + +from typing import Optional + +from absl.testing import parameterized + +from tensorflow.python.data.experimental.ops import global_shuffle_op +from tensorflow.python.data.kernel_tests import test_base +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import combinations +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors +from tensorflow.python.framework import random_seed +from tensorflow.python.platform import test + + +class GlobalShuffleTest(test_base.DatasetTestBase, parameterized.TestCase): + """Tests for global shuffling of tf.data datasets.""" + + @combinations.generate( + combinations.times( + test_base.default_test_combinations(), + combinations.combine(seed=[None, 42], use_tensor_seed=[True, False]))) + def testRange(self, seed: Optional[int], use_tensor_seed: bool): + dataset_range = 100 + dataset = dataset_ops.Dataset.range(dataset_range) + seed = (constant_op.constant(seed, dtype=dtypes.int64) + if seed and use_tensor_seed else seed) + dataset = global_shuffle_op._global_shuffle(dataset, seed=seed) + dataset = dataset.repeat(3) + output = self.getDatasetOutput(dataset, requires_initialization=True) + self.assertCountEqual(output, list(range(dataset_range)) * 3) + + output_per_iteration = [ + output[i : i + dataset_range] + for i in range(0, len(output), dataset_range)] + self.assertCountEqual(output_per_iteration[0], list(range(dataset_range))) + self.assertCountEqual(output_per_iteration[1], list(range(dataset_range))) + self.assertCountEqual(output_per_iteration[2], list(range(dataset_range))) + self.assertNotEqual(output_per_iteration[0], output_per_iteration[1]) + self.assertNotEqual(output_per_iteration[0], output_per_iteration[2]) + self.assertNotEqual(output_per_iteration[1], output_per_iteration[2]) + + @combinations.generate( + combinations.times( + test_base.default_test_combinations(), + combinations.combine(seed=[None, 42]))) + def testNegativeRange(self, seed: Optional[int]): + dataset_range = 10 + dataset = dataset_ops.Dataset.range(dataset_range, -dataset_range, -1) + dataset = global_shuffle_op._global_shuffle(dataset) + dataset = dataset.repeat(3) + output = self.getDatasetOutput(dataset, requires_initialization=True) + self.assertCountEqual( + output, list(range(dataset_range, -dataset_range, -1)) * 3) + + output_per_iteration = [ + output[i : i + dataset_range * 2] + for i in range(0, len(output), dataset_range * 2)] + self.assertCountEqual(output_per_iteration[0], + list(range(dataset_range, -dataset_range, -1))) + self.assertCountEqual(output_per_iteration[1], + list(range(dataset_range, -dataset_range, -1))) + self.assertCountEqual(output_per_iteration[2], + list(range(dataset_range, -dataset_range, -1))) + self.assertNotEqual(output_per_iteration[0], output_per_iteration[1]) + self.assertNotEqual(output_per_iteration[0], output_per_iteration[2]) + self.assertNotEqual(output_per_iteration[1], output_per_iteration[2]) + + @combinations.generate( + combinations.times( + test_base.default_test_combinations(), + combinations.combine(reshuffle=[True, False], seed=[None, 42]))) + def testReshuffleRepeatEpochs(self, reshuffle: bool, seed: Optional[int]): + dataset_range = 100 + dataset = dataset_ops.Dataset.range(dataset_range) + dataset = global_shuffle_op._global_shuffle( + dataset, seed=seed, reshuffle_each_iteration=reshuffle) + dataset = dataset.repeat(2) + + output = self.getDatasetOutput(dataset, requires_initialization=True) + self.assertCountEqual(output, list(range(dataset_range)) * 2) + output_per_iteration = [ + output[i : i + dataset_range] + for i in range(0, len(output), dataset_range)] + if reshuffle: + self.assertNotEqual(output_per_iteration[0], output_per_iteration[1]) + else: + self.assertEqual(output_per_iteration[0], output_per_iteration[1]) + + @combinations.generate( + combinations.times( + combinations.combine(tf_api_version=2, mode="eager"), + combinations.combine(reshuffle=[True, False], seed=[None, 42]))) + def testReshuffleIterationEpochs(self, reshuffle: bool, seed: Optional[int]): + # TensorFlow unit tests set the global graph seed. We unset it here so that + # we can control determinism via the `seed` parameter. + random_seed.set_random_seed(None) + dataset_range = 100 + dataset = dataset_ops.Dataset.range(dataset_range) + dataset = global_shuffle_op._global_shuffle( + dataset, seed=seed, reshuffle_each_iteration=reshuffle) + + first_epoch = self.getDatasetOutput(dataset) + second_epoch = self.getDatasetOutput(dataset) + if reshuffle: + self.assertNotEqual(first_epoch, second_epoch) + else: + self.assertEqual(first_epoch, second_epoch) + + @combinations.generate(test_base.default_test_combinations()) + def testEmptyDataset(self): + dataset = dataset_ops.Dataset.range(0) + with self.assertRaisesRegex( + errors.InvalidArgumentError, + "`global_shuffle` requires the input dataset to have a non-empty " + "finite cardinality."): + dataset = global_shuffle_op._global_shuffle(dataset) + self.getDatasetOutput(dataset, requires_initialization=True) + + @combinations.generate(test_base.default_test_combinations()) + def testUnsupportedDataset(self): + dataset = dataset_ops.Dataset.range(100) + dataset = dataset.shuffle(buffer_size=1) + with self.assertRaisesRegex( + errors.FailedPreconditionError, + "`global_shuffle` requires all upstream transformations be compatible " + "with random access."): + dataset = global_shuffle_op._global_shuffle(dataset) + self.getDatasetOutput(dataset, requires_initialization=True) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/python/data/experimental/kernel_tests/optimization/BUILD b/tensorflow/python/data/experimental/kernel_tests/optimization/BUILD index 075e1eb1794969..34c1098ccae8fc 100644 --- a/tensorflow/python/data/experimental/kernel_tests/optimization/BUILD +++ b/tensorflow/python/data/experimental/kernel_tests/optimization/BUILD @@ -205,6 +205,7 @@ tf_py_strict_test( "//tensorflow/python/data/ops:dataset_ops", "//tensorflow/python/data/ops:options", "//tensorflow/python/framework:combinations", + "//tensorflow/python/framework:constant_op", "//tensorflow/python/framework:dtypes", "//tensorflow/python/framework:errors", "//tensorflow/python/framework:ops", @@ -238,3 +239,31 @@ tf_py_strict_test( "@absl_py//absl/testing:parameterized", ], ) + +tf_py_strict_test( + name = "seq_interleave_prefetch_test", + size = "medium", + srcs = ["seq_interleave_prefetch_test.py"], + deps = [ + "//tensorflow/python/data/experimental/ops:batching", + "//tensorflow/python/data/experimental/ops:grouping", + "//tensorflow/python/data/experimental/ops:scan_ops", + "//tensorflow/python/data/experimental/ops:testing", + "//tensorflow/python/data/kernel_tests:test_base", + "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/python/data/ops:options", + "//tensorflow/python/framework:combinations", + "//tensorflow/python/framework:constant_op", + "//tensorflow/python/framework:dtypes", + "//tensorflow/python/framework:errors", + "//tensorflow/python/framework:ops", + "//tensorflow/python/ops:array_ops", + "//tensorflow/python/ops:random_ops", + "//tensorflow/python/ops:variable_scope", + "//tensorflow/python/ops:variables", + "//tensorflow/python/platform:client_testlib", + "//tensorflow/python/platform:tf_logging", + "//third_party/py/numpy", + "@absl_py//absl/testing:parameterized", + ], +) diff --git a/tensorflow/python/data/experimental/kernel_tests/optimization/seq_interleave_prefetch_test.py b/tensorflow/python/data/experimental/kernel_tests/optimization/seq_interleave_prefetch_test.py new file mode 100644 index 00000000000000..a5dfa78f5a0fc0 --- /dev/null +++ b/tensorflow/python/data/experimental/kernel_tests/optimization/seq_interleave_prefetch_test.py @@ -0,0 +1,100 @@ +# Copyright 2023 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the `SeqInterleavePrefetch` optimization.""" +from absl.testing import parameterized + +from tensorflow.python.data.kernel_tests import test_base +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.data.ops import options as options_lib +from tensorflow.python.framework import combinations +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.platform import test + + +class SeqInterleavePrefetchTest( + test_base.DatasetTestBase, parameterized.TestCase +): + + @combinations.generate( + combinations.times( + test_base.eager_only_combinations(), + combinations.combine(cycle_length=[2, 4]), + combinations.combine(block_length=[2, 4]), + combinations.combine(other_arguments=[True, False]), + ) + ) + def testOptimizationSeqInterleavePrefetch( + self, + cycle_length, + block_length, + other_arguments, + ): + num_input_elements = 16 + var1 = constant_op.constant(9, dtype=dtypes.int64) + var2 = constant_op.constant(11, dtype=dtypes.int64) + + # dataset1: Deterministic parallel interleave dataset. + dataset1 = dataset_ops.Dataset.range(num_input_elements) + options1 = options_lib.Options() + options1.experimental_optimization.apply_default_optimizations = False + options1.experimental_optimization.seq_interleave_prefetch = False + dataset1 = dataset1.with_options(options1) + if other_arguments: + dataset1 = dataset1.interleave( + (lambda _: dataset_ops.Dataset.range(var1 + var2 + 1)), + cycle_length=cycle_length, + block_length=block_length, + num_parallel_calls=dataset_ops.AUTOTUNE, + deterministic=True, + ) + else: + dataset1 = dataset1.interleave( + (lambda _: dataset_ops.Dataset.range(num_input_elements)), + cycle_length=cycle_length, + block_length=block_length, + num_parallel_calls=dataset_ops.AUTOTUNE, + deterministic=True, + ) + + # dataset2: Deterministic parallel interleave dataset with + # `seq_interleave_prefetch` optimization enabled. + dataset2 = dataset_ops.Dataset.range(num_input_elements) + options2 = options_lib.Options() + options2.experimental_optimization.apply_default_optimizations = False + options2.experimental_optimization.seq_interleave_prefetch = True + dataset2 = dataset2.with_options(options2) + if other_arguments: + dataset2 = dataset2.interleave( + (lambda _: dataset_ops.Dataset.range(var1 + var2 + 1)), + cycle_length=cycle_length, + block_length=block_length, + num_parallel_calls=dataset_ops.AUTOTUNE, + deterministic=True, + ) + else: + dataset2 = dataset2.interleave( + (lambda _: dataset_ops.Dataset.range(num_input_elements)), + cycle_length=cycle_length, + block_length=block_length, + num_parallel_calls=dataset_ops.AUTOTUNE, + deterministic=True, + ) + + self.assertDatasetsEqual(dataset1, dataset2) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/python/data/experimental/kernel_tests/service/distributed_save_load_ft_test.py b/tensorflow/python/data/experimental/kernel_tests/service/distributed_save_load_ft_test.py index 2e93dff8ed4c5b..e752b073206d9a 100644 --- a/tensorflow/python/data/experimental/kernel_tests/service/distributed_save_load_ft_test.py +++ b/tensorflow/python/data/experimental/kernel_tests/service/distributed_save_load_ft_test.py @@ -39,6 +39,7 @@ class DistributedSaveLoadFtTest( combinations.combine( num_elements=[200], num_workers=[1, 2], + save_repetitions=[1, 2], load_repetitions=[1, 2], sharding_policy=[ data_service_ops.ShardingPolicy.OFF, @@ -47,17 +48,21 @@ def test_dispatcher_restart( self, num_workers: int, num_elements: int, + save_repetitions: int, load_repetitions: int, sharding_policy: data_service_ops.ShardingPolicy): cluster = data_service_test_base.TestCluster(num_workers=num_workers) snapshot_dir = data_service_test_base.TempDir() dataset = dataset_ops.Dataset.range(num_elements) + if save_repetitions > 1: + dataset = dataset.repeat(save_repetitions) self.evaluate( distributed_save_op.distributed_save( dataset, snapshot_dir.full_path, cluster.dispatcher_address())) dataset = load_op._load_with_retry(snapshot_dir.full_path) - dataset = dataset.repeat(load_repetitions) + if load_repetitions > 1: + dataset = dataset.repeat(load_repetitions) dataset = dataset.apply( data_service_ops.distribute( sharding_policy, @@ -71,8 +76,9 @@ def test_dispatcher_restart( # For no sharding, dispatcher restarts do not affect data processing # happening at the workers. + repetitions = save_repetitions * load_repetitions if sharding_policy == data_service_ops.ShardingPolicy.OFF: - expected = list(range(num_elements)) * load_repetitions * num_workers + expected = list(range(num_elements)) * repetitions * num_workers self.assertCountEqual(output, expected) # Dynamic sharding may lose splits if the dispatcher fails. @@ -85,24 +91,33 @@ def test_dispatcher_restart( test_base.eager_only_combinations(), combinations.combine( num_elements=[200], + num_workers=[1, 2], + save_repetitions=[1, 2], load_repetitions=[1, 2], sharding_policy=[ - data_service_ops.ShardingPolicy.OFF, - data_service_ops.ShardingPolicy.DYNAMIC]))) + # TODO(b/297930782): Enable dynamic sharding. Need to fix the + # race condition where workers restart before sending the + # final task completion update. + data_service_ops.ShardingPolicy.OFF]))) def test_dispatcher_and_worker_restart( self, num_elements: int, + num_workers: int, + save_repetitions: int, load_repetitions: int, sharding_policy: data_service_ops.ShardingPolicy): - cluster = data_service_test_base.TestCluster(num_workers=1) + cluster = data_service_test_base.TestCluster(num_workers=num_workers) snapshot_dir = data_service_test_base.TempDir() dataset = dataset_ops.Dataset.range(num_elements) + if save_repetitions > 1: + dataset = dataset.repeat(save_repetitions) self.evaluate( distributed_save_op.distributed_save( dataset, snapshot_dir.full_path, cluster.dispatcher_address())) dataset = load_op._load_with_retry(snapshot_dir.full_path) - dataset = dataset.repeat(load_repetitions) + if load_repetitions > 1: + dataset = dataset.repeat(load_repetitions) dataset = dataset.apply( data_service_ops.distribute( sharding_policy, @@ -111,23 +126,19 @@ def test_dispatcher_and_worker_restart( iterator = self.getNext(dataset) output = [self.evaluate(iterator())] - cluster.restart_dispatcher() - cluster.workers[0].restart() + for i in range(num_workers): + cluster.restart_dispatcher() + cluster.workers[i].restart() output.extend(self.getIteratorOutput(iterator)) # If the sharding policy is OFF, the restarted worker will produce elements # from the beginning of the dataset. The result is a partial range plus # `num_elements` repetitions. if sharding_policy == data_service_ops.ShardingPolicy.OFF: - self.assertContainsSubset( - list(range(num_elements)) * load_repetitions, output) - - # For dynamic sharding, the first split (and possibly prefetched splits) may - # be lost. The result is a partial range plus zero or more `num_elements` - # ranges. - if sharding_policy == data_service_ops.ShardingPolicy.DYNAMIC: - num_ranges = len(output) // num_elements - self.assertContainsSubset(list(range(num_elements)) * num_ranges, output) + repetitions = save_repetitions * load_repetitions + self.assertContainsSubsequence( + sorted(output), + sorted(list(range(num_elements)) * repetitions * num_workers)) @combinations.generate( combinations.times( diff --git a/tensorflow/python/data/experimental/ops/BUILD b/tensorflow/python/data/experimental/ops/BUILD index 75e297261a51bb..7731cfb8252687 100644 --- a/tensorflow/python/data/experimental/ops/BUILD +++ b/tensorflow/python/data/experimental/ops/BUILD @@ -171,6 +171,21 @@ py_strict_library( ], ) +py_strict_library( + name = "global_shuffle_op", + srcs = [ + "global_shuffle_op.py", + ], + srcs_version = "PY3", + deps = [ + "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/python/data/util:random_seed", + "//tensorflow/python/framework:tensor", + "//tensorflow/python/ops:dataset_ops_gen", + "//tensorflow/python/ops:experimental_dataset_ops_gen", + ], +) + py_strict_library( name = "grouping", srcs = ["grouping.py"], @@ -518,6 +533,7 @@ py_strict_library( ":error_ops", ":from_list", ":get_single_element", + ":global_shuffle_op", ":grouping", ":interleave_ops", ":io", diff --git a/tensorflow/python/data/experimental/ops/global_shuffle_op.py b/tensorflow/python/data/experimental/ops/global_shuffle_op.py new file mode 100644 index 00000000000000..dc3899338b2f6c --- /dev/null +++ b/tensorflow/python/data/experimental/ops/global_shuffle_op.py @@ -0,0 +1,90 @@ +# Copyright 2024 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Globally shuffles tf.data datasets.""" + +from typing import Optional, Union + +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.data.util import random_seed +from tensorflow.python.framework import tensor +from tensorflow.python.ops import gen_dataset_ops +from tensorflow.python.ops import gen_experimental_dataset_ops as ged_ops + + +def _global_shuffle( # pylint: disable=unused-private-name + input_dataset: dataset_ops.DatasetV2, + seed: Optional[Union[int, tensor.Tensor]] = None, + reshuffle_each_iteration: bool = True, + name: Optional[str] = None) -> dataset_ops.DatasetV2: + """Globally shuffles the elements of `input_dataset`. + + The shuffling is done efficiently, without needing to buffer any additional + data. To achieve this, the transformations preceding global_shuffle must all + support random access. + + Requires that: + - The shuffled dataset and all its input datasets support random access. + - The input_dataset to have a known, finite cardinality. Users can use + `tf.data.experimental.assert_cardinality` to specify the cardinality of a + dataset if it cannot be determined at runtime. + + TODO(b/325112575): Move the API to dataset_ops.py. + TODO(b/325112575): Support reshuffle_each_iteration. + TODO(b/325112575): Support checkpoints. + + Args: + input_dataset: The dataset to be shuffled. + seed: An int or `tf.int64` scalar `tf.Tensor` to control the shuffle order. + If `None`, a random seed will be used. + reshuffle_each_iteration: A boolean, which if True, indicates that a + different shuffle order should be generated for each iteration of the + dataset. (Defaults to `True`.) + name: (Optional.) A name for the tf.data operation. + + Returns: + A new `Dataset` where elements are produced in a globally shuffled order. + + Raises: + InvalidArgumentError if the input dataset does not support random access, or + it has infinite or unknown cardinality. + """ + return _GlobalShuffleDataset( + input_dataset, + seed=seed, + reshuffle_each_iteration=reshuffle_each_iteration, + name=name) + + +class _GlobalShuffleDataset(dataset_ops.UnaryUnchangedStructureDataset): + """Shuffles all elements in the input dataset.""" + + def __init__( + self, + input_dataset: dataset_ops.DatasetV2, + seed: Optional[Union[int, tensor.Tensor]] = None, + reshuffle_each_iteration: bool = True, + name: Optional[str] = None): + self._input_dataset = input_dataset + self._seed, self._seed2 = random_seed.get_seed(seed) + self._reshuffle_each_iteration = reshuffle_each_iteration + self._name = name + variant_tensor = ged_ops.global_shuffle_dataset( + self._input_dataset._variant_tensor, # pylint: disable=protected-access + seed=self._seed, + seed2=self._seed2, + seed_generator=gen_dataset_ops.dummy_seed_generator(), + reshuffle_each_iteration=self._reshuffle_each_iteration, + **self._common_args) + super().__init__(input_dataset, variant_tensor) diff --git a/tensorflow/python/data/kernel_tests/BUILD b/tensorflow/python/data/kernel_tests/BUILD index 28a7d986a6434c..b87ddfe04dbf67 100644 --- a/tensorflow/python/data/kernel_tests/BUILD +++ b/tensorflow/python/data/kernel_tests/BUILD @@ -1144,6 +1144,7 @@ tf_py_strict_test( "//tensorflow/python/data/ops:options", "//tensorflow/python/eager:def_function", "//tensorflow/python/framework:combinations", + "//tensorflow/python/framework:constant_op", "//tensorflow/python/framework:dtypes", "//tensorflow/python/framework:errors", "//tensorflow/python/framework:ops", diff --git a/tensorflow/python/data/kernel_tests/filter_test.py b/tensorflow/python/data/kernel_tests/filter_test.py index e81ab7058bb2d3..e5ca81018ce9e4 100644 --- a/tensorflow/python/data/kernel_tests/filter_test.py +++ b/tensorflow/python/data/kernel_tests/filter_test.py @@ -162,6 +162,17 @@ def testName(self): lambda x: True, name="filter") self.assertDatasetProduces(dataset, [42]) + @combinations.generate(test_base.default_test_combinations()) + def testPredicateFailWithErrorContext(self): + dataset = dataset_ops.Dataset.from_tensors(42).filter( + lambda x: (x // 0) > 0, name="filter") + get_next = self.getNext(dataset) + with self.assertRaisesRegex( + errors.InvalidArgumentError, + r".*Error in user-defined function passed to .* transformation with " + r"iterator: Iterator::Root::.*"): + self.evaluate(get_next()) + class FilterCheckpointTest(checkpoint_test_base.CheckpointTestBase, parameterized.TestCase): diff --git a/tensorflow/python/data/kernel_tests/flat_map_test.py b/tensorflow/python/data/kernel_tests/flat_map_test.py index 4a3becfd753faa..49be5a7124613e 100644 --- a/tensorflow/python/data/kernel_tests/flat_map_test.py +++ b/tensorflow/python/data/kernel_tests/flat_map_test.py @@ -189,6 +189,20 @@ def fn(x): dataset = dataset_ops.Dataset.from_tensors(42).flat_map(fn, name="flat_map") self.assertDatasetProduces(dataset, [42]) + @combinations.generate(test_base.default_test_combinations()) + def testMapFuncFailWithErrorContext(self): + + def fn(x): + return dataset_ops.Dataset.from_tensors(x // 0) + + dataset = dataset_ops.Dataset.from_tensors(42).flat_map(fn, name="flat_map") + get_next = self.getNext(dataset) + with self.assertRaisesRegex( + errors.InvalidArgumentError, + r".*Error in user-defined function passed to .* transformation with " + r"iterator: Iterator::Root::.*"): + self.evaluate(get_next()) + @combinations.generate(test_base.v2_eager_only_combinations()) def testSymbolicCheckpointSize(self): examples_per_flat_map = 100 diff --git a/tensorflow/python/data/kernel_tests/interleave_test.py b/tensorflow/python/data/kernel_tests/interleave_test.py index 442109dcd2b339..d4c23c19edeb47 100644 --- a/tensorflow/python/data/kernel_tests/interleave_test.py +++ b/tensorflow/python/data/kernel_tests/interleave_test.py @@ -386,6 +386,23 @@ def map_fn(x): dataset_ops.Dataset.from_tensors(42).interleave( map_fn, num_parallel_calls=num_parallel_calls) + @combinations.generate( + combinations.times(test_base.default_test_combinations(), + combinations.combine(num_parallel_calls=[None, 1]))) + def testMapFuncFailWithErrorContext(self, num_parallel_calls): + + def fn(x): + return dataset_ops.Dataset.from_tensors(x // 0) + + dataset = dataset_ops.Dataset.from_tensors(42).interleave( + fn, num_parallel_calls=num_parallel_calls, name="interleave") + get_next = self.getNext(dataset) + with self.assertRaisesRegex( + errors.InvalidArgumentError, + r".*Error in user-defined function passed to .* transformation with " + r"iterator: Iterator::Root::.*"): + self.evaluate(get_next()) + @combinations.generate(test_base.v2_eager_only_combinations()) def testSymbolicCheckpointSize(self): if sys.platform == "darwin": diff --git a/tensorflow/python/data/kernel_tests/iterator_test.py b/tensorflow/python/data/kernel_tests/iterator_test.py index 8d1e384e033683..b9cd6d2ae2d32f 100644 --- a/tensorflow/python/data/kernel_tests/iterator_test.py +++ b/tensorflow/python/data/kernel_tests/iterator_test.py @@ -1072,6 +1072,27 @@ def fn(): self.evaluate(counter_var.initializer) self.assertEqual(self.evaluate(fn()), 10) + @combinations.generate(test_base.eager_only_combinations()) + def testSaveRestore(self): + ds = dataset_ops.Dataset.range(10) + ds = ds.shuffle(5, seed=42, reshuffle_each_iteration=False) + it = ds.as_numpy_iterator() + + expected = list(ds.as_numpy_iterator()) + + for i in range(3): + self.assertEqual(next(it), expected[i]) + + state = it.save() + + for i in range(3, 6): + self.assertEqual(next(it), expected[i]) + + it.restore(state) + + for i in range(3, 6): + self.assertEqual(next(it), expected[i]) + if __name__ == "__main__": test.main() diff --git a/tensorflow/python/data/kernel_tests/map_test.py b/tensorflow/python/data/kernel_tests/map_test.py index 70f56db29714d8..3ef584ba2dbed6 100644 --- a/tensorflow/python/data/kernel_tests/map_test.py +++ b/tensorflow/python/data/kernel_tests/map_test.py @@ -1540,6 +1540,22 @@ def testName(self, num_parallel_calls): lambda x: x * 2, num_parallel_calls=num_parallel_calls, name="map") self.assertDatasetProduces(dataset, [42]) + @combinations.generate( + combinations.times(test_base.default_test_combinations(), + combinations.combine(num_parallel_calls=[None, 1]))) + def testStatusMessage(self, num_parallel_calls): + dataset = dataset_ops.Dataset.from_tensors(21).map( + lambda x: x // 0, num_parallel_calls=num_parallel_calls, name="map") + options = options_lib.Options() + options.experimental_optimization.apply_default_optimizations = False + dataset = dataset.with_options(options) + get_next = self.getNext(dataset) + with self.assertRaisesRegex( + errors.InvalidArgumentError, + r".*Error in user-defined function passed to .* transformation with " + r"iterator: Iterator::Root::.*"): + self.evaluate(get_next()) + class MapCheckpointTest(checkpoint_test_base.CheckpointTestBase, parameterized.TestCase): diff --git a/tensorflow/python/data/kernel_tests/options_test.py b/tensorflow/python/data/kernel_tests/options_test.py index ecae99543fbd84..ebf9909c6c4670 100644 --- a/tensorflow/python/data/kernel_tests/options_test.py +++ b/tensorflow/python/data/kernel_tests/options_test.py @@ -79,6 +79,19 @@ def testOptionsTwiceSameOption(self): ds = ds.with_options(options2) self.assertTrue(self._get_options(ds).autotune.enabled) + @combinations.generate(test_base.default_test_combinations()) + def testOptionsTwiceSameOptionWithMap(self): + options1 = options_lib.Options() + options1.framework_type = ["seqio"] + options2 = options_lib.Options() + options2.framework_type = ["tfgrain"] + ds = dataset_ops.Dataset.range(5) + ds = ds.with_options(options1) + ds = ds.map(lambda x: x + 1) + ds = ds.with_options(options2) + self.assertDatasetProduces(ds, [1, 2, 3, 4, 5]) + self.assertLen(self._get_options(ds).framework_type, 2) + @combinations.generate(test_base.default_test_combinations()) def testOptionsMergeOptionsFromMultipleInputs(self): options1 = options_lib.Options() @@ -149,14 +162,17 @@ def testOptionsProtoRoundTrip(self): options.experimental_optimization.noop_elimination = True options.experimental_optimization.parallel_batch = True options.experimental_optimization.shuffle_and_repeat_fusion = True + options.experimental_optimization.seq_interleave_prefetch = True options.experimental_warm_start = True options.experimental_slack = True options.dataset_name = "test_name" + options.framework_type = ["TFDS", "TfGrain"] options.threading.max_intra_op_parallelism = 30 options.threading.private_threadpool_size = 40 pb = options._to_proto() result = options_lib.Options() result._from_proto(pb) + self.assertEqual(options.framework_type, result.framework_type) self.assertEqual(options, result) @combinations.generate(test_base.default_test_combinations()) diff --git a/tensorflow/python/data/kernel_tests/shuffle_test.py b/tensorflow/python/data/kernel_tests/shuffle_test.py index 39b203d71b2433..99266134f0f79d 100644 --- a/tensorflow/python/data/kernel_tests/shuffle_test.py +++ b/tensorflow/python/data/kernel_tests/shuffle_test.py @@ -30,6 +30,7 @@ from tensorflow.python.data.ops import options as options_lib from tensorflow.python.eager import def_function from tensorflow.python.framework import combinations +from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import ops @@ -107,7 +108,8 @@ def dataset_fn(count=5, buffer_size=None, seed=0): # Assert that shuffling twice with a different seed gives a different # permutation of the same elements. - get_next = self.getNext(dataset_fn(buffer_size=100, seed=137)) + get_next = self.getNext(dataset_fn( + buffer_size=100, seed=constant_op.constant(137, dtype=dtypes.int64))) reshuffled_elements_different_seed = [] for _ in range(20): reshuffled_elements_different_seed.append(self.evaluate(get_next())) @@ -176,6 +178,33 @@ def testDefaultArguments(self): for i in range(5): self.assertEqual(10, counts[i]) + @combinations.generate( + combinations.times( + test_base.default_test_combinations(), + combinations.combine( + dataset_range=[100], + buffer_size=[None, 10, 200], + seed=[None, 42], + use_tensor_input=[True, False]))) + def testTensorInput(self, dataset_range, buffer_size, seed, use_tensor_input): + dataset = dataset_ops.Dataset.range(dataset_range) + unshuffled_output = self.getDatasetOutput(dataset) + + if buffer_size: + buffer_size = ( + constant_op.constant(buffer_size, dtype=dtypes.int64) + if use_tensor_input else buffer_size) + else: + buffer_size = dataset.cardinality() + seed = (constant_op.constant(seed, dtype=dtypes.int64) + if seed and use_tensor_input else seed) + + shuffled_dataset = dataset.shuffle(buffer_size, seed=seed) + shuffled_output = self.getDatasetOutput(shuffled_dataset) + self.assertEqual(unshuffled_output, list(range(dataset_range))) + self.assertCountEqual(shuffled_output, unshuffled_output) + self.assertNotEqual(shuffled_output, unshuffled_output) + @combinations.generate(test_base.default_test_combinations()) def testUnknownCardinality(self): components = [0, 1, 2, 3, 4] diff --git a/tensorflow/python/data/ops/BUILD b/tensorflow/python/data/ops/BUILD index 6955f4d01de4ad..58a4d1d3c9748e 100644 --- a/tensorflow/python/data/ops/BUILD +++ b/tensorflow/python/data/ops/BUILD @@ -180,6 +180,7 @@ py_strict_library( "//tensorflow/python/data/util:structure", "//tensorflow/python/eager:context", "//tensorflow/python/framework:composite_tensor", + "//tensorflow/python/framework:constant_op", "//tensorflow/python/framework:dtypes", "//tensorflow/python/framework:errors", "//tensorflow/python/framework:ops", @@ -189,9 +190,12 @@ py_strict_library( "//tensorflow/python/framework:tensor_spec", "//tensorflow/python/framework:type_spec", "//tensorflow/python/framework:type_utils", + "//tensorflow/python/ops:array_ops_stack", "//tensorflow/python/ops:cond", "//tensorflow/python/ops:dataset_ops_gen", "//tensorflow/python/ops:parsing_ops", + "//tensorflow/python/ops:string_ops", + "//tensorflow/python/ops/ragged:ragged_string_ops", "//tensorflow/python/saved_model:nested_structure_coder", "//tensorflow/python/trackable:base", "//tensorflow/python/training:saver", diff --git a/tensorflow/python/data/ops/dataset_ops.py b/tensorflow/python/data/ops/dataset_ops.py index 21ca38c228a8e2..574a827c1c0788 100644 --- a/tensorflow/python/data/ops/dataset_ops.py +++ b/tensorflow/python/data/ops/dataset_ops.py @@ -1479,12 +1479,12 @@ def shuffle( ``` Args: - buffer_size: A `tf.int64` scalar `tf.Tensor`, representing the number of - elements from this dataset from which the new dataset will sample. To - uniformly shuffle the entire dataset, use + buffer_size: An int or `tf.int64` scalar `tf.Tensor`, representing the + number of elements from this dataset from which the new dataset will + sample. To uniformly shuffle the entire dataset, use `buffer_size=dataset.cardinality()`. - seed: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the random - seed that will be used to create the distribution. See + seed: (Optional.) An int or `tf.int64` scalar `tf.Tensor`, representing + the random seed that will be used to create the distribution. See `tf.random.set_seed` for behavior. reshuffle_each_iteration: (Optional.) A boolean, which if true indicates that the dataset should be pseudorandomly reshuffled each time it is diff --git a/tensorflow/python/data/ops/iterator_ops.py b/tensorflow/python/data/ops/iterator_ops.py index 6db3abca84c880..1b8db599df025e 100644 --- a/tensorflow/python/data/ops/iterator_ops.py +++ b/tensorflow/python/data/ops/iterator_ops.py @@ -13,6 +13,7 @@ # limitations under the License. # ============================================================================== """Python wrappers for Iterators.""" + import abc import threading import warnings @@ -34,8 +35,11 @@ from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import type_spec from tensorflow.python.framework import type_utils +from tensorflow.python.ops import array_ops_stack from tensorflow.python.ops import gen_dataset_ops from tensorflow.python.ops import parsing_ops +from tensorflow.python.ops import string_ops +from tensorflow.python.ops.ragged import ragged_string_ops from tensorflow.python.saved_model import nested_structure_coder from tensorflow.python.trackable import base as trackable from tensorflow.python.training.saver import BaseSaverBuilder @@ -792,10 +796,23 @@ def _save(self): state_variant = gen_dataset_ops.serialize_iterator( self._iterator_resource, external_state_policy ) - return parsing_ops.serialize_tensor(state_variant) + # Serialize each slice of the state_variant separately, to avoid hitting the + # 2GB proto serialization limit. + state = array_ops_stack.unstack(state_variant) + state = [parsing_ops.serialize_tensor(x) for x in state] + state = array_ops_stack.stack(state) + + state = string_ops.encode_base64(state) + state = string_ops.string_join(state, separator=",") + return state def _restore(self, state): - state_variant = parsing_ops.parse_tensor(state, dtypes.variant) + state = ragged_string_ops.string_split_v2(state, sep=",") + state = string_ops.decode_base64(state) + + state = array_ops_stack.unstack(state) + state = [parsing_ops.parse_tensor(x, dtypes.variant) for x in state] + state_variant = array_ops_stack.stack(state) return gen_dataset_ops.deserialize_iterator( self._iterator_resource, state_variant ) diff --git a/tensorflow/python/data/ops/map_op.py b/tensorflow/python/data/ops/map_op.py index 1751ac8d219488..2447b545ab47ac 100644 --- a/tensorflow/python/data/ops/map_op.py +++ b/tensorflow/python/data/ops/map_op.py @@ -14,6 +14,7 @@ # ============================================================================== """The implementation of `tf.data.Dataset.map`.""" +import inspect import warnings from tensorflow.python.data.ops import dataset_ops @@ -24,6 +25,27 @@ from tensorflow.python.ops import gen_dataset_ops +def _generate_default_name(): + """Generates a transformation name based on the current call stack. + + The name is useful for debugging, e.g. identifying transformations in xprofs. + + Returns: + The generated name. + """ + # Use the closest non-tf-data stack frame. + for frame in inspect.stack(): + if "tensorflow/python/data" in frame.filename: + continue + name = ( + "file_" + frame.filename.split("/")[-1] + "_line_" + str(frame.lineno) + ).replace(".", "_") + if name.isidentifier(): + return name + + return None + + def _map_v2(input_dataset, # pylint: disable=unused-private-name map_func, num_parallel_calls=None, @@ -109,7 +131,7 @@ def __init__(self, self._transformation_name(), dataset=input_dataset, use_legacy_function=use_legacy_function) - self._name = name + self._name = name or _generate_default_name() variant_tensor = gen_dataset_ops.map_dataset( input_dataset._variant_tensor, # pylint: disable=protected-access self._map_func.function.captured_inputs, @@ -159,7 +181,7 @@ def __init__(self, self._preserve_cardinality = preserve_cardinality self._num_parallel_calls = ops.convert_to_tensor( num_parallel_calls, dtype=dtypes.int64, name="num_parallel_calls") - self._name = name + self._name = name or _generate_default_name() variant_tensor = gen_dataset_ops.parallel_map_dataset_v2( input_dataset._variant_tensor, # pylint: disable=protected-access self._map_func.function.captured_inputs, diff --git a/tensorflow/python/data/ops/options.py b/tensorflow/python/data/ops/options.py index 396ab1b6d78856..41124b0bb56d37 100644 --- a/tensorflow/python/data/ops/options.py +++ b/tensorflow/python/data/ops/options.py @@ -229,6 +229,18 @@ class AutotuneOptions(options_lib.OptionsBase): docstring="When autotuning is enabled (through `autotune`), determines " "the algorithm to use.") + initial_parallelism = options_lib.create_option( + name="initial_parallelism", + ty=int, + docstring=( + "The initial parallelism to use for parallel transformations before" + " autotune has a chance to run. A higher value can help with quick" + " startup, but may cause the ram_budget to temporarily be exceeded." + " Memory-sensitive datasets should consider setting this to `1` to" + " avoid running out of memory. Defaults to 16." + ), + ) + def _to_proto(self): pb = dataset_options_pb2.AutotuneOptions() if self.enabled is not None: @@ -240,6 +252,8 @@ def _to_proto(self): if self.autotune_algorithm is not None: pb.autotune_algorithm = AutotuneAlgorithm._to_proto( # pylint: disable=protected-access self.autotune_algorithm) + if self.initial_parallelism is not None: + pb.initial_parallelism = self.initial_parallelism return pb def _from_proto(self, pb): @@ -252,6 +266,8 @@ def _from_proto(self, pb): if pb.WhichOneof("optional_autotune_algorithm") is not None: self.autotune_algorithm = AutotuneAlgorithm._from_proto( # pylint: disable=protected-access pb.autotune_algorithm) + if pb.WhichOneof("optional_initial_parallelism") is not None: + self.initial_parallelism = pb.initial_parallelism def _set_mutable(self, mutable): """Change the mutability value to `mutable` on this options and children.""" @@ -344,6 +360,16 @@ class OptimizationOptions(options_lib.OptionsBase): "when the last transformation is a synchronous transformation. If None, " "defaults to True.") + seq_interleave_prefetch = options_lib.create_option( + name="seq_interleave_prefetch", + ty=bool, + docstring=( + "Whether to replace parallel interleave using a sequential interleave" + " that prefetches elements from its input iterators. If None," + " defaults to False." + ), + ) + map_and_batch_fusion = options_lib.create_option( name="map_and_batch_fusion", ty=bool, @@ -399,6 +425,8 @@ def _to_proto(self): pb.filter_parallelization = self.filter_parallelization if self.inject_prefetch is not None: pb.inject_prefetch = self.inject_prefetch + if self.seq_interleave_prefetch is not None: + pb.seq_interleave_prefetch = self.seq_interleave_prefetch if self.map_and_batch_fusion is not None: pb.map_and_batch_fusion = self.map_and_batch_fusion if self.map_and_filter_fusion is not None: @@ -424,6 +452,8 @@ def _from_proto(self, pb): self.filter_parallelization = pb.filter_parallelization if pb.WhichOneof("optional_inject_prefetch") is not None: self.inject_prefetch = pb.inject_prefetch + if pb.WhichOneof("optional_seq_interleave_prefetch") is not None: + self.seq_interleave_prefetch = pb.seq_interleave_prefetch if pb.WhichOneof("optional_map_and_batch_fusion") is not None: self.map_and_batch_fusion = pb.map_and_batch_fusion if pb.WhichOneof("optional_map_and_filter_fusion") is not None: @@ -602,11 +632,18 @@ class Options(options_lib.OptionsBase): ), default_factory=lambda: True if test_mode.TEST_MODE else None, ) + dataset_name = options_lib.create_option( name="dataset_name", ty=str, docstring="A name for the dataset, to help in debugging.") + framework_type = options_lib.create_option( + name="framework_type", + ty=list, + docstring="The list of frameworks that are used to generate this " + "pipeline, used for telemetry.") + threading = options_lib.create_option( name="threading", ty=ThreadingOptions, @@ -664,6 +701,9 @@ def _to_proto(self): pb.warm_start = self.experimental_warm_start if self.dataset_name is not None: pb.dataset_name = self.dataset_name + if self.framework_type: + for framework_type in self.framework_type: + pb.framework_type.append(framework_type) pb.threading_options.CopyFrom(self.threading._to_proto()) # pylint: disable=protected-access return pb @@ -685,6 +725,10 @@ def _from_proto(self, pb): self.experimental_warm_start = pb.warm_start if pb.WhichOneof("optional_dataset_name") is not None: self.dataset_name = pb.dataset_name + if pb.framework_type: + self.framework_type = [] + for framework_type in pb.framework_type: + self.framework_type.append(framework_type) self.threading._from_proto(pb.threading_options) # pylint: disable=protected-access def _set_mutable(self, mutable): diff --git a/tensorflow/python/data/ops/prefetch_op.py b/tensorflow/python/data/ops/prefetch_op.py index d4982c67e027bc..49ae904d83d8aa 100644 --- a/tensorflow/python/data/ops/prefetch_op.py +++ b/tensorflow/python/data/ops/prefetch_op.py @@ -47,5 +47,6 @@ def __init__(self, input_dataset, buffer_size, slack_period=None, name=None): input_dataset._variant_tensor, buffer_size=self._buffer_size, slack_period=slack_period, + legacy_autotune=(buffer_size == dataset_ops.AUTOTUNE), **self._common_args) super().__init__(input_dataset, variant_tensor) diff --git a/tensorflow/python/data/util/options.py b/tensorflow/python/data/util/options.py index 3ec1c53ff65087..3d438b6a19b9c0 100644 --- a/tensorflow/python/data/util/options.py +++ b/tensorflow/python/data/util/options.py @@ -166,6 +166,10 @@ def merge_options(*options_list): setattr(result, name, that) elif isinstance(this, OptionsBase): setattr(result, name, merge_options(this, that)) + elif name == "framework_type": + # Since, `framework_type`` is a repeated string field (list), the merged + # result will be a combined list. + setattr(result, name, this+that) elif this != that: logging.warning("Changing the value of option %s from %r to %r.", name, this, that) diff --git a/tensorflow/python/debug/lib/BUILD b/tensorflow/python/debug/lib/BUILD index ad2637cdc41a15..53b20deb3cd8b3 100644 --- a/tensorflow/python/debug/lib/BUILD +++ b/tensorflow/python/debug/lib/BUILD @@ -130,6 +130,7 @@ py_strict_library( visibility = [ "//tensorflow:internal", "//third_party/py/tf_slim:__subpackages__", + "//waymo/ml/deploy/numeric_debugging:__subpackages__", ], deps = [ ":debug_graphs", diff --git a/tensorflow/python/distribute/BUILD b/tensorflow/python/distribute/BUILD index 2c497b61f9404a..7e86a2deff87e5 100644 --- a/tensorflow/python/distribute/BUILD +++ b/tensorflow/python/distribute/BUILD @@ -1756,6 +1756,7 @@ distribute_py_strict_test( tags = [ "multi_and_single_gpu", "no_cuda_asan", # times out + "no_oss", # TODO(b/292104274): Sometimes times out ], deps = [ ":collective_all_reduce_strategy", @@ -2156,13 +2157,6 @@ cuda_py_strict_test( tpu_py_strict_test( name = "collective_all_reduce_strategy_test_tpu", srcs = ["collective_all_reduce_strategy_test.py"], - # copybara:uncomment_begin - # args = [ - # "--tpu_use_tfrt=false", #TODO(b/227404010): Remove once the bug is fixed. - # ], - # copybara:uncomment_end - # FIXME(b/227404010): On TFRT TPU, eager CollectiveReduceV2 is broken. - disable_tfrt = True, main = "collective_all_reduce_strategy_test.py", python_version = "PY3", deps = [ diff --git a/tensorflow/python/framework/BUILD b/tensorflow/python/framework/BUILD index b773293d3dcb60..6c9ccdeb1d41ee 100644 --- a/tensorflow/python/framework/BUILD +++ b/tensorflow/python/framework/BUILD @@ -181,14 +181,12 @@ tf_cc_test( tf_proto_library( name = "op_reg_offset_proto", srcs = ["op_reg_offset.proto"], - cc_api_version = 3, visibility = ["//tensorflow:internal"], ) tf_proto_library( name = "kythe_metadata_proto", srcs = ["kythe_metadata.proto"], - cc_api_version = 3, visibility = ["//tensorflow:internal"], ) diff --git a/tensorflow/python/kernel_tests/array_ops/BUILD b/tensorflow/python/kernel_tests/array_ops/BUILD index 80cb2b53072a28..1a8d26ef226f7c 100644 --- a/tensorflow/python/kernel_tests/array_ops/BUILD +++ b/tensorflow/python/kernel_tests/array_ops/BUILD @@ -415,17 +415,12 @@ cuda_py_strict_test( size = "medium", srcs = ["init_ops_test.py"], shard_count = 4, - tags = [ - "noasan", - "notap", - ], deps = [ "//tensorflow/python/framework:constant_op", "//tensorflow/python/framework:errors", "//tensorflow/python/framework:for_generated_wrappers", "//tensorflow/python/framework:random_seed", "//tensorflow/python/framework:test_lib", - "//tensorflow/python/layers", "//tensorflow/python/ops:array_ops", "//tensorflow/python/ops:init_ops", "//tensorflow/python/ops:linalg_ops", diff --git a/tensorflow/python/kernel_tests/array_ops/init_ops_test.py b/tensorflow/python/kernel_tests/array_ops/init_ops_test.py index dbdd44ba9380f7..460b8f8e064e2c 100644 --- a/tensorflow/python/kernel_tests/array_ops/init_ops_test.py +++ b/tensorflow/python/kernel_tests/array_ops/init_ops_test.py @@ -22,10 +22,8 @@ from tensorflow.python.framework import ops from tensorflow.python.framework import random_seed from tensorflow.python.framework import test_util -from tensorflow.python.layers import convolutional from tensorflow.python.ops import array_ops from tensorflow.python.ops import init_ops -from tensorflow.python.ops import linalg_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import partitioned_variables from tensorflow.python.ops import random_ops @@ -538,13 +536,6 @@ def testMixedDType(self): constant_op.constant(4, dtype=dtypes.int32), dtype=dtypes.int64) self.assertAllEqual(self.evaluate(tf_ans), np.array([0, 1, 2, 3])) - def testLargeLimits(self): - # Test case for GitHub issue 46913. - with self.session(): - with self.assertRaises(errors_impl.ResourceExhaustedError): - v = math_ops.range(0, 9223372036854775807) - self.evaluate(v) - def testLargeStarts(self): # Test case for GitHub issue 46899. with self.session(): @@ -889,45 +880,6 @@ def testGain(self): t2 = init2(shape).eval() self.assertAllClose(t1, t2 / 3.14) - @test_util.run_deprecated_v1 - def testShapesValues(self): - gain = 3.14 - for dtype in [dtypes.float32]: - for kernel_size in [[3], [8], [3, 5], [2, 4], [3, 3, 3], [2, 2, 2]]: - tol = 1e-2 - # Check orthogonality by computing ratio between - # the 2-norms of the inputs and outputs. - if len(kernel_size) == 1: - shape = [4, 32, 64] - convolution = convolutional.conv1d - elif len(kernel_size) == 2: - convolution = convolutional.conv2d - shape = [4, 32, 32, 64] - else: - shape = [4, 16, 16, 16, 64] - convolution = convolutional.conv3d - inputs = random_ops.random_normal(shape, dtype=dtype) - inputs_2norm = linalg_ops.norm(inputs) - outputs = convolution( - inputs, - padding="same", - filters=128, - kernel_size=kernel_size, - use_bias=False, - kernel_initializer=init_ops.convolutional_delta_orthogonal( - gain=gain)) - outputs_shape = shape[0:-1] + [128] - outputs_2norm = linalg_ops.norm(outputs) - ratio = outputs_2norm / inputs_2norm - my_ops = variables.global_variables_initializer() - with self.session(): - self.evaluate(my_ops) - # Check the shape of the outputs - t = self.evaluate(outputs) - self.assertAllEqual(t.shape, outputs_shape) - # Check isometry of the delta-orthogonal kernel. - self.assertAllClose(self.evaluate(ratio), gain, rtol=tol, atol=tol) - @test_util.run_deprecated_v1 def testNonuniformity(self): value = 0 @@ -1025,62 +977,6 @@ def testNonuniformity(self): # Compute the sum of the absolute values of 'count' determinants self.assertAllClose(abs_value, count, rtol=tol, atol=tol) - @test_util.run_deprecated_v1 - def testShapesValues(self): - - def circular_pad(input_, width, kernel_size): - """Pad input_ for computing (circular) convolution. - - Args: - input_: the input tensor - width: the width of the tensor. - kernel_size: the kernel size of the filter. - - Returns: - a tensor whose width is (width + kernel_size - 1). - """ - - beginning = kernel_size // 2 - end = kernel_size - 1 - beginning - - tmp_up = array_ops.slice(input_, [0, width - beginning, 0], - [-1, beginning, -1]) - tmp_down = array_ops.slice(input_, [0, 0, 0], [-1, end, -1]) - tmp = array_ops.concat([tmp_up, input_, tmp_down], 1) - - return tmp - - cout = 64 - shape = [10, 20, 32] - outputs_shape = shape[0:-1] + [cout] - dtype = dtypes.float32 - tol = 1e-3 - gain = 3.14 - # Check orthogonality/isometry by computing the ratio between - # the 2-norms of the inputs and outputs. - for kernel_size in [[1], [2], [3], [4], [5], [6]]: - convolution = convolutional.conv1d - inputs = random_ops.random_normal(shape, dtype=dtype) - inputs_2norm = linalg_ops.norm(inputs) - input_with_circular_pad = circular_pad(inputs, shape[1], kernel_size[0]) - outputs = convolution( - input_with_circular_pad, - padding="valid", - filters=cout, - kernel_size=kernel_size[0], - use_bias=False, - kernel_initializer=init_ops.convolutional_orthogonal_1d(gain=gain)) - outputs_2norm = linalg_ops.norm(outputs) - ratio = outputs_2norm / inputs_2norm - my_ops = variables.global_variables_initializer() - with self.session(): - self.evaluate(my_ops) - # Check the shape of the outputs - t = self.evaluate(outputs) - self.assertAllEqual(t.shape, outputs_shape) - # Check isometry of the orthogonal kernel. - self.assertAllClose(self.evaluate(ratio), gain, rtol=tol, atol=tol) - class ConvolutionOrthogonal2dInitializerTest(test.TestCase): @@ -1124,67 +1020,6 @@ def testGain(self): t2 = init2(shape).eval() self.assertAllClose(t1, t2 / 3.14) - @test_util.run_deprecated_v1 - def testShapesValues(self): - - def circular_pad(input_, width, kernel_size): - """Pad input_ for computing (circular) convolution. - - Args: - input_: the input tensor - width: the width of the tensor. - kernel_size: the kernel size of the filter. - - Returns: - a tensor whose width is (width + kernel_size - 1). - """ - beginning = kernel_size // 2 - end = kernel_size - 1 - beginning - - tmp_up = array_ops.slice(input_, [0, width - beginning, 0, 0], - [-1, beginning, width, -1]) - tmp_down = array_ops.slice(input_, [0, 0, 0, 0], [-1, end, width, -1]) - tmp = array_ops.concat([tmp_up, input_, tmp_down], 1) - - new_width = width + kernel_size - 1 - tmp_left = array_ops.slice(tmp, [0, 0, width - beginning, 0], - [-1, new_width, beginning, -1]) - tmp_right = array_ops.slice(tmp, [0, 0, 0, 0], [-1, new_width, end, -1]) - - final = array_ops.concat([tmp_left, tmp, tmp_right], 2) - return final - - cout = 45 - shape = [64, 28, 28, 32] - outputs_shape = shape[0:-1] + [cout] - dtype = dtypes.float32 - tol = 1e-3 - gain = 3.14 - # Check orthogonality/isometry by computing the ratio between - # the 2-norms of the inputs and outputs. - for kernel_size in [[1, 1], [2, 2], [3, 3], [4, 4], [5, 5]]: - convolution = convolutional.conv2d - inputs = random_ops.random_normal(shape, dtype=dtype) - inputs_2norm = linalg_ops.norm(inputs) - input_with_circular_pad = circular_pad(inputs, shape[1], kernel_size[0]) - outputs = convolution( - input_with_circular_pad, - padding="valid", - filters=cout, - kernel_size=kernel_size, - use_bias=False, - kernel_initializer=init_ops.convolutional_orthogonal_2d(gain=gain)) - outputs_2norm = linalg_ops.norm(outputs) - ratio = outputs_2norm / inputs_2norm - my_ops = variables.global_variables_initializer() - with self.session(): - self.evaluate(my_ops) - # Check the shape of the outputs - t = self.evaluate(outputs) - self.assertAllEqual(t.shape, outputs_shape) - # Check isometry of the orthogonal kernel. - self.assertAllClose(self.evaluate(ratio), gain, rtol=tol, atol=tol) - @test_util.run_all_without_tensor_float_32( "Tests convolutional_orthogonal_3d, which calls matmul") @@ -1256,70 +1091,6 @@ def testNonuniformity(self): # Compute the sum of the absolute values of 'count' determinants self.assertAllClose(abs_value, count, rtol=tol, atol=tol) - @test_util.run_deprecated_v1 - def testShapesValues(self): - - def circular_pad(input_, width, kernel_size): - """Padding input_ for computing circular convolution. - - Args: - input_: the input tensor - width: the width of the tensor. - kernel_size: the kernel size of the filter. - - Returns: - a tensor whose width is (width + kernel_size - 1). - """ - - beginning = kernel_size // 2 - end = kernel_size - 1 - beginning - - tmp_up = array_ops.slice(input_, [0, width - beginning, 0, 0, 0], - [-1, beginning, -1, -1, -1]) - tmp_down = array_ops.slice(input_, [0, 0, 0, 0, 0], [-1, end, -1, -1, -1]) - tmp = array_ops.concat([tmp_up, input_, tmp_down], 1) - - tmp_left = array_ops.slice(tmp, [0, 0, width - beginning, 0, 0], - [-1, -1, beginning, -1, -1]) - tmp_right = array_ops.slice(tmp, [0, 0, 0, 0, 0], [-1, -1, end, -1, -1]) - tmp = array_ops.concat([tmp_left, tmp, tmp_right], 2) - - tmp_front = array_ops.slice(tmp, [0, 0, 0, width - beginning, 0], - [-1, -1, -1, beginning, -1]) - tmp_back = array_ops.slice(tmp, [0, 0, 0, 0, 0], [-1, -1, -1, end, -1]) - return array_ops.concat([tmp_front, tmp, tmp_back], 3) - - cout = 32 - shape = [1, 7, 7, 7, 16] - outputs_shape = shape[0:-1] + [cout] - dtype = dtypes.float32 - tol = 1e-3 - gain = 3.14 - # Check orthogonality/isometry by computing the ratio between - # the 2-norms of the inputs and outputs. - for kernel_size in [[1, 1, 1], [2, 2, 2], [3, 3, 3]]: - convolution = convolutional.conv3d - inputs = random_ops.random_normal(shape, dtype=dtype) - inputs_2norm = linalg_ops.norm(inputs) - input_with_circular_pad = circular_pad(inputs, shape[1], kernel_size[0]) - outputs = convolution( - input_with_circular_pad, - padding="valid", - filters=cout, - kernel_size=kernel_size[0], - use_bias=False, - kernel_initializer=init_ops.convolutional_orthogonal_3d(gain=gain)) - outputs_2norm = linalg_ops.norm(outputs) - ratio = outputs_2norm / inputs_2norm - my_ops = variables.global_variables_initializer() - with self.cached_session(): - self.evaluate(my_ops) - # Check the shape of the outputs - t = self.evaluate(outputs) - self.assertAllEqual(t.shape, outputs_shape) - # Check isometry of the orthogonal kernel. - self.assertAllClose(self.evaluate(ratio), gain, rtol=tol, atol=tol) - class IdentityInitializerTest(test.TestCase): diff --git a/tensorflow/python/kernel_tests/linalg/BUILD b/tensorflow/python/kernel_tests/linalg/BUILD index 936e4204d3ace8..4f12dc12ed3b7f 100644 --- a/tensorflow/python/kernel_tests/linalg/BUILD +++ b/tensorflow/python/kernel_tests/linalg/BUILD @@ -627,6 +627,7 @@ tf_py_strict_test( cuda_py_strict_test( name = "matrix_inverse_op_test", size = "small", + timeout = "moderate", srcs = ["matrix_inverse_op_test.py"], tags = ["optonly"], deps = [ diff --git a/tensorflow/python/kernel_tests/linalg/sparse/BUILD b/tensorflow/python/kernel_tests/linalg/sparse/BUILD index 280a56580ebdfe..d4d8d65195db00 100644 --- a/tensorflow/python/kernel_tests/linalg/sparse/BUILD +++ b/tensorflow/python/kernel_tests/linalg/sparse/BUILD @@ -1,6 +1,6 @@ # Tests of TensorFlow sparse linear algebra kernels using the Python API. -load("//tensorflow:tensorflow.default.bzl", "cuda_py_strict_test", "tf_py_strict_test") +load("//tensorflow:tensorflow.default.bzl", "cuda_py_strict_test") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], @@ -121,30 +121,6 @@ cuda_py_strict_test( ], ) -tf_py_strict_test( - name = "csr_sparse_matrix_dense_mat_mul_onednn_grad_test", - size = "medium", - srcs = ["csr_sparse_matrix_dense_mat_mul_grad_test.py"], - env = { - "TF_ENABLE_ONEDNN_SPMM": "1", - }, - main = "csr_sparse_matrix_dense_mat_mul_grad_test.py", - shard_count = 50, - tags = [ - "no_cuda_asan", # TODO(b/190824595) - ], - deps = [ - "//tensorflow/python/framework:ops", - "//tensorflow/python/framework:test_lib", - "//tensorflow/python/ops:array_ops", - "//tensorflow/python/ops:gradient_checker", - "//tensorflow/python/ops/linalg/sparse:sparse_csr_matrix_grad", - "//tensorflow/python/ops/linalg/sparse:sparse_csr_matrix_ops", - "//tensorflow/python/platform:client_testlib", - "//tensorflow/python/platform:tf_logging", - ], -) - cuda_py_strict_test( name = "csr_sparse_matrix_sparse_mat_mul_grad_test", size = "medium", diff --git a/tensorflow/python/kernel_tests/nn_ops/conv_ops_test.py b/tensorflow/python/kernel_tests/nn_ops/conv_ops_test.py index 71c2f3a208f122..b48876f81ee011 100644 --- a/tensorflow/python/kernel_tests/nn_ops/conv_ops_test.py +++ b/tensorflow/python/kernel_tests/nn_ops/conv_ops_test.py @@ -30,7 +30,6 @@ from tensorflow.python.framework import errors_impl from tensorflow.python.framework import ops from tensorflow.python.framework import test_util -from tensorflow.python.layers import convolutional from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import gen_nn_ops @@ -3538,13 +3537,15 @@ def benchmarkGPUConvStackFirst(self): timesteps = 600 features = 1 - inputs = random_ops.random_uniform( - [batch_size, 1, timesteps, features], seed=1234) + x = random_ops.random_uniform( + [batch_size, 1, timesteps, features], seed=1234 + ) num_outputs_list = [512] * 40 + [1] - kernel_w = 3 - x = inputs for num_outputs in num_outputs_list: - x = convolutional.conv2d(x, num_outputs, [1, kernel_w]) + kernel = random_ops.random_uniform( + [1, 3, features, num_outputs], seed=1234 + ) + x = nn_ops.conv2d(x, kernel) outputs = x self.evaluate(variables.global_variables_initializer()) diff --git a/tensorflow/python/kernel_tests/variables/BUILD b/tensorflow/python/kernel_tests/variables/BUILD index 04c6acd39386ec..bc18cb8d6d23a7 100644 --- a/tensorflow/python/kernel_tests/variables/BUILD +++ b/tensorflow/python/kernel_tests/variables/BUILD @@ -153,7 +153,6 @@ tf_py_strict_test( "//tensorflow/python/framework:errors", "//tensorflow/python/framework:ops", "//tensorflow/python/framework:test_lib", - "//tensorflow/python/layers", "//tensorflow/python/ops:array_ops", "//tensorflow/python/ops:cond", "//tensorflow/python/ops:init_ops", diff --git a/tensorflow/python/kernel_tests/variables/variable_scope_test.py b/tensorflow/python/kernel_tests/variables/variable_scope_test.py index 44e65874621ca3..ac57badf8f56e7 100644 --- a/tensorflow/python/kernel_tests/variables/variable_scope_test.py +++ b/tensorflow/python/kernel_tests/variables/variable_scope_test.py @@ -27,7 +27,6 @@ from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.framework import test_util -from tensorflow.python.layers import core as core_layers from tensorflow.python.ops import array_ops from tensorflow.python.ops import cond from tensorflow.python.ops import init_ops @@ -253,21 +252,6 @@ def testEagerVariablesOutsideStoreNotAddedToCollections(self): self.assertFalse(ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)) self.assertFalse(ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)) - def testEagerVariableStoreWithFunctionalLayer(self): - with context.eager_mode(): - container = variable_scope.EagerVariableStore() - x = constant_op.constant([[2.0]]) - with container.as_default(): - y = core_layers.dense(x, 1, name="my_dense", - kernel_initializer=init_ops.ones_initializer()) - self.assertAllEqual(y, [[2.0]]) - self.assertEqual(len(container.variables()), 2) - # Recreate the layer to test reuse. - with container.as_default(): - core_layers.dense(x, 1, name="my_dense", - kernel_initializer=init_ops.ones_initializer()) - self.assertEqual(len(container.variables()), 2) - # Not converted to use wrap_function because of # TypeError: Expected tf.group() expected Tensor arguments not 'None' with # type ''. diff --git a/tensorflow/python/kernel_tests/variables/variables_test.py b/tensorflow/python/kernel_tests/variables/variables_test.py index 45b88857090313..79e393cce8a5a9 100644 --- a/tensorflow/python/kernel_tests/variables/variables_test.py +++ b/tensorflow/python/kernel_tests/variables/variables_test.py @@ -646,6 +646,29 @@ def testTrainableVariable(self, cls): trainable=False) self.assertEqual(False, v4.trainable) + def testSaveSliceInfoFromSpecPasses(self): + save_slice_info = variables.Variable.SaveSliceInfo( + full_name="foo", + full_shape=[2, 3, 4], + var_offset=[0, 2, 0], + var_shape=[1, 1, 3]) + + save_slice_info_from_spec = variables.Variable.SaveSliceInfo.from_spec( + save_slice_info.spec) + + self.assertEqual(save_slice_info.spec, save_slice_info_from_spec.spec) + + @parameterized.parameters( + dict(spec="0", error_message="contain space-separated full_shape info"), + dict(spec="0:0", error_message="contain space-separated full_shape info"), + dict(spec="a b", error_message="full_shape must be a sequence of int"), + dict(spec="0 0:0", error_message="comma-separated pairs of offsets and"), + dict(spec="0 a,0:0:0", error_message="var_offset must be an integer"), + dict(spec="0 0,a:0:0", error_message="var_shape must be an integer")) + def testSaveSliceInfoFromSpecFails(self, spec, error_message): + with self.assertRaisesRegex(ValueError, error_message): + variables.Variable.SaveSliceInfo.from_spec(spec) + class IsInitializedTest(test.TestCase): diff --git a/tensorflow/python/lite/toco_python_api_wrapper.cc b/tensorflow/python/lite/toco_python_api_wrapper.cc index a650948a85ad89..39dc7802f28f0e 100644 --- a/tensorflow/python/lite/toco_python_api_wrapper.cc +++ b/tensorflow/python/lite/toco_python_api_wrapper.cc @@ -17,6 +17,7 @@ limitations under the License. #include #include "pybind11/pybind11.h" // from @pybind11 +#include "tensorflow/compiler/mlir/quantization/tensorflow/python/py_function_lib.h" #include "tensorflow/lite/toco/python/toco_python_api.h" #include "tensorflow/python/lib/core/pybind11_lib.h" @@ -28,16 +29,20 @@ PYBIND11_MODULE(_pywrap_toco_api, m) { [](py::object model_flags_proto_txt_raw, py::object toco_flags_proto_txt_raw, py::object input_contents_txt_raw, bool extended_return, py::object debug_info_txt_raw, - bool enable_mlir_converter) { + bool enable_mlir_converter, + const tensorflow::quantization::PyFunctionLibrary* + quantization_py_function_library) { return tensorflow::PyoOrThrow(toco::TocoConvert( model_flags_proto_txt_raw.ptr(), toco_flags_proto_txt_raw.ptr(), input_contents_txt_raw.ptr(), extended_return, - debug_info_txt_raw.ptr(), enable_mlir_converter)); + debug_info_txt_raw.ptr(), enable_mlir_converter, + quantization_py_function_library)); }, py::arg("model_flags_proto_txt_raw"), py::arg("toco_flags_proto_txt_raw"), py::arg("input_contents_txt_raw"), py::arg("extended_return") = false, py::arg("debug_info_txt_raw") = py::none(), py::arg("enable_mlir_converter") = false, + py::arg("quantization_py_function_library") = py::none(), R"pbdoc( Convert a model represented in `input_contents`. `model_flags_proto` describes model parameters. `toco_flags_proto` describes conversion @@ -55,13 +60,15 @@ PYBIND11_MODULE(_pywrap_toco_api, m) { bool fully_quantize, int inference_type, int input_data_type, int output_data_type, bool enable_numeric_verify, bool enable_whole_model_verify, py::object op_blocklist, - py::object node_blocklist, bool enable_variable_quantization) { + py::object node_blocklist, bool enable_variable_quantization, + bool disable_per_channel_for_dense_layers) { return tensorflow::PyoOrThrow(toco::MlirQuantizeModel( input_contents_txt_raw.ptr(), disable_per_channel, fully_quantize, inference_type, input_data_type, output_data_type, enable_numeric_verify, enable_whole_model_verify, op_blocklist.ptr(), node_blocklist.ptr(), - enable_variable_quantization)); + enable_variable_quantization, + disable_per_channel_for_dense_layers)); }, py::arg("input_contents_txt_raw"), py::arg("disable_per_channel") = false, py::arg("fully_quantize") = true, py::arg("inference_type") = 9, @@ -71,6 +78,7 @@ PYBIND11_MODULE(_pywrap_toco_api, m) { py::arg("op_blocklist") = py::none(), py::arg("node_blocklist") = py::none(), py::arg("enable_variable_quantization") = false, + py::arg("disable_per_channel_for_dense_layers") = false, R"pbdoc( Returns a quantized model. )pbdoc"); diff --git a/tensorflow/python/ops/BUILD b/tensorflow/python/ops/BUILD index 2da4f034f0c3c0..f4c80ec02dd88b 100644 --- a/tensorflow/python/ops/BUILD +++ b/tensorflow/python/ops/BUILD @@ -3153,6 +3153,7 @@ py_strict_library( "//tensorflow/python/util:tf_should_use", "//tensorflow/python/util:traceback_utils", "//third_party/py/numpy", + "@pypi_typing_extensions//:pkg", ], ) diff --git a/tensorflow/python/ops/math_ops_test.py b/tensorflow/python/ops/math_ops_test.py index 01eefe80f74ba2..1694537d68aaf6 100644 --- a/tensorflow/python/ops/math_ops_test.py +++ b/tensorflow/python/ops/math_ops_test.py @@ -791,16 +791,38 @@ def testWithPythonValue(self): def intEdgeTestData(self, dtype): """Edge-case test data for integer types.""" - # INT_MIN/-1 expected to produce signed-integer overflow, - # INT_MIN/INT_MAX expected to work. - nums = np.array([np.iinfo(dtype).min, -1, 1, - np.iinfo(dtype).max], - dtype=dtype).reshape([4, 1]) - divs = nums.reshape([1, 4]) + # INT_MIN/-1 will produce signed-integer overflow, so we instead test + # (INT_MIN + 1) / -1. + nums = np.array( + [ + [np.iinfo(dtype).min, -1, 1, np.iinfo(dtype).max], + [np.iinfo(dtype).min + 1, -1, 1, np.iinfo(dtype).max], + [np.iinfo(dtype).min, -1, 1, np.iinfo(dtype).max], + [np.iinfo(dtype).min, -1, 1, np.iinfo(dtype).max], + ], + dtype=dtype, + ) + divs = np.array( + [ + [ + np.iinfo(dtype).min, + np.iinfo(dtype).min, + np.iinfo(dtype).min, + np.iinfo(dtype).min, + ], + [-1, -1, -1, -1], + [1, 1, 1, 1], + [ + np.iinfo(dtype).max, + np.iinfo(dtype).max, + np.iinfo(dtype).max, + np.iinfo(dtype).max, + ], + ], + dtype=dtype, + ) return nums, divs - @test_util.disable_asan("Expected signed integer overflow.") - @test_util.disable_ubsan("Expected signed integer overflow.") def testFloorDivModIntEdges(self): for dtype in [np.int32, np.int64]: x, y = self.intEdgeTestData(dtype) @@ -810,12 +832,7 @@ def testFloorDivModIntEdges(self): tf_floor_mod = math_ops.floormod(x, y) np_floor_mod = self.numpySafeFloorModInt(x, y) self.assertAllEqual(tf_floor_mod, np_floor_mod) - z = math_ops.add(math_ops.multiply(tf_floor_div, y), tf_floor_mod) - # x = floor_div(x, y) * y + floor_mod(x, y) - self.assertAllEqual(z, np.broadcast_to(x, z.shape)) - @test_util.disable_asan("Expected signed integer overflow.") - @test_util.disable_ubsan("Expected signed integer overflow.") def testTruncateDivModIntEdges(self): for dtype in [np.int32, np.int64]: x, y = self.intEdgeTestData(dtype) @@ -825,9 +842,6 @@ def testTruncateDivModIntEdges(self): tf_truncate_mod = math_ops.truncatemod(x, y) np_truncate_mod = self.numpySafeTruncateModInt(x, y) self.assertAllEqual(tf_truncate_mod, np_truncate_mod) - z = math_ops.add(math_ops.multiply(tf_truncate_div, y), tf_truncate_mod) - # x = truncatediv(x, y) * y + truncatemod(x, y) - self.assertAllEqual(z, np.broadcast_to(x, z.shape)) @test_util.run_all_in_graph_and_eager_modes diff --git a/tensorflow/python/ops/variables.py b/tensorflow/python/ops/variables.py index 49821d75da445d..a09834cc428536 100644 --- a/tensorflow/python/ops/variables.py +++ b/tensorflow/python/ops/variables.py @@ -19,6 +19,7 @@ import functools import itertools import os +from typing_extensions import Self from tensorflow.core.framework import variable_pb2 from tensorflow.python import pywrap_tensorflow # pylint: disable=unused-import @@ -1300,6 +1301,74 @@ def spec(self): "%d,%d" % (o, s) for o, s in zip(self.var_offset, self.var_shape)) return full_shape_str + sl_spec + @classmethod + def from_spec(cls, spec: str) -> Self: + """Parses a SaveSliceInfo spec string and returns a SaveSliceInfo object. + + Args: + spec: The tensor slice spec string according to the SaveSliceInfo.spec + property. The spec contains the space-separated shape of the full + variable, followed by colon-separated pairs of the variable's offset + and shape, where each pair is comma-separated. For example, consider a + variable whose full shape is [4 3 5], offset is [0 1 3], and shape is + [4 1 2]. This variable's SaveSliceInfo.spec would be + "4 3 5 0,4:1,1:3,2". + + Returns: + A SaveSliceInfo object containing the extracted information. + + Raises: + ValueError: If the input string is not in the expected format. + """ + if not spec: + return cls() + + try: + full_shape_str, slice_str = spec.rsplit(" ", 1) + except ValueError as e: + raise ValueError( + "Spec string must contain space-separated full_shape info.") from e + + # Parse the full shape. + full_shape = [] + for dim in full_shape_str.split(): + try: + full_shape.append(int(dim)) + except ValueError as e: + raise ValueError( + "Spec string full_shape must be a sequence of integers. " + f"Found '{dim}', which is not an integer.") from e + + # Parse the slice specification. + var_offset = [] + var_shape = [] + for dim_spec in slice_str.split(":"): + try: + offset, shape = dim_spec.split(",") + except ValueError as e: + raise ValueError( + "Spec string must contain comma-separated pairs of offsets and " + "shapes.") from e + + try: + var_offset.append(int(offset)) + except ValueError as e: + raise ValueError( + "Spec string var_offset must be an integer. " + f"Found '{offset}', which is not an integer.") from e + try: + var_shape.append(int(shape)) + except ValueError as e: + raise ValueError( + "Spec string var_shape must be an integer. " + f"Found '{shape}', which is not an integer.") from e + + return cls( + full_shape=full_shape, + var_offset=var_offset, + var_shape=var_shape + ) + def to_proto(self, export_scope=None): """Returns a SaveSliceInfoDef() proto. diff --git a/tensorflow/python/ops/weak_tensor_math_ops_test.py b/tensorflow/python/ops/weak_tensor_math_ops_test.py index 3e74cb1ff76cde..f34fb65e214261 100644 --- a/tensorflow/python/ops/weak_tensor_math_ops_test.py +++ b/tensorflow/python/ops/weak_tensor_math_ops_test.py @@ -839,12 +839,36 @@ def numpySafeTruncateModInt(self, x, y): def intEdgeTestData(self, dtype): """Edge-case test data for integer types.""" - # INT_MIN/-1 expected to produce signed-integer overflow, - # INT_MIN/INT_MAX expected to work. + # INT_MIN/-1 will produce signed-integer overflow, so we instead test + # (INT_MIN + 1) / -1. nums = np.array( - [np.iinfo(dtype).min, -1, 1, np.iinfo(dtype).max], dtype=dtype - ).reshape([4, 1]) - divs = nums.reshape([1, 4]) + [ + [np.iinfo(dtype).min, -1, 1, np.iinfo(dtype).max], + [np.iinfo(dtype).min + 1, -1, 1, np.iinfo(dtype).max], + [np.iinfo(dtype).min, -1, 1, np.iinfo(dtype).max], + [np.iinfo(dtype).min, -1, 1, np.iinfo(dtype).max], + ], + dtype=dtype, + ) + divs = np.array( + [ + [ + np.iinfo(dtype).min, + np.iinfo(dtype).min, + np.iinfo(dtype).min, + np.iinfo(dtype).min, + ], + [-1, -1, -1, -1], + [1, 1, 1, 1], + [ + np.iinfo(dtype).max, + np.iinfo(dtype).max, + np.iinfo(dtype).max, + np.iinfo(dtype).max, + ], + ], + dtype=dtype, + ) return nums, divs @test_util.disable_asan("Expected signed integer overflow.") diff --git a/tensorflow/python/profiler/internal/_pywrap_profiler.pyi b/tensorflow/python/profiler/internal/_pywrap_profiler.pyi index 64193615707e1b..a514598c9682f7 100644 --- a/tensorflow/python/profiler/internal/_pywrap_profiler.pyi +++ b/tensorflow/python/profiler/internal/_pywrap_profiler.pyi @@ -23,4 +23,3 @@ def monitor(arg0: str, arg1: int, arg2: int, arg3: bool) -> str: ... def start_server(arg0: int) -> None: ... def trace(arg0: str, arg1: str, arg2: str, arg3: bool, arg4: int, arg5: int, arg6: dict) -> None: ... def xspace_to_tools_data(arg0: list, arg1: str, arg2: dict = ...) -> tuple: ... -def xspace_to_tools_data_from_byte_string(arg0: list, arg1: list, arg2: str, arg3: dict) -> tuple: ... diff --git a/tensorflow/python/profiler/internal/profiler_wrapper.cc b/tensorflow/python/profiler/internal/profiler_wrapper.cc index e2f14743cee3b8..a4c3f25a21318d 100644 --- a/tensorflow/python/profiler/internal/profiler_wrapper.cc +++ b/tensorflow/python/profiler/internal/profiler_wrapper.cc @@ -177,53 +177,4 @@ PYBIND11_MODULE(_pywrap_profiler, m) { // TODO: consider defaulting `xspace_path_list` to empty list, since // this parameter is only used for two of the tools... py::arg(), py::arg(), py::arg() = py::dict()); - - m.def("xspace_to_tools_data_from_byte_string", - [](const py::list& xspace_string_list, const py::list& filenames_list, - const py::str& py_tool_name, const py::dict options = py::dict()) { - std::vector> xspaces; - xspaces.reserve(xspace_string_list.size()); - std::vector xspace_paths; - xspace_paths.reserve(filenames_list.size()); - - // XSpace string inputs - for (py::handle obj : xspace_string_list) { - std::string xspace_string = std::string(py::cast(obj)); - auto xspace = std::make_unique(); - if (!xspace->ParseFromString(xspace_string)) { - return py::make_tuple(py::bytes(""), py::bool_(false)); - } - for (int i = 0; i < xspace->hostnames_size(); ++i) { - std::string hostname = xspace->hostnames(i); - std::replace(hostname.begin(), hostname.end(), ':', '_'); - xspace->mutable_hostnames(i)->swap(hostname); - } - xspaces.push_back(std::move(xspace)); - } - - // XSpace paths. - for (py::handle obj : filenames_list) { - xspace_paths.push_back(std::string(py::cast(obj))); - } - - auto status_or_session_snapshot = - tensorflow::profiler::SessionSnapshot::Create( - std::move(xspace_paths), std::move(xspaces)); - if (!status_or_session_snapshot.ok()) { - LOG(ERROR) << status_or_session_snapshot.status().message(); - return py::make_tuple(py::bytes(""), py::bool_(false)); - } - - std::string tool_name = std::string(py_tool_name); - ToolOptions tool_options = ToolOptionsFromPythonDict(options); - auto status_or_tool_data = - tensorflow::profiler::ConvertMultiXSpacesToToolData( - status_or_session_snapshot.value(), tool_name, tool_options); - if (!status_or_tool_data.ok()) { - LOG(ERROR) << status_or_tool_data.status().message(); - return py::make_tuple(py::bytes(""), py::bool_(false)); - } - return py::make_tuple(py::bytes(status_or_tool_data.value()), - py::bool_(true)); - }); }; diff --git a/tensorflow/python/saved_model/pywrap_saved_model_fingerprinting.cc b/tensorflow/python/saved_model/pywrap_saved_model_fingerprinting.cc index d3336a80f0881a..ce8206cdeee160 100644 --- a/tensorflow/python/saved_model/pywrap_saved_model_fingerprinting.cc +++ b/tensorflow/python/saved_model/pywrap_saved_model_fingerprinting.cc @@ -85,8 +85,8 @@ void DefineFingerprintingModule(py::module main_module) { m.def( "CreateFingerprintDef", - [](std::string export_dir) -> StatusOr { - StatusOr fingerprint = + [](std::string export_dir) -> absl::StatusOr { + absl::StatusOr fingerprint = fingerprinting::CreateFingerprintDef(export_dir); if (fingerprint.ok()) { return py::bytes(fingerprint.value().SerializeAsString()); @@ -105,7 +105,7 @@ void DefineFingerprintingModule(py::module main_module) { m.def( "ReadSavedModelFingerprint", [](std::string export_dir) { - StatusOr fingerprint = + absl::StatusOr fingerprint = fingerprinting::ReadSavedModelFingerprint(export_dir); if (fingerprint.ok()) { return py::bytes(fingerprint.value().SerializeAsString()); @@ -135,7 +135,7 @@ void DefineFingerprintingModule(py::module main_module) { m.def( "SingleprintFromFP", [](std::string export_dir) { - StatusOr singleprint = + absl::StatusOr singleprint = fingerprinting::Singleprint(export_dir); if (singleprint.ok()) { return py::str(singleprint.value()); @@ -153,7 +153,7 @@ void DefineFingerprintingModule(py::module main_module) { m.def( "SingleprintFromSM", [](std::string export_dir) { - StatusOr fingerprint_def = + absl::StatusOr fingerprint_def = fingerprinting::CreateFingerprintDef(export_dir); if (!fingerprint_def.ok()) { throw FingerprintException( @@ -164,7 +164,7 @@ void DefineFingerprintingModule(py::module main_module) { .c_str()); } - StatusOr singleprint = + absl::StatusOr singleprint = fingerprinting::Singleprint(fingerprint_def.value()); if (!singleprint.ok()) { throw FingerprintException( @@ -184,7 +184,7 @@ void DefineFingerprintingModule(py::module main_module) { "Singleprint", [](uint64 graph_def_program_hash, uint64 signature_def_hash, uint64 saved_object_graph_hash, uint64 checkpoint_hash) { - StatusOr singleprint = fingerprinting::Singleprint( + absl::StatusOr singleprint = fingerprinting::Singleprint( graph_def_program_hash, signature_def_hash, saved_object_graph_hash, checkpoint_hash); if (singleprint.ok()) { diff --git a/tensorflow/python/saved_model/save.py b/tensorflow/python/saved_model/save.py index 99c20bab308cb2..25b1d9f1d6797f 100644 --- a/tensorflow/python/saved_model/save.py +++ b/tensorflow/python/saved_model/save.py @@ -25,6 +25,7 @@ from tensorflow.core.framework import function_pb2 from tensorflow.core.framework import graph_pb2 +from tensorflow.core.framework import node_def_pb2 from tensorflow.core.framework import versions_pb2 from tensorflow.core.protobuf import meta_graph_pb2 from tensorflow.core.protobuf import saved_model_pb2 @@ -843,6 +844,95 @@ def _trace_gradient_functions(graph: ops.Graph, saveable_view: _SaveableView): saveable_view.gradient_defs.append(grad_def) +def _strip_debug_nodes(meta_graph_def: meta_graph_pb2.MetaGraphDef) -> None: + """An experimental function to remove debug nodes from the final graph. + + This function removes all Assert and CheckNumerics nodes from the meta_graph. + It strips the operators in both the nodes and in all of the function defs, + with the Assert ops being replaced by `NoOp`s and the CheckNumerics ops being + transformed into `Identity` ops. In addition to this, it creates control + inputs for the nodes that are not relevant for the op. For more information + about control inputs please see go/how-tensors-flow#control-dependencies. + + Args: + meta_graph_def: The meta_graph that will be exported. + """ + + def erase_regular_node_attributes(node: node_def_pb2.NodeDef) -> None: + """Erases regular node attributes.""" + attributes_to_remove = [ + attribute + for attribute in node.attr.keys() + if not attribute.startswith("_") + ] + for attribute in attributes_to_remove: + node.attr.pop(attribute) + + def prune_all_non_t_attributes(node: node_def_pb2.NodeDef) -> None: + """Prunes all attributes that are not `T`.""" + if "T" in node.attr: + t_value = node.attr["T"] + node.ClearField("attr") + node.attr["T"].CopyFrom(t_value) + else: + node.ClearField("attr") + + def is_control_input(name: str) -> str: + """Returns whether or not the input is a control input.""" + return name and name[0] == "^" + + def as_control_dep(name: str) -> str: + """Returns the input as a control dependency.""" + return "^" + name.split(":")[0] + + def maybe_do_strip(node: node_def_pb2.NodeDef) -> None: + """Strips the graph from Assert and CheckNumerics ops. + + For Assert ops, this function also rewrites all of the inputs to the nodes + that were transformed by making them into control dependencies. It also + removes all of the regular node attributes, that is all node attributes + that do not start with `_`. + + For CheckNumerics ops, this function turns the op into an Identity op, + which will be pruned later (according to the original implementation in + grappler's `debug_stripper.cc`. Then, since Identity ops only take one + input, it leaves the first input as is while transforming the other ones + into control dependencies. + + Args: + node: The node to potentally strip. + """ + if node.op == "Assert" or node.op == "PrintV2": + node.op = "NoOp" + erase_regular_node_attributes(node) + new_inputs = [] + for inp in node.input: + if not is_control_input(inp): + new_inputs.append(as_control_dep(inp)) + else: + new_inputs.append(inp) + node.ClearField("input") + node.input.extend(new_inputs) + elif node.op == "CheckNumerics" or node.op == "Print": + # The identity op will be pruned later. + node.op = "Identity" + prune_all_non_t_attributes(node) + # As Identity op only takes one input, mark redundant inputs as control + # inputs. + for i in range(1, len(node.input)): + if not is_control_input(node.input[i]): + node.input[i] = as_control_dep(node.input[i]) + + # First, we strip the assert nodes from the graph. + for node in meta_graph_def.graph_def.node: + maybe_do_strip(node) + + # Then, we strip the assert nodes from all of the function defs. + for func in meta_graph_def.graph_def.library.function: + for node in func.node_def: + maybe_do_strip(node) + + def _fill_meta_graph_def( meta_graph_def: meta_graph_pb2.MetaGraphDef, saveable_view: _SaveableView, @@ -850,6 +940,7 @@ def _fill_meta_graph_def( namespace_whitelist: List[str], save_custom_gradients: bool, create_saver: bool, + enable_debug_stripper: bool, defaults=None, ) -> Tuple[_AssetInfo, ops.Graph]: """Generates a MetaGraph which calls `signature_functions`. @@ -862,6 +953,7 @@ def _fill_meta_graph_def( namespace_whitelist: List of strings containing whitelisted op namespaces. save_custom_gradients: Whether to save custom gradients. create_saver: Whether to add SavedModel's native save and restore ops. + enable_debug_stripper: Whether to strip the debug nodes from the graph. defaults: A dictionary mapping signature_key to dictionary of user_specified_name to Tensor representing default values. @@ -952,8 +1044,6 @@ def call_with_mapped_captures(function, args): versions.__git_version__) # We currently always strip default attributes. meta_graph_def.meta_info_def.stripped_default_attrs = True - meta_graph_def.meta_info_def.stripped_op_list.MergeFrom( - meta_graph.stripped_op_list_for_graph(meta_graph_def.graph_def)) meta_graph_def.asset_file_def.extend(asset_info.asset_defs) for signature_key, signature in signatures.items(): meta_graph_def.signature_def[signature_key].CopyFrom(signature) @@ -961,6 +1051,10 @@ def call_with_mapped_captures(function, args): # store tensor_content in litle endian format if sys.byteorder == "big": utils_impl.swap_function_tensor_content(meta_graph_def, "big", "little") + if enable_debug_stripper: + _strip_debug_nodes(meta_graph_def) + meta_graph_def.meta_info_def.stripped_op_list.MergeFrom( + meta_graph.stripped_op_list_for_graph(meta_graph_def.graph_def)) return asset_info, exported_graph @@ -1514,6 +1608,7 @@ def _build_meta_graph_impl( namespace_whitelist=options.namespace_whitelist, save_custom_gradients=options.experimental_custom_gradients, create_saver=not options.experimental_skip_saver, + enable_debug_stripper=options.experimental_debug_stripper, defaults=defaults, ) if options.function_aliases: diff --git a/tensorflow/python/saved_model/save_options.py b/tensorflow/python/saved_model/save_options.py index c0da1ff1f8cd22..bc6c7b657ef95e 100644 --- a/tensorflow/python/saved_model/save_options.py +++ b/tensorflow/python/saved_model/save_options.py @@ -103,6 +103,7 @@ class SaveOptions: "namespace_whitelist", "save_debug_info", "function_aliases", + "experimental_debug_stripper", "experimental_io_device", "experimental_variable_policy", "experimental_custom_gradients", @@ -116,6 +117,7 @@ def __init__( namespace_whitelist=None, save_debug_info=False, function_aliases=None, + experimental_debug_stripper=False, experimental_io_device=None, experimental_variable_policy=None, experimental_custom_gradients=True, @@ -149,6 +151,10 @@ class Adder(tf.Module): ... @tf.function ... def double(self, x): tf.saved_model.SaveOptions( ... function_aliases={'double': model.double}) >>> tf.saved_model.save(model, '/tmp/adder', options=options) + experimental_debug_stripper: bool. If set to True, this strips the debug + nodes from the graph, from both the nodes and the function defs. Note + that this currently only strips the `Assert` nodes from the graph and + converts them into `NoOp`s instead. experimental_io_device: string. Applies in a distributed setting. Tensorflow device to use to access the filesystem. If `None` (default) then for each variable the filesystem is accessed from the CPU:0 device @@ -186,6 +192,7 @@ class Adder(tf.Module): ... @tf.function ... def double(self, x): self.save_debug_info = save_debug_info self.function_aliases = function_aliases if function_aliases else dict() self.experimental_custom_gradients = experimental_custom_gradients + self.experimental_debug_stripper = experimental_debug_stripper self.experimental_io_device = experimental_io_device self.experimental_variable_policy = VariablePolicy.from_obj( experimental_variable_policy diff --git a/tensorflow/python/saved_model/save_test.py b/tensorflow/python/saved_model/save_test.py index 743831edc042e6..45fb5c4f11e0ab 100644 --- a/tensorflow/python/saved_model/save_test.py +++ b/tensorflow/python/saved_model/save_test.py @@ -20,8 +20,12 @@ from google.protobuf import text_format from tensorflow.core.config import flags +from tensorflow.core.framework import attr_value_pb2 +from tensorflow.core.framework import function_pb2 from tensorflow.core.framework import graph_debug_info_pb2 from tensorflow.core.framework import graph_pb2 +from tensorflow.core.framework import node_def_pb2 +from tensorflow.core.protobuf import meta_graph_pb2 from tensorflow.python.checkpoint import checkpoint from tensorflow.python.checkpoint.sharding import sharding_policies from tensorflow.python.client import session as session_lib @@ -1110,22 +1114,160 @@ def test_save_custom_op_with_no_whitelist_specified(self): # If the user passes an empty list for the namespace whitelist rather than # nothing, we should then throw an exception if a custom op is used. with self.assertRaisesRegex( - ValueError, "Attempted to save ops from non-whitelisted namespaces"): + ValueError, "Attempted to save ops from non-whitelisted namespaces" + ): save._verify_ops(graph_def, []) + def test_strip_debug_nodes(self): + # Test that we are able to strip debug nodes from a meta_graph correctly. + test_node_defs = [ + node_def_pb2.NodeDef( + name="AssertNode", + op="Assert", + input=[ + "NonControlInput:output:0", + "^ControlInput:output:0", + ], + attr={ + "regular_node_attr": attr_value_pb2.AttrValue(i=1), + "_non_regular_node_attr": attr_value_pb2.AttrValue(i=2), + } + ), + node_def_pb2.NodeDef( + name="ConstNode", + op="Const", + ), + node_def_pb2.NodeDef( + name="CheckNumericsNode", + op="CheckNumerics", + input=[ + "NonControlInput:output:0", + "NonControlInputTwo:output:0", + "^ControlInput:output:0", + ], + attr={ + "T": attr_value_pb2.AttrValue(i=4), + "NotT": attr_value_pb2.AttrValue(i=5), + } + ), + node_def_pb2.NodeDef( + name="CheckNumericsNodeTwo", + op="CheckNumerics", + input=[ + "NonControlInput:output:0", + "NonControlInputTwo:output:0", + "^ControlInput:output:0", + ], + attr={ + "OnlyNotT": attr_value_pb2.AttrValue(i=6), + }, + ), + node_def_pb2.NodeDef( + name="PrintNode", + op="Print", + input=[ + "NonControlInput:output:0", + ], + ), + node_def_pb2.NodeDef( + name="PrintV2Node", + op="PrintV2", + input=[ + "NonControlInput:output:0", + ], + ), + ] + + expected_node_defs = [ + node_def_pb2.NodeDef( + name="AssertNode", + op="NoOp", + input=[ + "^NonControlInput", + "^ControlInput:output:0", + ], + attr={ + "_non_regular_node_attr": attr_value_pb2.AttrValue(i=2), + } + ), + node_def_pb2.NodeDef( + name="ConstNode", + op="Const", + ), + node_def_pb2.NodeDef( + name="CheckNumericsNode", + op="Identity", + input=[ + "NonControlInput:output:0", + "^NonControlInputTwo", + "^ControlInput:output:0", + ], + attr={ + "T": attr_value_pb2.AttrValue(i=4), + } + ), + node_def_pb2.NodeDef( + name="CheckNumericsNodeTwo", + op="Identity", + input=[ + "NonControlInput:output:0", + "^NonControlInputTwo", + "^ControlInput:output:0", + ], + ), + node_def_pb2.NodeDef( + name="PrintNode", + op="Identity", + input=[ + "NonControlInput:output:0", + ], + ), + node_def_pb2.NodeDef( + name="PrintV2Node", + op="NoOp", + input=[ + "^NonControlInput", + ], + ), + ] + + meta_graph_def = meta_graph_pb2.MetaGraphDef( + graph_def=graph_pb2.GraphDef( + node=test_node_defs, + library=function_pb2.FunctionDefLibrary( + function=[function_pb2.FunctionDef(node_def=test_node_defs)] + ), + ), + ) + + expected = meta_graph_pb2.MetaGraphDef( + graph_def=graph_pb2.GraphDef( + node=expected_node_defs, + library=function_pb2.FunctionDefLibrary( + function=[function_pb2.FunctionDef(node_def=expected_node_defs)] + ), + ), + ) + + save._strip_debug_nodes(meta_graph_def) + self.assertEqual(expected, meta_graph_def) + def test_save_debug_info_enabled(self): root = autotrackable.AutoTrackable() root.f = def_function.function( - lambda x: math_ops.mul(2., x, name="DEBUG_INFO_OP"), - input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)]) + lambda x: math_ops.mul(2.0, x, name="DEBUG_INFO_OP"), + input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)], + ) save_dir = os.path.join(self.get_temp_dir(), "saved_model") save.save( root, save_dir, root.f, - options=save_options.SaveOptions(save_debug_info=True)) - debug_info_file_name = os.path.join(save_dir, "debug", - "saved_model_debug_info.pb") + options=save_options.SaveOptions(save_debug_info=True), + ) + debug_info_file_name = os.path.join( + save_dir, "debug", "saved_model_debug_info.pb" + ) self.assertTrue(os.path.exists(debug_info_file_name)) debug_info = graph_debug_info_pb2.GraphDebugInfo() with open(debug_info_file_name, "rb") as f: diff --git a/tensorflow/python/tfe_wrapper.cc b/tensorflow/python/tfe_wrapper.cc index 1c77e02abd529a..a308a621912917 100644 --- a/tensorflow/python/tfe_wrapper.cc +++ b/tensorflow/python/tfe_wrapper.cc @@ -1005,7 +1005,7 @@ PYBIND11_MODULE(_pywrap_tfe, m) { TFE_ContextSetServerDefWithTimeoutAndRetries( tensorflow::InputTFE_Context(ctx), keep_alive_secs, buf.get()->data, buf.get()->length, timeout, retries, - status.get()); + status.get(), /*clear_existing_contexts=*/false); Py_END_ALLOW_THREADS; tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); }); diff --git a/tensorflow/python/tools/BUILD b/tensorflow/python/tools/BUILD index 208964d198579e..ac1aa9ccc6a197 100644 --- a/tensorflow/python/tools/BUILD +++ b/tensorflow/python/tools/BUILD @@ -429,6 +429,9 @@ py_strict_test( ], python_version = "PY3", srcs_version = "PY3", + tags = [ + "noasan", # TODO(b/222716501) + ], deps = [ ":saved_model_cli_lib", # copybara:uncomment "//third_party/py/google/protobuf:use_fast_cpp_protos", @@ -469,6 +472,18 @@ py_strict_binary( ], ) +# copybara:comment_begin(oss-only) +py_strict_binary( + name = "grpc_tpu_worker", + srcs = ["grpc_tpu_worker.py"], +) + +py_strict_binary( + name = "grpc_tpu_worker_service", + srcs = ["grpc_tpu_worker_service.py"], +) +# copybara:comment_end + EMITTED_AOT_SAVE_MODEL_OBJECTS = [ "x_matmul_y_large/saved_model.pb", "x_matmul_y_large/variables/variables.index", diff --git a/tensorflow/python/tools/api/generator2/generate_api.bzl b/tensorflow/python/tools/api/generator2/generate_api.bzl index c2a96438576d22..554389e625147a 100644 --- a/tensorflow/python/tools/api/generator2/generate_api.bzl +++ b/tensorflow/python/tools/api/generator2/generate_api.bzl @@ -176,7 +176,7 @@ def _generate_api_impl(ctx): args.use_param_file("--flagfile=%s") args.add_joined("--output_files", ctx.outputs.output_files, join_with = ",") - args.add("--output_dir", paths.join(ctx.bin_dir.path, ctx.label.package, ctx.attr.output_dir)) + args.add("--output_dir", paths.join(ctx.bin_dir.path, ctx.label.workspace_root, ctx.label.package, ctx.attr.output_dir)) if ctx.file.root_init_template: args.add("--root_init_template", ctx.file.root_init_template) args.add("--apiversion", ctx.attr.api_version) diff --git a/tensorflow/python/tools/grpc_tpu_worker.py b/tensorflow/python/tools/grpc_tpu_worker.py new file mode 100644 index 00000000000000..154fcaeacf07d9 --- /dev/null +++ b/tensorflow/python/tools/grpc_tpu_worker.py @@ -0,0 +1,138 @@ +# Copyright 2022 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# pylint: disable=g-import-not-at-top +"""Python-based TPU Worker GRPC server. + +Start a blocking TPU Worker GRPC server. + +Usage: + python3 grpc_tpu_worker.py +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import argparse +import os +import sys + +import requests + +from tensorflow.core.protobuf import config_pb2 +from tensorflow.core.protobuf import tensorflow_server_pb2 +from tensorflow.python.training import server_lib + + +def get_metadata(key): + return requests.get( + 'http://metadata.google.internal/computeMetadata' + '/v1/instance/attributes/{}'.format(key), + headers={ + 'Metadata-Flavor': 'Google' + }).text + + +def get_host_ip(): + return requests.get( + 'http://metadata.google.internal/computeMetadata/v1/instance/network-interfaces/0/ip', + headers={ + 'Metadata-Flavor': 'Google' + }).text + + +def setup_env_vars(): + """Set environment variables.""" + worker_id = get_metadata('agent-worker-number') + accelerator_type = get_metadata('accelerator-type') + worker_network_endpoints = get_metadata('worker-network-endpoints') + os.environ['TPU_STDERR_LOG_LEVEL'] = '0' + os.environ['CLOUD_TPU_TASK_ID'] = worker_id + os.environ['TPU_LOCK_DEVICE'] = 'true' + os.environ['TPU_CHIPS_PER_HOST_BOUNDS'] = '2,2,1' + accelerator_type_to_host_bounds = { + # v2 + 'v2-8': '1,1,1', + 'v2-32': '2,2,1', + 'v2-128': '4,4,1', + 'v2-256': '4,8,1', + 'v2-512': '8,8,1', + # v3 + 'v3-8': '1,1,1', + 'v3-32': '2,2,1', + 'v3-64': '2,4,1', + 'v3-128': '4,4,1', + 'v3-256': '4,8,1', + 'v3-512': '8,8,1', + 'v3-1024': '8,16,1', + 'v3-2048': '16,16,1', + # v4 + 'v4-8': '1,1,1', + 'v4-16': '1,1,2', + 'v4-32': '1,1,4', + 'v4-64': '1,2,4', + 'v4-128': '2,2,4', + 'v4-256': '2,2,8', + 'v4-512': '2,4,8', + 'v4-1024': '4,4,8', + 'v4-2048': '4,4,16', + 'v4-4096': '4,8,16', + } + + os.environ['TPU_HOST_BOUNDS'] = accelerator_type_to_host_bounds[ + accelerator_type] + os.environ['TPU_MESH_CONTROLLER_ADDRESS'] = worker_network_endpoints.split( + ',')[0].split(':')[2] + ':8476' + os.environ['TPU_MESH_CONTROLLER_PORT'] = '8476' + + os.environ['TPU_STDERR_LOG_LEVEL'] = '0' + + if accelerator_type not in ['v4-8', 'v4-16', 'v4-32', 'v4-64']: + os.environ['TPU_TOPOLOGY_WRAP'] = 'true,true,true' + + # Set the hostname override. + os.environ['TPU_HOSTNAME_OVERRIDE'] = get_host_ip() + + +def main(unused_args): + # Create Protobuf ServerDef. + server_def = tensorflow_server_pb2.ServerDef(protocol='grpc') + job_def = server_def.cluster.job.add() + job_def.name = 'tpu_worker' + job_def.tasks[0] = 'localhost:8470' + server_def.job_name = 'tpu_worker' + server_def.task_index = 0 + + config = config_pb2.ConfigProto() + + # Create GRPC Server instance + server = server_lib.Server(server_def, config=config) + + # join() is blocking, unlike start() + server.join() + + +def run(): + parser = argparse.ArgumentParser() + + _, unparsed = parser.parse_known_args() + # Must set environment variables before importing tensorflow. + setup_env_vars() + from tensorflow.python.platform import app + app.run(main=main, argv=[sys.argv[0]] + unparsed) + + +if __name__ == '__main__': + run() diff --git a/tensorflow/python/tools/grpc_tpu_worker_service.py b/tensorflow/python/tools/grpc_tpu_worker_service.py new file mode 100644 index 00000000000000..9da28227492f07 --- /dev/null +++ b/tensorflow/python/tools/grpc_tpu_worker_service.py @@ -0,0 +1,97 @@ +# Copyright 2022 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# pylint: disable=g-import-not-at-top +"""Script to start GRPC worker as a service. + +Usage: + python3 grpc_tpu_worker_service.py +""" + +import os +import subprocess +import sys + + +def get_sys_path(): + return ":".join(sys.path) + + +def get_username(): + return os.environ.get("USER") + + +username = get_username() +sys_path = get_sys_path() + +SERVICE_FILE_CONTENT = f""" +[Unit] +Description=GRPC TPU Worker Service +After=network.target +[Service] +Type=simple +Environment="PYTHONPATH=$PYTHONPATH{sys_path}" +EnvironmentFile=/home/tpu-runtime/tpu-env +#ExecStartPre=/bin/mkdir -p /tmp/tflogs +#ExecStartPre=/bin/touch /tmp/tflogs/grpc_tpu_worker.log +#ExecStartPre=/bin/chmod +r /tmp/tflogs +ExecStart=/home/{get_username()}/.local/bin/start_grpc_tpu_worker #2>&1 | tee -a /tmp/tflogs/grpc_tpu_worker.log +Restart=on-failure +# Restart service after 10 seconds if the service crashes: +RestartSec=10 +[Install] +WantedBy=multi-user.target +""" +SERVICE_NAME = "grpc_tpu_worker.service" + + +def create_systemd_service_file(service_content, service_name): + with open(service_name, "w") as file: + file.write(service_content) + print(f"Service file {service_name} created") + + +def move_file_to_systemd(service_name): + if not os.path.exists("~/.config/systemd/user/"): + mkdir_command = "mkdir -p ~/.config/systemd/user" + subprocess.run(mkdir_command, shell=True, check=True) + print("Created directory ~/.config/systemd/user/") + command = f"mv {service_name} ~/.config/systemd/user/{service_name}" + subprocess.run(command, shell=True, check=True) + print(f"Service file moved to ~/.config/systemd/user/{service_name}") + + +def enable_start_service(service_name): + commands = [ + "systemctl --user daemon-reload", + f"systemctl --user enable {service_name}", + f"systemctl --user start {service_name}", + ] + for command in commands: + subprocess.run(command, shell=True, check=True) + print(f"Executed: {command}") + + +def run(): + if os.path.exists(f"~/.config/systemd/user/{SERVICE_NAME}"): + print(f"Service file ~/.config/systemd/user/{SERVICE_NAME} already exists") + sys.exit(1) + else: + create_systemd_service_file(SERVICE_FILE_CONTENT, SERVICE_NAME) + move_file_to_systemd(SERVICE_NAME) + enable_start_service(SERVICE_NAME) + + +if __name__ == "__main__": + run() diff --git a/tensorflow/python/tpu/tpu.py b/tensorflow/python/tpu/tpu.py index cc830a68dc427d..a38bd9f881ee09 100644 --- a/tensorflow/python/tpu/tpu.py +++ b/tensorflow/python/tpu/tpu.py @@ -1513,6 +1513,7 @@ def under_tpu_inference_context() -> bool: graph = graph.outer_graph else: return False + return False class _TPUInferenceContext(control_flow_ops.XLAControlFlowContext): diff --git a/tensorflow/python/tpu/tpu_embedding_v2_utils.py b/tensorflow/python/tpu/tpu_embedding_v2_utils.py index 0ea650b336c630..98978846b9103d 100644 --- a/tensorflow/python/tpu/tpu_embedding_v2_utils.py +++ b/tensorflow/python/tpu/tpu_embedding_v2_utils.py @@ -240,7 +240,7 @@ def __init__( clip_weight_min: Optional[float] = None, clip_weight_max: Optional[float] = None, weight_decay_factor: Optional[float] = None, - multiply_weight_decay_factor_by_learning_rate: bool = None, + multiply_weight_decay_factor_by_learning_rate: Optional[bool] = None, clipvalue: Optional[ClipValueType] = None, low_dimensional_packing_status: bool = False, ): @@ -357,7 +357,7 @@ def __init__( clip_weight_min: Optional[float] = None, clip_weight_max: Optional[float] = None, weight_decay_factor: Optional[float] = None, - multiply_weight_decay_factor_by_learning_rate: bool = None, + multiply_weight_decay_factor_by_learning_rate: Optional[bool] = None, slot_variable_creation_fn: Optional[SlotVarCreationFnType] = None, clipvalue: Optional[ClipValueType] = None, low_dimensional_packing_status: bool = False, @@ -490,7 +490,7 @@ def __init__( clip_weight_min: Optional[float] = None, clip_weight_max: Optional[float] = None, weight_decay_factor: Optional[float] = None, - multiply_weight_decay_factor_by_learning_rate: bool = None, + multiply_weight_decay_factor_by_learning_rate: Optional[bool] = None, slot_variable_creation_fn: Optional[SlotVarCreationFnType] = None, clipvalue: Optional[ClipValueType] = None, low_dimensional_packing_status: bool = False, @@ -640,7 +640,7 @@ def __init__( clip_weight_min: Optional[float] = None, clip_weight_max: Optional[float] = None, weight_decay_factor: Optional[float] = None, - multiply_weight_decay_factor_by_learning_rate: bool = None, + multiply_weight_decay_factor_by_learning_rate: Optional[bool] = None, slot_variable_creation_fn: Optional[SlotVarCreationFnType] = None, clipvalue: Optional[ClipValueType] = None, multiply_linear_by_learning_rate: bool = False, @@ -815,7 +815,7 @@ def __init__( clip_weight_min: Optional[float] = None, clip_weight_max: Optional[float] = None, weight_decay_factor: Optional[float] = None, - multiply_weight_decay_factor_by_learning_rate: bool = None, + multiply_weight_decay_factor_by_learning_rate: Optional[bool] = None, slot_variable_creation_fn: Optional[SlotVarCreationFnType] = None, clipvalue: Optional[ClipValueType] = None, low_dimensional_packing_status: bool = False, diff --git a/tensorflow/python/trackable/data_structures.py b/tensorflow/python/trackable/data_structures.py index 4cb9904f275c21..989549c0b65e77 100644 --- a/tensorflow/python/trackable/data_structures.py +++ b/tensorflow/python/trackable/data_structures.py @@ -818,7 +818,13 @@ def __getattribute__(self, name): # of the wrapper without this logic. return object.__getattribute__(self, name) else: - return super().__getattribute__(name) + # Raise TypeError as AttributeError to fix breakage in wrapt 1.15 for + # `__getattribute__` as suggested in discussion with library author in + # GitHub https://github.com/GrahamDumpleton/wrapt/issues/231 + try: + return super().__getattribute__(name) + except TypeError as e: + raise AttributeError from e def copy(self): return copy.copy(self) diff --git a/tensorflow/tensorflow.bzl b/tensorflow/tensorflow.bzl index 6f98d8ab55622a..13fd78d1ebf68f 100644 --- a/tensorflow/tensorflow.bzl +++ b/tensorflow/tensorflow.bzl @@ -82,7 +82,7 @@ def register_extension_info(**kwargs): # not contain rc or alpha, only numbers. # Also update tensorflow/core/public/version.h # and tensorflow/tools/pip_package/setup.py -VERSION = "2.16.0" +VERSION = "2.17.0" VERSION_MAJOR = VERSION.split(".")[0] two_gpu_tags = ["requires-gpu-nvidia:2", "manual", "no_pip"] @@ -2547,7 +2547,7 @@ def py_test(deps = [], data = [], kernels = [], exec_properties = None, test_rul }), data = data + select({ "//conditions:default": kernels, - clean_dep("//tensorflow:no_tensorflow_py_deps"): ["//tensorflow/tools/pip_package:win_pip_package_marker"], + clean_dep("//tensorflow:no_tensorflow_py_deps"): [], }), exec_properties = exec_properties, **kwargs diff --git a/tensorflow/tf_framework_version_script.lds b/tensorflow/tf_framework_version_script.lds index 5a62eb9060f054..3e0dca3e587610 100644 --- a/tensorflow/tf_framework_version_script.lds +++ b/tensorflow/tf_framework_version_script.lds @@ -1,13 +1,5 @@ VERS_1.0 { # Hide libjpeg symbols to avoid symbol conflict with OpenCV local: - jpeg_*; - jinit_*; - jdiv_round_up; - jround_up; - jzero_far; - jcopy_*; - jsimd_*; - hwloc_*; *mlir*; }; diff --git a/tensorflow/tools/api/golden/v1/tensorflow.-variable.-save-slice-info.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.-variable.-save-slice-info.pbtxt index ac3ccd468b216a..612803d0358fc2 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.-variable.-save-slice-info.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.-variable.-save-slice-info.pbtxt @@ -10,6 +10,10 @@ tf_class { name: "__init__" argspec: "args=[\'self\', \'full_name\', \'full_shape\', \'var_offset\', \'var_shape\', \'save_slice_info_def\', \'import_scope\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], " } + member_method { + name: "from_spec" + argspec: "args=[\'cls\', \'spec\'], varargs=None, keywords=None, defaults=None" + } member_method { name: "to_proto" argspec: "args=[\'self\', \'export_scope\'], varargs=None, keywords=None, defaults=[\'None\'], " diff --git a/tensorflow/tools/api/golden/v1/tensorflow.data.-options.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.data.-options.pbtxt index 3bc59721e5ccc2..141e749bb3242f 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.data.-options.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.data.-options.pbtxt @@ -47,6 +47,10 @@ tf_class { name: "experimental_warm_start" mtype: "" } + member { + name: "framework_type" + mtype: "" + } member { name: "threading" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-autotune-options.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-autotune-options.pbtxt index 91a0a2f6c57c7f..4c9a8ccfe98730 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-autotune-options.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-autotune-options.pbtxt @@ -15,6 +15,10 @@ tf_class { name: "enabled" mtype: "" } + member { + name: "initial_parallelism" + mtype: "" + } member { name: "ram_budget" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-optimization-options.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-optimization-options.pbtxt index b4ce611b3ff696..6e8bac7c375f2e 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-optimization-options.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-optimization-options.pbtxt @@ -43,6 +43,10 @@ tf_class { name: "parallel_batch" mtype: "" } + member { + name: "seq_interleave_prefetch" + mtype: "" + } member { name: "shuffle_and_repeat_fusion" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt index ba8a37dac5e4f0..4829cda77f6b2c 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt @@ -1944,6 +1944,10 @@ tf_module { name: "GlobalIterId" argspec: "args=[\'name\'], varargs=None, keywords=None, defaults=[\'None\'], " } + member_method { + name: "GlobalShuffleDataset" + argspec: "args=[\'input_dataset\', \'seed\', \'seed2\', \'seed_generator\', \'output_types\', \'output_shapes\', \'reshuffle_each_iteration\', \'metadata\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'\', \'None\'], " + } member_method { name: "Greater" argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " diff --git a/tensorflow/tools/api/golden/v1/tensorflow.saved_model.-save-options.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.saved_model.-save-options.pbtxt index 99aeb4d5057ef9..44764be213ac66 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.saved_model.-save-options.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.saved_model.-save-options.pbtxt @@ -6,6 +6,10 @@ tf_class { name: "experimental_custom_gradients" mtype: "" } + member { + name: "experimental_debug_stripper" + mtype: "" + } member { name: "experimental_image_format" mtype: "" @@ -40,6 +44,6 @@ tf_class { } member_method { name: "__init__" - argspec: "args=[\'self\', \'namespace_whitelist\', \'save_debug_info\', \'function_aliases\', \'experimental_io_device\', \'experimental_variable_policy\', \'experimental_custom_gradients\', \'experimental_image_format\', \'experimental_skip_saver\', \'experimental_sharding_callback\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\', \'None\', \'None\', \'True\', \'False\', \'False\', \'None\'], " + argspec: "args=[\'self\', \'namespace_whitelist\', \'save_debug_info\', \'function_aliases\', \'experimental_debug_stripper\', \'experimental_io_device\', \'experimental_variable_policy\', \'experimental_custom_gradients\', \'experimental_image_format\', \'experimental_skip_saver\', \'experimental_sharding_callback\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\', \'False\', \'None\', \'None\', \'True\', \'False\', \'False\', \'None\'], " } } diff --git a/tensorflow/tools/api/golden/v1/tensorflow.train.experimental.-max-shard-size-policy.-max-shard-size-partitioner.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.train.experimental.-max-shard-size-policy.-max-shard-size-partitioner.pbtxt new file mode 100644 index 00000000000000..f546ec054e340c --- /dev/null +++ b/tensorflow/tools/api/golden/v1/tensorflow.train.experimental.-max-shard-size-policy.-max-shard-size-partitioner.pbtxt @@ -0,0 +1,12 @@ +path: "tensorflow.train.experimental.MaxShardSizePolicy.MaxShardSizePartitioner" +tf_class { + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + } + member_method { + name: "get_shards" + argspec: "args=[\'self\', \'max_shard_size\', \'shardable_tensors\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v1/tensorflow.train.experimental.-max-shard-size-policy.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.train.experimental.-max-shard-size-policy.pbtxt index eeb8a04569157a..d2e902bc6d235d 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.train.experimental.-max-shard-size-policy.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.train.experimental.-max-shard-size-policy.pbtxt @@ -3,6 +3,10 @@ tf_class { is_instance: "" is_instance: "" is_instance: "" + member { + name: "MaxShardSizePartitioner" + mtype: "" + } member { name: "description" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.-variable.-save-slice-info.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.-variable.-save-slice-info.pbtxt index ac3ccd468b216a..612803d0358fc2 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.-variable.-save-slice-info.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.-variable.-save-slice-info.pbtxt @@ -10,6 +10,10 @@ tf_class { name: "__init__" argspec: "args=[\'self\', \'full_name\', \'full_shape\', \'var_offset\', \'var_shape\', \'save_slice_info_def\', \'import_scope\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], " } + member_method { + name: "from_spec" + argspec: "args=[\'cls\', \'spec\'], varargs=None, keywords=None, defaults=None" + } member_method { name: "to_proto" argspec: "args=[\'self\', \'export_scope\'], varargs=None, keywords=None, defaults=[\'None\'], " diff --git a/tensorflow/tools/api/golden/v2/tensorflow.data.-options.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.data.-options.pbtxt index 3bc59721e5ccc2..141e749bb3242f 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.data.-options.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.data.-options.pbtxt @@ -47,6 +47,10 @@ tf_class { name: "experimental_warm_start" mtype: "" } + member { + name: "framework_type" + mtype: "" + } member { name: "threading" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-autotune-options.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-autotune-options.pbtxt index 91a0a2f6c57c7f..4c9a8ccfe98730 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-autotune-options.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-autotune-options.pbtxt @@ -15,6 +15,10 @@ tf_class { name: "enabled" mtype: "" } + member { + name: "initial_parallelism" + mtype: "" + } member { name: "ram_budget" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-optimization-options.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-optimization-options.pbtxt index b4ce611b3ff696..6e8bac7c375f2e 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-optimization-options.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-optimization-options.pbtxt @@ -43,6 +43,10 @@ tf_class { name: "parallel_batch" mtype: "" } + member { + name: "seq_interleave_prefetch" + mtype: "" + } member { name: "shuffle_and_repeat_fusion" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.experimental.dtensor.-d-variable.-save-slice-info.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.experimental.dtensor.-d-variable.-save-slice-info.pbtxt index 790fdc05122e38..ebd596ecb91165 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.experimental.dtensor.-d-variable.-save-slice-info.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.experimental.dtensor.-d-variable.-save-slice-info.pbtxt @@ -10,6 +10,10 @@ tf_class { name: "__init__" argspec: "args=[\'self\', \'full_name\', \'full_shape\', \'var_offset\', \'var_shape\', \'save_slice_info_def\', \'import_scope\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], " } + member_method { + name: "from_spec" + argspec: "args=[\'cls\', \'spec\'], varargs=None, keywords=None, defaults=None" + } member_method { name: "to_proto" argspec: "args=[\'self\', \'export_scope\'], varargs=None, keywords=None, defaults=[\'None\'], " diff --git a/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt index ba8a37dac5e4f0..4829cda77f6b2c 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt @@ -1944,6 +1944,10 @@ tf_module { name: "GlobalIterId" argspec: "args=[\'name\'], varargs=None, keywords=None, defaults=[\'None\'], " } + member_method { + name: "GlobalShuffleDataset" + argspec: "args=[\'input_dataset\', \'seed\', \'seed2\', \'seed_generator\', \'output_types\', \'output_shapes\', \'reshuffle_each_iteration\', \'metadata\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'\', \'None\'], " + } member_method { name: "Greater" argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " diff --git a/tensorflow/tools/api/golden/v2/tensorflow.saved_model.-save-options.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.saved_model.-save-options.pbtxt index 99aeb4d5057ef9..44764be213ac66 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.saved_model.-save-options.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.saved_model.-save-options.pbtxt @@ -6,6 +6,10 @@ tf_class { name: "experimental_custom_gradients" mtype: "" } + member { + name: "experimental_debug_stripper" + mtype: "" + } member { name: "experimental_image_format" mtype: "" @@ -40,6 +44,6 @@ tf_class { } member_method { name: "__init__" - argspec: "args=[\'self\', \'namespace_whitelist\', \'save_debug_info\', \'function_aliases\', \'experimental_io_device\', \'experimental_variable_policy\', \'experimental_custom_gradients\', \'experimental_image_format\', \'experimental_skip_saver\', \'experimental_sharding_callback\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\', \'None\', \'None\', \'True\', \'False\', \'False\', \'None\'], " + argspec: "args=[\'self\', \'namespace_whitelist\', \'save_debug_info\', \'function_aliases\', \'experimental_debug_stripper\', \'experimental_io_device\', \'experimental_variable_policy\', \'experimental_custom_gradients\', \'experimental_image_format\', \'experimental_skip_saver\', \'experimental_sharding_callback\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\', \'False\', \'None\', \'None\', \'True\', \'False\', \'False\', \'None\'], " } } diff --git a/tensorflow/tools/api/golden/v2/tensorflow.train.experimental.-max-shard-size-policy.-max-shard-size-partitioner.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.train.experimental.-max-shard-size-policy.-max-shard-size-partitioner.pbtxt new file mode 100644 index 00000000000000..f546ec054e340c --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.train.experimental.-max-shard-size-policy.-max-shard-size-partitioner.pbtxt @@ -0,0 +1,12 @@ +path: "tensorflow.train.experimental.MaxShardSizePolicy.MaxShardSizePartitioner" +tf_class { + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + } + member_method { + name: "get_shards" + argspec: "args=[\'self\', \'max_shard_size\', \'shardable_tensors\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.train.experimental.-max-shard-size-policy.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.train.experimental.-max-shard-size-policy.pbtxt index eeb8a04569157a..d2e902bc6d235d 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.train.experimental.-max-shard-size-policy.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.train.experimental.-max-shard-size-policy.pbtxt @@ -3,6 +3,10 @@ tf_class { is_instance: "" is_instance: "" is_instance: "" + member { + name: "MaxShardSizePartitioner" + mtype: "" + } member { name: "description" mtype: "" diff --git a/tensorflow/tools/ci_build/build_scripts/ARM_SKIP_TESTS_EXTENDED.sh b/tensorflow/tools/ci_build/build_scripts/ARM_SKIP_TESTS_EXTENDED.sh index 09c56cbdcc48bf..63b34e895ef5ef 100644 --- a/tensorflow/tools/ci_build/build_scripts/ARM_SKIP_TESTS_EXTENDED.sh +++ b/tensorflow/tools/ci_build/build_scripts/ARM_SKIP_TESTS_EXTENDED.sh @@ -18,6 +18,4 @@ set -x source tensorflow/tools/ci_build/build_scripts/ARM_SKIP_TESTS.sh ARM_SKIP_TESTS="${ARM_SKIP_TESTS} \ --//tensorflow/core/grappler/optimizers:auto_mixed_precision_test_cpu \ --//tensorflow/core/grappler/optimizers:remapper_test_cpu \ " diff --git a/tensorflow/tools/ci_build/release/common.sh b/tensorflow/tools/ci_build/release/common.sh index 741bc643409102..340db2d9a529bf 100644 --- a/tensorflow/tools/ci_build/release/common.sh +++ b/tensorflow/tools/ci_build/release/common.sh @@ -20,10 +20,10 @@ LATEST_BAZEL_VERSION=6.5.0 # LINT.ThenChange( # //tensorflow/opensource_only/.bazelversion, +# //tensorflow/opensource_only/ci/official/requirements_updater/.bazelversion # //tensorflow/tools/ci_build/install/install_bazel.sh, # //tensorflow/tools/ci_build/install/install_bazel_from_source.sh, -# //tensorflow/tools/toolchains/cross_compile/cc/cc_toolchain_config.bzl, -# //tensorflow_estimator/google/kokoro/common.sh) +# //tensorflow/tools/toolchains/cross_compile/cc/cc_toolchain_config.bzl) # Run flaky functions with retries. # run_with_retry cmd @@ -106,8 +106,7 @@ function update_bazel_linux { which bazel bazel version } -# LINT.ThenChange( -# //tensorflow_estimator/google/kokoro/common.sh) +# LINT.ThenChange() function install_ubuntu_16_pip_deps { PIP_CMD="pip" diff --git a/tensorflow/tools/optimization/BUILD b/tensorflow/tools/optimization/BUILD index fc12208932e00b..6c2b53f0da4dc8 100644 --- a/tensorflow/tools/optimization/BUILD +++ b/tensorflow/tools/optimization/BUILD @@ -28,6 +28,9 @@ tf_cuda_library( "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", "//tensorflow/core:tensorflow", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:string_view", ], ) @@ -39,14 +42,12 @@ tf_cc_binary( "//tensorflow/compiler/jit:xla_cpu_jit", "//tensorflow/compiler/jit:xla_gpu_jit", "//tensorflow/compiler/tf2xla:xla_compiler", - "//tensorflow/core:core_cpu", "//tensorflow/core:core_cpu_base", - "//tensorflow/core:framework", "//tensorflow/core:framework_internal", "//tensorflow/core:framework_lite", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", "//tensorflow/core:tensorflow", - "@com_google_absl//absl/strings", + "@local_tsl//tsl/platform:status", ], ) diff --git a/tensorflow/tools/optimization/gpu_optimization_pass_runner_main.cc b/tensorflow/tools/optimization/gpu_optimization_pass_runner_main.cc index 563b1d5197cb71..ad2856332dd0a0 100644 --- a/tensorflow/tools/optimization/gpu_optimization_pass_runner_main.cc +++ b/tensorflow/tools/optimization/gpu_optimization_pass_runner_main.cc @@ -17,14 +17,16 @@ limitations under the License. // --output_file_path=/tmp/output.pbtxt // --optimization_pass=NameOfGraphOptimizationPass -#include "absl/strings/str_cat.h" -#include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/core/framework/types.h" #include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/init_main.h" +#include "tensorflow/core/platform/types.h" #include "tensorflow/core/protobuf/config.pb.h" #include "tensorflow/core/util/command_line_flags.h" #include "tensorflow/tools/optimization/optimization_pass_runner.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/status.h" namespace tensorflow { namespace { diff --git a/tensorflow/tools/optimization/optimization_pass_runner.cc b/tensorflow/tools/optimization/optimization_pass_runner.cc index 33bd9caf37452a..502977c4f72eb8 100644 --- a/tensorflow/tools/optimization/optimization_pass_runner.cc +++ b/tensorflow/tools/optimization/optimization_pass_runner.cc @@ -23,11 +23,12 @@ limitations under the License. #include #include -#include "tensorflow/core/common_runtime/device.h" +#include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" #include "tensorflow/core/common_runtime/device_set.h" #include "tensorflow/core/common_runtime/graph_constructor.h" #include "tensorflow/core/common_runtime/optimization_registry.h" -#include "tensorflow/core/framework/allocator.h" #include "tensorflow/core/framework/device_attributes.pb.h" #include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/function.pb.h" @@ -37,9 +38,10 @@ limitations under the License. #include "tensorflow/core/graph/graph.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/types.h" #include "tensorflow/core/protobuf/config.pb.h" #include "tensorflow/core/public/session_options.h" +#include "tsl/platform/errors.h" namespace tensorflow { namespace { diff --git a/tensorflow/tools/optimization/optimization_pass_runner.h b/tensorflow/tools/optimization/optimization_pass_runner.h index 0f7e3cfb56461d..0b96ce3e5a9d47 100644 --- a/tensorflow/tools/optimization/optimization_pass_runner.h +++ b/tensorflow/tools/optimization/optimization_pass_runner.h @@ -19,8 +19,11 @@ limitations under the License. #include #include +#include "absl/strings/string_view.h" #include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/common_runtime/optimization_registry.h" +#include "tensorflow/core/framework/device.h" +#include "tensorflow/core/framework/types.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/protobuf/config.pb.h" diff --git a/tensorflow/tools/pip_package/BUILD b/tensorflow/tools/pip_package/BUILD index 0209ada2f68c74..91be8d17531433 100644 --- a/tensorflow/tools/pip_package/BUILD +++ b/tensorflow/tools/pip_package/BUILD @@ -3,7 +3,7 @@ load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda") load("@local_config_syslibs//:build_defs.bzl", "if_not_system_lib") -load("//tensorflow:tensorflow.bzl", "filegroup_as_file", "transitive_hdrs") +load("//tensorflow:tensorflow.bzl", "filegroup_as_file", "if_with_tpu_support", "transitive_hdrs") load("//tensorflow/core/platform:build_config_root.bzl", "tf_additional_license_deps") load("//third_party/mkl:build_defs.bzl", "if_enable_mkl", "if_mkl", "if_mkl_ml") @@ -197,7 +197,10 @@ COMMON_PIP_DEPS = [ "//tensorflow/python/distribute/experimental/rpc:rpc_ops", "//tensorflow/python/util:pywrap_xla_ops", "//tensorflow:tensorflow_py", -] +] + if_with_tpu_support([ + "//tensorflow/python/tools:grpc_tpu_worker", + "//tensorflow/python/tools:grpc_tpu_worker_service", +]) filegroup( name = "licenses", diff --git a/tensorflow/tools/pip_package/setup.py b/tensorflow/tools/pip_package/setup.py index e6a52936a47f1c..a3bf69c4e3a285 100644 --- a/tensorflow/tools/pip_package/setup.py +++ b/tensorflow/tools/pip_package/setup.py @@ -48,7 +48,7 @@ # result for pip. # Also update tensorflow/tensorflow.bzl and # tensorflow/core/public/version.h -_VERSION = '2.16.0' +_VERSION = '2.17.0' # We use the same setup.py for all tensorflow_* packages and for the nightly @@ -105,7 +105,7 @@ def standard_or_nightly(standard, nightly): 'six >= 1.12.0', 'termcolor >= 1.1.0', 'typing_extensions >= 3.6.6', - 'wrapt >= 1.11.0, < 1.15', + 'wrapt >= 1.11.0', # TODO(b/305196096): Remove the <3.12 condition once the pkg is updated 'tensorflow-io-gcs-filesystem >= 0.23.1 ; python_version < "3.12"', # grpcio does not build correctly on big-endian machines due to lack of @@ -121,8 +121,8 @@ def standard_or_nightly(standard, nightly): # dependencies on the release branch is updated to the stable releases (RC # or final). For example, 'keras-nightly ~= 2.14.0.dev' will be replaced by # 'keras >= 2.14.0rc0, < 2.15' on the release branch after the branch cut. - 'tb-nightly ~= 2.16.0.a', - 'keras-nightly ~= 3.0.0.dev', + 'tb-nightly ~= 2.17.0.a', + 'keras-nightly ~= 3.1.0.dev', ] REQUIRED_PACKAGES = [p for p in REQUIRED_PACKAGES if p is not None] @@ -155,9 +155,15 @@ def standard_or_nightly(standard, nightly): # Windows machine. standard_or_nightly('tensorflow-intel', 'tf-nightly-intel') + '==' + _VERSION + ';platform_system=="Windows"', - # Install the TensorFlow package built by Apple if the user is running - # macOS on an Apple Silicon machine. - standard_or_nightly('tensorflow-macos', 'tf-nightly-macos') + '==' + + # Starting with TF 2.16, Apple Silicon packages are uploaded directly + # to the "tensorflow" project on PyPI. In order to not break users who + # are still using `tensorflow-macos`, we upload an empty installer wheel + # to "tensorflow-macos" and add "tensorflow" as its dependency. Please + # note that this will go away in TF 2.17 and `tensorflow-macos` will be + # considered deprecated. Installer packages are not uploaded to + # `tf-nightly-macos`, `tf-nightly` is added below only to avoid breaking + # CI builds. + standard_or_nightly('tensorflow', 'tf-nightly') + '==' + _VERSION + ';platform_system=="Darwin" and platform_machine=="arm64"', ] @@ -249,6 +255,10 @@ def finalize_options(self): def mkdir_and_copy_file(self, header): install_dir = os.path.join(self.install_dir, os.path.dirname(header)) + # Windows platform uses "\" in path strings, the external header location + # expects "/" in paths. Hence, we replaced "\" with "/" for this reason + if platform.system() == 'Windows': + install_dir = install_dir.replace('\\', '/') # Get rid of some extra intervening directories so we can have fewer # directories for -I install_dir = re.sub('/google/protobuf_archive/src', '', install_dir) @@ -320,19 +330,27 @@ def find_files(pattern, root): # $ pip install -f https://storage.googleapis.com/libtpu-releases/index.html # libtpu is built and uploaded to this link every night (PST). if '_tpu' in project_name: - # For tensorflow-tpu releases, use a set libtpu-nightly version; + # For tensorflow-tpu releases, use a set libtpu version; # For tf-nightly-tpu, use the most recent libtpu-nightly. Because of the # timing of these tests, the UTC date from eight hours ago is expected to be a # valid version. _libtpu_version = standard_or_nightly( - '0.1.dev20231018', + '2.16.0rc0', '0.1.dev' + ( datetime.datetime.now(tz=datetime.timezone.utc) - datetime.timedelta(hours=8) ).strftime('%Y%m%d'), ) - REQUIRED_PACKAGES.append([f'libtpu-nightly=={_libtpu_version}']) + if _libtpu_version.startswith('0.1'): + REQUIRED_PACKAGES.append([f'libtpu-nightly=={_libtpu_version}']) + else: + REQUIRED_PACKAGES.append([f'libtpu=={_libtpu_version}']) + CONSOLE_SCRIPTS.extend([ + 'start_grpc_tpu_worker = tensorflow.python.tools.grpc_tpu_worker:run', + ('start_grpc_tpu_service = ' + 'tensorflow.python.tools.grpc_tpu_worker_service:run'), + ]) if os.name == 'nt': EXTENSION_NAME = 'python/_pywrap_tensorflow_internal.pyd' diff --git a/tensorflow/tools/pip_package/v2/BUILD b/tensorflow/tools/pip_package/v2/BUILD index 4f63f8c37d772f..99996a4e8fda25 100644 --- a/tensorflow/tools/pip_package/v2/BUILD +++ b/tensorflow/tools/pip_package/v2/BUILD @@ -3,7 +3,7 @@ load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda") load("@local_config_syslibs//:build_defs.bzl", "if_not_system_lib") -load("//tensorflow:tensorflow.bzl", "transitive_hdrs") +load("//tensorflow:tensorflow.bzl", "if_with_tpu_support", "transitive_hdrs") load("//tensorflow/core/platform:build_config_root.bzl", "tf_additional_license_deps") load("//third_party/mkl:build_defs.bzl", "if_enable_mkl", "if_mkl", "if_mkl_ml") load("//tensorflow/tools/pip_package/v2/utils:data_deps.bzl", "collect_data_files") @@ -226,7 +226,10 @@ COMMON_PIP_DEPS = [ "//tensorflow/python/util:pywrap_xla_ops", "//tensorflow:tensorflow_py", "//tensorflow/tools/compatibility:tf_upgrade_v2", -] +] + if_with_tpu_support([ + "//tensorflow/python/tools:grpc_tpu_worker", + "//tensorflow/python/tools:grpc_tpu_worker_service", +]) py_binary( name = "build_pip_package_py", diff --git a/tensorflow/tools/pip_package/v2/setup.py b/tensorflow/tools/pip_package/v2/setup.py index cd5525963eabe3..9ad5b0c34c2698 100644 --- a/tensorflow/tools/pip_package/v2/setup.py +++ b/tensorflow/tools/pip_package/v2/setup.py @@ -48,7 +48,7 @@ # result for pip. # Also update tensorflow/tensorflow.bzl and # tensorflow/core/public/version.h -_VERSION = '2.16.0' +_VERSION = '2.17.0' # We use the same setup.py for all tensorflow_* packages and for the nightly @@ -98,7 +98,7 @@ def standard_or_nightly(standard, nightly): 'six >= 1.12.0', 'termcolor >= 1.1.0', 'typing_extensions >= 3.6.6', - 'wrapt >= 1.11.0, < 1.15', + 'wrapt >= 1.11.0', # TODO(b/305196096): Remove the <3.12 condition once the pkg is updated 'tensorflow-io-gcs-filesystem >= 0.23.1 ; python_version < "3.12"', # grpcio does not build correctly on big-endian machines due to lack of @@ -114,8 +114,8 @@ def standard_or_nightly(standard, nightly): # dependencies on the release branch is updated to the stable releases (RC # or final). For example, 'keras-nightly ~= 2.14.0.dev' will be replaced by # 'keras >= 2.14.0rc0, < 2.15' on the release branch after the branch cut. - 'tb-nightly ~= 2.16.0.a', - 'keras-nightly ~= 3.0.0.dev', + 'tb-nightly ~= 2.17.0.a', + 'keras-nightly ~= 3.1.0.dev', ] REQUIRED_PACKAGES = [p for p in REQUIRED_PACKAGES if p is not None] @@ -242,6 +242,9 @@ def finalize_options(self): def mkdir_and_copy_file(self, header): install_dir = os.path.join(self.install_dir, os.path.dirname(header)) + # Windows platform uses "\" in path strings, the external header location + # expects "/" in paths. Hence, we replaced "\" with "/" for this reason + install_dir = install_dir.replace('\\', '/') # Get rid of some extra intervening directories so we can have fewer # directories for -I install_dir = re.sub('/google/protobuf_archive/src', '', install_dir) @@ -314,19 +317,27 @@ def find_files(pattern, root): # https://storage.googleapis.com/libtpu-releases/index.html # libtpu is built and uploaded to this link every night (PST). if '_tpu' in project_name: - # For tensorflow-tpu releases, use a set libtpu-nightly version; + # For tensorflow-tpu releases, use a set libtpu version; # For tf-nightly-tpu, use the most recent libtpu-nightly. Because of the # timing of these tests, the UTC date from eight hours ago is expected to be a # valid version. _libtpu_version = standard_or_nightly( - '0.1.dev20231018', + '2.16.0rc0', '0.1.dev' + ( datetime.datetime.now(tz=datetime.timezone.utc) - datetime.timedelta(hours=8) ).strftime('%Y%m%d'), ) - REQUIRED_PACKAGES.append([f'libtpu-nightly=={_libtpu_version}']) + if _libtpu_version.startswith('0.1'): + REQUIRED_PACKAGES.append([f'libtpu-nightly=={_libtpu_version}']) + else: + REQUIRED_PACKAGES.append([f'libtpu=={_libtpu_version}']) + CONSOLE_SCRIPTS.extend([ + 'start_grpc_tpu_worker = tensorflow.python.tools.grpc_tpu_worker:run', + ('start_grpc_tpu_service = ' + 'tensorflow.python.tools.grpc_tpu_worker_service:run'), + ]) if os.name == 'nt': EXTENSION_NAME = 'python/_pywrap_tensorflow_internal.pyd' diff --git a/tensorflow/tools/proto_splitter/cc/BUILD b/tensorflow/tools/proto_splitter/cc/BUILD index 1188ed94533864..cbaceb37e12720 100644 --- a/tensorflow/tools/proto_splitter/cc/BUILD +++ b/tensorflow/tools/proto_splitter/cc/BUILD @@ -23,6 +23,7 @@ cc_library( name = "split", hdrs = ["split.h"], deps = [ + ":util", "//tensorflow/tools/proto_splitter:chunk_proto_cc", "//tensorflow/tools/proto_splitter:versions_proto_cc", "@com_google_absl//absl/status", @@ -84,6 +85,7 @@ tf_cc_test( deps = [ ":composable_splitter", ":test_util", + ":util", "//tensorflow/core:lib", "//tensorflow/core:portable_gif_internal", "//tensorflow/core:protos_all_cc", @@ -123,6 +125,7 @@ cc_library( "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:cord", "@com_google_absl//absl/strings:str_format", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:protobuf", @@ -237,6 +240,7 @@ tf_cc_test( ":graph_def_splitter", ":max_size", ":test_util", + ":util", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", "//tensorflow/core:test", @@ -284,6 +288,7 @@ tf_cc_test( deps = [ ":max_size", ":saved_model_splitter", + ":util", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", "//tensorflow/core:test", diff --git a/tensorflow/tools/proto_splitter/cc/composable_splitter_base.cc b/tensorflow/tools/proto_splitter/cc/composable_splitter_base.cc index b02c09c6fa8d62..e3010d2fcdd7c0 100644 --- a/tensorflow/tools/proto_splitter/cc/composable_splitter_base.cc +++ b/tensorflow/tools/proto_splitter/cc/composable_splitter_base.cc @@ -53,7 +53,7 @@ limitations under the License. namespace tensorflow { namespace tools::proto_splitter { -using ::proto_splitter::ChunkMetadata; +using ::tensorflow::proto_splitter::ChunkMetadata; VersionDef ComposableSplitterBase::Version() { VersionDef version; @@ -71,8 +71,7 @@ size_t ComposableSplitterBase::GetInitialSize() { return size_; } -absl::StatusOr*, ChunkedMessage*>> -ComposableSplitterBase::Split() { +absl::StatusOr ComposableSplitterBase::Split() { if (parent_splitter_ != nullptr) { return absl::UnimplementedError( "The `Split` function behavior for children ComposableSplitter has not " @@ -99,14 +98,15 @@ ComposableSplitterBase::Split() { << ". " << chunk_msg; built_ = true; } - return std::make_pair(&chunks_, &chunked_message_); + return (ChunkedProto){.chunks = &chunks_, + .chunked_message = &chunked_message_}; } template static absl::Status WriteToRecordWriter( riegeli::RecordWriter& writer, const std::vector& chunks, ChunkedMessage& chunked_message, - const ::proto_splitter::VersionDef& version) { + const ::tensorflow::proto_splitter::VersionDef& version) { // Export Riegeli / chunked file. ChunkMetadata metadata; *metadata.mutable_message() = chunked_message; @@ -122,17 +122,19 @@ static absl::Status WriteToRecordWriter( LOG(INFO) << "Writing chunk of size " << msg_chunk->ByteSizeLong(); writer.WriteRecord(*msg_chunk); chunk_metadata->set_size(msg_chunk->ByteSizeLong()); - chunk_metadata->set_type(::proto_splitter::ChunkInfo::MESSAGE); + chunk_metadata->set_type( + ::tensorflow::proto_splitter::ChunkInfo::MESSAGE); } else if (std::holds_alternative(chunk)) { auto* msg_chunk = std::get(chunk); writer.WriteRecord(*msg_chunk); chunk_metadata->set_size(msg_chunk->ByteSizeLong()); - chunk_metadata->set_type(::proto_splitter::ChunkInfo::MESSAGE); + chunk_metadata->set_type( + ::tensorflow::proto_splitter::ChunkInfo::MESSAGE); } else { const auto& str_chunk = std::get(chunk); writer.WriteRecord(str_chunk); chunk_metadata->set_size(str_chunk.size()); - chunk_metadata->set_type(::proto_splitter::ChunkInfo::BYTES); + chunk_metadata->set_type(::tensorflow::proto_splitter::ChunkInfo::BYTES); } chunk_metadata->set_offset(writer.LastPos().get().numeric()); } @@ -154,15 +156,16 @@ absl::Status ComposableSplitterBase::Write(std::string file_prefix) { auto split_results = Split(); if (!split_results.ok()) return split_results.status(); - auto& chunks = *split_results.value().first; - auto& chunked_message = *split_results.value().second; + + std::vector* chunks = split_results.value().chunks; + ChunkedMessage* chunked_message = split_results.value().chunked_message; tsl::Env* env = tsl::Env::Default(); TF_RETURN_IF_ERROR(env->RecursivelyCreateDir( std::string{tensorflow::io::Dirname(file_prefix)})); std::string output_path; - if (chunked_message.chunked_fields().empty()) { + if (chunked_message->chunked_fields().empty()) { // Export regular pb. output_path = absl::StrCat(file_prefix, ".pb"); TF_RETURN_IF_ERROR( @@ -174,7 +177,7 @@ absl::Status ComposableSplitterBase::Write(std::string file_prefix) { riegeli::RecordWriter writer((WriterType(output_path))); if (!writer.is_open()) return writer.status(); TF_RETURN_IF_ERROR(WriteToRecordWriter( - writer, chunks, chunked_message, Version())); + writer, *chunks, *chunked_message, Version())); if (!writer.Close()) return writer.status(); } LOG(INFO) << "Splitter output written to " << output_path; @@ -187,11 +190,11 @@ ComposableSplitterBase::WriteToString() { auto split_results = Split(); if (!split_results.ok()) return split_results.status(); - auto& chunks = *split_results.value().first; - auto& chunked_message = *split_results.value().second; + std::vector* chunks = split_results.value().chunks; + ChunkedMessage* chunked_message = split_results.value().chunked_message; std::string output; - if (chunked_message.chunked_fields().empty()) { + if (chunked_message->chunked_fields().empty()) { // Export regular pb. if (!message_->SerializeToString(&output)) return absl::InvalidArgumentError("Serialization to string failed"); @@ -203,7 +206,7 @@ ComposableSplitterBase::WriteToString() { riegeli::RecordWriter writer((WriterType(&output))); if (!writer.is_open()) return writer.status(); TF_RETURN_IF_ERROR(WriteToRecordWriter( - writer, chunks, chunked_message, Version())); + writer, *chunks, *chunked_message, Version())); if (!writer.Close()) return writer.status(); LOG(INFO) << "Splitter output written to string"; return std::make_tuple(output, true); @@ -217,11 +220,11 @@ ComposableSplitterBase::WriteToCord() { auto split_results = Split(); if (!split_results.ok()) return split_results.status(); - auto& chunks = *split_results.value().first; - auto& chunked_message = *split_results.value().second; + std::vector* chunks = split_results.value().chunks; + ChunkedMessage* chunked_message = split_results.value().chunked_message; absl::Cord output; - if (chunked_message.chunked_fields().empty()) { + if (chunked_message->chunked_fields().empty()) { // Export regular pb. if (!message_->SerializeToCord(&output)) return absl::InvalidArgumentError("Serialization to absl::Cord failed"); @@ -233,7 +236,7 @@ ComposableSplitterBase::WriteToCord() { riegeli::RecordWriter writer((WriterType(&output))); if (!writer.is_open()) return writer.status(); TF_RETURN_IF_ERROR(WriteToRecordWriter( - writer, chunks, chunked_message, Version())); + writer, *chunks, *chunked_message, Version())); if (!writer.Close()) return writer.status(); LOG(INFO) << "Splitter output written to absl::Cord"; return std::make_tuple(output, true); diff --git a/tensorflow/tools/proto_splitter/cc/composable_splitter_base.h b/tensorflow/tools/proto_splitter/cc/composable_splitter_base.h index a37a3c61ca0a02..611f075723f43d 100644 --- a/tensorflow/tools/proto_splitter/cc/composable_splitter_base.h +++ b/tensorflow/tools/proto_splitter/cc/composable_splitter_base.h @@ -19,7 +19,6 @@ limitations under the License. #include #include #include -#include #include #include "absl/status/status.h" @@ -56,9 +55,7 @@ class ComposableSplitterBase : public Splitter { // ChunkedMessage: Metadata about the chunked fields.) // If the message is not split, `chunks` should only contain the original // message. - absl::StatusOr< - std::pair*, ::proto_splitter::ChunkedMessage*>> - Split() override; + absl::StatusOr Split() override; // Serializes a proto to disk. // The writer writes all chunks into a Riegeli file. The chunk metadata @@ -109,7 +106,7 @@ class ComposableSplitterBase : public Splitter { bool built_; tsl::protobuf::Message* message_; std::vector chunks_; - ::proto_splitter::ChunkedMessage chunked_message_; + ::tensorflow::proto_splitter::ChunkedMessage chunked_message_; ComposableSplitterBase* parent_splitter_; std::vector* fields_in_parent_; size_t size_ = 0; diff --git a/tensorflow/tools/proto_splitter/cc/composable_splitter_test.cc b/tensorflow/tools/proto_splitter/cc/composable_splitter_test.cc index 8efdf36caee628..59e99ff4a3ec98 100644 --- a/tensorflow/tools/proto_splitter/cc/composable_splitter_test.cc +++ b/tensorflow/tools/proto_splitter/cc/composable_splitter_test.cc @@ -25,6 +25,8 @@ limitations under the License. #include #include "absl/status/status.h" #include "absl/strings/cord.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" #include "riegeli/bytes/cord_reader.h" // from @riegeli #include "riegeli/bytes/fd_reader.h" // from @riegeli #include "riegeli/bytes/string_reader.h" // from @riegeli @@ -35,6 +37,7 @@ limitations under the License. #include "tensorflow/core/platform/path.h" #include "tensorflow/core/platform/types.h" #include "tensorflow/tools/proto_splitter/cc/test_util.h" +#include "tensorflow/tools/proto_splitter/cc/util.h" #include "tensorflow/tools/proto_splitter/chunk.pb.h" #include "tensorflow/tools/proto_splitter/testdata/test_message.pb.h" #include "tsl/lib/core/status_test_util.h" @@ -48,10 +51,10 @@ namespace tensorflow { namespace tools::proto_splitter { namespace { -using ::proto_splitter::ChunkedMessage; -using ::proto_splitter::ChunkMetadata; -using ::proto_splitter_testdata::RepeatedRepeatedString; -using ::proto_splitter_testdata::RepeatedString; +using ::tensorflow::proto_splitter::ChunkedMessage; +using ::tensorflow::proto_splitter::ChunkMetadata; +using ::tensorflow::proto_splitter_testdata::RepeatedRepeatedString; +using ::tensorflow::proto_splitter_testdata::RepeatedString; using ::testing::HasSubstr; using ::testing::SizeIs; using tsl::testing::StatusIs; @@ -95,13 +98,14 @@ TEST(RepeatedStringSplitterTest, TestSplitChunks) { auto message = SetUpRepeatedString(strings); RepeatedStringSplitter splitter = RepeatedStringSplitter(&message); TF_ASSERT_OK_AND_ASSIGN(auto ret, splitter.Split()); - auto chunks = ret.first; - auto chunked_message = ret.second; + std::vector* chunks = ret.chunks; + ASSERT_NE(chunks, nullptr); + ChunkedMessage* chunked_message = ret.chunked_message; + ASSERT_NE(chunked_message, nullptr); for (int i = 0; i < chunks->size(); i++) { - auto chunk = chunks->at(i); - EXPECT_TRUE(std::holds_alternative(chunk)); - EXPECT_EQ(strings[i], std::get(chunk)); + MessageBytes chunk = (*chunks)[i]; + EXPECT_THAT(chunk, ::testing::VariantWith(strings[i])); } EXPECT_THAT(*chunked_message, EqualsProto(R"pb(chunked_fields { field_tag { field: 1 } @@ -121,8 +125,8 @@ TEST(RepeatedStringSplitterTest, TestSplitChunks) { // Calling split again should return the same chunks/ChunkedMessage. TF_ASSERT_OK_AND_ASSIGN(auto ret2, splitter.Split()); - auto chunks2 = ret2.first; - auto chunked_message2 = ret2.second; + std::vector* chunks2 = ret2.chunks; + ChunkedMessage* chunked_message2 = ret2.chunked_message; EXPECT_EQ(chunks2, chunks); EXPECT_EQ(chunked_message2, chunked_message); } @@ -135,7 +139,7 @@ static void CheckChunks(riegeli::RecordReader& reader, reader.SeekBack(); reader.ReadRecord(chunk_metadata); - auto chunk_info = chunk_metadata.chunks(); + auto& chunk_info = chunk_metadata.chunks(); EXPECT_EQ(chunk_info.size(), strings.size()); for (int i = 0; i < chunk_info.size(); i++) { reader.Seek(chunk_info[i].offset()); @@ -220,11 +224,13 @@ TEST(RepeatedStringSplitterTest, TestNoSplit) { RepeatedString message; // No strings RepeatedStringSplitter splitter = RepeatedStringSplitter(&message); TF_ASSERT_OK_AND_ASSIGN(auto ret, splitter.Split()); - auto chunks = ret.first; - auto chunked_message = ret.second; + std::vector* chunks = ret.chunks; + ASSERT_NE(chunks, nullptr); + ChunkedMessage* chunked_message = ret.chunked_message; + ASSERT_NE(chunked_message, nullptr); EXPECT_THAT(*chunks, SizeIs(1)); - EXPECT_THAT(*std::get(chunks->at(0)), + EXPECT_THAT(*std::get((*chunks)[0]), EqualsProto("")); EXPECT_THAT(*chunked_message, EqualsProto(R"pb(chunk_index: 0)pb")); } @@ -266,8 +272,10 @@ TEST(ComposableTest, RepeatedRepeatedStringTest) { RepeatedRepeatedStringSplitter splitter = RepeatedRepeatedStringSplitter(&message); TF_ASSERT_OK_AND_ASSIGN(auto ret, splitter.Split()); - auto chunks = ret.first; - auto chunked_message = ret.second; + std::vector* chunks = ret.chunks; + ASSERT_NE(chunks, nullptr); + ChunkedMessage* chunked_message = ret.chunked_message; + ASSERT_NE(chunked_message, nullptr); std::vector expected_chunks = {"piece-1", "piece-2", "piece-3", "new-strings-1", "foo-1", "foo-2"}; @@ -275,13 +283,13 @@ TEST(ComposableTest, RepeatedRepeatedStringTest) { // RepeatedRepeatedStringSplitter sets the first chunk as the user-provided // message, so the expected size is 7. EXPECT_THAT(*chunks, SizeIs(7)); - EXPECT_THAT(*std::get(chunks->at(0)), + EXPECT_THAT(*std::get((*chunks)[0]), EqualsProto(message)); for (int i = 1; i < chunks->size(); i++) { - auto chunk = chunks->at(i); - EXPECT_TRUE(std::holds_alternative(chunk)); - EXPECT_EQ(expected_chunks[i - 1], std::get(chunk)); + MessageBytes chunk = (*chunks)[i]; + EXPECT_THAT(chunk, + ::testing::VariantWith(expected_chunks[i - 1])); } // message.rs[2].strings[0] (value = "foo-1") should be the chunk at index 5. @@ -305,7 +313,8 @@ TEST(ComposableTest, ChildSplitterTest) { TF_EXPECT_OK(child.BuildChunks()); TF_ASSERT_OK_AND_ASSIGN(auto ret, splitter.Split()); - auto chunks = ret.first; + std::vector* chunks = ret.chunks; + ASSERT_NE(chunks, nullptr); EXPECT_THAT(*chunks, SizeIs(5)); // Total 5 chunks should be generated. } diff --git a/tensorflow/tools/proto_splitter/cc/graph_def_splitter_test.cc b/tensorflow/tools/proto_splitter/cc/graph_def_splitter_test.cc index bbb2587a2d3c39..1d98a3a390bd33 100644 --- a/tensorflow/tools/proto_splitter/cc/graph_def_splitter_test.cc +++ b/tensorflow/tools/proto_splitter/cc/graph_def_splitter_test.cc @@ -32,6 +32,7 @@ limitations under the License. #include "tensorflow/core/platform/test.h" #include "tensorflow/tools/proto_splitter/cc/max_size.h" #include "tensorflow/tools/proto_splitter/cc/test_util.h" +#include "tensorflow/tools/proto_splitter/cc/util.h" #include "tensorflow/tools/proto_splitter/testdata/test_message.pb.h" #include "tsl/lib/core/status_test_util.h" #include "tsl/platform/protobuf.h" @@ -41,6 +42,8 @@ namespace tensorflow { namespace tools::proto_splitter { namespace { +using ::tensorflow::proto_splitter::ChunkedMessage; + // Ensures that all Messages are less than the max size. std::string chunks are // not limited by the max size, so they are ignored in this check. #define EXPECT_CHUNK_SIZES(chunks, max_size) \ @@ -68,15 +71,29 @@ TEST(GraphDefSplitterTest, TestLargeConstant) { TF_EXPECT_OK(tensorflow::ReadBinaryProto(tensorflow::Env::Default(), graph_def_path, &proto)); - EXPECT_GE(proto.ByteSize(), GetMaxSize()); - auto large_constant_1 = + EXPECT_GE(proto.ByteSizeLong(), GetMaxSize()); + std::string large_constant_1, large_constant_2; + const std::variant& tensor_constant_1 = proto.node(2).attr().at("value").tensor().tensor_content(); - auto large_constant_2 = + const std::variant& tensor_constant_2 = proto.node(4).attr().at("value").tensor().tensor_content(); + if (std::holds_alternative(tensor_constant_1)) { + large_constant_1 = std::get(tensor_constant_1); + } else { + absl::CopyCordToString(std::get(tensor_constant_1), + &large_constant_1); + } + if (std::holds_alternative(tensor_constant_2)) { + large_constant_2 = std::get(tensor_constant_2); + } else { + absl::CopyCordToString(std::get(tensor_constant_2), + &large_constant_2); + } GraphDefSplitter splitter(&proto); TF_ASSERT_OK_AND_ASSIGN(auto x, splitter.Split()); - auto chunked_message = x.second; + ChunkedMessage* chunked_message = x.chunked_message; + ASSERT_NE(chunked_message, nullptr); EXPECT_THAT(*chunked_message, EqualsProto(R"pb(chunk_index: 0 chunked_fields { @@ -98,13 +115,14 @@ TEST(GraphDefSplitterTest, TestLargeConstant) { message { chunk_index: 2 } })pb")); - auto chunks = x.first; + std::vector* chunks = x.chunks; + ASSERT_NE(chunks, nullptr); EXPECT_CHUNK_SIZES(chunks, max_size); - EXPECT_TRUE(std::holds_alternative(chunks->at(1))); - EXPECT_EQ(large_constant_1, std::get(chunks->at(1))); - EXPECT_TRUE(std::holds_alternative(chunks->at(2))); - EXPECT_EQ(large_constant_2, std::get(chunks->at(2))); + EXPECT_THAT((*chunks)[1], + ::testing::VariantWith(large_constant_1)); + EXPECT_THAT((*chunks)[2], + ::testing::VariantWith(large_constant_2)); } TEST(GraphDefSplitterTest, TestLargeNodes) { @@ -128,7 +146,8 @@ TEST(GraphDefSplitterTest, TestLargeNodes) { GraphDefSplitter splitter(&proto); TF_ASSERT_OK_AND_ASSIGN(auto x, splitter.Split()); - auto chunked_message = x.second; + ChunkedMessage* chunked_message = x.chunked_message; + ASSERT_NE(chunked_message, nullptr); EXPECT_THAT(*chunked_message, EqualsProto(R"pb(chunk_index: 0 chunked_fields { field_tag { field: 1 } @@ -150,29 +169,30 @@ TEST(GraphDefSplitterTest, TestLargeNodes) { field_tag { index: 5 } message { chunk_index: 4 } })pb")); - auto chunks = x.first; + std::vector* chunks = x.chunks; + ASSERT_NE(chunks, nullptr); EXPECT_CHUNK_SIZES(chunks, max_size); EXPECT_TRUE(std::holds_alternative>( - chunks->at(1))); + (*chunks)[1])); EXPECT_TRUE(std::holds_alternative>( - chunks->at(2))); + (*chunks)[2])); EXPECT_TRUE(std::holds_alternative>( - chunks->at(3))); + (*chunks)[3])); EXPECT_TRUE(std::holds_alternative>( - chunks->at(4))); + (*chunks)[4])); EXPECT_THAT( - *std::get>(chunks->at(1)).get(), + *std::get>((*chunks)[1]).get(), EqualsProto(node_1)); EXPECT_THAT( - *std::get>(chunks->at(2)).get(), + *std::get>((*chunks)[2]).get(), EqualsProto(node_2)); EXPECT_THAT( - *std::get>(chunks->at(3)).get(), + *std::get>((*chunks)[3]).get(), EqualsProto(node_3)); EXPECT_THAT( - *std::get>(chunks->at(4)).get(), + *std::get>((*chunks)[4]).get(), EqualsProto(node_5)); } TEST(GraphDefSplitterTest, TestLotsNodes) { @@ -197,7 +217,8 @@ TEST(GraphDefSplitterTest, TestLotsNodes) { TF_ASSERT_OK_AND_ASSIGN(auto x, splitter.Split()); - auto chunked_message = x.second; + ChunkedMessage* chunked_message = x.chunked_message; + ASSERT_NE(chunked_message, nullptr); EXPECT_THAT( *chunked_message, EqualsProto(R"pb(chunk_index: 0 @@ -206,11 +227,12 @@ TEST(GraphDefSplitterTest, TestLotsNodes) { chunked_fields { message { chunk_index: 3 } } chunked_fields { message { chunk_index: 4 } })pb")); - auto chunks = x.first; + std::vector* chunks = x.chunks; + ASSERT_NE(chunks, nullptr); EXPECT_CHUNK_SIZES(chunks, max_size); int actual_node_size = 0; - for (auto chunk : *chunks) { + for (MessageBytes& chunk : *chunks) { GraphDef* message = nullptr; if (std::holds_alternative>( chunk)) { @@ -242,7 +264,8 @@ TEST(GraphDefSplitterTest, TestFunctionLotsOfNodes) { GraphDefSplitter splitter(&proto); TF_ASSERT_OK_AND_ASSIGN(auto x, splitter.Split()); - auto chunks = x.first; + std::vector* chunks = x.chunks; + ASSERT_NE(chunks, nullptr); EXPECT_CHUNK_SIZES(chunks, max_size); } @@ -261,7 +284,8 @@ TEST(GraphDefSplitterTest, TestFunctionLargeNodes) { GraphDefSplitter splitter(&proto); TF_ASSERT_OK_AND_ASSIGN(auto x, splitter.Split()); - auto chunks = x.first; + std::vector* chunks = x.chunks; + ASSERT_NE(chunks, nullptr); EXPECT_CHUNK_SIZES(chunks, max_size); } @@ -280,7 +304,8 @@ TEST(GraphDefSplitterTest, TestGraphAndFunction) { GraphDefSplitter splitter(&proto); TF_ASSERT_OK_AND_ASSIGN(auto x, splitter.Split()); - auto chunks = x.first; + std::vector* chunks = x.chunks; + ASSERT_NE(chunks, nullptr); EXPECT_CHUNK_SIZES(chunks, max_size); TF_ASSERT_OK(splitter.Write("/tmp/hoi")); diff --git a/tensorflow/tools/proto_splitter/cc/saved_model_splitter_test.cc b/tensorflow/tools/proto_splitter/cc/saved_model_splitter_test.cc index 17b72ba7d2a7de..b03bcc118f77c3 100644 --- a/tensorflow/tools/proto_splitter/cc/saved_model_splitter_test.cc +++ b/tensorflow/tools/proto_splitter/cc/saved_model_splitter_test.cc @@ -14,7 +14,9 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/tools/proto_splitter/cc/saved_model_splitter.h" +#include #include +#include #include #include @@ -29,6 +31,7 @@ limitations under the License. #include "tensorflow/core/platform/test.h" #include "tensorflow/core/protobuf/saved_model.pb.h" #include "tensorflow/tools/proto_splitter/cc/max_size.h" +#include "tensorflow/tools/proto_splitter/cc/util.h" #include "tensorflow/tools/proto_splitter/testdata/test_message.pb.h" #include "tsl/lib/core/status_test_util.h" #include "tsl/platform/protobuf.h" @@ -55,7 +58,7 @@ namespace { } \ } while (0) -string NonChunkedSavedModel() { +std::string NonChunkedSavedModel() { return io::JoinPath(testing::TensorFlowSrcRoot(), "cc", "saved_model", "testdata", "chunked_saved_model", "non_chunked_model", "saved_model.pb"); @@ -68,12 +71,13 @@ TEST(SavedModelSplitterTest, TestSplit) { TF_EXPECT_OK(tensorflow::ReadBinaryProto(tensorflow::Env::Default(), NonChunkedSavedModel(), &proto)); - EXPECT_GE(proto.ByteSize(), GetMaxSize()); + EXPECT_GE(proto.ByteSizeLong(), GetMaxSize()); SavedModelSplitter splitter(&proto); TF_ASSERT_OK_AND_ASSIGN(auto x, splitter.Split()); - auto chunks = x.first; + std::vector* chunks = x.chunks; + ASSERT_NE(chunks, nullptr); // Should create a new chunk with the single large constant. EXPECT_EQ(2, chunks->size()); diff --git a/tensorflow/tools/proto_splitter/cc/split.h b/tensorflow/tools/proto_splitter/cc/split.h index 083b2219286bb4..6fe0571337f9c7 100644 --- a/tensorflow/tools/proto_splitter/cc/split.h +++ b/tensorflow/tools/proto_splitter/cc/split.h @@ -16,7 +16,6 @@ limitations under the License. #ifndef TENSORFLOW_TOOLS_PROTO_SPLITTER_CC_SPLIT_H_ #define TENSORFLOW_TOOLS_PROTO_SPLITTER_CC_SPLIT_H_ -#include #include #include #include @@ -24,6 +23,7 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" +#include "tensorflow/tools/proto_splitter/cc/util.h" #include "tensorflow/tools/proto_splitter/chunk.pb.h" #include "tensorflow/tools/proto_splitter/versions.pb.h" #include "tsl/platform/protobuf.h" @@ -31,10 +31,8 @@ limitations under the License. namespace tensorflow { namespace tools::proto_splitter { -using ::proto_splitter::ChunkedMessage; -using ::proto_splitter::VersionDef; -using MessageBytes = std::variant, - tsl::protobuf::Message*, std::string>; +using ::tensorflow::proto_splitter::ChunkedMessage; +using ::tensorflow::proto_splitter::VersionDef; // Interface for proto message splitters. class Splitter { @@ -42,8 +40,7 @@ class Splitter { virtual ~Splitter() = default; // Split message into chunks. - virtual absl::StatusOr*, ChunkedMessage*>> - Split() = 0; + virtual absl::StatusOr Split() = 0; // Write message to disk. virtual absl::Status Write(std::string file_prefix) = 0; diff --git a/tensorflow/tools/proto_splitter/cc/util.cc b/tensorflow/tools/proto_splitter/cc/util.cc index 669f4288c8ef80..4e03d5c94990eb 100644 --- a/tensorflow/tools/proto_splitter/cc/util.cc +++ b/tensorflow/tools/proto_splitter/cc/util.cc @@ -42,7 +42,7 @@ limitations under the License. namespace tensorflow { namespace tools::proto_splitter { -using ::proto_splitter::ChunkedField; +using ::tensorflow::proto_splitter::ChunkedField; namespace { absl::StatusOr FieldInt(const FieldType& field) { @@ -207,33 +207,34 @@ absl::Status AddMapKey(const tsl::protobuf::FieldDescriptor& key_field, } absl::StatusOr GetMapKeyFromFieldIndex( - ::proto_splitter::FieldIndex field_index) { + ::tensorflow::proto_splitter::FieldIndex field_index) { if (!field_index.has_map_key()) return absl::FailedPreconditionError( "Field index doesn't contain a map key."); switch (field_index.map_key().type_case()) { - case ::proto_splitter::FieldIndex::MapKey::TypeCase::kBoolean: + case ::tensorflow::proto_splitter::FieldIndex::MapKey::TypeCase::kBoolean: return field_index.map_key().boolean(); break; - case ::proto_splitter::FieldIndex::MapKey::TypeCase::kS: + case ::tensorflow::proto_splitter::FieldIndex::MapKey::TypeCase::kS: return field_index.map_key().s(); break; - case ::proto_splitter::FieldIndex::MapKey::TypeCase::kI32: + case ::tensorflow::proto_splitter::FieldIndex::MapKey::TypeCase::kI32: return field_index.map_key().i32(); break; - case ::proto_splitter::FieldIndex::MapKey::TypeCase::kI64: + case ::tensorflow::proto_splitter::FieldIndex::MapKey::TypeCase::kI64: // Cast to int type, which may be lossy. We'll deal with it when it // becomes an issue. return static_cast(field_index.map_key().i64()); break; - case ::proto_splitter::FieldIndex::MapKey::TypeCase::kUi32: + case ::tensorflow::proto_splitter::FieldIndex::MapKey::TypeCase::kUi32: return static_cast(field_index.map_key().ui32()); break; - case ::proto_splitter::FieldIndex::MapKey::TypeCase::kUi64: + case ::tensorflow::proto_splitter::FieldIndex::MapKey::TypeCase::kUi64: return static_cast(field_index.map_key().ui64()); break; - case ::proto_splitter::FieldIndex::MapKey::TypeCase::TYPE_NOT_SET: + case ::tensorflow::proto_splitter::FieldIndex::MapKey::TypeCase:: + TYPE_NOT_SET: default: return absl::FailedPreconditionError( absl::StrCat("Unknown map key type: ", field_index.DebugString())); @@ -243,12 +244,12 @@ absl::StatusOr GetMapKeyFromFieldIndex( } // namespace absl::StatusOr> GetFieldTypes( - const tsl::protobuf::RepeatedPtrField<::proto_splitter::FieldIndex>& - field_tags) { + const tsl::protobuf::RepeatedPtrField< + ::tensorflow::proto_splitter::FieldIndex>& field_tags) { std::vector fields; for (int fti = 0; fti < field_tags.size();) { switch (field_tags[fti].kind_case()) { - case ::proto_splitter::FieldIndex::KindCase::kField: + case ::tensorflow::proto_splitter::FieldIndex::KindCase::kField: fields.push_back( Field(static_cast(field_tags[fti].field()), std::nullopt)); fti++; @@ -263,15 +264,15 @@ absl::StatusOr> GetFieldTypes( fields.back().second = map_key; } break; - case ::proto_splitter::FieldIndex::KindCase::kIndex: + case ::tensorflow::proto_splitter::FieldIndex::KindCase::kIndex: return absl::FailedPreconditionError( "Index doesn't belong to any field."); break; - case ::proto_splitter::FieldIndex::KindCase::kMapKey: + case ::tensorflow::proto_splitter::FieldIndex::KindCase::kMapKey: return absl::FailedPreconditionError( "Map key doesn't belong to any field."); break; - case ::proto_splitter::FieldIndex::KindCase::KIND_NOT_SET: + case ::tensorflow::proto_splitter::FieldIndex::KindCase::KIND_NOT_SET: default: return absl::FailedPreconditionError(absl::StrCat( "Unknown field kind: ", field_tags[fti].DebugString())); @@ -745,9 +746,9 @@ absl::StatusOr>> GetRiegeliReader( return reader; } -absl::StatusOr<::proto_splitter::ChunkMetadata> GetChunkMetadata( +absl::StatusOr<::tensorflow::proto_splitter::ChunkMetadata> GetChunkMetadata( riegeli::RecordReader>& reader) { - ::proto_splitter::ChunkMetadata chunk_metadata; + ::tensorflow::proto_splitter::ChunkMetadata chunk_metadata; bool read_metadata_success = reader.Seek(reader.Size().value()) && reader.SeekBack() && reader.ReadRecord(chunk_metadata); @@ -757,7 +758,7 @@ absl::StatusOr<::proto_splitter::ChunkMetadata> GetChunkMetadata( absl::StatusOr ReadChunk( riegeli::RecordReader>& reader, - const ::proto_splitter::ChunkInfo& chunk_info) { + const ::tensorflow::proto_splitter::ChunkInfo& chunk_info) { riegeli::Position pos = chunk_info.offset(); std::string chunk(chunk_info.size(), '\0'); if (reader.Seek(pos) && reader.ReadRecord(chunk)) return chunk; diff --git a/tensorflow/tools/proto_splitter/cc/util.h b/tensorflow/tools/proto_splitter/cc/util.h index bb82d562ba3526..751298025cbd6b 100644 --- a/tensorflow/tools/proto_splitter/cc/util.h +++ b/tensorflow/tools/proto_splitter/cc/util.h @@ -16,6 +16,7 @@ limitations under the License. #define TENSORFLOW_TOOLS_PROTO_SPLITTER_CC_UTIL_H_ #include +#include #include #include #include @@ -23,6 +24,7 @@ limitations under the License. #include #include "absl/status/statusor.h" +#include "absl/strings/cord.h" #include "absl/strings/string_view.h" #include "riegeli/bytes/fd_reader.h" // from @riegeli #include "riegeli/records/record_reader.h" // from @riegeli @@ -32,6 +34,14 @@ limitations under the License. namespace tensorflow { namespace tools::proto_splitter { +using MessageBytes = std::variant, + tsl::protobuf::Message*, std::string>; + +struct ChunkedProto { + std::vector* chunks = nullptr; + ::tensorflow::proto_splitter::ChunkedMessage* chunked_message = nullptr; +}; + // TODO(b/282796592): Consider switching to `tsl::protobuf::FieldPath` in the // future. @@ -44,8 +54,8 @@ using Field = std::pair>; // std::vector, since multiple field tags may correspond to a single // field when the field is repeated or a map. absl::StatusOr> GetFieldTypes( - const tsl::protobuf::RepeatedPtrField<::proto_splitter::FieldIndex>& - field_tags); + const tsl::protobuf::RepeatedPtrField< + ::tensorflow::proto_splitter::FieldIndex>& field_tags); // Sets message.field_desc[field_index] to the data contained in chunk, // according to the (cpp) type described by field_desc. Uses message_callback @@ -111,13 +121,13 @@ absl::StatusOr GetField(const tsl::protobuf::Message& message, const std::vector& fields); // Updates `field_tag` in the ChunkedField proto. -absl::Status AddFieldTag(const tsl::protobuf::Descriptor& desc, - const std::vector& fields, - ::proto_splitter::ChunkedField& chunked_field); +absl::Status AddFieldTag( + const tsl::protobuf::Descriptor& desc, const std::vector& fields, + ::tensorflow::proto_splitter::ChunkedField& chunked_field); -absl::Status AddFieldTag(const tsl::protobuf::Descriptor& desc, - const Field& field, - ::proto_splitter::ChunkedField& chunked_field); +absl::Status AddFieldTag( + const tsl::protobuf::Descriptor& desc, const Field& field, + ::tensorflow::proto_splitter::ChunkedField& chunked_field); // Returns the index of the map key in the map field. If the key is not found, // returns -1. @@ -137,13 +147,13 @@ absl::StatusOr>> GetRiegeliReader( // Read the last chunk, which contains metadata necessary for reading the // remaining chunks. -absl::StatusOr<::proto_splitter::ChunkMetadata> GetChunkMetadata( +absl::StatusOr<::tensorflow::proto_splitter::ChunkMetadata> GetChunkMetadata( riegeli::RecordReader>& reader); // Use the `reader` to read in the chunk specified by `chunk_info`. absl::StatusOr ReadChunk( riegeli::RecordReader>& reader, - const ::proto_splitter::ChunkInfo& chunk_info); + const ::tensorflow::proto_splitter::ChunkInfo& chunk_info); // Returns true if prefix can only be found as a .pb file, and false if a .cpb // file exists. Returns an error if neither .pb nor .cpb exist. diff --git a/tensorflow/tools/proto_splitter/cc/util_test.cc b/tensorflow/tools/proto_splitter/cc/util_test.cc index ebdb6d3d8a37e4..23e9c8067db0e1 100644 --- a/tensorflow/tools/proto_splitter/cc/util_test.cc +++ b/tensorflow/tools/proto_splitter/cc/util_test.cc @@ -36,8 +36,8 @@ namespace tensorflow { namespace tools::proto_splitter { namespace { -using ::proto_splitter::ChunkedField; -using ::proto_splitter_testdata::ManyFields; +using ::tensorflow::proto_splitter::ChunkedField; +using ::tensorflow::proto_splitter_testdata::ManyFields; using ::testing::HasSubstr; using tsl::testing::IsOkAndHolds; using tsl::testing::StatusIs; @@ -66,7 +66,8 @@ tsl::StatusOr MakeManyFields() { })pb"); } -tsl::StatusOr> +tsl::StatusOr< + tsl::protobuf::RepeatedPtrField<::tensorflow::proto_splitter::FieldIndex>> MakeFieldTags() { TF_ASSIGN_OR_RETURN(auto ret, ParseTextProto(R"pb( field_tag { field: 2 } @@ -77,7 +78,8 @@ MakeFieldTags() { return ret.field_tag(); } -tsl::StatusOr> +tsl::StatusOr< + tsl::protobuf::RepeatedPtrField<::tensorflow::proto_splitter::FieldIndex>> MakeFieldTagsTooManyIndices() { TF_ASSIGN_OR_RETURN(auto ret, ParseTextProto(R"pb( field_tag { field: 2 } @@ -89,7 +91,8 @@ MakeFieldTagsTooManyIndices() { return ret.field_tag(); } -tsl::StatusOr> +tsl::StatusOr< + tsl::protobuf::RepeatedPtrField<::tensorflow::proto_splitter::FieldIndex>> MakeFieldTagsTooManyMapKeys() { TF_ASSIGN_OR_RETURN(auto ret, ParseTextProto(R"pb( field_tag { field: 2 } @@ -101,7 +104,8 @@ MakeFieldTagsTooManyMapKeys() { return ret.field_tag(); } -tsl::StatusOr> +tsl::StatusOr< + tsl::protobuf::RepeatedPtrField<::tensorflow::proto_splitter::FieldIndex>> MakeFieldTagsMisplacedIndex() { TF_ASSIGN_OR_RETURN(auto ret, ParseTextProto(R"pb( field_tag { field: 2 } @@ -113,7 +117,8 @@ MakeFieldTagsMisplacedIndex() { return ret.field_tag(); } -tsl::StatusOr> +tsl::StatusOr< + tsl::protobuf::RepeatedPtrField<::tensorflow::proto_splitter::FieldIndex>> MakeFieldTagsMisplacedMapKey() { TF_ASSIGN_OR_RETURN(auto ret, ParseTextProto(R"pb( field_tag { field: 2 } @@ -525,8 +530,8 @@ TEST(UtilTest, TestReadChunk) { reader.Close(); TF_ASSERT_OK(read_metadata.status()); } - ::proto_splitter::ChunkMetadata metadata = read_metadata.value(); - std::vector<::proto_splitter::ChunkInfo> chunks_info( + ::tensorflow::proto_splitter::ChunkMetadata metadata = read_metadata.value(); + std::vector<::tensorflow::proto_splitter::ChunkInfo> chunks_info( metadata.chunks().begin(), metadata.chunks().end()); for (const auto& chunk_info : chunks_info) { diff --git a/tensorflow/tools/proto_splitter/chunk.proto b/tensorflow/tools/proto_splitter/chunk.proto index d4b1fd637009b5..8484e92141ec39 100644 --- a/tensorflow/tools/proto_splitter/chunk.proto +++ b/tensorflow/tools/proto_splitter/chunk.proto @@ -1,6 +1,6 @@ syntax = "proto3"; -package proto_splitter; +package tensorflow.proto_splitter; import "tensorflow/tools/proto_splitter/versions.proto"; diff --git a/tensorflow/tools/proto_splitter/merge.cc b/tensorflow/tools/proto_splitter/merge.cc index 16b96698433da4..63f9079ac29620 100644 --- a/tensorflow/tools/proto_splitter/merge.cc +++ b/tensorflow/tools/proto_splitter/merge.cc @@ -38,11 +38,11 @@ limitations under the License. namespace tensorflow::tools::proto_splitter { -using ::proto_splitter::ChunkedField; -using ::proto_splitter::ChunkedMessage; -using ::proto_splitter::ChunkInfo; -using ::proto_splitter::ChunkMetadata; -using ::proto_splitter::FieldIndex; +using ::tensorflow::proto_splitter::ChunkedField; +using ::tensorflow::proto_splitter::ChunkedMessage; +using ::tensorflow::proto_splitter::ChunkInfo; +using ::tensorflow::proto_splitter::ChunkMetadata; +using ::tensorflow::proto_splitter::FieldIndex; using tools::proto_splitter::GetChunkMetadata; using tools::proto_splitter::GetRiegeliReader; using tools::proto_splitter::OnlyContainsPb; diff --git a/tensorflow/tools/proto_splitter/merge.h b/tensorflow/tools/proto_splitter/merge.h index 994b87574400c0..379f7dd14e6a85 100644 --- a/tensorflow/tools/proto_splitter/merge.h +++ b/tensorflow/tools/proto_splitter/merge.h @@ -41,7 +41,7 @@ class Merger { // TODO(b/282775853): Integrate Splitter return type with Merge input type static absl::Status Merge( const std::vector>& chunks, - const ::proto_splitter::ChunkedMessage& chunked_message, + const ::tensorflow::proto_splitter::ChunkedMessage& chunked_message, tsl::protobuf::Message* merged_message); // Reads a TF SavedModel chunked protobuf from `prefix` (must be .pb or .cpb) @@ -56,7 +56,7 @@ class Merger { // Like `Merger::Read`, but only reads what's specified in `chunk_metadata`. static absl::Status ReadPartial( absl::string_view prefix, - const ::proto_splitter::ChunkMetadata& chunk_metadata, + const ::tensorflow::proto_splitter::ChunkMetadata& chunk_metadata, tsl::protobuf::Message* merged_message); private: @@ -67,9 +67,9 @@ class Merger { // Uses metadata contained in `chunked_message` to fill `merged_message` with // data accessed by the `reader` using `chunks_info`. static absl::Status ReadFields( - const ::proto_splitter::ChunkedMessage& chunked_message, + const ::tensorflow::proto_splitter::ChunkedMessage& chunked_message, riegeli::RecordReader>& reader, - const std::vector<::proto_splitter::ChunkInfo>& + const std::vector<::tensorflow::proto_splitter::ChunkInfo>& chunks_info, // TODO(adamcogdell): this can just be a // RepeatedPtrField tsl::protobuf::Message* merged_message); @@ -80,9 +80,9 @@ class Merger { // value of `op`) to add those fields to `merged_message`. Otherwise, the // field is simply added to `merged_message` using reflection. static absl::Status ProcessField( - const ::proto_splitter::ChunkedField& chunked_field, + const ::tensorflow::proto_splitter::ChunkedField& chunked_field, tsl::protobuf::Message* merged_message, - const std::vector<::proto_splitter::ChunkInfo>& chunks_info, + const std::vector<::tensorflow::proto_splitter::ChunkInfo>& chunks_info, const std::vector>& chunks, riegeli::RecordReader>& reader, MergerOp op); }; diff --git a/tensorflow/tools/proto_splitter/merge_test.cc b/tensorflow/tools/proto_splitter/merge_test.cc index 60d9e6c1d96ee9..82a67360debf11 100644 --- a/tensorflow/tools/proto_splitter/merge_test.cc +++ b/tensorflow/tools/proto_splitter/merge_test.cc @@ -56,13 +56,13 @@ TEST(MergeTest, TestReadRiegeliTreeDepthFirst) { const std::string cpb_path = io::JoinPath(testing::TensorFlowSrcRoot(), "tools/proto_splitter/testdata", "df-split-tree"); - ::proto_splitter_testdata::StringNode merged_tree; + ::tensorflow::proto_splitter_testdata::StringNode merged_tree; TF_ASSERT_OK(Merger::Read(cpb_path, &merged_tree)); const std::string pbtxt_path = io::JoinPath(testing::TensorFlowSrcRoot(), "tools/proto_splitter/testdata", "split-tree"); - ::proto_splitter_testdata::StringNode test_proto; + ::tensorflow::proto_splitter_testdata::StringNode test_proto; TF_ASSERT_OK(tsl::ReadTextProto( tsl::Env::Default(), absl::StrCat(pbtxt_path, ".pbtxt"), &test_proto)); @@ -73,14 +73,14 @@ TEST(MergeTest, TestReadRiegeliTreeBreadthFirst) { const std::string cpb_path = io::JoinPath(testing::TensorFlowSrcRoot(), "tools/proto_splitter/testdata", "bf-split-tree"); - ::proto_splitter_testdata::StringNode merged_tree; + ::tensorflow::proto_splitter_testdata::StringNode merged_tree; TF_ASSERT_OK(Merger::Read(cpb_path, &merged_tree)); const std::string pbtxt_path = io::JoinPath(testing::TensorFlowSrcRoot(), "tools/proto_splitter/testdata", "split-tree"); - ::proto_splitter_testdata::StringNode test_proto; + ::tensorflow::proto_splitter_testdata::StringNode test_proto; TF_ASSERT_OK(tsl::ReadTextProto( tsl::Env::Default(), absl::StrCat(pbtxt_path, ".pbtxt"), &test_proto)); @@ -93,28 +93,29 @@ TEST(MergeTest, TestMergeTreeChunksDepthFirst) { "tools/proto_splitter/testdata", "df-split-tree"); std::vector> chunks; for (const auto& chunk : kDFSplitTreeChunks) { - ::proto_splitter_testdata::StringNode string_node; + ::tensorflow::proto_splitter_testdata::StringNode string_node; ::tsl::protobuf::TextFormat::ParseFromString(chunk, &string_node); std::unique_ptr<::tsl::protobuf::Message> node = - std::make_unique<::proto_splitter_testdata::StringNode>(string_node); + std::make_unique<::tensorflow::proto_splitter_testdata::StringNode>( + string_node); chunks.push_back(std::move(node)); } std::string split_tree_metadata; TF_ASSERT_OK(tsl::ReadFileToString( tsl::Env::Default(), absl::StrCat(path, ".pbtxt"), &split_tree_metadata)); - ::proto_splitter::ChunkedMessage chunked_message; + ::tensorflow::proto_splitter::ChunkedMessage chunked_message; ::tsl::protobuf::TextFormat::ParseFromString(split_tree_metadata, &chunked_message); - ::proto_splitter_testdata::StringNode merged_tree; + ::tensorflow::proto_splitter_testdata::StringNode merged_tree; TF_ASSERT_OK(Merger::Merge(chunks, chunked_message, &merged_tree)); const std::string pbtxt_path = io::JoinPath(testing::TensorFlowSrcRoot(), "tools/proto_splitter/testdata", "split-tree"); - ::proto_splitter_testdata::StringNode test_proto; + ::tensorflow::proto_splitter_testdata::StringNode test_proto; TF_ASSERT_OK(tsl::ReadTextProto( tsl::Env::Default(), absl::StrCat(pbtxt_path, ".pbtxt"), &test_proto)); @@ -127,28 +128,29 @@ TEST(MergeTest, TestMergeTreeChunksBreadthFirst) { "tools/proto_splitter/testdata", "bf-split-tree"); std::vector> chunks; for (const auto& chunk : kBFSplitTreeChunks) { - ::proto_splitter_testdata::StringNode string_node; + ::tensorflow::proto_splitter_testdata::StringNode string_node; ::tsl::protobuf::TextFormat::ParseFromString(chunk, &string_node); std::unique_ptr<::tsl::protobuf::Message> node = - std::make_unique<::proto_splitter_testdata::StringNode>(string_node); + std::make_unique<::tensorflow::proto_splitter_testdata::StringNode>( + string_node); chunks.push_back(std::move(node)); } std::string split_tree_metadata; TF_ASSERT_OK(tsl::ReadFileToString( tsl::Env::Default(), absl::StrCat(path, ".pbtxt"), &split_tree_metadata)); - ::proto_splitter::ChunkedMessage chunked_message; + ::tensorflow::proto_splitter::ChunkedMessage chunked_message; ::tsl::protobuf::TextFormat::ParseFromString(split_tree_metadata, &chunked_message); - ::proto_splitter_testdata::StringNode merged_tree; + ::tensorflow::proto_splitter_testdata::StringNode merged_tree; TF_ASSERT_OK(Merger::Merge(chunks, chunked_message, &merged_tree)); const std::string pbtxt_path = io::JoinPath(testing::TensorFlowSrcRoot(), "tools/proto_splitter/testdata", "split-tree"); - ::proto_splitter_testdata::StringNode test_proto; + ::tensorflow::proto_splitter_testdata::StringNode test_proto; TF_ASSERT_OK(tsl::ReadTextProto( tsl::Env::Default(), absl::StrCat(pbtxt_path, ".pbtxt"), &test_proto)); @@ -201,10 +203,10 @@ TEST(MergeTest, TestReadManyField) { const std::string path = io::JoinPath(testing::TensorFlowSrcRoot(), "tools/proto_splitter/testdata", "many-field"); - ::proto_splitter_testdata::ManyFields merged_many_field; + ::tensorflow::proto_splitter_testdata::ManyFields merged_many_field; TF_ASSERT_OK(Merger::Read(path, &merged_many_field)); - ::proto_splitter_testdata::ManyFields test_many_field; + ::tensorflow::proto_splitter_testdata::ManyFields test_many_field; TF_ASSERT_OK(tsl::ReadTextProto( tsl::Env::Default(), absl::StrCat(path, ".pbtxt"), &test_many_field)); @@ -251,8 +253,9 @@ TEST(MergeTest, TestReadPartial) { reader.Close(); TF_ASSERT_OK(read_metadata.status()); } - ::proto_splitter::ChunkMetadata chunk_metadata = read_metadata.value(); - ::proto_splitter::ChunkMetadata partial_chunk_metadata; + ::tensorflow::proto_splitter::ChunkMetadata chunk_metadata = + read_metadata.value(); + ::tensorflow::proto_splitter::ChunkMetadata partial_chunk_metadata; partial_chunk_metadata.mutable_chunks()->CopyFrom(chunk_metadata.chunks()); partial_chunk_metadata.mutable_message()->set_chunk_index( chunk_metadata.message().chunk_index()); diff --git a/tensorflow/tools/proto_splitter/testdata/test_message.proto b/tensorflow/tools/proto_splitter/testdata/test_message.proto index f09f4198dc4277..8bca5cece8c70e 100644 --- a/tensorflow/tools/proto_splitter/testdata/test_message.proto +++ b/tensorflow/tools/proto_splitter/testdata/test_message.proto @@ -1,6 +1,6 @@ syntax = "proto3"; -package proto_splitter_testdata; +package tensorflow.proto_splitter_testdata; message RepeatedString { repeated string strings = 1; @@ -24,4 +24,4 @@ message ManyFields { message StringNode { string val = 1; repeated StringNode child_nodes = 2; -} \ No newline at end of file +} diff --git a/tensorflow/tools/proto_splitter/versions.proto b/tensorflow/tools/proto_splitter/versions.proto index de6f020b99cbfc..abd89d3ba3bb05 100644 --- a/tensorflow/tools/proto_splitter/versions.proto +++ b/tensorflow/tools/proto_splitter/versions.proto @@ -1,6 +1,6 @@ syntax = "proto3"; -package proto_splitter; +package tensorflow.proto_splitter; option cc_enable_arenas = true; diff --git a/tensorflow/tools/test/performance.bzl b/tensorflow/tools/test/performance.bzl index f918da44589729..1576299656943c 100644 --- a/tensorflow/tools/test/performance.bzl +++ b/tensorflow/tools/test/performance.bzl @@ -49,7 +49,7 @@ def tf_cc_logged_benchmark( deps = [ "@absl_py//absl:app", "@absl_py//absl/flags", - "//tensorflow/core:protos_all_py_pb2", + "@org_tensorflow//tensorflow/core:protos_all_py", "//tensorflow/python/platform:gfile", "//tensorflow/python/platform:test", "//tensorflow/python/platform:tf_logging", diff --git a/tensorflow/tools/tf_sig_build_dockerfiles/Dockerfile b/tensorflow/tools/tf_sig_build_dockerfiles/Dockerfile index c8253ec6213875..a659294bb2cf77 100644 --- a/tensorflow/tools/tf_sig_build_dockerfiles/Dockerfile +++ b/tensorflow/tools/tf_sig_build_dockerfiles/Dockerfile @@ -28,6 +28,9 @@ COPY setup.cuda.sh /setup.cuda.sh COPY devel.packages.txt /devel.packages.txt RUN /setup.sources.sh && /setup.packages.sh /devel.packages.txt && /setup.cuda.sh +# Make sure clang is on the path +RUN ln -s /usr/lib/llvm-17/bin/clang /usr/bin/clang + # Install various tools. # - bats: bash unit testing framework # - bazelisk: always use the correct bazel version diff --git a/tensorflow/tools/tf_sig_build_dockerfiles/devel.packages.txt b/tensorflow/tools/tf_sig_build_dockerfiles/devel.packages.txt index 827d1d396d0031..49db07a0f190ab 100644 --- a/tensorflow/tools/tf_sig_build_dockerfiles/devel.packages.txt +++ b/tensorflow/tools/tf_sig_build_dockerfiles/devel.packages.txt @@ -36,10 +36,9 @@ autoconf automake build-essential ca-certificates -# TODO(b/308399490) Remove CMake once dm-tree (Keras dependency) has 3.12 wheels -cmake llvm-17 clang-17 +clang-tidy-17 lld-17 clang-format-12 colordiff diff --git a/tensorflow/tools/tf_sig_build_dockerfiles/devel.requirements.txt b/tensorflow/tools/tf_sig_build_dockerfiles/devel.requirements.txt index 6e5bbbaa16e8f1..480cdc54385e5a 100644 --- a/tensorflow/tools/tf_sig_build_dockerfiles/devel.requirements.txt +++ b/tensorflow/tools/tf_sig_build_dockerfiles/devel.requirements.txt @@ -43,8 +43,6 @@ scipy ~= 1.11.3; python_version >= '3.12' # Earliest version for Python 3.12 # Required for TFLite import from JAX tests jax ~= 0.4.1; python_version <= '3.11' jaxlib ~= 0.4.1; python_version <= '3.11' # Earliest version for Python 3.11 -# Needs to be addressed. Unblocked 2.4 branchcut cl/338377048 -PyYAML ~= 6.0 # For uploading auditwheel ~= 5.0.0 twine ~= 3.6.0 diff --git a/tensorflow/tools/toolchains/remote_config/configs.bzl b/tensorflow/tools/toolchains/remote_config/configs.bzl index 4d99c0ff41333f..9555f225773cc8 100644 --- a/tensorflow/tools/toolchains/remote_config/configs.bzl +++ b/tensorflow/tools/toolchains/remote_config/configs.bzl @@ -627,7 +627,7 @@ def initialize_rbe_configs(): sigbuild_tf_configs( name_container_map = { - "sigbuild-r2.16": "docker://gcr.io/tensorflow-sigs/build@sha256:22d863e6fe3f98946015b9e1264b2eeb8e56e504535a6c1d5e564cae65ae5d37", + "sigbuild-r2.16": "docker://gcr.io/tensorflow-sigs/build@sha256:842a5ba84d3658c5bf1f8a31e16284f7becc35409da0dfd71816afa3cd28d728", "sigbuild-r2.16-python3.9": "docker://gcr.io/tensorflow-sigs/build@sha256:22d863e6fe3f98946015b9e1264b2eeb8e56e504535a6c1d5e564cae65ae5d37", "sigbuild-r2.16-python3.10": "docker://gcr.io/tensorflow-sigs/build@sha256:da15288c8464153eadd35da720540a544b76aa9d78cceb42a6821b2f3e70a0fa", "sigbuild-r2.16-python3.11": "docker://gcr.io/tensorflow-sigs/build@sha256:842a5ba84d3658c5bf1f8a31e16284f7becc35409da0dfd71816afa3cd28d728", @@ -667,7 +667,7 @@ def initialize_rbe_configs(): sigbuild_tf_configs( name_container_map = { - "sigbuild-r2.16-clang": "docker://gcr.io/tensorflow-sigs/build@sha256:22d863e6fe3f98946015b9e1264b2eeb8e56e504535a6c1d5e564cae65ae5d37", + "sigbuild-r2.16-clang": "docker://gcr.io/tensorflow-sigs/build@sha256:842a5ba84d3658c5bf1f8a31e16284f7becc35409da0dfd71816afa3cd28d728", "sigbuild-r2.16-clang-python3.9": "docker://gcr.io/tensorflow-sigs/build@sha256:22d863e6fe3f98946015b9e1264b2eeb8e56e504535a6c1d5e564cae65ae5d37", "sigbuild-r2.16-clang-python3.10": "docker://gcr.io/tensorflow-sigs/build@sha256:da15288c8464153eadd35da720540a544b76aa9d78cceb42a6821b2f3e70a0fa", "sigbuild-r2.16-clang-python3.11": "docker://gcr.io/tensorflow-sigs/build@sha256:842a5ba84d3658c5bf1f8a31e16284f7becc35409da0dfd71816afa3cd28d728", diff --git a/tensorflow/workspace2.bzl b/tensorflow/workspace2.bzl index a996e90e385f81..f76ff11862d9df 100644 --- a/tensorflow/workspace2.bzl +++ b/tensorflow/workspace2.bzl @@ -150,9 +150,9 @@ def _tf_repositories(): # LINT.IfChange tf_http_archive( name = "XNNPACK", - sha256 = "434fe914cb52da3e66ba920082af969f527f23729fff182aecd87ac5324e9f90", - strip_prefix = "XNNPACK-dcbfffb80fb4f6fcfcfb5b3723854ec8797fa546", - urls = tf_mirror_urls("https://github.com/google/XNNPACK/archive/dcbfffb80fb4f6fcfcfb5b3723854ec8797fa546.zip"), + sha256 = "bd7592e11699a34e94ce7fd36d95798092b508c01b6ae44ff232c7929bb2c927", + strip_prefix = "XNNPACK-9325fcfe52092b2f8f816db218bca208db7b2750", + urls = tf_mirror_urls("https://github.com/google/XNNPACK/archive/9325fcfe52092b2f8f816db218bca208db7b2750.zip"), ) # LINT.ThenChange(//tensorflow/lite/tools/cmake/modules/xnnpack.cmake) @@ -181,9 +181,9 @@ def _tf_repositories(): name = "cudnn_frontend_archive", build_file = "//third_party:cudnn_frontend.BUILD", patch_file = ["//third_party:cudnn_frontend_header_fix.patch"], - sha256 = "015ea933139a30e9ccd177b5e0dbfb16f3d08df78334aaacea57880275df734b", - strip_prefix = "cudnn-frontend-1.0.0", - urls = tf_mirror_urls("https://github.com/NVIDIA/cudnn-frontend/archive/refs/tags/v1.0.0.zip"), + sha256 = "c2f5373ddf84e33d289dad5766667f52de652dfbbb1dccb2fada9cfcf2d774cf", + strip_prefix = "cudnn-frontend-1.1.0", + urls = tf_mirror_urls("https://github.com/NVIDIA/cudnn-frontend/archive/refs/tags/v1.1.0.zip"), ) tf_http_archive( @@ -205,9 +205,9 @@ def _tf_repositories(): tf_http_archive( name = "onednn", build_file = "//third_party/mkl_dnn:mkldnn_v1.BUILD", - sha256 = "8d150a77025f38bff182aaef4dd643625563b2f311c635f86cf4b769b04d7b48", - strip_prefix = "oneDNN-3.3", - urls = tf_mirror_urls("https://github.com/oneapi-src/oneDNN/archive/refs/tags/v3.3.tar.gz"), + sha256 = "e291fa4702f4bcfa6c8c23cb5b6599f0fefa8f23bc08edb9e15ddc5254ab7843", + strip_prefix = "oneDNN-3.3.4", + urls = tf_mirror_urls("https://github.com/oneapi-src/oneDNN/archive/refs/tags/v3.3.4.tar.gz"), ) tf_http_archive( @@ -608,12 +608,13 @@ def _tf_repositories(): urls = tf_mirror_urls("https://github.com/NVlabs/cub/archive/1.9.9.zip"), ) + # Note that we are currently taking NVTX headers from a NCCL release to get nvToolsExtPayload.h tf_http_archive( name = "nvtx_archive", - build_file = "//third_party:nvtx.BUILD", - sha256 = "bb8d1536aad708ec807bc675e12e5838c2f84481dec4005cd7a9bbd49e326ba1", - strip_prefix = "NVTX-3.0.1/c/include", - urls = tf_mirror_urls("https://github.com/NVIDIA/NVTX/archive/v3.0.1.tar.gz"), + build_file = "//third_party:nvtx/BUILD", + sha256 = "1c5474553afedb88e878c772f13d6f90b9226b3f2971dfa6f873adb9443100c2", + strip_prefix = "nccl-2.19.3-1/src/include/nvtx3", + urls = tf_mirror_urls("https://github.com/nvidia/nccl/archive/v2.19.3-1.tar.gz"), ) tf_http_archive( diff --git a/tensorflow/workspace3.bzl b/tensorflow/workspace3.bzl index af1613994a749c..7d187724bb1e47 100644 --- a/tensorflow/workspace3.bzl +++ b/tensorflow/workspace3.bzl @@ -34,10 +34,10 @@ def workspace(): http_archive( name = "rules_license", urls = [ - "https://mirror.bazel.build/github.com/bazelbuild/rules_license/releases/download/0.0.4/rules_license-0.0.4.tar.gz", - "https://github.com/bazelbuild/rules_license/releases/download/0.0.4/rules_license-0.0.4.tar.gz", + "https://mirror.bazel.build/github.com/bazelbuild/rules_license/releases/download/0.0.7/rules_license-0.0.7.tar.gz", + "https://github.com/bazelbuild/rules_license/releases/download/0.0.7/rules_license-0.0.7.tar.gz", ], - sha256 = "6157e1e68378532d0241ecd15d3c45f6e5cfd98fc10846045509fb2a7cc9e381", + sha256 = "4531deccb913639c30e5c7512a054d5d875698daeb75d8cf90f284375fe7c360", ) http_archive( diff --git a/third_party/cudnn_frontend_header_fix.patch b/third_party/cudnn_frontend_header_fix.patch index af22372c66009e..70476bd3ff5d56 100644 --- a/third_party/cudnn_frontend_header_fix.patch +++ b/third_party/cudnn_frontend_header_fix.patch @@ -1,234 +1,13 @@ -diff --git a/include/cudnn_backend_base.h b/include/cudnn_backend_base.h -index 1240282..cba52ec 100644 ---- a/include/cudnn_backend_base.h -+++ b/include/cudnn_backend_base.h -@@ -24,7 +24,7 @@ - - #include - --#include -+#include "third_party/gpus/cudnn/cudnn.h" - - namespace cudnn_frontend { - -diff --git a/include/cudnn_frontend_ConvDesc.h b/include/cudnn_frontend_ConvDesc.h -index 6e1d7ab..4deec88 100644 ---- a/include/cudnn_frontend_ConvDesc.h -+++ b/include/cudnn_frontend_ConvDesc.h -@@ -29,8 +29,8 @@ - #include - #include - --#include --#include -+#include "third_party/gpus/cudnn/cudnn.h" -+#include "third_party/gpus/cudnn/cudnn_backend.h" - - #include "cudnn_frontend_utils.h" - -diff --git a/include/cudnn_frontend_Engine.h b/include/cudnn_frontend_Engine.h -index b95efb8..867541e 100644 ---- a/include/cudnn_frontend_Engine.h -+++ b/include/cudnn_frontend_Engine.h -@@ -30,8 +30,8 @@ - #include - #include - --#include --#include -+#include "third_party/gpus/cudnn/cudnn.h" -+#include "third_party/gpus/cudnn/cudnn_backend.h" - - #include "cudnn_frontend_OperationGraph.h" - #include "cudnn_frontend_utils.h" -diff --git a/include/cudnn_frontend_EngineConfig.h b/include/cudnn_frontend_EngineConfig.h -index 973e777..97f0883 100644 ---- a/include/cudnn_frontend_EngineConfig.h -+++ b/include/cudnn_frontend_EngineConfig.h -@@ -29,8 +29,8 @@ - #include - #include - --#include --#include -+#include "third_party/gpus/cudnn/cudnn.h" -+#include "third_party/gpus/cudnn/cudnn_backend.h" - - #include "cudnn_frontend_Engine.h" - #include "cudnn_frontend_utils.h" -diff --git a/include/cudnn_frontend_EngineFallbackList.h b/include/cudnn_frontend_EngineFallbackList.h -index 4d4e5be..6390bc5 100644 ---- a/include/cudnn_frontend_EngineFallbackList.h -+++ b/include/cudnn_frontend_EngineFallbackList.h -@@ -22,7 +22,7 @@ - - #pragma once - --#include -+#include "third_party/gpus/cudnn/cudnn.h" - #include - #include "cudnn_frontend_Heuristics.h" - -diff --git a/include/cudnn_frontend_ExecutionPlan.h b/include/cudnn_frontend_ExecutionPlan.h -index afceeb3..3d426e2 100644 ---- a/include/cudnn_frontend_ExecutionPlan.h -+++ b/include/cudnn_frontend_ExecutionPlan.h -@@ -30,8 +30,8 @@ - #include - #include - --#include --#include -+#include "third_party/gpus/cudnn/cudnn.h" -+#include "third_party/gpus/cudnn/cudnn_backend.h" - - #include "cudnn_frontend_EngineConfig.h" - #include "cudnn_frontend_Engine.h" -diff --git a/include/cudnn_frontend_Filters.h b/include/cudnn_frontend_Filters.h -index 676f0f2..4d1c020 100644 ---- a/include/cudnn_frontend_Filters.h -+++ b/include/cudnn_frontend_Filters.h -@@ -22,7 +22,7 @@ - - #pragma once - --#include -+#include "third_party/gpus/cudnn/cudnn.h" - - namespace cudnn_frontend { - -diff --git a/include/cudnn_frontend_Heuristics.h b/include/cudnn_frontend_Heuristics.h -index dda3fb3..3e89237 100644 ---- a/include/cudnn_frontend_Heuristics.h -+++ b/include/cudnn_frontend_Heuristics.h -@@ -25,8 +25,8 @@ - #include - #include - --#include --#include -+#include "third_party/gpus/cudnn/cudnn.h" -+#include "third_party/gpus/cudnn/cudnn_backend.h" - - #include "cudnn_frontend_OperationGraph.h" - #include "cudnn_frontend_EngineConfig.h" -diff --git a/include/cudnn_frontend_MatMulDesc.h b/include/cudnn_frontend_MatMulDesc.h -index c9258ba..141f2f9 100644 ---- a/include/cudnn_frontend_MatMulDesc.h -+++ b/include/cudnn_frontend_MatMulDesc.h -@@ -29,8 +29,8 @@ - #include - #include - --#include --#include -+#include "third_party/gpus/cudnn/cudnn.h" -+#include "third_party/gpus/cudnn/cudnn_backend.h" - - #include "cudnn_frontend_utils.h" - -diff --git a/include/cudnn_frontend_Operation.h b/include/cudnn_frontend_Operation.h -index bf16cfa..f3086e1 100644 ---- a/include/cudnn_frontend_Operation.h -+++ b/include/cudnn_frontend_Operation.h -@@ -30,8 +30,8 @@ - #include - #include +diff --git a/include/cudnn_frontend.h b/include/cudnn_frontend.h +index 0f0d5a6..802bcbb 100644 +--- a/include/cudnn_frontend.h ++++ b/include/cudnn_frontend.h +@@ -97,7 +97,7 @@ + * - Simpler samples on how to use the new API. + */ -#include --#include +#include "third_party/gpus/cudnn/cudnn.h" -+#include "third_party/gpus/cudnn/cudnn_backend.h" #include "cudnn_frontend_ConvDesc.h" - #include "cudnn_frontend_PointWiseDesc.h" -diff --git a/include/cudnn_frontend_OperationGraph.h b/include/cudnn_frontend_OperationGraph.h -index c5e2704..71589b2 100644 ---- a/include/cudnn_frontend_OperationGraph.h -+++ b/include/cudnn_frontend_OperationGraph.h -@@ -30,8 +30,8 @@ - #include - #include - --#include --#include -+#include "third_party/gpus/cudnn/cudnn.h" -+#include "third_party/gpus/cudnn/cudnn_backend.h" - - #include "cudnn_frontend_Operation.h" - #include "cudnn_frontend_utils.h" -diff --git a/include/cudnn_frontend_PointWiseDesc.h b/include/cudnn_frontend_PointWiseDesc.h -index afa71ce..56b6507 100644 ---- a/include/cudnn_frontend_PointWiseDesc.h -+++ b/include/cudnn_frontend_PointWiseDesc.h -@@ -30,8 +30,8 @@ - #include - #include - --#include --#include -+#include "third_party/gpus/cudnn/cudnn.h" -+#include "third_party/gpus/cudnn/cudnn_backend.h" - - #include "cudnn_frontend_utils.h" - -diff --git a/include/cudnn_frontend_ReductionDesc.h b/include/cudnn_frontend_ReductionDesc.h -index 5df2c5e..419fc93 100644 ---- a/include/cudnn_frontend_ReductionDesc.h -+++ b/include/cudnn_frontend_ReductionDesc.h -@@ -29,8 +29,8 @@ - #include - #include - --#include --#include -+#include "third_party/gpus/cudnn/cudnn.h" -+#include "third_party/gpus/cudnn/cudnn_backend.h" - - #include "cudnn_frontend_utils.h" - -diff --git a/include/cudnn_frontend_Resample.h b/include/cudnn_frontend_Resample.h -index 351e2da..b1a1904 100644 ---- a/include/cudnn_frontend_Resample.h -+++ b/include/cudnn_frontend_Resample.h -@@ -29,8 +29,8 @@ - #include - #include - --#include --#include -+#include "third_party/gpus/cudnn/cudnn.h" -+#include "third_party/gpus/cudnn/cudnn_backend.h" - - #include "cudnn_frontend_utils.h" - -diff --git a/include/cudnn_frontend_Rng.h b/include/cudnn_frontend_Rng.h -index 9d4e6ca..4224b61 100644 ---- a/include/cudnn_frontend_Rng.h -+++ b/include/cudnn_frontend_Rng.h -@@ -29,8 +29,8 @@ - #include - #include - --#include --#include -+#include "third_party/gpus/cudnn/cudnn.h" -+#include "third_party/gpus/cudnn/cudnn_backend.h" - - #include "cudnn_frontend_utils.h" - -diff --git a/include/cudnn_frontend_VariantPack.h b/include/cudnn_frontend_VariantPack.h -index 455ab8b..4173860 100644 ---- a/include/cudnn_frontend_VariantPack.h -+++ b/include/cudnn_frontend_VariantPack.h -@@ -30,8 +30,8 @@ - #include - #include - --#include --#include -+#include "third_party/gpus/cudnn/cudnn.h" -+#include "third_party/gpus/cudnn/cudnn_backend.h" - - #include "cudnn_frontend_utils.h" - + #include "cudnn_frontend_Heuristics.h" diff --git a/third_party/curl.BUILD b/third_party/curl.BUILD index 8dcd54451dae66..b31d8488aaa0ba 100644 --- a/third_party/curl.BUILD +++ b/third_party/curl.BUILD @@ -14,11 +14,6 @@ CURL_WIN_COPTS = [ "/DCURL_DISABLE_PROXY", "/DHAVE_LIBZ", "/DHAVE_ZLIB_H", - # Defining _USING_V110_SDK71_ is hackery to defeat curl's incorrect - # detection of what OS releases we can build on with VC 2012. This - # may not be needed (or may have to change) if the WINVER setting - # changes in //third_party/msvc/vc_12_0/CROSSTOOL. - "/D_USING_V110_SDK71_", ] CURL_WIN_SRCS = [ diff --git a/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc.tpl b/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc.tpl index 0da1d7b58f4bb0..74fafb9b32f516 100755 --- a/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc.tpl +++ b/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc.tpl @@ -1,4 +1,4 @@ -#!/usr/bin/env python +#!/usr/bin/env python3 # Copyright 2015 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -41,7 +41,7 @@ import os import subprocess import re import sys -import pipes +import shlex # Template values set by cuda_autoconf. CPU_COMPILER = ('%{cpu_compiler}') @@ -299,7 +299,7 @@ def main(): if args.x and args.x[0] == 'cuda': if args.cuda_log: Log('-x cuda') - leftover = [pipes.quote(s) for s in leftover] + leftover = [shlex.quote(s) for s in leftover] if args.cuda_log: Log('using nvcc') return InvokeNvcc(leftover, log=args.cuda_log) diff --git a/third_party/gpus/cuda/build_defs.bzl.tpl b/third_party/gpus/cuda/build_defs.bzl.tpl index 189d3e3e784003..bc865cecb3240a 100644 --- a/third_party/gpus/cuda/build_defs.bzl.tpl +++ b/third_party/gpus/cuda/build_defs.bzl.tpl @@ -94,6 +94,25 @@ def if_cuda_is_configured(x, no_cuda = []): return select({"//conditions:default": x}) return select({"//conditions:default": no_cuda}) +def if_cuda_newer_than(wanted_ver, if_true, if_false = []): + """Tests if CUDA was enabled during the configured process and if the + configured version is at least `wanted_ver`. `wanted_ver` needs + to be provided as a string in the format `_`. + Example: `11_0` + """ + + wanted_major = int(wanted_ver.split('_')[0]) + wanted_minor = int(wanted_ver.split('_')[1]) + + configured_version = "%{cuda_version}" + configured_major = int(configured_version.split('.')[0]) + configured_minor = int(configured_version.split('.')[1]) + + if %{cuda_is_configured} and (wanted_major, wanted_minor) <= (configured_major, configured_minor): + return select({"//conditions:default": if_true}) + return select({"//conditions:default": if_false}) + + def cuda_header_library( name, hdrs, diff --git a/third_party/gpus/cuda_configure.bzl b/third_party/gpus/cuda_configure.bzl index bc692500836410..5bf0504dc91bcc 100644 --- a/third_party/gpus/cuda_configure.bzl +++ b/third_party/gpus/cuda_configure.bzl @@ -827,6 +827,7 @@ def _create_dummy_repository(repository_ctx): "%{cuda_is_configured}": "False", "%{cuda_extra_copts}": "[]", "%{cuda_gpu_architectures}": "[]", + "%{cuda_version}": "0.0", }, ) _tpl( @@ -1214,6 +1215,7 @@ def _create_local_cuda_repository(repository_ctx): cuda_config.compute_capabilities, ), "%{cuda_gpu_architectures}": str(cuda_config.compute_capabilities), + "%{cuda_version}": cuda_config.cuda_version, }, ) @@ -1427,6 +1429,7 @@ def _create_remote_cuda_repository(repository_ctx, remote_config_repo): repository_ctx, compute_capabilities(repository_ctx), ), + "%{cuda_version}": get_host_environ(repository_ctx, _TF_CUDA_VERSION), }, ) repository_ctx.template( diff --git a/third_party/gpus/rocm/build_defs.bzl.tpl b/third_party/gpus/rocm/build_defs.bzl.tpl index 2b4595bb222885..339733755d6f1f 100644 --- a/third_party/gpus/rocm/build_defs.bzl.tpl +++ b/third_party/gpus/rocm/build_defs.bzl.tpl @@ -38,6 +38,16 @@ def rocm_version_number(): """Returns a list of supported GPU architectures.""" return %{rocm_version_number} +def if_gpu_is_configured(if_true, if_false = []): + """Tests if ROCm or CUDA was enabled during the configure process. + + Unlike if_rocm() or if_cuda(), this does not require that we are building + with --config=rocm or --config=cuda, respectively. Used to allow non-GPU + code to depend on ROCm or CUDA libraries. + + """ + return select({"//conditions:default": %{gpu_is_configured}}) + def if_rocm_is_configured(x): """Tests if the ROCm was enabled during the configure process. diff --git a/third_party/gpus/rocm_configure.bzl b/third_party/gpus/rocm_configure.bzl index 520c9bce6c5265..a83755f0c17f80 100644 --- a/third_party/gpus/rocm_configure.bzl +++ b/third_party/gpus/rocm_configure.bzl @@ -10,6 +10,7 @@ load( ":cuda_configure.bzl", + "enable_cuda", "make_copy_dir_rule", "make_copy_files_rule", "to_list_of_strings", @@ -449,6 +450,7 @@ def _create_dummy_repository(repository_ctx): "rocm:build_defs.bzl", { "%{rocm_is_configured}": "False", + "%{gpu_is_configured}": "if_true" if enable_cuda(repository_ctx) else "if_false", "%{rocm_extra_copts}": "[]", "%{rocm_gpu_architectures}": "[]", "%{rocm_version_number}": "0", @@ -634,6 +636,7 @@ def _create_local_rocm_repository(repository_ctx): tpl_paths["rocm:build_defs.bzl"], { "%{rocm_is_configured}": "True", + "%{gpu_is_configured}": "if_true", "%{rocm_extra_copts}": _compute_rocm_extra_copts( repository_ctx, rocm_config.amdgpu_targets, @@ -762,6 +765,7 @@ def _create_remote_rocm_repository(repository_ctx, remote_config_repo): "rocm:build_defs.bzl", { "%{rocm_is_configured}": "True", + "%{gpu_is_configured}": "if_true", "%{rocm_extra_copts}": _compute_rocm_extra_copts( repository_ctx, [], #_compute_capabilities(repository_ctx) @@ -815,6 +819,7 @@ _ENVIRONS = [ _GCC_HOST_COMPILER_PATH, _GCC_HOST_COMPILER_PREFIX, "TF_NEED_ROCM", + "TF_NEED_CUDA", # Needed by the `if_gpu_is_configured` macro _ROCM_TOOLKIT_PATH, _TF_ROCM_AMDGPU_TARGETS, ] diff --git a/third_party/llvm/workspace.bzl b/third_party/llvm/workspace.bzl index ad6be050767281..7394824e266d75 100644 --- a/third_party/llvm/workspace.bzl +++ b/third_party/llvm/workspace.bzl @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive") def repo(name): """Imports LLVM.""" - LLVM_COMMIT = "c9a6e993f7b349405b6c8f9244cd9cf0f56a6a81" - LLVM_SHA256 = "fcf63e5fa636345867bb699ff0e134c21eb1d1d60328cf0eed8ab3c911b6b19e" + LLVM_COMMIT = "e899641df2391179e8ec29ca14c53b09ae7ce85c" + LLVM_SHA256 = "d94296fbfde8ba0bc80a917c816017258b83568dd627cf5f02be1f927e8472bb" tf_http_archive( name = name, diff --git a/third_party/mkl_dnn/mkldnn_v1.BUILD b/third_party/mkl_dnn/mkldnn_v1.BUILD index dee86c455cbf82..add18406251df9 100644 --- a/third_party/mkl_dnn/mkldnn_v1.BUILD +++ b/third_party/mkl_dnn/mkldnn_v1.BUILD @@ -14,9 +14,8 @@ _CMAKE_COMMON_LIST = { "#cmakedefine DNNL_SYCL_CUDA": "#undef DNNL_SYCL_CUDA", "#cmakedefine DNNL_SYCL_HIP": "#undef DNNL_SYCL_HIP", "#cmakedefine DNNL_ENABLE_STACK_CHECKER": "#undef DNNL_ENABLE_STACK_CHECKER", - "#cmakedefine ONEDNN_BUILD_GRAPH": "#undef ONEDNN_BUILD_GRAPH", - "#cmakedefine DNNL_EXPERIMENTAL_SPARSE": "#define DNNL_EXPERIMENTAL_SPARSE", "#cmakedefine DNNL_EXPERIMENTAL": "#undef DNNL_EXPERIMENTAL", + "#cmakedefine ONEDNN_BUILD_GRAPH": "#undef ONEDNN_BUILD_GRAPH", "#cmakedefine01 BUILD_TRAINING": "#define BUILD_TRAINING 1", "#cmakedefine01 BUILD_INFERENCE": "#define BUILD_INFERENCE 0", "#cmakedefine01 BUILD_PRIMITIVE_ALL": "#define BUILD_PRIMITIVE_ALL 1", @@ -96,7 +95,7 @@ expand_template( substitutions = { "@DNNL_VERSION_MAJOR@": "3", "@DNNL_VERSION_MINOR@": "3", - "@DNNL_VERSION_PATCH@": "0", + "@DNNL_VERSION_PATCH@": "4", "@DNNL_VERSION_HASH@": "N/A", }, template = "include/oneapi/dnnl/dnnl_version.h.in", diff --git a/third_party/stablehlo/temporary.patch b/third_party/stablehlo/temporary.patch index 8a53bd5f19638f..c8aa227b5a35ef 100755 --- a/third_party/stablehlo/temporary.patch +++ b/third_party/stablehlo/temporary.patch @@ -1,7 +1,7 @@ diff --ruN a/stablehlo/CMakeLists.txt b/stablehlo/CMakeLists.txt --- stablehlo/CMakeLists.txt +++ stablehlo/CMakeLists.txt -@@ -13,131 +13,20 @@ +@@ -13,153 +13,20 @@ # See the License for the specific language governing permissions and # limitations under the License. # @@ -25,6 +25,11 @@ diff --ruN a/stablehlo/CMakeLists.txt b/stablehlo/CMakeLists.txt -if(POLICY CMP0116) - cmake_policy(SET CMP0116 OLD) -endif() +- +-# Support for return(PROPAGATE ...) in functions. +-if (POLICY CMP0140) +- cmake_policy(SET CMP0140 NEW) +-endif() +# This build of StableHLO is meant to be embedded in MLIR-HLO. +# As a result, its root CMakeLists.txt is different from the original +# CMakeLists.txt from https://github.com/openxla/stablehlo. @@ -39,6 +44,9 @@ diff --ruN a/stablehlo/CMakeLists.txt b/stablehlo/CMakeLists.txt -option(STABLEHLO_BUILD_EMBEDDED "Build StableHLO as part of another project" OFF) -option(STABLEHLO_ENABLE_BINDINGS_PYTHON "Enables StableHLO Python bindings" OFF) -option(STABLEHLO_ENABLE_STRICT_BUILD "Build StableHLO with strict warnings and warnings as errors" OFF) +-option(STABLEHLO_ENABLE_SANITIZER "Enable a sanitizer [OFF, address]" OFF) +-option(STABLEHLO_ENABLE_SPLIT_DWARF "Enable split DWARF if the platform supports it" OFF) +-option(STABLEHLO_ENABLE_LLD "Use LLD as the linker if available" OFF) -#------------------------------------------------------------------------------- -# Project setup and globals @@ -55,29 +63,6 @@ diff --ruN a/stablehlo/CMakeLists.txt b/stablehlo/CMakeLists.txt - set(CMAKE_CXX_STANDARD 17) -endif() - --# Build with ccache if the package is present --set(LLVM_CCACHE_BUILD OFF CACHE BOOL "Set to ON for a ccache enabled build") --if(LLVM_CCACHE_BUILD) -- find_program(CCACHE_PROGRAM ccache) -- if(CCACHE_PROGRAM) -- set(LLVM_CCACHE_MAXSIZE "" CACHE STRING "Size of ccache") -- set(LLVM_CCACHE_DIR "" CACHE STRING "Directory to keep ccached data") -- set(LLVM_CCACHE_PARAMS "CCACHE_CPP2=yes CCACHE_HASHDIR=yes" -- CACHE STRING "Parameters to pass through to ccache") -- -- set(CCACHE_PROGRAM "${LLVM_CCACHE_PARAMS} ${CCACHE_PROGRAM}") -- if (LLVM_CCACHE_MAXSIZE) -- set(CCACHE_PROGRAM "CCACHE_MAXSIZE=${LLVM_CCACHE_MAXSIZE} ${CCACHE_PROGRAM}") -- endif() -- if (LLVM_CCACHE_DIR) -- set(CCACHE_PROGRAM "CCACHE_DIR=${LLVM_CCACHE_DIR} ${CCACHE_PROGRAM}") -- endif() -- set_property(GLOBAL PROPERTY RULE_LAUNCH_COMPILE ${CCACHE_PROGRAM}) -- else() -- message(FATAL_ERROR "Unable to find the program ccache. Set LLVM_CCACHE_BUILD to OFF") -- endif() --endif() -- -#------------------------------------------------------------------------------- -# MLIR/LLVM Configuration -#------------------------------------------------------------------------------- @@ -114,10 +99,39 @@ diff --ruN a/stablehlo/CMakeLists.txt b/stablehlo/CMakeLists.txt - message(STATUS "Building StableHLO embedded in another project") -endif() - +-# Add the CMake modules specific to StableHLO +-list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_LIST_DIR}/cmake") +- -if(LLVM_ENABLE_ZLIB) - find_package(ZLIB) -endif() - +-#------------------------------------------------------------------------------- +-# Performance configuration +-#------------------------------------------------------------------------------- +- +-include(CheckCXXCompilerFlag) +-include(CheckLinkerFlag) +-if (STABLEHLO_ENABLE_LLD) +- message(STATUS "Enabling LLD as the linker") +- add_link_options("-fuse-ld=lld") +-endif() +- +-if(STABLEHLO_ENABLE_SPLIT_DWARF) +- check_cxx_compiler_flag(-gsplit-dwarf STABLEHLO_SUPPORTS_SPLIT_DWARF) +- if (STABLEHLO_SUPPORTS_SPLIT_DWARF) +- message(STATUS "Enabling split-dwarf build") +- add_compile_options(-gsplit-dwarf -ggnu-pubnames) +- endif() +- check_linker_flag(CXX "-Wl,--gdb-index" STABLEHLO_SUPPORTS_GDB_INDEX) +- # If we set LLD it doesn't seem to affect the check_linker_flag above. +- # Account for it with the generator expression OR +- if (STABLEHLO_SUPPORTS_GDB_INDEX OR STABLEHLO_ENABLE_LLD) +- message(STATUS "Enabling GDB index in binary") +- add_link_options("-Wl,--gdb-index") +- endif() +-endif() +- -include(TableGen) -include(AddLLVM) -include(AddMLIR) @@ -129,6 +143,14 @@ diff --ruN a/stablehlo/CMakeLists.txt b/stablehlo/CMakeLists.txt -link_directories(${LLVM_BUILD_LIBRARY_DIR}) -add_definitions(${LLVM_DEFINITIONS}) - +- +-#------------------------------------------------------------------------------- +-# Sanitizer configuration +-#------------------------------------------------------------------------------- +- +-include(SetupSanitizers) +-setup_sanitizers() +- -#------------------------------------------------------------------------------- -# Python configuration -#------------------------------------------------------------------------------- @@ -141,6 +163,27 @@ diff --ruN a/stablehlo/CMakeLists.txt b/stablehlo/CMakeLists.txt #------------------------------------------------------------------------------- # Directory setup +diff --ruN a/stablehlo/MODULE.bazel.lock b/stablehlo/MODULE.bazel.lock +--- stablehlo/MODULE.bazel.lock ++++ stablehlo/MODULE.bazel.lock +@@ -1,3 +1,17 @@ ++# Copyright 2024 The StableHLO Authors. All Rights Reserved. ++# ++# Licensed under the Apache License, Version 2.0 (the "License"); ++# you may not use this file except in compliance with the License. ++# You may obtain a copy of the License at ++# ++# https://www.apache.org/licenses/LICENSE-2.0 ++# ++# Unless required by applicable law or agreed to in writing, software ++# distributed under the License is distributed on an "AS IS" BASIS, ++# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ++# See the License for the specific language governing permissions and ++# limitations under the License. ++ + { + "lockFileVersion": 3, + "moduleFileHash": "836f0a7d2276ed93403f104a10008b94ec7e7f81b8d6921cea287f0a6d364efa", diff --ruN a/stablehlo/stablehlo/CMakeLists.txt b/stablehlo/stablehlo/CMakeLists.txt --- stablehlo/stablehlo/CMakeLists.txt +++ stablehlo/stablehlo/CMakeLists.txt @@ -152,18 +195,927 @@ diff --ruN a/stablehlo/stablehlo/CMakeLists.txt b/stablehlo/stablehlo/CMakeLists add_subdirectory(integrations) add_subdirectory(reference) add_subdirectory(tests) -diff --ruN a/stablehlo/stablehlo/api/PortableApi.h b/stablehlo/stablehlo/api/PortableApi.h ---- stablehlo/stablehlo/api/PortableApi.h -+++ stablehlo/stablehlo/api/PortableApi.h -@@ -27,7 +27,7 @@ +diff --ruN a/stablehlo/stablehlo/dialect/AssemblyFormat.cpp b/stablehlo/stablehlo/dialect/AssemblyFormat.cpp +--- stablehlo/stablehlo/dialect/AssemblyFormat.cpp ++++ stablehlo/stablehlo/dialect/AssemblyFormat.cpp +@@ -16,15 +16,28 @@ + #include "stablehlo/dialect/AssemblyFormat.h" + + #include ++#include + #include - /// Return the current version for portable API. - /// Increments on all meaningful changes to this file. --inline int64_t getApiVersion() { return 5; } -+inline int64_t getApiVersion() { return 6; } + #include "llvm/ADT/ArrayRef.h" + #include "llvm/ADT/STLExtras.h" ++#include "llvm/ADT/StringExtras.h" ++#include "llvm/Support/Debug.h" + #include "llvm/Support/ErrorHandling.h" + #include "llvm/Support/Regex.h" ++#include "llvm/Support/SMLoc.h" ++#include "mlir/IR/Builders.h" + #include "mlir/IR/BuiltinAttributes.h" + #include "mlir/IR/BuiltinTypeInterfaces.h" ++#include "mlir/IR/OpImplementation.h" ++#include "mlir/IR/OperationSupport.h" ++#include "mlir/IR/Region.h" ++#include "mlir/IR/TypeUtilities.h" ++#include "mlir/IR/ValueRange.h" ++#include "mlir/Support/LLVM.h" + #include "mlir/Support/LogicalResult.h" ++ ++#define DEBUG_TYPE "hlo-assembly" + + namespace mlir { + namespace hlo { +@@ -212,6 +225,343 @@ + return success(); + } - // Get the current StableHLO version. ++namespace { ++void createArgs(ArrayRef operands, ++ ArrayRef types, ++ SmallVector& args) { ++ for (auto argAndType : llvm::zip(operands, types)) { ++ auto& arg = args.emplace_back(); ++ arg.ssaName = std::get<0>(argAndType); ++ arg.type = std::get<1>(argAndType); ++ } ++} ++ ++Operation* createReturn(OpBuilder& builder, Dialect* dialect, Location loc, ++ ResultRange operands) { ++ auto returnOpName = dialect->getNamespace() + ".return"; ++ OperationState returnOpState(loc, returnOpName.str()); ++ returnOpState.operands.append(operands.begin(), operands.end()); ++ return builder.create(returnOpState); ++} ++ ++bool hasSameOperandAndResultTypes(Operation& op) { ++ Type expected; ++ if (op.getNumResults() != 0) expected = op.getResult(0).getType(); ++ if (op.getNumOperands() != 0) expected = op.getOperand(0).getType(); ++ if (!expected) return false; ++ ++ auto typeMatch = [&](Type actual) { return actual == expected; }; ++ return llvm::all_of(op.getOperandTypes(), typeMatch) && ++ llvm::all_of(op.getResultTypes(), typeMatch); ++} ++ ++// Checks the following eligibility criteria for compact printing of reduce: ++// E1. The reduce-op wraps a single inner-op in the associated region. ++// E2. The single operation is a commutative binary-op from the dialect, zero ++// region, producing single result such that the operands and result all ++// have the same type. ++// E3. The reduce-op consist of at least one input-operand; The operand-types of ++// inner-op should be derived trivially from the element-type of reduce-op's ++// first input-operand. ++// E4. The arguments of the region's only basic block are forwarded perfectly ++// to inner-op's operands. ++// E5. The single operation result is perfectly forwarded to the reduce op ++// return. ++static bool isReduceEligibleForCompactPrint(Operation* op, ValueRange inputs, ++ Region& body) { ++ // Check E1. ++ LLVM_DEBUG(llvm::dbgs() << "Checking ReduceOp compact print E1\n"); ++ auto& block = body.front(); ++ if (!hasSingleElement(block.without_terminator())) return false; ++ ++ Operation& innerOp = *block.begin(); ++ ++ // Check E2. ++ LLVM_DEBUG(llvm::dbgs() << "Checking ReduceOp compact print E2\n"); ++ if (innerOp.getDialect() != op->getDialect()) return false; ++ ++ if (innerOp.getNumOperands() != 2 || ++ !innerOp.hasTrait() || ++ !hasSameOperandAndResultTypes(innerOp) || ++ (!innerOp.hasTrait() && ++ !innerOp.hasTrait()) || ++ !innerOp.hasTrait()) ++ return false; ++ ++ // Check E3. ++ LLVM_DEBUG(llvm::dbgs() << "Checking ReduceOp compact print E3\n"); ++ if (inputs.empty()) return false; ++ ++ auto elemType = inputs[0].getType().cast().getElementType(); ++ auto expectedInnerOpType = RankedTensorType::get(/*shape=*/{}, elemType); ++ if (innerOp.getOperands()[0].getType() != expectedInnerOpType) return false; ++ ++ // Check E4. ++ LLVM_DEBUG(llvm::dbgs() << "Checking ReduceOp compact print E4\n"); ++ if (!llvm::equal(block.getArguments(), innerOp.getOperands())) return false; ++ ++ // Check E5. ++ LLVM_DEBUG(llvm::dbgs() << "Checking ReduceOp compact print E5\n"); ++ auto retOp = block.getTerminator(); ++ if (!retOp->getName().stripDialect().equals("return")) return false; ++ ++ return llvm::equal(innerOp.getResults(), retOp->getOperands()); ++} ++} // namespace ++ ++void printReduceOp(OpAsmPrinter& p, Operation* op, ValueRange inputs, ++ ArrayRef dimensions, Region& body) { ++ { ++ // Print the pairs of operands under the form: ++ // (%arg0 init: %arg3), (%arg1 init: %arg4), (%arg2 init: %arg5) ++ StringRef comma = ""; ++ int numOperandPairs = op->getNumOperands() / 2; ++ for (int opId : llvm::seq(0, numOperandPairs)) { ++ p << comma << "(" << op->getOperand(opId) ++ << " init: " << op->getOperand(opId + numOperandPairs) << ")"; ++ comma = ", "; ++ } ++ } ++ ++ // If the reduce-op is eligible for compact printing, we emit the one-liner: ++ // stablehlo.reduce applies across dimensions = [...] : ++ // Note: We are not printing the function type of reduction operation. We ++ // have some simplifying assumptions (refer to IsEligibleForCompactPrint::E3) ++ // to derive the type from that of reduce-op. ++ if (isReduceEligibleForCompactPrint(op, inputs, body)) { ++ Operation& innerOp = body.front().front(); ++ p << " applies "; ++ llvm::printEscapedString(innerOp.getName().getStringRef(), p.getStream()); ++ p << " across dimensions = ["; ++ llvm::interleaveComma(dimensions, p); ++ p << "]"; ++ p.printOptionalAttrDict(op->getAttrs(), {"dimensions"}); ++ p << " : "; ++ p.printFunctionalType(op); ++ } else { ++ p << " across dimensions = ["; ++ llvm::interleaveComma(dimensions, p); ++ p << "]"; ++ p.printOptionalAttrDict(op->getAttrs(), {"dimensions"}); ++ p << " : "; ++ p.printFunctionalType(op); ++ p.printNewline(); ++ p << " reducer"; ++ { ++ // Print the pairs of block operands under the form: ++ // (%arg0_elt, %arg0_acc) (%arg1_elt, %arg1_acc): ++ Block& reducer = body.front(); ++ int numOperandPairs = op->getNumOperands() / 2; ++ for (int opId : llvm::seq(0, numOperandPairs)) { ++ p << "("; ++ p.printRegionArgument(reducer.getArgument(opId)); ++ p << ", "; ++ p.printRegionArgument(reducer.getArgument(opId + numOperandPairs)); ++ p << ") "; ++ } ++ } ++ p << ' '; ++ p.printRegion(body, /*printEntryBlockArgs=*/false); ++ } ++} ++ ++ParseResult parseReduceOp( ++ OpAsmParser& parser, OperationState& result, ++ std::function)> createDimensions) { ++ llvm::SMLoc loc = parser.getCurrentLocation(); ++ Location currLocation = parser.getEncodedSourceLoc(loc); ++ ++ // Parse the operands of reduce-op, this is a list of pair under the form: ++ // (%arg0 init: %arg3), (%arg1 init: %arg4), (%arg2 init: %arg5) ++ // Each input to reduce is paired with its init value, even though in memory ++ // they are stored with the input first and the init values after. ++ SmallVector operands; ++ SmallVector initOperands; ++ do { ++ (void)parser.parseOptionalComma(); ++ if (parser.parseOptionalLParen()) break; ++ OpAsmParser::UnresolvedOperand operand, initOperand; ++ if (parser.parseOperand(operand) || parser.parseKeyword("init") || ++ parser.parseColon() || parser.parseOperand(initOperand) || ++ parser.parseRParen()) ++ return failure(); ++ operands.push_back(operand); ++ initOperands.push_back(initOperand); ++ } while (true); ++ operands.append(initOperands); ++ ++ // Check if we are parsing the compact version of reduce-op: ++ // stablehlo.reduce applies across dimensions = [...] : ++ // else parse the "region-based" variant. ++ if (failed(parser.parseOptionalKeyword("applies"))) { ++ // Parse the inner-op dimensions, reduce-op's function-type and ++ // optional location. ++ SmallVector dimensions; ++ auto parseDim = [&]() -> ParseResult { ++ if (parser.parseInteger(dimensions.emplace_back())) return failure(); ++ return success(); ++ }; ++ ++ FunctionType reduceOpFnType; ++ if (parser.parseKeyword("across") || parser.parseKeyword("dimensions") || ++ parser.parseEqual() || ++ parser.parseCommaSeparatedList(AsmParser::Delimiter::Square, ++ parseDim) || ++ parser.parseOptionalAttrDict(result.attributes) || ++ parser.parseColon() || parser.parseType(reduceOpFnType) || ++ parser.parseKeyword("reducer")) ++ return failure(); ++ OpBuilder builder(parser.getBuilder().getContext()); ++ result.addAttribute("dimensions", createDimensions(builder, dimensions)); ++ ++ // Parse the "reducer" region now. ++ SmallVector reducerOperands; ++ SmallVector reducerInitOperands; ++ SmallVector reducerTypes; ++ SmallVector reducerInitTypes; ++ SmallVector, 2> reducerLocs; ++ SmallVector, 2> reducerInitLocs; ++ auto parseBlockOperand = ++ [&](SmallVectorImpl& operands, ++ SmallVectorImpl& types, ++ SmallVectorImpl>& locs) -> ParseResult { ++ OpAsmParser::UnresolvedOperand operand; ++ Type type; ++ std::optional loc; ++ if (parser.parseOperand(operand, /*allowResultNumber=*/false) || ++ parser.parseColon() || parser.parseType(type) || ++ parser.parseOptionalLocationSpecifier(loc)) ++ return failure(); ++ operands.push_back(operand); ++ types.push_back(type); ++ locs.push_back(loc); ++ return success(); ++ }; ++ do { ++ if (failed(parser.parseOptionalLParen())) break; ++ if (parseBlockOperand(reducerOperands, reducerTypes, reducerLocs) || ++ parser.parseComma() || ++ parseBlockOperand(reducerInitOperands, reducerInitTypes, ++ reducerInitLocs) || ++ parser.parseRParen()) ++ return failure(); ++ } while (true); ++ reducerOperands.append(reducerInitOperands); ++ reducerTypes.append(reducerInitTypes); ++ reducerLocs.append(reducerInitLocs); ++ result.addTypes(reduceOpFnType.getResults()); ++ SmallVector reducerArgs; ++ createArgs(reducerOperands, reducerTypes, reducerArgs); ++ ++ // Derive the SSA-values for reduce-op's operands and parse the region, and ++ // the optional trailing location. ++ std::optional trailingLoc; ++ if (parser.resolveOperands(operands, reduceOpFnType.getInputs(), loc, ++ result.operands) || ++ parser.parseRegion(*result.addRegion(), reducerArgs)) ++ return failure(); ++ // Set the individual block arguments. ++ for (auto argAndLoc : ++ llvm::zip(result.regions.front()->front().getArguments(), reducerLocs)) ++ if (std::get<1>(argAndLoc)) ++ std::get<0>(argAndLoc).setLoc(std::get<1>(argAndLoc).value()); ++ result.location = trailingLoc.value_or(currLocation); ++ return success(); ++ } ++ ++ // Parse the inner-op name and check if the contract on inner-op ++ // mentioned in "isEligibleForCompactPrint::E2" for pretty-printing is met. ++ FailureOr innerOpNameInfo = parser.parseCustomOperationName(); ++ if (failed(innerOpNameInfo)) return failure(); ++ ++ StringRef innerOpName = innerOpNameInfo->getStringRef(); ++ Dialect* innerOpDialect = innerOpNameInfo->getDialect(); ++ StringRef reduceOpDialect = result.name.getDialectNamespace(); ++ LLVM_DEBUG(llvm::dbgs() << "Reduce: " << reduceOpDialect << "\n"); ++ LLVM_DEBUG(llvm::dbgs() << "inner: " << innerOpDialect->getNamespace() ++ << "\n"); ++ if (!innerOpDialect || ++ !innerOpDialect->getNamespace().equals(reduceOpDialect) || ++ !innerOpNameInfo->hasTrait::Impl>() || ++ !innerOpNameInfo->hasTrait() || ++ (!innerOpNameInfo->hasTrait() && ++ !innerOpNameInfo->hasTrait()) || ++ !innerOpNameInfo->hasTrait()) { ++ parser.emitError(loc, ++ "expected the inner-op to be a commutative binary-op that " ++ "matching the reduce op dialect, with zero region, " ++ "producing single result"); ++ return failure(); ++ } ++ ++ // Parse the inner-op dimensions, reduce-op's function-type and ++ // optional location. ++ SmallVector dimensions; ++ auto parseDim = [&]() -> ParseResult { ++ if (parser.parseInteger(dimensions.emplace_back())) return failure(); ++ return success(); ++ }; ++ ++ std::optional explicitLoc; ++ FunctionType reduceOpFnType; ++ if (parser.parseKeyword("across") || parser.parseKeyword("dimensions") || ++ parser.parseEqual() || ++ parser.parseCommaSeparatedList(AsmParser::Delimiter::Square, parseDim) || ++ parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || ++ parser.parseType(reduceOpFnType) || ++ parser.parseOptionalLocationSpecifier(explicitLoc)) ++ return failure(); ++ ++ if (!reduceOpFnType || reduceOpFnType.getInputs().empty()) { ++ if (!reduceOpFnType) return parser.emitError(loc, "expected function type"); ++ return parser.emitError(loc, ++ "input types missing in reduce-op function type"); ++ } ++ ++ // If location of reduce-op is explicitly provided, then use it; Else use ++ // the parser's current location. ++ Location reduceOpLoc = explicitLoc.value_or(currLocation); ++ ++ // Derive the SSA-values for reduce-op's operands. ++ if (parser.resolveOperands(operands, reduceOpFnType.getInputs(), loc, ++ result.operands)) ++ return failure(); ++ ++ // Derive the type of inner-op from that of reduce-op's input operand. ++ auto innerOpType = RankedTensorType::get( ++ /*shape=*/{}, getElementTypeOrSelf(reduceOpFnType.getInput(0))); ++ ++ // Add a region for reduce-op. ++ Region& region = *result.addRegion(); ++ ++ // Create a basic-block inside reduce-op's region. ++ Block& block = region.emplaceBlock(); ++ auto lhs = block.addArgument(innerOpType, reduceOpLoc); ++ auto rhs = block.addArgument(innerOpType, reduceOpLoc); ++ ++ // Create and insert an "inner-op" operation in the block. ++ OpBuilder builder(parser.getBuilder().getContext()); ++ builder.setInsertionPointToStart(&block); ++ ++ OperationState innerOpState(reduceOpLoc, innerOpName); ++ innerOpState.operands.push_back(lhs); ++ innerOpState.operands.push_back(rhs); ++ innerOpState.addTypes(innerOpType); ++ ++ Operation* innerOp = builder.create(innerOpState); ++ ++ // Insert a return statement in the block returning the inner-op's result. ++ createReturn(builder, innerOp->getDialect(), innerOp->getLoc(), ++ innerOp->getResults()); ++ ++ // Populate the reduce-op operation-state with result-type, location, and ++ // dimension attribute. ++ result.addTypes(reduceOpFnType.getResults()); ++ result.location = innerOp->getLoc(); ++ result.addAttribute("dimensions", createDimensions(builder, dimensions)); ++ return success(); ++} ++ + void printSelectOpType(OpAsmPrinter& p, Operation* op, ShapedType pred, + ShapedType onTrue, ShapedType onFalse, + ShapedType result) { +@@ -250,6 +600,63 @@ + auto fnType = types[0].cast(); + return assignFromFunctionType(parser, loc, {&pred, &onTrue, &onFalse}, result, + fnType); ++} ++ ++void printWhileOp(OpAsmPrinter& p, Operation* op, Region& cond, Region& body) { ++ p << '('; ++ llvm::interleaveComma(llvm::zip(body.getArguments(), op->getOperands()), p, ++ [&](auto zip) { ++ p.printOperand(std::get<0>(zip)); ++ p << " = "; ++ p.printOperand(std::get<1>(zip)); ++ }); ++ p << ")"; ++ if (op->getNumOperands()) { ++ p << " : "; ++ llvm::interleaveComma(op->getOperandTypes(), p); ++ } ++ p.printOptionalAttrDictWithKeyword(op->getAttrs()); ++ p.printNewline(); ++ p << " cond "; ++ p.printRegion(cond, /*printEntryBlockArgs=*/false); ++ p << " do "; ++ p.printRegion(body, /*printEntryBlockArgs=*/false); ++} ++ ++ParseResult parseWhileOp(OpAsmParser& parser, OperationState& result) { ++ llvm::SMLoc loc = parser.getCurrentLocation(); ++ // Parse the operands of the while: these are of the form: ++ // %iter_arg = %init_val ++ // where %iter_arg is the name of the block argument in the cond/body blocks ++ // and %init_val is the actual operand. ++ SmallVector operands; ++ SmallVector iterArgs; ++ if (parser.parseLParen()) return failure(); ++ do { ++ if (succeeded(parser.parseOptionalRParen())) break; ++ OpAsmParser::UnresolvedOperand operand, iterArg; ++ if (parser.parseOperand(iterArg) || parser.parseEqual() || ++ parser.parseOperand(operand)) ++ return failure(); ++ iterArgs.push_back(iterArg); ++ operands.push_back(operand); ++ if (succeeded(parser.parseOptionalRParen())) break; ++ if (failed(parser.parseComma())) return failure(); ++ } while (true); ++ if (!operands.empty()) { ++ if (parser.parseColon() || parser.parseTypeList(result.types)) ++ return failure(); ++ } ++ SmallVector args; ++ createArgs(iterArgs, result.types, args); ++ if (parser.resolveOperands(operands, result.types, loc, result.operands) || ++ parser.parseOptionalAttrDictWithKeyword(result.attributes) || ++ parser.parseKeyword("cond") || ++ parser.parseRegion(*result.addRegion(), args) || ++ parser.parseKeyword("do") || ++ parser.parseRegion(*result.addRegion(), args)) ++ return failure(); ++ return success(); + } + + //===----------------------------------------------------------------------===// +diff --ruN a/stablehlo/stablehlo/dialect/AssemblyFormat.h b/stablehlo/stablehlo/dialect/AssemblyFormat.h +--- stablehlo/stablehlo/dialect/AssemblyFormat.h ++++ stablehlo/stablehlo/dialect/AssemblyFormat.h +@@ -16,19 +16,25 @@ + #ifndef STABLEHLO_DIALECT_ASSEMBLYFORMAT_H + #define STABLEHLO_DIALECT_ASSEMBLYFORMAT_H + ++#include ++#include ++ + #include "llvm/ADT/ArrayRef.h" + #include "llvm/ADT/SmallVector.h" +-#include "llvm/ADT/StringRef.h" + #include "mlir/IR/Attributes.h" + #include "mlir/IR/Builders.h" + #include "mlir/IR/BuiltinAttributes.h" ++#include "mlir/IR/BuiltinTypeInterfaces.h" + #include "mlir/IR/Dialect.h" + #include "mlir/IR/DialectImplementation.h" +-#include "mlir/IR/MLIRContext.h" + #include "mlir/IR/OpImplementation.h" + #include "mlir/IR/Operation.h" ++#include "mlir/IR/OperationSupport.h" ++#include "mlir/IR/Region.h" + #include "mlir/IR/TypeRange.h" + #include "mlir/IR/Types.h" ++#include "mlir/IR/ValueRange.h" ++#include "mlir/Support/LLVM.h" + #include "mlir/Support/LogicalResult.h" + #include "stablehlo/dialect/Base.h" + +@@ -154,6 +160,15 @@ + ParseResult parseComplexOpType(OpAsmParser& parser, Type& lhs, Type& rhs, + Type& result); + ++// Print reduce with or without compact printing ++void printReduceOp(OpAsmPrinter& p, Operation* op, ValueRange inputs, ++ ArrayRef dimensions, Region& body); ++ ++// Parse reduce with or without compact parsing ++ParseResult parseReduceOp( ++ OpAsmParser& parser, OperationState& result, ++ std::function)> createDimensions); ++ + // SelectOpType - only print the condition and result type when branch types + // match the result type. // +@@ -170,15 +185,27 @@ + ParseResult parseSelectOpType(OpAsmParser& parser, Type& pred, Type& onTrue, + Type& onFalse, Type& result); + ++// Print a `while` op. ++// ++// op ::= `stablehlo.while` `(` assignment-list `)` `:` types attribute-dict ++// `cond` region ++// `do` region ++// assignment-list ::= assignment | assignment `,` assignment-list ++// assignment ::= ssa-value `=` ssa-value ++void printWhileOp(OpAsmPrinter& p, Operation* op, Region& cond, Region& body); ++ ++// Parse reduce with or without compact parsing ++ParseResult parseWhileOp(OpAsmParser& parser, OperationState& result); ++ + //===----------------------------------------------------------------------===// + // Attribute Printers and Parsers + //===----------------------------------------------------------------------===// + + // SliceRanges - Used to print multi-dimensional ranges for slice. + void printSliceRanges(OpAsmPrinter& p, Operation* op, +- ArrayRef startIndices, +- ArrayRef limitIndices, +- ArrayRef strides); ++ llvm::ArrayRef startIndices, ++ llvm::ArrayRef limitIndices, ++ llvm::ArrayRef strides); + + ParseResult parseSliceRanges(OpAsmParser& parser, + DenseI64ArrayAttr& startIndices, +diff --ruN a/stablehlo/stablehlo/dialect/StablehloOps.cpp b/stablehlo/stablehlo/dialect/StablehloOps.cpp +--- stablehlo/stablehlo/dialect/StablehloOps.cpp ++++ stablehlo/stablehlo/dialect/StablehloOps.cpp +@@ -99,16 +99,6 @@ + return dialect->getRegisteredInterface(); + } + +-void createArgs(ArrayRef operands, +- ArrayRef types, +- SmallVector& args) { +- for (auto argAndType : llvm::zip(operands, types)) { +- auto& arg = args.emplace_back(); +- arg.ssaName = std::get<0>(argAndType); +- arg.type = std::get<1>(argAndType); +- } +-} +- + // Returns a new scalar integer value having type `type`. Here `type` must be + // an integer or index type. + Value maybeCastTo(OpBuilder& b, Location loc, Value value, Type type) { +@@ -1472,305 +1462,16 @@ + // ReduceOp + //===----------------------------------------------------------------------===// + +-bool hasSameOperandAndResultTypes(Operation& op) { +- Type expected; +- if (op.getNumResults() != 0) expected = op.getResult(0).getType(); +- if (op.getNumOperands() != 0) expected = op.getOperand(0).getType(); +- if (!expected) return false; +- +- auto typeMatch = [&](Type actual) { return actual == expected; }; +- return llvm::all_of(op.getOperandTypes(), typeMatch) && +- llvm::all_of(op.getResultTypes(), typeMatch); +-} +- +-// Checks the following eligibility criteria for compact printing of reduce: +-// E1. The reduce-op wraps a single inner-op in the associated region. +-// E2. The single operation is a commutative binary-op from the dialect, zero +-// region, producing single result such that the operands and result all +-// have the same type. +-// E3. The reduce-op consist of at least one input-operand; The operand-types of +-// inner-op should be derived trivially from the element-type of reduce-op's +-// first input-operand. +-// E4. The arguments of the region's only basic block are forwarded perfectly +-// to inner-op's operands. +-// E5. The single operation result is perfectly forwarded to the reduce op +-// return. +-static bool isEligibleForCompactPrint(ReduceOp op) { +- // Check E1. +- auto& block = op.getBody().front(); +- if (!hasSingleElement(block.without_terminator())) return false; +- +- Operation& innerOp = *block.begin(); +- +- // Check E2. +- if (innerOp.getDialect() != op->getDialect()) return false; +- +- if (innerOp.getNumOperands() != 2 || +- !innerOp.hasTrait() || +- !hasSameOperandAndResultTypes(innerOp) || +- !innerOp.hasTrait() || +- !innerOp.hasTrait()) +- return false; +- +- // Check E3. +- if (op.getInputs().empty()) return false; +- +- auto elemType = +- op.getInputs()[0].getType().cast().getElementType(); +- auto expectedInnerOpType = RankedTensorType::get(/*shape=*/{}, elemType); +- if (innerOp.getOperands()[0].getType() != expectedInnerOpType) return false; +- +- // Check E4. +- if (!llvm::equal(block.getArguments(), innerOp.getOperands())) return false; +- +- // Check E5. +- auto retOp = dyn_cast(block.getTerminator()); +- if (!retOp) return false; +- +- return llvm::equal(innerOp.getResults(), retOp.getOperands()); +-} +- + void ReduceOp::print(OpAsmPrinter& p) { +- { +- // Print the pairs of operands under the form: +- // (%arg0 init: %arg3), (%arg1 init: %arg4), (%arg2 init: %arg5) +- StringRef comma = ""; +- int numOperandPairs = getNumOperands() / 2; +- for (int opId : llvm::seq(0, numOperandPairs)) { +- p << comma << "(" << getOperand(opId) +- << " init: " << getOperand(opId + numOperandPairs) << ")"; +- comma = ", "; +- } +- } +- +- // If the reduce-op is eligible for compact printing, we emit the one-liner: +- // stablehlo.reduce applies across dimensions = [...] : +- // Note: We are not printing the function type of reduction operation. We +- // have some simplifying assumptions (refer to IsEligibleForCompactPrint::E3) +- // to derive the type from that of reduce-op. +- if (isEligibleForCompactPrint(*this)) { +- Operation& innerOp = getBody().front().front(); +- p << " applies "; +- printEscapedString(innerOp.getName().getStringRef(), p.getStream()); +- +- p << " across dimensions = ["; +- llvm::interleaveComma(getDimensions(), p); +- p << "]"; +- p.printOptionalAttrDict(getOperation()->getAttrs(), {"dimensions"}); +- p << " : "; +- p.printFunctionalType(*this); +- } else { +- p << " across dimensions = ["; +- llvm::interleaveComma(getDimensions(), p); +- p << "]"; +- p.printOptionalAttrDict(getOperation()->getAttrs(), {"dimensions"}); +- p << " : "; +- p.printFunctionalType(*this); +- p.printNewline(); +- p << " reducer"; +- { +- // Print the pairs of block operands under the form: +- // (%arg0_elt, %arg0_acc) (%arg1_elt, %arg1_acc): +- Block& reducer = getBody().front(); +- int numOperandPairs = getNumOperands() / 2; +- for (int opId : llvm::seq(0, numOperandPairs)) { +- p << "("; +- p.printRegionArgument(reducer.getArgument(opId)); +- p << ", "; +- p.printRegionArgument(reducer.getArgument(opId + numOperandPairs)); +- p << ") "; +- } +- } +- p << ' '; +- p.printRegion(getBody(), /*printEntryBlockArgs=*/false); +- } ++ hlo::printReduceOp(p, getOperation(), getInputs(), getDimensions(), ++ getBody()); + } + + ParseResult ReduceOp::parse(OpAsmParser& parser, OperationState& result) { +- llvm::SMLoc loc = parser.getCurrentLocation(); +- Location currLocation = parser.getEncodedSourceLoc(loc); +- +- // Parse the operands of reduce-op, this is a list of pair under the form: +- // (%arg0 init: %arg3), (%arg1 init: %arg4), (%arg2 init: %arg5) +- // Each input to reduce is paired with its init value, even though in memory +- // they are stored with the input first and the init values after. +- SmallVector operands; +- SmallVector initOperands; +- do { +- (void)parser.parseOptionalComma(); +- if (parser.parseOptionalLParen()) break; +- OpAsmParser::UnresolvedOperand operand, initOperand; +- if (parser.parseOperand(operand) || parser.parseKeyword("init") || +- parser.parseColon() || parser.parseOperand(initOperand) || +- parser.parseRParen()) +- return failure(); +- operands.push_back(operand); +- initOperands.push_back(initOperand); +- } while (true); +- operands.append(initOperands); +- +- // Check if we are parsing the compact version of reduce-op: +- // stablehlo.reduce applies across dimensions = [...] : +- // else parse the "region-based" variant. +- if (failed(parser.parseOptionalKeyword("applies"))) { +- // Parse the inner-op dimensions, reduce-op's function-type and +- // optional location. +- SmallVector dimensions; +- auto parseDim = [&]() -> ParseResult { +- if (parser.parseInteger(dimensions.emplace_back())) return failure(); +- return success(); +- }; +- +- FunctionType reduceOpFnType; +- if (parser.parseKeyword("across") || parser.parseKeyword("dimensions") || +- parser.parseEqual() || +- parser.parseCommaSeparatedList(AsmParser::Delimiter::Square, +- parseDim) || +- parser.parseOptionalAttrDict(result.attributes) || +- parser.parseColon() || parser.parseType(reduceOpFnType) || +- parser.parseKeyword("reducer")) +- return failure(); +- OpBuilder builder(parser.getBuilder().getContext()); +- result.addAttribute("dimensions", builder.getDenseI64ArrayAttr(dimensions)); +- +- // Parse the "reducer" region now. +- SmallVector reducerOperands; +- SmallVector reducerInitOperands; +- SmallVector reducerTypes; +- SmallVector reducerInitTypes; +- SmallVector, 2> reducerLocs; +- SmallVector, 2> reducerInitLocs; +- auto parseBlockOperand = +- [&](SmallVectorImpl& operands, +- SmallVectorImpl& types, +- SmallVectorImpl>& locs) -> ParseResult { +- OpAsmParser::UnresolvedOperand operand; +- Type type; +- std::optional loc; +- if (parser.parseOperand(operand, /*allowResultNumber=*/false) || +- parser.parseColon() || parser.parseType(type) || +- parser.parseOptionalLocationSpecifier(loc)) +- return failure(); +- operands.push_back(operand); +- types.push_back(type); +- locs.push_back(loc); +- return success(); +- }; +- do { +- if (failed(parser.parseOptionalLParen())) break; +- if (parseBlockOperand(reducerOperands, reducerTypes, reducerLocs) || +- parser.parseComma() || +- parseBlockOperand(reducerInitOperands, reducerInitTypes, +- reducerInitLocs) || +- parser.parseRParen()) +- return failure(); +- } while (true); +- reducerOperands.append(reducerInitOperands); +- reducerTypes.append(reducerInitTypes); +- reducerLocs.append(reducerInitLocs); +- result.addTypes(reduceOpFnType.getResults()); +- SmallVector reducerArgs; +- createArgs(reducerOperands, reducerTypes, reducerArgs); +- +- // Derive the SSA-values for reduce-op's operands and parse the region, and +- // the optional trailing location. +- std::optional trailingLoc; +- if (parser.resolveOperands(operands, reduceOpFnType.getInputs(), loc, +- result.operands) || +- parser.parseRegion(*result.addRegion(), reducerArgs)) +- return failure(); +- // Set the individual block arguments. +- for (auto argAndLoc : +- llvm::zip(result.regions.front()->front().getArguments(), reducerLocs)) +- if (std::get<1>(argAndLoc)) +- std::get<0>(argAndLoc).setLoc(std::get<1>(argAndLoc).value()); +- result.location = trailingLoc.value_or(currLocation); +- return success(); +- } +- +- // Parse the inner-op name and check if the contract on inner-op +- // mentioned in "isEligibleForCompactPrint::E2" for pretty-printing is met. +- FailureOr innerOpNameInfo = parser.parseCustomOperationName(); +- if (failed(innerOpNameInfo)) return failure(); +- +- StringRef innerOpName = innerOpNameInfo->getStringRef(); +- Dialect* innerOpDialect = innerOpNameInfo->getDialect(); +- if (!innerOpDialect || !innerOpDialect->getNamespace().equals("stablehlo") || +- !innerOpNameInfo->hasTrait::Impl>() || +- !innerOpNameInfo->hasTrait() || +- !innerOpNameInfo->hasTrait() || +- !innerOpNameInfo->hasTrait()) { +- parser.emitError(loc, +- "expected the inner-op to be a commutative binary-op from " +- "stablehlo dialect, zero region, producing single result"); +- return failure(); +- } +- +- // Parse the inner-op dimensions, reduce-op's function-type and +- // optional location. +- SmallVector dimensions; +- auto parseDim = [&]() -> ParseResult { +- if (parser.parseInteger(dimensions.emplace_back())) return failure(); +- return success(); ++ auto parseDenseArray = [](OpBuilder& b, ArrayRef dims) -> Attribute { ++ return b.getDenseI64ArrayAttr(dims); + }; +- +- std::optional explicitLoc; +- FunctionType reduceOpFnType; +- if (parser.parseKeyword("across") || parser.parseKeyword("dimensions") || +- parser.parseEqual() || +- parser.parseCommaSeparatedList(AsmParser::Delimiter::Square, parseDim) || +- parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || +- parser.parseType(reduceOpFnType) || +- parser.parseOptionalLocationSpecifier(explicitLoc)) +- return failure(); +- +- if (!reduceOpFnType || reduceOpFnType.getInputs().empty()) { +- if (!reduceOpFnType) return parser.emitError(loc, "expected function type"); +- return parser.emitError(loc, +- "input types missing in reduce-op function type"); +- } +- +- // If location of reduce-op is explicitly provided, then use it; Else use +- // the parser's current location. +- Location reduceOpLoc = explicitLoc.value_or(currLocation); +- +- // Derive the SSA-values for reduce-op's operands. +- if (parser.resolveOperands(operands, reduceOpFnType.getInputs(), loc, +- result.operands)) +- return failure(); +- +- // Derive the type of inner-op from that of reduce-op's input operand. +- auto innerOpType = RankedTensorType::get( +- /*shape=*/{}, getElementTypeOrSelf(reduceOpFnType.getInput(0))); +- +- // Add a region for reduce-op. +- Region& region = *result.addRegion(); +- +- // Create a basic-block inside reduce-op's region. +- Block& block = region.emplaceBlock(); +- auto lhs = block.addArgument(innerOpType, reduceOpLoc); +- auto rhs = block.addArgument(innerOpType, reduceOpLoc); +- +- // Create and insert an "inner-op" operation in the block. +- OpBuilder builder(parser.getBuilder().getContext()); +- builder.setInsertionPointToStart(&block); +- +- OperationState innerOpState(reduceOpLoc, innerOpName); +- innerOpState.operands.push_back(lhs); +- innerOpState.operands.push_back(rhs); +- innerOpState.addTypes(innerOpType); +- +- Operation* innerOp = builder.create(innerOpState); +- +- // Insert a return statement in the block returning the inner-op's result. +- builder.create(innerOp->getLoc(), innerOp->getResults()); +- +- // Populate the reduce-op operation-state with result-type, location, and +- // dimension attribute. +- result.addTypes(reduceOpFnType.getResults()); +- result.location = innerOp->getLoc(); +- result.addAttribute("dimensions", builder.getDenseI64ArrayAttr(dimensions)); +- return success(); ++ return hlo::parseReduceOp(parser, result, parseDenseArray); + } + + LogicalResult ReduceOp::inferReturnTypeComponents( +@@ -2385,69 +2086,12 @@ + return hlo::verifyWhileOp(getLoc(), getOperand(), getCond(), getBody()); + } + +-/// Print a `while` op. +-/// +-/// op ::= `stablehlo.while` `(` assignment-list `)` `:` types attribute-dict +-/// `cond` region +-/// `do` region +-/// assignment-list ::= assignment | assignment `,` assignment-list +-/// assignment ::= ssa-value `=` ssa-value + void WhileOp::print(OpAsmPrinter& p) { +- p << '('; +- llvm::interleaveComma( +- llvm::zip(SingleBlock::getBody()->getArguments(), getOperands()), p, +- [&](auto zip) { +- p.printOperand(std::get<0>(zip)); +- p << " = "; +- p.printOperand(std::get<1>(zip)); +- }); +- p << ")"; +- if (getNumOperands()) { +- p << " : "; +- llvm::interleaveComma(getOperandTypes(), p); +- } +- p.printOptionalAttrDictWithKeyword(getOperation()->getAttrs()); +- p.printNewline(); +- p << " cond "; +- p.printRegion(getRegion(0), /*printEntryBlockArgs=*/false); +- p << " do "; +- p.printRegion(getRegion(1), /*printEntryBlockArgs=*/false); ++ hlo::printWhileOp(p, getOperation(), getCond(), getBody()); + } + + ParseResult WhileOp::parse(OpAsmParser& parser, OperationState& result) { +- llvm::SMLoc loc = parser.getCurrentLocation(); +- // Parse the operands of the while: these are of the form: +- // %iter_arg = %init_val +- // where %iter_arg is the name of the block argument in the cond/body blocks +- // and %init_val is the actual operand. +- SmallVector operands; +- SmallVector iterArgs; +- if (parser.parseLParen()) return failure(); +- do { +- if (succeeded(parser.parseOptionalRParen())) break; +- OpAsmParser::UnresolvedOperand operand, iterArg; +- if (parser.parseOperand(iterArg) || parser.parseEqual() || +- parser.parseOperand(operand)) +- return failure(); +- iterArgs.push_back(iterArg); +- operands.push_back(operand); +- if (succeeded(parser.parseOptionalRParen())) break; +- if (failed(parser.parseComma())) return failure(); +- } while (true); +- if (!operands.empty()) { +- if (parser.parseColon() || parser.parseTypeList(result.types)) +- return failure(); +- } +- SmallVector args; +- createArgs(iterArgs, result.types, args); +- if (parser.resolveOperands(operands, result.types, loc, result.operands) || +- parser.parseOptionalAttrDictWithKeyword(result.attributes) || +- parser.parseKeyword("cond") || +- parser.parseRegion(*result.addRegion(), args) || +- parser.parseKeyword("do") || +- parser.parseRegion(*result.addRegion(), args)) +- return failure(); +- return success(); ++ return hlo::parseWhileOp(parser, result); + } + + LogicalResult UniformDequantizeOp::inferReturnTypeComponents( diff --ruN a/stablehlo/stablehlo/experimental/BUILD.bazel b/stablehlo/stablehlo/experimental/BUILD.bazel --- stablehlo/stablehlo/experimental/BUILD.bazel +++ stablehlo/stablehlo/experimental/BUILD.bazel @@ -1274,7 +2226,7 @@ diff --ruN a/stablehlo/stablehlo/experimental/tests/CMakeLists.txt b/stablehlo/s diff --ruN a/stablehlo/stablehlo/experimental/tests/chlo_recompose_ops.mlir b/stablehlo/stablehlo/experimental/tests/chlo_recompose_ops.mlir --- stablehlo/stablehlo/experimental/tests/chlo_recompose_ops.mlir +++ stablehlo/stablehlo/experimental/tests/chlo_recompose_ops.mlir -@@ -0,0 +1,36 @@ +@@ -0,0 +1,51 @@ +// RUN: experimental-stablehlo-opt --experimental-chlo-recompose-ops --split-input-file --verify-diagnostics %s | FileCheck %s + +// ----- @@ -1311,6 +2263,21 @@ diff --ruN a/stablehlo/stablehlo/experimental/tests/chlo_recompose_ops.mlir b/st + } : (tensor<16xf32>) -> tensor + func.return %0 : tensor +} ++ ++// ----- ++ ++// CHECK-LABEL: @recompose_erf ++func.func @recompose_erf(%arg0: tensor<3x20x20xbf16>) -> tensor { ++ // CHECK: %0 = chlo.erf %arg0 : tensor<3x20x20xbf16> -> tensor ++ %0 = "stablehlo.custom_call"(%arg0) { ++ backend_config = "", ++ call_target_name = "mhlo.erf", ++ mhlo.attributes = {}, ++ mhlo.version = 1 : i64 ++ } : (tensor<3x20x20xbf16>) -> tensor ++ func.return %0 : tensor ++} ++ diff --ruN a/stablehlo/stablehlo/experimental/tests/lit.cfg.py b/stablehlo/stablehlo/experimental/tests/lit.cfg.py --- stablehlo/stablehlo/experimental/tests/lit.cfg.py +++ stablehlo/stablehlo/experimental/tests/lit.cfg.py @@ -1922,7 +2889,7 @@ diff --ruN a/stablehlo/stablehlo/experimental/transforms/CMakeLists.txt b/stable diff --ruN a/stablehlo/stablehlo/experimental/transforms/ChloRecomposeOps.cpp b/stablehlo/stablehlo/experimental/transforms/ChloRecomposeOps.cpp --- stablehlo/stablehlo/experimental/transforms/ChloRecomposeOps.cpp +++ stablehlo/stablehlo/experimental/transforms/ChloRecomposeOps.cpp -@@ -0,0 +1,151 @@ +@@ -0,0 +1,168 @@ +/* Copyright 2024 The StableHLO Authors. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. @@ -1940,6 +2907,7 @@ diff --ruN a/stablehlo/stablehlo/experimental/transforms/ChloRecomposeOps.cpp b/ +#include +#include + ++#include "llvm/ADT/SmallVector.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/BuiltinAttributes.h" @@ -2044,6 +3012,15 @@ diff --ruN a/stablehlo/stablehlo/experimental/transforms/ChloRecomposeOps.cpp b/ + } +}; + ++struct ErfOpRecomposePattern : public OpRewritePattern { ++ using OpRewritePattern::OpRewritePattern; ++ LogicalResult matchAndRewrite(CustomCallOp op, ++ PatternRewriter& rewriter) const override { ++ if (op.getCallTargetName() != "mhlo.erf") return failure(); ++ return recomposeChloOpFromCustomCall(op, rewriter); ++ } ++}; ++ +} // namespace + +struct ChloRecomposeOpsPass @@ -2051,21 +3028,28 @@ diff --ruN a/stablehlo/stablehlo/experimental/transforms/ChloRecomposeOps.cpp b/ + using ChloRecomposeOpsPassBase::ChloRecomposeOpsPassBase; + + void runOnOperation() override { -+ // Do a single traversal to recompose CHLO ops. -+ // TODO(#1048): Find out why .maxIterations = 1 no longer works. ++ // Do a single traversal to recompose CustomCallOp to CHLO ops. + GreedyRewriteConfig config; + config.useTopDownTraversal = true; + config.enableRegionSimplification = true; -+ config.maxIterations = 2; ++ config.maxIterations = 1; + config.maxNumRewrites = GreedyRewriteConfig::kNoLimit; -+ config.strictMode = GreedyRewriteStrictness::AnyOp; ++ config.strictMode = GreedyRewriteStrictness::ExistingOps; + + RewritePatternSet patterns(&getContext()); + patterns.add(&getContext()); + patterns.add(&getContext()); ++ patterns.add(&getContext()); + -+ if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns), -+ config))) { ++ // Only apply to CustomCallOps ++ auto moduleOp = getOperation(); ++ llvm::SmallVector candidateOps; ++ moduleOp.walk([&](CustomCallOp op) { candidateOps.push_back(op); }); ++ ++ if (failed(applyOpPatternsAndFold(candidateOps, std::move(patterns), ++ config))) { ++ moduleOp.emitError("Failed to converge ChloRecomposeOps in ") ++ << config.maxIterations << " iterations"; + return signalPassFailure(); + } + } @@ -2160,7 +3144,7 @@ diff --ruN a/stablehlo/stablehlo/experimental/transforms/Passes.td b/stablehlo/s diff --ruN a/stablehlo/stablehlo/experimental/transforms/StablehloCanonicalizeDynamism.cpp b/stablehlo/stablehlo/experimental/transforms/StablehloCanonicalizeDynamism.cpp --- stablehlo/stablehlo/experimental/transforms/StablehloCanonicalizeDynamism.cpp +++ stablehlo/stablehlo/experimental/transforms/StablehloCanonicalizeDynamism.cpp -@@ -0,0 +1,167 @@ +@@ -0,0 +1,171 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + Copyright 2023 The StableHLO Authors. +Licensed under the Apache License, Version 2.0 (the "License"); @@ -2317,8 +3301,12 @@ diff --ruN a/stablehlo/stablehlo/experimental/transforms/StablehloCanonicalizeDy + patterns.add(&getContext()); + patterns.add(&getContext()); + patterns.add(&getContext()); -+ if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns), ++ ++ auto funcOp = getOperation(); ++ if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns), + config))) { ++ funcOp.emitError("Failed to converge StablehloCanonicalizeDynamism in ") ++ << config.maxIterations << " iterations"; + return signalPassFailure(); + } + } @@ -2502,4 +3490,61 @@ diff --ruN a/stablehlo/stablehlo/experimental/transforms/StablehloRefineShapes.c +} // namespace experimental +} // namespace stablehlo +} // namespace mlir +diff --ruN a/stablehlo/stablehlo/tests/verify_reduce.mlir b/stablehlo/stablehlo/tests/verify_reduce.mlir +--- stablehlo/stablehlo/tests/verify_reduce.mlir ++++ stablehlo/stablehlo/tests/verify_reduce.mlir +@@ -490,7 +490,7 @@ + // ----- + + func.func @reduce_parsing_pretty_reduce_non_commutative(%arg0: tensor , %arg1: tensor ) -> tensor { +- // expected-error@+1 {{expected the inner-op to be a commutative binary-op from stablehlo dialect, zero region, producing single result}} ++ // expected-error@+1 {{expected the inner-op to be a commutative binary-op that matching the reduce op dialect, with zero region, producing single result}} + %0 = stablehlo.reduce(%arg0 init: %arg1) applies stablehlo.divide across dimensions = [1] : (tensor, tensor) -> tensor loc("foo") + func.return %0 : tensor + } +@@ -498,7 +498,7 @@ + // ----- + + func.func @reduce_parsing_pretty_reduce_wrong_dialect(%arg0: tensor , %arg1: tensor ) -> tensor { +- // expected-error@+1 {{expected the inner-op to be a commutative binary-op from stablehlo dialect, zero region, producing single result}} ++ // expected-error@+1 {{expected the inner-op to be a commutative binary-op that matching the reduce op dialect, with zero region, producing single result}} + %0 = stablehlo.reduce(%arg0 init: %arg1) applies std.add across dimensions = [1] : (tensor, tensor) -> tensor loc("foo") + func.return %0 : tensor + } +@@ -506,7 +506,7 @@ + // ----- + + func.func @reduce_parsing_pretty_reduce_non_binary(%arg0: tensor , %arg1: tensor ) -> tensor { +- // expected-error@+1 {{expected the inner-op to be a commutative binary-op from stablehlo dialect, zero region, producing single result}} ++ // expected-error@+1 {{expected the inner-op to be a commutative binary-op that matching the reduce op dialect, with zero region, producing single result}} + %0 = stablehlo.reduce(%arg0 init: %arg1) applies stablehlo.reshape across dimensions = [1] : (tensor, tensor) -> tensor loc("foo") + func.return %0 : tensor + } +diff --ruN a/stablehlo/stablehlo/transforms/StablehloAggressiveSimplification.cpp b/stablehlo/stablehlo/transforms/StablehloAggressiveSimplification.cpp +--- stablehlo/stablehlo/transforms/StablehloAggressiveSimplification.cpp ++++ stablehlo/stablehlo/transforms/StablehloAggressiveSimplification.cpp +@@ -126,9 +126,8 @@ + + // The canonical form has the constant operand as the RHS. + if (isa(type.getElementType()) && lhsAttr && !rhsAttr) { +- rewriter.modifyOpInPlace(op, [op, lhs, rhs] { +- op->setOperands(ValueRange{rhs, lhs}); +- }); ++ rewriter.modifyOpInPlace( ++ op, [op, lhs, rhs] { op->setOperands(ValueRange{rhs, lhs}); }); + return success(); + } + +@@ -221,9 +220,8 @@ + + // The canonical form has the constant operand as the RHS. + if (isa(type.getElementType()) && lhsAttr && !rhsAttr) { +- rewriter.modifyOpInPlace(op, [op, lhs, rhs] { +- op->setOperands(ValueRange{rhs, lhs}); +- }); ++ rewriter.modifyOpInPlace( ++ op, [op, lhs, rhs] { op->setOperands(ValueRange{rhs, lhs}); }); + return success(); + } + diff --git a/third_party/stablehlo/workspace.bzl b/third_party/stablehlo/workspace.bzl index 271c373a66a8c1..411c6103290796 100644 --- a/third_party/stablehlo/workspace.bzl +++ b/third_party/stablehlo/workspace.bzl @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") def repo(): # LINT.IfChange - STABLEHLO_COMMIT = "c30f551469ca37a1f2a8c8ac42ef1b989573dce6" - STABLEHLO_SHA256 = "71720bd4003f417beb1acefd8f87b20f0e1db5edf498bf2f4642e5f0a3542c02" + STABLEHLO_COMMIT = "e708c82502982697540886738a307f72f9e9a7ff" + STABLEHLO_SHA256 = "3fecbe7779bee0801af746d974738748f7b461df54a4f610b32bb75647b32125" # LINT.ThenChange(Google-internal path) tf_http_archive( diff --git a/third_party/tf_runtime/workspace.bzl b/third_party/tf_runtime/workspace.bzl index 10a56bfa0edce1..99aa32a79a30aa 100644 --- a/third_party/tf_runtime/workspace.bzl +++ b/third_party/tf_runtime/workspace.bzl @@ -6,8 +6,8 @@ def repo(): """Imports TFRT.""" # Attention: tools parse and update these lines. - TFRT_COMMIT = "e99b8f121f63cdfae811b2cafc4dab5ce97986f6" - TFRT_SHA256 = "0e0ec61414532ec44f271ed7450253c462c2789d2a2c24c178e1377bef10f3da" + TFRT_COMMIT = "aec2070dee4792b80177d167f26491b1d30eced4" + TFRT_SHA256 = "0d398b68353ae8e547f4f974d43b3c29c9ce9cce535c66dba5efcd5bee4ad36d" tf_http_archive( name = "tf_runtime", diff --git a/third_party/triton/cl607293980.patch b/third_party/triton/cl607293980.patch new file mode 100644 index 00000000000000..b7b9d0e84fab2e --- /dev/null +++ b/third_party/triton/cl607293980.patch @@ -0,0 +1,17 @@ +Long standing patch due to licensing issues. +diff --git a/include/triton/Tools/Sys/GetEnv.hpp b/include/triton/Tools/Sys/GetEnv.hpp +index 31bc03fe1..a19a432df 100644 +--- a/include/triton/Tools/Sys/GetEnv.hpp ++++ b/include/triton/Tools/Sys/GetEnv.hpp +@@ -34,9 +34,10 @@ inline const std::set ENV_VARS = { + "AMDGCN_ENABLE_DUMP", + "DISABLE_FAST_REDUCTION", + "DISABLE_LLVM_OPT", +- "DISABLE_MMA_V3", ++ "ENABLE_MMA_V3", + "DISABLE_PTXAS_OPT", + "LLVM_IR_ENABLE_DUMP", ++ "MLIR_ENABLE_DIAGNOSTICS", + "MLIR_ENABLE_DUMP", + "TRITON_DISABLE_LINE_INFO", + "TRITON_DISABLE_RESHAPE_ENCODING_INFERENCE", diff --git a/third_party/triton/workspace.bzl b/third_party/triton/workspace.bzl index d2cf68c50e4306..e0364e4b646929 100644 --- a/third_party/triton/workspace.bzl +++ b/third_party/triton/workspace.bzl @@ -5,13 +5,15 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") def repo(): """Imports Triton.""" - TRITON_COMMIT = "cl601105910" - TRITON_SHA256 = "523b31822e431c79e2d6bc566272e7fc4f4183ae28aebcf662d11db740691d6d" + TRITON_COMMIT = "cl608559313" + TRITON_SHA256 = "d37c0a2921f756cb355dc7ea7e91ea708cef867117edff37106f5a947c5a5a38" tf_http_archive( name = "triton", sha256 = TRITON_SHA256, strip_prefix = "triton-{commit}".format(commit = TRITON_COMMIT), urls = tf_mirror_urls("https://github.com/openxla/triton/archive/{commit}.tar.gz".format(commit = TRITON_COMMIT)), # For temporary changes which haven't landed upstream yet. - patch_file = [], + patch_file = [ + "//third_party/triton:cl607293980.patch", # long standing :( + ], ) diff --git a/third_party/xla/.bazelrc b/third_party/xla/.bazelrc index a635862b43a43c..c21cf6e6e15d5d 100644 --- a/third_party/xla/.bazelrc +++ b/third_party/xla/.bazelrc @@ -255,6 +255,14 @@ build:cuda_clang_official --action_env=CLANG_CUDA_COMPILER_PATH="/usr/lib/llvm-1 build:cuda_clang_official --action_env=LD_LIBRARY_PATH="/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64" build:cuda_clang_official --crosstool_top="@sigbuild-r2.16-clang_config_cuda//crosstool:toolchain" +# Build with nvcc for CUDA and clang for host +build:nvcc_clang --config=cuda +# Unfortunately, cuda_configure.bzl demands this for using nvcc + clang +build:nvcc_clang --action_env=TF_CUDA_CLANG="1" +build:nvcc_clang --action_env=TF_NVCC_CLANG="1" +build:nvcc_clang --@local_config_cuda//:cuda_compiler=nvcc + + # Debug config build:dbg -c dbg # Only include debug info for files under tensorflow/, excluding kernels, to @@ -527,8 +535,8 @@ build:rbe_linux_cuda --repo_env=TF_NCCL_CONFIG_REPO="@sigbuild-r2.16-clang_confi test:rbe_linux_cuda --test_env=LD_LIBRARY_PATH="/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64" build:rbe_linux_cuda_nvcc --config=rbe_linux_cuda +build:rbe_linux_cuda_nvcc --config=nvcc_clang build:rbe_linux_cuda_nvcc --repo_env TF_NCCL_USE_STUB=1 -build:rbe_linux_cuda_nvcc --action_env=TF_NVCC_CLANG="1" # TODO(kanglan): Remove rbe_win and rbe_win_py3* after b/289091160 is fixed build:rbe_win --config=rbe_base @@ -577,6 +585,7 @@ build:elinux_armhf --copt -mfp16-format=ieee # Load rc file written by ./configure. try-import %workspace%/.tf_configure.bazelrc +try-import %workspace%/xla_configure.bazelrc # Load rc file with user-specific options. try-import %workspace%/.bazelrc.user @@ -777,28 +786,38 @@ test:linux_cuda_pycpp_test_filters --build_tag_filters=-no_oss,-oss_excluded,-os test:linux_cuda_pycpp_test_filters --test_lang_filters=cc,py --test_size_filters=small,medium test:linux_cuda_pycpp_test --config=linux_cuda_pycpp_test_filters -- //tensorflow/... -//tensorflow/python/integration_testing/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... # ARM64 PYCPP -test:linux_arm64_pycpp_test_filters --test_tag_filters=-no_oss,-no_aarch64,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only -test:linux_arm64_pycpp_test_filters --build_tag_filters=-no_oss,-no_aarch64,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only -test:linux_arm64_pycpp_test_filters --test_lang_filters=cc,py --test_size_filters=small,medium --flaky_test_attempts=3 +# In Linux Arm64 presubmit/continuous build, we cross-compile the binaries on +# Linux x86 so that we can use RBE. Since tests still need to run on the single +# host Arm64 machine, the build becomes too slow (~30 min) to be a presubmit. +# For testing purposes, we want to see the runtime performance of an +# experimental job that is build-only, i.e, we only build the test targets and +# do not run them. By prefixing the configs with "build", we can run both +# `bazel build` and `bazel test` commands with the same config as test configs +# inherit from build. +build:linux_arm64_pycpp_test_filters --test_tag_filters=-no_oss,-no_aarch64,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only +build:linux_arm64_pycpp_test_filters --build_tag_filters=-no_oss,-no_aarch64,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only +build:linux_arm64_pycpp_test_filters --test_lang_filters=cc,py --test_size_filters=small,medium --flaky_test_attempts=3 # TODO(michaelhudgins): Why do we need to specifically omit go and java here? -test:linux_arm64_pycpp_test --config=linux_arm64_pycpp_test_filters -- //tensorflow/... -//tensorflow/python/integration_testing/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/core/grappler/optimizers:auto_mixed_precision_test_cpu -//tensorflow/core/grappler/optimizers:remapper_test_cpu -//tensorflow/core/kernels/image:resize_bicubic_op_test -//tensorflow/compiler/mlir/tfr/examples/customization:test_ops_test -//tensorflow/compiler/mlir/tfr/examples/mnist:mnist_ops_test -//tensorflow/compiler/mlir/tfr/examples/pad:pad_ops_test -//tensorflow/python/tools:aot_compiled_test +build:linux_arm64_pycpp_test --config=linux_arm64_pycpp_test_filters -- //tensorflow/... -//tensorflow/python/integration_testing/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/core/grappler/optimizers:auto_mixed_precision_test_cpu -//tensorflow/core/grappler/optimizers:remapper_test_cpu -//tensorflow/core/kernels/image:resize_bicubic_op_test -//tensorflow/compiler/mlir/tfr/examples/customization:test_ops_test -//tensorflow/compiler/mlir/tfr/examples/mnist:mnist_ops_test -//tensorflow/compiler/mlir/tfr/examples/pad:pad_ops_test -//tensorflow/python/tools:aot_compiled_test # CROSS-COMPILE ARM64 PYCPP -test:cross_compile_linux_arm64_pycpp_test --config=linux_arm64_pycpp_test +build:cross_compile_linux_arm64_pycpp_test --config=linux_arm64_pycpp_test # Tests that fail only when cross-compiled -test:cross_compile_linux_arm64_pycpp_test -//tensorflow/compiler/mlir/quantization/stablehlo:convert_tf_quant_to_mhlo_int_test +build:cross_compile_linux_arm64_pycpp_test -//tensorflow/compiler/mlir/quantization/stablehlo:convert_tf_quant_to_mhlo_int_test # MACOS ARM64 PYCPP test:macos_arm64_pycpp_test_filters --test_tag_filters=-no_oss,-oss_excluded,-oss_serial,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test,-no_mac_arm64,-no_aarch64 test:macos_arm64_pycpp_test_filters --build_tag_filters=-no_oss,-oss_excluded,-oss_serial,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test,-no_mac_arm64,-no_aarch64 test:macos_arm64_pycpp_test_filters --test_lang_filters=cc,py --test_size_filters=small,medium test:macos_arm64_pycpp_test --config=macos_arm64_pycpp_test_filters -- //tensorflow/... -//tensorflow/python/integration_testing/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/compiler/aot/... -//tensorflow/core/kernels/image:resize_bicubic_op_test # MACOS X86 PYCPP -test:macos_x86_pycpp_test_filters --test_tag_filters=-no_oss,-oss_excluded,-oss_serial,-no_oss_py38,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test -test:macos_x86_pycpp_test_filters --build_tag_filters=-no_oss,-oss_excluded,-oss_serial,-no_oss_py38,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test -test:macos_x86_pycpp_test_filters --keep_going --test_lang_filters=cc,py --test_size_filters=small,medium -test:macos_x86_pycpp_test --config=macos_x86_pycpp_test_filters -- //tensorflow/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/python/integration_testing/... -//tensorflow/tools/toolchains/... -//tensorflow/lite/... -//tensorflow/compiler/aot/... +# These are defined as build configs so that we can run a build only job. See +# the note under "ARM64 PYCPP" for more details. +build:macos_x86_pycpp_test_filters --test_tag_filters=-no_oss,-oss_excluded,-oss_serial,-no_oss_py38,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test +build:macos_x86_pycpp_test_filters --build_tag_filters=-no_oss,-oss_excluded,-oss_serial,-no_oss_py38,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test +build:macos_x86_pycpp_test_filters --keep_going --test_lang_filters=cc,py --test_size_filters=small,medium +build:macos_x86_pycpp_test --config=macos_x86_pycpp_test_filters -- //tensorflow/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/python/integration_testing/... -//tensorflow/tools/toolchains/... -//tensorflow/lite/... -//tensorflow/compiler/aot/... # CROSS-COMPILE MACOS X86 PYCPP -test:cross_compile_macos_x86_pycpp_test --config=macos_x86_pycpp_test -test:cross_compile_macos_x86_pycpp_test -//tensorflow/core/kernels:quantized_conv_ops_test -//tensorflow/core/kernels:quantized_matmul_op_test -//tensorflow/python/ops:quantized_conv_ops_test -//tensorflow/tools/graph_transforms:transforms_test -//tensorflow/python/tools:aot_compiled_test +build:cross_compile_macos_x86_pycpp_test --config=macos_x86_pycpp_test +build:cross_compile_macos_x86_pycpp_test -//tensorflow/core/kernels:quantized_conv_ops_test -//tensorflow/core/kernels:quantized_matmul_op_test -//tensorflow/python/ops:quantized_conv_ops_test -//tensorflow/tools/graph_transforms:transforms_test -//tensorflow/python/tools:aot_compiled_test # END TF TEST SUITE OPTIONS # START CROSS-COMPILE CONFIGS @@ -855,8 +874,12 @@ build:cross_compile_macos_x86 --platform_mappings=tensorflow/tools/toolchains/cr # RBE cross-compile configs for Darwin x86 build:rbe_cross_compile_macos_x86 --config=cross_compile_macos_x86 build:rbe_cross_compile_macos_x86 --config=rbe_cross_compile_base +build:rbe_cross_compile_macos_x86 --bes_upload_mode=nowait_for_upload_complete test:rbe_cross_compile_macos_x86 --config=rbe_cross_compile_base # Increase the test timeout as tests often take longer on mac. test:rbe_cross_compile_macos_x86 --test_timeout=300,450,1200,3600 +# Limit jobs to 100 to avoid running into "out of memory" issues (b/316266643) +build:rbe_cross_compile_macos_x86 --jobs=100 +test:rbe_cross_compile_macos_x86 --jobs=100 # END MACOS CROSS-COMPILE CONFIGS # END CROSS-COMPILE CONFIGS diff --git a/third_party/xla/xla/.clang-format b/third_party/xla/.clang-format similarity index 75% rename from third_party/xla/xla/.clang-format rename to third_party/xla/.clang-format index c2aa8675561990..b894b2a3e874c8 100644 --- a/third_party/xla/xla/.clang-format +++ b/third_party/xla/.clang-format @@ -1,3 +1,4 @@ BasedOnStyle: Google Language: Cpp PointerBindsToType: true +SortIncludes: Never diff --git a/third_party/xla/.gitignore b/third_party/xla/.gitignore index 2f09259330bc65..ee6ca187cc7fa4 100644 --- a/third_party/xla/.gitignore +++ b/third_party/xla/.gitignore @@ -10,6 +10,7 @@ bazel-testlogs # Ignore files produced by `configure` .tf_configure.bazelrc +xla_configure.bazelrc tools/python_bin_path.sh # Emacs autosaves diff --git a/third_party/xla/.kokoro/linux/build.sh b/third_party/xla/.kokoro/linux/build.sh index 662b66094d8f1d..1bb64faf7ba041 100644 --- a/third_party/xla/.kokoro/linux/build.sh +++ b/third_party/xla/.kokoro/linux/build.sh @@ -27,15 +27,20 @@ function is_linux_gpu_job() { } function is_linux_cpu_arm64_job() { - [[ "$KOKORO_JOB_NAME" =~ tensorflow/xla/linux/arm64/.*cpu.* ]] + [[ "$KOKORO_JOB_NAME" =~ tensorflow/xla/linux/.*arm64.*/.*cpu.* ]] } -# Pull the container (in case it was updated since the instance started) and -# store its SHA in the Sponge log. -docker pull "$DOCKER_IMAGE" -echo "TF_INFO_DOCKER_IMAGE,$DOCKER_IMAGE" >> "$KOKORO_ARTIFACTS_DIR/custom_sponge_config.csv" -echo "TF_INFO_DOCKER_SHA,$(docker pull "$DOCKER_IMAGE" | sed -n '/Digest:/s/Digest: //g p')" >> "$KOKORO_ARTIFACTS_DIR/custom_sponge_config.csv" +function pull_docker_image_with_retries() { + # Pull the container (in case it was updated since the instance started) and + # store its SHA in the Sponge log. + docker pull "$DOCKER_IMAGE" || sleep 15 + docker pull "$DOCKER_IMAGE" || sleep 15 + docker pull "$DOCKER_IMAGE" + echo "TF_INFO_DOCKER_IMAGE,$DOCKER_IMAGE" >> "$KOKORO_ARTIFACTS_DIR/custom_sponge_config.csv" + echo "TF_INFO_DOCKER_SHA,$(docker pull "$DOCKER_IMAGE" | sed -n '/Digest:/s/Digest: //g p')" >> "$KOKORO_ARTIFACTS_DIR/custom_sponge_config.csv" +} +pull_docker_image_with_retries # Start a container in the background docker run --name xla -w /tf/xla -itd --rm \ -v "$KOKORO_ARTIFACTS_DIR/github/xla:/tf/xla" \ @@ -49,8 +54,15 @@ RBE_FLAGS="" if is_linux_gpu_job ; then TAGS_FILTER="$TAGS_FILTER,gpu,requires-gpu-nvidia,-no_gpu" + + # We are currently running XLA presubmits on machines with NVIDIA T4 GPUs, + # which have a compute compatibility of 7.5. Se we filter out all the tests + # that need a newer GPU: + UNSUPPORTED_GPU_TAGS="$(echo -requires-gpu-sm{80,86,89,90}{,-only})" + TAGS_FILTER="${TAGS_FILTER},${UNSUPPORTED_GPU_TAGS// /,}" + ADDITIONAL_FLAGS="$ADDITIONAL_FLAGS --run_under=//tools/ci_build/gpu_build:parallel_gpu_execute" - RBE_FLAGS="--config=rbe_linux_cuda_nvcc" + RBE_FLAGS="--config=rbe_linux_cuda_nvcc --jobs=150" echo "***NOTE: nvidia-smi lists the highest CUDA version the driver supports, which may be different than the version of CUDA actually used!!***" nvidia-smi else @@ -58,10 +70,10 @@ else ADDITIONAL_FLAGS="$ADDITIONAL_FLAGS --config=nonccl" if is_linux_cpu_arm64_job ; then - TAGS_FILTER="$TAGS_FILTER,-no_arm64" - ADDITIONAL_FLAGS="$ADDITIONAL_FLAGS --action_env PYTHON_BIN_PATH=/usr/bin/python3.10 --python_path=/usr/bin/python3.10" + TAGS_FILTER="$TAGS_FILTER,-no_aarch64" + ADDITIONAL_FLAGS="$ADDITIONAL_FLAGS --config=tf_public_cache_push --action_env PYTHON_BIN_PATH=/usr/bin/python3.10 --python_path=/usr/bin/python3.10" else - RBE_FLAGS="--config=rbe_linux_cpu" + RBE_FLAGS="--config=rbe_linux_cpu --jobs=150" fi fi @@ -76,7 +88,6 @@ docker exec xla bazel \ --profile=/tf/pkg/profile.json.gz \ --flaky_test_attempts=3 \ $RBE_FLAGS \ - --jobs=150 \ --nobuild_tests_only \ $ADDITIONAL_FLAGS \ -- //xla/... //build_tools/... diff --git a/third_party/xla/build_tools/BUILD b/third_party/xla/build_tools/BUILD new file mode 100644 index 00000000000000..3111ccb2505db4 --- /dev/null +++ b/third_party/xla/build_tools/BUILD @@ -0,0 +1,28 @@ +# Copyright 2024 The OpenXLA Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +load("//xla:pytype.default.bzl", "pytype_strict_library") + +package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + licenses = ["notice"], +) + +pytype_strict_library( + name = "test_utils", + testonly = True, + srcs = ["test_utils.py"], + visibility = ["//visibility:public"], +) diff --git a/third_party/xla/build_tools/configure/BUILD b/third_party/xla/build_tools/configure/BUILD new file mode 100644 index 00000000000000..90d63687cfb1b7 --- /dev/null +++ b/third_party/xla/build_tools/configure/BUILD @@ -0,0 +1,83 @@ +# Copyright 2024 The OpenXLA Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +load("//xla:pytype.default.bzl", "pytype_strict_library") + +# Placeholder: load py_test +load("@local_config_cuda//cuda:build_defs.bzl", "cuda_library") + +package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + licenses = ["notice"], +) + +pytype_strict_library( + name = "configure", + srcs = ["configure.py"], +) + +py_test( + name = "configure_test", + srcs = ["configure_test.py"], + data = [ + "testdata/clang.bazelrc", + "testdata/cuda_clang.bazelrc", + "testdata/gcc.bazelrc", + "testdata/nvcc_clang.bazelrc", + "testdata/nvcc_gcc.bazelrc", + ], + deps = [ + ":configure", + "//build_tools:test_utils", + "@absl_py//absl/testing:absltest", + ], +) + +# Below targets are just for checking if the host/CUDA compiler are configured +# as expected. +cc_library( + name = "assert_clang", + srcs = ["assert_clang.cc"], + tags = ["manual"], +) + +cc_library( + name = "assert_gcc", + srcs = ["assert_gcc.cc"], + tags = ["manual"], +) + +cuda_library( + name = "assert_cuda_clang", + srcs = ["assert_cuda_clang.cu.cc"], + tags = [ + "gpu", + "manual", + ], + deps = ["@local_config_cuda//cuda:cuda_headers"], +) + +cuda_library( + name = "assert_nvcc", + srcs = ["assert_nvcc.cu.cc"], + tags = [ + "gpu", + "manual", + ], + # Notably, this builds fine in OSS without this dependency. Apparently, + # NVCC can give targets access to CUDA headers without letting Bazel know, + # while CUDA clang cannot. + deps = ["@local_config_cuda//cuda:cuda_headers"], +) diff --git a/third_party/xla/xla/service/gpu/runtime/custom_call.h b/third_party/xla/build_tools/configure/assert_clang.cc similarity index 61% rename from third_party/xla/xla/service/gpu/runtime/custom_call.h rename to third_party/xla/build_tools/configure/assert_clang.cc index b79081bd753b46..3dd57d1d1ff8b5 100644 --- a/third_party/xla/xla/service/gpu/runtime/custom_call.h +++ b/third_party/xla/build_tools/configure/assert_clang.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The OpenXLA Authors. +/* Copyright 2024 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -13,17 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_GPU_RUNTIME_CUSTOM_CALL_H_ -#define XLA_SERVICE_GPU_RUNTIME_CUSTOM_CALL_H_ - -#include "xla/runtime/custom_call_registry.h" - -namespace xla { -namespace gpu { - -void RegisterXlaClassicCustomCalls(runtime::DirectCustomCallRegistry& registry); - -} // namespace gpu -} // namespace xla - -#endif // XLA_SERVICE_GPU_RUNTIME_CUSTOM_CALL_H_ +#ifndef __clang__ +#error "__clang__ not defined!" +#endif // #ifdef __clang__ diff --git a/third_party/xla/xla/service/gpu/runtime/memset.h b/third_party/xla/build_tools/configure/assert_cuda_clang.cu.cc similarity index 59% rename from third_party/xla/xla/service/gpu/runtime/memset.h rename to third_party/xla/build_tools/configure/assert_cuda_clang.cu.cc index a6b5a9ed38a526..12aeb2743b6356 100644 --- a/third_party/xla/xla/service/gpu/runtime/memset.h +++ b/third_party/xla/build_tools/configure/assert_cuda_clang.cu.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The OpenXLA Authors. +/* Copyright 2024 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -13,18 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_GPU_RUNTIME_MEMSET_H_ -#define XLA_SERVICE_GPU_RUNTIME_MEMSET_H_ - -#include "xla/runtime/custom_call_registry.h" - -namespace xla { -namespace gpu { - -// Registers XLA Gpu runtime memset custom calls. -void RegisterMemsetCustomCalls(runtime::DirectCustomCallRegistry& registry); - -} // namespace gpu -} // namespace xla - -#endif // XLA_SERVICE_GPU_RUNTIME_MEMSET_H_ +#if !defined(__clang__) || !defined(__CUDA__) +#error "__clang__ or __CUDA__ not defined!" +#endif // #if !defined(__clang__) || !defined(__CUDA__) diff --git a/third_party/xla/build_tools/configure/assert_gcc.cc b/third_party/xla/build_tools/configure/assert_gcc.cc new file mode 100644 index 00000000000000..617da0d621a01f --- /dev/null +++ b/third_party/xla/build_tools/configure/assert_gcc.cc @@ -0,0 +1,21 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Notably, clang will define `__GNUC__`, so need to make sure __clang__ is not +// defined to detect GCC (or, most correctly, some compiler that supports GNU +// extensions that is not clang). +#if !defined(__GNUC__) || defined(__clang__) +#error "__GNUC__ is not defined independently of __clang__!" +#endif // #if !defined(__GNUC__) || defined(__clang__) diff --git a/third_party/xla/build_tools/configure/assert_nvcc.cu.cc b/third_party/xla/build_tools/configure/assert_nvcc.cu.cc new file mode 100644 index 00000000000000..ea9287565755cc --- /dev/null +++ b/third_party/xla/build_tools/configure/assert_nvcc.cu.cc @@ -0,0 +1,17 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef __NVCC__ +#error "__NVCC__ not defined!" +#endif // #ifdef __NVCC__ diff --git a/third_party/xla/build_tools/configure/configure.py b/third_party/xla/build_tools/configure/configure.py new file mode 100755 index 00000000000000..ae40f9f41e7fdc --- /dev/null +++ b/third_party/xla/build_tools/configure/configure.py @@ -0,0 +1,537 @@ +#!/usr/bin/env python3 +# Copyright 2024 The OpenXLA Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Configure script to get build parameters from user. + +This script populates a bazelrc file that tells Bazel where to look for +cuda versions and compilers. Note: that a configuration is possible to request, +does not mean that it is supported (e.g. building with gcc). That being said, +if this stops working for you on an unsupported build and you have a fix, please +send a PR! + +Example usage: + `./configure.py --backend=cpu --host_compiler=clang` + Will write a bazelrc to the root of the repo with the lines required to find + the clang in your path. If that isn't the correct clang, you can override like + `./configure.py --backend=cpu --clang_path=`. + +NOTE(ddunleavy): Lots of these things should probably be outside of configure.py +but are here because of complexity in `cuda_configure.bzl` and the TF bazelrc. +Once XLA has it's own bazelrc, and cuda_configure.bzl is replaced or refactored, +we can probably make this file smaller. + +TODO(ddunleavy): add more thorough validation. +""" +import argparse +import dataclasses +import enum +import logging +import os +import pathlib +import shutil +import subprocess +import sys +from typing import Optional + +_REQUIRED_CUDA_LIBRARIES = ["cublas", "cuda", "cudnn"] +_DEFAULT_BUILD_AND_TEST_TAG_FILTERS = ("-no_oss",) +# Assume we are being invoked from the symlink at the root of the repo +_XLA_SRC_ROOT = pathlib.Path(__file__).absolute().parent +_FIND_CUDA_CONFIG = str( + _XLA_SRC_ROOT + / "third_party" + / "tsl" + / "third_party" + / "gpus" + / "find_cuda_config.py" +) +_XLA_BAZELRC_NAME = "xla_configure.bazelrc" +_KW_ONLY_IF_PYTHON310 = {"kw_only": True} if sys.version_info >= (3, 10) else {} + + +def _find_executable(executable: str) -> Optional[str]: + logging.info("Trying to find path to %s...", executable) + # Resolving the symlink is necessary for finding system headers. + if unresolved_path := shutil.which(executable): + return str(pathlib.Path(unresolved_path).resolve()) + return None + + +def _find_executable_or_die(executable: str) -> str: + """Finds executable and resolves symlinks or raises RuntimeError. + + Resolving symlinks is sometimes necessary for finding system headers. + + Args: + executable: The name of the executable that we want to find. + + Returns: + The path to the executable we are looking for. + Raises: + RuntimeError: if path to the executable cannot be found. + """ + resolved_path_to_exe = _find_executable(executable) + if resolved_path_to_exe is None: + raise RuntimeError( + f"Could not find executable `{executable}`! " + "Please change your $PATH or pass the path directly like" + f"`--{executable}_path=path/to/executable." + ) + logging.info("Found path to %s at %s", executable, resolved_path_to_exe) + + return resolved_path_to_exe + + +def _get_cuda_compute_capabilities_or_die() -> list[str]: + """Finds compute capabilities via nvidia-smi or rasies exception. + + Returns: + list of unique, sorted strings representing compute capabilities: + Raises: + RuntimeError: if path to nvidia-smi couldn't be found. + subprocess.CalledProcessError: if nvidia-smi process failed. + """ + try: + nvidia_smi = _find_executable_or_die("nvidia-smi") + nvidia_smi_proc = subprocess.run( + [nvidia_smi, "--query-gpu=compute_cap", "--format=csv,noheader"], + capture_output=True, + check=True, + text=True, + ) + # Command above returns a newline separated list of compute capabilities + # with possible repeats. So we should unique them and sort the final result. + capabilities = sorted(set(nvidia_smi_proc.stdout.strip().split("\n"))) + logging.info("Found CUDA compute capabilities: %s", capabilities) + return capabilities + except (RuntimeError, subprocess.CalledProcessError) as e: + logging.info( + "Could not find nvidia-smi, or nvidia-smi command failed. Please pass" + " capabilities directly using --cuda_compute_capabilities." + ) + raise e + + +def _get_clang_major_version(path_to_clang: str) -> int: + """Gets the major version of the clang at `path_to_clang`. + + Args: + path_to_clang: Path to a clang executable + + Returns: + The major version. + """ + logging.info("Running echo __clang_major__ | %s -E -P -", path_to_clang) + clang_version_proc = subprocess.run( + [path_to_clang, "-E", "-P", "-"], + input="__clang_major__", + check=True, + capture_output=True, + text=True, + ) + major_version = int(clang_version_proc.stdout) + logging.info("%s reports major version %s.", path_to_clang, major_version) + + return major_version + + +class ArgparseableEnum(enum.Enum): + """Enum base class with helper methods for working with argparse. + + Example usage: + ``` + class Fruit(ArgparseableEnum): + APPLE = enum.auto() + + # argparse setup + parser.add_argument("--fruit", type=Fruit.from_str, choices=list(Fruit)) + ``` + Users can pass strings like `--fruit=apple` with nice error messages and the + parser will get the corresponding enum value. + + NOTE: PyType gets confused when this class is used to create Enums in the + functional style like `ArgparseableEnum("Fruit", ["APPLE", "BANANA"])`. + """ + + def __str__(self): + return self.name + + @classmethod + def from_str(cls, s): + s = s.upper() + try: + return cls[s] + except KeyError: + # Sloppy looking exception handling, but argparse will catch ValueError + # and give a pleasant error message. KeyError would not work here. + raise ValueError # pylint: disable=raise-missing-from + + +class Backend(ArgparseableEnum): + CPU = enum.auto() + CUDA = enum.auto() + ROCM = enum.auto() + + +class HostCompiler(ArgparseableEnum): + CLANG = enum.auto() + GCC = enum.auto() + + +class CudaCompiler(ArgparseableEnum): + CLANG = enum.auto() + NVCC = enum.auto() + + +class OS(ArgparseableEnum): + LINUX = enum.auto() + MACOS = enum.auto() + WINDOWS = enum.auto() + + +@dataclasses.dataclass(**_KW_ONLY_IF_PYTHON310) +class DiscoverablePathsAndVersions: + """Paths to various tools and libraries needed to build XLA. + + This class is where all 'stateful' activity should happen, like trying to read + environment variables or looking for things in the $PATH. An instance that has + all fields set should not try to do any of these things though, so that this + file can remain unit testable. + """ + + clang_path: Optional[str] = None + clang_major_version: Optional[int] = None + gcc_path: Optional[str] = None + lld_path: Optional[str] = None + ld_library_path: Optional[str] = None + + # CUDA specific + cublas_version: Optional[str] = None + cuda_toolkit_path: Optional[str] = None + cuda_compute_capabilities: Optional[list[str]] = None + cudnn_version: Optional[str] = None + nccl_version: Optional[str] = None + + def get_relevant_paths_and_versions(self, config: "XLAConfigOptions"): + """Gets paths and versions as needed by the config. + + Args: + config: XLAConfigOptions instance that determines what paths and versions + to try to autoconfigure. + """ + if self.ld_library_path is None: + self.ld_library_path = os.environ.get("LD_LIBRARY_PATH", None) + + if config.host_compiler == HostCompiler.CLANG: + self.clang_path = self.clang_path or _find_executable_or_die("clang") + self.clang_major_version = ( + self.clang_major_version or _get_clang_major_version(self.clang_path) + ) + + # Notably, we don't use `_find_executable_or_die` for lld, as it changes + # which commands it accepts based on it's name! ld.lld is symlinked to a + # different executable just called lld, which should not be invoked + # directly. + self.lld_path = self.lld_path or shutil.which("ld.lld") + elif config.host_compiler == HostCompiler.GCC: + self.gcc_path = self.gcc_path or _find_executable_or_die("gcc") + + if config.backend == Backend.CUDA: + if config.cuda_compiler == CudaCompiler.CLANG: + self.clang_path = self.clang_path or _find_executable_or_die("clang") + + if not self.cuda_compute_capabilities: + self.cuda_compute_capabilities = _get_cuda_compute_capabilities_or_die() + + self._get_cuda_libraries_paths_and_versions_if_needed(config) + + def _get_cuda_libraries_paths_and_versions_if_needed( + self, config: "XLAConfigOptions" + ): + """Gets cuda paths and versions if user left any unspecified. + + This uses `find_cuda_config.py` to find versions for all libraries in + `_REQUIRED_CUDA_LIBRARIES`. + + Args: + config: config that determines which libraries should be found. + """ + should_find_nccl = config.using_nccl and self.nccl_version is None + any_cuda_config_unset = any([ + self.cublas_version is None, + self.cuda_toolkit_path is None, + self.cudnn_version is None, + should_find_nccl, + ]) + + maybe_nccl = ["nccl"] if should_find_nccl else [] + + if any_cuda_config_unset: + logging.info( + "Some CUDA config versions and paths were not provided, " + "so trying to find them using find_cuda_config.py" + ) + try: + find_cuda_config_proc = subprocess.run( + [ + sys.executable, + _FIND_CUDA_CONFIG, + *_REQUIRED_CUDA_LIBRARIES, + *maybe_nccl, + ], + capture_output=True, + check=True, + text=True, + ) + except subprocess.CalledProcessError as e: + logging.info("Command %s failed. Is CUDA installed?", e.cmd) + logging.info("Dumping %s ouptut:\n %s", e.cmd, e.output) + raise e + + cuda_config = dict( + tuple(line.split(": ")) + for line in find_cuda_config_proc.stdout.strip().split("\n") + ) + + self.cublas_version = self.cublas_version or cuda_config["cublas_version"] + self.cuda_toolkit_path = ( + self.cuda_toolkit_path or cuda_config["cuda_toolkit_path"] + ) + self.cudnn_version = self.cudnn_version or cuda_config["cudnn_version"] + if should_find_nccl: + self.nccl_version = self.nccl_version or cuda_config["nccl_version"] + + +@dataclasses.dataclass(frozen=True, **_KW_ONLY_IF_PYTHON310) +class XLAConfigOptions: + """Represents XLA configuration options.""" + + backend: Backend + os: OS + python_bin_path: str + host_compiler: HostCompiler + compiler_options: list[str] + + # CUDA specific + cuda_compiler: CudaCompiler + using_nccl: bool + using_tensorrt: bool + + def to_bazelrc_lines( + self, + dpav: DiscoverablePathsAndVersions, + ) -> list[str]: + """Creates a bazelrc given an XLAConfigOptions. + + Necessary paths are provided by the user, or retrieved via + `self._get_relevant_paths`. + + Args: + dpav: DiscoverablePathsAndVersions that may hold user-specified paths and + versions. The dpav will then read from `self` to determine what to try + to auto-configure. + + Returns: + The lines of a bazelrc. + """ + dpav.get_relevant_paths_and_versions(self) + rc = [] + build_and_test_tag_filters = list(_DEFAULT_BUILD_AND_TEST_TAG_FILTERS) + + # Platform independent options based on host compiler + if self.host_compiler == HostCompiler.GCC: + rc.append(f"build --action_env GCC_HOST_COMPILER_PATH={dpav.gcc_path}") + elif self.host_compiler == HostCompiler.CLANG: + rc.append(f"build --action_env CLANG_COMPILER_PATH={dpav.clang_path}") + rc.append(f"build --repo_env CC={dpav.clang_path}") + rc.append(f"build --repo_env BAZEL_COMPILER={dpav.clang_path}") + self.compiler_options.append("-Wno-error=unused-command-line-argument") + if dpav.lld_path: + rc.append(f"build --linkopt --ld-path={dpav.lld_path}") + + if self.backend == Backend.CPU: + build_and_test_tag_filters.append("-gpu") + + elif self.backend == Backend.CUDA: + compiler_pair = self.cuda_compiler, self.host_compiler + + if compiler_pair == (CudaCompiler.CLANG, HostCompiler.CLANG): + rc.append("build --config cuda_clang") + rc.append( + f"build --action_env CLANG_CUDA_COMPILER_PATH={dpav.clang_path}" + ) + elif compiler_pair == (CudaCompiler.NVCC, HostCompiler.CLANG): + rc.append("build --config nvcc_clang") + # This is demanded by cuda_configure.bzl + rc.append( + f"build --action_env CLANG_CUDA_COMPILER_PATH={dpav.clang_path}" + ) + elif compiler_pair == (CudaCompiler.NVCC, HostCompiler.GCC): + rc.append("build --config cuda") + else: + raise NotImplementedError( + "CUDA clang with host compiler gcc not supported" + ) + + # Lines needed for CUDA backend regardless of CUDA/host compiler + rc.append( + f"build --action_env CUDA_TOOLKIT_PATH={dpav.cuda_toolkit_path}" + ) + rc.append(f"build --action_env TF_CUBLAS_VERSION={dpav.cublas_version}") + rc.append( + "build --action_env" + f" TF_CUDA_COMPUTE_CAPABILITIES={','.join(dpav.cuda_compute_capabilities)}" + ) + rc.append(f"build --action_env TF_CUDNN_VERSION={dpav.cudnn_version}") + rc.append(f"build --repo_env TF_NEED_TENSORRT={int(self.using_tensorrt)}") + if self.using_nccl: + rc.append(f"build --action_env TF_NCCL_VERSION={dpav.nccl_version}") + else: + rc.append("build --config nonccl") + elif self.backend == Backend.ROCM: + pass + + # Lines that are added for every backend + if dpav.ld_library_path: + rc.append(f"build --action_env LD_LIBRARY_PATH={dpav.ld_library_path}") + + if dpav.clang_major_version in (16, 17): + self.compiler_options.append("-Wno-gnu-offsetof-extensions") + + rc.append(f"build --action_env PYTHON_BIN_PATH={self.python_bin_path}") + rc.append(f"build --python_path {self.python_bin_path}") + rc.append("test --test_env LD_LIBRARY_PATH") + rc.append("test --test_size_filters small,medium") + + rc.extend([ + f"build --copt {compiler_option}" + for compiler_option in self.compiler_options + ]) + + # Add build and test tag filters + build_and_test_tag_filters = ",".join(build_and_test_tag_filters) + rc.append(f"build --build_tag_filters {build_and_test_tag_filters}") + rc.append(f"build --test_tag_filters {build_and_test_tag_filters}") + rc.append(f"test --build_tag_filters {build_and_test_tag_filters}") + rc.append(f"test --test_tag_filters {build_and_test_tag_filters}") + + return rc + + +def _parse_args(): + """Creates an argparse.ArgumentParser and parses arguments.""" + comma_separated_list = lambda l: [s.strip() for s in l.split(",")] + + parser = argparse.ArgumentParser(allow_abbrev=False) + parser.add_argument( + "--backend", + type=Backend.from_str, + choices=list(Backend), + required=True, + ) + parser.add_argument( + "--os", type=OS.from_str, choices=list(OS), default="linux" + ) + parser.add_argument( + "--host_compiler", + type=HostCompiler.from_str, + choices=list(HostCompiler), + default="clang", + ) + parser.add_argument( + "--cuda_compiler", + type=CudaCompiler.from_str, + choices=list(CudaCompiler), + default="nvcc", + ) + parser.add_argument( + "--cuda_compute_capabilities", + type=comma_separated_list, + default=None, + ) + parser.add_argument("--python_bin_path", default=sys.executable) + parser.add_argument( + "--compiler_options", + type=comma_separated_list, + default="-Wno-sign-compare", + ) + parser.add_argument("--nccl", action="store_true") + parser.add_argument("--tensorrt", action="store_true") + + # Path and version overrides + path_help = "Optional: will be found on PATH if possible." + parser.add_argument("--clang_path", help=path_help) + parser.add_argument("--gcc_path", help=path_help) + parser.add_argument( + "--ld_library_path", + help=( + "Optional: will be automatically taken from the current environment" + " if flag is not set" + ), + ) + parser.add_argument("--lld_path", help=path_help) + + # CUDA specific + find_cuda_config_help = ( + "Optional: will be found using `find_cuda_config.py` if flag is not set." + ) + parser.add_argument("--cublas_version", help=find_cuda_config_help) + parser.add_argument("--cuda_toolkit_path", help=find_cuda_config_help) + parser.add_argument("--cudnn_version", help=find_cuda_config_help) + parser.add_argument("--nccl_version", help=find_cuda_config_help) + + return parser.parse_args() + + +def main(): + # Setup logging + logging.basicConfig() + logging.getLogger().setLevel(logging.INFO) + + args = _parse_args() + + config = XLAConfigOptions( + backend=args.backend, + os=args.os, + host_compiler=args.host_compiler, + cuda_compiler=args.cuda_compiler, + python_bin_path=args.python_bin_path, + compiler_options=args.compiler_options, + using_nccl=args.nccl, + using_tensorrt=args.tensorrt, + ) + + bazelrc_lines = config.to_bazelrc_lines( + DiscoverablePathsAndVersions( + clang_path=args.clang_path, + gcc_path=args.gcc_path, + ld_library_path=args.ld_library_path, + cublas_version=args.cublas_version, + cuda_compute_capabilities=args.cuda_compute_capabilities, + cuda_toolkit_path=args.cuda_toolkit_path, + cudnn_version=args.cudnn_version, + nccl_version=args.nccl_version, + ) + ) + + bazelrc_path = _XLA_SRC_ROOT / _XLA_BAZELRC_NAME + bazelrc_contents = "\n".join(bazelrc_lines) + "\n" + + with (bazelrc_path).open("w") as f: + logging.info("Writing bazelrc to %s...", bazelrc_path) + f.write(bazelrc_contents) + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/third_party/xla/build_tools/configure/configure_test.py b/third_party/xla/build_tools/configure/configure_test.py new file mode 100644 index 00000000000000..c952c8f9241f4f --- /dev/null +++ b/third_party/xla/build_tools/configure/configure_test.py @@ -0,0 +1,179 @@ +# Copyright 2024 The OpenXLA Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from absl.testing import absltest + +from xla.build_tools import test_utils +from xla.build_tools.configure import configure + + +XLAConfigOptions = configure.XLAConfigOptions +DiscoverablePathsAndVersions = configure.DiscoverablePathsAndVersions +Backend = configure.Backend +HostCompiler = configure.HostCompiler +CudaCompiler = configure.CudaCompiler +OS = configure.OS + +_PYTHON_BIN_PATH = "/usr/bin/python3" +_CLANG_PATH = "/usr/lib/llvm-17/bin/clang" +_GCC_PATH = "/usr/bin/gcc" +_COMPILER_OPTIONS = ("-Wno-sign-compare",) + +# CUDA specific paths and versions +_CUDA_SPECIFIC_PATHS_AND_VERSIONS = { + "cublas_version": "12.3", + "cuda_toolkit_path": "/usr/local/cuda-12.2", + "cuda_compute_capabilities": ["7.5"], + "cudnn_version": "8", + "ld_library_path": "/usr/local/nvidia/lib:/usr/local/nvidia/lib64", + "nccl_version": "2", +} + + +class ConfigureTest(absltest.TestCase): + + @classmethod + def setUpClass(cls): + super().setUpClass() + + testdata = ( + test_utils.xla_src_root() / "build_tools" / "configure" / "testdata" + ) + + with (testdata / "clang.bazelrc").open() as f: + cls.clang_bazelrc_lines = [line.strip() for line in f.readlines()] + + with (testdata / "gcc.bazelrc").open() as f: + cls.gcc_bazelrc_lines = [line.strip() for line in f.readlines()] + + with (testdata / "cuda_clang.bazelrc").open() as f: + cls.cuda_clang_bazelrc_lines = [line.strip() for line in f.readlines()] + + with (testdata / "nvcc_clang.bazelrc").open() as f: + cls.nvcc_clang_bazelrc_lines = [line.strip() for line in f.readlines()] + + with (testdata / "nvcc_gcc.bazelrc").open() as f: + cls.nvcc_gcc_bazelrc_lines = [line.strip() for line in f.readlines()] + + def test_clang_bazelrc(self): + config = XLAConfigOptions( + backend=Backend.CPU, + os=OS.LINUX, + python_bin_path=_PYTHON_BIN_PATH, + host_compiler=HostCompiler.CLANG, + compiler_options=list(_COMPILER_OPTIONS), + cuda_compiler=CudaCompiler.NVCC, + using_nccl=False, + using_tensorrt=False, + ) + + bazelrc_lines = config.to_bazelrc_lines( + DiscoverablePathsAndVersions( + clang_path=_CLANG_PATH, + ld_library_path="", + clang_major_version=17, + ) + ) + + self.assertEqual(bazelrc_lines, self.clang_bazelrc_lines) + + def test_gcc_bazelrc(self): + config = XLAConfigOptions( + backend=Backend.CPU, + os=OS.LINUX, + python_bin_path=_PYTHON_BIN_PATH, + host_compiler=HostCompiler.GCC, + compiler_options=list(_COMPILER_OPTIONS), + cuda_compiler=CudaCompiler.NVCC, + using_nccl=False, + using_tensorrt=False, + ) + + bazelrc_lines = config.to_bazelrc_lines( + DiscoverablePathsAndVersions( + gcc_path=_GCC_PATH, + ld_library_path="", + ) + ) + + self.assertEqual(bazelrc_lines, self.gcc_bazelrc_lines) + + def test_cuda_clang_bazelrc(self): + config = XLAConfigOptions( + backend=Backend.CUDA, + os=OS.LINUX, + python_bin_path=_PYTHON_BIN_PATH, + host_compiler=HostCompiler.CLANG, + compiler_options=list(_COMPILER_OPTIONS), + cuda_compiler=CudaCompiler.CLANG, + using_nccl=False, + using_tensorrt=False, + ) + + bazelrc_lines = config.to_bazelrc_lines( + DiscoverablePathsAndVersions( + clang_path=_CLANG_PATH, + clang_major_version=17, + **_CUDA_SPECIFIC_PATHS_AND_VERSIONS, + ) + ) + + self.assertEqual(bazelrc_lines, self.cuda_clang_bazelrc_lines) + + def test_nvcc_clang_bazelrc(self): + config = XLAConfigOptions( + backend=Backend.CUDA, + os=OS.LINUX, + python_bin_path=_PYTHON_BIN_PATH, + host_compiler=HostCompiler.CLANG, + compiler_options=list(_COMPILER_OPTIONS), + cuda_compiler=CudaCompiler.NVCC, + using_nccl=False, + using_tensorrt=False, + ) + + bazelrc_lines = config.to_bazelrc_lines( + DiscoverablePathsAndVersions( + clang_path=_CLANG_PATH, + clang_major_version=17, + **_CUDA_SPECIFIC_PATHS_AND_VERSIONS, + ) + ) + + self.assertEqual(bazelrc_lines, self.nvcc_clang_bazelrc_lines) + + def test_nvcc_gcc_bazelrc(self): + config = XLAConfigOptions( + backend=Backend.CUDA, + os=OS.LINUX, + python_bin_path=_PYTHON_BIN_PATH, + host_compiler=HostCompiler.GCC, + compiler_options=list(_COMPILER_OPTIONS), + cuda_compiler=CudaCompiler.NVCC, + using_nccl=False, + using_tensorrt=False, + ) + + bazelrc_lines = config.to_bazelrc_lines( + DiscoverablePathsAndVersions( + gcc_path=_GCC_PATH, + **_CUDA_SPECIFIC_PATHS_AND_VERSIONS, + ) + ) + + self.assertEqual(bazelrc_lines, self.nvcc_gcc_bazelrc_lines) + + +if __name__ == "__main__": + absltest.main() diff --git a/third_party/xla/build_tools/configure/testdata/clang.bazelrc b/third_party/xla/build_tools/configure/testdata/clang.bazelrc new file mode 100644 index 00000000000000..317be65966633d --- /dev/null +++ b/third_party/xla/build_tools/configure/testdata/clang.bazelrc @@ -0,0 +1,14 @@ +build --action_env CLANG_COMPILER_PATH=/usr/lib/llvm-17/bin/clang +build --repo_env CC=/usr/lib/llvm-17/bin/clang +build --repo_env BAZEL_COMPILER=/usr/lib/llvm-17/bin/clang +build --action_env PYTHON_BIN_PATH=/usr/bin/python3 +build --python_path /usr/bin/python3 +test --test_env LD_LIBRARY_PATH +test --test_size_filters small,medium +build --copt -Wno-sign-compare +build --copt -Wno-error=unused-command-line-argument +build --copt -Wno-gnu-offsetof-extensions +build --build_tag_filters -no_oss,-gpu +build --test_tag_filters -no_oss,-gpu +test --build_tag_filters -no_oss,-gpu +test --test_tag_filters -no_oss,-gpu diff --git a/third_party/xla/build_tools/configure/testdata/cuda_clang.bazelrc b/third_party/xla/build_tools/configure/testdata/cuda_clang.bazelrc new file mode 100644 index 00000000000000..b998cf06935f33 --- /dev/null +++ b/third_party/xla/build_tools/configure/testdata/cuda_clang.bazelrc @@ -0,0 +1,23 @@ +build --action_env CLANG_COMPILER_PATH=/usr/lib/llvm-17/bin/clang +build --repo_env CC=/usr/lib/llvm-17/bin/clang +build --repo_env BAZEL_COMPILER=/usr/lib/llvm-17/bin/clang +build --config cuda_clang +build --action_env CLANG_CUDA_COMPILER_PATH=/usr/lib/llvm-17/bin/clang +build --action_env CUDA_TOOLKIT_PATH=/usr/local/cuda-12.2 +build --action_env TF_CUBLAS_VERSION=12.3 +build --action_env TF_CUDA_COMPUTE_CAPABILITIES=7.5 +build --action_env TF_CUDNN_VERSION=8 +build --repo_env TF_NEED_TENSORRT=0 +build --config nonccl +build --action_env LD_LIBRARY_PATH=/usr/local/nvidia/lib:/usr/local/nvidia/lib64 +build --action_env PYTHON_BIN_PATH=/usr/bin/python3 +build --python_path /usr/bin/python3 +test --test_env LD_LIBRARY_PATH +test --test_size_filters small,medium +build --copt -Wno-sign-compare +build --copt -Wno-error=unused-command-line-argument +build --copt -Wno-gnu-offsetof-extensions +build --build_tag_filters -no_oss +build --test_tag_filters -no_oss +test --build_tag_filters -no_oss +test --test_tag_filters -no_oss diff --git a/third_party/xla/build_tools/configure/testdata/gcc.bazelrc b/third_party/xla/build_tools/configure/testdata/gcc.bazelrc new file mode 100644 index 00000000000000..8eefec15ee8efb --- /dev/null +++ b/third_party/xla/build_tools/configure/testdata/gcc.bazelrc @@ -0,0 +1,10 @@ +build --action_env GCC_HOST_COMPILER_PATH=/usr/bin/gcc +build --action_env PYTHON_BIN_PATH=/usr/bin/python3 +build --python_path /usr/bin/python3 +test --test_env LD_LIBRARY_PATH +test --test_size_filters small,medium +build --copt -Wno-sign-compare +build --build_tag_filters -no_oss,-gpu +build --test_tag_filters -no_oss,-gpu +test --build_tag_filters -no_oss,-gpu +test --test_tag_filters -no_oss,-gpu diff --git a/third_party/xla/build_tools/configure/testdata/nvcc_clang.bazelrc b/third_party/xla/build_tools/configure/testdata/nvcc_clang.bazelrc new file mode 100644 index 00000000000000..912dc50faff4c1 --- /dev/null +++ b/third_party/xla/build_tools/configure/testdata/nvcc_clang.bazelrc @@ -0,0 +1,23 @@ +build --action_env CLANG_COMPILER_PATH=/usr/lib/llvm-17/bin/clang +build --repo_env CC=/usr/lib/llvm-17/bin/clang +build --repo_env BAZEL_COMPILER=/usr/lib/llvm-17/bin/clang +build --config nvcc_clang +build --action_env CLANG_CUDA_COMPILER_PATH=/usr/lib/llvm-17/bin/clang +build --action_env CUDA_TOOLKIT_PATH=/usr/local/cuda-12.2 +build --action_env TF_CUBLAS_VERSION=12.3 +build --action_env TF_CUDA_COMPUTE_CAPABILITIES=7.5 +build --action_env TF_CUDNN_VERSION=8 +build --repo_env TF_NEED_TENSORRT=0 +build --config nonccl +build --action_env LD_LIBRARY_PATH=/usr/local/nvidia/lib:/usr/local/nvidia/lib64 +build --action_env PYTHON_BIN_PATH=/usr/bin/python3 +build --python_path /usr/bin/python3 +test --test_env LD_LIBRARY_PATH +test --test_size_filters small,medium +build --copt -Wno-sign-compare +build --copt -Wno-error=unused-command-line-argument +build --copt -Wno-gnu-offsetof-extensions +build --build_tag_filters -no_oss +build --test_tag_filters -no_oss +test --build_tag_filters -no_oss +test --test_tag_filters -no_oss diff --git a/third_party/xla/build_tools/configure/testdata/nvcc_gcc.bazelrc b/third_party/xla/build_tools/configure/testdata/nvcc_gcc.bazelrc new file mode 100644 index 00000000000000..863209697362de --- /dev/null +++ b/third_party/xla/build_tools/configure/testdata/nvcc_gcc.bazelrc @@ -0,0 +1,18 @@ +build --action_env GCC_HOST_COMPILER_PATH=/usr/bin/gcc +build --config cuda +build --action_env CUDA_TOOLKIT_PATH=/usr/local/cuda-12.2 +build --action_env TF_CUBLAS_VERSION=12.3 +build --action_env TF_CUDA_COMPUTE_CAPABILITIES=7.5 +build --action_env TF_CUDNN_VERSION=8 +build --repo_env TF_NEED_TENSORRT=0 +build --config nonccl +build --action_env LD_LIBRARY_PATH=/usr/local/nvidia/lib:/usr/local/nvidia/lib64 +build --action_env PYTHON_BIN_PATH=/usr/bin/python3 +build --python_path /usr/bin/python3 +test --test_env LD_LIBRARY_PATH +test --test_size_filters small,medium +build --copt -Wno-sign-compare +build --build_tag_filters -no_oss +build --test_tag_filters -no_oss +test --build_tag_filters -no_oss +test --test_tag_filters -no_oss diff --git a/third_party/xla/build_tools/lint/BUILD b/third_party/xla/build_tools/lint/BUILD index 9cb4f866ce7af8..b4b825c925425c 100644 --- a/third_party/xla/build_tools/lint/BUILD +++ b/third_party/xla/build_tools/lint/BUILD @@ -17,7 +17,6 @@ load("//xla:pytype.default.bzl", "pytype_strict_library") # Placeholder: load py_test package( - default_visibility = ["//visibility:public"], # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], licenses = ["notice"], ) @@ -46,9 +45,9 @@ py_test( "testdata/bad_cc.diff", "testdata/important_cc.diff", ], - tags = ["no_oss"], deps = [ ":check_contents", + "//build_tools:test_utils", "@absl_py//absl/testing:absltest", ], ) @@ -61,9 +60,9 @@ py_test( "testdata/crosstool.diff", "testdata/important_cc.diff", ], - tags = ["no_oss"], deps = [ ":diff_parser", + "//build_tools:test_utils", "@absl_py//absl/testing:absltest", ], ) @@ -71,7 +70,6 @@ py_test( py_test( name = "generate_compile_commands_test", srcs = ["generate_compile_commands_test.py"], - tags = ["no_oss"], deps = [ ":generate_compile_commands", "@absl_py//absl/testing:absltest", diff --git a/third_party/xla/build_tools/lint/check_contents_test.py b/third_party/xla/build_tools/lint/check_contents_test.py index 8781076ed2dacb..21d58785c6f6a5 100644 --- a/third_party/xla/build_tools/lint/check_contents_test.py +++ b/third_party/xla/build_tools/lint/check_contents_test.py @@ -14,6 +14,7 @@ # ============================================================================ from absl.testing import absltest +from xla.build_tools import test_utils from xla.build_tools.lint import check_contents from xla.build_tools.lint import diff_parser @@ -24,11 +25,11 @@ class CheckDiffsTest(absltest.TestCase): def setUpClass(cls): super().setUpClass() - base_path = "third_party/xla/build_tools/lint" - with open(f"{base_path}/testdata/bad_cc.diff") as f: + testdata = test_utils.xla_src_root() / "build_tools" / "lint" / "testdata" + with (testdata / "bad_cc.diff").open() as f: cls.bad_cc_hunks = diff_parser.parse_hunks(f.read()) - with open(f"{base_path}/testdata/important_cc.diff") as f: + with (testdata / "important_cc.diff").open() as f: cls.important_cc_hunks = diff_parser.parse_hunks(f.read()) def test_check_good_diff(self): diff --git a/third_party/xla/build_tools/lint/diff_parser_test.py b/third_party/xla/build_tools/lint/diff_parser_test.py index a21761040fc347..787020cc865033 100644 --- a/third_party/xla/build_tools/lint/diff_parser_test.py +++ b/third_party/xla/build_tools/lint/diff_parser_test.py @@ -14,6 +14,7 @@ # ============================================================================ from absl.testing import absltest +from xla.build_tools import test_utils from xla.build_tools.lint import diff_parser @@ -23,15 +24,15 @@ class ParseDiffTest(absltest.TestCase): def setUpClass(cls): super().setUpClass() - base_path = "third_party/xla/build_tools/lint" + testdata = test_utils.xla_src_root() / "build_tools" / "lint" / "testdata" - with open(f"{base_path}/testdata/bad_cc.diff") as f: + with (testdata / "bad_cc.diff").open() as f: cls.bad_cc_diff = f.read() - with open(f"{base_path}/testdata/important_cc.diff") as f: + with (testdata / "important_cc.diff").open() as f: cls.important_cc_diff = f.read() - with open(f"{base_path}/testdata/crosstool.diff") as f: + with (testdata / "crosstool.diff").open() as f: cls.crosstool_diff = f.read() def test_parse_important_cc_diff(self): diff --git a/third_party/xla/build_tools/test_utils.py b/third_party/xla/build_tools/test_utils.py new file mode 100644 index 00000000000000..1d9672379d9ca3 --- /dev/null +++ b/third_party/xla/build_tools/test_utils.py @@ -0,0 +1,28 @@ +# Copyright 2024 The OpenXLA Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Test utils for python tests in XLA.""" +import os +import pathlib + + +def xla_src_root() -> pathlib.Path: + """Gets the path to the root of the XLA source tree.""" + is_oss = "BAZEL_TEST" in os.environ + test_srcdir = os.environ["TEST_SRCDIR"] + test_workspace = os.environ["TEST_WORKSPACE"] + if is_oss: + return pathlib.Path(test_srcdir) / test_workspace + else: + return pathlib.Path(test_srcdir) / test_workspace / "third_party" / "xla" diff --git a/third_party/xla/configure b/third_party/xla/configure deleted file mode 100755 index e43908e39da0cc..00000000000000 --- a/third_party/xla/configure +++ /dev/null @@ -1,15 +0,0 @@ -#!/usr/bin/env bash - -set -e -set -o pipefail - -if [ -z "$PYTHON_BIN_PATH" ]; then - PYTHON_BIN_PATH=$(which python3 || which python || true) -fi - -# Set all env variables -CONFIGURE_DIR=$(dirname "$0") -"$PYTHON_BIN_PATH" "${CONFIGURE_DIR}/configure.py" "$@" - -echo "Configuration finished" - diff --git a/third_party/xla/configure.cmd b/third_party/xla/configure.cmd deleted file mode 100644 index 4efb802b42155c..00000000000000 --- a/third_party/xla/configure.cmd +++ /dev/null @@ -1,20 +0,0 @@ -:: Copyright 2019 The OpenXLA Authors. -:: -:: Licensed under the Apache License, Version 2.0 (the "License"); -:: you may not use this file except in compliance with the License. -:: You may obtain a copy of the License at -:: -:: http://www.apache.org/licenses/LICENSE-2.0 -:: -:: Unless required by applicable law or agreed to in writing, software -:: distributed under the License is distributed on an "AS IS" BASIS, WITHOUT -:: WARRANTIES OR CONDITIONS OF ANY KIND< either express or implied. See the -:: License for the specific language governing permissions and limitations under -:: the License. - -@echo off - -set configure_dir=%~dp0 -set configure_dir=%configure_dir:~0,-1% -python "%configure_dir%\configure.py" %* || ( exit /b ) -echo Configuration finished diff --git a/third_party/xla/configure.py b/third_party/xla/configure.py deleted file mode 100644 index 949a74d05ba8d8..00000000000000 --- a/third_party/xla/configure.py +++ /dev/null @@ -1,1194 +0,0 @@ -# Copyright 2017 The OpenXLA Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""configure script to get build parameters from user.""" - -import argparse -import os -import pathlib -import platform -import re -import subprocess -import sys - -# pylint: disable=g-import-not-at-top -try: - from shutil import which -except ImportError: - from distutils.spawn import find_executable as which -# pylint: enable=g-import-not-at-top - -_DEFAULT_CUDA_VERSION = '11' -_DEFAULT_CUDNN_VERSION = '2' -_DEFAULT_CUDA_COMPUTE_CAPABILITIES = '5.2,7.0' - -_DEFAULT_PROMPT_ASK_ATTEMPTS = 10 - -_TF_BAZELRC_FILENAME = '.tf_configure.bazelrc' -_TF_WORKSPACE_ROOT = '' -_TF_BAZELRC = '' -_TF_CURRENT_BAZEL_VERSION = None - - -class UserInputError(Exception): - pass - - -def is_windows(): - return platform.system() == 'Windows' - - -def is_linux(): - return platform.system() == 'Linux' - - -def is_macos(): - return platform.system() == 'Darwin' - - -def is_ppc64le(): - return platform.machine() == 'ppc64le' - - -def is_cygwin(): - return platform.system().startswith('CYGWIN_NT') - - -def get_input(question): - try: - try: - answer = raw_input(question) - except NameError: - answer = input(question) # pylint: disable=bad-builtin - except EOFError: - answer = '' - return answer - - -def write_to_bazelrc(line): - with open(_TF_BAZELRC, 'a') as f: - f.write(line + '\n') - - -def write_action_env_to_bazelrc(var_name, var): - write_to_bazelrc('build --action_env {}="{}"'.format(var_name, str(var))) - - -def run_shell(cmd, allow_non_zero=False, stderr=None): - if stderr is None: - stderr = sys.stdout - if allow_non_zero: - try: - output = subprocess.check_output(cmd, stderr=stderr) - except subprocess.CalledProcessError as e: - output = e.output - else: - output = subprocess.check_output(cmd, stderr=stderr) - return output.decode('UTF-8').strip() - - -def cygpath(path): - """Convert path from posix to windows.""" - return os.path.abspath(path).replace('\\', '/') - - -def get_python_path(environ_cp, python_bin_path): - """Get the python site package paths.""" - python_paths = [] - if environ_cp.get('PYTHONPATH'): - python_paths = environ_cp.get('PYTHONPATH').split(':') - try: - stderr = open(os.devnull, 'wb') - library_paths = run_shell([ - python_bin_path, '-c', - 'import site; print("\\n".join(site.getsitepackages()))' - ], - stderr=stderr).split('\n') - except subprocess.CalledProcessError: - library_paths = [ - run_shell([ - python_bin_path, '-c', - 'from distutils.sysconfig import get_python_lib;' - 'print(get_python_lib())' - ]) - ] - - all_paths = set(python_paths + library_paths) - # Sort set so order is deterministic - all_paths = sorted(all_paths) - - paths = [] - for path in all_paths: - if os.path.isdir(path): - paths.append(path) - return paths - - -def get_python_major_version(python_bin_path): - """Get the python major version.""" - return run_shell([python_bin_path, '-c', 'import sys; print(sys.version[0])']) - - -def setup_python(environ_cp): - """Setup python related env variables.""" - # Get PYTHON_BIN_PATH, default is the current running python. - default_python_bin_path = sys.executable - ask_python_bin_path = ('Please specify the location of python. [Default is ' - '{}]: ').format(default_python_bin_path) - while True: - python_bin_path = get_from_env_or_user_or_default(environ_cp, - 'PYTHON_BIN_PATH', - ask_python_bin_path, - default_python_bin_path) - # Check if the path is valid - if os.path.isfile(python_bin_path) and os.access(python_bin_path, os.X_OK): - break - elif not os.path.exists(python_bin_path): - print('Invalid python path: {} cannot be found.'.format(python_bin_path)) - else: - print('{} is not executable. Is it the python binary?'.format( - python_bin_path)) - environ_cp['PYTHON_BIN_PATH'] = '' - - # Convert python path to Windows style before checking lib and version - if is_windows() or is_cygwin(): - python_bin_path = cygpath(python_bin_path) - - # Get PYTHON_LIB_PATH - python_lib_path = environ_cp.get('PYTHON_LIB_PATH') - if not python_lib_path: - python_lib_paths = get_python_path(environ_cp, python_bin_path) - if environ_cp.get('USE_DEFAULT_PYTHON_LIB_PATH') == '1': - python_lib_path = python_lib_paths[0] - else: - print('Found possible Python library paths:\n %s' % - '\n '.join(python_lib_paths)) - default_python_lib_path = python_lib_paths[0] - python_lib_path = get_input( - 'Please input the desired Python library path to use. ' - 'Default is [{}]\n'.format(python_lib_paths[0])) - if not python_lib_path: - python_lib_path = default_python_lib_path - environ_cp['PYTHON_LIB_PATH'] = python_lib_path - - python_major_version = get_python_major_version(python_bin_path) - if python_major_version == '2': - write_to_bazelrc('build --host_force_python=PY2') - - # Convert python path to Windows style before writing into bazel.rc - if is_windows() or is_cygwin(): - python_lib_path = cygpath(python_lib_path) - - # Set-up env variables used by python_configure.bzl - write_action_env_to_bazelrc('PYTHON_BIN_PATH', python_bin_path) - write_action_env_to_bazelrc('PYTHON_LIB_PATH', python_lib_path) - write_to_bazelrc('build --python_path=\"{}"'.format(python_bin_path)) - environ_cp['PYTHON_BIN_PATH'] = python_bin_path - - # If choosen python_lib_path is from a path specified in the PYTHONPATH - # variable, need to tell bazel to include PYTHONPATH - if environ_cp.get('PYTHONPATH'): - python_paths = environ_cp.get('PYTHONPATH').split(':') - if python_lib_path in python_paths: - write_action_env_to_bazelrc('PYTHONPATH', environ_cp.get('PYTHONPATH')) - - # Write tools/python_bin_path.sh - with open( - os.path.join(_TF_WORKSPACE_ROOT, 'tools', 'python_bin_path.sh'), - 'w') as f: - f.write('export PYTHON_BIN_PATH="{}"'.format(python_bin_path)) - - -def reset_tf_configure_bazelrc(): - """Reset file that contains customized config settings.""" - open(_TF_BAZELRC, 'w').close() - - -def get_var(environ_cp, - var_name, - query_item, - enabled_by_default, - question=None, - yes_reply=None, - no_reply=None): - """Get boolean input from user. - - If var_name is not set in env, ask user to enable query_item or not. If the - response is empty, use the default. - - Args: - environ_cp: copy of the os.environ. - var_name: string for name of environment variable, e.g. "TF_NEED_CUDA". - query_item: string for feature related to the variable, e.g. "CUDA for - Nvidia GPUs". - enabled_by_default: boolean for default behavior. - question: optional string for how to ask for user input. - yes_reply: optional string for reply when feature is enabled. - no_reply: optional string for reply when feature is disabled. - - Returns: - boolean value of the variable. - - Raises: - UserInputError: if an environment variable is set, but it cannot be - interpreted as a boolean indicator, assume that the user has made a - scripting error, and will continue to provide invalid input. - Raise the error to avoid infinitely looping. - """ - if not question: - question = 'Do you wish to build XLA with {} support?'.format( - query_item) - if not yes_reply: - yes_reply = '{} support will be enabled for XLA.'.format(query_item) - if not no_reply: - no_reply = 'No {}'.format(yes_reply) - - yes_reply += '\n' - no_reply += '\n' - - if enabled_by_default: - question += ' [Y/n]: ' - else: - question += ' [y/N]: ' - - var = environ_cp.get(var_name) - if var is not None: - var_content = var.strip().lower() - true_strings = ('1', 't', 'true', 'y', 'yes') - false_strings = ('0', 'f', 'false', 'n', 'no') - if var_content in true_strings: - var = True - elif var_content in false_strings: - var = False - else: - raise UserInputError( - 'Environment variable %s must be set as a boolean indicator.\n' - 'The following are accepted as TRUE : %s.\n' - 'The following are accepted as FALSE: %s.\n' - 'Current value is %s.' % - (var_name, ', '.join(true_strings), ', '.join(false_strings), var)) - - while var is None: - user_input_origin = get_input(question) - user_input = user_input_origin.strip().lower() - if user_input == 'y': - print(yes_reply) - var = True - elif user_input == 'n': - print(no_reply) - var = False - elif not user_input: - if enabled_by_default: - print(yes_reply) - var = True - else: - print(no_reply) - var = False - else: - print('Invalid selection: {}'.format(user_input_origin)) - return var - - -def set_action_env_var(environ_cp, - var_name, - query_item, - enabled_by_default, - question=None, - yes_reply=None, - no_reply=None, - bazel_config_name=None): - """Set boolean action_env variable. - - Ask user if query_item will be enabled. Default is used if no input is given. - Set environment variable and write to .bazelrc. - - Args: - environ_cp: copy of the os.environ. - var_name: string for name of environment variable, e.g. "TF_NEED_CUDA". - query_item: string for feature related to the variable, e.g. "CUDA for - Nvidia GPUs". - enabled_by_default: boolean for default behavior. - question: optional string for how to ask for user input. - yes_reply: optional string for reply when feature is enabled. - no_reply: optional string for reply when feature is disabled. - bazel_config_name: adding config to .bazelrc instead of action_env. - """ - var = int( - get_var(environ_cp, var_name, query_item, enabled_by_default, question, - yes_reply, no_reply)) - - if not bazel_config_name: - write_action_env_to_bazelrc(var_name, var) - elif var: - write_to_bazelrc('build --config=%s' % bazel_config_name) - environ_cp[var_name] = str(var) - - -def convert_version_to_int(version): - """Convert a version number to a integer that can be used to compare. - - Version strings of the form X.YZ and X.Y.Z-xxxxx are supported. The - 'xxxxx' part, for instance 'homebrew' on OS/X, is ignored. - - Args: - version: a version to be converted - - Returns: - An integer if converted successfully, otherwise return None. - """ - version = version.split('-')[0] - version_segments = version.split('.') - # Treat "0.24" as "0.24.0" - if len(version_segments) == 2: - version_segments.append('0') - for seg in version_segments: - if not seg.isdigit(): - return None - - version_str = ''.join(['%03d' % int(seg) for seg in version_segments]) - return int(version_str) - - -def retrieve_bazel_version(): - """Retrieve installed bazel version (or bazelisk). - - Returns: - The bazel version detected. - """ - bazel_executable = which('bazel') - if bazel_executable is None: - bazel_executable = which('bazelisk') - if bazel_executable is None: - print('Cannot find bazel. Please install bazel/bazelisk.') - sys.exit(1) - - stderr = open(os.devnull, 'wb') - curr_version = run_shell([bazel_executable, '--version'], - allow_non_zero=True, - stderr=stderr) - if curr_version.startswith('bazel '): - curr_version = curr_version.split('bazel ')[1] - - curr_version_int = convert_version_to_int(curr_version) - - # Check if current bazel version can be detected properly. - if not curr_version_int: - print('WARNING: current bazel installation is not a release version.') - return curr_version - - print('You have bazel %s installed.' % curr_version) - return curr_version - - -def set_cc_opt_flags(environ_cp): - """Set up architecture-dependent optimization flags. - - Also append CC optimization flags to bazel.rc.. - - Args: - environ_cp: copy of the os.environ. - """ - if is_ppc64le(): - # gcc on ppc64le does not support -march, use mcpu instead - default_cc_opt_flags = '-mcpu=native' - elif is_windows(): - default_cc_opt_flags = '/arch:AVX' - else: - # On all other platforms, no longer use `-march=native` as this can result - # in instructions that are too modern being generated. Users that want - # maximum performance should compile TF in their environment and can pass - # `-march=native` there. - # See https://github.com/tensorflow/tensorflow/issues/45744 and duplicates - default_cc_opt_flags = '-Wno-sign-compare' - question = ('Please specify optimization flags to use during compilation when' - ' bazel option "--config=opt" is specified [Default is %s]: ' - ) % default_cc_opt_flags - cc_opt_flags = get_from_env_or_user_or_default(environ_cp, 'CC_OPT_FLAGS', - question, default_cc_opt_flags) - for opt in cc_opt_flags.split(): - write_to_bazelrc('build:opt --copt=%s' % opt) - write_to_bazelrc('build:opt --host_copt=%s' % opt) - - -def set_tf_cuda_clang(environ_cp): - """set TF_CUDA_CLANG action_env. - - Args: - environ_cp: copy of the os.environ. - """ - question = 'Do you want to use clang as CUDA compiler?' - yes_reply = 'Clang will be used as CUDA compiler.' - no_reply = 'nvcc will be used as CUDA compiler.' - set_action_env_var( - environ_cp, - 'TF_CUDA_CLANG', - None, - False, - question=question, - yes_reply=yes_reply, - no_reply=no_reply, - bazel_config_name='cuda_clang') - - -def set_tf_download_clang(environ_cp): - """Set TF_DOWNLOAD_CLANG action_env.""" - question = 'Do you wish to download a fresh release of clang? (Experimental)' - yes_reply = 'Clang will be downloaded and used to compile tensorflow.' - no_reply = 'Clang will not be downloaded.' - set_action_env_var( - environ_cp, - 'TF_DOWNLOAD_CLANG', - None, - False, - question=question, - yes_reply=yes_reply, - no_reply=no_reply, - bazel_config_name='download_clang') - - -def get_from_env_or_user_or_default(environ_cp, var_name, ask_for_var, - var_default): - """Get var_name either from env, or user or default. - - If var_name has been set as environment variable, use the preset value, else - ask for user input. If no input is provided, the default is used. - - Args: - environ_cp: copy of the os.environ. - var_name: string for name of environment variable, e.g. "TF_NEED_CUDA". - ask_for_var: string for how to ask for user input. - var_default: default value string. - - Returns: - string value for var_name - """ - var = environ_cp.get(var_name) - if not var: - var = get_input(ask_for_var) - print('\n') - if not var: - var = var_default - return var - - -def set_clang_cuda_compiler_path(environ_cp): - """Set CLANG_CUDA_COMPILER_PATH.""" - default_clang_path = '/usr/lib/llvm-16/bin/clang' - if not os.path.exists(default_clang_path): - default_clang_path = which('clang') or '' - - clang_cuda_compiler_path = prompt_loop_or_load_from_env( - environ_cp, - var_name='CLANG_CUDA_COMPILER_PATH', - var_default=default_clang_path, - ask_for_var='Please specify clang path that to be used as host compiler.', - check_success=os.path.exists, - resolve_symlinks=True, - error_msg='Invalid clang path. %s cannot be found.', - ) - - # Set CLANG_CUDA_COMPILER_PATH - environ_cp['CLANG_CUDA_COMPILER_PATH'] = clang_cuda_compiler_path - write_action_env_to_bazelrc('CLANG_CUDA_COMPILER_PATH', - clang_cuda_compiler_path) - return clang_cuda_compiler_path - - -def prompt_loop_or_load_from_env(environ_cp, - var_name, - var_default, - ask_for_var, - check_success, - error_msg, - suppress_default_error=False, - resolve_symlinks=False, - n_ask_attempts=_DEFAULT_PROMPT_ASK_ATTEMPTS): - """Loop over user prompts for an ENV param until receiving a valid response. - - For the env param var_name, read from the environment or verify user input - until receiving valid input. When done, set var_name in the environ_cp to its - new value. - - Args: - environ_cp: (Dict) copy of the os.environ. - var_name: (String) string for name of environment variable, e.g. "TF_MYVAR". - var_default: (String) default value string. - ask_for_var: (String) string for how to ask for user input. - check_success: (Function) function that takes one argument and returns a - boolean. Should return True if the value provided is considered valid. May - contain a complex error message if error_msg does not provide enough - information. In that case, set suppress_default_error to True. - error_msg: (String) String with one and only one '%s'. Formatted with each - invalid response upon check_success(input) failure. - suppress_default_error: (Bool) Suppress the above error message in favor of - one from the check_success function. - resolve_symlinks: (Bool) Translate symbolic links into the real filepath. - n_ask_attempts: (Integer) Number of times to query for valid input before - raising an error and quitting. - - Returns: - [String] The value of var_name after querying for input. - - Raises: - UserInputError: if a query has been attempted n_ask_attempts times without - success, assume that the user has made a scripting error, and will - continue to provide invalid input. Raise the error to avoid infinitely - looping. - """ - default = environ_cp.get(var_name) or var_default - full_query = '%s [Default is %s]: ' % ( - ask_for_var, - default, - ) - - for _ in range(n_ask_attempts): - val = get_from_env_or_user_or_default(environ_cp, var_name, full_query, - default) - if check_success(val): - break - if not suppress_default_error: - print(error_msg % val) - environ_cp[var_name] = '' - else: - raise UserInputError('Invalid %s setting was provided %d times in a row. ' - 'Assuming to be a scripting mistake.' % - (var_name, n_ask_attempts)) - - if resolve_symlinks: - val = os.path.realpath(val) - environ_cp[var_name] = val - return val - - -def set_gcc_host_compiler_path(environ_cp): - """Set GCC_HOST_COMPILER_PATH.""" - default_gcc_host_compiler_path = which('gcc') or '' - cuda_bin_symlink = '%s/bin/gcc' % environ_cp.get('CUDA_TOOLKIT_PATH') - - if os.path.islink(cuda_bin_symlink): - # os.readlink is only available in linux - default_gcc_host_compiler_path = os.path.realpath(cuda_bin_symlink) - - gcc_host_compiler_path = prompt_loop_or_load_from_env( - environ_cp, - var_name='GCC_HOST_COMPILER_PATH', - var_default=default_gcc_host_compiler_path, - ask_for_var='Please specify which gcc should be used by nvcc as the host ' - 'compiler.', - check_success=os.path.exists, - resolve_symlinks=True, - error_msg='Invalid gcc path. %s cannot be found.', - ) - - write_action_env_to_bazelrc('GCC_HOST_COMPILER_PATH', gcc_host_compiler_path) - - -def choose_compiler(environ_cp): - question = 'Do you want to use Clang to build TensorFlow?' - yes_reply = 'Clang will be used to compile TensorFlow.' - no_reply = 'GCC will be used to compile TensorFlow.' - var = int( - get_var( - environ_cp, 'TF_NEED_CLANG', None, True, question, yes_reply, no_reply - ) - ) - return var - - -def set_clang_compiler_path(environ_cp): - """Set CLANG_COMPILER_PATH and environment variables. - - Loop over user prompts for clang path until receiving a valid response. - Default is used if no input is given. Set CLANG_COMPILER_PATH and write - environment variables CC and BAZEL_COMPILER to .bazelrc. - - Args: - environ_cp: (Dict) copy of the os.environ. - - Returns: - string value for clang_compiler_path. - """ - # Default path if clang-16 is installed by using apt-get install - default_clang_path = '/usr/lib/llvm-16/bin/clang' - if not os.path.exists(default_clang_path): - default_clang_path = which('clang') or '' - - clang_compiler_path = prompt_loop_or_load_from_env( - environ_cp, - var_name='CLANG_COMPILER_PATH', - var_default=default_clang_path, - ask_for_var='Please specify the path to clang executable.', - check_success=os.path.exists, - resolve_symlinks=True, - error_msg=( - 'Invalid clang path. %s cannot be found. Note that TensorFlow now' - ' requires clang to compile. You may override this behavior by' - ' setting TF_NEED_CLANG=0' - ), - ) - - write_action_env_to_bazelrc('CLANG_COMPILER_PATH', clang_compiler_path) - write_to_bazelrc('build --repo_env=CC=%s' % clang_compiler_path) - write_to_bazelrc('build --repo_env=BAZEL_COMPILER=%s' % clang_compiler_path) - write_to_bazelrc('build --linkopt="-fuse-ld=lld"') - write_to_bazelrc('build --linkopt="-lm"') - - return clang_compiler_path - - -def retrieve_clang_version(clang_executable): - """Retrieve installed clang version. - - Args: - clang_executable: (String) path to clang executable - - Returns: - The clang version detected. - """ - stderr = open(os.devnull, 'wb') - curr_version = run_shell( - [clang_executable, '--version'], allow_non_zero=True, stderr=stderr - ) - - curr_version_split = curr_version.lower().split('clang version ') - if len(curr_version_split) > 1: - curr_version = curr_version_split[1].split()[0] - - curr_version_int = convert_version_to_int(curr_version) - # Check if current clang version can be detected properly. - if not curr_version_int: - print('WARNING: current clang installation is not a release version.\n') - return None - - print('You have Clang %s installed.\n' % curr_version) - return curr_version - - -# Disable clang extension that rejects type definitions within offsetof. -# This was added in clang-16 by https://reviews.llvm.org/D133574. -# Can be removed once upb is updated, since a type definition is used within -# offset of in the current version of ubp. See -# https://github.com/protocolbuffers/upb/blob/9effcbcb27f0a665f9f345030188c0b291e32482/upb/upb.c#L183. -def disable_clang16_offsetof_extension(clang_version): - if int(clang_version.split('.')[0]) == 16: - write_to_bazelrc('build --copt=-Wno-gnu-offsetof-extensions') - - -def set_tf_cuda_paths(environ_cp): - """Set TF_CUDA_PATHS.""" - ask_cuda_paths = ( - 'Please specify the comma-separated list of base paths to look for CUDA ' - 'libraries and headers. [Leave empty to use the default]: ') - tf_cuda_paths = get_from_env_or_user_or_default(environ_cp, 'TF_CUDA_PATHS', - ask_cuda_paths, '') - if tf_cuda_paths: - environ_cp['TF_CUDA_PATHS'] = tf_cuda_paths - - -def set_tf_cuda_version(environ_cp): - """Set TF_CUDA_VERSION.""" - ask_cuda_version = ( - 'Please specify the CUDA SDK version you want to use. ' - '[Leave empty to default to CUDA %s]: ') % _DEFAULT_CUDA_VERSION - tf_cuda_version = get_from_env_or_user_or_default(environ_cp, - 'TF_CUDA_VERSION', - ask_cuda_version, - _DEFAULT_CUDA_VERSION) - environ_cp['TF_CUDA_VERSION'] = tf_cuda_version - - -def set_tf_cudnn_version(environ_cp): - """Set TF_CUDNN_VERSION.""" - ask_cudnn_version = ( - 'Please specify the cuDNN version you want to use. ' - '[Leave empty to default to cuDNN %s]: ') % _DEFAULT_CUDNN_VERSION - tf_cudnn_version = get_from_env_or_user_or_default(environ_cp, - 'TF_CUDNN_VERSION', - ask_cudnn_version, - _DEFAULT_CUDNN_VERSION) - environ_cp['TF_CUDNN_VERSION'] = tf_cudnn_version - - -def set_tf_nccl_version(environ_cp): - """Set TF_NCCL_VERSION.""" - if not is_linux(): - raise ValueError('Currently NCCL is only supported on Linux platform.') - - if 'TF_NCCL_VERSION' in environ_cp: - return - - ask_nccl_version = ( - 'Please specify the locally installed NCCL version you want to use. ' - '[Leave empty to use http://github.com/nvidia/nccl]: ') - tf_nccl_version = get_from_env_or_user_or_default(environ_cp, - 'TF_NCCL_VERSION', - ask_nccl_version, '') - environ_cp['TF_NCCL_VERSION'] = tf_nccl_version - - -def get_native_cuda_compute_capabilities(environ_cp): - """Get native cuda compute capabilities. - - Args: - environ_cp: copy of the os.environ. - - Returns: - string of native cuda compute capabilities, separated by comma. - """ - device_query_bin = os.path.join( - environ_cp.get('CUDA_TOOLKIT_PATH'), 'extras/demo_suite/deviceQuery') - if os.path.isfile(device_query_bin) and os.access(device_query_bin, os.X_OK): - try: - output = run_shell(device_query_bin).split('\n') - pattern = re.compile('[0-9]*\\.[0-9]*') - output = [pattern.search(x) for x in output if 'Capability' in x] - output = ','.join(x.group() for x in output if x is not None) - except subprocess.CalledProcessError: - output = '' - else: - output = '' - return output - - -def set_tf_cuda_compute_capabilities(environ_cp): - """Set TF_CUDA_COMPUTE_CAPABILITIES.""" - while True: - native_cuda_compute_capabilities = get_native_cuda_compute_capabilities( - environ_cp) - if not native_cuda_compute_capabilities: - default_cuda_compute_capabilities = _DEFAULT_CUDA_COMPUTE_CAPABILITIES - else: - default_cuda_compute_capabilities = native_cuda_compute_capabilities - - ask_cuda_compute_capabilities = ( - 'Please specify a list of comma-separated CUDA compute capabilities ' - 'you want to build with.\nYou can find the compute capability of your ' - 'device at: https://developer.nvidia.com/cuda-gpus. Each capability ' - 'can be specified as "x.y" or "compute_xy" to include both virtual and' - ' binary GPU code, or as "sm_xy" to only include the binary ' - 'code.\nPlease note that each additional compute capability ' - 'significantly increases your build time and binary size, and that ' - 'XLA only supports compute capabilities >= 5.2 [Default is: ' - '%s]: ' % default_cuda_compute_capabilities - ) - tf_cuda_compute_capabilities = get_from_env_or_user_or_default( - environ_cp, 'TF_CUDA_COMPUTE_CAPABILITIES', - ask_cuda_compute_capabilities, default_cuda_compute_capabilities) - # Check whether all capabilities from the input is valid - all_valid = True - # Remove all whitespace characters before splitting the string - # that users may insert by accident, as this will result in error - tf_cuda_compute_capabilities = ''.join(tf_cuda_compute_capabilities.split()) - for compute_capability in tf_cuda_compute_capabilities.split(','): - m = re.match('[0-9]+.[0-9]+', compute_capability) - if not m: - # We now support sm_52,compute_70. - sm_compute_match = re.match('(sm|compute)_?([0-9]+[0-9]+)', - compute_capability) - if not sm_compute_match: - print('Invalid compute capability: %s' % compute_capability) - all_valid = False - else: - ver = int(sm_compute_match.group(2)) - if ver < 52: - print( - 'ERROR: XLA only supports small CUDA compute' - ' capabilities of sm_52 and higher. Please re-specify the list' - ' of compute capabilities excluding version %s.' % ver - ) - all_valid = False - else: - ver = float(m.group(0)) - if ver < 5.2: - print( - 'ERROR: XLA only supports CUDA compute capabilities 5.2 ' - 'and higher. Please re-specify the list of compute ' - 'capabilities excluding version %s.' % ver - ) - all_valid = False - - if all_valid: - break - - # Reset and Retry - environ_cp['TF_CUDA_COMPUTE_CAPABILITIES'] = '' - - # Set TF_CUDA_COMPUTE_CAPABILITIES - environ_cp['TF_CUDA_COMPUTE_CAPABILITIES'] = tf_cuda_compute_capabilities - write_action_env_to_bazelrc('TF_CUDA_COMPUTE_CAPABILITIES', - tf_cuda_compute_capabilities) - - -def set_other_cuda_vars(environ_cp): - """Set other CUDA related variables.""" - # If CUDA is enabled, always use GPU during build and test. - if environ_cp.get('TF_CUDA_CLANG') == '1': - write_to_bazelrc('build --config=cuda_clang') - else: - write_to_bazelrc('build --config=cuda') - - -def system_specific_test_config(environ_cp): - """Add default build and test flags required for TF tests to bazelrc.""" - write_to_bazelrc('test --flaky_test_attempts=3') - write_to_bazelrc('test --test_size_filters=small,medium') - - # Each instance of --test_tag_filters or --build_tag_filters overrides all - # previous instances, so we need to build up a complete list and write a - # single list of filters for the .bazelrc file. - - # Filters to use with both --test_tag_filters and --build_tag_filters - test_and_build_filters = ['-benchmark-test', '-no_oss', '-oss_excluded'] - # Additional filters for --test_tag_filters beyond those in - # test_and_build_filters - test_only_filters = ['-oss_serial'] - if is_windows(): - test_and_build_filters += ['-no_windows', '-windows_excluded'] - if ((environ_cp.get('TF_NEED_CUDA', None) == '1') or - (environ_cp.get('TF_NEED_ROCM', None) == '1')): - test_and_build_filters += ['-no_windows_gpu', '-no_gpu'] - else: - test_and_build_filters.append('-gpu') - elif is_macos(): - test_and_build_filters += ['-gpu', '-nomac', '-no_mac', '-mac_excluded'] - elif is_linux(): - if ((environ_cp.get('TF_NEED_CUDA', None) == '1') or - (environ_cp.get('TF_NEED_ROCM', None) == '1')): - test_and_build_filters.append('-no_gpu') - write_to_bazelrc('test --test_env=LD_LIBRARY_PATH') - else: - test_and_build_filters.append('-gpu') - if environ_cp.get('TF_NEED_ROCM', None) == '1': - test_and_build_filters.append('-no_rocm') - - write_to_bazelrc('test --test_tag_filters=%s' % - ','.join(test_and_build_filters + test_only_filters)) - write_to_bazelrc('test --build_tag_filters=%s' % - ','.join(test_and_build_filters)) - write_to_bazelrc('build --test_tag_filters=%s' % - ','.join(test_and_build_filters + test_only_filters)) - write_to_bazelrc('build --build_tag_filters=%s' % - ','.join(test_and_build_filters)) - - # Disable tests with "v1only" tag in "v2" Bazel config, but not in "v1" config - write_to_bazelrc('test:v1 --test_tag_filters=%s' % - ','.join(test_and_build_filters + test_only_filters)) - write_to_bazelrc('test:v1 --build_tag_filters=%s' % - ','.join(test_and_build_filters)) - write_to_bazelrc( - 'test:v2 --test_tag_filters=%s' % - ','.join(test_and_build_filters + test_only_filters + ['-v1only'])) - write_to_bazelrc('test:v2 --build_tag_filters=%s' % - ','.join(test_and_build_filters + ['-v1only'])) - - -def set_system_libs_flag(environ_cp): - syslibs = environ_cp.get('TF_SYSTEM_LIBS', '') - if syslibs: - if ',' in syslibs: - syslibs = ','.join(sorted(syslibs.split(','))) - else: - syslibs = ','.join(sorted(syslibs.split())) - write_action_env_to_bazelrc('TF_SYSTEM_LIBS', syslibs) - - for varname in ('PREFIX', 'LIBDIR', 'INCLUDEDIR', 'PROTOBUF_INCLUDE_PATH'): - if varname in environ_cp: - write_to_bazelrc('build --define=%s=%s' % (varname, environ_cp[varname])) - - -def set_windows_build_flags(): - """Set Windows specific build options.""" - - # First available in VS 16.4. Speeds up Windows compile times by a lot. See - # https://groups.google.com/a/tensorflow.org/d/topic/build/SsW98Eo7l3o/discussion - # pylint: disable=line-too-long - write_to_bazelrc( - 'build --copt=/d2ReducedOptimizeHugeFunctions --host_copt=/d2ReducedOptimizeHugeFunctions' - ) - - -def config_info_line(name, help_text): - """Helper function to print formatted help text for Bazel config options.""" - print('\t--config=%-12s\t# %s' % (name, help_text)) - - -def validate_cuda_config(environ_cp): - """Run find_cuda_config.py and return cuda_toolkit_path, or None.""" - - def maybe_encode_env(env): - """Encodes unicode in env to str on Windows python 2.x.""" - if not is_windows() or sys.version_info[0] != 2: - return env - for k, v in env.items(): - if isinstance(k, unicode): - k = k.encode('ascii') - if isinstance(v, unicode): - v = v.encode('ascii') - env[k] = v - return env - - cuda_libraries = ['cuda', 'cudnn'] - if is_linux(): - if environ_cp.get('TF_NCCL_VERSION', None): - cuda_libraries.append('nccl') - - find_cuda_script = os.path.join( - pathlib.Path(__file__).parent.resolve(), - 'third_party/tsl/third_party/gpus/find_cuda_config.py', - ) - if not os.path.isfile(find_cuda_script): - raise FileNotFoundError( - "Can't find 'find_cuda_config.py' script inside working directory," - f' expected in {find_cuda_script}' - ) - proc = subprocess.Popen( - [environ_cp['PYTHON_BIN_PATH'], find_cuda_script] + cuda_libraries, - stdout=subprocess.PIPE, - env=maybe_encode_env(environ_cp), - ) - - if proc.wait(): - # Errors from find_cuda_config.py were sent to stderr. - print('Asking for detailed CUDA configuration...\n') - return False - - config = dict( - tuple(line.decode('ascii').rstrip().split(': ')) for line in proc.stdout) - - print('Found CUDA %s in:' % config['cuda_version']) - print(' %s' % config['cuda_library_dir']) - print(' %s' % config['cuda_include_dir']) - - print('Found cuDNN %s in:' % config['cudnn_version']) - print(' %s' % config['cudnn_library_dir']) - print(' %s' % config['cudnn_include_dir']) - - if config.get('nccl_version', None): - print('Found NCCL %s in:' % config['nccl_version']) - print(' %s' % config['nccl_library_dir']) - print(' %s' % config['nccl_include_dir']) - - print('\n') - - environ_cp['CUDA_TOOLKIT_PATH'] = config['cuda_toolkit_path'] - return True - - -def get_gcc_compiler(environ_cp): - gcc_env = environ_cp.get('CXX') or environ_cp.get('CC') or which('gcc') - if gcc_env is not None: - gcc_version = run_shell([gcc_env, '--version']).split() - if gcc_version[0] in ('gcc', 'g++'): - return gcc_env - return None - - -def main(): - global _TF_WORKSPACE_ROOT - global _TF_BAZELRC - global _TF_CURRENT_BAZEL_VERSION - - parser = argparse.ArgumentParser() - parser.add_argument( - '--workspace', - type=str, - default=os.path.abspath(os.path.dirname(__file__)), - help='The absolute path to your active Bazel workspace.') - args = parser.parse_args() - - _TF_WORKSPACE_ROOT = args.workspace - _TF_BAZELRC = os.path.join(_TF_WORKSPACE_ROOT, _TF_BAZELRC_FILENAME) - - # Make a copy of os.environ to be clear when functions and getting and setting - # environment variables. - environ_cp = dict(os.environ) - - try: - current_bazel_version = retrieve_bazel_version() - except subprocess.CalledProcessError as e: - print('Error retrieving bazel version: ', e.output.decode('UTF-8').strip()) - raise e - - _TF_CURRENT_BAZEL_VERSION = convert_version_to_int(current_bazel_version) - - reset_tf_configure_bazelrc() - - setup_python(environ_cp) - - if is_windows(): - environ_cp['TF_NEED_OPENCL'] = '0' - environ_cp['TF_CUDA_CLANG'] = '0' - # TODO(ibiryukov): Investigate using clang as a cpu or cuda compiler on - # Windows. - environ_cp['TF_DOWNLOAD_CLANG'] = '0' - environ_cp['TF_NEED_MPI'] = '0' - - if is_ppc64le(): - # Enable MMA Dynamic Dispatch support if 'gcc' and if linker >= 2.35 - gcc_env = get_gcc_compiler(environ_cp) - if gcc_env is not None: - - # Use gold linker if 'gcc' and if 'ppc64le' - write_to_bazelrc('build --linkopt="-fuse-ld=gold"') - - # Get the linker version - ld_version = run_shell([gcc_env, '-Wl,-version']).split() - - ld_version_int = convert_version_to_int(ld_version[3]) - if ld_version_int is None: - ld_version_int = convert_version_to_int(ld_version[4]) - - # Enable if 'ld' version >= 2.35 - if ld_version_int >= 2035000: - write_to_bazelrc( - 'build --copt="-DEIGEN_ALTIVEC_ENABLE_MMA_DYNAMIC_DISPATCH=1"') - - set_action_env_var( - environ_cp, 'TF_NEED_ROCM', 'ROCm', False, bazel_config_name='rocm') - if (environ_cp.get('TF_NEED_ROCM') == '1' and - 'LD_LIBRARY_PATH' in environ_cp and - environ_cp.get('LD_LIBRARY_PATH') != '1'): - write_action_env_to_bazelrc('LD_LIBRARY_PATH', - environ_cp.get('LD_LIBRARY_PATH')) - - if (environ_cp.get('TF_NEED_ROCM') == '1' and environ_cp.get('ROCM_PATH')): - write_action_env_to_bazelrc('ROCM_PATH', environ_cp.get('ROCM_PATH')) - - if (environ_cp.get('TF_NEED_ROCM') == '1' and environ_cp.get('HIP_PLATFORM')): - write_action_env_to_bazelrc('HIP_PLATFORM', environ_cp.get('HIP_PLATFORM')) - - if is_windows(): - print('\nWARNING: Cannot build with CUDA support on Windows.\n' - 'Starting in TF 2.11, CUDA build is not supported for Windows. ' - 'For using XLA GPU on Windows, you will need to build/install ' - 'XLA in WSL2.\n') - environ_cp['TF_NEED_CUDA'] = '0' - else: - environ_cp['TF_NEED_CUDA'] = str( - int(get_var(environ_cp, 'TF_NEED_CUDA', 'CUDA', False))) - if (environ_cp.get('TF_NEED_CUDA') == '1' and - 'TF_CUDA_CONFIG_REPO' not in environ_cp): - - environ_save = dict(environ_cp) - for _ in range(_DEFAULT_PROMPT_ASK_ATTEMPTS): - - if validate_cuda_config(environ_cp): - cuda_env_names = [ - 'TF_CUDA_VERSION', - 'TF_CUBLAS_VERSION', - 'TF_CUDNN_VERSION', - 'TF_NCCL_VERSION', - 'TF_CUDA_PATHS', - # Items below are for backwards compatibility when not using - # TF_CUDA_PATHS. - 'CUDA_TOOLKIT_PATH', - 'CUDNN_INSTALL_PATH', - 'NCCL_INSTALL_PATH', - 'NCCL_HDR_PATH', - ] - # Note: set_action_env_var above already writes to bazelrc. - for name in cuda_env_names: - if name in environ_cp: - write_action_env_to_bazelrc(name, environ_cp[name]) - break - - # Restore settings changed below if CUDA config could not be validated. - environ_cp = dict(environ_save) - - set_tf_cuda_version(environ_cp) - set_tf_cudnn_version(environ_cp) - if is_linux(): - set_tf_nccl_version(environ_cp) - - set_tf_cuda_paths(environ_cp) - - else: - raise UserInputError( - 'Invalid CUDA setting were provided %d ' - 'times in a row. Assuming to be a scripting mistake.' - % _DEFAULT_PROMPT_ASK_ATTEMPTS - ) - - set_tf_cuda_compute_capabilities(environ_cp) - if 'LD_LIBRARY_PATH' in environ_cp and environ_cp.get( - 'LD_LIBRARY_PATH') != '1': - write_action_env_to_bazelrc('LD_LIBRARY_PATH', - environ_cp.get('LD_LIBRARY_PATH')) - - set_tf_cuda_clang(environ_cp) - if environ_cp.get('TF_CUDA_CLANG') == '1': - # Set up which clang we should use as the cuda / host compiler. - clang_cuda_compiler_path = set_clang_cuda_compiler_path(environ_cp) - clang_version = retrieve_clang_version(clang_cuda_compiler_path) - disable_clang16_offsetof_extension(clang_version) - else: - # Set up which gcc nvcc should use as the host compiler - # No need to set this on Windows - if not is_windows(): - set_gcc_host_compiler_path(environ_cp) - set_other_cuda_vars(environ_cp) - else: - # CUDA not required. Ask whether we should use clang for the CPU build. - if is_linux(): - environ_cp['TF_NEED_CLANG'] = str(choose_compiler(environ_cp)) - if environ_cp.get('TF_NEED_CLANG') == '1': - clang_compiler_path = set_clang_compiler_path(environ_cp) - clang_version = retrieve_clang_version(clang_compiler_path) - disable_clang16_offsetof_extension(clang_version) - - # ROCm / CUDA are mutually exclusive. - # At most 1 GPU platform can be configured. - gpu_platform_count = 0 - if environ_cp.get('TF_NEED_ROCM') == '1': - gpu_platform_count += 1 - if environ_cp.get('TF_NEED_CUDA') == '1': - gpu_platform_count += 1 - if gpu_platform_count >= 2: - raise UserInputError('CUDA / ROCm are mututally exclusive. ' - 'At most 1 GPU platform can be configured.') - - # Disable NCCL if XLA is configured for CPU - if gpu_platform_count == 0: - write_to_bazelrc('build --config=nonccl') - - set_cc_opt_flags(environ_cp) - set_system_libs_flag(environ_cp) - if is_windows(): - set_windows_build_flags() - - system_specific_test_config(environ_cp) - - print('Preconfigured Bazel build configs. You can use any of the below by ' - 'adding "--config=<>" to your build command. See .bazelrc for more ' - 'details.') - config_info_line('mkl', 'Build with MKL support.') - config_info_line( - 'mkl_aarch64', - 'Build with oneDNN and Compute Library for the Arm Architecture (ACL).') - config_info_line('monolithic', 'Config for mostly static monolithic build.') - config_info_line('numa', 'Build with NUMA support.') - config_info_line( - 'dynamic_kernels', - '(Experimental) Build kernels into separate shared objects.') - config_info_line('v1', 'Build with TensorFlow 1 API instead of TF 2 API.') - - print('Preconfigured Bazel build configs to DISABLE default on features:') - config_info_line('nogcp', 'Disable GCP support.') - - if gpu_platform_count == 1: - config_info_line('nonccl', 'Disable NVIDIA NCCL support.') - - -if __name__ == '__main__': - main() diff --git a/third_party/xla/docs/_toc.yaml b/third_party/xla/docs/_toc.yaml index df0bf39c6c84a7..50a24a1a6607c8 100644 --- a/third_party/xla/docs/_toc.yaml +++ b/third_party/xla/docs/_toc.yaml @@ -1,28 +1,51 @@ toc: - heading: XLA developer guide -- title: Overview - path: /xla -- title: Aliasing - path: /xla/aliasing -- title: Architecture - path: /xla/architecture -- title: Broadcasting - path: /xla/broadcasting -- title: Build from source - path: /xla/build_from_source -- title: Code reviews - path: /xla/code_reviews -- title: Copybara quirks - path: /xla/copybara -- title: Custom calls - path: /xla/custom_call -- title: Developer guide - path: /xla/developer_guide -- title: Developing a new backend - path: /xla/developing_new_backend -- title: Operation semantics - path: /xla/operation_semantics -- title: Shapes and layout - path: /xla/shapes -- title: Tiled layout - path: /xla/tiled_layout +- title: Getting started + section: + - title: Overview + path: /xla + - title: XLA architecture + path: /xla/architecture + - title: Operation semantics + path: /xla/operation_semantics +- title: Developer details + section: + - title: Broadcasting + path: /xla/broadcasting + - title: Shapes and layout + path: /xla/shapes + - title: Aliasing + path: /xla/aliasing + - title: Tiled layout + path: /xla/tiled_layout + - title: Writing custom calls + path: /xla/custom_call + - title: Persisted autotuning + path: /xla/persisted_autotuning + - title: Copybara quirks + path: /xla/copybara + - title: XLA Tooling + path: /xla/tools + - title: Using LSP autocompletion + path: /xla/lsp +- title: Contributing + section: + - title: Develop a new backend for XLA + path: /xla/developing_new_backend + - title: Developer guide + path: /xla/developer_guide + - title: Code reviews + path: /xla/code_reviews + - title: Build from source + path: /xla/build_from_source +- title: Using XLA in TensorFlow + section: + - title: Using XLA in TensorFlow + path: /xla/tf2xla + - title: Use tfcompile + path: /xla/tf2xla/tfcompile + - title: Autoclustering tutorial + path: /xla/tf2xla/tutorials/autoclustering_xla + - title: Use XLA with tf.function + path: /xla/tf2xla/tutorials/jit_compile + diff --git a/third_party/xla/docs/async_ops.md b/third_party/xla/docs/async_ops.md index 889272eecc4411..a2f7c1dbc7aff7 100644 --- a/third_party/xla/docs/async_ops.md +++ b/third_party/xla/docs/async_ops.md @@ -24,22 +24,13 @@ instructions. %async-start = (f32[64], f32[32], s32[]) async-start(f32[64] %operand), calls=%async_op -%async-done = f32[32] async-done((f32[64], f32[32], s32[]) %async-start), - calls=%async_op +%async-done = f32[32] async-done((f32[64], f32[32], s32[]) %async-start) ``` In the representation above, only `async-start` has a called computation since it is trivial to find what the `async-done` does by following its operand to find the corresponding `async-start` to find the called computation. -Today both `async-start` and `async-done` have a called computation attribute, -but long term we plan to keep it only for `async-start`, since it is trivial -to find what the `async-done` does by following its operand to find the -corresponding `async-start` to find the called computation. - -> [!NOTE] -> Tracked as b/302594825 internally. - Also note that the first element in the output tuple of `async-start` aliases with the operand, so the buffer stays alive until at least the async-done instruction. @@ -102,10 +93,8 @@ to the following and the two can be parsed to the same representation: (f32[64], f32[32], s32[]) %op-start), op_specific_attr=”foo” %op-update1 = (f32[64], f32[32], s32[]) op-update( - (f32[64], f32[32], s32[]) %op-update0), - op_specific_attr=”foo” -%op-done = f32[32] op-done((f32[64], f32[32], s32[]) %op-update1), - op_specific_attr=”foo” + (f32[64], f32[32], s32[]) %op-update0) +%op-done = f32[32] op-done((f32[64], f32[32], s32[]) %op-update1) ``` diff --git a/third_party/xla/docs/build_from_source.md b/third_party/xla/docs/build_from_source.md index ca2eb50234bf32..91ef1e49608818 100644 --- a/third_party/xla/docs/build_from_source.md +++ b/third_party/xla/docs/build_from_source.md @@ -10,23 +10,14 @@ If you did not clone the XLA repository or install Bazel, please check out the ### Configure XLA builds are configured by the `.bazelrc` file in the repository's root -directory. The `./configure` or `./configure.py` scripts can be used to adjust +directory. The `./configure.py` script can be used to adjust common settings. -If you need to change the configuration, run the `./configure` script from the -repository's root directory. This script will prompt you for the location of XLA -dependencies and asks for additional build configuration options (compiler +If you need to change the configuration, run the `./configure.py` script from the +repository's root directory. This script has flags for the location of XLA +dependencies and additional build configuration options (compiler flags, for example). Refer to the *Sample session* section for details. -``` -./configure -``` - -There is also a python version of this script, `./configure.py`. If using a -virtual environment, `python configure.py` prioritizes paths within the -environment, whereas `./configure` prioritizes paths outside the environment. In -both cases you can change the default. - ### CPU support We recommend using a suitable docker container to build/test XLA, such as @@ -39,19 +30,18 @@ docker run --name xla -w /xla -it -d --rm -v $PWD:/xla tensorflow/build:latest-p Using a docker container you can build XLA with CPU support using the following commands: ``` -docker exec xla ./configure +docker exec xla ./configure.py --backend=CPU docker exec xla bazel build //xla/... --spawn_strategy=sandboxed --test_output=all ``` -If you want to build XLA targets with CPU support without Docker you need to install gcc-10: +If you want to build XLA targets with CPU support without Docker you need to install clang. XLA currently builds on CI with clang-17, but earlier versions should also work: ``` -apt install gcc-10 g++-10 +apt install clang ``` Then configure and build targets using the following commands: -``` -yes '' | GCC_HOST_COMPILER_PATH=/usr/bin/gcc-10 CC=/usr/bin/gcc-10 TF_NEED_ROCM=0 TF_NEED_CUDA=0 TF_CUDA_CLANG=0 ./configure +``` ./configure.py --backend=CPU bazel build --test_output=all --spawn_strategy=sandboxed //xla/... ``` @@ -69,7 +59,7 @@ docker run --name xla_gpu -w /xla -it -d --rm -v $PWD:/xla tensorflow/build:late To build XLA with GPU support use the following command: ``` -docker exec -e TF_NEED_CUDA=1 xla_gpu ./configure +docker exec xla_gpu ./configure.py --backend=CUDA docker exec xla_gpu bazel build --test_output=all --spawn_strategy=sandboxed //xla/... ``` @@ -81,7 +71,7 @@ install the following additional dependencies: Then configure and build targets using the following commands: ``` -yes '' | GCC_HOST_COMPILER_PATH=/usr/bin/gcc-10 CC=/usr/bin/gcc-10 TF_NEED_ROCM=0 TF_NEED_CUDA=1 TF_CUDA_CLANG=0 ./configure +./configure.py --backend=CUDA bazel build --test_output=all --spawn_strategy=sandboxed //xla/... ``` diff --git a/third_party/xla/docs/custom_call.md b/third_party/xla/docs/custom_call.md index bd2bff418b1768..84633d697daa20 100644 --- a/third_party/xla/docs/custom_call.md +++ b/third_party/xla/docs/custom_call.md @@ -14,6 +14,15 @@ program. > to change it capriciously, but it may change. Some possible future changes are > described below. +> **Caution** The HLO-visible names of functions registered with the custom-call +> macros API do not respect C++ namespaces. As a result, accidental collisions +> from functions registered by different libraries are entirely possible! The +> API will reject such duplicate registrations, but to avoid issues in large +> projects the safest option is to either fully namespace-qualify all references +> to the functions in both the `XLA_REGISTER_CUSTOM_CALL` registration macros +> and custom call target references or to use C-style namespacing directly in +> the function name. + ## Create a custom call on CPU You can create an HLO instruction that represents a custom call via XLA's client diff --git a/third_party/xla/docs/developer_guide.md b/third_party/xla/docs/developer_guide.md index fe30b24740d24d..53b3efcd8cab5c 100644 --- a/third_party/xla/docs/developer_guide.md +++ b/third_party/xla/docs/developer_guide.md @@ -53,14 +53,14 @@ the repository, and create a pull request. Build for CPU: ```sh -docker exec xla ./configure +docker exec xla ./configure.py --backend=CPU docker exec xla bazel build --test_output=all --spawn_strategy=sandboxed //xla/... ``` Build for GPU: ```sh -docker exec -e TF_NEED_CUDA=1 xla ./configure +docker exec xla ./configure.py --backend=CUDA docker exec xla bazel build --test_output=all --spawn_strategy=sandboxed //xla/... ``` diff --git a/third_party/xla/docs/index.md b/third_party/xla/docs/index.md index 0d54e0ab8eed56..76bbb657f9fe7f 100644 --- a/third_party/xla/docs/index.md +++ b/third_party/xla/docs/index.md @@ -37,18 +37,6 @@ Alibaba, Amazon Web Services, AMD, Apple, Arm, Google, Intel, Meta, and NVIDIA. ## Documentation -To learn more about XLA, check out the guides below. If you're a new XLA +To learn more about XLA, check out the links on the left. If you're a new XLA developer, you might want to start with [XLA architecture](architecture.md) and then read [Code reviews](code_reviews.md). - -- [Aliasing in XLA](aliasing.md) -- [XLA architecture](architecture.md) -- [Broadcasting](broadcasting.md) -- [Code reviews](code_reviews.md) -- [XLA custom calls](custom_call.md) -- [Developing a new backend for XLA](developing_new_backend.md) -- [Indexing Analysis](indexing.md) -- [Operation semantics](operation_semantics.md) -- [Shapes and layout](shapes.md) -- [Tiled layout](tiled_layout.md) -- [Setting up LSP with clangd](lsp.md) diff --git a/third_party/xla/docs/indexing.md b/third_party/xla/docs/indexing.md index 111242b2cc8258..5dbffb8acb5a94 100644 --- a/third_party/xla/docs/indexing.md +++ b/third_party/xla/docs/indexing.md @@ -65,13 +65,20 @@ $s_1 \in [0, 16)$. This mapping can be constructed from the attributes of HLO instructions or the mappings of unfused instructions can be composed to get indexing for a fusion. The mapping also has a domain, which specifies for what elements of the tensor -the mapping exists. $$ \begin{eqnarray} \boldsymbol{f}(\boldsymbol{d}, -\boldsymbol{s})\; &s.t.& \\ \boldsymbol{lb}_d &\leq& \boldsymbol{d} \leq -\boldsymbol{ub}_d \\ \boldsymbol{lb}_s &\leq& \boldsymbol{s} \leq -\boldsymbol{ub}_s \\ \boldsymbol{lb}_g &\leq& \boldsymbol{g}(\boldsymbol{d}, -\boldsymbol{s}) \leq \boldsymbol{ub}_g \end{eqnarray} $$ Since we want to -minimize recomputation, we need a library for symbolic computations. XLA already -depends on MLIR, so we use +the mapping exists. + +$$ +\begin{eqnarray} +\boldsymbol{f}(\boldsymbol{d}, \boldsymbol{s})\; &s.t.& \\ +\boldsymbol{lb}_d &\leq& \boldsymbol{d} \leq \boldsymbol{ub}_d \\ +\boldsymbol{lb}_s &\leq& \boldsymbol{s} \leq \boldsymbol{ub}_s \\ +\boldsymbol{lb}_g &\leq& \boldsymbol{g}(\boldsymbol{d}, + \boldsymbol{s}) \leq \boldsymbol{ub}_g +\end{eqnarray} +$$ + +Since we want to minimize recomputation, we need a library for symbolic +computations. XLA already depends on MLIR, so we use [mlir::AffineMap](https://github.com/llvm/llvm-project/blob/main/mlir/include/mlir/IR/AffineMap.h) instead of writing a symbolic arithmetic library. @@ -92,17 +99,13 @@ struct Range { int64_t upper_bound; }; -struct Domain { +struct IndexingMap { + mlir::AffineMap affine_map; std::vector dimension_ranges; std::vector symbol_ranges; llvm::DenseMap expr_ranges; }; -struct IndexingMap { - mlir::AffineMap affine_map; - Domain domain; -}; - ``` `dim_ranges` encodes the **inclusive** box constraints for the dimension @@ -140,7 +143,7 @@ The input to output maps - input_i -> output: $(d_0, d_1) \mapsto (d_0, d_1)$ for $\boldsymbol{d} \in {\rm Dom}(output)$ -### Broadcast +### [Broadcast](https://openxla.org/xla/operation_semantics#broadcastindim) Broadcasting means that some of the dimensions will be removed when we map output to input and added when we map input to output. @@ -165,12 +168,12 @@ mapping. Those are the symbols that represent ranges of values. For example, in this particular case every element of input with index $d_0$ is mapped to a 10x1x30 slice of the output. -### Constant and Iota +### Constant and [Iota](https://openxla.org/xla/operation_semantics#iota) Conveniently, they do not have any input parameters, so there is nothing to compute indexing for. -### Transpose +### [Transpose](https://openxla.org/xla/operation_semantics#transpose) Indexing map for transpose is a permutation of input/output dimensions. @@ -189,7 +192,7 @@ The input to output map: - input -> output: $(d_0, d_1, d_2, d_3) \mapsto (d_0, d_2, d_3, d_1)$ for $\boldsymbol{d} \in {\rm Dom}(input)$ -### Reverse +### [Reverse](https://openxla.org/xla/operation_semantics#rev_reverse) Indexing map for reverse changes the reverted dimensions to $upper\_bound(d_i) - d_i$: @@ -209,7 +212,7 @@ The input to output map: - input -> output: $(d_0, d_1, d_2, d_3) \mapsto (d_0, -d_1 + 16, -d_2 + 8, d_3)$ for $\boldsymbol{d} \in {\rm Dom}(input)$ -### **(Variadic)Reduce** +### **[(Variadic)Reduce](https://openxla.org/xla/operation_semantics#reduce)** Variadic reduction have several inputs and several inits, the map from output to input adds the reduced dimensions. So, it behaves like an inverse to a broadcast @@ -239,7 +242,7 @@ The input to output maps: for $i, j = 0, \ldots, INPUT\\_COUNT$. -### Slice +### [Slice](https://openxla.org/xla/operation_semantics#slice) Indexing from output to input for slice results in a strided indexing map which is valid for every element of the output. Mapping from the input to output is @@ -264,7 +267,7 @@ The input to output map: **TBD**: input-to-output indexing -### Reshape +### [Reshape](https://openxla.org/xla/operation_semantics#reshape) Reshapes come in different flavors. @@ -366,7 +369,7 @@ A bitcast op can be represented as a Therefore, its indexing maps are just a composition of indexing maps for this sequence. -### Concatenate +### [Concatenate](https://openxla.org/xla/operation_semantics#concatenate) Output-to-input mapping for concat is defined for all inputs, but with non-overlapping domains, i.e. only one of the inputs will be used at a time. @@ -396,7 +399,7 @@ The inputs to output map: - input 2 -> output: $(d_0, d_1) \mapsto (d_0, d_1 + 50)$ for $\boldsymbol{d} \in {\rm Dom}(input_2)$. -### Dot (output-to-input implemented +### [Dot](https://openxla.org/xla/operation_semantics#dot) Indexing maps for dot are very similar to the ones of reduce. @@ -422,9 +425,43 @@ The inputs to output maps: - input_2 -> output: $(d_0, d_1, d_2) \mapsto (d_0, s_0, d_1)$ for $\boldsymbol{d} \in {\rm Dom}(input_2)$ and $\boldsymbol{s} \in [0, 127]$ -### Reduce-window (TBD) +### [Pad](https://openxla.org/xla/operation_semantics#pad) + +Indexing of PadOp is inverse of SliceOp indexing. + +```c+ +p0 = f32[4, 4] parameter(0) +p1 = f32[] parameter(1) +pad = f32[12, 16] pad(p0, p1), padding=1_4_1x4_8_0 +``` + +The padding config `1_4_1x4_8_0` denotes `lowPad_highPad_interiorPad_dim_0 x lowPad_highPad_interiorPad_dim_1`. + +The output to input maps: -### Pad (TBD) +- output -> input: $(d_0, d_1) \mapsto ((d_0 - 1) / 2, d_1 - 4)$ + for $\boldsymbol{d} \in [1, 7] \times [4, 7]$ and $(d_0 - 1) \mod 2 \equiv 0$ +- output -> init: $(d_0, d_1) \mapsto ()$ for $\boldsymbol{d} \in {\rm Dom}(output)$ + + +### [ReduceWindow](https://openxla.org/xla/operation_semantics#reducewindow) + +ReduceWindow in XLA also performs padding. Therefore, the indexing maps can be +computed as a composition of ReduceWindow indexing that does not do any padding +and PadOp's indexing. + + +```c+ +c_inf = f32[] constant(-inf) +p0 = f32[1024, 514] parameter(0) +reduce-window = f32[1024, 3] reduce-window(p0, c_inf), + window={size=1x512 pad=0_0x0_0}, to_apply=max +``` + +The output to input maps: + +- output -> input: $(d_0, d_1) \mapsto (d_0, d_1 + s_0)$ for $\boldsymbol{d} \in [0, 1023] \times [0, 2]$ and $\boldsymbol{s} \in [0, 511]$ +- output -> init: $(d_0, d_1) \mapsto ()$ for $\boldsymbol{d} \in {\rm Dom}(output)$ ## Indexing Maps for Fusion @@ -471,7 +508,6 @@ f { The output-to-input indexing map for `p0` in this case is just $(d_0, d_1, d_2) \mapsto (d_2, d_0, d_1)$. -​ ### Softmax @@ -519,3 +555,13 @@ reshape2 = f32[10, 10, 10] reshape(reshape1) After the composition of indexing maps and their simplification we will get $(d_0, d_1, d_2) \mapsto (d_0, d_1, d_2)$. + +Indexing map simplification also simplifies the constraints. + +1. Constraints of type +`lower_bound <= affine_expr (floordiv, +, -, *) constant <= upper_bound` are +rewritten as `updated_lower_bound <= affine_expr <= updated_upped_bound`. +2. Constraints that are always satisfied, e.g. $d_0 + s_0 in [0, 20]$ +for $d_0 \in [0, 5]$ and $s_0 \in [1, 3]$ are eliminated. +3. Affine expressions in the constraints are optimized as the indexing affine +map above. diff --git a/third_party/xla/docs/persisted_autotuning.md b/third_party/xla/docs/persisted_autotuning.md new file mode 100644 index 00000000000000..5d1f01ab501442 --- /dev/null +++ b/third_party/xla/docs/persisted_autotuning.md @@ -0,0 +1,76 @@ +# Persisted autotuning (GPU only) + +We use OpenAI Triton for generating some of the GPU kernels. Triton allows +generating fast GPU kernels for certain fusions, but we have to tune some +parameters for each such fusion. + +This can take a long time if there are many fusions, so we provide a way to load +those autotuning results, while still running the other compilation steps +normally. Autotuning caches are still useful if we make a few changes: the +fusions that are present in the cache will use the cache, and the other ones +will be autotuned normally. + +The autotuning results can be dumped/loaded using these parameters: + +``` +--xla_gpu_dump_autotune_results_to= +--xla_gpu_load_autotune_results_from= +``` + +If we specify a .txt or .textproto file, then the cache will be dumped in +textproto format, otherwise in binary protobuf format. + +## In tests + +Persisted autotuning can also be used in tests. It is recommended to use it if +the tests are very big, especially if the performance of the test environment is +limited. + +It only works well if the autotune cache contains results generated on the same +type of GPU where the tests are being run. + +### Making a test use persisted autotuning + +For now let's assume that the test in question always uses the same GPU type. + +1. We have to export the autotune results from the test, for example by + specifying these parameters to the test command: + + ``` + --test_env=XLA_FLAGS=--xla_gpu_dump_autotune_results_to=TEST_UNDECLARED_OUTPUTS_DIR/autotune_cache.textproto + --test_sharding_strategy=disabled + ``` + + Sharding must be disabled to correctly get a single autotune cache for all + tests. + +2. Then we have to upload that cache to our code repository. + +3. Then we have to add the cache to the data dependencies of our test target, + and load it using an environment variable. + + ``` + data = ["test_autotune_cache.textproto"], + env = {"XLA_FLAGS": "--xla_gpu_load_autotune_results_from=" + + "$(execpath test_autotune_cache.textproto)"}, + ``` + + (It is OK to use sharding in tests that load autotune results.) + +Please also see the example tests in +[xla/service/gpu/tests/BUILD](https://github.com/openxla/xla/blob/main/xla/service/gpu/tests/BUILD): + +- load_autotune_results_using_execpath_test +- load_autotune_results_from_test_workspace_test +- dump_autotune_results_to_test_outputs_test + +### Cache obsolescence + +If many changes are made to a model, it is possible that the cache will no +longer contain all fusions, so the test will become slower. In this case we +would have to regenerate the autotuning cache. + +If we start using a new type of GPU for running the tests, the same applies. + +The cache may also become obsolete if the XLA compiler evolves and generates +different fusions. diff --git a/third_party/xla/docs/tf2xla/index.md b/third_party/xla/docs/tf2xla/index.md new file mode 100644 index 00000000000000..edde1f7de62374 --- /dev/null +++ b/third_party/xla/docs/tf2xla/index.md @@ -0,0 +1,239 @@ +# XLA: Optimizing Compiler for Machine Learning + +[OpenXLA](https://openxla.org) is a domain-specific compiler for linear +algebra that can accelerate TensorFlow models with potentially no source code +changes. + +## Introduction + +When a TensorFlow program is run, all of the operations are executed +individually by the TensorFlow executor. Each TensorFlow operation has a +precompiled GPU kernel implementation that the executor dispatches to. + +XLA provides an alternative mode of running models: it compiles the TensorFlow +graph into a sequence of computation kernels generated specifically for the +given model. Because these kernels are unique to the model, they can exploit +model-specific information for optimization. For example, let's look at an +optimization XLA does in the context of a simple TensorFlow computation: + +``` +def model_fn(x, y, z): + return tf.reduce_sum(x + y * z) +``` + +Run without XLA, the graph launches three kernels: one for the multiplication, +one for the addition and one for the reduction. However, XLA can optimize the +graph so that it computes the result in a single kernel launch. It does this by +"fusing" the addition, multiplication and reduction into a single GPU kernel. +Moreover, this fused operation does not write out the intermediate values +produced by `y*z` and `x+y*z` to memory; instead it "streams" the results of +these intermediate computations directly to their users while keeping them +entirely in GPU registers. Fusion is XLA's single most important optimization. +Memory bandwidth is typically the scarcest resource on hardware accelerators, so +removing memory operations is one of the best ways to improve performance. + +## Enable XLA for TensorFlow models + +### Explicit compilation with `tf.function(jit_compile=True)` + +Explicit compilation API offers a fine-grained control for choosing which +functions should be compiled. For example, the following TensorFlow function +which performs the MNIST training is compiled with XLA: + +``` +@tf.function(jit_compile=True) +def train_mnist(images, labels): + images, labels = cast(images, labels) + + with tf.GradientTape() as tape: + predicted_labels = layer(images) + loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits( + logits=predicted_labels, labels=labels + )) + layer_variables = layer.trainable_variables + grads = tape.gradient(loss, layer_variables) + optimizer.apply_gradients(zip(grads, layer_variables)) +``` + +The `jit_compile` API has _must-compile_ semantics: either the entire +function is compiled with XLA, or an `errors.InvalidArgumentError` exception is +thrown. XLA can not currently compile functions where dimensions are not +_inferrable_: that is, if it's not possible to infer the dimensions of all +tensors without running the entire computation. For example, the following +function will not compile: + +``` +@tf.function +def not_compilable(x): + return tf.unique(x) +``` + +Shapes can vary across the runs though: + +``` +@tf.function(jit_compile=True) +def recompiled_on_launch(a, b): + return a + b + +recompiled_on_launch(tf.ones([1, 10]), tf.ones([1, 10])) +recompiled_on_launch(tf.ones([1, 100]), tf.ones([1, 100])) +``` + +Note: Nesting behavior: the function will be compiled if at least one function +in its call stack has `jit_compile=True`. + +See the [tutorial colab](./tutorials/jit_compile.ipynb) for a more detailed +usage example, and a +[tutorial video](https://www.youtube.com/watch?v=cPAD9vLKE0c) on +`jit_compile=True` usage. + +### Usage with Keras + +For Keras models, `jit_compile=True` can be set as an argument to +[`model.compile`](https://www.tensorflow.org/api_docs/python/tf/keras/Model#compile): + +``` +model.compile(optimizer="adam", jit_compile=True) +``` + +### Usage with distributed strategy + +XLA:GPU can be used with TF distributed strategy +([`MirroredStrategy`](https://www.tensorflow.org/api_docs/python/tf/distribute/MirroredStrategy) +or +[`MultiWorkerMirroredStrategy`](https://www.tensorflow.org/api_docs/python/tf/distribute/experimental/MultiWorkerMirroredStrategy)) +by annotating step function with `jit_compile=True`: + +``` +@tf.function(jit_compile=True) +def step_fn(): + t = tf.ones(shape=[100], dtype=tf.float32) + ctx = tf.distribute.get_replica_context() + return ctx.all_reduce(tf.distribute.ReduceOp.SUM, t) + +@tf.function +def run_fn(): + return strategy.run(step_fn) +``` + +### Auto-clustering + +A simple way to start using XLA in TensorFlow models without any changes is to +enable _auto-clustering_, which automatically finds _clusters_ (connected +subgraphs) within the TensorFlow functions which can be compiled and executed +using XLA. Auto-clustering on GPU can be enabled by setting the `TF_XLA_FLAGS` +environment variable: + +Note: In TF2, only the code inside `tf.function` will be clustered. + +``` +$ TF_XLA_FLAGS=--tf_xla_auto_jit=2 path/to/your/tf/program +``` + +Auto-clustering is currently optimized for GPU workloads, but it can also be +enabled on CPU by additionally using the flag `--tf_xla_cpu_global_jit`: + +``` +$ TF_XLA_FLAGS="--tf_xla_auto_jit=2 --tf_xla_cpu_global_jit" path/to/your/program +``` + +Note: Auto-clustering support on CPU and on multi-GPU environments is +experimental. + +For a detailed usage example see the +[auto-clustering tutorial colab](./tutorials/autoclustering_xla.ipynb). + +### AOT (Ahead-of-time) compilation for CPU with `tfcompile` + +You can also use a standalone [`tfcompile`](./tfcompile.md) tool, which converts +TensorFlow graph into executable code (for x86-64 CPU only). + +## Inspect compiled programs + +XLA provides introspection facilities which let you inspect the generated +programs. To dump the generated programs, use the environment variable +`XLA_FLAGS`: + +``` +$ XLA_FLAGS="--xla_dump_to=/tmp/generated" TF_XLA_FLAGS="--tf_xla_auto_jit=2" my/tensorflow/program +``` + +After the dumping is performed, you can find the following files in +`/tmp/generated`: + +- `module_XXXX.*_optimizations.txt` Generated + [XLA programs](./operation_semantics.md), one per each compiled cluster. + Attaching those when submitting XLA bug reports is extremely helpful! + +- `module_XXXX.ir-*.ll` Generated files in + [LLVM](https://llvm.org/docs/LangRef.html) intermediate representation, with + [NVPTX](https://llvm.org/docs/NVPTXUsage.html) intrinsics. + +- `module_XXXX.ptx` Generated + [PTX](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html) + files. + +You can also dump the graph visualizing the embedding of XLA clusters inside of +the TensorFlow graph with: + +``` +$ TF_DUMP_GRAPH_PREFIX=/tmp/generated TF_XLA_FLAGS="--tf_xla_clustering_debug" +``` + +## Reproducible bug reports + +A bug report is much easier to reproduce if it includes dumps for the generated +XLA programs and the used auto-clustering embedding. +To generate them for a TensorFlow program running with auto-clustering, launch: + +``` +$ TF_DUMP_GRAPH_PREFIX=/tmp/generated \ + TF_XLA_FLAGS="--tf_xla_clustering_debug --tf_xla_auto_jit=2" \ + XLA_FLAGS="--xla_dump_hlo_as_text --xla_dump_to=/tmp/generated" \ + my/tensorflow/program" +``` + +When filing bugs, attach the contents of the `/tmp/generated` directory +(referenced above). + +If possible, try to isolate +a bug to a single XLA program by using the +[`run_hlo_module`](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/tools/run_hlo_module_main.cc) +and iteratively running it on generated programs. + +## Further reading + +- [OpenXLA Documentation](https://openxla.org) OpenXLA Documentation +- [Known Issues](./known_issues.md) List of known issues with XLA+TF +- [XLA - TensorFlow, Compiled](https://developers.googleblog.com/2017/03/xla-tensorflow-compiled.html): + Read on Google Developers Blog +- Check out the + [XLA source](https://github.com/openxla/xla) + on Github! + +## XLA Frontends + +Apart from TensorFlow, XLA programs can be generated by: + +- [JAX](https://github.com/google/jax): Composable transformations of + Python+NumPy programs +- [Julia](https://github.com/JuliaTPU/XLA.jl): The Julia language for + scientific computing +- [PyTorch](https://github.com/pytorch/xla): PyTorch framework +- [Nx](https://github.com/elixir-nx/nx): Numerical computing library for the + Elixir programming language + +## Talks + +### Using XLA from TF using `jit_compile=True` + + + +### XLA Overview + + diff --git a/third_party/xla/docs/tf2xla/tfcompile.md b/third_party/xla/docs/tf2xla/tfcompile.md new file mode 100644 index 00000000000000..5d60a4e90a9acb --- /dev/null +++ b/third_party/xla/docs/tf2xla/tfcompile.md @@ -0,0 +1,279 @@ +# Using AOT compilation + +## What is tfcompile? + +`tfcompile` is a standalone tool that ahead-of-time (AOT) compiles TensorFlow +graphs into executable code. It can reduce total binary size, and also avoid +some runtime overheads. A typical use-case of `tfcompile` is to compile an +inference graph into executable code for mobile devices. + +The TensorFlow graph is normally executed by the TensorFlow runtime. This incurs +some runtime overhead for execution of each node in the graph. This also leads +to a larger total binary size, since the code for the TensorFlow runtime needs +to be available, in addition to the graph itself. The executable code produced +by `tfcompile` does not use the TensorFlow runtime, and only has dependencies on +kernels that are actually used in the computation. + +The compiler is built on top of the XLA framework. The code bridging TensorFlow +to the XLA framework resides under +[tensorflow/compiler](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/). + +## What does tfcompile do? + +`tfcompile` takes a subgraph, identified by the TensorFlow concepts of +feeds and fetches, and generates a function that implements that subgraph. +The `feeds` are the input arguments for the function, and the `fetches` are the +output arguments for the function. All inputs must be fully specified by the +feeds; the resulting pruned subgraph cannot contain Placeholder or Variable +nodes. It is common to specify all Placeholders and Variables as feeds, which +ensures the resulting subgraph no longer contains these nodes. The generated +function is packaged as a `cc_library`, with a header file exporting the +function signature, and an object file containing the implementation. The user +writes code to invoke the generated function as appropriate. + +## Using tfcompile + +This section details high level steps for generating an executable binary with +`tfcompile` from a TensorFlow subgraph. The steps are: + +* Step 1: Configure the subgraph to compile +* Step 2: Use the `tf_library` build macro to compile the subgraph +* Step 3: Write code to invoke the subgraph +* Step 4: Create the final binary + +### Step 1: Configure the subgraph to compile + +Identify the feeds and fetches that correspond to the input and output +arguments for the generated function. Then configure the `feeds` and `fetches` +in a [`tensorflow.tf2xla.Config`](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/tf2xla/tf2xla.proto) +proto. + +```textproto +# Each feed is a positional input argument for the generated function. The order +# of each entry matches the order of each input argument. Here “x_hold” and “y_hold” +# refer to the names of placeholder nodes defined in the graph. +feed { + id { node_name: "x_hold" } + shape { + dim { size: 2 } + dim { size: 3 } + } +} +feed { + id { node_name: "y_hold" } + shape { + dim { size: 3 } + dim { size: 2 } + } +} + +# Each fetch is a positional output argument for the generated function. The order +# of each entry matches the order of each output argument. Here “x_y_prod” +# refers to the name of a matmul node defined in the graph. +fetch { + id { node_name: "x_y_prod" } +} +``` + +### Step 2: Use tf_library build macro to compile the subgraph + +This step converts the graph into a `cc_library` using the `tf_library` build +macro. The `cc_library` consists of an object file containing the code generated +from the graph, along with a header file that gives access to the generated +code. `tf_library` utilizes `tfcompile` to compile the TensorFlow graph into +executable code. + +```build +load("//tensorflow/compiler/aot:tfcompile.bzl", "tf_library") + +# Use the tf_library macro to compile your graph into executable code. +tf_library( + # name is used to generate the following underlying build rules: + # : cc_library packaging the generated header and object files + # _test : cc_test containing a simple test and benchmark + # _benchmark : cc_binary containing a stand-alone benchmark with minimal deps; + # can be run on a mobile device + name = "test_graph_tfmatmul", + # cpp_class specifies the name of the generated C++ class, with namespaces allowed. + # The class will be generated in the given namespace(s), or if no namespaces are + # given, within the global namespace. + cpp_class = "foo::bar::MatMulComp", + # graph is the input GraphDef proto, by default expected in binary format. To + # use the text format instead, just use the ‘.pbtxt’ suffix. A subgraph will be + # created from this input graph, with feeds as inputs and fetches as outputs. + # No Placeholder or Variable ops may exist in this subgraph. + graph = "test_graph_tfmatmul.pb", + # config is the input Config proto, by default expected in binary format. To + # use the text format instead, use the ‘.pbtxt’ suffix. This is where the + # feeds and fetches were specified above, in the previous step. + config = "test_graph_tfmatmul.config.pbtxt", +) +``` + +> To generate the GraphDef proto (test_graph_tfmatmul.pb) for this example, run +> [make_test_graphs.py](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/aot/tests/make_test_graphs.py) +> and specify the output location with the --out_dir flag. + +Typical graphs contain [`Variables`](https://www.tensorflow.org/guide/variables) +representing the weights that are learned via training, but `tfcompile` cannot +compile a subgraph that contain `Variables`. The +[freeze_graph.py](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/tools/freeze_graph.py) +tool converts variables into constants, using values stored in a checkpoint +file. As a convenience, the `tf_library` macro supports the `freeze_checkpoint` +argument, which runs the tool. For more examples see +[tensorflow/compiler/aot/tests/BUILD](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/aot/tests/BUILD). + +> Constants that show up in the compiled subgraph are compiled directly into the +> generated code. To pass the constants into the generated function, rather than +> having them compiled-in, simply pass them in as feeds. + +For details on the `tf_library` build macro, see +[tfcompile.bzl](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/aot/tfcompile.bzl). + +For details on the underlying `tfcompile` tool, see +[tfcompile_main.cc](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/aot/tfcompile_main.cc). + +### Step 3: Write code to invoke the subgraph + +This step uses the header file (`test_graph_tfmatmul.h`) generated by the +`tf_library` build macro in the previous step to invoke the generated code. The +header file is located in the `bazel-bin` directory corresponding to the +build package, and is named based on the name attribute set on the `tf_library` +build macro. For example, the header generated for `test_graph_tfmatmul` would +be `test_graph_tfmatmul.h`. Below is an abbreviated version of what is +generated. The generated file, in `bazel-bin`, contains additional useful +comments. + +```c++ +namespace foo { +namespace bar { + +// MatMulComp represents a computation previously specified in a +// TensorFlow graph, now compiled into executable code. +class MatMulComp { + public: + // AllocMode controls the buffer allocation mode. + enum class AllocMode { + ARGS_RESULTS_AND_TEMPS, // Allocate arg, result and temp buffers + RESULTS_AND_TEMPS_ONLY, // Only allocate result and temp buffers + }; + + MatMulComp(AllocMode mode = AllocMode::ARGS_RESULTS_AND_TEMPS); + ~MatMulComp(); + + // Runs the computation, with inputs read from arg buffers, and outputs + // written to result buffers. Returns true on success and false on failure. + bool Run(); + + // Arg methods for managing input buffers. Buffers are in row-major order. + // There is a set of methods for each positional argument. + void** args(); + + void set_arg0_data(float* data); + float* arg0_data(); + float& arg0(size_t dim0, size_t dim1); + + void set_arg1_data(float* data); + float* arg1_data(); + float& arg1(size_t dim0, size_t dim1); + + // Result methods for managing output buffers. Buffers are in row-major order. + // Must only be called after a successful Run call. There is a set of methods + // for each positional result. + void** results(); + + + float* result0_data(); + float& result0(size_t dim0, size_t dim1); +}; + +} // end namespace bar +} // end namespace foo +``` + +The generated C++ class is called `MatMulComp` in the `foo::bar` namespace, +because that was the `cpp_class` specified in the `tf_library` macro. All +generated classes have a similar API, with the only difference being the methods +to handle arg and result buffers. Those methods differ based on the number and +types of the buffers, which were specified by the `feed` and `fetch` arguments +to the `tf_library` macro. + +There are three types of buffers managed within the generated class: `args` +representing the inputs, `results` representing the outputs, and `temps` +representing temporary buffers used internally to perform the computation. By +default, each instance of the generated class allocates and manages all of these +buffers for you. The `AllocMode` constructor argument may be used to change this +behavior. All buffers are aligned to 64-byte boundaries. + +The generated C++ class is just a wrapper around the low-level code generated by +XLA. + +Example of invoking the generated function based on +[`tfcompile_test.cc`](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/aot/tests/tfcompile_test.cc): + +```c++ +#define EIGEN_USE_THREADS +#define EIGEN_USE_CUSTOM_THREAD_POOL + +#include +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" +#include "third_party/tensorflow/compiler/aot/tests/test_graph_tfmatmul.h" // generated + +int main(int argc, char** argv) { + Eigen::ThreadPool tp(2); // Size the thread pool as appropriate. + Eigen::ThreadPoolDevice device(&tp, tp.NumThreads()); + + + foo::bar::MatMulComp matmul; + matmul.set_thread_pool(&device); + + // Set up args and run the computation. + const float args[12] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}; + std::copy(args + 0, args + 6, matmul.arg0_data()); + std::copy(args + 6, args + 12, matmul.arg1_data()); + matmul.Run(); + + // Check result + if (matmul.result0(0, 0) == 58) { + std::cout << "Success" << std::endl; + } else { + std::cout << "Failed. Expected value 58 at 0,0. Got:" + << matmul.result0(0, 0) << std::endl; + } + + return 0; +} +``` + +### Step 4: Create the final binary + +This step combines the library generated by `tf_library` in step 2 and the code +written in step 3 to create a final binary. Below is an example `bazel` BUILD +file. + +```build +# Example of linking your binary +# Also see //tensorflow/compiler/aot/tests/BUILD +load("//tensorflow/compiler/aot:tfcompile.bzl", "tf_library") + +# The same tf_library call from step 2 above. +tf_library( + name = "test_graph_tfmatmul", + ... +) + +# The executable code generated by tf_library can then be linked into your code. +cc_binary( + name = "my_binary", + srcs = [ + "my_code.cc", # include test_graph_tfmatmul.h to access the generated header + ], + deps = [ + ":test_graph_tfmatmul", # link in the generated object file + "//third_party/eigen3", + ], + linkopts = [ + "-lpthread", + ] +) +``` diff --git a/third_party/xla/docs/tf2xla/tutorials/autoclustering_xla.ipynb b/third_party/xla/docs/tf2xla/tutorials/autoclustering_xla.ipynb new file mode 100644 index 00000000000000..88f94c2bbc3f82 --- /dev/null +++ b/third_party/xla/docs/tf2xla/tutorials/autoclustering_xla.ipynb @@ -0,0 +1,272 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "f4TSNCvpENrW" + }, + "source": [ + "##### Copyright 2019 The TensorFlow Authors." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "vamNSA0vEP-m" + }, + "outputs": [], + "source": [ + "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n", + "# you may not use this file except in compliance with the License.\n", + "# You may obtain a copy of the License at\n", + "#\n", + "# https://www.apache.org/licenses/LICENSE-2.0\n", + "#\n", + "# Unless required by applicable law or agreed to in writing, software\n", + "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", + "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", + "# See the License for the specific language governing permissions and\n", + "# limitations under the License." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "asd4sdga7g" + }, + "source": [ + "# Classifying CIFAR-10 with XLA\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "b7noD9NjFRL-" + }, + "source": [ + "\u003ctable class=\"tfo-notebook-buttons\" align=\"left\"\u003e\n", + " \u003ctd\u003e\n", + " \u003ca target=\"_blank\" href=\"https://www.tensorflow.org/xla/tutorials/autoclustering_xla\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" /\u003eView on TensorFlow.org\u003c/a\u003e\n", + " \u003c/td\u003e\n", + " \u003ctd\u003e\n", + " \u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/g3doc/tutorials/autoclustering_xla.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003eRun in Google Colab\u003c/a\u003e\n", + " \u003c/td\u003e\n", + " \u003ctd\u003e\n", + " \u003ca target=\"_blank\" href=\"https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/g3doc/tutorials/autoclustering_xla.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView source on GitHub\u003c/a\u003e\n", + " \u003c/td\u003e\n", + " \u003ctd\u003e\n", + " \u003ca href=\"https://storage.googleapis.com/tensorflow_docs/tensorflow/tensorflow/compiler/xla/g3doc/tutorials/autoclustering_xla.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/download_logo_32px.png\" /\u003eDownload notebook\u003c/a\u003e\n", + " \u003c/td\u003e\n", + "\u003c/table\u003e" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "mz65veHXsmnS" + }, + "source": [ + "This tutorial trains a TensorFlow model to classify the [CIFAR-10](https://en.wikipedia.org/wiki/CIFAR-10) dataset, and we compile it using XLA.\n", + "\n", + "You will load and normalize the dataset using the [TensorFlow Datasets (TFDS)](https://tensorflow.org/datasets) API. First, install/upgrade TensorFlow and TFDS:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "R4xtYyOf78e3" + }, + "outputs": [], + "source": [ + "!pip install -U -q tensorflow tensorflow_datasets" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "PH2HbLW65tmo" + }, + "outputs": [], + "source": [ + "import tensorflow as tf\n", + "import tensorflow_datasets as tfds" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "7vm2QsMisCxI" + }, + "outputs": [], + "source": [ + "# Check that GPU is available: cf. https://colab.research.google.com/notebooks/gpu.ipynb\n", + "assert(tf.test.gpu_device_name())\n", + "\n", + "tf.keras.backend.clear_session()\n", + "tf.config.optimizer.set_jit(False) # Start with XLA disabled.\n", + "\n", + "def load_data():\n", + " result = tfds.load('cifar10', batch_size = -1)\n", + " (x_train, y_train) = result['train']['image'],result['train']['label']\n", + " (x_test, y_test) = result['test']['image'],result['test']['label']\n", + " \n", + " x_train = x_train.numpy().astype('float32') / 256\n", + " x_test = x_test.numpy().astype('float32') / 256\n", + "\n", + " # Convert class vectors to binary class matrices.\n", + " y_train = tf.keras.utils.to_categorical(y_train, num_classes=10)\n", + " y_test = tf.keras.utils.to_categorical(y_test, num_classes=10)\n", + " return ((x_train, y_train), (x_test, y_test))\n", + "\n", + "(x_train, y_train), (x_test, y_test) = load_data()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "MgNM2tbgtScx" + }, + "source": [ + "We define the model, adapted from the Keras [CIFAR-10 example](https://keras.io/examples/cifar10_cnn/):" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "3ZRQSwoRsKM_" + }, + "outputs": [], + "source": [ + "def generate_model():\n", + " return tf.keras.models.Sequential([\n", + " tf.keras.layers.Conv2D(32, (3, 3), padding='same', input_shape=x_train.shape[1:]),\n", + " tf.keras.layers.Activation('relu'),\n", + " tf.keras.layers.Conv2D(32, (3, 3)),\n", + " tf.keras.layers.Activation('relu'),\n", + " tf.keras.layers.MaxPooling2D(pool_size=(2, 2)),\n", + " tf.keras.layers.Dropout(0.25),\n", + "\n", + " tf.keras.layers.Conv2D(64, (3, 3), padding='same'),\n", + " tf.keras.layers.Activation('relu'),\n", + " tf.keras.layers.Conv2D(64, (3, 3)),\n", + " tf.keras.layers.Activation('relu'),\n", + " tf.keras.layers.MaxPooling2D(pool_size=(2, 2)),\n", + " tf.keras.layers.Dropout(0.25),\n", + "\n", + " tf.keras.layers.Flatten(),\n", + " tf.keras.layers.Dense(512),\n", + " tf.keras.layers.Activation('relu'),\n", + " tf.keras.layers.Dropout(0.5),\n", + " tf.keras.layers.Dense(10),\n", + " tf.keras.layers.Activation('softmax')\n", + " ])\n", + "\n", + "model = generate_model()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "-M4GtGDZtb8a" + }, + "source": [ + "We train the model using the\n", + "[RMSprop](https://www.tensorflow.org/api_docs/python/tf/train/RMSPropOptimizer)\n", + "optimizer:\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "UKCmrhF0tiMa" + }, + "outputs": [], + "source": [ + "def compile_model(model):\n", + " opt = tf.keras.optimizers.RMSprop(learning_rate=0.0001)\n", + " model.compile(loss='categorical_crossentropy',\n", + " optimizer=opt,\n", + " metrics=['accuracy'])\n", + " return model\n", + "\n", + "model = compile_model(model)\n", + "\n", + "def train_model(model, x_train, y_train, x_test, y_test, epochs=25):\n", + " model.fit(x_train, y_train, batch_size=256, epochs=epochs, validation_data=(x_test, y_test), shuffle=True)\n", + "\n", + "def warmup(model, x_train, y_train, x_test, y_test):\n", + " # Warm up the JIT, we do not wish to measure the compilation time.\n", + " initial_weights = model.get_weights()\n", + " train_model(model, x_train, y_train, x_test, y_test, epochs=1)\n", + " model.set_weights(initial_weights)\n", + "\n", + "warmup(model, x_train, y_train, x_test, y_test)\n", + "%time train_model(model, x_train, y_train, x_test, y_test)\n", + "\n", + "scores = model.evaluate(x_test, y_test, verbose=1)\n", + "print('Test loss:', scores[0])\n", + "print('Test accuracy:', scores[1])" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "SLpfQ0StRgsu" + }, + "source": [ + "Now let's train the model again, using the XLA compiler.\n", + "To enable the compiler in the middle of the application, we need to reset the Keras session." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "jxU-Tzy4SX7p" + }, + "outputs": [], + "source": [ + "# We need to clear the session to enable JIT in the middle of the program.\n", + "tf.keras.backend.clear_session()\n", + "tf.config.optimizer.set_jit(True) # Enable XLA.\n", + "model = compile_model(generate_model())\n", + "(x_train, y_train), (x_test, y_test) = load_data()\n", + "\n", + "warmup(model, x_train, y_train, x_test, y_test)\n", + "%time train_model(model, x_train, y_train, x_test, y_test)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "iWHz6P1se92F" + }, + "source": [ + "On a machine with a Titan V GPU and an Intel Xeon E5-2690 CPU the speed up is ~1.17x." + ] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "collapsed_sections": [], + "name": "CIFAR-10 with XLA.ipynb", + "private_outputs": true, + "provenance": [], + "toc_visible": true + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/third_party/xla/docs/tf2xla/tutorials/jit_compile.ipynb b/third_party/xla/docs/tf2xla/tutorials/jit_compile.ipynb new file mode 100644 index 00000000000000..b9967f4e94f4da --- /dev/null +++ b/third_party/xla/docs/tf2xla/tutorials/jit_compile.ipynb @@ -0,0 +1,270 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "f4TSNCvpENrW" + }, + "source": [ + "##### Copyright 2019 The TensorFlow Authors." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "vamNSA0vEP-m" + }, + "outputs": [], + "source": [ + "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n", + "# you may not use this file except in compliance with the License.\n", + "# You may obtain a copy of the License at\n", + "#\n", + "# https://www.apache.org/licenses/LICENSE-2.0\n", + "#\n", + "# Unless required by applicable law or agreed to in writing, software\n", + "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", + "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", + "# See the License for the specific language governing permissions and\n", + "# limitations under the License." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "e1oSi4lHFt3z" + }, + "source": [ + "# Use XLA with tf.function" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "b7noD9NjFRL-" + }, + "source": [ + "\u003ctable class=\"tfo-notebook-buttons\" align=\"left\"\u003e\n", + " \u003ctd\u003e\n", + " \u003ca target=\"_blank\" href=\"https://www.tensorflow.org/xla/tutorials/compile\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" /\u003eView on TensorFlow.org\u003c/a\u003e\n", + " \u003c/td\u003e\n", + " \u003ctd\u003e\n", + " \u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/g3doc/tutorials/jit_compile.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003eRun in Google Colab\u003c/a\u003e\n", + " \u003c/td\u003e\n", + " \u003ctd\u003e\n", + " \u003ca href=\"https://storage.googleapis.com/tensorflow_docs/tensorflow/tensorflow/compiler/xla/g3doc/tutorials/jit_compile.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/download_logo_32px.png\" /\u003eDownload notebook\u003c/a\u003e\n", + " \u003c/td\u003e\n", + " \u003ctd\u003e\n", + " \u003ca target=\"_blank\" href=\"https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/g3doc/tutorials/jit_compile.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView source on GitHub\u003c/a\u003e\n", + " \u003c/td\u003e\n", + "\u003c/table\u003e" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "sDy5lSBd4BDE" + }, + "source": [ + "This tutorial trains a TensorFlow model to classify the MNIST dataset, where the training function is compiled using XLA.\n", + "\n", + "First, load TensorFlow and enable eager execution." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "45kUPj5ZFrRa" + }, + "outputs": [], + "source": [ + "import tensorflow as tf\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "GZVNiRmTDV-5" + }, + "source": [ + "Then define some necessary constants and prepare the MNIST dataset." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "f37TSEGvGX4_" + }, + "outputs": [], + "source": [ + "# Size of each input image, 28 x 28 pixels\n", + "IMAGE_SIZE = 28 * 28\n", + "# Number of distinct number labels, [0..9]\n", + "NUM_CLASSES = 10\n", + "# Number of examples in each training batch (step)\n", + "TRAIN_BATCH_SIZE = 100\n", + "# Number of training steps to run\n", + "TRAIN_STEPS = 1000\n", + "\n", + "# Loads MNIST dataset.\n", + "train, test = tf.keras.datasets.mnist.load_data()\n", + "train_ds = tf.data.Dataset.from_tensor_slices(train).batch(TRAIN_BATCH_SIZE).repeat()\n", + "\n", + "# Casting from raw data to the required datatypes.\n", + "def cast(images, labels):\n", + " images = tf.cast(\n", + " tf.reshape(images, [-1, IMAGE_SIZE]), tf.float32)\n", + " labels = tf.cast(labels, tf.int64)\n", + " return (images, labels)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "lv7I-u_82v1S" + }, + "source": [ + "Finally, define the model and the optimizer. The model uses a single dense layer." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "7O2NcEfG206Q" + }, + "outputs": [], + "source": [ + "layer = tf.keras.layers.Dense(NUM_CLASSES)\n", + "optimizer = tf.keras.optimizers.Adam()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "x_ZehpZP-SfS" + }, + "source": [ + "# Define the training function\n", + "\n", + "In the training function, you get the predicted labels using the layer defined above, and then minimize the gradient of the loss using the optimizer. In order to compile the computation using XLA, place it inside `tf.function` with `jit_compile=True`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "ZbhJl_WvGa3g" + }, + "outputs": [], + "source": [ + "@tf.function(jit_compile=True)\n", + "def train_mnist(images, labels):\n", + " images, labels = cast(images, labels)\n", + "\n", + " with tf.GradientTape() as tape:\n", + " predicted_labels = layer(images)\n", + " loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(\n", + " logits=predicted_labels, labels=labels\n", + " ))\n", + " layer_variables = layer.trainable_variables\n", + " grads = tape.gradient(loss, layer_variables)\n", + " optimizer.apply_gradients(zip(grads, layer_variables))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "EZD1m_n1DxAF" + }, + "source": [ + "# Train and test the model" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "gukC2Hol3sFZ" + }, + "source": [ + "Once you have defined the training function, define the model." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "qe28bAHNHUG2" + }, + "outputs": [], + "source": [ + "for images, labels in train_ds:\n", + " if optimizer.iterations \u003e TRAIN_STEPS:\n", + " break\n", + " train_mnist(images, labels)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "qgsKmz3n2UiW" + }, + "source": [ + "And, finally, check the accuracy:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "_GxF6jTRHVuA" + }, + "outputs": [], + "source": [ + "images, labels = cast(test[0], test[1])\n", + "predicted_labels = layer(images)\n", + "correct_prediction = tf.equal(tf.argmax(predicted_labels, 1), labels)\n", + "accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))\n", + "print(\"Prediction accuracy after training: %s\" % accuracy)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "PXoOjJnuZRaV" + }, + "source": [ + "Behind the scenes, the XLA compiler has compiled the entire TF function to HLO, which has enabled fusion optimizations. Using the introspection facilities, we can see the HLO code (other interesting possible values for \"stage\" are `optimized_hlo` for HLO after optimizations and `optimized_hlo_dot` for a Graphviz graph):" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "_a8GsNLVaLSQ" + }, + "outputs": [], + "source": [ + "print(train_mnist.experimental_get_compiler_ir(images, labels)(stage='hlo'))" + ] + } + ], + "metadata": { + "colab": { + "collapsed_sections": [], + "name": "jit_compile.ipynb", + "provenance": [], + "toc_visible": true + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/third_party/xla/docs/tools.md b/third_party/xla/docs/tools.md new file mode 100644 index 00000000000000..fe7d8ee5a6a86c --- /dev/null +++ b/third_party/xla/docs/tools.md @@ -0,0 +1,155 @@ +# Using XLA tooling + +The XLA development workflow is usually centered around +[HLO](./operation_semantics) IR, which represents isolated functional +computation given to the compiler. XLA comes with multiple command line tools +(described below) which consume HLO and either run it, or provide an +intermediate compilation stage. Using such tools is invaluable for a fast +`compile->modify->run` iteration cycle, as HLO is both visualizable and +hackable, and iteratively changing and running it is often the fastest way to +understand and to fix an XLA performance or behavior. + +The easiest way to obtain the HLO for a program being compiled with XLA is +usually to use the `XLA_FLAGS` environment variable: + +``` +XLA_FLAGS=--xla_dump_to=/tmp/myfolder ./myprogram-entry-point +``` + +which stores all before-optimization HLO files in the folder specified, along +with many other useful artifacts. + +## Running HLO snippets: `run_hlo_module` + +The tool `run_hlo_module` operates on pre-optimization HLO, and by default +bundles compilation, running and comparison with the reference interpreter +implementation. For example, the usual invocation to run an input file +`computation.hlo` on an NVIDIA GPU and to check it for correctness is: + +``` +run_hlo_module --platform=CUDA --reference_platform=Interpreter computation.hlo +``` + +As with all the tools, `--help` can be used to obtain the full list of options. + +## Running HLO snippets with SPMD support: `multihost_hlo_runner` + +Multihost HLO runner is a very similar tool, with the caveat that it supports +SPMD, including cross host communication. A typical invocation looks like: + +``` +hlo_runner_main --device_type=gpu --use_spmd_partitioning=true --num_partitions=4 --num_replicas=1 --hlo_file=computation.hlo +``` + +## Running passes/stages of HLO compilation: `hlo-opt` + +When debugging or understanding the workings of the compiler, it is often useful +to get the expansion for a particular hardware at a particular point in the +pipeline (be it HLO, optimized HLO, TritonIR or LLVM), for a given (Stable) HLO +input. + +`hlo-opt` supports multiple output stages: be it PTX, HLO after optimizations, +LLVM IR before optimizations, or TritonIR. The exact set of stages supported +depends on the platform (as e.g. PTX is NVIDIA-specific), and can be seen using +the --list-stages command: + +``` +$ hlo-opt --platform=CUDA --list-stages +hlo +llvm +ptx +``` + +After selecting a stage, the user can write the result of the conversion for a +given platform to a given stream: + +``` +$ hlo-opt myinput.hlo --platform=CUDA --stage=llvm +``` + +which would print the dump to stdout (or to a given file if `-o` was specified). + +### Deviceless Usage + +Access to a GPU is not needed for most of the compilation, and by specifying a +GPU spec on the command line we can get e.g. PTX output without access to an +accelerator: + +``` +$ hlo-opt --platform=CUDA --stage=llvm --xla_gpu_target_config_filename=(pwd)/tools/data/gpu_specs/a100_80.txtpb input.hlo +``` + +Note: For the above invocation to work, the user would usually either need to +disable autotuning with `--xla_gpu_autotune_level=0` or load a pre-existing +autotuning results with `--xla_gpu_load_autotune_results_from=` +(obtained with `--xla_gpu_dump_autotune_results_to=`). + +Specs for popular GPUs are shipped with the compiler, and the provided file is +string serialization of `device_description.proto`: + +``` +gpu_device_info { + cuda_compute_capability { + major: 8 + minor: 0 + } + threads_per_block_limit: 1024 + threads_per_warp: 32 + shared_memory_per_block: 127152 + shared_memory_per_core: 65536 + threads_per_core_limit: 2048 + core_count: 6192 + fpus_per_core: 64 + block_dim_limit_x: 2147483647 + block_dim_limit_y: 65535 + block_dim_limit_z: 65535 + memory_bandwidth: 2039000000000 + l2_cache_size: 4194304 + clock_rate_ghz: 1.1105 + device_memory_size: 79050250240 +} +platform_name: "CUDA" +``` + +Deviceless compilation might run into issues if autotuning is required. Luckily, +we can also provide those on the command line: + +``` +hlo-opt --platform=CUDA --stage=llvm --xla_gpu_target_config_filename=gpu_specs/a100_80.txtpb --xla_gpu_load_autotune_results_from=results.textpb input.hlo +``` + +The autotune file is text serialization of `autotune_results.proto`, with +example looking like: + +``` +version: 2 +results { + device: "sm_8.0 with 42331013120B RAM, 108 cores, 1410000KHz clock, 1215000KHz mem clock, 41943040B L2$" + hlo: "{\n tmp_0 = f16[1,16,17,3]{3,2,1,0} parameter(0)\n tmp_1 = f16[16,51]{1,0} bitcast(f16[1,16,17,3]{3,2,1,0} tmp_0)\n tmp_2 = s8[16,17,3]{2,1,0} parameter(1)\n tmp_3 = s8[51,16]{0,1} bitcast(s8[16,17,3]{2,1,0} tmp_2)\n tmp_4 = f16[51,16]{0,1} convert(s8[51,16]{0,1} tmp_3)\n tmp_5 = f16[16,16]{1,0} dot(f16[16,51]{1,0} tmp_1, f16[51,16]{0,1} tmp_4), lhs_contracting_dims={1}, rhs_contracting_dims={0}\n ROOT tmp_6 = f16[1,16,16]{2,1,0} bitcast(f16[16,16]{1,0} tmp_5)\n}" + result { + run_time { + nanos: 31744 + } + triton { + block_m: 32 + block_n: 32 + block_k: 32 + split_k: 1 + num_stages: 1 + num_warps: 4 + } + } +} +``` + +The autotuning database can be serialized using +`XLA_FLAGS=--xla_gpu_dump_autotune_results_t=` + +### Running a Single Compiler Pass + +The flags from `XLA_FLAGS` are also supported, so the tool can be used to test +running a single pass: + +``` +hlo-opt --platform=CUDA --stage=hlo --xla-hlo-enable-passes-only=algebraic_simplifer input.hlo +``` diff --git a/third_party/xla/third_party/compute_library/BUILD b/third_party/xla/third_party/compute_library/BUILD index b353c1fa0aedba..4fc694c50a43cf 100644 --- a/third_party/xla/third_party/compute_library/BUILD +++ b/third_party/xla/third_party/compute_library/BUILD @@ -1,9 +1,6 @@ load("@bazel_skylib//:bzl_library.bzl", "bzl_library") -exports_files( - ["LICENSE"], - visibility = ["//visibility:public"], -) +exports_files(["LICENSE"]) config_setting( name = "build_with_acl", diff --git a/third_party/xla/third_party/cudnn_frontend_header_fix.patch b/third_party/xla/third_party/cudnn_frontend_header_fix.patch index af22372c66009e..70476bd3ff5d56 100644 --- a/third_party/xla/third_party/cudnn_frontend_header_fix.patch +++ b/third_party/xla/third_party/cudnn_frontend_header_fix.patch @@ -1,234 +1,13 @@ -diff --git a/include/cudnn_backend_base.h b/include/cudnn_backend_base.h -index 1240282..cba52ec 100644 ---- a/include/cudnn_backend_base.h -+++ b/include/cudnn_backend_base.h -@@ -24,7 +24,7 @@ - - #include - --#include -+#include "third_party/gpus/cudnn/cudnn.h" - - namespace cudnn_frontend { - -diff --git a/include/cudnn_frontend_ConvDesc.h b/include/cudnn_frontend_ConvDesc.h -index 6e1d7ab..4deec88 100644 ---- a/include/cudnn_frontend_ConvDesc.h -+++ b/include/cudnn_frontend_ConvDesc.h -@@ -29,8 +29,8 @@ - #include - #include - --#include --#include -+#include "third_party/gpus/cudnn/cudnn.h" -+#include "third_party/gpus/cudnn/cudnn_backend.h" - - #include "cudnn_frontend_utils.h" - -diff --git a/include/cudnn_frontend_Engine.h b/include/cudnn_frontend_Engine.h -index b95efb8..867541e 100644 ---- a/include/cudnn_frontend_Engine.h -+++ b/include/cudnn_frontend_Engine.h -@@ -30,8 +30,8 @@ - #include - #include - --#include --#include -+#include "third_party/gpus/cudnn/cudnn.h" -+#include "third_party/gpus/cudnn/cudnn_backend.h" - - #include "cudnn_frontend_OperationGraph.h" - #include "cudnn_frontend_utils.h" -diff --git a/include/cudnn_frontend_EngineConfig.h b/include/cudnn_frontend_EngineConfig.h -index 973e777..97f0883 100644 ---- a/include/cudnn_frontend_EngineConfig.h -+++ b/include/cudnn_frontend_EngineConfig.h -@@ -29,8 +29,8 @@ - #include - #include - --#include --#include -+#include "third_party/gpus/cudnn/cudnn.h" -+#include "third_party/gpus/cudnn/cudnn_backend.h" - - #include "cudnn_frontend_Engine.h" - #include "cudnn_frontend_utils.h" -diff --git a/include/cudnn_frontend_EngineFallbackList.h b/include/cudnn_frontend_EngineFallbackList.h -index 4d4e5be..6390bc5 100644 ---- a/include/cudnn_frontend_EngineFallbackList.h -+++ b/include/cudnn_frontend_EngineFallbackList.h -@@ -22,7 +22,7 @@ - - #pragma once - --#include -+#include "third_party/gpus/cudnn/cudnn.h" - #include - #include "cudnn_frontend_Heuristics.h" - -diff --git a/include/cudnn_frontend_ExecutionPlan.h b/include/cudnn_frontend_ExecutionPlan.h -index afceeb3..3d426e2 100644 ---- a/include/cudnn_frontend_ExecutionPlan.h -+++ b/include/cudnn_frontend_ExecutionPlan.h -@@ -30,8 +30,8 @@ - #include - #include - --#include --#include -+#include "third_party/gpus/cudnn/cudnn.h" -+#include "third_party/gpus/cudnn/cudnn_backend.h" - - #include "cudnn_frontend_EngineConfig.h" - #include "cudnn_frontend_Engine.h" -diff --git a/include/cudnn_frontend_Filters.h b/include/cudnn_frontend_Filters.h -index 676f0f2..4d1c020 100644 ---- a/include/cudnn_frontend_Filters.h -+++ b/include/cudnn_frontend_Filters.h -@@ -22,7 +22,7 @@ - - #pragma once - --#include -+#include "third_party/gpus/cudnn/cudnn.h" - - namespace cudnn_frontend { - -diff --git a/include/cudnn_frontend_Heuristics.h b/include/cudnn_frontend_Heuristics.h -index dda3fb3..3e89237 100644 ---- a/include/cudnn_frontend_Heuristics.h -+++ b/include/cudnn_frontend_Heuristics.h -@@ -25,8 +25,8 @@ - #include - #include - --#include --#include -+#include "third_party/gpus/cudnn/cudnn.h" -+#include "third_party/gpus/cudnn/cudnn_backend.h" - - #include "cudnn_frontend_OperationGraph.h" - #include "cudnn_frontend_EngineConfig.h" -diff --git a/include/cudnn_frontend_MatMulDesc.h b/include/cudnn_frontend_MatMulDesc.h -index c9258ba..141f2f9 100644 ---- a/include/cudnn_frontend_MatMulDesc.h -+++ b/include/cudnn_frontend_MatMulDesc.h -@@ -29,8 +29,8 @@ - #include - #include - --#include --#include -+#include "third_party/gpus/cudnn/cudnn.h" -+#include "third_party/gpus/cudnn/cudnn_backend.h" - - #include "cudnn_frontend_utils.h" - -diff --git a/include/cudnn_frontend_Operation.h b/include/cudnn_frontend_Operation.h -index bf16cfa..f3086e1 100644 ---- a/include/cudnn_frontend_Operation.h -+++ b/include/cudnn_frontend_Operation.h -@@ -30,8 +30,8 @@ - #include - #include +diff --git a/include/cudnn_frontend.h b/include/cudnn_frontend.h +index 0f0d5a6..802bcbb 100644 +--- a/include/cudnn_frontend.h ++++ b/include/cudnn_frontend.h +@@ -97,7 +97,7 @@ + * - Simpler samples on how to use the new API. + */ -#include --#include +#include "third_party/gpus/cudnn/cudnn.h" -+#include "third_party/gpus/cudnn/cudnn_backend.h" #include "cudnn_frontend_ConvDesc.h" - #include "cudnn_frontend_PointWiseDesc.h" -diff --git a/include/cudnn_frontend_OperationGraph.h b/include/cudnn_frontend_OperationGraph.h -index c5e2704..71589b2 100644 ---- a/include/cudnn_frontend_OperationGraph.h -+++ b/include/cudnn_frontend_OperationGraph.h -@@ -30,8 +30,8 @@ - #include - #include - --#include --#include -+#include "third_party/gpus/cudnn/cudnn.h" -+#include "third_party/gpus/cudnn/cudnn_backend.h" - - #include "cudnn_frontend_Operation.h" - #include "cudnn_frontend_utils.h" -diff --git a/include/cudnn_frontend_PointWiseDesc.h b/include/cudnn_frontend_PointWiseDesc.h -index afa71ce..56b6507 100644 ---- a/include/cudnn_frontend_PointWiseDesc.h -+++ b/include/cudnn_frontend_PointWiseDesc.h -@@ -30,8 +30,8 @@ - #include - #include - --#include --#include -+#include "third_party/gpus/cudnn/cudnn.h" -+#include "third_party/gpus/cudnn/cudnn_backend.h" - - #include "cudnn_frontend_utils.h" - -diff --git a/include/cudnn_frontend_ReductionDesc.h b/include/cudnn_frontend_ReductionDesc.h -index 5df2c5e..419fc93 100644 ---- a/include/cudnn_frontend_ReductionDesc.h -+++ b/include/cudnn_frontend_ReductionDesc.h -@@ -29,8 +29,8 @@ - #include - #include - --#include --#include -+#include "third_party/gpus/cudnn/cudnn.h" -+#include "third_party/gpus/cudnn/cudnn_backend.h" - - #include "cudnn_frontend_utils.h" - -diff --git a/include/cudnn_frontend_Resample.h b/include/cudnn_frontend_Resample.h -index 351e2da..b1a1904 100644 ---- a/include/cudnn_frontend_Resample.h -+++ b/include/cudnn_frontend_Resample.h -@@ -29,8 +29,8 @@ - #include - #include - --#include --#include -+#include "third_party/gpus/cudnn/cudnn.h" -+#include "third_party/gpus/cudnn/cudnn_backend.h" - - #include "cudnn_frontend_utils.h" - -diff --git a/include/cudnn_frontend_Rng.h b/include/cudnn_frontend_Rng.h -index 9d4e6ca..4224b61 100644 ---- a/include/cudnn_frontend_Rng.h -+++ b/include/cudnn_frontend_Rng.h -@@ -29,8 +29,8 @@ - #include - #include - --#include --#include -+#include "third_party/gpus/cudnn/cudnn.h" -+#include "third_party/gpus/cudnn/cudnn_backend.h" - - #include "cudnn_frontend_utils.h" - -diff --git a/include/cudnn_frontend_VariantPack.h b/include/cudnn_frontend_VariantPack.h -index 455ab8b..4173860 100644 ---- a/include/cudnn_frontend_VariantPack.h -+++ b/include/cudnn_frontend_VariantPack.h -@@ -30,8 +30,8 @@ - #include - #include - --#include --#include -+#include "third_party/gpus/cudnn/cudnn.h" -+#include "third_party/gpus/cudnn/cudnn_backend.h" - - #include "cudnn_frontend_utils.h" - + #include "cudnn_frontend_Heuristics.h" diff --git a/third_party/xla/third_party/llvm_openmp/BUILD b/third_party/xla/third_party/llvm_openmp/BUILD index 1592cf2ad51913..71a21b4e3786bb 100644 --- a/third_party/xla/third_party/llvm_openmp/BUILD +++ b/third_party/xla/third_party/llvm_openmp/BUILD @@ -19,19 +19,20 @@ load( load("@bazel_skylib//:bzl_library.bzl", "bzl_library") package( - default_visibility = ["//visibility:public"], + default_visibility = [ + "//visibility:public", + ], ) -exports_files( - ["LICENSE.txt"], - visibility = ["//visibility:public"], -) +exports_files(["LICENSE.txt"]) py_binary( name = "expand_cmake_vars", srcs = ["expand_cmake_vars.py"], srcs_version = "PY3", - visibility = ["//visibility:public"], + visibility = [ + "@llvm_openmp//:__subpackages__", + ], ) kmp_i18n_os_type = select({ @@ -239,5 +240,4 @@ if_windows(a = libiomp5_cc_binary( bzl_library( name = "openmp_bzl", srcs = ["openmp.bzl"], - visibility = ["//visibility:public"], ) diff --git a/third_party/xla/third_party/python_runtime/BUILD b/third_party/xla/third_party/python_runtime/BUILD index 14210ebf684fa9..2a1609191fe351 100644 --- a/third_party/xla/third_party/python_runtime/BUILD +++ b/third_party/xla/third_party/python_runtime/BUILD @@ -5,5 +5,4 @@ package(default_visibility = ["//visibility:public"]) alias( name = "headers", actual = "@local_config_python//:python_headers", - visibility = ["//visibility:public"], ) diff --git a/third_party/xla/third_party/stablehlo/temporary.patch b/third_party/xla/third_party/stablehlo/temporary.patch index 8a53bd5f19638f..c8aa227b5a35ef 100755 --- a/third_party/xla/third_party/stablehlo/temporary.patch +++ b/third_party/xla/third_party/stablehlo/temporary.patch @@ -1,7 +1,7 @@ diff --ruN a/stablehlo/CMakeLists.txt b/stablehlo/CMakeLists.txt --- stablehlo/CMakeLists.txt +++ stablehlo/CMakeLists.txt -@@ -13,131 +13,20 @@ +@@ -13,153 +13,20 @@ # See the License for the specific language governing permissions and # limitations under the License. # @@ -25,6 +25,11 @@ diff --ruN a/stablehlo/CMakeLists.txt b/stablehlo/CMakeLists.txt -if(POLICY CMP0116) - cmake_policy(SET CMP0116 OLD) -endif() +- +-# Support for return(PROPAGATE ...) in functions. +-if (POLICY CMP0140) +- cmake_policy(SET CMP0140 NEW) +-endif() +# This build of StableHLO is meant to be embedded in MLIR-HLO. +# As a result, its root CMakeLists.txt is different from the original +# CMakeLists.txt from https://github.com/openxla/stablehlo. @@ -39,6 +44,9 @@ diff --ruN a/stablehlo/CMakeLists.txt b/stablehlo/CMakeLists.txt -option(STABLEHLO_BUILD_EMBEDDED "Build StableHLO as part of another project" OFF) -option(STABLEHLO_ENABLE_BINDINGS_PYTHON "Enables StableHLO Python bindings" OFF) -option(STABLEHLO_ENABLE_STRICT_BUILD "Build StableHLO with strict warnings and warnings as errors" OFF) +-option(STABLEHLO_ENABLE_SANITIZER "Enable a sanitizer [OFF, address]" OFF) +-option(STABLEHLO_ENABLE_SPLIT_DWARF "Enable split DWARF if the platform supports it" OFF) +-option(STABLEHLO_ENABLE_LLD "Use LLD as the linker if available" OFF) -#------------------------------------------------------------------------------- -# Project setup and globals @@ -55,29 +63,6 @@ diff --ruN a/stablehlo/CMakeLists.txt b/stablehlo/CMakeLists.txt - set(CMAKE_CXX_STANDARD 17) -endif() - --# Build with ccache if the package is present --set(LLVM_CCACHE_BUILD OFF CACHE BOOL "Set to ON for a ccache enabled build") --if(LLVM_CCACHE_BUILD) -- find_program(CCACHE_PROGRAM ccache) -- if(CCACHE_PROGRAM) -- set(LLVM_CCACHE_MAXSIZE "" CACHE STRING "Size of ccache") -- set(LLVM_CCACHE_DIR "" CACHE STRING "Directory to keep ccached data") -- set(LLVM_CCACHE_PARAMS "CCACHE_CPP2=yes CCACHE_HASHDIR=yes" -- CACHE STRING "Parameters to pass through to ccache") -- -- set(CCACHE_PROGRAM "${LLVM_CCACHE_PARAMS} ${CCACHE_PROGRAM}") -- if (LLVM_CCACHE_MAXSIZE) -- set(CCACHE_PROGRAM "CCACHE_MAXSIZE=${LLVM_CCACHE_MAXSIZE} ${CCACHE_PROGRAM}") -- endif() -- if (LLVM_CCACHE_DIR) -- set(CCACHE_PROGRAM "CCACHE_DIR=${LLVM_CCACHE_DIR} ${CCACHE_PROGRAM}") -- endif() -- set_property(GLOBAL PROPERTY RULE_LAUNCH_COMPILE ${CCACHE_PROGRAM}) -- else() -- message(FATAL_ERROR "Unable to find the program ccache. Set LLVM_CCACHE_BUILD to OFF") -- endif() --endif() -- -#------------------------------------------------------------------------------- -# MLIR/LLVM Configuration -#------------------------------------------------------------------------------- @@ -114,10 +99,39 @@ diff --ruN a/stablehlo/CMakeLists.txt b/stablehlo/CMakeLists.txt - message(STATUS "Building StableHLO embedded in another project") -endif() - +-# Add the CMake modules specific to StableHLO +-list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_LIST_DIR}/cmake") +- -if(LLVM_ENABLE_ZLIB) - find_package(ZLIB) -endif() - +-#------------------------------------------------------------------------------- +-# Performance configuration +-#------------------------------------------------------------------------------- +- +-include(CheckCXXCompilerFlag) +-include(CheckLinkerFlag) +-if (STABLEHLO_ENABLE_LLD) +- message(STATUS "Enabling LLD as the linker") +- add_link_options("-fuse-ld=lld") +-endif() +- +-if(STABLEHLO_ENABLE_SPLIT_DWARF) +- check_cxx_compiler_flag(-gsplit-dwarf STABLEHLO_SUPPORTS_SPLIT_DWARF) +- if (STABLEHLO_SUPPORTS_SPLIT_DWARF) +- message(STATUS "Enabling split-dwarf build") +- add_compile_options(-gsplit-dwarf -ggnu-pubnames) +- endif() +- check_linker_flag(CXX "-Wl,--gdb-index" STABLEHLO_SUPPORTS_GDB_INDEX) +- # If we set LLD it doesn't seem to affect the check_linker_flag above. +- # Account for it with the generator expression OR +- if (STABLEHLO_SUPPORTS_GDB_INDEX OR STABLEHLO_ENABLE_LLD) +- message(STATUS "Enabling GDB index in binary") +- add_link_options("-Wl,--gdb-index") +- endif() +-endif() +- -include(TableGen) -include(AddLLVM) -include(AddMLIR) @@ -129,6 +143,14 @@ diff --ruN a/stablehlo/CMakeLists.txt b/stablehlo/CMakeLists.txt -link_directories(${LLVM_BUILD_LIBRARY_DIR}) -add_definitions(${LLVM_DEFINITIONS}) - +- +-#------------------------------------------------------------------------------- +-# Sanitizer configuration +-#------------------------------------------------------------------------------- +- +-include(SetupSanitizers) +-setup_sanitizers() +- -#------------------------------------------------------------------------------- -# Python configuration -#------------------------------------------------------------------------------- @@ -141,6 +163,27 @@ diff --ruN a/stablehlo/CMakeLists.txt b/stablehlo/CMakeLists.txt #------------------------------------------------------------------------------- # Directory setup +diff --ruN a/stablehlo/MODULE.bazel.lock b/stablehlo/MODULE.bazel.lock +--- stablehlo/MODULE.bazel.lock ++++ stablehlo/MODULE.bazel.lock +@@ -1,3 +1,17 @@ ++# Copyright 2024 The StableHLO Authors. All Rights Reserved. ++# ++# Licensed under the Apache License, Version 2.0 (the "License"); ++# you may not use this file except in compliance with the License. ++# You may obtain a copy of the License at ++# ++# https://www.apache.org/licenses/LICENSE-2.0 ++# ++# Unless required by applicable law or agreed to in writing, software ++# distributed under the License is distributed on an "AS IS" BASIS, ++# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ++# See the License for the specific language governing permissions and ++# limitations under the License. ++ + { + "lockFileVersion": 3, + "moduleFileHash": "836f0a7d2276ed93403f104a10008b94ec7e7f81b8d6921cea287f0a6d364efa", diff --ruN a/stablehlo/stablehlo/CMakeLists.txt b/stablehlo/stablehlo/CMakeLists.txt --- stablehlo/stablehlo/CMakeLists.txt +++ stablehlo/stablehlo/CMakeLists.txt @@ -152,18 +195,927 @@ diff --ruN a/stablehlo/stablehlo/CMakeLists.txt b/stablehlo/stablehlo/CMakeLists add_subdirectory(integrations) add_subdirectory(reference) add_subdirectory(tests) -diff --ruN a/stablehlo/stablehlo/api/PortableApi.h b/stablehlo/stablehlo/api/PortableApi.h ---- stablehlo/stablehlo/api/PortableApi.h -+++ stablehlo/stablehlo/api/PortableApi.h -@@ -27,7 +27,7 @@ +diff --ruN a/stablehlo/stablehlo/dialect/AssemblyFormat.cpp b/stablehlo/stablehlo/dialect/AssemblyFormat.cpp +--- stablehlo/stablehlo/dialect/AssemblyFormat.cpp ++++ stablehlo/stablehlo/dialect/AssemblyFormat.cpp +@@ -16,15 +16,28 @@ + #include "stablehlo/dialect/AssemblyFormat.h" + + #include ++#include + #include - /// Return the current version for portable API. - /// Increments on all meaningful changes to this file. --inline int64_t getApiVersion() { return 5; } -+inline int64_t getApiVersion() { return 6; } + #include "llvm/ADT/ArrayRef.h" + #include "llvm/ADT/STLExtras.h" ++#include "llvm/ADT/StringExtras.h" ++#include "llvm/Support/Debug.h" + #include "llvm/Support/ErrorHandling.h" + #include "llvm/Support/Regex.h" ++#include "llvm/Support/SMLoc.h" ++#include "mlir/IR/Builders.h" + #include "mlir/IR/BuiltinAttributes.h" + #include "mlir/IR/BuiltinTypeInterfaces.h" ++#include "mlir/IR/OpImplementation.h" ++#include "mlir/IR/OperationSupport.h" ++#include "mlir/IR/Region.h" ++#include "mlir/IR/TypeUtilities.h" ++#include "mlir/IR/ValueRange.h" ++#include "mlir/Support/LLVM.h" + #include "mlir/Support/LogicalResult.h" ++ ++#define DEBUG_TYPE "hlo-assembly" + + namespace mlir { + namespace hlo { +@@ -212,6 +225,343 @@ + return success(); + } - // Get the current StableHLO version. ++namespace { ++void createArgs(ArrayRef operands, ++ ArrayRef types, ++ SmallVector& args) { ++ for (auto argAndType : llvm::zip(operands, types)) { ++ auto& arg = args.emplace_back(); ++ arg.ssaName = std::get<0>(argAndType); ++ arg.type = std::get<1>(argAndType); ++ } ++} ++ ++Operation* createReturn(OpBuilder& builder, Dialect* dialect, Location loc, ++ ResultRange operands) { ++ auto returnOpName = dialect->getNamespace() + ".return"; ++ OperationState returnOpState(loc, returnOpName.str()); ++ returnOpState.operands.append(operands.begin(), operands.end()); ++ return builder.create(returnOpState); ++} ++ ++bool hasSameOperandAndResultTypes(Operation& op) { ++ Type expected; ++ if (op.getNumResults() != 0) expected = op.getResult(0).getType(); ++ if (op.getNumOperands() != 0) expected = op.getOperand(0).getType(); ++ if (!expected) return false; ++ ++ auto typeMatch = [&](Type actual) { return actual == expected; }; ++ return llvm::all_of(op.getOperandTypes(), typeMatch) && ++ llvm::all_of(op.getResultTypes(), typeMatch); ++} ++ ++// Checks the following eligibility criteria for compact printing of reduce: ++// E1. The reduce-op wraps a single inner-op in the associated region. ++// E2. The single operation is a commutative binary-op from the dialect, zero ++// region, producing single result such that the operands and result all ++// have the same type. ++// E3. The reduce-op consist of at least one input-operand; The operand-types of ++// inner-op should be derived trivially from the element-type of reduce-op's ++// first input-operand. ++// E4. The arguments of the region's only basic block are forwarded perfectly ++// to inner-op's operands. ++// E5. The single operation result is perfectly forwarded to the reduce op ++// return. ++static bool isReduceEligibleForCompactPrint(Operation* op, ValueRange inputs, ++ Region& body) { ++ // Check E1. ++ LLVM_DEBUG(llvm::dbgs() << "Checking ReduceOp compact print E1\n"); ++ auto& block = body.front(); ++ if (!hasSingleElement(block.without_terminator())) return false; ++ ++ Operation& innerOp = *block.begin(); ++ ++ // Check E2. ++ LLVM_DEBUG(llvm::dbgs() << "Checking ReduceOp compact print E2\n"); ++ if (innerOp.getDialect() != op->getDialect()) return false; ++ ++ if (innerOp.getNumOperands() != 2 || ++ !innerOp.hasTrait() || ++ !hasSameOperandAndResultTypes(innerOp) || ++ (!innerOp.hasTrait() && ++ !innerOp.hasTrait()) || ++ !innerOp.hasTrait()) ++ return false; ++ ++ // Check E3. ++ LLVM_DEBUG(llvm::dbgs() << "Checking ReduceOp compact print E3\n"); ++ if (inputs.empty()) return false; ++ ++ auto elemType = inputs[0].getType().cast().getElementType(); ++ auto expectedInnerOpType = RankedTensorType::get(/*shape=*/{}, elemType); ++ if (innerOp.getOperands()[0].getType() != expectedInnerOpType) return false; ++ ++ // Check E4. ++ LLVM_DEBUG(llvm::dbgs() << "Checking ReduceOp compact print E4\n"); ++ if (!llvm::equal(block.getArguments(), innerOp.getOperands())) return false; ++ ++ // Check E5. ++ LLVM_DEBUG(llvm::dbgs() << "Checking ReduceOp compact print E5\n"); ++ auto retOp = block.getTerminator(); ++ if (!retOp->getName().stripDialect().equals("return")) return false; ++ ++ return llvm::equal(innerOp.getResults(), retOp->getOperands()); ++} ++} // namespace ++ ++void printReduceOp(OpAsmPrinter& p, Operation* op, ValueRange inputs, ++ ArrayRef dimensions, Region& body) { ++ { ++ // Print the pairs of operands under the form: ++ // (%arg0 init: %arg3), (%arg1 init: %arg4), (%arg2 init: %arg5) ++ StringRef comma = ""; ++ int numOperandPairs = op->getNumOperands() / 2; ++ for (int opId : llvm::seq(0, numOperandPairs)) { ++ p << comma << "(" << op->getOperand(opId) ++ << " init: " << op->getOperand(opId + numOperandPairs) << ")"; ++ comma = ", "; ++ } ++ } ++ ++ // If the reduce-op is eligible for compact printing, we emit the one-liner: ++ // stablehlo.reduce applies across dimensions = [...] : ++ // Note: We are not printing the function type of reduction operation. We ++ // have some simplifying assumptions (refer to IsEligibleForCompactPrint::E3) ++ // to derive the type from that of reduce-op. ++ if (isReduceEligibleForCompactPrint(op, inputs, body)) { ++ Operation& innerOp = body.front().front(); ++ p << " applies "; ++ llvm::printEscapedString(innerOp.getName().getStringRef(), p.getStream()); ++ p << " across dimensions = ["; ++ llvm::interleaveComma(dimensions, p); ++ p << "]"; ++ p.printOptionalAttrDict(op->getAttrs(), {"dimensions"}); ++ p << " : "; ++ p.printFunctionalType(op); ++ } else { ++ p << " across dimensions = ["; ++ llvm::interleaveComma(dimensions, p); ++ p << "]"; ++ p.printOptionalAttrDict(op->getAttrs(), {"dimensions"}); ++ p << " : "; ++ p.printFunctionalType(op); ++ p.printNewline(); ++ p << " reducer"; ++ { ++ // Print the pairs of block operands under the form: ++ // (%arg0_elt, %arg0_acc) (%arg1_elt, %arg1_acc): ++ Block& reducer = body.front(); ++ int numOperandPairs = op->getNumOperands() / 2; ++ for (int opId : llvm::seq(0, numOperandPairs)) { ++ p << "("; ++ p.printRegionArgument(reducer.getArgument(opId)); ++ p << ", "; ++ p.printRegionArgument(reducer.getArgument(opId + numOperandPairs)); ++ p << ") "; ++ } ++ } ++ p << ' '; ++ p.printRegion(body, /*printEntryBlockArgs=*/false); ++ } ++} ++ ++ParseResult parseReduceOp( ++ OpAsmParser& parser, OperationState& result, ++ std::function)> createDimensions) { ++ llvm::SMLoc loc = parser.getCurrentLocation(); ++ Location currLocation = parser.getEncodedSourceLoc(loc); ++ ++ // Parse the operands of reduce-op, this is a list of pair under the form: ++ // (%arg0 init: %arg3), (%arg1 init: %arg4), (%arg2 init: %arg5) ++ // Each input to reduce is paired with its init value, even though in memory ++ // they are stored with the input first and the init values after. ++ SmallVector operands; ++ SmallVector initOperands; ++ do { ++ (void)parser.parseOptionalComma(); ++ if (parser.parseOptionalLParen()) break; ++ OpAsmParser::UnresolvedOperand operand, initOperand; ++ if (parser.parseOperand(operand) || parser.parseKeyword("init") || ++ parser.parseColon() || parser.parseOperand(initOperand) || ++ parser.parseRParen()) ++ return failure(); ++ operands.push_back(operand); ++ initOperands.push_back(initOperand); ++ } while (true); ++ operands.append(initOperands); ++ ++ // Check if we are parsing the compact version of reduce-op: ++ // stablehlo.reduce applies across dimensions = [...] : ++ // else parse the "region-based" variant. ++ if (failed(parser.parseOptionalKeyword("applies"))) { ++ // Parse the inner-op dimensions, reduce-op's function-type and ++ // optional location. ++ SmallVector dimensions; ++ auto parseDim = [&]() -> ParseResult { ++ if (parser.parseInteger(dimensions.emplace_back())) return failure(); ++ return success(); ++ }; ++ ++ FunctionType reduceOpFnType; ++ if (parser.parseKeyword("across") || parser.parseKeyword("dimensions") || ++ parser.parseEqual() || ++ parser.parseCommaSeparatedList(AsmParser::Delimiter::Square, ++ parseDim) || ++ parser.parseOptionalAttrDict(result.attributes) || ++ parser.parseColon() || parser.parseType(reduceOpFnType) || ++ parser.parseKeyword("reducer")) ++ return failure(); ++ OpBuilder builder(parser.getBuilder().getContext()); ++ result.addAttribute("dimensions", createDimensions(builder, dimensions)); ++ ++ // Parse the "reducer" region now. ++ SmallVector reducerOperands; ++ SmallVector reducerInitOperands; ++ SmallVector reducerTypes; ++ SmallVector reducerInitTypes; ++ SmallVector, 2> reducerLocs; ++ SmallVector, 2> reducerInitLocs; ++ auto parseBlockOperand = ++ [&](SmallVectorImpl& operands, ++ SmallVectorImpl& types, ++ SmallVectorImpl>& locs) -> ParseResult { ++ OpAsmParser::UnresolvedOperand operand; ++ Type type; ++ std::optional loc; ++ if (parser.parseOperand(operand, /*allowResultNumber=*/false) || ++ parser.parseColon() || parser.parseType(type) || ++ parser.parseOptionalLocationSpecifier(loc)) ++ return failure(); ++ operands.push_back(operand); ++ types.push_back(type); ++ locs.push_back(loc); ++ return success(); ++ }; ++ do { ++ if (failed(parser.parseOptionalLParen())) break; ++ if (parseBlockOperand(reducerOperands, reducerTypes, reducerLocs) || ++ parser.parseComma() || ++ parseBlockOperand(reducerInitOperands, reducerInitTypes, ++ reducerInitLocs) || ++ parser.parseRParen()) ++ return failure(); ++ } while (true); ++ reducerOperands.append(reducerInitOperands); ++ reducerTypes.append(reducerInitTypes); ++ reducerLocs.append(reducerInitLocs); ++ result.addTypes(reduceOpFnType.getResults()); ++ SmallVector reducerArgs; ++ createArgs(reducerOperands, reducerTypes, reducerArgs); ++ ++ // Derive the SSA-values for reduce-op's operands and parse the region, and ++ // the optional trailing location. ++ std::optional trailingLoc; ++ if (parser.resolveOperands(operands, reduceOpFnType.getInputs(), loc, ++ result.operands) || ++ parser.parseRegion(*result.addRegion(), reducerArgs)) ++ return failure(); ++ // Set the individual block arguments. ++ for (auto argAndLoc : ++ llvm::zip(result.regions.front()->front().getArguments(), reducerLocs)) ++ if (std::get<1>(argAndLoc)) ++ std::get<0>(argAndLoc).setLoc(std::get<1>(argAndLoc).value()); ++ result.location = trailingLoc.value_or(currLocation); ++ return success(); ++ } ++ ++ // Parse the inner-op name and check if the contract on inner-op ++ // mentioned in "isEligibleForCompactPrint::E2" for pretty-printing is met. ++ FailureOr innerOpNameInfo = parser.parseCustomOperationName(); ++ if (failed(innerOpNameInfo)) return failure(); ++ ++ StringRef innerOpName = innerOpNameInfo->getStringRef(); ++ Dialect* innerOpDialect = innerOpNameInfo->getDialect(); ++ StringRef reduceOpDialect = result.name.getDialectNamespace(); ++ LLVM_DEBUG(llvm::dbgs() << "Reduce: " << reduceOpDialect << "\n"); ++ LLVM_DEBUG(llvm::dbgs() << "inner: " << innerOpDialect->getNamespace() ++ << "\n"); ++ if (!innerOpDialect || ++ !innerOpDialect->getNamespace().equals(reduceOpDialect) || ++ !innerOpNameInfo->hasTrait::Impl>() || ++ !innerOpNameInfo->hasTrait() || ++ (!innerOpNameInfo->hasTrait() && ++ !innerOpNameInfo->hasTrait()) || ++ !innerOpNameInfo->hasTrait()) { ++ parser.emitError(loc, ++ "expected the inner-op to be a commutative binary-op that " ++ "matching the reduce op dialect, with zero region, " ++ "producing single result"); ++ return failure(); ++ } ++ ++ // Parse the inner-op dimensions, reduce-op's function-type and ++ // optional location. ++ SmallVector dimensions; ++ auto parseDim = [&]() -> ParseResult { ++ if (parser.parseInteger(dimensions.emplace_back())) return failure(); ++ return success(); ++ }; ++ ++ std::optional explicitLoc; ++ FunctionType reduceOpFnType; ++ if (parser.parseKeyword("across") || parser.parseKeyword("dimensions") || ++ parser.parseEqual() || ++ parser.parseCommaSeparatedList(AsmParser::Delimiter::Square, parseDim) || ++ parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || ++ parser.parseType(reduceOpFnType) || ++ parser.parseOptionalLocationSpecifier(explicitLoc)) ++ return failure(); ++ ++ if (!reduceOpFnType || reduceOpFnType.getInputs().empty()) { ++ if (!reduceOpFnType) return parser.emitError(loc, "expected function type"); ++ return parser.emitError(loc, ++ "input types missing in reduce-op function type"); ++ } ++ ++ // If location of reduce-op is explicitly provided, then use it; Else use ++ // the parser's current location. ++ Location reduceOpLoc = explicitLoc.value_or(currLocation); ++ ++ // Derive the SSA-values for reduce-op's operands. ++ if (parser.resolveOperands(operands, reduceOpFnType.getInputs(), loc, ++ result.operands)) ++ return failure(); ++ ++ // Derive the type of inner-op from that of reduce-op's input operand. ++ auto innerOpType = RankedTensorType::get( ++ /*shape=*/{}, getElementTypeOrSelf(reduceOpFnType.getInput(0))); ++ ++ // Add a region for reduce-op. ++ Region& region = *result.addRegion(); ++ ++ // Create a basic-block inside reduce-op's region. ++ Block& block = region.emplaceBlock(); ++ auto lhs = block.addArgument(innerOpType, reduceOpLoc); ++ auto rhs = block.addArgument(innerOpType, reduceOpLoc); ++ ++ // Create and insert an "inner-op" operation in the block. ++ OpBuilder builder(parser.getBuilder().getContext()); ++ builder.setInsertionPointToStart(&block); ++ ++ OperationState innerOpState(reduceOpLoc, innerOpName); ++ innerOpState.operands.push_back(lhs); ++ innerOpState.operands.push_back(rhs); ++ innerOpState.addTypes(innerOpType); ++ ++ Operation* innerOp = builder.create(innerOpState); ++ ++ // Insert a return statement in the block returning the inner-op's result. ++ createReturn(builder, innerOp->getDialect(), innerOp->getLoc(), ++ innerOp->getResults()); ++ ++ // Populate the reduce-op operation-state with result-type, location, and ++ // dimension attribute. ++ result.addTypes(reduceOpFnType.getResults()); ++ result.location = innerOp->getLoc(); ++ result.addAttribute("dimensions", createDimensions(builder, dimensions)); ++ return success(); ++} ++ + void printSelectOpType(OpAsmPrinter& p, Operation* op, ShapedType pred, + ShapedType onTrue, ShapedType onFalse, + ShapedType result) { +@@ -250,6 +600,63 @@ + auto fnType = types[0].cast(); + return assignFromFunctionType(parser, loc, {&pred, &onTrue, &onFalse}, result, + fnType); ++} ++ ++void printWhileOp(OpAsmPrinter& p, Operation* op, Region& cond, Region& body) { ++ p << '('; ++ llvm::interleaveComma(llvm::zip(body.getArguments(), op->getOperands()), p, ++ [&](auto zip) { ++ p.printOperand(std::get<0>(zip)); ++ p << " = "; ++ p.printOperand(std::get<1>(zip)); ++ }); ++ p << ")"; ++ if (op->getNumOperands()) { ++ p << " : "; ++ llvm::interleaveComma(op->getOperandTypes(), p); ++ } ++ p.printOptionalAttrDictWithKeyword(op->getAttrs()); ++ p.printNewline(); ++ p << " cond "; ++ p.printRegion(cond, /*printEntryBlockArgs=*/false); ++ p << " do "; ++ p.printRegion(body, /*printEntryBlockArgs=*/false); ++} ++ ++ParseResult parseWhileOp(OpAsmParser& parser, OperationState& result) { ++ llvm::SMLoc loc = parser.getCurrentLocation(); ++ // Parse the operands of the while: these are of the form: ++ // %iter_arg = %init_val ++ // where %iter_arg is the name of the block argument in the cond/body blocks ++ // and %init_val is the actual operand. ++ SmallVector operands; ++ SmallVector iterArgs; ++ if (parser.parseLParen()) return failure(); ++ do { ++ if (succeeded(parser.parseOptionalRParen())) break; ++ OpAsmParser::UnresolvedOperand operand, iterArg; ++ if (parser.parseOperand(iterArg) || parser.parseEqual() || ++ parser.parseOperand(operand)) ++ return failure(); ++ iterArgs.push_back(iterArg); ++ operands.push_back(operand); ++ if (succeeded(parser.parseOptionalRParen())) break; ++ if (failed(parser.parseComma())) return failure(); ++ } while (true); ++ if (!operands.empty()) { ++ if (parser.parseColon() || parser.parseTypeList(result.types)) ++ return failure(); ++ } ++ SmallVector args; ++ createArgs(iterArgs, result.types, args); ++ if (parser.resolveOperands(operands, result.types, loc, result.operands) || ++ parser.parseOptionalAttrDictWithKeyword(result.attributes) || ++ parser.parseKeyword("cond") || ++ parser.parseRegion(*result.addRegion(), args) || ++ parser.parseKeyword("do") || ++ parser.parseRegion(*result.addRegion(), args)) ++ return failure(); ++ return success(); + } + + //===----------------------------------------------------------------------===// +diff --ruN a/stablehlo/stablehlo/dialect/AssemblyFormat.h b/stablehlo/stablehlo/dialect/AssemblyFormat.h +--- stablehlo/stablehlo/dialect/AssemblyFormat.h ++++ stablehlo/stablehlo/dialect/AssemblyFormat.h +@@ -16,19 +16,25 @@ + #ifndef STABLEHLO_DIALECT_ASSEMBLYFORMAT_H + #define STABLEHLO_DIALECT_ASSEMBLYFORMAT_H + ++#include ++#include ++ + #include "llvm/ADT/ArrayRef.h" + #include "llvm/ADT/SmallVector.h" +-#include "llvm/ADT/StringRef.h" + #include "mlir/IR/Attributes.h" + #include "mlir/IR/Builders.h" + #include "mlir/IR/BuiltinAttributes.h" ++#include "mlir/IR/BuiltinTypeInterfaces.h" + #include "mlir/IR/Dialect.h" + #include "mlir/IR/DialectImplementation.h" +-#include "mlir/IR/MLIRContext.h" + #include "mlir/IR/OpImplementation.h" + #include "mlir/IR/Operation.h" ++#include "mlir/IR/OperationSupport.h" ++#include "mlir/IR/Region.h" + #include "mlir/IR/TypeRange.h" + #include "mlir/IR/Types.h" ++#include "mlir/IR/ValueRange.h" ++#include "mlir/Support/LLVM.h" + #include "mlir/Support/LogicalResult.h" + #include "stablehlo/dialect/Base.h" + +@@ -154,6 +160,15 @@ + ParseResult parseComplexOpType(OpAsmParser& parser, Type& lhs, Type& rhs, + Type& result); + ++// Print reduce with or without compact printing ++void printReduceOp(OpAsmPrinter& p, Operation* op, ValueRange inputs, ++ ArrayRef dimensions, Region& body); ++ ++// Parse reduce with or without compact parsing ++ParseResult parseReduceOp( ++ OpAsmParser& parser, OperationState& result, ++ std::function)> createDimensions); ++ + // SelectOpType - only print the condition and result type when branch types + // match the result type. // +@@ -170,15 +185,27 @@ + ParseResult parseSelectOpType(OpAsmParser& parser, Type& pred, Type& onTrue, + Type& onFalse, Type& result); + ++// Print a `while` op. ++// ++// op ::= `stablehlo.while` `(` assignment-list `)` `:` types attribute-dict ++// `cond` region ++// `do` region ++// assignment-list ::= assignment | assignment `,` assignment-list ++// assignment ::= ssa-value `=` ssa-value ++void printWhileOp(OpAsmPrinter& p, Operation* op, Region& cond, Region& body); ++ ++// Parse reduce with or without compact parsing ++ParseResult parseWhileOp(OpAsmParser& parser, OperationState& result); ++ + //===----------------------------------------------------------------------===// + // Attribute Printers and Parsers + //===----------------------------------------------------------------------===// + + // SliceRanges - Used to print multi-dimensional ranges for slice. + void printSliceRanges(OpAsmPrinter& p, Operation* op, +- ArrayRef startIndices, +- ArrayRef limitIndices, +- ArrayRef strides); ++ llvm::ArrayRef startIndices, ++ llvm::ArrayRef limitIndices, ++ llvm::ArrayRef strides); + + ParseResult parseSliceRanges(OpAsmParser& parser, + DenseI64ArrayAttr& startIndices, +diff --ruN a/stablehlo/stablehlo/dialect/StablehloOps.cpp b/stablehlo/stablehlo/dialect/StablehloOps.cpp +--- stablehlo/stablehlo/dialect/StablehloOps.cpp ++++ stablehlo/stablehlo/dialect/StablehloOps.cpp +@@ -99,16 +99,6 @@ + return dialect->getRegisteredInterface(); + } + +-void createArgs(ArrayRef operands, +- ArrayRef types, +- SmallVector& args) { +- for (auto argAndType : llvm::zip(operands, types)) { +- auto& arg = args.emplace_back(); +- arg.ssaName = std::get<0>(argAndType); +- arg.type = std::get<1>(argAndType); +- } +-} +- + // Returns a new scalar integer value having type `type`. Here `type` must be + // an integer or index type. + Value maybeCastTo(OpBuilder& b, Location loc, Value value, Type type) { +@@ -1472,305 +1462,16 @@ + // ReduceOp + //===----------------------------------------------------------------------===// + +-bool hasSameOperandAndResultTypes(Operation& op) { +- Type expected; +- if (op.getNumResults() != 0) expected = op.getResult(0).getType(); +- if (op.getNumOperands() != 0) expected = op.getOperand(0).getType(); +- if (!expected) return false; +- +- auto typeMatch = [&](Type actual) { return actual == expected; }; +- return llvm::all_of(op.getOperandTypes(), typeMatch) && +- llvm::all_of(op.getResultTypes(), typeMatch); +-} +- +-// Checks the following eligibility criteria for compact printing of reduce: +-// E1. The reduce-op wraps a single inner-op in the associated region. +-// E2. The single operation is a commutative binary-op from the dialect, zero +-// region, producing single result such that the operands and result all +-// have the same type. +-// E3. The reduce-op consist of at least one input-operand; The operand-types of +-// inner-op should be derived trivially from the element-type of reduce-op's +-// first input-operand. +-// E4. The arguments of the region's only basic block are forwarded perfectly +-// to inner-op's operands. +-// E5. The single operation result is perfectly forwarded to the reduce op +-// return. +-static bool isEligibleForCompactPrint(ReduceOp op) { +- // Check E1. +- auto& block = op.getBody().front(); +- if (!hasSingleElement(block.without_terminator())) return false; +- +- Operation& innerOp = *block.begin(); +- +- // Check E2. +- if (innerOp.getDialect() != op->getDialect()) return false; +- +- if (innerOp.getNumOperands() != 2 || +- !innerOp.hasTrait() || +- !hasSameOperandAndResultTypes(innerOp) || +- !innerOp.hasTrait() || +- !innerOp.hasTrait()) +- return false; +- +- // Check E3. +- if (op.getInputs().empty()) return false; +- +- auto elemType = +- op.getInputs()[0].getType().cast().getElementType(); +- auto expectedInnerOpType = RankedTensorType::get(/*shape=*/{}, elemType); +- if (innerOp.getOperands()[0].getType() != expectedInnerOpType) return false; +- +- // Check E4. +- if (!llvm::equal(block.getArguments(), innerOp.getOperands())) return false; +- +- // Check E5. +- auto retOp = dyn_cast(block.getTerminator()); +- if (!retOp) return false; +- +- return llvm::equal(innerOp.getResults(), retOp.getOperands()); +-} +- + void ReduceOp::print(OpAsmPrinter& p) { +- { +- // Print the pairs of operands under the form: +- // (%arg0 init: %arg3), (%arg1 init: %arg4), (%arg2 init: %arg5) +- StringRef comma = ""; +- int numOperandPairs = getNumOperands() / 2; +- for (int opId : llvm::seq(0, numOperandPairs)) { +- p << comma << "(" << getOperand(opId) +- << " init: " << getOperand(opId + numOperandPairs) << ")"; +- comma = ", "; +- } +- } +- +- // If the reduce-op is eligible for compact printing, we emit the one-liner: +- // stablehlo.reduce applies across dimensions = [...] : +- // Note: We are not printing the function type of reduction operation. We +- // have some simplifying assumptions (refer to IsEligibleForCompactPrint::E3) +- // to derive the type from that of reduce-op. +- if (isEligibleForCompactPrint(*this)) { +- Operation& innerOp = getBody().front().front(); +- p << " applies "; +- printEscapedString(innerOp.getName().getStringRef(), p.getStream()); +- +- p << " across dimensions = ["; +- llvm::interleaveComma(getDimensions(), p); +- p << "]"; +- p.printOptionalAttrDict(getOperation()->getAttrs(), {"dimensions"}); +- p << " : "; +- p.printFunctionalType(*this); +- } else { +- p << " across dimensions = ["; +- llvm::interleaveComma(getDimensions(), p); +- p << "]"; +- p.printOptionalAttrDict(getOperation()->getAttrs(), {"dimensions"}); +- p << " : "; +- p.printFunctionalType(*this); +- p.printNewline(); +- p << " reducer"; +- { +- // Print the pairs of block operands under the form: +- // (%arg0_elt, %arg0_acc) (%arg1_elt, %arg1_acc): +- Block& reducer = getBody().front(); +- int numOperandPairs = getNumOperands() / 2; +- for (int opId : llvm::seq(0, numOperandPairs)) { +- p << "("; +- p.printRegionArgument(reducer.getArgument(opId)); +- p << ", "; +- p.printRegionArgument(reducer.getArgument(opId + numOperandPairs)); +- p << ") "; +- } +- } +- p << ' '; +- p.printRegion(getBody(), /*printEntryBlockArgs=*/false); +- } ++ hlo::printReduceOp(p, getOperation(), getInputs(), getDimensions(), ++ getBody()); + } + + ParseResult ReduceOp::parse(OpAsmParser& parser, OperationState& result) { +- llvm::SMLoc loc = parser.getCurrentLocation(); +- Location currLocation = parser.getEncodedSourceLoc(loc); +- +- // Parse the operands of reduce-op, this is a list of pair under the form: +- // (%arg0 init: %arg3), (%arg1 init: %arg4), (%arg2 init: %arg5) +- // Each input to reduce is paired with its init value, even though in memory +- // they are stored with the input first and the init values after. +- SmallVector operands; +- SmallVector initOperands; +- do { +- (void)parser.parseOptionalComma(); +- if (parser.parseOptionalLParen()) break; +- OpAsmParser::UnresolvedOperand operand, initOperand; +- if (parser.parseOperand(operand) || parser.parseKeyword("init") || +- parser.parseColon() || parser.parseOperand(initOperand) || +- parser.parseRParen()) +- return failure(); +- operands.push_back(operand); +- initOperands.push_back(initOperand); +- } while (true); +- operands.append(initOperands); +- +- // Check if we are parsing the compact version of reduce-op: +- // stablehlo.reduce applies across dimensions = [...] : +- // else parse the "region-based" variant. +- if (failed(parser.parseOptionalKeyword("applies"))) { +- // Parse the inner-op dimensions, reduce-op's function-type and +- // optional location. +- SmallVector dimensions; +- auto parseDim = [&]() -> ParseResult { +- if (parser.parseInteger(dimensions.emplace_back())) return failure(); +- return success(); +- }; +- +- FunctionType reduceOpFnType; +- if (parser.parseKeyword("across") || parser.parseKeyword("dimensions") || +- parser.parseEqual() || +- parser.parseCommaSeparatedList(AsmParser::Delimiter::Square, +- parseDim) || +- parser.parseOptionalAttrDict(result.attributes) || +- parser.parseColon() || parser.parseType(reduceOpFnType) || +- parser.parseKeyword("reducer")) +- return failure(); +- OpBuilder builder(parser.getBuilder().getContext()); +- result.addAttribute("dimensions", builder.getDenseI64ArrayAttr(dimensions)); +- +- // Parse the "reducer" region now. +- SmallVector reducerOperands; +- SmallVector reducerInitOperands; +- SmallVector reducerTypes; +- SmallVector reducerInitTypes; +- SmallVector, 2> reducerLocs; +- SmallVector, 2> reducerInitLocs; +- auto parseBlockOperand = +- [&](SmallVectorImpl& operands, +- SmallVectorImpl& types, +- SmallVectorImpl>& locs) -> ParseResult { +- OpAsmParser::UnresolvedOperand operand; +- Type type; +- std::optional loc; +- if (parser.parseOperand(operand, /*allowResultNumber=*/false) || +- parser.parseColon() || parser.parseType(type) || +- parser.parseOptionalLocationSpecifier(loc)) +- return failure(); +- operands.push_back(operand); +- types.push_back(type); +- locs.push_back(loc); +- return success(); +- }; +- do { +- if (failed(parser.parseOptionalLParen())) break; +- if (parseBlockOperand(reducerOperands, reducerTypes, reducerLocs) || +- parser.parseComma() || +- parseBlockOperand(reducerInitOperands, reducerInitTypes, +- reducerInitLocs) || +- parser.parseRParen()) +- return failure(); +- } while (true); +- reducerOperands.append(reducerInitOperands); +- reducerTypes.append(reducerInitTypes); +- reducerLocs.append(reducerInitLocs); +- result.addTypes(reduceOpFnType.getResults()); +- SmallVector reducerArgs; +- createArgs(reducerOperands, reducerTypes, reducerArgs); +- +- // Derive the SSA-values for reduce-op's operands and parse the region, and +- // the optional trailing location. +- std::optional trailingLoc; +- if (parser.resolveOperands(operands, reduceOpFnType.getInputs(), loc, +- result.operands) || +- parser.parseRegion(*result.addRegion(), reducerArgs)) +- return failure(); +- // Set the individual block arguments. +- for (auto argAndLoc : +- llvm::zip(result.regions.front()->front().getArguments(), reducerLocs)) +- if (std::get<1>(argAndLoc)) +- std::get<0>(argAndLoc).setLoc(std::get<1>(argAndLoc).value()); +- result.location = trailingLoc.value_or(currLocation); +- return success(); +- } +- +- // Parse the inner-op name and check if the contract on inner-op +- // mentioned in "isEligibleForCompactPrint::E2" for pretty-printing is met. +- FailureOr innerOpNameInfo = parser.parseCustomOperationName(); +- if (failed(innerOpNameInfo)) return failure(); +- +- StringRef innerOpName = innerOpNameInfo->getStringRef(); +- Dialect* innerOpDialect = innerOpNameInfo->getDialect(); +- if (!innerOpDialect || !innerOpDialect->getNamespace().equals("stablehlo") || +- !innerOpNameInfo->hasTrait::Impl>() || +- !innerOpNameInfo->hasTrait() || +- !innerOpNameInfo->hasTrait() || +- !innerOpNameInfo->hasTrait()) { +- parser.emitError(loc, +- "expected the inner-op to be a commutative binary-op from " +- "stablehlo dialect, zero region, producing single result"); +- return failure(); +- } +- +- // Parse the inner-op dimensions, reduce-op's function-type and +- // optional location. +- SmallVector dimensions; +- auto parseDim = [&]() -> ParseResult { +- if (parser.parseInteger(dimensions.emplace_back())) return failure(); +- return success(); ++ auto parseDenseArray = [](OpBuilder& b, ArrayRef dims) -> Attribute { ++ return b.getDenseI64ArrayAttr(dims); + }; +- +- std::optional explicitLoc; +- FunctionType reduceOpFnType; +- if (parser.parseKeyword("across") || parser.parseKeyword("dimensions") || +- parser.parseEqual() || +- parser.parseCommaSeparatedList(AsmParser::Delimiter::Square, parseDim) || +- parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || +- parser.parseType(reduceOpFnType) || +- parser.parseOptionalLocationSpecifier(explicitLoc)) +- return failure(); +- +- if (!reduceOpFnType || reduceOpFnType.getInputs().empty()) { +- if (!reduceOpFnType) return parser.emitError(loc, "expected function type"); +- return parser.emitError(loc, +- "input types missing in reduce-op function type"); +- } +- +- // If location of reduce-op is explicitly provided, then use it; Else use +- // the parser's current location. +- Location reduceOpLoc = explicitLoc.value_or(currLocation); +- +- // Derive the SSA-values for reduce-op's operands. +- if (parser.resolveOperands(operands, reduceOpFnType.getInputs(), loc, +- result.operands)) +- return failure(); +- +- // Derive the type of inner-op from that of reduce-op's input operand. +- auto innerOpType = RankedTensorType::get( +- /*shape=*/{}, getElementTypeOrSelf(reduceOpFnType.getInput(0))); +- +- // Add a region for reduce-op. +- Region& region = *result.addRegion(); +- +- // Create a basic-block inside reduce-op's region. +- Block& block = region.emplaceBlock(); +- auto lhs = block.addArgument(innerOpType, reduceOpLoc); +- auto rhs = block.addArgument(innerOpType, reduceOpLoc); +- +- // Create and insert an "inner-op" operation in the block. +- OpBuilder builder(parser.getBuilder().getContext()); +- builder.setInsertionPointToStart(&block); +- +- OperationState innerOpState(reduceOpLoc, innerOpName); +- innerOpState.operands.push_back(lhs); +- innerOpState.operands.push_back(rhs); +- innerOpState.addTypes(innerOpType); +- +- Operation* innerOp = builder.create(innerOpState); +- +- // Insert a return statement in the block returning the inner-op's result. +- builder.create(innerOp->getLoc(), innerOp->getResults()); +- +- // Populate the reduce-op operation-state with result-type, location, and +- // dimension attribute. +- result.addTypes(reduceOpFnType.getResults()); +- result.location = innerOp->getLoc(); +- result.addAttribute("dimensions", builder.getDenseI64ArrayAttr(dimensions)); +- return success(); ++ return hlo::parseReduceOp(parser, result, parseDenseArray); + } + + LogicalResult ReduceOp::inferReturnTypeComponents( +@@ -2385,69 +2086,12 @@ + return hlo::verifyWhileOp(getLoc(), getOperand(), getCond(), getBody()); + } + +-/// Print a `while` op. +-/// +-/// op ::= `stablehlo.while` `(` assignment-list `)` `:` types attribute-dict +-/// `cond` region +-/// `do` region +-/// assignment-list ::= assignment | assignment `,` assignment-list +-/// assignment ::= ssa-value `=` ssa-value + void WhileOp::print(OpAsmPrinter& p) { +- p << '('; +- llvm::interleaveComma( +- llvm::zip(SingleBlock::getBody()->getArguments(), getOperands()), p, +- [&](auto zip) { +- p.printOperand(std::get<0>(zip)); +- p << " = "; +- p.printOperand(std::get<1>(zip)); +- }); +- p << ")"; +- if (getNumOperands()) { +- p << " : "; +- llvm::interleaveComma(getOperandTypes(), p); +- } +- p.printOptionalAttrDictWithKeyword(getOperation()->getAttrs()); +- p.printNewline(); +- p << " cond "; +- p.printRegion(getRegion(0), /*printEntryBlockArgs=*/false); +- p << " do "; +- p.printRegion(getRegion(1), /*printEntryBlockArgs=*/false); ++ hlo::printWhileOp(p, getOperation(), getCond(), getBody()); + } + + ParseResult WhileOp::parse(OpAsmParser& parser, OperationState& result) { +- llvm::SMLoc loc = parser.getCurrentLocation(); +- // Parse the operands of the while: these are of the form: +- // %iter_arg = %init_val +- // where %iter_arg is the name of the block argument in the cond/body blocks +- // and %init_val is the actual operand. +- SmallVector operands; +- SmallVector iterArgs; +- if (parser.parseLParen()) return failure(); +- do { +- if (succeeded(parser.parseOptionalRParen())) break; +- OpAsmParser::UnresolvedOperand operand, iterArg; +- if (parser.parseOperand(iterArg) || parser.parseEqual() || +- parser.parseOperand(operand)) +- return failure(); +- iterArgs.push_back(iterArg); +- operands.push_back(operand); +- if (succeeded(parser.parseOptionalRParen())) break; +- if (failed(parser.parseComma())) return failure(); +- } while (true); +- if (!operands.empty()) { +- if (parser.parseColon() || parser.parseTypeList(result.types)) +- return failure(); +- } +- SmallVector args; +- createArgs(iterArgs, result.types, args); +- if (parser.resolveOperands(operands, result.types, loc, result.operands) || +- parser.parseOptionalAttrDictWithKeyword(result.attributes) || +- parser.parseKeyword("cond") || +- parser.parseRegion(*result.addRegion(), args) || +- parser.parseKeyword("do") || +- parser.parseRegion(*result.addRegion(), args)) +- return failure(); +- return success(); ++ return hlo::parseWhileOp(parser, result); + } + + LogicalResult UniformDequantizeOp::inferReturnTypeComponents( diff --ruN a/stablehlo/stablehlo/experimental/BUILD.bazel b/stablehlo/stablehlo/experimental/BUILD.bazel --- stablehlo/stablehlo/experimental/BUILD.bazel +++ stablehlo/stablehlo/experimental/BUILD.bazel @@ -1274,7 +2226,7 @@ diff --ruN a/stablehlo/stablehlo/experimental/tests/CMakeLists.txt b/stablehlo/s diff --ruN a/stablehlo/stablehlo/experimental/tests/chlo_recompose_ops.mlir b/stablehlo/stablehlo/experimental/tests/chlo_recompose_ops.mlir --- stablehlo/stablehlo/experimental/tests/chlo_recompose_ops.mlir +++ stablehlo/stablehlo/experimental/tests/chlo_recompose_ops.mlir -@@ -0,0 +1,36 @@ +@@ -0,0 +1,51 @@ +// RUN: experimental-stablehlo-opt --experimental-chlo-recompose-ops --split-input-file --verify-diagnostics %s | FileCheck %s + +// ----- @@ -1311,6 +2263,21 @@ diff --ruN a/stablehlo/stablehlo/experimental/tests/chlo_recompose_ops.mlir b/st + } : (tensor<16xf32>) -> tensor + func.return %0 : tensor +} ++ ++// ----- ++ ++// CHECK-LABEL: @recompose_erf ++func.func @recompose_erf(%arg0: tensor<3x20x20xbf16>) -> tensor { ++ // CHECK: %0 = chlo.erf %arg0 : tensor<3x20x20xbf16> -> tensor ++ %0 = "stablehlo.custom_call"(%arg0) { ++ backend_config = "", ++ call_target_name = "mhlo.erf", ++ mhlo.attributes = {}, ++ mhlo.version = 1 : i64 ++ } : (tensor<3x20x20xbf16>) -> tensor ++ func.return %0 : tensor ++} ++ diff --ruN a/stablehlo/stablehlo/experimental/tests/lit.cfg.py b/stablehlo/stablehlo/experimental/tests/lit.cfg.py --- stablehlo/stablehlo/experimental/tests/lit.cfg.py +++ stablehlo/stablehlo/experimental/tests/lit.cfg.py @@ -1922,7 +2889,7 @@ diff --ruN a/stablehlo/stablehlo/experimental/transforms/CMakeLists.txt b/stable diff --ruN a/stablehlo/stablehlo/experimental/transforms/ChloRecomposeOps.cpp b/stablehlo/stablehlo/experimental/transforms/ChloRecomposeOps.cpp --- stablehlo/stablehlo/experimental/transforms/ChloRecomposeOps.cpp +++ stablehlo/stablehlo/experimental/transforms/ChloRecomposeOps.cpp -@@ -0,0 +1,151 @@ +@@ -0,0 +1,168 @@ +/* Copyright 2024 The StableHLO Authors. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. @@ -1940,6 +2907,7 @@ diff --ruN a/stablehlo/stablehlo/experimental/transforms/ChloRecomposeOps.cpp b/ +#include +#include + ++#include "llvm/ADT/SmallVector.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/BuiltinAttributes.h" @@ -2044,6 +3012,15 @@ diff --ruN a/stablehlo/stablehlo/experimental/transforms/ChloRecomposeOps.cpp b/ + } +}; + ++struct ErfOpRecomposePattern : public OpRewritePattern { ++ using OpRewritePattern::OpRewritePattern; ++ LogicalResult matchAndRewrite(CustomCallOp op, ++ PatternRewriter& rewriter) const override { ++ if (op.getCallTargetName() != "mhlo.erf") return failure(); ++ return recomposeChloOpFromCustomCall(op, rewriter); ++ } ++}; ++ +} // namespace + +struct ChloRecomposeOpsPass @@ -2051,21 +3028,28 @@ diff --ruN a/stablehlo/stablehlo/experimental/transforms/ChloRecomposeOps.cpp b/ + using ChloRecomposeOpsPassBase::ChloRecomposeOpsPassBase; + + void runOnOperation() override { -+ // Do a single traversal to recompose CHLO ops. -+ // TODO(#1048): Find out why .maxIterations = 1 no longer works. ++ // Do a single traversal to recompose CustomCallOp to CHLO ops. + GreedyRewriteConfig config; + config.useTopDownTraversal = true; + config.enableRegionSimplification = true; -+ config.maxIterations = 2; ++ config.maxIterations = 1; + config.maxNumRewrites = GreedyRewriteConfig::kNoLimit; -+ config.strictMode = GreedyRewriteStrictness::AnyOp; ++ config.strictMode = GreedyRewriteStrictness::ExistingOps; + + RewritePatternSet patterns(&getContext()); + patterns.add(&getContext()); + patterns.add(&getContext()); ++ patterns.add(&getContext()); + -+ if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns), -+ config))) { ++ // Only apply to CustomCallOps ++ auto moduleOp = getOperation(); ++ llvm::SmallVector candidateOps; ++ moduleOp.walk([&](CustomCallOp op) { candidateOps.push_back(op); }); ++ ++ if (failed(applyOpPatternsAndFold(candidateOps, std::move(patterns), ++ config))) { ++ moduleOp.emitError("Failed to converge ChloRecomposeOps in ") ++ << config.maxIterations << " iterations"; + return signalPassFailure(); + } + } @@ -2160,7 +3144,7 @@ diff --ruN a/stablehlo/stablehlo/experimental/transforms/Passes.td b/stablehlo/s diff --ruN a/stablehlo/stablehlo/experimental/transforms/StablehloCanonicalizeDynamism.cpp b/stablehlo/stablehlo/experimental/transforms/StablehloCanonicalizeDynamism.cpp --- stablehlo/stablehlo/experimental/transforms/StablehloCanonicalizeDynamism.cpp +++ stablehlo/stablehlo/experimental/transforms/StablehloCanonicalizeDynamism.cpp -@@ -0,0 +1,167 @@ +@@ -0,0 +1,171 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + Copyright 2023 The StableHLO Authors. +Licensed under the Apache License, Version 2.0 (the "License"); @@ -2317,8 +3301,12 @@ diff --ruN a/stablehlo/stablehlo/experimental/transforms/StablehloCanonicalizeDy + patterns.add(&getContext()); + patterns.add(&getContext()); + patterns.add(&getContext()); -+ if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns), ++ ++ auto funcOp = getOperation(); ++ if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns), + config))) { ++ funcOp.emitError("Failed to converge StablehloCanonicalizeDynamism in ") ++ << config.maxIterations << " iterations"; + return signalPassFailure(); + } + } @@ -2502,4 +3490,61 @@ diff --ruN a/stablehlo/stablehlo/experimental/transforms/StablehloRefineShapes.c +} // namespace experimental +} // namespace stablehlo +} // namespace mlir +diff --ruN a/stablehlo/stablehlo/tests/verify_reduce.mlir b/stablehlo/stablehlo/tests/verify_reduce.mlir +--- stablehlo/stablehlo/tests/verify_reduce.mlir ++++ stablehlo/stablehlo/tests/verify_reduce.mlir +@@ -490,7 +490,7 @@ + // ----- + + func.func @reduce_parsing_pretty_reduce_non_commutative(%arg0: tensor , %arg1: tensor ) -> tensor { +- // expected-error@+1 {{expected the inner-op to be a commutative binary-op from stablehlo dialect, zero region, producing single result}} ++ // expected-error@+1 {{expected the inner-op to be a commutative binary-op that matching the reduce op dialect, with zero region, producing single result}} + %0 = stablehlo.reduce(%arg0 init: %arg1) applies stablehlo.divide across dimensions = [1] : (tensor, tensor) -> tensor loc("foo") + func.return %0 : tensor + } +@@ -498,7 +498,7 @@ + // ----- + + func.func @reduce_parsing_pretty_reduce_wrong_dialect(%arg0: tensor , %arg1: tensor ) -> tensor { +- // expected-error@+1 {{expected the inner-op to be a commutative binary-op from stablehlo dialect, zero region, producing single result}} ++ // expected-error@+1 {{expected the inner-op to be a commutative binary-op that matching the reduce op dialect, with zero region, producing single result}} + %0 = stablehlo.reduce(%arg0 init: %arg1) applies std.add across dimensions = [1] : (tensor, tensor) -> tensor loc("foo") + func.return %0 : tensor + } +@@ -506,7 +506,7 @@ + // ----- + + func.func @reduce_parsing_pretty_reduce_non_binary(%arg0: tensor , %arg1: tensor ) -> tensor { +- // expected-error@+1 {{expected the inner-op to be a commutative binary-op from stablehlo dialect, zero region, producing single result}} ++ // expected-error@+1 {{expected the inner-op to be a commutative binary-op that matching the reduce op dialect, with zero region, producing single result}} + %0 = stablehlo.reduce(%arg0 init: %arg1) applies stablehlo.reshape across dimensions = [1] : (tensor, tensor) -> tensor loc("foo") + func.return %0 : tensor + } +diff --ruN a/stablehlo/stablehlo/transforms/StablehloAggressiveSimplification.cpp b/stablehlo/stablehlo/transforms/StablehloAggressiveSimplification.cpp +--- stablehlo/stablehlo/transforms/StablehloAggressiveSimplification.cpp ++++ stablehlo/stablehlo/transforms/StablehloAggressiveSimplification.cpp +@@ -126,9 +126,8 @@ + + // The canonical form has the constant operand as the RHS. + if (isa(type.getElementType()) && lhsAttr && !rhsAttr) { +- rewriter.modifyOpInPlace(op, [op, lhs, rhs] { +- op->setOperands(ValueRange{rhs, lhs}); +- }); ++ rewriter.modifyOpInPlace( ++ op, [op, lhs, rhs] { op->setOperands(ValueRange{rhs, lhs}); }); + return success(); + } + +@@ -221,9 +220,8 @@ + + // The canonical form has the constant operand as the RHS. + if (isa(type.getElementType()) && lhsAttr && !rhsAttr) { +- rewriter.modifyOpInPlace(op, [op, lhs, rhs] { +- op->setOperands(ValueRange{rhs, lhs}); +- }); ++ rewriter.modifyOpInPlace( ++ op, [op, lhs, rhs] { op->setOperands(ValueRange{rhs, lhs}); }); + return success(); + } + diff --git a/third_party/xla/third_party/stablehlo/workspace.bzl b/third_party/xla/third_party/stablehlo/workspace.bzl index 271c373a66a8c1..411c6103290796 100644 --- a/third_party/xla/third_party/stablehlo/workspace.bzl +++ b/third_party/xla/third_party/stablehlo/workspace.bzl @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") def repo(): # LINT.IfChange - STABLEHLO_COMMIT = "c30f551469ca37a1f2a8c8ac42ef1b989573dce6" - STABLEHLO_SHA256 = "71720bd4003f417beb1acefd8f87b20f0e1db5edf498bf2f4642e5f0a3542c02" + STABLEHLO_COMMIT = "e708c82502982697540886738a307f72f9e9a7ff" + STABLEHLO_SHA256 = "3fecbe7779bee0801af746d974738748f7b461df54a4f610b32bb75647b32125" # LINT.ThenChange(Google-internal path) tf_http_archive( diff --git a/third_party/xla/third_party/triton/cl607293980.patch b/third_party/xla/third_party/triton/cl607293980.patch new file mode 100644 index 00000000000000..b7b9d0e84fab2e --- /dev/null +++ b/third_party/xla/third_party/triton/cl607293980.patch @@ -0,0 +1,17 @@ +Long standing patch due to licensing issues. +diff --git a/include/triton/Tools/Sys/GetEnv.hpp b/include/triton/Tools/Sys/GetEnv.hpp +index 31bc03fe1..a19a432df 100644 +--- a/include/triton/Tools/Sys/GetEnv.hpp ++++ b/include/triton/Tools/Sys/GetEnv.hpp +@@ -34,9 +34,10 @@ inline const std::set ENV_VARS = { + "AMDGCN_ENABLE_DUMP", + "DISABLE_FAST_REDUCTION", + "DISABLE_LLVM_OPT", +- "DISABLE_MMA_V3", ++ "ENABLE_MMA_V3", + "DISABLE_PTXAS_OPT", + "LLVM_IR_ENABLE_DUMP", ++ "MLIR_ENABLE_DIAGNOSTICS", + "MLIR_ENABLE_DUMP", + "TRITON_DISABLE_LINE_INFO", + "TRITON_DISABLE_RESHAPE_ENCODING_INFERENCE", diff --git a/third_party/xla/third_party/triton/workspace.bzl b/third_party/xla/third_party/triton/workspace.bzl index d2cf68c50e4306..e0364e4b646929 100644 --- a/third_party/xla/third_party/triton/workspace.bzl +++ b/third_party/xla/third_party/triton/workspace.bzl @@ -5,13 +5,15 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") def repo(): """Imports Triton.""" - TRITON_COMMIT = "cl601105910" - TRITON_SHA256 = "523b31822e431c79e2d6bc566272e7fc4f4183ae28aebcf662d11db740691d6d" + TRITON_COMMIT = "cl608559313" + TRITON_SHA256 = "d37c0a2921f756cb355dc7ea7e91ea708cef867117edff37106f5a947c5a5a38" tf_http_archive( name = "triton", sha256 = TRITON_SHA256, strip_prefix = "triton-{commit}".format(commit = TRITON_COMMIT), urls = tf_mirror_urls("https://github.com/openxla/triton/archive/{commit}.tar.gz".format(commit = TRITON_COMMIT)), # For temporary changes which haven't landed upstream yet. - patch_file = [], + patch_file = [ + "//third_party/triton:cl607293980.patch", # long standing :( + ], ) diff --git a/third_party/xla/third_party/tsl/.bazelrc b/third_party/xla/third_party/tsl/.bazelrc index a635862b43a43c..c21cf6e6e15d5d 100644 --- a/third_party/xla/third_party/tsl/.bazelrc +++ b/third_party/xla/third_party/tsl/.bazelrc @@ -255,6 +255,14 @@ build:cuda_clang_official --action_env=CLANG_CUDA_COMPILER_PATH="/usr/lib/llvm-1 build:cuda_clang_official --action_env=LD_LIBRARY_PATH="/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64" build:cuda_clang_official --crosstool_top="@sigbuild-r2.16-clang_config_cuda//crosstool:toolchain" +# Build with nvcc for CUDA and clang for host +build:nvcc_clang --config=cuda +# Unfortunately, cuda_configure.bzl demands this for using nvcc + clang +build:nvcc_clang --action_env=TF_CUDA_CLANG="1" +build:nvcc_clang --action_env=TF_NVCC_CLANG="1" +build:nvcc_clang --@local_config_cuda//:cuda_compiler=nvcc + + # Debug config build:dbg -c dbg # Only include debug info for files under tensorflow/, excluding kernels, to @@ -527,8 +535,8 @@ build:rbe_linux_cuda --repo_env=TF_NCCL_CONFIG_REPO="@sigbuild-r2.16-clang_confi test:rbe_linux_cuda --test_env=LD_LIBRARY_PATH="/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64" build:rbe_linux_cuda_nvcc --config=rbe_linux_cuda +build:rbe_linux_cuda_nvcc --config=nvcc_clang build:rbe_linux_cuda_nvcc --repo_env TF_NCCL_USE_STUB=1 -build:rbe_linux_cuda_nvcc --action_env=TF_NVCC_CLANG="1" # TODO(kanglan): Remove rbe_win and rbe_win_py3* after b/289091160 is fixed build:rbe_win --config=rbe_base @@ -577,6 +585,7 @@ build:elinux_armhf --copt -mfp16-format=ieee # Load rc file written by ./configure. try-import %workspace%/.tf_configure.bazelrc +try-import %workspace%/xla_configure.bazelrc # Load rc file with user-specific options. try-import %workspace%/.bazelrc.user @@ -777,28 +786,38 @@ test:linux_cuda_pycpp_test_filters --build_tag_filters=-no_oss,-oss_excluded,-os test:linux_cuda_pycpp_test_filters --test_lang_filters=cc,py --test_size_filters=small,medium test:linux_cuda_pycpp_test --config=linux_cuda_pycpp_test_filters -- //tensorflow/... -//tensorflow/python/integration_testing/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... # ARM64 PYCPP -test:linux_arm64_pycpp_test_filters --test_tag_filters=-no_oss,-no_aarch64,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only -test:linux_arm64_pycpp_test_filters --build_tag_filters=-no_oss,-no_aarch64,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only -test:linux_arm64_pycpp_test_filters --test_lang_filters=cc,py --test_size_filters=small,medium --flaky_test_attempts=3 +# In Linux Arm64 presubmit/continuous build, we cross-compile the binaries on +# Linux x86 so that we can use RBE. Since tests still need to run on the single +# host Arm64 machine, the build becomes too slow (~30 min) to be a presubmit. +# For testing purposes, we want to see the runtime performance of an +# experimental job that is build-only, i.e, we only build the test targets and +# do not run them. By prefixing the configs with "build", we can run both +# `bazel build` and `bazel test` commands with the same config as test configs +# inherit from build. +build:linux_arm64_pycpp_test_filters --test_tag_filters=-no_oss,-no_aarch64,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only +build:linux_arm64_pycpp_test_filters --build_tag_filters=-no_oss,-no_aarch64,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only +build:linux_arm64_pycpp_test_filters --test_lang_filters=cc,py --test_size_filters=small,medium --flaky_test_attempts=3 # TODO(michaelhudgins): Why do we need to specifically omit go and java here? -test:linux_arm64_pycpp_test --config=linux_arm64_pycpp_test_filters -- //tensorflow/... -//tensorflow/python/integration_testing/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/core/grappler/optimizers:auto_mixed_precision_test_cpu -//tensorflow/core/grappler/optimizers:remapper_test_cpu -//tensorflow/core/kernels/image:resize_bicubic_op_test -//tensorflow/compiler/mlir/tfr/examples/customization:test_ops_test -//tensorflow/compiler/mlir/tfr/examples/mnist:mnist_ops_test -//tensorflow/compiler/mlir/tfr/examples/pad:pad_ops_test -//tensorflow/python/tools:aot_compiled_test +build:linux_arm64_pycpp_test --config=linux_arm64_pycpp_test_filters -- //tensorflow/... -//tensorflow/python/integration_testing/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/core/grappler/optimizers:auto_mixed_precision_test_cpu -//tensorflow/core/grappler/optimizers:remapper_test_cpu -//tensorflow/core/kernels/image:resize_bicubic_op_test -//tensorflow/compiler/mlir/tfr/examples/customization:test_ops_test -//tensorflow/compiler/mlir/tfr/examples/mnist:mnist_ops_test -//tensorflow/compiler/mlir/tfr/examples/pad:pad_ops_test -//tensorflow/python/tools:aot_compiled_test # CROSS-COMPILE ARM64 PYCPP -test:cross_compile_linux_arm64_pycpp_test --config=linux_arm64_pycpp_test +build:cross_compile_linux_arm64_pycpp_test --config=linux_arm64_pycpp_test # Tests that fail only when cross-compiled -test:cross_compile_linux_arm64_pycpp_test -//tensorflow/compiler/mlir/quantization/stablehlo:convert_tf_quant_to_mhlo_int_test +build:cross_compile_linux_arm64_pycpp_test -//tensorflow/compiler/mlir/quantization/stablehlo:convert_tf_quant_to_mhlo_int_test # MACOS ARM64 PYCPP test:macos_arm64_pycpp_test_filters --test_tag_filters=-no_oss,-oss_excluded,-oss_serial,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test,-no_mac_arm64,-no_aarch64 test:macos_arm64_pycpp_test_filters --build_tag_filters=-no_oss,-oss_excluded,-oss_serial,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test,-no_mac_arm64,-no_aarch64 test:macos_arm64_pycpp_test_filters --test_lang_filters=cc,py --test_size_filters=small,medium test:macos_arm64_pycpp_test --config=macos_arm64_pycpp_test_filters -- //tensorflow/... -//tensorflow/python/integration_testing/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/compiler/aot/... -//tensorflow/core/kernels/image:resize_bicubic_op_test # MACOS X86 PYCPP -test:macos_x86_pycpp_test_filters --test_tag_filters=-no_oss,-oss_excluded,-oss_serial,-no_oss_py38,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test -test:macos_x86_pycpp_test_filters --build_tag_filters=-no_oss,-oss_excluded,-oss_serial,-no_oss_py38,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test -test:macos_x86_pycpp_test_filters --keep_going --test_lang_filters=cc,py --test_size_filters=small,medium -test:macos_x86_pycpp_test --config=macos_x86_pycpp_test_filters -- //tensorflow/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/python/integration_testing/... -//tensorflow/tools/toolchains/... -//tensorflow/lite/... -//tensorflow/compiler/aot/... +# These are defined as build configs so that we can run a build only job. See +# the note under "ARM64 PYCPP" for more details. +build:macos_x86_pycpp_test_filters --test_tag_filters=-no_oss,-oss_excluded,-oss_serial,-no_oss_py38,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test +build:macos_x86_pycpp_test_filters --build_tag_filters=-no_oss,-oss_excluded,-oss_serial,-no_oss_py38,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test +build:macos_x86_pycpp_test_filters --keep_going --test_lang_filters=cc,py --test_size_filters=small,medium +build:macos_x86_pycpp_test --config=macos_x86_pycpp_test_filters -- //tensorflow/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/python/integration_testing/... -//tensorflow/tools/toolchains/... -//tensorflow/lite/... -//tensorflow/compiler/aot/... # CROSS-COMPILE MACOS X86 PYCPP -test:cross_compile_macos_x86_pycpp_test --config=macos_x86_pycpp_test -test:cross_compile_macos_x86_pycpp_test -//tensorflow/core/kernels:quantized_conv_ops_test -//tensorflow/core/kernels:quantized_matmul_op_test -//tensorflow/python/ops:quantized_conv_ops_test -//tensorflow/tools/graph_transforms:transforms_test -//tensorflow/python/tools:aot_compiled_test +build:cross_compile_macos_x86_pycpp_test --config=macos_x86_pycpp_test +build:cross_compile_macos_x86_pycpp_test -//tensorflow/core/kernels:quantized_conv_ops_test -//tensorflow/core/kernels:quantized_matmul_op_test -//tensorflow/python/ops:quantized_conv_ops_test -//tensorflow/tools/graph_transforms:transforms_test -//tensorflow/python/tools:aot_compiled_test # END TF TEST SUITE OPTIONS # START CROSS-COMPILE CONFIGS @@ -855,8 +874,12 @@ build:cross_compile_macos_x86 --platform_mappings=tensorflow/tools/toolchains/cr # RBE cross-compile configs for Darwin x86 build:rbe_cross_compile_macos_x86 --config=cross_compile_macos_x86 build:rbe_cross_compile_macos_x86 --config=rbe_cross_compile_base +build:rbe_cross_compile_macos_x86 --bes_upload_mode=nowait_for_upload_complete test:rbe_cross_compile_macos_x86 --config=rbe_cross_compile_base # Increase the test timeout as tests often take longer on mac. test:rbe_cross_compile_macos_x86 --test_timeout=300,450,1200,3600 +# Limit jobs to 100 to avoid running into "out of memory" issues (b/316266643) +build:rbe_cross_compile_macos_x86 --jobs=100 +test:rbe_cross_compile_macos_x86 --jobs=100 # END MACOS CROSS-COMPILE CONFIGS # END CROSS-COMPILE CONFIGS diff --git a/third_party/xla/third_party/tsl/opensource_only.files b/third_party/xla/third_party/tsl/opensource_only.files index 75e0698b80b04e..865ad4d8aa3038 100644 --- a/third_party/xla/third_party/tsl/opensource_only.files +++ b/third_party/xla/third_party/tsl/opensource_only.files @@ -17,8 +17,6 @@ third_party/ducc/threading.h: third_party/eigen3/BUILD: third_party/eigen3/LICENSE: third_party/eigen3/eigen_archive.BUILD: -third_party/gif.BUILD: -third_party/gif_fix_strtok_r.patch: third_party/git/BUILD.tpl: third_party/git/BUILD: third_party/git/git_configure.bzl: @@ -68,8 +66,6 @@ third_party/nccl/nccl_configure.bzl: third_party/nccl/system.BUILD.tpl: third_party/nvtx/BUILD: third_party/nvtx/LICENSE: -third_party/png.BUILD: -third_party/png_fix_rpi.patch: third_party/protobuf/BUILD: third_party/py/non_hermetic/BUILD.tpl: third_party/py/non_hermetic/BUILD: @@ -151,4 +147,4 @@ tsl/mkl/BUILD: tsl/mkl/LICENSE: tsl/mkl/MKL_LICENSE: tsl/mkl/build_defs.bzl: -tsl/platform/default/build_config/BUILD: +tsl/profiler/BUILD: diff --git a/third_party/xla/third_party/tsl/third_party/compute_library/BUILD b/third_party/xla/third_party/tsl/third_party/compute_library/BUILD index b353c1fa0aedba..4fc694c50a43cf 100644 --- a/third_party/xla/third_party/tsl/third_party/compute_library/BUILD +++ b/third_party/xla/third_party/tsl/third_party/compute_library/BUILD @@ -1,9 +1,6 @@ load("@bazel_skylib//:bzl_library.bzl", "bzl_library") -exports_files( - ["LICENSE"], - visibility = ["//visibility:public"], -) +exports_files(["LICENSE"]) config_setting( name = "build_with_acl", diff --git a/third_party/xla/third_party/tsl/third_party/curl.BUILD b/third_party/xla/third_party/tsl/third_party/curl.BUILD index 8dcd54451dae66..b31d8488aaa0ba 100644 --- a/third_party/xla/third_party/tsl/third_party/curl.BUILD +++ b/third_party/xla/third_party/tsl/third_party/curl.BUILD @@ -14,11 +14,6 @@ CURL_WIN_COPTS = [ "/DCURL_DISABLE_PROXY", "/DHAVE_LIBZ", "/DHAVE_ZLIB_H", - # Defining _USING_V110_SDK71_ is hackery to defeat curl's incorrect - # detection of what OS releases we can build on with VC 2012. This - # may not be needed (or may have to change) if the WINVER setting - # changes in //third_party/msvc/vc_12_0/CROSSTOOL. - "/D_USING_V110_SDK71_", ] CURL_WIN_SRCS = [ diff --git a/third_party/xla/third_party/tsl/third_party/gif.BUILD b/third_party/xla/third_party/tsl/third_party/gif.BUILD deleted file mode 100644 index 51621ba953e6e2..00000000000000 --- a/third_party/xla/third_party/tsl/third_party/gif.BUILD +++ /dev/null @@ -1,61 +0,0 @@ -# Description: -# A library for decoding and encoding GIF images - -licenses(["notice"]) # MIT - -exports_files(["COPYING"]) - -cc_library( - name = "gif", - srcs = [ - "dgif_lib.c", - "egif_lib.c", - "gif_err.c", - "gif_font.c", - "gif_hash.c", - "gif_hash.h", - "gif_lib_private.h", - "gifalloc.c", - "openbsd-reallocarray.c", - "quantize.c", - ], - hdrs = ["gif_lib.h"], - defines = select({ - ":android": [ - "S_IREAD=S_IRUSR", - "S_IWRITE=S_IWUSR", - "S_IEXEC=S_IXUSR", - ], - "//conditions:default": [], - }), - includes = ["."], - visibility = ["//visibility:public"], - deps = select({ - ":windows": [":windows_polyfill"], - "//conditions:default": [], - }), -) - -cc_library( - name = "windows_polyfill", - hdrs = ["windows/unistd.h"], - includes = ["windows"], -) - -genrule( - name = "windows_unistd_h", - outs = ["windows/unistd.h"], - cmd = "touch $@", -) - -config_setting( - name = "windows", - values = { - "cpu": "x64_windows", - }, -) - -config_setting( - name = "android", - values = {"crosstool_top": "//external:android/crosstool"}, -) diff --git a/third_party/xla/third_party/tsl/third_party/gif_fix_strtok_r.patch b/third_party/xla/third_party/tsl/third_party/gif_fix_strtok_r.patch deleted file mode 100644 index c9c9c30c41fab9..00000000000000 --- a/third_party/xla/third_party/tsl/third_party/gif_fix_strtok_r.patch +++ /dev/null @@ -1,15 +0,0 @@ -diff -r -u ./fixed_gif_font.c ./gif_font.c ---- ./fixed_gif_font.c 2019-09-05 11:05:25.009598262 -0700 -+++ ./gif_font.c 2019-09-05 10:52:45.308389085 -0700 -@@ -11,6 +11,11 @@ - - #include "gif_lib.h" - -+// Windows doesn't have strtok_r. -+#if defined(WIN32) || defined(_WIN32) || defined(__WIN32) && !defined(__CYGWIN__) -+#define strtok_r strtok_s -+#endif -+ - /***************************************************************************** - Ascii 8 by 8 regular font - only first 128 characters are supported. - *****************************************************************************/ diff --git a/third_party/xla/third_party/tsl/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc.tpl b/third_party/xla/third_party/tsl/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc.tpl index 0da1d7b58f4bb0..74fafb9b32f516 100755 --- a/third_party/xla/third_party/tsl/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc.tpl +++ b/third_party/xla/third_party/tsl/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc.tpl @@ -1,4 +1,4 @@ -#!/usr/bin/env python +#!/usr/bin/env python3 # Copyright 2015 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -41,7 +41,7 @@ import os import subprocess import re import sys -import pipes +import shlex # Template values set by cuda_autoconf. CPU_COMPILER = ('%{cpu_compiler}') @@ -299,7 +299,7 @@ def main(): if args.x and args.x[0] == 'cuda': if args.cuda_log: Log('-x cuda') - leftover = [pipes.quote(s) for s in leftover] + leftover = [shlex.quote(s) for s in leftover] if args.cuda_log: Log('using nvcc') return InvokeNvcc(leftover, log=args.cuda_log) diff --git a/third_party/xla/third_party/tsl/third_party/gpus/cuda/build_defs.bzl.tpl b/third_party/xla/third_party/tsl/third_party/gpus/cuda/build_defs.bzl.tpl index 189d3e3e784003..bc865cecb3240a 100644 --- a/third_party/xla/third_party/tsl/third_party/gpus/cuda/build_defs.bzl.tpl +++ b/third_party/xla/third_party/tsl/third_party/gpus/cuda/build_defs.bzl.tpl @@ -94,6 +94,25 @@ def if_cuda_is_configured(x, no_cuda = []): return select({"//conditions:default": x}) return select({"//conditions:default": no_cuda}) +def if_cuda_newer_than(wanted_ver, if_true, if_false = []): + """Tests if CUDA was enabled during the configured process and if the + configured version is at least `wanted_ver`. `wanted_ver` needs + to be provided as a string in the format `_`. + Example: `11_0` + """ + + wanted_major = int(wanted_ver.split('_')[0]) + wanted_minor = int(wanted_ver.split('_')[1]) + + configured_version = "%{cuda_version}" + configured_major = int(configured_version.split('.')[0]) + configured_minor = int(configured_version.split('.')[1]) + + if %{cuda_is_configured} and (wanted_major, wanted_minor) <= (configured_major, configured_minor): + return select({"//conditions:default": if_true}) + return select({"//conditions:default": if_false}) + + def cuda_header_library( name, hdrs, diff --git a/third_party/xla/third_party/tsl/third_party/gpus/cuda_configure.bzl b/third_party/xla/third_party/tsl/third_party/gpus/cuda_configure.bzl index 5cd589e8eeb2be..89ea8f54a3495e 100644 --- a/third_party/xla/third_party/tsl/third_party/gpus/cuda_configure.bzl +++ b/third_party/xla/third_party/tsl/third_party/gpus/cuda_configure.bzl @@ -827,6 +827,7 @@ def _create_dummy_repository(repository_ctx): "%{cuda_is_configured}": "False", "%{cuda_extra_copts}": "[]", "%{cuda_gpu_architectures}": "[]", + "%{cuda_version}": "0.0", }, ) _tpl( @@ -1214,6 +1215,7 @@ def _create_local_cuda_repository(repository_ctx): cuda_config.compute_capabilities, ), "%{cuda_gpu_architectures}": str(cuda_config.compute_capabilities), + "%{cuda_version}": cuda_config.cuda_version, }, ) @@ -1427,6 +1429,7 @@ def _create_remote_cuda_repository(repository_ctx, remote_config_repo): repository_ctx, compute_capabilities(repository_ctx), ), + "%{cuda_version}": get_host_environ(repository_ctx, _TF_CUDA_VERSION), }, ) repository_ctx.template( diff --git a/third_party/xla/third_party/tsl/third_party/gpus/rocm/build_defs.bzl.tpl b/third_party/xla/third_party/tsl/third_party/gpus/rocm/build_defs.bzl.tpl index 2b4595bb222885..339733755d6f1f 100644 --- a/third_party/xla/third_party/tsl/third_party/gpus/rocm/build_defs.bzl.tpl +++ b/third_party/xla/third_party/tsl/third_party/gpus/rocm/build_defs.bzl.tpl @@ -38,6 +38,16 @@ def rocm_version_number(): """Returns a list of supported GPU architectures.""" return %{rocm_version_number} +def if_gpu_is_configured(if_true, if_false = []): + """Tests if ROCm or CUDA was enabled during the configure process. + + Unlike if_rocm() or if_cuda(), this does not require that we are building + with --config=rocm or --config=cuda, respectively. Used to allow non-GPU + code to depend on ROCm or CUDA libraries. + + """ + return select({"//conditions:default": %{gpu_is_configured}}) + def if_rocm_is_configured(x): """Tests if the ROCm was enabled during the configure process. diff --git a/third_party/xla/third_party/tsl/third_party/gpus/rocm_configure.bzl b/third_party/xla/third_party/tsl/third_party/gpus/rocm_configure.bzl index 5c1195bada43f8..c96ecf4d62eb64 100644 --- a/third_party/xla/third_party/tsl/third_party/gpus/rocm_configure.bzl +++ b/third_party/xla/third_party/tsl/third_party/gpus/rocm_configure.bzl @@ -10,6 +10,7 @@ load( ":cuda_configure.bzl", + "enable_cuda", "make_copy_dir_rule", "make_copy_files_rule", "to_list_of_strings", @@ -449,6 +450,7 @@ def _create_dummy_repository(repository_ctx): "rocm:build_defs.bzl", { "%{rocm_is_configured}": "False", + "%{gpu_is_configured}": "if_true" if enable_cuda(repository_ctx) else "if_false", "%{rocm_extra_copts}": "[]", "%{rocm_gpu_architectures}": "[]", "%{rocm_version_number}": "0", @@ -634,6 +636,7 @@ def _create_local_rocm_repository(repository_ctx): tpl_paths["rocm:build_defs.bzl"], { "%{rocm_is_configured}": "True", + "%{gpu_is_configured}": "if_true", "%{rocm_extra_copts}": _compute_rocm_extra_copts( repository_ctx, rocm_config.amdgpu_targets, @@ -762,6 +765,7 @@ def _create_remote_rocm_repository(repository_ctx, remote_config_repo): "rocm:build_defs.bzl", { "%{rocm_is_configured}": "True", + "%{gpu_is_configured}": "if_true", "%{rocm_extra_copts}": _compute_rocm_extra_copts( repository_ctx, [], #_compute_capabilities(repository_ctx) @@ -815,6 +819,7 @@ _ENVIRONS = [ _GCC_HOST_COMPILER_PATH, _GCC_HOST_COMPILER_PREFIX, "TF_NEED_ROCM", + "TF_NEED_CUDA", # Needed by the `if_gpu_is_configured` macro _ROCM_TOOLKIT_PATH, _TF_ROCM_AMDGPU_TARGETS, ] diff --git a/third_party/xla/third_party/tsl/third_party/hwloc/BUILD b/third_party/xla/third_party/tsl/third_party/hwloc/BUILD index 3848c0818e77db..db5c9ec8873bb4 100644 --- a/third_party/xla/third_party/tsl/third_party/hwloc/BUILD +++ b/third_party/xla/third_party/tsl/third_party/hwloc/BUILD @@ -1,12 +1,10 @@ # BUILD file to make this directory a package. package( - default_visibility = ["//visibility:public"], # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], licenses = ["notice"], ) exports_files( ["static-components.h"], - visibility = ["//visibility:public"], ) diff --git a/third_party/xla/third_party/tsl/third_party/implib_so/BUILD b/third_party/xla/third_party/tsl/third_party/implib_so/BUILD index 8401d6152b88ab..ca6976cd8d3425 100644 --- a/third_party/xla/third_party/tsl/third_party/implib_so/BUILD +++ b/third_party/xla/third_party/tsl/third_party/implib_so/BUILD @@ -5,7 +5,6 @@ licenses(["notice"]) # MIT py_binary( name = "get_symbols", srcs = ["get_symbols.py"], - visibility = ["//visibility:public"], deps = [ "@bazel_tools//tools/python/runfiles", "@implib_so//:implib_gen_lib", @@ -15,7 +14,6 @@ py_binary( py_binary( name = "make_stub", srcs = ["make_stub.py"], - visibility = ["//visibility:public"], deps = [ "@bazel_tools//tools/python/runfiles", "@implib_so//:implib_gen_lib", diff --git a/third_party/xla/third_party/tsl/third_party/jpeg/BUILD b/third_party/xla/third_party/tsl/third_party/jpeg/BUILD deleted file mode 100644 index ed1568c32f33ed..00000000000000 --- a/third_party/xla/third_party/tsl/third_party/jpeg/BUILD +++ /dev/null @@ -1,3 +0,0 @@ -# Needed to make this a package. - -# copybara:uncomment package(default_applicable_licenses = ["//tensorflow:license"]) diff --git a/third_party/xla/third_party/tsl/third_party/jpeg/BUILD.system b/third_party/xla/third_party/tsl/third_party/jpeg/BUILD.system deleted file mode 100644 index f4f52da9bdae1b..00000000000000 --- a/third_party/xla/third_party/tsl/third_party/jpeg/BUILD.system +++ /dev/null @@ -1,12 +0,0 @@ -licenses(["notice"]) # custom notice-style license, see LICENSE.md - -filegroup( - name = "LICENSE.md", - visibility = ["//visibility:public"], -) - -cc_library( - name = "jpeg", - linkopts = ["-ljpeg"], - visibility = ["//visibility:public"], -) diff --git a/third_party/xla/third_party/tsl/third_party/jpeg/jpeg.BUILD b/third_party/xla/third_party/tsl/third_party/jpeg/jpeg.BUILD deleted file mode 100644 index 9f61f9e31e5e12..00000000000000 --- a/third_party/xla/third_party/tsl/third_party/jpeg/jpeg.BUILD +++ /dev/null @@ -1,806 +0,0 @@ -# Description: -# libjpeg-turbo is a drop in replacement for jpeglib optimized with SIMD. - -load("@bazel_skylib//rules:expand_template.bzl", "expand_template") -load("@bazel_skylib//rules:common_settings.bzl", "string_flag") - -licenses(["notice"]) # custom notice-style license, see LICENSE.md - -exports_files(["LICENSE.md"]) - -WIN_COPTS = [ - "/Ox", - "-DWITH_SIMD", - "-wd4996", -] - -libjpegturbo_copts = select({ - ":android": [ - "-O3", - "-fPIC", - "-w", - ], - ":windows": WIN_COPTS, - "//conditions:default": [ - "-O3", - "-w", - ], -}) + select({ - ":armeabi-v7a": [ - "-D__ARM_NEON__", - "-DNEON_INTRINSICS", - "-march=armv7-a", - "-mfpu=neon", - "-mfloat-abi=softfp", - "-fprefetch-loop-arrays", - ], - ":arm64-v8a": [ - "-DNEON_INTRINSICS", - ], - ":linux_ppc64le": [ - "-mcpu=power8", - "-mtune=power8", - ], - "//conditions:default": [], -}) - -cc_library( - name = "jpeg", - srcs = [ - "jaricom.c", - "jcapimin.c", - "jcapistd.c", - "jcarith.c", - "jccoefct.c", - "jccolor.c", - "jcdctmgr.c", - "jchuff.c", - "jchuff.h", - "jcinit.c", - "jcmainct.c", - "jcmarker.c", - "jcmaster.c", - "jcomapi.c", - "jconfig.h", - "jconfigint.h", - "jcparam.c", - "jcphuff.c", - "jcprepct.c", - "jcsample.c", - "jctrans.c", - "jdapimin.c", - "jdapistd.c", - "jdarith.c", - "jdatadst.c", - "jdatasrc.c", - "jdcoefct.c", - "jdcoefct.h", - "jdcolor.c", - "jdct.h", - "jddctmgr.c", - "jdhuff.c", - "jdhuff.h", - "jdinput.c", - "jdmainct.c", - "jdmainct.h", - "jdmarker.c", - "jdmaster.c", - "jdmaster.h", - "jdmerge.c", - "jdmerge.h", - "jdphuff.c", - "jdpostct.c", - "jdsample.c", - "jdsample.h", - "jdtrans.c", - "jerror.c", - "jfdctflt.c", - "jfdctfst.c", - "jfdctint.c", - "jidctflt.c", - "jidctfst.c", - "jidctint.c", - "jidctred.c", - "jinclude.h", - "jmemmgr.c", - "jmemnobs.c", - "jmemsys.h", - "jpeg_nbits_table.h", - "jpegcomp.h", - "jquant1.c", - "jquant2.c", - "jutils.c", - "jversion.h", - ], - hdrs = [ - "jccolext.c", # should have been named .inc - "jdcol565.c", # should have been named .inc - "jdcolext.c", # should have been named .inc - "jdmrg565.c", # should have been named .inc - "jdmrgext.c", # should have been named .inc - "jerror.h", - "jmorecfg.h", - "jpegint.h", - "jpeglib.h", - "jstdhuff.c", # should have been named .inc - ], - copts = libjpegturbo_copts, - visibility = ["//visibility:public"], - deps = select({ - ":nosimd": [":simd_none"], - ":k8": [":simd_x86_64"], - ":armeabi-v7a": [":simd_armv7a"], - ":arm64-v8a": [":simd_armv8a"], - ":linux_ppc64le": [":simd_altivec"], - ":windows": [":simd_win_x86_64"], - "//conditions:default": [":simd_none"], - }), -) - -cc_library( - name = "simd_altivec", - srcs = [ - "jchuff.h", - "jconfig.h", - "jconfigint.h", - "jdct.h", - "jerror.h", - "jinclude.h", - "jmorecfg.h", - "jpegint.h", - "jpeglib.h", - "jsimd.h", - "jsimddct.h", - "simd/jsimd.h", - "simd/powerpc/jccolor-altivec.c", - "simd/powerpc/jcgray-altivec.c", - "simd/powerpc/jcsample-altivec.c", - "simd/powerpc/jdcolor-altivec.c", - "simd/powerpc/jdmerge-altivec.c", - "simd/powerpc/jdsample-altivec.c", - "simd/powerpc/jfdctfst-altivec.c", - "simd/powerpc/jfdctint-altivec.c", - "simd/powerpc/jidctfst-altivec.c", - "simd/powerpc/jidctint-altivec.c", - "simd/powerpc/jquanti-altivec.c", - "simd/powerpc/jsimd.c", - ], - hdrs = [ - "simd/powerpc/jccolext-altivec.c", - "simd/powerpc/jcgryext-altivec.c", - "simd/powerpc/jcsample.h", - "simd/powerpc/jdcolext-altivec.c", - "simd/powerpc/jdmrgext-altivec.c", - "simd/powerpc/jsimd_altivec.h", - ], - copts = libjpegturbo_copts, -) - -SRCS_SIMD_COMMON = [ - "jchuff.h", - "jconfig.h", - "jconfigint.h", - "jdct.h", - "jerror.h", - "jinclude.h", - "jmorecfg.h", - "jpegint.h", - "jpeglib.h", - "jsimddct.h", - "jsimd.h", - "simd/jsimd.h", -] - -cc_library( - name = "simd_x86_64", - srcs = [ - "simd/x86_64/jccolor-avx2.o", - "simd/x86_64/jccolor-sse2.o", - "simd/x86_64/jcgray-avx2.o", - "simd/x86_64/jcgray-sse2.o", - "simd/x86_64/jchuff-sse2.o", - "simd/x86_64/jcphuff-sse2.o", - "simd/x86_64/jcsample-avx2.o", - "simd/x86_64/jcsample-sse2.o", - "simd/x86_64/jdcolor-avx2.o", - "simd/x86_64/jdcolor-sse2.o", - "simd/x86_64/jdmerge-avx2.o", - "simd/x86_64/jdmerge-sse2.o", - "simd/x86_64/jdsample-avx2.o", - "simd/x86_64/jdsample-sse2.o", - "simd/x86_64/jfdctflt-sse.o", - "simd/x86_64/jfdctfst-sse2.o", - "simd/x86_64/jfdctint-avx2.o", - "simd/x86_64/jfdctint-sse2.o", - "simd/x86_64/jidctflt-sse2.o", - "simd/x86_64/jidctfst-sse2.o", - "simd/x86_64/jidctint-avx2.o", - "simd/x86_64/jidctint-sse2.o", - "simd/x86_64/jidctred-sse2.o", - "simd/x86_64/jquantf-sse2.o", - "simd/x86_64/jquanti-avx2.o", - "simd/x86_64/jquanti-sse2.o", - "simd/x86_64/jsimd.c", - "simd/x86_64/jsimdcpu.o", - ] + SRCS_SIMD_COMMON, - copts = libjpegturbo_copts, - linkstatic = 1, -) - -genrule( - name = "simd_x86_64_assemblage23", - srcs = [ - "jconfig.h", - "jconfigint.h", - "simd/x86_64/jccolext-avx2.asm", - "simd/x86_64/jccolext-sse2.asm", - "simd/x86_64/jccolor-avx2.asm", - "simd/x86_64/jccolor-sse2.asm", - "simd/x86_64/jcgray-avx2.asm", - "simd/x86_64/jcgray-sse2.asm", - "simd/x86_64/jcgryext-avx2.asm", - "simd/x86_64/jcgryext-sse2.asm", - "simd/x86_64/jchuff-sse2.asm", - "simd/x86_64/jcphuff-sse2.asm", - "simd/x86_64/jcsample-avx2.asm", - "simd/x86_64/jcsample-sse2.asm", - "simd/x86_64/jdcolext-avx2.asm", - "simd/x86_64/jdcolext-sse2.asm", - "simd/x86_64/jdcolor-avx2.asm", - "simd/x86_64/jdcolor-sse2.asm", - "simd/x86_64/jdmerge-avx2.asm", - "simd/x86_64/jdmerge-sse2.asm", - "simd/x86_64/jdmrgext-avx2.asm", - "simd/x86_64/jdmrgext-sse2.asm", - "simd/x86_64/jdsample-avx2.asm", - "simd/x86_64/jdsample-sse2.asm", - "simd/x86_64/jfdctflt-sse.asm", - "simd/x86_64/jfdctfst-sse2.asm", - "simd/x86_64/jfdctint-avx2.asm", - "simd/x86_64/jfdctint-sse2.asm", - "simd/x86_64/jidctflt-sse2.asm", - "simd/x86_64/jidctfst-sse2.asm", - "simd/x86_64/jidctint-avx2.asm", - "simd/x86_64/jidctint-sse2.asm", - "simd/x86_64/jidctred-sse2.asm", - "simd/x86_64/jquantf-sse2.asm", - "simd/x86_64/jquanti-avx2.asm", - "simd/x86_64/jquanti-sse2.asm", - "simd/x86_64/jsimdcpu.asm", - "simd/nasm/jcolsamp.inc", - "simd/nasm/jdct.inc", - "simd/nasm/jsimdcfg.inc", - "simd/nasm/jsimdcfg.inc.h", - "simd/nasm/jsimdext.inc", - ], - outs = [ - "simd/x86_64/jccolor-avx2.o", - "simd/x86_64/jccolor-sse2.o", - "simd/x86_64/jcgray-avx2.o", - "simd/x86_64/jcgray-sse2.o", - "simd/x86_64/jchuff-sse2.o", - "simd/x86_64/jcphuff-sse2.o", - "simd/x86_64/jcsample-avx2.o", - "simd/x86_64/jcsample-sse2.o", - "simd/x86_64/jdcolor-avx2.o", - "simd/x86_64/jdcolor-sse2.o", - "simd/x86_64/jdmerge-avx2.o", - "simd/x86_64/jdmerge-sse2.o", - "simd/x86_64/jdsample-avx2.o", - "simd/x86_64/jdsample-sse2.o", - "simd/x86_64/jfdctflt-sse.o", - "simd/x86_64/jfdctfst-sse2.o", - "simd/x86_64/jfdctint-avx2.o", - "simd/x86_64/jfdctint-sse2.o", - "simd/x86_64/jidctflt-sse2.o", - "simd/x86_64/jidctfst-sse2.o", - "simd/x86_64/jidctint-avx2.o", - "simd/x86_64/jidctint-sse2.o", - "simd/x86_64/jidctred-sse2.o", - "simd/x86_64/jquantf-sse2.o", - "simd/x86_64/jquanti-avx2.o", - "simd/x86_64/jquanti-sse2.o", - "simd/x86_64/jsimdcpu.o", - ], - cmd = "for out in $(OUTS); do\n" + - " $(location @nasm//:nasm) -f elf64" + - " -DELF -DPIC -D__x86_64__" + - " -I $$(dirname $(location jconfig.h))/" + - " -I $$(dirname $(location jconfigint.h))/" + - " -I $$(dirname $(location simd/nasm/jsimdcfg.inc.h))/" + - " -I $$(dirname $(location simd/x86_64/jccolext-sse2.asm))/" + - " -o $$out" + - " $$(dirname $(location simd/x86_64/jccolext-sse2.asm))/$$(basename $${out%.o}.asm)\n" + - "done", - tools = ["@nasm"], -) - -expand_template( - name = "neon-compat_gen", - out = "simd/arm/neon-compat.h", - substitutions = { - "#cmakedefine HAVE_VLD1_S16_X3": "#define HAVE_VLD1_S16_X3", - "#cmakedefine HAVE_VLD1_U16_X2": "#define HAVE_VLD1_U16_X2", - "#cmakedefine HAVE_VLD1Q_U8_X4": "#define HAVE_VLD1Q_U8_X4", - }, - template = "simd/arm/neon-compat.h.in", -) - -genrule( - name = "neon-compat_hdr_src", - srcs = ["simd/arm/neon-compat.h"], - outs = ["neon-compat.h"], - cmd = "cp $(location simd/arm/neon-compat.h) $@", -) - -cc_library( - name = "neon-compat_hdr", - hdrs = ["neon-compat.h"], - copts = libjpegturbo_copts, -) - -SRCS_SIMD_ARM = [ - "simd/arm/jccolor-neon.c", - "simd/arm/jcgray-neon.c", - "simd/arm/jcphuff-neon.c", - "simd/arm/jcsample-neon.c", - "simd/arm/jdcolor-neon.c", - "simd/arm/jdmerge-neon.c", - "simd/arm/jdsample-neon.c", - "simd/arm/jfdctfst-neon.c", - "simd/arm/jfdctint-neon.c", - "simd/arm/jidctfst-neon.c", - "simd/arm/jidctint-neon.c", - "simd/arm/jidctred-neon.c", - "simd/arm/jquanti-neon.c", -] - -# .c files in the following list are used like .h files in that they are -# "#include"-ed in the actual .c files. So, treat them like normal headers, and -# they *should not* be compiled into individual objects. -HDRS_SIMD_ARM = [ - "simd/arm/align.h", - "simd/arm/jchuff.h", - "simd/arm/jcgryext-neon.c", - "simd/arm/jdcolext-neon.c", - "simd/arm/jdmrgext-neon.c", -] - -cc_library( - name = "simd_armv7a", - srcs = [ - "simd/arm/aarch32/jchuff-neon.c", - "simd/arm/aarch32/jsimd.c", - ] + SRCS_SIMD_COMMON + SRCS_SIMD_ARM, - hdrs = [ - "simd/arm/aarch32/jccolext-neon.c", - ] + HDRS_SIMD_ARM, - copts = libjpegturbo_copts, - visibility = ["//visibility:private"], - deps = [":neon-compat_hdr"], -) - -cc_library( - name = "simd_armv8a", - srcs = [ - "simd/arm/aarch64/jchuff-neon.c", - "simd/arm/aarch64/jsimd.c", - ] + SRCS_SIMD_COMMON + SRCS_SIMD_ARM, - hdrs = [ - "simd/arm/aarch64/jccolext-neon.c", - ] + HDRS_SIMD_ARM, - copts = libjpegturbo_copts, - visibility = ["//visibility:private"], - deps = [":neon-compat_hdr"], -) - -cc_library( - name = "simd_win_x86_64", - srcs = [ - "simd/x86_64/jccolor-avx2.obj", - "simd/x86_64/jccolor-sse2.obj", - "simd/x86_64/jcgray-avx2.obj", - "simd/x86_64/jcgray-sse2.obj", - "simd/x86_64/jchuff-sse2.obj", - "simd/x86_64/jcphuff-sse2.obj", - "simd/x86_64/jcsample-avx2.obj", - "simd/x86_64/jcsample-sse2.obj", - "simd/x86_64/jdcolor-avx2.obj", - "simd/x86_64/jdcolor-sse2.obj", - "simd/x86_64/jdmerge-avx2.obj", - "simd/x86_64/jdmerge-sse2.obj", - "simd/x86_64/jdsample-avx2.obj", - "simd/x86_64/jdsample-sse2.obj", - "simd/x86_64/jfdctflt-sse.obj", - "simd/x86_64/jfdctfst-sse2.obj", - "simd/x86_64/jfdctint-avx2.obj", - "simd/x86_64/jfdctint-sse2.obj", - "simd/x86_64/jidctflt-sse2.obj", - "simd/x86_64/jidctfst-sse2.obj", - "simd/x86_64/jidctint-avx2.obj", - "simd/x86_64/jidctint-sse2.obj", - "simd/x86_64/jidctred-sse2.obj", - "simd/x86_64/jquantf-sse2.obj", - "simd/x86_64/jquanti-avx2.obj", - "simd/x86_64/jquanti-sse2.obj", - "simd/x86_64/jsimd.c", - "simd/x86_64/jsimdcpu.obj", - ] + SRCS_SIMD_COMMON, - copts = libjpegturbo_copts, -) - -genrule( - name = "simd_win_x86_64_assemble", - srcs = [ - "jconfig.h", - "jconfigint.h", - "simd/x86_64/jccolext-avx2.asm", - "simd/x86_64/jccolext-sse2.asm", - "simd/x86_64/jccolor-avx2.asm", - "simd/x86_64/jccolor-sse2.asm", - "simd/x86_64/jcgray-avx2.asm", - "simd/x86_64/jcgray-sse2.asm", - "simd/x86_64/jcgryext-avx2.asm", - "simd/x86_64/jcgryext-sse2.asm", - "simd/x86_64/jchuff-sse2.asm", - "simd/x86_64/jcphuff-sse2.asm", - "simd/x86_64/jcsample-avx2.asm", - "simd/x86_64/jcsample-sse2.asm", - "simd/x86_64/jdcolext-avx2.asm", - "simd/x86_64/jdcolext-sse2.asm", - "simd/x86_64/jdcolor-avx2.asm", - "simd/x86_64/jdcolor-sse2.asm", - "simd/x86_64/jdmerge-avx2.asm", - "simd/x86_64/jdmerge-sse2.asm", - "simd/x86_64/jdmrgext-avx2.asm", - "simd/x86_64/jdmrgext-sse2.asm", - "simd/x86_64/jdsample-avx2.asm", - "simd/x86_64/jdsample-sse2.asm", - "simd/x86_64/jfdctflt-sse.asm", - "simd/x86_64/jfdctfst-sse2.asm", - "simd/x86_64/jfdctint-avx2.asm", - "simd/x86_64/jfdctint-sse2.asm", - "simd/x86_64/jidctflt-sse2.asm", - "simd/x86_64/jidctfst-sse2.asm", - "simd/x86_64/jidctint-avx2.asm", - "simd/x86_64/jidctint-sse2.asm", - "simd/x86_64/jidctred-sse2.asm", - "simd/x86_64/jquantf-sse2.asm", - "simd/x86_64/jquanti-avx2.asm", - "simd/x86_64/jquanti-sse2.asm", - "simd/x86_64/jsimdcpu.asm", - "simd/nasm/jcolsamp.inc", - "simd/nasm/jdct.inc", - "simd/nasm/jsimdcfg.inc", - "simd/nasm/jsimdcfg.inc.h", - "simd/nasm/jsimdext.inc", - ], - outs = [ - "simd/x86_64/jccolor-avx2.obj", - "simd/x86_64/jccolor-sse2.obj", - "simd/x86_64/jcgray-avx2.obj", - "simd/x86_64/jcgray-sse2.obj", - "simd/x86_64/jchuff-sse2.obj", - "simd/x86_64/jcphuff-sse2.obj", - "simd/x86_64/jcsample-avx2.obj", - "simd/x86_64/jcsample-sse2.obj", - "simd/x86_64/jdcolor-avx2.obj", - "simd/x86_64/jdcolor-sse2.obj", - "simd/x86_64/jdmerge-avx2.obj", - "simd/x86_64/jdmerge-sse2.obj", - "simd/x86_64/jdsample-avx2.obj", - "simd/x86_64/jdsample-sse2.obj", - "simd/x86_64/jfdctflt-sse.obj", - "simd/x86_64/jfdctfst-sse2.obj", - "simd/x86_64/jfdctint-avx2.obj", - "simd/x86_64/jfdctint-sse2.obj", - "simd/x86_64/jidctflt-sse2.obj", - "simd/x86_64/jidctfst-sse2.obj", - "simd/x86_64/jidctint-avx2.obj", - "simd/x86_64/jidctint-sse2.obj", - "simd/x86_64/jidctred-sse2.obj", - "simd/x86_64/jquantf-sse2.obj", - "simd/x86_64/jquanti-avx2.obj", - "simd/x86_64/jquanti-sse2.obj", - "simd/x86_64/jsimdcpu.obj", - ], - cmd = "for out in $(OUTS); do\n" + - " $(location @nasm//:nasm) -fwin64 -DWIN64 -D__x86_64__" + - " -I $$(dirname $(location simd/x86_64/jccolext-sse2.asm))/" + - " -I $$(dirname $(location simd/nasm/jdct.inc))/" + - " -I $$(dirname $(location simd/nasm/jdct.inc))/../../win/" + - " -o $$out" + - " $$(dirname $(location simd/x86_64/jccolext-sse2.asm))/$$(basename $${out%.obj}.asm)\n" + - "done", - tools = ["@nasm"], -) - -cc_library( - name = "simd_none", - srcs = [ - "jchuff.h", - "jconfig.h", - "jconfigint.h", - "jdct.h", - "jerror.h", - "jinclude.h", - "jmorecfg.h", - "jpegint.h", - "jpeglib.h", - "jsimd.h", - "jsimd_none.c", - "jsimddct.h", - ], - copts = libjpegturbo_copts, -) - -expand_template( - name = "jversion", - out = "jversion.h", - substitutions = { - "@COPYRIGHT_YEAR@": "1991-2022", - }, - template = "jversion.h.in", -) - -expand_template( - name = "jconfig_win", - out = "jconfig_win.h", - substitutions = { - "@JPEG_LIB_VERSION@": "62", - "@VERSION@": "2.1.4", - "@LIBJPEG_TURBO_VERSION_NUMBER@": "2001004", - "@BITS_IN_JSAMPLE@": "8", - "#cmakedefine C_ARITH_CODING_SUPPORTED": "#define C_ARITH_CODING_SUPPORTED", - "#cmakedefine D_ARITH_CODING_SUPPORTED": "#define D_ARITH_CODING_SUPPORTED", - "#cmakedefine MEM_SRCDST_SUPPORTED": "#define MEM_SRCDST_SUPPORTED", - "#cmakedefine WITH_SIMD": "", - }, - template = "win/jconfig.h.in", -) - -JCONFIG_NOWIN_COMMON_SUBSTITUTIONS = { - "@JPEG_LIB_VERSION@": "62", - "@VERSION@": "2.1.4", - "@LIBJPEG_TURBO_VERSION_NUMBER@": "2001004", - "#cmakedefine C_ARITH_CODING_SUPPORTED 1": "#define C_ARITH_CODING_SUPPORTED 1", - "#cmakedefine D_ARITH_CODING_SUPPORTED 1": "#define D_ARITH_CODING_SUPPORTED 1", - "#cmakedefine MEM_SRCDST_SUPPORTED 1": "#define MEM_SRCDST_SUPPORTED 1", - "@BITS_IN_JSAMPLE@": "8", - "#cmakedefine HAVE_LOCALE_H 1": "#define HAVE_LOCALE_H 1", - "#cmakedefine HAVE_STDDEF_H 1": "#define HAVE_STDDEF_H 1", - "#cmakedefine HAVE_STDLIB_H 1": "#define HAVE_STDLIB_H 1", - "#cmakedefine NEED_SYS_TYPES_H 1": "#define NEED_SYS_TYPES_H 1", - "#cmakedefine NEED_BSD_STRINGS 1": "", - "#cmakedefine HAVE_UNSIGNED_CHAR 1": "#define HAVE_UNSIGNED_CHAR 1", - "#cmakedefine HAVE_UNSIGNED_SHORT 1": "#define HAVE_UNSIGNED_SHORT 1", - "#cmakedefine INCOMPLETE_TYPES_BROKEN 1": "", - "#cmakedefine RIGHT_SHIFT_IS_UNSIGNED 1": "", - "#cmakedefine __CHAR_UNSIGNED__ 1": "", - "#undef const": "", - "#undef size_t": "", -} - -JCONFIG_NOWIN_SIMD_SUBSTITUTIONS = { - "#cmakedefine WITH_SIMD 1": "#define WITH_SIMD 1", -} - -JCONFIG_NOWIN_NOSIMD_SUBSTITUTIONS = { - "#cmakedefine WITH_SIMD 1": "", -} - -JCONFIG_NOWIN_SIMD_SUBSTITUTIONS.update(JCONFIG_NOWIN_COMMON_SUBSTITUTIONS) - -JCONFIG_NOWIN_NOSIMD_SUBSTITUTIONS.update(JCONFIG_NOWIN_COMMON_SUBSTITUTIONS) - -expand_template( - name = "jconfig_nowin_nosimd", - out = "jconfig_nowin_nosimd.h", - substitutions = JCONFIG_NOWIN_NOSIMD_SUBSTITUTIONS, - template = "jconfig.h.in", -) - -expand_template( - name = "jconfig_nowin_simd", - out = "jconfig_nowin_simd.h", - substitutions = JCONFIG_NOWIN_SIMD_SUBSTITUTIONS, - template = "jconfig.h.in", -) - -JCONFIGINT_COMMON_SUBSTITUTIONS = { - "@BUILD@": "20221022", - "@VERSION@": "2.1.4", - "@CMAKE_PROJECT_NAME@": "libjpeg-turbo", - "#undef inline": "", - "#cmakedefine HAVE_INTRIN_H": "", -} - -JCONFIGINT_NOWIN_SUBSTITUTIONS = { - "#cmakedefine HAVE_BUILTIN_CTZL": "#define HAVE_BUILTIN_CTZL", - "@INLINE@": "inline __attribute__((always_inline))", - "#define SIZEOF_SIZE_T @SIZE_T@": "#if (__WORDSIZE==64 && !defined(__native_client__))\n" + - "#define SIZEOF_SIZE_T 8\n" + - "#else\n" + - "#define SIZEOF_SIZE_T 4\n" + - "#endif\n", -} - -JCONFIGINT_WIN_SUBSTITUTIONS = { - "#cmakedefine HAVE_BUILTIN_CTZL": "", - "#define INLINE @INLINE@": "#if defined(__GNUC__)\n" + - "#define INLINE inline __attribute__((always_inline))\n" + - "#elif defined(_MSC_VER)\n" + - "#define INLINE __forceinline\n" + - "#else\n" + - "#define INLINE\n" + - "#endif\n", - "#define SIZEOF_SIZE_T @SIZE_T@": "#if (__WORDSIZE==64)\n" + - "#define SIZEOF_SIZE_T 8\n" + - "#else\n" + - "#define SIZEOF_SIZE_T 4\n" + - "#endif\n", -} - -JCONFIGINT_NOWIN_SUBSTITUTIONS.update(JCONFIGINT_COMMON_SUBSTITUTIONS) - -JCONFIGINT_WIN_SUBSTITUTIONS.update(JCONFIGINT_COMMON_SUBSTITUTIONS) - -expand_template( - name = "jconfigint_nowin", - out = "jconfigint_nowin.h", - substitutions = JCONFIGINT_NOWIN_SUBSTITUTIONS, - template = "jconfigint.h.in", -) - -expand_template( - name = "jconfigint_win", - out = "jconfigint_win.h", - substitutions = JCONFIGINT_WIN_SUBSTITUTIONS, - template = "jconfigint.h.in", -) - -genrule( - name = "configure", - srcs = [ - "jconfig_win.h", - "jconfig_nowin_nosimd.h", - "jconfig_nowin_simd.h", - ], - outs = ["jconfig.h"], - cmd = select({ - ":windows": "cp $(location jconfig_win.h) $@", - ":k8": "cp $(location jconfig_nowin_simd.h) $@", - ":armeabi-v7a": "cp $(location jconfig_nowin_simd.h) $@", - ":arm64-v8a": "cp $(location jconfig_nowin_simd.h) $@", - ":linux_ppc64le": "cp $(location jconfig_nowin_simd.h) $@", - "//conditions:default": "cp $(location jconfig_nowin_nosimd.h) $@", - }), -) - -genrule( - name = "configure_internal", - srcs = [ - "jconfigint_win.h", - "jconfigint_nowin.h", - ], - outs = ["jconfigint.h"], - cmd = select({ - ":windows": "cp $(location jconfigint_win.h) $@", - "//conditions:default": "cp $(location jconfigint_nowin.h) $@", - }), -) - -# jiminy cricket the way this file is generated is completely outrageous -genrule( - name = "configure_simd", - outs = ["simd/jsimdcfg.inc"], - cmd = "cat <<'EOF' >$@\n" + - "%define DCTSIZE 8\n" + - "%define DCTSIZE2 64\n" + - "%define RGB_RED 0\n" + - "%define RGB_GREEN 1\n" + - "%define RGB_BLUE 2\n" + - "%define RGB_PIXELSIZE 3\n" + - "%define EXT_RGB_RED 0\n" + - "%define EXT_RGB_GREEN 1\n" + - "%define EXT_RGB_BLUE 2\n" + - "%define EXT_RGB_PIXELSIZE 3\n" + - "%define EXT_RGBX_RED 0\n" + - "%define EXT_RGBX_GREEN 1\n" + - "%define EXT_RGBX_BLUE 2\n" + - "%define EXT_RGBX_PIXELSIZE 4\n" + - "%define EXT_BGR_RED 2\n" + - "%define EXT_BGR_GREEN 1\n" + - "%define EXT_BGR_BLUE 0\n" + - "%define EXT_BGR_PIXELSIZE 3\n" + - "%define EXT_BGRX_RED 2\n" + - "%define EXT_BGRX_GREEN 1\n" + - "%define EXT_BGRX_BLUE 0\n" + - "%define EXT_BGRX_PIXELSIZE 4\n" + - "%define EXT_XBGR_RED 3\n" + - "%define EXT_XBGR_GREEN 2\n" + - "%define EXT_XBGR_BLUE 1\n" + - "%define EXT_XBGR_PIXELSIZE 4\n" + - "%define EXT_XRGB_RED 1\n" + - "%define EXT_XRGB_GREEN 2\n" + - "%define EXT_XRGB_BLUE 3\n" + - "%define EXT_XRGB_PIXELSIZE 4\n" + - "%define RGBX_FILLER_0XFF 1\n" + - "%define JSAMPLE byte ; unsigned char\n" + - "%define SIZEOF_JSAMPLE SIZEOF_BYTE ; sizeof(JSAMPLE)\n" + - "%define CENTERJSAMPLE 128\n" + - "%define JCOEF word ; short\n" + - "%define SIZEOF_JCOEF SIZEOF_WORD ; sizeof(JCOEF)\n" + - "%define JDIMENSION dword ; unsigned int\n" + - "%define SIZEOF_JDIMENSION SIZEOF_DWORD ; sizeof(JDIMENSION)\n" + - "%define JSAMPROW POINTER ; JSAMPLE * (jpeglib.h)\n" + - "%define JSAMPARRAY POINTER ; JSAMPROW * (jpeglib.h)\n" + - "%define JSAMPIMAGE POINTER ; JSAMPARRAY * (jpeglib.h)\n" + - "%define JCOEFPTR POINTER ; JCOEF * (jpeglib.h)\n" + - "%define SIZEOF_JSAMPROW SIZEOF_POINTER ; sizeof(JSAMPROW)\n" + - "%define SIZEOF_JSAMPARRAY SIZEOF_POINTER ; sizeof(JSAMPARRAY)\n" + - "%define SIZEOF_JSAMPIMAGE SIZEOF_POINTER ; sizeof(JSAMPIMAGE)\n" + - "%define SIZEOF_JCOEFPTR SIZEOF_POINTER ; sizeof(JCOEFPTR)\n" + - "%define DCTELEM word ; short\n" + - "%define SIZEOF_DCTELEM SIZEOF_WORD ; sizeof(DCTELEM)\n" + - "%define float FP32 ; float\n" + - "%define SIZEOF_FAST_FLOAT SIZEOF_FP32 ; sizeof(float)\n" + - "%define ISLOW_MULT_TYPE word ; must be short\n" + - "%define SIZEOF_ISLOW_MULT_TYPE SIZEOF_WORD ; sizeof(ISLOW_MULT_TYPE)\n" + - "%define IFAST_MULT_TYPE word ; must be short\n" + - "%define SIZEOF_IFAST_MULT_TYPE SIZEOF_WORD ; sizeof(IFAST_MULT_TYPE)\n" + - "%define IFAST_SCALE_BITS 2 ; fractional bits in scale factors\n" + - "%define FLOAT_MULT_TYPE FP32 ; must be float\n" + - "%define SIZEOF_FLOAT_MULT_TYPE SIZEOF_FP32 ; sizeof(FLOAT_MULT_TYPE)\n" + - "%define JSIMD_NONE 0x00\n" + - "%define JSIMD_MMX 0x01\n" + - "%define JSIMD_3DNOW 0x02\n" + - "%define JSIMD_SSE 0x04\n" + - "%define JSIMD_SSE2 0x08\n" + - "EOF", -) - -string_flag( - name = "noasm", - build_setting_default = "no", -) - -config_setting( - name = "nosimd", - flag_values = {":noasm": "yes"}, -) - -config_setting( - name = "k8", - flag_values = {":noasm": "no"}, - values = {"cpu": "k8"}, -) - -config_setting( - name = "android", - values = {"crosstool_top": "//external:android/crosstool"}, -) - -config_setting( - name = "armeabi-v7a", - flag_values = {":noasm": "no"}, - values = {"cpu": "armeabi-v7a"}, -) - -config_setting( - name = "arm64-v8a", - flag_values = {":noasm": "no"}, - values = {"cpu": "arm64-v8a"}, -) - -config_setting( - name = "windows", - flag_values = {":noasm": "no"}, - values = {"cpu": "x64_windows"}, -) - -config_setting( - name = "linux_ppc64le", - flag_values = {":noasm": "no"}, - values = {"cpu": "ppc"}, -) diff --git a/third_party/xla/third_party/tsl/third_party/jpeg/jpeg_helpers.BUILD.bazel b/third_party/xla/third_party/tsl/third_party/jpeg/jpeg_helpers.BUILD.bazel deleted file mode 100644 index 5b01f6e3e4cfd1..00000000000000 --- a/third_party/xla/third_party/tsl/third_party/jpeg/jpeg_helpers.BUILD.bazel +++ /dev/null @@ -1 +0,0 @@ -licenses(["notice"]) diff --git a/third_party/xla/third_party/tsl/third_party/jpeg/workspace.bzl b/third_party/xla/third_party/tsl/third_party/jpeg/workspace.bzl deleted file mode 100644 index 631cc933bc60d9..00000000000000 --- a/third_party/xla/third_party/tsl/third_party/jpeg/workspace.bzl +++ /dev/null @@ -1,13 +0,0 @@ -"""loads the jpeg library, used by TF.""" - -load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") - -def repo(): - tf_http_archive( - name = "libjpeg_turbo", - urls = tf_mirror_urls("https://github.com/libjpeg-turbo/libjpeg-turbo/archive/refs/tags/2.1.4.tar.gz"), - sha256 = "a78b05c0d8427a90eb5b4eb08af25309770c8379592bb0b8a863373128e6143f", - strip_prefix = "libjpeg-turbo-2.1.4", - build_file = "//third_party/jpeg:jpeg.BUILD", - system_build_file = "//third_party/jpeg:BUILD.system", - ) diff --git a/third_party/xla/third_party/tsl/third_party/llvm_openmp/BUILD b/third_party/xla/third_party/tsl/third_party/llvm_openmp/BUILD index 9c1e23f1f888c5..34ad101bd35036 100644 --- a/third_party/xla/third_party/tsl/third_party/llvm_openmp/BUILD +++ b/third_party/xla/third_party/tsl/third_party/llvm_openmp/BUILD @@ -19,19 +19,20 @@ load( load("@bazel_skylib//:bzl_library.bzl", "bzl_library") package( - default_visibility = ["//visibility:public"], + default_visibility = [ + "//visibility:public", + ], ) -exports_files( - ["LICENSE.txt"], - visibility = ["//visibility:public"], -) +exports_files(["LICENSE.txt"]) py_binary( name = "expand_cmake_vars", srcs = ["expand_cmake_vars.py"], srcs_version = "PY3", - visibility = ["//visibility:public"], + visibility = [ + "@llvm_openmp//:__subpackages__", + ], ) kmp_i18n_os_type = select({ @@ -239,5 +240,4 @@ if_windows(a = libiomp5_cc_binary( bzl_library( name = "openmp_bzl", srcs = ["openmp.bzl"], - visibility = ["//visibility:public"], ) diff --git a/third_party/xla/third_party/tsl/third_party/mkl/BUILD b/third_party/xla/third_party/tsl/third_party/mkl/BUILD index cfbdc211790d25..6da193d41ba067 100644 --- a/third_party/xla/third_party/tsl/third_party/mkl/BUILD +++ b/third_party/xla/third_party/tsl/third_party/mkl/BUILD @@ -7,67 +7,56 @@ package(default_visibility = ["//visibility:public"]) alias( name = "build_with_mkl", actual = "//tsl/mkl:build_with_mkl", - visibility = ["//visibility:public"], ) alias( name = "build_with_mkl_lnx_x64", actual = "//tsl/mkl:build_with_mkl_lnx_x64", - visibility = ["//visibility:public"], ) alias( name = "build_with_mkl_lnx_openmp", actual = "//tsl/mkl:build_with_mkl_lnx_openmp", - visibility = ["//visibility:public"], ) alias( name = "build_with_mkl_windows_openmp", actual = "//tsl/mkl:build_with_mkl_windows_openmp", - visibility = ["//visibility:public"], ) alias( name = "build_with_mkl_aarch64", actual = "//tsl/mkl:build_with_mkl_aarch64", - visibility = ["//visibility:public"], ) alias( name = "enable_mkl", actual = "//tsl/mkl:enable_mkl", - visibility = ["//visibility:public"], ) alias( name = "intel_binary_blob", actual = "//tsl/mkl:intel_binary_blob", - visibility = ["//visibility:public"], ) alias( name = "LICENSE", actual = "//tsl/mkl:LICENSE", - visibility = ["//visibility:public"], ) alias( name = "mkl_libs_linux", actual = "//tsl/mkl:mkl_libs_linux", - visibility = ["//visibility:public"], ) alias( name = "mkl_libs_darwin", actual = "//tsl/mkl:mkl_libs_darwin", - visibility = ["//visibility:public"], ) alias( name = "mkl_libs_windows", actual = "//tsl/mkl:mkl_libs_windows", - visibility = ["//visibility:public"], ) bzl_library( diff --git a/third_party/xla/third_party/tsl/third_party/mkl/build_defs.bzl b/third_party/xla/third_party/tsl/third_party/mkl/build_defs.bzl index 76bea5d8552b2e..6af999122a90f0 100644 --- a/third_party/xla/third_party/tsl/third_party/mkl/build_defs.bzl +++ b/third_party/xla/third_party/tsl/third_party/mkl/build_defs.bzl @@ -13,7 +13,7 @@ mkl_repository depends on the following environment variables: """ load( - "@local_tsl//tsl/mkl:build_defs.bzl", + "//tsl/mkl:build_defs.bzl", _if_enable_mkl = "if_enable_mkl", _if_mkl = "if_mkl", _if_mkl_lnx_x64 = "if_mkl_lnx_x64", diff --git a/third_party/xla/third_party/tsl/third_party/mkl_dnn/BUILD b/third_party/xla/third_party/tsl/third_party/mkl_dnn/BUILD index c536923f794b07..a9cdec9dd632aa 100644 --- a/third_party/xla/third_party/tsl/third_party/mkl_dnn/BUILD +++ b/third_party/xla/third_party/tsl/third_party/mkl_dnn/BUILD @@ -1,6 +1,7 @@ load("@bazel_skylib//:bzl_library.bzl", "bzl_library") package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], default_visibility = ["//visibility:public"], licenses = ["notice"], ) @@ -16,7 +17,6 @@ config_setting( "build_with_mkl": "true", "build_with_mkl_opensource": "true", }, - visibility = ["//visibility:public"], ) config_setting( @@ -25,7 +25,6 @@ config_setting( "build_with_mkl": "true", "build_with_openmp": "true", }, - visibility = ["//visibility:public"], ) config_setting( @@ -34,7 +33,6 @@ config_setting( "build_with_mkl_aarch64": "true", "build_with_openmp": "true", }, - visibility = ["//visibility:public"], ) config_setting( @@ -42,11 +40,9 @@ config_setting( define_values = { "build_with_mkl_aarch64": "true", }, - visibility = ["//visibility:public"], ) bzl_library( name = "build_defs_bzl", srcs = ["build_defs.bzl"], - visibility = ["//visibility:public"], ) diff --git a/third_party/xla/third_party/tsl/third_party/mkl_dnn/mkldnn_v1.BUILD b/third_party/xla/third_party/tsl/third_party/mkl_dnn/mkldnn_v1.BUILD index a2528211936972..69843d482c80d7 100644 --- a/third_party/xla/third_party/tsl/third_party/mkl_dnn/mkldnn_v1.BUILD +++ b/third_party/xla/third_party/tsl/third_party/mkl_dnn/mkldnn_v1.BUILD @@ -95,7 +95,7 @@ expand_template( substitutions = { "@DNNL_VERSION_MAJOR@": "3", "@DNNL_VERSION_MINOR@": "3", - "@DNNL_VERSION_PATCH@": "0", + "@DNNL_VERSION_PATCH@": "4", "@DNNL_VERSION_HASH@": "N/A", }, template = "include/oneapi/dnnl/dnnl_version.h.in", diff --git a/third_party/xla/third_party/tsl/third_party/nvtx/BUILD b/third_party/xla/third_party/tsl/third_party/nvtx/BUILD index 48fb92f3cc5840..af6de99cb8fcf7 100644 --- a/third_party/xla/third_party/tsl/third_party/nvtx/BUILD +++ b/third_party/xla/third_party/tsl/third_party/nvtx/BUILD @@ -2,10 +2,7 @@ licenses(["notice"]) -exports_files( - ["LICENSE.txt"], - visibility = ["//visibility:public"], -) +exports_files(["LICENSE.txt"]) cc_library( name = "headers", diff --git a/third_party/xla/third_party/tsl/third_party/png.BUILD b/third_party/xla/third_party/tsl/third_party/png.BUILD deleted file mode 100644 index 95ea74b8a03c3c..00000000000000 --- a/third_party/xla/third_party/tsl/third_party/png.BUILD +++ /dev/null @@ -1,70 +0,0 @@ -# Description: -# libpng is the official PNG reference library. - -licenses(["notice"]) # BSD/MIT-like license - -exports_files(["LICENSE"]) - -cc_library( - name = "png", - srcs = [ - "png.c", - "pngdebug.h", - "pngerror.c", - "pngget.c", - "pnginfo.h", - "pnglibconf.h", - "pngmem.c", - "pngpread.c", - "pngpriv.h", - "pngread.c", - "pngrio.c", - "pngrtran.c", - "pngrutil.c", - "pngset.c", - "pngstruct.h", - "pngtrans.c", - "pngwio.c", - "pngwrite.c", - "pngwtran.c", - "pngwutil.c", - ] + select({ - ":windows": [ - "intel/filter_sse2_intrinsics.c", - "intel/intel_init.c", - ], - "@local_tsl//tsl:linux_ppc64le": [ - #"powerpc/filter_vsx_intrinsics.c", - #"powerpc/powerpc_init.c", - ], - "//conditions:default": [ - ], - }), - hdrs = [ - "png.h", - "pngconf.h", - ], - copts = select({ - ":windows": ["-DPNG_INTEL_SSE_OPT=1"], - "//conditions:default": [], - }), - includes = ["."], - linkopts = select({ - ":windows": [], - "//conditions:default": ["-lm"], - }), - visibility = ["//visibility:public"], - deps = ["@zlib"], -) - -genrule( - name = "snappy_stubs_public_h", - srcs = ["scripts/pnglibconf.h.prebuilt"], - outs = ["pnglibconf.h"], - cmd = "sed -e 's/PNG_ZLIB_VERNUM 0/PNG_ZLIB_VERNUM 0x12d0/' $< >$@", -) - -config_setting( - name = "windows", - values = {"cpu": "x64_windows"}, -) diff --git a/third_party/xla/third_party/tsl/third_party/png_fix_rpi.patch b/third_party/xla/third_party/tsl/third_party/png_fix_rpi.patch deleted file mode 100644 index df6cfd7ffaee55..00000000000000 --- a/third_party/xla/third_party/tsl/third_party/png_fix_rpi.patch +++ /dev/null @@ -1,16 +0,0 @@ -diff -r -u ./scripts/pnglibconf.h.prebuilt ./scripts/pnglibconf.h.prebuilt ---- ./scripts/pnglibconf.h.prebuilt -+++ ./scripts/pnglibconf.h.prebuilt -@@ -19,6 +19,12 @@ - #define PNG_ALIGNED_MEMORY_SUPPORTED - /*#undef PNG_ARM_NEON_API_SUPPORTED*/ - /*#undef PNG_ARM_NEON_CHECK_SUPPORTED*/ -+ -+/* Workaround not having a great build file by forcing -+ * png filter optimization to be disabled on arm */ -+#define PNG_ARM_NEON_OPT 0 -+ -+ - #define PNG_BENIGN_ERRORS_SUPPORTED - #define PNG_BENIGN_READ_ERRORS_SUPPORTED - /*#undef PNG_BENIGN_WRITE_ERRORS_SUPPORTED*/ diff --git a/third_party/xla/third_party/tsl/third_party/python_runtime/BUILD b/third_party/xla/third_party/tsl/third_party/python_runtime/BUILD index 14210ebf684fa9..2a1609191fe351 100644 --- a/third_party/xla/third_party/tsl/third_party/python_runtime/BUILD +++ b/third_party/xla/third_party/tsl/third_party/python_runtime/BUILD @@ -5,5 +5,4 @@ package(default_visibility = ["//visibility:public"]) alias( name = "headers", actual = "@local_config_python//:python_headers", - visibility = ["//visibility:public"], ) diff --git a/third_party/xla/third_party/tsl/third_party/tensorrt/plugin/BUILD b/third_party/xla/third_party/tsl/third_party/tensorrt/plugin/BUILD index 2c76d3db31dd70..56e26d779de155 100644 --- a/third_party/xla/third_party/tsl/third_party/tensorrt/plugin/BUILD +++ b/third_party/xla/third_party/tsl/third_party/tensorrt/plugin/BUILD @@ -3,10 +3,7 @@ # TensorRT open source repository. load("@local_config_cuda//cuda:build_defs.bzl", "cuda_default_copts", "cuda_library") -exports_files( - ["LICENSE"], - visibility = ["//visibility:public"], -) +exports_files(["LICENSE"]) cuda_library( name = "plugin_common", @@ -19,7 +16,6 @@ cuda_library( "plugin/common/plugin.h", ], strip_include_prefix = "plugin/common", - visibility = ["//visibility:public"], deps = [ "@local_config_tensorrt//:tensorrt", "@local_config_tensorrt//:tensorrt_headers", @@ -33,7 +29,6 @@ cc_library( "plugin/efficientNMSPlugin/efficientNMSParameters.h", "plugin/efficientNMSPlugin/efficientNMSPlugin.h", ], - visibility = ["//visibility:public"], ) cuda_library( diff --git a/third_party/xla/third_party/tsl/third_party/tf_runtime/workspace.bzl b/third_party/xla/third_party/tsl/third_party/tf_runtime/workspace.bzl index 10a56bfa0edce1..99aa32a79a30aa 100644 --- a/third_party/xla/third_party/tsl/third_party/tf_runtime/workspace.bzl +++ b/third_party/xla/third_party/tsl/third_party/tf_runtime/workspace.bzl @@ -6,8 +6,8 @@ def repo(): """Imports TFRT.""" # Attention: tools parse and update these lines. - TFRT_COMMIT = "e99b8f121f63cdfae811b2cafc4dab5ce97986f6" - TFRT_SHA256 = "0e0ec61414532ec44f271ed7450253c462c2789d2a2c24c178e1377bef10f3da" + TFRT_COMMIT = "aec2070dee4792b80177d167f26491b1d30eced4" + TFRT_SHA256 = "0d398b68353ae8e547f4f974d43b3c29c9ce9cce535c66dba5efcd5bee4ad36d" tf_http_archive( name = "tf_runtime", diff --git a/third_party/xla/third_party/tsl/tools/def_file_filter/BUILD b/third_party/xla/third_party/tsl/tools/def_file_filter/BUILD index 8fa81a4475ac64..250f31a6beb7ae 100644 --- a/third_party/xla/third_party/tsl/tools/def_file_filter/BUILD +++ b/third_party/xla/third_party/tsl/tools/def_file_filter/BUILD @@ -11,5 +11,4 @@ package(default_visibility = ["//visibility:public"]) filegroup( name = "symbols_pybind", srcs = ["symbols_pybind.txt"], - visibility = ["//visibility:public"], ) diff --git a/third_party/xla/third_party/tsl/tools/toolchains/cpus/py/BUILD b/third_party/xla/third_party/tsl/tools/toolchains/cpus/py/BUILD index 54feb1695a21e3..1235988abb7fa9 100644 --- a/third_party/xla/third_party/tsl/tools/toolchains/cpus/py/BUILD +++ b/third_party/xla/third_party/tsl/tools/toolchains/cpus/py/BUILD @@ -22,7 +22,6 @@ cc_library( name = "python_headers", hdrs = [":python_include"], includes = ["python_include"], - visibility = ["//visibility:public"], deps = select({ ":windows": [":python_lib"], "//conditions:default": [], @@ -33,7 +32,6 @@ cc_library( name = "numpy_headers", hdrs = [":numpy_include"], includes = ["numpy_include"], - visibility = ["//visibility:public"], ) config_setting( diff --git a/third_party/xla/third_party/tsl/tools/toolchains/cpus/py3/BUILD b/third_party/xla/third_party/tsl/tools/toolchains/cpus/py3/BUILD index 5dc47b98284c89..d47256ebef88fa 100644 --- a/third_party/xla/third_party/tsl/tools/toolchains/cpus/py3/BUILD +++ b/third_party/xla/third_party/tsl/tools/toolchains/cpus/py3/BUILD @@ -22,7 +22,6 @@ cc_library( name = "python_headers", hdrs = [":python_include"], includes = ["python_include"], - visibility = ["//visibility:public"], deps = select({ ":windows": [":python_lib"], "//conditions:default": [], @@ -33,7 +32,6 @@ cc_library( name = "numpy_headers", hdrs = [":numpy_include"], includes = ["numpy_include"], - visibility = ["//visibility:public"], ) config_setting( diff --git a/third_party/xla/third_party/tsl/tools/toolchains/cross_compile/cc/BUILD b/third_party/xla/third_party/tsl/tools/toolchains/cross_compile/cc/BUILD index 976e57b777f15d..7cf6d8c3747b27 100644 --- a/third_party/xla/third_party/tsl/tools/toolchains/cross_compile/cc/BUILD +++ b/third_party/xla/third_party/tsl/tools/toolchains/cross_compile/cc/BUILD @@ -15,10 +15,7 @@ cc_toolchain_suite( }, ) -filegroup( - name = "empty", - visibility = ["//visibility:public"], -) +filegroup(name = "empty") # We define a wraper ("cc_wrapper.sh") around the compiler to replace all paths # in the binary (bazel-out/.../path/to/original/library.so) by the paths @@ -27,7 +24,6 @@ filegroup( filegroup( name = "cc_wrapper_and_macos_sysroot", srcs = ["cc_wrapper.sh"] + glob(["MacOSX.sdk/**"]), - visibility = ["//visibility:public"], ) cc_toolchain( diff --git a/third_party/xla/third_party/tsl/tools/toolchains/remote_config/configs.bzl b/third_party/xla/third_party/tsl/tools/toolchains/remote_config/configs.bzl index aab6ac89e37c67..7bbfb8b2854ca4 100644 --- a/third_party/xla/third_party/tsl/tools/toolchains/remote_config/configs.bzl +++ b/third_party/xla/third_party/tsl/tools/toolchains/remote_config/configs.bzl @@ -627,7 +627,7 @@ def initialize_rbe_configs(): sigbuild_tf_configs( name_container_map = { - "sigbuild-r2.16": "docker://gcr.io/tensorflow-sigs/build@sha256:22d863e6fe3f98946015b9e1264b2eeb8e56e504535a6c1d5e564cae65ae5d37", + "sigbuild-r2.16": "docker://gcr.io/tensorflow-sigs/build@sha256:842a5ba84d3658c5bf1f8a31e16284f7becc35409da0dfd71816afa3cd28d728", "sigbuild-r2.16-python3.9": "docker://gcr.io/tensorflow-sigs/build@sha256:22d863e6fe3f98946015b9e1264b2eeb8e56e504535a6c1d5e564cae65ae5d37", "sigbuild-r2.16-python3.10": "docker://gcr.io/tensorflow-sigs/build@sha256:da15288c8464153eadd35da720540a544b76aa9d78cceb42a6821b2f3e70a0fa", "sigbuild-r2.16-python3.11": "docker://gcr.io/tensorflow-sigs/build@sha256:842a5ba84d3658c5bf1f8a31e16284f7becc35409da0dfd71816afa3cd28d728", @@ -667,7 +667,7 @@ def initialize_rbe_configs(): sigbuild_tf_configs( name_container_map = { - "sigbuild-r2.16-clang": "docker://gcr.io/tensorflow-sigs/build@sha256:22d863e6fe3f98946015b9e1264b2eeb8e56e504535a6c1d5e564cae65ae5d37", + "sigbuild-r2.16-clang": "docker://gcr.io/tensorflow-sigs/build@sha256:842a5ba84d3658c5bf1f8a31e16284f7becc35409da0dfd71816afa3cd28d728", "sigbuild-r2.16-clang-python3.9": "docker://gcr.io/tensorflow-sigs/build@sha256:22d863e6fe3f98946015b9e1264b2eeb8e56e504535a6c1d5e564cae65ae5d37", "sigbuild-r2.16-clang-python3.10": "docker://gcr.io/tensorflow-sigs/build@sha256:da15288c8464153eadd35da720540a544b76aa9d78cceb42a6821b2f3e70a0fa", "sigbuild-r2.16-clang-python3.11": "docker://gcr.io/tensorflow-sigs/build@sha256:842a5ba84d3658c5bf1f8a31e16284f7becc35409da0dfd71816afa3cd28d728", diff --git a/third_party/xla/third_party/tsl/tools/toolchains/win/bazel_211/BUILD b/third_party/xla/third_party/tsl/tools/toolchains/win/bazel_211/BUILD index 07aff97390d02e..cc23c8ecb22680 100644 --- a/third_party/xla/third_party/tsl/tools/toolchains/win/bazel_211/BUILD +++ b/third_party/xla/third_party/tsl/tools/toolchains/win/bazel_211/BUILD @@ -22,31 +22,26 @@ package(default_visibility = ["//visibility:public"]) cc_library( name = "malloc", - visibility = ["//visibility:public"], ) filegroup( name = "empty", srcs = [], - visibility = ["//visibility:public"], ) filegroup( name = "mingw_compiler_files", srcs = [":builtin_include_directory_paths_mingw"], - visibility = ["//visibility:public"], ) filegroup( name = "clangcl_compiler_files", srcs = [":builtin_include_directory_paths_clangcl"], - visibility = ["//visibility:public"], ) filegroup( name = "msvc_compiler_files", srcs = [":builtin_include_directory_paths_msvc"], - visibility = ["//visibility:public"], ) # Hardcoded toolchain, legacy behaviour. @@ -358,5 +353,4 @@ toolchain( filegroup( name = "link_dynamic_library", srcs = ["link_dynamic_library.sh"], - visibility = ["//visibility:public"], ) diff --git a/third_party/xla/third_party/tsl/tools/toolchains/win/tf_win_05022023/BUILD b/third_party/xla/third_party/tsl/tools/toolchains/win/tf_win_05022023/BUILD index 4d6c76ca644f07..f245f6d0789c9d 100644 --- a/third_party/xla/third_party/tsl/tools/toolchains/win/tf_win_05022023/BUILD +++ b/third_party/xla/third_party/tsl/tools/toolchains/win/tf_win_05022023/BUILD @@ -22,31 +22,26 @@ package(default_visibility = ["//visibility:public"]) cc_library( name = "malloc", - visibility = ["//visibility:public"], ) filegroup( name = "empty", srcs = [], - visibility = ["//visibility:public"], ) filegroup( name = "mingw_compiler_files", srcs = [":builtin_include_directory_paths_mingw"], - visibility = ["//visibility:public"], ) filegroup( name = "clangcl_compiler_files", srcs = [":builtin_include_directory_paths_clangcl"], - visibility = ["//visibility:public"], ) filegroup( name = "msvc_compiler_files", srcs = [":builtin_include_directory_paths_msvc"], - visibility = ["//visibility:public"], ) # Hardcoded toolchain, legacy behaviour. diff --git a/third_party/xla/third_party/tsl/tools/toolchains/win_1803/py38/BUILD b/third_party/xla/third_party/tsl/tools/toolchains/win_1803/py38/BUILD index 3efde36a05097c..9aa4d82e6daca6 100644 --- a/third_party/xla/third_party/tsl/tools/toolchains/win_1803/py38/BUILD +++ b/third_party/xla/third_party/tsl/tools/toolchains/win_1803/py38/BUILD @@ -39,7 +39,6 @@ cc_library( name = "python_headers", hdrs = [":python_include"], includes = ["python_include"], - visibility = ["//visibility:public"], deps = select({ ":windows": [":python_lib"], "//conditions:default": [], @@ -50,7 +49,6 @@ cc_library( name = "numpy_headers", hdrs = [":numpy_include"], includes = ["numpy_include"], - visibility = ["//visibility:public"], ) config_setting( diff --git a/third_party/xla/third_party/tsl/tools/toolchains/win_1803/py39/BUILD b/third_party/xla/third_party/tsl/tools/toolchains/win_1803/py39/BUILD index f3892df8ef4fae..f5b545cb161b3a 100644 --- a/third_party/xla/third_party/tsl/tools/toolchains/win_1803/py39/BUILD +++ b/third_party/xla/third_party/tsl/tools/toolchains/win_1803/py39/BUILD @@ -66,7 +66,6 @@ cc_library( name = "python_headers", hdrs = [":python_include"], includes = ["python_include"], - visibility = ["//visibility:public"], deps = select({ ":windows": [":python_lib"], "//conditions:default": [], @@ -77,7 +76,6 @@ cc_library( name = "numpy_headers", hdrs = [":numpy_include"], includes = ["numpy_include"], - visibility = ["//visibility:public"], ) config_setting( diff --git a/third_party/xla/third_party/tsl/tsl/BUILD b/third_party/xla/third_party/tsl/tsl/BUILD index c7cacce00c94dc..e41e9b1237faad 100644 --- a/third_party/xla/third_party/tsl/tsl/BUILD +++ b/third_party/xla/third_party/tsl/tsl/BUILD @@ -1,7 +1,7 @@ -load("tsl.bzl", "if_google", "if_oss") load("@bazel_skylib//lib:selects.bzl", "selects") load("@bazel_skylib//rules:common_settings.bzl", "bool_flag", "bool_setting") load("@bazel_skylib//:bzl_library.bzl", "bzl_library") +load("tsl.bzl", "if_google", "if_oss") # copybara:uncomment package(default_applicable_licenses = ["//tensorflow:license"]) @@ -132,7 +132,6 @@ config_setting( "apple_platform_type": "macos", "cpu": "darwin", }, - visibility = ["//visibility:public"], ) config_setting( @@ -145,7 +144,6 @@ config_setting( "apple_platform_type": "macos", "cpu": "darwin_x86_64", }, - visibility = ["//visibility:public"], ) selects.config_setting_group( diff --git a/third_party/xla/third_party/tsl/tsl/c/BUILD b/third_party/xla/third_party/tsl/tsl/c/BUILD index d2997f86eea138..39019dffa29ee2 100644 --- a/third_party/xla/third_party/tsl/tsl/c/BUILD +++ b/third_party/xla/third_party/tsl/tsl/c/BUILD @@ -1,15 +1,14 @@ # Description: # C API for TensorFlow, for use by client language bindings. +load("//tsl:tsl.bzl", "internal_visibility", "tsl_copts", "tsl_gpu_library") load("//tsl/platform:build_config.bzl", "tsl_cc_test") load("@local_tsl//tsl/platform:rules_cc.bzl", "cc_library") -load("//tsl:tsl.bzl", "tsl_copts", "tsl_gpu_library") # buildifier: disable=same-origin-load load("//tsl:tsl.default.bzl", "filegroup") package( - default_visibility = ["//visibility:public"], # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], licenses = ["notice"], ) @@ -22,7 +21,7 @@ filegroup( srcs = [ "tsl_status.h", ], - visibility = ["//visibility:public"], + visibility = internal_visibility(["//tensorflow:__subpackages__"]), ) filegroup( @@ -36,7 +35,9 @@ filegroup( "*test*", ], ), - visibility = ["//visibility:public"], + visibility = internal_visibility([ + "//tensorflow/c:__subpackages__", + ]), ) tsl_gpu_library( @@ -78,7 +79,6 @@ cc_library( tsl_cc_test( name = "tsl_status_test", srcs = ["tsl_status_test.cc"], - visibility = ["//visibility:public"], deps = [ ":tsl_status", ":tsl_status_internal", @@ -111,5 +111,7 @@ tsl_gpu_library( filegroup( name = "tsl_status_internal_headers", srcs = ["tsl_status_internal.h"], - visibility = ["//visibility:public"], + visibility = internal_visibility([ + "//tensorflow/c:__subpackages__", + ]), ) diff --git a/third_party/xla/third_party/tsl/tsl/concurrency/BUILD b/third_party/xla/third_party/tsl/tsl/concurrency/BUILD index 0cfb38d9697f61..4250b3db059897 100644 --- a/third_party/xla/third_party/tsl/tsl/concurrency/BUILD +++ b/third_party/xla/third_party/tsl/tsl/concurrency/BUILD @@ -3,6 +3,7 @@ load("@local_tsl//tsl/platform:rules_cc.bzl", "cc_library") load("//tsl/platform:build_config.bzl", "tsl_cc_test") package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], default_visibility = ["//visibility:public"], licenses = ["notice"], ) @@ -19,7 +20,6 @@ cc_library( "chain.h", ], compatible_with = get_compatible_with_portable(), - visibility = ["//visibility:public"], deps = [ ":concurrent_vector", ":ref_count", @@ -35,7 +35,6 @@ cc_library( tsl_cc_test( name = "async_value_test", srcs = ["async_value_test.cc"], - visibility = ["//visibility:public"], deps = [ ":async_value", "//tsl/platform:test", @@ -46,7 +45,6 @@ tsl_cc_test( tsl_cc_test( name = "async_value_ref_test", srcs = ["async_value_ref_test.cc"], - visibility = ["//visibility:public"], deps = [ ":async_value", "//tsl/platform:test", @@ -58,7 +56,6 @@ cc_library( name = "concurrent_vector", hdrs = ["concurrent_vector.h"], compatible_with = get_compatible_with_portable(), - visibility = ["//visibility:public"], deps = [ "@com_google_absl//absl/synchronization", "@com_google_absl//absl/types:span", @@ -68,7 +65,6 @@ cc_library( tsl_cc_test( name = "concurrent_vector_test", srcs = ["concurrent_vector_test.cc"], - visibility = ["//visibility:public"], deps = [ ":concurrent_vector", "//tsl/platform:env", @@ -82,5 +78,4 @@ cc_library( name = "ref_count", hdrs = ["ref_count.h"], compatible_with = get_compatible_with_portable(), - visibility = ["//visibility:public"], ) diff --git a/third_party/xla/third_party/tsl/tsl/cuda/BUILD.bazel b/third_party/xla/third_party/tsl/tsl/cuda/BUILD.bazel index a62af76811a3e5..acb55221627a16 100644 --- a/third_party/xla/third_party/tsl/tsl/cuda/BUILD.bazel +++ b/third_party/xla/third_party/tsl/tsl/cuda/BUILD.bazel @@ -13,7 +13,6 @@ load( ) package( - default_visibility = ["//visibility:public"], licenses = ["notice"], ) diff --git a/third_party/xla/third_party/tsl/tsl/cuda/cudnn.symbols b/third_party/xla/third_party/tsl/tsl/cuda/cudnn.symbols index 2c4dbd71030b38..95c46295e1dcbb 100644 --- a/third_party/xla/third_party/tsl/tsl/cuda/cudnn.symbols +++ b/third_party/xla/third_party/tsl/tsl/cuda/cudnn.symbols @@ -3,6 +3,7 @@ cudnnActivationForward cudnnAddTensor cudnnAdvInferVersionCheck cudnnAdvTrainVersionCheck +cudnnAdvVersionCheck cudnnBackendCreateDescriptor cudnnBackendDestroyDescriptor cudnnBackendExecute @@ -20,6 +21,7 @@ cudnnCTCLoss cudnnCTCLoss_v8 cudnnCnnInferVersionCheck cudnnCnnTrainVersionCheck +cudnnCnnVersionCheck cudnnConvolutionBackwardBias cudnnConvolutionBackwardData cudnnConvolutionBackwardFilter @@ -175,6 +177,7 @@ cudnnGetTensorNdDescriptor cudnnGetTensorSizeInBytes cudnnGetTensorTransformDescriptor cudnnGetVersion +cudnnGraphVersionCheck cudnnIm2Col cudnnInitTransformDest cudnnLRNCrossChannelBackward @@ -189,6 +192,7 @@ cudnnNormalizationForwardTraining cudnnOpTensor cudnnOpsInferVersionCheck cudnnOpsTrainVersionCheck +cudnnOpsVersionCheck cudnnPoolingBackward cudnnPoolingForward cudnnQueryRuntimeError diff --git a/third_party/xla/third_party/tsl/tsl/distributed_runtime/BUILD b/third_party/xla/third_party/tsl/tsl/distributed_runtime/BUILD index 5ada05374fad08..479903d7841ca0 100644 --- a/third_party/xla/third_party/tsl/tsl/distributed_runtime/BUILD +++ b/third_party/xla/third_party/tsl/tsl/distributed_runtime/BUILD @@ -2,13 +2,17 @@ # Distributed runtime modules for machine learning, which allows coordination between multiple # processes for distributed operations. +load("//tsl:tsl.bzl", "internal_visibility") load( "@local_tsl//tsl/platform:rules_cc.bzl", "cc_library", ) package( - default_visibility = ["//visibility:public"], + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = internal_visibility([ + "//tsl:internal", + ]), licenses = ["notice"], ) @@ -16,7 +20,6 @@ cc_library( name = "call_options", srcs = ["call_options.cc"], hdrs = ["call_options.h"], - visibility = ["//visibility:public"], deps = [ "//tsl/platform:macros", "//tsl/platform:mutex", @@ -30,5 +33,4 @@ filegroup( srcs = [ "call_options.h", ], - visibility = ["//visibility:public"], ) diff --git a/third_party/xla/third_party/tsl/tsl/distributed_runtime/coordination/BUILD b/third_party/xla/third_party/tsl/tsl/distributed_runtime/coordination/BUILD index 72cfcb34967870..68a79930df849b 100644 --- a/third_party/xla/third_party/tsl/tsl/distributed_runtime/coordination/BUILD +++ b/third_party/xla/third_party/tsl/tsl/distributed_runtime/coordination/BUILD @@ -1,16 +1,18 @@ -load("//tsl:tsl.bzl", "if_oss", "set_external_visibility", "tsl_gpu_library") +load("//tsl:tsl.bzl", "if_oss", "internal_visibility", "tsl_gpu_library") load("//tsl/platform:build_config.bzl", "tf_proto_library", "tsl_cc_test") load("@local_tsl//tsl/platform:rules_cc.bzl", "cc_library") package( - default_visibility = ["//visibility:public"], + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = internal_visibility([ + "//tsl:internal", + ]), licenses = ["notice"], ) cc_library( name = "coordination_service_error_util", hdrs = ["coordination_service_error_util.h"], - visibility = ["//visibility:public"], deps = [ "//tsl/platform:errors", "//tsl/platform:status", @@ -22,7 +24,6 @@ cc_library( tsl_cc_test( name = "coordination_service_error_util_test", srcs = ["coordination_service_error_util_test.cc"], - visibility = ["//visibility:public"], deps = [ ":coordination_service_error_util", "//tsl/platform:errors", @@ -36,7 +37,6 @@ tsl_cc_test( cc_library( name = "coordination_client", hdrs = ["coordination_client.h"], - visibility = ["//visibility:public"], deps = [ "//tsl/distributed_runtime:call_options", "//tsl/platform:status", @@ -47,7 +47,6 @@ cc_library( cc_library( name = "coordination_service", hdrs = ["coordination_service.h"], - visibility = ["//visibility:public"], deps = [ ":coordination_client", "//tsl/platform:status", @@ -64,7 +63,6 @@ cc_library( tsl_gpu_library( name = "coordination_service_impl", srcs = ["coordination_service.cc"], - visibility = ["//visibility:public"], deps = [ ":coordination_client", ":coordination_service", @@ -96,7 +94,6 @@ tf_proto_library( testonly = 1, srcs = ["test_device.proto"], cc_api_version = 2, - visibility = ["//visibility:public"], ) tsl_cc_test( @@ -106,7 +103,6 @@ tsl_cc_test( "manual", "no_oss", ]), # b/169705709, no protobuf matchers in OSS. - visibility = ["//visibility:public"], deps = [ ":coordination_client", ":coordination_service", @@ -137,7 +133,6 @@ tsl_gpu_library( name = "coordination_service_agent", srcs = ["coordination_service_agent.cc"], hdrs = ["coordination_service_agent.h"], - visibility = ["//visibility:public"], deps = [ ":coordination_client", ":coordination_service_error_util", @@ -164,7 +159,6 @@ tsl_gpu_library( tsl_cc_test( name = "coordination_service_agent_test", srcs = ["coordination_service_agent_test.cc"], - visibility = ["//visibility:public"], deps = [ ":coordination_client", ":coordination_service_agent", @@ -194,7 +188,6 @@ cc_library( hdrs = [ "coordination_service_rpc_handler.h", ], - visibility = ["//visibility:public"], deps = [ ":coordination_service", ":coordination_service_agent", @@ -214,7 +207,6 @@ cc_library( tsl_cc_test( name = "coordination_service_recoverable_job_test", srcs = ["coordination_service_recoverable_job_test.cc"], - visibility = ["//visibility:public"], deps = [ ":coordination_client", ":coordination_service", @@ -245,5 +237,4 @@ filegroup( "coordination_client.h", "coordination_service.h", ], - visibility = ["//visibility:public"], ) diff --git a/third_party/xla/third_party/tsl/tsl/distributed_runtime/preemption/BUILD b/third_party/xla/third_party/tsl/tsl/distributed_runtime/preemption/BUILD index 99df97792f73a0..aca10c8ca69783 100644 --- a/third_party/xla/third_party/tsl/tsl/distributed_runtime/preemption/BUILD +++ b/third_party/xla/third_party/tsl/tsl/distributed_runtime/preemption/BUILD @@ -1,10 +1,13 @@ -load("@local_tsl//tsl/platform:rules_cc.bzl", "cc_library") -load("//tsl/platform:build_config.bzl", "tsl_cc_test") +load("//tsl:tsl.bzl", "internal_visibility") load("//tsl:tsl.default.bzl", "get_compatible_with_portable", "tsl_grpc_cc_dependencies") -load("//tsl:tsl.bzl", "set_external_visibility") +load("//tsl/platform:build_config.bzl", "tsl_cc_test") +load("@local_tsl//tsl/platform:rules_cc.bzl", "cc_library") package( - default_visibility = ["//visibility:public"], + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = internal_visibility([ + "//tsl:internal", + ]), licenses = ["notice"], ) @@ -13,7 +16,6 @@ cc_library( srcs = ["preemption_notifier.cc"], hdrs = ["preemption_notifier.h"], compatible_with = get_compatible_with_portable(), - visibility = ["//visibility:public"], deps = [ "//tsl/platform:env", "//tsl/platform:errors", @@ -29,7 +31,6 @@ tsl_cc_test( name = "preemption_notifier_test", size = "small", srcs = ["preemption_notifier_test.cc"], - visibility = ["//visibility:public"], deps = [ ":preemption_notifier", "//tsl/platform:env", @@ -49,7 +50,6 @@ cc_library( name = "preemption_sync_manager", srcs = ["preemption_sync_manager.cc"], hdrs = ["preemption_sync_manager.h"], - visibility = ["//visibility:public"], deps = [ ":preemption_notifier", "//tsl/distributed_runtime:call_options", @@ -70,7 +70,6 @@ tsl_cc_test( name = "preemption_sync_manager_test", size = "small", srcs = ["preemption_sync_manager_test.cc"], - visibility = ["//visibility:public"], deps = [ ":preemption_notifier", ":preemption_sync_manager", diff --git a/third_party/xla/third_party/tsl/tsl/distributed_runtime/preemption/preemption_notifier.cc b/third_party/xla/third_party/tsl/tsl/distributed_runtime/preemption/preemption_notifier.cc index f9838f2823b250..56226a85896e5f 100644 --- a/third_party/xla/third_party/tsl/tsl/distributed_runtime/preemption/preemption_notifier.cc +++ b/third_party/xla/third_party/tsl/tsl/distributed_runtime/preemption/preemption_notifier.cc @@ -94,13 +94,14 @@ void SigtermNotifier::StartListenerThread() { } // namespace -StatusOr PreemptionNotifier::WillBePreemptedAt() { +absl::StatusOr PreemptionNotifier::WillBePreemptedAt() { absl::Notification n; - StatusOr result; - WillBePreemptedAtAsync([&n, &result](StatusOr async_result) { - result = async_result; - n.Notify(); - }); + absl::StatusOr result; + WillBePreemptedAtAsync( + [&n, &result](absl::StatusOr async_result) { + result = async_result; + n.Notify(); + }); n.WaitForNotification(); return result; } @@ -117,7 +118,7 @@ void PreemptionNotifier::WillBePreemptedAtAsync(PreemptTimeCallback callback) { } void PreemptionNotifier::NotifyRegisteredListeners( - StatusOr death_time) { + absl::StatusOr death_time) { mutex_lock l(mu_); if (death_time.ok()) { death_time_ = death_time.value(); diff --git a/third_party/xla/third_party/tsl/tsl/distributed_runtime/preemption/preemption_notifier.h b/third_party/xla/third_party/tsl/tsl/distributed_runtime/preemption/preemption_notifier.h index 53941ceea6493e..075af20fcd3346 100644 --- a/third_party/xla/third_party/tsl/tsl/distributed_runtime/preemption/preemption_notifier.h +++ b/third_party/xla/third_party/tsl/tsl/distributed_runtime/preemption/preemption_notifier.h @@ -75,7 +75,7 @@ namespace tsl { class PreemptionNotifier { public: - typedef std::function)> PreemptTimeCallback; + typedef std::function)> PreemptTimeCallback; using PreemptionNotifierFactory = std::function(Env* env)>; @@ -112,7 +112,7 @@ class PreemptionNotifier { // termination will occur once the listener receives the preemption // notification. If no death time is specified, absl::Now() is returned. // Returns error::Cancelled if UnregisterListeners() is called. - StatusOr WillBePreemptedAt(); + absl::StatusOr WillBePreemptedAt(); // Registers a callback that takes the death time as input once the listener // receives the preemption notification. @@ -126,7 +126,7 @@ class PreemptionNotifier { Env* GetEnv() { return env_; } // Invokes all pending callbacks upon receipt of preemption notice with death // time or errors (e.g. cancellation during shutdown). - void NotifyRegisteredListeners(StatusOr death_time); + void NotifyRegisteredListeners(absl::StatusOr death_time); private: static std::unordered_map* diff --git a/third_party/xla/third_party/tsl/tsl/distributed_runtime/preemption/preemption_notifier_test.cc b/third_party/xla/third_party/tsl/tsl/distributed_runtime/preemption/preemption_notifier_test.cc index abd1d24c9f51e4..d083e2ef1ba2ab 100644 --- a/third_party/xla/third_party/tsl/tsl/distributed_runtime/preemption/preemption_notifier_test.cc +++ b/third_party/xla/third_party/tsl/tsl/distributed_runtime/preemption/preemption_notifier_test.cc @@ -59,7 +59,7 @@ TEST_F(PreemptNotifierTest, WillBePreemptedAt) { []() { std::raise(SIGTERM); }); // Preempt time should be current timestamp. - StatusOr result = preempt_notifier->WillBePreemptedAt(); + absl::StatusOr result = preempt_notifier->WillBePreemptedAt(); TF_CHECK_OK(result.status()); absl::Time preempt_time = result.value(); @@ -84,7 +84,7 @@ TEST_F(PreemptNotifierTest, env->SleepForMicroseconds(absl::ToInt64Microseconds(absl::Seconds(2))); // Preempt time should be current timestamp. - StatusOr result = preempt_notifier->WillBePreemptedAt(); + absl::StatusOr result = preempt_notifier->WillBePreemptedAt(); TF_CHECK_OK(result.status()); absl::Time preempt_time = result.value(); @@ -105,17 +105,17 @@ TEST_F(PreemptNotifierTest, WillBePreemptedAtAsync_SameResultForAllCallbacks) { []() { std::raise(SIGTERM); }); // Preempt time should be current timestamp. - StatusOr preempt_time; - StatusOr preempt_time_2; + absl::StatusOr preempt_time; + absl::StatusOr preempt_time_2; absl::Notification n; absl::Notification n_2; preempt_notifier->WillBePreemptedAtAsync( - [&preempt_time, &n](StatusOr result) { + [&preempt_time, &n](absl::StatusOr result) { preempt_time = result; n.Notify(); }); preempt_notifier->WillBePreemptedAtAsync( - [&preempt_time_2, &n_2](StatusOr result) { + [&preempt_time_2, &n_2](absl::StatusOr result) { preempt_time_2 = result; n_2.Notify(); }); @@ -135,7 +135,7 @@ TEST_F(PreemptNotifierTest, Reset_TwoDifferentPreemptTimesRecorded) { // Raise first signal. std::raise(SIGTERM); - StatusOr result = preempt_notifier->WillBePreemptedAt(); + absl::StatusOr result = preempt_notifier->WillBePreemptedAt(); TF_CHECK_OK(result.status()); absl::Time preempt_time = result.value(); @@ -154,10 +154,10 @@ TEST_F(PreemptNotifierTest, DestructorCancelsPendingCalls) { auto env = Env::Default(); std::unique_ptr preempt_notifier = PreemptionNotifier::CreatePreemptionNotifier("sigterm", env); - StatusOr result; + absl::StatusOr result; absl::Notification n; preempt_notifier->WillBePreemptedAtAsync( - [&result, &n](StatusOr status_or_time) { + [&result, &n](absl::StatusOr status_or_time) { result = status_or_time; n.Notify(); }); diff --git a/third_party/xla/third_party/tsl/tsl/distributed_runtime/preemption/preemption_sync_manager.cc b/third_party/xla/third_party/tsl/tsl/distributed_runtime/preemption/preemption_sync_manager.cc index a4ca1ac9159dee..ef501c692a7726 100644 --- a/third_party/xla/third_party/tsl/tsl/distributed_runtime/preemption/preemption_sync_manager.cc +++ b/third_party/xla/third_party/tsl/tsl/distributed_runtime/preemption/preemption_sync_manager.cc @@ -73,11 +73,12 @@ class PreemptionSyncManagerImpl : public PreemptionSyncManager { ~PreemptionSyncManagerImpl() override { shutdown_.Notify(); } - Status Initialize(CoordinationServiceAgent* agent) override; - Status Initialize(CoordinationServiceAgent* agent, - const std::string& preemption_notifier_type) override; - Status Initialize(CoordinationServiceAgent* agent, - std::unique_ptr notifier) override; + absl::Status Initialize(CoordinationServiceAgent* agent) override; + absl::Status Initialize(CoordinationServiceAgent* agent, + const std::string& preemption_notifier_type) override; + absl::Status Initialize( + CoordinationServiceAgent* agent, + std::unique_ptr notifier) override; bool ReachedSyncPoint(int step_counter) override; private: @@ -103,11 +104,12 @@ class PreemptionSyncManagerImpl : public PreemptionSyncManager { std::shared_ptr call_opts_; }; -Status PreemptionSyncManagerImpl::Initialize(CoordinationServiceAgent* agent) { +absl::Status PreemptionSyncManagerImpl::Initialize( + CoordinationServiceAgent* agent) { return Initialize(agent, "sigterm"); } -Status PreemptionSyncManagerImpl::Initialize( +absl::Status PreemptionSyncManagerImpl::Initialize( CoordinationServiceAgent* agent, const std::string& preemption_notifier_type) { TF_ASSIGN_OR_RETURN(Env * env, agent->GetEnv()); @@ -115,7 +117,7 @@ Status PreemptionSyncManagerImpl::Initialize( preemption_notifier_type, env)); } -Status PreemptionSyncManagerImpl::Initialize( +absl::Status PreemptionSyncManagerImpl::Initialize( CoordinationServiceAgent* agent, std::unique_ptr notifier) { TF_ASSIGN_OR_RETURN(Env * env, agent->GetEnv()); @@ -147,8 +149,8 @@ Status PreemptionSyncManagerImpl::Initialize( } notified_metric->GetCell()->Set(true); // Notify coordination service about preemption notice. - const Status s = agent->InsertKeyValue(kPreemptionNoticeKey, - absl::FormatTime(*death_time)); + const absl::Status s = agent->InsertKeyValue( + kPreemptionNoticeKey, absl::FormatTime(*death_time)); LOG(INFO) << "Notified coordination service that this task will " "be preempted at " << *death_time << ". Status: " << s; @@ -177,7 +179,7 @@ Status PreemptionSyncManagerImpl::Initialize( // CancelPreemptionBarrier() cannot be used because this may be // triggered after preemption sync manager has been destroyed. agent->CancelBarrierAsync( - kPreemptionBarrier, [](const Status& status) { + kPreemptionBarrier, [](const absl::Status& status) { if (!status.ok()) { LOG(ERROR) << "Failed to cancel preemption barrier: " << status; @@ -205,7 +207,7 @@ Status PreemptionSyncManagerImpl::Initialize( death_time))); }); - return OkStatus(); + return absl::OkStatus(); } void PreemptionSyncManagerImpl::ComputeSyncCallCounter(absl::Time death_time) { @@ -231,7 +233,7 @@ void PreemptionSyncManagerImpl::ComputeSyncCallCounter(absl::Time death_time) { // `preemption_sync_counter_` or the protocol failed. This ensures correctness // of the preemption sync protocol. mutex_lock l(mu_); - const Status notified_status = agent_->InsertKeyValue( + const absl::Status notified_status = agent_->InsertKeyValue( current_call_counter_key_, std::to_string(call_counter_)); if (!notified_status.ok()) { LOG(ERROR) << "Preemption sync failed - could not inform service of " @@ -243,7 +245,7 @@ void PreemptionSyncManagerImpl::ComputeSyncCallCounter(absl::Time death_time) { // 3. Impose a barrier to wait until everybody sends their current call // counter. - const Status barrier_status = + const absl::Status barrier_status = agent_->WaitAtBarrier(kPreemptionBarrier, kPreemptionBarrierTimeout, {}); if (!barrier_status.ok()) { LOG(ERROR) << "Preemption sync barrier failed: " << barrier_status; @@ -287,11 +289,12 @@ void PreemptionSyncManagerImpl::ComputeSyncCallCounter(absl::Time death_time) { } void PreemptionSyncManagerImpl::CancelPreemptionBarrier() { - agent_->CancelBarrierAsync(kPreemptionBarrier, [](const Status& status) { - if (!status.ok()) { - LOG(ERROR) << "Failed to cancel preemption barrier: " << status; - } - }); + agent_->CancelBarrierAsync( + kPreemptionBarrier, [](const absl::Status& status) { + if (!status.ok()) { + LOG(ERROR) << "Failed to cancel preemption barrier: " << status; + } + }); } bool PreemptionSyncManagerImpl::ReachedSyncPoint(int step_counter) { diff --git a/third_party/xla/third_party/tsl/tsl/distributed_runtime/preemption/preemption_sync_manager.h b/third_party/xla/third_party/tsl/tsl/distributed_runtime/preemption/preemption_sync_manager.h index 2c359b686ffc58..baf1911cac2d6d 100644 --- a/third_party/xla/third_party/tsl/tsl/distributed_runtime/preemption/preemption_sync_manager.h +++ b/third_party/xla/third_party/tsl/tsl/distributed_runtime/preemption/preemption_sync_manager.h @@ -35,11 +35,13 @@ class PreemptionSyncManager { public: virtual ~PreemptionSyncManager() = default; - virtual Status Initialize(CoordinationServiceAgent* agent) = 0; - virtual Status Initialize(CoordinationServiceAgent* agent, - const std::string& preemption_notifier_type) = 0; - virtual Status Initialize(CoordinationServiceAgent* agent, - std::unique_ptr notifier) = 0; + virtual absl::Status Initialize(CoordinationServiceAgent* agent) = 0; + virtual absl::Status Initialize( + CoordinationServiceAgent* agent, + const std::string& preemption_notifier_type) = 0; + virtual absl::Status Initialize( + CoordinationServiceAgent* agent, + std::unique_ptr notifier) = 0; // Check if the synchronized point has been reached. When a task has been // preempted, a safe sync point will be determined by using the fastest task's diff --git a/third_party/xla/third_party/tsl/tsl/distributed_runtime/preemption/preemption_sync_manager_test.cc b/third_party/xla/third_party/tsl/tsl/distributed_runtime/preemption/preemption_sync_manager_test.cc index 4caed02d705ad2..82d578c2b9658c 100644 --- a/third_party/xla/third_party/tsl/tsl/distributed_runtime/preemption/preemption_sync_manager_test.cc +++ b/third_party/xla/third_party/tsl/tsl/distributed_runtime/preemption/preemption_sync_manager_test.cc @@ -158,7 +158,7 @@ class PreemptionSyncManagerTest : public ::testing::Test { std::unique_ptr coord_client2 = absl::WrapUnique(NewGrpcCoordinationClient( grpc_server_->InProcessChannel(::grpc::ChannelArguments()))); - auto error_fn = [](const Status& status) { + auto error_fn = [](const absl::Status& status) { LOG(ERROR) << "Coordination service agent in error status: " << status; }; CoordinationServiceConfig coord_config; diff --git a/third_party/xla/third_party/tsl/tsl/distributed_runtime/rpc/BUILD b/third_party/xla/third_party/tsl/tsl/distributed_runtime/rpc/BUILD index 6ef4e72e665a6e..9effe4beb3e2e8 100644 --- a/third_party/xla/third_party/tsl/tsl/distributed_runtime/rpc/BUILD +++ b/third_party/xla/third_party/tsl/tsl/distributed_runtime/rpc/BUILD @@ -1,13 +1,16 @@ # Description: # RPC communication interfaces and implementations for TensorFlow. -load("//tsl:tsl.bzl", "set_external_visibility") +load("//tsl:tsl.bzl", "internal_visibility") load("//tsl:tsl.default.bzl", "tsl_grpc_cc_dependencies") load("//tsl/platform:build_config.bzl", "tf_proto_library", "tsl_cc_test") load("@local_tsl//tsl/platform:rules_cc.bzl", "cc_library") package( - default_visibility = ["//visibility:public"], + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = internal_visibility([ + "//tsl:internal", + ]), licenses = ["notice"], ) @@ -15,7 +18,6 @@ cc_library( name = "async_service_interface", srcs = [], hdrs = ["async_service_interface.h"], - visibility = ["//visibility:public"], deps = [], ) @@ -23,7 +25,6 @@ cc_library( name = "grpc_call", srcs = [], hdrs = ["grpc_call.h"], - visibility = ["//visibility:public"], deps = [ "//tsl/platform:mutex", "//tsl/platform:refcount", @@ -34,7 +35,6 @@ cc_library( name = "grpc_util", srcs = ["grpc_util.cc"], hdrs = ["grpc_util.h"], - visibility = ["//visibility:public"], deps = [ "//tsl/platform:protobuf", "//tsl/platform:status", @@ -53,7 +53,6 @@ tsl_cc_test( tags = [ "no_mac", ], - visibility = ["//visibility:public"], deps = [ ":grpc_util", ":test_request_proto_cc_impl", @@ -69,7 +68,6 @@ tsl_cc_test( cc_library( name = "grpc_channel_common", hdrs = ["grpc_channel_common.h"], - visibility = ["//visibility:public"], deps = [ ":grpc_util", "//tsl/platform:logging", @@ -82,7 +80,6 @@ cc_library( name = "grpc_channel", srcs = ["grpc_channel.cc"], hdrs = ["grpc_channel.h"], - visibility = ["//visibility:public"], deps = [ ":grpc_channel_common", ":grpc_util", @@ -109,7 +106,6 @@ tsl_cc_test( srcs = [ "grpc_channel_test.cc", ], - visibility = ["//visibility:public"], deps = [ ":grpc_channel", "//tsl/lib/core:status_test_util", @@ -125,7 +121,6 @@ tsl_cc_test( cc_library( name = "grpc_state", hdrs = ["grpc_state.h"], - visibility = ["//visibility:public"], deps = [ ":grpc_client_cq_tag", ":grpc_util", @@ -143,7 +138,6 @@ cc_library( name = "grpc_client_cq_tag", srcs = [], hdrs = ["grpc_client_cq_tag.h"], - visibility = ["//visibility:public"], deps = [ "//tsl/platform:macros", ], @@ -154,7 +148,6 @@ tf_proto_library( testonly = 1, srcs = ["test_request.proto"], create_java_proto = False, - visibility = ["//visibility:public"], ) filegroup( @@ -163,5 +156,4 @@ filegroup( "**/*.cc", "**/*.h", ]), - visibility = ["//visibility:public"], ) diff --git a/third_party/xla/third_party/tsl/tsl/distributed_runtime/rpc/coordination/BUILD b/third_party/xla/third_party/tsl/tsl/distributed_runtime/rpc/coordination/BUILD index 6bf8ca8a9874f1..ba8ec2333d197d 100644 --- a/third_party/xla/third_party/tsl/tsl/distributed_runtime/rpc/coordination/BUILD +++ b/third_party/xla/third_party/tsl/tsl/distributed_runtime/rpc/coordination/BUILD @@ -1,9 +1,12 @@ -load("@local_tsl//tsl/platform:rules_cc.bzl", "cc_library") +load("//tsl:tsl.bzl", "internal_visibility") load("//tsl:tsl.default.bzl", "tsl_grpc_cc_dependencies") -load("//tsl:tsl.bzl", "set_external_visibility") +load("@local_tsl//tsl/platform:rules_cc.bzl", "cc_library") package( - default_visibility = ["//visibility:public"], + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = internal_visibility([ + "//tsl:internal", + ]), licenses = ["notice"], ) @@ -11,7 +14,6 @@ cc_library( name = "grpc_coordination_client", srcs = ["grpc_coordination_client.cc"], hdrs = ["grpc_coordination_client.h"], - visibility = ["//visibility:public"], deps = [ "//tsl/distributed_runtime:call_options", "//tsl/distributed_runtime/coordination:coordination_client", @@ -31,7 +33,6 @@ cc_library( name = "grpc_coordination_service_impl", srcs = ["grpc_coordination_service_impl.cc"], hdrs = ["grpc_coordination_service_impl.h"], - visibility = ["//visibility:public"], deps = [ "//tsl/distributed_runtime/coordination:coordination_service_agent", "//tsl/distributed_runtime/coordination:coordination_service_rpc_handler", diff --git a/third_party/xla/third_party/tsl/tsl/distributed_runtime/rpc/grpc_channel.cc b/third_party/xla/third_party/tsl/tsl/distributed_runtime/rpc/grpc_channel.cc index ba886c1a7bf1d3..492c984e12f13a 100644 --- a/third_party/xla/third_party/tsl/tsl/distributed_runtime/rpc/grpc_channel.cc +++ b/third_party/xla/third_party/tsl/tsl/distributed_runtime/rpc/grpc_channel.cc @@ -49,10 +49,10 @@ string MakeAddress(const string& job, int replica, int task) { } // Allows the host to be a raw IP (either v4 or v6). -Status ValidateHostPortPair(const string& host_port) { +absl::Status ValidateHostPortPair(const string& host_port) { string bns_prefix = "/bns/"; if (host_port.substr(0, bns_prefix.length()) == bns_prefix) { - return OkStatus(); + return absl::OkStatus(); } uint32 port; auto colon_index = host_port.find_last_of(':'); @@ -61,7 +61,7 @@ Status ValidateHostPortPair(const string& host_port) { return errors::InvalidArgument("Could not interpret \"", host_port, "\" as a host-port pair."); } - return OkStatus(); + return absl::OkStatus(); } ::grpc::ChannelArguments* CreateDefaultChannelArguments() { @@ -140,21 +140,22 @@ ::grpc::ChannelArguments GetChannelArguments(const RPCOptions* rpc_options) { return args; } -Status NewHostPortGrpcChannel(const string& target, - const RPCOptions* rpc_options, - SharedGrpcChannelPtr* channel_pointer) { +absl::Status NewHostPortGrpcChannel(const string& target, + const RPCOptions* rpc_options, + SharedGrpcChannelPtr* channel_pointer) { // Minimally ensure that the target is valid TF_RETURN_IF_ERROR(ValidateHostPortPair(target)); ::grpc::ChannelArguments args = GetChannelArguments(rpc_options); *channel_pointer = ::grpc::CreateCustomChannel( "dns:///" + target, ::grpc::InsecureChannelCredentials(), args); - return OkStatus(); + return absl::OkStatus(); } ChannelCreationFunction ConvertToChannelCreationFunction( - const std::function& new_channel_func_ptr) { + const std::function& + new_channel_func_ptr) { return [new_channel_func_ptr](const string& target) -> SharedGrpcChannelPtr { SharedGrpcChannelPtr channel_ptr; if (new_channel_func_ptr(target, /*rpc_options=*/nullptr, &channel_ptr) @@ -166,7 +167,7 @@ ChannelCreationFunction ConvertToChannelCreationFunction( }; } -Status GrpcChannelSpec::AddHostPortsJob( +absl::Status GrpcChannelSpec::AddHostPortsJob( const string& job_id, const std::map& host_ports) { if (!job_ids_.insert(job_id).second) { return errors::InvalidArgument( @@ -176,7 +177,7 @@ Status GrpcChannelSpec::AddHostPortsJob( TF_RETURN_IF_ERROR(ValidateHostPortPair(id_host_port.second)); } host_ports_jobs_.emplace_back(job_id, host_ports); - return OkStatus(); + return absl::OkStatus(); } namespace { diff --git a/third_party/xla/third_party/tsl/tsl/distributed_runtime/rpc/grpc_channel.h b/third_party/xla/third_party/tsl/tsl/distributed_runtime/rpc/grpc_channel.h index b019377f9986dd..654e7aa91c3218 100644 --- a/third_party/xla/third_party/tsl/tsl/distributed_runtime/rpc/grpc_channel.h +++ b/third_party/xla/third_party/tsl/tsl/distributed_runtime/rpc/grpc_channel.h @@ -43,8 +43,8 @@ class GrpcChannelSpec { const std::map host_ports; }; - Status AddHostPortsJob(const string& job_id, - const std::map& host_ports); + absl::Status AddHostPortsJob(const string& job_id, + const std::map& host_ports); const std::vector& host_ports_jobs() const { return host_ports_jobs_; @@ -88,12 +88,13 @@ GrpcChannelCache* NewGrpcChannelCache( ::grpc::ChannelArguments GetChannelArguments(const RPCOptions* rpc_options); ChannelCreationFunction ConvertToChannelCreationFunction( - const std::function& new_channel_func_ptr); + const std::function& + new_channel_func_ptr); -Status NewHostPortGrpcChannel(const string& target, - const RPCOptions* rpc_options, - SharedGrpcChannelPtr* channel_pointer); +absl::Status NewHostPortGrpcChannel(const string& target, + const RPCOptions* rpc_options, + SharedGrpcChannelPtr* channel_pointer); } // namespace tsl diff --git a/third_party/xla/third_party/tsl/tsl/distributed_runtime/rpc/grpc_state.h b/third_party/xla/third_party/tsl/tsl/distributed_runtime/rpc/grpc_state.h index 37b41edc0a0103..893e1b0192f694 100644 --- a/third_party/xla/third_party/tsl/tsl/distributed_runtime/rpc/grpc_state.h +++ b/third_party/xla/third_party/tsl/tsl/distributed_runtime/rpc/grpc_state.h @@ -149,7 +149,7 @@ class RPCState : public GrpcClientCQTag { VLOG(2) << "Completed call: " << method_; - Status s = FromGrpcStatus(status_); + absl::Status s = FromGrpcStatus(status_); if (s.ok() && !ok) { // Since this function is only being used for processing the response // to Finish for client-side unary calls, ok should never be false @@ -206,7 +206,7 @@ class RPCState : public GrpcClientCQTag { } void ParseAndCallDone() { - Status s; + absl::Status s; if (!parse_proto_fn_(&response_buf_, response_)) { s.Update(errors::Internal("could not parse rpc response")); } diff --git a/third_party/xla/third_party/tsl/tsl/distributed_runtime/rpc/grpc_util.h b/third_party/xla/third_party/tsl/tsl/distributed_runtime/rpc/grpc_util.h index c1cce692b2a197..b10fff85a003e0 100644 --- a/third_party/xla/third_party/tsl/tsl/distributed_runtime/rpc/grpc_util.h +++ b/third_party/xla/third_party/tsl/tsl/distributed_runtime/rpc/grpc_util.h @@ -52,7 +52,7 @@ inline bool IsStreamRemovedError(const ::grpc::Status& s) { s.error_message() == kStreamRemovedMessage; } -inline std::string SerializePayloads(const Status& s) { +inline std::string SerializePayloads(const absl::Status& s) { tensorflow::distributed_runtime::GrpcPayloadContainer container; s.ForEachPayload([&container](StringPiece key, const absl::Cord& value) { (*container.mutable_payloads())[std::string(key)] = std::string(value); @@ -60,7 +60,7 @@ inline std::string SerializePayloads(const Status& s) { return container.SerializeAsString(); } -inline void InsertSerializedPayloads(Status& s, std::string payloads) { +inline void InsertSerializedPayloads(absl::Status& s, std::string payloads) { tensorflow::distributed_runtime::GrpcPayloadContainer container; if (container.ParseFromString(payloads)) { for (const auto& key_val : container.payloads()) { @@ -73,24 +73,25 @@ inline void InsertSerializedPayloads(Status& s, std::string payloads) { } } -inline Status FromGrpcStatus(const ::grpc::Status& s) { +inline absl::Status FromGrpcStatus(const ::grpc::Status& s) { if (s.ok()) { - return OkStatus(); + return absl::OkStatus(); } else { - Status converted; + absl::Status converted; // Convert "UNKNOWN" stream removed errors into unavailable, to allow // for retry upstream. if (IsStreamRemovedError(s)) { - converted = Status(absl::StatusCode::kUnavailable, s.error_message()); + converted = + absl::Status(absl::StatusCode::kUnavailable, s.error_message()); } - converted = Status(static_cast(s.error_code()), - s.error_message()); + converted = absl::Status(static_cast(s.error_code()), + s.error_message()); InsertSerializedPayloads(converted, s.error_details()); return converted; } } -inline ::grpc::Status ToGrpcStatus(const Status& s) { +inline ::grpc::Status ToGrpcStatus(const absl::Status& s) { if (s.ok()) { return ::grpc::Status::OK; } else { diff --git a/third_party/xla/third_party/tsl/tsl/distributed_runtime/rpc/grpc_util_test.cc b/third_party/xla/third_party/tsl/tsl/distributed_runtime/rpc/grpc_util_test.cc index 2d5554b2c3c19c..9872c1cf705a0f 100644 --- a/third_party/xla/third_party/tsl/tsl/distributed_runtime/rpc/grpc_util_test.cc +++ b/third_party/xla/third_party/tsl/tsl/distributed_runtime/rpc/grpc_util_test.cc @@ -71,9 +71,9 @@ TestRequest MakeProto(int size) { } TEST(PayloadSerialization, PayloadsAreTransmitted) { - Status status = errors::InvalidArgument("invalid arg message"); + absl::Status status = errors::InvalidArgument("invalid arg message"); status.SetPayload("a", absl::Cord("\\xFF\\x02\\x03")); - Status status_recovered = FromGrpcStatus(ToGrpcStatus(status)); + absl::Status status_recovered = FromGrpcStatus(ToGrpcStatus(status)); ASSERT_TRUE(status_recovered.GetPayload("a").has_value()); EXPECT_EQ(status_recovered.GetPayload("a").value(), "\\xFF\\x02\\x03"); @@ -84,7 +84,7 @@ TEST(PayloadSerialization, PayloadsCorrupted) { ::grpc::StatusCode::INVALID_ARGUMENT, "invalid arg message", "string that can not be serialized to the GrpcPayloadContainer proto"); - Status converted = FromGrpcStatus(status); + absl::Status converted = FromGrpcStatus(status); EXPECT_TRUE(converted.GetPayload(kGrpcPayloadsLost).has_value()); } diff --git a/third_party/xla/third_party/tsl/tsl/framework/BUILD b/third_party/xla/third_party/tsl/tsl/framework/BUILD index 3e23aa05cbd3d9..1706df89dc563c 100644 --- a/third_party/xla/third_party/tsl/tsl/framework/BUILD +++ b/third_party/xla/third_party/tsl/tsl/framework/BUILD @@ -4,8 +4,12 @@ # The libraries in this package are not allowed to have ANY dependencies # to other TF components outside of TSL. -load("//tsl:tsl.bzl", "set_external_visibility") +load("//tsl:tsl.bzl", "internal_visibility") load("//tsl:tsl.default.bzl", "filegroup", "get_compatible_with_portable") +load( + "//tsl/platform:build_config.bzl", + "tsl_cc_test", +) load( "//tsl/platform:build_config_root.bzl", "if_static", @@ -14,13 +18,12 @@ load( "@local_tsl//tsl/platform:rules_cc.bzl", "cc_library", ) -load( - "//tsl/platform:build_config.bzl", - "tsl_cc_test", -) package( - default_visibility = ["//visibility:public"], + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = [ + "//visibility:public", + ], licenses = ["notice"], ) @@ -37,7 +40,6 @@ filegroup( "tracking_allocator.h", "type_traits.h", ], - visibility = ["//visibility:public"], ) # Files needed for core:mobile_srcs_no_runtime. @@ -56,7 +58,6 @@ filegroup( "tracking_allocator.h", "type_traits.h", ], - visibility = ["//visibility:public"], ) # Files needed for core:mobile_srcs_only_runtime. @@ -69,7 +70,6 @@ filegroup( "metrics.cc", "metrics.h", ], - visibility = ["//visibility:public"], ) filegroup( @@ -82,14 +82,13 @@ filegroup( "tracking_allocator.h", "type_traits.h", ], - visibility = ["//visibility:public"], + visibility = internal_visibility(["//tensorflow/core:__subpackages__"]), ) # Files needed for tf2xla build. filegroup( name = "xla_cpu_runtime_hdrs", srcs = ["fixedpoint_types.h"], - visibility = ["//visibility:public"], ) # Individual targets. These should be preferred over tensorflow/core:framework @@ -155,7 +154,11 @@ cc_library( "cpu_allocator_impl.cc", "tracking_allocator.h", ], - visibility = ["//visibility:public"], + visibility = internal_visibility([ + "@local_xla//xla:__subpackages__", + "//tensorflow/core:__subpackages__", + "//tsl:__subpackages__", + ]), deps = [ ":numeric_types", ":type_traits", @@ -228,7 +231,6 @@ cc_library( "device_id.h", "device_id_manager.h", ], - visibility = ["//visibility:public"], deps = [ "//tsl/lib/gtl:int_type", ] + if_static([ @@ -243,7 +245,6 @@ cc_library( "device_id.h", "device_id_manager.h", ], - visibility = ["//visibility:public"], deps = [ ":device_type", "//tsl/lib/gtl:int_type", @@ -265,7 +266,6 @@ cc_library( hdrs = [ "device_id_utils.h", ], - visibility = ["//visibility:public"], deps = [ ":device_id_impl", ":device_type", @@ -286,13 +286,15 @@ filegroup( "device_id.h", "device_id_manager.h", ], - visibility = ["//visibility:public"], ) cc_library( name = "numeric_types", hdrs = ["numeric_types.h"], - visibility = ["//visibility:public"], + visibility = internal_visibility([ + "//tensorflow/compiler:__subpackages__", + "//tensorflow/core:__subpackages__", + ]), deps = [ ":fixedpoint_types", "//tsl/platform:types", @@ -313,7 +315,6 @@ cc_library( name = "metrics", srcs = ["metrics.cc"], hdrs = ["metrics.h"], - visibility = ["//visibility:public"], deps = [ "//tsl/lib/monitoring:counter", ], @@ -332,7 +333,9 @@ cc_library( cc_library( name = "type_traits", hdrs = ["type_traits.h"], - visibility = ["//visibility:public"], + visibility = internal_visibility([ + "//tensorflow/core/framework:__pkg__", + ]), deps = [ ":numeric_types", "//tsl/platform:types", @@ -344,7 +347,7 @@ filegroup( srcs = [ "cancellation.h", ], - visibility = ["//visibility:public"], + visibility = internal_visibility(["//tensorflow/core:__subpackages__"]), ) cc_library( @@ -371,11 +374,35 @@ cc_library( ], ) +cc_library( + name = "serving_device_selector", + srcs = ["serving_device_selector.cc"], + hdrs = ["serving_device_selector.h"], + visibility = ["//visibility:public"], + deps = [ + "//tsl/platform:logging", + "@com_google_absl//absl/container:fixed_array", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + ], +) + +cc_library( + name = "serving_device_selector_policies", + srcs = ["serving_device_selector_policies.cc"], + hdrs = ["serving_device_selector_policies.h"], + features = ["-layering_check"], + deps = [ + ":serving_device_selector", + "@com_google_absl//absl/strings:string_view", + ], +) + tsl_cc_test( name = "cancellation_test", size = "small", srcs = ["cancellation_test.cc"], - visibility = ["//visibility:public"], deps = [ ":cancellation", "//tsl/platform:env", @@ -401,7 +428,12 @@ exports_files( "shared_counter.h", "tracking_allocator.h", ], - visibility = ["//visibility:public"], + visibility = internal_visibility([ + "//tensorflow/core:__pkg__", + "//tensorflow/core/common_runtime:__pkg__", + "//tensorflow/core/common_runtime/gpu:__pkg__", + "//tensorflow/core/framework:__pkg__", + ]), ) # Files whose users still need to be migrated from core:framework to the @@ -415,7 +447,6 @@ exports_files( "numeric_types.h", "type_traits.h", ], - visibility = ["//visibility:public"], ) tsl_cc_test( @@ -423,7 +454,6 @@ tsl_cc_test( srcs = [ "device_id_utils_test.cc", ], - visibility = ["//visibility:public"], deps = [ ":device_id_impl", ":device_id_utils", diff --git a/third_party/xla/third_party/tsl/tsl/framework/bfc_allocator.cc b/third_party/xla/third_party/tsl/tsl/framework/bfc_allocator.cc index e19c0018d26e18..79a8ea7c892d00 100644 --- a/third_party/xla/third_party/tsl/tsl/framework/bfc_allocator.cc +++ b/third_party/xla/third_party/tsl/tsl/framework/bfc_allocator.cc @@ -135,24 +135,19 @@ bool BFCAllocator::Extend(size_t alignment, size_t rounded_bytes) { size_t bytes = std::min(curr_region_allocation_bytes_, available_bytes); size_t bytes_received; void* mem_addr = sub_allocator_->Alloc(alignment, bytes, &bytes_received); - if (mem_addr == nullptr && !started_backpedal_) { - // Only backpedal once. - started_backpedal_ = true; - + if (mem_addr == nullptr) { static constexpr float kBackpedalFactor = 0.9; // Try allocating less memory. while (mem_addr == nullptr) { bytes = RoundedBytes(bytes * kBackpedalFactor); - if (bytes < rounded_bytes) break; + if (bytes < rounded_bytes) { + return false; + } mem_addr = sub_allocator_->Alloc(alignment, bytes, &bytes_received); } } - if (mem_addr == nullptr) { - return false; - } - if (!increased_allocation) { // Increase the region size of the next required allocation. curr_region_allocation_bytes_ *= 2; diff --git a/third_party/xla/third_party/tsl/tsl/framework/bfc_allocator.h b/third_party/xla/third_party/tsl/tsl/framework/bfc_allocator.h index 47619856abe8dd..76921c5f04a79d 100644 --- a/third_party/xla/third_party/tsl/tsl/framework/bfc_allocator.h +++ b/third_party/xla/third_party/tsl/tsl/framework/bfc_allocator.h @@ -579,10 +579,6 @@ class BFCAllocator : public Allocator { // The size of the current region allocation. size_t curr_region_allocation_bytes_; - // An indicator that expansion of a region has hit the limits - // of the available memory. - bool started_backpedal_ = false; - // Whether the allocator will coalesce adjacent sub allocator provided // AllocationRegions. This may be disabled if discrete sub allocator // regions can't be treated as contiguous (e.g. if the allocation refers to diff --git a/third_party/xla/third_party/tsl/tsl/framework/contraction/BUILD b/third_party/xla/third_party/tsl/tsl/framework/contraction/BUILD index 26dafa4ef99e41..e265644dce3d79 100644 --- a/third_party/xla/third_party/tsl/tsl/framework/contraction/BUILD +++ b/third_party/xla/third_party/tsl/tsl/framework/contraction/BUILD @@ -4,6 +4,7 @@ load("@bazel_skylib//rules:common_settings.bzl", "bool_flag") load("@bazel_skylib//:bzl_library.bzl", "bzl_library") package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], default_visibility = ["//visibility:public"], features = [ # Required since headers are not self-contained. @@ -24,13 +25,11 @@ config_setting( define_values = { "tensorflow_mkldnn_contraction_kernel": "0", }, - visibility = ["//visibility:public"], ) bzl_library( name = "build_defs", srcs = ["build_defs.bzl"], - visibility = ["//visibility:public"], ) # Add @@ -45,7 +44,6 @@ bool_flag( config_setting( name = "disable_onednn_contraction_kernel_config", flag_values = {":disable_onednn_contraction_kernel": "True"}, - visibility = ["//visibility:public"], ) # Depending on a build configuration this target provides custom kernel for Eigen @@ -77,7 +75,6 @@ cc_library( name = "eigen_contraction_kernel", hdrs = ["eigen_contraction_kernel.h"], compatible_with = get_compatible_with_portable(), - visibility = ["//visibility:public"], deps = select({ ":no_mkldnn_contraction_kernel": [":eigen_contraction_kernel_no_mkl"], ":disable_onednn_contraction_kernel_config": [":eigen_contraction_kernel_no_mkl"], @@ -105,7 +102,6 @@ cc_library( "TENSORFLOW_USE_MKLDNN_CONTRACTION_KERNEL", ], }), - visibility = ["//visibility:public"], deps = [ "//tsl/framework/fixedpoint", "//tsl/platform:dynamic_annotations", @@ -131,7 +127,6 @@ exports_files( "eigen_contraction_kernel.cc", "eigen_contraction_kernel.h", ], - visibility = ["//visibility:public"], ) cc_library( @@ -139,7 +134,6 @@ cc_library( srcs = ["eigen_contraction_kernel.cc"], hdrs = ["eigen_contraction_kernel.h"], compatible_with = get_compatible_with_portable(), - visibility = ["//visibility:public"], # Somehow the following code works with fixedpoint, but not here. # visibility = [ # "//tensorflow:__subpackages__", @@ -160,7 +154,6 @@ filegroup( srcs = [ "eigen_contraction_kernel.h", ], - visibility = ["//visibility:public"], ) # Maintain the same name as other directories until a principled refactor is done, as these files @@ -170,5 +163,4 @@ filegroup( srcs = [ "eigen_contraction_kernel.cc", ], - visibility = ["//visibility:public"], ) diff --git a/third_party/xla/third_party/tsl/tsl/framework/convolution/BUILD b/third_party/xla/third_party/tsl/tsl/framework/convolution/BUILD index d6ed7283d0dbfc..6af7b983946c64 100644 --- a/third_party/xla/third_party/tsl/tsl/framework/convolution/BUILD +++ b/third_party/xla/third_party/tsl/tsl/framework/convolution/BUILD @@ -6,6 +6,7 @@ load( ) package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], default_visibility = ["//visibility:public"], features = [ # Required since headers are not self-contained. @@ -19,7 +20,6 @@ cc_library( "eigen_spatial_convolutions-inl.h", ], compatible_with = get_compatible_with_portable(), - visibility = ["//visibility:public"], deps = [ "//tsl/framework/convolution:eigen_convolution_helpers", ], @@ -34,7 +34,6 @@ cc_library( defines = [ "EIGEN_ALTIVEC_USE_CUSTOM_PACK=0", ], - visibility = ["//visibility:public"], ) # Tensorflow also has an eigen_helpers that is closely related, so maintain the same name. @@ -45,7 +44,6 @@ cc_library( ], compatible_with = get_compatible_with_portable(), defines = ["EIGEN_NEON_GEBP_NR=4"], - visibility = ["//visibility:public"], deps = [ "//tsl/framework/contraction:eigen_contraction_kernel", "//tsl/framework/convolution:eigen_convolution_helpers", @@ -63,7 +61,6 @@ filegroup( "eigen_spatial_convolutions.h", "eigen_spatial_convolutions-inl.h", ], - visibility = ["//visibility:public"], # Somehow the following code works with fixedpoint, but not here. # visibility = [ # "//tensorflow:__subpackages__", @@ -92,7 +89,6 @@ exports_files( "eigen_spatial_convolutions.h", "eigen_spatial_convolutions-inl.h", ], - visibility = ["//visibility:public"], ) tsl_cc_test( @@ -101,7 +97,6 @@ tsl_cc_test( srcs = [ "eigen_spatial_convolutions_test.cc", ], - visibility = ["//visibility:public"], deps = [ ":eigen_helpers", "//tsl/platform:test", diff --git a/third_party/xla/third_party/tsl/tsl/framework/fixedpoint/BUILD b/third_party/xla/third_party/tsl/tsl/framework/fixedpoint/BUILD index b9319feacd4fbc..310080f5285a60 100644 --- a/third_party/xla/third_party/tsl/tsl/framework/fixedpoint/BUILD +++ b/third_party/xla/third_party/tsl/tsl/framework/fixedpoint/BUILD @@ -1,7 +1,9 @@ -load("@local_tsl//tsl/platform:rules_cc.bzl", "cc_library") +load("//tsl:tsl.bzl", "internal_visibility") load("//tsl:tsl.default.bzl", "get_compatible_with_portable") +load("@local_tsl//tsl/platform:rules_cc.bzl", "cc_library") package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], default_visibility = ["//visibility:public"], features = [ # Required since headers are not self-contained. @@ -25,7 +27,6 @@ cc_library( "TypeCastingAVX512.h", ], compatible_with = get_compatible_with_portable(), - visibility = ["//visibility:public"], deps = [ "//tsl/framework:fixedpoint_types", "@eigen_archive//:eigen3", @@ -47,7 +48,10 @@ filegroup( "TypeCastingAVX512.h", ], compatible_with = get_compatible_with_portable(), - visibility = ["//visibility:public"], + visibility = internal_visibility([ + "//tensorflow:__subpackages__", + "//tsl:internal", + ]), ) # Files needed for core:mobile_srcs_no_runtime. @@ -66,7 +70,6 @@ filegroup( "TypeCastingAVX512.h", ], compatible_with = get_compatible_with_portable(), - visibility = ["//visibility:public"], ) filegroup( @@ -84,5 +87,4 @@ filegroup( "TypeCastingAVX512.h", ], compatible_with = get_compatible_with_portable(), - visibility = ["//visibility:public"], ) diff --git a/third_party/xla/third_party/tsl/tsl/framework/serving_device_selector.cc b/third_party/xla/third_party/tsl/tsl/framework/serving_device_selector.cc new file mode 100644 index 00000000000000..96ea75e258b960 --- /dev/null +++ b/third_party/xla/third_party/tsl/tsl/framework/serving_device_selector.cc @@ -0,0 +1,161 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tsl/framework/serving_device_selector.h" + +#include +#include +#include +#include +#include + +#include "absl/container/fixed_array.h" +#include "absl/strings/string_view.h" +#include "tsl/platform/logging.h" + +namespace tsl { + +inline constexpr int kHighPriority = 0; + +DeviceReservation::DeviceReservation(int device_index, + ServingDeviceSelector* device_selector) + : device_index_(device_index), device_selector_(device_selector) {} + +DeviceReservation::~DeviceReservation() { reset(); } + +void DeviceReservation::reset() { + if (device_selector_) device_selector_->FreeDeviceReservation(*this); + device_selector_ = nullptr; +} + +DeviceReservation::DeviceReservation(DeviceReservation&& r) + : device_index_{r.device_index_}, device_selector_{r.device_selector_} { + r.device_selector_ = nullptr; +} + +DeviceReservation& DeviceReservation::operator=(DeviceReservation&& r) { + if (this == &r) return *this; + + if (device_selector_) device_selector_->FreeDeviceReservation(*this); + + device_index_ = r.device_index_; + device_selector_ = r.device_selector_; + r.device_selector_ = nullptr; + return *this; +} + +/*static*/ void ServingDeviceSelector::CompletedHelper( + DeviceState& device_state, int32_t device_index, int32_t priority, + std::optional& min_exec_time, bool had_error, int64_t now_ns) { + // Check that priority 'priority' queue is non-empty. + DCHECK(!device_state.enqueued_programs[priority].empty()); + auto& program_info = device_state.enqueued_programs[priority].front(); + auto prefetch_results = program_info.prefetch_results; + auto execution_info = program_info.execution_info; + device_state.enqueued_programs[priority].pop_front(); + // To make tracked execution time as accurate as possible, we only record this + // execution time if two programs ran back-to-back without host round trip. + if (!device_state.timer_reset && !had_error) { + LOG(INFO) << "Complete. update device[" << device_index + << "], priority: " << priority + << ", prefetch: " << static_cast(prefetch_results) + << ", time: " << now_ns - device_state.last_started_ns; + const_cast(execution_info) + ->AddTime(now_ns - device_state.last_started_ns, prefetch_results); + // Only update min_exec_time_ when running_average is updated. This avoids + // the case where running_average is zero. + if (!min_exec_time.has_value() || + execution_info->GetTime(prefetch_results) < min_exec_time.value()) { + min_exec_time = execution_info->GetTime(prefetch_results); + } + } + // If there are remaining programs, update the start time. + if (!device_state.enqueued_programs.empty()) { + device_state.last_started_ns = now_ns; + device_state.timer_reset = false; + } +} + +/*static*/ int64_t ServingDeviceSelector::EstimateTimeTillIdleNs( + const DeviceState& device_state, int32_t priority, int64_t min_exec_time, + int64_t now_ns) { + int64_t ns_till_idle = 0; + // Add time from each program in queues with priority 'priority' or higher. + for (int32_t i = 0; i <= priority; i++) { + for (auto& info : device_state.enqueued_programs[i]) { + ns_till_idle += + info.execution_info->MaybeGetValidTime(info.prefetch_results); + } + } + // Accounts for the elapsed time of the currently running but unfinished + // program (i.e., enqueued programs). + if (ns_till_idle > 0) { + DCHECK_GT(device_state.last_started_ns, 0); + ns_till_idle = std::max( + 0, ns_till_idle - (now_ns - device_state.last_started_ns)); + } + + // Add time from scheduled programs with priority 'priority' or higher + int64_t ns_of_schedule_programs = 0; + for (int32_t i = 0; i <= priority; i++) { + for (auto& info : device_state.scheduled_programs[i]) { + ns_of_schedule_programs += std::max( + info.execution_info->MaybeGetValidTime(info.prefetch_results), + min_exec_time); + } + } + return ns_till_idle + ns_of_schedule_programs; +} +/*static*/ void ServingDeviceSelector::EnqueueHelper( + DeviceState& device_state, int32_t device_index, + ExecutionInfo& execution_info, absl::string_view fingerprint, + int32_t priority, int64_t req_id, size_t priority_queue_count, + int prefetch_results, int64_t now_ns) { + if (!device_state.scheduled_programs[priority].empty()) { + auto& program = device_state.scheduled_programs[priority].front(); + if (program.fingerprint.empty()) { + program.execution_info = &execution_info; + program.fingerprint = fingerprint; + if (priority == kHighPriority) { + device_state.last_fingerprint = fingerprint; + } + device_state.unknown_fingerprint_requests--; + } + device_state.enqueued_programs[static_cast(priority)].push_back( + std::move(program)); + device_state.scheduled_programs[static_cast(priority)].pop_front(); + } else { + DeviceState::ProgramInfo program; + program.execution_info = &execution_info; + program.fingerprint = fingerprint; + program.req_id = req_id; + program.priority = priority; + program.prefetch_results = prefetch_results; + device_state.enqueued_programs[priority].push_back(program); + device_state.last_fingerprint = fingerprint; + } + + // Count number of programs in enqueued_programs queues. + int64_t num_programs_enqueued = 0; + for (int64_t i = 0; i < priority_queue_count; i++) { + num_programs_enqueued += device_state.enqueued_programs[i].size(); + } + + if (num_programs_enqueued == 1) { + device_state.last_started_ns = now_ns; + device_state.timer_reset = true; + } +} +} // namespace tsl diff --git a/third_party/xla/third_party/tsl/tsl/framework/serving_device_selector.h b/third_party/xla/third_party/tsl/tsl/framework/serving_device_selector.h new file mode 100644 index 00000000000000..9ec14a61dccf63 --- /dev/null +++ b/third_party/xla/third_party/tsl/tsl/framework/serving_device_selector.h @@ -0,0 +1,193 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_TSL_FRAMEWORK_SERVING_DEVICE_SELECTOR_H_ +#define TENSORFLOW_TSL_FRAMEWORK_SERVING_DEVICE_SELECTOR_H_ + +#include +#include +#include + +#include "absl/container/fixed_array.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "tsl/platform/logging.h" + +namespace tsl { + +class ServingDeviceSelector; + +// A RAII type for device reservation. +class DeviceReservation { + public: + DeviceReservation(int device_index, ServingDeviceSelector* selector); + ~DeviceReservation(); + + DeviceReservation(const DeviceReservation&) = delete; + DeviceReservation& operator=(const DeviceReservation&) = delete; + + DeviceReservation(DeviceReservation&& r); + DeviceReservation& operator=(DeviceReservation&& r); + + int device_index() const { return device_index_; } + + void reset(); + + private: + int device_index_; + ServingDeviceSelector* device_selector_; +}; + +// Interface for runtime device selection for serving. +// NOTE: This interface is experimental and subject to change. +class ServingDeviceSelector { + public: + // Tracks the running average of certain program execution time. + class RunningAverage { + public: + void Add(int64_t value) { + DCHECK_GE(value, 0); + sum_ += value; + ++count_; + latency_ = sum_ / count_; + } + + int64_t Get() const { return latency_; } + + private: + int64_t sum_ = 0; + int64_t count_ = 0; + int64_t latency_ = 0; + }; + + // Tracks the program execution information, including execution time. + class ExecutionInfo { + public: + explicit ExecutionInfo(int64_t num_prefetch_result = 1) + : running_average_(num_prefetch_result) {} + + virtual ~ExecutionInfo() = default; + + void AddTime(int64_t value, int result) { + DCHECK_GE(value, 0); + DCHECK_LT(result, running_average_.size()); + running_average_.at(result).Add(value); + } + + int64_t GetTime(int result) const { + DCHECK_LT(result, running_average_.size()); + return running_average_.at(result).Get(); + } + + // To be conservative when one of the path is missing. + virtual int64_t MaybeGetValidTime(int result) const { + return GetTime(result); + } + + private: + // Records program average execution time, one for each prefetch result. + absl::FixedArray running_average_; + }; + + struct DeviceState { + explicit DeviceState(int64_t priority_count = 1) + : enqueued_programs(priority_count), + scheduled_programs(priority_count) {} + // TODO(b/295352859): Add more stats to track that are useful for the Policy + // to use when selecting a device. + struct ProgramInfo { + std::string fingerprint; + int32_t priority; + int64_t req_id = -1; + const ExecutionInfo* execution_info; + int prefetch_results; + }; + // A queue of enqueued programs, one for each priority level + absl::FixedArray> enqueued_programs; + // A queue of scheduled yet enqueued programs, one for each priority level. + // May or may not have fingerprint. + absl::FixedArray> scheduled_programs; + // Timestamp in nanoseconds of last started program. + int64_t last_started_ns = 0; + // Fingerprint of last enqueued high priority program. + std::string last_fingerprint; + // The number of scheduled not yet enqueued programs with unknown + // fingerprints. + int32_t unknown_fingerprint_requests; + // Whether execution timer was reset, true iff a program is enqueued while + // all queues (for all priorities) were empty. + bool timer_reset = true; + }; + + // Struct of all tracked device states, which will be passed to Policy. + struct DeviceStates { + absl::Span states; + }; + + // Policy used to select a device. + class Policy { + public: + virtual ~Policy() = default; + // Selects a device based on the tracked states of all devices. + virtual int SelectDevice(absl::string_view program_fingerprint, + const DeviceStates& device_states) = 0; + }; + + virtual ~ServingDeviceSelector() = default; + + // Reserves a device according to a given selection policy. The reserved + // device will be freed when the lifetime of the returned `DeviceReservation` + // object ends. + virtual DeviceReservation ReserveDevice( + absl::string_view program_fingerprint) = 0; + + protected: + // A helper function for Enqueue. The EnqueueHelper does the following things. + // 1. If there are programs in the scheduled_programs queue of the given + // priority, move the program to the corresponding enqueued_programs + // queue. Update the fingerprint if it is unknown. This is a typical TF1 + // use case. + // 2. If there are no programs in the scheduled_programs queue of the given + // priority, create the program of the fingerprint and place it in the + // corresponding enqueued_programs queue. + // This can happen in two cases: (1) TFRT that doesn't need + // scheduled_programs queue. (2) In TF1, Schedule() was not called prior + // to Enqueue(). + // This helper also updates last_started_ns and timer_reset. + static void EnqueueHelper(DeviceState& device_state, int32_t device_index, + ExecutionInfo& execution_info, + absl::string_view fingerprint, int32_t priority, + int64_t req_id, size_t priority_queue_count, + int prefetch_results, int64_t now_ns); + // A helper function tells a program has completed on the given device. + static void CompletedHelper(DeviceState& device_state, int32_t device_index, + int32_t priority, + std::optional& min_exec_time, + bool had_error, int64_t now_ns); + // Helper to estimate the time until the core becomes idle in nanoseconds. + // Only considers queues with priority at least as high as 'priority'. + static int64_t EstimateTimeTillIdleNs(const DeviceState& device_state, + int32_t priority, int64_t min_exec_time, + int64_t now_ns); + + private: + friend DeviceReservation; + + // Frees the given device reservation. + virtual void FreeDeviceReservation(const DeviceReservation& reservation) = 0; +}; + +} // namespace tsl + +#endif // TENSORFLOW_TSL_FRAMEWORK_SERVING_DEVICE_SELECTOR_H_ diff --git a/tensorflow/core/common_runtime/serving_device_selector_policies.cc b/third_party/xla/third_party/tsl/tsl/framework/serving_device_selector_policies.cc similarity index 83% rename from tensorflow/core/common_runtime/serving_device_selector_policies.cc rename to third_party/xla/third_party/tsl/tsl/framework/serving_device_selector_policies.cc index 336b955760f30d..7c074ff0780187 100644 --- a/tensorflow/core/common_runtime/serving_device_selector_policies.cc +++ b/third_party/xla/third_party/tsl/tsl/framework/serving_device_selector_policies.cc @@ -12,14 +12,14 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/core/common_runtime/serving_device_selector_policies.h" +#include "tsl/framework/serving_device_selector_policies.h" #include #include "absl/strings/string_view.h" -#include "tensorflow/core/common_runtime/serving_device_selector.h" +#include "tsl/framework/serving_device_selector.h" -namespace tensorflow { +namespace tsl { int RoundRobinPolicy::SelectDevice( absl::string_view program_fingerprint, @@ -28,4 +28,4 @@ int RoundRobinPolicy::SelectDevice( return ordinal_.fetch_add(1, std::memory_order_relaxed) % num_devices; } -} // namespace tensorflow +} // namespace tsl diff --git a/tensorflow/core/common_runtime/serving_device_selector_policies.h b/third_party/xla/third_party/tsl/tsl/framework/serving_device_selector_policies.h similarity index 75% rename from tensorflow/core/common_runtime/serving_device_selector_policies.h rename to third_party/xla/third_party/tsl/tsl/framework/serving_device_selector_policies.h index 916e91bfaf1fb9..638206bc1229c8 100644 --- a/tensorflow/core/common_runtime/serving_device_selector_policies.h +++ b/third_party/xla/third_party/tsl/tsl/framework/serving_device_selector_policies.h @@ -12,14 +12,14 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_SERVING_DEVICE_SELECTOR_POLICIES_H_ -#define TENSORFLOW_CORE_COMMON_RUNTIME_SERVING_DEVICE_SELECTOR_POLICIES_H_ +#ifndef TENSORFLOW_TSL_FRAMEWORK_SERVING_DEVICE_SELECTOR_POLICIES_H_ +#define TENSORFLOW_TSL_FRAMEWORK_SERVING_DEVICE_SELECTOR_POLICIES_H_ #include -#include "tensorflow/core/common_runtime/serving_device_selector.h" +#include "tsl/framework/serving_device_selector.h" -namespace tensorflow { +namespace tsl { enum class ServingDeviceSelectorPolicy { kRoundRobin, @@ -37,6 +37,6 @@ class RoundRobinPolicy : public ServingDeviceSelector::Policy { std::atomic ordinal_; }; -} // namespace tensorflow +} // namespace tsl -#endif // TENSORFLOW_CORE_COMMON_RUNTIME_SERVING_DEVICE_SELECTOR_POLICIES_H_ +#endif // TENSORFLOW_TSL_FRAMEWORK_SERVING_DEVICE_SELECTOR_POLICIES_H_ diff --git a/third_party/xla/third_party/tsl/tsl/lib/core/BUILD b/third_party/xla/third_party/tsl/tsl/lib/core/BUILD index f597293bf733ae..eb5bbda819c7c3 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/core/BUILD +++ b/third_party/xla/third_party/tsl/tsl/lib/core/BUILD @@ -4,9 +4,9 @@ # The libraries in this package are not allowed to have ANY dependencies # to other TF components outside of TSL. -load("//tsl/platform:build_config.bzl", "tsl_cc_test") -load("//tsl:tsl.bzl", "set_external_visibility") +load("//tsl:tsl.bzl", "internal_visibility") load("//tsl:tsl.default.bzl", "get_compatible_with_portable") +load("//tsl/platform:build_config.bzl", "tsl_cc_test") load( "@local_tsl//tsl/platform:rules_cc.bzl", "cc_library", @@ -14,7 +14,10 @@ load( # TODO(rdzhabarov): Tighten visibility after migration is complete. package( - default_visibility = ["//visibility:public"], + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = [ + "//visibility:public", + ], licenses = ["notice"], ) @@ -25,7 +28,7 @@ filegroup( "bits.h", ], compatible_with = get_compatible_with_portable(), - visibility = ["//visibility:public"], + visibility = internal_visibility(["//tensorflow/core:__pkg__"]), ) filegroup( @@ -36,7 +39,10 @@ filegroup( "status_test_util.h", ], compatible_with = get_compatible_with_portable(), - visibility = ["//visibility:public"], + visibility = internal_visibility([ + "//tensorflow/core:__pkg__", + "//tensorflow/core/lib/core:__pkg__", + ]), ) filegroup( @@ -45,7 +51,7 @@ filegroup( "bitmap_test.cc", ], compatible_with = get_compatible_with_portable(), - visibility = ["//visibility:public"], + visibility = internal_visibility(["//tensorflow/core:__pkg__"]), ) filegroup( @@ -55,7 +61,7 @@ filegroup( "bits.h", ], compatible_with = get_compatible_with_portable(), - visibility = ["//visibility:public"], + visibility = internal_visibility(["//tensorflow/core:__pkg__"]), ) filegroup( @@ -64,7 +70,10 @@ filegroup( "status_test_util.h", ], compatible_with = get_compatible_with_portable(), - visibility = ["//visibility:public"], + visibility = internal_visibility([ + "//tensorflow/core:__pkg__", + "//tensorflow/core/lib/core:__pkg__", + ]), ) cc_library( @@ -72,7 +81,6 @@ cc_library( srcs = ["bitmap.cc"], hdrs = ["bitmap.h"], compatible_with = get_compatible_with_portable(), - visibility = ["//visibility:public"], deps = [ "//tsl/platform:logging", ], @@ -83,7 +91,6 @@ cc_library( name = "status_test_util", testonly = 1, hdrs = ["status_test_util.h"], - visibility = ["//visibility:public"], deps = [ "//tsl/platform:status", "//tsl/platform:test", @@ -93,7 +100,6 @@ cc_library( cc_library( name = "bits", hdrs = ["bits.h"], - visibility = ["//visibility:public"], deps = [ "//tsl/platform:logging", "//tsl/platform:types", @@ -105,7 +111,6 @@ tsl_cc_test( name = "bits_test", size = "small", srcs = ["bits_test.cc"], - visibility = ["//visibility:public"], deps = [ ":bits", "//tsl/platform:test", diff --git a/third_party/xla/third_party/tsl/tsl/lib/gtl/BUILD b/third_party/xla/third_party/tsl/tsl/lib/gtl/BUILD index 8ade5f0eb0beb3..306360d7cd6b16 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/gtl/BUILD +++ b/third_party/xla/third_party/tsl/tsl/lib/gtl/BUILD @@ -1,4 +1,4 @@ -load("//tsl:tsl.bzl", "set_external_visibility") +load("//tsl:tsl.bzl", "internal_visibility") load("//tsl:tsl.default.bzl", "filegroup") load( "//tsl/platform:build_config.bzl", @@ -10,7 +10,28 @@ load( ) package( - default_visibility = ["//visibility:public"], + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = internal_visibility([ + # tensorflow/core:lib effectively exposes all targets under tensorflow/core/lib/** + "//tensorflow/core:__pkg__", + # tensorflow/core/lib/strings:proto_serialization uses on gtl:inlined_vector + "//tensorflow/core/lib/strings:__pkg__", + "//tsl/lib/strings:__pkg__", + # tensorflow/core/framework uses map_util, and flatmap + "//tensorflow/core/framework:__pkg__", + "//tsl/framework:__pkg__", + "//tsl/platform/cloud:__pkg__", + # tensorflow/core/util uses inlined_vector + "//tensorflow/core/util:__pkg__", + # tensorflow/core/tfrt/utils uses inlined_vector + "//tensorflow/core/tfrt/utils:__pkg__", + # tensorflow/examples/custom_ops_doc/simple_hash_table uses map_util + "//tensorflow/examples/custom_ops_doc/simple_hash_table:__pkg__", + "@local_xla//xla:__subpackages__", + "//tensorflow/core/lib/gtl:__subpackages__", + "//tsl/distributed_runtime/rpc:__pkg__", + "//tsl/profiler/utils:__pkg__", + ]), licenses = ["notice"], ) @@ -19,14 +40,12 @@ package( cc_library( name = "compactptrset", hdrs = ["compactptrset.h"], - visibility = ["//visibility:public"], deps = [":flatset"], ) cc_library( name = "flatmap", hdrs = ["flatmap.h"], - visibility = ["//visibility:public"], deps = [ ":flatrep", "//tsl/platform:hash", @@ -38,7 +57,6 @@ cc_library( cc_library( name = "flatrep", hdrs = ["flatrep.h"], - visibility = ["//visibility:public"], deps = [ "//tsl/platform:types", "@com_google_absl//absl/base:prefetch", @@ -48,7 +66,6 @@ cc_library( cc_library( name = "flatset", hdrs = ["flatset.h"], - visibility = ["//visibility:public"], deps = [ ":flatrep", "//tsl/platform:hash", @@ -60,7 +77,6 @@ cc_library( cc_library( name = "inlined_vector", hdrs = ["inlined_vector.h"], - visibility = ["//visibility:public"], deps = [ "//tsl/platform:macros", "//tsl/platform:types", @@ -71,7 +87,6 @@ cc_library( cc_library( name = "int_type", hdrs = ["int_type.h"], - visibility = ["//visibility:public"], deps = [ "//tsl/platform:macros", "//tsl/platform:types", @@ -81,7 +96,6 @@ cc_library( cc_library( name = "iterator_range", hdrs = ["iterator_range.h"], - visibility = ["//visibility:public"], ) cc_library( @@ -91,7 +105,6 @@ cc_library( "//tsl/lib/gtl/subtle:map_traits", ], hdrs = ["map_util.h"], - visibility = ["//visibility:public"], ) filegroup( @@ -103,7 +116,10 @@ filegroup( "inlined_vector.h", "iterator_range.h", ], - visibility = ["//visibility:public"], + visibility = internal_visibility([ + "//tensorflow/core:__pkg__", + "//tensorflow/core/lib/gtl:__pkg__", + ]), ) filegroup( @@ -112,21 +128,30 @@ filegroup( "int_type.h", "map_util.h", ], - visibility = ["//visibility:public"], + visibility = internal_visibility([ + "//tensorflow/core:__pkg__", + "//tensorflow/core/lib/gtl:__pkg__", + ]), ) filegroup( name = "legacy_lib_test_internal_headers", srcs = [ ], - visibility = ["//visibility:public"], + visibility = internal_visibility([ + "//tensorflow/core:__pkg__", + "//tensorflow/core/lib/gtl:__pkg__", + ]), ) filegroup( name = "legacy_android_gif_internal_headers", srcs = [ ], - visibility = ["//visibility:public"], + visibility = internal_visibility([ + "//tensorflow/core:__pkg__", + "//tensorflow/core/lib/gtl:__pkg__", + ]), ) # Export source files needed for mobile builds, which do not use granular targets. @@ -137,7 +162,11 @@ filegroup( "flatrep.h", "inlined_vector.h", ], - visibility = ["//visibility:public"], + visibility = internal_visibility([ + "//tensorflow/core:__pkg__", + "//tensorflow/core/lib/gtl:__pkg__", + "//tsl:__subpackages__", + ]), ) filegroup( @@ -149,7 +178,10 @@ filegroup( "map_util.h", "//tsl/lib/gtl/subtle:map_traits", ], - visibility = ["//visibility:public"], + visibility = internal_visibility([ + "//tensorflow/core:__pkg__", + "//tensorflow/core/lib/gtl:__pkg__", + ]), ) filegroup( @@ -165,7 +197,10 @@ filegroup( "map_util.h", "//tsl/lib/gtl/subtle:map_traits", ], - visibility = ["//visibility:public"], + visibility = internal_visibility([ + "//tensorflow/core:__pkg__", + "//tensorflow/core/lib/gtl:__pkg__", + ]), ) tsl_cc_test( @@ -178,7 +213,6 @@ tsl_cc_test( "iterator_range_test.cc", "map_util_test.cc", ], - visibility = ["//visibility:public"], deps = [ ":compactptrset", ":flatmap", diff --git a/third_party/xla/third_party/tsl/tsl/lib/gtl/subtle/BUILD b/third_party/xla/third_party/tsl/tsl/lib/gtl/subtle/BUILD index 5bb45dad10ec9d..3e9bfe7a5d03e4 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/gtl/subtle/BUILD +++ b/third_party/xla/third_party/tsl/tsl/lib/gtl/subtle/BUILD @@ -1,10 +1,10 @@ # Description: # gtl subtle packages. +load("//tsl:tsl.bzl", "internal_visibility") load("//tsl:tsl.default.bzl", "filegroup") package( - default_visibility = ["//visibility:public"], # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], licenses = ["notice"], ) @@ -14,5 +14,8 @@ filegroup( srcs = [ "map_traits.h", ], - visibility = ["//visibility:public"], + visibility = internal_visibility([ + "//tensorflow/core/lib/gtl/subtle:__pkg__", + "//tsl/lib/gtl:__pkg__", + ]), ) diff --git a/third_party/xla/third_party/tsl/tsl/lib/hash/BUILD b/third_party/xla/third_party/tsl/tsl/lib/hash/BUILD index d4b155ab625da6..9ef78a04c27284 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/hash/BUILD +++ b/third_party/xla/third_party/tsl/tsl/lib/hash/BUILD @@ -1,7 +1,7 @@ load( "//tsl:tsl.bzl", "if_linux_x86_64", - "set_external_visibility", + "internal_visibility", "tsl_copts", ) load("//tsl:tsl.default.bzl", "filegroup") @@ -15,7 +15,13 @@ load( ) package( - default_visibility = ["//visibility:public"], + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = internal_visibility([ + # tensorflow/tsl/lib/io/table_builder.cc uses crc functionality + "//tsl/lib/io:__pkg__", + # tensorflow/core/lib/hash aliases hash for now + "//tensorflow/core/lib/hash:__pkg__", + ]), licenses = ["notice"], ) @@ -27,7 +33,6 @@ cc_library( hdrs = ["crc32c.h"], # -msse4.2 enables the use of crc32c compiler builtins. copts = tsl_copts() + if_linux_x86_64(["-msse4.2"]), - visibility = ["//visibility:public"], deps = [ "//tsl/platform", "//tsl/platform:cord", @@ -44,7 +49,7 @@ filegroup( "crc32c.cc", "crc32c.h", ], - visibility = ["//visibility:public"], + visibility = internal_visibility(["//tensorflow/core/lib/hash:__pkg__"]), ) filegroup( @@ -52,14 +57,13 @@ filegroup( srcs = [ "crc32c.h", ], - visibility = ["//visibility:public"], + visibility = internal_visibility(["//tensorflow/core/lib/hash:__pkg__"]), ) tsl_cc_test( name = "crc32c_test", size = "small", srcs = ["crc32c_test.cc"], - visibility = ["//visibility:public"], deps = [ ":crc32c", "//tsl/platform:logging", diff --git a/third_party/xla/third_party/tsl/tsl/lib/histogram/BUILD b/third_party/xla/third_party/tsl/tsl/lib/histogram/BUILD index f669b2b57232e3..0093ad1b5274da 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/histogram/BUILD +++ b/third_party/xla/third_party/tsl/tsl/lib/histogram/BUILD @@ -1,3 +1,4 @@ +load("//tsl:tsl.bzl", "internal_visibility") load("//tsl:tsl.default.bzl", "filegroup") load( "//tsl/platform:build_config.bzl", @@ -9,7 +10,6 @@ load( ) package( - default_visibility = ["//visibility:public"], # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], licenses = ["notice"], ) @@ -18,7 +18,11 @@ cc_library( name = "histogram", srcs = ["histogram.cc"], hdrs = ["histogram.h"], - visibility = ["//visibility:public"], + visibility = internal_visibility([ + "//learning/brain/google/monitoring:__pkg__", + "//tensorflow/core/lib/histogram:__pkg__", + "//tsl/lib/monitoring:__pkg__", + ]), deps = [ "//tsl/platform:logging", "//tsl/platform:macros", @@ -37,7 +41,7 @@ filegroup( "histogram.cc", "histogram.h", ], - visibility = ["//visibility:public"], + visibility = internal_visibility(["//tensorflow/core/lib/histogram:__pkg__"]), ) filegroup( @@ -45,7 +49,7 @@ filegroup( srcs = [ "histogram.h", ], - visibility = ["//visibility:public"], + visibility = internal_visibility(["//tensorflow/core/lib/histogram:__pkg__"]), ) tsl_cc_test( @@ -53,7 +57,6 @@ tsl_cc_test( srcs = [ "histogram_test.cc", ], - visibility = ["//visibility:public"], deps = [ ":histogram", "//tsl/platform:logging", diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/BUILD b/third_party/xla/third_party/tsl/tsl/lib/io/BUILD index 01aaf40d7b7c2c..5b45d10a620a1e 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/io/BUILD +++ b/third_party/xla/third_party/tsl/tsl/lib/io/BUILD @@ -1,13 +1,24 @@ -load("//tsl/platform:build_config.bzl", "tsl_cc_test") -load("//tsl:tsl.bzl", "set_external_visibility") +load("//tsl:tsl.bzl", "internal_visibility") load("//tsl:tsl.default.bzl", "filegroup") +load("//tsl/platform:build_config.bzl", "tsl_cc_test") load( "@local_tsl//tsl/platform:rules_cc.bzl", "cc_library", ) package( - default_visibility = ["//visibility:public"], + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = internal_visibility([ + "//tensorflow/c/experimental/filesystem:__pkg__", + "//tensorflow/c/experimental/filesystem/plugins/posix:__pkg__", + "//tsl/lib/io/snappy:__pkg__", + "@local_xla//xla:__subpackages__", + # tensorflow/core:lib effectively exposes all targets under tensorflow/core/lib/** + "//tensorflow/core/util:__subpackages__", + "//tensorflow/core:__pkg__", + "//tensorflow/core/lib/io:__subpackages__", + "//tsl/profiler:__subpackages__", + ]), licenses = ["notice"], ) @@ -25,7 +36,6 @@ cc_library( "format.h", "table_builder.h", ], - visibility = ["//visibility:public"], deps = [ ":iterator", ":table_options", @@ -47,7 +57,6 @@ cc_library( name = "buffered_inputstream", srcs = ["buffered_inputstream.cc"], hdrs = ["buffered_inputstream.h"], - visibility = ["//visibility:public"], deps = [ ":inputstream_interface", ":random_inputstream", @@ -61,7 +70,6 @@ cc_library( name = "compression", srcs = ["compression.cc"], hdrs = ["compression.h"], - visibility = ["//visibility:public"], alwayslink = True, ) @@ -69,7 +77,6 @@ cc_library( name = "inputbuffer", srcs = ["inputbuffer.cc"], hdrs = ["inputbuffer.h"], - visibility = ["//visibility:public"], deps = [ "//tsl/platform:coding", "//tsl/platform:env", @@ -86,7 +93,6 @@ cc_library( name = "inputstream_interface", srcs = ["inputstream_interface.cc"], hdrs = ["inputstream_interface.h"], - visibility = ["//visibility:public"], deps = [ "//tsl/platform:cord", "//tsl/platform:errors", @@ -100,7 +106,6 @@ cc_library( name = "iterator", srcs = ["iterator.cc"], hdrs = ["iterator.h"], - visibility = ["//visibility:public"], deps = [ "//tsl/platform:status", "//tsl/platform:stringpiece", @@ -111,7 +116,6 @@ cc_library( cc_library( name = "proto_encode_helper", hdrs = ["proto_encode_helper.h"], - visibility = ["//visibility:public"], deps = [ "//tsl/platform:coding", "//tsl/platform:logging", @@ -124,7 +128,6 @@ cc_library( name = "random_inputstream", srcs = ["random_inputstream.cc"], hdrs = ["random_inputstream.h"], - visibility = ["//visibility:public"], deps = [ ":inputstream_interface", "//tsl/platform:cord", @@ -137,7 +140,6 @@ cc_library( name = "record_reader", srcs = ["record_reader.cc"], hdrs = ["record_reader.h"], - visibility = ["//visibility:public"], deps = [ ":buffered_inputstream", ":compression", @@ -162,7 +164,6 @@ cc_library( name = "record_writer", srcs = ["record_writer.cc"], hdrs = ["record_writer.h"], - visibility = ["//visibility:public"], deps = [ ":compression", ":snappy_compression_options", @@ -184,25 +185,21 @@ cc_library( alias( name = "snappy_inputbuffer", actual = "//tsl/lib/io/snappy:snappy_inputbuffer", - visibility = ["//visibility:public"], ) alias( name = "snappy_inputstream", actual = "//tsl/lib/io/snappy:snappy_inputstream", - visibility = ["//visibility:public"], ) alias( name = "snappy_outputbuffer", actual = "//tsl/lib/io/snappy:snappy_outputbuffer", - visibility = ["//visibility:public"], ) alias( name = "snappy_compression_options", actual = "//tsl/lib/io/snappy:snappy_compression_options", - visibility = ["//visibility:public"], ) cc_library( @@ -213,7 +210,6 @@ cc_library( hdrs = [ "cache.h", ], - visibility = ["//visibility:public"], deps = [ "//tsl/platform:mutex", "//tsl/platform:raw_coding", @@ -231,7 +227,6 @@ cc_library( "table.h", "two_level_iterator.h", ], - visibility = ["//visibility:public"], deps = [ ":block", ":cache", @@ -247,7 +242,6 @@ cc_library( cc_library( name = "table_options", hdrs = ["table_options.h"], - visibility = ["//visibility:public"], ) cc_library( @@ -266,7 +260,6 @@ tsl_cc_test( name = "buffered_file_test", size = "small", srcs = ["buffered_file_test.cc"], - visibility = ["//visibility:public"], deps = [ ":buffered_file", "//tsl/lib/core:status_test_util", @@ -282,7 +275,6 @@ cc_library( name = "zlib_compression_options", srcs = ["zlib_compression_options.cc"], hdrs = ["zlib_compression_options.h"], - visibility = ["//visibility:public"], deps = [ "//tsl/platform:types", "@zlib", @@ -294,7 +286,6 @@ cc_library( name = "zlib_inputstream", srcs = ["zlib_inputstream.cc"], hdrs = ["zlib_inputstream.h"], - visibility = ["//visibility:public"], deps = [ ":inputstream_interface", ":zlib_compression_options", @@ -313,7 +304,6 @@ cc_library( name = "zlib_outputbuffer", srcs = ["zlib_outputbuffer.cc"], hdrs = ["zlib_outputbuffer.h"], - visibility = ["//visibility:public"], deps = [ ":zlib_compression_options", "//tsl/platform:env", @@ -369,7 +359,6 @@ filegroup( "//tsl/lib/io/snappy:snappy_inputstream.cc", "//tsl/lib/io/snappy:snappy_inputstream.h", ], - visibility = ["//visibility:public"], ) filegroup( @@ -399,7 +388,7 @@ filegroup( "//tsl/lib/io/snappy:snappy_inputstream.h", "//tsl/lib/io/snappy:snappy_outputbuffer.h", ], - visibility = ["//visibility:public"], + visibility = internal_visibility(["//tensorflow/core:__pkg__"]), ) filegroup( @@ -417,7 +406,7 @@ filegroup( "table_builder.h", "table_options.h", ], - visibility = ["//visibility:public"], + visibility = internal_visibility(["//tensorflow/core:__pkg__"]), ) filegroup( @@ -433,7 +422,7 @@ filegroup( "//tsl/lib/io/snappy:snappy_inputstream.h", "//tsl/lib/io/snappy:snappy_outputbuffer.h", ], - visibility = ["//visibility:public"], + visibility = internal_visibility(["//tensorflow/core:__pkg__"]), ) filegroup( @@ -443,14 +432,13 @@ filegroup( "block_builder.h", "format.h", ], - visibility = ["//visibility:public"], + visibility = internal_visibility(["//tensorflow/core:__pkg__"]), ) tsl_cc_test( name = "buffered_inputstream_test", size = "small", srcs = ["buffered_inputstream_test.cc"], - visibility = ["//visibility:public"], deps = [ ":buffered_inputstream", ":random_inputstream", @@ -467,7 +455,6 @@ tsl_cc_test( name = "cache_test", size = "small", srcs = ["cache_test.cc"], - visibility = ["//visibility:public"], deps = [ ":cache", "//tsl/platform:coding", @@ -481,7 +468,6 @@ tsl_cc_test( name = "inputbuffer_test", size = "small", srcs = ["inputbuffer_test.cc"], - visibility = ["//visibility:public"], deps = [ ":inputbuffer", "//tsl/lib/core:status_test_util", @@ -502,7 +488,6 @@ tsl_cc_test( name = "inputstream_interface_test", size = "small", srcs = ["inputstream_interface_test.cc"], - visibility = ["//visibility:public"], deps = [ ":inputstream_interface", "//tsl/lib/core:status_test_util", @@ -516,7 +501,6 @@ tsl_cc_test( name = "random_inputstream_test", size = "small", srcs = ["random_inputstream_test.cc"], - visibility = ["//visibility:public"], deps = [ ":random_inputstream", "//tsl/lib/core:status_test_util", @@ -531,7 +515,6 @@ tsl_cc_test( name = "record_reader_writer_test", size = "small", srcs = ["record_reader_writer_test.cc"], - visibility = ["//visibility:public"], deps = [ ":record_reader", ":record_writer", @@ -552,7 +535,6 @@ tsl_cc_test( name = "recordio_test", size = "small", srcs = ["recordio_test.cc"], - visibility = ["//visibility:public"], deps = [ ":record_reader", ":record_writer", @@ -573,7 +555,6 @@ tsl_cc_test( name = "table_test", size = "small", srcs = ["table_test.cc"], - visibility = ["//visibility:public"], deps = [ ":block", ":iterator", @@ -593,7 +574,6 @@ tsl_cc_test( name = "zlib_buffers_test", size = "small", srcs = ["zlib_buffers_test.cc"], - visibility = ["//visibility:public"], deps = [ ":random_inputstream", ":zlib_compression_options", diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/snappy/BUILD b/third_party/xla/third_party/tsl/tsl/lib/io/snappy/BUILD index 9d817cefdb1030..57e178d668e126 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/io/snappy/BUILD +++ b/third_party/xla/third_party/tsl/tsl/lib/io/snappy/BUILD @@ -1,3 +1,4 @@ +load("//tsl:tsl.bzl", "internal_visibility") load( "//tsl/platform:build_config.bzl", "tsl_cc_test", @@ -11,27 +12,27 @@ load( ) package( - default_visibility = ["//visibility:public"], + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = internal_visibility([ + "//tensorflow/core/lib/io:__pkg__", + "//tsl/lib/io:__pkg__", + ]), licenses = ["notice"], ) -exports_files( - [ - "snappy_compression_options.h", - "snappy_inputbuffer.h", - "snappy_inputstream.h", - "snappy_outputbuffer.h", - "snappy_inputstream.cc", - "snappy_test.cc", - ], - visibility = ["//visibility:public"], -) +exports_files([ + "snappy_compression_options.h", + "snappy_inputbuffer.h", + "snappy_inputstream.h", + "snappy_outputbuffer.h", + "snappy_inputstream.cc", + "snappy_test.cc", +]) cc_library( name = "snappy_inputbuffer", srcs = ["snappy_inputbuffer.cc"], hdrs = ["snappy_inputbuffer.h"], - visibility = ["//visibility:public"], deps = [ "//tsl/lib/io:inputstream_interface", "//tsl/platform:env", @@ -48,7 +49,6 @@ cc_library( name = "snappy_outputbuffer", srcs = ["snappy_outputbuffer.cc"], hdrs = ["snappy_outputbuffer.h"], - visibility = ["//visibility:public"], deps = [ "//tsl/platform", "//tsl/platform:env", @@ -64,7 +64,6 @@ cc_library( name = "snappy_inputstream", srcs = ["snappy_inputstream.cc"], hdrs = ["snappy_inputstream.h"], - visibility = ["//visibility:public"], deps = [ "//tsl/lib/io:inputstream_interface", "//tsl/platform:errors", @@ -77,7 +76,6 @@ cc_library( cc_library( name = "snappy_compression_options", hdrs = ["snappy_compression_options.h"], - visibility = ["//visibility:public"], deps = [ "//tsl/platform:types", ], @@ -88,7 +86,6 @@ tsl_cc_test( name = "snappy_test", size = "small", srcs = ["snappy_test.cc"], - visibility = ["//visibility:public"], deps = [ ":snappy_inputbuffer", ":snappy_inputstream", diff --git a/third_party/xla/third_party/tsl/tsl/lib/math/BUILD b/third_party/xla/third_party/tsl/tsl/lib/math/BUILD index a48ddadcf92d06..e5f1178382650a 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/math/BUILD +++ b/third_party/xla/third_party/tsl/tsl/lib/math/BUILD @@ -1,12 +1,15 @@ +load("//tsl:tsl.bzl", "internal_visibility") load("//tsl:tsl.default.bzl", "get_compatible_with_portable") -load("//tsl:tsl.bzl", "set_external_visibility") load( "//tsl/platform:build_config.bzl", "tsl_cc_test", ) package( - default_visibility = ["//visibility:public"], + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = internal_visibility([ + "//tensorflow:__subpackages__", + ]), licenses = ["notice"], ) @@ -14,7 +17,11 @@ cc_library( name = "math_util", hdrs = ["math_util.h"], compatible_with = get_compatible_with_portable(), - visibility = ["//visibility:public"], + visibility = internal_visibility([ + "//platforms/performance/tf_sim/utils:__subpackages__", + "//platforms/xla/service:__subpackages__", + "//tensorflow:__subpackages__", + ]), deps = ["@com_google_absl//absl/base:core_headers"], ) @@ -24,7 +31,6 @@ tsl_cc_test( srcs = [ "math_util_test.cc", ], - visibility = ["//visibility:public"], deps = [ ":math_util", "//tsl/platform:logging", @@ -42,12 +48,9 @@ filegroup( "math_util.h", ], compatible_with = get_compatible_with_portable(), - visibility = ["//visibility:public"], + visibility = internal_visibility(["//tensorflow/core:__pkg__"]), ) -exports_files( - [ - "math_util.h", - ], - visibility = ["//visibility:public"], -) +exports_files([ + "math_util.h", +]) diff --git a/third_party/xla/third_party/tsl/tsl/lib/monitoring/BUILD b/third_party/xla/third_party/tsl/tsl/lib/monitoring/BUILD index 8dcff6a0337f5e..1842f5ab656071 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/monitoring/BUILD +++ b/third_party/xla/third_party/tsl/tsl/lib/monitoring/BUILD @@ -1,4 +1,4 @@ -load("//tsl:tsl.bzl", "set_external_visibility") +load("//tsl:tsl.bzl", "internal_visibility") load("//tsl:tsl.default.bzl", "filegroup") load( "@local_tsl//tsl/platform:rules_cc.bzl", @@ -6,14 +6,33 @@ load( ) package( - default_visibility = ["//visibility:public"], + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = internal_visibility([ + "//learning/brain/google/data:__subpackages__", + "//learning/brain/google/monitoring:__subpackages__", + # tensorflow/core:lib effectively exposes all targets under tensorflow/core/lib/** + "//tensorflow/core:__pkg__", + # tensorflow/core/platform:monitoring depends on this package + "//tensorflow/core/platform:__subpackages__", + # tensorflow/compiler/xla/pjrt:metrics depends on this package + "@local_xla//xla/pjrt:__subpackages__", + "@local_xla//xla/service/gpu:__subpackages__", + # tensorflow/compiler/mlir/tfrt:tf_jitrt depends on this package + "//tensorflow/compiler/mlir/tfrt:__subpackages__", + "@local_xla//xla/stream_executor:__subpackages__", + "@local_xla//xla/hlo/experimental:__subpackages__", + "//tensorflow/core/lib/monitoring:__subpackages__", + "@local_xla//xla/service:__subpackages__", + "//tsl/framework:__subpackages__", + "//tsl/distributed_runtime:__subpackages__", + "//tensorflow/compiler/mlir/tf2xla:__subpackages__", + ]), licenses = ["notice"], ) cc_library( name = "counter", hdrs = ["counter.h"], - visibility = ["//visibility:public"], deps = [ ":collection_registry", ":metric_def", @@ -32,7 +51,6 @@ cc_library( hdrs = [ "gauge.h", ], - visibility = ["//visibility:public"], deps = [ ":collection_registry", ":metric_def", @@ -49,7 +67,6 @@ cc_library( name = "sampler", srcs = ["sampler.cc"], hdrs = ["sampler.h"], - visibility = ["//visibility:public"], deps = [ ":collection_registry", ":metric_def", @@ -69,7 +86,6 @@ cc_library( hdrs = [ "types.h", ], - visibility = ["//visibility:public"], deps = [ "//tsl/platform:types", ], @@ -78,7 +94,9 @@ cc_library( cc_library( name = "metric_def", hdrs = ["metric_def.h"], - visibility = ["//visibility:public"], + visibility = internal_visibility([ + "//tensorflow/core:__subpackages__", + ]), deps = [ ":types", "//tsl/platform:stringpiece", @@ -91,7 +109,9 @@ cc_library( name = "collection_registry", srcs = ["collection_registry.cc"], hdrs = ["collection_registry.h"], - visibility = ["//visibility:public"], + visibility = internal_visibility([ + "//tensorflow/core:__subpackages__", + ]), deps = [ ":collected_metrics", ":metric_def", @@ -113,7 +133,6 @@ cc_library( hdrs = [ "collected_metrics.h", ], - visibility = ["//visibility:public"], deps = [ ":metric_def", ":types", @@ -139,7 +158,6 @@ cc_library( testonly = 1, srcs = ["cell_reader-inl.cc"], hdrs = ["cell_reader-inl.h"], - visibility = ["//visibility:public"], #visibility = ["//visibility:private"], deps = [ ":collected_metrics", @@ -159,7 +177,6 @@ cc_library( name = "percentile_sampler", srcs = ["percentile_sampler.cc"], hdrs = ["percentile_sampler.h"], - visibility = ["//visibility:public"], deps = [ ":collection_registry", ":metric_def", @@ -193,7 +210,6 @@ cc_library( hdrs = [ "timed.h", ], - visibility = ["//visibility:public"], deps = [ "//tsl/platform:env_time", "//tsl/platform:types", @@ -214,7 +230,10 @@ filegroup( "timed.h", "types.h", ], - visibility = ["//visibility:public"], + visibility = internal_visibility([ + "//tensorflow/core:__pkg__", + "//tensorflow/core/lib/monitoring:__pkg__", + ]), ) filegroup( @@ -232,7 +251,10 @@ filegroup( "timed.h", "types.h", ], - visibility = ["//visibility:public"], + visibility = internal_visibility([ + "//tensorflow/core:__pkg__", + "//tensorflow/core/lib/monitoring:__pkg__", + ]), ) filegroup( @@ -249,5 +271,8 @@ filegroup( "test_utils.h", "types.h", ], - visibility = ["//visibility:public"], + visibility = internal_visibility([ + "//tensorflow/core:__pkg__", + "//tensorflow/core/lib/monitoring:__pkg__", + ]), ) diff --git a/third_party/xla/third_party/tsl/tsl/lib/random/BUILD b/third_party/xla/third_party/tsl/tsl/lib/random/BUILD index f20d9a1123870f..3223845738ed7c 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/random/BUILD +++ b/third_party/xla/third_party/tsl/tsl/lib/random/BUILD @@ -1,19 +1,24 @@ +load("//tsl:tsl.bzl", "internal_visibility") +load("//tsl:tsl.default.bzl", "filegroup", "get_compatible_with_portable") load( "//tsl/platform:build_config.bzl", "tsl_cc_test", ) -load("//tsl:tsl.default.bzl", "filegroup", "get_compatible_with_portable") load("@local_tsl//tsl/platform:rules_cc.bzl", "cc_library") package( - default_visibility = ["//visibility:public"], + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = internal_visibility([ + "//tsl/lib/io:__pkg__", + # tensorflow/core/platform/random aliases this package + "//tensorflow/core/lib/random:__pkg__", + ]), licenses = ["notice"], ) cc_library( name = "exact_uniform_int", hdrs = ["exact_uniform_int.h"], - visibility = ["//visibility:public"], ) cc_library( @@ -28,7 +33,6 @@ cc_library( "random_distributions.h", "simple_philox.h", ], - visibility = ["//visibility:public"], deps = [ ":exact_uniform_int", ":philox_random", @@ -46,7 +50,10 @@ cc_library( name = "random_distributions_utils", hdrs = ["random_distributions_utils.h"], compatible_with = get_compatible_with_portable(), - visibility = ["//visibility:public"], + visibility = internal_visibility([ + "//tensorflow/core/lib/random:__pkg__", + "//tensorflow/lite:__subpackages__", + ]), deps = [":philox_random"], ) @@ -54,14 +61,16 @@ cc_library( name = "philox_random", hdrs = ["philox_random.h"], compatible_with = get_compatible_with_portable(), - visibility = ["//visibility:public"], + visibility = internal_visibility([ + "//tensorflow/core/lib/random:__pkg__", + "//tensorflow/lite:__subpackages__", + ]), ) cc_library( name = "philox_random_test_utils", testonly = True, hdrs = ["philox_random_test_utils.h"], - visibility = ["//visibility:public"], deps = [ ":philox_random", "//tsl/platform:logging", @@ -73,7 +82,6 @@ cc_library( name = "weighted_picker", srcs = ["weighted_picker.cc"], hdrs = ["weighted_picker.h"], - visibility = ["//visibility:public"], deps = [ ":philox", "//tsl/platform:logging", @@ -98,7 +106,6 @@ filegroup( "weighted_picker.cc", "weighted_picker.h", ], - visibility = ["//visibility:public"], ) filegroup( @@ -110,7 +117,7 @@ filegroup( "random_distributions_utils.h", "simple_philox.h", ], - visibility = ["//visibility:public"], + visibility = internal_visibility(["//tensorflow/core/lib/random:__pkg__"]), ) filegroup( @@ -120,7 +127,7 @@ filegroup( "random_distributions_utils.h", "weighted_picker.h", ], - visibility = ["//visibility:public"], + visibility = internal_visibility(["//tensorflow/core/lib/random:__pkg__"]), ) filegroup( @@ -128,7 +135,7 @@ filegroup( srcs = [ "philox_random_test_utils.h", ], - visibility = ["//visibility:public"], + visibility = internal_visibility(["//tensorflow/core/lib/random:__pkg__"]), ) filegroup( @@ -143,14 +150,13 @@ filegroup( "simple_philox.h", "weighted_picker.h", ], - visibility = ["//visibility:public"], + visibility = internal_visibility(["//tensorflow/core/lib/random:__pkg__"]), ) tsl_cc_test( name = "distribution_sampler_test", size = "small", srcs = ["distribution_sampler_test.cc"], - visibility = ["//visibility:public"], deps = [ ":philox", "//tsl/platform:macros", @@ -165,7 +171,6 @@ tsl_cc_test( name = "philox_random_test", size = "small", srcs = ["philox_random_test.cc"], - visibility = ["//visibility:public"], deps = [ ":philox", ":philox_random", @@ -181,7 +186,6 @@ tsl_cc_test( name = "random_distributions_test", srcs = ["random_distributions_test.cc"], tags = ["optonly"], - visibility = ["//visibility:public"], deps = [ ":philox", ":philox_random", @@ -198,7 +202,6 @@ tsl_cc_test( name = "simple_philox_test", size = "small", srcs = ["simple_philox_test.cc"], - visibility = ["//visibility:public"], deps = [ ":philox", "//tsl/platform:logging", @@ -212,7 +215,6 @@ tsl_cc_test( name = "weighted_picker_test", size = "medium", srcs = ["weighted_picker_test.cc"], - visibility = ["//visibility:public"], deps = [ ":philox", ":weighted_picker", diff --git a/third_party/xla/third_party/tsl/tsl/lib/strings/BUILD b/third_party/xla/third_party/tsl/tsl/lib/strings/BUILD index a1c9dc81ac0e4c..b19d298365ae15 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/strings/BUILD +++ b/third_party/xla/third_party/tsl/tsl/lib/strings/BUILD @@ -1,4 +1,4 @@ -load("//tsl:tsl.bzl", "set_external_visibility") +load("//tsl:tsl.bzl", "internal_visibility") load("//tsl:tsl.default.bzl", "filegroup") load( "@local_tsl//tsl/platform:rules_cc.bzl", @@ -11,7 +11,15 @@ cc_library( name = "proto_serialization", srcs = ["proto_serialization.cc"], hdrs = ["proto_serialization.h"], - visibility = ["//visibility:public"], + visibility = internal_visibility([ + "@local_xla//xla/pjrt:__subpackages__", + "@local_xla//xla/python:__pkg__", + "@local_xla//xla/service:__pkg__", + "@local_xla//xla/stream_executor:__pkg__", + "//tensorflow/core/lib/strings:__pkg__", + "//tensorflow/compiler/tf2xla/kernels:__pkg__", + "//tensorflow/core/util/autotune_maps:__pkg__", + ]), deps = [ "//tsl/lib/gtl:inlined_vector", "//tsl/platform:hash", @@ -29,7 +37,7 @@ filegroup( "proto_serialization.cc", "proto_serialization.h", ], - visibility = ["//visibility:public"], + visibility = internal_visibility(["//tensorflow/core/lib/strings:__pkg__"]), ) filegroup( @@ -37,7 +45,7 @@ filegroup( srcs = [ "proto_serialization.h", ], - visibility = ["//visibility:public"], + visibility = internal_visibility(["//tensorflow/core/lib/strings:__pkg__"]), ) filegroup( @@ -45,7 +53,7 @@ filegroup( srcs = [ "proto_serialization.h", ], - visibility = ["//visibility:public"], + visibility = internal_visibility(["//tensorflow/core/lib/strings:__pkg__"]), ) filegroup( @@ -53,5 +61,5 @@ filegroup( srcs = [ "proto_serialization.h", ], - visibility = ["//visibility:public"], + visibility = internal_visibility(["//tensorflow/core/lib/strings:__pkg__"]), ) diff --git a/third_party/xla/third_party/tsl/tsl/platform/BUILD b/third_party/xla/third_party/tsl/tsl/platform/BUILD index 28171fdaf50044..91ec2f9ea17098 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/BUILD +++ b/third_party/xla/third_party/tsl/tsl/platform/BUILD @@ -7,7 +7,7 @@ load( "//tsl:tsl.bzl", "if_not_fuchsia", - "set_external_visibility", + "internal_visibility", "tsl_copts", ) load("//tsl:tsl.default.bzl", "get_compatible_with_portable") @@ -25,6 +25,7 @@ load( "tf_stream_executor_deps", "tf_windows_aware_platform_deps", "tsl_cc_test", + "tsl_grpc_credentials_deps", "tsl_protobuf_deps", ) load("//tsl/platform:build_config_root.bzl", "if_static") @@ -38,7 +39,10 @@ load( ) package( - default_visibility = ["//visibility:public"], + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = [ + "//visibility:public", + ], licenses = ["notice"], ) @@ -48,14 +52,16 @@ exports_files( "load_library.h", "stringpiece_test.cc", ], - visibility = ["//visibility:public"], + visibility = internal_visibility([ + "//tensorflow/core/platform:__subpackages__", + "//tsl:__subpackages__", + ]), ) cc_library( name = "base64", srcs = ["base64.cc"], hdrs = ["base64.h"], - visibility = ["//visibility:public"], deps = [ ":errors", ":macros", @@ -69,7 +75,6 @@ cc_library( name = "blocking_counter", hdrs = ["blocking_counter.h"], compatible_with = get_compatible_with_portable(), - visibility = ["//visibility:public"], deps = [ ":logging", ":mutex", @@ -80,14 +85,12 @@ cc_library( name = "byte_order", hdrs = ["byte_order.h"], compatible_with = get_compatible_with_portable(), - visibility = ["//visibility:public"], ) cc_library( name = "coding", srcs = ["coding.cc"], hdrs = ["coding.h"], - visibility = ["//visibility:public"], deps = [ ":byte_order", ":stringpiece", @@ -99,7 +102,6 @@ cc_library( name = "criticality", compatible_with = get_compatible_with_portable(), textual_hdrs = ["criticality.h"], - visibility = ["//visibility:public"], deps = tf_platform_deps("criticality"), ) @@ -109,7 +111,6 @@ tsl_cc_test( srcs = [ "criticality_test.cc", ], - visibility = ["//visibility:public"], deps = [ ":criticality", ":test", @@ -121,7 +122,6 @@ cc_library( name = "denormal", srcs = ["denormal.cc"], hdrs = ["denormal.h"], - visibility = ["//visibility:public"], deps = [ ":macros", ":platform", @@ -133,7 +133,6 @@ tsl_cc_test( name = "denormal_test", size = "small", srcs = ["denormal_test.cc"], - visibility = ["//visibility:public"], deps = [ ":denormal", ":test", @@ -149,13 +148,11 @@ cc_library( "file_system_helper.h", "threadpool.h", ], - visibility = ["//visibility:public"], deps = tf_windows_aware_platform_deps("env") + if_static([":env_impl"]), ) cc_library( name = "env_impl", - visibility = ["//visibility:public"], deps = tf_windows_aware_platform_deps("env_impl"), ) @@ -163,7 +160,6 @@ cc_library( name = "env_time", compatible_with = get_compatible_with_portable(), textual_hdrs = ["env_time.h"], - visibility = ["//visibility:public"], deps = tf_windows_aware_platform_deps("env_time"), ) @@ -171,7 +167,6 @@ cc_library( name = "errors", srcs = ["errors.cc"], hdrs = ["errors.h"], - visibility = ["//visibility:public"], deps = [ ":logging", ":macros", @@ -189,26 +184,15 @@ cc_library( name = "dynamic_annotations", hdrs = ["dynamic_annotations.h"], compatible_with = get_compatible_with_portable(), - visibility = ["//visibility:public"], deps = [ "@com_google_absl//absl/base:dynamic_annotations", ], ) -cc_library( - name = "gif", - hdrs = ["gif.h"], - visibility = ["//visibility:public"], - deps = [ - "@gif", - ], -) - cc_library( name = "mutex", compatible_with = get_compatible_with_portable(), textual_hdrs = ["mutex.h"], - visibility = ["//visibility:public"], deps = tf_platform_deps("mutex"), ) @@ -216,7 +200,6 @@ cc_library( name = "numbers", srcs = ["numbers.cc"], hdrs = ["numbers.h"], - visibility = ["//visibility:public"], deps = [ ":logging", ":macros", @@ -232,7 +215,6 @@ cc_library( name = "path", srcs = ["path.cc"], hdrs = ["path.h"], - visibility = ["//visibility:public"], deps = [ ":logging", ":mutex", @@ -253,7 +235,6 @@ cc_library( "protobuf_util.cc", ], hdrs = ["protobuf.h"], - visibility = ["//visibility:public"], deps = [ ":platform", ":types", @@ -263,7 +244,6 @@ cc_library( cc_library( name = "regexp", hdrs = ["regexp.h"], - visibility = ["//visibility:public"], deps = [ ":platform", "@com_googlesource_code_re2//:re2", @@ -273,7 +253,6 @@ cc_library( cc_library( name = "resource", textual_hdrs = ["resource.h"], - visibility = ["//visibility:public"], deps = [ ":stringpiece", ] + tf_resource_deps(), @@ -283,14 +262,12 @@ cc_library( name = "stack_frame", hdrs = ["stack_frame.h"], compatible_with = get_compatible_with_portable(), - visibility = ["//visibility:public"], ) cc_library( name = "status", srcs = ["status.cc"], hdrs = ["status.h"], - visibility = ["//visibility:public"], deps = [ ":logging", ":macros", @@ -319,7 +296,6 @@ cc_library( "status_to_from_proto.cc", ], hdrs = ["status_to_from_proto.h"], - visibility = ["//visibility:public"], deps = [ ":status", "//tsl/protobuf:error_codes_proto_impl_cc", @@ -332,7 +308,6 @@ cc_library( testonly = 1, srcs = ["status_matchers.cc"], hdrs = ["status_matchers.h"], - visibility = ["//visibility:public"], deps = [ ":status", ":statusor", @@ -344,7 +319,6 @@ cc_library( cc_library( name = "statusor", hdrs = ["statusor.h"], - visibility = ["//visibility:public"], deps = [ ":errors", ":logging", @@ -363,7 +337,6 @@ cc_library( name = "thread_annotations", hdrs = ["thread_annotations.h"], compatible_with = get_compatible_with_portable(), - visibility = ["//visibility:public"], ) cc_library( @@ -372,7 +345,6 @@ cc_library( srcs = ["test.cc"], compatible_with = get_compatible_with_portable(), textual_hdrs = ["test.h"], - visibility = ["//visibility:public"], deps = [ ":logging", ":macros", @@ -389,7 +361,6 @@ cc_library( testonly = True, hdrs = ["test_benchmark.h"], compatible_with = get_compatible_with_portable(), - visibility = ["//visibility:public"], deps = [ ":platform", "@com_google_benchmark//:benchmark", @@ -404,7 +375,10 @@ filegroup( "test_benchmark.h", ], compatible_with = get_compatible_with_portable(), - visibility = ["//visibility:public"], + visibility = internal_visibility([ + "//tensorflow/core:__pkg__", + "//tensorflow/core/platform:__pkg__", + ]), ) filegroup( @@ -414,13 +388,15 @@ filegroup( "test.h", ], compatible_with = get_compatible_with_portable(), - visibility = ["//visibility:public"], + visibility = internal_visibility([ + "//tensorflow/core:__pkg__", + "//tensorflow/core/platform:__pkg__", + ]), ) cc_library( name = "tracing", textual_hdrs = ["tracing.h"], - visibility = ["//visibility:public"], deps = tf_platform_deps("tracing"), ) @@ -432,7 +408,6 @@ cc_library( "tstring.h", ], compatible_with = get_compatible_with_portable(), - visibility = ["//visibility:public"], deps = [ ":cord", ":platform", @@ -447,7 +422,6 @@ filegroup( "ctstring_internal.h", ], compatible_with = get_compatible_with_portable(), - visibility = ["//visibility:public"], ) filegroup( @@ -495,7 +469,6 @@ filegroup( "//tsl/platform/profile_utils:i_cpu_utils_helper.h", ], compatible_with = get_compatible_with_portable(), - visibility = ["//visibility:public"], ) # Header files for tensorflow/core:platform_base. @@ -510,7 +483,6 @@ filegroup( "threadpool_options.h", ], compatible_with = get_compatible_with_portable(), - visibility = ["//visibility:public"], ) filegroup( @@ -527,7 +499,6 @@ filegroup( "thread_annotations.h", ], compatible_with = get_compatible_with_portable(), - visibility = ["//visibility:public"], ) # Export source files needed for mobile builds, which do not use granular targets. @@ -615,7 +586,6 @@ filegroup( ], }), compatible_with = get_compatible_with_portable(), - visibility = ["//visibility:public"], ) filegroup( @@ -624,7 +594,6 @@ filegroup( "error_logging.h", "fingerprint.h", "notification.h", - "png.h", "random.cc", "random.h", "test_benchmark.h", @@ -639,16 +608,10 @@ filegroup( "subprocess.h", ]), compatible_with = get_compatible_with_portable(), - visibility = ["//visibility:public"], -) - -filegroup( - name = "gif_hdrs", - srcs = [ - "gif.h", - ], - compatible_with = get_compatible_with_portable(), - visibility = ["//visibility:public"], + visibility = internal_visibility([ + "//tensorflow/core:__pkg__", + "//tensorflow/core/platform:__pkg__", + ]), ) filegroup( @@ -659,9 +622,6 @@ filegroup( ], exclude = [ "dynamic_annotations.h", - "gif.h", - "png.h", - "jpeg.h", ], ) + [ "//tsl/platform/profile_utils:android_armv7a_cpu_utils_helper.h", @@ -670,7 +630,6 @@ filegroup( "//tsl/platform/profile_utils:i_cpu_utils_helper.h", ], compatible_with = get_compatible_with_portable(), - visibility = ["//visibility:public"], ) exports_files( @@ -691,6 +650,7 @@ exports_files( "file_system.h", "file_system_helper.cc", "file_system_helper.h", + "grpc_credentials.h", "host_info.h", "human_readable_json.h", "init_main.h", @@ -711,7 +671,10 @@ exports_files( "tracing.cc", "tracing.h", ], - visibility = ["//visibility:public"], + visibility = internal_visibility([ + ":__subpackages__", + "//tensorflow:__subpackages__", + ]), ) filegroup( @@ -724,13 +687,14 @@ filegroup( "stringpiece.h", ], compatible_with = get_compatible_with_portable(), - visibility = ["//visibility:public"], + visibility = internal_visibility([ + "//tensorflow/core:__pkg__", + ]), ) cc_library( name = "intrusive_ptr", hdrs = ["intrusive_ptr.h"], - visibility = ["//visibility:public"], deps = [], ) @@ -752,25 +716,10 @@ filegroup( "unbounded_work_queue.h", ], compatible_with = get_compatible_with_portable(), - visibility = ["//visibility:public"], -) - -filegroup( - name = "jpeg_hdrs", - srcs = [ - "jpeg.h", - ], - compatible_with = get_compatible_with_portable(), - visibility = ["//visibility:public"], -) - -cc_library( - name = "jpeg", - hdrs = ["jpeg.h"], - visibility = ["//visibility:public"], - deps = [ - "@libjpeg_turbo//:jpeg", - ], + visibility = internal_visibility([ + "//tensorflow/core:__pkg__", + "//tensorflow/core/platform:__pkg__", + ]), ) filegroup( @@ -781,7 +730,10 @@ filegroup( "platform.h", ], compatible_with = get_compatible_with_portable(), - visibility = ["//visibility:public"], + visibility = internal_visibility([ + "//tensorflow/core:__pkg__", + "//tensorflow/core/lib/jpeg:__pkg__", + ]), ) filegroup( @@ -795,7 +747,11 @@ filegroup( "stringpiece.h", ], compatible_with = get_compatible_with_portable(), - visibility = ["//visibility:public"], + visibility = internal_visibility([ + "//tensorflow/core:__pkg__", + "//tensorflow/core/lib/jpeg:__pkg__", + "//tensorflow/core/platform:__pkg__", + ]), ) filegroup( @@ -808,14 +764,17 @@ filegroup( "platform.h", ], compatible_with = get_compatible_with_portable(), - visibility = ["//visibility:public"], + visibility = internal_visibility([ + "//tensorflow/core:__pkg__", + "//tensorflow/core/lib/gif:__pkg__", + "//tensorflow/core/platform:__pkg__", + ]), ) cc_library( name = "macros", hdrs = ["macros.h"], compatible_with = get_compatible_with_portable(), - visibility = ["//visibility:public"], ) filegroup( @@ -827,13 +786,11 @@ filegroup( "platform.h", ], compatible_with = get_compatible_with_portable(), - visibility = ["//visibility:public"], ) cc_library( name = "net", textual_hdrs = ["net.h"], - visibility = ["//visibility:public"], deps = tf_windows_aware_platform_deps("net"), ) @@ -841,7 +798,6 @@ cc_library( name = "platform", hdrs = ["platform.h"], compatible_with = get_compatible_with_portable(), - visibility = ["//visibility:public"], ) cc_library( @@ -856,7 +812,6 @@ cc_library( "numa.h", "snappy.h", ], - visibility = ["//visibility:public"], deps = tf_windows_aware_platform_deps("platform_port"), ) @@ -865,13 +820,11 @@ cc_library( srcs = [ "platform_strings_computed.h", ], - visibility = ["//visibility:public"], ) cc_library( name = "protobuf_compiler", hdrs = ["protobuf_compiler.h"], - visibility = ["//visibility:public"], deps = tf_protobuf_compiler_deps(), ) @@ -879,7 +832,6 @@ cc_library( name = "random", srcs = ["random.cc"], hdrs = ["random.h"], - visibility = ["//visibility:public"], deps = [ ":mutex", ":types", @@ -891,7 +843,6 @@ cc_library( testonly = 1, srcs = ["resource_loader.cc"], textual_hdrs = ["resource_loader.h"], - visibility = ["//visibility:public"], deps = [ ":path", ":test", @@ -901,7 +852,6 @@ cc_library( cc_library( name = "rocm_rocdl_path", textual_hdrs = ["rocm_rocdl_path.h"], - visibility = ["//visibility:public"], deps = tf_platform_deps("rocm_rocdl_path"), ) @@ -911,20 +861,17 @@ filegroup( "stacktrace_handler.h", ], compatible_with = get_compatible_with_portable(), - visibility = ["//visibility:public"], ) cc_library( name = "stacktrace_handler_hdrs_lib", hdrs = ["stacktrace_handler.h"], compatible_with = get_compatible_with_portable(), - visibility = ["//visibility:public"], ) cc_library( name = "stacktrace_handler", textual_hdrs = ["stacktrace_handler.h"], - visibility = ["//visibility:public"], deps = tf_windows_aware_platform_deps("stacktrace_handler"), alwayslink = 1, ) @@ -936,7 +883,6 @@ tsl_cc_test( "stacktrace_handler_test.cc", ], tags = ["no_windows"], - visibility = ["//visibility:public"], deps = [ ":logging", ":stacktrace", @@ -950,7 +896,6 @@ cc_library( name = "str_util", srcs = ["str_util.cc"], hdrs = ["str_util.h"], - visibility = ["//visibility:public"], deps = [ ":logging", ":macros", @@ -964,7 +909,6 @@ cc_library( name = "strcat", srcs = ["strcat.cc"], hdrs = ["strcat.h"], - visibility = ["//visibility:public"], deps = [ ":logging", ":macros", @@ -979,7 +923,6 @@ cc_library( name = "stringpiece", hdrs = ["stringpiece.h"], compatible_with = get_compatible_with_portable(), - visibility = ["//visibility:public"], deps = [ "@com_google_absl//absl/strings", ], @@ -989,7 +932,6 @@ cc_library( name = "crash_analysis", hdrs = ["crash_analysis.h"], compatible_with = get_compatible_with_portable(), - visibility = ["//visibility:public"], deps = [ ":platform", ] + tf_platform_deps("crash_analysis"), @@ -999,7 +941,6 @@ cc_library( name = "stringprintf", srcs = ["stringprintf.cc"], hdrs = ["stringprintf.h"], - visibility = ["//visibility:public"], deps = [ ":macros", ":types", @@ -1011,7 +952,6 @@ cc_library( textual_hdrs = [ "subprocess.h", ], - visibility = ["//visibility:public"], deps = tf_windows_aware_platform_deps("subprocess"), ) @@ -1019,7 +959,6 @@ cc_library( name = "cord", hdrs = ["cord.h"], compatible_with = get_compatible_with_portable(), - visibility = ["//visibility:public"], deps = [ "@com_google_absl//absl/strings:cord", ], @@ -1029,7 +968,6 @@ cc_library( name = "threadpool_interface", hdrs = ["threadpool_interface.h"], compatible_with = get_compatible_with_portable(), - visibility = ["//visibility:public"], deps = [ ":mutex", ":types", @@ -1041,7 +979,6 @@ cc_library( name = "types", hdrs = ["types.h"], compatible_with = get_compatible_with_portable(), - visibility = ["//visibility:public"], deps = [ ":bfloat16", ":ml_dtypes", @@ -1053,12 +990,12 @@ cc_library( cc_library( name = "build_test", testonly = 1, - visibility = ["//visibility:public"], + visibility = internal_visibility([ + "//tensorflow/core/platform:__pkg__", + ]), deps = [ ":byte_order", ":fingerprint", - ":gif", - ":jpeg", ":macros", ":net", ":platform", @@ -1074,7 +1011,6 @@ cc_library( name = "bfloat16", hdrs = ["bfloat16.h"], compatible_with = get_compatible_with_portable(), - visibility = ["//visibility:public"], deps = [ ":byte_order", "@eigen_archive//:eigen3", @@ -1085,7 +1021,6 @@ cc_library( name = "ml_dtypes", hdrs = ["ml_dtypes.h"], compatible_with = get_compatible_with_portable(), - visibility = ["//visibility:public"], deps = [ "@ml_dtypes//:float8", "@ml_dtypes//:int4", @@ -1095,7 +1030,6 @@ cc_library( cc_library( name = "dso_loader", hdrs = ["dso_loader.h"], - visibility = ["//visibility:public"], deps = [ ":platform", ] + tf_stream_executor_deps("dso_loader"), @@ -1105,7 +1039,9 @@ cc_library( name = "logging", compatible_with = get_compatible_with_portable(), textual_hdrs = ["logging.h"], - visibility = ["//visibility:public"], + visibility = [ + "//visibility:public", + ], deps = tf_logging_deps(), ) @@ -1113,18 +1049,26 @@ cc_library( name = "error_logging", compatible_with = get_compatible_with_portable(), textual_hdrs = ["error_logging.h"], - visibility = ["//visibility:public"], + visibility = [ + "//visibility:public", + ], deps = [ "@com_google_absl//absl/status", "@com_google_absl//absl/strings", ] + tf_error_logging_deps(), ) +cc_library( + name = "grpc_credentials", + compatible_with = get_compatible_with_portable(), + textual_hdrs = ["grpc_credentials.h"], + deps = tsl_grpc_credentials_deps(), +) + cc_library( name = "prefetch", hdrs = ["prefetch.h"], compatible_with = get_compatible_with_portable(), - visibility = ["//visibility:public"], deps = [ "@com_google_absl//absl/base:prefetch", ], @@ -1135,7 +1079,6 @@ cc_library( srcs = ["hash.cc"], hdrs = ["hash.h"], compatible_with = get_compatible_with_portable(), - visibility = ["//visibility:public"], deps = [ ":macros", ":raw_coding", @@ -1147,7 +1090,6 @@ cc_library( cc_library( name = "human_readable_json", textual_hdrs = ["human_readable_json.h"], - visibility = ["//visibility:public"], deps = tf_platform_deps("human_readable_json"), ) @@ -1155,7 +1097,6 @@ cc_library( name = "raw_coding", hdrs = ["raw_coding.h"], compatible_with = get_compatible_with_portable(), - visibility = ["//visibility:public"], deps = [ ":byte_order", ":types", @@ -1171,13 +1112,12 @@ filegroup( "str_util.h", ], compatible_with = get_compatible_with_portable(), - visibility = ["//visibility:public"], + visibility = internal_visibility(["//tensorflow/core:__pkg__"]), ) cc_library( name = "casts", hdrs = ["casts.h"], - visibility = ["//visibility:public"], deps = [ ":platform", ] + tf_platform_deps("casts"), @@ -1187,7 +1127,6 @@ cc_library( name = "setround", srcs = ["setround.cc"], hdrs = ["setround.h"], - visibility = ["//visibility:public"], deps = [ ":logging", ":macros", @@ -1198,7 +1137,6 @@ cc_library( name = "stacktrace", hdrs = ["stacktrace.h"], compatible_with = get_compatible_with_portable(), - visibility = ["//visibility:public"], deps = [ ":platform", ] + tf_windows_aware_platform_deps("stacktrace"), @@ -1211,7 +1149,6 @@ tsl_cc_test( "stacktrace_test.cc", ], tags = ["no_windows"], - visibility = ["//visibility:public"], deps = [ ":logging", ":stacktrace", @@ -1224,14 +1161,12 @@ cc_library( name = "cuda_libdevice_path", compatible_with = get_compatible_with_portable(), textual_hdrs = ["cuda_libdevice_path.h"], - visibility = ["//visibility:public"], deps = tf_cuda_libdevice_path_deps(), ) cc_library( name = "file_statistics", hdrs = ["file_statistics.h"], - visibility = ["//visibility:public"], deps = [ ":types", ], @@ -1241,7 +1176,6 @@ cc_library( name = "fingerprint", hdrs = ["fingerprint.h"], compatible_with = get_compatible_with_portable(), - visibility = ["//visibility:public"], deps = [ ":platform", ":stringpiece", @@ -1253,7 +1187,6 @@ tsl_cc_test( name = "fingerprint_test", size = "small", srcs = ["fingerprint_test.cc"], - visibility = ["//visibility:public"], deps = [ ":fingerprint", ":test", @@ -1267,7 +1200,6 @@ cc_library( srcs = ["tensor_float_32_utils.cc"], hdrs = ["tensor_float_32_utils.h"], copts = tsl_copts(), - visibility = ["//visibility:public"], alwayslink = 1, ) @@ -1275,7 +1207,6 @@ cc_library( name = "scanner", srcs = ["scanner.cc"], hdrs = ["scanner.h"], - visibility = ["//visibility:public"], deps = [ ":macros", ":str_util", @@ -1287,21 +1218,18 @@ filegroup( name = "tensor_float_32_hdr", srcs = ["tensor_float_32_utils.h"], compatible_with = get_compatible_with_portable(), - visibility = ["//visibility:public"], ) cc_library( name = "tensor_float_32_hdr_lib", hdrs = [":tensor_float_32_hdr"], compatible_with = get_compatible_with_portable(), - visibility = ["//visibility:public"], ) tsl_cc_test( name = "ctstring_test", size = "small", srcs = ["ctstring_test.cc"], - visibility = ["//visibility:public"], deps = [ ":test", ":test_main", @@ -1313,7 +1241,6 @@ tsl_cc_test( name = "hash_test", size = "small", srcs = ["hash_test.cc"], - visibility = ["//visibility:public"], deps = [ ":hash", ":logging", @@ -1327,7 +1254,6 @@ tsl_cc_test( name = "path_test", size = "small", srcs = ["path_test.cc"], - visibility = ["//visibility:public"], deps = [ ":env", ":env_impl", @@ -1341,7 +1267,6 @@ tsl_cc_test( tsl_cc_test( name = "random_test", srcs = ["random_test.cc"], - visibility = ["//visibility:public"], deps = [ ":random", ":test", @@ -1354,7 +1279,6 @@ tsl_cc_test( name = "tstring_test", size = "small", srcs = ["tstring_test.cc"], - visibility = ["//visibility:public"], deps = [ ":cord", ":platform", @@ -1374,7 +1298,6 @@ cc_library( "//tsl:windows": [], "//conditions:default": ["-lm"], }), - visibility = ["//visibility:public"], deps = [ ":platform", ":stacktrace_handler", @@ -1389,7 +1312,6 @@ tsl_cc_test( name = "status_test", size = "small", srcs = ["status_test.cc"], - visibility = ["//visibility:public"], deps = [ ":errors", ":stack_frame", @@ -1410,7 +1332,6 @@ tsl_cc_test( name = "statusor_test", size = "small", srcs = ["statusor_test.cc"], - visibility = ["//visibility:public"], deps = [ ":errors", ":macros", @@ -1426,7 +1347,6 @@ tsl_cc_test( name = "status_matchers_test", size = "small", srcs = ["status_matchers_test.cc"], - visibility = ["//visibility:public"], deps = [ ":errors", ":status", @@ -1442,29 +1362,16 @@ cc_library( name = "notification", hdrs = ["notification.h"], compatible_with = get_compatible_with_portable(), - visibility = ["//visibility:public"], deps = [ "@com_google_absl//absl/synchronization", "@com_google_absl//absl/time", ], ) -cc_library( - name = "png", - hdrs = ["png.h"], - compatible_with = get_compatible_with_portable(), - visibility = ["//visibility:public"], - deps = [ - ":platform", - "@png", - ], -) - cc_library( name = "threadpool_options", hdrs = ["threadpool_options.h"], compatible_with = get_compatible_with_portable(), - visibility = ["//visibility:public"], deps = [ ":threadpool_interface", ], @@ -1474,7 +1381,6 @@ cc_library( name = "unbounded_work_queue", hdrs = ["unbounded_work_queue.h"], compatible_with = get_compatible_with_portable(), - visibility = ["//visibility:public"], deps = [ ":platform", ] + tf_platform_deps("unbounded_work_queue"), @@ -1483,7 +1389,6 @@ cc_library( tsl_cc_test( name = "unbounded_work_queue_test", srcs = ["unbounded_work_queue_test.cc"], - visibility = ["//visibility:public"], deps = [ ":blocking_counter", ":env", @@ -1500,14 +1405,12 @@ cc_library( name = "context", compatible_with = get_compatible_with_portable(), textual_hdrs = ["context.h"], - visibility = ["//visibility:public"], deps = tf_platform_deps("context"), ) cc_library( name = "load_library", textual_hdrs = ["load_library.h"], - visibility = ["//visibility:public"], deps = [ ":status", ] + tf_windows_aware_platform_deps("load_library"), @@ -1517,7 +1420,6 @@ cc_library( name = "abi", srcs = ["abi.cc"], hdrs = ["abi.h"], - visibility = ["//visibility:public"], deps = [ ":types", ], @@ -1526,7 +1428,6 @@ cc_library( cc_library( name = "refcount", hdrs = ["refcount.h"], - visibility = ["//visibility:public"], deps = [ ":logging", ":mutex", @@ -1537,7 +1438,6 @@ cc_library( cc_library( name = "null_file_system", hdrs = ["null_file_system.h"], - visibility = ["//visibility:public"], deps = [ ":env", ], @@ -1553,7 +1453,6 @@ tsl_cc_test( "//tsl/platform/testdata:test_noop", "//tsl/platform/testdata:test_stderr", ], - visibility = ["//visibility:public"], deps = [ ":path", ":strcat", @@ -1568,7 +1467,6 @@ tsl_cc_test( name = "errors_test", size = "small", srcs = ["errors_test.cc"], - visibility = ["//visibility:public"], deps = [ ":errors", ":test", @@ -1583,7 +1481,6 @@ tsl_cc_test( srcs = [ "intrusive_ptr_test.cc", ], - visibility = ["//visibility:public"], deps = [ ":intrusive_ptr", ":refcount", @@ -1602,7 +1499,6 @@ tsl_cc_test( "manual", "notap", ], - visibility = ["//visibility:public"], deps = [ ":logging", ":test", @@ -1615,7 +1511,6 @@ tsl_cc_test( size = "small", srcs = ["setround_test.cc"], tags = ["noclang"], - visibility = ["//visibility:public"], deps = [ ":setround", ":test", @@ -1629,7 +1524,6 @@ tsl_cc_test( srcs = [ "refcount_test.cc", ], - visibility = ["//visibility:public"], deps = [ ":env", ":env_impl", @@ -1645,7 +1539,6 @@ tsl_cc_test( srcs = [ "integral_types_test.cc", ], - visibility = ["//visibility:public"], deps = [ ":test", ":test_main", @@ -1659,7 +1552,6 @@ tsl_cc_test( srcs = [ "logging_test.cc", ], - visibility = ["//visibility:public"], deps = [ ":logging", ":test", @@ -1673,7 +1565,6 @@ tsl_cc_test( srcs = [ "mutex_test.cc", ], - visibility = ["//visibility:public"], deps = [ ":env", ":env_impl", @@ -1699,7 +1590,6 @@ tsl_cc_test( srcs = [ "net_test.cc", ], - visibility = ["//visibility:public"], deps = [ ":logging", ":net", @@ -1717,7 +1607,6 @@ tsl_cc_test( tags = [ "notap", #TODO(b/245510532) : disabled due to flakiness. ], - visibility = ["//visibility:public"], deps = [ ":env", ":env_impl", @@ -1735,7 +1624,6 @@ tsl_cc_test( srcs = [ "scanner_test.cc", ], - visibility = ["//visibility:public"], deps = [ ":scanner", ":test", @@ -1749,7 +1637,6 @@ tsl_cc_test( srcs = [ "str_util_test.cc", ], - visibility = ["//visibility:public"], deps = [ ":str_util", ":test", @@ -1763,7 +1650,6 @@ tsl_cc_test( srcs = [ "strcat_test.cc", ], - visibility = ["//visibility:public"], deps = [ ":strcat", ":stringprintf", @@ -1780,7 +1666,6 @@ tsl_cc_test( srcs = [ "stringpiece_test.cc", ], - visibility = ["//visibility:public"], deps = [ ":stringpiece", ":test", @@ -1794,7 +1679,6 @@ tsl_cc_test( srcs = [ "stringprintf_test.cc", ], - visibility = ["//visibility:public"], deps = [ ":stringprintf", ":test", @@ -1808,7 +1692,6 @@ tsl_cc_test( srcs = [ "numbers_test.cc", ], - visibility = ["//visibility:public"], deps = [ ":numbers", ":test", @@ -1819,7 +1702,6 @@ tsl_cc_test( bzl_library( name = "rules_cc_bzl", srcs = ["rules_cc.bzl"], - visibility = ["//visibility:public"], deps = tf_platform_alias("rules_cc_bzl"), ) @@ -1832,7 +1714,6 @@ cc_library( "retrying_utils.h", ], copts = tsl_copts(), - visibility = ["//visibility:public"], deps = [ ":env", ":errors", @@ -1849,7 +1730,6 @@ cc_library( "retrying_file_system.h", ], copts = tsl_copts(), - visibility = ["//visibility:public"], deps = [ ":env", ":errors", @@ -1863,7 +1743,6 @@ tsl_cc_test( name = "retrying_file_system_test", size = "small", srcs = ["retrying_file_system_test.cc"], - visibility = ["//visibility:public"], deps = [ ":env_impl", ":retrying_file_system", @@ -1878,7 +1757,6 @@ tsl_cc_test( name = "retrying_utils_test", size = "small", srcs = ["retrying_utils_test.cc"], - visibility = ["//visibility:public"], deps = [ ":env", ":env_impl", diff --git a/third_party/xla/third_party/tsl/tsl/platform/build_config.bzl b/third_party/xla/third_party/tsl/tsl/platform/build_config.bzl index 1658d0b017ec68..6a9193f65b1554 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/build_config.bzl +++ b/third_party/xla/third_party/tsl/tsl/platform/build_config.bzl @@ -38,6 +38,7 @@ load( _tf_stream_executor_deps = "tf_stream_executor_deps", _tf_windows_aware_platform_deps = "tf_windows_aware_platform_deps", _tsl_cc_test = "tsl_cc_test", + _tsl_grpc_credentials_deps = "tsl_grpc_credentials_deps", _tsl_protobuf_deps = "tsl_protobuf_deps", ) @@ -78,3 +79,4 @@ tf_stream_executor_deps = _tf_stream_executor_deps tf_windows_aware_platform_deps = _tf_windows_aware_platform_deps tsl_protobuf_deps = _tsl_protobuf_deps tsl_cc_test = _tsl_cc_test +tsl_grpc_credentials_deps = _tsl_grpc_credentials_deps diff --git a/third_party/xla/third_party/tsl/tsl/platform/cloud/BUILD b/third_party/xla/third_party/tsl/tsl/platform/cloud/BUILD index ad242085cef386..e9588213ddf932 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/cloud/BUILD +++ b/third_party/xla/third_party/tsl/tsl/platform/cloud/BUILD @@ -1,17 +1,20 @@ # Description: # Cloud file system implementation. -load("@local_tsl//tsl/platform:rules_cc.bzl", "cc_library") load( "//tsl:tsl.bzl", "if_windows", - "set_external_visibility", + "internal_visibility", "tsl_copts", ) load("//tsl/platform:build_config.bzl", "tsl_cc_test") +load("@local_tsl//tsl/platform:rules_cc.bzl", "cc_library") package( - default_visibility = ["//visibility:public"], + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = internal_visibility([ + ":dependency_allowlist", + ]), licenses = ["notice"], ) @@ -29,7 +32,6 @@ cc_library( name = "expiring_lru_cache", hdrs = ["expiring_lru_cache.h"], copts = tsl_copts(), - visibility = ["//visibility:public"], deps = [ "//tsl/platform:env", "//tsl/platform:mutex", @@ -42,7 +44,6 @@ cc_library( name = "file_block_cache", hdrs = ["file_block_cache.h"], copts = tsl_copts(), - visibility = ["//visibility:public"], deps = [ "//tsl/platform:env", "//tsl/platform:mutex", @@ -78,7 +79,6 @@ cc_library( srcs = ["gcs_dns_cache.cc"], hdrs = ["gcs_dns_cache.h"], copts = tsl_copts(), - visibility = ["//visibility:public"], deps = [ ":http_request", "//tsl/platform:env", @@ -95,7 +95,6 @@ cc_library( srcs = ["gcs_throttle.cc"], hdrs = ["gcs_throttle.h"], copts = tsl_copts(), - visibility = ["//visibility:public"], deps = [ "//tsl/platform:env", ], @@ -190,7 +189,6 @@ cc_library( name = "http_request", hdrs = ["http_request.h"], copts = tsl_copts(), - visibility = ["//visibility:public"], deps = [ "//tsl/platform:env", "//tsl/platform:errors", @@ -207,7 +205,6 @@ cc_library( srcs = ["curl_http_request.cc"], hdrs = ["curl_http_request.h"], copts = tsl_copts(), - visibility = ["//visibility:public"], deps = [ ":http_request", "//tsl/lib/gtl:map_util", @@ -232,7 +229,6 @@ cc_library( "http_request_fake.h", ], copts = tsl_copts(), - visibility = ["//visibility:public"], deps = [ ":curl_http_request", "//tsl/lib/core:status_test_util", @@ -255,7 +251,6 @@ cc_library( "google_auth_provider.h", ], copts = tsl_copts(), - visibility = ["//visibility:public"], deps = [ ":compute_engine_metadata_client", ":oauth_client", @@ -281,7 +276,6 @@ cc_library( "compute_engine_metadata_client.h", ], copts = tsl_copts(), - visibility = ["//visibility:public"], deps = [ ":curl_http_request", ":http_request", @@ -302,7 +296,6 @@ cc_library( "zone_provider.h", ], copts = tsl_copts(), - visibility = ["//visibility:public"], deps = [ ":compute_engine_metadata_client", "//tsl/platform:errors", @@ -316,7 +309,6 @@ cc_library( testonly = 1, hdrs = ["now_seconds_env.h"], copts = tsl_copts(), - visibility = ["//visibility:public"], deps = [ "//tsl/platform:env", "//tsl/platform:mutex", @@ -333,7 +325,6 @@ cc_library( "oauth_client.h", ], copts = tsl_copts(), - visibility = ["//visibility:public"], deps = [ ":curl_http_request", ":http_request", @@ -355,7 +346,6 @@ cc_library( "time_util.h", ], copts = tsl_copts(), - visibility = ["//visibility:public"], deps = [ "//tsl/platform:errors", "//tsl/platform:status", @@ -366,7 +356,6 @@ tsl_cc_test( name = "expiring_lru_cache_test", size = "small", srcs = ["expiring_lru_cache_test.cc"], - visibility = ["//visibility:public"], deps = [ ":expiring_lru_cache", ":now_seconds_env", @@ -381,7 +370,6 @@ tsl_cc_test( name = "ram_file_block_cache_test", size = "small", srcs = ["ram_file_block_cache_test.cc"], - visibility = ["//visibility:public"], deps = [ ":now_seconds_env", ":ram_file_block_cache", @@ -399,7 +387,6 @@ tsl_cc_test( name = "gcs_file_system_test", size = "small", srcs = ["gcs_file_system_test.cc"], - visibility = ["//visibility:public"], deps = [ ":gcs_file_system", ":http_request_fake", @@ -420,7 +407,6 @@ tsl_cc_test( size = "small", srcs = ["gcs_dns_cache_test.cc"], linkopts = if_windows(["-DEFAULTLIB:ws2_32.lib"]), - visibility = ["//visibility:public"], deps = [ ":gcs_dns_cache", "//tsl/platform:env_impl", @@ -435,7 +421,6 @@ tsl_cc_test( size = "small", srcs = ["gcs_throttle_test.cc"], linkopts = if_windows(["-DEFAULTLIB:ws2_32.lib"]), - visibility = ["//visibility:public"], deps = [ ":gcs_throttle", "//tsl/lib/core:status_test_util", @@ -450,7 +435,6 @@ tsl_cc_test( name = "curl_http_request_test", size = "small", srcs = ["curl_http_request_test.cc"], - visibility = ["//visibility:public"], deps = [ ":curl_http_request", "//tsl/lib/core:status_test_util", @@ -471,7 +455,6 @@ tsl_cc_test( "//tsl/platform/cloud/testdata:service_account_credentials", "//tsl/platform/cloud/testdata:service_account_public_key", ], - visibility = ["//visibility:public"], deps = [ ":http_request_fake", ":oauth_client", @@ -495,7 +478,6 @@ tsl_cc_test( "//tsl/platform/cloud/testdata:application_default_credentials", "//tsl/platform/cloud/testdata:service_account_credentials", ], - visibility = ["//visibility:public"], deps = [ ":google_auth_provider", ":http_request_fake", @@ -512,7 +494,6 @@ tsl_cc_test( name = "compute_engine_metadata_client_test", size = "small", srcs = ["compute_engine_metadata_client_test.cc"], - visibility = ["//visibility:public"], deps = [ ":compute_engine_metadata_client", ":http_request_fake", @@ -527,7 +508,6 @@ tsl_cc_test( name = "compute_engine_zone_provider_test", size = "small", srcs = ["compute_engine_zone_provider_test.cc"], - visibility = ["//visibility:public"], deps = [ ":compute_engine_zone_provider", ":http_request_fake", @@ -541,7 +521,6 @@ tsl_cc_test( name = "time_util_test", size = "small", srcs = ["time_util_test.cc"], - visibility = ["//visibility:public"], deps = [ ":time_util", "//tsl/lib/core:status_test_util", diff --git a/third_party/xla/third_party/tsl/tsl/platform/cloud/gcs_file_system.cc b/third_party/xla/third_party/tsl/tsl/platform/cloud/gcs_file_system.cc index ea65028a96cd22..869dc993ee0a9d 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/cloud/gcs_file_system.cc +++ b/third_party/xla/third_party/tsl/tsl/platform/cloud/gcs_file_system.cc @@ -66,10 +66,10 @@ limitations under the License. namespace tsl { namespace { -constexpr char kGcsUriBase[] = "https://www.googleapis.com/storage/v1/"; +constexpr char kGcsUriBase[] = "https://www.googleapis.com./storage/v1/"; constexpr char kGcsUploadUriBase[] = - "https://www.googleapis.com/upload/storage/v1/"; -constexpr char kStorageHost[] = "storage.googleapis.com"; + "https://www.googleapis.com./upload/storage/v1/"; +constexpr char kStorageHost[] = "storage.googleapis.com."; constexpr char kBucketMetadataLocationKey[] = "location"; constexpr size_t kReadAppendableFileBufferSize = 1024 * 1024; // In bytes. constexpr int kGetChildrenDefaultPageSize = 1000; diff --git a/third_party/xla/third_party/tsl/tsl/platform/cloud/gcs_file_system_test.cc b/third_party/xla/third_party/tsl/tsl/platform/cloud/gcs_file_system_test.cc index 9221128276af9e..e403599096e5f3 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/cloud/gcs_file_system_test.cc +++ b/third_party/xla/third_party/tsl/tsl/platform/cloud/gcs_file_system_test.cc @@ -62,13 +62,13 @@ class FakeZoneProvider : public ZoneProvider { TEST(GcsFileSystemTest, NewRandomAccessFile_NoBlockCache) { std::vector requests( {new FakeHttpRequest( - "Uri: https://storage.googleapis.com/bucket/random_access.txt\n" + "Uri: https://storage.googleapis.com./bucket/random_access.txt\n" "Auth Token: fake_token\n" "Range: 0-5\n" "Timeouts: 5 1 20\n", "012345"), new FakeHttpRequest( - "Uri: https://storage.googleapis.com/bucket/random_access.txt\n" + "Uri: https://storage.googleapis.com./bucket/random_access.txt\n" "Auth Token: fake_token\n" "Range: 6-11\n" "Timeouts: 5 1 20\n", @@ -108,13 +108,13 @@ TEST(GcsFileSystemTest, NewRandomAccessFile_NoBlockCache) { TEST(GcsFileSystemTest, NewRandomAccessFile_Buffered) { std::vector requests({ new FakeHttpRequest( - "Uri: https://storage.googleapis.com/bucket/random_access.txt\n" + "Uri: https://storage.googleapis.com./bucket/random_access.txt\n" "Auth Token: fake_token\n" "Range: 0-9\n" "Timeouts: 5 1 20\n", "0123456789"), new FakeHttpRequest( - "Uri: https://storage.googleapis.com/bucket/random_access.txt\n" + "Uri: https://storage.googleapis.com./bucket/random_access.txt\n" "Auth Token: fake_token\n" "Range: 10-19\n" "Timeouts: 5 1 20\n", @@ -155,14 +155,14 @@ TEST(GcsFileSystemTest, NewRandomAccessFile_Buffered) { TEST(GcsFileSystemTest, NewRandomAccessFile_Buffered_Errors) { std::vector requests({ new FakeHttpRequest( - "Uri: https://storage.googleapis.com/bucket/random_access.txt\n" + "Uri: https://storage.googleapis.com./bucket/random_access.txt\n" "Auth Token: fake_token\n" "Range: 0-9\n" "Timeouts: 5 1 20\n", "Server Not", errors::Unavailable("important HTTP error 308"), nullptr, {}, 308), new FakeHttpRequest( - "Uri: https://storage.googleapis.com/bucket/random_access.txt\n" + "Uri: https://storage.googleapis.com./bucket/random_access.txt\n" "Auth Token: fake_token\n" "Range: 6-15\n" "Timeouts: 5 1 20\n", @@ -204,13 +204,13 @@ TEST(GcsFileSystemTest, NewRandomAccessFile_Buffered_Errors) { TEST(GcsFileSystemTest, NewRandomAccessFile_Buffered_ReadAtEOF) { std::vector requests( {new FakeHttpRequest( - "Uri: https://storage.googleapis.com/bucket/random_access.txt\n" + "Uri: https://storage.googleapis.com./bucket/random_access.txt\n" "Auth Token: fake_token\n" "Range: 0-9\n" "Timeouts: 5 1 20\n", "0123456789"), new FakeHttpRequest( - "Uri: https://storage.googleapis.com/bucket/random_access.txt\n" + "Uri: https://storage.googleapis.com./bucket/random_access.txt\n" "Auth Token: fake_token\n" "Range: 10-19\n" "Timeouts: 5 1 20\n", @@ -251,7 +251,7 @@ TEST(GcsFileSystemTest, NewRandomAccessFile_Buffered_CachedOutOfRange) { // In this test, there is only one backend request since we cache the file // size. std::vector requests({new FakeHttpRequest( - "Uri: https://storage.googleapis.com/bucket/random_access.txt\n" + "Uri: https://storage.googleapis.com./bucket/random_access.txt\n" "Auth Token: fake_token\n" "Range: 0-9\n" "Timeouts: 5 1 20\n", @@ -297,13 +297,13 @@ TEST(GcsFileSystemTest, NewRandomAccessFile_Buffered_CachedNotSequential) { // a backend request. std::vector requests( {new FakeHttpRequest( - "Uri: https://storage.googleapis.com/bucket/random_access.txt\n" + "Uri: https://storage.googleapis.com./bucket/random_access.txt\n" "Auth Token: fake_token\n" "Range: 1-10\n" "Timeouts: 5 1 20\n", "12345678"), new FakeHttpRequest( - "Uri: https://storage.googleapis.com/bucket/random_access.txt\n" + "Uri: https://storage.googleapis.com./bucket/random_access.txt\n" "Auth Token: fake_token\n" "Range: 0-9\n" "Timeouts: 5 1 20\n", @@ -339,13 +339,13 @@ TEST(GcsFileSystemTest, NewRandomAccessFile_Buffered_CachedNotSequential) { TEST(GcsFileSystemTest, NewRandomAccessFile_Buffered_Growing) { std::vector requests( {new FakeHttpRequest( - "Uri: https://storage.googleapis.com/bucket/random_access.txt\n" + "Uri: https://storage.googleapis.com./bucket/random_access.txt\n" "Auth Token: fake_token\n" "Range: 0-9\n" "Timeouts: 5 1 20\n", "012345678"), new FakeHttpRequest( - "Uri: https://storage.googleapis.com/bucket/random_access.txt\n" + "Uri: https://storage.googleapis.com./bucket/random_access.txt\n" "Auth Token: fake_token\n" "Range: 9-18\n" "Timeouts: 5 1 20\n", @@ -387,13 +387,13 @@ TEST(GcsFileSystemTest, NewRandomAccessFile_Buffered_ReadBackwards) { // Go backwards in the file. It should trigger a new read. std::vector requests( {new FakeHttpRequest( - "Uri: https://storage.googleapis.com/bucket/random_access.txt\n" + "Uri: https://storage.googleapis.com./bucket/random_access.txt\n" "Auth Token: fake_token\n" "Range: 5-14\n" "Timeouts: 5 1 20\n", "56789"), new FakeHttpRequest( - "Uri: https://storage.googleapis.com/bucket/random_access.txt\n" + "Uri: https://storage.googleapis.com./bucket/random_access.txt\n" "Auth Token: fake_token\n" "Range: 0-9\n" "Timeouts: 5 1 20\n", @@ -433,7 +433,7 @@ TEST(GcsFileSystemTest, NewRandomAccessFile_Buffered_ReadBackwards) { TEST(GcsFileSystemTest, NewRandomAccessFile_WithLocationConstraintInSameLocation) { std::vector requests({new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket\n" + "Uri: https://www.googleapis.com./storage/v1/b/bucket\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", R"( @@ -460,7 +460,7 @@ TEST(GcsFileSystemTest, TEST(GcsFileSystemTest, NewRandomAccessFile_WithLocationConstraintCaching) { std::vector requests( {new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket\n" + "Uri: https://www.googleapis.com./storage/v1/b/bucket\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", R"( @@ -468,7 +468,7 @@ TEST(GcsFileSystemTest, NewRandomAccessFile_WithLocationConstraintCaching) { "location":"US-EAST1" })"), new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/anotherbucket\n" + "Uri: https://www.googleapis.com./storage/v1/b/anotherbucket\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", R"( @@ -476,7 +476,7 @@ TEST(GcsFileSystemTest, NewRandomAccessFile_WithLocationConstraintCaching) { "location":"US-EAST1" })"), new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket\n" + "Uri: https://www.googleapis.com./storage/v1/b/bucket\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", R"( @@ -517,7 +517,7 @@ TEST(GcsFileSystemTest, NewRandomAccessFile_WithLocationConstraintCaching) { TEST(GcsFileSystemTest, NewRandomAccessFile_WithLocationConstraintInDifferentLocation) { std::vector requests({new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket\n" + "Uri: https://www.googleapis.com./storage/v1/b/bucket\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", R"( @@ -547,13 +547,13 @@ TEST(GcsFileSystemTest, TEST(GcsFileSystemTest, NewRandomAccessFile_NoBlockCache_DifferentN) { std::vector requests( {new FakeHttpRequest( - "Uri: https://storage.googleapis.com/bucket/random_access.txt\n" + "Uri: https://storage.googleapis.com./bucket/random_access.txt\n" "Auth Token: fake_token\n" "Range: 0-2\n" "Timeouts: 5 1 20\n", "012"), new FakeHttpRequest( - "Uri: https://storage.googleapis.com/bucket/random_access.txt\n" + "Uri: https://storage.googleapis.com./bucket/random_access.txt\n" "Auth Token: fake_token\n" "Range: 3-12\n" "Timeouts: 5 1 20\n", @@ -593,26 +593,26 @@ TEST(GcsFileSystemTest, NewRandomAccessFile_WithBlockCache) { // "0123456789abcde". std::vector requests( {new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket/o/" + "Uri: https://www.googleapis.com./storage/v1/b/bucket/o/" "random_access.txt?fields=size%2Cgeneration%2Cupdated\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", strings::StrCat("{\"size\": \"15\",\"generation\": \"1\"," "\"updated\": \"2016-04-29T23:15:24.896Z\"}")), new FakeHttpRequest( - "Uri: https://storage.googleapis.com/bucket/random_access.txt\n" + "Uri: https://storage.googleapis.com./bucket/random_access.txt\n" "Auth Token: fake_token\n" "Range: 0-8\n" "Timeouts: 5 1 20\n", "012345678"), new FakeHttpRequest( - "Uri: https://storage.googleapis.com/bucket/random_access.txt\n" + "Uri: https://storage.googleapis.com./bucket/random_access.txt\n" "Auth Token: fake_token\n" "Range: 9-17\n" "Timeouts: 5 1 20\n", "9abcde"), new FakeHttpRequest( - "Uri: https://storage.googleapis.com/bucket/random_access.txt\n" + "Uri: https://storage.googleapis.com./bucket/random_access.txt\n" "Auth Token: fake_token\n" "Range: 18-26\n" "Timeouts: 5 1 20\n", @@ -679,27 +679,27 @@ TEST(GcsFileSystemTest, NewRandomAccessFile_WithBlockCache_Flush) { // "0123456789abcde". std::vector requests( {new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket/o/" + "Uri: https://www.googleapis.com./storage/v1/b/bucket/o/" "random_access.txt?fields=size%2Cgeneration%2Cupdated\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", strings::StrCat("{\"size\": \"15\",\"generation\": \"1\"," "\"updated\": \"2016-04-29T23:15:24.896Z\"}")), new FakeHttpRequest( - "Uri: https://storage.googleapis.com/bucket/random_access.txt\n" + "Uri: https://storage.googleapis.com./bucket/random_access.txt\n" "Auth Token: fake_token\n" "Range: 0-8\n" "Timeouts: 5 1 20\n", "012345678"), new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket/o/" + "Uri: https://www.googleapis.com./storage/v1/b/bucket/o/" "random_access.txt?fields=size%2Cgeneration%2Cupdated\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", strings::StrCat("{\"size\": \"15\",\"generation\": \"1\"," "\"updated\": \"2016-04-29T23:15:24.896Z\"}")), new FakeHttpRequest( - "Uri: https://storage.googleapis.com/bucket/random_access.txt\n" + "Uri: https://storage.googleapis.com./bucket/random_access.txt\n" "Auth Token: fake_token\n" "Range: 0-8\n" "Timeouts: 5 1 20\n", @@ -738,22 +738,24 @@ TEST(GcsFileSystemTest, NewRandomAccessFile_WithBlockCache_MaxStaleness) { // "0123456789abcdef". std::vector requests( {new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket/o/" + "Uri: https://www.googleapis.com./storage/v1/b/bucket/o/" "object?fields=size%2Cgeneration%2Cupdated\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", strings::StrCat("{\"size\": \"16\",\"generation\": \"1\"," "\"updated\": \"2016-04-29T23:15:24.896Z\"}")), - new FakeHttpRequest("Uri: https://storage.googleapis.com/bucket/object\n" - "Auth Token: fake_token\n" - "Range: 0-7\n" - "Timeouts: 5 1 20\n", - "01234567"), - new FakeHttpRequest("Uri: https://storage.googleapis.com/bucket/object\n" - "Auth Token: fake_token\n" - "Range: 8-15\n" - "Timeouts: 5 1 20\n", - "89abcdef")}); + new FakeHttpRequest( + "Uri: https://storage.googleapis.com./bucket/object\n" + "Auth Token: fake_token\n" + "Range: 0-7\n" + "Timeouts: 5 1 20\n", + "01234567"), + new FakeHttpRequest( + "Uri: https://storage.googleapis.com./bucket/object\n" + "Auth Token: fake_token\n" + "Range: 8-15\n" + "Timeouts: 5 1 20\n", + "89abcdef")}); GcsFileSystem fs( std::unique_ptr(new FakeAuthProvider), std::unique_ptr( @@ -800,27 +802,27 @@ TEST(GcsFileSystemTest, NewRandomAccessFile_WithBlockCache_FileSignatureChanges) { std::vector requests( {new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket/o/" + "Uri: https://www.googleapis.com./storage/v1/b/bucket/o/" "random_access.txt?fields=size%2Cgeneration%2Cupdated\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", strings::StrCat("{\"size\": \"5\",\"generation\": \"1\"," "\"updated\": \"2016-04-29T23:15:24.896Z\"}")), new FakeHttpRequest( - "Uri: https://storage.googleapis.com/bucket/random_access.txt\n" + "Uri: https://storage.googleapis.com./bucket/random_access.txt\n" "Auth Token: fake_token\n" "Range: 0-8\n" "Timeouts: 5 1 20\n", "01234"), new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket/o/" + "Uri: https://www.googleapis.com./storage/v1/b/bucket/o/" "random_access.txt?fields=size%2Cgeneration%2Cupdated\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", strings::StrCat("{\"size\": \"5\",\"generation\": \"2\"," "\"updated\": \"2016-04-29T23:15:24.896Z\"}")), new FakeHttpRequest( - "Uri: https://storage.googleapis.com/bucket/random_access.txt\n" + "Uri: https://storage.googleapis.com./bucket/random_access.txt\n" "Auth Token: fake_token\n" "Range: 0-8\n" "Timeouts: 5 1 20\n", @@ -874,14 +876,14 @@ TEST(GcsFileSystemTest, NewRandomAccessFile_NoObjectName) { TEST(GcsFileSystemTest, NewRandomAccessFile_InconsistentRead) { std::vector requests( {new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket/o/" + "Uri: https://www.googleapis.com./storage/v1/b/bucket/o/" "random_access.txt?fields=size%2Cgeneration%2Cupdated\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", strings::StrCat("{\"size\": \"6\",\"generation\": \"1\"," "\"updated\": \"2016-04-29T23:15:24.896Z\"}")), new FakeHttpRequest( - "Uri: https://storage.googleapis.com/bucket/random_access.txt\n" + "Uri: https://storage.googleapis.com./bucket/random_access.txt\n" "Auth Token: fake_token\n" "Range: 0-5\n" "Timeouts: 5 1 20\n", @@ -917,20 +919,20 @@ TEST(GcsFileSystemTest, NewRandomAccessFile_InconsistentRead) { TEST(GcsFileSystemTest, NewWritableFile) { std::vector requests( {new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket/o/" + "Uri: https://www.googleapis.com./storage/v1/b/bucket/o/" "path%2Fwriteable?fields=size%2Cgeneration%2Cupdated\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", strings::StrCat("{\"size\": \"16\",\"generation\": \"1\"," "\"updated\": \"2016-04-29T23:15:24.896Z\"}")), new FakeHttpRequest( - "Uri: https://storage.googleapis.com/bucket/path%2Fwriteable\n" + "Uri: https://storage.googleapis.com./bucket/path%2Fwriteable\n" "Auth Token: fake_token\n" "Range: 0-7\n" "Timeouts: 5 1 20\n", "01234567"), new FakeHttpRequest( - "Uri: https://www.googleapis.com/upload/storage/v1/b/bucket/o?" + "Uri: https://www.googleapis.com./upload/storage/v1/b/bucket/o?" "uploadType=resumable&name=path%2Fwriteable\n" "Auth Token: fake_token\n" "Header X-Upload-Content-Length: 17\n" @@ -944,14 +946,14 @@ TEST(GcsFileSystemTest, NewWritableFile) { "Put body: content1,content2\n", ""), new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket/o/" + "Uri: https://www.googleapis.com./storage/v1/b/bucket/o/" "path%2Fwriteable?fields=size%2Cgeneration%2Cupdated\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", strings::StrCat("{\"size\": \"33\",\"generation\": \"2\"," "\"updated\": \"2016-04-29T23:15:34.896Z\"}")), new FakeHttpRequest( - "Uri: https://storage.googleapis.com/bucket/path%2Fwriteable\n" + "Uri: https://storage.googleapis.com./bucket/path%2Fwriteable\n" "Auth Token: fake_token\n" "Range: 0-7\n" "Timeouts: 5 1 20\n", @@ -998,7 +1000,7 @@ TEST(GcsFileSystemTest, NewWritableFile) { TEST(GcsFileSystemTest, NewWritableFile_ResumeUploadSucceeds) { std::vector requests( {new FakeHttpRequest( - "Uri: https://www.googleapis.com/upload/storage/v1/b/bucket/o?" + "Uri: https://www.googleapis.com./upload/storage/v1/b/bucket/o?" "uploadType=resumable&name=path%2Fwriteable.txt\n" "Auth Token: fake_token\n" "Header X-Upload-Content-Length: 17\n" @@ -1076,20 +1078,20 @@ TEST(GcsFileSystemTest, NewWritableFile_ResumeUploadSucceedsOnGetStatus) { // path. std::vector requests( {new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket/o/" + "Uri: https://www.googleapis.com./storage/v1/b/bucket/o/" "path%2Fwriteable?fields=size%2Cgeneration%2Cupdated\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", strings::StrCat("{\"size\": \"16\",\"generation\": \"1\"," "\"updated\": \"2016-04-29T23:15:24.896Z\"}")), new FakeHttpRequest( - "Uri: https://storage.googleapis.com/bucket/path%2Fwriteable\n" + "Uri: https://storage.googleapis.com./bucket/path%2Fwriteable\n" "Auth Token: fake_token\n" "Range: 0-7\n" "Timeouts: 5 1 20\n", "01234567"), new FakeHttpRequest( - "Uri: https://www.googleapis.com/upload/storage/v1/b/bucket/o?" + "Uri: https://www.googleapis.com./upload/storage/v1/b/bucket/o?" "uploadType=resumable&name=path%2Fwriteable\n" "Auth Token: fake_token\n" "Header X-Upload-Content-Length: 17\n" @@ -1109,14 +1111,14 @@ TEST(GcsFileSystemTest, NewWritableFile_ResumeUploadSucceedsOnGetStatus) { "Put: yes\n", "", OkStatus(), nullptr, {}, 201), new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket/o/" + "Uri: https://www.googleapis.com./storage/v1/b/bucket/o/" "path%2Fwriteable?fields=size%2Cgeneration%2Cupdated\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", strings::StrCat("{\"size\": \"33\",\"generation\": \"2\"," "\"updated\": \"2016-04-29T23:19:24.896Z\"}")), new FakeHttpRequest( - "Uri: https://storage.googleapis.com/bucket/path%2Fwriteable\n" + "Uri: https://storage.googleapis.com./bucket/path%2Fwriteable\n" "Auth Token: fake_token\n" "Range: 0-7\n" "Timeouts: 5 1 20\n", @@ -1163,7 +1165,7 @@ TEST(GcsFileSystemTest, NewWritableFile_ResumeUploadSucceedsOnGetStatus) { TEST(GcsFileSystemTest, NewWritableFile_ResumeUploadAllAttemptsFail) { std::vector requests( {new FakeHttpRequest( - "Uri: https://www.googleapis.com/upload/storage/v1/b/bucket/o?" + "Uri: https://www.googleapis.com./upload/storage/v1/b/bucket/o?" "uploadType=resumable&name=path%2Fwriteable.txt\n" "Auth Token: fake_token\n" "Header X-Upload-Content-Length: 17\n" @@ -1196,7 +1198,7 @@ TEST(GcsFileSystemTest, NewWritableFile_ResumeUploadAllAttemptsFail) { // These calls will be made in the Close() attempt from the destructor. // Letting the destructor succeed. requests.emplace_back(new FakeHttpRequest( - "Uri: https://www.googleapis.com/upload/storage/v1/b/bucket/o?" + "Uri: https://www.googleapis.com./upload/storage/v1/b/bucket/o?" "uploadType=resumable&name=path%2Fwriteable.txt\n" "Auth Token: fake_token\n" "Header X-Upload-Content-Length: 17\n" @@ -1245,7 +1247,7 @@ TEST(GcsFileSystemTest, NewWritableFile_UploadReturns410) { std::vector requests( {new FakeHttpRequest( - "Uri: https://www.googleapis.com/upload/storage/v1/b/bucket/o?" + "Uri: https://www.googleapis.com./upload/storage/v1/b/bucket/o?" "uploadType=resumable&name=path%2Fwriteable.txt\n" "Auth Token: fake_token\n" "Header X-Upload-Content-Length: 17\n" @@ -1262,7 +1264,7 @@ TEST(GcsFileSystemTest, NewWritableFile_UploadReturns410) { // These calls will be made in the Close() attempt from the destructor. // Letting the destructor succeed. new FakeHttpRequest( - "Uri: https://www.googleapis.com/upload/storage/v1/b/bucket/o?" + "Uri: https://www.googleapis.com./upload/storage/v1/b/bucket/o?" "uploadType=resumable&name=path%2Fwriteable.txt\n" "Auth Token: fake_token\n" "Header X-Upload-Content-Length: 17\n" @@ -1334,26 +1336,26 @@ TEST(GcsFileSystemTest, NewWritableFile_NoObjectName) { TEST(GcsFileSystemTest, NewAppendableFile) { std::vector requests( {new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket/o/" + "Uri: https://www.googleapis.com./storage/v1/b/bucket/o/" "path%2Fappendable?fields=size%2Cgeneration%2Cupdated\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", strings::StrCat("{\"size\": \"8\",\"generation\": \"1\"," "\"updated\": \"2016-04-29T23:15:24.896Z\"}")), new FakeHttpRequest( - "Uri: https://storage.googleapis.com/bucket/path%2Fappendable\n" + "Uri: https://storage.googleapis.com./bucket/path%2Fappendable\n" "Auth Token: fake_token\n" "Range: 0-1048575\n" "Timeouts: 5 1 20\n", "content1,"), new FakeHttpRequest( - "Uri: https://storage.googleapis.com/bucket/path%2Fappendable\n" + "Uri: https://storage.googleapis.com./bucket/path%2Fappendable\n" "Auth Token: fake_token\n" "Range: 0-31\n" "Timeouts: 5 1 20\n", "content1,"), new FakeHttpRequest( - "Uri: https://www.googleapis.com/upload/storage/v1/b/bucket/o?" + "Uri: https://www.googleapis.com./upload/storage/v1/b/bucket/o?" "uploadType=resumable&name=path%2Fappendable\n" "Auth Token: fake_token\n" "Header X-Upload-Content-Length: 17\n" @@ -1367,14 +1369,14 @@ TEST(GcsFileSystemTest, NewAppendableFile) { "Put body: content1,content2\n", ""), new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket/o/" + "Uri: https://www.googleapis.com./storage/v1/b/bucket/o/" "path%2Fappendable?fields=size%2Cgeneration%2Cupdated\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", strings::StrCat("{\"size\": \"8\",\"generation\": \"2\"," "\"updated\": \"2016-04-29T23:25:24.896Z\"}")), new FakeHttpRequest( - "Uri: https://storage.googleapis.com/bucket/path%2Fappendable\n" + "Uri: https://storage.googleapis.com./bucket/path%2Fappendable\n" "Auth Token: fake_token\n" "Range: 0-31\n" "Timeouts: 5 1 20\n", @@ -1435,13 +1437,13 @@ TEST(GcsFileSystemTest, NewAppendableFile_NoObjectName) { TEST(GcsFileSystemTest, NewAppendableFile_ObjectDoesNotExist) { std::vector requests( {new FakeHttpRequest( - "Uri: https://storage.googleapis.com/bucket/filename\n" + "Uri: https://storage.googleapis.com./bucket/filename\n" "Auth Token: fake_token\n" "Range: 0-1048575\n" "Timeouts: 5 1 20\n", "", errors::NotFound("404"), 404), new FakeHttpRequest( - "Uri: https://www.googleapis.com/upload/storage/v1/b/bucket/o" + "Uri: https://www.googleapis.com./upload/storage/v1/b/bucket/o" "?uploadType=resumable&name=filename\n" "Auth Token: fake_token\n" "Header X-Upload-Content-Length: 0\n" @@ -1467,7 +1469,7 @@ TEST(GcsFileSystemTest, NewReadOnlyMemoryRegionFromFile) { const string content = "file content"; std::vector requests( {new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket/o/" + "Uri: https://www.googleapis.com./storage/v1/b/bucket/o/" "path%2Frandom_access.txt?fields=size%2Cgeneration%2Cupdated\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", @@ -1475,7 +1477,7 @@ TEST(GcsFileSystemTest, NewReadOnlyMemoryRegionFromFile) { ", \"generation\": \"1\"", ", \"updated\": \"2016-04-29T23:15:24.896Z\"}")), new FakeHttpRequest( - strings::StrCat("Uri: https://storage.googleapis.com/bucket/" + strings::StrCat("Uri: https://storage.googleapis.com./bucket/" "path%2Frandom_access.txt\n" "Auth Token: fake_token\n" "Range: 0-", @@ -1520,7 +1522,7 @@ TEST(GcsFileSystemTest, NewReadOnlyMemoryRegionFromFile_NoObjectName) { TEST(GcsFileSystemTest, FileExists_YesAsObject) { std::vector requests({new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket/o/" + "Uri: https://www.googleapis.com./storage/v1/b/bucket/o/" "path%2Ffile1.txt?fields=size%2Cgeneration%2Cupdated\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", @@ -1543,13 +1545,13 @@ TEST(GcsFileSystemTest, FileExists_YesAsObject) { TEST(GcsFileSystemTest, FileExists_YesAsFolder) { std::vector requests( {new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket/o/" + "Uri: https://www.googleapis.com./storage/v1/b/bucket/o/" "path%2Fsubfolder?fields=size%2Cgeneration%2Cupdated\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", "", errors::NotFound("404"), 404), new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket/o?" + "Uri: https://www.googleapis.com./storage/v1/b/bucket/o?" "fields=items%2Fname%2CnextPageToken&prefix=path%2Fsubfolder%2F" "&maxResults=1\n" "Auth Token: fake_token\n" @@ -1573,12 +1575,12 @@ TEST(GcsFileSystemTest, FileExists_YesAsFolder) { TEST(GcsFileSystemTest, FileExists_YesAsBucket) { std::vector requests( {new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket1\n" + "Uri: https://www.googleapis.com./storage/v1/b/bucket1\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", "{\"size\": \"100\"}"), new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket1\n" + "Uri: https://www.googleapis.com./storage/v1/b/bucket1\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", "{\"size\": \"100\"}")}); @@ -1600,13 +1602,13 @@ TEST(GcsFileSystemTest, FileExists_YesAsBucket) { TEST(GcsFileSystemTest, FileExists_NotAsObjectOrFolder) { std::vector requests( {new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket/o/" + "Uri: https://www.googleapis.com./storage/v1/b/bucket/o/" "path%2Ffile1.txt?fields=size%2Cgeneration%2Cupdated\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", "", errors::NotFound("404"), 404), new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket/o?" + "Uri: https://www.googleapis.com./storage/v1/b/bucket/o?" "fields=items%2Fname%2CnextPageToken&prefix=path%2Ffile1.txt%2F" "&maxResults=1\n" "Auth Token: fake_token\n" @@ -1630,12 +1632,12 @@ TEST(GcsFileSystemTest, FileExists_NotAsObjectOrFolder) { TEST(GcsFileSystemTest, FileExists_NotAsBucket) { std::vector requests( {new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket2\n" + "Uri: https://www.googleapis.com./storage/v1/b/bucket2\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", "", errors::NotFound("404"), 404), new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket2\n" + "Uri: https://www.googleapis.com./storage/v1/b/bucket2\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", "", errors::NotFound("404"), 404)}); @@ -1656,20 +1658,20 @@ TEST(GcsFileSystemTest, FileExists_NotAsBucket) { TEST(GcsFileSystemTest, FileExists_StatCache) { std::vector requests( {new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket/o/" + "Uri: https://www.googleapis.com./storage/v1/b/bucket/o/" "path%2Ffile1.txt?fields=size%2Cgeneration%2Cupdated\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", strings::StrCat("{\"size\": \"1010\",\"generation\": \"1\"," "\"updated\": \"2016-04-29T23:15:24.896Z\"}")), new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket/o/" + "Uri: https://www.googleapis.com./storage/v1/b/bucket/o/" "path%2Fsubfolder%2F?fields=size%2Cgeneration%2Cupdated\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", "", errors::NotFound("404"), 404), new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket/o?" + "Uri: https://www.googleapis.com./storage/v1/b/bucket/o?" "fields=items%2Fname%2CnextPageToken&prefix=path%2Fsubfolder%2F" "&maxResults=1\n" "Auth Token: fake_token\n" @@ -1697,7 +1699,7 @@ TEST(GcsFileSystemTest, FileExists_StatCache) { TEST(GcsFileSystemTest, FileExists_DirectoryMark) { std::vector requests({new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket/o/" + "Uri: https://www.googleapis.com./storage/v1/b/bucket/o/" "dir%2F?fields=size%2Cgeneration%2Cupdated\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", @@ -1720,7 +1722,7 @@ TEST(GcsFileSystemTest, FileExists_DirectoryMark) { TEST(GcsFileSystemTest, GetChildren_NoItems) { std::vector requests({new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket/o?" + "Uri: https://www.googleapis.com./storage/v1/b/bucket/o?" "fields=items%2Fname%2Cprefixes%2CnextPageToken&delimiter=%2F&prefix=" "path%2F\n" "Auth Token: fake_token\n" @@ -1745,7 +1747,7 @@ TEST(GcsFileSystemTest, GetChildren_NoItems) { TEST(GcsFileSystemTest, GetChildren_ThreeFiles) { std::vector requests({new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket/o?" + "Uri: https://www.googleapis.com./storage/v1/b/bucket/o?" "fields=items%2Fname%2Cprefixes%2CnextPageToken&delimiter=%2F&prefix=" "path%2F\n" "Auth Token: fake_token\n" @@ -1774,7 +1776,7 @@ TEST(GcsFileSystemTest, GetChildren_ThreeFiles) { TEST(GcsFileSystemTest, GetChildren_SelfDirectoryMarker) { std::vector requests({new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket/o?" + "Uri: https://www.googleapis.com./storage/v1/b/bucket/o?" "fields=items%2Fname%2Cprefixes%2CnextPageToken&delimiter=%2F&prefix=" "path%2F\n" "Auth Token: fake_token\n" @@ -1802,7 +1804,7 @@ TEST(GcsFileSystemTest, GetChildren_SelfDirectoryMarker) { TEST(GcsFileSystemTest, GetChildren_ThreeFiles_NoSlash) { std::vector requests({new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket/o?" + "Uri: https://www.googleapis.com./storage/v1/b/bucket/o?" "fields=items%2Fname%2Cprefixes%2CnextPageToken&delimiter=%2F&prefix=" "path%2F\n" "Auth Token: fake_token\n" @@ -1831,7 +1833,7 @@ TEST(GcsFileSystemTest, GetChildren_ThreeFiles_NoSlash) { TEST(GcsFileSystemTest, GetChildren_Root) { std::vector requests({new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket-a-b-c/o?" + "Uri: https://www.googleapis.com./storage/v1/b/bucket-a-b-c/o?" "fields=items%2Fname%2Cprefixes%2CnextPageToken&delimiter=%2F\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", @@ -1855,7 +1857,7 @@ TEST(GcsFileSystemTest, GetChildren_Root) { TEST(GcsFileSystemTest, GetChildren_Empty) { std::vector requests({new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket/o?" + "Uri: https://www.googleapis.com./storage/v1/b/bucket/o?" "fields=items%2Fname%2Cprefixes%2CnextPageToken&delimiter=%2F&prefix=" "path%2F\n" "Auth Token: fake_token\n" @@ -1881,7 +1883,7 @@ TEST(GcsFileSystemTest, GetChildren_Empty) { TEST(GcsFileSystemTest, GetChildren_Pagination) { std::vector requests( {new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket/o?" + "Uri: https://www.googleapis.com./storage/v1/b/bucket/o?" "fields=items%2Fname%2Cprefixes%2CnextPageToken&delimiter=%2F&" "prefix=path%2F\n" "Auth Token: fake_token\n" @@ -1892,7 +1894,7 @@ TEST(GcsFileSystemTest, GetChildren_Pagination) { " { \"name\": \"path/file3.txt\" }]," "\"prefixes\": [\"path/subpath/\"]}"), new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket/o?" + "Uri: https://www.googleapis.com./storage/v1/b/bucket/o?" "fields=items%2Fname%2Cprefixes%2CnextPageToken&delimiter=%2F&" "prefix=path%2F" "&pageToken=ABCD==\n" @@ -1923,7 +1925,7 @@ TEST(GcsFileSystemTest, GetChildren_Pagination) { TEST(GcsFileSystemTest, GetMatchingPaths_NoWildcard) { std::vector requests({new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket/o?" + "Uri: https://www.googleapis.com./storage/v1/b/bucket/o?" "fields=items%2Fname%2CnextPageToken&prefix=path%2Fsubpath%2F\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", @@ -1949,7 +1951,7 @@ TEST(GcsFileSystemTest, GetMatchingPaths_NoWildcard) { TEST(GcsFileSystemTest, GetMatchingPaths_BucketAndWildcard) { std::vector requests({new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket/o?" + "Uri: https://www.googleapis.com./storage/v1/b/bucket/o?" "fields=items%2Fname%2CnextPageToken\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", @@ -1978,7 +1980,7 @@ TEST(GcsFileSystemTest, GetMatchingPaths_BucketAndWildcard) { TEST(GcsFileSystemTest, GetMatchingPaths_FolderAndWildcard_Matches) { std::vector requests({new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket/o?" + "Uri: https://www.googleapis.com./storage/v1/b/bucket/o?" "fields=items%2Fname%2CnextPageToken&prefix=path%2F\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", @@ -2006,7 +2008,7 @@ TEST(GcsFileSystemTest, GetMatchingPaths_FolderAndWildcard_Matches) { TEST(GcsFileSystemTest, GetMatchingPaths_SelfDirectoryMarker) { std::vector requests({new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket/o?" + "Uri: https://www.googleapis.com./storage/v1/b/bucket/o?" "fields=items%2Fname%2CnextPageToken&prefix=path%2F\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", @@ -2031,7 +2033,7 @@ TEST(GcsFileSystemTest, GetMatchingPaths_SelfDirectoryMarker) { TEST(GcsFileSystemTest, GetMatchingPaths_SlashInObjectName) { std::vector requests({new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket/o?" + "Uri: https://www.googleapis.com./storage/v1/b/bucket/o?" "fields=items%2Fname%2CnextPageToken&prefix=path%2F\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", @@ -2056,7 +2058,7 @@ TEST(GcsFileSystemTest, GetMatchingPaths_SlashInObjectName) { TEST(GcsFileSystemTest, GetMatchingPaths_SlashInObjectNameEscaped) { std::vector requests({new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket/o?" + "Uri: https://www.googleapis.com./storage/v1/b/bucket/o?" "fields=items%2Fname%2CnextPageToken&prefix=path%2F\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", @@ -2081,7 +2083,7 @@ TEST(GcsFileSystemTest, GetMatchingPaths_SlashInObjectNameEscaped) { TEST(GcsFileSystemTest, GetMatchingPaths_FolderAndWildcard_NoMatches) { std::vector requests({new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket/o?" + "Uri: https://www.googleapis.com./storage/v1/b/bucket/o?" "fields=items%2Fname%2CnextPageToken&prefix=path%2F\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", @@ -2127,14 +2129,14 @@ TEST(GcsFileSystemTest, GetMatchingPaths_OnlyWildcard) { TEST(GcsFileSystemTest, GetMatchingPaths_Cache) { std::vector requests( {new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket/o?" + "Uri: https://www.googleapis.com./storage/v1/b/bucket/o?" "fields=items%2Fname%2CnextPageToken&prefix=path%2Fsubpath%2F\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", "{\"items\": [ " " { \"name\": \"path/subpath/file2.txt\" }]}"), new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket/o?" + "Uri: https://www.googleapis.com./storage/v1/b/bucket/o?" "fields=items%2Fname%2CnextPageToken\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", @@ -2172,14 +2174,14 @@ TEST(GcsFileSystemTest, GetMatchingPaths_Cache) { TEST(GcsFileSystemTest, GetMatchingPaths_Cache_Flush) { std::vector requests( {new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket/o?" + "Uri: https://www.googleapis.com./storage/v1/b/bucket/o?" "fields=items%2Fname%2CnextPageToken&prefix=path%2Fsubpath%2F\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", "{\"items\": [ " " { \"name\": \"path/subpath/file2.txt\" }]}"), new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket/o?" + "Uri: https://www.googleapis.com./storage/v1/b/bucket/o?" "fields=items%2Fname%2CnextPageToken&prefix=path%2Fsubpath%2F\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", @@ -2218,33 +2220,33 @@ TEST(GcsFileSystemTest, GetMatchingPaths_Cache_Flush) { TEST(GcsFileSystemTest, DeleteFile) { std::vector requests( {new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket/o/" + "Uri: https://www.googleapis.com./storage/v1/b/bucket/o/" "path%2Ffile1.txt?fields=size%2Cgeneration%2Cupdated\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", strings::StrCat("{\"size\": \"8\",\"generation\": \"1\"," "\"updated\": \"2016-04-29T23:15:24.896Z\"}")), new FakeHttpRequest( - "Uri: https://storage.googleapis.com/bucket/path%2Ffile1.txt\n" + "Uri: https://storage.googleapis.com./bucket/path%2Ffile1.txt\n" "Auth Token: fake_token\n" "Range: 0-15\n" "Timeouts: 5 1 20\n", "01234567"), - new FakeHttpRequest("Uri: https://www.googleapis.com/storage/v1/b" + new FakeHttpRequest("Uri: https://www.googleapis.com./storage/v1/b" "/bucket/o/path%2Ffile1.txt\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n" "Delete: yes\n", ""), new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket/o/" + "Uri: https://www.googleapis.com./storage/v1/b/bucket/o/" "path%2Ffile1.txt?fields=size%2Cgeneration%2Cupdated\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", strings::StrCat("{\"size\": \"8\",\"generation\": \"2\"," "\"updated\": \"2016-04-29T23:19:24.896Z\"}")), new FakeHttpRequest( - "Uri: https://storage.googleapis.com/bucket/path%2Ffile1.txt\n" + "Uri: https://storage.googleapis.com./bucket/path%2Ffile1.txt\n" "Auth Token: fake_token\n" "Range: 0-15\n" "Timeouts: 5 1 20\n", @@ -2296,26 +2298,26 @@ TEST(GcsFileSystemTest, DeleteFile_NoObjectName) { TEST(GcsFileSystemTest, DeleteFile_StatCacheRemoved) { std::vector requests( {new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket/o/" + "Uri: https://www.googleapis.com./storage/v1/b/bucket/o/" "file.txt?fields=size%2Cgeneration%2Cupdated\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", strings::StrCat("{\"size\": \"1010\",\"generation\": \"1\"," "\"updated\": \"2016-04-29T23:15:24.896Z\"}")), - new FakeHttpRequest("Uri: https://www.googleapis.com/storage/v1/b" + new FakeHttpRequest("Uri: https://www.googleapis.com./storage/v1/b" "/bucket/o/file.txt\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n" "Delete: yes\n", ""), new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket/o/" + "Uri: https://www.googleapis.com./storage/v1/b/bucket/o/" "file.txt?fields=size%2Cgeneration%2Cupdated\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", "", errors::NotFound("404"), 404), new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket/o?" + "Uri: https://www.googleapis.com./storage/v1/b/bucket/o?" "fields=items%2Fname%2CnextPageToken&prefix=file.txt%2F" "&maxResults=1\n" "Auth Token: fake_token\n" @@ -2347,7 +2349,7 @@ TEST(GcsFileSystemTest, DeleteFile_StatCacheRemoved) { TEST(GcsFileSystemTest, DeleteDir_Empty) { std::vector requests({new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket/o?" + "Uri: https://www.googleapis.com./storage/v1/b/bucket/o?" "fields=items%2Fname%2CnextPageToken&prefix=path%2F&maxResults=2\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", @@ -2369,13 +2371,13 @@ TEST(GcsFileSystemTest, DeleteDir_Empty) { TEST(GcsFileSystemTest, DeleteDir_OnlyDirMarkerLeft) { std::vector requests( {new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket/o?" + "Uri: https://www.googleapis.com./storage/v1/b/bucket/o?" "fields=items%2Fname%2CnextPageToken&prefix=path%2F&maxResults=2\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", "{\"items\": [ " " { \"name\": \"path/\" }]}"), - new FakeHttpRequest("Uri: https://www.googleapis.com/storage/v1/b" + new FakeHttpRequest("Uri: https://www.googleapis.com./storage/v1/b" "/bucket/o/path%2F\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n" @@ -2397,7 +2399,7 @@ TEST(GcsFileSystemTest, DeleteDir_OnlyDirMarkerLeft) { TEST(GcsFileSystemTest, DeleteDir_BucketOnly) { std::vector requests({new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket/o?fields=items%2F" + "Uri: https://www.googleapis.com./storage/v1/b/bucket/o?fields=items%2F" "name%2CnextPageToken&maxResults=2\nAuth Token: fake_token\n" "Timeouts: 5 1 10\n", "{}")}); @@ -2417,7 +2419,7 @@ TEST(GcsFileSystemTest, DeleteDir_BucketOnly) { TEST(GcsFileSystemTest, DeleteDir_NonEmpty) { std::vector requests({new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket/o?" + "Uri: https://www.googleapis.com./storage/v1/b/bucket/o?" "fields=items%2Fname%2CnextPageToken&prefix=path%2F&maxResults=2\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", @@ -2440,7 +2442,7 @@ TEST(GcsFileSystemTest, DeleteDir_NonEmpty) { TEST(GcsFileSystemTest, GetFileSize) { std::vector requests({new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket/o/" + "Uri: https://www.googleapis.com./storage/v1/b/bucket/o/" "file.txt?fields=size%2Cgeneration%2Cupdated\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", @@ -2484,7 +2486,7 @@ TEST(GcsFileSystemTest, RenameFile_Folder) { std::vector requests( {// Check if this is a folder or an object. new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket/o?" + "Uri: https://www.googleapis.com./storage/v1/b/bucket/o?" "fields=items%2Fname%2CnextPageToken&prefix=path1%2F" "&maxResults=1\n" "Auth Token: fake_token\n" @@ -2493,7 +2495,7 @@ TEST(GcsFileSystemTest, RenameFile_Folder) { " { \"name\": \"path1/subfolder/file1.txt\" }]}"), // Requesting the full list of files in the folder. new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket/o?" + "Uri: https://www.googleapis.com./storage/v1/b/bucket/o?" "fields=items%2Fname%2CnextPageToken&prefix=path1%2F\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", @@ -2503,7 +2505,7 @@ TEST(GcsFileSystemTest, RenameFile_Folder) { " { \"name\": \"path1/file2.txt\" }]}"), // Copying the directory marker. new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket/o/" + "Uri: https://www.googleapis.com./storage/v1/b/bucket/o/" "path1%2F/rewriteTo/b/bucket/o/path2%2F\n" "Auth Token: fake_token\n" "Post: yes\n" @@ -2511,7 +2513,7 @@ TEST(GcsFileSystemTest, RenameFile_Folder) { "{\"done\": true}"), // Deleting the original directory marker. new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket/o/" + "Uri: https://www.googleapis.com./storage/v1/b/bucket/o/" "path1%2F\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n" @@ -2519,7 +2521,7 @@ TEST(GcsFileSystemTest, RenameFile_Folder) { ""), // Copying the first file. new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket/o/" + "Uri: https://www.googleapis.com./storage/v1/b/bucket/o/" "path1%2Fsubfolder%2Ffile1.txt/rewriteTo/b/bucket/o/" "path2%2Fsubfolder%2Ffile1.txt\n" "Auth Token: fake_token\n" @@ -2528,7 +2530,7 @@ TEST(GcsFileSystemTest, RenameFile_Folder) { "{\"done\": true}"), // Deleting the first original file. new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket/o/" + "Uri: https://www.googleapis.com./storage/v1/b/bucket/o/" "path1%2Fsubfolder%2Ffile1.txt\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n" @@ -2536,7 +2538,7 @@ TEST(GcsFileSystemTest, RenameFile_Folder) { ""), // Copying the second file. new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket/o/" + "Uri: https://www.googleapis.com./storage/v1/b/bucket/o/" "path1%2Ffile2.txt/rewriteTo/b/bucket/o/path2%2Ffile2.txt\n" "Auth Token: fake_token\n" "Post: yes\n" @@ -2544,7 +2546,7 @@ TEST(GcsFileSystemTest, RenameFile_Folder) { "{\"done\": true}"), // Deleting the second original file. new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket/o/" + "Uri: https://www.googleapis.com./storage/v1/b/bucket/o/" "path1%2Ffile2.txt\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n" @@ -2568,34 +2570,34 @@ TEST(GcsFileSystemTest, RenameFile_Folder) { TEST(GcsFileSystemTest, RenameFile_Object) { std::vector requests( {new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket/o/" + "Uri: https://www.googleapis.com./storage/v1/b/bucket/o/" "path%2Fsrc.txt?fields=size%2Cgeneration%2Cupdated\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", strings::StrCat("{\"size\": \"8\",\"generation\": \"1\"," "\"updated\": \"2016-04-29T23:15:24.896Z\"}")), new FakeHttpRequest( - "Uri: https://storage.googleapis.com/bucket/path%2Fsrc.txt\n" + "Uri: https://storage.googleapis.com./bucket/path%2Fsrc.txt\n" "Auth Token: fake_token\n" "Range: 0-15\n" "Timeouts: 5 1 20\n", "01234567"), new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket/o/" + "Uri: https://www.googleapis.com./storage/v1/b/bucket/o/" "path%2Fdst.txt?fields=size%2Cgeneration%2Cupdated\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", strings::StrCat("{\"size\": \"8\",\"generation\": \"1\"," "\"updated\": \"2016-04-29T23:15:24.896Z\"}")), new FakeHttpRequest( - "Uri: https://storage.googleapis.com/bucket/path%2Fdst.txt\n" + "Uri: https://storage.googleapis.com./bucket/path%2Fdst.txt\n" "Auth Token: fake_token\n" "Range: 0-15\n" "Timeouts: 5 1 20\n", "76543210"), // IsDirectory is checking whether there are children objects. new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket/o?" + "Uri: https://www.googleapis.com./storage/v1/b/bucket/o?" "fields=items%2Fname%2CnextPageToken&prefix=path%2Fsrc.txt%2F" "&maxResults=1\n" "Auth Token: fake_token\n" @@ -2603,7 +2605,7 @@ TEST(GcsFileSystemTest, RenameFile_Object) { "{}"), // Copying to the new location. new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket/o/" + "Uri: https://www.googleapis.com./storage/v1/b/bucket/o/" "path%2Fsrc.txt/rewriteTo/b/bucket/o/path%2Fdst.txt\n" "Auth Token: fake_token\n" "Post: yes\n" @@ -2611,34 +2613,34 @@ TEST(GcsFileSystemTest, RenameFile_Object) { "{\"done\": true}"), // Deleting the original file. new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket/o/" + "Uri: https://www.googleapis.com./storage/v1/b/bucket/o/" "path%2Fsrc.txt\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n" "Delete: yes\n", ""), new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket/o/" + "Uri: https://www.googleapis.com./storage/v1/b/bucket/o/" "path%2Fsrc.txt?fields=size%2Cgeneration%2Cupdated\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", strings::StrCat("{\"size\": \"8\",\"generation\": \"2\"," "\"updated\": \"2016-04-29T23:15:24.896Z\"}")), new FakeHttpRequest( - "Uri: https://storage.googleapis.com/bucket/path%2Fsrc.txt\n" + "Uri: https://storage.googleapis.com./bucket/path%2Fsrc.txt\n" "Auth Token: fake_token\n" "Range: 0-15\n" "Timeouts: 5 1 20\n", "89abcdef"), new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket/o/" + "Uri: https://www.googleapis.com./storage/v1/b/bucket/o/" "path%2Fdst.txt?fields=size%2Cgeneration%2Cupdated\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", strings::StrCat("{\"size\": \"8\",\"generation\": \"2\"," "\"updated\": \"2016-04-29T23:15:24.896Z\"}")), new FakeHttpRequest( - "Uri: https://storage.googleapis.com/bucket/path%2Fdst.txt\n" + "Uri: https://storage.googleapis.com./bucket/path%2Fdst.txt\n" "Auth Token: fake_token\n" "Range: 0-15\n" "Timeouts: 5 1 20\n", @@ -2681,7 +2683,7 @@ TEST(GcsFileSystemTest, RenameFile_Object_FlushTargetStatCache) { std::vector requests( {// Stat the target file. new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket/o/" + "Uri: https://www.googleapis.com./storage/v1/b/bucket/o/" "path%2Fdst.txt?fields=size%2Cgeneration%2Cupdated\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", @@ -2689,7 +2691,7 @@ TEST(GcsFileSystemTest, RenameFile_Object_FlushTargetStatCache) { "\"updated\": \"2016-04-29T23:15:24.896Z\"}")), // IsDirectory is checking whether there are children objects. new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket/o?" + "Uri: https://www.googleapis.com./storage/v1/b/bucket/o?" "fields=items%2Fname%2CnextPageToken&prefix=path%2Fsrc.txt%2F" "&maxResults=1\n" "Auth Token: fake_token\n" @@ -2697,7 +2699,7 @@ TEST(GcsFileSystemTest, RenameFile_Object_FlushTargetStatCache) { "{}"), // IsDirectory is checking if the path exists as an object. new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket/o/" + "Uri: https://www.googleapis.com./storage/v1/b/bucket/o/" "path%2Fsrc.txt?fields=size%2Cgeneration%2Cupdated\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", @@ -2705,7 +2707,7 @@ TEST(GcsFileSystemTest, RenameFile_Object_FlushTargetStatCache) { "\"updated\": \"2016-04-29T23:15:24.896Z\"}")), // Copying to the new location. new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket/o/" + "Uri: https://www.googleapis.com./storage/v1/b/bucket/o/" "path%2Fsrc.txt/rewriteTo/b/bucket/o/path%2Fdst.txt\n" "Auth Token: fake_token\n" "Post: yes\n" @@ -2713,14 +2715,14 @@ TEST(GcsFileSystemTest, RenameFile_Object_FlushTargetStatCache) { "{\"done\": true}"), // Deleting the original file. new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket/o/" + "Uri: https://www.googleapis.com./storage/v1/b/bucket/o/" "path%2Fsrc.txt\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n" "Delete: yes\n", ""), new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket/o/" + "Uri: https://www.googleapis.com./storage/v1/b/bucket/o/" "path%2Fdst.txt?fields=size%2Cgeneration%2Cupdated\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", @@ -2757,7 +2759,7 @@ TEST(GcsFileSystemTest, RenameFile_Object_DeletionRetried) { std::vector requests( {// IsDirectory is checking whether there are children objects. new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket/o?" + "Uri: https://www.googleapis.com./storage/v1/b/bucket/o?" "fields=items%2Fname%2CnextPageToken&prefix=path%2Fsrc.txt%2F" "&maxResults=1\n" "Auth Token: fake_token\n" @@ -2765,7 +2767,7 @@ TEST(GcsFileSystemTest, RenameFile_Object_DeletionRetried) { "{}"), // IsDirectory is checking if the path exists as an object. new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket/o/" + "Uri: https://www.googleapis.com./storage/v1/b/bucket/o/" "path%2Fsrc.txt?fields=size%2Cgeneration%2Cupdated\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", @@ -2773,7 +2775,7 @@ TEST(GcsFileSystemTest, RenameFile_Object_DeletionRetried) { "\"updated\": \"2016-04-29T23:15:24.896Z\"}")), // Copying to the new location. new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket/o/" + "Uri: https://www.googleapis.com./storage/v1/b/bucket/o/" "path%2Fsrc.txt/rewriteTo/b/bucket/o/path%2Fdst.txt\n" "Auth Token: fake_token\n" "Post: yes\n" @@ -2781,7 +2783,7 @@ TEST(GcsFileSystemTest, RenameFile_Object_DeletionRetried) { "{\"done\": true}"), // Deleting the original file - the deletion returns a failure. new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket/o/" + "Uri: https://www.googleapis.com./storage/v1/b/bucket/o/" "path%2Fsrc.txt\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n" @@ -2789,7 +2791,7 @@ TEST(GcsFileSystemTest, RenameFile_Object_DeletionRetried) { "", errors::Unavailable("503"), 503), // Deleting the original file again - the deletion returns NOT_FOUND. new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket/o/" + "Uri: https://www.googleapis.com./storage/v1/b/bucket/o/" "path%2Fsrc.txt\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n" @@ -2815,7 +2817,7 @@ TEST(GcsFileSystemTest, RenameFile_Object_Incomplete) { std::vector requests( {// IsDirectory is checking whether there are children objects. new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket/o?" + "Uri: https://www.googleapis.com./storage/v1/b/bucket/o?" "fields=items%2Fname%2CnextPageToken&prefix=path%2Fsrc.txt%2F" "&maxResults=1\n" "Auth Token: fake_token\n" @@ -2823,7 +2825,7 @@ TEST(GcsFileSystemTest, RenameFile_Object_Incomplete) { "{}"), // IsDirectory is checking if the path exists as an object. new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket/o/" + "Uri: https://www.googleapis.com./storage/v1/b/bucket/o/" "path%2Fsrc.txt?fields=size%2Cgeneration%2Cupdated\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", @@ -2831,7 +2833,7 @@ TEST(GcsFileSystemTest, RenameFile_Object_Incomplete) { "\"updated\": \"2016-04-29T23:15:24.896Z\"}")), // Copying to the new location. new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket/o/" + "Uri: https://www.googleapis.com./storage/v1/b/bucket/o/" "path%2Fsrc.txt/rewriteTo/b/bucket/o/path%2Fdst.txt\n" "Auth Token: fake_token\n" "Post: yes\n" @@ -2854,7 +2856,7 @@ TEST(GcsFileSystemTest, RenameFile_Object_Incomplete) { TEST(GcsFileSystemTest, Stat_Object) { std::vector requests({new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket/o/" + "Uri: https://www.googleapis.com./storage/v1/b/bucket/o/" "file.txt?fields=size%2Cgeneration%2Cupdated\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", @@ -2881,13 +2883,13 @@ TEST(GcsFileSystemTest, Stat_Object) { TEST(GcsFileSystemTest, Stat_Folder) { std::vector requests( {new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket/o/" + "Uri: https://www.googleapis.com./storage/v1/b/bucket/o/" "subfolder?fields=size%2Cgeneration%2Cupdated\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", "", errors::NotFound("404"), 404), new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket/o?" + "Uri: https://www.googleapis.com./storage/v1/b/bucket/o?" "fields=items%2Fname%2CnextPageToken&prefix=subfolder%2F" "&maxResults=1\n" "Auth Token: fake_token\n" @@ -2915,13 +2917,13 @@ TEST(GcsFileSystemTest, Stat_Folder) { TEST(GcsFileSystemTest, Stat_ObjectOrFolderNotFound) { std::vector requests( {new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket/o/" + "Uri: https://www.googleapis.com./storage/v1/b/bucket/o/" "path?fields=size%2Cgeneration%2Cupdated\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", "", errors::NotFound("404"), 404), new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket/o?" + "Uri: https://www.googleapis.com./storage/v1/b/bucket/o?" "fields=items%2Fname%2CnextPageToken&prefix=path%2F" "&maxResults=1\n" "Auth Token: fake_token\n" @@ -2945,7 +2947,7 @@ TEST(GcsFileSystemTest, Stat_ObjectOrFolderNotFound) { TEST(GcsFileSystemTest, Stat_Bucket) { std::vector requests({new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket\n" + "Uri: https://www.googleapis.com./storage/v1/b/bucket\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", "{}")}); @@ -2969,7 +2971,7 @@ TEST(GcsFileSystemTest, Stat_Bucket) { TEST(GcsFileSystemTest, Stat_BucketNotFound) { std::vector requests({new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket\n" + "Uri: https://www.googleapis.com./storage/v1/b/bucket\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", "", errors::NotFound("404"), 404)}); @@ -2992,20 +2994,20 @@ TEST(GcsFileSystemTest, Stat_BucketNotFound) { TEST(GcsFileSystemTest, Stat_Cache) { std::vector requests( {new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket/o/" + "Uri: https://www.googleapis.com./storage/v1/b/bucket/o/" "file.txt?fields=size%2Cgeneration%2Cupdated\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", strings::StrCat("{\"size\": \"1010\",\"generation\": \"1\"," "\"updated\": \"2016-04-29T23:15:24.896Z\"}")), new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket/o/" + "Uri: https://www.googleapis.com./storage/v1/b/bucket/o/" "subfolder%2F?fields=size%2Cgeneration%2Cupdated\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", "", errors::NotFound("404"), 404), new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket/o?" + "Uri: https://www.googleapis.com./storage/v1/b/bucket/o?" "fields=items%2Fname%2CnextPageToken&prefix=subfolder%2F" "&maxResults=1\n" "Auth Token: fake_token\n" @@ -3041,14 +3043,14 @@ TEST(GcsFileSystemTest, Stat_Cache) { TEST(GcsFileSystemTest, Stat_Cache_Flush) { std::vector requests( {new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket/o/" + "Uri: https://www.googleapis.com./storage/v1/b/bucket/o/" "file.txt?fields=size%2Cgeneration%2Cupdated\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", strings::StrCat("{\"size\": \"1010\",\"generation\": \"1\"," "\"updated\": \"2016-04-29T23:15:24.896Z\"}")), new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket/o/" + "Uri: https://www.googleapis.com./storage/v1/b/bucket/o/" "file.txt?fields=size%2Cgeneration%2Cupdated\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", @@ -3085,7 +3087,7 @@ TEST(GcsFileSystemTest, Stat_Cache_Flush) { TEST(GcsFileSystemTest, Stat_FilenameEndingWithSlash) { std::vector requests({new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket/o/" + "Uri: https://www.googleapis.com./storage/v1/b/bucket/o/" "dir%2F?fields=size%2Cgeneration%2Cupdated\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", @@ -3111,14 +3113,14 @@ TEST(GcsFileSystemTest, Stat_FilenameEndingWithSlash) { TEST(GcsFileSystemTest, IsDirectory_NotFound) { std::vector requests( {new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket/o?" + "Uri: https://www.googleapis.com./storage/v1/b/bucket/o?" "fields=items%2Fname%2CnextPageToken&prefix=file.txt%2F" "&maxResults=1\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", "{}"), new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket/o/" + "Uri: https://www.googleapis.com./storage/v1/b/bucket/o/" "file.txt?fields=size%2Cgeneration%2Cupdated\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", @@ -3141,14 +3143,14 @@ TEST(GcsFileSystemTest, IsDirectory_NotFound) { TEST(GcsFileSystemTest, IsDirectory_NotDirectoryButObject) { std::vector requests( {new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket/o?" + "Uri: https://www.googleapis.com./storage/v1/b/bucket/o?" "fields=items%2Fname%2CnextPageToken&prefix=file.txt%2F" "&maxResults=1\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", "{}"), new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket/o/" + "Uri: https://www.googleapis.com./storage/v1/b/bucket/o/" "file.txt?fields=size%2Cgeneration%2Cupdated\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", @@ -3172,14 +3174,14 @@ TEST(GcsFileSystemTest, IsDirectory_NotDirectoryButObject) { TEST(GcsFileSystemTest, IsDirectory_Yes) { std::vector requests( {new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket/o?" + "Uri: https://www.googleapis.com./storage/v1/b/bucket/o?" "fields=items%2Fname%2CnextPageToken&prefix=subfolder%2F" "&maxResults=1\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", "{\"items\": [{\"name\": \"subfolder/\"}]}"), new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket/o?" + "Uri: https://www.googleapis.com./storage/v1/b/bucket/o?" "fields=items%2Fname%2CnextPageToken&prefix=subfolder%2F" "&maxResults=1\n" "Auth Token: fake_token\n" @@ -3203,12 +3205,12 @@ TEST(GcsFileSystemTest, IsDirectory_Yes) { TEST(GcsFileSystemTest, IsDirectory_Bucket) { std::vector requests( {new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket\n" + "Uri: https://www.googleapis.com./storage/v1/b/bucket\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", "{}"), new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket\n" + "Uri: https://www.googleapis.com./storage/v1/b/bucket\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", "{}")}); @@ -3229,7 +3231,7 @@ TEST(GcsFileSystemTest, IsDirectory_Bucket) { TEST(GcsFileSystemTest, IsDirectory_BucketNotFound) { std::vector requests({new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket\n" + "Uri: https://www.googleapis.com./storage/v1/b/bucket\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", "", errors::NotFound("404"), 404)}); @@ -3254,14 +3256,14 @@ TEST(GcsFileSystemTest, CreateDir_Folder) { { // File doesn't exist. new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket/o/" + "Uri: https://www.googleapis.com./storage/v1/b/bucket/o/" "subpath%2F?fields=size%2Cgeneration%2Cupdated\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", "{}"), // Simple upload. new FakeHttpRequest( - "Uri: https://www.googleapis.com/upload/storage/v1/b/bucket/o?" + "Uri: https://www.googleapis.com./upload/storage/v1/b/bucket/o?" "uploadType=media&name=subpath%2F&ifGenerationMatch=0\n" "Auth Token: fake_token\n" "Post: yes\n" @@ -3269,7 +3271,7 @@ TEST(GcsFileSystemTest, CreateDir_Folder) { ""), // File exists. new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket/o/" + "Uri: https://www.googleapis.com./storage/v1/b/bucket/o/" "subpath%2F?fields=size%2Cgeneration%2Cupdated\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", @@ -3277,14 +3279,14 @@ TEST(GcsFileSystemTest, CreateDir_Folder) { "\"updated\": \"2016-04-29T23:15:24.896Z\"}")), // File doesn't exist again. new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket/o/" + "Uri: https://www.googleapis.com./storage/v1/b/bucket/o/" "subpath%2F?fields=size%2Cgeneration%2Cupdated\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", "{}"), // Simulate object uploaded in between. new FakeHttpRequest( - "Uri: https://www.googleapis.com/upload/storage/v1/b/bucket/o?" + "Uri: https://www.googleapis.com./upload/storage/v1/b/bucket/o?" "uploadType=media&name=subpath%2F&ifGenerationMatch=0\n" "Auth Token: fake_token\n" "Post: yes\n" @@ -3316,12 +3318,12 @@ TEST(GcsFileSystemTest, CreateDir_Folder) { TEST(GcsFileSystemTest, CreateDir_Bucket) { std::vector requests( {new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket\n" + "Uri: https://www.googleapis.com./storage/v1/b/bucket\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", ""), new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket\n" + "Uri: https://www.googleapis.com./storage/v1/b/bucket\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", "")}); @@ -3344,7 +3346,7 @@ TEST(GcsFileSystemTest, DeleteRecursively_Ok) { std::vector requests( {// IsDirectory is checking whether there are children objects. new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket/o?" + "Uri: https://www.googleapis.com./storage/v1/b/bucket/o?" "fields=items%2Fname%2CnextPageToken&prefix=path%2F" "&maxResults=1\n" "Auth Token: fake_token\n" @@ -3353,7 +3355,7 @@ TEST(GcsFileSystemTest, DeleteRecursively_Ok) { " { \"name\": \"path/file1.txt\" }]}"), // GetChildren recursively. new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket/o?" + "Uri: https://www.googleapis.com./storage/v1/b/bucket/o?" "fields=items%2Fname%2CnextPageToken&prefix=path%2F\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", @@ -3363,35 +3365,35 @@ TEST(GcsFileSystemTest, DeleteRecursively_Ok) { " { \"name\": \"path/subpath/file2.txt\" }," " { \"name\": \"path/file3.txt\" }]}"), // Delete the current directory's marker. - new FakeHttpRequest("Uri: https://www.googleapis.com/storage/v1/b" + new FakeHttpRequest("Uri: https://www.googleapis.com./storage/v1/b" "/bucket/o/path%2F\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n" "Delete: yes\n", ""), // Delete the object - fails and will be retried. - new FakeHttpRequest("Uri: https://www.googleapis.com/storage/v1/b" + new FakeHttpRequest("Uri: https://www.googleapis.com./storage/v1/b" "/bucket/o/path%2Ffile1.txt\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n" "Delete: yes\n", "", errors::Unavailable("500"), 500), // Delete the object again. - new FakeHttpRequest("Uri: https://www.googleapis.com/storage/v1/b" + new FakeHttpRequest("Uri: https://www.googleapis.com./storage/v1/b" "/bucket/o/path%2Ffile1.txt\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n" "Delete: yes\n", ""), // Delete the object. - new FakeHttpRequest("Uri: https://www.googleapis.com/storage/v1/b" + new FakeHttpRequest("Uri: https://www.googleapis.com./storage/v1/b" "/bucket/o/path%2Fsubpath%2Ffile2.txt\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n" "Delete: yes\n", ""), // Delete the object. - new FakeHttpRequest("Uri: https://www.googleapis.com/storage/v1/b" + new FakeHttpRequest("Uri: https://www.googleapis.com./storage/v1/b" "/bucket/o/path%2Ffile3.txt\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n" @@ -3419,7 +3421,7 @@ TEST(GcsFileSystemTest, DeleteRecursively_DeletionErrors) { std::vector requests( {// IsDirectory is checking whether there are children objects. new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket/o?" + "Uri: https://www.googleapis.com./storage/v1/b/bucket/o?" "fields=items%2Fname%2CnextPageToken&prefix=path%2F" "&maxResults=1\n" "Auth Token: fake_token\n" @@ -3428,7 +3430,7 @@ TEST(GcsFileSystemTest, DeleteRecursively_DeletionErrors) { " { \"name\": \"path/file1.txt\" }]}"), // Calling GetChildren recursively. new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket/o?" + "Uri: https://www.googleapis.com./storage/v1/b/bucket/o?" "fields=items%2Fname%2CnextPageToken&prefix=path%2F\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", @@ -3438,14 +3440,14 @@ TEST(GcsFileSystemTest, DeleteRecursively_DeletionErrors) { " { \"name\": \"path/subpath/file2.txt\" }," " { \"name\": \"path/file3.txt\" }]}"), // Deleting the object. - new FakeHttpRequest("Uri: https://www.googleapis.com/storage/v1/b" + new FakeHttpRequest("Uri: https://www.googleapis.com./storage/v1/b" "/bucket/o/path%2Ffile1.txt\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n" "Delete: yes\n", ""), // Deleting the directory marker gs://bucket/path/ - fails with 404. - new FakeHttpRequest("Uri: https://www.googleapis.com/storage/v1/b" + new FakeHttpRequest("Uri: https://www.googleapis.com./storage/v1/b" "/bucket/o/path%2Fsubpath%2F\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n" @@ -3453,7 +3455,7 @@ TEST(GcsFileSystemTest, DeleteRecursively_DeletionErrors) { "", errors::NotFound("404"), 404), // Checking if gs://bucket/path/subpath/ is a folder - it is. new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket/o?" + "Uri: https://www.googleapis.com./storage/v1/b/bucket/o?" "fields=items%2Fname%2CnextPageToken&prefix=path%2Fsubpath%2F" "&maxResults=1\n" "Auth Token: fake_token\n" @@ -3461,14 +3463,14 @@ TEST(GcsFileSystemTest, DeleteRecursively_DeletionErrors) { strings::StrCat("{\"items\": [ " " { \"name\": \"path/subpath/\" }]}")), // Deleting the object gs://bucket/path/subpath/file2.txt - new FakeHttpRequest("Uri: https://www.googleapis.com/storage/v1/b" + new FakeHttpRequest("Uri: https://www.googleapis.com./storage/v1/b" "/bucket/o/path%2Fsubpath%2Ffile2.txt\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n" "Delete: yes\n", ""), // Deleting the object s://bucket/path/file3.txt - fails with 404. - new FakeHttpRequest("Uri: https://www.googleapis.com/storage/v1/b" + new FakeHttpRequest("Uri: https://www.googleapis.com./storage/v1/b" "/bucket/o/path%2Ffile3.txt\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n" @@ -3476,7 +3478,7 @@ TEST(GcsFileSystemTest, DeleteRecursively_DeletionErrors) { "", errors::NotFound("404"), 404), // Checking if gs://bucket/path/file3.txt/ is a folder - it's not. new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket/o?" + "Uri: https://www.googleapis.com./storage/v1/b/bucket/o?" "fields=items%2Fname%2CnextPageToken&prefix=path%2Ffile3.txt%2F" "&maxResults=1\n" "Auth Token: fake_token\n" @@ -3484,7 +3486,7 @@ TEST(GcsFileSystemTest, DeleteRecursively_DeletionErrors) { "{}"), // Checking if gs://bucket/path/file3.txt is an object - fails with 404. new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket/o/" + "Uri: https://www.googleapis.com./storage/v1/b/bucket/o/" "path%2Ffile3.txt?fields=size%2Cgeneration%2Cupdated\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", @@ -3512,7 +3514,7 @@ TEST(GcsFileSystemTest, DeleteRecursively_NotAFolder) { std::vector requests( {// IsDirectory is checking whether there are children objects. new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket/o?" + "Uri: https://www.googleapis.com./storage/v1/b/bucket/o?" "fields=items%2Fname%2CnextPageToken&prefix=path%2F" "&maxResults=1\n" "Auth Token: fake_token\n" @@ -3520,7 +3522,7 @@ TEST(GcsFileSystemTest, DeleteRecursively_NotAFolder) { "{}"), // IsDirectory is checking if the path exists as an object. new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket/o/" + "Uri: https://www.googleapis.com./storage/v1/b/bucket/o/" "path?fields=size%2Cgeneration%2Cupdated\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", @@ -3604,7 +3606,7 @@ TEST(GcsFileSystemTest, AdditionalRequestHeaderTest) { std::vector requests( {// IsDirectory is checking whether there are children objects. - new FakeHttpRequest("Uri: https://www.googleapis.com/fake\n" + new FakeHttpRequest("Uri: https://www.googleapis.com./fake\n" "Auth Token: fake_token\n" "Header mynewheader: newheadercontents\n" "Header Hello: world\n", @@ -3622,7 +3624,7 @@ TEST(GcsFileSystemTest, AdditionalRequestHeaderTest) { std::unique_ptr request; TF_EXPECT_OK(fs7.CreateHttpRequest(&request)); - request->SetUri("https://www.googleapis.com/fake"); + request->SetUri("https://www.googleapis.com./fake"); request->AddHeader("Hello", "world"); TF_EXPECT_OK(request->Send()); } @@ -3684,7 +3686,7 @@ TEST(GcsFileSystemTest, OverrideCacheParameters) { TEST(GcsFileSystemTest, CreateHttpRequest) { std::vector requests( {// IsDirectory is checking whether there are children objects. - new FakeHttpRequest("Uri: https://www.googleapis.com/fake\n" + new FakeHttpRequest("Uri: https://www.googleapis.com./fake\n" "Auth Token: fake_token\n" "Header Hello: world\n", "{}")}); @@ -3701,7 +3703,7 @@ TEST(GcsFileSystemTest, CreateHttpRequest) { std::unique_ptr request; TF_EXPECT_OK(fs.CreateHttpRequest(&request)); - request->SetUri("https://www.googleapis.com/fake"); + request->SetUri("https://www.googleapis.com./fake"); request->AddHeader("Hello", "world"); TF_EXPECT_OK(request->Send()); } @@ -3745,7 +3747,7 @@ class TestGcsStats : public GcsStatsInterface { TEST(GcsFileSystemTest, Stat_StatsRecording) { std::vector requests({new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket/o/" + "Uri: https://www.googleapis.com./storage/v1/b/bucket/o/" "file.txt?fields=size%2Cgeneration%2Cupdated\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", @@ -3773,7 +3775,7 @@ TEST(GcsFileSystemTest, Stat_StatsRecording) { TEST(GcsFileSystemTest, NewRandomAccessFile_StatsRecording) { std::vector requests({new FakeHttpRequest( - "Uri: https://storage.googleapis.com/bucket/random_access.txt\n" + "Uri: https://storage.googleapis.com./bucket/random_access.txt\n" "Auth Token: fake_token\n" "Range: 0-5\n" "Timeouts: 5 1 20\n", @@ -3815,7 +3817,7 @@ TEST(GcsFileSystemTest, NewAppendableFile_MultipleFlushesWithCompose) { // Fetch the file (stats and then content) new FakeHttpRequest( "Uri: " - "https://www.googleapis.com/storage/v1/b/bucket/o/" + "https://www.googleapis.com./storage/v1/b/bucket/o/" "some%2Fpath%2Fappendable?fields=size%2Cgeneration%2Cupdated\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", @@ -3823,14 +3825,14 @@ TEST(GcsFileSystemTest, NewAppendableFile_MultipleFlushesWithCompose) { "\"updated\": \"2016-04-29T23:15:24.896Z\"}")), new FakeHttpRequest( "Uri: " - "https://storage.googleapis.com/bucket/some%2Fpath%2Fappendable\n" + "https://storage.googleapis.com./bucket/some%2Fpath%2Fappendable\n" "Auth Token: fake_token\n" "Range: 0-1048575\n" "Timeouts: 5 1 20\n", contents[0]), // Upload entire file new FakeHttpRequest( - "Uri: https://www.googleapis.com/upload/storage/v1/b/bucket/o?" + "Uri: https://www.googleapis.com./upload/storage/v1/b/bucket/o?" "uploadType=resumable&name=some%2Fpath%2Fappendable\n" "Auth Token: fake_token\n" "Header X-Upload-Content-Length: 18\n" @@ -3848,7 +3850,7 @@ TEST(GcsFileSystemTest, NewAppendableFile_MultipleFlushesWithCompose) { // Upload new part to a temporary object new FakeHttpRequest( "Uri: " - "https://www.googleapis.com/upload/storage/v1/b/bucket/" + "https://www.googleapis.com./upload/storage/v1/b/bucket/" "o?uploadType=resumable&name=some%2Fpath%2F.tmpcompose%2Fappendable." "18\n" "Auth Token: fake_token\n" @@ -3870,7 +3872,7 @@ TEST(GcsFileSystemTest, NewAppendableFile_MultipleFlushesWithCompose) { // Fetch generation new FakeHttpRequest( "Uri: " - "https://www.googleapis.com/storage/v1/b/bucket/o/" + "https://www.googleapis.com./storage/v1/b/bucket/o/" "some%2Fpath%2Fappendable?fields=size%2Cgeneration%2Cupdated\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", @@ -3878,7 +3880,7 @@ TEST(GcsFileSystemTest, NewAppendableFile_MultipleFlushesWithCompose) { "\"updated\": \"2016-04-29T23:15:24.896Z\"}")), // Compose the new part at the end of the original object. new FakeHttpRequest("Uri: " - "https://www.googleapis.com/storage/v1/b/bucket/o/" + "https://www.googleapis.com./storage/v1/b/bucket/o/" "some%2Fpath%2Fappendable/compose\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n" @@ -3891,14 +3893,14 @@ TEST(GcsFileSystemTest, NewAppendableFile_MultipleFlushesWithCompose) { ""), // Delete the temporary object. new FakeHttpRequest("Uri: " - "https://www.googleapis.com/storage/v1/b/bucket/o/" + "https://www.googleapis.com./storage/v1/b/bucket/o/" "some%2Fpath%2F.tmpcompose%2Fappendable.18\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n" "Delete: yes\n", ""), new FakeHttpRequest( - "Uri: https://www.googleapis.com/upload/storage/v1/b/bucket/o?" + "Uri: https://www.googleapis.com./upload/storage/v1/b/bucket/o?" "uploadType=resumable&name=some%2Fpath%2F.tmpcompose%2Fappendable." "27\n" "Auth Token: fake_token\n" @@ -3917,14 +3919,14 @@ TEST(GcsFileSystemTest, NewAppendableFile_MultipleFlushesWithCompose) { // Fetch generation new FakeHttpRequest( "Uri: " - "https://www.googleapis.com/storage/v1/b/bucket/o/" + "https://www.googleapis.com./storage/v1/b/bucket/o/" "some%2Fpath%2Fappendable?fields=size%2Cgeneration%2Cupdated\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", strings::StrCat("{\"size\": \"8\",\"generation\": \"4567\"," "\"updated\": \"2016-04-29T23:15:24.896Z\"}")), new FakeHttpRequest("Uri: " - "https://www.googleapis.com/storage/v1/b/bucket/o/" + "https://www.googleapis.com./storage/v1/b/bucket/o/" "some%2Fpath%2Fappendable/compose\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n" @@ -3936,7 +3938,7 @@ TEST(GcsFileSystemTest, NewAppendableFile_MultipleFlushesWithCompose) { "'some/path/.tmpcompose/appendable.27'}]}\n", ""), new FakeHttpRequest("Uri: " - "https://www.googleapis.com/storage/v1/b/bucket/o/" + "https://www.googleapis.com./storage/v1/b/bucket/o/" "some%2Fpath%2F.tmpcompose%2Fappendable." "27\n" "Auth Token: fake_token\n" @@ -3973,20 +3975,20 @@ TEST(GcsFileSystemTest, NewAppendableFile_MultipleFlushesWithoutCompose) { {"content0,", "content1,", "content2,", "content3,"}); std::vector requests({ new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket/o/" + "Uri: https://www.googleapis.com./storage/v1/b/bucket/o/" "path%2Fappendable?fields=size%2Cgeneration%2Cupdated\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", strings::StrCat("{\"size\": \"8\",\"generation\": \"1\"," "\"updated\": \"2016-04-29T23:15:24.896Z\"}")), new FakeHttpRequest( - "Uri: https://storage.googleapis.com/bucket/path%2Fappendable\n" + "Uri: https://storage.googleapis.com./bucket/path%2Fappendable\n" "Auth Token: fake_token\n" "Range: 0-1048575\n" "Timeouts: 5 1 20\n", contents[0]), new FakeHttpRequest( - "Uri: https://www.googleapis.com/upload/storage/v1/b/bucket/o?" + "Uri: https://www.googleapis.com./upload/storage/v1/b/bucket/o?" "uploadType=resumable&name=path%2Fappendable\n" "Auth Token: fake_token\n" "Header X-Upload-Content-Length: 18\n" @@ -4003,7 +4005,7 @@ TEST(GcsFileSystemTest, NewAppendableFile_MultipleFlushesWithoutCompose) { contents[0], contents[1], "\n"), ""), new FakeHttpRequest("Uri: " - "https://www.googleapis.com/upload/storage/v1/b/" + "https://www.googleapis.com./upload/storage/v1/b/" "bucket/o?" "uploadType=resumable&name=path%2Fappendable\n" "Auth Token: fake_token\n" @@ -4024,7 +4026,7 @@ TEST(GcsFileSystemTest, NewAppendableFile_MultipleFlushesWithoutCompose) { contents[0], contents[1], contents[2], "\n"), ""), new FakeHttpRequest( - "Uri: https://www.googleapis.com/upload/storage/v1/b/bucket/o?" + "Uri: https://www.googleapis.com./upload/storage/v1/b/bucket/o?" "uploadType=resumable&name=path%2Fappendable\n" "Auth Token: fake_token\n" "Header X-Upload-Content-Length: 36\n" diff --git a/third_party/xla/third_party/tsl/tsl/platform/cloud/testdata/BUILD b/third_party/xla/third_party/tsl/tsl/platform/cloud/testdata/BUILD index 13102f42701e05..1cafb7b63b75c6 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/cloud/testdata/BUILD +++ b/third_party/xla/third_party/tsl/tsl/platform/cloud/testdata/BUILD @@ -3,7 +3,6 @@ load("//tsl:tsl.default.bzl", "filegroup") package( - default_visibility = ["//visibility:public"], # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], licenses = ["notice"], ) @@ -13,7 +12,7 @@ filegroup( srcs = [ "application_default_credentials.json", ], - visibility = ["//visibility:public"], + visibility = ["//tsl/platform/cloud:__pkg__"], ) filegroup( @@ -21,7 +20,7 @@ filegroup( srcs = [ "service_account_credentials.json", ], - visibility = ["//visibility:public"], + visibility = ["//tsl/platform/cloud:__pkg__"], ) filegroup( @@ -29,5 +28,5 @@ filegroup( srcs = [ "service_account_public_key.txt", ], - visibility = ["//visibility:public"], + visibility = ["//tsl/platform/cloud:__pkg__"], ) diff --git a/third_party/xla/third_party/tsl/tsl/platform/default/BUILD b/third_party/xla/third_party/tsl/tsl/platform/default/BUILD index 846d38e2e2680d..8e6a52db7d1854 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/default/BUILD +++ b/third_party/xla/third_party/tsl/tsl/platform/default/BUILD @@ -1,12 +1,23 @@ -load("//tsl:tsl.bzl", "if_not_fuchsia", "if_not_windows", "set_external_visibility", "tsl_copts") -load("//tsl:tsl.default.bzl", "filegroup") +load( + "//tsl:tsl.bzl", + "if_not_fuchsia", + "if_not_windows", + "internal_visibility", + "tsl_copts", +) +load("//tsl:tsl.default.bzl", "filegroup", "tsl_grpc_cc_dependencies") load("@local_tsl//tsl/platform:rules_cc.bzl", "cc_library") # Tensorflow default + linux implementations of tensorflow/core/platform libraries. load("@bazel_skylib//:bzl_library.bzl", "bzl_library") package( - default_visibility = ["//visibility:public"], + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = internal_visibility([ + "//tensorflow/core/lib/jpeg:__pkg__", + "//tensorflow/core/platform:__pkg__", + "//tsl/platform:__pkg__", + ]), licenses = ["notice"], ) @@ -18,7 +29,6 @@ cc_library( "no_oss", "nobuilder", ], - visibility = ["//visibility:public"], ) cc_library( @@ -30,7 +40,6 @@ cc_library( "nobuilder", ], textual_hdrs = ["context.h"], - visibility = ["//visibility:public"], deps = [ "//tsl/platform", ], @@ -40,7 +49,6 @@ cc_library( name = "criticality", hdrs = ["//tsl/platform:criticality.h"], textual_hdrs = ["criticality.h"], - visibility = ["//visibility:public"], deps = [ "//tsl/platform", ], @@ -56,7 +64,6 @@ cc_library( "no_oss", "nobuilder", ], - visibility = ["//visibility:public"], deps = [ "//tsl/platform", "//tsl/platform:logging", @@ -80,7 +87,6 @@ cc_library( "manual", "nobuilder", ], - visibility = ["//visibility:public"], deps = [ "//tsl/platform:load_library", "//tsl/platform:logging", @@ -113,12 +119,12 @@ cc_library( "//tsl/platform:ram_file_system.h", "//tsl/platform:threadpool.h", ], + copts = tsl_copts(), tags = [ "manual", "no_oss", "nobuilder", ], - visibility = ["//visibility:public"], deps = [ "//tsl/platform", "//tsl/platform:blocking_counter", @@ -166,7 +172,6 @@ cc_library( "no_oss", "nobuilder", ], - visibility = ["//visibility:public"], deps = [ ":env", "//tsl/platform:load_library", @@ -186,7 +191,6 @@ cc_library( "no_oss", "nobuilder", ], - visibility = ["//visibility:public"], deps = ["//tsl/platform:types"], ) @@ -199,13 +203,27 @@ cc_library( "no_oss", "nobuilder", ], - visibility = ["//visibility:public"], deps = [ "@com_google_absl//absl/status", "@com_google_absl//absl/strings", ], ) +cc_library( + name = "grpc_credentials", + srcs = ["grpc_credentials.cc"], + hdrs = ["//tsl/platform:grpc_credentials.h"], + tags = [ + "manual", + "no_oss", + "nobuilder", + ], + deps = [ + "//tsl/platform:logging", + "@com_google_absl//absl/log:check", + ] + tsl_grpc_cc_dependencies(), +) + cc_library( name = "human_readable_json", srcs = ["human_readable_json.cc"], @@ -215,7 +233,6 @@ cc_library( "no_oss", "nobuilder", ], - visibility = ["//visibility:public"], deps = [ "//tsl/platform:errors", "//tsl/platform:protobuf", @@ -233,7 +250,6 @@ cc_library( "no_oss", "nobuilder", ], - visibility = ["//visibility:public"], deps = [ "@com_google_absl//absl/status", ], @@ -249,7 +265,6 @@ cc_library( "nobuilder", ], textual_hdrs = ["logging.h"], - visibility = ["//visibility:public"], deps = [ "//tsl/platform", "//tsl/platform:env_time", @@ -266,7 +281,6 @@ filegroup( srcs = [ "integral_types.h", ] + if_not_windows(["env_time.cc"]), - visibility = ["//visibility:public"], ) cc_library( @@ -282,7 +296,6 @@ cc_library( "nobuilder", ], textual_hdrs = ["mutex.h"], - visibility = ["//visibility:public"], deps = [ "//tsl/platform", "//tsl/platform:macros", @@ -303,7 +316,6 @@ cc_library( "no_oss", "nobuilder", ], - visibility = ["//visibility:public"], deps = [ "//tsl/platform:logging", "//tsl/platform:strcat", @@ -338,7 +350,6 @@ cc_library( "no_oss", "nobuilder", ], - visibility = ["//visibility:public"], deps = [ "//tsl/platform", "//tsl/platform:byte_order", @@ -368,7 +379,6 @@ cc_library( "no_oss", "nobuilder", ], - visibility = ["//visibility:public"], deps = [ "//tsl/platform:stringpiece", ], @@ -384,7 +394,6 @@ cc_library( "no_oss", "nobuilder", ], - visibility = ["//visibility:public"], deps = [ "//tsl/platform:logging", "//tsl/platform:path", @@ -397,7 +406,6 @@ cc_library( name = "stacktrace", hdrs = ["stacktrace.h"], linkopts = ["-ldl"], - visibility = ["//visibility:public"], deps = [ "//tsl/platform", "//tsl/platform:abi", @@ -409,7 +417,6 @@ cc_library( srcs = ["stacktrace_handler.cc"], hdrs = ["//tsl/platform:stacktrace_handler_hdrs"], linkstatic = 1, - visibility = ["//visibility:public"], deps = [ "//tsl/platform", "//tsl/platform:stacktrace", @@ -427,7 +434,6 @@ cc_library( "nobuilder", ], textual_hdrs = ["subprocess.h"], - visibility = ["//visibility:public"], deps = [ "//tsl/platform", "//tsl/platform:logging", @@ -451,7 +457,6 @@ cc_library( "nobuilder", ], textual_hdrs = ["tracing_impl.h"], - visibility = ["//visibility:public"], deps = [ "//tsl/platform", "//tsl/platform:hash", @@ -473,7 +478,6 @@ cc_library( "nobuilder", ], textual_hdrs = ["integral_types.h"], - visibility = ["//visibility:public"], ) cc_library( @@ -485,7 +489,6 @@ cc_library( "no_oss", "nobuilder", ], - visibility = ["//visibility:public"], deps = [ "//tsl/platform:env", "//tsl/platform:mutex", @@ -507,7 +510,7 @@ cc_library( "nobuilder", ], textual_hdrs = ["crash_analysis.h"], - visibility = ["//visibility:public"], + visibility = internal_visibility(["//tensorflow:__subpackages__"]), deps = [ "//tsl/platform", "//tsl/platform:protobuf", @@ -522,7 +525,7 @@ cc_library( "nobuilder", ], textual_hdrs = ["status.h"], - visibility = ["//visibility:public"], + visibility = internal_visibility(["//tensorflow:__subpackages__"]), ) cc_library( @@ -533,7 +536,7 @@ cc_library( "nobuilder", ], textual_hdrs = ["statusor.h"], - visibility = ["//visibility:public"], + visibility = internal_visibility(["//tensorflow:__subpackages__"]), deps = [ "//tsl/platform:macros", "//tsl/platform:status", @@ -544,19 +547,18 @@ cc_library( bzl_library( name = "cuda_build_defs_bzl", srcs = ["cuda_build_defs.bzl"], - visibility = ["//visibility:public"], + visibility = internal_visibility(["//tensorflow:__subpackages__"]), ) bzl_library( name = "rules_cc_bzl", srcs = ["rules_cc.bzl"], - visibility = ["//visibility:public"], ) # Export source files needed for mobile builds, which do not use granular targets. filegroup( name = "additional_mobile_srcs_no_runtime", - visibility = ["//visibility:public"], + visibility = internal_visibility(["//tensorflow/core/platform:__pkg__"]), ) filegroup( @@ -577,7 +579,10 @@ filegroup( "//tsl/platform/profile_utils:cpu_utils.h", "//tsl/platform/profile_utils:i_cpu_utils_helper.h", ], - visibility = ["//visibility:public"], + visibility = internal_visibility([ + "//tensorflow/core/platform:__pkg__", + "//tsl/platform:__pkg__", + ]), ) filegroup( @@ -593,7 +598,7 @@ filegroup( "subprocess.cc", "subprocess.h", ]), - visibility = ["//visibility:public"], + visibility = internal_visibility(["//tensorflow/core/platform:__pkg__"]), ) exports_files( @@ -605,7 +610,7 @@ exports_files( "test.cc", ], ), - visibility = ["//visibility:public"], + visibility = internal_visibility(["//tensorflow/core/platform:__pkg__"]), ) exports_files( @@ -614,5 +619,10 @@ exports_files( "logging.h", "test.cc", ], - visibility = ["//visibility:public"], + visibility = internal_visibility([ + "//tensorflow/core:__pkg__", + "//tensorflow/core/lib/gif:__pkg__", + "//tensorflow/core/lib/jpeg:__pkg__", + "//tensorflow/core/platform:__pkg__", + ]), ) diff --git a/third_party/xla/third_party/tsl/tsl/platform/default/build_config.bzl b/third_party/xla/third_party/tsl/tsl/platform/default/build_config.bzl index c0826ed15177fe..bbd6211fe102bd 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/default/build_config.bzl +++ b/third_party/xla/third_party/tsl/tsl/platform/default/build_config.bzl @@ -732,7 +732,7 @@ def tf_lib_proto_parsing_deps(): return [ ":protos_all_cc", clean_dep("@eigen_archive//:eigen3"), - clean_dep("//tsl/platform/default/build_config:proto_parsing"), + clean_dep("//tsl/protobuf:protos_all_cc"), ] def tf_py_clif_cc(name, visibility = None, **kwargs): @@ -830,6 +830,9 @@ def tf_logging_deps(): def tf_error_logging_deps(): return [clean_dep("//tsl/platform/default:error_logging")] +def tsl_grpc_credentials_deps(): + return [clean_dep("//tsl/platform/default:grpc_credentials")] + def tf_resource_deps(): return [clean_dep("//tsl/platform/default:resource")] diff --git a/third_party/xla/third_party/tsl/tsl/platform/default/build_config/BUILD b/third_party/xla/third_party/tsl/tsl/platform/default/build_config/BUILD deleted file mode 100644 index 2d5d252efe7ba6..00000000000000 --- a/third_party/xla/third_party/tsl/tsl/platform/default/build_config/BUILD +++ /dev/null @@ -1,134 +0,0 @@ -# Description: -# Platform-specific build configurations. - -load("@local_tsl//tsl/platform:rules_cc.bzl", "cc_library") -load("//tsl:tsl.bzl", "set_external_visibility", "tsl_copts") - -package(default_visibility = ["//visibility:public"]) - -licenses(["notice"]) # Apache 2.0 - -exports_files( - ["LICENSE"], - visibility = ["//visibility:public"], -) - -cc_library( - name = "gtest", - testonly = 1, - copts = tsl_copts(), - visibility = ["//visibility:public"], - deps = [ - "@com_google_googletest//:gtest", - ], -) - -cc_library( - name = "tensorflow_platform_specific", - copts = tsl_copts(), - linkstatic = 1, - visibility = ["//visibility:public"], - deps = [], -) - -cc_library( - name = "_empty_lib", - visibility = ["//visibility:public"], -) - -# Dummy stream executor cuda plugins. -cc_library( - name = "cublas_plugin", - srcs = [], - visibility = ["//visibility:public"], -) - -cc_library( - name = "cufft_plugin", - srcs = [], - visibility = ["//visibility:public"], -) - -cc_library( - name = "cudnn_plugin", - srcs = [], - visibility = ["//visibility:public"], -) - -# Minimal lib so that tools used for mobile compilation -# don't have to depend on platformlib. -cc_library( - name = "proto_parsing", - copts = tsl_copts(), - visibility = ["//visibility:public"], - deps = [ - "//tsl/protobuf:protos_all_cc", - ], -) - -# Minimal lib to be used by tensorflow/core:framework_lite. -# This provides minimal support for writing operator implementations (kernels), -# and excludes anything that can bloat binary size if used. -cc_library( - name = "minimal", - srcs = [], - copts = tsl_copts(), - visibility = ["//visibility:public"], -) - -cc_library( - name = "gif", - copts = tsl_copts(), - visibility = ["//visibility:public"], - deps = [ - "@gif", - ], -) - -cc_library( - name = "jpeg", - copts = tsl_copts(), - visibility = ["//visibility:public"], - deps = [ - "@libjpeg_turbo//:jpeg", - ], -) - -cc_library( - name = "png", - copts = tsl_copts(), - visibility = ["//visibility:public"], - deps = [ - "@png", - "@zlib", - ], -) - -cc_library( - name = "test_main", - testonly = 1, - linkstatic = 1, - visibility = ["//visibility:public"], - deps = [], -) - -cc_library( - name = "cuda", - data = [ - "@local_config_cuda//cuda:cudart", - ], - linkopts = select({ - "//tsl:macos": [ - "-Wl,-rpath,../local_config_cuda/cuda/lib", - "-Wl,-rpath,../local_config_cuda/cuda/extras/CUPTI/lib", - ], - "//conditions:default": [ - "-Wl,-rpath,../local_config_cuda/cuda/lib64", - "-Wl,-rpath,../local_config_cuda/cuda/extras/CUPTI/lib64", - ], - }), - visibility = ["//visibility:public"], - deps = [ - "@local_config_cuda//cuda:cudart", - ], -) diff --git a/third_party/xla/third_party/tsl/tsl/platform/default/cuda_build_defs.bzl b/third_party/xla/third_party/tsl/tsl/platform/default/cuda_build_defs.bzl index ad89515cd34c0b..1f7f52a627b769 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/default/cuda_build_defs.bzl +++ b/third_party/xla/third_party/tsl/tsl/platform/default/cuda_build_defs.bzl @@ -1,6 +1,10 @@ """Open source build configurations for CUDA.""" -load("@local_config_cuda//cuda:build_defs.bzl", _if_cuda_is_configured = "if_cuda_is_configured") +load( + "@local_config_cuda//cuda:build_defs.bzl", + _if_cuda_is_configured = "if_cuda_is_configured", + _if_cuda_newer_than = "if_cuda_newer_than", +) # We perform this indirection so that the copybara tool can distinguish this # macro from others provided by the same file. @@ -16,3 +20,6 @@ def cuda_rpath_flags(relpath): "-Wl,-rpath='$$ORIGIN/../../" + relpath + "'", "-Wl,-rpath='$$ORIGIN/../" + relpath + "'", ] + +def if_cuda_newer_than(wanted_ver, if_true, if_false = []): + return _if_cuda_newer_than(wanted_ver, if_true, if_false) diff --git a/third_party/xla/third_party/tsl/tsl/platform/default/grpc_credentials.cc b/third_party/xla/third_party/tsl/tsl/platform/default/grpc_credentials.cc new file mode 100644 index 00000000000000..44850f56e05195 --- /dev/null +++ b/third_party/xla/third_party/tsl/tsl/platform/default/grpc_credentials.cc @@ -0,0 +1,42 @@ +// Copyright 2024 The OpenXLA Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tsl/platform/grpc_credentials.h" + +#include + +#include "absl/log/check.h" +#include "grpcpp/security/credentials.h" +#include "grpcpp/security/server_credentials.h" +#include "tsl/platform/logging.h" + +namespace tsl { + +std::shared_ptr GetClientCredentials( + bool verify_secure_credentials) { + CHECK(!verify_secure_credentials) + << "Insecure gRPC credentials are unexpectedly used!"; + LOG(INFO) << "gRPC insecure client credentials are used."; + return grpc::InsecureChannelCredentials(); +} + +std::shared_ptr GetServerCredentials( + bool verify_secure_credentials) { + CHECK(!verify_secure_credentials) + << "Insecure gRPC credentials are unexpectedly used!"; + LOG(INFO) << "gRPC insecure server credentials are used."; + return grpc::InsecureServerCredentials(); +} + +} // namespace tsl diff --git a/third_party/xla/third_party/tsl/tsl/platform/grpc_credentials.h b/third_party/xla/third_party/tsl/tsl/platform/grpc_credentials.h new file mode 100644 index 00000000000000..5625811c0fdde9 --- /dev/null +++ b/third_party/xla/third_party/tsl/tsl/platform/grpc_credentials.h @@ -0,0 +1,38 @@ +/* + * Copyright 2024 The OpenXLA Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef TENSORFLOW_TSL_PLATFORM_GRPC_CREDENTIALS_H_ +#define TENSORFLOW_TSL_PLATFORM_GRPC_CREDENTIALS_H_ + +#include + +#include "grpcpp/security/credentials.h" +#include "grpcpp/security/server_credentials.h" + +namespace tsl { + +// Get credentials to use in the client gRPC. +// If `verify_secure_credentials`, crash if insecure credentials are used. +std::shared_ptr<::grpc::ChannelCredentials> GetClientCredentials( + bool verify_secure_credentials = true); + +// Get credentials to use in the server gRPC. +// If `verify_secure_credentials`, crash if insecure credentials are used. +std::shared_ptr<::grpc::ServerCredentials> GetServerCredentials( + bool verify_secure_credentials = true); +} // namespace tsl + +#endif // TENSORFLOW_TSL_PLATFORM_GRPC_CREDENTIALS_H_ diff --git a/third_party/xla/third_party/tsl/tsl/platform/png.h b/third_party/xla/third_party/tsl/tsl/platform/png.h deleted file mode 100644 index c66e88a5dcd6af..00000000000000 --- a/third_party/xla/third_party/tsl/tsl/platform/png.h +++ /dev/null @@ -1,30 +0,0 @@ -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_TSL_PLATFORM_PNG_H_ -#define TENSORFLOW_TSL_PLATFORM_PNG_H_ - -#include "tsl/platform/platform.h" - -#if defined(PLATFORM_GOOGLE) && !defined(IS_MOBILE_PLATFORM) -#include "png.h" // from @png // IWYU pragma: export -#elif defined(PLATFORM_POSIX) || defined(PLATFORM_WINDOWS) || \ - defined(PLATFORM_POSIX_ANDROID) || defined(IS_MOBILE_PLATFORM) -#include // IWYU pragma: export -#else -#error Define the appropriate PLATFORM_ macro for this platform -#endif - -#endif // TENSORFLOW_TSL_PLATFORM_PNG_H_ diff --git a/third_party/xla/third_party/tsl/tsl/platform/profile_utils/BUILD b/third_party/xla/third_party/tsl/tsl/platform/profile_utils/BUILD index 0832fcaedd51b7..1421f3c7e143f4 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/profile_utils/BUILD +++ b/third_party/xla/third_party/tsl/tsl/platform/profile_utils/BUILD @@ -1,7 +1,7 @@ # Description: # profile_utils targets. -load("//tsl:tsl.bzl", "set_external_visibility") +load("//tsl:tsl.bzl", "internal_visibility") load("//tsl:tsl.default.bzl", "filegroup") load( "@local_tsl//tsl/platform:rules_cc.bzl", @@ -13,22 +13,25 @@ load( ) package( - default_visibility = ["//visibility:public"], + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = internal_visibility([ + "@local_xla//xla/stream_executor:__subpackages__", + "//tensorflow/core/platform:__subpackages__", + "//tsl:__pkg__", + "//tsl/platform/default:__pkg__", + ]), licenses = ["notice"], ) -exports_files( - srcs = [ - "android_armv7a_cpu_utils_helper.cc", - "android_armv7a_cpu_utils_helper.h", - "clock_cycle_profiler.h", - "cpu_utils.cc", - "cpu_utils.h", - "cpu_utils_test.cc", - "i_cpu_utils_helper.h", - ], - visibility = ["//visibility:public"], -) +exports_files(srcs = [ + "android_armv7a_cpu_utils_helper.cc", + "android_armv7a_cpu_utils_helper.h", + "clock_cycle_profiler.h", + "cpu_utils.cc", + "cpu_utils.h", + "cpu_utils_test.cc", + "i_cpu_utils_helper.h", +]) filegroup( name = "legacy_lib_internal_srcs", @@ -36,7 +39,10 @@ filegroup( "android_armv7a_cpu_utils_helper.cc", "clock_cycle_profiler.cc", ], - visibility = ["//visibility:public"], + visibility = internal_visibility([ + "//tensorflow/core/platform:__subpackages__", + "//tsl/platform:__pkg__", + ]), ) cc_library( @@ -50,7 +56,6 @@ cc_library( "i_cpu_utils_helper.h", ], copts = tsl_copts(), - visibility = ["//visibility:public"], deps = [ "//tsl/platform:logging", "//tsl/platform:macros", diff --git a/third_party/xla/third_party/tsl/tsl/platform/regexp.h b/third_party/xla/third_party/tsl/tsl/platform/regexp.h index dbab940e81a0dd..fac545c266aae3 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/regexp.h +++ b/third_party/xla/third_party/tsl/tsl/platform/regexp.h @@ -19,9 +19,9 @@ limitations under the License. #include "tsl/platform/platform.h" #if TSL_IS_IN_OSS -#include "re2/re2.h" +#include "re2/re2.h" // IWYU pragma: export #else -#include "third_party/re2/re2.h" -#endif // TSL_IS_IN_OSS +#include "third_party/re2/re2.h" // IWYU pragma: export +#endif // TSL_IS_IN_OSS #endif // TENSORFLOW_TSL_PLATFORM_REGEXP_H_ diff --git a/third_party/xla/third_party/tsl/tsl/platform/status.cc b/third_party/xla/third_party/tsl/tsl/platform/status.cc index 063ad314568813..2fb124322c1098 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/status.cc +++ b/third_party/xla/third_party/tsl/tsl/platform/status.cc @@ -166,12 +166,6 @@ const char* NullTerminatedMessage(const Status& status) { } -Status OkStatus() { return Status(); } - -Status FromAbslStatus(const absl::Status& s) { return s; } - -absl::Status ToAbslStatus(const ::absl::Status& s) { return s; } - std::string* TfCheckOpHelperOutOfLine(const ::tsl::Status& v, const char* msg) { std::string r("Non-OK-status: "); r += msg; diff --git a/third_party/xla/third_party/tsl/tsl/platform/status.h b/third_party/xla/third_party/tsl/tsl/platform/status.h index 3b5cc2ef4e24da..812ac1a0d6adfc 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/status.h +++ b/third_party/xla/third_party/tsl/tsl/platform/status.h @@ -46,7 +46,7 @@ limitations under the License. #include "tsl/platform/default/status.h" // IWYU pragma: export #endif -// This macro should eventually be provided by Abseil. +// TODO: b/323943471 - This macro should eventually be provided by Abseil. #ifndef ABSL_DEPRECATE_AND_INLINE #define ABSL_DEPRECATE_AND_INLINE() #endif @@ -105,10 +105,14 @@ namespace tsl { // // Returns an OK status, equivalent to a default constructed instance. Prefer // usage of `OkStatus()` when constructing such an OK status. -Status OkStatus(); +ABSL_DEPRECATE_AND_INLINE() inline absl::Status OkStatus() { + return absl::OkStatus(); +}; -absl::Status FromAbslStatus(const absl::Status& s); -absl::Status ToAbslStatus(const ::absl::Status& s); +ABSL_DEPRECATE_AND_INLINE() +inline absl::Status FromAbslStatus(const absl::Status& s) { return s; } +ABSL_DEPRECATE_AND_INLINE() +inline absl::Status ToAbslStatus(const ::absl::Status& s) { return s; } // Given `Status.message()` does not guarantee to be always backed by a // null-terminated string, we have this utility function when it's needed for diff --git a/third_party/xla/third_party/tsl/tsl/platform/statusor.h b/third_party/xla/third_party/tsl/tsl/platform/statusor.h index 0db4e733112c8c..6c49be5132fc9d 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/statusor.h +++ b/third_party/xla/third_party/tsl/tsl/platform/statusor.h @@ -69,6 +69,7 @@ limitations under the License. #define TENSORFLOW_TSL_PLATFORM_STATUSOR_H_ #include "absl/base/attributes.h" +#include "absl/base/macros.h" #include "absl/status/statusor.h" #include "tsl/platform/errors.h" #include "tsl/platform/macros.h" @@ -82,9 +83,15 @@ limitations under the License. #include "tsl/platform/default/statusor.h" // IWYU pragma: export #endif +// TODO: b/323943471 - This macro should eventually be provided by Abseil. +#ifndef ABSL_DEPRECATE_AND_INLINE +#define ABSL_DEPRECATE_AND_INLINE() +#endif + namespace tsl { -using absl::StatusOr; +template +using StatusOr ABSL_DEPRECATE_AND_INLINE() = absl::StatusOr; } // namespace tsl diff --git a/third_party/xla/third_party/tsl/tsl/platform/testdata/BUILD b/third_party/xla/third_party/tsl/tsl/platform/testdata/BUILD index 665398eb19e8ff..d50e456b83f5da 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/testdata/BUILD +++ b/third_party/xla/third_party/tsl/tsl/platform/testdata/BUILD @@ -3,7 +3,10 @@ # Thus helping write cross platform tests. package( - default_visibility = ["//visibility:public"], + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = [ + "//tsl/platform:__pkg__", + ], licenses = ["notice"], ) diff --git a/third_party/xla/third_party/tsl/tsl/platform/threadpool.cc b/third_party/xla/third_party/tsl/tsl/platform/threadpool.cc index 218226611b13f1..8b2c850331e944 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/threadpool.cc +++ b/third_party/xla/third_party/tsl/tsl/platform/threadpool.cc @@ -28,6 +28,10 @@ limitations under the License. #include "tsl/platform/setround.h" #include "tsl/platform/tracing.h" +#ifdef DNNL_AARCH64_USE_ACL +#include "tsl/platform/cpu_info.h" +#endif // DNNL_AARCH64_USE_ACL + #ifdef TENSORFLOW_THREADSCALING_EXPERIMENTAL ABSL_FLAG(float, tensorflow_num_threads_scale_factor, 1.0, "Allows to scale all Tensorflow ThreadPools. Total number of threads " @@ -107,6 +111,14 @@ ThreadPool::ThreadPool(Env* env, const ThreadOptions& thread_options, bool low_latency_hint, Eigen::Allocator* allocator) { CHECK_GE(num_threads, 1); +#ifdef DNNL_AARCH64_USE_ACL + // To avoid cost of swapping in and out threads from running processes + // we do not use all available cores to parallelise TF operations. + if (num_threads == tsl::port::NumTotalCPUs() && num_threads >= 16) { + num_threads = num_threads - 1; + } +#endif // DNNL_AARCH64_USE_ACL + #ifdef TENSORFLOW_THREADSCALING_EXPERIMENTAL CHECK_GT(absl::GetFlag(FLAGS_tensorflow_num_threads_scale_factor), 0); num_threads *= absl::GetFlag(FLAGS_tensorflow_num_threads_scale_factor); diff --git a/third_party/xla/third_party/tsl/tsl/platform/windows/BUILD b/third_party/xla/third_party/tsl/tsl/platform/windows/BUILD index a7955cdd5bfdea..fda40c1e211f02 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/windows/BUILD +++ b/third_party/xla/third_party/tsl/tsl/platform/windows/BUILD @@ -1,6 +1,7 @@ # Tensorflow windows-specific implementations of tensorflow/core/platform libraries. load( "//tsl:tsl.bzl", + "internal_visibility", "tsl_copts", ) load("//tsl:tsl.default.bzl", "filegroup") @@ -10,7 +11,11 @@ load( ) package( - default_visibility = ["//visibility:public"], + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = internal_visibility([ + "//tensorflow/core/platform:__pkg__", + "//tsl/platform:__pkg__", + ]), licenses = ["notice"], ) @@ -36,7 +41,6 @@ cc_library( "no_oss", "nobuilder", ], - visibility = ["//visibility:public"], deps = [ ":error_windows", ":wide_char", @@ -84,7 +88,6 @@ cc_library( "no_oss", "nobuilder", ], - visibility = ["//visibility:public"], deps = [ ":env", ], @@ -99,7 +102,6 @@ cc_library( "no_oss", "nobuilder", ], - visibility = ["//visibility:public"], deps = [ "//tsl/platform:types", ], @@ -115,7 +117,9 @@ cc_library( "no_oss", "nobuilder", ], - visibility = ["//visibility:public"], + # This code is highly windows specific and should only be used with care + # from this package. + visibility = ["//visibility:private"], ) cc_library( @@ -126,7 +130,6 @@ cc_library( "no_oss", "nobuilder", ], - visibility = ["//visibility:public"], deps = ["//tsl/platform:types"], ) @@ -139,7 +142,6 @@ cc_library( "no_oss", "nobuilder", ], - visibility = ["//visibility:public"], deps = [ ":wide_char", "//tsl/platform:errors", @@ -159,7 +161,6 @@ cc_library( "no_oss", "nobuilder", ], - visibility = ["//visibility:public"], deps = [ ":error_windows", "//tsl/platform:errors", @@ -189,7 +190,6 @@ cc_library( "no_oss", "nobuilder", ], - visibility = ["//visibility:public"], deps = [ "//tsl/platform", "//tsl/platform:byte_order", @@ -209,7 +209,6 @@ cc_library( "no_oss", "nobuilder", ], - visibility = ["//visibility:public"], deps = ["//tsl/platform:mutex"], ) @@ -222,7 +221,6 @@ cc_library( "no_oss", "nobuilder", ], - visibility = ["//visibility:public"], deps = [ "//tsl/platform:mutex", "//tsl/platform:stacktrace", @@ -240,7 +238,6 @@ cc_library( "nobuilder", ], textual_hdrs = ["subprocess.h"], - visibility = ["//visibility:public"], deps = [ "//tsl/platform", "//tsl/platform:logging", @@ -260,16 +257,14 @@ cc_library( "no_oss", "nobuilder", ], - visibility = ["//visibility:public"], ) filegroup( name = "xla_cpu_runtime_srcs", srcs = ["env_time.cc"], - visibility = ["//visibility:public"], ) exports_files( srcs = ["intrinsics_port.h"], - visibility = ["//visibility:public"], + visibility = internal_visibility(["//tensorflow/core/platform:__pkg__"]), ) diff --git a/third_party/xla/third_party/tsl/tsl/profiler/BUILD b/third_party/xla/third_party/tsl/tsl/profiler/BUILD index af9a004e299ecd..130527c3bae006 100644 --- a/third_party/xla/third_party/tsl/tsl/profiler/BUILD +++ b/third_party/xla/third_party/tsl/tsl/profiler/BUILD @@ -2,29 +2,16 @@ package_group( name = "friends", - includes = ["//tsl:internal"], ) package_group( name = "internal", - packages = [ - "//tensorflow/core/profiler/...", - "//tensorflow/python/eager/...", - "//tensorflow/python/profiler/...", - "//tensorflow/python/tpu/profiler/...", - "//tsl/profiler/...", - "//xla/backends/profiler/...", - ], ) package_group( name = "xla_profiler_backends", - packages = ["//xla/backends/profiler/..."], ) package_group( name = "xla_internal", - packages = [ - "//xla/...", - ], ) diff --git a/third_party/xla/third_party/tsl/tsl/profiler/backends/cpu/BUILD b/third_party/xla/third_party/tsl/tsl/profiler/backends/cpu/BUILD index c1f6ae3ebd6ce8..e7d90f1c8caecb 100644 --- a/third_party/xla/third_party/tsl/tsl/profiler/backends/cpu/BUILD +++ b/third_party/xla/third_party/tsl/tsl/profiler/backends/cpu/BUILD @@ -1,4 +1,4 @@ -load("//tsl:tsl.bzl", "set_external_visibility") +load("//tsl:tsl.bzl", "internal_visibility") load("//tsl/platform:build_config_root.bzl", "if_static") load("@local_tsl//tsl/platform:rules_cc.bzl", "cc_library") load("//tsl/profiler/builds:build_config.bzl", "tf_profiler_copts") @@ -10,7 +10,10 @@ cc_library( name = "traceme_recorder", hdrs = ["traceme_recorder.h"], copts = tf_profiler_copts(), - visibility = ["//visibility:public"], + visibility = internal_visibility([ + "//tsl/profiler:internal", + "//tsl/profiler:xla_profiler_backends", + ]), deps = [ "//tsl/platform:macros", "//tsl/platform:mutex", @@ -29,7 +32,13 @@ cc_library( ], hdrs = ["traceme_recorder.h"], copts = tf_profiler_copts(), - visibility = ["//visibility:public"], + visibility = internal_visibility([ + "//tensorflow/python:__pkg__", + "//tsl/platform/cloud:__pkg__", + "//tsl/profiler:__pkg__", + "//tsl/profiler:internal", + "//tsl/profiler:xla_internal", + ]), deps = [ "//tsl/platform:env", "//tsl/platform:logging", @@ -45,7 +54,6 @@ cc_library( tsl_cc_test( name = "traceme_recorder_test", srcs = ["traceme_recorder_test.cc"], - visibility = ["//visibility:public"], deps = [ ":traceme_recorder", ":traceme_recorder_impl", @@ -68,7 +76,9 @@ cc_library( name = "annotation_stack", hdrs = ["annotation_stack.h"], copts = tf_profiler_copts(), - visibility = ["//visibility:public"], + visibility = internal_visibility([ + "//tsl/profiler:internal", + ]), deps = [ "//tsl/platform:macros", "//tsl/platform:types", @@ -85,7 +95,10 @@ cc_library( "annotation_stack.h", ], copts = tf_profiler_copts(), - visibility = ["//visibility:public"], + visibility = internal_visibility([ + "@local_xla//xla:__subpackages__", + "//tsl/profiler:internal", + ]), deps = [ "//tsl/platform:macros", "//tsl/platform:types", @@ -99,7 +112,10 @@ cc_library( srcs = ["host_tracer_utils.cc"], hdrs = ["host_tracer_utils.h"], copts = tf_profiler_copts(), - visibility = ["//visibility:public"], + visibility = internal_visibility([ + "//tsl/profiler:internal", + "//tsl/profiler:xla_internal", + ]), deps = [ ":traceme_recorder", "//tsl/platform:types", diff --git a/third_party/xla/third_party/tsl/tsl/profiler/backends/cpu/annotation_stack.cc b/third_party/xla/third_party/tsl/tsl/profiler/backends/cpu/annotation_stack.cc index 97b4c5daeb373e..1e4b99bf79a846 100644 --- a/third_party/xla/third_party/tsl/tsl/profiler/backends/cpu/annotation_stack.cc +++ b/third_party/xla/third_party/tsl/tsl/profiler/backends/cpu/annotation_stack.cc @@ -16,33 +16,69 @@ limitations under the License. #include "tsl/profiler/backends/cpu/annotation_stack.h" #include +#include +#include +#include +#include +#include +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" #include "tsl/platform/types.h" namespace tsl { namespace profiler { -namespace internal { - -#ifdef _WIN32 -#define DECL_DLL_EXPORT __declspec(dllexport) -#else -#define DECL_DLL_EXPORT -#endif -// DLL imported variables cannot be initialized on Windows. This file is -// included only on DLL exports. -DECL_DLL_EXPORT std::atomic g_annotation_enabled(0); - -// g_annotation_enabled implementation must be lock-free for faster execution of -// the ScopedAnnotation API. This can be commented (if compilation is failing) -// but execution might be slow (even when tracing is disabled). -static_assert(ATOMIC_INT_LOCK_FREE == 2, "Assumed atomic was lock free"); -} // namespace internal +// Returns the annotation data for the given generation. +static auto GetAnnotationData(const std::atomic& atomic) { + static thread_local struct { + int generation = 0; + std::vector stack; + std::string string; + } data; + int generation = atomic.load(std::memory_order_acquire); + if (generation != data.generation) { + data = {generation}; + } + return std::make_pair(&data.stack, &data.string); +}; + +void AnnotationStack::PushAnnotation(std::string_view name) { + auto [stack, string] = GetAnnotationData(generation_); + stack->push_back(string->size()); + if (!string->empty()) { + return absl::StrAppend( + string, "::", absl::string_view(name.data(), name.size()) // NOLINT + ); + } + string->assign(name); +} + +void AnnotationStack::PopAnnotation() { + auto [stack, string] = GetAnnotationData(generation_); + if (stack->empty()) { + return string->clear(); + } + string->resize(stack->back()); + stack->pop_back(); +} -/*static*/ string* AnnotationStack::ThreadAnnotationStack() { - static thread_local string annotation_stack; - return &annotation_stack; +const string& AnnotationStack::Get() { + return *std::get(GetAnnotationData(generation_)); } +void AnnotationStack::Enable(bool enable) { + int generation = generation_.load(std::memory_order_relaxed); + while (!generation_.compare_exchange_weak( + generation, enable ? generation | 1 : generation + 1 & ~1, + std::memory_order_release)) { + } +} + +// AnnotationStack::generation_ implementation must be lock-free for faster +// execution of the ScopedAnnotation API. +std::atomic AnnotationStack::generation_{0}; +static_assert(ATOMIC_INT_LOCK_FREE == 2, "Assumed atomic was lock free"); + } // namespace profiler } // namespace tsl diff --git a/third_party/xla/third_party/tsl/tsl/profiler/backends/cpu/annotation_stack.h b/third_party/xla/third_party/tsl/tsl/profiler/backends/cpu/annotation_stack.h index 23bd5236f185bf..44d1626e6a5cb3 100644 --- a/third_party/xla/third_party/tsl/tsl/profiler/backends/cpu/annotation_stack.h +++ b/third_party/xla/third_party/tsl/tsl/profiler/backends/cpu/annotation_stack.h @@ -15,80 +15,42 @@ limitations under the License. #ifndef TENSORFLOW_TSL_PROFILER_BACKENDS_CPU_ANNOTATION_STACK_H_ #define TENSORFLOW_TSL_PROFILER_BACKENDS_CPU_ANNOTATION_STACK_H_ -#include - #include -#include +#include -#include "absl/strings/str_cat.h" -#include "absl/strings/string_view.h" -#include "tsl/platform/macros.h" #include "tsl/platform/types.h" namespace tsl { namespace profiler { -namespace internal { - -// Whether annotations are enabled. -// Static atomic so Annotation::IsEnabled can be fast and non-blocking. -TF_EXPORT extern std::atomic g_annotation_enabled; - -} // namespace internal // Backend for ScopedAnnotation. class AnnotationStack { public: - // Appends name to the annotation for the current thread and returns the - // original length of the annotation. - // Append name to the current annotation, separated by "::". - // The choice of separator "::" is based on characters not used by - // TensorFlow for its TensorOps. - static size_t PushAnnotation(absl::string_view name) { - string* annotation_stack = ThreadAnnotationStack(); - size_t old_length = annotation_stack->size(); - if (old_length != 0) { - absl::StrAppend(annotation_stack, "::", name); - } else { - *annotation_stack = string(name); - } - return old_length; - } + // Appends name to the annotations for the current thread, separated by "::". + // The choice of separator "::" is based on characters not used by TensorFlow + // for its TensorOps. + static void PushAnnotation(std::string_view name); - static size_t PushAnnotation(string&& name) { - string* annotation_stack = ThreadAnnotationStack(); - size_t old_length = annotation_stack->size(); - if (old_length != 0) { - absl::StrAppend(annotation_stack, "::", name); - } else { - *annotation_stack = std::move(name); - } - return old_length; - } + // Resizes the annotation stack for the current thread. + static void PopAnnotation(); // Returns the annotation stack for the current thread. - static const string& Get() { return *ThreadAnnotationStack(); } + static const string& Get(); - // Resizes the annotation stack for the current thread to its old length. - static void PopAnnotation(size_t old_length) { - ThreadAnnotationStack()->resize(old_length); - } - - static void Enable(bool enable) { - internal::g_annotation_enabled.store(enable, std::memory_order_release); - } + // Enables or disables the annotation stack. + static void Enable(bool enable); + // Returns whether the annotation stack is enabled. static bool IsEnabled() { - return internal::g_annotation_enabled.load(std::memory_order_acquire); + return generation_.load(std::memory_order_acquire) & 1; } private: AnnotationStack() = default; - AnnotationStack(const AnnotationStack&) = delete; - void operator=(const AnnotationStack&) = delete; - - // Returns a reference to the annotation for the current thread. - static string* ThreadAnnotationStack(); + // Enabled if odd, disabled if even. The value is incremented for every call + // to Enable() which changes the enabled state. + static std::atomic generation_; }; } // namespace profiler diff --git a/third_party/xla/third_party/tsl/tsl/profiler/builds/BUILD b/third_party/xla/third_party/tsl/tsl/profiler/builds/BUILD index b5bac93cb20919..050103eec5570b 100644 --- a/third_party/xla/third_party/tsl/tsl/profiler/builds/BUILD +++ b/third_party/xla/third_party/tsl/tsl/profiler/builds/BUILD @@ -1,5 +1,8 @@ +load("//tsl:tsl.bzl", "internal_visibility") + package( - default_visibility = ["//visibility:public"], + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = internal_visibility(["//tsl/profiler:internal"]), licenses = ["notice"], ) @@ -7,5 +10,4 @@ package( config_setting( name = "profiler_build_oss", define_values = {"profiler_build": "oss"}, - visibility = ["//visibility:public"], ) diff --git a/third_party/xla/third_party/tsl/tsl/profiler/builds/oss/BUILD b/third_party/xla/third_party/tsl/tsl/profiler/builds/oss/BUILD index 9ea44126a3c89a..d406ee17088ad7 100644 --- a/third_party/xla/third_party/tsl/tsl/profiler/builds/oss/BUILD +++ b/third_party/xla/third_party/tsl/tsl/profiler/builds/oss/BUILD @@ -1,7 +1,6 @@ # Tensorflow default + linux implementations of tensorflow/core/profiler libraries. package( - default_visibility = ["//visibility:public"], # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], licenses = ["notice"], ) diff --git a/third_party/xla/third_party/tsl/tsl/profiler/convert/BUILD b/third_party/xla/third_party/tsl/tsl/profiler/convert/BUILD index a92c9f0cdf7300..12273bf96f1964 100644 --- a/third_party/xla/third_party/tsl/tsl/profiler/convert/BUILD +++ b/third_party/xla/third_party/tsl/tsl/profiler/convert/BUILD @@ -1,13 +1,14 @@ -load("//tsl:tsl.bzl", "set_external_visibility") +load("//tsl:tsl.bzl", "internal_visibility") +load("//tsl/platform:build_config.bzl", "tsl_cc_test") load( "@local_tsl//tsl/platform:rules_cc.bzl", "cc_library", ) -load("//tsl/platform:build_config.bzl", "tsl_cc_test") load("//tsl/profiler/builds:build_config.bzl", "tf_profiler_copts") package( - default_visibility = ["//visibility:public"], + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = internal_visibility(["//tsl/profiler:internal"]), licenses = ["notice"], ) @@ -16,7 +17,9 @@ cc_library( srcs = ["trace_container.cc"], hdrs = ["trace_container.h"], copts = tf_profiler_copts(), - visibility = ["//visibility:public"], + visibility = [ + "//tsl/profiler:internal", + ], deps = [ "//tsl/platform:protobuf", "//tsl/profiler/protobuf:trace_events_proto_cc", @@ -26,7 +29,11 @@ cc_library( cc_library( name = "xla_op_utils", hdrs = ["xla_op_utils.h"], - visibility = ["//visibility:public"], + visibility = internal_visibility([ + "//tsl/profiler:internal", + "//tsl/profiler:xla_profiler_backends", + "@local_xla//xla/python:__pkg__", + ]), deps = ["@com_google_absl//absl/strings"], ) @@ -34,7 +41,6 @@ tsl_cc_test( name = "xla_op_utils_test", size = "small", srcs = ["xla_op_utils_test.cc"], - visibility = ["//visibility:public"], deps = [ ":xla_op_utils", "//tsl/platform:test", @@ -47,10 +53,11 @@ cc_library( srcs = ["post_process_single_host_xplane.cc"], hdrs = ["post_process_single_host_xplane.h"], copts = tf_profiler_copts(), - visibility = ["//visibility:public"], + visibility = internal_visibility(["//tsl/profiler:internal"]), deps = [ "//tsl/platform:types", "//tsl/profiler/protobuf:xplane_proto_cc", + "//tsl/profiler/utils:timestamp_utils", "//tsl/profiler/utils:xplane_schema", "//tsl/profiler/utils:xplane_utils", ], @@ -61,7 +68,9 @@ cc_library( srcs = ["trace_events_to_json.cc"], hdrs = ["trace_events_to_json.h"], copts = tf_profiler_copts(), - visibility = ["//visibility:public"], + visibility = internal_visibility([ + "//tsl/profiler:internal", + ]), deps = [ ":trace_container", "//tsl/platform:protobuf", @@ -78,7 +87,6 @@ cc_library( tsl_cc_test( name = "trace_container_test", srcs = ["trace_container_test.cc"], - visibility = ["//visibility:public"], deps = [ ":trace_container", "//tsl/platform:protobuf", @@ -91,7 +99,6 @@ tsl_cc_test( tsl_cc_test( name = "trace_events_to_json_test", srcs = ["trace_events_to_json_test.cc"], - visibility = ["//visibility:public"], deps = [ ":trace_container", ":trace_events_to_json", @@ -108,7 +115,9 @@ cc_library( srcs = ["xplane_to_trace_events.cc"], hdrs = ["xplane_to_trace_events.h"], copts = tf_profiler_copts(), - visibility = ["//visibility:public"], + visibility = internal_visibility([ + "//tsl/profiler:internal", + ]), deps = [ ":trace_container", "//tsl/platform:types", @@ -128,7 +137,6 @@ tsl_cc_test( name = "xplane_to_trace_events_test", size = "small", srcs = ["xplane_to_trace_events_test.cc"], - visibility = ["//visibility:public"], deps = [ ":xplane_to_trace_events", "//tsl/platform:test", diff --git a/third_party/xla/third_party/tsl/tsl/profiler/convert/post_process_single_host_xplane.cc b/third_party/xla/third_party/tsl/tsl/profiler/convert/post_process_single_host_xplane.cc index fbba8c2eb840ab..49e2f7dbda2ae3 100644 --- a/third_party/xla/third_party/tsl/tsl/profiler/convert/post_process_single_host_xplane.cc +++ b/third_party/xla/third_party/tsl/tsl/profiler/convert/post_process_single_host_xplane.cc @@ -17,7 +17,9 @@ limitations under the License. #include #include +#include "tsl/platform/types.h" #include "tsl/profiler/protobuf/xplane.pb.h" +#include "tsl/profiler/utils/timestamp_utils.h" #include "tsl/profiler/utils/xplane_schema.h" #include "tsl/profiler/utils/xplane_utils.h" @@ -43,7 +45,7 @@ void MergeHostPlanesAndSortLines(tensorflow::profiler::XSpace* space) { } // namespace void PostProcessSingleHostXSpace(tensorflow::profiler::XSpace* space, - uint64 start_time_ns) { + uint64 start_time_ns, uint64 stop_time_ns) { VLOG(3) << "Post processing local profiler XSpace."; // Post processing the collected XSpace without hold profiler lock. // 1. Merge all host planes and sorts lines by name. @@ -51,7 +53,10 @@ void PostProcessSingleHostXSpace(tensorflow::profiler::XSpace* space, // 2. Normalize all timestamps by shifting timeline to profiling start time. // NOTE: this have to be done before sorting XSpace due to timestamp overflow. NormalizeTimestamps(space, start_time_ns); - // 3. Sort each plane of the XSpace + // 3. Add information regarding profiling start_time_ns_ and stop_time_ns_ to + // taskEnv. + SetSessionTimestamps(start_time_ns, stop_time_ns, *space); + // 4. Sort each plane of the XSpace SortXSpace(space); } diff --git a/third_party/xla/third_party/tsl/tsl/profiler/convert/post_process_single_host_xplane.h b/third_party/xla/third_party/tsl/tsl/profiler/convert/post_process_single_host_xplane.h index d0183f2dba188d..0b413931e989fd 100644 --- a/third_party/xla/third_party/tsl/tsl/profiler/convert/post_process_single_host_xplane.h +++ b/third_party/xla/third_party/tsl/tsl/profiler/convert/post_process_single_host_xplane.h @@ -23,7 +23,7 @@ namespace profiler { // Post process XSpaces collected locally from multiple profilers. void PostProcessSingleHostXSpace(tensorflow::profiler::XSpace* space, - uint64 start_time_ns); + uint64 start_time_ns, uint64 stop_time_ns); } // namespace profiler } // namespace tsl diff --git a/third_party/xla/third_party/tsl/tsl/profiler/lib/BUILD b/third_party/xla/third_party/tsl/tsl/profiler/lib/BUILD index 48427c0b61999d..cf9e41e245eed5 100644 --- a/third_party/xla/third_party/tsl/tsl/profiler/lib/BUILD +++ b/third_party/xla/third_party/tsl/tsl/profiler/lib/BUILD @@ -1,8 +1,8 @@ -load("@local_tsl//tsl/platform:rules_cc.bzl", "cc_library") -load("//tsl/platform:build_config_root.bzl", "if_static") +load("//tsl:tsl.bzl", "if_not_android", "internal_visibility", "nvtx_headers") load("//tsl:tsl.default.bzl", "filegroup") -load("//tsl:tsl.bzl", "if_not_android", "set_external_visibility") load("//tsl/platform:build_config.bzl", "tsl_cc_test") +load("//tsl/platform:build_config_root.bzl", "if_static") +load("@local_tsl//tsl/platform:rules_cc.bzl", "cc_library") load( "//tsl/platform/default:cuda_build_defs.bzl", "if_cuda_is_configured", @@ -53,7 +53,9 @@ cc_library( name = "profiler_controller", srcs = ["profiler_controller.cc"], hdrs = ["profiler_controller.h"], - visibility = ["//visibility:public"], + visibility = internal_visibility([ + "//tsl/profiler:internal", + ]), deps = [ ":profiler_interface", "//tsl/platform:errors", @@ -66,7 +68,12 @@ cc_library( cc_library( name = "profiler_factory", hdrs = ["profiler_factory.h"], - visibility = ["//visibility:public"], + visibility = internal_visibility([ + "//tsl/profiler:internal", + "//tsl/profiler:xla_profiler_backends", + "@local_xla//xla/python:__pkg__", + "//learning/brain/tfrc/executor/stream_executor:__pkg__", + ]), deps = [ ":profiler_interface", "//tsl/profiler/protobuf:profiler_options_proto_cc", @@ -82,7 +89,10 @@ cc_library( "profiler_factory.h", ], copts = tf_profiler_copts(), - visibility = ["//visibility:public"], + visibility = internal_visibility([ + "//tsl/profiler:internal", + "//learning/brain/tfrc/executor/stream_executor:__pkg__", + ]), deps = [ ":profiler_interface", "//tsl/platform:mutex", @@ -95,7 +105,6 @@ cc_library( tsl_cc_test( name = "profiler_factory_test", srcs = ["profiler_factory_test.cc"], - visibility = ["//visibility:public"], deps = [ ":profiler_factory", ":profiler_factory_impl", @@ -114,7 +123,11 @@ cc_library( name = "profiler_interface", hdrs = ["profiler_interface.h"], copts = tf_profiler_copts(), - visibility = ["//visibility:public"], + visibility = internal_visibility([ + "//tsl:internal", + "//tsl/profiler:internal", + "//tsl/profiler:xla_profiler_backends", + ]), deps = [ "//tsl/platform:status", "//tsl/profiler/protobuf:xplane_proto_cc", @@ -126,7 +139,10 @@ cc_library( srcs = ["profiler_lock.cc"], hdrs = ["profiler_lock.h"], copts = tf_profiler_copts(), - visibility = ["//visibility:public"], + visibility = internal_visibility([ + "//tsl/profiler:internal", + "//tsl/profiler:xla_internal", + ]), deps = [ "//tsl/platform:errors", "//tsl/platform:macros", @@ -138,7 +154,6 @@ cc_library( tsl_cc_test( name = "profiler_lock_test", srcs = ["profiler_lock_test.cc"], - visibility = ["//visibility:public"], deps = [ ":profiler_lock", "//tsl/platform:test", @@ -149,7 +164,7 @@ tsl_cc_test( cc_library( name = "profiler_session", hdrs = ["profiler_session.h"], - visibility = ["//visibility:public"], + visibility = internal_visibility(["//tsl:internal"]), deps = [ "//tsl/platform", "//tsl/platform:errors", @@ -174,7 +189,10 @@ cc_library( "profiler_session.h", ], copts = tf_profiler_copts(), - visibility = ["//visibility:public"], + visibility = internal_visibility([ + "//tensorflow/python:__pkg__", + "//tsl/profiler:internal", + ]), deps = [ "//tsl/platform:errors", "//tsl/platform:logging", @@ -213,11 +231,11 @@ cc_library( tsl_cc_test( name = "traceme_encode_test", srcs = ["traceme_encode_test.cc"], - visibility = ["//visibility:public"], deps = [ ":traceme_encode", "//tsl/platform", "//tsl/platform:test", + "//tsl/platform:test_benchmark", "//tsl/platform:test_main", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", @@ -227,7 +245,7 @@ tsl_cc_test( tf_profiler_pybind_cc_library_wrapper( name = "traceme_for_pybind", actual = ":traceme", - visibility = ["//visibility:public"], + visibility = internal_visibility(["//tsl/profiler:xla_internal"]), ) cc_library( @@ -250,15 +268,14 @@ cc_library( cc_library( name = "nvtx_utils", hdrs = ["nvtx_utils.h"], + local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]), visibility = ["//visibility:public"], deps = [ "//tsl/platform:logging", "//tsl/platform:macros", "//tsl/platform:types", "@com_google_absl//absl/strings", - ] + if_cuda_is_configured([ - "@local_config_cuda//cuda:cuda_headers", # NVTX headers - ]), + ] + if_cuda_is_configured(nvtx_headers()), ) cc_library( @@ -275,26 +292,12 @@ cc_library( ]), ) -cc_library( - name = "scoped_annotation_stack", - hdrs = ["scoped_annotation_stack.h"], - visibility = ["//visibility:public"], - deps = [ - "@com_google_absl//absl/strings", - ] + if_not_android([ - ":nvtx_utils", - "//tsl/profiler/backends/cpu:annotation_stack", - ]), -) - tsl_cc_test( name = "scoped_annotation_test", size = "small", srcs = ["scoped_annotation_test.cc"], - visibility = ["//visibility:public"], deps = [ ":scoped_annotation", - ":scoped_annotation_stack", "//tsl/platform:test", "//tsl/platform:test_benchmark", "//tsl/platform:test_main", @@ -321,7 +324,10 @@ cc_library( name = "profiler_collection", srcs = ["profiler_collection.cc"], hdrs = ["profiler_collection.h"], - visibility = ["//visibility:public"], + visibility = internal_visibility([ + "@local_xla//xla/backends/profiler/plugin:__pkg__", + "//learning/brain/tfrc/executor/stream_executor:__pkg__", + ]), deps = [ ":profiler_interface", "//tsl/platform:status", diff --git a/third_party/xla/third_party/tsl/tsl/profiler/lib/context_types.cc b/third_party/xla/third_party/tsl/tsl/profiler/lib/context_types.cc index 9379885c4de76b..371631c10ba882 100644 --- a/third_party/xla/third_party/tsl/tsl/profiler/lib/context_types.cc +++ b/third_party/xla/third_party/tsl/tsl/profiler/lib/context_types.cc @@ -46,6 +46,8 @@ const char* GetContextTypeString(ContextType context_type) { return "tpu_launch"; case ContextType::kPathwaysExecutor: return "pathways_exec"; + case ContextType::kPjrtLibraryCall: + return "pjrt_library_call"; } } diff --git a/third_party/xla/third_party/tsl/tsl/profiler/lib/context_types.h b/third_party/xla/third_party/tsl/tsl/profiler/lib/context_types.h index 6f65454354a1dc..621f35462fdae2 100644 --- a/third_party/xla/third_party/tsl/tsl/profiler/lib/context_types.h +++ b/third_party/xla/third_party/tsl/tsl/profiler/lib/context_types.h @@ -36,6 +36,7 @@ enum class ContextType : int { kTpuStream, kTpuLaunch, kPathwaysExecutor, + kPjrtLibraryCall, kLastContextType = ContextType::kTpuLaunch, }; diff --git a/third_party/xla/third_party/tsl/tsl/profiler/lib/nvtx_utils.h b/third_party/xla/third_party/tsl/tsl/profiler/lib/nvtx_utils.h index e3eaaa08af79e8..8713550b3a20c8 100644 --- a/third_party/xla/third_party/tsl/tsl/profiler/lib/nvtx_utils.h +++ b/third_party/xla/third_party/tsl/tsl/profiler/lib/nvtx_utils.h @@ -17,42 +17,28 @@ limitations under the License. #define TENSORFLOW_TSL_PROFILER_LIB_NVTX_UTILS_H_ #include - -#include "absl/strings/string_view.h" -#include "tsl/platform/logging.h" -#include "tsl/platform/macros.h" +#include #if GOOGLE_CUDA #include "nvtx3/nvToolsExt.h" +#include "nvtx3/nvToolsExtPayload.h" #else // Some typedef to help build without NVTX. -typedef void* nvtxEventAttributes_t; typedef void* nvtxDomainHandle_t; typedef void* nvtxStringHandle_t; #endif namespace tsl { namespace profiler { -namespace nvtx { // A helper function that return the domains to use if NVTX profiling // is enabled. inline std::optional GetNVTXDomain() { #if GOOGLE_CUDA - static nvtxDomainHandle_t domain; - static bool is_enabled = [] { - bool _is_enabled = false; - // Force NVTX marker if a tool triggered the profiler. - domain = nvtxDomainCreateA("TSL"); - if (domain) { - _is_enabled = true; - } - VLOG(1) << "Is NVTX marker enabled? " << _is_enabled; - return _is_enabled; - }(); - if (is_enabled) return domain; + static nvtxDomainHandle_t domain = nvtxDomainCreateA("TSL"); + if (domain != nullptr) return domain; #endif - return {}; + return std::nullopt; } // A helper function to decide whether to enable CUDA NVTX profiling ranges. @@ -64,42 +50,39 @@ inline bool RangesEnabled() { #endif } -// Two types of NVTX range annotation are supported, the older/simpler option -// is to use std::string and have the NVTX implementation copy a C-style -// string every time. The other option is to pass a struct implementing two -// methods: -// -// std::string_view Title() const; -// nvtxStringHandle_t NvtxRegisteredTitle() const; -// -// in which case NvtxRegisteredTitle() will be used when starting NVTX ranges, -// avoiding this string copy. -// The Title() method is needed because AnnotationStack::PushAnnotation(...) is -// the backend for some annotations when NVTX is not enabled, and it does not -// recognise registered strings. has_annotation_api_v -// distinguishes between the two types of annotation. -template -inline constexpr bool has_annotation_api_v = - !std::is_same_v; +// Older/simpler version; NVTX implementation copies a C-style string each time +inline void RangePush(nvtxDomainHandle_t domain, const char* ascii) { +#if GOOGLE_CUDA + nvtxEventAttributes_t attrs{}; + attrs.version = NVTX_VERSION; + attrs.size = NVTX_EVENT_ATTRIB_STRUCT_SIZE; + attrs.messageType = NVTX_MESSAGE_TYPE_ASCII; + attrs.message.ascii = ascii; + ::nvtxDomainRangePushEx(domain, &attrs); +#endif +} +inline void RangePush(nvtxDomainHandle_t domain, const std::string& str) { + RangePush(domain, str.c_str()); +} -template -void RangePush(nvtxDomainHandle_t domain, const AnnotationType& annotation) { +// More powerful version: pass a registered string instead of a C-style string, +// and attach a generic payload. The Annotation type must implement a method +// called NvtxSchemaId() that allows the NVTX backend to interpret the payload. +template +void RangePush(nvtxDomainHandle_t domain, nvtxStringHandle_t handle, + const Annotation& annotation) { #if GOOGLE_CUDA nvtxEventAttributes_t attrs{}; attrs.version = NVTX_VERSION; attrs.size = NVTX_EVENT_ATTRIB_STRUCT_SIZE; - if constexpr (has_annotation_api_v>) { - attrs.messageType = NVTX_MESSAGE_TYPE_REGISTERED; - attrs.message.registered = annotation.NvtxRegisteredTitle(); - } else { - attrs.messageType = NVTX_MESSAGE_TYPE_ASCII; - attrs.message.ascii = annotation.c_str(); - } + attrs.messageType = NVTX_MESSAGE_TYPE_REGISTERED; + attrs.message.registered = handle; + NVTX_PAYLOAD_EVTATTR_SET(attrs, annotation.NvtxSchemaId(), &annotation, + sizeof(Annotation)); ::nvtxDomainRangePushEx(domain, &attrs); #endif } -} // namespace nvtx } // namespace profiler } // namespace tsl #endif // TENSORFLOW_TSL_PROFILER_LIB_NVTX_UTILS_H_ diff --git a/third_party/xla/third_party/tsl/tsl/profiler/lib/profiler_session.cc b/third_party/xla/third_party/tsl/tsl/profiler/lib/profiler_session.cc index f5ba7e1ee9281e..2edf404b64e4e4 100644 --- a/third_party/xla/third_party/tsl/tsl/profiler/lib/profiler_session.cc +++ b/third_party/xla/third_party/tsl/tsl/profiler/lib/profiler_session.cc @@ -70,6 +70,7 @@ Status ProfilerSession::CollectDataInternal(XSpace* space) { LOG(INFO) << "Profiler session collecting data."; if (profilers_ != nullptr) { profilers_->Stop().IgnoreError(); + stop_time_ns_ = profiler::GetCurrentTimeNanos(); profilers_->CollectData(space).IgnoreError(); profilers_.reset(); // data has been collected. } @@ -83,7 +84,7 @@ Status ProfilerSession::CollectData(XSpace* space) { #if !defined(IS_MOBILE_PLATFORM) space->add_hostnames(port::Hostname()); TF_RETURN_IF_ERROR(CollectDataInternal(space)); - profiler::PostProcessSingleHostXSpace(space, start_time_ns_); + profiler::PostProcessSingleHostXSpace(space, start_time_ns_, stop_time_ns_); #endif return OkStatus(); } diff --git a/third_party/xla/third_party/tsl/tsl/profiler/lib/profiler_session.h b/third_party/xla/third_party/tsl/tsl/profiler/lib/profiler_session.h index 424e5c87d0b4ef..e6fb67218ac5fe 100644 --- a/third_party/xla/third_party/tsl/tsl/profiler/lib/profiler_session.h +++ b/third_party/xla/third_party/tsl/tsl/profiler/lib/profiler_session.h @@ -83,6 +83,7 @@ class ProfilerSession { std::unique_ptr profilers_ TF_GUARDED_BY(mutex_); uint64 start_time_ns_; + uint64 stop_time_ns_; tensorflow::ProfileOptions options_; #endif tsl::Status status_ TF_GUARDED_BY(mutex_); diff --git a/third_party/xla/third_party/tsl/tsl/profiler/lib/scoped_annotation.h b/third_party/xla/third_party/tsl/tsl/profiler/lib/scoped_annotation.h index f047fafc4ebe3a..b779d600959466 100644 --- a/third_party/xla/third_party/tsl/tsl/profiler/lib/scoped_annotation.h +++ b/third_party/xla/third_party/tsl/tsl/profiler/lib/scoped_annotation.h @@ -18,123 +18,84 @@ limitations under the License. #include #include -#include #include #include #include -#include "absl/strings/string_view.h" #include "tsl/platform/macros.h" -#include "tsl/platform/types.h" #if !defined(IS_MOBILE_PLATFORM) #include "tsl/profiler/backends/cpu/annotation_stack.h" +#endif + +#if GOOGLE_CUDA #include "tsl/profiler/lib/nvtx_utils.h" #endif namespace tsl { namespace profiler { -// Adds an annotation to all activities for the duration of the instance -// lifetime through the currently registered TraceCollector. -// -// Usage: { -// ScopedAnnotation annotation("my kernels"); -// Kernel1<<>>; -// LaunchKernel2(); // Launches a CUDA kernel. -// } -// This will add 'my kernels' to both kernels in the profiler UI -template -class ScopedAnnotationT { - public: - explicit ScopedAnnotationT(absl::string_view name) { -#if !defined(IS_MOBILE_PLATFORM) +// Adds an annotation to all activities through the currently registered +// TraceCollector until PopAnnotation() is called. +template +inline void PushAnnotation(const T& generator) { #if GOOGLE_CUDA - std::optional domain = - tsl::profiler::nvtx::GetNVTXDomain(); - if (TF_PREDICT_FALSE(domain.has_value())) { - tsl::profiler::nvtx::RangePush(domain.value(), std::string{name}); - } else // NOLINT -#endif - if (always_annotate || TF_PREDICT_FALSE(AnnotationStack::IsEnabled())) { - old_length_ = AnnotationStack::PushAnnotation(name); - } -#endif + if (auto domain = GetNVTXDomain(); TF_PREDICT_FALSE(domain.has_value())) { + return RangePush(*domain, generator()); } +#endif - explicit ScopedAnnotationT(const char* name) - : ScopedAnnotationT(absl::string_view(name)) {} - - explicit ScopedAnnotationT(const string& name) { #if !defined(IS_MOBILE_PLATFORM) -#if GOOGLE_CUDA - std::optional domain = - tsl::profiler::nvtx::GetNVTXDomain(); - if (TF_PREDICT_FALSE(domain.has_value())) { - tsl::profiler::nvtx::RangePush(domain.value(), name); - } else // NOLINT -#endif - if (always_annotate || TF_PREDICT_FALSE(AnnotationStack::IsEnabled())) { - old_length_ = AnnotationStack::PushAnnotation(name); - } -#endif + if (TF_PREDICT_FALSE(AnnotationStack::IsEnabled())) { + AnnotationStack::PushAnnotation(static_cast(generator())); } +#endif +} + +inline void PushAnnotation(const char* name) { + PushAnnotation([&] { return name; }); +} +inline void PushAnnotation(const std::string& name) { + PushAnnotation([&] { return name; }); +} + +inline void PopAnnotation() { + // TODO(b/137971921): without this memory fence, two presubmit tests will + // fail probably due to compiler in that presubmit config. + std::atomic_thread_fence(std::memory_order_acquire); - explicit ScopedAnnotationT(string&& name) { -#if !defined(IS_MOBILE_PLATFORM) #if GOOGLE_CUDA - std::optional domain = - tsl::profiler::nvtx::GetNVTXDomain(); - if (TF_PREDICT_FALSE(domain.has_value())) { - tsl::profiler::nvtx::RangePush(domain.value(), name); - } else // NOLINT -#endif - if (always_annotate || TF_PREDICT_FALSE(AnnotationStack::IsEnabled())) { - old_length_ = AnnotationStack::PushAnnotation(std::move(name)); - } -#endif + if (auto domain = GetNVTXDomain(); TF_PREDICT_FALSE(domain.has_value())) { + ::nvtxDomainRangePop(*domain); + return; } +#endif - template - explicit ScopedAnnotationT(NameGeneratorT name_generator) { #if !defined(IS_MOBILE_PLATFORM) -#if GOOGLE_CUDA - std::optional domain = - tsl::profiler::nvtx::GetNVTXDomain(); - if (TF_PREDICT_FALSE(domain.has_value())) { - tsl::profiler::nvtx::RangePush(domain.value(), name_generator()); - } else // NOLINT -#endif - if (always_annotate || TF_PREDICT_FALSE(AnnotationStack::IsEnabled())) { - auto annotation = name_generator(); - if constexpr (tsl::profiler::nvtx::has_annotation_api_v< - std::decay_t>) { - old_length_ = AnnotationStack::PushAnnotation(annotation.Title()); - } else { - old_length_ = AnnotationStack::PushAnnotation(std::move(annotation)); - } - } + if (TF_PREDICT_FALSE(AnnotationStack::IsEnabled())) { + AnnotationStack::PopAnnotation(); + } #endif +} + +// Adds an annotation to all activities for the duration of the instance +// lifetime through the currently registered TraceCollector. +// +// Usage: { +// ScopedAnnotation annotation("my kernels"); +// Kernel1<<>>; +// LaunchKernel2(); // Launches a CUDA kernel. +// } +// This will add 'my kernels' to both kernels in the profiler UI +class ScopedAnnotation { + public: + template + explicit ScopedAnnotation(T&& annotation) { + PushAnnotation(std::forward(annotation)); } // Pops the name passed in the constructor from the current annotation. - ~ScopedAnnotationT() { - // TODO(b/137971921): without this memory fence, two presubmit tests will - // fail probably due to compiler in that presubmit config. - std::atomic_thread_fence(std::memory_order_acquire); -#if !defined(IS_MOBILE_PLATFORM) -#if GOOGLE_CUDA - std::optional domain = - tsl::profiler::nvtx::GetNVTXDomain(); - if (TF_PREDICT_FALSE(domain.has_value())) { - ::nvtxDomainRangePop(domain.value()); - } else // NOLINT -#endif - if (TF_PREDICT_FALSE(old_length_ != kInvalidLength)) { - AnnotationStack::PopAnnotation(old_length_); - } -#endif - } + ~ScopedAnnotation() { PopAnnotation(); } static bool IsEnabled() { #if !defined(IS_MOBILE_PLATFORM) @@ -145,18 +106,10 @@ class ScopedAnnotationT { } private: - // signals that annotation is disabled at the constructor. - static constexpr size_t kInvalidLength = static_cast(-1); - - ScopedAnnotationT(const ScopedAnnotationT&) = delete; - void operator=(const ScopedAnnotationT&) = delete; - - size_t old_length_ = kInvalidLength; + ScopedAnnotation(const ScopedAnnotation&) = delete; + ScopedAnnotation& operator=(const ScopedAnnotation&) = delete; }; -using ScopedAnnotation = ScopedAnnotationT; -using ScopedAnnotationAlways = ScopedAnnotationT; - } // namespace profiler } // namespace tsl diff --git a/third_party/xla/third_party/tsl/tsl/profiler/lib/scoped_annotation_stack.h b/third_party/xla/third_party/tsl/tsl/profiler/lib/scoped_annotation_stack.h deleted file mode 100644 index db46f7c99135e4..00000000000000 --- a/third_party/xla/third_party/tsl/tsl/profiler/lib/scoped_annotation_stack.h +++ /dev/null @@ -1,118 +0,0 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#ifndef TENSORFLOW_TSL_PROFILER_LIB_SCOPED_ANNOTATION_STACK_H_ -#define TENSORFLOW_TSL_PROFILER_LIB_SCOPED_ANNOTATION_STACK_H_ - -#include - -#include -#include -#include -#include - -#include "absl/strings/string_view.h" -#if !defined(IS_MOBILE_PLATFORM) -#include "tsl/profiler/backends/cpu/annotation_stack.h" -#include "tsl/profiler/lib/nvtx_utils.h" -#endif - -namespace tsl { -namespace profiler { - -// ScopedAnnotation for clients that can't use RAII for managing the lifetime -// of annotations. It provides an API similar to the `TraceMe::ActivityStart` -// and `TraceMe::ActivityEnd`. -// -// Usage: -// int64_t id = ScopedAnnotationStack::ActivityStart("foo"); -// foo(); -// ScopedAnnotationStack::ActivityEnd(id); -// -// Prefer a regular `ScopedAnnotation`. The name of this class is a misnomer, -// because it doesn't do any automatic destruction at the scope end, it's just -// for the sake of consistency. -class ScopedAnnotationStack { - static constexpr size_t kInvalidActivity = static_cast(-1); - - public: - static bool IsEnabled() { return AnnotationStack::IsEnabled(); } - - static int64_t ActivityStart(std::string name) { -#if !defined(IS_MOBILE_PLATFORM) -#if GOOGLE_CUDA - std::optional domain = - tsl::profiler::nvtx::GetNVTXDomain(); - if (TF_PREDICT_FALSE(domain.has_value())) { - tsl::profiler::nvtx::RangePush(domain.value(), name); - } else // NOLINT -#endif - if (TF_PREDICT_FALSE(AnnotationStack::IsEnabled())) { - return AnnotationStack::PushAnnotation(std::move(name)); - } -#endif - return kInvalidActivity; - } - - static int64_t ActivityStart(std::string_view name) { - return ActivityStart(std::string(name)); - } - - static int64_t ActivityStart(const char* name) { - return ActivityStart(std::string_view(name)); - } - - template - static int64_t ActivityStart(NameGeneratorT name_generator) { -#if !defined(IS_MOBILE_PLATFORM) -#if GOOGLE_CUDA - std::optional domain = - tsl::profiler::nvtx::GetNVTXDomain(); - if (TF_PREDICT_FALSE(domain.has_value())) { - tsl::profiler::nvtx::RangePush(domain.value(), name_generator()); - } else // NOLINT -#endif - if (TF_PREDICT_FALSE(AnnotationStack::IsEnabled())) { - auto annotation = name_generator(); - if constexpr (tsl::profiler::nvtx::has_annotation_api_v< - std::decay_t>) { - return AnnotationStack::PushAnnotation(annotation.Title()); - } else { - return AnnotationStack::PushAnnotation(std::move(annotation)); - } - } -#endif - return kInvalidActivity; - } - - static void ActivityEnd(int64_t activity_id) { -#if !defined(IS_MOBILE_PLATFORM) -#if GOOGLE_CUDA - std::optional domain = - tsl::profiler::nvtx::GetNVTXDomain(); - if (TF_PREDICT_FALSE(domain.has_value())) { - ::nvtxDomainRangePop(domain.value()); - } else // NOLINT -#endif - if (TF_PREDICT_FALSE(activity_id != kInvalidActivity)) { - AnnotationStack::PopAnnotation(activity_id); - } -#endif - } -}; - -} // namespace profiler -} // namespace tsl - -#endif // TENSORFLOW_TSL_PROFILER_LIB_SCOPED_ANNOTATION_STACK_H_ diff --git a/third_party/xla/third_party/tsl/tsl/profiler/lib/scoped_annotation_test.cc b/third_party/xla/third_party/tsl/tsl/profiler/lib/scoped_annotation_test.cc index dab3e91f2ed4cd..0ae9d3276375f1 100644 --- a/third_party/xla/third_party/tsl/tsl/profiler/lib/scoped_annotation_test.cc +++ b/third_party/xla/third_party/tsl/tsl/profiler/lib/scoped_annotation_test.cc @@ -21,7 +21,6 @@ limitations under the License. #include "tsl/platform/test.h" #include "tsl/platform/test_benchmark.h" #include "tsl/profiler/backends/cpu/annotation_stack.h" -#include "tsl/profiler/lib/scoped_annotation_stack.h" namespace tsl { namespace profiler { @@ -50,11 +49,11 @@ TEST(ScopedAnnotation, Simple) { { AnnotationStack::Enable(true); - int64_t id0 = ScopedAnnotationStack::ActivityStart("foo"); - int64_t id1 = ScopedAnnotationStack::ActivityStart("bar"); + PushAnnotation("foo"); + PushAnnotation("bar"); EXPECT_EQ(AnnotationStack::Get(), "foo::bar"); // enabled - ScopedAnnotationStack::ActivityEnd(id1); - ScopedAnnotationStack::ActivityEnd(id0); + PopAnnotation(); + PopAnnotation(); AnnotationStack::Enable(false); } diff --git a/third_party/xla/third_party/tsl/tsl/profiler/lib/traceme_encode_test.cc b/third_party/xla/third_party/tsl/tsl/profiler/lib/traceme_encode_test.cc index ea64d28f9e2b48..4827bee4d820b6 100644 --- a/third_party/xla/third_party/tsl/tsl/profiler/lib/traceme_encode_test.cc +++ b/third_party/xla/third_party/tsl/tsl/profiler/lib/traceme_encode_test.cc @@ -20,6 +20,7 @@ limitations under the License. #include "absl/strings/str_format.h" #include "tsl/platform/platform.h" #include "tsl/platform/test.h" +#include "tsl/platform/test_benchmark.h" namespace tsl { namespace profiler { @@ -81,5 +82,25 @@ TEST(TraceMeEncodeTest, NoNameTest) { } } // namespace + +void BM_TraceMeEncode(::testing::benchmark::State& state) { + for (auto s : state) { + TraceMeEncode( + "MyTestEvent", + {{"Lorem ipsum dolor sit amet", 1}, + {"consectetur adipiscing elit", 2}, + {"sed do eiusmod tempor incididunt", 3.52}, + {"ut labore et dolore magna aliqua", "Ut enim ad minim veniam"}, + {"quis nostrud exercitation ullamco", "laboris nisi ut aliquip ex"}, + {"ea commodo consequat.", 11111.1111}, + {"Duis aute", 1234567890}, + {"irure dolor in", " reprehenderit in voluptate"}, + {"velit esse cillum dolore", "eu fugiat nulla pariatur."}, + {"Excepteur sint", "occaecat cupidatat non proident, sunt in"}, + {"culpa qui officia", "deserunt mollit anim id est laborum."}}); + } +} +BENCHMARK(BM_TraceMeEncode); + } // namespace profiler } // namespace tsl diff --git a/third_party/xla/third_party/tsl/tsl/profiler/protobuf/BUILD b/third_party/xla/third_party/tsl/tsl/profiler/protobuf/BUILD index a8194cdb3e182f..f43bc94f52f585 100644 --- a/third_party/xla/third_party/tsl/tsl/profiler/protobuf/BUILD +++ b/third_party/xla/third_party/tsl/tsl/profiler/protobuf/BUILD @@ -1,6 +1,6 @@ # Placeholder: load py_proto_library # copybara:uncomment(oss-unused) load("//net/grpc/go/build_defs:go_grpc_library.bzl", "go_grpc_library") -load("//tsl:tsl.bzl", "set_external_visibility") +load("//tsl:tsl.bzl", "internal_visibility") load("//tsl/platform:build_config.bzl", "tf_proto_library") # copybara:uncomment package(default_applicable_licenses = ["//tensorflow:license"]) @@ -17,7 +17,7 @@ tf_proto_library( srcs = ["xplane.proto"], cc_api_version = 2, make_default_target_header_only = True, - visibility = ["//visibility:public"], + visibility = internal_visibility([":friends"]), ) tf_proto_library( @@ -71,7 +71,7 @@ tf_proto_library( name = "trace_events_proto", srcs = ["trace_events.proto"], cc_api_version = 2, - visibility = ["//visibility:public"], + visibility = internal_visibility([":friends"]), ) # copybara:uncomment_begin(google-only) @@ -86,14 +86,20 @@ tf_proto_library( # This is needed because of how tf_android_core_proto_sources parses proto paths. exports_files( srcs = ["xplane.proto"], - visibility = ["//visibility:public"], + visibility = internal_visibility([ + "//tensorflow/core:__pkg__", + "//tsl:__pkg__", + ]), ) tf_proto_library( name = "profile_proto", srcs = ["profile.proto"], cc_api_version = 2, - visibility = ["//visibility:public"], + visibility = internal_visibility([ + "@local_xla//xla/python:__pkg__", + "//tsl/profiler:internal", + ]), ) tf_proto_library( @@ -113,7 +119,7 @@ tf_proto_library( # py_proto_library( # name = "xplane_py_pb2", # api_version = 2, -# visibility = set_external_visibility([":friends"]), +# visibility = internal_visibility([":friends"]), # deps = [":xplane_proto"], # ) # copybara:uncomment_end @@ -123,5 +129,5 @@ tf_proto_library( srcs = ["profiled_instructions.proto"], cc_api_version = 2, make_default_target_header_only = True, - visibility = ["//visibility:public"], + visibility = internal_visibility([":friends"]), ) diff --git a/third_party/xla/third_party/tsl/tsl/profiler/rpc/BUILD b/third_party/xla/third_party/tsl/tsl/profiler/rpc/BUILD index 5f7f5dbc3205be..debedf53094de4 100644 --- a/third_party/xla/third_party/tsl/tsl/profiler/rpc/BUILD +++ b/third_party/xla/third_party/tsl/tsl/profiler/rpc/BUILD @@ -1,14 +1,15 @@ +load("//tsl:tsl.bzl", "internal_visibility") +load("//tsl:tsl.default.bzl", "tsl_grpc_cc_dependencies") load("@local_tsl//tsl/platform:rules_cc.bzl", "cc_library") load( "//tsl/profiler/builds:build_config.bzl", "tf_profiler_copts", "tf_profiler_pybind_cc_library_wrapper", ) -load("//tsl:tsl.default.bzl", "tsl_grpc_cc_dependencies") -load("//tsl:tsl.bzl", "set_external_visibility") package( - default_visibility = ["//visibility:public"], + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = internal_visibility(["//tsl/profiler:internal"]), licenses = ["notice"], ) @@ -18,7 +19,14 @@ cc_library( srcs = ["profiler_service_impl.cc"], hdrs = ["profiler_service_impl.h"], copts = tf_profiler_copts(), - visibility = ["//visibility:public"], + visibility = internal_visibility([ + "//tensorflow/core/data/service:__pkg__", + "//tensorflow/core/distributed_runtime/rpc:__pkg__", + "//tensorflow/core/profiler/rpc:__pkg__", + "//tensorflow/python:__pkg__", + "//tsl/profiler/rpc/client:__pkg__", + "//tensorflow_serving/model_servers:__pkg__", + ]), deps = [ "//tsl/platform:env", "//tsl/platform:env_time", @@ -45,7 +53,7 @@ cc_library( tf_profiler_pybind_cc_library_wrapper( name = "profiler_server_for_pybind", actual = ":profiler_server_impl", - visibility = ["//visibility:public"], + visibility = internal_visibility(["//tensorflow/python/profiler/internal:__pkg__"]), ) cc_library( @@ -53,7 +61,14 @@ cc_library( srcs = ["profiler_server.cc"], hdrs = ["profiler_server.h"], copts = tf_profiler_copts(), - visibility = ["//visibility:public"], + visibility = internal_visibility([ + "@local_xla//xla:__subpackages__", + "//tensorflow/core/profiler/rpc:__pkg__", + "//tensorflow/python:__pkg__", + "//tensorflow/python/profiler/internal:__pkg__", + "//tsl/profiler:internal", + "//tsl/profiler/rpc/client:__pkg__", + ]), deps = [ ":profiler_service_impl", "//tsl/platform:logging", diff --git a/third_party/xla/third_party/tsl/tsl/profiler/rpc/client/BUILD b/third_party/xla/third_party/tsl/tsl/profiler/rpc/client/BUILD index 8dc488fc78c5aa..deb9383157a594 100644 --- a/third_party/xla/third_party/tsl/tsl/profiler/rpc/client/BUILD +++ b/third_party/xla/third_party/tsl/tsl/profiler/rpc/client/BUILD @@ -1,6 +1,6 @@ -load("@local_tsl//tsl/platform:rules_cc.bzl", "cc_library") -load("//tsl:tsl.bzl", "set_external_visibility") +load("//tsl:tsl.bzl", "internal_visibility") load("//tsl:tsl.default.bzl", "tsl_grpc_cc_dependencies") +load("@local_tsl//tsl/platform:rules_cc.bzl", "cc_library") load( "//tsl/platform:build_config.bzl", "tf_protos_profiler_service", @@ -13,7 +13,10 @@ load( ) package( - default_visibility = ["//visibility:public"], + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = internal_visibility([ + "//tsl/profiler:internal", + ]), licenses = ["notice"], ) @@ -22,7 +25,11 @@ cc_library( srcs = ["capture_profile.cc"], hdrs = ["capture_profile.h"], copts = tf_profiler_copts(), - visibility = ["//visibility:public"], + visibility = internal_visibility([ + "@local_xla//xla/python:__pkg__", + "//tensorflow/core/profiler/rpc/client:__pkg__", + "//tensorflow/python/profiler/internal:__pkg__", + ]), deps = [ ":profiler_client_for_pybind", ":remote_profiler_session_manager", @@ -49,7 +56,12 @@ cc_library( srcs = ["save_profile.cc"], hdrs = ["save_profile.h"], copts = tf_profiler_copts(), - visibility = ["//visibility:public"], + visibility = internal_visibility([ + "//tensorflow/core/profiler/rpc/client:__pkg__", + "@local_xla//xla/python:__pkg__", + "//tsl/profiler:internal", + "//tsl/profiler/rpc:__pkg__", + ]), deps = [ "//tsl/lib/io:zlib_compression_options", "//tsl/lib/io:zlib_outputbuffer", @@ -69,13 +81,20 @@ cc_library( tf_profiler_pybind_cc_library_wrapper( name = "profiler_client_for_pybind", actual = ":profiler_client", - visibility = ["//visibility:public"], + visibility = internal_visibility([ + "//tensorflow/core/profiler/rpc/client:__pkg__", + "//tensorflow/python/profiler/internal:__pkg__", + ]), ) cc_library( name = "profiler_client", hdrs = ["profiler_client.h"], - visibility = ["//visibility:public"], + visibility = internal_visibility([ + "@local_xla//xla:__subpackages__", + "//tensorflow/core/profiler/rpc/client:__pkg__", + "//tensorflow/python/profiler/internal:__pkg__", + ]), deps = [ ":profiler_client_impl", "//tsl/platform:status", @@ -94,7 +113,12 @@ cc_library( "profiler_client.h", ], copts = tf_profiler_copts(), - visibility = ["//visibility:public"], + visibility = internal_visibility([ + "@local_xla//xla/python:__pkg__", + "//tensorflow/core/profiler/rpc/client:__pkg__", + "//tensorflow/python:__pkg__", + "//tensorflow/python/profiler/internal:__pkg__", + ]), deps = [ "//tsl/platform:errors", "//tsl/platform:logging", @@ -114,7 +138,6 @@ cc_library( name = "profiler_client_test_util", testonly = 1, hdrs = ["profiler_client_test_util.h"], - visibility = ["//visibility:public"], deps = [ "//tsl/platform:logging", "//tsl/platform:test", @@ -131,7 +154,6 @@ cc_library( tsl_cc_test( name = "profiler_client_test", srcs = ["profiler_client_test.cc"], - visibility = ["//visibility:public"], deps = [ ":profiler_client", ":profiler_client_impl", # for oss @@ -156,7 +178,6 @@ cc_library( srcs = ["remote_profiler_session_manager.cc"], hdrs = ["remote_profiler_session_manager.h"], copts = tf_profiler_copts(), - visibility = ["//visibility:public"], deps = [ ":profiler_client_for_pybind", "//tsl/platform:env_time", @@ -177,7 +198,6 @@ cc_library( tsl_cc_test( name = "remote_profiler_session_manager_test", srcs = ["remote_profiler_session_manager_test.cc"], - visibility = ["//visibility:public"], deps = [ ":profiler_client_impl", # for oss ":profiler_client_test_util", diff --git a/third_party/xla/third_party/tsl/tsl/profiler/utils/BUILD b/third_party/xla/third_party/tsl/tsl/profiler/utils/BUILD index 5c1f5c5742339e..527723044ae302 100644 --- a/third_party/xla/third_party/tsl/tsl/profiler/utils/BUILD +++ b/third_party/xla/third_party/tsl/tsl/profiler/utils/BUILD @@ -1,11 +1,14 @@ -load("//tsl:tsl.bzl", "set_external_visibility") +load("//tsl:tsl.bzl", "internal_visibility") load("//tsl/platform:build_config.bzl", "tsl_cc_test") load("//tsl/platform:build_config_root.bzl", "if_static") load("@local_tsl//tsl/platform:rules_cc.bzl", "cc_library") load("//tsl/profiler/builds:build_config.bzl", "tf_profiler_copts") package( - default_visibility = ["//visibility:public"], + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = internal_visibility([ + "//tsl/profiler:internal", + ]), licenses = ["notice"], ) @@ -19,13 +22,11 @@ package_group( cc_library( name = "math_utils", hdrs = ["math_utils.h"], - visibility = ["//visibility:public"], ) cc_library( name = "format_utils", hdrs = ["format_utils.h"], - visibility = ["//visibility:public"], deps = [ "//tsl/platform:logging", ], @@ -35,7 +36,7 @@ cc_library( name = "time_utils", hdrs = ["time_utils.h"], copts = tf_profiler_copts(), - visibility = ["//visibility:public"], + visibility = internal_visibility([":friends"]), deps = [ ":math_utils", ] + if_static([ @@ -50,7 +51,11 @@ cc_library( "time_utils.h", ], copts = tf_profiler_copts(), - visibility = ["//visibility:public"], + visibility = internal_visibility([ + "@local_xla//xla:__subpackages__", + "//tsl/platform/cloud:__pkg__", + "//tsl/profiler:internal", + ]), deps = [ ":math_utils", "@com_google_absl//absl/time", @@ -62,7 +67,6 @@ cc_library( name = "timespan", hdrs = ["timespan.h"], copts = tf_profiler_copts(), - visibility = ["//visibility:public"], deps = [ ":math_utils", "//tsl/platform:logging", @@ -74,7 +78,6 @@ cc_library( tsl_cc_test( name = "timespan_test", srcs = ["timespan_test.cc"], - visibility = ["//visibility:public"], deps = [ ":timespan", "//tsl/platform:test", @@ -87,7 +90,6 @@ cc_library( srcs = ["tf_op_utils.cc"], hdrs = ["tf_op_utils.h"], copts = tf_profiler_copts(), - visibility = ["//visibility:public"], deps = [ "//tsl/platform:macros", "//tsl/platform:regexp", @@ -99,7 +101,6 @@ tsl_cc_test( name = "tf_op_utils_test", size = "small", srcs = ["tf_op_utils_test.cc"], - visibility = ["//visibility:public"], deps = [ ":tf_op_utils", "//tsl/platform:test", @@ -113,7 +114,7 @@ cc_library( srcs = ["xplane_schema.cc"], hdrs = ["xplane_schema.h"], copts = tf_profiler_copts(), - visibility = ["//visibility:public"], + visibility = internal_visibility([":friends"]), deps = [ ":tf_op_utils", "//tsl/lib/gtl:map_util", @@ -133,7 +134,7 @@ cc_library( srcs = ["xplane_visitor.cc"], hdrs = ["xplane_visitor.h"], copts = tf_profiler_copts(), - visibility = ["//visibility:public"], + visibility = internal_visibility([":friends"]), deps = [ ":timespan", "//tsl/platform:logging", @@ -150,7 +151,7 @@ cc_library( srcs = ["xplane_builder.cc"], hdrs = ["xplane_builder.h"], copts = tf_profiler_copts(), - visibility = ["//visibility:public"], + visibility = internal_visibility([":friends"]), deps = [ ":math_utils", ":timespan", @@ -170,7 +171,6 @@ tsl_cc_test( name = "xplane_builder_test", size = "small", srcs = ["xplane_builder_test.cc"], - visibility = ["//visibility:public"], deps = [ ":xplane_builder", ":xplane_visitor", @@ -185,7 +185,10 @@ cc_library( name = "trace_utils", hdrs = ["trace_utils.h"], copts = tf_profiler_copts(), - visibility = ["//visibility:public"], + visibility = internal_visibility([ + "@local_xla//xla/backends/profiler/gpu:__pkg__", + "//tsl/profiler:internal", + ]), deps = [ "//tsl/platform:types", "@com_google_absl//absl/strings", @@ -197,7 +200,7 @@ cc_library( srcs = ["xplane_utils.cc"], hdrs = ["xplane_utils.h"], copts = tf_profiler_copts(), - visibility = ["//visibility:public"], + visibility = internal_visibility([":friends"]), deps = [ ":math_utils", ":tf_xplane_visitor", @@ -222,7 +225,6 @@ cc_library( tsl_cc_test( name = "xplane_utils_test", srcs = ["xplane_utils_test.cc"], - visibility = ["//visibility:public"], deps = [ ":math_utils", ":xplane_builder", @@ -244,7 +246,7 @@ cc_library( name = "tf_xplane_visitor", hdrs = ["tf_xplane_visitor.h"], copts = tf_profiler_copts(), - visibility = ["//visibility:public"], + visibility = internal_visibility([":friends"]), deps = [ ":xplane_schema", ":xplane_visitor", @@ -257,7 +259,7 @@ cc_library( srcs = ["parse_annotation.cc"], hdrs = ["parse_annotation.h"], copts = tf_profiler_copts(), - visibility = ["//visibility:public"], + visibility = internal_visibility([":friends"]), deps = [ "@com_google_absl//absl/strings", ], @@ -266,7 +268,6 @@ cc_library( tsl_cc_test( name = "parse_annotation_test", srcs = ["parse_annotation_test.cc"], - visibility = ["//visibility:public"], deps = [ ":parse_annotation", "//tsl/platform:test", @@ -280,7 +281,7 @@ cc_library( srcs = ["group_events.cc"], hdrs = ["group_events.h"], copts = tf_profiler_copts(), - visibility = ["//visibility:public"], + visibility = internal_visibility([":friends"]), deps = [ ":tf_xplane_visitor", ":xplane_builder", @@ -308,7 +309,7 @@ cc_library( srcs = ["xplane_test_utils.cc"], hdrs = ["xplane_test_utils.h"], copts = tf_profiler_copts(), - visibility = ["//visibility:public"], + visibility = internal_visibility([":friends"]), deps = [ ":xplane_builder", ":xplane_schema", @@ -324,7 +325,6 @@ cc_library( tsl_cc_test( name = "group_events_test", srcs = ["group_events_test.cc"], - visibility = ["//visibility:public"], deps = [ ":group_events", ":tf_xplane_visitor", @@ -346,7 +346,6 @@ cc_library( name = "tpu_xplane_utils", srcs = ["tpu_xplane_utils.cc"], hdrs = ["tpu_xplane_utils.h"], - visibility = ["//visibility:public"], deps = [ ":xplane_schema", ":xplane_utils", @@ -360,7 +359,6 @@ cc_library( tsl_cc_test( name = "tpu_xplane_utils_test", srcs = ["tpu_xplane_utils_test.cc"], - visibility = ["//visibility:public"], deps = [ ":tpu_xplane_utils", ":xplane_schema", @@ -376,7 +374,10 @@ cc_library( name = "file_system_utils", hdrs = ["file_system_utils.h"], copts = tf_profiler_copts(), - visibility = ["//visibility:public"], + visibility = internal_visibility([ + "@local_xla//xla/python:__pkg__", + "//tsl/profiler:internal", + ]), deps = [ "//tsl/platform", "@com_google_absl//absl/strings", @@ -388,7 +389,10 @@ cc_library( srcs = ["buffer_pool.cc"], hdrs = ["buffer_pool.h"], copts = tf_profiler_copts(), - visibility = ["//visibility:public"], + visibility = internal_visibility([ + "@local_xla//xla/backends/profiler/gpu:__pkg__", + "//tsl/profiler:internal", + ]), deps = [ "//tsl/platform:logging", "//tsl/platform:mutex", @@ -400,7 +404,6 @@ cc_library( tsl_cc_test( name = "buffer_pool_test", srcs = ["buffer_pool_test.cc"], - visibility = ["//visibility:public"], deps = [ ":buffer_pool", "//tsl/platform:test", @@ -413,7 +416,7 @@ cc_library( srcs = ["preprocess_xplane.cc"], hdrs = ["preprocess_xplane.h"], copts = tf_profiler_copts(), - visibility = ["//visibility:public"], + visibility = internal_visibility([":friends"]), deps = [ ":tpu_xplane_utils", ":trace_utils", @@ -432,7 +435,6 @@ cc_library( tsl_cc_test( name = "preprocess_xplane_test", srcs = ["preprocess_xplane_test.cc"], - visibility = ["//visibility:public"], deps = [ ":preprocess_xplane", ":tf_xplane_visitor", @@ -452,7 +454,6 @@ cc_library( name = "session_manager", srcs = ["session_manager.cc"], hdrs = ["session_manager.h"], - visibility = ["//visibility:public"], deps = [ "//tsl/platform:errors", "//tsl/platform:status", @@ -461,3 +462,29 @@ cc_library( "@com_google_absl//absl/container:flat_hash_map", ], ) + +cc_library( + name = "timestamp_utils", + srcs = ["timestamp_utils.cc"], + hdrs = ["timestamp_utils.h"], + deps = [ + ":xplane_builder", + ":xplane_schema", + ":xplane_utils", + "//tsl/profiler/protobuf:xplane_proto_cc", + "@com_google_absl//absl/log", + ], +) + +tsl_cc_test( + name = "timestamp_utils_test", + srcs = ["timestamp_utils_test.cc"], + deps = [ + ":timestamp_utils", + ":xplane_schema", + ":xplane_utils", + ":xplane_visitor", + "//tsl/platform:test", + "//tsl/platform:test_main", + ], +) diff --git a/third_party/xla/third_party/tsl/tsl/profiler/utils/timestamp_utils.cc b/third_party/xla/third_party/tsl/tsl/profiler/utils/timestamp_utils.cc new file mode 100644 index 00000000000000..ea208ed309c468 --- /dev/null +++ b/third_party/xla/third_party/tsl/tsl/profiler/utils/timestamp_utils.cc @@ -0,0 +1,49 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tsl/profiler/utils/timestamp_utils.h" + +#include + +#include "absl/log/log.h" +#include "tsl/profiler/protobuf/xplane.pb.h" +#include "tsl/profiler/utils/xplane_builder.h" +#include "tsl/profiler/utils/xplane_schema.h" +#include "tsl/profiler/utils/xplane_utils.h" + +namespace tsl { +namespace profiler { + +void SetSessionTimestamps(uint64_t start_walltime_ns, uint64_t stop_walltime_ns, + tensorflow::profiler::XSpace& space) { + if (start_walltime_ns != 0 && stop_walltime_ns != 0) { + tsl::profiler::XPlaneBuilder plane( + tsl::profiler::FindOrAddMutablePlaneWithName( + &space, tsl::profiler::kTaskEnvPlaneName)); + plane.AddStatValue(*plane.GetOrCreateStatMetadata( + GetTaskEnvStatTypeStr(kEnvProfileStartTime)), + start_walltime_ns); + plane.AddStatValue(*plane.GetOrCreateStatMetadata( + GetTaskEnvStatTypeStr(kEnvProfileStopTime)), + stop_walltime_ns); + } else { + LOG(WARNING) << "Not Setting Session Timestamps, (start_walltime_ns, " + "stop_walltime_ns) : " + << start_walltime_ns << ", " << stop_walltime_ns; + } +} + +} // namespace profiler +} // namespace tsl diff --git a/tensorflow/core/profiler/backends/gpu/cupti_wrapper.h b/third_party/xla/third_party/tsl/tsl/profiler/utils/timestamp_utils.h similarity index 51% rename from tensorflow/core/profiler/backends/gpu/cupti_wrapper.h rename to third_party/xla/third_party/tsl/tsl/profiler/utils/timestamp_utils.h index 36e840e6d5bdb2..87013c97a6f5b0 100644 --- a/tensorflow/core/profiler/backends/gpu/cupti_wrapper.h +++ b/third_party/xla/third_party/tsl/tsl/profiler/utils/timestamp_utils.h @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -13,17 +13,21 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CORE_PROFILER_BACKENDS_GPU_CUPTI_WRAPPER_H_ -#define TENSORFLOW_CORE_PROFILER_BACKENDS_GPU_CUPTI_WRAPPER_H_ +#ifndef TENSORFLOW_TSL_PROFILER_UTILS_TIMESTAMP_UTILS_H_ +#define TENSORFLOW_TSL_PROFILER_UTILS_TIMESTAMP_UTILS_H_ -#include "xla/backends/profiler/gpu/cupti_wrapper.h" +#include -namespace tensorflow { -namespace profiler { +#include "tsl/profiler/protobuf/xplane.pb.h" -using xla::profiler::CuptiWrapper; // NOLINT +namespace tsl { +namespace profiler { +// Add metadata regarding profile start_time and stop_time to xspace. +// This function won't have an effect if either of the timestamps is zero. +void SetSessionTimestamps(uint64_t start_walltime_ns, uint64_t stop_walltime_ns, + tensorflow::profiler::XSpace& space); } // namespace profiler -} // namespace tensorflow +} // namespace tsl -#endif // TENSORFLOW_CORE_PROFILER_BACKENDS_GPU_CUPTI_WRAPPER_H_ +#endif // TENSORFLOW_TSL_PROFILER_UTILS_TIMESTAMP_UTILS_H_ diff --git a/third_party/xla/third_party/tsl/tsl/profiler/utils/timestamp_utils_test.cc b/third_party/xla/third_party/tsl/tsl/profiler/utils/timestamp_utils_test.cc new file mode 100644 index 00000000000000..893e31ebb5ec59 --- /dev/null +++ b/third_party/xla/third_party/tsl/tsl/profiler/utils/timestamp_utils_test.cc @@ -0,0 +1,45 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tsl/profiler/utils/timestamp_utils.h" + +#include "tsl/platform/test.h" +#include "tsl/profiler/utils/xplane_schema.h" +#include "tsl/profiler/utils/xplane_utils.h" +#include "tsl/profiler/utils/xplane_visitor.h" + +namespace tsl { +namespace profiler { +using ::testing::Eq; + +TEST(TimestampUtilsTest, StartAndStopTimestampAreAdded) { + XSpace xspace; + + SetSessionTimestamps(1000, 2000, xspace); + + const XPlane* xplane = FindPlaneWithName(xspace, kTaskEnvPlaneName); + + XPlaneVisitor visitor(xplane, {}, {FindTaskEnvStatType}); + + auto start_time = visitor.GetStat(TaskEnvStatType::kEnvProfileStartTime); + auto stop_time = visitor.GetStat(TaskEnvStatType::kEnvProfileStopTime); + + EXPECT_THAT(start_time->IntOrUintValue(), Eq(1000)); + EXPECT_THAT(stop_time->IntOrUintValue(), Eq(2000)); +} + +} // namespace profiler + +} // namespace tsl diff --git a/third_party/xla/third_party/tsl/tsl/profiler/utils/xplane_schema.cc b/third_party/xla/third_party/tsl/tsl/profiler/utils/xplane_schema.cc index 0f330f7653ca31..5794a3f490ba19 100644 --- a/third_party/xla/third_party/tsl/tsl/profiler/utils/xplane_schema.cc +++ b/third_party/xla/third_party/tsl/tsl/profiler/utils/xplane_schema.cc @@ -59,6 +59,8 @@ const absl::string_view kCounterEventsLineName = "_counters_"; const absl::string_view kDeviceVendorNvidia = "Nvidia"; const absl::string_view kDeviceVendorAMD = "AMD"; +const absl::string_view kTaskEnvPlaneName = "Task Environment"; + namespace { constexpr int kNumHostEventTypes = @@ -330,6 +332,7 @@ const StatTypeMap& GetStatTypeMap() { {"dcn_destination_per_slice_device_id", kDcnDestinationPerSliceDeviceId}, {"dcn_chunk", kDcnChunk}, {"dcn_loop_index", kDcnLoopIndex}, + {"dropped_traces", kDroppedTraces}, }); DCHECK_EQ(stat_type_map->size(), kNumStatTypes); return *stat_type_map; @@ -395,6 +398,29 @@ const LineIdTypeStrMap& GetLineIdTypeStrMap() { return *line_id_type_str_map; } +using TaskEnvStatTypeMap = + absl::flat_hash_map; +using TaskEnvStatTypeStrMap = + absl::flat_hash_map; + +constexpr int kNumTaskEnvStatTypes = TaskEnvStatType::kLastTaskEnvStatType - + TaskEnvStatType::kFirstTaskEnvStatType + 1; + +const TaskEnvStatTypeMap& GetTaskEnvStatTypeMap() { + static auto* task_env_stat_type_map = new TaskEnvStatTypeMap({ + {"profile_start_time", kEnvProfileStartTime}, + {"profile_stop_time", kEnvProfileStopTime}, + }); + DCHECK_EQ(task_env_stat_type_map->size(), kNumTaskEnvStatTypes); + return *task_env_stat_type_map; +} + +const TaskEnvStatTypeStrMap& GetTaskEnvStatTypeStrMap() { + static auto* task_env_stat_type_str_map = new TaskEnvStatTypeStrMap( + gtl::ReverseMap(GetTaskEnvStatTypeMap())); + return *task_env_stat_type_str_map; +} + } // namespace absl::string_view GetHostEventTypeStr(HostEventType event_type) { @@ -443,6 +469,17 @@ std::optional FindMegaScaleStatType(absl::string_view stat_name) { return std::nullopt; } +absl::string_view GetTaskEnvStatTypeStr(TaskEnvStatType stat_type) { + return GetTaskEnvStatTypeStrMap().at(stat_type); +} + +std::optional FindTaskEnvStatType(absl::string_view stat_name) { + if (auto stat_type = gtl::FindOrNull(GetTaskEnvStatTypeMap(), stat_name)) { + return *stat_type; + } + return std::nullopt; +} + absl::string_view GetLineIdTypeStr(LineIdType line_id_type) { return GetLineIdTypeStrMap().at(line_id_type); } diff --git a/third_party/xla/third_party/tsl/tsl/profiler/utils/xplane_schema.h b/third_party/xla/third_party/tsl/tsl/profiler/utils/xplane_schema.h index 6b8fa485ccfa1f..255099aef57aa5 100644 --- a/third_party/xla/third_party/tsl/tsl/profiler/utils/xplane_schema.h +++ b/third_party/xla/third_party/tsl/tsl/profiler/utils/xplane_schema.h @@ -77,6 +77,9 @@ TF_CONST_INIT extern const absl::string_view kCounterEventsLineName; TF_CONST_INIT extern const absl::string_view kDeviceVendorNvidia; TF_CONST_INIT extern const absl::string_view kDeviceVendorAMD; +// Name of Xplane that contains environment information +TF_CONST_INIT extern const absl::string_view kTaskEnvPlaneName; + // Max collectives to display per TPU. // Since in most cases there will be more than 9 collectives, the last line // contains all collectives that did not qualify to get their own line. @@ -314,7 +317,8 @@ enum StatType { kEdgeTpuModelInfo, kEdgeTpuModelProfileInfo, kEdgeTpuMlir, - kLastStatType = kEdgeTpuMlir, + kDroppedTraces, + kLastStatType = kDroppedTraces, }; enum MegaScaleStatType : uint8_t { @@ -341,6 +345,13 @@ enum MegaScaleStatType : uint8_t { kLastMegaScaleStatType = kMegaScaleGraphProtos, }; +enum TaskEnvStatType { + kFirstTaskEnvStatType = 1, + kEnvProfileStartTime = kFirstTaskEnvStatType, + kEnvProfileStopTime, + kLastTaskEnvStatType = kEnvProfileStopTime, +}; + static constexpr uint32_t kLineIdOffset = 10000; enum LineIdType { @@ -401,6 +412,10 @@ bool IsInternalEvent(std::optional event_type); // Returns true if the given stat shouldn't be shown in the trace viewer. bool IsInternalStat(std::optional stat_type); +absl::string_view GetTaskEnvStatTypeStr(TaskEnvStatType stat_type); + +std::optional FindTaskEnvStatType(absl::string_view stat_name); + // Support for flow events: // This class enables encoding/decoding the flow id and direction, stored as // XStat value. The flow id are limited to 56 bits. diff --git a/third_party/xla/third_party/tsl/tsl/protobuf/BUILD b/third_party/xla/third_party/tsl/tsl/protobuf/BUILD index a8c7f93c8999b4..db185588785a46 100644 --- a/third_party/xla/third_party/tsl/tsl/protobuf/BUILD +++ b/third_party/xla/third_party/tsl/tsl/protobuf/BUILD @@ -2,7 +2,7 @@ load( "//tsl:tsl.bzl", "if_google", - "set_external_visibility", + "internal_visibility", ) load( "//tsl/platform:build_config.bzl", @@ -10,7 +10,12 @@ load( ) package( - default_visibility = ["//visibility:public"], + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = internal_visibility([ + "//tensorflow/core:__subpackages__", + "//tsl:internal", + "//tensorflow_models:__subpackages__", + ]), features = if_google(["-parse_headers"]), licenses = ["notice"], ) @@ -98,7 +103,10 @@ tf_proto_library( name = "test_log_proto", srcs = ["test_log.proto"], make_default_target_header_only = True, - visibility = ["//visibility:public"], + visibility = internal_visibility([ + "//tensorflow/core:__subpackages__", + "//tsl/util:__pkg__", + ]), ) tf_proto_library( diff --git a/third_party/xla/third_party/tsl/tsl/protobuf/dnn.proto b/third_party/xla/third_party/tsl/tsl/protobuf/dnn.proto index b349115292e43a..cc16b2141e0e7a 100644 --- a/third_party/xla/third_party/tsl/tsl/protobuf/dnn.proto +++ b/third_party/xla/third_party/tsl/tsl/protobuf/dnn.proto @@ -179,6 +179,13 @@ message ConvolutionDescriptorProto { string name = 7; } +// NormKind kind +enum NormKind { + LAYER_FWD_INFER = 0; + LAYER_FWD_TRAIN = 1; + LAYER_BWD = 2; +} + // FusedMHAKind kind enum FusedMHAKind { BMM1_OUTPUT_UNKNOWN = 0; diff --git a/third_party/xla/third_party/tsl/tsl/python/lib/core/BUILD b/third_party/xla/third_party/tsl/tsl/python/lib/core/BUILD index df1482766a08d5..421930c4b0161f 100644 --- a/third_party/xla/third_party/tsl/tsl/python/lib/core/BUILD +++ b/third_party/xla/third_party/tsl/tsl/python/lib/core/BUILD @@ -2,20 +2,21 @@ # Implementation of custom numpy floats. package( - default_visibility = ["//visibility:public"], + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = [ + "//visibility:public", + ], licenses = ["notice"], ) filegroup( name = "numpy_hdr", srcs = ["numpy.h"], - visibility = ["//visibility:public"], ) filegroup( name = "basic_hdrs", srcs = ["numpy.h"], - visibility = ["//visibility:public"], ) cc_library( @@ -31,7 +32,6 @@ cc_library( "-use_header_modules", # Required for pybind11. "-parse_headers", ], - visibility = ["//visibility:public"], deps = [ ":numpy", "//third_party/py/numpy:headers", @@ -47,7 +47,6 @@ cc_library( name = "numpy", srcs = ["numpy.cc"], hdrs = ["numpy.h"], - visibility = ["//visibility:public"], deps = [ "//third_party/py/numpy:headers", "//third_party/python_runtime:headers", diff --git a/third_party/xla/third_party/tsl/tsl/tsl.bzl b/third_party/xla/third_party/tsl/tsl/tsl.bzl index cc668741f78537..5e81264ce4c7ca 100644 --- a/third_party/xla/third_party/tsl/tsl/tsl.bzl +++ b/third_party/xla/third_party/tsl/tsl/tsl.bzl @@ -101,6 +101,14 @@ def if_google(google_value, oss_value = []): """ return oss_value # copybara:comment_replace return google_value +def internal_visibility(internal_targets): + """Returns internal_targets in g3, but returns public in OSS. + + Useful for targets that are part of the XLA/TSL API surface but want finer-grained visibilites + internally. + """ + return if_google(internal_targets, ["//visibility:public"]) + # TODO(jakeharmon): Use this to replace if_static def if_tsl_link_protobuf(if_true, if_false = []): return select({ @@ -432,6 +440,9 @@ check_deps = rule( def get_compatible_with_portable(): return [] +def get_compatible_with_libtpu_portable(): + return [] + def filegroup(**kwargs): native.filegroup(**kwargs) @@ -466,7 +477,7 @@ _transitive_hdrs = rule( def transitive_hdrs(name, deps = [], **kwargs): _transitive_hdrs(name = name + "_gather", deps = deps) - native.filegroup(name = name, srcs = [":" + name + "_gather"]) + native.filegroup(name = name, srcs = [":" + name + "_gather"], **kwargs) # Create a header only library that includes all the headers exported by # the libraries in deps. @@ -761,7 +772,5 @@ def tsl_pybind_extension_opensource( compatible_with = compatible_with, ) -# Used for specifying external visibility constraints. In non-monorepo situations, this needs to be -# public, but monorepos can have more precise constraints. -def set_external_visibility(monorepo_paths): - return if_oss(["//visibility:public"], monorepo_paths) +def nvtx_headers(): + return if_oss(["@nvtx_archive//:headers"], ["@local_config_cuda//cuda:cuda_headers"]) diff --git a/third_party/xla/third_party/tsl/tsl/tsl.default.bzl b/third_party/xla/third_party/tsl/tsl/tsl.default.bzl index 1759e5106320d5..912939245725ab 100644 --- a/third_party/xla/third_party/tsl/tsl/tsl.default.bzl +++ b/third_party/xla/third_party/tsl/tsl/tsl.default.bzl @@ -3,6 +3,7 @@ load( "//tsl:tsl.bzl", _filegroup = "filegroup", + _get_compatible_with_libtpu_portable = "get_compatible_with_libtpu_portable", _get_compatible_with_portable = "get_compatible_with_portable", _if_not_mobile_or_arm_or_lgpl_restricted = "if_not_mobile_or_arm_or_lgpl_restricted", _internal_hlo_deps = "internal_hlo_deps", @@ -11,6 +12,7 @@ load( ) get_compatible_with_portable = _get_compatible_with_portable +get_compatible_with_libtpu_portable = _get_compatible_with_libtpu_portable filegroup = _filegroup if_not_mobile_or_arm_or_lgpl_restricted = _if_not_mobile_or_arm_or_lgpl_restricted internal_hlo_deps = _internal_hlo_deps diff --git a/third_party/xla/third_party/tsl/tsl/util/BUILD b/third_party/xla/third_party/tsl/tsl/util/BUILD index 827826f13d3651..cb70feac45c080 100644 --- a/third_party/xla/third_party/tsl/tsl/util/BUILD +++ b/third_party/xla/third_party/tsl/tsl/util/BUILD @@ -4,14 +4,10 @@ # The libraries in this package are not allowed to have ANY dependencies # to other TF components outside of TSL. -load( - "@local_tsl//tsl/platform:rules_cc.bzl", - "cc_library", -) load( "//tsl:tsl.bzl", "check_deps", - "set_external_visibility", + "internal_visibility", "tsl_copts", ) load("//tsl:tsl.default.bzl", "filegroup", "get_compatible_with_portable") @@ -23,9 +19,16 @@ load( "//tsl/platform:build_config.bzl", "tsl_cc_test", ) +load( + "@local_tsl//tsl/platform:rules_cc.bzl", + "cc_library", +) package( - default_visibility = ["//visibility:public"], + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = [ + "//visibility:public", + ], licenses = ["notice"], ) @@ -35,7 +38,6 @@ filegroup( "byte_swap_array.cc", "byte_swap_array.h", ], - visibility = ["//visibility:public"], ) filegroup( @@ -52,7 +54,6 @@ filegroup( "use_cudnn.cc", "use_cudnn.h", ], - visibility = ["//visibility:public"], ) filegroup( @@ -61,7 +62,10 @@ filegroup( "determinism.h", ], compatible_with = get_compatible_with_portable(), - visibility = ["//visibility:public"], + visibility = internal_visibility([ + "//tensorflow:__subpackages__", + "//tensorflow/core/util:__pkg__", + ]), ) filegroup( @@ -75,7 +79,6 @@ filegroup( "stats_calculator.h", "use_cudnn.h", ], - visibility = ["//visibility:public"], ) filegroup( @@ -83,7 +86,6 @@ filegroup( srcs = [ "use_cudnn.cc", ], - visibility = ["//visibility:public"], ) filegroup( @@ -93,7 +95,10 @@ filegroup( "env_var.h", "use_cudnn.h", ], - visibility = ["//visibility:public"], + visibility = internal_visibility([ + "//tensorflow/core:__pkg__", + "//tensorflow/core/util:__pkg__", + ]), ) filegroup( @@ -102,7 +107,10 @@ filegroup( "determinism.h", ], compatible_with = get_compatible_with_portable(), - visibility = ["//visibility:public"], + visibility = internal_visibility([ + "//tensorflow:__subpackages__", + "//tensorflow/core/util:__pkg__", + ]), ) filegroup( @@ -112,14 +120,12 @@ filegroup( "stat_summarizer_options.h", "use_cudnn.h", ], - visibility = ["//visibility:public"], ) cc_library( name = "byte_swap_array", srcs = ["byte_swap_array.cc"], hdrs = ["byte_swap_array.h"], - visibility = ["//visibility:public"], deps = [ "//tsl/platform:byte_order", "//tsl/platform:errors", @@ -131,6 +137,7 @@ cc_library( name = "determinism_hdr_lib", hdrs = [":determinism_hdr"], compatible_with = get_compatible_with_portable(), + # TODO(b/298501506): narrow this in a way that won't break TAP visibility = ["//visibility:public"], ) @@ -141,7 +148,7 @@ cc_library( srcs = ["determinism.cc"], hdrs = ["determinism.h"], copts = tsl_copts(), - visibility = ["//visibility:public"], + visibility = internal_visibility(["//tensorflow:__subpackages__"]), deps = [ ":env_var", "//tsl/platform:mutex", @@ -164,7 +171,7 @@ cc_library( alias( name = "determinism_for_kernels", actual = if_static(":determinism", ":determinism_hdr_lib"), - visibility = ["//visibility:public"], + visibility = internal_visibility(["//tensorflow:__subpackages__"]), ) check_deps( @@ -185,7 +192,6 @@ cc_library( # whenever determinism tests are run. ":determinism_check_deps", ], - visibility = ["//visibility:public"], deps = [":determinism"], ) @@ -193,7 +199,6 @@ cc_library( name = "env_var", srcs = ["env_var.cc"], hdrs = ["env_var.h"], - visibility = ["//visibility:public"], deps = [ "//tsl/platform:errors", "//tsl/platform:logging", @@ -210,7 +215,10 @@ cc_library( name = "reporter", srcs = ["reporter.cc"], hdrs = ["reporter.h"], - visibility = ["//visibility:public"], + visibility = internal_visibility([ + "//tensorflow/core:__subpackages__", + "//tsl:__subpackages__", + ]), deps = [ "//tsl/platform:env", "//tsl/platform:env_impl", @@ -233,13 +241,14 @@ cc_library( "stats_calculator.h", ], copts = tsl_copts(), - visibility = ["//visibility:public"], + visibility = internal_visibility([ + "//tsl:internal", + ]), ) tsl_cc_test( name = "stats_calculator_test", srcs = ["stats_calculator_test.cc"], - visibility = ["//visibility:public"], deps = [ ":stats_calculator_portable", "//tsl/platform:test", @@ -251,7 +260,6 @@ cc_library( name = "device_name_utils", srcs = ["device_name_utils.cc"], hdrs = ["device_name_utils.h"], - visibility = ["//visibility:public"], deps = [ "//tsl/platform:errors", "//tsl/platform:status", @@ -263,7 +271,6 @@ tsl_cc_test( name = "device_name_utils_test", size = "small", srcs = ["device_name_utils_test.cc"], - visibility = ["//visibility:public"], deps = [ ":device_name_utils", "//tsl/lib/core:status_test_util", @@ -279,7 +286,6 @@ cc_library( name = "command_line_flags", srcs = ["command_line_flags.cc"], hdrs = ["command_line_flags.h"], - visibility = ["//visibility:public"], deps = [ "//tsl/platform:logging", "//tsl/platform:str_util", @@ -296,7 +302,7 @@ filegroup( srcs = [ "reporter.h", ], - visibility = ["//visibility:public"], + visibility = internal_visibility(["//tensorflow/core/util:__pkg__"]), ) filegroup( @@ -304,7 +310,12 @@ filegroup( srcs = [ "onednn_threadpool.h", ], - visibility = ["//visibility:public"], + visibility = internal_visibility([ + "@local_xla//xla:__subpackages__", + "//tensorflow/core:__pkg__", + "//tensorflow/core/framework:__pkg__", + "//tensorflow/core/util:__pkg__", + ]), ) filegroup( @@ -313,7 +324,10 @@ filegroup( srcs = [ "reporter.h", ], - visibility = ["//visibility:public"], + visibility = internal_visibility([ + "//tensorflow/core:__pkg__", + "//tensorflow/core/util:__pkg__", + ]), ) filegroup( @@ -323,5 +337,8 @@ filegroup( "reporter.cc", ":android_test_hdrs", ], - visibility = ["//visibility:public"], + visibility = internal_visibility([ + "//tensorflow/core:__pkg__", + "//tensorflow/core/util:__pkg__", + ]), ) diff --git a/third_party/xla/third_party/tsl/tsl/util/onednn_threadpool.h b/third_party/xla/third_party/tsl/tsl/util/onednn_threadpool.h index 82fbec738f00ee..7d8a093ae89fa6 100644 --- a/third_party/xla/third_party/tsl/tsl/util/onednn_threadpool.h +++ b/third_party/xla/third_party/tsl/tsl/util/onednn_threadpool.h @@ -151,6 +151,16 @@ class OneDnnThreadPool : public threadpool_iface { ~OneDnnThreadPool() {} + static void set_onednn_max_threads(int num_threads) { +#if DNNL_VERSION_MAJOR >= 3 || \ + (DNNL_VERSION_MAJOR == 2 && DNNL_VERSION_MINOR >= 7) +#ifndef DNNL_AARCH64_USE_ACL + dnnl_threadpool_interop_set_max_concurrency(num_threads); +#endif // DNNL_AARCH64_USE_ACL +#endif // DNNL_VERSION_MAJOR >= 3 || + // (DNNL_VERSION_MAJOR == 2 && DNNL_VERSION_MINOR >= 7) + } + private: Eigen::ThreadPoolInterface* eigen_interface_ = nullptr; int num_threads_ = 1; // Execute in caller thread. @@ -159,13 +169,7 @@ class OneDnnThreadPool : public threadpool_iface { inline void set_num_and_max_threads(int num_threads) { num_threads_ = num_threads == -1 ? eigen_interface_->NumThreads() : num_threads; -#if DNNL_VERSION_MAJOR >= 3 || \ - (DNNL_VERSION_MAJOR == 2 && DNNL_VERSION_MINOR >= 7) -#ifndef DNNL_AARCH64_USE_ACL - dnnl_threadpool_interop_set_max_concurrency(num_threads_); -#endif // DNNL_AARCH64_USE_ACL -#endif // DNNL_VERSION_MAJOR >= 3 || - // (DNNL_VERSION_MAJOR == 2 && DNNL_VERSION_MINOR >= 7) + set_onednn_max_threads(num_threads_); } }; @@ -178,6 +182,7 @@ class OneDnnThreadPool { OneDnnThreadPool(Eigen::ThreadPoolInterface* eigen_interface) {} OneDnnThreadPool(Eigen::ThreadPoolInterface* eigen_interface, bool can_use_caller_thread, int num_threads = -1) {} + static void set_onednn_max_threads(int num_threads) {} }; #endif // !ENABLE_ONEDNN_OPENMP diff --git a/third_party/xla/third_party/tsl/tsl/util/proto/BUILD b/third_party/xla/third_party/tsl/tsl/util/proto/BUILD index 07f4f10718ca03..2752d1f13e07d0 100644 --- a/third_party/xla/third_party/tsl/tsl/util/proto/BUILD +++ b/third_party/xla/third_party/tsl/tsl/util/proto/BUILD @@ -4,14 +4,16 @@ load( ) package( - default_visibility = ["//visibility:public"], + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = [ + "//visibility:public", + ], licenses = ["notice"], ) cc_library( name = "proto_utils", hdrs = ["proto_utils.h"], - visibility = ["//visibility:public"], deps = [ "@com_google_absl//absl/time", "@com_google_protobuf//:protobuf_headers", diff --git a/third_party/xla/third_party/tsl/workspace2.bzl b/third_party/xla/third_party/tsl/workspace2.bzl index 5bfb6fd033db27..e23dcc3a4c7ad6 100644 --- a/third_party/xla/third_party/tsl/workspace2.bzl +++ b/third_party/xla/third_party/tsl/workspace2.bzl @@ -22,7 +22,6 @@ load("//third_party/gpus:cuda_configure.bzl", "cuda_configure") load("//third_party/gpus:rocm_configure.bzl", "rocm_configure") load("//third_party/hwloc:workspace.bzl", hwloc = "repo") load("//third_party/implib_so:workspace.bzl", implib_so = "repo") -load("//third_party/jpeg:workspace.bzl", jpeg = "repo") load("//third_party/llvm:setup.bzl", "llvm_setup") load("//third_party/nasm:workspace.bzl", nasm = "repo") load("//third_party/nccl:nccl_configure.bzl", "nccl_configure") @@ -51,7 +50,6 @@ def _initialize_third_party(): gemmlowp() hwloc() implib_so() - jpeg() ml_dtypes() nasm() pybind11_abseil() @@ -250,26 +248,6 @@ def _tf_repositories(): urls = tf_mirror_urls("https://github.com/GoogleCloudPlatform/tensorflow-gcp-tools/archive/2643d8caeba6ca2a6a0b46bb123953cb95b7e7d5.tar.gz"), ) - tf_http_archive( - name = "png", - build_file = "//third_party:png.BUILD", - patch_file = ["//third_party:png_fix_rpi.patch"], - sha256 = "a00e9d2f2f664186e4202db9299397f851aea71b36a35e74910b8820e380d441", - strip_prefix = "libpng-1.6.39", - system_build_file = "//third_party/systemlibs:png.BUILD", - urls = tf_mirror_urls("https://github.com/glennrp/libpng/archive/v1.6.39.tar.gz"), - ) - - tf_http_archive( - name = "gif", - build_file = "//third_party:gif.BUILD", - patch_file = ["//third_party:gif_fix_strtok_r.patch"], - sha256 = "31da5562f44c5f15d63340a09a4fd62b48c45620cd302f77a6d9acf0077879bd", - strip_prefix = "giflib-5.2.1", - system_build_file = "//third_party/systemlibs:gif.BUILD", - urls = tf_mirror_urls("https://pilotfiber.dl.sourceforge.net/project/giflib/giflib-5.2.1.tar.gz"), - ) - tf_http_archive( name = "six_archive", build_file = "//third_party:six.BUILD", @@ -403,7 +381,6 @@ def _tf_repositories(): urls = tf_mirror_urls("https://github.com/open-source-parsers/jsoncpp/archive/1.9.5.tar.gz"), ) - # Note: if you update this, you have to update libpng too. See cl/437813808 tf_http_archive( name = "zlib", build_file = "//third_party:zlib.BUILD", diff --git a/third_party/xla/tools/ci_build/gpu_build/BUILD b/third_party/xla/tools/ci_build/gpu_build/BUILD index 29177db0620a8b..743f8f19dba0e8 100644 --- a/third_party/xla/tools/ci_build/gpu_build/BUILD +++ b/third_party/xla/tools/ci_build/gpu_build/BUILD @@ -3,7 +3,6 @@ # learning applications. package( - default_visibility = ["//visibility:public"], # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], licenses = ["notice"], ) @@ -22,5 +21,5 @@ filegroup( "**/OWNERS", ], ), - visibility = ["//visibility:public"], + visibility = ["//visibility:private"], ) diff --git a/third_party/xla/tools/toolchains/cpus/py/BUILD b/third_party/xla/tools/toolchains/cpus/py/BUILD index 54feb1695a21e3..1235988abb7fa9 100644 --- a/third_party/xla/tools/toolchains/cpus/py/BUILD +++ b/third_party/xla/tools/toolchains/cpus/py/BUILD @@ -22,7 +22,6 @@ cc_library( name = "python_headers", hdrs = [":python_include"], includes = ["python_include"], - visibility = ["//visibility:public"], deps = select({ ":windows": [":python_lib"], "//conditions:default": [], @@ -33,7 +32,6 @@ cc_library( name = "numpy_headers", hdrs = [":numpy_include"], includes = ["numpy_include"], - visibility = ["//visibility:public"], ) config_setting( diff --git a/third_party/xla/tools/toolchains/cpus/py3/BUILD b/third_party/xla/tools/toolchains/cpus/py3/BUILD index 5dc47b98284c89..d47256ebef88fa 100644 --- a/third_party/xla/tools/toolchains/cpus/py3/BUILD +++ b/third_party/xla/tools/toolchains/cpus/py3/BUILD @@ -22,7 +22,6 @@ cc_library( name = "python_headers", hdrs = [":python_include"], includes = ["python_include"], - visibility = ["//visibility:public"], deps = select({ ":windows": [":python_lib"], "//conditions:default": [], @@ -33,7 +32,6 @@ cc_library( name = "numpy_headers", hdrs = [":numpy_include"], includes = ["numpy_include"], - visibility = ["//visibility:public"], ) config_setting( diff --git a/third_party/xla/tools/toolchains/cross_compile/cc/BUILD b/third_party/xla/tools/toolchains/cross_compile/cc/BUILD index 976e57b777f15d..7cf6d8c3747b27 100644 --- a/third_party/xla/tools/toolchains/cross_compile/cc/BUILD +++ b/third_party/xla/tools/toolchains/cross_compile/cc/BUILD @@ -15,10 +15,7 @@ cc_toolchain_suite( }, ) -filegroup( - name = "empty", - visibility = ["//visibility:public"], -) +filegroup(name = "empty") # We define a wraper ("cc_wrapper.sh") around the compiler to replace all paths # in the binary (bazel-out/.../path/to/original/library.so) by the paths @@ -27,7 +24,6 @@ filegroup( filegroup( name = "cc_wrapper_and_macos_sysroot", srcs = ["cc_wrapper.sh"] + glob(["MacOSX.sdk/**"]), - visibility = ["//visibility:public"], ) cc_toolchain( diff --git a/third_party/xla/tools/toolchains/remote_config/configs.bzl b/third_party/xla/tools/toolchains/remote_config/configs.bzl index aab6ac89e37c67..7bbfb8b2854ca4 100644 --- a/third_party/xla/tools/toolchains/remote_config/configs.bzl +++ b/third_party/xla/tools/toolchains/remote_config/configs.bzl @@ -627,7 +627,7 @@ def initialize_rbe_configs(): sigbuild_tf_configs( name_container_map = { - "sigbuild-r2.16": "docker://gcr.io/tensorflow-sigs/build@sha256:22d863e6fe3f98946015b9e1264b2eeb8e56e504535a6c1d5e564cae65ae5d37", + "sigbuild-r2.16": "docker://gcr.io/tensorflow-sigs/build@sha256:842a5ba84d3658c5bf1f8a31e16284f7becc35409da0dfd71816afa3cd28d728", "sigbuild-r2.16-python3.9": "docker://gcr.io/tensorflow-sigs/build@sha256:22d863e6fe3f98946015b9e1264b2eeb8e56e504535a6c1d5e564cae65ae5d37", "sigbuild-r2.16-python3.10": "docker://gcr.io/tensorflow-sigs/build@sha256:da15288c8464153eadd35da720540a544b76aa9d78cceb42a6821b2f3e70a0fa", "sigbuild-r2.16-python3.11": "docker://gcr.io/tensorflow-sigs/build@sha256:842a5ba84d3658c5bf1f8a31e16284f7becc35409da0dfd71816afa3cd28d728", @@ -667,7 +667,7 @@ def initialize_rbe_configs(): sigbuild_tf_configs( name_container_map = { - "sigbuild-r2.16-clang": "docker://gcr.io/tensorflow-sigs/build@sha256:22d863e6fe3f98946015b9e1264b2eeb8e56e504535a6c1d5e564cae65ae5d37", + "sigbuild-r2.16-clang": "docker://gcr.io/tensorflow-sigs/build@sha256:842a5ba84d3658c5bf1f8a31e16284f7becc35409da0dfd71816afa3cd28d728", "sigbuild-r2.16-clang-python3.9": "docker://gcr.io/tensorflow-sigs/build@sha256:22d863e6fe3f98946015b9e1264b2eeb8e56e504535a6c1d5e564cae65ae5d37", "sigbuild-r2.16-clang-python3.10": "docker://gcr.io/tensorflow-sigs/build@sha256:da15288c8464153eadd35da720540a544b76aa9d78cceb42a6821b2f3e70a0fa", "sigbuild-r2.16-clang-python3.11": "docker://gcr.io/tensorflow-sigs/build@sha256:842a5ba84d3658c5bf1f8a31e16284f7becc35409da0dfd71816afa3cd28d728", diff --git a/third_party/xla/tools/toolchains/win/bazel_211/BUILD b/third_party/xla/tools/toolchains/win/bazel_211/BUILD index 07aff97390d02e..cc23c8ecb22680 100644 --- a/third_party/xla/tools/toolchains/win/bazel_211/BUILD +++ b/third_party/xla/tools/toolchains/win/bazel_211/BUILD @@ -22,31 +22,26 @@ package(default_visibility = ["//visibility:public"]) cc_library( name = "malloc", - visibility = ["//visibility:public"], ) filegroup( name = "empty", srcs = [], - visibility = ["//visibility:public"], ) filegroup( name = "mingw_compiler_files", srcs = [":builtin_include_directory_paths_mingw"], - visibility = ["//visibility:public"], ) filegroup( name = "clangcl_compiler_files", srcs = [":builtin_include_directory_paths_clangcl"], - visibility = ["//visibility:public"], ) filegroup( name = "msvc_compiler_files", srcs = [":builtin_include_directory_paths_msvc"], - visibility = ["//visibility:public"], ) # Hardcoded toolchain, legacy behaviour. @@ -358,5 +353,4 @@ toolchain( filegroup( name = "link_dynamic_library", srcs = ["link_dynamic_library.sh"], - visibility = ["//visibility:public"], ) diff --git a/third_party/xla/tools/toolchains/win/tf_win_05022023/BUILD b/third_party/xla/tools/toolchains/win/tf_win_05022023/BUILD index 4d6c76ca644f07..f245f6d0789c9d 100644 --- a/third_party/xla/tools/toolchains/win/tf_win_05022023/BUILD +++ b/third_party/xla/tools/toolchains/win/tf_win_05022023/BUILD @@ -22,31 +22,26 @@ package(default_visibility = ["//visibility:public"]) cc_library( name = "malloc", - visibility = ["//visibility:public"], ) filegroup( name = "empty", srcs = [], - visibility = ["//visibility:public"], ) filegroup( name = "mingw_compiler_files", srcs = [":builtin_include_directory_paths_mingw"], - visibility = ["//visibility:public"], ) filegroup( name = "clangcl_compiler_files", srcs = [":builtin_include_directory_paths_clangcl"], - visibility = ["//visibility:public"], ) filegroup( name = "msvc_compiler_files", srcs = [":builtin_include_directory_paths_msvc"], - visibility = ["//visibility:public"], ) # Hardcoded toolchain, legacy behaviour. diff --git a/third_party/xla/tools/toolchains/win_1803/py38/BUILD b/third_party/xla/tools/toolchains/win_1803/py38/BUILD index 3efde36a05097c..9aa4d82e6daca6 100644 --- a/third_party/xla/tools/toolchains/win_1803/py38/BUILD +++ b/third_party/xla/tools/toolchains/win_1803/py38/BUILD @@ -39,7 +39,6 @@ cc_library( name = "python_headers", hdrs = [":python_include"], includes = ["python_include"], - visibility = ["//visibility:public"], deps = select({ ":windows": [":python_lib"], "//conditions:default": [], @@ -50,7 +49,6 @@ cc_library( name = "numpy_headers", hdrs = [":numpy_include"], includes = ["numpy_include"], - visibility = ["//visibility:public"], ) config_setting( diff --git a/third_party/xla/tools/toolchains/win_1803/py39/BUILD b/third_party/xla/tools/toolchains/win_1803/py39/BUILD index f3892df8ef4fae..f5b545cb161b3a 100644 --- a/third_party/xla/tools/toolchains/win_1803/py39/BUILD +++ b/third_party/xla/tools/toolchains/win_1803/py39/BUILD @@ -66,7 +66,6 @@ cc_library( name = "python_headers", hdrs = [":python_include"], includes = ["python_include"], - visibility = ["//visibility:public"], deps = select({ ":windows": [":python_lib"], "//conditions:default": [], @@ -77,7 +76,6 @@ cc_library( name = "numpy_headers", hdrs = [":numpy_include"], includes = ["numpy_include"], - visibility = ["//visibility:public"], ) config_setting( diff --git a/third_party/xla/workspace2.bzl b/third_party/xla/workspace2.bzl index 587d7874d8ac60..b0ebdf48ff8f08 100644 --- a/third_party/xla/workspace2.bzl +++ b/third_party/xla/workspace2.bzl @@ -1,10 +1,10 @@ """TensorFlow workspace initialization. Consult the WORKSPACE on how to use it.""" -# Import TSL Workspaces -load("@local_tsl//:workspace2.bzl", "tsl_workspace2") - # Import third party config rules. load("@bazel_skylib//lib:versions.bzl", "versions") + +# Import TSL Workspaces +load("@local_tsl//:workspace2.bzl", "tsl_workspace2") load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # Import third party repository rules. See go/tfbr-thirdparty. @@ -34,9 +34,9 @@ def _tf_repositories(): name = "cudnn_frontend_archive", build_file = "//third_party:cudnn_frontend.BUILD", patch_file = ["//third_party:cudnn_frontend_header_fix.patch"], - sha256 = "015ea933139a30e9ccd177b5e0dbfb16f3d08df78334aaacea57880275df734b", - strip_prefix = "cudnn-frontend-1.0.0", - urls = tf_mirror_urls("https://github.com/NVIDIA/cudnn-frontend/archive/refs/tags/v1.0.0.zip"), + sha256 = "c2f5373ddf84e33d289dad5766667f52de652dfbbb1dccb2fada9cfcf2d774cf", + strip_prefix = "cudnn-frontend-1.1.0", + urls = tf_mirror_urls("https://github.com/NVIDIA/cudnn-frontend/archive/refs/tags/v1.1.0.zip"), ) tf_http_archive( @@ -94,6 +94,13 @@ def _tf_repositories(): #url = "http://www.tcs.hut.fi/Software/bliss/bliss-0.73.zip", ) + tf_http_archive( + name = "pybind11_protobuf", + urls = tf_mirror_urls("https://github.com/pybind/pybind11_protobuf/archive/80f3440cd8fee124e077e2e47a8a17b78b451363.zip"), + sha256 = "c7ab64b1ccf9a678694a89035a8c865a693e4e872803778f91f0965c2f281d78", + strip_prefix = "pybind11_protobuf-80f3440cd8fee124e077e2e47a8a17b78b451363", + ) + # buildifier: disable=function-docstring # buildifier: disable=unnamed-macro def workspace(): diff --git a/third_party/xla/xla/BUILD b/third_party/xla/xla/BUILD index e2306f203aef36..5d113506e817b5 100644 --- a/third_party/xla/xla/BUILD +++ b/third_party/xla/xla/BUILD @@ -1,6 +1,7 @@ # Placeholder: load py_proto_library load("//xla:xla.bzl", "xla_cc_test", "xla_py_proto_library") load("//third_party/compute_library:build_defs.bzl", "if_enable_acl") +load("@local_tsl//tsl:tsl.bzl", "internal_visibility") load("@local_tsl//tsl:tsl.default.bzl", "filegroup", "get_compatible_with_portable") load( "@local_tsl//tsl/platform:build_config.bzl", @@ -9,7 +10,8 @@ load( load("@local_tsl//tsl/platform:rules_cc.bzl", "cc_library") package( - default_visibility = ["//visibility:public"], + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = internal_visibility(["//xla:internal"]), licenses = ["notice"], ) @@ -47,12 +49,9 @@ package_group( ], ) -exports_files( - [ - "lit.cfg.py", - ], - visibility = ["//visibility:public"], -) +exports_files([ + "lit.cfg.py", +]) # Filegroup used to collect source files for dependency checking. filegroup( @@ -61,7 +60,6 @@ filegroup( "**/*.cc", "**/*.h", ]), - visibility = ["//visibility:public"], ) filegroup( @@ -70,7 +68,7 @@ filegroup( "cpu_function_runtime.cc", "executable_run_options.cc", ], - visibility = ["//visibility:public"], + visibility = internal_visibility([":friends"]), ) filegroup( @@ -80,7 +78,7 @@ filegroup( "executable_run_options.h", "types.h", ], - visibility = ["//visibility:public"], + visibility = internal_visibility([":friends"]), ) tf_proto_library( @@ -108,7 +106,7 @@ tf_proto_library( cc_library( name = "bit_cast", hdrs = ["bit_cast.h"], - visibility = ["//visibility:public"], + visibility = internal_visibility([":friends"]), deps = [ ":types", "@com_google_absl//absl/base", @@ -138,7 +136,7 @@ cc_library( "comparison_util.h", "primitive_util.h", ], - visibility = ["//visibility:public"], + visibility = internal_visibility([":friends"]), deps = [ ":shape_util", ":statusor", @@ -168,14 +166,14 @@ xla_cc_test( cc_library( name = "compiler_macros", hdrs = ["compiler_macros.h"], - visibility = ["//visibility:public"], + visibility = internal_visibility([":friends"]), ) cc_library( name = "ef57", srcs = ["ef57.cc"], hdrs = ["ef57.h"], - visibility = ["//visibility:public"], + visibility = internal_visibility([":friends"]), deps = [ ":compiler_macros", "@com_google_absl//absl/types:span", @@ -205,7 +203,7 @@ cc_library( hdrs = [ "execution_options_util.h", ], - visibility = ["//visibility:public"], + visibility = internal_visibility([":friends"]), deps = [ ":debug_options_flags", ":xla_proto_cc", @@ -220,7 +218,7 @@ cc_library( hdrs = [ "frontend_attributes.h", ], - visibility = ["//visibility:public"], + visibility = internal_visibility([":friends"]), deps = ["//xla/hlo/ir:hlo"], ) @@ -228,7 +226,7 @@ cc_library( name = "test", testonly = 1, hdrs = ["test.h"], - visibility = ["//visibility:public"], + visibility = internal_visibility([":friends"]), deps = [ "@local_tsl//tsl/platform", "@local_tsl//tsl/platform:test", @@ -239,11 +237,10 @@ cc_library( name = "types", hdrs = ["types.h"], compatible_with = get_compatible_with_portable(), - visibility = ["//visibility:public"], + visibility = internal_visibility([":friends"]), deps = [ - "@com_google_absl//absl/strings:str_format", "@eigen_archive//:eigen3", - "@ml_dtypes//:int4", + "@local_tsl//tsl/platform:ml_dtypes", ], ) @@ -265,7 +262,7 @@ cc_library( name = "service_interface", srcs = [], hdrs = ["service_interface.h"], - visibility = ["//visibility:public"], + visibility = internal_visibility([":friends"]), deps = [ ":status", ":xla_data_proto_cc", @@ -389,7 +386,6 @@ cc_library( name = "permutation_util", srcs = ["permutation_util.cc"], hdrs = ["permutation_util.h"], - visibility = ["//visibility:public"], deps = [ ":types", "@com_google_absl//absl/container:inlined_vector", @@ -685,14 +681,14 @@ cc_library( cc_library( name = "error_spec", hdrs = ["error_spec.h"], - visibility = ["//visibility:public"], + visibility = internal_visibility([":friends"]), ) cc_library( name = "literal_comparison", srcs = ["literal_comparison.cc"], hdrs = ["literal_comparison.h"], - visibility = ["//visibility:public"], + visibility = internal_visibility([":friends"]), deps = [ ":error_spec", ":literal", @@ -741,7 +737,7 @@ cc_library( name = "array", srcs = ["array.cc"], hdrs = ["array.h"], - visibility = ["//visibility:public"], + visibility = internal_visibility([":friends"]), deps = [ ":status", ":types", @@ -787,7 +783,7 @@ xla_cc_test( cc_library( name = "array3d", hdrs = ["array3d.h"], - visibility = ["//visibility:public"], + visibility = internal_visibility([":friends"]), deps = [ ":array", ":types", @@ -809,7 +805,7 @@ xla_cc_test( cc_library( name = "array4d", hdrs = ["array4d.h"], - visibility = ["//visibility:public"], + visibility = internal_visibility([":friends"]), deps = [ ":array", ":array2d", @@ -847,7 +843,7 @@ cc_library( name = "packed_literal_reader", srcs = ["packed_literal_reader.cc"], hdrs = ["packed_literal_reader.h"], - visibility = ["//visibility:public"], + visibility = internal_visibility([":friends"]), deps = [ ":literal", ":shape_util", @@ -867,7 +863,7 @@ cc_library( name = "test_helpers", testonly = 1, hdrs = ["test_helpers.h"], - visibility = ["//visibility:public"], + visibility = internal_visibility([":friends"]), deps = [ ":statusor", ":types", @@ -882,7 +878,7 @@ cc_library( name = "text_literal_reader", srcs = ["text_literal_reader.cc"], hdrs = ["text_literal_reader.h"], - visibility = ["//visibility:public"], + visibility = internal_visibility([":friends"]), deps = [ ":literal", ":shape_util", @@ -919,7 +915,7 @@ cc_library( name = "text_literal_writer", srcs = ["text_literal_writer.cc"], hdrs = ["text_literal_writer.h"], - visibility = ["//visibility:public"], + visibility = internal_visibility([":friends"]), deps = [ ":literal", ":shape_util", @@ -1068,7 +1064,6 @@ cc_library( name = "parse_flags_from_env", srcs = ["parse_flags_from_env.cc"], hdrs = ["parse_flags_from_env.h"], - visibility = ["//visibility:public"], deps = [ ":types", @@ -1104,11 +1099,12 @@ cc_library( ], hdrs = ["debug_options_flags.h"], copts = if_enable_acl(["-DXLA_CPU_USE_ACL=1"]), - visibility = ["//visibility:public"], + visibility = internal_visibility([":friends"]), deps = [ ":parse_flags_from_env", ":xla_proto_cc", + "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:node_hash_map", @@ -1125,7 +1121,7 @@ cc_library( srcs = ["cpu_function_runtime.cc"], hdrs = ["cpu_function_runtime.h"], compatible_with = get_compatible_with_portable(), - visibility = ["//visibility:public"], + visibility = internal_visibility([":friends"]), deps = [ "@com_google_absl//absl/base:dynamic_annotations", ], @@ -1150,7 +1146,6 @@ xla_cc_test( cc_library( name = "refcounting_hash_map", hdrs = ["refcounting_hash_map.h"], - visibility = ["//visibility:public"], deps = [ ":statusor", "@com_google_absl//absl/base:core_headers", @@ -1173,20 +1168,17 @@ xla_cc_test( cc_library( name = "union_find", hdrs = ["union_find.h"], - visibility = ["//visibility:public"], ) cc_library( name = "side_effect_util", srcs = ["side_effect_util.cc"], hdrs = ["side_effect_util.h"], - visibility = ["//visibility:public"], ) cc_library( name = "lazy", hdrs = ["lazy.h"], - visibility = ["//visibility:public"], deps = ["@com_google_absl//absl/functional:any_invocable"], ) @@ -1213,7 +1205,6 @@ tf_proto_library( srcs = ["autotuning.proto"], make_default_target_header_only = True, protodeps = ["@local_tsl//tsl/protobuf:dnn_proto"], - visibility = ["//visibility:public"], ) cc_library( @@ -1234,7 +1225,7 @@ cc_library( # py_proto_library( # name = "xla_data_proto_py_pb2", # api_version = 2, -# visibility = [":friends"], +# visibility = internal_visibility([":friends"]), # deps = [":xla_data_proto"], # ) # @@ -1243,7 +1234,7 @@ cc_library( # testonly = 0, # api_version = 2, # compatible_with = ["//buildenv/target:non_prod"], -# visibility = [":friends"], +# visibility = internal_visibility([":friends"]), # deps = [":xla_proto"], # ) # copybara:uncomment_end diff --git a/third_party/xla/xla/autotuning.proto b/third_party/xla/xla/autotuning.proto index b177fe3e2895ea..9a867ab7e21776 100644 --- a/third_party/xla/xla/autotuning.proto +++ b/third_party/xla/xla/autotuning.proto @@ -82,6 +82,7 @@ message AutotuneResult { int64 split_k = 4; int64 num_stages = 5; int64 num_warps = 6; + int64 num_ctas = 7; } int64 scratch_bytes = 8; diff --git a/third_party/xla/xla/backends/interpreter/BUILD b/third_party/xla/xla/backends/interpreter/BUILD index d7ccdee29ceddc..e065ca884b578c 100644 --- a/third_party/xla/xla/backends/interpreter/BUILD +++ b/third_party/xla/xla/backends/interpreter/BUILD @@ -5,6 +5,7 @@ load( load("@local_tsl//tsl/platform:rules_cc.bzl", "cc_library") package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], default_visibility = ["//visibility:public"], licenses = ["notice"], ) @@ -13,7 +14,6 @@ cc_library( name = "interpreter_transfer_manager", srcs = ["interpreter_transfer_manager.cc"], hdrs = ["interpreter_transfer_manager.h"], - visibility = ["//visibility:public"], deps = [ ":platform_id", "//xla/service:generic_transfer_manager", @@ -26,7 +26,6 @@ cc_library( name = "compiler", srcs = ["compiler.cc"], hdrs = ["compiler.h"], - visibility = ["//visibility:public"], deps = [ ":executable", ":platform_id", @@ -69,7 +68,6 @@ cc_library( name = "platform_id", srcs = ["platform_id.cc"], hdrs = ["platform_id.h"], - visibility = ["//visibility:public"], deps = ["//xla/stream_executor"] + if_static( ["@com_google_protobuf//:protobuf"], ["@com_google_protobuf//:protobuf_headers"], @@ -80,7 +78,6 @@ cc_library( name = "executable_base", srcs = ["executable_base.cc"], hdrs = ["executable_base.h"], - visibility = ["//visibility:public"], deps = [ "//xla:literal", "//xla:shape_tree", @@ -103,7 +100,6 @@ cc_library( name = "executable", srcs = ["executable.cc"], hdrs = ["executable.h"], - visibility = ["//visibility:public"], deps = [ ":executable_base", ":executor", @@ -133,7 +129,6 @@ cc_library( name = "platform", srcs = ["platform.cc"], hdrs = ["platform.h"], - visibility = ["//visibility:public"], deps = [ ":executor", ":platform_id", @@ -143,14 +138,13 @@ cc_library( "@local_tsl//tsl/platform:status", "@local_tsl//tsl/platform:statusor", ], - alwayslink = True, # Registers itself with the MultiPlatformManager. + alwayslink = True, # Registers itself with the PlatformManager. ) cc_library( name = "executor", srcs = ["executor.cc"], hdrs = ["executor.h"], - visibility = ["//visibility:public"], deps = [ "//xla:shape_util", "//xla:status_macros", diff --git a/third_party/xla/xla/backends/interpreter/compiler.cc b/third_party/xla/xla/backends/interpreter/compiler.cc index bc38a84fa462c2..e48f63f77929ae 100644 --- a/third_party/xla/xla/backends/interpreter/compiler.cc +++ b/third_party/xla/xla/backends/interpreter/compiler.cc @@ -61,7 +61,7 @@ namespace { // Handles custom_call ops during evaluation by routing them through the global // CPU registry used by other CPU-based backends. -StatusOr HandleEvaluatorCustomCall( +absl::StatusOr HandleEvaluatorCustomCall( const HloInstruction* custom_call, absl::Span operands) { // Find the target C function in the global registry. auto* registry = CustomCallTargetRegistry::Global(); @@ -92,12 +92,13 @@ StatusOr HandleEvaluatorCustomCall( Status InterpreterCompiler::RunHloOptimization(HloModule* hlo_module) { HloPassPipeline pipeline("Interpreter"); + // The TopkDecomposer generates a compare op with type=TOTALORDER and must + // run before the ComparisonExpander which rewrites such comparisons. pipeline.AddPass(); pipeline.AddPass(); pipeline.AddPass(); pipeline.AddPass(); pipeline.AddPass(); - pipeline.AddPass(); pipeline.AddPass(); pipeline.AddPass( /*rewrite_training_op=*/true, @@ -109,7 +110,7 @@ Status InterpreterCompiler::RunHloOptimization(HloModule* hlo_module) { return pipeline.Run(hlo_module).status(); } -StatusOr> InterpreterCompiler::RunHloPasses( +absl::StatusOr> InterpreterCompiler::RunHloPasses( std::unique_ptr hlo_module, se::StreamExecutor* /*stream_exec*/, const CompileOptions& /*options*/) { VLOG(1) << "Run hlo passes on graph " << hlo_module->name(); @@ -117,7 +118,7 @@ StatusOr> InterpreterCompiler::RunHloPasses( return std::move(hlo_module); } -StatusOr> InterpreterCompiler::RunBackend( +absl::StatusOr> InterpreterCompiler::RunBackend( std::unique_ptr hlo_module, se::StreamExecutor* stream_exec, const CompileOptions& /*options*/) { TF_RET_CHECK(stream_exec != nullptr); @@ -146,7 +147,8 @@ StatusOr> InterpreterCompiler::RunBackend( return std::move(executable); } -StatusOr>> InterpreterCompiler::Compile( +absl::StatusOr>> +InterpreterCompiler::Compile( std::unique_ptr module_group, std::vector> stream_exec, const CompileOptions& options) { @@ -170,7 +172,7 @@ StatusOr>> InterpreterCompiler::Compile( return std::move(ret); } -StatusOr>> +absl::StatusOr>> InterpreterCompiler::CompileAheadOfTime( std::unique_ptr module_group, const AotCompilationOptions& aot_options) { diff --git a/third_party/xla/xla/backends/interpreter/compiler.h b/third_party/xla/xla/backends/interpreter/compiler.h index 3fcf2a849e655f..cfdbaa4fd23928 100644 --- a/third_party/xla/xla/backends/interpreter/compiler.h +++ b/third_party/xla/xla/backends/interpreter/compiler.h @@ -42,18 +42,18 @@ class InterpreterCompiler : public Compiler { InterpreterCompiler() {} ~InterpreterCompiler() override {} - StatusOr> RunHloPasses( + absl::StatusOr> RunHloPasses( std::unique_ptr hlo_module, se::StreamExecutor* stream_exec, const CompileOptions& options) override; - StatusOr> RunBackend( + absl::StatusOr> RunBackend( std::unique_ptr hlo_module, se::StreamExecutor* stream_exec, const CompileOptions& options) override; - StatusOr>> Compile( + absl::StatusOr>> Compile( std::unique_ptr module_group, std::vector> stream_exec, const CompileOptions& options) override; - StatusOr>> + absl::StatusOr>> CompileAheadOfTime(std::unique_ptr module_group, const AotCompilationOptions& aot_options) override; diff --git a/third_party/xla/xla/backends/interpreter/executable.cc b/third_party/xla/xla/backends/interpreter/executable.cc index f429cfc9a17fad..bd7b6261a85158 100644 --- a/third_party/xla/xla/backends/interpreter/executable.cc +++ b/third_party/xla/xla/backends/interpreter/executable.cc @@ -51,7 +51,7 @@ InterpreterExecutable::InterpreterExecutable( } } -StatusOr InterpreterExecutable::Evaluate( +absl::StatusOr InterpreterExecutable::Evaluate( const ServiceExecutableRunOptions* run_options, const HloComputation& computation, absl::Span arg_literals) { // Execute the graph using the HloEvaluator. diff --git a/third_party/xla/xla/backends/interpreter/executable.h b/third_party/xla/xla/backends/interpreter/executable.h index 3ac500b2b8ea6b..4a66f3bb375e29 100644 --- a/third_party/xla/xla/backends/interpreter/executable.h +++ b/third_party/xla/xla/backends/interpreter/executable.h @@ -48,9 +48,10 @@ class InterpreterExecutable : public InterpreterExecutableBase { static int64_t ShapeSizeBytes(const Shape& shape); protected: - StatusOr Evaluate(const ServiceExecutableRunOptions* run_options, - const HloComputation& computation, - absl::Span arg_literals) override + absl::StatusOr Evaluate( + const ServiceExecutableRunOptions* run_options, + const HloComputation& computation, + absl::Span arg_literals) override ABSL_LOCKS_EXCLUDED(evaluator_lock_); // The interpreter interprets executables with an HloEvaluator. diff --git a/third_party/xla/xla/backends/interpreter/executable_base.cc b/third_party/xla/xla/backends/interpreter/executable_base.cc index aef3732a003a89..7329c5d09f3f30 100644 --- a/third_party/xla/xla/backends/interpreter/executable_base.cc +++ b/third_party/xla/xla/backends/interpreter/executable_base.cc @@ -38,7 +38,7 @@ InterpreterExecutableBase::InterpreterExecutableBase( : Executable(std::move(hlo_module), /*hlo_profile_printer_data=*/nullptr, /*hlo_profile_index_map=*/nullptr) {} -StatusOr InterpreterExecutableBase::ExecuteAsyncOnStream( +absl::StatusOr InterpreterExecutableBase::ExecuteAsyncOnStream( const ServiceExecutableRunOptions* run_options, std::vector arguments, HloExecutionProfile* hlo_execution_profile) { @@ -150,7 +150,7 @@ StatusOr InterpreterExecutableBase::ExecuteAsyncOnStream( return std::move(result); } -StatusOr +absl::StatusOr InterpreterExecutableBase::AllocateOutputMemoryWithInputReuse( const Shape& shape, const HloInputOutputAliasConfig& alias_config, se::DeviceMemoryAllocator* allocator, diff --git a/third_party/xla/xla/backends/interpreter/executable_base.h b/third_party/xla/xla/backends/interpreter/executable_base.h index 41681f4a86ba1f..fa55e567464435 100644 --- a/third_party/xla/xla/backends/interpreter/executable_base.h +++ b/third_party/xla/xla/backends/interpreter/executable_base.h @@ -37,19 +37,19 @@ class InterpreterExecutableBase : public Executable { public: explicit InterpreterExecutableBase(std::unique_ptr hlo_module); - StatusOr ExecuteAsyncOnStream( + absl::StatusOr ExecuteAsyncOnStream( const ServiceExecutableRunOptions* run_options, std::vector arguments, HloExecutionProfile* hlo_execution_profile) override; protected: - virtual StatusOr Evaluate( + virtual absl::StatusOr Evaluate( const ServiceExecutableRunOptions* run_options, const HloComputation& computation, absl::Span arg_literals) = 0; private: - StatusOr AllocateOutputMemoryWithInputReuse( + absl::StatusOr AllocateOutputMemoryWithInputReuse( const Shape& shape, const HloInputOutputAliasConfig& alias_config, se::DeviceMemoryAllocator* allocator, std::vector* arguments, stream_executor::Stream* stream); diff --git a/third_party/xla/xla/backends/interpreter/executor.cc b/third_party/xla/xla/backends/interpreter/executor.cc index dea705fd907e9f..3aa8616033c908 100644 --- a/third_party/xla/xla/backends/interpreter/executor.cc +++ b/third_party/xla/xla/backends/interpreter/executor.cc @@ -43,9 +43,9 @@ bool XlaInterpreterExecutor::Memcpy(Stream *stream, void *host_dst, uint64_t size) { AsExecutorStream(stream)->EnqueueTask([this, host_dst, dev_src, size]() { // Ignore errors. - tsl::Status ok = SynchronousMemcpy(host_dst, dev_src, size); + absl::Status ok = SynchronousMemcpy(host_dst, dev_src, size); }); - tsl::Status status = AsExecutorStream(stream)->BlockUntilDone(); + absl::Status status = AsExecutorStream(stream)->BlockUntilDone(); if (status.ok()) { return true; } @@ -60,9 +60,9 @@ bool XlaInterpreterExecutor::Memcpy(Stream *stream, DeviceMemoryBase *dev_dst, const void *host_src, uint64_t size) { AsExecutorStream(stream)->EnqueueTask([this, dev_dst, host_src, size]() { // Ignore errors. - tsl::Status ok = SynchronousMemcpy(dev_dst, host_src, size); + absl::Status ok = SynchronousMemcpy(dev_dst, host_src, size); }); - tsl::Status status = AsExecutorStream(stream)->BlockUntilDone(); + absl::Status status = AsExecutorStream(stream)->BlockUntilDone(); if (status.ok()) { return true; } @@ -73,21 +73,20 @@ bool XlaInterpreterExecutor::Memcpy(Stream *stream, DeviceMemoryBase *dev_dst, return false; } -tsl::Status XlaInterpreterExecutor::SynchronousMemcpy(DeviceMemoryBase *dev_dst, - const void *host_src, - uint64_t size) { +absl::Status XlaInterpreterExecutor::SynchronousMemcpy( + DeviceMemoryBase *dev_dst, const void *host_src, uint64_t size) { memcpy(dev_dst->opaque(), host_src, size); - return ::tsl::OkStatus(); + return absl::OkStatus(); } -tsl::Status XlaInterpreterExecutor::SynchronousMemcpy( +absl::Status XlaInterpreterExecutor::SynchronousMemcpy( void *host_dst, const DeviceMemoryBase &dev_src, uint64_t size) { memcpy(host_dst, dev_src.opaque(), size); - return ::tsl::OkStatus(); + return absl::OkStatus(); } bool XlaInterpreterExecutor::HostCallback( - Stream *stream, absl::AnyInvocable callback) { + Stream *stream, absl::AnyInvocable callback) { AsExecutorStream(stream)->EnqueueTaskWithStatus(std::move(callback)); return true; } @@ -96,7 +95,7 @@ bool XlaInterpreterExecutor::CreateStreamDependency(Stream *dependent, Stream *other) { AsExecutorStream(dependent)->EnqueueTaskWithStatus( [other]() { return other->BlockHostUntilDone(); }); - tsl::Status status = AsExecutorStream(dependent)->BlockUntilDone(); + absl::Status status = AsExecutorStream(dependent)->BlockUntilDone(); if (status.ok()) { return true; } @@ -107,11 +106,11 @@ bool XlaInterpreterExecutor::CreateStreamDependency(Stream *dependent, return false; } -tsl::Status XlaInterpreterExecutor::BlockHostUntilDone(Stream *stream) { +absl::Status XlaInterpreterExecutor::BlockHostUntilDone(Stream *stream) { return AsExecutorStream(stream)->BlockUntilDone(); } -tsl::StatusOr> +absl::StatusOr> XlaInterpreterExecutor::CreateDeviceDescription(int device_ordinal) { internal::DeviceDescriptionBuilder builder; diff --git a/third_party/xla/xla/backends/interpreter/executor.h b/third_party/xla/xla/backends/interpreter/executor.h index 7de385986b70dd..79964a093ba9b6 100644 --- a/third_party/xla/xla/backends/interpreter/executor.h +++ b/third_party/xla/xla/backends/interpreter/executor.h @@ -47,19 +47,19 @@ class XlaInterpreterExecutor : public internal::StreamExecutorInterface { public: XlaInterpreterExecutor() = default; - tsl::Status Init(int device_ordinal, DeviceOptions device_options) override { + absl::Status Init(int device_ordinal, DeviceOptions device_options) override { device_ordinal_ = device_ordinal; - return ::tsl::OkStatus(); + return absl::OkStatus(); } int device_ordinal() const override { return device_ordinal_; }; - tsl::Status GetKernel(const MultiKernelLoaderSpec &spec, - Kernel *kernel) override { + absl::Status GetKernel(const MultiKernelLoaderSpec &spec, + Kernel *kernel) override { return tsl::errors::Unimplemented("Not Implemented"); } - tsl::Status Launch(Stream *stream, const ThreadDim &thread_dims, - const BlockDim &block_dims, const Kernel &kernel, - const KernelArgs &args) override { + absl::Status Launch(Stream *stream, const ThreadDim &thread_dims, + const BlockDim &block_dims, const Kernel &kernel, + const KernelArgs &args) override { return tsl::errors::Unimplemented("Not Implemented"); } @@ -83,56 +83,57 @@ class XlaInterpreterExecutor : public internal::StreamExecutorInterface { return false; } - tsl::Status MemZero(Stream *stream, DeviceMemoryBase *location, - uint64_t size) override { + absl::Status MemZero(Stream *stream, DeviceMemoryBase *location, + uint64_t size) override { return tsl::errors::Internal("Interpreter can not memzero"); } - tsl::Status Memset(Stream *stream, DeviceMemoryBase *location, - uint8_t pattern, uint64_t size) override { + absl::Status Memset(Stream *stream, DeviceMemoryBase *location, + uint8_t pattern, uint64_t size) override { return tsl::errors::Internal("Interpreter can not memset"); } - tsl::Status Memset32(Stream *stream, DeviceMemoryBase *location, - uint32_t pattern, uint64_t size) override { + absl::Status Memset32(Stream *stream, DeviceMemoryBase *location, + uint32_t pattern, uint64_t size) override { return tsl::errors::Internal("Interpreter can not memset"); } // No "synchronize all activity" implemented for this platform at the moment. bool SynchronizeAllActivity() override { return true; } - tsl::Status SynchronousMemZero(DeviceMemoryBase *location, - uint64_t size) override { + absl::Status SynchronousMemZero(DeviceMemoryBase *location, + uint64_t size) override { return tsl::errors::Internal("Interpreter can not memzero"); } - tsl::Status SynchronousMemSet(DeviceMemoryBase *location, int value, - uint64_t size) override { + absl::Status SynchronousMemSet(DeviceMemoryBase *location, int value, + uint64_t size) override { return tsl::errors::Internal("Interpreter can not memset"); } - tsl::Status SynchronousMemcpy(DeviceMemoryBase *dev_dst, const void *host_src, - uint64_t size) override; - tsl::Status SynchronousMemcpy(void *host_dst, const DeviceMemoryBase &dev_src, - uint64_t size) override; - tsl::Status SynchronousMemcpyDeviceToDevice(DeviceMemoryBase *pop_dst, - const DeviceMemoryBase &pop_src, - uint64_t size) override { - return tsl::Status{absl::StatusCode::kUnimplemented, ""}; + absl::Status SynchronousMemcpy(DeviceMemoryBase *dev_dst, + const void *host_src, uint64_t size) override; + absl::Status SynchronousMemcpy(void *host_dst, + const DeviceMemoryBase &dev_src, + uint64_t size) override; + absl::Status SynchronousMemcpyDeviceToDevice(DeviceMemoryBase *pop_dst, + const DeviceMemoryBase &pop_src, + uint64_t size) override { + return absl::Status{absl::StatusCode::kUnimplemented, ""}; } bool HostCallback(Stream *stream, - absl::AnyInvocable callback) override; + absl::AnyInvocable callback) override; - tsl::Status AllocateEvent(Event *event) override { return ::tsl::OkStatus(); } + absl::Status AllocateEvent(Event *event) override { return absl::OkStatus(); } - tsl::Status DeallocateEvent(Event *event) override { - return ::tsl::OkStatus(); + absl::Status DeallocateEvent(Event *event) override { + return absl::OkStatus(); } - tsl::Status RecordEvent(Stream *stream, Event *event) override { - return tsl::Status{absl::StatusCode::kUnimplemented, "RecordEvent"}; + absl::Status RecordEvent(Stream *stream, Event *event) override { + return absl::Status{absl::StatusCode::kUnimplemented, "RecordEvent"}; } - tsl::Status WaitForEvent(Stream *stream, Event *event) override { - return tsl::Status{absl::StatusCode::kUnimplemented, "WaitForEvent"}; + absl::Status WaitForEvent(Stream *stream, Event *event) override { + return absl::Status{absl::StatusCode::kUnimplemented, "WaitForEvent"}; } Event::Status PollForEventStatus(Event *event) override { @@ -143,22 +144,22 @@ class XlaInterpreterExecutor : public internal::StreamExecutorInterface { void DeallocateStream(Stream *stream) override {} bool CreateStreamDependency(Stream *dependent, Stream *other) override; - tsl::Status BlockHostUntilDone(Stream *stream) override; + absl::Status BlockHostUntilDone(Stream *stream) override; bool DeviceMemoryUsage(int64_t *free, int64_t *total) const override { return false; } - tsl::StatusOr> CreateDeviceDescription() + absl::StatusOr> CreateDeviceDescription() const override { return CreateDeviceDescription(0); } - static tsl::StatusOr> + static absl::StatusOr> CreateDeviceDescription(int device_ordinal); - tsl::Status EnablePeerAccessTo(StreamExecutorInterface *other) override { - return ::tsl::OkStatus(); + absl::Status EnablePeerAccessTo(StreamExecutorInterface *other) override { + return absl::OkStatus(); } bool CanEnablePeerAccessTo(StreamExecutorInterface *other) override { @@ -170,11 +171,6 @@ class XlaInterpreterExecutor : public internal::StreamExecutorInterface { return nullptr; } - std::unique_ptr CreateKernelImplementation() - override { - return nullptr; - } - std::unique_ptr GetStreamImplementation() override { return std::unique_ptr( @@ -188,7 +184,8 @@ class XlaInterpreterExecutor : public internal::StreamExecutorInterface { DeviceMemoryBase AllocateSingleOutput(const xla::Shape &shape); - tsl::StatusOr AllocateOutputBuffer(const xla::Shape &shape); + absl::StatusOr AllocateOutputBuffer( + const xla::Shape &shape); }; } // namespace interpreter diff --git a/third_party/xla/xla/backends/interpreter/platform.cc b/third_party/xla/xla/backends/interpreter/platform.cc index c0e1f65fd78bb5..823ba66ba005fa 100644 --- a/third_party/xla/xla/backends/interpreter/platform.cc +++ b/third_party/xla/xla/backends/interpreter/platform.cc @@ -21,9 +21,9 @@ limitations under the License. #include "absl/strings/str_format.h" #include "xla/backends/interpreter/executor.h" #include "xla/stream_executor/device_options.h" -#include "xla/stream_executor/multi_platform_manager.h" #include "xla/stream_executor/platform.h" #include "xla/stream_executor/platform/initialize.h" +#include "xla/stream_executor/platform_manager.h" #include "tsl/platform/status.h" namespace stream_executor { @@ -41,12 +41,12 @@ int XlaInterpreterPlatform::VisibleDeviceCount() const { return 1; } const std::string& XlaInterpreterPlatform::Name() const { return name_; } -tsl::StatusOr> +absl::StatusOr> XlaInterpreterPlatform::DescriptionForDevice(int ordinal) const { return XlaInterpreterExecutor::CreateDeviceDescription(ordinal); } -tsl::StatusOr XlaInterpreterPlatform::ExecutorForDevice( +absl::StatusOr XlaInterpreterPlatform::ExecutorForDevice( int ordinal) { StreamExecutorConfig config; config.ordinal = ordinal; @@ -54,20 +54,20 @@ tsl::StatusOr XlaInterpreterPlatform::ExecutorForDevice( return GetExecutor(config); } -tsl::StatusOr XlaInterpreterPlatform::GetExecutor( +absl::StatusOr XlaInterpreterPlatform::GetExecutor( const StreamExecutorConfig& config) { return executor_cache_.GetOrCreate( config, [&]() { return GetUncachedExecutor(config); }); } -tsl::StatusOr> +absl::StatusOr> XlaInterpreterPlatform::GetUncachedExecutor( const StreamExecutorConfig& config) { auto executor = std::make_unique( this, std::make_unique(), config.ordinal); auto init_status = executor->Init(config.device_options); if (!init_status.ok()) { - return tsl::Status{ + return absl::Status{ absl::StatusCode::kInternal, absl::StrFormat( "failed initializing StreamExecutor for device ordinal %d: %s", @@ -79,12 +79,12 @@ XlaInterpreterPlatform::GetUncachedExecutor( static void InitializeXlaInterpreterPlatform() { std::unique_ptr platform(new XlaInterpreterPlatform); - TF_CHECK_OK(MultiPlatformManager::RegisterPlatform(std::move(platform))); + TF_CHECK_OK(PlatformManager::RegisterPlatform(std::move(platform))); } } // namespace interpreter } // namespace stream_executor -REGISTER_MODULE_INITIALIZER( +STREAM_EXECUTOR_REGISTER_MODULE_INITIALIZER( interpreter_platform, stream_executor::interpreter::InitializeXlaInterpreterPlatform()); diff --git a/third_party/xla/xla/backends/interpreter/platform.h b/third_party/xla/xla/backends/interpreter/platform.h index e29444c8854b29..c81f7f7f2fd60f 100644 --- a/third_party/xla/xla/backends/interpreter/platform.h +++ b/third_party/xla/xla/backends/interpreter/platform.h @@ -39,15 +39,15 @@ class XlaInterpreterPlatform : public Platform { const std::string& Name() const override; - tsl::StatusOr> DescriptionForDevice( + absl::StatusOr> DescriptionForDevice( int ordinal) const override; - tsl::StatusOr ExecutorForDevice(int ordinal) override; + absl::StatusOr ExecutorForDevice(int ordinal) override; - tsl::StatusOr GetExecutor( + absl::StatusOr GetExecutor( const StreamExecutorConfig& config) override; - tsl::StatusOr> GetUncachedExecutor( + absl::StatusOr> GetUncachedExecutor( const StreamExecutorConfig& config) override; private: diff --git a/third_party/xla/xla/backends/profiler/BUILD b/third_party/xla/xla/backends/profiler/BUILD index d1fa77c7c61210..5978f53341c13f 100644 --- a/third_party/xla/xla/backends/profiler/BUILD +++ b/third_party/xla/xla/backends/profiler/BUILD @@ -1,4 +1,9 @@ -load("@local_tsl//tsl:tsl.bzl", "if_with_tpu_support", "tsl_gpu_library") +load( + "@local_tsl//tsl:tsl.bzl", + "if_with_tpu_support", + "internal_visibility", + "tsl_gpu_library", +) # copybara:uncomment package(default_applicable_licenses = ["//tensorflow:license"]) @@ -13,7 +18,7 @@ package_group( tsl_gpu_library( name = "profiler_backends", - visibility = ["//visibility:public"], + visibility = internal_visibility(["//xla:internal"]), deps = [ "//xla/backends/profiler/cpu:host_tracer", "//xla/backends/profiler/cpu:metadata_collector", diff --git a/third_party/xla/xla/backends/profiler/cpu/BUILD b/third_party/xla/xla/backends/profiler/cpu/BUILD index 97d643f637e75f..0820bfcc1155f0 100644 --- a/third_party/xla/xla/backends/profiler/cpu/BUILD +++ b/third_party/xla/xla/backends/profiler/cpu/BUILD @@ -1,3 +1,4 @@ +load("@local_tsl//tsl:tsl.bzl", "internal_visibility") load("@local_tsl//tsl/platform:rules_cc.bzl", "cc_library") load("@local_tsl//tsl/profiler/builds:build_config.bzl", "tf_profiler_copts") @@ -6,7 +7,10 @@ load("@local_tsl//tsl/profiler/builds:build_config.bzl", "tf_profiler_copts") cc_library( name = "host_tracer", srcs = ["host_tracer_factory.cc"], - visibility = ["//visibility:public"], + visibility = internal_visibility([ + "//xla/backends/profiler:__pkg__", + # copybara:uncomment "//tensorflow/core/profiler:internal", + ]), deps = [ ":host_tracer_impl", "@local_tsl//tsl/profiler/lib:profiler_factory", @@ -20,7 +24,9 @@ cc_library( srcs = ["host_tracer.cc"], hdrs = ["host_tracer.h"], copts = tf_profiler_copts(), - visibility = ["//visibility:public"], + visibility = internal_visibility([ + # copybara:uncomment "//tensorflow/core/profiler:internal", + ]), deps = [ "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:status", @@ -38,7 +44,10 @@ cc_library( cc_library( name = "python_tracer", srcs = ["python_tracer_factory.cc"], - visibility = ["//visibility:public"], + visibility = internal_visibility([ + "//xla/python:__pkg__", + # copybara:uncomment "//tensorflow/core/profiler:internal", + ]), deps = [ ":python_tracer_impl", "@local_tsl//tsl/profiler/lib:profiler_factory", @@ -53,7 +62,9 @@ cc_library( hdrs = ["python_tracer.h"], copts = tf_profiler_copts() + ["-fexceptions"], features = ["-use_header_modules"], - visibility = ["//visibility:public"], + visibility = internal_visibility([ + # copybara:uncomment "//tensorflow/core/profiler:internal", + ]), deps = [ "//xla/python/profiler/internal:python_hooks", "@local_tsl//tsl/platform:errors", @@ -69,7 +80,10 @@ cc_library( name = "metadata_collector", srcs = ["metadata_collector.cc"], copts = tf_profiler_copts(), - visibility = ["//visibility:public"], + visibility = internal_visibility([ + "//xla/backends/profiler:__pkg__", + # copybara:uncomment "//tensorflow/core/profiler:internal", + ]), deps = [ ":metadata_utils", "//xla/service:hlo_proto_cc", @@ -89,7 +103,9 @@ cc_library( cc_library( name = "metadata_utils", hdrs = ["metadata_utils.h"], - visibility = ["//visibility:public"], + visibility = internal_visibility([ + # copybara:uncomment "//tensorflow/core/profiler:internal", + ]), deps = [ "//xla/service:hlo_proto_cc", "@local_tsl//tsl/profiler/convert:xla_op_utils", diff --git a/third_party/xla/xla/backends/profiler/cpu/host_tracer.cc b/third_party/xla/xla/backends/profiler/cpu/host_tracer.cc index 750dcb2d568cd9..c656afe0ffb688 100644 --- a/third_party/xla/xla/backends/profiler/cpu/host_tracer.cc +++ b/third_party/xla/xla/backends/profiler/cpu/host_tracer.cc @@ -42,11 +42,11 @@ class HostTracer : public tsl::profiler::ProfilerInterface { explicit HostTracer(int host_trace_level); ~HostTracer() override; - tsl::Status Start() override; // TENSORFLOW_STATUS_OK + absl::Status Start() override; // TENSORFLOW_STATUS_OK - tsl::Status Stop() override; // TENSORFLOW_STATUS_OK + absl::Status Stop() override; // TENSORFLOW_STATUS_OK - tsl::Status CollectData( // TENSORFLOW_STATUS_OK + absl::Status CollectData( // TENSORFLOW_STATUS_OK tensorflow::profiler::XSpace* space) override; private: @@ -68,7 +68,7 @@ HostTracer::HostTracer(int host_trace_level) HostTracer::~HostTracer() { Stop().IgnoreError(); } // NOLINT -tsl::Status HostTracer::Start() { // TENSORFLOW_STATUS_OK +absl::Status HostTracer::Start() { // TENSORFLOW_STATUS_OK if (recording_) { return tsl::errors::Internal("TraceMeRecorder already started"); } @@ -81,33 +81,33 @@ tsl::Status HostTracer::Start() { // TENSORFLOW_STATUS_OK if (!recording_) { return tsl::errors::Internal("Failed to start TraceMeRecorder"); } - return tsl::OkStatus(); + return absl::OkStatus(); } -tsl::Status HostTracer::Stop() { // TENSORFLOW_STATUS_OK +absl::Status HostTracer::Stop() { // TENSORFLOW_STATUS_OK if (!recording_) { return tsl::errors::Internal("TraceMeRecorder not started"); } events_ = tsl::profiler::TraceMeRecorder::Stop(); recording_ = false; - return tsl::OkStatus(); + return absl::OkStatus(); } -tsl::Status HostTracer::CollectData( // TENSORFLOW_STATUS_OK +absl::Status HostTracer::CollectData( // TENSORFLOW_STATUS_OK tensorflow::profiler::XSpace* space) { VLOG(2) << "Collecting data to XSpace from HostTracer."; if (recording_) { return tsl::errors::Internal("TraceMeRecorder not stopped"); } if (events_.empty()) { - return tsl::OkStatus(); + return absl::OkStatus(); } tensorflow::profiler::XPlane* plane = tsl::profiler::FindOrAddMutablePlaneWithName( space, tsl::profiler::kHostThreadsPlaneName); ConvertCompleteEventsToXPlane(start_timestamp_ns_, std::exchange(events_, {}), plane); - return tsl::OkStatus(); + return absl::OkStatus(); } } // namespace diff --git a/third_party/xla/xla/backends/profiler/cpu/python_tracer.cc b/third_party/xla/xla/backends/profiler/cpu/python_tracer.cc index 0430b0eb1fb63b..4649fcd6e6d3ff 100644 --- a/third_party/xla/xla/backends/profiler/cpu/python_tracer.cc +++ b/third_party/xla/xla/backends/profiler/cpu/python_tracer.cc @@ -35,11 +35,11 @@ class PythonTracer : public tsl::profiler::ProfilerInterface { : options_(options) {} ~PythonTracer() override; - tsl::Status Start() override; // TENSORFLOW_STATUS_OK + absl::Status Start() override; // TENSORFLOW_STATUS_OK - tsl::Status Stop() override; // TENSORFLOW_STATUS_OK + absl::Status Stop() override; // TENSORFLOW_STATUS_OK - tsl::Status CollectData( // TENSORFLOW_STATUS_OK + absl::Status CollectData( // TENSORFLOW_STATUS_OK tensorflow::profiler::XSpace* space) override; private: @@ -53,34 +53,34 @@ class PythonTracer : public tsl::profiler::ProfilerInterface { PythonTracer::~PythonTracer() { Stop().IgnoreError(); } // NOLINT -tsl::Status PythonTracer::Start() { // TENSORFLOW_STATUS_OK +absl::Status PythonTracer::Start() { // TENSORFLOW_STATUS_OK if (recording_) { return tsl::errors::Internal("PythonTracer already started"); } VLOG(1) << __FUNCTION__; recording_ = true; PythonHooks::GetSingleton()->Start(options_); - return tsl::OkStatus(); + return absl::OkStatus(); } -tsl::Status PythonTracer::Stop() { // TENSORFLOW_STATUS_OK +absl::Status PythonTracer::Stop() { // TENSORFLOW_STATUS_OK if (!recording_) { return tsl::errors::Internal("PythonTracer not started"); } VLOG(1) << __FUNCTION__; context_ = PythonHooks::GetSingleton()->Stop(); recording_ = false; - return tsl::OkStatus(); + return absl::OkStatus(); } -tsl::Status PythonTracer::CollectData( // TENSORFLOW_STATUS_OK +absl::Status PythonTracer::CollectData( // TENSORFLOW_STATUS_OK tensorflow::profiler::XSpace* space) { VLOG(2) << "Collecting data to XSpace from PythonTracer."; if (context_) { context_->Finalize(space); context_.reset(); } - return tsl::OkStatus(); + return absl::OkStatus(); } } // namespace diff --git a/third_party/xla/xla/backends/profiler/gpu/BUILD b/third_party/xla/xla/backends/profiler/gpu/BUILD index 64eb006c083457..91c036c9fece9c 100644 --- a/third_party/xla/xla/backends/profiler/gpu/BUILD +++ b/third_party/xla/xla/backends/profiler/gpu/BUILD @@ -6,6 +6,7 @@ load( load("@local_config_rocm//rocm:build_defs.bzl", "if_rocm") load( "@local_tsl//tsl:tsl.bzl", + "internal_visibility", "tsl_copts", "tsl_gpu_library", ) @@ -17,7 +18,6 @@ load( "@local_tsl//tsl/platform:build_config_root.bzl", "tf_cuda_tests_tags", ) -load("@local_tsl//tsl/platform:rules_cc.bzl", "cc_library") load( "@local_tsl//tsl/platform/default:cuda_build_defs.bzl", "if_cuda_is_configured", @@ -25,7 +25,8 @@ load( load("@local_tsl//tsl/profiler/builds:build_config.bzl", "tf_profiler_copts") package( - default_visibility = ["//visibility:public"], + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = internal_visibility(["//xla:internal"]), ) tsl_gpu_library( @@ -38,7 +39,6 @@ tsl_gpu_library( ":cupti_wrapper", ":rocm_tracer", ], - visibility = ["//visibility:public"], deps = [ ":cupti_utils", "@com_google_absl//absl/container:fixed_array", @@ -75,7 +75,6 @@ tsl_gpu_library( cuda_deps = [ ":cupti_interface", ], - visibility = ["//visibility:public"], deps = [ "@local_tsl//tsl/platform:test", ], @@ -139,7 +138,7 @@ cuda_library( visibility = ["//visibility:public"], deps = [ "@local_config_cuda//cuda:cuda_headers", - "@local_config_cuda//cuda:cudart", + "@local_config_cuda//cuda:cudart_static", "@local_tsl//tsl/platform:test", ], ) @@ -152,7 +151,10 @@ cuda_library( # that the wrapper is about the only direct user. tsl_gpu_library( name = "cupti_wrapper", - srcs = if_cuda(["cupti_wrapper.cc"]), + srcs = if_cuda([ + "cupti_wrapper.cc", + "cupti_wrapper_stub.cc", + ]), hdrs = if_cuda(["cupti_wrapper.h"]), copts = tf_profiler_copts() + tsl_copts(), linkstatic = 1, @@ -246,7 +248,6 @@ tsl_gpu_library( srcs = if_cuda(["nvtx_utils.cc"]), hdrs = if_cuda(["nvtx_utils.h"]), copts = tf_profiler_copts() + tsl_copts(), - visibility = ["//visibility:public"], deps = [ "@com_google_absl//absl/strings", "@local_tsl//tsl/platform", @@ -257,7 +258,7 @@ tsl_gpu_library( tsl_gpu_library( name = "cupti_collector", srcs = if_cuda(["cupti_collector.cc"]), - hdrs = if_cuda(["cupti_collector.h"]), + hdrs = ["cupti_collector.h"], copts = tf_profiler_copts() + tsl_copts(), visibility = ["//visibility:public"], deps = [ @@ -282,20 +283,6 @@ tsl_gpu_library( ] + if_cuda(["@local_tsl//tsl/cuda:cupti"]), ) -cc_library( - name = "cupti_collector_header", - hdrs = ["cupti_collector.h"], - visibility = ["//visibility:public"], - deps = [ - "@com_google_absl//absl/container:fixed_array", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/container:node_hash_set", - "@com_google_absl//absl/strings", - "@local_tsl//tsl/platform:types", - "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", - ], -) - tsl_gpu_library( name = "cupti_utils", srcs = if_cuda(["cupti_utils.cc"]), @@ -304,7 +291,11 @@ tsl_gpu_library( ":cupti_error_manager", ":cupti_interface", ":cupti_wrapper", + "@com_google_absl//absl/base", "@com_google_absl//absl/memory", + "@local_tsl//tsl/platform:logging", + "@local_tsl//tsl/platform:stringpiece", + "@local_tsl//tsl/util:env_var", ], visibility = ["//visibility:public"], alwayslink = 1, diff --git a/third_party/xla/xla/backends/profiler/gpu/cupti_tracer.cc b/third_party/xla/xla/backends/profiler/gpu/cupti_tracer.cc index 4ac130b6f71b23..a7ebab58e6e3f5 100644 --- a/third_party/xla/xla/backends/profiler/gpu/cupti_tracer.cc +++ b/third_party/xla/xla/backends/profiler/gpu/cupti_tracer.cc @@ -22,6 +22,7 @@ limitations under the License. #include "absl/container/node_hash_set.h" #include "third_party/gpus/cuda/extras/CUPTI/include/cupti_activity.h" #include "third_party/gpus/cuda/extras/CUPTI/include/generated_nvtx_meta.h" +#include "third_party/gpus/cuda/include/cuda.h" #include "xla/backends/profiler/gpu/cupti_collector.h" #include "xla/backends/profiler/gpu/nvtx_utils.h" #include "tsl/platform/env.h" @@ -1061,629 +1062,6 @@ class CuptiDriverApiHookWithActivityApi : public CuptiDriverApiHook { void operator=(const CuptiDriverApiHookWithActivityApi &) = delete; }; -struct KernelRecord { - const char *kernel_name; - // TODO(csigg): cuStreamGetCtx introduced in CUDA 9.2 would allow us to only - // record the stream and infer the context during collection. - CUcontext context; - CUstream stream; - uint32_t correlation_id; - CUevent start_event; - CUevent stop_event; - KernelDetails details; - uint64_t start_timestamp; -}; - -struct MemcpyRecord { - CuptiTracerEventType type; - size_t size_bytes; - CUcontext context; - CUstream stream; - uint32_t correlation_id; - bool async; - CUevent start_event; - CUevent stop_event; - uint64_t start_timestamp; -}; - -Status CreateAndRecordEvent(CUevent *event, CUstream stream) { - CuptiApiTracingDisabler disabler; - TF_RETURN_IF_ERROR(ToStatus(cuEventCreate(event, CU_EVENT_DEFAULT))); - return ToStatus(cuEventRecord(*event, stream)); -} - -// Maintain and restore current thread's CUDA context. -// Note: cuStreamGetCtx only available after CUDA 9.2. -class ScopedCudaContext { - public: - explicit ScopedCudaContext(CUstream stream) : stream_(stream) { - CuptiApiTracingDisabler disabler; // don't trace cuda call in this func. - CUcontext context; - if (cuStreamGetCtx(stream, &context) != CUDA_SUCCESS) return; - context_ = context; - uint32_t device_ordinal; - if (cuptiGetDeviceId(context, &device_ordinal) != CUPTI_SUCCESS) return; - device_ordinal_ = device_ordinal; - context_pushed_ = cuCtxPushCurrent(context) == CUDA_SUCCESS; - } - ~ScopedCudaContext() { - if (!context_pushed_) return; - CuptiApiTracingDisabler disabler; // don't trace cuda call in this func. - cuCtxPopCurrent(&*context_); - } - - // If successful, return the device ordinal of the relevant cuda stream. - // Otherwise std::nullopt; non-std ok - std::optional GetDeviceOrdinal() { return device_ordinal_; } - - // If successful, return the cuda context of the relevant cuda stream. - // Otherwise std::nullopt; - std::optional GetContext() { return context_; } - - private: - CUstream stream_; - std::optional context_; - std::optional device_ordinal_; - bool context_pushed_ = false; -}; - -// Stores a series of kernel and memcpy records. -class CudaEventRecorder { - public: - CudaEventRecorder(CuptiInterface *cupti_interface, - CuptiTraceCollector *collector, int ordinal) - : cupti_interface_(cupti_interface), - collector_(collector), - ordinal_(ordinal) { - device_name_ = absl::StrCat("gpu ", ordinal); // default. - CUdevice device; - if (cuDeviceGet(&device, ordinal) == CUDA_SUCCESS) { - char name[100]; - if (cuDeviceGetName(name, sizeof(name), device) == CUDA_SUCCESS) { - device_name_ = name; - } - } - } - - // Registers the start of a kernel launch. The returned index should be passed - // to StopKernel() after the kernel launch has completed. - template - size_t StartKernel(const char *kernel_name, CUcontext context, - uint32_t correlation_id, const T *params) { - CUstream stream = params->hStream; - KernelRecord record = {kernel_name, context, stream, correlation_id}; - record.details.registers_per_thread = 0; // unknown. - record.details.static_shared_memory_usage = params->sharedMemBytes; - record.details.dynamic_shared_memory_usage = 0; // unknown - record.details.block_x = params->blockDimX; - record.details.block_y = params->blockDimY; - record.details.block_z = params->blockDimZ; - record.details.grid_x = params->gridDimX; - record.details.grid_y = params->gridDimY; - record.details.grid_z = params->gridDimZ; - record.start_timestamp = CuptiTracer::GetTimestamp(); - LogIfError(CreateAndRecordEvent(&record.start_event, stream)); - absl::MutexLock lock(&mutex_); - if (stopped_) return -1; - kernel_records_.push_back(record); - return kernel_records_.size() - 1; - } - uint64_t StopKernel(size_t index) { - absl::MutexLock lock(&mutex_); - if (index >= kernel_records_.size()) return 0; - auto &record = kernel_records_[index]; - LogIfError(CreateAndRecordEvent(&record.stop_event, record.stream)); - return record.start_timestamp; - } - - // Registers the start of a copy operation. The returned index should be - // passed to StopMemcpy() after the memcpy has completed. - size_t StartMemcpy(CuptiTracerEventType type, size_t size_bytes, - CUcontext context, CUstream stream, - uint32_t correlation_id, bool async) { - MemcpyRecord record = {type, size_bytes, context, - stream, correlation_id, async}; - record.start_timestamp = CuptiTracer::GetTimestamp(); - LogIfError(CreateAndRecordEvent(&record.start_event, stream)); - absl::MutexLock lock(&mutex_); - if (stopped_) return -1; - memcpy_records_.push_back(record); - return memcpy_records_.size() - 1; - } - uint64_t StopMemcpy(size_t index) { - absl::MutexLock lock(&mutex_); - if (index >= memcpy_records_.size()) return 0; - auto &record = memcpy_records_[index]; - LogIfError(CreateAndRecordEvent(&record.stop_event, record.stream)); - return record.start_timestamp; - } - - Status Stop() { - { - absl::MutexLock lock(&mutex_); - stopped_ = true; - LOG(INFO) << "Collecting " << kernel_records_.size() - << " kernel records, " << memcpy_records_.size() - << " memcpy records."; - - // Gather all profiled streams and contexts. - for (const auto &record : kernel_records_) { - TF_RETURN_IF_ERROR( - AddStreamInfo(record.context, record.stream, "Kernel")); - } - for (const auto &record : memcpy_records_) { - TF_RETURN_IF_ERROR(AddStreamInfo(record.context, record.stream, - GetTraceEventTypeName(record.type))); - } - } - - // Synchronize all contexts, record end events, synchronize again. - // This scheme is an unreliable measure to associate a event with the wall - // time. There are chances that other threads might enque kernels which - // delay the second synchronization. - TF_RETURN_IF_ERROR(Synchronize()); - for (auto &pair : context_infos_) { - TF_RETURN_IF_ERROR(ToStatus(cuCtxSetCurrent(pair.first))); - TF_RETURN_IF_ERROR(CreateAndRecordEvent(&pair.second.end_event, nullptr)); - } - - TF_RETURN_IF_ERROR(Synchronize()); - end_walltime_us_ = Env::Default()->NowMicros(); - return OkStatus(); - } - - Status Flush(AnnotationMap *annotation_map) { - auto kernel_records = ConsumeKernelRecords(); - auto memcpy_records = ConsumeMemcpyRecords(); - for (const auto &record : kernel_records) { - TF_RETURN_IF_ERROR(SaveRecord(record, annotation_map)); - } - for (const auto &record : memcpy_records) { - TF_RETURN_IF_ERROR(SaveRecord(record, annotation_map)); - } - return OkStatus(); - } - - std::vector ConsumeKernelRecords() { - absl::MutexLock lock(&mutex_); - return std::move(kernel_records_); - } - std::vector ConsumeMemcpyRecords() { - absl::MutexLock lock(&mutex_); - return std::move(memcpy_records_); - } - - private: - struct ContextInfo { - uint32_t context_id = 0; - int num_streams = 0; - CUevent end_event; - }; - - struct StreamInfo { - uint32_t stream_id = 0; - std::string name; - int index; // 0 is reserved for null stream. - const ContextInfo *ctx_info; - }; - - // Synchronizes all contexts. - Status Synchronize() const { - CuptiApiTracingDisabler disabler; - for (const auto &pair : context_infos_) { - TF_RETURN_IF_ERROR(ToStatus(cuCtxSetCurrent(pair.first))); - TF_RETURN_IF_ERROR(ToStatus(cuCtxSynchronize())); - } - return OkStatus(); - } - - // Returns element from context_infos_, adding it if not yet present. - Status GetContextInfo(CUcontext context, ContextInfo **ctx_info_ptr) { - auto it = context_infos_.find(context); - - if (it == context_infos_.end()) { - uint32_t context_id = 0; - RETURN_IF_CUPTI_ERROR( - cupti_interface_->GetContextId(context, &context_id)); - ContextInfo ctx_info = {context_id}; - it = context_infos_.emplace(context, ctx_info).first; - } - - *ctx_info_ptr = &it->second; - return OkStatus(); - } - - // Adds element to stream_infos_ if not yet present. If present, clear name - // if it doesn't match parameter. - Status AddStreamInfo(CUcontext context, CUstream stream, - absl::string_view name) { - StreamKey key(context, stream); - auto it = stream_infos_.find(key); - if (it != stream_infos_.end()) { - if (it->second.name != name) { - it->second.name.clear(); // Stream with inconsistent names, clear it. - } - return OkStatus(); - } - - ContextInfo *ctx_info; - TF_RETURN_IF_ERROR(GetContextInfo(context, &ctx_info)); - int index = stream ? ++ctx_info->num_streams : 0; - uint32_t stream_id = 0; -#if defined(CUDA_API_PER_THREAD_DEFAULT_STREAM) - RETURN_IF_CUPTI_ERROR( - cupti_interface_->GetStreamIdEx(context, stream, 1, &stream_id)); -#else - RETURN_IF_CUPTI_ERROR( - cupti_interface_->GetStreamIdEx(context, stream, 0, &stream_id)); -#endif - - StreamInfo stream_info = {stream_id, static_cast(name), index, - ctx_info}; - stream_infos_.emplace(key, stream_info); - return OkStatus(); - } - - // Returns time in microseconds between events recorded on the GPU. - static uint64_t GetElapsedTimeUs(CUevent start, CUevent stop) { - CuptiApiTracingDisabler disabler; - float elapsed_ms = 0.0f; - LogIfError(ToStatus(cuEventElapsedTime(&elapsed_ms, start, stop))); - return static_cast( - std::llroundf(1000 * std::max(elapsed_ms, 0.0f))); - } - - Status SaveRecord(const KernelRecord &record, - AnnotationMap *annotation_map) const { - if (!record.start_event || !record.stop_event) { - return OkStatus(); - } - const auto &stream_info = - stream_infos_.at(StreamKey(record.context, record.stream)); - auto start_us = - GetElapsedTimeUs(record.start_event, stream_info.ctx_info->end_event); - auto elapsed_us = GetElapsedTimeUs(record.start_event, record.stop_event); - - std::string annotation; - - CuptiTracerEvent event{}; - event.type = CuptiTracerEventType::Kernel; - event.source = CuptiTracerEventSource::Activity; // on gpu device. - event.name = record.kernel_name; - event.start_time_ns = (end_walltime_us_ - start_us) * 1000; - event.end_time_ns = event.start_time_ns + elapsed_us * 1000; - event.device_id = ordinal_; - event.context_id = stream_info.ctx_info->context_id; - event.stream_id = stream_info.stream_id; - event.correlation_id = record.correlation_id; - AnnotationMap::AnnotationInfo info = collector_->annotation_map()->LookUp( - event.device_id, event.correlation_id); - event.annotation = info.annotation; - event.kernel_info = record.details; - collector_->AddEvent(std::move(event)); - return OkStatus(); - } - - Status SaveRecord(const MemcpyRecord &record, - AnnotationMap *annotation_map) const { - if (!record.start_event || !record.stop_event) { - return OkStatus(); - } - const auto &stream_info = - stream_infos_.at(StreamKey(record.context, record.stream)); - auto start_us = - GetElapsedTimeUs(record.start_event, stream_info.ctx_info->end_event); - auto elapsed_us = GetElapsedTimeUs(record.start_event, record.stop_event); - - CuptiTracerEvent event{}; - event.type = record.type; - event.name = GetTraceEventTypeName(event.type); - event.source = CuptiTracerEventSource::Activity; - event.start_time_ns = (end_walltime_us_ - start_us) * 1000; - event.end_time_ns = event.start_time_ns + elapsed_us * 1000; - event.device_id = ordinal_; - event.context_id = stream_info.ctx_info->context_id; - event.stream_id = stream_info.stream_id; - event.correlation_id = record.correlation_id; - AnnotationMap::AnnotationInfo info = collector_->annotation_map()->LookUp( - event.device_id, event.correlation_id); - event.annotation = info.annotation; - event.memcpy_info.num_bytes = record.size_bytes; - // TODO: support MemcpyD2D where destination != source; - event.memcpy_info.destination = ordinal_; - event.memcpy_info.async = record.async; - // TODO: set src_mem_kind and dst_mem_kind. - collector_->AddEvent(std::move(event)); - return OkStatus(); - } - - absl::Mutex mutex_; - bool stopped_ TF_GUARDED_BY(mutex_) = false; - std::vector kernel_records_ TF_GUARDED_BY(mutex_); - std::vector memcpy_records_ - TF_GUARDED_BY(mutex_); // non std ok - - CuptiInterface *cupti_interface_; - CuptiTraceCollector *collector_; - const int ordinal_; - std::string device_name_; - uint64_t end_walltime_us_; - // Include context in key to distinguish null streams. - using StreamKey = std::pair; - - absl::node_hash_map context_infos_; - absl::flat_hash_map stream_infos_; -}; - -// This hook uses cuda events to measure device side activities. -class CuptiDriverApiHookWithCudaEvent : public CuptiDriverApiHook { - public: - CuptiDriverApiHookWithCudaEvent(const CuptiTracerOptions &option, - CuptiInterface *cupti_interface, - CuptiTraceCollector *collector) - : option_(option), - cupti_interface_(cupti_interface), - collector_(collector) { - int num_gpus = CuptiTracer::NumGpus(); - cuda_event_recorders_.reserve(num_gpus); - for (int i = 0; i < num_gpus; ++i) { - cuda_event_recorders_.emplace_back( - std::make_unique(cupti_interface, collector, i)); - } - } - ~CuptiDriverApiHookWithCudaEvent() { - for (auto *callback_context : callback_contexts_) delete callback_context; - } - - Status OnDriverApiEnter(int device_id, CUpti_CallbackDomain domain, - CUpti_CallbackId cbid, - const CUpti_CallbackData *cbdata) override { - auto *recorder = cuda_event_recorders_[device_id].get(); - switch (cbid) { - case CUPTI_DRIVER_TRACE_CBID_cuLaunchKernel: { - DCHECK_NE(cbdata->symbolName, nullptr); - const auto *params = - static_cast(cbdata->functionParams); - *cbdata->correlationData = recorder->StartKernel( - cbdata->symbolName, cbdata->context, cbdata->correlationId, params); - break; - } -#if CUDA_VERSION >= 11080 // CUDA 11.8 - case CUPTI_DRIVER_TRACE_CBID_cuLaunchKernelEx: { - DCHECK_NE(cbdata->symbolName, nullptr); - const auto *params = static_cast( - cbdata->functionParams); - *cbdata->correlationData = recorder->StartKernel( - cbdata->symbolName, cbdata->context, cbdata->correlationId, - params->config); - break; - } -#endif // CUDA_VERSION >= 11080 - case CUPTI_DRIVER_TRACE_CBID_cuLaunchCooperativeKernel: { - DCHECK_NE(cbdata->symbolName, nullptr); - const auto *params = - static_cast( - cbdata->functionParams); - *cbdata->correlationData = - recorder->StartKernel( - cbdata->symbolName, cbdata->context, cbdata->correlationId, - params); - break; - } - case CUPTI_DRIVER_TRACE_CBID_cuLaunchCooperativeKernelMultiDevice: { - const auto *params = - static_cast( - cbdata->functionParams); - std::vector record_indices; - record_indices.reserve(params->numDevices); - *cbdata->correlationData = -1; // Invalid value. - const auto &annotation = AnnotationStack::Get(); - for (int i = 0; i < params->numDevices; ++i) { - CUstream stream = params->launchParamsList[i].hStream; - ScopedCudaContext scoped_cuda_context(stream); - auto dev_id = scoped_cuda_context.GetDeviceOrdinal(); - auto context = scoped_cuda_context.GetContext(); - if (!dev_id) return tsl::errors::Internal("Invalid CUDA stream"); - // Because annotation are per device, therefore we need to populate - // annotation for each device involved. - collector_->annotation_map()->Add(*dev_id, cbdata->correlationId, - annotation, ""); - record_indices.push_back( - cuda_event_recorders_[*dev_id]->StartKernel( - "CooperativeKernelMultiDevice", *context, - cbdata->correlationId, &(params->launchParamsList[i]))); - } - auto *callback_context = - new CuptiApiCallbackContext(std::move(record_indices)); - callback_contexts_.insert(callback_context); - *cbdata->correlationData = reinterpret_cast(callback_context); - } break; - case CUPTI_DRIVER_TRACE_CBID_cuMemcpy: { - const auto *params = - static_cast(cbdata->functionParams); - StartMemcpy(GetMemcpyType(params->src, params->dst), - cbdata, recorder); - break; - } - case CUPTI_DRIVER_TRACE_CBID_cuMemcpyAsync: { - const auto *params = - static_cast(cbdata->functionParams); - StartMemcpyAsync( - GetMemcpyType(params->src, params->dst), cbdata, recorder); - break; - } - case CUPTI_DRIVER_TRACE_CBID_cuMemcpyHtoD_v2: - StartMemcpy(CuptiTracerEventType::MemcpyH2D, - cbdata, recorder); - break; - case CUPTI_DRIVER_TRACE_CBID_cuMemcpyHtoDAsync_v2: - StartMemcpyAsync( - CuptiTracerEventType::MemcpyH2D, cbdata, recorder); - break; - case CUPTI_DRIVER_TRACE_CBID_cuMemcpyDtoH_v2: - StartMemcpy(CuptiTracerEventType::MemcpyD2H, - cbdata, recorder); - break; - case CUPTI_DRIVER_TRACE_CBID_cuMemcpyDtoHAsync_v2: - StartMemcpyAsync( - CuptiTracerEventType::MemcpyD2H, cbdata, recorder); - break; - case CUPTI_DRIVER_TRACE_CBID_cuMemcpyDtoD_v2: - StartMemcpy(CuptiTracerEventType::MemcpyD2D, - cbdata, recorder); - break; - case CUPTI_DRIVER_TRACE_CBID_cuMemcpyDtoDAsync_v2: - StartMemcpyAsync( - CuptiTracerEventType::MemcpyD2D, cbdata, recorder); - break; - default: - VLOG(1) << "Unexpected callback id: " << cbid; - break; - } - return OkStatus(); - } - - Status OnDriverApiExit(int device_id, CUpti_CallbackDomain domain, - CUpti_CallbackId cbid, - const CUpti_CallbackData *cbdata) override { - auto *recorder = cuda_event_recorders_[device_id].get(); - if (*cbdata->correlationData == static_cast(-1)) return OkStatus(); - uint64_t start_tsc = 0; - switch (cbid) { - case CUPTI_DRIVER_TRACE_CBID_cuLaunchKernel: -#if CUDA_VERSION >= 11080 // CUDA 11.8 - case CUPTI_DRIVER_TRACE_CBID_cuLaunchKernelEx: -#endif // CUDA_VERSION >= 11080 - case CUPTI_DRIVER_TRACE_CBID_cuLaunchCooperativeKernel: - start_tsc = recorder->StopKernel(*cbdata->correlationData); - break; - case CUPTI_DRIVER_TRACE_CBID_cuLaunchCooperativeKernelMultiDevice: { - auto *callback_context = reinterpret_cast( - *cbdata->correlationData); - callback_contexts_.erase(callback_context); - auto record_indices = std::move(callback_context->record_indices); - delete callback_context; - const auto *params = - static_cast( - cbdata->functionParams); - if (record_indices.size() != params->numDevices) - return tsl::errors::Internal("Invalid correlation data"); - for (int i = 0; i < params->numDevices; ++i) { - CUstream stream = params->launchParamsList[i].hStream; - ScopedCudaContext scoped_cuda_context(stream); - auto dev_id = scoped_cuda_context.GetDeviceOrdinal(); - if (!dev_id) return tsl::errors::Internal("Invalid CUDA stream"); - start_tsc = - cuda_event_recorders_[*dev_id]->StopKernel(record_indices[i]); - } - } break; - case CUPTI_DRIVER_TRACE_CBID_cuMemcpy: - case CUPTI_DRIVER_TRACE_CBID_cuMemcpyAsync: - case CUPTI_DRIVER_TRACE_CBID_cuMemcpyHtoD_v2: - case CUPTI_DRIVER_TRACE_CBID_cuMemcpyHtoDAsync_v2: - case CUPTI_DRIVER_TRACE_CBID_cuMemcpyDtoH_v2: - case CUPTI_DRIVER_TRACE_CBID_cuMemcpyDtoHAsync_v2: - case CUPTI_DRIVER_TRACE_CBID_cuMemcpyDtoD_v2: - case CUPTI_DRIVER_TRACE_CBID_cuMemcpyDtoDAsync_v2: - start_tsc = recorder->StopMemcpy(*cbdata->correlationData); - break; - default: - VLOG(1) << "Unexpected callback id: " << cbid; - // TODO: figure out how to get start timestamp in this case. - return OkStatus(); - } - // If we are not collecting CPU events from Callback API, we can return now. - if (!option_.required_callback_api_events) { - return OkStatus(); - } - - // Grab timestamp for API exit. API entry timestamp saved in cbdata. - uint64_t end_tsc = CuptiTracer::GetTimestamp(); - return AddDriverApiCallbackEvent(collector_, cupti_interface_, device_id, - start_tsc, end_tsc, domain, cbid, cbdata); - } - Status SyncAndFlush() override { - for (auto &recorder : cuda_event_recorders_) { - TF_RETURN_IF_ERROR(recorder->Stop()); - } - for (auto &recorder : cuda_event_recorders_) { - TF_RETURN_IF_ERROR(recorder->Flush(collector_->annotation_map())); - } - return OkStatus(); - } - - private: - template - static void StartMemcpy(CuptiTracerEventType type, - const CUpti_CallbackData *cbdata, - CudaEventRecorder *recorder) { - const auto *params = static_cast(cbdata->functionParams); - *cbdata->correlationData = - recorder->StartMemcpy(type, params->ByteCount, cbdata->context, nullptr, - cbdata->correlationId, /*async*/ false); - } - - template - static void StartMemcpyAsync(CuptiTracerEventType type, - const CUpti_CallbackData *cbdata, - CudaEventRecorder *recorder) { - const auto *params = static_cast(cbdata->functionParams); - *cbdata->correlationData = recorder->StartMemcpy( - type, params->ByteCount, cbdata->context, params->hStream, - cbdata->correlationId, /*async*/ true); - } - - static CUmemorytype GetMemoryType(CUdeviceptr ptr) { - CuptiApiTracingDisabler disabler; - CUmemorytype mem_type = CU_MEMORYTYPE_HOST; - auto status = - cuPointerGetAttribute(&mem_type, CU_POINTER_ATTRIBUTE_MEMORY_TYPE, ptr); - if (status == CUDA_ERROR_INVALID_VALUE) { - // Pointer not registered with CUDA, must be host memory. - return CU_MEMORYTYPE_HOST; - } - LogIfError(ToStatus(status)); - return mem_type; - } - - static CuptiTracerEventType GetMemcpyType(CUdeviceptr src, CUdeviceptr dst) { - CUmemorytype src_type = GetMemoryType(src); - CUmemorytype dst_type = GetMemoryType(dst); - // TODO: handle CU_MEMORYTYPE_ARRAY case - if (src_type == CU_MEMORYTYPE_HOST && dst_type == CU_MEMORYTYPE_DEVICE) { - return CuptiTracerEventType::MemcpyH2D; - } else if (src_type == CU_MEMORYTYPE_DEVICE && - dst_type == CU_MEMORYTYPE_HOST) { - return CuptiTracerEventType::MemcpyD2H; - } else if (src_type == CU_MEMORYTYPE_DEVICE && - dst_type == CU_MEMORYTYPE_DEVICE) { - return CuptiTracerEventType::MemcpyD2D; - } - return CuptiTracerEventType::MemcpyOther; - } - - // Each cuLaunchCooperativeKernelMultiDevice will need to add an entry in - // each corresponding device, therefore we need to keep records of all - // the record indices in each device's record array. - // We allocate such data structure during API entry and free during API exit. - // However there is no guarantee that we receive such callbacks in pairs, we - // maintain a on-going API calls to make sure no memory leaks. - struct CuptiApiCallbackContext { - explicit CuptiApiCallbackContext(std::vector &&r) - : record_indices(std::move(r)) {} - std::vector record_indices; - }; - - const CuptiTracerOptions option_; - CuptiInterface *cupti_interface_; - CuptiTraceCollector *collector_; - absl::node_hash_set callback_contexts_; - std::vector> cuda_event_recorders_; - CuptiDriverApiHookWithCudaEvent(const CuptiDriverApiHookWithCudaEvent &) = - delete; - void operator=(const CuptiDriverApiHookWithCudaEvent &) = delete; -}; - /*static*/ std::string ErrorWithHostname(absl::string_view error_message) { return absl::StrCat(tsl::port::Hostname(), ": ", error_message); } @@ -1868,30 +1246,21 @@ void CuptiTracer::Enable(const CuptiTracerOptions &option, CuptiTraceCollector *collector) { option_ = option; collector_ = collector; - if (option_->enable_event_based_activity) { - option_->enable_activity_api = false; - cupti_driver_api_hook_.reset(new CuptiDriverApiHookWithCudaEvent( - option, cupti_interface_, collector)); - } else { - cupti_driver_api_hook_.reset(new CuptiDriverApiHookWithActivityApi( - option, cupti_interface_, collector)); - } + + cupti_driver_api_hook_.reset(new CuptiDriverApiHookWithActivityApi( + option, cupti_interface_, collector)); Status status = EnableApiTracing(); need_root_access_ |= status.code() == tsl::error::PERMISSION_DENIED; if (!status.ok()) return; - if (option_->enable_activity_api) { - EnableActivityTracing().IgnoreError(); - } + EnableActivityTracing().IgnoreError(); tsl::profiler::AnnotationStack::Enable(true); } void CuptiTracer::Disable() { DisableApiTracing().IgnoreError(); - if (option_->enable_activity_api) { - DisableActivityTracing().IgnoreError(); - } + DisableActivityTracing().IgnoreError(); cupti_interface_->CleanUp(); Finalize().IgnoreError(); cupti_driver_api_hook_->SyncAndFlush().IgnoreError(); diff --git a/third_party/xla/xla/backends/profiler/gpu/cupti_tracer.h b/third_party/xla/xla/backends/profiler/gpu/cupti_tracer.h index 419df5756c1893..59583453779b75 100644 --- a/third_party/xla/xla/backends/profiler/gpu/cupti_tracer.h +++ b/third_party/xla/xla/backends/profiler/gpu/cupti_tracer.h @@ -29,13 +29,6 @@ namespace xla { namespace profiler { struct CuptiTracerOptions { - bool enable_activity_api = true; - - // Use cuda events to enclose the kernel/memcpy to measure device activity. - // enable_event_based_activity, if true, will override the enable_activity_api - // setting. - bool enable_event_based_activity = false; - bool required_callback_api_events = true; // The callback ids that will be enabled and monitored, if empty, all // Callback ids to be enabled using Callback API. diff --git a/third_party/xla/xla/backends/profiler/gpu/cupti_utils.cc b/third_party/xla/xla/backends/profiler/gpu/cupti_utils.cc index 04f400483842df..ee9a542485a48c 100644 --- a/third_party/xla/xla/backends/profiler/gpu/cupti_utils.cc +++ b/third_party/xla/xla/backends/profiler/gpu/cupti_utils.cc @@ -12,17 +12,43 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "absl/base/call_once.h" #include "absl/memory/memory.h" #include "xla/backends/profiler/gpu/cupti_error_manager.h" #include "xla/backends/profiler/gpu/cupti_interface.h" #include "xla/backends/profiler/gpu/cupti_wrapper.h" +#include "tsl/platform/logging.h" +#include "tsl/platform/stringpiece.h" +#include "tsl/util/env_var.h" namespace xla { namespace profiler { +bool IsCuptiUseStubInterface() { + // TODO: b/149634979: Remove this after NVIDIA issue 4459155 resolved. + static constexpr tsl::StringPiece cupti_use_stub_interface_env = + "TF_GPU_CUPTI_USE_STUB_INTERFACE"; + static absl::once_flag once; // NOLINT(clang-diagnostic-unreachable-code) + static bool cupti_use_stub_interface = false; + absl::call_once(once, [&] { + tsl::ReadBoolFromEnvVar(cupti_use_stub_interface_env, false, + &cupti_use_stub_interface) + .IgnoreError(); + if (cupti_use_stub_interface) { + LOG(INFO) << cupti_use_stub_interface_env << " is set to true, " + << "XLA Profiler is using stub CUPTI interface to work around " + << "potential serious bug in CUPTI lib. Such control may be " + << "removed/disabled in future if the known issue is resolved!"; + } + }); + return cupti_use_stub_interface; +} + CuptiInterface* GetCuptiInterface() { static CuptiInterface* cupti_interface = - new CuptiErrorManager(std::make_unique()); + IsCuptiUseStubInterface() + ? new CuptiErrorManager(std::make_unique()) + : new CuptiErrorManager(std::make_unique()); return cupti_interface; } diff --git a/third_party/xla/xla/backends/profiler/gpu/cupti_wrapper.h b/third_party/xla/xla/backends/profiler/gpu/cupti_wrapper.h index bcea89a113245e..ada4af91e6d48e 100644 --- a/third_party/xla/xla/backends/profiler/gpu/cupti_wrapper.h +++ b/third_party/xla/xla/backends/profiler/gpu/cupti_wrapper.h @@ -93,6 +93,76 @@ class CuptiWrapper : public xla::profiler::CuptiInterface { void operator=(const CuptiWrapper&) = delete; }; +// This is an implementation of CuptiWrapper that implements all load bearing +// APIs as no-op. This is a stub that keeps XLA profiler functional, but all +// collected profiles will be empty. +class CuptiWrapperStub : public xla::profiler::CuptiInterface { + public: + CuptiWrapperStub() {} + + ~CuptiWrapperStub() override {} + + // CUPTI activity API + CUptiResult ActivityDisable(CUpti_ActivityKind kind) override; + + CUptiResult ActivityEnable(CUpti_ActivityKind kind) override; + + CUptiResult ActivityFlushAll(uint32_t flag) override; + + CUptiResult ActivityGetNextRecord(uint8_t* buffer, + size_t valid_buffer_size_bytes, + CUpti_Activity** record) override; + + CUptiResult ActivityGetNumDroppedRecords(CUcontext context, + uint32_t stream_id, + size_t* dropped) override; + + CUptiResult ActivityConfigureUnifiedMemoryCounter( + CUpti_ActivityUnifiedMemoryCounterConfig* config, + uint32_t count) override; + + CUptiResult ActivityRegisterCallbacks( + CUpti_BuffersCallbackRequestFunc func_buffer_requested, + CUpti_BuffersCallbackCompleteFunc func_buffer_completed) override; + + CUptiResult GetDeviceId(CUcontext context, uint32_t* deviceId) override; + + CUptiResult GetTimestamp(uint64_t* timestamp) override; + + // cuptiFinalize is only defined in CUDA8 and above. + // To enable it in CUDA8, the environment variable CUPTI_ENABLE_FINALIZE must + // be set to 1. + CUptiResult Finalize() override; + + // CUPTI callback API + CUptiResult EnableCallback(uint32_t enable, CUpti_SubscriberHandle subscriber, + CUpti_CallbackDomain domain, + CUpti_CallbackId cbid) override; + + CUptiResult EnableDomain(uint32_t enable, CUpti_SubscriberHandle subscriber, + CUpti_CallbackDomain domain) override; + + CUptiResult Subscribe(CUpti_SubscriberHandle* subscriber, + CUpti_CallbackFunc callback, void* userdata) override; + + CUptiResult Unsubscribe(CUpti_SubscriberHandle subscriber) override; + + CUptiResult GetResultString(CUptiResult result, const char** str) override; + + CUptiResult GetContextId(CUcontext context, uint32_t* context_id) override; + + CUptiResult GetStreamIdEx(CUcontext context, CUstream stream, + uint8_t per_thread_stream, + uint32_t* stream_id) override; + + void CleanUp() override {} + bool Disabled() const override { return false; } + + private: + CuptiWrapperStub(const CuptiWrapperStub&) = delete; + void operator=(const CuptiWrapperStub&) = delete; +}; + } // namespace profiler } // namespace xla diff --git a/third_party/xla/xla/backends/profiler/gpu/cupti_wrapper_stub.cc b/third_party/xla/xla/backends/profiler/gpu/cupti_wrapper_stub.cc new file mode 100644 index 00000000000000..945fe49e853ded --- /dev/null +++ b/third_party/xla/xla/backends/profiler/gpu/cupti_wrapper_stub.cc @@ -0,0 +1,109 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "xla/backends/profiler/gpu/cupti_wrapper.h" + +namespace xla { +namespace profiler { + +CUptiResult CuptiWrapperStub::ActivityDisable(CUpti_ActivityKind kind) { + return CUPTI_SUCCESS; +} + +CUptiResult CuptiWrapperStub::ActivityEnable(CUpti_ActivityKind kind) { + return CUPTI_SUCCESS; +} + +CUptiResult CuptiWrapperStub::ActivityFlushAll(uint32_t flag) { + return CUPTI_SUCCESS; +} + +CUptiResult CuptiWrapperStub::ActivityGetNextRecord( + uint8_t* buffer, size_t valid_buffer_size_bytes, CUpti_Activity** record) { + return CUPTI_ERROR_MAX_LIMIT_REACHED; +} + +CUptiResult CuptiWrapperStub::ActivityGetNumDroppedRecords(CUcontext context, + uint32_t stream_id, + size_t* dropped) { + *dropped = 0; + return CUPTI_SUCCESS; +} + +CUptiResult CuptiWrapperStub::ActivityConfigureUnifiedMemoryCounter( + CUpti_ActivityUnifiedMemoryCounterConfig* config, uint32_t count) { + return CUPTI_SUCCESS; +} + +CUptiResult CuptiWrapperStub::ActivityRegisterCallbacks( + CUpti_BuffersCallbackRequestFunc func_buffer_requested, + CUpti_BuffersCallbackCompleteFunc func_buffer_completed) { + return CUPTI_SUCCESS; +} + +CUptiResult CuptiWrapperStub::GetDeviceId(CUcontext context, + uint32_t* deviceId) { + return cuptiGetDeviceId(context, deviceId); +} + +CUptiResult CuptiWrapperStub::GetTimestamp(uint64_t* timestamp) { + return cuptiGetTimestamp(timestamp); +} + +CUptiResult CuptiWrapperStub::Finalize() { return CUPTI_SUCCESS; } + +CUptiResult CuptiWrapperStub::EnableCallback(uint32_t enable, + CUpti_SubscriberHandle subscriber, + CUpti_CallbackDomain domain, + CUpti_CallbackId cbid) { + return CUPTI_SUCCESS; +} + +CUptiResult CuptiWrapperStub::EnableDomain(uint32_t enable, + CUpti_SubscriberHandle subscriber, + CUpti_CallbackDomain domain) { + return CUPTI_SUCCESS; +} + +CUptiResult CuptiWrapperStub::Subscribe(CUpti_SubscriberHandle* subscriber, + CUpti_CallbackFunc callback, + void* userdata) { + return CUPTI_SUCCESS; +} + +CUptiResult CuptiWrapperStub::Unsubscribe(CUpti_SubscriberHandle subscriber) { + return CUPTI_SUCCESS; +} + +CUptiResult CuptiWrapperStub::GetResultString(CUptiResult result, + const char** str) { + return cuptiGetResultString(result, str); +} + +CUptiResult CuptiWrapperStub::GetContextId(CUcontext context, + uint32_t* context_id) { + return cuptiGetContextId(context, context_id); +} + +CUptiResult CuptiWrapperStub::GetStreamIdEx(CUcontext context, CUstream stream, + uint8_t per_thread_stream, + uint32_t* stream_id) { + return cuptiGetStreamIdEx(context, stream, per_thread_stream, stream_id); +} + +} // namespace profiler +} // namespace xla diff --git a/third_party/xla/xla/backends/profiler/gpu/device_tracer_cuda.cc b/third_party/xla/xla/backends/profiler/gpu/device_tracer_cuda.cc index b6bd098be21fc4..70530221b68123 100644 --- a/third_party/xla/xla/backends/profiler/gpu/device_tracer_cuda.cc +++ b/third_party/xla/xla/backends/profiler/gpu/device_tracer_cuda.cc @@ -23,6 +23,7 @@ limitations under the License. #include "absl/container/fixed_array.h" #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" +#include "third_party/gpus/cuda/include/cuda.h" #include "xla/backends/profiler/gpu/cupti_collector.h" #include "xla/backends/profiler/gpu/cupti_tracer.h" #include "xla/backends/profiler/gpu/cupti_wrapper.h" @@ -130,12 +131,6 @@ Status GpuTracer::DoStart() { CUPTI_DRIVER_TRACE_CBID_cuStreamSynchronize, }; - bool use_cupti_activity_api = true; - ReadBoolFromEnvVar("TF_GPU_CUPTI_USE_ACTIVITY_API", true, - &use_cupti_activity_api) - .IgnoreError(); - options_.enable_event_based_activity = !use_cupti_activity_api; - bool trace_concurrent_kernels = false; ReadBoolFromEnvVar("TF_GPU_CUPTI_FORCE_CONCURRENT_KERNEL", true, &trace_concurrent_kernels) diff --git a/third_party/xla/xla/backends/profiler/plugin/BUILD b/third_party/xla/xla/backends/profiler/plugin/BUILD index 02f20827ac2379..28419c97d06d28 100644 --- a/third_party/xla/xla/backends/profiler/plugin/BUILD +++ b/third_party/xla/xla/backends/profiler/plugin/BUILD @@ -6,7 +6,6 @@ load("@local_tsl//tsl/platform:rules_cc.bzl", "cc_library") load("@local_tsl//tsl/profiler/builds:build_config.bzl", "tf_profiler_copts") package( - default_visibility = ["//visibility:public"], # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], licenses = ["notice"], ) @@ -24,7 +23,7 @@ cc_library( srcs = ["plugin_tracer.cc"], hdrs = ["plugin_tracer.h"], copts = tf_profiler_copts(), - visibility = ["//visibility:public"], + visibility = ["//xla:internal"], deps = [ ":profiler_c_api_hdrs", "//xla:status", @@ -34,6 +33,7 @@ cc_library( "@local_tsl//tsl/profiler/lib:profiler_interface", "@local_tsl//tsl/profiler/protobuf:profiler_options_proto_cc", "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", + "@local_tsl//tsl/profiler/utils:xplane_schema", ], alwayslink = True, ) @@ -56,7 +56,10 @@ cc_library( name = "plugin_tracer_impl", srcs = ["plugin_tracer_impl.cc"], hdrs = ["plugin_tracer_impl.h"], - visibility = ["//visibility:public"], + visibility = [ + "//learning/brain/research/pjrt:__pkg__", + "//xla/pjrt/c:__pkg__", + ], deps = [ ":profiler_c_api_hdrs", ":profiler_error", diff --git a/third_party/xla/xla/backends/profiler/plugin/plugin_tracer.cc b/third_party/xla/xla/backends/profiler/plugin/plugin_tracer.cc index 8176ac55f553e3..078187beeddcf1 100644 --- a/third_party/xla/xla/backends/profiler/plugin/plugin_tracer.cc +++ b/third_party/xla/xla/backends/profiler/plugin/plugin_tracer.cc @@ -21,15 +21,18 @@ limitations under the License. #include #include "absl/status/status.h" +#include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "xla/backends/profiler/plugin/profiler_c_api.h" #include "xla/status.h" #include "tsl/platform/logging.h" #include "tsl/profiler/protobuf/xplane.pb.h" +#include "tsl/profiler/utils/xplane_schema.h" namespace xla { namespace profiler { +using tensorflow::profiler::XLine; using tensorflow::profiler::XPlane; using tensorflow::profiler::XSpace; @@ -169,6 +172,11 @@ Status PluginTracer::CollectData(XSpace* space) { xspace.ParseFromArray(args.buffer, args.buffer_size_in_bytes); for (XPlane& tpu_plane : *xspace.mutable_planes()) { XPlane* plane = space->add_planes(); + if (tpu_plane.name() == tsl::profiler::kHostThreadsPlaneName) { + for (XLine& xline : *tpu_plane.mutable_lines()) { + xline.set_display_name(absl::StrCat("libtpu:", xline.name())); + } + } plane->Swap(&tpu_plane); } } diff --git a/third_party/xla/xla/backends/profiler/tpu/BUILD b/third_party/xla/xla/backends/profiler/tpu/BUILD index dbc527ed2c3e4e..6f47d30c29b919 100644 --- a/third_party/xla/xla/backends/profiler/tpu/BUILD +++ b/third_party/xla/xla/backends/profiler/tpu/BUILD @@ -3,7 +3,6 @@ load("@local_tsl//tsl/platform:rules_cc.bzl", "cc_library") load("@local_tsl//tsl/profiler/builds:build_config.bzl", "tf_profiler_copts") package( - default_visibility = ["//visibility:public"], # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], licenses = ["notice"], ) @@ -12,7 +11,7 @@ cc_library( name = "tpu_tracer", srcs = if_with_tpu_support(["tpu_tracer.cc"]), copts = tf_profiler_copts(), - visibility = ["//visibility:public"], + visibility = ["//xla:internal"], deps = [ "//xla/stream_executor/tpu:tpu_api", "//xla/stream_executor/tpu:tpu_api_dlsym_set_fn", diff --git a/third_party/xla/xla/c/BUILD b/third_party/xla/xla/c/BUILD index 770aa03124c4d8..2a2288a13c9395 100644 --- a/third_party/xla/xla/c/BUILD +++ b/third_party/xla/xla/c/BUILD @@ -1,7 +1,12 @@ +load("@local_tsl//tsl:tsl.bzl", "internal_visibility") load("@local_tsl//tsl/platform:rules_cc.bzl", "cc_library") package( - default_visibility = ["//visibility:public"], + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = internal_visibility([ + "//learning/brain/tfrt/tpu_plugin:__subpackages__", + "//tensorflow/core/common_runtime/next_pluggable_device:__subpackages__", + ]), licenses = ["notice"], ) @@ -10,7 +15,6 @@ cc_library( hdrs = [ "c_api_decl.h", ], - visibility = ["//visibility:public"], deps = [ ], ) diff --git a/third_party/xla/xla/client/BUILD b/third_party/xla/xla/client/BUILD index 241c66874814ab..6c3b7790d88f21 100644 --- a/third_party/xla/xla/client/BUILD +++ b/third_party/xla/xla/client/BUILD @@ -6,6 +6,7 @@ load("@local_tsl//tsl:tsl.default.bzl", "filegroup") load("@local_tsl//tsl/platform:rules_cc.bzl", "cc_library") package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], default_visibility = ["//visibility:public"], licenses = ["notice"], ) @@ -24,14 +25,12 @@ filegroup( "**/*.cc", "**/*.h", ]), - visibility = ["//visibility:public"], ) cc_library( name = "global_data", srcs = ["global_data.cc"], hdrs = ["global_data.h"], - visibility = ["//visibility:public"], deps = [ "//xla:service_interface", "//xla:types", @@ -47,7 +46,6 @@ cc_library( name = "padding", srcs = ["padding.cc"], hdrs = ["padding.h"], - visibility = ["//visibility:public"], deps = [ "//xla:statusor", "//xla:types", @@ -72,7 +70,6 @@ cc_library( name = "client", srcs = ["client.cc"], hdrs = ["client.h"], - visibility = ["//visibility:public"], deps = [ ":global_data", ":xla_computation", @@ -98,7 +95,6 @@ cc_library( name = "executable_build_options", srcs = ["executable_build_options.cc"], hdrs = ["executable_build_options.h"], - visibility = ["//visibility:public"], deps = [ "//xla:debug_options_flags", "//xla:execution_options_util", @@ -124,7 +120,6 @@ cc_library( name = "local_client", srcs = ["local_client.cc"], hdrs = ["local_client.h"], - visibility = ["//visibility:public"], deps = [ ":client", ":executable_build_options", @@ -153,7 +148,6 @@ cc_library( name = "compile_only_client", srcs = ["compile_only_client.cc"], hdrs = ["compile_only_client.h"], - visibility = ["//visibility:public"], deps = [ ":client", ":xla_computation", @@ -174,7 +168,6 @@ cc_library( name = "client_library", srcs = ["client_library.cc"], hdrs = ["client_library.h"], - visibility = ["//visibility:public"], deps = [ ":compile_only_client", ":local_client", @@ -197,7 +190,6 @@ cc_library( name = "sharding_builder", srcs = ["sharding_builder.cc"], hdrs = ["sharding_builder.h"], - visibility = ["//visibility:public"], deps = [ "//xla:array", "//xla:shape_tree", diff --git a/third_party/xla/xla/client/executable_build_options.cc b/third_party/xla/xla/client/executable_build_options.cc index 32e7fc717ee16d..168885fe7aa26b 100644 --- a/third_party/xla/xla/client/executable_build_options.cc +++ b/third_party/xla/xla/client/executable_build_options.cc @@ -176,6 +176,12 @@ StatusOr ExecutableBuildOptions::ToProto() const { } output.set_alias_passthrough_params(alias_passthrough_params()); output.set_run_backend_only(run_backend_only()); + if (!allow_spmd_sharding_propagation_to_parameters().empty()) { + output.mutable_allow_spmd_sharding_propagation_to_parameters()->Clear(); + for (bool v : allow_spmd_sharding_propagation_to_parameters()) { + output.mutable_allow_spmd_sharding_propagation_to_parameters()->Add(v); + } + } if (!allow_spmd_sharding_propagation_to_output().empty()) { output.mutable_allow_spmd_sharding_propagation_to_output()->Clear(); for (bool v : allow_spmd_sharding_propagation_to_output()) { @@ -224,6 +230,8 @@ StatusOr ExecutableBuildOptionsFromProto( } output.set_alias_passthrough_params(input.alias_passthrough_params()); output.set_run_backend_only(input.run_backend_only()); + output.set_allow_spmd_sharding_propagation_to_parameters( + input.allow_spmd_sharding_propagation_to_parameters()); output.set_allow_spmd_sharding_propagation_to_output( input.allow_spmd_sharding_propagation_to_output()); *output.mutable_fdo_profile() = input.fdo_profile(); @@ -266,6 +274,15 @@ ExecutionOptions CreateExecutionOptions( execution_options.mutable_auto_spmd_partitioning_mesh_ids()->Add(t); } execution_options.set_deduplicate_hlo(build_options.deduplicate_hlo()); + if (!build_options.allow_spmd_sharding_propagation_to_parameters().empty()) { + execution_options.mutable_allow_spmd_sharding_propagation_to_parameters() + ->Clear(); + for (bool v : + build_options.allow_spmd_sharding_propagation_to_parameters()) { + execution_options.mutable_allow_spmd_sharding_propagation_to_parameters() + ->Add(v); + } + } if (!build_options.allow_spmd_sharding_propagation_to_output().empty()) { execution_options.mutable_allow_spmd_sharding_propagation_to_output() ->Clear(); diff --git a/third_party/xla/xla/client/executable_build_options.h b/third_party/xla/xla/client/executable_build_options.h index 94f76d0c1fbe66..2fdbb5a84eb77e 100644 --- a/third_party/xla/xla/client/executable_build_options.h +++ b/third_party/xla/xla/client/executable_build_options.h @@ -153,13 +153,34 @@ class ExecutableBuildOptions { return *this; } + absl::Span allow_spmd_sharding_propagation_to_parameters() const { + return allow_spmd_sharding_propagation_to_parameters_; + } absl::Span allow_spmd_sharding_propagation_to_output() const { return allow_spmd_sharding_propagation_to_output_; } + bool any_allow_spmd_sharding_propagation_to_parameters() const { + return absl::c_linear_search(allow_spmd_sharding_propagation_to_parameters_, + true); + } bool any_allow_spmd_sharding_propagation_to_output() const { return absl::c_linear_search(allow_spmd_sharding_propagation_to_output_, true); } + // Allows sharding propagation to propagate to the inputs. This changes the + // input shape of the computation (which is undesirable), but it can be used + // to allow to run partial compilation to determine what would be the input + // sharding of a computation if XLA would be allowed to propagate the sharding + // which can be used by higher level framework as a way to query intermediate + // sharding of operations when multiple computation would be chained and + // merged together. + ExecutableBuildOptions& set_allow_spmd_sharding_propagation_to_parameters( + absl::Span allow_spmd_sharding_propagation_to_parameters) { + allow_spmd_sharding_propagation_to_parameters_.assign( + allow_spmd_sharding_propagation_to_parameters.begin(), + allow_spmd_sharding_propagation_to_parameters.end()); + return *this; + } // Allows sharding propagation to propagate to the outputs. This changes the // output shape of the computation (which is undesirable), but it can be used // to allow to run partial compilation to determine what would be the output @@ -233,6 +254,8 @@ class ExecutableBuildOptions { std::optional device_assignment_; bool alias_passthrough_params_ = false; bool run_backend_only_ = false; + absl::InlinedVector allow_spmd_sharding_propagation_to_parameters_ = + {false}; absl::InlinedVector allow_spmd_sharding_propagation_to_output_ = { false}; tsl::thread::ThreadPool* compile_thread_pool_ = nullptr; diff --git a/third_party/xla/xla/client/lib/BUILD b/third_party/xla/xla/client/lib/BUILD index 25e83a7ed8ea83..65b29dcbda7cd4 100644 --- a/third_party/xla/xla/client/lib/BUILD +++ b/third_party/xla/xla/client/lib/BUILD @@ -1,11 +1,13 @@ # Common computation builders for XLA. load("//xla/tests:build_defs.bzl", "generate_backend_suites", "xla_test") +load("@local_tsl//tsl:tsl.bzl", "internal_visibility") load("@local_tsl//tsl:tsl.default.bzl", "filegroup") load("@local_tsl//tsl/platform:rules_cc.bzl", "cc_library") package( - default_visibility = ["//visibility:public"], + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = internal_visibility(["//xla/client:friends"]), licenses = ["notice"], ) @@ -16,7 +18,6 @@ filegroup( "**/*.cc", "**/*.h", ]), - visibility = ["//visibility:public"], ) # Generate test_suites for all backends, named "${backend}_tests". @@ -26,7 +27,6 @@ cc_library( name = "arithmetic", srcs = ["arithmetic.cc"], hdrs = ["arithmetic.h"], - visibility = ["//visibility:public"], deps = [ ":constants", "//xla:shape_util", @@ -61,7 +61,6 @@ cc_library( hdrs = [ "comparators.h", ], - visibility = ["//visibility:public"], deps = [ ":constants", "//xla:shape_util", @@ -96,7 +95,6 @@ cc_library( name = "constants", srcs = ["constants.cc"], hdrs = ["constants.h"], - visibility = ["//visibility:public"], deps = [ "//xla:literal_util", "//xla:shape_util", @@ -112,7 +110,6 @@ cc_library( name = "broadcast", srcs = ["broadcast.cc"], hdrs = ["broadcast.h"], - visibility = ["//visibility:public"], deps = [ "//xla:shape_util", "//xla:status_macros", @@ -144,7 +141,6 @@ cc_library( name = "conv_grad_size_util", srcs = ["conv_grad_size_util.cc"], hdrs = ["conv_grad_size_util.h"], - visibility = ["//visibility:public"], deps = [ "//xla:status_macros", "//xla/client:padding", @@ -156,7 +152,6 @@ cc_library( name = "dynamic_shaped_ops", srcs = ["dynamic_shaped_ops.cc"], hdrs = ["dynamic_shaped_ops.h"], - visibility = ["//visibility:public"], deps = [ ":constants", "//xla:shape_util", @@ -174,7 +169,6 @@ cc_library( name = "loops", srcs = ["loops.cc"], hdrs = ["loops.h"], - visibility = ["//visibility:public"], deps = [ ":constants", "//xla:shape_util", @@ -191,7 +185,6 @@ cc_library( name = "math", srcs = ["math.cc"], hdrs = ["math.h"], - visibility = ["//visibility:public"], deps = [ ":arithmetic", ":constants", @@ -229,7 +222,6 @@ cc_library( name = "matrix", srcs = ["matrix.cc"], hdrs = ["matrix.h"], - visibility = ["//visibility:public"], deps = [ ":arithmetic", ":constants", @@ -275,7 +267,6 @@ cc_library( name = "pooling", srcs = ["pooling.cc"], hdrs = ["pooling.h"], - visibility = ["//visibility:public"], deps = [ ":arithmetic", ":constants", @@ -302,7 +293,6 @@ cc_library( name = "prng", srcs = ["prng.cc"], hdrs = ["prng.h"], - visibility = ["//visibility:public"], deps = [ ":constants", "//xla:shape_util", @@ -335,7 +325,6 @@ cc_library( name = "qr", srcs = ["qr.cc"], hdrs = ["qr.h"], - visibility = ["//visibility:public"], deps = [ ":arithmetic", ":constants", @@ -381,7 +370,6 @@ cc_library( name = "lu_decomposition", srcs = ["lu_decomposition.cc"], hdrs = ["lu_decomposition.h"], - visibility = ["//visibility:public"], deps = [ "//xla:shape_util", "//xla:statusor", @@ -395,7 +383,6 @@ cc_library( name = "approx_topk", srcs = ["approx_topk.cc"], hdrs = ["approx_topk.h"], - visibility = ["//visibility:public"], deps = [ ":approx_topk_shape", "//xla:shape_util", @@ -412,7 +399,6 @@ cc_library( name = "approx_topk_shape", srcs = ["approx_topk_shape.cc"], hdrs = ["approx_topk_shape.h"], - visibility = ["//visibility:public"], deps = [ "//xla:statusor", "//xla:util", @@ -423,7 +409,6 @@ cc_library( name = "slicing", srcs = ["slicing.cc"], hdrs = ["slicing.h"], - visibility = ["//visibility:public"], deps = [ ":arithmetic", ":constants", @@ -458,7 +443,6 @@ cc_library( name = "sorting", srcs = ["sorting.cc"], hdrs = ["sorting.h"], - visibility = ["//visibility:public"], deps = [ ":comparators", ":constants", @@ -489,7 +473,6 @@ xla_test( cc_library( name = "quantize", hdrs = ["quantize.h"], - visibility = ["//visibility:public"], deps = [ ":constants", "//xla:types", @@ -524,7 +507,6 @@ cc_library( name = "testing", srcs = ["testing.cc"], hdrs = ["testing.h"], - visibility = ["//visibility:public"], deps = [ "//xla:execution_options_util", "//xla:literal", @@ -547,7 +529,6 @@ cc_library( name = "self_adjoint_eig", srcs = ["self_adjoint_eig.cc"], hdrs = ["self_adjoint_eig.h"], - visibility = ["//visibility:public"], deps = [ ":arithmetic", ":comparators", @@ -599,7 +580,6 @@ cc_library( name = "svd", srcs = ["svd.cc"], hdrs = ["svd.h"], - visibility = ["//visibility:public"], deps = [ ":arithmetic", ":comparators", @@ -650,7 +630,6 @@ cc_library( name = "tridiagonal", srcs = ["tridiagonal.cc"], hdrs = ["tridiagonal.h"], - visibility = ["//visibility:public"], deps = [ ":constants", ":loops", @@ -688,7 +667,6 @@ cc_library( name = "logdet", srcs = ["logdet.cc"], hdrs = ["logdet.h"], - visibility = ["//visibility:public"], deps = [ ":arithmetic", ":constants", @@ -733,7 +711,6 @@ cc_library( name = "tuple", srcs = ["tuple.cc"], hdrs = ["tuple.h"], - visibility = ["//visibility:public"], deps = [ "//xla:shape_tree", "//xla:shape_util", diff --git a/third_party/xla/xla/client/lib/math.cc b/third_party/xla/xla/client/lib/math.cc index 39e58cd77e2947..545d4449a72938 100644 --- a/third_party/xla/xla/client/lib/math.cc +++ b/third_party/xla/xla/client/lib/math.cc @@ -336,26 +336,6 @@ static XlaOp ErfImpl32(XlaOp x) { EvaluatePolynomial(x2, kBeta); } -XlaOp Erf(XlaOp x) { - auto& b = *x.builder(); - return b.ReportErrorOrReturn([&]() -> StatusOr { - TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("Erf", x)); - TF_ASSIGN_OR_RETURN(auto shape, b.GetShape(x)); - // erf(x) = - // erf_impl(x) if x < 1 - // 1 - erfc_impl(x) otherwise - if (shape.element_type() == F64) { - return Select(Lt(Abs(x), ScalarLike(x, 1)), ErfImpl64(x), - ScalarLike(x, 1) - ErfcImpl64(x)); - } - // Erf(c)Impl don't have enough precision when run with bf16 intermediates - // (not surprising!), so upcast to f32 in this case. - return DoWithUpcastToF32( - x, {BF16, F16, F8E5M2, F8E4M3FN, F8E4M3B11FNUZ, F8E5M2FNUZ, F8E4M3FNUZ}, - [](XlaOp x) { return ErfImpl32(x); }); - }); -} - namespace { // Approximation for the inverse error function from diff --git a/third_party/xla/xla/client/lib/math.h b/third_party/xla/xla/client/lib/math.h index 6c0d0ae05388a9..74b8a387a416de 100644 --- a/third_party/xla/xla/client/lib/math.h +++ b/third_party/xla/xla/client/lib/math.h @@ -46,9 +46,6 @@ XlaOp Reciprocal(XlaOp operand); // Computes an approximation of the error function complement (1 - erf(x)). XlaOp Erfc(XlaOp x); -// Computes an approximation of the error function. -XlaOp Erf(XlaOp x); - // Computes an approximation of the inverse of the error function. XlaOp ErfInv(XlaOp x); diff --git a/third_party/xla/xla/client/value_inference.cc b/third_party/xla/xla/client/value_inference.cc index 179f232862a83d..8d09bc57b5759b 100644 --- a/third_party/xla/xla/client/value_inference.cc +++ b/third_party/xla/xla/client/value_inference.cc @@ -1113,6 +1113,7 @@ StatusOr PostorderDFSVisitor::AnalyzeIsDynamic( case HloOpcode::kCollectivePermuteDone: case HloOpcode::kCos: case HloOpcode::kClz: + case HloOpcode::kErf: case HloOpcode::kExp: case HloOpcode::kExpm1: case HloOpcode::kFloor: diff --git a/third_party/xla/xla/client/xla_builder.cc b/third_party/xla/xla/client/xla_builder.cc index e3ee1aedbec6ed..64a09fe51edba9 100644 --- a/third_party/xla/xla/client/xla_builder.cc +++ b/third_party/xla/xla/client/xla_builder.cc @@ -837,6 +837,58 @@ StatusOr XlaBuilder::Build(int64_t root_id, return OkStatus(); } +XlaOp XlaBuilder::DynamicBroadcastInDim( + const XlaOp operand, const XlaOp output_dimensions, + absl::Span broadcast_dimensions, const Shape& output_shape) { + return ReportErrorOrReturn([&]() -> absl::StatusOr { + TF_RET_CHECK(!output_shape.is_dynamic()); + TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand)); + + int64_t operand_rank = operand_shape->rank(); + int64_t result_rank = output_shape.rank(); + int64_t broadcast_dimensions_size = broadcast_dimensions.size(); + if (broadcast_dimensions_size != operand_rank) { + return InvalidArgument( + "broadcast_dimensions size (%d) does not match operand rank (%d)", + broadcast_dimensions_size, operand_rank); + } + + if (result_rank < operand_rank) { + return InvalidArgument("result rank (%d) is less than operand rank (%d)", + result_rank, operand_rank); + } + + for (int64_t i = 0; i != broadcast_dimensions_size; ++i) { + int64_t dim_index = broadcast_dimensions[i]; + if (dim_index < 0 || dim_index >= result_rank) { + return InvalidArgument( + "broadcast_dimensions contains invalid value %d for result with " + "rank %d", + dim_index, result_rank); + } + + int64_t dim_size = operand_shape->dimensions(i); + int64_t result_dim_size = output_shape.dimensions(dim_index); + + if (dim_size != 1 && dim_size != result_dim_size && + dim_size != Shape::kUnboundedSize) { + return InvalidArgument( + "size of operand dimension %d (%d) is not compatible with size of " + "result dimension %d (%d)", + i, dim_size, dim_index, result_dim_size); + } + } + + return xla::CustomCall( + operand.builder(), "mhlo.dynamic_broadcast_in_dim", + /*operands=*/{operand, output_dimensions}, + /*shape=*/output_shape, + /*opaque=*/ + absl::StrCat("{broadcast_dimensions=[", + absl::StrJoin(broadcast_dimensions, ","), "]}")); + }); +} + StatusOr XlaBuilder::InDimBroadcast( const Shape& shape, XlaOp operand, absl::Span broadcast_dimensions) { @@ -862,8 +914,8 @@ StatusOr XlaBuilder::InDimBroadcast( << " i: " << i << ", shape: " << shape.ToString() << ", operand_shape: " << operand_shape->ToString(); } else { - // Non-broadcast dimensions must not be dynamic. - TF_RET_CHECK(!shape.is_dynamic_dimension(i)); + // Non-broadcast dimensions must be static. + TF_RET_CHECK(shape.is_static_dimension(i)); } } return AddInstruction(std::move(instr), HloOpcode::kBroadcast, {operand}); @@ -898,7 +950,7 @@ StatusOr XlaBuilder::AddBroadcastSequence(const Shape& output_shape, operand_shape->is_dynamic_dimension(i)); } else { TF_RET_CHECK(operand_shape->dimensions(i) == 1 && - !operand_shape->is_dynamic_dimension(i)) + operand_shape->is_static_dimension(i)) << "An explicit broadcast sequence requires the broadcasted " "dimensions to be trivial; operand shape: " << *operand_shape << "; output_shape: " << output_shape; @@ -935,6 +987,57 @@ XlaOp XlaBuilder::UnaryOp(HloOpcode unop, XlaOp operand) { }); } +namespace { + +// Broadcasts an origin XLA op to the rank of target_shape. +// Does not broadcast rank dimensions to match, only expands rank. +// Is identity function if origin rank matches target rank. +StatusOr BroadcastToTargetRank( + XlaOp origin, const Shape& origin_shape, const Shape& target_shape, + absl::Span broadcast_dimensions) { + const int64_t origin_rank = origin_shape.rank(); + const int64_t target_rank = target_shape.rank(); + + // Identity op if ranks match, shold never be larger than target. + if (origin_rank >= target_rank) { + return origin; + } + + // Update target_size with origin sizes using broadcast_dimensions + absl::Span target_dimensions = target_shape.dimensions(); + std::vector target_size{target_dimensions.begin(), + target_dimensions.end()}; + for (int64_t origin_dim = 0; origin_dim < origin_rank; origin_dim++) { + int64_t target_dim = broadcast_dimensions[origin_dim]; + target_size[target_dim] = origin_shape.dimensions(origin_dim); + } + return xla::BroadcastInDim(origin, target_size, broadcast_dimensions); +} + +// For ternary ops, only scalar broadcasting is supported. +// Return the non-scalar shape that all scalars should be broadcasted too +// Returns status if non-scalar operands do not match. +StatusOr> InferScalarBroadcastShape( + const Shape* lhs_shape, const Shape* rhs_shape, const Shape* ehs_shape) { + // The shape is not scalar, it may have unbounded/bounded dynamic + // dimensions. + std::optional broadcasted_shape; + for (const Shape* shape : {lhs_shape, rhs_shape, ehs_shape}) { + if (!shape->IsArray() || shape->rank() == 0) continue; + if (!broadcasted_shape.has_value()) { + broadcasted_shape = ShapeUtil::MakeStaticShape(*shape); + } + // TODO(jpienaar): The case where we need to compute the broadcasted + // shape by considering multiple of the shapes is not implemented. + // Consider reusing getBroadcastedType from mlir/Dialect/Traits.h. + TF_RET_CHECK(ShapeUtil::SameDimensions(broadcasted_shape.value(), *shape)) + << "Unimplemented implicit broadcast."; + } + return broadcasted_shape; +} + +} // namespace + XlaOp XlaBuilder::BinaryOp(HloOpcode binop, XlaOp lhs, XlaOp rhs, absl::Span broadcast_dimensions, std::optional direction, @@ -946,40 +1049,12 @@ XlaOp XlaBuilder::BinaryOp(HloOpcode binop, XlaOp lhs, XlaOp rhs, Shape shape, ShapeInference::InferBinaryOpShape( binop, *lhs_shape, *rhs_shape, broadcast_dimensions)); - const int64_t lhs_rank = lhs_shape->rank(); - const int64_t rhs_rank = rhs_shape->rank(); - - XlaOp updated_lhs = lhs; - XlaOp updated_rhs = rhs; - if (!broadcast_dimensions.empty() && lhs_rank != rhs_rank) { - const bool should_broadcast_lhs = lhs_rank < rhs_rank; - XlaOp from = should_broadcast_lhs ? lhs : rhs; - const Shape& from_shape = should_broadcast_lhs ? *lhs_shape : *rhs_shape; - - std::vector to_size; - std::vector to_size_is_dynamic; - const auto rank = shape.rank(); - to_size.reserve(rank); - to_size_is_dynamic.reserve(rank); - for (int i = 0; i < rank; i++) { - to_size.push_back(shape.dimensions(i)); - to_size_is_dynamic.push_back(false); - } - for (int64_t from_dim = 0; from_dim < from_shape.rank(); from_dim++) { - int64_t to_dim = broadcast_dimensions[from_dim]; - to_size[to_dim] = from_shape.dimensions(from_dim); - to_size_is_dynamic[to_dim] = from_shape.is_dynamic_dimension(from_dim); - } - - const Shape& broadcasted_shape = ShapeUtil::MakeShape( - from_shape.element_type(), to_size, to_size_is_dynamic); - TF_ASSIGN_OR_RETURN( - XlaOp broadcasted_operand, - InDimBroadcast(broadcasted_shape, from, broadcast_dimensions)); - - updated_lhs = should_broadcast_lhs ? broadcasted_operand : lhs; - updated_rhs = !should_broadcast_lhs ? broadcasted_operand : rhs; - } + TF_ASSIGN_OR_RETURN( + XlaOp updated_lhs, + BroadcastToTargetRank(lhs, *lhs_shape, shape, broadcast_dimensions)); + TF_ASSIGN_OR_RETURN( + XlaOp updated_rhs, + BroadcastToTargetRank(rhs, *rhs_shape, shape, broadcast_dimensions)); TF_ASSIGN_OR_RETURN(const Shape* updated_lhs_shape, GetShapePtr(updated_lhs)); @@ -1058,23 +1133,11 @@ XlaOp XlaBuilder::TernaryOp(HloOpcode triop, XlaOp lhs, XlaOp rhs, XlaOp ehs) { TF_ASSIGN_OR_RETURN(const Shape* rhs_shape, GetShapePtr(rhs)); TF_ASSIGN_OR_RETURN(const Shape* ehs_shape, GetShapePtr(ehs)); - // The shape is not scalar, it may have unbounded/bounded dynamic - // dimensions. - std::optional non_scalar_shape; - for (const Shape* shape : {lhs_shape, rhs_shape, ehs_shape}) { - if (shape->IsArray() && shape->rank() != 0) { - if (non_scalar_shape.has_value()) { - // TODO(jpienaar): The case where we need to compute the broadcasted - // shape by considering multiple of the shapes is not implemented. - // Consider reusing getBroadcastedType from mlir/Dialect/Traits.h. - TF_RET_CHECK( - ShapeUtil::SameDimensions(non_scalar_shape.value(), *shape)) - << "Unimplemented implicit broadcast."; - } else { - non_scalar_shape = ShapeUtil::MakeStaticShape(*shape); - } - } - } + TF_ASSIGN_OR_RETURN( + std::optional non_scalar_shape, + InferScalarBroadcastShape(lhs_shape, rhs_shape, ehs_shape)); + + // Scalar broadcast if mix of scalars and non-scalars if (non_scalar_shape.has_value()) { bool is_unbounded_dynamic = non_scalar_shape->is_unbounded_dynamic(); if (ShapeUtil::IsScalar(*lhs_shape)) { @@ -3921,7 +3984,7 @@ XlaOp XlaBuilder::GetDimensionSize(XlaOp operand, int64_t dimension) { *operand_shape, dimension)); // Calling GetDimensionSize on a static dimension returns a constant // instruction. - if (!operand_shape->is_dynamic_dimension(dimension)) { + if (operand_shape->is_static_dimension(dimension)) { return ConstantR0(this, operand_shape->dimensions(dimension)); } *instr.mutable_shape() = shape.ToProto(); @@ -4423,6 +4486,13 @@ XlaOp BroadcastInDim(const XlaOp operand, broadcast_dimensions); } +XlaOp DynamicBroadcastInDim(const XlaOp operand, const XlaOp output_dimensions, + absl::Span broadcast_dimensions, + const Shape& output_shape) { + return operand.builder()->DynamicBroadcastInDim( + operand, output_dimensions, broadcast_dimensions, output_shape); +} + XlaOp Copy(const XlaOp operand) { return operand.builder()->UnaryOp(HloOpcode::kCopy, operand); } @@ -5165,6 +5235,9 @@ XlaOp Log(const XlaOp operand) { XlaOp Log1p(const XlaOp operand) { return operand.builder()->UnaryOp(HloOpcode::kLog1p, operand); } +XlaOp Erf(const XlaOp operand) { + return operand.builder()->UnaryOp(HloOpcode::kErf, operand); +} XlaOp Logistic(const XlaOp operand) { return operand.builder()->UnaryOp(HloOpcode::kLogistic, operand); } diff --git a/third_party/xla/xla/client/xla_builder.h b/third_party/xla/xla/client/xla_builder.h index 8bba0d85537b9c..83366428da2572 100644 --- a/third_party/xla/xla/client/xla_builder.h +++ b/third_party/xla/xla/client/xla_builder.h @@ -512,6 +512,14 @@ class XlaBuilder { XlaOp BroadcastInDim(XlaOp operand, absl::Span out_dim_size, absl::Span broadcast_dimensions); + // This is an experimental API for creating the mhlo.dynamic_broadcast_in_dim + // op from the XlaBuilder. This is only intended for export to MHLO or + // StableHLO, and cannot be compiled. Only static output_dimensions are + // allowed, and broadcast_dimensions is verified. + XlaOp DynamicBroadcastInDim(XlaOp operand, XlaOp output_dimensions, + absl::Span broadcast_dimensions, + const Shape& output_shape); + XlaOp Pad(XlaOp operand, XlaOp padding_value, const PaddingConfig& padding_config); XlaOp PadInDim(XlaOp operand, XlaOp padding_value, int64_t dimno, @@ -1177,6 +1185,11 @@ class XlaBuilder { absl::Span out_dim_size, absl::Span broadcast_dimensions); + friend XlaOp DynamicBroadcastInDim( + XlaOp operand, XlaOp output_dimensions, + absl::Span broadcast_dimensions, + const Shape& output_shape); + friend XlaOp Copy(XlaOp operand); friend XlaOp Pad(XlaOp operand, XlaOp padding_value, @@ -1494,6 +1507,7 @@ class XlaBuilder { friend XlaOp Abs(XlaOp operand); friend XlaOp Atan2(XlaOp y, XlaOp x, absl::Span broadcast_dimensions); + friend XlaOp Erf(XlaOp operand); friend XlaOp Exp(XlaOp operand); friend XlaOp Expm1(XlaOp operand); friend XlaOp Floor(XlaOp operand); @@ -1859,6 +1873,15 @@ XlaOp Broadcast(XlaOp operand, absl::Span broadcast_sizes); XlaOp BroadcastInDim(XlaOp operand, absl::Span out_dim_size, absl::Span broadcast_dimensions); +// This is an experimental API for creating the mhlo.dynamic_broadcast_in_dim +// op from the XlaBuilder. This is only intended for export to MHLO or +// StableHLO, and cannot be compiled. See +// https://www.tensorflow.org/mlir/hlo_ops#mhlodynamic_broadcast_in_dim_mhlodynamicbroadcastindimop. +// for the op semantics. +XlaOp DynamicBroadcastInDim(XlaOp operand, XlaOp output_dimensions, + absl::Span broadcast_dimensions, + const Shape& output_shape); + // Copies the input operand to the output. This operation is for internal // purpose and is only used by the compiler for optimization purposes or to // ensure correctness. The XLA client should never have to generate this @@ -2545,6 +2568,9 @@ XlaOp Abs(XlaOp operand); XlaOp Atan2(XlaOp y, XlaOp x, absl::Span broadcast_dimensions = {}); +// Enqueues an erf instruction onto the computation. +XlaOp Erf(XlaOp operand); + // Enqueues an exp instruction onto the computation. XlaOp Exp(XlaOp operand); diff --git a/third_party/xla/xla/client/xla_builder_test.cc b/third_party/xla/xla/client/xla_builder_test.cc index e030b9ac351cbe..87a91f87f064b1 100644 --- a/third_party/xla/xla/client/xla_builder_test.cc +++ b/third_party/xla/xla/client/xla_builder_test.cc @@ -249,6 +249,30 @@ TEST(XlaBuilderTest, XPlusX) { EXPECT_THAT(root, GmockMatch(m::Add(m::Parameter(0), m::Parameter(0)))); } +TEST(XlaBuilderTest, TestBinaryOpImplicitBroadcast) { + XlaBuilder b(TestName()); + TF_ASSERT_OK_AND_ASSIGN(Shape lhs, ParseShape("f32[1]")); + TF_ASSERT_OK_AND_ASSIGN(Shape rhs, ParseShape("f32[2, 2]")); + TF_ASSERT_OK_AND_ASSIGN(Shape expected, ParseShape("f32[2,2]")); + Add(Parameter(&b, 0, lhs, "lhs"), Parameter(&b, 1, rhs, "rhs"), + /*broadcast_dimensions=*/{1}); + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(b)); + EXPECT_THAT(GetRoot(*module), + GmockMatch(m::Op().WithShapeEqualTo(&expected))); +} + +TEST(XlaBuilderTest, TestBinaryOpImplicitBroadcastBounded) { + XlaBuilder b(TestName()); + TF_ASSERT_OK_AND_ASSIGN(Shape lhs, ParseShape("f32[1]")); + TF_ASSERT_OK_AND_ASSIGN(Shape rhs, ParseShape("f32[<=2, <=2]")); + TF_ASSERT_OK_AND_ASSIGN(Shape expected, ParseShape("f32[<=2, <=2]")); + Add(Parameter(&b, 0, lhs, "lhs"), Parameter(&b, 1, rhs, "rhs"), + /*broadcast_dimensions=*/{1}); + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(b)); + EXPECT_THAT(GetRoot(*module), + GmockMatch(m::Op().WithShapeEqualTo(&expected))); +} + TEST(XlaBuilderTest, ShapeInferenceError) { XlaBuilder b(TestName()); auto x = Parameter(&b, 0, ShapeUtil::MakeShape(U32, {2, 4, 6}), "x"); @@ -1584,6 +1608,76 @@ TEST(XlaBuilderTest, TopKDimensions) { EXPECT_EQ(root->shape().tuple_shapes(1).dimensions(1), k); } +//============================================================================// +// Experimental Test +//============================================================================// + +TEST(XlaBuilderTest, DynamicBroadcastInDimExportSuccess) { + XlaBuilder b(TestName()); + TF_ASSERT_OK_AND_ASSIGN(const Shape& operand, ParseShape("f32[1, ?]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape& output_dimensions, ParseShape("s32[3]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape& output_shape, + ParseShape("f32[1, 2, 3]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape& expected, ParseShape("f32[1, 2, 3]")); + DynamicBroadcastInDim( + Parameter(&b, 0, operand, "operand"), + Parameter(&b, 1, output_dimensions, "output_dimensions"), + /*broadcast_dimensions=*/{1, 2}, output_shape); + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(b)); + EXPECT_THAT(module->ToString(), HasSubstr("mhlo.dynamic_broadcast_in_dim")); + EXPECT_THAT(module->ToString(), HasSubstr("broadcast_dimensions=[1,2]")); + EXPECT_THAT(GetRoot(*module), + GmockMatch(m::Op().WithShapeEqualTo(&expected))); +} + +TEST(XlaBuilderTest, DynamicBroadcastInDimNonBroadcastDimSizeGreaterThanOne) { + XlaBuilder b(TestName()); + TF_ASSERT_OK_AND_ASSIGN(const Shape& operand, ParseShape("f32[2, ?]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape& output_dimensions, ParseShape("s32[3]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape& output_shape, + ParseShape("f32[2, 2, 3]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape& expected, ParseShape("f32[2, 2, 3]")); + DynamicBroadcastInDim( + Parameter(&b, 0, operand, "operand"), + Parameter(&b, 1, output_dimensions, "output_dimensions"), + /*broadcast_dimensions=*/{1, 2}, output_shape); + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(b)); + EXPECT_THAT(module->ToString(), HasSubstr("mhlo.dynamic_broadcast_in_dim")); + EXPECT_THAT(module->ToString(), HasSubstr("broadcast_dimensions=[1,2]")); + EXPECT_THAT(GetRoot(*module), + GmockMatch(m::Op().WithShapeEqualTo(&expected))); +} + +TEST(XlaBuilderTest, DynamicBroadcastInDimIncompatibleBroadcastSize) { + XlaBuilder b(TestName()); + TF_ASSERT_OK_AND_ASSIGN(const Shape& operand, ParseShape("f32[2, ?]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape& output_dimensions, ParseShape("s32[3]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape& output_shape, + ParseShape("f32[2, 3, 3]")); + DynamicBroadcastInDim( + Parameter(&b, 0, operand, "operand"), + Parameter(&b, 1, output_dimensions, "output_dimensions"), + /*broadcast_dimensions=*/{1, 2}, output_shape); + EXPECT_THAT( + BuildHloModule(b), + StatusIs(_, HasSubstr("size of operand dimension 0 (2) is not compatible " + "with size of result dimension 1 (3)"))); +} + +TEST(XlaBuilderTest, DynamicBroadcastInDimUnsupportedDynamicResultSize) { + XlaBuilder b(TestName()); + TF_ASSERT_OK_AND_ASSIGN(const Shape& operand, ParseShape("f32[1, ?]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape& output_dimensions, ParseShape("s32[3]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape& output_shape, + ParseShape("f32[1, 2, ?]")); + DynamicBroadcastInDim( + Parameter(&b, 0, operand, "operand"), + Parameter(&b, 1, output_dimensions, "output_dimensions"), + /*broadcast_dimensions=*/{1, 2}, output_shape); + EXPECT_THAT(BuildHloModule(b), + StatusIs(_, HasSubstr("!output_shape.is_dynamic()"))); +} + //============================================================================// // Unbounded Dynamism Test //============================================================================// @@ -1774,11 +1868,8 @@ TEST(XlaBuilderTest, UnboundedClampUnsupportedImplicitBroadcast1) { TF_ASSERT_OK_AND_ASSIGN(Shape ehs, ParseShape("f32[?, 10]")); Clamp(Parameter(&b, 0, lhs, "lhs"), Parameter(&b, 1, rhs, "rhs"), Parameter(&b, 2, ehs, "ehs")); - EXPECT_THAT( - BuildHloModule(b), - StatusIs(_, - HasSubstr("ShapeUtil::SameDimensions(non_scalar_shape.value(), " - "*shape) Unimplemented implicit broadcast."))); + EXPECT_THAT(BuildHloModule(b), + StatusIs(_, HasSubstr("Unimplemented implicit broadcast."))); } TEST(XlaBuilderTest, UnboundedClampUnsupportedImplicitBroadcast2) { @@ -1788,11 +1879,8 @@ TEST(XlaBuilderTest, UnboundedClampUnsupportedImplicitBroadcast2) { TF_ASSERT_OK_AND_ASSIGN(Shape ehs, ParseShape("f32[]")); Clamp(Parameter(&b, 0, lhs, "lhs"), Parameter(&b, 1, rhs, "rhs"), Parameter(&b, 2, ehs, "ehs")); - EXPECT_THAT( - BuildHloModule(b), - StatusIs(_, - HasSubstr( - "!is_unbounded_dynamic Unimplemented implicit broadcast."))); + EXPECT_THAT(BuildHloModule(b), + StatusIs(_, HasSubstr("Unimplemented implicit broadcast."))); } TEST(XlaBuilderTest, UnboundedClampUnsupportedImplicitBroadcast3) { @@ -2196,11 +2284,8 @@ TEST(XlaBuilderTest, UnboundedSelectUnsupportedImplicitBroadcast1) { TF_ASSERT_OK_AND_ASSIGN(Shape ehs, ParseShape("f32[?, 10]")); Select(Parameter(&b, 0, lhs, "lhs"), Parameter(&b, 1, rhs, "rhs"), Parameter(&b, 2, ehs, "ehs")); - EXPECT_THAT( - BuildHloModule(b), - StatusIs(_, - HasSubstr("ShapeUtil::SameDimensions(non_scalar_shape.value(), " - "*shape) Unimplemented implicit broadcast."))); + EXPECT_THAT(BuildHloModule(b), + StatusIs(_, HasSubstr("Unimplemented implicit broadcast."))); } TEST(XlaBuilderTest, UnboundedSelectUnsupportedImplicitBroadcast2) { @@ -2301,6 +2386,7 @@ INSTANTIATE_TEST_SUITE_P(UnboundedDynamism, XlaBuilderUnboundedUnaryOpTest, {"f32[?]", "f32[?]", &Ceil}, {"u32[?]", "u32[?]", &Clz}, {"f32[?]", "f32[?]", &Cos}, + {"f32[?]", "f32[?]", &Erf}, {"f32[?]", "f32[?]", &Exp}, {"f32[?]", "f32[?]", &Expm1}, {"f32[?]", "f32[?]", &Floor}, diff --git a/third_party/xla/xla/comparison_util.h b/third_party/xla/xla/comparison_util.h index 49bbc77040b2f5..1127fa17ea1134 100644 --- a/third_party/xla/xla/comparison_util.h +++ b/third_party/xla/xla/comparison_util.h @@ -61,6 +61,13 @@ class Comparison { kPartial, }; + friend absl::string_view ComparisonOrderToString(Comparison::Order order); + + template + friend void AbslStringify(Sink& sink, const Order& p) { + absl::Format(&sink, "%s", ComparisonOrderToString(p)); + } + // Represents different comparison operations. enum class Direction : uint8_t { kEq, @@ -228,7 +235,6 @@ inline std::ostream& operator<<(std::ostream& os, const Comparison& cmp) { std::string ComparisonDirectionToString(Comparison::Direction direction); std::string ComparisonTypeToString(Comparison::Type type); absl::string_view ComparisonPrimitiveTypeToString(PrimitiveType type); -absl::string_view ComparisonOrderToString(Comparison::Order order); StatusOr StringToComparisonDirection( absl::string_view direction); diff --git a/third_party/xla/xla/debug_options_flags.cc b/third_party/xla/xla/debug_options_flags.cc index bffd2033211862..496943c6dd5199 100644 --- a/third_party/xla/xla/debug_options_flags.cc +++ b/third_party/xla/xla/debug_options_flags.cc @@ -22,10 +22,12 @@ limitations under the License. #include #include +#include "absl/algorithm/container.h" #include "absl/base/call_once.h" #include "absl/container/flat_hash_map.h" #include "absl/container/node_hash_map.h" #include "absl/strings/ascii.h" +#include "absl/strings/match.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" #include "absl/strings/str_join.h" @@ -85,7 +87,7 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() { opts.set_xla_cpu_fast_math_honor_division(true); // TODO(AyanmoI): Remove this flag when cuDNN FMHA is fully supported. - opts.set_xla_gpu_enable_cudnn_fmha(false); + opts.set_xla_gpu_enable_cudnn_fmha(true); opts.set_xla_gpu_fused_attention_use_cudnn_rng(false); @@ -98,9 +100,6 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() { // flag. opts.set_xla_gpu_enable_cublaslt(false); -#if CUDA_VERSION >= 12000 - // The new GPU runtime causes occasional hangs on cuda 11, possibly due to the - // use of cuda graphs. Disable using cuda graphs. opts.add_xla_gpu_enable_command_buffer(DebugOptions::FUSION); opts.add_xla_gpu_enable_command_buffer(DebugOptions::CUBLAS); opts.add_xla_gpu_enable_command_buffer(DebugOptions::CUSTOM_CALL); @@ -108,7 +107,6 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() { opts.set_xla_gpu_graph_min_graph_size(5); opts.set_xla_gpu_graph_enable_concurrent_region(false); opts.set_xla_gpu_graph_eviction_timeout_seconds(60); -#endif // Despite the name, fast min/max on GPUs does not seem to be any faster, and // adds very counter-intuitive "NaN-swallowing" behavior. @@ -141,9 +139,11 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() { opts.set_xla_gpu_enable_xla_runtime_executable(false); opts.set_xla_gpu_enable_custom_fusions(false); + opts.set_xla_gpu_enable_address_computation_fusion(true); opts.set_xla_gpu_nccl_termination_timeout_seconds(-1); opts.set_xla_gpu_enable_shared_constants(true); opts.set_xla_gpu_enable_nccl_user_buffers(false); + opts.set_xla_gpu_enable_nccl_comm_splitting(false); // Set 4GB space limit for redzone scratch allocator. opts.set_xla_gpu_redzone_scratch_max_megabytes(1LL << 12); @@ -193,7 +193,7 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() { opts.set_xla_gpu_exhaustive_tiling_search(false); - opts.set_xla_gpu_enable_priority_fusion(false); + opts.set_xla_gpu_enable_priority_fusion(true); opts.set_xla_gpu_auto_spmd_partitioning_memory_budget_gb(0); opts.set_xla_gpu_auto_spmd_partitioning_memory_budget_ratio(1.1); @@ -223,6 +223,12 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() { opts.set_xla_gpu_enable_libnvptxcompiler(false); + opts.set_xla_gpu_enable_dot_strength_reduction(true); + + opts.set_xla_gpu_enable_bf16_6way_gemm(false); + opts.set_xla_gpu_nccl_collective_max_nchannels(0); + opts.set_xla_gpu_nccl_p2p_max_nchannels(0); + return opts; } @@ -402,17 +408,76 @@ void MakeDebugOptionsFlags(std::vector* flag_list, // Custom "sub-parser" lambda for xla_gpu_enable_command_buffer. auto setter_for_xla_gpu_enable_command_buffer = - [debug_options](const std::string& values) { - debug_options->clear_xla_gpu_enable_command_buffer(); - for (const absl::string_view value : absl::StrSplit(values, ',')) { + [debug_options](const std::string& input) { + auto is_command_type = [](absl::string_view value) { + DebugOptions::CommandBufferCmdType cmd_type; + return DebugOptions::CommandBufferCmdType_Parse( + absl::AsciiStrToUpper(value), &cmd_type); + }; + + auto is_add_or_remove_command_type = [&](absl::string_view value) { + if (absl::StartsWith(value, "+") || absl::StartsWith(value, "-")) { + return (is_command_type(value.substr(1))); + } + return false; + }; + + auto parse_command_type = [](absl::string_view value) { DebugOptions::CommandBufferCmdType cmd_type; - if (!DebugOptions::CommandBufferCmdType_Parse( - absl::AsciiStrToUpper(value), &cmd_type)) { - return false; + DebugOptions::CommandBufferCmdType_Parse(absl::AsciiStrToUpper(value), + &cmd_type); + return cmd_type; + }; + + auto erase_command_type = [](tsl::protobuf::RepeatedField* enabled, + DebugOptions::CommandBufferCmdType type) { + auto it = enabled->begin(); + while (it != enabled->end()) { + if (*it == type) { + it = enabled->erase(it); + } else { + it++; + } } - debug_options->add_xla_gpu_enable_command_buffer(cmd_type); + }; + + // Disable command buffers by clearing a set of supported commands. + if (input.empty()) { + debug_options->clear_xla_gpu_enable_command_buffer(); + return true; } - return true; + + std::vector values = absl::StrSplit(input, ','); + + // Overwrite a set of supported commands with a flag. + if (absl::c_all_of(values, is_command_type)) { + debug_options->clear_xla_gpu_enable_command_buffer(); + for (const absl::string_view value : values) { + debug_options->add_xla_gpu_enable_command_buffer( + parse_command_type(value)); + } + return true; + } + + // Add or remove a commands from a default set. + if (absl::c_all_of(values, is_add_or_remove_command_type)) { + for (const absl::string_view value : values) { + DebugOptions::CommandBufferCmdType cmd_type = + parse_command_type(value.substr(1)); + if (absl::StartsWith(value, "+")) { + debug_options->add_xla_gpu_enable_command_buffer(cmd_type); + } else if (absl::StartsWith(value, "-")) { + tsl::protobuf::RepeatedField* enabled = + debug_options->mutable_xla_gpu_enable_command_buffer(); + erase_command_type(enabled, cmd_type); + } + return true; + } + } + + // Return an error if flag value was not recognized as one of the + // supported modes. + return false; }; // Custom "sub-parser" for xla_fuel. Note that ConsumeFuel does not do any @@ -1027,7 +1092,10 @@ void MakeDebugOptionsFlags(std::vector* flag_list, flag_list->push_back(tsl::Flag( "xla_gpu_enable_command_buffer", setter_for_xla_gpu_enable_command_buffer, command_types_to_string(debug_options->xla_gpu_enable_command_buffer()), - "The types of the commands that are recorded into command buffers")); + "The types of the commands that are recorded into command buffers. It" + " can either be a list of command types or a list of command types with" + " + and - as prefix, which indicate adding or removing a command type" + " to/from the default list.")); flag_list->push_back(tsl::Flag( "xla_gpu_graph_num_runs_to_instantiate", int32_setter_for( @@ -1094,6 +1162,12 @@ void MakeDebugOptionsFlags(std::vector* flag_list, "Limits custom fusion only to fusions which match this regular " "expression. Default is all custom fusions registerered in a current " "process.")); + flag_list->push_back(tsl::Flag( + "xla_gpu_enable_address_computation_fusion", + bool_setter_for( + &DebugOptions::set_xla_gpu_enable_address_computation_fusion), + debug_options->xla_gpu_enable_address_computation_fusion(), + "Whether to enable XLA address computation fusion")); flag_list->push_back(tsl::Flag( "xla_gpu_nccl_termination_timeout_seconds", int64_setter_for( @@ -1112,6 +1186,12 @@ void MakeDebugOptionsFlags(std::vector* flag_list, "Enables NCCL User Buffer Registration. collective_memory_size in the " "allocator config must also be set to a non-zero value that is large " "enough to meet peak collective memory usage.")); + flag_list->push_back(tsl::Flag( + "xla_gpu_enable_nccl_comm_splitting", + bool_setter_for(&DebugOptions::set_xla_gpu_enable_nccl_comm_splitting), + debug_options->xla_gpu_enable_nccl_comm_splitting(), + "Enables NCCL communicator splitting which allows sharing NCCL resources " + "between different NCCL cliques.")); flag_list->push_back(tsl::Flag( "xla_gpu_redzone_scratch_max_megabytes", int64_setter_for( @@ -1453,6 +1533,26 @@ void MakeDebugOptionsFlags(std::vector* flag_list, debug_options->xla_gpu_enable_libnvptxcompiler(), "Use libnvptxcompiler for PTX-to-GPU-assembly compilation instead of " "calling ptxas.")); + flag_list->push_back(tsl::Flag( + "xla_gpu_enable_dot_strength_reduction", + bool_setter_for(&DebugOptions::set_xla_gpu_enable_dot_strength_reduction), + debug_options->xla_gpu_enable_dot_strength_reduction(), + "Enable rewriting matmuls with a vector into reductions.")); + flag_list->push_back( + tsl::Flag("xla_gpu_nccl_collective_max_nchannels", + int64_setter_for( + &DebugOptions::set_xla_gpu_nccl_collective_max_nchannels), + debug_options->xla_gpu_nccl_collective_max_nchannels(), + "Specify the maximum number of channels(SMs) NCCL will use " + "for collective operations. Default is 0 which is to let " + "NCCL decide.")); + flag_list->push_back(tsl::Flag( + "xla_gpu_nccl_p2p_max_nchannels", + int64_setter_for(&DebugOptions::set_xla_gpu_nccl_p2p_max_nchannels), + debug_options->xla_gpu_nccl_p2p_max_nchannels(), + "Specify the maximum number of channels(SMs) NCCL will use " + "for p2p operations. Default is 0 which is to let " + "NCCL decide.")); } // NOLINT(readability/fn_size) // Allocates flag_values and flag_objects; this function must not be called more diff --git a/third_party/xla/xla/experiments/BUILD b/third_party/xla/xla/experiments/BUILD index c9aa052a1493ce..d298feaf3f0a7a 100644 --- a/third_party/xla/xla/experiments/BUILD +++ b/third_party/xla/xla/experiments/BUILD @@ -1,6 +1,8 @@ # Various experiments related to the compiler that are not a part of the final XLA binary. package( - default_visibility = ["//visibility:public"], + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + # keep visibility private, if you need to depend on this, move it out of experiments + default_visibility = ["//visibility:private"], licenses = ["notice"], ) diff --git a/third_party/xla/xla/experiments/sm_bandwidth_benchmark/BUILD b/third_party/xla/xla/experiments/sm_bandwidth_benchmark/BUILD index 003c091632c73f..79174d282658bd 100644 --- a/third_party/xla/xla/experiments/sm_bandwidth_benchmark/BUILD +++ b/third_party/xla/xla/experiments/sm_bandwidth_benchmark/BUILD @@ -6,7 +6,6 @@ load("//xla:xla.bzl", "xla_cc_test") cc_library( name = "sm_bw_utils", hdrs = ["sm_bw_utils.h"], - visibility = ["//visibility:public"], deps = [ "@local_tsl//tsl/platform:logging", ] + if_cuda([ @@ -18,7 +17,6 @@ cuda_library( name = "sm_bw_kernels", srcs = ["sm_bw_kernels.cu.cc"], hdrs = ["sm_bw_kernels.h"], - visibility = ["//visibility:public"], deps = [ ":sm_bw_utils", ], diff --git a/third_party/xla/xla/experiments/triton_autotuning/BUILD b/third_party/xla/xla/experiments/triton_autotuning/BUILD index b32c9e1d2926a7..cdadad94af0bb3 100644 --- a/third_party/xla/xla/experiments/triton_autotuning/BUILD +++ b/third_party/xla/xla/experiments/triton_autotuning/BUILD @@ -2,7 +2,7 @@ package( default_applicable_licenses = ["//tensorflow:license"], - default_visibility = ["//visibility:public"], + default_visibility = ["//visibility:private"], licenses = ["notice"], ) diff --git a/third_party/xla/xla/ffi/BUILD b/third_party/xla/xla/ffi/BUILD index 6ed7f263f08eb3..f4a29a469ba8df 100644 --- a/third_party/xla/xla/ffi/BUILD +++ b/third_party/xla/xla/ffi/BUILD @@ -2,13 +2,14 @@ load("//xla:xla.bzl", "xla_cc_test") load("@local_tsl//tsl/platform:rules_cc.bzl", "cc_library") package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], default_visibility = ["//visibility:public"], ) cc_library( name = "api", hdrs = ["//xla/ffi/api:api_headers"], - visibility = ["//visibility:public"], + visibility = ["//visibility:private"], deps = ["//xla/ffi/api:c_api"], ) @@ -16,7 +17,6 @@ cc_library( name = "call_frame", srcs = ["call_frame.cc"], hdrs = ["call_frame.h"], - visibility = ["//visibility:public"], deps = [ "//xla:types", "//xla:xla_data_proto_cc", @@ -33,7 +33,6 @@ cc_library( cc_library( name = "ffi", hdrs = ["ffi.h"], - visibility = ["//visibility:public"], deps = [ ":api", "//xla:shape_util", @@ -54,7 +53,6 @@ cc_library( name = "ffi_api", srcs = ["ffi_api.cc"], hdrs = ["ffi_api.h"], - visibility = ["//visibility:public"], deps = [ ":api", ":call_frame", diff --git a/third_party/xla/xla/ffi/api/BUILD b/third_party/xla/xla/ffi/api/BUILD index 3c195162a7d1a5..399f233a211f38 100644 --- a/third_party/xla/xla/ffi/api/BUILD +++ b/third_party/xla/xla/ffi/api/BUILD @@ -3,6 +3,7 @@ load("@local_tsl//tsl:tsl.default.bzl", "filegroup") load("@local_tsl//tsl/platform:rules_cc.bzl", "cc_library") package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], default_visibility = ["//visibility:public"], ) @@ -25,39 +26,34 @@ package( filegroup( name = "api_headers", srcs = ["api.h"], - visibility = ["//visibility:public"], ) filegroup( name = "c_api_headers", srcs = ["c_api.h"], - visibility = ["//visibility:public"], ) cc_library( name = "api", hdrs = [":api_headers"], - visibility = ["//visibility:public"], + visibility = ["//visibility:private"], deps = [":c_api"], ) cc_library( name = "c_api", hdrs = ["c_api.h"], - visibility = ["//visibility:public"], ) cc_library( name = "c_api_internal", hdrs = ["c_api_internal.h"], - visibility = ["//visibility:public"], deps = [":c_api"], ) cc_library( name = "ffi", hdrs = ["ffi.h"], - visibility = ["//visibility:public"], deps = [ ":api", ":c_api", diff --git a/third_party/xla/xla/ffi/api/ffi.h b/third_party/xla/xla/ffi/api/ffi.h index 57233ee301b7af..200cb3e2f2f5ab 100644 --- a/third_party/xla/xla/ffi/api/ffi.h +++ b/third_party/xla/xla/ffi/api/ffi.h @@ -16,9 +16,9 @@ limitations under the License. #ifndef XLA_FFI_API_FFI_H_ #define XLA_FFI_API_FFI_H_ -#ifdef TENSORFLOW_COMPILER_XLA_FFI_FFI_H_ +#ifdef XLA_FFI_FFI_H_ #error Two different XLA FFI implementations cannot be included together -#endif // XLA_FFI_API_H_ +#endif // XLA_FFI_FFI_H_ #include #include diff --git a/third_party/xla/xla/ffi/ffi.h b/third_party/xla/xla/ffi/ffi.h index 37734a11f1304c..c600e8ad86a769 100644 --- a/third_party/xla/xla/ffi/ffi.h +++ b/third_party/xla/xla/ffi/ffi.h @@ -16,7 +16,7 @@ limitations under the License. #ifndef XLA_FFI_FFI_H_ #define XLA_FFI_FFI_H_ -#ifdef TENSORFLOW_COMPILER_XLA_FFI_API_FFI_H_ +#ifdef XLA_FFI_API_FFI_H_ #error Two different XLA FFI implementations cannot be included together #endif // XLA_FFI_API_FFI_H_ diff --git a/third_party/xla/xla/ffi/ffi_api.cc b/third_party/xla/xla/ffi/ffi_api.cc index 2bf48bbb11ba72..07031757eff859 100644 --- a/third_party/xla/xla/ffi/ffi_api.cc +++ b/third_party/xla/xla/ffi/ffi_api.cc @@ -108,8 +108,8 @@ static Status RegisterHandler(std::string_view name, std::string_view platform, return OkStatus(); } -StatusOr FindHandler(std::string_view name, - std::string_view platform) { +absl::StatusOr FindHandler(std::string_view name, + std::string_view platform) { auto it = GetHandlerRegistry().find(MakeHandlerKey(name, platform)); if (it == GetHandlerRegistry().end()) return absl::NotFoundError(absl::StrCat("No FFI handler registered for ", diff --git a/third_party/xla/xla/ffi/ffi_api.h b/third_party/xla/xla/ffi/ffi_api.h index 70a40bfa608e7d..b8a99100531c24 100644 --- a/third_party/xla/xla/ffi/ffi_api.h +++ b/third_party/xla/xla/ffi/ffi_api.h @@ -62,8 +62,8 @@ Status Call(XLA_FFI_Handler* handler, CallFrame& call_frame, // Returns registered FFI handler for a given name and platform, or an error if // it's not found in the static registry. -StatusOr FindHandler(std::string_view name, - std::string_view platform); +absl::StatusOr FindHandler(std::string_view name, + std::string_view platform); //===----------------------------------------------------------------------===// // XLA FFI Api Implementation diff --git a/third_party/xla/xla/hlo/evaluator/BUILD b/third_party/xla/xla/hlo/evaluator/BUILD index 5f041a91f8c9ab..0d6b67083576f7 100644 --- a/third_party/xla/xla/hlo/evaluator/BUILD +++ b/third_party/xla/xla/hlo/evaluator/BUILD @@ -5,7 +5,8 @@ load("//xla:xla.bzl", "xla_cc_test") load("@local_tsl//tsl/platform:rules_cc.bzl", "cc_library") package( - default_visibility = ["//visibility:public"], + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = [":friends"], licenses = ["notice"], ) @@ -40,7 +41,6 @@ cc_library( "hlo_evaluator_typed_visitor_uint8.cc", ], hdrs = ["hlo_evaluator.h"], - visibility = ["//visibility:public"], deps = [ "//xla:array2d", "//xla:comparison_util", diff --git a/third_party/xla/xla/hlo/evaluator/hlo_evaluator.cc b/third_party/xla/xla/hlo/evaluator/hlo_evaluator.cc index eb8c26a0b4cc7e..b1d20bd7867ccf 100644 --- a/third_party/xla/xla/hlo/evaluator/hlo_evaluator.cc +++ b/third_party/xla/xla/hlo/evaluator/hlo_evaluator.cc @@ -96,9 +96,10 @@ namespace { using primitive_util::NativeTypeOf; template -StatusOr Compare(const Shape& shape, Comparison comparison, - LiteralSlice lhs_literal, LiteralSlice rhs_literal) { - auto populate = [&](auto compare_op) -> StatusOr { +absl::StatusOr Compare(const Shape& shape, Comparison comparison, + LiteralSlice lhs_literal, + LiteralSlice rhs_literal) { + auto populate = [&](auto compare_op) -> absl::StatusOr { Literal result(shape); TF_RETURN_IF_ERROR(result.PopulateParallel( [&](absl::Span multi_index, int /*thread_id*/) { @@ -147,7 +148,7 @@ StatusOr Compare(const Shape& shape, Comparison comparison, std::optional GetInstructionStaticValueAsBool( const HloInstruction* instruction) { HloEvaluator evaluator; - StatusOr static_value = evaluator.Evaluate( + absl::StatusOr static_value = evaluator.Evaluate( instruction, /*recursively_evaluate_nonconstant_operands=*/true); if (static_value.ok()) { return static_value->GetFirstElement(); @@ -251,7 +252,7 @@ struct DynamicOrStaticInteger { std::optional GetInstructionValueAsInteger( const HloInstruction* instruction) { HloEvaluator evaluator; - StatusOr static_value = evaluator.Evaluate( + absl::StatusOr static_value = evaluator.Evaluate( instruction, /*recursively_evaluate_nonconstant_operands=*/true); if (static_value.ok()) { if (instruction->shape().element_type() == PrimitiveType::PRED) { @@ -859,7 +860,7 @@ HloEvaluator::HloEvaluator(int64_t max_loop_iterations) }); } -StatusOr HloEvaluator::Evaluate( +absl::StatusOr HloEvaluator::Evaluate( const HloComputation& computation, absl::Span arg_literals) { CHECK(computation.parent() != nullptr); @@ -920,7 +921,7 @@ StatusOr HloEvaluator::Evaluate( return result.Clone(); } -StatusOr HloEvaluator::Evaluate( +absl::StatusOr HloEvaluator::Evaluate( const HloInstruction* instruction, bool recursively_evaluate_nonconstant_operands) { arg_literals_.clear(); @@ -955,7 +956,7 @@ bool HloEvaluator::TryEvaluate(const HloInstruction* instruction, return true; } -StatusOr HloEvaluator::EvaluateWithSubstitutions( +absl::StatusOr HloEvaluator::EvaluateWithSubstitutions( const HloInstruction* instruction, const absl::flat_hash_map& substitutions) { @@ -983,7 +984,7 @@ StatusOr HloEvaluator::EvaluateWithSubstitutions( return result; } -StatusOr HloEvaluator::EvaluateElementwiseBinaryOp( +absl::StatusOr HloEvaluator::EvaluateElementwiseBinaryOp( HloOpcode opcode, const Literal& lhs, const Literal& rhs) { std::unique_ptr lhs_instr = HloInstruction::CreateConstant(lhs.Clone()); @@ -998,7 +999,7 @@ StatusOr HloEvaluator::EvaluateElementwiseBinaryOp( return result; } -StatusOr HloEvaluator::EvaluateElementwiseTernaryOp( +absl::StatusOr HloEvaluator::EvaluateElementwiseTernaryOp( HloOpcode opcode, const Literal& lhs, const Literal& rhs, const Literal& ehs) { std::unique_ptr lhs_instr = @@ -1016,7 +1017,7 @@ StatusOr HloEvaluator::EvaluateElementwiseTernaryOp( return Evaluate(cloned_instruction.get()); } -StatusOr HloEvaluator::EvaluateElementwiseCompareOp( +absl::StatusOr HloEvaluator::EvaluateElementwiseCompareOp( ComparisonDirection direction, const Literal& lhs, const Literal& rhs) { std::unique_ptr lhs_instr = HloInstruction::CreateConstant(lhs.Clone()); @@ -1032,7 +1033,7 @@ StatusOr HloEvaluator::EvaluateElementwiseCompareOp( return result; } -StatusOr HloEvaluator::EvaluateElementwiseUnaryOp( +absl::StatusOr HloEvaluator::EvaluateElementwiseUnaryOp( HloOpcode opcode, const Literal& operand) { std::unique_ptr operand_instr = HloInstruction::CreateConstant(operand.Clone()); @@ -1046,7 +1047,7 @@ StatusOr HloEvaluator::EvaluateElementwiseUnaryOp( return result; } -StatusOr HloEvaluator::EvaluateDotOp( +absl::StatusOr HloEvaluator::EvaluateDotOp( const DotDimensionNumbers& dim_numbers, const PrecisionConfig& precision_config, const Literal& lhs, const Literal& rhs) { @@ -1189,7 +1190,7 @@ Status HloEvaluator::EvaluateInternal( } if (!tuple_points_to_analysis_cache_) { HloModule* module = instruction->GetModule(); - StatusOr> + absl::StatusOr> tuple_points_to_analysis = TuplePointsToAnalysis::Run(module); if (tuple_points_to_analysis.ok()) { tuple_points_to_analysis_cache_ = @@ -2347,7 +2348,7 @@ class OutputBatchIndexToInputIndex { // same storage for all invocations. // // This returns a Span into memory owned by the class. - StatusOr> operator()( + absl::StatusOr> operator()( absl::Span output_index) { PropagateOutputIndexGatherDimsToIndexVectorIndex(output_index); TF_RETURN_IF_ERROR(FetchIndexVector()); @@ -2467,7 +2468,7 @@ class OutputOffsetIndexToInputIndex { // result (input_index_), mutating it in place. // // This returns a Span into memory owned by the class. - StatusOr> operator()( + absl::StatusOr> operator()( absl::Span output_index) { PropagateOutputIndexWindowDimsToInputIndex(output_index); return absl::Span(input_index_); @@ -2507,9 +2508,9 @@ class OutputOffsetIndexToInputIndex { // Reshapes the gather indices input to have a trailing degenerate `1` dimension // if necessary. Hands over the ownership of the newly created literal (if // there is one) to `reshaped_start_indices`. -static StatusOr> ReshapedGatherIndices( - int64_t index_vector_dim, const Literal& start_indices, - Literal* reshaped_start_indices) { +static absl::StatusOr> +ReshapedGatherIndices(int64_t index_vector_dim, const Literal& start_indices, + Literal* reshaped_start_indices) { if (start_indices.shape().dimensions_size() != index_vector_dim) { return std::cref(start_indices); } @@ -2574,7 +2575,8 @@ Status HloEvaluator::HandleGather(const HloInstruction* gather) { auto gather_inner_loop_body = [&](absl::Span output_window_index, absl::Span input_gather_index, - absl::Span output_gather_index) -> StatusOr { + absl::Span output_gather_index) + -> absl::StatusOr { TF_ASSIGN_OR_RETURN( absl::Span input_window_index, output_offset_index_to_input_index(output_window_index)); @@ -2608,7 +2610,8 @@ Status HloEvaluator::HandleGather(const HloInstruction* gather) { }; auto gather_outer_loop_body = - [&](absl::Span output_gather_index) -> StatusOr { + [&](absl::Span output_gather_index) + -> absl::StatusOr { TF_ASSIGN_OR_RETURN(absl::Span input_gather_index, output_batch_index_to_input_index(output_gather_index)); TF_RETURN_IF_ERROR(ShapeUtil::ForEachIndexWithStatus( @@ -2628,7 +2631,7 @@ namespace { // Reshapes the scatter indices input to have a trailing degenerate `1` // dimension if necessary. Hands over the ownership of the newly created // literal (if there is one) to `reshaped_indices`. -StatusOr> ReshapedScatterIndices( +absl::StatusOr> ReshapedScatterIndices( int64_t index_vector_dim, const Literal& indices, Literal* reshaped_indices) { if (indices.shape().dimensions_size() != index_vector_dim) { @@ -2750,7 +2753,7 @@ class UpdateScatterIndexToInputIndex { // same storage for all invocations. // // This returns a Span into memory owned by the class. - StatusOr> operator()( + absl::StatusOr> operator()( absl::Span update_index) { PropagateUpdateIndexScatterDimsToIndexVectorIndex(update_index); TF_RETURN_IF_ERROR(FetchIndexVector()); @@ -2873,7 +2876,7 @@ class UpdateWindowIndexToInputIndex { // result (input_index_), mutating it in place. // // This returns a Span into memory owned by the class. - StatusOr> operator()( + absl::StatusOr> operator()( absl::Span update_index) { PropagateUpdateIndexWindowDimsToInputIndex(update_index); return absl::Span(input_index_); @@ -2966,7 +2969,8 @@ Status HloEvaluator::HandleScatter(const HloInstruction* hlo) { auto scatter_inner_loop_body = [&](absl::Span update_window_index, absl::Span input_scatter_index, - absl::Span update_scatter_index) -> StatusOr { + absl::Span update_scatter_index) + -> absl::StatusOr { TF_ASSIGN_OR_RETURN( absl::Span input_window_index, update_window_index_to_input_index(update_window_index)); @@ -3018,7 +3022,8 @@ Status HloEvaluator::HandleScatter(const HloInstruction* hlo) { }; auto scatter_outer_loop_body = - [&](absl::Span update_scatter_index) -> StatusOr { + [&](absl::Span update_scatter_index) + -> absl::StatusOr { TF_ASSIGN_OR_RETURN( absl::Span input_scatter_index, update_scatter_index_to_input_index(update_scatter_index)); @@ -3115,13 +3120,14 @@ Status HloEvaluator::HandleAsyncStart(const HloInstruction* async_start) { arg_literals.push_back(&arg_literal); } - HloEvaluator embedded_evaluator; - embedded_evaluator.set_dynamic_dimension_inference( + std::unique_ptr embedded_evaluator = + CreateEmbedded(max_loop_iterations_); + embedded_evaluator->set_dynamic_dimension_inference( dynamic_dimension_inference_); TF_ASSIGN_OR_RETURN( Literal result, - embedded_evaluator.Evaluate(*async_start->async_wrapped_computation(), - arg_literals)); + embedded_evaluator->Evaluate(*async_start->async_wrapped_computation(), + arg_literals)); evaluated_[async_start] = Literal(async_start->shape()); // Copy the operand values to the index {0, i} of the output. @@ -3415,10 +3421,10 @@ Status HloEvaluator::HandleSelect(const HloInstruction* select) { namespace { -StatusOr CreateScalarLiteral(int64_t value, - PrimitiveType element_type) { +absl::StatusOr CreateScalarLiteral(int64_t value, + PrimitiveType element_type) { return primitive_util::PrimitiveTypeSwitch>( - [&](auto primitive_type_constant) -> StatusOr { + [&](auto primitive_type_constant) -> absl::StatusOr { if constexpr (primitive_util::IsIntegralType(primitive_type_constant)) { return LiteralUtil::CreateR0( static_cast>(value)); @@ -3431,7 +3437,7 @@ StatusOr CreateScalarLiteral(int64_t value, // Parses the while loop if it matches one of the known patterns. Returns the // value of the loop induction variable after the loop execution if the loop is // static. -StatusOr TryParseAndEvaluateWhileInductionVar( +absl::StatusOr TryParseAndEvaluateWhileInductionVar( const HloInstruction* while_hlo) { std::optional parsed_while_loop = PatternMatchParseWhileLoop(while_hlo); @@ -3506,7 +3512,7 @@ Status HloEvaluator::HandleWhile(const HloInstruction* while_hlo) { dynamic_dimension_inference_); while (keep_going) { if (max_loop_iterations_ >= 0 && iteration_count++ > max_loop_iterations_) { - StatusOr result = + absl::StatusOr result = TryParseAndEvaluateWhileInductionVar(while_hlo); if (result.ok()) { lcv = std::move(result).value(); @@ -3545,11 +3551,11 @@ Literal ExtractLiteralFromIndexPositions(const Literal& from, return LiteralUtil::CreateR1(values); } -StatusOr ExtractFromIndexPositions(const Literal& from, - absl::Span indices) { +absl::StatusOr ExtractFromIndexPositions( + const Literal& from, absl::Span indices) { PrimitiveType type = from.shape().element_type(); return primitive_util::PrimitiveTypeSwitch>( - [&](auto primitive_type_constant) -> StatusOr { + [&](auto primitive_type_constant) -> absl::StatusOr { if constexpr (primitive_util::IsArrayType(primitive_type_constant)) { return ExtractLiteralFromIndexPositions< NativeTypeOf>(from, indices); @@ -3608,9 +3614,9 @@ void IterateThroughWindow( } template -StatusOr StochasticConvertOp(const Literal& operand_literal, - const Literal& random_literal, - const Shape& result_shape) { +absl::StatusOr StochasticConvertOp(const Literal& operand_literal, + const Literal& random_literal, + const Shape& result_shape) { std::function stochastic_convert_op = [](Fp operand, Uint random) -> ResultT { bool is_negative = static_cast(Eigen::numext::signbit(operand)); @@ -3672,9 +3678,9 @@ StatusOr StochasticConvertOp(const Literal& operand_literal, // Converts from primitive types to native types. template -StatusOr StochasticConvertOp(const Literal& operand_literal, - const Literal& random_literal, - const Shape& result_shape) { +absl::StatusOr StochasticConvertOp(const Literal& operand_literal, + const Literal& random_literal, + const Shape& result_shape) { return StochasticConvertOp< typename primitive_util::PrimitiveTypeToNative::type, typename primitive_util::PrimitiveTypeToNative::type, @@ -3684,11 +3690,11 @@ StatusOr StochasticConvertOp(const Literal& operand_literal, // Evaluates all possible paths of converting to different integers. template -StatusOr StochasticConvertOp(const Literal& operand_literal, - const Literal& random_literal, - const Shape& result_shape) { +absl::StatusOr StochasticConvertOp(const Literal& operand_literal, + const Literal& random_literal, + const Shape& result_shape) { return primitive_util::PrimitiveTypeSwitch>( - [&](auto primitive_type_constant) -> StatusOr { + [&](auto primitive_type_constant) -> absl::StatusOr { if constexpr (primitive_util::IsSignedIntegralType( primitive_type_constant)) { return StochasticConvertOp StochasticConvertOp(const Literal& operand_literal, result_shape.element_type()); } -StatusOr StochasticConvertOp(const Literal& operand_literal, - const Literal& random_literal, - const Shape& result_shape) { +absl::StatusOr StochasticConvertOp(const Literal& operand_literal, + const Literal& random_literal, + const Shape& result_shape) { return primitive_util::PrimitiveTypeSwitch>( - [&](auto primitive_type_constant) -> StatusOr { + [&](auto primitive_type_constant) -> absl::StatusOr { if constexpr (primitive_util::IsFloatingPointType( primitive_type_constant)) { return StochasticConvertOp< @@ -3924,9 +3930,9 @@ Status HloEvaluator::HandleSort(const HloInstruction* sort) { << " accessing increment of size " << increment.size(); increment[sort_dim] = sort_dim_elements; - auto comparator = [sort](absl::Span literals_to_sort, - int64_t a, int64_t b, - HloEvaluator* embedded_evaluator) -> StatusOr { + auto comparator = + [sort](absl::Span literals_to_sort, int64_t a, int64_t b, + HloEvaluator* embedded_evaluator) -> absl::StatusOr { absl::InlinedVector literals; literals.reserve(2 * sort->operand_count()); for (int64_t i = 0; i < sort->operand_count(); ++i) { @@ -3947,10 +3953,10 @@ Status HloEvaluator::HandleSort(const HloInstruction* sort) { embedded_evaluator->ResetVisitStates(); return computed_result.Get({}); }; - auto less_than = [&comparator]( - absl::Span literals_to_sort, int64_t a, - int64_t b, - HloEvaluator* embedded_evaluator) -> StatusOr { + auto less_than = + [&comparator](absl::Span literals_to_sort, int64_t a, + int64_t b, + HloEvaluator* embedded_evaluator) -> absl::StatusOr { TF_ASSIGN_OR_RETURN(bool a_is_smaller, comparator(literals_to_sort, a, b, embedded_evaluator)); #ifndef NDEBUG @@ -4100,7 +4106,7 @@ Status HloEvaluator::HandleSort(const HloInstruction* sort) { // Iterate through each dimension except 'sort_dim'. TF_RETURN_IF_ERROR(ShapeUtil::ForEachIndexWithStatus( key_shape, zero_base, key_shape.dimensions(), increment, - [&](absl::Span indices) -> StatusOr { + [&](absl::Span indices) -> absl::StatusOr { // Extract a slice from each operand literal that corresponds to // exactly the row in dimension 'sort_dim'. std::vector limit_indices(indices.begin(), indices.end()); @@ -4185,7 +4191,7 @@ static bool IsScalarAdd(HloComputation* computation) { // the user-provided computation on the accumulator and the output element // (until the reduction is completed, the output element is also used as // an accumulator). -static StatusOr PerformReductionStep( +static absl::StatusOr PerformReductionStep( bool is_tuple, absl::Span input_index, absl::Span output_index, absl::Span input_args, absl::Span results, @@ -4235,7 +4241,7 @@ static StatusOr PerformReductionStep( return true; } -static StatusOr GenerateReduceOutputElement( +static absl::StatusOr GenerateReduceOutputElement( bool is_tuple, absl::Span output_index, absl::Span init_values, diff --git a/third_party/xla/xla/hlo/evaluator/hlo_evaluator.h b/third_party/xla/xla/hlo/evaluator/hlo_evaluator.h index 7c13ab01c9ec18..ed6accfacf96e9 100644 --- a/third_party/xla/xla/hlo/evaluator/hlo_evaluator.h +++ b/third_party/xla/xla/hlo/evaluator/hlo_evaluator.h @@ -89,7 +89,9 @@ class HloEvaluator : public ConstDfsHloVisitorWithDefault { // instance of the subclass instead. virtual std::unique_ptr CreateEmbedded( int64_t max_loop_iterations) { - return std::make_unique(max_loop_iterations); + auto result = std::make_unique(max_loop_iterations); + result->set_custom_call_handler(custom_call_handler_); + return result; } // Enables subclasses to be notified when a new computation is being @@ -105,13 +107,13 @@ class HloEvaluator : public ConstDfsHloVisitorWithDefault { // // (Dummy template arg is to reduce the overloading priority of one overload // so that Evaluate(module, {}) resolves unambiguously.) - StatusOr Evaluate(const HloModule& module, - absl::Span arg_literals) { + absl::StatusOr Evaluate( + const HloModule& module, absl::Span arg_literals) { return Evaluate(*module.entry_computation(), arg_literals); } template - StatusOr Evaluate(const HloModule& module, - absl::Span arg_literals) { + absl::StatusOr Evaluate(const HloModule& module, + absl::Span arg_literals) { return Evaluate(*module.entry_computation(), arg_literals); } @@ -134,11 +136,12 @@ class HloEvaluator : public ConstDfsHloVisitorWithDefault { // // (Dummy template arg is to reduce the overloading priority of one overload // so that Evaluate(module, {}) resolves unambiguously.) - StatusOr Evaluate(const HloComputation& computation, - absl::Span arg_literals); + absl::StatusOr Evaluate( + const HloComputation& computation, + absl::Span arg_literals); template - StatusOr Evaluate(const HloComputation& computation, - absl::Span arg_literals) { + absl::StatusOr Evaluate(const HloComputation& computation, + absl::Span arg_literals) { std::vector arg_literal_ptrs; for (const auto& l : arg_literals) { arg_literal_ptrs.push_back(&l); @@ -152,7 +155,7 @@ class HloEvaluator : public ConstDfsHloVisitorWithDefault { // within its parent computation until it encounters something that cannot be // evaluated, such as an Infeed or a Parameter instruction. // It makes best effort to partially evaluate a dependency if possible. - StatusOr Evaluate( + absl::StatusOr Evaluate( const HloInstruction* instruction, bool recursively_evaluate_nonconstant_operands = false); @@ -166,30 +169,29 @@ class HloEvaluator : public ConstDfsHloVisitorWithDefault { // // For example, given instruction = op(A, B, C) and the map // {A = x, C = y}, this evaluates op(x, B, y). - StatusOr EvaluateWithSubstitutions( + absl::StatusOr EvaluateWithSubstitutions( const HloInstruction* instruction, const absl::flat_hash_map& substitutions); - StatusOr EvaluateElementwiseBinaryOp(HloOpcode opcode, - const Literal& lhs, - const Literal& rhs); + absl::StatusOr EvaluateElementwiseBinaryOp(HloOpcode opcode, + const Literal& lhs, + const Literal& rhs); - StatusOr EvaluateElementwiseUnaryOp(HloOpcode opcode, - const Literal& operand); + absl::StatusOr EvaluateElementwiseUnaryOp(HloOpcode opcode, + const Literal& operand); - StatusOr EvaluateElementwiseTernaryOp(HloOpcode opcode, - const Literal& lhs, - const Literal& rhs, - const Literal& ehs); + absl::StatusOr EvaluateElementwiseTernaryOp(HloOpcode opcode, + const Literal& lhs, + const Literal& rhs, + const Literal& ehs); - StatusOr EvaluateElementwiseCompareOp(ComparisonDirection direction, - const Literal& lhs, - const Literal& rhs); + absl::StatusOr EvaluateElementwiseCompareOp( + ComparisonDirection direction, const Literal& lhs, const Literal& rhs); - StatusOr EvaluateDotOp(const DotDimensionNumbers& dim_numbers, - const PrecisionConfig& precision_config, - const Literal& lhs, const Literal& rhs); + absl::StatusOr EvaluateDotOp(const DotDimensionNumbers& dim_numbers, + const PrecisionConfig& precision_config, + const Literal& lhs, const Literal& rhs); void set_dynamic_dimension_inference( DynamicDimensionInference* dynamic_dimension_inference) { @@ -206,7 +208,7 @@ class HloEvaluator : public ConstDfsHloVisitorWithDefault { // Handles evaluation of a custom-call op. // Operand literals are provided in |operands| and implementations must // populate |output| before returning. - using CustomCallHandler = std::function( + using CustomCallHandler = std::function( const HloInstruction* custom_call, absl::Span operands)>; // Sets a handler that is called during evaluation for custom-call ops. @@ -434,7 +436,7 @@ class HloEvaluator : public ConstDfsHloVisitorWithDefault { private: template - static StatusOr ElementWiseUnaryOpImpl( + static absl::StatusOr ElementWiseUnaryOpImpl( const HloInstruction* instruction, const std::function& unary_op, const Literal& operand_literal) { diff --git a/third_party/xla/xla/hlo/evaluator/hlo_evaluator_test.cc b/third_party/xla/xla/hlo/evaluator/hlo_evaluator_test.cc index 1a01e8f1b34539..ef3cb8ad8cee90 100644 --- a/third_party/xla/xla/hlo/evaluator/hlo_evaluator_test.cc +++ b/third_party/xla/xla/hlo/evaluator/hlo_evaluator_test.cc @@ -77,7 +77,7 @@ class HloEvaluatorTest : public HloTestBase { public: HloEvaluatorTest() : use_bfloat16_(false) { InitializeFftData(); } - StatusOr Evaluate( + absl::StatusOr Evaluate( absl::Span arg_literals = {}) { if (use_bfloat16_) { HloElementTypeConverter(F32, BF16).Run(m_.get()).value(); @@ -155,7 +155,7 @@ class HloEvaluatorTest : public HloTestBase { } void TestEvaluationFailure(HloInstruction* instruction) { - StatusOr result = evaluator_.Evaluate(instruction); + absl::StatusOr result = evaluator_.Evaluate(instruction); EXPECT_TRUE(!result.ok()); } @@ -170,7 +170,7 @@ class HloEvaluatorTest : public HloTestBase { } void TestRecursiveEvaluationFailure(HloInstruction* instruction) { - StatusOr result = evaluator_.Evaluate( + absl::StatusOr result = evaluator_.Evaluate( instruction, /*recursively_evaluate_nonconstant_operands=*/true); EXPECT_TRUE(!result.ok()); } @@ -4605,6 +4605,30 @@ TEST_F(HloEvaluatorTest, EvaluateCustomCall_ManyInputs) { EXPECT_TRUE(absl::c_equal(expected_data, actual_literal.data())); } +TEST_F(HloEvaluatorTest, EvaluateCustomCallInFusion) { + const absl::string_view hlo_text = R"( +fusion1 { + p = f32[] parameter(0) + ROOT c = f32[] custom-call(p), custom_call_target="__cchandler1" +} + +ENTRY e { + p = f32[] parameter(0) + ROOT f = f32[] fusion(p), kind=kCustom, calls=fusion1 +})"; + + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); + auto input = LiteralUtil::CreateR0(0); + HloEvaluator evaluator; + evaluator.set_custom_call_handler([](const HloInstruction* custom_call, + absl::Span operands) { + return LiteralUtil::CreateR0(1 - + operands[0]->GetFirstElement()); + }); + TF_ASSERT_OK_AND_ASSIGN(auto output, evaluator.Evaluate(*m_, {&input})); + EXPECT_EQ(output, LiteralUtil::CreateR0(1)); +} + TEST_F(HloEvaluatorTest, IsFiniteF16) { const absl::string_view hlo_text = R"( HloModule test diff --git a/third_party/xla/xla/hlo/evaluator/hlo_evaluator_typed_visitor.h b/third_party/xla/xla/hlo/evaluator/hlo_evaluator_typed_visitor.h index ef85750aa488c5..68b79d25bf5d24 100644 --- a/third_party/xla/xla/hlo/evaluator/hlo_evaluator_typed_visitor.h +++ b/third_party/xla/xla/hlo/evaluator/hlo_evaluator_typed_visitor.h @@ -245,6 +245,18 @@ class HloEvaluatorTypedVisitor : public ConstDfsHloVisitorWithDefault { return UnsupportedTypeError(ceil); } + Status HandleErf(const HloInstruction* erf) override { + if constexpr (!is_complex_v) { + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[erf], + ElementWiseUnaryOp(erf, [](ElementwiseT elem_operand) { + return std::erf(elem_operand); + })); + return OkStatus(); + } + return UnsupportedTypeError(erf); + } + Status HandleExp(const HloInstruction* exp) override { TF_ASSIGN_OR_RETURN(parent_->evaluated_[exp], ElementWiseUnaryOp(exp, [](ElementwiseT elem_operand) { @@ -1593,7 +1605,7 @@ class HloEvaluatorTypedVisitor : public ConstDfsHloVisitorWithDefault { } private: - StatusOr ElementWiseUnaryOp( + absl::StatusOr ElementWiseUnaryOp( const HloInstruction* instruction, const std::function& unary_op) { const Literal& operand_literal = @@ -1606,7 +1618,7 @@ class HloEvaluatorTypedVisitor : public ConstDfsHloVisitorWithDefault { return std::move(result_literal); } - StatusOr ElementWiseBinaryOp( + absl::StatusOr ElementWiseBinaryOp( const HloInstruction* instruction, const std::function& binary_op) { @@ -1631,7 +1643,7 @@ class HloEvaluatorTypedVisitor : public ConstDfsHloVisitorWithDefault { } template - StatusOr ElementwiseTernaryOp( + absl::StatusOr ElementwiseTernaryOp( const HloInstruction* instruction, const std::function& ternary_op) { const auto& shape = instruction->shape(); diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/BUILD b/third_party/xla/xla/hlo/experimental/auto_sharding/BUILD index f7ff3c3e023f06..7a3cb4f6976c4d 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/BUILD +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/BUILD @@ -2,10 +2,12 @@ load("@bazel_skylib//rules:build_test.bzl", "build_test") load("//xla:xla.bzl", "auto_sharding_deps", "auto_sharding_solver_deps", "xla_cc_binary", "xla_cc_test") +load("@local_tsl//tsl:tsl.default.bzl", "get_compatible_with_libtpu_portable") load("@local_tsl//tsl/platform:build_config.bzl", "tf_proto_library") package( - default_visibility = ["//visibility:public"], + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = [":friends"], ) package_group( @@ -27,7 +29,7 @@ cc_library( hdrs = [ "auto_sharding.h", ], - visibility = ["//visibility:public"], + compatible_with = get_compatible_with_libtpu_portable(), deps = [ ":auto_sharding_cost_graph", ":auto_sharding_option", @@ -82,7 +84,7 @@ cc_library( cc_library( name = "auto_sharding_solver_impl", srcs = ["auto_sharding_solver_impl.cc"], - visibility = ["//visibility:public"], + compatible_with = get_compatible_with_libtpu_portable(), deps = [ ":auto_sharding_proto_cc", ":auto_sharding_strategy", @@ -94,7 +96,7 @@ cc_library( cc_library( name = "auto_sharding_solver", srcs = ["auto_sharding_solver.cc"], - visibility = ["//visibility:public"], + compatible_with = get_compatible_with_libtpu_portable(), deps = [ ":auto_sharding_proto_cc", ":auto_sharding_strategy", @@ -121,7 +123,7 @@ cc_library( "auto_sharding_solver.h", "auto_sharding_strategy.h", ], - visibility = ["//visibility:public"], + compatible_with = get_compatible_with_libtpu_portable(), deps = [ ":auto_sharding_proto_cc", "//xla:shape_util", @@ -138,13 +140,14 @@ cc_library( cc_library( name = "auto_sharding_cost_graph", hdrs = ["auto_sharding_cost_graph.h"], - visibility = ["//visibility:public"], + compatible_with = get_compatible_with_libtpu_portable(), deps = [ ":auto_sharding_strategy", ":matrix", "//xla:shape_util", "//xla/hlo/ir:hlo", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", @@ -155,7 +158,7 @@ cc_library( name = "auto_sharding_option", srcs = ["auto_sharding_option.cc"], hdrs = ["auto_sharding_option.h"], - visibility = ["//visibility:public"], + compatible_with = get_compatible_with_libtpu_portable(), deps = [ ":auto_sharding_util", "@com_google_absl//absl/algorithm:container", @@ -168,7 +171,7 @@ cc_library( cc_library( name = "auto_sharding_wrapper", hdrs = ["auto_sharding_wrapper.h"], - visibility = ["//visibility:public"], + compatible_with = get_compatible_with_libtpu_portable(), deps = [ ":auto_sharding_cost_graph", ":auto_sharding_option", @@ -184,7 +187,7 @@ cc_library( cc_library( name = "auto_sharding_impl", srcs = ["auto_sharding_impl.cc"], - visibility = ["//visibility:public"], + compatible_with = get_compatible_with_libtpu_portable(), deps = [ ":auto_sharding_cost_graph", ":auto_sharding_option", @@ -201,7 +204,7 @@ cc_library( cc_library( name = "matrix", hdrs = ["matrix.h"], - visibility = ["//visibility:public"], + compatible_with = get_compatible_with_libtpu_portable(), deps = [ "@com_google_absl//absl/strings", "@local_tsl//tsl/platform:logging", @@ -212,7 +215,7 @@ cc_library( name = "cluster_environment", srcs = ["cluster_environment.cc"], hdrs = ["cluster_environment.h"], - visibility = ["//visibility:public"], + compatible_with = get_compatible_with_libtpu_portable(), deps = [ ":auto_sharding_option", ":auto_sharding_strategy", @@ -226,7 +229,7 @@ cc_library( cc_library( name = "profiling_result", hdrs = ["profiling_result.h"], - visibility = ["//visibility:public"], + compatible_with = get_compatible_with_libtpu_portable(), deps = [":auto_sharding_strategy"], ) @@ -234,7 +237,7 @@ cc_library( name = "auto_sharding_util", srcs = ["auto_sharding_util.cc"], hdrs = ["auto_sharding_util.h"], - visibility = ["//visibility:public"], + compatible_with = get_compatible_with_libtpu_portable(), deps = [ ":auto_sharding_strategy", "//xla:array", @@ -267,13 +270,14 @@ cc_library( name = "metrics", srcs = ["metrics.cc"], hdrs = ["metrics.h"], - visibility = ["//visibility:public"], + compatible_with = get_compatible_with_libtpu_portable(), deps = ["@local_tsl//tsl/lib/monitoring:counter"], ) xla_cc_binary( name = "auto_sharding_runner", srcs = ["auto_sharding_runner.cc"], + compatible_with = get_compatible_with_libtpu_portable(), deps = [ ":auto_sharding", "//xla:status", @@ -287,7 +291,6 @@ xla_cc_binary( tf_proto_library( name = "auto_sharding_proto", srcs = ["auto_sharding.proto"], - visibility = ["//visibility:public"], ) build_test( diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.cc b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.cc index aed4423833ba71..4fc3d6320d47db 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.cc +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.cc @@ -238,20 +238,24 @@ void FollowArrayOrTokenStrategyGroup( const StrategyGroup& src_strategy_group, const Shape& shape, const size_t instruction_id, const bool have_memory_cost, const ClusterEnvironment& cluster_env, - StableHashMap>& + const StableHashMap>& pretrimmed_strategy_map, StrategyGroup& strategy_group) { CHECK(shape.IsArray() || shape.IsToken()); + std::vector pretrimmed_strategies; // Only follows the given strategy when there is no other strategy to be // restored. - if (!pretrimmed_strategy_map.contains(src_strategy_group.node_idx)) { + auto pretrimmed_strategy_map_it = + pretrimmed_strategy_map.find(src_strategy_group.node_idx); + if (pretrimmed_strategy_map_it != pretrimmed_strategy_map.end()) { + pretrimmed_strategies = pretrimmed_strategy_map_it->second; + } else { strategy_group.following = &src_strategy_group; } + strategy_group.strategies.reserve(src_strategy_group.strategies.size()); // Creates the sharding strategies and restores trimmed strategies, if any. - std::vector& pretrimmed_strategies = - pretrimmed_strategy_map[src_strategy_group.node_idx]; for (int64_t sid = 0; sid < src_strategy_group.strategies.size() + pretrimmed_strategies.size(); ++sid) { @@ -288,7 +292,7 @@ std::unique_ptr MaybeFollowInsStrategyGroup( const StrategyGroup* src_strategy_group, const Shape& shape, const size_t instruction_id, const bool have_memory_cost, StrategyGroups& strategy_groups, const ClusterEnvironment& cluster_env, - StableHashMap>& + const StableHashMap>& pretrimmed_strategy_map) { std::unique_ptr strategy_group; if (src_strategy_group->is_tuple) { @@ -315,7 +319,7 @@ std::unique_ptr MaybeFollowInsStrategyGroup( return strategy_group; } -StatusOr> FollowReduceStrategy( +absl::StatusOr> FollowReduceStrategy( const HloInstruction* ins, const Shape& output_shape, const HloInstruction* operand, const HloInstruction* unit, const size_t instruction_id, StrategyMap& strategy_map, @@ -1175,7 +1179,7 @@ void FillAllStrategiesForArray( } } -StatusOr> CreateAllStrategiesGroup( +absl::StatusOr> CreateAllStrategiesGroup( const HloInstruction* ins, const Shape& shape, const size_t instruction_id, StrategyGroups& strategy_groups, const ClusterEnvironment& cluster_env, const StrategyMap& strategy_map, const AutoShardingOption& option, @@ -1521,7 +1525,7 @@ std::unique_ptr CreateElementwiseOperatorStrategies( const size_t instruction_id, const HloInstruction* ins, const StrategyMap& strategy_map, const ClusterEnvironment& cluster_env, const InstructionDepthMap& depth_map, const AliasMap& alias_map, - StableHashMap>& + const StableHashMap>& pretrimmed_strategy_map, const int64_t max_depth, StrategyGroups& strategy_groups, AssociativeDotPairs& associative_dot_pairs) { @@ -1703,11 +1707,14 @@ AutoShardingSolverResult CallSolver( int num_nodes_without_default = 0; for (NodeIdx node_idx = 0; node_idx < request.num_nodes(); ++node_idx) { const StrategyGroup* strategy_group = strategy_groups[node_idx]; - auto instruction_name = - instructions.at(strategy_group->instruction_id)->name(); + const auto instruction = instructions.at(strategy_group->instruction_id); + const auto instruction_name = instruction->name(); + const auto opcode = HloOpcodeString(instruction->opcode()); request.add_instruction_names( absl::StrCat(instruction_name, " (id: ", node_idx, ")")); + request.add_opcodes(std::string(opcode)); AutoShardingSolverRequest_Costs ci, di, mi, pi; + AutoShardingSolverRequest_Names strategy_names; std::optional default_strategy; auto iter = sharding_propagation_solution.find(instruction_name); if (iter != sharding_propagation_solution.end()) { @@ -1728,6 +1735,7 @@ AutoShardingSolverResult CallSolver( cost_graph.extra_node_costs_[node_idx][j]); mi.add_costs(strategy.memory_cost); pi.add_costs(default_strategy && sharding == *default_strategy ? 0 : 1); + strategy_names.add_names(sharding.ToString()); } if (option.use_sharding_propagation_for_default_shardings && *std::min_element(pi.costs().begin(), pi.costs().end()) > 0) { @@ -1740,6 +1748,7 @@ AutoShardingSolverResult CallSolver( request.mutable_communication_costs()->Add(std::move(di)); request.mutable_memory_costs()->Add(std::move(mi)); request.mutable_departure_costs()->Add(std::move(pi)); + request.mutable_strategy_names()->Add(std::move(strategy_names)); } LOG(INFO) << "Total nodes without default: " << num_nodes_without_default; @@ -1931,7 +1940,9 @@ void SetHloSharding(const HloInstructionSequence& sequence, const std::vector& instructions = sequence.instructions(); for (HloInstruction* inst : instructions) { - if (inst->opcode() == HloOpcode::kOutfeed) { + if (inst->opcode() == HloOpcode::kOutfeed || + inst->opcode() == HloOpcode::kSend || + inst->opcode() == HloOpcode::kSendDone) { continue; } auto iter = strategy_map.find(inst); @@ -2076,11 +2087,17 @@ Status SetHloShardingPostProcessing( device_mesh, resharding_cache); } } - } else if (inst->opcode() == HloOpcode::kOutfeed) { - // Outfeed operand shardings are handled in downstream passes and so we - // ignore outfeed ops here. However, we need to ensure that outfeed ops - // which have user shardings have their shardings restored at the end. If - // not, this can lead to errors downstream in the spmd_partitioner pass. + } else if (inst->opcode() == HloOpcode::kOutfeed || + inst->opcode() == HloOpcode::kSendDone) { + // Outfeed: Outfeed operand shardings are handled in downstream passes and + // so we ignore outfeed ops here. However, we need to ensure that outfeed + // ops which have user shardings have their shardings restored at the + // end. If not, this can lead to errors downstream in the spmd_partitioner + // pass. + + // In the analysis itself, we use replicated strategies as a stand-in for + // the (expected) maximal sharding annotations that send-done ops usually + // have. Here we restore these maximal shardings if present. auto preserved_sharding_iter = preserve_shardings->find(inst->name()); if (preserved_sharding_iter != preserve_shardings->end()) { const auto& preserved_sharding = preserved_sharding_iter->second; @@ -2102,7 +2119,22 @@ Status SetHloShardingPostProcessing( inst->set_sharding(preserved_sharding.at(0)); } } - + continue; + } else if (inst->opcode() == HloOpcode::kSend) { + // In the analysis itself, we use replicated strategies as a stand-in for + // the (expected) maximal sharding annotations that send ops usually + // have. Here we restore these maximal shardings if present. + auto preserved_sharding_iter = preserve_shardings->find(inst->name()); + if (preserved_sharding_iter != preserve_shardings->end()) { + const auto& preserved_sharding = preserved_sharding_iter->second; + if (preserved_sharding.size() > 1) { + inst->set_sharding( + HloSharding::Tuple(inst->shape(), preserved_sharding)); + } else { + CHECK_EQ(preserved_sharding.size(), 1); + inst->set_sharding(preserved_sharding[0]); + } + } continue; } else { if (inst->shape().IsTuple()) { @@ -2390,7 +2422,7 @@ void SaveShardingForInstruction( } } -// Check whether the shardings that need to be perserved are preserved. +// Check whether the shardings that need to be preserved are preserved. void CheckUserShardingPreservation( HloModule* module, const absl::flat_hash_map>& @@ -2450,7 +2482,7 @@ int64_t MemoryBudgetLowerBound(const HloModule& module, buffer_to_sharded_value_mapping; for (LivenessIdx time_idx = 0; time_idx < liveness_set.size(); ++time_idx) { for (const HloValue* value : liveness_set[time_idx]) { - auto buffer = alias_analysis->GetBufferContainingValue(*value); + const auto& buffer = alias_analysis->GetBufferContainingValue(*value); if (value->instruction()->has_sharding()) { auto this_value_sharding = get_value_sharding(value); auto iter = buffer_to_sharded_value_mapping.find(buffer.id()); @@ -2482,7 +2514,7 @@ int64_t MemoryBudgetLowerBound(const HloModule& module, } Shape shape = ShapeUtil::GetSubshape(value->instruction()->shape(), value->index()); - auto buffer = alias_analysis->GetBufferContainingValue(*value); + const auto& buffer = alias_analysis->GetBufferContainingValue(*value); auto iter = buffer_to_sharded_value_mapping.find(buffer.id()); std::optional optional_sharding = std::nullopt; if (iter != buffer_to_sharded_value_mapping.end()) { @@ -2531,6 +2563,7 @@ void RecoverShardingsFromPartialMesh( } } } + // DFS to find the replicated set starting from cur instruction. void FindReplicateSet( HloInstruction* cur, const AliasMap& alias_map, const CostGraph& cost_graph, @@ -3202,7 +3235,9 @@ AutoShardingImplementation::SaveAndRemoveShardingAnnotation( for (const HloComputation* computation : module->computations(execution_threads)) { for (const auto inst : computation->instructions()) { - if (inst->opcode() == HloOpcode::kOutfeed) { + if (inst->opcode() == HloOpcode::kOutfeed || + inst->opcode() == HloOpcode::kSend || + inst->opcode() == HloOpcode::kSendDone) { spmd::SaveShardingForInstruction(inst, /* save_for_copy_users */ false, preserve_shardings); @@ -3258,6 +3293,12 @@ AutoShardingImplementation::SaveAndRemoveShardingAnnotation( continue; } + if (ins->opcode() == HloOpcode::kOutfeed || + ins->opcode() == HloOpcode::kSend || + ins->opcode() == HloOpcode::kSendDone) { + continue; + } + if (ins->has_sharding()) { module_is_changed |= true; ins->clear_sharding(); @@ -3296,7 +3337,7 @@ AutoShardingImplementation::AutoShardingImplementation( const AutoShardingOption& option) : option_(option) {} -StatusOr AutoShardingImplementation::RunAutoSharding( +absl::StatusOr AutoShardingImplementation::RunAutoSharding( HloModule* module, const absl::flat_hash_set& replicated_small_tensors, const absl::flat_hash_set& execution_threads, @@ -3313,7 +3354,7 @@ StatusOr AutoShardingImplementation::RunAutoSharding( // shardings to their input ops. absl::flat_hash_map> unspecified_dims; - StatusOr changed = ProcessShardingInstruction( + absl::StatusOr changed = ProcessShardingInstruction( module, execution_threads, /*replace_sharding_with_copy=*/true, &unspecified_dims, /*saved_root_shardings=*/nullptr, /*saved_parameter_shardings=*/nullptr); @@ -3422,7 +3463,7 @@ StatusOr AutoShardingImplementation::RunAutoSharding( total_devices *= i; } if (mesh_idx != partial_mesh_shapes.size() - 1) { - StatusOr changed = spmd::AdjustShardingsWithPartialMeshShape( + absl::StatusOr changed = spmd::AdjustShardingsWithPartialMeshShape( sequence.instructions(), mesh_shape, total_devices, /* crash_on_error */ !option_.try_multiple_mesh_shapes); if (changed.ok()) { @@ -3647,7 +3688,7 @@ std::unique_ptr CloneModule(const HloModule* module) { return module_clone; } -StatusOr AutoSharding::Run( +absl::StatusOr AutoSharding::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { if (!option_.enable) { @@ -3656,7 +3697,7 @@ StatusOr AutoSharding::Run( LOG(INFO) << "Starting the auto sharding pass"; if (IsModuleManuallySharded(module)) { - LOG(ERROR) + LOG(FATAL) << "Auto-sharding on partially manually sharded modules is not yet " "supported. Please fall back on the sharding propagation pass."; return false; @@ -3725,8 +3766,7 @@ StatusOr AutoSharding::Run( /*is_spmd */ true, /*propagate_metadata */ false, /*allow_spmd_sharding_propagation_to_output*/ module->config().allow_spmd_sharding_propagation_to_output(), - /*allow_spmd_sharding_propagation_to_parameters */ - absl::InlinedVector{false}, + module->config().allow_spmd_sharding_propagation_to_parameters(), /*cse_prevention_only */ false, /*sharding_helper*/ nullptr); @@ -3787,7 +3827,7 @@ StatusOr AutoSharding::Run( } } - StatusOr module_is_changed; + absl::StatusOr module_is_changed; if (skip_auto_sharding) { VLOG(1) << "Solver timed out. Will now rely on sharding propagation to " "perform sharding."; @@ -3866,7 +3906,7 @@ StatusOr AutoSharding::Run( return module_is_changed; } -StatusOr DummyAutoSharding::Run( +absl::StatusOr DummyAutoSharding::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { // ----- Set Dummy Replicated Sharding ----- diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.h b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.h index 54b896c61caec7..1d911c83ed71f1 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.h +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.h @@ -55,7 +55,7 @@ class DummyAutoSharding : public HloModulePass { absl::string_view name() const override { return "dummy_auto_sharding"; } using HloPassInterface::Run; - StatusOr Run( + absl::StatusOr Run( HloModule* module, const absl::flat_hash_set& execution_threads) override; }; @@ -71,7 +71,7 @@ class AutoShardingImplementation { explicit AutoShardingImplementation(const AutoShardingOption& option); ~AutoShardingImplementation() = default; - StatusOr RunAutoSharding( + absl::StatusOr RunAutoSharding( HloModule* module, const absl::flat_hash_set& replicated_small_tensors, const absl::flat_hash_set& execution_threads, @@ -115,7 +115,7 @@ class AutoSharding : public HloModulePass { absl::string_view name() const override { return "auto_sharding"; } using HloPassInterface::Run; - StatusOr Run( + absl::StatusOr Run( HloModule* module, const absl::flat_hash_set& execution_threads) override; @@ -166,6 +166,8 @@ Status FilterStrategy(const HloInstruction* ins, const Shape& shape, Status HandleDot(std::unique_ptr& strategy_group, StrategyGroups& strategy_groups, StrategyMap& strategy_map, const HloInstruction* ins, size_t instruction_id, + const HloInstructionSequence& instruction_sequence, + const HloCostAnalysis& hlo_cost_analysis, const ClusterEnvironment& cluster_env, const InstructionBatchDimMap& batch_map, const AutoShardingOption& option, const CallGraph& call_graph); @@ -173,6 +175,8 @@ Status HandleDot(std::unique_ptr& strategy_group, Status HandleConv(std::unique_ptr& strategy_group, StrategyGroups& strategy_groups, StrategyMap& strategy_map, const HloInstruction* ins, size_t instruction_id, + const HloInstructionSequence& instruction_sequence, + const HloCostAnalysis& hlo_cost_analysis, const ClusterEnvironment& cluster_env, const InstructionBatchDimMap& batch_map, const AutoShardingOption& option, @@ -254,7 +258,7 @@ void FillAllStrategiesForArray( bool create_replicated_strategies, bool create_partially_replicated_strategies); -StatusOr> CreateAllStrategiesGroup( +absl::StatusOr> CreateAllStrategiesGroup( const HloInstruction* ins, const Shape& shape, size_t instruction_id, StrategyGroups& strategy_groups, const ClusterEnvironment& cluster_env, const StrategyMap& strategy_map, const AutoShardingOption& option, @@ -269,7 +273,7 @@ std::unique_ptr CreateElementwiseOperatorStrategies( size_t instruction_id, const HloInstruction* ins, const StrategyMap& strategy_map, const ClusterEnvironment& cluster_env, const InstructionDepthMap& depth_map, const AliasMap& alias_map, - StableHashMap>& + const StableHashMap>& pretrimmed_strategy_map, int64_t max_depth, StrategyGroups& strategy_groups, AssociativeDotPairs& associative_dot_pairs); @@ -313,7 +317,7 @@ void EnumerateAllPartition(const HloInstruction* ins, const Shape& shape, int64_t partition_dimensions, const std::vector& tensor_dims = {}); -StatusOr> FollowReduceStrategy( +absl::StatusOr> FollowReduceStrategy( const HloInstruction* ins, const Shape& output_shape, const HloInstruction* operand, const HloInstruction* unit, size_t instruction_id, StrategyMap& strategy_map, @@ -341,7 +345,7 @@ std::unique_ptr MaybeFollowInsStrategyGroup( const StrategyGroup* src_strategy_group, const Shape& shape, size_t instruction_id, bool have_memory_cost, StrategyGroups& strategy_groups, const ClusterEnvironment& cluster_env, - StableHashMap>& + const StableHashMap>& pretrimmed_strategy_map); void RemoveInvalidShardingsWithShapes(const Shape& shape, @@ -363,7 +367,7 @@ void TrimOrGenerateStrategiesBasedOnExistingSharding( const CallGraph& call_graph, bool strict); // Build possible sharding strategies and their costs for all instructions. -StatusOr> +absl::StatusOr> BuildStrategyAndCost(const HloInstructionSequence& sequence, const HloModule* module, const absl::flat_hash_map& diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.proto b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.proto index 64b6c158fb5d5f..4d1ebdfd0f5747 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.proto +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.proto @@ -31,6 +31,9 @@ message AutoShardingSolverRequest { message Edges { repeated int64 edges = 1; } + message Names { + repeated string names = 1; + } message SolverTimeout { int64 solver_timeout_in_seconds = 1; } @@ -56,6 +59,8 @@ message AutoShardingSolverRequest { repeated Pair aliases = 14; repeated Costs value_costs = 15; repeated string instruction_names = 16; + repeated string opcodes = 33; + repeated Names strategy_names = 32; optional SolverTimeout solver_timeout = 17; optional Coeff overbudget_coeff = 18; optional Coeff makespan_coeff = 19; diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_cost_graph.h b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_cost_graph.h index cd41ca48aaa28c..9b5599defde477 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_cost_graph.h +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_cost_graph.h @@ -26,6 +26,7 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "absl/log/check.h" +#include "absl/log/log.h" #include "absl/strings/str_cat.h" #include "absl/types/span.h" #include "xla/hlo/experimental/auto_sharding/auto_sharding_strategy.h" @@ -46,48 +47,57 @@ class CostGraph { adjacency_.assign(strategy_groups.size(), StableHashSet()); // Build the cost graph - for (const auto& strategies : strategy_groups) { - node_lens_.push_back(strategies->strategies.size()); + for (StrategyGroup* strategy_group : strategy_groups) { + node_lens_.push_back(strategy_group->strategies.size()); extra_node_costs_.push_back( - std::vector(strategies->strategies.size(), 0.0)); - - for (size_t i = 0; i < strategies->in_nodes.size(); ++i) { - if (!strategies->in_nodes[i]->is_tuple) { - NodeIdx src_idx = strategies->in_nodes[i]->node_idx; - NodeIdx dst_idx = strategies->node_idx; - Matrix edge_cost = CreateEdgeCost(src_idx, dst_idx, i, strategies); + std::vector(strategy_group->strategies.size(), 0.0)); + + const auto& in_nodes = strategy_group->in_nodes; + for (size_t i = 0; i < in_nodes.size(); ++i) { + if (!in_nodes[i]->is_tuple) { + NodeIdx src_idx = in_nodes[i]->node_idx; + NodeIdx dst_idx = strategy_group->node_idx; + Matrix edge_cost = + CreateEdgeCost(src_idx, dst_idx, i, strategy_group); AddEdgeCost(src_idx, dst_idx, edge_cost); - } else if (strategies->in_nodes[i]->is_tuple && - strategies->in_nodes.size() > 1) { - for (size_t l = 0; l < strategies->in_nodes[i]->childs.size(); l++) { - NodeIdx src_idx = strategies->in_nodes[i]->childs.at(l)->node_idx; - NodeIdx dst_idx = strategies->node_idx; + } else if (in_nodes[i]->is_tuple && in_nodes.size() > 1) { + for (size_t l = 0; l < in_nodes[i]->childs.size(); l++) { + NodeIdx src_idx = in_nodes[i]->childs.at(l)->node_idx; + NodeIdx dst_idx = strategy_group->node_idx; Matrix edge_cost = - CreateEdgeCost(src_idx, dst_idx, i, strategies, true); + CreateEdgeCost(src_idx, dst_idx, i, strategy_group, true); AddEdgeCost(src_idx, dst_idx, edge_cost); } } else { - CHECK_EQ(strategies->in_nodes.size(), 1) + CHECK_EQ(in_nodes.size(), 1) << "Do not support instructions with more than one tuple " "operand. If this CHECK fails, we will need to fix " "b/233412625."; - for (size_t l = 0; l < strategies->in_nodes[i]->childs.size(); l++) { - NodeIdx src_idx = strategies->in_nodes[i]->childs.at(l)->node_idx; - NodeIdx dst_idx = strategies->node_idx; + for (size_t l = 0; l < in_nodes[i]->childs.size(); l++) { + NodeIdx src_idx = in_nodes[i]->childs.at(l)->node_idx; + NodeIdx dst_idx = strategy_group->node_idx; // TODO(b/233412625) Support more general case, e.g., multiple tuple // operands. If there is only one operand and it's a tuple, the // first index of resharding_costs is for the tuple element. - Matrix edge_cost = - CreateEdgeCost(src_idx, dst_idx, /*in_node_idx=*/l, strategies); + Matrix edge_cost = CreateEdgeCost( + src_idx, dst_idx, /*in_node_idx=*/l, strategy_group); AddEdgeCost(src_idx, dst_idx, edge_cost); } } } - if (strategies->following) { - to_merge_pairs_.push_back( - {strategies->node_idx, strategies->following->node_idx}); + if (strategy_group->following) { + if (strategy_group->strategies.size() == + strategy_group->following->strategies.size()) { + to_merge_pairs_.push_back( + {strategy_group->node_idx, strategy_group->following->node_idx}); + } else { + LOG(WARNING) << "Different strategy counts for instruction ID " + << strategy_group->instruction_id + << " and following instruction ID " + << strategy_group->following->instruction_id; + } } } diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_dot_handler.cc b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_dot_handler.cc index d0d71870024d2a..2d6a1d03a3ccd5 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_dot_handler.cc +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_dot_handler.cc @@ -33,14 +33,17 @@ limitations under the License. #include "xla/hlo/experimental/auto_sharding/auto_sharding_option.h" #include "xla/hlo/experimental/auto_sharding/auto_sharding_strategy.h" #include "xla/hlo/experimental/auto_sharding/auto_sharding_util.h" +#include "xla/hlo/experimental/auto_sharding/auto_sharding_wrapper.h" #include "xla/hlo/experimental/auto_sharding/cluster_environment.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/ir/hlo_schedule.h" #include "xla/hlo/ir/hlo_sharding.h" #include "xla/service/call_graph.h" #include "xla/service/dot_as_convolution_util.h" +#include "xla/service/hlo_cost_analysis.h" #include "xla/service/sharding_propagation.h" #include "xla/status.h" #include "tsl/platform/errors.h" @@ -63,12 +66,18 @@ class HandlerBase { protected: HandlerBase(std::unique_ptr& strategy_group, StrategyMap& strategy_map, const HloInstruction* ins, + const int64_t instruction_id, + const HloInstructionSequence& instruction_sequence, + const HloCostAnalysis& hlo_cost_analysis, const ClusterEnvironment& cluster_env, const InstructionBatchDimMap& batch_map, const AutoShardingOption& option, const CallGraph& call_graph) : strategy_group_(strategy_group), strategy_map_(strategy_map), ins_(ins), + instruction_id_(instruction_id), + instruction_sequence_(instruction_sequence), + hlo_cost_analysis_(hlo_cost_analysis), cluster_env_(cluster_env), batch_map_(batch_map), option_(option), @@ -145,6 +154,9 @@ class HandlerBase { std::unique_ptr& strategy_group_; StrategyMap& strategy_map_; const HloInstruction* ins_; + const int64_t instruction_id_; + const HloInstructionSequence& instruction_sequence_; + const HloCostAnalysis& hlo_cost_analysis_; const ClusterEnvironment& cluster_env_; const InstructionBatchDimMap& batch_map_; const AutoShardingOption& option_; @@ -160,13 +172,18 @@ class DotHandler : public HandlerBase { public: DotHandler(std::unique_ptr& strategy_group, StrategyMap& strategy_map, const HloDotInstruction* ins, + int64_t instruction_id, + const HloInstructionSequence& instruction_sequence, + const HloCostAnalysis& hlo_cost_analysis, const ClusterEnvironment& cluster_env, const InstructionBatchDimMap& batch_map, const AutoShardingOption& option, const CallGraph& call_graph); DotHandler( std::unique_ptr& strategy_group, StrategyMap& strategy_map, - const HloConvolutionInstruction* ins, + const HloConvolutionInstruction* ins, int64_t instruction_id, + const HloInstructionSequence& instruction_sequence, + const HloCostAnalysis& hlo_cost_analysis, const dot_as_convolution_util::DotConvolutionDimsInfo& conv_as_dot_dims, const ClusterEnvironment& cluster_env, const InstructionBatchDimMap& batch_map, const AutoShardingOption& option, @@ -216,6 +233,9 @@ class ConvHandler : public HandlerBase { public: ConvHandler(std::unique_ptr& strategy_group, StrategyMap& strategy_map, const HloInstruction* ins, + int64_t instruction_id, + const HloInstructionSequence& instruction_sequence, + const HloCostAnalysis& hlo_cost_analysis, const ClusterEnvironment& cluster_env, const InstructionBatchDimMap& batch_map, const AutoShardingOption& option, const CallGraph& call_graph); @@ -353,12 +373,16 @@ std::optional HandlerBase::GetShardingFromUser( DotHandler::DotHandler(std::unique_ptr& strategy_group, StrategyMap& strategy_map, const HloDotInstruction* ins, + const int64_t instruction_id, + const HloInstructionSequence& instruction_sequence, + const HloCostAnalysis& hlo_cost_analysis, const ClusterEnvironment& cluster_env, const InstructionBatchDimMap& batch_map, const AutoShardingOption& option, const CallGraph& call_graph) - : HandlerBase(strategy_group, strategy_map, ins, cluster_env, batch_map, - option, call_graph), + : HandlerBase(strategy_group, strategy_map, ins, instruction_id, + instruction_sequence, hlo_cost_analysis, cluster_env, + batch_map, option, call_graph), is_dot_(true), space_base_dim_(ins->dot_dimension_numbers().lhs_batch_dimensions_size()), lhs_con_dims_(ins->dot_dimension_numbers().lhs_contracting_dimensions()), @@ -373,13 +397,16 @@ DotHandler::DotHandler(std::unique_ptr& strategy_group, DotHandler::DotHandler( std::unique_ptr& strategy_group, StrategyMap& strategy_map, - const HloConvolutionInstruction* ins, + const HloConvolutionInstruction* ins, const int64_t instruction_id, + const HloInstructionSequence& instruction_sequence, + const HloCostAnalysis& hlo_cost_analysis, const dot_as_convolution_util::DotConvolutionDimsInfo& conv_as_dot_dims, const ClusterEnvironment& cluster_env, const InstructionBatchDimMap& batch_map, const AutoShardingOption& option, const CallGraph& call_graph) - : HandlerBase(strategy_group, strategy_map, ins, cluster_env, batch_map, - option, call_graph), + : HandlerBase(strategy_group, strategy_map, ins, instruction_id, + instruction_sequence, hlo_cost_analysis, cluster_env, + batch_map, option, call_graph), is_dot_(false), space_base_dim_(-1) { CHECK(conv_as_dot_dims.conv_spatial_dims.empty()); @@ -652,6 +679,9 @@ void DotHandler::RecomputeSplitBothContract() { if (device_mesh_.dim(e.mesh_dims[0]) <= 1 || device_mesh_.dim(e.mesh_dims[1]) <= 1) return; + if (!option_.allow_recompute_heavy_op) { + return; + } std::string name = absl::StrFormat("RR = RS x SR @ {%d} (allreduce @ %d)", e.mesh_dims[0], e.mesh_dims[0]); const DimMap lhs_dim_map = {{lhs_con_dims_[e.i], e.mesh_dims[0]}}; @@ -660,7 +690,10 @@ void DotHandler::RecomputeSplitBothContract() { if (is_dot_) { out_dim_map = DimMap{}; } - double compute_cost = cluster_env_.DotCost(lhs_->shape(), rhs_->shape()); + double compute_cost = GetDotConvReplicationPenalty( + ins_, instruction_id_, /* window */ 10, + instruction_sequence_, hlo_cost_analysis_) / + device_mesh_.dim(e.mesh_dims[0]); auto communication_cost_fn = [this, &e](const HloSharding& output_spec) { double memory_cost = GetBytes(ins_->shape()) / output_spec.NumTiles(); return cluster_env_.AllReduceCost(memory_cost, e.mesh_dims[0]); @@ -838,12 +871,16 @@ Status DotHandler::RegisterStrategies() { ConvHandler::ConvHandler(std::unique_ptr& strategy_group, StrategyMap& strategy_map, const HloInstruction* ins, + const int64_t instruction_id, + const HloInstructionSequence& instruction_sequence, + const HloCostAnalysis& hlo_cost_analysis, const ClusterEnvironment& cluster_env, const InstructionBatchDimMap& batch_map, const AutoShardingOption& option, const CallGraph& call_graph) - : HandlerBase(strategy_group, strategy_map, ins, cluster_env, batch_map, - option, call_graph), + : HandlerBase(strategy_group, strategy_map, ins, instruction_id, + instruction_sequence, hlo_cost_analysis, cluster_env, + batch_map, option, call_graph), conv_dnums_(ins->convolution_dimension_numbers()) { lhs_batch_dim_ = conv_dnums_.input_batch_dimension(); lhs_in_channel_dim_ = conv_dnums_.input_feature_dimension(); @@ -1011,6 +1048,8 @@ void ConvHandler::SplitDepthwise(bool forward) { Status HandleDot(std::unique_ptr& strategy_group, StrategyGroups& strategy_groups, StrategyMap& strategy_map, const HloInstruction* ins, size_t instruction_id, + const HloInstructionSequence& instruction_sequence, + const HloCostAnalysis& hlo_cost_analysis, const ClusterEnvironment& cluster_env, const InstructionBatchDimMap& batch_map, const AutoShardingOption& option, @@ -1019,6 +1058,7 @@ Status HandleDot(std::unique_ptr& strategy_group, strategy_groups); DotHandler handler(strategy_group, strategy_map, Cast(ins), + instruction_id, instruction_sequence, hlo_cost_analysis, cluster_env, batch_map, option, call_graph); TF_RETURN_IF_ERROR(handler.RegisterStrategies()); return OkStatus(); @@ -1028,6 +1068,8 @@ Status HandleDot(std::unique_ptr& strategy_group, Status HandleConv(std::unique_ptr& strategy_group, StrategyGroups& strategy_groups, StrategyMap& strategy_map, const HloInstruction* ins, size_t instruction_id, + const HloInstructionSequence& instruction_sequence, + const HloCostAnalysis& hlo_cost_analysis, const ClusterEnvironment& cluster_env, const InstructionBatchDimMap& batch_map, const AutoShardingOption& option, @@ -1038,13 +1080,15 @@ Status HandleConv(std::unique_ptr& strategy_group, auto conv_as_dot_dims = dot_as_convolution_util::ParseConvolutionDimsInfo(ins); if (conv_as_dot_dims.conv_spatial_dims.empty()) { - DotHandler handler(strategy_group, strategy_map, - Cast(ins), conv_as_dot_dims, - cluster_env, batch_map, option, call_graph); + DotHandler handler( + strategy_group, strategy_map, Cast(ins), + instruction_id, instruction_sequence, hlo_cost_analysis, + conv_as_dot_dims, cluster_env, batch_map, option, call_graph); TF_RETURN_IF_ERROR(handler.RegisterStrategies()); } else { - ConvHandler handler(strategy_group, strategy_map, ins, cluster_env, + ConvHandler handler(strategy_group, strategy_map, ins, instruction_id, + instruction_sequence, hlo_cost_analysis, cluster_env, batch_map, option, call_graph); TF_RETURN_IF_ERROR(handler.RegisterStrategies()); } diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_option.cc b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_option.cc index dc3f77634360fd..208e11ea3cb4a9 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_option.cc +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_option.cc @@ -107,9 +107,6 @@ std::string AutoShardingOption::ToString() const { lines.push_back(absl::StrCat("nd_sharding_iteratively_strict_search_space: ", nd_sharding_iteratively_strict_search_space)); - lines.push_back(absl::StrCat("allow_replicated_strategy_for_dot_and_conv: ", - allow_replicated_strategy_for_dot_and_conv)); - lines.push_back(absl::StrCat("device_mesh_shape: [", absl::StrJoin(device_mesh_shape, ","), "]")); lines.push_back(absl::StrCat("device_mesh_alpha: [", diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_option.h b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_option.h index 682c6f16ae585f..1dd8f8454d1f7a 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_option.h +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_option.h @@ -104,9 +104,12 @@ struct AutoShardingOption { // 2d mesh case. bool batch_matmul_always_split_batch = false; - // If true, allow strategies that recompute heavy operators (e.g., dot) - // to reduce communication. - bool allow_recompute_heavy_op = false; + // If true, allow strategies that recompute heavy operators (e.g., dot) to + // reduce communication. This will generate generate replicated or partially + // replicated strategies for dot/conv ops. Generating these seems to be + // beneficial for LLM serving models, but can increase the search space, so + // this feature is exposed as an option. + bool allow_recompute_heavy_op = true; // If true, allow adding 1d strategies in 2d logical mesh. bool allow_mixed_mesh_shape = false; @@ -143,11 +146,6 @@ struct AutoShardingOption { // space more scalable. Therefore leaving it as an option. bool nd_sharding_iteratively_strict_search_space = false; - // Whether or not to generate replicated strategies for dot/conv - // ops. Generating these seems to be beneficial for LLM serving models, but - // can increase the search space, so this feature is exposed as an option. - bool allow_replicated_strategy_for_dot_and_conv = true; - // Device mesh shape. std::vector device_mesh_shape; // Device IDs in the mesh. diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_solver.cc b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_solver.cc index e19acb04f6abef..e983179128f405 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_solver.cc +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_solver.cc @@ -58,6 +58,10 @@ using ::operations_research::MPConstraint; using ::operations_research::MPSolver; using ::operations_research::MPVariable; +// We need to nudge the maximum cost (if present) slightly, since the constraint +// solver cannot guarantee exact numerical precision. +constexpr double kMaxCostEpsilon = 1.0001; + bool AutoShardingSolverResult::operator==( const AutoShardingSolverResult& other) const { return status == other.status && @@ -294,23 +298,19 @@ AutoShardingSolverResult CallORToolsSolver( MPVariable* makespan_var = nullptr; size_t unique_nodes = 0; - const auto strat_follow = StratFollow(request); for (NodeIdx node_idx = 0; node_idx < request.num_nodes(); ++node_idx) { if (request.s_follow(node_idx) < 0) { unique_nodes += 1; // Creates variables for instructions that do not follow others. - for (NodeStrategyIdx j = 0; j < request.s_len(node_idx); ++j) { - MPVariable* var = - strat_follow[node_idx][j] >= 0 - ? s[node_idx][strat_follow[node_idx][j]] - : solver->MakeBoolVar(absl::StrCat("s[", node_idx, "]", j)); - s[node_idx].push_back(var); - } + solver->MakeBoolVarArray(request.s_len(node_idx), + absl::StrCat("s[", node_idx, "]"), &s[node_idx]); } } for (NodeIdx node_idx = 0; node_idx < request.num_nodes(); ++node_idx) { if (request.s_follow(node_idx) >= 0) { + CHECK_EQ(request.s_len(node_idx), + request.s_len(request.s_follow(node_idx))); // Copies the variable of followed instruction to the following // instruction. s[node_idx] = s[request.s_follow(node_idx)]; @@ -332,26 +332,9 @@ AutoShardingSolverResult CallORToolsSolver( continue; } unique_edges += 1; - for (NodeStrategyIdx j = 0; j < request.s_len(edge.first); ++j) { - for (NodeStrategyIdx k = 0; k < request.s_len(edge.second); ++k) { - NodeStrategyIdx j_follow = strat_follow[followed_edge.first][j] >= 0 - ? strat_follow[followed_edge.first][j] - : j; - NodeStrategyIdx k_follow = strat_follow[followed_edge.second][k] >= 0 - ? strat_follow[followed_edge.second][k] - : k; - EdgeStrategyIdx edge_strategy_idx = - j_follow * request.s_len(followed_edge.second) + k_follow; - MPVariable* var = - (strat_follow[followed_edge.first][j] >= 0 || - strat_follow[followed_edge.second][k] >= 0) - ? e[edge_idx][edge_strategy_idx] - : solver->MakeBoolVar(absl::StrCat("e[", followed_edge.first, - ",", followed_edge.second, - "]", edge_strategy_idx)); - e[edge_idx].push_back(var); - } - } + solver->MakeBoolVarArray( + request.s_len(edge.first) * request.s_len(edge.second), + absl::StrCat("e[", edge.first, ",", edge.second, "]"), &e[edge_idx]); edge_map.insert({followed_edge, edge_idx}); } @@ -366,16 +349,17 @@ AutoShardingSolverResult CallORToolsSolver( // Construct objective function. // Node costs + absl::flat_hash_set infinity_vars; for (NodeIdx node_idx = 0; node_idx < request.num_nodes(); ++node_idx) { - absl::flat_hash_set visited_vars; for (NodeStrategyIdx j = 0; j < s[node_idx].size(); ++j) { - if (visited_vars.contains(s[node_idx][j])) continue; - visited_vars.insert(s[node_idx][j]); double accumulated_coefficient = solver->MutableObjective()->GetCoefficient(s[node_idx][j]); double coefficient = request.computation_costs(node_idx).costs(j) + request.communication_costs(node_idx).costs(j); - if (coefficient >= kInfinityCost) continue; + if (coefficient >= kInfinityCost) { + infinity_vars.insert(s[node_idx][j]); + continue; + } AddSalt(absl::StrCat(node_idx, "S", j), request.saltiplier(), &coefficient); solver->MutableObjective()->SetCoefficient( @@ -384,31 +368,34 @@ AutoShardingSolverResult CallORToolsSolver( } // Edge costs for (EdgeIdx edge_idx = 0; edge_idx < num_edges; ++edge_idx) { - absl::flat_hash_set visited_vars; for (EdgeStrategyIdx j = 0; j < e[edge_idx].size(); ++j) { - if (visited_vars.contains(e[edge_idx][j])) continue; - visited_vars.insert(e[edge_idx][j]); double accumulated_coefficient = solver->MutableObjective()->GetCoefficient(e[edge_idx][j]); double coefficient = request.resharding_costs(edge_idx).costs(j); - if (coefficient >= kInfinityCost) continue; + if (coefficient >= kInfinityCost) { + infinity_vars.insert(e[edge_idx][j]); + continue; + } AddSalt(absl::StrCat(edge_idx, "E", j), request.saltiplier(), &coefficient); solver->MutableObjective()->SetCoefficient( e[edge_idx][j], accumulated_coefficient + coefficient); } } + LOG(INFO) << "Number of infinity terms: " << infinity_vars.size(); // Add constraints. // 0. Do not choose solutions with infinity costs, as it will make the // objective value so large that other solution choices do not matter anymore. + // Also eliminate strategies that are known to be dominated by others. + const NodeStrategies shaved_strategies = + StrategyShaver(request).FindShavedStrategies(); for (NodeIdx node_idx = 0; node_idx < request.num_nodes(); ++node_idx) { if (s[node_idx].empty() || request.s_follow(node_idx) >= 0) continue; bool all_infinity = true; for (NodeStrategyIdx j = 0; j < s[node_idx].size(); ++j) { - const double node_cost = request.computation_costs(node_idx).costs(j) + - request.communication_costs(node_idx).costs(j); - if (node_cost >= kInfinityCost) { + if (infinity_vars.contains(s[node_idx][j]) || + shaved_strategies.contains({node_idx, j})) { MPConstraint* constraint = solver->MakeRowConstraint( 0.0, 0.0, absl::StrCat("infinitycost: s[", node_idx, "][", j, "] = 0")); @@ -425,8 +412,7 @@ AutoShardingSolverResult CallORToolsSolver( if (e[edge_idx].empty() || e_follow[edge_idx] >= 0) continue; bool all_infinity = true; for (EdgeStrategyIdx j = 0; j < e[edge_idx].size(); ++j) { - const double edge_cost = request.resharding_costs(edge_idx).costs(j); - if (edge_cost >= kInfinityCost) { + if (infinity_vars.contains(e[edge_idx][j])) { MPConstraint* constraint = solver->MakeRowConstraint( 0.0, 0.0, absl::StrCat("infinitycost: e[", edge_idx, "][", j, "] = 0")); @@ -456,10 +442,7 @@ AutoShardingSolverResult CallORToolsSolver( 1.0, 1.0, absl::StrCat("sum(s[", node_idx, "][j] for j = [0 .. ", s[node_idx].size(), ")) = 1")); - absl::flat_hash_set visited_vars; for (NodeStrategyIdx j = 0; j < s[node_idx].size(); ++j) { - if (visited_vars.contains(s[node_idx][j])) continue; - visited_vars.insert(s[node_idx][j]); constraint->SetCoefficient(s[node_idx][j], 1.0); } } @@ -477,10 +460,7 @@ AutoShardingSolverResult CallORToolsSolver( absl::StrCat("mem[", time_idx, "]")); if (overbudget_var) constraint->SetCoefficient(overbudget_var, -1.0); for (NodeIdx node_idx : request.live(time_idx).nodes()) { - absl::flat_hash_set visited_vars; for (NodeStrategyIdx j = 0; j < s[node_idx].size(); ++j) { - if (visited_vars.contains(s[node_idx][j])) continue; - visited_vars.insert(s[node_idx][j]); const double accumulated_coefficient = constraint->GetCoefficient(s[node_idx][j]); const double memory_cost = request.memory_costs(node_idx).costs(j); @@ -490,10 +470,7 @@ AutoShardingSolverResult CallORToolsSolver( } if (request.live_edges().empty()) continue; for (EdgeIdx edge_idx : request.live_edges(time_idx).edges()) { - absl::flat_hash_set visited_vars; for (EdgeStrategyIdx j = 0; j < e[edge_idx].size(); ++j) { - if (visited_vars.contains(e[edge_idx][j])) continue; - visited_vars.insert(e[edge_idx][j]); const double accumulated_coefficient = constraint->GetCoefficient(e[edge_idx][j]); const double memory_cost = @@ -522,10 +499,7 @@ AutoShardingSolverResult CallORToolsSolver( MPConstraint* constraint = solver->MakeRowConstraint( 1.0, 1.0, absl::StrCat("sum(e[", edge.first(), "][", edge.second(), "][*]) = 1")); - absl::flat_hash_set visited_vars; for (EdgeStrategyIdx j = 0; j < e[edge_idx].size(); ++j) { - if (visited_vars.contains(e[edge_idx][j])) continue; - visited_vars.insert(e[edge_idx][j]); constraint->SetCoefficient(e[edge_idx][j], 1.0); } } @@ -538,11 +512,8 @@ AutoShardingSolverResult CallORToolsSolver( -MPSolver::infinity(), 0, absl::StrCat("f for i = ", edge_idx, ", p = ", p)); constraint->SetCoefficient(s[edge.first()][p], -1.0); - absl::flat_hash_set visited_vars; for (NodeStrategyIdx q = 0; q < s[edge.second()].size(); ++q) { const EdgeStrategyIdx j = p * s[edge.second()].size() + q; - if (visited_vars.contains(e[edge_idx][j])) continue; - visited_vars.insert(e[edge_idx][j]); constraint->SetCoefficient(e[edge_idx][j], 1.0); } } @@ -556,11 +527,8 @@ AutoShardingSolverResult CallORToolsSolver( -MPSolver::infinity(), 0, absl::StrCat("g for i = ", edge_idx, ", q = ", q)); constraint->SetCoefficient(s[edge.second()][q], -1.0); - absl::flat_hash_set visited_vars; for (NodeStrategyIdx p = 0; p < s[edge.first()].size(); ++p) { const EdgeStrategyIdx j = p * s[edge.second()].size() + q; - if (visited_vars.contains(e[edge_idx][j])) continue; - visited_vars.insert(e[edge_idx][j]); constraint->SetCoefficient(e[edge_idx][j], 1.0); } } @@ -603,10 +571,10 @@ AutoShardingSolverResult CallORToolsSolver( } } if (request.has_max_cost()) { + double max_cost = kMaxCostEpsilon * request.max_cost().coeff(); + max_cost -= solver->Objective().offset(); MPConstraint* cost_constraint = solver->MakeRowConstraint( - -MPSolver::infinity(), - request.max_cost().coeff() - solver->Objective().offset(), - "cost_constraint"); + -MPSolver::infinity(), max_cost, "cost_constraint"); for (const auto [var, coeff] : solver->Objective().terms()) { cost_constraint->SetCoefficient(var, coeff); } @@ -616,10 +584,7 @@ AutoShardingSolverResult CallORToolsSolver( std::vector> hint; for (NodeIdx node_idx = 0; node_idx < request.num_nodes(); ++node_idx) { if (request.s_follow(node_idx) >= 0) continue; - absl::flat_hash_set visited_vars; for (NodeStrategyIdx j = 0; j < s[node_idx].size(); ++j) { - if (visited_vars.contains(s[node_idx][j])) continue; - visited_vars.insert(s[node_idx][j]); double hint_val = (request.s_hint(node_idx) == j) ? 1.0 : 0.0; hint.push_back({s[node_idx][j], hint_val}); } @@ -674,6 +639,7 @@ AutoShardingSolverResult CallORToolsSolver( << "Unique nodes: " << unique_nodes << "\n" << "Unique edges: " << unique_edges << "\n" << "Total instructions: " << request.num_nodes() << "\n" + << "Total edges: " << request.edges_size() << "\n" << "Memory budget: " << request.memory_budget() / (1024 * 1024 * 1024) << "GB\n" << "Number variables for ILP: " << solver->NumVariables() << "\n" diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_solver.h b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_solver.h index ba3a1c06498e48..4315dda0f28d70 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_solver.h +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_solver.h @@ -33,13 +33,13 @@ namespace spmd { struct AutoShardingSolverResult { public: AutoShardingSolverResult( - StatusOr, - std::vector, double>> + absl::StatusOr, + std::vector, double>> status, bool skip_auto_sharding) : status(status), skip_auto_sharding(skip_auto_sharding) {} bool operator==(const AutoShardingSolverResult& other) const; - StatusOr, std::vector, double>> + absl::StatusOr, std::vector, double>> status; bool skip_auto_sharding; }; @@ -110,19 +110,31 @@ double EvaluateMakespan(const AutoShardingSolverRequest& request, AutoShardingSolverRequest ScaleRequest( const AutoShardingSolverRequest& request); -// Determines if two strategies are equivalent (i.e., share identical node -// costs, edge costs, and alias mappings). -bool CheckEquivalent(const AutoShardingSolverRequest& request, - const std::vector& src_edges, - const std::vector& dst_edges, - const std::vector& src_aliases, - const std::vector& dst_aliases, NodeIdx node_idx, - NodeStrategyIdx first, NodeStrategyIdx second); - -// For every node, examine each sharding strategy to see if it is equivalent to -// another (which, if so, would allow the reusing of strategy variables). -std::vector> StratFollow( - const AutoShardingSolverRequest& request); +// Determines if strategy 'first' is dominated by strategy 'second' (i.e., its +// costs are all equal or worse, and it has identical alias mappings). +bool CheckDominance(const AutoShardingSolverRequest& request, + const std::vector& src_edges, + const std::vector& dst_edges, + const std::vector& src_aliases, + const std::vector& dst_aliases, NodeIdx node_idx, + NodeStrategyIdx first, NodeStrategyIdx second); + +class StrategyShaver { + public: + explicit StrategyShaver(const AutoShardingSolverRequest& request); + + // For every node, examine each sharding strategy to see if it is dominated by + // another. + NodeStrategies FindShavedStrategies() const; + + private: + const AutoShardingSolverRequest& request_; // NOLINT + std::vector> src_edge_map_; + std::vector> dst_edge_map_; + std::vector> src_alias_map_; + std::vector> dst_alias_map_; + std::vector> followers_; +}; } // namespace spmd } // namespace xla diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_solver_impl.cc b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_solver_impl.cc index 2120bbac5bd586..4be54f98a0a496 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_solver_impl.cc +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_solver_impl.cc @@ -15,7 +15,6 @@ limitations under the License. #include -#include "absl/log/check.h" #include "xla/hlo/experimental/auto_sharding/auto_sharding.pb.h" #include "xla/hlo/experimental/auto_sharding/auto_sharding_solver.h" #include "xla/hlo/experimental/auto_sharding/auto_sharding_strategy.h" @@ -39,15 +38,11 @@ double EvaluateMakespan(const AutoShardingSolverRequest& request, return 0.0; // TODO(moffitt): Implement this. } -std::vector> StratFollow( - const AutoShardingSolverRequest& request) { - CHECK_EQ(request.num_nodes(), request.s_len_size()); - std::vector> strat_follow(request.num_nodes()); - for (NodeIdx node_idx = 0; node_idx < request.num_nodes(); ++node_idx) { - if (request.s_follow(node_idx) >= 0) continue; - strat_follow[node_idx].resize(request.s_len(node_idx), -1); - } - return strat_follow; +StrategyShaver::StrategyShaver(const AutoShardingSolverRequest& request) + : request_(request) {} + +NodeStrategies StrategyShaver::FindShavedStrategies() const { + return {}; // TODO(moffitt): Implement this. } } // namespace spmd diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_strategy.cc b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_strategy.cc index 99b9a0b34b680a..62bc888e69943e 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_strategy.cc +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_strategy.cc @@ -67,7 +67,7 @@ bool LeafVectorsAreConsistent(const std::vector& one, // NOLINTBEGIN(readability/fn_size) // TODO(zhuohan): Decompose this function into smaller pieces -StatusOr> +absl::StatusOr> BuildStrategyAndCost(const HloInstructionSequence& sequence, const HloModule* module, const absl::flat_hash_map& @@ -101,6 +101,8 @@ BuildStrategyAndCost(const HloInstructionSequence& sequence, max_depth = std::max(max_depth, iter.second); } + absl::flat_hash_map + while_body_args_to_input_tuple; // Register strategies and their costs for each instruction. for (size_t instruction_id = 0; instruction_id < instructions.size(); ++instruction_id) { @@ -126,8 +128,45 @@ BuildStrategyAndCost(const HloInstructionSequence& sequence, only_allow_divisible = option.only_allow_divisible_intermediate; } + bool is_follow_necessary_for_correctness = false; switch (opcode) { - case HloOpcode::kParameter: + case HloOpcode::kParameter: { + auto it = while_body_args_to_input_tuple.find(ins); + if (it != while_body_args_to_input_tuple.end()) { + const HloInstruction* while_input_tuple = it->second; + const StrategyGroup* while_input_tuple_strategy_group = + strategy_map.at(while_input_tuple).get(); + + VLOG(5) << "Following while input " << while_input_tuple->name(); + strategy_group = CreateTupleStrategyGroup(instruction_id); + strategy_group->childs.reserve(ins->shape().tuple_shapes_size()); + // We use this following relationship to ensure that the input tuple + // of the while loop, and the parameter of the body of that while + // loop. Therefore, this followinf relationship is necessary for + // correctness, and is not merely an optmization. + is_follow_necessary_for_correctness = true; + for (size_t i = 0; i < ins->shape().tuple_shapes_size(); ++i) { + std::unique_ptr child_strategies = + MaybeFollowInsStrategyGroup( + while_input_tuple_strategy_group->childs[i].get(), + ins->shape().tuple_shapes().at(i), instruction_id, + /* have_memory_cost= */ true, strategy_groups, cluster_env, + pretrimmed_strategy_map); + child_strategies->tuple_element_idx = i; + strategy_group->childs.push_back(std::move(child_strategies)); + } + } else { + strategy_group = + CreateAllStrategiesGroup( + ins, ins->shape(), instruction_id, strategy_groups, + cluster_env, strategy_map, option, replicated_penalty, + batch_dim_map, call_graph, only_allow_divisible, + option.allow_replicated_parameters, + /* create_partially_replicated_strategies */ true) + .value(); + } + break; + } case HloOpcode::kRngBitGenerator: case HloOpcode::kRng: { strategy_group = @@ -345,8 +384,17 @@ BuildStrategyAndCost(const HloInstructionSequence& sequence, // Find output shardings. switch (opcode) { + case HloOpcode::kSlice: { + bool is_1d_sharding = + VectorGreaterThanOneElementCount( + input_spec.tile_assignment().dimensions()) == 1; + output_spec = PropagateDimwiseShardingSlice( + input_spec, operand->shape(), ins->shape(), + is_1d_sharding ? cluster_env.device_mesh_1d_ + : cluster_env.device_mesh_); + break; + } case HloOpcode::kPad: - case HloOpcode::kSlice: case HloOpcode::kConcatenate: case HloOpcode::kDynamicSlice: case HloOpcode::kDynamicUpdateSlice: @@ -434,6 +482,7 @@ BuildStrategyAndCost(const HloInstructionSequence& sequence, case HloOpcode::kBitcastConvert: case HloOpcode::kCopy: case HloOpcode::kCos: + case HloOpcode::kErf: case HloOpcode::kExp: case HloOpcode::kExpm1: case HloOpcode::kFloor: @@ -495,10 +544,12 @@ BuildStrategyAndCost(const HloInstructionSequence& sequence, break; } case HloOpcode::kDot: { - TF_RETURN_IF_ERROR(HandleDot( - strategy_group, strategy_groups, strategy_map, ins, instruction_id, - cluster_env, batch_dim_map, option, call_graph)); - if (option.allow_replicated_strategy_for_dot_and_conv) { + TF_RETURN_IF_ERROR(HandleDot(strategy_group, strategy_groups, + strategy_map, ins, instruction_id, + sequence, hlo_cost_analysis, cluster_env, + batch_dim_map, option, call_graph)); + + if (option.allow_recompute_heavy_op) { AddReplicatedStrategy( ins, ins->shape(), cluster_env, strategy_map, strategy_group, GetDotConvReplicationPenalty(ins, instruction_id, /* window */ 10, @@ -507,10 +558,11 @@ BuildStrategyAndCost(const HloInstructionSequence& sequence, break; } case HloOpcode::kConvolution: { - TF_RETURN_IF_ERROR(HandleConv( - strategy_group, strategy_groups, strategy_map, ins, instruction_id, - cluster_env, batch_dim_map, option, call_graph)); - if (option.allow_replicated_strategy_for_dot_and_conv) { + TF_RETURN_IF_ERROR(HandleConv(strategy_group, strategy_groups, + strategy_map, ins, instruction_id, + sequence, hlo_cost_analysis, cluster_env, + batch_dim_map, option, call_graph)); + if (option.allow_recompute_heavy_op) { AddReplicatedStrategy( ins, ins->shape(), cluster_env, strategy_map, strategy_group, GetDotConvReplicationPenalty(ins, instruction_id, /* window */ 10, @@ -553,6 +605,16 @@ BuildStrategyAndCost(const HloInstructionSequence& sequence, child_strategies->tuple_element_idx = i; strategy_group->childs.push_back(std::move(child_strategies)); } + + if (ins->users().size() == 1 && + ins->users()[0]->opcode() == HloOpcode::kWhile) { + const HloInstruction* while_op = ins->users()[0]; + while_body_args_to_input_tuple[while_op->while_body() + ->parameter_instruction(0)] = ins; + while_body_args_to_input_tuple[while_op->while_condition() + ->parameter_instruction(0)] = ins; + } + break; } case HloOpcode::kGetTupleElement: { @@ -677,6 +739,27 @@ BuildStrategyAndCost(const HloInstructionSequence& sequence, strategy_group, replicated_penalty); break; } + case HloOpcode::kSend: { + strategy_group = CreateTupleStrategyGroup(instruction_id); + strategy_group->childs.reserve(ins->shape().tuple_shapes_size()); + for (size_t i = 0; i < ins->shape().tuple_shapes_size(); ++i) { + std::unique_ptr child_strategies = + CreateLeafStrategyGroup(instruction_id, ins, strategy_map, + strategy_groups); + AddReplicatedStrategy(ins, ins->shape().tuple_shapes(i), cluster_env, + strategy_map, child_strategies, 0); + child_strategies->tuple_element_idx = i; + strategy_group->childs.push_back(std::move(child_strategies)); + } + break; + } + case HloOpcode::kSendDone: { + strategy_group = CreateLeafStrategyGroup(instruction_id, ins, + strategy_map, strategy_groups); + AddReplicatedStrategy(ins, ins->shape(), cluster_env, strategy_map, + strategy_group, 0); + break; + } case HloOpcode::kAfterAll: { strategy_group = CreateLeafStrategyGroup(instruction_id, ins, strategy_map, strategy_groups); @@ -700,10 +783,12 @@ BuildStrategyAndCost(const HloInstructionSequence& sequence, if (!LeafVectorsAreConsistent(strategy_group->strategies, strategy_group->following->strategies)) { // It confuses the solver if two instructions have different number of - // sharding strategies but share the same ILP variable. The solver - // would run much longer and/or return infeasible solutions. - // So if two strategies' strategiess are inconsistent, we unfollow - // them. + // sharding strategies but share the same ILP variable. The solver would + // run much longer and/or return infeasible solutions. So if two + // strategies are inconsistent, we unfollow them. + CHECK(!is_follow_necessary_for_correctness) + << "Reverting a following decision that is necessary for " + "correctness. Please report this as a bug."; strategy_group->following = nullptr; } } else if (strategy_group->is_tuple) { @@ -712,6 +797,9 @@ BuildStrategyAndCost(const HloInstructionSequence& sequence, !LeafVectorsAreConsistent( strategy_group->childs.at(i)->strategies, strategy_group->childs.at(i)->following->strategies)) { + CHECK(!is_follow_necessary_for_correctness) + << "Reverting a following decision that is necessary for " + "correctness. Please report this as a bug."; strategy_group->childs.at(i)->following = nullptr; } } diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_strategy.h b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_strategy.h index 8a72b08dd7f1f4..7b8b991939d2c2 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_strategy.h +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_strategy.h @@ -130,6 +130,10 @@ using EdgeStrategyIdx = int64_t; // An index into an edge's strategy vector. using LivenessIdx = int64_t; // An index into the liveness vector. using AliasIdx = int64_t; // An index into the alias vector. +// Various classes needed to support strategy shaving. +using NodeStrategy = std::pair; +using NodeStrategies = StableHashSet; + // A group of strategy choices (along with details like index values) // for each instruction. struct StrategyGroup { diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_test.cc b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_test.cc index 5611d25c32281f..62209f6bbb5179 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_test.cc +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_test.cc @@ -59,7 +59,7 @@ using ::testing::UnorderedElementsAre; using DummyAutoShardingTest = HloTestBase; TEST_F(DummyAutoShardingTest, ReplicatedShardingDummy) { - const char* const hlo_string = R"( + constexpr absl::string_view hlo_string = R"( HloModule module ENTRY %elementwise { %param0 = f32[5,7,11,13]{3,2,1,0} parameter(0) @@ -79,14 +79,14 @@ ENTRY %elementwise { class AutoShardingTest : public HloTestBase { protected: - const char* const dot_hlo_string_ = R"( + absl::string_view dot_hlo_string_ = R"( HloModule module ENTRY matmul { parameter.1 = f32[32,64]{1,0} parameter(0) parameter.2 = f32[64,128]{1,0} parameter(1) ROOT root = f32[32,128]{1,0} dot(parameter.1, parameter.2), lhs_contracting_dims={1}, rhs_contracting_dims={0} })"; - const char* const add_hlo_string_ = R"( + absl::string_view add_hlo_string_ = R"( HloModule module ENTRY %elementwise { %param0 = f32[16,32,64]{2,1,0} parameter(0) @@ -163,7 +163,7 @@ ENTRY %elementwise { }; TEST_F(AutoShardingTest, DISABLED_ElementWiseOperator) { - const char* const hlo_string = R"( + constexpr absl::string_view hlo_string = R"( HloModule module ENTRY %elementwise { %param0 = f32[128,128]{0,1} parameter(0) @@ -189,7 +189,7 @@ ENTRY %elementwise { } TEST_F(AutoShardingTest, NDIterativeSolveTest) { - const char* const hlo_string = R"( + constexpr absl::string_view hlo_string = R"( HloModule module ENTRY %elementwise { @@ -217,8 +217,38 @@ ENTRY %elementwise { EXPECT_THAT(slice, op::Sharding("{devices=[256,1]<=[256]}")); } +TEST_F(AutoShardingTest, SliceDeviceMeshTest) { + constexpr absl::string_view hlo_string = R"( +HloModule module + +ENTRY %elementwise { + param = s32[512,3084]{1,0} parameter(0) + slice = s32[512,2048]{1,0} slice(param), slice={[0:512], [0:2048]} + ROOT copy = s32[512,2048]{1,0} copy(slice) +})"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN( + bool changed, + AutoSharding(/* option */ {.enable = true, + .solve_nd_sharding_iteratively = true, + .device_mesh_shape = {2, 2}, + .device_mesh_alpha = {1.0, 1.0}, + .device_mesh_beta = {0.01, 1.0}}) + .Run(module.get())); + VLOG(10) << module->ToString(); + EXPECT_TRUE(changed); + const HloInstruction* slice = FindInstruction(module.get(), "slice"); + ASSERT_NE(slice, nullptr); + EXPECT_THAT( + slice, + AnyOf(op::Sharding("{devices=[2,1,2]0,1,2,3 last_tile_dim_replicate}"), + op::Sharding("{devices=[2,1,2]0,2,1,3 last_tile_dim_replicate}"))); +} + TEST_F(AutoShardingTest, RngBitGeneratorArrayInput) { - const char* const hlo_string = R"( + constexpr absl::string_view hlo_string = R"( HloModule rng_bit_generator ENTRY %RngBitGenerator (p0: u64[2]) -> (u64[2], u32[16,16]) { @@ -243,7 +273,7 @@ ENTRY %RngBitGenerator (p0: u64[2]) -> (u64[2], u32[16,16]) { } TEST_F(AutoShardingTest, RngBitGeneratorTupleInput) { - const char* const hlo_string = R"( + constexpr absl::string_view hlo_string = R"( HloModule rng_bit_generator ENTRY %RngBitGenerator { @@ -273,7 +303,7 @@ ENTRY %RngBitGenerator { } TEST_F(AutoShardingTest, DotLHSTwoNonContractingDims) { - const char* const hlo_string = R"( + constexpr absl::string_view hlo_string = R"( HloModule module ENTRY %entry { %param0 = f32[4,256,64]{2,1,0} parameter(0) @@ -325,7 +355,7 @@ ENTRY %entry { } TEST_F(AutoShardingTest, DotRHSTwoNonContractingDims) { - const char* const hlo_string = R"( + constexpr absl::string_view hlo_string = R"( HloModule module ENTRY %entry { %param0 = f32[4,256,32]{2,1,0} parameter(0) @@ -377,7 +407,7 @@ ENTRY %entry { } TEST_F(AutoShardingTest, DotTwoContractingDims) { - const char* const hlo_string = R"( + constexpr absl::string_view hlo_string = R"( HloModule module ENTRY %entry { %param0 = f32[4,256,64]{2,1,0} parameter(0) @@ -418,7 +448,7 @@ ENTRY %entry { } TEST_F(AutoShardingTest, TwoMatmul) { - const char* const hlo_string = R"( + constexpr absl::string_view hlo_string = R"( HloModule module ENTRY twomatmul { parameter.1 = f32[64,64]{1,0} parameter(0) @@ -432,7 +462,7 @@ ENTRY twomatmul { ParseAndReturnVerifiedModule(hlo_string)); AutoShardingOption option; option.enable = true; - option.allow_replicated_strategy_for_dot_and_conv = false; + option.allow_recompute_heavy_op = false; option.device_mesh_shape = {2, 2}; option.device_mesh_ids = {0, 1, 2, 3}; option.device_mesh_alpha = {1.0, 1.0}; @@ -463,7 +493,7 @@ ENTRY twomatmul { // Test with replicated strategies on for dot TF_ASSERT_OK_AND_ASSIGN(module, ParseAndReturnVerifiedModule(hlo_string)); option.enable = true; - option.allow_replicated_strategy_for_dot_and_conv = true; + option.allow_recompute_heavy_op = true; option.device_mesh_shape = {2, 2}; option.device_mesh_ids = {0, 1, 2, 3}; option.device_mesh_alpha = {1.0, 1.0}; @@ -497,7 +527,7 @@ ENTRY twomatmul { } TEST_F(AutoShardingTest, ProcessCustomCallShardings) { - const char* const hlo_string = R"( + constexpr absl::string_view hlo_string = R"( HloModule module ENTRY %entry { @@ -527,7 +557,7 @@ ENTRY %entry { } TEST_F(AutoShardingTest, SaveAndRemoveShardingAnnotationKeepAll) { - const char* const hlo_string = R"( + constexpr absl::string_view hlo_string = R"( HloModule module ENTRY %entry (param0: f32[4,256,64], param1: f32[4,256,32]) -> f32[64,32] { @@ -559,7 +589,7 @@ ENTRY %entry (param0: f32[4,256,64], param1: f32[4,256,32]) -> f32[64,32] { IsTrue()))); auto verified_parse_sharding = [](const absl::string_view sharding_str) { - StatusOr sharding = ParseSharding(sharding_str); + absl::StatusOr sharding = ParseSharding(sharding_str); CHECK_OK(sharding); return *sharding; }; @@ -581,7 +611,7 @@ ENTRY %entry (param0: f32[4,256,64], param1: f32[4,256,32]) -> f32[64,32] { TEST_F(AutoShardingTest, SaveAndRemoveShardingAnnotationKeepInputOutputSmallTensor) { - const char* const hlo_string = R"( + constexpr absl::string_view hlo_string = R"( HloModule module ENTRY %entry (param0: f32[4,256,64], param1: f32[4,256,32]) -> f32[64,32] { @@ -613,7 +643,7 @@ ENTRY %entry (param0: f32[4,256,64], param1: f32[4,256,32]) -> f32[64,32] { IsTrue()))); auto verified_parse_sharding = [](const absl::string_view sharding_str) { - StatusOr sharding = ParseSharding(sharding_str); + absl::StatusOr sharding = ParseSharding(sharding_str); CHECK_OK(sharding); return *sharding; }; @@ -631,7 +661,7 @@ ENTRY %entry (param0: f32[4,256,64], param1: f32[4,256,32]) -> f32[64,32] { } TEST_F(AutoShardingTest, SaveAndRemoveShardingAnnotationKeepInputOutput) { - const char* const hlo_string = R"( + constexpr absl::string_view hlo_string = R"( HloModule module ENTRY %entry (param0: f32[4,256,64], param1: f32[4,256,32]) -> f32[64,32] { @@ -709,7 +739,7 @@ ENTRY %entry (param0: f32[4,256,64], param1: f32[4,256,32]) -> f32[64,32] { } TEST_F(AutoShardingTest, SaveAndRemoveShardingAnnotationRemoveAll) { - const char* const hlo_string = R"( + constexpr absl::string_view hlo_string = R"( HloModule module ENTRY %entry (param0: f32[4,256,64], param1: f32[4,256,32]) -> f32[64,32] { @@ -746,7 +776,7 @@ ENTRY %entry (param0: f32[4,256,64], param1: f32[4,256,32]) -> f32[64,32] { } TEST_F(AutoShardingTest, SaveAndRemoveShardingAnnotationRemoveAllSmallTensor) { - const char* const hlo_string = R"( + constexpr absl::string_view hlo_string = R"( HloModule module ENTRY %entry (param0: f32[4,256,64], param1: f32[4,256,32]) -> f32[64,32] { @@ -799,7 +829,7 @@ ENTRY %entry (param0: f32[4,256,64], param1: f32[4,256,32]) -> f32[64,32] { } TEST_F(AutoShardingTest, TupleReduceTest) { - const char* const hlo_string = R"( + constexpr absl::string_view hlo_string = R"( HloModule module %func (lhs_value: f32[], lhs_index: s32[], rhs_value: f32[], rhs_index: s32[]) -> (f32[], s32[]) { %lhs_value = f32[] parameter(0) @@ -845,7 +875,7 @@ ENTRY %entry { } TEST_F(AutoShardingTest, ReduceTest) { - const char* const hlo_string = R"( + constexpr absl::string_view hlo_string = R"( HloModule module %func (x: f32[], y: f32[]) -> f32[] { @@ -888,7 +918,7 @@ ENTRY %entry { } TEST_F(AutoShardingTest, ScatterTest2D) { - const char* const hlo_string = R"( + constexpr absl::string_view hlo_string = R"( HloModule module region { @@ -927,7 +957,7 @@ ENTRY %Scatter { } TEST_F(AutoShardingTest, ScatterTest3D) { - const char* const hlo_string = R"( + constexpr absl::string_view hlo_string = R"( HloModule module region { @@ -970,7 +1000,7 @@ ENTRY %Scatter { } TEST_F(AutoShardingTest, GatherTest) { - const char* const hlo_string = R"( + constexpr absl::string_view hlo_string = R"( HloModule module ENTRY %entry { %param0 = f32[256,1024]{0,1} parameter(0) @@ -1001,7 +1031,7 @@ ENTRY %entry { } TEST_F(AutoShardingTest, GatherTestNoReshard) { - const char* const hlo_string = R"( + constexpr absl::string_view hlo_string = R"( HloModule module ENTRY %entry { get-tuple-element = s8[1000,128]{1,0} parameter(0) @@ -1032,7 +1062,7 @@ ENTRY %entry { } TEST_F(AutoShardingTest, GatherConvTest) { - const char* const hlo_string = R"( + constexpr absl::string_view hlo_string = R"( HloModule module ENTRY %entry { %param0 = f32[1024,1024]{0,1} parameter(0) @@ -1331,7 +1361,7 @@ TEST_F(AutoShardingTest, InvalidOptions) { TEST_F(AutoShardingTest, AutoShardingKeepUserShardingInputOutput) { // An HLO Module with sharding for all instructions. - const char* const hlo_string = R"( + constexpr absl::string_view hlo_string = R"( HloModule module ENTRY %entry (param0: f32[4,256,64], param1: f32[4,256,32]) -> f32[64,32] { @@ -1363,7 +1393,7 @@ ENTRY %entry (param0: f32[4,256,64], param1: f32[4,256,32]) -> f32[64,32] { TEST_F(AutoShardingTest, AutoShardingKeepUserShardingAdd) { // An HLO Module with sharding for all instructions. - const char* const hlo_string = R"( + constexpr absl::string_view hlo_string = R"( HloModule module ENTRY %elementwise { %param0 = f32[128,128]{0,1} parameter(0) @@ -1397,7 +1427,7 @@ ENTRY %elementwise { TEST_F(AutoShardingTest, AutoShardingKeepUserShardingDot) { // An HLO Module with sharding for all instructions. - const char* const hlo_string = R"( + constexpr absl::string_view hlo_string = R"( HloModule module ENTRY %entry (param0: f32[4,256,64], param1: f32[4,256,32]) -> f32[64,32] { @@ -1443,7 +1473,7 @@ ENTRY %entry (param0: f32[4,256,64], param1: f32[4,256,32]) -> f32[64,32] { } TEST_F(AutoShardingTest, DISABLED_AutoShardingKeepUserShardingTupleReduce) { - const char* const hlo_string = R"( + constexpr absl::string_view hlo_string = R"( HloModule module %func (lhs_value: f32[], lhs_index: s32[], rhs_value: f32[], rhs_index: s32[]) -> (f32[], s32[]) { %lhs_value = f32[] parameter(0) @@ -1491,7 +1521,7 @@ ENTRY %entry { } TEST_F(AutoShardingTest, DISABLED_TupleParameter) { - const char* const hlo_string = R"( + constexpr absl::string_view hlo_string = R"( HloModule module ENTRY %tupleparameter { %tuple_param = (f32[16,32,64]{2,1,0}, f32[16,32,64]{2,1,0}) parameter(0) @@ -1521,7 +1551,7 @@ ENTRY %tupleparameter { // CRASHES TEST_F(AutoShardingTest, DISABLED_GetTupleElementWithUserShardingTest) { - const char* const hlo_string = R"( + constexpr absl::string_view hlo_string = R"( HloModule module %while_cond { @@ -1568,7 +1598,7 @@ ENTRY %entry (param0: f32[16,256,256], param1: f32[16,256,256]) -> f32[16,256,25 } TEST_F(AutoShardingTest, While) { - const char* const hlo_string = R"( + constexpr absl::string_view hlo_string = R"( HloModule module %cond { @@ -1648,7 +1678,7 @@ ENTRY %entry { } TEST_F(AutoShardingTest, DynamicSlice) { - const char* const hlo_string = R"( + constexpr absl::string_view hlo_string = R"( HloModule module ENTRY %entry { %param0 = s32[] parameter(0) @@ -1677,7 +1707,7 @@ ENTRY %entry { } TEST_F(AutoShardingTest, Alias) { - const char* const hlo_string = R"( + constexpr absl::string_view hlo_string = R"( HloModule module, input_output_alias={ {0}: (0, {}, may-alias), {1}: (1, {}, may-alias), {2}: (2, {}, may-alias), {3}: (3, {}, may-alias)} ENTRY %entry { @@ -1703,7 +1733,7 @@ ENTRY %entry { } TEST_F(AutoShardingTest, AliasTupleParameter) { - const char* const hlo_string = R"( + constexpr absl::string_view hlo_string = R"( HloModule module, input_output_alias={ {0}: (0, {0}, may-alias), {1}: (0, {1}, may-alias), {2}: (0, {2}, may-alias), {3}: (0, {3}, may-alias)} ENTRY %entry { @@ -1730,7 +1760,7 @@ ENTRY %entry { } TEST_F(AutoShardingTest, JaxRandomUniform) { - const char* const hlo_string = R"( + constexpr absl::string_view hlo_string = R"( HloModule module clone { lhs.1 = u32[] parameter(0) @@ -1781,7 +1811,7 @@ ENTRY %entry { } TEST_F(AutoShardingTest, Reshape) { - const char* const hlo_string = R"( + constexpr absl::string_view hlo_string = R"( HloModule module ENTRY %entry { @@ -1809,7 +1839,7 @@ ENTRY %entry { } TEST_F(AutoShardingTest, ReshapeWithInvalidUserSharding) { - const char* const hlo_string = R"( + constexpr absl::string_view hlo_string = R"( HloModule module ENTRY %entry { @@ -1835,7 +1865,7 @@ ENTRY %entry { } TEST_F(AutoShardingTest, Broadcast) { - const char* const hlo_string = R"( + constexpr absl::string_view hlo_string = R"( HloModule module ENTRY %entry { @@ -1854,7 +1884,7 @@ ENTRY %entry { } TEST_F(AutoShardingTest, TestReshardingCostsForUserAnnotatedSharding) { - const char* const hlo_string = R"( + constexpr absl::string_view hlo_string = R"( HloModule module ENTRY %entry { @@ -1880,7 +1910,7 @@ ENTRY %entry { } TEST_F(AutoShardingTest, AllowAliasToFollowerConversion) { - const char* const hlo_string = R"( + constexpr absl::string_view hlo_string = R"( HloModule module, input_output_alias={ {0}: (0, {}, may-alias), {1}: (1, {}, may-alias), {2}: (2, {}, may-alias), {3}: (3, {}, may-alias)} ENTRY %entry { @@ -1907,7 +1937,7 @@ ENTRY %entry { } TEST_F(AutoShardingTest, DisallowAliasToFollowerConversion) { - const char* const hlo_string = R"( + constexpr absl::string_view hlo_string = R"( HloModule module, input_output_alias={ {0}: (0, {}, may-alias), {1}: (1, {}, may-alias), {2}: (2, {}, may-alias), {3}: (3, {}, may-alias)} ENTRY %entry { diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_util.cc b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_util.cc index 112f07979b187a..c32932b727521a 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_util.cc +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_util.cc @@ -152,6 +152,32 @@ std::optional PropagateDimwiseSharding( return input_spec; } +HloSharding PropagateDimwiseShardingSlice(const HloSharding& input_spec, + const Shape& old_shape, + const Shape& new_shape, + const Array& device_mesh) { + if (input_spec.IsReplicated()) { + return input_spec; + } + + CHECK(old_shape.IsArray()); + + std::vector tensor_to_mesh_dim = + GetTensorDimToMeshDim(new_shape.rank(), input_spec, device_mesh, + /* consider_reverse_device_meshes */ false); + + std::vector tensor_dims; + std::vector mesh_dims; + for (size_t i = 0; i < new_shape.rank(); ++i) { + if (new_shape.dimensions(i) == old_shape.dimensions(i) && + tensor_to_mesh_dim[i] > -1) { + tensor_dims.push_back(i); + mesh_dims.push_back(tensor_to_mesh_dim[i]); + } + } + return Tile(new_shape, tensor_dims, mesh_dims, device_mesh); +} + // Propagate sharding for ReduceWindow-like operations. // The sharding can successfully propagate if the window operation only happens // on tensor dimensions that are not tiled. @@ -356,6 +382,7 @@ void BatchDimMapForward(const std::vector& instructions, case HloOpcode::kBitcastConvert: case HloOpcode::kCopy: case HloOpcode::kCos: + case HloOpcode::kErf: case HloOpcode::kExp: case HloOpcode::kExpm1: case HloOpcode::kFloor: @@ -615,6 +642,7 @@ void BatchDimMapBackward(const std::vector& instructions, case HloOpcode::kBitcastConvert: case HloOpcode::kCopy: case HloOpcode::kCos: + case HloOpcode::kErf: case HloOpcode::kExp: case HloOpcode::kExpm1: case HloOpcode::kFloor: @@ -1046,8 +1074,8 @@ void UseAllReduceForGradAcc(StableHashSet& replicated_set, // dimensions at 0, e.g., array is 2D and dim = 1, this returns array[0, 1], // array[1, 1], array [2, 1], .... // Returns error status if dim >= array.num_dimensions(). -StatusOr> GetValuesAlongOneDim(const Array& array, - int dim) { +absl::StatusOr> GetValuesAlongOneDim( + const Array& array, int dim) { if (dim >= array.num_dimensions()) { return absl::OutOfRangeError(absl::StrCat( "Input dim (", dim, @@ -1064,7 +1092,8 @@ StatusOr> GetValuesAlongOneDim(const Array& array, } // Check whether a sequence is an arithmetic sequence. -StatusOr CheckArithmeticSequence(absl::Span sequence) { +absl::StatusOr CheckArithmeticSequence( + absl::Span sequence) { if (sequence.size() < 2) { return absl::OutOfRangeError( "Invalid device id assignment: sequence.size() < 2"); @@ -1427,7 +1456,8 @@ void FixMixedMeshShapeResharding(HloInstruction* inst, int operand_num, const Array& device_mesh, ReshardingCache* resharding_cache) { HloInstruction* operand = inst->mutable_operand(operand_num); - if (operand->opcode() == HloOpcode::kOutfeed) { + if (operand->opcode() == HloOpcode::kOutfeed || + operand->opcode() == HloOpcode::kSendDone) { return; } @@ -1556,7 +1586,6 @@ HloSharding Tile(const Shape& tensor_shape, tile_assignment_dimensions[tensor_dims[i]] = device_mesh.dim(mesh_dims[i]); split_prod *= device_mesh.dim(mesh_dims[i]); } - // Replicate on remaining mesh dimensions bool replicate_on_last_tile_dim = false; if (split_prod < device_mesh.num_elements()) { @@ -1733,19 +1762,14 @@ AliasSet BuildAliasSet(const HloModule* module, for (const HloComputation* computation : module->computations()) { for (const HloInstruction* instruction : computation->instructions()) { if (instruction->opcode() == HloOpcode::kWhile) { + // Aliasing between the while op, and the parameters of its body and + // conditional computations is handled by making the latter follow the + // input tuple to thew while loop in the function + // BuildStrategyAndCost(). traverse_tuple_alias( strategy_map.at(instruction).get(), strategy_map.at(instruction->while_body()->root_instruction()) .get()); - traverse_tuple_alias( - strategy_map.at(instruction).get(), - strategy_map.at(instruction->while_body()->parameter_instruction(0)) - .get()); - traverse_tuple_alias( - strategy_map.at(instruction).get(), - strategy_map - .at(instruction->while_condition()->parameter_instruction(0)) - .get()); } else if (instruction->opcode() == HloOpcode::kConditional) { auto branch_computations = instruction->branch_computations(); for (size_t i = 0; i < branch_computations.size(); ++i) { @@ -1798,8 +1822,7 @@ Status CheckAliasSetCompatibility(const AliasSet& alias_set, "tensors and may result in large memory consumption: " << "(" << instructions.at(src_strategy_group->instruction_id)->name() << ", " << instructions.at(dst_strategy_group->instruction_id)->name() - << ")" - << "\n" + << ")" << "\n" << "(" << src_strategy_group->node_idx << ", " << dst_strategy_group->node_idx << ")\n" << src_strategy_group->ToString() << "\n" @@ -1941,7 +1964,7 @@ double ReshardingCostMixedMeshShape( return resharding_costs; } -StatusOr> +absl::StatusOr> AdjustShardingWithPartialMeshShapePerElement( const HloSharding& sharding, const absl::flat_hash_set& valid_shards, int64_t total_num_devices, @@ -2031,7 +2054,7 @@ AdjustShardingWithPartialMeshShapePerElement( return std::nullopt; } -StatusOr AdjustShardingsWithPartialMeshShape( +absl::StatusOr AdjustShardingsWithPartialMeshShape( const std::vector& instructions, const std::vector& mesh_shape, int64_t total_num_devices, bool crash_on_error) { @@ -2052,7 +2075,7 @@ StatusOr AdjustShardingsWithPartialMeshShape( for (size_t i = 0; i < inst->shape().tuple_shapes_size(); i++) { auto shape = inst->shape().tuple_shapes(i); auto sharding = inst->sharding().tuple_elements()[i]; - StatusOr> new_sharding_result = + absl::StatusOr> new_sharding_result = AdjustShardingWithPartialMeshShapePerElement( sharding, valid_shards, total_num_devices, crash_on_error); if (new_sharding_result.ok()) { @@ -2071,7 +2094,7 @@ StatusOr AdjustShardingsWithPartialMeshShape( } inst->set_sharding(HloSharding::Tuple(output_tuple_sharding)); } else { - StatusOr> sharding_result = + absl::StatusOr> sharding_result = AdjustShardingWithPartialMeshShapePerElement( inst->sharding(), valid_shards, total_num_devices, crash_on_error); diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_util.h b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_util.h index 4162c5d0d1c78b..67403cb7f0a401 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_util.h +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_util.h @@ -428,6 +428,11 @@ std::optional PropagateDimwiseSharding( const HloSharding& input_spec, const Shape& old_shape, const Shape& new_shape); +HloSharding PropagateDimwiseShardingSlice(const HloSharding& input_spec, + const Shape& old_shape, + const Shape& new_shape, + const Array& device_mesh); + // Propagate sharding for ReduceWindow-like operations. // The sharding can successfully propagate if the window operation only happens // on tensor dimensions that are not tiled. @@ -528,10 +533,11 @@ std::vector> GetReplicaGroupsAlongOneDimension( // dimensions at 0, e.g., array is 2D and dim = 1, this returns array[0, 1], // array[1, 1], array [2, 1], .... // Returns error status if dim >= array.num_dimensions(). -StatusOr> GetValuesAlongOneDim(const Array& array, - int dim); +absl::StatusOr> GetValuesAlongOneDim( + const Array& array, int dim); -StatusOr CheckArithmeticSequence(absl::Span sequence); +absl::StatusOr CheckArithmeticSequence( + absl::Span sequence); // Checks if the number of sharded dimensions in the tile assignment matches the // device mesh. @@ -617,7 +623,7 @@ double ReshardingCostMixedMeshShape( // If a sharding is [8, 4] for the complete mesh shape, we convert it to [8, 1] // given [1, 8, 1] as the partial mesh shape. // total_num_devices should equal to the product of mesh_shape elements. -StatusOr AdjustShardingsWithPartialMeshShape( +absl::StatusOr AdjustShardingsWithPartialMeshShape( const std::vector& instructions, const std::vector& mesh_shape, int64_t total_num_devices, bool crash_on_error); diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/cluster_environment.cc b/third_party/xla/xla/hlo/experimental/auto_sharding/cluster_environment.cc index 906068c9c51bc3..a5e873059726f6 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/cluster_environment.cc +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/cluster_environment.cc @@ -120,19 +120,6 @@ double ClusterEnvironment::AllToAllCost(double num_bytes, int mesh_dim) const { mesh_beta_); } -double ClusterEnvironment::DotCost(const Shape& lhs_shape, - const Shape& rhs_shape) const { - if (!auto_sharding_option_.allow_recompute_heavy_op) { - return kInfinityCost; - } - - // TODO(zhuohan): When profiling data is not available, it is not easy to - // align the scale of compute cost and communication cost. Here we just use - // a simple heuristic to compute the compute cost with communication cost. - double num_bytes = GetBytes(lhs_shape) + GetBytes(rhs_shape); - return AllReduceCost(num_bytes, 0) + AllReduceCost(num_bytes, 1); -} - double ClusterEnvironment::CollectivePermuteCost( double num_bytes, const std::vector>& src_dst_pairs) const { diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/cluster_environment.h b/third_party/xla/xla/hlo/experimental/auto_sharding/cluster_environment.h index 47db1ab7a1918f..0a18ec6d884279 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/cluster_environment.h +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/cluster_environment.h @@ -50,7 +50,7 @@ class ClusterEnvironment { mesh_beta_(mesh_beta.begin(), mesh_beta.end()), prof_result_(prof_result), total_devices_(device_mesh.num_elements()), - device_mesh_1d_(original_device_mesh), + device_mesh_1d_(device_mesh), auto_sharding_option_(auto_sharding_option) { // Build replica group for each dimension. non_zero_mesh_dims_ = @@ -68,10 +68,8 @@ class ClusterEnvironment { original_device_mesh_shape.end()); size_t largest_dim_idx = std::distance(original_device_mesh_shape.begin(), max_dim_iterator); - - std::vector device_mesh_1d_shape( - original_device_mesh.num_dimensions(), 1); - device_mesh_1d_shape[largest_dim_idx] = original_device_mesh.num_elements(); + std::vector device_mesh_1d_shape(device_mesh.num_dimensions(), 1); + device_mesh_1d_shape[largest_dim_idx] = device_mesh.num_elements(); device_mesh_1d_.Reshape(device_mesh_1d_shape); } @@ -133,8 +131,6 @@ class ClusterEnvironment { const HloSharding& src_spec, const HloSharding& dst_spec) const; - double DotCost(const Shape& lhs_shape, const Shape& rhs_shape) const; - // This function attempts to overestimate the cost of replicating a tensor of // shape `shape` sharded according to `src_spec`. double OverestimateReplicationCost(const Shape& shape, diff --git a/third_party/xla/xla/hlo/ir/BUILD b/third_party/xla/xla/hlo/ir/BUILD index 537cc657443174..bbd9e17120c7ec 100644 --- a/third_party/xla/xla/hlo/ir/BUILD +++ b/third_party/xla/xla/hlo/ir/BUILD @@ -1,10 +1,12 @@ # Description: # XLA’s HLO Intermediate Representation implementation. +load("@local_tsl//tsl:tsl.bzl", "internal_visibility") load("@local_tsl//tsl/platform:rules_cc.bzl", "cc_library") package( - default_visibility = ["//visibility:public"], + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = internal_visibility([":friends"]), licenses = ["notice"], ) @@ -53,7 +55,6 @@ cc_library( "hlo_sharding.h", "hlo_sharding_metadata.h", ], - visibility = ["//visibility:public"], deps = [ ":ptrvec", ":tile_assignment", @@ -111,7 +112,6 @@ cc_library( name = "hlo_module_group", srcs = ["hlo_module_group.cc"], hdrs = ["hlo_module_group.h"], - visibility = ["//visibility:public"], deps = [ ":hlo", "//xla/service:hlo_proto_cc", @@ -124,7 +124,6 @@ cc_library( name = "hlo_reachability", srcs = ["hlo_reachability.cc"], hdrs = ["hlo_reachability.h"], - visibility = ["//visibility:public"], deps = [ ":hlo", "//xla:types", @@ -137,7 +136,6 @@ cc_library( cc_library( name = "ptrvec", hdrs = ["ptrvec.h"], - visibility = ["//visibility:public"], deps = [ "@com_google_absl//absl/log:check", "@local_tsl//tsl/platform:logging", @@ -162,7 +160,6 @@ cc_library( name = "tile_assignment", srcs = ["tile_assignment.cc"], hdrs = ["tile_assignment.h"], - visibility = ["//visibility:public"], deps = [ "//xla:array", "//xla:printer", diff --git a/third_party/xla/xla/hlo/ir/dfs_hlo_visitor.h b/third_party/xla/xla/hlo/ir/dfs_hlo_visitor.h index 81654d524b8c10..01f825dc6ff97e 100644 --- a/third_party/xla/xla/hlo/ir/dfs_hlo_visitor.h +++ b/third_party/xla/xla/hlo/ir/dfs_hlo_visitor.h @@ -175,6 +175,9 @@ class DfsHloVisitorBase { virtual Status HandleRoundNearestEven(HloInstructionPtr hlo) { return HandleElementwiseUnary(hlo); } + virtual Status HandleErf(HloInstructionPtr hlo) { + return HandleElementwiseUnary(hlo); + } virtual Status HandleLogistic(HloInstructionPtr hlo) { return HandleElementwiseUnary(hlo); } diff --git a/third_party/xla/xla/hlo/ir/hlo_computation.cc b/third_party/xla/xla/hlo/ir/hlo_computation.cc index d84efb9a2e1bd2..1891617fc6addd 100644 --- a/third_party/xla/xla/hlo/ir/hlo_computation.cc +++ b/third_party/xla/xla/hlo/ir/hlo_computation.cc @@ -19,7 +19,6 @@ limitations under the License. #include #include #include -#include #include #include #include @@ -44,6 +43,7 @@ limitations under the License. #include "xla/map_util.h" #include "xla/printer.h" #include "xla/service/mapped_ptr_container_sorter.h" +#include "xla/service/name_uniquer.h" #include "xla/shape_util.h" #include "xla/status_macros.h" #include "xla/util.h" @@ -115,19 +115,9 @@ HloComputation::HloComputation( const std::string& name, int parameter_count, std::vector>* instructions, HloInstruction* root_instruction) - : name_(NameUniquer::GetSanitizedName(name)), - unique_id_(-1), + : unique_id_(-1), root_instruction_(root_instruction), - fusion_instruction_(nullptr), - is_fusion_computation_(false), - custom_call_instruction_(nullptr), - is_custom_call_computation_(false), - collective_call_instruction_(nullptr), - is_collective_called_computation_(false), - while_call_instruction_(nullptr), - is_while_call_body_computation_(false), - conditional_call_instruction_(nullptr), - is_conditional_branch_computation_(false) { + name_(NameUniquer::GetSanitizedName(name)) { param_instructions_.resize(parameter_count, nullptr); bool root_found = false; for (auto& instruction : *instructions) { @@ -149,10 +139,9 @@ HloComputation::HloComputation( } HloComputation::~HloComputation() { - if (fusion_instruction_ != nullptr) { - CHECK(fusion_instruction_->fused_instructions_computation() == this); - fusion_instruction_->ClearCalledComputations(); - fusion_instruction_ = nullptr; + if (FusionInstruction() != nullptr) { + CHECK(FusionInstruction()->fused_instructions_computation() == this); + FusionInstruction()->ClearCalledComputations(); } if (IsAsyncComputation()) { CHECK(async_start_->async_wrapped_computation() == this); @@ -161,6 +150,29 @@ HloComputation::~HloComputation() { for (const auto& i : instructions_) { delete i.inst(); } + Cleanup(); +} + +void HloComputation::SetInstruction(HloInstruction* instruction, + InstructionType type) { + static_assert(alignof(HloInstruction) == kInstructionTypeMask + 1, + "HloInstruction should be aligned as a QWORD"); + + DCHECK(type != InstructionType::kUnset) + << "Set instruction must be called with a valid type, not kUnset."; + DCHECK(instruction_type() == InstructionType::kUnset || + instruction_type() == type) + << "Unexpected instruction type. Current type is " + << static_cast(instruction_type()) << " and it cannot be reset to " + << static_cast(type); + + // If `instruction` is nullptr, we need to preserve the existing type. + if (instruction == nullptr) { + type = instruction_type(); + } + + instruction_and_type_ = + reinterpret_cast(instruction) | static_cast(type); } HloInstruction* HloComputation::AddInstruction( @@ -206,7 +218,7 @@ HloInstruction* HloComputation::AddParameter( std::unique_ptr instruction) { CHECK(instruction->opcode() == HloOpcode::kParameter); CHECK(!IsFusionComputation() || - fusion_instruction_->operand_count() == param_instructions_.size()); + FusionInstruction()->operand_count() == param_instructions_.size()); instruction->set_parent(this); param_instructions_.push_back(instruction.get()); AddInstructionInternal(std::move(instruction)); @@ -280,7 +292,7 @@ HloInstruction* HloComputation::ReplaceParameter( CHECK_LT(param_no, param_instructions_.size()); CHECK(instruction->opcode() == HloOpcode::kParameter); CHECK(!IsFusionComputation() || - fusion_instruction_->operand_count() == param_instructions_.size()); + FusionInstruction()->operand_count() == param_instructions_.size()); instruction->set_parent(this); HloInstruction* new_instruction = @@ -425,7 +437,7 @@ Status HloComputation::RemoveInstructionImpl(HloInstruction* instruction, TF_RET_CHECK(inst_it != instruction_indices_.end()); HloInstructionInfo* info = &instructions_[inst_it->second]; info->inst()->set_parent(nullptr); - to_be_deleted_.emplace_back(info->inst()); // Takes ownership + to_be_deleted_.push_back(info->inst()); // Takes ownership to_be_deleted_.back()->DetachFromOperandsAndUsers(); // Clear all operands to avoid Null operands. to_be_deleted_.back()->RemoveAllOperands(); @@ -878,7 +890,7 @@ HloComputationProto HloComputation::ToProto() const { } proto.set_root_id(root_instruction()->unique_id()); *proto.mutable_program_shape() = ComputeProgramShape().ToProto(); - proto.set_is_fusion_computation(is_fusion_computation_); + proto.set_is_fusion_computation(IsFusionComputation()); proto.set_execution_thread(IsMainThread() ? "" : std::string(execution_thread())); return proto; @@ -942,7 +954,10 @@ HloComputation::CreateFromProto( auto computation = absl::WrapUnique( new HloComputation(proto.name(), parameter_count, &instructions, root)); computation->unique_id_ = proto.id(); - computation->is_fusion_computation_ = proto.is_fusion_computation(); + if (proto.is_fusion_computation()) { + computation->instruction_and_type_ = + static_cast(InstructionType::kFusion); + } if (!proto.execution_thread().empty()) { computation->SetExecutionThread(proto.execution_thread()); } @@ -1053,7 +1068,13 @@ StatusOr HloComputation::CreateAsyncInstructions( async_start->CopyBackendConfigFrom(instruction); async_done->set_metadata(instruction->metadata()); async_done->CopyBackendConfigFrom(instruction); - TF_RETURN_IF_ERROR(async_done->CopyAllControlDepsFrom(instruction)); + for (HloInstruction* control_pred : instruction->control_predecessors()) { + TF_RETURN_IF_ERROR(control_pred->AddControlDependencyTo(async_start)); + } + for (HloInstruction* control_successor : instruction->control_successors()) { + TF_RETURN_IF_ERROR(async_done->AddControlDependencyTo(control_successor)); + } + if (replace) { TF_RETURN_IF_ERROR(instruction->DropAllControlDeps()); TF_RETURN_IF_ERROR(ReplaceInstruction(instruction, async_done)); @@ -1284,11 +1305,7 @@ StatusOr HloComputation::ReplaceInstructionWithDifferentShape( // But still this seems to be better than nothing. bool overwrite_op_name = new_instruction->metadata().op_name().empty() && !old_instruction->metadata().op_name().empty(); - bool overwrite_pass_id = - new_instruction->metadata().op_name().empty() && - new_instruction->metadata().logical_creation_pass_id() == 0 && - old_instruction->metadata().logical_creation_pass_id() != 0; - if (overwrite_op_name || overwrite_pass_id) { + if (overwrite_op_name) { new_instruction->set_metadata(old_instruction->metadata()); } if (new_instruction->frontend_attributes().map().empty()) { diff --git a/third_party/xla/xla/hlo/ir/hlo_computation.h b/third_party/xla/xla/hlo/ir/hlo_computation.h index 9dd7a6fc78d28a..f2df3c7b19a48a 100644 --- a/third_party/xla/xla/hlo/ir/hlo_computation.h +++ b/third_party/xla/xla/hlo/ir/hlo_computation.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef XLA_HLO_IR_HLO_COMPUTATION_H_ #define XLA_HLO_IR_HLO_COMPUTATION_H_ +#include #include #include #include @@ -34,6 +35,8 @@ limitations under the License. #include "xla/hlo/ir/dfs_hlo_visitor.h" #include "xla/hlo/ir/hlo_clone_context.h" #include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/ir/ptrvec.h" #include "xla/iterator_util.h" #include "xla/printer.h" #include "xla/service/hlo.pb.h" @@ -294,8 +297,20 @@ class HloComputation { absl::string_view name() const { return name_; } + // Sets the string identifier for this computation. Name will be sanitized to + // match the regexp "[a-zA-Z_][a-zA-Z0-9_.-]*". + // + // See also HloModule::SetAndUniquifyComputationName(), which does this plus + // UniqufyName(). + void SetAndSanitizeName(absl::string_view name) { + name_ = NameUniquer::GetSanitizedName(name); + } + // Use the given NameUniquer to select a unique name for the computation based // on the computation's existing name. + // + // See also HloModule::SetAndUniquifyComputationName(), which does this plus + // SetAndSanitizeName(). void UniquifyName(NameUniquer* name_uniquer); // Prints a string representation of the computation. @@ -698,100 +713,90 @@ class HloComputation { // Returns if this computation is a fusion computation. // Do not use this method to determine if fusion_instruction_ != nullptr. // Instead, directly do: FusionInstruction() != nullptr - bool IsFusionComputation() const { return is_fusion_computation_; } + bool IsFusionComputation() const { + return instruction_type() == InstructionType::kFusion; + } // Returns if this computation is the entry computation of the module. bool IsEntryComputation() const; // Returns the owning fusion instruction, or nullptr if this is not a fusion // computation. - HloInstruction* FusionInstruction() const { return fusion_instruction_; } + HloInstruction* FusionInstruction() const { + return instruction_type() == InstructionType::kFusion ? instruction() + : nullptr; + } void SetFusionInstruction(HloInstruction* fusion_instruction) { - CHECK(!IsCustomCallComputation() && !IsAsyncComputation() && - !IsCollectiveCalledComputation() && !IsWhileBodyComputation() && - !IsConditionalBranchComputation()); - fusion_instruction_ = fusion_instruction; - is_fusion_computation_ |= (fusion_instruction != nullptr); + SetInstruction(fusion_instruction, InstructionType::kFusion); } // Returns if this computation is a custom-call computation. - bool IsCustomCallComputation() const { return is_custom_call_computation_; } + bool IsCustomCallComputation() const { + return instruction_type() == InstructionType::kCustomCall; + } // Returns the owning custom call instruction, or nullptr if this is not a // custom call computation. HloInstruction* CustomCallInstruction() const { - return custom_call_instruction_; + return instruction_type() == InstructionType::kCustomCall ? instruction() + : nullptr; } void SetCustomCallInstruction(HloInstruction* custom_call_instruction) { - CHECK(!IsFusionComputation() && !IsAsyncComputation() && - !IsCollectiveCalledComputation() && !IsWhileBodyComputation() && - !IsConditionalBranchComputation()); - custom_call_instruction_ = custom_call_instruction; - is_custom_call_computation_ |= (custom_call_instruction != nullptr); + SetInstruction(custom_call_instruction, InstructionType::kCustomCall); } // Returns if this computation is a to_apply region of a collective. bool IsCollectiveCalledComputation() const { - return is_collective_called_computation_; + return instruction_type() == InstructionType::kCollective; } // Returns the owning collective call instruction, or nullptr if this is not a // collective call computation. HloInstruction* CollectiveCallInstruction() const { - return collective_call_instruction_; + return instruction_type() == InstructionType::kCollective ? instruction() + : nullptr; } void SetCollectiveCallInstruction( HloInstruction* collective_call_instruction) { - CHECK(!IsFusionComputation() && !IsAsyncComputation() && - !IsCustomCallComputation() && !IsWhileBodyComputation() && - !IsConditionalBranchComputation()); - collective_call_instruction_ = collective_call_instruction; - is_collective_called_computation_ |= - (collective_call_instruction != nullptr); + SetInstruction(collective_call_instruction, InstructionType::kCollective); } // Returns if this computation is a body computation of a while. bool IsWhileBodyComputation() const { - return is_while_call_body_computation_; + return instruction_type() == InstructionType::kWhile; } // Returns the owning while call instruction, or nullptr if this is not a // while call body computation. HloInstruction* WhileCallInstruction() const { - return while_call_instruction_; + return instruction_type() == InstructionType::kWhile ? instruction() + : nullptr; } void SetWhileCallInstruction(HloInstruction* while_call_instruction) { - CHECK(!IsFusionComputation() && !IsAsyncComputation() && - !IsCustomCallComputation() && !IsCollectiveCalledComputation() && - !IsConditionalBranchComputation()); CHECK(while_call_instruction != nullptr); CHECK(while_call_instruction->opcode() == HloOpcode::kWhile); - while_call_instruction_ = while_call_instruction; - is_while_call_body_computation_ = true; + SetInstruction(while_call_instruction, InstructionType::kWhile); } // Returns if this computation is a branch computation of a conditional. bool IsConditionalBranchComputation() const { - return is_conditional_branch_computation_; + return instruction_type() == InstructionType::kConditional; } // Returns the owning conditional call instruction, or nullptr if this is not // a conditional branch computation. HloInstruction* ConditionalCallInstruction() const { - return conditional_call_instruction_; + return instruction_type() == InstructionType::kConditional ? instruction() + : nullptr; } void SetConditionalCallInstruction( HloInstruction* conditional_call_instruction) { - CHECK(!IsFusionComputation() && !IsAsyncComputation() && - !IsCustomCallComputation() && !IsCollectiveCalledComputation() && - !IsWhileBodyComputation()); CHECK(conditional_call_instruction != nullptr); CHECK(conditional_call_instruction->opcode() == HloOpcode::kConditional); - conditional_call_instruction_ = conditional_call_instruction; - is_conditional_branch_computation_ = true; + SetInstruction(conditional_call_instruction, InstructionType::kConditional); } // Returns if this computation is an async computation. @@ -802,18 +807,14 @@ class HloComputation { HloInstruction* AsyncStart() const { return async_start_; } void AddAsyncStart(HloInstruction* async_instruction) { - CHECK(!IsCalledComputation()); + // TODO: Add instruction type for async instructions. + CHECK(instruction_type() == InstructionType::kUnset); CHECK(async_instruction->opcode() == HloOpcode::kAsyncStart); async_start_ = async_instruction; } void RemoveAsyncStart() { async_start_ = nullptr; } - // Returns if this computation is invoked by an Hlo instruction. - bool IsCalledComputation() const { - return IsFusionComputation() || IsCustomCallComputation(); - } - // Clear the unique ID of the computation so that it can be re-assigned, such // as for the purpose of compacting the unique IDs. void ClearUniqueIdInternal() { unique_id_ = -1; } @@ -845,7 +846,12 @@ class HloComputation { // stage clean up process is designed such that HloPass can have stable // internal pointers to HloInstructions while we create and remove // HloInstructions in a pass. - void Cleanup() { to_be_deleted_.clear(); } + void Cleanup() { + for (HloInstruction* it : to_be_deleted_) { + delete it; + } + to_be_deleted_.clear(); + } // Returns true if a given instruction is marked dead in this computation. bool IsMarkedAsDead(const HloInstruction* inst); @@ -904,50 +910,48 @@ class HloComputation { Status RemoveInstructionImpl(HloInstruction* instruction, bool ignore_safety_check); - std::string name_; - int64_t unique_id_; - HloInstruction* root_instruction_; - - // If this computation is a fusion computation, this field points to the - // corresponding fusion instruction (if it is live). Otherwise, this is null. - HloInstruction* fusion_instruction_; - - // Determines whether this computation is a fusion computation. A fusion - // computation ordinarily also has a non-null fusion_instruction_. However, if - // a fusion instruction is removed during compilation, the fusion computation - // becomes unreachable, and its fusion_instruction_ is set to null. We still - // need to regard such computations as fusion computations for HLO scheduling - // purposes. - bool is_fusion_computation_; - - // If this computation is a custom-call computation, this field points to the - // corresponding custom-call instruction (if it is live). Otherwise, this is - // null. - HloInstruction* custom_call_instruction_; - - // Determines whether this computation is a custom-call computation. - bool is_custom_call_computation_; + enum class InstructionType : uint8_t { + kUnset, + // This computation is a fusion computation. A fusion computation ordinarily + // also has a non-null instruction. However, if a fusion instruction + // is removed during compilation, the fusion computation becomes + // unreachable, and its instruction is set to null. We still need to regard + // such computations as fusion computations for HLO scheduling purposes. + kFusion, + // This computation is a custom-call computation. + kCustomCall, + // This computation is a while body computation. + kCollective, + // This computation is a while body computation. + kWhile, + // This computation is a conditional branch computation. + kConditional, + }; + static constexpr uintptr_t kInstructionTypeMask = 0b111; + static_assert(static_cast(InstructionType::kUnset) == 0, + "kUnset must be 0."); - // If this computation is a collective sub-computation, this field points to - // the corresponding collective instruction. Otherwise, this is null. - HloInstruction* collective_call_instruction_; + InstructionType instruction_type() const { + return static_cast(instruction_and_type_ & + kInstructionTypeMask); + } - // Determines whether this computation is a collective sub-computation. - bool is_collective_called_computation_; + HloInstruction* instruction() const { + return reinterpret_cast(instruction_and_type_ & + ~kInstructionTypeMask); + } - // If this computation is a while body computation, this field points to - // the corresponding while instruction. Otherwise, this is null. - HloInstruction* while_call_instruction_; + void SetInstruction(HloInstruction* instruction, InstructionType type); - // Determines whether this computation is a while body computation. - bool is_while_call_body_computation_; + int64_t unique_id_; + HloInstruction* root_instruction_; - // If this computation is a conditional branch computation, this field points - // to the corresponding conditional instruction. Otherwise, this is null. - HloInstruction* conditional_call_instruction_; + // Module containing this computation. + HloModule* parent_ = nullptr; - // Determines whether this computation is a conditional branch computation. - bool is_conditional_branch_computation_; + // Contains HloInstruction* and its type. + // The respective type in the least significant three bits. + uintptr_t instruction_and_type_ = 0; // If this computation is an async computation, this field points to the // first async instruction (async-start) in the asynchronous op chain that @@ -955,11 +959,7 @@ class HloComputation { // Otherwise, this is empty. HloInstruction* async_start_ = nullptr; - // Execution thread of this computation. By default, it's main thread. - std::string execution_thread_ = HloInstruction::kMainExecutionThread; - - // Module containing this computation. - HloModule* parent_ = nullptr; + HloInstruction::InstructionVector param_instructions_; // Store instructions in std::vector as they can be added and removed // arbitrarily and we want a stable iteration order. Keep a map from @@ -967,11 +967,14 @@ class HloComputation { HloInstructionList instructions_; absl::flat_hash_map instruction_indices_; + // Execution thread of this computation. By default, it's main thread. + std::string execution_thread_ = HloInstruction::kMainExecutionThread; + // Removed instructions are moved into to_be_deleted_ first and then // deallocated when Cleanup is called. - std::vector> to_be_deleted_; + PtrVec to_be_deleted_; - HloInstruction::InstructionVector param_instructions_; + std::string name_; HloComputation(const HloComputation&) = delete; HloComputation& operator=(const HloComputation&) = delete; diff --git a/third_party/xla/xla/hlo/ir/hlo_instruction.cc b/third_party/xla/xla/hlo/ir/hlo_instruction.cc index 11a68c2f06e498..2c07b6e7ca5511 100644 --- a/third_party/xla/xla/hlo/ir/hlo_instruction.cc +++ b/third_party/xla/xla/hlo/ir/hlo_instruction.cc @@ -1298,6 +1298,7 @@ HloInstruction::CreateRngBitGenerator(const Shape& shape, HloInstruction* state, case HloOpcode::kCos: case HloOpcode::kOptimizationBarrier: case HloOpcode::kClz: + case HloOpcode::kErf: case HloOpcode::kExp: case HloOpcode::kExpm1: case HloOpcode::kFloor: @@ -1634,6 +1635,12 @@ HloInstruction::CreateCollectivePermuteStart( is_host_transfer); } +/* static */ std::unique_ptr HloInstruction::CreateRecvDone( + HloInstruction* operand, int64_t channel_id, bool is_host_transfer) { + return std::make_unique(operand, channel_id, + is_host_transfer); +} + /* static */ std::unique_ptr HloInstruction::CreateReverse( const Shape& shape, HloInstruction* operand, absl::Span dimensions) { @@ -2348,6 +2355,7 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( case HloOpcode::kOptimizationBarrier: case HloOpcode::kCopyDone: case HloOpcode::kCos: + case HloOpcode::kErf: case HloOpcode::kExp: case HloOpcode::kExpm1: case HloOpcode::kImag: @@ -2451,6 +2459,8 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( CHECK_EQ(new_operands.size(), 0); clone = CreatePartitionId(shape); break; + default: + CHECK(0) << "Unsupported opcode: " << opcode_; } // SetupDerivedInstruction will setup the precision_config_ field. SetupDerivedInstruction(clone.get()); @@ -2776,6 +2786,7 @@ bool HloInstruction::IdenticalSlowPath( case HloOpcode::kCos: case HloOpcode::kDivide: case HloOpcode::kDynamicUpdateSlice: + case HloOpcode::kErf: case HloOpcode::kExp: case HloOpcode::kExpm1: case HloOpcode::kFloor: @@ -2878,6 +2889,7 @@ bool HloInstruction::IdenticalSlowPath( case HloOpcode::kReduceScatter: case HloOpcode::kAllReduceStart: case HloOpcode::kAllToAll: + case HloOpcode::kCollectiveBroadcast: case HloOpcode::kCollectivePermute: case HloOpcode::kCollectivePermuteStart: case HloOpcode::kConvolution: @@ -3144,7 +3156,6 @@ bool HloInstruction::has_to_apply() const { case HloOpcode::kReduceWindow: case HloOpcode::kScatter: case HloOpcode::kSort: - case HloOpcode::kTopK: return true; case HloOpcode::kCustomCall: // CustomCall can have a to_apply computation, but it is not required to @@ -3325,6 +3336,7 @@ bool HloInstruction::IsOpElementwise(HloOpcode opcode) { case HloOpcode::kBitcastConvert: case HloOpcode::kCopy: case HloOpcode::kCos: + case HloOpcode::kErf: case HloOpcode::kExp: case HloOpcode::kExpm1: case HloOpcode::kFloor: @@ -3481,7 +3493,8 @@ void HloInstruction::PrintWithCanonicalNameMap( (!metadata_->op_type().empty() || !metadata_->op_name().empty() || !metadata_->source_file().empty())) { printer->Append(", metadata={"); - printer->Append(xla::OpMetadataToString(*metadata_)); + printer->Append(xla::OpMetadataToString( + *metadata_, options.print_metadata_only_op_name())); printer->Append("}"); } if (options.print_backend_config() && !backend_config_.empty()) { @@ -3944,6 +3957,8 @@ Status HloInstruction::Visit(DfsHloVisitorBase* visitor) { return visitor->HandleBatchNormInference(this); case HloOpcode::kBatchNormGrad: return visitor->HandleBatchNormGrad(this); + case HloOpcode::kErf: + return visitor->HandleErf(this); case HloOpcode::kLogistic: return visitor->HandleLogistic(this); case HloOpcode::kSign: @@ -4168,11 +4183,12 @@ Status HloInstruction::Visit(DfsHloVisitorBase* visitor) { return visitor->HandleCholesky(this); case HloOpcode::kOptimizationBarrier: return visitor->HandleOptimizationBarrier(this); + default: + return Internal( + "Unhandled HloOpcode for DfsHloVisitor: %s. This should not happen - " + "please file a bug for XLA.", + HloOpcodeString(opcode_)); } - return Internal( - "Unhandled HloOpcode for DfsHloVisitor: %s. This should not happen - " - "please file a bug for XLA.", - HloOpcodeString(opcode_)); } // Explicit instantiations. diff --git a/third_party/xla/xla/hlo/ir/hlo_instruction.h b/third_party/xla/xla/hlo/ir/hlo_instruction.h index 2ded22753c7469..490181fcb1d544 100644 --- a/third_party/xla/xla/hlo/ir/hlo_instruction.h +++ b/third_party/xla/xla/hlo/ir/hlo_instruction.h @@ -95,6 +95,7 @@ class HloPrintOptions { print_large_constants_(false), print_only_essential_constants_(false), print_metadata_(true), + print_metadata_only_op_name_(false), print_backend_config_(true), print_infeed_outfeed_config_(true), compact_operands_(false), @@ -203,6 +204,13 @@ class HloPrintOptions { return *this; } + // If true and print_metadata is true, metadata op name will be printed. Other + // metadata values will be omitted. + HloPrintOptions& set_print_metadata_only_op_name(bool value) { + print_metadata_only_op_name_ = value; + return *this; + } + // If true, backend_config will be printed. HloPrintOptions& set_print_backend_config(bool value) { print_backend_config_ = value; @@ -369,6 +377,9 @@ class HloPrintOptions { return print_subcomputation_mode_; } bool print_metadata() const { return print_metadata_; } + bool print_metadata_only_op_name() const { + return print_metadata_only_op_name_; + } bool print_backend_config() const { return print_backend_config_; } bool print_infeed_outfeed_config() const { return print_infeed_outfeed_config_; @@ -408,6 +419,7 @@ class HloPrintOptions { bool print_large_constants_; bool print_only_essential_constants_; bool print_metadata_; + bool print_metadata_only_op_name_; bool print_backend_config_; bool print_infeed_outfeed_config_; bool compact_operands_; @@ -1036,6 +1048,10 @@ class HloInstruction { // and returns the receive buffer. The operand must be kRecv. static std::unique_ptr CreateRecvDone( HloInstruction* operand, bool is_host_transfer = false); + // Similar to the above, but the operand doesn't have to be a kRecv. + static std::unique_ptr CreateRecvDone( + HloInstruction* operand, int64_t channel_id, + bool is_host_transfer = false); // Creates a slice instruction, where the operand is sliced by the given // start/limit indices. @@ -1357,7 +1373,7 @@ class HloInstruction { // Returns true if this instruction has a side effect. An instruction has a // side effect if it uses certain opcodes or calls a computation with a side // effect. - bool HasSideEffect() const; + virtual bool HasSideEffect() const; // Returns the result shape of this instruction. const Shape& shape() const; @@ -1656,6 +1672,13 @@ class HloInstruction { // a bitcast. bool IsEffectiveBitcast() const; + // Returns true if this instruction is asynchronous with the + // async_execution_thread set to `execution_thread`. + bool IsAsyncInstructionWithExecutionThread( + absl::string_view execution_thread) const { + return IsAsynchronous() && async_execution_thread() == execution_thread; + }; + // Gets/sets the to_apply HloComputation for Call, Map, Reduce, etc. // The setter should only be called by HloModule or HloComputation methods. // @@ -2085,11 +2108,7 @@ class HloInstruction { // Sets the debug metadata for this instruction, excluding creation_pass_id, // which should never be copied anywhere. - void set_metadata(const OpMetadata& metadata) { - int64_t creation_pass_id = metadata_->creation_pass_id(); - *metadata_ = metadata; - metadata_->set_creation_pass_id(creation_pass_id); - } + void set_metadata(const OpMetadata& metadata) { *metadata_ = metadata; } void set_size_of_generated_code_in_bytes(int64_t code_size_in_bytes) { metadata_->set_size_of_generated_code_in_bytes(code_size_in_bytes); @@ -2099,15 +2118,9 @@ class HloInstruction { metadata_->set_size_of_memory_working_set_in_bytes( working_set_size_in_bytes); } - void set_creation_pass_id(int64_t pass_id) { - metadata_->set_creation_pass_id(pass_id); - } void set_metadata_op_name(const std::string& name) { metadata_->set_op_name(name); } - void set_logical_creation_pass_id(int64_t pass_id) { - metadata_->set_logical_creation_pass_id(pass_id); - } void set_metadata_deduplicated_name(std::string deduplicated_name) { metadata_->set_deduplicated_name(std::move(deduplicated_name)); } diff --git a/third_party/xla/xla/hlo/ir/hlo_instructions.cc b/third_party/xla/xla/hlo/ir/hlo_instructions.cc index 87dcbb80657d80..4c1c38d5389def 100644 --- a/third_party/xla/xla/hlo/ir/hlo_instructions.cc +++ b/third_party/xla/xla/hlo/ir/hlo_instructions.cc @@ -839,13 +839,30 @@ HloRecvDoneInstruction::HloRecvDoneInstruction(HloRecvInstruction* operand, AppendOperand(operand); } +HloRecvDoneInstruction::HloRecvDoneInstruction(HloInstruction* operand, + int64_t channel_id, + bool is_host_transfer) + : HloSendRecvInstruction( + HloOpcode::kRecvDone, + ShapeUtil::MakeTupleShape( + {ShapeUtil::GetTupleElementShape(operand->shape(), 0), + ShapeUtil::MakeTokenShape()}), + channel_id, is_host_transfer) { + AppendOperand(operand); +} + std::unique_ptr HloRecvDoneInstruction::CloneWithNewOperandsImpl( const Shape& shape, absl::Span new_operands, HloCloneContext* context) const { CHECK_EQ(new_operands.size(), 1); + HloRecvInstruction* recv = dynamic_cast(new_operands[0]); + if (recv != nullptr) { + return std::make_unique(recv, is_host_transfer()); + } + return std::make_unique( - Cast(new_operands[0]), is_host_transfer()); + new_operands[0], channel_id().value(), is_host_transfer()); } HloCollectiveInstruction::HloCollectiveInstruction( diff --git a/third_party/xla/xla/hlo/ir/hlo_instructions.h b/third_party/xla/xla/hlo/ir/hlo_instructions.h index fa114df17ff0e1..def030ffc3cdbe 100644 --- a/third_party/xla/xla/hlo/ir/hlo_instructions.h +++ b/third_party/xla/xla/hlo/ir/hlo_instructions.h @@ -264,6 +264,10 @@ class HloAsyncInstruction : public HloInstruction { // *end(GetAsyncChain()) is the async-done op. std::vector GetAsyncChain() const; + bool HasSideEffect() const override { + return async_wrapped_instruction()->HasSideEffect(); + } + protected: // Helper to constructs async-{start,update,done}. HloAsyncInstruction(HloOpcode opcode, const Shape& shape, @@ -607,6 +611,8 @@ class HloRecvDoneInstruction : public HloSendRecvInstruction { public: explicit HloRecvDoneInstruction(HloRecvInstruction* operand, bool is_host_transfer); + explicit HloRecvDoneInstruction(HloInstruction* operand, int64_t channel_id, + bool is_host_transfer); static bool ClassOf(const HloInstruction* hlo) { return hlo->opcode() == HloOpcode::kRecvDone; diff --git a/third_party/xla/xla/hlo/ir/hlo_module.cc b/third_party/xla/xla/hlo/ir/hlo_module.cc index 5ba604e1618ef2..e7a6edbd193843 100644 --- a/third_party/xla/xla/hlo/ir/hlo_module.cc +++ b/third_party/xla/xla/hlo/ir/hlo_module.cc @@ -370,6 +370,15 @@ void HloModule::Print(Printer* printer, const HloPrintOptions& options) const { entry_computation_layout().Print(printer); printer->Append("}"); } + if (config.allow_spmd_sharding_propagation_to_parameters().size() != 1 || + config.allow_spmd_sharding_propagation_to_parameters().back()) { + printer->Append(", allow_spmd_sharding_propagation_to_parameters={"); + AppendJoin(printer, config.allow_spmd_sharding_propagation_to_parameters(), + ",", [](Printer* printer, bool i) { + printer->Append(i ? "true" : "false"); + }); + printer->Append("}"); + } if (config.allow_spmd_sharding_propagation_to_output().size() != 1 || config.allow_spmd_sharding_propagation_to_output().back()) { printer->Append(", allow_spmd_sharding_propagation_to_output={"); @@ -702,6 +711,11 @@ StatusOr HloModule::CreateModuleConfigFromShape( } module_config.set_auto_spmd_partitioning_mesh_ids(mesh_ids); module_config.set_deduplicate_hlo(execution_options->deduplicate_hlo()); + if (!execution_options->allow_spmd_sharding_propagation_to_parameters() + .empty()) { + module_config.set_allow_spmd_sharding_propagation_to_parameters( + execution_options->allow_spmd_sharding_propagation_to_parameters()); + } if (!execution_options->allow_spmd_sharding_propagation_to_output() .empty()) { module_config.set_allow_spmd_sharding_propagation_to_output( diff --git a/third_party/xla/xla/hlo/ir/hlo_module.h b/third_party/xla/xla/hlo/ir/hlo_module.h index 6a3c93b88c81cb..6311929ac533a5 100644 --- a/third_party/xla/xla/hlo/ir/hlo_module.h +++ b/third_party/xla/xla/hlo/ir/hlo_module.h @@ -546,6 +546,12 @@ class HloModule { instr->UniquifyName(&instruction_name_uniquer_); } + void SetAndUniquifyComputationName(HloComputation* computation, + absl::string_view name) { + computation->SetAndSanitizeName(name); + computation->UniquifyName(&computation_name_uniquer_); + } + Status CheckUniqueNamesAndIdsForComputationsAndInstructions() const; // Checks if this config has a list of entry parameters' HLO shardings for diff --git a/third_party/xla/xla/hlo/ir/hlo_module_metadata.cc b/third_party/xla/xla/hlo/ir/hlo_module_metadata.cc index 5ad3f264bf55d8..3af6a63cacec9a 100644 --- a/third_party/xla/xla/hlo/ir/hlo_module_metadata.cc +++ b/third_party/xla/xla/hlo/ir/hlo_module_metadata.cc @@ -18,7 +18,10 @@ limitations under the License. #include #include "absl/container/flat_hash_set.h" +#include "absl/log/log.h" +#include "xla/util.h" #include "tsl/platform/env.h" +#include "tsl/platform/protobuf.h" namespace xla { @@ -84,4 +87,16 @@ void HloModuleMetadata::set_prepartitioning_metadata( } } +Status HloModuleMetadata::set_custom_metadata( + const ::tsl::protobuf::Message& message) { + TF_ASSIGN_OR_RETURN(HloPassMetadata * pass_metadata, + GetCurrentHloPassMetadata()); + if (!pass_metadata->mutable_custom_metadata()->PackFrom(message)) { + LOG(WARNING) << "failed to pack custom metadata for " + << pass_metadata->pass_id(); + return Internal("failed to pack custom metadata"); + }; + return OkStatus(); +} + } // namespace xla diff --git a/third_party/xla/xla/hlo/ir/hlo_module_metadata.h b/third_party/xla/xla/hlo/ir/hlo_module_metadata.h index 1c4f2a774e1bc4..0fc4f3169ad3ce 100644 --- a/third_party/xla/xla/hlo/ir/hlo_module_metadata.h +++ b/third_party/xla/xla/hlo/ir/hlo_module_metadata.h @@ -25,6 +25,7 @@ limitations under the License. #include "xla/status_macros.h" #include "xla/util.h" #include "tsl/platform/env.h" +#include "tsl/platform/protobuf.h" namespace xla { @@ -62,6 +63,7 @@ class HloModuleMetadata { void add_partitioned_module_id(int64_t id) { module_metadata_.add_partitioned_module_ids(id); } + Status set_custom_metadata(const ::tsl::protobuf::Message& message); StatusOr current_pass_id() { TF_ASSIGN_OR_RETURN(HloPassMetadata * pass_metadata, diff --git a/third_party/xla/xla/hlo/ir/hlo_op_metadata.cc b/third_party/xla/xla/hlo/ir/hlo_op_metadata.cc index aeebd0cfce20fb..b45f23ba96d432 100644 --- a/third_party/xla/xla/hlo/ir/hlo_op_metadata.cc +++ b/third_party/xla/xla/hlo/ir/hlo_op_metadata.cc @@ -24,8 +24,16 @@ limitations under the License. namespace xla { -std::string OpMetadataToString(const OpMetadata& metadata) { +std::string OpMetadataToString(const OpMetadata& metadata, bool only_op_name) { std::vector result; + if (only_op_name) { + if (!metadata.op_name().empty()) { + return absl::StrCat("op_name=\"", absl::CEscape(metadata.op_name()), + "\""); + } else { + return ""; + } + } if (!metadata.op_type().empty()) { result.push_back( absl::StrCat("op_type=\"", absl::CEscape(metadata.op_type()), "\"")); diff --git a/third_party/xla/xla/hlo/ir/hlo_op_metadata.h b/third_party/xla/xla/hlo/ir/hlo_op_metadata.h index 26311c812a1f7f..acbd34c84af2f8 100644 --- a/third_party/xla/xla/hlo/ir/hlo_op_metadata.h +++ b/third_party/xla/xla/hlo/ir/hlo_op_metadata.h @@ -21,7 +21,8 @@ limitations under the License. #include "xla/xla_data.pb.h" namespace xla { -std::string OpMetadataToString(const OpMetadata& metadata); +std::string OpMetadataToString(const OpMetadata& metadata, + bool only_op_name = false); } // namespace xla #endif // XLA_HLO_IR_HLO_OP_METADATA_H_ diff --git a/third_party/xla/xla/hlo/ir/hlo_opcode.h b/third_party/xla/xla/hlo/ir/hlo_opcode.h index 338346b80070c4..42be90377f959d 100644 --- a/third_party/xla/xla/hlo/ir/hlo_opcode.h +++ b/third_party/xla/xla/hlo/ir/hlo_opcode.h @@ -75,6 +75,7 @@ namespace xla { V(kCholesky, "cholesky", 1) \ V(kClamp, "clamp", 3) \ V(kClz, "count-leading-zeros", 1) \ + V(kCollectiveBroadcast, "collective-broadcast", kHloOpcodeIsVariadic) \ V(kCollectivePermute, "collective-permute", kHloOpcodeIsVariadic) \ V(kCollectivePermuteDone, "collective-permute-done", 1) \ V(kCollectivePermuteStart, "collective-permute-start", kHloOpcodeIsVariadic) \ @@ -96,6 +97,7 @@ namespace xla { V(kDynamicReshape, "dynamic-reshape", kHloOpcodeIsVariadic) \ V(kDynamicSlice, "dynamic-slice", kHloOpcodeIsVariadic) \ V(kDynamicUpdateSlice, "dynamic-update-slice", kHloOpcodeIsVariadic) \ + V(kErf, "erf", 1) \ V(kExp, "exponential", 1) \ V(kExpm1, "exponential-minus-one", 1) \ V(kFft, "fft", 1) \ diff --git a/third_party/xla/xla/hlo/transforms/BUILD b/third_party/xla/xla/hlo/transforms/BUILD index 2482904a2582be..16b4c0cf384094 100644 --- a/third_party/xla/xla/hlo/transforms/BUILD +++ b/third_party/xla/xla/hlo/transforms/BUILD @@ -5,7 +5,8 @@ load("//xla:xla.bzl", "xla_cc_test") load("@local_tsl//tsl/platform:rules_cc.bzl", "cc_library") package( - default_visibility = ["//visibility:public"], + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = [":friends"], licenses = ["notice"], ) @@ -20,7 +21,6 @@ cc_library( name = "hlo_constant_splitter", srcs = ["hlo_constant_splitter.cc"], hdrs = ["hlo_constant_splitter.h"], - visibility = ["//visibility:public"], deps = [ "//xla/hlo/ir:hlo", "//xla/service:hlo_pass", diff --git a/third_party/xla/xla/hlo/utils/BUILD b/third_party/xla/xla/hlo/utils/BUILD index 92ae2ccc70f2bf..cec66c2bd7be4f 100644 --- a/third_party/xla/xla/hlo/utils/BUILD +++ b/third_party/xla/xla/hlo/utils/BUILD @@ -8,7 +8,8 @@ load( load("@local_tsl//tsl/platform:rules_cc.bzl", "cc_library") package( - default_visibility = ["//visibility:public"], + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = [":friends"], licenses = ["notice"], ) @@ -23,7 +24,6 @@ cc_library( name = "hlo_live_range", srcs = ["hlo_live_range.cc"], hdrs = ["hlo_live_range.h"], - visibility = ["//visibility:public"], deps = [ "//xla:shape_util", "//xla:statusor", @@ -65,7 +65,6 @@ cc_library( testonly = 1, srcs = ["hlo_matchers.cc"], hdrs = ["hlo_matchers.h"], - visibility = ["//visibility:public"], deps = [ "//xla:test", "//xla/hlo/ir:hlo", @@ -94,7 +93,6 @@ cc_library( hdrs = [ "hlo_sharding_util.h", ], - visibility = ["//visibility:public"], deps = [ "//xla:array", "//xla:literal_util", @@ -144,7 +142,6 @@ cc_library( name = "hlo_query", srcs = ["hlo_query.cc"], hdrs = ["hlo_query.h"], - visibility = ["//visibility:public"], deps = [ "//xla:literal", "//xla:shape_util", diff --git a/third_party/xla/xla/hlo/utils/hlo_live_range.cc b/third_party/xla/xla/hlo/utils/hlo_live_range.cc index 674f548c149b59..670bf8f570889a 100644 --- a/third_party/xla/xla/hlo/utils/hlo_live_range.cc +++ b/third_party/xla/xla/hlo/utils/hlo_live_range.cc @@ -41,7 +41,7 @@ limitations under the License. namespace xla { /*static*/ -StatusOr> HloLiveRange::Run( +absl::StatusOr> HloLiveRange::Run( const HloSchedule& schedule, const HloAliasAnalysis& alias_analysis, const HloComputation* computation, bool module_scoped_analysis) { std::unique_ptr hlo_live_range( diff --git a/third_party/xla/xla/hlo/utils/hlo_live_range.h b/third_party/xla/xla/hlo/utils/hlo_live_range.h index 47332d8f0ee9e8..eb1530503ab121 100644 --- a/third_party/xla/xla/hlo/utils/hlo_live_range.h +++ b/third_party/xla/xla/hlo/utils/hlo_live_range.h @@ -37,7 +37,7 @@ class HloLiveRange { public: // Constructs a hlo live range object for the given module and computation // assuming the given HLO instruction ordering. - static StatusOr> Run( + static absl::StatusOr> Run( const HloSchedule& schedule, const HloAliasAnalysis& alias_analysis, const HloComputation* computation, bool module_scoped_analysis = true); diff --git a/third_party/xla/xla/hlo/utils/hlo_matchers.h b/third_party/xla/xla/hlo/utils/hlo_matchers.h index 2111cbb70de440..d912c81381f3d1 100644 --- a/third_party/xla/xla/hlo/utils/hlo_matchers.h +++ b/third_party/xla/xla/hlo/utils/hlo_matchers.h @@ -285,6 +285,7 @@ HLO_MATCHER(Divide); HLO_MATCHER(Domain); HLO_MATCHER(DynamicSlice); HLO_MATCHER(DynamicUpdateSlice); +HLO_MATCHER(Erf); HLO_MATCHER(Exp); HLO_MATCHER(Fft); HLO_MATCHER(Floor); diff --git a/third_party/xla/xla/hlo/utils/hlo_sharding_util.cc b/third_party/xla/xla/hlo/utils/hlo_sharding_util.cc index f5792d58d70b2a..3a56fd7a3e5fc6 100644 --- a/third_party/xla/xla/hlo/utils/hlo_sharding_util.cc +++ b/third_party/xla/xla/hlo/utils/hlo_sharding_util.cc @@ -1807,7 +1807,7 @@ HloSharding GatherOutputOrScatterUpdateShardingFromIndicesParallelDimensions( indices_sharding.metadata()); } -StatusOr, HloOpcode>> +absl::StatusOr, HloOpcode>> IdentityValueAndHloOpcodeForScatterReduceComputation( const HloScatterInstruction& scatter) { auto computation = scatter.to_apply(); diff --git a/third_party/xla/xla/hlo/utils/hlo_sharding_util.h b/third_party/xla/xla/hlo/utils/hlo_sharding_util.h index 0fbd71cee7be31..bbf074c408a4a7 100644 --- a/third_party/xla/xla/hlo/utils/hlo_sharding_util.h +++ b/third_party/xla/xla/hlo/utils/hlo_sharding_util.h @@ -260,7 +260,7 @@ HloSharding GatherOutputOrScatterUpdateShardingFromIndicesParallelDimensions( // - If computation is min/max, return max value/min value with corresponding op // code. // - Otherwise, return error status. -StatusOr, HloOpcode>> +absl::StatusOr, HloOpcode>> IdentityValueAndHloOpcodeForScatterReduceComputation( const HloScatterInstruction& scatter); diff --git a/third_party/xla/xla/layout.cc b/third_party/xla/xla/layout.cc index 55097bd1e7fdb4..9c586f6db619d8 100644 --- a/third_party/xla/xla/layout.cc +++ b/third_party/xla/xla/layout.cc @@ -84,8 +84,8 @@ Layout::Layout(absl::Span minor_to_major, int64_t dynamic_shape_metadata_prefix_bytes) : index_primitive_type_(index_primitive_type), pointer_primitive_type_(element_primitive_type), - element_size_in_bits_(element_size_in_bits), memory_space_(memory_space), + element_size_in_bits_(element_size_in_bits), minor_to_major_(minor_to_major.begin(), minor_to_major.end()), tiles_(tiles.begin(), tiles.end()), tail_padding_alignment_in_elements_(tail_padding_alignment_in_elements), @@ -116,8 +116,8 @@ Layout::Layout(const Layout& other) n_dim_ordered_(other.n_dim_ordered_), index_primitive_type_(other.index_primitive_type_), pointer_primitive_type_(other.pointer_primitive_type_), - element_size_in_bits_(other.element_size_in_bits_), memory_space_(other.memory_space_), + element_size_in_bits_(other.element_size_in_bits_), minor_to_major_(other.minor_to_major_), tiles_(other.tiles_), tail_padding_alignment_in_elements_( diff --git a/third_party/xla/xla/layout.h b/third_party/xla/xla/layout.h index 70b05f129ee3a9..f36da76e037172 100644 --- a/third_party/xla/xla/layout.h +++ b/third_party/xla/xla/layout.h @@ -389,13 +389,13 @@ class Layout { PrimitiveType index_primitive_type_ : 8; PrimitiveType pointer_primitive_type_ : 8; + // The assigned memory space. + int8_t memory_space_ = 0; + // The number of bits used to store an individual array element. // When the value is 0, default to ShapeUtil::ByteSizeOfPrimitiveType. uint16_t element_size_in_bits_ = 0; - // The assigned memory space. - int8_t memory_space_ = 0; - // A map from physical dimension numbers to logical dimension numbers. // The first element is the most minor physical dimension (fastest varying // index) and the last the most major (slowest varying index). The contents of diff --git a/third_party/xla/xla/literal_util.h b/third_party/xla/xla/literal_util.h index 4bfa97fbfa0c13..452f34291b31a1 100644 --- a/third_party/xla/xla/literal_util.h +++ b/third_party/xla/xla/literal_util.h @@ -76,6 +76,8 @@ class LiteralUtil { // literal's linear representation in memory. template static Literal CreateR0(NativeT value); + template + static Literal CreateR0(PrimitiveType primitive_type, T value); template static Literal CreateR1(absl::Span values); static Literal CreateR1(const tsl::core::Bitmap& values); @@ -297,6 +299,17 @@ template return literal; } +template +/* static */ Literal LiteralUtil::CreateR0(PrimitiveType primitive_type, + T value) { + return primitive_util::ArrayTypeSwitch( + [&value](auto type) { + using NativeT = primitive_util::NativeTypeOf; + return CreateR0(static_cast(value)); + }, + primitive_type); +} + template /* static */ Literal LiteralUtil::CreateR1(absl::Span values) { Literal literal( diff --git a/third_party/xla/xla/mlir/backends/cpu/BUILD b/third_party/xla/xla/mlir/backends/cpu/BUILD index b4623938eb2cba..f26817898f4f31 100644 --- a/third_party/xla/xla/mlir/backends/cpu/BUILD +++ b/third_party/xla/xla/mlir/backends/cpu/BUILD @@ -2,7 +2,8 @@ load("@bazel_skylib//rules:build_test.bzl", "build_test") load("//xla:xla.bzl", "xla_cc_binary") package( - default_visibility = ["//visibility:public"], + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = ["//xla/mlir:__subpackages__"], licenses = ["notice"], ) diff --git a/third_party/xla/xla/mlir/backends/cpu/transforms/BUILD b/third_party/xla/xla/mlir/backends/cpu/transforms/BUILD index 6cae37652fb39c..889bb5106136bf 100644 --- a/third_party/xla/xla/mlir/backends/cpu/transforms/BUILD +++ b/third_party/xla/xla/mlir/backends/cpu/transforms/BUILD @@ -3,7 +3,8 @@ load("@local_tsl//tsl:tsl.default.bzl", "get_compatible_with_portable") load("@local_tsl//tsl/platform:rules_cc.bzl", "cc_library") package( - default_visibility = ["//visibility:public"], + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = ["//xla:internal"], licenses = ["notice"], ) @@ -38,7 +39,6 @@ cc_library( "xla_rewrite_realloc_to_alloc.cc", ], hdrs = ["passes.h"], - visibility = ["//visibility:public"], deps = [ ":passes_inc_gen", "//xla/mlir/runtime/transforms:type_converter", diff --git a/third_party/xla/xla/mlir/backends/cpu/transforms/sparse_rewrite_passes.cc b/third_party/xla/xla/mlir/backends/cpu/transforms/sparse_rewrite_passes.cc index 028e565b5eb5ee..915e8c57d48598 100644 --- a/third_party/xla/xla/mlir/backends/cpu/transforms/sparse_rewrite_passes.cc +++ b/third_party/xla/xla/mlir/backends/cpu/transforms/sparse_rewrite_passes.cc @@ -460,7 +460,9 @@ struct SparseSDDMMCallRewriter { iteratorTypes.push_back(utils::IteratorType::parallel); iteratorTypes.push_back(utils::IteratorType::reduction); using MapList = ArrayRef>; - auto infer = [](MapList m) { return AffineMap::inferFromExprList(m); }; + auto infer = [&](MapList m) { + return AffineMap::inferFromExprList(m, rewriter.getContext()); + }; AffineExpr i, j, k; bindDims(op.getContext(), i, j, k); auto indexingMaps = infer({{i, k}, {k, j}, {i, j}}); @@ -522,7 +524,9 @@ struct Sparse2To4SpMMCallRewriter { iteratorTypes.push_back(utils::IteratorType::parallel); iteratorTypes.push_back(utils::IteratorType::reduction); using MapList = ArrayRef>; - auto infer = [](MapList m) { return AffineMap::inferFromExprList(m); }; + auto infer = [&](MapList m) { + return AffineMap::inferFromExprList(m, rewriter.getContext()); + }; AffineExpr i, j, k; bindDims(op.getContext(), i, j, k); auto indexing_maps = infer({{i, k}, {k, j}, {i, j}}); diff --git a/third_party/xla/xla/mlir/backends/cpu/transforms/tests/BUILD b/third_party/xla/xla/mlir/backends/cpu/transforms/tests/BUILD index e329737f131c98..cdb29227e237cf 100644 --- a/third_party/xla/xla/mlir/backends/cpu/transforms/tests/BUILD +++ b/third_party/xla/xla/mlir/backends/cpu/transforms/tests/BUILD @@ -1,7 +1,6 @@ load("//xla:lit.bzl", "enforce_glob", "lit_test_suite") package( - default_visibility = ["//visibility:public"], # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], licenses = ["notice"], ) diff --git a/third_party/xla/xla/mlir/backends/gpu/BUILD b/third_party/xla/xla/mlir/backends/gpu/BUILD deleted file mode 100644 index b14530d06af5d0..00000000000000 --- a/third_party/xla/xla/mlir/backends/gpu/BUILD +++ /dev/null @@ -1,29 +0,0 @@ -load("@bazel_skylib//rules:build_test.bzl", "build_test") -load("//xla:xla.bzl", "xla_cc_binary") - -package( - default_visibility = ["//visibility:public"], - licenses = ["notice"], -) - -build_test( - name = "xla-gpu-opt_build_test", - targets = [ - ":xla-gpu-opt", - ], -) - -xla_cc_binary( - name = "xla-gpu-opt", - srcs = ["xla-gpu-opt.cc"], - deps = [ - "//xla/mlir/backends/gpu/transforms:passes", - "//xla/mlir_hlo:lhlo", - "//xla/mlir_hlo:lhlo_gpu", - "@llvm-project//mlir:FuncDialect", - "@llvm-project//mlir:FuncExtensions", - "@llvm-project//mlir:GPUDialect", - "@llvm-project//mlir:MemRefDialect", - "@llvm-project//mlir:MlirOptLib", - ], -) diff --git a/third_party/xla/xla/mlir/backends/gpu/transforms/BUILD b/third_party/xla/xla/mlir/backends/gpu/transforms/BUILD deleted file mode 100644 index abbf17d1c8171c..00000000000000 --- a/third_party/xla/xla/mlir/backends/gpu/transforms/BUILD +++ /dev/null @@ -1,102 +0,0 @@ -load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library") -load("@local_tsl//tsl:tsl.default.bzl", "get_compatible_with_portable") -load("@local_tsl//tsl/platform:rules_cc.bzl", "cc_library") - -package( - default_visibility = ["//visibility:public"], - licenses = ["notice"], -) - -gentbl_cc_library( - name = "passes_inc_gen", - compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - [ - "-gen-pass-decls", - "-name=GpuTransforms", - ], - "passes.h.inc", - ), - ], - tblgen = "@llvm-project//mlir:mlir-tblgen", - td_file = "passes.td", - deps = ["@llvm-project//mlir:PassBaseTdFiles"], -) - -cc_library( - name = "dataflow_analysis", - srcs = ["dataflow_analysis.cc"], - hdrs = ["dataflow_analysis.h"], - compatible_with = [], - visibility = ["//visibility:public"], - deps = [ - "//xla/mlir_hlo:lhlo_gpu", - "@com_google_absl//absl/strings", - "@llvm-project//mlir:FuncDialect", - "@llvm-project//mlir:GPUDialect", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:MemRefDialect", - ], -) - -cc_library( - name = "passes", - srcs = [ - "add_concurrent_regions.cc", - "add_hlo_trace_annotations.cc", - "gpu_to_gpu_runtime.cc", - "lmhlo_gpu_to_gpu_runtime.cc", - "lmhlo_to_gpu_launch.cc", - "lmhlo_to_gpu_runtime.cc", - "memref_get_global_to_arg.cc", - "outline_cuda_graphs.cc", - "passes.cc", - "stream_assignment.cc", - "uid_generator.h", - ], - hdrs = ["passes.h"], - # Override cc_library()'s internal default value of ["//buildenv/target:gce"].` - # TODO(ezhulenev): Do not depend on NCCL thunks in compiler passes. - compatible_with = [], - visibility = ["//visibility:public"], - deps = [ - ":dataflow_analysis", - ":passes_inc_gen", - "//xla:debug_options_flags", - "//xla:xla_proto_cc", - "//xla/mlir/runtime/ir:rt", - "//xla/mlir/runtime/utils:custom_calls", - "//xla/mlir_hlo:lhlo", - "//xla/mlir_hlo:lhlo_gpu", - "//xla/service/gpu:backend_configs_cc", - "//xla/service/gpu:gpu_executable", - "//xla/service/gpu:launch_dimensions", - "//xla/service/gpu:nccl_collective_thunks", - "//xla/service/gpu/runtime3:conditional_thunk", - "//xla/service/gpu/runtime3:copy_thunk", - "//xla/service/gpu/runtime3:kernel_thunk", - "//xla/service/gpu/runtime3:memset_thunk", - "//xla/service/gpu/runtime3:sequential_thunk", - "//xla/service/gpu/runtime3:while_thunk", - "//xla/stream_executor:blas", - "//xla/stream_executor:device_description", - "//xla/translate/mhlo_to_hlo:location_exporter", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/log", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@llvm-project//llvm:Support", - "@llvm-project//mlir:ArithDialect", - "@llvm-project//mlir:ControlFlowDialect", - "@llvm-project//mlir:FuncDialect", - "@llvm-project//mlir:GPUDialect", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:MemRefDialect", - "@llvm-project//mlir:Pass", - "@llvm-project//mlir:SCFDialect", - "@llvm-project//mlir:Support", - "@llvm-project//mlir:Transforms", - "@local_tsl//tsl/platform:env", - ], -) diff --git a/third_party/xla/xla/mlir/backends/gpu/transforms/add_concurrent_regions.cc b/third_party/xla/xla/mlir/backends/gpu/transforms/add_concurrent_regions.cc deleted file mode 100644 index caed7dc88d08f9..00000000000000 --- a/third_party/xla/xla/mlir/backends/gpu/transforms/add_concurrent_regions.cc +++ /dev/null @@ -1,236 +0,0 @@ -/* Copyright 2023 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#include -#include -#include -#include -#include - -#include "absl/strings/match.h" -#include "llvm/ADT/ArrayRef.h" -#include "llvm/ADT/SmallVector.h" -#include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project -#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project -#include "mlir/Dialect/GPU/IR/GPUDialect.h" // from @llvm-project -#include "mlir/Dialect/MemRef/IR/MemRef.h" // from @llvm-project -#include "mlir/IR/BuiltinOps.h" // from @llvm-project -#include "mlir/IR/SymbolTable.h" // from @llvm-project -#include "mlir/IR/Value.h" // from @llvm-project -#include "mlir/IR/Visitors.h" // from @llvm-project -#include "mlir/Pass/Pass.h" // from @llvm-project -#include "mlir/Support/LLVM.h" // from @llvm-project -#include "xla/mlir/backends/gpu/transforms/dataflow_analysis.h" -#include "xla/mlir/runtime/utils/custom_calls.h" -#include "tsl/platform/env.h" - -namespace xla { -namespace gpu { - -namespace { - -#define GEN_PASS_DEF_ADDCONCURRENTREGIONSPASS -#include "xla/mlir/backends/gpu/transforms/passes.h.inc" - -using namespace mlir; // NOLINT -using mlir::func::FuncOp; -using xla::runtime::CustomCallDeclarations; - -class AddConcurrentRegionsPass - : public impl::AddConcurrentRegionsPassBase { - void runOnOperation() override; -}; - -//===----------------------------------------------------------------------===// - -struct RegionInfo { - Operation* start; - Operation* end; - int size; -}; - -bool IsNoOp(Operation* op) { - return isa(op); -} - -int GetKernelCount(llvm::ArrayRef region) { - int kernel_count = 0; - for (const DataflowAnalysis::Node& node : region) { - Operation* op = node.operation; - if (!IsNoOp(op)) { - kernel_count++; - } - } - return kernel_count; -} - -// We use the size of the inputs to the kernel as a heuristic to avoid -// adding memory bound kernels to the concurrent region. -// The memory bandwidth on A100 is 2MB/us, so a data movement less than 10MB -// is hidden by the kernel launch overhead, which is 5us. -static constexpr int64_t kInputSizeThreshold = 10'000'000; - -bool IsKernelMemoryBound(Operation* op) { - if (auto launch_func = dyn_cast(op)) { - size_t size = 0; - - for (Value operand : launch_func.getOperands()) { - if (auto memref_type = dyn_cast(operand.getType())) { - size += (memref_type.getNumElements() * - memref_type.getElementTypeBitWidth() + - 7) / - 8; - } - } - - if (size > kInputSizeThreshold) { - return true; - } - } - - return false; -} - -// -// Return a list of pairs of operations, in which the first element is the -// first operation in the region, and the second is the last operation in the -// region. -// -// We currently use a greedy algorithm to determine region starting point: -// regions = [] -// region = {first operation} -// for operation in the capture function -// if HasDependency(region, operation) -// regions.add(region) -// region = new region -// else -// region.add(operation) -// -llvm::SmallVector GetRegionInfos( - FuncOp capture_func, DataflowAnalysis& dataflow_analysis) { - llvm::SmallVector region_infos; - DataflowAnalysis::DataflowGraph dataflow_graph = - dataflow_analysis.GetDataflowGraph(capture_func); - - // If verbose logging is enabled print the dataflow graph as a DOT graph. - if (VLOG_IS_ON(100)) { - std::cout << "Dependency graph for graph capture function " - << capture_func.getName().str() << ":\n" - << dataflow_analysis.ToDot(dataflow_graph); - } - - llvm::SmallVector region; - - auto store_region_and_start_new_region = [&]() { - int kernel_count = GetKernelCount(region); - if (kernel_count >= 2) { - RegionInfo region_info = {region.front().operation, - region.back().operation, kernel_count}; - region_infos.push_back(region_info); - } - region.clear(); - }; - - auto append_node_to_region = [&](const DataflowAnalysis::Node& node) { - if (region.empty()) { - if (!IsNoOp(node.operation)) { - region.push_back(node); - } - } else { - region.push_back(node); - } - }; - - for (const DataflowAnalysis::Node& node : dataflow_graph) { - if (isa(node.operation)) { - break; - } - - bool has_dependency = false; - for (const DataflowAnalysis::Node& node_in_region : region) { - std::vector children = node_in_region.children; - if (std::find(children.begin(), children.end(), node.index) != - children.end()) { - has_dependency = true; - break; - } - } - - if (IsKernelMemoryBound(node.operation)) { - store_region_and_start_new_region(); - } else if (has_dependency) { - store_region_and_start_new_region(); - append_node_to_region(node); - } else { - append_node_to_region(node); - } - } - - store_region_and_start_new_region(); - return region_infos; -} - -void InsertConcurrentRegions(FuncOp capture_func, - CustomCallDeclarations& custom_calls, - DataflowAnalysis& dataflow_analysis) { - llvm::SmallVector region_infos = - GetRegionInfos(capture_func, dataflow_analysis); - auto sym_table = custom_calls.sym_table(); - - for (RegionInfo region_info : region_infos) { - Operation* start = region_info.start; - Operation* end = region_info.end; - - ImplicitLocOpBuilder b(start->getLoc(), sym_table.getOp()); - func::FuncOp begin_marker = custom_calls.GetOrCreate( - b, "xla.gpu.concurrent_region.begin", TypeRange(), TypeRange()); - b.setInsertionPoint(start); - auto call = b.create(begin_marker.getName(), TypeRange()); - call->setAttr(b.getStringAttr("size"), - IntegerAttr::get(b.getIntegerType(64), region_info.size)); - - func::FuncOp end_marker = custom_calls.GetOrCreate( - b, "xla.gpu.concurrent_region.end", TypeRange(), TypeRange()); - b.setInsertionPointAfter(end); - b.create(end_marker.getName(), TypeRange()); - } -} - -//===----------------------------------------------------------------------===// - -void AddConcurrentRegionsPass::runOnOperation() { - ModuleOp module = getOperation(); - SymbolTable sym_table(module); - CustomCallDeclarations custom_calls(std::move(sym_table)); - - auto func_ops = llvm::to_vector(module.getOps()); - - for (auto func_op : func_ops) { - // Find the gpu graph capture function. - if (absl::StrContains(func_op.getSymNameAttr().str(), - "xla.gpu.graph.capture")) { - InsertConcurrentRegions(func_op, custom_calls, - getAnalysis()); - } - } -} - -} // namespace - -std::unique_ptr> createAddConcurrentRegionsPass() { - return std::make_unique(); -} - -} // namespace gpu -} // namespace xla diff --git a/third_party/xla/xla/mlir/backends/gpu/transforms/add_hlo_trace_annotations.cc b/third_party/xla/xla/mlir/backends/gpu/transforms/add_hlo_trace_annotations.cc deleted file mode 100644 index c4ba4e3008acb2..00000000000000 --- a/third_party/xla/xla/mlir/backends/gpu/transforms/add_hlo_trace_annotations.cc +++ /dev/null @@ -1,88 +0,0 @@ -/* Copyright 2022 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include -#include -#include - -#include "absl/strings/match.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project -#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project -#include "mlir/IR/MLIRContext.h" // from @llvm-project -#include "mlir/IR/SymbolTable.h" // from @llvm-project -#include "mlir/Support/LLVM.h" // from @llvm-project -#include "xla/mlir/backends/gpu/transforms/passes.h" -#include "xla/mlir/runtime/ir/rt_dialect.h" -#include "xla/translate/mhlo_to_hlo/location_exporter.h" - -namespace xla { -namespace gpu { - -#define GEN_PASS_DEF_ADDHLOTRACEANNOTATIONSPASS -#include "xla/mlir/backends/gpu/transforms/passes.h.inc" - -using namespace mlir; // NOLINT - -using xla::runtime::HloTraceAttr; - -class AddHloTraceAnnotationsPass - : public impl::AddHloTraceAnnotationsPassBase { - void runOnOperation() override; - - void getDependentDialects(DialectRegistry& registry) const override { - registry.insert(); - } -}; - -//===----------------------------------------------------------------------===// - -void AddHloTraceAnnotationsPass::runOnOperation() { - MLIRContext* ctx = &getContext(); - - ModuleOp module = getOperation(); - SymbolTable sym_table(module); - - getOperation().walk([&](func::CallOp call) { - // Check if the callee is a custom call. - auto callee = sym_table.lookup(call.getCallee()); - if (!callee->hasAttr("rt.custom_call")) return; - - // Drop multi-op trace for CUDA graphs since they are too large for xprof to - // display. - // TODO(b/275240695): Report the graph content once the Xprof team provides - // an API. - if (absl::StrContains(call.getCalleeAttr().getValue(), - "xla.gpu.graph.launch")) { - auto capture = call->getAttr("capture").cast(); - std::string op_name = "cuda_graph/" + capture.getValue().str(); - auto annotation = HloTraceAttr::get(ctx, std::move(op_name)); - call->setAttr("rt.trace", annotation); - return; - } - - // HLO operation name is encoded in the operation location. - std::string hlo_op = mlir::mhlo::GetDebugNameFromLocation(call->getLoc()); - auto annotation = HloTraceAttr::get(ctx, std::move(hlo_op)); - call->setAttr("rt.trace", annotation); - }); -} - -std::unique_ptr> -createAddHloTraceAnnotationsPass() { - return std::make_unique(); -} - -} // namespace gpu -} // namespace xla diff --git a/third_party/xla/xla/mlir/backends/gpu/transforms/dataflow_analysis.cc b/third_party/xla/xla/mlir/backends/gpu/transforms/dataflow_analysis.cc deleted file mode 100644 index 59d70def88789f..00000000000000 --- a/third_party/xla/xla/mlir/backends/gpu/transforms/dataflow_analysis.cc +++ /dev/null @@ -1,277 +0,0 @@ -/* Copyright 2023 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "xla/mlir/backends/gpu/transforms/dataflow_analysis.h" - -#include -#include -#include -#include - -#include "absl/strings/str_cat.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project -#include "mlir/Dialect/GPU/IR/GPUDialect.h" // from @llvm-project -#include "mlir/Dialect/MemRef/IR/MemRef.h" // from @llvm-project -#include "mlir/IR/SymbolTable.h" // from @llvm-project -#include "mlir/IR/Value.h" // from @llvm-project -#include "mlir/IR/Visitors.h" // from @llvm-project -#include "xla/mlir_hlo/lhlo_gpu/IR/lhlo_gpu_ops.h" - -namespace xla { -namespace gpu { - -namespace { - -using namespace mlir; // NOLINT -using mlir::BlockArgument; -using mlir::Operation; -using mlir::func::FuncOp; - -// Represents a slice of the buffer argument to the graph capture function. -struct BufferUse { - BlockArgument arg; - size_t offset; - size_t byte_len; - - // The buffer is only read by the operation. - bool read_only; -}; - -BufferUse GetBufferUse(Value operand, bool read_only = false) { - Operation* defining_op = operand.getDefiningOp(); - if (!defining_op) { - auto block_argument = cast(operand); - auto memref_type = cast(block_argument.getType()); - size_t byte_len = - (memref_type.getNumElements() * memref_type.getElementTypeBitWidth() + - 7) / - 8; - return {block_argument, 0, byte_len, read_only}; - } - - if (isa(defining_op)) { - auto view_op = cast(defining_op); - auto buffer_use = GetBufferUse(view_op.getSource()); - - IntegerAttr offset_attr; - bool is_constant = - matchPattern(view_op.getByteShift(), m_Constant(&offset_attr)); - if (!is_constant) { - // Failed to refine the BufferUse. - return buffer_use; - } - size_t offset = offset_attr.getInt(); - - // Get len. - auto memref_type = cast(view_op.getType()); - // TODO(b/274157088): Handle the case where elements are complex numbers. - if (!memref_type.getElementType().isIntOrFloat()) { - return buffer_use; - } - - size_t byte_len = - (memref_type.getNumElements() * memref_type.getElementTypeBitWidth() + - 7) / - 8; - - return {buffer_use.arg, buffer_use.offset + offset, byte_len, read_only}; - } - - if (auto cast = dyn_cast(defining_op)) { - return GetBufferUse(cast.getSource(), read_only); - } - - return {}; -} - -llvm::SmallVector GetBufferUses(Operation& operation) { - llvm::SmallVector operand_buffer_uses; - if (auto launch_func = dyn_cast(operation)) { - auto kernel_func = - SymbolTable::lookupNearestSymbolFrom( - &operation, launch_func.getKernel()); - auto kernel_operands = launch_func.getKernelOperands(); - for (auto it : llvm::enumerate(kernel_operands)) { - BufferUse buffer_use = GetBufferUse( - it.value(), - /*read_only=*/!kernel_func.getArgAttrOfType( - it.index(), "lmhlo.written")); - operand_buffer_uses.push_back(buffer_use); - } - } else if (auto gemm = dyn_cast(operation)) { - BufferUse buffer_use_0 = GetBufferUse(gemm.getA(), /*read_only=*/true); - BufferUse buffer_use_1 = GetBufferUse(gemm.getB(), /*read_only=*/true); - BufferUse buffer_use_2 = GetBufferUse(gemm.getC(), /*read_only=*/false); - operand_buffer_uses.push_back(buffer_use_0); - operand_buffer_uses.push_back(buffer_use_1); - operand_buffer_uses.push_back(buffer_use_2); - } else if (auto memcpy = dyn_cast(operation)) { - BufferUse src_buffer = GetBufferUse(memcpy.getSrc(), /*read_only=*/true); - BufferUse dst_buffer = GetBufferUse(memcpy.getDst(), /*read_only=*/false); - operand_buffer_uses.push_back(src_buffer); - operand_buffer_uses.push_back(dst_buffer); - } - - return operand_buffer_uses; -} - -// Arguments to the graph capture function may have the "lmhlo.constant_name" -// attribute, which indicates that the passed-in buffer is constant. -bool IsConstant(BlockArgument block_argument) { - // Check if the input buffer is marked as constant. - Region* parent_region = block_argument.getParentRegion(); - auto parent_func = parent_region->getParentOfType(); - unsigned parent_func_arg_index = block_argument.getArgNumber(); - auto cst = parent_func.getArgAttrOfType(parent_func_arg_index, - "lmhlo.constant_name"); - return cst != nullptr; -} - -// Check if two buffer_uses overlap. -bool HasDependency(BufferUse buffer_use_a, BufferUse buffer_use_b) { - if (buffer_use_a.arg.getArgNumber() != buffer_use_b.arg.getArgNumber()) - return false; - if (IsConstant(buffer_use_a.arg) || IsConstant(buffer_use_b.arg)) - return false; - if (buffer_use_a.read_only && buffer_use_b.read_only) return false; - - // Check if two buffer slices overlap. - size_t start1 = buffer_use_a.offset; - size_t end1 = buffer_use_a.offset + buffer_use_a.byte_len; - size_t start2 = buffer_use_b.offset; - size_t end2 = buffer_use_b.offset + buffer_use_b.byte_len; - if (std::max(start1, start2) < std::min(end1, end2)) { - return true; - } - return false; -} - -bool HasDependency(llvm::ArrayRef buffer_uses_a, - llvm::ArrayRef buffer_uses_b) { - for (auto buffer_use_a : buffer_uses_a) { - for (auto buffer_use_b : buffer_uses_b) { - if (HasDependency(buffer_use_a, buffer_use_b)) return true; - } - } - return false; -} - -// Remove edges that are redundant for determining the execution order of -// kernels. We use the following algorithm to compute the transitive reduction: -// -// For source node in graph: -// For each edge (source -> target) -// longest_distance = the length of the longest path from source to target -// if (longest_distance > 1): -// remove (source -> target) -// -void TransitiveReduction(DataflowAnalysis::DataflowGraph& graph) { - std::vector> parents(graph.size(), std::vector()); - for (const DataflowAnalysis::Node& node : graph) { - for (size_t child_index : node.children) { - parents[child_index].push_back(node.index); - } - } - - std::vector longest_distance(graph.size()); - for (DataflowAnalysis::Node& source : graph) { - if (source.children.empty()) { - continue; - } - - std::fill(longest_distance.begin(), longest_distance.end(), 0); - size_t farthest_child = source.children.back(); - for (size_t target = source.index + 1; target <= farthest_child; target++) { - for (size_t mid : parents[target]) { - // If the mid node is before source in the topological order, no path - // source -> mid -> target can exits and we can skip it. - if (mid >= source.index) { - // If source -> mid -> target is longer than the longest path so far - // from source -> target, update the longest distance. - int candidate_longest_distance = longest_distance[mid] + 1; - if (candidate_longest_distance > longest_distance[target]) { - longest_distance[target] = candidate_longest_distance; - } - } - } - } - - source.children.erase( - std::remove_if( - source.children.begin(), source.children.end(), - [&](size_t target) { return longest_distance[target] > 1; }), - source.children.end()); - } -} - -} // namespace - -DataflowAnalysis::DataflowGraph DataflowAnalysis::GetDataflowGraph( - FuncOp graph_capture_function) { - std::vector graph; - for (auto [index, op] : llvm::enumerate(graph_capture_function.getOps())) { - graph.push_back(Node{&op, index, {}}); - } - - // A vector that stores the buffer used by each operation in the graph. The - // i-th operation's buffer uses are stored as the vector buffer_uses[i]; - std::vector> buffer_uses; - for (Operation& operation : graph_capture_function.getOps()) { - buffer_uses.push_back(GetBufferUses(operation)); - } - - for (int i = 0; i < graph.size(); ++i) { - Node& node_i = graph[i]; - llvm::ArrayRef buffer_uses_i = buffer_uses[i]; - for (int j = i + 1; j < graph.size(); ++j) { - llvm::ArrayRef buffer_uses_j = buffer_uses[j]; - if (HasDependency(buffer_uses_i, buffer_uses_j)) { - node_i.children.push_back(j); - } - } - } - - TransitiveReduction(graph); - return graph; -} - -std::string DataflowAnalysis::ToDot(const DataflowGraph& graph) { - std::string pad; - std::string res; - auto indent = [&] { pad.append(2, ' '); }; - auto outdent = [&] { pad.resize(pad.size() - 2); }; - auto addline = [&](auto&&... args) { - absl::StrAppend(&res, pad, args..., "\n"); - }; - auto get_name = [](const Node& node) -> std::string { - return absl::StrCat("\"", node.operation->getName().getStringRef().str(), - "_", node.index, "\""); - }; - - addline("digraph {"); - indent(); - for (const Node& node : graph) { - for (size_t child_index : node.children) { - Node child = graph[child_index]; - addline(get_name(node), " -> ", get_name(child)); - } - } - outdent(); - addline("}"); - return res; -} - -} // namespace gpu -} // namespace xla diff --git a/third_party/xla/xla/mlir/backends/gpu/transforms/dataflow_analysis.h b/third_party/xla/xla/mlir/backends/gpu/transforms/dataflow_analysis.h deleted file mode 100644 index 32bc001f03d211..00000000000000 --- a/third_party/xla/xla/mlir/backends/gpu/transforms/dataflow_analysis.h +++ /dev/null @@ -1,56 +0,0 @@ -/* Copyright 2023 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_MLIR_BACKENDS_GPU_TRANSFORMS_DATAFLOW_ANALYSIS_H_ -#define XLA_MLIR_BACKENDS_GPU_TRANSFORMS_DATAFLOW_ANALYSIS_H_ - -#include -#include - -#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project - -namespace xla { -namespace gpu { - -class DataflowAnalysis { - public: - explicit DataflowAnalysis(mlir::Operation* op) {} - - struct Node { - mlir::Operation* operation; - size_t index; - std::vector children; - }; - - using DataflowGraph = std::vector; - - // This function creates a dataflow graph that represent data dependencies in - // the graph capture function. The analysis relies on some properties of the - // IR in XLA: - // (1) Buffer arguments do not alias. It is guaranteed that two buffer - // arguments to the graph capture function do not overlap. - // (2) XLA operations do not have any side effects beyond writing to its - // buffer arguments. So it is safe to reorder operations if they do not - // have write-conflicts. - // (3) We have information about read-only and read-write buffer arguments. - DataflowGraph GetDataflowGraph(mlir::func::FuncOp graph_capture_function); - - std::string ToDot(const DataflowGraph& graph); -}; - -} // namespace gpu -} // namespace xla - -#endif // XLA_MLIR_BACKENDS_GPU_TRANSFORMS_DATAFLOW_ANALYSIS_H_ diff --git a/third_party/xla/xla/mlir/backends/gpu/transforms/gpu_to_gpu_runtime.cc b/third_party/xla/xla/mlir/backends/gpu/transforms/gpu_to_gpu_runtime.cc deleted file mode 100644 index 02e97984c564a3..00000000000000 --- a/third_party/xla/xla/mlir/backends/gpu/transforms/gpu_to_gpu_runtime.cc +++ /dev/null @@ -1,252 +0,0 @@ -/* Copyright 2022 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include -#include -#include - -#include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project -#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project -#include "mlir/Dialect/GPU/IR/GPUDialect.h" // from @llvm-project -#include "mlir/Dialect/MemRef/IR/MemRef.h" // from @llvm-project -#include "mlir/IR/ImplicitLocOpBuilder.h" // from @llvm-project -#include "mlir/IR/MLIRContext.h" // from @llvm-project -#include "mlir/IR/PatternMatch.h" // from @llvm-project -#include "mlir/IR/SymbolTable.h" // from @llvm-project -#include "mlir/IR/TypeRange.h" // from @llvm-project -#include "mlir/IR/ValueRange.h" // from @llvm-project -#include "mlir/Pass/Pass.h" // from @llvm-project -#include "mlir/Support/LogicalResult.h" // from @llvm-project -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project -#include "xla/mlir/backends/gpu/transforms/uid_generator.h" -#include "xla/mlir/runtime/utils/custom_calls.h" - -namespace xla { -namespace gpu { - -#define GEN_PASS_DEF_CONVERTGPUTOGPURUNTIMEPASS -#include "xla/mlir/backends/gpu/transforms/passes.h.inc" - -using namespace mlir; // NOLINT - -using mlir::gpu::GPUModuleOp; -using mlir::gpu::LaunchFuncOp; -using mlir::gpu::MemcpyOp; -using mlir::gpu::MemsetOp; - -using xla::runtime::CustomCallDeclarations; - -class ConvertGpuToGpuRuntimePass - : public impl::ConvertGpuToGpuRuntimePassBase { - void runOnOperation() override; - - void getDependentDialects(DialectRegistry& registry) const override { - registry.insert(); - } -}; - -//===----------------------------------------------------------------------===// - -class GpuModuleOpLowering : public OpRewritePattern { - public: - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(GPUModuleOp op, - PatternRewriter& rewriter) const override { - rewriter.eraseOp(op); - return success(); - } -}; - -//===----------------------------------------------------------------------===// - -class MemcpyOpLowering : public OpRewritePattern { - public: - MemcpyOpLowering(MLIRContext* ctx, CustomCallDeclarations& custom_calls) - : OpRewritePattern(ctx), custom_calls_(custom_calls) {} - - // We use a heuristic to identify the direction of the memcpy operation, if - // the operand was allocated by alloca op or is a global memref, then it must - // be a memref on the host. - static bool IsHostMemRef(Value value) { - auto* op = value.getDefiningOp(); - return llvm::isa_and_nonnull(op); - } - - // Identify the direction of the memcpy operation. - static StringRef Target(MemcpyOp op) { - if (IsHostMemRef(op.getDst())) return "xla.gpu.memcpy.d2h"; - if (IsHostMemRef(op.getSrc())) return "xla.gpu.memcpy.h2d"; - return "xla.gpu.memcpy.d2d"; - } - - LogicalResult matchAndRewrite(MemcpyOp op, - PatternRewriter& rewriter) const override { - // Get or create a custom call function declaration. - ImplicitLocOpBuilder b(op.getLoc(), rewriter); - func::FuncOp callee = custom_calls_.GetOrCreate(b, Target(op), op); - - auto stream = op->getAttrOfType("stream"); - - // Create a function launch call operation. - auto call = rewriter.replaceOpWithNewOp( - op, callee.getName(), TypeRange(), op.getOperands()); - - if (stream) { - call->setAttr(b.getStringAttr("stream"), stream); - } else { - call->setAttr(b.getStringAttr("stream"), b.getI64IntegerAttr(0)); - } - - return success(); - } - - private: - CustomCallDeclarations& custom_calls_; -}; - -//===----------------------------------------------------------------------===// - -class MemsetOpLowering : public OpRewritePattern { - private: - static constexpr const char kCustomCallTarget[] = "xla.gpu.memset"; - - public: - MemsetOpLowering(MLIRContext* ctx, CustomCallDeclarations& custom_calls) - : OpRewritePattern(ctx), custom_calls_(custom_calls) {} - - LogicalResult matchAndRewrite(MemsetOp op, - PatternRewriter& rewriter) const override { - // Get or create a custom call function declaration. - ImplicitLocOpBuilder b(op.getLoc(), rewriter); - func::FuncOp callee = custom_calls_.GetOrCreate(b, kCustomCallTarget, op); - - // Create a function launch call operation. - rewriter.replaceOpWithNewOp(op, callee.getName(), TypeRange(), - op.getOperands()); - - return success(); - } - - private: - CustomCallDeclarations& custom_calls_; -}; - -//===----------------------------------------------------------------------===// - -class LaunchFuncOpLowering : public OpRewritePattern { - private: - static constexpr const char kCustomCallTarget[] = "xla.gpu.func.launch"; - - public: - LaunchFuncOpLowering(MLIRContext* ctx, UidGenerator& uid, - CustomCallDeclarations& custom_calls) - : OpRewritePattern(ctx), uid_(uid), custom_calls_(custom_calls) {} - - LogicalResult matchAndRewrite(LaunchFuncOp op, - PatternRewriter& rewriter) const override { - ImplicitLocOpBuilder b(op.getLoc(), rewriter); - - // Cast grid and block dimensions to i32 before passing to the custom call. - auto cast = [&](mlir::Value value) { - return b.create(b.getI32Type(), value); - }; - - // Prepare arguments for the custom call. - llvm::SmallVector args = { - cast(op.getGridSizeX()), cast(op.getGridSizeY()), - cast(op.getGridSizeZ()), cast(op.getBlockSizeX()), - cast(op.getBlockSizeY()), cast(op.getBlockSizeZ())}; - - // Shared memory size is optional for the `gpu.launch` but mandatory for the - // Xla runtime kernel launch custom call. - if (op.getDynamicSharedMemorySize()) { - args.insert(args.begin(), op.getDynamicSharedMemorySize()); - } else { - auto zero = b.create(0, b.getI32Type()); - args.insert(args.begin(), zero); - } - - // Add kernel arguments. - llvm::copy(op.getKernelOperands(), std::back_inserter(args)); - - auto computation = op->getAttr("__custom_fusion_computation"); - - // Get or create a custom call function declaration. - func::FuncOp callee = custom_calls_.GetOrCreate( - b, computation ? "xla.gpu.func.custom_launch" : "xla.gpu.func.launch", - TypeRange(ValueRange(args)), TypeRange()); - - // Create a function launch call operation. - auto call = b.create(callee.getName(), TypeRange(), args); - call->setAttr(b.getStringAttr("kernel"), op.getKernelName()); - - // Assign a unique id to this instance of a kernel launch operation. - call->setAttr(b.getStringAttr("uid"), b.getI64IntegerAttr(uid_.uid())); - - // Set assigned stream for the kernel launch. - auto stream = op->getAttrOfType("stream"); - if (stream) { - call->setAttr(b.getStringAttr("stream"), stream); - } else { - call->setAttr(b.getStringAttr("stream"), b.getI64IntegerAttr(0)); - } - - // Copy custom fusion computation. - if (computation) { - call->setAttr("__custom_fusion_computation", computation); - } - - // Erase the original gpu launch operation. - rewriter.eraseOp(op); - - return success(); - } - - private: - UidGenerator& uid_; - CustomCallDeclarations& custom_calls_; -}; - -//===----------------------------------------------------------------------===// - -void ConvertGpuToGpuRuntimePass::runOnOperation() { - ModuleOp module = getOperation(); - MLIRContext* ctx = module.getContext(); - - // Keep track of the custom calls created from the lowered operations. - SymbolTable sym_table(module); - CustomCallDeclarations custom_calls(std::move(sym_table)); - - // Each kernel launch operation gets a unique id. - UidGenerator kernel_uid; - - // Convert gpu operations to XLA gpu runtime custom calls. - RewritePatternSet patterns(ctx); - patterns.insert(ctx); - patterns.insert(ctx, kernel_uid, custom_calls); - patterns.insert(ctx, custom_calls); - - if (failed(applyPatternsAndFoldGreedily(module, std::move(patterns)))) - return signalPassFailure(); -} - -std::unique_ptr> -createConvertGpuToGpuRuntimePass() { - return std::make_unique(); -} - -} // namespace gpu -} // namespace xla diff --git a/third_party/xla/xla/mlir/backends/gpu/transforms/lmhlo_gpu_to_gpu_runtime.cc b/third_party/xla/xla/mlir/backends/gpu/transforms/lmhlo_gpu_to_gpu_runtime.cc deleted file mode 100644 index 7eba728ed7aef1..00000000000000 --- a/third_party/xla/xla/mlir/backends/gpu/transforms/lmhlo_gpu_to_gpu_runtime.cc +++ /dev/null @@ -1,1112 +0,0 @@ -/* Copyright 2022 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include -#include -#include -#include -#include -#include -#include - -#include "llvm/ADT/SmallVector.h" -#include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project -#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project -#include "mlir/Dialect/MemRef/IR/MemRef.h" // from @llvm-project -#include "mlir/Dialect/SCF/IR/SCF.h" // from @llvm-project -#include "mlir/IR/Attributes.h" // from @llvm-project -#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project -#include "mlir/IR/BuiltinOps.h" // from @llvm-project -#include "mlir/IR/ImplicitLocOpBuilder.h" // from @llvm-project -#include "mlir/IR/MLIRContext.h" // from @llvm-project -#include "mlir/IR/PatternMatch.h" // from @llvm-project -#include "mlir/IR/SymbolTable.h" // from @llvm-project -#include "mlir/IR/Value.h" // from @llvm-project -#include "mlir/Pass/Pass.h" // from @llvm-project -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project -#include "xla/mlir/backends/gpu/transforms/uid_generator.h" -#include "xla/mlir/runtime/ir/rt_dialect.h" -#include "xla/mlir/runtime/utils/custom_calls.h" -#include "xla/mlir_hlo/lhlo_gpu/IR/lhlo_gpu_ops.h" -#include "xla/stream_executor/blas.h" - -namespace xla { -namespace gpu { - -#define GEN_PASS_DEF_CONVERTLMHLOGPUTOGPURUNTIMEPASS -#include "xla/mlir/backends/gpu/transforms/passes.h.inc" - -using namespace mlir; // NOLINT - -using mlir::lmhlo_gpu::CholeskyOp; -using mlir::lmhlo_gpu::ConvBackwardFilterOp; -using mlir::lmhlo_gpu::ConvBackwardInputOp; -using mlir::lmhlo_gpu::ConvForwardFusedOp; -using mlir::lmhlo_gpu::ConvForwardFusedSideInputOp; -using mlir::lmhlo_gpu::ConvForwardGraphOp; -using mlir::lmhlo_gpu::ConvForwardOp; -using mlir::lmhlo_gpu::CublasLtMatmulF8Op; -using mlir::lmhlo_gpu::CublasLtMatmulOp; -using mlir::lmhlo_gpu::CudnnConvReorderFilterAndBiasOp; -using mlir::lmhlo_gpu::CudnnConvReorderFilterOp; -using mlir::lmhlo_gpu::CudnnNormOp; -using mlir::lmhlo_gpu::GEMMOp; -using mlir::lmhlo_gpu::RadixSortOp; - -using xla::runtime::CustomCallDeclarations; - -class ConvertLmhloGpuToGpuRuntimePass - : public impl::ConvertLmhloGpuToGpuRuntimePassBase< - ConvertLmhloGpuToGpuRuntimePass> { - void runOnOperation() override; - - void getDependentDialects(DialectRegistry& registry) const override { - registry.insert(); - } -}; - -//===----------------------------------------------------------------------===// - -class GemmOpLowering : public OpRewritePattern { - static constexpr const char kCustomCallTarget[] = "xla.gpu.gemm"; - - public: - GemmOpLowering(MLIRContext* ctx, UidGenerator& uid, - CustomCallDeclarations& custom_calls) - : OpRewritePattern(ctx), uid_(uid), custom_calls_(custom_calls) {} - - LogicalResult matchAndRewrite(GEMMOp op, - PatternRewriter& rewriter) const override { - { - // Set requires_blas attribute to true. The runtime pass will add cuBLAS - // initialization custom call to the entry function if the attribute is - // set to true. - auto module = op.getOperation()->getParentOfType(); - ImplicitLocOpBuilder b(module.getLoc(), rewriter); - module->setAttr(b.getStringAttr(runtime::kRequiresBlasAttrName), - BoolAttr::get(b.getContext(), true)); - } - - // Get or create a custom call function declaration. - ImplicitLocOpBuilder b(op.getLoc(), rewriter); - func::FuncOp callee = custom_calls_.GetOrCreate(b, kCustomCallTarget, op); - - // Convert Gemm to a function call. - auto call = rewriter.create(op.getLoc(), callee.getName(), - TypeRange(), op.getOperands()); - - // Assign a unique id to this instance of a gemm operation. - call->setAttr(b.getStringAttr("uid"), b.getI64IntegerAttr(uid_.uid())); - - // Copy backend specific attributes. - auto algorithm_attr = - op.getAlgorithm() - ? op.getAlgorithmAttr() - : b.getI64IntegerAttr(stream_executor::blas::kDefaultGemmAlgo); - call->setAttr(b.getStringAttr("algorithm"), algorithm_attr); - call->setAttr(b.getStringAttr("alpha_imag"), op.getAlphaImagAttr()); - call->setAttr(b.getStringAttr("alpha_real"), op.getAlphaRealAttr()); - call->setAttr(b.getStringAttr("beta"), op.getBetaAttr()); - call->setAttr(b.getStringAttr("dot_dims"), op.getDotDimensionNumbers()); - - if (auto precisions = op.getPrecisionConfig()) { - llvm::SmallVector values; - for (auto precision : *precisions) { - auto value = precision.cast().getValue(); - values.push_back(static_cast(value)); - } - call->setAttr(b.getStringAttr("precision"), b.getI32TensorAttr(values)); - } else { - call->setAttr(b.getStringAttr("precision"), b.getI32TensorAttr({0, 0})); - } - - // Erase the original gemm operation. - rewriter.eraseOp(op); - - return success(); - } - - private: - UidGenerator& uid_; - CustomCallDeclarations& custom_calls_; -}; - -//===----------------------------------------------------------------------===// - -class CublasLtMatmulOpLowering : public OpRewritePattern { - private: - static constexpr const char kCustomCallTarget[] = "xla.gpu.cublas.lt.matmul"; - - public: - CublasLtMatmulOpLowering(MLIRContext* ctx, UidGenerator& uid, - CustomCallDeclarations& custom_calls) - : OpRewritePattern(ctx), - uid_(uid), - custom_calls_(custom_calls) {} - - LogicalResult matchAndRewrite(CublasLtMatmulOp op, - PatternRewriter& rewriter) const override { - // Get the custom call target. - std::string matmul = kCustomCallTarget; - - switch (op.getEpilogue()) { - case mlir::lmhlo_gpu::CublasLtMatmulEpilogue::Default: - case mlir::lmhlo_gpu::CublasLtMatmulEpilogue::Relu: - case mlir::lmhlo_gpu::CublasLtMatmulEpilogue::Gelu: - if (op.getNumOperands() != 4) { - return op.emitOpError("unexpected number of operands for matmul"); - } - break; - case mlir::lmhlo_gpu::CublasLtMatmulEpilogue::Bias: - case mlir::lmhlo_gpu::CublasLtMatmulEpilogue::BiasRelu: - case mlir::lmhlo_gpu::CublasLtMatmulEpilogue::BiasGelu: - if (op.getNumOperands() != 5) { - return op.emitOpError("unexpected number of operands for matmul"); - } - matmul += ".bias"; - break; - case mlir::lmhlo_gpu::CublasLtMatmulEpilogue::GeluAux: - if (op.getNumOperands() != 5) { - return op.emitOpError("unexpected number of operands for matmul"); - } - matmul += ".aux"; - break; - case mlir::lmhlo_gpu::CublasLtMatmulEpilogue::BiasGeluAux: - if (op.getNumOperands() != 6) { - return op.emitOpError("unexpected number of operands for matmul"); - } - matmul += ".bias.aux"; - break; - } - - // Get or create a custom call function declaration. - ImplicitLocOpBuilder b(op.getLoc(), rewriter); - func::FuncOp callee = custom_calls_.GetOrCreate(b, matmul, op); - - // Convert matmul to a function call. - auto call = rewriter.create(op.getLoc(), callee.getName(), - TypeRange(), op.getOperands()); - - // Assign a unique id to this instance of a matmul operation. - call->setAttr(b.getStringAttr("uid"), b.getI64IntegerAttr(uid_.uid())); - - // Copy backend specific attributes. - call->setAttr(b.getStringAttr("algorithm"), op.getAlgorithmAttr()); - call->setAttr(b.getStringAttr("alpha_imag"), op.getAlphaImagAttr()); - call->setAttr(b.getStringAttr("alpha_real"), op.getAlphaRealAttr()); - call->setAttr(b.getStringAttr("beta"), op.getBetaAttr()); - call->setAttr(b.getStringAttr("dot_dims"), op.getDotDimensionNumbers()); - call->setAttr(b.getStringAttr("epilogue"), op.getEpilogueAttr()); - - // TODO(ezhulenev): Today we can't pass an array of enum attributes to the - // custom call. Also we do not have a corresponding precision enum on the - // SE/XLA side, so we encode it as an i32 array (tensor). - if (auto precisions = op.getPrecisionConfig()) { - llvm::SmallVector values; - for (auto precision : *precisions) { - auto value = precision.cast().getValue(); - values.push_back(static_cast(value)); - } - call->setAttr(b.getStringAttr("precision"), b.getI32TensorAttr(values)); - } else { - call->setAttr(b.getStringAttr("precision"), b.getI32TensorAttr({0, 0})); - } - - // Erase the original matmul operation. - rewriter.eraseOp(op); - - return success(); - } - - private: - UidGenerator& uid_; - CustomCallDeclarations& custom_calls_; -}; - -// As above for FP8 Custom Calls. -class CublasLtMatmulF8OpLowering : public OpRewritePattern { - private: - static constexpr const char kCustomCallTarget[] = - "xla.gpu.cublas.lt.matmul.f8"; - - public: - CublasLtMatmulF8OpLowering(MLIRContext* ctx, UidGenerator& uid, - CustomCallDeclarations& custom_calls) - : OpRewritePattern(ctx), - uid_(uid), - custom_calls_(custom_calls) {} - - LogicalResult matchAndRewrite(CublasLtMatmulF8Op op, - PatternRewriter& rewriter) const override { - // Get or create a custom call function declaration. - ImplicitLocOpBuilder b(op.getLoc(), rewriter); - func::FuncOp callee = custom_calls_.GetOrCreate(b, kCustomCallTarget, op); - - // Convert matmul to a function call. - auto call = rewriter.create(op.getLoc(), callee.getName(), - TypeRange(), op.getOperands()); - - // Assign a unique id to this instance of a matmul operation. - call->setAttr(b.getStringAttr("uid"), b.getI64IntegerAttr(uid_.uid())); - - // Copy backend specific attributes. - call->setAttr(b.getStringAttr("algorithm"), op.getAlgorithmAttr()); - call->setAttr(b.getStringAttr("alpha_imag"), op.getAlphaImagAttr()); - call->setAttr(b.getStringAttr("alpha_real"), op.getAlphaRealAttr()); - call->setAttr(b.getStringAttr("beta"), op.getBetaAttr()); - call->setAttr(b.getStringAttr("dot_dims"), op.getDotDimensionNumbers()); - call->setAttr(b.getStringAttr("epilogue"), op.getEpilogueAttr()); - - // TODO(ezhulenev): Today we can't pass an array of enum attributes to the - // custom call. Also we do not have a corresponding precision enum on the - // SE/XLA side, so we encode it as an i32 array (tensor). - if (auto precisions = op.getPrecisionConfig()) { - llvm::SmallVector values; - for (auto precision : *precisions) { - auto value = precision.cast().getValue(); - values.push_back(static_cast(value)); - } - call->setAttr(b.getStringAttr("precision"), b.getI32TensorAttr(values)); - } else { - call->setAttr(b.getStringAttr("precision"), b.getI32TensorAttr({0, 0})); - } - - // Erase the original matmul operation. - rewriter.eraseOp(op); - - return success(); - } - - private: - UidGenerator& uid_; - CustomCallDeclarations& custom_calls_; -}; - -//===----------------------------------------------------------------------===// - -template -class ConvOpLowering : public OpRewritePattern { - private: - static StringRef CustomCallTarget(ConvForwardOp) { - return "xla.gpu.conv.forward"; - } - static StringRef CustomCallTarget(ConvForwardFusedOp) { - return "xla.gpu.conv.forward.fused"; - } - static StringRef CustomCallTarget(ConvForwardFusedSideInputOp) { - return "xla.gpu.conv.forward.fused.side_input"; - } - static StringRef CustomCallTarget(ConvBackwardFilterOp) { - return "xla.gpu.conv.backward.filter"; - } - static StringRef CustomCallTarget(ConvBackwardInputOp) { - return "xla.gpu.conv.backward.input"; - } - static StringRef CustomCallTarget(ConvForwardGraphOp) { - return "xla.gpu.conv.forward.graph"; - } - - public: - explicit ConvOpLowering(MLIRContext* ctx, UidGenerator& uid, - CustomCallDeclarations& custom_calls) - : OpRewritePattern(ctx), uid_(uid), custom_calls_(custom_calls) {} - - LogicalResult matchAndRewrite(Conv op, - PatternRewriter& rewriter) const override { - // Get or create a custom call function declaration. - ImplicitLocOpBuilder b(op.getLoc(), rewriter); - func::FuncOp callee = - custom_calls_.GetOrCreate(b, CustomCallTarget(op), op); - - // Convert Conv to a function call. - auto call = rewriter.create(op.getLoc(), callee.getName(), - TypeRange(), op.getOperands()); - - // Helper functins to copy attributes from the conv op to the custom call. - auto set_attr = [&](StringRef name, Attribute attr) { - call->setAttr(b.getStringAttr(name), attr); - }; - - auto set_xi64 = [&](StringRef name, - std::optional attr) { - SmallVector values; - if (attr.has_value()) - values = llvm::to_vector(attr->getValues()); - set_attr(name, b.getI64TensorAttr(values)); - }; - - // Convert `BoolElementsAttr` to i64 before passing to the runtime. - // TODO(ezhulenev): Allow passing boolean tensors to the XLA custom calls. - auto set_xi1 = [&](StringRef name, std::optional attr) { - SmallVector values; - if (attr.has_value()) - values.assign(attr->getValues().begin(), - attr->getValues().end()); - set_attr(name, b.getI64TensorAttr(values)); - }; - - // Assign a unique id to this instance of a conv operation. - call->setAttr(b.getStringAttr("uid"), b.getI64IntegerAttr(uid_.uid())); - - // Copy dimension number attributes. - call->setAttr(b.getStringAttr("conv_dims"), op.getDimensionNumbers()); - - // Copy convolution window attributes. - set_xi1("window_reversal", op.getWindowReversal()); - set_xi64("window_strides", op.getWindowStrides()); - set_xi64("lhs_dilation", op.getLhsDilation()); - set_xi64("rhs_dilation", op.getRhsDilation()); - set_xi64("padding", op.getPadding()); - - // Copy backend config. - call->setAttr(b.getStringAttr("backend_config"), op.getBackendConfig()); - - // Copy remaining attributes. - set_attr("feature_group_count", op.getFeatureGroupCountAttr()); - set_attr("result_scale", op.getResultScaleAttr()); - - // Copy attributes specific for fused convolutions. - if (auto fused = dyn_cast(op.getOperation())) { - call->setAttr(b.getStringAttr("activation_mode"), - fused.getActivationModeAttr()); - set_attr("leakyrelu_alpha", fused.getLeakyreluAlphaAttr()); - } - - // Copy attributes specific for fused convolutions with side input. - if (auto fused = dyn_cast(op.getOperation())) { - call->setAttr(b.getStringAttr("activation_mode"), - fused.getActivationModeAttr()); - set_attr("side_input_scale", fused.getSideInputScaleAttr()); - } - - // Copy attributes specific for graph convolutions. - if (auto fused = dyn_cast(op.getOperation())) { - call->setAttr(b.getStringAttr("n_aux_outputs"), - fused.getNAuxOutputsAttr()); - call->setAttr(b.getStringAttr("serialized_graph"), - fused.getSerializedGraphAttr()); - } - - // Erase the original conv operation. - rewriter.eraseOp(op); - - return success(); - } - - private: - UidGenerator& uid_; - CustomCallDeclarations& custom_calls_; -}; - -class ConvForwardOpLowering : public ConvOpLowering { - public: - using ConvOpLowering::ConvOpLowering; -}; - -class ConvForwardFusedOpLowering : public ConvOpLowering { - public: - using ConvOpLowering::ConvOpLowering; -}; - -class ConvBackwardFilterOpLowering - : public ConvOpLowering { - public: - using ConvOpLowering::ConvOpLowering; -}; - -class ConvBackwardInputOpLowering : public ConvOpLowering { - public: - using ConvOpLowering::ConvOpLowering; -}; - -class ConvForwardFusedSideInputOpLowering - : public ConvOpLowering { - public: - using ConvOpLowering::ConvOpLowering; -}; - -class ConvForwardGraphOpLowering : public ConvOpLowering { - public: - using ConvOpLowering::ConvOpLowering; -}; - -//===----------------------------------------------------------------------===// - -template -class CudnnConvReorderOpLowering : public OpRewritePattern { - private: - static StringRef CustomCallTarget(CudnnConvReorderFilterOp) { - return "xla.gpu.conv.reorder.filter"; - } - static StringRef CustomCallTarget(CudnnConvReorderFilterAndBiasOp) { - return "xla.gpu.conv.reorder.filter_and_bias"; - } - - public: - explicit CudnnConvReorderOpLowering(MLIRContext* ctx, - CustomCallDeclarations& custom_calls) - : OpRewritePattern(ctx), custom_calls_(custom_calls) {} - - LogicalResult matchAndRewrite(ConvReorder op, - PatternRewriter& rewriter) const override { - // Get or create a custom call function declaration. - ImplicitLocOpBuilder b(op.getLoc(), rewriter); - func::FuncOp callee = - custom_calls_.GetOrCreate(b, CustomCallTarget(op), op); - - auto filterDims = rewriter.getDenseI64ArrayAttr( - llvm::to_vector(op.getFilterDims().template getValues())); - - // Replace ConvOp with an equivalent custom call. - auto call = rewriter.replaceOpWithNewOp( - op, callee.getName(), TypeRange(), op.getOperands()); - call->setAttr(b.getStringAttr("filter_dims"), filterDims); - - return success(); - } - - private: - CustomCallDeclarations& custom_calls_; -}; - -class CudnnConvReorderFilterOpLowering - : public CudnnConvReorderOpLowering { - public: - using CudnnConvReorderOpLowering::CudnnConvReorderOpLowering; -}; - -class CudnnConvReorderFilterAndBiasOpLowering - : public CudnnConvReorderOpLowering { - public: - using CudnnConvReorderOpLowering::CudnnConvReorderOpLowering; -}; - -//===----------------------------------------------------------------------===// - -class CholeskyOpLowering : public OpRewritePattern { - private: - static constexpr const char kCustomCallTarget[] = "xla.gpu.cholesky"; - - public: - explicit CholeskyOpLowering(MLIRContext* ctx, - CustomCallDeclarations& custom_calls) - : OpRewritePattern(ctx), custom_calls_(custom_calls) {} - - LogicalResult matchAndRewrite(CholeskyOp op, - PatternRewriter& rewriter) const override { - // Get or create a custom call function declaration. - ImplicitLocOpBuilder b(op.getLoc(), rewriter); - func::FuncOp callee = custom_calls_.GetOrCreate(b, kCustomCallTarget, op); - - // Convert Cholesky to a function call. - auto call = rewriter.create(op.getLoc(), callee.getName(), - TypeRange(), op.getOperands()); - - const auto& dims = - op.getInput().getType().cast().getShape(); - if (dims.size() < 2) - return op.emitOpError() << "Input's dimension count (" << dims.size() - << ") must be 2 or greater."; - int64_t n = dims[dims.size() - 1]; - int64_t batch_size = - std::accumulate(dims.begin(), dims.end() - 2, int64_t{1}, - [](int64_t a, int64_t b) { return a * b; }); - - // Copy backend specific attributes. - call->setAttr(b.getStringAttr("batch_size"), - b.getI64IntegerAttr(batch_size)); - call->setAttr(b.getStringAttr("n"), b.getI64IntegerAttr(n)); - call->setAttr(b.getStringAttr("is_lower"), op.getIsLowerAttr()); - - // Erase the original Cholesky operation. - rewriter.eraseOp(op); - - return success(); - } - - private: - CustomCallDeclarations& custom_calls_; -}; - -class NormOpLowering : public OpRewritePattern { - private: - static constexpr const char kCustomCallTarget[] = "xla.gpu.norm"; - - public: - NormOpLowering(MLIRContext* ctx, UidGenerator& uid, - CustomCallDeclarations& custom_calls) - : OpRewritePattern(ctx), - uid_(uid), - custom_calls_(custom_calls) {} - - LogicalResult matchAndRewrite(CudnnNormOp op, - PatternRewriter& rewriter) const override { - // Get or create a Custom Call function declaration. - ImplicitLocOpBuilder b(op.getLoc(), rewriter); - func::FuncOp callee = custom_calls_.GetOrCreate(b, kCustomCallTarget, op); - - // Convert norm to a function call. - auto call = rewriter.create(op.getLoc(), callee.getName(), - TypeRange(), op.getOperands()); - - // Assign a unique id to this instance of a norm operation. - call->setAttr(b.getStringAttr("uid"), b.getI64IntegerAttr(uid_.uid())); - - // Copy backend specific attributes. - call->setAttr(b.getStringAttr("norm_algorithm_config"), - op.getAlgorithmConfigAttr()); - call->setAttr(b.getStringAttr("epsilon"), op.getEpsilonAttr()); - - mlir::ArrayAttr array = op.getOperandLayouts(); - SmallVector values; - for (auto array_elem : array) { - mlir::IntegerAttr attr = array_elem.dyn_cast(); - values.push_back(attr.getInt()); - } - call->setAttr(b.getStringAttr("operand_layouts"), - b.getI64TensorAttr(values)); - - // Erase the original norm operation. - rewriter.eraseOp(op); - - return success(); - } - - private: - UidGenerator& uid_; - CustomCallDeclarations& custom_calls_; -}; - -using mlir::lmhlo_gpu::fusedMHAOp; - -template -class FusedAttentionForwardLowering - : public OpRewritePattern { - private: - static constexpr const char kCustomCallTarget[] = "xla.gpu.fused.attention."; - - public: - explicit FusedAttentionForwardLowering(MLIRContext* ctx, UidGenerator& uid, - CustomCallDeclarations& custom_calls) - : OpRewritePattern(ctx), - uid_(uid), - custom_calls_(custom_calls) {} - - LogicalResult matchAndRewrite(FusedDotAttentionForward op, - PatternRewriter& rewriter) const override { - // Get the custom call target. - std::string fused_attention = kCustomCallTarget; - auto num_operands = op.getNumOperands(); - switch (op.getFusedMhaDag()) { - case mlir::lmhlo_gpu::FusedMhaDagSignature::Default: - if (num_operands == 5) { - fused_attention += "bmm.bmm.inference"; - } else if (num_operands == 6) { - fused_attention += "bmm.bmm.forward"; - } else { - return op.emitOpError( - "unexpected number of operands for fused dot attention - BMMBMM"); - } - break; - case mlir::lmhlo_gpu::FusedMhaDagSignature::Softmax: - if (num_operands == 5) { - fused_attention += "softmax.inference"; - } else if (num_operands == 6) { - fused_attention += "softmax.forward"; - } else { - return op.emitOpError( - "unexpected number of operands for fused dot attention - " - "BMM_Softmax_BMM"); - } - break; - case mlir::lmhlo_gpu::FusedMhaDagSignature::SoftmaxDropout: - if (num_operands == 5) { - fused_attention += "softmax.dropout.inference"; - } else if (num_operands == 6) { - fused_attention += "softmax.dropout.forward"; - } else { - return op.emitOpError( - "unexpected number of operands for fused dot attention - " - "BMM_Softmax_Dropout_BMM"); - } - break; - - case mlir::lmhlo_gpu::FusedMhaDagSignature::ScaleBiasMaskSoftmax: - if (num_operands == 7) { - fused_attention += "scale.bias.mask.softmax.inference"; - } else if (num_operands == 8) { - fused_attention += "scale.bias.mask.softmax.forward"; - } else { - return op.emitOpError( - "unexpected number of operands for fused dot attention - " - "BMM_Bias_Mask_Softmax_BMM"); - } - break; - - case mlir::lmhlo_gpu::FusedMhaDagSignature::ScaleBiasMaskSoftmaxDropout: - if (num_operands == 7) { - fused_attention += "scale.bias.mask.softmax.dropout.inference"; - } else if (num_operands == 8) { - fused_attention += "scale.bias.mask.softmax.dropout.forward"; - } else { - return op.emitOpError( - "unexpected number of operands for fused dot attention - " - "BMM_Bias_Mask_Softmax_Dropout_BMM"); - } - break; - - case mlir::lmhlo_gpu::FusedMhaDagSignature::ScaleMaskSoftmax: - if (num_operands == 6) { - fused_attention += "scale.mask.softmax.inference"; - } else if (num_operands == 7) { - fused_attention += "scale.mask.softmax.forward"; - } else { - return op.emitOpError( - "unexpected number of operands for fused dot attention - " - "BMM_mask_Softmax_BMM"); - } - break; - - case mlir::lmhlo_gpu::FusedMhaDagSignature::ScaleMaskSoftmaxDropout: - if (num_operands == 6) { - fused_attention += "scale.mask.softmax.dropout.inference"; - } else if (num_operands == 7) { - fused_attention += "scale.mask.softmax.dropout.forward"; - } else { - return op.emitOpError( - "unexpected number of operands for fused dot attention - " - "BMM_mask_Softmax_Dropout_BMM"); - } - break; - - case mlir::lmhlo_gpu::FusedMhaDagSignature::ScaleBiasSoftmax: - if (num_operands == 6) { - fused_attention += "scale.bias.softmax.inference"; - } else if (num_operands == 7) { - fused_attention += "scale.bias.softmax.forward"; - } else { - return op.emitOpError( - "unexpected number of operands for fused dot attention - " - "BMM_bias_Softmax_BMM"); - } - break; - - case mlir::lmhlo_gpu::FusedMhaDagSignature::ScaleBiasSoftmaxDropout: - if (num_operands == 6) { - fused_attention += "scale.bias.softmax.dropout.inference"; - } else if (num_operands == 7) { - fused_attention += "scale.bias.softmax.dropout.forward"; - } else { - return op.emitOpError( - "unexpected number of operands for fused dot attention - " - "BMM_bias_Softmax_Dropout_BMM"); - } - break; - - default: - return op.emitOpError("Undefined fused dot attention DAG signature"); - } - - // Get or create a custom call function declaration. - ImplicitLocOpBuilder b(op.getLoc(), rewriter); - func::FuncOp callee = custom_calls_.GetOrCreate(b, fused_attention, op); - - // Convert fused_attention to a function call. - auto call = rewriter.create(op.getLoc(), callee.getName(), - TypeRange(), op.getOperands()); - - // Assign a unique id to this instance of a fused_attention operation. - call->setAttr(b.getStringAttr("uid"), b.getI64IntegerAttr(uid_.uid())); - - // Helper functins to copy attributes from the conv op to the custom call. - auto set_attr = [&](StringRef name, Attribute attr) { - if (attr) { - call->setAttr(b.getStringAttr(name), attr); - } - }; - - set_attr("fmha_scale", op.getFmhaScaleAttr()); - set_attr("dropout_rate", op.getDropoutRateAttr()); - set_attr("seed", op.getSeedAttr()); - set_attr("is_flash_attention", op.getIsFlashAttentionAttr()); - set_attr("is_causal_mask", op.getIsCausalMaskAttr()); - set_attr("fused_mha_dag", op.getFusedMhaDagAttr()); - set_attr("algorithm_config", op.getAlgorithmConfigAttr()); - set_attr("bmm1_dot_dimension_numbers", op.getBmm1DotDimensionNumbers()); - set_attr("bmm2_dot_dimension_numbers", op.getBmm2DotDimensionNumbers()); - - auto set_xi64 = [&](StringRef name, mlir::ArrayAttr array) { - int rank = array.size(); - SmallVector values; - for (int i = 0; i < rank; i++) { - mlir::IntegerAttr attr = array[i].dyn_cast(); - values.push_back(attr.getInt()); - } - set_attr(name, b.getI64TensorAttr(values)); - }; - - set_xi64("intermediate_tensor_dimensions", - op.getIntermediateTensorDimensions()); - set_xi64("intermediate_tensor_layout", op.getIntermediateTensorLayout()); - - // Erase the original fused dot attention operation. - rewriter.eraseOp(op); - - return success(); - } - - private: - UidGenerator& uid_; - CustomCallDeclarations& custom_calls_; -}; - -class FusedAttentionForwardOpLowering - : public FusedAttentionForwardLowering { - public: - using FusedAttentionForwardLowering::FusedAttentionForwardLowering; -}; - -using mlir::lmhlo_gpu::fusedMHABackwardOp; - -template -class FusedAttentionBackwardLowering - : public OpRewritePattern { - private: - static constexpr const char kFusedAttentionCustomCallTarget[] = - "xla.gpu.fused.attention.backward."; - static constexpr const char kFlashAttentionCustomCallTarget[] = - "xla.gpu.flash.attention.backward."; - - public: - explicit FusedAttentionBackwardLowering(MLIRContext* ctx, UidGenerator& uid, - CustomCallDeclarations& custom_calls) - : OpRewritePattern(ctx), - uid_(uid), - custom_calls_(custom_calls) {} - - LogicalResult matchAndRewrite(FusedDotAttentionBackward op, - PatternRewriter& rewriter) const override { - // Get the custom call target. - bool is_flash_attention = op.getIsFlashAttention(); - std::string fused_attention = is_flash_attention - ? kFlashAttentionCustomCallTarget - : kFusedAttentionCustomCallTarget; - auto num_operands = op.getNumOperands(); - switch (op.getFusedMhaDag()) { - case mlir::lmhlo_gpu::FusedMhaBackwardDagSignature::BackwardSoftmax: - if (is_flash_attention) { - if (num_operands == 12) { - fused_attention += "scale.softmax"; - } else { - return op.emitOpError( - "unexpected number of operands for flash attention backward - " - "BMM_Softmax_BMM"); - } - } - break; - - case mlir::lmhlo_gpu::FusedMhaBackwardDagSignature:: - BackwardSoftmaxDropout: - if (is_flash_attention) { - if (num_operands == 12) { - fused_attention += "scale.softmax.dropout"; - } else { - return op.emitOpError( - "unexpected number of operands for flash attention backward - " - "BMM_Softmax_Dropout_BMM"); - } - } - break; - case mlir::lmhlo_gpu::FusedMhaBackwardDagSignature:: - BackwardScaleBiasSoftmax: - if (is_flash_attention) { - if (num_operands == 13) { - fused_attention += "scale.bias.softmax"; - } else { - return op.emitOpError( - "unexpected number of operands for flash attention backward - " - "BMM_Bias_Softmax_BMM"); - } - break; - } - if (num_operands == 10) { - fused_attention += "scale.softmax"; - } else if (num_operands == 11) { - fused_attention += "scale.dbias.softmax"; - } else { - return op.emitOpError( - "unexpected number of operands for fused attention backward - " - "BMM_Bias_Softmax_BMM"); - } - break; - - case mlir::lmhlo_gpu::FusedMhaBackwardDagSignature:: - BackwardScaleBiasSoftmaxDropout: - if (is_flash_attention) { - if (num_operands == 13) { - fused_attention += "scale.bias.softmax.dropout"; - } else { - return op.emitOpError( - "unexpected number of operands for flash attention backward - " - "BMM_Bias_Softmax_Dropout_BMM"); - } - break; - } - if (num_operands == 10) { - fused_attention += "scale.softmax.dropout"; - } else if (num_operands == 11) { - fused_attention += "scale.dbias.softmax.dropout"; - } else { - return op.emitOpError( - "unexpected number of operands for fused attention backward - " - "BMM_Bias_Softmax_Dropout_BMM"); - } - break; - - case mlir::lmhlo_gpu::FusedMhaBackwardDagSignature:: - BackwardScaleBiasMaskSoftmax: - if (is_flash_attention) { - if (num_operands == 14) { - fused_attention += "scale.bias.mask.softmax"; - } else { - return op.emitOpError( - "unexpected number of operands for flash attention backward - " - "BMM_Bias_Mask_Softmax_BMM"); - } - break; - } - if (num_operands == 11) { - fused_attention += "scale.mask.softmax"; - } else if (num_operands == 12) { - fused_attention += "scale.dbias.mask.softmax"; - } else { - return op.emitOpError( - "unexpected number of operands for fused attention backward - " - "BMM_Bias_Mask_Softmax_BMM"); - } - break; - - case mlir::lmhlo_gpu::FusedMhaBackwardDagSignature:: - BackwardScaleBiasMaskSoftmaxDropout: - if (is_flash_attention) { - if (num_operands == 14) { - fused_attention += "scale.bias.mask.softmax.dropout"; - } else { - return op.emitOpError( - "unexpected number of operands for flash attention backward - " - "BMM_Bias_Mask_Softmax_Dropout_BMM"); - } - break; - } - if (num_operands == 11) { - fused_attention += "scale.mask.softmax.dropout"; - } else if (num_operands == 12) { - fused_attention += "scale.dbias.mask.softmax.dropout"; - } else { - return op.emitOpError( - "unexpected number of operands for fused attention backward - " - "BMM_Bias_Mask_Softmax_Dropout_BMM"); - } - break; - - case mlir::lmhlo_gpu::FusedMhaBackwardDagSignature:: - BackwardScaleMaskSoftmax: - if (is_flash_attention) { - if (num_operands == 13) { - fused_attention += "scale.mask.softmax"; - } else { - return op.emitOpError( - "unexpected number of operands for flash attention backward - " - "BMM_Mask_Softmax_BMM"); - } - break; - } - break; - - case mlir::lmhlo_gpu::FusedMhaBackwardDagSignature:: - BackwardScaleMaskSoftmaxDropout: - if (is_flash_attention) { - if (num_operands == 13) { - fused_attention += "scale.mask.softmax.dropout"; - } else { - return op.emitOpError( - "unexpected number of operands for flash attention backward - " - "BMM_Mask_Softmax_Dropout_BMM"); - } - break; - } - break; - default: - return op.emitOpError("Undefined fused attention DAG signature"); - } - - // Get or create a custom call function declaration. - ImplicitLocOpBuilder b(op.getLoc(), rewriter); - func::FuncOp callee = custom_calls_.GetOrCreate(b, fused_attention, op); - - // Convert fused_attention to a function call. - auto call = rewriter.create(op.getLoc(), callee.getName(), - TypeRange(), op.getOperands()); - - // Assign a unique id to this instance of a fused_attention operation. - call->setAttr(b.getStringAttr("uid"), b.getI64IntegerAttr(uid_.uid())); - - // Helper functins to copy attributes from the conv op to the custom call. - auto set_attr = [&](StringRef name, Attribute attr) { - if (attr) { - call->setAttr(b.getStringAttr(name), attr); - } - }; - - set_attr("fmha_scale", op.getFmhaScaleAttr()); - set_attr("dropout_rate", op.getDropoutRateAttr()); - set_attr("seed", op.getSeedAttr()); - set_attr("is_flash_attention", op.getIsFlashAttentionAttr()); - set_attr("is_causal_mask", op.getIsCausalMaskAttr()); - set_attr("fused_mha_dag", op.getFusedMhaDagAttr()); - set_attr("algorithm_config", op.getAlgorithmConfigAttr()); - set_attr("bmm1_grad_gemm1_dot_dimension_numbers", - op.getBmm1GradGemm1DotDimensionNumbers()); - set_attr("bmm1_grad_gemm2_dot_dimension_numbers", - op.getBmm1GradGemm2DotDimensionNumbers()); - set_attr("bmm2_grad_gemm1_dot_dimension_numbers", - op.getBmm2GradGemm1DotDimensionNumbers()); - set_attr("bmm2_grad_gemm2_dot_dimension_numbers", - op.getBmm2GradGemm2DotDimensionNumbers()); - - auto set_xi64 = [&](StringRef name, mlir::ArrayAttr array) { - int rank = array.size(); - SmallVector values; - for (int i = 0; i < rank; i++) { - mlir::IntegerAttr attr = array[i].dyn_cast(); - values.push_back(attr.getInt()); - } - set_attr(name, b.getI64TensorAttr(values)); - }; - - set_xi64("intermediate_tensor_dimensions", - op.getIntermediateTensorDimensions()); - set_xi64("intermediate_tensor_layout", op.getIntermediateTensorLayout()); - - // Erase the original fused dot attention operation. - rewriter.eraseOp(op); - - return success(); - } - - private: - UidGenerator& uid_; - CustomCallDeclarations& custom_calls_; -}; - -class FusedAttentionBackwardOpLowering - : public FusedAttentionBackwardLowering { - public: - using FusedAttentionBackwardLowering::FusedAttentionBackwardLowering; -}; - -class RadixSortOpLowering : public OpRewritePattern { - private: - static constexpr const char kSortKeysTarget[] = "xla.gpu.radix_sort_keys"; - static constexpr const char kSortPairsTarget[] = "xla.gpu.radix_sort_pairs"; - - public: - explicit RadixSortOpLowering(MLIRContext* ctx, - CustomCallDeclarations& custom_calls) - : OpRewritePattern(ctx), custom_calls_(custom_calls) {} - - LogicalResult matchAndRewrite(RadixSortOp op, - PatternRewriter& rewriter) const override { - // Get or create a custom call function declaration. - ImplicitLocOpBuilder b(op.getLoc(), rewriter); - func::FuncOp callee = custom_calls_.GetOrCreate( - b, op.getOperands().size() == 3 ? kSortKeysTarget : kSortPairsTarget, - op); - - // Convert radix sort to a function call. - auto call = rewriter.create(op.getLoc(), callee.getName(), - TypeRange(), op.getOperands()); - call->setAttr(b.getStringAttr("descending"), op.getDescendingAttr()); - - // Erase the original operation. - rewriter.eraseOp(op); - - return success(); - } - - private: - CustomCallDeclarations& custom_calls_; -}; - -//===----------------------------------------------------------------------===// - -void ConvertLmhloGpuToGpuRuntimePass::runOnOperation() { - ModuleOp module = getOperation(); - MLIRContext* ctx = module.getContext(); - - // Keep track of the custom calls created from the lowered operations. - SymbolTable sym_table(module); - CustomCallDeclarations custom_calls(std::move(sym_table)); - - // Convert lmhlo_gpu operations to XLA gpu runtime custom calls. - RewritePatternSet patterns(ctx); - - // Each unique Gemm/Matmul operation in the module will get assigned a uid. - UidGenerator matmul_uid; - patterns.insert(ctx, matmul_uid, custom_calls); - - // Each unique Conv operation in the module will get assigned a uid. - UidGenerator conv_uid; - patterns - .insert( - ctx, conv_uid, custom_calls); - - // Patterns for every other Gpu operation. - patterns.insert(ctx, custom_calls); - patterns.insert(ctx, custom_calls); - patterns.insert(ctx, custom_calls); - patterns.insert(ctx, custom_calls); - - // Each unique Norm operation in the module will get assigned a uid. - UidGenerator norm_uid; - patterns.insert(ctx, norm_uid, custom_calls); - - // Each unique fused_attention operation in the module will get assigned a - // uid. - UidGenerator fused_attention_uid; - patterns.insert(ctx, fused_attention_uid, - custom_calls); - - // Each unique fused_attention_backward operation in the module will get - // assigned a uid. - UidGenerator fused_attention_backward_uid; - patterns.insert( - ctx, fused_attention_backward_uid, custom_calls); - - if (failed(applyPatternsAndFoldGreedily(module, std::move(patterns)))) - return signalPassFailure(); -} - -std::unique_ptr> -createConvertLmhloGpuToGpuRuntimePass() { - return std::make_unique(); -} - -} // namespace gpu -} // namespace xla diff --git a/third_party/xla/xla/mlir/backends/gpu/transforms/lmhlo_to_gpu_launch.cc b/third_party/xla/xla/mlir/backends/gpu/transforms/lmhlo_to_gpu_launch.cc deleted file mode 100644 index 39964a3467cdcc..00000000000000 --- a/third_party/xla/xla/mlir/backends/gpu/transforms/lmhlo_to_gpu_launch.cc +++ /dev/null @@ -1,488 +0,0 @@ -/* Copyright 2022 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include "absl/status/statusor.h" -#include "llvm/ADT/APFloat.h" -#include "llvm/ADT/STLExtras.h" -#include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project -#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project -#include "mlir/Dialect/GPU/IR/GPUDialect.h" // from @llvm-project -#include "mlir/IR/Builders.h" // from @llvm-project -#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project -#include "mlir/IR/BuiltinTypes.h" // from @llvm-project -#include "mlir/IR/Location.h" // from @llvm-project -#include "mlir/IR/MLIRContext.h" // from @llvm-project -#include "mlir/IR/SymbolTable.h" // from @llvm-project -#include "mlir/IR/Visitors.h" // from @llvm-project -#include "mlir/Pass/Pass.h" // from @llvm-project -#include "mlir/Support/LogicalResult.h" // from @llvm-project -#include "xla/mlir/runtime/ir/rt_ops.h" -#include "xla/mlir_hlo/lhlo/IR/lhlo_ops.h" -#include "xla/service/gpu/launch_dimensions.h" -#include "xla/service/gpu/runtime3/conditional_thunk.h" -#include "xla/service/gpu/runtime3/copy_thunk.h" -#include "xla/service/gpu/runtime3/kernel_thunk.h" -#include "xla/service/gpu/runtime3/memset_thunk.h" -#include "xla/service/gpu/runtime3/sequential_thunk.h" -#include "xla/service/gpu/runtime3/while_thunk.h" - -namespace xla { -namespace gpu { - -#define GEN_PASS_DEF_CONVERTLMHLOTOGPULAUNCHPASS -#include "xla/mlir/backends/gpu/transforms/passes.h.inc" - -using namespace mlir; // NOLINT - -using mlir::gpu::GPUDialect; -using mlir::gpu::GPUFuncOp; -using mlir::gpu::GPUModuleOp; -using mlir::gpu::KernelDim3; -using mlir::gpu::LaunchFuncOp; -using mlir::gpu::MemcpyOp; -using mlir::gpu::MemsetOp; -using mlir::gpu::ReturnOp; - -class ConvertLmhloToGpuLaunchPass - : public impl::ConvertLmhloToGpuLaunchPassBase< - ConvertLmhloToGpuLaunchPass> { - public: - explicit ConvertLmhloToGpuLaunchPass(ThunkSequence* thunk_sequence) - : thunk_sequence_(thunk_sequence) {} - - void runOnOperation() override; - - void getDependentDialects(DialectRegistry& registry) const override { - registry.insert(); - } - - private: - ThunkSequence* thunk_sequence_; -}; - -// XLA some times (ab)uses custom calls to represent operations for which we do -// not want to define a separate `HloOpcode`. These operations emitted as device -// kernels (similar to fusions), and we detect such custom calls by name, and -// handle them similar to how we handle fusions. -static std::array kCustomCallIntrinsics = { - "SliceToDynamic", "PadToStatic"}; - -//===-----------------------------------------------------------------------===/ - -static Value MakeBitPatternConstant(OpBuilder& b, Location loc, Type type, - uint32_t bit_pattern) { - mlir::MLIRContext* ctx = type.getContext(); - - // For zero bit pattern always memset with a zero value of the same type. - if (bit_pattern == 0) { - // Because `arith` dialect doesn't support unsigned constants, we have to - // create signless constant first, and then use `rt.unsigned_cast` operation - // to make it unsigned. When lowering to LLVM and function calls, this - // casting operation will be erased. - if (type.isUnsignedInteger()) { - auto signless = IntegerType::get(ctx, type.getIntOrFloatBitWidth()); - auto zero = b.create(loc, b.getZeroAttr(signless)); - return b.create(loc, type, zero.getResult()); - } - - return b.create(loc, b.getZeroAttr(type)); - } - - // In XLA a 1-byte bit pattern copied to fill a 32-byte word when - // `Memset32BitValueThunk` is constructed, so to get back an `i1` constant we - // only need to check if any bit is set to `1`. - if (type.isInteger(1)) { - return b.create(loc, b.getBoolAttr(bit_pattern)); - } - - // Xla IR emitter copies integers of smaller width to fill 32 bits, so we can - // safely truncate the bit pattern. For integers larger than 32 bits we can - // construct a wider integer, as Xla guarantees that all 32-bit words are - // equal. - if (auto integer = type.dyn_cast()) { - llvm::APInt i32(32, bit_pattern); - - assert(integer.getWidth() <= 64 && "integer value must be <= 64 bits"); - llvm::APInt value = integer.getWidth() <= 32 ? i32.trunc(integer.getWidth()) - : i32.concat(i32); - - // See unsigned-to-signed cast documentation above. - if (integer.isUnsigned()) { - auto signless = IntegerType::get(ctx, integer.getWidth()); - auto cst = - b.create(loc, b.getIntegerAttr(signless, value)); - return b.create(loc, type, cst.getResult()); - } - - return b.create(loc, b.getIntegerAttr(integer, value)); - } - - // Similar to integer type we can safely truncate or concat bit pattern. - if (auto fp = type.dyn_cast()) { - llvm::APInt i32(32, bit_pattern); - - assert(fp.getWidth() <= 64 && "floating point value must be <= 64 bits"); - llvm::APInt ivalue = - fp.getWidth() <= 32 ? i32.trunc(fp.getWidth()) : i32.concat(i32); - - llvm::APFloat fvalue = [&]() -> llvm::APFloat { - if (fp.isBF16()) return {llvm::APFloat::BFloat(), ivalue}; - if (fp.isF16()) return {llvm::APFloat::IEEEhalf(), ivalue}; - if (fp.isF32()) return {llvm::APFloat::IEEEsingle(), ivalue}; - if (fp.isF64()) return {llvm::APFloat::IEEEdouble(), ivalue}; - - assert(false && "unsupported floating point type"); - return llvm::APFloat::getZero(llvm::APFloat::IEEEsingle()); - }(); - - return b.create(loc, fvalue, fp); - } - - // Return a constant index value, that will safely fail verification (there is - // no memset operation for `index` type), so that we do not accidentally crash - // the binary in optimized builds. - assert(false && "unsupported memset type"); - return b.create(loc, 0); -} - -static void ExtractThunksForOp(Operation* op, ThunkSequence& thunk_sequence, - ThunkSequence* thunks_for_op) { - for (std::unique_ptr& thunk : thunk_sequence) { - if (thunk == nullptr) { - // This thunk has already been std::move()'ed out of the ThunkSequence - // (see below). Do nothing. - } else if (thunk->kind() == Thunk::kWhile) { - // Search for thunks for the op in while loop. - auto* while_thunk = static_cast(thunk.get()); - ExtractThunksForOp(op, while_thunk->condition_thunk_sequence()->thunks(), - thunks_for_op); - ExtractThunksForOp(op, while_thunk->body_thunk_sequence()->thunks(), - thunks_for_op); - } else if (thunk->kind() == Thunk::kConditional) { - // Search for thunks for the op in conditional branches. - auto* cond_thunk = static_cast(thunk.get()); - for (const std::unique_ptr& branch_thunks : - cond_thunk->branch_thunks()) { - ExtractThunksForOp(op, branch_thunks->thunks(), thunks_for_op); - } - } else if (thunk->op() == op) { - // Found a thunk for the op. - thunks_for_op->push_back(std::move(thunk)); - } else { - // Thunk is not relevant to the op. Do nothing. - } - } -} - -// Returns the data to rewrite op without changing the IR. -static absl::StatusOr> Match( - Operation* op, ThunkSequence& thunk_sequence) { - auto thunks_for_op = std::make_unique(); - ExtractThunksForOp(op, thunk_sequence, thunks_for_op.get()); - - // Check if we know how to lower a Thunk to Gpu operation(s). - auto is_supported = [](const std::unique_ptr& thunk) -> bool { - Thunk::Kind kinds[] = {Thunk::kKernel, Thunk::kCustomKernel, - Thunk::kCopy, Thunk::kMemset32BitValue, - Thunk::kMemzero, Thunk::kSequential}; - return llvm::any_of( - kinds, [&](Thunk::Kind kind) { return thunk->kind() == kind; }); - }; - - if (!llvm::all_of(*thunks_for_op, is_supported)) { - return absl::InternalError("Unsupported Thunk kind"); - } - - return std::move(thunks_for_op); -} - -static void LowerThunkToGpuOp(Operation* op, OpBuilder& b, - GPUModuleOp gpu_module, Thunk* thunk); - -// Replaces op with gpu.launch_func, gpu.memcpy, gpu.memset ops. -static void Rewrite(Operation* op, OpBuilder& b, SymbolTable& symbol_table, - ThunkSequence* thunks) { - OpBuilder::InsertionGuard guard(b); - auto loc = op->getLoc(); - - b.setInsertionPoint(op->getParentOfType()); - auto gpu_module = b.create(loc, "gpu_module"); - symbol_table.insert(gpu_module); - - for (const std::unique_ptr& thunk : *thunks) { - LowerThunkToGpuOp(op, b, gpu_module, thunk.get()); - } - - op->erase(); -} - -static void LowerKernelThunkToGpuOp( - Operation* op, OpBuilder& b, GPUModuleOp gpu_module, - const KernelThunk& thunk, const SmallVector& kernel_args, - const SmallVector& kernel_args_written) { - mlir::Location loc = op->getLoc(); - b.setInsertionPointToStart(gpu_module.getBody()); - - auto func_type = - b.getType(TypeRange(ValueRange(kernel_args)), TypeRange()); - - gpu::GPUFuncOp kernel_func = - b.create(loc, thunk.kernel_name(), func_type); - kernel_func->setAttr(GPUDialect::getKernelFuncAttrName(), b.getUnitAttr()); - - for (int i = 0; i < kernel_args.size(); ++i) { - if (kernel_args_written[i]) { - kernel_func.setArgAttr(i, "lmhlo.written", b.getUnitAttr()); - } - } - - b.setInsertionPointToEnd(&kernel_func.getBody().back()); - b.create(loc); - - auto make_const_idx = [&](int64_t value) { - auto attr = b.getIndexAttr(value); - return b.create(loc, attr).getResult(); - }; - - auto make_kernel_dim3 = [&](const auto& dim3) { - return KernelDim3{make_const_idx(dim3.x), make_const_idx(dim3.y), - make_const_idx(dim3.z)}; - }; - - b.setInsertionPoint(op); - const auto& launch_dims = thunk.launch_dimensions(); - auto grid_size = make_kernel_dim3(launch_dims.block_counts()); - auto block_size = make_kernel_dim3(launch_dims.thread_counts_per_block()); - auto shmem_size = b.create( - loc, b.getI32IntegerAttr(thunk.shmem_bytes())); - - b.create(loc, kernel_func, grid_size, block_size, shmem_size, - kernel_args); -} - -static void LowerCustomKernelThunkToGpuOp( - Operation* op, OpBuilder& b, GPUModuleOp gpu_module, - const CustomKernelThunk& thunk, const SmallVector& kernel_args, - const SmallVector& kernel_args_written) { - mlir::Location loc = op->getLoc(); - b.setInsertionPointToStart(gpu_module.getBody()); - - auto func_type = - b.getType(TypeRange(ValueRange(kernel_args)), TypeRange()); - - gpu::GPUFuncOp kernel_func = - b.create(loc, thunk.custom_kernel_name(), func_type); - kernel_func->setAttr(GPUDialect::getKernelFuncAttrName(), b.getUnitAttr()); - - for (int i = 0; i < kernel_args.size(); ++i) { - if (kernel_args_written[i]) { - kernel_func.setArgAttr(i, "lmhlo.written", b.getUnitAttr()); - } - } - - b.setInsertionPointToEnd(&kernel_func.getBody().back()); - b.create(loc); - - auto make_const_idx = [&](int64_t value) { - auto attr = b.getIndexAttr(value); - return b.create(loc, attr).getResult(); - }; - - auto make_kernel_dim3 = [&](const auto& dim3) { - return KernelDim3{make_const_idx(dim3.x), make_const_idx(dim3.y), - make_const_idx(dim3.z)}; - }; - - b.setInsertionPoint(op); - auto launch_dims = thunk.launch_dimensions(); - auto grid_size = make_kernel_dim3(launch_dims.block_counts()); - auto block_size = make_kernel_dim3(launch_dims.thread_counts_per_block()); - auto shmem_size = b.create( - loc, b.getI32IntegerAttr(thunk.shmem_bytes())); - - auto launch_func = b.create( - loc, kernel_func, grid_size, block_size, shmem_size, kernel_args); - - if (auto computation = op->getAttr("__custom_fusion_computation")) { - launch_func->setAttr("__custom_fusion_computation", computation); - } else { - launch_func->setAttr("__custom_fusion_computation", - b.getStringAttr("")); - } -} - -static void LowerThunkToGpuOp(Operation* op, OpBuilder& b, - GPUModuleOp gpu_module, Thunk* thunk) { - auto loc = op->getLoc(); - - if (thunk->kind() == Thunk::kSequential) { - const auto* seq_thunk = static_cast(thunk); - for (const std::unique_ptr& thunk : seq_thunk->thunks()) { - LowerThunkToGpuOp(op, b, gpu_module, thunk.get()); - } - return; - } - - if (thunk->kind() == Thunk::kCopy) { - const auto* copy_thunk = static_cast(thunk); - b.setInsertionPoint(op); - b.create(loc, TypeRange(), ValueRange(), - copy_thunk->destination_value(), - copy_thunk->source_value()); - return; - } - - auto rewrite_memset = [&](const xla::BufferAllocation::Slice& slice, - uint32_t memset_value, Value buffer_arg) { - auto element_type = - buffer_arg.getType().cast().getElementType(); - b.setInsertionPoint(op); - Value value = MakeBitPatternConstant(b, loc, element_type, memset_value); - b.create(loc, TypeRange(), ValueRange(), buffer_arg, value); - }; - - if (thunk->kind() == Thunk::kMemset32BitValue) { - const auto* memset_thunk = static_cast(thunk); - rewrite_memset(memset_thunk->destination(), memset_thunk->value(), - memset_thunk->dest_value()); - return; - } - if (thunk->kind() == Thunk::kMemzero) { - const auto* memzero_thunk = static_cast(thunk); - rewrite_memset(memzero_thunk->destination(), 0, - memzero_thunk->dest_value()); - return; - } - - if (thunk->kind() == Thunk::kKernel) { - const auto* kernel_thunk = static_cast(thunk); - - SmallVector kernel_args; - for (auto kernel_arg : kernel_thunk->values()) - kernel_args.push_back(kernel_arg); - - SmallVector kernel_args_written; - for (auto written : kernel_thunk->written()) { - kernel_args_written.push_back(written); - } - - LowerKernelThunkToGpuOp(op, b, gpu_module, *kernel_thunk, kernel_args, - kernel_args_written); - return; - } - - if (thunk->kind() == Thunk::kCustomKernel) { - const auto* kernel_thunk = static_cast(thunk); - - SmallVector kernel_args; - for (auto kernel_arg : kernel_thunk->values()) - kernel_args.push_back(kernel_arg); - - SmallVector kernel_args_written; - for (auto written : kernel_thunk->written()) { - kernel_args_written.push_back(written); - } - - LowerCustomKernelThunkToGpuOp(op, b, gpu_module, *kernel_thunk, kernel_args, - kernel_args_written); - return; - } - - CHECK(false) << "Thunk kind not handled: " << thunk->kind(); -} - -// An overload set for defining predicates for operations that should -// conditionally go through the XLA GPU code emitters. -template -static bool HasGpuEmitter(OpTy) { - return true; -} - -// Select custom calls that have corresponding GPU emitters. -static bool HasGpuEmitter(lmhlo::CustomCallOp custom_call) { - return llvm::any_of(kCustomCallIntrinsics, [&](std::string_view name) { - return custom_call.getCallTargetName().equals(name); - }); -} - -//===-----------------------------------------------------------------------===/ - -void ConvertLmhloToGpuLaunchPass::runOnOperation() { - ModuleOp module = getOperation(); - - // No thunks to lower from. Skip pass. - if (thunk_sequence_ == nullptr) return signalPassFailure(); - - // Collect thunks for rewriting each compatible operation in the module into - // the sequence of device kernel launches. Some operation might have an empty - // thunk sequence (e.g. redundant copy operation that does not require running - // anything on device). - absl::flat_hash_map> rewrites; - - // Get data to rewrite kernel ops without changing the IR. - auto walk = [&](auto op_type_tag) { - return module.walk([&](decltype(op_type_tag) op) -> WalkResult { - if (!HasGpuEmitter(op)) return success(); - - auto data = Match(op, *thunk_sequence_); - if (!data.ok()) return op.emitOpError(data.status().message()); - - rewrites[op] = std::move(*data); - return success(); - }); - }; - - // Collect all operations that have GPU code emitters. - if (walk(lmhlo::FusionOp()).wasInterrupted() || - walk(lmhlo::RngGetAndUpdateStateOp()).wasInterrupted() || - walk(lmhlo::ScatterOp()).wasInterrupted() || - walk(lmhlo::SelectAndScatterOp()).wasInterrupted() || - walk(lmhlo::SortOp()).wasInterrupted() || - walk(lmhlo::CustomCallOp()).wasInterrupted() || - walk(LaunchFuncOp()).wasInterrupted()) - return signalPassFailure(); - - // No operations that should be lowered to sequence of device launches. - if (rewrites.empty()) return; - - OpBuilder b(module); - SymbolTable symbol_table(module); - - // Replace matched operations with gpu.launch_func's. - for (const auto& [op, thunks] : rewrites) { - Rewrite(op, b, symbol_table, thunks.get()); - } - - // Mark module as gpu.container_module. - module->setAttr(GPUDialect::getContainerModuleAttrName(), b.getUnitAttr()); -} - -std::unique_ptr> -createConvertLmhloToGpuLaunchPass(ThunkSequence* thunk_sequence) { - return std::make_unique(thunk_sequence); -} - -} // namespace gpu -} // namespace xla diff --git a/third_party/xla/xla/mlir/backends/gpu/transforms/lmhlo_to_gpu_runtime.cc b/third_party/xla/xla/mlir/backends/gpu/transforms/lmhlo_to_gpu_runtime.cc deleted file mode 100644 index 68b2946dac4046..00000000000000 --- a/third_party/xla/xla/mlir/backends/gpu/transforms/lmhlo_to_gpu_runtime.cc +++ /dev/null @@ -1,1239 +0,0 @@ -/* Copyright 2022 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include -#include -#include -#include -#include -#include - -#include "llvm/ADT/SmallVector.h" -#include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project -#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h" // from @llvm-project -#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" // from @llvm-project -#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project -#include "mlir/Dialect/GPU/IR/GPUDialect.h" // from @llvm-project -#include "mlir/Dialect/MemRef/IR/MemRef.h" // from @llvm-project -#include "mlir/Dialect/SCF/IR/SCF.h" // from @llvm-project -#include "mlir/IR/Attributes.h" // from @llvm-project -#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project -#include "mlir/IR/IRMapping.h" // from @llvm-project -#include "mlir/IR/ImplicitLocOpBuilder.h" // from @llvm-project -#include "mlir/IR/MLIRContext.h" // from @llvm-project -#include "mlir/IR/PatternMatch.h" // from @llvm-project -#include "mlir/IR/SymbolTable.h" // from @llvm-project -#include "mlir/IR/Visitors.h" // from @llvm-project -#include "mlir/Pass/Pass.h" // from @llvm-project -#include "mlir/Support/LogicalResult.h" // from @llvm-project -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project -#include "xla/mlir/backends/gpu/transforms/uid_generator.h" -#include "xla/mlir/runtime/utils/custom_calls.h" -#include "xla/mlir_hlo/lhlo/IR/lhlo_ops.h" -#include "xla/mlir_hlo/lhlo_gpu/IR/lhlo_gpu_ops.h" -#include "xla/service/gpu/nccl_all_gather_thunk.h" -#include "xla/service/gpu/nccl_all_reduce_thunk.h" -#include "xla/service/gpu/nccl_all_to_all_thunk.h" -#include "xla/service/gpu/nccl_collective_permute_thunk.h" -#include "xla/service/gpu/nccl_collective_thunk.h" -#include "xla/service/gpu/nccl_recv_thunk.h" -#include "xla/service/gpu/nccl_send_thunk.h" - -namespace xla { -namespace gpu { - -#define GEN_PASS_DEF_CONVERTLMHLOTOGPURUNTIMEPASS -#include "xla/mlir/backends/gpu/transforms/passes.h.inc" - -using namespace mlir; // NOLINT - -using mlir::gpu::MemcpyOp; - -using mlir::lmhlo::CaseOp; -using mlir::lmhlo::CustomCallOp; -using mlir::lmhlo::FftOp; -using mlir::lmhlo::InfeedOp; -using mlir::lmhlo::OutfeedOp; -using mlir::lmhlo::TerminatorOp; -using mlir::lmhlo::WhileOp; - -using xla::runtime::AppendCustomCallAttrs; -using xla::runtime::CustomCallDeclarations; - -// helper template to check T is any of the types listed in Ts. -template -inline constexpr bool is_any = std::disjunction_v...>; - -class ConvertLmhloToGpuRuntimePass - : public impl::ConvertLmhloToGpuRuntimePassBase< - ConvertLmhloToGpuRuntimePass> { - void runOnOperation() override; - - void getDependentDialects(DialectRegistry& registry) const override { - registry - .insert(); - } -}; - -//===----------------------------------------------------------------------===// - -class TerminatorOpLowering : public OpRewritePattern { - public: - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(TerminatorOp op, - PatternRewriter& rewriter) const override { - rewriter.replaceOpWithNewOp(op); - return mlir::success(); - } -}; - -//===----------------------------------------------------------------------===// - -template -class IoFeedOpLowering : public OpRewritePattern { - static StringRef Target(InfeedOp) { return "xla.gpu.infeed"; } - static StringRef Target(OutfeedOp) { return "xla.gpu.outfeed"; } - - public: - IoFeedOpLowering(MLIRContext* ctx, CustomCallDeclarations& custom_calls) - : OpRewritePattern(ctx), custom_calls_(custom_calls) {} - - LogicalResult matchAndRewrite(IoFeedOp op, - PatternRewriter& rewriter) const override { - // Get or create a custom call function declaration. - ImplicitLocOpBuilder b(op.getLoc(), rewriter); - func::FuncOp callee = custom_calls_.GetOrCreate(b, Target(op), op); - - llvm::SmallVector custom_call_attrs = { - {b.getStringAttr("config"), op.getConfigAttr()}}; - - // Call the runtime intrinsic with the original operands. - auto call = rewriter.replaceOpWithNewOp( - op, callee.getName(), TypeRange(), op.getOperands()); - AppendCustomCallAttrs(call, custom_call_attrs); - - return success(); - } - - private: - CustomCallDeclarations& custom_calls_; -}; - -class InfeedOpLowering : public IoFeedOpLowering { - public: - using IoFeedOpLowering::IoFeedOpLowering; -}; - -class OutfeedOpLowering : public IoFeedOpLowering { - public: - using IoFeedOpLowering::IoFeedOpLowering; -}; - -//===----------------------------------------------------------------------===// - -class CustomCallOpLowering : public OpRewritePattern { - private: - static constexpr const char kCustomCallTarget[] = "xla.gpu.custom_call"; - - public: - CustomCallOpLowering(MLIRContext* ctx, CustomCallDeclarations& custom_calls) - : OpRewritePattern(ctx), custom_calls_(custom_calls) {} - - // Rewrite custom call with `API_VERSION_TYPED_FFI` version into XLA runtime - // custom calls bypassing custom call adaptor. - LogicalResult rewriteTypedCustomCall(CustomCallOp op, - PatternRewriter& rewriter) const { - // TODO(ezhulenev): Support target arg mapping, or explain why we do not - // need them for typed custom calls. - if (op.getTargetArgMapping()) - return op.emitOpError( - "API_VERSION_TYPED_FFI custom calls do not " - "support target arg mapping"); - - // Create a custom call function declaration. - ImplicitLocOpBuilder b(op.getLoc(), rewriter); - func::FuncOp callee = - custom_calls_.GetOrCreate(b, op.getCallTargetName(), op); - // Custom calls starting with the __gpu$ prefix are considered internal and - // statically linked (e.g. __gpu$TopK). - if (!op.getCallTargetName().starts_with("__gpu$")) { - callee->setAttr("rt.dynamic", UnitAttr::get(b.getContext())); - } - - // Forward backend config to the custom call implementation. - auto dict = op.getBackendConfig() - ? op.getBackendConfig()->cast() - : nullptr; - llvm::SmallVector backend_config(dict.begin(), dict.end()); - - // Call the custom call function forwarding user-defined attributes. - auto call = rewriter.replaceOpWithNewOp( - op, callee.getName(), TypeRange(), op.getOperands()); - AppendCustomCallAttrs(call, backend_config); - - return success(); - } - - LogicalResult matchAndRewrite(CustomCallOp op, - PatternRewriter& rewriter) const override { - // Typed custom calls lowered directly to XLA runtime custom calls. - if (op.getApiVersion() == mhlo::CustomCallApiVersion::API_VERSION_TYPED_FFI) - return rewriteTypedCustomCall(op, rewriter); - - ImplicitLocOpBuilder b(op.getLoc(), rewriter); - - // By default all operands passed to the custom call handler. - llvm::SmallVector operands = op.getOperands(); - - // If custom call has target arguments mapping, then we need to pass `i64` - // scalars in place of holes to detect them in custom call handler. - // - // TODO(ezhulenev): We need an `xla` dialect to model Xla framework - // semantics including holes for custom call. As a work around we pass `i64` - // values because xla custom call do not support scalar arguments, and we - // can disambiguate holes from buffers. - if (op.getTargetArgMapping().has_value()) { - auto mapping = *op.getTargetArgMapping(); - int64_t num_args = mapping.getNumArgs(); - int64_t num_results = mapping.getNumResults(); - - // We represent holes as an arbitrary `i64` constant. - Value hole = b.create(b.getI64IntegerAttr(-1)); - operands = llvm::SmallVector(num_args + num_results, hole); - - // Update operands to mapped custom call arguments. - auto args = mapping.getArgsToTargetArgs(); - for (const auto& indexed : llvm::enumerate(args)) - operands[indexed.value()] = op.getArgs()[indexed.index()]; - - // Update operands to mapped custom call results. - auto res = mapping.getResultsToTargetResults(); - for (const auto& indexed : llvm::enumerate(res)) - operands[num_args + indexed.value()] = op.getOutput()[indexed.index()]; - } - - // Create a custom call function declaration. - func::FuncOp callee = custom_calls_.GetOrCreate( - b, kCustomCallTarget, TypeRange(ValueRange(operands)), TypeRange()); - - llvm::SmallVector custom_call_attrs = { - {b.getStringAttr("api_version"), op.getApiVersionAttr()}, - {b.getStringAttr("backend_config"), op.getBackendConfigAttr()}, - {b.getStringAttr("call_target_name"), op.getCallTargetNameAttr()}}; - - // Call the runtime intrinsic with the original operands. - auto call = rewriter.replaceOpWithNewOp( - op, callee.getName(), TypeRange(), operands); - AppendCustomCallAttrs(call, custom_call_attrs); - - return success(); - } - - private: - CustomCallDeclarations& custom_calls_; -}; - -//===----------------------------------------------------------------------===// - -class FftOpLowering : public OpRewritePattern { - private: - static constexpr const char kCustomCallTarget[] = "xla.gpu.fft"; - - public: - FftOpLowering(MLIRContext* ctx, UidGenerator& uid, - CustomCallDeclarations& custom_calls) - : OpRewritePattern(ctx), uid_(uid), custom_calls_(custom_calls) {} - - LogicalResult matchAndRewrite(FftOp op, - PatternRewriter& rewriter) const override { - // Create a custom call function declaration. - ImplicitLocOpBuilder b(op.getLoc(), rewriter); - func::FuncOp callee = custom_calls_.GetOrCreate(b, kCustomCallTarget, op); - - llvm::SmallVector custom_call_attrs = { - {b.getStringAttr("fft_length"), op.getFftLengthAttr()}, - {b.getStringAttr("fft_type"), op.getFftTypeAttr()}, - {b.getStringAttr("uid"), b.getI64IntegerAttr(uid_.uid())}}; - - // Convert Fft to a function call. - auto call = rewriter.replaceOpWithNewOp( - op, callee.getName(), TypeRange(), op.getOperands()); - AppendCustomCallAttrs(call, custom_call_attrs); - return success(); - } - - private: - UidGenerator& uid_; - CustomCallDeclarations& custom_calls_; -}; - -//===----------------------------------------------------------------------===// - -class CaseOpLowering : public OpRewritePattern { - public: - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(CaseOp op, - PatternRewriter& rewriter) const override { - ImplicitLocOpBuilder b(op.getLoc(), rewriter); - - // Copy index buffer to the host ... - auto index_type = op.getIndex().getType().dyn_cast(); - - // Always create an `alloca` in the parent function entry block. - // See: https://llvm.org/docs/Frontend/PerformanceTips.html#use-of-allocas - Value index_on_host = [&]() -> Value { - OpBuilder::InsertionGuard guard(b); - b.setInsertionPointToStart(&op->getParentOfType().front()); - return b.create(index_type); - }(); - - b.create(TypeRange(), ValueRange({index_on_host, op.getIndex()})); - - // Get the index value from the buffer. - Value index = b.create(index_type.getElementType(), - index_on_host, ValueRange()); - - bool is_predicate = index_type.getElementType().isInteger(1); - - // For binary index (predicate) convert i1 to i32 index. - if (is_predicate) { - Value c0 = b.create(b.getI32IntegerAttr(0)); - Value c1 = b.create(b.getI32IntegerAttr(1)); - index = b.create(index, c0, c1); - } - - // For integer index make sure that it is within range. - if (!is_predicate) { - unsigned n = op.getNumRegions() - 1; - Value c0 = b.create(b.getI32IntegerAttr(0)); - Value cN = b.create(b.getI32IntegerAttr(n)); - - Value too_small = b.create( - b.getI1Type(), arith::CmpIPredicate::slt, index, c0); - Value too_large = b.create( - b.getI1Type(), arith::CmpIPredicate::sgt, index, cN); - - Value out_of_range = b.create(too_small, too_large); - index = b.create(out_of_range, cN, index); - } - - // Wrap the CFG constructed from the `lmhlo.case` operation in an - // `scf.execute_region` operation, so that we do not introduce the CFG - // into regions that expect a single block (e.g. inside the loop body). - auto execute = b.create(TypeRange()); - - // Add an entry block to the execute region operation. - Block& entry = execute.getRegion().emplaceBlock(); - - // Create a block with `scf.yield` terminator. - Block& yield = execute.getRegion().emplaceBlock(); - b.setInsertionPointToStart(&yield); - b.create(); - - // Prepare case destinations for the `scf.switch` operation. - llvm::SmallVector case_values; - llvm::SmallVector case_blocks; - llvm::SmallVector case_operands; - - // Create blocks from each of the case regions. - for (Region& region : op->getRegions()) { - // Move `lmhlo.case` block into the execute region. - Block& block = region.front(); - block.moveBefore(&yield); - - // Erase original `lmhlo.terminator`. - rewriter.eraseOp(block.getTerminator()); - - // Branch into the yield block. - b.setInsertionPointToEnd(&block); - b.create(&yield); - - // Add a `cf.switch` case. - int32_t idx = case_blocks.size(); - case_values.push_back(b.getI32IntegerAttr(idx).getValue()); - case_blocks.push_back(&block); - case_operands.push_back({}); - } - - // Create a `cf.switch` operation in the execute region entry block. - b.setInsertionPointToEnd(&entry); - b.create(index, &yield, ValueRange(), case_values, - case_blocks, case_operands); - - // Erase the original case operation. - rewriter.eraseOp(op); - - return success(); - } -}; - -//===----------------------------------------------------------------------===// - -class WhileOpLowering : public OpRewritePattern { - public: - using OpRewritePattern::OpRewritePattern; - - // Rewrite while loop with known trip count to `scf.for` operation. - LogicalResult rewriteForLoop(WhileOp op, PatternRewriter& rewriter) const { - ImplicitLocOpBuilder b(op.getLoc(), rewriter); - - Value lb = b.create(0); - Value ub = b.create(*op.getTripCount()); - Value c1 = b.create(1); - - // Create an `scf.for` loop in place of `lmhlo.while` loop. - auto loop = b.create(lb, ub, c1, ValueRange()); - - // Move body region into the new loop operation. - IRMapping mapping; - rewriter.eraseOp(op.getBody().front().getTerminator()); - rewriter.inlineBlockBefore(&op.getBody().front(), - loop.getBody()->getTerminator()); - - // Erase the original while loop. - rewriter.eraseOp(op); - - return success(); - } - - // Rewrite while loop with unknown trip count to `scf.while` operation. - LogicalResult rewriteWhileLoop(WhileOp op, PatternRewriter& rewriter) const { - ImplicitLocOpBuilder b(op.getLoc(), rewriter); - - // Create an `scf.while` loop in place of `lmhlo.while` loop. - auto loop = b.create(TypeRange(), ValueRange()); - - // Predicate buffer placed on the device. - Value pred = op.getOperand(0); - - // Inline condition and body regions into the new loop operation. - IRMapping mapping; - rewriter.inlineRegionBefore(op.getCond(), loop.getBefore(), - loop.getBefore().begin()); - rewriter.inlineRegionBefore(op.getBody(), loop.getAfter(), - loop.getAfter().begin()); - - { // Replace loop condition terminator. - auto* terminator = loop.getBefore().back().getTerminator(); - b.setInsertionPointAfter(terminator); - - auto i1 = b.getI1Type(); - - // Always create an `alloca` in the parent function entry block. - // See: https://llvm.org/docs/Frontend/PerformanceTips.html#use-of-allocas - Value pred_on_host = [&]() -> Value { - OpBuilder::InsertionGuard guard(b); - b.setInsertionPointToStart( - &op->getParentOfType().front()); - return b.create(MemRefType::get({}, i1)); - }(); - - // Copy predicate buffer to the host ... - b.create(TypeRange(), ValueRange({pred_on_host, pred})); - - // .. and check if we need to continue loop iteration. - Value cond = b.create(i1, pred_on_host, ValueRange()); - b.create(cond, ValueRange()); - rewriter.eraseOp(terminator); - } - - { // Replace loop body terminator. - auto* terminator = loop.getAfter().back().getTerminator(); - b.setInsertionPointAfter(terminator); - b.create(TypeRange(), ValueRange()); - rewriter.eraseOp(terminator); - } - - // Erase the original while loop. - rewriter.eraseOp(op); - - return success(); - } - - LogicalResult matchAndRewrite(WhileOp op, - PatternRewriter& rewriter) const override { - assert(op.getNumOperands() == 1 && "expected single lmhlo.while operand"); - return op.getTripCount().has_value() ? rewriteForLoop(op, rewriter) - : rewriteWhileLoop(op, rewriter); - } -}; - -//===----------------------------------------------------------------------===// -// Collective operations lowerings. -//===----------------------------------------------------------------------===// - -using mlir::lmhlo::PartitionIdOp; -using mlir::lmhlo::ReplicaIdOp; -using mlir::lmhlo_gpu::AllGatherDoneOp; -using mlir::lmhlo_gpu::AllGatherStartOp; -using mlir::lmhlo_gpu::AllReduceDoneOp; -using mlir::lmhlo_gpu::AllReduceStartOp; -using mlir::lmhlo_gpu::AllToAllDoneOp; -using mlir::lmhlo_gpu::AllToAllStartOp; -using mlir::lmhlo_gpu::CollectivePermuteDoneOp; -using mlir::lmhlo_gpu::CollectivePermuteStartOp; -using mlir::lmhlo_gpu::ReduceScatterDoneOp; -using mlir::lmhlo_gpu::ReduceScatterStartOp; - -using lmhlo::RecvDoneOp; -using lmhlo::RecvOp; -using lmhlo::SendDoneOp; -using lmhlo::SendOp; - -// We assign unique id to all collective operations in the module, so that we -// can efficiently access per-op state at run time. Exception to this rule are -// asynchronous collective operations, that share the same unique id by the pair -// of corresponding `start` and `done` operations. -// -// Asynchronous collective operations pass HLO Token to represent the dependency -// between the `Start` and `Done` operations. When we lower to XLA runtime -// custom calls we rely on assigning each unique pair of `Start` and `Done` -// operations a unique event id, and use shared "context" owned by the -// GpuExecutable to pass Gpu events from `Start` to `Done` custom call handlers. -// -// TODO(ezhulenev): Once XLA runtime custom calls support returning values, we -// should explicitly return event id from the `Start` custom call, and pass it -// to the `Done` custom call. Longer term this should become an `!async.token` -// and rely on XLA runtime asynchronous execution. -class CollectiveUidGenerator { - public: - CollectiveUidGenerator() : cnt_(0) {} - - // Assigns a unique event id to the pair of start and done operations. - int32_t AssignUid(Operation* start, Operation* done) { - int32_t id = next(); - uids_[start] = id; - uids_[done] = id; - return id; - } - - FailureOr AssignedUid(Operation* op) { - // Async operations must be assigned uid ahead of time. - if (isa(op)) { - auto it = uids_.find(op); - if (it == uids_.end()) return failure(); - return it->second; - } - // For every other operation we just assign a next id. - return next(); - } - - private: - int32_t next() { return cnt_++; } - - int32_t cnt_; - llvm::DenseMap uids_; -}; - -// Filters out host send/recv which do not participate in collective op -// lowerings. -struct CollectiveFilter { - template - static std::enable_if_t, bool> ShouldHandle( - OpT) { - return true; - } - - // We only handle send/recv that is not a host transfer. - template - static std::enable_if_t, bool> ShouldHandle( - OpT op) { - return !op.getIsHostTransfer(); - } -}; - -template -NcclCollectiveConfig GetNcclCollectiveConfigForP2POps(OpT op, int replica_count, - int num_partitions) { - return ThunkT::GetNcclP2PConfig(op, replica_count, num_partitions).config; -} - -template -class CollectiveOpLowering : public OpRewritePattern { - // Define target custom call for lowering of collective ops. - static StringRef Target(AllGatherStartOp) { return "xla.gpu.all_gather"; } - static StringRef Target(AllReduceStartOp) { return "xla.gpu.all_reduce"; } - static StringRef Target(AllToAllStartOp) { return "xla.gpu.all_to_all"; } - static StringRef Target(ReduceScatterStartOp) { - return "xla.gpu.reduce_scatter"; - } - static StringRef Target(CollectivePermuteStartOp) { - return "xla.gpu.collective_permute"; - } - static StringRef Target(SendOp) { return "xla.gpu.send"; } - static StringRef Target(RecvOp) { return "xla.gpu.recv"; } - - template - static std::enable_if_t< - is_any, - NcclCollectiveConfig> - GetNcclCollectiveConfig(OpT op, int /*replica_count*/, - int /*num_partitions*/) { - return GetNcclCollectiveConfigForMlir(op, op.getUseGlobalDeviceIds()); - } - - static NcclCollectiveConfig GetNcclCollectiveConfig(AllToAllStartOp op, - int /*replica_count*/, - int /*num_partitions*/) { - // TODO(b/180174349): LMHLO AllToAll incorrectly has use_global_device_ids - // attribute and it should be removed. - return GetNcclCollectiveConfigForMlir(op, std::nullopt); - } - - static NcclCollectiveConfig GetNcclCollectiveConfig( - CollectivePermuteStartOp op, int replica_count, int num_partitions) { - return GetNcclCollectiveConfigForP2POps( - op, replica_count, num_partitions); - } - - static NcclCollectiveConfig GetNcclCollectiveConfig(SendOp op, - int replica_count, - int num_partitions) { - return GetNcclCollectiveConfigForP2POps( - op, replica_count, num_partitions); - } - - static NcclCollectiveConfig GetNcclCollectiveConfig(RecvOp op, - int replica_count, - int num_partitions) { - return GetNcclCollectiveConfigForP2POps( - op, replica_count, num_partitions); - } - - template - static std::enable_if_t, - LogicalResult> - TryDegenerateToMemCopy(NonCollectivePermuteOp op, - const NcclCollectiveConfig& config, int replica_count, - int num_partitions, PatternRewriter& rewriter) { - if (!config.IsDegenerate(replica_count, num_partitions)) { - return failure(); - } - - for (int64_t i = 0; i < op.getInputs().size(); i++) { - rewriter.create( - op.getLoc(), TypeRange(), - ValueRange({op.getOutputs()[i], op.getOperands()[i]})); - } - - return success(); - } - - // Send/Recv is never degenerate by itself, so returns failure(). - template - static std::enable_if_t, LogicalResult> - TryDegenerateToMemCopy(OpT op, const NcclCollectiveConfig& config, - int replica_count, int num_partitions, - PatternRewriter& rewriter) { - return failure(); - } - - static LogicalResult TryDegenerateToMemCopy( - CollectivePermuteStartOp op, const NcclCollectiveConfig& config, - int replica_count, int num_partitions, PatternRewriter& rewriter) { - if (!NcclCollectivePermuteStartThunk::IsDegenerate(op, replica_count, - num_partitions)) { - return failure(); - } - - rewriter.create( - op.getLoc(), TypeRange(), - ValueRange({op.getOutput(), op.getOperand()})); - - return success(); - } - - static Status CheckImplementable(AllGatherStartOp op, int64_t replica_count, - int64_t num_partitions) { - return NcclAllGatherStartThunk::CheckImplementable(op, replica_count, - num_partitions); - } - - static Status CheckImplementable(AllReduceStartOp op, int64_t replica_count, - int64_t num_partitions) { - return NcclAllReduceStartThunk::CheckImplementable(op, replica_count, - num_partitions); - } - - static Status CheckImplementable(AllToAllStartOp op, int64_t replica_count, - int64_t num_partitions) { - return NcclAllToAllStartThunk::CheckImplementable(op, replica_count, - num_partitions); - } - - static Status CheckImplementable(CollectivePermuteStartOp op, - int64_t replica_count, - int64_t num_partitions) { - return NcclCollectivePermuteStartThunk::CheckImplementable( - op, replica_count, num_partitions); - } - - static Status CheckImplementable(SendOp op, int64_t replica_count, - int64_t num_partitions) { - return NcclSendThunk::CheckImplementable(op, replica_count, num_partitions); - } - - static Status CheckImplementable(RecvOp op, int64_t replica_count, - int64_t num_partitions) { - return NcclRecvThunk::CheckImplementable(op, replica_count, num_partitions); - } - - static Status CheckImplementable(ReduceScatterStartOp op, - int64_t replica_count, - int64_t num_partitions) { - return NcclReduceScatterStartThunk::CheckImplementable(op, replica_count, - num_partitions); - } - - template - static typename std::enable_if_t< - is_any, LogicalResult> - SetSpecificAttrs(ImplicitLocOpBuilder& b, OpT op, func::CallOp call) { - std::optional reduction_kind = - NcclAllReduceReduceScatterThunkBase::MatchAllReduceComputation( - op.getComputation()); - if (!reduction_kind.has_value()) - return op.emitOpError() - << "Failed to determine reduction computation for AllReduce"; - - call->setAttr( - b.getStringAttr("reduction_kind"), - b.getI64IntegerAttr(static_cast(reduction_kind.value()))); - - return success(); - } - - static LogicalResult SetSpecificAttrs(ImplicitLocOpBuilder& b, - AllGatherStartOp op, - func::CallOp call) { - return success(); - } - - static LogicalResult SetSpecificAttrs(ImplicitLocOpBuilder& b, - AllToAllStartOp op, func::CallOp call) { - call->setAttr(b.getStringAttr("has_split_dimension"), - b.getBoolAttr(op.getSplitDimension().has_value())); - return success(); - } - - static void SetSourceTargetPeersAttrs( - ImplicitLocOpBuilder& b, - const std::vector>& source_target_pairs, - func::CallOp call) { - std::vector source_peers; - std::vector target_peers; - source_peers.reserve(source_target_pairs.size()); - target_peers.reserve(source_target_pairs.size()); - for (const auto& source_target_pair : source_target_pairs) { - source_peers.push_back(source_target_pair.first); - target_peers.push_back(source_target_pair.second); - } - - auto source_peers_attr = b.getI64TensorAttr(source_peers); - auto target_peers_attr = b.getI64TensorAttr(target_peers); - call->setAttr(b.getStringAttr("source_peers"), source_peers_attr); - call->setAttr(b.getStringAttr("target_peers"), target_peers_attr); - } - - static LogicalResult SetSpecificAttrs(ImplicitLocOpBuilder& b, - CollectivePermuteStartOp op, - func::CallOp call) { - auto source_target_pairs_or = - ConvertNx2Attribute(op.getSourceTargetPairs()); - if (!source_target_pairs_or.ok()) { - return op.emitOpError() << source_target_pairs_or.status().message(); - } - SetSourceTargetPeersAttrs(b, source_target_pairs_or.value(), call); - return success(); - } - - template - static typename std::enable_if_t, LogicalResult> - SetSpecificAttrs(ImplicitLocOpBuilder& b, OpT op, func::CallOp call) { - auto source_target_pairs_or = - GetSourceTargetPairs(op.getFrontendAttributes()); - if (!source_target_pairs_or.ok()) { - return op.emitOpError() << source_target_pairs_or.status().message(); - } - SetSourceTargetPeersAttrs(b, source_target_pairs_or.value(), call); - return success(); - } - - template - static typename std::enable_if_t, bool> getIsSync( - OpT) { - return false; - } - - template - static typename std::enable_if_t, bool> - getIsSync(OpT op) { - return op.getIsSync(); - } - - template - static typename std::enable_if_t, bool> - noParallelCustomCall(OpT) { - return false; - } - - template - static typename std::enable_if_t, bool> - noParallelCustomCall(OpT op) { - return op.getNoParallelCustomCall(); - } - - // For async collective erase all corresponding done operations. - template - void eraseDoneOp(PatternRewriter& rewriter, CollectiveOp op) const { - if (auto start = dyn_cast(op.getOperation())) { - auto users = llvm::to_vector(start.getToken().getUsers()); - llvm::for_each(users, [&](Operation* user) { - if (isa(user)) rewriter.eraseOp(user); - }); - } - } - - public: - CollectiveOpLowering(MLIRContext* ctx, CollectiveUidGenerator& uid, - CustomCallDeclarations& custom_calls) - : OpRewritePattern(ctx), - uid_(uid), - custom_calls_(custom_calls) {} - - LogicalResult matchAndRewrite(CollectiveOp op, - PatternRewriter& rewriter) const override { - if (!CollectiveFilter::ShouldHandle(op)) { - return failure(); - } - - // Construct an NCCL collective config from the parent func attributes. - func::FuncOp fn = op->template getParentOfType(); - auto replica_count_attr = fn->getAttrOfType("replica_count"); - auto num_partitions_attr = fn->getAttrOfType("num_partitions"); - const int64_t replica_count = replica_count_attr.getInt(); - const int64_t num_partitions = num_partitions_attr.getInt(); - - NcclCollectiveConfig config = - GetNcclCollectiveConfig(op, replica_count, num_partitions); - - // For async collective erase all corresponding done operations. - auto erase_done_op = [&]() { - eraseDoneOp(rewriter, op); - eraseDoneOp(rewriter, op); - eraseDoneOp(rewriter, - op); - eraseDoneOp(rewriter, op); - eraseDoneOp(rewriter, op); - eraseDoneOp(rewriter, op); - eraseDoneOp(rewriter, op); - }; - - // A given collective op can be degenerate if across all groups formed - // by it are singleton. In such a case, we don't need to do any - // communication and we can just copy the input to the output. - if (succeeded(TryDegenerateToMemCopy(op, config, replica_count, - num_partitions, rewriter))) { - // For async collective erase all corresponding done operations. - erase_done_op(); - - // Erase the original collective operation. - rewriter.eraseOp(op); - - return success(); - } - - Status implementable_status = - CheckImplementable(op, replica_count, num_partitions); - if (!implementable_status.ok()) { - return op.emitOpError() << implementable_status.message(); - } - - // Check that we have and assigned unique collective operation id. - auto uid = uid_.AssignedUid(op); - if (failed(uid)) { - return op.emitOpError("failed to get a unique collective operation id"); - } - - // Get or create a custom call function declaration. - ImplicitLocOpBuilder b(op.getLoc(), rewriter); - - // We always drop the return value from the signature, because for - // AllReduceStart operation we pass dependency through the collective - // operation id. - func::FuncOp callee = custom_calls_.GetOrCreate( - b, Target(op), TypeRange(op.getOperands()), TypeRange()); - - // Convert collective op to a function call. - auto call = rewriter.create(op.getLoc(), callee.getName(), - TypeRange(), op.getOperands()); - - // Copy backend specific attributes. - call->setAttr(b.getStringAttr("group_mode"), - b.getI64IntegerAttr(static_cast(config.group_mode))); - call->setAttr(b.getStringAttr("op_id"), b.getI64IntegerAttr(config.op_id)); - - // TODO(b/233930690): Pass the attribute below as a nested array. - // Pass an array of arrays using two vectors; one specifying all the values - // and another specifying the (ending) offsets of each array in the other - // vector. Example: [ [10, 20, 30, 40], [50, 60], [70, 80, 90] ] turns into - // offsets=[4, 6, 9] values=[10, 20, 30, 40, 50, 60, 70, 80, 90]. - std::vector replica_group_offsets; - std::vector replica_group_values; - replica_group_offsets.reserve(config.replica_groups.size()); - int replica_group_offset = 0; - for (const auto& replica_group : config.replica_groups) { - replica_group_offset += replica_group.replica_ids_size(); - replica_group_offsets.push_back(replica_group_offset); - replica_group_values.reserve(replica_group_offset); - for (auto replica_id : replica_group.replica_ids()) { - replica_group_values.push_back(replica_id); - } - } - call->setAttr(b.getStringAttr("replica_group_offsets"), - b.getI64TensorAttr(replica_group_offsets)); - call->setAttr(b.getStringAttr("replica_group_values"), - b.getI64TensorAttr(replica_group_values)); - - // Assign a unique collective operation id. - call->setAttr(b.getStringAttr("uid"), b.getI32IntegerAttr(*uid)); - - // Set attributes specific to the type of collective operation. - auto result = SetSpecificAttrs(b, op, call); - if (failed(result)) return result; - - bool is_async = !getIsSync(op); - call->setAttr(b.getStringAttr("is_async"), b.getBoolAttr(is_async)); - - call->setAttr(b.getStringAttr("no_parallel_custom_call"), - b.getBoolAttr(noParallelCustomCall(op))); - - // If the collective will not execute asynchronously, erase the associated - // done op. - if (!is_async) { - erase_done_op(); - } else { - // For asynchronous start operation we need to produce a fake token, that - // will be later removed, because corresponding `done` operation doesn't - // have a token argument. We rely on the `unrealized_conversion_cast` - // operation to create a fake token from the `i8` constant, and on the - // dead code elimination pass that will remove unused fake tokens. - Value token = op.getToken(); - Value c0 = b.create(b.getI8IntegerAttr(0)); - auto fake = b.create(token.getType(), c0); - token.replaceAllUsesWith(fake.getResult(0)); - } - - // Erase the original collective operation. - rewriter.eraseOp(op); - - return success(); - } - - private: - CollectiveUidGenerator& uid_; - CustomCallDeclarations& custom_calls_; -}; - -#define DEFINE_COLLECTIVE_OP_LOWERING(OP) \ - class OP##Lowering : public CollectiveOpLowering { \ - public: \ - using CollectiveOpLowering::CollectiveOpLowering; \ - } - -DEFINE_COLLECTIVE_OP_LOWERING(AllGatherStartOp); -DEFINE_COLLECTIVE_OP_LOWERING(AllReduceStartOp); -DEFINE_COLLECTIVE_OP_LOWERING(AllToAllStartOp); -DEFINE_COLLECTIVE_OP_LOWERING(CollectivePermuteStartOp); -DEFINE_COLLECTIVE_OP_LOWERING(ReduceScatterStartOp); -DEFINE_COLLECTIVE_OP_LOWERING(SendOp); -DEFINE_COLLECTIVE_OP_LOWERING(RecvOp); - -#undef DEFINE_COLLECTIVE_OP_LOWERING - -template -class AsyncDoneOpLowering : public OpRewritePattern { - public: - AsyncDoneOpLowering(MLIRContext* ctx, CollectiveUidGenerator& uid, - CustomCallDeclarations& custom_calls) - : OpRewritePattern(ctx), uid_(uid), custom_calls_(custom_calls) {} - - LogicalResult matchAndRewrite(OpT op, - PatternRewriter& rewriter) const override { - // Get or create a custom call function declaration. - ImplicitLocOpBuilder b(op.getLoc(), rewriter); - func::FuncOp callee = custom_calls_.GetOrCreate( - b, "xla.gpu.collective_done", TypeRange(), TypeRange()); - - // Get a unique collective operation id. - FailureOr uid = uid_.AssignedUid(op); - if (failed(uid)) - return op.emitOpError("failed to get a unique collective operation id"); - - llvm::SmallVector custom_call_attributes = { - {b.getStringAttr("uid"), b.getI32IntegerAttr(*uid)}, - {b.getStringAttr("done_type"), b.getStringAttr(Derived::kDoneType)}}; - - // Convert AllReduceDone to a function call. - auto call = rewriter.replaceOpWithNewOp(op, callee.getName(), - TypeRange()); - AppendCustomCallAttrs(call, custom_call_attributes); - - return success(); - } - - private: - CollectiveUidGenerator& uid_; - CustomCallDeclarations& custom_calls_; -}; - -#define DEFINE_COLLECTIVE_DONE_OP_LOWERING(OP, done_type) \ - struct OP##Lowering : public AsyncDoneOpLowering { \ - static constexpr const char kDoneType[] = done_type; \ - using AsyncDoneOpLowering::AsyncDoneOpLowering; \ - } - -DEFINE_COLLECTIVE_DONE_OP_LOWERING(AllGatherDoneOp, "all_gather_done"); -DEFINE_COLLECTIVE_DONE_OP_LOWERING(AllReduceDoneOp, "all_reduce_done"); -DEFINE_COLLECTIVE_DONE_OP_LOWERING(AllToAllDoneOp, "all_to_all_done"); -DEFINE_COLLECTIVE_DONE_OP_LOWERING(CollectivePermuteDoneOp, - "collective_permute_done"); -DEFINE_COLLECTIVE_DONE_OP_LOWERING(ReduceScatterDoneOp, "reduce_scatter_done"); -DEFINE_COLLECTIVE_DONE_OP_LOWERING(SendDoneOp, "send_done"); -DEFINE_COLLECTIVE_DONE_OP_LOWERING(RecvDoneOp, "recv_done"); - -#undef DEFINE_COLLECTIVE_DONE_OP_LOWERING - -template -class CollectiveIdOpLowering : public OpRewritePattern { - static StringRef Target(ReplicaIdOp) { return "xla.gpu.replica_id"; } - static StringRef Target(PartitionIdOp) { return "xla.gpu.partition_id"; } - - public: - CollectiveIdOpLowering(MLIRContext* ctx, CustomCallDeclarations& custom_calls) - : OpRewritePattern(ctx), custom_calls_(custom_calls) {} - - LogicalResult matchAndRewrite(CollectiveIdOp op, - PatternRewriter& rewriter) const override { - // Get or create a custom call function declaration. - ImplicitLocOpBuilder b(op.getLoc(), rewriter); - func::FuncOp callee = custom_calls_.GetOrCreate(b, Target(op), op); - - // Call the runtime intrinsic with the original operands. - rewriter.replaceOpWithNewOp(op, callee.getName(), TypeRange(), - op->getOperands()); - return success(); - } - - private: - CustomCallDeclarations& custom_calls_; -}; - -class ReplicaIdOpLowering : public CollectiveIdOpLowering { - public: - using CollectiveIdOpLowering::CollectiveIdOpLowering; -}; - -class PartitionIdOpLowering : public CollectiveIdOpLowering { - public: - using CollectiveIdOpLowering::CollectiveIdOpLowering; -}; - -//===----------------------------------------------------------------------===// -// Host<->Device communication ops lowering (Send/Recv). -//===----------------------------------------------------------------------===// - -template -class HostSendRecvOpLowering : public OpRewritePattern { - public: - HostSendRecvOpLowering(MLIRContext* ctx, CustomCallDeclarations& custom_calls) - : OpRewritePattern(ctx), custom_calls_(custom_calls) {} - - LogicalResult matchAndRewrite(OpT op, - PatternRewriter& rewriter) const override { - if (!op.getIsHostTransfer()) { - return failure(); - } - - constexpr bool is_done_op = - is_any; - - // Get or create a custom call function declaration. - ImplicitLocOpBuilder b(op.getLoc(), rewriter); - - // For done ops, drop the token input. - TypeRange input_types = - is_done_op ? TypeRange() : TypeRange(op->getOperands()); - func::FuncOp callee = custom_calls_.GetOrCreate( - b, Derived::kCustomCallTarget, input_types, TypeRange()); - - llvm::SmallVector custom_call_attributes = { - {b.getStringAttr("channel_handle"), op.getChannelHandleAttr()}}; - if constexpr (!is_done_op) { - custom_call_attributes.push_back(NamedAttribute( - b.getStringAttr("frontend_attributes"), op.getFrontendAttributes())); - } - - // Convert Send/Recv/SendDone/RecvDone to a function call. - ValueRange inputs = - is_done_op ? ValueRange() : ValueRange(op->getOperands()); - auto call = rewriter.create(op.getLoc(), callee.getName(), - TypeRange(), inputs); - AppendCustomCallAttrs(call, custom_call_attributes); - - if constexpr (!is_done_op) { - // For communication operation we need to produce a fake token, that will - // be later removed, because corresponding `done` operation doesn't have - // the token argument. We rely on the `unrealized_conversion_cast` - // operation to create a fake token from the `i8` constant. - Value token = op.getResult(); - Value c0 = b.create(b.getI8IntegerAttr(0)); - auto fake = b.create(token.getType(), c0); - token.replaceAllUsesWith(fake.getResult(0)); - } - - // Erase the original operation. - rewriter.eraseOp(op); - - return success(); - } - - private: - CustomCallDeclarations& custom_calls_; -}; - -#define DEFINE_HOST_SENDRECV_OP_LOWERING(OP, custom_call) \ - struct Host##OP##Lowering \ - : public HostSendRecvOpLowering { \ - static constexpr const char kCustomCallTarget[] = custom_call; \ - using HostSendRecvOpLowering::HostSendRecvOpLowering; \ - } - -DEFINE_HOST_SENDRECV_OP_LOWERING(SendOp, "xla.gpu.send_host"); -DEFINE_HOST_SENDRECV_OP_LOWERING(SendDoneOp, "xla.gpu.send_done_host"); -DEFINE_HOST_SENDRECV_OP_LOWERING(RecvOp, "xla.gpu.recv_host"); -DEFINE_HOST_SENDRECV_OP_LOWERING(RecvDoneOp, "xla.gpu.recv_done_host"); - -//===----------------------------------------------------------------------===// - -template -static WalkResult AssignAsyncUid(Operation* op, - CollectiveUidGenerator& collective_uid) { - auto start = dyn_cast(op); - if (!start) { - if constexpr (sizeof...(Remaining) != 0) { - return AssignAsyncUid(op, collective_uid); - } else { - return WalkResult::advance(); - } - } - - if (!CollectiveFilter::ShouldHandle(start)) { - return WalkResult::advance(); - } - - Value token = start.getToken(); - - // We expect the token to be consumed just once. - if (!token.hasOneUse()) return start.emitOpError("token has multiple uses"); - - // Token must be consumed by the corresponding done operation. - auto done = dyn_cast(*token.getUsers().begin()); - if (!done) return start.emitOpError("illegal token user"); - - collective_uid.AssignUid(start, done); - return WalkResult::advance(); -} - -void ConvertLmhloToGpuRuntimePass::runOnOperation() { - ModuleOp module = getOperation(); - MLIRContext* ctx = module.getContext(); - - // Keep track of the custom calls created from the lowered operations. - SymbolTable sym_table(module); - CustomCallDeclarations custom_calls(std::move(sym_table)); - - // Convert lmhlo operations to XLA gpu runtime custom calls. - RewritePatternSet patterns(ctx); - patterns.insert(ctx); - patterns.insert( - ctx, custom_calls); - - UidGenerator fft_uid; - patterns.insert(ctx, fft_uid, custom_calls); - - // Assign shared unique id to each unique pair of async start-done operations, - // all other collective operations will get assigned uid. - CollectiveUidGenerator collective_uid; - auto walked = module.walk([&collective_uid](Operation* op) { - return AssignAsyncUid< - std::pair, - std::pair, - std::pair, - std::pair, - std::pair, - std::pair, std::pair>( - op, collective_uid); - }); - if (walked.wasInterrupted()) return signalPassFailure(); - - // Convert lmhlo collective operations to XLA gpu runtime custom calls. - patterns.insert(ctx, - custom_calls); - patterns.insert( - ctx, collective_uid, custom_calls); - - // Convert lmhlo host<->device point-to-point communication operations to XLA - // gpu runtime. - patterns.insert(ctx, - custom_calls); - - if (failed(applyPatternsAndFoldGreedily(module, std::move(patterns)))) - return signalPassFailure(); - - // TODO(ezhulenev): We must run `done` op lowering after the `start` op - // lowering to ensure that all redundant collective operations will be - // safely replaced by a `memcpy` operations. - // - // This should be a part of lmhlo operation canonicalization. - { - RewritePatternSet patterns(ctx); - patterns.insert(ctx, collective_uid, custom_calls); - if (failed(applyPatternsAndFoldGreedily(module, std::move(patterns)))) - return signalPassFailure(); - } -} - -std::unique_ptr> -createConvertLmhloToGpuRuntimePass() { - return std::make_unique(); -} - -} // namespace gpu -} // namespace xla diff --git a/third_party/xla/xla/mlir/backends/gpu/transforms/memref_get_global_to_arg.cc b/third_party/xla/xla/mlir/backends/gpu/transforms/memref_get_global_to_arg.cc deleted file mode 100644 index 5333b1cdb51fb8..00000000000000 --- a/third_party/xla/xla/mlir/backends/gpu/transforms/memref_get_global_to_arg.cc +++ /dev/null @@ -1,168 +0,0 @@ -/* Copyright 2022 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include -#include - -#include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project -#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project -#include "mlir/Dialect/MemRef/IR/MemRef.h" // from @llvm-project -#include "mlir/IR/ImplicitLocOpBuilder.h" // from @llvm-project -#include "mlir/Transforms/DialectConversion.h" // from @llvm-project -#include "xla/mlir/backends/gpu/transforms/passes.h" - -namespace xla { -namespace gpu { - -#define GEN_PASS_DEF_CONVERTMEMREFGETGLOBALTOARGPASS -#include "xla/mlir/backends/gpu/transforms/passes.h.inc" - -using namespace mlir; // NOLINT - -class ConvertMemrefGetGlobalToArgPass - : public impl::ConvertMemrefGetGlobalToArgPassBase< - ConvertMemrefGetGlobalToArgPass> { - public: - ConvertMemrefGetGlobalToArgPass() = default; - - explicit ConvertMemrefGetGlobalToArgPass(int64_t min_num_elements) { - this->min_num_elements_ = min_num_elements; - } - - void runOnOperation() override; - - void getDependentDialects(DialectRegistry& registry) const override { - registry.insert(); - } -}; - -//===----------------------------------------------------------------------===// - -using GlobalConstantsArgs = - llvm::DenseMap>; - -// Returns a mapping from a global constant name to the function argument. -// -// Example: -// -// memref.global "private" constant @cst : memref<2x3xf32> -// func @get_global(%arg0: memref<24xi8> {lmhlo.constant_name = "cst"}) -// -// All memref.get_global operations will be replaced by constant arguments -// corresponding to the global constant. -static GlobalConstantsArgs GetConstantArgs(ModuleOp m) { - GlobalConstantsArgs mapping; - - m.walk([&](func::FuncOp func) { - for (unsigned i = 0; i < func.getNumArguments(); ++i) { - auto cst = func.getArgAttrOfType(i, "lmhlo.constant_name"); - if (cst) mapping[func][cst] = func.getArgument(i); - } - }); - - return mapping; -} - -class GetGlobalOpLowering : public OpRewritePattern { - public: - GetGlobalOpLowering(MLIRContext* ctx, const GlobalConstantsArgs& cst_args) - : OpRewritePattern(ctx), cst_args_(cst_args) {} - - LogicalResult matchAndRewrite(memref::GetGlobalOp op, - PatternRewriter& rewriter) const override { - // Find global constants mapping for the parent function. - auto func_mapping = cst_args_.find(op->getParentOfType()); - if (func_mapping == cst_args_.end()) return failure(); - - // Check if the global operation corresponds to the LMHLO constant arg. - auto arg = func_mapping->second.find(op.getName()); - if (arg == func_mapping->second.end()) return failure(); - - ImplicitLocOpBuilder b(op.getLoc(), rewriter); - MemRefType memref = op->getResult(0).getType().cast(); - - // For identity layouts we can replace all loads from a global with the - // corresponding argument. - if (memref.getLayout().isIdentity()) { - Value c0 = b.create(rewriter.getIndexAttr(0)); - rewriter.replaceOpWithNewOp(op, memref, arg->second, c0, - ValueRange()); - return success(); - } - - // For non-identity type we first view constant argument as a flat memref - // with the correct element type, and then cast it to the strided memref - // corresponding to the original memref layout. - - // Get the strides and offset from the original memref type. - int64_t offset; - llvm::SmallVector strides; - if (failed(getStridesAndOffset(memref, strides, offset))) - return op.emitOpError("failed to compute strides and offset"); - - // Create a 1d view into the corresponding argument. - Value c0 = b.create(rewriter.getIndexAttr(0)); - Value flat_view = b.create( - MemRefType::get({memref.getNumElements()}, memref.getElementType()), - arg->second, c0, ValueRange()); - - // Cast flat memref view into the original memref type. - rewriter.replaceOpWithNewOp( - op, memref, flat_view, offset, memref.getShape(), strides); - - return success(); - } - - private: - const GlobalConstantsArgs& cst_args_; -}; - -void ConvertMemrefGetGlobalToArgPass::runOnOperation() { - ModuleOp module = getOperation(); - MLIRContext* ctx = module.getContext(); - - // Replace memref loads from globals corresponding to the constant arguments. - RewritePatternSet patterns(ctx); - GlobalConstantsArgs cst_args = GetConstantArgs(module); - patterns.insert(ctx, cst_args); - - // Set up conversion target to rewrite only GetGlobalOp larger than the - // threshold and avoid any other canonicalizations that can break later - // passes. - ConversionTarget target(*ctx); - target.addDynamicallyLegalOp( - [&](memref::GetGlobalOp op) { - auto memref = op.getType(); - return memref.getNumElements() < min_num_elements_; - }); - target.addLegalOp(); - - if (failed(applyPartialConversion(module, target, std::move(patterns)))) - signalPassFailure(); -} - -std::unique_ptr> -createConvertMemrefGetGlobalToArgPass() { - return std::make_unique(); -} - -std::unique_ptr> -createConvertMemrefGetGlobalToArgPass(int64_t min_num_elements) { - return std::make_unique(min_num_elements); -} - -} // namespace gpu -} // namespace xla diff --git a/third_party/xla/xla/mlir/backends/gpu/transforms/outline_cuda_graphs.cc b/third_party/xla/xla/mlir/backends/gpu/transforms/outline_cuda_graphs.cc deleted file mode 100644 index 5e8393cad39dc0..00000000000000 --- a/third_party/xla/xla/mlir/backends/gpu/transforms/outline_cuda_graphs.cc +++ /dev/null @@ -1,518 +0,0 @@ -/* Copyright 2022 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include -#include -#include -#include -#include - -#include "absl/container/flat_hash_set.h" -#include "llvm/ADT/STLExtras.h" -#include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project -#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project -#include "mlir/Dialect/GPU/IR/GPUDialect.h" // from @llvm-project -#include "mlir/Dialect/MemRef/IR/MemRef.h" // from @llvm-project -#include "mlir/IR/Builders.h" // from @llvm-project -#include "mlir/IR/Dominance.h" // from @llvm-project -#include "mlir/IR/ImplicitLocOpBuilder.h" // from @llvm-project -#include "mlir/IR/MLIRContext.h" // from @llvm-project -#include "mlir/IR/Operation.h" // from @llvm-project -#include "mlir/IR/SymbolTable.h" // from @llvm-project -#include "mlir/IR/TypeRange.h" // from @llvm-project -#include "mlir/Support/LogicalResult.h" // from @llvm-project -#include "mlir/Transforms/RegionUtils.h" // from @llvm-project -#include "xla/debug_options_flags.h" -#include "xla/mlir/backends/gpu/transforms/passes.h" -#include "xla/mlir/runtime/ir/rt_dialect.h" -#include "xla/mlir/runtime/ir/rt_ops.h" -#include "xla/mlir/runtime/utils/custom_calls.h" -#include "xla/mlir_hlo/lhlo_gpu/IR/lhlo_gpu_ops.h" -#include "xla/service/gpu/backend_configs.pb.h" -#include "xla/stream_executor/blas.h" -#include "xla/xla.pb.h" - -namespace xla { -namespace gpu { - -#define GEN_PASS_DEF_OUTLINEGPUGRAPHSPASS -#include "xla/mlir/backends/gpu/transforms/passes.h.inc" - -using namespace mlir; // NOLINT - -using mlir::gpu::LaunchFuncOp; - -class OutlineGpuGraphsPass - : public impl::OutlineGpuGraphsPassBase { - public: - OutlineGpuGraphsPass() = default; - explicit OutlineGpuGraphsPass( - absl::flat_hash_set command_types, - int min_graph_size) - : command_types_(std::move(command_types)) { - this->min_graph_size_ = min_graph_size; - } - - void runOnOperation() override; - - void getDependentDialects(DialectRegistry& registry) const override { - registry.insert(); - } - - private: - absl::flat_hash_set command_types_ = { - DebugOptions::FUSION, DebugOptions::CUBLAS, DebugOptions::CUDNN}; - int gpu_graph_level_ = 3; -}; - -//===----------------------------------------------------------------------===// - -struct OpCapturePattern { - // CUDA-graph-compatible operations can be either moved or cloned into the - // graph capture function. Most of the operations should be moved, as they - // have side effects, however small constants and pure operations like - // `memref.view` can be safely cloned into the graph region. We rely on later - // dead code elimination to erase them from the "main" function if they are - // not used by any other operations. - enum class Capture { kMove, kClone }; - - virtual ~OpCapturePattern() = default; - virtual FailureOr match(Operation* op) = 0; -}; - -using OpCapturePatternSet = std::vector>; - -// A sequence of operations to be outlined into cuda graph capture function. -using CaptureSequence = - llvm::SmallVector>; - -//===----------------------------------------------------------------------===// - -template -struct OpCapture : public OpCapturePattern { - FailureOr match(Operation* op) final { - if (isa(op)) return capture; - return failure(); - } -}; - -static constexpr auto kMove = OpCapturePattern::Capture::kMove; -static constexpr auto kClone = OpCapturePattern::Capture::kClone; - -template -using MoveOp = OpCapture; -template -using CloneOp = OpCapture; - -// Capture gpu operations by moving them into graph capture function. -struct LaunchFuncOpCapture : public MoveOp {}; - -template -struct ConvOpCapture : public OpCapturePattern { - FailureOr match(Operation* op) final { - if (auto conv = llvm::dyn_cast(op)) { - // Convolution that does runtime autotuning should not be captured, since - // CUDA graphs do not support operations that allocate memory. - lmhlo_gpu::ConvolutionBackendConfigAttr backend_config = - conv.getBackendConfig(); - if (backend_config.getAlgorithm() != -1) { - return kMove; - } - } - return failure(); - } -}; - -// TODO(b/270426911): Right now GEMM/Convolution with runtime autotuning can't -// be captured by a cuda graph. However, longer term the proper fix is to make -// autotuning "cuda-graph-aware", and run autotuning on a separate stream that -// is not in capture mode. -struct ConvForwardOpCapture : public ConvOpCapture {}; -struct ConvBackwardInputOpCapture - : public ConvOpCapture {}; -struct ConvBackwardFilterOpCapture - : public ConvOpCapture {}; -struct ConvForwardFusedOpCapture - : public ConvOpCapture {}; -struct ConvForwardFusedSideInputOpCapture - : public ConvOpCapture {}; - -struct GemmOpCapture : public OpCapturePattern { - FailureOr match(Operation* op) final { - if (auto gemm = llvm::dyn_cast(op)) { - // GEMM that does runtime autotuning should not be captured, since CUDA - // graph does not support operations that allocate memory. - if (!gemm.getAlgorithm().has_value() || - gemm.getAlgorithm().value() != - stream_executor::blas::kRuntimeAutotuning) { - return kMove; - } - } - return failure(); - } -}; - -struct MemcpyOpCapture : public OpCapturePattern { - FailureOr match(Operation* op) final { - if (auto memcpy = llvm::dyn_cast(op)) { - // We use a heuristic to identify the direction of the memcpy operation, - // if the operand was allocated by alloca op or is a global memref, then - // it must be a memref on the host. - auto IsHostMemRef = [](Value value) { - auto* op = value.getDefiningOp(); - return llvm::isa_and_nonnull(op); - }; - - auto IsDeviceToDevice = [&](mlir::gpu::MemcpyOp op) { - return !IsHostMemRef(op.getDst()) && !IsHostMemRef(op.getSrc()); - }; - - // Device-to-host Memcpy cannot be captured by CUDA graphs. - if (IsDeviceToDevice(memcpy)) { - return kMove; - } - } - return failure(); - } -}; - -// Capture pure operations by cloning them into graph capture function. -struct ConstantOpCapture : public CloneOp {}; -struct ViewOpCapture : public CloneOp {}; -struct ReinterpretCastOpCapture : public CloneOp {}; - -//===----------------------------------------------------------------------===// - -// Collect sequences of operations that can be outlined into Cuda Graphs. -static std::vector CollectCaptureSequences( - DominanceInfo& dominance, ModuleOp module, OpCapturePatternSet& patterns) { - std::vector seqs; - - // Match given operation with all capture patterns. - auto match = [&](Operation* op) -> FailureOr { - for (auto& pattern : patterns) { - if (auto matched = pattern->match(op); succeeded(matched)) return matched; - } - return failure(); - }; - - // Find graph-compatible sequences of operations in every block. - module.walk([&](Block* block) { - CaptureSequence* seq = &seqs.emplace_back(); - - for (Operation& op : *block) { - FailureOr matched = match(&op); - // Append matched operation to the current sequence. We only append - // operations that must be moved into the graph capture function (ops with - // side effects), and add cloneable operations later. - if (succeeded(matched) && *matched == kMove) - seq->emplace_back(&op, *matched); - - // Skip unsupported operation and start a new sequence. - if (failed(matched) && !seq->empty()) seq = &seqs.emplace_back(); - } - - // Remove the last sequence if it's empty. - if (seq->empty()) seqs.pop_back(); - }); - - // Remove cloneable operations accidentally captured by the sequence of ops, - // e.g. we can have `memref.view` between two kernel launch operations that - // is not used by operations in the captured sequence. - for (CaptureSequence& seq : seqs) { - llvm::DenseSet moveable_ops; - for (auto& [op, capture] : seq) - if (capture == kMove) moveable_ops.insert(op); - - llvm::erase_if(seq, [&](auto& pair) { - return pair.second == kClone && - llvm::none_of(pair.first->getUsers(), [&](Operation* user) { - return moveable_ops.contains(user); - }); - }); - } - - // Try to extend discovered sequences of ops following operands use-def chains - // and pulling cloneable operations defining operands into the graph capture - // sequence. In practice we just clone `arith.constant` and `memref.view` - // operations into the graph capture function, to make it cheaper to compute - // the hash of the arguments at run time. - for (CaptureSequence& seq : seqs) { - llvm::DenseSet seq_ops; // operations already in `seq` - llvm::SmallVector worklist; - - // Add operations that define `op` arguments to the worklist. - auto populate_worklist = [&](Operation* op) { - for (Value arg : op->getOperands()) - if (Operation* op = arg.getDefiningOp()) worklist.push_back(op); - }; - - for (auto& [op, _] : seq) { - seq_ops.insert(op); - populate_worklist(op); - } - - // Find cloneable ops and group them by block where they are defined. - llvm::DenseMap> cloneable; - - // Traverse use-def chains to collect all cloneable operations. - while (!worklist.empty()) { - Operation* op = worklist.pop_back_val(); - if (seq_ops.contains(op)) continue; - - // Check if operation can be cloned into graph capture function. - if (auto matched = match(op); - succeeded(matched) && *matched == OpCapturePattern::Capture::kClone) { - cloneable[op->getBlock()].push_back(op); - seq_ops.insert(op); - populate_worklist(op); - } - } - - // Traverse blocks according to their dominance to avoid used-before-defined - // invalid SSA region construction in graph capture function. - llvm::SmallVector blocks; - for (auto& [block, _] : cloneable) blocks.push_back(block); - llvm::sort(blocks, [&](Block* a, Block* b) { - return dominance.properlyDominates(a, b); - }); - - for (Block* block : llvm::reverse(blocks)) { - // Sort operations according to their original position in the block. - llvm::sort(cloneable[block], [](Operation* a, Operation* b) { - return a->isBeforeInBlock(b); - }); - - // Prepend all cloneable operations to the discovered ops sequence. - auto cloned = llvm::map_range(cloneable[block], [](Operation* op) { - return std::make_pair(op, OpCapturePattern::Capture::kClone); - }); - seq.insert(seq.begin(), cloned.begin(), cloned.end()); - } - } - - return seqs; -} - -//===----------------------------------------------------------------------===// - -using xla::runtime::CustomCallDeclarations; - -static std::vector GetGraphCaptureFuncArgs(const CaptureSequence& seq) { - llvm::SetVector args; - - // Values defined by operations in the capture sequence. - llvm::DenseSet defined_by_seq; - for (auto& [op, _] : seq) - defined_by_seq.insert(op->result_begin(), op->result_end()); - - // Add arguments defined outside of the capture sequence. - for (auto& [op, _] : seq) { - auto external_args = llvm::make_filter_range( - op->getOperands(), - [&](Value arg) { return !defined_by_seq.contains(arg); }); - args.insert(external_args.begin(), external_args.end()); - } - llvm::SmallVector args_sv = args.takeVector(); - std::vector args_tv(args_sv.begin(), args_sv.end()); - return args_tv; -} - -// Given a sequence of operations, outline them into a graph capture function -// and replace them with an XLA Gpu runtime function call. -static LogicalResult Outline(unsigned ordinal, - CustomCallDeclarations& custom_calls, - CaptureSequence& seq, int min_graph_size) { - // Only operations that have to be moved into the graph capture function - // represent Gpu computations. - unsigned num_move_captures = llvm::count_if(seq, [](auto capture) { - return capture.second == OpCapturePattern::Capture::kMove; - }); - DebugOptions debug_options = GetDebugOptionsFromFlags(); - if (num_move_captures < min_graph_size) return failure(); - - SymbolTable& sym_table = custom_calls.sym_table(); - MLIRContext* ctx = sym_table.getOp()->getContext(); - - // Create a fused location out of LaunchFuncOp operations. - llvm::SmallVector locations; - for (auto& op : seq) locations.push_back(op.first->getLoc()); - ImplicitLocOpBuilder b(FusedLoc::get(ctx, locations), sym_table.getOp()); - - // Arguments of the graph capture function. - std::vector args = GetGraphCaptureFuncArgs(seq); - - // Create a function in the compiled module. - auto func = b.create( - "xla.gpu.graph.capture", - FunctionType::get(ctx, TypeRange(ValueRange(args)), TypeRange())); - - Operation* first_op = seq.front().first; - auto parent_func = first_op->getParentOfType(); - - // If an argument to parent_func has the "lmhlo.constant_name" attribute and - // is passed to the graph capture function, we propagate the attribute the - // graph capture function. - // - // We also annotate all arguments with "rt.allocation_index" attribute that - // allows us to forward correct arguments to graph capture function during - // Gpu executable initialization (see `InstantiateAllGraphs` implementation). - for (unsigned i = 0; i < args.size(); ++i) { - Value arg = args[i]; - - // Check if arg is a function argument of parent_func. - if (!isa(arg)) continue; - - // Function arguments are passed in as block arguments to the entry block. - auto block_arg = cast(arg); - Block* parent_block = block_arg.getParentBlock(); - if (!parent_block->isEntryBlock()) continue; - - // If this is an argument to the entry block of the parent function, it - // means that it's the XLA allocation, and we forward index to the capture - // function. - func.setArgAttr(i, "rt.allocation_index", - b.getIndexAttr(block_arg.getArgNumber())); - - // Check that the parent_block is in the SSACFG region of parent_func. - Region& parent_func_region = parent_func.getRegion(); - if (parent_block->getParent() != &parent_func_region) continue; - - unsigned parent_func_arg_index = block_arg.getArgNumber(); - auto cst = parent_func.getArgAttrOfType(parent_func_arg_index, - "lmhlo.constant_name"); - if (cst) { - func.setArgAttr(i, "lmhlo.constant_name", cst); - } - } - - for (auto op : seq) { - mlir::Operation* captured_op = op.first; - if (isa(captured_op)) { - func->setAttr(b.getStringAttr(runtime::kRequiresBlasAttrName), - BoolAttr::get(ctx, true)); - break; - } - } - - // Add graph capture function to the module. - sym_table.insert(func); - - // Export graph capture function to the runtime. - b.setInsertionPoint(func); - b.create(func, ordinal); - - // Create a custom call declaration corresponding to the outlined graph - // capture function. - func::FuncOp graph_launch = custom_calls.GetOrCreate( - b, "xla.gpu.graph.launch", TypeRange(ValueRange(args)), TypeRange()); - - // Call the cuda graph launch custom call right before the first moved op. - auto insertion_point = llvm::find_if(seq, [](auto capture) { - return capture.second == OpCapturePattern::Capture::kMove; - }); - b.setInsertionPoint(insertion_point->first); - - auto call = b.create(graph_launch.getName(), TypeRange(), args); - call->setAttr(b.getStringAttr("capture"), FlatSymbolRefAttr::get(func)); - - // At this point we successfully added new functions to the module, so we can - // move or clone captured operations from their original location to the graph - // capture function. - Block* body = func.addEntryBlock(); - - // We'll need to replace operands of cloned/moved operations inside the graph - // capture function. - llvm::SmallVector> mappings; // {from, to} mappings - for (auto mapping : llvm::zip(args, func.getArguments())) - mappings.emplace_back(std::get<0>(mapping), std::get<1>(mapping)); - - // Move or clone operations into the graph capture function. - for (auto& [op, capture] : seq) { - if (capture == OpCapturePattern::Capture::kMove) - op->moveBefore(body, body->end()); - - if (capture == OpCapturePattern::Capture::kClone) { - Operation* clone = op->clone(); - OpBuilder::atBlockEnd(body).insert(clone); - - for (auto mapping : llvm::zip(op->getResults(), clone->getResults())) - mappings.emplace_back(std::get<0>(mapping), std::get<1>(mapping)); - } - } - - // Update def-use chains inside the graph capture function. - for (auto mapping : mappings) { - replaceAllUsesInRegionWith(mapping.first, mapping.second, func.getBody()); - } - - // Add a return operation to the graph capture function. - b.setInsertionPointToEnd(body); - b.create(ValueRange()); - - return success(); -} - -//===----------------------------------------------------------------------===// - -void OutlineGpuGraphsPass::runOnOperation() { - SymbolTable sym_table(getOperation()); - CustomCallDeclarations custom_calls(std::move(sym_table)); - - OpCapturePatternSet patterns; - - if (command_types_.contains(DebugOptions::FUSION)) { - // Enable capturing fusions and memcpies. - patterns.emplace_back(new LaunchFuncOpCapture()); - patterns.emplace_back(new ConstantOpCapture()); - patterns.emplace_back(new ViewOpCapture()); - patterns.emplace_back(new MemcpyOpCapture()); - patterns.emplace_back(new ReinterpretCastOpCapture()); - } - - if (command_types_.contains(DebugOptions::CUBLAS)) { - // Enable capturing gemms. - patterns.emplace_back(new GemmOpCapture()); - } - - if (command_types_.contains(DebugOptions::CUDNN)) { - // Enable capturing convolutions. - patterns.emplace_back(new ConvForwardOpCapture()); - patterns.emplace_back(new ConvBackwardInputOpCapture()); - patterns.emplace_back(new ConvBackwardFilterOpCapture()); - patterns.emplace_back(new ConvForwardFusedOpCapture()); - patterns.emplace_back(new ConvForwardFusedSideInputOpCapture()); - } - - unsigned ordinal = 1; // entry point will be exported with ordinal 0 - for (auto& seq : CollectCaptureSequences(getAnalysis(), - getOperation(), patterns)) { - if (succeeded(Outline(ordinal, custom_calls, seq, min_graph_size_))) - ordinal++; - } -} - -std::unique_ptr> createOutlineGpuGraphsPass() { - return std::make_unique(); -} - -std::unique_ptr> createOutlineGpuGraphsPass( - absl::flat_hash_set command_types, - int min_graph_size) { - return std::make_unique(command_types, min_graph_size); -} - -} // namespace gpu -} // namespace xla diff --git a/third_party/xla/xla/mlir/backends/gpu/transforms/passes.cc b/third_party/xla/xla/mlir/backends/gpu/transforms/passes.cc deleted file mode 100644 index a453e037334931..00000000000000 --- a/third_party/xla/xla/mlir/backends/gpu/transforms/passes.cc +++ /dev/null @@ -1,95 +0,0 @@ -/* Copyright 2022 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "xla/mlir/backends/gpu/transforms/passes.h" - -#include -#include - -#include "absl/log/log.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project -#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project -#include "mlir/IR/BuiltinOps.h" // from @llvm-project -#include "mlir/IR/SymbolTable.h" // from @llvm-project -#include "mlir/Pass/PassManager.h" // from @llvm-project -#include "mlir/Transforms/Passes.h" // from @llvm-project -#include "xla/mlir/runtime/ir/rt_ops.h" - -namespace xla { -namespace gpu { - -using namespace mlir; // NOLINT - -std::vector> GetAllocationIndices(mlir::ModuleOp module) { - std::vector> res; - - SymbolTable sym_table(module); - for (auto op : module.getOps()) { - unsigned ordinal = *op.ordinal(); - if (ordinal >= res.size()) res.resize(ordinal + 1); - - auto func = sym_table.lookup(op.getFunctionRef()); - res[ordinal].resize(func.getNumArguments(), -1); - - for (unsigned i = 0; i < func.getNumArguments(); ++i) { - auto idx = func.getArgAttrOfType(i, "rt.allocation_index"); - if (idx) res[ordinal][i] = idx.getInt(); - } - } - - return res; -} - -void populateXlaGpuRuntimePasses(mlir::OpPassManager& pm, - ThunkSequence* thunk_sequence, - const GpuPipelineOpts& opts) { - // Lower operations with registered IR emitters to Gpu launches. - pm.addPass(createConvertLmhloToGpuLaunchPass(thunk_sequence)); - - // Clean up IR before converting it to the runtime operations. - pm.addPass(createCSEPass()); - - // Convert global memrefs corresponding to constant arguments. - pm.addPass(createConvertMemrefGetGlobalToArgPass()); - pm.addPass(createSymbolDCEPass()); // Clean up unused global constants. - - // Outline CUDA-Graph-compatible operations into graph capture functions. - pm.addPass( - createOutlineGpuGraphsPass(opts.command_types, opts.min_graph_size)); - if (opts.enable_concurrent_region) { - // Concurrent regions create repeated-fork-join topology inside CUDA graphs, - // which is not optimized by architectures prior to Ampere and may cause - // regression. So we enable concurrent regions only on Ampere GPUs. - if (auto cc = std::get_if( - &opts.compute_capability); - !cc || cc->IsAtLeast(8, 0)) { - pm.addPass(createAddConcurrentRegionsPass()); - } else { - LOG(WARNING) - << "Multi-stream execution disabled on non-ampere architectures"; - } - } - - // Lower all Gpu operations to the XLA Gpu runtime custom calls. - pm.addPass(createConvertLmhloGpuToGpuRuntimePass()); - pm.addPass(createConvertLmhloToGpuRuntimePass()); - pm.addPass(createConvertGpuToGpuRuntimePass()); - - // Add performance tracing annotations. - pm.addPass(createAddHloTraceAnnotationsPass()); -} - -} // namespace gpu -} // namespace xla diff --git a/third_party/xla/xla/mlir/backends/gpu/transforms/passes.h b/third_party/xla/xla/mlir/backends/gpu/transforms/passes.h deleted file mode 100644 index fa6fae39205f3f..00000000000000 --- a/third_party/xla/xla/mlir/backends/gpu/transforms/passes.h +++ /dev/null @@ -1,150 +0,0 @@ -/* Copyright 2022 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_MLIR_BACKENDS_GPU_TRANSFORMS_PASSES_H_ -#define XLA_MLIR_BACKENDS_GPU_TRANSFORMS_PASSES_H_ - -#include -#include -#include - -#include "absl/container/flat_hash_set.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project -#include "mlir/IR/BuiltinOps.h" // from @llvm-project -#include "mlir/Pass/Pass.h" // from @llvm-project -#include "xla/stream_executor/device_description.h" -#include "xla/xla.pb.h" - -namespace xla { -namespace gpu { - -#define GEN_PASS_DECL_ADDHLOTRACEANNOTATIONSPASS -#define GEN_PASS_DECL_CONVERTGPUTOGPURUNTIMEPASS -#define GEN_PASS_DECL_CONVERTLMHLOGPUTOGPURUNTIMEPASS -#define GEN_PASS_DECL_CONVERTLMHLOTOGPULAUNCHPASS -#define GEN_PASS_DECL_CONVERTLMHLOTOGPURUNTIMEPASS -#define GEN_PASS_DECL_CONVERTMEMREFGETGLOBALTOARGPASS -#define GEN_PASS_DECL_OUTLINEGPUGRAPHSPASS -#define GEN_PASS_DECL_ADDCONCURRENTREGIONSPASS -#define GEN_PASS_DECL_STREAMASSIGNMENTPASS -#include "xla/mlir/backends/gpu/transforms/passes.h.inc" - -class ThunkSequence; // forward declare - -// Collects `rt.allocation_index` attributes from all exported functions. -// -// auto result = GetAllocationIndices(); -// result[ordinal][argument_index] == allocation_index; -// -// Returns `-1` for all arguments that do not have `rt.allocation_index` -// attribute. -// -// TODO(ezhulenev): This is a very ugly hack for graph capture integration, but -// given that we are moving towards a new runtime and command buffers, it's -// supposed to be a very short lived hack. -std::vector> GetAllocationIndices(mlir::ModuleOp module); - -struct GpuPipelineOpts { - // Enable experimental pass that outlines parts of the XLA computation into - // CUDA Graphs, which allows us to amortize the cost of launching multiple - // device kernels. - absl::flat_hash_set command_types; - int32_t min_graph_size = 0; - bool enable_concurrent_region = false; - stream_executor::GpuComputeCapability compute_capability; -}; - -// Populate passes that lower MLIR modules from a combination of LMHLO and -// LMHLO_GPU dialects to the XLA Gpu runtime. This pipeline is composed from -// the passes defined below, and few builtin MLIR passes. -void populateXlaGpuRuntimePasses(mlir::OpPassManager& pm, - ThunkSequence* thunk_sequence, - const GpuPipelineOpts& opts = {}); - -//===----------------------------------------------------------------------===// -// Auxiliary passes for lowering to XLA Gpu runtime. -//===----------------------------------------------------------------------===// - -std::unique_ptr> -createConvertMemrefGetGlobalToArgPass(); - -std::unique_ptr> -createConvertMemrefGetGlobalToArgPass(int64_t min_num_elements); - -//===-----------------------------------------------------------------------===/ -// Passes for lowering from the `gpu` dialect. -//===-----------------------------------------------------------------------===/ - -std::unique_ptr> -createConvertGpuToGpuRuntimePass(); - -//===----------------------------------------------------------------------===// -// Passes for lowering from the `lmhlo` dialect. -//===----------------------------------------------------------------------===// - -std::unique_ptr> -createConvertLmhloToGpuLaunchPass(ThunkSequence* thunk_sequence = nullptr); - -std::unique_ptr> -createConvertLmhloToGpuRuntimePass(); - -//===----------------------------------------------------------------------===// -// Passes for lowering from the `lmhlo_gpu` dialect. -//===----------------------------------------------------------------------===// - -std::unique_ptr> -createConvertLmhloGpuToGpuRuntimePass(); - -//===----------------------------------------------------------------------===// -// XLA runtime performance tracing passes. -//===----------------------------------------------------------------------===// - -std::unique_ptr> -createAddHloTraceAnnotationsPass(); - -//===----------------------------------------------------------------------===// -// XLA runtime <-> Cuda Graphs integration. -//===----------------------------------------------------------------------===// - -std::unique_ptr> -createOutlineGpuGraphsPass(); - -std::unique_ptr> createOutlineGpuGraphsPass( - absl::flat_hash_set command_types, - int32_t min_graph_size); - -//===----------------------------------------------------------------------===// -// Passes for marking concurrent region in CUDA graph capture function. -//===----------------------------------------------------------------------===// - -std::unique_ptr> -createAddConcurrentRegionsPass(); - -//===----------------------------------------------------------------------===// -// Passes for assigning kernels to streams in CUDA graph capture function. -//===----------------------------------------------------------------------===// - -std::unique_ptr> -createStreamAssignmentPass(); - -//===-----------------------------------------------------------------------===/ - -#define GEN_PASS_REGISTRATION -#include "xla/mlir/backends/gpu/transforms/passes.h.inc" - -} // namespace gpu -} // namespace xla - -#endif // XLA_MLIR_BACKENDS_GPU_TRANSFORMS_PASSES_H_ diff --git a/third_party/xla/xla/mlir/backends/gpu/transforms/passes.td b/third_party/xla/xla/mlir/backends/gpu/transforms/passes.td deleted file mode 100644 index a522e7eeb735bf..00000000000000 --- a/third_party/xla/xla/mlir/backends/gpu/transforms/passes.td +++ /dev/null @@ -1,302 +0,0 @@ -/* Copyright 2022 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_GPU_PASSES -#define XLA_GPU_PASSES - -include "mlir/Pass/PassBase.td" - -//===----------------------------------------------------------------------===// -// Auxiliary passes for lowering to XLA Gpu runtime. -//===----------------------------------------------------------------------===// - -def ConvertMemrefGetGlobalToArgPass : - Pass<"xla-memref-get-global-to-arg", "mlir::ModuleOp"> { - let summary = "Converts memref.get_global corresponding to lmhlo constants"; - - let description = [{ - Replaces `memref.get_global` operations corresponding to the lmhlo constant - arguments (arguments marked with `lmhlo.constant_name` attribute) to use - the constant arguments directly. - - Once we used global constants for constant folding, we no longer need to - keep them in the module, because they'll be in the binary constant section - on the host, and we need them on the device. - }]; - - let constructor = "createConvertMemrefGetGlobalToArgPass()"; - - let options = [ - Option<"min_num_elements_", "min-num-elements", "int64_t", /*default=*/"0", - "Do not convert `memref.get_global` operation if the number of " - "elements is smaller than the given value.">, - ]; -} - -//===----------------------------------------------------------------------===// -// Passes for lowering from the `gpu` dialect. -//===----------------------------------------------------------------------===// - -def ConvertGpuToGpuRuntimePass : - Pass<"xla-gpu-to-gpu-runtime", "mlir::ModuleOp"> { - let summary = "Converts gpu operations to XLA Gpu runtime custom calls"; - - let description = [{ - Converts gpu operations (function launch, memcpy, etc...) to the XLA Gpu - runtime custom calls. - }]; - - let constructor = "createConvertGpuToGpuRuntimePass()"; -} - -//===----------------------------------------------------------------------===// -// Passes for lowering from the `lmhlo` dialect. -//===----------------------------------------------------------------------===// - -def ConvertLmhloToGpuLaunchPass : - Pass<"xla-lmhlo-to-gpu-launch", "mlir::ModuleOp"> { - let summary = "Converts lmhlo fusions to Gpu dialect kernel launch"; - - let description = [{ - Converts lmhlo operations that have registered IR emitters (e.g. fusions) to - Gpu dialect kernel launch operations (and trivial memory operations like - memcpy or memset). This pass relies on a pre-compiled ThunkSequence with an - associated device module (PTX and cubin) to find device kernels - corresponding to lmhlo operation in the input module. - - Created Gpu kernel launch operations can be further lowered to the Gpu - runtime by the `xla-gpu-to-gpu-runtime` pass. - }]; - - let constructor = "createConvertLmhloToGpuLaunchPass()"; -} - -def ConvertLmhloToGpuRuntimePass : - Pass<"xla-lmhlo-to-gpu-runtime", "mlir::ModuleOp"> { - let summary = "Converts lmhlo operations to XLA Gpu runtime custom calls"; - - let description = [{ - Converts lmhlo dialect operations (infeed, outfeed, collectives, etc...) to - the XLA Gpu runtime custom calls. - }]; - - let constructor = "createConvertLmhloToGpuRuntimePass()"; -} - -//===----------------------------------------------------------------------===// -// Passes for lowering from the `lmhlo_gpu` dialect. -//===----------------------------------------------------------------------===// - -def ConvertLmhloGpuToGpuRuntimePass : - Pass<"xla-lmhlo-gpu-to-gpu-runtime", "mlir::ModuleOp"> { - let summary = "Converts lmhlo_gpu operations to XLA Gpu runtime custom calls"; - - let description = [{ - Converts lmhlo_gpu dialect operations (gemm, convolution, etc...) to - the XLA Gpu runtime custom calls. - }]; - - let constructor = "createConvertLmhloGpuToGpuRuntimePass()"; -} - -//===----------------------------------------------------------------------===// -// XLA runtime performance tracing passes. -//===----------------------------------------------------------------------===// - -// TODO(ezhulenev): This pass should be generic for all backends, consider -// moving it to the `transforms/runtime` folder once it will be used by CPU -// compiler. - -def AddHloTraceAnnotationsPass : - Pass<"xla-add-hlo-trace-annotations", "mlir::ModuleOp"> { - let summary = "Adds HLO trace annotations to the supported operations"; - - let description = [{ - Adds HLO trace annotations to the operations that result from compiling - an input HLO module, e.g. it adds HLO trace annotations to all runtime custom - calls that are constructed from the corresponding HLO operations. - - Example: - - ```mlir - call @local_xla.gpu.gemm(...) : (...) -> memref - ``` - - becomes: - - ```mlir - call @local_xla.gpu.gemm(...) { rt.trace = #rt.hlo<"gemm.1", "xla_module", 0> } - : (...) -> memref - ``` - - XLA compilation pipeline wraps traced operations into the `rt.trace` - operation, and eventually lowers them to the tracing API calls. - }]; - - let constructor = "createAddHloTraceAnnotationsPass()"; -} - -//===----------------------------------------------------------------------===// -// Xla Gpu <-> Cuda Graphs integration. -//===----------------------------------------------------------------------===// - -def OutlineGpuGraphsPass : - Pass<"xla-gpu-outline-gpu-graphs", "mlir::ModuleOp"> { - let summary = "Outline sequences of Xla Gpu operations into CUDA Graphs"; - - let description = [{ - Converts sequences of supported Xla Gpu operations to Cuda Graph capture - functions, and replaces the original sequences with calls to the Xla Cuda - Graph runtime API. - - Example: - - ```mlir - gpu.launch_func @compute::foo args(%arg0: memref) - gpu.launch_func @compute::bar args(%arg1: memref) - ``` - - becomes: - - ```mlir - // Export cuda graph capture function to Xla runtime. - rt.export @capture ordinal 1 - func.func @capture(@arg0: memref, %arg1: memref) { - ... capture a graph corresponding to a sequence of `gpu.launch_func` ops - } - - // Replace a sequence of graph launch operations with a call to runtime API. - call @local_xla.gpu.graph.launch(%arg0: memref, - %arg1: memref) - attributes { capture = @capture } - ``` - }]; - - let constructor = "createOutlineGpuGraphsPass()"; - - let options = [ - Option<"min_graph_size_", "min_graph_size", "int64_t", /*default=*/"2", - "The minimum size of the outlined CUDA graph function.">, - ]; -} - -//===----------------------------------------------------------------------===// -// Add concurrent regions to CUDA graph capture functions. -//===----------------------------------------------------------------------===// - -def AddConcurrentRegionsPass: - Pass<"xla-gpu-add-concurrent-regions", "mlir::ModuleOp"> { - let summary = "Identify and mark concurrent regions in CUDA graph capture " - "functions"; - - let description = [{ - Add concurent region markers to indicate a region of operations that can be - executed concurrently. - - Example: - - ```mlir - func.func @capture.cuda.graph() { - call @local_xla.gpu.launch.func - call @local_xla.gpu.launch.func - - // Everything here can run concurrently - call @local_xla.gpu.launch.func - call @local_xla.gpu.launch.func - call @local_xla.gpu.launch.func - call @local_xla.gpu.launch.func - // Back to sequential execution - - call @local_xla.gpu.launch.func - func.return - } - ``` - - becomes: - - ```mlir - func.func @capture.cuda.graph() { - call @local_xla.gpu.launch.func - call @local_xla.gpu.launch.func - - call @local_xla.gpu.concurrent_region.begin() - call @local_xla.gpu.launch.func - call @local_xla.gpu.launch.func - call @local_xla.gpu.launch.func - call @local_xla.gpu.launch.func - call @local_xla.gpu.concurrent_region.end() - - call @local_xla.gpu.launch.func - func.return - } - ``` - - }]; - - let constructor = "createAddConcurrentRegionsPass()"; -} - -//===----------------------------------------------------------------------===// -// Stream assignment. -//===----------------------------------------------------------------------===// - -def StreamAssignmentPass: - Pass<"xla-gpu-stream-assignment", "mlir::ModuleOp"> { - let summary = "Identify and mark concurrent regions in CUDA graph capture " - "functions"; - - let description = [{ - Assign a stream to each kernel launch in the capture function. Streams are - assigned to exploit parallelism, so that we can build parallel GPU graphs - duing graph capture. - - Example: - - ```mlir - func.func @capture.cuda.graph() { - // func1, func2, func3 can run in parallel - call @local_xla.gpu.launch.func1 - call @local_xla.gpu.launch.func2 - call @local_xla.gpu.launch.func3 - - // Depends on xla.gpu.launc.func1 and xla.gpu.launch.func2 to finish. - call @local_xla.gpu.launch.func - func.return - } - ``` - - becomes: - - ```mlir - func.func @capture.cuda.graph() { - // func1, func2, func3 can run in parallel - call @local_xla.gpu.launch.func1 {stream = 0 : i64} - call @local_xla.gpu.launch.func2 {stream = 1 : i64} - call @local_xla.gpu.launch.func3 {stream = 2 : i64} - - // Add explicit synchronization to wait for stream 1 to finish executing - // func2. - call @local_xla.stream.await {from = 0 : i64, to = [1]} - call @local_xla.gpu.launch.func {stream = 0: i64} - func.return - } - ``` - - }]; - - let constructor = "createStreamAssignmentPass()"; -} - -#endif // XLA_GPU_PASSES diff --git a/third_party/xla/xla/mlir/backends/gpu/transforms/stream_assignment.cc b/third_party/xla/xla/mlir/backends/gpu/transforms/stream_assignment.cc deleted file mode 100644 index 08ab2e230dbe66..00000000000000 --- a/third_party/xla/xla/mlir/backends/gpu/transforms/stream_assignment.cc +++ /dev/null @@ -1,271 +0,0 @@ -/* Copyright 2023 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include -#include -#include -#include -#include -#include -#include - -#include "absl/strings/match.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project -#include "mlir/Dialect/GPU/IR/GPUDialect.h" // from @llvm-project -#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project -#include "mlir/IR/BuiltinOps.h" // from @llvm-project -#include "mlir/IR/ImplicitLocOpBuilder.h" // from @llvm-project -#include "mlir/IR/Operation.h" // from @llvm-project -#include "mlir/Pass/Pass.h" // from @llvm-project -#include "xla/mlir/backends/gpu/transforms/dataflow_analysis.h" -#include "xla/mlir/runtime/utils/custom_calls.h" - -namespace xla { -namespace gpu { - -namespace { - -#define GEN_PASS_DEF_STREAMASSIGNMENTPASS -#include "xla/mlir/backends/gpu/transforms/passes.h.inc" - -using namespace mlir; // NOLINT -using mlir::func::FuncOp; -using DataflowGraph = DataflowAnalysis::DataflowGraph; -using Node = DataflowAnalysis::Node; - -class StreamAssignmentPass - : public impl::StreamAssignmentPassBase { - void runOnOperation() override; - - void getDependentDialects(DialectRegistry& registry) const override { - registry.insert(); - } -}; - -static constexpr int kNumStreams = 10; - -//===----------------------------------------------------------------------===// - -bool IsParallelizableOp(Operation* op) { - return isa(op); -} - -// -// A simple algorithm to assign streams using the dependency information -// provided by the dataflow graph. -// Pseudocode: -// stream = 0 -// while there exists op such that it is unassigned: -// assign op to stream -// while op has a child: -// op = the last child in the order of execution in the capture function -// assign op to stream -// stream++ -// -// When assigning a stream to a dependency chain, we find the next op in the -// chain by finding the last child of the current op. For example, in the -// following dependency graph, A and C are assigned to stream 0, while B is -// assigned to 1. -// -// A-->B C -// | ^ -// +------| -// -std::vector AssignStreams(const DataflowGraph& graph, int num_streams) { - std::vector stream_assignment(graph.size(), -1); - size_t current_stream = 0; - - auto get_current_stream = [&]() -> size_t { - size_t assigned_stream = current_stream; - current_stream++; - if (current_stream == num_streams) { - current_stream = 0; - } - return assigned_stream; - }; - - auto get_first_unassigned_node = [&stream_assignment = - std::as_const(stream_assignment), - &graph]() -> std::optional { - for (auto [index, stream] : llvm::enumerate(stream_assignment)) { - if (stream == -1 && IsParallelizableOp(graph[index].operation)) { - return index; - } - } - return std::nullopt; - }; - - auto get_last_unassigned_child = [&stream_assignment = - std::as_const(stream_assignment), - &graph](Node node) -> std::optional { - for (int i = node.children.size() - 1; i >= 0; i--) { - Node child = graph[node.children[i]]; - if (!IsParallelizableOp(child.operation)) continue; - if (stream_assignment[child.index] == -1) { - return child; - } - } - return std::nullopt; - }; - - std::function assign_stream_to_dependency_chain = - [&](Node node, size_t stream) { - stream_assignment[node.index] = stream; - - if (auto child = get_last_unassigned_child(node)) { - assign_stream_to_dependency_chain(child.value(), stream); - } - }; - - while (std::optional unassigned_index = get_first_unassigned_node()) { - Node unassigned_node = graph[unassigned_index.value()]; - size_t assigned_stream = get_current_stream(); - assign_stream_to_dependency_chain(unassigned_node, assigned_stream); - } - - // next: Assign all non parallelizable ops to stream 0. - - return stream_assignment; -} - -std::optional GetAssignedStream(Operation* op) { - if (op->hasAttr("stream")) { - return op->getAttrOfType("stream").getInt(); - } - return std::nullopt; -} - -// -// Add synchronizations between assigned streams. The added custom call -// xla.streams.await() {from = A, to = [B, C, ...]} makes future work submitted -// to A wait for work that are already submitted to streams B, C, ... -// -// Pseudo code: -// For each node in the dependency graph -// If the node has a stream A assigned -// parents = A's parents -// to_streams = the assigned streams of its parents -// add xla.streams.await() {from = A, to = to_streams} before node -// -// TODO(anlunx): Handle the case where the cuda graph contains non -// parallelizable ops (cuBLAS, cuDNN). -// -void AddSynchronization(FuncOp await_op, - runtime::CustomCallDeclarations custom_calls, - const DataflowGraph& graph) { - for (const Node& node : graph) { - Operation* op = node.operation; - std::optional op_stream = GetAssignedStream(op); - if (!op_stream.has_value()) { - continue; - } - int from_stream = op_stream.value(); - - std::array dependent_streams; - dependent_streams.fill(false); - for (int i = 0; i < node.index; i++) { - if (std::find(graph[i].children.begin(), graph[i].children.end(), - node.index) != graph[i].children.end()) { - if (std::optional to_stream = - GetAssignedStream(graph[i].operation)) { - if (to_stream.value() != from_stream) { - dependent_streams[to_stream.value()] = true; - } - } - } - } - - ImplicitLocOpBuilder b(op->getLoc(), custom_calls.sym_table().getOp()); - llvm::SmallVector to_streams; - for (int i = 0; i < kNumStreams; i++) { - if (dependent_streams[i]) { - to_streams.push_back(b.getI64IntegerAttr(i)); - } - } - - if (to_streams.empty()) { - continue; - } - - b.setInsertionPoint(op); - auto call = b.create(await_op.getName(), TypeRange()); - call->setAttr(b.getStringAttr("from"), b.getI64IntegerAttr(from_stream)); - call->setAttr(b.getStringAttr("to"), b.getArrayAttr(to_streams)); - } -} - -//===----------------------------------------------------------------------===// - -void StreamAssignmentPass::runOnOperation() { - ModuleOp module = getOperation(); - SymbolTable sym_table(module); - runtime::CustomCallDeclarations custom_calls(std::move(sym_table)); - - auto func_ops = llvm::to_vector(module.getOps()); - ImplicitLocOpBuilder b(module->getLoc(), custom_calls.sym_table().getOp()); - func::FuncOp begin_marker = custom_calls.GetOrCreate( - b, "xla.gpu.concurrent_region.begin", TypeRange(), TypeRange()); - func::FuncOp end_marker = custom_calls.GetOrCreate( - b, "xla.gpu.concurrent_region.end", TypeRange(), TypeRange()); - func::FuncOp await_op = custom_calls.GetOrCreate(b, "xla.streams.await", - TypeRange(), TypeRange()); - - for (auto func_op : func_ops) { - if (!absl::StrContains(func_op.getSymNameAttr().str(), - "xla.gpu.graph.capture")) { - continue; - } - - DataflowAnalysis dataflow_analysis(func_op); - DataflowGraph graph = dataflow_analysis.GetDataflowGraph(func_op); - std::vector stream_assignment = AssignStreams(graph, kNumStreams); - - size_t stream_count = 0; - for (auto [index, stream] : llvm::enumerate(stream_assignment)) { - stream_count = std::max(stream_count, stream + 1); - Node node = graph[index]; - Operation* op = node.operation; - ImplicitLocOpBuilder b(op->getLoc(), custom_calls.sym_table().getOp()); - if (stream != -1) { - op->setAttr(b.getStringAttr("stream"), b.getI64IntegerAttr(stream)); - } - } - - AddSynchronization(await_op, custom_calls, graph); - - ImplicitLocOpBuilder b(func_op->getLoc(), custom_calls.sym_table().getOp()); - auto first_op = &(*func_op.getOps().begin()); - b.setInsertionPoint(first_op); - auto call = b.create(begin_marker.getName(), TypeRange()); - call->setAttr(b.getStringAttr("size"), b.getI64IntegerAttr(stream_count)); - - auto op_it = func_op.getOps().begin(); - while (!isa(*op_it)) { - op_it++; - } - Operation* return_op = &(*op_it); - b.setInsertionPoint(return_op); - b.create(end_marker.getName(), TypeRange()); - } -} - -} // namespace - -std::unique_ptr> createStreamAssignmentPass() { - return std::make_unique(); -} - -} // namespace gpu -} // namespace xla diff --git a/third_party/xla/xla/mlir/backends/gpu/transforms/tests/BUILD b/third_party/xla/xla/mlir/backends/gpu/transforms/tests/BUILD deleted file mode 100644 index ed7e97f057fd19..00000000000000 --- a/third_party/xla/xla/mlir/backends/gpu/transforms/tests/BUILD +++ /dev/null @@ -1,40 +0,0 @@ -load("//xla:lit.bzl", "enforce_glob", "lit_test_suite") - -package( - default_visibility = ["//visibility:public"], - # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - licenses = ["notice"], -) - -lit_test_suite( - name = "all_tests", - srcs = enforce_glob( - [ - "add_concurrent_regions.mlir", - "add_hlo_trace.mlir", - "gpu_launch.mlir", - "gpu_memcpy.mlir", - "gpu_memset.mlir", - "lmhlo_case.mlir", - "lmhlo_custom_call.mlir", - "lmhlo_fft.mlir", - "lmhlo_gpu_cholesky.mlir", - "lmhlo_gpu_conv.mlir", - "lmhlo_gpu_cublas_lt_matmul.mlir", - "lmhlo_gpu_gemm.mlir", - "lmhlo_infeed.mlir", - "lmhlo_outfeed.mlir", - "lmhlo_send_recv.mlir", - "lmhlo_while.mlir", - "memref_get_global_to_arg.mlir", - "outline_cuda_graphs.mlir", - "stream_assignment.mlir", - ], - include = ["*.mlir"], - ), - cfg = "//xla:lit.cfg.py", - tools = [ - "//xla/mlir/backends/gpu:xla-gpu-opt", - "@llvm-project//llvm:FileCheck", - ], -) diff --git a/third_party/xla/xla/mlir/backends/gpu/transforms/tests/add_concurrent_regions.mlir b/third_party/xla/xla/mlir/backends/gpu/transforms/tests/add_concurrent_regions.mlir deleted file mode 100644 index 5cf7738273d556..00000000000000 --- a/third_party/xla/xla/mlir/backends/gpu/transforms/tests/add_concurrent_regions.mlir +++ /dev/null @@ -1,348 +0,0 @@ -// RUN: xla-gpu-opt %s --split-input-file -xla-gpu-add-concurrent-regions \ -// RUN: | FileCheck %s - - -// ----- -// Check that two consecutive launch_funcs using different buffers is captured -// by a concurrent_region. - -module attributes {gpu.container_module} { - - gpu.module @gpu_module attributes {binary = "kernel binary"} { - gpu.func @fn0(%arg0: memref<3x3xi64> {lmhlo.written} ) kernel { gpu.return } - gpu.func @fn1(%arg0: memref<3x3xi64> {lmhlo.written} ) kernel { gpu.return } - } - - - // CHECK: func @local_xla.gpu.graph.capture - func.func @local_xla.gpu.graph.capture(%arg0: memref<72xi8>, %arg1: memref<72xi8>, %arg2: memref<328xi8>, %arg3: memref<72xi8>, %arg4: memref<72xi8>, %arg5: memref<72xi8>, %arg6: memref<72xi8>, %arg7: memref<72xi8>, %arg8: memref<72xi8>, %arg9: memref<72xi8>) { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %view = memref.view %arg0[%c0][] : memref<72xi8> to memref<3x3xi64> - %view_0 = memref.view %arg1[%c0][] : memref<72xi8> to memref<3x3xi64> - - // CHECK: call @local_xla.gpu.concurrent_region.begin() - // CHECK-NEXT: gpu.launch_func - // CHECK-NEXT: gpu.launch_func - // CHECK-NEXT: call @local_xla.gpu.concurrent_region.end() - // CHECK-NEXT: return - gpu.launch_func @gpu_module::@fn0 blocks in (%c1, %c1, %c1) - threads in (%c1, %c1, %c1) args(%view : memref<3x3xi64>) - gpu.launch_func @gpu_module::@fn1 blocks in (%c1, %c1, %c1) - threads in (%c1, %c1, %c1) args(%view_0 : memref<3x3xi64>) - return - } -} - -// ----- -// Check that two consecutive launch_funcs using the same buffer is not -// captured. - -module attributes {gpu.container_module} { - - gpu.module @gpu_module attributes {binary = "kernel binary"} { - gpu.func @fn0(%arg0: memref<3x3xi64> {lmhlo.written} ) kernel { gpu.return } - gpu.func @fn1(%arg0: memref<3x3xi64> {lmhlo.written} ) kernel { gpu.return } - } - - - // CHECK: func @local_xla.gpu.graph.capture - func.func @local_xla.gpu.graph.capture(%arg0: memref<72xi8>, %arg1: memref<72xi8>, %arg2: memref<328xi8>, %arg3: memref<72xi8>, %arg4: memref<72xi8>, %arg5: memref<72xi8>, %arg6: memref<72xi8>, %arg7: memref<72xi8>, %arg8: memref<72xi8>, %arg9: memref<72xi8>) { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %view = memref.view %arg0[%c0][] : memref<72xi8> to memref<3x3xi64> - %view_0 = memref.view %arg0[%c0][] : memref<72xi8> to memref<3x3xi64> - - // CHECK: gpu.launch_func - // CHECK-NEXT: gpu.launch_func - // CHECK-NEXT: return - gpu.launch_func @gpu_module::@fn0 blocks in (%c1, %c1, %c1) - threads in (%c1, %c1, %c1) args(%view : memref<3x3xi64>) - gpu.launch_func @gpu_module::@fn1 blocks in (%c1, %c1, %c1) - threads in (%c1, %c1, %c1) args(%view_0 : memref<3x3xi64>) - return - } -} - -// ----- -// Check that there is no dependency from launch_funcs that do not write to -// buffers. - -module attributes {gpu.container_module} { - - gpu.module @gpu_module attributes {binary = "kernel binary"} { - gpu.func @fn0(%arg0: memref<3x3xi64> ) kernel { gpu.return } - gpu.func @fn1(%arg0: memref<3x3xi64> ) kernel { gpu.return } - } - - - // CHECK: func @local_xla.gpu.graph.capture - func.func @local_xla.gpu.graph.capture(%arg0: memref<72xi8>, %arg1: memref<72xi8>, %arg2: memref<328xi8>, %arg3: memref<72xi8>, %arg4: memref<72xi8>, %arg5: memref<72xi8>, %arg6: memref<72xi8>, %arg7: memref<72xi8>, %arg8: memref<72xi8>, %arg9: memref<72xi8>) { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %view = memref.view %arg0[%c0][] : memref<72xi8> to memref<3x3xi64> - %view_0 = memref.view %arg0[%c0][] : memref<72xi8> to memref<3x3xi64> - - // CHECK: call @local_xla.gpu.concurrent_region.begin() - // CHECK-NEXT: gpu.launch_func - // CHECK-NEXT: gpu.launch_func - // CHECK-NEXT: call @local_xla.gpu.concurrent_region.end() - // CHECK-NEXT: return - gpu.launch_func @gpu_module::@fn0 blocks in (%c1, %c1, %c1) - threads in (%c1, %c1, %c1) args(%view : memref<3x3xi64>) - gpu.launch_func @gpu_module::@fn1 blocks in (%c1, %c1, %c1) - threads in (%c1, %c1, %c1) args(%view_0 : memref<3x3xi64>) - return - } -} - -// ----- -// Check that the i1 data type is handled correctly. -module attributes {gpu.container_module} { - - gpu.module @gpu_module attributes {binary = "kernel binary"} { - gpu.func @fn0(%arg0: memref<3x3xi1> {lmhlo.written} ) kernel { gpu.return } - gpu.func @fn1(%arg0: memref<3x3xi1> {lmhlo.written} ) kernel { gpu.return } - } - - - // CHECK: func @local_xla.gpu.graph.capture - func.func @local_xla.gpu.graph.capture(%arg0: memref<72xi8>, %arg1: memref<72xi8>, %arg2: memref<328xi8>, %arg3: memref<72xi8>, %arg4: memref<72xi8>, %arg5: memref<72xi8>, %arg6: memref<72xi8>, %arg7: memref<72xi8>, %arg8: memref<72xi8>, %arg9: memref<72xi8>) { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %view = memref.view %arg0[%c0][] : memref<72xi8> to memref<3x3xi1> - %view_0 = memref.view %arg0[%c0][] : memref<72xi8> to memref<3x3xi1> - - // CHECK-NOT: xla.gpu.concurrent_region.begin() - // CHECK: gpu.launch_func - // CHECK-NEXT: gpu.launch_func - // CHECK-NEXT: return - gpu.launch_func @gpu_module::@fn0 blocks in (%c1, %c1, %c1) - threads in (%c1, %c1, %c1) args(%view : memref<3x3xi1>) - gpu.launch_func @gpu_module::@fn1 blocks in (%c1, %c1, %c1) - threads in (%c1, %c1, %c1) args(%view_0 : memref<3x3xi1>) - return - } -} - -// ----- -// Check that disjoint buffer slices does not introduce dependency. - -module attributes {gpu.container_module} { - - gpu.module @gpu_module attributes {binary = "kernel binary"} { - gpu.func @fn0(%arg0: memref<3x3xi64> {lmhlo.written} ) kernel { gpu.return } - gpu.func @fn1(%arg0: memref<3x3xi64> {lmhlo.written} ) kernel { gpu.return } - } - - - // CHECK: func @local_xla.gpu.graph.capture - func.func @local_xla.gpu.graph.capture(%arg0: memref<144xi8>) { - %c0 = arith.constant 0 : index - %c72 = arith.constant 72 : index - %c1 = arith.constant 1 : index - %view = memref.view %arg0[%c0][] : memref<144xi8> to memref<3x3xi64> - %view_0 = memref.view %arg0[%c72][] : memref<144xi8> to memref<3x3xi64> - - // CHECK: call @local_xla.gpu.concurrent_region.begin() - // CHECK-NEXT: gpu.launch_func - // CHECK-NEXT: gpu.launch_func - // CHECK-NEXT: call @local_xla.gpu.concurrent_region.end() - // CHECK-NEXT: return - gpu.launch_func @gpu_module::@fn0 blocks in (%c1, %c1, %c1) - threads in (%c1, %c1, %c1) args(%view : memref<3x3xi64>) - gpu.launch_func @gpu_module::@fn1 blocks in (%c1, %c1, %c1) - threads in (%c1, %c1, %c1) args(%view_0 : memref<3x3xi64>) - return - } -} - -// ----- -// Check that overlapping buffer slices creates dependency. - -module attributes {gpu.container_module} { - - gpu.module @gpu_module attributes {binary = "kernel binary"} { - gpu.func @fn0(%arg0: memref<3x3xi64> {lmhlo.written} ) kernel { gpu.return } - gpu.func @fn1(%arg0: memref<3x3xi64> {lmhlo.written} ) kernel { gpu.return } - } - - - // CHECK: func @local_xla.gpu.graph.capture - func.func @local_xla.gpu.graph.capture(%arg0: memref<144xi8>) { - %c0 = arith.constant 0 : index - %c36 = arith.constant 36 : index - %c1 = arith.constant 1 : index - %view = memref.view %arg0[%c0][] : memref<144xi8> to memref<3x3xi64> - %view_0 = memref.view %arg0[%c36][] : memref<144xi8> to memref<3x3xi64> - - // CHECK: gpu.launch_func - // CHECK-NEXT: gpu.launch_func - // CHECK-NEXT: return - gpu.launch_func @gpu_module::@fn0 blocks in (%c1, %c1, %c1) - threads in (%c1, %c1, %c1) args(%view : memref<3x3xi64>) - gpu.launch_func @gpu_module::@fn1 blocks in (%c1, %c1, %c1) - threads in (%c1, %c1, %c1) args(%view_0 : memref<3x3xi64>) - return - } -} - -// ----- -// Check that constant input buffer does not create dependency. - -module attributes {gpu.container_module} { - - gpu.module @gpu_module attributes {binary = "kernel binary"} { - gpu.func @fn0(%arg0: memref<3x3xi64> {lmhlo.written} ) kernel { gpu.return } - gpu.func @fn1(%arg0: memref<3x3xi64> {lmhlo.written} ) kernel { gpu.return } - } - - - // CHECK: func @local_xla.gpu.graph.capture - func.func @local_xla.gpu.graph.capture(%arg0: memref<144xi8> {lmhlo.constant_name = "cst0"}) { - %c0 = arith.constant 0 : index - %c36 = arith.constant 36 : index - %c1 = arith.constant 1 : index - %view = memref.view %arg0[%c0][] : memref<144xi8> to memref<3x3xi64> - %view_0 = memref.view %arg0[%c36][] : memref<144xi8> to memref<3x3xi64> - - // CHECK: call @local_xla.gpu.concurrent_region.begin() - // CHECK-NEXT: gpu.launch_func - // CHECK-NEXT: gpu.launch_func - // CHECK-NEXT: call @local_xla.gpu.concurrent_region.end() - // CHECK-NEXT: return - gpu.launch_func @gpu_module::@fn0 blocks in (%c1, %c1, %c1) - threads in (%c1, %c1, %c1) args(%view : memref<3x3xi64>) - gpu.launch_func @gpu_module::@fn1 blocks in (%c1, %c1, %c1) - threads in (%c1, %c1, %c1) args(%view_0 : memref<3x3xi64>) - return - } -} - -// ----- -// Check that two gemms that read the same buffer are moved into a concurrent -// region. - -module attributes {gpu.container_module} { - - // CHECK: func @local_xla.gpu.graph.capture - func.func @local_xla.gpu.graph.capture(%arg0: memref<16xi8>, - %arg1: memref<16xi8>, - %arg2: memref<16xi8>, - %arg3: memref<16xi8>) { - %c0 = arith.constant 0 : index - %view_0 = memref.view %arg0[%c0][] : memref<16xi8> to memref<2x2xf32> - %c1 = arith.constant 0 : index - %view_1 = memref.view %arg1[%c1][] : memref<16xi8> to memref<2x2xf32> - %c2 = arith.constant 0 : index - %view_2 = memref.view %arg2[%c2][] : memref<16xi8> to memref<2x2xf32> - %view_3 = memref.view %arg3[%c2][] : memref<16xi8> to memref<2x2xf32> - - // CHECK: call @local_xla.gpu.concurrent_region.begin() - // CHECK-NEXT: lmhlo_gpu.gemm - // CHECK-NEXT: lmhlo_gpu.gemm - // CHECK-NEXT: call @local_xla.gpu.concurrent_region.end() - // CHECK-NEXT: return - "lmhlo_gpu.gemm"(%view_0, %view_1, %view_2) {alpha_imag = 0.000000e+00 : f64, alpha_real = 1.000000e+00 : f64, beta = 0.000000e+00 : f64, batch_size = 1 : i64, lhs_stride = 4 : i64, rhs_stride = 4 : i64, dot_dimension_numbers = #mhlo.dot} : (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>) -> () - "lmhlo_gpu.gemm"(%view_0, %view_1, %view_3) {alpha_imag = 0.000000e+00 : f64, alpha_real = 1.000000e+00 : f64, beta = 0.000000e+00 : f64, batch_size = 1 : i64, lhs_stride = 4 : i64, rhs_stride = 4 : i64, dot_dimension_numbers = #mhlo.dot} : (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>) -> () - return - } - - func.func private @external() -} - -// ----- -// Check that lmhlo_gpu.gemm is not moved into the concurrent region if it -// uses a buffer used by a kernel launch. - -module attributes {gpu.container_module} { - - gpu.module @gpu_module attributes {binary = "kernel binary"} { - gpu.func @fn0(%arg0: memref<16xi8> {lmhlo.written} ) kernel { gpu.return } - } - - // CHECK: func @local_xla.gpu.graph.capture - func.func @local_xla.gpu.graph.capture(%arg0: memref<16xi8>, - %arg1: memref<16xi8>, - %arg2: memref<16xi8>) { - %c0 = arith.constant 0 : index - %view_0 = memref.view %arg0[%c0][] : memref<16xi8> to memref<2x2xf32> - %c1 = arith.constant 0 : index - %view_1 = memref.view %arg1[%c1][] : memref<16xi8> to memref<2x2xf32> - %c2 = arith.constant 0 : index - %view_2 = memref.view %arg2[%c2][] : memref<16xi8> to memref<2x2xf32> - - // CHECK-NOT: @local_xla.gpu.concurrent_region.begin() - // CHECK: lmhlo_gpu.gemm - // CHECK-NEXT: gpu.launch_func - // CHECK-NEXT: return - "lmhlo_gpu.gemm"(%view_0, %view_1, %view_2) {alpha_imag = 0.000000e+00 : f64, alpha_real = 1.000000e+00 : f64, beta = 0.000000e+00 : f64, batch_size = 1 : i64, lhs_stride = 4 : i64, rhs_stride = 4 : i64, dot_dimension_numbers = #mhlo.dot} : (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>) -> () - gpu.launch_func @gpu_module::@fn0 blocks in (%c0, %c0, %c0) - threads in (%c0, %c0, %c0) args(%arg0: memref<16xi8>) - return - } - - func.func private @external() -} - -// ----- -// Check that memcpies are added to concurrent regions. - -module attributes {gpu.container_module} { - - // CHECK: func @local_xla.gpu.graph.capture - func.func @local_xla.gpu.graph.capture(%arg0: memref<16xi8>, - %arg1: memref<16xi8>, - %arg2: memref<16xi8>) { - %c0 = arith.constant 0 : index - %view_0 = memref.view %arg0[%c0][] : memref<16xi8> to memref<2x2xf32> - %c1 = arith.constant 0 : index - %view_1 = memref.view %arg1[%c1][] : memref<16xi8> to memref<2x2xf32> - %c2 = arith.constant 0 : index - %view_2 = memref.view %arg2[%c2][] : memref<16xi8> to memref<2x2xf32> - - // CHECK: @local_xla.gpu.concurrent_region.begin() - // CHECK-NEXT: gpu.memcpy - // CHECK-NEXT: gpu.memcpy - // CHECK-NEXT: @local_xla.gpu.concurrent_region.end() - // CHECK-NEXT: return - gpu.memcpy %view_1, %view_0 : memref<2x2xf32>, memref<2x2xf32> - gpu.memcpy %view_2, %view_0 : memref<2x2xf32>, memref<2x2xf32> - return - } - - func.func private @external() -} - -// ----- -// Check that region size is set correctly. - -module attributes {gpu.container_module} { - - gpu.module @gpu_module attributes {binary = "kernel binary"} { - gpu.func @fn0(%arg0: memref<3x3xi64> {lmhlo.written} ) kernel { gpu.return } - gpu.func @fn1(%arg0: memref<3x3xi64> {lmhlo.written} ) kernel { gpu.return } - } - - - // CHECK: func @local_xla.gpu.graph.capture - func.func @local_xla.gpu.graph.capture(%arg0: memref<72xi8>, %arg1: memref<72xi8>, %arg2: memref<328xi8>, %arg3: memref<72xi8>, %arg4: memref<72xi8>, %arg5: memref<72xi8>, %arg6: memref<72xi8>, %arg7: memref<72xi8>, %arg8: memref<72xi8>, %arg9: memref<72xi8>) { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %view = memref.view %arg0[%c0][] : memref<72xi8> to memref<3x3xi64> - %view_0 = memref.view %arg1[%c0][] : memref<72xi8> to memref<3x3xi64> - - // CHECK: call @local_xla.gpu.concurrent_region.begin() {size = 2 : i64} - // CHECK-NEXT: gpu.launch_func - // CHECK-NEXT: memref.view - // CHECK-NEXT: gpu.launch_func - // CHECK-NEXT: call @local_xla.gpu.concurrent_region.end() - // CHECK-NEXT: return - gpu.launch_func @gpu_module::@fn0 blocks in (%c1, %c1, %c1) - threads in (%c1, %c1, %c1) args(%view : memref<3x3xi64>) - %view_1 = memref.view %arg1[%c0][] : memref<72xi8> to memref<3x3xi64> - gpu.launch_func @gpu_module::@fn1 blocks in (%c1, %c1, %c1) - threads in (%c1, %c1, %c1) args(%view_0 : memref<3x3xi64>) - return - } -} diff --git a/third_party/xla/xla/mlir/backends/gpu/transforms/tests/add_hlo_trace.mlir b/third_party/xla/xla/mlir/backends/gpu/transforms/tests/add_hlo_trace.mlir deleted file mode 100644 index 864a4b0433ddc1..00000000000000 --- a/third_party/xla/xla/mlir/backends/gpu/transforms/tests/add_hlo_trace.mlir +++ /dev/null @@ -1,15 +0,0 @@ -// RUN: xla-gpu-opt %s -xla-add-hlo-trace-annotations | FileCheck %s - -module attributes { mhlo.unique_id = 42 : i64 } { - -func.func private @local_xla.foo() attributes { rt.custom_call = "xla.foo" } - -// CHECK: func @func() { -func.func @func() { - // CHECK: call @local_xla.foo() - // CHECK-SAME: rt.trace = #rt.hlo_trace<"gemm.name.42"> - call @local_xla.foo() : () -> () loc("gemm.name.42") - return -} - -} loc("module-name") diff --git a/third_party/xla/xla/mlir/backends/gpu/transforms/tests/gpu_launch.mlir b/third_party/xla/xla/mlir/backends/gpu/transforms/tests/gpu_launch.mlir deleted file mode 100644 index e05ff982bb3c48..00000000000000 --- a/third_party/xla/xla/mlir/backends/gpu/transforms/tests/gpu_launch.mlir +++ /dev/null @@ -1,64 +0,0 @@ -// RUN: xla-gpu-opt %s -xla-gpu-to-gpu-runtime | FileCheck %s - -module attributes {gpu.container_module} { - -// CHECK-NOT: gpu.module -gpu.module @gpu_module attributes {binary = "kernel binary"} { - gpu.func @fn0(%arg0: memref<4x4xf32>, %arg1: memref<4x4xf32>) kernel { - gpu.return - } - gpu.func @fn1(%arg0: memref<4x4xf32>, %arg1: memref<4x4xf32>) kernel { - gpu.return - } -} - -// CHECK: @func( -// CHECK: %[[ARG0:.*]]: memref<4x4xf32>, -// CHECK: %[[ARG1:.*]]: memref<4x4xf32> -// CHECK: ) -func.func @func(%arg0: memref<4x4xf32>, %arg1: memref<4x4xf32>) { - // Launch dimensions converted to i32 as a part of the lowering. - // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : i32 - // CHECK-DAG: %[[C2:.*]] = arith.constant 2 : i32 - // CHECK-DAG: %[[C3:.*]] = arith.constant 3 : i32 - // CHECK-DAG: %[[C4:.*]] = arith.constant 4 : i32 - // CHECK-DAG: %[[C5:.*]] = arith.constant 5 : i32 - // CHECK-DAG: %[[C6:.*]] = arith.constant 6 : i32 - // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : i32 - // CHECK-DAG: %[[C256:.*]] = arith.constant 256 : i32 - %c1 = arith.constant 1 : index - %c2 = arith.constant 2 : index - %c3 = arith.constant 3 : index - %c4 = arith.constant 4 : index - %c5 = arith.constant 5 : index - %c6 = arith.constant 6 : index - %c256 = arith.constant 256 : i32 - - // CHECK: call @[[LAUNCH:[_a-z.]+]](%[[C0]], %[[C1]], %[[C2]], %[[C3]], - // CHECK-SAME: %[[C4]], %[[C5]], %[[C6]], %[[ARG0]], %[[ARG1]]) - // CHECK-SAME: kernel = "fn0" - gpu.launch_func @gpu_module::@fn0 - blocks in (%c1, %c2, %c3) - threads in (%c4, %c5, %c6) - args(%arg0 : memref<4x4xf32>, %arg1 : memref<4x4xf32>) - - // CHECK: call @[[LAUNCH]](%[[C256]], %[[C3]], %[[C2]], %[[C1]], %[[C6]], - // CHECK-SAME: %[[C5]], %[[C4]], %[[ARG0]], %[[ARG1]]) - // CHECK-DAG: kernel = "fn1" - gpu.launch_func @gpu_module::@fn1 - blocks in (%c3, %c2, %c1) - threads in (%c6, %c5, %c4) - dynamic_shared_memory_size %c256 - args(%arg0 : memref<4x4xf32>, %arg1 : memref<4x4xf32>) - - func.return -} - -// CHECK: func private @[[LAUNCH]](i32, i32, i32, i32, i32, i32, -// CHECK-SAME: memref<4x4xf32>, memref<4x4xf32>) -// CHECK-SAME: attributes {rt.custom_call = "xla.gpu.func.launch"} - -// Check that we have a single custom call declaration in the module. -// CHECK-NOT: rt.custom_call - -} diff --git a/third_party/xla/xla/mlir/backends/gpu/transforms/tests/gpu_memcpy.mlir b/third_party/xla/xla/mlir/backends/gpu/transforms/tests/gpu_memcpy.mlir deleted file mode 100644 index 410a94c489ed16..00000000000000 --- a/third_party/xla/xla/mlir/backends/gpu/transforms/tests/gpu_memcpy.mlir +++ /dev/null @@ -1,46 +0,0 @@ -// RUN: xla-gpu-opt %s --split-input-file -xla-gpu-to-gpu-runtime | FileCheck %s - -// CHECK: func @gpu_memcpy_d2d( -// CHECK: %[[DST:[a-z0-9]+]]: memref, -// CHECK: %[[SRC:[a-z0-9]+]]: memref -// CHECK: ) -func.func @gpu_memcpy_d2d(%dst: memref, %src: memref) { - // CHECK: call @[[MEMCPY:.*]](%[[DST]], %[[SRC]]) - gpu.memcpy %dst, %src : memref, memref - return -} - -// CHECK: func private @[[MEMCPY]](memref, memref) -// CHECK-SAME: attributes {rt.custom_call = "xla.gpu.memcpy.d2d"} - -// ----- - -// CHECK: func @gpu_memcpy_h2d( -// CHECK: %[[DST:[a-z0-9]+]]: memref -// CHECK: ) -func.func @gpu_memcpy_h2d(%dst: memref, %dim: index) { - // CHECK: %[[SRC:.*]] = memref.alloca - %src = memref.alloca(%dim) : memref - // CHECK: call @[[MEMCPY:.*]](%[[DST]], %[[SRC]]) - gpu.memcpy %dst, %src : memref, memref - return -} - -// CHECK: func private @[[MEMCPY]](memref, memref) -// CHECK-SAME: attributes {rt.custom_call = "xla.gpu.memcpy.h2d"} - -// ----- - -// CHECK: func @gpu_memcpy_d2h( -// CHECK: %[[SRC:[a-z0-9]+]]: memref -// CHECK: ) -func.func @gpu_memcpy_d2h(%src: memref, %dim: index) { - // CHECK: %[[DST:.*]] = memref.alloca - %dst = memref.alloca(%dim) : memref - // CHECK: call @[[MEMCPY:.*]](%[[DST]], %[[SRC]]) - gpu.memcpy %dst, %src : memref, memref - return -} - -// CHECK: func private @[[MEMCPY]](memref, memref) -// CHECK-SAME: attributes {rt.custom_call = "xla.gpu.memcpy.d2h"} diff --git a/third_party/xla/xla/mlir/backends/gpu/transforms/tests/gpu_memset.mlir b/third_party/xla/xla/mlir/backends/gpu/transforms/tests/gpu_memset.mlir deleted file mode 100644 index 33b2232fc8519d..00000000000000 --- a/third_party/xla/xla/mlir/backends/gpu/transforms/tests/gpu_memset.mlir +++ /dev/null @@ -1,31 +0,0 @@ -// RUN: xla-gpu-opt %s --split-input-file -xla-gpu-to-gpu-runtime | FileCheck %s - -// CHECK: func @gpu_memset_i32( -// CHECK: %[[DST:[a-z0-9]+]]: memref -// CHECK: ) -func.func @gpu_memset_i32(%dst: memref) { - // CHECK: %[[CST:.*]] = arith.constant 0 : i32 - %cst = arith.constant 0 : i32 - // CHECK: call @[[MEMSET:.*]](%[[DST]], %[[CST]]) - gpu.memset %dst, %cst : memref, i32 - return -} - -// CHECK: func private @[[MEMSET]](memref, i32) -// CHECK-SAME: attributes {rt.custom_call = "xla.gpu.memset"} - -// ----- - -// CHECK: func @gpu_memset_f32( -// CHECK: %[[DST:[a-z0-9]+]]: memref -// CHECK: ) -func.func @gpu_memset_f32(%dst: memref) { - // CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : f32 - %cst = arith.constant 0.000000e+00 : f32 - // CHECK: call @[[MEMSET:.*]](%[[DST]], %[[CST]]) - gpu.memset %dst, %cst : memref, f32 - return -} - -// CHECK: func private @[[MEMSET]](memref, f32) -// CHECK-SAME: attributes {rt.custom_call = "xla.gpu.memset"} diff --git a/third_party/xla/xla/mlir/backends/gpu/transforms/tests/lmhlo_case.mlir b/third_party/xla/xla/mlir/backends/gpu/transforms/tests/lmhlo_case.mlir deleted file mode 100644 index 9e4120cda6328e..00000000000000 --- a/third_party/xla/xla/mlir/backends/gpu/transforms/tests/lmhlo_case.mlir +++ /dev/null @@ -1,116 +0,0 @@ -// RUN: xla-gpu-opt %s -xla-lmhlo-to-gpu-runtime | FileCheck %s - -module attributes {gpu.container_module} { - memref.global "private" constant @constant : memref = dense<0> - - gpu.module @case0 attributes {binary = "ptx"} { - gpu.func @fn(%arg0: memref) kernel { - gpu.return - } - } - - gpu.module @case1 attributes {binary = "ptx"} { - gpu.func @fn(%arg0: memref) kernel { - gpu.return - } - } - - // CHECK: @case_true_false( - // CHECK-SAME: %[[ARG0:.*]]: memref, - // CHECK-SAME: %[[ARG1:.*]]: memref - // CHECK-SAME: ) - func.func @case_true_false(%arg0: memref, %arg1: memref) { - %c1 = arith.constant 1 : index - - // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : i32 - // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : i32 - - // CHECK: %[[HOST:.*]] = memref.alloca() : memref - // CHECK: gpu.memcpy %[[HOST]], %[[ARG1]] - - // CHECK: %[[PRED:.*]] = memref.load %[[HOST]][] : memref - // CHECK: %[[IDX:.*]] = arith.select %[[PRED]], %[[C0]], %[[C1]] - - // CHECK: scf.execute_region - // CHECK: cf.switch %[[IDX]] : i32 - // CHECK: default: ^[[YIELD:.*]], - // CHECK: 0: ^[[CASE0:.*]], - // CHECK: 1: ^[[CASE1:.*]] - "lmhlo.case"(%arg1) ({ - gpu.launch_func @case0::@fn blocks in (%c1, %c1, %c1) - threads in (%c1, %c1, %c1) - args(%arg0 : memref) - "lmhlo.terminator"() : () -> () - }, { - gpu.launch_func @case1::@fn blocks in (%c1, %c1, %c1) - threads in (%c1, %c1, %c1) - args(%arg0 : memref) - "lmhlo.terminator"() : () -> () - }) : (memref) -> () - - // CHECK: ^[[CASE0]]: - // CHECK: gpu.launch_func @case0::@fn - // CHECK: cf.br ^[[YIELD]] - - // CHECK: ^[[CASE1]]: - // CHECK: gpu.launch_func @case1::@fn - // CHECK: cf.br ^[[YIELD]] - - // CHECK: ^[[YIELD]]: - // CHECK-NEXT: scf.yield - - // CHECK: return - "lmhlo.terminator"() : () -> () - } - - // CHECK: @case_index( - // CHECK-SAME: %[[ARG0:.*]]: memref, - // CHECK-SAME: %[[ARG1:.*]]: memref - // CHECK-SAME: ) - func.func @case_index(%arg0: memref, %arg1: memref) { - %c1 = arith.constant 1 : index - - // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : i32 - // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : i32 - - // CHECK: %[[HOST:.*]] = memref.alloca() : memref - // CHECK: gpu.memcpy %[[HOST]], %[[ARG1]] - - // CHECK: %[[PRED:.*]] = memref.load %[[HOST]][] : memref - // CHECK: %[[SMALL:.*]] = arith.cmpi slt, %[[PRED]], %[[C0]] : i32 - // CHECK: %[[LARGE:.*]] = arith.cmpi sgt, %[[PRED]], %[[C1]] : i32 - // CHECK: %[[OOR:.*]] = arith.ori %[[SMALL]], %[[LARGE]] : i1 - // CHECK: %[[IDX:.*]] = arith.select %[[OOR]], %[[C1]], %[[PRED]] : i32 - - // CHECK: scf.execute_region - // CHECK: cf.switch %[[IDX]] : i32 - // CHECK: default: ^[[YIELD:.*]], - // CHECK: 0: ^[[CASE0:.*]], - // CHECK: 1: ^[[CASE1:.*]] - "lmhlo.case"(%arg1) ({ - gpu.launch_func @case0::@fn blocks in (%c1, %c1, %c1) - threads in (%c1, %c1, %c1) - args(%arg0 : memref) - "lmhlo.terminator"() : () -> () - }, { - gpu.launch_func @case1::@fn blocks in (%c1, %c1, %c1) - threads in (%c1, %c1, %c1) - args(%arg0 : memref) - "lmhlo.terminator"() : () -> () - }) : (memref) -> () - - // CHECK: ^[[CASE0]]: - // CHECK: gpu.launch_func @case0::@fn - // CHECK: cf.br ^[[YIELD]] - - // CHECK: ^[[CASE1]]: - // CHECK: gpu.launch_func @case1::@fn - // CHECK: cf.br ^[[YIELD]] - - // CHECK: ^[[YIELD]]: - // CHECK-NEXT: scf.yield - - // CHECK: return - "lmhlo.terminator"() : () -> () - } -} diff --git a/third_party/xla/xla/mlir/backends/gpu/transforms/tests/lmhlo_custom_call.mlir b/third_party/xla/xla/mlir/backends/gpu/transforms/tests/lmhlo_custom_call.mlir deleted file mode 100644 index 6b333f90f77dc3..00000000000000 --- a/third_party/xla/xla/mlir/backends/gpu/transforms/tests/lmhlo_custom_call.mlir +++ /dev/null @@ -1,64 +0,0 @@ -// RUN: xla-gpu-opt %s -split-input-file -xla-lmhlo-to-gpu-runtime \ -// RUN: | FileCheck %s - -// CHECK: func @test -// CHECK: %[[ARG0:.*]]: memref -// CHECK: ) -func.func @test(%arg0: memref) { - // CHECK: call @[[CUSTOM_CALL:.*]](%[[ARG0]]) - // CHECK-SAME: api_version = 2 : i32 - // CHECK-SAME: backend_config = "" - // CHECK-SAME: call_target_name = "target" - // CHECK-SAME: : (memref) -> () - "lmhlo.custom_call"(%arg0) ({}) { - api_version = 2 : i32, - backend_config = "", - call_target_name = "target", - operandSegmentSizes = array - } : (memref) -> () - return -} - -// CHECK: func.func private @[[CUSTOM_CALL]](memref) -// CHECK-SAME: attributes {rt.custom_call = "xla.gpu.custom_call"} - -// ----- - -// CHECK: func @test_with_mapping -// CHECK: %[[ARG0:[0-9a-z]*]]: memref, -// CHECK: %[[ARG1:[0-9a-z]*]]: memref, -// CHECK: %[[ARG2:[0-9a-z]*]]: memref, -// CHECK: %[[ARG3:[0-9a-z]*]]: memref, -// CHECK: %[[ARG4:[0-9a-z]*]]: memref -// CHECK: ) -func.func @test_with_mapping( - %arg0: memref, - %arg1: memref, - %arg2: memref, - %arg3: memref, - %arg4: memref) { - // CHECK: %[[HOLE:.*]] = arith.constant -1 : i64 - - // CHECK: call @[[CUSTOM_CALL:.*]](%[[ARG0]], %[[HOLE]], %[[ARG1]], %[[HOLE]], - // CHECK-SAME: %[[ARG2]], %[[ARG3]], %[[HOLE]], %[[ARG4]]) - // CHECK-SAME: api_version = 1 : i32 - // CHECK-SAME: backend_config = "" - // CHECK-SAME: call_target_name = "target" - "lmhlo.custom_call"(%arg0, %arg1, %arg2, %arg3, %arg4) ({}) { - api_version = 1 : i32, - backend_config = "", - call_target_name = "target", - operandSegmentSizes = array, - target_arg_mapping = #lmhlo.custom_call_target_arg_mapping< - num_args = 4, - num_results = 4, - args_to_target_args = [0, 2], - results_to_target_results = [0, 1, 3]> - } : (memref, memref, memref, memref, memref) -> () - - return -} - -// CHECK: func.func private @[[CUSTOM_CALL]](memref, i64, memref, i64, -// CHECK-SAME: memref, memref, i64, memref) -// CHECK-SAME: attributes {rt.custom_call = "xla.gpu.custom_call"} diff --git a/third_party/xla/xla/mlir/backends/gpu/transforms/tests/lmhlo_fft.mlir b/third_party/xla/xla/mlir/backends/gpu/transforms/tests/lmhlo_fft.mlir deleted file mode 100644 index aeb19228d013e3..00000000000000 --- a/third_party/xla/xla/mlir/backends/gpu/transforms/tests/lmhlo_fft.mlir +++ /dev/null @@ -1,25 +0,0 @@ -// RUN: xla-gpu-opt %s -xla-lmhlo-to-gpu-runtime | FileCheck %s - -// CHECK: @compute( -// CHECK: %[[ARG0:[a-z0-9]+]]: memref<3x5x16x5xcomplex -// CHECK: %[[ARG1:[a-z0-9]+]]: memref<3x5x16x8xf32> -// CHECK: ) -func.func @compute(%arg0: memref<3x5x16x5xcomplex>, - %arg1: memref<3x5x16x8xf32>) { - - // CHECK: call @[[FFT:.*]](%[[ARG0]], %[[ARG1]]) - // CHECK-SAME: fft_length = dense<[16, 8]> : tensor<2xi64> - // CHECK-SAME: fft_type = #mhlo - // CHECK-SAME: uid = 0 : i64 - "lmhlo.fft"(%arg0, %arg1) { - fft_length = dense<[16, 8]> : tensor<2xi64>, - fft_type = #mhlo - } : (memref<3x5x16x5xcomplex>, memref<3x5x16x8xf32>) -> () - - // CHECK-NEXT: return - func.return -} - -// CHECK: func private @[[FFT]](memref<3x5x16x5xcomplex>, -// CHECK-SAME: memref<3x5x16x8xf32>) -// CHECK-SAME: attributes {rt.custom_call = "xla.gpu.fft"} diff --git a/third_party/xla/xla/mlir/backends/gpu/transforms/tests/lmhlo_gpu_cholesky.mlir b/third_party/xla/xla/mlir/backends/gpu/transforms/tests/lmhlo_gpu_cholesky.mlir deleted file mode 100644 index 9dbc0ee2eb260c..00000000000000 --- a/third_party/xla/xla/mlir/backends/gpu/transforms/tests/lmhlo_gpu_cholesky.mlir +++ /dev/null @@ -1,28 +0,0 @@ -// RUN: xla-gpu-opt %s -xla-lmhlo-gpu-to-gpu-runtime | FileCheck %s - -// CHECK: @compute( -// CHECK: %[[ARG0:[a-z0-9]+]]: memref<4x4xi32> -// CHECK: %[[ARG1:[a-z0-9]+]]: memref<4x4xi32> -// CHECK: %[[ARG2:[a-z0-9]+]]: memref<4x4xi32> -// CHECK: %[[ARG3:[a-z0-9]+]]: memref<4x4xi32> -// CHECK: ) -func.func @compute(%operand: memref<4x4xi32>, %a: memref<4x4xi32>, - %workspace: memref<4x4xi32>, %info: memref<4x4xi32>) { - - // CHECK: call @[[CHOLESKY:.*]](%[[ARG0]], %[[ARG1]], %[[ARG2]], %[[ARG3]]) - // CHECK-SAME: batch_size = 1 : i64 - // CHECK-SAME: is_lower = true - // CHECK-SAME: n = 4 : i64 - "lmhlo_gpu.cholesky"(%operand, %a, %workspace, %info) { - batch_size = 1 : i64, - is_lower = true, - n = 4 : i64 - } : (memref<4x4xi32>, memref<4x4xi32>, memref<4x4xi32>, memref<4x4xi32>) -> () - - // CHECK-NEXT: return - func.return -} - -// CHECK: func private @[[CHOLESKY]](memref<4x4xi32>, memref<4x4xi32>, -// CHECK-SAME: memref<4x4xi32>, memref<4x4xi32>) -// CHECK-SAME: attributes {rt.custom_call = "xla.gpu.cholesky"} diff --git a/third_party/xla/xla/mlir/backends/gpu/transforms/tests/lmhlo_gpu_conv.mlir b/third_party/xla/xla/mlir/backends/gpu/transforms/tests/lmhlo_gpu_conv.mlir deleted file mode 100644 index e1173d06d707f7..00000000000000 --- a/third_party/xla/xla/mlir/backends/gpu/transforms/tests/lmhlo_gpu_conv.mlir +++ /dev/null @@ -1,380 +0,0 @@ -// RUN: xla-gpu-opt %s -split-input-file -xla-lmhlo-gpu-to-gpu-runtime \ -// RUN: | FileCheck %s - -#map0 = affine_map<(d0, d1, d2, d3) -> (d0 * 3 + d1 + d2 * 9 + d3 * 9)> -#map1 = affine_map<(d0, d1, d2, d3) -> (d0 * 16384 + d1 * 4 + d2 + d3 * 16)> -#map2 = affine_map<(d0, d1, d2, d3) -> (d0 * 4096 + d1 * 2 + d2 + d3 * 4)> - -// CHECK: @conv_forward( -// CHECK: %[[INPUT:[a-z0-9]+]]: memref -// CHECK: %[[FILTER:[a-z0-9]+]]: memref -// CHECK: %[[OUTPUT:[a-z0-9]+]]: memref -// CHECK: %[[SCRATCH:[a-z0-9]+]]: memref -// CHECK: ) -func.func @conv_forward(%input: memref<1x4x4x1024xf16, #map1>, - %filter: memref<3x3x1x1024xf16, #map0>, - %output: memref<1x2x2x1024xf16, #map2>, - %scratch: memref<0xui8>) { - - // CHECK: call @local_xla.gpu.conv.forward( - // CHECK-SAME: %[[INPUT]], %[[FILTER]], %[[OUTPUT]], %[[SCRATCH]]) - - // CHECK-DAG: uid = 0 : i64 - // CHECK-DAG: conv_dims = #mhlo.conv<[b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f]> - - // CHECK-DAG: window_strides = dense<1> : tensor<2xi64> - // CHECK-DAG: lhs_dilation = dense<1> : tensor<2xi64> - // CHECK-DAG: rhs_dilation = dense<1> : tensor<2xi64> - // CHECK-DAG: window_reversal = dense<0> : tensor<2xi64> - // CHECK-DAG: padding = dense<> : tensor<0xi64> - - // CHECK-DAG: backend_config = #lmhlo_gpu.convolution_backend_config< - // CHECK-DAG: algorithm = 0 - // CHECK-DAG: is_cudnn_frontend = true - // CHECK-DAG: knob_ids = [] - // CHECK-DAG: knob_values = [] - // CHECK-DAG: operand_0_layout = [2, 1, 3, 0] - // CHECK-DAG: operand_1_layout = [1, 0, 2, 3] - // CHECK-DAG: tensor_ops_enabled = false - // CHECK-DAG: workspace_size = 0 - - // CHECK-DAG: feature_group_count = 1024 : i64 - // CHECK-DAG: result_scale = 1.000000e+00 : f64 - lmhlo_gpu.conv_forward(%input, %filter, %output, %scratch) - dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], - window = { stride = [1, 1], - lhs_dilate = [1, 1], - rhs_dilate = [1, 1], - reverse = [0, 0] - } - { backend_config = #lmhlo_gpu.convolution_backend_config< - algorithm = 0, - is_cudnn_frontend = true, - is_cudnn_reordered_int8 = false, - knob_ids = [], - knob_values = [], - operand_0_layout = [2, 1, 3, 0], - operand_1_layout = [1, 0, 2, 3], - result_layout = [2, 1, 3, 0], - tensor_ops_enabled = false, - workspace_size = 0 - >, - batch_group_count = 1 : i64, - feature_group_count = 1024 : i64, - precision_config = [], - result_scale = 1.000000e+00 : f64 - } : (memref<1x4x4x1024xf16, #map1>, - memref<3x3x1x1024xf16, #map0>, - memref<1x2x2x1024xf16, #map2>, - memref<0xui8>) -> () - - return -} - -// CHECK: func private @local_xla.gpu.conv.forward( -// CHECK-SAME: memref<1x4x4x1024xf16, #map{{[0-9]*}}>, memref<3x3x1x1024xf16, #map{{[0-9]*}}>, -// CHECK-SAME: memref<1x2x2x1024xf16, #map{{[0-9]*}}>, memref<0xui8>) -// CHECK-SAME: attributes {rt.custom_call = "xla.gpu.conv.forward"} - -// ----- - -#map0 = affine_map<(d0, d1, d2, d3) -> (d0 * 9 + d1 * 3 + d2 + d3 * 9)> -#map1 = affine_map<(d0, d1, d2, d3) -> (d0 * 3 + d1 + d2 * 27 + d3 * 9)> -#map2 = affine_map<(d0, d1, d2, d3) -> (d0 + d1 + d2 + d3)> - -// CHECK: @conv_backwardfilter( -// CHECK: %[[INPUT:[a-z0-9]+]]: memref -// CHECK: %[[D_OUTPUT:[a-z0-9]+]]: memref -// CHECK: %[[D_FILTER:[a-z0-9]+]]: memref -// CHECK: %[[SCRATCH:[a-z0-9]+]]: memref -// CHECK: ) -func.func @conv_backwardfilter(%input: memref<1x3x3x5xf16, #map0>, - %d_output: memref<3x3x5x3xf16, #map1>, - %d_filter: memref<1x1x1x3xf16, #map2>, - %scratch: memref<0xui8>) { - // CHECK: call @local_xla.gpu.conv.backward.filter( - // CHECK-SAME: %[[INPUT]], %[[D_OUTPUT]], %[[D_FILTER]], %[[SCRATCH]]) - lmhlo_gpu.conv_backwardfilter(%input, %d_output, %d_filter, %scratch) - dim_numbers = [f, 0, 1, b]x[i, 0, 1, o]->[0, 1, b, f], - window = { stride = [1, 1], - lhs_dilate = [1, 1], - rhs_dilate = [1, 1], - reverse = [0, 0] - } - { backend_config = #lmhlo_gpu.convolution_backend_config< - algorithm = 0, - is_cudnn_frontend = true, - is_cudnn_reordered_int8 = false, - knob_ids = [], - knob_values = [], - operand_0_layout = [2, 1, 0, 3], - operand_1_layout = [1, 0, 3, 2], - result_layout = [2, 1, 0, 3], - tensor_ops_enabled = false, - workspace_size = 0 - >, - batch_group_count = 1 : i64, - feature_group_count = 1 : i64, - precision_config = [], - result_scale = 1.000000e+00 : f64 - } : (memref<1x3x3x5xf16, #map0>, - memref<3x3x5x3xf16, #map1>, - memref<1x1x1x3xf16, #map2>, - memref<0xui8>) -> () - return -} - -// CHECK: func private @local_xla.gpu.conv.backward.filter( -// CHECK-SAME: memref<1x3x3x5xf16, #map{{[0-9]*}}>, memref<3x3x5x3xf16, #map{{[0-9]*}}>, -// CHECK-SAME: memref<1x1x1x3xf16, #map{{[0-9]*}}>, memref<0xui8> -// CHECK-SAME: ) attributes {rt.custom_call = -// CHECK-SAME: "xla.gpu.conv.backward.filter"} - -// ----- - -// CHECK: @conv_backwardinput( -// CHECK: %[[D_OUTPUT:[a-z0-9]+]]: memref -// CHECK: %[[FILTER:[a-z0-9]+]]: memref -// CHECK: %[[D_INPUT:[a-z0-9]+]]: memref -// CHECK: %[[SCRATCH:[a-z0-9]+]]: memref -// CHECK: ) -func.func @conv_backwardinput(%d_output: memref<4x5x16x16xf64>, - %filter: memref<5x3x7x7xf64>, - %d_input: memref<4x3x16x16xf64>, - %scratch: memref<0xui8>) { - // CHECK: call @local_xla.gpu.conv.backward.input( - // CHECK-SAME: %[[D_OUTPUT]], %[[FILTER]], %[[D_INPUT]], %[[SCRATCH]]) - lmhlo_gpu.conv_backwardinput(%d_output, %filter, %d_input, %scratch) - dim_numbers = [b, f, 0, 1]x[o, i, 0, 1]->[b, f, 0, 1], - window = { stride = [1, 1], - lhs_dilate = [1, 1], - rhs_dilate = [1, 1], - reverse = [0, 0] - } - { backend_config = #lmhlo_gpu.convolution_backend_config< - algorithm = 2, - is_cudnn_frontend = true, - is_cudnn_reordered_int8 = false, - knob_ids = [3, 2], - knob_values = [0, 3], - operand_0_layout = [3, 2, 1, 0], - operand_1_layout = [3, 2, 1, 0], - result_layout = [3, 2, 1, 0], - tensor_ops_enabled = false, - workspace_size = 0 - >, - batch_group_count = 1 : i64, - feature_group_count = 1 : i64, - precision_config = [], - result_scale = 1.000000e+00 : f64 - } : (memref<4x5x16x16xf64>, - memref<5x3x7x7xf64>, - memref<4x3x16x16xf64>, - memref<0xui8>) -> () - return -} - -// CHECK: func private @local_xla.gpu.conv.backward.input( -// CHECK-SAME: memref<4x5x16x16xf64>, memref<5x3x7x7xf64>, -// CHECK-SAME: memref<4x3x16x16xf64>, memref<0xui8> -// CHECK-SAME: ) attributes {rt.custom_call = -// CHECK-SAME: "xla.gpu.conv.backward.input"} - -// ----- - -#map0 = affine_map<(d0, d1, d2, d3) -> (d0 * 3 + d1 + d2 * 9 + d3 * 9)> -#map1 = affine_map<(d0, d1, d2, d3) -> (d0 * 25 + d1 * 5 + d2 + d3 * 25)> -#map2 = affine_map<(d0, d1, d2, d3) -> (d0 * 800 + d1 * 5 + d2 + d3 * 25)> - -// CHECK: @conv_forward_fused( -// CHECK: %[[INPUT:[a-z0-9]+]]: memref -// CHECK: %[[FILTER:[a-z0-9]+]]: memref -// CHECK: %[[BIAS:[a-z0-9]+]]: memref -// CHECK: %[[OUTPUT:[a-z0-9]+]]: memref -// CHECK: %[[SCRATCH:[a-z0-9]+]]: memref -// CHECK: ) -func.func @conv_forward_fused(%input: memref<8x5x5x1xf32, #map1>, - %filter: memref<3x3x1x32xf32, #map0>, - %bias: memref<32xf32>, - %output: memref<8x5x5x32xf32, #map2>, - %scratch: memref<0xui8>) { - // CHECK: call @local_xla.gpu.conv.forward.fused( - // CHECK-SAME: %[[INPUT]], %[[FILTER]], %[[BIAS]], %[[OUTPUT]], %[[SCRATCH]]) - - // CHECK-DAG: activation_mode = #lmhlo_gpu - // CHECK-DAG: knob_ids = [2, 3] - // CHECK-DAG: knob_values = [4, 0] - lmhlo_gpu.conv_forward_fused(%input, %filter, %bias, %output, %scratch) - dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], - window = { stride = [1, 1], - lhs_dilate = [1, 1], - rhs_dilate = [1, 1], - reverse = [0, 0] - } - { activation_mode = #lmhlo_gpu, - leakyrelu_alpha = 0.0 : f64, - backend_config = #lmhlo_gpu.convolution_backend_config< - algorithm = 11, - is_cudnn_frontend = true, - is_cudnn_reordered_int8 = false, - knob_ids = [2, 3], - knob_values = [4, 0], - operand_0_layout = [2, 1, 3, 0], - operand_1_layout = [1, 0, 2, 3], - result_layout = [2, 1, 3, 0], - tensor_ops_enabled = false, - workspace_size = 0 - >, - batch_group_count = 1 : i64, - feature_group_count = 1 : i64, - precision_config = [], - result_scale = 1.000000e+00 : f64 - } : (memref<8x5x5x1xf32, #map1>, - memref<3x3x1x32xf32, #map0>, - memref<32xf32>, - memref<8x5x5x32xf32, #map2>, - memref<0xui8>) -> () - - return -} - -// CHECK: func private @local_xla.gpu.conv.forward.fused( -// CHECK-SAME: memref<8x5x5x1xf32, #map{{[0-9]*}}>, memref<3x3x1x32xf32, #map{{[0-9]*}}>, -// CHECK-SAME: memref<32xf32>, memref<8x5x5x32xf32, #map{{[0-9]*}}>, memref<0xui8> -// CHECK-SAME: ) attributes {rt.custom_call = -// CHECK-SAME: "xla.gpu.conv.forward.fused"} - -// ----- - -#map0 = affine_map<(d0, d1, d2, d3) -> (d0 * 576 + d1 * 3 + d2 + d3 * 9)> -#map1 = affine_map<(d0, d1, d2, d3) -> (d0 * 3 + d1 + d2 * 9 + d3 * 576)> - -// CHECK: @conv_forward_fused_with_side_input( -// CHECK: %[[INPUT:[a-z0-9]+]]: memref -// CHECK: %[[FILTER:[a-z0-9]+]]: memref -// CHECK: %[[BIAS:[a-z0-9]+]]: memref -// CHECK: %[[SIDE_INPUT:[a-z0-9]+]]: memref -// CHECK: %[[OUTPUT:[a-z0-9]+]]: memref -// CHECK: %[[SCRATCH:[a-z0-9]+]]: memref -// CHECK: ) -func.func @conv_forward_fused_with_side_input( - %input: memref<1x3x3x64xf64, #map0>, - %filter: memref<3x3x64x64xf64, #map1>, - %bias: memref<64xf64>, - %side_input: memref<1x3x3x64xf64, #map0>, - %output: memref<1x3x3x64xf64, #map0>, - %scratch: memref<0xui8>) { - - // CHECK: call @local_xla.gpu.conv.forward.fused.side_input( - // CHECK-SAME: %[[INPUT]], %[[FILTER]], %[[BIAS]], %[[SIDE_INPUT]], - // CHECK-SAME: %[[OUTPUT]], %[[SCRATCH]]) - - // CHECK-DAG: activation_mode = #lmhlo_gpu - // CHECK-DAG: side_input_scale = 1.000000e+00 : f64 - lmhlo_gpu.conv_forward_fused_with_side_input( - %input, %filter, %bias, %side_input, %output, %scratch) - dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], - window = { stride = [1, 1], - lhs_dilate = [1, 1], - rhs_dilate = [1, 1], - reverse = [0, 0] - } - { activation_mode = #lmhlo_gpu, - backend_config = #lmhlo_gpu.convolution_backend_config< - algorithm = 0, - is_cudnn_frontend = true, - is_cudnn_reordered_int8 = false, - knob_ids = [], - knob_values = [], - operand_0_layout = [2, 1, 3, 0], - operand_1_layout = [1, 0, 2, 3], - result_layout = [2, 1, 3, 0], - tensor_ops_enabled = false, - workspace_size = 0 - >, - batch_group_count = 1 : i64, - feature_group_count = 1 : i64, - precision_config = [], - result_scale = 1.000000e+00 : f64, - side_input_scale = 1.000000e+00 : f64 - } : (memref<1x3x3x64xf64, #map0>, - memref<3x3x64x64xf64, #map1>, - memref<64xf64>, - memref<1x3x3x64xf64, #map0>, - memref<1x3x3x64xf64, #map0>, - memref<0xui8>) -> () - - return -} - -// CHECK: func private @local_xla.gpu.conv.forward.fused.side_input( -// CHECK-SAME: memref<1x3x3x64xf64, #map{{[0-9]*}}>, memref<3x3x64x64xf64, #map{{[0-9]*}}>, -// CHECK-SAME: memref<64xf64>, memref<1x3x3x64xf64, #map{{[0-9]*}}>, -// CHECK-SAME: memref<1x3x3x64xf64, #map{{[0-9]*}}>, memref<0xui8> -// CHECK-SAME: ) attributes {rt.custom_call = -// CHECK-SAME: "xla.gpu.conv.forward.fused.side_input"} - -// ----- - -#map0 = affine_map<(d0, d1, d2, d3, d4) -> (d0 + d1 + d2 + d3 * 3 + d4 * 9)> - -// CHECK: @conv_reorder_filter( -// CHECK: %[[INPUT:[a-z0-9]+]]: memref -// CHECK: %[[OUTPUT:[a-z0-9]+]]: memref -// CHECK: ) -func.func @conv_reorder_filter( - %input: memref<1x1x3x3x32xi8, #map0>, - %output: memref<1x1x3x3x32xi8, #map0>) { - - // CHECK: call @local_xla.gpu.conv.reorder.filter( - // CHECK-SAME: %[[INPUT]], %[[OUTPUT]] - // CHECK-DAG: filter_dims = array - "lmhlo_gpu.cudnn_conv_reorder_filter"(%input, %output) { - filter_dims = dense<[1, 32, 3, 3]> : tensor<4xi64> - }: (memref<1x1x3x3x32xi8, #map0>, - memref<1x1x3x3x32xi8, #map0>) -> () - - return -} - -// CHECK: func private @local_xla.gpu.conv.reorder.filter( -// CHECK-SAME: memref<1x1x3x3x32xi8, #map{{[0-9]*}}>, -// CHECK-SAME: memref<1x1x3x3x32xi8, #map{{[0-9]*}}> -// CHECK-SAME: ) attributes {rt.custom_call = -// CHECK-SAME: "xla.gpu.conv.reorder.filter"} - -// ----- - -#map0 = affine_map<(d0, d1, d2, d3, d4) -> (d0 + d1 + d2 + d3 * 3 + d4 * 9)> - -// CHECK: @conv_reorder_filter_and_bias( -// CHECK: %[[FILTER_INPUT:[a-z0-9]+]]: memref -// CHECK: %[[BIAS_INPUT:[a-z0-9]+]]: memref -// CHECK: %[[FILTER_OUTPUT:[a-z0-9]+]]: memref -// CHECK: %[[BIAS_OUTPUT:[a-z0-9]+]]: memref -// CHECK: ) -func.func @conv_reorder_filter_and_bias( - %filter_input: memref<1x1x3x3x32xi8, #map0>, - %bias_input: memref<32xf32>, - %filter_output: memref<1x1x3x3x32xi8, #map0>, - %bias_output: memref<32xf32>) { - - // CHECK: call @local_xla.gpu.conv.reorder.filter_and_bias( - // CHECK-SAME: %[[FILTER_INPUT]], %[[BIAS_INPUT]], %[[FILTER_OUTPUT]], %[[BIAS_OUTPUT]] - // CHECK-DAG: filter_dims = array - "lmhlo_gpu.cudnn_conv_reorder_filter_and_bias"( - %filter_input, %bias_input, %filter_output, %bias_output) { - filter_dims = dense<[1, 32, 3, 3]> : tensor<4xi64> - }: (memref<1x1x3x3x32xi8, #map0>, memref<32xf32>, - memref<1x1x3x3x32xi8, #map0>, memref<32xf32>) -> () - - return -} - -// CHECK: func private @local_xla.gpu.conv.reorder.filter_and_bias( -// CHECK-SAME: memref<1x1x3x3x32xi8, #map{{[0-9]*}}>, -// CHECK-SAME: memref<32xf32>, -// CHECK-SAME: memref<1x1x3x3x32xi8, #map{{[0-9]*}}>, -// CHECK-SAME: memref<32xf32> -// CHECK-SAME: ) attributes {rt.custom_call = -// CHECK-SAME: "xla.gpu.conv.reorder.filter_and_bias"} diff --git a/third_party/xla/xla/mlir/backends/gpu/transforms/tests/lmhlo_gpu_cublas_lt_matmul.mlir b/third_party/xla/xla/mlir/backends/gpu/transforms/tests/lmhlo_gpu_cublas_lt_matmul.mlir deleted file mode 100644 index 6255a255411c5a..00000000000000 --- a/third_party/xla/xla/mlir/backends/gpu/transforms/tests/lmhlo_gpu_cublas_lt_matmul.mlir +++ /dev/null @@ -1,100 +0,0 @@ -// RUN: xla-gpu-opt %s -split-input-file -xla-lmhlo-gpu-to-gpu-runtime \ -// RUN: | FileCheck %s - -// CHECK: @compute( -// CHECK: %[[A:[a-z0-9]+]]: memref<2x6x2x2xf32>, -// CHECK: %[[B:[a-z0-9]+]]: memref<2x6x2x2xf32>, -// CHECK: %[[C:[a-z0-9]+]]: memref<2x6x2x2xf32>, -// CHECK: %[[D:[a-z0-9]+]]: memref<2x6x2x2xf32> -// CHECK: ) -func.func @compute(%a: memref<2x6x2x2xf32>, - %b: memref<2x6x2x2xf32>, - %c: memref<2x6x2x2xf32>, - %d: memref<2x6x2x2xf32>) { - - // CHECK: @local_xla.gpu.cublas.lt.matmul(%[[A]], %[[B]], %[[C]], %[[D]]) - // CHECK-SAME: alpha_imag = 0.000000e+00 : f64 - // CHECK-SAME: alpha_real = 1.000000e+00 : f64 - // CHECK-SAME: beta = 0.000000e+00 : f64 - // CHECK-SAME: dot_dims = #mhlo.dot - // CHECK-SAME: epilogue = #lmhlo_gpu - // CHECK-SAME: precision = dense<0> : tensor<2xi32> - // CHECK-SAME: uid = 0 : i64 - "lmhlo_gpu.cublas.lt.matmul"(%a, %b, %c, %d) { - algorithm = 0 : i64, - alpha_imag = 0.000000e+00 : f64, - alpha_real = 1.000000e+00 : f64, - beta = 0.000000e+00 : f64, - dot_dimension_numbers = #mhlo.dot< - lhs_batching_dimensions = [0, 1], - rhs_batching_dimensions = [0, 1], - lhs_contracting_dimensions = [3], - rhs_contracting_dimensions = [2]>, - epilogue = #lmhlo_gpu, - precision_config = [#mhlo, #mhlo], - operandSegmentSizes = array - } : (memref<2x6x2x2xf32>, memref<2x6x2x2xf32>, - memref<2x6x2x2xf32>, memref<2x6x2x2xf32>) -> () - - return -} - -// CHECK: func private @local_xla.gpu.cublas.lt.matmul( -// CHECK-SAME: memref<2x6x2x2xf32>, memref<2x6x2x2xf32>, -// CHECK-SAME: memref<2x6x2x2xf32>, memref<2x6x2x2xf32> -// CHECK-SAME: ) attributes {rt.custom_call = "xla.gpu.cublas.lt.matmul"} - -// ----- - -// CHECK: @compute( -// CHECK: %[[A:[a-z0-9]+]]: memref<2x6x2x2xf32>, -// CHECK: %[[B:[a-z0-9]+]]: memref<2x6x2x2xf32>, -// CHECK: %[[C:[a-z0-9]+]]: memref<2x6x2x2xf32>, -// CHECK: %[[D:[a-z0-9]+]]: memref<2x6x2x2xf32>, -// CHECK: %[[BIAS:[a-z0-9]+]]: memref<2x6x2x2xf32> -// CHECK: ) -func.func @compute(%a: memref<2x6x2x2xf32>, - %b: memref<2x6x2x2xf32>, - %c: memref<2x6x2x2xf32>, - %d: memref<2x6x2x2xf32>, - %bias: memref<2x6x2x2xf32>) { - - // CHECK: @local_xla.gpu.cublas.lt.matmul.bias(%[[A]], %[[B]], %[[C]], %[[D]], - // CHECK-SAME: %[[BIAS]]) - // CHECK-SAME: alpha_imag = 0.000000e+00 : f64 - // CHECK-SAME: alpha_real = 1.000000e+00 : f64 - // CHECK-SAME: beta = 0.000000e+00 : f64 - // CHECK-SAME: dot_dims = #mhlo.dot - // CHECK-SAME: epilogue = #lmhlo_gpu - // CHECK-SAME: precision = dense<0> : tensor<2xi32> - // CHECK-SAME: uid = 0 : i64 - "lmhlo_gpu.cublas.lt.matmul"(%a, %b, %c, %d, %bias) { - algorithm = 0 : i64, - alpha_imag = 0.000000e+00 : f64, - alpha_real = 1.000000e+00 : f64, - beta = 0.000000e+00 : f64, - dot_dimension_numbers = #mhlo.dot< - lhs_batching_dimensions = [0, 1], - rhs_batching_dimensions = [0, 1], - lhs_contracting_dimensions = [3], - rhs_contracting_dimensions = [2]>, - epilogue = #lmhlo_gpu, - precision_config = [#mhlo, #mhlo], - operandSegmentSizes = array - } : (memref<2x6x2x2xf32>, memref<2x6x2x2xf32>, memref<2x6x2x2xf32>, - memref<2x6x2x2xf32>, memref<2x6x2x2xf32>) -> () - - return -} - -// CHECK: func private @local_xla.gpu.cublas.lt.matmul.bias( -// CHECK-SAME: memref<2x6x2x2xf32>, memref<2x6x2x2xf32>, memref<2x6x2x2xf32>, -// CHECK-SAME: memref<2x6x2x2xf32>, memref<2x6x2x2xf32> -// CHECK-SAME: ) attributes {rt.custom_call = -// CHECK-SAME: "xla.gpu.cublas.lt.matmul.bias"} diff --git a/third_party/xla/xla/mlir/backends/gpu/transforms/tests/lmhlo_gpu_gemm.mlir b/third_party/xla/xla/mlir/backends/gpu/transforms/tests/lmhlo_gpu_gemm.mlir deleted file mode 100644 index 51d2247049d3c2..00000000000000 --- a/third_party/xla/xla/mlir/backends/gpu/transforms/tests/lmhlo_gpu_gemm.mlir +++ /dev/null @@ -1,41 +0,0 @@ -// RUN: xla-gpu-opt %s -split-input-file -xla-lmhlo-gpu-to-gpu-runtime \ -// RUN: | FileCheck %s - -// CHECK: @compute( -// CHECK: %[[LHS:[a-z0-9]+]]: memref<4x4xf32>, -// CHECK: %[[RHS:[a-z0-9]+]]: memref<4x4xf32>, -// CHECK: %[[OUT:[a-z0-9]+]]: memref<4x4xf32> -// CHECK: ) -func.func @compute(%lhs: memref<4x4xf32>, %rhs: memref<4x4xf32>, - %out: memref<4x4xf32>) { - - // CHECK: call @[[GEMM:[_a-z.]+]](%[[LHS]], %[[RHS]], %[[OUT]]) - // CHECK-SAME: algorithm = 13 : i64 - // CHECK-SAME: alpha_imag = 0.000000e+00 : f64 - // CHECK-SAME: alpha_real = 1.000000e+00 : f64 - // CHECK-SAME: beta = 0.000000e+00 : f64 - // CHECK-SAME: dot_dims = #mhlo.dot - // CHECK-SAME: uid = 0 : i64 - // CHECK-SAME: (memref<4x4xf32>, memref<4x4xf32>, memref<4x4xf32>) -> () - "lmhlo_gpu.gemm"(%lhs, %rhs, %out) - { - algorithm = 13 : i64, - alpha_imag = 0.000000e+00 : f64, - alpha_real = 1.000000e+00 : f64, - batch_size = 1 : i64, - beta = 0.000000e+00 : f64, - dot_dimension_numbers = #mhlo.dot, - lhs_stride = 16 : i64, - rhs_stride = 16 : i64 - } - : (memref<4x4xf32>, memref<4x4xf32>, memref<4x4xf32>) -> () - - // CHECK-NEXT: return - func.return -} - -// CHECK: func private @[[GEMM:[_a-z.]+]](memref<4x4xf32>, memref<4x4xf32>, -// CHECK-SAME: memref<4x4xf32>) -// CHECK-SAME: attributes {rt.custom_call = "xla.gpu.gemm"} diff --git a/third_party/xla/xla/mlir/backends/gpu/transforms/tests/lmhlo_infeed.mlir b/third_party/xla/xla/mlir/backends/gpu/transforms/tests/lmhlo_infeed.mlir deleted file mode 100644 index 089a7e6ac3351a..00000000000000 --- a/third_party/xla/xla/mlir/backends/gpu/transforms/tests/lmhlo_infeed.mlir +++ /dev/null @@ -1,14 +0,0 @@ -// RUN: xla-gpu-opt %s -xla-lmhlo-to-gpu-runtime | FileCheck %s - -// CHECK: func @gpu_infeed( -// CHECK: %[[ARG0:[a-z0-9]+]]: memref -// CHECK: ) -func.func @gpu_infeed(%arg0: memref) { - // CHECK: call @[[INFEED:.*]](%[[ARG0]]) - // CHECK-SAME: {config = "abc"} : (memref) -> () - "lmhlo.infeed"(%arg0) {config = "abc"} : (memref) -> () - return -} - -// CHECK: func private @[[INFEED]](memref) -// CHECK-SAME: attributes {rt.custom_call = "xla.gpu.infeed"} diff --git a/third_party/xla/xla/mlir/backends/gpu/transforms/tests/lmhlo_outfeed.mlir b/third_party/xla/xla/mlir/backends/gpu/transforms/tests/lmhlo_outfeed.mlir deleted file mode 100644 index 32cf254a7ff99b..00000000000000 --- a/third_party/xla/xla/mlir/backends/gpu/transforms/tests/lmhlo_outfeed.mlir +++ /dev/null @@ -1,14 +0,0 @@ -// RUN: xla-gpu-opt %s -xla-lmhlo-to-gpu-runtime | FileCheck %s - -// CHECK: func @gpu_infeed( -// CHECK: %[[ARG0:[a-z0-9]+]]: memref -// CHECK: ) -func.func @gpu_infeed(%arg0: memref) { - // CHECK: call @[[OUTFEED:.*]](%[[ARG0]]) - // CHECK-SAME: {config = "abc"} : (memref) -> () - "lmhlo.outfeed"(%arg0) {config = "abc"} : (memref) -> () - return -} - -// CHECK: func private @[[OUTFEED]](memref) -// CHECK-SAME: attributes {rt.custom_call = "xla.gpu.outfeed"} diff --git a/third_party/xla/xla/mlir/backends/gpu/transforms/tests/lmhlo_send_recv.mlir b/third_party/xla/xla/mlir/backends/gpu/transforms/tests/lmhlo_send_recv.mlir deleted file mode 100644 index 51130537522712..00000000000000 --- a/third_party/xla/xla/mlir/backends/gpu/transforms/tests/lmhlo_send_recv.mlir +++ /dev/null @@ -1,88 +0,0 @@ -// RUN: xla-gpu-opt %s -split-input-file -xla-lmhlo-to-gpu-runtime \ -// RUN: | FileCheck %s - -// CHECK: func @send( -// CHECK: %[[ARG0:[a-z0-9]+]]: memref<4xf32> -// CHECK: ) -func.func @send(%arg0: memref<4xf32>) { - // CHECK: call @local_xla.gpu.send_host(%[[ARG0]]) { - // CHECK-SAME: channel_handle = #mhlo.channel_handle, - // CHECK-SAME: frontend_attributes = { - // CHECK-SAME: _xla_dcn_recv_channel = "2", - // CHECK-SAME: _xla_host_transfer_handler_name = "undef", - // CHECK-SAME: _xla_host_transfer_rendezvous = "undef" - // CHECK-SAME: }} : (memref<4xf32>) -> () - "lmhlo.send"(%arg0) { - channel_handle = #mhlo.channel_handle, - frontend_attributes = {_xla_dcn_recv_channel = "2", - _xla_host_transfer_handler_name = "undef", - _xla_host_transfer_rendezvous = "undef"}, - is_host_transfer = true - } : (memref<4xf32>) -> !mhlo.token - return -} - -// CHECK: func private @local_xla.gpu.send_host(memref<4xf32>) -// CHECK-SAME: attributes {rt.custom_call = "xla.gpu.send_host"} - -// ----- - -// CHECK: func @recv( -// CHECK: %[[ARG0:[a-z0-9]+]]: memref<4xf32> -// CHECK: ) -func.func @recv(%arg0: memref<4xf32>) { - // CHECK: call @local_xla.gpu.recv_host(%[[ARG0]]) { - // CHECK-SAME: channel_handle = #mhlo.channel_handle, - // CHECK-SAME: frontend_attributes = { - // CHECK-SAME: _xla_host_transfer_handler_name = "undef", - // CHECK-SAME: _xla_host_transfer_rendezvous = "undef" - // CHECK-SAME: }} : (memref<4xf32>) -> () - "lmhlo.recv"(%arg0) { - channel_handle = #mhlo.channel_handle, - frontend_attributes = {_xla_host_transfer_handler_name = "undef", - _xla_host_transfer_rendezvous = "undef"}, - is_host_transfer = true - } : (memref<4xf32>) -> !mhlo.token - return -} - -// CHECK: func private @local_xla.gpu.recv_host(memref<4xf32>) -// CHECK-SAME: attributes {rt.custom_call = "xla.gpu.recv_host"} - -// ----- - -// CHECK: func @send_done( -// CHECK: %[[ARG0:[a-z0-9]+]]: !mhlo.token -// CHECK: ) -func.func @send_done(%arg0: !mhlo.token) { - // CHECK: call @local_xla.gpu.send_done_host() { - // CHECK-SAME: channel_handle = #mhlo.channel_handle - // CHECK-SAME: } : () -> () - "lmhlo.send_done"(%arg0) { - channel_handle = #mhlo.channel_handle, - is_host_transfer = true - } : (!mhlo.token) -> () - return -} - -// CHECK: func private @local_xla.gpu.send_done_host() -// CHECK-SAME: attributes {rt.custom_call = "xla.gpu.send_done_host"} - -// ----- - -// CHECK: func @recv_done( -// CHECK: %[[ARG0:[a-z0-9]+]]: !mhlo.token -// CHECK: ) -func.func @recv_done(%arg0: !mhlo.token) { - // CHECK: call @local_xla.gpu.recv_done_host() { - // CHECK-SAME: channel_handle = #mhlo.channel_handle - // CHECK-SAME: } : () -> () - "lmhlo.recv_done"(%arg0) { - channel_handle = #mhlo.channel_handle, - is_host_transfer = true - } : (!mhlo.token) -> () - return -} - -// CHECK: func private @local_xla.gpu.recv_done_host() -// CHECK-SAME: attributes {rt.custom_call = "xla.gpu.recv_done_host"} diff --git a/third_party/xla/xla/mlir/backends/gpu/transforms/tests/lmhlo_while.mlir b/third_party/xla/xla/mlir/backends/gpu/transforms/tests/lmhlo_while.mlir deleted file mode 100644 index 9dac0c42d4b00d..00000000000000 --- a/third_party/xla/xla/mlir/backends/gpu/transforms/tests/lmhlo_while.mlir +++ /dev/null @@ -1,97 +0,0 @@ -// RUN: xla-gpu-opt %s --split-input-file -xla-lmhlo-to-gpu-runtime \ -// RUN: | FileCheck %s - -module attributes {gpu.container_module} { - memref.global "private" constant @constant : memref = dense<0> - - gpu.module @cond attributes {binary = "ptx"} { - gpu.func @fn(%arg0: memref, %arg1: memref) kernel { - gpu.return - } - } - - gpu.module @body attributes {binary = "ptx"} { - gpu.func @fn(%arg0: memref) kernel { - gpu.return - } - } - - // CHECK: @while_loop( - // CHECK-SAME: %[[ARG0:.*]]: memref, - // CHECK-SAME: %[[ARG1:.*]]: memref - // CHECK-SAME: ) - func.func @while_loop(%arg0: memref, %arg1: memref) { - %c1 = arith.constant 1 : index - %0 = memref.get_global @constant : memref - gpu.memcpy %arg0, %0 : memref, memref - - // CHECK: %[[HOST_PRED:.*]] = memref.alloca() : memref - // CHECK: scf.while : () -> () - "lmhlo.while"(%arg1) ({ - // CHECK: gpu.launch_func @cond::@fn - // CHECK: gpu.memcpy %[[HOST_PRED]], %[[ARG1]] - // CHECK: %[[COND:.*]] = memref.load %[[HOST_PRED]][] : memref - // CHECK: scf.condition(%[[COND]]) - gpu.launch_func @cond::@fn blocks in (%c1, %c1, %c1) - threads in (%c1, %c1, %c1) - args(%arg0 : memref, %arg1 : memref) - "lmhlo.terminator"() : () -> () - }, { - // CHECK: gpu.launch_func @body::@fn - // CHECK: scf.yield - gpu.launch_func @body::@fn blocks in (%c1, %c1, %c1) - threads in (%c1, %c1, %c1) - args(%arg0 : memref) - "lmhlo.terminator"() : () -> () - }) : (memref) -> () - "lmhlo.terminator"() : () -> () - } -} - -// ----- -// Check that while loops with known trip counts lower to `scf.for` loops. - -module attributes {gpu.container_module} { - memref.global "private" constant @constant : memref = dense<0> - - gpu.module @cond attributes {binary = "ptx"} { - gpu.func @fn(%arg0: memref, %arg1: memref) kernel { - gpu.return - } - } - - gpu.module @body attributes {binary = "ptx"} { - gpu.func @fn(%arg0: memref) kernel { - gpu.return - } - } - - // CHECK: @for_loop( - // CHECK-SAME: %[[ARG0:.*]]: memref, - // CHECK-SAME: %[[ARG1:.*]]: memref - // CHECK-SAME: ) - func.func @for_loop(%arg0: memref, %arg1: memref) { - // CHECK-DAG: %[[LB:.*]] = arith.constant 0 - // CHECK-DAG: %[[UB:.*]] = arith.constant 3000 - // CHECK-DAG: %[[C1:.*]] = arith.constant 1 - %c1 = arith.constant 1 : index - - // CHECK: scf.for %[[I:.*]] = %[[LB]] to %[[UB]] step %[[C1]] - // CHECK-NEXT: gpu.launch_func @body::@fn - // CHECK-NOT: gpu.launch.func - - "lmhlo.while"(%arg1) ({ - gpu.launch_func @cond::@fn blocks in (%c1, %c1, %c1) - threads in (%c1, %c1, %c1) - args(%arg0 : memref, %arg1 : memref) - "lmhlo.terminator"() : () -> () - }, { - gpu.launch_func @body::@fn blocks in (%c1, %c1, %c1) - threads in (%c1, %c1, %c1) - args(%arg0 : memref) - "lmhlo.terminator"() : () -> () - }) {trip_count = 3000 : i64} : (memref) -> () - - "lmhlo.terminator"() : () -> () - } -} diff --git a/third_party/xla/xla/mlir/backends/gpu/transforms/tests/memref_get_global_to_arg.mlir b/third_party/xla/xla/mlir/backends/gpu/transforms/tests/memref_get_global_to_arg.mlir deleted file mode 100644 index 6361c77f145f1c..00000000000000 --- a/third_party/xla/xla/mlir/backends/gpu/transforms/tests/memref_get_global_to_arg.mlir +++ /dev/null @@ -1,43 +0,0 @@ -// RUN: xla-gpu-opt %s -xla-memref-get-global-to-arg=min-num-elements=2 \ -// RUN: | FileCheck %s - -#map = affine_map<(d0, d1) -> (d0 + 2 * d1)> - -memref.global "private" constant @cst0 : memref<2x3xf32> = - dense<[[1.000000e+00, 2.000000e+00, 3.000000e+00], - [4.000000e+00, 5.000000e+00, 6.000000e+00]]> - -memref.global "private" constant @cst1 : memref = - dense<1.000000e+00> - -memref.global "private" constant @cst2 : memref<2x3xf32, #map> = - dense<[[1.000000e+00, 2.000000e+00, 3.000000e+00], - [4.000000e+00, 5.000000e+00, 6.000000e+00]]> - -// CHECK: func.func @get_global( -// CHECK-SAME: %[[ARG0:.*]]: memref<24xi8> {lmhlo.constant_name = "cst0"}, -// CHECK-SAME: %[[ARG1:.*]]: memref<4xi8> {lmhlo.constant_name = "cst1"}, -// CHECK-SAME: %[[ARG2:.*]]: memref<24xi8> {lmhlo.constant_name = "cst2"} -// CHECK-SAME: ) -func.func @get_global(%arg0: memref<24xi8> {lmhlo.constant_name = "cst0"}, - %arg1: memref<4xi8> {lmhlo.constant_name = "cst1"}, - %arg2: memref<24xi8> {lmhlo.constant_name = "cst2"}) - -> (memref<2x3xf32>, memref, memref<2x3xf32, #map>) { - - // CHECK: %[[C0:.*]] = arith.constant 0 : index - // CHECK: %[[V0:.*]] = memref.view %[[ARG0]][%[[C0]]][] {{.*}} memref<2x3xf32> - %0 = memref.get_global @cst0 : memref<2x3xf32> - - // CHECK: %[[V1:.*]] = memref.get_global {{.*}} : memref - %1 = memref.get_global @cst1 : memref - - // CHECK: %[[C0_1:.*]] = arith.constant 0 : index - // CHECK: %[[F:.*]] = memref.view %[[ARG2]][%[[C0_1]]][] {{.*}} memref<6xf32> - // CHECK: %[[V2:.*]] = memref.reinterpret_cast %[[F]] - // CHECK-SAME: to offset: [0], sizes: [2, 3], strides: [1, 2] - %2 = memref.get_global @cst2 : memref<2x3xf32, #map> - - // CHECK: return %[[V0]], %[[V1]], %[[V2]] - // CHECK-SAME: : memref<2x3xf32>, memref, memref<2x3xf32, #map{{[0-9]*}}> - return %0, %1, %2 : memref<2x3xf32>, memref, memref<2x3xf32, #map> -} diff --git a/third_party/xla/xla/mlir/backends/gpu/transforms/tests/outline_cuda_graphs.mlir b/third_party/xla/xla/mlir/backends/gpu/transforms/tests/outline_cuda_graphs.mlir deleted file mode 100644 index cd286419770ce5..00000000000000 --- a/third_party/xla/xla/mlir/backends/gpu/transforms/tests/outline_cuda_graphs.mlir +++ /dev/null @@ -1,686 +0,0 @@ -// RUN: xla-gpu-opt %s --split-input-file -xla-gpu-outline-gpu-graphs \ -// RUN: | FileCheck %s - -module attributes {gpu.container_module} { - -gpu.module @gpu_module attributes {binary = "kernel binary"} { - gpu.func @fn0(%arg0: memref) kernel { - gpu.return - } - gpu.func @fn1(%arg0: memref) kernel { - gpu.return - } -} - -// CHECK: @func( -// CHECK: %[[ARG0:.*]]: memref, -// CHECK: %[[ARG1:.*]]: memref -// CHECK: ) -func.func @func(%arg0: memref, %arg1: memref) { - %c1 = arith.constant 1 : index - %c2 = arith.constant 2 : index - %c3 = arith.constant 3 : index - %c4 = arith.constant 4 : index - %c5 = arith.constant 5 : index - %c6 = arith.constant 6 : index - - // CHECK: call @local_xla.gpu.graph.launch(%[[ARG0]], %[[ARG1]]) - // CHECK-SAME: {capture = @local_xla.gpu.graph.capture} - // CHECK-NEXT: return - - gpu.launch_func @gpu_module::@fn0 - blocks in (%c1, %c2, %c3) - threads in (%c4, %c5, %c6) - args(%arg0 : memref) - - gpu.launch_func @gpu_module::@fn1 - blocks in (%c3, %c2, %c1) - threads in (%c6, %c5, %c4) - args(%arg1 : memref) - - func.return -} - -// CHECK: func @local_xla.gpu.graph.capture -// CHECK-NEXT: %[[C1:.*]] = arith.constant 1 -// CHECK-NEXT: %[[C2:.*]] = arith.constant 2 -// CHECK-NEXT: %[[C3:.*]] = arith.constant 3 -// CHECK-NEXT: %[[C4:.*]] = arith.constant 4 -// CHECK-NEXT: %[[C5:.*]] = arith.constant 5 -// CHECK-NEXT: %[[C6:.*]] = arith.constant 6 -// CHECK-NEXT: gpu.launch_func @gpu_module::@fn0 -// CHECK-SAME: blocks in (%[[C1]], %[[C2]], %[[C3]]) -// CHECK-SAME: threads in (%[[C4]], %[[C5]], %[[C6]]) -// CHECK-NEXT: gpu.launch_func @gpu_module::@fn1 -// CHECK-SAME: blocks in (%[[C3]], %[[C2]], %[[C1]]) -// CHECK-SAME: threads in (%[[C6]], %[[C5]], %[[C4]]) -// CHECK-NEXT: return - -// CHECK: func private @local_xla.gpu.graph.launch(memref, memref) -// CHECK-SAME: attributes {rt.custom_call = "xla.gpu.graph.launch"} -} - -// ----- -// Check that single function launch was not outlined into graph capture. - -module attributes {gpu.container_module} { - -gpu.module @gpu_module attributes {binary = "kernel binary"} { - gpu.func @fn0(%arg0: memref) kernel { - gpu.return - } -} - -// CHECK: @func(%[[ARG0:.*]]: memref) -func.func @func(%arg0: memref) { - %c1 = arith.constant 1 : index - - // CHECK: gpu.launch_func {{.*}} args(%[[ARG0]] : memref) - // CHECK-NOT: call @local_xla.gpu.graph.launch - gpu.launch_func @gpu_module::@fn0 - blocks in (%c1, %c1, %c1) - threads in (%c1, %c1, %c1) - args(%arg0 : memref) - - func.return -} - -} - -// ----- -// Check that two different sequences are outlined in different capture -// functions. - -module attributes {gpu.container_module} { - -gpu.module @gpu_module attributes {binary = "kernel binary"} { - gpu.func @fn0(%arg0: memref) kernel { - gpu.return - } - gpu.func @fn1(%arg0: memref) kernel { - gpu.return - } -} - -// CHECK: @func(%[[ARG0:.*]]: memref) -func.func @func(%arg0: memref) { - // CHECK: %[[C1:.*]] = arith.constant 1 - %c1 = arith.constant 1 : index - - // CHECK: call @local_xla.gpu.graph.launch(%[[ARG0]]) - // CHECK-SAME: {capture = @[[CAPTURE:.*]]} - - gpu.launch_func @gpu_module::@fn0 - blocks in (%c1, %c1, %c1) - threads in (%c1, %c1, %c1) - args(%arg0 : memref) - - gpu.launch_func @gpu_module::@fn1 - blocks in (%c1, %c1, %c1) - threads in (%c1, %c1, %c1) - args(%arg0 : memref) - - // CHECK: %[[C2:.*]] = arith.constant 2 - %c2 = arith.constant 2 : index - - // Use function call to break the captured ops sequence. - // CHECK: call @external - call @external(): () -> () - - // CHECK: call @local_xla.gpu.graph.launch(%[[ARG0]]) - // CHECK-SAME: {capture = @[[CAPTURE_0:.*]]} - - gpu.launch_func @gpu_module::@fn1 - blocks in (%c2, %c2, %c2) - threads in (%c2, %c2, %c2) - args(%arg0 : memref) - - gpu.launch_func @gpu_module::@fn0 - blocks in (%c2, %c2, %c2) - threads in (%c2, %c2, %c2) - args(%arg0 : memref) - - func.return -} - -func.func private @external() - -// CHECK: rt.export @[[CAPTURE]] -// CHECK: func.func @[[CAPTURE]]( -// CHECK: %arg0: memref -// CHECK: ) -// CHECK-NEXT: arith.constant 1 -// CHECK-NEXT: gpu.launch_func @gpu_module::@fn0 -// CHECK-NEXT: gpu.launch_func @gpu_module::@fn1 - -// CHECK: rt.export @[[CAPTURE_0]] -// CHECK: func.func @[[CAPTURE_0]]( -// CHECK: %arg0: memref -// CHECK: ) -// CHECK-NEXT: arith.constant 2 -// CHECK-NEXT: gpu.launch_func @gpu_module::@fn1 -// CHECK-NEXT: gpu.launch_func @gpu_module::@fn0 - -} - -// ----- -// Check that constants from the different basic blocks are cloned into the -// graph capture function. - -module attributes {gpu.container_module} { - -gpu.module @gpu_module attributes {binary = "kernel binary"} { - gpu.func @fn0(%arg0: memref) kernel { - gpu.return - } - gpu.func @fn1(%arg0: memref) kernel { - gpu.return - } -} - -// CHECK: @func( -// CHECK: %[[ARG0:.*]]: memref, -// CHECK: %[[ARG1:.*]]: memref -// CHECK: ) -func.func @func(%arg0: memref, %arg1: memref) { - cf.br ^bb2 -^bb1: - // CHECK: call @local_xla.gpu.graph.launch(%[[ARG0]], %[[ARG1]]) - // CHECK-SAME: {capture = @local_xla.gpu.graph.capture} - // CHECK-NEXT: return - - gpu.launch_func @gpu_module::@fn0 - blocks in (%c1, %c1, %c1) - threads in (%c1, %c1, %c1) - args(%arg0 : memref) - - gpu.launch_func @gpu_module::@fn1 - blocks in (%c1, %c1, %c1) - threads in (%c1, %c1, %c1) - args(%arg1 : memref) - - func.return - -^bb2: - %c1 = arith.constant 1 : index - cf.br ^bb1 -} -} - -// CHECK: func @local_xla.gpu.graph.capture -// CHECK-NEXT: arith.constant 1 -// CHECK-NEXT: gpu.launch_func @gpu_module::@fn0 -// CHECK-NEXT: gpu.launch_func @gpu_module::@fn1 -// CHECK-NEXT: return - -// ----- -// Check that memref.view operations are cloned into the graph capture function. - -module attributes {gpu.container_module} { - -gpu.module @gpu_module attributes {binary = "kernel binary"} { - gpu.func @fn0(%arg0: memref<4xf32>) kernel { gpu.return } - gpu.func @fn1(%arg0: memref<4xf32>) kernel { gpu.return } -} - -// CHECK: @func(%[[ARG0:.*]]: memref<16xi8>) -func.func @func(%arg0: memref<16xi8>) { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %view = memref.view %arg0[%c0][] : memref<16xi8> to memref<4xf32> - - call @external() : () -> () - - // CHECK: call @local_xla.gpu.graph.launch(%[[ARG0]]) - // CHECK-SAME: {capture = @local_xla.gpu.graph.capture} - // CHECK-NEXT: return - gpu.launch_func @gpu_module::@fn0 blocks in (%c1, %c1, %c1) - threads in (%c1, %c1, %c1) args(%view : memref<4xf32>) - gpu.launch_func @gpu_module::@fn1 blocks in (%c1, %c1, %c1) - threads in (%c1, %c1, %c1) args(%view : memref<4xf32>) - - func.return -} - -func.func private @external() -} - -// CHECK: func @local_xla.gpu.graph.capture -// CHECK-NEXT: arith.constant 0 -// CHECK-NEXT: arith.constant 1 -// CHECK-NEXT: memref.view -// CHECK-NEXT: gpu.launch_func @gpu_module::@fn0 -// CHECK-NEXT: gpu.launch_func @gpu_module::@fn1 -// CHECK-NEXT: return - -// ----- -// Check that memref.view not used by operations in the captured graph will not -// be moved into the graph capture function. - -module attributes {gpu.container_module} { - -gpu.module @gpu_module attributes {binary = "kernel binary"} { - gpu.func @fn0(%arg0: memref<16xi8>) kernel { gpu.return } - gpu.func @fn1(%arg0: memref<16xi8>) kernel { gpu.return } -} - -// CHECK: @func(%[[ARG0:.*]]: memref<16xi8>) -func.func @func(%arg0: memref<16xi8>) { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - - call @external() : () -> () - - // CHECK: call @local_xla.gpu.graph.launch(%[[ARG0]]) - // CHECK-SAME: {capture = @local_xla.gpu.graph.capture} - // CHECK-NEXT: memref.view - // CHECK-NEXT: return - gpu.launch_func @gpu_module::@fn0 blocks in (%c1, %c1, %c1) - threads in (%c1, %c1, %c1) args(%arg0 : memref<16xi8>) - %view = memref.view %arg0[%c0][] : memref<16xi8> to memref<4xf32> - gpu.launch_func @gpu_module::@fn1 blocks in (%c1, %c1, %c1) - threads in (%c1, %c1, %c1) args(%arg0 : memref<16xi8>) - - func.return -} - -func.func private @external() -} - -// CHECK: func @local_xla.gpu.graph.capture -// CHECK-NEXT: arith.constant 1 -// CHECK-NEXT: gpu.launch_func @gpu_module::@fn0 -// CHECK-NEXT: gpu.launch_func @gpu_module::@fn1 -// CHECK-NEXT: return - -// ----- -// Check that lmhlo_gpu.gemm is moved into the graph capture function. - -module attributes {gpu.container_module} { - - gpu.module @gpu_module attributes {binary = "kernel binary"} { - gpu.func @fn0(%arg0: memref<16xi8>) kernel { gpu.return } - } - - // CHECK: @func(%[[ARG0:.*]]: memref<16xi8> {lmhlo.params = 0 : index} - // CHECK-SAME: %[[ARG1:.*]]: memref<16xi8> {lmhlo.params = 1 : index} - // CHECK-SAME: %[[ARG2:.*]]: memref<16xi8> - func.func @func(%raw_arg0: memref<16xi8> {lmhlo.params = 0 : index}, - %raw_arg1: memref<16xi8> {lmhlo.params = 1 : index}, - %raw_arg2: memref<16xi8> {lmhlo.output_index = dense<[0]> : tensor<1xindex>}) attributes { - result_xla_shape = "(f32[4]) " - } { - %c0 = arith.constant 0 : index - %arg0 = memref.view %raw_arg0[%c0][] : memref<16xi8> to memref<2x2xf32> - %c1 = arith.constant 0 : index - %arg1 = memref.view %raw_arg1[%c1][] : memref<16xi8> to memref<2x2xf32> - %c2 = arith.constant 0 : index - %arg2 = memref.view %raw_arg2[%c2][] : memref<16xi8> to memref<2x2xf32> - - // CHECK: call @local_xla.gpu.graph.launch(%[[ARG0]], %[[ARG1]], %[[ARG2]]) - // CHECK-SAME: {capture = @local_xla.gpu.graph.capture} - "lmhlo_gpu.gemm"(%arg0, %arg1, %arg2) {alpha_imag = 0.000000e+00 : f64, alpha_real = 1.000000e+00 : f64, beta = 0.000000e+00 : f64, batch_size = 1 : i64, lhs_stride = 4 : i64, rhs_stride = 4 : i64, dot_dimension_numbers = #mhlo.dot} : (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>) -> () - gpu.launch_func @gpu_module::@fn0 blocks in (%c1, %c1, %c1) - threads in (%c1, %c1, %c1) args(%raw_arg0 : memref<16xi8>) - "lmhlo.terminator"() : () -> () - } - - func.func private @external() -} - -// CHECK: func @local_xla.gpu.graph.capture -// CHECK-NEXT: arith.constant 0 -// CHECK-NEXT: memref.view -// CHECK-NEXT: arith.constant 0 -// CHECK-NEXT: memref.view -// CHECK-NEXT: arith.constant 0 -// CHECK-NEXT: memref.view -// CHECK-NEXT: "lmhlo_gpu.gemm" -// CHECK-NEXT: gpu.launch_func @gpu_module::@fn0 -// CHECK-NEXT: return - -// ----- -// Check that lmhlo_gpu.gemm with runtime autotuning is not captured by a CUDA -// graph. - -module attributes {gpu.container_module} { - - gpu.module @gpu_module attributes {binary = "kernel binary"} { - gpu.func @fn0(%arg0: memref<16xi8>) kernel { gpu.return } - } - - // CHECK: @func(%[[ARG0:.*]]: memref<16xi8> {lmhlo.params = 0 : index} - // CHECK-SAME: %[[ARG1:.*]]: memref<16xi8> {lmhlo.params = 1 : index} - // CHECK-SAME: %[[ARG2:.*]]: memref<16xi8> - func.func @func(%raw_arg0: memref<16xi8> {lmhlo.params = 0 : index}, - %raw_arg1: memref<16xi8> {lmhlo.params = 1 : index}, - %raw_arg2: memref<16xi8> {lmhlo.output_index = dense<[0]> : tensor<1xindex>}) attributes { - result_xla_shape = "(f32[4]) " - } { - %c0 = arith.constant 0 : index - %arg0 = memref.view %raw_arg0[%c0][] : memref<16xi8> to memref<2x2xf32> - %c1 = arith.constant 0 : index - %arg1 = memref.view %raw_arg1[%c1][] : memref<16xi8> to memref<2x2xf32> - %c2 = arith.constant 0 : index - %arg2 = memref.view %raw_arg2[%c2][] : memref<16xi8> to memref<2x2xf32> - - - // CHECK-NOT: call @local_xla.gpu.graph.launch - // CHECK: "lmhlo_gpu.gemm" - "lmhlo_gpu.gemm"(%arg0, %arg1, %arg2) {algorithm = -5, alpha_imag = 0.000000e+00 : f64, alpha_real = 1.000000e+00 : f64, beta = 0.000000e+00 : f64, batch_size = 1 : i64, lhs_stride = 4 : i64, rhs_stride = 4 : i64, dot_dimension_numbers = #mhlo.dot} : (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>) -> () - gpu.launch_func @gpu_module::@fn0 blocks in (%c1, %c1, %c1) - threads in (%c1, %c1, %c1) args(%raw_arg0 : memref<16xi8>) - "lmhlo.terminator"() : () -> () - } - - func.func private @external() -} - -// ----- -// Check that convolution with runtime autotuning is not captured by a CUDA -// graph. - -#map0 = affine_map<(d0, d1, d2, d3) -> (d0 * 3 + d1 + d2 * 9 + d3 * 9)> -#map1 = affine_map<(d0, d1, d2, d3) -> (d0 * 16384 + d1 * 4 + d2 + d3 * 16)> -#map2 = affine_map<(d0, d1, d2, d3) -> (d0 * 4096 + d1 * 2 + d2 + d3 * 4)> - -module attributes {gpu.container_module} { - - gpu.module @gpu_module attributes {binary = "kernel binary"} { - gpu.func @fn0(%arg0: memref<16xi8>) kernel { gpu.return } - } - - - // CHECK: @func(%[[ARG0:.*]]: memref<8x5x5x1xf32, #map> - // CHECK-SAME: %[[ARG1:.*]]: memref<3x3x1x32xf32, #map1> - // CHECK-SAME: %[[ARG2:.*]]: memref<32xf32> - // CHECK-SAME: %[[ARG3:.*]]: memref<8x5x5x32xf32, #map2> - // CHECK-SAME: %[[ARG4:.*]]: memref<0xui8> - // CHECK-SAME: %[[ARG5:.*]]: memref<16xi8> - func.func @func(%input: memref<8x5x5x1xf32, #map1>, - %filter: memref<3x3x1x32xf32, #map0>, - %bias: memref<32xf32>, - %output: memref<8x5x5x32xf32, #map2>, - %scratch: memref<0xui8>, - %raw_arg0: memref<16xi8> {lmhlo.params = 0 : index} - ) { - %c0 = arith.constant 0 : index - - // CHECK-NOT: call @local_xla.g.cuda.graph.launch - // CHECK: lmhlo_gpu.conv_forward_fused - lmhlo_gpu.conv_forward_fused(%input, %filter, %bias, %output, %scratch) - dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], - window = { stride = [1, 1], - lhs_dilate = [1, 1], - rhs_dilate = [1, 1], - reverse = [0, 0] - } - { activation_mode = #lmhlo_gpu, - leakyrelu_alpha = 0.0 : f64, - backend_config = #lmhlo_gpu.convolution_backend_config< - algorithm = -1, - is_cudnn_frontend = true, - is_cudnn_reordered_int8 = false, - knob_ids = [2, 3], - knob_values = [4, 0], - operand_0_layout = [2, 1, 3, 0], - operand_1_layout = [1, 0, 2, 3], - result_layout = [2, 1, 3, 0], - tensor_ops_enabled = false, - workspace_size = 0 - >, - batch_group_count = 1 : i64, - feature_group_count = 1 : i64, - precision_config = [], - result_scale = 1.000000e+00 : f64 - } : (memref<8x5x5x1xf32, #map1>, - memref<3x3x1x32xf32, #map0>, - memref<32xf32>, - memref<8x5x5x32xf32, #map2>, - memref<0xui8>) -> () - gpu.launch_func @gpu_module::@fn0 blocks in (%c0, %c0, %c0) - threads in (%c0, %c0, %c0) args(%raw_arg0 : memref<16xi8>) - return - } - func.func private @external() -} - -// ----- -// Check that convolutions are captured by cuda graphs. - -#map0 = affine_map<(d0, d1, d2, d3) -> (d0 * 3 + d1 + d2 * 9 + d3 * 9)> -#map1 = affine_map<(d0, d1, d2, d3) -> (d0 * 16384 + d1 * 4 + d2 + d3 * 16)> -#map2 = affine_map<(d0, d1, d2, d3) -> (d0 * 4096 + d1 * 2 + d2 + d3 * 4)> - -module attributes {gpu.container_module} { - - gpu.module @gpu_module attributes {binary = "kernel binary"} { - gpu.func @fn0(%arg0: memref<16xi8>) kernel { gpu.return } - } - - - // CHECK: @func(%[[ARG0:.*]]: memref<1x4x4x1024xf16, #map> - // CHECK-SAME: %[[ARG1:.*]]: memref<3x3x1x1024xf16, #map1> - // CHECK-SAME: %[[ARG2:.*]]: memref<1x2x2x1024xf16, #map2> - // CHECK-SAME: %[[ARG3:.*]]: memref<0xui8> - // CHECK-SAME: %[[ARG4:.*]]: memref<16xi8> - func.func @func(%input: memref<1x4x4x1024xf16, #map1>, - %filter: memref<3x3x1x1024xf16, #map0>, - %output: memref<1x2x2x1024xf16, #map2>, - %scratch: memref<0xui8>, - %raw_arg0: memref<16xi8> {lmhlo.params = 0 : index} - ) { - %c0 = arith.constant 0 : index - - // CHECK: call @local_xla.gpu.graph.launch(%[[ARG0]], %[[ARG1]], %[[ARG2]], %[[ARG3]], %[[ARG4]]) - // CHECK-SAME: {capture = @local_xla.gpu.graph.capture} - lmhlo_gpu.conv_forward(%input, %filter, %output, %scratch) - dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], - window = { stride = [1, 1], - lhs_dilate = [1, 1], - rhs_dilate = [1, 1], - reverse = [0, 0] - } - { backend_config = #lmhlo_gpu.convolution_backend_config< - algorithm = 0, - is_cudnn_frontend = true, - is_cudnn_reordered_int8 = false, - knob_ids = [], - knob_values = [], - operand_0_layout = [2, 1, 3, 0], - operand_1_layout = [1, 0, 2, 3], - result_layout = [2, 1, 3, 0], - tensor_ops_enabled = false, - workspace_size = 0 - >, - batch_group_count = 1 : i64, - feature_group_count = 1024 : i64, - precision_config = [], - result_scale = 1.000000e+00 : f64 - } : (memref<1x4x4x1024xf16, #map1>, - memref<3x3x1x1024xf16, #map0>, - memref<1x2x2x1024xf16, #map2>, - memref<0xui8>) -> () - gpu.launch_func @gpu_module::@fn0 blocks in (%c0, %c0, %c0) - threads in (%c0, %c0, %c0) args(%raw_arg0 : memref<16xi8>) - return - } - func.func private @external() -} - -// CHECK: func @local_xla.gpu.graph.capture -// CHECK-NEXT: arith.constant 0 -// CHECK-NEXT: lmhlo_gpu.conv_forward -// CHECK-NEXT: gpu.launch_func @gpu_module::@fn0 -// CHECK-NEXT: return - -// ----- -// Check that d2d memcpy are captured. - -module attributes {gpu.container_module} { - - // CHECK: @func(%[[ARG0:.*]]: memref<100xi8>) - func.func @func(%arg0: memref<100xi8>) { - %c0 = arith.constant 0 : index - %dst = memref.view %arg0[%c0][] : memref<100xi8> to memref<10xf32> - %src = memref.view %arg0[%c0][] : memref<100xi8> to memref<10xf32> - - // CHECK: call @local_xla.gpu.graph.launch(%[[ARG0]]) - // CHECK-SAME: {capture = @local_xla.gpu.graph.capture} - gpu.memcpy %dst, %src : memref<10xf32>, memref<10xf32> - gpu.memcpy %dst, %src : memref<10xf32>, memref<10xf32> - - // CHECK: return - return - } - func.func private @external() -} - -// CHECK: func @local_xla.gpu.graph.capture -// CHECK: gpu.memcpy -// CHECK: gpu.memcpy -// CHECK-NEXT: return - -// ----- -// Check that memref.reinterpret_cast operations are cloned into the graph -// capture function. - -module attributes {gpu.container_module} { - -gpu.module @gpu_module attributes {binary = "kernel binary"} { - gpu.func @fn0(%arg0: memref<16xi8, strided<[1], offset: 0>>) kernel { gpu.return } - gpu.func @fn1(%arg0: memref<16xi8, strided<[1], offset: 0>>) kernel { gpu.return } -} - -// CHECK: @func(%[[ARG0:.*]]: memref<16xi8>) -func.func @func(%arg0: memref<16xi8>) { - %c1 = arith.constant 1 : index - %view = memref.reinterpret_cast %arg0 to offset: [0], sizes: [16], strides: [1]: memref<16xi8> to memref<16xi8, strided<[1], offset: 0>> - - call @external() : () -> () - - // CHECK: call @local_xla.gpu.graph.launch(%[[ARG0]]) - // CHECK-SAME: {capture = @local_xla.gpu.graph.capture} - // CHECK-NEXT: return - gpu.launch_func @gpu_module::@fn0 blocks in (%c1, %c1, %c1) - threads in (%c1, %c1, %c1) args(%view : memref<16xi8, strided<[1], offset: 0>>) - gpu.launch_func @gpu_module::@fn1 blocks in (%c1, %c1, %c1) - threads in (%c1, %c1, %c1) args(%view : memref<16xi8, strided<[1], offset: 0>>) - - func.return -} - -func.func private @external() -} - -// CHECK: func @local_xla.gpu.graph.capture -// CHECK-NEXT: arith.constant 1 -// CHECK-NEXT: memref.reinterpret_cast -// CHECK-NEXT: gpu.launch_func @gpu_module::@fn0 -// CHECK-NEXT: gpu.launch_func @gpu_module::@fn1 -// CHECK-NEXT: return - -// ----- -// Check that the loop body of lmhlo.while is cloned into the graph. - -module attributes {gpu.container_module} { - -gpu.module @gpu_module attributes {binary = "kernel binary"} { - gpu.func @fn0(%arg0: memref<16xi8>) kernel { gpu.return } - gpu.func @fn1(%arg0: memref<16xi8>) kernel { gpu.return } -} - -// CHECK: @func(%[[ARG0:.*]]: memref<16xi8> -func.func @func(%arg0: memref<16xi8>, %cond: memref) { - %c1 = arith.constant 1 : index - - call @external() : () -> () - - "lmhlo.while"(%cond) ({ - // CHECK: func.call @local_xla.gpu.graph.launch(%[[ARG0]]) - // CHECK-SAME: {capture = @local_xla.gpu.graph.capture} - gpu.launch_func @gpu_module::@fn0 blocks in (%c1, %c1, %c1) - threads in (%c1, %c1, %c1) args(%arg0: memref<16xi8>) - gpu.launch_func @gpu_module::@fn1 blocks in (%c1, %c1, %c1) - threads in (%c1, %c1, %c1) args(%arg0: memref<16xi8>) - "lmhlo.terminator"() : () -> () }, { - // CHECK: func.call @local_xla.gpu.graph.launch(%[[ARG0]]) - // CHECK-SAME: {capture = @local_xla.gpu.graph.capture_0} - gpu.launch_func @gpu_module::@fn0 blocks in (%c1, %c1, %c1) - threads in (%c1, %c1, %c1) args(%arg0: memref<16xi8>) - gpu.launch_func @gpu_module::@fn1 blocks in (%c1, %c1, %c1) - threads in (%c1, %c1, %c1) args(%arg0: memref<16xi8>) - "lmhlo.terminator"() : () -> () - }) : (memref) -> () - func.return -} - -func.func private @external() -} - -// CHECK: func @local_xla.gpu.graph.capture -// CHECK-NEXT: arith.constant 1 -// CHECK-NEXT: gpu.launch_func @gpu_module::@fn0 -// CHECK-NEXT: gpu.launch_func @gpu_module::@fn1 -// CHECK-NEXT: return - -// CHECK: func @local_xla.gpu.graph.capture_0 -// CHECK-NEXT: arith.constant 1 -// CHECK-NEXT: gpu.launch_func @gpu_module::@fn0 -// CHECK-NEXT: gpu.launch_func @gpu_module::@fn1 -// CHECK-NEXT: return - -// ----- -// Check that lmhlo.constant_name is propogated to the graph capture function -module attributes {gpu.container_module} { - -gpu.module @gpu_module attributes {binary = "kernel binary"} { - gpu.func @fn0(%arg0: memref) kernel { - gpu.return - } - gpu.func @fn1(%arg0: memref) kernel { - gpu.return - } -} - -// CHECK: @func( -// CHECK: %[[ARG0:.*]]: memref {lmhlo.constant_name = "cst0"}, -// CHECK: %[[ARG1:.*]]: memref {lmhlo.constant_name = "cst1"} -// CHECK: ) -func.func @func(%arg0: memref {lmhlo.constant_name = "cst0"}, - %arg1: memref {lmhlo.constant_name = "cst1"}) { - %c1 = arith.constant 1 : index - - // CHECK: call @local_xla.gpu.graph.launch(%[[ARG0]], %[[ARG1]]) - // CHECK-SAME: {capture = @local_xla.gpu.graph.capture} - // CHECK-NEXT: return - - gpu.launch_func @gpu_module::@fn0 - blocks in (%c1, %c1, %c1) - threads in (%c1, %c1, %c1) - args(%arg0 : memref) - - gpu.launch_func @gpu_module::@fn1 - blocks in (%c1, %c1, %c1) - threads in (%c1, %c1, %c1) - args(%arg1 : memref) - - func.return -} - -// CHECK: func @local_xla.gpu.graph.capture( -// CHECK-SAME: %[[ARG0]]: memref {lmhlo.constant_name = "cst0", -// CHECK-SAME: %[[ARG1]]: memref {lmhlo.constant_name = "cst1", -// CHECK-SAME: ) -// CHECK-NEXT: %[[C1:.*]] = arith.constant 1 -// CHECK-NEXT: gpu.launch_func @gpu_module::@fn0 -// CHECK-SAME: blocks in (%[[C1]], %[[C1]], %[[C1]]) -// CHECK-SAME: threads in (%[[C1]], %[[C1]], %[[C1]]) -// CHECK-NEXT: gpu.launch_func @gpu_module::@fn1 -// CHECK-SAME: blocks in (%[[C1]], %[[C1]], %[[C1]]) -// CHECK-SAME: threads in (%[[C1]], %[[C1]], %[[C1]]) -// CHECK-NEXT: return - -// CHECK: func private @local_xla.gpu.graph.launch(memref, memref) -// CHECK-SAME: attributes {rt.custom_call = "xla.gpu.graph.launch"} -} diff --git a/third_party/xla/xla/mlir/backends/gpu/transforms/tests/stream_assignment.mlir b/third_party/xla/xla/mlir/backends/gpu/transforms/tests/stream_assignment.mlir deleted file mode 100644 index a3543dca28bad9..00000000000000 --- a/third_party/xla/xla/mlir/backends/gpu/transforms/tests/stream_assignment.mlir +++ /dev/null @@ -1,190 +0,0 @@ -// RUN: xla-gpu-opt %s --split-input-file -xla-gpu-stream-assignment \ -// RUN: | FileCheck %s - -// ----- -// Check that independent kernels are assigned to different streams. -// A B--->C -// | ^ -// | | -// +--------+ -// -// Stream assignment: A->0 B->1 C->0 - -module attributes {gpu.container_module} { - - gpu.module @gpu_module attributes {binary = "kernel binary"} { - gpu.func @fn1(%arg0: memref<3x3xi64> {lmhlo.written} ) kernel { gpu.return } - gpu.func @fn2(%arg0: memref<3x3xi64>, %arg1: memref<3x3xi64>) kernel { gpu.return } - } - - // CHECK: func @local_xla.gpu.graph.capture - func.func @local_xla.gpu.graph.capture(%arg0: memref<72xi8>, %arg1: memref<72xi8>) { - // CHECK: call @local_xla.gpu.concurrent_region.begin() - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %view = memref.view %arg0[%c0][] : memref<72xi8> to memref<3x3xi64> - %view_0 = memref.view %arg1[%c0][] : memref<72xi8> to memref<3x3xi64> - - // CHECK: gpu.launch_func @gpu_module::@fn1 - // CHECK-SAME: {stream = 0 : i64} - gpu.launch_func @gpu_module::@fn1 blocks in (%c1, %c1, %c1) - threads in (%c1, %c1, %c1) args(%view : memref<3x3xi64>) - // CHECK: gpu.launch_func @gpu_module::@fn1 - // CHECK-SAME: {stream = 1 : i64} - gpu.launch_func @gpu_module::@fn1 blocks in (%c1, %c1, %c1) - threads in (%c1, %c1, %c1) args(%view_0 : memref<3x3xi64>) - // CHECK: call @local_xla.streams.await() {from = 0 : i64, to = [1]} - // CHECK: gpu.launch_func @gpu_module::@fn2 - // CHECK-SAME: {stream = 0 : i64} - gpu.launch_func @gpu_module::@fn2 blocks in (%c1, %c1, %c1) - threads in (%c1, %c1, %c1) args(%view : memref<3x3xi64>, %view_0 : memref<3x3xi64>) - // CHECK: call @local_xla.gpu.concurrent_region.end() - // CHECK: return - return - } -} - -// ----- -// Check that the assignment for the following pattern correctly exploits -// parallelism. -// A--->B C -// | ^ -// | | -// +--------+ -// -// Stream assignment: A->0 B->1 C->0 -// - -module attributes {gpu.container_module} { - - gpu.module @gpu_module attributes {binary = "kernel binary"} { - gpu.func @fn1(%arg0: memref<3x3xi64> {lmhlo.written} ) kernel { gpu.return } - gpu.func @fn2(%arg0: memref<3x3xi64> {lmhlo.written}, %arg1: memref<3x3xi64> {lmhlo.written}) kernel { gpu.return } - } - - - // CHECK: func @local_xla.gpu.graph.capture - func.func @local_xla.gpu.graph.capture(%arg0: memref<72xi8>, %arg1: memref<72xi8>) { - // CHECK: call @local_xla.gpu.concurrent_region.begin() - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %view = memref.view %arg0[%c0][] : memref<72xi8> to memref<3x3xi64> - %view_0 = memref.view %arg1[%c0][] : memref<72xi8> to memref<3x3xi64> - - // CHECK: gpu.launch_func @gpu_module::@fn2 - // CHECK-SAME: {stream = 0 : i64} - gpu.launch_func @gpu_module::@fn2 blocks in (%c1, %c1, %c1) - threads in (%c1, %c1, %c1) args(%view : memref<3x3xi64>, %view_0 : memref<3x3xi64>) - // CHECK: call @local_xla.streams.await() {from = 1 : i64, to = [0]} - // CHECK: gpu.launch_func @gpu_module::@fn1 - // CHECK-SAME: {stream = 1 : i64} - gpu.launch_func @gpu_module::@fn1 blocks in (%c1, %c1, %c1) - threads in (%c1, %c1, %c1) args(%view : memref<3x3xi64>) - // CHECK: gpu.launch_func @gpu_module::@fn1 - // CHECK-SAME: {stream = 0 : i64} - gpu.launch_func @gpu_module::@fn1 blocks in (%c1, %c1, %c1) - threads in (%c1, %c1, %c1) args(%view_0 : memref<3x3xi64>) - // CHECK: call @local_xla.gpu.concurrent_region.end() - // CHECK: return - return - } -} - -// ----- -// Check that stream with multiple dependencies is handled correctly. -// A B C-->D -// | | ^ -// | |--------| -// +-------------+ -// -// Stream assignment: A->0 B->1 C->2 D->0 -// - -module attributes {gpu.container_module} { - - gpu.module @gpu_module attributes {binary = "kernel binary"} { - gpu.func @fn1(%arg0: memref<3x3xi64> {lmhlo.written} ) kernel { gpu.return } - gpu.func @fn2(%arg0: memref<3x3xi64> {lmhlo.written}, %arg1: memref<3x3xi64> {lmhlo.written}, %arg3: memref<3x3xi64>) kernel { gpu.return } - } - - - // CHECK: func @local_xla.gpu.graph.capture - func.func @local_xla.gpu.graph.capture(%arg0: memref<72xi8>, %arg1: memref<72xi8>, %arg2: memref<72xi8>) { - // CHECK: call @local_xla.gpu.concurrent_region.begin() - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %view_0 = memref.view %arg0[%c0][] : memref<72xi8> to memref<3x3xi64> - %view_1 = memref.view %arg1[%c0][] : memref<72xi8> to memref<3x3xi64> - %view_2 = memref.view %arg2[%c0][] : memref<72xi8> to memref<3x3xi64> - - // CHECK: gpu.launch_func @gpu_module::@fn1 - // CHECK-SAME: {stream = 0 : i64} - gpu.launch_func @gpu_module::@fn1 blocks in (%c1, %c1, %c1) - threads in (%c1, %c1, %c1) args(%view_0 : memref<3x3xi64>) - // CHECK: gpu.launch_func @gpu_module::@fn1 - // CHECK-SAME: {stream = 1 : i64} - gpu.launch_func @gpu_module::@fn1 blocks in (%c1, %c1, %c1) - threads in (%c1, %c1, %c1) args(%view_1 : memref<3x3xi64>) - // CHECK: gpu.launch_func @gpu_module::@fn1 - // CHECK-SAME: {stream = 2 : i64} - gpu.launch_func @gpu_module::@fn1 blocks in (%c1, %c1, %c1) - threads in (%c1, %c1, %c1) args(%view_2 : memref<3x3xi64>) - // CHECK: call @local_xla.streams.await() {from = 0 : i64, to = [1, 2]} - // CHECK: gpu.launch_func @gpu_module::@fn2 - // CHECK-SAME: {stream = 0 : i64} - gpu.launch_func @gpu_module::@fn2 blocks in (%c1, %c1, %c1) - threads in (%c1, %c1, %c1) args(%view_0 : memref<3x3xi64>, %view_1 : memref<3x3xi64>, %view_2 : memref<3x3xi64>) - // CHECK: call @local_xla.gpu.concurrent_region.end() - // CHECK: return - return - } -} - -// ----- -// Check that stream synchronization only happens when two streams joins. -// A B--->C-->D -// | ^ -// | | -// +---------+ -// -// Stream assignment: A->0 B->1 C->0 D->0 -// - -module attributes {gpu.container_module} { - - gpu.module @gpu_module attributes {binary = "kernel binary"} { - gpu.func @fn1(%arg0: memref<3x3xi64> {lmhlo.written} ) kernel { gpu.return } - gpu.func @fn2(%arg0: memref<3x3xi64> {lmhlo.written}, %arg1: memref<3x3xi64>) kernel { gpu.return } - } - - - // CHECK: func @local_xla.gpu.graph.capture - func.func @local_xla.gpu.graph.capture(%arg0: memref<72xi8>, %arg1: memref<72xi8>) { - // CHECK: call @local_xla.gpu.concurrent_region.begin() - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %view_0 = memref.view %arg0[%c0][] : memref<72xi8> to memref<3x3xi64> - %view_1 = memref.view %arg1[%c0][] : memref<72xi8> to memref<3x3xi64> - - // CHECK: gpu.launch_func @gpu_module::@fn1 - // CHECK-SAME: {stream = 0 : i64} - gpu.launch_func @gpu_module::@fn1 blocks in (%c1, %c1, %c1) - threads in (%c1, %c1, %c1) args(%view_0 : memref<3x3xi64>) - // CHECK: gpu.launch_func @gpu_module::@fn1 - // CHECK-SAME: {stream = 1 : i64} - gpu.launch_func @gpu_module::@fn1 blocks in (%c1, %c1, %c1) - threads in (%c1, %c1, %c1) args(%view_1 : memref<3x3xi64>) - // CHECK: call @local_xla.streams.await() {from = 0 : i64, to = [1]} - // CHECK: gpu.launch_func @gpu_module::@fn2 - // CHECK-SAME: {stream = 0 : i64} - gpu.launch_func @gpu_module::@fn2 blocks in (%c1, %c1, %c1) - threads in (%c1, %c1, %c1) args(%view_0 : memref<3x3xi64>, %view_1 : memref<3x3xi64>) - // CHECK-NEXT: gpu.launch_func @gpu_module::@fn2 - // CHECK-SAME: {stream = 0 : i64} - gpu.launch_func @gpu_module::@fn2 blocks in (%c1, %c1, %c1) - threads in (%c1, %c1, %c1) args(%view_0 : memref<3x3xi64>, %view_1 : memref<3x3xi64>) - // CHECK: call @local_xla.gpu.concurrent_region.end() - // CHECK: return - return - } -} diff --git a/third_party/xla/xla/mlir/backends/gpu/transforms/uid_generator.h b/third_party/xla/xla/mlir/backends/gpu/transforms/uid_generator.h deleted file mode 100644 index 6e03655ede558b..00000000000000 --- a/third_party/xla/xla/mlir/backends/gpu/transforms/uid_generator.h +++ /dev/null @@ -1,42 +0,0 @@ -/* Copyright 2022 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_MLIR_BACKENDS_GPU_TRANSFORMS_UID_GENERATOR_H_ -#define XLA_MLIR_BACKENDS_GPU_TRANSFORMS_UID_GENERATOR_H_ - -#include - -namespace xla { -namespace gpu { - -// Every stateful operation in the module gets assigned a unique id, that is -// passed to the custom call handler. This id is used for caching resources -// between the different invocations of the same custom call (e.g. cache -// convolution descriptors). -// -// TODO(b/255600288): Improve stateful custom calls in Xla runtime. -class UidGenerator { - public: - UidGenerator() : uid_(0) {} - int64_t uid() { return uid_.fetch_add(1); } - - private: - std::atomic uid_; -}; - -} // namespace gpu -} // namespace xla - -#endif // XLA_MLIR_BACKENDS_GPU_TRANSFORMS_UID_GENERATOR_H_ diff --git a/third_party/xla/xla/mlir/backends/gpu/xla-gpu-opt.cc b/third_party/xla/xla/mlir/backends/gpu/xla-gpu-opt.cc deleted file mode 100644 index ccea4d69d64b9c..00000000000000 --- a/third_party/xla/xla/mlir/backends/gpu/xla-gpu-opt.cc +++ /dev/null @@ -1,36 +0,0 @@ -/* Copyright 2022 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "mlir/Dialect/Func/Extensions/AllExtensions.h" // from @llvm-project -#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project -#include "mlir/Dialect/GPU/IR/GPUDialect.h" // from @llvm-project -#include "mlir/Dialect/MemRef/IR/MemRef.h" // from @llvm-project -#include "mlir/Tools/mlir-opt/MlirOptMain.h" // from @llvm-project -#include "xla/mlir/backends/gpu/transforms/passes.h" -#include "xla/mlir_hlo/lhlo/IR/lhlo_ops.h" -#include "xla/mlir_hlo/lhlo_gpu/IR/lhlo_gpu_ops.h" - -int main(int argc, char **argv) { - mlir::DialectRegistry registry; - registry - .insert(); - mlir::func::registerAllExtensions(registry); - - xla::gpu::registerGpuTransformsPasses(); - - return failed(MlirOptMain(argc, argv, "Xla Gpu Pass Driver\n", registry)); -} diff --git a/third_party/xla/xla/mlir/framework/ir/BUILD b/third_party/xla/xla/mlir/framework/ir/BUILD index b410c5b71a2e3d..9db56a6a5e11a5 100644 --- a/third_party/xla/xla/mlir/framework/ir/BUILD +++ b/third_party/xla/xla/mlir/framework/ir/BUILD @@ -1,9 +1,11 @@ load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "td_library") +load("@local_tsl//tsl:tsl.bzl", "internal_visibility") load("@local_tsl//tsl:tsl.default.bzl", "get_compatible_with_portable") load("@local_tsl//tsl/platform:rules_cc.bzl", "cc_library") package( - default_visibility = ["//visibility:public"], + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = internal_visibility(["//learning/brain/mlir:xla_friends"]), licenses = ["notice"], ) @@ -62,7 +64,6 @@ cc_library( "xla_framework.h.inc", ], hdrs = ["xla_framework.h"], - visibility = ["//visibility:public"], deps = [ ":xla_framework_inc_gen", "@llvm-project//llvm:Support", diff --git a/third_party/xla/xla/mlir/framework/tests/BUILD b/third_party/xla/xla/mlir/framework/tests/BUILD index 8f6df845d20816..e0311ea4ac362d 100644 --- a/third_party/xla/xla/mlir/framework/tests/BUILD +++ b/third_party/xla/xla/mlir/framework/tests/BUILD @@ -1,7 +1,6 @@ load("//xla:lit.bzl", "enforce_glob", "lit_test_suite") package( - default_visibility = ["//visibility:public"], # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], licenses = ["notice"], ) @@ -18,7 +17,7 @@ lit_test_suite( ), cfg = "//xla:lit.cfg.py", tools = [ - "//xla/translate/mhlo_to_lhlo_with_xla:xla-translate-opt", + "//xla/translate:xla-translate-opt", "@llvm-project//llvm:FileCheck", ], ) diff --git a/third_party/xla/xla/mlir/framework/transforms/BUILD b/third_party/xla/xla/mlir/framework/transforms/BUILD index be60927699ced5..26cbb6afcf960b 100644 --- a/third_party/xla/xla/mlir/framework/transforms/BUILD +++ b/third_party/xla/xla/mlir/framework/transforms/BUILD @@ -1,9 +1,11 @@ load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library") +load("@local_tsl//tsl:tsl.bzl", "internal_visibility") load("@local_tsl//tsl:tsl.default.bzl", "get_compatible_with_portable") load("@local_tsl//tsl/platform:rules_cc.bzl", "cc_library") package( - default_visibility = ["//visibility:public"], + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = internal_visibility(["//learning/brain/mlir:xla_friends"]), licenses = ["notice"], ) @@ -39,7 +41,6 @@ cc_library( hdrs = [ "passes.h", ], - visibility = ["//visibility:public"], deps = [ ":passes_inc_gen", "//xla/mlir/framework/ir:xla_framework", diff --git a/third_party/xla/xla/mlir/math/BUILD b/third_party/xla/xla/mlir/math/BUILD index 7b4bde174949ea..15f4a35f9a210a 100644 --- a/third_party/xla/xla/mlir/math/BUILD +++ b/third_party/xla/xla/mlir/math/BUILD @@ -10,6 +10,7 @@ package_group( ) package( - default_visibility = ["//visibility:public"], + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = [":friends"], licenses = ["notice"], ) diff --git a/third_party/xla/xla/mlir/math/transforms/BUILD b/third_party/xla/xla/mlir/math/transforms/BUILD index 5a7436ac2a9b17..cfc7ff68c1e73c 100644 --- a/third_party/xla/xla/mlir/math/transforms/BUILD +++ b/third_party/xla/xla/mlir/math/transforms/BUILD @@ -3,7 +3,8 @@ load("@local_tsl//tsl:tsl.default.bzl", "get_compatible_with_portable") load("@local_tsl//tsl/platform:rules_cc.bzl", "cc_library") package( - default_visibility = ["//visibility:public"], + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = ["//xla/mlir/math:friends"], licenses = ["notice"], ) @@ -32,7 +33,6 @@ cc_library( ], hdrs = ["passes.h"], compatible_with = get_compatible_with_portable(), - visibility = ["//visibility:public"], deps = [ ":passes_inc_gen", "@llvm-project//mlir:ArithDialect", diff --git a/third_party/xla/xla/mlir/math/transforms/tests/BUILD b/third_party/xla/xla/mlir/math/transforms/tests/BUILD index 6fa63434706dbf..307c66af3b0561 100644 --- a/third_party/xla/xla/mlir/math/transforms/tests/BUILD +++ b/third_party/xla/xla/mlir/math/transforms/tests/BUILD @@ -1,7 +1,6 @@ load("//xla:lit.bzl", "enforce_glob", "lit_test_suite") package( - default_visibility = ["//visibility:public"], # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], licenses = ["notice"], ) diff --git a/third_party/xla/xla/mlir/memref/BUILD b/third_party/xla/xla/mlir/memref/BUILD index 7b4bde174949ea..15f4a35f9a210a 100644 --- a/third_party/xla/xla/mlir/memref/BUILD +++ b/third_party/xla/xla/mlir/memref/BUILD @@ -10,6 +10,7 @@ package_group( ) package( - default_visibility = ["//visibility:public"], + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = [":friends"], licenses = ["notice"], ) diff --git a/third_party/xla/xla/mlir/memref/transforms/BUILD b/third_party/xla/xla/mlir/memref/transforms/BUILD index 0c09155bb1eb9c..adc78fc5b6a1cc 100644 --- a/third_party/xla/xla/mlir/memref/transforms/BUILD +++ b/third_party/xla/xla/mlir/memref/transforms/BUILD @@ -3,7 +3,8 @@ load("@local_tsl//tsl:tsl.default.bzl", "get_compatible_with_portable") load("@local_tsl//tsl/platform:rules_cc.bzl", "cc_library") package( - default_visibility = ["//visibility:public"], + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = ["//xla/mlir/memref:friends"], licenses = ["notice"], ) @@ -29,10 +30,10 @@ cc_library( srcs = ["aligned_allocations.cc"], hdrs = ["passes.h"], compatible_with = get_compatible_with_portable(), - visibility = ["//visibility:public"], deps = [ ":passes_inc_gen", "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", "@llvm-project//mlir:MemRefDialect", "@llvm-project//mlir:Pass", ], diff --git a/third_party/xla/xla/mlir/memref/transforms/aligned_allocations.cc b/third_party/xla/xla/mlir/memref/transforms/aligned_allocations.cc index 6bb9a925496dd5..8cee23aec11e35 100644 --- a/third_party/xla/xla/mlir/memref/transforms/aligned_allocations.cc +++ b/third_party/xla/xla/mlir/memref/transforms/aligned_allocations.cc @@ -19,6 +19,8 @@ limitations under the License. #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/Dialect/MemRef/IR/MemRef.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project #include "xla/mlir/memref/transforms/passes.h" diff --git a/third_party/xla/xla/mlir/memref/transforms/tests/BUILD b/third_party/xla/xla/mlir/memref/transforms/tests/BUILD index 9a2cd8be569ced..57836f22d167fd 100644 --- a/third_party/xla/xla/mlir/memref/transforms/tests/BUILD +++ b/third_party/xla/xla/mlir/memref/transforms/tests/BUILD @@ -1,7 +1,6 @@ load("//xla:lit.bzl", "enforce_glob", "lit_test_suite") package( - default_visibility = ["//visibility:public"], # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], licenses = ["notice"], ) diff --git a/third_party/xla/xla/mlir/runtime/BUILD b/third_party/xla/xla/mlir/runtime/BUILD index 7309674b1e7d78..7538e87ec6ba8e 100644 --- a/third_party/xla/xla/mlir/runtime/BUILD +++ b/third_party/xla/xla/mlir/runtime/BUILD @@ -22,7 +22,8 @@ package_group( ) package( - default_visibility = ["//visibility:public"], + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = [":friends"], licenses = ["notice"], ) @@ -40,6 +41,7 @@ xla_cc_binary( deps = [ "//xla/mlir/math/transforms:passes", "//xla/mlir/memref/transforms:passes", + "//xla/mlir/runtime/ir:rt", "//xla/mlir/runtime/ir/tests:testlib", "//xla/mlir/runtime/transforms:compilation_pipeline_cpu", "//xla/mlir/runtime/transforms:compilation_pipeline_gpu", @@ -50,5 +52,6 @@ xla_cc_binary( "@llvm-project//mlir:MathDialect", "@llvm-project//mlir:MemRefDialect", "@llvm-project//mlir:MlirOptLib", + "@llvm-project//mlir:Support", ], ) diff --git a/third_party/xla/xla/mlir/runtime/ir/BUILD b/third_party/xla/xla/mlir/runtime/ir/BUILD index cbe2154a3bc8a5..50e28fcf876897 100644 --- a/third_party/xla/xla/mlir/runtime/ir/BUILD +++ b/third_party/xla/xla/mlir/runtime/ir/BUILD @@ -3,7 +3,8 @@ load("@local_tsl//tsl:tsl.default.bzl", "get_compatible_with_portable") load("@local_tsl//tsl/platform:rules_cc.bzl", "cc_library") package( - default_visibility = ["//visibility:public"], + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = ["//xla/mlir/runtime:friends"], licenses = ["notice"], ) @@ -97,7 +98,6 @@ cc_library( "rt_ops.h", ], compatible_with = get_compatible_with_portable(), - visibility = ["//visibility:public"], deps = [ ":rt_inc_gen", ":rt_interfaces_inc_gen", diff --git a/third_party/xla/xla/mlir/runtime/ir/tests/BUILD b/third_party/xla/xla/mlir/runtime/ir/tests/BUILD index b7f5b4bd6029ca..8baa4ef9124e44 100644 --- a/third_party/xla/xla/mlir/runtime/ir/tests/BUILD +++ b/third_party/xla/xla/mlir/runtime/ir/tests/BUILD @@ -83,7 +83,7 @@ cc_library( srcs = ["testlib.cc"], hdrs = ["testlib.h"], compatible_with = get_compatible_with_portable(), - visibility = ["//visibility:public"], + visibility = ["//xla/mlir/runtime:friends"], deps = [ ":testlib_inc_gen", "@llvm-project//llvm:Support", diff --git a/third_party/xla/xla/mlir/runtime/transforms/BUILD b/third_party/xla/xla/mlir/runtime/transforms/BUILD index e7c650711671d4..4a82bebe33cc55 100644 --- a/third_party/xla/xla/mlir/runtime/transforms/BUILD +++ b/third_party/xla/xla/mlir/runtime/transforms/BUILD @@ -5,7 +5,8 @@ load("@local_tsl//tsl/platform:build_config.bzl", "if_llvm_system_z_available") load("@local_tsl//tsl/platform:rules_cc.bzl", "cc_library") package( - default_visibility = ["//visibility:public"], + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = ["//xla/mlir/runtime:friends"], licenses = ["notice"], ) @@ -39,7 +40,6 @@ cc_library( ], hdrs = ["passes.h"], compatible_with = get_compatible_with_portable(), - visibility = ["//visibility:public"], deps = [ ":custom_call_encoding", ":passes_inc_gen", @@ -70,7 +70,6 @@ cc_library( srcs = ["calling_convention.cc"], hdrs = ["calling_convention.h"], compatible_with = get_compatible_with_portable(), - visibility = ["//visibility:public"], deps = [ "//xla/mlir/runtime/ir:rt", "@llvm-project//mlir:IR", @@ -195,7 +194,6 @@ cc_library( name = "compilation_pipeline_options", hdrs = ["compilation_pipeline_options.h"], compatible_with = get_compatible_with_portable(), - visibility = ["//visibility:public"], deps = [ ":custom_call_encoding", "//xla/runtime:type_id", @@ -208,7 +206,6 @@ cc_library( srcs = ["custom_call_encoding.cc"], hdrs = ["custom_call_encoding.h"], compatible_with = get_compatible_with_portable(), - visibility = ["//visibility:public"], deps = [ "//xla:shape_util", "//xla/mlir/runtime/ir:rt", @@ -232,7 +229,6 @@ cc_library( srcs = ["jit_compiler.cc"], hdrs = ["jit_compiler.h"], compatible_with = get_compatible_with_portable(), - visibility = ["//visibility:public"], deps = [ ":calling_convention", ":compiler", @@ -283,7 +279,6 @@ cc_library( srcs = ["specialization.cc"], hdrs = ["specialization.h"], compatible_with = get_compatible_with_portable(), - visibility = ["//visibility:public"], deps = [ ":type_converter", "//xla/mlir/runtime/utils:constraints", @@ -307,7 +302,6 @@ cc_library( srcs = ["type_converter.cc"], hdrs = ["type_converter.h"], compatible_with = get_compatible_with_portable(), - visibility = ["//visibility:public"], deps = [ "//xla:shape_util", "//xla/mlir/runtime/ir:rt", @@ -337,7 +331,6 @@ cc_library( name = "compiler", hdrs = ["compiler.h"], compatible_with = get_compatible_with_portable(), - visibility = ["//visibility:public"], deps = [ "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", diff --git a/third_party/xla/xla/mlir/runtime/transforms/tests/BUILD b/third_party/xla/xla/mlir/runtime/transforms/tests/BUILD index 1ddcf60bab5adc..966f4837132247 100644 --- a/third_party/xla/xla/mlir/runtime/transforms/tests/BUILD +++ b/third_party/xla/xla/mlir/runtime/transforms/tests/BUILD @@ -1,7 +1,6 @@ load("//xla:lit.bzl", "enforce_glob", "lit_test_suite") package( - default_visibility = ["//visibility:public"], # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], licenses = ["notice"], ) @@ -32,7 +31,7 @@ cc_library( testonly = 1, srcs = ["testlib_pipeline.cc"], hdrs = ["testlib_pipeline.h"], - visibility = ["//visibility:public"], + visibility = ["//xla:runtime"], deps = [ "//xla/mlir/runtime/transforms:compiler", "//xla/mlir/runtime/transforms:passes", diff --git a/third_party/xla/xla/mlir/runtime/utils/BUILD b/third_party/xla/xla/mlir/runtime/utils/BUILD index 78c4a00ae7f216..08021de8b0da53 100644 --- a/third_party/xla/xla/mlir/runtime/utils/BUILD +++ b/third_party/xla/xla/mlir/runtime/utils/BUILD @@ -2,7 +2,8 @@ load("@local_tsl//tsl:tsl.default.bzl", "get_compatible_with_portable") load("@local_tsl//tsl/platform:rules_cc.bzl", "cc_library") package( - default_visibility = ["//visibility:public"], + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = ["//xla/mlir/runtime:friends"], licenses = ["notice"], ) @@ -11,7 +12,6 @@ cc_library( srcs = ["async_runtime_api.cc"], hdrs = ["async_runtime_api.h"], compatible_with = get_compatible_with_portable(), - visibility = ["//visibility:public"], deps = [ "//xla/runtime:async_runtime", "@com_google_absl//absl/base:dynamic_annotations", @@ -27,7 +27,6 @@ cc_library( name = "c_runner_utils", hdrs = ["c_runner_utils.h"], compatible_with = get_compatible_with_portable(), - visibility = ["//visibility:public"], deps = [ "@llvm-project//llvm:OrcJIT", "@llvm-project//mlir:mlir_c_runner_utils", @@ -39,7 +38,6 @@ cc_library( srcs = ["constraints.cc"], hdrs = ["constraints.h"], compatible_with = get_compatible_with_portable(), - visibility = ["//visibility:public"], deps = [ "//xla/runtime:constraints", "@com_google_absl//absl/status", @@ -57,7 +55,6 @@ cc_library( srcs = ["custom_calls.cc"], hdrs = ["custom_calls.h"], compatible_with = get_compatible_with_portable(), - visibility = ["//visibility:public"], deps = [ "@llvm-project//llvm:Support", "@llvm-project//mlir:FuncDialect", @@ -69,6 +66,5 @@ cc_library( name = "float_16bits", hdrs = ["float_16bits.h"], compatible_with = get_compatible_with_portable(), - visibility = ["//visibility:public"], deps = ["@llvm-project//llvm:OrcJIT"], ) diff --git a/third_party/xla/xla/mlir/runtime/xla-runtime-opt.cc b/third_party/xla/xla/mlir/runtime/xla-runtime-opt.cc index 355c6072609f4e..a710634d721000 100644 --- a/third_party/xla/xla/mlir/runtime/xla-runtime-opt.cc +++ b/third_party/xla/xla/mlir/runtime/xla-runtime-opt.cc @@ -18,9 +18,11 @@ limitations under the License. #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/Dialect/Math/IR/Math.h" // from @llvm-project #include "mlir/Dialect/MemRef/IR/MemRef.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Tools/mlir-opt/MlirOptMain.h" // from @llvm-project #include "xla/mlir/math/transforms/passes.h" #include "xla/mlir/memref/transforms/passes.h" +#include "xla/mlir/runtime/ir/rt_dialect.h" #include "xla/mlir/runtime/ir/tests/testlib.h" #include "xla/mlir/runtime/transforms/passes.h" diff --git a/third_party/xla/xla/mlir/utils/BUILD b/third_party/xla/xla/mlir/utils/BUILD index d6dc30baf24e8b..f413f6c76f7ade 100644 --- a/third_party/xla/xla/mlir/utils/BUILD +++ b/third_party/xla/xla/mlir/utils/BUILD @@ -1,8 +1,13 @@ +load("@local_tsl//tsl:tsl.bzl", "internal_visibility") load("@local_tsl//tsl:tsl.default.bzl", "get_compatible_with_portable") load("@local_tsl//tsl/platform:rules_cc.bzl", "cc_library") package( - default_visibility = ["//visibility:public"], + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = internal_visibility([ + "//third_party/golang/github_com/gomlx/gomlx:__subpackages__", + "//xla:internal", + ]), licenses = ["notice"], ) @@ -11,7 +16,6 @@ cc_library( srcs = ["error_util.cc"], hdrs = ["error_util.h"], compatible_with = get_compatible_with_portable(), - visibility = ["//visibility:public"], deps = [ "@com_google_absl//absl/status", "@llvm-project//llvm:Support", diff --git a/third_party/xla/xla/mlir/xla_cpu/ir/BUILD b/third_party/xla/xla/mlir/xla_cpu/ir/BUILD index 3252e387c8be8e..9cb3de61b2118e 100644 --- a/third_party/xla/xla/mlir/xla_cpu/ir/BUILD +++ b/third_party/xla/xla/mlir/xla_cpu/ir/BUILD @@ -1,9 +1,14 @@ load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "td_library") -load("@local_tsl//tsl:tsl.default.bzl", "get_compatible_with_portable") +load("@local_tsl//tsl:tsl.bzl", "internal_visibility") +load( + "@local_tsl//tsl:tsl.default.bzl", + "get_compatible_with_portable", +) load("@local_tsl//tsl/platform:rules_cc.bzl", "cc_library") package( - default_visibility = ["//visibility:public"], + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = internal_visibility(["//learning/brain/mlir:xla_friends"]), ) td_library( @@ -92,7 +97,6 @@ cc_library( "xla_cpu.cc", ], hdrs = ["xla_cpu.h"], - visibility = ["//visibility:public"], deps = [ ":xla_cpu_dialect_inc_gen", ":xla_cpu_enums_inc_gen", diff --git a/third_party/xla/xla/mlir_hlo/BUILD b/third_party/xla/xla/mlir_hlo/BUILD index 8d0f2b2b5aabc4..42f2a8ea0c2bb8 100644 --- a/third_party/xla/xla/mlir_hlo/BUILD +++ b/third_party/xla/xla/mlir_hlo/BUILD @@ -1,31 +1,26 @@ load("@bazel_skylib//rules:build_test.bzl", "build_test") load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "gentbl_filegroup", "td_library") +load("@local_tsl//tsl:tsl.bzl", "internal_visibility") load("@local_tsl//tsl:tsl.default.bzl", "get_compatible_with_portable") load("@local_tsl//tsl/platform:rules_cc.bzl", "cc_library") package( - default_visibility = ["//visibility:public"], + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = internal_visibility(["//learning/brain/mlir:mhlo_friends"]), licenses = ["notice"], ) -exports_files( - [ - "mhlo/IR/hlo_ops.td", - "lhlo/IR/lhlo_ops.td", - ], - visibility = ["//visibility:public"], -) +exports_files([ + "mhlo/IR/hlo_ops.td", + "lhlo/IR/lhlo_ops.td", +]) # Python extension sources. -exports_files( - ["bindings/python/MlirHloModule.cc"], - visibility = ["//visibility:public"], -) +exports_files(["bindings/python/MlirHloModule.cc"]) filegroup( name = "hlo_ops_td_filegroup", srcs = glob(["mhlo/IR/*.td"]), - visibility = ["//visibility:public"], ) td_library( @@ -321,7 +316,6 @@ cc_library( srcs = ["mhlo/IR/hlo_ops_common.cc"], hdrs = ["mhlo/IR/hlo_ops_common.h"], strip_include_prefix = ".", - visibility = ["//visibility:public"], deps = [ "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", @@ -384,7 +378,6 @@ cc_library( "deallocation/transforms/passes.h", ], strip_include_prefix = ".", - visibility = ["//visibility:public"], deps = [ ":deallocation_passes_inc_gen", ":deallocation_utils", @@ -428,7 +421,6 @@ cc_library( srcs = ["deallocation/utils/util.cc"], hdrs = ["deallocation/utils/util.h"], strip_include_prefix = ".", - visibility = ["//visibility:public"], deps = [ "@llvm-project//llvm:Support", "@llvm-project//mlir:ControlFlowInterfaces", @@ -484,7 +476,6 @@ cc_library( "lhlo/IR/lhlo_structured_interface.h.inc", ], strip_include_prefix = ".", - visibility = ["//visibility:public"], deps = [ ":lhlo_structured_interface_inc_gen", "@llvm-project//mlir:IR", @@ -497,7 +488,6 @@ cc_library( srcs = ["utils/convert_op_folder.cc"], hdrs = ["utils/convert_op_folder.h"], strip_include_prefix = ".", - visibility = ["//visibility:public"], deps = [ "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", @@ -525,7 +515,6 @@ cc_library( "utils/hlo_utils.h", ], strip_include_prefix = ".", - visibility = ["//visibility:public"], deps = [ ":canonicalize_inc_gen", ":convert_op_folder", @@ -573,7 +562,6 @@ cc_library( "lhlo/utils/lhlo_utils.h", ], strip_include_prefix = ".", - visibility = ["//visibility:public"], deps = [ ":hlo_ops_common", ":lhlo_ops_inc_gen", @@ -598,7 +586,6 @@ cc_library( srcs = ["lhlo_gpu/IR/lhlo_gpu_ops.cc"], hdrs = ["lhlo_gpu/IR/lhlo_gpu_ops.h"], strip_include_prefix = ".", - visibility = ["//visibility:public"], deps = [ ":hlo_ops_common", ":lhlo", @@ -619,7 +606,6 @@ cc_library( srcs = ["lhlo_gpu/IR/lhlo_gpu_ops.cc.inc"], hdrs = ["lhlo_gpu/IR/lhlo_gpu_ops.h.inc"], strip_include_prefix = ".", - visibility = ["//visibility:public"], deps = [ "@llvm-project//llvm:Support", "@llvm-project//mlir:Analysis", @@ -643,7 +629,6 @@ cc_library( srcs = ["mhlo/IR/init.cc"], hdrs = ["mhlo/IR/register.h"], strip_include_prefix = ".", - visibility = ["//visibility:public"], deps = [ ":mlir_hlo", "@llvm-project//mlir:IR", @@ -697,10 +682,10 @@ cc_library( "mhlo/transforms/mhlo_canonicalize_scatter/mhlo_canonicalize_scatter.cc", "mhlo/transforms/mhlo_flatten_tuple/mhlo_flatten_tuple.cc", "mhlo/transforms/mhlo_passes.h.inc", + "mhlo/transforms/mhlo_quant_legalize_to_int/mhlo_quant_legalize_to_int.cc", "mhlo/transforms/optimize_mhlo/optimize_mhlo.cc", "mhlo/transforms/optimize_mhlo/optimize_mhlo_pass.cc", "mhlo/transforms/prepare_for_export/prepare_for_export.cc", - "mhlo/transforms/rank_specialization/rank_specialization.cc", "mhlo/transforms/restrict_max_rank/restrict_max_rank.cc", "mhlo/transforms/shape_legalize_to_hlo/shape_legalize_to_hlo.cc", "mhlo/transforms/shape_reification/shape_reification_pass.cc", @@ -720,7 +705,6 @@ cc_library( "mhlo/utils/mhlo_rng_utils.h", ], strip_include_prefix = ".", - visibility = ["//visibility:public"], deps = [ ":chlo_legalize_to_hlo", ":hlo_legalize_to_stablehlo", @@ -758,6 +742,7 @@ cc_library( "@llvm-project//mlir:MathDialect", "@llvm-project//mlir:MemRefDialect", "@llvm-project//mlir:Pass", + "@llvm-project//mlir:QuantOps", "@llvm-project//mlir:SCFDialect", "@llvm-project//mlir:ShapeDialect", "@llvm-project//mlir:ShapeTransforms", @@ -778,7 +763,6 @@ cc_library( srcs = ["mhlo/utils/type_conversion.cc"], hdrs = ["mhlo/utils/type_conversion.h"], strip_include_prefix = ".", - visibility = ["//visibility:public"], deps = [ ":mlir_hlo", "@llvm-project//mlir:FuncDialect", @@ -794,7 +778,6 @@ cc_library( name = "map_lmhlo_to_scalar_op", hdrs = ["lhlo/transforms/map_lmhlo_to_scalar_op.h"], strip_include_prefix = ".", - visibility = ["//visibility:public"], deps = [ ":map_lhlo_to_hlo_op", ":map_mhlo_to_scalar_op", @@ -805,7 +788,6 @@ cc_library( name = "map_mhlo_to_scalar_op", hdrs = ["mhlo/transforms/map_mhlo_to_scalar_op.h"], strip_include_prefix = ".", - visibility = ["//visibility:public"], deps = [ ":mlir_hlo", "@llvm-project//llvm:Support", @@ -821,7 +803,6 @@ cc_library( name = "map_chlo_to_hlo_op", hdrs = ["mhlo/transforms/map_chlo_to_hlo_op.h"], strip_include_prefix = ".", - visibility = ["//visibility:public"], deps = [ ":mlir_hlo", "@llvm-project//mlir:IR", @@ -833,7 +814,6 @@ cc_library( name = "map_hlo_to_lhlo_op", hdrs = ["lhlo/transforms/map_hlo_to_lhlo_op.h"], strip_include_prefix = ".", - visibility = ["//visibility:public"], deps = [ ":lhlo", ":mlir_hlo", @@ -844,7 +824,6 @@ cc_library( name = "map_lhlo_to_hlo_op", hdrs = ["lhlo/transforms/map_lhlo_to_hlo_op.h"], strip_include_prefix = ".", - visibility = ["//visibility:public"], deps = [ ":lhlo", ":mlir_hlo", @@ -855,7 +834,6 @@ cc_library( name = "map_stablehlo_to_hlo_op", hdrs = ["mhlo/transforms/map_stablehlo_to_hlo_op.h"], strip_include_prefix = ".", - visibility = ["//visibility:public"], deps = [ ":mlir_hlo", "@stablehlo//:stablehlo_ops", @@ -873,7 +851,6 @@ cc_library( ], hdrs = ["lhlo/transforms/passes.h"], strip_include_prefix = ".", - visibility = ["//visibility:public"], deps = [ ":lhlo", ":lmhlo_pass_inc_gen", @@ -905,7 +882,6 @@ cc_library( srcs = ["utils/codegen_utils.cc"], hdrs = ["utils/codegen_utils.h"], strip_include_prefix = ".", - visibility = ["//visibility:public"], deps = [ "@llvm-project//llvm:Support", "@llvm-project//mlir:ArithDialect", @@ -921,7 +897,6 @@ cc_library( name = "placement_utils", hdrs = ["utils/placement_utils.h"], strip_include_prefix = ".", - visibility = ["//visibility:public"], deps = ["@llvm-project//llvm:Support"], ) @@ -930,7 +905,6 @@ cc_library( srcs = ["mhlo/utils/legalize_to_linalg_utils.cc"], hdrs = ["mhlo/utils/legalize_to_linalg_utils.h"], strip_include_prefix = ".", - visibility = ["//visibility:public"], deps = [ ":map_mhlo_to_scalar_op", ":mlir_hlo", @@ -958,7 +932,6 @@ cc_library( srcs = ["mhlo/utils/mhlo_rng_utils.cc"], hdrs = ["mhlo/utils/mhlo_rng_utils.h"], strip_include_prefix = ".", - visibility = ["//visibility:public"], deps = [ ":mlir_hlo", "@llvm-project//llvm:Support", @@ -978,7 +951,6 @@ cc_library( srcs = ["mhlo/utils/mhlo_scatter_gather_utils.cc"], hdrs = ["mhlo/utils/mhlo_scatter_gather_utils.h"], strip_include_prefix = ".", - visibility = ["//visibility:public"], deps = [ ":mlir_hlo", "@llvm-project//mlir:DialectUtils", @@ -1029,7 +1001,6 @@ cc_library( srcs = ["mhlo/transforms/unfuse_batch_norm/unfuse_batch_norm.cc"], hdrs = ["mhlo/transforms/rewriters.h"], strip_include_prefix = ".", - visibility = ["//visibility:public"], deps = [ ":mlir_hlo", "@llvm-project//llvm:Support", @@ -1047,7 +1018,6 @@ cc_library( srcs = ["mhlo/transforms/chlo_legalize_to_hlo/chlo_legalize_to_hlo.cc"], hdrs = ["mhlo/transforms/rewriters.h"], strip_include_prefix = ".", - visibility = ["//visibility:public"], deps = [ ":chlo_legalize_to_hlo_inc_gen", ":map_chlo_to_hlo_op", @@ -1085,7 +1055,6 @@ cc_library( srcs = ["mhlo/transforms/hlo_legalize_to_stablehlo/hlo_legalize_to_stablehlo.cc"], hdrs = ["mhlo/transforms/rewriters.h"], strip_include_prefix = ".", - visibility = ["//visibility:public"], deps = [ ":map_stablehlo_to_hlo_op", ":mlir_hlo", @@ -1104,7 +1073,6 @@ cc_library( srcs = ["mhlo/transforms/stablehlo_legalize_to_hlo/stablehlo_legalize_to_hlo.cc"], hdrs = ["mhlo/transforms/rewriters.h"], strip_include_prefix = ".", - visibility = ["//visibility:public"], deps = [ ":map_stablehlo_to_hlo_op", ":mlir_hlo", @@ -1135,7 +1103,6 @@ cc_library( "transforms/passes.h", ], strip_include_prefix = ".", - visibility = ["//visibility:public"], deps = [ ":chlo_legalize_to_hlo", ":deallocation_passes", @@ -1148,7 +1115,6 @@ cc_library( ":stablehlo_legalize_to_hlo", ":transforms_passes", ":transforms_passes_inc_gen", - ":userange_analysis", "@llvm-project//llvm:Support", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", @@ -1161,10 +1127,8 @@ cc_library( cc_library( name = "transforms_passes", srcs = [ - "analysis/test_userange_analysis.cc", "mhlo/analysis/test_shape_component_analysis.cc", "transforms/alloc_to_arg_pass.cc", - "transforms/buffer_packing.cc", "transforms/bufferize.cc", "transforms/bufferize_pass.cc", "transforms/collapse_parallel_loops_to_1d_pass.cc", @@ -1185,7 +1149,6 @@ cc_library( "transforms/rewriters.h", ], strip_include_prefix = ".", - visibility = ["//visibility:public"], deps = [ ":deallocation_passes", ":lhlo", @@ -1194,7 +1157,6 @@ cc_library( ":shape_component_analysis", ":transforms_passes_inc_gen", ":type_conversion", - ":userange_analysis", "@llvm-project//llvm:Support", "@llvm-project//mlir:AffineDialect", "@llvm-project//mlir:AffineToStandard", @@ -1267,7 +1229,6 @@ cc_library( ], hdrs = ["transforms/gpu_passes.h"], strip_include_prefix = ".", - visibility = ["//visibility:public"], deps = [ ":gpu_transforms_passes_inc_gen", ":lhlo", @@ -1349,27 +1310,11 @@ gentbl_cc_library( deps = ["@llvm-project//mlir:PassBaseTdFiles"], ) -cc_library( - name = "userange_analysis", - srcs = ["analysis/userange_analysis.cc"], - hdrs = ["analysis/userange_analysis.h"], - strip_include_prefix = ".", - visibility = ["//visibility:public"], - deps = [ - "@llvm-project//llvm:Support", - "@llvm-project//mlir:Analysis", - "@llvm-project//mlir:BufferizationTransforms", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:LoopLikeInterface", - ], -) - cc_library( name = "shape_component_analysis", srcs = ["mhlo/analysis/shape_component_analysis.cc"], hdrs = ["mhlo/analysis/shape_component_analysis.h"], strip_include_prefix = ".", - visibility = ["//visibility:public"], deps = [ ":mlir_hlo", "@llvm-project//llvm:Support", @@ -1401,7 +1346,6 @@ cc_library( srcs = CAPI_SOURCES, hdrs = CAPI_HEADERS, strip_include_prefix = ".", - visibility = ["//visibility:public"], deps = [ ":all_passes", ":mlir_hlo", @@ -1414,7 +1358,6 @@ cc_library( name = "CAPIHeaders", hdrs = CAPI_HEADERS, strip_include_prefix = ".", - visibility = ["//visibility:public"], deps = ["@llvm-project//mlir:CAPIIRHeaders"], ) @@ -1424,7 +1367,6 @@ cc_library( srcs = CAPI_SOURCES, hdrs = CAPI_HEADERS, strip_include_prefix = ".", - visibility = ["//visibility:public"], deps = [ ":all_passes", ":mlir_hlo", @@ -1487,7 +1429,6 @@ filegroup( "bindings/python/mlir/dialects/mhlo.py", ":MhloOpsPyGen", ], - visibility = ["//visibility:public"], ) # A light-weight runtime support library, used by MLIR code that results diff --git a/third_party/xla/xla/mlir_hlo/CMakeLists.txt b/third_party/xla/xla/mlir_hlo/CMakeLists.txt index d52ec092a08f6a..27c03298b33c1e 100644 --- a/third_party/xla/xla/mlir_hlo/CMakeLists.txt +++ b/third_party/xla/xla/mlir_hlo/CMakeLists.txt @@ -158,7 +158,6 @@ set(MLIR_HLO_TOOLS_DIR ${MLIR_HLO_BINARY_DIR}/bin) set(MLIR_HLO_LIB_DIR ${MLIR_HLO_BINARY_DIR}/lib) add_custom_target(check-mlir-hlo) -add_subdirectory(analysis) add_subdirectory(bindings) add_subdirectory(deallocation) add_subdirectory(lhlo) diff --git a/third_party/xla/xla/mlir_hlo/analysis/CMakeLists.txt b/third_party/xla/xla/mlir_hlo/analysis/CMakeLists.txt deleted file mode 100644 index 88a0fec149403f..00000000000000 --- a/third_party/xla/xla/mlir_hlo/analysis/CMakeLists.txt +++ /dev/null @@ -1,28 +0,0 @@ -add_mlir_library(MLIRHLOAnalysis - userange_analysis.cc - - DEPENDS - mlir-headers - - LINK_LIBS PUBLIC - MLIRAnalysis - MLIRIR -) - -add_mlir_library(MLIRHLOTestAnalysis - test_userange_analysis.cc - - DEPENDS - LMHLOTransformsPassIncGen - - LINK_COMPONENTS - Core - - LINK_LIBS PUBLIC - LmhloDialect - LmhloGPUDialect - MLIRHLOAnalysis - MLIRAnalysis - MLIRPass - MLIRTransforms -) diff --git a/third_party/xla/xla/mlir_hlo/analysis/test_userange_analysis.cc b/third_party/xla/xla/mlir_hlo/analysis/test_userange_analysis.cc deleted file mode 100644 index c4581a8a0a9a7e..00000000000000 --- a/third_party/xla/xla/mlir_hlo/analysis/test_userange_analysis.cc +++ /dev/null @@ -1,51 +0,0 @@ -/* Copyright 2021 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include - -#include "analysis/userange_analysis.h" -#include "lhlo/IR/lhlo_ops.h" -#include "mlir/Dialect/Bufferization/Transforms/BufferUtils.h" -#include "mlir/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.h" -#include "mlir/Pass/Pass.h" - -namespace mlir { - -#define GEN_PASS_DEF_TESTUSERANGE -#include "transforms/passes.h.inc" - -namespace { - -struct TestUserangePass : public impl::TestUserangeBase { - void getDependentDialects(DialectRegistry ®istry) const override { - registry.insert(); - } - - void runOnOperation() override { - llvm::outs() << "Testing : " << getOperation().getName() << "\n"; - UserangeAnalysis(getOperation(), - bufferization::BufferPlacementAllocs(getOperation()), - BufferViewFlowAnalysis(getOperation())) - .dump(llvm::outs()); - } -}; - -} // end anonymous namespace - -std::unique_ptr> createTestUserangePass() { - return std::make_unique(); -} - -} // namespace mlir diff --git a/third_party/xla/xla/mlir_hlo/analysis/userange_analysis.cc b/third_party/xla/xla/mlir_hlo/analysis/userange_analysis.cc deleted file mode 100644 index f4aca6fcc59b1e..00000000000000 --- a/third_party/xla/xla/mlir_hlo/analysis/userange_analysis.cc +++ /dev/null @@ -1,625 +0,0 @@ -/* Copyright 2021 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "analysis/userange_analysis.h" - -#include -#include -#include -#include - -#include "llvm/ADT/SetOperations.h" -#include "mlir/IR/Block.h" -#include "mlir/IR/Region.h" -#include "mlir/Interfaces/LoopLikeInterface.h" - -using namespace mlir; - -namespace { -/// Builds a userange information from the given value and its liveness. The -/// information includes all operations that are within the userange. -struct UserangeInfoBuilder { - using OperationListT = Liveness::OperationListT; - using ValueSetT = BufferViewFlowAnalysis::ValueSetT; - - public: - /// Constructs an Userange builder. - UserangeInfoBuilder(Liveness liveness, ValueSetT values, - OperationListT opList) - : values(std::move(values)), - opList(std::move(opList)), - liveness(std::move(liveness)) {} - - /// Computes the userange of the current value by iterating over all of its - /// uses. - Liveness::OperationListT computeUserange() { - Region *topRegion = findTopRegion(); - // Iterate over all associated uses. - for (Operation *use : opList) { - // If one of the parents implements a LoopLikeOpInterface we need to add - // all operations inside of its regions to the userange. - Operation *loopParent = use->getParentOfType(); - if (loopParent && topRegion->isProperAncestor(use->getParentRegion())) - addAllOperationsInRegion(loopParent); - - // Check if the parent block has already been processed. - Block *useBlock = findTopLiveBlock(use); - if (!startBlocks.insert(useBlock).second || visited.contains(useBlock)) - continue; - - // Add all operations inside the block that are within the userange. - findOperationsInUse(useBlock); - } - return currentUserange; - } - - private: - /// Find the top most Region of all values stored in the values set. - Region *findTopRegion() const { - Region *topRegion = nullptr; - llvm::for_each(values, [&](Value v) { - Region *other = v.getParentRegion(); - if (!topRegion || topRegion->isAncestor(other)) topRegion = other; - }); - return topRegion; - } - - /// Finds the highest level block that has the current value in its liveOut - /// set. - Block *findTopLiveBlock(Operation *op) const { - Operation *topOp = op; - while (const LivenessBlockInfo *blockInfo = - liveness.getLiveness(op->getBlock())) { - if (llvm::any_of(values, - [&](Value v) { return blockInfo->isLiveOut(v); })) - topOp = op; - op = op->getParentOp(); - } - return topOp->getBlock(); - } - - /// Adds all operations from start to end to the userange of the current - /// value. If an operation implements a nested region all operations inside of - /// it are included as well. If includeEnd is false the end operation is not - /// added. - void addAllOperationsBetween(Operation *start, Operation *end) { - currentUserange.push_back(start); - addAllOperationsInRegion(start); - - while (start != end) { - start = start->getNextNode(); - addAllOperationsInRegion(start); - currentUserange.push_back(start); - } - } - - /// Adds all operations that are uses of the value in the given block to the - /// userange of the current value. Additionally iterate over all successors - /// where the value is live. - void findOperationsInUse(Block *block) { - SmallVector blocksToProcess; - addOperationsInBlockAndFindSuccessors( - block, block, getStartOperation(block), blocksToProcess); - while (!blocksToProcess.empty()) { - Block *toProcess = blocksToProcess.pop_back_val(); - addOperationsInBlockAndFindSuccessors( - block, toProcess, &toProcess->front(), blocksToProcess); - } - } - - /// Adds the operations between the given start operation and the computed end - /// operation to the userange. If the current value is live out, add all - /// successor blocks that have the value live in to the process queue. If we - /// find a loop, add the operations before the first use in block to the - /// userange (if any). The startBlock is the block where the iteration over - /// all successors started and is propagated further to find potential loops. - void addOperationsInBlockAndFindSuccessors( - const Block *startBlock, Block *toProcess, Operation *start, - SmallVector &blocksToProcess) { - const LivenessBlockInfo *blockInfo = liveness.getLiveness(toProcess); - Operation *end = getEndOperation(toProcess); - - addAllOperationsBetween(start, end); - - // If the value is live out we need to process all successors at which the - // value is live in. - if (!llvm::any_of(values, [&](Value v) { return blockInfo->isLiveOut(v); })) - return; - for (Block *successor : toProcess->getSuccessors()) { - // If the successor is the startBlock, we found a loop and only have to - // add the operations from the block front to the first use of the - // value. - if (!llvm::any_of(values, [&](Value v) { - return liveness.getLiveness(successor)->isLiveIn(v); - })) - continue; - if (successor == startBlock) { - start = &successor->front(); - end = getStartOperation(successor); - if (start != end) addAllOperationsBetween(start, end->getPrevNode()); - // Else we need to check if the value is live in and the successor - // has not been visited before. If so we also need to process it. - } else if (visited.insert(successor).second) { - blocksToProcess.push_back(successor); - } - } - } - - /// Iterates over all regions of a given operation and adds all operations - /// inside those regions to the userange of the current value. - void addAllOperationsInRegion(Operation *parentOp) { - // Iterate over all regions of the parentOp. - for (Region ®ion : parentOp->getRegions()) { - // Iterate over blocks inside the region. - for (Block &block : region) { - // If the blocks have been used as a startBlock before, we need to add - // all operations between the block front and the startOp of the value. - if (startBlocks.contains(&block)) { - Operation *start = &block.front(); - Operation *end = getStartOperation(&block); - if (start != end) addAllOperationsBetween(start, end->getPrevNode()); - - // If the block has never been seen before, we need to add all - // operations inside. - } else if (visited.insert(&block).second) { - for (Operation &op : block) { - addAllOperationsInRegion(&op); - currentUserange.push_back(&op); - } - continue; - } - // If the block has either been visited before or was used as a - // startBlock, we need to add all operations between the endOp of the - // value and the end of the block. - Operation *end = getEndOperation(&block); - if (end == &block.back()) continue; - addAllOperationsBetween(end->getNextNode(), &block.back()); - } - } - } - - /// Find the start operation of the current value inside the given block. - Operation *getStartOperation(Block *block) { - Operation *startOperation = &block->back(); - for (Operation *useOp : opList) { - // Find the associated operation in the current block (if any). - useOp = block->findAncestorOpInBlock(*useOp); - // Check whether the use is in our block and after the current end - // operation. - if (useOp && useOp->isBeforeInBlock(startOperation)) - startOperation = useOp; - } - return startOperation; - } - - /// Find the end operation of the current value inside the given block. - Operation *getEndOperation(Block *block) { - const LivenessBlockInfo *blockInfo = liveness.getLiveness(block); - if (llvm::any_of(values, [&](Value v) { return blockInfo->isLiveOut(v); })) - return &block->back(); - - Operation *endOperation = &block->front(); - for (Operation *useOp : opList) { - // Find the associated operation in the current block (if any). - useOp = block->findAncestorOpInBlock(*useOp); - // Check whether the use is in our block and after the current end - // operation. - if (useOp && endOperation->isBeforeInBlock(useOp)) endOperation = useOp; - } - return endOperation; - } - - /// The current Value. - ValueSetT values; - - /// The list of all operations used by the values. - OperationListT opList; - - /// The result list of the userange computation. - OperationListT currentUserange; - - /// The set of visited blocks during the userange computation. - SmallPtrSet visited; - - /// The set of blocks that the userange computation started from. - SmallPtrSet startBlocks; - - /// The current liveness info. - Liveness liveness; -}; -} // namespace - -/// Empty UseInterval Constructor. -UseInterval::UseInterval() - : start(std::numeric_limits::max()), - end(std::numeric_limits::min()) {} - -/// Performs an interval subtraction => A = A - B. -void UseInterval::intervalSubtract(UseInterval::Vector &a, - const UseInterval::Vector &b) { - const auto *iterB = b.begin(); - const auto *endB = b.end(); - for (auto *iterA = a.begin(); iterA != a.end() && iterB != endB;) { - // iterA is strictly before iterB => increment iterA. - if (*iterA < *iterB) { - ++iterA; - // iterB is strictly before iterA => increment iterB. - } else if (*iterA > *iterB) { - ++iterB; - // iterB overlaps with the start of iterA, but iterA has some values that - // go beyond those of iterB. We have to set the start of iterA to the end - // of iterB + 1 and increment iterB. A(3, 100) - B(3, 5) => A(6,100) - } else if (iterA->start >= iterB->start && iterA->end > iterB->end) { - iterA->start = iterB->end + 1; - ++iterB; - // iterB overlaps with the end of iterA, but iterA has some values that - // come before iterB. We have to set the end of iterA to the start of - // iterB - 1 and increment iterA. A(4, 50) - B(40, 50) => A(4, 39) - } else if (iterA->end <= iterB->end && iterA->start < iterB->start) { - iterA->end = iterB->start - 1; - ++iterA; - // iterB is in the middle of iterA. We have to split iterA and increment - // iterB. - // A(2, 10) - B(5, 7) => (2, 4), (8, 10) - } else if (iterA->start < iterB->start && iterA->end > iterB->end) { - size_t endA = iterA->end; - iterA->end = iterB->start - 1; - iterA = a.insert(iterA, UseInterval(iterB->end + 1, endA)); - ++iterB; - // Both intervals are equal. We have to erase the whole interval. - // A(5, 5) - B(5, 5) => {} - } else { - iterA = a.erase(iterA); - ++iterB; - } - } -} - -/// Performs an interval intersection => A = A ^ B. -void UseInterval::intervalIntersect(UseInterval::Vector &a, - const UseInterval::Vector &b) { - const auto *iterB = b.begin(); - const auto *endB = b.end(); - for (auto *iterA = a.begin(); iterA != a.end();) { - // iterB points to the end, therefore the remaining UseIntervals from A must - // be erased or iterA is strictly before iterB => erase iterA. - if (iterB == endB || *iterA < *iterB) { - iterA = a.erase(iterA); - // iterB is strictly before iterA => increment iterB. - } else if (*iterA > *iterB) { - ++iterB; - // iterB overlaps with iterA => reduce the interval to the overlap and - // insert the ending split-off to vector A again. - } else { - size_t currentEndA = iterA->end; - iterA->start = std::max(iterA->start, iterB->start); - iterA->end = std::min(currentEndA, iterB->end); - if (currentEndA > iterB->end) { - iterA = a.insert(std::next(iterA), - UseInterval(iterB->end + 1, currentEndA)); - ++iterB; - } else { - ++iterA; - } - } - } -} - -/// Performs an interval merge => A = A u B. -/// Note: All overlapping and contiguous UseIntervals are merged. -void UseInterval::intervalMerge(UseInterval::Vector &a, - const UseInterval::Vector &b) { - const auto *iterB = b.begin(); - const auto *endB = b.end(); - // Iterate over UseInterval::Vector a and b. - for (auto *iterA = a.begin(); iterA != a.end() && iterB != endB;) { - // Let A be the UseInterval of iterA and B the UseInterval of iterB. - // Check if A is before B. - if (*iterA < *iterB) { - // Check if A and B can be merged if they are contiguous. If the merge - // result contains the next elements of A, we can erase them. - if (iterA->isContiguous(*iterB)) { - mergeAndEraseContiguousIntervals(a, iterA, *iterB); - ++iterB; - } - ++iterA; - // Check if B is before A. - } else if (*iterA > *iterB) { - // Check if A and B can be merged if they are contiguous, else add B - // to the Vector of A. - if (iterB->isContiguous(*iterA)) - iterA->mergeWith(*iterB); - else - iterA = a.insert(iterA, *iterB); - ++iterB; - // The UseIntervals interfere and must be merged. - } else { - mergeAndEraseContiguousIntervals(a, iterA, *iterB); - ++iterB; - } - } - // If there are remaining UseIntervals in b, add them to a. - if (iterB != endB) a.insert(a.end(), iterB, endB); -} - -/// Merge the UseIntervals and erase overlapping and contiguouse UseIntervals -/// of the UseInterval::Vector. -void UseInterval::mergeAndEraseContiguousIntervals( - UseInterval::Vector &interval, UseInterval *iter, - const UseInterval &toMerge) { - // Return if the iter points to the end. - if (iter == interval.end()) return; - - // Merge the UseIntervals. - iter->mergeWith(toMerge); - - // Find the next UseInterval from iter that is not contiguous with the merged - // iter. - UseInterval *next = std::next(iter); - while (next != interval.end() && iter->isContiguous(*next)) { - if (iter->end < next->end) iter->end = next->end; - ++next; - } - // Remove contiguous UseIntervals. - if (std::next(iter) != next) iter = interval.erase(std::next(iter), next); -} - -UserangeAnalysis::UserangeAnalysis( - Operation *op, const bufferization::BufferPlacementAllocs &allocs, - const BufferViewFlowAnalysis &aliases) - : liveness(op) { - // Walk over all operations and map them to an ID. - op->walk([&](Operation *operation) { - gatherMemoryEffects(operation); - operationIds.insert({operation, operationIds.size()}); - operations.push_back(operation); - }); - - // Compute the use range for every allocValue and its aliases. Merge them - // and compute an interval. Add all computed intervals to the useIntervalMap. - for (const bufferization::BufferPlacementAllocs::AllocEntry &entry : allocs) { - Value allocValue = std::get<0>(entry); - const Value::use_range &allocUses = allocValue.getUses(); - size_t dist = std::distance(allocUses.begin(), allocUses.end()); - OperationListT useList; - useList.reserve(dist); - for (auto &use : allocUses) useList.push_back(use.getOwner()); - computeUsePositions(allocValue); - - UserangeInfoBuilder builder(liveness, {allocValue}, useList); - OperationListT liveOperations = builder.computeUserange(); - - // Sort the operation list by ids. - std::sort(liveOperations.begin(), liveOperations.end(), - [&](Operation *left, Operation *right) { - return operationIds[left] < operationIds[right]; - }); - - UseInterval::Vector allocInterval = - computeInterval(allocValue, liveOperations); - // Iterate over all aliases and add their useranges to the userange of the - // current value. Also add the useInterval of each alias to the - // useIntervalMap. - ValueSetT aliasSet = aliases.resolve(allocValue); - for (Value alias : aliasSet) { - if (alias == allocValue) continue; - if (!aliasUseranges.count(alias)) { - OperationListT aliasOperations; - // If the alias is a BlockArgument then the value is live with the first - // operation inside that block. Otherwise the liveness analysis is - // sufficient for the use range. - if (alias.isa()) { - aliasOperations.push_back(&alias.getParentBlock()->front()); - for (auto &use : alias.getUses()) - aliasOperations.push_back(use.getOwner()); - // Compute the use range for the alias and sort the operations - // afterwards. - UserangeInfoBuilder aliasBuilder(liveness, {alias}, aliasOperations); - aliasOperations = aliasBuilder.computeUserange(); - std::sort(aliasOperations.begin(), aliasOperations.end(), - [&](Operation *left, Operation *right) { - return operationIds[left] < operationIds[right]; - }); - } else { - aliasOperations = liveness.resolveLiveness(alias); - } - - aliasUseranges.insert({alias, aliasOperations}); - useIntervalMap.insert( - {alias, computeInterval(alias, aliasUseranges[alias])}); - computeUsePositions(alias); - } - UseInterval::intervalMerge(allocInterval, useIntervalMap[alias]); - mergeUsePositions(usePositionMap[allocValue], usePositionMap[alias]); - } - aliasCache.insert(std::make_pair(allocValue, aliasSet)); - - // Map the current allocValue to the computed useInterval. - useIntervalMap.insert(std::make_pair(allocValue, allocInterval)); - } -} - -/// Computes the doubled Id for the given value inside the operation based on -/// the program sequence. If the value has only read effects, the returning ID -/// will be even, otherwise odd. -size_t UserangeAnalysis::computeId(Value v, Operation *op) const { - size_t doubledID = (operationIds.find(op)->second + 1) * 2 - 1; - auto mapIter = opReadWriteMap.find(op); - if (mapIter == opReadWriteMap.end()) return doubledID; - auto reads = mapIter->second.first; - auto writes = mapIter->second.second; - if (reads.contains(v) && !writes.contains(v)) return doubledID - 1; - return doubledID; -} - -/// Computes the UsePositions of the given Value, sorts and inserts them into -/// the usePositionMap. -void UserangeAnalysis::computeUsePositions(Value v) { - // Get the uses of v. - const Value::use_range &uses = v.getUses(); - - // Create a UsePositionList. - UsePositionList usePosList; - size_t dist = std::distance(uses.begin(), uses.end()); - usePosList.reserve(dist); - - // Add all ids and Operations to the UsePositionList. - for (auto &use : uses) { - Operation *useOwner = use.getOwner(); - usePosList.emplace_back(computeId(v, useOwner), useOwner); - } - - // Sort the UsePositions by ascending Ids. - std::sort(usePosList.begin(), usePosList.end(), - [](const UsePosition &a, const UsePosition &b) { - return a.first < b.first; - }); - - // Insert the UsePositionList into the usePositionMap. - usePositionMap.insert(std::make_pair(v, usePosList)); -} - -/// Merges listB into listA, sorts the result and removes all duplicates. -void UserangeAnalysis::mergeUsePositions(UsePositionList &listA, - const UsePositionList &listB) { - // Insert listB into listA. - listA.insert(listA.end(), listB.begin(), listB.end()); - - // Sort the resulting listA. - std::sort(listA.begin(), listA.end(), - [](const UsePosition &a, const UsePosition &b) { - return a.first < b.first; - }); - - // Remove duplicates. - listA.erase(std::unique(listA.begin(), listA.end()), listA.end()); -} - -/// Checks if the use intervals of the given values interfere. -bool UserangeAnalysis::rangesInterfere(Value itemA, Value itemB) const { - ValueSetT intersect = aliasCache.find(itemA)->second; - llvm::set_intersect(intersect, aliasCache.find(itemB)->second); - UseInterval::Vector tmpIntervalA = useIntervalMap.find(itemA)->second; - const UseInterval::Vector &intervalsB = useIntervalMap.find(itemB)->second; - - // If the two values share a common alias, then the alias does not count as an - // interference and should be removed. - if (!intersect.empty()) { - for (Value alias : intersect) { - const UseInterval::Vector &aliasInterval = - useIntervalMap.find(alias)->second; - UseInterval::intervalSubtract(tmpIntervalA, aliasInterval); - } - } - - // Iterate over both UseInterval::Vector and check if they interfere. - const auto *iterB = intervalsB.begin(); - const auto *endB = intervalsB.end(); - for (auto iterA = tmpIntervalA.begin(), endA = tmpIntervalA.end(); - iterA != endA && iterB != endB;) { - if (*iterA < *iterB) - ++iterA; - else if (*iterA > *iterB) - ++iterB; - else - return true; - } - return false; -} - -/// Merges the userange of itemB into the userange of itemA. -void UserangeAnalysis::unionRanges(Value itemA, Value itemB) { - UseInterval::intervalMerge(useIntervalMap[itemA], useIntervalMap[itemB]); -} - -/// Builds an UseInterval::Vector corresponding to the given OperationList. -UseInterval::Vector UserangeAnalysis::computeInterval( - Value value, const Liveness::OperationListT &operationList) { - assert(!operationList.empty() && "Operation list must not be empty"); - size_t start = computeId(value, *operationList.begin()); - size_t last = start; - UseInterval::Vector intervals; - // Iterate over all operations in the operationList. If the gap between the - // respective operationIds is greater 1 create a new interval. - for (auto opIter = ++operationList.begin(), e = operationList.end(); - opIter != e; ++opIter) { - size_t current = computeId(value, *opIter); - if (current - last > 2) { - intervals.emplace_back(start, last); - start = current; - } - last = current; - } - intervals.emplace_back(start, last); - return intervals; -} - -/// Checks each operand within the operation for its memory effects and -/// separates them into read and write. -void UserangeAnalysis::gatherMemoryEffects(Operation *op) { - if (OpTrait::hasElementwiseMappableTraits(op)) { - if (auto effectInterface = dyn_cast(op)) { - SmallPtrSet readEffectSet; - SmallPtrSet writeEffectSet; - SmallVector effects; - for (auto operand : op->getOperands()) { - effects.clear(); - effectInterface.getEffectsOnValue(operand, effects); - for (auto effect : effects) { - if (isa(effect.getEffect())) - writeEffectSet.insert(operand); - else if (isa(effect.getEffect())) - readEffectSet.insert(operand); - } - } - opReadWriteMap.insert( - {op, std::make_pair(readEffectSet, writeEffectSet)}); - } - } -} - -/// Computes the doubled Id back to the OperationId. -size_t UserangeAnalysis::unwrapId(size_t id) const { return id / 2; } - -void UserangeAnalysis::dump(raw_ostream &os) { - os << "// ---- UserangeAnalysis -----\n"; - llvm::SmallVector values; - values.reserve(useIntervalMap.size()); - for (auto const &item : useIntervalMap) { - values.push_back(item.first); - } - std::sort(values.begin(), values.end(), [&](Value left, Value right) { - if (left.getDefiningOp()) { - if (right.getDefiningOp()) - return operationIds[left.getDefiningOp()] < - operationIds[right.getDefiningOp()]; - return true; - } - if (right.getDefiningOp()) return false; - return operationIds[&left.getParentBlock()->front()] < - operationIds[&right.getParentBlock()->front()]; - }); - for (auto value : values) { - os << "Value: " << value << (value.getDefiningOp() ? "\n" : ""); - auto *rangeIt = useIntervalMap[value].begin(); - os << "Userange: {(" << rangeIt->start << ", " << rangeIt->end << ")"; - rangeIt++; - for (auto *e = useIntervalMap[value].end(); rangeIt != e; ++rangeIt) { - os << ", (" << rangeIt->start << ", " << rangeIt->end << ")"; - } - os << "}\n"; - } - os << "// ---------------------------\n"; -} diff --git a/third_party/xla/xla/mlir_hlo/analysis/userange_analysis.h b/third_party/xla/xla/mlir_hlo/analysis/userange_analysis.h deleted file mode 100644 index 3e04e4bbc7bf71..00000000000000 --- a/third_party/xla/xla/mlir_hlo/analysis/userange_analysis.h +++ /dev/null @@ -1,206 +0,0 @@ -/* Copyright 2021 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef MLIR_HLO_ANALYSIS_USERANGE_ANALYSIS_H -#define MLIR_HLO_ANALYSIS_USERANGE_ANALYSIS_H - -#include -#include - -#include "mlir/Analysis/Liveness.h" -#include "mlir/Dialect/Bufferization/Transforms/BufferUtils.h" -#include "mlir/IR/Operation.h" -#include "mlir/IR/Value.h" - -namespace mlir { - -/// Representation of an inclusive Interval for the Userange. -struct UseInterval { - using Vector = SmallVector; - - public: - /// UseInterval Constructor. - UseInterval(); - /// Empty UseInterval Constructor. - UseInterval(size_t start, size_t end) : start(start), end(end) {} - - /// Checks if the given UseInterval overlaps with this UseInterval. - bool isOverlapping(const UseInterval &other) const { - return start <= other.end && end >= other.start; - } - - /// Checks if the given UseInterval is contiguous with this UseInterval in - /// terms of doubled Ids. - /// For example: (0, 2) and (4, 6) are contiguous where (0, 2) and (5, 6) are - /// not. - bool isContiguous(const UseInterval &other) const { - return start <= other.end + 2 && end + 2 >= other.start; - } - - /// Checks if the given position is inside this UseInterval. - bool contains(size_t position) const { - return start <= position && end >= position; - } - - /// Merges this UseInterval with the given UseInterval by updating start and - /// end. - bool mergeWith(const UseInterval &other) { - if (!isContiguous(other)) return false; - start = std::min(start, other.start); - end = std::max(end, other.end); - return true; - } - - /// Performs an interval subtraction => A = A - B. - static void intervalSubtract(Vector &a, const Vector &b); - - /// Performs an interval intersection => A = A ^ B. - static void intervalIntersect(Vector &a, const Vector &b); - - /// Performs an interval merge => A = A u B. - /// Note: All overlapping and contiguous UseIntervals are merged. - static void intervalMerge(Vector &a, const Vector &b); - - /// Merge the UseIntervals and erase overlapping and contiguouse UseIntervals - /// of the UseInterval::Vector. - static void mergeAndEraseContiguousIntervals(Vector &interval, - UseInterval *iter, - const UseInterval &toMerge); - - bool operator<(const UseInterval &other) const { return end < other.start; } - - bool operator>(const UseInterval &other) const { return start > other.end; } - - bool operator==(const UseInterval &other) const { - return start == other.start && end == other.end; - } - - /// The start of this UseInterval. - size_t start; - - /// The end of this UseInterval. - size_t end; -}; - -/// Represents an analysis for computing the useranges of all alloc values -/// inside a given function operation. The analysis uses liveness information to -/// compute intervals starting at the first and ending with the last use of -/// every alloc value. -class UserangeAnalysis { - public: - using UsePosition = std::pair; - using UsePositionList = std::vector; - - UserangeAnalysis(Operation *op, - const bufferization::BufferPlacementAllocs &allocs, - const BufferViewFlowAnalysis &aliases); - - /// Returns the index of the first operation that uses the given value or an - /// empty Optional if the value has no uses. - std::optional getFirstUseIndex(Value value) const { - auto &intervals = useIntervalMap.find(value)->second; - if (intervals.empty()) return std::nullopt; - return intervals.begin()->start; - } - - /// Returns the UseInterval::Vector of the given value. - std::optional getUserangeInterval( - Value value) const { - auto intervals = useIntervalMap.find(value); - if (intervals == useIntervalMap.end()) return std::nullopt; - return &intervals->second; - } - - /// Returns an UsePositionList* of the given value or an empty Optional - /// if the value has no uses. - std::optional getUserangePositions( - Value value) const { - auto usePosition = usePositionMap.find(value); - if (usePosition == usePositionMap.end() || usePosition->second.empty()) - return std::nullopt; - return &usePosition->second; - } - - /// Returns the operation associated with a given Id. - Operation *getOperation(size_t id) const { return operations[unwrapId(id)]; }; - - /// Computes the doubled Id for the given value inside the operation based on - /// the program sequence. If the value has only read effects, the returning ID - /// will be even, otherwise odd. - size_t computeId(Value v, Operation *op) const; - - /// Checks if the use intervals of the given values interfere. - bool rangesInterfere(Value itemA, Value itemB) const; - - /// Merges the userange of itemB into the userange of itemA. - void unionRanges(Value itemA, Value itemB); - - /// Merges listB into listA, sorts the result and removes all duplicates. - static void mergeUsePositions(UsePositionList &listA, - const UsePositionList &listB); - - /// Dumps the liveness information to the given stream. - void dump(raw_ostream &os); - - private: - using ValueSetT = BufferViewFlowAnalysis::ValueSetT; - using OperationListT = Liveness::OperationListT; - - /// Builds an UseInterval::Vector corresponding to the given OperationList. - UseInterval::Vector computeInterval( - Value value, const Liveness::OperationListT &operationList); - - /// Computes the UsePositions of the given Value, sorts and inserts them into - /// the usePositionMap. - void computeUsePositions(Value v); - - /// Checks each operand within the operation for its memory effects and - /// separates them into read and write. - void gatherMemoryEffects(Operation *op); - - /// Computes the doubled Id back to the OperationId. - size_t unwrapId(size_t id) const; - - /// Maps each Operation to a unique ID according to the program sequence. - DenseMap operationIds; - - /// Stores all operations according to the program sequence. - std::vector operations; - - /// Maps a value to its UseInterval::Vector. - DenseMap useIntervalMap; - - /// Maps an Operation to a pair of read and write Operands. - DenseMap, SmallPtrSet>> - opReadWriteMap; - - /// Maps aliasValues to their use ranges. This is necessary to prevent - /// recomputations of the use range intervals of the aliases. - DenseMap aliasUseranges; - - /// Maps a Value to a UsePostionList which contains all uses of the Value and - /// their userange position. - DenseMap usePositionMap; - - /// Cache the alias lists for all values to avoid recomputation. - BufferViewFlowAnalysis::ValueMapT aliasCache; - - /// The current liveness info. - Liveness liveness; -}; - -} // namespace mlir - -#endif // MLIR_HLO_ANALYSIS_USERANGE_ANALYSIS_H diff --git a/third_party/xla/xla/mlir_hlo/lhlo/transforms/lhlo_legalize_to_parallel_loops/lhlo_legalize_to_parallel_loops.cc b/third_party/xla/xla/mlir_hlo/lhlo/transforms/lhlo_legalize_to_parallel_loops/lhlo_legalize_to_parallel_loops.cc index 81adcc7e9634cd..f3c6b3940aced5 100644 --- a/third_party/xla/xla/mlir_hlo/lhlo/transforms/lhlo_legalize_to_parallel_loops/lhlo_legalize_to_parallel_loops.cc +++ b/third_party/xla/xla/mlir_hlo/lhlo/transforms/lhlo_legalize_to_parallel_loops/lhlo_legalize_to_parallel_loops.cc @@ -453,12 +453,14 @@ class ReduceWindowOpConverter loc, inputType.getElementType(), mappedIvs.inBounds, /*withElseRegion=*/true); - OpBuilder thenBuilder = elemOrInit.getThenBodyBuilder(rewriter); + OpBuilder thenBuilder = + elemOrInit.getThenBodyBuilder(rewriter->getListener()); Value elem = thenBuilder.create(loc, input, mappedIvs.ivs); thenBuilder.create(loc, elem); - OpBuilder elseBuilder = elemOrInit.getElseBodyBuilder(rewriter); + OpBuilder elseBuilder = + elemOrInit.getElseBodyBuilder(rewriter->getListener()); elseBuilder.create(loc, *windowLoop.getInitVals().begin()); return rewriter->create(loc, diff --git a/third_party/xla/xla/mlir_hlo/lhlo_gpu/IR/lhlo_gpu_ops.td b/third_party/xla/xla/mlir_hlo/lhlo_gpu/IR/lhlo_gpu_ops.td index 7b6de62b1eab0f..e1c305979e9496 100644 --- a/third_party/xla/xla/mlir_hlo/lhlo_gpu/IR/lhlo_gpu_ops.td +++ b/third_party/xla/xla/mlir_hlo/lhlo_gpu/IR/lhlo_gpu_ops.td @@ -333,16 +333,20 @@ def LHLOGPU_AllToAllDoneOp: LHLOGPU_Op<"all_to_all_done"> { def LHLOGPU_CudnnNormOp : LHLOGPU_Op<"Norm", [AttrSizedOperandSegments]> { let arguments = (ins - Arg:$input, + Arg:$x, Arg:$scale, - Arg:$bias, - Arg:$output, - Arg, "", [MemWrite]>:$expectation, - Arg, "", [MemWrite]>:$norm_factor, + Arg:$y_or_dx, + Arg, "", [MemRead]>:$bias, + Arg, "", [MemRead]>:$dy, + Arg, "", [MemRead, MemWrite]>:$expectation, + Arg, "", [MemRead, MemWrite]>:$norm_factor, + Arg, "", [MemWrite]>:$dscale, + Arg, "", [MemWrite]>:$dbias, Arg:$scratch, NormAlgorithmConfigAttr:$algorithm_config, F64Attr:$epsilon, - I64ArrayAttr:$operand_layouts + I64ArrayAttr:$operand_layouts, + CudnnNormKindAttr:$kind ); } diff --git a/third_party/xla/xla/mlir_hlo/lhlo_gpu/IR/lhlo_gpu_ops_enums.td b/third_party/xla/xla/mlir_hlo/lhlo_gpu/IR/lhlo_gpu_ops_enums.td index 2c365d94deecc2..480025bfe3c2e9 100644 --- a/third_party/xla/xla/mlir_hlo/lhlo_gpu/IR/lhlo_gpu_ops_enums.td +++ b/third_party/xla/xla/mlir_hlo/lhlo_gpu/IR/lhlo_gpu_ops_enums.td @@ -121,6 +121,19 @@ def NormAlgorithmConfigAttr : AttrDef< let summary = "GPU Norm Algorithm configuration"; } +def CudnnNormKindLayerFwdInfer : I32EnumAttrCase<"LayerFwdInfer", 0>; +def CudnnNormKindLayerFwdTrain : I32EnumAttrCase<"LayerFwdTrain", 1>; +def CudnnNormKindLayerBwd : I32EnumAttrCase<"LayerBwd", 2>; + +def CudnnNormKind: I32EnumAttr<"CudnnNormKind", + "Mode of cuDNN norm", + [CudnnNormKindLayerFwdInfer, CudnnNormKindLayerFwdTrain, CudnnNormKindLayerBwd]> { + let genSpecializedAttr = 0; + let cppNamespace = "::mlir::lmhlo_gpu"; +} + +def CudnnNormKindAttr : EnumAttr; + def FusedMHAAlgorithmConfigAttr : AttrDef< LmhloGpuDialect, "FusedMHAAlgorithmConfig"> { let mnemonic = "fHMA_algorithm_config"; diff --git a/third_party/xla/xla/mlir_hlo/mhlo/IR/hlo_ops.cc b/third_party/xla/xla/mlir_hlo/mhlo/IR/hlo_ops.cc index 5c4ab30e7fe021..78c2ada2639b3b 100644 --- a/third_party/xla/xla/mlir_hlo/mhlo/IR/hlo_ops.cc +++ b/third_party/xla/xla/mlir_hlo/mhlo/IR/hlo_ops.cc @@ -169,16 +169,6 @@ hlo::HloDialectInterface* getMhloDialect(MLIRContext* context) { return dialect->getRegisteredInterface(); } -void createArgs(ArrayRef operands, - ArrayRef types, - SmallVector& args) { - for (auto argAndType : llvm::zip(operands, types)) { - auto& arg = args.emplace_back(); - arg.ssaName = std::get<0>(argAndType); - arg.type = std::get<1>(argAndType); - } -} - //===----------------------------------------------------------------------===// // Utilities for the canonicalize patterns //===----------------------------------------------------------------------===// @@ -389,6 +379,7 @@ INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(CosineOp) INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(CrossReplicaSumOp) INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(DivOp) INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(DomainOp) +INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(ErfOp) INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(ExpOp) INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(Expm1Op) INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(FloorOp) @@ -2392,6 +2383,16 @@ void BroadcastInDimOp::getCanonicalizationPatterns(RewritePatternSet& results, //===----------------------------------------------------------------------===// LogicalResult DynamicBroadcastInDimOp::verify() { + // Check for unranked dynamism. Unranked dynamism is not supported by + // StableHLO (hlo::verifyReshapeOp will fail) and we can't verify + // anything statically in that case anyway. + auto outputdimensionsType = + getOutputDimensions().getType().cast(); + auto resultType = getResult().getType().cast(); + if (!outputdimensionsType.hasRank() || !resultType.hasRank()) { + return success(); + } + return hlo::verifyDynamicBroadcastInDimOp( getLoc(), getOperand(), getOutputDimensions(), llvm::to_vector(getBroadcastDimensions().getValues()), @@ -2854,6 +2855,13 @@ LogicalResult ConcatenateOp::reifyReturnTypeShapes( //===----------------------------------------------------------------------===// LogicalResult DynamicReshapeOp::verify() { + // Check for unranked dynamism. Unranked dynamism is not supported by + // StableHLO (hlo::verifyDynamicReshapeOp will fail) and we can't verify + // anything statically in that case anyway. + auto resultType = getResult().getType().cast(); + auto outputShapeType = getOutputShape().getType().cast(); + if (!resultType.hasRank() || !outputShapeType.hasStaticShape()) + return success(); return hlo::verifyDynamicReshapeOp(getLoc(), getOutputShape(), getResult()); } @@ -3603,304 +3611,17 @@ bool hasSameOperandAndResultTypes(Operation& op) { llvm::all_of(op.getResultTypes(), typeMatch); } -// Checks the following eligibility criteria for compact printing of -// mhlo.reduce: -// E1. The reduce-op wraps a single inner-op in the associated region. -// E2. The single operation is a commutative binary-op from mhlo dialect, zero -// region, producing single result such that the operands and result all -// have the same type. -// E3. The reduce-op consist of at least one input-operand; The operand-types of -// inner-op should be derived trivially from the element-type of reduce-op's -// first input-operand. -// E4. The arguments of the region's only basic block are forwarded perfectly -// to inner-op's operands. -// E5. The reduce-op, inner-op, blocks arguments, and the return-op all have the -// same location. -// E6. The single operation result is perfectly forwarded to the reduce op -// return. -static bool isEligibleForCompactPrint(ReduceOp op) { - // Check E1. - auto& block = op.getBody().front(); - if (!hasSingleElement(block.without_terminator())) return false; - - Operation& innerOp = *block.begin(); - - // Check E2. - if (innerOp.getDialect() != op->getDialect()) return false; - - if (innerOp.getNumOperands() != 2 || - !innerOp.hasTrait() || - !hasSameOperandAndResultTypes(innerOp) || - !innerOp.hasTrait() || - !innerOp.hasTrait()) - return false; - - // Check E3. - if (op.getInputs().empty()) return false; - - auto elemType = - op.getInputs()[0].getType().cast().getElementType(); - auto expectedInnerOpType = RankedTensorType::get(/*shape=*/{}, elemType); - if (innerOp.getOperands()[0].getType() != expectedInnerOpType) return false; - - // Check E4. - if (!llvm::equal(block.getArguments(), innerOp.getOperands())) return false; - - // Check E5. - auto retOp = dyn_cast(block.getTerminator()); - if (!retOp) return false; - - auto blockArgLoc = block.getArgument(0).getLoc(); - if (blockArgLoc != block.getArgument(1).getLoc()) return false; - - if (innerOp.getLoc() != op.getLoc() || retOp.getLoc() != op.getLoc() || - blockArgLoc != op.getLoc()) - return false; - - // Check E6. - return llvm::equal(innerOp.getResults(), retOp.getOperands()); -} - void ReduceOp::print(OpAsmPrinter& p) { - { - // Print the pairs of operands under the form: - // (%arg0 init: %arg3), (%arg1 init: %arg4), (%arg2 init: %arg5) - StringRef comma = ""; - int numOperandPairs = getNumOperands() / 2; - for (int opId : llvm::seq(0, numOperandPairs)) { - p << comma << "(" << getOperand(opId) - << " init: " << getOperand(opId + numOperandPairs) << ")"; - comma = ", "; - } - } - - // If the reduce-op is eligible for compact printing, we emit the one-liner: - // mhlo.reduce applies across dimensions = [...] : - // Note: We are not printing the function type of reduction operation. We - // have some simplifying assumptions (refer to IsEligibleForCompactPrint::E3) - // to derive the type from that of reduce-op. - if (isEligibleForCompactPrint(*this)) { - Operation& innerOp = getBody().front().front(); - p << " applies "; - printEscapedString(innerOp.getName().getStringRef(), p.getStream()); - - p << " across dimensions = ["; - llvm::interleaveComma(getDimensions().getValues(), p); - p << "]"; - p << " : "; - p.printFunctionalType(*this); - } else { - p << " across dimensions = ["; - llvm::interleaveComma(getDimensions().getValues(), p); - p << "]"; - p.printOptionalAttrDict(getOperation()->getAttrs(), {"dimensions"}); - p << " : "; - p.printFunctionalType(*this); - p.printNewline(); - p << " reducer"; - { - // Print the pairs of block operands under the form: - // (%arg0_elt, %arg0_acc) (%arg1_elt, %arg1_acc): - Block& reducer = getBody().front(); - int numOperandPairs = getNumOperands() / 2; - for (int opId : llvm::seq(0, numOperandPairs)) { - p << "("; - p.printRegionArgument(reducer.getArgument(opId)); - p << ", "; - p.printRegionArgument(reducer.getArgument(opId + numOperandPairs)); - p << ") "; - } - } - p << ' '; - p.printRegion(getBody(), /*printEntryBlockArgs=*/false); - } + auto dimensions = llvm::to_vector(getDimensions().getValues()); + hlo::printReduceOp(p, getOperation(), getInputs(), dimensions, getBody()); } ParseResult ReduceOp::parse(OpAsmParser& parser, OperationState& result) { - llvm::SMLoc loc = parser.getCurrentLocation(); - Location currLocation = parser.getEncodedSourceLoc(loc); - - // Parse the operands of reduce-op, this is a list of pair under the form: - // (%arg0 init: %arg3), (%arg1 init: %arg4), (%arg2 init: %arg5) - // Each input to reduce is paired with its init value, even though in memory - // they are stored with the input first and the init values after. - SmallVector operands; - SmallVector initOperands; - do { - (void)parser.parseOptionalComma(); - if (parser.parseOptionalLParen()) break; - OpAsmParser::UnresolvedOperand operand, initOperand; - if (parser.parseOperand(operand) || parser.parseKeyword("init") || - parser.parseColon() || parser.parseOperand(initOperand) || - parser.parseRParen()) - return failure(); - operands.push_back(operand); - initOperands.push_back(initOperand); - } while (true); - operands.append(initOperands); - - // Check if we are parsing the compact version of reduce-op: - // mhlo.reduce applies across dimensions = [...] : - // else parse the "region-based" variant. - if (failed(parser.parseOptionalKeyword("applies"))) { - // Parse the inner-op dimensions, reduce-op's function-type and - // optional location. - SmallVector dimensions; - auto parseDim = [&]() -> ParseResult { - if (parser.parseInteger(dimensions.emplace_back())) return failure(); - return success(); - }; - - FunctionType reduceOpFntype; - if (parser.parseKeyword("across") || parser.parseKeyword("dimensions") || - parser.parseEqual() || - parser.parseCommaSeparatedList(AsmParser::Delimiter::Square, - parseDim) || - parser.parseOptionalAttrDict(result.attributes) || - parser.parseColon() || parser.parseType(reduceOpFntype) || - parser.parseKeyword("reducer")) - return failure(); - OpBuilder builder(parser.getBuilder().getContext()); - result.addAttribute("dimensions", builder.getI64TensorAttr(dimensions)); - - // Parse the "reducer" region now. - SmallVector reducerOperands; - SmallVector reducerInitOperands; - SmallVector reducerTypes; - SmallVector reducerInitTypes; - SmallVector, 2> reducerLocs; - SmallVector, 2> reducerInitLocs; - auto parseBlockOperand = - [&](SmallVectorImpl& operands, - SmallVectorImpl& types, - SmallVectorImpl>& locs) -> ParseResult { - OpAsmParser::UnresolvedOperand operand; - Type type; - std::optional loc; - if (parser.parseOperand(operand, /*allowResultNumber=*/false) || - parser.parseColon() || parser.parseType(type) || - parser.parseOptionalLocationSpecifier(loc)) - return failure(); - operands.push_back(operand); - types.push_back(type); - locs.push_back(loc); - return success(); - }; - do { - if (failed(parser.parseOptionalLParen())) break; - if (parseBlockOperand(reducerOperands, reducerTypes, reducerLocs) || - parser.parseComma() || - parseBlockOperand(reducerInitOperands, reducerInitTypes, - reducerInitLocs) || - parser.parseRParen()) - return failure(); - } while (true); - reducerOperands.append(reducerInitOperands); - reducerTypes.append(reducerInitTypes); - reducerLocs.append(reducerInitLocs); - result.addTypes(reduceOpFntype.getResults()); - SmallVector reducerArgs; - createArgs(reducerOperands, reducerTypes, reducerArgs); - - // Derive the SSA-values for reduce-op's operands and parse the region, and - // the optional trailing location. - std::optional trailingLoc; - if (parser.resolveOperands(operands, reduceOpFntype.getInputs(), loc, - result.operands) || - parser.parseRegion(*result.addRegion(), reducerArgs)) - return failure(); - // Set the individual block arguments. - for (auto argAndLoc : - llvm::zip(result.regions.front()->front().getArguments(), reducerLocs)) - if (std::get<1>(argAndLoc)) - std::get<0>(argAndLoc).setLoc(std::get<1>(argAndLoc).value()); - result.location = trailingLoc.value_or(currLocation); - return success(); - } - - // Parse the inner-op name and check if the contract on inner-op - // mentioned in "isEligibleForCompactPrint::E2" for pretty-priting is met. - FailureOr innerOpNameInfo = parser.parseCustomOperationName(); - if (failed(innerOpNameInfo)) return failure(); - - StringRef innerOpName = innerOpNameInfo->getStringRef(); - Dialect* innerOpDialect = innerOpNameInfo->getDialect(); - if (!innerOpDialect || !innerOpDialect->getNamespace().equals("mhlo") || - !innerOpNameInfo->hasTrait::Impl>() || - !innerOpNameInfo->hasTrait() || - !innerOpNameInfo->hasTrait() || - !innerOpNameInfo->hasTrait()) { - parser.emitError(loc, - "expected the inner-op to be a commutative binary-op from " - "mhlo dialect, zero region, producing single result"); - return failure(); - } - - // Parse the inner-op dimensions, reduce-op's function-type and - // optional location. - SmallVector dimensions; - auto parseDim = [&]() -> ParseResult { - if (parser.parseInteger(dimensions.emplace_back())) return failure(); - return success(); + auto parseDenseElements = [](OpBuilder& b, + ArrayRef dims) -> Attribute { + return b.getI64TensorAttr(dims); }; - - std::optional explicitLoc; - FunctionType reduceOpFntype; - if (parser.parseKeyword("across") || parser.parseKeyword("dimensions") || - parser.parseEqual() || - parser.parseCommaSeparatedList(AsmParser::Delimiter::Square, parseDim) || - parser.parseColon() || parser.parseType(reduceOpFntype) || - parser.parseOptionalLocationSpecifier(explicitLoc)) - return failure(); - - if (!reduceOpFntype || reduceOpFntype.getInputs().empty()) { - if (!reduceOpFntype) return parser.emitError(loc, "expected function type"); - return parser.emitError(loc, - "input types missing in reduce-op function type"); - } - - // If location of reduce-op is explicitly provided, then use it; Else use - // the parser's current location. - Location reduceOpLoc = explicitLoc.value_or(currLocation); - - // Derive the SSA-values for reduce-op's operands. - if (parser.resolveOperands(operands, reduceOpFntype.getInputs(), loc, - result.operands)) - return failure(); - - // Derive the type of inner-op from that of reduce-op's input operand. - auto innerOpType = RankedTensorType::get( - /*shape=*/{}, getElementTypeOrSelf(reduceOpFntype.getInput(0))); - - // Add a region for reduce-op. - Region& region = *result.addRegion(); - - // Create a basic-block inside reduce-op's region. - Block& block = region.emplaceBlock(); - auto lhs = block.addArgument(innerOpType, reduceOpLoc); - auto rhs = block.addArgument(innerOpType, reduceOpLoc); - - // Create and insert an "inner-op" operation in the block. - OpBuilder builder(parser.getBuilder().getContext()); - builder.setInsertionPointToStart(&block); - - OperationState innerOpState(reduceOpLoc, innerOpName); - innerOpState.operands.push_back(lhs); - innerOpState.operands.push_back(rhs); - innerOpState.addTypes(innerOpType); - - Operation* innerOp = builder.create(innerOpState); - - // Insert a return statement in the block returning the inner-op's result. - builder.create(innerOp->getLoc(), innerOp->getResults()); - - // Populate the reduce-op operation-state with result-type, location, and - // dimension attribute. - result.addTypes(reduceOpFntype.getResults()); - result.location = innerOp->getLoc(); - result.addAttribute("dimensions", builder.getI64TensorAttr(dimensions)); - - return success(); + return hlo::parseReduceOp(parser, result, parseDenseElements); } LogicalResult ReduceOp::inferReturnTypeComponents( @@ -4604,6 +4325,14 @@ LogicalResult DynamicPadOp::reifyReturnTypeShapes( //===----------------------------------------------------------------------===// LogicalResult ReshapeOp::verify() { + // Check for unranked dynamism. Unranked dynamism is not supported by + // StableHLO (hlo::verifyReshapeOp will fail) and we can't verify + // anything statically in that case anyway. + auto operandType = getOperand().getType().cast(); + auto resultType = getResult().getType().cast(); + if (!operandType.hasRank() || !resultType.hasRank()) { + return success(); + } return hlo::verifyReshapeOp(getLoc(), getOperand(), getResult()); } @@ -4915,6 +4644,7 @@ UNARY_FOLDER_FLOAT(RoundNearestEvenOp, RoundNearestEven) UNARY_FOLDER_FLOAT(RoundOp, Round) UNARY_FOLDER_UPCAST_TO_F64(CosineOp, std::cos, AnyValue) +UNARY_FOLDER_UPCAST_TO_F64(ErfOp, std::erf, AnyValue) UNARY_FOLDER_UPCAST_TO_F64(ExpOp, std::exp, AnyValue) UNARY_FOLDER_UPCAST_TO_F64(LogisticOp, logistic, AnyValue) UNARY_FOLDER_UPCAST_TO_F64(LogOp, std::log, PositiveValue) @@ -6226,70 +5956,12 @@ LogicalResult WhileOp::verify() { return hlo::verifyWhileOp(getLoc(), getOperand(), getCond(), getBody()); } -/// Print a `while` op. -/// -/// op ::= `mhlo.while` `(` assignment-list `)` `:` types attribute-dict -/// `cond` region -/// `do` region -/// assignment-list ::= assignment | assignment `,` assignment-list -/// assignment ::= ssa-value `=` ssa-value void WhileOp::print(OpAsmPrinter& p) { - p << '('; - llvm::interleaveComma( - llvm::zip(SingleBlock::getBody()->getArguments(), getOperands()), p, - [&](auto zip) { - p.printOperand(std::get<0>(zip)); - p << " = "; - p.printOperand(std::get<1>(zip)); - }); - p << ")"; - if (getNumOperands()) { - p << " : "; - llvm::interleaveComma(getOperandTypes(), p); - } - p.printOptionalAttrDictWithKeyword(getOperation()->getAttrs()); - p.printNewline(); - p << " cond "; - p.printRegion(getRegion(0), /*printEntryBlockArgs=*/false); - p << " do "; - p.printRegion(getRegion(1), /*printEntryBlockArgs=*/false); + hlo::printWhileOp(p, getOperation(), getCond(), getBody()); } ParseResult WhileOp::parse(OpAsmParser& parser, OperationState& result) { - llvm::SMLoc loc = parser.getCurrentLocation(); - // Parse the operands of the while: these are of the form: - // %iter_arg = %init_val - // where %iter_arg is the name of the block argument in the cond/body blocks - // and %init_val is the actual operand. - SmallVector operands; - SmallVector iterArgs; - if (parser.parseLParen()) return failure(); - do { - if (succeeded(parser.parseOptionalRParen())) break; - OpAsmParser::UnresolvedOperand operand, iterArg; - if (parser.parseOperand(iterArg) || parser.parseEqual() || - parser.parseOperand(operand)) - return failure(); - iterArgs.push_back(iterArg); - operands.push_back(operand); - if (succeeded(parser.parseOptionalRParen())) break; - if (failed(parser.parseComma())) return failure(); - } while (true); - if (!operands.empty()) { - if (parser.parseColon() || parser.parseTypeList(result.types)) - return failure(); - } - - SmallVector args; - createArgs(iterArgs, result.types, args); - if (parser.resolveOperands(operands, result.types, loc, result.operands) || - parser.parseOptionalAttrDictWithKeyword(result.attributes) || - parser.parseKeyword("cond") || - parser.parseRegion(*result.addRegion(), args) || - parser.parseKeyword("do") || - parser.parseRegion(*result.addRegion(), args)) - return failure(); - return success(); + return hlo::parseWhileOp(parser, result); } LogicalResult WhileOp::fold(FoldAdaptor /*adaptor*/, diff --git a/third_party/xla/xla/mlir_hlo/mhlo/IR/hlo_ops.td b/third_party/xla/xla/mlir_hlo/mhlo/IR/hlo_ops.td index 307b4652559aba..680b80fc3e94c7 100644 --- a/third_party/xla/xla/mlir_hlo/mhlo/IR/hlo_ops.td +++ b/third_party/xla/xla/mlir_hlo/mhlo/IR/hlo_ops.td @@ -305,6 +305,23 @@ def MHLO_CosineOp: MHLO_UnaryElementwiseOp<"cosine", let hasCustomHLOConverter = 1; } +def MHLO_ErfOp: MHLO_UnaryElementwiseOp<"erf", + [Pure, HLO_CompatibleOperandsAndResultType], MHLO_FpTensor> { + let summary = "Erf operation"; + let description = [{ + Performs element-wise erf operation on `operand` tensor and produces a + `result` tensor. + + See: + https://github.com/openxla/stablehlo/blob/main/docs/spec.md#erf + + Example: + ```mlir + %result = mhlo.erf %operand : tensor<2x2xf32> + ``` + }]; + let hasFolder = 1; +} def MHLO_ExpOp: MHLO_UnaryElementwiseOp<"exponential", [Pure, HLO_CompatibleOperandsAndResultType], MHLO_FpComplexOrQuantizedIntTensor> { let summary = "Exp operation"; diff --git a/third_party/xla/xla/mlir_hlo/mhlo/IR/hlo_utils.td b/third_party/xla/xla/mlir_hlo/mhlo/IR/hlo_utils.td index 09d80cd743e928..079e363b826f5d 100644 --- a/third_party/xla/xla/mlir_hlo/mhlo/IR/hlo_utils.td +++ b/third_party/xla/xla/mlir_hlo/mhlo/IR/hlo_utils.td @@ -44,6 +44,8 @@ def MHLO_ConstantLikeNegInfValue : NativeCodeCall< def NullDenseIntElementsAttr : NativeCodeCall<"DenseIntElementsAttr()">; +def NullDenseI64ArrayAttr : NativeCodeCall<"DenseI64ArrayAttr()">; + def BinBroadcastDimensions : NativeCodeCall< "hlo::getBroadcastDimensionsAttr(&$_builder, $0, $1)">; diff --git a/third_party/xla/xla/mlir_hlo/mhlo/analysis/CMakeLists.txt b/third_party/xla/xla/mlir_hlo/mhlo/analysis/CMakeLists.txt index 4fb958c6d687b0..de8c9c223c8254 100644 --- a/third_party/xla/xla/mlir_hlo/mhlo/analysis/CMakeLists.txt +++ b/third_party/xla/xla/mlir_hlo/mhlo/analysis/CMakeLists.txt @@ -20,7 +20,6 @@ add_mlir_library(MhloTestAnalysis LMHLOTransformsPassIncGen LINK_LIBS PUBLIC - MLIRHLOAnalysis MLIRAnalysis MLIRPass MLIRTransforms diff --git a/third_party/xla/xla/mlir_hlo/mhlo/transforms/CMakeLists.txt b/third_party/xla/xla/mlir_hlo/mhlo/transforms/CMakeLists.txt index aa44eaca01a782..cfa9453226b794 100644 --- a/third_party/xla/xla/mlir_hlo/mhlo/transforms/CMakeLists.txt +++ b/third_party/xla/xla/mlir_hlo/mhlo/transforms/CMakeLists.txt @@ -70,7 +70,6 @@ add_mlir_library(MhloPasses prepare_for_export/prepare_for_export.cc optimize_mhlo/optimize_mhlo.cc optimize_mhlo/optimize_mhlo_pass.cc - rank_specialization/rank_specialization.cc restrict_max_rank/restrict_max_rank.cc shape_legalize_to_hlo/shape_legalize_to_hlo.cc shape_reification/shape_reification_pass.cc @@ -107,6 +106,26 @@ add_mlir_library(MhloPasses StablehloBroadcastUtils ) +add_mlir_library(MhloQuantToIntConversion + mhlo_quant_legalize_to_int/mhlo_quant_legalize_to_int.cc + + DEPENDS + MLIRhlo_opsIncGen + MLIRMhloPassIncGen + + LINK_COMPONENTS + Core + + LINK_LIBS PUBLIC + LmhloDialect + MhloDialect + MhloTypeConversion + MLIRIR + MLIRPass + MLIRMathDialect + MLIRTransforms + MLIRTransformUtils +) add_mlir_library(MhloToMemrefConversion hlo_legalize_to_memref/hlo_legalize_to_memref.cc @@ -291,6 +310,7 @@ add_library(AllMhloPasses INTERFACE) target_link_libraries(AllMhloPasses INTERFACE ChloPasses MhloPasses + MhloQuantToIntConversion MhloToArithmeticConversion MhloToMemrefConversion MhloToStandard diff --git a/third_party/xla/xla/mlir_hlo/mhlo/transforms/chlo_legalize_to_hlo/chlo_legalize_to_hlo.cc b/third_party/xla/xla/mlir_hlo/mhlo/transforms/chlo_legalize_to_hlo/chlo_legalize_to_hlo.cc index ad9b94627f3d9d..cb7a6ba5114990 100644 --- a/third_party/xla/xla/mlir_hlo/mhlo/transforms/chlo_legalize_to_hlo/chlo_legalize_to_hlo.cc +++ b/third_party/xla/xla/mlir_hlo/mhlo/transforms/chlo_legalize_to_hlo/chlo_legalize_to_hlo.cc @@ -589,7 +589,7 @@ Value materializeErfcApproximationF32(ConversionPatternRewriter &rewriter, erfcApprox); } -struct ConvertErfOp : public OpConversionPattern { +struct BasisConvertErfOp : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite( ErfOp op, OpAdaptor adaptor, @@ -1553,10 +1553,7 @@ struct ConvertSinhOp : public OpConversionPattern { // (tensor<16x16xf32>) -> tensor<16x8xf32> // %6 = "mhlo.slice"(%4) ... // -// TODO(b/284078162): Decide what to do with this pattern given that we now -// have mhlo::TopKOp. No action needed for now given that mhlo::TopKOp is -// currently categorized as `hasPrivateFeaturesNotInStablehlo`. -struct ConvertTopKOp : public OpConversionPattern { +struct BasisConvertTopKOp : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite( TopKOp op, OpAdaptor /*adaptor*/, @@ -1941,6 +1938,14 @@ void populateChloBroadcastingPatterns(MLIRContext *context, context); } +void populateChloLegalizeToHloBasisOpsPatterns(MLIRContext *context, + RewritePatternSet *patterns) { + // Patterns that decompose to a basis set of HLOs + // These are guaranteed to be convertible to StableHLO, but discard some + // higher level information that is useful to XLA compilation. + patterns->add(context); +} + void populateDecomposeChloPatterns(MLIRContext *context, RewritePatternSet *patterns) { populateWithGenerated(*patterns); @@ -1950,14 +1955,12 @@ void populateDecomposeChloPatterns(MLIRContext *context, patterns->add(context); // clang-format on } diff --git a/third_party/xla/xla/mlir_hlo/mhlo/transforms/chlo_legalize_to_hlo/chlo_legalize_to_hlo_pass.cc b/third_party/xla/xla/mlir_hlo/mhlo/transforms/chlo_legalize_to_hlo/chlo_legalize_to_hlo_pass.cc index 643d74493ef616..058e6db42fc321 100644 --- a/third_party/xla/xla/mlir_hlo/mhlo/transforms/chlo_legalize_to_hlo/chlo_legalize_to_hlo_pass.cc +++ b/third_party/xla/xla/mlir_hlo/mhlo/transforms/chlo_legalize_to_hlo/chlo_legalize_to_hlo_pass.cc @@ -25,12 +25,15 @@ limitations under the License. #include "mlir/Dialect/Shape/IR/Shape.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Pass/Pass.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/DialectConversion.h" #include "stablehlo/dialect/ChloOps.h" namespace mlir { namespace mhlo { #define GEN_PASS_DEF_CHLOLEGALIZETOHLOPASS +#define GEN_PASS_DEF_CHLOLEGALIZETOHLOBASISOPSPASS #include "mhlo/transforms/mhlo_passes.h.inc" namespace { @@ -45,10 +48,6 @@ struct ChloLegalizeToHloPass this->expand_compositions_ = expandCompositions; } - void getDependentDialects(DialectRegistry ®istry) const override { - registry.insert(); - } - void runOnOperation() override { ConversionTarget conversionTarget(getContext()); RewritePatternSet conversionPatterns(&getContext()); @@ -81,6 +80,37 @@ struct ChloLegalizeToHloPass } }; +struct ChloLegalizeToHloBasisOpsPass + : public impl::ChloLegalizeToHloBasisOpsPassBase< + ChloLegalizeToHloBasisOpsPass> { + using ChloLegalizeToHloBasisOpsPassBase::ChloLegalizeToHloBasisOpsPassBase; + + void runOnOperation() override { + ConversionTarget conversionTarget(getContext()); + RewritePatternSet conversionPatterns(&getContext()); + + // Patterns will only be applied to these ops + conversionTarget.addIllegalOp(); + + // Programs with MHLO equivalents to the StableHLO ops are likely bugs + // for users of this expander pass, so best to disallow. + conversionTarget.addIllegalOp(); // TODO: Add ErfOp + + // Given that the resulting patterns should be convertible to StableHLO + // Only MHLO should be legal. + conversionTarget + .addLegalDialect(); + + chlo::populateChloLegalizeToHloBasisOpsPatterns(&getContext(), + &conversionPatterns); + + if (failed(applyPartialConversion(getOperation(), conversionTarget, + std::move(conversionPatterns)))) { + return signalPassFailure(); + } + } +}; + } // namespace std::unique_ptr> createChloLegalizeToHloPass( @@ -89,5 +119,10 @@ std::unique_ptr> createChloLegalizeToHloPass( expandCompositions); } +std::unique_ptr> +createChloLegalizeToHloBasisOpsPass() { + return std::make_unique(); +} + } // namespace mhlo } // namespace mlir diff --git a/third_party/xla/xla/mlir_hlo/mhlo/transforms/chlo_legalize_to_hlo/chlo_legalize_to_hlo_patterns.td b/third_party/xla/xla/mlir_hlo/mhlo/transforms/chlo_legalize_to_hlo/chlo_legalize_to_hlo_patterns.td index 4b05cb1e6fa9a9..937739f2851454 100644 --- a/third_party/xla/xla/mlir_hlo/mhlo/transforms/chlo_legalize_to_hlo/chlo_legalize_to_hlo_patterns.td +++ b/third_party/xla/xla/mlir_hlo/mhlo/transforms/chlo_legalize_to_hlo/chlo_legalize_to_hlo_patterns.td @@ -355,3 +355,9 @@ def : Pat<(CHLO_ConstantOp $v), def : Pat<(CHLO_TanOp $v), (MHLO_TanOp $v)>; + +def : Pat<(CHLO_ErfOp $v), + (MHLO_ErfOp $v)>; + +def : Pat<(CHLO_TopKOp AnyRankedTensor:$v, $k), + (MHLO_TopKOp $v, $k, ConstBoolAttrTrue)>; diff --git a/third_party/xla/xla/mlir_hlo/mhlo/transforms/hlo_legalize_to_arithmetic/hlo_legalize_to_arithmetic.cc b/third_party/xla/xla/mlir_hlo/mhlo/transforms/hlo_legalize_to_arithmetic/hlo_legalize_to_arithmetic.cc index a137f36298652b..fead7be62bf1c2 100644 --- a/third_party/xla/xla/mlir_hlo/mhlo/transforms/hlo_legalize_to_arithmetic/hlo_legalize_to_arithmetic.cc +++ b/third_party/xla/xla/mlir_hlo/mhlo/transforms/hlo_legalize_to_arithmetic/hlo_legalize_to_arithmetic.cc @@ -207,6 +207,7 @@ void populateScalarHloToArithmeticConversionPatterns( ScalarHloToArithmeticPattern, ScalarHloToArithmeticPattern, ScalarHloToArithmeticPattern, + ScalarHloToArithmeticPattern, ScalarHloToArithmeticPattern, ScalarHloToArithmeticPattern, ScalarHloToArithmeticPattern, diff --git a/third_party/xla/xla/mlir_hlo/mhlo/transforms/hlo_legalize_to_stablehlo/hlo_legalize_to_stablehlo.cc b/third_party/xla/xla/mlir_hlo/mhlo/transforms/hlo_legalize_to_stablehlo/hlo_legalize_to_stablehlo.cc index 0840db93ffa854..8de15bf6c65c1e 100644 --- a/third_party/xla/xla/mlir_hlo/mhlo/transforms/hlo_legalize_to_stablehlo/hlo_legalize_to_stablehlo.cc +++ b/third_party/xla/xla/mlir_hlo/mhlo/transforms/hlo_legalize_to_stablehlo/hlo_legalize_to_stablehlo.cc @@ -147,6 +147,12 @@ std::optional getPublicFeaturesNotInStablehlo(HloOpTy hloOp) { // Version 1: Initial version for TopK. return 1; } + // StableHLO doesn't support TopK yet. + // Proposal: https://github.com/openxla/stablehlo/pull/1593 + if constexpr (std::is_same::value) { + // Version 1: Initial version for ErfOp. + return 1; + } return std::nullopt; } @@ -649,7 +655,8 @@ void populateHloToStablehloPatterns(RewritePatternSet* patterns, #include "stablehlo/dialect/StablehloOps.cpp.inc" >(patterns, converter, context, allowExperimentalFeatures); - populateHloToStablehloCustomCallPatterns( + populateHloToStablehloCustomCallPatterns( patterns, converter, context, allowExperimentalFeatures); } diff --git a/third_party/xla/xla/mlir_hlo/mhlo/transforms/legalize_to_linalg/legalize_to_linalg.cc b/third_party/xla/xla/mlir_hlo/mhlo/transforms/legalize_to_linalg/legalize_to_linalg.cc index 7abb36de697b41..a3c786179169c3 100644 --- a/third_party/xla/xla/mlir_hlo/mhlo/transforms/legalize_to_linalg/legalize_to_linalg.cc +++ b/third_party/xla/xla/mlir_hlo/mhlo/transforms/legalize_to_linalg/legalize_to_linalg.cc @@ -3263,7 +3263,7 @@ struct ConvolutionOpGeneralConversion // Finally, create the computation auto inferredMaps = - AffineMap::inferFromExprList({srcExprs, windowExprs, dstExprs}); + AffineMap::inferFromExprList({srcExprs, windowExprs, dstExprs}, ctx); Value emptyTensor = rewriter.create( loc, reshapedResultShape, resultType.getElementType()); @@ -3578,7 +3578,7 @@ struct ReduceWindowOpOnTensorsGenericConversion SmallVector inferredMaps(3, AffineMap::get(ctx)); if (rank > 0) inferredMaps = - AffineMap::inferFromExprList({srcExprs, windowExprs, dstExprs}); + AffineMap::inferFromExprList({srcExprs, windowExprs, dstExprs}, ctx); SmallVector indexingMaps; @@ -4504,6 +4504,7 @@ void populateHloToLinalgConversionPattern(MLIRContext* context, PointwiseToLinalgMapConverter, PointwiseToLinalgMapConverter, PointwiseToLinalgMapConverter, + PointwiseToLinalgMapConverter, PointwiseToLinalgMapConverter, PointwiseToLinalgMapConverter, PointwiseToLinalgMapConverter, @@ -4563,6 +4564,7 @@ void populateHloToLinalgConversionPattern(MLIRContext* context, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, + PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, diff --git a/third_party/xla/xla/mlir_hlo/mhlo/transforms/map_mhlo_to_scalar_op.h b/third_party/xla/xla/mlir_hlo/mhlo/transforms/map_mhlo_to_scalar_op.h index 5fcb5a6bf8eaab..f724059ff4982e 100644 --- a/third_party/xla/xla/mlir_hlo/mhlo/transforms/map_mhlo_to_scalar_op.h +++ b/third_party/xla/xla/mlir_hlo/mhlo/transforms/map_mhlo_to_scalar_op.h @@ -81,6 +81,10 @@ struct MhloToScalarOp { using COp = ::mlir::complex::CosOp; }; template <> +struct MhloToScalarOp { + using FOp = ::mlir::math::ErfOp; +}; +template <> struct MhloToScalarOp { using FOp = ::mlir::math::ExpOp; using COp = ::mlir::complex::ExpOp; @@ -470,10 +474,9 @@ inline Value mapMhloOpToStdScalarOp( mlir::ImplicitLocOpBuilder b(loc, *builder); // Integer and float types for casting and constant generation. - auto floatType = - argTypes.front().cast().getElementType().cast(); + auto floatType = getElementTypeOrSelf(argTypes.front()).cast(); int64_t nbits = floatType.getWidth(); - auto intType = mlir::IntegerType::get(loc.getContext(), floatType.getWidth()); + auto intType = mlir::IntegerType::get(loc.getContext(), nbits); Value xAsInt = b.create(intType, adaptor.getOperand()); @@ -1072,13 +1075,18 @@ template <> inline Value mapMhloOpToStdScalarOp( Location loc, ArrayRef resultTypes, ArrayRef /*argTypes*/, mhlo::LogisticOp::Adaptor adaptor, OpBuilder* b) { - // 1.0 / (1.0 - exp(-x)) + // 1.0 / (1.0 + exp(-x)) Value negX = mapMhloOpToStdScalarOp( loc, resultTypes, resultTypes, {adaptor.getOperand()}, b); Value expNegX = mapMhloOpToStdScalarOp(loc, resultTypes, resultTypes, {{negX}}, b); - Value oneFloat = b->create(loc, b->getF32FloatAttr(1.0)); + Type type = getElementTypeOrSelf(resultTypes[0]); + Value oneFloat = + type.isa() + ? b->create(loc, b->getF32FloatAttr(1.0)) + : getConstantOrSplat(b, loc, resultTypes[0], + FloatAttr::get(type, 1.0f)); Value one = mapConvertOpToStdScalarOp(loc, resultTypes, resultTypes, {oneFloat.getType()}, {{oneFloat}}, b); Value oneAddExprNegX = mapMhloOpToStdScalarOp( diff --git a/third_party/xla/xla/mlir_hlo/mhlo/transforms/mhlo_passes.td b/third_party/xla/xla/mlir_hlo/mhlo/transforms/mhlo_passes.td index a6cb43bf1ad412..4eb55548710b3b 100644 --- a/third_party/xla/xla/mlir_hlo/mhlo/transforms/mhlo_passes.td +++ b/third_party/xla/xla/mlir_hlo/mhlo/transforms/mhlo_passes.td @@ -18,6 +18,8 @@ include "mlir/Pass/PassBase.td" def ChloLegalizeToHloPass : Pass<"chlo-legalize-to-hlo", "func::FuncOp"> { let summary = "Legalize CHLO to HLO."; let constructor = "createChloLegalizeToHloPass()"; + let dependentDialects = ["mhlo::MhloDialect", "chlo::ChloDialect", + "shape::ShapeDialect", "scf::SCFDialect"]; let options = [ Option<"legalize_broadcasts_", "legalize-broadcasts", "bool", /*default=*/"true", "Legalize implicit broadcasts to explicit HLO broadcasting forms">, @@ -26,6 +28,18 @@ def ChloLegalizeToHloPass : Pass<"chlo-legalize-to-hlo", "func::FuncOp"> { ]; } +def ChloLegalizeToHloBasisOpsPass : Pass<"chlo-legalize-to-hlo-basis-ops", "func::FuncOp"> { + let summary = "Legalize specific CHLO ops (e.g. ErfOf and TopKOp) to basis MHLO ops."; + let description = [{ + XLA has specialization for certain CHLO ops (ErfOp, TopKOp), and other + backends still require decomposition of these ops into the basis set which + can be converted safely to StableHLO. This pass is needed until we have + direct CHLO to StableHLO lowerings. + }]; + let constructor = "createChloLegalizeToHloBasisOpsPass()"; + let dependentDialects = ["mhlo::MhloDialect", "chlo::ChloDialect"]; +} + def LegalizeSparseOpsPass : Pass<"legalize-sparse-ops", "func::FuncOp"> { let summary = "Legalize from sparse ops before convert MLIR to XLA computation."; let constructor = "createLegalizeSparseOperationsPass()"; @@ -320,23 +334,6 @@ def SparseRewritingPass : Pass<"mhlo-sparse-rewriting", "func::FuncOp"> { let constructor = "createSparseRewritingPass()"; } -/// Rank specialization passes. - -def RankSpecializationClusterPass - : Pass<"mhlo-rank-specialization-cluster", "func::FuncOp"> { - let constructor = "createRankSpecializationClusterPass()"; -} - -def RankSpecializationToSCFPass - : Pass<"mhlo-rank-specialization-to-scf", "func::FuncOp"> { - let constructor = "createRankSpecializationToSCFPass()"; - let options = [ - Option<"max_target_rank_", "max-target-rank", "int", /*default=*/"8", - "The maximum supported rank after rank specialization. Any argument " - "of greater rank may result in a runtime failure.">, - ]; -} - def MhloExpandOpsSimplifierPass : Pass<"mhlo-expand-ops-simplifier", "func::FuncOp"> { let summary = "Expand feature rich mhlo ops into a set of simpler mhlo ops."; @@ -397,3 +394,16 @@ def ShapeLegalizeToHloPass : Pass<"shape-legalize-to-hlo", "func::FuncOp"> { /*default=*/"false", "Whether to legalize Cstr Ops to shape_assertion custom_call"> ]; } + +def MhloQuantLegalizeToInt : Pass<"mhlo-quant-legalize-to-int", "mlir::func::FuncOp"> { + let summary = "Convert from MHLO quantized ops to MHLO primitive ops."; + + let description = [{ + Convert from MHLO quantized ops with MHLO quant types to MHLO primitive ops + like int ops. + }]; + let constructor = "createMhloQuantLegalizeToIntPass()"; + let dependentDialects = ["chlo::ChloDialect", "mhlo::MhloDialect", + "quant::QuantizationDialect", + "func::FuncDialect"]; +} diff --git a/third_party/xla/xla/mlir_hlo/mhlo/transforms/mhlo_quant_legalize_to_int/mhlo_quant_legalize_to_int.cc b/third_party/xla/xla/mlir_hlo/mhlo/transforms/mhlo_quant_legalize_to_int/mhlo_quant_legalize_to_int.cc new file mode 100644 index 00000000000000..3ff1edf08f520d --- /dev/null +++ b/third_party/xla/xla/mlir_hlo/mhlo/transforms/mhlo_quant_legalize_to_int/mhlo_quant_legalize_to_int.cc @@ -0,0 +1,1375 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include +#include +#include +#include + +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "mhlo/IR/hlo_ops.h" +#include "mhlo/transforms/passes.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Func/Transforms/FuncConversions.h" +#include "mlir/Dialect/Quant/QuantOps.h" +#include "mlir/Dialect/Quant/QuantTypes.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypeInterfaces.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Location.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/OperationSupport.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/Region.h" +#include "mlir/IR/TypeUtilities.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/DialectConversion.h" +#include "stablehlo/dialect/ChloOps.h" + +namespace mlir::mhlo { +namespace { + +// TODO: b/311218165 - consider extract this to common utils and better ways to +// handle polymorphism. +using QuantType = std::variant; +FailureOr getQuantType(Type type) { + if (auto quantType = + getElementTypeOrSelf(type).dyn_cast()) { + return QuantType(quantType); + } + if (auto quantType = getElementTypeOrSelf(type) + .dyn_cast()) { + return QuantType(quantType); + } + return failure(); +} + +bool isPerTensorType(QuantType quantType) { + return std::holds_alternative(quantType); +} + +bool isPerChannelType(QuantType quantType) { + return std::holds_alternative(quantType); +} + +quant::UniformQuantizedType getPerTensorType(QuantType quantType) { + return std::get(quantType); +} + +quant::UniformQuantizedPerAxisType getPerChannelType(QuantType quantType) { + return std::get(quantType); +} + +// Extracts scale and zero point info from input quant type info. +void getQuantizationParams(OpBuilder &builder, Location loc, + QuantType quantType, Value &scales, + Value &zeroPoints, bool outputZeroPointInFp, + DenseI64ArrayAttr &broadcastDims) { + // Get scales/zero points for per-tensor and per-axis quantization cases. + if (auto *quantPerTensorType = + std::get_if(&quantType)) { + scales = builder.create( + loc, builder.getF32FloatAttr(quantPerTensorType->getScale())); + if (outputZeroPointInFp) { + zeroPoints = builder.create( + loc, builder.getF32FloatAttr( + static_cast(quantPerTensorType->getZeroPoint()))); + } else { + zeroPoints = builder.create( + loc, builder.getI32IntegerAttr( + static_cast(quantPerTensorType->getZeroPoint()))); + } + } else { + auto &quantPerChannelType = + std::get(quantType); + SmallVector scalesVec; + for (auto scale : quantPerChannelType.getScales()) + scalesVec.push_back(scale); + scales = builder.create( + loc, + DenseFPElementsAttr::get( + RankedTensorType::get( + {static_cast(quantPerChannelType.getScales().size())}, + builder.getF32Type()), + scalesVec)); + if (outputZeroPointInFp) { + SmallVector zeroPointsVec; + for (auto zeroPoint : quantPerChannelType.getZeroPoints()) + zeroPointsVec.push_back(zeroPoint); + zeroPoints = builder.create( + loc, DenseFPElementsAttr::get( + RankedTensorType::get( + {static_cast( + quantPerChannelType.getZeroPoints().size())}, + builder.getF32Type()), + zeroPointsVec)); + } else { + SmallVector zeroPointsVec; + for (auto zeroPoint : quantPerChannelType.getZeroPoints()) + zeroPointsVec.push_back(zeroPoint); + zeroPoints = builder.create( + loc, DenseIntElementsAttr::get( + RankedTensorType::get( + {static_cast( + quantPerChannelType.getZeroPoints().size())}, + builder.getI32Type()), + zeroPointsVec)); + } + broadcastDims = DenseI64ArrayAttr::get( + builder.getContext(), + {static_cast(quantPerChannelType.getQuantizedDimension())}); + } +} + +// Extracts storage min/max from input quant type info. +void getQuantizationStorageInfo(OpBuilder &builder, Location loc, + QuantType quantType, Value &storageMin, + Value &storageMax) { + if (auto *quantPerTensorType = + std::get_if(&quantType)) { + storageMin = builder.create( + loc, builder.getF32FloatAttr( + static_cast(quantPerTensorType->getStorageTypeMin()))); + storageMax = builder.create( + loc, builder.getF32FloatAttr( + static_cast(quantPerTensorType->getStorageTypeMax()))); + } else { + auto &quantPerChannelType = + std::get(quantType); + storageMin = builder.create( + loc, builder.getF32FloatAttr( + static_cast(quantPerChannelType.getStorageTypeMin()))); + storageMax = builder.create( + loc, builder.getF32FloatAttr( + static_cast(quantPerChannelType.getStorageTypeMax()))); + } +} + +// Extracts storage type of a UQ type. Return original type if it is no UQ type. +Type getQuantStorageType(Type type) { + if (auto shaped = type.dyn_cast()) { + return shaped.clone(getQuantStorageType(shaped.getElementType())); + } + + if (auto elementType = + getElementTypeOrSelf(type).dyn_cast()) { + return elementType.getStorageType(); + } + if (auto elementType = getElementTypeOrSelf(type) + .dyn_cast()) { + return elementType.getStorageType(); + } + return type; +} + +Type getQuantStorageType(QuantType type) { + if (isPerTensorType(type)) { + return getPerTensorType(type).getStorageType(); + } + return getPerChannelType(type).getStorageType(); +} + +Value applyMergedScalesAndZps(OpBuilder &builder, Location loc, + QuantType inputQuantType, + QuantType outputQuantType, + Value inputFloatTensor) { + // Use single merged scale and merged zp if both input and output are + // per-tensor quantized. Otherwise use a vector. + if (isPerTensorType(inputQuantType) && isPerTensorType(outputQuantType)) { + quant::UniformQuantizedType inputPerTensorType = + getPerTensorType(inputQuantType); + quant::UniformQuantizedType outputPerTensorType = + getPerTensorType(outputQuantType); + double mergedScaleFp = + inputPerTensorType.getScale() / outputPerTensorType.getScale(); + auto mergedScale = builder.create( + loc, builder.getF32FloatAttr(static_cast(mergedScaleFp))); + inputFloatTensor = + builder.create(loc, inputFloatTensor, mergedScale, + /*broadcast_dimensions=*/nullptr); + // Add merged_zp only when it is non-zero. + double mergedZpFp = outputPerTensorType.getZeroPoint() - + inputPerTensorType.getZeroPoint() * mergedScaleFp; + if (mergedZpFp != 0) { + Value mergedZp = builder.create( + loc, builder.getF32FloatAttr(static_cast(mergedZpFp))); + inputFloatTensor = builder.create( + loc, inputFloatTensor, mergedZp, /*broadcast_dimensions=*/nullptr); + } + } else { + int64_t channelSize = + isPerChannelType(outputQuantType) + ? getPerChannelType(outputQuantType).getScales().size() + : getPerChannelType(inputQuantType).getScales().size(); + int64_t quantizedDimension = + isPerChannelType(outputQuantType) + ? getPerChannelType(outputQuantType).getQuantizedDimension() + : getPerChannelType(inputQuantType).getQuantizedDimension(); + SmallVector mergedScaleDouble, mergedZpDouble; + mergedScaleDouble.resize(channelSize); + mergedZpDouble.resize(channelSize); + for (int i = 0; i < channelSize; ++i) { + mergedScaleDouble[i] = + (isPerChannelType(inputQuantType) + ? getPerChannelType(inputQuantType).getScales()[i] + : getPerTensorType(inputQuantType).getScale()) / + (isPerChannelType(outputQuantType) + ? getPerChannelType(outputQuantType).getScales()[i] + : getPerTensorType(outputQuantType).getScale()); + mergedZpDouble[i] = + (isPerChannelType(outputQuantType) + ? getPerChannelType(outputQuantType).getZeroPoints()[i] + : getPerTensorType(outputQuantType).getZeroPoint()) - + (isPerChannelType(inputQuantType) + ? getPerChannelType(inputQuantType).getZeroPoints()[i] + : getPerTensorType(inputQuantType).getZeroPoint()) * + mergedScaleDouble[i]; + } + SmallVector mergedScaleFloat(mergedScaleDouble.begin(), + mergedScaleDouble.end()), + mergedZpFloat(mergedZpDouble.begin(), mergedZpDouble.end()); + + auto broadcastDims = + DenseI64ArrayAttr::get(builder.getContext(), {quantizedDimension}); + Value mergedScale = builder.create( + loc, DenseFPElementsAttr::get( + RankedTensorType::get({channelSize}, builder.getF32Type()), + mergedScaleFloat)); + inputFloatTensor = builder.create( + loc, inputFloatTensor, mergedScale, broadcastDims); + if (llvm::any_of(mergedZpFloat, [](double zp) { return zp != 0; })) { + Value mergedZp = builder.create( + loc, DenseFPElementsAttr::get( + RankedTensorType::get({channelSize}, builder.getF32Type()), + mergedZpFloat)); + inputFloatTensor = builder.create( + loc, inputFloatTensor, mergedZp, broadcastDims); + } + } + return inputFloatTensor; +} + +// This helper function create ops to requantize `input` tensor and returns the +// output tensor. Clamping is done if output integer bit-width < i32. It assumes +// that if both input and output tensor are per-channel quantized, they have the +// same quantization axis. +// +// Requantization is essentially dequantize --> quantize. +// +// Dequantize: (input - zp) * scale +// Quantize: input / scale + zp +// +// Hence, +// output = (input - input_zp) * input_scale / output_scale + output_zp +// +// This is simplified as: +// output = input * merged_scale + merged_zp +// where: +// merged_zp = output_zp - input_zp * merged_scale. +// merged_scale = input_scale / output_scale. +Value requantize(mlir::OpState op, Value input, QuantType inputQuantType, + QuantType outputQuantType, TensorType outputTensorType, + ConversionPatternRewriter &rewriter) { + // Skip requantization when input and result have the same type. + if (inputQuantType == outputQuantType) { + return rewriter.create(op->getLoc(), outputTensorType, + input); + } + + auto floatTensorType = outputTensorType.clone(rewriter.getF32Type()); + Value outputFloat = + rewriter.create(op->getLoc(), floatTensorType, input); + + outputFloat = applyMergedScalesAndZps(rewriter, op->getLoc(), inputQuantType, + outputQuantType, outputFloat); + + // Clamp output if the output integer bit-width <32. + if (outputTensorType.getElementType().cast().getWidth() < 32) { + Value quantizationMin, quantizationMax; + getQuantizationStorageInfo(rewriter, op->getLoc(), outputQuantType, + quantizationMin, quantizationMax); + // Clamp results by [quantizationMin, quantizationMax]. + outputFloat = rewriter.create(op->getLoc(), quantizationMin, + outputFloat, quantizationMax); + } + + outputFloat = rewriter.create( + op->getLoc(), floatTensorType, outputFloat); + return rewriter.create(op->getLoc(), outputTensorType, + outputFloat); +} + +class ConvertUniformQuantizeOp + : public OpConversionPattern { + public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite( + mhlo::UniformQuantizeOp op, mhlo::UniformQuantizeOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto inputElementType = getElementTypeOrSelf(op.getOperand().getType()); + if (inputElementType.isF32()) { + auto quantType = getQuantType(op.getResult().getType()); + if (succeeded(quantType)) { + return matchAndRewriteQuantize(op, adaptor, rewriter, *quantType); + } + } else if (inputElementType.isa()) { + auto inputQuantType = getQuantType(inputElementType); + auto outputQuantType = getQuantType(op.getResult().getType()); + if (succeeded(inputQuantType) && succeeded(outputQuantType)) { + if (isPerChannelType(*inputQuantType) && + isPerChannelType(*outputQuantType) && + getPerChannelType(*inputQuantType).getQuantizedDimension() != + getPerChannelType(*outputQuantType).getQuantizedDimension()) { + op->emitError("Cannot requantize while changing quantization_axis"); + return failure(); + } + return matchAndRewriteRequantize(op, adaptor, rewriter, *inputQuantType, + *outputQuantType); + } + } + op->emitError("Unsupported input element type."); + return failure(); + } + + LogicalResult matchAndRewriteQuantize(mhlo::UniformQuantizeOp op, + mhlo::UniformQuantizeOpAdaptor adaptor, + ConversionPatternRewriter &rewriter, + QuantType quantType) const { + Value scales, zeroPoints; + DenseI64ArrayAttr broadcastDims; + getQuantizationParams(rewriter, op->getLoc(), quantType, scales, zeroPoints, + /*outputZeroPointInFp=*/true, broadcastDims); + + Value quantizationMin, quantizationMax; + getQuantizationStorageInfo(rewriter, op->getLoc(), quantType, + quantizationMin, quantizationMax); + + auto resFloatTensorType = + op.getOperand().getType().clone(rewriter.getF32Type()); + Value resFloat = rewriter.create( + op->getLoc(), resFloatTensorType, adaptor.getOperand(), scales, + broadcastDims); + resFloat = rewriter.create( + op->getLoc(), resFloatTensorType, resFloat, zeroPoints, broadcastDims); + + resFloat = rewriter.create(op->getLoc(), resFloatTensorType, + quantizationMin, resFloat, + quantizationMax); + resFloat = rewriter.create( + op->getLoc(), resFloatTensorType, resFloat); + auto resFinalTensorType = resFloatTensorType.clone( + getQuantStorageType(op.getResult().getType().getElementType())); + rewriter.replaceOpWithNewOp(op, resFinalTensorType, + resFloat); + return success(); + } + + LogicalResult matchAndRewriteRequantize( + mhlo::UniformQuantizeOp op, mhlo::UniformQuantizeOpAdaptor adaptor, + ConversionPatternRewriter &rewriter, QuantType inputQuantType, + QuantType outputQuantType) const { + rewriter.replaceOp( + op, + requantize(op, adaptor.getOperand(), inputQuantType, outputQuantType, + /*outputTensorType=*/ + op.getResult().getType().cast().clone( + getQuantStorageType(outputQuantType)), + rewriter)); + return success(); + } +}; + +class ConvertUniformDequantizeOp + : public OpConversionPattern { + public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite( + mhlo::UniformDequantizeOp op, mhlo::UniformDequantizeOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto quantType = getQuantType(op.getOperand().getType()); + if (failed(quantType)) { + return failure(); + } + Value scales, zeroPoints; + DenseI64ArrayAttr broadcastDims; + getQuantizationParams(rewriter, op->getLoc(), *quantType, scales, + zeroPoints, + /*outputZeroPointInFp=*/false, broadcastDims); + + Value input = adaptor.getOperand(); + // TODO: b/260280919 - Consider avoiding conversion to int32. + auto resInt32TensorType = + input.getType().cast().clone(rewriter.getI32Type()); + Value resInt32 = rewriter.create( + op->getLoc(), resInt32TensorType, input); + resInt32 = rewriter.create( + op->getLoc(), resInt32TensorType, resInt32, zeroPoints, broadcastDims); + auto resFloatTensorType = + resInt32.getType().cast().clone(rewriter.getF32Type()); + Value resFloat = rewriter.create( + op->getLoc(), resFloatTensorType, resInt32); + resFloat = rewriter.replaceOpWithNewOp( + op, resFloatTensorType, resFloat, scales, broadcastDims); + return success(); + } +}; + +class ConvertUniformQuantizedAddOp : public OpConversionPattern { + public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite( + mhlo::AddOp op, mhlo::AddOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto lhsQuantType = + getQuantType(getElementTypeOrSelf(op.getLhs().getType())); + auto rhsQuantType = + getQuantType(getElementTypeOrSelf(op.getRhs().getType())); + auto resQuantType = + getQuantType(getElementTypeOrSelf(op.getResult().getType())); + + // We only handle cases where lhs, rhs and results all have quantized + // element type. + if (failed(lhsQuantType) || failed(rhsQuantType) || failed(resQuantType)) { + op->emitError( + "AddOp requires the quantized element type for all operands and " + "results"); + return failure(); + } + + if (isPerChannelType(*lhsQuantType) || isPerChannelType(*rhsQuantType) || + isPerChannelType(*resQuantType)) { + // Handle Per-Channel Quantized Types. We only support lhs/rhs/result with + // exact same per-channel quantized types with I32 storage type. + if (!isPerChannelType(*lhsQuantType) || + !isPerChannelType(*rhsQuantType) || + !isPerChannelType(*resQuantType) || + getPerChannelType(*lhsQuantType) != + getPerChannelType(*rhsQuantType) || + getPerChannelType(*lhsQuantType) != + getPerChannelType(*resQuantType)) { + op->emitError( + "Per-channel quantized AddOp requires the same quantized element " + "type for all operands and results"); + return failure(); + } + if (!getPerChannelType(*lhsQuantType).getStorageType().isInteger(32)) { + // For server-side StableHLO Quantization, add is quantized only when + // fused with conv/dot ops, whose output must be i32. + op->emitError("Per-channel quantized AddOp requires i32 storage type"); + return failure(); + } + return matchAndRewritePerChannel(op, adaptor, rewriter, + getPerChannelType(*lhsQuantType)); + } + + // TODO: b/260280919 - Consider avoiding conversion to int32. + auto resInt32TensorType = + op.getResult().getType().clone(rewriter.getI32Type()); + + // When lhs, rhs and result have different scale and zps, requantize them to + // be the same as the result. + // TODO: b/260280919 - Consider avoiding conversion to int32. + Value lhs = adaptor.getLhs(); + Value lhsInt32Tensor = requantize(op, lhs, *lhsQuantType, *resQuantType, + resInt32TensorType, rewriter); + + Value rhs = adaptor.getRhs(); + Value rhsInt32Tensor = requantize(op, rhs, *rhsQuantType, *resQuantType, + resInt32TensorType, rewriter); + + Value zeroPoint = rewriter.create( + op->getLoc(), rewriter.getI32IntegerAttr(static_cast( + getPerTensorType(*resQuantType).getZeroPoint()))); + + // Now the lhs and rhs have been coverted to the same scale and zps. + // Given: + // lhs_fp = (lhs_quant - zp) * scale + // rhs_fp = (rhs_quant - zp) * scale + // res_fp = lhs_fp + rhs_fp + // = ((lhs_quant + rhs_quant - zp) - zp) * scale + // res_quant = res_fp / scale + zp + // = lhs_quant + rhs_quant - zp + // The following add the inputs and then substract by zero point. + Value addResult = rewriter.create( + op->getLoc(), resInt32TensorType, lhsInt32Tensor, rhsInt32Tensor, + nullptr); + Value resInt32 = rewriter.create( + op->getLoc(), resInt32TensorType, addResult, zeroPoint, nullptr); + + if (getQuantStorageType(*resQuantType).isInteger(32)) { + // For i32, clamping is not needed. + rewriter.replaceOp(op, resInt32); + } else { + // Clamp results by [quantizationMin, quantizationMax] when storage type + // is not i32. + Value resultQuantizationMin = rewriter.create( + op->getLoc(), + rewriter.getI32IntegerAttr(static_cast( + getPerTensorType(*resQuantType).getStorageTypeMin()))); + Value resultQuantizationMax = rewriter.create( + op->getLoc(), + rewriter.getI32IntegerAttr(static_cast( + getPerTensorType(*resQuantType).getStorageTypeMax()))); + resInt32 = rewriter.create( + op->getLoc(), resInt32TensorType, resultQuantizationMin, resInt32, + resultQuantizationMax); + // Convert results back to result storage type. + auto resFinalTensorType = + resInt32TensorType.clone(getQuantStorageType(*resQuantType)); + rewriter.replaceOpWithNewOp(op, resFinalTensorType, + resInt32); + } + + return success(); + } + + LogicalResult matchAndRewritePerChannel( + mhlo::AddOp op, mhlo::AddOpAdaptor adaptor, + ConversionPatternRewriter &rewriter, + quant::UniformQuantizedPerAxisType quantType) const { + // We assume lhs/rhs/result have the same quantized type with i32 storage. + Value addResult = rewriter.create( + op->getLoc(), adaptor.getLhs(), adaptor.getRhs()); + // Add zp contribution if it is non-zero for any channel. + if (llvm::any_of(quantType.getZeroPoints(), + [](int64_t zp) { return zp != 0; })) { + SmallVector zpsVec(quantType.getZeroPoints().begin(), + quantType.getZeroPoints().end()); + Value zps = rewriter.create( + op->getLoc(), + DenseIntElementsAttr::get( + RankedTensorType::get({static_cast(zpsVec.size())}, + rewriter.getI32Type()), + zpsVec)); + addResult = rewriter.create( + op->getLoc(), addResult, zps, + rewriter.getDenseI64ArrayAttr( + {static_cast(quantType.getQuantizedDimension())})); + } + rewriter.replaceOp(op, addResult); + return success(); + } +}; + +// This is a convenient struct for holding dimension numbers for dot-like ops +// including DotGeneral and Convolution. So that we can share code for all +// dot-like ops. +// For Convolution, only NHWC format is supported. +// For DotGeneral, there is no contracting dims. The batching and contracting +// dimensions are defined in +// https://github.com/openxla/stablehlo/blob/main/docs/spec.md#dot_general. +struct DotLikeDimensionNumbers { + SmallVector lhsBatchingDims; + SmallVector lhsSpatialDims; + SmallVector lhsContractingDims; + SmallVector rhsBatchingDims; + SmallVector rhsSpatialDims; + SmallVector rhsContractingDims; +}; + +// A shared matchAndRewrite implementation for dot-like hybrid quantized +// operators. Hybrid ops are currently only interpreted as weight-only +// quantization ops, this might change in the future. +// +// All attrs of the original op are preserved after the conversion. +template +LogicalResult matchAndRewriteDotLikeHybridOp( + OpType &op, OpAdaptorType &adaptor, ConversionPatternRewriter &rewriter) { + // For dot like hybrid ops, lhs is float type, rhs is uniform + // quantized type and result is float type. + // For weight-only quantization: + // result = hybridOp(lhs, dequant(rhs)) + Value lhsFloat32Tensor = adaptor.getLhs(); + Value rhs = adaptor.getRhs(); + quant::UniformQuantizedType rhsElementType = + getElementTypeOrSelf(op.getRhs().getType()) + .template cast(); + auto resFloat32TensorType = + op.getResult().getType().template cast(); + auto rhsFloat32TensorType = + op.getRhs().getType().template cast().clone( + rewriter.getF32Type()); + + // Get scales and zero points for rhs. + Value rhsZeroPoint = rewriter.create( + op->getLoc(), rewriter.getF32FloatAttr((rhsElementType.getZeroPoint()))); + Value rhsScaleConstant = rewriter.create( + op->getLoc(), + rewriter.getF32FloatAttr(static_cast(rhsElementType.getScale()))); + + // Dequantize rhs_float32_tensor. + Value rhsFloat32Tensor = + rewriter.create(op->getLoc(), rhsFloat32TensorType, rhs); + rhsFloat32Tensor = rewriter.create( + op->getLoc(), rhsFloat32TensorType, rhsFloat32Tensor, rhsZeroPoint, + nullptr); + rhsFloat32Tensor = rewriter.create( + op->getLoc(), rhsFloat32TensorType, rhsFloat32Tensor, rhsScaleConstant, + nullptr); + + // Execute conversion target op. + SmallVector operands{lhsFloat32Tensor, rhsFloat32Tensor}; + rewriter.replaceOpWithNewOp(op, resFloat32TensorType, operands, + op->getAttrs()); + return success(); +} + +Value createZeroPointPartialOffset(OpBuilder &builder, Location loc, + Value tensor, const int64_t otherTensorZp, + SmallVector reductionDims) { + // This function calculates part of the zero-point-offset by using + // mhlo::Reduce to sum over the contracting dims of the tensor, and then + // multiply by zp of the other tensor. + auto outputElementType = builder.getI32Type(); + + // Calculate the output tensor shape. This is input tensor dims minus + // contracting dims. + auto rankedTensor = tensor.getType().cast(); + SmallVector outputDims; + for (int64_t i = 0; i < rankedTensor.getRank(); ++i) { + if (llvm::count(reductionDims, i) == 0) { + outputDims.push_back(rankedTensor.getDimSize(i)); + } + } + + // Convert input tensor to output type since mhlo::Reduce only supports same + // element type for input/output. + tensor = builder.create( + loc, tensor.getType().cast().clone(outputElementType), + tensor); + auto reducerTensorType = RankedTensorType::get({}, outputElementType); + + // Initial value for reduced tensor. This is set 0. + Value initValues = builder.create( + loc, DenseIntElementsAttr::get(reducerTensorType, {0})); + mhlo::ReduceOp reduce = builder.create( + loc, RankedTensorType::get(outputDims, outputElementType), tensor, + initValues, builder.getI64TensorAttr(reductionDims)); + // Define reducer function to compute sum. + Region ®ion = reduce.getBody(); + Block &block = region.emplaceBlock(); + block.addArgument(reducerTensorType, loc); + block.addArgument(reducerTensorType, loc); + auto *firstArgument = block.args_begin(); + auto secondArgument = block.args_rbegin(); + { + OpBuilder::InsertionGuard guard(builder); + builder.setInsertionPointToStart(&block); + Value sum = + builder.create(loc, *firstArgument, *secondArgument); + builder.create(loc, sum); + } + Value zp = builder.create( + loc, builder.getI32IntegerAttr(otherTensorZp)); + Value mulOp = builder.create(loc, reduce.getResult(0), + zp, nullptr); + return mulOp; +} + +Value getDimValue(OpBuilder &builder, Location loc, Value tensor, + mlir::ShapedType tensorShape, int64_t idx) { + if (tensorShape.isDynamicDim(idx)) { + // Get dynamic dim using GetDimensionSizeOp and convert result from to + // <1xi64>. + Value dynamicDim = builder.create( + loc, tensor, builder.getI64IntegerAttr(idx)); + dynamicDim = builder.create( + loc, RankedTensorType::get(ArrayRef{}, builder.getI64Type()), + dynamicDim); + return builder.create( + loc, RankedTensorType::get({1}, builder.getI64Type()), dynamicDim); + } + return builder.create( + loc, DenseIntElementsAttr::get( + RankedTensorType::get({1}, builder.getI64Type()), + {tensorShape.getDimSize(idx)})); +} + +Value calculateDynamicOutputDims(OpBuilder &builder, Location loc, Value output, + ShapedType outputTensorType) { + // Calculate each output tensor dim and concatenate into a 1D tensor. + SmallVector outputDims; + for (int64_t i = 0; i < outputTensorType.getRank(); ++i) { + outputDims.push_back( + getDimValue(builder, loc, output, outputTensorType, i)); + } + return builder.create(loc, outputDims, + builder.getI64IntegerAttr(0)); +} + +Value broadcastZpContribution(OpBuilder &builder, Location loc, + Value zpContribution, + ArrayRef reductionDims, + ArrayRef batchingDims, + int64_t nonBatchingStartingIdx, Value output, + TensorType outputTensorType, + Value &outputDimsValue) { + // This function calculates the dims for broadcasting from the + // zero-point-offset tensor to the final output tensor, and then do the + // broadcast. + auto zpContributionRank = + zpContribution.getType().cast().getRank(); + SmallVector broadcastDims; + broadcastDims.resize(zpContributionRank, 0); + // Result tensor will have batching dims first, then LHS result dims, then + // RHS result dims. So non-batching result dims index doesn't start from 0. + // The arg non_batching_starting_idx is used to distinguish LHS and RHS. + int64_t resultBatchingIdx = 0; + int64_t resultNonBatchingIdx = nonBatchingStartingIdx; + for (int64_t idx = 0, originalIdx = 0; idx < zpContributionRank; + ++idx, ++originalIdx) { + // zp_contribution has removed contracting/spatial dims from the tensor + // after reduction. The following recovers the index in the original tensor. + while (llvm::count(reductionDims, originalIdx) != 0) { + originalIdx++; + } + if (llvm::count(batchingDims, originalIdx) == 0) { + broadcastDims[idx] = resultNonBatchingIdx++; + } else { + broadcastDims[idx] = resultBatchingIdx++; + } + } + // Use broadcast_in_dim or dyanmic_broadcast_in_dim based on output shape + // dynamism. + if (outputTensorType.cast().hasStaticShape()) { + zpContribution = builder.create( + loc, outputTensorType, zpContribution, + DenseIntElementsAttr::get( + RankedTensorType::get({static_cast(broadcastDims.size())}, + builder.getI64Type()), + broadcastDims)); + } else { + if (!outputDimsValue) { + outputDimsValue = + calculateDynamicOutputDims(builder, loc, output, outputTensorType); + } + zpContribution = builder.create( + loc, outputTensorType, zpContribution, outputDimsValue, + DenseIntElementsAttr::get( + RankedTensorType::get({static_cast(broadcastDims.size())}, + builder.getI64Type()), + broadcastDims)); + } + return zpContribution; +} + +Value calculateZeroPointOffset(OpBuilder &builder, Location loc, Value lhs, + Value rhs, Value output, int64_t lhsZp, + int64_t rhsZp, TensorType outputTensorType, + const DotLikeDimensionNumbers &dims) { + mlir::ShapedType lhsShape = lhs.getType().cast(); + mlir::ShapedType rhsShape = rhs.getType().cast(); + Value result = nullptr; + Value outputDimsValue = nullptr; + // Calculate LHS contribution when RHS zp is non-zero. + if (rhsZp != 0) { + SmallVector reductionDims = to_vector(llvm::concat( + dims.lhsSpatialDims, dims.lhsContractingDims)); + Value lhsZpContribution = + createZeroPointPartialOffset(builder, loc, lhs, rhsZp, reductionDims); + // Broadcast lhs ZP contribution to result tensor shape. + lhsZpContribution = broadcastZpContribution( + builder, loc, lhsZpContribution, reductionDims, dims.lhsBatchingDims, + dims.lhsBatchingDims.size(), output, outputTensorType, outputDimsValue); + result = lhsZpContribution; + } + // Calculate RHS contribution when LHS zp is non-zero. + if (lhsZp != 0) { + SmallVector reductionDims = to_vector(llvm::concat( + dims.rhsSpatialDims, dims.rhsContractingDims)); + Value rhsZpContribution = + createZeroPointPartialOffset(builder, loc, rhs, lhsZp, reductionDims); + // Broadcast rhs ZP contribution to result tensor shape. + rhsZpContribution = broadcastZpContribution( + builder, loc, rhsZpContribution, reductionDims, dims.rhsBatchingDims, + lhsShape.getRank() - dims.lhsContractingDims.size(), output, + outputTensorType, outputDimsValue); + if (result) { + result = builder.create(loc, result, rhsZpContribution); + } else { + result = rhsZpContribution; + } + } + + if (lhsZp != 0 && rhsZp != 0) { + // Contributions from LHS_ZP * RHS_ZP. + // This is multiplied by the product of all contracting dimensions. + int32_t contractingDimTotalInt = 1; + bool hasDynamicContractingDim = false; + Value dynamicContractingDimTotal = builder.create( + loc, builder.getI32IntegerAttr(static_cast(1))); + // Calculate the product for static/dynamic dims separately. + for (int64_t rhsIdx : llvm::concat( + dims.rhsSpatialDims, dims.rhsContractingDims)) { + if (rhsShape.isDynamicDim(rhsIdx)) { + hasDynamicContractingDim = true; + auto dim = builder.create( + loc, rhs, builder.getI64IntegerAttr(rhsIdx)); + dynamicContractingDimTotal = + builder.create(loc, dynamicContractingDimTotal, dim); + } else { + contractingDimTotalInt *= rhsShape.getDimSize(rhsIdx); + } + } + Value zpOffsetValue = builder.create( + loc, builder.getI32IntegerAttr(static_cast(lhsZp) * + static_cast(rhsZp) * + contractingDimTotalInt)); + // Multiply the static dims contribution by the dynamic one if needed. + if (hasDynamicContractingDim) { + zpOffsetValue = builder.create(loc, zpOffsetValue, + dynamicContractingDimTotal); + } + result = builder.create(loc, result, zpOffsetValue, + nullptr); + } + return result; +} + +// Generic function to create DotGeneral kernel for Dot/DotGeneral ops. +template +Value createDotLikeKernel(OpBuilder &builder, Location loc, DotLikeOp, + Type resultType, Value &lhs, Value &rhs, + ArrayRef attrs) { + return builder.create(loc, resultType, + ArrayRef{lhs, rhs}, attrs); +} + +// Template specialization for Convolution op. +// This function may pad LHS if needed. If so, lhs is updated in place. +template <> +Value createDotLikeKernel(OpBuilder &builder, Location loc, + mhlo::ConvolutionOp op, + Type resultType, Value &lhs, + Value &rhs, + ArrayRef attrs) { + // We only handle the case where RHS zp is zero. + // Explicitly pad LHS with zp and update LHS value. + SmallVector newAttrs(attrs); + if (op.getPadding().has_value() && + llvm::any_of(op.getPaddingAttr().getValues(), + [](int64_t x) { return x != 0; })) { + auto originalPadding = op.getPaddingAttr().getValues(); + + Value zp = builder.create( + loc, + DenseIntElementsAttr::get( + RankedTensorType::get({}, builder.getI8Type()), + {static_cast(getElementTypeOrSelf(op.getLhs().getType()) + .cast() + .getZeroPoint())})); + // Convert Padding attributes from mhlo::Convolution to mhlo::Pad. Note that + // Padding is applied for spatial dimensions [1...rank-1) only for + // mhlo::Convolution. But mhlo::Pad require those for all dimensions. Hence + // we add 0 to the beginning and end of the padding vectors. + int64_t rank = lhs.getType().cast().getRank(); + SmallVector paddingLow(rank, 0), paddingHigh(rank, 0), + paddingInterior(rank, 0); + for (int64_t i = 1; i < rank - 1; ++i) { + paddingLow[i] = originalPadding[i * 2 - 2]; + paddingHigh[i] = originalPadding[i * 2 - 1]; + } + lhs = builder.create( + loc, lhs, zp, + DenseIntElementsAttr::get( + RankedTensorType::get({rank}, builder.getI64Type()), paddingLow), + DenseIntElementsAttr::get( + RankedTensorType::get({rank}, builder.getI64Type()), paddingHigh), + DenseIntElementsAttr::get( + RankedTensorType::get({rank}, builder.getI64Type()), + paddingInterior)); + + // After explicitly padding/dilating LHS, update attributes so that LHS is + // not padded/dilated again during Convolution. + for (auto &attr : newAttrs) { + if (attr.getName().getValue() == "padding") { + attr.setValue(SplatElementsAttr::get( + RankedTensorType::get({rank - 2, 2}, builder.getI64Type()), + builder.getI64IntegerAttr(0))); + } + } + } + return builder.create( + loc, resultType, ArrayRef{lhs, rhs}, newAttrs); +} + +template +LogicalResult matchAndRewriteDotLikeOp(DotLikeOp op, DotLikeOpAdaptor adaptor, + ArrayRef attrs, + const DotLikeDimensionNumbers &dims, + ConversionPatternRewriter &rewriter) { + // Lower Dot/DotGeneral UQ ops to DotGeneral int. + // Assumes that operands and results are uq types. + Value lhs = adaptor.getLhs(); + Value rhs = adaptor.getRhs(); + auto resInt32TensorType = + op.getResult().getType().clone(rewriter.getI32Type()); + + // Dot result + // = dot((lhs - zp_l) * scale_l, (rhs - zp_r) * scale_r) / scale_res + // + zp_res + // = dot(lhs - zp_l, rhs - zp_r) * scale_l * scale_r / scale_res + zp_res + // = dot(lhs, rhs) * combined_scale + combined_zp + // where: + // combined_scale = scale_l * scale_r / scale_res + // combined_zp = res_zp - zp_offset * combined_scale + // zp_offset = zp_l*rhs + zp_r*lhs - zp_l*zp_r + Value resI32 = createDotLikeKernel(rewriter, op->getLoc(), op, + resInt32TensorType, lhs, rhs, attrs); + + auto lhsElementQuantType = getElementTypeOrSelf(op.getLhs().getType()) + .template cast(); + auto rhsElementQuantType = + getElementTypeOrSelf(op.getRhs().getType()) + .template dyn_cast(); + auto rhsElementQuantPerChannelType = + getElementTypeOrSelf(op.getRhs().getType()) + .template dyn_cast(); + auto resElementQuantType = + getElementTypeOrSelf(op.getResult()) + .template dyn_cast(); + auto resElementQuantPerChannelType = + getElementTypeOrSelf(op.getResult()) + .template dyn_cast(); + + // Here we assume LHS must be per-tensor quantized. + // If RHS is per-channel quantized, it must has 0 zp. + Value zpOffset = calculateZeroPointOffset( + rewriter, op->getLoc(), lhs, rhs, resI32, + lhsElementQuantType.getZeroPoint(), + (rhsElementQuantType ? rhsElementQuantType.getZeroPoint() : 0), + resInt32TensorType, dims); + + // For per-channel quantization, we assume that result scales are proportional + // to rhs scales for each channels. + double combinedScaleFp = + rhsElementQuantType + ? lhsElementQuantType.getScale() * rhsElementQuantType.getScale() / + resElementQuantType.getScale() + : lhsElementQuantType.getScale() * + rhsElementQuantPerChannelType.getScales()[0] / + resElementQuantPerChannelType.getScales()[0]; + + // Multiply dot result and zp_offset by combined_scale only if it is not 1.0. + if (std::abs(combinedScaleFp - 1.0) > 0.001) { + Value combinedScale = rewriter.create( + op->getLoc(), rewriter.getF32FloatAttr(combinedScaleFp)); + + auto resFloat32TensorType = + op.getResult().getType().clone(rewriter.getF32Type()); + Value resF32 = rewriter.create( + op->getLoc(), resFloat32TensorType, resI32); + resF32 = rewriter.create( + op->getLoc(), resFloat32TensorType, resF32, combinedScale, nullptr); + resI32 = rewriter.create(op->getLoc(), resInt32TensorType, + resF32); + + // Skip zp_offset if it is 0. + if (zpOffset) { + auto zpOffsetFloat32TensorType = + zpOffset.getType().cast().clone(rewriter.getF32Type()); + zpOffset = rewriter.create( + op->getLoc(), zpOffsetFloat32TensorType, zpOffset); + zpOffset = rewriter.create( + op->getLoc(), zpOffsetFloat32TensorType, zpOffset, combinedScale, + nullptr); + zpOffset = rewriter.create( + op->getLoc(), zpOffsetFloat32TensorType.clone(rewriter.getI32Type()), + zpOffset); + } + } + + // If result is per-channel quantized, it must has 0 zp. + Value combinedZp = rewriter.create( + op->getLoc(), + rewriter.getI32IntegerAttr( + resElementQuantType ? resElementQuantType.getZeroPoint() : 0)); + if (zpOffset) { + combinedZp = rewriter.create( + op->getLoc(), resInt32TensorType, combinedZp, zpOffset, nullptr); + } + rewriter.replaceOpWithNewOp( + op, resInt32TensorType, resI32, combinedZp, nullptr); + return success(); +} + +template +FailureOr isDotLikeOpHybrid(DotLikeOp op) { + // Checks whether a dot-like op is hybrid by looking at input/output types. + // Returns failure() when the type is not supported. + bool isLhsQuant = isa( + getElementTypeOrSelf(op.getLhs().getType())); + bool isLhsQuantPerChannel = isa( + getElementTypeOrSelf(op.getLhs().getType())); + bool isRhsQuant = isa( + getElementTypeOrSelf(op.getRhs().getType())); + bool isRhsQuantPerChannel = isa( + getElementTypeOrSelf(op.getRhs().getType())); + bool isResQuant = + isa(getElementTypeOrSelf(op.getResult())); + bool isResQuantPerChannel = isa( + getElementTypeOrSelf(op.getResult())); + + if (isLhsQuant && ((isRhsQuant && isResQuant) || + (isa(op) && isRhsQuantPerChannel && + isResQuantPerChannel))) { + // For quantized ops, RHS and result must be both per-channel quantized. + // For Convolution, we also support per-channel quantized RHS/result. + return false; + } + if (!isLhsQuant && !isLhsQuantPerChannel && isRhsQuant && !isResQuant && + !isResQuantPerChannel) { + return true; + } + op->emitError("Invalid input/output type for Dot/Convolution op"); + return failure(); +} + +class ConvertUniformQuantizedDotOp : public OpConversionPattern { + public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite( + mhlo::DotOp op, mhlo::DotOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto isHybrid = isDotLikeOpHybrid(op); + if (failed(isHybrid)) { + return failure(); + } + if (*isHybrid) { + return matchAndRewriteDotLikeHybridOp(op, adaptor, rewriter); + } // DotOp is a special case of DotGeneralOp, where LHS and RHS are both + // rank-2 tensors and have contracting dims of 1 and 0 respectively. + auto dims = mhlo::DotDimensionNumbersAttr::get( + rewriter.getContext(), /*lhsBatchingDimensions=*/{}, + /*rhsBatchingDimensions=*/{}, /*lhsContractingDimensions=*/{1}, + /*rhsContractingDimensions=*/{0}); + SmallVector attrs(op->getAttrs()); + attrs.push_back( + {StringAttr::get(rewriter.getContext(), "dot_dimension_numbers"), + dims}); + return matchAndRewriteDotLikeOp( + op, adaptor, attrs, + DotLikeDimensionNumbers{/*lhs_batching_dims=*/{}, + /*lhs_spatial_dims=*/{}, + /*lhs_contracting_dims=*/{1}, + /*rhs_batching_dims=*/{}, + /*rhs_spatial_dims=*/{}, + /*rhs_contracting_dims=*/{0}}, + rewriter); + } +}; + +class ConvertUniformQuantizedDotGeneralOp + : public OpConversionPattern { + public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite( + mhlo::DotGeneralOp op, mhlo::DotGeneralOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto isHybrid = isDotLikeOpHybrid(op); + if (failed(isHybrid)) { + return failure(); + } + if (*isHybrid) { + return matchAndRewriteDotLikeHybridOp(op, adaptor, rewriter); + } + return matchAndRewriteDotLikeOp( + op, adaptor, op->getAttrs(), + DotLikeDimensionNumbers{ + to_vector(op.getDotDimensionNumbers().getLhsBatchingDimensions()), + /*lhs_spatial_dims=*/{}, + to_vector( + op.getDotDimensionNumbers().getLhsContractingDimensions()), + to_vector(op.getDotDimensionNumbers().getRhsBatchingDimensions()), + /*rhs_spatial_dims=*/{}, + to_vector( + op.getDotDimensionNumbers().getRhsContractingDimensions())}, + rewriter); + } +}; + +bool isConvNhwc(const mhlo::ConvDimensionNumbersAttr &dims) { + return dims.getInputBatchDimension() == 0 && + dims.getInputFeatureDimension() == 3 && + dims.getInputSpatialDimensions().size() == 2 && + dims.getInputSpatialDimensions()[0] == 1 && + dims.getInputSpatialDimensions()[1] == 2 && + dims.getKernelInputFeatureDimension() == 2 && + dims.getKernelOutputFeatureDimension() == 3 && + dims.getKernelSpatialDimensions().size() == 2 && + dims.getKernelSpatialDimensions()[0] == 0 && + dims.getKernelSpatialDimensions()[1] == 1 && + dims.getOutputBatchDimension() == 0 && + dims.getOutputFeatureDimension() == 3 && + dims.getOutputSpatialDimensions().size() == 2 && + dims.getOutputSpatialDimensions()[0] == 1 && + dims.getOutputSpatialDimensions()[1] == 2; +} + +bool isConvNDHWC(const mhlo::ConvDimensionNumbersAttr &dims) { + return dims.getInputBatchDimension() == 0 && + dims.getInputFeatureDimension() == 4 && + dims.getInputSpatialDimensions().size() == 3 && + dims.getInputSpatialDimensions()[0] == 1 && + dims.getInputSpatialDimensions()[1] == 2 && + dims.getInputSpatialDimensions()[2] == 3 && + dims.getKernelInputFeatureDimension() == 3 && + dims.getKernelOutputFeatureDimension() == 4 && + dims.getKernelSpatialDimensions().size() == 3 && + dims.getKernelSpatialDimensions()[0] == 0 && + dims.getKernelSpatialDimensions()[1] == 1 && + dims.getKernelSpatialDimensions()[2] == 2 && + dims.getOutputBatchDimension() == 0 && + dims.getOutputFeatureDimension() == 4 && + dims.getOutputSpatialDimensions().size() == 3 && + dims.getOutputSpatialDimensions()[0] == 1 && + dims.getOutputSpatialDimensions()[1] == 2 && + dims.getOutputSpatialDimensions()[2] == 3; +} + +FailureOr verifyAndConstructDims( + mhlo::ConvolutionOp op) { + // RHS (weight) must have zero zp. + // Here assumes RHS/result must be both per-tensor or both per-channel + // quantized. + auto failedOr = getQuantType(op.getRhs().getType()); + if (failed(failedOr)) { + return failure(); + } + QuantType rhsElementQuantType = *failedOr; + bool isRhsQuantPerTensor = + std::get_if(&rhsElementQuantType); + + if (isRhsQuantPerTensor + ? (std::get(rhsElementQuantType) + .getZeroPoint() != 0) + : llvm::any_of(llvm::concat( + std::get( + rhsElementQuantType) + .getZeroPoints(), + getElementTypeOrSelf(op.getResult()) + .cast() + .getZeroPoints()), + [](int64_t zp) { return zp != 0; })) { + op->emitError("RHS/result UQ type must have zero zp."); + return failure(); + } + // For per-channel quantization, RHS quantized axis must be out channel axis. + if (!isRhsQuantPerTensor && + (std::get(rhsElementQuantType) + .getQuantizedDimension() != + op.getRhs().getType().cast().getRank() - 1)) { + op->emitError("Conv quantized axis must be out channel axis"); + return failure(); + } + // For per-channel quantization, ratio between RHS and Result scales must be + // the same for each channel. + if (!isRhsQuantPerTensor) { + auto resElementQuantPerChannelType = + getElementTypeOrSelf(op.getResult()) + .cast(); + SmallVector scaleRatios( + resElementQuantPerChannelType.getScales().size()); + for (size_t i = 0; i < scaleRatios.size(); ++i) { + scaleRatios[i] = + resElementQuantPerChannelType.getScales()[i] / + std::get(rhsElementQuantType) + .getScales()[i]; + auto diff = (scaleRatios[i] - scaleRatios[0]) / scaleRatios[0]; + // Check all ratios within a threshold. + if (std::abs(diff) > 0.001) { + op->emitError( + "Per-channel quantizated Conv must have same RHS/Result scale " + "ratio for each channel"); + return failure(); + } + } + } + // lhs_dilation must not exist. + if (op.getLhsDilation().has_value() && + llvm::any_of(op.getLhsDilationAttr().getValues(), + [](int64_t dilate) { return dilate != 1; })) { + op->emitError("lhs_dilation must be 1."); + return failure(); + } + + // We only support NHWC Conv2D and NDHWC Conv3D. + auto dims = op.getDimensionNumbers(); + if (isConvNhwc(dims)) { + // 2D Convolution. + return DotLikeDimensionNumbers{/*lhs_batching_dims=*/{0}, + /*lhs_spatial_dims=*/{1, 2}, + /*lhs_contracting_dims=*/{3}, + /*rhs_batching_dims=*/{}, + /*rhs_spatial_dims=*/{0, 1}, + /*rhs_contracting_dims=*/{2}}; + } + if (isConvNDHWC(dims)) { + // 3D Convolution. + return DotLikeDimensionNumbers{/*lhs_batching_dims=*/{0}, + /*lhs_spatial_dims=*/{1, 2, 3}, + /*lhs_contracting_dims=*/{4}, + /*rhs_batching_dims=*/{}, + /*rhs_spatial_dims=*/{0, 1, 2}, + /*rhs_contracting_dims=*/{3}}; + } + op->emitError("Convolution data format must be NHWC."); + return failure(); +} + +class ConvertUniformQuantizedConvolutionOp + : public OpConversionPattern { + public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite( + mhlo::ConvolutionOp op, mhlo::ConvolutionOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto isHybrid = isDotLikeOpHybrid(op); + if (failed(isHybrid)) { + return failure(); + } + if (*isHybrid) { + return matchAndRewriteDotLikeHybridOp(op, adaptor, rewriter); + } + auto dims = verifyAndConstructDims(op); + if (failed(dims)) return failure(); + return matchAndRewriteDotLikeOp(op, adaptor, op->getAttrs(), *dims, + rewriter); + } +}; + +// This pattern lowers a generic MHLO op for uq->int. +// This pattern essentially just performs type change, with no algorithm change. +// TODO: b/310685906 - Add operand/result type validations. +class ConvertGenericOp : public ConversionPattern { + public: + explicit ConvertGenericOp(MLIRContext *ctx, TypeConverter &converter) + : ConversionPattern(converter, MatchAnyOpTypeTag(), 1, ctx) {} + + LogicalResult matchAndRewrite( + Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + // This pattern only handle selected ops. + if (!isa(op)) { + return failure(); + } + + // Determine new result type: use storage type for uq types; use original + // type otherwise. + SmallVector newResultTypes; + for (auto resultType : op->getResultTypes()) { + newResultTypes.push_back(getQuantStorageType(resultType)); + } + + OperationState state(op->getLoc(), op->getName().getStringRef(), operands, + newResultTypes, op->getAttrs(), op->getSuccessors()); + for (Region ®ion : op->getRegions()) { + Region &newRegion = *state.addRegion(); + rewriter.inlineRegionBefore(region, newRegion, newRegion.begin()); + if (failed( + rewriter.convertRegionTypes(&newRegion, *getTypeConverter()))) { + return failure(); + } + } + Operation *newOp = rewriter.create(state); + rewriter.replaceOp(op, newOp); + return success(); + } +}; + +// TypeConverter for converting UQ type to int type. +class UniformQuantizedToIntTypeConverter : public TypeConverter { + public: + UniformQuantizedToIntTypeConverter() { + addConversion([](Type type) -> Type { return getQuantStorageType(type); }); + } +}; + +#define GEN_PASS_DEF_MHLOQUANTLEGALIZETOINT +#include "mhlo/transforms/mhlo_passes.h.inc" + +class MhloQuantLegalizeToInt + : public impl::MhloQuantLegalizeToIntBase { + public: + // Performs conversion of MHLO quant ops to primitive ops. + void runOnOperation() override { + Operation *op = getOperation(); + MLIRContext *context = op->getContext(); + RewritePatternSet patterns(context); + + // Populate MHLO quant ops conversion patterns. + patterns.add(context); + + // uq->int convert patterns for func.func, func.return and generic ops. + UniformQuantizedToIntTypeConverter converter; + patterns.add(context, converter); + populateFunctionOpInterfaceTypeConversionPattern(patterns, + converter); + populateReturnOpTypeConversionPattern(patterns, converter); + + ConversionTarget target(*op->getContext()); + target.addIllegalDialect(); + auto isLegal = [&converter](Operation *op) { + return converter.isLegal(op); + }; + target.addDynamicallyLegalDialect(isLegal); + target.addDynamicallyLegalDialect(isLegal); + target.addDynamicallyLegalDialect( + [&converter](Operation *op) { + if (auto func = dyn_cast(op)) { + return converter.isSignatureLegal(func.getFunctionType()); + } + return converter.isLegal(op); + }); + + LogicalResult result = + applyPartialConversion(op, target, std::move(patterns)); + if (failed(result)) { + signalPassFailure(); + } + } +}; + +} // namespace + +std::unique_ptr> +createMhloQuantLegalizeToIntPass() { + return std::make_unique(); +} + +} // namespace mlir::mhlo diff --git a/third_party/xla/xla/mlir_hlo/mhlo/transforms/passes.h b/third_party/xla/xla/mlir_hlo/mhlo/transforms/passes.h index 1198d9edf6cd00..a52c5b4ee7b2a4 100644 --- a/third_party/xla/xla/mlir_hlo/mhlo/transforms/passes.h +++ b/third_party/xla/xla/mlir_hlo/mhlo/transforms/passes.h @@ -53,6 +53,10 @@ std::unique_ptr> createLegalizeToStdPass(); std::unique_ptr> createChloLegalizeToHloPass( bool legalizeBroadcasts = true, bool expandCompositions = true); +/// Lowers specific ops from the CHLO dialect to an HLO basis opset +std::unique_ptr> +createChloLegalizeToHloBasisOpsPass(); + // Lowers from sparse ops in CHLO dialect to Linalg dialect. std::unique_ptr> createLegalizeSparseOperationsPass( bool legalizeToCustomCalls = true); @@ -144,15 +148,6 @@ std::unique_ptr> createConstraintFusionPass(); std::unique_ptr> createGroupReductionDimensionsPass( bool preferColumnsReductions = true); -/// Rank specialization passes: -/// - Find compatible operations and group them together in one rank -/// specialization cluster. -/// - Lower rank specialization clusters to SCF and ranked operations. -std::unique_ptr> -createRankSpecializationClusterPass(); -std::unique_ptr> createRankSpecializationToSCFPass( - int64_t maxTargetRank = 5); - std::unique_ptr> createOptimizeMhloPass(); std::unique_ptr> createLowerComplexPass(); @@ -199,6 +194,10 @@ std::unique_ptr> createStablehloLegalizeToHloPass(); std::unique_ptr> createShapeLegalizeToHloPass( bool legalizeConstraints = false); +// Legalizes from MHLO quantized ops with MHLO quant types to MHLO primitive ops +// like int ops. +std::unique_ptr> createMhloQuantLegalizeToIntPass(); + // Test passes. std::unique_ptr createTestInferShapedTypeMethodsPass(); std::unique_ptr createTestMaterializeBroadcastsPass(); diff --git a/third_party/xla/xla/mlir_hlo/mhlo/transforms/prepare_for_export/prepare_for_export.cc b/third_party/xla/xla/mlir_hlo/mhlo/transforms/prepare_for_export/prepare_for_export.cc index 5c594d9db2b500..3aceda3e7de692 100644 --- a/third_party/xla/xla/mlir_hlo/mhlo/transforms/prepare_for_export/prepare_for_export.cc +++ b/third_party/xla/xla/mlir_hlo/mhlo/transforms/prepare_for_export/prepare_for_export.cc @@ -15,6 +15,7 @@ limitations under the License. // This file implements logic for some optimizations to reduce size on export. +#include #include #include #include @@ -23,6 +24,7 @@ limitations under the License. #include "mhlo/IR/hlo_ops.h" #include "mhlo/transforms/passes.h" #include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/Block.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinOps.h" @@ -30,7 +32,9 @@ limitations under the License. #include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/Operation.h" +#include "mlir/IR/Region.h" #include "mlir/IR/Types.h" +#include "mlir/IR/Value.h" #include "mlir/Pass/Pass.h" #include "mlir/Support/LLVM.h" #include "mlir/Transforms/RegionUtils.h" @@ -153,6 +157,36 @@ void prepareBroadcastInDim(BroadcastInDimOp bcast) { DenseIntElementsAttr::get(dims.getType(), transposedDim)); } +// Make implicitly captured constant explicit before exporting +void prepareExplicitCapturedConstants(Operation *op) { + for (Region ®ion : op->getRegions()) { + assert(region.getBlocks().size() == 1 && + "Only OPs with single block regions are allowed"); + llvm::SetVector implicitInputs; + // Get implicit inputs, i.e. those are used in the region + // but defined outside + getUsedValuesDefinedAbove(region, implicitInputs); + Block &block = region.getBlocks().front(); + OpBuilder builder(&block.front()); + for (Value input : implicitInputs) { + // If the captured value is defined by a constant OP, + // Create a clone constant OP within a block to make + // it explicit and replace uses within the block + Operation *definingOp = input.getDefiningOp(); + mlir::DenseElementsAttr attr; + if (matchPattern(input, m_Constant(&attr))) { + Operation *clonedOp = builder.clone(*definingOp); + // Find which uses belong to the block and replace + // with the cloned/explicit one + input.replaceUsesWithIf( + clonedOp->getResult(0), [&block](OpOperand &use) { + return block.getParentOp()->isProperAncestor(use.getOwner()); + }); + } + } + } +} + void PrepareForExportPass::runOnOperation() { getOperation().walk([&](Operation *op) { mlir::SplatElementsAttr attr; @@ -161,6 +195,11 @@ void PrepareForExportPass::runOnOperation() { if (auto whileOp = dyn_cast(op)) return prepareWhileOp(whileOp); if (auto bcastOp = dyn_cast(op)) return prepareBroadcastInDim(bcastOp); + // IfOp, CaseOp, WhileOp are already being handled during + // mhlo --> hlo translation. MapOp soon be deprecated. + if (mlir::isa(op)) + return prepareExplicitCapturedConstants(op); }); } diff --git a/third_party/xla/xla/mlir_hlo/mhlo/transforms/rank_specialization/rank_specialization.cc b/third_party/xla/xla/mlir_hlo/mhlo/transforms/rank_specialization/rank_specialization.cc deleted file mode 100644 index 1444c86031b0ac..00000000000000 --- a/third_party/xla/xla/mlir_hlo/mhlo/transforms/rank_specialization/rank_specialization.cc +++ /dev/null @@ -1,976 +0,0 @@ -/* Copyright 2021 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. - -==============================================================================*/ - -#include -#include -#include -#include -#include - -#include "llvm/ADT/EquivalenceClasses.h" -#include "llvm/ADT/STLExtras.h" -#include "llvm/ADT/SmallSet.h" -#include "llvm/ADT/SmallVector.h" -#include "mhlo/IR/hlo_ops.h" -#include "mhlo/transforms/passes.h" -#include "mhlo/transforms/rewriters.h" -#include "mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h" -#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Dialect/SCF/IR/SCF.h" -#include "mlir/Dialect/Shape/IR/Shape.h" -#include "mlir/Dialect/Tensor/IR/Tensor.h" -#include "mlir/IR/Block.h" -#include "mlir/IR/BuiltinOps.h" -#include "mlir/IR/BuiltinTypes.h" -#include "mlir/IR/IRMapping.h" -#include "mlir/IR/MLIRContext.h" -#include "mlir/IR/Operation.h" -#include "mlir/IR/PatternMatch.h" -#include "mlir/Interfaces/InferTypeOpInterface.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" -#include "stablehlo/dialect/ChloOps.h" - -namespace mlir { - -/// Needed to build `llvm::SmallSet`s and `llvm::EquivalenceClasses` of -/// `mlir::Value`s. -static bool operator<(const Value &lhs, const Value &rhs) { - return lhs.getAsOpaquePointer() < rhs.getAsOpaquePointer(); -} - -namespace mhlo { - -#define GEN_PASS_DEF_RANKSPECIALIZATIONCLUSTERPASS -#define GEN_PASS_DEF_RANKSPECIALIZATIONTOSCFPASS -#include "mhlo/transforms/mhlo_passes.h.inc" - -namespace { - -/// Identify clusters of operations that can be rank-specialized together. The -/// required traits for clustered operations are: -/// - Element-wise: All operations in the group must be element-wise. This -/// allows to reshape operands before applying the operations as well as -/// reshaping the result to the desired shape afterwards. This way, we can, -/// e.g., apply unary ops to a completely flattened operand and restore the -/// original shape afterwards. -/// - Broadcasting semantics: All operations must implement broadcasting -/// semantics. Most importantly, this allows extending operand shapes such -/// that they match in rank. Operations that require all their operands to -/// be of the same shape also fulfill this requirement. -/// - Shape reification: All operations must implement -/// `InferShapedTypeOpInterface`. This is later needed to compute and to -/// restore the desired result shape. - -bool isClusterable(Operation *op) { - if (!llvm::isa(op)) return false; - if (op->getNumOperands() == 0) return false; - return (op->hasTrait() && - op->hasTrait()) || - op->hasTrait(); -} - -struct RankSpecializationClusterPattern : public RewritePattern { - explicit RankSpecializationClusterPattern(MLIRContext *ctx) - : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, ctx) {} - - LogicalResult matchAndRewrite(Operation *op, - PatternRewriter &rewriter) const override { - // Only apply to operations that have not been clustered yet. - if (op->getParentOfType()) { - return failure(); - } - - // Only cluster when rank specialization is needed. - if (!isClusterable(op) || !llvm::any_of(op->getOperandTypes(), [](Type ty) { - return ty.isa(); - })) { - return failure(); - } - - // Collect all collectively rank specializable ops. - SmallVector cluster; - llvm::SmallSet operandSet; - llvm::SmallSet resultSet; - - Operation *rootOp = op; - while (rootOp->getNextNode() != nullptr && - isClusterable(rootOp->getNextNode())) - rootOp = rootOp->getNextNode(); - - Operation *it = rootOp; - while (it != nullptr && isClusterable(it)) { - // Find results that escape the cluster. - for (OpOperand &use : it->getUses()) { - if (!llvm::is_contained(cluster, use.getOwner())) - resultSet.insert(use.get()); - } - - // Update cluster operands. - for (OpResult v : it->getResults()) operandSet.erase(Value(v)); - for (OpOperand &v : it->getOpOperands()) operandSet.insert(v.get()); - - cluster.push_back(it); - it = it->getPrevNode(); - } - - // Create `RankSpecializationClusterOp`. - auto operands = llvm::to_vector<16>(operandSet); - auto results = llvm::to_vector<16>(resultSet); - auto resultTypes = llvm::to_vector<16>( - llvm::map_range(resultSet, [](Value v) { return v.getType(); })); - Location loc = op->getLoc(); - auto clusterOp = rewriter.create( - loc, resultTypes, operands); - - // Create body block. - auto operandTypes = llvm::to_vector<16>( - llvm::map_range(operandSet, [](Value v) { return v.getType(); })); - Block *block = - rewriter.createBlock(&clusterOp.getBody(), {}, operandTypes, - SmallVector(operandTypes.size(), loc)); - - // Copy operations into the body. - IRMapping bvm; - for (auto it : llvm::zip(operands, block->getArguments())) - bvm.map(std::get<0>(it), std::get<1>(it)); - rewriter.setInsertionPointToStart(block); - for (Operation *it : llvm::reverse(cluster)) rewriter.clone(*it, bvm); - - // Create `RankSpecializationClusterYieldOp`. - auto mappedResults = llvm::to_vector<16>( - llvm::map_range(results, [&](Value v) { return bvm.lookup(v); })); - rewriter.create(loc, mappedResults); - - // Replace original ops with the new results. - for (auto it : llvm::zip(results, clusterOp.getResults())) - bvm.map(std::get<0>(it), std::get<1>(it)); - for (Operation *it : cluster) { - if (it->getUses().empty()) { - rewriter.eraseOp(it); - continue; - } - auto replacements = llvm::to_vector<16>(llvm::map_range( - it->getResults(), [&](Value v) { return bvm.lookup(v); })); - rewriter.replaceOp(it, replacements); - } - - return success(); - } -}; - -struct MergeRankSpecializationClusterOpsPattern - : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(chlo::RankSpecializationClusterOp op, - PatternRewriter &rewriter) const override { - auto precedingOp = - llvm::dyn_cast_or_null( - op->getPrevNode()); - if (!precedingOp) return failure(); - Block *body = op.SingleBlock::getBody(); - Block *precedingBody = precedingOp.SingleBlock::getBody(); - auto yieldOp = llvm::dyn_cast( - op.SingleBlock::getBody()->getTerminator()); - auto precedingYieldOp = - llvm::dyn_cast( - precedingOp.SingleBlock::getBody()->getTerminator()); - - // Merge cluster operands. Consider only those operands of the second - // cluster that do not originate in the preceding cluster. - SmallVector newOperands; - for (Value v : precedingOp.getOperands()) newOperands.push_back(v); - for (Value v : op.getOperands()) { - if (v.getDefiningOp() != precedingOp && - !llvm::is_contained(precedingOp.getOperands(), v)) { - newOperands.push_back(v); - } - } - - // Merge cluster results. Consider only those results of the preceding - // cluster that are not exclusively used as operands to the second cluster. - SmallVector newUnmappedResults; - for (auto it : - llvm::zip(precedingOp.getResults(), precedingYieldOp.getResults())) { - Value result, innerResult; - std::tie(result, innerResult) = it; - if (!llvm::all_of(result.getUsers(), - [&](Operation *user) { return user == op; })) { - newUnmappedResults.push_back(innerResult); - } - } - for (Value v : yieldOp.getResults()) newUnmappedResults.push_back(v); - - // Create merged cluster op. - rewriter.setInsertionPoint(precedingOp); - auto loc = op.getLoc(); - auto resultTypes = llvm::to_vector<16>(llvm::map_range( - newUnmappedResults, [](Value v) { return v.getType(); })); - auto newOp = rewriter.create( - loc, resultTypes, newOperands); - auto operandTypes = llvm::to_vector<16>( - llvm::map_range(newOperands, [](Value v) { return v.getType(); })); - Block *newBody = - rewriter.createBlock(&newOp.getBody(), {}, operandTypes, - SmallVector(operandTypes.size(), loc)); - rewriter.setInsertionPointToStart(newBody); - - // Map operands and copy operations of the preceding cluster into the new - // body. - IRMapping bvm; - for (const auto &it : llvm::enumerate(precedingBody->getArguments())) - bvm.map(it.value(), newBody->getArgument(it.index())); - for (Operation &nestedOp : precedingBody->without_terminator()) - rewriter.clone(nestedOp, bvm); - - // Map operands and copy operations of the second cluster. If they result - // from the preceeding cluster, we can simply map the corresponding value - // internally. - for (auto it : llvm::zip(body->getArguments(), op.getOperands())) { - Value blockArg, operand; - std::tie(blockArg, operand) = it; - if (operand.getDefiningOp() == precedingOp) { - auto where = llvm::find(precedingOp.getResults(), operand); - assert(where.getBase() != nullptr && "expected to find "); - bvm.map(blockArg, - bvm.lookup(precedingYieldOp.getOperand(where.getIndex()))); - } else { - auto where = llvm::find(newOp.getOperands(), operand); - bvm.map(blockArg, newBody->getArgument(where.getIndex())); - } - } - for (Operation &nestedOp : body->without_terminator()) { - rewriter.clone(nestedOp, bvm); - } - - // Yield inner results. - rewriter.create( - loc, - llvm::to_vector<16>(llvm::map_range(newUnmappedResults, [&](Value v) { - return bvm.lookupOrDefault(v); - }))); - - // Replace the two cluster ops with the new corresponding results. - SmallVector precedingOpReplacements; - int64_t i = 0; - for (Value result : precedingOp.getResults()) { - Value replacement = nullptr; - if (!llvm::all_of(result.getUsers(), - [&](Operation *user) { return user == op; })) { - replacement = newOp->getResult(i++); - } - precedingOpReplacements.push_back(replacement); - } - ValueRange opReplacements = - newOp.getResults().take_back(op.getNumResults()); - rewriter.replaceOp(op, opReplacements); - rewriter.replaceOp(precedingOp, precedingOpReplacements); - - return success(); - } -}; - -struct RankSpecializationClusterPass - : public impl::RankSpecializationClusterPassBase< - RankSpecializationClusterPass> { - void getDependentDialects(DialectRegistry ®istry) const override { - registry.insert(); - } - - void runOnOperation() override { - MLIRContext *ctx = &getContext(); - RewritePatternSet patterns(ctx); - mhlo::populateRankSpecializationClusterPatterns(ctx, &patterns); - if (failed(applyPatternsAndFoldGreedily(getOperation(), - std::move(patterns)))) { - return signalPassFailure(); - } - } -}; - -/// Lower rank specialization cluster to SCF. - -bool isScalarTensorType(Type ty) { - auto rankedTy = ty.dyn_cast(); - return rankedTy && rankedTy.getRank() == 0; -} - -bool isScalarShapeType(Type ty) { - return ty.cast().getDimSize(0) == 0; -} - -Type deriveRankedTensorTypes(Type ty, int64_t rank) { - auto tensorTy = ty.dyn_cast(); - if (!tensorTy) return ty; - SmallVector shape(rank, ShapedType::kDynamic); - return RankedTensorType::get(shape, tensorTy.getElementType()); -} - -Type deriveUnrankedTensorTypes(Type ty) { - if (auto rankedTy = ty.dyn_cast()) - return UnrankedTensorType::get(rankedTy.getElementType()); - return ty; -} - -SmallVector materializeRankedOperations( - OpBuilder &b, Location loc, IRMapping &bvm, - chlo::RankSpecializationClusterOp op) { - // Create ranked operations. - for (Operation &nestedOp : op.SingleBlock::getBody()->without_terminator()) { - auto mappedOperands = llvm::to_vector<4>(llvm::map_range( - nestedOp.getOperands(), [&](Value v) { return bvm.lookup(v); })); - int64_t targetRank = 0; - for (Value v : mappedOperands) { - targetRank = - std::max(targetRank, v.getType().cast().getRank()); - } - auto rankedResultTypes = llvm::to_vector<2>( - llvm::map_range(nestedOp.getResultTypes(), [targetRank](Type ty) { - return deriveRankedTensorTypes(ty, targetRank); - })); - OperationState rankedOpState(loc, nestedOp.getName().getStringRef(), - mappedOperands, rankedResultTypes, - nestedOp.getAttrs()); - Operation *rankedOp = b.create(rankedOpState); - for (auto it : llvm::zip(nestedOp.getResults(), rankedOp->getResults())) - bvm.map(std::get<0>(it), std::get<1>(it)); - } - - // Collect ranked results. - auto yieldOp = llvm::cast( - op.SingleBlock::getBody()->getTerminator()); - return llvm::to_vector<8>(llvm::map_range( - yieldOp.getResults(), [&](Value v) { return bvm.lookup(v); })); -} - -SmallVector materializeFinalReshape( - PatternRewriter &rewriter, Location loc, - chlo::RankSpecializationClusterOp op, ValueRange unshapedResults) { - auto yieldOp = llvm::cast( - op.SingleBlock::getBody()->getTerminator()); - assert(unshapedResults.size() == 1 && yieldOp.getResults().size() == 1 && - "Currently, rank specialization supports only one result."); - - // Reify result shape. - Operation *lastOpBeforeShapeReification = op->getPrevNode(); - SmallVector resultShape; - Value originalResult = yieldOp.getResults().front(); - auto originalResultIface = - llvm::cast(originalResult.getDefiningOp()); - if (failed(originalResultIface.reifyReturnTypeShapes( - rewriter, originalResultIface->getOperands(), resultShape))) { - return {}; - } - - // Materialize final reshape. - Value unshapedResult = unshapedResults.front(); - Value result = rewriter.create( - loc, deriveUnrankedTensorTypes(unshapedResult.getType()), unshapedResult, - resultShape.front()); - - // Reify shapes until they are independent of operations in the original - // cluster. - { - Operation *it = resultShape.front().getDefiningOp(); - while (it != nullptr && it != lastOpBeforeShapeReification) { - bool advanced = false; - if (auto shapeOfOp = llvm::dyn_cast(it)) { - Operation *def = shapeOfOp.getArg().getDefiningOp(); - if (def && def->getBlock() == op.SingleBlock::getBody()) { - // Resolve `shape_of` op because it still depends on operation in the - // original cluster. - OpBuilder::InsertionGuard guard(rewriter); - rewriter.setInsertionPoint(shapeOfOp); - SmallVector tmpShape; - auto iface = llvm::cast(def); - if (failed(iface.reifyReturnTypeShapes(rewriter, iface->getOperands(), - tmpShape))) - return {}; - rewriter.replaceOp(shapeOfOp, tmpShape.front()); - - // Continue, including the newly created operations. - it = tmpShape.front().getDefiningOp(); - advanced = true; - } - } - - // Skip op, otherwise. - if (!advanced) it = it->getPrevNode(); - } - } - - // Replace all remaining uses of the original cluster's block args. - for (auto it : - llvm::zip(op.getOperands(), op.SingleBlock::getBody()->getArguments())) { - Value operand, barg; - std::tie(operand, barg) = it; - barg.replaceUsesWithIf(operand, [&](OpOperand &operand) { - return operand.getOwner()->getBlock() != op.SingleBlock::getBody(); - }); - } - - return {result}; -} - -Value materializeFlatShape(OpBuilder &b, Location loc, ValueRange sameShapes) { - assert(!sameShapes.empty() && "Expected at least one shape."); - Value shape = sameShapes.size() == 1 - ? sameShapes.front() - : b.create(loc, sameShapes.front().getType(), - sameShapes); - return b.create( - loc, - b.create(loc, b.getIndexType(), shape).getResult()); -} - -Value materializeScalarRankSpecializationCase( - OpBuilder &b, Location loc, chlo::RankSpecializationClusterOp op, - const SmallVector &shapes, ValueRange nonScalarsOfSameShape, - function_ref elseBuilderFn) { - // Materialize predicate: All operands are scalars, except the expected - // non-scalars. - Value one = b.create(loc, 1); - Value allOthersAreScalar; - for (auto it : llvm::zip(op.getOperands(), shapes)) { - Value operand, shape; - std::tie(operand, shape) = it; - if (llvm::is_contained(nonScalarsOfSameShape, operand) || - isScalarTensorType(operand.getType())) { - continue; - } - auto literal = b.create( - loc, arith::CmpIPredicate::eq, - b.create(loc, shape), one); - allOthersAreScalar = - allOthersAreScalar - ? b.create(loc, allOthersAreScalar, literal) - .getResult() - : literal.getResult(); - } - - auto ifOp = b.create( - loc, allOthersAreScalar, - [&](OpBuilder &b, Location loc) { - // Compute flat non-scalar shape. - SmallVector nonScalarShapes; - for (auto it : llvm::zip(op.getOperands(), shapes)) { - Value operand, shape; - std::tie(operand, shape) = it; - if (llvm::is_contained(nonScalarsOfSameShape, operand)) - nonScalarShapes.push_back(shape); - } - Value flatShape = materializeFlatShape(b, loc, nonScalarShapes); - - // Derive ranked operands. - auto rankedOperands = llvm::to_vector<8>( - llvm::map_range(op.getOperands(), [&](Value v) -> Value { - if (isScalarTensorType(v.getType())) return v; - if (!llvm::is_contained(nonScalarsOfSameShape, v)) { - return b - .create( - loc, deriveRankedTensorTypes(v.getType(), /*rank=*/0), - v) - .getResult(); - } - return b - .create( - loc, deriveRankedTensorTypes(v.getType(), /*rank=*/1), v, - flatShape) - .getResult(); - })); - - // Materialize ranked variants for the element-wise operations. - IRMapping bvm; - for (auto it : llvm::zip(op.SingleBlock::getBody()->getArguments(), - rankedOperands)) - bvm.map(std::get<0>(it), std::get<1>(it)); - Value unshapedResult = - materializeRankedOperations(b, loc, bvm, op).front(); - - // Return as unranked tensor for compatibility with the other cases. - b.create( - loc, b.create( - loc, deriveUnrankedTensorTypes(unshapedResult.getType()), - unshapedResult) - .getDest()); - }, - elseBuilderFn); - - return ifOp.getResults().front(); -} - -Value materializeEqualShapesRankSpecializationCase( - OpBuilder &b, Location loc, chlo::RankSpecializationClusterOp op, - const SmallVector &shapes, - function_ref elseBuilderFn) { - // Materialize all shapes equal predicate. - Value allShapesEqOrScalar; - auto nonScalarShapes = llvm::to_vector<8>(llvm::make_filter_range( - shapes, [](Value v) { return !isScalarShapeType(v.getType()); })); - assert( - nonScalarShapes.size() >= 2 && - "Equal shapes strategy requires at least two non-scalar operand shapes."); - for (Value s : llvm::drop_begin(nonScalarShapes)) { - auto literal = b.create(loc, nonScalarShapes.front(), s); - allShapesEqOrScalar = - allShapesEqOrScalar - ? b.create(loc, allShapesEqOrScalar, literal) - .getResult() - : literal; - } - - auto ifOp = b.create( - loc, allShapesEqOrScalar, - [&](OpBuilder &b, Location loc) { - // Flatten non-scalar operands. - Value flatShape = materializeFlatShape(b, loc, nonScalarShapes); - auto flatOperands = llvm::to_vector<8>( - llvm::map_range(op.getOperands(), [&](Value v) -> Value { - if (isScalarTensorType(v.getType())) return v; - return b.create( - loc, deriveRankedTensorTypes(v.getType(), /*rank=*/1), v, - flatShape); - })); - - // Materialize ranked variants for the element-wise operations. - IRMapping bvm; - for (auto it : - llvm::zip(op.SingleBlock::getBody()->getArguments(), flatOperands)) - bvm.map(std::get<0>(it), std::get<1>(it)); - Value unshapedResult = - materializeRankedOperations(b, loc, bvm, op).front(); - - // Return as unranked tensor for compatibility with the other cases. - b.create( - loc, b.create( - loc, deriveUnrankedTensorTypes(unshapedResult.getType()), - unshapedResult) - .getDest()); - }, - elseBuilderFn); - - return ifOp.getResults().front(); -} - -Value materializeTargetRankSpecializationCase( - OpBuilder &b, Location loc, chlo::RankSpecializationClusterOp op, - const SmallVector &shapes, int64_t targetRank) { - // Reshape unranked operands to match the target rank. - RankedTensorType extentTensorTy = - shape::getExtentTensorType(b.getContext(), targetRank); - Value allOnesShape = b.create( - loc, extentTensorTy, - mlir::DenseIntElementsAttr::get(extentTensorTy, - SmallVector(targetRank, 1))); - SmallVector rankedOperands; - for (auto it : llvm::zip(op.getOperands(), shapes)) { - Value operand, shape; - std::tie(operand, shape) = it; - if (operand.getType().isa()) { - rankedOperands.push_back(operand); - continue; - } - Value rankedShape = b.create( - loc, extentTensorTy, - b.create(loc, - shape::getExtentTensorType(b.getContext()), - shape, allOnesShape, - /*error=*/nullptr)); - rankedOperands.push_back(b.create( - loc, deriveRankedTensorTypes(operand.getType(), targetRank), operand, - rankedShape)); - } - - // Materialize ranked versions of the element-wise operations. - IRMapping bvm; - for (auto it : llvm::zip(op.getBody().front().getArguments(), rankedOperands)) - bvm.map(std::get<0>(it), std::get<1>(it)); - - // Return as unranked for compatibility with other target ranks. - auto unshapedResult = materializeRankedOperations(b, loc, bvm, op).front(); - return b.create( - loc, deriveUnrankedTensorTypes(unshapedResult.getType()), unshapedResult); -} - -Value recusivelyMaterializeTargetRankSpecializationCases( - OpBuilder &b, Location loc, chlo::RankSpecializationClusterOp op, - const SmallVector &shapes, Value maxRank, int64_t minTargetRank, - int64_t maxTargetRank) { - Value condition = b.create( - loc, arith::CmpIPredicate::ule, maxRank, - b.create(loc, minTargetRank)); - - // If only a unique target rank is left, we can lower to an assert instead - // of the usual if operation. - if (minTargetRank == maxTargetRank) { - b.create( - loc, condition, - "Input for dynamic binary or n-ary op lowering was of " - "a rank greater than " + - std::to_string(maxTargetRank)); - return materializeTargetRankSpecializationCase(b, loc, op, shapes, - minTargetRank); - } - - // Materialize IR for the smallest considered target rank. - auto ifOp = b.create(loc, op->getResultTypes(), condition, - /*withElseRegion=*/true); - auto thenBuilder = ifOp.getThenBodyBuilder(); - thenBuilder.create( - loc, materializeTargetRankSpecializationCase(thenBuilder, loc, op, shapes, - minTargetRank)); - - // Recurse for all remaining target ranks. - auto elseBuilder = ifOp.getElseBodyBuilder(); - elseBuilder.create( - loc, recusivelyMaterializeTargetRankSpecializationCases( - elseBuilder, loc, op, shapes, maxRank, minTargetRank + 1, - maxTargetRank)); - - return ifOp.getResults().front(); -} - -Value materializeGenericRankSpecializationCases( - OpBuilder &b, Location loc, chlo::RankSpecializationClusterOp op, - const SmallVector &shapes, int64_t maxTargetRank) { - // Get the minimum broadcast shapes of the operands. - auto nonScalarShapes = llvm::to_vector<8>(llvm::make_filter_range( - shapes, [](Value v) { return !isScalarShapeType(v.getType()); })); - auto minBcastShapesOp = b.create( - loc, - SmallVector(nonScalarShapes.size(), - shape::getExtentTensorType(b.getContext())), - nonScalarShapes); - - // Find the maximum rank among the reduced operand shapes. - Value maxRank; - for (Value shape : minBcastShapesOp.getResults()) { - Value rank = b.create(loc, b.getIndexType(), shape); - if (!maxRank) { - maxRank = rank; - } else { - maxRank = b.create( - loc, - b.create(loc, arith::CmpIPredicate::sgt, maxRank, - rank), - maxRank, rank); - } - } - - // Collect reduced shapes. - SmallVector reducedShapes; - auto it = minBcastShapesOp.result_begin(); - for (Value s : shapes) { - if (isScalarShapeType(s.getType())) { - reducedShapes.push_back(s); - } else { - reducedShapes.push_back(*it++); - } - } - - // Materialize rank specialization for ranks 1, ... - return recusivelyMaterializeTargetRankSpecializationCases( - b, loc, op, reducedShapes, maxRank, /*minTargetRank=*/1, maxTargetRank); -} - -Value materializeDefaultRankSpecializationCases( - OpBuilder &b, Location loc, chlo::RankSpecializationClusterOp op, - const SmallVector &shapes, int64_t maxTargetRank) { - return materializeEqualShapesRankSpecializationCase( - b, loc, op, shapes, [&](OpBuilder &b, Location loc) { - b.create(loc, materializeGenericRankSpecializationCases( - b, loc, op, shapes, maxTargetRank)); - }); -} - -SmallVector -materializeRankSpecializationForSingleNonScalarShapeEquivalenceClass( - PatternRewriter &rewriter, Location loc, - chlo::RankSpecializationClusterOp op, ValueRange nonScalarsOfSameShape) { - // Compute flat operand shape. - auto nonScalarShapes = - llvm::to_vector<4>(llvm::map_range(nonScalarsOfSameShape, [&](Value v) { - return rewriter.create(loc, v).getResult(); - })); - Value flatShape = materializeFlatShape(rewriter, loc, nonScalarShapes); - - // Materialize ranked variants for the element-wise operations. - IRMapping bvm; - for (auto it : - llvm::zip(op.SingleBlock::getBody()->getArguments(), op.getOperands())) { - Value operand; - Value bbArg; - std::tie(bbArg, operand) = it; - if (!isScalarTensorType(operand.getType())) { - assert(llvm::is_contained(nonScalarsOfSameShape, operand) && - "Expected all non-scalars in the same shape equivalence class."); - operand = rewriter.create( - loc, deriveRankedTensorTypes(operand.getType(), /*rank=*/1), operand, - flatShape); - } - bvm.map(bbArg, operand); - } - SmallVector unshapedResults = - materializeRankedOperations(rewriter, loc, bvm, op); - - // Restore the results' expected shape. - Value shape = nonScalarShapes.front(); - return llvm::to_vector<8>( - llvm::map_range(unshapedResults, [&](Value v) -> Value { - return rewriter.create( - loc, deriveUnrankedTensorTypes(v.getType()), v, shape); - })); -} - -Value materializeRankSpecializationForTwoNonScalarShapeEquivalenceClasses( - PatternRewriter &rewriter, Location loc, - chlo::RankSpecializationClusterOp op, - SmallVector, 4> nonScalarEqs, int64_t maxTargetRank) { - assert(nonScalarEqs.size() == 2 && - "Expect two non-scalar equivalence classes."); - auto shapes = - llvm::to_vector<8>(llvm::map_range(op.getOperands(), [&](Value v) { - return rewriter.create(loc, v).getResult(); - })); - ValueRange lhsNonScalarEqs = nonScalarEqs[0]; - ValueRange rhsNonScalarEqs = nonScalarEqs[1]; - - // Materialize all the different cases. - Value unshapedResult = materializeScalarRankSpecializationCase( - rewriter, loc, op, shapes, rhsNonScalarEqs, - [&](OpBuilder &b, Location loc) { - b.create( - loc, materializeScalarRankSpecializationCase( - b, loc, op, shapes, lhsNonScalarEqs, - [&](OpBuilder &b, Location loc) { - b.create( - loc, materializeDefaultRankSpecializationCases( - b, loc, op, shapes, maxTargetRank)); - })); - }); - - // Materialize final reshape once and for all rank specialization cases. - return materializeFinalReshape(rewriter, loc, op, unshapedResult).front(); -} - -// Materialize rank generic rank specialization. -Value materializeDefaultRankSpecialization(PatternRewriter &rewriter, - Location loc, - chlo::RankSpecializationClusterOp op, - int64_t maxTargetRank) { - auto shapes = - llvm::to_vector<8>(llvm::map_range(op.getOperands(), [&](Value v) { - return rewriter.create(loc, v).getResult(); - })); - - // Materialize all the different cases. - Value unshapedResult = materializeDefaultRankSpecializationCases( - rewriter, loc, op, shapes, maxTargetRank); - - // Materialize final reshape once and for all rank specialization cases. - return materializeFinalReshape(rewriter, loc, op, unshapedResult).front(); -} - -// This is a very limited form of shape inference. It is correct but incomplete. -SmallVector, 4> findNonScalarShapeEquivalences( - chlo::RankSpecializationClusterOp op) { - llvm::EquivalenceClasses eqs; - - // Bridge the equivalences between operands and block arguments. - for (auto it : - llvm::zip(op.getOperands(), op.SingleBlock::getBody()->getArguments())) - eqs.unionSets(std::get<0>(it), std::get<1>(it)); - - // Find equalities through `SameOperandsAndResultShape` trait. - auto unionSets = [&](ValueRange vs) { - if (vs.empty()) return; - Value repr = vs.front(); - for (Value v : vs.drop_front()) eqs.unionSets(repr, v); - }; - for (Operation &nestedOp : op.SingleBlock::getBody()->without_terminator()) { - if (nestedOp.hasTrait()) { - unionSets(nestedOp.getOperands()); - unionSets(nestedOp.getResults()); - if (!nestedOp.getOperands().empty() && !nestedOp.getResults().empty()) - eqs.unionSets(nestedOp.getResult(0), nestedOp.getOperand(0)); - } - } - - // Find shape equalities through surrounding constraints. - if (auto assumingOp = op->getParentOfType()) { - SmallVector queue; - auto appendIfNotNull = [&](Operation *op) { - if (op != nullptr) queue.push_back(op); - }; - appendIfNotNull(assumingOp.getWitness().getDefiningOp()); - while (!queue.empty()) { - Operation *it = queue.pop_back_val(); - if (auto assumingAllOp = llvm::dyn_cast(it)) { - for (Value v : assumingAllOp.getInputs()) - appendIfNotNull(v.getDefiningOp()); - } else if (auto cstrEqOp = llvm::dyn_cast(it)) { - Value refArg; - for (Value v : cstrEqOp.getShapes()) { - if (auto shapeOfOp = - dyn_cast_or_null(v.getDefiningOp())) { - if (!refArg) { - refArg = shapeOfOp.getArg(); - } else { - eqs.unionSets(refArg, shapeOfOp.getArg()); - } - } - } - } - } - } - - // Find equalities through special knowledge of ops. - // TODO(frgossen): Remove this when these shape equalities can be inferred - // from surrounding shape constraints. - for (Operation &nestedOp : op.SingleBlock::getBody()->without_terminator()) { - if (auto selectOp = llvm::dyn_cast(nestedOp)) { - unionSets( - {selectOp.getOnTrue(), selectOp.getOnFalse(), selectOp.getResult()}); - } else if (auto clampOp = llvm::dyn_cast(nestedOp)) { - unionSets({clampOp.getOperand(), clampOp.getResult()}); - } - } - - // Convert to a list-like equivalence class representation. - SmallVector, 4> nonScalarEqs; - for (Value v : op.getOperands()) { - if (isScalarTensorType(v.getType())) continue; - bool inserted = false; - for (auto &eqClass : nonScalarEqs) { - if (eqs.isEquivalent(eqClass.front(), v)) { - eqClass.push_back(v); - inserted = true; - break; - } - } - if (!inserted) nonScalarEqs.push_back(SmallVector({v})); - } - - return nonScalarEqs; -} - -struct LowerRankSpecializationClusterPattern - : public OpRewritePattern { - LowerRankSpecializationClusterPattern(MLIRContext *ctx, int64_t maxTargetRank) - : OpRewritePattern(ctx, /*benefit=*/1), - maxTargetRank(maxTargetRank) {} - - LogicalResult matchAndRewrite(chlo::RankSpecializationClusterOp op, - PatternRewriter &rewriter) const override { - // Restoring the result shape currently relies on all operands being used - // for a single result. The result shape is then the broadcasted shape of - // all operands. - if (op.getNumResults() != 1) return failure(); - - // If there is only a single non-scalar shape equivalence class, we can - // flatten that operands completely. - SmallVector, 4> nonScalarEqs = - findNonScalarShapeEquivalences(op); - Location loc = op.getLoc(); - if (nonScalarEqs.size() == 1) { - rewriter.replaceOp( - op, - materializeRankSpecializationForSingleNonScalarShapeEquivalenceClass( - rewriter, loc, op, nonScalarEqs.front())); - return success(); - } - - // If there are exactly two non-scalar shape equivalence classes, we can - // consider two extra cases: If either of the operand classes turns out to - // be all-scalars at runtime, we can, again, flatten all operands. - if (nonScalarEqs.size() == 2) { - rewriter.replaceOp( - op, - materializeRankSpecializationForTwoNonScalarShapeEquivalenceClasses( - rewriter, loc, op, nonScalarEqs, maxTargetRank)); - return success(); - } - - // For all other cases, reshape the operands to match in rank, apply the - // operation, and restore the expected shape. - rewriter.replaceOp(op, materializeDefaultRankSpecialization( - rewriter, loc, op, maxTargetRank)); - return success(); - } - - private: - int64_t maxTargetRank; -}; - -struct RankSpecializationToSCFPass - : public impl::RankSpecializationToSCFPassBase< - RankSpecializationToSCFPass> { - explicit RankSpecializationToSCFPass(int64_t maxTargetRank) - : RankSpecializationToSCFPassBase< - RankSpecializationToSCFPass>::RankSpecializationToSCFPassBase() { - this->max_target_rank_ = maxTargetRank; - } - - void getDependentDialects(DialectRegistry ®istry) const override { - registry - .insert(); - } - - void runOnOperation() override { - MLIRContext *ctx = &getContext(); - RewritePatternSet patterns(ctx); - populateRankSpecializationToSCFPatterns(ctx, &patterns, - this->max_target_rank_); - if (failed(applyPatternsAndFoldGreedily(getOperation(), - std::move(patterns)))) { - return signalPassFailure(); - } - } -}; - -} // namespace - -void populateRankSpecializationClusterPatterns(MLIRContext *context, - RewritePatternSet *patterns) { - patterns->add(context); -} - -void populateRankSpecializationToSCFPatterns(MLIRContext *context, - RewritePatternSet *patterns, - int64_t maxTargetRank) { - patterns->add(context, maxTargetRank); - shape::BroadcastOp::getCanonicalizationPatterns(*patterns, context); - shape::ShapeOfOp::getCanonicalizationPatterns(*patterns, context); - shape::AnyOp::getCanonicalizationPatterns(*patterns, context); -} - -std::unique_ptr> -createRankSpecializationClusterPass() { - return std::make_unique(); -} - -std::unique_ptr> createRankSpecializationToSCFPass( - int64_t maxTargetRank) { - return std::make_unique(maxTargetRank); -} - -} // namespace mhlo -} // namespace mlir diff --git a/third_party/xla/xla/mlir_hlo/mhlo/transforms/rewriters.h b/third_party/xla/xla/mlir_hlo/mhlo/transforms/rewriters.h index c8f8aa4b885dd8..4de69f27d268e0 100644 --- a/third_party/xla/xla/mlir_hlo/mhlo/transforms/rewriters.h +++ b/third_party/xla/xla/mlir_hlo/mhlo/transforms/rewriters.h @@ -203,6 +203,12 @@ void populateChloBroadcastingPatterns(MLIRContext *context, void populateDecomposeChloPatterns(MLIRContext *context, RewritePatternSet *patterns); +// Adds pattern to decompose specific CHLO ops like ErfOp and TopKOp to their +// basis set of operations. These ops have 1:1 corresponding MHLO ops, but for +// certain backends, they need to be expanded. +void populateChloLegalizeToHloBasisOpsPatterns(MLIRContext *context, + RewritePatternSet *patterns); + } // namespace chlo namespace stablehlo { diff --git a/third_party/xla/xla/mlir_hlo/tests/BUILD b/third_party/xla/xla/mlir_hlo/tests/BUILD index 5e88c0e5ebc584..4db81efeb2dd52 100644 --- a/third_party/xla/xla/mlir_hlo/tests/BUILD +++ b/third_party/xla/xla/mlir_hlo/tests/BUILD @@ -3,7 +3,6 @@ load("@bazel_skylib//rules:expand_template.bzl", "expand_template") load("@llvm-project//llvm:lit_test.bzl", "lit_test", "package_path") package( - default_visibility = ["//visibility:public"], # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], licenses = ["notice"], ) @@ -35,7 +34,6 @@ package( cc_library( name = "capi_test", srcs = ["capi_test.c"], - visibility = ["//visibility:public"], deps = [ "//xla/mlir_hlo:CAPI", ], diff --git a/third_party/xla/xla/mlir_hlo/tests/Dialect/chlo/chlo_legalize_to_hlo_broadcasts.mlir b/third_party/xla/xla/mlir_hlo/tests/Dialect/chlo/chlo_legalize_to_hlo_broadcasts.mlir index ce5a493fd6e894..512e1cea6ff752 100644 --- a/third_party/xla/xla/mlir_hlo/tests/Dialect/chlo/chlo_legalize_to_hlo_broadcasts.mlir +++ b/third_party/xla/xla/mlir_hlo/tests/Dialect/chlo/chlo_legalize_to_hlo_broadcasts.mlir @@ -151,7 +151,7 @@ func.func @selectv2_dynamic_ranked(%arg0: tensor<1xi1>, %arg1: tensor<2x?x8xi32> // CHECK-LABEL: @dynamicNonScalarBroadcastDimensions func.func @dynamicNonScalarBroadcastDimensions(%arg0: tensor<1x4xf32>, %arg1: tensor<4xf32>) -> tensor<1x4xf32> { // CHECK: mhlo.add - %0 = chlo.broadcast_add %arg0, %arg1 {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1x4xf32>, tensor<4xf32>) -> tensor<1x4xf32> + %0 = chlo.broadcast_add %arg0, %arg1 {broadcast_dimensions = array} : (tensor<1x4xf32>, tensor<4xf32>) -> tensor<1x4xf32> func.return %0 : tensor<1x4xf32> } @@ -160,7 +160,7 @@ func.func @dynamicNonScalarBroadcastDimensions(%arg0: tensor<1x4xf32>, %arg1: te // CHECK-LABEL: @dynamicNonScalarByScalarBroadcastDimensions func.func @dynamicNonScalarByScalarBroadcastDimensions(%arg0: tensor<1x4xf32>, %arg1: tensor) -> tensor<1x4xf32> { // CHECK: mhlo.add - %0 = chlo.broadcast_add %arg0, %arg1 {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<1x4xf32>, tensor) -> tensor<1x4xf32> + %0 = chlo.broadcast_add %arg0, %arg1 {broadcast_dimensions = array} : (tensor<1x4xf32>, tensor) -> tensor<1x4xf32> func.return %0 : tensor<1x4xf32> } @@ -169,7 +169,7 @@ func.func @dynamicNonScalarByScalarBroadcastDimensions(%arg0: tensor<1x4xf32>, % func.func @dynamicNonScalarBroadcastDimensionsSizeMismatch(%arg0: tensor<1x4xf32>, %arg1: tensor<4xf32>) -> tensor<1x4xf32> { // expected-warning @+2 {{unsupported non prefix-padded dynamic rank broadcast_dimensions}} // expected-error @+1 {{failed to legalize operation}} - %0 = chlo.broadcast_add %arg0, %arg1 {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (tensor<1x4xf32>, tensor<4xf32>) -> tensor<1x4xf32> + %0 = chlo.broadcast_add %arg0, %arg1 {broadcast_dimensions = array} : (tensor<1x4xf32>, tensor<4xf32>) -> tensor<1x4xf32> func.return %0 : tensor<1x4xf32> } @@ -178,7 +178,7 @@ func.func @dynamicNonScalarBroadcastDimensionsSizeMismatch(%arg0: tensor<1x4xf32 func.func @dynamicNonScalarBroadcastDimensionsMismatch(%arg0: tensor<1x4xf32>, %arg1: tensor<4xf32>) -> tensor<1x4xf32> { // expected-warning @+2 {{unsupported non prefix-padded dynamic rank broadcast_dimensions}} // expected-error @+1 {{failed to legalize operation}} - %0 = chlo.broadcast_add %arg0, %arg1 {broadcast_dimensions = dense<2> : tensor<1xi64>} : (tensor<1x4xf32>, tensor<4xf32>) -> tensor<1x4xf32> + %0 = chlo.broadcast_add %arg0, %arg1 {broadcast_dimensions = array} : (tensor<1x4xf32>, tensor<4xf32>) -> tensor<1x4xf32> func.return %0 : tensor<1x4xf32> } diff --git a/third_party/xla/xla/mlir_hlo/tests/Dialect/chlo/chlo_legalize_to_mhlo.mlir b/third_party/xla/xla/mlir_hlo/tests/Dialect/chlo/chlo_legalize_to_mhlo.mlir index 3e72efd620c746..d30a4a8136e9d7 100644 --- a/third_party/xla/xla/mlir_hlo/tests/Dialect/chlo/chlo_legalize_to_mhlo.mlir +++ b/third_party/xla/xla/mlir_hlo/tests/Dialect/chlo/chlo_legalize_to_mhlo.mlir @@ -262,147 +262,7 @@ func.func @conj(%arg0: tensor<3xcomplex>) -> tensor<3xcomplex> { // CHECK-LABEL: @erf_f64 // CHECK-SAME: %[[ARG:.*]]: tensor func.func @erf_f64(%arg : tensor) -> tensor { - // CHECK: %[[TMP_0:.*]] = mhlo.multiply %[[ARG]], %[[ARG]] - // CHECK: %[[TMP_3:.*]] = mhlo.constant dense<9.6049737398705161> - // CHECK: %[[TMP_5:.*]] = mhlo.multiply %[[TMP_3]], %[[TMP_0]] - // CHECK: %[[TMP_6:.*]] = mhlo.constant dense<90.026019720384269> - // CHECK: %[[TMP_7:.*]] = mhlo.add %[[TMP_5]], %[[TMP_6]] - // CHECK: %[[TMP_8:.*]] = mhlo.multiply %[[TMP_7]], %[[TMP_0]] - // CHECK: %[[TMP_9:.*]] = mhlo.constant dense<2232.0053459468431> - // CHECK: %[[TMP_10:.*]] = mhlo.add %[[TMP_8]], %[[TMP_9]] - // CHECK: %[[TMP_11:.*]] = mhlo.multiply %[[TMP_10]], %[[TMP_0]] - // CHECK: %[[TMP_12:.*]] = mhlo.constant dense<7003.3251411280507> - // CHECK: %[[TMP_13:.*]] = mhlo.add %[[TMP_11]], %[[TMP_12]] - // CHECK: %[[TMP_14:.*]] = mhlo.multiply %[[TMP_13]], %[[TMP_0]] - // CHECK: %[[TMP_15:.*]] = mhlo.constant dense<55592.301301039493> - // CHECK: %[[TMP_16:.*]] = mhlo.add %[[TMP_14]], %[[TMP_15]] - // CHECK: %[[TMP_17:.*]] = mhlo.multiply %[[ARG]], %[[TMP_16]] - // CHECK: %[[TMP_20:.*]] = mhlo.constant dense<1.000000e+00> - // CHECK: %[[TMP_22:.*]] = mhlo.multiply %[[TMP_20]], %[[TMP_0]] - // CHECK: %[[TMP_23:.*]] = mhlo.constant dense<33.561714164750313> - // CHECK: %[[TMP_24:.*]] = mhlo.add %[[TMP_22]], %[[TMP_23]] - // CHECK: %[[TMP_25:.*]] = mhlo.multiply %[[TMP_24]], %[[TMP_0]] - // CHECK: %[[TMP_26:.*]] = mhlo.constant dense<521.35794978015269> - // CHECK: %[[TMP_27:.*]] = mhlo.add %[[TMP_25]], %[[TMP_26]] - // CHECK: %[[TMP_28:.*]] = mhlo.multiply %[[TMP_27]], %[[TMP_0]] - // CHECK: %[[TMP_29:.*]] = mhlo.constant dense<4594.3238297098014> - // CHECK: %[[TMP_30:.*]] = mhlo.add %[[TMP_28]], %[[TMP_29]] - // CHECK: %[[TMP_31:.*]] = mhlo.multiply %[[TMP_30]], %[[TMP_0]] - // CHECK: %[[TMP_32:.*]] = mhlo.constant dense<22629.000061389095> - // CHECK: %[[TMP_33:.*]] = mhlo.add %[[TMP_31]], %[[TMP_32]] - // CHECK: %[[TMP_34:.*]] = mhlo.multiply %[[TMP_33]], %[[TMP_0]] - // CHECK: %[[TMP_35:.*]] = mhlo.constant dense<49267.394260863592> - // CHECK: %[[TMP_36:.*]] = mhlo.add %[[TMP_34]], %[[TMP_35]] - // CHECK: %[[TMP_37:.*]] = mhlo.divide %[[TMP_17]], %[[TMP_36]] - // CHECK: %[[TMP_38:.*]] = mhlo.constant dense<1.000000e+00> - // CHECK: %[[TMP_39:.*]] = mhlo.multiply %[[ARG]], %[[ARG]] - // CHECK: %[[TMP_40:.*]] = mhlo.negate %[[TMP_39]] - // CHECK: %[[TMP_41:.*]] = mhlo.exponential %[[TMP_40]] - // CHECK: %[[TMP_42:.*]] = mhlo.abs %[[ARG]] - // CHECK: %[[TMP_45:.*]] = mhlo.constant dense<2.4619698147353052E-10> - // CHECK: %[[TMP_47:.*]] = mhlo.multiply %[[TMP_45]], %[[TMP_42]] - // CHECK: %[[TMP_48:.*]] = mhlo.constant dense<0.56418956483106886> - // CHECK: %[[TMP_49:.*]] = mhlo.add %[[TMP_47]], %[[TMP_48]] - // CHECK: %[[TMP_50:.*]] = mhlo.multiply %[[TMP_49]], %[[TMP_42]] - // CHECK: %[[TMP_51:.*]] = mhlo.constant dense<7.4632105644226989> - // CHECK: %[[TMP_52:.*]] = mhlo.add %[[TMP_50]], %[[TMP_51]] - // CHECK: %[[TMP_53:.*]] = mhlo.multiply %[[TMP_52]], %[[TMP_42]] - // CHECK: %[[TMP_54:.*]] = mhlo.constant dense<48.637197098568137> - // CHECK: %[[TMP_55:.*]] = mhlo.add %[[TMP_53]], %[[TMP_54]] - // CHECK: %[[TMP_56:.*]] = mhlo.multiply %[[TMP_55]], %[[TMP_42]] - // CHECK: %[[TMP_57:.*]] = mhlo.constant dense<196.5208329560771> - // CHECK: %[[TMP_58:.*]] = mhlo.add %[[TMP_56]], %[[TMP_57]] - // CHECK: %[[TMP_59:.*]] = mhlo.multiply %[[TMP_58]], %[[TMP_42]] - // CHECK: %[[TMP_60:.*]] = mhlo.constant dense<526.44519499547732> - // CHECK: %[[TMP_61:.*]] = mhlo.add %[[TMP_59]], %[[TMP_60]] - // CHECK: %[[TMP_62:.*]] = mhlo.multiply %[[TMP_61]], %[[TMP_42]] - // CHECK: %[[TMP_63:.*]] = mhlo.constant dense<934.52852717195765> - // CHECK: %[[TMP_64:.*]] = mhlo.add %[[TMP_62]], %[[TMP_63]] - // CHECK: %[[TMP_65:.*]] = mhlo.multiply %[[TMP_64]], %[[TMP_42]] - // CHECK: %[[TMP_66:.*]] = mhlo.constant dense<1027.5518868951572> - // CHECK: %[[TMP_67:.*]] = mhlo.add %[[TMP_65]], %[[TMP_66]] - // CHECK: %[[TMP_68:.*]] = mhlo.multiply %[[TMP_67]], %[[TMP_42]] - // CHECK: %[[TMP_69:.*]] = mhlo.constant dense<557.53533536939938> - // CHECK: %[[TMP_70:.*]] = mhlo.add %[[TMP_68]], %[[TMP_69]] - // CHECK: %[[TMP_71:.*]] = mhlo.multiply %[[TMP_41]], %[[TMP_70]] - // CHECK: %[[TMP_74:.*]] = mhlo.constant dense<1.000000e+00> - // CHECK: %[[TMP_76:.*]] = mhlo.multiply %[[TMP_74]], %[[TMP_42]] - // CHECK: %[[TMP_77:.*]] = mhlo.constant dense<13.228195115474499> - // CHECK: %[[TMP_78:.*]] = mhlo.add %[[TMP_76]], %[[TMP_77]] - // CHECK: %[[TMP_79:.*]] = mhlo.multiply %[[TMP_78]], %[[TMP_42]] - // CHECK: %[[TMP_80:.*]] = mhlo.constant dense<86.707214088598973> - // CHECK: %[[TMP_81:.*]] = mhlo.add %[[TMP_79]], %[[TMP_80]] - // CHECK: %[[TMP_82:.*]] = mhlo.multiply %[[TMP_81]], %[[TMP_42]] - // CHECK: %[[TMP_83:.*]] = mhlo.constant dense<354.93777888781989> - // CHECK: %[[TMP_84:.*]] = mhlo.add %[[TMP_82]], %[[TMP_83]] - // CHECK: %[[TMP_85:.*]] = mhlo.multiply %[[TMP_84]], %[[TMP_42]] - // CHECK: %[[TMP_86:.*]] = mhlo.constant dense<975.70850174320549> - // CHECK: %[[TMP_87:.*]] = mhlo.add %[[TMP_85]], %[[TMP_86]] - // CHECK: %[[TMP_88:.*]] = mhlo.multiply %[[TMP_87]], %[[TMP_42]] - // CHECK: %[[TMP_89:.*]] = mhlo.constant dense<1823.9091668790973> - // CHECK: %[[TMP_90:.*]] = mhlo.add %[[TMP_88]], %[[TMP_89]] - // CHECK: %[[TMP_91:.*]] = mhlo.multiply %[[TMP_90]], %[[TMP_42]] - // CHECK: %[[TMP_92:.*]] = mhlo.constant dense<2246.3376081871097> - // CHECK: %[[TMP_93:.*]] = mhlo.add %[[TMP_91]], %[[TMP_92]] - // CHECK: %[[TMP_94:.*]] = mhlo.multiply %[[TMP_93]], %[[TMP_42]] - // CHECK: %[[TMP_95:.*]] = mhlo.constant dense<1656.6630919416134> - // CHECK: %[[TMP_96:.*]] = mhlo.add %[[TMP_94]], %[[TMP_95]] - // CHECK: %[[TMP_97:.*]] = mhlo.multiply %[[TMP_96]], %[[TMP_42]] - // CHECK: %[[TMP_98:.*]] = mhlo.constant dense<557.53534081772773> - // CHECK: %[[TMP_99:.*]] = mhlo.add %[[TMP_97]], %[[TMP_98]] - // CHECK: %[[TMP_100:.*]] = mhlo.divide %[[TMP_71]], %[[TMP_99]] - // CHECK: %[[TMP_103:.*]] = mhlo.constant dense<0.56418958354775506> - // CHECK: %[[TMP_105:.*]] = mhlo.multiply %[[TMP_103]], %[[TMP_42]] - // CHECK: %[[TMP_106:.*]] = mhlo.constant dense<1.275366707599781> - // CHECK: %[[TMP_107:.*]] = mhlo.add %[[TMP_105]], %[[TMP_106]] - // CHECK: %[[TMP_108:.*]] = mhlo.multiply %[[TMP_107]], %[[TMP_42]] - // CHECK: %[[TMP_109:.*]] = mhlo.constant dense<5.0190504225118051> - // CHECK: %[[TMP_110:.*]] = mhlo.add %[[TMP_108]], %[[TMP_109]] - // CHECK: %[[TMP_111:.*]] = mhlo.multiply %[[TMP_110]], %[[TMP_42]] - // CHECK: %[[TMP_112:.*]] = mhlo.constant dense<6.160210979930536> - // CHECK: %[[TMP_113:.*]] = mhlo.add %[[TMP_111]], %[[TMP_112]] - // CHECK: %[[TMP_114:.*]] = mhlo.multiply %[[TMP_113]], %[[TMP_42]] - // CHECK: %[[TMP_115:.*]] = mhlo.constant dense<7.4097426995044895> - // CHECK: %[[TMP_116:.*]] = mhlo.add %[[TMP_114]], %[[TMP_115]] - // CHECK: %[[TMP_117:.*]] = mhlo.multiply %[[TMP_116]], %[[TMP_42]] - // CHECK: %[[TMP_118:.*]] = mhlo.constant dense<2.9788666537210022> - // CHECK: %[[TMP_119:.*]] = mhlo.add %[[TMP_117]], %[[TMP_118]] - // CHECK: %[[TMP_120:.*]] = mhlo.multiply %[[TMP_41]], %[[TMP_119]] - // CHECK: %[[TMP_123:.*]] = mhlo.constant dense<1.000000e+00> - // CHECK: %[[TMP_125:.*]] = mhlo.multiply %[[TMP_123]], %[[TMP_42]] - // CHECK: %[[TMP_126:.*]] = mhlo.constant dense<2.2605286322011726> - // CHECK: %[[TMP_127:.*]] = mhlo.add %[[TMP_125]], %[[TMP_126]] - // CHECK: %[[TMP_128:.*]] = mhlo.multiply %[[TMP_127]], %[[TMP_42]] - // CHECK: %[[TMP_129:.*]] = mhlo.constant dense<9.3960352493800147> - // CHECK: %[[TMP_130:.*]] = mhlo.add %[[TMP_128]], %[[TMP_129]] - // CHECK: %[[TMP_131:.*]] = mhlo.multiply %[[TMP_130]], %[[TMP_42]] - // CHECK: %[[TMP_132:.*]] = mhlo.constant dense<12.048953980809666> - // CHECK: %[[TMP_133:.*]] = mhlo.add %[[TMP_131]], %[[TMP_132]] - // CHECK: %[[TMP_134:.*]] = mhlo.multiply %[[TMP_133]], %[[TMP_42]] - // CHECK: %[[TMP_135:.*]] = mhlo.constant dense<17.081445074756591> - // CHECK: %[[TMP_136:.*]] = mhlo.add %[[TMP_134]], %[[TMP_135]] - // CHECK: %[[TMP_137:.*]] = mhlo.multiply %[[TMP_136]], %[[TMP_42]] - // CHECK: %[[TMP_138:.*]] = mhlo.constant dense<9.6089680906328585> - // CHECK: %[[TMP_139:.*]] = mhlo.add %[[TMP_137]], %[[TMP_138]] - // CHECK: %[[TMP_140:.*]] = mhlo.multiply %[[TMP_139]], %[[TMP_42]] - // CHECK: %[[TMP_141:.*]] = mhlo.constant dense<3.3690764510008151> - // CHECK: %[[TMP_142:.*]] = mhlo.add %[[TMP_140]], %[[TMP_141]] - // CHECK: %[[TMP_143:.*]] = mhlo.divide %[[TMP_120]], %[[TMP_142]] - // CHECK: %[[TMP_144:.*]] = mhlo.constant dense<8.000000e+00> - // CHECK: %[[TMP_145:.*]] = mhlo.compare LT, %[[TMP_42]], %[[TMP_144]], NOTYPE - // CHECK: %[[TMP_146:.*]] = mhlo.select %[[TMP_145]], %[[TMP_100]], %[[TMP_143]] - // CHECK: %[[TMP_147:.*]] = mhlo.constant dense<-709.78271289338397> - // CHECK: %[[TMP_148:.*]] = mhlo.compare LT, %[[TMP_40]], %[[TMP_147]], NOTYPE - // CHECK: %[[TMP_149:.*]] = mhlo.constant dense<0.000000e+00> - // CHECK: %[[TMP_150:.*]] = mhlo.select %[[TMP_148]], %[[TMP_149]], %[[TMP_146]] - // CHECK: %[[TMP_152:.*]] = mhlo.compare LT, %[[ARG]], %[[TMP_149]], NOTYPE - // CHECK: %[[TMP_153:.*]] = mhlo.constant dense<2.000000e+00> - // CHECK: %[[TMP_154:.*]] = mhlo.subtract %[[TMP_153]], %[[TMP_150]] - // CHECK: %[[TMP_155:.*]] = mhlo.select %[[TMP_152]], %[[TMP_154]], %[[TMP_150]] - // CHECK: %[[TMP_156:.*]] = mhlo.subtract %[[TMP_38]], %[[TMP_155]] - // CHECK: %[[TMP_157:.*]] = mhlo.abs %[[ARG]] - // CHECK: %[[TMP_159:.*]] = mhlo.compare LT, %[[TMP_157]], %[[TMP_38]], NOTYPE - // CHECK: %[[RESULT:.*]] = mhlo.select %[[TMP_159]], %[[TMP_37]], %[[TMP_156]] + // CHECK: %[[RESULT:.*]] = mhlo.erf %[[ARG]] // CHECK: return %[[RESULT]] %1 = "chlo.erf"(%arg) : (tensor) -> tensor func.return %1 : tensor @@ -413,47 +273,7 @@ func.func @erf_f64(%arg : tensor) -> tensor { // CHECK-LABEL: @erf_f32 // CHECK-SAME: %[[ARG:.*]]: tensor func.func @erf_f32(%arg : tensor) -> tensor { - // CHECK-DAG: %[[TMP_0:.*]] = mhlo.constant dense<-4.000000e+00> - // CHECK-DAG: %[[TMP_1:.*]] = mhlo.constant dense<4.000000e+00> - // CHECK: %[[TMP_2:.*]] = mhlo.clamp %[[TMP_0]], %[[ARG]], %[[TMP_1]] - // CHECK: %[[TMP_3:.*]] = mhlo.multiply %[[TMP_2]], %[[TMP_2]] - // CHECK: %[[TMP_6:.*]] = mhlo.constant dense<-2.72614237E-10> - // CHECK: %[[TMP_8:.*]] = mhlo.multiply %[[TMP_6]], %[[TMP_3]] - // CHECK: %[[TMP_9:.*]] = mhlo.constant dense<2.77068146E-8> - // CHECK: %[[TMP_10:.*]] = mhlo.add %[[TMP_8]], %[[TMP_9]] - // CHECK: %[[TMP_11:.*]] = mhlo.multiply %[[TMP_10]], %[[TMP_3]] - // CHECK: %[[TMP_12:.*]] = mhlo.constant dense<-2.10102394E-6> - // CHECK: %[[TMP_13:.*]] = mhlo.add %[[TMP_11]], %[[TMP_12]] - // CHECK: %[[TMP_14:.*]] = mhlo.multiply %[[TMP_13]], %[[TMP_3]] - // CHECK: %[[TMP_15:.*]] = mhlo.constant dense<-5.69250624E-5> - // CHECK: %[[TMP_16:.*]] = mhlo.add %[[TMP_14]], %[[TMP_15]] - // CHECK: %[[TMP_17:.*]] = mhlo.multiply %[[TMP_16]], %[[TMP_3]] - // CHECK: %[[TMP_18:.*]] = mhlo.constant dense<-7.34990637E-4> - // CHECK: %[[TMP_19:.*]] = mhlo.add %[[TMP_17]], %[[TMP_18]] - // CHECK: %[[TMP_20:.*]] = mhlo.multiply %[[TMP_19]], %[[TMP_3]] - // CHECK: %[[TMP_21:.*]] = mhlo.constant dense<-2.954600e-03> - // CHECK: %[[TMP_22:.*]] = mhlo.add %[[TMP_20]], %[[TMP_21]] - // CHECK: %[[TMP_23:.*]] = mhlo.multiply %[[TMP_22]], %[[TMP_3]] - // CHECK: %[[TMP_24:.*]] = mhlo.constant dense<-0.0160960332> - // CHECK: %[[TMP_25:.*]] = mhlo.add %[[TMP_23]], %[[TMP_24]] - // CHECK: %[[TMP_28:.*]] = mhlo.constant dense<-1.45660715E-5> - // CHECK: %[[TMP_30:.*]] = mhlo.multiply %[[TMP_28]], %[[TMP_3]] - // CHECK: %[[TMP_31:.*]] = mhlo.constant dense<-2.13374049E-4> - // CHECK: %[[TMP_32:.*]] = mhlo.add %[[TMP_30]], %[[TMP_31]] - // CHECK: %[[TMP_33:.*]] = mhlo.multiply %[[TMP_32]], %[[TMP_3]] - // CHECK: %[[TMP_34:.*]] = mhlo.constant dense<-0.00168282702> - // CHECK: %[[TMP_35:.*]] = mhlo.add %[[TMP_33]], %[[TMP_34]] - // CHECK: %[[TMP_36:.*]] = mhlo.multiply %[[TMP_35]], %[[TMP_3]] - // CHECK: %[[TMP_37:.*]] = mhlo.constant dense<-0.00737332925> - // CHECK: %[[TMP_38:.*]] = mhlo.add %[[TMP_36]], %[[TMP_37]] - // CHECK: %[[TMP_39:.*]] = mhlo.multiply %[[TMP_38]], %[[TMP_3]] - // CHECK: %[[TMP_40:.*]] = mhlo.constant dense<-0.0142647391> - // CHECK: %[[TMP_41:.*]] = mhlo.add %[[TMP_39]], %[[TMP_40]] - // CHECK: %[[TMP_42:.*]] = mhlo.multiply %[[TMP_2]], %[[TMP_25]] - // CHECK: %[[TMP_43:.*]] = mhlo.divide %[[TMP_42]], %[[TMP_41]] - // CHECK-DAG: %[[TMP_44:.*]] = mhlo.constant dense<-1.000000e+00> - // CHECK-DAG: %[[TMP_45:.*]] = mhlo.constant dense<1.000000e+00> - // CHECK: %[[RESULT:.*]] = mhlo.clamp %[[TMP_44]], %[[TMP_43]], %[[TMP_45]] + // CHECK: %[[RESULT:.*]] = mhlo.erf %[[ARG]] // CHECK: return %[[RESULT]] %1 = "chlo.erf"(%arg) : (tensor) -> tensor func.return %1 : tensor @@ -464,8 +284,7 @@ func.func @erf_f32(%arg : tensor) -> tensor { // CHECK-LABEL: @erf_f16 // CHECK-SAME: %[[ARG:.*]]: tensor func.func @erf_f16(%arg : tensor) -> tensor { - // CHECK: mhlo.convert %[[ARG]] : (tensor) -> tensor - // CHECK: %[[RESULT:.*]] = mhlo.convert %{{.*}} : (tensor) -> tensor + // CHECK: %[[RESULT:.*]] = mhlo.erf %[[ARG]] // CHECK: return %[[RESULT]] %1 = "chlo.erf"(%arg) : (tensor) -> tensor func.return %1 : tensor @@ -476,8 +295,7 @@ func.func @erf_f16(%arg : tensor) -> tensor { // CHECK-LABEL: @erf_bf16 // CHECK-SAME: %[[ARG:.*]]: tensor func.func @erf_bf16(%arg : tensor) -> tensor { - // CHECK: mhlo.convert %[[ARG]] : (tensor) -> tensor - // CHECK: %[[RESULT:.*]] = mhlo.convert %{{.*}} : (tensor) -> tensor + // CHECK: %[[RESULT:.*]] = mhlo.erf %[[ARG]] // CHECK: return %[[RESULT]] %1 = "chlo.erf"(%arg) : (tensor) -> tensor func.return %1 : tensor @@ -2466,15 +2284,7 @@ func.func @tan_f32(%arg : tensor) -> tensor { // CHECK-LABEL: @top_k // CHECK-SAME: (%[[ARG:.*]]: tensor<16x16xf32>) func.func @top_k(%arg : tensor<16x16xf32>) -> (tensor<16x8xf32>, tensor<16x8xi32>) { - // CHECK: %[[IOTA:.*]] = "mhlo.iota"() {iota_dimension = 1 : i64} - // CHECK-NEXT: %[[SORT:.*]]:2 = "mhlo.sort"(%[[ARG]], %[[IOTA]]) ({ - // CHECK-NEXT: ^{{.*}}(%[[LHS:.*]]: tensor, %[[RHS:.*]]: tensor, %{{.*}}: tensor, %{{.*}}: tensor): - // CHECK-NEXT: %[[CMP:.*]] = mhlo.compare GT, %[[LHS]], %[[RHS]], TOTALORDER - // CHECK-NEXT: mhlo.return %[[CMP]] - // CHECK-NEXT: }) {dimension = 1 : i64, is_stable = true} : (tensor<16x16xf32>, tensor<16x16xi32>) -> (tensor<16x16xf32>, tensor<16x16xi32>) - // CHECK-NEXT: %[[VAL:.*]] = "mhlo.slice"(%[[SORT]]#0) {limit_indices = dense<[16, 8]> : tensor<2xi64>, start_indices = dense<0> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} - // CHECK-NEXT: %[[IDX:.*]] = "mhlo.slice"(%[[SORT]]#1) {limit_indices = dense<[16, 8]> : tensor<2xi64>, start_indices = dense<0> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} - // CHECK-NEXT: return %[[VAL]], %[[IDX]] + // CHECK: %values, %indices = mhlo.topk(%arg0, k = 8, largest = true) : tensor<16x16xf32> -> (tensor<16x8xf32>, tensor<16x8xi32>) %1:2 = chlo.top_k(%arg, k=8) : tensor<16x16xf32> -> (tensor<16x8xf32>, tensor<16x8xi32>) func.return %1#0, %1#1 : tensor<16x8xf32>, tensor<16x8xi32> } @@ -2485,28 +2295,7 @@ func.func @top_k(%arg : tensor<16x16xf32>) -> (tensor<16x8xf32>, tensor<16x8xi32 // CHECK-SAME: ([[ARG:%.*]]: tensor // CHECK-SAME: -> (tensor, tensor) func.func @dyn_top_k(%arg0: tensor) -> (tensor, tensor) { - // CHECK-NEXT: [[DIM_0_I32:%.*]] = "mhlo.get_dimension_size"([[ARG]]) {dimension = 0 : i64} : (tensor) -> tensor - // CHECK-NEXT: [[DIM_0_I32x1:%.*]] = mhlo.reshape [[DIM_0_I32]] : (tensor) -> tensor<1xi32> - // CHECK-NEXT: [[DIM_1_I32:%.*]] = "mhlo.get_dimension_size"([[ARG]]) {dimension = 1 : i64} : (tensor) -> tensor - // CHECK-NEXT: [[DIM_1_I32x1:%.*]] = mhlo.reshape [[DIM_1_I32]] : (tensor) -> tensor<1xi32> - // CHECK-NEXT: [[DIM_2_I32:%.*]] = "mhlo.get_dimension_size"([[ARG]]) {dimension = 2 : i64} : (tensor) -> tensor - // CHECK-NEXT: [[DIM_2_I32x1:%.*]] = mhlo.reshape [[DIM_2_I32]] : (tensor) -> tensor<1xi32> - // CHECK-NEXT: [[IOTA_SHAPE:%.*]] = "mhlo.concatenate"([[DIM_0_I32x1]], [[DIM_1_I32x1]], [[DIM_2_I32x1]]) {dimension = 0 : i64} : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<3xi32> - // CHECK-NEXT: [[K_I32:%.*]] = mhlo.constant dense<2> : tensor - // CHECK-NEXT: [[K_I32x1:%.*]] = mhlo.reshape [[K_I32]] : (tensor) -> tensor<1xi32> - // CHECK-NEXT: [[RESULT_SHAPE:%.*]] = "mhlo.concatenate"([[DIM_0_I32x1]], [[DIM_1_I32x1]], [[K_I32x1]]) {dimension = 0 : i64} : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<3xi32> - // CHECK-NEXT: [[IOTA:%.*]] = "mhlo.dynamic_iota"([[IOTA_SHAPE]]) {iota_dimension = 2 : i64} : (tensor<3xi32>) -> tensor - // CHECK-NEXT: [[SORT:%.*]]:2 = "mhlo.sort"([[ARG]], [[IOTA]]) ({ - // CHECK-NEXT: ^bb0([[ARG_1:%.*]]: tensor, [[ARG_2:%.*]]: tensor, [[ARG_3:%.*]]: tensor, [[ARG_4:%.*]]: tensor): - // CHECK-NEXT: [[CMP:%.*]] = mhlo.compare GT, [[ARG_1]], [[ARG_2]], NOTYPE : (tensor, tensor) -> tensor - // CHECK-NEXT: mhlo.return [[CMP]] : tensor - // CHECK-NEXT: }) {dimension = 2 : i64, is_stable = true} : (tensor, tensor) -> (tensor, tensor) - // CHECK-NEXT: [[STARTS:%.*]] = mhlo.constant dense<0> : tensor<3xi64> - // CHECK-NEXT: [[LIMITS:%.*]] = mhlo.convert [[RESULT_SHAPE]] : (tensor<3xi32>) -> tensor<3xi64> - // CHECK-NEXT: [[STRIDES:%.*]] = mhlo.constant dense<1> : tensor<3xi64> - // CHECK-NEXT: [[VAL:%.*]] = mhlo.real_dynamic_slice [[SORT]]#0, [[STARTS]], [[LIMITS]], [[STRIDES]] : (tensor, tensor<3xi64>, tensor<3xi64>, tensor<3xi64>) -> tensor - // CHECK-NEXT: [[IDX:%.*]] = mhlo.real_dynamic_slice [[SORT]]#1, [[STARTS]], [[LIMITS]], [[STRIDES]] : (tensor, tensor<3xi64>, tensor<3xi64>, tensor<3xi64>) -> tensor - // CHECK-NEXT: return [[VAL]], [[IDX]] : tensor, tensor + // CHECK: %values, %indices = mhlo.topk(%arg0, k = 2, largest = true) : tensor -> (tensor, tensor) %values, %indices = chlo.top_k(%arg0, k = 2) : tensor -> (tensor, tensor) return %values, %indices : tensor, tensor } diff --git a/third_party/xla/xla/mlir_hlo/tests/Dialect/chlo/chlo_legalize_to_mhlo_basis_ops.mlir b/third_party/xla/xla/mlir_hlo/tests/Dialect/chlo/chlo_legalize_to_mhlo_basis_ops.mlir new file mode 100644 index 00000000000000..0a9b7eef41657e --- /dev/null +++ b/third_party/xla/xla/mlir_hlo/tests/Dialect/chlo/chlo_legalize_to_mhlo_basis_ops.mlir @@ -0,0 +1,276 @@ +// RUN: mlir-hlo-opt --chlo-legalize-to-hlo-basis-ops --chlo-legalize-to-hlo --split-input-file -verify-diagnostics %s | FileCheck %s + +// ----- + +// CHECK-LABEL: @erf_f64 +// CHECK-SAME: %[[ARG:.*]]: tensor +func.func @erf_f64(%arg : tensor) -> tensor { + // CHECK: %[[TMP_0:.*]] = mhlo.multiply %[[ARG]], %[[ARG]] + // CHECK: %[[TMP_3:.*]] = mhlo.constant dense<9.6049737398705161> + // CHECK: %[[TMP_5:.*]] = mhlo.multiply %[[TMP_3]], %[[TMP_0]] + // CHECK: %[[TMP_6:.*]] = mhlo.constant dense<90.026019720384269> + // CHECK: %[[TMP_7:.*]] = mhlo.add %[[TMP_5]], %[[TMP_6]] + // CHECK: %[[TMP_8:.*]] = mhlo.multiply %[[TMP_7]], %[[TMP_0]] + // CHECK: %[[TMP_9:.*]] = mhlo.constant dense<2232.0053459468431> + // CHECK: %[[TMP_10:.*]] = mhlo.add %[[TMP_8]], %[[TMP_9]] + // CHECK: %[[TMP_11:.*]] = mhlo.multiply %[[TMP_10]], %[[TMP_0]] + // CHECK: %[[TMP_12:.*]] = mhlo.constant dense<7003.3251411280507> + // CHECK: %[[TMP_13:.*]] = mhlo.add %[[TMP_11]], %[[TMP_12]] + // CHECK: %[[TMP_14:.*]] = mhlo.multiply %[[TMP_13]], %[[TMP_0]] + // CHECK: %[[TMP_15:.*]] = mhlo.constant dense<55592.301301039493> + // CHECK: %[[TMP_16:.*]] = mhlo.add %[[TMP_14]], %[[TMP_15]] + // CHECK: %[[TMP_17:.*]] = mhlo.multiply %[[ARG]], %[[TMP_16]] + // CHECK: %[[TMP_20:.*]] = mhlo.constant dense<1.000000e+00> + // CHECK: %[[TMP_22:.*]] = mhlo.multiply %[[TMP_20]], %[[TMP_0]] + // CHECK: %[[TMP_23:.*]] = mhlo.constant dense<33.561714164750313> + // CHECK: %[[TMP_24:.*]] = mhlo.add %[[TMP_22]], %[[TMP_23]] + // CHECK: %[[TMP_25:.*]] = mhlo.multiply %[[TMP_24]], %[[TMP_0]] + // CHECK: %[[TMP_26:.*]] = mhlo.constant dense<521.35794978015269> + // CHECK: %[[TMP_27:.*]] = mhlo.add %[[TMP_25]], %[[TMP_26]] + // CHECK: %[[TMP_28:.*]] = mhlo.multiply %[[TMP_27]], %[[TMP_0]] + // CHECK: %[[TMP_29:.*]] = mhlo.constant dense<4594.3238297098014> + // CHECK: %[[TMP_30:.*]] = mhlo.add %[[TMP_28]], %[[TMP_29]] + // CHECK: %[[TMP_31:.*]] = mhlo.multiply %[[TMP_30]], %[[TMP_0]] + // CHECK: %[[TMP_32:.*]] = mhlo.constant dense<22629.000061389095> + // CHECK: %[[TMP_33:.*]] = mhlo.add %[[TMP_31]], %[[TMP_32]] + // CHECK: %[[TMP_34:.*]] = mhlo.multiply %[[TMP_33]], %[[TMP_0]] + // CHECK: %[[TMP_35:.*]] = mhlo.constant dense<49267.394260863592> + // CHECK: %[[TMP_36:.*]] = mhlo.add %[[TMP_34]], %[[TMP_35]] + // CHECK: %[[TMP_37:.*]] = mhlo.divide %[[TMP_17]], %[[TMP_36]] + // CHECK: %[[TMP_38:.*]] = mhlo.constant dense<1.000000e+00> + // CHECK: %[[TMP_39:.*]] = mhlo.multiply %[[ARG]], %[[ARG]] + // CHECK: %[[TMP_40:.*]] = mhlo.negate %[[TMP_39]] + // CHECK: %[[TMP_41:.*]] = mhlo.exponential %[[TMP_40]] + // CHECK: %[[TMP_42:.*]] = mhlo.abs %[[ARG]] + // CHECK: %[[TMP_45:.*]] = mhlo.constant dense<2.4619698147353052E-10> + // CHECK: %[[TMP_47:.*]] = mhlo.multiply %[[TMP_45]], %[[TMP_42]] + // CHECK: %[[TMP_48:.*]] = mhlo.constant dense<0.56418956483106886> + // CHECK: %[[TMP_49:.*]] = mhlo.add %[[TMP_47]], %[[TMP_48]] + // CHECK: %[[TMP_50:.*]] = mhlo.multiply %[[TMP_49]], %[[TMP_42]] + // CHECK: %[[TMP_51:.*]] = mhlo.constant dense<7.4632105644226989> + // CHECK: %[[TMP_52:.*]] = mhlo.add %[[TMP_50]], %[[TMP_51]] + // CHECK: %[[TMP_53:.*]] = mhlo.multiply %[[TMP_52]], %[[TMP_42]] + // CHECK: %[[TMP_54:.*]] = mhlo.constant dense<48.637197098568137> + // CHECK: %[[TMP_55:.*]] = mhlo.add %[[TMP_53]], %[[TMP_54]] + // CHECK: %[[TMP_56:.*]] = mhlo.multiply %[[TMP_55]], %[[TMP_42]] + // CHECK: %[[TMP_57:.*]] = mhlo.constant dense<196.5208329560771> + // CHECK: %[[TMP_58:.*]] = mhlo.add %[[TMP_56]], %[[TMP_57]] + // CHECK: %[[TMP_59:.*]] = mhlo.multiply %[[TMP_58]], %[[TMP_42]] + // CHECK: %[[TMP_60:.*]] = mhlo.constant dense<526.44519499547732> + // CHECK: %[[TMP_61:.*]] = mhlo.add %[[TMP_59]], %[[TMP_60]] + // CHECK: %[[TMP_62:.*]] = mhlo.multiply %[[TMP_61]], %[[TMP_42]] + // CHECK: %[[TMP_63:.*]] = mhlo.constant dense<934.52852717195765> + // CHECK: %[[TMP_64:.*]] = mhlo.add %[[TMP_62]], %[[TMP_63]] + // CHECK: %[[TMP_65:.*]] = mhlo.multiply %[[TMP_64]], %[[TMP_42]] + // CHECK: %[[TMP_66:.*]] = mhlo.constant dense<1027.5518868951572> + // CHECK: %[[TMP_67:.*]] = mhlo.add %[[TMP_65]], %[[TMP_66]] + // CHECK: %[[TMP_68:.*]] = mhlo.multiply %[[TMP_67]], %[[TMP_42]] + // CHECK: %[[TMP_69:.*]] = mhlo.constant dense<557.53533536939938> + // CHECK: %[[TMP_70:.*]] = mhlo.add %[[TMP_68]], %[[TMP_69]] + // CHECK: %[[TMP_71:.*]] = mhlo.multiply %[[TMP_41]], %[[TMP_70]] + // CHECK: %[[TMP_74:.*]] = mhlo.constant dense<1.000000e+00> + // CHECK: %[[TMP_76:.*]] = mhlo.multiply %[[TMP_74]], %[[TMP_42]] + // CHECK: %[[TMP_77:.*]] = mhlo.constant dense<13.228195115474499> + // CHECK: %[[TMP_78:.*]] = mhlo.add %[[TMP_76]], %[[TMP_77]] + // CHECK: %[[TMP_79:.*]] = mhlo.multiply %[[TMP_78]], %[[TMP_42]] + // CHECK: %[[TMP_80:.*]] = mhlo.constant dense<86.707214088598973> + // CHECK: %[[TMP_81:.*]] = mhlo.add %[[TMP_79]], %[[TMP_80]] + // CHECK: %[[TMP_82:.*]] = mhlo.multiply %[[TMP_81]], %[[TMP_42]] + // CHECK: %[[TMP_83:.*]] = mhlo.constant dense<354.93777888781989> + // CHECK: %[[TMP_84:.*]] = mhlo.add %[[TMP_82]], %[[TMP_83]] + // CHECK: %[[TMP_85:.*]] = mhlo.multiply %[[TMP_84]], %[[TMP_42]] + // CHECK: %[[TMP_86:.*]] = mhlo.constant dense<975.70850174320549> + // CHECK: %[[TMP_87:.*]] = mhlo.add %[[TMP_85]], %[[TMP_86]] + // CHECK: %[[TMP_88:.*]] = mhlo.multiply %[[TMP_87]], %[[TMP_42]] + // CHECK: %[[TMP_89:.*]] = mhlo.constant dense<1823.9091668790973> + // CHECK: %[[TMP_90:.*]] = mhlo.add %[[TMP_88]], %[[TMP_89]] + // CHECK: %[[TMP_91:.*]] = mhlo.multiply %[[TMP_90]], %[[TMP_42]] + // CHECK: %[[TMP_92:.*]] = mhlo.constant dense<2246.3376081871097> + // CHECK: %[[TMP_93:.*]] = mhlo.add %[[TMP_91]], %[[TMP_92]] + // CHECK: %[[TMP_94:.*]] = mhlo.multiply %[[TMP_93]], %[[TMP_42]] + // CHECK: %[[TMP_95:.*]] = mhlo.constant dense<1656.6630919416134> + // CHECK: %[[TMP_96:.*]] = mhlo.add %[[TMP_94]], %[[TMP_95]] + // CHECK: %[[TMP_97:.*]] = mhlo.multiply %[[TMP_96]], %[[TMP_42]] + // CHECK: %[[TMP_98:.*]] = mhlo.constant dense<557.53534081772773> + // CHECK: %[[TMP_99:.*]] = mhlo.add %[[TMP_97]], %[[TMP_98]] + // CHECK: %[[TMP_100:.*]] = mhlo.divide %[[TMP_71]], %[[TMP_99]] + // CHECK: %[[TMP_103:.*]] = mhlo.constant dense<0.56418958354775506> + // CHECK: %[[TMP_105:.*]] = mhlo.multiply %[[TMP_103]], %[[TMP_42]] + // CHECK: %[[TMP_106:.*]] = mhlo.constant dense<1.275366707599781> + // CHECK: %[[TMP_107:.*]] = mhlo.add %[[TMP_105]], %[[TMP_106]] + // CHECK: %[[TMP_108:.*]] = mhlo.multiply %[[TMP_107]], %[[TMP_42]] + // CHECK: %[[TMP_109:.*]] = mhlo.constant dense<5.0190504225118051> + // CHECK: %[[TMP_110:.*]] = mhlo.add %[[TMP_108]], %[[TMP_109]] + // CHECK: %[[TMP_111:.*]] = mhlo.multiply %[[TMP_110]], %[[TMP_42]] + // CHECK: %[[TMP_112:.*]] = mhlo.constant dense<6.160210979930536> + // CHECK: %[[TMP_113:.*]] = mhlo.add %[[TMP_111]], %[[TMP_112]] + // CHECK: %[[TMP_114:.*]] = mhlo.multiply %[[TMP_113]], %[[TMP_42]] + // CHECK: %[[TMP_115:.*]] = mhlo.constant dense<7.4097426995044895> + // CHECK: %[[TMP_116:.*]] = mhlo.add %[[TMP_114]], %[[TMP_115]] + // CHECK: %[[TMP_117:.*]] = mhlo.multiply %[[TMP_116]], %[[TMP_42]] + // CHECK: %[[TMP_118:.*]] = mhlo.constant dense<2.9788666537210022> + // CHECK: %[[TMP_119:.*]] = mhlo.add %[[TMP_117]], %[[TMP_118]] + // CHECK: %[[TMP_120:.*]] = mhlo.multiply %[[TMP_41]], %[[TMP_119]] + // CHECK: %[[TMP_123:.*]] = mhlo.constant dense<1.000000e+00> + // CHECK: %[[TMP_125:.*]] = mhlo.multiply %[[TMP_123]], %[[TMP_42]] + // CHECK: %[[TMP_126:.*]] = mhlo.constant dense<2.2605286322011726> + // CHECK: %[[TMP_127:.*]] = mhlo.add %[[TMP_125]], %[[TMP_126]] + // CHECK: %[[TMP_128:.*]] = mhlo.multiply %[[TMP_127]], %[[TMP_42]] + // CHECK: %[[TMP_129:.*]] = mhlo.constant dense<9.3960352493800147> + // CHECK: %[[TMP_130:.*]] = mhlo.add %[[TMP_128]], %[[TMP_129]] + // CHECK: %[[TMP_131:.*]] = mhlo.multiply %[[TMP_130]], %[[TMP_42]] + // CHECK: %[[TMP_132:.*]] = mhlo.constant dense<12.048953980809666> + // CHECK: %[[TMP_133:.*]] = mhlo.add %[[TMP_131]], %[[TMP_132]] + // CHECK: %[[TMP_134:.*]] = mhlo.multiply %[[TMP_133]], %[[TMP_42]] + // CHECK: %[[TMP_135:.*]] = mhlo.constant dense<17.081445074756591> + // CHECK: %[[TMP_136:.*]] = mhlo.add %[[TMP_134]], %[[TMP_135]] + // CHECK: %[[TMP_137:.*]] = mhlo.multiply %[[TMP_136]], %[[TMP_42]] + // CHECK: %[[TMP_138:.*]] = mhlo.constant dense<9.6089680906328585> + // CHECK: %[[TMP_139:.*]] = mhlo.add %[[TMP_137]], %[[TMP_138]] + // CHECK: %[[TMP_140:.*]] = mhlo.multiply %[[TMP_139]], %[[TMP_42]] + // CHECK: %[[TMP_141:.*]] = mhlo.constant dense<3.3690764510008151> + // CHECK: %[[TMP_142:.*]] = mhlo.add %[[TMP_140]], %[[TMP_141]] + // CHECK: %[[TMP_143:.*]] = mhlo.divide %[[TMP_120]], %[[TMP_142]] + // CHECK: %[[TMP_144:.*]] = mhlo.constant dense<8.000000e+00> + // CHECK: %[[TMP_145:.*]] = mhlo.compare LT, %[[TMP_42]], %[[TMP_144]], NOTYPE + // CHECK: %[[TMP_146:.*]] = mhlo.select %[[TMP_145]], %[[TMP_100]], %[[TMP_143]] + // CHECK: %[[TMP_147:.*]] = mhlo.constant dense<-709.78271289338397> + // CHECK: %[[TMP_148:.*]] = mhlo.compare LT, %[[TMP_40]], %[[TMP_147]], NOTYPE + // CHECK: %[[TMP_149:.*]] = mhlo.constant dense<0.000000e+00> + // CHECK: %[[TMP_150:.*]] = mhlo.select %[[TMP_148]], %[[TMP_149]], %[[TMP_146]] + // CHECK: %[[TMP_152:.*]] = mhlo.compare LT, %[[ARG]], %[[TMP_149]], NOTYPE + // CHECK: %[[TMP_153:.*]] = mhlo.constant dense<2.000000e+00> + // CHECK: %[[TMP_154:.*]] = mhlo.subtract %[[TMP_153]], %[[TMP_150]] + // CHECK: %[[TMP_155:.*]] = mhlo.select %[[TMP_152]], %[[TMP_154]], %[[TMP_150]] + // CHECK: %[[TMP_156:.*]] = mhlo.subtract %[[TMP_38]], %[[TMP_155]] + // CHECK: %[[TMP_157:.*]] = mhlo.abs %[[ARG]] + // CHECK: %[[TMP_159:.*]] = mhlo.compare LT, %[[TMP_157]], %[[TMP_38]], NOTYPE + // CHECK: %[[RESULT:.*]] = mhlo.select %[[TMP_159]], %[[TMP_37]], %[[TMP_156]] + // CHECK: return %[[RESULT]] + %1 = "chlo.erf"(%arg) : (tensor) -> tensor + func.return %1 : tensor +} + +// ----- + +// CHECK-LABEL: @erf_f32 +// CHECK-SAME: %[[ARG:.*]]: tensor +func.func @erf_f32(%arg : tensor) -> tensor { + // CHECK-DAG: %[[TMP_0:.*]] = mhlo.constant dense<-4.000000e+00> + // CHECK-DAG: %[[TMP_1:.*]] = mhlo.constant dense<4.000000e+00> + // CHECK: %[[TMP_2:.*]] = mhlo.clamp %[[TMP_0]], %[[ARG]], %[[TMP_1]] + // CHECK: %[[TMP_3:.*]] = mhlo.multiply %[[TMP_2]], %[[TMP_2]] + // CHECK: %[[TMP_6:.*]] = mhlo.constant dense<-2.72614237E-10> + // CHECK: %[[TMP_8:.*]] = mhlo.multiply %[[TMP_6]], %[[TMP_3]] + // CHECK: %[[TMP_9:.*]] = mhlo.constant dense<2.77068146E-8> + // CHECK: %[[TMP_10:.*]] = mhlo.add %[[TMP_8]], %[[TMP_9]] + // CHECK: %[[TMP_11:.*]] = mhlo.multiply %[[TMP_10]], %[[TMP_3]] + // CHECK: %[[TMP_12:.*]] = mhlo.constant dense<-2.10102394E-6> + // CHECK: %[[TMP_13:.*]] = mhlo.add %[[TMP_11]], %[[TMP_12]] + // CHECK: %[[TMP_14:.*]] = mhlo.multiply %[[TMP_13]], %[[TMP_3]] + // CHECK: %[[TMP_15:.*]] = mhlo.constant dense<-5.69250624E-5> + // CHECK: %[[TMP_16:.*]] = mhlo.add %[[TMP_14]], %[[TMP_15]] + // CHECK: %[[TMP_17:.*]] = mhlo.multiply %[[TMP_16]], %[[TMP_3]] + // CHECK: %[[TMP_18:.*]] = mhlo.constant dense<-7.34990637E-4> + // CHECK: %[[TMP_19:.*]] = mhlo.add %[[TMP_17]], %[[TMP_18]] + // CHECK: %[[TMP_20:.*]] = mhlo.multiply %[[TMP_19]], %[[TMP_3]] + // CHECK: %[[TMP_21:.*]] = mhlo.constant dense<-2.954600e-03> + // CHECK: %[[TMP_22:.*]] = mhlo.add %[[TMP_20]], %[[TMP_21]] + // CHECK: %[[TMP_23:.*]] = mhlo.multiply %[[TMP_22]], %[[TMP_3]] + // CHECK: %[[TMP_24:.*]] = mhlo.constant dense<-0.0160960332> + // CHECK: %[[TMP_25:.*]] = mhlo.add %[[TMP_23]], %[[TMP_24]] + // CHECK: %[[TMP_28:.*]] = mhlo.constant dense<-1.45660715E-5> + // CHECK: %[[TMP_30:.*]] = mhlo.multiply %[[TMP_28]], %[[TMP_3]] + // CHECK: %[[TMP_31:.*]] = mhlo.constant dense<-2.13374049E-4> + // CHECK: %[[TMP_32:.*]] = mhlo.add %[[TMP_30]], %[[TMP_31]] + // CHECK: %[[TMP_33:.*]] = mhlo.multiply %[[TMP_32]], %[[TMP_3]] + // CHECK: %[[TMP_34:.*]] = mhlo.constant dense<-0.00168282702> + // CHECK: %[[TMP_35:.*]] = mhlo.add %[[TMP_33]], %[[TMP_34]] + // CHECK: %[[TMP_36:.*]] = mhlo.multiply %[[TMP_35]], %[[TMP_3]] + // CHECK: %[[TMP_37:.*]] = mhlo.constant dense<-0.00737332925> + // CHECK: %[[TMP_38:.*]] = mhlo.add %[[TMP_36]], %[[TMP_37]] + // CHECK: %[[TMP_39:.*]] = mhlo.multiply %[[TMP_38]], %[[TMP_3]] + // CHECK: %[[TMP_40:.*]] = mhlo.constant dense<-0.0142647391> + // CHECK: %[[TMP_41:.*]] = mhlo.add %[[TMP_39]], %[[TMP_40]] + // CHECK: %[[TMP_42:.*]] = mhlo.multiply %[[TMP_2]], %[[TMP_25]] + // CHECK: %[[TMP_43:.*]] = mhlo.divide %[[TMP_42]], %[[TMP_41]] + // CHECK-DAG: %[[TMP_44:.*]] = mhlo.constant dense<-1.000000e+00> + // CHECK-DAG: %[[TMP_45:.*]] = mhlo.constant dense<1.000000e+00> + // CHECK: %[[RESULT:.*]] = mhlo.clamp %[[TMP_44]], %[[TMP_43]], %[[TMP_45]] + // CHECK: return %[[RESULT]] + %1 = "chlo.erf"(%arg) : (tensor) -> tensor + func.return %1 : tensor +} + +// ----- + +// CHECK-LABEL: @erf_f16 +// CHECK-SAME: %[[ARG:.*]]: tensor +func.func @erf_f16(%arg : tensor) -> tensor { + // CHECK: mhlo.convert %[[ARG]] : (tensor) -> tensor + // CHECK: %[[RESULT:.*]] = mhlo.convert %{{.*}} : (tensor) -> tensor + // CHECK: return %[[RESULT]] + %1 = "chlo.erf"(%arg) : (tensor) -> tensor + func.return %1 : tensor +} + +// ----- + +// CHECK-LABEL: @erf_bf16 +// CHECK-SAME: %[[ARG:.*]]: tensor +func.func @erf_bf16(%arg : tensor) -> tensor { + // CHECK: mhlo.convert %[[ARG]] : (tensor) -> tensor + // CHECK: %[[RESULT:.*]] = mhlo.convert %{{.*}} : (tensor) -> tensor + // CHECK: return %[[RESULT]] + %1 = "chlo.erf"(%arg) : (tensor) -> tensor + func.return %1 : tensor +} + + +// CHECK-LABEL: @top_k +// CHECK-SAME: (%[[ARG:.*]]: tensor<16x16xf32>) +func.func @top_k(%arg : tensor<16x16xf32>) -> (tensor<16x8xf32>, tensor<16x8xi32>) { + // CHECK: %[[IOTA:.*]] = "mhlo.iota"() {iota_dimension = 1 : i64} + // CHECK-NEXT: %[[SORT:.*]]:2 = "mhlo.sort"(%[[ARG]], %[[IOTA]]) ({ + // CHECK-NEXT: ^{{.*}}(%[[LHS:.*]]: tensor, %[[RHS:.*]]: tensor, %{{.*}}: tensor, %{{.*}}: tensor): + // CHECK-NEXT: %[[CMP:.*]] = mhlo.compare GT, %[[LHS]], %[[RHS]], TOTALORDER + // CHECK-NEXT: mhlo.return %[[CMP]] + // CHECK-NEXT: }) {dimension = 1 : i64, is_stable = true} : (tensor<16x16xf32>, tensor<16x16xi32>) -> (tensor<16x16xf32>, tensor<16x16xi32>) + // CHECK-NEXT: %[[VAL:.*]] = "mhlo.slice"(%[[SORT]]#0) {limit_indices = dense<[16, 8]> : tensor<2xi64>, start_indices = dense<0> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} + // CHECK-NEXT: %[[IDX:.*]] = "mhlo.slice"(%[[SORT]]#1) {limit_indices = dense<[16, 8]> : tensor<2xi64>, start_indices = dense<0> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} + // CHECK-NEXT: return %[[VAL]], %[[IDX]] + %1:2 = chlo.top_k(%arg, k=8) : tensor<16x16xf32> -> (tensor<16x8xf32>, tensor<16x8xi32>) + func.return %1#0, %1#1 : tensor<16x8xf32>, tensor<16x8xi32> +} + +// ----- + +// CHECK-LABEL: @dyn_top_k +// CHECK-SAME: ([[ARG:%.*]]: tensor +// CHECK-SAME: -> (tensor, tensor) +func.func @dyn_top_k(%arg0: tensor) -> (tensor, tensor) { + // CHECK-NEXT: [[DIM_0_I32:%.*]] = "mhlo.get_dimension_size"([[ARG]]) {dimension = 0 : i64} : (tensor) -> tensor + // CHECK-NEXT: [[DIM_0_I32x1:%.*]] = mhlo.reshape [[DIM_0_I32]] : (tensor) -> tensor<1xi32> + // CHECK-NEXT: [[DIM_1_I32:%.*]] = "mhlo.get_dimension_size"([[ARG]]) {dimension = 1 : i64} : (tensor) -> tensor + // CHECK-NEXT: [[DIM_1_I32x1:%.*]] = mhlo.reshape [[DIM_1_I32]] : (tensor) -> tensor<1xi32> + // CHECK-NEXT: [[DIM_2_I32:%.*]] = "mhlo.get_dimension_size"([[ARG]]) {dimension = 2 : i64} : (tensor) -> tensor + // CHECK-NEXT: [[DIM_2_I32x1:%.*]] = mhlo.reshape [[DIM_2_I32]] : (tensor) -> tensor<1xi32> + // CHECK-NEXT: [[IOTA_SHAPE:%.*]] = "mhlo.concatenate"([[DIM_0_I32x1]], [[DIM_1_I32x1]], [[DIM_2_I32x1]]) {dimension = 0 : i64} : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<3xi32> + // CHECK-NEXT: [[K_I32:%.*]] = mhlo.constant dense<2> : tensor + // CHECK-NEXT: [[K_I32x1:%.*]] = mhlo.reshape [[K_I32]] : (tensor) -> tensor<1xi32> + // CHECK-NEXT: [[RESULT_SHAPE:%.*]] = "mhlo.concatenate"([[DIM_0_I32x1]], [[DIM_1_I32x1]], [[K_I32x1]]) {dimension = 0 : i64} : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<3xi32> + // CHECK-NEXT: [[IOTA:%.*]] = "mhlo.dynamic_iota"([[IOTA_SHAPE]]) {iota_dimension = 2 : i64} : (tensor<3xi32>) -> tensor + // CHECK-NEXT: [[SORT:%.*]]:2 = "mhlo.sort"([[ARG]], [[IOTA]]) ({ + // CHECK-NEXT: ^bb0([[ARG_1:%.*]]: tensor, [[ARG_2:%.*]]: tensor, [[ARG_3:%.*]]: tensor, [[ARG_4:%.*]]: tensor): + // CHECK-NEXT: [[CMP:%.*]] = mhlo.compare GT, [[ARG_1]], [[ARG_2]], NOTYPE : (tensor, tensor) -> tensor + // CHECK-NEXT: mhlo.return [[CMP]] : tensor + // CHECK-NEXT: }) {dimension = 2 : i64, is_stable = true} : (tensor, tensor) -> (tensor, tensor) + // CHECK-NEXT: [[STARTS:%.*]] = mhlo.constant dense<0> : tensor<3xi64> + // CHECK-NEXT: [[LIMITS:%.*]] = mhlo.convert [[RESULT_SHAPE]] : (tensor<3xi32>) -> tensor<3xi64> + // CHECK-NEXT: [[STRIDES:%.*]] = mhlo.constant dense<1> : tensor<3xi64> + // CHECK-NEXT: [[VAL:%.*]] = mhlo.real_dynamic_slice [[SORT]]#0, [[STARTS]], [[LIMITS]], [[STRIDES]] : (tensor, tensor<3xi64>, tensor<3xi64>, tensor<3xi64>) -> tensor + // CHECK-NEXT: [[IDX:%.*]] = mhlo.real_dynamic_slice [[SORT]]#1, [[STARTS]], [[LIMITS]], [[STRIDES]] : (tensor, tensor<3xi64>, tensor<3xi64>, tensor<3xi64>) -> tensor + // CHECK-NEXT: return [[VAL]], [[IDX]] : tensor, tensor + %values, %indices = chlo.top_k(%arg0, k = 2) : tensor -> (tensor, tensor) + return %values, %indices : tensor, tensor +} diff --git a/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/convert-mhlo-quant-to-int.mlir b/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/convert-mhlo-quant-to-int.mlir new file mode 100644 index 00000000000000..cb19c112bfe05d --- /dev/null +++ b/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/convert-mhlo-quant-to-int.mlir @@ -0,0 +1,2132 @@ +// RUN: mlir-hlo-opt --mhlo-quant-legalize-to-int -split-input-file %s -verify-diagnostics | FileCheck %s + +// CHECK-LABEL: func @uniform_quantize_and_dequantize +func.func @uniform_quantize_and_dequantize(%arg0: tensor) -> tensor { + // CHECK-DAG: %[[SCALES:.*]] = mhlo.constant dense<1.000000e+00> : tensor + // CHECK-DAG: %[[ZPS:.*]] = mhlo.constant dense<3.000000e+00> : tensor + // CHECK-DAG: %[[QUANT_MIN:.*]] = mhlo.constant dense<-1.280000e+02> : tensor + // CHECK-DAG: %[[QUANT_MAX:.*]] = mhlo.constant dense<1.270000e+02> : tensor + // CHECK: %[[VAL0:.*]] = chlo.broadcast_divide %arg0, %[[SCALES]] : (tensor, tensor) -> tensor + // CHECK: %[[VAL1:.*]] = chlo.broadcast_add %[[VAL0]], %[[ZPS]] : (tensor, tensor) -> tensor + // CHECK: %[[VAL2:.*]] = mhlo.clamp %[[QUANT_MIN]], %[[VAL1]], %[[QUANT_MAX]] : (tensor, tensor, tensor) -> tensor + // CHECK: %[[VAL3:.*]] = mhlo.round_nearest_even %[[VAL2]] : tensor + // CHECK: %[[VAL4:.*]] = mhlo.convert %[[VAL3]] : (tensor) -> tensor + %0 = mhlo.uniform_quantize %arg0 : (tensor) -> tensor> + + // CHECK-DAG: %[[SCALES_DQ:.*]] = mhlo.constant dense<1.000000e+00> : tensor + // CHECK-DAG: %[[ZPS_DQ:.*]] = mhlo.constant dense<3> : tensor + // CHECK: %[[VAL5:.*]] = mhlo.convert %[[VAL4]] : (tensor) -> tensor + // CHECK: %[[VAL6:.*]] = chlo.broadcast_subtract %[[VAL5]], %[[ZPS_DQ]] : (tensor, tensor) -> tensor + // CHECK: %[[VAL7:.*]] = mhlo.convert %[[VAL6]] : (tensor) -> tensor + // CHECK: %[[VAL8:.*]] = chlo.broadcast_multiply %[[VAL7]], %[[SCALES_DQ]] : (tensor, tensor) -> tensor + // CHECK: return %[[VAL8]] : tensor + %1 = mhlo.uniform_dequantize %0 : (tensor>) -> tensor + return %1 : tensor +} + +// ----- + +// CHECK-LABEL: func @uniform_quantize_convert_dequantize +func.func @uniform_quantize_convert_dequantize(%arg0: tensor) -> tensor { + // CHECK-DAG: %[[SCALES:.*]] = mhlo.constant dense<1.000000e+00> : tensor + // CHECK-DAG: %[[ZPS:.*]] = mhlo.constant dense<3.000000e+00> : tensor + // CHECK-DAG: %[[QUANT_MIN:.*]] = mhlo.constant dense<-1.280000e+02> : tensor + // CHECK-DAG: %[[QUANT_MAX:.*]] = mhlo.constant dense<1.270000e+02> : tensor + // CHECK: %[[VAL0:.*]] = chlo.broadcast_divide %arg0, %[[SCALES]] : (tensor, tensor) -> tensor + // CHECK: %[[VAL1:.*]] = chlo.broadcast_add %[[VAL0]], %[[ZPS]] : (tensor, tensor) -> tensor + // CHECK: %[[VAL2:.*]] = mhlo.clamp %[[QUANT_MIN]], %[[VAL1]], %[[QUANT_MAX]] : (tensor, tensor, tensor) -> tensor + // CHECK: %[[VAL3:.*]] = mhlo.round_nearest_even %[[VAL2]] : tensor + // CHECK: %[[VAL4:.*]] = mhlo.convert %[[VAL3]] : (tensor) -> tensor + %0 = mhlo.uniform_quantize %arg0 : (tensor) -> tensor> + + // CHECK: %[[VAL5:.*]] = mhlo.bitcast_convert %[[VAL4]] : (tensor) -> tensor + %1 = mhlo.bitcast_convert %0 : (tensor>) -> tensor + + // CHECK: %[[VAL6:.*]] = mhlo.bitcast_convert %[[VAL5]] : (tensor) -> tensor + %2 = mhlo.bitcast_convert %1 : (tensor) -> tensor> + + // CHECK-DAG: %[[SCALES_DQ:.*]] = mhlo.constant dense<1.000000e+00> : tensor + // CHECK-DAG: %[[ZPS_DQ:.*]] = mhlo.constant dense<3> : tensor + // CHECK: %[[VAL7:.*]] = mhlo.convert %[[VAL6]] : (tensor) -> tensor + // CHECK: %[[VAL8:.*]] = chlo.broadcast_subtract %[[VAL7]], %[[ZPS_DQ]] : (tensor, tensor) -> tensor + // CHECK: %[[VAL9:.*]] = mhlo.convert %[[VAL8]] : (tensor) -> tensor + // CHECK: %[[VAL10:.*]] = chlo.broadcast_multiply %[[VAL9]], %[[SCALES_DQ]] : (tensor, tensor) -> tensor + // CHECK: return %[[VAL10]] : tensor + %3 = mhlo.uniform_dequantize %2 : (tensor>) -> tensor + return %3 : tensor +} + +// ----- + +// CHECK-LABEL: func @uniform_quantize_and_dequantize_int4 +func.func @uniform_quantize_and_dequantize_int4(%arg0: tensor) -> tensor { + // CHECK-DAG: %[[SCALES:.*]] = mhlo.constant dense<1.000000e+00> : tensor + // CHECK-DAG: %[[ZPS:.*]] = mhlo.constant dense<3.000000e+00> : tensor + // CHECK-DAG: %[[QUANT_MIN:.*]] = mhlo.constant dense<-8.000000e+00> : tensor + // CHECK-DAG: %[[QUANT_MAX:.*]] = mhlo.constant dense<7.000000e+00> : tensor + // CHECK: %[[VAL0:.*]] = chlo.broadcast_divide %arg0, %[[SCALES]] : (tensor, tensor) -> tensor + // CHECK: %[[VAL1:.*]] = chlo.broadcast_add %[[VAL0]], %[[ZPS]] : (tensor, tensor) -> tensor + // CHECK: %[[VAL2:.*]] = mhlo.clamp %[[QUANT_MIN]], %[[VAL1]], %[[QUANT_MAX]] : (tensor, tensor, tensor) -> tensor + // CHECK: %[[VAL3:.*]] = mhlo.round_nearest_even %[[VAL2]] : tensor + // CHECK: %[[VAL4:.*]] = mhlo.convert %[[VAL3]] : (tensor) -> tensor + %0 = mhlo.uniform_quantize %arg0 : (tensor) -> tensor> + + // CHECK-DAG: %[[SCALES_DQ:.*]] = mhlo.constant dense<1.000000e+00> : tensor + // CHECK-DAG: %[[ZPS_DQ:.*]] = mhlo.constant dense<3> : tensor + // CHECK: %[[VAL5:.*]] = mhlo.convert %[[VAL4]] : (tensor) -> tensor + // CHECK: %[[VAL6:.*]] = chlo.broadcast_subtract %[[VAL5]], %[[ZPS_DQ]] : (tensor, tensor) -> tensor + // CHECK: %[[VAL7:.*]] = mhlo.convert %[[VAL6]] : (tensor) -> tensor + // CHECK: %[[VAL8:.*]] = chlo.broadcast_multiply %[[VAL7]], %[[SCALES_DQ]] : (tensor, tensor) -> tensor + // CHECK: return %[[VAL8]] : tensor + %1 = mhlo.uniform_dequantize %0 : (tensor>) -> tensor + return %1 : tensor +} + +// ----- + +// CHECK-LABEL: func @uniform_quantize_and_dequantize_type_exensions +func.func @uniform_quantize_and_dequantize_type_exensions(%arg0: tensor>) -> () { + // CHECK: %[[QUANTIZED:.*]] = mhlo.convert %[[VAL0:.*]] : (tensor>) -> tensor> + %0 = mhlo.uniform_quantize %arg0 : (tensor>) -> tensor, #mhlo.type_extensions> + // CHECK: %[[DEQUANTIZED:.*]] = chlo.broadcast_multiply %[[VAL1:.*]], %[[CONST_SCALE:.*]] : (tensor>, tensor) -> tensor> + %1 = mhlo.uniform_dequantize %0 : (tensor, #mhlo.type_extensions>) -> tensor> + return +} + +// ----- + +#SV = #sparse_tensor.encoding<{ map = (d0) -> (d0 : compressed) }> + +// CHECK: #[[$SV:.*]] = #sparse_tensor.encoding<{ map = (d0) -> (d0 : compressed) }> +// CHECK-LABEL: func @uniform_quantize_and_dequantize_sparse_tensor_encoding +func.func @uniform_quantize_and_dequantize_sparse_tensor_encoding(%arg0: tensor) -> () { + // CHECK: %[[QUANTIZED:.*]] = mhlo.convert %[[VAL0:.*]] : (tensor) -> tensor + %0 = mhlo.uniform_quantize %arg0 : (tensor) -> tensor, #SV> + // CHECK: %[[DEQUANTIZED:.*]] = chlo.broadcast_multiply %[[VAL1:.*]], %[[CONST_SCALE:.*]] : (tensor, tensor) -> tensor + %1 = mhlo.uniform_dequantize %0 : (tensor, #SV>) -> tensor + return +} + +// ----- + +// CHECK-LABEL: func @quantize_per_channel +func.func @quantize_per_channel(%arg0: tensor<26x26x3x2xf32> + ) -> tensor<26x26x3x2x!quant.uniform> { + // CHECK-DAG: %[[SCALES:.*]] = mhlo.constant dense<[1.100000e+00, 1.100000e-01]> + // CHECK-DAG: %[[ZPS:.*]] = mhlo.constant dense<[-1.000000e+01, 2.000000e+00]> + // CHECK-DAG: %[[QMIN:.*]] = mhlo.constant dense<-2.14748365E+9> : tensor + // CHECK-DAG: %[[QMAX:.*]] = mhlo.constant dense<2.14748365E+9> : tensor + // CHECK: %[[DIVIDE:.*]] = chlo.broadcast_divide %arg0, %[[SCALES]] + // CHECK-SAME: {broadcast_dimensions = array} + // CHECK-SAME: (tensor<26x26x3x2xf32>, tensor<2xf32>) -> tensor<26x26x3x2xf32> + // CHECK: %[[ADD:.*]] = chlo.broadcast_add %[[DIVIDE]], %[[ZPS]] + // CHECK-SAME: {broadcast_dimensions = array} + // CHECK-SAME: (tensor<26x26x3x2xf32>, tensor<2xf32>) -> tensor<26x26x3x2xf32> + // CHECK: %[[CLAMP:.*]] = mhlo.clamp %[[QMIN]], %[[ADD]], %[[QMAX]] + // CHECK: %[[ROUND:.*]] = mhlo.round_nearest_even %[[CLAMP]] + // CHECK: %[[RESULT:.*]] = mhlo.convert %[[ROUND]] + // CHECK-SAME: (tensor<26x26x3x2xf32>) -> tensor<26x26x3x2xi32> + %0 = mhlo.uniform_quantize %arg0 : (tensor<26x26x3x2xf32> + ) -> tensor<26x26x3x2x!quant.uniform> + return %0 : tensor<26x26x3x2x!quant.uniform> +} + +// ----- + +// CHECK-LABEL: func @dequantize_per_channel +func.func @dequantize_per_channel( + %arg0: tensor<26x26x3x2x!quant.uniform> + ) -> tensor<26x26x3x2xf32> { + // CHECK-DAG: %[[SCALES:.*]] = mhlo.constant dense<[1.100000e+00, 1.100000e-01]> + // CHECK-DAG: %[[ZPS:.*]] = mhlo.constant dense<[-10, 2]> : tensor<2xi32> + // CHECK: %[[SUBTRACT:.*]] = chlo.broadcast_subtract + // CHECK-SAME: %[[INPUT:.*]], %[[ZPS]] + // CHECK-SAME: {broadcast_dimensions = array} + // CHECK-SAME: (tensor<26x26x3x2xi32>, tensor<2xi32>) -> tensor<26x26x3x2xi32> + // CHECK: %[[FLOAT:.*]] = mhlo.convert %[[SUBTRACT]] + // CHECK: %[[RESULT:.*]] = chlo.broadcast_multiply + // CHECK-SAME: %[[FLOAT]], %[[SCALES]] + // CHECK-SAME: {broadcast_dimensions = array} + // CHECK-SAME: (tensor<26x26x3x2xf32>, tensor<2xf32>) -> tensor<26x26x3x2xf32> + %0 = mhlo.uniform_dequantize %arg0 : ( + tensor<26x26x3x2x!quant.uniform> + ) -> tensor<26x26x3x2xf32> + return %0 : tensor<26x26x3x2xf32> +} + +// ----- + +// CHECK-LABEL: func @add +func.func @add( + %arg0: tensor>, + %arg1: tensor> + ) -> tensor> { + // CHECK: %[[VAL1:.*]] = mhlo.convert %[[VAL0:.*]] : (tensor) -> tensor + // CHECK: %[[VAL3:.*]] = mhlo.convert %[[VAL2:.*]] : (tensor) -> tensor + // CHECK-DAG: %[[VAL5:.*]] = mhlo.constant dense<3> : tensor + // CHECK: %[[VAL4:.*]] = chlo.broadcast_add %[[VAL1]], %[[VAL3]] : (tensor, tensor) -> tensor + // CHECK: %[[VAL6:.*]] = chlo.broadcast_subtract %[[VAL4]], %[[VAL5]] : (tensor, tensor) -> tensor + // CHECK: %[[VAL9:.*]] = mhlo.clamp %[[VAL7:.*]], %[[VAL6]], %[[VAL8:.*]] : (tensor, tensor, tensor) -> tensor + // CHECK: %[[VAL10:.*]] = mhlo.convert %[[VAL9]] : (tensor) -> tensor + %0 = mhlo.add %arg0, %arg1: ( + tensor>, + tensor> + ) -> tensor> + return %0: tensor> +} + +// ----- + +// CHECK-LABEL: func @add_i32 +func.func @add_i32( + %arg0: tensor>, + %arg1: tensor> + ) -> tensor> { + // CHECK: %[[VAL1:.*]] = mhlo.convert %[[VAL0:.*]] : tensor + // CHECK: %[[VAL3:.*]] = mhlo.convert %[[VAL2:.*]] : tensor + // CHECK-DAG: %[[VAL5:.*]] = mhlo.constant dense<3> : tensor + // CHECK: %[[VAL4:.*]] = chlo.broadcast_add %[[VAL1:.*]], %[[VAL3:.*]] : (tensor, tensor) -> tensor + // CHECK: %[[VAL6:.*]] = chlo.broadcast_subtract %[[VAL4]], %[[VAL5]] : (tensor, tensor) -> tensor + // CHECK-NEXT: return + %2 = mhlo.add %arg0, %arg1: ( + tensor>, + tensor> + ) -> tensor> + return %2 : tensor> +} + +// ----- + +// CHECK-LABEL: func @add_int4 +func.func @add_int4( + %arg0: tensor>, + %arg1: tensor> + ) -> tensor> { + // CHECK: %[[VAL1:.*]] = mhlo.convert %[[VAL0:.*]] : (tensor) -> tensor + // CHECK: %[[VAL3:.*]] = mhlo.convert %[[VAL2:.*]] : (tensor) -> tensor + // CHECK-DAG: %[[VAL5:.*]] = mhlo.constant dense<3> : tensor + // CHECK: %[[VAL4:.*]] = chlo.broadcast_add %[[VAL1]], %[[VAL3]] : (tensor, tensor) -> tensor + // CHECK: %[[VAL6:.*]] = chlo.broadcast_subtract %[[VAL4]], %[[VAL5]] : (tensor, tensor) -> tensor + // CHECK: %[[VAL9:.*]] = mhlo.clamp %[[VAL7:.*]], %[[VAL6]], %[[VAL8:.*]] : (tensor, tensor, tensor) -> tensor + // CHECK: %[[VAL10:.*]] = mhlo.convert %[[VAL9]] : (tensor) -> tensor + %0 = mhlo.add %arg0, %arg1: ( + tensor>, + tensor> + ) -> tensor> + return %0 : tensor> +} + +// ----- + +// CHECK-LABEL: @add_different_lhs_type +func.func @add_different_lhs_type( + %arg0: tensor>, + %arg1: tensor> + ) -> tensor> { + // CHECK-DAG: %[[COMBINED_SCALE:.*]] = mhlo.constant dense<2.000000e+00> : tensor + // CHECK-DAG: %[[LHS:.*]] = mhlo.convert %arg0 : (tensor) -> tensor + // CHECK-DAG: %[[MUL:.*]] = chlo.broadcast_multiply %[[LHS]], %[[COMBINED_SCALE]] : (tensor, tensor) -> tensor + // CHECK-DAG: %[[COMBINED_ZP:.*]] = mhlo.constant dense<-5.000000e+00> + // CHECK: %[[LHS_32:.*]] = chlo.broadcast_add %[[MUL]], %[[COMBINED_ZP]] : (tensor, tensor) -> tensor + + // CHECK-DAG: %[[RHS_32:.*]] = mhlo.convert %[[RHS:.*]] : (tensor) -> tensor + // CHECK-DAG: %[[RES_ZPS:.*]] = mhlo.constant dense<1> : tensor + // CHECK-DAG: %[[VAL7:.*]] = chlo.broadcast_add %[[LHS_32_REQ:.*]], %[[RHS_32:.*]] : (tensor, tensor) -> tensor + // CHECK-DAG: %[[VAL9:.*]] = chlo.broadcast_subtract %[[VAL7:.*]], %[[RES_ZPS:.*]] : (tensor, tensor) -> tensor + // CHECK-DAG: %[[QUANT_MIN:.*]] = mhlo.constant dense<-128> : tensor + // CHECK-DAG: %[[QUANT_MAX:.*]] = mhlo.constant dense<127> : tensor + // CHECK: %[[VAL10:.*]] = mhlo.clamp %[[QUANT_MIN:.*]], %[[VAL9:.*]], %[[QUANT_MAX:.*]] : (tensor, tensor, tensor) -> tensor + // CHECK: %[[VAL11:.*]] = mhlo.convert %[[VAL10:.*]] : (tensor) -> tensor + %2 = mhlo.add %arg0, %arg1: ( + tensor>, + tensor> + ) -> tensor> + return %2 : tensor> +} + +// ----- + +// CHECK-LABEL: @add_different_rhs_type +func.func @add_different_rhs_type( + %arg0: tensor>, + %arg1: tensor> + ) -> tensor> { + // CHECK-DAG: %[[COMBINED_SCALE:.*]] = mhlo.constant dense<2.000000e+00> : tensor + // CHECK-DAG: %[[RHS:.*]] = mhlo.convert %arg1 : (tensor) -> tensor + // CHECK-DAG: %[[MUL:.*]] = chlo.broadcast_multiply %[[RHS]], %[[COMBINED_SCALE]] : (tensor, tensor) -> tensor + // CHECK-DAG: %[[COMBINED_ZP:.*]] = mhlo.constant dense<-5.000000e+00> + // CHECK: %[[RHS_32:.*]] = chlo.broadcast_add %[[MUL]], %[[COMBINED_ZP]] : (tensor, tensor) -> tensor + + // CHECK-DAG: %[[RES_ZPS:.*]] = mhlo.constant dense<1> : tensor + // CHECK-DAG: %[[VAL7:.*]] = chlo.broadcast_add %[[LHS_32:.*]], %[[RHS_32_REQ:.*]] : (tensor, tensor) -> tensor + // CHECK-DAG: %[[VAL9:.*]] = chlo.broadcast_subtract %[[VAL7:.*]], %[[RES_ZPS:.*]] : (tensor, tensor) -> tensor + // CHECK-DAG: %[[QUANT_MIN:.*]] = mhlo.constant dense<-128> : tensor + // CHECK-DAG: %[[QUANT_MAX:.*]] = mhlo.constant dense<127> : tensor + // CHECK: %[[VAL10:.*]] = mhlo.clamp %[[QUANT_MIN:.*]], %[[VAL9:.*]], %[[QUANT_MAX:.*]] : (tensor, tensor, tensor) -> tensor + // CHECK: %[[VAL11:.*]] = mhlo.convert %[[VAL10:.*]] : (tensor) -> tensor + %0 = mhlo.add %arg0, %arg1: ( + tensor>, + tensor> + ) -> tensor> + return %0 : tensor> +} + +// CHECK-LABEL: @add_different_res_type +func.func @add_different_res_type( + %arg0: tensor>, + %arg1: tensor> + ) -> tensor> { + // CHECK-DAG: %[[COMBINED_SCALE:.*]] = mhlo.constant dense<2.000000e+00> : tensor + // CHECK-DAG: %[[LHS:.*]] = mhlo.convert %arg0 : (tensor) -> tensor + // CHECK-DAG: %[[MUL:.*]] = chlo.broadcast_multiply %[[LHS]], %[[COMBINED_SCALE]] : (tensor, tensor) -> tensor + // CHECK-DAG: %[[COMBINED_ZP:.*]] = mhlo.constant dense<-5.000000e+00> + // CHECK: %[[LHS_32_REQ:.*]] = chlo.broadcast_add %[[MUL]], %[[COMBINED_ZP]] : (tensor, tensor) -> tensor + + // CHECK-DAG: %[[COMBINED_SCALE:.*]] = mhlo.constant dense<2.000000e+00> : tensor + // CHECK-DAG: %[[RHS:.*]] = mhlo.convert %arg1 : (tensor) -> tensor + // CHECK-DAG: %[[MUL:.*]] = chlo.broadcast_multiply %[[RHS]], %[[COMBINED_SCALE]] : (tensor, tensor) -> tensor + // CHECK-DAG: %[[COMBINED_ZP:.*]] = mhlo.constant dense<-5.000000e+00> + // CHECK: %[[RHS_32_REQ:.*]] = chlo.broadcast_add %[[MUL]], %[[COMBINED_ZP]] : (tensor, tensor) -> tensor + + // CHECK-DAG: %[[RES_ZPS:.*]] = mhlo.constant dense<1> : tensor + // CHECK-DAG: %[[VAL11:.*]] = chlo.broadcast_add %[[LHS_32_REQ:.*]], %[[RHS_32_REQ:.*]] : (tensor, tensor) -> tensor + // CHECK-DAG: %[[VAL12:.*]] = chlo.broadcast_subtract %[[VAL11:.*]], %[[RES_ZPS:.*]] : (tensor, tensor) -> tensor + // CHECK-DAG: %[[QUANT_MIN:.*]] = mhlo.constant dense<-128> : tensor + // CHECK-DAG: %[[QUANT_MAX:.*]] = mhlo.constant dense<127> : tensor + // CHECK: %[[VAL13:.*]] = mhlo.clamp %[[QUANT_MIN:.*]], %[[VAL12:.*]], %[[QUANT_MAX:.*]] : (tensor, tensor, tensor) -> tensor + // CHECK: %[[VAL14:.*]] = mhlo.convert %[[VAL13:.*]] : (tensor) -> tensor + %0 = mhlo.add %arg0, %arg1: ( + tensor>, + tensor> + ) -> tensor> + return %0 : tensor> +} + +// ----- + +// CHECK-LABEL: func @add_per_channel +func.func @add_per_channel( + %arg0: tensor>, + %arg1: tensor> + ) -> tensor> { + // CHECK: %[[ADD:.*]] = mhlo.add {{.*}} : tensor + // CHECK: %[[ZPS:.*]] = mhlo.constant dense<[3, 2]> : tensor<2xi32> + // CHECK: %[[BCAST_SUB:.*]] = chlo.broadcast_subtract %[[ADD]], %[[ZPS]] + // CHECK-SAME: {broadcast_dimensions = array} + // CHECK-SAME: (tensor, tensor<2xi32>) -> tensor + // CHECK: return %[[BCAST_SUB]] : tensor + %11 = mhlo.add %arg0, %arg1 : tensor> + return %11 : tensor> +} + +// ----- + +// CHECK-LABEL: func @add_per_channel_no_zp +func.func @add_per_channel_no_zp( + %arg0: tensor>, + %arg1: tensor> + ) -> tensor> { + // CHECK: %[[ADD:.*]] = mhlo.add {{.*}} : tensor + // CHECK: return %[[ADD]] : tensor + %11 = mhlo.add %arg0, %arg1 : tensor> + return %11 : tensor> +} + +// ----- + +func.func @add_per_channel_i8( + %arg0: tensor>, + %arg1: tensor> + ) -> tensor> { + // expected-error@+2 {{Per-channel quantized AddOp requires i32 storage type}} + // expected-error@+1 {{failed to legalize operation 'mhlo.add' that was explicitly marked illegal}} + %11 = mhlo.add %arg0, %arg1 : tensor> + return %11 : tensor> +} + +// ----- + +func.func @add_per_channel_different_quant_types( + %arg0: tensor>, + %arg1: tensor> + ) -> tensor> { + // expected-error@+2 {{Per-channel quantized AddOp requires the same quantized element type for all operands and results}} + // expected-error@+1 {{failed to legalize operation 'mhlo.add' that was explicitly marked illegal}} + %11 = mhlo.add %arg0, %arg1 : ( + tensor>, + tensor> + ) -> tensor> + return %11 : tensor> +} + +// ----- + +func.func @add_per_channel_per_tensor_mix( + %arg0: tensor>, + %arg1: tensor> + ) -> tensor> { + // expected-error@+2 {{Per-channel quantized AddOp requires the same quantized element type for all operands and results}} + // expected-error@+1 {{failed to legalize operation 'mhlo.add' that was explicitly marked illegal}} + %11 = mhlo.add %arg0, %arg1 : ( + tensor>, + tensor> + ) -> tensor> + return %11 : tensor> +} + +// ----- + +// CHECK-LABEL: func @requantize +func.func @requantize( + %arg0: tensor> + ) -> tensor> { + // CHECK-DAG: %[[MERGED_ZP:.*]] = mhlo.constant dense<-5.000000e+00> : tensor + // CHECK-DAG: %[[MERGED_SCALE:.*]] = mhlo.constant dense<2.000000e+00> : tensor + // CHECK-DAG: %[[VAL1:.*]] = mhlo.convert %arg0 : (tensor) -> tensor + // CHECK-DAG: %[[VAL2:.*]] = chlo.broadcast_multiply %[[VAL1]], %[[MERGED_SCALE]] : (tensor, tensor) -> tensor + // CHECK: %[[VAL3:.*]] = chlo.broadcast_add %[[VAL2]], %[[MERGED_ZP]] : (tensor, tensor) -> tensor + // CHECK-DAG: %[[QUANT_MIN:.*]] = mhlo.constant dense<-1.280000e+02> : tensor + // CHECK-DAG: %[[QUANT_MAX:.*]] = mhlo.constant dense<1.270000e+02> : tensor + // CHECK: %[[VAL4:.*]] = mhlo.clamp %[[QUANT_MIN]], %[[VAL3]], %[[QUANT_MAX]] : (tensor, tensor, tensor) -> tensor + // CHECK: %[[VAL5:.*]] = mhlo.round_nearest_even %[[VAL4]] : tensor + // CHECK: %[[VAL6:.*]] = mhlo.convert %[[VAL5]] : (tensor) -> tensor + %0 = mhlo.uniform_quantize %arg0 : ( + tensor> + ) -> tensor> + return %0 : tensor> +} + +// ----- + +// CHECK-LABEL: func @requantize_merged_zp_zero +func.func @requantize_merged_zp_zero( + %arg0: tensor> + ) -> tensor> { + // CHECK-DAG: %[[MERGED_SCALE:.*]] = mhlo.constant dense<2.000000e+00> : tensor + // CHECK-DAG: %[[VAL1:.*]] = mhlo.convert %arg0 : (tensor) -> tensor + // CHECK: %[[VAL2:.*]] = chlo.broadcast_multiply %[[VAL1]], %[[MERGED_SCALE]] : (tensor, tensor) -> tensor + // CHECK-DAG: %[[QUANT_MIN:.*]] = mhlo.constant dense<-1.280000e+02> : tensor + // CHECK-DAG: %[[QUANT_MAX:.*]] = mhlo.constant dense<1.270000e+02> : tensor + // CHECK: %[[VAL3:.*]] = mhlo.clamp %[[QUANT_MIN]], %[[VAL2]], %[[QUANT_MAX]] : (tensor, tensor, tensor) -> tensor + // CHECK: %[[VAL4:.*]] = mhlo.round_nearest_even %[[VAL3]] : tensor + // CHECK: %[[VAL5:.*]] = mhlo.convert %[[VAL4]] : (tensor) -> tensor + %0 = mhlo.uniform_quantize %arg0 : (tensor>) -> tensor> + return %0 : tensor> +} + +// ----- + +// CHECK-LABEL: func @requantize_per_channel +func.func @requantize_per_channel( + %arg0: tensor<2x2x!quant.uniform> + ) -> tensor<2x2x!quant.uniform> { + // CHECK-DAG: %[[VAL1:.*]] = mhlo.convert %arg0 : (tensor<2x2xi8>) -> tensor<2x2xf32> + // CHECK-DAG: %[[MERGED_SCALE:.*]] = mhlo.constant dense<[2.000000e+00, 5.000000e-01]> : tensor<2xf32> + // CHECK: %[[VAL2:.*]] = chlo.broadcast_multiply %[[VAL1]], %[[MERGED_SCALE]] + // CHECK-SAME: broadcast_dimensions = array + // CHECK-DAG: %[[MERGED_ZP:.*]] = mhlo.constant dense<[-5.000000e+00, -2.000000e+00]> : tensor<2xf32> + // CHECK: %[[VAL3:.*]] = chlo.broadcast_add %[[VAL2]], %[[MERGED_ZP]] + // CHECK-SAME: broadcast_dimensions = array + // CHECK-DAG: %[[QUANT_MIN:.*]] = mhlo.constant dense<-1.280000e+02> : tensor + // CHECK-DAG: %[[QUANT_MAX:.*]] = mhlo.constant dense<1.270000e+02> : tensor + // CHECK: %[[VAL4:.*]] = mhlo.clamp %[[QUANT_MIN]], %[[VAL3]], %[[QUANT_MAX]] + // CHECK: %[[VAL5:.*]] = mhlo.round_nearest_even %[[VAL4]] : tensor<2x2xf32> + // CHECK: %[[VAL6:.*]] = mhlo.convert %[[VAL5]] : (tensor<2x2xf32>) -> tensor<2x2xi8> + %0 = mhlo.uniform_quantize %arg0 : ( + tensor<2x2x!quant.uniform> + ) -> tensor<2x2x!quant.uniform> + return %0 : tensor<2x2x!quant.uniform> +} + +// ----- + +// CHECK-LABEL: func @requantize_per_channel_to_per_tensor +func.func @requantize_per_channel_to_per_tensor( + %arg0: tensor<2x2x!quant.uniform> + ) -> tensor<2x2x!quant.uniform> { + // CHECK-DAG: %[[VAL1:.*]] = mhlo.convert %arg0 : (tensor<2x2xi8>) -> tensor<2x2xf32> + // CHECK-DAG: %[[MERGED_SCALE:.*]] = mhlo.constant dense<[2.000000e+00, 1.000000e+00]> : tensor<2xf32> + // CHECK: %[[VAL2:.*]] = chlo.broadcast_multiply %[[VAL1]], %[[MERGED_SCALE]] + // CHECK-SAME: broadcast_dimensions = array + // CHECK-DAG: %[[MERGED_ZP:.*]] = mhlo.constant dense<[-5.000000e+00, -1.000000e+00]> : tensor<2xf32> + // CHECK: %[[VAL3:.*]] = chlo.broadcast_add %[[VAL2]], %[[MERGED_ZP]] + // CHECK-SAME: broadcast_dimensions = array + // CHECK-DAG: %[[QUANT_MIN:.*]] = mhlo.constant dense<-1.280000e+02> : tensor + // CHECK-DAG: %[[QUANT_MAX:.*]] = mhlo.constant dense<1.270000e+02> : tensor + // CHECK: %[[VAL4:.*]] = mhlo.clamp %[[QUANT_MIN]], %[[VAL3]], %[[QUANT_MAX]] + // CHECK: %[[VAL5:.*]] = mhlo.round_nearest_even %[[VAL4]] : tensor<2x2xf32> + // CHECK: %[[VAL6:.*]] = mhlo.convert %[[VAL5]] : (tensor<2x2xf32>) -> tensor<2x2xi8> + %0 = mhlo.uniform_quantize %arg0 : ( + tensor<2x2x!quant.uniform> + ) -> tensor<2x2x!quant.uniform> + return %0 : tensor<2x2x!quant.uniform> +} + +// ----- + +// CHECK-LABEL: func @requantize_per_tensor_to_per_channel +func.func @requantize_per_tensor_to_per_channel( + %arg0: tensor<2x2x!quant.uniform> + ) -> tensor<2x2x!quant.uniform> { + // CHECK-DAG: %[[VAL1:.*]] = mhlo.convert %arg0 : (tensor<2x2xi8>) -> tensor<2x2xf32> + // CHECK-DAG: %[[MERGED_SCALE:.*]] = mhlo.constant dense<[1.000000e+00, 5.000000e-01]> : tensor<2xf32> + // CHECK: %[[VAL2:.*]] = chlo.broadcast_multiply %[[VAL1]], %[[MERGED_SCALE]] + // CHECK-SAME: broadcast_dimensions = array + // CHECK-DAG: %[[MERGED_ZP:.*]] = mhlo.constant dense<[-1.000000e+00, -2.000000e+00]> : tensor<2xf32> + // CHECK: %[[VAL3:.*]] = chlo.broadcast_add %[[VAL2]], %[[MERGED_ZP]] + // CHECK-SAME: broadcast_dimensions = array + // CHECK-DAG: %[[QUANT_MIN:.*]] = mhlo.constant dense<-1.280000e+02> : tensor + // CHECK-DAG: %[[QUANT_MAX:.*]] = mhlo.constant dense<1.270000e+02> : tensor + // CHECK: %[[VAL4:.*]] = mhlo.clamp %[[QUANT_MIN]], %[[VAL3]], %[[QUANT_MAX]] + // CHECK: %[[VAL5:.*]] = mhlo.round_nearest_even %[[VAL4]] : tensor<2x2xf32> + // CHECK: %[[VAL6:.*]] = mhlo.convert %[[VAL5]] : (tensor<2x2xf32>) -> tensor<2x2xi8> + %0 = mhlo.uniform_quantize %arg0 : ( + tensor<2x2x!quant.uniform> + ) -> tensor<2x2x!quant.uniform> + return %0 : tensor<2x2x!quant.uniform> +} + +// ----- + +func.func @requantize_per_channel_change_axis( + %arg0: tensor<2x2x!quant.uniform> + ) -> tensor<2x2x!quant.uniform> { + // expected-error@+2 {{Cannot requantize while changing quantization_axis}} + // expected-error@+1 {{failed to legalize operation 'mhlo.uniform_quantize' that was explicitly marked illegal}} + %0 = mhlo.uniform_quantize %arg0 : ( + tensor<2x2x!quant.uniform> + ) -> tensor<2x2x!quant.uniform> + return %0 : tensor<2x2x!quant.uniform> +} + +// ----- + +// CHECK-LABEL: func @dot +func.func @dot(%arg0: tensor<2x2x!quant.uniform>, + %arg1: tensor<2x2x!quant.uniform> + ) -> tensor<2x2x!quant.uniform> { + // CHECK: "mhlo.dot_general" + // CHECK-SAME: lhs_contracting_dimensions = [1] + // CHECK-SAME: rhs_contracting_dimensions = [0] + // CHECK-SAME: (tensor<2x2xi8>, tensor<2x2xi8>) -> tensor<2x2xi32> + %0 = "mhlo.dot" (%arg0, %arg1) : ( + tensor<2x2x!quant.uniform>, + tensor<2x2x!quant.uniform> + ) -> tensor<2x2x!quant.uniform> + return %0 : tensor<2x2x!quant.uniform> +} + +// ----- + +// CHECK-LABEL: func @dot_int4 +func.func @dot_int4( + %arg0: tensor<2x2x!quant.uniform>, + %arg1: tensor<2x2x!quant.uniform> + ) -> tensor<2x2x!quant.uniform> { + // CHECK: "mhlo.dot_general" + // CHECK-SAME: lhs_contracting_dimensions = [1] + // CHECK-SAME: rhs_contracting_dimensions = [0] + // CHECK-SAME: (tensor<2x2xi4>, tensor<2x2xi4>) -> tensor<2x2xi32> + %0 = "mhlo.dot" (%arg0, %arg1): ( + tensor<2x2x!quant.uniform>, + tensor<2x2x!quant.uniform> + ) -> tensor<2x2x!quant.uniform> + return %0 : tensor<2x2x!quant.uniform> +} + +// ----- + +// CHECK-LABEL: func @dot_dynamic +func.func @dot_dynamic( + %arg0: tensor>, + %arg1: tensor> + ) -> tensor> { + // CHECK: %[[DOT:.*]] = "mhlo.dot_general" + // CHECK-SAME: lhs_contracting_dimensions = [1] + // CHECK-SAME: rhs_contracting_dimensions = [0] + // CHECK-SAME: (tensor, tensor) -> tensor + + // CHECK: mhlo.reduce + // CHECK-SAME: applies mhlo.add across dimensions = [1] + // CHECK-SAME: (tensor, tensor) -> tensor + // CHECK: "mhlo.get_dimension_size"(%[[DOT]]) + // CHECK-SAME: {dimension = 0 : i64} : (tensor) -> tensor + // CHECK: "mhlo.get_dimension_size"(%[[DOT]]) + // CHECK-SAME: {dimension = 1 : i64} : (tensor) -> tensor + // CHECK: %[[DYN_DIMS:.*]] = "mhlo.concatenate" + // CHECK-SAME: {dimension = 0 : i64} + // CHECK: mhlo.dynamic_broadcast_in_dim + // CHECK-SAME: %[[DYN_DIMS]]) + // CHECK-SAME: broadcast_dimensions = dense<0> + // CHECK-SAME: (tensor, tensor<2xi64>) -> tensor + + // CHECK: mhlo.reduce + // CHECK-SAME: applies mhlo.add across dimensions = [0] + // CHECK-SAME: (tensor, tensor) -> tensor + // CHECK: mhlo.dynamic_broadcast_in_dim + // CHECK-SAME: %[[DYN_DIMS]]) + // CHECK-SAME: broadcast_dimensions = dense<1> + // CHECK-SAME: (tensor, tensor<2xi64>) -> tensor + %0 = "mhlo.dot" (%arg0, %arg1) : ( + tensor>, + tensor> + ) -> tensor> + return %0 : tensor> +} + +// ----- + +// CHECK-LABEL: func @dot_dynamic_int4 +func.func @dot_dynamic_int4( + %arg0: tensor>, + %arg1: tensor> + ) -> tensor> { + // CHECK: mhlo.dot_general + // CHECK-SAME: lhs_contracting_dimensions = [1] + // CHECK-SAME: rhs_contracting_dimensions = [0] + // CHECK-SAME: (tensor, tensor) -> tensor + %0 = "mhlo.dot" (%arg0, %arg1) : ( + tensor>, + tensor> + ) -> tensor> + return %0 : tensor> +} + +// ----- + +// CHECK-LABEL: func @dot_dynamic_contracting_dim +func.func @dot_dynamic_contracting_dim( + %arg0: tensor<2x?x!quant.uniform>, + %arg1: tensor> + ) -> tensor<2x2x!quant.uniform> { + // CHECK: "mhlo.dot_general" + // CHECK-SAME: lhs_contracting_dimensions = [1] + // CHECK-SAME: rhs_contracting_dimensions = [0] + // CHECK-SAME: (tensor<2x?xi8>, tensor) -> tensor<2x2xi32> + + // CHECK: mhlo.reduce + // CHECK-SAME: applies mhlo.add across dimensions = [1] + // CHECK-SAME: (tensor<2x?xi32>, tensor) -> tensor<2xi32> + + // CHECK: mhlo.reduce + // CHECK-SAME: applies mhlo.add across dimensions = [0] + // CHECK-SAME: (tensor, tensor) -> tensor<2xi32> + + // CHECK: %[[DYNAMIC_DIM_INIT:.*]] = mhlo.constant dense<1> : tensor + // CHECK: %[[DYNAMIC_DIM:.*]] = "mhlo.get_dimension_size" + // CHECK-SAME: {dimension = 0 : i64} : (tensor) -> tensor + // CHECK: %[[DYNAMIC_DIM_TOTAL:.*]] = mhlo.multiply + // CHECK-SAME: %[[DYNAMIC_DIM_INIT]], %[[DYNAMIC_DIM]] + // CHECK: %[[DIMS:.*]] = mhlo.constant dense<9> : tensor + // CHECK: %[[DIMS_1:.*]] = mhlo.multiply %[[DIMS]], %[[DYNAMIC_DIM_TOTAL]] + // CHECK: chlo.broadcast_subtract %[[ZP_OFFSET:.*]], %[[DIMS:.*]] + %0 = "mhlo.dot" (%arg0, %arg1) : ( + tensor<2x?x!quant.uniform>, + tensor> + ) -> tensor<2x2x!quant.uniform> + return %0 : tensor<2x2x!quant.uniform> +} + +// ----- + +// CHECK-LABEL: func @dot_dynamic_result_dim +func.func @dot_dynamic_result_dim( + %arg0: tensor>, + %arg1: tensor<2x?x!quant.uniform> + ) -> tensor> { + // CHECK: "mhlo.dot_general" + // CHECK-SAME: lhs_contracting_dimensions = [1] + // CHECK-SAME: rhs_contracting_dimensions = [0] + // CHECK-SAME: (tensor, tensor<2x?xi8>) -> tensor + + // CHECK: mhlo.reduce + // CHECK-SAME: applies mhlo.add across dimensions = [1] + // CHECK-SAME: (tensor, tensor) -> tensor + // CHECK: mhlo.dynamic_broadcast_in_dim + // CHECK-SAME: broadcast_dimensions = dense<0> + // CHECK-SAME: (tensor, tensor<2xi64>) -> tensor + + // CHECK: mhlo.reduce + // CHECK-SAME: applies mhlo.add across dimensions = [0] + // CHECK-SAME: (tensor<2x?xi32>, tensor) -> tensor + // CHECK: mhlo.dynamic_broadcast_in_dim + // CHECK-SAME: broadcast_dimensions = dense<1> + // CHECK-SAME: (tensor, tensor<2xi64>) -> tensor + + %0 = "mhlo.dot" (%arg0, %arg1) : ( + tensor>, + tensor<2x?x!quant.uniform> + ) -> tensor> + return %0 : tensor> +} + +// ----- + +// CHECK-LABEL: func @dot_dynamic_batch_dim +func.func @dot_dynamic_batch_dim( + %arg0: tensor>, + %arg1: tensor<2x2x!quant.uniform> + ) -> tensor> { + // CHECK: "mhlo.dot_general" + // CHECK-SAME: lhs_contracting_dimensions = [1] + // CHECK-SAME: rhs_contracting_dimensions = [0] + // CHECK-SAME: (tensor, tensor<2x2xi8>) -> tensor + + // CHECK: mhlo.reduce + // CHECK-SAME: applies mhlo.add across dimensions = [1] + // CHECK-SAME: (tensor, tensor) -> tensor + // CHECK: mhlo.dynamic_broadcast_in_dim + // CHECK-SAME: broadcast_dimensions = dense<0> + // CHECK-SAME: (tensor, tensor<2xi64>) -> tensor + + // CHECK: mhlo.reduce + // CHECK-SAME: applies mhlo.add across dimensions = [0] + // CHECK-SAME: (tensor<2x2xi32>, tensor) -> tensor<2xi32> + // CHECK: mhlo.dynamic_broadcast_in_dim + // CHECK-SAME: broadcast_dimensions = dense<1> + // CHECK-SAME: (tensor<2xi32>, tensor<2xi64>) -> tensor + + %0 = "mhlo.dot" (%arg0, %arg1) : ( + tensor>, + tensor<2x2x!quant.uniform> + ) -> tensor> + return %0 : tensor> +} + +// ----- + +// CHECK-LABEL: func @dot_general +func.func @dot_general( + %arg0: tensor<2x5x6x!quant.uniform>, + %arg1: tensor<6x8x2x!quant.uniform> + ) -> tensor<2x5x8x!quant.uniform> { + // CHECK: %[[DOT_RES:.*]] = "mhlo.dot_general" + // CHECK-SAME: lhs_batching_dimensions = [0] + // CHECK-SAME: rhs_batching_dimensions = [2] + // CHECK-SAME: lhs_contracting_dimensions = [2] + // CHECK-SAME: rhs_contracting_dimensions = [0] + + // Zero point offset contribution from LHS tensor * RHS ZP. + + // CHECK: %[[LHS_I32:.*]] = mhlo.convert %[[LHS:.*]] : (tensor<2x5x6xi8>) + // CHECK-SAME: -> tensor<2x5x6xi32> + // CHECK: %[[LHS_REDUCE_INIT:.*]] = mhlo.constant dense<0> : tensor + // CHECK: %[[LHS_REDUCE:.*]] = mhlo.reduce(%[[LHS_I32]] init: %[[LHS_REDUCE_INIT]]) + // CHECK-SAME: applies mhlo.add across dimensions = [2] + // CHECK-SAME: (tensor<2x5x6xi32>, tensor) + // CHECK-SAME: -> tensor<2x5xi32> + // CHECK: %[[RHS_ZP:.*]] = mhlo.constant dense<5> : tensor + // CHECK: %[[LHS_ZP_CONTRIB:.*]] = chlo.broadcast_multiply + // CHECK-SAME: %[[LHS_REDUCE]], %[[RHS_ZP]] : + // CHECK-SAME: (tensor<2x5xi32>, tensor) -> tensor<2x5xi32> + // CHECK: %[[LHS_ZP_BCAST:.*]] = "mhlo.broadcast_in_dim"(%[[LHS_ZP_CONTRIB]]) + // CHECK-SAME: broadcast_dimensions = dense<[0, 1]> + // CHECK-SAME: (tensor<2x5xi32>) -> tensor<2x5x8xi32> + + // Zero point offset contribution from RHS tensor * LHS ZP. + + // CHECK: %[[RHS_I32:.*]] = mhlo.convert %[[RHS:.*]] : (tensor<6x8x2xi8>) + // CHECK-SAME: -> tensor<6x8x2xi32> + // CHECK: %[[RHS_REDUCE_INIT:.*]] = mhlo.constant dense<0> : tensor + // CHECK: %[[RHS_REDUCE:.*]] = mhlo.reduce(%[[RHS_I32]] init: %[[RHS_REDUCE_INIT]]) + // CHECK-SAME: applies mhlo.add across dimensions = [0] + // CHECK-SAME: (tensor<6x8x2xi32>, tensor) + // CHECK-SAME: -> tensor<8x2xi32> + // CHECK: %[[RHS_ZP:.*]] = mhlo.constant dense<3> : tensor + // CHECK: %[[RHS_ZP_CONTRIB:.*]] = chlo.broadcast_multiply + // CHECK-SAME: %[[RHS_REDUCE]], %[[RHS_ZP]] : + // CHECK-SAME: (tensor<8x2xi32>, tensor) -> tensor<8x2xi32> + // CHECK: %[[RHS_ZP_BCAST:.*]] = "mhlo.broadcast_in_dim"(%[[RHS_ZP_CONTRIB]]) + // CHECK-SAME: broadcast_dimensions = dense<[2, 0]> + // CHECK-SAME: (tensor<8x2xi32>) -> tensor<2x5x8xi32> + // CHECK: %[[ZP_TOTAL_1:.*]] = mhlo.add %[[LHS_ZP_BCAST]], %[[RHS_ZP_BCAST]] + + // Zero point offset contribution from LHS ZP * RHS ZP. + + // CHECK: %[[ZPS:.*]] = mhlo.constant dense<90> : tensor + // CHECK: %[[ZP_TOTAL_2:.*]] = chlo.broadcast_subtract %[[ZP_TOTAL_1]], %[[ZPS]] + // CHECK-SAME: (tensor<2x5x8xi32>, tensor) -> tensor<2x5x8xi32> + + // Combine dot result with zero point offset and output final result. + + // CHECK: %[[COMBINED_SCALE:.*]] = mhlo.constant dense<5.000000e-01> : tensor + // CHECK: %[[RES_FP:.*]] = mhlo.convert %[[DOT_RES]] + // CHECK-SAME: (tensor<2x5x8xi32>) -> tensor<2x5x8xf32> + // CHECK: %[[RES_FP_1:.*]] = chlo.broadcast_multiply + // CHECK-SAME: %[[RES_FP:.*]], %[[COMBINED_SCALE]] + // CHECK: %[[RES_INT:.*]] = mhlo.convert %[[RES_FP_1]] + // CHECK-SAME: (tensor<2x5x8xf32>) -> tensor<2x5x8xi32> + + // CHECK: %[[ZP_TOTAL_3:.*]] = mhlo.convert %[[ZP_TOTAL_2]] + // CHECK-SAME: (tensor<2x5x8xi32>) -> tensor<2x5x8xf32> + // CHECK: %[[ZP_TOTAL_4:.*]] = chlo.broadcast_multiply + // CHECK-SAME: %[[ZP_TOTAL_3:.*]], %[[COMBINED_SCALE]] + // CHECK: %[[ZP_TOTAL_5:.*]] = mhlo.convert %[[ZP_TOTAL_4]] + // CHECK-SAME: (tensor<2x5x8xf32>) -> tensor<2x5x8xi32> + + // CHECK: %[[RES_ZP:.*]] = mhlo.constant dense<7> : tensor + // CHECK: %[[ZP_TOTAL_6:.*]] = chlo.broadcast_subtract %[[RES_ZP]], %[[ZP_TOTAL_5]] + // CHECK-SAME: (tensor, tensor<2x5x8xi32>) -> tensor<2x5x8xi32> + // CHECK: chlo.broadcast_add %[[RES_INT]], %[[ZP_TOTAL_6]] + + %0 = "mhlo.dot_general" (%arg0, %arg1) { + dot_dimension_numbers = #mhlo.dot< + lhs_batching_dimensions = [0], + rhs_batching_dimensions = [2], + lhs_contracting_dimensions = [2], + rhs_contracting_dimensions = [0] + >} : ( + tensor<2x5x6x!quant.uniform>, + tensor<6x8x2x!quant.uniform> + ) -> tensor<2x5x8x!quant.uniform> + return %0 : tensor<2x5x8x!quant.uniform> +} + +// ----- + +// CHECK-LABEL: func @dot_general_combined_scale_1 +func.func @dot_general_combined_scale_1( + %arg0: tensor<2x5x6x!quant.uniform>, + %arg1: tensor<6x8x2x!quant.uniform> + ) -> tensor<2x5x8x!quant.uniform> { + // CHECK: %[[DOT_RES:.*]] = "mhlo.dot_general" + // CHECK-SAME: lhs_batching_dimensions = [0] + // CHECK-SAME: rhs_batching_dimensions = [2] + // CHECK-SAME: lhs_contracting_dimensions = [2] + // CHECK-SAME: rhs_contracting_dimensions = [0] + + // Zero point offset contribution from LHS tensor * RHS ZP. + + // CHECK: %[[LHS_I32:.*]] = mhlo.convert %[[LHS:.*]] : (tensor<2x5x6xi8>) + // CHECK-SAME: -> tensor<2x5x6xi32> + // CHECK: %[[LHS_REDUCE_INIT:.*]] = mhlo.constant dense<0> : tensor + // CHECK: %[[LHS_REDUCE:.*]] = mhlo.reduce(%[[LHS_I32]] init: %[[LHS_REDUCE_INIT]]) + // CHECK-SAME: applies mhlo.add across dimensions = [2] + // CHECK-SAME: (tensor<2x5x6xi32>, tensor) + // CHECK-SAME: -> tensor<2x5xi32> + // CHECK: %[[RHS_ZP:.*]] = mhlo.constant dense<5> : tensor + // CHECK: %[[LHS_ZP_CONTRIB:.*]] = chlo.broadcast_multiply + // CHECK-SAME: %[[LHS_REDUCE]], %[[RHS_ZP]] : + // CHECK-SAME: (tensor<2x5xi32>, tensor) -> tensor<2x5xi32> + // CHECK: %[[LHS_ZP_BCAST:.*]] = "mhlo.broadcast_in_dim"(%[[LHS_ZP_CONTRIB]]) + // CHECK-SAME: broadcast_dimensions = dense<[0, 1]> + // CHECK-SAME: (tensor<2x5xi32>) -> tensor<2x5x8xi32> + + // Zero point offset contribution from RHS tensor * LHS ZP. + + // CHECK: %[[RHS_I32:.*]] = mhlo.convert %[[RHS:.*]] : (tensor<6x8x2xi8>) + // CHECK-SAME: -> tensor<6x8x2xi32> + // CHECK: %[[RHS_REDUCE_INIT:.*]] = mhlo.constant dense<0> : tensor + // CHECK: %[[RHS_REDUCE:.*]] = mhlo.reduce(%[[RHS_I32]] init: %[[RHS_REDUCE_INIT]]) + // CHECK-SAME: applies mhlo.add across dimensions = [0] + // CHECK-SAME: (tensor<6x8x2xi32>, tensor) + // CHECK-SAME: -> tensor<8x2xi32> + // CHECK: %[[RHS_ZP:.*]] = mhlo.constant dense<3> : tensor + // CHECK: %[[RHS_ZP_CONTRIB:.*]] = chlo.broadcast_multiply + // CHECK-SAME: %[[RHS_REDUCE]], %[[RHS_ZP]] : + // CHECK-SAME: (tensor<8x2xi32>, tensor) -> tensor<8x2xi32> + // CHECK: %[[RHS_ZP_BCAST:.*]] = "mhlo.broadcast_in_dim"(%[[RHS_ZP_CONTRIB]]) + // CHECK-SAME: broadcast_dimensions = dense<[2, 0]> + // CHECK-SAME: (tensor<8x2xi32>) -> tensor<2x5x8xi32> + // CHECK: %[[ZP_TOTAL_1:.*]] = mhlo.add %[[LHS_ZP_BCAST]], %[[RHS_ZP_BCAST]] + + // Zero point offset contribution from LHS ZP * RHS ZP. + + // CHECK: %[[ZPS:.*]] = mhlo.constant dense<90> : tensor + // CHECK: %[[ZP_TOTAL_2:.*]] = chlo.broadcast_subtract %[[ZP_TOTAL_1]], %[[ZPS]] + // CHECK-SAME: (tensor<2x5x8xi32>, tensor) -> tensor<2x5x8xi32> + + // Combine dot result with zero point offset and output final result. + // Do not multiply by combined scale since it is 1.0 and thus no-op. + + // CHECK: %[[RES_ZP:.*]] = mhlo.constant dense<7> : tensor + // CHECK: %[[ZP_TOTAL_3:.*]] = chlo.broadcast_subtract %[[RES_ZP]], %[[ZP_TOTAL_2]] + // CHECK-SAME: (tensor, tensor<2x5x8xi32>) -> tensor<2x5x8xi32> + // CHECK: chlo.broadcast_add %[[DOT_RES]], %[[ZP_TOTAL_3]] + + %0 = "mhlo.dot_general" (%arg0, %arg1) { + dot_dimension_numbers = #mhlo.dot< + lhs_batching_dimensions = [0], + rhs_batching_dimensions = [2], + lhs_contracting_dimensions = [2], + rhs_contracting_dimensions = [0] + >} : ( + tensor<2x5x6x!quant.uniform>, + tensor<6x8x2x!quant.uniform> + ) -> tensor<2x5x8x!quant.uniform> + return %0 : tensor<2x5x8x!quant.uniform> +} + +// ----- + +// CHECK-LABEL: func @dot_general_multiple_batching_dims +func.func @dot_general_multiple_batching_dims( + %arg0: tensor<2x5x3x7x6x!quant.uniform>, + %arg1: tensor<6x2x7x8x3x!quant.uniform> + ) -> tensor<2x3x5x8x!quant.uniform> { + // CHECK: %[[DOT_RES:.*]] = "mhlo.dot_general" + // CHECK-SAME: lhs_batching_dimensions = [0, 2] + // CHECK-SAME: rhs_batching_dimensions = [1, 4] + // CHECK-SAME: lhs_contracting_dimensions = [4, 3] + // CHECK-SAME: rhs_contracting_dimensions = [0, 2]>} + + // Zero point offset contribution from LHS tensor * RHS ZP. + + // CHECK: %[[LHS_I32:.*]] = mhlo.convert %[[LHS:.*]] : (tensor<2x5x3x7x6xi8>) + // CHECK-SAME: -> tensor<2x5x3x7x6xi32> + // CHECK: %[[LHS_REDUCE_INIT:.*]] = mhlo.constant dense<0> : tensor + // CHECK: %[[LHS_REDUCE:.*]] = mhlo.reduce(%[[LHS_I32]] init: %[[LHS_REDUCE_INIT]]) + // CHECK-SAME: applies mhlo.add across dimensions = [4, 3] + // CHECK-SAME: (tensor<2x5x3x7x6xi32>, tensor) + // CHECK-SAME: -> tensor<2x5x3xi32> + // CHECK: %[[RHS_ZP:.*]] = mhlo.constant dense<5> : tensor + // CHECK: %[[LHS_ZP_CONTRIB:.*]] = chlo.broadcast_multiply + // CHECK-SAME: %[[LHS_REDUCE]], %[[RHS_ZP]] : + // CHECK-SAME: (tensor<2x5x3xi32>, tensor) -> tensor<2x5x3xi32> + // CHECK: %[[LHS_ZP_BCAST:.*]] = "mhlo.broadcast_in_dim"(%[[LHS_ZP_CONTRIB]]) + // CHECK-SAME: broadcast_dimensions = dense<[0, 2, 1]> + // CHECK-SAME: (tensor<2x5x3xi32>) -> tensor<2x3x5x8xi32> + + // Zero point offset contribution from RHS tensor * LHS ZP. + + // CHECK: %[[RHS_I32:.*]] = mhlo.convert %[[RHS:.*]] : (tensor<6x2x7x8x3xi8>) + // CHECK-SAME: -> tensor<6x2x7x8x3xi32> + // CHECK: %[[RHS_REDUCE_INIT:.*]] = mhlo.constant dense<0> : tensor + // CHECK: %[[RHS_REDUCE:.*]] = mhlo.reduce(%[[RHS_I32]] init: %[[RHS_REDUCE_INIT]]) + // CHECK-SAME: applies mhlo.add across dimensions = [0, 2] + // CHECK-SAME: (tensor<6x2x7x8x3xi32>, tensor) + // CHECK-SAME: -> tensor<2x8x3xi32> + // CHECK: %[[RHS_ZP:.*]] = mhlo.constant dense<3> : tensor + // CHECK: %[[RHS_ZP_CONTRIB:.*]] = chlo.broadcast_multiply + // CHECK-SAME: %[[RHS_REDUCE]], %[[RHS_ZP]] : + // CHECK-SAME: (tensor<2x8x3xi32>, tensor) -> tensor<2x8x3xi32> + // CHECK: %[[RHS_ZP_BCAST:.*]] = "mhlo.broadcast_in_dim"(%[[RHS_ZP_CONTRIB]]) + // CHECK-SAME: broadcast_dimensions = dense<[0, 3, 1]> + // CHECK-SAME: (tensor<2x8x3xi32>) -> tensor<2x3x5x8xi32> + // CHECK: %[[ZP_TOTAL_1:.*]] = mhlo.add %[[LHS_ZP_BCAST]], %[[RHS_ZP_BCAST]] + + // Zero point offset contribution from LHS ZP * RHS ZP. + + // CHECK: %[[ZPS:.*]] = mhlo.constant dense<630> : tensor + // CHECK: %[[ZP_TOTAL_2:.*]] = chlo.broadcast_subtract %[[ZP_TOTAL_1]], %[[ZPS]] + // CHECK-SAME: (tensor<2x3x5x8xi32>, tensor) -> tensor<2x3x5x8xi32> + + // Combine dot result with zero point offset and output final result. + + // CHECK: %[[COMBINED_SCALE:.*]] = mhlo.constant dense<5.000000e-01> : tensor + // CHECK: %[[RES_FP:.*]] = mhlo.convert %[[DOT_RES]] + // CHECK-SAME: (tensor<2x3x5x8xi32>) -> tensor<2x3x5x8xf32> + // CHECK: %[[RES_FP_1:.*]] = chlo.broadcast_multiply + // CHECK-SAME: %[[RES_FP:.*]], %[[COMBINED_SCALE]] + // CHECK: %[[RES_INT:.*]] = mhlo.convert %[[RES_FP_1]] + // CHECK-SAME: (tensor<2x3x5x8xf32>) -> tensor<2x3x5x8xi32> + + // CHECK: %[[ZP_TOTAL_3:.*]] = mhlo.convert %[[ZP_TOTAL_2]] + // CHECK-SAME: (tensor<2x3x5x8xi32>) -> tensor<2x3x5x8xf32> + // CHECK: %[[ZP_TOTAL_4:.*]] = chlo.broadcast_multiply + // CHECK-SAME: %[[ZP_TOTAL_3:.*]], %[[COMBINED_SCALE]] + // CHECK: %[[ZP_TOTAL_5:.*]] = mhlo.convert %[[ZP_TOTAL_4]] + // CHECK-SAME: (tensor<2x3x5x8xf32>) -> tensor<2x3x5x8xi32> + + // CHECK: %[[RES_ZP:.*]] = mhlo.constant dense<7> : tensor + // CHECK: %[[ZP_TOTAL_6:.*]] = chlo.broadcast_subtract %[[RES_ZP]], %[[ZP_TOTAL_5]] + // CHECK-SAME: (tensor, tensor<2x3x5x8xi32>) -> tensor<2x3x5x8xi32> + // CHECK: chlo.broadcast_add %[[RES_INT]], %[[ZP_TOTAL_6]] + + %0 = "mhlo.dot_general" (%arg0, %arg1) { + dot_dimension_numbers = #mhlo.dot< + lhs_batching_dimensions = [0, 2], + rhs_batching_dimensions = [1, 4], + lhs_contracting_dimensions = [4, 3], + rhs_contracting_dimensions = [0, 2] + >} : ( + tensor<2x5x3x7x6x!quant.uniform>, + tensor<6x2x7x8x3x!quant.uniform> + ) -> tensor<2x3x5x8x!quant.uniform> + return %0 : tensor<2x3x5x8x!quant.uniform> +} + +// ----- + +// CHECK-LABEL: func @dot_general_rhs_zero_zp +func.func @dot_general_rhs_zero_zp( + %arg0: tensor<2x5x6x!quant.uniform>, + %arg1: tensor<6x8x2x!quant.uniform> + ) -> tensor<2x5x8x!quant.uniform> { + // CHECK: %[[DOT_RES:.*]] = "mhlo.dot_general" + // CHECK-SAME: lhs_batching_dimensions = [0] + // CHECK-SAME: rhs_batching_dimensions = [2] + // CHECK-SAME: lhs_contracting_dimensions = [2] + // CHECK-SAME: rhs_contracting_dimensions = [0] + + // Zero point offset contribution from LHS tensor * RHS ZP is 0 and skipped. + + // Zero point offset contribution from RHS tensor * LHS ZP. + + // CHECK: %[[RHS_I32:.*]] = mhlo.convert %[[RHS:.*]] : (tensor<6x8x2xi8>) + // CHECK-SAME: -> tensor<6x8x2xi32> + // CHECK: %[[RHS_REDUCE_INIT:.*]] = mhlo.constant dense<0> : tensor + // CHECK: %[[RHS_REDUCE:.*]] = mhlo.reduce(%[[RHS_I32]] init: %[[RHS_REDUCE_INIT]]) + // CHECK-SAME: applies mhlo.add across dimensions = [0] + // CHECK-SAME: (tensor<6x8x2xi32>, tensor) + // CHECK-SAME: -> tensor<8x2xi32> + // CHECK: %[[RHS_ZP:.*]] = mhlo.constant dense<3> : tensor + // CHECK: %[[RHS_ZP_CONTRIB:.*]] = chlo.broadcast_multiply + // CHECK-SAME: %[[RHS_REDUCE]], %[[RHS_ZP]] : + // CHECK-SAME: (tensor<8x2xi32>, tensor) -> tensor<8x2xi32> + // CHECK: %[[RHS_ZP_BCAST:.*]] = "mhlo.broadcast_in_dim"(%[[RHS_ZP_CONTRIB]]) + // CHECK-SAME: broadcast_dimensions = dense<[2, 0]> + // CHECK-SAME: (tensor<8x2xi32>) -> tensor<2x5x8xi32> + + // Zero point offset contribution from LHS ZP * RHS ZP is 0 and skipped. + + // Combine dot result with zero point offset and output final result. + + // CHECK: %[[COMBINED_SCALE:.*]] = mhlo.constant dense<5.000000e-01> : tensor + // CHECK: %[[RES_FP:.*]] = mhlo.convert %[[DOT_RES]] + // CHECK-SAME: (tensor<2x5x8xi32>) -> tensor<2x5x8xf32> + // CHECK: %[[RES_FP_1:.*]] = chlo.broadcast_multiply + // CHECK-SAME: %[[RES_FP:.*]], %[[COMBINED_SCALE]] + // CHECK: %[[RES_INT:.*]] = mhlo.convert %[[RES_FP_1]] + // CHECK-SAME: (tensor<2x5x8xf32>) -> tensor<2x5x8xi32> + + // CHECK: %[[ZP_TOTAL_1:.*]] = mhlo.convert %[[RHS_ZP_BCAST]] + // CHECK-SAME: (tensor<2x5x8xi32>) -> tensor<2x5x8xf32> + // CHECK: %[[ZP_TOTAL_2:.*]] = chlo.broadcast_multiply + // CHECK-SAME: %[[ZP_TOTAL_1:.*]], %[[COMBINED_SCALE]] + // CHECK: %[[ZP_TOTAL_3:.*]] = mhlo.convert %[[ZP_TOTAL_2]] + // CHECK-SAME: (tensor<2x5x8xf32>) -> tensor<2x5x8xi32> + + // CHECK: %[[RES_ZP:.*]] = mhlo.constant dense<7> : tensor + // CHECK: %[[ZP_TOTAL_4:.*]] = chlo.broadcast_subtract %[[RES_ZP]], %[[ZP_TOTAL_3]] + // CHECK-SAME: (tensor, tensor<2x5x8xi32>) -> tensor<2x5x8xi32> + // CHECK: chlo.broadcast_add %[[RES_INT]], %[[ZP_TOTAL_4]] + + %0 = "mhlo.dot_general" (%arg0, %arg1) { + dot_dimension_numbers = #mhlo.dot< + lhs_batching_dimensions = [0], + rhs_batching_dimensions = [2], + lhs_contracting_dimensions = [2], + rhs_contracting_dimensions = [0] + >} : ( + tensor<2x5x6x!quant.uniform>, + tensor<6x8x2x!quant.uniform> + ) -> tensor<2x5x8x!quant.uniform> + return %0 : tensor<2x5x8x!quant.uniform> +} + +// ----- + +// CHECK-LABEL: func @dot_general_zero_zp +func.func @dot_general_zero_zp( + %arg0: tensor<2x5x6x!quant.uniform>, + %arg1: tensor<6x8x2x!quant.uniform> + ) -> tensor<2x5x8x!quant.uniform> { + // CHECK: %[[DOT_RES:.*]] = "mhlo.dot_general" + // CHECK-SAME: lhs_batching_dimensions = [0] + // CHECK-SAME: rhs_batching_dimensions = [2] + // CHECK-SAME: lhs_contracting_dimensions = [2] + // CHECK-SAME: rhs_contracting_dimensions = [0] + + // Both LHS/RHS have zero zp. No zp contribution. + + // CHECK-DAG: %[[COMBINED_SCALE:.*]] = mhlo.constant dense<1.500000e+00> : tensor + // CHECK: %[[RES_FP:.*]] = mhlo.convert %[[DOT_RES]] : + // CHECK-SAME: (tensor<2x5x8xi32>) -> tensor<2x5x8xf32> + // CHECK: %[[RES_FP_1:.*]] = chlo.broadcast_multiply + // CHECK-SAME: %[[RES_FP:.*]], %[[COMBINED_SCALE]] + // CHECK: %[[RES_INT:.*]] = mhlo.convert %[[RES_FP_1]] + // CHECK-SAME: (tensor<2x5x8xf32>) -> tensor<2x5x8xi32> + + // CHECK: %[[RES_ZP:.*]] = mhlo.constant dense<7> : tensor + // CHECK: chlo.broadcast_add %[[RES_INT]], %[[RES_ZP]] + + %0 = "mhlo.dot_general" (%arg0, %arg1) { + dot_dimension_numbers = #mhlo.dot< + lhs_batching_dimensions = [0], + rhs_batching_dimensions = [2], + lhs_contracting_dimensions = [2], + rhs_contracting_dimensions = [0] + >} : ( + tensor<2x5x6x!quant.uniform>, + tensor<6x8x2x!quant.uniform> + ) -> tensor<2x5x8x!quant.uniform> + return %0 : tensor<2x5x8x!quant.uniform> +} + +// ----- + +// CHECK-LABEL: func @dot_general_multiple_dynamic_dims +func.func @dot_general_multiple_dynamic_dims( + %arg0: tensor>, + %arg1: tensor<6x?x?x8x3x!quant.uniform> + ) -> tensor> { + // CHECK: %[[DOT_RES:.*]] = "mhlo.dot_general" + // CHECK-SAME: lhs_batching_dimensions = [0, 2] + // CHECK-SAME: rhs_batching_dimensions = [1, 4] + // CHECK-SAME: lhs_contracting_dimensions = [4, 3] + // CHECK-SAME: rhs_contracting_dimensions = [0, 2]>} + + // Zero point offset contribution from LHS tensor * RHS ZP. + + // CHECK: %[[LHS_I32:.*]] = mhlo.convert %[[LHS:.*]] : (tensor) + // CHECK-SAME: -> tensor + // CHECK: %[[LHS_REDUCE_INIT:.*]] = mhlo.constant dense<0> : tensor + // CHECK: %[[LHS_REDUCE:.*]] = mhlo.reduce(%[[LHS_I32]] init: %[[LHS_REDUCE_INIT]]) + // CHECK-SAME: applies mhlo.add across dimensions = [4, 3] + // CHECK-SAME: (tensor, tensor) + // CHECK-SAME: -> tensor + // CHECK: %[[RHS_ZP:.*]] = mhlo.constant dense<5> : tensor + // CHECK: %[[LHS_ZP_CONTRIB:.*]] = chlo.broadcast_multiply + // CHECK-SAME: %[[LHS_REDUCE]], %[[RHS_ZP]] : + // CHECK-SAME: (tensor, tensor) -> tensor + + // Calculate output dynamic dims. + // CHECK: %[[DIM_1_1:.*]] = "mhlo.get_dimension_size"(%[[DOT_RES]]) + // CHECK-SAME: {dimension = 0 : i64} + // CHECK: %[[DIM_1_2:.*]] = mhlo.convert %[[DIM_1_1]] : (tensor) -> tensor + // CHECK: %[[DIM_1:.*]] = mhlo.reshape %[[DIM_1_2]] : (tensor) -> tensor<1xi64> + // CHECK: %[[DIM_2:.*]] = mhlo.constant dense<3> : tensor<1xi64> + // CHECK: %[[DIM_3_1:.*]] = "mhlo.get_dimension_size"(%[[DOT_RES]]) + // CHECK-SAME: {dimension = 2 : i64} + // CHECK: %[[DIM_3_2:.*]] = mhlo.convert %[[DIM_3_1]] : (tensor) -> tensor + // CHECK: %[[DIM_3:.*]] = mhlo.reshape %[[DIM_3_2]] : (tensor) -> tensor<1xi64> + // CHECK: %[[DIM_4:.*]] = mhlo.constant dense<8> : tensor<1xi64> + // CHECK: %[[OUTPUT_DIMS:.*]] = "mhlo.concatenate" + // CHECK-SAME: %[[DIM_1]], %[[DIM_2]], %[[DIM_3]], %[[DIM_4]] + + // CHECK: %[[LHS_ZP_BCAST:.*]] = "mhlo.dynamic_broadcast_in_dim" + // CHECK-SAME: (%[[LHS_ZP_CONTRIB]], %[[OUTPUT_DIMS]]) + // CHECK-SAME: broadcast_dimensions = dense<[0, 2, 1]> + // CHECK-SAME: (tensor, tensor<4xi64>) -> tensor + + // Zero point offset contribution from RHS tensor * LHS ZP. + + // CHECK: %[[RHS_I32:.*]] = mhlo.convert %[[RHS:.*]] : (tensor<6x?x?x8x3xi8>) + // CHECK-SAME: -> tensor<6x?x?x8x3xi32> + // CHECK: %[[RHS_REDUCE_INIT:.*]] = mhlo.constant dense<0> : tensor + // CHECK: %[[RHS_REDUCE:.*]] = mhlo.reduce(%[[RHS_I32]] init: %[[RHS_REDUCE_INIT]]) + // CHECK-SAME: applies mhlo.add across dimensions = [0, 2] + // CHECK-SAME: (tensor<6x?x?x8x3xi32>, tensor) + // CHECK-SAME: -> tensor + // CHECK: %[[RHS_ZP:.*]] = mhlo.constant dense<3> : tensor + // CHECK: %[[RHS_ZP_CONTRIB:.*]] = chlo.broadcast_multiply + // CHECK-SAME: %[[RHS_REDUCE]], %[[RHS_ZP]] : + // CHECK-SAME: (tensor, tensor) -> tensor + + // CHECK: %[[RHS_ZP_BCAST:.*]] = "mhlo.dynamic_broadcast_in_dim" + // CHECK-SAME: (%[[RHS_ZP_CONTRIB]], %[[OUTPUT_DIMS]]) + // CHECK-SAME: broadcast_dimensions = dense<[0, 3, 1]> + // CHECK-SAME: (tensor, tensor<4xi64>) -> tensor + // CHECK: %[[ZP_TOTAL_1:.*]] = mhlo.add %[[LHS_ZP_BCAST]], %[[RHS_ZP_BCAST]] + + // Zero point offset contribution from LHS ZP * RHS ZP. + + // CHECK: %[[ZPS_INIT:.*]] = mhlo.constant dense<1> : tensor + // CHECK: %[[DYN_DIM:.*]] = "mhlo.get_dimension_size"(%[[RHS]]) + // CHECK: %[[ZPS_1:.*]] = mhlo.multiply %[[ZPS_INIT]], %[[DYN_DIM]] + // CHECK: %[[STATIC_DIM:.*]] = mhlo.constant dense<90> : tensor + // CHECK: %[[ZPS:.*]] = mhlo.multiply %[[STATIC_DIM]], %[[ZPS_1]] + // CHECK: %[[ZP_TOTAL_2:.*]] = chlo.broadcast_subtract %[[ZP_TOTAL_1]], %[[ZPS]] + // CHECK-SAME: (tensor, tensor) -> tensor + + // Combine dot result with zero point offset and output final result. + + // CHECK: %[[COMBINED_SCALE:.*]] = mhlo.constant dense<5.000000e-01> : tensor + // CHECK: %[[RES_FP:.*]] = mhlo.convert %[[DOT_RES]] + // CHECK-SAME: (tensor) -> tensor + // CHECK: %[[RES_FP_1:.*]] = chlo.broadcast_multiply + // CHECK-SAME: %[[RES_FP:.*]], %[[COMBINED_SCALE]] + // CHECK: %[[RES_INT:.*]] = mhlo.convert %[[RES_FP_1]] + // CHECK-SAME: (tensor) -> tensor + + // CHECK: %[[ZP_TOTAL_3:.*]] = mhlo.convert %[[ZP_TOTAL_2]] + // CHECK-SAME: (tensor) -> tensor + // CHECK: %[[ZP_TOTAL_4:.*]] = chlo.broadcast_multiply + // CHECK-SAME: %[[ZP_TOTAL_3:.*]], %[[COMBINED_SCALE]] + // CHECK: %[[ZP_TOTAL_5:.*]] = mhlo.convert %[[ZP_TOTAL_4]] + // CHECK-SAME: (tensor) -> tensor + + // CHECK: %[[RES_ZP:.*]] = mhlo.constant dense<7> : tensor + // CHECK: %[[ZP_TOTAL_6:.*]] = chlo.broadcast_subtract %[[RES_ZP]], %[[ZP_TOTAL_5]] + // CHECK-SAME: (tensor, tensor) -> tensor + // CHECK: chlo.broadcast_add %[[RES_INT]], %[[ZP_TOTAL_6]] + + %0 = "mhlo.dot_general" (%arg0, %arg1) { + dot_dimension_numbers = #mhlo.dot< + lhs_batching_dimensions = [0, 2], + rhs_batching_dimensions = [1, 4], + lhs_contracting_dimensions = [4, 3], + rhs_contracting_dimensions = [0, 2] + >} : ( + tensor>, + tensor<6x?x?x8x3x!quant.uniform> + ) -> tensor> + return %0 : tensor> +} + +// ----- + +// CHECK-LABEL: func @conv2d_dynamic +func.func @conv2d_dynamic( + %arg0: tensor>, + %arg1: tensor> + ) -> tensor> { + // CHECK-NOT: mhlo.pad + + // CHECK: %[[CONV:.*]] = mhlo.convolution + // CHECK-SAME: (%[[LHS:.*]], %[[RHS:.{1,4}]]) + // CHECK-SAME: dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f] + // CHECK-SAME: window = {stride = [1, 2], pad = {{\[}}[0, 0], [0, 0]], + // CHECK-SAME: lhs_dilate = [1, 1], rhs_dilate = [2, 2]} + // CHECK-SAME: {batch_group_count = 1 : i64, feature_group_count = 1 : i64} + // CHECK-SAME: (tensor, tensor) -> tensor + + // Zero point offset contribution from LHS ZP * RHS. + + // CHECK: %[[RHS_I32:.*]] = mhlo.convert %[[RHS]] + // CHECK-SAME: (tensor) -> tensor + // CHECK: %[[RHS_REDUCE:.*]] = mhlo.reduce(%[[RHS_I32]] + // CHECK-SAME: applies mhlo.add across dimensions = [0, 1, 2] + // CHECK-SAME: (tensor, tensor) + // CHECK-SAME: -> tensor + // CHECK: %[[LHS_ZP:.*]] = mhlo.constant dense<4> : tensor + // CHECK: %[[RHS_ZP_CONTRIB:.*]] = chlo.broadcast_multiply %[[RHS_REDUCE]], %[[LHS_ZP]] + // CHECK-SAME: (tensor, tensor) -> tensor + // CHECK: %[[RHS_ZP_BCAST:.*]] = "mhlo.dynamic_broadcast_in_dim" + // CHECK-SAME: %[[RHS_ZP_CONTRIB]] + // CHECK-SAME: {broadcast_dimensions = dense<3> : tensor<1xi64>} + // CHECK-SAME: (tensor, tensor<4xi64>) -> tensor + + // Combine conv result with zero point offset and output final result. + + // CHECK: %[[COMBINED_SCALE:.*]] = mhlo.constant dense<6.000000e+00> : tensor + // CHECK: %[[RES_FP:.*]] = mhlo.convert %[[CONV]] + // CHECK-SAME: (tensor) -> tensor + // CHECK: %[[RES_FP_1:.*]] = chlo.broadcast_multiply + // CHECK-SAME: %[[RES_FP:.*]], %[[COMBINED_SCALE]] + // CHECK: %[[RES_INT:.*]] = mhlo.convert %[[RES_FP_1]] + // CHECK-SAME: (tensor) -> tensor + + // CHECK: %[[ZP_TOTAL_1:.*]] = mhlo.convert %[[RHS_ZP_BCAST]] + // CHECK-SAME: (tensor) -> tensor + // CHECK: %[[ZP_TOTAL_2:.*]] = chlo.broadcast_multiply + // CHECK-SAME: %[[ZP_TOTAL_1:.*]], %[[COMBINED_SCALE]] + // CHECK: %[[ZP_TOTAL_3:.*]] = mhlo.convert %[[ZP_TOTAL_2]] + // CHECK-SAME: (tensor) -> tensor + + // CHECK: %[[RES_ZP:.*]] = mhlo.constant dense<5> : tensor + // CHECK: %[[ZP_TOTAL_4:.*]] = chlo.broadcast_subtract %[[RES_ZP]], %[[ZP_TOTAL_3]] + // CHECK-SAME: (tensor, tensor) -> tensor + // CHECK: chlo.broadcast_add %[[RES_INT]], %[[ZP_TOTAL_4]] + %0 = mhlo.convolution(%arg0, %arg1) + dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], + window = { + stride = [1, 2], pad = [[0, 0], [0, 0]], + lhs_dilate = [1, 1], + rhs_dilate = [2, 2] + } + { + batch_group_count = 1 : i64, + feature_group_count = 1 : i64 + } : (tensor>, tensor>) + -> tensor> + return %0 : tensor> +} + +// ----- + +// CHECK-LABEL: func @conv2d_static +func.func @conv2d_static( + %arg0: tensor<128x28x28x1x!quant.uniform>, + %arg1: tensor<3x3x1x128x!quant.uniform> + ) -> tensor<128x26x26x128x!quant.uniform> { + // CHECK-NOT: mhlo.pad + + // CHECK: %[[CONV:.*]] = mhlo.convolution + // CHECK-SAME: (%[[LHS:.*]], %[[RHS:.{1,4}]]) + // CHECK-SAME: dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f] + // CHECK-SAME: window = {stride = [1, 1], pad = {{\[}}[0, 0], [0, 0]], + // CHECK-SAME: lhs_dilate = [1, 1], rhs_dilate = [1, 1]} + // CHECK-SAME: {batch_group_count = 1 : i64, feature_group_count = 1 : i64} + // CHECK-SAME: (tensor<128x28x28x1xi8>, tensor<3x3x1x128xi8>) -> tensor<128x26x26x128xi32> + + // Zero point offset contribution from LHS ZP * RHS. + + // CHECK: %[[RHS_I32:.*]] = mhlo.convert %[[RHS]] + // CHECK-SAME: (tensor<3x3x1x128xi8>) -> tensor<3x3x1x128xi32> + // CHECK: %[[RHS_REDUCE:.*]] = mhlo.reduce(%[[RHS_I32]] + // CHECK-SAME: applies mhlo.add across dimensions = [0, 1, 2] + // CHECK-SAME: (tensor<3x3x1x128xi32>, tensor) + // CHECK-SAME: -> tensor<128xi32> + // CHECK: %[[LHS_ZP:.*]] = mhlo.constant dense<4> : tensor + // CHECK: %[[RHS_ZP_CONTRIB:.*]] = chlo.broadcast_multiply %[[RHS_REDUCE]], %[[LHS_ZP]] + // CHECK-SAME: (tensor<128xi32>, tensor) -> tensor<128xi32> + // CHECK: %[[RHS_ZP_BCAST:.*]] = "mhlo.broadcast_in_dim" + // CHECK-SAME: %[[RHS_ZP_CONTRIB]] + // CHECK-SAME: {broadcast_dimensions = dense<3> : tensor<1xi64>} + // CHECK-SAME: (tensor<128xi32>) -> tensor<128x26x26x128xi32> + + // Combine conv result with zero point offset and output final result. + + // CHECK: %[[COMBINED_SCALE:.*]] = mhlo.constant dense<6.000000e+00> : tensor + // CHECK: %[[RES_FP:.*]] = mhlo.convert %[[CONV]] + // CHECK-SAME: (tensor<128x26x26x128xi32>) -> tensor<128x26x26x128xf32> + // CHECK: %[[RES_FP_1:.*]] = chlo.broadcast_multiply + // CHECK-SAME: %[[RES_FP:.*]], %[[COMBINED_SCALE]] + // CHECK: %[[RES_INT:.*]] = mhlo.convert %[[RES_FP_1]] + // CHECK-SAME: (tensor<128x26x26x128xf32>) -> tensor<128x26x26x128xi32> + + // CHECK: %[[ZP_TOTAL_1:.*]] = mhlo.convert %[[RHS_ZP_BCAST]] + // CHECK-SAME: (tensor<128x26x26x128xi32>) -> tensor<128x26x26x128xf32> + // CHECK: %[[ZP_TOTAL_2:.*]] = chlo.broadcast_multiply + // CHECK-SAME: %[[ZP_TOTAL_1:.*]], %[[COMBINED_SCALE]] + // CHECK: %[[ZP_TOTAL_3:.*]] = mhlo.convert %[[ZP_TOTAL_2]] + // CHECK-SAME: (tensor<128x26x26x128xf32>) -> tensor<128x26x26x128xi32> + + // CHECK: %[[RES_ZP:.*]] = mhlo.constant dense<5> : tensor + // CHECK: %[[ZP_TOTAL_4:.*]] = chlo.broadcast_subtract %[[RES_ZP]], %[[ZP_TOTAL_3]] + // CHECK-SAME: (tensor, tensor<128x26x26x128xi32>) -> tensor<128x26x26x128xi32> + // CHECK: chlo.broadcast_add %[[RES_INT]], %[[ZP_TOTAL_4]] + %0 = mhlo.convolution(%arg0, %arg1) + dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], + window = { + stride = [1, 1], pad = [[0, 0], [0, 0]], + lhs_dilate = [1, 1], + rhs_dilate = [1, 1] + } + { + batch_group_count = 1 : i64, + feature_group_count = 1 : i64 + } : (tensor<128x28x28x1x!quant.uniform>, tensor<3x3x1x128x!quant.uniform>) + -> tensor<128x26x26x128x!quant.uniform> + return %0 : tensor<128x26x26x128x!quant.uniform> +} + +// ----- + +// CHECK-LABEL: func @conv2d_default_attr +func.func @conv2d_default_attr( + %arg0: tensor<128x28x28x1x!quant.uniform>, + %arg1: tensor<3x3x1x128x!quant.uniform> + ) -> tensor<128x26x26x128x!quant.uniform> { + // CHECK: mhlo.convolution + // CHECK-NOT: quant.uniform + %0 = mhlo.convolution(%arg0, %arg1) + dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], + window = { + } + { + batch_group_count = 1 : i64, + feature_group_count = 1 : i64 + } : (tensor<128x28x28x1x!quant.uniform>, tensor<3x3x1x128x!quant.uniform>) + -> tensor<128x26x26x128x!quant.uniform> + return %0 : tensor<128x26x26x128x!quant.uniform> +} + +// ----- + +// CHECK-LABEL: func @conv2d_static_padding +func.func @conv2d_static_padding( + %arg0: tensor<128x28x28x1x!quant.uniform>, + %arg1: tensor<3x3x1x128x!quant.uniform> + ) -> tensor<128x29x33x128x!quant.uniform> { + // Explicitly pad LHS with ZP. + + // CHECK: %[[LHS_ZP_i8:.*]] = mhlo.constant dense<4> : tensor + // CHECK: %[[LHS_PAD:.*]] = "mhlo.pad"(%[[LHS:.*]], %[[LHS_ZP_i8]]) + // CHECK-SAME: edge_padding_high = dense<[0, 2, 4, 0]> + // CHECK-SAME: edge_padding_low = dense<[0, 1, 3, 0]> + // CHECK-SAME: interior_padding = dense<0> + // CHECK-SAME: (tensor<128x28x28x1xi8>, tensor) -> tensor<128x31x35x1xi8> + + // Convolution with padding removed. + + // CHECK: %[[CONV:.*]] = mhlo.convolution + // CHECK-SAME: (%[[LHS_PAD]], %[[RHS:.{1,4}]]) + // CHECK-SAME: dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f] + // CHECK-SAME: window = {stride = [1, 1], pad = {{\[}}[0, 0], [0, 0]], + // CHECK-SAME: lhs_dilate = [1, 1], rhs_dilate = [1, 1]} + // CHECK-SAME: {batch_group_count = 1 : i64, feature_group_count = 1 : i64} + // CHECK-SAME: (tensor<128x31x35x1xi8>, tensor<3x3x1x128xi8>) -> tensor<128x29x33x128xi32> + + // Zero point offset contribution from LHS ZP * RHS. + + // CHECK: %[[RHS_I32:.*]] = mhlo.convert %[[RHS]] + // CHECK-SAME: (tensor<3x3x1x128xi8>) -> tensor<3x3x1x128xi32> + // CHECK: %[[RHS_REDUCE:.*]] = mhlo.reduce(%[[RHS_I32]] + // CHECK-SAME: applies mhlo.add across dimensions = [0, 1, 2] + // CHECK-SAME: (tensor<3x3x1x128xi32>, tensor) + // CHECK-SAME: -> tensor<128xi32> + // CHECK: %[[LHS_ZP:.*]] = mhlo.constant dense<4> : tensor + // CHECK: %[[RHS_ZP_CONTRIB:.*]] = chlo.broadcast_multiply %[[RHS_REDUCE]], %[[LHS_ZP]] + // CHECK-SAME: (tensor<128xi32>, tensor) -> tensor<128xi32> + // CHECK: %[[RHS_ZP_BCAST:.*]] = "mhlo.broadcast_in_dim" + // CHECK-SAME: %[[RHS_ZP_CONTRIB]] + // CHECK-SAME: {broadcast_dimensions = dense<3> : tensor<1xi64>} + // CHECK-SAME: (tensor<128xi32>) -> tensor<128x29x33x128xi32> + + // Combine conv result with zero point offset and output final result. + + // CHECK: %[[COMBINED_SCALE:.*]] = mhlo.constant dense<6.000000e+00> : tensor + // CHECK: %[[RES_FP:.*]] = mhlo.convert %[[CONV]] + // CHECK-SAME: (tensor<128x29x33x128xi32>) -> tensor<128x29x33x128xf32> + // CHECK: %[[RES_FP_1:.*]] = chlo.broadcast_multiply + // CHECK-SAME: %[[RES_FP:.*]], %[[COMBINED_SCALE]] + // CHECK: %[[RES_INT:.*]] = mhlo.convert %[[RES_FP_1]] + // CHECK-SAME: (tensor<128x29x33x128xf32>) -> tensor<128x29x33x128xi32> + + // CHECK: %[[ZP_TOTAL_1:.*]] = mhlo.convert %[[RHS_ZP_BCAST]] + // CHECK-SAME: (tensor<128x29x33x128xi32>) -> tensor<128x29x33x128xf32> + // CHECK: %[[ZP_TOTAL_2:.*]] = chlo.broadcast_multiply + // CHECK-SAME: %[[ZP_TOTAL_1:.*]], %[[COMBINED_SCALE]] + // CHECK: %[[ZP_TOTAL_3:.*]] = mhlo.convert %[[ZP_TOTAL_2]] + // CHECK-SAME: (tensor<128x29x33x128xf32>) -> tensor<128x29x33x128xi32> + + // CHECK: %[[RES_ZP:.*]] = mhlo.constant dense<5> : tensor + // CHECK: %[[ZP_TOTAL_4:.*]] = chlo.broadcast_subtract %[[RES_ZP]], %[[ZP_TOTAL_3]] + // CHECK-SAME: (tensor, tensor<128x29x33x128xi32>) -> tensor<128x29x33x128xi32> + // CHECK: chlo.broadcast_add %[[RES_INT]], %[[ZP_TOTAL_4]] + %0 = mhlo.convolution(%arg0, %arg1) + dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], + window = { + stride = [1, 1], pad = [[1, 2], [3, 4]], + lhs_dilate = [1, 1], + rhs_dilate = [1, 1] + } + { + batch_group_count = 1 : i64, + feature_group_count = 1 : i64 + } : (tensor<128x28x28x1x!quant.uniform>, tensor<3x3x1x128x!quant.uniform>) + -> tensor<128x29x33x128x!quant.uniform> + return %0 : tensor<128x29x33x128x!quant.uniform> +} + +// ----- + +// CHECK-LABEL: func @conv2d_per_channel +func.func @conv2d_per_channel( + %arg0: tensor<128x28x28x1x!quant.uniform>, + %arg1: tensor<3x3x1x2x!quant.uniform> + ) -> tensor<128x26x26x2x!quant.uniform> { + // CHECK: %[[CONV:.*]] = mhlo.convolution(%arg0, %arg1) + // CHECK-SAME: dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], + // CHECK-SAME: window = {stride = [1, 1], pad = {{\[}}[0, 0], [0, 0]], + // CHECK-SAME: lhs_dilate = [1, 1], rhs_dilate = [1, 1] + // CHECK-SAME: {batch_group_count = 1 : i64, feature_group_count = 1 : i64} + // CHECK-SAME: (tensor<128x28x28x1xi8>, tensor<3x3x1x2xi8>) -> tensor<128x26x26x2xi32> + + // CHECK: %[[RHS:.*]] = mhlo.convert %arg1 : (tensor<3x3x1x2xi8>) -> tensor<3x3x1x2xi32> + // CHECK: %[[REDUCE:.*]] = mhlo.reduce(%[[RHS]] + // CHECK-SAME: applies mhlo.add across dimensions = [0, 1, 2] + // CHECK: %[[LHS_ZP:.*]] = mhlo.constant dense<4> : tensor + // CHECK: %[[ZP_OFFSET:.*]] = chlo.broadcast_multiply %[[REDUCE]], %[[LHS_ZP]] + // CHECK: %[[ZP_OFFSET_BCAST:.*]] = "mhlo.broadcast_in_dim"(%[[ZP_OFFSET]]) + // CHECK: %[[RES_ZP:.*]] = mhlo.constant dense<0> : tensor + // CHECK: %[[ZP_OFFSET_TOTAL:.*]] = chlo.broadcast_subtract %[[RES_ZP:.*]], %[[ZP_OFFSET_BCAST]] + // CHECK: chlo.broadcast_add %[[CONV]], %[[ZP_OFFSET_TOTAL]] + %0 = mhlo.convolution(%arg0, %arg1) + dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], + window = { + stride = [1, 1], pad = [[0, 0], [0, 0]], + lhs_dilate = [1, 1], + rhs_dilate = [1, 1] + } + { + batch_group_count = 1 : i64, + feature_group_count = 1 : i64 + } : ( + tensor<128x28x28x1x!quant.uniform>, + tensor<3x3x1x2x!quant.uniform>) + -> tensor<128x26x26x2x!quant.uniform> + return %0 : tensor<128x26x26x2x!quant.uniform> +} + +// ----- + +// CHECK-LABEL: func @conv3d_static +func.func @conv3d_static( + %arg0: tensor<128x28x28x28x1x!quant.uniform>, + %arg1: tensor<3x3x3x1x128x!quant.uniform> + ) -> tensor<128x26x26x26x128x!quant.uniform>{ + // CHECK-NOT: mhlo.pad + + // CHECK: mhlo.convolution + // CHECK-SAME: dim_numbers = [b, 0, 1, 2, f]x[0, 1, 2, i, o]->[b, 0, 1, 2, f] + // CHECK-SAME: window = {stride = [1, 1, 1], pad = {{\[}}[0, 0], [0, 0], [0, 0]], + // CHECK-SAME: lhs_dilate = [1, 1, 1], rhs_dilate = [1, 1, 1]} + // CHECK-SAME: {batch_group_count = 1 : i64, feature_group_count = 1 : i64} + // CHECK-SAME: (tensor<128x28x28x28x1xi8>, tensor<3x3x3x1x128xi8>) -> tensor<128x26x26x26x128xi32> + + // CHECK: mhlo.reduce + // CHECK-SAME: applies mhlo.add across dimensions = [0, 1, 2, 3] + %0 = mhlo.convolution(%arg0, %arg1) + dim_numbers = [b, 0, 1, 2, f]x[0, 1, 2, i, o]->[b, 0, 1, 2, f], + window = { + stride = [1, 1, 1], pad = [[0, 0], [0, 0], [0, 0]], + lhs_dilate = [1, 1, 1], + rhs_dilate = [1, 1, 1] + } + { + batch_group_count = 1 : i64, + feature_group_count = 1 : i64 + } : (tensor<128x28x28x28x1x!quant.uniform>, tensor<3x3x3x1x128x!quant.uniform>) + -> tensor<128x26x26x26x128x!quant.uniform> + return %0 : tensor<128x26x26x26x128x!quant.uniform> +} + +// ----- + +func.func @conv3d_rhs_zp_not_zero( + %arg0: tensor<128x28x28x28x1x!quant.uniform>, + %arg1: tensor<3x3x3x1x128x!quant.uniform>) { + // expected-error@+2 {{RHS/result UQ type must have zero zp}} + // expected-error@+1 {{failed to legalize operation 'mhlo.convolution' that was explicitly marked illegal}} + %0 = mhlo.convolution(%arg0, %arg1) + dim_numbers = [b, 0, 1, 2, f]x[0, 1, 2, i, o]->[b, 0, 1, 2, f], + window = { + stride = [1, 1, 1], pad = [[0, 0], [0, 0], [0, 0]], + lhs_dilate = [1, 1, 1], + rhs_dilate = [1, 1, 1] + } + { + batch_group_count = 1 : i64, + feature_group_count = 1 : i64 + } : (tensor<128x28x28x28x1x!quant.uniform>, tensor<3x3x3x1x128x!quant.uniform>) + -> tensor<128x26x26x26x128x!quant.uniform> + return +} + +// ----- + +func.func @conv3d_rhs_invalid_dilate( + %arg0: tensor<128x28x28x28x1x!quant.uniform>, + %arg1: tensor<3x3x3x1x128x!quant.uniform>) { + // expected-error@+2 {{lhs_dilation must be 1}} + // expected-error@+1 {{failed to legalize operation 'mhlo.convolution' that was explicitly marked illegal}} + %0 = mhlo.convolution(%arg0, %arg1) + dim_numbers = [b, 0, 1, 2, f]x[0, 1, 2, i, o]->[b, 0, 1, 2, f], + window = { + stride = [1, 1, 1], pad = [[0, 0], [0, 0], [0, 0]], + lhs_dilate = [2, 2, 2], + rhs_dilate = [1, 1, 1] + } + { + batch_group_count = 1 : i64, + feature_group_count = 1 : i64 + } : (tensor<128x28x28x28x1x!quant.uniform>, tensor<3x3x3x1x128x!quant.uniform>) + -> tensor<128x53x53x53x128x!quant.uniform> + return +} + +// ----- + +func.func @conv3d_non_nhwc( + %arg0: tensor<128x1x28x28x28x!quant.uniform>, + %arg1: tensor<3x3x3x1x128x!quant.uniform>) { + // expected-error@+2 {{Convolution data format must be NHWC}} + // expected-error@+1 {{failed to legalize operation 'mhlo.convolution' that was explicitly marked illegal}} + %0 = mhlo.convolution(%arg0, %arg1) + dim_numbers = [b, f, 0, 1, 2]x[0, 1, 2, i, o]->[b, f, 0, 1, 2], + window = { + stride = [1, 1, 1], pad = [[0, 0], [0, 0], [0, 0]], + lhs_dilate = [1, 1, 1], + rhs_dilate = [1, 1, 1] + } + { + batch_group_count = 1 : i64, + feature_group_count = 1 : i64 + } : (tensor<128x1x28x28x28x!quant.uniform>, tensor<3x3x3x1x128x!quant.uniform>) + -> tensor<128x128x26x26x26x!quant.uniform> + return +} + +// ----- + +func.func @conv2d_non_nhwc( + %arg0: tensor<128x1x28x28x!quant.uniform>, + %arg1: tensor<3x3x1x128x!quant.uniform>) { + // expected-error@+2 {{Convolution data format must be NHWC}} + // expected-error@+1 {{failed to legalize operation 'mhlo.convolution' that was explicitly marked illegal}} + %0 = mhlo.convolution(%arg0, %arg1) + dim_numbers = [b, f, 0, 1]x[0, 1, i, o]->[b, f, 0, 1], + window = { + stride = [1, 1], pad = [[0, 0], [0, 0]], + lhs_dilate = [1, 1], + rhs_dilate = [1, 1] + } + { + batch_group_count = 1 : i64, + feature_group_count = 1 : i64 + } : (tensor<128x1x28x28x!quant.uniform>, tensor<3x3x1x128x!quant.uniform>) + -> tensor<128x128x26x26x!quant.uniform> + return +} + +// ----- + +func.func @conv2d_per_channel_rhs_zp_not_zero( + %arg0: tensor<128x28x28x1x!quant.uniform>, + %arg1: tensor<3x3x1x2x!quant.uniform> + ) -> tensor<128x26x26x2x!quant.uniform> { + // expected-error@+2 {{RHS/result UQ type must have zero zp.}} + // expected-error@+1 {{failed to legalize operation 'mhlo.convolution' that was explicitly marked illegal}} + %0 = mhlo.convolution(%arg0, %arg1) + dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], + window = { + stride = [1, 1], pad = [[0, 0], [0, 0]], + lhs_dilate = [1, 1], + rhs_dilate = [1, 1] + } + { + batch_group_count = 1 : i64, + feature_group_count = 1 : i64 + } : ( + tensor<128x28x28x1x!quant.uniform>, + tensor<3x3x1x2x!quant.uniform>) + -> tensor<128x26x26x2x!quant.uniform> + return %0 : tensor<128x26x26x2x!quant.uniform> +} + +// ----- + +func.func @conv2d_per_channel_res_zp_not_zero( + %arg0: tensor<128x28x28x1x!quant.uniform>, + %arg1: tensor<3x3x1x2x!quant.uniform> + ) -> tensor<128x26x26x2x!quant.uniform> { + // expected-error@+2 {{RHS/result UQ type must have zero zp.}} + // expected-error@+1 {{failed to legalize operation 'mhlo.convolution' that was explicitly marked illegal}} + %0 = mhlo.convolution(%arg0, %arg1) + dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], + window = { + stride = [1, 1], pad = [[0, 0], [0, 0]], + lhs_dilate = [1, 1], + rhs_dilate = [1, 1] + } + { + batch_group_count = 1 : i64, + feature_group_count = 1 : i64 + } : ( + tensor<128x28x28x1x!quant.uniform>, + tensor<3x3x1x2x!quant.uniform>) + -> tensor<128x26x26x2x!quant.uniform> + return %0 : tensor<128x26x26x2x!quant.uniform> +} + +// ----- + +func.func @conv2d_per_channel_rhs_only( + %arg0: tensor<128x28x28x1x!quant.uniform>, + %arg1: tensor<3x3x1x2x!quant.uniform> + ) -> tensor<128x26x26x2x!quant.uniform> { + // expected-error@+2 {{Invalid input/output type for Dot/Convolution op}} + // expected-error@+1 {{failed to legalize operation 'mhlo.convolution' that was explicitly marked illegal}} + %0 = mhlo.convolution(%arg0, %arg1) + dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], + window = { + stride = [1, 1], pad = [[0, 0], [0, 0]], + lhs_dilate = [1, 1], + rhs_dilate = [1, 1] + } + { + batch_group_count = 1 : i64, + feature_group_count = 1 : i64 + } : ( + tensor<128x28x28x1x!quant.uniform>, + tensor<3x3x1x2x!quant.uniform>) + -> tensor<128x26x26x2x!quant.uniform> + return %0 : tensor<128x26x26x2x!quant.uniform> +} + +// ----- + +func.func @conv2d_per_channel_res_only( + %arg0: tensor<128x28x28x1x!quant.uniform>, + %arg1: tensor<3x3x1x2x!quant.uniform> + ) -> tensor<128x26x26x2x!quant.uniform> { + // expected-error@+2 {{Invalid input/output type for Dot/Convolution op}} + // expected-error@+1 {{failed to legalize operation 'mhlo.convolution' that was explicitly marked illegal}} + %0 = mhlo.convolution(%arg0, %arg1) + dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], + window = { + stride = [1, 1], pad = [[0, 0], [0, 0]], + lhs_dilate = [1, 1], + rhs_dilate = [1, 1] + } + { + batch_group_count = 1 : i64, + feature_group_count = 1 : i64 + } : ( + tensor<128x28x28x1x!quant.uniform>, + tensor<3x3x1x2x!quant.uniform>) + -> tensor<128x26x26x2x!quant.uniform> + return %0 : tensor<128x26x26x2x!quant.uniform> +} + +// ----- + +func.func @conv2d_per_channel_unsupported_channel( + %arg0: tensor<128x28x28x1x!quant.uniform>, + %arg1: tensor<3x3x1x2x!quant.uniform> + ) -> tensor<128x26x26x2x!quant.uniform> { + // expected-error@+2 {{Conv quantized axis must be out channel axis}} + // expected-error@+1 {{failed to legalize operation 'mhlo.convolution' that was explicitly marked illegal}} + %0 = mhlo.convolution(%arg0, %arg1) + dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], + window = { + stride = [1, 1], pad = [[0, 0], [0, 0]], + lhs_dilate = [1, 1], + rhs_dilate = [1, 1] + } + { + batch_group_count = 1 : i64, + feature_group_count = 1 : i64 + } : (tensor<128x28x28x1x!quant.uniform>, tensor<3x3x1x2x!quant.uniform>) + -> tensor<128x26x26x2x!quant.uniform> + return %0 : tensor<128x26x26x2x!quant.uniform> +} + +// ----- + +func.func @conv2d_per_channel_rhs_result_scale_ratio_different( + %arg0: tensor<128x28x28x1x!quant.uniform>, + %arg1: tensor<3x3x1x2x!quant.uniform> + ) -> tensor<128x26x26x2x!quant.uniform> { + // expected-error@+2 {{Per-channel quantizated Conv must have same RHS/Result scale ratio for each channel}} + // expected-error@+1 {{failed to legalize operation 'mhlo.convolution' that was explicitly marked illegal}} + %0 = mhlo.convolution(%arg0, %arg1) + dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], + window = { + stride = [1, 1], pad = [[0, 0], [0, 0]], + lhs_dilate = [1, 1], + rhs_dilate = [1, 1] + } + { + batch_group_count = 1 : i64, + feature_group_count = 1 : i64 + } : ( + tensor<128x28x28x1x!quant.uniform>, + tensor<3x3x1x2x!quant.uniform>) + -> tensor<128x26x26x2x!quant.uniform> + return %0 : tensor<128x26x26x2x!quant.uniform> +} + +// ----- + +// CHECK-LABEL: func @dot_hybrid +func.func @dot_hybrid( + %arg0: tensor, + %arg1: tensor>) -> tensor { + // CHECK: %[[VAL1:.*]] = mhlo.convert %[[VAL0:.*]] : (tensor) -> tensor + // CHECK: %[[VAL3:.*]] = chlo.broadcast_subtract %[[VAL1]], %[[VAL2:.*]] : (tensor, tensor) -> tensor + // CHECK: %[[VAL5:.*]] = chlo.broadcast_multiply %[[VAL3]], %[[VAL4:.*]] : (tensor, tensor) -> tensor + // CHECK: %[[VAL7:.*]] = "mhlo.dot"(%arg0, %[[VAL5]]) : (tensor, tensor) -> tensor + %1 = "mhlo.dot" (%arg0, %arg1): ( + tensor, tensor>) -> tensor + return %1: tensor +} + +// ----- + +func.func @dot_hybrid_result_type_not_float( + %arg0: tensor, + %arg1: tensor>) { + // expected-error@+2 {{Invalid input/output type for Dot/Convolution op}} + // expected-error@+1 {{failed to legalize operation 'mhlo.dot' that was explicitly marked illegal}} + %1 = "mhlo.dot" (%arg0, %arg1): ( + tensor, tensor> + ) -> tensor> + return +} + +// ----- + +func.func @dot_hybrid_lhs_type_not_float( + %arg0: tensor>, + %arg1: tensor) { + // expected-error@+2 {{Invalid input/output type for Dot/Convolution op}} + // expected-error@+1 {{failed to legalize operation 'mhlo.dot' that was explicitly marked illegal}} + %1 = "mhlo.dot" (%arg0, %arg1): ( + tensor>, tensor + ) -> tensor> + return +} + +// ----- + +// CHECK-LABEL: func @conv2d_static_hybrid +func.func @conv2d_static_hybrid( + %arg0: tensor<128x28x28x1xf32>, + %arg1: tensor<3x3x1x128x!quant.uniform> + ) -> tensor<128x26x26x128xf32> { + // CHECK-DAG: %[[ZP:.*]] = mhlo.constant dense<1.000000e+00> : tensor + // CHECK-DAG: %[[SCALE:.*]] = mhlo.constant dense<3.000000e+00> : tensor + // CHECK: %[[RHS:.*]] = mhlo.convert %arg1 : (tensor<3x3x1x128xi8>) -> tensor<3x3x1x128xf32> + // CHECK: %[[SUB:.*]] = chlo.broadcast_subtract %[[RHS]], %[[ZP]] + // CHECK: %[[MUL:.*]] = chlo.broadcast_multiply %[[SUB]], %[[SCALE]] + // CHECK: mhlo.convolution(%arg0, %[[MUL]]) + // CHECK-SAME: dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f] + // CHECK-SAME: stride = [1, 1], pad = {{\[}}[0, 0], [0, 0]] + // CHECK-SAME: lhs_dilate = [1, 1], rhs_dilate = [1, 1]} + // CHECK-SAME: {batch_group_count = 1 : i64, feature_group_count = 1 : i64} + // CHECK-SAME: : (tensor<128x28x28x1xf32>, tensor<3x3x1x128xf32>) -> tensor<128x26x26x128xf32> + %0 = mhlo.convolution(%arg0, %arg1) + dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], + window = { + stride = [1, 1], pad = [[0, 0], [0, 0]], + lhs_dilate = [1, 1], + rhs_dilate = [1, 1] + } + { + batch_group_count = 1 : i64, + feature_group_count = 1 : i64 + } : (tensor<128x28x28x1xf32>, tensor<3x3x1x128x!quant.uniform>) + -> tensor<128x26x26x128xf32> + return %0 : tensor<128x26x26x128xf32> +} + +// ----- + +func.func @conv2d_hybrid_result_not_float( + %arg0: tensor<128x28x28x1xf32>, + %arg1: tensor<3x3x1x128x!quant.uniform>) { + // expected-error@+2 {{Invalid input/output type for Dot/Convolution op}} + // expected-error@+1 {{failed to legalize operation 'mhlo.convolution' that was explicitly marked illegal}} + %0 = mhlo.convolution(%arg0, %arg1) + dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], + window = { + stride = [1, 1], pad = [[0, 0], [0, 0]], + lhs_dilate = [1, 1], + rhs_dilate = [1, 1] + } + { + batch_group_count = 1 : i64, + feature_group_count = 1 : i64 + } : (tensor<128x28x28x1xf32>, tensor<3x3x1x128x!quant.uniform>) + -> tensor<128x26x26x128x!quant.uniform> + return +} + +// ----- + +func.func @dot_general_hybrid_result_not_float( + %arg0: tensor<2x5x6xf32>, + %arg1: tensor<6x8x2x!quant.uniform>) { + // expected-error@+2 {{Invalid input/output type for Dot/Convolution op}} + // expected-error@+1 {{failed to legalize operation 'mhlo.dot_general' that was explicitly marked illegal}} + %0 = "mhlo.dot_general" (%arg0, %arg1) { + dot_dimension_numbers = #mhlo.dot< + lhs_batching_dimensions = [0], + rhs_batching_dimensions = [2], + lhs_contracting_dimensions = [2], + rhs_contracting_dimensions = [0] + >} : ( + tensor<2x5x6xf32>, + tensor<6x8x2x!quant.uniform> + ) -> tensor<2x5x8x!quant.uniform> + return +} + +// ----- + +// CHECK-LABEL: func @mhlo_constant_uniform_quantized +func.func @mhlo_constant_uniform_quantized() -> tensor<1x!quant.uniform> { + // CHECK: mhlo.constant dense<9> : tensor<1xi8> + %0 = mhlo.constant() {value = dense<9> : tensor<1xi8>} : () -> tensor<1x!quant.uniform> + return %0 : tensor<1x!quant.uniform> +} + +// ----- + +// CHECK-LABEL: func @mhlo_constant_uniform_quantized_per_channel +func.func @mhlo_constant_uniform_quantized_per_channel() -> () { + // CHECK: mhlo.constant dense<[9, 4]> : tensor<2xi8> + %0 = mhlo.constant() {value = dense<[9, 4]> : tensor<2xi8>} : () + -> tensor<2x!quant.uniform> + return +} + + +// ----- + +// CHECK-LABEL: func @mhlo_constant_int +func.func @mhlo_constant_int() -> tensor { + // CHECK: mhlo.constant dense<-128> : tensor + %0 = mhlo.constant() {value = dense<-128> : tensor} : () -> tensor + return %0 : tensor +} + +// ----- + +// CHECK-LABEL: func @broadcast +func.func @broadcast( + %arg0: tensor<1x2x!quant.uniform> + ) -> tensor<2x3x1x!quant.uniform> { + // CHECK: "mhlo.broadcast_in_dim" + // CHECK-SAME: broadcast_dimensions = dense<[2, 0]> : tensor<2xi64> + // CHECK-SAME: (tensor<1x2xi8>) -> tensor<2x3x1xi8> + %0 = "mhlo.broadcast_in_dim"(%arg0) { + broadcast_dimensions = dense<[2, 0]> : tensor<2xi64> + } : (tensor<1x2x!quant.uniform>) -> tensor<2x3x1x!quant.uniform> + return %0 : tensor<2x3x1x!quant.uniform> +} + +// ----- + +// CHECK-LABEL: func @broadcast_per_channel +func.func @broadcast_per_channel( + %arg0: tensor<2x!quant.uniform> + ) -> tensor<128x26x26x2x!quant.uniform> { + // CHECK: "mhlo.broadcast_in_dim" + // CHECK-SAME: broadcast_dimensions = dense<3> : tensor<1xi64> + // CHECK-SAME: (tensor<2xi32>) -> tensor<128x26x26x2xi32> + %0 = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<3> : tensor<1xi64>}: ( + tensor<2x!quant.uniform> + ) -> tensor<128x26x26x2x!quant.uniform> + return %0 : tensor<128x26x26x2x!quant.uniform> +} + +// ----- + +// CHECK-LABEL: func @dynamic_broadcast +func.func @dynamic_broadcast( + %arg0: tensor<1x2x!quant.uniform>, + %arg1: tensor<3xi32> + ) -> tensor> { + // CHECK: "mhlo.dynamic_broadcast_in_dim" + // CHECK-SAME: broadcast_dimensions = dense<[1, 2]> : tensor<2xi64> + // CHECK-SAME: (tensor<1x2xi8>, tensor<3xi32>) -> tensor + %0 = "mhlo.dynamic_broadcast_in_dim"(%arg0, %arg1) { + broadcast_dimensions = dense<[1, 2]> : tensor<2xi64> + } : ( + tensor<1x2x!quant.uniform>, tensor<3xi32> + ) -> tensor> + return %0 : tensor> +} + +// ----- + +// CHECK-LABEL: func @max +func.func @max( + %arg0: tensor<1x2x!quant.uniform> + ) -> tensor<1x2x!quant.uniform> { + // CHECK: mhlo.maximum + // CHECK-SAME: tensor<1x2xi8> + %0 = "mhlo.maximum"(%arg0, %arg0) : ( + tensor<1x2x!quant.uniform>, + tensor<1x2x!quant.uniform> + ) -> tensor<1x2x!quant.uniform> + return %0 : tensor<1x2x!quant.uniform> +} + +// ----- + +// CHECK-LABEL: func @max_per_channel +func.func @max_per_channel( + %arg0: tensor<1x2x!quant.uniform> + ) -> tensor<1x2x!quant.uniform> { + // CHECK: mhlo.maximum + // CHECK-SAME: tensor<1x2xi8> + %0 = "mhlo.maximum"(%arg0, %arg0) : ( + tensor<1x2x!quant.uniform>, + tensor<1x2x!quant.uniform> + ) -> tensor<1x2x!quant.uniform> + return %0 : tensor<1x2x!quant.uniform> +} + +// ----- + +// CHECK-LABEL: func @min +func.func @min( + %arg0: tensor<1x2x!quant.uniform> + ) -> tensor<1x2x!quant.uniform> { + // CHECK: mhlo.minimum + // CHECK-SAME: tensor<1x2xi8> + %0 = "mhlo.minimum"(%arg0, %arg0) : ( + tensor<1x2x!quant.uniform>, + tensor<1x2x!quant.uniform> + ) -> tensor<1x2x!quant.uniform> + return %0 : tensor<1x2x!quant.uniform> +} + +// ----- + +// CHECK-LABEL: func @min_per_channel +func.func @min_per_channel( + %arg0: tensor<1x2x!quant.uniform> + ) -> tensor<1x2x!quant.uniform> { + // CHECK: mhlo.minimum + // CHECK-SAME: tensor<1x2xi8> + %0 = "mhlo.minimum"(%arg0, %arg0) : ( + tensor<1x2x!quant.uniform>, + tensor<1x2x!quant.uniform> + ) -> tensor<1x2x!quant.uniform> + return %0 : tensor<1x2x!quant.uniform> +} + +// ----- + +// CHECK-LABEL: func @function(%arg0: tensor<1x2xi8>) -> tensor<1x2xi8> +func.func @function( + %arg0: tensor<1x2x!quant.uniform> + ) -> tensor<1x2x!quant.uniform> { + // CHECK: return %arg0 : tensor<1x2xi8> + return %arg0 : tensor<1x2x!quant.uniform> +} + +// ----- + +// CHECK-LABEL: func @concatenate +func.func @concatenate( + %arg0: tensor<3x2x!quant.uniform:f32, 5.000000e-03>>, + %arg1: tensor<1x2x!quant.uniform:f32, 5.000000e-03>> + ) -> tensor<4x2x!quant.uniform:f32, 5.000000e-03>> { + // CHECK: mhlo.concatenate + // CHECK-SAME: (tensor<3x2xi8>, tensor<1x2xi8>) -> tensor<4x2xi8> + %0 = "mhlo.concatenate"(%arg0, %arg1) {dimension = 0 : i64} : ( + tensor<3x2x!quant.uniform:f32, 5.000000e-03>>, + tensor<1x2x!quant.uniform:f32, 5.000000e-03>> + ) -> tensor<4x2x!quant.uniform:f32, 5.000000e-03>> + return %0 : tensor<4x2x!quant.uniform:f32, 5.000000e-03>> +} + +// ----- + +// CHECK-LABEL: func @pad +func.func @pad( + %arg0: tensor<2x3x!quant.uniform:f32, 5.000000e-03>>, + %arg1: tensor:f32, 5.000000e-03>> + ) -> tensor<5x9x!quant.uniform:f32, 5.000000e-03>> { + // CHECK: mhlo.pad + // CHECK-SAME: (tensor<2x3xi8>, tensor) -> tensor<5x9xi8> + %0 = "mhlo.pad"(%arg0, %arg1) { + edge_padding_low = dense<[0, 1]> : tensor<2xi64>, + edge_padding_high = dense<[2, 1]> : tensor<2xi64>, + interior_padding = dense<[1, 2]> : tensor<2xi64> + }: ( + tensor<2x3x!quant.uniform:f32, 5.000000e-03>>, + tensor:f32, 5.000000e-03>> + ) -> tensor<5x9x!quant.uniform:f32, 5.000000e-03>> + return %0 : tensor<5x9x!quant.uniform:f32, 5.000000e-03>> +} + +// ----- + +// CHECK-LABEL: func @reshape +func.func @reshape( + %arg0: tensor<1x3x!quant.uniform> + ) -> tensor<3x1x!quant.uniform> { + // CHECK: mhlo.reshape + // CHECK-SAME: (tensor<1x3xi8>) -> tensor<3x1xi8> + %0 = "mhlo.reshape"(%arg0) : ( + tensor<1x3x!quant.uniform> + ) -> tensor<3x1x!quant.uniform> + return %0 : tensor<3x1x!quant.uniform> +} + +// ----- + +// CHECK-LABEL: func @select +func.func @select( + %arg0: tensor<1x3xi1>, + %arg1: tensor<1x3x!quant.uniform>, + %arg2: tensor<1x3x!quant.uniform> + ) -> tensor<1x3x!quant.uniform> { + // CHECK: mhlo.select + // CHECK-SAME: tensor<1x3xi8> + %0 = "mhlo.select"(%arg0, %arg1, %arg2) : ( + tensor<1x3xi1>, + tensor<1x3x!quant.uniform>, + tensor<1x3x!quant.uniform> + ) -> tensor<1x3x!quant.uniform> + return %0 : tensor<1x3x!quant.uniform> +} + +// ----- + +// CHECK-LABEL: func @transpose +func.func @transpose( + %arg0: tensor<3x1x!quant.uniform> + ) -> tensor<1x3x!quant.uniform> { + // CHECK: mhlo.transpose + // CHECK-SAME: (tensor<3x1xi8>) -> tensor<1x3xi8> + %0 = "mhlo.transpose"(%arg0) {permutation = dense<[1, 0]> : tensor<2xi64>} : ( + tensor<3x1x!quant.uniform> + ) -> tensor<1x3x!quant.uniform> + return %0 : tensor<1x3x!quant.uniform> +} + +// ----- + +// CHECK-LABEL: func @gather +func.func @gather( + %arg0: tensor<3x4x2x!quant.uniform>, + %arg1: tensor<2x3x2xi64> + ) -> tensor<2x3x2x2x!quant.uniform> { + // CHECK: mhlo.gather + // CHECK-SAME: (tensor<3x4x2xi8>, tensor<2x3x2xi64>) -> tensor<2x3x2x2xi8> + %0 = "mhlo.gather"(%arg0, %arg1) { + dimension_numbers = #mhlo.gather< + offset_dims = [2, 3], + collapsed_slice_dims = [0], + start_index_map = [1, 0], + index_vector_dim = 2>, + slice_sizes = dense<[1, 2, 2]> : tensor<3xi64>, + indices_are_sorted = false + } : ( + tensor<3x4x2x!quant.uniform>, + tensor<2x3x2xi64> + ) -> tensor<2x3x2x2x!quant.uniform> + return %0 : tensor<2x3x2x2x!quant.uniform> +} + +// ----- + +// CHECK-LABEL: func @slice +func.func @slice( + %arg0: tensor<3x4x!quant.uniform> + ) -> tensor<2x2x!quant.uniform> { + // CHECK: mhlo.slice + // CHECK-SAME: (tensor<3x4xi8>) -> tensor<2x2xi8> + %0 = "mhlo.slice"(%arg0) { + start_indices = dense<[1, 2]> : tensor<2xi64>, + limit_indices = dense<[3, 4]> : tensor<2xi64>, + strides = dense<1> : tensor<2xi64> + } : ( + tensor<3x4x!quant.uniform> + ) -> tensor<2x2x!quant.uniform> + return %0 : tensor<2x2x!quant.uniform> +} + +// ----- + +// CHECK-LABEL: func @get_dimension_size +func.func @get_dimension_size( + %arg0: tensor> + ) -> tensor { + // CHECK: mhlo.get_dimension_size + // CHECK-SAME: (tensor) -> tensor + %0 = "mhlo.get_dimension_size"(%arg0) {dimension = 0 : i64} : ( + tensor>) -> tensor + return %0 : tensor +} + +// ----- + +// CHECK-LABEL: reduce_window +func.func @reduce_window( + %arg0: tensor<2x3x10x3x!quant.uniform>, + %arg1: tensor> + ) -> tensor<2x3x10x3x!quant.uniform> { + // CHECK: mhlo.reduce_window + // CHECK: %[[ARG2:.*]]: tensor, %[[ARG3:.*]]: tensor + // CHECK: %[[MAX:.*]] = mhlo.maximum %[[ARG2]], %[[ARG3]] : tensor + // CHECK: mhlo.return %[[MAX]] : tensor + // CHECK: (tensor<2x3x10x3xi8>, tensor) -> tensor<2x3x10x3xi8> + %0 = "mhlo.reduce_window"(%arg0, %arg1) ({ + ^bb0(%arg2: tensor>, %arg3: tensor>): + %1 = mhlo.maximum %arg2, %arg3 : tensor> + mhlo.return %1 : tensor> + }) {padding = dense<[[0, 0], [1, 1], [1, 1], [0, 0]]> : tensor<4x2xi64>, window_dimensions = dense<[1, 3, 3, 1]> : tensor<4xi64>} : (tensor<2x3x10x3x!quant.uniform>, tensor>) -> tensor<2x3x10x3x!quant.uniform> + return %0 : tensor<2x3x10x3x!quant.uniform> +} diff --git a/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/hlo-legalize-to-stablehlo.mlir b/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/hlo-legalize-to-stablehlo.mlir index 023686dad92240..3f7266eddcd2d2 100644 --- a/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/hlo-legalize-to-stablehlo.mlir +++ b/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/hlo-legalize-to-stablehlo.mlir @@ -272,20 +272,20 @@ func.func @attr_rng_algorithm_philox(%arg0: tensor) -> (tensor, tensor } // CHECK-LABEL: "attr_rng_distribution_uniform" -func.func @attr_rng_distribution_uniform(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { +func.func @attr_rng_distribution_uniform(%arg0: tensor, %arg1: tensor, %arg2: tensor<0xindex>) -> tensor { %0 = "mhlo.rng"(%arg0, %arg1, %arg2) { // CHECK: rng_distribution = #stablehlo rng_distribution = #mhlo.rng_distribution - } : (tensor, tensor, tensor) -> tensor + } : (tensor, tensor, tensor<0xindex>) -> tensor func.return %0 : tensor } // CHECK-LABEL: "attr_rng_distribution_normal" -func.func @attr_rng_distribution_normal(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { +func.func @attr_rng_distribution_normal(%arg0: tensor, %arg1: tensor, %arg2: tensor<0xindex>) -> tensor { %0 = "mhlo.rng"(%arg0, %arg1, %arg2) { // CHECK: rng_distribution = #stablehlo rng_distribution = #mhlo.rng_distribution - } : (tensor, tensor, tensor) -> tensor + } : (tensor, tensor, tensor<0xindex>) -> tensor func.return %0 : tensor } @@ -863,9 +863,9 @@ func.func @op_dynamic_pad(%arg0: tensor, %arg1: tensor, %arg2: tenso } // CHECK-LABEL: "op_dynamic_reshape" -func.func @op_dynamic_reshape(%arg0: tensor<16xf32>, %arg1: tensor) -> tensor { - // CHECK: "stablehlo.dynamic_reshape"(%arg0, %arg1) : (tensor<16xf32>, tensor) -> tensor - %0 = "mhlo.dynamic_reshape"(%arg0, %arg1) : (tensor<16xf32>, tensor) -> tensor +func.func @op_dynamic_reshape(%arg0: tensor<16xf32>, %arg1: tensor<2xindex>) -> tensor { + // CHECK: "stablehlo.dynamic_reshape"(%arg0, %arg1) : (tensor<16xf32>, tensor<2xindex>) -> tensor + %0 = "mhlo.dynamic_reshape"(%arg0, %arg1) : (tensor<16xf32>, tensor<2xindex>) -> tensor func.return %0 : tensor } @@ -1327,13 +1327,13 @@ func.func @op_rng_bit_generator(%arg0: tensor) -> (tensor, tensor } // CHECK-LABEL: "op_rng" -func.func @op_rng(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { +func.func @op_rng(%arg0: tensor, %arg1: tensor, %arg2: tensor<0xindex>) -> tensor { // CHECK: "stablehlo.rng"(%arg0, %arg1, %arg2) { // CHECK-SAME: rng_distribution = #stablehlo - // CHECK-SAME: } : (tensor, tensor, tensor) -> tensor + // CHECK-SAME: } : (tensor, tensor, tensor<0xindex>) -> tensor %0 = "mhlo.rng"(%arg0, %arg1, %arg2) { rng_distribution = #mhlo.rng_distribution - } : (tensor, tensor, tensor) -> tensor + } : (tensor, tensor, tensor<0xindex>) -> tensor func.return %0 : tensor } diff --git a/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/mhlo_reduce_pretty_print.mlir b/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/mhlo_reduce_pretty_print.mlir index 8435e14fee55a0..4471026910f8e1 100644 --- a/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/mhlo_reduce_pretty_print.mlir +++ b/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/mhlo_reduce_pretty_print.mlir @@ -25,45 +25,6 @@ func.func @reduce_one_op_all_locs_same(%arg0: tensor, %arg1 : tensor } -// The test case is not eligible for pretty-printing reduce-op. The location of -// reduce-op is different. - -// CHECK-LABEL: func @reduce_one_op_all_locs_not_same_1 -// CHECK-NEXT: mhlo.reduce(%arg0 init: %arg1) -// CHECK-SAME: across dimensions = [1] {foo = "bar"} -// CHECK-SAME: : (tensor, tensor) -> tensor -// CHECK-NEXT: reducer(%arg[[x:.+]]: tensor loc("foo"), %arg[[y:.+]]: tensor loc("foo")) -// CHECK-NEXT: mhlo.add %arg[[x]], %arg[[y]] : tensor loc("foo") -// CHECK-NEXT: mhlo.return %{{[0-9]+}} : tensor loc("foo") -// CHECK-NEXT: loc("not_foo") - -func.func @reduce_one_op_all_locs_not_same_1(%arg0: tensor, %arg1 : tensor) -> (tensor) { - %0 = "mhlo.reduce"(%arg0, %arg1) ({ - ^bb0(%arg2: tensor loc("foo"), %arg3: tensor loc("foo")): - %1 = "mhlo.add"(%arg2, %arg3) : (tensor, tensor) -> tensor loc("foo") - "mhlo.return"(%1) : (tensor) -> () loc("foo") - }) {dimensions = dense<[1]> : tensor<1xi64>, foo = "bar"} : (tensor, tensor) -> tensor loc("not_foo") - - func.return %0: tensor -} - -// The test case is not eligible for pretty-printing reduce-op. The location of -// block-arguments are different. - -// CHECK-LABEL: func @reduce_one_op_all_locs_not_same_2 -// CHECK-NOT: applies - -func.func @reduce_one_op_all_locs_not_same_2(%arg0: tensor, %arg1 : tensor) -> (tensor) { - %0 = "mhlo.reduce"(%arg0, %arg1) ({ - ^bb0(%arg2: tensor loc("foo"), %arg3: tensor loc("not_foo")): - %1 = "mhlo.add"(%arg2, %arg3) : (tensor, tensor) -> tensor loc("foo") - "mhlo.return"(%1) : (tensor) -> () loc("foo") - }) {dimensions = dense<[1]> : tensor<1xi64>} : (tensor, tensor) -> tensor loc("foo") - - func.return %0: tensor -} - - // The test case is not eligible for pretty-printing reduce-op. More than two // block-arguments which are not perfectly forwarded to inner-op. @@ -168,3 +129,16 @@ func.func @reduce_innerop_type_not_trivially_derived(%arg0: tensor<4x4xf32>, %ar func.return %0: tensor<4xf32> } + + +// The test case makes sure any custom attrs set on the reduce-op are +// printed/parsed when pretty-printed. + +// CHECK-LABEL: func @pretty_print_with_custom_attr +// CHECK: applies mhlo.add across dimensions = [1] {custom_user_attr = 1 : i64} + +func.func @pretty_print_with_custom_attr(%arg0: tensor<2x64x13xf32>) -> tensor<2x13xf32> { + %0 = mhlo.constant dense<0.000000e+00> : tensor + %1 = mhlo.reduce(%arg0 init: %0) applies mhlo.add across dimensions = [1] {custom_user_attr = 1 : i64} : (tensor<2x64x13xf32>, tensor) -> tensor<2x13xf32> + return %1 : tensor<2x13xf32> +} diff --git a/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/prepare-for-export.mlir b/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/prepare-for-export.mlir index 7b300ecad14662..0872ff45825723 100644 --- a/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/prepare-for-export.mlir +++ b/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/prepare-for-export.mlir @@ -99,3 +99,156 @@ func.func @broadcast_in_dim_dimension_unsorted(%arg0: tensor<1x2xi32>) -> tensor %0 = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[2, 1]> : tensor<2xi64>} : (tensor<1x2xi32>) -> tensor<1x2x3xi32> func.return %0 : tensor<1x2x3xi32> } + +// ----- + +// CHECK-LABEL: @reduce_with_multiple_implicit_captures +func.func @reduce_with_multiple_implicit_captures(%arg0: tensor<2x2xf32>) -> tuple> { + %0 = mhlo.constant dense<1.000000e+00> : tensor + %1 = mhlo.constant dense<0.000000e+00> : tensor + // CHECK: mhlo.reduce + %2 = mhlo.reduce(%arg0 init: %1) across dimensions = [0, 1] : (tensor<2x2xf32>, tensor) -> tensor + reducer(%arg1: tensor, %arg2: tensor) { + // CHECK-DAG: mhlo.constant dense<0.000000e+00> : tensor + // CHECK-DAG: mhlo.constant dense<1.000000e+00> : tensor + // CHECK: mhlo.compare + %5 = mhlo.compare NE, %arg1, %1 : (tensor, tensor) -> tensor + %6 = mhlo.compare NE, %arg2, %1 : (tensor, tensor) -> tensor + %7 = mhlo.or %5, %6 : tensor + %8 = mhlo.select %7, %0, %1 : tensor, tensor + mhlo.return %8 : tensor + } + %3 = mhlo.compare NE, %2, %1 : (tensor, tensor) -> tensor + %4 = mhlo.tuple %3 {xla_shape = "(pred[])"} : tuple> + return %4 : tuple> +} + +// ----- + +// CHECK-LABEL: @all_reduce_with_implicit_capture +func.func @all_reduce_with_implicit_capture(%arg0: tensor) -> tensor { + %c = mhlo.constant dense<0.0> : tensor + // CHECK: mhlo.all_reduce + // CHECK-NEXT: ^[[BB:bb.*]](%[[ARG1:arg.*]]: tensor, %[[ARG2:arg.*]]: tensor): + %0 = "mhlo.all_reduce"(%arg0) ({ + ^bb0(%arg1: tensor, %arg2: tensor): + // CHECK: %[[VAL1:.*]] = mhlo.constant dense<0.000000e+00> : tensor + // CHECK: mhlo.add + // CHECK-SAME: %[[ARG1]], %[[VAL1]] + %1 = mhlo.add %arg1, %c : tensor + mhlo.return %1 : tensor + }) {replica_groups = dense<[[0], [1]]> : tensor<2x1xi64>} : (tensor) -> tensor + return %0 : tensor + } + +// ----- + +// CHECK-LABEL: @reduce_scatter_with_implicit_capture +func.func @reduce_scatter_with_implicit_capture(%data: tensor<4x16xf32>) -> tensor<4x4xf32> { + %c = mhlo.constant dense<0.0> : tensor + // CHECK: mhlo.reduce_scatter + // CHECK-NEXT: ^[[BB:bb.*]](%[[ARG1:arg.*]]: tensor, %[[ARG2:arg.*]]: tensor): + %0 = "mhlo.reduce_scatter"(%data) ({ + ^bb0(%arg2: tensor, %arg3: tensor): + // CHECK: %[[VAL1:.*]] = mhlo.constant dense<0.000000e+00> : tensor + // CHECK: mhlo.add + // CHECK-SAME: %[[ARG1]], %[[VAL1]] + %1 = mhlo.add %arg2, %c : tensor + "mhlo.return"(%1) : (tensor) -> () + }) {replica_groups = dense<[[0, 1, 2, 3]]> : tensor<1x4xi64>, + scatter_dimension = 1 : i64, + channel_handle = #mhlo.channel_handle, + use_global_device_ids} : (tensor<4x16xf32>) -> tensor<4x4xf32> + func.return %0 : tensor<4x4xf32> +} + +// ----- + +// CHECK-LABEL: @reduce_window_with_implicit_capture +func.func @reduce_window_with_implicit_capture(%arg0: tensor<2x17x31x7xf32>, %arg1: tensor) -> tensor<2x16x30x7xf32> { + %c = mhlo.constant dense<0.0> : tensor + // CHECK: mhlo.reduce_window + // CHECK-NEXT: ^[[BB:bb.*]](%[[ARG2:arg.*]]: tensor, %[[ARG3:arg.*]]: tensor): + %0 = "mhlo.reduce_window"(%arg0, %arg1) ({ + ^bb0(%arg2: tensor, %arg3: tensor): + // CHECK: %[[VAL1:.*]] = mhlo.constant dense<0.000000e+00> : tensor + // CHECK: mhlo.maximum + // CHECK-SAME: %[[ARG2]], %[[VAL1]] + %1 = mhlo.maximum %arg2, %c : tensor + mhlo.return %1 : tensor + }) {window_dimensions = dense<[1, 2, 2, 1]> : tensor<4xi64>} : (tensor<2x17x31x7xf32>, tensor) -> tensor<2x16x30x7xf32> + return %0 : tensor<2x16x30x7xf32> + } + +// ----- + +// CHECK-LABEL: @scatter_with_implicit_capture +func.func @scatter_with_implicit_capture(%arg0: tensor<3xi32>, %arg1: tensor<1x1xi32>, + %arg2: tensor<1xi32>) -> tensor<3xi32> { + %c = mhlo.constant dense<0> : tensor + // CHECK: mhlo.scatter + // CHECK-NEXT: ^[[BB:bb.*]](%[[ARG3:arg.*]]: tensor, %[[ARG4:arg.*]]: tensor): + %0 = "mhlo.scatter"(%arg0, %arg1, %arg2) ({ + ^bb0(%arg3: tensor, %arg4: tensor): + // CHECK: %[[VAL1:.*]] = mhlo.constant dense<0> : tensor + // CHECK: mhlo.add + // CHECK-SAME: %[[ARG4]], %[[VAL1]] + %x = mhlo.add %arg4, %c : tensor + "mhlo.return"(%x) : (tensor) -> () + }) { + indices_are_sorted = false, + scatter_dimension_numbers = #mhlo.scatter< + update_window_dims = [], + inserted_window_dims = [0], + scatter_dims_to_operand_dims = [0], + index_vector_dim = 1, + >, + unique_indices = false + } : (tensor<3xi32>, tensor<1x1xi32>, tensor<1xi32>) -> tensor<3xi32> + func.return %0 : tensor<3xi32> +} + +// ----- + +// CHECK-LABEL: @select_and_scatter_with_implicit_capture +func.func @select_and_scatter_with_implicit_capture(%arg0: tensor<10x24x24x64xf32>, %arg1: tensor<10x23x23x64xf32>, %arg2: tensor) -> tensor<10x24x24x64xf32> { + %c1 = mhlo.constant dense<0.0> : tensor + %c2 = mhlo.constant dense<0.0> : tensor + // CHECK: mhlo.select_and_scatter + // CHECK-NEXT: ^[[BB:bb.*]](%[[ARG3:arg.*]]: tensor, %[[ARG4:arg.*]]: tensor): + %0 = "mhlo.select_and_scatter"(%arg0, %arg1, %arg2) ({ + ^bb0(%arg3: tensor, %arg4: tensor): + // CHECK: %[[VAL1:.*]] = mhlo.constant dense<0.000000e+00> : tensor + // CHECK: mhlo.compare + // CHECK-SAME: %[[ARG3]], %[[VAL1]] + %1 = mhlo.compare GE, %arg3, %c1, TOTALORDER : (tensor, tensor) -> tensor + mhlo.return %1 : tensor + }, { + // CHECK: ^[[BB:bb.*]](%[[ARG3:arg.*]]: tensor, %[[ARG4:arg.*]]: tensor): + ^bb0(%arg3: tensor, %arg4: tensor): + // CHECK: %[[VAL2:.*]] = mhlo.constant dense<0.000000e+00> : tensor + // CHECK: mhlo.add + // CHECK-SAME: %[[ARG4]], %[[VAL2]] + %1 = mhlo.add %arg4, %c2 : tensor + mhlo.return %1 : tensor + }) {window_dimensions = dense<[1, 2, 2, 1]> : tensor<4xi64>} : (tensor<10x24x24x64xf32>, tensor<10x23x23x64xf32>, tensor) -> tensor<10x24x24x64xf32> + return %0 : tensor<10x24x24x64xf32> + } + +// ----- + +// CHECK-LABEL: @sort_with_implicit_capture +func.func @sort_with_implicit_capture(%input0: tensor<16x16xf32>, %input1: tensor<16x16xi32>) { + %c = mhlo.constant dense<0.0> : tensor + // CHECK: mhlo.sort + // CHECK-NEXT: ^[[BB:bb.*]](%[[ARG0:arg.*]]: tensor, %[[ARG1:arg.*]]: tensor, %[[ARG2:arg.*]]: tensor, %[[ARG3:arg.*]]: tensor): + %0:2 = "mhlo.sort"(%input0, %input1) ({ + ^bb0(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor): + // CHECK: %[[VAL1:.*]] = mhlo.constant dense<0.000000e+00> : tensor + // CHECK: mhlo.compare + // CHECK-SAME: %[[ARG0]], %[[VAL1]] + %7 = "mhlo.compare"(%arg0, %c) {comparison_direction = #mhlo} : (tensor, tensor) -> tensor + "mhlo.return"(%7) : (tensor) -> () + }) {dimension = 1 : i64, is_stable = true} : (tensor<16x16xf32>, tensor<16x16xi32>) -> (tensor<16x16xf32>, tensor<16x16xi32>) + func.return +} diff --git a/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/sparse_ops.mlir b/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/sparse_ops.mlir index e5ea921f6df2a7..842dd4e8b2a38d 100644 --- a/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/sparse_ops.mlir +++ b/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/sparse_ops.mlir @@ -259,7 +259,7 @@ func.func @dot3(%arg0: tensor<4xf64, #SV>, // CHECK-LABEL: func @sparse_reduce( // CHECK-SAME: %[[A:.*]]: tensor<10xi64, #{{.*}}>) -> tensor { // CHECK: %[[C:.*]] = mhlo.constant dense<0> : tensor -// CHECK: %[[T:.*]] = mhlo.reduce(%[[A]] init: %[[C]]) across dimensions = [0] : (tensor<10xi64, #{{.*}}>) -> tensor +// CHECK: %[[T:.*]] = mhlo.reduce(%[[A]] init: %[[C]]) applies mhlo.add across dimensions = [0] : (tensor<10xi64, #{{.*}}>) -> tensor // CHECK: return %[[T]] : tensor func.func @sparse_reduce(%arg0: tensor<10xi64, #SV>) -> tensor { %0 = mhlo.constant dense<0> : tensor diff --git a/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/stablehlo-legalize-to-hlo.mlir b/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/stablehlo-legalize-to-hlo.mlir index a0d6b9d8d241e8..cde68d60e020f4 100644 --- a/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/stablehlo-legalize-to-hlo.mlir +++ b/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/stablehlo-legalize-to-hlo.mlir @@ -250,20 +250,20 @@ func.func @attr_rng_algorithm_philox(%arg0: tensor) -> (tensor, tensor } // CHECK-LABEL: "attr_rng_distribution_uniform" -func.func @attr_rng_distribution_uniform(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { +func.func @attr_rng_distribution_uniform(%arg0: tensor, %arg1: tensor, %arg2: tensor<0xindex>) -> tensor { %0 = "stablehlo.rng"(%arg0, %arg1, %arg2) { // CHECK: rng_distribution = #mhlo.rng_distribution rng_distribution = #stablehlo - } : (tensor, tensor, tensor) -> tensor + } : (tensor, tensor, tensor<0xindex>) -> tensor func.return %0 : tensor } // CHECK-LABEL: "attr_rng_distribution_normal" -func.func @attr_rng_distribution_normal(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { +func.func @attr_rng_distribution_normal(%arg0: tensor, %arg1: tensor, %arg2: tensor<0xindex>) -> tensor { %0 = "stablehlo.rng"(%arg0, %arg1, %arg2) { // CHECK: rng_distribution = #mhlo.rng_distribution rng_distribution = #stablehlo - } : (tensor, tensor, tensor) -> tensor + } : (tensor, tensor, tensor<0xindex>) -> tensor func.return %0 : tensor } @@ -872,9 +872,9 @@ func.func @op_dynamic_pad(%arg0: tensor, %arg1: tensor, %arg2: tenso } // CHECK-LABEL: "op_dynamic_reshape" -func.func @op_dynamic_reshape(%arg0: tensor<16xf32>, %arg1: tensor) -> tensor { - // CHECK: "mhlo.dynamic_reshape"(%arg0, %arg1) : (tensor<16xf32>, tensor) -> tensor - %0 = "stablehlo.dynamic_reshape"(%arg0, %arg1) : (tensor<16xf32>, tensor) -> tensor +func.func @op_dynamic_reshape(%arg0: tensor<16xf32>, %arg1: tensor<2xindex>) -> tensor { + // CHECK: "mhlo.dynamic_reshape"(%arg0, %arg1) : (tensor<16xf32>, tensor<2xindex>) -> tensor + %0 = "stablehlo.dynamic_reshape"(%arg0, %arg1) : (tensor<16xf32>, tensor<2xindex>) -> tensor func.return %0 : tensor } @@ -1334,13 +1334,13 @@ func.func @op_rng_bit_generator(%arg0: tensor) -> (tensor, tensor } // CHECK-LABEL: "op_rng" -func.func @op_rng(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { +func.func @op_rng(%arg0: tensor, %arg1: tensor, %arg2: tensor<0xindex>) -> tensor { // CHECK: "mhlo.rng"(%arg0, %arg1, %arg2) { // CHECK-SAME: rng_distribution = #mhlo.rng_distribution - // CHECK-SAME: } : (tensor, tensor, tensor) -> tensor + // CHECK-SAME: } : (tensor, tensor, tensor<0xindex>) -> tensor %0 = "stablehlo.rng"(%arg0, %arg1, %arg2) { rng_distribution = #stablehlo - } : (tensor, tensor, tensor) -> tensor + } : (tensor, tensor, tensor<0xindex>) -> tensor func.return %0 : tensor } diff --git a/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/verifier_reduce_op.mlir b/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/verifier_reduce_op.mlir index baab3c2bb83678..f335f18102d105 100644 --- a/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/verifier_reduce_op.mlir +++ b/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/verifier_reduce_op.mlir @@ -521,7 +521,7 @@ func.func @reduce_verify_rettype(%arg0: tensor, %arg1 : tensor) // ----- func.func @reduce_parsing_pretty_reduce_non_commutative(%arg0: tensor , %arg1: tensor ) -> tensor { - // expected-error@+1 {{expected the inner-op to be a commutative binary-op from mhlo dialect, zero region, producing single result}} + // expected-error@+1 {{expected the inner-op to be a commutative binary-op that matching the reduce op dialect, with zero region, producing single result}} %0 = mhlo.reduce(%arg0 init: %arg1) applies mhlo.divide across dimensions = [1] : (tensor, tensor) -> tensor loc("foo") func.return %0 : tensor } @@ -529,7 +529,7 @@ func.func @reduce_parsing_pretty_reduce_non_commutative(%arg0: tensor , // ----- func.func @reduce_parsing_pretty_reduce_wrong_dialect(%arg0: tensor , %arg1: tensor ) -> tensor { - // expected-error@+1 {{expected the inner-op to be a commutative binary-op from mhlo dialect, zero region, producing single result}} + // expected-error@+1 {{expected the inner-op to be a commutative binary-op that matching the reduce op dialect, with zero region, producing single result}} %0 = mhlo.reduce(%arg0 init: %arg1) applies std.add across dimensions = [1] : (tensor, tensor) -> tensor loc("foo") func.return %0 : tensor } @@ -537,7 +537,7 @@ func.func @reduce_parsing_pretty_reduce_wrong_dialect(%arg0: tensor , % // ----- func.func @reduce_parsing_pretty_reduce_non_binary(%arg0: tensor , %arg1: tensor ) -> tensor { - // expected-error@+1 {{expected the inner-op to be a commutative binary-op from mhlo dialect, zero region, producing single result}} + // expected-error@+1 {{expected the inner-op to be a commutative binary-op that matching the reduce op dialect, with zero region, producing single result}} %0 = mhlo.reduce(%arg0 init: %arg1) applies mhlo.reshape across dimensions = [1] : (tensor, tensor) -> tensor loc("foo") func.return %0 : tensor } diff --git a/third_party/xla/xla/mlir_hlo/tests/buffer_packing.mlir b/third_party/xla/xla/mlir_hlo/tests/buffer_packing.mlir deleted file mode 100644 index 737dad4148de1b..00000000000000 --- a/third_party/xla/xla/mlir_hlo/tests/buffer_packing.mlir +++ /dev/null @@ -1,164 +0,0 @@ -// RUN: mlir-hlo-opt -buffer-packing -split-input-file %s | FileCheck %s - -// CHECK-LABEL: @noPackingSameLiveRange -func.func @noPackingSameLiveRange() -> (f32, f32) { - // CHECK: memref.alloc - // CHECK: memref.alloc - %c1 = arith.constant 1 : index - %c2 = arith.constant 2.0 : f32 - %0 = memref.alloc() : memref<42xf32> - %1 = memref.alloc() : memref<42xf32> - memref.store %c2, %0[%c1] : memref<42xf32> - memref.store %c2, %1[%c1] : memref<42xf32> - %2 = memref.load %0[%c1] : memref<42xf32> - %3 = memref.load %1[%c1] : memref<42xf32> - return %2, %3 : f32, f32 -} - -// ----- - -// CHECK-LABEL: @packingScfIfSameSize -func.func @packingScfIfSameSize(%pred : i1) -> (f32) { - // CHECK: %[[MEM:.*]] = memref.alloc() : memref<192xi8> - // CHECK: %[[VIEW1:.*]] = memref.view %[[MEM]][%{{.*}}][] : memref<192xi8> to memref<42xf32> - // CHECK: %[[VIEW2:.*]] = memref.view %[[MEM]][%{{.*}}][] : memref<192xi8> to memref<42xf32> - // CHECK: scf.if - // CHECK: memref.load %[[VIEW1]] - // CHECK: else - // CHECK: memref.load %[[VIEW2]] - %c1 = arith.constant 1 : index - %c2 = arith.constant 2.0 : f32 - %0 = memref.alloc() : memref<42xf32> - %1 = memref.alloc() : memref<42xf32> - %2 = scf.if %pred -> f32 { - memref.store %c2, %0[%c1] : memref<42xf32> - %2 = memref.load %0[%c1] : memref<42xf32> - scf.yield %2 : f32 - } else { - memref.store %c2, %1[%c1] : memref<42xf32> - %2 = memref.load %1[%c1] : memref<42xf32> - scf.yield %2 : f32 - } - return %2 : f32 -} - -// ----- - -// CHECK-LABEL: @packingScfIfDifferentSize -func.func @packingScfIfDifferentSize(%pred : i1) -> (f32) { - // CHECK: %[[MEM:.*]] = memref.alloc() : memref<192xi8> - // CHECK: scf.if - // CHECK: %[[VIEW1:.*]] = memref.view %[[MEM]][%{{.*}}][] : memref<192xi8> to memref<42xf32> - // CHECK: memref.load %[[VIEW1]] - // CHECK: else - // CHECK: %[[VIEW2:.*]] = memref.view %[[MEM]][%{{.*}}][] : memref<192xi8> to memref<16xf32> - // CHECK: memref.load %[[VIEW2]] - %c1 = arith.constant 1 : index - %c2 = arith.constant 2.0 : f32 - %0 = scf.if %pred -> f32 { - %0 = memref.alloc() : memref<42xf32> - memref.store %c2, %0[%c1] : memref<42xf32> - %1 = memref.load %0[%c1] : memref<42xf32> - scf.yield %1 : f32 - } else { - %0 = memref.alloc() : memref<16xf32> - memref.store %c2, %0[%c1] : memref<16xf32> - %1 = memref.load %0[%c1] : memref<16xf32> - scf.yield %1 : f32 - } - return %0 : f32 -} - -// ----- - -// CHECK-LABEL: @packingScfIfDifferentElementType -func.func @packingScfIfDifferentElementType(%pred : i1) -> (f32) { - // CHECK: %[[MEM:.*]] = memref.alloc() : memref<128xi8> - // CHECK: scf.if - // CHECK: %[[VIEW1:.*]] = memref.view %[[MEM]][%{{.*}}][] : memref<128xi8> to memref<42xf16> - // CHECK: memref.load %[[VIEW1]] - // CHECK: else - // CHECK: %[[VIEW2:.*]] = memref.view %[[MEM]][%{{.*}}][] : memref<128xi8> to memref<16xf32> - // CHECK: memref.load %[[VIEW2]] - %c1 = arith.constant 1 : index - %0 = scf.if %pred -> f32 { - %c2 = arith.constant 2.0 : f16 - %0 = memref.alloc() : memref<42xf16> - memref.store %c2, %0[%c1] : memref<42xf16> - %1 = memref.load %0[%c1] : memref<42xf16> - %2 = arith.extf %1 : f16 to f32 - scf.yield %2 : f32 - } else { - %c2 = arith.constant 2.0 : f32 - %0 = memref.alloc() : memref<16xf32> - memref.store %c2, %0[%c1] : memref<16xf32> - %1 = memref.load %0[%c1] : memref<16xf32> - scf.yield %1 : f32 - } - return %0 : f32 -} - -// ----- - -// CHECK-LABEL: @packWithOutsideControlFlow -func.func @packWithOutsideControlFlow(%pred : i1) -> (f32, f32) { - // CHECK: %[[MEM:.*]] = memref.alloc() : memref<192xi8> - // CHECK: %[[VIEW0:.*]] = memref.view %[[MEM]][%{{.*}}][] : memref<192xi8> to memref<42xf32> - // CHECK: memref.load %[[VIEW0]] - // CHECK: scf.if - // CHECK: %[[VIEW1:.*]] = memref.view %[[MEM]][%{{.*}}][] : memref<192xi8> to memref<42xf32> - // CHECK: memref.load %[[VIEW1]] - // CHECK: else - // CHECK: %[[VIEW2:.*]] = memref.view %[[MEM]][%{{.*}}][] : memref<192xi8> to memref<42xf32> - // CHECK: memref.load %[[VIEW2]] - %c1 = arith.constant 1 : index - %c2 = arith.constant 2.0 : f32 - %0 = memref.alloc() : memref<42xf32> - memref.store %c2, %0[%c1] : memref<42xf32> - %1 = memref.load %0[%c1] : memref<42xf32> - %2 = scf.if %pred -> f32 { - %3 = memref.alloc() : memref<42xf32> - memref.store %c2, %3[%c1] : memref<42xf32> - %4 = memref.load %3[%c1] : memref<42xf32> - scf.yield %4 : f32 - } else { - %3 = memref.alloc() : memref<42xf32> - memref.store %c2, %3[%c1] : memref<42xf32> - %4 = memref.load %3[%c1] : memref<42xf32> - scf.yield %4 : f32 - } - return %1, %2 : f32, f32 -} - -// ----- - -// CHECK-LABEL: @packTwoInOne -func.func @packTwoInOne(%pred : i1) -> (f32) { - // CHECK: %[[MEM:.*]] = memref.alloc() : memref<192xi8> - // CHECK: scf.if - // CHECK: %[[VIEW1:.*]] = memref.view %[[MEM]][%{{.*}}][] : memref<192xi8> to memref<42xf32> - // CHECK: memref.load %[[VIEW1]] - // CHECK: else - // CHECK: %[[VIEW2:.*]] = memref.view %[[MEM]][%{{.*}}][] : memref<192xi8> to memref<16xf32> - // CHECK: %[[VIEW3:.*]] = memref.view %[[MEM]][%{{.*}}][] : memref<192xi8> to memref<8xf32> - // CHECK: memref.load %[[VIEW2]] - // CHECK: memref.load %[[VIEW3]] - %c1 = arith.constant 1 : index - %c2 = arith.constant 2.0 : f32 - %0 = scf.if %pred -> f32 { - %0 = memref.alloc() : memref<42xf32> - memref.store %c2, %0[%c1] : memref<42xf32> - %1 = memref.load %0[%c1] : memref<42xf32> - scf.yield %1 : f32 - } else { - %0 = memref.alloc() : memref<16xf32> - %1 = memref.alloc() : memref<8xf32> - memref.store %c2, %0[%c1] : memref<16xf32> - %2 = memref.load %0[%c1] : memref<16xf32> - memref.store %c2, %1[%c1] : memref<8xf32> - %3 = memref.load %1[%c1] : memref<8xf32> - %4 = arith.addf %2, %3 : f32 - scf.yield %4 : f32 - } - return %0 : f32 -} diff --git a/third_party/xla/xla/mlir_hlo/tests/naive_copy_removal.mlir b/third_party/xla/xla/mlir_hlo/tests/naive_copy_removal.mlir index bec5a0822ac7b5..2b6a5b191c4634 100644 --- a/third_party/xla/xla/mlir_hlo/tests/naive_copy_removal.mlir +++ b/third_party/xla/xla/mlir_hlo/tests/naive_copy_removal.mlir @@ -65,9 +65,9 @@ func.func @target_is_subview_of_subview(%arg0: memref<8x8xf32>) %subview_5 = memref.subview %alloc_4[0, 0] [%c4, %c4] [1, 1] : memref<8x8xf32> to memref> %subview_6 = memref.subview %subview_5[0, 0] [%c4, %c4] [1, 1] : - memref> to memref> + memref> to memref> memref.copy %arg0, %subview_6 : - memref<8x8xf32> to memref> + memref<8x8xf32> to memref> return %arg0 : memref<8x8xf32> } @@ -79,32 +79,6 @@ func.func @target_is_subview_of_subview(%arg0: memref<8x8xf32>) // ----- -func.func @do_not_simplify_subview_of_subview(%arg0: memref<8x8xf32>) - -> vector<8x8xf32> { - %c4 = arith.constant 4 : index - %c0 = arith.constant 0 : index - %cst_0 = arith.constant 0.000000e+00 : f32 - %alloc_4 = memref.alloc() {alignment = 64 : i64} : memref<8x8xf32> - %subview_5 = memref.subview %alloc_4[0, 0] [%c4, %c4] [1, 1] : - memref<8x8xf32> to memref> - %subview_6 = memref.subview %subview_5[0, 0] [%c4, %c4] [1, 1] : - memref> to memref> - memref.copy %arg0, %subview_6 : - memref<8x8xf32> to memref> - %27 = vector.transfer_read %subview_5[%c0, %c0], %cst_0 : - memref>, vector<8x8xf32> - return %27 : vector<8x8xf32> -} - -// CHECK-LABEL: func @do_not_simplify_subview_of_subview( - -// CHECK: memref.alloc -// CHECK: memref.subview -// CHECK: memref.subview -// CHECK: memref.copy - -// ----- - func.func @do_not_simplify_subview(%arg0: memref<8x8xf32>) -> vector<8x8xf32> { %c4 = arith.constant 4 : index %c0 = arith.constant 0 : index diff --git a/third_party/xla/xla/mlir_hlo/tests/rank-specialization.mlir b/third_party/xla/xla/mlir_hlo/tests/rank-specialization.mlir deleted file mode 100644 index 3627fbebedc894..00000000000000 --- a/third_party/xla/xla/mlir_hlo/tests/rank-specialization.mlir +++ /dev/null @@ -1,702 +0,0 @@ -// RUN: mlir-hlo-opt %s --split-input-file --mhlo-rank-specialization-cluster | FileCheck %s -// RUN: mlir-hlo-opt %s --split-input-file --mhlo-rank-specialization-cluster --mhlo-rank-specialization-to-scf=max-target-rank=3 | FileCheck %s --check-prefix CHECK-SCF - -// CHECK-LABEL: @add_mul -// CHECK-SAME: (%[[ARG0:.*]]: tensor<*xf32>, %[[ARG1:.*]]: tensor<*xf32>, %[[ARG2:.*]]: tensor<*xf32>) -func.func @add_mul(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>, - %arg2 : tensor<*xf32>) -> tensor<*xf32> { - // CHECK: %[[RES:.*]] = "chlo.rank_specialization_cluster"(%[[ARG2]], %[[ARG0]], %[[ARG1]]) ({ - // CHECK: ^bb0(%[[ARG2_:.*]]: tensor<*xf32>, %[[ARG0_:.*]]: tensor<*xf32>, %[[ARG1_:.*]]: tensor<*xf32>): - // CHECK: %[[TMP:.*]] = chlo.broadcast_multiply %[[ARG0_]], %[[ARG1_]] - // CHECK: %[[INNER_RES:.*]] = chlo.broadcast_add %[[TMP]], %[[ARG2_]] - // CHECK: "chlo.rank_specialization_cluster_yield"(%[[INNER_RES]]) - // CHECK: }) : (tensor<*xf32>, tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> - // CHECK: return %[[RES]] - %0 = chlo.broadcast_multiply %arg0, %arg1 - : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> - %1 = chlo.broadcast_add %0, %arg2 - : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> - func.return %1 : tensor<*xf32> -} - -// CHECK-SCF-LABEL: @add_mul -// CHECK-SCF-SAME: (%[[ARG0:.*]]: tensor<*xf32>, %[[ARG1:.*]]: tensor<*xf32>, %[[ARG2:.*]]: tensor<*xf32>) -// CHECK-SCF-DAG: %[[C1:.*]] = arith.constant 1 -// CHECK-SCF-DAG: %[[C2:.*]] = arith.constant 2 -// CHECK-SCF-DAG: %[[C3:.*]] = arith.constant 3 -// CHECK-SCF-DAG: %[[ONE_SHAPE_1:.*]] = shape.const_shape [1] -// CHECK-SCF-DAG: %[[ONE_SHAPE_2:.*]] = shape.const_shape [1, 1] -// CHECK-SCF-DAG: %[[ONE_SHAPE_3:.*]] = shape.const_shape [1, 1, 1] -// CHECK-SCF-DAG: %[[SHAPE_ARG0:.*]] = shape.shape_of %[[ARG0]] -// CHECK-SCF-DAG: %[[SHAPE_ARG1:.*]] = shape.shape_of %[[ARG1]] -// CHECK-SCF-DAG: %[[SHAPE_ARG2:.*]] = shape.shape_of %[[ARG2]] -// Equal shapes case: -// CHECK-SCF-DAG: %[[EQ20:.*]] = shape.shape_eq %[[SHAPE_ARG2]], %[[SHAPE_ARG0]] -// CHECK-SCF-DAG: %[[EQ21:.*]] = shape.shape_eq %[[SHAPE_ARG2]], %[[SHAPE_ARG1]] -// CHECK-SCF-DAG: %[[SHAPES_EQ:.*]] = arith.andi %[[EQ20]], %[[EQ21]] -// CHECK-SCF: %[[UNSHAPED_RES_EQ_SHAPES:.*]] = scf.if %[[SHAPES_EQ]] -// CHECK-SCF-DAG: %[[ANY_SHAPE:.*]] = shape.any %[[SHAPE_ARG2]], %[[SHAPE_ARG0]], %[[SHAPE_ARG1]] -// CHECK-SCF-DAG: %[[N:.*]] = shape.num_elements %[[ANY_SHAPE]] -// CHECK-SCF-DAG: %[[FLAT_SHAPE:.*]] = tensor.from_elements %[[N]] -// CHECK-SCF-DAG: %[[FLAT_ARG0:.*]] = mhlo.dynamic_reshape %[[ARG0]], %[[FLAT_SHAPE]] -// CHECK-SCF-DAG: %[[FLAT_ARG1:.*]] = mhlo.dynamic_reshape %[[ARG1]], %[[FLAT_SHAPE]] -// CHECK-SCF-DAG: %[[FLAT_ARG2:.*]] = mhlo.dynamic_reshape %[[ARG2]], %[[FLAT_SHAPE]] -// CHECK-SCF-DAG: %[[TMP:.*]] = chlo.broadcast_multiply %[[FLAT_ARG0]], %[[FLAT_ARG1]] : (tensor, tensor) -// CHECK-SCF-DAG: %[[INNER_RES:.*]] = chlo.broadcast_add %[[TMP]], %[[FLAT_ARG2]] : (tensor, tensor) -// CHECK-SCF-DAG: %[[INNER_RES_:.*]] = tensor.cast %[[INNER_RES]] -// CHECK-SCF: scf.yield %[[INNER_RES_]] -// CHECK-SCF: else -// Find maximum reduced rank. -// CHECK-SCF-DAG: %[[REDUCED_SHAPES:.*]]:3 = chlo.minimum_broadcast_shapes %[[SHAPE_ARG2]], %[[SHAPE_ARG0]], %[[SHAPE_ARG1]] -// CHECK-SCF-DAG: %[[REDUCED_RANK0:.*]] = shape.rank %[[REDUCED_SHAPES]]#1 -// CHECK-SCF-DAG: %[[REDUCED_RANK1:.*]] = shape.rank %[[REDUCED_SHAPES]]#2 -// CHECK-SCF-DAG: %[[REDUCED_RANK2:.*]] = shape.rank %[[REDUCED_SHAPES]]#0 -// CHECK-SCF-DAG: %[[R2_GT_R0:.*]] = arith.cmpi sgt, %[[REDUCED_RANK2]], %[[REDUCED_RANK0]] -// CHECK-SCF-DAG: %[[R20:.*]] = arith.select %[[R2_GT_R0]], %[[REDUCED_RANK2]], %[[REDUCED_RANK0]] -// CHECK-SCF-DAG: %[[R20_GT_R1:.*]] = arith.cmpi sgt, %[[R20]], %[[REDUCED_RANK1]] -// CHECK-SCF-DAG: %[[MAX_RED_RANK:.*]] = arith.select %[[R20_GT_R1]], %[[R20]], %[[REDUCED_RANK1]] -// Generic case 1: -// CHECK-SCF: %[[MAX_RED_RANK_LE_1:.*]] = arith.cmpi ule, %[[MAX_RED_RANK]], %[[C1]] -// CHECK-SCF: %[[UNSHAPED_RES_1:.*]] = scf.if %[[MAX_RED_RANK_LE_1]] -// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#1, %[[ONE_SHAPE_1]] -// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#2, %[[ONE_SHAPE_1]] -// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG2:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#0, %[[ONE_SHAPE_1]] -// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0_:.*]] = tensor.cast %[[EXT_SHAPE_ARG0]] -// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1_:.*]] = tensor.cast %[[EXT_SHAPE_ARG1]] -// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG2_:.*]] = tensor.cast %[[EXT_SHAPE_ARG2]] -// CHECK-SCF-DAG: %[[REDUCED_ARG0:.*]] = mhlo.dynamic_reshape %[[ARG0]], %[[EXT_SHAPE_ARG0_]] -// CHECK-SCF-DAG: %[[REDUCED_ARG1:.*]] = mhlo.dynamic_reshape %[[ARG1]], %[[EXT_SHAPE_ARG1_]] -// CHECK-SCF-DAG: %[[REDUCED_ARG2:.*]] = mhlo.dynamic_reshape %[[ARG2]], %[[EXT_SHAPE_ARG2_]] -// CHECK-SCF-DAG: %[[TMP:.*]] = chlo.broadcast_multiply %[[REDUCED_ARG0]], %[[REDUCED_ARG1]] : (tensor, tensor) -// CHECK-SCF-DAG: %[[INNER_RES:.*]] = chlo.broadcast_add %[[TMP]], %[[REDUCED_ARG2]] : (tensor, tensor) -// CHECK-SCF-DAG: %[[INNER_RES_:.*]] = tensor.cast %[[INNER_RES]] -// CHECK-SCF: scf.yield %[[INNER_RES_]] -// CHECK-SCF: else -// Generic case 2: -// CHECK-SCF: %[[MAX_RED_RANK_LE_2:.*]] = arith.cmpi ule, %[[MAX_RED_RANK]], %[[C2]] -// CHECK-SCF: %[[UNSHAPED_RES_2:.*]] = scf.if %[[MAX_RED_RANK_LE_2]] -// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#1, %[[ONE_SHAPE_2]] -// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#2, %[[ONE_SHAPE_2]] -// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG2:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#0, %[[ONE_SHAPE_2]] -// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0_:.*]] = tensor.cast %[[EXT_SHAPE_ARG0]] -// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1_:.*]] = tensor.cast %[[EXT_SHAPE_ARG1]] -// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG2_:.*]] = tensor.cast %[[EXT_SHAPE_ARG2]] -// CHECK-SCF-DAG: %[[REDUCED_ARG0:.*]] = mhlo.dynamic_reshape %[[ARG0]], %[[EXT_SHAPE_ARG0_]] -// CHECK-SCF-DAG: %[[REDUCED_ARG1:.*]] = mhlo.dynamic_reshape %[[ARG1]], %[[EXT_SHAPE_ARG1_]] -// CHECK-SCF-DAG: %[[REDUCED_ARG2:.*]] = mhlo.dynamic_reshape %[[ARG2]], %[[EXT_SHAPE_ARG2_]] -// CHECK-SCF-DAG: %[[TMP:.*]] = chlo.broadcast_multiply %[[REDUCED_ARG0]], %[[REDUCED_ARG1]] : (tensor, tensor) -// CHECK-SCF-DAG: %[[INNER_RES:.*]] = chlo.broadcast_add %[[TMP]], %[[REDUCED_ARG2]] : (tensor, tensor) -// CHECK-SCF-DAG: %[[INNER_RES_:.*]] = tensor.cast %[[INNER_RES]] -// CHECK-SCF: scf.yield %[[INNER_RES_]] -// CHECK-SCF: else -// Generic case 3: -// CHECK-SCF: %[[MAX_RED_RANK_LE_3:.*]] = arith.cmpi ule, %[[MAX_RED_RANK]], %[[C3]] -// CHECK-SCF: assert %[[MAX_RED_RANK_LE_3]], "Input for dynamic binary or n-ary op lowering was of a rank greater than 3" -// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#1, %[[ONE_SHAPE_3]] -// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#2, %[[ONE_SHAPE_3]] -// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG2:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#0, %[[ONE_SHAPE_3]] -// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0_:.*]] = tensor.cast %[[EXT_SHAPE_ARG0]] -// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1_:.*]] = tensor.cast %[[EXT_SHAPE_ARG1]] -// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG2_:.*]] = tensor.cast %[[EXT_SHAPE_ARG2]] -// CHECK-SCF-DAG: %[[REDUCED_ARG0:.*]] = mhlo.dynamic_reshape %[[ARG0]], %[[EXT_SHAPE_ARG0_]] -// CHECK-SCF-DAG: %[[REDUCED_ARG1:.*]] = mhlo.dynamic_reshape %[[ARG1]], %[[EXT_SHAPE_ARG1_]] -// CHECK-SCF-DAG: %[[REDUCED_ARG2:.*]] = mhlo.dynamic_reshape %[[ARG2]], %[[EXT_SHAPE_ARG2_]] -// CHECK-SCF-DAG: %[[TMP:.*]] = chlo.broadcast_multiply %[[REDUCED_ARG0]], %[[REDUCED_ARG1]] : (tensor, tensor) -// CHECK-SCF-DAG: %[[INNER_RES:.*]] = chlo.broadcast_add %[[TMP]], %[[REDUCED_ARG2]] : (tensor, tensor) -// CHECK-SCF-DAG: %[[INNER_RES_:.*]] = tensor.cast %[[INNER_RES]] -// CHECK-SCF: scf.yield %[[INNER_RES_]] -// CHECK-SCF: scf.yield %[[UNSHAPED_RES_2]] -// CHECK-SCF: scf.yield %[[UNSHAPED_RES_1]] -// Reshape the result. -// CHECK-SCF-DAG: %[[SHAPE_ARG0:.*]] = shape.shape_of %[[ARG0]] -// CHECK-SCF-DAG: %[[SHAPE_ARG1:.*]] = shape.shape_of %[[ARG1]] -// CHECK-SCF-DAG: %[[TMP:.*]] = shape.broadcast %[[SHAPE_ARG0]], %[[SHAPE_ARG1]] -// CHECK-SCF-DAG: %[[SHAPE_ARG2:.*]] = shape.shape_of %[[ARG2]] -// CHECK-SCF-DAG: %[[RES_SHAPE:.*]] = shape.broadcast %[[TMP]], %[[SHAPE_ARG2]] -// CHECK-SCF-DAG: %[[RES:.*]] = mhlo.dynamic_reshape %[[UNSHAPED_RES_EQ_SHAPES]], %[[RES_SHAPE]] -// CHECK-SCF: return %[[RES]] - -// ----- - -// CHECK-LABEL: @compare_const_like -// CHECK-SAME: (%[[ARG0:.*]]: tensor<*xf32>) -func.func @compare_const_like(%arg0 : tensor<*xf32>) -> tensor<*xi1> { - // CHECK: %[[RES:.*]] = "chlo.rank_specialization_cluster"(%[[ARG0]]) ({ - // CHECK: ^bb0(%[[ARG1:.*]]: tensor<*xf32>): - // CHECK: %[[ZERO:.*]] = "chlo.constant_like"(%[[ARG1]]) {value = 0.000000e+00 : f32} : (tensor<*xf32>) -> tensor<*xf32> - // CHECK: %[[CMP_GT:.*]] = chlo.broadcast_compare %[[ARG1]], %[[ZERO]] {comparison_direction = #chlo} : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xi1> - // CHECK: "chlo.rank_specialization_cluster_yield"(%[[CMP_GT]]) : (tensor<*xi1>) -> () - // CHECK: }) : (tensor<*xf32>) -> tensor<*xi1> - // CHECK: return %[[RES]] : tensor<*xi1> - %0 = "chlo.constant_like"(%arg0) {value = 0.000000e+00 : f32} : (tensor<*xf32>) -> tensor<*xf32> - %1 = chlo.broadcast_compare %arg0, %0 {comparison_direction = #chlo} - : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xi1> - func.return %1 : tensor<*xi1> -} - -// ----- - -// Unary MHLO operation. -// CHECK-LABEL: @sqrt -// CHECK-SAME: (%[[ARG:.*]]: tensor<*xf32>) -func.func @sqrt(%arg : tensor<*xf32>) -> tensor<*xf32> { - // CHECK: %[[RES:.*]] = "chlo.rank_specialization_cluster"(%[[ARG]]) - // CHECK: ^bb0(%[[ARG_:.*]]: tensor<*xf32>): - // CHECK: %[[TMP0:.*]] = mhlo.sqrt %[[ARG_]] - // CHECK: %[[TMP1:.*]] = mhlo.sqrt %[[TMP0]] - // CHECK: %[[TMP2:.*]] = mhlo.sqrt %[[TMP1]] - // CHECK: "chlo.rank_specialization_cluster_yield"(%[[TMP2]]) - // CHECK: return %[[RES]] - %0 = mhlo.sqrt %arg : (tensor<*xf32>) -> tensor<*xf32> - %1 = mhlo.sqrt %0 : (tensor<*xf32>) -> tensor<*xf32> - %2 = mhlo.sqrt %1 : (tensor<*xf32>) -> tensor<*xf32> - func.return %2 : tensor<*xf32> -} - -// CHECK-SCF-LABEL: @sqrt -// CHECK-SCF-SAME: (%[[ARG:.*]]: tensor<*xf32>) -// CHECK-SCF: %[[SHAPE:.*]] = shape.shape_of %[[ARG]] -// CHECK-SCF: %[[N:.*]] = shape.num_elements %[[SHAPE]] -// CHECK-SCF: %[[FLAT_SHAPE:.*]] = tensor.from_elements %[[N]] -// CHECK-SCF: %[[FLAT_ARG:.*]] = mhlo.dynamic_reshape %[[ARG]], %[[FLAT_SHAPE]] : (tensor<*xf32>, tensor<1xindex>) -> tensor -// CHECK-SCF: %[[TMP0:.*]] = mhlo.sqrt %[[FLAT_ARG]] : tensor -// CHECK-SCF: %[[TMP1:.*]] = mhlo.sqrt %[[TMP0]] : tensor -// CHECK-SCF: %[[UNSHAPED_RES:.*]] = mhlo.sqrt %[[TMP1]] : tensor -// CHECK-SCF: %[[RES:.*]] = mhlo.dynamic_reshape %[[UNSHAPED_RES]], %[[SHAPE]] : (tensor, tensor) -> tensor<*xf32> -// CHECK-SCF: return %[[RES]] - -// ----- - -// Don't cluster ranked operations. -// CHECK-LABEL: @sqrt_ranked -// CHECK-SAME: (%[[ARG:.*]]: tensor<3x?xf32>) -func.func @sqrt_ranked(%arg: tensor<3x?xf32>) -> tensor<3x?xf32> { - // CHECK-NOT: rank_specialization_cluster - %0 = mhlo.sqrt %arg : (tensor<3x?xf32>) -> tensor<3x?xf32> - %1 = mhlo.sqrt %0 : (tensor<3x?xf32>) -> tensor<3x?xf32> - %2 = mhlo.sqrt %1 : (tensor<3x?xf32>) -> tensor<3x?xf32> - func.return %2 : tensor<3x?xf32> -} - -// CHECK-SCF-LABEL: @sqrt_ranked -// CHECK-SCF-NOT: dynamic_reshape -// CHECK-SCF: return - -// ----- - -// Operation with mixed ranked and unranked operands. -// CHECK-LABEL: @select_mixed -// CHECK-SAME: (%[[PRED:.*]]: tensor<*xi1>, %[[ARG1:.*]]: tensor<*xf32>, %[[ARG2:.*]]: tensor<2xf32>) -func.func @select_mixed(%pred: tensor<*xi1>, %arg1: tensor<*xf32>, - %arg2: tensor<2xf32>) -> tensor<*xf32> { - // CHECK: %[[RES:.*]] = "chlo.rank_specialization_cluster"(%[[PRED]], %[[ARG1]], %[[ARG2]]) - // CHECK: ^bb0(%[[PRED_:.*]]: tensor<*xi1>, %[[ARG1_:.*]]: tensor<*xf32>, %[[ARG2_:.*]]: tensor<2xf32>) - // CHECK: %[[TMP:.*]] = chlo.broadcast_select %[[PRED_]], %[[ARG1_]], %[[ARG2_]] - // CHECK: "chlo.rank_specialization_cluster_yield"(%[[TMP]]) - // CHECK: return %[[RES]] - %0 = "chlo.broadcast_select"(%pred, %arg1, %arg2) - : (tensor<*xi1>, tensor<*xf32>, tensor<2xf32>) -> tensor<*xf32> - func.return %0 : tensor<*xf32> -} - -// CHECK-SCF-LABEL: @select_mixed -// CHECK-SCF: chlo.broadcast_select %{{.*}}, %{{.*}}, %{{.*}} : (tensor, tensor, tensor) -// CHECK-SCF: return - -// ----- - -// Unary CHLO operation. -// CHECK-LABEL: @tan -// CHECK-SAME: (%[[ARG:.*]]: tensor<*xf32>) -func.func @tan(%arg : tensor<*xf32>) -> tensor<*xf32> { - // CHECK: %[[RES:.*]] = "chlo.rank_specialization_cluster"(%[[ARG]]) ({ - // CHECK: ^bb0(%[[ARG_:.*]]: tensor<*xf32>) - // CHECK: %[[TMP0:.*]] = chlo.tan %[[ARG_]] - // CHECK: %[[TMP1:.*]] = chlo.tan %[[TMP0]] - // CHECK: %[[TMP2:.*]] = chlo.tan %[[TMP1]] - // CHECK: "chlo.rank_specialization_cluster_yield"(%[[TMP2]]) - // CHECK: return %[[RES]] - %0 = chlo.tan %arg : tensor<*xf32> -> tensor<*xf32> - %1 = chlo.tan %0 : tensor<*xf32> -> tensor<*xf32> - %2 = chlo.tan %1 : tensor<*xf32> -> tensor<*xf32> - func.return %2 : tensor<*xf32> -} - -// CHECK-SCF-LABEL: @tan -// CHECK-SCF-SAME: (%[[ARG:.*]]: tensor<*xf32>) -// CHECK-SCF: %[[SHAPE:.*]] = shape.shape_of %[[ARG]] -// CHECK-SCF: %[[N:.*]] = shape.num_elements %[[SHAPE]] -// CHECK-SCF: %[[FLAT_SHAPE:.*]] = tensor.from_elements %[[N]] -// CHECK-SCF: %[[FLAT_ARG:.*]] = mhlo.dynamic_reshape %[[ARG]], %[[FLAT_SHAPE]] : (tensor<*xf32>, tensor<1xindex>) -> tensor -// CHECK-SCF: %[[TMP0:.*]] = chlo.tan %[[FLAT_ARG]] : tensor -// CHECK-SCF: %[[TMP1:.*]] = chlo.tan %[[TMP0]] : tensor -// CHECK-SCF: %[[UNSHAPED_RES:.*]] = chlo.tan %[[TMP1]] : tensor -// CHECK-SCF: %[[RES:.*]] = mhlo.dynamic_reshape %[[UNSHAPED_RES]], %[[SHAPE]] : (tensor, tensor) -> tensor<*xf32> -// CHECK-SCF: return %[[RES]] - -// ----- - -// Composition of unary/binary CHLO and unary MHLO ops. -// CHECK-LABEL: @mixed -// CHECK-SAME: (%[[ARG0:.*]]: tensor<*xf32>, %[[ARG1:.*]]: tensor<*xf32>, %[[ARG2:.*]]: tensor<*xf32>) -func.func @mixed(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>, %arg2 : tensor<*xf32>) - -> tensor<*xf32> { - // CHECK: %[[RES:.*]] = "chlo.rank_specialization_cluster"(%[[ARG2]], %[[ARG1]], %[[ARG0]]) - // CHECK: ^bb0(%[[ARG2_:.*]]: tensor<*xf32>, %[[ARG1_:.*]]: tensor<*xf32>, %[[ARG0_:.*]]: tensor<*xf32>) - // CHECK: %[[TMP0:.*]] = chlo.tan %[[ARG0_]] - // CHECK: %[[TMP1:.*]] = mhlo.sqrt %[[ARG1_]] - // CHECK: %[[TMP2:.*]] = chlo.broadcast_multiply %[[TMP0]], %[[TMP1]] - // CHECK: %[[TMP3:.*]] = chlo.broadcast_add %[[TMP2]], %[[ARG2_]] - // CHECK: %[[TMP4:.*]] = mhlo.sqrt %[[TMP3]] - // CHECK: %[[TMP5:.*]] = chlo.tan %[[TMP4]] - // CHECK: "chlo.rank_specialization_cluster_yield"(%[[TMP5]]) - // CHECK: return %[[RES]] - %0 = chlo.tan %arg0 : tensor<*xf32> -> tensor<*xf32> - %1 = mhlo.sqrt %arg1 : (tensor<*xf32>) -> tensor<*xf32> - %2 = chlo.broadcast_multiply %0, %1 - : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> - %3 = chlo.broadcast_add %2, %arg2 - : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> - %4 = mhlo.sqrt %3 : (tensor<*xf32>) -> tensor<*xf32> - %5 = chlo.tan %4 : tensor<*xf32> -> tensor<*xf32> - func.return %5 : tensor<*xf32> -} - -// CHECK-SCF-LABEL: @mixed -// CHECK-SCF-DAG: %[[TMP0:.*]] = chlo.tan %{{.*}} : tensor -// CHECK-SCF-DAG: %[[TMP1:.*]] = mhlo.sqrt %{{.*}} : tensor -// CHECK-SCF-DAG: %[[TMP2:.*]] = chlo.broadcast_multiply %[[TMP0]], %[[TMP1]] : (tensor, tensor) -// CHECK-SCF-DAG: %[[TMP3:.*]] = chlo.broadcast_add %[[TMP2]], %{{.*}} : (tensor, tensor) -// CHECK-SCF-DAG: %[[TMP4:.*]] = mhlo.sqrt %[[TMP3]] : tensor -// CHECK-SCF: chlo.tan %[[TMP4]] : tensor - -// ----- - -// Constant cluster operand. -// CHECK-LABEL: @relu -// CHECK-SAME: (%[[ARG:.*]]: tensor<*xf32>) -func.func @relu(%arg : tensor<*xf32>) -> tensor<*xf32> { - // CHECK: %[[C0:.*]] = mhlo.constant dense<0.000000e+00> - // CHECK: %[[RES:.*]] = "chlo.rank_specialization_cluster"(%[[ARG]], %[[C0]]) - // CHECK: ^bb0(%[[ARG_:.*]]: tensor<*xf32>, %[[C0_:.*]]: tensor): - // CHECK: %[[TMP:.*]] = chlo.broadcast_maximum %[[ARG_]], %[[C0_]] - // CHECK: "chlo.rank_specialization_cluster_yield"(%[[TMP]]) - // CHECK: return %[[RES]] - %0 = mhlo.constant dense<0.000000e+00> : tensor - %1 = chlo.broadcast_maximum %0, %arg - : (tensor, tensor<*xf32>) -> tensor<*xf32> - func.return %1 : tensor<*xf32> -} - -// CHECK-SCF-LABEL: @relu -// CHECK-SCF-SAME: (%[[ARG:.*]]: tensor<*xf32>) -// CHECK-SCF: %[[C0:.*]] = mhlo.constant dense<0.000000e+00> -// CHECK-SCF: %[[SHAPE:.*]] = shape.shape_of %[[ARG]] -// CHECK-SCF: %[[N:.*]] = shape.num_elements %[[SHAPE]] -// CHECK-SCF: %[[FLAT_SHAPE:.*]] = tensor.from_elements %[[N]] -// CHECK-SCF: %[[FLAT_ARG:.*]] = mhlo.dynamic_reshape %[[ARG]], %[[FLAT_SHAPE]] : (tensor<*xf32>, tensor<1xindex>) -> tensor -// CHECK-SCF: %[[UNSHAPED_RES:.*]] = chlo.broadcast_maximum %[[FLAT_ARG]], %[[C0]] : (tensor, tensor) -// CHECK-SCF: %[[RES:.*]] = mhlo.dynamic_reshape %[[UNSHAPED_RES]], %[[SHAPE]] : (tensor, tensor) -> tensor<*xf32> -// CHECK-SCF: return %[[RES]] - -// ----- - -// Cluster with binary non-broadcasting operation. -// CHECK-LABEL: @angle -// CHECK-SAME: (%[[ARG:.*]]: tensor<*xcomplex>) -func.func @angle(%arg : tensor<*xcomplex>) -> tensor<*xf32> { - // CHECK: %[[RES:.*]] = "chlo.rank_specialization_cluster"(%[[ARG]]) - // CHECK: ^bb0(%[[ARG_:.*]]: tensor<*xcomplex>): - // CHECK: %[[IMAG:.*]] = mhlo.imag %[[ARG_]] - // CHECK: %[[REAL:.*]] = mhlo.real %[[ARG_]] - // CHECK: %[[TMP:.*]] = mhlo.atan2 %[[IMAG]], %[[REAL]] - // CHECK: "chlo.rank_specialization_cluster_yield"(%[[TMP]]) - // CHECK: return %[[RES]] - %0 = mhlo.imag %arg : (tensor<*xcomplex>) -> tensor<*xf32> - %1 = mhlo.real %arg : (tensor<*xcomplex>) -> tensor<*xf32> - %2 = mhlo.atan2 %0, %1 : tensor<*xf32> - func.return %2 : tensor<*xf32> -} - -// CHECK-SCF-LABEL: @angle -// CHECK-SCF-SAME: (%[[ARG:.*]]: tensor<*xcomplex>) -// CHECK-SCF: %[[SHAPE:.*]] = shape.shape_of %[[ARG]] -// CHECK-SCF: %[[N:.*]] = shape.num_elements %[[SHAPE]] -// CHECK-SCF: %[[FLAT_SHAPE:.*]] = tensor.from_elements %[[N]] -// CHECK-SCF: %[[FLAT_ARG:.*]] = mhlo.dynamic_reshape %[[ARG]], %[[FLAT_SHAPE]] : (tensor<*xcomplex>, tensor<1xindex>) -> tensor> -// CHECK-SCF: %[[IMAG:.*]] = mhlo.imag %[[FLAT_ARG]] : (tensor>) -// CHECK-SCF: %[[REAL:.*]] = mhlo.real %[[FLAT_ARG]] : (tensor>) -// CHECK-SCF: %[[UNSHAPED_RES:.*]] = mhlo.atan2 %[[IMAG]], %[[REAL]] : tensor - // CHECK-SCF: %[[RES:.*]] = mhlo.dynamic_reshape %[[UNSHAPED_RES]], %[[SHAPE]] : (tensor, tensor) -> tensor<*xf32> -// CHECK-SCF: return %[[RES]] - -// ----- - -// Scalar cluster operand. -// CHECK-LABEL: @xlogy -// CHECK-SAME: (%[[ARG0:.*]]: tensor<*xf32>, %[[ARG1:.*]]: tensor<*xf32>) -func.func @xlogy(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>) -> tensor<*xf32> { - // CHECK: %[[C0:.*]] = mhlo.constant dense<0.000000e+00> - // CHECK: %[[RES:.*]] = "chlo.rank_specialization_cluster"(%[[C0]], %[[ARG0]], %[[ARG1]]) - // CHECK: ^bb0(%[[C0_:.*]]: tensor, %[[ARG0_:.*]]: tensor<*xf32>, %[[ARG1_:.*]]: tensor<*xf32>): - // CHECK: %[[TMP0:.*]] = chlo.broadcast_compare %[[ARG0_]], %[[C0_]] {comparison_direction = #chlo} - // CHECK: %[[TMP1:.*]] = mhlo.log %[[ARG1_]] - // CHECK: %[[TMP2:.*]] = chlo.broadcast_multiply %[[ARG0_]], %[[TMP1]] - // CHECK: %[[TMP3:.*]] = chlo.broadcast_select %[[TMP0]], %[[C0_]], %[[TMP2]] - // CHECK: "chlo.rank_specialization_cluster_yield"(%[[TMP3]]) - // CHECK: return %[[RES]] - %0 = mhlo.constant dense<0.000000e+00> : tensor - %1 = tensor.cast %0 : tensor to tensor - %2 = chlo.broadcast_compare %arg0, %1 {comparison_direction = #chlo} - : (tensor<*xf32>, tensor) -> tensor<*xi1> - %3 = mhlo.log %arg1 : (tensor<*xf32>) -> tensor<*xf32> - %4 = chlo.broadcast_multiply %arg0, %3 - : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> - %5 = chlo.broadcast_select %2, %1, %4 - : (tensor<*xi1>, tensor, tensor<*xf32>) -> tensor<*xf32> - func.return %5 : tensor<*xf32> -} - -// CHECK-SCF: @xlogy -// CHECK-SCF-SAME: (%[[ARG0:.*]]: tensor<*xf32>, %[[ARG1:.*]]: tensor<*xf32>) -// CHECK-SCF-DAG: %[[C1:.*]] = arith.constant 1 -// CHECK-SCF-DAG: %[[ONE_SHAPE_1:.*]] = shape.const_shape [1] -// CHECK-SCF-DAG: %[[SHAPE_ARG0:.*]] = shape.shape_of %[[ARG0]] -// CHECK-SCF-DAG: %[[SHAPE_ARG1:.*]] = shape.shape_of %[[ARG1]] -// CHECK-SCF-DAG: %[[ZERO:.*]] = mhlo.constant dense<0.00{{.*}}> -// Lhs scalar case: -// CHECK-SCF-DAG: %[[LHS_N:.*]] = shape.num_elements %[[SHAPE_ARG0]] -// CHECK-SCF-DAG: %[[LHS_SCALAR:.*]] = arith.cmpi eq, %[[LHS_N]], %[[C1]] -// CHECK-SCF: %[[UNSHAPED_RES:.*]] = scf.if %[[LHS_SCALAR]] -// CHECK-SCF-DAG: %[[N:.*]] = shape.num_elements %[[SHAPE_ARG1]] -// CHECK-SCF-DAG: %[[FLAT_SHAPE:.*]] = tensor.from_elements %[[N]] -// CHECK-SCF-DAG: %[[FLAT_NON_SCALAR:.*]] = mhlo.dynamic_reshape %[[ARG1]], %[[FLAT_SHAPE]] -// CHECK-SCF-DAG: %[[SCALAR:.*]] = mhlo.reshape %[[ARG0]] -// CHECK-SCF-DAG: %[[PRED:.*]] = chlo.broadcast_compare %[[SCALAR]], %[[ZERO]] {comparison_direction = #chlo} : (tensor, tensor) -// CHECK-SCF-DAG: %[[TMP0:.*]] = mhlo.log %[[FLAT_NON_SCALAR]] : tensor -// CHECK-SCF-DAG: %[[TMP1:.*]] = chlo.broadcast_multiply %[[SCALAR]], %[[TMP0]] : (tensor, tensor) -// CHECK-SCF-DAG: %[[INNER_RES:.*]] = chlo.broadcast_select %[[PRED]], %[[ZERO]], %[[TMP1]] : (tensor, tensor, tensor) -// CHECK-SCF-DAG: %[[INNER_RES_:.*]] = tensor.cast %[[INNER_RES]] -// CHECK-SCF: scf.yield %[[INNER_RES_]] -// CHECK-SCF: else -// Rhs scalar case: -// CHECK-SCF-DAG: %[[RHS_N:.*]] = shape.num_elements %[[SHAPE_ARG1]] -// CHECK-SCF-DAG: %[[RHS_SCALAR:.*]] = arith.cmpi eq, %[[RHS_N]], %[[C1]] -// CHECK-SCF: %{{.*}} = scf.if %[[RHS_SCALAR]] -// CHECK-SCF-DAG: %[[N:.*]] = shape.num_elements %[[SHAPE_ARG0]] -// CHECK-SCF-DAG: %[[FLAT_SHAPE:.*]] = tensor.from_elements %[[N]] -// CHECK-SCF-DAG: %[[FLAT_NON_SCALAR:.*]] = mhlo.dynamic_reshape %[[ARG0]], %[[FLAT_SHAPE]] -// CHECK-SCF-DAG: %[[SCALAR:.*]] = mhlo.reshape %[[ARG1]] -// CHECK-SCF-DAG: %[[PRED:.*]] = chlo.broadcast_compare %[[FLAT_NON_SCALAR]], %[[ZERO]] {comparison_direction = #chlo} : (tensor, tensor) -// CHECK-SCF-DAG: %[[TMP0:.*]] = mhlo.log %[[SCALAR]] : tensor -// CHECK-SCF-DAG: %[[TMP1:.*]] = chlo.broadcast_multiply %[[FLAT_NON_SCALAR]], %[[TMP0]] : (tensor, tensor) -// CHECK-SCF-DAG: %[[INNER_RES:.*]] = chlo.broadcast_select %[[PRED]], %[[ZERO]], %[[TMP1]] : (tensor, tensor, tensor) -// CHECK-SCF-DAG: %[[INNER_RES_:.*]] = tensor.cast %[[INNER_RES]] -// CHECK-SCF: scf.yield %[[INNER_RES_]] -// CHECK-SCF: else -// Equal shapes case: -// CHECK-SCF-DAG: %[[SHAPES_EQ:.*]] = shape.shape_eq %[[SHAPE_ARG0]], %[[SHAPE_ARG1]] -// CHECK-SCF: %{{.*}} = scf.if %[[SHAPES_EQ]] -// CHECK-SCF-DAG: %[[SHAPE:.*]] = shape.any %[[SHAPE_ARG0]], %[[SHAPE_ARG1]] -// CHECK-SCF-DAG: %[[N:.*]] = shape.num_elements %[[SHAPE]] -// CHECK-SCF-DAG: %[[FLAT_SHAPE:.*]] = tensor.from_elements %[[N]] -// CHECK-SCF-DAG: %[[FLAT_ARG0:.*]] = mhlo.dynamic_reshape %[[ARG0]], %[[FLAT_SHAPE]] -// CHECK-SCF-DAG: %[[FLAT_ARG1:.*]] = mhlo.dynamic_reshape %[[ARG1]], %[[FLAT_SHAPE]] -// CHECK-SCF-DAG: %[[PRED:.*]] = chlo.broadcast_compare %[[FLAT_ARG0]], %[[ZERO]] {comparison_direction = #chlo} : (tensor, tensor) -// CHECK-SCF-DAG: %[[TMP0:.*]] = mhlo.log %[[FLAT_ARG1]] : tensor -// CHECK-SCF-DAG: %[[TMP1:.*]] = chlo.broadcast_multiply %[[FLAT_ARG0]], %[[TMP0]] : (tensor, tensor) -// CHECK-SCF-DAG: %[[INNER_RES:.*]] = chlo.broadcast_select %[[PRED]], %[[ZERO]], %[[TMP1]] : (tensor, tensor, tensor) -// CHECK-SCF-DAG: %[[INNER_RES_:.*]] = tensor.cast %[[INNER_RES]] -// CHECK-SCF: scf.yield %[[INNER_RES_]] -// CHECK-SCF: else -// Find maximum reduced rank. -// CHECK-SCF-DAG: %[[REDUCED_SHAPES:.*]]:2 = chlo.minimum_broadcast_shapes %[[SHAPE_ARG0]], %[[SHAPE_ARG1]] -// CHECK-SCF-DAG: %[[REDUCED_RANK0:.*]] = shape.rank %[[REDUCED_SHAPES]]#0 -// CHECK-SCF-DAG: %[[REDUCED_RANK1:.*]] = shape.rank %[[REDUCED_SHAPES]]#1 -// CHECK-SCF-DAG: %[[R0_GT_R1:.*]] = arith.cmpi sgt, %[[REDUCED_RANK0]], %[[REDUCED_RANK1]] -// CHECK-SCF-DAG: %[[MAX_RED_RANK:.*]] = arith.select %[[R0_GT_R1]], %[[REDUCED_RANK0]], %[[REDUCED_RANK1]] -// Generic case 1: -// CHECK-SCF: %[[MAX_RED_RANK_LE_1:.*]] = arith.cmpi ule, %[[MAX_RED_RANK]], %[[C1]] -// CHECK-SCF: %{{.*}} = scf.if %[[MAX_RED_RANK_LE_1]] -// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#0, %[[ONE_SHAPE_1]] -// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#1, %[[ONE_SHAPE_1]] -// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0_:.*]] = tensor.cast %[[EXT_SHAPE_ARG0]] -// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1_:.*]] = tensor.cast %[[EXT_SHAPE_ARG1]] -// CHECK-SCF-DAG: %[[REDUCED_ARG0:.*]] = mhlo.dynamic_reshape %[[ARG0]], %[[EXT_SHAPE_ARG0_]] -// CHECK-SCF-DAG: %[[REDUCED_ARG1:.*]] = mhlo.dynamic_reshape %[[ARG1]], %[[EXT_SHAPE_ARG1_]] -// CHECK-SCF-DAG: %[[PRED:.*]] = chlo.broadcast_compare %[[REDUCED_ARG0]], %[[ZERO]] {comparison_direction = #chlo} : (tensor, tensor) -// CHECK-SCF-DAG: %[[TMP0:.*]] = mhlo.log %[[REDUCED_ARG1]] : tensor -// CHECK-SCF-DAG: %[[TMP1:.*]] = chlo.broadcast_multiply %[[REDUCED_ARG0]], %[[TMP0]] : (tensor, tensor) -// CHECK-SCF-DAG: %[[INNER_RES:.*]] = chlo.broadcast_select %[[PRED]], %[[ZERO]], %[[TMP1]] : (tensor, tensor, tensor) -// CHECK-SCF-DAG: %[[INNER_RES_:.*]] = tensor.cast %[[INNER_RES]] -// CHECK-SCF: scf.yield %[[INNER_RES_]] -// CHECK-SCF: else -// ... -// Reshape the result. -// CHECK-SCF: %[[S0:.*]] = shape.shape_of %[[ARG0]] -// CHECK-SCF: %[[S0_:.*]] = shape.shape_of %[[ARG0]] -// CHECK-SCF: %[[S1:.*]] = shape.shape_of %[[ARG1]] -// CHECK-SCF: %[[TMP:.*]] = shape.broadcast %[[S0_]], %[[S1]] -// CHECK-SCF: %[[RES_SHAPE:.*]] = shape.broadcast %[[S0]], %[[TMP]] -// CHECK-SCF: %[[RES:.*]] = mhlo.dynamic_reshape %[[UNSHAPED_RES]], %[[RES_SHAPE]] : (tensor<*xf32>, tensor) -> tensor<*xf32> -// CHECK-SCF: return %[[RES]] - -// ----- - -// CHECK-LABEL: @mul -// CHECK-SAME: (%[[ARG0:.*]]: tensor<*xf32>, %[[ARG1:.*]]: tensor<*xf32>) -func.func @mul(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>) -> tensor<*xf32> { - // CHECK: %[[RES:.*]] = "chlo.rank_specialization_cluster"(%[[ARG0]], %[[ARG1]]) - // CHECK: ^bb0(%[[ARG0_:.*]]: tensor<*xf32>, %[[ARG1_:.*]]: tensor<*xf32>): - // CHECK: %[[TMP:.*]] = chlo.broadcast_multiply %[[ARG0_]], %[[ARG1_]] - // CHECK: "chlo.rank_specialization_cluster_yield"(%[[TMP]]) - // CHECK: return %[[RES]] - %0 = chlo.broadcast_multiply %arg0, %arg1 : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> - func.return %0 : tensor<*xf32> -} - -// CHECK-SCF-LABEL: @mul -// CHECK-SCF-SAME: (%[[ARG0:.*]]: tensor<*xf32>, %[[ARG1:.*]]: tensor<*xf32>) -// CHECK-SCF-DAG: %[[C1:.*]] = arith.constant 1 -// CHECK-SCF-DAG: %[[C2:.*]] = arith.constant 2 -// CHECK-SCF-DAG: %[[C3:.*]] = arith.constant 3 -// CHECK-SCF-DAG: %[[ONE_SHAPE_1:.*]] = shape.const_shape [1] -// CHECK-SCF-DAG: %[[ONE_SHAPE_2:.*]] = shape.const_shape [1, 1] -// CHECK-SCF-DAG: %[[ONE_SHAPE_3:.*]] = shape.const_shape [1, 1, 1] -// CHECK-SCF-DAG: %[[SHAPE_ARG0:.*]] = shape.shape_of %[[ARG0]] -// CHECK-SCF-DAG: %[[SHAPE_ARG1:.*]] = shape.shape_of %[[ARG1]] -// Lhs scalar case: -// CHECK-SCF-DAG: %[[LHS_N:.*]] = shape.num_elements %[[SHAPE_ARG0]] -// CHECK-SCF-DAG: %[[LHS_SCALAR:.*]] = arith.cmpi eq, %[[LHS_N]], %[[C1]] -// CHECK-SCF: %[[UNSHAPED_RES_LHS_SCALAR:.*]] = scf.if %[[LHS_SCALAR]] -// CHECK-SCF-DAG: %[[N:.*]] = shape.num_elements %[[SHAPE_ARG1]] -// CHECK-SCF-DAG: %[[FLAT_SHAPE:.*]] = tensor.from_elements %[[N]] -// CHECK-SCF-DAG: %[[FLAT_NON_SCALAR:.*]] = mhlo.dynamic_reshape %[[ARG1]], %[[FLAT_SHAPE]] -// CHECK-SCF-DAG: %[[SCALAR:.*]] = mhlo.reshape %[[ARG0]] -// CHECK-SCF-DAG: %[[INNER_RES:.*]] = chlo.broadcast_multiply %[[SCALAR]], %[[FLAT_NON_SCALAR]] : (tensor, tensor) -// CHECK-SCF-DAG: %[[INNER_RES_:.*]] = tensor.cast %[[INNER_RES]] -// CHECK-SCF: scf.yield %[[INNER_RES_]] -// CHECK-SCF: else -// Rhs scalar case: -// CHECK-SCF-DAG: %[[RHS_N:.*]] = shape.num_elements %[[SHAPE_ARG1]] -// CHECK-SCF-DAG: %[[RHS_SCALAR:.*]] = arith.cmpi eq, %[[RHS_N]], %[[C1]] -// CHECK-SCF: %[[UNSHAPED_RES_RHS_SCALAR:.*]] = scf.if %[[RHS_SCALAR]] -// CHECK-SCF-DAG: %[[N:.*]] = shape.num_elements %[[SHAPE_ARG0]] -// CHECK-SCF-DAG: %[[FLAT_SHAPE:.*]] = tensor.from_elements %[[N]] -// CHECK-SCF-DAG: %[[FLAT_NON_SCALAR:.*]] = mhlo.dynamic_reshape %[[ARG0]], %[[FLAT_SHAPE]] -// CHECK-SCF-DAG: %[[SCALAR:.*]] = mhlo.reshape %[[ARG1]] -// CHECK-SCF-DAG: %[[INNER_RES:.*]] = chlo.broadcast_multiply %[[FLAT_NON_SCALAR]], %[[SCALAR]] : (tensor, tensor) -// CHECK-SCF-DAG: %[[INNER_RES_:.*]] = tensor.cast %[[INNER_RES]] -// CHECK-SCF: scf.yield %[[INNER_RES_]] -// CHECK-SCF: else -// Equal shapes case: -// CHECK-SCF-DAG: %[[SHAPES_EQ:.*]] = shape.shape_eq %[[SHAPE_ARG0]], %[[SHAPE_ARG1]] -// CHECK-SCF: %[[UNSHAPED_RES_EQ_SHAPES:.*]] = scf.if %[[SHAPES_EQ]] -// CHECK-SCF-DAG: %[[SHAPE:.*]] = shape.any %[[SHAPE_ARG0]], %[[SHAPE_ARG1]] -// CHECK-SCF-DAG: %[[N:.*]] = shape.num_elements %[[SHAPE]] -// CHECK-SCF-DAG: %[[FLAT_SHAPE:.*]] = tensor.from_elements %[[N]] -// CHECK-SCF-DAG: %[[FLAT_ARG0:.*]] = mhlo.dynamic_reshape %[[ARG0]], %[[FLAT_SHAPE]] -// CHECK-SCF-DAG: %[[FLAT_ARG1:.*]] = mhlo.dynamic_reshape %[[ARG1]], %[[FLAT_SHAPE]] -// CHECK-SCF-DAG: %[[INNER_RES:.*]] = chlo.broadcast_multiply %[[FLAT_ARG0]], %[[FLAT_ARG1]] : (tensor, tensor) -// CHECK-SCF-DAG: %[[INNER_RES_:.*]] = tensor.cast %[[INNER_RES]] -// CHECK-SCF: scf.yield %[[INNER_RES_]] -// CHECK-SCF: else -// Find maximum reduced rank. -// CHECK-SCF-DAG: %[[REDUCED_SHAPES:.*]]:2 = chlo.minimum_broadcast_shapes %[[SHAPE_ARG0]], %[[SHAPE_ARG1]] -// CHECK-SCF-DAG: %[[REDUCED_RANK0:.*]] = shape.rank %[[REDUCED_SHAPES]]#0 -// CHECK-SCF-DAG: %[[REDUCED_RANK1:.*]] = shape.rank %[[REDUCED_SHAPES]]#1 -// CHECK-SCF-DAG: %[[R0_GT_R1:.*]] = arith.cmpi sgt, %[[REDUCED_RANK0]], %[[REDUCED_RANK1]] -// CHECK-SCF-DAG: %[[MAX_RED_RANK:.*]] = arith.select %[[R0_GT_R1]], %[[REDUCED_RANK0]], %[[REDUCED_RANK1]] -// Generic case 1: -// CHECK-SCF: %[[MAX_RED_RANK_LE_1:.*]] = arith.cmpi ule, %[[MAX_RED_RANK]], %[[C1]] -// CHECK-SCF: %[[UNSHAPED_RES_1:.*]] = scf.if %[[MAX_RED_RANK_LE_1]] -// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#0, %[[ONE_SHAPE_1]] -// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#1, %[[ONE_SHAPE_1]] -// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0_:.*]] = tensor.cast %[[EXT_SHAPE_ARG0]] -// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1_:.*]] = tensor.cast %[[EXT_SHAPE_ARG1]] -// CHECK-SCF-DAG: %[[REDUCED_ARG0:.*]] = mhlo.dynamic_reshape %[[ARG0]], %[[EXT_SHAPE_ARG0_]] -// CHECK-SCF-DAG: %[[REDUCED_ARG1:.*]] = mhlo.dynamic_reshape %[[ARG1]], %[[EXT_SHAPE_ARG1_]] -// CHECK-SCF-DAG: %[[INNER_RES:.*]] = chlo.broadcast_multiply %[[REDUCED_ARG0]], %[[REDUCED_ARG1]] : (tensor, tensor) -// CHECK-SCF-DAG: %[[INNER_RES_:.*]] = tensor.cast %[[INNER_RES]] -// CHECK-SCF: scf.yield %[[INNER_RES_]] -// CHECK-SCF: else -// Generic case 2: -// CHECK-SCF: %[[MAX_RED_RANK_LE_2:.*]] = arith.cmpi ule, %[[MAX_RED_RANK]], %[[C2]] -// CHECK-SCF: %[[UNSHAPED_RES_2:.*]] = scf.if %[[MAX_RED_RANK_LE_2]] -// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#0, %[[ONE_SHAPE_2]] -// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#1, %[[ONE_SHAPE_2]] -// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0_:.*]] = tensor.cast %[[EXT_SHAPE_ARG0]] -// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1_:.*]] = tensor.cast %[[EXT_SHAPE_ARG1]] -// CHECK-SCF-DAG: %[[REDUCED_ARG0:.*]] = mhlo.dynamic_reshape %[[ARG0]], %[[EXT_SHAPE_ARG0_]] -// CHECK-SCF-DAG: %[[REDUCED_ARG1:.*]] = mhlo.dynamic_reshape %[[ARG1]], %[[EXT_SHAPE_ARG1_]] -// CHECK-SCF-DAG: %[[INNER_RES:.*]] = chlo.broadcast_multiply %[[REDUCED_ARG0]], %[[REDUCED_ARG1]] : (tensor, tensor) -// CHECK-SCF-DAG: %[[INNER_RES_:.*]] = tensor.cast %[[INNER_RES]] -// CHECK-SCF: scf.yield %[[INNER_RES_]] -// CHECK-SCF: else -// Generic case 3: -// CHECK-SCF: %[[MAX_RED_RANK_LE_3:.*]] = arith.cmpi ule, %[[MAX_RED_RANK]], %[[C3]] -// CHECK-SCF: assert %[[MAX_RED_RANK_LE_3]], "Input for dynamic binary or n-ary op lowering was of a rank greater than 3" -// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#0, %[[ONE_SHAPE_3]] -// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#1, %[[ONE_SHAPE_3]] -// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0_:.*]] = tensor.cast %[[EXT_SHAPE_ARG0]] -// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1_:.*]] = tensor.cast %[[EXT_SHAPE_ARG1]] -// CHECK-SCF-DAG: %[[REDUCED_ARG0:.*]] = mhlo.dynamic_reshape %[[ARG0]], %[[EXT_SHAPE_ARG0_]] -// CHECK-SCF-DAG: %[[REDUCED_ARG1:.*]] = mhlo.dynamic_reshape %[[ARG1]], %[[EXT_SHAPE_ARG1_]] -// CHECK-SCF-DAG: %[[INNER_RES:.*]] = chlo.broadcast_multiply %[[REDUCED_ARG0]], %[[REDUCED_ARG1]] : (tensor, tensor) -// CHECK-SCF-DAG: %[[INNER_RES_:.*]] = tensor.cast %[[INNER_RES]] -// CHECK-SCF: scf.yield %[[INNER_RES_]] -// CHECK-SCF: scf.yield %[[UNSHAPED_RES_2]] -// CHECK-SCF: scf.yield %[[UNSHAPED_RES_1]] -// CHECK-SCF: scf.yield %[[UNSHAPED_RES_EQ_SHAPES]] -// CHECK-SCF: scf.yield %[[UNSHAPED_RES_RHS_SCALAR]] -// Reshape the result. -// CHECK-SCF-DAG: %[[SHAPE_ARG0:.*]] = shape.shape_of %[[ARG0]] -// CHECK-SCF-DAG: %[[SHAPE_ARG1:.*]] = shape.shape_of %[[ARG1]] -// CHECK-SCF-DAG: %[[RES_SHAPE:.*]] = shape.broadcast %[[SHAPE_ARG0]], %[[SHAPE_ARG1]] -// CHECK-SCF-DAG: %[[RES:.*]] = mhlo.dynamic_reshape %[[UNSHAPED_RES_LHS_SCALAR]], %[[RES_SHAPE]] -// CHECK-SCF: return %[[RES]] - -// ----- - -// CHECK-LABEL: @merge_clusters -// CHECK-SAME: (%[[ARG0:.*]]: tensor<*xf64>, %[[ARG1:.*]]: tensor<*xf64>) -func.func @merge_clusters(%arg0: tensor<*xf64>, %arg1 : tensor<*xf64>) - -> tensor<*xf64> { - // CHECK: %[[RES:.*]] = "chlo.rank_specialization_cluster"(%[[ARG0]], %[[ARG1]]) - // CHECK: ^bb0(%[[ARG0_:.*]]: tensor<*xf64>, %[[ARG1_:.*]]: tensor<*xf64>): - // CHECK: %[[TMP0:.*]] = mhlo.tanh %[[ARG0_]] - // CHECK: %[[TMP1:.*]] = chlo.broadcast_add %[[TMP0]], %[[ARG0_]] - // CHECK: %[[TMP2:.*]] = chlo.broadcast_add %[[TMP1]], %[[ARG1_]] - // CHECK: "chlo.rank_specialization_cluster_yield"(%[[TMP2]]) - // CHECK: return %[[RES]] - %0 = "chlo.rank_specialization_cluster"(%arg0) ({ - ^bb0(%arg0_: tensor<*xf64>): - %1 = mhlo.tanh %arg0_ : (tensor<*xf64>) -> tensor<*xf64> - "chlo.rank_specialization_cluster_yield"(%1) : (tensor<*xf64>) -> () - }) : (tensor<*xf64>) -> (tensor<*xf64>) - %2 = "chlo.rank_specialization_cluster"(%0, %arg0, %arg1) ({ - ^bb0(%3: tensor<*xf64>, %4: tensor<*xf64>, %5: tensor<*xf64>): - %6 = "chlo.broadcast_add"(%3, %4) - : (tensor<*xf64>, tensor<*xf64>) -> tensor<*xf64> - %7 = "chlo.broadcast_add"(%6, %5) - : (tensor<*xf64>, tensor<*xf64>) -> tensor<*xf64> - "chlo.rank_specialization_cluster_yield"(%7) : (tensor<*xf64>) -> () - }) : (tensor<*xf64>, tensor<*xf64>, tensor<*xf64>) -> (tensor<*xf64>) - func.return %2 : tensor<*xf64> -} - -// ----- - -// CHECK-LABEL: @all_equal_shapes_inferrable -// CHECK-SAME: (%[[ARG0:.*]]: tensor<*xf64>, %[[ARG1:.*]]: tensor<*xf64>) -func.func @all_equal_shapes_inferrable(%arg0: tensor<*xf64>, %arg1 : tensor<*xf64>) - -> tensor<*xf64> { - // CHECK: %[[RES:.*]] = "chlo.rank_specialization_cluster"(%[[ARG0]], %[[ARG1]]) - // CHECK: ^bb0(%[[ARG0_:.*]]: tensor<*xf64>, %[[ARG1_:.*]]: tensor<*xf64>) - // CHECK: %[[INNER_RES:.*]] = mhlo.add %[[ARG0_]], %[[ARG1_]] - // CHECK: "chlo.rank_specialization_cluster_yield"(%[[INNER_RES]]) - // CHECK: return %[[RES]] - %0 = "mhlo.add"(%arg0, %arg1) - : (tensor<*xf64>, tensor<*xf64>) -> tensor<*xf64> - func.return %0 : tensor<*xf64> -} - -// CHECK-SCF-LABEL: @all_equal_shapes_inferrable -// CHECK-SCF-SAME: (%[[ARG0:.*]]: tensor<*xf64>, %[[ARG1:.*]]: tensor<*xf64>) -// CHECK-SCF-DAG: %[[S0:.*]] = shape.shape_of %[[ARG0]] -// CHECK-SCF-DAG: %[[S1:.*]] = shape.shape_of %[[ARG1]] -// CHECK-SCF-DAG: %[[S:.*]] = shape.any %[[S0]], %[[S1]] -// CHECK-SCF-DAG: %[[N:.*]] = shape.num_elements %[[S]] -// CHECK-SCF-DAG: %[[FLAT_S:.*]] = tensor.from_elements %[[N]] -// CHECK-SCF-DAG: %[[FLAT0:.*]] = mhlo.dynamic_reshape %[[ARG0]], %[[FLAT_S]] -// CHECK-SCF-DAG: %[[FLAT1:.*]] = mhlo.dynamic_reshape %[[ARG1]], %[[FLAT_S]] -// CHECK-SCF: %[[FLAT_RES:.*]] = mhlo.add %[[FLAT0]], %[[FLAT1]] -// CHECK-SCF-DAG: %[[RES:.*]] = mhlo.dynamic_reshape %[[FLAT_RES]], %[[S0]] -// CHECK-SCF: return %[[RES]] - -// ----- - -// All shapes are equal, which is inferrable through the select op. -// CHECK-LABEL: @relu_grad -// CHECK-SAME: (%[[ARG0:.*]]: tensor<*xf32>, %[[ARG1:.*]]: tensor<*xf32>) -func.func @relu_grad(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> { - // CHECK: %[[RES:.*]] = "chlo.rank_specialization_cluster"(%[[ARG1]], %[[ARG0]]) - // CHECK: ^bb0(%[[ARG1_:.*]]: tensor<*xf32>, %[[ARG0_:.*]]: tensor<*xf32>) - // CHECK: %[[TMP0:.*]] = "chlo.constant_like"(%[[ARG0_]]) {value = 0.0{{.*}}e+00 : f32} - // CHECK: %[[TMP1:.*]] = mhlo.compare GT, %[[ARG0_]], %[[TMP0]] - // CHECK: %[[TMP2:.*]] = mhlo.select %[[TMP1]], %[[ARG1_]], %[[TMP0]] - // CHECK: "chlo.rank_specialization_cluster_yield"(%[[TMP2]]) - // CHECK: return %[[RES]] - %0 = "chlo.constant_like"(%arg0) {value = 0.000000e+00 : f32} : (tensor<*xf32>) -> tensor<*xf32> - %1 = "mhlo.compare"(%arg0, %0) {comparison_direction = #mhlo} : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xi1> - %2 = "mhlo.select"(%1, %arg1, %0) : (tensor<*xi1>, tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> - func.return %2 : tensor<*xf32> -} - -// CHECK-SCF-LABEL: @relu_grad -// CHECK-SCF-SAME: (%[[ARG0:.*]]: tensor<*xf32>, %[[ARG1:.*]]: tensor<*xf32>) -// CHECK-SCF-DAG: %[[S0:.*]] = shape.shape_of %[[ARG0]] -// CHECK-SCF-DAG: %[[S1:.*]] = shape.shape_of %[[ARG1]] -// CHECK-SCF-DAG: %[[S:.*]] = shape.any %[[S1]], %[[S0]] -// CHECK-SCF-DAG: %[[N:.*]] = shape.num_elements %[[S]] -// CHECK-SCF-DAG: %[[FLAT_SHAPE:.*]] = tensor.from_elements %[[N]] -// CHECK-SCF-DAG: %[[FLAT0:.*]] = mhlo.dynamic_reshape %[[ARG0]], %[[FLAT_SHAPE]] -// CHECK-SCF-DAG: %[[FLAT1:.*]] = mhlo.dynamic_reshape %[[ARG1]], %[[FLAT_SHAPE]] -// CHECK-SCF-DAG: %[[ZERO:.*]] = "chlo.constant_like"(%[[FLAT0]]) {value = 0.0{{.*}}+00 : f32} -// CHECK-SCF-DAG: %[[PRED:.*]] = mhlo.compare GT, %[[FLAT0]], %[[ZERO]] -// CHECK-SCF: %[[UNSHAPED_RES:.*]] = mhlo.select %[[PRED]], %[[FLAT1]], %[[ZERO]] -// CHECK-SCF-DAG: %[[RES:.*]] = mhlo.dynamic_reshape %[[UNSHAPED_RES]], %[[S1]] -// CHECK-SCF: return %[[RES]] - -// ----- - -// Find shape equivalences through surrounding constraints. -// CHECK-LABEL: @relu_grad -// CHECK-SAME: (%[[ARG0:.*]]: tensor<*xf32>, %[[ARG1:.*]]: tensor<*xf32>) -func.func @relu_grad(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> { - // CHECK-DAG: %[[S0:.*]] = shape.shape_of %[[ARG0]] - // CHECK-DAG: %[[S1:.*]] = shape.shape_of %[[ARG1]] - // CHECK-DAG: %[[CSTR_EQ:.*]] = shape.cstr_eq %[[S0]], %[[S1]] - // CHECK: %[[RES:.*]] = shape.assuming %[[CSTR_EQ]] - // CHECK: %[[INNER_RES:.*]] = "chlo.rank_specialization_cluster"(%[[ARG1]], %[[ARG0]]) - // CHECK: ^bb0(%[[ARG1_:.*]]: tensor<*xf32>, %[[ARG0_:.*]]: tensor<*xf32>): - // CHECK-DAG: %[[ZERO:.*]] = "chlo.constant_like"(%[[ARG0_]]) {value = 0.0{{.*}}+00 : f32} - // CHECK-DAG: %[[PRED:.*]] = mhlo.compare GT, %[[ARG0_]], %[[ZERO]] - // CHECK-DAG: %[[INNER_INNER_RES:.*]] = mhlo.select %[[PRED]], %[[ARG1_]], %[[ZERO]] - // CHECK: "chlo.rank_specialization_cluster_yield"(%[[INNER_INNER_RES]]) - // CHECK: shape.assuming_yield %[[INNER_RES]] - // CHECK: return %[[RES]] - %0 = shape.shape_of %arg0 : tensor<*xf32> -> tensor - %1 = shape.shape_of %arg1 : tensor<*xf32> -> tensor - %2 = shape.cstr_eq %0, %1 : tensor, tensor - %3 = shape.assuming %2 -> tensor<*xf32> { - %4 = "chlo.constant_like"(%arg0) {value = 0.000000e+00 : f32} - : (tensor<*xf32>) -> tensor<*xf32> - %5 = "mhlo.compare"(%arg0, %4) {comparison_direction = #mhlo} - : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xi1> - %6 = "mhlo.select"(%5, %arg1, %4) - : (tensor<*xi1>, tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> - shape.assuming_yield %6 : tensor<*xf32> - } - func.return %3 : tensor<*xf32> -} - -// CHECK-SCF-LABEL: @relu_grad -// CHECK-SCF-SAME: (%[[ARG0:.*]]: tensor<*xf32>, %[[ARG1:.*]]: tensor<*xf32>) -// CHECK-SCF-DAG: %[[S0:.*]] = shape.shape_of %[[ARG0]] -// CHECK-SCF-DAG: %[[S1:.*]] = shape.shape_of %[[ARG1]] -// CHECK-SCF-DAG: %[[CSTR_EQ:.*]] = shape.cstr_eq %0, %1 -// CHECK-SCF: %[[RES:.*]] = shape.assuming %[[CSTR_EQ]] -// CHECK-SCF-DAG: %[[S0:.*]] = shape.shape_of %[[ARG0]] -// CHECK-SCF-DAG: %[[S1:.*]] = shape.shape_of %[[ARG1]] -// CHECK-SCF-DAG: %[[S:.*]] = shape.any %[[S1]], %[[S0]] -// CHECK-SCF-DAG: %[[N:.*]] = shape.num_elements %[[S]] -// CHECK-SCF-DAG: %[[FLAT_SHAPE:.*]] = tensor.from_elements %[[N]] -// CHECK-SCF-DAG: %[[FLAT0:.*]] = mhlo.dynamic_reshape %[[ARG0]], %[[FLAT_SHAPE]] -// CHECK-SCF-DAG: %[[FLAT1:.*]] = mhlo.dynamic_reshape %[[ARG1]], %[[FLAT_SHAPE]] -// CHECK-SCF-DAG: %[[ZERO:.*]] = "chlo.constant_like"(%[[FLAT0]]) {value = 0.0{{.*}}+00 : f32} -// CHECK-SCF-DAG: %[[PRED:.*]] = mhlo.compare GT, %[[FLAT0]], %[[ZERO]] -// CHECK-SCF: %[[UNSHAPED_RES:.*]] = mhlo.select %[[PRED]], %[[FLAT1]], %[[ZERO]] -// CHECK-SCF-DAG: %[[INNER_RES:.*]] = mhlo.dynamic_reshape %[[UNSHAPED_RES]], %[[S1]] -// CHECK-SCF: shape.assuming_yield %[[INNER_RES]] -// CHECK-SCF: return %[[RES]] diff --git a/third_party/xla/xla/mlir_hlo/tests/test_userange.mlir b/third_party/xla/xla/mlir_hlo/tests/test_userange.mlir deleted file mode 100644 index 88b01dcf8c9acb..00000000000000 --- a/third_party/xla/xla/mlir_hlo/tests/test_userange.mlir +++ /dev/null @@ -1,118 +0,0 @@ -// RUN: mlir-hlo-opt -test-print-userange -split-input-file %s | FileCheck %s - -// CHECK-LABEL: Testing : func_empty -func.func @func_empty() { - func.return -} -// CHECK: ---- UserangeAnalysis ----- -// CHECK-NEXT: --------------------------- - -// ----- - -// CHECK-LABEL: Testing : useRangeGap -func.func @useRangeGap(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) -{ - %0 = memref.alloc() : memref<2xf32> - %1 = memref.alloc() : memref<2xf32> - cf.cond_br %arg0, ^bb1, ^bb2 -^bb1: - "lmhlo.negate"(%arg1, %0) : (memref<2xf32>, memref<2xf32>) -> () - "lmhlo.negate"(%arg1, %1) : (memref<2xf32>, memref<2xf32>) -> () - cf.br ^bb3 -^bb2: - "lmhlo.negate"(%arg2, %0) : (memref<2xf32>, memref<2xf32>) -> () - "lmhlo.negate"(%arg2, %1) : (memref<2xf32>, memref<2xf32>) -> () - cf.br ^bb3 -^bb3: - func.return -} -// CHECK: Value: %[[A0:.*]] = memref.alloc -// CHECK-NEXT: Userange: {(7, 7), (13, 13)} -// CHECK: Value: %[[A1:.*]] = memref.alloc -// CHECK-NEXT: Userange: {(9, 9), (15, 15)} -// CHECK: %[[A0]] = memref.alloc -// CHECK: %[[A1]] = memref.alloc - -// ----- - -// CHECK-LABEL: Testing : loopWithNestedRegion -func.func @loopWithNestedRegion(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) -{ - %0 = memref.alloc() : memref<2xf32> - %1 = memref.alloc() : memref<2xf32> - %2 = memref.alloc() : memref<2xf32> - %3 = memref.alloc() : memref<2xf32> - cf.br ^bb1 -^bb1: - %4 = scf.if %arg0 -> (memref<2xf32>) { - "lmhlo.negate"(%arg1, %0) : (memref<2xf32>, memref<2xf32>) -> () - scf.yield %2 : memref<2xf32> - } else { - "lmhlo.negate"(%arg1, %1) : (memref<2xf32>, memref<2xf32>) -> () - scf.yield %2 : memref<2xf32> - } - cf.br ^bb2 -^bb2: - cf.cond_br %arg0, ^bb1, ^bb3 -^bb3: - "lmhlo.negate"(%arg1, %2) : (memref<2xf32>, memref<2xf32>) -> () - "lmhlo.negate"(%arg1, %3) : (memref<2xf32>, memref<2xf32>) -> () - func.return -} -// CHECK: Value: %[[A0:.*]] = memref.alloc -// CHECK-NEXT: Userange: {(11, 23)} -// CHECK: Value: %[[A1:.*]] = memref.alloc -// CHECK-NEXT: Userange: {(11, 23)} -// CHECK: Value: %[[A2:.*]] = memref.alloc -// CHECK-NEXT: Userange: {(11, 25)} -// CHECK: Value: %[[A3:.*]] = memref.alloc -// CHECK-NEXT: Userange: {(27, 27)} -// CHECK: Value: %[[A4:.*]] = scf.if -// CHECK: Userange: {(19, 19)} -// CHECK: %[[A0]] = memref.alloc -// CHECK: %[[A1]] = memref.alloc -// CHECK: %[[A2]] = memref.alloc -// CHECK: %[[A3]] = memref.alloc -// CHECK: %[[A4]] = scf.if - -// ----- - -// CHECK-LABEL: Testing : condBranchWithAlias -func.func @condBranchWithAlias(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) -{ - %0 = memref.alloc() : memref<2xf32> - cf.cond_br %arg0, ^bb1, ^bb2 -^bb1: - "lmhlo.negate"(%arg1, %0) : (memref<2xf32>, memref<2xf32>) -> () - cf.br ^bb3(%0 : memref<2xf32>) -^bb2: - %1 = memref.alloc() : memref<2xf32> - "lmhlo.negate"(%arg1, %1) : (memref<2xf32>, memref<2xf32>) -> () - cf.br ^bb3(%1 : memref<2xf32>) -^bb3(%2 : memref<2xf32>): - %3 = memref.alloc() : memref<2xf32> - "lmhlo.copy"(%2, %arg2) : (memref<2xf32>, memref<2xf32>) -> () - "lmhlo.copy"(%3, %arg2) : (memref<2xf32>, memref<2xf32>) -> () - %4 = memref.alloc() : memref<2xf32> - "lmhlo.copy"(%4, %arg2) : (memref<2xf32>, memref<2xf32>) -> () - cf.br ^bb4(%0 : memref<2xf32>) -^bb4(%5 : memref<2xf32>): - "lmhlo.copy"(%5, %arg2) : (memref<2xf32>, memref<2xf32>) -> () - func.return -} -// CHECK: Value: %[[A0:.*]] = memref.alloc -// CHECK-NEXT: Userange: {(5, 7), (15, 27)} -// CHECK: Value: %[[A1:.*]] = memref.alloc -// CHECK-NEXT: Userange: {(11, 17)} -// CHECK: Value: %[[A2:.*]] = memref.alloc -// CHECK-NEXT: Userange: {(19, 19)} -// CHECK: Value: %[[A3:.*]] = memref.alloc -// CHECK-NEXT: Userange: {(23, 23)} -// CHECK: Value: of type 'memref<2xf32>' at index: 0 -// CHECK-SAME: Userange: {(15, 17)} -// CHECK: Value: of type 'memref<2xf32>' at index: 0 -// CHECK-SAME: Userange: {(27, 27)} -// CHECK: %[[A0]] = memref.alloc -// CHECK: %[[A1]] = memref.alloc -// CHECK: %[[A2]] = memref.alloc -// CHECK: %[[A3]] = memref.alloc diff --git a/third_party/xla/xla/mlir_hlo/tests/vectorize_copy.mlir b/third_party/xla/xla/mlir_hlo/tests/vectorize_copy.mlir index 8c57281a7041c9..fa852549c90578 100644 --- a/third_party/xla/xla/mlir_hlo/tests/vectorize_copy.mlir +++ b/third_party/xla/xla/mlir_hlo/tests/vectorize_copy.mlir @@ -1,7 +1,6 @@ // RUN: mlir-hlo-opt %s --vectorize-copy --split-input-file | FileCheck %s -func.func @vectorize_copy(%arg: memref<2x2xf32>) -> memref<2x2xf32> { - %subview = memref.subview %arg[0, 0] [2, 2] [1, 1] : memref<2x2xf32> to memref<2x2xf32, strided<[16, 1]>> +func.func @vectorize_copy(%subview: memref<2x2xf32, strided<[16, 1]>>) -> memref<2x2xf32> { %alloc = memref.alloc() : memref<2x2xf32> memref.copy %subview, %alloc : memref<2x2xf32, strided<[16, 1]>> to memref<2x2xf32> return %alloc : memref<2x2xf32> diff --git a/third_party/xla/xla/mlir_hlo/tools/mlir-hlo-opt/CMakeLists.txt b/third_party/xla/xla/mlir_hlo/tools/mlir-hlo-opt/CMakeLists.txt index db2dbf2e60ac47..48c132706fc2a5 100644 --- a/third_party/xla/xla/mlir_hlo/tools/mlir-hlo-opt/CMakeLists.txt +++ b/third_party/xla/xla/mlir_hlo/tools/mlir-hlo-opt/CMakeLists.txt @@ -28,9 +28,7 @@ set(LIBS LmhloGPUDialect LmhloPasses MLIRBufferTransforms - MLIRHLOAnalysis MLIRHLOGPUTransforms - MLIRHLOTestAnalysis MhloRegisterDialects MhloTestAnalysis ) diff --git a/third_party/xla/xla/mlir_hlo/transforms/CMakeLists.txt b/third_party/xla/xla/mlir_hlo/transforms/CMakeLists.txt index f1e9d6004c967e..45c8f24a796f5f 100644 --- a/third_party/xla/xla/mlir_hlo/transforms/CMakeLists.txt +++ b/third_party/xla/xla/mlir_hlo/transforms/CMakeLists.txt @@ -24,7 +24,6 @@ add_public_tablegen_target(LMHLOGPUTransformsPassIncGen) add_mlir_library(MLIRBufferTransforms alloc_to_arg_pass.cc - buffer_packing.cc bufferize.cc bufferize_pass.cc collapse_parallel_loops_to_1d_pass.cc @@ -49,7 +48,6 @@ add_mlir_library(MLIRBufferTransforms LINK_LIBS PUBLIC ChloOps MLIRGPUDialect - MLIRHLOAnalysis MLIRIR MLIRMathTransforms MLIRPass @@ -75,7 +73,6 @@ add_mlir_library(MLIRHLOGPUTransforms LINK_LIBS PUBLIC MLIRArithTransforms MLIRGPUDialect - MLIRHLOAnalysis MLIRIR MLIRMemRefTransforms MLIRPass diff --git a/third_party/xla/xla/mlir_hlo/transforms/buffer_packing.cc b/third_party/xla/xla/mlir_hlo/transforms/buffer_packing.cc deleted file mode 100644 index 6bf8c1899608b2..00000000000000 --- a/third_party/xla/xla/mlir_hlo/transforms/buffer_packing.cc +++ /dev/null @@ -1,494 +0,0 @@ -/* Copyright 2021 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include -#include -#include -#include -#include -#include -#include - -#include "analysis/userange_analysis.h" -#include "mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/Dialect/Bufferization/Transforms/BufferUtils.h" -#include "mlir/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Dialect/MemRef/IR/MemRef.h" -#include "mlir/IR/Operation.h" -#include "mlir/Pass/Pass.h" -#include "transforms/passes.h" -#include "utils/hlo_utils.h" - -namespace mlir { - -#define GEN_PASS_DEF_BUFFERPACKING -#define GEN_PASS_DEF_MEMORYCOUNT -#include "transforms/passes.h.inc" - -namespace { - -/// Returns the length of an userange interval. -size_t computeUserangeSize(const UseInterval &interval) { - return interval.end - interval.start + 1; -} - -/// Compute the byte size of a given Value. -size_t computeByteSize(const Value &v) { - auto type = v.getType().cast(); - return type.getNumElements() * type.getElementTypeBitWidth() / 8; -} - -/// Compute the 64 byte alinged segments of a given Value. -size_t computeAlignedSegments(const Value &v) { - size_t padding = 64; - size_t bytes = computeByteSize(v); - return std::ceil(bytes / (double)padding); -} - -/// The buffer offset information. -struct AllocBufferOffset { - public: - AllocBufferOffset(Value source, size_t offset) - : source(source), offset(offset) {} - - Value source; - size_t offset; -}; - -/// Contains the information to create a new buffer, that is used to pack -/// other buffers. -struct PackedBuffer { - public: - PackedBuffer(size_t numSegments, - std::vector &packedBuffers) - : numSegments(numSegments), allocBufferOffsets(packedBuffers) {} - - size_t numSegments; - std::vector allocBufferOffsets; -}; - -/// Contains the information about a buffers allocation for sorting and checking -/// if it fits into other buffers and vise versa. -/// This structure contains the allocation value, the first and last userangeid -/// of a buffer, the window id, the number of alligned 64 byte segments and all -/// userange intervals. -struct AllocationInfo { - public: - AllocationInfo(Value alloc, size_t allocUserangeId, size_t firstUse, - size_t lastUse, size_t numSegments, size_t windowId, - const UseInterval::Vector *userangeIntervals) - : alloc(alloc), - allocUserangeId(allocUserangeId), - firstUse(firstUse), - lastUse(lastUse), - numSegments(numSegments), - windowId(windowId), - userangeIntervals(userangeIntervals) {} - - /// The allocation value. - Value alloc; - - /// The id of allocation based on the Userange Analysis. - size_t allocUserangeId; - - /// The first use of the buffer. - size_t firstUse; - - /// The last use of the buffer based on the Userange Analysis. - size_t lastUse; - - /// The number of 64 byte aligned segments of contigous memory. - size_t numSegments; - - /// The window id of the allocation position. - size_t windowId; - - /// The userange intervals of the buffer. - const UseInterval::Vector *userangeIntervals; - - /// Compute the gaps of the alloc userange with the number of segments. The - /// maxUserangeId is used to add a dummy gap from the last used id to the - /// maxUserangeId. By default the maxUserangeId is zero and no gap is added. - std::list> computeGaps( - size_t maxUserangeId = 0) { - std::list> gaps; - - // The previous gap ending, initially set to 0. - size_t gapEnd = 0; - - for (const auto *useRangeIter = userangeIntervals->begin(); - useRangeIter < userangeIntervals->end(); ++useRangeIter) { - // Add a gap if the end is not equal to the start. - if (gapEnd < useRangeIter->start) - gaps.emplace_back(UseInterval(gapEnd, useRangeIter->start - 1), - numSegments); - gapEnd = useRangeIter->end + 1; - } - - // Add a dummy gap behind the last use of the buffer. - if (gapEnd < maxUserangeId) { - gaps.emplace_back(UseInterval(gapEnd, maxUserangeId), numSegments); - } - - return gaps; - } - - /// Compute the userange size. - size_t getUserangeSize() const { return lastUse - firstUse + 1; } -}; - -// Comparator to sort allocation informations by window id, userange and by -// number of memory segments. -class AllocInfoWinIdComparator { - public: - bool operator()(const AllocationInfo &a, const AllocationInfo &b) { - if (a.windowId == b.windowId) { - if (a.allocUserangeId == b.allocUserangeId) - return a.numSegments > b.numSegments; - return a.allocUserangeId > b.allocUserangeId; - } - return a.windowId < b.windowId; - } -}; - -// Comparator to sort the allocation informations by number of segments. -class AllocInfoMemSizeCompare { - public: - bool operator()(const AllocationInfo &a, const AllocationInfo &b) { - return a.numSegments > b.numSegments; - } -}; - -/// This approach computes an allocation information list and sorts it by -/// a given comparator. From top to bottom the algortihm tries to fill userange -/// gaps with appropriate buffers behind it, to optimze the memory. It is a bin -/// packing approach. -template -class SortedPackingStrategy { - public: - using AllocInfoList = std::vector; - - public: - /// Constructs the Sorted Packing Strategy. The window size is used as sliding - /// window size. Allocation userangepositions that are in the same range are - /// mapped to the same window id. So the information of the allocation - /// starting position is blured. - SortedPackingStrategy(size_t windowSize, CompareT compare) - : windowSize(windowSize), compare(compare) {} - - /// Optimize the buffer allocations. - void optimze(const mlir::bufferization::BufferPlacementAllocs &allocs, - const UserangeAnalysis &userangeAnalysis, - std::vector &packedBuffers) { - AllocInfoList allocInfos; - allocInfos.reserve(std::distance(allocs.begin(), allocs.end())); - - // Create allocInformations and store them in allocInfos. - size_t maxUserangeId = - computeAllocationInfos(allocInfos, userangeAnalysis, allocs); - - // Sort the allocation infos. - std::sort(allocInfos.begin(), allocInfos.end(), compare); - - for (auto currentIter = allocInfos.begin(); currentIter != allocInfos.end(); - ++currentIter) { - std::vector allocBufferOffsets{ - AllocBufferOffset(currentIter->alloc, 0)}; - - // Compute userange gaps. - std::list> gaps = - currentIter->computeGaps(maxUserangeId); - - if (gaps.empty()) continue; - - for (auto checkedAllocInfoIter = std::next(currentIter); - checkedAllocInfoIter != allocInfos.end();) { - // Check if a gap exists to pack the memory into. - // If not continue. - if (!findGapAndUpdate(gaps, allocBufferOffsets, *checkedAllocInfoIter, - *currentIter)) { - ++checkedAllocInfoIter; - continue; - } - checkedAllocInfoIter = allocInfos.erase(checkedAllocInfoIter); - } - // Add the current buffer offets to the packed infos. - packedBuffers.emplace_back(currentIter->numSegments * 64, - allocBufferOffsets); - } - } - - private: - const size_t windowSize; - const CompareT compare; - - /// We try to find an appropriate userange gap to pack the buffer into it. - /// If we find one we update only the gaps and the buffer offset map. - bool findGapAndUpdate(std::list> &gaps, - std::vector &allocBufferOffsets, - const AllocationInfo &allocToPack, - const AllocationInfo &allocToPackInto) { - // Check if the buffer to pack into has enough memory. - if (allocToPackInto.numSegments < allocToPack.numSegments) return false; - for (auto gapIter = gaps.begin(); gapIter != gaps.end();) { - // The list is sorted, so we can break here. - if (gapIter->first.start > allocToPack.firstUse) break; - - // Checks if enough contiguous memory segments are free or if the current - // gap is out of bounds. - if (gapIter->second < allocToPack.numSegments || - allocToPack.firstUse < gapIter->first.start || - allocToPack.lastUse > gapIter->first.end) { - ++gapIter; - continue; - } - - // Stores the packed buffer with the offset. - allocBufferOffsets.emplace_back( - allocToPack.alloc, - (allocToPackInto.numSegments - gapIter->second) * 64); - - // Update gap segments, will removed later if no free contigous memory - // exists. It is needed to split the interval, if not the full gap is - // used. - size_t freeContiguousMemory = gapIter->second; - gapIter->second = freeContiguousMemory - allocToPack.numSegments; - - // Check if the gap must be splitted. If so, then the current gap must be - // trimmed accordingly. Therefore, new gaps are created in front and after - // the current gap. - if (computeUserangeSize(gapIter->first) > allocToPack.getUserangeSize()) { - size_t oldStart = gapIter->first.start; - size_t oldEnd = gapIter->first.end; - gapIter->first.end = allocToPack.lastUse; - gapIter->first.start = allocToPack.firstUse; - - // Insert a new gap behind. - if (allocToPack.lastUse < oldEnd) - gaps.insert( - std::next(gapIter), - std::make_pair(UseInterval(allocToPack.lastUse + 1, oldEnd), - freeContiguousMemory)); - // Insert a new gap before. - if (allocToPack.firstUse > oldStart) - gaps.insert( - gapIter, - std::make_pair(UseInterval(oldStart, allocToPack.firstUse - 1), - freeContiguousMemory)); - } - - // If a gap interval has no free contiguous memory anymore, erease it from - // list. - if (gapIter->second <= 0) gapIter = gaps.erase(gapIter); - - return true; - } - return false; - } - - /// Aggreagtes the allocation informations of the allocs and returns the - /// maximal userange. - size_t computeAllocationInfos( - AllocInfoList &allocInfos, const UserangeAnalysis &userangeAnalysis, - const mlir::bufferization::BufferPlacementAllocs &allocs) { - // Create allocInformations and store them in allocInfos. - size_t maxUserangeId = 0; - - for (auto &allocEntry : allocs) { - Value v = std::get<0>(allocEntry); - auto userangeIntervals = userangeAnalysis.getUserangeInterval(v); - - if (!userangeIntervals) continue; - - // Computes the userange id of the allocation. - size_t allocUserangeId = userangeAnalysis.computeId(v, v.getDefiningOp()); - - // Computes the last use of the allocated buffer. - size_t lastUse = std::prev((*userangeIntervals.value()).end())->end; - - // Computes the first use of the allocated buffer. - size_t firstUse = (*userangeIntervals.value()).begin()->start; - - // Computes the number of aligend segments of the buffer. - size_t numSegments = computeAlignedSegments(v); - maxUserangeId = std::max(maxUserangeId, lastUse); - allocInfos.emplace_back(v, allocUserangeId, firstUse, lastUse, - numSegments, 0, userangeIntervals.value()); - } - - // If the window size is zero we need no sorting anymore. - if (windowSize == 0) return maxUserangeId; - // Sorts the allocation informations to compute the window id. The window id - // is used to blur the userange starting position of an allocation. - std::sort(allocInfos.begin(), allocInfos.end(), - [](const AllocationInfo &a, const AllocationInfo &b) { - return a.allocUserangeId < b.allocUserangeId; - }); - - // resize window id - size_t windowId = 0; - size_t lastAllocUserangeId = 0; - for (auto &allocationInfo : allocInfos) { - if (allocationInfo.allocUserangeId > lastAllocUserangeId + windowSize) - ++windowId; - - lastAllocUserangeId = allocationInfo.allocUserangeId; - allocationInfo.windowId = windowId; - } - return maxUserangeId; - } -}; - -/// Pass to pack buffer together to optimize the memeory consumption and to -/// save allocation operations. A strategy must be passed as a template -/// argument. -class BufferPacking : bufferization::BufferPlacementTransformationBase { - public: - template - BufferPacking(Operation *op, StrategyT strategy) - : BufferPlacementTransformationBase(op), - userangeAnalysis(op, allocs, aliases), - dominators(op) { - std::vector packedBuffers; - strategy.optimze(allocs, userangeAnalysis, packedBuffers); - - for (auto &packedBuffer : packedBuffers) { - // Find common dominators. - Block *block = findAllocationsDominator(packedBuffer.allocBufferOffsets); - // Find alloc position operation. - mlir::OpBuilder packBuilder(&(block->front())); - auto location = block->front().getLoc(); - auto memrefType = - MemRefType::get({static_cast(packedBuffer.numSegments)}, - packBuilder.getIntegerType(8)); - Value targetBuffer = - packBuilder.create(location, memrefType); - - for (auto &packInfo : packedBuffer.allocBufferOffsets) { - Value currentAlloc = packInfo.source; - size_t offset = packInfo.offset; - Operation *viewDefOp = currentAlloc.getDefiningOp(); - Location loc = viewDefOp->getLoc(); - mlir::OpBuilder viewBuilder(viewDefOp); - - // Create a arithmetic ConstantOp with the aligned offset. - Value constantOp = viewBuilder.create( - loc, viewBuilder.getIndexType(), - viewBuilder.getIntegerAttr(viewBuilder.getIndexType(), offset)); - - // Store the operands for the ViewOp. - SmallVector newOperands{targetBuffer}; - newOperands.push_back(constantOp); - - auto shape = currentAlloc.getType().cast(); - - // Create a ViewOp with the shape of the old alloc and use the created - // packed alloc and the constant for the operands. - Value viewOp = - viewBuilder.create(loc, shape, newOperands); - - // Replace all old allocs references with the created ViewOp and - // afterwards remove the old allocs. - currentAlloc.replaceAllUsesWith(viewOp); - viewDefOp->erase(); - } - } - } - - private: - UserangeAnalysis userangeAnalysis; - /// The current dominance info. - DominanceInfo dominators; - - /// Find the block that dominates all buffer allocations. - Block *findAllocationsDominator( - const std::vector &packingInfos) { - SmallPtrSet allocValues; - for (auto &packInfo : packingInfos) { - allocValues.insert(packInfo.source); - } - - // Find common dominators. - return bufferization::findCommonDominator(packingInfos.begin()->source, - allocValues, dominators); - } -}; - -/// Tries to pack allocated buffer together to save allocation operations and -/// memory. The window size is used as sliding window size. Allocation -/// userangepoitions that are in the same range are mapped to the same window -/// id. The information of the allocation starting position is blured. -struct BufferPackingPass : public impl::BufferPackingBase { - explicit BufferPackingPass(unsigned windowSize) { - this->window_size_ = windowSize; - } - - void runOnOperation() override { - if (window_size_ == 0) { - SortedPackingStrategy strategy( - window_size_, AllocInfoMemSizeCompare()); - BufferPacking packing(getOperation(), strategy); - } else { - SortedPackingStrategy strategy( - window_size_, AllocInfoWinIdComparator()); - BufferPacking packing(getOperation(), strategy); - } - } -}; - -/// Pass to find all allocations and to compute memory usage. -struct MemoryCountPass : impl::MemoryCountBase { - void runOnOperation() override { - Operation *op = getOperation(); - std::vector allocs; - op->walk([&](MemoryEffectOpInterface opInterface) { - // Try to find a single allocation result. - SmallVector effects; - opInterface.getEffects(effects); - - SmallVector allocateResultEffects; - llvm::copy_if( - effects, std::back_inserter(allocateResultEffects), - [=](MemoryEffects::EffectInstance &it) { - Value value = it.getValue(); - return isa(it.getEffect()) && value && - value.isa() && - it.getResource() != - SideEffects::AutomaticAllocationScopeResource::get(); - }); - - if (allocateResultEffects.size() != 1) return; - // Insert allocation. - allocs.push_back(allocateResultEffects[0].getValue()); - }); - auto output = mlir::hlo::computeMemory(allocs); - llvm::outs() << "Memory Count Pass:\n" - << output.first << ";" << output.second << "\n"; - } -}; - -} // namespace - -std::unique_ptr> createBufferPackingPass( - unsigned windowSize) { - return std::make_unique(windowSize); -} - -std::unique_ptr> createMemoryCountPass() { - return std::make_unique(); -} - -} // namespace mlir diff --git a/third_party/xla/xla/mlir_hlo/transforms/passes.h b/third_party/xla/xla/mlir_hlo/transforms/passes.h index 6012a587d50e64..a668cc250dd81c 100644 --- a/third_party/xla/xla/mlir_hlo/transforms/passes.h +++ b/third_party/xla/xla/mlir_hlo/transforms/passes.h @@ -44,7 +44,6 @@ using BufferizePatternsCallback = std::function> createBufferPackingPass( - unsigned windowSize = 5); - -/// Creates a pass that tests the useranges of the UserangeAnalysis. -std::unique_ptr> createTestUserangePass(); - /// Creates a pass that prints the analysis results of ShapeComponentsAnalysis. std::unique_ptr> createTestShapeComponentAnalysisPass(); -/// Creates a pass that computes the allocated memory. -std::unique_ptr> createMemoryCountPass(); - // Pass to lower index cast on tensors to tensor dialect. std::unique_ptr> createLowerIndexCastPass(); diff --git a/third_party/xla/xla/mlir_hlo/transforms/passes.td b/third_party/xla/xla/mlir_hlo/transforms/passes.td index 2d979c8508eb35..e5a36ae6ca03af 100644 --- a/third_party/xla/xla/mlir_hlo/transforms/passes.td +++ b/third_party/xla/xla/mlir_hlo/transforms/passes.td @@ -18,26 +18,6 @@ limitations under the License. include "mlir/Pass/PassBase.td" -def BufferPacking : Pass<"buffer-packing", "func::FuncOp"> { - let summary = "Pass to pack allocated buffer to reduce memory consumption."; - let description = [{The pass tries to pack smaller buffers into larger buffers. - To do this, it sorts all allocated buffers by multiple criteria depends on the - selected window-size. - After this sorting, the buffers are checked whether subsequent buffers can be - packed into them.}]; - let dependentDialects = ["func::FuncDialect","memref::MemRefDialect", - "arith::ArithDialect"]; - let constructor = "createBufferPackingPass()"; - let options = [ - Option<"window_size_", "window-size", "unsigned", - /*default=*/"5", "The window size blurs the start position of an" - "allocated buffer. Buffers allocated in the same sliding window area" - "are treated equally in terms of starting position, withing the" - "sliding window area they are sorted by memory size." - "A window size of zero sorts the buffers only by memory size.">, - ]; -} - def CollapseParallelLoopsTo1DPass : Pass<"collapse-parallel-loops-to-1d"> { let summary = "Collapses multidimensional loops."; let description = [{ The pass converts a multidimensional `scf.parallel` loop @@ -71,18 +51,6 @@ def TileLoopsPass : Pass<"tile-loops", "func::FuncOp"> { let dependentDialects = ["affine::AffineDialect"]; } -def MemoryCount : Pass<"memory-count", "func::FuncOp"> { - let summary = "Test pass to count the allocated memory of a module."; - let description = [{A test pass that prints the size of allocated memory of a - module.}]; - let constructor = "createMemoryCountPass()"; -} - -def TestUserange : Pass<"test-print-userange", "func::FuncOp"> { - let summary = "Test pass for checking userange intervals."; - let constructor = "createTestUserangePass()"; -} - def TestShapeComponentAnalysis : Pass<"test-print-shape-components", "func::FuncOp"> { let summary = "Test pass for analyzing shape components."; diff --git a/third_party/xla/xla/mlir_hlo/transforms/rewriters.h b/third_party/xla/xla/mlir_hlo/transforms/rewriters.h index 8e5fc500cebe24..6e484b22370507 100644 --- a/third_party/xla/xla/mlir_hlo/transforms/rewriters.h +++ b/third_party/xla/xla/mlir_hlo/transforms/rewriters.h @@ -31,11 +31,6 @@ void populateExtraBufferizePatterns( MLIRContext *context, bufferization::BufferizeTypeConverter *converter, RewritePatternSet *patterns); -/// Populate pattern to bufferize `linalg.tiled_loop`. -void populateTiledLoopBufferizePattern( - MLIRContext *context, bufferization::BufferizeTypeConverter *converter, - RewritePatternSet *patterns); - } // namespace mlir #endif // MLIR_HLO_TRANSFORMS_REWRITERS_H diff --git a/third_party/xla/xla/mlir_hlo/utils/hlo_utils.cc b/third_party/xla/xla/mlir_hlo/utils/hlo_utils.cc index 96a67fca751b8e..fc02ada8840de2 100644 --- a/third_party/xla/xla/mlir_hlo/utils/hlo_utils.cc +++ b/third_party/xla/xla/mlir_hlo/utils/hlo_utils.cc @@ -24,14 +24,19 @@ limitations under the License. #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/Attributes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Value.h" +#include "mlir/Support/LLVM.h" namespace mlir { namespace hlo { static constexpr size_t kPaddingSize = 64; -DenseIntElementsAttr getBroadcastDimensionsAttr(Builder* b, Value x, Value y, - bool allowEmpty) { +DenseI64ArrayAttr getBroadcastDimensionsAttr(Builder* b, Value x, Value y, + bool allowEmpty) { TensorType xType = x.getType().dyn_cast(); TensorType yType = y.getType().dyn_cast(); if (!xType || !yType) return {}; @@ -57,9 +62,7 @@ DenseIntElementsAttr getBroadcastDimensionsAttr(Builder* b, Value x, Value y, std::iota(broadcastDimensions.begin(), broadcastDimensions.end(), maxRank - minRank); - RankedTensorType type = - RankedTensorType::get({minRank}, b->getIntegerType(64)); - return DenseIntElementsAttr::get(type, broadcastDimensions); + return b->getDenseI64ArrayAttr(broadcastDimensions); } DenseElementsAttr getScalarOfType(Type ty, int64_t rawValue) { diff --git a/third_party/xla/xla/mlir_hlo/utils/hlo_utils.h b/third_party/xla/xla/mlir_hlo/utils/hlo_utils.h index d4a3fb141f2597..72f22992b4a944 100644 --- a/third_party/xla/xla/mlir_hlo/utils/hlo_utils.h +++ b/third_party/xla/xla/mlir_hlo/utils/hlo_utils.h @@ -36,10 +36,9 @@ namespace hlo { // between two ranked tensors. // If `allow_empty` is true, then null can be returned to mean that the // broadcast is an "identity". -mlir::DenseIntElementsAttr getBroadcastDimensionsAttr(mlir::Builder* b, - mlir::Value x, - mlir::Value y, - bool allowEmpty = true); +mlir::DenseI64ArrayAttr getBroadcastDimensionsAttr(mlir::Builder* b, + mlir::Value x, mlir::Value y, + bool allowEmpty = true); // Get a constant splat for the given value of type. Requires value to be of // type static shaped RankedTensorType. diff --git a/third_party/xla/xla/pjrt/BUILD b/third_party/xla/xla/pjrt/BUILD index c12143b22c4644..4dca26b8a5c6a6 100644 --- a/third_party/xla/xla/pjrt/BUILD +++ b/third_party/xla/xla/pjrt/BUILD @@ -1,5 +1,6 @@ # Placeholder: load py_proto_library load("//xla:xla.bzl", "xla_cc_test") +load("@local_tsl//tsl:tsl.bzl", "internal_visibility") load( "@local_tsl//tsl/platform:build_config.bzl", "tf_proto_library", @@ -7,7 +8,8 @@ load( load("@local_tsl//tsl/platform:rules_cc.bzl", "cc_library") package( - default_visibility = ["//visibility:public"], + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = internal_visibility(["//xla:internal"]), licenses = ["notice"], ) @@ -28,7 +30,6 @@ cc_library( name = "worker_thread", srcs = ["worker_thread.cc"], hdrs = ["worker_thread.h"], - visibility = ["//visibility:public"], deps = [ "@com_google_absl//absl/synchronization", "@local_tsl//tsl/platform:env", @@ -39,7 +40,6 @@ cc_library( name = "event_pool", srcs = ["event_pool.cc"], hdrs = ["event_pool.h"], - visibility = ["//visibility:public"], deps = [ "//xla:status_macros", "//xla:statusor", @@ -53,7 +53,6 @@ cc_library( name = "semaphore", srcs = ["semaphore.cc"], hdrs = ["semaphore.h"], - visibility = ["//visibility:public"], deps = [ "//xla:types", "@com_google_absl//absl/synchronization", @@ -77,7 +76,6 @@ cc_library( name = "tracked_device_buffer", srcs = ["tracked_device_buffer.cc"], hdrs = ["tracked_device_buffer.h"], - visibility = ["//visibility:public"], deps = [ ":event_pool", ":local_device_state", @@ -118,7 +116,6 @@ cc_library( name = "local_device_state", srcs = ["local_device_state.cc"], hdrs = ["local_device_state.h"], - visibility = ["//visibility:public"], deps = [ ":event_pool", ":pjrt_common", @@ -128,7 +125,9 @@ cc_library( "//xla:util", "//xla/client:local_client", "//xla/stream_executor", + "@com_google_absl//absl/log:check", "@com_google_absl//absl/synchronization", + "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/profiler/lib:traceme", "@local_tsl//tsl/protobuf:error_codes_proto_impl_cc", ], @@ -138,7 +137,6 @@ cc_library( name = "pjrt_api", srcs = ["pjrt_api.cc"], hdrs = ["pjrt_api.h"], - visibility = ["//visibility:public"], deps = [ "//xla:status", "//xla:statusor", @@ -170,7 +168,7 @@ cc_library( name = "pjrt_client", srcs = ["pjrt_client.cc"], hdrs = ["pjrt_client.h"], - visibility = ["//visibility:public"], + visibility = internal_visibility(["//xla:friends"]), deps = [ ":pjrt_common", ":pjrt_compiler", @@ -209,7 +207,6 @@ cc_library( testonly = 1, srcs = ["pjrt_client_test.cc"], hdrs = ["pjrt_client_test.h"], - visibility = ["//visibility:public"], deps = [ ":pjrt_client", "//xla:shape_util", @@ -231,7 +228,7 @@ cc_library( name = "pjrt_executable", srcs = ["pjrt_executable.cc"], hdrs = ["pjrt_executable.h"], - visibility = ["//visibility:public"], + visibility = internal_visibility([":friends"]), deps = [ ":compile_options_proto_cc", ":executable_metadata_proto_cc", @@ -279,7 +276,6 @@ xla_cc_test( cc_library( name = "pjrt_device_description", hdrs = ["pjrt_device_description.h"], - visibility = ["//visibility:public"], deps = [ ":pjrt_common", "@com_google_absl//absl/container:flat_hash_map", @@ -290,7 +286,7 @@ cc_library( name = "pjrt_compiler", srcs = ["pjrt_compiler.cc"], hdrs = ["pjrt_compiler.h"], - visibility = ["//visibility:public"], + visibility = internal_visibility([":friends"]), deps = [ ":metrics", ":pjrt_device_description", @@ -329,7 +325,7 @@ xla_cc_test( cc_library( name = "pjrt_common", hdrs = ["pjrt_common.h"], - visibility = ["//visibility:public"], + visibility = [":friends"], deps = [ "@local_tsl//tsl/lib/gtl:int_type", ], @@ -339,7 +335,7 @@ cc_library( name = "utils", srcs = ["utils.cc"], hdrs = ["utils.h"], - visibility = ["//visibility:public"], + visibility = internal_visibility(["//xla:friends"]), deps = [ ":layout_mode", "//xla:shape_util", @@ -370,7 +366,7 @@ cc_library( name = "layout_mode", srcs = ["layout_mode.cc"], hdrs = ["layout_mode.h"], - visibility = ["//visibility:public"], + visibility = ["//xla:friends"], deps = [ "//xla:shape_util", "//xla:status", @@ -384,7 +380,6 @@ cc_library( name = "metrics", srcs = ["metrics.cc"], hdrs = ["metrics.h"], - visibility = ["//visibility:public"], deps = [ "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/strings", @@ -408,7 +403,6 @@ cc_library( name = "stream_executor_executable", srcs = ["stream_executor_executable.cc"], hdrs = ["stream_executor_executable.h"], - visibility = ["//visibility:public"], deps = [ ":pjrt_executable", ":stream_executor_executable_proto_cc", @@ -424,7 +418,7 @@ cc_library( name = "pjrt_stream_executor_client", srcs = ["pjrt_stream_executor_client.cc"], hdrs = ["pjrt_stream_executor_client.h"], - visibility = ["//visibility:public"], + visibility = internal_visibility(["//xla:friends"]), deps = [ ":event_pool", ":local_device_state", @@ -526,7 +520,6 @@ cc_library( name = "interpreter_device", srcs = ["interpreter_device.cc"], hdrs = ["interpreter_device.h"], - visibility = ["//visibility:public"], deps = [ ":pjrt_stream_executor_client", "//xla:statusor", @@ -541,7 +534,7 @@ cc_library( name = "mlir_to_hlo", srcs = ["mlir_to_hlo.cc"], hdrs = ["mlir_to_hlo.h"], - visibility = ["//visibility:public"], + visibility = [":friends"], deps = [ "//xla:status", "//xla:statusor", @@ -579,7 +572,7 @@ cc_library( cc_library( name = "pjrt_future", hdrs = ["pjrt_future.h"], - visibility = ["//visibility:public"], + visibility = internal_visibility([":friends"]), deps = [ "@com_google_absl//absl/functional:any_invocable", "@local_tsl//tsl/concurrency:async_value", @@ -591,7 +584,9 @@ cc_library( cc_library( name = "tfrt_cpu_pjrt_client", hdrs = ["tfrt_cpu_pjrt_client.h"], - visibility = ["//visibility:public"], + visibility = internal_visibility([ + "//xla:friends", + ]), deps = [ "//xla/pjrt/cpu:cpu_client", ], @@ -600,7 +595,6 @@ cc_library( cc_library( name = "lru_cache", hdrs = ["lru_cache.h"], - visibility = ["//visibility:public"], deps = [ "@com_google_absl//absl/container:node_hash_map", "@local_tsl//tsl/platform:logging", @@ -624,7 +618,7 @@ cc_library( "transpose_kernels.h", ], hdrs = ["transpose.h"], - visibility = ["//visibility:public"], + visibility = [":friends"], deps = [ ":lru_cache", "//xla:compiler_macros", @@ -670,7 +664,6 @@ cc_library( name = "pjrt_c_api_client", srcs = ["pjrt_c_api_client.cc"], hdrs = ["pjrt_c_api_client.h"], - visibility = ["//visibility:public"], deps = [ ":compile_options_proto_cc", ":mlir_to_hlo", @@ -790,7 +783,7 @@ cc_library( name = "host_callback", srcs = ["host_callback.cc"], hdrs = ["host_callback.h"], - visibility = ["//visibility:public"], + visibility = [":friends"], deps = [ ":pjrt_client", ":pjrt_future", @@ -853,7 +846,6 @@ cc_library( "-fno-strict-aliasing", ], features = ["-use_header_modules"], - visibility = ["//visibility:public"], deps = [ "//xla:status", "@com_google_absl//absl/status", @@ -869,7 +861,7 @@ cc_library( "-fno-strict-aliasing", ], features = ["-use_header_modules"], - visibility = ["//visibility:public"], + visibility = [":friends"], deps = [ ":exceptions", "//xla:status", diff --git a/third_party/xla/xla/pjrt/c/BUILD b/third_party/xla/xla/pjrt/c/BUILD index d89a30fda57258..ef12625d7d5e5d 100644 --- a/third_party/xla/xla/pjrt/c/BUILD +++ b/third_party/xla/xla/pjrt/c/BUILD @@ -169,15 +169,22 @@ cc_library( "//xla/backends/profiler/plugin:plugin_tracer_impl", "//xla/backends/profiler/plugin:profiler_c_api_hdrs", "//xla/backends/profiler/plugin:profiler_error", + "//xla/client:local_client", "//xla/ffi", "//xla/ffi:ffi_api", "//xla/ffi/api:c_api", "//xla/pjrt:pjrt_client", "//xla/pjrt:pjrt_common", + "//xla/pjrt:pjrt_compiler", + "//xla/pjrt:pjrt_device_description", "//xla/pjrt/gpu:gpu_helpers", "//xla/pjrt/gpu:se_gpu_pjrt_client", + "//xla/pjrt/gpu:se_gpu_pjrt_compiler", # To register GPU AOT compiler "//xla/python:inspect_sharding", # To register "InspectSharding" custom partitioning handler. + "//xla/service:compiler", "//xla/service:custom_call_target_registry", + "//xla/stream_executor:device_description", + "//xla/stream_executor:stream_executor_headers", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/status", "@com_google_absl//absl/strings:str_format", @@ -189,7 +196,6 @@ cc_library( name = "pjrt_c_api_gpu", srcs = ["pjrt_c_api_gpu.cc"], hdrs = ["pjrt_c_api_gpu.h"], - visibility = ["//visibility:public"], deps = [ ":pjrt_c_api_gpu_internal", ":pjrt_c_api_hdrs", diff --git a/third_party/xla/xla/pjrt/c/CHANGELOG.md b/third_party/xla/xla/pjrt/c/CHANGELOG.md index 5bd553df96126c..4680cf8f70aa67 100644 --- a/third_party/xla/xla/pjrt/c/CHANGELOG.md +++ b/third_party/xla/xla/pjrt/c/CHANGELOG.md @@ -1,5 +1,12 @@ # PJRT C API changelog +## 0.42 +* Renamed all ``priv`` fields to ``extension_start`` + +## 0.41 +* Renamed PJRT_Structure_Base to PJRT_Extension_Base +* Renamed PJRT_Structure_Type to PJRT_Extension_Type (and similarly for enum fields) + ## 0.40 (Nov 27, 2023) * Added PJRT_Executable_GetCompiledMemoryStats. @@ -21,7 +28,7 @@ PJRT_ExecuteOptions. * Deprecated PJRT_LoadedExecutable_Fingerprint ## 0.34 (Oct 9, 2023) -* Added PJRT_Structure_Type::PJRT_Structure_Type_Profiler. +* Added PJRT_Extension_Type::PJRT_Extension_Type_Profiler. ## 0.33 (Oct 3, 2023) * Added PJRT_Client_CreateViewOfDeviceBuffer. @@ -30,9 +37,9 @@ PJRT_ExecuteOptions. * Added PJRT_Buffer_CopyToMemory. ## 0.31 (Sep 22, 2023) -* Added PJRT_Structure_Base. -* Added PJRT_Structure_Type. -* Renamed PJRT_Api.priv to PJRT_Api.extension_start. +* Added PJRT_Extension_Base. +* Added PJRT_Extension_Type. +* Renamed PJRT_Api.extension_start to PJRT_Api.extension_start. ## 0.30 (Sep 14, 2023) * Added PJRT_NamedValue_Type::PJRT_NamedValue_kBool. diff --git a/third_party/xla/xla/pjrt/c/pjrt_c_api.h b/third_party/xla/xla/pjrt/c/pjrt_c_api.h index c86d6813b586a2..b7450e44f4f6c9 100644 --- a/third_party/xla/xla/pjrt/c/pjrt_c_api.h +++ b/third_party/xla/xla/pjrt/c/pjrt_c_api.h @@ -53,14 +53,14 @@ extern "C" { // Changes include: // * Adding a new field to the PJRT_Api or argument structs // * Renaming a method or argument (doesn't affect ABI) -#define PJRT_API_MINOR 40 +#define PJRT_API_MINOR 42 // The plugin should set the major_version and minor_version of // PJRT_Api.pjrt_api_version to be the `PJRT_API_MAJOR` and `PJRT_API_MINOR` in // this header that the implementation was compiled with. struct PJRT_Api_Version { size_t struct_size; - void* priv; + void* extension_start; int major_version; // out int minor_version; // out }; @@ -77,7 +77,7 @@ typedef struct PJRT_Error PJRT_Error; struct PJRT_Error_Destroy_Args { size_t struct_size; - void* priv; + void* extension_start; PJRT_Error* error; }; PJRT_DEFINE_STRUCT_TRAITS(PJRT_Error_Destroy_Args, error); @@ -87,7 +87,7 @@ typedef void PJRT_Error_Destroy(PJRT_Error_Destroy_Args* args); struct PJRT_Error_Message_Args { size_t struct_size; - void* priv; + void* extension_start; const PJRT_Error* error; // Has the lifetime of `error`. const char* message; // out @@ -121,7 +121,7 @@ typedef enum { struct PJRT_Error_GetCode_Args { size_t struct_size; - void* priv; + void* extension_start; const PJRT_Error* error; PJRT_Error_Code code; // out }; @@ -151,7 +151,7 @@ typedef enum { // Named value for key-value pairs. struct PJRT_NamedValue { size_t struct_size; - void* priv; + void* extension_start; const char* name; size_t name_size; PJRT_NamedValue_Type type; @@ -172,16 +172,16 @@ PJRT_DEFINE_STRUCT_TRAITS(PJRT_NamedValue, value_size); struct PJRT_Plugin_Initialize_Args { size_t struct_size; - void* priv; + void* extension_start; }; -PJRT_DEFINE_STRUCT_TRAITS(PJRT_Plugin_Initialize_Args, priv); +PJRT_DEFINE_STRUCT_TRAITS(PJRT_Plugin_Initialize_Args, extension_start); // One-time plugin setup. Must be called before any other functions are called. typedef PJRT_Error* PJRT_Plugin_Initialize(PJRT_Plugin_Initialize_Args* args); struct PJRT_Plugin_Attributes_Args { size_t struct_size; - void* priv; + void* extension_start; // Returned attributes have the lifetime of the process. const PJRT_NamedValue* attributes; // out size_t num_attributes; // out @@ -205,7 +205,7 @@ typedef struct PJRT_Event PJRT_Event; struct PJRT_Event_Destroy_Args { size_t struct_size; - void* priv; + void* extension_start; PJRT_Event* event; }; PJRT_DEFINE_STRUCT_TRAITS(PJRT_Event_Destroy_Args, event); @@ -215,7 +215,7 @@ typedef PJRT_Error* PJRT_Event_Destroy(PJRT_Event_Destroy_Args* args); struct PJRT_Event_IsReady_Args { size_t struct_size; - void* priv; + void* extension_start; PJRT_Event* event; bool is_ready; // out }; @@ -227,7 +227,7 @@ typedef PJRT_Error* PJRT_Event_IsReady(PJRT_Event_IsReady_Args* args); struct PJRT_Event_Error_Args { size_t struct_size; - void* priv; + void* extension_start; PJRT_Event* event; }; PJRT_DEFINE_STRUCT_TRAITS(PJRT_Event_Error_Args, event); @@ -245,7 +245,7 @@ typedef PJRT_Error* PJRT_Event_Error(PJRT_Event_Error_Args* args); struct PJRT_Event_Await_Args { size_t struct_size; - void* priv; + void* extension_start; PJRT_Event* event; }; PJRT_DEFINE_STRUCT_TRAITS(PJRT_Event_Await_Args, event); @@ -263,7 +263,7 @@ typedef void (*PJRT_Event_OnReadyCallback)(PJRT_Error* error, void* user_arg); struct PJRT_Event_OnReady_Args { size_t struct_size; - void* priv; + void* extension_start; PJRT_Event* event; PJRT_Event_OnReadyCallback callback; // `user_arg` allows `callback` to be called with arbitrary arguments (e.g. @@ -297,7 +297,7 @@ typedef void (*PJRT_KeyValueGetCallback_ValueDeleter)(char* value); struct PJRT_KeyValueGetCallback_Args { size_t struct_size; - void* priv; + void* extension_start; const char* key; size_t key_size; int timeout_in_ms; @@ -323,7 +323,7 @@ typedef PJRT_Error* (*PJRT_KeyValueGetCallback)( struct PJRT_KeyValuePutCallback_Args { size_t struct_size; - void* priv; + void* extension_start; const char* key; size_t key_size; // Only needs to stay alive for the duration of the PJRT_KeyValuePutCallback @@ -344,7 +344,7 @@ typedef PJRT_Error* (*PJRT_KeyValuePutCallback)( struct PJRT_Client_Create_Args { size_t struct_size; - void* priv; + void* extension_start; // Extra platform-specific options to create a client. const PJRT_NamedValue* create_options; size_t num_options; @@ -367,7 +367,7 @@ typedef PJRT_Error* PJRT_Client_Create(PJRT_Client_Create_Args* args); struct PJRT_Client_Destroy_Args { size_t struct_size; - void* priv; + void* extension_start; PJRT_Client* client; }; PJRT_DEFINE_STRUCT_TRAITS(PJRT_Client_Destroy_Args, client); @@ -377,7 +377,7 @@ typedef PJRT_Error* PJRT_Client_Destroy(PJRT_Client_Destroy_Args* args); struct PJRT_Client_PlatformName_Args { size_t struct_size; - void* priv; + void* extension_start; PJRT_Client* client; // `platform_name` has the same lifetime as `client`. It is owned by `client`. const char* platform_name; // out @@ -391,7 +391,7 @@ typedef PJRT_Error* PJRT_Client_PlatformName( struct PJRT_Client_ProcessIndex_Args { size_t struct_size; - void* priv; + void* extension_start; PJRT_Client* client; int process_index; // out }; @@ -404,7 +404,7 @@ typedef PJRT_Error* PJRT_Client_ProcessIndex( struct PJRT_Client_PlatformVersion_Args { size_t struct_size; - void* priv; + void* extension_start; PJRT_Client* client; // `platform_version` has the same lifetime as `client`. It's owned by // `client`. @@ -421,7 +421,7 @@ typedef PJRT_Error* PJRT_Client_PlatformVersion( struct PJRT_Client_TopologyDescription_Args { size_t struct_size; - void* priv; + void* extension_start; PJRT_Client* client; // Is owned by and has the same lifetime as `client`. PJRT_TopologyDescription* topology; // out @@ -435,7 +435,7 @@ typedef PJRT_Error* PJRT_Client_TopologyDescription( struct PJRT_Client_Devices_Args { size_t struct_size; - void* priv; + void* extension_start; PJRT_Client* client; PJRT_Device* const* devices; // out size_t num_devices; // out @@ -448,7 +448,7 @@ typedef PJRT_Error* PJRT_Client_Devices(PJRT_Client_Devices_Args* args); struct PJRT_Client_AddressableDevices_Args { size_t struct_size; - void* priv; + void* extension_start; PJRT_Client* client; PJRT_Device* const* addressable_devices; // out size_t num_addressable_devices; // out @@ -464,7 +464,7 @@ typedef PJRT_Error* PJRT_Client_AddressableDevices( struct PJRT_Client_LookupDevice_Args { size_t struct_size; - void* priv; + void* extension_start; PJRT_Client* client; int id; // `device` has the same lifetime as `client`. It is owned by `client`. @@ -479,7 +479,7 @@ typedef PJRT_Error* PJRT_Client_LookupDevice( struct PJRT_Client_LookupAddressableDevice_Args { size_t struct_size; - void* priv; + void* extension_start; PJRT_Client* client; int local_hardware_id; // `addressable_device` has the same lifetime as `client`. It is owned by @@ -496,7 +496,7 @@ typedef PJRT_Error* PJRT_Client_LookupAddressableDevice( struct PJRT_Client_AddressableMemories_Args { size_t struct_size; - void* priv; + void* extension_start; PJRT_Client* client; PJRT_Memory* const* addressable_memories; // out size_t num_addressable_memories; // out @@ -512,7 +512,7 @@ typedef PJRT_Error* PJRT_Client_AddressableMemories( struct PJRT_Program { size_t struct_size; - void* priv; + void* extension_start; // Serialized code in the specified format below. // String is owned by the caller. char* code; // in/out depending on usage @@ -529,7 +529,7 @@ PJRT_DEFINE_STRUCT_TRAITS(PJRT_Program, format_size); struct PJRT_Client_Compile_Args { size_t struct_size; - void* priv; + void* extension_start; PJRT_Client* client; // Only needs to stay alive for the duration of the Compile call. // `program->format` and `program->format_size` are owned by the caller. @@ -549,7 +549,7 @@ typedef PJRT_Error* PJRT_Client_Compile(PJRT_Client_Compile_Args* args); struct PJRT_Client_DefaultDeviceAssignment_Args { size_t struct_size; - void* priv; + void* extension_start; PJRT_Client* client; int num_replicas; int num_partitions; @@ -643,7 +643,7 @@ typedef enum { struct PJRT_Buffer_MemoryLayout_Tiled { size_t struct_size; - void* priv; + void* extension_start; // A map from physical dimension numbers to logical dimension numbers. // The first element is the most minor physical dimension (fastest varying // index) and the last the most major (slowest varying index). The contents of @@ -661,7 +661,7 @@ PJRT_DEFINE_STRUCT_TRAITS(PJRT_Buffer_MemoryLayout_Tiled, num_tiles); struct PJRT_Buffer_MemoryLayout_Strides { size_t struct_size; - void* priv; + void* extension_start; // Number of bytes to traverse per dimension. Must be the same size as // the number of dimensions of the data. Caution: `byte_strides` are allowed // to be negative, in which case data may need to point to the interior of @@ -676,7 +676,7 @@ PJRT_DEFINE_STRUCT_TRAITS(PJRT_Buffer_MemoryLayout_Strides, num_byte_strides); // strides. struct PJRT_Buffer_MemoryLayout { size_t struct_size; - void* priv; + void* extension_start; union { PJRT_Buffer_MemoryLayout_Tiled tiled; PJRT_Buffer_MemoryLayout_Strides strides; @@ -687,7 +687,7 @@ PJRT_DEFINE_STRUCT_TRAITS(PJRT_Buffer_MemoryLayout, type); struct PJRT_Client_BufferFromHostBuffer_Args { size_t struct_size; - void* priv; + void* extension_start; PJRT_Client* client; // Pointer to the host buffer const void* data; @@ -735,7 +735,7 @@ typedef PJRT_Error* PJRT_Client_BufferFromHostBuffer( struct PJRT_Client_CreateViewOfDeviceBuffer_Args { size_t struct_size; - void* priv; + void* extension_start; PJRT_Client* client; // A pointer to a non-owned device buffer. A PJRT_Buffer that is a non-owned // view of this device buffer will be created. @@ -781,7 +781,7 @@ typedef PJRT_Error* PJRT_Client_CreateViewOfDeviceBuffer( struct PJRT_DeviceDescription_Id_Args { size_t struct_size; - void* priv; + void* extension_start; PJRT_DeviceDescription* device_description; int id; // out }; @@ -795,7 +795,7 @@ typedef PJRT_Error* PJRT_DeviceDescription_Id( struct PJRT_DeviceDescription_ProcessIndex_Args { size_t struct_size; - void* priv; + void* extension_start; PJRT_DeviceDescription* device_description; int process_index; // out }; @@ -812,7 +812,7 @@ typedef PJRT_Error* PJRT_DeviceDescription_ProcessIndex( struct PJRT_DeviceDescription_Attributes_Args { size_t struct_size; - void* priv; + void* extension_start; PJRT_DeviceDescription* device_description; size_t num_attributes; // out const PJRT_NamedValue* attributes; // out @@ -826,7 +826,7 @@ typedef PJRT_Error* PJRT_DeviceDescription_Attributes( struct PJRT_DeviceDescription_Kind_Args { size_t struct_size; - void* priv; + void* extension_start; PJRT_DeviceDescription* device_description; // `device_kind` string is owned by `device` and has same lifetime as // `device`. @@ -842,7 +842,7 @@ typedef PJRT_Error* PJRT_DeviceDescription_Kind( struct PJRT_DeviceDescription_DebugString_Args { size_t struct_size; - void* priv; + void* extension_start; PJRT_DeviceDescription* device_description; const char* debug_string; // out size_t debug_string_size; // out @@ -857,7 +857,7 @@ typedef PJRT_Error* PJRT_DeviceDescription_DebugString( struct PJRT_DeviceDescription_ToString_Args { size_t struct_size; - void* priv; + void* extension_start; PJRT_DeviceDescription* device_description; const char* to_string; // out size_t to_string_size; // out @@ -873,7 +873,7 @@ typedef PJRT_Error* PJRT_DeviceDescription_ToString( struct PJRT_Device_GetDescription_Args { size_t struct_size; - void* priv; + void* extension_start; PJRT_Device* device; PJRT_DeviceDescription* device_description; // out }; @@ -885,7 +885,7 @@ typedef PJRT_Error* PJRT_Device_GetDescription( struct PJRT_Device_IsAddressable_Args { size_t struct_size; - void* priv; + void* extension_start; PJRT_Device* device; bool is_addressable; // out }; @@ -897,7 +897,7 @@ typedef PJRT_Error* PJRT_Device_IsAddressable( struct PJRT_Device_LocalHardwareId_Args { size_t struct_size; - void* priv; + void* extension_start; PJRT_Device* device; int local_hardware_id; // out }; @@ -910,7 +910,7 @@ typedef PJRT_Error* PJRT_Device_LocalHardwareId( struct PJRT_Device_AddressableMemories_Args { size_t struct_size; - void* priv; + void* extension_start; PJRT_Device* device; // Has the lifetime of `device`. PJRT_Memory* const* memories; // out @@ -924,7 +924,7 @@ typedef PJRT_Error* PJRT_Device_AddressableMemories( struct PJRT_Device_DefaultMemory_Args { size_t struct_size; - void* priv; + void* extension_start; PJRT_Device* device; // `memory` has the same lifetime as `device`. PJRT_Memory* memory; // out @@ -938,7 +938,7 @@ typedef PJRT_Error* PJRT_Device_DefaultMemory( struct PJRT_Device_MemoryStats_Args { size_t struct_size; - void* priv; + void* extension_start; PJRT_Device* device; // Number of bytes in use. @@ -989,7 +989,7 @@ typedef PJRT_Error* PJRT_Device_MemoryStats(PJRT_Device_MemoryStats_Args* args); struct PJRT_Memory_Id_Args { size_t struct_size; - void* priv; + void* extension_start; PJRT_Memory* memory; int id; // out }; @@ -1000,7 +1000,7 @@ typedef PJRT_Error* PJRT_Memory_Id(PJRT_Memory_Id_Args* args); struct PJRT_Memory_Kind_Args { size_t struct_size; - void* priv; + void* extension_start; PJRT_Memory* memory; // `memory_kind` has same lifetime as `memory`. const char* memory_kind; // out @@ -1013,7 +1013,7 @@ typedef PJRT_Error* PJRT_Memory_Kind(PJRT_Memory_Kind_Args* args); struct PJRT_Memory_DebugString_Args { size_t struct_size; - void* priv; + void* extension_start; PJRT_Memory* memory; const char* debug_string; // out size_t debug_string_size; // out @@ -1026,7 +1026,7 @@ typedef PJRT_Error* PJRT_Memory_DebugString(PJRT_Memory_DebugString_Args* args); struct PJRT_Memory_ToString_Args { size_t struct_size; - void* priv; + void* extension_start; PJRT_Memory* memory; const char* to_string; // out size_t to_string_size; // out @@ -1038,7 +1038,7 @@ typedef PJRT_Error* PJRT_Memory_ToString(PJRT_Memory_ToString_Args* args); struct PJRT_Memory_AddressableByDevices_Args { size_t struct_size; - void* priv; + void* extension_start; PJRT_Memory* memory; PJRT_Device* const* devices; // out size_t num_devices; // out @@ -1053,7 +1053,7 @@ typedef PJRT_Error* PJRT_Memory_AddressableByDevices( struct PJRT_Executable_Destroy_Args { size_t struct_size; - void* priv; + void* extension_start; PJRT_Executable* executable; }; PJRT_DEFINE_STRUCT_TRAITS(PJRT_Executable_Destroy_Args, executable); @@ -1063,7 +1063,7 @@ typedef PJRT_Error* PJRT_Executable_Destroy(PJRT_Executable_Destroy_Args* args); struct PJRT_LoadedExecutable_Destroy_Args { size_t struct_size; - void* priv; + void* extension_start; PJRT_LoadedExecutable* executable; }; PJRT_DEFINE_STRUCT_TRAITS(PJRT_LoadedExecutable_Destroy_Args, executable); @@ -1075,7 +1075,7 @@ typedef PJRT_Error* PJRT_LoadedExecutable_Destroy( struct PJRT_LoadedExecutable_GetExecutable_Args { size_t struct_size; - void* priv; + void* extension_start; PJRT_LoadedExecutable* loaded_executable; PJRT_Executable* executable; // out }; @@ -1088,7 +1088,7 @@ typedef PJRT_Error* PJRT_LoadedExecutable_GetExecutable( struct PJRT_Executable_Name_Args { size_t struct_size; - void* priv; + void* extension_start; PJRT_Executable* executable; // `executable_name` has the same lifetime as `executable`. It is owned by // `executable`. @@ -1103,7 +1103,7 @@ typedef PJRT_Error* PJRT_Executable_Name(PJRT_Executable_Name_Args* args); // TODO(b/269178731): Revisit whether num_replicas is needed. struct PJRT_Executable_NumReplicas_Args { size_t struct_size; - void* priv; + void* extension_start; PJRT_Executable* executable; size_t num_replicas; // out }; @@ -1115,7 +1115,7 @@ typedef PJRT_Error* PJRT_Executable_NumReplicas( struct PJRT_Executable_NumPartitions_Args { size_t struct_size; - void* priv; + void* extension_start; PJRT_Executable* executable; size_t num_partitions; // out }; @@ -1127,7 +1127,7 @@ typedef PJRT_Error* PJRT_Executable_NumPartitions( struct PJRT_LoadedExecutable_AddressableDevices_Args { size_t struct_size; - void* priv; + void* extension_start; PJRT_LoadedExecutable* executable; PJRT_Device* const* addressable_devices; // out size_t num_addressable_devices; // out @@ -1141,7 +1141,7 @@ typedef PJRT_Error* PJRT_LoadedExecutable_AddressableDevices( struct PJRT_Executable_OptimizedProgram_Args { size_t struct_size; - void* priv; + void* extension_start; PJRT_Executable* executable; PJRT_Program* program; // out, but read below }; @@ -1175,7 +1175,7 @@ typedef PJRT_Error* PJRT_Executable_OptimizedProgram( struct PJRT_LoadedExecutable_Delete_Args { size_t struct_size; - void* priv; + void* extension_start; PJRT_LoadedExecutable* executable; }; PJRT_DEFINE_STRUCT_TRAITS(PJRT_LoadedExecutable_Delete_Args, executable); @@ -1190,7 +1190,7 @@ typedef PJRT_Error* PJRT_LoadedExecutable_Delete( struct PJRT_LoadedExecutable_IsDeleted_Args { size_t struct_size; - void* priv; + void* extension_start; PJRT_LoadedExecutable* executable; bool is_deleted; // out }; @@ -1247,7 +1247,7 @@ PJRT_DEFINE_STRUCT_TRAITS(PJRT_RecvCallbackInfo, recv_callback); struct PJRT_ExecuteOptions { size_t struct_size; - void* priv; + void* extension_start; // Callbacks for when send/recv ops are executed. The outer lists correspond // to each device returned by `PJRT_Executable_AddressableDevices` for // `executable` (i.e. they will have length `num_devices`). Each inner list @@ -1279,7 +1279,7 @@ PJRT_DEFINE_STRUCT_TRAITS(PJRT_ExecuteOptions, launch_id); struct PJRT_LoadedExecutable_Execute_Args { size_t struct_size; - void* priv; + void* extension_start; PJRT_LoadedExecutable* executable; // Only needs to stay alive for the duration of the Execute call. PJRT_ExecuteOptions* options; @@ -1318,7 +1318,7 @@ typedef PJRT_Error* PJRT_LoadedExecutable_Execute( struct PJRT_Executable_NumOutputs_Args { size_t struct_size; - void* priv; + void* extension_start; PJRT_Executable* executable; size_t num_outputs; // out }; @@ -1330,7 +1330,7 @@ typedef PJRT_Error* PJRT_Executable_NumOutputs( struct PJRT_Executable_SizeOfGeneratedCodeInBytes_Args { size_t struct_size; - void* priv; + void* extension_start; PJRT_Executable* executable; int64_t size_in_bytes; // out }; @@ -1342,7 +1342,7 @@ typedef PJRT_Error* PJRT_Executable_SizeOfGeneratedCodeInBytes( struct PJRT_Executable_Fingerprint_Args { size_t struct_size; - void* priv; + void* extension_start; PJRT_Executable* executable; // Has the lifetime of `executable` const char* executable_fingerprint; // out @@ -1360,7 +1360,7 @@ typedef PJRT_Error* PJRT_Executable_Fingerprint( struct PJRT_Executable_GetCostAnalysis_Args { size_t struct_size; - void* priv; + void* extension_start; PJRT_Executable* executable; size_t num_properties; // out // `properties` and any embedded data are owned by and have the same lifetime @@ -1378,7 +1378,7 @@ typedef PJRT_Error* PJRT_Executable_GetCostAnalysis( struct PJRT_Executable_GetCompiledMemoryStats_Args { size_t struct_size; - void* priv; + void* extension_start; PJRT_Executable* executable; // Mirrors xla::CompiledMemoryStats. @@ -1399,7 +1399,7 @@ typedef PJRT_Error* PJRT_Executable_GetCompiledMemoryStats( struct PJRT_Executable_OutputElementTypes_Args { size_t struct_size; - void* priv; + void* extension_start; PJRT_Executable* executable; PJRT_Buffer_Type* output_types; // out size_t num_output_types; // out @@ -1413,7 +1413,7 @@ typedef PJRT_Error* PJRT_Executable_OutputElementTypes( struct PJRT_Executable_OutputDimensions_Args { size_t struct_size; - void* priv; + void* extension_start; PJRT_Executable* executable; size_t num_outputs; // Has length: sum of all elements in the list `dim_sizes`. @@ -1431,7 +1431,7 @@ typedef PJRT_Error* PJRT_Executable_OutputDimensions( struct PJRT_Executable_OutputMemoryKinds_Args { size_t struct_size; - void* priv; + void* extension_start; PJRT_Executable* executable; size_t num_outputs; // Has length `num_outputs`. @@ -1450,7 +1450,7 @@ typedef struct PJRT_SerializedExecutable PJRT_SerializedExecutable; struct PJRT_Executable_Serialize_Args { size_t struct_size; - void* priv; + void* extension_start; const PJRT_Executable* executable; // Lives only as long as serialized_executable @@ -1473,7 +1473,7 @@ typedef PJRT_Error* PJRT_Executable_Serialize( struct PJRT_Executable_DeserializeAndLoad_Args { size_t struct_size; - void* priv; + void* extension_start; PJRT_Client* client; const char* serialized_executable; size_t serialized_executable_size; @@ -1490,7 +1490,7 @@ typedef PJRT_Error* PJRT_Executable_DeserializeAndLoad( struct PJRT_LoadedExecutable_Fingerprint_Args { size_t struct_size; - void* priv; + void* extension_start; PJRT_LoadedExecutable* executable; // Has the lifetime of `executable` const char* executable_fingerprint; // out @@ -1510,7 +1510,7 @@ typedef PJRT_Error* PJRT_LoadedExecutable_Fingerprint( struct PJRT_Buffer_Destroy_Args { size_t struct_size; - void* priv; + void* extension_start; PJRT_Buffer* buffer; }; PJRT_DEFINE_STRUCT_TRAITS(PJRT_Buffer_Destroy_Args, buffer); @@ -1521,7 +1521,7 @@ typedef PJRT_Error* PJRT_Buffer_Destroy(PJRT_Buffer_Destroy_Args* args); struct PJRT_Buffer_ElementType_Args { size_t struct_size; - void* priv; + void* extension_start; PJRT_Buffer* buffer; PJRT_Buffer_Type type; // out }; @@ -1532,7 +1532,7 @@ typedef PJRT_Error* PJRT_Buffer_ElementType(PJRT_Buffer_ElementType_Args* args); struct PJRT_Buffer_Dimensions_Args { size_t struct_size; - void* priv; + void* extension_start; PJRT_Buffer* buffer; // Has the lifetime of `buffer` and length `num_dims`. const int64_t* dims; // out @@ -1545,7 +1545,7 @@ typedef PJRT_Error* PJRT_Buffer_Dimensions(PJRT_Buffer_Dimensions_Args* args); struct PJRT_Buffer_UnpaddedDimensions_Args { size_t struct_size; - void* priv; + void* extension_start; PJRT_Buffer* buffer; // Has the lifetime of `buffer` and length `num_dims`. const int64_t* unpadded_dims; // out @@ -1565,7 +1565,7 @@ typedef PJRT_Error* PJRT_Buffer_UnpaddedDimensions( struct PJRT_Buffer_DynamicDimensionIndices_Args { size_t struct_size; - void* priv; + void* extension_start; PJRT_Buffer* buffer; // Has the lifetime of `buffer` and length `num_dynamic_dims`. const size_t* dynamic_dim_indices; // out @@ -1583,7 +1583,7 @@ typedef PJRT_Error* PJRT_Buffer_DynamicDimensionIndices( struct PJRT_Buffer_GetMemoryLayout_Args { size_t struct_size; - void* priv; + void* extension_start; PJRT_Buffer* buffer; // Layout data is owned by and has the lifetime of `buffer`. PJRT_Buffer_MemoryLayout layout; // out @@ -1596,7 +1596,7 @@ typedef PJRT_Error* PJRT_Buffer_GetMemoryLayout( struct PJRT_Buffer_ToHostBuffer_Args { size_t struct_size; - void* priv; + void* extension_start; PJRT_Buffer* src; // The caller can specify an optional host layout. If nullptr, the layout of @@ -1622,7 +1622,7 @@ typedef PJRT_Error* PJRT_Buffer_ToHostBuffer( struct PJRT_Buffer_OnDeviceSizeInBytes_Args { size_t struct_size; - void* priv; + void* extension_start; PJRT_Buffer* buffer; size_t on_device_size_in_bytes; // out }; @@ -1635,7 +1635,7 @@ typedef PJRT_Error* PJRT_Buffer_OnDeviceSizeInBytes( struct PJRT_Buffer_Delete_Args { size_t struct_size; - void* priv; + void* extension_start; PJRT_Buffer* buffer; }; PJRT_DEFINE_STRUCT_TRAITS(PJRT_Buffer_Delete_Args, buffer); @@ -1649,7 +1649,7 @@ typedef PJRT_Error* PJRT_Buffer_Delete(PJRT_Buffer_Delete_Args* args); struct PJRT_Buffer_IsDeleted_Args { size_t struct_size; - void* priv; + void* extension_start; PJRT_Buffer* buffer; bool is_deleted; // out }; @@ -1660,7 +1660,7 @@ typedef PJRT_Error* PJRT_Buffer_IsDeleted(PJRT_Buffer_IsDeleted_Args* args); struct PJRT_Buffer_CopyToDevice_Args { size_t struct_size; - void* priv; + void* extension_start; PJRT_Buffer* buffer; PJRT_Device* dst_device; PJRT_Buffer* dst_buffer; // out @@ -1675,7 +1675,7 @@ typedef PJRT_Error* PJRT_Buffer_CopyToDevice( struct PJRT_Buffer_CopyToMemory_Args { size_t struct_size; - void* priv; + void* extension_start; PJRT_Buffer* buffer; PJRT_Memory* dst_memory; PJRT_Buffer* dst_buffer; // out @@ -1690,7 +1690,7 @@ typedef PJRT_Error* PJRT_Buffer_CopyToMemory( struct PJRT_Buffer_IsOnCpu_Args { size_t struct_size; - void* priv; + void* extension_start; PJRT_Buffer* buffer; bool is_on_cpu; // out }; @@ -1701,7 +1701,7 @@ typedef PJRT_Error* PJRT_Buffer_IsOnCpu(PJRT_Buffer_IsOnCpu_Args* args); struct PJRT_Buffer_Device_Args { size_t struct_size; - void* priv; + void* extension_start; PJRT_Buffer* buffer; PJRT_Device* device; // out }; @@ -1712,7 +1712,7 @@ typedef PJRT_Error* PJRT_Buffer_Device(PJRT_Buffer_Device_Args* args); struct PJRT_Buffer_Memory_Args { size_t struct_size; - void* priv; + void* extension_start; PJRT_Buffer* buffer; PJRT_Memory* memory; // out }; @@ -1723,7 +1723,7 @@ typedef PJRT_Error* PJRT_Buffer_Memory(PJRT_Buffer_Memory_Args* args); struct PJRT_Buffer_ReadyEvent_Args { size_t struct_size; - void* priv; + void* extension_start; PJRT_Buffer* buffer; // The caller is responsible for calling PJRT_Event_Destroy on `event`. PJRT_Event* event; // out @@ -1743,7 +1743,7 @@ typedef PJRT_Error* PJRT_Buffer_ReadyEvent(PJRT_Buffer_ReadyEvent_Args* args); struct PJRT_Buffer_UnsafePointer_Args { size_t struct_size; - void* priv; + void* extension_start; PJRT_Buffer* buffer; uintptr_t buffer_pointer; // out }; @@ -1756,7 +1756,7 @@ typedef PJRT_Error* PJRT_Buffer_UnsafePointer( struct PJRT_Buffer_IncreaseExternalReferenceCount_Args { size_t struct_size; - void* priv; + void* extension_start; PJRT_Buffer* buffer; }; PJRT_DEFINE_STRUCT_TRAITS(PJRT_Buffer_IncreaseExternalReferenceCount_Args, @@ -1772,7 +1772,7 @@ typedef PJRT_Error* PJRT_Buffer_IncreaseExternalReferenceCount( struct PJRT_Buffer_DecreaseExternalReferenceCount_Args { size_t struct_size; - void* priv; + void* extension_start; PJRT_Buffer* buffer; }; PJRT_DEFINE_STRUCT_TRAITS(PJRT_Buffer_DecreaseExternalReferenceCount_Args, @@ -1786,7 +1786,7 @@ typedef PJRT_Error* PJRT_Buffer_DecreaseExternalReferenceCount( struct PJRT_Buffer_OpaqueDeviceMemoryDataPointer_Args { size_t struct_size; - void* priv; + void* extension_start; PJRT_Buffer* buffer; void* device_memory_ptr; // out }; @@ -1803,7 +1803,7 @@ typedef PJRT_Error* PJRT_Buffer_OpaqueDeviceMemoryDataPointer( struct PJRT_CopyToDeviceStream_Destroy_Args { size_t struct_size; - void* priv; + void* extension_start; PJRT_CopyToDeviceStream* stream; }; PJRT_DEFINE_STRUCT_TRAITS(PJRT_CopyToDeviceStream_Destroy_Args, stream); @@ -1814,7 +1814,7 @@ typedef PJRT_Error* PJRT_CopyToDeviceStream_Destroy( struct PJRT_CopyToDeviceStream_AddChunk_Args { size_t struct_size; - void* priv; + void* extension_start; PJRT_CopyToDeviceStream* stream; // Takes ownership of `chunk` (i.e. implementation will call chunk.deleter). PJRT_Chunk* chunk; @@ -1835,7 +1835,7 @@ typedef PJRT_Error* PJRT_CopyToDeviceStream_AddChunk( struct PJRT_CopyToDeviceStream_TotalBytes_Args { size_t struct_size; - void* priv; + void* extension_start; PJRT_CopyToDeviceStream* stream; int64_t total_bytes; // out }; @@ -1847,7 +1847,7 @@ typedef PJRT_Error* PJRT_CopyToDeviceStream_TotalBytes( struct PJRT_CopyToDeviceStream_GranuleSize_Args { size_t struct_size; - void* priv; + void* extension_start; PJRT_CopyToDeviceStream* stream; int64_t granule_size_in_bytes; // out }; @@ -1861,7 +1861,7 @@ typedef PJRT_Error* PJRT_CopyToDeviceStream_GranuleSize( struct PJRT_CopyToDeviceStream_CurrentBytes_Args { size_t struct_size; - void* priv; + void* extension_start; PJRT_CopyToDeviceStream* stream; int64_t current_bytes; // out }; @@ -1877,7 +1877,7 @@ typedef PJRT_Error* PJRT_CopyToDeviceStream_CurrentBytes( struct PJRT_TopologyDescription_Create_Args { size_t struct_size; - void* priv; + void* extension_start; const char* topology_name; size_t topology_name_size; // Extra platform-specific options to create a client. @@ -1894,7 +1894,7 @@ typedef PJRT_Error* PJRT_TopologyDescription_Create( struct PJRT_TopologyDescription_Destroy_Args { size_t struct_size; - void* priv; + void* extension_start; PJRT_TopologyDescription* topology; }; PJRT_DEFINE_STRUCT_TRAITS(PJRT_TopologyDescription_Destroy_Args, topology); @@ -1905,7 +1905,7 @@ typedef PJRT_Error* PJRT_TopologyDescription_Destroy( struct PJRT_TopologyDescription_PlatformVersion_Args { size_t struct_size; - void* priv; + void* extension_start; PJRT_TopologyDescription* topology; // `platform_version` has the same lifetime as `topology`. It's owned by // `topology`. @@ -1922,7 +1922,7 @@ typedef PJRT_Error* PJRT_TopologyDescription_PlatformVersion( struct PJRT_TopologyDescription_PlatformName_Args { size_t struct_size; - void* priv; + void* extension_start; PJRT_TopologyDescription* topology; // `platform_name` has the same lifetime as `topology`. It is owned by // `topology`. @@ -1938,7 +1938,7 @@ typedef PJRT_Error* PJRT_TopologyDescription_PlatformName( struct PJRT_TopologyDescription_GetDeviceDescriptions_Args { size_t struct_size; - void* priv; + void* extension_start; PJRT_TopologyDescription* topology; // Has the same lifetime as topology. PJRT_DeviceDescription* const* descriptions; // out @@ -1957,7 +1957,7 @@ typedef struct PJRT_SerializedTopology PJRT_SerializedTopology; struct PJRT_TopologyDescription_Serialize_Args { size_t struct_size; - void* priv; + void* extension_start; PJRT_TopologyDescription* topology; // Lives only as long as serialized_topology. @@ -1979,7 +1979,7 @@ typedef PJRT_Error* PJRT_TopologyDescription_Serialize( struct PJRT_TopologyDescription_Attributes_Args { size_t struct_size; - void* priv; + void* extension_start; PJRT_TopologyDescription* topology; // Only lives as long as topology. @@ -1995,7 +1995,7 @@ typedef PJRT_Error* PJRT_TopologyDescription_Attributes( struct PJRT_Compile_Args { size_t struct_size; - void* priv; + void* extension_start; const PJRT_TopologyDescription* topology; // Only needs to stay alive for the duration of the Compile call. // `program->format` and `program->format_size` are owned by the caller. @@ -2019,17 +2019,17 @@ typedef PJRT_Error* PJRT_Compile(PJRT_Compile_Args* args); // -------------------------------- Extension ---------------------------------- typedef enum { - PJRT_Structure_Type_Gpu_Custom_Call = 0, - PJRT_Structure_Type_Profiler, -} PJRT_Structure_Type; + PJRT_Extension_Type_Gpu_Custom_Call = 0, + PJRT_Extension_Type_Profiler, +} PJRT_Extension_Type; -// PJRT_Structure_Base contains a type and a pointer to next -// PJRT_Structure_Base. The framework can go through this chain to find +// PJRT_Extension_Base contains a type and a pointer to next +// PJRT_Extension_Base. The framework can go through this chain to find // structure and identify it with the type. -typedef struct PJRT_Structure_Base { - PJRT_Structure_Type type; - const struct PJRT_Structure_Base* next; -} PJRT_Structure_Base; +typedef struct PJRT_Extension_Base { + PJRT_Extension_Type type; + const struct PJRT_Extension_Base* next; +} PJRT_Extension_Base; // -------------------------------- API access --------------------------------- diff --git a/third_party/xla/xla/pjrt/c/pjrt_c_api_gpu_extension.h b/third_party/xla/xla/pjrt/c/pjrt_c_api_gpu_extension.h index f6dbbc92e7fe78..341eadcd15a25b 100644 --- a/third_party/xla/xla/pjrt/c/pjrt_c_api_gpu_extension.h +++ b/third_party/xla/xla/pjrt/c/pjrt_c_api_gpu_extension.h @@ -41,7 +41,7 @@ typedef PJRT_Error* PJRT_Gpu_Register_Custom_Call( PJRT_Gpu_Register_Custom_Call_Args* args); typedef struct PJRT_Gpu_Custom_Call { - PJRT_Structure_Type type; + PJRT_Extension_Type type; const void* next; PJRT_Gpu_Register_Custom_Call* custom_call; } PJRT_Gpu_Custom_Call; diff --git a/third_party/xla/xla/pjrt/c/pjrt_c_api_gpu_internal.cc b/third_party/xla/xla/pjrt/c/pjrt_c_api_gpu_internal.cc index fa10334d506fc8..bcd1a4410830ef 100644 --- a/third_party/xla/xla/pjrt/c/pjrt_c_api_gpu_internal.cc +++ b/third_party/xla/xla/pjrt/c/pjrt_c_api_gpu_internal.cc @@ -27,6 +27,7 @@ limitations under the License. #include "xla/backends/profiler/plugin/plugin_tracer_impl.h" #include "xla/backends/profiler/plugin/profiler_c_api.h" #include "xla/backends/profiler/plugin/profiler_error.h" +#include "xla/client/local_client.h" #include "xla/ffi/api/c_api.h" #include "xla/ffi/ffi.h" #include "xla/ffi/ffi_api.h" @@ -39,7 +40,12 @@ limitations under the License. #include "xla/pjrt/gpu/se_gpu_pjrt_client.h" #include "xla/pjrt/pjrt_client.h" #include "xla/pjrt/pjrt_common.h" +#include "xla/pjrt/pjrt_compiler.h" +#include "xla/pjrt/pjrt_device_description.h" +#include "xla/service/compiler.h" #include "xla/service/custom_call_target_registry.h" +#include "xla/stream_executor/device_description.h" +#include "xla/stream_executor/stream_executor_pimpl.h" #include "tsl/platform/errors.h" namespace pjrt { @@ -137,8 +143,32 @@ PJRT_Error* PJRT_Client_Create(PJRT_Client_Create_Args* args) { PJRT_Error* PJRT_GpuDeviceTopology_Create( PJRT_TopologyDescription_Create_Args* args) { - return new PJRT_Error{tsl::errors::Unimplemented( - "Topology not supported for GPU compilation.")}; + PJRT_RETURN_IF_ERROR(ActualStructSizeIsGreaterOrEqual( + "PJRT_TopologyDescription_Create_Args", + PJRT_TopologyDescription_Create_Args_STRUCT_SIZE, args->struct_size)); + + PJRT_ASSIGN_OR_RETURN(xla::LocalClient * xla_client, + xla::GetGpuXlaClient(/*platform_name=*/std::nullopt, + /*allowed_devices=*/std::nullopt)); + stream_executor::StreamExecutor* executor = + xla_client->backend().default_stream_executor(); + const stream_executor::DeviceDescription& description = + executor->GetDeviceDescription(); + std::vector device_ids; + device_ids.reserve(xla_client->backend().stream_executors().size()); + for (stream_executor::StreamExecutor* executor : + xla_client->backend().stream_executors()) { + device_ids.push_back(executor->device_ordinal()); + } + auto gpu_target_config = xla::Compiler::TargetConfig(executor); + auto pjrt_topology = + std::make_unique( + xla::CudaId(), xla::CudaName(), description.name(), device_ids, + absl::flat_hash_map{ + {"target_config", + gpu_target_config.ToProto().SerializeAsString()}}); + args->topology = CreateWrapperDeviceTopology(std::move(pjrt_topology)); + return nullptr; } PLUGIN_Profiler_Api profiler_api{ @@ -155,7 +185,7 @@ PLUGIN_Profiler_Api profiler_api{ }; PJRT_Profiler_Extension profiler_extension{ - /*type=*/PJRT_Structure_Type::PJRT_Structure_Type_Profiler, + /*type=*/PJRT_Extension_Type::PJRT_Extension_Type_Profiler, /*next=*/nullptr, /*profiler_api=*/&profiler_api, }; @@ -187,7 +217,7 @@ PJRT_Error* PJRT_Gpu_Register_Custom_Call( } PJRT_Gpu_Custom_Call custom_call{ - /*type=*/PJRT_Structure_Type::PJRT_Structure_Type_Gpu_Custom_Call, + /*type=*/PJRT_Extension_Type::PJRT_Extension_Type_Gpu_Custom_Call, /*next=*/&profiler_extension, /*custom_call=*/PJRT_Gpu_Register_Custom_Call, }; diff --git a/third_party/xla/xla/pjrt/c/pjrt_c_api_gpu_test.cc b/third_party/xla/xla/pjrt/c/pjrt_c_api_gpu_test.cc index 444f00fb79b917..e60e2126a73b08 100644 --- a/third_party/xla/xla/pjrt/c/pjrt_c_api_gpu_test.cc +++ b/third_party/xla/xla/pjrt/c/pjrt_c_api_gpu_test.cc @@ -84,7 +84,7 @@ TEST_F(PjrtCApiGpuTest, CreateViewOfDeviceBuffer) { PJRT_Buffer_OpaqueDeviceMemoryDataPointer_Args device_buffer_ptr_args; device_buffer_ptr_args.struct_size = PJRT_Buffer_OpaqueDeviceMemoryDataPointer_Args_STRUCT_SIZE; - device_buffer_ptr_args.priv = nullptr; + device_buffer_ptr_args.extension_start = nullptr; device_buffer_ptr_args.buffer = buffer.get(); PJRT_Error* device_buffer_ptr_error = api_->PJRT_Buffer_OpaqueDeviceMemoryDataPointer(&device_buffer_ptr_args); @@ -92,7 +92,7 @@ TEST_F(PjrtCApiGpuTest, CreateViewOfDeviceBuffer) { // Looks up a device. PJRT_Buffer_Device_Args device_args = PJRT_Buffer_Device_Args{ /*struct_size=*/PJRT_Buffer_Device_Args_STRUCT_SIZE, - /*priv=*/nullptr, + /*extension_start=*/nullptr, /*buffer=*/buffer.get(), }; PJRT_Error* device_error = api_->PJRT_Buffer_Device(&device_args); @@ -102,7 +102,7 @@ TEST_F(PjrtCApiGpuTest, CreateViewOfDeviceBuffer) { PJRT_Client_CreateViewOfDeviceBuffer_Args create_view_args; create_view_args.struct_size = PJRT_Client_CreateViewOfDeviceBuffer_Args_STRUCT_SIZE; - create_view_args.priv = nullptr; + create_view_args.extension_start = nullptr; create_view_args.client = client_; create_view_args.device_buffer_ptr = device_buffer_ptr_args.device_memory_ptr; xla::Shape shape = xla::ShapeUtil::MakeShape(xla::S32, {4}); @@ -136,7 +136,7 @@ TEST_F(PjrtCApiGpuTest, CreateViewOfDeviceBuffer) { // Transfers view_buffer to host to verify. PJRT_Buffer_ToHostBuffer_Args to_host_args; to_host_args.struct_size = PJRT_Buffer_ToHostBuffer_Args_STRUCT_SIZE; - to_host_args.priv = nullptr; + to_host_args.extension_start = nullptr; to_host_args.src = view_buffer.get(); xla::Shape host_shape = xla::ShapeUtil::MakeShape(xla::F32, {4}); auto literal = std::make_shared(host_shape); @@ -163,7 +163,7 @@ absl::StatusOr BuildCreateArg( std::vector& c_options) { PJRT_Client_Create_Args args; args.struct_size = PJRT_Client_Create_Args_STRUCT_SIZE; - args.priv = nullptr; + args.extension_start = nullptr; args.create_options = c_options.data(); args.num_options = c_options.size(); args.kv_get_callback = kv_callback_data->c_kv_get; @@ -201,7 +201,7 @@ TEST(PjrtCApiGpuKVStoreTest, CreateClientWithKVCallback) { PJRT_Client_Devices_Args device_args; device_args.struct_size = PJRT_Client_Devices_Args_STRUCT_SIZE; - device_args.priv = nullptr; + device_args.extension_start = nullptr; device_args.client = create_arg.client; PJRT_Error* device_error = api->PJRT_Client_Devices(&device_args); @@ -211,7 +211,7 @@ TEST(PjrtCApiGpuKVStoreTest, CreateClientWithKVCallback) { PJRT_Client_AddressableDevices_Args addressable_device_args; addressable_device_args.struct_size = PJRT_Client_AddressableDevices_Args_STRUCT_SIZE; - addressable_device_args.priv = nullptr; + addressable_device_args.extension_start = nullptr; addressable_device_args.client = create_arg.client; PJRT_Error* addressable_device_error = @@ -221,7 +221,7 @@ TEST(PjrtCApiGpuKVStoreTest, CreateClientWithKVCallback) { PJRT_Client_Destroy_Args destroy_args; destroy_args.struct_size = PJRT_Client_Destroy_Args_STRUCT_SIZE; - destroy_args.priv = nullptr; + destroy_args.extension_start = nullptr; destroy_args.client = create_arg.client; PJRT_Error* destroy_error = api->PJRT_Client_Destroy(&destroy_args); @@ -253,7 +253,7 @@ TEST(PjrtCApiGpuAllocatorTest, ValidOptionsParsing) { ::pjrt::ConvertToPjRtNamedValueList(options, /*api_minor_version=*/30)); PJRT_Client_Create_Args create_arg; create_arg.struct_size = PJRT_Client_Create_Args_STRUCT_SIZE; - create_arg.priv = nullptr; + create_arg.extension_start = nullptr; create_arg.client = nullptr; create_arg.create_options = c_options.data(); create_arg.num_options = c_options.size(); @@ -262,7 +262,7 @@ TEST(PjrtCApiGpuAllocatorTest, ValidOptionsParsing) { PJRT_Client_Destroy_Args destroy_args; destroy_args.struct_size = PJRT_Client_Destroy_Args_STRUCT_SIZE; - destroy_args.priv = nullptr; + destroy_args.extension_start = nullptr; destroy_args.client = create_arg.client; PJRT_Error* destroy_error = api->PJRT_Client_Destroy(&destroy_args); @@ -282,7 +282,7 @@ TEST(PjrtCApiGpuAllocatorTest, InvalidAllocatorOptionsParsing) { ::pjrt::ConvertToPjRtNamedValueList(options, /*api_minor_version=*/30)); PJRT_Client_Create_Args create_arg; create_arg.struct_size = PJRT_Client_Create_Args_STRUCT_SIZE; - create_arg.priv = nullptr; + create_arg.extension_start = nullptr; create_arg.client = nullptr; create_arg.create_options = c_options.data(); create_arg.num_options = c_options.size(); @@ -297,7 +297,7 @@ TEST(PjrtCApiGpuAllocatorTest, InvalidAllocatorOptionsParsing) { PJRT_Error_Destroy_Args error_destroy_args; error_destroy_args.struct_size = PJRT_Error_Destroy_Args_STRUCT_SIZE; - error_destroy_args.priv = nullptr; + error_destroy_args.extension_start = nullptr; error_destroy_args.error = error; api->PJRT_Error_Destroy(&error_destroy_args); @@ -317,7 +317,7 @@ TEST(PjrtCApiPlatformNameTest, AvailablePlatformName) { ::pjrt::ConvertToPjRtNamedValueList(options, /*api_minor_version=*/30)); PJRT_Client_Create_Args create_arg; create_arg.struct_size = PJRT_Client_Create_Args_STRUCT_SIZE; - create_arg.priv = nullptr; + create_arg.extension_start = nullptr; create_arg.client = nullptr; create_arg.create_options = c_options.data(); create_arg.num_options = c_options.size(); @@ -326,7 +326,7 @@ TEST(PjrtCApiPlatformNameTest, AvailablePlatformName) { PJRT_Client_PlatformName_Args platform_name_args; platform_name_args.struct_size = PJRT_Client_PlatformName_Args_STRUCT_SIZE; - platform_name_args.priv = nullptr; + platform_name_args.extension_start = nullptr; platform_name_args.client = create_arg.client; PJRT_Error* platform_name_error = @@ -340,7 +340,7 @@ TEST(PjrtCApiPlatformNameTest, AvailablePlatformName) { PJRT_Client_Destroy_Args destroy_args; destroy_args.struct_size = PJRT_Client_Destroy_Args_STRUCT_SIZE; - destroy_args.priv = nullptr; + destroy_args.extension_start = nullptr; destroy_args.client = create_arg.client; PJRT_Error* destroy_error = api->PJRT_Client_Destroy(&destroy_args); @@ -359,7 +359,7 @@ TEST(PjrtCApiPlatformNameTest, UnavailablePlatformName) { ::pjrt::ConvertToPjRtNamedValueList(options, /*api_minor_version=*/30)); PJRT_Client_Create_Args create_arg; create_arg.struct_size = PJRT_Client_Create_Args_STRUCT_SIZE; - create_arg.priv = nullptr; + create_arg.extension_start = nullptr; create_arg.client = nullptr; create_arg.create_options = c_options.data(); create_arg.num_options = c_options.size(); @@ -374,7 +374,7 @@ TEST(PjrtCApiPlatformNameTest, UnavailablePlatformName) { PJRT_Error_Destroy_Args error_destroy_args; error_destroy_args.struct_size = PJRT_Error_Destroy_Args_STRUCT_SIZE; - error_destroy_args.priv = nullptr; + error_destroy_args.extension_start = nullptr; error_destroy_args.error = error; api->PJRT_Error_Destroy(&error_destroy_args); @@ -382,7 +382,7 @@ TEST(PjrtCApiPlatformNameTest, UnavailablePlatformName) { void TestCustomCallV2() {} -TEST(PjrtCApiGpuPrivTest, CustomCallUntyped) { +TEST(PjrtCApiGpuExtensionTest, CustomCallUntyped) { PJRT_Gpu_Register_Custom_Call_Args args; args.struct_size = PJRT_Gpu_Register_Custom_Call_Args_STRUCT_SIZE; std::string function_name = "untyped_function_name"; @@ -391,11 +391,11 @@ TEST(PjrtCApiGpuPrivTest, CustomCallUntyped) { args.api_version = 0; args.custom_call_function = reinterpret_cast(&TestCustomCallV2); auto api = GetPjrtApi(); - const PJRT_Structure_Base* next = - reinterpret_cast(api->extension_start); + const PJRT_Extension_Base* next = + reinterpret_cast(api->extension_start); while (next != nullptr && next->type != - PJRT_Structure_Type::PJRT_Structure_Type_Gpu_Custom_Call) { + PJRT_Extension_Type::PJRT_Extension_Type_Gpu_Custom_Call) { next = next->next; } ASSERT_NE(next, nullptr); @@ -413,7 +413,7 @@ static void* kNoop = xla::ffi::Ffi::Bind() .To([]() { return xla::ffi::Error::Success(); }) .release(); -TEST(PjrtCApiGpuPrivTest, CustomCallTyped) { +TEST(PjrtCApiGpuExtensionTest, CustomCallTyped) { PJRT_Gpu_Register_Custom_Call_Args args; args.struct_size = PJRT_Gpu_Register_Custom_Call_Args_STRUCT_SIZE; std::string function_name = "typed_function_name"; @@ -422,11 +422,11 @@ TEST(PjrtCApiGpuPrivTest, CustomCallTyped) { args.api_version = 1; args.custom_call_function = kNoop; auto api = GetPjrtApi(); - const PJRT_Structure_Base* next = - reinterpret_cast(api->extension_start); + const PJRT_Extension_Base* next = + reinterpret_cast(api->extension_start); while (next != nullptr && next->type != - PJRT_Structure_Type::PJRT_Structure_Type_Gpu_Custom_Call) { + PJRT_Extension_Type::PJRT_Extension_Type_Gpu_Custom_Call) { next = next->next; } ASSERT_NE(next, nullptr); diff --git a/third_party/xla/xla/pjrt/c/pjrt_c_api_helpers.cc b/third_party/xla/xla/pjrt/c/pjrt_c_api_helpers.cc index c655dec9dfc922..4263ca7277eec8 100644 --- a/third_party/xla/xla/pjrt/c/pjrt_c_api_helpers.cc +++ b/third_party/xla/xla/pjrt/c/pjrt_c_api_helpers.cc @@ -61,7 +61,7 @@ PJRT_ClientDeleter MakeClientDeleter(const PJRT_Api* api) { return [api](PJRT_Client* client) -> void { PJRT_Client_Destroy_Args destroy_args; destroy_args.struct_size = PJRT_Client_Destroy_Args_STRUCT_SIZE; - destroy_args.priv = nullptr; + destroy_args.extension_start = nullptr; destroy_args.client = client; PJRT_Error* error = api->PJRT_Client_Destroy(&destroy_args); @@ -74,7 +74,7 @@ PJRT_ErrorDeleter MakeErrorDeleter(const PJRT_Api* api) { return [api](PJRT_Error* error) -> void { PJRT_Error_Destroy_Args destroy_args; destroy_args.struct_size = PJRT_Error_Destroy_Args_STRUCT_SIZE; - destroy_args.priv = nullptr; + destroy_args.extension_start = nullptr; destroy_args.error = error; api->PJRT_Error_Destroy(&destroy_args); @@ -85,7 +85,7 @@ PJRT_BufferDeleter MakeBufferDeleter(const PJRT_Api* api) { return [api](PJRT_Buffer* buffer) -> void { PJRT_Buffer_Destroy_Args destroy_args; destroy_args.struct_size = PJRT_Buffer_Destroy_Args_STRUCT_SIZE; - destroy_args.priv = nullptr; + destroy_args.extension_start = nullptr; destroy_args.buffer = buffer; pjrt::LogFatalIfPjrtError(api->PJRT_Buffer_Destroy(&destroy_args), api); @@ -96,7 +96,7 @@ PJRT_ExecutableDeleter MakeExecutableDeleter(const PJRT_Api* api) { return [api](PJRT_Executable* executable) -> void { PJRT_Executable_Destroy_Args args; args.struct_size = PJRT_Executable_Destroy_Args_STRUCT_SIZE; - args.priv = nullptr; + args.extension_start = nullptr; args.executable = executable; pjrt::LogFatalIfPjrtError(api->PJRT_Executable_Destroy(&args), api); }; @@ -106,7 +106,7 @@ PJRT_LoadedExecutableDeleter MakeLoadedExecutableDeleter(const PJRT_Api* api) { return [api](PJRT_LoadedExecutable* executable) -> void { PJRT_LoadedExecutable_Destroy_Args args; args.struct_size = PJRT_LoadedExecutable_Destroy_Args_STRUCT_SIZE; - args.priv = nullptr; + args.extension_start = nullptr; args.executable = executable; pjrt::LogFatalIfPjrtError(api->PJRT_LoadedExecutable_Destroy(&args), api); }; @@ -127,7 +127,7 @@ PJRT_TopologyDescriptionDeleter MakeTopologyDescriptionDeleter( PJRT_TopologyDescription_Destroy_Args destroy_args; destroy_args.struct_size = PJRT_TopologyDescription_Destroy_Args_STRUCT_SIZE; - destroy_args.priv = nullptr; + destroy_args.extension_start = nullptr; destroy_args.topology = topology; pjrt::LogFatalIfPjrtError( @@ -138,7 +138,7 @@ PJRT_TopologyDescriptionDeleter MakeTopologyDescriptionDeleter( PJRT_Error_Code GetErrorCode(const PJRT_Error* error, const PJRT_Api* api) { PJRT_Error_GetCode_Args args; args.struct_size = PJRT_Error_GetCode_Args_STRUCT_SIZE; - args.priv = nullptr; + args.extension_start = nullptr; args.error = error; pjrt::LogFatalIfPjrtError(api->PJRT_Error_GetCode(&args), api); return args.code; @@ -205,7 +205,7 @@ absl::string_view GetPjrtErrorMessage(const PJRT_Error* error, const PJRT_Api* api) { PJRT_Error_Message_Args message_args; message_args.struct_size = PJRT_Error_Message_Args_STRUCT_SIZE; - message_args.priv = nullptr; + message_args.extension_start = nullptr; message_args.error = error; api->PJRT_Error_Message(&message_args); return absl::string_view(message_args.message, message_args.message_size); @@ -225,7 +225,7 @@ PJRT_EventDeleter MakeEventDeleter(const PJRT_Api* api) { return [api](PJRT_Event* managed) { PJRT_Event_Destroy_Args args; args.struct_size = PJRT_Event_Destroy_Args_STRUCT_SIZE; - args.priv = nullptr; + args.extension_start = nullptr; args.event = managed; LogFatalIfPjrtError(api->PJRT_Event_Destroy(&args), api); @@ -389,7 +389,7 @@ xla::PjRtFuture ConvertCEventToCppFuture(PJRT_Event* c_event, using xla::Status, xla::PjRtFuture; PJRT_Event_OnReady_Args event_onready_args; event_onready_args.struct_size = PJRT_Event_OnReady_Args_STRUCT_SIZE; - event_onready_args.priv = nullptr; + event_onready_args.extension_start = nullptr; event_onready_args.event = c_event; PjRtFuture::Promise promise = PjRtFuture::CreatePromise(); @@ -400,7 +400,7 @@ xla::PjRtFuture ConvertCEventToCppFuture(PJRT_Event* c_event, promise.Set(s); ::pjrt::MakeErrorDeleter(c_api)(error); } else { - promise.Set(tsl::OkStatus()); + promise.Set(absl::OkStatus()); } ::pjrt::MakeEventDeleter(c_api)(c_event); }); @@ -419,12 +419,12 @@ xla::PjRtFuture ConvertCEventToCppFuture(PJRT_Event* c_event, return PjRtFuture(std::move(promise)); } -static xla::StatusOr ConvertToPjRtNamedValue( +static absl::StatusOr ConvertToPjRtNamedValue( const std::string& name, const xla::PjRtValueType& value, int api_minor_version) { PJRT_NamedValue c_value; c_value.struct_size = PJRT_NamedValue_STRUCT_SIZE; - c_value.priv = nullptr; + c_value.extension_start = nullptr; c_value.name = name.c_str(); c_value.name_size = name.size(); @@ -468,7 +468,7 @@ static xla::StatusOr ConvertToPjRtNamedValue( return c_value; } -xla::StatusOr> ConvertToPjRtNamedValueList( +absl::StatusOr> ConvertToPjRtNamedValueList( const absl::flat_hash_map& cpp_value_map, int api_minor_version) { std::vector c_value_list; @@ -524,7 +524,7 @@ ConvertFromPjRtNamedValueList(const PJRT_NamedValue* c_value_list, return cpp_value_map; } -static xla::StatusOr GetPjrtNamedValueType( +static absl::StatusOr GetPjrtNamedValueType( xla::PjRtValueType cpp_value) { if (std::holds_alternative(cpp_value)) { return PJRT_NamedValue_Type::PJRT_NamedValue_kString; @@ -564,7 +564,7 @@ xla::Status ValidateCreateOptions( it->second); } } - return tsl::OkStatus(); + return absl::OkStatus(); } static std::string StructSizeErrorMsg(absl::string_view struct_name, @@ -590,13 +590,13 @@ xla::Status ActualStructSizeIsGreaterOrEqual(absl::string_view struct_name, if (actual_size > expected_size) { VLOG(2) << StructSizeErrorMsg(struct_name, expected_size, actual_size); } - return tsl::OkStatus(); + return absl::OkStatus(); } absl::string_view GetPlatformVersion(PJRT_Client* client, const PJRT_Api* api) { PJRT_Client_PlatformVersion_Args args; args.struct_size = PJRT_Client_PlatformVersion_Args_STRUCT_SIZE; - args.priv = nullptr; + args.extension_start = nullptr; args.client = client; LogFatalIfPjrtError(api->PJRT_Client_PlatformVersion(&args), api); @@ -609,18 +609,18 @@ absl::string_view GetPlatformName(PJRT_Client* client, const PJRT_Api* api) { PJRT_Client_PlatformName_Args args; args.client = client; args.struct_size = PJRT_Client_PlatformName_Args_STRUCT_SIZE; - args.priv = nullptr; + args.extension_start = nullptr; pjrt::LogFatalIfPjrtError(api->PJRT_Client_PlatformName(&args), api); absl::string_view platform_name(args.platform_name, args.platform_name_size); return platform_name; } -xla::StatusOr GetTopologyDescription( +absl::StatusOr GetTopologyDescription( PJRT_Client* client, const PJRT_Api* api) { PJRT_Client_TopologyDescription_Args args; args.struct_size = PJRT_Client_TopologyDescription_Args_STRUCT_SIZE; - args.priv = nullptr; + args.extension_start = nullptr; args.client = client; RETURN_STATUS_IF_PJRT_ERROR(api->PJRT_Client_TopologyDescription(&args), api); return args.topology; @@ -659,7 +659,7 @@ PJRT_DeviceDescription* GetDeviceDescription(const PJRT_Api* api, PJRT_Device* device) { PJRT_Device_GetDescription_Args args; args.struct_size = PJRT_Device_GetDescription_Args_STRUCT_SIZE; - args.priv = nullptr; + args.extension_start = nullptr; args.device = device; pjrt::LogFatalIfPjrtError(api->PJRT_Device_GetDescription(&args), api); return args.device_description; @@ -669,7 +669,7 @@ absl::Span GetAddressableMemories(const PJRT_Api* api, PJRT_Device* device) { PJRT_Device_AddressableMemories_Args args; args.struct_size = PJRT_Device_AddressableMemories_Args_STRUCT_SIZE; - args.priv = nullptr; + args.extension_start = nullptr; args.device = device; pjrt::LogFatalIfPjrtError(api->PJRT_Device_AddressableMemories(&args), api); return absl::MakeSpan(args.memories, args.num_memories); @@ -687,7 +687,7 @@ static void PjRtValueDeleterCallback(char* value) { delete[] value; } static PJRT_KeyValueGetCFunc ToKVGetCFunc( xla::KeyValueStoreInterface* kv_store) { return [kv_store](PJRT_KeyValueGetCallback_Args* args) -> PJRT_Error* { - xla::StatusOr output = + absl::StatusOr output = kv_store->Get(std::string_view(args->key, args->key_size), absl::Milliseconds(args->timeout_in_ms)); if (!output.ok()) { @@ -807,7 +807,7 @@ PJRT_RecvCallbackInfo CppRecvCallbackToCRecvCallback( }}; } -xla::StatusOr ConvertToBufferMemoryLayoutData( +absl::StatusOr ConvertToBufferMemoryLayoutData( const xla::Layout& cpp_layout) { BufferMemoryLayoutData layout_data; layout_data.c_layout.type = @@ -834,7 +834,7 @@ xla::StatusOr ConvertToBufferMemoryLayoutData( return layout_data; } -xla::StatusOr ConvertToBufferMemoryLayoutData( +absl::StatusOr ConvertToBufferMemoryLayoutData( absl::Span byte_strides) { BufferMemoryLayoutData layout_data; layout_data.c_layout.type = @@ -844,7 +844,7 @@ xla::StatusOr ConvertToBufferMemoryLayoutData( return layout_data; } -xla::StatusOr ConvertToLayout( +absl::StatusOr ConvertToLayout( const PJRT_Buffer_MemoryLayout_Tiled& c_tiled) { absl::Span minor_to_major(c_tiled.minor_to_major, c_tiled.minor_to_major_size); @@ -864,7 +864,7 @@ xla::StatusOr ConvertToLayout( PJRT_Buffer_Type GetElementType(const PJRT_Api* api, PJRT_Buffer* buffer) { PJRT_Buffer_ElementType_Args args; args.struct_size = PJRT_Buffer_ElementType_Args_STRUCT_SIZE; - args.priv = nullptr; + args.extension_start = nullptr; args.buffer = buffer; LogFatalIfPjrtError(api->PJRT_Buffer_ElementType(&args), api); return args.type; @@ -874,7 +874,7 @@ absl::Span GetDimensions(const PJRT_Api* api, PJRT_Buffer* buffer) { PJRT_Buffer_Dimensions_Args args; args.struct_size = PJRT_Buffer_Dimensions_Args_STRUCT_SIZE; - args.priv = nullptr; + args.extension_start = nullptr; args.buffer = buffer; LogFatalIfPjrtError(api->PJRT_Buffer_Dimensions(&args), api); return {args.dims, args.num_dims}; @@ -884,16 +884,15 @@ PJRT_Buffer_MemoryLayout GetMemoryLayout(const PJRT_Api* api, PJRT_Buffer* buffer) { PJRT_Buffer_GetMemoryLayout_Args args; args.struct_size = PJRT_Buffer_GetMemoryLayout_Args_STRUCT_SIZE; - args.priv = nullptr; + args.extension_start = nullptr; args.buffer = buffer; LogFatalIfPjrtError(api->PJRT_Buffer_GetMemoryLayout(&args), api); return args.layout; } -xla::StatusOr BuildXlaShapeFromC(PJRT_Buffer_Type element_type, - const int64_t* dims, - size_t num_dims, - PJRT_Buffer_MemoryLayout* layout) { +absl::StatusOr BuildXlaShapeFromC( + PJRT_Buffer_Type element_type, const int64_t* dims, size_t num_dims, + PJRT_Buffer_MemoryLayout* layout) { xla::Shape shape = xla::ShapeUtil::MakeShape(ConvertFromPjRtBufferType(element_type), absl::Span(dims, num_dims)); @@ -925,7 +924,7 @@ absl::string_view PlatformName(const PJRT_Api* api, const PJRT_TopologyDescription* topo_desc) { PJRT_TopologyDescription_PlatformName_Args args; args.struct_size = PJRT_TopologyDescription_PlatformName_Args_STRUCT_SIZE; - args.priv = nullptr; + args.extension_start = nullptr; args.topology = const_cast(topo_desc); LogFatalIfPjrtError(api->PJRT_TopologyDescription_PlatformName(&args), api); return {args.platform_name, args.platform_name_size}; @@ -936,7 +935,7 @@ absl::Span DeviceDescriptions( PJRT_TopologyDescription_GetDeviceDescriptions_Args args; args.struct_size = PJRT_TopologyDescription_GetDeviceDescriptions_Args_STRUCT_SIZE; - args.priv = nullptr; + args.extension_start = nullptr; args.topology = const_cast(topo_desc); LogFatalIfPjrtError( api->PJRT_TopologyDescription_GetDeviceDescriptions(&args), api); @@ -947,7 +946,7 @@ absl::StatusOr GetCompiledMemoryStats( const PJRT_Api* api, PJRT_Executable* executable) { PJRT_Executable_GetCompiledMemoryStats_Args args; args.struct_size = PJRT_Executable_GetCompiledMemoryStats_Args_STRUCT_SIZE; - args.priv = 0; + args.extension_start = nullptr; args.executable = executable; RETURN_STATUS_IF_PJRT_ERROR( api->PJRT_Executable_GetCompiledMemoryStats(&args), api); diff --git a/third_party/xla/xla/pjrt/c/pjrt_c_api_helpers.h b/third_party/xla/xla/pjrt/c/pjrt_c_api_helpers.h index 720ef577d5b218..d619ed64ac3631 100644 --- a/third_party/xla/xla/pjrt/c/pjrt_c_api_helpers.h +++ b/third_party/xla/xla/pjrt/c/pjrt_c_api_helpers.h @@ -138,7 +138,7 @@ xla::PjRtFuture ConvertCEventToCppFuture(PJRT_Event* c_event, // The data of returned variable-length PJRT_NamedValue list is backed by // `cpp_value_map`, so `cpp_value_map` must outlive the returned list. It will // raise errors for unsupported PjRtValueType. -xla::StatusOr> ConvertToPjRtNamedValueList( +absl::StatusOr> ConvertToPjRtNamedValueList( const absl::flat_hash_map& cpp_value_map, int api_minor_version); @@ -165,7 +165,7 @@ xla::Status ActualStructSizeIsGreaterOrEqual(absl::string_view struct_name, absl::string_view GetPlatformVersion(PJRT_Client* client, const PJRT_Api* api); absl::string_view GetPlatformName(PJRT_Client* client, const PJRT_Api* api); -xla::StatusOr GetTopologyDescription( +absl::StatusOr GetTopologyDescription( PJRT_Client* client, const PJRT_Api* api); // Releases `chunk`. @@ -245,12 +245,12 @@ struct BufferMemoryLayoutData { std::vector tile_dims; std::vector tile_dim_sizes; }; -xla::StatusOr ConvertToBufferMemoryLayoutData( +absl::StatusOr ConvertToBufferMemoryLayoutData( const xla::Layout& cpp_layout); -xla::StatusOr ConvertToBufferMemoryLayoutData( +absl::StatusOr ConvertToBufferMemoryLayoutData( absl::Span byte_strides); -xla::StatusOr ConvertToLayout( +absl::StatusOr ConvertToLayout( const PJRT_Buffer_MemoryLayout_Tiled& c_tiled); PJRT_Buffer_Type GetElementType(const PJRT_Api* api, PJRT_Buffer* buffer); @@ -259,10 +259,10 @@ absl::Span GetDimensions(const PJRT_Api* api, PJRT_Buffer_MemoryLayout GetMemoryLayout(const PJRT_Api* api, PJRT_Buffer* buffer); -xla::StatusOr BuildXlaShapeFromC(PJRT_Buffer_Type element_type, - const int64_t* dims, - size_t num_dims, - PJRT_Buffer_MemoryLayout* layout); +absl::StatusOr BuildXlaShapeFromC(PJRT_Buffer_Type element_type, + const int64_t* dims, + size_t num_dims, + PJRT_Buffer_MemoryLayout* layout); absl::string_view PlatformName(const PJRT_Api* api, const PJRT_TopologyDescription* topo_desc); diff --git a/third_party/xla/xla/pjrt/c/pjrt_c_api_helpers_test.cc b/third_party/xla/xla/pjrt/c/pjrt_c_api_helpers_test.cc index 39b25f5a3741b9..2575fa4ff9eb5d 100644 --- a/third_party/xla/xla/pjrt/c/pjrt_c_api_helpers_test.cc +++ b/third_party/xla/xla/pjrt/c/pjrt_c_api_helpers_test.cc @@ -110,7 +110,7 @@ TEST(PjRtCApiHelperTest, InvalidOptionName) { auto status = ValidateCreateOptions(invalid_map, expected); - EXPECT_NE(status, tsl::OkStatus()); + EXPECT_NE(status, absl::OkStatus()); EXPECT_THAT(status.message(), HasSubstr("Unexpected option name passed to PJRT_Client_Create")); } @@ -125,7 +125,7 @@ TEST(PjRtCApiHelperTest, InvalidOptionTypeIndex) { auto status = ValidateCreateOptions(invalid_map, expected); - EXPECT_NE(status, tsl::OkStatus()); + EXPECT_NE(status, absl::OkStatus()); EXPECT_THAT(status.message(), HasSubstr("Option passed to PJRT_Client_Create with name string " "has type index 2 but expected type index is 0")); @@ -149,7 +149,7 @@ TEST(PjRtCApiHelperTest, Callback) { TEST(PjRtCApiHelperTest, ConvertToCLayoutFromStrides) { std::vector strides = {4, 8}; - xla::StatusOr layout_data = + absl::StatusOr layout_data = ConvertToBufferMemoryLayoutData(strides); EXPECT_TRUE(layout_data.ok()); diff --git a/third_party/xla/xla/pjrt/c/pjrt_c_api_profiler_extension.h b/third_party/xla/xla/pjrt/c/pjrt_c_api_profiler_extension.h index 8cd54a92645bc4..684c699feae554 100644 --- a/third_party/xla/xla/pjrt/c/pjrt_c_api_profiler_extension.h +++ b/third_party/xla/xla/pjrt/c/pjrt_c_api_profiler_extension.h @@ -26,7 +26,7 @@ extern "C" { #define PJRT_API_PROFILER_EXTENSION_VERSION 0 typedef struct PJRT_Profiler_Extension { - PJRT_Structure_Type type; + PJRT_Extension_Type type; const void* next; PLUGIN_Profiler_Api* profiler_api; } PJRT_Profiler_Extension; diff --git a/third_party/xla/xla/pjrt/c/pjrt_c_api_test.cc b/third_party/xla/xla/pjrt/c/pjrt_c_api_test.cc index 8bfd4bfc82c732..573a43c7aa6865 100644 --- a/third_party/xla/xla/pjrt/c/pjrt_c_api_test.cc +++ b/third_party/xla/xla/pjrt/c/pjrt_c_api_test.cc @@ -149,7 +149,7 @@ TEST_F(PjrtCApiTest, PlatformName) { PJRT_Client_PlatformName_Args args; args.client = client_; args.struct_size = PJRT_Client_PlatformName_Args_STRUCT_SIZE; - args.priv = nullptr; + args.extension_start = nullptr; PJRT_Error* error = api_->PJRT_Client_PlatformName(&args); ASSERT_EQ(error, nullptr); absl::string_view platform_name(args.platform_name, args.platform_name_size); @@ -160,7 +160,7 @@ TEST_F(PjrtCApiTest, ClientProcessIndex) { PJRT_Client_ProcessIndex_Args process_index_args = PJRT_Client_ProcessIndex_Args{ .struct_size = PJRT_Client_ProcessIndex_Args_STRUCT_SIZE, - .priv = nullptr, + .extension_start = nullptr, .client = client_, .process_index = -1, }; @@ -199,7 +199,7 @@ TEST_F(PjrtCApiTest, LookupDevice) { PJRT_Client_LookupDevice_Args lookup_device_args = PJRT_Client_LookupDevice_Args{ .struct_size = PJRT_Client_LookupDevice_Args_STRUCT_SIZE, - .priv = nullptr, + .extension_start = nullptr, .client = client_, .id = 0, .device = nullptr, @@ -217,7 +217,7 @@ TEST_F(PjrtCApiTest, LookupAddressableDevice) { PJRT_Client_LookupAddressableDevice_Args lookup_addressable_device_args = PJRT_Client_LookupAddressableDevice_Args{ .struct_size = PJRT_Client_LookupAddressableDevice_Args_STRUCT_SIZE, - .priv = nullptr, + .extension_start = nullptr, .client = client_, .local_hardware_id = 0, .addressable_device = nullptr, @@ -239,7 +239,7 @@ TEST_F(PjrtCApiTest, GetDefaultDeviceAssignmentNominal) { std::vector assignment_buffer(kNumReplicas * kNumPartitions); PJRT_Client_DefaultDeviceAssignment_Args args{ .struct_size = PJRT_Client_DefaultDeviceAssignment_Args_STRUCT_SIZE, - .priv = nullptr, + .extension_start = nullptr, .client = client_, .num_replicas = kNumReplicas, .num_partitions = kNumPartitions, @@ -257,7 +257,7 @@ TEST_F(PjrtCApiTest, GetDefaultDeviceAssignmentBufferTooSmall) { std::vector assignment_buffer(kBufferSize); PJRT_Client_DefaultDeviceAssignment_Args args{ .struct_size = PJRT_Client_DefaultDeviceAssignment_Args_STRUCT_SIZE, - .priv = nullptr, + .extension_start = nullptr, .client = client_, .num_replicas = kNumReplicas, .num_partitions = kNumPartitions, @@ -276,7 +276,7 @@ TEST_F(PjrtCApiTest, GetDefaultDeviceAssignmentBufferTooSmall) { TEST_F(PjrtCApiTest, LookupDeviceNegativeId) { PJRT_Client_LookupDevice_Args args = PJRT_Client_LookupDevice_Args{ .struct_size = PJRT_Client_LookupDevice_Args_STRUCT_SIZE, - .priv = nullptr, + .extension_start = nullptr, .client = client_, .id = -1, .device = nullptr, @@ -296,7 +296,7 @@ TEST_F(PjrtCApiTest, LookupDeviceOutOfRangeId) { int out_of_range_id = GetNumDevices(); PJRT_Client_LookupDevice_Args args = PJRT_Client_LookupDevice_Args{ .struct_size = PJRT_Client_LookupDevice_Args_STRUCT_SIZE, - .priv = nullptr, + .extension_start = nullptr, .client = client_, .id = out_of_range_id, .device = nullptr, @@ -318,7 +318,7 @@ void destroy_executable(PJRT_LoadedExecutable* executable, const PJRT_Api* api) { PJRT_LoadedExecutable_Destroy_Args args{ .struct_size = PJRT_LoadedExecutable_Destroy_Args_STRUCT_SIZE, - .priv = nullptr, + .extension_start = nullptr, .executable = executable, }; PJRT_Error* error = api->PJRT_LoadedExecutable_Destroy(&args); @@ -345,7 +345,7 @@ TEST_F(PjrtCApiTest, BufferTransferImmutableUntilTransferCompletes) { PJRT_Event_Await_Args await_args; await_args.struct_size = PJRT_Event_Await_Args_STRUCT_SIZE; - await_args.priv = nullptr; + await_args.extension_start = nullptr; await_args.event = event.get(); PJRT_Error* event_error = api_->PJRT_Event_Await(&await_args); ASSERT_EQ(event_error, nullptr); @@ -354,7 +354,7 @@ TEST_F(PjrtCApiTest, BufferTransferImmutableUntilTransferCompletes) { TEST_F(PjrtCApiTest, Compile) { PJRT_Client_Compile_Args args = PJRT_Client_Compile_Args{ .struct_size = PJRT_Client_Compile_Args_STRUCT_SIZE, - .priv = nullptr, + .extension_start = nullptr, .client = client_, }; std::string options_str = BuildSingleDeviceCompileOptionStr(); @@ -365,7 +365,7 @@ TEST_F(PjrtCApiTest, Compile) { std::string program_code{module_add_one}; PJRT_Program program = PJRT_Program{ .struct_size = PJRT_Program_STRUCT_SIZE, - .priv = nullptr, + .extension_start = nullptr, .code = program_code.data(), .code_size = program_code.length(), .format = format.c_str(), @@ -383,7 +383,7 @@ TEST_F(PjrtCApiTest, Compile) { TEST_F(PjrtCApiTest, CompileXlaComputation) { PJRT_Client_Compile_Args args = PJRT_Client_Compile_Args{ .struct_size = PJRT_Client_Compile_Args_STRUCT_SIZE, - .priv = nullptr, + .extension_start = nullptr, .client = client_, }; xla::DeviceAssignment device_assignment(1, 1); @@ -403,7 +403,7 @@ TEST_F(PjrtCApiTest, CompileXlaComputation) { std::string format(::pjrt::kHloFormat); PJRT_Program program = PJRT_Program{ .struct_size = PJRT_Program_STRUCT_SIZE, - .priv = nullptr, + .extension_start = nullptr, .code = module_str.data(), .code_size = module_str.size(), .format = format.c_str(), @@ -421,7 +421,7 @@ TEST_F(PjrtCApiTest, CompileXlaComputation) { TEST_F(PjrtCApiTest, CompileInvalidOption) { PJRT_Client_Compile_Args args = PJRT_Client_Compile_Args{ .struct_size = PJRT_Client_Compile_Args_STRUCT_SIZE, - .priv = nullptr, + .extension_start = nullptr, .client = client_, }; std::string options_str = "invalid compile options"; @@ -432,7 +432,7 @@ TEST_F(PjrtCApiTest, CompileInvalidOption) { std::string program_code{module_add_one}; PJRT_Program program = PJRT_Program{ .struct_size = PJRT_Program_STRUCT_SIZE, - .priv = nullptr, + .extension_start = nullptr, .code = program_code.data(), .code_size = program_code.length(), .format = format.c_str(), @@ -453,7 +453,7 @@ TEST_F(PjrtCApiTest, CompileInvalidOption) { TEST_F(PjrtCApiTest, CompileInvalidProgramFormat) { PJRT_Client_Compile_Args args = PJRT_Client_Compile_Args{ .struct_size = PJRT_Client_Compile_Args_STRUCT_SIZE, - .priv = nullptr, + .extension_start = nullptr, .client = client_, }; xla::DeviceAssignment device_assignment(1, 1); @@ -468,7 +468,7 @@ TEST_F(PjrtCApiTest, CompileInvalidProgramFormat) { std::string format("invalid"); PJRT_Program program = PJRT_Program{ .struct_size = PJRT_Program_STRUCT_SIZE, - .priv = nullptr, + .extension_start = nullptr, .code = nullptr, .code_size = 0, .format = format.c_str(), @@ -498,7 +498,7 @@ TEST_F(PjrtCApiTest, DeviceProcessIndex) { PJRT_DeviceDescription_ProcessIndex_Args args = PJRT_DeviceDescription_ProcessIndex_Args{ .struct_size = PJRT_DeviceDescription_ProcessIndex_Args_STRUCT_SIZE, - .priv = nullptr, + .extension_start = nullptr, .device_description = ::pjrt::GetDeviceDescription(api_, GetClientDevices()[0]), .process_index = -1, @@ -512,7 +512,7 @@ TEST_F(PjrtCApiTest, DeviceProcessIndex) { TEST_F(PjrtCApiTest, DeviceIsAddressable) { PJRT_Device_IsAddressable_Args args = PJRT_Device_IsAddressable_Args{ .struct_size = PJRT_Device_IsAddressable_Args_STRUCT_SIZE, - .priv = nullptr, + .extension_start = nullptr, .device = GetClientDevices()[0], .is_addressable = false, }; @@ -525,7 +525,7 @@ TEST_F(PjrtCApiTest, DeviceIsAddressable) { TEST_F(PjrtCApiTest, DeviceLocalHardwareId) { PJRT_Device_LocalHardwareId_Args args = PJRT_Device_LocalHardwareId_Args{ .struct_size = PJRT_Device_LocalHardwareId_Args_STRUCT_SIZE, - .priv = nullptr, + .extension_start = nullptr, .device = GetClientDevices()[0], .local_hardware_id = -1, }; @@ -565,7 +565,7 @@ class PjrtCApiBufferTest : public PjrtCApiTest { TEST_F(PjrtCApiBufferTest, IsDeleted) { PJRT_Buffer_IsDeleted_Args is_deleted_args; is_deleted_args.struct_size = PJRT_Buffer_IsDeleted_Args_STRUCT_SIZE; - is_deleted_args.priv = nullptr; + is_deleted_args.extension_start = nullptr; is_deleted_args.buffer = buffer_.get(); PJRT_Error* is_deleted_error = api_->PJRT_Buffer_IsDeleted(&is_deleted_args); ASSERT_EQ(is_deleted_error, nullptr); @@ -573,7 +573,7 @@ TEST_F(PjrtCApiBufferTest, IsDeleted) { PJRT_Buffer_Delete_Args delete_args; delete_args.struct_size = PJRT_Buffer_Delete_Args_STRUCT_SIZE; - delete_args.priv = nullptr; + delete_args.extension_start = nullptr; delete_args.buffer = buffer_.get(); PJRT_Error* delete_error = api_->PJRT_Buffer_Delete(&delete_args); ASSERT_EQ(delete_error, nullptr); @@ -586,7 +586,7 @@ TEST_F(PjrtCApiBufferTest, IsDeleted) { TEST_F(PjrtCApiBufferTest, GetOnDeviceSizeInBytes) { PJRT_Buffer_OnDeviceSizeInBytes_Args args; args.struct_size = PJRT_Buffer_OnDeviceSizeInBytes_Args_STRUCT_SIZE; - args.priv = nullptr; + args.extension_start = nullptr; args.buffer = buffer_.get(); PJRT_Error* on_device_size_bytes_error = api_->PJRT_Buffer_OnDeviceSizeInBytes(&args); @@ -598,7 +598,7 @@ TEST_F(PjrtCApiBufferTest, GetOnDeviceSizeInBytes) { TEST_F(PjrtCApiBufferTest, ReadyEvent) { PJRT_Buffer_ReadyEvent_Args get_event_args; get_event_args.struct_size = PJRT_Buffer_ReadyEvent_Args_STRUCT_SIZE; - get_event_args.priv = nullptr; + get_event_args.extension_start = nullptr; get_event_args.buffer = buffer_.get(); auto error = ToUniquePtr(api_->PJRT_Buffer_ReadyEvent(&get_event_args)); ASSERT_EQ(error, nullptr); @@ -609,7 +609,7 @@ TEST_F(PjrtCApiBufferTest, ReadyEvent) { // Wait for `buffer_`'s data transfer to complete (if it hasn't already) PJRT_Event_Await_Args await_args; await_args.struct_size = PJRT_Event_Await_Args_STRUCT_SIZE; - await_args.priv = nullptr; + await_args.extension_start = nullptr; await_args.event = event; error.reset(api_->PJRT_Event_Await(&await_args)); ASSERT_EQ(error, nullptr); @@ -617,7 +617,7 @@ TEST_F(PjrtCApiBufferTest, ReadyEvent) { // Must be ready when `PJRT_Event_Await` completes PJRT_Event_IsReady_Args ready_args; ready_args.struct_size = PJRT_Event_IsReady_Args_STRUCT_SIZE; - ready_args.priv = nullptr; + ready_args.extension_start = nullptr; ready_args.event = event; error.reset(api_->PJRT_Event_IsReady(&ready_args)); ASSERT_EQ(error, nullptr); @@ -626,7 +626,7 @@ TEST_F(PjrtCApiBufferTest, ReadyEvent) { // Clean up PJRT_Event_Destroy_Args destroy_args; destroy_args.struct_size = PJRT_Event_Destroy_Args_STRUCT_SIZE; - destroy_args.priv = nullptr; + destroy_args.extension_start = nullptr; destroy_args.event = event; error.reset(api_->PJRT_Event_Destroy(&destroy_args)); EXPECT_EQ(error, nullptr); @@ -635,7 +635,7 @@ TEST_F(PjrtCApiBufferTest, ReadyEvent) { TEST_F(PjrtCApiBufferTest, ToHostBufferNoHostLayout) { PJRT_Buffer_ToHostBuffer_Args args; args.struct_size = PJRT_Buffer_ToHostBuffer_Args_STRUCT_SIZE; - args.priv = nullptr; + args.extension_start = nullptr; args.src = buffer_.get(); xla::Shape host_shape = xla::ShapeUtil::MakeShape(xla::F32, {4}); auto literal = std::make_shared(host_shape); @@ -661,7 +661,7 @@ TEST_F(PjrtCApiBufferTest, IncreaseAndDecreaseReferenceCount) { PJRT_Buffer_IncreaseExternalReferenceCount_Args increase_reference_count_args; increase_reference_count_args.struct_size = PJRT_Buffer_IncreaseExternalReferenceCount_Args_STRUCT_SIZE; - increase_reference_count_args.priv = nullptr; + increase_reference_count_args.extension_start = nullptr; increase_reference_count_args.buffer = buffer_.get(); PJRT_Error* increase_reference_count_error = api_->PJRT_Buffer_IncreaseExternalReferenceCount( @@ -671,7 +671,7 @@ TEST_F(PjrtCApiBufferTest, IncreaseAndDecreaseReferenceCount) { PJRT_Buffer_DecreaseExternalReferenceCount_Args decrease_reference_count_args; decrease_reference_count_args.struct_size = PJRT_Buffer_DecreaseExternalReferenceCount_Args_STRUCT_SIZE; - decrease_reference_count_args.priv = nullptr; + decrease_reference_count_args.extension_start = nullptr; decrease_reference_count_args.buffer = buffer_.get(); PJRT_Error* decrease_reference_error = api_->PJRT_Buffer_DecreaseExternalReferenceCount( @@ -683,7 +683,7 @@ TEST_F(PjrtCApiBufferTest, DecreaseReferenceCountReturnsError) { PJRT_Buffer_DecreaseExternalReferenceCount_Args args; args.struct_size = PJRT_Buffer_DecreaseExternalReferenceCount_Args_STRUCT_SIZE; - args.priv = nullptr; + args.extension_start = nullptr; args.buffer = buffer_.get(); auto error = ToUniquePtr(api_->PJRT_Buffer_DecreaseExternalReferenceCount(&args)); @@ -698,7 +698,7 @@ TEST_F(PjrtCApiBufferTest, DecreaseReferenceCountReturnsError) { TEST_F(PjrtCApiBufferTest, OpaqueDeviceMemoryDataPointer) { PJRT_Buffer_OpaqueDeviceMemoryDataPointer_Args args; args.struct_size = PJRT_Buffer_OpaqueDeviceMemoryDataPointer_Args_STRUCT_SIZE; - args.priv = nullptr; + args.extension_start = nullptr; args.buffer = buffer_.get(); PJRT_Error* error = api_->PJRT_Buffer_OpaqueDeviceMemoryDataPointer(&args); EXPECT_EQ(error, nullptr); diff --git a/third_party/xla/xla/pjrt/c/pjrt_c_api_test_base.cc b/third_party/xla/xla/pjrt/c/pjrt_c_api_test_base.cc index 8eafe9382fb15e..aa383745be2794 100644 --- a/third_party/xla/xla/pjrt/c/pjrt_c_api_test_base.cc +++ b/third_party/xla/xla/pjrt/c/pjrt_c_api_test_base.cc @@ -42,7 +42,7 @@ namespace { PJRT_Client* CreateClient(const PJRT_Api* api) { PJRT_Client_Create_Args create_args; create_args.struct_size = PJRT_Client_Create_Args_STRUCT_SIZE; - create_args.priv = nullptr; + create_args.extension_start = nullptr; create_args.create_options = nullptr; create_args.num_options = 0; create_args.kv_get_callback = nullptr; @@ -67,7 +67,7 @@ PjrtCApiTestBase::~PjrtCApiTestBase() { destroy_client(client_); } void PjrtCApiTestBase::destroy_client(PJRT_Client* client) { PJRT_Client_Destroy_Args destroy_args; destroy_args.struct_size = PJRT_Client_Destroy_Args_STRUCT_SIZE; - destroy_args.priv = nullptr; + destroy_args.extension_start = nullptr; destroy_args.client = client; PJRT_Error* error = api_->PJRT_Client_Destroy(&destroy_args); CHECK_EQ(error, nullptr); @@ -76,7 +76,7 @@ void PjrtCApiTestBase::destroy_client(PJRT_Client* client) { int PjrtCApiTestBase::GetDeviceId(PJRT_DeviceDescription* device_desc) const { PJRT_DeviceDescription_Id_Args args = PJRT_DeviceDescription_Id_Args{ .struct_size = PJRT_DeviceDescription_Id_Args_STRUCT_SIZE, - .priv = nullptr, + .extension_start = nullptr, .device_description = device_desc, .id = -1, }; @@ -96,7 +96,7 @@ bool PjrtCApiTestBase::IsValidDeviceId(PJRT_Device* device) const { int PjrtCApiTestBase::GetLocalHardwareId(PJRT_Device* device) const { PJRT_Device_LocalHardwareId_Args args = PJRT_Device_LocalHardwareId_Args{ .struct_size = PJRT_Device_LocalHardwareId_Args_STRUCT_SIZE, - .priv = nullptr, + .extension_start = nullptr, .device = device, .local_hardware_id = -1, }; @@ -108,7 +108,7 @@ int PjrtCApiTestBase::GetLocalHardwareId(PJRT_Device* device) const { absl::Span PjrtCApiTestBase::GetClientDevices() const { PJRT_Client_Devices_Args dev_args; dev_args.struct_size = PJRT_Client_Devices_Args_STRUCT_SIZE; - dev_args.priv = nullptr; + dev_args.extension_start = nullptr; dev_args.client = client_; PJRT_Error* error = api_->PJRT_Client_Devices(&dev_args); CHECK(error == nullptr); @@ -136,7 +136,7 @@ absl::Span PjrtCApiTestBase::GetClientAddressableDevices() const { PJRT_Client_AddressableDevices_Args addr_args; addr_args.struct_size = PJRT_Client_AddressableDevices_Args_STRUCT_SIZE; - addr_args.priv = nullptr; + addr_args.extension_start = nullptr; addr_args.client = client_; PJRT_Error* error = api_->PJRT_Client_AddressableDevices(&addr_args); CHECK(error == nullptr); @@ -151,7 +151,7 @@ PjrtCApiTestBase::CreateBufferFromHostBufferArgs( PJRT_Device* device) { PJRT_Client_BufferFromHostBuffer_Args args; args.struct_size = PJRT_Client_BufferFromHostBuffer_Args_STRUCT_SIZE; - args.priv = nullptr; + args.extension_start = nullptr; args.data = data.data(); args.type = ::pjrt::ConvertToPjRtBufferType(shape.element_type()); @@ -195,7 +195,7 @@ PjrtCApiTestBase::create_buffer(PJRT_Device* device) { PJRT_Buffer_ReadyEvent_Args get_event_args; get_event_args.struct_size = PJRT_Buffer_ReadyEvent_Args_STRUCT_SIZE; - get_event_args.priv = nullptr; + get_event_args.extension_start = nullptr; get_event_args.buffer = buffer.get(); auto ready_event_error = ToUniquePtr(api_->PJRT_Buffer_ReadyEvent(&get_event_args)); diff --git a/third_party/xla/xla/pjrt/c/pjrt_c_api_wrapper_impl.cc b/third_party/xla/xla/pjrt/c/pjrt_c_api_wrapper_impl.cc index ef8387d7365d5d..e405a5b2e1a207 100644 --- a/third_party/xla/xla/pjrt/c/pjrt_c_api_wrapper_impl.cc +++ b/third_party/xla/xla/pjrt/c/pjrt_c_api_wrapper_impl.cc @@ -120,7 +120,7 @@ static xla::Status PopulateExecutableCostAnalysis(PJRT_Executable* executable) { std::string& property_name = cost_analysis_names[i]; cost_analysis_property.struct_size = PJRT_NamedValue_STRUCT_SIZE; - cost_analysis_property.priv = nullptr; + cost_analysis_property.extension_start = nullptr; property_name = property.first; cost_analysis_property.name = property_name.c_str(); @@ -516,7 +516,7 @@ static void PopulatePjrtExecutableAddressableDevices( namespace { -xla::StatusOr ParseCompileOptions( +absl::StatusOr ParseCompileOptions( absl::string_view options_str) { xla::CompileOptionsProto options_proto; // Open source ParseFromString doesn't support string_view. @@ -529,7 +529,7 @@ xla::StatusOr ParseCompileOptions( using ProgramVariant = std::variant, xla::XlaComputation>; -xla::StatusOr< +absl::StatusOr< std::variant, xla::XlaComputation>> ParsePjrtProgram(std::optional& context, const PJRT_Program* program) { @@ -1075,8 +1075,8 @@ static xla::Status VerifyOptimizedProgramArgs( return xla::OkStatus(); } -static xla::StatusOr> GetOptimizedProgramModule( - const PJRT_Executable_OptimizedProgram_Args* args) { +static absl::StatusOr> +GetOptimizedProgramModule(const PJRT_Executable_OptimizedProgram_Args* args) { TF_ASSIGN_OR_RETURN(std::vector> hlo_modules, args->executable->get()->GetHloModules()); if (hlo_modules.empty()) { @@ -1272,7 +1272,7 @@ static xla::SendCallback CSendCallbackToCpp( std::unique_ptr error(callback( &c_chunk, &c_callback_error, total_size_in_bytes, done, user_arg)); if (error == nullptr) { - return tsl::OkStatus(); + return absl::OkStatus(); } return error->status; }}; @@ -1936,7 +1936,7 @@ PJRT_Error* PJRT_Event_Error(PJRT_Event_Error_Args* args) { if (!event->status.has_value()) { PJRT_Event_Await_Args await_args; await_args.struct_size = PJRT_Event_Await_Args_STRUCT_SIZE; - await_args.priv = nullptr; + await_args.extension_start = nullptr; await_args.event = event; return PJRT_Event_Await(&await_args); } @@ -2077,7 +2077,7 @@ static std::vector PopulatePjrtAttributes( for (auto const& [name, value] : attributes) { PJRT_NamedValue& cur_attribute = c_attributes[ind]; cur_attribute.struct_size = PJRT_NamedValue_STRUCT_SIZE; - cur_attribute.priv = nullptr; + cur_attribute.extension_start = nullptr; cur_attribute.name = name.c_str(); cur_attribute.name_size = name.size(); if (const std::string* string_val = std::get_if(&value)) { @@ -2173,9 +2173,9 @@ static void AttachDevicesAndMemories(PJRT_Client* c_client) { } } -static xla::StatusOr> +static absl::StatusOr> GetStatusOrTopologyDescription(const xla::PjRtClient& cpp_client) { - xla::StatusOr status_or_cpp_topo = + absl::StatusOr status_or_cpp_topo = cpp_client.GetTopologyDescription(); if (!status_or_cpp_topo.ok()) { return status_or_cpp_topo.status(); diff --git a/third_party/xla/xla/pjrt/c/pjrt_c_api_wrapper_impl.h b/third_party/xla/xla/pjrt/c/pjrt_c_api_wrapper_impl.h index d277b9fbecc7a9..e679c1d4f27a28 100644 --- a/third_party/xla/xla/pjrt/c/pjrt_c_api_wrapper_impl.h +++ b/third_party/xla/xla/pjrt/c/pjrt_c_api_wrapper_impl.h @@ -78,7 +78,7 @@ struct PJRT_Client { // `owned_memories`. absl::flat_hash_map c_memory_from_cpp_memory; - xla::StatusOr> topology; + absl::StatusOr> topology; explicit PJRT_Client(std::unique_ptr cpp_client); }; @@ -112,7 +112,7 @@ struct PJRT_Executable { // Must be shared_ptr so that we can share with PJRT_LoadedExecutable. std::shared_ptr executable; - xla::StatusOr fingerprint; + absl::StatusOr fingerprint; // Used to synchronize concurrent setting of cached values. mutable absl::Mutex mutex; diff --git a/third_party/xla/xla/pjrt/compile_options.proto b/third_party/xla/xla/pjrt/compile_options.proto index 3cb7195c2e2d5d..4ea4af933e9367 100644 --- a/third_party/xla/xla/pjrt/compile_options.proto +++ b/third_party/xla/xla/pjrt/compile_options.proto @@ -7,7 +7,7 @@ import "xla/xla.proto"; import "xla/xla_data.proto"; // A serialization of xla::ExecutableBuildOptions. -// Next id: 18. +// Next id: 19. message ExecutableBuildOptionsProto { // If set, this is the device to build the computation for. Valid // device_ordinal values are: 0 to # of devices - 1. These values are @@ -65,6 +65,18 @@ message ExecutableBuildOptionsProto { // which can be used to compile post-optimizations HLO modules. bool run_backend_only = 11; + // Allows sharding propagation to propagate to the parameters. This changes + // the input shape of the computation (which is undesirable), but it can be + // used to allow to run partial compilation to determine what would be the + // input sharding of a computation if XLA would be allowed to propagate the + // sharding which can be used by higher level framework as a way to query + // intermediate sharding of operations when multiple computation would be + // chained and merged together. + // This is a vector of bool, because the user can control which parameters can + // have the sharding substituted. If only one boolean value is passed in the + // vector that is interpreted as the value to be applied for every parameter. + repeated bool allow_spmd_sharding_propagation_to_parameters = 18; + // Allows sharding propagation to propagate to the outputs. This changes the // output shape of the computation (which is undesirable), but it can be used // to allow to run partial compilation to determine what would be the output diff --git a/third_party/xla/xla/pjrt/cpu/BUILD b/third_party/xla/xla/pjrt/cpu/BUILD index 1b8d2aca5f208c..f736d66bfeb4bf 100644 --- a/third_party/xla/xla/pjrt/cpu/BUILD +++ b/third_party/xla/xla/pjrt/cpu/BUILD @@ -3,7 +3,8 @@ load("@local_tsl//tsl/platform:build_config.bzl", "tf_proto_library") load("@local_tsl//tsl/platform:rules_cc.bzl", "cc_library") package( - default_visibility = ["//visibility:public"], + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = ["//xla:internal"], licenses = ["notice"], ) @@ -21,7 +22,6 @@ cc_library( name = "tracked_tfrt_cpu_device_buffer", srcs = ["tracked_tfrt_cpu_device_buffer.cc"], hdrs = ["tracked_tfrt_cpu_device_buffer.h"], - visibility = ["//visibility:public"], deps = [ "//xla:cpu_function_runtime", "//xla:shape_util", @@ -54,7 +54,9 @@ cc_library( name = "abstract_tfrt_cpu_buffer", srcs = ["abstract_tfrt_cpu_buffer.cc"], hdrs = ["abstract_tfrt_cpu_buffer.h"], - visibility = ["//visibility:public"], + visibility = [ + "//xla:friends", + ], deps = [ ":tracked_tfrt_cpu_device_buffer", "//xla:cpu_function_runtime", @@ -106,7 +108,6 @@ cc_library( name = "cpu_topology", srcs = ["cpu_topology.cc"], hdrs = ["cpu_topology.h"], - visibility = ["//visibility:public"], deps = [ ":cpu_topology_proto_cc", "@com_google_absl//absl/types:span", @@ -128,7 +129,9 @@ cc_library( name = "cpu_client", srcs = ["cpu_client.cc"], hdrs = ["cpu_client.h"], - visibility = ["//visibility:public"], + visibility = [ + "//xla:friends", + ], deps = [ ":abstract_tfrt_cpu_buffer", ":cpu_topology", @@ -245,7 +248,6 @@ cc_library( "-fno-strict-aliasing", ], features = ["-use_header_modules"], - visibility = ["//visibility:public"], deps = [ "//third_party/gloo", "//xla/pjrt:status_casters", @@ -263,7 +265,6 @@ cc_library( "-fno-strict-aliasing", ], features = ["-use_header_modules"], - visibility = ["//visibility:public"], deps = [ "//third_party/gloo", "//xla:shape_util", diff --git a/third_party/xla/xla/pjrt/cpu/cpu_client.cc b/third_party/xla/xla/pjrt/cpu/cpu_client.cc index 1e3a3e65ce9754..ad45573a9f0e33 100644 --- a/third_party/xla/xla/pjrt/cpu/cpu_client.cc +++ b/third_party/xla/xla/pjrt/cpu/cpu_client.cc @@ -645,15 +645,15 @@ static StatusOr> JitCompile( const XlaComputation& computation, const absl::Span argument_layouts, const ExecutableBuildOptions& build_options, - const ExecutionOptions& execution_options) { + const ExecutionOptions& execution_options, + const xla::Compiler::CompileOptions& compile_options, int num_threads) { TF_ASSIGN_OR_RETURN(ProgramShape program_shape, computation.GetProgramShape()); // Unoptimized HloModuleConfig. TF_ASSIGN_OR_RETURN( std::unique_ptr hlo_module_config, CreateModuleConfig(program_shape, argument_layouts, &execution_options, - execution_options.num_replicas(), - /*num_threads=*/std::nullopt, + execution_options.num_replicas(), num_threads, /*aot_options=*/nullptr)); // Unoptimized HloModule. @@ -669,14 +669,13 @@ static StatusOr> JitCompile( bool allow_sparse_shapes = hlo_module->config().debug_options().xla_cpu_use_xla_runtime(); cpu::CpuCompiler compiler(allow_sparse_shapes); - xla::Compiler::CompileOptions dummy; - TF_ASSIGN_OR_RETURN(hlo_module, - compiler.RunHloPasses(std::move(hlo_module), - /*stream_exec=*/nullptr, dummy)); + TF_ASSIGN_OR_RETURN(hlo_module, compiler.RunHloPasses(std::move(hlo_module), + /*stream_exec=*/nullptr, + compile_options)); // Run backend. return compiler.RunBackend(std::move(hlo_module), /*stream_exec=*/nullptr, - dummy); + compile_options); } StatusOr> TfrtCpuClient::Compile( @@ -758,9 +757,14 @@ StatusOr> TfrtCpuClient::Compile( computation.GetProgramShape()); ExecutionOptions execution_options = CreateExecutionOptions(build_options, &program_shape); - TF_ASSIGN_OR_RETURN(std::unique_ptr cpu_executable, - JitCompile(computation, argument_layout_pointers, - build_options, execution_options)); + xla::Compiler::CompileOptions compile_options{ + build_options.device_allocator(), build_options.compile_thread_pool(), + build_options.layout_canonicalization_callback()}; + TF_ASSIGN_OR_RETURN( + std::unique_ptr cpu_executable, + JitCompile(computation, argument_layout_pointers, build_options, + execution_options, compile_options, + eigen_intraop_device()->getPool()->NumThreads())); auto cpu_executable_ptr = tensorflow::down_cast(cpu_executable.get()); diff --git a/third_party/xla/xla/pjrt/distributed/BUILD b/third_party/xla/xla/pjrt/distributed/BUILD index 907329eb8327e8..2e20fc7b2e8f70 100644 --- a/third_party/xla/xla/pjrt/distributed/BUILD +++ b/third_party/xla/xla/pjrt/distributed/BUILD @@ -6,21 +6,20 @@ load("@local_tsl//tsl/platform:rules_cc.bzl", "cc_library") licenses(["notice"]) package( - default_visibility = ["//visibility:public"], + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = ["//xla:internal"], ) tf_proto_library( name = "protocol_proto", srcs = ["protocol.proto"], cc_api_version = 2, - visibility = ["//visibility:public"], ) cc_library( name = "service", srcs = ["service.cc"], hdrs = ["service.h"], - visibility = ["//visibility:public"], deps = [ ":topology_util", ":util", @@ -68,7 +67,6 @@ cc_library( hdrs = [ "client.h", ], - visibility = ["//visibility:public"], deps = [ ":key_value_store_interface", ":util", @@ -96,7 +94,6 @@ cc_library( cc_library( name = "util", hdrs = ["util.h"], - visibility = ["//visibility:public"], deps = [ "//xla:status", ] + tsl_grpc_cc_dependencies(), @@ -106,11 +103,11 @@ cc_library( name = "distributed", srcs = ["distributed.cc"], hdrs = ["distributed.h"], - visibility = ["//visibility:public"], deps = [ ":client", ":service", "//xla:statusor", + "@local_tsl//tsl/platform:grpc_credentials", ] + tsl_grpc_cc_dependencies(), ) @@ -118,7 +115,6 @@ cc_library( name = "topology_util", srcs = ["topology_util.cc"], hdrs = ["topology_util.h"], - visibility = ["//visibility:public"], deps = [ ":key_value_store_interface", ":protocol_proto_cc", @@ -166,7 +162,6 @@ xla_cc_test( cc_library( name = "key_value_store_interface", hdrs = ["key_value_store_interface.h"], - visibility = ["//visibility:public"], deps = [ "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", @@ -178,7 +173,6 @@ cc_library( name = "in_memory_key_value_store", srcs = ["in_memory_key_value_store.cc"], hdrs = ["in_memory_key_value_store.h"], - visibility = ["//visibility:public"], deps = [ ":key_value_store_interface", "@com_google_absl//absl/container:flat_hash_map", diff --git a/third_party/xla/xla/pjrt/distributed/distributed.cc b/third_party/xla/xla/pjrt/distributed/distributed.cc index 2c3977dd238de7..dfd6b09e9d3690 100644 --- a/third_party/xla/xla/pjrt/distributed/distributed.cc +++ b/third_party/xla/xla/pjrt/distributed/distributed.cc @@ -15,27 +15,32 @@ limitations under the License. #include "xla/pjrt/distributed/distributed.h" +#include #include -#include "grpcpp/grpcpp.h" +#include "grpcpp/channel.h" +#include "grpcpp/create_channel.h" #include "xla/pjrt/distributed/client.h" #include "xla/pjrt/distributed/service.h" +#include "xla/statusor.h" +#include "tsl/platform/grpc_credentials.h" namespace xla { +// In OSS, insecure credentials are used as default. +constexpr bool kVerifySecureCredentials = false; + StatusOr> GetDistributedRuntimeService(std::string address, const CoordinationServiceImpl::Options& options) { - auto credentials = ::grpc::InsecureServerCredentials(); - return DistributedRuntimeService::Get(address, credentials, options); + return DistributedRuntimeService::Get( + address, tsl::GetServerCredentials(kVerifySecureCredentials), options); } std::shared_ptr GetDistributedRuntimeClient( std::string address, const DistributedRuntimeClient::Options& options) { - std::shared_ptr<::grpc::ChannelCredentials> creds = - ::grpc::InsecureChannelCredentials(); - std::shared_ptr<::grpc::Channel> channel = - ::grpc::CreateChannel(address, creds); + std::shared_ptr channel = grpc::CreateChannel( + address, tsl::GetClientCredentials(kVerifySecureCredentials)); return GetDistributedRuntimeClient(channel, options); } diff --git a/third_party/xla/xla/pjrt/event_pool.cc b/third_party/xla/xla/pjrt/event_pool.cc index 1db130633c2226..fb3364917967a7 100644 --- a/third_party/xla/xla/pjrt/event_pool.cc +++ b/third_party/xla/xla/pjrt/event_pool.cc @@ -33,7 +33,7 @@ EventPool::Handle::~Handle() { EventPool::EventPool(bool allow_reuse) : allow_reuse_(allow_reuse), next_sequence_number_(1) {} -StatusOr EventPool::AllocateEvent( +absl::StatusOr EventPool::AllocateEvent( se::StreamExecutor* executor) { Handle event; @@ -58,7 +58,7 @@ void EventPool::ThenRecordEvent(se::Stream* stream, EventPool::Handle& handle) { handle.sequence_number_ = next_sequence_number_++; } -StatusOr EventPool::ThenAllocateAndRecordEvent( +absl::StatusOr EventPool::ThenAllocateAndRecordEvent( se::Stream* stream) { TF_ASSIGN_OR_RETURN(EventPool::Handle handle, AllocateEvent(stream->parent())); diff --git a/third_party/xla/xla/pjrt/event_pool.h b/third_party/xla/xla/pjrt/event_pool.h index 1286cdbf73541c..89b8f6d8161a7b 100644 --- a/third_party/xla/xla/pjrt/event_pool.h +++ b/third_party/xla/xla/pjrt/event_pool.h @@ -76,11 +76,11 @@ class EventPool { // such as cudaStreamWaitEvent capture the state of the event at the time of // the host-side call and are not affected by a later host-side // cudaEventRecord. - StatusOr ThenAllocateAndRecordEvent(se::Stream* stream); + absl::StatusOr ThenAllocateAndRecordEvent(se::Stream* stream); // Version of ThenAllocateAndRecordEvent split into two phases; this is // sometimes helpful if we want to avoid failures by preallocating events. - StatusOr AllocateEvent(se::StreamExecutor* executor); + absl::StatusOr AllocateEvent(se::StreamExecutor* executor); void ThenRecordEvent(se::Stream* stream, EventPool::Handle& handle); private: diff --git a/third_party/xla/xla/pjrt/gpu/BUILD b/third_party/xla/xla/pjrt/gpu/BUILD index 2a3b5805808bfa..646a94a504e26b 100644 --- a/third_party/xla/xla/pjrt/gpu/BUILD +++ b/third_party/xla/xla/pjrt/gpu/BUILD @@ -2,12 +2,14 @@ load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda") load("//xla:xla.bzl", "xla_cc_test") load("//xla/stream_executor:build_defs.bzl", "if_cuda_or_rocm", "if_gpu_is_configured") load("@local_config_rocm//rocm:build_defs.bzl", "if_rocm") +load("@local_tsl//tsl:tsl.bzl", "internal_visibility") load("@local_tsl//tsl/platform:build_config.bzl", "tf_proto_library") load("@local_tsl//tsl/platform:rules_cc.bzl", "cc_library") load("@local_tsl//tsl/platform/default:cuda_build_defs.bzl", "if_cuda_is_configured") package( - default_visibility = ["//visibility:public"], + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = internal_visibility(["//xla:internal"]), licenses = ["notice"], ) @@ -15,7 +17,7 @@ cc_library( name = "gpu_helpers", srcs = ["gpu_helpers.cc"], hdrs = ["gpu_helpers.h"], - visibility = ["//visibility:public"], + visibility = internal_visibility(["//xla/pjrt:friends"]), deps = [ "//xla:statusor", "//xla:types", @@ -37,7 +39,7 @@ cc_library( srcs = ["se_gpu_pjrt_client.cc"], hdrs = ["se_gpu_pjrt_client.h"], defines = if_cuda(["GOOGLE_CUDA=1"]) + if_rocm(["TENSORFLOW_USE_ROCM=1"]), - visibility = ["//visibility:public"], + visibility = internal_visibility(["//xla/pjrt:friends"]), deps = [ ":gpu_helpers", ":gpu_metrics", @@ -162,7 +164,6 @@ cc_library( name = "nccl_id_store", srcs = ["nccl_id_store.cc"], hdrs = ["nccl_id_store.h"], - visibility = ["//visibility:public"], deps = [ "//xla:status_macros", "//xla:statusor", @@ -212,7 +213,6 @@ cc_library( name = "gpu_topology", srcs = ["gpu_topology.cc"], hdrs = ["gpu_topology.h"], - visibility = ["//visibility:public"], deps = [ ":gpu_topology_proto_cc", ], @@ -223,7 +223,6 @@ cc_library( srcs = ["se_gpu_pjrt_compiler.cc"], hdrs = ["se_gpu_pjrt_compiler.h"], defines = if_cuda(["GOOGLE_CUDA=1"]) + if_rocm(["TENSORFLOW_USE_ROCM=1"]), - visibility = ["//visibility:public"], deps = [ ":se_gpu_pjrt_client", "//xla:status_macros", @@ -268,7 +267,6 @@ cc_library( name = "gpu_metrics", srcs = ["gpu_metrics.cc"], hdrs = ["gpu_metrics.h"], - visibility = ["//visibility:public"], deps = [ "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:string_view", diff --git a/third_party/xla/xla/pjrt/gpu/gpu_helpers.cc b/third_party/xla/xla/pjrt/gpu/gpu_helpers.cc index 6b69a465bf4b0f..efd5fbf8c17f60 100644 --- a/third_party/xla/xla/pjrt/gpu/gpu_helpers.cc +++ b/third_party/xla/xla/pjrt/gpu/gpu_helpers.cc @@ -34,7 +34,7 @@ limitations under the License. namespace xla { // Builds an xla::LocalClient for the GPU platform. -StatusOr GetGpuXlaClient( +absl::StatusOr GetGpuXlaClient( const std::optional& platform_name, const std::optional>& allowed_devices) { TF_ASSIGN_OR_RETURN( @@ -71,7 +71,7 @@ void EnablePeerAccess(absl::Span executors) { } // Builds a BFCAllocator for all local GPUs. -StatusOr> CreateBFCAllocator( +absl::StatusOr> CreateBFCAllocator( se::StreamExecutor* executor, double memory_fraction, bool preallocate) { bool enable_unified_memory; Status status = tsl::ReadBoolFromEnvVar("TF_FORCE_UNIFIED_MEMORY", false, @@ -119,8 +119,9 @@ StatusOr> CreateBFCAllocator( } // Builds a BFCAllocator for all local GPUs that uses collective memory. -StatusOr> CreateCollectiveBFCAllocator( - se::StreamExecutor* executor, size_t allocator_memory, bool preallocate) { +absl::StatusOr> CreateCollectiveBFCAllocator( + se::StreamExecutor* executor, double memory_fraction, + size_t collective_memory_size) { int device_ordinal = executor->device_ordinal(); auto sub_allocator = std::make_unique( executor, tsl::PlatformDeviceId(device_ordinal), @@ -128,6 +129,16 @@ StatusOr> CreateCollectiveBFCAllocator( /*alloc_visitors=*/std::vector(), /*free_visitors=*/std::vector()); + int64_t free_memory; + int64_t total_memory; + if (!executor->DeviceMemoryUsage(&free_memory, &total_memory)) { + return Unavailable("Failed to query available memory from device %i", + device_ordinal); + } + bool preallocate = collective_memory_size != 0; + size_t allocator_memory = + preallocate ? collective_memory_size : total_memory * memory_fraction; + if (preallocate) { LOG(INFO) << "XLA backend allocating " << allocator_memory << " bytes on device " << device_ordinal @@ -153,8 +164,18 @@ std::unique_ptr GetGpuHostAllocator( new se::DeviceHostAllocator(executor, /*numa_node=*/0, /*alloc_visitors=*/{}, /*free_visitors=*/{})); - // TODO(phawkins): allow the user to tune this. - const int64_t kGpuHostMemoryLimitBytes = 64 * (1LL << 30); + + int64_t xla_pjrt_gpu_host_memory_limit_gb; + Status status = + tsl::ReadInt64FromEnvVar("XLA_PJRT_GPU_HOST_MEMORY_LIMIT_GB", 64, + &xla_pjrt_gpu_host_memory_limit_gb); + if (!status.ok()) { + LOG(ERROR) << "Unable to read XLA_PJRT_GPU_HOST_MEMORY_LIMIT_GB: " + << status.message(); + } + + const int64_t kGpuHostMemoryLimitBytes = + xla_pjrt_gpu_host_memory_limit_gb * (1LL << 30); tsl::BFCAllocator::Options opts; opts.allow_growth = true; diff --git a/third_party/xla/xla/pjrt/gpu/gpu_helpers.h b/third_party/xla/xla/pjrt/gpu/gpu_helpers.h index ab0ee2a42e56ed..6ee4cc0c886b3e 100644 --- a/third_party/xla/xla/pjrt/gpu/gpu_helpers.h +++ b/third_party/xla/xla/pjrt/gpu/gpu_helpers.h @@ -31,7 +31,7 @@ limitations under the License. namespace xla { // Builds an xla::LocalClient for the GPU platform. -StatusOr GetGpuXlaClient( +absl::StatusOr GetGpuXlaClient( const std::optional& platform_name, const std::optional>& allowed_devices); @@ -58,10 +58,11 @@ struct GpuAllocatorConfig { // allocator will allocate more memory as allocations are requested. bool preallocate = true; - // Amount of collective memory (ncclMemAlloc) to reserve. Must be set when - // using `xla_gpu_enable_nccl_user_buffers=true`. If this value is 0, - // collective memory will not be allocated. Should be set to a multiple of - // 512MB to avoid wasting memory due to granularity requirements. + // Amount of collective memory (ncclMemAlloc) to preallocate. If this value is + // 0, collective memory space will be grown as needed to fit the application's + // usage, with the drawback of potentially higher fragmentation. If set, + // should be set to a multiple of 512MB to avoid wasting memory due to + // granularity requirements. size_t collective_memory_size = 0; }; @@ -69,12 +70,13 @@ std::unique_ptr GetGpuHostAllocator( se::StreamExecutor* executor); // Builds a BFCAllocator for all local GPUs. -StatusOr> CreateBFCAllocator( +absl::StatusOr> CreateBFCAllocator( se::StreamExecutor* executor, double memory_fraction, bool preallocate); // Builds a BFCAllocator for all local GPUs that uses collective memory. -StatusOr> CreateCollectiveBFCAllocator( - se::StreamExecutor* executor, size_t allocator_memory, bool preallocate); +absl::StatusOr> CreateCollectiveBFCAllocator( + se::StreamExecutor* executor, double memory_fraction, + size_t collective_memory_size); } // namespace xla diff --git a/third_party/xla/xla/pjrt/gpu/nccl_id_store.cc b/third_party/xla/xla/pjrt/gpu/nccl_id_store.cc index 13c18bbd8c9f2b..facdecada09afd 100644 --- a/third_party/xla/xla/pjrt/gpu/nccl_id_store.cc +++ b/third_party/xla/xla/pjrt/gpu/nccl_id_store.cc @@ -29,7 +29,7 @@ limitations under the License. namespace xla { -StatusOr NcclIdStore::GetNcclUniqueId( +absl::StatusOr NcclIdStore::GetNcclUniqueId( const gpu::NcclCliqueKey& key) { // The caller must ensure that threads calling this method concurrently have // unique keys, otherwise the global key-value store may hold the wrong value. diff --git a/third_party/xla/xla/pjrt/gpu/nccl_id_store.h b/third_party/xla/xla/pjrt/gpu/nccl_id_store.h index bdfced219d8ee9..70060e242b1506 100644 --- a/third_party/xla/xla/pjrt/gpu/nccl_id_store.h +++ b/third_party/xla/xla/pjrt/gpu/nccl_id_store.h @@ -41,7 +41,8 @@ class NcclIdStore { device_to_node_(std::move(device_to_node)), kv_store_(std::move(kv_store)) {} - StatusOr GetNcclUniqueId(const gpu::NcclCliqueKey& key); + absl::StatusOr GetNcclUniqueId( + const gpu::NcclCliqueKey& key); private: const int node_id_; diff --git a/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client.cc b/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client.cc index 9b8cf9e6465980..f8068b0a755f8a 100644 --- a/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client.cc +++ b/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client.cc @@ -18,6 +18,8 @@ limitations under the License. #include #include #include +#include +#include #include #include #include @@ -108,9 +110,9 @@ namespace xla { class AsyncHostToDeviceTransferManager : public xla::PjRtClient::AsyncHostToDeviceTransferManager { public: - static StatusOr> Create( - absl::Span shapes, PjRtStreamExecutorDevice* device, - PjRtStreamExecutorClient* client) { + static absl::StatusOr> + Create(absl::Span shapes, PjRtStreamExecutorDevice* device, + PjRtStreamExecutorClient* client) { absl::InlinedVector, 4> buffers; absl::InlinedVector, 4> buffer_ptrs; absl::InlinedVector, 4> @@ -298,6 +300,27 @@ class AsyncHostToDeviceTransferManager bool is_last_transfer, absl::AnyInvocable on_done) override { auto* stream = device_->local_device_state()->host_to_device_stream(); + auto* client = + tensorflow::down_cast(device_->client()); + bool should_stage_host_to_device_transfers = + client->should_stage_host_to_device_transfers(); + std::shared_ptr staging_buffer; + if (should_stage_host_to_device_transfers) { + auto* host_memory_allocator = client->host_memory_allocator(); + if (host_memory_allocator == nullptr) { + return InvalidArgument( + "host_memory_allocator should be initialized for staging buffer " + "transfer."); + } + + void* ptr = host_memory_allocator->AllocateRaw( + tsl::Allocator::kAllocatorAlignment, transfer_size); + staging_buffer = std::shared_ptr( + ptr, [host_memory_allocator = host_memory_allocator](void* ptr) { + host_memory_allocator->DeallocateRaw(ptr); + }); + } + absl::ReleasableMutexLock l(&mu_); DCHECK_LT(buffer_index, buffer_ptrs_.size()); if (last_transfer_started_[buffer_index]) { @@ -329,16 +352,36 @@ class AsyncHostToDeviceTransferManager } ++transfers_in_flight_; + // Release the lock before transfer in case transfer or cleanup could be + // called on this thread, to avoid deadlock. + l.Release(); + auto event = device_->local_device_state()->event_pool().AllocateEvent( stream->parent()); + if (transfer_size != 0) { - stream->ThenMemcpy(&sub_buffer, data, transfer_size); + if (staging_buffer != nullptr) { + auto copy_to_staging_buffer = [data, transfer_size, + staging_buffer]() mutable { + std::memcpy(staging_buffer.get(), data, transfer_size); + }; + if (auto status = + stream->DoHostCallback(std::move(copy_to_staging_buffer)); + !status.ok()) { + return status; + } + if (auto status = stream->Memcpy(&sub_buffer, staging_buffer.get(), + transfer_size); + !status.ok()) { + return status; + } + } else if (auto status = stream->Memcpy(&sub_buffer, data, transfer_size); + !status.ok()) { + return status; + } } device_->local_device_state()->event_pool().ThenRecordEvent(stream, event.value()); - // Release the lock before calling ThenDoHostCallback in case cleanup - // could be called on this thread, to avoid deadlock. - l.Release(); auto cleanup = [this, buffer_index, event = std::move(event).value(), stream, is_last_transfer, @@ -346,8 +389,7 @@ class AsyncHostToDeviceTransferManager CleanUp(buffer_index, std::move(event), stream, is_last_transfer, std::move(on_done)); }; - stream->ThenDoHostCallback(std::move(cleanup)); - return OkStatus(); + return stream->DoHostCallback(std::move(cleanup)); } void SetBufferError(int buffer_index, Status error) override { @@ -432,7 +474,7 @@ absl::string_view StreamExecutorGpuClient::platform_version() const { #endif // TENSORFLOW_USE_ROCM && defined(TF_ROCM_VERSION) } -StatusOr> +absl::StatusOr> StreamExecutorGpuClient::CreateBuffersForAsyncHostToDevice( absl::Span shapes, PjRtDevice* device) { auto* stream_executor_device = @@ -441,7 +483,7 @@ StreamExecutorGpuClient::CreateBuffersForAsyncHostToDevice( shapes, stream_executor_device, this); } -xla::StatusOr +absl::StatusOr StreamExecutorGpuClient::GetDefaultDeviceAssignment(int num_replicas, int num_partitions) const { if (num_partitions == 1 && num_replicas <= addressable_devices().size()) { @@ -543,7 +585,7 @@ PjRtFuture StreamExecutorGpuClient::CopyRawSubBufferToHost( }); } -StatusOr> +absl::StatusOr> StreamExecutorGpuClient::Compile(const XlaComputation& computation, CompileOptions options) { auto executable = PjRtStreamExecutorClient::Compile(computation, options); @@ -571,7 +613,7 @@ StreamExecutorGpuClient::Compile(const XlaComputation& computation, namespace { #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM -StatusOr> FromProto( +absl::StatusOr> FromProto( const StreamExecutorExecutableProto& proto) { TF_ASSIGN_OR_RETURN(CompileOptions compile_options, CompileOptions::FromProto(proto.compile_options())); @@ -591,7 +633,7 @@ StatusOr> FromProto( #endif } // namespace -StatusOr> +absl::StatusOr> StreamExecutorGpuClient::LoadSerialized(absl::string_view serialized, std::optional options, const LoadOptions& load_options) { @@ -614,8 +656,26 @@ StreamExecutorGpuClient::LoadSerialized(absl::string_view serialized, return absl::InternalError("LoadSerialized only works with cuda or rocm."); } -StatusOr> StreamExecutorGpuClient::Load( - std::unique_ptr executable) { +absl::StatusOr> +StreamExecutorGpuClient::DeserializeExecutable( + absl::string_view serialized, std::optional options) { + if (serialized.size() > std::numeric_limits::max()) { + return Internal( + "StreamExecutorGpuClient::DeserializeExecutable proto too large " + "(>2GB)"); + } +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM + StreamExecutorExecutableProto proto; + if (proto.ParseFromArray(serialized.data(), serialized.size())) { + TF_ASSIGN_OR_RETURN(auto se_executable, FromProto(proto)); + return Load(std::move(se_executable)); + } +#endif + return PjRtStreamExecutorClient::DeserializeExecutable(serialized, options); +} + +absl::StatusOr> +StreamExecutorGpuClient::Load(std::unique_ptr executable) { auto se_executable = absl::WrapUnique( tensorflow::down_cast(executable.release())); @@ -652,7 +712,7 @@ namespace { #if defined(GOOGLE_CUDA) && CUDA_VERSION >= 11020 -StatusOr> +absl::StatusOr> CreateCudaAsyncAllocator( se::Platform* platform, const std::map>& addressable_devices, @@ -700,7 +760,7 @@ CreateCudaAsyncAllocator( #else // defined(GOOGLE_CUDA) && CUDA_VERSION >= 11020 -StatusOr> +absl::StatusOr> CreateCudaAsyncAllocator( se::Platform* platform, const std::map>& addressable_devices, @@ -711,7 +771,7 @@ CreateCudaAsyncAllocator( #endif // defined(GOOGLE_CUDA) && CUDA_VERSION >= 11020 // Builds a LocalDeviceState for each GPU present. -StatusOr>> +absl::StatusOr>> BuildLocalDeviceStates(LocalClient* xla_client) { std::map> addressable_devices; for (se::StreamExecutor* executor : @@ -728,7 +788,7 @@ BuildLocalDeviceStates(LocalClient* xla_client) { // Constructs a GPU device memory allocator to use, according to the allocator // configuration the client requested. -StatusOr> +absl::StatusOr> GetStreamExecutorGpuDeviceAllocator( se::Platform* platform, const GpuAllocatorConfig& allocator_config, const std::map>& @@ -778,20 +838,17 @@ GetStreamExecutorGpuDeviceAllocator( } // Add any additional allocators for alternate memory spaces. - if (allocator_config.collective_memory_size != 0) { - for (const auto& ordinal_and_device : addressable_devices) { - TF_ASSIGN_OR_RETURN( - auto collective_bfc_allocator, - CreateCollectiveBFCAllocator( - ordinal_and_device.second->executor(), - /*allocator_memory=*/allocator_config.collective_memory_size, - /*preallocate=*/true)); - allocators.emplace_back(std::move(collective_bfc_allocator), - ordinal_and_device.second->compute_stream(), - /*memory_space=*/1); - } + for (const auto& ordinal_and_device : addressable_devices) { + TF_ASSIGN_OR_RETURN( + auto collective_bfc_allocator, + CreateCollectiveBFCAllocator( + ordinal_and_device.second->executor(), + /*memory_fraction=*/1.0 - allocator_config.memory_fraction, + allocator_config.collective_memory_size)); + allocators.emplace_back(std::move(collective_bfc_allocator), + ordinal_and_device.second->compute_stream(), + /*memory_space=*/1); } - return std::make_unique(platform, std::move(allocators)); } @@ -929,8 +986,9 @@ absl::StatusOr StreamExecutorGpuDevice::GetAllocatorStats() auto* allocator_adapter = dynamic_cast( tensorflow::down_cast(client())->allocator()); if (!allocator_adapter) { - return FailedPrecondition( - "GetAllocatorStats() only works with MultiDeviceAdapter allocator"); + return Unimplemented( + "GetAllocatorStats() is only implemented with MultiDeviceAdapter " + "allocator"); } TF_ASSIGN_OR_RETURN(auto allocator, allocator_adapter->GetAllocator( @@ -949,7 +1007,7 @@ int StreamExecutorGpuDevice::core_on_chip() const { return description().core_on_chip(); } -StatusOr> GetStreamExecutorGpuClient( +absl::StatusOr> GetStreamExecutorGpuClient( const GpuClientOptions& options) { #if TENSORFLOW_USE_ROCM auto pjrt_platform_name = xla::RocmName(); @@ -975,8 +1033,6 @@ StatusOr> GetStreamExecutorGpuClient( if (options.enable_mock_nccl) { gpu_run_options->set_enable_mock_nccl_collectives(); } - absl::flat_hash_map device_maps; - absl::Mutex mu; std::shared_ptr kv_store = options.kv_store; if (options.enable_mock_nccl) { kv_store = std::make_shared(); diff --git a/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client.h b/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client.h index 066f453392dfba..10778be11a7454 100644 --- a/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client.h +++ b/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client.h @@ -60,14 +60,17 @@ class StreamExecutorGpuTopologyDescription : public PjRtTopologyDescription { } // `gpu_device_ids` is the list of logical device ids for the GPU devices and // will be used to initialize the GPU topology. - StreamExecutorGpuTopologyDescription(const PjRtPlatformId platform_id, - const absl::string_view platform_name, - const absl::string_view platform_version, - const std::vector& gpu_device_ids) + StreamExecutorGpuTopologyDescription( + const PjRtPlatformId platform_id, const absl::string_view platform_name, + const absl::string_view platform_version, + const std::vector& gpu_device_ids, + const absl::flat_hash_map& attributes = + {}) : platform_id_(platform_id), platform_name_(platform_name), platform_version_(platform_version), - gpu_topology_(gpu_device_ids) {} + gpu_topology_(gpu_device_ids), + attributes_(attributes) {} bool operator==(const StreamExecutorGpuTopologyDescription& other) const { return this->platform_id() == other.platform_id() && @@ -179,12 +182,12 @@ class StreamExecutorGpuClient : public xla::PjRtStreamExecutorClient { tsl::Fingerprint64(platform_name), platform_name, devices_.back()->device_kind(), devices_)) {} - xla::StatusOr GetDefaultDeviceAssignment( + absl::StatusOr GetDefaultDeviceAssignment( int num_replicas, int num_partitions) const override; absl::string_view platform_version() const override; - StatusOr> + absl::StatusOr> CreateBuffersForAsyncHostToDevice(absl::Span shapes, PjRtDevice* device) override; @@ -192,13 +195,13 @@ class StreamExecutorGpuClient : public xla::PjRtStreamExecutorClient { int64_t offset, int64_t transfer_size) override; - StatusOr GetTopologyDescription() + absl::StatusOr GetTopologyDescription() const override { return &topology_; } // TODO(b/285385306): Enable loading a non-loaded PjRtExecutable. - StatusOr> Load( + absl::StatusOr> Load( std::unique_ptr executable, const LoadOptions& load_options) override { return absl::WrapUnique( @@ -207,16 +210,20 @@ class StreamExecutorGpuClient : public xla::PjRtStreamExecutorClient { // TODO(b/296466237): Unify `Load` method after (de)serialization and tests on // existing use cases are done. - StatusOr> Load( + absl::StatusOr> Load( std::unique_ptr executable); // TODO(b/296466237): Unify `LoadSerializedExecutable` after fixing existing // tests. - StatusOr> LoadSerialized( + absl::StatusOr> LoadSerialized( absl::string_view serialized, std::optional options, const LoadOptions& load_options); - StatusOr> Compile( + absl::StatusOr> DeserializeExecutable( + absl::string_view serialized, + std::optional options) override; + + absl::StatusOr> Compile( const XlaComputation& computation, CompileOptions options) override; private: @@ -246,7 +253,7 @@ struct GpuClientOptions { bool enable_mock_nccl = false; }; -StatusOr> GetStreamExecutorGpuClient( +absl::StatusOr> GetStreamExecutorGpuClient( const GpuClientOptions& options); } // namespace xla diff --git a/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client_test.cc b/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client_test.cc index 748b45b0e5a738..ecf04945364deb 100644 --- a/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client_test.cc +++ b/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client_test.cc @@ -55,7 +55,7 @@ using ::testing::ElementsAre; using ::testing::HasSubstr; using ::tsl::testing::StatusIs; -StatusOr> CompileExecutable( +absl::StatusOr> CompileExecutable( absl::string_view program, xla::PjRtClient& client, xla::CompileOptions compile_options = xla::CompileOptions()) { TF_ASSIGN_OR_RETURN(auto hlo_module, @@ -67,8 +67,8 @@ StatusOr> CompileExecutable( // Given the result of a PjrtExecutable::Execute call (TF-status of vectors of // vectors), extract the zeroth result from the zeroth device. -StatusOr> ExtractSingleResult( - xla::StatusOr>>>& +absl::StatusOr> ExtractSingleResult( + absl::StatusOr>>>& result) { TF_RETURN_IF_ERROR(result.status()); TF_RET_CHECK(result->size() == 1); @@ -328,11 +328,6 @@ TEST(StreamExecutorGpuClientTest, FromHostAsync) { buffers.emplace_back(transfer_manager->RetrieveBuffer(i)); } - absl::Mutex mu; - std::vector> literals; - int got_literal_count = 0; - int got_callback_count = 0; - for (int i = 0; i < src_shapes.size(); ++i) { TF_ASSERT_OK(transfer_manager->TransferRawDataToBuffer( i, @@ -341,6 +336,11 @@ TEST(StreamExecutorGpuClientTest, FromHostAsync) { [&]() {})); } + absl::Mutex mu; + std::vector> literals; + int got_literal_count = 0; + int got_callback_count = 0; + for (auto& buffer : buffers) { literals.push_back(std::make_shared( ShapeUtil::DeviceShapeToHostShape(buffer->on_device_shape()))); diff --git a/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_compiler.cc b/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_compiler.cc index 751e9d177829b7..7a4f99bed3e10b 100644 --- a/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_compiler.cc +++ b/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_compiler.cc @@ -105,13 +105,24 @@ StreamExecutorGpuCompiler::Compile(CompileOptions options, CompileOptions input_options = options; if (!options.target_config) { - if (!client) { + if (client != nullptr) { + TF_RETURN_IF_ERROR(IsValidTopologyAndClientForCompile(topology, client)); + return client->Compile(computation, options); + } + auto attr = topology.Attributes(); + if (auto it = attr.find("target_config"); it != attr.end()) { + auto target_config_str = std::get(it->second); + stream_executor::GpuTargetConfigProto gpu_target_config_proto; + if (!gpu_target_config_proto.ParseFromString(target_config_str)) { + return FailedPrecondition("Failed to parse GpuTargetConfigProto"); + } + options.target_config.emplace( + Compiler::TargetConfig(gpu_target_config_proto)); + } else { return absl::UnimplementedError( "Compilation without client and without target_config specified is " "not implemented"); } - TF_RETURN_IF_ERROR(IsValidTopologyAndClientForCompile(topology, client)); - return client->Compile(computation, options); } TF_RETURN_IF_ERROR(options.ApplyAllOptionOverrides()); std::vector argument_layout_pointers; @@ -185,7 +196,7 @@ StreamExecutorGpuCompiler::Compile(CompileOptions options, #endif } -REGISTER_MODULE_INITIALIZER(pjrt_register_se_gpu_compiler, { +STREAM_EXECUTOR_REGISTER_MODULE_INITIALIZER(pjrt_register_se_gpu_compiler, { PjRtRegisterCompiler(CudaName(), std::make_unique()); }); diff --git a/third_party/xla/xla/pjrt/local_device_state.cc b/third_party/xla/xla/pjrt/local_device_state.cc index 20ef5352793118..d2149087e19f00 100644 --- a/third_party/xla/xla/pjrt/local_device_state.cc +++ b/third_party/xla/xla/pjrt/local_device_state.cc @@ -22,9 +22,11 @@ limitations under the License. #include #include +#include "absl/log/check.h" #include "absl/synchronization/mutex.h" #include "xla/stream_executor/stream.h" #include "xla/util.h" +#include "tsl/platform/errors.h" #include "tsl/profiler/lib/traceme.h" #include "tsl/protobuf/error_codes.pb.h" @@ -137,11 +139,10 @@ Status LocalDeviceState::SynchronizeAllActivity() { Status LocalDeviceState::ThenMemcpyDeviceToDevice( se::Stream* transfer_stream, se::Stream* dst_stream, se::DeviceMemoryBase src_buffer, se::DeviceMemoryBase dst_buffer) { - // The default implementation simply calls ThenMemcpyD2D, and assumes that + // The default implementation simply calls MemcpyD2D, and assumes that // the buffer addresses identify the devices. This does not work // on all platforms; this method is virtual so it can be overridden. - transfer_stream->ThenMemcpyD2D(&dst_buffer, src_buffer, dst_buffer.size()); - return OkStatus(); + return transfer_stream->MemcpyD2D(&dst_buffer, src_buffer, dst_buffer.size()); } void LocalDeviceState::ThenExecuteCallback(se::Stream* stream, @@ -216,21 +217,24 @@ std::vector LocalDeviceState::GetDeviceToDeviceStreams() { } std::unique_ptr LocalDeviceState::BorrowStreamFromPool() { - absl::MutexLock lock(&mu_); - if (usage_stream_pool_.empty()) { - auto stream = std::make_unique(compute_stream_->parent()); - stream->Init(); - return stream; - } else { - std::unique_ptr stream = std::move(usage_stream_pool_.top()); - usage_stream_pool_.pop(); - auto status = stream->RefreshStatus(); // Can return error::Unimplemented - // Stream may fail with "ABORTED: Bad connection". - if (status.code() != tsl::error::ABORTED) { - CHECK(stream->ok()) << status; + { + absl::MutexLock lock(&stream_pool_mu_); + if (!usage_stream_pool_.empty()) { + std::unique_ptr stream = std::move(usage_stream_pool_.top()); + usage_stream_pool_.pop(); + auto status = stream->RefreshStatus(); // Can return error::Unimplemented + // Stream may fail with "ABORTED: Bad connection". + if (status.code() != tsl::error::ABORTED) { + CHECK(stream->ok()) << status; + } + return stream; } - return stream; } + + // The stream pool is empty, create a new stream. + auto stream = std::make_unique(compute_stream_->parent()); + stream->Init(); + return stream; } void LocalDeviceState::ReturnStreamToPool(std::unique_ptr stream) { @@ -239,7 +243,7 @@ void LocalDeviceState::ReturnStreamToPool(std::unique_ptr stream) { if (status.code() != tsl::error::ABORTED) { CHECK(stream->ok()) << status; } - absl::MutexLock lock(&mu_); + absl::MutexLock lock(&stream_pool_mu_); usage_stream_pool_.push(std::move(stream)); } diff --git a/third_party/xla/xla/pjrt/local_device_state.h b/third_party/xla/xla/pjrt/local_device_state.h index 231571a75203c0..09bbf84348cf81 100644 --- a/third_party/xla/xla/pjrt/local_device_state.h +++ b/third_party/xla/xla/pjrt/local_device_state.h @@ -223,13 +223,15 @@ class LocalDeviceState { int next_device_to_host_stream_ ABSL_GUARDED_BY(mu_) = 0; int next_device_to_device_stream_ ABSL_GUARDED_BY(mu_) = 0; int next_external_ready_event_stream_ ABSL_GUARDED_BY(mu_) = 0; - std::stack> usage_stream_pool_ - ABSL_GUARDED_BY(mu_); std::random_device prng_seed_device_ ABSL_GUARDED_BY(mu_); std::mt19937 prng_seed_generator_ ABSL_GUARDED_BY(mu_); std::uniform_int_distribution<> prng_seed_distribution_ ABSL_GUARDED_BY(mu_); + absl::Mutex stream_pool_mu_; + std::stack> usage_stream_pool_ + ABSL_GUARDED_BY(stream_pool_mu_); + // Callback map pairs callback stream with a device stream and is used for // running short host-side callbacks after device side events, without // preventing the device-side stream from doing useful work. diff --git a/third_party/xla/xla/pjrt/pjrt_api.cc b/third_party/xla/xla/pjrt/pjrt_api.cc index 391eacf60b56e0..a710d8b62135b4 100644 --- a/third_party/xla/xla/pjrt/pjrt_api.cc +++ b/third_party/xla/xla/pjrt/pjrt_api.cc @@ -77,7 +77,7 @@ xla::Status SetPjrtApi(absl::string_view device_type, const PJRT_Api* api) { (*pjrt_apis)[canonicalize_device_type] = std::make_pair(api, /*is_initialized=*/false); LOG(INFO) << "PJRT_Api is set for device type " << canonicalize_device_type; - return tsl::OkStatus(); + return absl::OkStatus(); } typedef const PJRT_Api* (*PjrtApiInitFn)(); @@ -177,7 +177,7 @@ xla::Status InitializePjrtPlugin(absl::string_view device_type) { } PJRT_Plugin_Initialize_Args args; args.struct_size = PJRT_Plugin_Initialize_Args_STRUCT_SIZE; - args.priv = nullptr; + args.extension_start = nullptr; RETURN_STATUS_IF_PJRT_ERROR(pjrt_api->PJRT_Plugin_Initialize(&args), pjrt_api); iter->second.second = true; diff --git a/third_party/xla/xla/pjrt/pjrt_c_api_client.cc b/third_party/xla/xla/pjrt/pjrt_c_api_client.cc index 79ff23580c207a..d7e3a222d40c8d 100644 --- a/third_party/xla/xla/pjrt/pjrt_c_api_client.cc +++ b/third_party/xla/xla/pjrt/pjrt_c_api_client.cc @@ -141,7 +141,7 @@ void PjRtCApiClient::InitDevicesAndMemorySpaces() { // Initialize devices. PJRT_Client_Devices_Args devices_args; devices_args.struct_size = PJRT_Client_Devices_Args_STRUCT_SIZE; - devices_args.priv = nullptr; + devices_args.extension_start = nullptr; devices_args.client = c_client_.get(); pjrt::LogFatalIfPjrtError(c_api_->PJRT_Client_Devices(&devices_args), c_api_); @@ -162,7 +162,7 @@ void PjRtCApiClient::InitDevicesAndMemorySpaces() { // Initialize addressable devices. PJRT_Client_AddressableDevices_Args address_args; address_args.struct_size = PJRT_Client_AddressableDevices_Args_STRUCT_SIZE; - address_args.priv = nullptr; + address_args.extension_start = nullptr; address_args.client = c_client_.get(); pjrt::LogFatalIfPjrtError( @@ -180,7 +180,7 @@ void PjRtCApiClient::InitDevicesAndMemorySpaces() { // TODO(yueshengys): Initialize global memory spaces when supported. PJRT_Client_AddressableMemories_Args memory_args; memory_args.struct_size = PJRT_Client_AddressableMemories_Args_STRUCT_SIZE; - memory_args.priv = nullptr; + memory_args.extension_start = nullptr; memory_args.client = c_client_.get(); std::unique_ptr client_error( @@ -212,7 +212,7 @@ void PjRtCApiClient::InitDevicesAndMemorySpaces() { PJRT_Device* c_device = cpp_device->c_device(); PJRT_Device_AddressableMemories_Args args; args.struct_size = PJRT_Device_AddressableMemories_Args_STRUCT_SIZE; - args.priv = nullptr; + args.extension_start = nullptr; args.device = c_device; std::unique_ptr device_error( @@ -241,7 +241,7 @@ void PjRtCApiClient::InitDevicesAndMemorySpaces() { PJRT_Memory* c_memory = cpp_memory->c_memory(); PJRT_Memory_AddressableByDevices_Args args; args.struct_size = PJRT_Memory_AddressableByDevices_Args_STRUCT_SIZE; - args.priv = nullptr; + args.extension_start = nullptr; args.memory = c_memory; pjrt::LogFatalIfPjrtError(c_api_->PJRT_Memory_AddressableByDevices(&args), c_api_); @@ -272,7 +272,7 @@ absl::Span PjRtCApiClient::addressable_devices() const { int PjRtCApiClient::process_index() const { PJRT_Client_ProcessIndex_Args process_index_args; process_index_args.struct_size = PJRT_Client_ProcessIndex_Args_STRUCT_SIZE; - process_index_args.priv = nullptr; + process_index_args.extension_start = nullptr; process_index_args.client = c_client_.get(); pjrt::LogFatalIfPjrtError( c_api_->PJRT_Client_ProcessIndex(&process_index_args), c_api_); @@ -301,7 +301,7 @@ StatusOr PjRtCApiClient::GetDefaultDeviceAssignment( int num_replicas, int num_partitions) const { PJRT_Client_DefaultDeviceAssignment_Args args; args.struct_size = PJRT_Client_DefaultDeviceAssignment_Args_STRUCT_SIZE; - args.priv = nullptr; + args.extension_start = nullptr; args.client = c_client_.get(); args.num_replicas = num_replicas; args.num_partitions = num_partitions; @@ -324,7 +324,7 @@ StatusOr PjRtCApiClient::LookupDevice( PjRtGlobalDeviceId global_device_id) const { PJRT_Client_LookupDevice_Args args; args.struct_size = PJRT_Client_LookupDevice_Args_STRUCT_SIZE; - args.priv = nullptr; + args.extension_start = nullptr; args.client = c_client_.get(); args.id = global_device_id.value(); RETURN_STATUS_IF_PJRT_ERROR(c_api_->PJRT_Client_LookupDevice(&args), c_api_); @@ -340,7 +340,7 @@ StatusOr PjRtCApiClient::LookupAddressableDevice( PjRtLocalDeviceId local_device_id) const { PJRT_Client_LookupAddressableDevice_Args args; args.struct_size = PJRT_Client_LookupAddressableDevice_Args_STRUCT_SIZE; - args.priv = nullptr; + args.extension_start = nullptr; args.client = c_client_.get(); args.local_hardware_id = local_device_id.value(); RETURN_STATUS_IF_PJRT_ERROR( @@ -360,7 +360,7 @@ static StatusOr> InitializeArgsAndCompile( const std::string& format) { PJRT_Client_Compile_Args args; args.struct_size = PJRT_Client_Compile_Args_STRUCT_SIZE; - args.priv = nullptr; + args.extension_start = nullptr; args.client = client; TF_ASSIGN_OR_RETURN(const CompileOptionsProto options_proto, options.ToProto()); @@ -370,7 +370,7 @@ static StatusOr> InitializeArgsAndCompile( PJRT_Program program; program.struct_size = PJRT_Program_STRUCT_SIZE; - program.priv = nullptr; + program.extension_start = nullptr; program.code = const_cast(code.c_str()); program.code_size = code.size(); program.format = format.c_str(); @@ -407,7 +407,7 @@ PjRtCApiClient::DeserializeExecutable(absl::string_view serialized, PJRT_Executable_DeserializeAndLoad_Args des_args; des_args.struct_size = PJRT_Executable_DeserializeAndLoad_Args_STRUCT_SIZE; - des_args.priv = nullptr; + des_args.extension_start = nullptr; des_args.client = c_client_.get(); des_args.serialized_executable = serialized.data(); des_args.serialized_executable_size = serialized.length(); @@ -447,7 +447,7 @@ StatusOr PjRtCApiClient::UnsafeBufferPointer( PJRT_Buffer_UnsafePointer_Args args; args.struct_size = PJRT_Buffer_UnsafePointer_Args_STRUCT_SIZE; - args.priv = nullptr; + args.extension_start = nullptr; args.buffer = tensorflow::down_cast(buffer)->c_buffer(); @@ -477,7 +477,7 @@ PjRtCApiClient::BufferFromHostBufferInternalImpl( PJRT_Client_BufferFromHostBuffer_Args args; args.struct_size = PJRT_Client_BufferFromHostBuffer_Args_STRUCT_SIZE; - args.priv = nullptr; + args.extension_start = nullptr; args.client = c_client_.get(); args.data = data; args.type = ::pjrt::ConvertToPjRtBufferType(type); @@ -527,7 +527,7 @@ PjRtCApiClient::BufferFromHostBufferInternalImpl( if (on_done_with_host_buffer) { PJRT_Event_OnReady_Args event_args; event_args.struct_size = PJRT_Event_OnReady_Args_STRUCT_SIZE; - event_args.priv = nullptr; + event_args.extension_start = nullptr; event_args.event = event.get(); event_args.user_arg = new absl::AnyInvocable( [on_done_with_host_buffer = std::move(on_done_with_host_buffer), @@ -590,7 +590,7 @@ StatusOr> PjRtCApiClient::CreateViewOfDeviceBuffer( std::optional stream) { PJRT_Client_CreateViewOfDeviceBuffer_Args args; args.struct_size = PJRT_Client_CreateViewOfDeviceBuffer_Args_STRUCT_SIZE; - args.priv = nullptr; + args.extension_start = nullptr; args.client = c_client_.get(); args.device_buffer_ptr = device_ptr; args.dims = shape.dimensions().data(); @@ -644,7 +644,7 @@ PjRtCApiDeviceDescription::PjRtCApiDeviceDescription( int PjRtCApiDeviceDescription::id() const { PJRT_DeviceDescription_Id_Args args; args.struct_size = PJRT_DeviceDescription_Id_Args_STRUCT_SIZE; - args.priv = nullptr; + args.extension_start = nullptr; args.device_description = device_description_; pjrt::LogFatalIfPjrtError(c_api_->PJRT_DeviceDescription_Id(&args), c_api_); return args.id; @@ -653,7 +653,7 @@ int PjRtCApiDeviceDescription::id() const { int PjRtCApiDeviceDescription::process_index() const { PJRT_DeviceDescription_ProcessIndex_Args args; args.struct_size = PJRT_DeviceDescription_ProcessIndex_Args_STRUCT_SIZE; - args.priv = nullptr; + args.extension_start = nullptr; args.device_description = device_description_; pjrt::LogFatalIfPjrtError(c_api_->PJRT_DeviceDescription_ProcessIndex(&args), c_api_); @@ -664,7 +664,7 @@ void PjRtCApiDeviceDescription::InitAttributes() { attributes_ = {}; PJRT_DeviceDescription_Attributes_Args args; args.struct_size = PJRT_DeviceDescription_Attributes_Args_STRUCT_SIZE; - args.priv = nullptr; + args.extension_start = nullptr; args.device_description = device_description_; pjrt::LogFatalIfPjrtError(c_api_->PJRT_DeviceDescription_Attributes(&args), c_api_); @@ -715,7 +715,7 @@ PjRtCApiDeviceDescription::Attributes() const { absl::string_view PjRtCApiDeviceDescription::device_kind() const { PJRT_DeviceDescription_Kind_Args args; args.struct_size = PJRT_DeviceDescription_Kind_Args_STRUCT_SIZE; - args.priv = nullptr; + args.extension_start = nullptr; args.device_description = device_description_; pjrt::LogFatalIfPjrtError(c_api_->PJRT_DeviceDescription_Kind(&args), c_api_); @@ -727,7 +727,7 @@ absl::string_view PjRtCApiDeviceDescription::device_kind() const { absl::string_view PjRtCApiDeviceDescription::DebugString() const { PJRT_DeviceDescription_DebugString_Args args; args.struct_size = PJRT_DeviceDescription_DebugString_Args_STRUCT_SIZE; - args.priv = nullptr; + args.extension_start = nullptr; args.device_description = device_description_; pjrt::LogFatalIfPjrtError(c_api_->PJRT_DeviceDescription_DebugString(&args), c_api_); @@ -738,7 +738,7 @@ absl::string_view PjRtCApiDeviceDescription::DebugString() const { absl::string_view PjRtCApiDeviceDescription::ToString() const { PJRT_DeviceDescription_ToString_Args args; args.struct_size = PJRT_DeviceDescription_ToString_Args_STRUCT_SIZE; - args.priv = nullptr; + args.extension_start = nullptr; args.device_description = device_description_; pjrt::LogFatalIfPjrtError(c_api_->PJRT_DeviceDescription_ToString(&args), c_api_); @@ -757,7 +757,7 @@ PjRtClient* PjRtCApiDevice::client() const { return client_; } bool PjRtCApiDevice::IsAddressable() const { PJRT_Device_IsAddressable_Args args; args.struct_size = PJRT_Device_IsAddressable_Args_STRUCT_SIZE; - args.priv = nullptr; + args.extension_start = nullptr; args.device = device_; const PJRT_Api* api = client_->pjrt_c_api(); pjrt::LogFatalIfPjrtError(api->PJRT_Device_IsAddressable(&args), api); @@ -771,7 +771,7 @@ int PjRtCApiDevice::local_hardware_id() const { PjRtLocalHardwareId PjRtCApiDevice::local_hardware_id_typed() const { PJRT_Device_LocalHardwareId_Args args; args.struct_size = PJRT_Device_LocalHardwareId_Args_STRUCT_SIZE; - args.priv = nullptr; + args.extension_start = nullptr; args.device = device_; const PJRT_Api* api = client_->pjrt_c_api(); pjrt::LogFatalIfPjrtError(api->PJRT_Device_LocalHardwareId(&args), api); @@ -781,7 +781,7 @@ PjRtLocalHardwareId PjRtCApiDevice::local_hardware_id_typed() const { StatusOr PjRtCApiDevice::default_memory_space() const { PJRT_Device_DefaultMemory_Args args; args.struct_size = PJRT_Device_DefaultMemory_Args_STRUCT_SIZE; - args.priv = nullptr; + args.extension_start = nullptr; args.device = device_; const PJRT_Api* api = client_->pjrt_c_api(); RETURN_STATUS_IF_PJRT_ERROR(api->PJRT_Device_DefaultMemory(&args), api); @@ -791,7 +791,7 @@ StatusOr PjRtCApiDevice::default_memory_space() const { StatusOr PjRtCApiDevice::GetAllocatorStats() const { PJRT_Device_MemoryStats_Args args; args.struct_size = PJRT_Device_MemoryStats_Args_STRUCT_SIZE; - args.priv = nullptr; + args.extension_start = nullptr; args.device = device_; const PJRT_Api* api = client_->pjrt_c_api(); RETURN_STATUS_IF_PJRT_ERROR(api->PJRT_Device_MemoryStats(&args), api); @@ -859,7 +859,7 @@ PjRtClient* PjRtCApiMemorySpace::client() const { return client_; } int PjRtCApiMemorySpace::id() const { PJRT_Memory_Id_Args args; args.struct_size = PJRT_Memory_Id_Args_STRUCT_SIZE; - args.priv = nullptr; + args.extension_start = nullptr; args.memory = c_memory_; pjrt::LogFatalIfPjrtError(pjrt_c_api()->PJRT_Memory_Id(&args), pjrt_c_api()); return args.id; @@ -868,7 +868,7 @@ int PjRtCApiMemorySpace::id() const { absl::string_view PjRtCApiMemorySpace::memory_space_kind() const { PJRT_Memory_Kind_Args args; args.struct_size = PJRT_Memory_Kind_Args_STRUCT_SIZE; - args.priv = nullptr; + args.extension_start = nullptr; args.memory = c_memory_; pjrt::LogFatalIfPjrtError(pjrt_c_api()->PJRT_Memory_Kind(&args), @@ -880,7 +880,7 @@ absl::string_view PjRtCApiMemorySpace::memory_space_kind() const { absl::string_view PjRtCApiMemorySpace::DebugString() const { PJRT_Memory_DebugString_Args args; args.struct_size = PJRT_Memory_DebugString_Args_STRUCT_SIZE; - args.priv = nullptr; + args.extension_start = nullptr; args.memory = c_memory_; pjrt::LogFatalIfPjrtError(pjrt_c_api()->PJRT_Memory_DebugString(&args), pjrt_c_api()); @@ -890,7 +890,7 @@ absl::string_view PjRtCApiMemorySpace::DebugString() const { absl::string_view PjRtCApiMemorySpace::ToString() const { PJRT_Memory_ToString_Args args; args.struct_size = PJRT_Memory_ToString_Args_STRUCT_SIZE; - args.priv = nullptr; + args.extension_start = nullptr; args.memory = c_memory_; pjrt::LogFatalIfPjrtError(pjrt_c_api()->PJRT_Memory_ToString(&args), pjrt_c_api()); @@ -910,7 +910,7 @@ absl::string_view PjRtCApiExecutable::name() const { PJRT_Executable_Name_Args args; args.executable = executable; args.struct_size = PJRT_Executable_Name_Args_STRUCT_SIZE; - args.priv = nullptr; + args.extension_start = nullptr; pjrt::LogFatalIfPjrtError(c_api->PJRT_Executable_Name(&args), c_api); return absl::string_view(args.executable_name, args.executable_name_size); @@ -922,7 +922,7 @@ int PjRtCApiExecutable::num_replicas() const { PJRT_Executable_NumReplicas_Args args; args.executable = executable; args.struct_size = PJRT_Executable_NumReplicas_Args_STRUCT_SIZE; - args.priv = nullptr; + args.extension_start = nullptr; pjrt::LogFatalIfPjrtError(c_api->PJRT_Executable_NumReplicas(&args), c_api); return args.num_replicas; @@ -934,7 +934,7 @@ int PjRtCApiExecutable::num_partitions() const { PJRT_Executable_NumPartitions_Args args; args.executable = executable; args.struct_size = PJRT_Executable_NumPartitions_Args_STRUCT_SIZE; - args.priv = nullptr; + args.extension_start = nullptr; pjrt::LogFatalIfPjrtError(c_api->PJRT_Executable_NumPartitions(&args), c_api); return args.num_partitions; @@ -946,7 +946,7 @@ int64_t PjRtCApiExecutable::SizeOfGeneratedCodeInBytes() const { PJRT_Executable_SizeOfGeneratedCodeInBytes_Args args; args.struct_size = PJRT_Executable_SizeOfGeneratedCodeInBytes_Args_STRUCT_SIZE; - args.priv = nullptr; + args.extension_start = nullptr; args.executable = executable; pjrt::LogFatalIfPjrtError( @@ -959,7 +959,7 @@ PjRtCApiExecutable::GetCostAnalysis() const { // Initialize function call args PJRT_Executable_GetCostAnalysis_Args args; args.struct_size = PJRT_Executable_GetCostAnalysis_Args_STRUCT_SIZE; - args.priv = nullptr; + args.extension_start = nullptr; args.executable = c_executable(); // Make PJRT C API call @@ -976,7 +976,7 @@ StatusOr>> PjRtCApiExecutable::GetOutputElementTypes() const { PJRT_Executable_OutputElementTypes_Args args; args.struct_size = PJRT_Executable_OutputElementTypes_Args_STRUCT_SIZE; - args.priv = nullptr; + args.extension_start = nullptr; args.executable = c_executable(); const PJRT_Api* c_api = pjrt_c_api(); @@ -996,7 +996,7 @@ StatusOr>> PjRtCApiExecutable::GetOutputDimensions() const { PJRT_Executable_OutputDimensions_Args args; args.struct_size = PJRT_Executable_OutputDimensions_Args_STRUCT_SIZE; - args.priv = nullptr; + args.extension_start = nullptr; args.executable = c_executable(); const PJRT_Api* c_api = pjrt_c_api(); @@ -1022,7 +1022,7 @@ StatusOr>> PjRtCApiExecutable::GetOutputMemoryKinds() const { PJRT_Executable_OutputMemoryKinds_Args args; args.struct_size = PJRT_Executable_OutputMemoryKinds_Args_STRUCT_SIZE; - args.priv = nullptr; + args.extension_start = nullptr; args.executable = c_executable(); const PJRT_Api* c_api = pjrt_c_api(); @@ -1044,11 +1044,11 @@ PjRtCApiExecutable::GetHloModules() const { auto* executable = c_executable(); PJRT_Executable_OptimizedProgram_Args args; args.struct_size = PJRT_Executable_OptimizedProgram_Args_STRUCT_SIZE; - args.priv = nullptr; + args.extension_start = nullptr; args.executable = executable; PJRT_Program program; program.struct_size = PJRT_Program_STRUCT_SIZE; - program.priv = nullptr; + program.extension_start = nullptr; program.code = nullptr; args.program = &program; @@ -1113,7 +1113,7 @@ StatusOr PjRtCApiExecutable::SerializeExecutable() const { auto* executable = c_executable(); PJRT_Executable_Serialize_Args ser_args; ser_args.struct_size = PJRT_Executable_Serialize_Args_STRUCT_SIZE; - ser_args.priv = nullptr; + ser_args.extension_start = nullptr; ser_args.executable = executable; ser_args.serialized_executable = nullptr; @@ -1137,7 +1137,7 @@ StatusOr PjRtCApiExecutable::FingerprintExecutable() const { PJRT_Executable_Fingerprint_Args args; args.struct_size = PJRT_Executable_Fingerprint_Args_STRUCT_SIZE; - args.priv = nullptr; + args.extension_start = nullptr; args.executable = c_executable(); RETURN_STATUS_IF_PJRT_ERROR(c_api_->PJRT_Executable_Fingerprint(&args), c_api_); @@ -1154,7 +1154,7 @@ PjRtCApiLoadedExecutable::PjRtCApiLoadedExecutable( client->pjrt_c_api())) { PJRT_LoadedExecutable_GetExecutable_Args args; args.struct_size = PJRT_LoadedExecutable_GetExecutable_Args_STRUCT_SIZE; - args.priv = nullptr; + args.extension_start = nullptr; args.loaded_executable = c_loaded_executable(); args.executable = nullptr; pjrt::LogFatalIfPjrtError( @@ -1167,7 +1167,7 @@ PjRtCApiLoadedExecutable::PjRtCApiLoadedExecutable( void PjRtCApiLoadedExecutable::InitDevices() { PJRT_LoadedExecutable_AddressableDevices_Args args; args.struct_size = PJRT_LoadedExecutable_AddressableDevices_Args_STRUCT_SIZE; - args.priv = nullptr; + args.extension_start = nullptr; args.executable = c_loaded_executable(); args.addressable_devices = nullptr; args.num_addressable_devices = 0; @@ -1272,7 +1272,7 @@ CApiCopyToDeviceStream::CApiCopyToDeviceStream( PJRT_CopyToDeviceStream_TotalBytes_Args total_bytes_args; total_bytes_args.struct_size = PJRT_CopyToDeviceStream_TotalBytes_Args_STRUCT_SIZE; - total_bytes_args.priv = nullptr; + total_bytes_args.extension_start = nullptr; total_bytes_args.stream = c_stream_; pjrt::LogFatalIfPjrtError( c_api_->PJRT_CopyToDeviceStream_TotalBytes(&total_bytes_args), c_api_); @@ -1281,7 +1281,7 @@ CApiCopyToDeviceStream::CApiCopyToDeviceStream( PJRT_CopyToDeviceStream_GranuleSize_Args granule_size_args; granule_size_args.struct_size = PJRT_CopyToDeviceStream_GranuleSize_Args_STRUCT_SIZE; - granule_size_args.priv = nullptr; + granule_size_args.extension_start = nullptr; granule_size_args.stream = c_stream_; pjrt::LogFatalIfPjrtError( c_api_->PJRT_CopyToDeviceStream_GranuleSize(&granule_size_args), c_api_); @@ -1291,7 +1291,7 @@ CApiCopyToDeviceStream::CApiCopyToDeviceStream( CApiCopyToDeviceStream::~CApiCopyToDeviceStream() { PJRT_CopyToDeviceStream_Destroy_Args destroy_args; destroy_args.struct_size = PJRT_CopyToDeviceStream_Destroy_Args_STRUCT_SIZE; - destroy_args.priv = nullptr; + destroy_args.extension_start = nullptr; destroy_args.stream = c_stream_; pjrt::LogFatalIfPjrtError( c_api_->PJRT_CopyToDeviceStream_Destroy(&destroy_args), c_api_); @@ -1303,14 +1303,14 @@ PjRtFuture CApiCopyToDeviceStream::AddChunk(PjRtChunk chunk) { PJRT_CopyToDeviceStream_AddChunk_Args add_chunk_args; add_chunk_args.struct_size = PJRT_CopyToDeviceStream_AddChunk_Args_STRUCT_SIZE; - add_chunk_args.priv = nullptr; + add_chunk_args.extension_start = nullptr; add_chunk_args.stream = c_stream_; add_chunk_args.chunk = &c_chunk; PJRT_CopyToDeviceStream_CurrentBytes_Args current_bytes_args; current_bytes_args.struct_size = PJRT_CopyToDeviceStream_CurrentBytes_Args_STRUCT_SIZE; - current_bytes_args.priv = nullptr; + current_bytes_args.extension_start = nullptr; current_bytes_args.stream = c_stream_; { @@ -1430,7 +1430,7 @@ PjRtCApiLoadedExecutable::GetCommonExecuteArgs( PJRT_LoadedExecutable_Execute_Args args; args.struct_size = PJRT_LoadedExecutable_Execute_Args_STRUCT_SIZE; - args.priv = nullptr; + args.extension_start = nullptr; args.executable = c_loaded_executable(); args.options = &c_options; args.options->struct_size = PJRT_ExecuteOptions_STRUCT_SIZE; @@ -1465,7 +1465,7 @@ PjRtCApiLoadedExecutable::GetCommonExecuteArgs( PJRT_Executable_NumOutputs_Args numoutputs_args; numoutputs_args.struct_size = PJRT_Executable_NumOutputs_Args_STRUCT_SIZE; - numoutputs_args.priv = nullptr; + numoutputs_args.extension_start = nullptr; numoutputs_args.executable = c_executable(); RETURN_STATUS_IF_PJRT_ERROR( pjrt_c_api()->PJRT_Executable_NumOutputs(&numoutputs_args), pjrt_c_api()); @@ -1637,7 +1637,7 @@ PjRtCApiLoadedExecutable::ExecutePortable( void PjRtCApiLoadedExecutable::Delete() { PJRT_LoadedExecutable_Delete_Args args; args.struct_size = PJRT_LoadedExecutable_Delete_Args_STRUCT_SIZE; - args.priv = nullptr; + args.extension_start = nullptr; args.executable = c_loaded_executable(); const PJRT_Api* c_api = pjrt_c_api(); pjrt::LogFatalIfPjrtError(c_api->PJRT_LoadedExecutable_Delete(&args), c_api); @@ -1646,7 +1646,7 @@ void PjRtCApiLoadedExecutable::Delete() { bool PjRtCApiLoadedExecutable::IsDeleted() { PJRT_LoadedExecutable_IsDeleted_Args args; args.struct_size = PJRT_LoadedExecutable_IsDeleted_Args_STRUCT_SIZE; - args.priv = nullptr; + args.extension_start = nullptr; args.executable = c_loaded_executable(); const PJRT_Api* c_api = pjrt_c_api(); @@ -1670,7 +1670,7 @@ StatusOr PjRtCApiLoadedExecutable::FingerprintExecutable() const { // TODO(yeounoh): To be removed after 01/20/2024. PJRT_LoadedExecutable_Fingerprint_Args args; args.struct_size = PJRT_LoadedExecutable_Fingerprint_Args_STRUCT_SIZE; - args.priv = nullptr; + args.extension_start = nullptr; args.executable = c_loaded_executable(); const PJRT_Api* c_api = pjrt_c_api(); std::unique_ptr error( @@ -1694,7 +1694,7 @@ PjRtCApiBuffer::PjRtCApiBuffer(PjRtCApiClient* client, PJRT_Buffer* buffer) PrimitiveType PjRtCApiBuffer::element_type() const { PJRT_Buffer_ElementType_Args args; args.struct_size = PJRT_Buffer_ElementType_Args_STRUCT_SIZE; - args.priv = nullptr; + args.extension_start = nullptr; args.buffer = buffer_.get(); pjrt::LogFatalIfPjrtError(pjrt_c_api()->PJRT_Buffer_ElementType(&args), pjrt_c_api()); @@ -1704,7 +1704,7 @@ PrimitiveType PjRtCApiBuffer::element_type() const { absl::Span PjRtCApiBuffer::dimensions() const { PJRT_Buffer_Dimensions_Args args; args.struct_size = PJRT_Buffer_Dimensions_Args_STRUCT_SIZE; - args.priv = nullptr; + args.extension_start = nullptr; args.buffer = buffer_.get(); pjrt::LogFatalIfPjrtError(pjrt_c_api()->PJRT_Buffer_Dimensions(&args), pjrt_c_api()); @@ -1717,7 +1717,7 @@ const Layout& PjRtCApiBuffer::layout() const { if (!layout_.has_value()) { PJRT_Buffer_GetMemoryLayout_Args args; args.struct_size = PJRT_Buffer_GetMemoryLayout_Args_STRUCT_SIZE; - args.priv = nullptr; + args.extension_start = nullptr; args.buffer = buffer_.get(); pjrt::LogFatalIfPjrtError( pjrt_c_api()->PJRT_Buffer_GetMemoryLayout(&args), pjrt_c_api()); @@ -1735,7 +1735,7 @@ const Layout& PjRtCApiBuffer::layout() const { bool PjRtCApiBuffer::has_dynamic_dimensions() const { PJRT_Buffer_DynamicDimensionIndices_Args args; args.struct_size = PJRT_Buffer_DynamicDimensionIndices_Args_STRUCT_SIZE; - args.priv = nullptr; + args.extension_start = nullptr; args.buffer = buffer_.get(); const PJRT_Api* api = pjrt_c_api(); @@ -1760,7 +1760,7 @@ absl::Span PjRtCApiBuffer::is_dynamic_dimension() const { PJRT_Buffer_DynamicDimensionIndices_Args args; args.struct_size = PJRT_Buffer_DynamicDimensionIndices_Args_STRUCT_SIZE; - args.priv = nullptr; + args.extension_start = nullptr; args.buffer = buffer_.get(); const PJRT_Api* api = pjrt_c_api(); std::unique_ptr error( @@ -1781,7 +1781,7 @@ absl::Span PjRtCApiBuffer::is_dynamic_dimension() const { StatusOr> PjRtCApiBuffer::logical_dimensions() { PJRT_Buffer_UnpaddedDimensions_Args args; args.struct_size = PJRT_Buffer_UnpaddedDimensions_Args_STRUCT_SIZE; - args.priv = nullptr; + args.extension_start = nullptr; args.buffer = buffer_.get(); RETURN_STATUS_IF_PJRT_ERROR( pjrt_c_api()->PJRT_Buffer_UnpaddedDimensions(&args), pjrt_c_api()); @@ -1792,7 +1792,7 @@ StatusOr> PjRtCApiBuffer::logical_dimensions() { PjRtFuture PjRtCApiBuffer::ToLiteral(MutableLiteralBase* literal) { PJRT_Buffer_ToHostBuffer_Args args; args.struct_size = PJRT_Buffer_ToHostBuffer_Args_STRUCT_SIZE; - args.priv = nullptr; + args.extension_start = nullptr; args.src = buffer_.get(); const xla::Shape& shape = literal->shape(); @@ -1834,7 +1834,7 @@ PjRtFuture PjRtCApiBuffer::ToLiteral(MutableLiteralBase* literal) { StatusOr PjRtCApiBuffer::GetOnDeviceSizeInBytes() const { PJRT_Buffer_OnDeviceSizeInBytes_Args args; args.struct_size = PJRT_Buffer_OnDeviceSizeInBytes_Args_STRUCT_SIZE; - args.priv = nullptr; + args.extension_start = nullptr; args.buffer = buffer_.get(); RETURN_STATUS_IF_PJRT_ERROR( client_->pjrt_c_api()->PJRT_Buffer_OnDeviceSizeInBytes(&args), @@ -1846,7 +1846,7 @@ StatusOr PjRtCApiBuffer::GetOnDeviceSizeInBytes() const { PjRtMemorySpace* PjRtCApiBuffer::memory_space() const { PJRT_Buffer_Memory_Args args; args.struct_size = PJRT_Buffer_Memory_Args_STRUCT_SIZE; - args.priv = nullptr; + args.extension_start = nullptr; args.buffer = buffer_.get(); const PJRT_Api* api = pjrt_c_api(); std::unique_ptr error( @@ -1863,7 +1863,7 @@ PjRtMemorySpace* PjRtCApiBuffer::memory_space() const { PjRtDevice* PjRtCApiBuffer::device() const { PJRT_Buffer_Device_Args args; args.struct_size = PJRT_Buffer_Device_Args_STRUCT_SIZE; - args.priv = nullptr; + args.extension_start = nullptr; args.buffer = buffer_.get(); const PJRT_Api* api = pjrt_c_api(); pjrt::LogFatalIfPjrtError(api->PJRT_Buffer_Device(&args), api); @@ -1873,7 +1873,7 @@ PjRtDevice* PjRtCApiBuffer::device() const { void PjRtCApiBuffer::Delete() { PJRT_Buffer_Delete_Args args; args.struct_size = PJRT_Buffer_Delete_Args_STRUCT_SIZE; - args.priv = nullptr; + args.extension_start = nullptr; args.buffer = buffer_.get(); const PJRT_Api* api = pjrt_c_api(); pjrt::LogFatalIfPjrtError(api->PJRT_Buffer_Delete(&args), api); @@ -1882,7 +1882,7 @@ void PjRtCApiBuffer::Delete() { bool PjRtCApiBuffer::IsDeleted() { PJRT_Buffer_IsDeleted_Args args; args.struct_size = PJRT_Buffer_IsDeleted_Args_STRUCT_SIZE; - args.priv = nullptr; + args.extension_start = nullptr; args.buffer = buffer_.get(); const PJRT_Api* api = pjrt_c_api(); pjrt::LogFatalIfPjrtError(api->PJRT_Buffer_IsDeleted(&args), api); @@ -1894,7 +1894,7 @@ StatusOr> PjRtCApiBuffer::CopyToDevice( if (dst_device->client() == client_) { PJRT_Buffer_CopyToDevice_Args args; args.struct_size = PJRT_Buffer_CopyToDevice_Args_STRUCT_SIZE; - args.priv = nullptr; + args.extension_start = nullptr; args.buffer = buffer_.get(); args.dst_device = tensorflow::down_cast(dst_device)->c_device(); @@ -1927,7 +1927,7 @@ StatusOr> PjRtCApiBuffer::CopyToMemorySpace( if (dst_memory->client() == client_) { PJRT_Buffer_CopyToMemory_Args args; args.struct_size = PJRT_Buffer_CopyToMemory_Args_STRUCT_SIZE; - args.priv = nullptr; + args.extension_start = nullptr; args.buffer = buffer_.get(); args.dst_memory = tensorflow::down_cast(dst_memory)->c_memory(); @@ -1956,7 +1956,7 @@ StatusOr> PjRtCApiBuffer::CopyToMemorySpace( bool PjRtCApiBuffer::IsOnCpu() const { PJRT_Buffer_IsOnCpu_Args args; args.struct_size = PJRT_Buffer_IsOnCpu_Args_STRUCT_SIZE; - args.priv = nullptr; + args.extension_start = nullptr; args.buffer = buffer_.get(); const PJRT_Api* api = pjrt_c_api(); pjrt::LogFatalIfPjrtError(api->PJRT_Buffer_IsOnCpu(&args), api); @@ -1968,7 +1968,7 @@ PJRT_Event* PjRtCApiBuffer::GetReadyEvent() { const PJRT_Api* api = pjrt_c_api(); PJRT_Buffer_ReadyEvent_Args args; args.struct_size = PJRT_Buffer_ReadyEvent_Args_STRUCT_SIZE; - args.priv = nullptr; + args.extension_start = nullptr; args.buffer = buffer_.get(); pjrt::LogFatalIfPjrtError(api->PJRT_Buffer_ReadyEvent(&args), api); readiness_event_.reset(args.event); @@ -1981,7 +1981,7 @@ void PjRtCApiBuffer::MakePromiseTrackEvent() { const PJRT_Api* api = pjrt_c_api(); PJRT_Event_OnReady_Args args; args.struct_size = PJRT_Event_OnReady_Args_STRUCT_SIZE; - args.priv = nullptr; + args.extension_start = nullptr; args.event = GetReadyEvent(); args.user_arg = new std::function( [promise = readiness_promise_, api](PJRT_Error* error) -> void { @@ -2019,7 +2019,7 @@ PjRtCApiBuffer::AcquireExternalReference() { increase_reference_count_args.buffer = c_buffer(); increase_reference_count_args.struct_size = PJRT_Buffer_IncreaseExternalReferenceCount_Args_STRUCT_SIZE; - increase_reference_count_args.priv = nullptr; + increase_reference_count_args.extension_start = nullptr; RETURN_STATUS_IF_PJRT_ERROR( pjrt_c_api()->PJRT_Buffer_IncreaseExternalReferenceCount( &increase_reference_count_args), @@ -2029,7 +2029,7 @@ PjRtCApiBuffer::AcquireExternalReference() { opaque_device_memory_data_pointer_args; opaque_device_memory_data_pointer_args.struct_size = PJRT_Buffer_OpaqueDeviceMemoryDataPointer_Args_STRUCT_SIZE; - opaque_device_memory_data_pointer_args.priv = nullptr; + opaque_device_memory_data_pointer_args.extension_start = nullptr; opaque_device_memory_data_pointer_args.buffer = c_buffer(); RETURN_STATUS_IF_PJRT_ERROR( pjrt_c_api()->PJRT_Buffer_OpaqueDeviceMemoryDataPointer( @@ -2046,7 +2046,7 @@ PjRtCApiExternalReference::~PjRtCApiExternalReference() { PJRT_Buffer_DecreaseExternalReferenceCount_Args args; args.struct_size = PJRT_Buffer_DecreaseExternalReferenceCount_Args_STRUCT_SIZE; - args.priv = nullptr; + args.extension_start = nullptr; args.buffer = buffer_->c_buffer(); pjrt::LogFatalIfPjrtError( client_->pjrt_c_api()->PJRT_Buffer_DecreaseExternalReferenceCount(&args), @@ -2072,7 +2072,7 @@ absl::string_view PjRtCApiTopologyDescription::platform_name() const { PJRT_TopologyDescription_PlatformName_Args args; args.topology = c_topology_; args.struct_size = PJRT_TopologyDescription_PlatformName_Args_STRUCT_SIZE; - args.priv = nullptr; + args.extension_start = nullptr; pjrt::LogFatalIfPjrtError( c_api_->PJRT_TopologyDescription_PlatformName(&args), c_api_); return absl::string_view(args.platform_name, args.platform_name_size); @@ -2081,7 +2081,7 @@ absl::string_view PjRtCApiTopologyDescription::platform_name() const { absl::string_view PjRtCApiTopologyDescription::platform_version() const { PJRT_TopologyDescription_PlatformVersion_Args args; args.struct_size = PJRT_TopologyDescription_PlatformVersion_Args_STRUCT_SIZE; - args.priv = nullptr; + args.extension_start = nullptr; args.topology = c_topology_; pjrt::LogFatalIfPjrtError( c_api_->PJRT_TopologyDescription_PlatformVersion(&args), c_api_); @@ -2093,7 +2093,7 @@ PjRtCApiTopologyDescription::DeviceDescriptions() const { PJRT_TopologyDescription_GetDeviceDescriptions_Args args; args.struct_size = PJRT_TopologyDescription_GetDeviceDescriptions_Args_STRUCT_SIZE; - args.priv = nullptr; + args.extension_start = nullptr; args.topology = c_topology_; pjrt::LogFatalIfPjrtError( c_api_->PJRT_TopologyDescription_GetDeviceDescriptions(&args), c_api_); @@ -2111,7 +2111,7 @@ PjRtCApiTopologyDescription::DeviceDescriptions() const { StatusOr PjRtCApiTopologyDescription::Serialize() const { PJRT_TopologyDescription_Serialize_Args args; args.struct_size = PJRT_TopologyDescription_Serialize_Args_STRUCT_SIZE; - args.priv = nullptr; + args.extension_start = nullptr; args.topology = c_topology_; RETURN_STATUS_IF_PJRT_ERROR(c_api_->PJRT_TopologyDescription_Serialize(&args), c_api_); @@ -2123,7 +2123,7 @@ StatusOr PjRtCApiTopologyDescription::Serialize() const { void PjRtCApiTopologyDescription::InitAttributes() { PJRT_TopologyDescription_Attributes_Args args; args.struct_size = PJRT_TopologyDescription_Attributes_Args_STRUCT_SIZE; - args.priv = nullptr; + args.extension_start = nullptr; args.topology = c_topology_; pjrt::LogFatalIfPjrtError(c_api_->PJRT_TopologyDescription_Attributes(&args), c_api_); @@ -2139,7 +2139,7 @@ static StatusOr> InitializeArgsAndCompileAot( const std::string& format) { PJRT_Compile_Args args; args.struct_size = PJRT_Compile_Args_STRUCT_SIZE; - args.priv = nullptr; + args.extension_start = nullptr; if (client == nullptr) { args.client = nullptr; } else { @@ -2157,7 +2157,7 @@ static StatusOr> InitializeArgsAndCompileAot( PJRT_Program program; program.struct_size = PJRT_Program_STRUCT_SIZE; - program.priv = nullptr; + program.extension_start = nullptr; program.code = const_cast(code.c_str()); program.code_size = code.size(); program.format = format.c_str(); @@ -2203,7 +2203,7 @@ StatusOr> GetCApiClient( PJRT_Client_Create_Args init_args; init_args.struct_size = PJRT_Client_Create_Args_STRUCT_SIZE; - init_args.priv = nullptr; + init_args.extension_start = nullptr; TF_ASSIGN_OR_RETURN( std::vector c_options, pjrt::ConvertToPjRtNamedValueList(create_options, @@ -2237,7 +2237,7 @@ StatusOr> GetCApiTopology( PJRT_TopologyDescription_Create_Args init_args; init_args.struct_size = PJRT_TopologyDescription_Create_Args_STRUCT_SIZE; - init_args.priv = nullptr; + init_args.extension_start = nullptr; TF_ASSIGN_OR_RETURN( std::vector c_options, pjrt::ConvertToPjRtNamedValueList(create_options, diff --git a/third_party/xla/xla/pjrt/pjrt_stream_executor_client_test.cc b/third_party/xla/xla/pjrt/pjrt_stream_executor_client_test.cc index 4c9ec7c5b8e915..b308ff82bb1b52 100644 --- a/third_party/xla/xla/pjrt/pjrt_stream_executor_client_test.cc +++ b/third_party/xla/xla/pjrt/pjrt_stream_executor_client_test.cc @@ -150,7 +150,7 @@ TEST(PjRtStreamExecutorClientTest, DonateWithControlDependency) { EXPECT_FALSE(got_literal); - avr.emplace(tsl::OkStatus()); + avr.emplace(absl::OkStatus()); EXPECT_TRUE(future.IsReady()); { diff --git a/third_party/xla/xla/pjrt/plugin/BUILD b/third_party/xla/xla/pjrt/plugin/BUILD index e2504e175fefc9..4b9aa12b80b0a1 100644 --- a/third_party/xla/xla/pjrt/plugin/BUILD +++ b/third_party/xla/xla/pjrt/plugin/BUILD @@ -30,12 +30,12 @@ load("@local_tsl//tsl/platform:rules_cc.bzl", "cc_library") # ** Please don't remove this file - it is supporting some 3rd party plugins ** package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], default_visibility = ["//visibility:public"], ) cc_library( name = "plugin", - visibility = ["//visibility:public"], deps = [ #"//xla/pjrt/plugin/example:example_lib", ], diff --git a/third_party/xla/xla/pjrt/tf_pjrt_client.h b/third_party/xla/xla/pjrt/tf_pjrt_client.h index a8df9dc59e7d8d..7aaed8f25e90ab 100644 --- a/third_party/xla/xla/pjrt/tf_pjrt_client.h +++ b/third_party/xla/xla/pjrt/tf_pjrt_client.h @@ -340,6 +340,10 @@ class TfPjRtClient : public PjRtClient { StatusOr CreateHostToDeviceChannelHandle() override { return wrapped_->CreateHostToDeviceChannelHandle(); } + StatusOr GetTopologyDescription() + const override { + return wrapped_->GetTopologyDescription(); + } Status Defragment() override { return wrapped_->Defragment(); } PjRtClient* wrapped() const { return wrapped_.get(); } diff --git a/third_party/xla/xla/pjrt/tracked_device_buffer.cc b/third_party/xla/xla/pjrt/tracked_device_buffer.cc index d812054b66e131..c1a699e2de4b81 100644 --- a/third_party/xla/xla/pjrt/tracked_device_buffer.cc +++ b/third_party/xla/xla/pjrt/tracked_device_buffer.cc @@ -79,7 +79,7 @@ void BufferSequencingEvent::WaitForEventOnStream(se::Stream* stream) { return; } - stream->ThenWaitFor(event_.event()); + stream->WaitFor(event_.event()).IgnoreError(); streams_defined_on_.push_back(stream); } diff --git a/third_party/xla/xla/primitive_util.cc b/third_party/xla/xla/primitive_util.cc index 98168390720d5e..75f263ced59939 100644 --- a/third_party/xla/xla/primitive_util.cc +++ b/third_party/xla/xla/primitive_util.cc @@ -24,6 +24,7 @@ limitations under the License. #include "absl/strings/ascii.h" #include "absl/strings/string_view.h" #include "xla/statusor.h" +#include "xla/types.h" #include "xla/util.h" #include "xla/xla_data.pb.h" #include "tsl/platform/logging.h" @@ -93,6 +94,17 @@ bool HasInfinity(PrimitiveType type) { return false; } +bool HasNegativeZero(PrimitiveType type) { + if (ABSL_PREDICT_TRUE(IsFloatingPointType(type))) { + return FloatingPointTypeSwitch( + [&](auto constant_type) -> bool { + return has_negative_zero_v>; + }, + type); + } + return false; +} + xla::PrimitiveType SignedIntegralTypeForBitWidth(int64_t src_bitwidth) { switch (src_bitwidth) { case 4: @@ -166,7 +178,7 @@ GetPrimitiveTypeStringMap() { } // namespace -StatusOr StringToPrimitiveType(absl::string_view name) { +absl::StatusOr StringToPrimitiveType(absl::string_view name) { const auto& map = GetPrimitiveTypeStringMap(); auto found = map.find(name); if (found == map.end()) { diff --git a/third_party/xla/xla/primitive_util.h b/third_party/xla/xla/primitive_util.h index d9e932ba673e9b..627a0fa5a44a81 100644 --- a/third_party/xla/xla/primitive_util.h +++ b/third_party/xla/xla/primitive_util.h @@ -68,6 +68,9 @@ int ExponentBias(PrimitiveType type); // Returns whether the type has a value for infinity. bool HasInfinity(PrimitiveType type); +// Returns whether the type has a value for negative zero. +bool HasNegativeZero(PrimitiveType type); + // Returns the XLA primitive type (eg, F32) corresponding to the given // template parameter native type (eg, float). template @@ -705,7 +708,7 @@ const std::string& LowercasePrimitiveTypeName(PrimitiveType s); // Returns the PrimitiveType matching the given name. The given name is expected // to be lower-case. -StatusOr StringToPrimitiveType(absl::string_view name); +absl::StatusOr StringToPrimitiveType(absl::string_view name); // Returns true if the given name is a primitive type string (lower-case). bool IsPrimitiveTypeName(absl::string_view name); diff --git a/third_party/xla/xla/python/BUILD b/third_party/xla/xla/python/BUILD index dae405d80036f0..c6367e24e1b59f 100644 --- a/third_party/xla/xla/python/BUILD +++ b/third_party/xla/xla/python/BUILD @@ -11,6 +11,7 @@ load("@local_config_rocm//rocm:build_defs.bzl", "if_rocm") load( "@local_tsl//tsl:tsl.bzl", "if_cuda_or_rocm", + "internal_visibility", ) load("@local_tsl//tsl:tsl.default.bzl", "tsl_pybind_extension") load("@local_tsl//tsl/platform:build_config.bzl", "pyx_library", "tf_proto_library") @@ -21,7 +22,10 @@ load( ) package( - default_visibility = ["//visibility:public"], + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = internal_visibility([ + ":friends", + ]), licenses = ["notice"], ) @@ -44,13 +48,10 @@ pytype_strict_library( ], ) -exports_files( - [ - "xla_client.py", - "xla_client.pyi", - ], - visibility = ["//visibility:public"], -) +exports_files([ + "xla_client.py", + "xla_client.pyi", +]) pyx_library( name = "custom_call_for_test", @@ -178,7 +179,7 @@ cc_library( "-fno-strict-aliasing", ], features = ["-use_header_modules"], - visibility = ["//visibility:public"], + visibility = [":friends"], deps = [ "//xla:literal", "//xla:shape_util", @@ -206,7 +207,7 @@ cc_library( "-fno-strict-aliasing", ], features = ["-use_header_modules"], - visibility = ["//visibility:public"], + visibility = [":friends"], deps = [ "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:inlined_vector", @@ -224,7 +225,6 @@ cc_library( "-fno-strict-aliasing", ], features = ["-use_header_modules"], - visibility = ["//visibility:public"], deps = [ "//xla:status_macros", "//xla:util", @@ -243,7 +243,7 @@ cc_library( "-fno-strict-aliasing", ], features = ["-use_header_modules"], - visibility = ["//visibility:public"], + visibility = [":friends"], deps = [ ":python_ref_manager", # placeholder for index annotation deps @@ -251,7 +251,9 @@ cc_library( "@com_google_absl//absl/hash", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", + "@local_config_python//:python_headers", # buildcleaner: keep "//xla/pjrt:exceptions", + "@local_tsl//tsl/platform", "@local_tsl//tsl/platform:logging", "@pybind11", ], @@ -267,7 +269,6 @@ cc_library( "-fno-strict-aliasing", ], features = ["-use_header_modules"], - visibility = ["//visibility:public"], deps = [ ":traceback", "//xla:statusor", @@ -313,7 +314,6 @@ cc_library( "TENSORFLOW_USE_ROCM=1", ]), features = ["-use_header_modules"], - visibility = ["//visibility:public"], deps = [ ":callback", ":pprof_profile_builder", @@ -373,6 +373,7 @@ cc_library( "@pybind11_abseil//pybind11_abseil:absl_casters", ] + if_cuda([ "@local_config_cuda//cuda:cuda_headers", + "//xla/stream_executor/cuda:cuda_driver", ]) + if_rocm([ "@local_config_rocm//rocm:rocm_headers", ]), @@ -391,7 +392,6 @@ cc_library( "-fno-strict-aliasing", ], features = ["-use_header_modules"], - visibility = ["//visibility:public"], deps = [ ":python_ref_manager", "//xla:comparison_util", @@ -422,7 +422,6 @@ cc_library( "TENSORFLOW_USE_ROCM=1", ]), features = ["-use_header_modules"], - visibility = ["//visibility:public"], deps = [ ":callback", "//xla:comparison_util", @@ -449,7 +448,6 @@ cc_library( "-fno-strict-aliasing", ], features = ["-use_header_modules"], - visibility = ["//visibility:public"], deps = [ ":py_client", ":python_ref_manager", @@ -479,7 +477,7 @@ cc_library( "-fno-strict-aliasing", ], features = ["-use_header_modules"], - visibility = ["//visibility:public"], + visibility = [":friends"], # For the functions to access C++ flags/thread-local variables deps = [ ":py_client", ":python_ref_manager", @@ -504,7 +502,6 @@ cc_library( name = "inspect_sharding", srcs = ["inspect_sharding.cc"], hdrs = ["inspect_sharding.h"], - visibility = ["//visibility:public"], deps = [ "//xla/hlo/ir:hlo", "//xla/service:custom_call_sharding_helper", @@ -524,7 +521,7 @@ cc_library( "-fno-strict-aliasing", ], features = ["-use_header_modules"], - visibility = ["//visibility:public"], + visibility = ["//visibility:private"], deps = [ ":inspect_sharding", # placeholder for index annotation deps @@ -552,7 +549,6 @@ cc_library( "-fno-strict-aliasing", ], features = ["-use_header_modules"], - visibility = ["//visibility:public"], deps = [ ":types", # placeholder for index annotation deps @@ -578,7 +574,6 @@ cc_library( name = "outfeed_receiver", srcs = ["outfeed_receiver.cc"], hdrs = ["outfeed_receiver.h"], - visibility = ["//visibility:public"], deps = [ "//xla:literal", "//xla:shape_util", @@ -606,7 +601,7 @@ cc_library( "-fno-strict-aliasing", ], features = ["-use_header_modules"], - visibility = ["//visibility:public"], + visibility = ["//visibility:private"], deps = [ ":jax_jit", ":py_client", @@ -635,7 +630,7 @@ cc_library( "-fno-strict-aliasing", ], features = ["-use_header_modules"], - visibility = ["//visibility:public"], + visibility = ["//visibility:private"], deps = [ ":jax_jit", ":py_client", @@ -692,7 +687,6 @@ cc_library( "-fno-strict-aliasing", ], features = ["-use_header_modules"], - visibility = ["//visibility:public"], deps = [ ":outfeed_receiver", ":py_client", @@ -740,7 +734,7 @@ cc_library( "-fno-strict-aliasing", ], features = ["-use_header_modules"], - visibility = ["//visibility:public"], + visibility = [":friends"], deps = [ ":pytree_proto_cc", # placeholder for index annotation deps @@ -767,7 +761,6 @@ cc_library( "-fno-strict-aliasing", ], features = ["-use_header_modules"], - visibility = ["//visibility:public"], deps = [ ":refine_polymorphic_shapes", ":types", @@ -801,7 +794,6 @@ cc_library( name = "refine_polymorphic_shapes", srcs = ["refine_polymorphic_shapes.cc"], hdrs = ["refine_polymorphic_shapes.h"], - visibility = ["//visibility:public"], deps = [ "//xla/mlir/utils:error_util", "@com_google_absl//absl/status", @@ -831,7 +823,6 @@ cc_library( "-fno-strict-aliasing", ], features = ["-use_header_modules"], - visibility = ["//visibility:public"], deps = [ ":types", ":xplane_to_profile_instructions", @@ -872,7 +863,7 @@ cc_library( "-fno-strict-aliasing", ], features = ["-use_header_modules"], - visibility = ["//visibility:public"], + visibility = [":friends"], deps = [ # placeholder for index annotation deps "@com_google_absl//absl/base:core_headers", @@ -894,7 +885,6 @@ cc_library( "-fno-strict-aliasing", ], features = ["-use_header_modules"], - visibility = ["//visibility:public"], deps = [ "//xla:status", "//xla:util", @@ -916,7 +906,7 @@ cc_library( "-fno-strict-aliasing", ], features = ["-use_header_modules"], - visibility = ["//visibility:public"], + visibility = ["//visibility:private"], deps = [ # placeholder for index annotation deps "@com_google_absl//absl/cleanup", @@ -937,13 +927,14 @@ cc_library( "-fno-strict-aliasing", ], features = ["-use_header_modules"], - visibility = ["//visibility:public"], deps = [ ":py_client", ":types", # placeholder for index annotation deps "@com_google_absl//absl/hash", + "@com_google_absl//absl/status", "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/synchronization", "@com_google_absl//absl/types:span", "//xla:array", @@ -956,6 +947,9 @@ cc_library( "//xla/client:executable_build_options", "//xla/client:xla_builder", "//xla/client:xla_computation", + "//xla/ffi", + "//xla/ffi:ffi_api", + "//xla/ffi/api:c_api", "//xla/hlo/ir:hlo", "//xla/hlo/ir:hlo_module_group", "//xla/pjrt:exceptions", @@ -981,7 +975,6 @@ tf_proto_library( name = "py_host_callback_proto", srcs = ["py_host_callback.proto"], cc_api_version = 2, - visibility = ["//visibility:public"], ) # TODO(phawkins): the configuration settings here are overly confusing. The right fix is to split @@ -990,7 +983,6 @@ tf_proto_library( config_setting( name = "link_gpu_plugin", define_values = {"xla_python_enable_gpu": "true"}, - visibility = ["//visibility:public"], ) bool_flag( @@ -1003,7 +995,6 @@ config_setting( flag_values = { ":enable_gpu": "True", }, - visibility = ["//visibility:public"], ) # If this flag is enabled, it sets RPATH on the xla_extension to values that are suitable for @@ -1018,14 +1009,12 @@ config_setting( flag_values = { ":jax_cuda_pip_rpaths": "True", }, - visibility = ["//visibility:public"], ) # We cannot nest select and if_cuda_is_configured so we introduce # a standalone cc_library target. cc_library( name = "gpu_plugin_deps", - visibility = ["//visibility:public"], deps = [ "//xla/service:gpu_plugin", ] + if_cuda_is_configured([ @@ -1079,7 +1068,6 @@ cc_library( "//conditions:default": [], }), features = ["-use_header_modules"], - visibility = ["//visibility:public"], deps = [ ":custom_call_sharding", ":dlpack", @@ -1166,7 +1154,6 @@ cc_library( name = "xplane_to_profile_instructions", srcs = ["xplane_to_profile_instructions.cc"], hdrs = ["xplane_to_profile_instructions.h"], - visibility = ["//visibility:public"], deps = [ "//xla:status", "//xla:xla_proto_cc", @@ -1211,7 +1198,6 @@ xla_cc_test( cc_library( name = "status_casters", hdrs = ["status_casters.h"], - visibility = ["//visibility:public"], deps = [ "//xla/pjrt:status_casters", ], diff --git a/third_party/xla/xla/python/ifrt/BUILD b/third_party/xla/xla/python/ifrt/BUILD index 025c70f94b776b..c82387ee095616 100644 --- a/third_party/xla/xla/python/ifrt/BUILD +++ b/third_party/xla/xla/python/ifrt/BUILD @@ -1,4 +1,5 @@ load("//xla:xla.bzl", "xla_cc_test") +load("@local_tsl//tsl:tsl.bzl", "internal_visibility") load("@local_tsl//tsl:tsl.default.bzl", "get_compatible_with_portable") load("@local_tsl//tsl/platform:build_config.bzl", "tf_proto_library") @@ -20,15 +21,16 @@ package_group( ) package( - default_visibility = ["//visibility:public"], + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = internal_visibility([ + ":friends", + ":internal", + ]), ) -exports_files( - [ - "BUILD", - ], - visibility = ["//visibility:public"], -) +exports_files([ + "BUILD", +]) cc_library( name = "ifrt", @@ -67,7 +69,6 @@ cc_library( "value.h", ], compatible_with = get_compatible_with_portable(), - visibility = ["//visibility:public"], deps = [ ":serdes", ":types_proto_cc", @@ -186,7 +187,6 @@ cc_library( testonly = True, srcs = ["test_util.cc"], hdrs = ["test_util.h"], - visibility = ["//visibility:public"], deps = [ ":ifrt", "//xla:statusor", @@ -205,7 +205,6 @@ cc_library( testonly = True, srcs = ["sharding_test_util.cc"], hdrs = ["sharding_test_util.h"], - visibility = ["//visibility:public"], deps = [ ":ifrt", ":mock", @@ -219,7 +218,6 @@ cc_library( name = "no_impl_test_main", testonly = True, srcs = ["no_impl_test_main.cc"], - visibility = ["//visibility:public"], deps = [ "@com_google_googletest//:gtest", ], @@ -229,7 +227,6 @@ cc_library( name = "array_impl_test_lib", testonly = True, srcs = ["array_impl_test_lib.cc"], - visibility = ["//visibility:public"], deps = [ ":ifrt", ":test_util", @@ -254,7 +251,6 @@ cc_library( name = "client_impl_test_lib", testonly = True, srcs = ["client_impl_test_lib.cc"], - visibility = ["//visibility:public"], deps = [ ":ifrt", ":test_util", @@ -277,7 +273,6 @@ cc_library( name = "tuple_impl_test_lib", testonly = True, srcs = ["tuple_impl_test_lib.cc"], - visibility = ["//visibility:public"], deps = [ ":ifrt", ":test_util", @@ -304,7 +299,6 @@ cc_library( testonly = True, srcs = ["mock.cc"], hdrs = ["mock.h"], - visibility = ["//visibility:public"], deps = [ ":ifrt", "//xla:literal", @@ -326,7 +320,6 @@ cc_library( srcs = ["serdes.cc"], hdrs = ["serdes.h"], compatible_with = get_compatible_with_portable(), - visibility = ["//visibility:public"], deps = [ ":serdes_proto_cc", "@com_google_absl//absl/base:core_headers", @@ -360,7 +353,6 @@ xla_cc_test( tf_proto_library( name = "serdes_proto", srcs = ["serdes.proto"], - visibility = ["//visibility:public"], ) cc_library( @@ -368,7 +360,6 @@ cc_library( srcs = ["sharding_serdes.cc"], hdrs = ["sharding_serdes.h"], compatible_with = get_compatible_with_portable(), - visibility = ["//visibility:public"], deps = [ ":ifrt", ":serdes", @@ -399,12 +390,10 @@ xla_cc_test( tf_proto_library( name = "types_proto", srcs = ["types.proto"], - visibility = ["//visibility:public"], ) tf_proto_library( name = "sharding_proto", srcs = ["sharding.proto"], protodeps = [":types_proto"], - visibility = ["//visibility:public"], ) diff --git a/third_party/xla/xla/python/ifrt/client.h b/third_party/xla/xla/python/ifrt/client.h index f9639b9c206945..42dd3026264830 100644 --- a/third_party/xla/xla/python/ifrt/client.h +++ b/third_party/xla/xla/python/ifrt/client.h @@ -124,6 +124,14 @@ class Client : public llvm::RTTIExtends { virtual absl::string_view platform_version() const = 0; virtual PlatformId platform_id() const = 0; + // Returns the attributes of the client. In principle, these try to describe + // capabilities of a client rather than being a "feature flag". + // + // List of officially supported attributes: + // + // * supports_executable_serialization (bool; default = true): Whether IFRT + // executables produced by this client are serializable. If false, all + // executables from this client are considered not serializable. using ClientAttribute = xla::PjRtValueType; virtual absl::flat_hash_map attributes() const = 0; diff --git a/third_party/xla/xla/python/ifrt/ir/BUILD b/third_party/xla/xla/python/ifrt/ir/BUILD index 04d0808b8bc5e7..4084880cbb77dc 100644 --- a/third_party/xla/xla/python/ifrt/ir/BUILD +++ b/third_party/xla/xla/python/ifrt/ir/BUILD @@ -2,7 +2,6 @@ load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "td_library") load("@local_tsl//tsl:tsl.default.bzl", "get_compatible_with_portable") package( - default_visibility = ["//visibility:public"], # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], licenses = ["notice"], ) @@ -130,7 +129,7 @@ cc_library( "sharding_param.h", ], compatible_with = get_compatible_with_portable(), - visibility = ["//visibility:public"], + visibility = ["//xla/python/ifrt:friends"], deps = [ ":ifrt_dialect_inc_gen", ":ifrt_interfaces_inc_gen", @@ -147,7 +146,7 @@ cc_library( srcs = ["compiler.cc"], hdrs = ["compiler.h"], compatible_with = get_compatible_with_portable(), - visibility = ["//visibility:public"], + visibility = ["//xla/python/ifrt:friends"], deps = [ "//xla/python/ifrt", "@com_google_absl//absl/container:flat_hash_map", diff --git a/third_party/xla/xla/python/ifrt/ir/tests/BUILD b/third_party/xla/xla/python/ifrt/ir/tests/BUILD index e5f261d6088d8a..75bb1618e6e320 100644 --- a/third_party/xla/xla/python/ifrt/ir/tests/BUILD +++ b/third_party/xla/xla/python/ifrt/ir/tests/BUILD @@ -2,7 +2,6 @@ load("//xla:lit.bzl", "enforce_glob", "lit_test_suite") load("//xla:xla.bzl", "xla_cc_test") package( - default_visibility = ["//visibility:public"], # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], licenses = ["notice"], ) @@ -52,7 +51,7 @@ cc_library( testonly = True, srcs = ["executable_impl_test_base.cc"], hdrs = ["executable_impl_test_base.h"], - visibility = ["//visibility:public"], + visibility = ["//xla/python/ifrt:friends"], deps = [ "//xla:status_macros", "//xla/mlir_hlo:hlo_dialect_registration", @@ -76,7 +75,7 @@ cc_library( name = "executable_impl_test_lib", testonly = True, srcs = ["executable_impl_test_lib.cc"], - visibility = ["//visibility:public"], + visibility = ["//xla/python/ifrt:friends"], deps = [ ":executable_impl_test_base", "//xla/pjrt:pjrt_executable", diff --git a/third_party/xla/xla/python/ifrt/ir/transforms/BUILD b/third_party/xla/xla/python/ifrt/ir/transforms/BUILD index 72a6895c28bd4a..4905efe8d33189 100644 --- a/third_party/xla/xla/python/ifrt/ir/transforms/BUILD +++ b/third_party/xla/xla/python/ifrt/ir/transforms/BUILD @@ -2,7 +2,8 @@ load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library") load("@local_tsl//tsl:tsl.default.bzl", "get_compatible_with_portable") package( - default_visibility = ["//visibility:public"], + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = ["//xla/python/ifrt:friends"], licenses = ["notice"], ) @@ -34,7 +35,6 @@ cc_library( ], hdrs = ["passes.h"], compatible_with = get_compatible_with_portable(), - visibility = ["//visibility:public"], deps = [ ":constants", ":passes_inc_gen", @@ -57,7 +57,6 @@ cc_library( srcs = ["built_in_spmd_expansions.cc"], hdrs = ["built_in_spmd_expansions.h"], compatible_with = get_compatible_with_portable(), - visibility = ["//visibility:public"], deps = [ "//xla/python/ifrt/ir/transforms/spmd_expanders:spmd_expander", "@llvm-project//mlir:FuncDialect", @@ -69,6 +68,5 @@ cc_library( name = "constants", hdrs = ["constants.h"], compatible_with = get_compatible_with_portable(), - visibility = ["//visibility:public"], deps = ["@llvm-project//llvm:Support"], ) diff --git a/third_party/xla/xla/python/ifrt/ir/transforms/spmd_expanders/BUILD b/third_party/xla/xla/python/ifrt/ir/transforms/spmd_expanders/BUILD index 1af96031fadd64..090ca21d09dcbc 100644 --- a/third_party/xla/xla/python/ifrt/ir/transforms/spmd_expanders/BUILD +++ b/third_party/xla/xla/python/ifrt/ir/transforms/spmd_expanders/BUILD @@ -1,7 +1,6 @@ load("@local_tsl//tsl:tsl.default.bzl", "get_compatible_with_portable") package( - default_visibility = ["//visibility:public"], # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], licenses = ["notice"], ) @@ -15,7 +14,7 @@ cc_library( "*spmd_expander.h", ]), compatible_with = get_compatible_with_portable(), - visibility = ["//visibility:public"], + visibility = ["//xla/python/ifrt:friends"], deps = [ "//xla/python/ifrt/ir", "@llvm-project//llvm:Support", diff --git a/third_party/xla/xla/python/ifrt/support/BUILD b/third_party/xla/xla/python/ifrt/support/BUILD index a2b245904e8360..1fde84af53aed4 100644 --- a/third_party/xla/xla/python/ifrt/support/BUILD +++ b/third_party/xla/xla/python/ifrt/support/BUILD @@ -1,7 +1,6 @@ load("//xla:xla.bzl", "xla_cc_test") package( - default_visibility = ["//visibility:public"], # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], licenses = ["notice"], ) @@ -10,7 +9,7 @@ cc_library( name = "sharding_conversions", srcs = ["sharding_conversions.cc"], hdrs = ["sharding_conversions.h"], - visibility = ["//visibility:public"], + visibility = ["//xla/python/ifrt:friends"], deps = [ "//xla:statusor", "//xla:xla_data_proto_cc", diff --git a/third_party/xla/xla/python/ifrt_proxy/client/BUILD b/third_party/xla/xla/python/ifrt_proxy/client/BUILD new file mode 100644 index 00000000000000..7a6071ad0b2aee --- /dev/null +++ b/third_party/xla/xla/python/ifrt_proxy/client/BUILD @@ -0,0 +1,519 @@ +# Copyright 2023 The OpenXLA Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("//xla/python/ifrt_proxy/common:ifrt_proxy.bzl", "default_ifrt_proxy_visibility", "ifrt_proxy_cc_test") +load("@local_tsl//tsl:tsl.bzl", "if_google") +load("@local_tsl//tsl:tsl.default.bzl", "tsl_pybind_extension") + +package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = default_ifrt_proxy_visibility, +) + +cc_library( + name = "grpc_client_session", + srcs = [ + "grpc_client_session.cc", + ], + hdrs = ["grpc_client_session.h"], + deps = [ + ":client_session", + "//xla/pjrt/distributed:util", + "//xla/python/ifrt", + "//xla/python/ifrt_proxy/common:grpc_credentials", + "//xla/python/ifrt_proxy/common:grpc_ifrt_service_cc_grpc_proto", + "//xla/python/ifrt_proxy/common:ifrt_service_proto_cc", + "@com_github_grpc_grpc//:grpc", + "@com_github_grpc_grpc//:grpc++", + "@com_google_absl//absl/base", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/functional:bind_front", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/synchronization", + "@local_tsl//tsl/platform:env", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:logging", + "@local_tsl//tsl/platform:unbounded_work_queue", + ], +) + +ifrt_proxy_cc_test( + name = "grpc_client_session_test", + srcs = [ + "grpc_client_session_test.cc", + ], + deps = [ + ":grpc_client_session", + ":version", + "//xla/python/ifrt_proxy/common:grpc_credentials", + "//xla/python/ifrt_proxy/common:grpc_ifrt_service_cc_grpc_proto", + "//xla/python/ifrt_proxy/common:grpc_ifrt_service_proto_cc", + "//xla/python/ifrt_proxy/common:ifrt_service_proto_cc", + "@com_github_grpc_grpc//:gpr", + "@com_github_grpc_grpc//:grpc++", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/log:log_sink_registry", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/time", + "@com_google_googletest//:gtest_main", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:logging", + "@local_tsl//tsl/platform:status_matchers", + "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/platform:test", + ], +) + +cc_library( + name = "rpc_helper", + srcs = [ + "rpc_helper.cc", + ], + hdrs = ["rpc_helper.h"], + deps = [ + ":client_session", + ":host_buffer", + "//xla/python/ifrt", + "//xla/python/ifrt_proxy/common:ifrt_service_proto_cc", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/synchronization", + "@local_tsl//tsl/platform:status_to_from_proto", + ] + if_google(["@com_google_absl//absl/types:source_location"]), +) + +cc_library( + name = "client", + srcs = ["client.cc"], + hdrs = ["client.h"], + deps = [ + ":array", + ":compiler", + ":device", + ":memory", + ":rpc_helper", + "//xla:xla_data_proto_cc", + "//xla/pjrt:pjrt_compiler", + "//xla/pjrt:pjrt_device_description", + "//xla/python/ifrt", + "//xla/python/ifrt_proxy/common:common_serdes", + "//xla/python/ifrt_proxy/common:ifrt_service_proto_cc", + "//xla/python/ifrt_proxy/common:types", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:optional", + "@com_google_absl//absl/types:span", + "@llvm-project//llvm:Support", + "@local_tsl//tsl/concurrency:ref_count", + "@local_tsl//tsl/platform:statusor", + ], +) + +ifrt_proxy_cc_test( + name = "client_test", + srcs = ["client_test.cc"], + deps = [ + ":client", + ":client_session", + ":host_buffer", + ":mock_client_session", + ":mock_host_buffer", + ":rpc_helper", + ":version", + "//xla/pjrt:pjrt_device_description", + "//xla/python/ifrt", + "//xla/python/ifrt_proxy/common:ifrt_service_proto_cc", + "//xla/service:computation_placer_hdr", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings:string_view", + "@com_google_googletest//:gtest_main", + "@local_tsl//tsl/platform", + "@local_tsl//tsl/platform:protobuf", + "@local_tsl//tsl/platform:status_matchers", + "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/platform:test", + ], +) + +cc_library( + name = "device", + srcs = ["device.cc"], + hdrs = ["device.h"], + deps = [ + "//xla:literal", + "//xla/pjrt:pjrt_client", + "//xla/pjrt:pjrt_device_description", + "//xla/pjrt:pjrt_future", + "//xla/python/ifrt", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + ], +) + +cc_library( + name = "array", + srcs = ["array.cc"], + hdrs = ["array.h"], + deps = [ + ":rpc_helper", + "//xla:status_macros", + "//xla/python/ifrt", + "//xla/python/ifrt_proxy/common:array_util", + "//xla/python/ifrt_proxy/common:ifrt_service_proto_cc", + "//xla/python/ifrt_proxy/common:types", + "//xla/python/ifrt_proxy/common:types_proto_cc", + "//xla/python/pjrt_ifrt", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/cleanup", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:cord", + "@com_google_absl//absl/types:optional", + "@com_google_absl//absl/types:span", + "@llvm-project//llvm:Support", + "@local_tsl//tsl/concurrency:ref_count", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:statusor", + ], +) + +ifrt_proxy_cc_test( + name = "array_test", + srcs = ["array_test.cc"], + deps = [ + ":array", + ":client_session", + ":host_buffer", + ":mock_client_session", + ":mock_host_buffer", + ":rpc_helper", + ":version", + "//xla/python/ifrt", + "//xla/python/ifrt:mock", + "//xla/python/ifrt_proxy/common:ifrt_service_proto_cc", + "//xla/python/ifrt_proxy/common:types", + "//xla/python/ifrt_proxy/common:types_proto_cc", + "@com_google_absl//absl/status", + "@com_google_absl//absl/types:span", + "@com_google_googletest//:gtest_main", + "@local_tsl//tsl/concurrency:ref_count", + "@local_tsl//tsl/platform:protobuf", + "@local_tsl//tsl/platform:status_matchers", + "@local_tsl//tsl/platform:test", + ], +) + +cc_library( + name = "client_session", + hdrs = ["client_session.h"], + deps = [ + "//xla/python/ifrt", + "//xla/python/ifrt_proxy/common:ifrt_service_proto_cc", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + ], +) + +cc_library( + name = "mock_client_session", + testonly = True, + hdrs = ["mock_client_session.h"], + deps = [ + ":client_session", + "//xla/python/ifrt", + "//xla/python/ifrt_proxy/common:ifrt_service_proto_cc", + "@com_google_absl//absl/status", + "@com_google_googletest//:gtest", + ], +) + +cc_library( + name = "compiler", + srcs = ["compiler.cc"], + hdrs = ["compiler.h"], + deps = [ + ":executable", + ":rpc_helper", + "//xla/pjrt:host_callback", + "//xla/python/ifrt", + "//xla/python/ifrt:serdes", + "//xla/python/ifrt_proxy/common:ifrt_service_proto_cc", + "//xla/python/ifrt_proxy/server:host_callback", + "//xla/python/pjrt_ifrt", + "//xla/python/pjrt_ifrt:xla_ifrt", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@llvm-project//llvm:Support", + "@local_tsl//tsl/concurrency:ref_count", + "@local_tsl//tsl/platform:status_to_from_proto", + "@local_tsl//tsl/platform:statusor", + ], +) + +ifrt_proxy_cc_test( + name = "compiler_test", + srcs = ["compiler_test.cc"], + deps = [ + ":client_session", + ":compiler", + ":host_buffer", + ":mock_client_session", + ":mock_host_buffer", + ":rpc_helper", + ":version", + "//xla/python/ifrt", + "//xla/python/ifrt:mock", + "//xla/python/ifrt:serdes", + "//xla/python/ifrt_proxy/common:ifrt_service_proto_cc", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@com_google_googletest//:gtest_main", + "@llvm-project//llvm:Support", + "@local_tsl//tsl/platform:protobuf", + "@local_tsl//tsl/platform:status_matchers", + "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/platform:test", + ], +) + +cc_library( + name = "executable", + srcs = ["executable.cc"], + hdrs = ["executable.h"], + deps = [ + ":array", + ":host_buffer", + ":rpc_helper", + "//xla:shape_util", + "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/pjrt:host_callback", + "//xla/pjrt:pjrt_executable", + "//xla/python/ifrt", + "//xla/python/ifrt_proxy/common:ifrt_service_proto_cc", + "//xla/python/ifrt_proxy/common:types", + "//xla/python/pjrt_ifrt", + "@com_google_absl//absl/cleanup", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:node_hash_set", + "@com_google_absl//absl/functional:bind_front", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:cord", + "@com_google_absl//absl/types:span", + "@llvm-project//llvm:Support", + "@local_tsl//tsl/concurrency:ref_count", + "@local_tsl//tsl/platform:env", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:status_to_from_proto", + "@local_tsl//tsl/platform:statusor", + ], +) + +cc_library( + name = "host_buffer", + hdrs = ["host_buffer.h"], + deps = [ + "//xla/python/ifrt", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:cord", + ], +) + +cc_library( + name = "mock_host_buffer", + testonly = True, + hdrs = ["mock_host_buffer.h"], + deps = [ + ":host_buffer", + "//xla/python/ifrt", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:cord", + "@com_google_googletest//:gtest", + ], +) + +cc_library( + name = "grpc_host_buffer", + srcs = ["grpc_host_buffer.cc"], + hdrs = ["grpc_host_buffer.h"], + deps = [ + ":host_buffer", + "//xla/pjrt/distributed:util", + "//xla/python/ifrt", + "//xla/python/ifrt_proxy/common:grpc_ifrt_service_cc_grpc_proto", + "//xla/python/ifrt_proxy/common:grpc_ifrt_service_proto_cc", + "@com_github_grpc_grpc//:grpc++", + "@com_google_absl//absl/log", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:cord", + "@com_google_absl//absl/strings:string_view", + "@local_tsl//tsl/platform:env", + "@local_tsl//tsl/platform:unbounded_work_queue", + "@local_tsl//tsl/protobuf:status_proto_cc", + ], +) + +cc_library( + name = "grpc_client", + srcs = ["grpc_client.cc"], + deps = [ + ":client", + ":grpc_client_session", + ":grpc_host_buffer", + ":registry", + ":rpc_helper", + ":version", + "//xla/pjrt/distributed:util", + "//xla/python/ifrt", + "//xla/python/ifrt_proxy/common:grpc_ifrt_service_proto_cc", + "//xla/python/ifrt_proxy/common:ifrt_service_proto_cc", + "@com_github_grpc_grpc//:grpc++", + "@com_google_absl//absl/functional:any_invocable", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/log:log_entry", + "@com_google_absl//absl/log:log_sink", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/time", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:statusor", + ], + alwayslink = True, +) + +cc_library( + name = "registry", + srcs = ["registry.cc"], + hdrs = ["registry.h"], + deps = [ + "//xla/python/ifrt", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/time", + ], +) + +cc_library( + name = "memory", + hdrs = ["memory.h"], + deps = [ + "//xla/pjrt:pjrt_client", + "//xla/python/ifrt", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + ], +) + +cc_library( + name = "version", + hdrs = ["version.h"], +) + +ifrt_proxy_cc_test( + name = "executable_test", + srcs = ["executable_test.cc"], + deps = [ + ":array", + ":client_session", + ":executable", + ":host_buffer", + ":mock_client_session", + ":mock_host_buffer", + ":rpc_helper", + ":version", + "//xla:shape_util", + "//xla/pjrt:pjrt_common", + "//xla/python/ifrt", + "//xla/python/ifrt:mock", + "//xla/python/ifrt_proxy/common:ifrt_service_proto_cc", + "//xla/python/ifrt_proxy/common:types", + "@com_google_absl//absl/status", + "@com_google_absl//absl/types:span", + "@com_google_googletest//:gtest_main", + "@llvm-project//llvm:Support", + "@local_tsl//tsl/concurrency:ref_count", + "@local_tsl//tsl/platform:protobuf", + "@local_tsl//tsl/platform:status_matchers", + "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/platform:test", + ], +) + +tsl_pybind_extension( + name = "py_module", + srcs = ["py_module.cc"], + deps = [ + ":grpc_client", + ":registry", + "//xla/pjrt:status_casters", + "//xla/python:py_client", + "//xla/python/ifrt", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/log:log_entry", + "@com_google_absl//absl/log:log_sink", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@local_tsl//tsl/platform:env", + "@local_tsl//tsl/platform:statusor", + "@pybind11", + "@pybind11_abseil//pybind11_abseil:absl_casters", + "@pybind11_protobuf//pybind11_protobuf:native_proto_caster", + ], +) diff --git a/third_party/xla/xla/python/ifrt_proxy/client/README.md b/third_party/xla/xla/python/ifrt_proxy/client/README.md new file mode 100644 index 00000000000000..b97be206439a1c --- /dev/null +++ b/third_party/xla/xla/python/ifrt_proxy/client/README.md @@ -0,0 +1,11 @@ +This directory implements the IFRT proxy client. + +## Expected behavior when connection to the IFRT proxy server fails + +If a connection to the proxy server fails abruptly, any in-progress or further +IFRT API calls and `Future`s are expected to either return valid values (if the +value was already fetched from the server and is being cached locally) or an +error from `rpc_helper.cc`'s `WrapAsConnectionError()`. They are expected to +neither "hang" beyond the brief period required to determine whether the +connection has failed nor crash the process internally within the proxy client +library. diff --git a/third_party/xla/xla/python/ifrt_proxy/client/array.cc b/third_party/xla/xla/python/ifrt_proxy/client/array.cc new file mode 100644 index 00000000000000..58c01c1d062906 --- /dev/null +++ b/third_party/xla/xla/python/ifrt_proxy/client/array.cc @@ -0,0 +1,345 @@ +// Copyright 2023 The OpenXLA Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "xla/python/ifrt_proxy/client/array.h" + +#include +#include +#include +#include +#include +#include +#include + +#include "absl/cleanup/cleanup.h" +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/cord.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "absl/strings/string_view.h" +#include "absl/strings/substitute.h" +#include "absl/types/span.h" +#include "llvm/Support/Casting.h" +#include "xla/python/ifrt/array.h" +#include "xla/python/ifrt/client.h" +#include "xla/python/ifrt/dtype.h" +#include "xla/python/ifrt/future.h" +#include "xla/python/ifrt/shape.h" +#include "xla/python/ifrt/sharding.h" +#include "xla/python/ifrt_proxy/client/rpc_helper.h" +#include "xla/python/ifrt_proxy/common/array_util.h" +#include "xla/python/ifrt_proxy/common/ifrt_service.pb.h" +#include "xla/python/ifrt_proxy/common/types.h" +#include "xla/python/ifrt_proxy/common/types.pb.h" +#include "xla/python/pjrt_ifrt/pjrt_array.h" +#include "xla/status_macros.h" +#include "tsl/concurrency/ref_count.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/statusor.h" + +namespace xla { +namespace ifrt { +namespace proxy { + +char Array::ID = 0; + +absl::StatusOr> +Array::MakeArrayFromHostBuffer( + xla::ifrt::Client* client, std::shared_ptr rpc_helper, + const void* data, DType dtype, Shape shape, + std::optional> byte_strides, + std::shared_ptr sharding, + xla::ifrt::Client::HostBufferSemantics semantics, + std::function on_done_with_host_buffer) { + TF_ASSIGN_OR_RETURN(const auto array_mem_region, + ArrayMemRegion::FromZerothElementPointer( + /*zeroth_element=*/data, dtype, shape, byte_strides)); + + const uint64_t host_buffer_handle = + rpc_helper->host_buffer_store()->NextHandle(); + TF_RETURN_IF_ERROR( + rpc_helper->host_buffer_store() + ->Store(host_buffer_handle, array_mem_region.mem_region()) + .Await()); + + auto req = std::make_unique(); + req->set_host_buffer_handle(host_buffer_handle); + req->set_dtype(ToDTypeProto(dtype)); + *req->mutable_shape() = ToShapeProto(shape); + TF_ASSIGN_OR_RETURN(*req->mutable_sharding(), ToShardingProto(*sharding)); + if (byte_strides.has_value()) { + *req->mutable_byte_strides() = ToByteStridesProto(*byte_strides); + } + + TF_ASSIGN_OR_RETURN( + auto response, + rpc_helper->MakeArrayFromHostBuffer(std::move(req)).Await()); + const ArrayHandle handle{.handle = response->array_handle()}; + + if (on_done_with_host_buffer != nullptr) { + std::move(on_done_with_host_buffer)(); + } + + return tsl::RCReference( + tsl::MakeRef(client, std::move(rpc_helper), dtype, + std::move(shape), std::move(sharding), handle)); +} + +void Array::Destruct(RpcHelper* rpc_helper, ArrayHandle handle) { + auto req = std::make_unique(); + req->set_array_handle(handle.handle); + rpc_helper->DestructArray(std::move(req)) + .OnReady( + [](absl::StatusOr> response) { + if (!response.ok()) { + LOG(WARNING) + << "Server returned an error when asked to destruct array: " + << response.status(); + } + }); +} + +Future Array::GetReadyFuture() const { + auto req = std::make_unique(); + req->set_array_handle(handle_.handle); + + auto promise = Future::CreatePromise(); + rpc_helper_->CheckArrayReady(std::move(req)) + .OnReady( + [promise](absl::StatusOr> + resp) mutable -> void { promise.Set(resp.status()); }); + return Future(std::move(promise)); +} + +Future Array::Delete() { + auto req = std::make_unique(); + req->set_array_handle(handle_.handle); + + absl::StatusOr> response = + rpc_helper_->DeleteArray(std::move(req)).Await(); + if (!response.ok()) { + return Future(response.status()); + } + + // TODO(b/266635130): So that the caller is not blocked until the server + // replies with the deletion's response, from within + // `Future(status_handle_promise).OnReady()`, schedule `CheckFuture()` on a + // separate thread. + return rpc_helper_->CheckFuture((*response)->deletion_future_handle()); +} + +bool Array::IsDeleted() const { + auto req = std::make_unique(); + req->set_array_handle(handle_.handle); + + absl::StatusOr> response = + rpc_helper_->IsArrayDeleted(std::move(req)).Await(); + if (response.ok()) { + return (*response)->deleted(); + } else { + LOG(ERROR) << "Internal error from proxy server during Array::IsDeleted(): " + << response.status(); + // Return false so that the user likely queries the array with some + // method that returns an absl::Status, and ends up with the real + // error being returned to them by that method. + return false; + } +} + +absl::StatusOr> +Array::AssembleArrayFromSingleDeviceArrays( + xla::ifrt::Client* client, std::shared_ptr rpc_helper, + Shape shape, std::shared_ptr sharding, + absl::Span> arrays, + ArrayCopySemantics semantics) { + auto req = std::make_unique(); + TF_RET_CHECK(!arrays.empty()); + *req->mutable_shape() = ToShapeProto(shape); + TF_ASSIGN_OR_RETURN(*req->mutable_sharding(), ToShardingProto(*sharding)); + req->set_copy_semantics(ToArrayCopySemanticsProto(semantics)); + for (const tsl::RCReference& rcref : arrays) { + Array* array = llvm::dyn_cast(rcref.get()); + if (array == nullptr) { + return absl::InvalidArgumentError(absl::Substitute( + "Array at $0 supplied to AssembleArrayFromSingleDeviceArrays() is " + "not a xla::ifrt::proxy::Array.", + rcref.get())); + } + req->add_single_device_array_handles(array->handle_.handle); + } + + TF_ASSIGN_OR_RETURN( + std::shared_ptr response, + rpc_helper->AssembleArrayFromSingleDeviceArrays(std::move(req)).Await()); + ArrayHandle handle{.handle = response->array_handle()}; + + return tsl::RCReference( + tsl::MakeRef(client, std::move(rpc_helper), arrays[0]->dtype(), + std::move(shape), std::move(sharding), handle)); +} + +absl::StatusOr>> +Array::DisassembleIntoSingleDeviceArrays(ArrayCopySemantics semantics) { + auto req = std::make_unique(); + req->set_array_handle(handle_.handle); + req->set_copy_semantics(ToArrayCopySemanticsProto(semantics)); + + TF_ASSIGN_OR_RETURN( + std::shared_ptr response, + rpc_helper_->DisassembleIntoSingleDeviceArrays(std::move(req)).Await()); + std::vector handles; + for (auto& handle : response->single_device_array_handles()) { + handles.push_back(ArrayHandle{.handle = handle}); + } + + TF_ASSIGN_OR_RETURN(auto shape_and_shardings, sharding_->Disassemble(shape_)); + CHECK_EQ(handles.size(), shape_and_shardings.size()) + << " " << absl::StrJoin(handles, ",") << " " << shape_ << " " + << *sharding_ << " "; + + std::vector> result; + result.reserve(handles.size()); + for (int i = 0; i < handles.size(); ++i) { + result.push_back(tsl::RCReference(tsl::MakeRef( + client_, rpc_helper_, dtype_, std::move(shape_and_shardings[i].first), + std::move(shape_and_shardings[i].second), handles[i]))); + } + + return result; +} + +absl::StatusOr> Array::FullyReplicatedShard( + ArrayCopySemantics semantics) { + auto req = std::make_unique(); + req->set_array_handle(handle_.handle); + req->set_copy_semantics(ToArrayCopySemanticsProto(semantics)); + + TF_ASSIGN_OR_RETURN( + std::shared_ptr response, + rpc_helper_->FullyReplicatedShard(std::move(req)).Await()); + + ArrayHandle handle{.handle = response->array_handle()}; + + // We are making the assumption the Array returned by the server corresponds + // to the first device. Revisit this when IFRT supports: (1) an inexpensive + // way to derive a SingleDeviceSharding from a fully replicated Array's + // sharding and (2) A generalized `Reshard` API that allows the user to + // request an Array to be made out of a specific single shard. + std::unique_ptr single_device_sharding = + xla::ifrt::SingleDeviceSharding::Create(sharding_->devices()[0], + sharding_->memory_kind()); + + return tsl::RCReference( + tsl::MakeRef(client_, rpc_helper_, dtype_, shape_, + std::move(single_device_sharding), handle)); +} + +absl::StatusOr> Array::Reshard( + std::shared_ptr new_sharding, + ArrayCopySemantics semantics) { + auto req = std::make_unique(); + req->set_array_handle(handle_.handle); + TF_ASSIGN_OR_RETURN(*req->mutable_sharding(), ToShardingProto(*new_sharding)); + req->set_copy_semantics(ToArrayCopySemanticsProto(semantics)); + + TF_ASSIGN_OR_RETURN(std::shared_ptr response, + rpc_helper_->Reshard(std::move(req)).Await()); + ArrayHandle handle{.handle = response->array_handle()}; + + return tsl::RCReference(tsl::MakeRef( + client_, rpc_helper_, dtype_, shape_, std::move(new_sharding), handle)); +} + +Future Array::CopyToHostBuffer( + void* data, std::optional> byte_strides, + ArrayCopySemantics semantics) { + const auto mem_region = ArrayMemRegion::FromZerothElementPointer( + /*zeroth_element=*/data, dtype_, shape_, byte_strides); + if (!mem_region.ok()) { + return Future(mem_region.status()); + } + + auto req = std::make_unique(); + req->set_array_handle(handle_.handle); + if (byte_strides.has_value()) { + *req->mutable_byte_strides() = ToByteStridesProto(*byte_strides); + } + const uint64_t host_buffer_handle = + rpc_helper_->host_buffer_store()->NextHandle(); + req->set_host_buffer_handle(host_buffer_handle); + + auto promise = Future::CreatePromise(); + auto on_ready = [host_buffer_store = rpc_helper_->host_buffer_store(), + promise, host_buffer_handle, + mem_region = mem_region->mem_region()]( + absl::StatusOr> + resp) mutable { + if (!resp.ok()) { + promise.Set(resp.status()); + return; + } + + auto host_buffer = host_buffer_store->Lookup(host_buffer_handle); + host_buffer.OnReady( + [promise, mem_region, host_buffer_store, + host_buffer_handle](absl::StatusOr data) mutable { + absl::Cleanup cleanup = [&]() { + host_buffer_store->Delete(host_buffer_handle) + .OnReady([buffer_status = data.status()](absl::Status status) { + if (!status.ok()) { + LOG(WARNING) << "Failed to delete host buffer: " << status + << " (buffer status: " << buffer_status << ")"; + } + }); + }; + + if (!data.ok()) { + promise.Set(data.status()); + return; + } + if (data->size() != mem_region.size()) { + auto status = absl::InternalError( + absl::StrCat("During CopyToHostBuffer, size mismatch in " + "response from proxy: ", + mem_region.size(), " vs ", data->size())); + LOG(ERROR) << status; + promise.Set(status); + return; + } +#if defined(PLATFORM_GOOGLE) + data->CopyToArray(const_cast(mem_region.data())); +#else + std::memcpy(const_cast(mem_region.data()), + data->Flatten().data(), data->size()); +#endif + promise.Set(absl::OkStatus()); + }); + }; + rpc_helper_->CopyToHostBuffer(std::move(req)).OnReady(std::move(on_ready)); + return Future(std::move(promise)); +} + +xla::ifrt::Client* Array::client() const { return client_; } + +std::string Array::DebugString() const { + return absl::Substitute("proxy::Array, this=$0, handle=$1", this, + handle_.handle); +} + +} // namespace proxy +} // namespace ifrt +} // namespace xla diff --git a/third_party/xla/xla/python/ifrt_proxy/client/array.h b/third_party/xla/xla/python/ifrt_proxy/client/array.h new file mode 100644 index 00000000000000..c17f497b1ce81d --- /dev/null +++ b/third_party/xla/xla/python/ifrt_proxy/client/array.h @@ -0,0 +1,144 @@ +/* + * Copyright 2023 The OpenXLA Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef XLA_PYTHON_IFRT_PROXY_CLIENT_ARRAY_H_ +#define XLA_PYTHON_IFRT_PROXY_CLIENT_ARRAY_H_ + +#include +#include +#include +#include +#include +#include +#include + +#include "absl/base/attributes.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/types/optional.h" +#include "absl/types/span.h" +#include "llvm/Support/ExtensibleRTTI.h" +#include "xla/python/ifrt/array.h" +#include "xla/python/ifrt/client.h" +#include "xla/python/ifrt/dtype.h" +#include "xla/python/ifrt/future.h" +#include "xla/python/ifrt/shape.h" +#include "xla/python/ifrt/sharding.h" +#include "xla/python/ifrt/tuple.h" +#include "xla/python/ifrt/value.h" +#include "xla/python/ifrt_proxy/client/rpc_helper.h" +#include "xla/python/ifrt_proxy/common/types.h" +#include "tsl/concurrency/ref_count.h" + +namespace xla { +namespace ifrt { +namespace proxy { + +// Implementation of the xla::ifrt::Array interface. +class Array final : public llvm::RTTIExtends { + public: + // `Array::MakeArrayFromHostBuffer()` implements + // `Client::MakeArrayFromHostBuffer()`. + // TODO(b/261226026): Implement logic directly in client.cc. + static absl::StatusOr> + MakeArrayFromHostBuffer(xla::ifrt::Client* client, + std::shared_ptr rpc_helper, + const void* data, DType dtype, Shape shape, + std::optional> byte_strides, + std::shared_ptr sharding, + xla::ifrt::Client::HostBufferSemantics semantics, + std::function on_done_with_host_buffer); + + // `Array::AssembleArrayFromSingleDeviceArrays()` implements + // `Client::AssembleArrayFromSingleDeviceArrays()`. + // TODO(b/261226026): Implement logic directly in client.cc. + static absl::StatusOr> + AssembleArrayFromSingleDeviceArrays( + xla::ifrt::Client* client, std::shared_ptr rpc_helper, + Shape shape, std::shared_ptr sharding, + absl::Span> arrays, + ArrayCopySemantics semantics); + + // Destructs the array associated with the given handle. The corresponding + // array becomes unusable afterwards. + static void Destruct(RpcHelper* rpc_helper, ArrayHandle handle); + + Array(xla::ifrt::Client* const client, std::shared_ptr rpc_helper, + DType dtype, Shape shape, std::shared_ptr sharding, + ArrayHandle handle) + : client_(client), + rpc_helper_(std::move(rpc_helper)), + dtype_(dtype), + shape_(std::move(shape)), + sharding_(std::move(sharding)), + handle_(handle) {} + + ~Array() override { Destruct(rpc_helper_.get(), handle_); } + + ArrayHandle handle() const { return handle_; } + + xla::ifrt::Client* client() const override; + Future GetReadyFuture() const override; + Future Delete() override; + bool IsDeleted() const override; + std::string DebugString() const override; + + DType dtype() const override { return dtype_; } + const Shape& shape() const override { return shape_; } + const Sharding& sharding() const override { return *sharding_; } + std::shared_ptr shared_ptr_sharding() const override { + return sharding_; + } + + absl::StatusOr>> + DisassembleIntoSingleDeviceArrays(ArrayCopySemantics semantics) override; + + absl::StatusOr> FullyReplicatedShard( + xla::ifrt::ArrayCopySemantics semantics) override; + + ABSL_MUST_USE_RESULT + Future CopyToHostBuffer( + void* data, std::optional> byte_strides, + ArrayCopySemantics semantics) override; + + absl::StatusOr> Reshard( + std::shared_ptr new_sharding, + ArrayCopySemantics semantics) override; + + static char ID; // NOLINT + + private: + template + friend tsl::RCReference tsl::MakeRef(Args&&... args); + + // Not owned. Used only for implementing `client()` interface method. Note + // that `client()` will still return the pointer even if the pointed-to memory + // is freed; this unfortunate behavior currently exists in all IFRT + // implementations. + xla::ifrt::Client* const client_; + + const std::shared_ptr rpc_helper_; + const DType dtype_; + const Shape shape_; + const std::shared_ptr sharding_; + const ArrayHandle handle_; +}; + +} // namespace proxy +} // namespace ifrt +} // namespace xla + +#endif // XLA_PYTHON_IFRT_PROXY_CLIENT_ARRAY_H_ diff --git a/third_party/xla/xla/python/ifrt_proxy/client/array_test.cc b/third_party/xla/xla/python/ifrt_proxy/client/array_test.cc new file mode 100644 index 00000000000000..0d80bba73cd9dc --- /dev/null +++ b/third_party/xla/xla/python/ifrt_proxy/client/array_test.cc @@ -0,0 +1,139 @@ +// Copyright 2023 The OpenXLA Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "xla/python/ifrt_proxy/client/array.h" + +#include +#include + +#include +#include +#include "absl/status/status.h" +#include "absl/types/span.h" +#include "xla/python/ifrt/array.h" +#include "xla/python/ifrt/dtype.h" +#include "xla/python/ifrt/future.h" +#include "xla/python/ifrt/memory.h" +#include "xla/python/ifrt/mock.h" +#include "xla/python/ifrt/shape.h" +#include "xla/python/ifrt/sharding.h" +#include "xla/python/ifrt_proxy/client/client_session.h" +#include "xla/python/ifrt_proxy/client/host_buffer.h" +#include "xla/python/ifrt_proxy/client/mock_client_session.h" +#include "xla/python/ifrt_proxy/client/mock_host_buffer.h" +#include "xla/python/ifrt_proxy/client/rpc_helper.h" +#include "xla/python/ifrt_proxy/client/version.h" +#include "xla/python/ifrt_proxy/common/ifrt_service.pb.h" +#include "xla/python/ifrt_proxy/common/types.h" +#include "xla/python/ifrt_proxy/common/types.pb.h" +#include "tsl/concurrency/ref_count.h" +#include "tsl/platform/protobuf.h" // IWYU pragma: keep +#include "tsl/platform/status_matchers.h" +#include "tsl/platform/test.h" + +using ::testing::_; +using ::testing::Pointee; +using ::testing::Return; +using ::tsl::protobuf::TextFormat; +using ::tsl::testing::IsOk; + +#if defined(PLATFORM_GOOGLE) +using ::testing::EquivToProto; +using ::testing::proto::Partially; +#endif + +namespace xla { +namespace ifrt { +namespace proxy { +namespace { + +IfrtProxyVersion Version() { + IfrtProxyVersion version; + version.set_protocol_version(kClientMinVersion); + return version; +} + +class ArrayTest : public ::testing::Test { + protected: + void SetUp() override { + session_ = std::make_shared(); + rpc_helper_ = std::make_shared(Version(), session_); + + host_buffer_store_ = std::make_shared(); + rpc_helper_->set_host_buffer_store(host_buffer_store_); + + // Default handler that ignores all uninteresting requests, but still + // invokes the callback in order to avoid hanging the caller forever. + EXPECT_CALL(*session_, Enqueue(_)) + .WillRepeatedly(Return(Future( + absl::InternalError("Request has no mock handlers")))); + } + + std::shared_ptr session_; + std::shared_ptr rpc_helper_; + std::shared_ptr host_buffer_store_; +}; + +// TODO(b/315809436): Test needs rewrite because protobuf matchers are not OSS +#if defined(PLATFORM_GOOGLE) +TEST_F(ArrayTest, Destruction) { + IfrtResponse response; + EXPECT_CALL( + *session_, + Enqueue(Pointee(Partially(EquivToProto(R"pb(destruct_array_request { + array_handle: 1234 + })pb"))))) + .WillOnce(MockClientSessionReturnResponse(response)); + + MockClient client; + tsl::MakeRef(&client, rpc_helper_, DType(DType::Kind::kBF16), + Shape({}), /*sharding=*/nullptr, + ArrayHandle{.handle = 1234}); +} +#endif + +// TODO(b/315809436): Test needs rewrite because protobuf matchers are not OSS +#if defined(PLATFORM_GOOGLE) +TEST_F(ArrayTest, FullyReplicatedShard) { + IfrtResponse response; + ASSERT_TRUE(TextFormat::ParseFromString( + R"pb(response_metadata {} + fully_replicated_shard_response { array_handle: 5678 })pb", + &response)); + + EXPECT_CALL(*session_, Enqueue(Pointee(Partially(EquivToProto( + R"pb(fully_replicated_shard_request { + array_handle: 1234 + })pb"))))) + .WillOnce(MockClientSessionReturnResponse(response)); + + MockClient client; + MockDevice mock_device; + + auto sharding = xla::ifrt::SingleDeviceSharding::Create( + &mock_device, xla::ifrt::MemoryKind()); + + auto array = tsl::MakeRef( + &client, rpc_helper_, DType(DType::Kind::kBF16), Shape({}), + std::move(sharding), ArrayHandle{.handle = 1234}); + + ASSERT_THAT(array->FullyReplicatedShard(ArrayCopySemantics::kAlwaysCopy), + IsOk()); +} +#endif + +} // namespace +} // namespace proxy +} // namespace ifrt +} // namespace xla diff --git a/third_party/xla/xla/python/ifrt_proxy/client/client.cc b/third_party/xla/xla/python/ifrt_proxy/client/client.cc new file mode 100644 index 00000000000000..7153ad21420c1e --- /dev/null +++ b/third_party/xla/xla/python/ifrt_proxy/client/client.cc @@ -0,0 +1,209 @@ +// Copyright 2023 The OpenXLA Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "xla/python/ifrt_proxy/client/client.h" + +#include +#include +#include +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/memory/memory.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/types/span.h" +#include "xla/pjrt/pjrt_device_description.h" +#include "xla/python/ifrt/array.h" +#include "xla/python/ifrt/client.h" +#include "xla/python/ifrt/device.h" +#include "xla/python/ifrt/dtype.h" +#include "xla/python/ifrt/shape.h" +#include "xla/python/ifrt/sharding.h" +#include "xla/python/ifrt_proxy/client/array.h" +#include "xla/python/ifrt_proxy/client/device.h" +#include "xla/python/ifrt_proxy/client/memory.h" +#include "xla/python/ifrt_proxy/client/rpc_helper.h" +#include "xla/python/ifrt_proxy/common/ifrt_service.pb.h" +#include "xla/python/ifrt_proxy/common/types.h" +#include "xla/xla_data.pb.h" +#include "tsl/concurrency/ref_count.h" +#include "tsl/platform/statusor.h" + +namespace xla { +namespace ifrt { +namespace proxy { + +char Client::ID = 0; + +absl::StatusOr> Client::Create( + std::shared_ptr rpc_helper, InitResponse init_response) { + absl::flat_hash_set addressable_device_ids( + init_response.addressable_device_ids().begin(), + init_response.addressable_device_ids().end()); + + absl::flat_hash_map> memories; + for (const auto& m : init_response.memories()) { + auto memory = std::make_unique(m.id(), m.memory_space_kind(), + m.debug_string(), m.to_string()); + memories.insert({m.id(), std::move(memory)}); + } + + absl::flat_hash_map> devices; + std::vector device_ptrs; + std::vector addressable_device_ptrs; + + for (const auto& d : init_response.devices()) { + absl::flat_hash_map attributes; + for (const auto& [key, attr] : d.attributes()) { + TF_ASSIGN_OR_RETURN(xla::PjRtDeviceAttribute value, + FromVariantProto(attr)); + attributes.insert({key, std::move(value)}); + } + + DeviceDescription desc(d.id(), init_response.process_index(), + d.device_kind(), d.debug_string(), d.to_string(), + std::move(attributes)); + bool is_addressable = addressable_device_ids.contains(d.id()); + + auto device = + std::make_unique(std::move(desc), d.local_device_id(), + d.local_hardware_id(), is_addressable); + device_ptrs.push_back(device.get()); + if (is_addressable) { + addressable_device_ptrs.push_back(device.get()); + } + + if (d.has_default_memory_id()) { + const auto it = memories.find(d.default_memory_id()); + if (it == memories.end()) { + return absl::NotFoundError( + absl::StrCat("Memory ", d.default_memory_id(), " not found")); + } + device->default_memory_space_ = it->second.get(); + } + for (const int memory_id : d.memory_ids()) { + const auto it = memories.find(memory_id); + if (it == memories.end()) { + return absl::NotFoundError( + absl::StrCat("Memory ", memory_id, " not found")); + } + device->memory_spaces_.push_back(it->second.get()); + } + + devices.insert({d.id(), std::move(device)}); + } + + for (const auto& m : init_response.memories()) { + Memory* memory = memories.at(m.id()).get(); + for (const int device_id : m.device_ids()) { + const auto device = devices.find(device_id); + if (device == devices.end()) { + return absl::NotFoundError( + absl::StrCat("Device ", device_id, " not found")); + } + memory->devices_.push_back(device->second.get()); + } + } + + // Prefix the runtime_type string received from the server with "proxy/" so + // that the users (of this proxy client, such as JAX) do not erroneously + // conclude that they are talking with the backend runtime directly. + std::string runtime_type = + absl::StrCat("proxy/", init_response.runtime_type()); + + return absl::WrapUnique(new Client( + std::move(rpc_helper), init_response.session_id(), + init_response.platform_name(), init_response.platform_version(), + init_response.platform_id(), init_response.process_index(), runtime_type, + std::move(devices), std::move(device_ptrs), + std::move(addressable_device_ptrs), std::move(memories))); +} + +Client::Client(std::shared_ptr rpc_helper, uint64_t session_id, + std::string platform_name, std::string platform_version, + uint64_t platform_id, uint64_t process_index, + std::string runtime_type, + absl::flat_hash_map> devices, + std::vector device_ptrs, + std::vector addressable_device_ptrs, + absl::flat_hash_map> memories) + : rpc_helper_(rpc_helper), + platform_name_(std::move(platform_name)), + platform_version_(std::move(platform_version)), + platform_id_(platform_id), + process_index_(process_index), + runtime_type_(std::move(runtime_type)), + devices_(std::move(devices)), + device_ptrs_(device_ptrs), + addressable_device_ptrs_(std::move(addressable_device_ptrs)), + memories_(std::move(memories)), + default_compiler_(this, rpc_helper) {} + +Client::~Client() { rpc_helper_->Disconnect(); } + +absl::StatusOr Client::LookupDevice(int device_id) const { + auto it = devices_.find(device_id); + if (it == devices_.end()) { + return absl::NotFoundError( + absl::StrCat("Device ", device_id, " not found.")); + } + return it->second.get(); +} + +absl::StatusOr> +Client::MakeArrayFromHostBuffer( + const void* data, DType dtype, Shape shape, + std::optional> byte_strides, + std::shared_ptr sharding, + xla::ifrt::Client::HostBufferSemantics semantics, + std::function on_done_with_host_buffer) { + return Array::MakeArrayFromHostBuffer( + this, rpc_helper_, data, dtype, std::move(shape), std::move(byte_strides), + std::move(sharding), semantics, std::move(on_done_with_host_buffer)); +} + +absl::StatusOr> +Client::AssembleArrayFromSingleDeviceArrays( + Shape shape, std::shared_ptr sharding, + absl::Span> arrays, + ArrayCopySemantics semantics) { + return Array::AssembleArrayFromSingleDeviceArrays( + this, rpc_helper_, std::move(shape), sharding, arrays, semantics); +} + +absl::StatusOr Client::GetDefaultDeviceAssignment( + int num_replicas, int num_partitions) const { + auto req = std::make_unique(); + req->set_num_replicas(num_replicas); + req->set_num_partitions(num_partitions); + + auto future = rpc_helper_->GetDefaultDeviceAssignment(std::move(req)); + TF_ASSIGN_OR_RETURN(auto response, future.Await()); + + TF_ASSIGN_OR_RETURN( + auto assignment_to_return, + DeviceAssignment::Deserialize(response->device_assignment())); + + return *std::move(assignment_to_return); +} + +} // namespace proxy +} // namespace ifrt +} // namespace xla diff --git a/third_party/xla/xla/python/ifrt_proxy/client/client.h b/third_party/xla/xla/python/ifrt_proxy/client/client.h new file mode 100644 index 00000000000000..8cca1cbf0258af --- /dev/null +++ b/third_party/xla/xla/python/ifrt_proxy/client/client.h @@ -0,0 +1,157 @@ +/* + * Copyright 2023 The OpenXLA Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef XLA_PYTHON_IFRT_PROXY_CLIENT_CLIENT_H_ +#define XLA_PYTHON_IFRT_PROXY_CLIENT_CLIENT_H_ + +#include +#include +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "absl/types/span.h" +#include "llvm/Support/ExtensibleRTTI.h" +#include "xla/pjrt/pjrt_compiler.h" +#include "xla/python/ifrt/array.h" +#include "xla/python/ifrt/client.h" +#include "xla/python/ifrt/compiler.h" +#include "xla/python/ifrt/device.h" +#include "xla/python/ifrt/dtype.h" +#include "xla/python/ifrt/shape.h" +#include "xla/python/ifrt/sharding.h" +#include "xla/python/ifrt/tuple.h" +#include "xla/python/ifrt/value.h" +#include "xla/python/ifrt_proxy/client/compiler.h" +#include "xla/python/ifrt_proxy/client/device.h" +#include "xla/python/ifrt_proxy/client/memory.h" +#include "xla/python/ifrt_proxy/client/rpc_helper.h" +#include "xla/python/ifrt_proxy/common/ifrt_service.pb.h" +#include "tsl/concurrency/ref_count.h" + +namespace xla { +namespace ifrt { +namespace proxy { + +// Implementation of the xla::ifrt::Client interface. +class Client final : public llvm::RTTIExtends { + public: + static absl::StatusOr> Create( + std::shared_ptr rpc_helper, InitResponse init_response); + + ~Client() override; + + absl::StatusOr> MakeArrayFromHostBuffer( + const void* data, DType dtype, Shape shape, + std::optional> byte_strides, + std::shared_ptr sharding, HostBufferSemantics semantics, + std::function on_done_with_host_buffer) override; + + absl::StatusOr> + AssembleArrayFromSingleDeviceArrays( + Shape shape, std::shared_ptr sharding, + absl::Span> arrays, + ArrayCopySemantics semantics) override; + + absl::StatusOr> MakeTuple( + absl::Span> values) override { + return absl::UnimplementedError( + "MakeTuple is not supported for the IFRT proxy client."); + } + + absl::string_view runtime_type() const override { return runtime_type_; } + absl::string_view platform_name() const override { return platform_name_; } + absl::string_view platform_version() const override { + return platform_version_; + } + PlatformId platform_id() const override { return platform_id_; } + absl::flat_hash_map attributes() + const override { + // TODO(b/309059940): Forward the backend attributes to the client. + return {}; + } + int device_count() const override { return devices().size(); } + int addressable_device_count() const override { + return addressable_devices().size(); + } + absl::Span devices() const override { + return device_ptrs_; + } + absl::Span addressable_devices() const override { + return addressable_device_ptrs_; + } + int process_index() const override { return process_index_; } + absl::StatusOr GetDefaultDeviceAssignment( + int num_replicas, int num_partitions) const override; + absl::StatusOr LookupDevice(int device_id) const override; + absl::StatusOr LookupAddressableDevice( + int local_hardware_id) const override { + return absl::UnimplementedError( + "LookupAddressableDevice is not supported for the IFRT proxy client."); + } + xla::ifrt::Compiler* GetDefaultCompiler() override { + return &default_compiler_; + } + absl::StatusOr> + GetTopologyForDevices( + absl::Span devices) const override { + return absl::UnimplementedError( + "GetTopologyForDevices is not supported for the IFRT proxy client."); + } + + // For llvm::RTTIExtends. + static char ID; // NOLINT + + private: + Client(std::shared_ptr rpc_helper, uint64_t session_id, + std::string platform_name, std::string platform_version, + uint64_t platform_id, uint64_t process_index, std::string runtime_type, + absl::flat_hash_map> devices, + std::vector device_ptrs, + std::vector addressable_device_ptrs, + absl::flat_hash_map> memories); + + // rpc_helper_ will be referenced by various IFRT objects whose lifetime is + // managed by the layer above the IFRT interface, so shared_ptr is + // appropriate. + const std::shared_ptr rpc_helper_; + + const std::string platform_name_; + const std::string platform_version_; + const uint64_t platform_id_; + const uint64_t process_index_; + const std::string runtime_type_; + + const absl::flat_hash_map> devices_; + const std::vector device_ptrs_; + const std::vector addressable_device_ptrs_; + + const absl::flat_hash_map> memories_; + + Compiler default_compiler_; +}; + +} // namespace proxy +} // namespace ifrt +} // namespace xla + +#endif // XLA_PYTHON_IFRT_PROXY_CLIENT_CLIENT_H_ diff --git a/third_party/xla/xla/python/ifrt_proxy/client/client_session.h b/third_party/xla/xla/python/ifrt_proxy/client/client_session.h new file mode 100644 index 00000000000000..9bd795825e50cf --- /dev/null +++ b/third_party/xla/xla/python/ifrt_proxy/client/client_session.h @@ -0,0 +1,59 @@ +/* + * Copyright 2023 The OpenXLA Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef XLA_PYTHON_IFRT_PROXY_CLIENT_CLIENT_SESSION_H_ +#define XLA_PYTHON_IFRT_PROXY_CLIENT_CLIENT_SESSION_H_ + +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "xla/python/ifrt/future.h" +#include "xla/python/ifrt_proxy/common/ifrt_service.pb.h" + +namespace xla { +namespace ifrt { +namespace proxy { + +// Base class that defines the interface between IFRT service protocol and the +// stream implementation that is responsible for sending requests and receiving +// responses. +// +// `ClientSession` implementation must be thread-safe. +class ClientSession { + public: + // `Response` represents either an `IfrtResponse` value, or an `absl::Status` + // value corresponding to termination of the session stream. Value will never + // be a nullptr with OK status. + using Response = absl::StatusOr>; + + virtual ~ClientSession() = default; + + // Enqueues `request` to be sent via the stream; enqueued requests are sent in + // FIFO order. The caller must ensure that `request->op_id()` is unique + // throughout the stream's lifetime. The returned future becomes ready when a + // response for the given op id becomes ready. + virtual Future Enqueue(std::unique_ptr request) = 0; + + // Terminates the `ClientSession` if it has not already been terminated. + virtual void Finish(const absl::Status& s) {} +}; + +} // namespace proxy +} // namespace ifrt +} // namespace xla + +#endif // XLA_PYTHON_IFRT_PROXY_CLIENT_CLIENT_SESSION_H_ diff --git a/third_party/xla/xla/python/ifrt_proxy/client/client_test.cc b/third_party/xla/xla/python/ifrt_proxy/client/client_test.cc new file mode 100644 index 00000000000000..c6569a3a28f049 --- /dev/null +++ b/third_party/xla/xla/python/ifrt_proxy/client/client_test.cc @@ -0,0 +1,218 @@ +// Copyright 2023 The OpenXLA Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "xla/python/ifrt_proxy/client/client.h" + +#include + +#include +#include +#include "absl/status/status.h" +#include "absl/strings/string_view.h" +#include "xla/pjrt/pjrt_device_description.h" +#include "xla/python/ifrt/future.h" +#include "xla/python/ifrt_proxy/client/client_session.h" +#include "xla/python/ifrt_proxy/client/host_buffer.h" +#include "xla/python/ifrt_proxy/client/mock_client_session.h" +#include "xla/python/ifrt_proxy/client/mock_host_buffer.h" +#include "xla/python/ifrt_proxy/client/rpc_helper.h" +#include "xla/python/ifrt_proxy/client/version.h" +#include "xla/python/ifrt_proxy/common/ifrt_service.pb.h" +#include "xla/service/computation_placer.h" +#include "tsl/platform/platform.h" +#include "tsl/platform/protobuf.h" // IWYU pragma: keep +#include "tsl/platform/status_matchers.h" +#include "tsl/platform/statusor.h" +#include "tsl/platform/test.h" + +namespace xla { +namespace ifrt { +namespace proxy { +namespace { + +using ::testing::ElementsAre; +using ::testing::Not; +using ::testing::Pair; +using ::testing::Pointee; +using ::testing::Return; +using ::testing::SizeIs; +using ::testing::UnorderedElementsAre; +using ::tsl::testing::IsOk; +using ::tsl::testing::IsOkAndHolds; + +#if defined(PLATFORM_GOOGLE) +using ::testing::EquivToProto; +using ::testing::proto::Partially; +#endif + +IfrtProxyVersion Version() { + IfrtProxyVersion version; + version.set_protocol_version(kClientMinVersion); + return version; +} + +class ClientTest : public ::testing::Test { + protected: + void SetUp() override { + session_ = std::make_shared(); + rpc_helper_ = std::make_shared(Version(), session_); + + host_buffer_store_ = std::make_shared(); + rpc_helper_->set_host_buffer_store(host_buffer_store_); + + InitResponse response; + ASSERT_TRUE(tsl::protobuf::TextFormat::ParseFromString( + R"pb( + platform_name: "ifrt-service" + platform_version: "n/a" + platform_id: 42 + process_index: 1 + runtime_type: "ifrt-service" + devices { + id: 0 + local_hardware_id: 1234 + device_kind: "mock" + default_memory_id: 0 + memory_ids: [ 0 ] + attributes { + key: "name" + value { string_value: "device0" } + } + } + devices { + id: 1 + local_hardware_id: 1234 + device_kind: "mock" + default_memory_id: 1 + memory_ids: [ 1 ] + attributes { + key: "name" + value { string_value: "device1" } + } + } + addressable_device_ids: 1 + memories { + id: 0 + memory_space_kind: "mock" + device_ids: [ 0 ] + } + memories { + id: 1 + memory_space_kind: "mock" + device_ids: [ 1 ] + } + )pb", + &response)); + TF_ASSERT_OK_AND_ASSIGN(client_, Client::Create(rpc_helper_, response)); + } + + std::shared_ptr session_; + std::shared_ptr rpc_helper_; + std::shared_ptr host_buffer_store_; + std::unique_ptr client_; +}; + +TEST_F(ClientTest, Init) { + if (tsl::testing::kIsOpenSource) { + // TODO(b/324824974): Fix this. + GTEST_SKIP() << "Non-rootcaused bug: xla::PjRtDeviceAttribute does not " + << "work properly with IFRT proxy in open-source."; + } + EXPECT_EQ(client_->platform_name(), "ifrt-service"); + EXPECT_EQ(client_->platform_version(), "n/a"); + EXPECT_EQ(client_->platform_id(), 42); + EXPECT_EQ(client_->process_index(), 1); + EXPECT_EQ(client_->runtime_type(), "proxy/ifrt-service"); + + ASSERT_EQ(client_->device_count(), 2); + ASSERT_EQ(client_->addressable_device_count(), 1); + + TF_ASSERT_OK_AND_ASSIGN(auto* const device0, client_->LookupDevice(0)); + EXPECT_EQ(device0->id(), 0); + EXPECT_EQ(device0->local_hardware_id(), 1234); + EXPECT_EQ(device0->device_kind(), "mock"); + EXPECT_THAT(device0->Attributes(), + ElementsAre(Pair("name", xla::PjRtDeviceAttribute("device0")))); + + ASSERT_THAT(device0->memory_spaces(), SizeIs(1)); + auto* const memory0 = device0->memory_spaces()[0]; + EXPECT_EQ(memory0->id(), 0); + EXPECT_EQ(memory0->memory_space_kind(), "mock"); + EXPECT_THAT(memory0->devices(), UnorderedElementsAre(device0)); + EXPECT_THAT(device0->default_memory_space(), IsOkAndHolds(memory0)); + + TF_ASSERT_OK_AND_ASSIGN(auto* const device1, client_->LookupDevice(1)); + EXPECT_EQ(device1->id(), 1); + EXPECT_EQ(device1->local_hardware_id(), 1234); + EXPECT_EQ(device1->device_kind(), "mock"); + EXPECT_THAT(device1->Attributes(), + ElementsAre(Pair("name", xla::PjRtDeviceAttribute("device1")))); + + ASSERT_THAT(device1->memory_spaces(), SizeIs(1)); + auto* const memory1 = device1->memory_spaces()[0]; + EXPECT_EQ(memory1->id(), 1); + EXPECT_EQ(memory1->memory_space_kind(), "mock"); + EXPECT_THAT(memory1->devices(), UnorderedElementsAre(device1)); + EXPECT_THAT(device1->default_memory_space(), IsOkAndHolds(memory1)); + + EXPECT_THAT(client_->addressable_devices(), ElementsAre(device1)); +} + +// TODO(b/315809436): Test needs rewrite because protobuf matchers are not OSS +#if defined(PLATFORM_GOOGLE) +TEST_F(ClientTest, GetDefaultDeviceAssignmentSuccess) { + IfrtResponse response; + xla::DeviceAssignment assignment(1, 3); + ASSERT_THAT(assignment.Serialize( + response.mutable_get_default_device_assignment_response() + ->mutable_device_assignment()), + IsOk()); + + EXPECT_CALL(*session_, Enqueue(Pointee(Partially(EquivToProto( + R"pb( + get_default_device_assignment_request { + num_replicas: 1 + num_partitions: 3 + } + )pb"))))) + .WillOnce(MockClientSessionReturnResponse(response)); + + TF_ASSERT_OK_AND_ASSIGN(auto assignment_got, + client_->GetDefaultDeviceAssignment(1, 3)); + EXPECT_EQ(assignment_got.replica_count(), 1); + EXPECT_EQ(assignment_got.computation_count(), 3); +} +#endif + +// TODO(b/315809436): Test needs rewrite because protobuf matchers are not OSS +#if defined(PLATFORM_GOOGLE) +TEST_F(ClientTest, GetDefaultDeviceAssignmentFailure) { + EXPECT_CALL(*session_, Enqueue(Pointee(Partially(EquivToProto( + R"pb( + get_default_device_assignment_request { + num_replicas: 1 + num_partitions: 3 + } + )pb"))))) + .WillOnce(Return(Future( + absl::InternalError("injected from test")))); + + EXPECT_THAT(client_->GetDefaultDeviceAssignment(1, 3), Not(IsOk())); +} +#endif + +} // namespace +} // namespace proxy +} // namespace ifrt +} // namespace xla diff --git a/third_party/xla/xla/python/ifrt_proxy/client/compiler.cc b/third_party/xla/xla/python/ifrt_proxy/client/compiler.cc new file mode 100644 index 00000000000000..f0934734c0af37 --- /dev/null +++ b/third_party/xla/xla/python/ifrt_proxy/client/compiler.cc @@ -0,0 +1,156 @@ +// Copyright 2023 The OpenXLA Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "xla/python/ifrt_proxy/client/compiler.h" + +#include +#include +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "llvm/Support/Casting.h" +#include "xla/pjrt/host_callback.h" +#include "xla/python/ifrt/client.h" +#include "xla/python/ifrt/compiler.h" +#include "xla/python/ifrt/device.h" +#include "xla/python/ifrt/executable.h" +#include "xla/python/ifrt/host_callback.h" +#include "xla/python/ifrt/serdes.h" +#include "xla/python/ifrt_proxy/client/executable.h" +#include "xla/python/ifrt_proxy/client/rpc_helper.h" +#include "xla/python/ifrt_proxy/common/ifrt_service.pb.h" +#include "xla/python/ifrt_proxy/server/host_callback.h" +#include "xla/python/pjrt_ifrt/pjrt_host_callback.h" +#include "xla/python/pjrt_ifrt/xla_compiler.h" +#include "tsl/concurrency/ref_count.h" +#include "tsl/platform/status_to_from_proto.h" +#include "tsl/platform/statusor.h" + +namespace xla { +namespace ifrt { +namespace proxy { + +Compiler::Compiler(xla::ifrt::Client* client, + std::shared_ptr rpc_helper) + : client_(client), rpc_helper_(std::move(rpc_helper)) {} + +absl::StatusOr> Compiler::Compile( + std::unique_ptr program, + std::unique_ptr options) { + auto request = std::make_unique(); + TF_ASSIGN_OR_RETURN(*request->mutable_program(), Serialize(*program)); + + // Extract host callbacks from the XLA compile options. `XlaCompileOptions`'s + // SerDes fails when it contains host callbacks, so the following + // implementation handles host callback serialization out of band until we can + // natively support IFRT host callback on IFRT proxy. + std::vector> + loaded_host_callbacks; + if (auto* xla_options = + llvm::dyn_cast(options.get())) { + for (const auto& loaded_host_callback : + xla_options->loaded_host_callbacks) { + auto* pjrt_host_callback = + llvm::dyn_cast( + loaded_host_callback.get()); + if (pjrt_host_callback == nullptr) { + return absl::UnimplementedError("Unsupported host callback type"); + } + + const xla::HostCallback& xla_host_callback = + pjrt_host_callback->host_callback(); + + // The proxy server runs `RemoteLoadedHostCallback` that delegates actual + // host callback execution to the proxy client. + auto remote_loaded_host_callback = tsl::MakeRef( + client_, xla_host_callback.operands, xla_host_callback.results, + /*queue=*/nullptr); + TF_ASSIGN_OR_RETURN(*request->add_host_callbacks(), + remote_loaded_host_callback->Serialize()); + } + + loaded_host_callbacks.swap(xla_options->loaded_host_callbacks); + } + + TF_ASSIGN_OR_RETURN(*request->mutable_compile_options(), Serialize(*options)); + + // TODO(b/266635130): Avoid blocking the caller. + TF_ASSIGN_OR_RETURN(std::shared_ptr response, + rpc_helper_->Compile(std::move(request)).Await()); + + std::vector + addressable_device_logical_device_ids; + addressable_device_logical_device_ids.reserve( + response->addressable_device_logical_ids_size()); + for (const auto& logical_device_id : + response->addressable_device_logical_ids()) { + addressable_device_logical_device_ids.push_back({ + .replica = logical_device_id.replica(), + .partition = logical_device_id.partition(), + }); + } + + std::vector addressable_devices; + addressable_devices.reserve(response->addressable_device_ids_size()); + for (const int32_t device_id : response->addressable_device_ids()) { + TF_ASSIGN_OR_RETURN(xla::ifrt::Device* const device, + client_->LookupDevice(device_id)); + addressable_devices.push_back(device); + } + + absl::StatusOr> fingerprint; + switch (response->fingerprint_case()) { + case CompileResponse::kFingerprintValue: + fingerprint = response->fingerprint_value(); + break; + case CompileResponse::kFingerprintError: + fingerprint = tsl::StatusFromProto(response->fingerprint_error()); + break; + default: + fingerprint = std::nullopt; + break; + } + + std::vector loaded_host_callback_handles( + response->loaded_host_callback_handles().begin(), + response->loaded_host_callback_handles().end()); + + return std::make_unique( + client_, rpc_helper_, response->loaded_executable_handle(), + response->name(), response->num_devices(), + std::move(addressable_device_logical_device_ids), + std::move(addressable_devices), std::move(fingerprint), + std::move(loaded_host_callbacks), + std::move(loaded_host_callback_handles)); +} + +absl::StatusOr> +Compiler::DeserializeLoadedExecutable( + absl::string_view serialized, + std::unique_ptr options) { + return absl::UnimplementedError( + "IFRT service compiler does not support `DeserializeLoadedExecutable` " + "since the underlying serialization format is not stable"); +} + +char Compiler::ID = 0; + +} // namespace proxy +} // namespace ifrt +} // namespace xla diff --git a/third_party/xla/xla/python/ifrt_proxy/client/compiler.h b/third_party/xla/xla/python/ifrt_proxy/client/compiler.h new file mode 100644 index 00000000000000..6bfc814766d111 --- /dev/null +++ b/third_party/xla/xla/python/ifrt_proxy/client/compiler.h @@ -0,0 +1,57 @@ +/* + * Copyright 2023 The OpenXLA Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef XLA_PYTHON_IFRT_PROXY_CLIENT_COMPILER_H_ +#define XLA_PYTHON_IFRT_PROXY_CLIENT_COMPILER_H_ + +#include + +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "xla/python/ifrt/client.h" +#include "xla/python/ifrt/compiler.h" +#include "xla/python/ifrt_proxy/client/rpc_helper.h" + +namespace xla { +namespace ifrt { +namespace proxy { + +class Compiler final : public llvm::RTTIExtends { + public: + Compiler(xla::ifrt::Client* client, std::shared_ptr rpc_helper); + + absl::StatusOr> Compile( + std::unique_ptr program, + std::unique_ptr options) override; + + absl::StatusOr> + DeserializeLoadedExecutable( + absl::string_view serialized, + std::unique_ptr options) + override; + + static char ID; // NOLINT + + private: + xla::ifrt::Client* client_; + std::shared_ptr rpc_helper_; +}; + +} // namespace proxy +} // namespace ifrt +} // namespace xla + +#endif // XLA_PYTHON_IFRT_PROXY_CLIENT_COMPILER_H_ diff --git a/third_party/xla/xla/python/ifrt_proxy/client/compiler_test.cc b/third_party/xla/xla/python/ifrt_proxy/client/compiler_test.cc new file mode 100644 index 00000000000000..3692e7893b81e1 --- /dev/null +++ b/third_party/xla/xla/python/ifrt_proxy/client/compiler_test.cc @@ -0,0 +1,206 @@ +// Copyright 2023 The OpenXLA Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "xla/python/ifrt_proxy/client/compiler.h" + +#include +#include +#include + +#include +#include +#include "absl/log/check.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/ExtensibleRTTI.h" +#include "xla/python/ifrt/compiler.h" +#include "xla/python/ifrt/future.h" +#include "xla/python/ifrt/mock.h" +#include "xla/python/ifrt/serdes.h" +#include "xla/python/ifrt_proxy/client/client_session.h" +#include "xla/python/ifrt_proxy/client/host_buffer.h" +#include "xla/python/ifrt_proxy/client/mock_client_session.h" +#include "xla/python/ifrt_proxy/client/mock_host_buffer.h" +#include "xla/python/ifrt_proxy/client/rpc_helper.h" +#include "xla/python/ifrt_proxy/client/version.h" +#include "xla/python/ifrt_proxy/common/ifrt_service.pb.h" +#include "tsl/platform/protobuf.h" // IWYU pragma: keep +#include "tsl/platform/status_matchers.h" +#include "tsl/platform/statusor.h" +#include "tsl/platform/test.h" + +namespace xla { +namespace ifrt { +namespace proxy { +namespace { + +using ::testing::_; +using ::testing::ElementsAre; +using ::testing::FieldsAre; +using ::testing::Invoke; +using ::testing::Optional; +using ::testing::Pointee; +using ::testing::Return; +using ::tsl::protobuf::TextFormat; +using ::tsl::testing::IsOkAndHolds; + +#if defined(PLATFORM_GOOGLE) +using ::testing::EquivToProto; +using ::testing::proto::Partially; +#endif + +struct TestProgram : llvm::RTTIExtends { + static char ID; // NOLINT +}; + +[[maybe_unused]] char TestProgram::ID = 0; // NOLINT + +class TestProgramSerDes : public llvm::RTTIExtends { + public: + absl::string_view type_name() const override { + return "xla::ifrt::proxy::TestProgram"; + } + + absl::StatusOr Serialize(Serializable& serializable) override { + CHECK(llvm::isa(serializable)); + return ""; + } + + absl::StatusOr> Deserialize( + const std::string& serialized, + std::unique_ptr options) override { + return std::make_unique(); + } + + static char ID; // NOLINT +}; + +[[maybe_unused]] char TestProgramSerDes::ID = 0; // NOLINT + +struct TestCompileOptions + : llvm::RTTIExtends { + static char ID; // NOLINT +}; + +[[maybe_unused]] char TestCompileOptions::ID = 0; // NOLINT + +class TestCompileOptionsSerDes + : public llvm::RTTIExtends { + public: + absl::string_view type_name() const override { + return "xla::ifrt::proxy::TestCompileOptions"; + } + + absl::StatusOr Serialize(Serializable& serializable) override { + CHECK(llvm::isa(serializable)); + return ""; + } + + absl::StatusOr> Deserialize( + const std::string& serialized, + std::unique_ptr options) override { + return std::make_unique(); + } + + static char ID; // NOLINT +}; + +[[maybe_unused]] char TestCompileOptionsSerDes::ID = 0; // NOLINT + +IfrtProxyVersion Version() { + IfrtProxyVersion version; + version.set_protocol_version(kClientMinVersion); + return version; +} + +class CompilerTest : public testing::Test { + protected: + static void SetUpTestSuite() { + RegisterSerDes(std::make_unique()); + RegisterSerDes( + std::make_unique()); + } + + void SetUp() override { + session_ = std::make_shared(); + rpc_helper_ = std::make_shared(Version(), session_); + + host_buffer_store_ = std::make_shared(); + rpc_helper_->set_host_buffer_store(host_buffer_store_); + + // Default handler that ignores all uninteresting requests but still + // invokes the callback in order to avoid hanging the caller forever. + EXPECT_CALL(*session_, Enqueue(_)) + .WillRepeatedly(Return(Future( + absl::InternalError("Request has no mock handlers")))); + } + + std::shared_ptr session_; + std::shared_ptr rpc_helper_; + std::shared_ptr host_buffer_store_; +}; + +// TODO(b/315809436): Test needs rewrite because protobuf matchers are not OSS +#if defined(PLATFORM_GOOGLE) +TEST_F(CompilerTest, Compile) { + std::vector devices(2); + + MockClient client; + ON_CALL(client, LookupDevice(_)).WillByDefault(Invoke([&](int id) { + return &devices[id]; + })); + + Compiler compiler(&client, rpc_helper_); + + IfrtResponse response; + ASSERT_TRUE(TextFormat::ParseFromString( + R"pb(compile_response { + loaded_executable_handle: 1234 + name: "foo-executable" + num_devices: 2 + addressable_device_logical_ids { replica: 0 partition: 0 } + addressable_device_logical_ids { replica: 0 partition: 1 } + addressable_device_ids: [ 0, 1 ] + fingerprint_value: "fingerprint" + })pb", + &response)); + EXPECT_CALL(*session_, + Enqueue(Pointee(Partially(EquivToProto( + R"pb(compile_request { + program { type_name: "xla::ifrt::proxy::TestProgram" } + })pb"))))) + .WillOnce(MockClientSessionReturnResponse(response)); + + TF_ASSERT_OK_AND_ASSIGN( + auto executable, + compiler.Compile(std::make_unique(), + std::make_unique())); + + EXPECT_EQ(executable->name(), "foo-executable"); + EXPECT_EQ(executable->num_devices(), 2); + EXPECT_THAT(executable->addressable_device_logical_ids(), + ElementsAre(FieldsAre(0, 0), FieldsAre(0, 1))); + EXPECT_THAT(executable->addressable_devices(), + ElementsAre(&devices[0], &devices[1])); + EXPECT_THAT(executable->Fingerprint(), + IsOkAndHolds(Optional(std::string("fingerprint")))); +} +#endif + +} // namespace +} // namespace proxy +} // namespace ifrt +} // namespace xla diff --git a/third_party/xla/xla/python/ifrt_proxy/client/device.cc b/third_party/xla/xla/python/ifrt_proxy/client/device.cc new file mode 100644 index 00000000000000..f43d9aec101f7a --- /dev/null +++ b/third_party/xla/xla/python/ifrt_proxy/client/device.cc @@ -0,0 +1,61 @@ +// Copyright 2023 The OpenXLA Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "xla/python/ifrt_proxy/client/device.h" + +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "xla/literal.h" +#include "xla/pjrt/pjrt_client.h" +#include "xla/pjrt/pjrt_future.h" + +namespace xla { +namespace ifrt { +namespace proxy { + +std::unique_ptr Device::CreateAsyncTrackingEvent( + absl::string_view description) const { + return nullptr; +} + +absl::Status Device::TransferToInfeed(const xla::LiteralSlice& literal) { + return absl::UnimplementedError("Device does not support TransferToInfeed"); +} + +absl::Status Device::TransferFromOutfeed(xla::MutableBorrowingLiteral literal) { + return absl::UnimplementedError( + "Device does not support TransferFromOutfeed"); +} + +absl::Span Device::memory_spaces() const { + return memory_spaces_; +} + +absl::StatusOr Device::default_memory_space() const { + if (default_memory_space_ == nullptr) { + return absl::UnimplementedError( + "Device does not support default_memory_space"); + } + return default_memory_space_; +} + +char Device::ID = 0; // NOLINT + +} // namespace proxy +} // namespace ifrt +} // namespace xla diff --git a/third_party/xla/xla/python/ifrt_proxy/client/device.h b/third_party/xla/xla/python/ifrt_proxy/client/device.h new file mode 100644 index 00000000000000..6922488d7fa006 --- /dev/null +++ b/third_party/xla/xla/python/ifrt_proxy/client/device.h @@ -0,0 +1,137 @@ +/* + * Copyright 2023 The OpenXLA Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef XLA_PYTHON_IFRT_PROXY_CLIENT_DEVICE_H_ +#define XLA_PYTHON_IFRT_PROXY_CLIENT_DEVICE_H_ + +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "xla/literal.h" +#include "xla/pjrt/pjrt_client.h" +#include "xla/pjrt/pjrt_device_description.h" +#include "xla/pjrt/pjrt_future.h" +#include "xla/python/ifrt/device.h" +#include "xla/python/ifrt/memory.h" + +namespace xla { +namespace ifrt { +namespace proxy { + +class DeviceDescription final : public xla::PjRtDeviceDescription { + public: + DeviceDescription( + int id, int process_index, std::string device_kind, + std::string debug_string, std::string to_string, + absl::flat_hash_map attributes) + : id_(id), + process_index_(process_index), + device_kind_(device_kind), + debug_string_(std::move(debug_string)), + to_string_(std::move(to_string)), + attributes_(std::move(attributes)) {} + + int id() const override { return id_; } + + int process_index() const override { return process_index_; } + + absl::string_view device_kind() const override { return device_kind_; } + + absl::string_view DebugString() const override { return debug_string_; } + + absl::string_view ToString() const override { return to_string_; } + + const absl::flat_hash_map& Attributes() + const override { + return attributes_; + } + + private: + int id_; + int process_index_; + std::string device_kind_; + std::string debug_string_; + std::string to_string_; + absl::flat_hash_map attributes_; +}; + +class Device final : public xla::ifrt::Device { + public: + Device(DeviceDescription description, int local_device_id, + int local_hardware_id, bool is_addressable) + : description_(std::move(description)), + local_device_id_(local_device_id), + local_hardware_id_(local_hardware_id), + is_addressable_(is_addressable) {} + + xla::PjRtClient* client() const override { return nullptr; } + + bool IsAddressable() const override { return is_addressable_; } + + const xla::PjRtDeviceDescription& description() const override { + return description_; + } + + int local_hardware_id() const override { + return local_hardware_id_typed().value(); + } + + PjRtLocalDeviceId local_device_id() const override { + return PjRtLocalDeviceId(local_device_id_); + } + + PjRtLocalHardwareId local_hardware_id_typed() const override { + return PjRtLocalHardwareId(local_hardware_id_); + } + + std::unique_ptr CreateAsyncTrackingEvent( + absl::string_view description) const override; + + absl::Status TransferToInfeed(const xla::LiteralSlice& literal) override; + + absl::Status TransferFromOutfeed( + xla::MutableBorrowingLiteral literal) override; + + absl::Span memory_spaces() const override; + + absl::StatusOr default_memory_space() const override; + + static char ID; // NOLINT + + private: + friend class Client; // For `memory_spaces_` initialization. + + const DeviceDescription description_; + const int local_device_id_; + const int local_hardware_id_; + const bool is_addressable_; + + std::vector memory_spaces_; + xla::ifrt::Memory* default_memory_space_ = nullptr; +}; + +} // namespace proxy +} // namespace ifrt +} // namespace xla + +#endif // XLA_PYTHON_IFRT_PROXY_CLIENT_DEVICE_H_ diff --git a/third_party/xla/xla/python/ifrt_proxy/client/executable.cc b/third_party/xla/xla/python/ifrt_proxy/client/executable.cc new file mode 100644 index 00000000000000..07d57d00d5fabd --- /dev/null +++ b/third_party/xla/xla/python/ifrt_proxy/client/executable.cc @@ -0,0 +1,542 @@ +// Copyright 2023 The OpenXLA Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "xla/python/ifrt_proxy/client/executable.h" + +#include +#include +#include +#include +#include +#include +#include + +#include "absl/cleanup/cleanup.h" +#include "absl/container/flat_hash_map.h" +#include "absl/functional/bind_front.h" +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/cord.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "llvm/Support/Casting.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/layout.h" +#include "xla/pjrt/host_callback.h" +#include "xla/pjrt/pjrt_executable.h" +#include "xla/python/ifrt/array.h" +#include "xla/python/ifrt/client.h" +#include "xla/python/ifrt/device.h" +#include "xla/python/ifrt/dtype.h" +#include "xla/python/ifrt/executable.h" +#include "xla/python/ifrt/future.h" +#include "xla/python/ifrt/host_callback.h" +#include "xla/python/ifrt/shape.h" +#include "xla/python/ifrt_proxy/client/array.h" +#include "xla/python/ifrt_proxy/client/host_buffer.h" +#include "xla/python/ifrt_proxy/client/rpc_helper.h" +#include "xla/python/ifrt_proxy/common/ifrt_service.pb.h" +#include "xla/python/ifrt_proxy/common/types.h" +#include "xla/python/pjrt_ifrt/pjrt_host_callback.h" +#include "xla/shape_util.h" +#include "xla/xla_data.pb.h" +#include "tsl/concurrency/ref_count.h" +#include "tsl/platform/env.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/status_to_from_proto.h" +#include "tsl/platform/statusor.h" + +namespace xla { +namespace ifrt { +namespace proxy { + +namespace { + +// Locally executes the loaded host callback with given operand buffer from the +// IFRT proxy server and returns a result buffer to be sent back. +absl::StatusOr ExecuteLoadedHostCallback( + xla::ifrt::LoadedHostCallback* loaded_host_callback, + absl::Cord operand_buffer) { +#if defined(PLATFORM_GOOGLE) + auto* pjrt_host_callback = + llvm::dyn_cast( + loaded_host_callback); + if (pjrt_host_callback == nullptr) { + return absl::UnimplementedError( + "Non-PjRt host callbacks cannot be executed"); + } + const xla::HostCallback& xla_host_callback = + pjrt_host_callback->host_callback(); + + // The following allocates both operands and results using `aligned_alloc` in + // order to (loosely) emulate the XLA implementation where host callbacks are + // often called with aligned operand/result buffers. While this may not be + // strictly necessary for some callbacks, this reduces the chances of proxied + // callbacks behaving differently on a best-effort basis. + constexpr int kAlignment = 32; + + struct Deleter { + void operator()(void* p) { free(p); } + }; + + std::vector> operands; + operands.reserve(xla_host_callback.operands.size()); + std::vector operand_ptrs; + operand_ptrs.reserve(xla_host_callback.operands.size()); + + absl::CordReader reader(operand_buffer); + for (const auto& spec : xla_host_callback.operands) { + const int64_t size = xla::ShapeUtil::ByteSizeOf(spec.shape); + void* p; + CHECK_EQ(posix_memalign(&p, kAlignment, size), 0); + std::unique_ptr buffer(reinterpret_cast(p)); + + if (reader.Available() < size) { + return absl::InternalError(absl::StrCat( + "Buffer overflow while reading host callback execution operands; ", + "range: [", reader.Position(), ", ", reader.Position() + size, "), ", + "buffer size: ", operand_buffer.size())); + } + reader.ReadN(size, buffer.get()); + + operand_ptrs.push_back(buffer.get()); + operands.push_back(std::move(buffer)); + } + if (reader.Available() > 0) { + return absl::InternalError(absl::StrCat( + "Host callback execution did not consume the entire operand buffer; " + "size: ", + operand_buffer.size(), "; consumed: ", reader.Available())); + } + + absl::Cord result_buffer; + std::vector result_ptrs; + result_ptrs.reserve(xla_host_callback.results.size()); + + for (const auto& spec : xla_host_callback.results) { + const int64_t size = xla::ShapeUtil::ByteSizeOf(spec.shape); + void* data; + CHECK_EQ(posix_memalign(&data, kAlignment, size), 0); + + result_ptrs.push_back(data); + result_buffer.AppendExternalMemory( + absl::string_view(reinterpret_cast(data), size), data, &free); + } + + TF_RETURN_IF_ERROR( + xla_host_callback.callback(result_ptrs.data(), operand_ptrs.data())); + + return result_buffer; +#else + return absl::UnimplementedError("ExecuteLoadedHostCallback is unsupported."); +#endif +} + +// Same as `ExecuteLoadedHostCallback`, except that it uses host buffer store to +// retrieve operands and store results. +absl::StatusOr PrepareAndExecuteLoadedHostCallback( + ClientHostBufferStore* host_buffer_store, + xla::ifrt::LoadedHostCallback* loaded_host_callback, + uint64_t operand_handle) { + TF_ASSIGN_OR_RETURN(absl::Cord operands, + host_buffer_store->Lookup(operand_handle).Await()); + absl::Cleanup cleanup = [&]() { + host_buffer_store->Delete(operand_handle).OnReady([](absl::Status status) { + if (!status.ok()) { + LOG(ERROR) << "Failed to delete host callback operands: " << status; + } + }); + }; + + TF_ASSIGN_OR_RETURN( + absl::Cord results, + ExecuteLoadedHostCallback(loaded_host_callback, std::move(operands))); + + const uint64_t result_handle = host_buffer_store->NextHandle(); + TF_RETURN_IF_ERROR(host_buffer_store->Store(result_handle, results).Await()); + return result_handle; +} + +} // namespace + +LoadedExecutable::LoadedExecutable( + xla::ifrt::Client* client, std::shared_ptr rpc_helper, + uint64_t handle, std::string name, int num_devices, + std::vector + addressable_device_logical_device_ids, + std::vector addressable_devices, + absl::StatusOr> fingerprint, + std::vector> + loaded_host_callbacks, + std::vector loaded_host_callback_handles) + : client_(client), + rpc_helper_(std::move(rpc_helper)), + handle_(handle), + name_(std::move(name)), + num_devices_(num_devices), + addressable_device_logical_device_ids_( + std::move(addressable_device_logical_device_ids)), + addressable_devices_(std::move(addressable_devices)), + fingerprint_(std::move(fingerprint)) { + // Start host callback pollers. + CHECK_EQ(loaded_host_callbacks.size(), loaded_host_callback_handles.size()); + if (!loaded_host_callbacks.empty()) { + for (int i = 0; i < loaded_host_callbacks.size(); ++i) { + PollLoadedHostCallback(loaded_host_callback_handles[i], + loaded_host_callbacks[i]); + } + } + + // Asynchronously fetch shardings. Since users of `LoadedExecutable` typically + // require sharding information to invoke the executable, it is beneficial to + // eagerly schedule this fetch since, in some implementations, it may take a + // long time for sharding information to be available. + + auto promise = + Future>>::CreatePromise(); + metadata_future_ = Future>>(promise); + + auto req = std::make_unique(); + req->set_loaded_executable_handle(handle_); + + auto on_done = + [promise]( + absl::StatusOr> + response) mutable { + if (!response.ok()) { + LOG(ERROR) << "LoadedExecutableMetadata: Got " << response.status(); + promise.Set(response.status()); + return; + } + + auto info = std::make_shared(); + + if (response.value()->has_parameter_shardings()) { + const auto& p = response.value()->parameter_shardings().shardings(); + info->parameter_shardings.emplace(p.begin(), p.end()); + } + if (response.value()->has_output_shardings()) { + const auto& o = response.value()->output_shardings().shardings(); + info->output_shardings.emplace(o.begin(), o.end()); + } + + auto parse_layouts = + [](const LoadedExecutableMetadataResponse::LayoutList& list) { + std::vector layouts; + layouts.reserve(list.layouts_size()); + for (const auto& layout : list.layouts()) { + layouts.push_back(xla::Layout::CreateFromProto(layout)); + } + return layouts; + }; + + if (response.value()->has_parameter_layouts_list()) { + info->parameter_layouts = + parse_layouts(response.value()->parameter_layouts_list()); + } else if (response.value()->has_parameter_layouts_error()) { + info->parameter_layouts = + tsl::StatusFromProto(response.value()->parameter_layouts_error()); + } else { + info->parameter_layouts = absl::UnimplementedError( + "IFRT Proxy server did not return parameter layouts"); + } + if (response.value()->has_output_layouts_list()) { + info->output_layouts = + parse_layouts(response.value()->output_layouts_list()); + } else if (response.value()->has_output_layouts_error()) { + info->output_layouts = + tsl::StatusFromProto(response.value()->output_layouts_error()); + } else { + info->output_layouts = absl::UnimplementedError( + "IFRT Proxy server did not return output layouts"); + } + + if (const absl::Status s = tsl::StatusFromProto( + response.value()->output_memory_kinds().status()); + !s.ok()) { + info->output_memory_kinds = s; + } else { + std::vector> output_memory_kinds; + for (const auto& list : + response.value()->output_memory_kinds().memory_kind_lists()) { + std::vector kinds; + kinds.reserve(list.memory_kinds_size()); + for (const absl::string_view kind : list.memory_kinds()) { + const auto it = + info->memory_kinds.insert(std::string(kind)).first; + kinds.push_back(*it); + } + output_memory_kinds.push_back(std::move(kinds)); + } + info->output_memory_kinds = std::move(output_memory_kinds); + } + + promise.Set(std::move(info)); + }; + rpc_helper_->LoadedExecutableMetadata(std::move(req)) + .OnReady(std::move(on_done)); +} + +LoadedExecutable::~LoadedExecutable() { + auto req = std::make_unique(); + req->set_loaded_executable_handle(handle_); + + rpc_helper_->LoadedExecutableDestruct(std::move(req)) + .OnReady( + [](absl::StatusOr> + response) { + if (!response.ok()) { + LOG(ERROR) << "Failed to destroy `LoadedExecutable`: " + << response.status(); + } + }); +} + +xla::ifrt::Client* LoadedExecutable::client() const { return client_; } + +absl::string_view LoadedExecutable::name() const { return name_; } + +absl::StatusOr> LoadedExecutable::Fingerprint() + const { + return fingerprint_; +} + +absl::StatusOr LoadedExecutable::Serialize() const { + return absl::UnimplementedError( + "IFRT service executable does not support `Serialize` since the " + "underlying serialization format is not stable"); +} + +int LoadedExecutable::num_devices() const { return num_devices_; } + +int64_t LoadedExecutable::SizeOfGeneratedCodeInBytes() const { + LOG(FATAL) << "Unimplemented"; +} + +absl::StatusOr LoadedExecutable::GetCompiledMemoryStats() + const { + return absl::UnimplementedError("Unimplemented"); +} + +std::optional> LoadedExecutable::GetParameterShardings() + const { + auto info = metadata_future_.Await(); + if (!info.ok()) { + return std::nullopt; + } + return (*info)->parameter_shardings; +} + +std::optional> LoadedExecutable::GetOutputShardings() + const { + auto info = metadata_future_.Await(); + if (!info.ok()) { + return std::nullopt; + } + return (*info)->output_shardings; +} + +absl::StatusOr> LoadedExecutable::GetParameterLayouts() + const { + TF_ASSIGN_OR_RETURN(auto info, metadata_future_.Await()); + return info->parameter_layouts; +} + +absl::StatusOr> LoadedExecutable::GetOutputLayouts() const { + TF_ASSIGN_OR_RETURN(auto info, metadata_future_.Await()); + return info->output_layouts; +} + +absl::StatusOr>> +LoadedExecutable::GetOutputMemoryKinds() const { + TF_ASSIGN_OR_RETURN(auto info, metadata_future_.Await()); + return info->output_memory_kinds; +} + +absl::StatusOr>> +LoadedExecutable::GetHloModules() const { + return absl::UnimplementedError( + "IFRT service does not support LoadedExecutable::GetHloModules() since " + "HloModule does not provide stable serialization"); +} + +absl::StatusOr< + absl::flat_hash_map> +LoadedExecutable::GetCostAnalysis() const { + return absl::UnimplementedError("Unimplemented"); +} + +absl::StatusOr +LoadedExecutable::Execute(absl::Span> args, + const ExecuteOptions& options, + std::optional devices) { + auto req = std::make_unique(); + req->set_loaded_executable_handle(handle_); + for (const auto& arg : args) { + auto* array = llvm::dyn_cast_or_null(arg.get()); + if (array == nullptr) { + return absl::InvalidArgumentError( + "Invalid IFRT array type provided to `LoadedExecutable::Execute`"); + } + req->add_args_handles(array->handle().handle); + } + TF_ASSIGN_OR_RETURN(*req->mutable_execute_options(), options.ToProto()); + if (devices.has_value()) { + for (const auto* device : *devices) { + req->add_device_ids(device->id()); + } + } + + TF_ASSIGN_OR_RETURN( + std::shared_ptr response, + rpc_helper_->LoadedExecutableExecute(std::move(req)).Await()); + + // NOTE: All future and array handles in `response` must have an owner + // locally, or be requested to be destructed remotely, before returning. + + xla::ifrt::LoadedExecutable::ExecuteResult result; + + // Populate the execution status future. `CheckFuture` deletes the server-side + // futures after its completion. + result.status = rpc_helper_->CheckFuture(response->status_handle()); + + // Create output arrays. The cleanup logic ensures that all handles are + // properly cleaned up on early return. + absl::Cleanup cleanup = [&]() { + int index = result.outputs.size(); + result.outputs.clear(); // Cleaned up by `~Array()`. + + for (; index < response->outputs_size(); ++index) { + Array::Destruct( + rpc_helper_.get(), + ArrayHandle{.handle = response->outputs(index).array_handle()}); + } + }; + const auto lookup_device = absl::bind_front(&Client::LookupDevice, client()); + for (const auto& output : response->outputs()) { + DType dtype = FromDTypeProto(output.dtype()); + Shape shape = FromShapeProto(output.shape()); + TF_ASSIGN_OR_RETURN(auto sharding, + FromShardingProto(lookup_device, output.sharding())); + result.outputs.push_back(tsl::MakeRef( + client(), rpc_helper_, dtype, std::move(shape), std::move(sharding), + ArrayHandle{.handle = output.array_handle()})); + } + std::move(cleanup).Cancel(); + + return result; +} + +Future LoadedExecutable::Delete() { + auto req = std::make_unique(); + req->set_loaded_executable_handle(handle_); + + absl::StatusOr> response = + rpc_helper_->LoadedExecutableDelete(std::move(req)).Await(); + if (!response.ok()) { + return Future(response.status()); + } + return rpc_helper_->CheckFuture((*response)->future_handle()); +} + +bool LoadedExecutable::IsDeleted() const { + auto req = std::make_unique(); + req->set_loaded_executable_handle(handle_); + + absl::StatusOr> response = + rpc_helper_->LoadedExecutableIsDeleted(std::move(req)).Await(); + if (!response.ok()) { + LOG(ERROR) << "Failed to query the deletion status of `LoadedExecutable`: " + << response.status(); + return false; + } + return (*response)->is_deleted(); +} + +absl::Span +LoadedExecutable::addressable_device_logical_ids() const { + return addressable_device_logical_device_ids_; +} + +absl::Span LoadedExecutable::addressable_devices() + const { + return addressable_devices_; +} + +void LoadedExecutable::PollLoadedHostCallback( + uint64_t handle, + tsl::RCReference loaded_host_callback) { + // Note: individual host callbacks may live longer than the executable as the + // destruction of an IFRT executable is not required to block until all + // in-flight executions are complete. Therefore, the following lambda must not + // capture `this` and is scheduled on the default thread pool. + auto f = [rpc_helper = rpc_helper_, handle, + loaded_host_callback = std::move(loaded_host_callback)]() { + while (true) { + const uint64_t operand_handle = + rpc_helper->host_buffer_store()->NextHandle(); + + auto poll_req = std::make_unique(); + poll_req->set_loaded_host_callback_handle(handle); + poll_req->set_operand_host_buffer_handle(operand_handle); + auto response = + rpc_helper->LoadedHostCallbackPoll(std::move(poll_req)).Await(); + + if (!response.ok()) { + LOG_EVERY_N_SEC(ERROR, 60) + << "Failed to poll host callback execution: " << response.status(); + continue; + } + + if (!(*response)->has_host_callback_execution_handle()) { + // The host callback is destructed from the server. + break; + } + + auto ret_req = std::make_unique(); + ret_req->set_host_callback_execution_handle( + (*response)->host_callback_execution_handle()); + + absl::StatusOr result_handle = + PrepareAndExecuteLoadedHostCallback( + rpc_helper->host_buffer_store().get(), loaded_host_callback.get(), + operand_handle); + if (result_handle.ok()) { + ret_req->set_result_host_buffer_handle(*result_handle); + } else { + *ret_req->mutable_error() = tsl::StatusToProto(result_handle.status()); + } + + rpc_helper->LoadedHostCallbackReturn(std::move(ret_req)) + .OnReady([](absl::StatusOr< + std::shared_ptr> + response) { + if (!response.ok()) { + LOG(ERROR) << "Failed to return host callback results: " + << response.status(); + } + }); + } + }; + tsl::Env::Default()->SchedClosure(std::move(f)); +} + +char LoadedExecutable::ID = 0; // NOLINT + +} // namespace proxy +} // namespace ifrt +} // namespace xla diff --git a/third_party/xla/xla/python/ifrt_proxy/client/executable.h b/third_party/xla/xla/python/ifrt_proxy/client/executable.h new file mode 100644 index 00000000000000..41932886c3fed3 --- /dev/null +++ b/third_party/xla/xla/python/ifrt_proxy/client/executable.h @@ -0,0 +1,142 @@ +/* + * Copyright 2023 The OpenXLA Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef XLA_PYTHON_IFRT_PROXY_CLIENT_EXECUTABLE_H_ +#define XLA_PYTHON_IFRT_PROXY_CLIENT_EXECUTABLE_H_ + +#include +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/container/node_hash_set.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "llvm/Support/ExtensibleRTTI.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/layout.h" +#include "xla/pjrt/pjrt_executable.h" +#include "xla/python/ifrt/array.h" +#include "xla/python/ifrt/client.h" +#include "xla/python/ifrt/device.h" +#include "xla/python/ifrt/executable.h" +#include "xla/python/ifrt/future.h" +#include "xla/python/ifrt/host_callback.h" +#include "xla/python/ifrt_proxy/client/rpc_helper.h" +#include "xla/xla_data.pb.h" +#include "tsl/concurrency/ref_count.h" + +namespace xla { +namespace ifrt { +namespace proxy { + +class LoadedExecutable final + : public llvm::RTTIExtends { + public: + LoadedExecutable(xla::ifrt::Client* client, + std::shared_ptr rpc_helper, uint64_t handle, + std::string name, int num_devices, + std::vector + addressable_device_logical_device_ids, + std::vector addressable_devices, + absl::StatusOr> fingerprint, + std::vector> + loaded_host_callbacks, + std::vector loaded_host_callback_handles); + + ~LoadedExecutable() override; + + xla::ifrt::Client* client() const override; + absl::string_view name() const override; + absl::StatusOr> Fingerprint() const override; + absl::StatusOr Serialize() const override; + + int num_devices() const override; + int64_t SizeOfGeneratedCodeInBytes() const override; + absl::StatusOr GetCompiledMemoryStats() const override; + + std::optional> GetParameterShardings() const override; + std::optional> GetOutputShardings() const override; + absl::StatusOr> GetParameterLayouts() const override; + absl::StatusOr> GetOutputLayouts() const override; + absl::StatusOr>> + GetOutputMemoryKinds() const override; + absl::StatusOr>> GetHloModules() + const override; + + absl::StatusOr> + GetCostAnalysis() const override; + + absl::StatusOr Execute( + absl::Span> args, + const ExecuteOptions& options, + std::optional devices) override; + + Future Delete() override; + bool IsDeleted() const override; + + absl::Span addressable_device_logical_ids() + const override; + absl::Span addressable_devices() const override; + + static char ID; // NOLINT + + private: + struct Metadata { + std::optional> parameter_shardings; + std::optional> output_shardings; + + absl::StatusOr> parameter_layouts; + absl::StatusOr> output_layouts; + + // Elements in `output_memory_kinds` point to elements in `memory_kinds`. + // Required since `GetOutputMemoryKinds()` returns `absl::string_view`. + // `memory_kinds` uses `absl::node_hash_set` for pointer stability. + absl::node_hash_set memory_kinds; + absl::StatusOr>> + output_memory_kinds; + }; + + void PollLoadedHostCallback( + uint64_t handle, + tsl::RCReference loaded_host_callback); + + xla::ifrt::Client* client_; + std::shared_ptr rpc_helper_; + + const uint64_t handle_; + const std::string name_; + const int num_devices_; + const std::vector + addressable_device_logical_device_ids_; + const std::vector addressable_devices_; + const absl::StatusOr> fingerprint_; + + // Metadata queried when the executable is created. Declared as `mutable` + // since `Future::Await()` is not const. + mutable Future>> metadata_future_; +}; + +} // namespace proxy +} // namespace ifrt +} // namespace xla + +#endif // XLA_PYTHON_IFRT_PROXY_CLIENT_EXECUTABLE_H_ diff --git a/third_party/xla/xla/python/ifrt_proxy/client/executable_test.cc b/third_party/xla/xla/python/ifrt_proxy/client/executable_test.cc new file mode 100644 index 00000000000000..778ec1fc745433 --- /dev/null +++ b/third_party/xla/xla/python/ifrt_proxy/client/executable_test.cc @@ -0,0 +1,341 @@ +// Copyright 2023 The OpenXLA Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "xla/python/ifrt_proxy/client/executable.h" + +#include +#include +#include + +#include +#include +#include "absl/status/status.h" +#include "absl/types/span.h" +#include "llvm/Support/Casting.h" +#include "xla/layout_util.h" +#include "xla/pjrt/pjrt_common.h" +#include "xla/python/ifrt/array.h" +#include "xla/python/ifrt/device.h" +#include "xla/python/ifrt/dtype.h" +#include "xla/python/ifrt/executable.h" +#include "xla/python/ifrt/future.h" +#include "xla/python/ifrt/memory.h" +#include "xla/python/ifrt/mock.h" +#include "xla/python/ifrt/shape.h" +#include "xla/python/ifrt/sharding.h" +#include "xla/python/ifrt_proxy/client/array.h" +#include "xla/python/ifrt_proxy/client/client_session.h" +#include "xla/python/ifrt_proxy/client/host_buffer.h" +#include "xla/python/ifrt_proxy/client/mock_client_session.h" +#include "xla/python/ifrt_proxy/client/mock_host_buffer.h" +#include "xla/python/ifrt_proxy/client/rpc_helper.h" +#include "xla/python/ifrt_proxy/client/version.h" +#include "xla/python/ifrt_proxy/common/ifrt_service.pb.h" +#include "xla/python/ifrt_proxy/common/types.h" +#include "tsl/concurrency/ref_count.h" +#include "tsl/platform/protobuf.h" // IWYU pragma: keep +#include "tsl/platform/status_matchers.h" +#include "tsl/platform/statusor.h" +#include "tsl/platform/test.h" + +using ::testing::_; +using ::testing::ElementsAre; +using ::testing::Optional; +using ::testing::Pointee; +using ::testing::Return; +using ::testing::SizeIs; +using ::testing::StrEq; +using ::tsl::protobuf::TextFormat; +using ::tsl::testing::IsOkAndHolds; +using ::tsl::testing::StatusIs; + +#if defined(PLATFORM_GOOGLE) +using ::testing::EquivToProto; +using ::testing::proto::Partially; +#endif + +namespace xla { +namespace ifrt { +namespace proxy { +namespace { + +IfrtProxyVersion Version() { + IfrtProxyVersion version; + version.set_protocol_version(kClientMinVersion); + return version; +} + +class LoadedExecutableTest : public ::testing::Test { + protected: + void SetUp() override { + session_ = std::make_shared(); + rpc_helper_ = std::make_shared(Version(), session_); + + host_buffer_store_ = std::make_shared(); + rpc_helper_->set_host_buffer_store(host_buffer_store_); + + // Default handler that ignores all uninteresting requests, but still + // invokes the callback in order to avoid hanging the caller forever. + EXPECT_CALL(*session_, Enqueue(_)) + .WillRepeatedly(Return(Future( + absl::InternalError("Request has no mock handlers")))); + } + + std::shared_ptr session_; + std::shared_ptr rpc_helper_; + std::shared_ptr host_buffer_store_; +}; + +// TODO(b/315809436): Test needs rewrite because protobuf matchers are not OSS +#if defined(PLATFORM_GOOGLE) +TEST_F(LoadedExecutableTest, Metadata) { + IfrtResponse response; + ASSERT_TRUE(TextFormat::ParseFromString( + R"pb( + loaded_executable_metadata_response { + parameter_shardings { + shardings { type: REPLICATED } + shardings { + type: OTHER + tile_shape { + element_type: BF16 + dimensions: [ 2, 2 ] + } + tile_assignment_dimensions: [ 0, 1 ] + } + } + output_shardings { shardings { type: REPLICATED } } + parameter_layouts_list { + layouts { minor_to_major: 0 } + layouts { minor_to_major: [ 1, 0 ] } + } + output_layouts_list { layouts { minor_to_major: [ 1, 0 ] } } + output_memory_kinds { memory_kind_lists { memory_kinds: [ "foo" ] } } + } + )pb", + &response)); + EXPECT_CALL(*session_, Enqueue(Pointee(Partially(EquivToProto( + R"pb(loaded_executable_metadata_request { + loaded_executable_handle: 1234 + })pb"))))) + .WillOnce(MockClientSessionReturnResponse(response)); + + MockClient client; + LoadedExecutable executable( + &client, rpc_helper_, /*handle=*/1234, /*name=*/"foo", + /*num_devices=*/2, /*addressable_device_logical_device_ids=*/{}, + /*addressable_devices=*/{}, /*fingerprint=*/"fingerprint", + /*loaded_host_callbacks=*/{}, /*loaded_host_callback_handles=*/{}); + + EXPECT_THAT( + executable.GetParameterShardings(), + Optional(ElementsAre( + EquivToProto(R"pb(type: REPLICATED)pb"), + EquivToProto(R"pb(type: OTHER + tile_shape { + element_type: BF16 + dimensions: [ 2, 2 ] + } + tile_assignment_dimensions: [ 0, 1 ])pb")))); + EXPECT_THAT(executable.GetOutputShardings(), + Optional(ElementsAre(EquivToProto(R"pb(type: REPLICATED)pb")))); + EXPECT_THAT(executable.GetParameterLayouts(), + IsOkAndHolds(ElementsAre( + xla::LayoutUtil::MakeDescendingLayout(/*rank=*/1), + xla::LayoutUtil::MakeDescendingLayout(/*rank=*/2)))); + EXPECT_THAT(executable.GetOutputLayouts(), + IsOkAndHolds(ElementsAre( + xla::LayoutUtil::MakeDescendingLayout(/*rank=*/2)))); + EXPECT_THAT(executable.GetOutputMemoryKinds(), + IsOkAndHolds(ElementsAre(ElementsAre("foo")))); +} +#endif + +// TODO(b/315809436): Test needs rewrite because protobuf matchers are not OSS +#if defined(PLATFORM_GOOGLE) +TEST_F(LoadedExecutableTest, Execute) { + MockDevice device; + ON_CALL(device, global_device_id()) + .WillByDefault(Return(xla::PjRtGlobalDeviceId(1))); + + MockClient client; + ON_CALL(client, LookupDevice(1)).WillByDefault(Return(&device)); + + LoadedExecutable executable( + &client, rpc_helper_, /*handle=*/1234, /*name=*/"foo", + /*num_devices=*/2, /*addressable_device_logical_device_ids=*/{}, + /*addressable_devices=*/{}, /*fingerprint=*/"fingerprint", + /*loaded_host_callbacks=*/{}, /*loaded_host_callback_handles=*/{}); + + IfrtResponse response; + ASSERT_TRUE(TextFormat::ParseFromString(R"pb( + loaded_executable_execute_response { + status_handle: 2000 + outputs { + dtype: DTYPE_F32 + shape { dimensions: [ 4, 4 ] } + array_handle: 3000 + } + outputs { + dtype: DTYPE_F16 + shape { dimensions: [ 8 ] } + array_handle: 3001 + } + } + )pb", + &response)); + { + auto* outputs = response.mutable_loaded_executable_execute_response() + ->mutable_outputs(); + TF_ASSERT_OK_AND_ASSIGN( + *(*outputs)[0].mutable_sharding(), + ToShardingProto(*SingleDeviceSharding::Create(&device, MemoryKind()))); + TF_ASSERT_OK_AND_ASSIGN( + *(*outputs)[1].mutable_sharding(), + ToShardingProto(*SingleDeviceSharding::Create(&device, MemoryKind()))); + } + EXPECT_CALL(*session_, Enqueue(Pointee(Partially(EquivToProto( + R"pb(loaded_executable_execute_request { + loaded_executable_handle: 1234 + args_handles: [ 1000, 1001 ] + device_ids: [ 1 ] + })pb"))))) + .WillOnce(MockClientSessionReturnResponse(response)); + + ASSERT_TRUE(TextFormat::ParseFromString(R"pb( + response_metadata { + status { + code: 2 # UNKNOWN + message: "injected error" + } + } + )pb", + &response)); + EXPECT_CALL(*session_, + Enqueue(Pointee(Partially(EquivToProto(R"pb(check_future_request { + future_handle: 2000 + })pb"))))) + .WillOnce(MockClientSessionReturnResponse(response)); + + DeviceList devices({&device}); + + std::vector> args; + for (const uint64_t handle : {1000, 1001}) { + args.push_back(tsl::MakeRef( + &client, rpc_helper_, DType(DType::kF32), Shape({2, 2}), + OpaqueSharding::Create(devices, MemoryKind()), + ArrayHandle{.handle = handle})); + } + + TF_ASSERT_OK_AND_ASSIGN( + auto result, executable.Execute( + absl::MakeSpan(args), + xla::ifrt::LoadedExecutable::ExecuteOptions(), devices)); + + EXPECT_THAT(result.status.Await(), + StatusIs(absl::StatusCode::kUnknown, "injected error")); + + ASSERT_THAT(result.outputs, SizeIs(2)); + + const auto output0 = result.outputs[0]; + EXPECT_EQ(output0->dtype(), DType(DType::kF32)); + EXPECT_EQ(output0->shape(), Shape({4, 4})); + EXPECT_EQ(llvm::cast(output0.get())->handle().handle, 3000); + + const auto output1 = result.outputs[1]; + EXPECT_EQ(output1->dtype(), DType(DType::kF16)); + EXPECT_EQ(output1->shape(), Shape({8})); + EXPECT_EQ(llvm::cast(output1.get())->handle().handle, 3001); +} +#endif + +// TODO(b/315809436): Test needs rewrite because protobuf matchers are not OSS +#if defined(PLATFORM_GOOGLE) +TEST_F(LoadedExecutableTest, Delete) { + MockClient client; + LoadedExecutable executable( + &client, rpc_helper_, /*handle=*/1234, /*name=*/"foo", + /*num_devices=*/2, /*addressable_device_logical_device_ids=*/{}, + /*addressable_devices=*/{}, /*fingerprint=*/"fingerprint", + /*loaded_host_callbacks=*/{}, /*loaded_host_callback_handles=*/{}); + + { + IfrtResponse response; + ASSERT_TRUE(TextFormat::ParseFromString( + R"pb( + loaded_executable_delete_response { future_handle: 2000 } + )pb", + &response)); + EXPECT_CALL(*session_, Enqueue(Pointee(Partially(EquivToProto( + R"pb(loaded_executable_delete_request { + loaded_executable_handle: 1234 + })pb"))))) + .WillOnce(MockClientSessionReturnResponse(response)); + + ASSERT_TRUE(TextFormat::ParseFromString( + R"pb( + response_metadata { + status { + code: 2 # UNKNOWN + message: "injected error" + } + } + )pb", + &response)); + EXPECT_CALL( + *session_, + Enqueue(Pointee(Partially(EquivToProto(R"pb(check_future_request { + future_handle: 2000 + })pb"))))) + .WillOnce(MockClientSessionReturnResponse(response)); + + Future result = executable.Delete(); + EXPECT_THAT(result.Await(), + StatusIs(absl::StatusCode::kUnknown, StrEq("injected error"))); + } + + { + IfrtResponse response; + ASSERT_TRUE(TextFormat::ParseFromString( + R"pb( + loaded_executable_is_deleted_response { is_deleted: true } + )pb", + &response)); + EXPECT_CALL(*session_, Enqueue(Pointee(Partially(EquivToProto( + R"pb(loaded_executable_is_deleted_request { + loaded_executable_handle: 1234 + })pb"))))) + .WillOnce(MockClientSessionReturnResponse(response)); + + EXPECT_TRUE(executable.IsDeleted()); + } + + IfrtResponse response; + ASSERT_TRUE(TextFormat::ParseFromString( + R"pb( + loaded_executable_destruct_response {} + )pb", + &response)); + EXPECT_CALL(*session_, Enqueue(Pointee(Partially(EquivToProto( + R"pb(loaded_executable_destruct_request { + loaded_executable_handle: 1234 + })pb"))))) + .WillOnce(MockClientSessionReturnResponse(response)); +} +#endif + +} // namespace +} // namespace proxy +} // namespace ifrt +} // namespace xla diff --git a/third_party/xla/xla/python/ifrt_proxy/client/grpc_client.cc b/third_party/xla/xla/python/ifrt_proxy/client/grpc_client.cc new file mode 100644 index 00000000000000..4279ab7d4d430c --- /dev/null +++ b/third_party/xla/xla/python/ifrt_proxy/client/grpc_client.cc @@ -0,0 +1,189 @@ +// Copyright 2023 The OpenXLA Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include + +#include "absl/functional/any_invocable.h" +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/log/log_entry.h" +#include "absl/log/log_sink.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "absl/time/clock.h" +#include "absl/time/time.h" +#include "grpcpp/client_context.h" +#include "xla/pjrt/distributed/util.h" +#include "xla/python/ifrt/future.h" +#include "xla/python/ifrt_proxy/client/client.h" +#include "xla/python/ifrt_proxy/client/grpc_client_session.h" +#include "xla/python/ifrt_proxy/client/grpc_host_buffer.h" +#include "xla/python/ifrt_proxy/client/registry.h" +#include "xla/python/ifrt_proxy/client/rpc_helper.h" +#include "xla/python/ifrt_proxy/client/version.h" +#include "xla/python/ifrt_proxy/common/grpc_ifrt_service.pb.h" +#include "xla/python/ifrt_proxy/common/ifrt_service.pb.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/statusor.h" + +namespace xla { +namespace ifrt { +namespace proxy { + +namespace { + +// Attempts to establish a session to the proxy-server and returns a `Client` +// based on the session if successful. `on_disconnect` will be invoked exactly +// once if this function returns successfully, and not invoked if this function +// returns a non-OK status. +absl::StatusOr> AttemptConnection( + absl::string_view server_address, + std::function on_disconnect, int attempt_no, + absl::AnyInvocable log_initial_connection) { + std::unique_ptr rpc_helper; + auto init_response_promise = + Future>>::CreatePromise(); + + if (on_disconnect == nullptr) { + on_disconnect = [](absl::Status s) { + LOG(WARNING) << "IFRT proxy server disconnected: " << s; + }; + } + + // TODO(b/266635130): Move gRPC stub creation to be outside of `Client` so + // that we can pass mock `ClientSession` to the client. + auto stub = CreateGrpcStub(server_address); + + auto session_disconnect_cb = + [init_response = Future>>( + init_response_promise), + on_disconnect = std::move(on_disconnect), + attempt_no](absl::Status s) mutable { + // If the `rpc_helper->Init().OnReady(cb)` statement below has returned, + // the callback cb in that statement (which sets `init_response`) is + // guaranteed by `GrpcClientSession::Create()` to be called before + // `session_disconnect_cb`. + // TODO(madthanu): The above statement is false (even if we wanted to, + // we cannot meaningfully enforce or document the guarantee of + // the returned Future's OnReady being called before another callback), + // although the exact way init_response_promise is set below makes it + // work most of the time. + if (init_response.IsReady() && init_response.Await().ok()) { + // If the init RPC has already completed successfully, we have + // already or will be returning OK from the `AttemptConnection` call. + // So, invoke `on_disconnect`. + on_disconnect(s); + } else { + // Otherwise, we are going to return an error from + // `AttemptConnection`. So do not invoke `on_disconnect`. + VLOG(0) << "GrpcClientSession attempt " << attempt_no + << " failed: " << s; + } + }; + + GrpcIfrtSessionMetadata metadata; + { + GrpcGetVersionRequest request; + request.mutable_min_version()->set_protocol_version(kClientMinVersion); + request.mutable_max_version()->set_protocol_version(kClientMaxVersion); + + ::grpc::ClientContext context; + GrpcGetVersionResponse response; + TF_RETURN_IF_ERROR( + xla::FromGrpcStatus(stub->GetVersion(&context, request, &response))); + + CHECK_GE(response.version().protocol_version(), kClientMinVersion); + CHECK_LE(response.version().protocol_version(), kClientMaxVersion); + *metadata.mutable_version() = response.version(); + } + + auto session = + GrpcClientSession::Create(stub, metadata, session_disconnect_cb); + rpc_helper = + std::make_unique(metadata.version(), std::move(session)); + + log_initial_connection(absl::StrCat("Sending InitRequest and waiting for ", + "response (attempt ", attempt_no, ").")); + + // TODO(b/282757875): Use a separate Request that will indicate quickly + // whether the grpc_client<->grpc_server session has been established or + // not, instead of combining it with the Request that will fetch device + // information (which can take a while, depending on the IFRT backend). + rpc_helper->Init(std::make_unique()) + .OnReady([&](auto resp) mutable { init_response_promise.Set(resp); }); + + TF_ASSIGN_OR_RETURN(auto init_response, + Future>>( + init_response_promise) + .Await()); + + auto host_buffer_store = std::make_unique( + stub, metadata.version(), init_response->session_id()); + rpc_helper->set_host_buffer_store(std::move(host_buffer_store)); + + return Client::Create(std::move(rpc_helper), std::move(*init_response)); +} + +absl::StatusOr> CreateGrpcClient( + absl::string_view server_address, const ClientConnectionOptions& options) { + auto log_initial_connection = + [f = std::move(options.on_connection_update)](absl::string_view msg) { + VLOG(0) << msg; + if (f) { + f(absl::StrCat(absl::Now(), ": ", msg)); + } + }; + + absl::Time start_time = absl::Now(); + absl::Status last_status; + for (int i = 0; absl::Now() - start_time < options.connection_timeout; ++i) { + log_initial_connection(absl::StrCat("Connecting to IFRT proxy server at ", + server_address, ", attempt #", i, + "...")); + absl::StatusOr> result = AttemptConnection( + server_address, options.on_disconnect, i, log_initial_connection); + if (result.ok()) { + log_initial_connection(absl::StrCat("Connected to IFRT proxy server on ", + "attempt #", i, ".")); + return result; + } else { + last_status = result.status(); + log_initial_connection( + absl::StrCat("Connection to IFRT proxy server attempt #", i, + "failed: ", last_status.ToString())); + } + absl::SleepFor(absl::Seconds(1)); + } + + // We want to prepend a human-friendly error message to status before + // returning. + auto err_msg = + absl::StrCat("Unable to establish connection to ifrt_proxy server, ", + "please check provided address '", server_address, + "'; detailed error: ", last_status.message()); + log_initial_connection(err_msg); + return tsl::errors::CreateWithUpdatedMessage(last_status, err_msg); +} + +} // namespace + +bool register_client_factory = + ([] { RegisterClientFactory("grpc", CreateGrpcClient); }(), true); + +} // namespace proxy +} // namespace ifrt +} // namespace xla diff --git a/third_party/xla/xla/python/ifrt_proxy/client/grpc_client_session.cc b/third_party/xla/xla/python/ifrt_proxy/client/grpc_client_session.cc new file mode 100644 index 00000000000000..a70e633d6767e5 --- /dev/null +++ b/third_party/xla/xla/python/ifrt_proxy/client/grpc_client_session.cc @@ -0,0 +1,266 @@ +// Copyright 2023 The OpenXLA Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "xla/python/ifrt_proxy/client/grpc_client_session.h" + +#include +#include +#include +#include +#include +#include + +#include "absl/base/call_once.h" +#include "absl/base/thread_annotations.h" +#include "absl/container/flat_hash_map.h" +#include "absl/functional/bind_front.h" +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "absl/synchronization/mutex.h" +#include "absl/synchronization/notification.h" +#include "grpc/grpc.h" +#include "grpcpp/channel.h" +#include "grpcpp/client_context.h" +#include "grpcpp/create_channel.h" +#include "grpcpp/security/credentials.h" +#include "grpcpp/support/channel_arguments.h" +#include "xla/pjrt/distributed/util.h" +#include "xla/python/ifrt/future.h" +#include "xla/python/ifrt_proxy/client/client_session.h" +#include "xla/python/ifrt_proxy/common/grpc_credentials.h" +#include "xla/python/ifrt_proxy/common/grpc_ifrt_service.grpc.pb.h" +#include "xla/python/ifrt_proxy/common/ifrt_service.pb.h" +#include "tsl/platform/env.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/logging.h" +#include "tsl/platform/threadpool.h" +#include "tsl/platform/unbounded_work_queue.h" + +namespace xla { +namespace ifrt { +namespace proxy { + +using OpId = int64_t; + +// Logically equivalent to a map, but thread-safe and +// with various convenience functions. +class GrpcClientSession::ResponseCallbackTable { + public: + absl::Status Add(OpId op_id, ResponseCallback callback) { + absl::MutexLock l(&mu_); + const bool inserted = table_.insert({op_id, std::move(callback)}).second; + if (!inserted) { + return absl::AlreadyExistsError( + absl::StrCat("Op id ", op_id, " already exists")); + } + return absl::OkStatus(); + } + + std::optional Pop(OpId op_id) { + absl::MutexLock l(&mu_); + auto it = table_.find(op_id); + if (it == table_.end()) { + return std::nullopt; + } + auto cb = std::move(it->second); + table_.erase(it); + return std::move(cb); + } + + absl::flat_hash_map PopAll() { + absl::flat_hash_map result; + absl::MutexLock l(&mu_); + result = std::move(table_); + table_ = absl::flat_hash_map(); + return result; + } + + private: + absl::Mutex mu_; + absl::flat_hash_map table_ ABSL_GUARDED_BY(mu_); +}; + +std::shared_ptr GrpcClientSession::Create( + std::shared_ptr stub, + GrpcIfrtSessionMetadata metadata, + StreamTerminatedCallback stream_terminated_cb) { + auto context = std::make_unique<::grpc::ClientContext>(); + context->AddMetadata("ifrt-proxy-grpc-ifrt-session-metadata-bin", + metadata.SerializeAsString()); + std::shared_ptr result(new GrpcClientSession( + std::move(stub), std::move(context), std::move(stream_terminated_cb))); + return result; +} + +GrpcClientSession::GrpcClientSession( + std::shared_ptr stub, + std::unique_ptr<::grpc::ClientContext> context, + StreamTerminatedCallback stream_terminated_cb) + : response_callbacks_(std::make_unique()), + reader_thread_(std::make_unique( + tsl::Env::Default(), "ifrt_proxy_client_grpc_reader", + /*num_threads=*/1)), + stub_(std::move(stub)), + context_(std::move(context)), + stream_(stub_->IfrtSession(context_.get())), + stream_terminated_cb_(std::move(stream_terminated_cb)), + user_futures_work_queue_(std::make_unique( + tsl::Env::Default(), "GrpcClientSessionUserFuturesWorkQueue")) { + reader_thread_->Schedule( + absl::bind_front(&GrpcClientSession::ReadLoop, this)); +} + +Future GrpcClientSession::Enqueue( + std::unique_ptr request) { + auto promise = Future::CreatePromise(); + absl::Status status = Enqueue( + std::move(request), [promise, queue = user_futures_work_queue_.get()]( + Response response) mutable { + queue->Schedule([promise = std::move(promise), + response = std::move(response)]() mutable -> void { + promise.Set(std::move(response)); + }); + }); + if (!status.ok()) { + user_futures_work_queue_->Schedule([promise, status]() mutable -> void { + promise.Set(std::move(status)); + }); + } + return Future(std::move(promise)); +} + +absl::Status GrpcClientSession::Enqueue(std::unique_ptr req, + ResponseCallback callback) { + const OpId op_id = req->request_metadata().op_id(); + + absl::MutexLock l(&writer_mu_); + if (writes_stopped_) { + return absl::FailedPreconditionError( + "GrpcClientSession: writes no longer allowed."); + } + + TF_RETURN_IF_ERROR(response_callbacks_->Add(op_id, std::move(callback))); + + if (!stream_->Write(*req)) { + CHECK(response_callbacks_->Pop(op_id).has_value()); + return absl::UnknownError("GrpcClientSession: writing to stream failed."); + } + + return absl::OkStatus(); +} + +void GrpcClientSession::ReadLoop() { + while (true) { + auto read_buffer = std::make_unique(); + if (!stream_->Read(read_buffer.get())) { + LOG(INFO) << "GrpcClientSession: reader loop is exiting."; + break; + } + + const OpId op_id = read_buffer->response_metadata().op_id(); + std::optional callback = response_callbacks_->Pop(op_id); + + if (callback.has_value()) { + VLOG(1) << "GrpcClientSession: Issuing callback for " << op_id; + (*callback)(std::move(read_buffer)); + VLOG(1) << "GrpcClientSession: Done with callback for " << op_id; + } else { + LOG(ERROR) << "Received response with no remaining registered callback: " + << read_buffer->DebugString(); + } + } + + reader_thread_stopped_.Notify(); + Finish(absl::OkStatus()); +} + +void GrpcClientSession::Finish(const absl::Status& client_status) { + LOG(INFO) << "GrpcClientSession: Finish() called with client status " + << client_status; + + absl::call_once(finish_once_, [&] { + context_->TryCancel(); + + LOG(INFO) << "GrpcClientSession: Waiting for reader thread to stop."; + reader_thread_stopped_.WaitForNotification(); + + auto finish_stream_and_get_server_status = [&]() -> absl::Status { + LOG(INFO) << "GrpClientSession: Attempting to call stream->Finish()"; + absl::MutexLock l(&writer_mu_); + // Note: stream_->Finish() counts as a write, and needs to be serialized + // with stream->Write(). + LOG(INFO) << "GrpClientSession: Attempting to call stream->Finish(), " + "mutex acquired"; + absl::Status server_status = xla::FromGrpcStatus(stream_->Finish()); + LOG(INFO) << "GrpClientSession: stream->Finish() returned server status " + << server_status; + + CHECK(!writes_stopped_); + writes_stopped_ = true; + + return server_status; + }; + + absl::Status combined_status = finish_stream_and_get_server_status(); + combined_status.Update(client_status); + + auto all_callbacks = response_callbacks_->PopAll(); + for (auto& [_, cb] : all_callbacks) { + if (combined_status.ok()) { + cb(absl::AbortedError("Finish(OK) called.")); + } else { + cb(combined_status); + } + } + + LOG(INFO) << "GrpClientSession::Finish(): calling terminated cb with " + << combined_status; + stream_terminated_cb_(combined_status); + }); +} + +GrpcClientSession::~GrpcClientSession() { + GrpcClientSession::Finish(absl::CancelledError("~GrpcClientSession called.")); + reader_thread_.reset(); // Wait until the reader thread exits. + LOG(INFO) << "Deleting GrpcClientSession.user_futures_work_queue_ ..."; + user_futures_work_queue_.reset(); + LOG(INFO) << "Deleted GrpcClientSession.user_futures_work_queue_."; +} + +std::shared_ptr CreateGrpcStub( + absl::string_view server_address) { + ::grpc::ChannelArguments args; + // Remove message size limit to accommodate large messages exchanged during + // model compilation. + args.SetInt(GRPC_ARG_MAX_SEND_MESSAGE_LENGTH, -1); + args.SetInt(GRPC_ARG_MAX_RECEIVE_MESSAGE_LENGTH, -1); + std::shared_ptr<::grpc::Channel> channel = ::grpc::CreateCustomChannel( + std::string(server_address), GetClientCredentials(), args); + VLOG(0) << " Established channel."; + CHECK(channel != nullptr); + + std::shared_ptr stub = + grpc::GrpcIfrtService::NewStub(channel); + VLOG(0) << " Created stub."; + CHECK(stub != nullptr); + + return stub; +} + +} // namespace proxy +} // namespace ifrt +} // namespace xla diff --git a/third_party/xla/xla/python/ifrt_proxy/client/grpc_client_session.h b/third_party/xla/xla/python/ifrt_proxy/client/grpc_client_session.h new file mode 100644 index 00000000000000..9ca8219760a15b --- /dev/null +++ b/third_party/xla/xla/python/ifrt_proxy/client/grpc_client_session.h @@ -0,0 +1,144 @@ +/* + * Copyright 2023 The OpenXLA Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef XLA_PYTHON_IFRT_PROXY_CLIENT_GRPC_CLIENT_SESSION_H_ +#define XLA_PYTHON_IFRT_PROXY_CLIENT_GRPC_CLIENT_SESSION_H_ + +#include +#include + +#include "absl/base/call_once.h" +#include "absl/base/thread_annotations.h" +#include "absl/status/status.h" +#include "absl/strings/string_view.h" +#include "absl/synchronization/mutex.h" +#include "absl/synchronization/notification.h" +#include "grpcpp/client_context.h" +#include "grpcpp/support/client_callback.h" +#include "grpcpp/support/sync_stream.h" +#include "xla/python/ifrt/future.h" +#include "xla/python/ifrt_proxy/client/client_session.h" +#include "xla/python/ifrt_proxy/common/grpc_ifrt_service.grpc.pb.h" +#include "xla/python/ifrt_proxy/common/ifrt_service.pb.h" +#include "tsl/platform/threadpool.h" +#include "tsl/platform/unbounded_work_queue.h" + +namespace xla { +namespace ifrt { +namespace proxy { + +// `GrpcClientSession` implements the client side of an `IfrtSession` +// stream(ing RPC) and allows users to enqueue `IfrtRequest`s on the +// stream and register callbacks for when `IfrtResponse`s are received. +class GrpcClientSession : public ClientSession { + public: + // `StreamTerminatedCallback` represents a function that will be called when + // the underlying streaming RPC is terminated permanently. The callback may be + // invoked by the "primary" thread and with various mutex locks held, so the + // callback should both return soon and not block on any events (deadlocks may + // happen otherwise). + using StreamTerminatedCallback = std::function; + + // Returns an instantiation of GrpcClientSession on the given `stub`. + // `stream_terminated_cb` is guaranteed to be called exactly once (unless the + // process terminates beforehand). It is guaranteed that no registered + // `ResponseCallback` (see below) will be called after `stream_terminated_cb`. + static std::shared_ptr Create( + std::shared_ptr stub, + GrpcIfrtSessionMetadata metadata, + StreamTerminatedCallback stream_terminated_cb); + + Future Enqueue(std::unique_ptr request) override; + + // `ResponseCallback` represents a function that can be invoked when + // `ClientSession` receives an `IfrtResponse`. May be invoked by the "primary" + // thread and with various mutex locks held. + using ResponseCallback = std::function; + + absl::Status Enqueue(std::unique_ptr req, + ResponseCallback callback); + + // Terminates the `GrpcClientSession` if it has not already been terminated. + // Waits until `stream_terminated_cb` returns. + void Finish(const absl::Status& client_status) override; + + // Not copyable (or moveable) + GrpcClientSession(const GrpcClientSession&) = delete; + GrpcClientSession& operator=(const GrpcClientSession&) = delete; + + // Calls `Finish()`. Also waits for the destruction of + // `user_futures_work_queue_` (see below) and thus can block on user-level + // callbacks. + ~GrpcClientSession() override; + + private: + class ResponseCallbackTable; + + GrpcClientSession(std::shared_ptr stub, + std::unique_ptr<::grpc::ClientContext> context, + StreamTerminatedCallback stream_terminated_cb); + + // Repeatedly waits for a `IfrtResponse` message to arrive; for each message, + // looks up the corresponding callback registered in `response_callbacks_` and + // invokes it inline. + void ReadLoop(); + + // Thread-safe table that logically maps from RequestMetadata.OpId to + // ResponseCallback. + const std::unique_ptr response_callbacks_; + + // Thread that invokes `ReadLoop()`. + std::unique_ptr reader_thread_; + + // A notification (waited on by `Finish()`) for when `ReadLoop()` exits. + absl::Notification reader_thread_stopped_; + + // Set by `Finish()`, respected by `Enqueue()` calls. + bool writes_stopped_ ABSL_GUARDED_BY(writer_mu_) = false; + + // A mutex that ensures serialization between various `Enqueue()` calls, since + // only one thread is allowed to write to the gRPC stream at a time. + absl::Mutex writer_mu_; + + // Ensures logic inside `Finish()` is internally called only once. + absl::once_flag finish_once_; + + // References to gRPC objects used to read and write to the stream. + const std::shared_ptr stub_; + const std::unique_ptr<::grpc::ClientContext> context_; + const std::unique_ptr< + ::grpc::ClientReaderWriterInterface> + stream_; + + const StreamTerminatedCallback stream_terminated_cb_; + + // Threadpool used to perform `Future<>::Promise::Set()` for Futures returned + // to callers of `Enqueue(std::unique_ptr request)`. We do this + // because `Set()` may block on arbitrary `OnReady` callbacks set by those + // callers. + std::unique_ptr user_futures_work_queue_; +}; + +// Creates a gRPC stub that connects to `server_address`. It can be used for +// `GrpcClientSession`. The same stub can be reused across multiple sessions. +std::shared_ptr CreateGrpcStub( + absl::string_view server_address); + +} // namespace proxy +} // namespace ifrt +} // namespace xla + +#endif // XLA_PYTHON_IFRT_PROXY_CLIENT_GRPC_CLIENT_SESSION_H_ diff --git a/third_party/xla/xla/python/ifrt_proxy/client/grpc_client_session_test.cc b/third_party/xla/xla/python/ifrt_proxy/client/grpc_client_session_test.cc new file mode 100644 index 00000000000000..18f1bb1328de2d --- /dev/null +++ b/third_party/xla/xla/python/ifrt_proxy/client/grpc_client_session_test.cc @@ -0,0 +1,481 @@ +// Copyright 2023 The OpenXLA Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "xla/python/ifrt_proxy/client/grpc_client_session.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include "absl/base/thread_annotations.h" +#include "absl/container/flat_hash_set.h" +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/log/log_sink_registry.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/synchronization/mutex.h" +#include "absl/synchronization/notification.h" +#include "absl/time/time.h" +#include "grpc/support/time.h" +#include "grpcpp/channel.h" +#include "grpcpp/create_channel.h" +#include "grpcpp/server_builder.h" +#include "grpcpp/server_context.h" +#include "grpcpp/support/status.h" +#include "grpcpp/support/sync_stream.h" +#include "xla/python/ifrt_proxy/client/version.h" +#include "xla/python/ifrt_proxy/common/grpc_credentials.h" +#include "xla/python/ifrt_proxy/common/grpc_ifrt_service.grpc.pb.h" +#include "xla/python/ifrt_proxy/common/grpc_ifrt_service.pb.h" +#include "xla/python/ifrt_proxy/common/ifrt_service.pb.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/logging.h" +#include "tsl/platform/status_matchers.h" +#include "tsl/platform/statusor.h" +#include "tsl/platform/test.h" + +namespace xla { +namespace ifrt { +namespace proxy { + +namespace { + +using ::testing::Not; +using ::tsl::testing::IsOk; + +constexpr int kOp1 = 1; +constexpr int kOp2 = 2; + +// Sufficient time for all processing (that are not explicitly waiting for +// further input) to have finished. +constexpr absl::Duration kSufficientTime = absl::Seconds(5); + +GrpcIfrtSessionMetadata Metadata() { + GrpcIfrtSessionMetadata metadata; + metadata.mutable_version()->set_protocol_version(kClientMaxVersion); + return metadata; +} + +absl::Status TestError() { return absl::UnknownError("test error"); } + +// A thread-safe queue of `absl::Status` values. +class Queue { + public: + void Push(absl::Status t) { + absl::MutexLock l(&mu_); + queue_.push_back(std::move(t)); + } + + std::optional PopOrTimeout( + absl::Duration timeout = kSufficientTime) { + absl::MutexLock l(&mu_); + auto cond = [this]() ABSL_SHARED_LOCKS_REQUIRED(mu_) -> bool { + return !queue_.empty(); + }; + mu_.AwaitWithTimeout(absl::Condition(&cond), timeout); + if (queue_.empty()) { + return std::nullopt; + } + absl::Status result = std::move(queue_.front()); + queue_.pop_front(); + return result; + } + + absl::Status Pop(absl::Duration timeout = kSufficientTime) { + auto result = PopOrTimeout(timeout); + CHECK(result.has_value()) << "Timeout!"; + return *result; + } + + void PopAllDuringDestruction() { + absl::MutexLock l(&mu_); + allow_non_empty_destruction_ = true; + } + + ~Queue() { + absl::MutexLock l(&mu_); + if (!allow_non_empty_destruction_) CHECK(queue_.empty()) << " " << this; + } + + private: + absl::Mutex mu_; + std::deque queue_ ABSL_GUARDED_BY(mu_); + bool allow_non_empty_destruction_ ABSL_GUARDED_BY(mu_) = false; +}; + +// Checks that the input is a list of zero-or-more OK statuses followed by +// zero-or-more NOT-OK statuses. Succeeds for {OK, NOT_OK, NOT_OK}, but fails +// for {OK, NOT_OK, OK}. +void ExpectHeadAndTail( + std::vector, absl::Status>> var_list) { + std::vector status_list; + for (const auto& v : var_list) { + if (std::holds_alternative>(v)) { + status_list.push_back(std::get>(v).status()); + } else { + status_list.push_back(std::get(v)); + } + } + bool seen_not_ok = false; + std::string str; + for (const auto& s : status_list) { + absl::StrAppend(&str, "\n", s.ToString(), "\n-----\n"); + } + for (const auto& s : status_list) { + if (!s.ok()) seen_not_ok = true; + if (seen_not_ok) { + EXPECT_THAT(s, Not(IsOk())) << str; + } + } +} + +using ServerStream = ::grpc::ServerReaderWriter; +using SessionAction = bool; +constexpr SessionAction kContinueSession = true; +constexpr SessionAction kStopSession = false; +using OnSessionStart = std::function; +using OnReqReceived = + std::function; + +// A simple implementation of IfrtService with various test-hooks. +class SimpleIfrtService : public grpc::GrpcIfrtService::Service { + public: + SimpleIfrtService(OnReqReceived on_req_received, + OnSessionStart on_session_start) + : on_req_received_(std::move(on_req_received)), + on_session_start_(std::move(on_session_start)) {} + + ::grpc::Status IfrtSession(::grpc::ServerContext* context, + ServerStream* stream) override { + if (on_session_start_ && on_session_start_() == kStopSession) { + return ::grpc::Status::OK; + } + + { + absl::MutexLock l(&mu_); + CHECK(contexts_.insert(context).second); + } + + while (true) { + IfrtRequest request; + LOG(INFO) << "Server: waiting on Read()."; + if (!stream->Read(&request)) { + LOG(INFO) << "Server: Read() returned false."; + break; + } + LOG(INFO) << "Server: Read() returned true."; + if (!on_req_received_) { + IfrtResponse response; + response.mutable_response_metadata()->set_op_id( + request.request_metadata().op_id()); + stream->Write(response); + } else if (on_req_received_(request, stream) == kStopSession) { + break; + } + } + { + absl::MutexLock l(&mu_); + CHECK_EQ(contexts_.erase(context), 1); + } + + LOG(INFO) << "Finishing IFRT session"; + return ::grpc::Status::OK; + } + + void CancelAllServerSessions() { + absl::MutexLock l(&mu_); + for (const auto& context : contexts_) { + context->TryCancel(); + } + } + + private: + const OnReqReceived on_req_received_; + const OnSessionStart on_session_start_; + + // Keeps track of `::grpc::ServerContext` for all ongoing sessions. + absl::Mutex mu_; + absl::flat_hash_set<::grpc::ServerContext*> contexts_ ABSL_GUARDED_BY(mu_); +}; + +// Encapsulates objects related to a client and server instance of +// `grpc::GrpcIfrtService`. +class ClientAndServer { + public: + explicit ClientAndServer(OnReqReceived on_req_received = nullptr, + OnSessionStart on_session_start = nullptr) { + std::string address = + absl::StrCat("localhost:", tsl::testing::PickUnusedPortOrDie()); + ::grpc::ServerBuilder builder; + builder.AddListeningPort(address, GetServerCredentials()); + ifrt_service_ = + std::make_unique(on_req_received, on_session_start); + builder.RegisterService(ifrt_service_.get()); + server_ = builder.BuildAndStart(); + + LOG(INFO) << "Server started and listening on " << address; + absl::FlushLogSinks(); + + std::shared_ptr<::grpc::Channel> channel = + ::grpc::CreateChannel(address, GetClientCredentials()); + channel->WaitForConnected(gpr_time_add( + gpr_now(GPR_CLOCK_REALTIME), gpr_time_from_seconds(10, GPR_TIMESPAN))); + LOG(INFO) << "conn_state = " << channel->GetState(/*try_to_connect=*/false); + + auto stub = grpc::GrpcIfrtService::NewStub(channel); + CHECK(stub != nullptr); + + client_session_ = GrpcClientSession::Create( + std::move(stub), Metadata(), [this](absl::Status s) { + client_finished_q_.Push(s); + client_finished_notification_.Notify(); + }); + + client_finished_q_.PopAllDuringDestruction(); + } + + void StopServer() { + ifrt_service_->CancelAllServerSessions(); + server_->Shutdown(); + server_->Wait(); + } + + ~ClientAndServer() { + StopServer(); + client_session_->Finish(absl::CancelledError("~ClientAndServer")); + client_finished_notification_.WaitForNotificationWithTimeout( + kSufficientTime); + CHECK(client_finished_notification_.HasBeenNotified()); + } + + GrpcClientSession* client_session() { return client_session_.get(); } + + Queue* client_finished_q() { return &client_finished_q_; } + + absl::StatusOr SendSimpleRequest(int op_id) { + owned_queues_.push_back(std::make_unique()); + Queue* q = owned_queues_.back().get(); + + auto req = std::make_unique(); + req->mutable_request_metadata()->set_op_id(op_id); + TF_RETURN_IF_ERROR(client_session_->Enqueue( + std::move(req), + [q](GrpcClientSession::Response resp) { q->Push(resp.status()); })); + + return q; + } + + private: + std::vector> owned_queues_; + Queue client_finished_q_; + absl::Notification client_finished_notification_; + std::shared_ptr client_session_; + + std::unique_ptr<::grpc::Server> server_; + std::unique_ptr ifrt_service_; +}; + +TEST(GrpcClientSessionTest, HappyCaseOneRequestWithServerTermination) { + ClientAndServer cs; + + TF_ASSERT_OK_AND_ASSIGN(Queue * response_q, cs.SendSimpleRequest(kOp1)); + + EXPECT_THAT(response_q->Pop(), IsOk()); + + EXPECT_EQ(cs.client_finished_q()->PopOrTimeout(), std::nullopt); + + cs.StopServer(); + EXPECT_THAT(cs.client_finished_q()->Pop(), Not(IsOk())); +} + +TEST(GrpcClientSessionTest, HappyCaseTwoRequestsWithClientFinish) { + ClientAndServer cs; + + TF_ASSERT_OK_AND_ASSIGN(Queue * response_q_1, cs.SendSimpleRequest(kOp1)); + TF_ASSERT_OK_AND_ASSIGN(Queue * response_q_2, cs.SendSimpleRequest(kOp2)); + + EXPECT_THAT(response_q_1->Pop(), IsOk()); + EXPECT_THAT(response_q_2->Pop(), IsOk()); + + EXPECT_EQ(cs.client_finished_q()->PopOrTimeout(), std::nullopt); + + cs.client_session()->Finish(TestError()); + EXPECT_THAT(cs.client_finished_q()->Pop(), Not(IsOk())); +} + +TEST(GrpcClientSessionTest, ServerFinishesDuringFirstRead) { + ClientAndServer cs( + /*on_req_received=*/[](auto, auto) { return kStopSession; }); + + TF_ASSERT_OK_AND_ASSIGN(Queue * response_q_1, cs.SendSimpleRequest(kOp1)); + EXPECT_THAT(response_q_1->Pop(), Not(IsOk())); + + absl::StatusOr response_q_2 = cs.SendSimpleRequest(kOp2); + EXPECT_THAT(response_q_2.status(), Not(IsOk())); + + EXPECT_THAT(cs.client_finished_q()->Pop(), Not(IsOk())); +} + +TEST(GrpcClientSessionTest, ServerFinishesDuringConstruction) { + ClientAndServer cs(/*on_req_received=*/nullptr, + /*on_session_start=*/[]() { return kStopSession; }); + + absl::StatusOr response_q_1 = cs.SendSimpleRequest(kOp1); + absl::StatusOr response_q_2 = cs.SendSimpleRequest(kOp2); + + ExpectHeadAndTail({response_q_1, response_q_2}); + if (response_q_1.ok()) EXPECT_THAT(response_q_1.value()->Pop(), Not(IsOk())); + if (response_q_2.ok()) EXPECT_THAT(response_q_2.value()->Pop(), Not(IsOk())); + + EXPECT_THAT(cs.client_finished_q()->Pop(), Not(IsOk())); +} + +TEST(GrpcClientSessionTest, ClientFinishesAfterServerConsumesFirstRequest) { + std::atomic session_ptr; + ClientAndServer cs( + /*on_req_received=*/[session_ptr = &session_ptr](auto, auto) { + session_ptr->load()->Finish(TestError()); + return kContinueSession; + }); + session_ptr.store(cs.client_session()); + + TF_ASSERT_OK_AND_ASSIGN(Queue * response_q_1, cs.SendSimpleRequest(kOp1)); + EXPECT_THAT(response_q_1->Pop(), Not(IsOk())); + + absl::StatusOr response_q_2 = cs.SendSimpleRequest(kOp2); + EXPECT_THAT(response_q_2.status(), Not(IsOk())); + + EXPECT_THAT(cs.client_finished_q()->Pop(), Not(IsOk())); +} + +TEST(GrpcClientSessionTest, ClientFinishesAfterServerWritesFirstResponse) { + std::atomic session_ptr; + ClientAndServer cs( + /*on_req_received=*/[session_ptr = &session_ptr](const IfrtRequest& r, + ServerStream* s) { + IfrtResponse response; + response.mutable_response_metadata()->set_op_id( + r.request_metadata().op_id()); + s->Write(response); + session_ptr->load()->Finish(TestError()); + return kContinueSession; + }); + session_ptr.store(cs.client_session()); + + TF_ASSERT_OK_AND_ASSIGN(Queue * response_q_1, cs.SendSimpleRequest(kOp1)); + absl::StatusOr response_q_2 = cs.SendSimpleRequest(kOp2); + + // The client may or may not terminate before the first response arrives. + response_q_1->Pop().IgnoreError(); + + // The client may or may not terminate before the second request could be + // enqueued. If it could be enqueued, the client will die without the server + // sending the corresponding response. + if (response_q_2.ok()) { + EXPECT_THAT(response_q_2.value()->Pop(), Not(IsOk())); + } + + EXPECT_THAT(cs.client_finished_q()->Pop(), Not(IsOk())); +} + +TEST(GrpcClientSessionTest, ClientFinishesDuringServerConstruction) { + std::atomic session_ptr; + absl::Notification init_done; + ClientAndServer cs(/*on_req_received=*/nullptr, + /*on_session_start=*/[session_ptr = &session_ptr, + init_done = &init_done]() { + init_done->WaitForNotification(); + session_ptr->load()->Finish(TestError()); + return kContinueSession; + }); + session_ptr.store(cs.client_session()); + init_done.Notify(); + + absl::StatusOr response_q_1 = cs.SendSimpleRequest(kOp1); + absl::StatusOr response_q_2 = cs.SendSimpleRequest(kOp2); + + if (response_q_1.ok()) { + EXPECT_THAT(response_q_1.value()->Pop(), Not(IsOk())); + } + if (response_q_2.ok()) { + EXPECT_THAT(response_q_2.value()->Pop(), Not(IsOk())); + } + + ExpectHeadAndTail({response_q_1, response_q_2}); + + EXPECT_THAT(cs.client_finished_q()->Pop(), Not(IsOk())); +} + +TEST(GrpcClientSessionTest, MethodsAfterFinishReturnError) { + ClientAndServer cs; + + TF_ASSERT_OK_AND_ASSIGN(Queue * response_q_1, cs.SendSimpleRequest(kOp1)); + cs.client_session()->Finish(TestError()); + + EXPECT_THAT(cs.SendSimpleRequest(kOp2), Not(IsOk())); + + response_q_1->PopAllDuringDestruction(); +} + +TEST(GrpcClientSessionTest, ReceivingBadIfrtResponseDoesNotCrash) { + ClientAndServer cs( + /*on_req_received=*/[](const IfrtRequest& r, ServerStream* s) mutable { + IfrtResponse resp; + resp.mutable_response_metadata()->set_op_id(kOp2); + s->Write(resp); + resp.mutable_response_metadata()->set_op_id( + r.request_metadata().op_id()); + s->Write(resp); + return kContinueSession; + }); + + TF_ASSERT_OK_AND_ASSIGN(Queue * response_q, cs.SendSimpleRequest(kOp1)); + + EXPECT_THAT(response_q->Pop(), IsOk()); +} + +TEST(GrpcClientSessionTest, BadInitialChannelFailsPromptly) { + std::string address = + absl::StrCat("localhost:", tsl::testing::PickUnusedPortOrDie()); + + std::shared_ptr<::grpc::Channel> channel = + ::grpc::CreateChannel(address, GetClientCredentials()); + + std::unique_ptr stub = + grpc::GrpcIfrtService::NewStub(channel); + EXPECT_TRUE(stub != nullptr); + + auto session_finished = std::make_shared(); + auto session = GrpcClientSession::Create( + std::move(stub), Metadata(), + [session_finished](absl::Status s) { session_finished->Push(s); }); + + EXPECT_THAT(session_finished->Pop(), Not(IsOk())); +} + +} // namespace + +} // namespace proxy +} // namespace ifrt +} // namespace xla diff --git a/third_party/xla/xla/python/ifrt_proxy/client/grpc_host_buffer.cc b/third_party/xla/xla/python/ifrt_proxy/client/grpc_host_buffer.cc new file mode 100644 index 00000000000000..c5a69737f057d5 --- /dev/null +++ b/third_party/xla/xla/python/ifrt_proxy/client/grpc_host_buffer.cc @@ -0,0 +1,182 @@ +// Copyright 2023 The OpenXLA Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "xla/python/ifrt_proxy/client/grpc_host_buffer.h" + +#include +#include +#include +#include +#include + +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/cord.h" +#include "absl/strings/string_view.h" +#include "grpcpp/client_context.h" +#include "grpcpp/support/client_callback.h" +#include "grpcpp/support/status.h" +#include "grpcpp/support/sync_stream.h" +#include "xla/pjrt/distributed/util.h" +#include "xla/python/ifrt/future.h" +#include "xla/python/ifrt_proxy/common/grpc_ifrt_service.grpc.pb.h" +#include "xla/python/ifrt_proxy/common/grpc_ifrt_service.pb.h" +#include "tsl/platform/env.h" +#include "tsl/platform/unbounded_work_queue.h" +#include "tsl/protobuf/status.pb.h" + +namespace xla { +namespace ifrt { +namespace proxy { + +static constexpr int64_t kChunkSize = 1024 * 1024; + +GrpcClientHostBufferStore::GrpcClientHostBufferStore( + std::shared_ptr stub, + IfrtProxyVersion version, uint64_t session_id) + : stub_(std::move(stub)), + version_(std::move(version)), + session_id_(session_id), + lookup_work_queue_(std::make_unique( + tsl::Env::Default(), "HostBufferStoreLookupsWorkQueue")) {} + +GrpcClientHostBufferStore::~GrpcClientHostBufferStore() { + LOG(INFO) << "Waiting for destruction of HostBufferStoreLookupsWorkQueue..."; + lookup_work_queue_.reset(); + LOG(INFO) << "Destructed HostBufferStoreLookupsWorkQueue."; +} + +uint64_t GrpcClientHostBufferStore::NextHandle() { + return next_handle_.fetch_add(1, std::memory_order_relaxed); +} + +Future GrpcClientHostBufferStore::Store(uint64_t handle, + absl::string_view data) { + // The current implementation synchronously sends host buffer chunks. We may + // consider making it asynchronous if the caller can leverage such asynchrony. + + GrpcHostBufferStoreMetadata metadata; + metadata.set_session_id(session_id_); + metadata.set_handle(handle); + metadata.set_buffer_size(data.size()); + + ::grpc::ClientContext context; + context.AddMetadata("ifrt-proxy-grpc-host-buffer-store-metadata-bin", + metadata.SerializeAsString()); + + GrpcHostBufferStoreResponse response; + auto writer = stub_->HostBufferStore(&context, &response); + + for (int64_t offset = 0; offset < data.size(); offset += kChunkSize) { + GrpcHostBufferStoreRequest request; +#if defined(PLATFORM_GOOGLE) + request.set_alias_data(data.substr(offset, kChunkSize)); +#else + // TODO(b/325306748): Find a way to not do a memory-copy. + request.set_data(std::string(data.substr(offset, kChunkSize))); +#endif + writer->Write(request); + } + + if (!writer->WritesDone()) { + return Future( + absl::InternalError("Failed to write all host buffer chunks")); + } + + return Future(xla::FromGrpcStatus(writer->Finish())); +} + +Future GrpcClientHostBufferStore::Store(uint64_t handle, + const absl::Cord& data) { + // The current implementation synchronously sends host buffer chunks. We may + // consider making it asynchronous if the caller can leverage such asynchrony. + + GrpcHostBufferStoreMetadata metadata; + metadata.set_session_id(session_id_); + metadata.set_handle(handle); + metadata.set_buffer_size(data.size()); + + ::grpc::ClientContext context; + context.AddMetadata("ifrt-proxy-grpc-host-buffer-store-metadata-bin", + metadata.SerializeAsString()); + + GrpcHostBufferStoreResponse response; + auto writer = stub_->HostBufferStore(&context, &response); + + for (absl::string_view chunk : data.Chunks()) { + for (int64_t offset = 0; offset < chunk.size(); offset += kChunkSize) { + GrpcHostBufferStoreRequest request; +#if defined(PLATFORM_GOOGLE) + request.set_alias_data(chunk.substr(offset, kChunkSize)); +#else + // TODO(b/325306748): Find a way to not do a memory-copy. + request.set_data(std::string(chunk.substr(offset, kChunkSize))); +#endif + writer->Write(request); + } + } + if (!writer->WritesDone()) { + return Future( + absl::InternalError("Failed to write all host buffer chunks")); + } + + return Future(xla::FromGrpcStatus(writer->Finish())); +} + +Future> GrpcClientHostBufferStore::Lookup( + uint64_t handle) { + auto promise = Future>::CreatePromise(); + + lookup_work_queue_->Schedule([this, handle, promise]() mutable -> void { + GrpcHostBufferLookupRequest request; + request.set_handle(handle); + request.set_session_id(session_id_); + + ::grpc::ClientContext context; + + std::unique_ptr<::grpc::ClientReaderInterface> + stream = stub_->HostBufferLookup(&context, request); + + absl::Cord data; + GrpcHostBufferLookupResponse response; + while (stream->Read(&response)) { + data.Append(response.data()); + } + + absl::Status status = xla::FromGrpcStatus(stream->Finish()); + if (status.ok()) { + promise.Set(std::move(data)); + } else { + promise.Set(status); + } + }); + + return Future>(promise); +} + +Future GrpcClientHostBufferStore::Delete(uint64_t handle) { + GrpcHostBufferDeleteRequest request; + request.set_session_id(session_id_); + request.set_handle(handle); + + ::grpc::ClientContext context; + GrpcHostBufferDeleteResponse response; + return Future(xla::FromGrpcStatus( + stub_->HostBufferDelete(&context, request, &response))); +} + +} // namespace proxy +} // namespace ifrt +} // namespace xla diff --git a/third_party/xla/xla/python/ifrt_proxy/client/grpc_host_buffer.h b/third_party/xla/xla/python/ifrt_proxy/client/grpc_host_buffer.h new file mode 100644 index 00000000000000..bbf9b9eecfeefe --- /dev/null +++ b/third_party/xla/xla/python/ifrt_proxy/client/grpc_host_buffer.h @@ -0,0 +1,71 @@ +/* + * Copyright 2023 The OpenXLA Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef XLA_PYTHON_IFRT_PROXY_CLIENT_GRPC_HOST_BUFFER_H_ +#define XLA_PYTHON_IFRT_PROXY_CLIENT_GRPC_HOST_BUFFER_H_ + +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/cord.h" +#include "absl/strings/string_view.h" +#include "xla/python/ifrt/future.h" +#include "xla/python/ifrt_proxy/client/host_buffer.h" +#include "xla/python/ifrt_proxy/common/grpc_ifrt_service.grpc.pb.h" +#include "tsl/platform/unbounded_work_queue.h" + +namespace xla { +namespace ifrt { +namespace proxy { + +class GrpcClientHostBufferStore : public ClientHostBufferStore { + public: + GrpcClientHostBufferStore( + std::shared_ptr stub, + IfrtProxyVersion version, uint64_t session_id); + + ~GrpcClientHostBufferStore() override; + + // Implements ClientHostBufferStore. + + uint64_t NextHandle() override; + Future Store(uint64_t handle, absl::string_view data) override; + Future Store(uint64_t handle, const absl::Cord& data) override; + Future> Lookup(uint64_t handle) override; + Future Delete(uint64_t handle) override; + + private: + const std::shared_ptr stub_; + const IfrtProxyVersion version_; + const uint64_t session_id_; + std::atomic next_handle_ = 0; + + // Implementation note: `lookup_work_queue_` may have closures that invoke + // user-defined code. Each `Lookup()` call is associated with a scheduled + // closure, and the closure is used to first perform synchronous reads of the + // streaming RPC, and then to do `promise.Set()` for the Future returned to + // the caller. + std::unique_ptr lookup_work_queue_; +}; + +} // namespace proxy +} // namespace ifrt +} // namespace xla + +#endif // XLA_PYTHON_IFRT_PROXY_CLIENT_GRPC_HOST_BUFFER_H_ diff --git a/third_party/xla/xla/python/ifrt_proxy/client/host_buffer.h b/third_party/xla/xla/python/ifrt_proxy/client/host_buffer.h new file mode 100644 index 00000000000000..ceaf51debc7d8f --- /dev/null +++ b/third_party/xla/xla/python/ifrt_proxy/client/host_buffer.h @@ -0,0 +1,62 @@ +/* + * Copyright 2023 The OpenXLA Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef XLA_PYTHON_IFRT_PROXY_CLIENT_HOST_BUFFER_H_ +#define XLA_PYTHON_IFRT_PROXY_CLIENT_HOST_BUFFER_H_ + +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/cord.h" +#include "absl/strings/string_view.h" +#include "xla/python/ifrt/future.h" + +namespace xla { +namespace ifrt { +namespace proxy { + +class ClientHostBufferStore { + public: + virtual ~ClientHostBufferStore() = default; + + virtual uint64_t NextHandle() = 0; + + // Stores the data associated with the given handle. Returns an error if the + // handle already exists. + virtual Future Store(uint64_t handle, + absl::string_view data) = 0; + + // Stores the data associated with the given handle. Returns an error if the + // handle already exists. + // TODO(b/315023499) Find a way to increase the chunk size + virtual Future Store(uint64_t handle, + const absl::Cord& data) = 0; + + // Retrieves the data associated with the handle. Returns an error if the + // handle does not exist. + virtual Future> Lookup(uint64_t handle) = 0; + + // Deletes the host buffer associated with the handle. Returns an error if the + // handle does not exist. + virtual Future Delete(uint64_t handle) = 0; +}; + +} // namespace proxy +} // namespace ifrt +} // namespace xla + +#endif // XLA_PYTHON_IFRT_PROXY_CLIENT_HOST_BUFFER_H_ diff --git a/third_party/xla/xla/python/ifrt_proxy/client/memory.h b/third_party/xla/xla/python/ifrt_proxy/client/memory.h new file mode 100644 index 00000000000000..1bf8c584cd3dd3 --- /dev/null +++ b/third_party/xla/xla/python/ifrt_proxy/client/memory.h @@ -0,0 +1,78 @@ +/* + * Copyright 2023 The OpenXLA Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef XLA_PYTHON_IFRT_PROXY_CLIENT_MEMORY_H_ +#define XLA_PYTHON_IFRT_PROXY_CLIENT_MEMORY_H_ + +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "xla/pjrt/pjrt_client.h" +#include "xla/python/ifrt/device.h" +#include "xla/python/ifrt/memory.h" + +namespace xla { +namespace ifrt { +namespace proxy { + +class Memory : public xla::ifrt::Memory { + public: + Memory(int id, std::string memory_space_kind, std::string debug_string, + std::string to_string) + : id_(id), + memory_space_kind_(std::move(memory_space_kind)), + debug_string_(std::move(debug_string)), + to_string_(std::move(to_string)) {} + + // Not copyable or movable: IFRT expects `string_view` from + // `memory_space_kind()` to be stable throughout the client's lifetime. + Memory(const Memory& other) = delete; + Memory& operator=(const Memory& other) = delete; + + PjRtClient* client() const override { return nullptr; } + + absl::Span devices() const override { + return devices_; + } + + int id() const override { return id_; } + + absl::string_view memory_space_kind() const override { + return memory_space_kind_; + } + + absl::string_view DebugString() const override { return debug_string_; } + + absl::string_view ToString() const override { return to_string_; } + + private: + friend class Client; // For `devices_` initialization. + + int id_; + std::vector devices_; + std::string memory_space_kind_; + std::string debug_string_; + std::string to_string_; +}; + +} // namespace proxy +} // namespace ifrt +} // namespace xla + +#endif // XLA_PYTHON_IFRT_PROXY_CLIENT_MEMORY_H_ diff --git a/third_party/xla/xla/python/ifrt_proxy/client/mock_client_session.h b/third_party/xla/xla/python/ifrt_proxy/client/mock_client_session.h new file mode 100644 index 00000000000000..6b2a5bda249898 --- /dev/null +++ b/third_party/xla/xla/python/ifrt_proxy/client/mock_client_session.h @@ -0,0 +1,50 @@ +/* + * Copyright 2023 The OpenXLA Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef XLA_PYTHON_IFRT_PROXY_CLIENT_MOCK_CLIENT_SESSION_H_ +#define XLA_PYTHON_IFRT_PROXY_CLIENT_MOCK_CLIENT_SESSION_H_ + +#include + +#include +#include "absl/status/status.h" +#include "xla/python/ifrt/future.h" +#include "xla/python/ifrt_proxy/client/client_session.h" +#include "xla/python/ifrt_proxy/common/ifrt_service.pb.h" + +namespace xla { +namespace ifrt { +namespace proxy { + +class MockClientSession final : public ClientSession { + public: + MOCK_METHOD(Future, Enqueue, (std::unique_ptr req), + (override)); + MOCK_METHOD(void, Finish, (const absl::Status& s), (override)); +}; + +ACTION_P(MockClientSessionReturnResponse, response_proto) { + auto response = std::make_unique(response_proto); + response->mutable_response_metadata()->set_op_id( + arg0->request_metadata().op_id()); + return Future(std::move(response)); +} + +} // namespace proxy +} // namespace ifrt +} // namespace xla + +#endif // XLA_PYTHON_IFRT_PROXY_CLIENT_MOCK_CLIENT_SESSION_H_ diff --git a/third_party/xla/xla/python/ifrt_proxy/client/mock_host_buffer.h b/third_party/xla/xla/python/ifrt_proxy/client/mock_host_buffer.h new file mode 100644 index 00000000000000..81d70cc4e93012 --- /dev/null +++ b/third_party/xla/xla/python/ifrt_proxy/client/mock_host_buffer.h @@ -0,0 +1,50 @@ +/* + * Copyright 2023 The OpenXLA Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef XLA_PYTHON_IFRT_PROXY_CLIENT_MOCK_HOST_BUFFER_H_ +#define XLA_PYTHON_IFRT_PROXY_CLIENT_MOCK_HOST_BUFFER_H_ + +#include + +#include +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/cord.h" +#include "absl/strings/string_view.h" +#include "xla/python/ifrt/future.h" +#include "xla/python/ifrt_proxy/client/host_buffer.h" + +namespace xla { +namespace ifrt { +namespace proxy { + +class MockClientHostBufferStore final : public ClientHostBufferStore { + public: + MOCK_METHOD(uint64_t, NextHandle, (), (override)); + MOCK_METHOD(Future, Store, + (uint64_t handle, absl::string_view data), (override)); + MOCK_METHOD(Future, Store, + (uint64_t handle, const absl::Cord& data), (override)); + MOCK_METHOD(Future>, Lookup, (uint64_t handle), + (override)); + MOCK_METHOD(Future, Delete, (uint64_t handle), (override)); +}; + +} // namespace proxy +} // namespace ifrt +} // namespace xla + +#endif // XLA_PYTHON_IFRT_PROXY_CLIENT_MOCK_HOST_BUFFER_H_ diff --git a/third_party/xla/xla/python/ifrt_proxy/client/py_module.cc b/third_party/xla/xla/python/ifrt_proxy/client/py_module.cc new file mode 100644 index 00000000000000..4b407bb438bb71 --- /dev/null +++ b/third_party/xla/xla/python/ifrt_proxy/client/py_module.cc @@ -0,0 +1,119 @@ +// Copyright 2023 The OpenXLA Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include + +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/log/log_entry.h" +#include "absl/log/log_sink.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "pybind11/cast.h" // from @pybind11 +#include "pybind11/detail/common.h" // from @pybind11 +#include "pybind11/functional.h" // from @pybind11 // NOLINT // IWYU pragma: keep +#include "pybind11/gil.h" // from @pybind11 +#include "pybind11/pybind11.h" // from @pybind11 +#include "pybind11/pytypes.h" // from @pybind11 +#include "pybind11_abseil/absl_casters.h" // from @pybind11_abseil // NOLINT // IWYU pragma: keep +#include "pybind11_protobuf/native_proto_caster.h" // from @pybind11_protobuf +#include "xla/pjrt/status_casters.h" +#include "xla/python/ifrt/client.h" +#include "xla/python/ifrt_proxy/client/registry.h" +#include "xla/python/py_client.h" +#include "tsl/platform/env.h" +#include "tsl/platform/statusor.h" + +namespace xla { +namespace ifrt { +namespace proxy { +namespace { + +struct PyClientConnectionOptions { + std::function on_disconnect; + std::function on_connection_update; +}; + +absl::StatusOr> GetClient( + std::string proxy_server_address, + const PyClientConnectionOptions& py_options) { + DCHECK(PyGILState_Check()); + std::unique_ptr client; + + ClientConnectionOptions options; + if (py_options.on_disconnect) { + // While it is possible to pass around `py_options.on_disconnect` without + // wrapping it via a shared_ptr, copying the `py_options.on_disconnect` + // object can internally attempt to acquire the GIL [1], and can thus block + // or even deadlock. A unique_ptr or `absl::AnyInvocable` is not sufficient + // because downstream code can make copies. Reference: + // https://pybind11.readthedocs.io/en/stable/advanced/misc.html#common-sources-of-global-interpreter-lock-errors + auto py_on_disconnect = std::make_shared>( + std::move(py_options.on_disconnect)); + + options.on_disconnect = + [on_disconnect = std::move(py_on_disconnect)](absl::Status s) mutable { + LOG(WARNING) << "Connection to server failed, calling supplied " + << "`on_disconnect` function: " << s; + tsl::Env::Default()->SchedClosure([s, on_disconnect]() mutable { + pybind11::gil_scoped_acquire gil_acquire; + (*on_disconnect)(s); + on_disconnect = nullptr; + }); + }; + } + + if (py_options.on_connection_update) { + auto fn = std::make_shared>( + std::move(py_options.on_connection_update)); + options.on_connection_update = [fn](absl::string_view log_line) -> void { + tsl::Env::Default()->SchedClosure([fn, str = std::string(log_line)] { + pybind11::gil_scoped_acquire gil_acquire; + (*fn)(std::string(str)); + }); + }; + } + + { + pybind11::gil_scoped_release gil_release; + TF_ASSIGN_OR_RETURN(client, CreateClient(proxy_server_address, options)); + } + + // Constructing `xla::PyClient` requires GIL as it may dec-ref Python objects. + return std::make_shared(std::move(client)); +} + +} // namespace +} // namespace proxy +} // namespace ifrt +} // namespace xla + +PYBIND11_MODULE(py_module, m) { + pybind11_protobuf::ImportNativeProtoCasters(); + + using ::xla::ifrt::proxy::PyClientConnectionOptions; + pybind11::class_(m, "ClientConnectionOptions") + .def(pybind11::init<>()) + .def_readwrite("on_disconnect", &PyClientConnectionOptions::on_disconnect) + .def_readwrite("on_connection_update", + &PyClientConnectionOptions::on_connection_update); + + m.def("get_client", xla::ValueOrThrowWrapper(xla::ifrt::proxy::GetClient), + pybind11::arg("proxy_server_address"), pybind11::arg("options")); +} diff --git a/third_party/xla/xla/python/ifrt_proxy/client/registry.cc b/third_party/xla/xla/python/ifrt_proxy/client/registry.cc new file mode 100644 index 00000000000000..11680771b8b49b --- /dev/null +++ b/third_party/xla/xla/python/ifrt_proxy/client/registry.cc @@ -0,0 +1,102 @@ +// Copyright 2023 The OpenXLA Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "xla/python/ifrt_proxy/client/registry.h" + +#include +#include +#include +#include + +#include "absl/base/thread_annotations.h" +#include "absl/container/flat_hash_map.h" +#include "absl/log/check.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "absl/strings/string_view.h" +#include "absl/synchronization/mutex.h" +#include "xla/python/ifrt/client.h" + +namespace xla { +namespace ifrt { +namespace proxy { + +namespace { + +using FactoryFn = + std::function>( + absl::string_view, const ClientConnectionOptions&)>; + +struct Registry { + absl::Mutex mu; + absl::flat_hash_map factories ABSL_GUARDED_BY(mu); +}; + +Registry* registry() { + static auto* r = new Registry(); + return r; +} + +} // namespace + +void RegisterClientFactory(absl::string_view transport_name, + FactoryFn factory) { + absl::MutexLock l(®istry()->mu); + const bool inserted = + registry() + ->factories.insert({std::string(transport_name), factory}) + .second; + CHECK(inserted) << "IFRT proxy transport '" << transport_name + << "' already registered"; +} + +absl::StatusOr> CreateClient( + absl::string_view proxy_server_address, + const ClientConnectionOptions& options) { + const size_t pos = proxy_server_address.find("://"); + if (pos == std::string::npos) { + return absl::InvalidArgumentError( + absl::StrCat("IFRT proxy server address must be " + "'://' (e.g., " + "'grpc://localhost'), but got ", + proxy_server_address)); + } + + const absl::string_view transport_name = proxy_server_address.substr(0, pos); + const absl::string_view address = proxy_server_address.substr(pos + 3); + + FactoryFn factory; + { + absl::MutexLock l(®istry()->mu); + const auto it = registry()->factories.find(transport_name); + if (it == registry()->factories.end()) { + return absl::NotFoundError( + absl::StrCat("IFRT proxy transport '", transport_name, + "' not found; available transports are: ", + absl::StrJoin(registry()->factories, ", ", + [](std::string* out, const auto& it) { + out->append(it.first); + }))); + } + factory = it->second; + } + + return factory(address, options); +} + +} // namespace proxy +} // namespace ifrt +} // namespace xla diff --git a/third_party/xla/xla/python/ifrt_proxy/client/registry.h b/third_party/xla/xla/python/ifrt_proxy/client/registry.h new file mode 100644 index 00000000000000..ebf04532b278eb --- /dev/null +++ b/third_party/xla/xla/python/ifrt_proxy/client/registry.h @@ -0,0 +1,68 @@ +/* + * Copyright 2023 The OpenXLA Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef XLA_PYTHON_IFRT_PROXY_CLIENT_REGISTRY_H_ +#define XLA_PYTHON_IFRT_PROXY_CLIENT_REGISTRY_H_ + +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/time/time.h" +#include "xla/python/ifrt/client.h" + +namespace xla { +namespace ifrt { +namespace proxy { + +struct ClientConnectionOptions { + // Timeout for establishing the connection. + absl::Duration connection_timeout = absl::Minutes(2); + + // A callback that (if it is not set to nullptr) will be called if there was a + // successful connection to the proxy server, but there was a later + // disconnect. The callback may be called synchronously from a thread that + // performs various important activities, and therefore should not block on + // any events (or deadlocks may happen). + std::function on_disconnect = nullptr; + + // Captures logs related to establishing the connection. Logs may be generated + // synchronously from a thread that performs various important activities, + // so the function should not block (or deadlocks may happen). + std::function on_connection_update = nullptr; +}; + +// Registers a new factory for client backend implementation. Crashes if the +// same backend name is registered more than once. +void RegisterClientFactory( + absl::string_view transport_name, + std::function>( + absl::string_view address, const ClientConnectionOptions& options)> + factory); + +// Creates a client for the given backend target. The backend target string must +// be in the form of `:`. +absl::StatusOr> CreateClient( + absl::string_view proxy_server_address, + const ClientConnectionOptions& options = ClientConnectionOptions()); + +} // namespace proxy +} // namespace ifrt +} // namespace xla + +#endif // XLA_PYTHON_IFRT_PROXY_CLIENT_REGISTRY_H_ diff --git a/third_party/xla/xla/python/ifrt_proxy/client/rpc_helper.cc b/third_party/xla/xla/python/ifrt_proxy/client/rpc_helper.cc new file mode 100644 index 00000000000000..e07f689513e6a7 --- /dev/null +++ b/third_party/xla/xla/python/ifrt_proxy/client/rpc_helper.cc @@ -0,0 +1,175 @@ +// Copyright 2023 The OpenXLA Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "xla/python/ifrt_proxy/client/rpc_helper.h" + +#include +#include +#include + +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/synchronization/mutex.h" +#if defined(PLATFORM_GOOGLE) +#include "absl/types/source_location.h" +#endif +#include "xla/python/ifrt/future.h" +#include "xla/python/ifrt_proxy/client/client_session.h" +#include "xla/python/ifrt_proxy/common/ifrt_service.pb.h" +#include "tsl/platform/status_to_from_proto.h" + +namespace xla { +namespace ifrt { +namespace proxy { + +// DoRpc is a templated function that implements the logic of all RPC-wrapping +// functions of `RpcHelper`, such as `RpcHelper::MakeArrayFromHostBuffer()`. +template +Future>> DoRpc( + ClientSession* session, RequestMetadata metadata, + void (IfrtRequest::*set_req)(Req*), Resp* (IfrtResponse::*get_resp)(), + bool (IfrtResponse::*has_resp)() const, std::unique_ptr req) { + auto ifrt_req = std::make_unique(); + *ifrt_req->mutable_request_metadata() = metadata; + (ifrt_req.get()->*set_req)(req.release()); + + auto promise = Future>>::CreatePromise(); + auto on_ready = [promise, has_resp, + get_resp](ClientSession::Response r) mutable { + if (!r.ok()) { + LOG(ERROR) << "Connection to IFRT proxy server was terminated: " + << r.status(); + promise.Set(absl::UnavailableError( + absl::StrCat("Connection to IFRT proxy server was terminated: ", + r.status().ToString()))); + return; + } + + std::shared_ptr response = *std::move(r); + if (!response->has_response_metadata()) { + promise.Set(absl::InternalError( + absl::StrCat("IFRT server sent a message without metadata: ", + response->DebugString()))); + return; + } + + const absl::Status metadata_status = + tsl::StatusFromProto(response->response_metadata().status()); + const bool has_expected_response = (response.get()->*has_resp)(); + const auto has_some_response = + response->response_case() != IfrtResponse::RESPONSE_NOT_SET; + + if (metadata_status.ok() && !has_some_response) { + promise.Set(absl::InternalError( + absl::StrCat("OK response with no actual response set: ", + response->DebugString()))); + return; + } + + if (!has_expected_response && has_some_response) { + promise.Set(absl::InternalError(absl::StrCat( + "Response with wrong type (expected ", Resp::GetDescriptor()->name(), + "): ", response->DebugString()))); + return; + } + + // If the metadata_status is not-OK, according to ifrt_service.proto, + // there may be an error _instead_ of an actual response value. So, check if + // an actual response value exists, and if so return it irrespective of what + // the metadata_status says. + if (!has_some_response) { + promise.Set(metadata_status); + } else { + promise.Set( + std::make_shared(*std::move((response.get()->*get_resp)()))); + } + }; + session->Enqueue(std::move(ifrt_req)).OnReady(on_ready); + + return Future>>(promise); +} + +RequestMetadata RpcHelper::ManufactureRequestMetadata() { + RequestMetadata result; + { + absl::MutexLock l(&mu_); + result.set_op_id(next_op_id_++); + } + int prev_op_id = result.op_id() - 1; + if (prev_op_id != 0) { + // TODO(b/266635130): Depend only on necessary prior operations. + result.add_dependencies(prev_op_id); + } + // TODO(b/282757875): Add a ClearOps RPC for old dependencies. + return result; +} + +void RpcHelper::Disconnect() { + session_->Finish(absl::CancelledError("Disconnected by client")); +} + +// TODO(b/266635130): Remove this preprocessor macro. Preprocessor macros +// go against the style guide, but are convenient as we are introducing more +// RPCs and are making changes to the exact signature of the DoRpc function. +#define RPC(METHOD, PROPERTY) \ + RpcHelper::ResponseFuture RpcHelper::METHOD( \ + std::unique_ptr req) { \ + return DoRpc(session_.get(), ManufactureRequestMetadata(), \ + &IfrtRequest::set_allocated_##PROPERTY##_request, \ + &IfrtResponse::mutable_##PROPERTY##_response, \ + &IfrtResponse::has_##PROPERTY##_response, std::move(req)); \ + } + +RPC(Init, init); +RPC(GetDefaultDeviceAssignment, get_default_device_assignment); +RPC(CheckFuture, check_future); +RPC(MakeArrayFromHostBuffer, make_array_from_host_buffer); +RPC(AssembleArrayFromSingleDeviceArrays, + assemble_array_from_single_device_arrays); +RPC(DisassembleIntoSingleDeviceArrays, disassemble_into_single_device_arrays); +RPC(CopyToHostBuffer, copy_to_host_buffer); +RPC(CheckArrayReady, check_array_ready); +RPC(IsArrayDeleted, is_array_deleted); +RPC(DestructArray, destruct_array) +RPC(Reshard, reshard); +RPC(FullyReplicatedShard, fully_replicated_shard); +RPC(DeleteArray, delete_array); +RPC(Compile, compile); +RPC(LoadedExecutableMetadata, loaded_executable_metadata); +RPC(LoadedExecutableExecute, loaded_executable_execute); +RPC(LoadedExecutableDelete, loaded_executable_delete); +RPC(LoadedExecutableIsDeleted, loaded_executable_is_deleted); +RPC(LoadedExecutableDestruct, loaded_executable_destruct); +RPC(LoadedHostCallbackPoll, loaded_host_callback_poll); +RPC(LoadedHostCallbackReturn, loaded_host_callback_return); + +Future RpcHelper::CheckFuture(uint64_t handle) { + auto req = std::make_unique(); + req->set_future_handle(handle); + + auto promise = Future::CreatePromise(); + CheckFuture(std::move(req)) + .OnReady( + [promise](absl::StatusOr> + response) mutable { promise.Set(response.status()); }); + + return Future(promise); +} + +} // namespace proxy +} // namespace ifrt +} // namespace xla diff --git a/third_party/xla/xla/python/ifrt_proxy/client/rpc_helper.h b/third_party/xla/xla/python/ifrt_proxy/client/rpc_helper.h new file mode 100644 index 00000000000000..b5c3bf6340241d --- /dev/null +++ b/third_party/xla/xla/python/ifrt_proxy/client/rpc_helper.h @@ -0,0 +1,150 @@ +/* + * Copyright 2023 The OpenXLA Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef XLA_PYTHON_IFRT_PROXY_CLIENT_RPC_HELPER_H_ +#define XLA_PYTHON_IFRT_PROXY_CLIENT_RPC_HELPER_H_ + +#include +#include +#include + +#include "absl/base/thread_annotations.h" +#include "absl/log/check.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/synchronization/mutex.h" +#include "xla/python/ifrt/future.h" +#include "xla/python/ifrt_proxy/client/client_session.h" +#include "xla/python/ifrt_proxy/client/host_buffer.h" +#include "xla/python/ifrt_proxy/common/ifrt_service.pb.h" + +namespace xla { +namespace ifrt { +namespace proxy { + +// RpcHelper helps establish a connection with the IFRT server and perform +// logical RPCs on the connection. +// +// TODO(b/266635130): RpcHelper currently makes each logical RPC order-dependent +// on the previous RPC it was asked to make. Instead, allow users of RpcHelper +// specify the necessary dependency. +class RpcHelper { + public: + RpcHelper(IfrtProxyVersion version, std::shared_ptr session) + : version_(std::move(version)), session_(std::move(session)) {} + + void Disconnect(); + + RpcHelper(const RpcHelper&) = delete; + RpcHelper& operator=(const RpcHelper&) = delete; + ~RpcHelper() { Disconnect(); } + + // IFRT Proxy version negotiated between the client and the server. + const IfrtProxyVersion& version() const { return version_; } + + // Initializes the host buffer store for this RpcHelper instance. This must be + // called exactly once during initialization before `host_buffer_store()` is + // called. + void set_host_buffer_store( + std::shared_ptr host_buffer_store) { + CHECK(host_buffer_store_ == nullptr); + host_buffer_store_ = std::move(host_buffer_store); + } + + const std::shared_ptr& host_buffer_store() const { + return host_buffer_store_; + } + + template + using ResponseFuture = Future>>; + + // Wrapper function for various logical RPCs defined in ifrt_service.proto. + // Whenever the RPC finishes, `on_done` will be called with the result or the + // return status. `on_done` can be called with various locks held and should + // return quickly without blocking on any event. `on_done` is guaranteed to be + // called exactly once. + // + // The functions can be invoked after the connection is broken, but will + // result in `on_done` getting called with an error (see + // "WrapAsConnectionError" in `rpc_helper.cc`). + + ResponseFuture Init(std::unique_ptr req); + ResponseFuture GetDefaultDeviceAssignment( + std::unique_ptr req); + + ResponseFuture CheckFuture( + std::unique_ptr req); + + ResponseFuture MakeArrayFromHostBuffer( + std::unique_ptr req); + ResponseFuture + AssembleArrayFromSingleDeviceArrays( + std::unique_ptr req); + ResponseFuture + DisassembleIntoSingleDeviceArrays( + std::unique_ptr req); + ResponseFuture CopyToHostBuffer( + std::unique_ptr req); + ResponseFuture CheckArrayReady( + std::unique_ptr req); + ResponseFuture Reshard(std::unique_ptr req); + ResponseFuture FullyReplicatedShard( + std::unique_ptr req); + ResponseFuture IsArrayDeleted( + std::unique_ptr req); + ResponseFuture DeleteArray( + std::unique_ptr req); + ResponseFuture DestructArray( + std::unique_ptr req); + + ResponseFuture Compile(std::unique_ptr req); + + ResponseFuture LoadedExecutableMetadata( + std::unique_ptr req); + ResponseFuture LoadedExecutableExecute( + std::unique_ptr req); + ResponseFuture LoadedExecutableDelete( + std::unique_ptr req); + ResponseFuture LoadedExecutableIsDeleted( + std::unique_ptr req); + ResponseFuture LoadedExecutableDestruct( + std::unique_ptr req); + + ResponseFuture LoadedHostCallbackPoll( + std::unique_ptr req); + ResponseFuture LoadedHostCallbackReturn( + std::unique_ptr req); + + // Utility functions for common functions. + + Future CheckFuture(uint64_t handle); + + private: + RequestMetadata ManufactureRequestMetadata() ABSL_LOCKS_EXCLUDED(mu_); + + const IfrtProxyVersion version_; + const std::shared_ptr session_; + std::shared_ptr host_buffer_store_; + + absl::Mutex mu_; + uint64_t next_op_id_ ABSL_GUARDED_BY(mu_) = 1; +}; + +} // namespace proxy +} // namespace ifrt +} // namespace xla + +#endif // XLA_PYTHON_IFRT_PROXY_CLIENT_RPC_HELPER_H_ diff --git a/third_party/xla/xla/python/ifrt_proxy/client/version.h b/third_party/xla/xla/python/ifrt_proxy/client/version.h new file mode 100644 index 00000000000000..06df1e0c70b005 --- /dev/null +++ b/third_party/xla/xla/python/ifrt_proxy/client/version.h @@ -0,0 +1,32 @@ +/* + * Copyright 2023 The OpenXLA Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef XLA_PYTHON_IFRT_PROXY_CLIENT_VERSION_H_ +#define XLA_PYTHON_IFRT_PROXY_CLIENT_VERSION_H_ + +namespace xla { +namespace ifrt { +namespace proxy { + +// TODO(b/296144873): Document the version upgrade policy. +inline constexpr int kClientMinVersion = 1; +inline constexpr int kClientMaxVersion = 1; + +} // namespace proxy +} // namespace ifrt +} // namespace xla + +#endif // XLA_PYTHON_IFRT_PROXY_CLIENT_VERSION_H_ diff --git a/third_party/xla/xla/python/ifrt_proxy/common/BUILD b/third_party/xla/xla/python/ifrt_proxy/common/BUILD new file mode 100644 index 00000000000000..39f0c998f73d6c --- /dev/null +++ b/third_party/xla/xla/python/ifrt_proxy/common/BUILD @@ -0,0 +1,180 @@ +# Copyright 2023 The OpenXLA Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("//xla/python/ifrt_proxy/common:ifrt_proxy.bzl", "default_ifrt_proxy_visibility", "ifrt_proxy_cc_test") +load("@local_tsl//tsl:tsl.bzl", "if_google") +load("@local_tsl//tsl/platform:build_config.bzl", "tf_proto_library") +# copybara:uncomment load("@bazel_skylib//:bzl_library.bzl", "bzl_library") + +package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = default_ifrt_proxy_visibility, +) + +# Export headers referenced by the google-internal-version of grpc_credentials. +exports_files( + ["grpc_credentials.h"], + visibility = if_google( + ["//xla/python/ifrt_proxy/common/google:__pkg__"], + ["//visibility:private"], + ), +) + +cc_library( + name = "grpc_credentials", + hdrs = ["grpc_credentials.h"], + deps = if_google( + ["//xla/python/ifrt_proxy/common/google:grpc_credentials_lib"], + [":grpc_credentials_oss_lib"], + ) + ["@com_github_grpc_grpc//:grpc++"], +) + +cc_library( + name = "grpc_credentials_oss_lib", + srcs = [ + "grpc_credentials.cc", + "grpc_credentials.h", + ], + visibility = ["//visibility:private"], + deps = [ + "@com_github_grpc_grpc//:grpc++", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@local_tsl//tsl/platform", + ], + alwayslink = True, +) + +tf_proto_library( + name = "types_proto", + srcs = ["types.proto"], + protodeps = ["//xla/python/ifrt:serdes_proto"], +) + +tf_proto_library( + name = "ifrt_service_proto", + srcs = ["ifrt_service.proto"], + protodeps = [ + ":types_proto", + "//xla:xla_data_proto", + "//xla/pjrt:execute_options_proto", + "//xla/python/ifrt:serdes_proto", + "@local_tsl//tsl/protobuf:status_proto", + ], +) + +tf_proto_library( + name = "grpc_ifrt_service_proto", + srcs = ["grpc_ifrt_service.proto"], + has_services = True, + create_go_proto = False, + create_grpc_library = True, + create_java_proto = False, + create_kotlin_proto = False, + protodeps = [":ifrt_service_proto"], +) + +cc_library( + name = "types", + srcs = ["types.cc"], + hdrs = ["types.h"], + deps = [ + ":ifrt_service_proto_cc", + ":types_proto_cc", + "//xla/pjrt:pjrt_common", + "//xla/python/ifrt", + "//xla/python/ifrt:serdes", + "//xla/python/ifrt:sharding_serdes", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + "@llvm-project//llvm:Support", + "@local_tsl//tsl/platform:statusor", + ] + if_google(["@com_google_absl//absl/types:source_location"]), +) + +ifrt_proxy_cc_test( + name = "types_test", + srcs = ["types_test.cc"], + deps = [ + ":types", + ":types_proto_cc", + "//xla/pjrt:pjrt_common", + "//xla/python/ifrt", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + "@com_google_googletest//:gtest_main", + "@local_tsl//tsl/platform:status_matchers", + "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/platform:test", + ], +) + +cc_library( + name = "array_util", + srcs = ["array_util.cc"], + hdrs = ["array_util.h"], + deps = [ + "//xla/python/ifrt", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@local_tsl//tsl/platform:statusor", + ], +) + +ifrt_proxy_cc_test( + name = "array_util_test", + srcs = ["array_util_test.cc"], + deps = [ + ":array_util", + "//xla/python/ifrt", + "@com_google_absl//absl/strings:string_view", + "@com_google_googletest//:gtest_main", + "@local_tsl//tsl/platform:status_matchers", + "@local_tsl//tsl/platform:statusor", + ], +) + +# common_serdes is a collection of all common libraries that register SerDes implementations. +cc_library( + name = "common_serdes", + deps = ["//xla/python/pjrt_ifrt:xla_program_serdes"], + alwayslink = True, +) + +cc_library( + name = "proto_util", + srcs = ["proto_util.cc"], + hdrs = ["proto_util.h"], + deps = [ + ":ifrt_service_proto_cc", + "@com_google_absl//absl/log", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings:string_view", + "@local_tsl//tsl/platform:status_to_from_proto", + ], +) + +# copybara:uncomment_begin +# bzl_library( +# name = "ifrt_proxy_bzl", +# srcs = ["ifrt_proxy.bzl"], +# parse_tests = False, +# visibility = ["//visibility:private"], +# ) +# copybara:uncomment_end diff --git a/third_party/xla/xla/python/ifrt_proxy/common/array_util.cc b/third_party/xla/xla/python/ifrt_proxy/common/array_util.cc new file mode 100644 index 00000000000000..bdcf8a13dfcc8e --- /dev/null +++ b/third_party/xla/xla/python/ifrt_proxy/common/array_util.cc @@ -0,0 +1,156 @@ +// Copyright 2023 The OpenXLA Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "xla/python/ifrt_proxy/common/array_util.h" + +#include +#include + +#include "absl/log/check.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "absl/strings/string_view.h" +#include "xla/python/ifrt/dtype.h" +#include "xla/python/ifrt/shape.h" +#include "tsl/platform/statusor.h" + +namespace xla { +namespace ifrt { +namespace proxy { + +namespace { + +std::string StridesAsStr(const ArrayMemRegion::ByteStrides& strides) { + if (!strides.has_value()) return "strides{nullopt}"; + return absl::StrCat("strides{", absl::StrJoin(*strides, ","), "}"); +} + +} // namespace + +absl::StatusOr> DefaultByteStrides(const DType dtype, + const Shape& shape) { + if (!dtype.byte_size().has_value()) { + return absl::InvalidArgumentError( + absl::StrCat("Unsupported data type to query byte-strides for: ", + dtype.DebugString())); + } + std::vector result(shape.dims().size()); + int64_t stride = *dtype.byte_size(); + for (int i = static_cast(shape.dims().size()) - 1; i >= 0; --i) { + result[i] = stride; + stride *= shape.dims()[i]; + } + return result; +} + +absl::StatusOr ArrayMemRegion::FromZerothElementPointer( + const void* zeroth_element, const DType dtype, const Shape& shape, + ByteStrides byte_strides) { + if (!dtype.byte_size().has_value()) { + return absl::InvalidArgumentError( + absl::StrCat("Unsupported data type to construct ArrayMemRegion: ", + dtype.DebugString())); + } + // Below, we return an error for all situations where the zeroth_element + // is different from mem_region_start. + void* const mem_region_start = const_cast(zeroth_element); + + if (!byte_strides.has_value() || + (byte_strides->empty() && shape.dims().empty())) { + return ArrayMemRegion(mem_region_start, + dtype.byte_size().value() * shape.num_elements()); + } + if (shape.num_elements() == 0) { + return ArrayMemRegion(mem_region_start, 0); + } + if (shape.dims().size() != byte_strides->size()) { + return absl::InvalidArgumentError( + absl::StrCat("Shape has different dimensions from byte_strides: ", + shape.DebugString(), " vs ", StridesAsStr(byte_strides))); + } + // Logic based on + // https://numpy.org/doc/stable/reference/generated/numpy.ndarray.strides.html + // + // So long as all strides are positive, the array's memory region begins at + // the zeroth element, and the last element of the array is farthest off from + // the beginning. We use the offset of the last element of the array to + // calculate the memory region. Note that this reasoning does not apply to + // negative strides, since the zeroth element can then be in the middle of the + // memory region (as an example, consider shape=[10, 10] and + // element_strides=[10,-1]). + uint64_t last_element_byte_offset = 0; + for (int i = 0; i < byte_strides->size(); ++i) { + int stride = (*byte_strides)[i]; + if (shape.dims()[i] < 0) { + return absl::InvalidArgumentError( + absl::StrCat("A shape dimension is negative: ", shape.DebugString())); + } else if (shape.dims()[i] == 1) { + // The stride shouldn't matter in this case, so continue without checking + // validity of the given stride. + continue; + } else if (stride <= 0) { + return absl::UnimplementedError( + absl::StrCat("Negative or zero strides are not fully supported: ", + StridesAsStr(byte_strides))); + } else if (stride % dtype.byte_size().value() != 0) { + return absl::UnimplementedError(absl::StrCat( + "byte_stride[", i, "] is not a multiple of the data-type's size: ", + StridesAsStr(byte_strides), ", dtype=", dtype.DebugString())); + } else { + // `shape.dims()[i]` cannot be negative (we explicitly check for this + // above) or zero (we return early for `shape.num_elements() == 0`). + DCHECK_GT(shape.dims()[i], 0); + last_element_byte_offset += (stride * (shape.dims()[i] - 1)); + } + } + return ArrayMemRegion(mem_region_start, + last_element_byte_offset + dtype.byte_size().value()); +} + +absl::StatusOr ArrayMemRegion::FromMinimalMemRegion( + absl::string_view mem_region, const DType dtype, const Shape& shape, + ByteStrides byte_strides) { + // FromZerothElementPointer() currently returns an error for any situation + // where the zeroth_element will is not equal to the place where the minimal + // memory region starts. + TF_ASSIGN_OR_RETURN( + auto result, + FromZerothElementPointer(mem_region.data(), dtype, shape, byte_strides)); + + if (result.mem_region().size() != mem_region.size()) { + return absl::InvalidArgumentError( + absl::StrCat("Incorrect size ", result.mem_region().size(), " vs ", + mem_region.size(), "; is provided memory region minimal? ", + dtype.DebugString(), " ", shape.DebugString(), " ", + StridesAsStr(byte_strides))); + } + CHECK_EQ(result.mem_region().data(), mem_region.data()); + return result; +} + +absl::string_view ArrayMemRegion::mem_region() const { + return absl::string_view(static_cast(mem_region_start_), nbytes_); +} + +void* ArrayMemRegion::zeroth_element() const { + // ArrayMemRegion cannot yet be constructed for situations where the + // zeroth element pointer is different from mem_region_start_. + return mem_region_start_; +} + +} // namespace proxy +} // namespace ifrt +} // namespace xla diff --git a/third_party/xla/xla/python/ifrt_proxy/common/array_util.h b/third_party/xla/xla/python/ifrt_proxy/common/array_util.h new file mode 100644 index 00000000000000..2ba8ff7ce42567 --- /dev/null +++ b/third_party/xla/xla/python/ifrt_proxy/common/array_util.h @@ -0,0 +1,78 @@ +/* + * Copyright 2023 The OpenXLA Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef XLA_PYTHON_IFRT_PROXY_COMMON_ARRAY_UTIL_H_ +#define XLA_PYTHON_IFRT_PROXY_COMMON_ARRAY_UTIL_H_ + +#include +#include + +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "xla/python/ifrt/dtype.h" +#include "xla/python/ifrt/shape.h" + +namespace xla { +namespace ifrt { +namespace proxy { + +// Returns the byte-strides corresponding to the compact major-to-minor layout. +absl::StatusOr> DefaultByteStrides(DType dtype, + const Shape& shape); + +// Denotes a chunk of contiguous memory that contains all elements of the +// in-host (RAM) representation of an Array. +class ArrayMemRegion { + public: + // Nullopt implies compact major-to-minor layout, as returned by + // `DefaultByteStrides()`. + using ByteStrides = std::optional>; + + // Constructs an ArrayMemRegion given `mem_region`, where `mem_region` is + // minimal, i.e., the lower-most and upper-most addresses of `mem_region` are + // necessary to retrieve elements from the array. + static absl::StatusOr FromMinimalMemRegion( + absl::string_view mem_region, DType dtype, const Shape& shape, + ByteStrides byte_strides); + + // Constructs an ArrayMemRegion given a pointer to the zeroth-element of the + // (in-host representation of the) Array. + static absl::StatusOr FromZerothElementPointer( + const void* zeroth_element, DType dtype, const Shape& shape, + ByteStrides byte_strides); + + // Returns a region of memory whose lower-most and upper-most addresses are + // necessary to retrieve elements of the (in-host representation of) the + // array. + absl::string_view mem_region() const; + + // Returns a pointer to the zeroth-element of the (in-host representation of + // the) Array. + void* zeroth_element() const; + + private: + ArrayMemRegion(void* mem_region_start, size_t nbytes) + : mem_region_start_(mem_region_start), nbytes_(nbytes) {} + + void* const mem_region_start_; + const size_t nbytes_; +}; + +} // namespace proxy +} // namespace ifrt +} // namespace xla + +#endif // XLA_PYTHON_IFRT_PROXY_COMMON_ARRAY_UTIL_H_ diff --git a/third_party/xla/xla/python/ifrt_proxy/common/array_util_test.cc b/third_party/xla/xla/python/ifrt_proxy/common/array_util_test.cc new file mode 100644 index 00000000000000..51e189bb9ffc77 --- /dev/null +++ b/third_party/xla/xla/python/ifrt_proxy/common/array_util_test.cc @@ -0,0 +1,201 @@ +// Copyright 2023 The OpenXLA Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "xla/python/ifrt_proxy/common/array_util.h" + +#include +#include +#include +#include +#include + +#include +#include +#include "absl/strings/string_view.h" +#include "xla/python/ifrt/dtype.h" +#include "xla/python/ifrt/shape.h" +#include "tsl/platform/status_matchers.h" +#include "tsl/platform/statusor.h" + +namespace xla { +namespace ifrt { +namespace proxy { + +namespace { + +using ::testing::ElementsAre; +using ::testing::Not; +using ::testing::TestWithParam; +using ::tsl::testing::IsOk; +using ::tsl::testing::IsOkAndHolds; + +constexpr DType::Kind kF64 = DType::Kind::kF64; +constexpr DType::Kind kS32 = DType::Kind::kS32; +constexpr DType::Kind kString = DType::Kind::kString; +using Strides = std::vector; + +TEST(DefaultByteStrides, ErrorsIfBadDtype) { + EXPECT_THAT(DefaultByteStrides(DType(kString), Shape({1})), Not(IsOk())); +} + +TEST(DefaultByteStrides, HappyCase) { + EXPECT_THAT(DefaultByteStrides(DType(kF64), Shape({4, 3, 5})), + IsOkAndHolds(ElementsAre(120, 40, 8))); +} + +// TC represents a testcase. +struct TC { + const std::string test_name; + const DType::Kind dtype_kind; + const std::vector shape; + const std::optional> byte_strides; + const std::optional expected_size; +}; +std::string PrintToString(const TC& tc) { return tc.test_name; } + +class ArrayMemRegionSuccess : public TestWithParam {}; +INSTANTIATE_TEST_SUITE_P( + Tests, ArrayMemRegionSuccess, + testing::Values( + // F64 + TC{"DefaultF64", kF64, {4, 3, 5}, std::nullopt}, + TC{"MajorToMinorStridesF64", kF64, {4, 3, 5}, Strides({120, 40, 8})}, + TC{"NotMajorToMinorF64", kF64, {3, 4, 5}, Strides({40, 120, 8})}, + TC{"TransposedF64", kF64, {5, 3, 4}, Strides({8, 40, 120})}, + // S32 + TC{"DefaultS32", kS32, {4, 3, 5}, std::nullopt}, + TC{"MajorToMinorStridesS32", kS32, {4, 3, 5}, Strides({60, 20, 4})}, + TC{"NotMajorToMinorS32", kS32, {3, 4, 5}, Strides({20, 60, 4})}, + TC{"TransposedS32", kS32, {5, 3, 4}, Strides({4, 20, 60})}, + // Scalar + TC{"ScalarF64DefaultStrides", kF64, {}, std::nullopt}, + TC{"ScalarF64EmptyStrides", kF64, {}, Strides({})}, + // Zero elements + TC{"NoColsDefaultStrides", kF64, {5, 0}, std::nullopt}, + TC{"NoColsStridesNonZero", kF64, {5, 0}, Strides({40, 4})}, + TC{"NoColsStridesZero", kF64, {5, 0}, Strides({0, 0})}, + TC{"NoRowsDefaultStrides", kF64, {0, 5}, std::nullopt}, + TC{"NoRowsStridesNonZero", kF64, {0, 5}, Strides({40, 4})}, + TC{"NoRowsStridesZero", kF64, {0, 5}, Strides({0, 0})}, + // Dimension with size 1 + TC{"SingleElementArbitraryStrides", kF64, {1, 1}, Strides({100, 100})}, + TC{"OneRowArbitraryColStride", kF64, {1, 5}, Strides({100, 8})}, + TC{"OneColArbitraryRowStride", kF64, {5, 1}, Strides({8, 100})}, + TC{"OneRowZeroColStride", kF64, {1, 5}, Strides({0, 8})}, + TC{"OneColZeroRowStride", kF64, {5, 1}, Strides({8, 0})}, + // Non-compact strides. + TC{"NonCompactSingleDimension", kS32, {5}, Strides({16}), 68}, + TC{"NonCompactDim0", kS32, {4, 3, 5}, Strides({120, 20, 4}), 420}, + TC{"PaddedElements", kS32, {4, 3, 5}, Strides({120, 40, 8}), 476}), + testing::PrintToStringParamName()); +TEST_P(ArrayMemRegionSuccess, TestCase) { + const TC tc = GetParam(); + const DType dtype(tc.dtype_kind); + const Shape shape(tc.shape); + const size_t expected_size = tc.expected_size.value_or( + dtype.byte_size().value() * shape.num_elements()); + std::string data(expected_size, 'a'); + + TF_ASSERT_OK_AND_ASSIGN(auto mem_region1, + ArrayMemRegion::FromZerothElementPointer( + data.data(), dtype, shape, tc.byte_strides)); + EXPECT_EQ(mem_region1.zeroth_element(), data.data()); + // Note: `EXPECT_EQ(mem_region.mem_region(), absl::string_view(data))` can + // cause asan to complain if the expectation fails. + EXPECT_EQ(mem_region1.mem_region().data(), data.data()); + EXPECT_EQ(mem_region1.mem_region().size(), data.size()); + + TF_ASSERT_OK_AND_ASSIGN( + auto mem_region2, ArrayMemRegion::FromMinimalMemRegion(data, dtype, shape, + tc.byte_strides)); + EXPECT_EQ(mem_region2.zeroth_element(), data.data()); + EXPECT_EQ(mem_region2.mem_region().data(), data.data()); + EXPECT_EQ(mem_region2.mem_region().size(), data.size()); +} + +class ArrayMemRegionFailure : public TestWithParam {}; +INSTANTIATE_TEST_SUITE_P( + Tests, ArrayMemRegionFailure, + testing::Values( + // Will not be supported + TC{"OneString", kString, {}, std::nullopt}, + TC{"ManyStrings", kString, {5}, std::nullopt}, + // Currently unimplemented + TC{"NegativeByteStrides", kS32, {4, 3, 5}, Strides({-60, -20, -4})}, + TC{"ZeroByteStride", kS32, {5, 5}, Strides({0, 0})}, + TC{"SmallerByteStrideThanDataType", kS32, {5, 5}, Strides({1, 1})}, + TC{"ByteStrideIndivisibleByDataType", kS32, {5, 5}, Strides({7, 7})}, + // Bad arguments + TC{"NegativeShapeDimension", kS32, {-5, -5}, Strides({20, 4})}), + testing::PrintToStringParamName()); +TEST_P(ArrayMemRegionFailure, TestCase) { + const TC tc = GetParam(); + const DType dtype(tc.dtype_kind); + const Shape shape(tc.shape); + char const* kSomeAddr = reinterpret_cast(1UL << 48); + + auto mem_region1 = ArrayMemRegion::FromZerothElementPointer( + /*zeroth_element=*/kSomeAddr, dtype, shape, tc.byte_strides); + EXPECT_THAT(mem_region1.status(), Not(IsOk())); + + const size_t kSomeSize = 1024; + auto mem_region2 = ArrayMemRegion::FromMinimalMemRegion( + absl::string_view(kSomeAddr, kSomeSize), dtype, shape, tc.byte_strides); + EXPECT_THAT(mem_region2.status(), Not(IsOk())); +} + +TEST(ArrayMemRegion, FromBadMemRegionSizeFails) { + const DType kDType(kS32); + const Shape kShape({5, 5}); + const size_t kDataBytes = kDType.byte_size().value() * kShape.num_elements(); + + const size_t kExtraSuffixBytes = 10; + std::string data_with_extra_suffix(kDataBytes + kExtraSuffixBytes, 'a'); + + // If we know that the zeroth_element is at the beginning, then we + // can construct the ArrayMemoryRegion; the constructed ArrayMemoryRegion + // will not contain the suffix. + TF_ASSERT_OK_AND_ASSIGN( + auto mem_region1, + ArrayMemRegion::FromZerothElementPointer( + /*zeroth_element=*/data_with_extra_suffix.data(), kDType, kShape, + /*byte_strides=*/std::nullopt)); + EXPECT_EQ(mem_region1.mem_region().data(), data_with_extra_suffix.data()); + EXPECT_EQ(mem_region1.zeroth_element(), data_with_extra_suffix.data()); + EXPECT_LT(mem_region1.mem_region().size(), data_with_extra_suffix.size()); + EXPECT_EQ(mem_region1.mem_region().size(), kDataBytes); + + // But given the data_with_extra_suffix region, we cannot discover where + // within it the zeroth-element points to, so we cannot construct an + // ArrayMemoryRegion from it. + auto mem_region2 = ArrayMemRegion::FromMinimalMemRegion( + data_with_extra_suffix, kDType, kShape, + /*byte_strides=*/std::nullopt); + EXPECT_THAT(mem_region2.status(), Not(IsOk())); + + // Similarly, if we provided `FromMinimalMemRegion` a `data` that was smaller + // than what the constructed `ArrayMemoryRegion` should point to, that will + // be detected as an error. + std::string data_without_some_bytes(kDataBytes - kExtraSuffixBytes, 'a'); + auto mem_region3 = ArrayMemRegion::FromMinimalMemRegion( + data_without_some_bytes, kDType, kShape, + /*byte_strides=*/std::nullopt); + EXPECT_THAT(mem_region3.status(), Not(IsOk())); +} + +} // namespace + +} // namespace proxy +} // namespace ifrt +} // namespace xla diff --git a/third_party/xla/xla/python/ifrt_proxy/common/grpc_credentials.cc b/third_party/xla/xla/python/ifrt_proxy/common/grpc_credentials.cc new file mode 100644 index 00000000000000..f72424b859577f --- /dev/null +++ b/third_party/xla/xla/python/ifrt_proxy/common/grpc_credentials.cc @@ -0,0 +1,71 @@ +// Copyright 2023 The OpenXLA Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "xla/python/ifrt_proxy/common/grpc_credentials.h" + +#include +#include + +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "grpcpp/security/credentials.h" +#include "grpcpp/security/server_credentials.h" +#include "tsl/platform/platform.h" + +namespace xla { +namespace ifrt { +namespace proxy { + +namespace { + +bool UseInsecureCredentials() { + // Use insecure only with `bazel test`. + const bool insecure = (getenv("TEST_UNDECLARED_OUTPUTS_DIR") != nullptr); + + if (insecure) { + // We should not be getting to this point at all in the google-internal + // code, but check to be sure. + CHECK_EQ(TSL_IS_IN_OSS, 1); + } + + return insecure; +} + +} // namespace + +std::shared_ptr<::grpc::ChannelCredentials> GetClientCredentials() { + if (UseInsecureCredentials()) { + LOG(WARNING) << "Using insecure client credentials for gRPC."; + return ::grpc::InsecureChannelCredentials(); // NOLINT + } else { + LOG(INFO) << "Using ALTS client credentials for gRPC."; + return ::grpc::experimental::AltsCredentials( + ::grpc::experimental::AltsCredentialsOptions()); + } +} + +std::shared_ptr<::grpc::ServerCredentials> GetServerCredentials() { + if (UseInsecureCredentials()) { + LOG(WARNING) << "Using insecure server credentials for gRPC."; + return ::grpc::InsecureServerCredentials(); // NOLINT + } else { + LOG(INFO) << "Using ALTS server credentials for gRPC."; + return ::grpc::experimental::AltsServerCredentials( + ::grpc::experimental::AltsServerCredentialsOptions()); + } +} + +} // namespace proxy +} // namespace ifrt +} // namespace xla diff --git a/third_party/xla/xla/python/ifrt_proxy/common/grpc_credentials.h b/third_party/xla/xla/python/ifrt_proxy/common/grpc_credentials.h new file mode 100644 index 00000000000000..46435a4adebefd --- /dev/null +++ b/third_party/xla/xla/python/ifrt_proxy/common/grpc_credentials.h @@ -0,0 +1,41 @@ +/* + * Copyright 2023 The OpenXLA Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef XLA_PYTHON_IFRT_PROXY_COMMON_GRPC_CREDENTIALS_H_ +#define XLA_PYTHON_IFRT_PROXY_COMMON_GRPC_CREDENTIALS_H_ + +#include + +#include "grpcpp/security/credentials.h" +#include "grpcpp/security/server_credentials.h" + +namespace xla { +namespace ifrt { +namespace proxy { + +// Get credentials to use in the client gRPC. +// TODO(b/323079791): Migrate to use utility library from tsl/platform. +std::shared_ptr<::grpc::ChannelCredentials> GetClientCredentials(); + +// Get credentials to use in the server gRPC. +// TODO(b/323079791): Migrate to use utility library from tsl/platform. +std::shared_ptr<::grpc::ServerCredentials> GetServerCredentials(); + +} // namespace proxy +} // namespace ifrt +} // namespace xla + +#endif // XLA_PYTHON_IFRT_PROXY_COMMON_GRPC_CREDENTIALS_H_ diff --git a/third_party/xla/xla/python/ifrt_proxy/common/grpc_ifrt_service.proto b/third_party/xla/xla/python/ifrt_proxy/common/grpc_ifrt_service.proto new file mode 100644 index 00000000000000..6741e5d98af8a7 --- /dev/null +++ b/third_party/xla/xla/python/ifrt_proxy/common/grpc_ifrt_service.proto @@ -0,0 +1,107 @@ +// Copyright 2023 The OpenXLA Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax = "proto3"; + +package xla.ifrt.proxy; + +import "xla/python/ifrt_proxy/common/ifrt_service.proto"; + +service GrpcIfrtService { + // Returns the IFRT Proxy version that both the client and the server + // supports. Returns an error if there's no such version. + rpc GetVersion(GrpcGetVersionRequest) returns (GrpcGetVersionResponse) {} + + // IfrtSession is a stream of IFRT requests (from the client) and responses + // from the server. + // + // Clients can optionally start the stream with an InitRequest to configure + // startup options and to retrieve basic run-time system details such as the + // number and handles of the available devices (see InitResponse). But clients + // that are fine with the default options and do not immediately need the info + // from the InitResponse can start with any other request. + // + // TODO(b/282757875): Investigate if there are useful details that client + // should supply to the server even before the first InitRequest message - may + // be via gRPC metadata. + rpc IfrtSession(stream IfrtRequest) returns (stream IfrtResponse) {} + + // Sends a host buffer from the client to the server. Uses client-side + // streaming to allow sending buffers that exceed the 2GiB protobuf + // serialization limit. + rpc HostBufferStore(stream GrpcHostBufferStoreRequest) + returns (GrpcHostBufferStoreResponse); + + // Reads a host buffer from the server to the client. Uses server-side + // streaming to allow >2GiB host buffer transfer. + rpc HostBufferLookup(GrpcHostBufferLookupRequest) + returns (stream GrpcHostBufferLookupResponse); + + // Deletes a host buffer from the server. + rpc HostBufferDelete(GrpcHostBufferDeleteRequest) + returns (GrpcHostBufferDeleteResponse); +} + +message GrpcGetVersionRequest { + IfrtProxyVersion min_version = 1; + IfrtProxyVersion max_version = 2; +} + +message GrpcGetVersionResponse { + IfrtProxyVersion version = 1; +} + +// Metadata for `IfrtSession` requests, sent as client metadata associated with +// key "ifrt-proxy-grpc-ifrt-session-metadata-bin". +message GrpcIfrtSessionMetadata { + IfrtProxyVersion version = 1; +} + +// Metadata for `Store` requests, sent as client metadata associated with key +// "ifrt-proxy-grpc-host-buffer-store-metadata-bin". +message GrpcHostBufferStoreMetadata { + fixed64 session_id = 1; + fixed64 handle = 2; + int64 buffer_size = 3; +} + +// `Store` request that contains actual data, potentially chunked. All requests +// in a transfer must be sent in order and the server simply concatenate `bytes` +// in the response under this assumption. +message GrpcHostBufferStoreRequest { + bytes data = 1; // copybara_removed [ctype = STRING_PIECE] +} + +message GrpcHostBufferStoreResponse {} + +// `Lookup` request that specifies which host buffer in the server to read. +message GrpcHostBufferLookupRequest { + fixed64 session_id = 1; + fixed64 handle = 2; +} + +// `Lookup` response that returns the (potentially chunked) host buffer +// contents. As in `GrpcHostBufferStoreRequest`, all responses must be sent in +// order and the client simply concatenates `data`. +message GrpcHostBufferLookupResponse { + bytes data = 1; // copybara_removed [ctype = STRING_PIECE] +} + +// `Delete` request that specifies the host buffer to delete. +message GrpcHostBufferDeleteRequest { + fixed64 session_id = 1; + fixed64 handle = 2; +} + +message GrpcHostBufferDeleteResponse {} diff --git a/third_party/xla/xla/python/ifrt_proxy/common/ifrt_proxy.bzl b/third_party/xla/xla/python/ifrt_proxy/common/ifrt_proxy.bzl new file mode 100644 index 00000000000000..9dd5c3e3ad996e --- /dev/null +++ b/third_party/xla/xla/python/ifrt_proxy/common/ifrt_proxy.bzl @@ -0,0 +1,8 @@ +"""Common libraries for IFRT proxy.""" + +load("//xla:xla.bzl", "xla_cc_test") + +def ifrt_proxy_cc_test(**kwargs): + xla_cc_test(**kwargs) + +default_ifrt_proxy_visibility = ["//xla/python/ifrt_proxy:__subpackages__"] diff --git a/third_party/xla/xla/python/ifrt_proxy/common/ifrt_service.proto b/third_party/xla/xla/python/ifrt_proxy/common/ifrt_service.proto new file mode 100644 index 00000000000000..925bd76ecdea55 --- /dev/null +++ b/third_party/xla/xla/python/ifrt_proxy/common/ifrt_service.proto @@ -0,0 +1,483 @@ +// Copyright 2023 The OpenXLA Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax = "proto3"; + +package xla.ifrt.proxy; + +import "xla/pjrt/execute_options.proto"; +import "xla/python/ifrt/serdes.proto"; +import "xla/python/ifrt_proxy/common/types.proto"; +import "xla/xla_data.proto"; +import "tsl/protobuf/status.proto"; + +option cc_enable_arenas = true; + +message IfrtProxyVersion { + int32 protocol_version = 1; +} + +message IfrtRequest { + RequestMetadata request_metadata = 1; + + oneof request { + InitRequest init_request = 2; + + // ===== Future ===== + CheckFutureRequest check_future_request = 3; + + // ===== Array ===== + MakeArrayFromHostBufferRequest make_array_from_host_buffer_request = 4; + AssembleArrayFromSingleDeviceArraysRequest + assemble_array_from_single_device_arrays_request = 5; + CopyToHostBufferRequest copy_to_host_buffer_request = 6; + DisassembleIntoSingleDeviceArraysRequest + disassemble_into_single_device_arrays_request = 7; + CheckArrayReadyRequest check_array_ready_request = 8; + DeleteArrayRequest delete_array_request = 9; + ReshardRequest reshard_request = 10; + FullyReplicatedShardRequest fully_replicated_shard_request = 20; + IsArrayDeletedRequest is_array_deleted_request = 11; + DestructArrayRequest destruct_array_request = 12; + + // ==== Compiler ==== + CompileRequest compile_request = 13; + + // ===== LoadedExecutable ===== + LoadedExecutableMetadataRequest loaded_executable_metadata_request = 14; + LoadedExecutableExecuteRequest loaded_executable_execute_request = 15; + LoadedExecutableDeleteRequest loaded_executable_delete_request = 16; + LoadedExecutableIsDeletedRequest loaded_executable_is_deleted_request = 17; + LoadedExecutableDestructRequest loaded_executable_destruct_request = 18; + + // ===== LoadedHostCallback ===== + LoadedHostCallbackPollRequest loaded_host_callback_poll_request = 21; + LoadedHostCallbackReturnRequest loaded_host_callback_return_request = 22; + + // ===== Client ===== + GetDefaultDeviceAssignmentRequest get_default_device_assignment_request = + 19; + } +} + +message IfrtResponse { + ResponseMetadata response_metadata = 1; + + oneof response { + InitResponse init_response = 2; + + // ===== Future ===== + CheckFutureResponse check_future_response = 3; + + // ===== Array ===== + MakeArrayFromHostBufferResponse make_array_from_host_buffer_response = 4; + AssembleArrayFromSingleDeviceArraysResponse + assemble_array_from_single_device_arrays_response = 5; + CopyToHostBufferResponse copy_to_host_buffer_response = 6; + DisassembleIntoSingleDeviceArraysResponse + disassemble_into_single_device_arrays_response = 7; + CheckArrayReadyResponse check_array_ready_response = 8; + DeleteArrayResponse delete_array_response = 9; + ReshardResponse reshard_response = 10; + FullyReplicatedShardResponse fully_replicated_shard_response = 20; + IsArrayDeletedResponse is_array_deleted_response = 11; + DestructArrayResponse destruct_array_response = 12; + + // ===== Compiler ===== + CompileResponse compile_response = 13; + + // ===== LoadedExecutable ===== + LoadedExecutableMetadataResponse loaded_executable_metadata_response = 14; + LoadedExecutableExecuteResponse loaded_executable_execute_response = 15; + LoadedExecutableDeleteResponse loaded_executable_delete_response = 16; + LoadedExecutableIsDeletedResponse loaded_executable_is_deleted_response = + 17; + LoadedExecutableDestructResponse loaded_executable_destruct_response = 18; + + // ===== LoadedHostCallback ===== + LoadedHostCallbackPollResponse loaded_host_callback_poll_response = 21; + LoadedHostCallbackReturnResponse loaded_host_callback_return_response = 22; + + // ===== Client ===== + GetDefaultDeviceAssignmentResponse get_default_device_assignment_response = + 19; + } +} + +// Metadata of an IFRT Request. +message RequestMetadata { + // Identifies a logical IFRT Operation (equivalent to an IFRT API call). + // + // For the operations that require chunking (e.g.: MakeArrayFromHostBuffer) + // all the request proto messages share the same op_id. + // + // Must be unique and monotonically increasing across the life of a client - + // may stretch across multiple successive IfrtSessions used to reconnect and + // resync after transient connectivity failures. + fixed64 op_id = 1; + + // List of one or more prior ops this current op is "dependent" + // upon. Currently this allows the client to define the order in which the + // server starts the execution of requests. Future versions may add other + // types of dependencies. For instance, a separate list of dependencies that + // must *complete* executing before the current one can start to execute. + // + // An op_id that has not yet been seen by the server is treated as an error + // that fails the op. + repeated fixed64 dependencies = 2; + + // UserContext is a basic provenance mechanism that allows the server-side + // actions and artifacts (say, allocating a buffer) to be associated with the + // corresponding client-side context that triggered those actions. + // + // The optional UserContextId is generated by the client and are used as an + // opaque label by the server and the run-time systems behind it. + // TODO(b/282757875): Add a pointer to Usercontext bugs/design doc. + fixed64 user_context_id = 3; +} + +// Metadata of an IFRT Response. + +message ResponseMetadata { + // ID of the operation this response belongs to. + fixed64 op_id = 1; + + // Status of the operation. + // + // In case of "chunked" responses (i.e., the full logical response is + // spread across a sequence of IfrtResponse protos), the actual sequence of + // IfrtResponse messages will follow only if this Status is OK in the very + // first message. That is, in case of errors, server sends a single + // IfrtResponse with the appropriate error included. + // + // In case of "batched" operations (i.e., where the response is carrying + // the outcomes of multiple requests that were "batched" in the same + // IfrtRequest proto - such as deleting a bunch of Arrays) this Status + // field provides a way to quickly check if none of the individual + // operations encountered errors. Clients should not rely on specific error + // type or string when this is not OK, they should check the response + // message for individual Statuses. + tensorflow.StatusProto status = 2; +} + +// InitRequest allows the client to specify the optional startup configuration +// parameters such as an idle timeout for this `IfrtSession`, backend servers +// addresses, and whether to turn on tracing, etc. +// +// Initialization of a a session is optional, but if a client chooses to do it, +// it must be the very first op i.e., the InitRequest must be the very first +// request of the session. +message InitRequest {} + +// InitResponse contains basic runtime system info (such as the available +// devices, and name and type of the platform) that most clients can immediately +// make use of. It may also carry the status for whether the optional +// configuration requested by the InitRequest has been successfully applied. +message InitResponse { + uint64 session_id = 8; + + string platform_name = 1; // == ifrt::Client::platform_name() + string platform_version = 2; // == ifrt::Client::platform_version() + uint64 platform_id = 3; // == ifrt::Client::platform_id() + uint64 process_index = 4; // == ifrt::Client::process_index() + string runtime_type = 5; // == ifrt::Client::runtime_type() + + message Device { + int32 id = 1; + int32 local_device_id = 9; + int32 local_hardware_id = 2; + string device_kind = 3; + optional int32 default_memory_id = 7; + repeated int32 memory_ids = 8; + string debug_string = 4; + string to_string = 5; + map attributes = 6; + } + + repeated Device devices = 6; // == ifrt::Client::devices() + repeated int32 addressable_device_ids = + 7; // == ifrt::Client::addressable_devices() + + message Memory { + int32 id = 1; + string memory_space_kind = 2; + repeated int32 device_ids = 3; + string debug_string = 4; + string to_string = 5; + } + + repeated Memory memories = 9; +} + +// ================ Future-related operations ================ + +// Checks if the given Futures are ready on the server. This is a destructive +// read, i.e., the given future will no longer be able to be referenced. +message CheckFutureRequest { + fixed64 future_handle = 1; +} +message CheckFutureResponse {} + +// ================ Array-related operations ================ + +// In the current context of the IFRT proxy service, the term `Host` in the +// proto names below refers to the host where the proxy client and the user code +// (e.g.: a Jax application) are running. + +// Makes an IFRT Array from the contents of a HostBuffer. +// Equivalent to `ifrt::Client::MakeArrayFromHostBuffer`. +message MakeArrayFromHostBufferRequest { + proto.DType dtype = 1; + proto.Shape shape = 2; + proto.Sharding sharding = 3; + fixed64 host_buffer_handle = 4; + optional proto.ByteStrides byte_strides = 5; +} +message MakeArrayFromHostBufferResponse { + fixed64 array_handle = 1; +} + +// Makes an IFRT Array from a set of single-device Arrays. +// Equivalent to ifrt::Client::AssembleArrayFromSingleDeviceArrays. +message AssembleArrayFromSingleDeviceArraysRequest { + proto.Shape shape = 1; + proto.Sharding sharding = 2; + repeated fixed64 single_device_array_handles = 3; + proto.ArrayCopySemantics copy_semantics = 4; +} +message AssembleArrayFromSingleDeviceArraysResponse { + fixed64 array_handle = 1; +} + +// Reads the contents of a given IFRT Array. +// Equivalent to ifrt::Array::CopyToHostBuffer. +message CopyToHostBufferRequest { + fixed64 array_handle = 2; + optional proto.ByteStrides byte_strides = 3; + fixed64 host_buffer_handle = 1; +} +message CopyToHostBufferResponse {} + +// Breaks the given Array into its constituent per-device Arrays. +// Equivalent to ifrt::Array::DisassmebleIntoSingleDeviceArrays. +message DisassembleIntoSingleDeviceArraysRequest { + fixed64 array_handle = 1; + proto.ArrayCopySemantics copy_semantics = 2; +} +message DisassembleIntoSingleDeviceArraysResponse { + repeated fixed64 single_device_array_handles = 1; +} + +message ReshardRequest { + fixed64 array_handle = 1; + proto.Sharding sharding = 2; + proto.ArrayCopySemantics copy_semantics = 3; +} +message ReshardResponse { + fixed64 array_handle = 1; +} + +message FullyReplicatedShardRequest { + fixed64 array_handle = 1; + proto.ArrayCopySemantics copy_semantics = 2; +} +message FullyReplicatedShardResponse { + fixed64 array_handle = 1; +} + +// Checks if the given Arrays are ready on the server. +message CheckArrayReadyRequest { + fixed64 array_handle = 1; +} +message CheckArrayReadyResponse {} + +// Deletes the given Array. Response contains the handle for a Future that +// becomes ready when the deletion completes. +message DeleteArrayRequest { + fixed64 array_handle = 1; +} +message DeleteArrayResponse { + fixed64 deletion_future_handle = 1; +} + +message IsArrayDeletedRequest { + fixed64 array_handle = 1; +} +message IsArrayDeletedResponse { + bool deleted = 1; +} + +message DestructArrayRequest { + fixed64 array_handle = 1; +} +message DestructArrayResponse {} + +// ================ Compiler-related operations ================ + +// Modeled after `xla::PjRtLoadedExecutable::LogicalDeviceIds`. +// +// TODO(hyeontaek): this XLA-specific type is temporary and will be removed when +// `addressable_device_logical_ids()` is removed from `LoadedExecutable` or +// moved to a type-erased proto field. +message LogicalDeviceIds { + int32 replica = 1; + int32 partition = 2; +} + +// Compiles `mlir_module` and returns a `LoadedExecutable`. +message CompileRequest { + xla.ifrt.Serialized program = 1; + xla.ifrt.Serialized compile_options = 2; + repeated bytes host_callbacks = 3; +} +message CompileResponse { + fixed64 loaded_executable_handle = 1; + repeated fixed64 loaded_host_callback_handles = 8; + + // A subset of LoadedExecutable's fields that are cheap to calculate. See + // `LoadedExecutableMetadataResponse` for the rest of metadata. + string name = 2; + int32 num_devices = 3; + repeated LogicalDeviceIds addressable_device_logical_ids = 4; + repeated int32 addressable_device_ids = 5; + oneof fingerprint { + bytes fingerprint_value = 6; + tensorflow.StatusProto fingerprint_error = 7; + } +} + +// ================ LoadedExecutable-related operations ================ + +// Reads `LoadedExecutable`'s metadata that's typically available only after +// compilation. Metadata fields that are cheaper to calculate are available +// immediately as part of `CompileResponse`. +message LoadedExecutableMetadataRequest { + fixed64 loaded_executable_handle = 1; +} +message LoadedExecutableMetadataResponse { + message ShardingList { + repeated xla.OpSharding shardings = 1; + } + + optional ShardingList parameter_shardings = 1; + optional ShardingList output_shardings = 2; + + message LayoutList { + repeated xla.LayoutProto layouts = 1; + } + + oneof parameter_layouts { + LayoutList parameter_layouts_list = 4; + tensorflow.StatusProto parameter_layouts_error = 5; + } + oneof output_layouts { + LayoutList output_layouts_list = 6; + tensorflow.StatusProto output_layouts_error = 7; + } + + message MemoryKindList { + repeated string memory_kinds = 1; + } + + message OutputMemoryKind { + tensorflow.StatusProto status = 1; + repeated MemoryKindList memory_kind_lists = 2; + } + + OutputMemoryKind output_memory_kinds = 3; +} + +// Mirrors `LoadedExecutable::Execute`. Returns output array handles and a +// future handle that becomes ready when the execution completes. The latter can +// be checked by issuing `CheckFutureRequest`. +message LoadedExecutableExecuteRequest { + fixed64 loaded_executable_handle = 1; + repeated fixed64 args_handles = 2; + xla.ExecuteOptionsProto execute_options = 3; + repeated int32 device_ids = 4; +} +message LoadedExecutableExecuteResponse { + fixed64 status_handle = 1; + + message Output { + proto.DType dtype = 1; + proto.Shape shape = 2; + proto.Sharding sharding = 3; + fixed64 array_handle = 4; + } + + repeated Output outputs = 2; +} + +// Mirrors `LoadedExecutable::Delete`. Returns a handle of a future that becomes +// ready when the deletion completes. +message LoadedExecutableDeleteRequest { + fixed64 loaded_executable_handle = 1; +} +message LoadedExecutableDeleteResponse { + fixed64 future_handle = 1; +} + +// Mirrors `LoadedExecutable::IsDeleted`. +message LoadedExecutableIsDeletedRequest { + fixed64 loaded_executable_handle = 1; +} +message LoadedExecutableIsDeletedResponse { + bool is_deleted = 1; +} + +// Mirrors `LoadedExecutable::~LoadedExecutable`. The LoadedExecutable handle +// becomes unusable after this request. +message LoadedExecutableDestructRequest { + fixed64 loaded_executable_handle = 1; +} +message LoadedExecutableDestructResponse {} + +// ================ LoadedHostCallback-related operations ================ + +// Waits for the given host callback on the server to have any pending execution +// and retrieves its execution identifier and operands. The server serializes +// all operands, concatenates them in the argument order, stores it as a single +// host buffer assocatiated with the given handle. +message LoadedHostCallbackPollRequest { + fixed64 loaded_host_callback_handle = 1; + fixed64 operand_host_buffer_handle = 2; +} +message LoadedHostCallbackPollResponse { + optional fixed64 host_callback_execution_handle = 1; +} + +// Returns the results of a client-side host callback execution, requested by +// `LoadedHostCallbackPollResponse`. The client concatenates all serialized +// results and stores them as a single host buffer associated with the given +// handle. +message LoadedHostCallbackReturnRequest { + fixed64 host_callback_execution_handle = 1; + oneof result { + fixed64 result_host_buffer_handle = 3; + tensorflow.StatusProto error = 2; + } +} +message LoadedHostCallbackReturnResponse {} + +// ============= Operations supported by the IFRT `Client` class ============= + +// Mirrors Client::GetDefaultDeviceAssignment. +message GetDefaultDeviceAssignmentRequest { + fixed64 num_replicas = 1; + fixed64 num_partitions = 2; +} +message GetDefaultDeviceAssignmentResponse { + xla.DeviceAssignmentProto device_assignment = 1; +} diff --git a/third_party/xla/xla/python/ifrt_proxy/common/proto_util.cc b/third_party/xla/xla/python/ifrt_proxy/common/proto_util.cc new file mode 100644 index 00000000000000..a9d057c2139a59 --- /dev/null +++ b/third_party/xla/xla/python/ifrt_proxy/common/proto_util.cc @@ -0,0 +1,38 @@ +// Copyright 2023 The OpenXLA Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "xla/python/ifrt_proxy/common/proto_util.h" + +#include +#include + +#include "absl/status/status.h" +#include "tsl/platform/status_to_from_proto.h" + +namespace xla { +namespace ifrt { +namespace proxy { + +std::unique_ptr NewIfrtResponse(uint64_t op_id, + absl::Status status) { + auto ifrt_resp = std::make_unique(); + auto* response_metadata = ifrt_resp->mutable_response_metadata(); + response_metadata->set_op_id(op_id); + *response_metadata->mutable_status() = tsl::StatusToProto(status); + return ifrt_resp; +} + +} // namespace proxy +} // namespace ifrt +} // namespace xla diff --git a/third_party/xla/xla/python/ifrt_proxy/common/proto_util.h b/third_party/xla/xla/python/ifrt_proxy/common/proto_util.h new file mode 100644 index 00000000000000..d999d14f978367 --- /dev/null +++ b/third_party/xla/xla/python/ifrt_proxy/common/proto_util.h @@ -0,0 +1,57 @@ +/* + * Copyright 2023 The OpenXLA Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef XLA_PYTHON_IFRT_PROXY_COMMON_PROTO_UTIL_H_ +#define XLA_PYTHON_IFRT_PROXY_COMMON_PROTO_UTIL_H_ + +#include +#include +#include + +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/strings/string_view.h" +#include "xla/python/ifrt_proxy/common/ifrt_service.pb.h" + +namespace xla { +namespace ifrt { +namespace proxy { + +// Makes an IfrtResponse proto with the given metadata. +std::unique_ptr NewIfrtResponse( + uint64_t op_id, absl::Status status = absl::OkStatus()); + +// Converts an `absl::string_view` into a type that is appropriate for doing +// `proto->set_string_field(...)`. This type can be absl::string_view in the +// newest versions of protobuf, but needs to be std::string for previous +// versions. (As of Feb 2024, OpenXLA uses an old version.) +#if defined(PLATFORM_GOOGLE) +inline absl::string_view AsProtoStringData( + absl::string_view s ABSL_ATTRIBUTE_LIFETIME_BOUND) { + return s; +} +#else +inline std::string AsProtoStringData(absl::string_view s) { + LOG_FIRST_N(WARNING, 5) << "AsProtoStringData(): copying string_view->string"; + return std::string(s); +} +#endif + +} // namespace proxy +} // namespace ifrt +} // namespace xla + +#endif // XLA_PYTHON_IFRT_PROXY_COMMON_PROTO_UTIL_H_ diff --git a/third_party/xla/xla/python/ifrt_proxy/common/types.cc b/third_party/xla/xla/python/ifrt_proxy/common/types.cc new file mode 100644 index 00000000000000..1474a6dad1fd77 --- /dev/null +++ b/third_party/xla/xla/python/ifrt_proxy/common/types.cc @@ -0,0 +1,231 @@ +// Copyright 2023 The OpenXLA Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "xla/python/ifrt_proxy/common/types.h" + +#include +#include +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "llvm/Support/Casting.h" +#include "xla/pjrt/pjrt_common.h" +#include "xla/python/ifrt/array.h" +#include "xla/python/ifrt/device.h" +#include "xla/python/ifrt/dtype.h" +#include "xla/python/ifrt/serdes.h" +#include "xla/python/ifrt/shape.h" +#include "xla/python/ifrt/sharding.h" +#include "xla/python/ifrt/sharding_serdes.h" +#include "xla/python/ifrt_proxy/common/types.pb.h" +#include "tsl/platform/statusor.h" + +namespace xla { +namespace ifrt { +namespace proxy { + +DType FromDTypeProto(proto::DType dtype_proto) { + switch (dtype_proto) { + case proto::DType::DTYPE_PRED: + return DType(DType::Kind::kPred); + case proto::DType::DTYPE_TOKEN: + return DType(DType::Kind::kToken); +#define CASE(X) \ + case proto::DType::DTYPE_##X: \ + return DType(DType::Kind::k##X); + CASE(S4); + CASE(S8); + CASE(S16); + CASE(S32); + CASE(S64); + CASE(U4); + CASE(U8); + CASE(U16); + CASE(U32); + CASE(U64); + CASE(F16); + CASE(F32); + CASE(F64); + CASE(BF16); + CASE(C64); + CASE(C128); + CASE(F8E4M3FN); + CASE(F8E4M3B11FNUZ); + CASE(F8E4M3FNUZ); + CASE(F8E5M2); + CASE(F8E5M2FNUZ); +#undef CASE + default: + return DType(DType::Kind::kInvalid); + } +} + +proto::DType ToDTypeProto(DType dtype) { + switch (dtype.kind()) { + case DType::Kind::kPred: + return proto::DType::DTYPE_PRED; + case DType::Kind::kToken: + return proto::DType::DTYPE_TOKEN; +#define CASE(X) \ + case DType::Kind::k##X: \ + return proto::DType::DTYPE_##X; + CASE(S4); + CASE(S8); + CASE(S16); + CASE(S32); + CASE(S64); + CASE(U4); + CASE(U8); + CASE(U16); + CASE(U32); + CASE(U64); + CASE(F16); + CASE(F32); + CASE(F64); + CASE(BF16); + CASE(C64); + CASE(C128); + CASE(F8E4M3FN); + CASE(F8E4M3B11FNUZ); + CASE(F8E4M3FNUZ); + CASE(F8E5M2); + CASE(F8E5M2FNUZ); +#undef CASE + default: + return proto::DType::DTYPE_UNSPECIFIED; + } +} + +Shape FromShapeProto(const proto::Shape& shape_proto) { + return Shape(shape_proto.dimensions()); +} + +proto::Shape ToShapeProto(const Shape& shape) { + proto::Shape shape_proto; + for (int64_t dim : shape.dims()) { + shape_proto.add_dimensions(dim); + } + return shape_proto; +} + +absl::StatusOr FromVariantProto( + const proto::Variant& variant_proto) { + switch (variant_proto.value_case()) { + case proto::Variant::kStringValue: + return variant_proto.string_value(); + case proto::Variant::kInt64Value: + return variant_proto.int64_value(); + case proto::Variant::kInt64List: { + const auto& values = variant_proto.int64_list().values(); + return std::vector(values.begin(), values.end()); + } + case proto::Variant::kFloatValue: + return variant_proto.float_value(); + default: + return absl::UnimplementedError(absl::StrCat( + "Unknown xla.ifrt.proto.Variant case: ", variant_proto.value_case())); + } +} + +absl::StatusOr ToVariantProto(const xla::PjRtValueType& value) { + proto::Variant variant; + if (auto* s = std::get_if(&value)) { + variant.set_string_value(*s); + } else if (auto* i = std::get_if(&value)) { + variant.set_int64_value(*i); + } else if (auto* is = std::get_if>(&value)) { + for (const int64_t i : *is) { + variant.mutable_int64_list()->add_values(i); + } + } else if (auto* f = std::get_if(&value)) { + variant.set_float_value(*f); + } else { + return absl::UnimplementedError("Unknown xla::PjRtValueType type"); + } + return variant; +} + +absl::StatusOr> FromShardingProto( + DeviceList::LookupDeviceFunc lookup_device, + const proto::Sharding& sharding_proto) { + TF_ASSIGN_OR_RETURN(std::unique_ptr sharding, + Deserialize(sharding_proto.serialized_sharding(), + std::make_unique( + std::move(lookup_device)))); + return std::shared_ptr( + llvm::cast(sharding.release())); +} + +absl::StatusOr ToShardingProto(const Sharding& sharding) { + proto::Sharding sharding_proto; + TF_ASSIGN_OR_RETURN(*sharding_proto.mutable_serialized_sharding(), + Serialize(const_cast(sharding))); + return sharding_proto; +} + +proto::ArrayCopySemantics ToArrayCopySemanticsProto(ArrayCopySemantics s) { + switch (s) { + case ArrayCopySemantics::kAlwaysCopy: + return proto::ARRAY_COPY_SEMANTICS_ALWAYS_COPY; + case ArrayCopySemantics::kDonateInput: + return proto::ARRAY_COPY_SEMANTICS_DONATE_INPUT; + case ArrayCopySemantics::kReuseInput: + return proto::ARRAY_COPY_SEMANTICS_REUSE_INPUT; + } +} + +absl::StatusOr FromArrayCopySemanticsProto( + proto::ArrayCopySemantics s) { + MakeArrayFromHostBufferRequest req; + switch (s) { + case proto::ARRAY_COPY_SEMANTICS_ALWAYS_COPY: + return ArrayCopySemantics::kAlwaysCopy; + case proto::ARRAY_COPY_SEMANTICS_DONATE_INPUT: + return ArrayCopySemantics::kDonateInput; + case proto::ARRAY_COPY_SEMANTICS_REUSE_INPUT: + return ArrayCopySemantics::kReuseInput; + default: + return absl::InvalidArgumentError( + absl::StrCat("Unhandled proto-enum value ", s, ":", + proto::ArrayCopySemantics_Name(s))); + } +} + +std::vector FromByteStridesProto(const proto::ByteStrides& strides) { + std::vector result; + result.reserve(strides.strides_size()); + for (auto x : strides.strides()) { + result.push_back(x); + } + return result; +} + +proto::ByteStrides ToByteStridesProto(const absl::Span strides) { + proto::ByteStrides result; + for (auto x : strides) { + result.add_strides(x); + } + return result; +} + +} // namespace proxy +} // namespace ifrt +} // namespace xla diff --git a/third_party/xla/xla/python/ifrt_proxy/common/types.h b/third_party/xla/xla/python/ifrt_proxy/common/types.h new file mode 100644 index 00000000000000..bd894b70a68d6b --- /dev/null +++ b/third_party/xla/xla/python/ifrt_proxy/common/types.h @@ -0,0 +1,74 @@ +/* + * Copyright 2023 The OpenXLA Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef XLA_PYTHON_IFRT_PROXY_COMMON_TYPES_H_ +#define XLA_PYTHON_IFRT_PROXY_COMMON_TYPES_H_ + +#include +#include +#include + +#include "absl/status/statusor.h" +#include "absl/types/span.h" +#include "xla/pjrt/pjrt_common.h" +#include "xla/python/ifrt/array.h" +#include "xla/python/ifrt/device.h" +#include "xla/python/ifrt/dtype.h" +#include "xla/python/ifrt/shape.h" +#include "xla/python/ifrt/sharding.h" +#include "xla/python/ifrt_proxy/common/ifrt_service.pb.h" +#include "xla/python/ifrt_proxy/common/types.pb.h" + +namespace xla { +namespace ifrt { +namespace proxy { + +struct ArrayHandle { + uint64_t handle; + + template + friend void AbslStringify(Sink& sink, const ArrayHandle& h) { + absl::Format(&sink, "arr_%v", h.handle); + } +}; + +DType FromDTypeProto(proto::DType dtype_proto); +proto::DType ToDTypeProto(DType dtype); + +Shape FromShapeProto(const proto::Shape& shape_proto); +proto::Shape ToShapeProto(const Shape& shape); + +absl::StatusOr> FromShardingProto( + DeviceList::LookupDeviceFunc lookup_device, + const proto::Sharding& sharding_proto); +absl::StatusOr ToShardingProto(const Sharding& sharding); + +absl::StatusOr FromArrayCopySemanticsProto( + proto::ArrayCopySemantics s); +proto::ArrayCopySemantics ToArrayCopySemanticsProto(ArrayCopySemantics s); + +absl::StatusOr FromVariantProto( + const proto::Variant& variant_proto); +absl::StatusOr ToVariantProto(const xla::PjRtValueType& value); + +std::vector FromByteStridesProto(const proto::ByteStrides& strides); +proto::ByteStrides ToByteStridesProto(absl::Span strides); + +} // namespace proxy +} // namespace ifrt +} // namespace xla + +#endif // XLA_PYTHON_IFRT_PROXY_COMMON_TYPES_H_ diff --git a/third_party/xla/xla/python/ifrt_proxy/common/types.proto b/third_party/xla/xla/python/ifrt_proxy/common/types.proto new file mode 100644 index 00000000000000..7ef48aed10c54b --- /dev/null +++ b/third_party/xla/xla/python/ifrt_proxy/common/types.proto @@ -0,0 +1,104 @@ +// Copyright 2023 The OpenXLA Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax = "proto3"; + +package xla.ifrt.proto; + +import "xla/python/ifrt/serdes.proto"; + +// Sharding of an Array or Executable parameter/output. +// TODO(b/266635130): Remove `Sharding` and use `xla.ifrt.Serialized` directly +// if caching of sharding is unnecessary. +message Sharding { + xla.ifrt.Serialized serialized_sharding = 1; +} + +// Shape of an Array. Currently we only support static shapes with all +// dimension sizes greater than or equal to 0. +message Shape { + repeated int64 dimensions = 1; +} + +// Data types currently supported. Mirrors `xla::ifrt::DType`. +enum DType { + DTYPE_UNSPECIFIED = 0; + + // Predicates are two-state booleans. + DTYPE_PRED = 1; + + // Signed integral values of fixed width. + DTYPE_S4 = 21; + DTYPE_S8 = 2; + DTYPE_S16 = 3; + DTYPE_S32 = 4; + DTYPE_S64 = 5; + + // Unsigned integral values of fixed width. + DTYPE_U4 = 22; + DTYPE_U8 = 6; + DTYPE_U16 = 7; + DTYPE_U32 = 8; + DTYPE_U64 = 9; + + // Floating-point values of fixed width. + DTYPE_F16 = 10; + DTYPE_F32 = 11; + DTYPE_F64 = 12; + + // Truncated 16 bit floating-point format. This is similar to IEEE's 16 bit + // floating-point format, but uses 1 bit for the sign, 8 bits for the + // exponent and 7 bits for the mantissa. + DTYPE_BF16 = 16; + + // Complex values of fixed width. + DTYPE_C64 = 15; // Paired F32 (real, imag), as in std::complex. + DTYPE_C128 = 18; // Paired F64 (real, imag), as in std::complex. + + // A token type threaded between side-effecting operations. Shapes of this + // dtype will have empty dimensions. + DTYPE_TOKEN = 17; + + DTYPE_F8E4M3FN = 20; + DTYPE_F8E4M3B11FNUZ = 23; + DTYPE_F8E4M3FNUZ = 25; + DTYPE_F8E5M2 = 19; + DTYPE_F8E5M2FNUZ = 24; +} + +// Mirrors `xla::PjRtValueType`, which is used in IFRT to model +// polymorphic-typed values, e.g., `xla::ifrt::Executable::CostAnalysisValue`. +message Variant { + message Int64List { + repeated sfixed64 values = 1; + } + + oneof value { + string string_value = 1; + sfixed64 int64_value = 2; + Int64List int64_list = 3; + float float_value = 4; + } +} + +enum ArrayCopySemantics { + ARRAY_COPY_SEMANTICS_UNSPECIFIED = 0; + ARRAY_COPY_SEMANTICS_ALWAYS_COPY = 1; + ARRAY_COPY_SEMANTICS_REUSE_INPUT = 2; + ARRAY_COPY_SEMANTICS_DONATE_INPUT = 3; +} + +message ByteStrides { + repeated int64 strides = 1; +} diff --git a/third_party/xla/xla/python/ifrt_proxy/common/types_test.cc b/third_party/xla/xla/python/ifrt_proxy/common/types_test.cc new file mode 100644 index 00000000000000..5d0df922f1a5a6 --- /dev/null +++ b/third_party/xla/xla/python/ifrt_proxy/common/types_test.cc @@ -0,0 +1,167 @@ +// Copyright 2023 The OpenXLA Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "xla/python/ifrt_proxy/common/types.h" + +#include +#include +#include + +#include +#include +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "xla/pjrt/pjrt_common.h" +#include "xla/python/ifrt/shape.h" +#include "xla/python/ifrt_proxy/common/types.pb.h" +#include "tsl/platform/status_matchers.h" +#include "tsl/platform/statusor.h" +#include "tsl/platform/test.h" + +namespace xla { +namespace ifrt { +namespace proxy { +namespace { + +using ::tsl::testing::IsOkAndHolds; + +#if defined(PLATFORM_GOOGLE) +using ::testing::EquivToProto; +#endif + +TEST(DTypeTest, ToFromProto) { + for (int i = 0; i < proto::DType_descriptor()->value_count(); ++i) { + const proto::DType dtype = static_cast( + proto::DType_descriptor()->value(i)->number()); + EXPECT_EQ(ToDTypeProto(FromDTypeProto(dtype)), dtype); + } +} + +// TODO(b/315809436): Test needs rewrite because protobuf matchers are not OSS +#if defined(PLATFORM_GOOGLE) +class ShapeTest + : public testing::TestWithParam> {}; + +TEST_P(ShapeTest, FromShapeProto) { + const auto& [shape, shape_proto] = GetParam(); + EXPECT_EQ(FromShapeProto(shape_proto), shape); +} + +TEST_P(ShapeTest, ToShapeProto) { + const auto& [shape, shape_proto] = GetParam(); + EXPECT_THAT(ToShapeProto(shape), EquivToProto(shape_proto)); +} + +proto::Shape MakeProtoShape(absl::Span dims) { + auto shape = proto::Shape(); + for (auto dim : dims) { + shape.add_dimensions(dim); + } + return shape; +} + +INSTANTIATE_TEST_SUITE_P(Shape, ShapeTest, + testing::ValuesIn({ + std::make_pair(Shape({}), MakeProtoShape({})), + std::make_pair(Shape({1, 2}), + MakeProtoShape({1, 2})), + })); +#endif + +// TODO(b/315809436): Test needs rewrite because protobuf matchers are not OSS +#if defined(PLATFORM_GOOGLE) +class VariantTest : public testing::TestWithParam< + std::pair> {}; + +TEST_P(VariantTest, FromVariantProto) { + const auto& [variant, variant_proto] = GetParam(); + EXPECT_THAT(FromVariantProto(variant_proto), IsOkAndHolds(variant)); +} + +TEST_P(VariantTest, ToVariantProto) { + const auto& [variant, variant_proto] = GetParam(); + EXPECT_THAT(ToVariantProto(variant), + IsOkAndHolds(EquivToProto(variant_proto))); +} + +proto::Variant MakeProtoVariantString(absl::string_view arg) { + auto variant = proto::Variant(); + variant.set_string_value(arg); + return variant; +} + +proto::Variant MakeProtoVariantInt64(int64_t arg) { + auto variant = proto::Variant(); + variant.set_int64_value(arg); + return variant; +} + +proto::Variant MakeProtoVariantInt64List(absl::Span arg) { + auto variant = proto::Variant(); + for (auto arg : arg) { + variant.mutable_int64_list()->add_values(arg); + } + return variant; +} + +proto::Variant MakeProtoVariantFloat(float arg) { + auto variant = proto::Variant(); + variant.set_float_value(arg); + return variant; +} + +INSTANTIATE_TEST_SUITE_P( + Variant, VariantTest, + testing::ValuesIn({ + std::make_pair(xla::PjRtValueType("foo"), + MakeProtoVariantString("foo")), + std::make_pair(xla::PjRtValueType(static_cast(1234)), + MakeProtoVariantInt64(1234)), + std::make_pair(xla::PjRtValueType(std::vector{1, 2}), + MakeProtoVariantInt64List({1, 2})), + std::make_pair(xla::PjRtValueType(3.14f), MakeProtoVariantFloat(3.14f)), + })); +#endif + +class ByteStridesTest : public testing::TestWithParam> {}; + +TEST_P(ByteStridesTest, ToFromProto) { + std::vector strides = GetParam(); + EXPECT_EQ(FromByteStridesProto(ToByteStridesProto(strides)), strides); +} + +INSTANTIATE_TEST_SUITE_P( + ByteStrides, ByteStridesTest, + testing::ValuesIn(std::vector>{ + {}, {1}, {0}, {4, 8}, {8, 4}, {1, 2, 3, 4}, {0, 4}, {4, 0}})); + +TEST(ArrayCopySemantics, ToFromProtoTest) { + // NOLINTNEXTLINE readability-proto-enum-for-loop + for (int proto_enum_int = proto::ArrayCopySemantics_MIN; + proto_enum_int <= proto::ArrayCopySemantics_MAX; ++proto_enum_int) { + const auto proto_enum = + static_cast(proto_enum_int); + if (proto_enum == proto::ARRAY_COPY_SEMANTICS_UNSPECIFIED) { + continue; + } + TF_ASSERT_OK_AND_ASSIGN(const auto cpp_enum, + FromArrayCopySemanticsProto(proto_enum)); + EXPECT_EQ(proto_enum, ToArrayCopySemanticsProto(cpp_enum)); + } +} + +} // namespace +} // namespace proxy +} // namespace ifrt +} // namespace xla diff --git a/third_party/xla/xla/python/ifrt_proxy/integration_tests/BUILD b/third_party/xla/xla/python/ifrt_proxy/integration_tests/BUILD new file mode 100644 index 00000000000000..23a56803a43fd7 --- /dev/null +++ b/third_party/xla/xla/python/ifrt_proxy/integration_tests/BUILD @@ -0,0 +1,107 @@ +# Copyright 2023 The OpenXLA Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +load("//xla/python/ifrt_proxy/common:ifrt_proxy.bzl", "ifrt_proxy_cc_test") + +package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = ["//visibility:private"], +) + +cc_library( + name = "register_pjrt_cpu_for_ifrt_api_tests", + testonly = True, + srcs = ["register_pjrt_cpu_for_ifrt_api_tests.cc"], + deps = [ + "//xla/pjrt:pjrt_client", + "//xla/pjrt/cpu:cpu_client", + "//xla/python/ifrt", + "//xla/python/ifrt:test_util", + "//xla/python/ifrt_proxy/client:grpc_client", + "//xla/python/ifrt_proxy/client:registry", + "//xla/python/ifrt_proxy/server:grpc_server", + "//xla/python/pjrt_ifrt", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/strings", + "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/platform:test", + ], + alwayslink = True, +) + +ifrt_proxy_cc_test( + name = "client_impl_test_tfrt_cpu", + deps = [ + ":register_pjrt_cpu_for_ifrt_api_tests", + "//xla/python/ifrt:client_impl_test_lib", + "@com_google_absl//absl/strings", + "@com_google_googletest//:gtest_main", + ], +) + +ifrt_proxy_cc_test( + name = "array_impl_test_tfrt_cpu", + srcs = ["array_impl_test_tfrt_cpu.cc"], + deps = [ + ":register_pjrt_cpu_for_ifrt_api_tests", + "//xla/python/ifrt:array_impl_test_lib", + "@com_google_absl//absl/strings", + "@com_google_googletest//:gtest", + ], +) + +ifrt_proxy_cc_test( + name = "executable_impl_test_tfrt_cpu", + timeout = "moderate", + srcs = ["executable_impl_test_tfrt_cpu.cc"], + deps = [ + ":register_pjrt_cpu_for_ifrt_api_tests", # buildcleaner: keep + "//xla/python/ifrt:test_util", + "//xla/python/ifrt/ir/tests:executable_impl_test_lib", + "//xla/python/pjrt_ifrt:xla_executable_impl_test_lib", + "@com_google_absl//absl/strings", + "@com_google_googletest//:gtest", + ], +) + +ifrt_proxy_cc_test( + name = "mock_array_test", + size = "small", + srcs = ["mock_array_test.cc"], + deps = [ + "//xla:status", + "//xla/pjrt/cpu:cpu_client", + "//xla/python/ifrt", + "//xla/python/ifrt:mock", + "//xla/python/ifrt_proxy/client", + "//xla/python/ifrt_proxy/client:grpc_client", + "//xla/python/ifrt_proxy/client:registry", + "//xla/python/ifrt_proxy/server:grpc_server", + "//xla/python/pjrt_ifrt", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/time", + "@com_google_absl//absl/types:span", + "@com_google_googletest//:gtest_main", + "@local_tsl//tsl/concurrency:ref_count", + "@local_tsl//tsl/platform:env", + "@local_tsl//tsl/platform:status_matchers", + "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/platform:test", + ], +) diff --git a/third_party/xla/xla/python/ifrt_proxy/integration_tests/array_impl_test_tfrt_cpu.cc b/third_party/xla/xla/python/ifrt_proxy/integration_tests/array_impl_test_tfrt_cpu.cc new file mode 100644 index 00000000000000..4ce343a1cb72df --- /dev/null +++ b/third_party/xla/xla/python/ifrt_proxy/integration_tests/array_impl_test_tfrt_cpu.cc @@ -0,0 +1,42 @@ +// Copyright 2023 The OpenXLA Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" + +int main(int argc, char** argv) { + const std::string disabled[] = { + // TfrtCpuBuffer::ToLiteral() currently does not respect the layout of the + // destination literal. + "ArrayImplTest.MakeArrayFromHostBufferAndCopyToHostBufferWithByteStrides", + + // `ShardingParamSharding` does not support serialization yet. + // TODO(b/282757875): Enable the test once IFRT implements + // `ShardingParamShardingSerDes`. + "ArrayImplTest.AssembleAndDisassembleArray", + }; + + const std::string filter = absl::StrCat("-", absl::StrJoin(disabled, ":")); +#ifdef GTEST_FLAG_SET + GTEST_FLAG_SET(filter, filter.c_str()); +#else + testing::GTEST_FLAG(filter) = filter.c_str(); +#endif + + testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/third_party/xla/xla/python/ifrt_proxy/integration_tests/executable_impl_test_tfrt_cpu.cc b/third_party/xla/xla/python/ifrt_proxy/integration_tests/executable_impl_test_tfrt_cpu.cc new file mode 100644 index 00000000000000..2f9b1e47319166 --- /dev/null +++ b/third_party/xla/xla/python/ifrt_proxy/integration_tests/executable_impl_test_tfrt_cpu.cc @@ -0,0 +1,48 @@ +// Copyright 2023 The OpenXLA Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "absl/strings/string_view.h" +#include "xla/python/ifrt/test_util.h" + +int main(int argc, char** argv) { + const std::string disabled[] = { + // Executable::IsDeleted always returns false with TFRT CPU backend. + "LoadedExecutableImplTest.IsDeleted", + + // Enable this when Serialization support for IFRT IR is available. + "IfrtIrExecutableImplTest.CallXla", + "IfrtIrExecutableImplTest.Reshard", + "IfrtIrExecutableImplTest.ZeroInput", + "IfrtIrExecutableImplTest.ZeroOutput", + "IfrtIrExecutableImplTest.BufferDonation", + "IfrtIrExecutableImplTest.LoadedExecBinding", + "ProgramLoadedExecutableImplTest.MultipleAtomProgramsNeedDummyInputs", + }; + + const std::string filter = absl::StrCat("-", absl::StrJoin(disabled, ":")); + +#ifdef GTEST_FLAG_SET + GTEST_FLAG_SET(filter, filter.c_str()); +#else + testing::GTEST_FLAG(filter) = filter.c_str(); +#endif + + testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/third_party/xla/xla/python/ifrt_proxy/integration_tests/mock_array_test.cc b/third_party/xla/xla/python/ifrt_proxy/integration_tests/mock_array_test.cc new file mode 100644 index 00000000000000..ba2dc5315f319d --- /dev/null +++ b/third_party/xla/xla/python/ifrt_proxy/integration_tests/mock_array_test.cc @@ -0,0 +1,274 @@ +// Copyright 2023 The OpenXLA Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include "absl/base/thread_annotations.h" +#include "absl/log/check.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/synchronization/mutex.h" +#include "absl/synchronization/notification.h" +#include "absl/time/clock.h" +#include "absl/time/time.h" +#include "absl/types/span.h" +#include "xla/pjrt/cpu/cpu_client.h" +#include "xla/python/ifrt/array.h" +#include "xla/python/ifrt/client.h" +#include "xla/python/ifrt/device.h" +#include "xla/python/ifrt/dtype.h" +#include "xla/python/ifrt/future.h" +#include "xla/python/ifrt/memory.h" +#include "xla/python/ifrt/mock.h" +#include "xla/python/ifrt/shape.h" +#include "xla/python/ifrt/sharding.h" +#include "xla/python/ifrt_proxy/client/client.h" +#include "xla/python/ifrt_proxy/client/registry.h" +#include "xla/python/ifrt_proxy/server/grpc_server.h" +#include "xla/python/pjrt_ifrt/pjrt_client.h" +#include "xla/status.h" +#include "tsl/concurrency/ref_count.h" +#include "tsl/platform/env.h" +#include "tsl/platform/status_matchers.h" +#include "tsl/platform/statusor.h" +#include "tsl/platform/test.h" +#include "tsl/platform/threadpool.h" + +namespace xla { +namespace ifrt { +namespace proxy { +namespace { + +using ::tsl::testing::IsOk; +using ::tsl::testing::StatusIs; + +constexpr absl::StatusCode kInternal = absl::StatusCode::kInternal; + +constexpr absl::Duration kSomeTime = absl::Seconds(1); + +class MockArrayTest : public testing::Test { + public: + void SetUp() override { + std::string address = + absl::StrCat("localhost:", tsl::testing::PickUnusedPortOrDie()); + TF_ASSERT_OK_AND_ASSIGN( + server_, GrpcServer::CreateFromIfrtClientFactory( + address, [this] { return CreateMockBackend(); })); + TF_ASSERT_OK_AND_ASSIGN(client_, + CreateClient(absl::StrCat("grpc://", address))); + } + + struct ArrayPair { + // IFRT array exposed to the proxy's user. Not a mock. + tsl::RCReference proxy_client_array; + // IFRT array owned by the proxy server whose behavior should be + // reflected by proxy_client_array. Mock but delegated. + tsl::RCReference backend_array; + }; + + absl::StatusOr NewArray() { + DType dtype(DType::kF32); + Shape shape({2, 3}); + auto data = std::make_unique>(6); + std::iota(data->begin(), data->end(), 0); + xla::ifrt::Device* device = client_->addressable_devices().at(0); + std::shared_ptr sharding = + SingleDeviceSharding::Create(device, MemoryKind()); + + TF_ASSIGN_OR_RETURN( + auto client_arr, + client_->MakeArrayFromHostBuffer( + data->data(), dtype, shape, + /*byte_strides=*/std::nullopt, sharding, + Client::HostBufferSemantics::kImmutableOnlyDuringCall, + /*on_done_with_host_buffer=*/nullptr)); + + // When the above `MakeArrayFromHostBuffer` results in the server issuing a + // `MakeArrayFromHostBuffer()` to the underlying mock backend, the mock + // backend enqueues the returned mock array onto `mock_arrays_` (this code + // is in `CreateMockBackend()`). + absl::MutexLock l(&mu_); + CHECK_EQ(mock_arrays_.size(), 1); + auto mock = mock_arrays_.back(); + mock_arrays_.pop_back(); + return ArrayPair{client_arr, mock}; + } + + std::unique_ptr server_; + std::unique_ptr client_; + + private: + absl::StatusOr> CreateMockBackend() { + // TODO(b/292339723): Use reference backend as the delegate while mocking. + CpuClientOptions options; + options.asynchronous = true; + options.cpu_device_count = 2; + TF_ASSIGN_OR_RETURN(auto tfrt_cpu_client, xla::GetTfrtCpuClient(options)); + auto mock_backend = std::make_unique( + /*delegate=*/xla::ifrt::PjRtClient::Create(std::move(tfrt_cpu_client))); + + ON_CALL(*mock_backend, MakeArrayFromHostBuffer) + .WillByDefault( + [this, mock_backend = mock_backend.get()]( + const void* data, DType dtype, Shape shape, + std::optional> byte_strides, + std::shared_ptr sharding, + Client::HostBufferSemantics semantics, + std::function on_done_with_host_buffer) + -> absl::StatusOr> { + TF_ASSIGN_OR_RETURN( + auto delegated, + mock_backend->delegated()->MakeArrayFromHostBuffer( + data, dtype, shape, byte_strides, sharding, semantics, + on_done_with_host_buffer)); + auto result = tsl::MakeRef(delegated); + + absl::MutexLock l(&mu_); + mock_arrays_.push_back(result); + return result; + }); + + return mock_backend; + } + + absl::Mutex mu_; + std::vector> mock_arrays_ ABSL_GUARDED_BY(mu_); +}; + +TEST_F(MockArrayTest, ReadyFutureWaitsUntilReady) { + TF_ASSERT_OK_AND_ASSIGN(ArrayPair arr, NewArray()); + + absl::Notification wait_ready; + + EXPECT_CALL(*arr.backend_array, GetReadyFuture).WillOnce([&] { + wait_ready.WaitForNotification(); + return arr.backend_array->delegated()->GetReadyFuture(); + }); + + auto ready = arr.proxy_client_array->GetReadyFuture(); + + absl::SleepFor(kSomeTime); + EXPECT_FALSE(ready.IsReady()); + + wait_ready.Notify(); + EXPECT_THAT(ready.Await(), IsOk()); +} + +TEST_F(MockArrayTest, ReadyFuturePropagatesError) { + TF_ASSERT_OK_AND_ASSIGN(ArrayPair arr, NewArray()); + + EXPECT_CALL(*arr.backend_array, GetReadyFuture).WillOnce([&] { + return Future(absl::InternalError("testing")); + }); + + EXPECT_THAT(arr.proxy_client_array->GetReadyFuture().Await(), + StatusIs(kInternal)); +} + +TEST_F(MockArrayTest, DeletionFutureWaitsUntilDeleted) { + TF_ASSERT_OK_AND_ASSIGN(ArrayPair arr, NewArray()); + + tsl::thread::ThreadPool threads(tsl::Env::Default(), "t", /*num_threads=*/1); + absl::Notification wait_ready; + + EXPECT_CALL(*arr.backend_array, Delete).WillOnce([&] { + // TODO(b/266635130): Write a version of this testcase where the Delete() + // call of the MockArray blocks on `wait_ready`, instead of the Future it + // returns being blocked on `wait_ready`. That version of the testcase does + // not currently work since both the client and the server synchronously + // block until the MockArray's Delete() returns. + auto promise = Future::CreatePromise(); + threads.Schedule([&, promise]() mutable { + wait_ready.WaitForNotification(); + promise.Set(arr.backend_array->delegated()->Delete().Await()); + }); + return Future(promise); + }); + + EXPECT_FALSE(arr.proxy_client_array->IsDeleted()); + auto deleted_future = arr.proxy_client_array->Delete(); + + absl::SleepFor(kSomeTime); + EXPECT_FALSE(deleted_future.IsReady()); + EXPECT_FALSE(arr.proxy_client_array->IsDeleted()); + + wait_ready.Notify(); + EXPECT_THAT(deleted_future.Await(), IsOk()); + EXPECT_TRUE(arr.proxy_client_array->IsDeleted()); +} + +TEST_F(MockArrayTest, DeletionPropagatesError) { + TF_ASSERT_OK_AND_ASSIGN(ArrayPair arr, NewArray()); + + EXPECT_CALL(*arr.backend_array, Delete).WillOnce([&] { + return Future(absl::InternalError("testing")); + }); + + EXPECT_FALSE(arr.proxy_client_array->IsDeleted()); + EXPECT_THAT(arr.proxy_client_array->Delete().Await(), StatusIs(kInternal)); +} + +TEST_F(MockArrayTest, CopyToHostFutureWaitsUntilCopied) { + TF_ASSERT_OK_AND_ASSIGN(ArrayPair arr, NewArray()); + + absl::Notification wait_ready; + + EXPECT_CALL(*arr.backend_array, CopyToHostBuffer) + .WillOnce([&](auto data, auto byte_strides, auto semantics) { + wait_ready.WaitForNotification(); + return arr.backend_array->delegated()->CopyToHostBuffer( + data, byte_strides, semantics); + }); + + char data[1000]; + auto copied = arr.proxy_client_array->CopyToHostBuffer( + &data[0], /*byte_strides=*/std::nullopt, ArrayCopySemantics::kAlwaysCopy); + + absl::SleepFor(kSomeTime); + EXPECT_FALSE(copied.IsReady()); + + wait_ready.Notify(); + EXPECT_THAT(copied.Await(), IsOk()); +} + +TEST_F(MockArrayTest, CopyToHostFuturePropagatesError) { + TF_ASSERT_OK_AND_ASSIGN(ArrayPair arr, NewArray()); + + absl::Notification wait_ready; + + EXPECT_CALL(*arr.backend_array, CopyToHostBuffer).WillOnce([&] { + return Future(absl::InternalError("testing")); + }); + + char data[1000]; + auto copied = arr.proxy_client_array->CopyToHostBuffer( + &data[0], /*byte_strides=*/std::nullopt, ArrayCopySemantics::kAlwaysCopy); + + EXPECT_THAT(copied.Await(), StatusIs(kInternal)); +} + +} // namespace +} // namespace proxy +} // namespace ifrt +} // namespace xla diff --git a/third_party/xla/xla/python/ifrt_proxy/integration_tests/register_pjrt_cpu_for_ifrt_api_tests.cc b/third_party/xla/xla/python/ifrt_proxy/integration_tests/register_pjrt_cpu_for_ifrt_api_tests.cc new file mode 100644 index 00000000000000..6b344e011ede78 --- /dev/null +++ b/third_party/xla/xla/python/ifrt_proxy/integration_tests/register_pjrt_cpu_for_ifrt_api_tests.cc @@ -0,0 +1,80 @@ +// Copyright 2023 The OpenXLA Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// This file registers a factory with `xla::ifrt::test_util` that will spawn a +// IFRT proxy client connected to an instance of a proxy server that is backed +// by the IFRT-PjRt-CPU backend. +#include +#include +#include + +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/strings/str_cat.h" +#include "xla/pjrt/cpu/cpu_client.h" +#include "xla/pjrt/pjrt_client.h" +#include "xla/python/ifrt/array.h" +#include "xla/python/ifrt/test_util.h" +#include "xla/python/ifrt/tuple.h" +#include "xla/python/ifrt/value.h" +#include "xla/python/ifrt_proxy/client/registry.h" +#include "xla/python/ifrt_proxy/server/grpc_server.h" +#include "xla/python/pjrt_ifrt/pjrt_client.h" +#include "tsl/platform/statusor.h" +#include "tsl/platform/test.h" + +namespace xla { +namespace ifrt { +namespace proxy { +namespace test_util { +namespace { + +absl::StatusOr> CreateIfrtBackendClient() { + TF_ASSIGN_OR_RETURN(std::unique_ptr tfrt_cpu_client, + xla::GetTfrtCpuClient(/*asynchronous=*/true, + /*cpu_device_count=*/2)); + return xla::ifrt::PjRtClient::Create(std::move(tfrt_cpu_client)); +} + +const bool kUnused = + (xla::ifrt::test_util::RegisterClientFactory( + []() -> absl::StatusOr> { + std::string address = + absl::StrCat("localhost:", tsl::testing::PickUnusedPortOrDie()); + TF_ASSIGN_OR_RETURN(auto server, + GrpcServer::CreateFromIfrtClientFactory( + address, CreateIfrtBackendClient)); + + TF_ASSIGN_OR_RETURN(std::unique_ptr client, + CreateClient(absl::StrCat("grpc://", address))); + + return std::shared_ptr( + client.release(), /*deleter=*/ + [server = server.release()](xla::ifrt::Client* client) { + // Client has to be destructed before the server since the + // server's destructor (as of Jul 2023) waits for the client to + // end its session. + // TODO(b/282757875): Make the server cancel the client's + // session if the server is getting destructed. + delete client; + delete server; + }); + }), + true); + +} // namespace +} // namespace test_util +} // namespace proxy +} // namespace ifrt +} // namespace xla diff --git a/third_party/xla/xla/python/ifrt_proxy/jax/BUILD b/third_party/xla/xla/python/ifrt_proxy/jax/BUILD new file mode 100644 index 00000000000000..b86f65e9c3596a --- /dev/null +++ b/third_party/xla/xla/python/ifrt_proxy/jax/BUILD @@ -0,0 +1,52 @@ +# Copyright 2023 The OpenXLA Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Jax library for IFRT proxy. +load("//xla:pytype.default.bzl", "pytype_strict_library") + +package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = ["//visibility:public"], +) + +pytype_strict_library( + name = "ifrt_proxy_internal", + srcs = ["ifrt_proxy_internal.py"], + # copybara:uncomment_begin + # visibility = [ + # "//xla/python/ifrt_proxy/common/google:friends", + # "//xla/python/ifrt_proxy/common/google:jax_users", + # ], + # copybara:uncomment_end + deps = [ + "//xla/python:xla_client", + "//xla/python/ifrt_proxy/client:py_module", + "@pybind11_abseil//pybind11_abseil:status", + ], +) + +# copybara:uncomment_begin(ifrt_proxy.py is not exported to github) +# pytype_strict_library( +# name = "ifrt_proxy", +# srcs = ["ifrt_proxy.py"], +# visibility = [ +# "//xla/python/ifrt_proxy/common/google:friends", +# "//xla/python/ifrt_proxy/common/google:jax_users", +# ], +# deps = [ +# ":ifrt_proxy_internal", +# "//third_party/py/jax", +# ], +# ) +# copybara:uncomment_end diff --git a/third_party/xla/xla/python/ifrt_proxy/jax/ifrt_proxy_internal.py b/third_party/xla/xla/python/ifrt_proxy/jax/ifrt_proxy_internal.py new file mode 100644 index 00000000000000..746575cdd61135 --- /dev/null +++ b/third_party/xla/xla/python/ifrt_proxy/jax/ifrt_proxy_internal.py @@ -0,0 +1,74 @@ +# Copyright 2023 The OpenXLA Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Library to help create a IFRT proxy client.""" + +import dataclasses +from typing import Callable, Optional + +from pybind11_abseil import status +from xla.python import xla_client +from xla.python.ifrt_proxy.client import py_module + + +@dataclasses.dataclass +class ConnectionOptions: + """Various connection options. + + Attributes: + on_disconnect: Optional, a callback that will be called if there was a + successful connection to the proxy server and Jax commands could be + issued, but there was a later disconnect before the Client is destroyed. + on_connection_update: Optional, a callback that will be called with status + updates about initial connection establishment. The updates will be + provided as human-readable strings, and an end-user may find them helpful. + """ + + on_disconnect: Optional[Callable[[status.Status], None]] = None + on_connection_update: Optional[Callable[[str], None]] = None + + +_backend_created: bool = False +_connection_options: ConnectionOptions = ConnectionOptions() + + +def get_client(proxy_server_address: str) -> xla_client.Client: + """Creates an IFRT Proxy client for the given server address.""" + global _backend_created + _backend_created = True + cpp_options = py_module.ClientConnectionOptions() + cpp_options.on_disconnect = _connection_options.on_disconnect + cpp_options.on_connection_update = _connection_options.on_connection_update + client = py_module.get_client(proxy_server_address, cpp_options) + return client + + +def set_connection_options( + options: ConnectionOptions, +) -> None: + """Sets the connection options for the "proxy" jax_platforms. + + Args: + options: See documentation for ConnectionOptions class. + + Raises: + ValueError: If this function is called after the proxy backend has already + been created. + """ + global _connection_options + if _backend_created: + raise ValueError( + "set_connection_options() called after proxy backend was created." + ) + _connection_options = options diff --git a/third_party/xla/xla/python/ifrt_proxy/server/BUILD b/third_party/xla/xla/python/ifrt_proxy/server/BUILD new file mode 100644 index 00000000000000..70dfde6ecc0210 --- /dev/null +++ b/third_party/xla/xla/python/ifrt_proxy/server/BUILD @@ -0,0 +1,320 @@ +# Copyright 2023 The OpenXLA Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +load("//xla/python/ifrt_proxy/common:ifrt_proxy.bzl", "default_ifrt_proxy_visibility", "ifrt_proxy_cc_test") + +package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = default_ifrt_proxy_visibility, +) + +cc_library( + name = "grpc_server", + srcs = ["grpc_server.cc"], + hdrs = ["grpc_server.h"], + deps = [ + ":grpc_service_impl", + ":host_buffer", + ":ifrt_backend", + "//xla/python/ifrt", + "//xla/python/ifrt_proxy/common:grpc_credentials", + "//xla/python/ifrt_proxy/common:grpc_ifrt_service_cc_grpc_proto", + "@com_github_grpc_grpc//:grpc", + "@com_github_grpc_grpc//:grpc++", + "@com_google_absl//absl/functional:any_invocable", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:string_view", + "@local_tsl//tsl/platform:statusor", + ], +) + +ifrt_proxy_cc_test( + name = "grpc_server_test", + srcs = ["grpc_server_test.cc"], + tags = ["no_aarch64"], # TODO(b/326080238): Fix this. + deps = [ + ":grpc_server", + "//xla/python/ifrt_proxy/common:grpc_ifrt_service_cc_grpc_proto", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@com_google_googletest//:gtest_main", + "@local_tsl//tsl/platform:status_matchers", + "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/platform:test", + ], +) + +cc_library( + name = "grpc_service_impl", + srcs = ["grpc_service_impl.cc"], + hdrs = ["grpc_service_impl.h"], + deps = [ + ":host_buffer", + ":ifrt_backend", + ":ifrt_session_handler", + ":version", + "//xla/pjrt/distributed:util", + "//xla/python/ifrt_proxy/common:grpc_ifrt_service_cc_grpc_proto", + "//xla/python/ifrt_proxy/common:grpc_ifrt_service_proto_cc", + "//xla/python/ifrt_proxy/common:ifrt_service_proto_cc", + "//xla/python/ifrt_proxy/common:proto_util", + "@com_github_grpc_grpc//:grpc++", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/cleanup", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/functional:any_invocable", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/log:die_if_null", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/synchronization", + ], +) + +ifrt_proxy_cc_test( + name = "grpc_service_impl_test", + size = "small", + srcs = ["grpc_service_impl_test.cc"], + tags = ["no_aarch64"], # TODO(b/326080238): Fix this. + deps = [ + ":grpc_server", + ":grpc_service_impl", + ":host_buffer", + ":version", + "//xla/python/ifrt_proxy/client:grpc_host_buffer", + "//xla/python/ifrt_proxy/common:grpc_ifrt_service_cc_grpc_proto", + "//xla/python/ifrt_proxy/common:ifrt_service_proto_cc", + "@com_github_grpc_grpc//:grpc++", + "@com_google_absl//absl/log", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:cord", + "@com_google_googletest//:gtest_main", + "@local_tsl//tsl/platform:status_matchers", + "@local_tsl//tsl/platform:test", + ], +) + +cc_library( + name = "ifrt_session_handler", + srcs = ["ifrt_session_handler.cc"], + hdrs = ["ifrt_session_handler.h"], + deps = [ + ":ifrt_backend", + "//xla/python/ifrt", + "//xla/python/ifrt_proxy/common:ifrt_service_proto_cc", + "//xla/python/ifrt_proxy/common:proto_util", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/synchronization", + "@local_tsl//tsl/platform:statusor", + ], +) + +ifrt_proxy_cc_test( + name = "ifrt_session_handler_test", + srcs = ["ifrt_session_handler_test.cc"], + deps = [ + ":ifrt_backend", + ":ifrt_session_handler", + "//xla/python/ifrt", + "@com_google_googletest//:gtest_main", + "@local_tsl//tsl/platform:status_matchers", + ], +) + +cc_library( + name = "ifrt_backend", + srcs = ["ifrt_backend.cc"], + hdrs = ["ifrt_backend.h"], + deps = [ + ":host_buffer", + ":host_callback", + ":version", + "//xla:shape_util", + "//xla:xla_data_proto_cc", + "//xla/pjrt:pjrt_client", + "//xla/python/ifrt", + "//xla/python/ifrt:serdes", + "//xla/python/ifrt_proxy/common:array_util", + "//xla/python/ifrt_proxy/common:common_serdes", + "//xla/python/ifrt_proxy/common:ifrt_service_proto_cc", + "//xla/python/ifrt_proxy/common:proto_util", + "//xla/python/ifrt_proxy/common:types", + "//xla/python/ifrt_proxy/common:types_proto_cc", + "//xla/python/pjrt_ifrt:xla_ifrt", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/cleanup", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/functional:bind_front", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/types:span", + "@llvm-project//llvm:Support", + "@local_tsl//tsl/concurrency:ref_count", + "@local_tsl//tsl/platform:env", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:status_to_from_proto", + "@local_tsl//tsl/platform:statusor", + ], +) + +ifrt_proxy_cc_test( + name = "ifrt_backend_test", + srcs = ["ifrt_backend_test.cc"], + deps = [ + ":host_buffer", + ":host_callback", + ":ifrt_backend", + ":version", + "//xla:literal", + "//xla:literal_util", + "//xla:shape_util", + "//xla:status_macros", + "//xla:test", + "//xla:xla_data_proto_cc", + "//xla/pjrt:host_callback", + "//xla/pjrt:pjrt_common", + "//xla/pjrt:pjrt_device_description", + "//xla/python/ifrt", + "//xla/python/ifrt:mock", + "//xla/python/ifrt:serdes", + "//xla/python/ifrt:sharding_serdes", + "//xla/python/ifrt_proxy/common:ifrt_service_proto_cc", + "//xla/python/ifrt_proxy/common:types", + "//xla/python/ifrt_proxy/common:types_proto_cc", + "//xla/python/pjrt_ifrt:xla_ifrt", + "//xla/service:computation_placer_hdr", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/types:span", + "@com_google_googletest//:gtest_main", + "@llvm-project//llvm:Support", + "@local_tsl//tsl/concurrency:ref_count", + "@local_tsl//tsl/platform:env", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:protobuf", + "@local_tsl//tsl/platform:status_matchers", + "@local_tsl//tsl/platform:status_to_from_proto", + "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/platform:test", + "@local_tsl//tsl/protobuf:error_codes_proto_impl_cc", + "@local_tsl//tsl/protobuf:status_proto_cc", + ], +) + +cc_library( + name = "mock_ifrt_backend", + testonly = True, + hdrs = ["mock_ifrt_backend.h"], + deps = [ + ":ifrt_backend", + "//xla/python/ifrt", + "//xla/python/ifrt_proxy/common:ifrt_service_proto_cc", + "@com_google_absl//absl/status", + "@com_google_googletest//:gtest", + ], +) + +cc_library( + name = "host_buffer", + srcs = ["host_buffer.cc"], + hdrs = ["host_buffer.h"], + deps = [ + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/synchronization", + ], +) + +cc_library( + name = "host_callback", + srcs = ["host_callback.cc"], + hdrs = ["host_callback.h"], + deps = [ + "//xla:shape_util", + "//xla/pjrt:host_callback", + "//xla/python/ifrt", + "//xla/python/ifrt_proxy/common:proto_util", + "//xla/python/pjrt_ifrt", + "//xla/python/pjrt_ifrt:xla_host_callback_proto_cc", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/functional:bind_front", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/types:span", + "@llvm-project//llvm:Support", + "@local_tsl//tsl/concurrency:ref_count", + "@local_tsl//tsl/platform:errors", + ], +) + +cc_library( + name = "version", + srcs = ["version.cc"], + hdrs = ["version.h"], + deps = [ + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + ], +) + +ifrt_proxy_cc_test( + name = "version_test", + srcs = ["version_test.cc"], + deps = [ + ":version", + "@com_google_absl//absl/status", + "@com_google_googletest//:gtest_main", + "@local_tsl//tsl/platform:status_matchers", + ], +) + +ifrt_proxy_cc_test( + name = "host_buffer_test", + srcs = ["host_buffer_test.cc"], + deps = [ + ":host_buffer", + "@com_google_absl//absl/status", + "@com_google_googletest//:gtest_main", + "@local_tsl//tsl/platform:status_matchers", + ], +) diff --git a/third_party/xla/xla/python/ifrt_proxy/server/grpc_server.cc b/third_party/xla/xla/python/ifrt_proxy/server/grpc_server.cc new file mode 100644 index 00000000000000..fbd3b6952eb103 --- /dev/null +++ b/third_party/xla/xla/python/ifrt_proxy/server/grpc_server.cc @@ -0,0 +1,99 @@ +// Copyright 2023 The OpenXLA Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "xla/python/ifrt_proxy/server/grpc_server.h" + +#include +#include +#include +#include + +#include "absl/functional/any_invocable.h" +#include "absl/memory/memory.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "grpc/grpc.h" +#include "grpcpp/completion_queue.h" +#include "grpcpp/grpcpp.h" +#include "grpcpp/server_builder.h" +#include "xla/python/ifrt/client.h" +#include "xla/python/ifrt_proxy/common/grpc_credentials.h" +#include "xla/python/ifrt_proxy/common/grpc_ifrt_service.grpc.pb.h" +#include "xla/python/ifrt_proxy/server/grpc_service_impl.h" +#include "xla/python/ifrt_proxy/server/host_buffer.h" +#include "xla/python/ifrt_proxy/server/ifrt_backend.h" +#include "tsl/platform/statusor.h" + +namespace xla { +namespace ifrt { +namespace proxy { + +GrpcServer::~GrpcServer() { + server_->Shutdown(); + server_->Wait(); +} + +absl::StatusOr> GrpcServer::Create( + absl::string_view address, + std::unique_ptr impl) { + if (impl == nullptr) { + return absl::InvalidArgumentError( + "Service implementation cannot be a nullptr."); + } + + ::grpc::ServerBuilder builder; + // Remove message size limit to accommodate large messages exchanged during + // model compilation. + builder.AddChannelArgument(GRPC_ARG_MAX_SEND_MESSAGE_LENGTH, -1); + builder.AddChannelArgument(GRPC_ARG_MAX_RECEIVE_MESSAGE_LENGTH, -1); + builder.RegisterService(impl.get()); + builder.AddListeningPort(std::string(address), GetServerCredentials()); + auto server = builder.BuildAndStart(); + if (server == nullptr) { + return absl::UnavailableError( + absl::StrCat("Failed to initialize gRPC server at address:", address)); + } + + return absl::WrapUnique( + new GrpcServer(address, std::move(impl), std::move(server))); +} + +absl::StatusOr> +GrpcServer::CreateFromIfrtClientFactory( + absl::string_view address, + absl::AnyInvocable>()> + backend_ifrt_client_factory) { + if (backend_ifrt_client_factory == nullptr) { + return absl::InvalidArgumentError( + "backend_ifrt_client_factory cannot be nullptr."); + } + + auto service = std::make_unique( + [ifrt_client_factory = std::move(backend_ifrt_client_factory)]( + IfrtProxyVersion version, uint64_t session_id, + std::shared_ptr host_buffer_store) mutable + -> absl::StatusOr> { + TF_ASSIGN_OR_RETURN(auto ifrt_client, ifrt_client_factory()); + return IfrtBackend::Create(version, session_id, std::move(ifrt_client), + std::move(host_buffer_store)); + }); + + return Create(address, std::move(service)); +} + +} // namespace proxy +} // namespace ifrt +} // namespace xla diff --git a/third_party/xla/xla/python/ifrt_proxy/server/grpc_server.h b/third_party/xla/xla/python/ifrt_proxy/server/grpc_server.h new file mode 100644 index 00000000000000..d9bd31dcee376f --- /dev/null +++ b/third_party/xla/xla/python/ifrt_proxy/server/grpc_server.h @@ -0,0 +1,79 @@ +/* + * Copyright 2023 The OpenXLA Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef XLA_PYTHON_IFRT_PROXY_SERVER_GRPC_SERVER_H_ +#define XLA_PYTHON_IFRT_PROXY_SERVER_GRPC_SERVER_H_ + +#include +#include +#include + +#include "absl/functional/any_invocable.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "grpcpp/server.h" +#include "xla/python/ifrt/client.h" +#include "xla/python/ifrt_proxy/common/grpc_ifrt_service.grpc.pb.h" + +namespace xla { +namespace ifrt { +namespace proxy { + +// Makes and runs a gRPC server with the given implementation and address. +// Destroying this object shuts down the underlying gRPC server, and so can +// block. +class GrpcServer { + public: + // The address parameter must be in the standard URI format - as needed by the + // ::grpc::ServerBuilder::AddListentingPort. See the ::grpc::ServerBuilder + // documentation for more details. + static absl::StatusOr> Create( + absl::string_view address, + std::unique_ptr impl); + + static absl::StatusOr> + CreateFromIfrtClientFactory( + absl::string_view address, + absl::AnyInvocable>()> + backend_ifrt_client_factory); + + // Starts shutting down the server and waits until it properly shuts down. + ~GrpcServer(); + + // Address this server is listening on. + std::string address() const { return address_; } + + // Blocks until the server shuts down. + void Wait() { server_->Wait(); } + + private: + GrpcServer(absl::string_view address, + std::unique_ptr impl, + std::unique_ptr<::grpc::Server> server) + : address_(address), impl_(std::move(impl)), server_(std::move(server)) {} + + const std::string address_; // Address this server is listening on. + + // Make sure that impl_ outlives the server_. + std::unique_ptr impl_; + std::unique_ptr<::grpc::Server> server_; +}; + +} // namespace proxy +} // namespace ifrt +} // namespace xla + +#endif // XLA_PYTHON_IFRT_PROXY_SERVER_GRPC_SERVER_H_ diff --git a/third_party/xla/xla/python/ifrt_proxy/server/grpc_server_test.cc b/third_party/xla/xla/python/ifrt_proxy/server/grpc_server_test.cc new file mode 100644 index 00000000000000..40216bca7876c8 --- /dev/null +++ b/third_party/xla/xla/python/ifrt_proxy/server/grpc_server_test.cc @@ -0,0 +1,72 @@ +// Copyright 2023 The OpenXLA Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "xla/python/ifrt_proxy/server/grpc_server.h" + +#include +#include + +#include +#include +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" +#include "xla/python/ifrt_proxy/common/grpc_ifrt_service.grpc.pb.h" +#include "tsl/platform/status_matchers.h" +#include "tsl/platform/statusor.h" +#include "tsl/platform/test.h" + +namespace xla { +namespace ifrt { +namespace proxy { +namespace { + +using ::testing::Not; +using ::tsl::testing::IsOk; +using ::tsl::testing::StatusIs; + +// A fake IFRT service that fails all the Session creation attempts. +class FakeIfrtService : public grpc::GrpcIfrtService::Service {}; + +TEST(GrpcServerTest, CreationTest) { + auto addr = absl::StrCat("[::1]:", tsl::testing::PickUnusedPortOrDie()); + auto grpc_service_impl = std::make_unique(); + ASSERT_THAT(GrpcServer::Create(addr, std::move(grpc_service_impl)), IsOk()); + // Also implicitly tests that the destruction of the GrpcServer object. +} + +TEST(GrpcServerTest, CreationFailsIfImplIsNullptr) { + auto addr = absl::StrCat("[::1]:", tsl::testing::PickUnusedPortOrDie()); + EXPECT_THAT(GrpcServer::Create(addr, nullptr), + StatusIs(absl::StatusCode::kInvalidArgument)); +} + +TEST(GrpcServerTest, CreationFailsWithInvalidAddress) { + auto grpc_service_impl = std::make_unique(); + EXPECT_THAT(GrpcServer::Create(/*address=*/"invalid-address", + std::move(grpc_service_impl)), + Not(IsOk())); +} + +TEST(GrpcServerTest, RetrievingServerAddressWorks) { + auto addr = absl::StrCat("[::1]:", tsl::testing::PickUnusedPortOrDie()); + auto grpc_service_impl = std::make_unique(); + TF_ASSERT_OK_AND_ASSIGN( + auto grpc_server, GrpcServer::Create(addr, std::move(grpc_service_impl))); + EXPECT_EQ(grpc_server->address(), addr); +} + +} // namespace +} // namespace proxy +} // namespace ifrt +} // namespace xla diff --git a/third_party/xla/xla/python/ifrt_proxy/server/grpc_service_impl.cc b/third_party/xla/xla/python/ifrt_proxy/server/grpc_service_impl.cc new file mode 100644 index 00000000000000..8f89d253affcbe --- /dev/null +++ b/third_party/xla/xla/python/ifrt_proxy/server/grpc_service_impl.cc @@ -0,0 +1,241 @@ +// Copyright 2023 The OpenXLA Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "xla/python/ifrt_proxy/server/grpc_service_impl.h" + +#include +#include +#include +#include +#include + +#include "absl/cleanup/cleanup.h" +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "absl/synchronization/mutex.h" +#include "grpcpp/server_context.h" +#include "grpcpp/support/status.h" +#include "grpcpp/support/sync_stream.h" +#include "xla/pjrt/distributed/util.h" +#include "xla/python/ifrt_proxy/common/grpc_ifrt_service.pb.h" +#include "xla/python/ifrt_proxy/common/proto_util.h" +#include "xla/python/ifrt_proxy/server/host_buffer.h" +#include "xla/python/ifrt_proxy/server/ifrt_session_handler.h" +#include "xla/python/ifrt_proxy/server/version.h" + +namespace xla { +namespace ifrt { +namespace proxy { + +::grpc::Status GrpcServiceImpl::GetVersion(::grpc::ServerContext* context, + const GrpcGetVersionRequest* request, + GrpcGetVersionResponse* response) { + auto protocol_version = + ChooseVersion(request->min_version().protocol_version(), + request->max_version().protocol_version()); + if (!protocol_version.ok()) { + return xla::ToGrpcStatus(protocol_version.status()); + } + response->mutable_version()->set_protocol_version(*protocol_version); + return ::grpc::Status::OK; +} + +::grpc::Status GrpcServiceImpl::IfrtSession( + ::grpc::ServerContext* context, + ::grpc::ServerReaderWriter* stream) { + GrpcIfrtSessionMetadata metadata; + { + const auto it = context->client_metadata().find( + "ifrt-proxy-grpc-ifrt-session-metadata-bin"); + if (it == context->client_metadata().end()) { + return ::grpc::Status(::grpc::StatusCode::INVALID_ARGUMENT, + "Missing metadata for GrpcIfrtService.IfrtSession: " + "ifrt-proxy-grpc-ifrt-session-metadata-bin"); + } + if (!metadata.ParseFromString(AsProtoStringData( + absl::string_view(it->second.data(), it->second.size())))) { + return ::grpc::Status(::grpc::StatusCode::INVALID_ARGUMENT, + "Unable to parse GrpcIfrtSessionMetadata"); + } + } + + const uint64_t session_id = + next_session_id_.fetch_add(1, std::memory_order_relaxed); + + VLOG(0) << "Starting a new IFRT session with session_id=" << session_id; + + // Create a host buffer store for the session. + auto host_buffer_store = + std::make_shared(); + { + absl::MutexLock l(&host_buffer_store_mu_); + CHECK(host_buffer_stores_.insert({session_id, host_buffer_store}).second); + } + absl::Cleanup cleanup = [&] { + absl::MutexLock l(&host_buffer_store_mu_); + CHECK_GT(host_buffer_stores_.erase(session_id), 0); + }; + + absl::Mutex writer_mu; + + auto session_handler = IfrtSessionHandler::Create( + session_id, + [this, version = metadata.version(), + host_buffer_store = std::move(host_buffer_store)](uint64_t session_id) { + return backend_factory_(version, session_id, host_buffer_store); + }); + + if (!session_handler.ok()) { + LOG(INFO) << "Creating session " << session_id + << " failed: " << session_handler.status(); + return xla::ToGrpcStatus(session_handler.status()); + } + + bool first_request_read = false; + while (true) { + auto request = std::make_unique(); + if (!stream->Read(request.get())) { + break; + } + if (!first_request_read) { + VLOG(0) << "First request read for session " << session_id; + first_request_read = true; + } + (*session_handler) + ->NewIncomingRequest(std::move(request), + [&](std::shared_ptr response) { + absl::MutexLock l(&writer_mu); + stream->Write(*response); + }); + } + + VLOG(0) << "Finishing IFRT session " << session_id; + return ::grpc::Status::OK; +} + +::grpc::Status GrpcServiceImpl::HostBufferStore( + ::grpc::ServerContext* context, + ::grpc::ServerReader* stream, + GrpcHostBufferStoreResponse* response) { + const auto it = context->client_metadata().find( + "ifrt-proxy-grpc-host-buffer-store-metadata-bin"); + if (it == context->client_metadata().end()) { + return ::grpc::Status( + ::grpc::StatusCode::INTERNAL, + "Missing gRPC metadata for GrpcHostBufferService.Store"); + } + + GrpcHostBufferStoreMetadata metadata; + if (!metadata.ParseFromString(AsProtoStringData( + absl::string_view(it->second.data(), it->second.size())))) { + return ::grpc::Status(::grpc::StatusCode::DATA_LOSS, + "Unable to parse GrpcHostBufferStoreMetadata"); + } + + std::string data; + data.reserve(metadata.buffer_size()); + + GrpcHostBufferStoreRequest request; + while (stream->Read(&request)) { + data.append(request.data()); + } + if (data.size() != metadata.buffer_size()) { + return ::grpc::Status( + ::grpc::StatusCode::DATA_LOSS, + absl::StrCat("Potential data loss for host buffers: expected ", + metadata.buffer_size(), " bytes but got ", data.size(), + " bytes")); + } + + auto store = GetHostBufferStore(metadata.session_id()); + if (!store.ok()) { + return xla::ToGrpcStatus(store.status()); + } + return xla::ToGrpcStatus((*store)->Store(metadata.handle(), std::move(data))); +} + +::grpc::Status GrpcServiceImpl::HostBufferLookup( + ::grpc::ServerContext* context, const GrpcHostBufferLookupRequest* request, + ::grpc::ServerWriter* stream) { + static constexpr int64_t kChunkSize = 1024 * 1024; + + auto store = GetHostBufferStore(request->session_id()); + if (!store.ok()) { + return xla::ToGrpcStatus(store.status()); + } + auto data = (*store)->Lookup(request->handle()); + if (!data.ok()) { + return xla::ToGrpcStatus(data.status()); + } + + GrpcHostBufferLookupResponse response; + if (!(*data)->empty()) { + for (int64_t offset = 0; offset < (*data)->size(); offset += kChunkSize) { +#if defined(PLATFORM_GOOGLE) + response.set_alias_data( + absl::string_view(**data).substr(offset, kChunkSize)); +#else + // TODO(b/325306748): Find a way to not do a memory-copy. + response.set_data((*data)->substr(offset, kChunkSize)); +#endif + stream->Write(response); + response.Clear(); + } + } else { + // Send at least one response even if the buffer is empty. + stream->Write(response); + } + + return ::grpc::Status::OK; +} + +::grpc::Status GrpcServiceImpl::HostBufferDelete( + ::grpc::ServerContext* context, const GrpcHostBufferDeleteRequest* request, + GrpcHostBufferDeleteResponse* response) { + auto store = GetHostBufferStore(request->session_id()); + if (!store.ok()) { + return xla::ToGrpcStatus(store.status()); + } + return xla::ToGrpcStatus((*store)->Delete(request->handle())); +} + +bool GrpcServiceImpl::Test_InsertHostBufferStore( + uint64_t session_id, + std::shared_ptr store) { + absl::MutexLock l(&host_buffer_store_mu_); + return host_buffer_stores_.insert({session_id, std::move(store)}).second; +} + +bool GrpcServiceImpl::Test_DeleteHostBufferStore(uint64_t session_id) { + absl::MutexLock l(&host_buffer_store_mu_); + return host_buffer_stores_.erase(session_id) > 0; +} + +absl::StatusOr> +GrpcServiceImpl::GetHostBufferStore(uint64_t session_id) { + absl::MutexLock l(&host_buffer_store_mu_); + const auto it = host_buffer_stores_.find(session_id); + if (it == host_buffer_stores_.end()) { + return absl::NotFoundError( + absl::StrCat("Session id ", session_id, " does not exist")); + } + return it->second; +} + +} // namespace proxy +} // namespace ifrt +} // namespace xla diff --git a/third_party/xla/xla/python/ifrt_proxy/server/grpc_service_impl.h b/third_party/xla/xla/python/ifrt_proxy/server/grpc_service_impl.h new file mode 100644 index 00000000000000..c75709b4e6ff99 --- /dev/null +++ b/third_party/xla/xla/python/ifrt_proxy/server/grpc_service_impl.h @@ -0,0 +1,107 @@ +/* + * Copyright 2023 The OpenXLA Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef XLA_PYTHON_IFRT_PROXY_SERVER_GRPC_SERVICE_IMPL_H_ +#define XLA_PYTHON_IFRT_PROXY_SERVER_GRPC_SERVICE_IMPL_H_ + +#include +#include +#include +#include + +#include "absl/base/thread_annotations.h" +#include "absl/container/flat_hash_map.h" +#include "absl/functional/any_invocable.h" +#include "absl/log/die_if_null.h" +#include "absl/status/statusor.h" +#include "absl/synchronization/mutex.h" +#include "grpcpp/server_context.h" +#include "grpcpp/support/status.h" +#include "grpcpp/support/sync_stream.h" +#include "xla/python/ifrt_proxy/common/grpc_ifrt_service.grpc.pb.h" +#include "xla/python/ifrt_proxy/common/grpc_ifrt_service.pb.h" +#include "xla/python/ifrt_proxy/common/ifrt_service.pb.h" +#include "xla/python/ifrt_proxy/server/host_buffer.h" +#include "xla/python/ifrt_proxy/server/ifrt_backend.h" + +namespace xla { +namespace ifrt { +namespace proxy { + +// Implementation for `GrpcIfrtService`. +class GrpcServiceImpl : public grpc::GrpcIfrtService::Service { + public: + using BackendFactory = + absl::AnyInvocable>( + IfrtProxyVersion version, uint64_t session_id, + std::shared_ptr + host_buffer_store)>; + + explicit GrpcServiceImpl(BackendFactory backend_factory) + : backend_factory_(ABSL_DIE_IF_NULL(std::move(backend_factory))) {} + + ::grpc::Status GetVersion(::grpc::ServerContext* context, + const GrpcGetVersionRequest* request, + GrpcGetVersionResponse* response) override; + + ::grpc::Status IfrtSession( + ::grpc::ServerContext* context, + ::grpc::ServerReaderWriter* stream) override; + + ::grpc::Status HostBufferStore( + ::grpc::ServerContext* context, + ::grpc::ServerReader* stream, + GrpcHostBufferStoreResponse* response) override; + + ::grpc::Status HostBufferLookup( + ::grpc::ServerContext* context, + const GrpcHostBufferLookupRequest* request, + ::grpc::ServerWriter* stream) override; + + ::grpc::Status HostBufferDelete( + ::grpc::ServerContext* context, + const GrpcHostBufferDeleteRequest* request, + GrpcHostBufferDeleteResponse* response) override; + + // Test-only method that adds a new session in the host buffer store map. + // Returns false if the session id already exists. + bool Test_InsertHostBufferStore( + uint64_t session_id, + std::shared_ptr store); + + // Test-only method that removes the given session id from the host buffer + // store map. Returns false if the session id does not exist. + bool Test_DeleteHostBufferStore(uint64_t session_id); + + private: + absl::StatusOr> + GetHostBufferStore(uint64_t session_id) + ABSL_LOCKS_EXCLUDED(host_buffer_store_mu_); + + BackendFactory backend_factory_; + std::atomic next_session_id_ = 1; + + absl::Mutex host_buffer_store_mu_; + absl::flat_hash_map> + host_buffer_stores_ ABSL_GUARDED_BY(host_buffer_store_mu_); +}; + +} // namespace proxy +} // namespace ifrt +} // namespace xla + +#endif // XLA_PYTHON_IFRT_PROXY_SERVER_GRPC_SERVICE_IMPL_H_ diff --git a/third_party/xla/xla/python/ifrt_proxy/server/grpc_service_impl_test.cc b/third_party/xla/xla/python/ifrt_proxy/server/grpc_service_impl_test.cc new file mode 100644 index 00000000000000..2f8f553794dc47 --- /dev/null +++ b/third_party/xla/xla/python/ifrt_proxy/server/grpc_service_impl_test.cc @@ -0,0 +1,184 @@ +// Copyright 2023 The OpenXLA Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "xla/python/ifrt_proxy/server/grpc_service_impl.h" + +#include +#include +#include + +#include +#include +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/cord.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "grpcpp/server.h" +#include "grpcpp/server_builder.h" +#include "grpcpp/support/channel_arguments.h" +#include "grpcpp/support/status.h" +#include "xla/python/ifrt_proxy/client/grpc_host_buffer.h" +#include "xla/python/ifrt_proxy/common/grpc_ifrt_service.grpc.pb.h" +#include "xla/python/ifrt_proxy/common/ifrt_service.pb.h" +#include "xla/python/ifrt_proxy/server/grpc_server.h" +#include "xla/python/ifrt_proxy/server/host_buffer.h" +#include "xla/python/ifrt_proxy/server/version.h" +#include "tsl/platform/status_matchers.h" +#include "tsl/platform/test.h" + +namespace xla { +namespace ifrt { +namespace proxy { +namespace { + +using ::tsl::testing::IsOk; +using ::tsl::testing::IsOkAndHolds; +using ::tsl::testing::StatusIs; + +IfrtProxyVersion Version() { + IfrtProxyVersion version; + version.set_protocol_version(kServerMaxVersion); + return version; +} + +// Sets up fresh GrpcServer for testing. +absl::StatusOr> MakeGrpcServer() { + // TODO(b/282993619): For external/GKE uses, we may need to find (or build) + // a utility function that works similar to PickUnusedPortorDie(). + auto addr = absl::StrCat("[::1]:", tsl::testing::PickUnusedPortOrDie()); + return GrpcServer::CreateFromIfrtClientFactory(addr, []() { + return absl::UnimplementedError( + "IFRT client creation fails. This test is not expected to " + "instantiate any IFRT client"); + }); +} + +TEST(GrpcServiceImplTest, CanBeUsedToSetupAnGrpcServer) { + ASSERT_THAT(MakeGrpcServer(), IsOk()); + // Also implicitly tests that destruction of both the server and the + // implementation objects. +} + +class GrpcIfrtServiceImplHostBufferTest + : public testing::TestWithParam { + protected: + GrpcIfrtServiceImplHostBufferTest() + : impl_([](IfrtProxyVersion version, uint64_t session_id, + std::shared_ptr host_buffer_store) { + return absl::UnimplementedError( + "IFRT backend creation is not implemented"); + }) { + ::grpc::ServerBuilder builder; + builder.RegisterService(&impl_); + server_ = builder.BuildAndStart(); + + stub_ = grpc::GrpcIfrtService::NewStub( + server_->InProcessChannel(::grpc::ChannelArguments())); + } + + // Returns a string to be stored as a host buffer. The length is parameterized + // so that we can test chunking. + std::string GetTestData() const { + std::string data; + for (int i = 0; i < GetParam(); ++i) { + data.push_back(i % 7); + } + return data; + } + + GrpcServiceImpl impl_; + std::unique_ptr<::grpc::Server> server_; + std::shared_ptr stub_; +}; + +TEST_P(GrpcIfrtServiceImplHostBufferTest, StoreAndLookupStringView) { + static constexpr uint64_t kSessionId = 1; + + auto store = std::make_shared(); + ASSERT_TRUE(impl_.Test_InsertHostBufferStore(kSessionId, store)); + GrpcClientHostBufferStore client(stub_, Version(), kSessionId); + + constexpr uint64_t kHandle = 2; + const std::string data = GetTestData(); + absl::string_view source(data); + + ASSERT_THAT(client.Store(kHandle, source).Await(), IsOk()); + EXPECT_THAT(client.Lookup(kHandle).Await(), IsOkAndHolds(data)); + + EXPECT_TRUE(impl_.Test_DeleteHostBufferStore(kSessionId)); +} + +TEST_P(GrpcIfrtServiceImplHostBufferTest, StoreAndLookupCord) { + static constexpr uint64_t kSessionId = 1; + + auto store = std::make_shared(); + ASSERT_TRUE(impl_.Test_InsertHostBufferStore(kSessionId, store)); + GrpcClientHostBufferStore client(stub_, Version(), kSessionId); + + constexpr uint64_t kHandle = 2; + const std::string data = GetTestData(); + + absl::Cord source(data); + ASSERT_THAT(client.Store(kHandle, source).Await(), IsOk()); + EXPECT_THAT(client.Lookup(kHandle).Await(), IsOkAndHolds(data)); + + EXPECT_TRUE(impl_.Test_DeleteHostBufferStore(kSessionId)); +} + +TEST_P(GrpcIfrtServiceImplHostBufferTest, Lookup) { + static constexpr uint64_t kSessionId = 1; + + auto store = std::make_shared(); + ASSERT_TRUE(impl_.Test_InsertHostBufferStore(kSessionId, store)); + GrpcClientHostBufferStore client(stub_, Version(), kSessionId); + + constexpr uint64_t kHandle = 2; + const std::string data = GetTestData(); + ASSERT_THAT(store->Store(kHandle, data), IsOk()); + + EXPECT_THAT(client.Lookup(kHandle).Await(), IsOkAndHolds(data)); + + EXPECT_TRUE(impl_.Test_DeleteHostBufferStore(kSessionId)); +} + +TEST_P(GrpcIfrtServiceImplHostBufferTest, Delete) { + static constexpr uint64_t kSessionId = 1; + + auto store = std::make_shared(); + ASSERT_TRUE(impl_.Test_InsertHostBufferStore(kSessionId, store)); + GrpcClientHostBufferStore client(stub_, Version(), kSessionId); + + constexpr uint64_t kHandle = 2; + const std::string data = GetTestData(); + ASSERT_THAT(store->Store(kHandle, data), IsOk()); + + ASSERT_THAT(client.Delete(kHandle).Await(), IsOk()); + EXPECT_THAT(client.Lookup(kHandle).Await(), + StatusIs(absl::StatusCode::kNotFound)); + + EXPECT_TRUE(impl_.Test_DeleteHostBufferStore(kSessionId)); +} + +INSTANTIATE_TEST_SUITE_P( + DataSize, GrpcIfrtServiceImplHostBufferTest, + testing::Values(0, // Empty host buffer. + 16, // Small enough to fit in one chunk. + 3 * 1024 * 1024)); // Requires multiple chunks + +} // namespace +} // namespace proxy +} // namespace ifrt +} // namespace xla diff --git a/third_party/xla/xla/python/ifrt_proxy/server/host_buffer.cc b/third_party/xla/xla/python/ifrt_proxy/server/host_buffer.cc new file mode 100644 index 00000000000000..4b9dd7391ec81f --- /dev/null +++ b/third_party/xla/xla/python/ifrt_proxy/server/host_buffer.cc @@ -0,0 +1,65 @@ +// Copyright 2023 The OpenXLA Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "xla/python/ifrt_proxy/server/host_buffer.h" + +#include +#include +#include + +#include "absl/log/check.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/synchronization/mutex.h" + +namespace xla { +namespace ifrt { +namespace proxy { + +absl::Status HostBufferStore::Store(uint64_t handle, std::string data) { + absl::MutexLock lock(&mu_); + const bool inserted = + buffers_.insert({handle, std::make_shared(std::move(data))}) + .second; + if (!inserted) { + return absl::AlreadyExistsError( + absl::StrCat("Host buffer handle ", handle, " already exists")); + } + return absl::OkStatus(); +} + +absl::StatusOr> HostBufferStore::Lookup( + uint64_t handle) { + absl::MutexLock lock(&mu_); + const auto it = buffers_.find(handle); + if (it == buffers_.end()) { + return absl::NotFoundError( + absl::StrCat("Host buffer handle ", handle, " not found")); + } + return it->second; +} + +absl::Status HostBufferStore::Delete(uint64_t handle) { + absl::MutexLock lock(&mu_); + if (buffers_.erase(handle) == 0) { + return absl::NotFoundError( + absl::StrCat("Host buffer handle ", handle, " not found")); + } + return absl::OkStatus(); +} + +} // namespace proxy +} // namespace ifrt +} // namespace xla diff --git a/third_party/xla/xla/python/ifrt_proxy/server/host_buffer.h b/third_party/xla/xla/python/ifrt_proxy/server/host_buffer.h new file mode 100644 index 00000000000000..f9b07a40f30e91 --- /dev/null +++ b/third_party/xla/xla/python/ifrt_proxy/server/host_buffer.h @@ -0,0 +1,61 @@ +/* + * Copyright 2023 The OpenXLA Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef XLA_PYTHON_IFRT_PROXY_SERVER_HOST_BUFFER_H_ +#define XLA_PYTHON_IFRT_PROXY_SERVER_HOST_BUFFER_H_ + +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/synchronization/mutex.h" + +namespace xla { +namespace ifrt { +namespace proxy { + +// Keeps host buffers transferred from the client so that `IfrtBackend` can +// access them when requests with pointers to host buffers arrive. +// +// We expect one `HostBufferStore` to exist per session (i.e., per `IfrtBackend` +// instance) so that host buffers are cleaned up on session termination. +class HostBufferStore { + public: + // Stores the data associated with the given handle. Returns an error if the + // handle already exists. + absl::Status Store(uint64_t handle, std::string data); + + // Retrieves the data associated with the handle. Returns an error if the + // handle does not exist. + absl::StatusOr> Lookup(uint64_t handle); + + // Deletes the host buffer associated with the handle. Returns an error if the + // handle does not exist. + absl::Status Delete(uint64_t handle); + + private: + absl::Mutex mu_; + absl::flat_hash_map> buffers_ + ABSL_GUARDED_BY(mu_); +}; + +} // namespace proxy +} // namespace ifrt +} // namespace xla + +#endif // XLA_PYTHON_IFRT_PROXY_SERVER_HOST_BUFFER_H_ diff --git a/third_party/xla/xla/python/ifrt_proxy/server/host_buffer_test.cc b/third_party/xla/xla/python/ifrt_proxy/server/host_buffer_test.cc new file mode 100644 index 00000000000000..7adc31658dda38 --- /dev/null +++ b/third_party/xla/xla/python/ifrt_proxy/server/host_buffer_test.cc @@ -0,0 +1,57 @@ +// Copyright 2023 The OpenXLA Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "xla/python/ifrt_proxy/server/host_buffer.h" + +#include +#include + +#include +#include +#include "absl/status/status.h" +#include "tsl/platform/status_matchers.h" + +namespace xla { +namespace ifrt { +namespace proxy { +namespace { + +using ::testing::Pointee; +using ::tsl::testing::IsOk; +using ::tsl::testing::IsOkAndHolds; +using ::tsl::testing::StatusIs; + +TEST(HostBufferStoreTest, ReadAfterWrite) { + HostBufferStore store; + const uint64_t kHandle = 1; + + ASSERT_THAT(store.Store(kHandle, "foo"), IsOk()); + EXPECT_THAT(store.Lookup(kHandle), IsOkAndHolds(Pointee(std::string("foo")))); + + ASSERT_THAT(store.Delete(kHandle), IsOk()); + EXPECT_THAT(store.Lookup(kHandle), StatusIs(absl::StatusCode::kNotFound)); +} + +TEST(HostBufferStoreTest, UnknownHandle) { + HostBufferStore store; + const uint64_t kHandle = 1; + + EXPECT_THAT(store.Lookup(kHandle), StatusIs(absl::StatusCode::kNotFound)); + EXPECT_THAT(store.Delete(kHandle), StatusIs(absl::StatusCode::kNotFound)); +} + +} // namespace +} // namespace proxy +} // namespace ifrt +} // namespace xla diff --git a/third_party/xla/xla/python/ifrt_proxy/server/host_callback.cc b/third_party/xla/xla/python/ifrt_proxy/server/host_callback.cc new file mode 100644 index 00000000000000..a675ab40c66ef0 --- /dev/null +++ b/third_party/xla/xla/python/ifrt_proxy/server/host_callback.cc @@ -0,0 +1,195 @@ +// Copyright 2023 The OpenXLA Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "xla/python/ifrt_proxy/server/host_callback.h" + +#include +#include +#include +#include +#include +#include +#include + +#include "absl/base/thread_annotations.h" +#include "absl/functional/bind_front.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/synchronization/mutex.h" +#include "absl/types/span.h" +#include "llvm/Support/ExtensibleRTTI.h" +#include "xla/pjrt/host_callback.h" +#include "xla/python/ifrt/client.h" +#include "xla/python/ifrt/future.h" +#include "xla/python/ifrt/host_callback.h" +#include "xla/python/ifrt_proxy/common/proto_util.h" +#include "xla/python/pjrt_ifrt/pjrt_host_callback.h" +#include "xla/python/pjrt_ifrt/xla_host_callback.pb.h" +#include "xla/shape.h" +#include "xla/shape_util.h" +#include "tsl/concurrency/ref_count.h" +#include "tsl/platform/errors.h" + +namespace xla { +namespace ifrt { +namespace proxy { + +RemoteLoadedHostCallbackQueue::~RemoteLoadedHostCallbackQueue() { Close(); } + +absl::Status RemoteLoadedHostCallbackQueue::Push(ExecutionRequest request) { + absl::MutexLock l(&mu_); + if (closed_) { + return absl::CancelledError( + "RemoteLoadedHostCallback has stopped accepting new execution " + "requests"); + } + requests_.push_back(std::move(request)); + return absl::OkStatus(); +} + +std::optional +RemoteLoadedHostCallbackQueue::Pop() { + auto not_empty = [this]() ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_) { + return !requests_.empty() || closed_; + }; + absl::MutexLock l(&mu_, absl::Condition(¬_empty)); + if (closed_) { + return std::nullopt; + } + ExecutionRequest request = std::move(requests_.front()); + requests_.pop_front(); + return request; +} + +void RemoteLoadedHostCallbackQueue::Close() { + std::deque requests; + { + absl::MutexLock l(&mu_); + if (!closed_) { + requests.swap(requests_); + } + closed_ = true; + } + for (auto& request : requests) { + request.status.Set(absl::CancelledError( + "RemoteLoadedHostCallback execution has been cancelled")); + } +} + +absl::StatusOr> +RemoteLoadedHostCallback::CreateFromSerialized( + xla::ifrt::Client* client, absl::string_view serialized, + std::shared_ptr queue) { + xla::ifrt::XlaHostCallbackProto proto; + if (!proto.ParseFromString(AsProtoStringData(serialized))) { + return absl::DataLossError( + "Unable to deserialize RemoteLoadedHostCallback"); + } + + auto from_proto = + [](const auto& arg_protos) -> std::vector { + std::vector args; + args.reserve(arg_protos.size()); + for (const xla::ifrt::XlaHostCallbackProto::ArgInfo& arg_proto : + arg_protos) { + xla::HostCallbackArgInfo& arg = args.emplace_back(); + arg.channel_id = static_cast(arg_proto.channel_id()); + arg.shape = xla::Shape(arg_proto.shape()); + } + return args; + }; + + return tsl::MakeRef( + client, from_proto(proto.operands()), from_proto(proto.results()), + std::move(queue)); +} + +RemoteLoadedHostCallback::RemoteLoadedHostCallback( + xla::ifrt::Client* client, std::vector operands, + std::vector results, + std::shared_ptr queue) + : llvm::RTTIExtends( + client, + [&]() { + auto xla_host_callback = std::make_unique(); + xla_host_callback->operands = std::move(operands); + xla_host_callback->results = std::move(results); + xla_host_callback->callback = + absl::bind_front(&RemoteLoadedHostCallback::Execute, this); + return xla_host_callback; + }()), + queue_(std::move(queue)) {} + +RemoteLoadedHostCallback::~RemoteLoadedHostCallback() { + if (queue_ != nullptr) { + queue_->Close(); + } +} + +absl::Status RemoteLoadedHostCallback::Execute(void** result_ptrs, + void** operand_ptrs) { + if (queue_ == nullptr) { + return absl::FailedPreconditionError( + "RemoteLoadedHostCallback without queue cannot be executed"); + } + + RemoteLoadedHostCallbackQueue::ExecutionRequest request; + + auto to_buffer = + [&](absl::Span args, void** ptrs, + std::vector& buffers) { + buffers.reserve(args.size()); + for (int i = 0; i < args.size(); ++i) { + const int64_t size = xla::ShapeUtil::ByteSizeOf(args[i].shape); + buffers.push_back(RemoteLoadedHostCallbackQueue::Buffer{ + .data = ptrs[i], .size = size}); + } + }; + to_buffer(host_callback().operands, operand_ptrs, request.operands); + to_buffer(host_callback().results, result_ptrs, request.results); + + request.status = Future::CreatePromise(); + Future status(request.status); + + // Enqueue the execution request. `IfrtBackend` retrieves this by calling + // `PopExecutionRequest` and fulfills the `results` promise. + TF_RETURN_IF_ERROR(queue_->Push(std::move(request))); + + // Block until the execution finishes and return its status. + return status.Await(); +} + +absl::StatusOr RemoteLoadedHostCallback::Serialize() const { + xla::ifrt::XlaHostCallbackProto proto; + + auto to_proto = [](absl::Span args, + auto* args_proto) { + args_proto->Reserve(args.size()); + for (const auto& arg : args) { + auto* arg_proto = args_proto->Add(); + arg_proto->set_channel_id(arg.channel_id); + *arg_proto->mutable_shape() = arg.shape.ToProto(); + } + }; + to_proto(host_callback().operands, proto.mutable_operands()); + to_proto(host_callback().results, proto.mutable_results()); + + return proto.SerializeAsString(); +} + +} // namespace proxy +} // namespace ifrt +} // namespace xla diff --git a/third_party/xla/xla/python/ifrt_proxy/server/host_callback.h b/third_party/xla/xla/python/ifrt_proxy/server/host_callback.h new file mode 100644 index 00000000000000..e2d6ea834e7d60 --- /dev/null +++ b/third_party/xla/xla/python/ifrt_proxy/server/host_callback.h @@ -0,0 +1,126 @@ +/* + * Copyright 2023 The OpenXLA Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef XLA_PYTHON_IFRT_PROXY_SERVER_HOST_CALLBACK_H_ +#define XLA_PYTHON_IFRT_PROXY_SERVER_HOST_CALLBACK_H_ + +#include +#include +#include +#include +#include +#include + +#include "absl/base/thread_annotations.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/synchronization/mutex.h" +#include "llvm/Support/ExtensibleRTTI.h" +#include "xla/pjrt/host_callback.h" +#include "xla/python/ifrt/client.h" +#include "xla/python/ifrt/future.h" +#include "xla/python/pjrt_ifrt/pjrt_host_callback.h" +#include "tsl/concurrency/ref_count.h" + +namespace xla { +namespace ifrt { +namespace proxy { + +// Command queue interface between `RemoteLoadedHostCallback` and `IfrtBackend`. +// Responsible for keeping track of in-flight execution requests. +class RemoteLoadedHostCallbackQueue { + public: + struct Buffer { + void* data; + int64_t size; + }; + + // Encapsulates a host buffer execution. Operand and result buffers are + // pre-allocated and the caller is expected to fill them in-place before + // fulfilling the `status` promise. + struct ExecutionRequest { + std::vector operands; + std::vector results; + Future::Promise status; + }; + + ~RemoteLoadedHostCallbackQueue(); + + // Pushes a new execution request to the queue. Returns an error if the queue + // has already been closed. + absl::Status Push(ExecutionRequest request); + + // Blocks until this host callback queue has at least one pending execution + // and returns its information needed to perform execution. Returns nullopt if + // the request queue has already been closed by `Close()`. + std::optional Pop(); + + // Closes this request queue. After this call, all pending executions are + // unblocked with an error and no more executions can be enqueued. + void Close(); + + private: + absl::Mutex mu_; + bool closed_ ABSL_GUARDED_BY(mu_) = false; + std::deque requests_ ABSL_GUARDED_BY(mu_); +}; + +// Host callback that delegates its execution to an external executor. The +// executor waits for execution requests to be enqueued to the given +// `RemoteLoadedHostCallbackQueue` and returns results after execution by +// fulfilling the returned promise. +// +// This class is thread-safe. +// +// Note: The current implementation inherits from PjRt's host callback +// implementation. Even though this is a violation of the IFRT proxy's layering +// principle, it is unavoidable right now because the base `LoadedHostCallback` +// in IFRT has no associated execution semantics. For now, the IFRT proxy +// focuses on supporting host callbacks on PjRt-like IFRT implementations. +class RemoteLoadedHostCallback + : public llvm::RTTIExtends { + public: + // Creates from a serialized string returned by `Serialize()`. + static absl::StatusOr> + CreateFromSerialized(xla::ifrt::Client* client, absl::string_view serialized, + std::shared_ptr queue); + + // Create from operand/result specs. + RemoteLoadedHostCallback( + xla::ifrt::Client* client, std::vector operands, + std::vector results, + std::shared_ptr queue); + + ~RemoteLoadedHostCallback() override; + + // Serializes the remote host callback instance. The returned string can be + // deserialized into `RmeoteLoadedHostCallback` using `CreateFromSerialized`. + absl::StatusOr Serialize() const override; + + private: + // Implements the interface required by `xla::HostCallback`. + absl::Status Execute(void** result_ptrs, void** operand_ptrs); + + std::shared_ptr queue_; +}; + +} // namespace proxy +} // namespace ifrt +} // namespace xla + +#endif // XLA_PYTHON_IFRT_PROXY_SERVER_HOST_CALLBACK_H_ diff --git a/third_party/xla/xla/python/ifrt_proxy/server/ifrt_backend.cc b/third_party/xla/xla/python/ifrt_proxy/server/ifrt_backend.cc new file mode 100644 index 00000000000000..cae94e7fa86cc5 --- /dev/null +++ b/third_party/xla/xla/python/ifrt_proxy/server/ifrt_backend.cc @@ -0,0 +1,1174 @@ +// Copyright 2023 The OpenXLA Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "xla/python/ifrt_proxy/server/ifrt_backend.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/base/thread_annotations.h" +#include "absl/cleanup/cleanup.h" +#include "absl/container/flat_hash_map.h" +#include "absl/functional/bind_front.h" +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/memory/memory.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "absl/synchronization/mutex.h" +#include "absl/types/span.h" +#include "llvm/Support/Casting.h" +#include "xla/layout.h" +#include "xla/pjrt/pjrt_client.h" +#include "xla/python/ifrt/array.h" +#include "xla/python/ifrt/compiler.h" +#include "xla/python/ifrt/device.h" +#include "xla/python/ifrt/executable.h" +#include "xla/python/ifrt/future.h" +#include "xla/python/ifrt/host_callback.h" +#include "xla/python/ifrt/memory.h" +#include "xla/python/ifrt/serdes.h" +#include "xla/python/ifrt/shape.h" +#include "xla/python/ifrt_proxy/common/array_util.h" +#include "xla/python/ifrt_proxy/common/ifrt_service.pb.h" +#include "xla/python/ifrt_proxy/common/proto_util.h" +#include "xla/python/ifrt_proxy/common/types.h" +#include "xla/python/ifrt_proxy/common/types.pb.h" +#include "xla/python/ifrt_proxy/server/host_buffer.h" +#include "xla/python/ifrt_proxy/server/host_callback.h" +#include "xla/python/ifrt_proxy/server/version.h" +#include "xla/python/pjrt_ifrt/xla_compiler.h" +#include "xla/xla_data.pb.h" +#include "tsl/concurrency/ref_count.h" +#include "tsl/platform/env.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/status_to_from_proto.h" +#include "tsl/platform/statusor.h" +#include "tsl/platform/threadpool.h" + +namespace xla { +namespace ifrt { +namespace proxy { + +namespace { + +// Convenient wrapper for `xla::ifrt::Deserialize()`. +template +absl::StatusOr> Deserialize( + const Serialized& serialized, + std::unique_ptr options = nullptr) { + TF_ASSIGN_OR_RETURN(auto deserialized, + Deserialize(serialized, std::move(options))); + auto obj = absl::WrapUnique(llvm::dyn_cast(deserialized.release())); + if (obj == nullptr) { + return absl::InvalidArgumentError("Deserialization type mismatch"); + } + return obj; +} + +} // namespace + +IfrtBackend::IfrtBackend(IfrtProxyVersion version, uint64_t session_id, + std::unique_ptr ifrt_client, + std::shared_ptr host_buffer_store) + : version_(std::move(version)), + session_id_(session_id), + client_(std::move(ifrt_client)), + host_buffer_store_(std::move(host_buffer_store)), + compile_thread_pool_( + tsl::Env::Default(), + []() { + tsl::ThreadOptions options; + // Use a larger stack size since XLA often requires larger stacks + // for compilation. + options.stack_size = 240 * 1024; + return options; + }(), + "IfrtBackend", + // TODO(b/282757875): Consider making this configurable. + /*num_threads=*/32) {} + +absl::StatusOr> IfrtBackend::Create( + IfrtProxyVersion version, uint64_t session_id, + std::unique_ptr ifrt_client, + std::shared_ptr host_buffer_store) { + if (ifrt_client == nullptr) { + return absl::InvalidArgumentError("ifrt_client cannot be a nullptr."); + } + if (version.protocol_version() < kServerMinVersion || + version.protocol_version() > kServerMaxVersion) { + return absl::FailedPreconditionError(absl::StrCat( + "Protocol version ", version.protocol_version(), + " is unsupported by IFRT Proxy server; supported versions: [", + kServerMinVersion, ",", kServerMaxVersion, "]")); + } + return absl::WrapUnique( + new IfrtBackend(std::move(version), session_id, std::move(ifrt_client), + std::move(host_buffer_store))); +} + +IfrtBackend::~IfrtBackend() { + // Cancel all in-flight host callback executions. + absl::flat_hash_map + host_callback_executions; + { + absl::MutexLock lock(&host_callback_executions_mutex_); + host_callback_executions.swap(host_callback_executions_); + } + for (auto& [handle, execution_request] : host_callback_executions) { + std::move(execution_request) + .status.Set(absl::CancelledError("IFRT backend has shut down")); + } + + // Wait until all async work from `AsyncExecute` finishes execution. + { + auto done = [this]() ABSL_SHARED_LOCKS_REQUIRED(in_flight_count_mutex_) { + return in_flight_count_ == 0; + }; + absl::MutexLock lock(&in_flight_count_mutex_, absl::Condition(&done)); + } +} + +Future IfrtBackend::Process( + std::unique_ptr request) { + switch (request->request_case()) { + case IfrtRequest::RequestCase::kInitRequest: + return Future(HandleInit(std::move(request))); + case IfrtRequest::RequestCase::kCheckFutureRequest: + return HandleCheckFutureRequest(std::move(request)); + case IfrtRequest::RequestCase::kMakeArrayFromHostBufferRequest: + return Future( + HandleMakeArrayFromHostBufferRequest(std::move(request))); + case IfrtRequest::RequestCase::kAssembleArrayFromSingleDeviceArraysRequest: + return Future( + HandleAssembleArrayFromSingleDeviceArraysRequest(std::move(request))); + case IfrtRequest::RequestCase::kCopyToHostBufferRequest: + return HandleCopyToHostBufferRequest(std::move(request)); + case IfrtRequest::RequestCase::kDisassembleIntoSingleDeviceArraysRequest: + return Future( + HandleDisassembleIntoSingleDeviceArraysRequest(std::move(request))); + case IfrtRequest::RequestCase::kCheckArrayReadyRequest: + return Future(HandleCheckArrayReadyRequest(std::move(request))); + case IfrtRequest::RequestCase::kReshardRequest: + return Future(HandleReshardRequest(std::move(request))); + case IfrtRequest::RequestCase::kFullyReplicatedShardRequest: + return Future( + HandleFullyReplicatedShardRequest(std::move(request))); + case IfrtRequest::RequestCase::kDeleteArrayRequest: + return Future(HandleDeleteArrayRequest(std::move(request))); + case IfrtRequest::RequestCase::kIsArrayDeletedRequest: + return Future(HandleIsArrayDeletedRequest(std::move(request))); + case IfrtRequest::RequestCase::kDestructArrayRequest: + return Future(HandleDestructArrayRequest(std::move(request))); + case IfrtRequest::RequestCase::kCompileRequest: + return Future(HandleCompileRequest(std::move(request))); + case IfrtRequest::RequestCase::kLoadedExecutableMetadataRequest: + return HandleLoadedExecutableMetadataRequest(std::move(request)); + case IfrtRequest::RequestCase::kLoadedExecutableExecuteRequest: + return Future( + HandleLoadedExecutableExecuteRequest(std::move(request))); + case IfrtRequest::RequestCase::kLoadedExecutableDeleteRequest: + return Future( + HandleLoadedExecutableDeleteRequest(std::move(request))); + case IfrtRequest::RequestCase::kLoadedExecutableIsDeletedRequest: + return Future( + HandleLoadedExecutableIsDeletedRequest(std::move(request))); + case IfrtRequest::RequestCase::kLoadedExecutableDestructRequest: + return Future( + HandleLoadedExecutableDestructRequest(std::move(request))); + case IfrtRequest::RequestCase::kLoadedHostCallbackPollRequest: + return HandleLoadedHostCallbackPollRequest(std::move(request)); + case IfrtRequest::RequestCase::kLoadedHostCallbackReturnRequest: + return Future( + HandleLoadedHostCallbackReturnRequest(std::move(request))); + case IfrtRequest::RequestCase::kGetDefaultDeviceAssignmentRequest: + return Future( + HandleGetDefaultDeviceAssignmentRequest(std::move(request))); + default: + return Future(absl::UnimplementedError(absl::StrCat( + "Got unimplemented request type: ", request->request_case()))); + } +} + +uint64_t IfrtBackend::HandleGenerator::New() { + absl::MutexLock lock(&mu_); + return current_++; +} + +void IfrtBackend::HandleGenerator::BulkNew(absl::Span handles) { + absl::MutexLock lock(&mu_); + std::iota(handles.begin(), handles.end(), current_); + current_ += handles.size(); +} + +Future IfrtBackend::AsyncExecute( + std::function handle_fn, tsl::thread::ThreadPool* thread_pool) { + { + absl::MutexLock lock(&in_flight_count_mutex_); + ++in_flight_count_; + } + auto promise = Future::CreatePromise(); + auto f = [this, promise, handle_fn = std::move(handle_fn)]() mutable { + promise.Set(handle_fn()); + { + absl::MutexLock lock(&in_flight_count_mutex_); + --in_flight_count_; + } + }; + if (thread_pool != nullptr) { + thread_pool->Schedule(std::move(f)); + } else { + tsl::Env::Default()->SchedClosure(std::move(f)); + } + return Future(std::move(promise)); +} + +///////////////////////////////////////////////////////////////////////////// +// +// Handlers for individual request types +// + +BackendInterface::Response IfrtBackend::HandleInit( + std::unique_ptr request) { + std::unique_ptr response = + NewIfrtResponse(request->request_metadata().op_id()); + auto* init_resp = response->mutable_init_response(); + init_resp->set_session_id(session_id_); + init_resp->set_platform_name(AsProtoStringData(client_->platform_name())); + init_resp->set_platform_version( + AsProtoStringData(client_->platform_version())); + init_resp->set_platform_id(client_->platform_id()); + init_resp->set_runtime_type(AsProtoStringData(client_->runtime_type())); + init_resp->set_process_index(client_->process_index()); + + for (auto* device : client_->devices()) { + InitResponse::Device* d = init_resp->add_devices(); + d->set_id(device->id()); + d->set_local_device_id(device->local_device_id().value()); + d->set_local_hardware_id(device->local_hardware_id_typed().value()); + d->set_device_kind(AsProtoStringData(device->device_kind())); + if (auto default_memory_space = device->default_memory_space(); + default_memory_space.ok()) { + d->set_default_memory_id((*default_memory_space)->id()); + } + for (const auto* memory : device->memory_spaces()) { + d->add_memory_ids(memory->id()); + } + d->set_debug_string(AsProtoStringData(device->DebugString())); + d->set_to_string(AsProtoStringData(device->ToString())); + for (const auto& [name, attr] : device->Attributes()) { + TF_ASSIGN_OR_RETURN((*d->mutable_attributes())[name], + ToVariantProto(attr)); + } + } + for (auto* addressable_device : client_->addressable_devices()) { + init_resp->add_addressable_device_ids(addressable_device->id()); + } + + absl::flat_hash_map memories; + for (auto* device : client_->devices()) { + for (xla::ifrt::Memory* memory : device->memory_spaces()) { + const auto [it, inserted] = memories.insert({memory->id(), memory}); + if (!inserted && it->second != memory) { + return absl::FailedPreconditionError(absl::StrCat( + "Two memories cannot have the same id: ", memory->ToString(), + " vs. ", it->second->ToString())); + } + } + } + for (const auto& [id, memory] : memories) { + auto* m = init_resp->add_memories(); + m->set_id(id); + m->set_memory_space_kind(AsProtoStringData(memory->memory_space_kind())); + for (const auto* device : memory->devices()) { + m->add_device_ids(device->id()); + } + m->set_debug_string(AsProtoStringData(memory->DebugString())); + m->set_to_string(AsProtoStringData(memory->ToString())); + } + + return response; +} + +Future IfrtBackend::HandleCheckFutureRequest( + std::unique_ptr request) { + const CheckFutureRequest& check_request = request->check_future_request(); + + Future future; + { + absl::MutexLock lock(&futures_mutex_); + const auto it = futures_.find(check_request.future_handle()); + if (it == futures_.end()) { + return Future(absl::NotFoundError(absl::StrCat( + "Unknown future handle: ", check_request.future_handle()))); + } + future = std::move(it->second); + futures_.erase(it); + } + + auto promise = Future::CreatePromise(); + // With PjRtFuture, the `Future` needs to be owned by one or more owners until + // `OnReady()`'s lambda gets executed. So, capture a copy of `future` in the + // lambda, making the lambda itself an owner of `future`. + future.OnReady([op_id = request->request_metadata().op_id(), promise, + hold = future](absl::Status status) mutable { + if (!status.ok()) { + promise.Set(std::move(status)); + return; + } + auto ifrt_resp = NewIfrtResponse(op_id); + ifrt_resp->mutable_check_future_response(); + promise.Set(std::move(ifrt_resp)); + }); + + return Future(std::move(promise)); +} + +BackendInterface::Response IfrtBackend::HandleMakeArrayFromHostBufferRequest( + std::unique_ptr request) { + if (!request->has_make_array_from_host_buffer_request()) { + return absl::InternalError( + "MakeArrayFromHostBuffer got an IfrtRequest with no " + "MakeArrayFromHostBufferRequest in it."); + } + auto* make_array_request = + request->mutable_make_array_from_host_buffer_request(); + + TF_ASSIGN_OR_RETURN( + auto sharding, + FromShardingProto(absl::bind_front(&Client::LookupDevice, client_.get()), + make_array_request->sharding())); + + const auto byte_strides = [&]() -> std::optional> { + if (!make_array_request->has_byte_strides()) return std::nullopt; + return FromByteStridesProto(make_array_request->byte_strides()); + }(); + const auto shape = FromShapeProto(make_array_request->shape()); + const auto dtype = FromDTypeProto(make_array_request->dtype()); + + const uint64_t host_buffer_handle = make_array_request->host_buffer_handle(); + absl::Cleanup cleanup = [&] { + CHECK_OK(host_buffer_store_->Delete(host_buffer_handle)); + }; + TF_ASSIGN_OR_RETURN(std::shared_ptr host_buffer, + host_buffer_store_->Lookup(host_buffer_handle)); + std::move(cleanup).Invoke(); + + TF_ASSIGN_OR_RETURN(const auto mem_region, + ArrayMemRegion::FromMinimalMemRegion( + *host_buffer, dtype, shape, byte_strides)); + + TF_ASSIGN_OR_RETURN( + auto array, + client_->MakeArrayFromHostBuffer( + mem_region.zeroth_element(), dtype, std::move(shape), + std::move(byte_strides), std::move(sharding), + xla::ifrt::Client::HostBufferSemantics:: + kImmutableUntilTransferCompletes, + [hold = std::move(host_buffer)]() mutable { hold.reset(); })); + + // TODO(b/282757875): Consider merging the handle_generator with the + // arrays_. + uint64_t handle = handle_generator_.New(); + { + absl::MutexLock lock(&arrays_mutex_); + arrays_.insert({handle, std::move(array)}); + } + + std::unique_ptr response = + NewIfrtResponse(request->request_metadata().op_id()); + auto* make_array_resp = + response->mutable_make_array_from_host_buffer_response(); + make_array_resp->set_array_handle(handle); + + return response; +} + +BackendInterface::Response +IfrtBackend::HandleAssembleArrayFromSingleDeviceArraysRequest( + std::unique_ptr request) { + const auto& assemble_request = + request->assemble_array_from_single_device_arrays_request(); + + std::vector> arrays; + { + absl::ReaderMutexLock lock(&arrays_mutex_); + for (const uint64_t handle : + assemble_request.single_device_array_handles()) { + TF_ASSIGN_OR_RETURN(arrays.emplace_back(), GetArrayLocked(handle)); + } + } + + Shape shape = FromShapeProto(assemble_request.shape()); + TF_ASSIGN_OR_RETURN( + auto sharding, + FromShardingProto(absl::bind_front(&Client::LookupDevice, client_.get()), + assemble_request.sharding())); + TF_ASSIGN_OR_RETURN(auto semantics, FromArrayCopySemanticsProto( + assemble_request.copy_semantics())); + + TF_ASSIGN_OR_RETURN(auto array, client_->AssembleArrayFromSingleDeviceArrays( + std::move(shape), std::move(sharding), + absl::MakeSpan(arrays), semantics)); + + auto ifrt_resp = NewIfrtResponse(request->request_metadata().op_id()); + + uint64_t handle = handle_generator_.New(); + ifrt_resp->mutable_assemble_array_from_single_device_arrays_response() + ->set_array_handle(handle); + { + absl::MutexLock lock(&arrays_mutex_); + arrays_.insert({handle, std::move(array)}); + } + + return ifrt_resp; +} + +Future IfrtBackend::HandleCopyToHostBufferRequest( + std::unique_ptr request) { + const CopyToHostBufferRequest& copy_to_host = + request->copy_to_host_buffer_request(); + + auto array = GetArray(copy_to_host.array_handle()); + if (!array.ok()) { + return Future(array.status()); + } + + // Determine the size and allocate the host buffer. + // TODO(b/282757875): We may need to redo this to account for byte_strides, + // padding, and alignment requirements. + std::optional element_size = (*array)->dtype().byte_size(); + if (element_size == std::nullopt) { + return Future( + absl::InternalError("Array element size is unknown.")); + } + int64_t host_buffer_size = + (*array)->shape().num_elements() * element_size.value(); + // Use `std::unique_ptr` for pointer stability. + auto host_buffer = std::make_unique(); + host_buffer->resize(host_buffer_size); + + const auto byte_strides = [&]() -> std::optional> { + if (!copy_to_host.has_byte_strides()) { + return std::nullopt; + } + return FromByteStridesProto(copy_to_host.byte_strides()); + }(); + const auto mem_region = ArrayMemRegion::FromMinimalMemRegion( + absl::string_view(*host_buffer), (*array)->dtype(), (*array)->shape(), + byte_strides); + if (!mem_region.ok()) { + return Future(mem_region.status()); + } + + // TODO(b/282757875): Consider other ArrayCopySemantics. + Future copy_status = + (*array)->CopyToHostBuffer(mem_region->zeroth_element(), byte_strides, + ArrayCopySemantics::kAlwaysCopy); + + auto resp_promise = Future::CreatePromise(); + Future resp_future(resp_promise); + auto on_ready = [this, op_id = request->request_metadata().op_id(), + host_buffer = std::move(host_buffer), + host_buffer_handle = copy_to_host.host_buffer_handle()]( + absl::Status status) mutable + -> absl::StatusOr> { + TF_RETURN_IF_ERROR(status); + + TF_RETURN_IF_ERROR( + host_buffer_store_->Store(host_buffer_handle, *std::move(host_buffer))); + + std::unique_ptr response = NewIfrtResponse(op_id); + response->mutable_copy_to_host_buffer_response(); + return response; + }; + copy_status.OnReady( + [promise = std::move(resp_promise), on_ready = std::move(on_ready)]( + absl::Status status) mutable { promise.Set(on_ready(status)); }); + + return resp_future; +} + +BackendInterface::Response +IfrtBackend::HandleDisassembleIntoSingleDeviceArraysRequest( + std::unique_ptr request) { + TF_ASSIGN_OR_RETURN( + auto array, + GetArray(request->disassemble_into_single_device_arrays_request() + .array_handle())); + + // TODO(b/282757875): Consider other ArrayCopySemantics. + TF_ASSIGN_OR_RETURN(auto single_device_arrays, + array->DisassembleIntoSingleDeviceArrays( + xla::ifrt::ArrayCopySemantics::kAlwaysCopy)); + + // Set up an IfrtResponse with pre-allocated space for the right number of + // single device array handles. + int64_t num_arrays = single_device_arrays.size(); + auto response = NewIfrtResponse(request->request_metadata().op_id()); + + // Pre-allocate space in the response proto and fill it in with bulk allocated + // new handles. + auto* handles = + response->mutable_disassemble_into_single_device_arrays_response() + ->mutable_single_device_array_handles(); + handles->Reserve(num_arrays); + uint64_t* handles_buf = handles->AddNAlreadyReserved(num_arrays); + handle_generator_.BulkNew(absl::MakeSpan(handles_buf, num_arrays)); + + // Install the newly created arrays into the arrays_. + { + absl::MutexLock lock(&arrays_mutex_); + for (int i = 0; i < num_arrays; ++i) { + arrays_.insert({handles_buf[i], single_device_arrays[i]}); + } + } + + return response; +} + +Future IfrtBackend::HandleCheckArrayReadyRequest( + std::unique_ptr request) { + auto array = GetArray(request->check_array_ready_request().array_handle()); + if (!array.ok()) { + return Future(array.status()); + } + + auto ifrt_response_promise = + Future::CreatePromise(); + Future ifrt_response_future( + ifrt_response_promise); + + (*array)->GetReadyFuture().OnReady( + [op_id = request->request_metadata().op_id(), + promise = std::move(ifrt_response_promise)]( + absl::Status status) mutable -> void { + if (!status.ok()) { + promise.Set(std::move(status)); + return; + } + auto ifrt_response = NewIfrtResponse(op_id); + ifrt_response->mutable_check_array_ready_response(); + promise.Set(std::move(ifrt_response)); + }); + return ifrt_response_future; +} + +BackendInterface::Response IfrtBackend::HandleReshardRequest( + std::unique_ptr request) { + const auto& reshard_request = request->reshard_request(); + TF_ASSIGN_OR_RETURN(auto array, GetArray(reshard_request.array_handle())); + TF_ASSIGN_OR_RETURN( + auto sharding, + FromShardingProto(absl::bind_front(&Client::LookupDevice, client_.get()), + reshard_request.sharding())); + TF_ASSIGN_OR_RETURN(auto semantics, FromArrayCopySemanticsProto( + reshard_request.copy_semantics())); + + TF_ASSIGN_OR_RETURN(auto resharded_array, + array->Reshard(sharding, semantics)); + + uint64_t resharded_array_handle = handle_generator_.New(); + { + absl::MutexLock lock(&arrays_mutex_); + arrays_.insert({resharded_array_handle, std::move(resharded_array)}); + } + + auto ifrt_resp = NewIfrtResponse(request->request_metadata().op_id()); + ifrt_resp->mutable_reshard_response()->set_array_handle( + resharded_array_handle); + return ifrt_resp; +} + +BackendInterface::Response IfrtBackend::HandleFullyReplicatedShardRequest( + std::unique_ptr request) { + const auto& fully_replicated_shard_request = + request->fully_replicated_shard_request(); + TF_ASSIGN_OR_RETURN(auto array, + GetArray(fully_replicated_shard_request.array_handle())); + TF_ASSIGN_OR_RETURN(auto semantics, + FromArrayCopySemanticsProto( + fully_replicated_shard_request.copy_semantics())); + + // Here we are making the assumption that the `FullyReplicatedShard` returns + // the Array corresponding to the first device in the sharding - as needed by + // the proxy client for making the SingleDeviceSharding corresponding to the + // newly created array. Revisit this when IFRT supports: (1) an inexpensive + // way to derive a SingleDeviceSharding from a fully replicated Array's + // sharding and (2) A generalized Reshard API that allows the user to request + // an Array to be made out of a specific single shard. + TF_ASSIGN_OR_RETURN(auto new_array, array->FullyReplicatedShard(semantics)); + + uint64_t new_array_handle = handle_generator_.New(); + { + absl::MutexLock lock(&arrays_mutex_); + arrays_.insert({new_array_handle, std::move(new_array)}); + } + auto ifrt_resp = NewIfrtResponse(request->request_metadata().op_id()); + ifrt_resp->mutable_fully_replicated_shard_response()->set_array_handle( + new_array_handle); + return ifrt_resp; +} + +BackendInterface::Response IfrtBackend::HandleDeleteArrayRequest( + std::unique_ptr request) { + TF_ASSIGN_OR_RETURN(auto array, + GetArray(request->delete_array_request().array_handle())); + + auto deletion_future = array->Delete(); + uint64_t future_handle = handle_generator_.New(); + { + absl::MutexLock lock(&futures_mutex_); + futures_.insert({future_handle, std::move(deletion_future)}); + } + + auto ifrt_resp = NewIfrtResponse(request->request_metadata().op_id()); + ifrt_resp->mutable_delete_array_response()->set_deletion_future_handle( + future_handle); + return ifrt_resp; +} + +BackendInterface::Response IfrtBackend::HandleIsArrayDeletedRequest( + std::unique_ptr request) { + TF_ASSIGN_OR_RETURN( + auto array, GetArray(request->is_array_deleted_request().array_handle())); + + auto ifrt_resp = NewIfrtResponse(request->request_metadata().op_id()); + ifrt_resp->mutable_is_array_deleted_response()->set_deleted( + array->IsDeleted()); + return ifrt_resp; +} + +BackendInterface::Response IfrtBackend::HandleDestructArrayRequest( + std::unique_ptr request) { + { + absl::MutexLock lock(&arrays_mutex_); + bool deleted = + arrays_.erase(request->destruct_array_request().array_handle()); + if (!deleted) { + return absl::NotFoundError( + absl::StrCat("Unknown array handle: ", + request->destruct_array_request().array_handle())); + } + } + auto ifrt_resp = NewIfrtResponse(request->request_metadata().op_id()); + + // Currently DestructArrayResponse is an empty message, but proxy clients may + // rely on its presence for correct demuxing. + ifrt_resp->mutable_destruct_array_response(); + return ifrt_resp; +} + +Future IfrtBackend::HandleCompileRequest( + std::unique_ptr request) { + // Perform compilation on a thread pool in order to (1) avoid blocking the RPC + // thread during compilation and (2) run compilation with bigger stacks (often + // necessary for XLA). + auto f = [this, request = std::shared_ptr( + std::move(request))]() -> Response { + const CompileRequest& compile_request = request->compile_request(); + + TF_ASSIGN_OR_RETURN(auto program, Deserialize( + compile_request.program())); + TF_ASSIGN_OR_RETURN(auto options, Deserialize( + compile_request.compile_options())); + + // Deserialize host callbacks. IFRT proxy currently allows only one type of + // host callbacks from the client (`RemoteLoadedHostCallback`) and this is + // serialized out of band into its own field in the request proto. + std::vector> + host_callback_queues; + { + std::vector> + loaded_host_callbacks; + for (int i = 0; i < compile_request.host_callbacks_size(); ++i) { + host_callback_queues.emplace_back( + std::make_shared()); + TF_ASSIGN_OR_RETURN( + loaded_host_callbacks.emplace_back(), + RemoteLoadedHostCallback::CreateFromSerialized( + client_.get(), compile_request.host_callbacks(i), + host_callback_queues.back())); + } + if (!loaded_host_callbacks.empty()) { + if (auto xla_options = + llvm::dyn_cast(options.get())) { + xla_options->loaded_host_callbacks = std::move(loaded_host_callbacks); + } else { + return absl::UnimplementedError( + "Host callbacks are supported only for XLA-like IFRT " + "implementations using `xla::ifrt::XlaCompileOptions`"); + } + } + } + + TF_ASSIGN_OR_RETURN(auto executable, + client_->GetDefaultCompiler()->Compile( + std::move(program), std::move(options))); + + std::unique_ptr ifrt_resp = + NewIfrtResponse(request->request_metadata().op_id()); + auto* compile_resp = ifrt_resp->mutable_compile_response(); + + uint64_t handle = handle_generator_.New(); + compile_resp->set_loaded_executable_handle(handle); + + std::vector host_callback_handles(host_callback_queues.size()); + handle_generator_.BulkNew(absl::MakeSpan(host_callback_handles)); + compile_resp->mutable_loaded_host_callback_handles()->Add( + host_callback_handles.begin(), host_callback_handles.end()); + + // Populate executable metadata. + compile_resp->set_name(AsProtoStringData(executable->name())); + compile_resp->set_num_devices(executable->num_devices()); + for (const auto& logical_device_id : + executable->addressable_device_logical_ids()) { + LogicalDeviceIds* proto = + compile_resp->add_addressable_device_logical_ids(); + proto->set_replica(logical_device_id.replica); + proto->set_partition(logical_device_id.partition); + } + for (const auto* device : executable->addressable_devices()) { + compile_resp->add_addressable_device_ids(device->id()); + } + // TODO(b/282757875): Consider making fingerprint calculation asynchronous + // if it is expected to take long. + auto fingerprint = executable->Fingerprint(); + if (!fingerprint.ok()) { + *compile_resp->mutable_fingerprint_error() = + tsl::StatusToProto(fingerprint.status()); + } else if (fingerprint->has_value()) { + compile_resp->set_fingerprint_value(std::move(fingerprint)->value()); + } + + { + absl::MutexLock lock(&executables_mutex_); + executables_.insert({handle, std::move(executable)}); + } + { + absl::MutexLock lock(&host_callback_queues_mutex_); + for (int i = 0; i < host_callback_queues.size(); ++i) { + host_callback_queues_.insert( + {host_callback_handles[i], std::move(host_callback_queues[i])}); + } + } + + return ifrt_resp; + }; + return AsyncExecute(std::move(f), &compile_thread_pool_); +} + +Future +IfrtBackend::HandleLoadedExecutableMetadataRequest( + std::unique_ptr request) { + // Call `GetParameterShardings` and `GetOutputShardings` on a thread pool + // since some implementations may block until compilation completes. + return AsyncExecute([this, request = std::shared_ptr( + std::move(request))]() -> Response { + const uint64_t handle = request->loaded_executable_metadata_request() + .loaded_executable_handle(); + TF_ASSIGN_OR_RETURN(std::shared_ptr executable, + GetLoadedExecutable(handle)); + + std::unique_ptr ifrt_resp = + NewIfrtResponse(request->request_metadata().op_id()); + auto* metadata_resp = + ifrt_resp->mutable_loaded_executable_metadata_response(); + + if (auto parameter_shardings = executable->GetParameterShardings(); + parameter_shardings.has_value()) { + metadata_resp->mutable_parameter_shardings()->mutable_shardings()->Add( + parameter_shardings->begin(), parameter_shardings->end()); + } + if (auto output_shardings = executable->GetOutputShardings(); + output_shardings.has_value()) { + metadata_resp->mutable_output_shardings()->mutable_shardings()->Add( + output_shardings->begin(), output_shardings->end()); + } + + if (auto parameter_layouts = executable->GetParameterLayouts(); + parameter_layouts.ok()) { + auto* const layouts = + metadata_resp->mutable_parameter_layouts_list()->mutable_layouts(); + for (const xla::Layout& layout : *parameter_layouts) { + layouts->Add(layout.ToProto()); + } + } else { + *metadata_resp->mutable_parameter_layouts_error() = + tsl::StatusToProto(parameter_layouts.status()); + } + if (auto output_layouts = executable->GetOutputLayouts(); + output_layouts.ok()) { + auto* const layouts = + metadata_resp->mutable_output_layouts_list()->mutable_layouts(); + for (const xla::Layout& layout : *output_layouts) { + layouts->Add(layout.ToProto()); + } + } else { + *metadata_resp->mutable_output_layouts_error() = + tsl::StatusToProto(output_layouts.status()); + } + + auto output_memory_kinds = executable->GetOutputMemoryKinds(); + if (output_memory_kinds.ok()) { + for (const auto& memory_kinds : *output_memory_kinds) { + auto* const list = metadata_resp->mutable_output_memory_kinds() + ->add_memory_kind_lists() + ->mutable_memory_kinds(); + list->Reserve(memory_kinds.size()); + list->Add(memory_kinds.begin(), memory_kinds.end()); + } + } else { + *metadata_resp->mutable_output_memory_kinds()->mutable_status() = + tsl::StatusToProto(output_memory_kinds.status()); + } + + return ifrt_resp; + }); +} + +BackendInterface::Response IfrtBackend::HandleLoadedExecutableExecuteRequest( + std::unique_ptr request) { + const LoadedExecutableExecuteRequest& execute = + request->loaded_executable_execute_request(); + TF_ASSIGN_OR_RETURN(std::shared_ptr executable, + GetLoadedExecutable(execute.loaded_executable_handle())); + + std::vector> args; + args.reserve(execute.args_handles_size()); + { + absl::ReaderMutexLock lock(&arrays_mutex_); + for (const uint64_t handle : execute.args_handles()) { + TF_ASSIGN_OR_RETURN(args.emplace_back(), GetArrayLocked(handle)); + } + } + + TF_ASSIGN_OR_RETURN(auto execute_options, + xla::ifrt::LoadedExecutable::ExecuteOptions::FromProto( + execute.execute_options())); + + std::optional devices; + if (!execute.device_ids().empty()) { + DeviceList::Devices d; + d.reserve(execute.device_ids_size()); + for (const int32_t device_id : execute.device_ids()) { + TF_ASSIGN_OR_RETURN(d.emplace_back(), client_->LookupDevice(device_id)); + } + devices = DeviceList(std::move(d)); + } + + TF_ASSIGN_OR_RETURN( + xla::ifrt::LoadedExecutable::ExecuteResult result, + executable->Execute(absl::MakeSpan(args), execute_options, devices)); + + auto ifrt_resp = NewIfrtResponse(request->request_metadata().op_id()); + LoadedExecutableExecuteResponse* execute_response = + ifrt_resp->mutable_loaded_executable_execute_response(); + + // Register the future to `futures_`. Caller is expected to call + // `CheckFuture` exactly once to check for its status and erase it. In future, + // we may introduce separate mechanisms to remove futures from `futures_` + // without checking its status for situations where futures are not used. + { + absl::MutexLock lock(&futures_mutex_); + execute_response->set_status_handle(handle_generator_.New()); + futures_.insert( + {execute_response->status_handle(), std::move(result.status)}); + } + + // Register output arrays. At this point, we should never early return because + // doing so will leak futures or output arrays registered so far. + std::vector output_handles(result.outputs.size()); + handle_generator_.BulkNew(absl::MakeSpan(output_handles)); + { + absl::MutexLock lock(&arrays_mutex_); + for (int i = 0; i < result.outputs.size(); ++i) { + tsl::RCReference& array = result.outputs[i]; + + LoadedExecutableExecuteResponse::Output* output = + execute_response->add_outputs(); + output->set_dtype(ToDTypeProto(array->dtype())); + *output->mutable_shape() = ToShapeProto(array->shape()); + TF_ASSIGN_OR_RETURN(*output->mutable_sharding(), + ToShardingProto(array->sharding())); + output->set_array_handle(output_handles[i]); + + arrays_.insert({output_handles[i], std::move(array)}); + } + } + + return ifrt_resp; +} + +BackendInterface::Response IfrtBackend::HandleLoadedExecutableDeleteRequest( + std::unique_ptr request) { + const auto& del = request->loaded_executable_delete_request(); + TF_ASSIGN_OR_RETURN(std::shared_ptr executable, + GetLoadedExecutable(del.loaded_executable_handle())); + + Future future = executable->Delete(); + + auto ifrt_resp = NewIfrtResponse(request->request_metadata().op_id()); + auto* del_response = ifrt_resp->mutable_loaded_executable_delete_response(); + + { + absl::MutexLock lock(&futures_mutex_); + del_response->set_future_handle(handle_generator_.New()); + futures_.insert({del_response->future_handle(), std::move(future)}); + } + + return ifrt_resp; +} + +BackendInterface::Response IfrtBackend::HandleLoadedExecutableIsDeletedRequest( + std::unique_ptr request) { + const auto& is_deleted = request->loaded_executable_is_deleted_request(); + TF_ASSIGN_OR_RETURN( + std::shared_ptr executable, + GetLoadedExecutable(is_deleted.loaded_executable_handle())); + + auto ifrt_resp = NewIfrtResponse(request->request_metadata().op_id()); + auto* is_deleted_response = + ifrt_resp->mutable_loaded_executable_is_deleted_response(); + is_deleted_response->set_is_deleted(executable->IsDeleted()); + + return ifrt_resp; +} + +BackendInterface::Response IfrtBackend::HandleLoadedExecutableDestructRequest( + std::unique_ptr request) { + const auto& destruct = request->loaded_executable_destruct_request(); + + std::shared_ptr executable; + { + absl::MutexLock lock(&executables_mutex_); + const auto it = executables_.find(destruct.loaded_executable_handle()); + if (it == executables_.end()) { + return absl::NotFoundError( + absl::StrCat("Unknown loaded executable handle: ", + destruct.loaded_executable_handle())); + } + executable = std::move(it->second); + executables_.erase(it); + } + executable.reset(); + + // `RemoteLoadedHostCallback`'s request queue is closed when the host callback + // objects are destroyed by the underlying IFRT implementation when there are + // no more host callback executions to be done. + + auto ifrt_resp = NewIfrtResponse(request->request_metadata().op_id()); + ifrt_resp->mutable_loaded_executable_destruct_response(); + return ifrt_resp; +} + +Future +IfrtBackend::HandleLoadedHostCallbackPollRequest( + std::unique_ptr request) { + return AsyncExecute([this, request = std::shared_ptr( + std::move(request))]() -> Response { + const auto& poll = request->loaded_host_callback_poll_request(); + const uint64_t handle = poll.loaded_host_callback_handle(); + + // Find the host callback queue associated with the given handle. + std::shared_ptr queue; + { + absl::MutexLock lock(&host_callback_queues_mutex_); + auto it = host_callback_queues_.find(handle); + if (it == host_callback_queues_.end()) { + return absl::NotFoundError( + absl::StrCat("Unknown loaded host callback handle: ", handle)); + } + queue = it->second; + } + + // Block until the host callback has any pending execution and pop its + // execution info. May return a nullopt if the host callback has been + // deleted by the underlying IFRT implementation. + auto execution_request = queue->Pop(); + if (!execution_request.has_value()) { + { + absl::MutexLock lock(&host_callback_queues_mutex_); + host_callback_queues_.erase(handle); + } + auto ifrt_resp = NewIfrtResponse(request->request_metadata().op_id()); + ifrt_resp->mutable_loaded_host_callback_poll_response(); + return ifrt_resp; + } + + // After this point, we must fulfill the promise eventually in order to + // avoid deadlock (`absl::Cleanup` ensures this). + + absl::Cleanup cleanup = [&] { + std::move(execution_request) + ->status.Set(absl::UnknownError( + "Unable to enqueue the host callback execution")); + }; + + // Store the operands as a single contiguous buffer in the host buffer + // store. The client retrieves it by invoking `HostBufferLookup`. + { + std::string buffer; + for (const auto& operand : execution_request->operands) { + buffer.append(static_cast(operand.data), operand.size); + } + TF_RETURN_IF_ERROR(host_buffer_store_->Store( + poll.operand_host_buffer_handle(), std::move(buffer))); + } + + const uint64_t execution_handle = handle_generator_.New(); + { + absl::MutexLock lock(&host_callback_executions_mutex_); + host_callback_executions_.insert( + {execution_handle, *std::move(execution_request)}); + } + std::move(cleanup).Cancel(); + + auto ifrt_resp = NewIfrtResponse(request->request_metadata().op_id()); + auto* poll_response = + ifrt_resp->mutable_loaded_host_callback_poll_response(); + poll_response->set_host_callback_execution_handle(execution_handle); + return ifrt_resp; + }); +} + +BackendInterface::Response IfrtBackend::HandleLoadedHostCallbackReturnRequest( + std::unique_ptr request) { + const auto& ret = request->loaded_host_callback_return_request(); + + RemoteLoadedHostCallbackQueue::ExecutionRequest execution_request; + { + absl::MutexLock lock(&host_callback_executions_mutex_); + const auto it = + host_callback_executions_.find(ret.host_callback_execution_handle()); + if (it == host_callback_executions_.end()) { + return absl::NotFoundError( + absl::StrCat("Unknown host callback execution: ", + ret.host_callback_execution_handle())); + } + execution_request = std::move(it->second); + host_callback_executions_.erase(it); + } + absl::Cleanup cleanup = [&] { + std::move(execution_request) + .status.Set(absl::UnknownError( + "Unable to process the host callback execution results")); + }; + + // Copy the results from the host buffer store to the preallocated result + // buffers from `RemoteLoadedHostCallback`. Must be done before fulfilling the + // promise since the buffers may not be alive after that. + absl::Status status; + if (ret.has_result_host_buffer_handle()) { + TF_ASSIGN_OR_RETURN( + std::shared_ptr buffer, + host_buffer_store_->Lookup(ret.result_host_buffer_handle())); + absl::Cleanup cleanup = [&] { + CHECK_OK(host_buffer_store_->Delete(ret.result_host_buffer_handle())); + }; + + int64_t offset = 0; + for (const auto& result : execution_request.results) { + if (offset + result.size > buffer->size()) { + return absl::InternalError( + absl::StrCat("Buffer overflow while reading host callback " + "execution results; ", + "range: [", offset, ", ", offset + result.size, "), ", + "buffer size: ", buffer->size())); + } + std::memcpy(result.data, buffer->data() + offset, result.size); + offset += result.size; + } + if (offset != buffer->size()) { + return absl::InternalError( + absl::StrCat("Host callback execution did not consume the entire " + "result buffer; size: ", + buffer->size(), "; consumed: ", offset)); + } + } else { + status = tsl::StatusFromProto(ret.error()); + } + + // Fulfill the result promise. This unblocks the execution of the associated + // `RemoteLoadedHostCallback`. It is unsafe to access `execution_request` + // after this since the buffers may not be alive. + std::move(execution_request).status.Set(std::move(status)); + std::move(cleanup).Cancel(); + + auto ifrt_resp = NewIfrtResponse(request->request_metadata().op_id()); + ifrt_resp->mutable_loaded_host_callback_return_response(); + return ifrt_resp; +} + +BackendInterface::Response IfrtBackend::HandleGetDefaultDeviceAssignmentRequest( + std::unique_ptr request) { + const auto& get_default_device_assignment_request = + request->get_default_device_assignment_request(); + TF_ASSIGN_OR_RETURN( + auto assignment, + client_->GetDefaultDeviceAssignment( + get_default_device_assignment_request.num_replicas(), + get_default_device_assignment_request.num_partitions())); + + auto ifrt_resp = NewIfrtResponse(request->request_metadata().op_id()); + + // Currently, the xla::DeviceAssignment::Serialize does not fail. If test + // coverage for this error is needed, consider using testing::test_value to + // inject one. + TF_RETURN_IF_ERROR(assignment.Serialize( + ifrt_resp->mutable_get_default_device_assignment_response() + ->mutable_device_assignment())); + + return ifrt_resp; +} + +absl::StatusOr> +IfrtBackend::GetLoadedExecutable(uint64_t handle) { + absl::MutexLock lock(&executables_mutex_); + auto it = executables_.find(handle); + if (it == executables_.end()) { + return absl::NotFoundError( + absl::StrCat("Unknown loaded executable handle: ", handle)); + } + return it->second; +} + +absl::StatusOr> IfrtBackend::GetArray( + uint64_t array_handle) { + absl::ReaderMutexLock lock(&arrays_mutex_); + return GetArrayLocked(array_handle); +} + +absl::StatusOr> IfrtBackend::GetArrayLocked( + uint64_t array_handle) { + auto it = arrays_.find(array_handle); + if (it == arrays_.end()) { + return absl::NotFoundError( + absl::StrCat("Unknown array handle: ", array_handle)); + } + return it->second; +} + +} // namespace proxy +} // namespace ifrt +} // namespace xla diff --git a/third_party/xla/xla/python/ifrt_proxy/server/ifrt_backend.h b/third_party/xla/xla/python/ifrt_proxy/server/ifrt_backend.h new file mode 100644 index 00000000000000..9dd57c66dd2a1a --- /dev/null +++ b/third_party/xla/xla/python/ifrt_proxy/server/ifrt_backend.h @@ -0,0 +1,207 @@ +/* + * Copyright 2023 The OpenXLA Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef XLA_PYTHON_IFRT_PROXY_SERVER_IFRT_BACKEND_H_ +#define XLA_PYTHON_IFRT_PROXY_SERVER_IFRT_BACKEND_H_ + +#include +#include +#include + +#include "absl/base/thread_annotations.h" +#include "absl/container/flat_hash_map.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/synchronization/mutex.h" +#include "absl/types/span.h" +#include "xla/python/ifrt/array.h" +#include "xla/python/ifrt/client.h" +#include "xla/python/ifrt/executable.h" +#include "xla/python/ifrt/future.h" +#include "xla/python/ifrt/host_callback.h" +#include "xla/python/ifrt_proxy/common/ifrt_service.pb.h" +#include "xla/python/ifrt_proxy/server/host_buffer.h" +#include "xla/python/ifrt_proxy/server/host_callback.h" +#include "tsl/concurrency/ref_count.h" +#include "tsl/platform/threadpool.h" + +namespace xla { +namespace ifrt { +namespace proxy { + +// The abstract class `BackendInterface` defines the interface used by the IFRT +// service to interact with a variety of backend runtime system it can utilize. +class BackendInterface { + public: + virtual ~BackendInterface() = default; + + // Currently, responses (particularly those that carry buffer contents) can be + // of non-trivial size. Once we figured out how best to move the data, we may + // want to revise the shared_ptr below to the `IfrtResponse` proto itself. + // Also, if and when we have a move-only Future in xla::ifrt, we may consider + // changing it to std::unique_ptr. + using Response = absl::StatusOr>; + + // Processes a given IFRT Request and returns a Future of an IfrtResponse. + virtual Future Process(std::unique_ptr request) = 0; +}; + +// IfrtBackend implements a backend that already has a linkable C++ client that +// conforms to the xla::ifrt API. +class IfrtBackend final : public BackendInterface { + public: + // Creates an returns an IfrtBackend that uses the given IFRT Client to + // process the incoming proxy client requests. The `ifrt_client` param cannot + // be a nullptr. + static absl::StatusOr> Create( + IfrtProxyVersion version, uint64_t session_id, + std::unique_ptr ifrt_client, + std::shared_ptr host_buffer_store); + + ~IfrtBackend() override; + + // IFRT Proxy version negotiated between the client and the server. + const IfrtProxyVersion& version() const { return version_; } + + Future Process(std::unique_ptr request) override; + + private: + // Generates unique handles for returning to the client. All object types + // currently use this single "handle space". + class HandleGenerator { + public: + uint64_t New(); + + // Bulk allocates a given number of handles and saves them into the provided + // Span. + void BulkNew(absl::Span handles); + + private: + absl::Mutex mu_; + uint64_t current_ ABSL_GUARDED_BY(mu_) = 1; + }; + + IfrtBackend(IfrtProxyVersion version, uint64_t session_id, + std::unique_ptr ifrt_client, + std::shared_ptr host_buffer_store); + + // Executes the given function on the given thread pool and returns a future + // that becomes ready when the function returns. If the thread pool is not + // given, uses a default thread pool implementation that does not limit the + // maximum number of threads. + Future AsyncExecute(std::function handle_fn, + tsl::thread::ThreadPool* thread_pool = nullptr); + + ////////////////////////////////////////////////////////////////////// + // Handlers for individual requests + // + + Response HandleInit(std::unique_ptr request); + + Future HandleCheckFutureRequest( + std::unique_ptr request); + + Response HandleMakeArrayFromHostBufferRequest( + std::unique_ptr request); + Response HandleAssembleArrayFromSingleDeviceArraysRequest( + std::unique_ptr request); + Future HandleCopyToHostBufferRequest( + std::unique_ptr request); + Response HandleDisassembleIntoSingleDeviceArraysRequest( + std::unique_ptr request); + Response HandleReshardRequest(std::unique_ptr request); + Response HandleFullyReplicatedShardRequest( + std::unique_ptr request); + Future HandleCheckArrayReadyRequest( + std::unique_ptr request); + Response HandleDeleteArrayRequest(std::unique_ptr request); + Response HandleIsArrayDeletedRequest(std::unique_ptr request); + Response HandleDestructArrayRequest(std::unique_ptr request); + + Future HandleCompileRequest(std::unique_ptr request); + + Future HandleLoadedExecutableMetadataRequest( + std::unique_ptr request); + Response HandleLoadedExecutableExecuteRequest( + std::unique_ptr request); + Response HandleLoadedExecutableDeleteRequest( + std::unique_ptr request); + Response HandleLoadedExecutableIsDeletedRequest( + std::unique_ptr request); + Response HandleLoadedExecutableDestructRequest( + std::unique_ptr request); + + Future HandleLoadedHostCallbackPollRequest( + std::unique_ptr request); + Response HandleLoadedHostCallbackReturnRequest( + std::unique_ptr request); + + Response HandleGetDefaultDeviceAssignmentRequest( + std::unique_ptr request); + + ////////////////////////////////////////////////////////////////////// + // Convenient methods for object lookups + // + + absl::StatusOr> + GetLoadedExecutable(uint64_t handle); + + absl::StatusOr> GetArray(uint64_t handle); + absl::StatusOr> GetArrayLocked( + uint64_t handle) ABSL_SHARED_LOCKS_REQUIRED(arrays_mutex_); + + HandleGenerator handle_generator_; + + // Must not change during the life of this object. + const IfrtProxyVersion version_; + const uint64_t session_id_; + const std::unique_ptr client_; + const std::shared_ptr host_buffer_store_; + + absl::Mutex futures_mutex_; + absl::flat_hash_map> futures_ + ABSL_GUARDED_BY(futures_mutex_); + + absl::Mutex arrays_mutex_; + absl::flat_hash_map> arrays_ + ABSL_GUARDED_BY(arrays_mutex_); + + absl::Mutex executables_mutex_; + absl::flat_hash_map> + executables_ ABSL_GUARDED_BY(executables_mutex_); + + absl::Mutex host_callback_queues_mutex_; + absl::flat_hash_map> + host_callback_queues_ ABSL_GUARDED_BY(host_callback_queues_mutex_); + + absl::Mutex host_callback_executions_mutex_; + absl::flat_hash_map + host_callback_executions_ + ABSL_GUARDED_BY(host_callback_executions_mutex_); + + absl::Mutex in_flight_count_mutex_; + int64_t in_flight_count_ ABSL_GUARDED_BY(in_flight_count_mutex_) = 0; + + // Use a separate thread pool for compilation as XLA compilation often + // requires a bigger stack. + tsl::thread::ThreadPool compile_thread_pool_; +}; + +} // namespace proxy +} // namespace ifrt +} // namespace xla + +#endif // XLA_PYTHON_IFRT_PROXY_SERVER_IFRT_BACKEND_H_ diff --git a/third_party/xla/xla/python/ifrt_proxy/server/ifrt_backend_test.cc b/third_party/xla/xla/python/ifrt_proxy/server/ifrt_backend_test.cc new file mode 100644 index 00000000000000..b7b35560be74e3 --- /dev/null +++ b/third_party/xla/xla/python/ifrt_proxy/server/ifrt_backend_test.cc @@ -0,0 +1,1406 @@ +// Copyright 2023 The OpenXLA Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "xla/python/ifrt_proxy/server/ifrt_backend.h" + +#include + +#include +#include +#include +#include +#include +#include + +#include +#include +#include "absl/base/thread_annotations.h" +#include "absl/container/flat_hash_map.h" +#include "absl/log/check.h" +#include "absl/memory/memory.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "absl/synchronization/mutex.h" +#include "absl/types/span.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/ExtensibleRTTI.h" +#include "xla/layout.h" +#include "xla/layout_util.h" +#include "xla/literal.h" +#include "xla/literal_util.h" +#include "xla/pjrt/host_callback.h" +#include "xla/pjrt/pjrt_common.h" +#include "xla/pjrt/pjrt_device_description.h" +#include "xla/python/ifrt/array.h" +#include "xla/python/ifrt/compiler.h" +#include "xla/python/ifrt/device.h" +#include "xla/python/ifrt/dtype.h" +#include "xla/python/ifrt/executable.h" +#include "xla/python/ifrt/future.h" +#include "xla/python/ifrt/host_callback.h" +#include "xla/python/ifrt/memory.h" +#include "xla/python/ifrt/mock.h" +#include "xla/python/ifrt/serdes.h" +#include "xla/python/ifrt/shape.h" +#include "xla/python/ifrt/sharding.h" +#include "xla/python/ifrt_proxy/common/ifrt_service.pb.h" +#include "xla/python/ifrt_proxy/common/types.h" +#include "xla/python/ifrt_proxy/common/types.pb.h" +#include "xla/python/ifrt_proxy/server/host_buffer.h" +#include "xla/python/ifrt_proxy/server/host_callback.h" +#include "xla/python/ifrt_proxy/server/version.h" +#include "xla/python/pjrt_ifrt/xla_compiler.h" +#include "xla/service/computation_placer.h" +#include "xla/shape_util.h" +#include "xla/status_macros.h" +#include "xla/test.h" +#include "xla/xla_data.pb.h" +#include "tsl/concurrency/ref_count.h" +#include "tsl/platform/env.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/protobuf.h" // IWYU pragma: keep +#include "tsl/platform/status_matchers.h" +#include "tsl/platform/status_to_from_proto.h" +#include "tsl/platform/statusor.h" +#include "tsl/platform/test.h" +#include "tsl/protobuf/error_codes.pb.h" +#include "tsl/protobuf/status.pb.h" + +namespace xla { +namespace ifrt { +namespace proxy { +namespace { + +using ::testing::_; +using ::testing::ByMove; +using ::testing::DoAll; +using ::testing::ElementsAreArray; +using ::testing::HasSubstr; +using ::testing::Invoke; +using ::testing::Not; +using ::testing::NotNull; +using ::testing::Pointee; +using ::testing::Return; +using ::testing::ReturnRef; +using ::testing::SizeIs; +using ::testing::StrEq; +using ::tsl::protobuf::TextFormat; +using ::tsl::testing::IsOk; +using ::tsl::testing::IsOkAndHolds; +using ::tsl::testing::StatusIs; + +#if defined(PLATFORM_GOOGLE) +using ::testing::EquivToProto; +using ::testing::proto::IgnoringRepeatedFieldOrdering; +using ::testing::proto::Partially; +#endif + +constexpr uint64_t kSessionId = 12345; + +IfrtProxyVersion Version() { + IfrtProxyVersion version; + version.set_protocol_version(kServerMaxVersion); + return version; +} + +// Makes an empty request with the given op_id. Does not fail. +std::unique_ptr NewIfrtRequest(uint64_t op_id) { + auto ifrt_request = std::make_unique(); + auto* request_metadata = ifrt_request->mutable_request_metadata(); + request_metadata->set_op_id(op_id); + return ifrt_request; +} + +TEST(IfrtBackendTest, CreationFailsWithNullIfrtClient) { + EXPECT_THAT(IfrtBackend::Create(Version(), kSessionId, nullptr, nullptr), + StatusIs(absl::StatusCode::kInvalidArgument)); +} + +TEST(IfrtBackendTest, SuccessfulCreation) { + auto ifrt_client = std::make_unique(); + ASSERT_THAT(IfrtBackend::Create(Version(), kSessionId, std::move(ifrt_client), + std::make_shared()), + IsOk()); +} + +TEST(IfrtBackendTest, ShutdownSucceeds) { + auto ifrt_client = std::make_unique(); + TF_ASSERT_OK_AND_ASSIGN( + auto ifrt_backend, + IfrtBackend::Create(Version(), kSessionId, std::move(ifrt_client), + std::make_shared())); +} + +TEST(IfrtBackendTest, ProcessFailsWithNoRequestSet) { + auto ifrt_client = std::make_unique(); + TF_ASSERT_OK_AND_ASSIGN( + auto ifrt_backend, + IfrtBackend::Create(Version(), kSessionId, std::move(ifrt_client), + std::make_shared())); + + // Make a new request but leave the `OneOf` `request` field unset. And, that + // should fail the Process call. + auto request = std::make_unique(); + auto process_status = ifrt_backend->Process(std::move(request)).Await(); + ASSERT_THAT(process_status, Not(IsOk())); +} + +struct TestProgram : llvm::RTTIExtends { + static char ID; // NOLINT +}; + +[[maybe_unused]] char TestProgram::ID = 0; // NOLINT + +class TestProgramSerDes : public llvm::RTTIExtends { + public: + absl::string_view type_name() const override { + return "xla::ifrt::proxy::TestProgram"; + } + + absl::StatusOr Serialize(Serializable& serializable) override { + CHECK(llvm::isa(serializable)); + return ""; + } + + absl::StatusOr> Deserialize( + const std::string& serialized, + std::unique_ptr options) override { + return std::make_unique(); + } + + static char ID; // NOLINT +}; + +[[maybe_unused]] char TestProgramSerDes::ID = 0; // NOLINT + +struct TestCompileOptions + : llvm::RTTIExtends { + static char ID; // NOLINT +}; + +[[maybe_unused]] char TestCompileOptions::ID = 0; // NOLINT + +class TestCompileOptionsSerDes + : public llvm::RTTIExtends { + public: + absl::string_view type_name() const override { + return "xla::ifrt::proxy::TestCompileOptions"; + } + + absl::StatusOr Serialize(Serializable& serializable) override { + CHECK(llvm::isa(serializable)); + return ""; + } + + absl::StatusOr> Deserialize( + const std::string& serialized, + std::unique_ptr options) override { + return std::make_unique(); + } + + static char ID; // NOLINT +}; + +[[maybe_unused]] char TestCompileOptionsSerDes::ID = 0; // NOLINT + +class IfrtBackendHandlerTest : public testing::Test { + protected: + static void SetUpTestSuite() { + RegisterSerDes(std::make_unique()); + RegisterSerDes( + std::make_unique()); + } + + void SetUp() override { + auto mock_client = std::make_unique(); + + std::vector raw_device_ptrs; + for (int i = 0; i < 2; ++i) { + auto mock_device = std::make_unique(); + ON_CALL(*mock_device, global_device_id()) + .WillByDefault(Return(xla::PjRtGlobalDeviceId(i))); + raw_device_ptrs.push_back(mock_device.get()); + mock_devices_.push_back(std::move(mock_device)); + } + + ON_CALL(*mock_client, devices()).WillByDefault(Return(raw_device_ptrs)); + ON_CALL(*mock_client, LookupDevice(_)) + .WillByDefault( + Invoke([this](int id) -> absl::StatusOr { + if (id < 0 || id >= mock_devices_.size()) { + return absl::NotFoundError( + absl::StrCat("Unknown device id: ", id)); + } + return mock_devices_[id].get(); + })); + + // Remembering a raw pointer to the mock client here is OK, since most tests + // anyway have to make the basic and tacit assumption that the backend will + // call into the mock client --and thus keep it alive-- for the duration of + // the test. + mock_client_ = mock_client.get(); + + EXPECT_CALL(*mock_client_, GetDefaultCompiler) + .WillRepeatedly(Return(&mock_compiler_)); + + host_buffer_store_ = std::make_shared(); + TF_ASSERT_OK_AND_ASSIGN( + backend_, + IfrtBackend::Create(Version(), kSessionId, std::move(mock_client), + host_buffer_store_)); + } + + absl::StatusOr> CallBackend( + std::unique_ptr request) { + auto response_future = backend_->Process(std::move(request)); + return std::move(response_future).Await(); + } + + uint64_t NewOpId() { + absl::MutexLock lock(&mu_); + return current_op_id_++; + } + + uint64_t NewHostBufferHandle() { return current_host_buffer_handle_++; } + + // Utility method to set up a given MockArray (in the backend) that can then + // be the target of the other Array-specific methods. Returns the array + // handle. + absl::StatusOr MakeTestArray(tsl::RCReference mock_array) { + EXPECT_CALL(*mock_client_, MakeArrayFromHostBuffer(_, _, _, _, _, _, _)) + .WillOnce(Return(std::move(mock_array))); + + auto ifrt_request = NewIfrtRequest(NewOpId()); + { + const uint64_t host_buffer_handle = NewHostBufferHandle(); + TF_RETURN_IF_ERROR( + host_buffer_store_->Store(host_buffer_handle, "01234567")); + + auto* make_array = + ifrt_request->mutable_make_array_from_host_buffer_request(); + make_array->set_dtype(proto::DTYPE_S32); + make_array->mutable_shape()->add_dimensions(2); + make_array->set_host_buffer_handle(host_buffer_handle); + + TF_ASSIGN_OR_RETURN(auto* device, mock_client_->LookupDevice(1)); + TF_ASSIGN_OR_RETURN( + *make_array->mutable_sharding(), + ToShardingProto(*SingleDeviceSharding::Create(device, MemoryKind()))); + } + TF_ASSIGN_OR_RETURN(auto make_array_response, + CallBackend(std::move(ifrt_request))); + + TF_RETURN_IF_ERROR(tsl::StatusFromProto( + make_array_response->response_metadata().status())); + return make_array_response->make_array_from_host_buffer_response() + .array_handle(); + } + + absl::StatusOr CompileTestLoadedExecutable( + absl::StatusOr> loaded_executable) { + auto request = NewIfrtRequest(NewOpId()); + CompileRequest* compile_request = request->mutable_compile_request(); + TestProgram program; + TF_ASSIGN_OR_RETURN(*compile_request->mutable_program(), + Serialize(program)); + TestCompileOptions compile_options; + TF_ASSIGN_OR_RETURN(*compile_request->mutable_compile_options(), + Serialize(compile_options)); + + EXPECT_CALL(mock_compiler_, Compile(_, _)) + .WillOnce(Return(ByMove(std::move(loaded_executable)))); + + TF_ASSIGN_OR_RETURN(std::shared_ptr response, + CallBackend(std::move(request))); + + TF_RET_CHECK(response->has_compile_response()); + return response->compile_response(); + } + + absl::Status CheckFuture(uint64_t handle) { + auto request = NewIfrtRequest(NewOpId()); + request->mutable_check_future_request()->set_future_handle(handle); + TF_ASSIGN_OR_RETURN(std::shared_ptr response, + CallBackend(std::move(request))); + return tsl::StatusFromProto(response->response_metadata().status()); + } + + xla::ifrt::MockClient* mock_client_; + xla::ifrt::MockCompiler mock_compiler_; + std::vector> mock_devices_; + std::shared_ptr host_buffer_store_; + + private: + absl::Mutex mu_; + uint64_t current_op_id_ ABSL_GUARDED_BY(mu_) = 1; + uint64_t current_host_buffer_handle_ = 1; + + std::unique_ptr backend_; +}; + +// TODO(b/315809436): Test needs rewrite because protobuf matchers are not OSS +#if defined(PLATFORM_GOOGLE) +TEST_F(IfrtBackendHandlerTest, Init) { + EXPECT_CALL(*mock_client_, platform_name()) + .WillRepeatedly(Return("ifrt_backend")); + EXPECT_CALL(*mock_client_, platform_version()).WillRepeatedly(Return("n/a")); + EXPECT_CALL(*mock_client_, platform_id()).WillRepeatedly(Return(42)); + EXPECT_CALL(*mock_client_, process_index()).WillRepeatedly(Return(1)); + EXPECT_CALL(*mock_client_, runtime_type()) + .WillRepeatedly(Return("ifrt-service")); + + std::vector> mock_memory_devices; + mock_memory_devices.reserve(mock_devices_.size()); + for (const auto& mock_device : mock_devices_) { + mock_memory_devices.push_back({mock_device.get()}); + } + + std::vector mock_memories(mock_devices_.size()); + for (int i = 0; i < mock_memories.size(); ++i) { + MockMemory& memory = mock_memories[i]; + EXPECT_CALL(memory, devices()) + .WillRepeatedly(Return(mock_memory_devices[i])); + EXPECT_CALL(memory, id()).WillRepeatedly(Return(i)); + EXPECT_CALL(memory, memory_space_kind()).WillRepeatedly(Return("mock")); + } + + std::vector> device_memories; + device_memories.reserve(mock_devices_.size()); + for (int i = 0; i < mock_devices_.size(); ++i) { + device_memories.push_back({&mock_memories[i]}); + } + + using AttributeMap = + absl::flat_hash_map; + std::vector device_attributes(mock_devices_.size()); + + const uint32_t kLocalHardwareId = 1234; + for (int i = 0; i < mock_devices_.size(); ++i) { + device_attributes[i].insert({"name", absl::StrCat("device", i)}); + + MockDevice& mock_device = *mock_devices_[i]; + // TODO(b/314368788): Clean up PJRT device ID APIs. + EXPECT_CALL(mock_device, local_hardware_id_typed()) + .WillRepeatedly(Return(xla::PjRtLocalHardwareId(kLocalHardwareId))); + EXPECT_CALL(mock_device, local_hardware_id()) + .WillRepeatedly(Return(kLocalHardwareId)); + EXPECT_CALL(mock_device, local_device_id()) + .WillRepeatedly(Return(xla::PjRtLocalDeviceId(kLocalHardwareId))); + EXPECT_CALL(mock_device, device_kind()).WillRepeatedly(Return("mock")); + EXPECT_CALL(mock_device, memory_spaces()) + .WillRepeatedly(Return(device_memories[i])); + EXPECT_CALL(mock_device, default_memory_space()) + .WillRepeatedly(Return(&mock_memories[i])); + EXPECT_CALL(mock_device, Attributes()) + .WillRepeatedly(ReturnRef(device_attributes[i])); + } + + auto request = NewIfrtRequest(NewOpId()); + request->mutable_init_request(); + + EXPECT_THAT(CallBackend(std::move(request)), + IsOkAndHolds(Pointee( + Partially(IgnoringRepeatedFieldOrdering(EquivToProto(R"pb( + init_response { + session_id: 12345 + platform_name: "ifrt_backend" + platform_version: "n/a" + platform_id: 42 + process_index: 1 + runtime_type: "ifrt-service" + devices { + id: 0 + local_device_id: 1234 + local_hardware_id: 1234 + device_kind: "mock" + default_memory_id: 0 + memory_ids: [ 0 ] + attributes { + key: "name" + value { string_value: "device0" } + } + } + devices { + id: 1 + local_device_id: 1234 + local_hardware_id: 1234 + device_kind: "mock" + default_memory_id: 1 + memory_ids: [ 1 ] + attributes { + key: "name" + value { string_value: "device1" } + } + } + memories { + id: 0 + memory_space_kind: "mock" + device_ids: [ 0 ] + } + memories { + id: 1 + memory_space_kind: "mock" + device_ids: [ 1 ] + } + } + )pb")))))); +} +#endif + +// TODO(b/282757875): Use the MockRuntime fixture to cover the error cases for +// MakeArrayFromHostBuffer and CopyToHostBuffer methods as well. + +// Consider redoing the happy-path test below with PjRt CPU-only backend for +// non-SingleDeviceSharding. +TEST_F(IfrtBackendHandlerTest, DisassembleIntoSingleDeviceArraysSucceeds) { + // Set up a mock source array that returns two single device arrays on + // disassembly. + std::vector> single_device_arrays; + single_device_arrays.push_back(tsl::MakeRef()); + single_device_arrays.push_back(tsl::MakeRef()); + tsl::RCReference source_mock_array = + tsl::MakeRef(); + EXPECT_CALL(*source_mock_array, DisassembleIntoSingleDeviceArrays(_)) + .WillOnce(Return(std::move(single_device_arrays))); + + // Inject the mock_array. + TF_ASSERT_OK_AND_ASSIGN(auto array_handle, + MakeTestArray(std::move(source_mock_array))); + + // Disassemble. + auto disassemble_request = NewIfrtRequest(NewOpId()); + disassemble_request->mutable_disassemble_into_single_device_arrays_request() + ->set_array_handle(array_handle); + TF_ASSERT_OK_AND_ASSIGN(auto disassemble_response, + CallBackend(std::move(disassemble_request))); + + // We must have gotten back two handles corresponding to the two single device + // arrays we injected. + EXPECT_THAT( + disassemble_response->disassemble_into_single_device_arrays_response() + .single_device_array_handles(), + SizeIs(2)); +} + +TEST_F(IfrtBackendHandlerTest, MakeArrayFromHostBufferSuccess) { + // Given the below shape, dtype, and compact byte_strides, the size of the + // array data needs to be 480 bytes. + const uint64_t kHostBufferHandle = 1234; + ASSERT_THAT( + host_buffer_store_->Store(kHostBufferHandle, std::string(480, 'a')), + IsOk()); + + auto ifrt_request = NewIfrtRequest(NewOpId()); + { + auto* make_array = + ifrt_request->mutable_make_array_from_host_buffer_request(); + ASSERT_TRUE( + TextFormat::ParseFromString(R"pb( + dtype: DTYPE_F64 + shape { dimensions: [ 5, 3, 4 ] } + byte_strides { strides: [ 8, 40, 120 ] } + )pb", + make_array)); + make_array->set_host_buffer_handle(kHostBufferHandle); + TF_ASSERT_OK_AND_ASSIGN(auto* device, mock_client_->LookupDevice(1)); + TF_ASSERT_OK_AND_ASSIGN( + *make_array->mutable_sharding(), + ToShardingProto(*SingleDeviceSharding::Create(device, MemoryKind()))); + } + + const Shape expected_shape({5, 3, 4}); + const std::vector expected_byte_strides_vec = {8, 40, 120}; + const std::optional> expected_byte_strides = + absl::Span(expected_byte_strides_vec); + + tsl::RCReference mock_array = + tsl::MakeRef(); + + EXPECT_CALL(*mock_client_, + MakeArrayFromHostBuffer(_, DType(DType::kF64), expected_shape, + expected_byte_strides, _, _, _)) + .WillOnce(Return(std::move(mock_array))); + + TF_ASSERT_OK_AND_ASSIGN(auto response, CallBackend(std::move(ifrt_request))); + EXPECT_NE(response->make_array_from_host_buffer_response().array_handle(), 0); +} + +TEST_F(IfrtBackendHandlerTest, AssembleArrayFromSingleDeviceArrays) { + auto ifrt_request = NewIfrtRequest(NewOpId()); + { + ASSERT_TRUE(TextFormat::ParseFromString( + R"pb( + shape { dimensions: [ 2, 2 ] } + copy_semantics: ARRAY_COPY_SEMANTICS_ALWAYS_COPY + )pb", + ifrt_request + ->mutable_assemble_array_from_single_device_arrays_request())); + TF_ASSERT_OK_AND_ASSIGN(auto* device, mock_client_->LookupDevice(1)); + TF_ASSERT_OK_AND_ASSIGN( + *ifrt_request + ->mutable_assemble_array_from_single_device_arrays_request() + ->mutable_sharding(), + ToShardingProto(*SingleDeviceSharding::Create(device, MemoryKind()))); + } + + std::vector> single_device_arrays; + for (int i = 0; i < 2; ++i) { + auto array = tsl::MakeRef(); + single_device_arrays.push_back(array); + + TF_ASSERT_OK_AND_ASSIGN(uint64_t array_handle, MakeTestArray(array)); + ifrt_request->mutable_assemble_array_from_single_device_arrays_request() + ->add_single_device_array_handles(array_handle); + } + + tsl::RCReference result = + tsl::MakeRef(); + const Shape expected_shape({2, 2}); + + EXPECT_CALL(*mock_client_, + AssembleArrayFromSingleDeviceArrays( + expected_shape, _, ElementsAreArray(single_device_arrays), _)) + .WillOnce(Return(std::move(result))); + + TF_ASSERT_OK_AND_ASSIGN(auto response, CallBackend(std::move(ifrt_request))); + EXPECT_NE(response->assemble_array_from_single_device_arrays_response() + .array_handle(), + 0); +} + +TEST_F(IfrtBackendHandlerTest, CopyToHostSuccess) { + Shape shape({5, 3, 4}); + tsl::RCReference array = + tsl::MakeRef(); + ON_CALL(*array, shape()).WillByDefault(ReturnRef(shape)); + ON_CALL(*array, dtype()).WillByDefault(Return(DType(DType::kF64))); + + TF_ASSERT_OK_AND_ASSIGN(auto array_handle, MakeTestArray(array)); + + auto ifrt_request = NewIfrtRequest(NewOpId()); + auto* copy_to_host = ifrt_request->mutable_copy_to_host_buffer_request(); + ASSERT_TRUE( + TextFormat::ParseFromString(R"pb( + byte_strides { strides: [ 8, 40, 120 ] } + )pb", + copy_to_host)); + copy_to_host->set_array_handle(array_handle); + const uint64_t host_buffer_handle = NewHostBufferHandle(); + copy_to_host->set_host_buffer_handle(host_buffer_handle); + + const std::vector expected_byte_strides_vec = {8, 40, 120}; + const std::optional> expected_byte_strides = + absl::Span(expected_byte_strides_vec); + EXPECT_CALL(*array, CopyToHostBuffer(_, expected_byte_strides, _)) + .WillOnce(Return(Future(absl::OkStatus()))); + + TF_ASSERT_OK_AND_ASSIGN(auto response, CallBackend(std::move(ifrt_request))); + // Given the above shape, dtype, and compact byte_strides, the size of the + // array data needs to be 480 bytes. + EXPECT_THAT(host_buffer_store_->Lookup(host_buffer_handle), + IsOkAndHolds(Pointee(SizeIs(480)))); +} + +TEST_F(IfrtBackendHandlerTest, CopyToHostFailsWithNonExistentArrays) { + auto ifrt_request = NewIfrtRequest(NewOpId()); + ASSERT_TRUE(TextFormat::ParseFromString( + R"pb( + byte_strides { strides: [ 8, 40, 120 ] } + )pb", + ifrt_request->mutable_copy_to_host_buffer_request())); + ifrt_request->mutable_copy_to_host_buffer_request()->set_array_handle(0); + + EXPECT_THAT(CallBackend(std::move(ifrt_request)), + StatusIs(absl::StatusCode::kNotFound)); +} + +TEST_F(IfrtBackendHandlerTest, + DisassembleIntoSingleArrayFailsWhenBackendRuntimeFails) { + // Set up a mock source array that fails the disassembly. + constexpr absl::string_view kDisassembleErrorMessage = + "Some test-injected error message that is unlikely to match other error " + "messages - 1234"; + tsl::RCReference source_mock_array = + tsl::MakeRef(); + EXPECT_CALL(*source_mock_array, DisassembleIntoSingleDeviceArrays(_)) + .WillOnce(Return(absl::UnknownError(kDisassembleErrorMessage))); + + // Set up the mock client to return the source_mock_array when the test tries + // to MakeArrayFromHostBuffer. + TF_ASSERT_OK_AND_ASSIGN(auto array_handle, + MakeTestArray(std::move(source_mock_array))); + + // Disassembly must fail with the error we injected. + auto disassemble_request = NewIfrtRequest(NewOpId()); + disassemble_request->mutable_disassemble_into_single_device_arrays_request() + ->set_array_handle(array_handle); + ASSERT_THAT( + CallBackend(std::move(disassemble_request)), + StatusIs(absl::StatusCode::kUnknown, StrEq(kDisassembleErrorMessage))); +} + +TEST_F(IfrtBackendHandlerTest, ReshardSuccess) { + auto src_mock_array = tsl::MakeRef(); + auto resharded_mock_array = tsl::MakeRef(); + EXPECT_CALL(*src_mock_array, Reshard(_, _)) + .WillOnce(Return(std::move(resharded_mock_array))); + TF_ASSERT_OK_AND_ASSIGN(auto src_array_handle, + MakeTestArray(std::move(src_mock_array))); + + auto ifrt_request = NewIfrtRequest(NewOpId()); + auto* reshard_request = ifrt_request->mutable_reshard_request(); + reshard_request->set_array_handle(src_array_handle); + reshard_request->set_copy_semantics(proto::ARRAY_COPY_SEMANTICS_ALWAYS_COPY); + TF_ASSERT_OK_AND_ASSIGN(auto* device, mock_client_->LookupDevice(1)); + TF_ASSERT_OK_AND_ASSIGN( + *ifrt_request->mutable_reshard_request()->mutable_sharding(), + ToShardingProto(*SingleDeviceSharding::Create(device, MemoryKind()))); + + TF_ASSERT_OK_AND_ASSIGN(auto response, CallBackend(std::move(ifrt_request))); + + EXPECT_THAT(tsl::StatusFromProto(response->response_metadata().status()), + IsOk()); + EXPECT_NE(response->reshard_response().array_handle(), 0); +} + +TEST_F(IfrtBackendHandlerTest, FullyReplicatedShardSuccess) { + auto fully_replicated_mock_array = tsl::MakeRef(); + auto resultant_array = tsl::MakeRef(); + EXPECT_CALL(*fully_replicated_mock_array, FullyReplicatedShard(_)) + .WillOnce(Return(std::move(resultant_array))); + TF_ASSERT_OK_AND_ASSIGN( + auto fully_replicated_array_handle, + MakeTestArray(std::move(fully_replicated_mock_array))); + + auto ifrt_request = NewIfrtRequest(NewOpId()); + auto* fully_replicated_shard_request = + ifrt_request->mutable_fully_replicated_shard_request(); + fully_replicated_shard_request->set_array_handle( + fully_replicated_array_handle); + fully_replicated_shard_request->set_copy_semantics( + proto::ARRAY_COPY_SEMANTICS_ALWAYS_COPY); + + TF_ASSERT_OK_AND_ASSIGN(auto response, CallBackend(std::move(ifrt_request))); + EXPECT_NE(response->fully_replicated_shard_response().array_handle(), 0); +} + +TEST_F(IfrtBackendHandlerTest, FullyReplicatedShardFailure) { + auto fully_replicated_mock_array = tsl::MakeRef(); + EXPECT_CALL(*fully_replicated_mock_array, FullyReplicatedShard(_)) + .WillOnce(Return(absl::UnknownError("injected error"))); + TF_ASSERT_OK_AND_ASSIGN( + auto fully_replicated_array_handle, + MakeTestArray(std::move(fully_replicated_mock_array))); + + auto ifrt_request = NewIfrtRequest(NewOpId()); + auto* fully_replicated_shard_request = + ifrt_request->mutable_fully_replicated_shard_request(); + fully_replicated_shard_request->set_array_handle( + fully_replicated_array_handle); + fully_replicated_shard_request->set_copy_semantics( + proto::ARRAY_COPY_SEMANTICS_ALWAYS_COPY); + + EXPECT_THAT(CallBackend(std::move(ifrt_request)), + StatusIs(absl::StatusCode::kUnknown, StrEq("injected error"))); +} + +TEST_F(IfrtBackendHandlerTest, + FullyReplicatedShardFailsWithNonExistentArrayHandle) { + auto ifrt_request = NewIfrtRequest(NewOpId()); + auto* fully_replicated_shard_request = + ifrt_request->mutable_fully_replicated_shard_request(); + fully_replicated_shard_request->set_array_handle(0); + fully_replicated_shard_request->set_copy_semantics( + proto::ARRAY_COPY_SEMANTICS_ALWAYS_COPY); + + EXPECT_THAT(CallBackend(std::move(ifrt_request)), + StatusIs(absl::StatusCode::kNotFound)); +} + +TEST_F(IfrtBackendHandlerTest, ReshardFailsWhenTheBackendFails) { + auto mock_array = tsl::MakeRef(); + EXPECT_CALL(*mock_array, Reshard(_, _)) + .WillOnce(Return(absl::UnknownError("injected error"))); + TF_ASSERT_OK_AND_ASSIGN(auto array_handle, + MakeTestArray(std::move(mock_array))); + + auto ifrt_request = NewIfrtRequest(NewOpId()); + auto* reshard_request = ifrt_request->mutable_reshard_request(); + reshard_request->set_array_handle(array_handle); + reshard_request->set_copy_semantics(proto::ARRAY_COPY_SEMANTICS_ALWAYS_COPY); + TF_ASSERT_OK_AND_ASSIGN(auto* device, mock_client_->LookupDevice(1)); + TF_ASSERT_OK_AND_ASSIGN( + *ifrt_request->mutable_reshard_request()->mutable_sharding(), + ToShardingProto(*SingleDeviceSharding::Create(device, MemoryKind()))); + + EXPECT_THAT(CallBackend(std::move(ifrt_request)), + StatusIs(absl::StatusCode::kUnknown, StrEq("injected error"))); +} + +TEST_F(IfrtBackendHandlerTest, ReshardFailsWithNonExistentArrayHandle) { + auto ifrt_request = NewIfrtRequest(NewOpId()); + auto* reshard_request = ifrt_request->mutable_reshard_request(); + reshard_request->set_array_handle(0); + reshard_request->set_copy_semantics(proto::ARRAY_COPY_SEMANTICS_ALWAYS_COPY); + reshard_request->mutable_sharding(); + + EXPECT_THAT(CallBackend(std::move(ifrt_request)), + StatusIs(absl::StatusCode::kNotFound)); +} + +TEST_F(IfrtBackendHandlerTest, + CheckArrayReadyRequestRelaysTheResultFromBackend) { + auto mock_array = tsl::MakeRef(); + EXPECT_CALL(*mock_array, GetReadyFuture()) + .WillOnce(Return(Future(absl::OkStatus()))) + .WillOnce( + Return(Future(absl::UnknownError("injected error")))); + TF_ASSERT_OK_AND_ASSIGN(auto array_handle, + MakeTestArray(std::move(mock_array))); + + { + auto ifrt_request = NewIfrtRequest(NewOpId()); + ifrt_request->mutable_check_array_ready_request()->set_array_handle( + array_handle); + TF_ASSERT_OK_AND_ASSIGN(auto ifrt_response, + CallBackend(std::move(ifrt_request))); + + EXPECT_THAT(ifrt_response->response_metadata().status().code(), + tensorflow::error::OK); + EXPECT_TRUE(ifrt_response->has_check_array_ready_response()); + } + + { + auto ifrt_request = NewIfrtRequest(NewOpId()); + ifrt_request->mutable_check_array_ready_request()->set_array_handle( + array_handle); + EXPECT_THAT(CallBackend(std::move(ifrt_request)), + StatusIs(absl::StatusCode::kUnknown, StrEq("injected error"))); + } +} + +TEST_F(IfrtBackendHandlerTest, + CheckArrayReadyRequestFailsWithNonExistentArrayHandle) { + auto ifrt_request = NewIfrtRequest(NewOpId()); + ifrt_request->mutable_check_array_ready_request()->set_array_handle(0); + EXPECT_THAT(CallBackend(std::move(ifrt_request)), + StatusIs(absl::StatusCode::kNotFound)); +} + +TEST_F(IfrtBackendHandlerTest, DeleteArraySuccess) { + tsl::RCReference mock_array = + tsl::MakeRef(); + EXPECT_CALL(*mock_array, Delete()) + .WillOnce(Return(Future(absl::OkStatus()))); + TF_ASSERT_OK_AND_ASSIGN(auto array_handle, + MakeTestArray(std::move(mock_array))); + + uint64_t op_id = NewOpId(); + auto ifrt_request = NewIfrtRequest(op_id); + ifrt_request->mutable_delete_array_request()->set_array_handle(array_handle); + TF_ASSERT_OK_AND_ASSIGN(auto resp, CallBackend(std::move(ifrt_request))); + EXPECT_THAT(tsl::StatusFromProto(resp->response_metadata().status()), IsOk()); + EXPECT_NE(resp->delete_array_response().deletion_future_handle(), 0); +} + +TEST_F(IfrtBackendHandlerTest, DeleteArrayFailsWithNonExistentArrayHandle) { + auto ifrt_request = NewIfrtRequest(NewOpId()); + ifrt_request->mutable_delete_array_request()->set_array_handle(0); + EXPECT_THAT(CallBackend(std::move(ifrt_request)), + StatusIs(absl::StatusCode::kNotFound)); +} + +TEST_F(IfrtBackendHandlerTest, + IsDeleteRelaysBackTheReturnValueFromBackendRuntime) { + tsl::RCReference mock_array = + tsl::MakeRef(); + + EXPECT_CALL(*mock_array, IsDeleted()) + .WillOnce(Return(true)) + .WillOnce(Return(false)); + + TF_ASSERT_OK_AND_ASSIGN(auto array_handle, + MakeTestArray(std::move(mock_array))); + + auto ifrt_request = NewIfrtRequest(NewOpId()); + ifrt_request->mutable_is_array_deleted_request()->set_array_handle( + array_handle); + TF_ASSERT_OK_AND_ASSIGN(auto resp, CallBackend(std::move(ifrt_request))); + EXPECT_TRUE(resp->is_array_deleted_response().deleted()); + + ifrt_request = NewIfrtRequest(NewOpId()); + ifrt_request->mutable_is_array_deleted_request()->set_array_handle( + array_handle); + TF_ASSERT_OK_AND_ASSIGN(resp, CallBackend(std::move(ifrt_request))); + EXPECT_FALSE(resp->is_array_deleted_response().deleted()); +} + +TEST_F(IfrtBackendHandlerTest, IsDeleteFailsForNonExistentArrays) { + auto ifrt_request = NewIfrtRequest(NewOpId()); + ifrt_request->mutable_is_array_deleted_request()->set_array_handle(0); + EXPECT_THAT(CallBackend(std::move(ifrt_request)), + StatusIs(absl::StatusCode::kNotFound)); +} + +TEST_F(IfrtBackendHandlerTest, DestructArrayTest) { + tsl::RCReference mock_array = + tsl::MakeRef(); + TF_ASSERT_OK_AND_ASSIGN(auto array_handle, + MakeTestArray(std::move(mock_array))); + + auto ifrt_request = NewIfrtRequest(NewOpId()); + ifrt_request->mutable_destruct_array_request()->set_array_handle( + array_handle); + TF_ASSERT_OK_AND_ASSIGN(auto ifrt_resp, CallBackend(std::move(ifrt_request))); + EXPECT_TRUE(ifrt_resp->has_destruct_array_response()); + + // Retrying DestructArray should fail. And, this establishes that: (1) the + // handle no longer exists on the server, (2) DestructArray fails for + // non-existent arrays and (3) DestructArray is not idempotent. + ifrt_request = NewIfrtRequest(NewOpId()); + ifrt_request->mutable_destruct_array_request()->set_array_handle( + array_handle); + EXPECT_THAT(CallBackend(std::move(ifrt_request)), + StatusIs(absl::StatusCode::kNotFound)); +} + +// TODO(b/315809436): Test needs rewrite because protobuf matchers are not OSS +#if defined(PLATFORM_GOOGLE) +TEST_F(IfrtBackendHandlerTest, CompileSuccess) { + std::vector devices(4); + for (int i = 0; i < 4; ++i) { + EXPECT_CALL(devices[i], global_device_id()) + .WillOnce(Return(xla::PjRtGlobalDeviceId(i))); + } + + std::vector + addressable_device_logical_ids; + std::vector addressable_devices; + for (int i = 0; i < 4; ++i) { + addressable_device_logical_ids.push_back( + {.replica = i / 2, .partition = i % 2}); + addressable_devices.push_back(&devices[i]); + } + + auto executable = std::make_unique(); + EXPECT_CALL(*executable, name()).WillOnce(Return("executable_name")); + EXPECT_CALL(*executable, num_devices()).WillOnce(Return(4)); + EXPECT_CALL(*executable, addressable_device_logical_ids()) + .WillOnce(Return(absl::MakeSpan(addressable_device_logical_ids))); + EXPECT_CALL(*executable, addressable_devices()) + .WillOnce(Return(absl::MakeSpan(addressable_devices))); + EXPECT_CALL(*executable, Fingerprint()).WillOnce(Return("fingerprint")); + + EXPECT_THAT(CompileTestLoadedExecutable(std::move(executable)), + IsOkAndHolds(Partially(EquivToProto(R"pb( + name: "executable_name" + num_devices: 4 + addressable_device_logical_ids { replica: 0 partition: 0 } + addressable_device_logical_ids { replica: 0 partition: 1 } + addressable_device_logical_ids { replica: 1 partition: 0 } + addressable_device_logical_ids { replica: 1 partition: 1 } + addressable_device_ids: [ 0, 1, 2, 3 ] + fingerprint_value: "fingerprint" + )pb")))); +} +#endif + +TEST_F(IfrtBackendHandlerTest, CompileFailure) { + ASSERT_THAT( + CompileTestLoadedExecutable(absl::InternalError("injected error")), + StatusIs(absl::StatusCode::kInternal, StrEq("injected error"))); +} + +// TODO(b/315809436): Test needs rewrite because protobuf matchers are not OSS +#if defined(PLATFORM_GOOGLE) +TEST_F(IfrtBackendHandlerTest, LoadedExecutableMetadata) { + MockLoadedExecutable* executable; + uint64_t handle; + { + auto e = std::make_unique(); + executable = e.get(); + TF_ASSERT_OK_AND_ASSIGN(CompileResponse response, + CompileTestLoadedExecutable(std::move(e))); + handle = response.loaded_executable_handle(); + } + + { + OpSharding op_sharding1; + ASSERT_TRUE( + TextFormat::ParseFromString(R"pb(type: REPLICATED)pb", &op_sharding1)); + + OpSharding op_sharding2; + ASSERT_TRUE(TextFormat::ParseFromString( + R"pb(type: OTHER + tile_shape { + element_type: BF16 + dimensions: [ 2, 2 ] + } + tile_assignment_dimensions: [ 0, 1 ])pb", + &op_sharding2)); + + EXPECT_CALL(*executable, GetParameterShardings()) + .WillOnce(Return(std::vector{op_sharding1, op_sharding2})); + + EXPECT_CALL(*executable, GetOutputShardings()) + .WillOnce(Return(std::vector{op_sharding1})); + + EXPECT_CALL(*executable, GetParameterLayouts()) + .WillOnce(Return(std::vector{ + xla::LayoutUtil::MakeDescendingLayout(/*rank=*/1), + xla::LayoutUtil::MakeDescendingLayout(/*rank=*/2), + })); + EXPECT_CALL(*executable, GetOutputLayouts()) + .WillOnce(Return(std::vector{ + xla::LayoutUtil::MakeDescendingLayout(/*rank=*/2), + })); + EXPECT_CALL(*executable, GetOutputMemoryKinds()) + .WillOnce(Return(std::vector>{{"foo"}})); + + auto request = NewIfrtRequest(NewOpId()); + LoadedExecutableMetadataRequest* metadata_request = + request->mutable_loaded_executable_metadata_request(); + metadata_request->set_loaded_executable_handle(handle); + + EXPECT_THAT(CallBackend(std::move(request)), + IsOkAndHolds(Pointee(Partially(EquivToProto(R"pb( + loaded_executable_metadata_response { + parameter_shardings { + shardings { type: REPLICATED } + shardings { + type: OTHER + tile_shape { + element_type: BF16 + dimensions: [ 2, 2 ] + } + tile_assignment_dimensions: [ 0, 1 ] + } + } + output_shardings { shardings { type: REPLICATED } } + parameter_layouts_list { + layouts { minor_to_major: 0 } + layouts { minor_to_major: [ 1, 0 ] } + } + output_layouts_list { layouts { minor_to_major: [ 1, 0 ] } } + output_memory_kinds { + memory_kind_lists { memory_kinds: [ "foo" ] } + } + } + )pb"))))); + } + + { + EXPECT_CALL(*executable, GetParameterShardings()) + .WillOnce(Return(std::nullopt)); + EXPECT_CALL(*executable, GetOutputShardings()) + .WillOnce(Return(std::nullopt)); + EXPECT_CALL(*executable, GetParameterLayouts()) + .WillOnce(Return(absl::UnimplementedError("unimplemented"))); + EXPECT_CALL(*executable, GetOutputLayouts()) + .WillOnce(Return(absl::UnimplementedError("unimplemented"))); + EXPECT_CALL(*executable, GetOutputMemoryKinds()) + .WillOnce(Return(std::vector>{})); + + auto request = NewIfrtRequest(NewOpId()); + LoadedExecutableMetadataRequest* metadata_request = + request->mutable_loaded_executable_metadata_request(); + metadata_request->set_loaded_executable_handle(handle); + + TF_ASSERT_OK_AND_ASSIGN(std::shared_ptr response, + CallBackend(std::move(request))); + const auto& metadata_response = + response->loaded_executable_metadata_response(); + EXPECT_FALSE(metadata_response.has_parameter_shardings()); + EXPECT_FALSE(metadata_response.has_output_shardings()); + EXPECT_TRUE(metadata_response.has_parameter_layouts_error()); + EXPECT_TRUE(metadata_response.has_output_layouts_error()); + } +} +#endif + +// TODO(b/315809436): Test needs rewrite because protobuf matchers are not OSS +#if defined(PLATFORM_GOOGLE) +TEST_F(IfrtBackendHandlerTest, LoadedExecutableExecute) { + MockDevice device; + ON_CALL(device, global_device_id()) + .WillByDefault(Return(xla::PjRtGlobalDeviceId(0))); + + MockLoadedExecutable* executable; + uint64_t handle; + { + auto e = std::make_unique(); + executable = e.get(); + TF_ASSERT_OK_AND_ASSIGN(CompileResponse response, + CompileTestLoadedExecutable(std::move(e))); + handle = response.loaded_executable_handle(); + } + + constexpr int kNumArgs = 3; + constexpr int kNumOutputs = 2; + + Shape shape({2, 2}); + auto sharding = SingleDeviceSharding::Create(&device, MemoryKind()); + + auto make_array = [&]() { + auto array = tsl::MakeRef(); + ON_CALL(*array, dtype()).WillByDefault(Return(DType(DType::kF32))); + ON_CALL(*array, shape()).WillByDefault(ReturnRef(shape)); + ON_CALL(*array, sharding()).WillByDefault(ReturnRef(*sharding)); + return array; + }; + + std::vector> outputs; + outputs.reserve(kNumOutputs); + for (int i = 0; i < kNumOutputs; ++i) { + outputs.push_back(make_array()); + } + + EXPECT_CALL(*executable, Execute(SizeIs(kNumArgs), _, _)) + .WillOnce( + Invoke([&](absl::Span> args, + const xla::ifrt::LoadedExecutable::ExecuteOptions& options, + std::optional devices) + -> absl::StatusOr { + return LoadedExecutable::ExecuteResult{ + .status = + Future(absl::InternalError("injected error")), + .outputs = outputs, + }; + })); + + auto request = NewIfrtRequest(NewOpId()); + LoadedExecutableExecuteRequest* execute_request = + request->mutable_loaded_executable_execute_request(); + for (int i = 0; i < kNumArgs; ++i) { + TF_ASSERT_OK_AND_ASSIGN(uint64_t arg_handle, MakeTestArray(make_array())); + execute_request->add_args_handles(arg_handle); + } + execute_request->set_loaded_executable_handle(handle); + TF_ASSERT_OK_AND_ASSIGN( + *execute_request->mutable_execute_options(), + xla::ifrt::LoadedExecutable::ExecuteOptions().ToProto()); + + TF_ASSERT_OK_AND_ASSIGN(std::shared_ptr response, + CallBackend(std::move(request))); + EXPECT_THAT(response, Pointee(Partially(EquivToProto(R"pb( + loaded_executable_execute_response { + outputs { + dtype: DTYPE_F32 + shape { dimensions: [ 2, 2 ] } + } + outputs { + dtype: DTYPE_F32 + shape { dimensions: [ 2, 2 ] } + } + } + )pb")))); + TF_ASSERT_OK_AND_ASSIGN( + auto sharding_proto, + ToShardingProto(*SingleDeviceSharding::Create(&device, MemoryKind()))); + for (const auto& output : + response->loaded_executable_execute_response().outputs()) { + EXPECT_THAT(output.sharding(), EquivToProto(sharding_proto)); + EXPECT_NE(output.array_handle(), 0); + } + + EXPECT_THAT( + CheckFuture( + response->loaded_executable_execute_response().status_handle()), + StatusIs(absl::StatusCode::kInternal, StrEq("injected error"))); + + // The second call to `CheckFuture` fails since `CheckFuture` above performs a + // destructive read. + EXPECT_THAT( + CheckFuture( + response->loaded_executable_execute_response().status_handle()), + StatusIs(absl::StatusCode::kNotFound, + HasSubstr("Unknown future handle"))); +} +#endif + +// TODO(b/315809436): Test needs rewrite because protobuf matchers are not OSS +#if defined(PLATFORM_GOOGLE) +TEST_F(IfrtBackendHandlerTest, LoadedExecutableDelete) { + MockLoadedExecutable* executable; + uint64_t handle; + { + auto e = std::make_unique(); + executable = e.get(); + TF_ASSERT_OK_AND_ASSIGN(CompileResponse response, + CompileTestLoadedExecutable(std::move(e))); + handle = response.loaded_executable_handle(); + } + + { + EXPECT_CALL(*executable, Delete()) + .WillOnce(Return(Future(absl::OkStatus()))); + + auto request = NewIfrtRequest(NewOpId()); + LoadedExecutableDeleteRequest* delete_request = + request->mutable_loaded_executable_delete_request(); + delete_request->set_loaded_executable_handle(handle); + + TF_ASSERT_OK_AND_ASSIGN(std::shared_ptr response, + CallBackend(std::move(request))); + ASSERT_TRUE(response->has_loaded_executable_delete_response()); + + EXPECT_THAT( + CheckFuture( + response->loaded_executable_delete_response().future_handle()), + IsOk()); + } + + { + EXPECT_CALL(*executable, IsDeleted()).WillOnce(Return(true)); + + auto request = NewIfrtRequest(NewOpId()); + LoadedExecutableIsDeletedRequest* is_deleted_request = + request->mutable_loaded_executable_is_deleted_request(); + is_deleted_request->set_loaded_executable_handle(handle); + + EXPECT_THAT(CallBackend(std::move(request)), + IsOkAndHolds(Pointee(Partially(EquivToProto(R"pb( + loaded_executable_is_deleted_response { is_deleted: true } + )pb"))))); + } +} +#endif + +TEST_F(IfrtBackendHandlerTest, LoadedExecutableDestruct) { + MockLoadedExecutable* executable; + uint64_t handle; + { + auto e = std::make_unique(); + executable = e.get(); + TF_ASSERT_OK_AND_ASSIGN(CompileResponse response, + CompileTestLoadedExecutable(std::move(e))); + handle = response.loaded_executable_handle(); + } + + { + auto request = NewIfrtRequest(NewOpId()); + LoadedExecutableDestructRequest* destruct_request = + request->mutable_loaded_executable_destruct_request(); + destruct_request->set_loaded_executable_handle(handle); + + TF_ASSERT_OK_AND_ASSIGN(std::shared_ptr response, + CallBackend(std::move(request))); + ASSERT_TRUE(response->has_loaded_executable_destruct_response()); + } + + // Any attempt to access the loaded executable handle should now return an + // error. + { + auto request = NewIfrtRequest(NewOpId()); + LoadedExecutableDestructRequest* destruct_request = + request->mutable_loaded_executable_destruct_request(); + destruct_request->set_loaded_executable_handle(handle); + + EXPECT_THAT(CallBackend(std::move(request)), + StatusIs(absl::StatusCode::kNotFound, + HasSubstr("Unknown loaded executable handle"))); + } +} + +TEST_F(IfrtBackendHandlerTest, LoadedHostCallbackExecute) { + // Build a remote host callback with one F32 argument and one F32 result. + std::vector hcb_args = {{ + .channel_id = 1, + .shape = xla::ShapeUtil::MakeShape(xla::F32, {}), + }}; + std::vector hcb_results = {{ + .channel_id = 2, + .shape = xla::ShapeUtil::MakeShape(xla::F32, {}), + }}; + auto hcb = tsl::MakeRef( + mock_client_, std::move(hcb_args), std::move(hcb_results), + /*queue=*/nullptr); + + // Compile an executable with the above host callback. The resulting loaded + // host callback handle and `xla::HostCallback` are kept for triggering host + // callback execution. + // + // The setup code must use `xla::ifrt::XlaCompileOptions` for now since this + // is the only allowed compile options type that is currently recognized as + // supporting host callbacks. + MockLoadedExecutable* executable; + tsl::RCReference loaded_host_callback; + uint64_t loaded_host_callback_handle; + { + auto request = NewIfrtRequest(NewOpId()); + CompileRequest* compile_request = request->mutable_compile_request(); + + TestProgram program; + TF_ASSERT_OK_AND_ASSIGN(*compile_request->mutable_program(), + Serialize(program)); + xla::ifrt::XlaCompileOptions compile_options; + TF_ASSERT_OK_AND_ASSIGN(*compile_request->mutable_compile_options(), + Serialize(compile_options)); + + TF_ASSERT_OK_AND_ASSIGN(std::string host_callback_serialized, + hcb->Serialize()); + compile_request->add_host_callbacks(std::move(host_callback_serialized)); + + auto e = std::make_unique(); + executable = e.get(); + + EXPECT_CALL(mock_compiler_, Compile(_, _)) + .WillOnce(DoAll( + Invoke( + [&](const std::unique_ptr& program, + const std::unique_ptr& options) { + auto* xla_compile_options = + llvm::cast(options.get()); + auto& loaded_host_callbacks = + xla_compile_options->loaded_host_callbacks; + ASSERT_EQ(loaded_host_callbacks.size(), 1); + loaded_host_callback = loaded_host_callbacks.front(); + }), + Return(ByMove(std::move(e))))); + + TF_ASSERT_OK_AND_ASSIGN(std::shared_ptr response, + CallBackend(std::move(request))); + + ASSERT_TRUE(response->has_compile_response()); + CompileResponse compile_response = response->compile_response(); + + loaded_host_callback_handle = + compile_response.loaded_host_callback_handles(0); + ASSERT_THAT(loaded_host_callback, NotNull()); + } + + // Enqueue a host callback execution. This is done on a separate thread since + // `LoadedHostCallbackPollRequest` blocks until there is a pending execution. + auto host_callback_thread = absl::WrapUnique(tsl::Env::Default()->StartThread( + tsl::ThreadOptions(), "HostCallback", [&]() { + xla::Literal x = xla::LiteralUtil::CreateR0(1.0f); + + std::vector operands; + operands.push_back(x.untyped_data()); + + xla::Literal out = xla::LiteralUtil::CreateR0(0.0f); + std::vector results; + results.push_back(out.untyped_data()); + + const xla::HostCallback* xla_host_callback = + &llvm::cast(loaded_host_callback.get()) + ->host_callback(); + ASSERT_THAT( + xla_host_callback->callback(results.data(), operands.data()), + IsOk()); + EXPECT_EQ(out, xla::LiteralUtil::CreateR0(2.0f)); + })); + + // Poll for a host callback execution and verify its argument against the one + // passed by the execution thread above. + uint64_t host_callback_execution_handle; + { + const uint64_t operand_host_buffer_handle = NewHostBufferHandle(); + + auto request = NewIfrtRequest(NewOpId()); + LoadedHostCallbackPollRequest* poll_request = + request->mutable_loaded_host_callback_poll_request(); + poll_request->set_loaded_host_callback_handle(loaded_host_callback_handle); + poll_request->set_operand_host_buffer_handle(operand_host_buffer_handle); + + TF_ASSERT_OK_AND_ASSIGN(std::shared_ptr response, + CallBackend(std::move(request))); + + ASSERT_TRUE(response->has_loaded_host_callback_poll_response()); + const LoadedHostCallbackPollResponse& poll_response = + response->loaded_host_callback_poll_response(); + host_callback_execution_handle = + poll_response.host_callback_execution_handle(); + + TF_ASSERT_OK_AND_ASSIGN( + const std::shared_ptr operands, + host_buffer_store_->Lookup(operand_host_buffer_handle)); + EXPECT_EQ(xla::BorrowingLiteral(operands->data(), + xla::ShapeUtil::MakeShape(xla::F32, {})), + xla::LiteralUtil::CreateR0(1.0f)); + } + + // Return the execution result. This will unblock the execution thread above, + // which also verifies the result. + { + auto result = xla::LiteralUtil::CreateR0(2.0f); + std::string result_buffer(absl::string_view( + static_cast(result.untyped_data()), result.size_bytes())); + + const uint64_t result_host_buffer_handle = NewHostBufferHandle(); + ASSERT_THAT(host_buffer_store_->Store(result_host_buffer_handle, + std::move(result_buffer)), + IsOk()); + + auto request = NewIfrtRequest(NewOpId()); + LoadedHostCallbackReturnRequest* ret_request = + request->mutable_loaded_host_callback_return_request(); + ret_request->set_host_callback_execution_handle( + host_callback_execution_handle); + ret_request->set_result_host_buffer_handle(result_host_buffer_handle); + + TF_ASSERT_OK_AND_ASSIGN(std::shared_ptr response, + CallBackend(std::move(request))); + ASSERT_TRUE(response->has_loaded_host_callback_return_response()); + } +} + +TEST_F(IfrtBackendHandlerTest, GetDefaultDeviceAssignmentSuccess) { + const int kNumReplicas = 1; + const int kNumPartitions = 3; + + EXPECT_CALL(*mock_client_, + GetDefaultDeviceAssignment(kNumReplicas, kNumPartitions)) + .WillOnce(Return(xla::DeviceAssignment(kNumReplicas, kNumPartitions))); + + auto request = NewIfrtRequest(NewOpId()); + auto* default_device_assignment_request = + request->mutable_get_default_device_assignment_request(); + default_device_assignment_request->set_num_replicas(kNumReplicas); + default_device_assignment_request->set_num_partitions(kNumPartitions); + + TF_ASSERT_OK_AND_ASSIGN(auto response, CallBackend(std::move(request))); + TF_ASSERT_OK_AND_ASSIGN(auto assignment_got, + xla::DeviceAssignment::Deserialize( + response->get_default_device_assignment_response() + .device_assignment())); + EXPECT_EQ(assignment_got->replica_count(), kNumReplicas); + EXPECT_EQ(assignment_got->computation_count(), kNumPartitions); +} + +TEST_F(IfrtBackendHandlerTest, + GetDefaultDeviceAssignmentFailsIfTheBackendFails) { + const int kNumReplicas = 1; + const int kNumPartitions = 3; + + EXPECT_CALL(*mock_client_, + GetDefaultDeviceAssignment(kNumReplicas, kNumPartitions)) + .WillOnce(Return(absl::UnknownError("injected error"))); + + auto request = NewIfrtRequest(NewOpId()); + auto* default_device_assignment_request = + request->mutable_get_default_device_assignment_request(); + default_device_assignment_request->set_num_replicas(kNumReplicas); + default_device_assignment_request->set_num_partitions(kNumPartitions); + + EXPECT_THAT(CallBackend(std::move(request)), + StatusIs(absl::StatusCode::kUnknown, StrEq("injected error"))); +} + +} // namespace +} // namespace proxy +} // namespace ifrt +} // namespace xla diff --git a/third_party/xla/xla/python/ifrt_proxy/server/ifrt_session_handler.cc b/third_party/xla/xla/python/ifrt_proxy/server/ifrt_session_handler.cc new file mode 100644 index 00000000000000..a4b0a95f2866f3 --- /dev/null +++ b/third_party/xla/xla/python/ifrt_proxy/server/ifrt_session_handler.cc @@ -0,0 +1,117 @@ +// Copyright 2023 The OpenXLA Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "xla/python/ifrt_proxy/server/ifrt_session_handler.h" + +#include +#include +#include +#include + +#include "absl/log/log.h" +#include "absl/memory/memory.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/synchronization/mutex.h" +#include "xla/python/ifrt/future.h" +#include "xla/python/ifrt_proxy/common/ifrt_service.pb.h" +#include "xla/python/ifrt_proxy/common/proto_util.h" + +// The tsl include below is needed only for the Status macros such as +// ASSIGN_OR_RETURN, since the OSS absl package does not have the counterparts +// yet. +#include "tsl/platform/statusor.h" + +namespace xla { +namespace ifrt { +namespace proxy { + +absl::StatusOr> IfrtSessionHandler::Create( + uint64_t id, BackendFactory backend_factory) { + if (backend_factory == nullptr) { + return absl::InvalidArgumentError("BackendFactory cannot be nullptr."); + } + return absl::WrapUnique( + new IfrtSessionHandler(id, std::move(backend_factory))); +} + +IfrtSessionHandler::IfrtSessionHandler(uint64_t id, + BackendFactory backend_factory) + : session_id_(id), backend_factory_(std::move(backend_factory)) {} + +void IfrtSessionHandler::NewIncomingRequest( + std::unique_ptr request, + std::function)> on_done) { + VLOG(2) << "NewIncomingRequest: " << request->DebugString(); + + const uint64_t op_id = request->request_metadata().op_id(); + + // The current implementation exploits the async nature of the backend_ IFRT + // client to minimize the amount of work we do per request. However, using a + // threadpool here might make sense as a performance optimization. + + auto result = [&]() -> Future { + if (request->has_init_request()) { + return ProcessInitRequest(std::move(request)); + } + if (auto status = SetupBackendIfNeeded(); !status.ok()) { + return Future(status); + } + absl::ReaderMutexLock read_lock(&backend_mu_); + return backend_->Process(std::move(request)); + }(); + + // Consider maintaining a count of in-flight requests (that won't complete + // until the following OnReady callback happens) so we can safely deleting the + // reactor_. + result.OnReady([op_id, on_done = std::move(on_done)]( + absl::StatusOr> result) { + if (result.ok()) { + on_done(*std::move(result)); + } else { + on_done(NewIfrtResponse(op_id, result.status())); + } + }); +} + +Future IfrtSessionHandler::ProcessInitRequest( + std::unique_ptr request) { + absl::MutexLock lock(&backend_mu_); + if (backend_ != nullptr) { + // Currently backends cannot be reinitialized. + return Future(absl::FailedPreconditionError( + "This session has already been initialized.")); + } + + auto backend = backend_factory_(session_id_); + if (!backend.ok()) { + return Future(backend.status()); + } + backend_ = *std::move(backend); + + return backend_->Process(std::move(request)); +} + +absl::Status IfrtSessionHandler::SetupBackendIfNeeded() { + absl::MutexLock lock(&backend_mu_); + if (backend_ != nullptr) { + return absl::OkStatus(); + } + TF_ASSIGN_OR_RETURN(backend_, backend_factory_(session_id_)); + return absl::OkStatus(); +} + +} // namespace proxy +} // namespace ifrt +} // namespace xla diff --git a/third_party/xla/xla/python/ifrt_proxy/server/ifrt_session_handler.h b/third_party/xla/xla/python/ifrt_proxy/server/ifrt_session_handler.h new file mode 100644 index 00000000000000..505341a6934958 --- /dev/null +++ b/third_party/xla/xla/python/ifrt_proxy/server/ifrt_session_handler.h @@ -0,0 +1,82 @@ +/* + * Copyright 2023 The OpenXLA Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef XLA_PYTHON_IFRT_PROXY_SERVER_IFRT_SESSION_HANDLER_H_ +#define XLA_PYTHON_IFRT_PROXY_SERVER_IFRT_SESSION_HANDLER_H_ + +#include +#include + +#include "absl/container/flat_hash_set.h" +#include "xla/python/ifrt_proxy/common/ifrt_service.pb.h" +#include "xla/python/ifrt_proxy/server/ifrt_backend.h" + +namespace xla { +namespace ifrt { +namespace proxy { + +// IfrtSessionHandler glues an incoming stream to a stack of backend runtimes +// abstracted out by a `BackendInterface`. It utilizes the provided `Backend` to +// process the incoming client requests after ensuring that dependencies as +// specified by the client are honored and the chunked requests are fully +// re-assembled. +class IfrtSessionHandler { + public: + using BackendFactory = + absl::AnyInvocable>( + uint64_t session_id)>; + + using Response = BackendInterface::Response; + + // Makes a new IfrtSessionHandler with the given Session ID that uniquely + // identifies this session. The backend_factory cannot be a nullptr. + static absl::StatusOr> Create( + uint64_t id, BackendFactory backend_factory); + + uint64_t session_id() const { return session_id_; } + + // Top-level handler the transport implementation calls to hand off a new + // incoming request. `on_done` is called asynchronously to return responses. + void NewIncomingRequest( + std::unique_ptr request, + std::function)> on_done); + + private: + IfrtSessionHandler(uint64_t id, BackendFactory backend_factory); + + // InitRequest is treated somewhat differently than the rest since it triggers + // the creation of the backend_ + Future ProcessInitRequest(std::unique_ptr request) + ABSL_LOCKS_EXCLUDED(backend_mu_); + + // Sets up the backaned_ only if needed - i.e., only if it is a nullptr. + absl::Status SetupBackendIfNeeded() ABSL_LOCKS_EXCLUDED(backend_mu_); + + const uint64_t session_id_; // Unique ID of this Session. + + // The backend_ runtime(s) this session relies on for processing the incoming + // requests. It is instantiated at the start of a new Bidi stream, and + // currently does not change for the life of this object. + BackendFactory backend_factory_; + absl::Mutex backend_mu_; + std::unique_ptr backend_ ABSL_GUARDED_BY(backend_mu_); +}; + +} // namespace proxy +} // namespace ifrt +} // namespace xla + +#endif // XLA_PYTHON_IFRT_PROXY_SERVER_IFRT_SESSION_HANDLER_H_ diff --git a/third_party/xla/xla/python/ifrt_proxy/server/ifrt_session_handler_test.cc b/third_party/xla/xla/python/ifrt_proxy/server/ifrt_session_handler_test.cc new file mode 100644 index 00000000000000..b5a1e0bc316d57 --- /dev/null +++ b/third_party/xla/xla/python/ifrt_proxy/server/ifrt_session_handler_test.cc @@ -0,0 +1,70 @@ +// Copyright 2023 The OpenXLA Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "xla/python/ifrt_proxy/server/ifrt_session_handler.h" + +#include +#include +#include + +#include +#include +#include "xla/python/ifrt/future.h" +#include "xla/python/ifrt_proxy/server/ifrt_backend.h" +#include "tsl/platform/status_matchers.h" + +namespace xla { +namespace ifrt { +namespace proxy { +namespace { + +using ::testing::Not; +using ::tsl::testing::IsOk; + +// FakeBackend. Currently: Fails or returns illegal values where possible. +// All other methods return dummy strings or empty vectors. Individual tests +// can make derived classes that override specific methods as needed. +class FakeBackend : public BackendInterface { + public: + FakeBackend() = default; + ~FakeBackend() override = default; + + Future Process( + std::unique_ptr request) override { + return Future(std::make_unique()); + } +}; + +TEST(IfrtSessionHandlerTest, NullptrForBackendMakerFails) { + EXPECT_THAT(IfrtSessionHandler::Create(1234, nullptr), Not(IsOk())); +} + +TEST(IfrtSessionHandlerTest, SuccessfulCreation) { + std::unique_ptr backend = std::make_unique(); + EXPECT_THAT( + IfrtSessionHandler::Create( + 1234, [&](uint64_t session_id) { return std::move(backend); }), + IsOk()); +} + +// TODO(b/282757875) Add "end-to-end" tests that cover the entire path from the +// Server/BidiReactor to the backend. Since IfrtSessionHandler writes the +// responses (IfrtResponse messages) directly to the Bidi Reactor, tests for the +// actual processing of requests need a full server and a fake client that +// allows us retrieve and examine the responses. + +} // namespace +} // namespace proxy +} // namespace ifrt +} // namespace xla diff --git a/third_party/xla/xla/python/ifrt_proxy/server/mock_ifrt_backend.h b/third_party/xla/xla/python/ifrt_proxy/server/mock_ifrt_backend.h new file mode 100644 index 00000000000000..620808f912be87 --- /dev/null +++ b/third_party/xla/xla/python/ifrt_proxy/server/mock_ifrt_backend.h @@ -0,0 +1,42 @@ +/* + * Copyright 2023 The OpenXLA Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef XLA_PYTHON_IFRT_PROXY_SERVER_MOCK_IFRT_BACKEND_H_ +#define XLA_PYTHON_IFRT_PROXY_SERVER_MOCK_IFRT_BACKEND_H_ + +#include + +#include +#include "absl/status/status.h" +#include "xla/python/ifrt/future.h" +#include "xla/python/ifrt_proxy/common/ifrt_service.pb.h" +#include "xla/python/ifrt_proxy/server/ifrt_backend.h" + +namespace xla { +namespace ifrt { +namespace proxy { + +class MockIfrtBackend final : public BackendInterface { + public: + MOCK_METHOD(Future, Process, (std::unique_ptr request), + (final)); +}; + +} // namespace proxy +} // namespace ifrt +} // namespace xla + +#endif // XLA_PYTHON_IFRT_PROXY_SERVER_MOCK_IFRT_BACKEND_H_ diff --git a/third_party/xla/xla/python/ifrt_proxy/server/version.cc b/third_party/xla/xla/python/ifrt_proxy/server/version.cc new file mode 100644 index 00000000000000..b4f5298203a5ab --- /dev/null +++ b/third_party/xla/xla/python/ifrt_proxy/server/version.cc @@ -0,0 +1,48 @@ +/* + * Copyright 2023 The OpenXLA Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "xla/python/ifrt_proxy/server/version.h" + +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" + +namespace xla { +namespace ifrt { +namespace proxy { + +absl::StatusOr ChooseVersion(int client_min_version, + int client_max_version, + int server_min_version, + int server_max_version) { + const int version = std::min(server_max_version, client_max_version); + + if (version < server_min_version || version < client_min_version) { + return absl::InvalidArgumentError(absl::StrCat( + "IFRT Proxy client and server failed to agree on the " + "protocol version; supported versions: client = [", + client_min_version, ", ", client_max_version, "], server = [", + server_min_version, ", ", server_max_version, "]")); + } + + return version; +} + +} // namespace proxy +} // namespace ifrt +} // namespace xla diff --git a/third_party/xla/xla/python/ifrt_proxy/server/version.h b/third_party/xla/xla/python/ifrt_proxy/server/version.h new file mode 100644 index 00000000000000..2556b5656f6188 --- /dev/null +++ b/third_party/xla/xla/python/ifrt_proxy/server/version.h @@ -0,0 +1,41 @@ +/* + * Copyright 2023 The OpenXLA Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef XLA_PYTHON_IFRT_PROXY_SERVER_VERSION_H_ +#define XLA_PYTHON_IFRT_PROXY_SERVER_VERSION_H_ + +#include "absl/status/statusor.h" + +namespace xla { +namespace ifrt { +namespace proxy { + +// TODO(b/296144873): Document the version upgrade policy. +inline constexpr int kServerMinVersion = 1; +inline constexpr int kServerMaxVersion = 1; + +// Returns a version that both the client and the server support, or an error if +// there is no such a version. +absl::StatusOr ChooseVersion(int client_min_version, + int client_max_version, + int server_min_version = kServerMinVersion, + int server_max_version = kServerMaxVersion); + +} // namespace proxy +} // namespace ifrt +} // namespace xla + +#endif // XLA_PYTHON_IFRT_PROXY_SERVER_VERSION_H_ diff --git a/third_party/xla/xla/python/ifrt_proxy/server/version_test.cc b/third_party/xla/xla/python/ifrt_proxy/server/version_test.cc new file mode 100644 index 00000000000000..efebcad9d65d9a --- /dev/null +++ b/third_party/xla/xla/python/ifrt_proxy/server/version_test.cc @@ -0,0 +1,69 @@ +/* + * Copyright 2023 The OpenXLA Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "xla/python/ifrt_proxy/server/version.h" + +#include +#include +#include "absl/status/status.h" +#include "tsl/platform/status_matchers.h" + +namespace xla { +namespace ifrt { +namespace proxy { +namespace { + +using ::tsl::testing::IsOk; +using ::tsl::testing::StatusIs; + +struct Param { + int client_min_version; + int client_max_version; + int server_min_version; + int server_max_version; +}; + +class CompatibleVersionTest : public ::testing::TestWithParam {}; + +TEST_P(CompatibleVersionTest, Verify) { + const Param& param = GetParam(); + EXPECT_THAT(ChooseVersion(param.client_min_version, param.client_max_version, + param.server_min_version, param.server_max_version), + IsOk()); +} + +INSTANTIATE_TEST_SUITE_P(CompatibleVersionTest, CompatibleVersionTest, + ::testing::Values(Param{1, 1, 1, 1}, Param{1, 2, 2, 2}, + Param{2, 2, 1, 2}, + Param{1, 3, 3, 4})); + +class IncompatibleVersionTest : public ::testing::TestWithParam {}; + +TEST_P(IncompatibleVersionTest, Verify) { + const Param& param = GetParam(); + EXPECT_THAT(ChooseVersion(param.client_min_version, param.client_max_version, + param.server_min_version, param.server_max_version), + StatusIs(absl::StatusCode::kInvalidArgument)); +} + +INSTANTIATE_TEST_SUITE_P(IncompatibleVersionTest, IncompatibleVersionTest, + ::testing::Values(Param{1, 2, 3, 3}, Param{1, 3, 4, 6}, + Param{1, 1, 2, 2})); + +} // namespace +} // namespace proxy +} // namespace ifrt +} // namespace xla diff --git a/third_party/xla/xla/python/jax_jit.cc b/third_party/xla/xla/python/jax_jit.cc index 41dd2ff5ed5c55..57b907c6643bf4 100644 --- a/third_party/xla/xla/python/jax_jit.cc +++ b/third_party/xla/xla/python/jax_jit.cc @@ -311,7 +311,7 @@ xla::Status ParseArguments(absl::Span positional_args, } } } - return ::tsl::OkStatus(); + return absl::OkStatus(); } void BuildJaxjitSubmodule(py::module& m) { diff --git a/third_party/xla/xla/python/pjrt_ifrt/BUILD b/third_party/xla/xla/python/pjrt_ifrt/BUILD index 8f26cc8906d177..c635ded4633be5 100644 --- a/third_party/xla/xla/python/pjrt_ifrt/BUILD +++ b/third_party/xla/xla/python/pjrt_ifrt/BUILD @@ -1,4 +1,5 @@ load("//xla:xla.bzl", "xla_cc_test") +load("@local_tsl//tsl:tsl.bzl", "internal_visibility") load("@local_tsl//tsl:tsl.default.bzl", "get_compatible_with_portable") load("@local_tsl//tsl/platform:build_config.bzl", "tf_proto_library") @@ -20,15 +21,16 @@ package_group( ) package( - default_visibility = ["//visibility:public"], + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = internal_visibility([ + ":friends", + ":internal", + ]), ) -exports_files( - [ - "BUILD", - ], - visibility = ["//visibility:public"], -) +exports_files([ + "BUILD", +]) # TODO(hyeontaek): Move this target out of pjrt_ifrt. cc_library( @@ -42,7 +44,6 @@ cc_library( "xla_sharding.h", ], compatible_with = get_compatible_with_portable(), - visibility = ["//visibility:public"], deps = [ ":xla_compiler_proto_cc", "//xla:util", @@ -67,21 +68,18 @@ tf_proto_library( srcs = ["xla_host_callback.proto"], cc_api_version = 2, protodeps = ["//xla:xla_data_proto"], - visibility = ["//visibility:public"], ) tf_proto_library( name = "xla_compiler_proto", srcs = ["xla_compiler.proto"], protodeps = ["//xla/pjrt:compile_options_proto"], - visibility = ["//visibility:public"], ) cc_library( name = "xla_program_serdes", srcs = ["xla_program_serdes.cc"], compatible_with = get_compatible_with_portable(), - visibility = ["//visibility:public"], deps = [ ":xla_ifrt", "//xla/mlir_hlo:mhlo_passes", @@ -124,13 +122,11 @@ tf_proto_library( "//xla:xla_data_proto", "//xla/python/ifrt:types_proto", ], - visibility = ["//visibility:public"], ) cc_library( name = "xla_sharding_serdes", srcs = ["xla_sharding_serdes.cc"], - visibility = ["//visibility:public"], deps = [ ":xla_ifrt", ":xla_sharding_proto_cc", @@ -161,7 +157,6 @@ cc_library( name = "xla_executable_impl_test_lib", testonly = True, srcs = ["xla_executable_impl_test_lib.cc"], - visibility = ["//visibility:public"], deps = [ ":xla_ifrt", "//xla/pjrt:mlir_to_hlo", @@ -222,7 +217,6 @@ cc_library( "pjrt_tuple.h", ], compatible_with = get_compatible_with_portable(), - visibility = ["//visibility:public"], deps = [ ":xla_ifrt", "//xla:literal", @@ -266,7 +260,6 @@ cc_library( name = "tfrt_cpu_client_test_lib", testonly = True, srcs = ["tfrt_cpu_client_test_lib.cc"], - visibility = ["//visibility:public"], deps = [ ":pjrt_ifrt", "//xla/pjrt/cpu:cpu_client", diff --git a/third_party/xla/xla/python/pjrt_ifrt/pjrt_client.cc b/third_party/xla/xla/python/pjrt_ifrt/pjrt_client.cc index fe128317a69457..d0f4929bf7c0a0 100644 --- a/third_party/xla/xla/python/pjrt_ifrt/pjrt_client.cc +++ b/third_party/xla/xla/python/pjrt_ifrt/pjrt_client.cc @@ -21,10 +21,13 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_map.h" #include "absl/functional/any_invocable.h" #include "absl/memory/memory.h" #include "absl/strings/str_join.h" #include "absl/types/span.h" +#include "xla/pjrt/pjrt_client.h" +#include "xla/python/ifrt/client.h" #include "xla/python/ifrt/sharding.h" #include "xla/python/pjrt_ifrt/pjrt_array.h" #include "xla/python/pjrt_ifrt/pjrt_tuple.h" @@ -54,6 +57,25 @@ std::unique_ptr PjRtClient::Create( return absl::WrapUnique(new PjRtClient(std::move(pjrt_client))); } +absl::flat_hash_map +PjRtClient::attributes() const { + absl::flat_hash_map attributes; + attributes.insert({"supports_executable_serialization", true}); + + if (std::optional plugin_attributes = + pjrt_client_->plugin_attributes(); + plugin_attributes.has_value()) { + attributes.insert( + {"pjrt_c_api_major_version", + ClientAttribute(plugin_attributes->pjrt_c_api_major_version)}); + attributes.insert( + {"pjrt_c_api_minor_version", + ClientAttribute(plugin_attributes->pjrt_c_api_minor_version)}); + } + + return attributes; +} + StatusOr> PjRtClient::CreatePjRtArray( std::shared_ptr pjrt_buffer) { TF_ASSIGN_OR_RETURN(auto array, diff --git a/third_party/xla/xla/python/pjrt_ifrt/pjrt_client.h b/third_party/xla/xla/python/pjrt_ifrt/pjrt_client.h index 0266c67153838a..0db769072e5e2e 100644 --- a/third_party/xla/xla/python/pjrt_ifrt/pjrt_client.h +++ b/third_party/xla/xla/python/pjrt_ifrt/pjrt_client.h @@ -111,18 +111,8 @@ class PjRtClient final DCHECK(this); return pjrt_client_->platform_id(); } - absl::flat_hash_map attributes() - const override { - std::optional attributes = - pjrt_client_->plugin_attributes(); - if (!attributes.has_value()) { - return {}; - } - return {{"pjrt_c_api_major_version", - ClientAttribute(attributes->pjrt_c_api_major_version)}, - {"pjrt_c_api_minor_version", - ClientAttribute(attributes->pjrt_c_api_minor_version)}}; - } + + absl::flat_hash_map attributes() const override; int device_count() const override { DCHECK(this); diff --git a/third_party/xla/xla/python/pjrt_ifrt/pjrt_executable.cc b/third_party/xla/xla/python/pjrt_ifrt/pjrt_executable.cc index 66b59d8abf8349..a3c330c7bae1d6 100644 --- a/third_party/xla/xla/python/pjrt_ifrt/pjrt_executable.cc +++ b/third_party/xla/xla/python/pjrt_ifrt/pjrt_executable.cc @@ -261,6 +261,7 @@ StatusOr> PjRtLoadedExecutable::Create( build_options.use_spmd_partitioning() && build_options.num_partitions() > 1 && (build_options.use_auto_spmd_partitioning() || + build_options.any_allow_spmd_sharding_propagation_to_parameters() || build_options.any_allow_spmd_sharding_propagation_to_output()); TF_ASSIGN_OR_RETURN( auto pjrt_loaded_executable, diff --git a/third_party/xla/xla/python/profiler.cc b/third_party/xla/xla/python/profiler.cc index 53521e25bfbb7c..3a6bdfaf582953 100644 --- a/third_party/xla/xla/python/profiler.cc +++ b/third_party/xla/xla/python/profiler.cc @@ -60,10 +60,10 @@ tensorflow::ProfileOptions DefaultPythonProfileOptions() { } const PLUGIN_Profiler_Api* FindProfilerApi(const PJRT_Api* pjrt_api) { - const PJRT_Structure_Base* next = - reinterpret_cast(pjrt_api->extension_start); + const PJRT_Extension_Base* next = + reinterpret_cast(pjrt_api->extension_start); while (next != nullptr && - next->type != PJRT_Structure_Type::PJRT_Structure_Type_Profiler) { + next->type != PJRT_Extension_Type::PJRT_Extension_Type_Profiler) { next = next->next; } if (next == nullptr) { diff --git a/third_party/xla/xla/python/profiler/internal/BUILD b/third_party/xla/xla/python/profiler/internal/BUILD index 8798ee6602b858..0d9ed5cad4a3ec 100644 --- a/third_party/xla/xla/python/profiler/internal/BUILD +++ b/third_party/xla/xla/python/profiler/internal/BUILD @@ -1,9 +1,11 @@ +load("@local_tsl//tsl:tsl.bzl", "internal_visibility") load("@local_tsl//tsl:tsl.default.bzl", "get_compatible_with_portable") load("@local_tsl//tsl/profiler/builds:build_config.bzl", "tf_profiler_copts") load("@local_tsl//tsl/platform:rules_cc.bzl", "cc_library") package( - default_visibility = ["//visibility:public"], + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = ["//xla/python/profiler:__subpackages__"], licenses = ["notice"], ) @@ -14,7 +16,10 @@ cc_library( compatible_with = get_compatible_with_portable(), copts = tf_profiler_copts() + ["-fexceptions"], features = ["-use_header_modules"], # Incompatible with -fexceptions. - visibility = ["//visibility:public"], + visibility = internal_visibility([ + "//xla/backends/profiler:__subpackages__", + "//tensorflow/python/profiler/internal:__subpackages__", + ]), deps = [ "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/memory", @@ -37,7 +42,10 @@ cc_library( name = "traceme_wrapper", hdrs = ["traceme_wrapper.h"], copts = tf_profiler_copts(), - visibility = ["//visibility:public"], + visibility = internal_visibility([ + "//xla/python:__pkg__", + "//tensorflow/python/profiler/internal:__pkg__", + ]), deps = [ "@com_google_absl//absl/strings", "@local_tsl//tsl/platform:macros", diff --git a/third_party/xla/xla/python/py_array.cc b/third_party/xla/xla/python/py_array.cc index bed3b172ddbb53..64e4f0fbe866a5 100644 --- a/third_party/xla/xla/python/py_array.cc +++ b/third_party/xla/xla/python/py_array.cc @@ -52,6 +52,9 @@ limitations under the License. #include "xla/python/types.h" #include "xla/python/util.h" #include "xla/shape.h" +#if GOOGLE_CUDA +#include "xla/stream_executor/cuda/cuda_driver.h" +#endif #include "xla/util.h" #include "tsl/platform/statusor.h" @@ -706,6 +709,121 @@ py::dict PyArray::CudaArrayInterface() { return result; } +StatusOr CudaArrayInterfaceToBuffer( + const pybind11::dict& cai, std::shared_ptr client) { +#ifndef GOOGLE_CUDA + throw XlaRuntimeError("This operation requires CUDA support."); +#else + if (!cai.contains("data")) { + return absl::InvalidArgumentError( + "CUDA Array Interface does not define `data`"); + } + if (!cai.contains("shape")) { + return absl::InvalidArgumentError( + "CUDA Array Interface does not define `shape`"); + } + if (!cai.contains("typestr")) { + return absl::InvalidArgumentError( + "CUDA Array Interface does not define `typestr`"); + } + if (!cai.contains("version")) { + return absl::InvalidArgumentError( + "CUDA Array Interface does not define `version`"); + } + auto version = py::cast(cai["version"]); + if (version < 2 || version > 3) { + LOG(WARNING) << "CUDA Array Interface version " << version + << " support is undefined"; + } + auto data = py::cast(cai["data"]); + auto data_value = pybind11::cast(data[0]); + void* data_ptr = reinterpret_cast(data_value); + auto dimensions = pybind11::cast>(cai["shape"]); + if (data_value == 0 && absl::c_find(dimensions, 0) == dimensions.end()) { + return absl::InvalidArgumentError( + "CUDA Array Interface `data`(=NULL) and `shape`(no zero-valued " + "dimensions) are inconsistent"); + } + auto ndim = dimensions.size(); + TF_ASSIGN_OR_RETURN( + PrimitiveType element_type, + DtypeToPrimitiveType(py::dtype::from_args(cai["typestr"]))); + + // cannot determine device_id/stream when device pointer is NULL. + int device_id = + (data_value == 0 + ? 0 + : stream_executor::gpu::CreatedContexts::GetDeviceOrdinal(data_ptr)); + TF_ASSIGN_OR_RETURN(auto device, + client->DeviceFromLocalHardwareId(device_id)); + bool is_default_stream = + data_value == 0 || version == 2 || + (version == 3 && (!cai.contains("stream") || cai["stream"].is_none())); + TF_ASSIGN_OR_RETURN( + std::intptr_t stream, + ([is_default_stream, cai, device]() -> StatusOr { + if (is_default_stream) { + return device->GetStreamForExternalReadyEvents(); + } else { + auto stream_ = py::cast(cai["stream"]); + if (stream_ == 0) { + return absl::InvalidArgumentError( + "CUDA Array Interface does not allow zero stream value"); + } + return stream_; + } + }())); + + std::vector minor_to_major(ndim); + if (cai.contains("strides") && !cai["strides"].is_none() && data_value != 0) { + std::iota(minor_to_major.begin(), minor_to_major.end(), 0); + auto strides = pybind11::cast>(cai["strides"]); + if (strides.size() != ndim) { + return absl::InvalidArgumentError( + "CUDA Array Interface `shape` and `strides` dimensionalities are " + "inconsistent"); + } + absl::c_sort(minor_to_major, [&](int a, int b) { + // If two dimensions have the same stride, prefer the major-to-minor + // interpretation of the ordering, since that's what JAX wants. + return (strides[a] == strides[b] ? b < a : strides[a] < strides[b]); + }); + int64_t stride = ShapeUtil::ByteSizeOfPrimitiveType(element_type); + for (int64_t d : minor_to_major) { + if (dimensions[d] > 1 && strides[d] != stride) { + return absl::UnimplementedError(absl::StrCat( + "Only arrays with trivial (compact) striding are supported; " + "i.e., arrays whose striding represents a transposition of the " + "underlying buffer but not broadcasting. Dimensions were: [%s], " + "strides were [%s].", + absl::StrJoin(dimensions, ","), absl::StrJoin(strides, ","))); + } + stride *= dimensions[d]; + } + } else { + std::iota(minor_to_major.rbegin(), minor_to_major.rend(), 0); + } + Shape shape = ShapeUtil::MakeShapeWithDenseLayout(element_type, dimensions, + minor_to_major); + std::function on_delete_callback = []() {}; + TF_ASSIGN_OR_RETURN( + auto pjrt_buffer, + device->client()->CreateViewOfDeviceBuffer( + static_cast(data_ptr), shape, device.get(), on_delete_callback, + stream <= 2 ? std::nullopt : std::make_optional(stream))); + auto* ifrt_client = + llvm::dyn_cast_or_null(client->ifrt_client()); + if (ifrt_client == nullptr) { + throw XlaRuntimeError( + "This operation is implemented for a PjRt-compatible backend only."); + } + TF_ASSIGN_OR_RETURN(auto ifrt_array, + ifrt_client->CreatePjRtArray(std::move(pjrt_buffer))); + return PyArray::MakeFromSingleDeviceArray(std::move(client), Traceback::Get(), + std::move(ifrt_array), false, true); +#endif // GOOGLE_CUDA +} + Status PyArray::Delete() { for (auto& arr : py_arrays()) { TF_RETURN_IF_ERROR(arr.Delete()); diff --git a/third_party/xla/xla/python/py_array.h b/third_party/xla/xla/python/py_array.h index 745f647913b5f6..bc4e687c7ad646 100644 --- a/third_party/xla/xla/python/py_array.h +++ b/third_party/xla/xla/python/py_array.h @@ -276,6 +276,9 @@ class PyArrayResultHandler { std::vector shape_; }; +StatusOr CudaArrayInterfaceToBuffer( + const pybind11::dict& cai, std::shared_ptr cuda_client); + } // namespace xla #endif // XLA_PYTHON_PY_ARRAY_H_ diff --git a/third_party/xla/xla/python/sharding.cc b/third_party/xla/xla/python/sharding.cc index a34fdba14f61c7..9ca04a54a1b5eb 100644 --- a/third_party/xla/xla/python/sharding.cc +++ b/third_party/xla/xla/python/sharding.cc @@ -284,18 +284,10 @@ GSPMDSharding::GSPMDSharding(py::tuple devices, xla::HloSharding op_sharding, } void RegisterSharding(py::module& m) { - py::object abc_module = py::module::import("abc"); - py::object abc_meta = abc_module.attr("ABCMeta"); - py::object abc_init = abc_module.attr("_abc_init"); - - // NOLINTNEXTLINE(bugprone-unused-raii) - py::class_(m, "Sharding", py::metaclass(abc_meta)); - abc_init(py::type::of()); - - // NOLINTNEXTLINE(bugprone-unused-raii) - py::class_(m, "XLACompatibleSharding", - py::metaclass(abc_meta)); - abc_init(py::type::of()); + py::class_(m, "Sharding").def(py::init<>()); + + py::class_(m, "XLACompatibleSharding") + .def(py::init<>()); py::class_(m, "NamedSharding", py::dynamic_attr()) diff --git a/third_party/xla/xla/python/tpu_driver/BUILD b/third_party/xla/xla/python/tpu_driver/BUILD deleted file mode 100644 index caaf24e3c78cd0..00000000000000 --- a/third_party/xla/xla/python/tpu_driver/BUILD +++ /dev/null @@ -1,131 +0,0 @@ -load( - "//xla/python/tpu_driver:platform/external/tools.bzl", - "external_deps", - "go_grpc_library", -) -load("@local_tsl//tsl:tsl.default.bzl", "tsl_grpc_cc_dependencies") -load("@local_tsl//tsl/platform:build_config.bzl", "tf_proto_library") -load("@local_tsl//tsl/platform:rules_cc.bzl", "cc_library") - -licenses(["notice"]) - -package( - default_visibility = ["//visibility:public"], -) - -tf_proto_library( - name = "tpu_driver_proto", - srcs = ["tpu_driver.proto"], - cc_api_version = 2, - protodeps = [], - visibility = ["//visibility:public"], -) - -tf_proto_library( - name = "tpu_service_proto", - srcs = ["tpu_service.proto"], - has_services = 1, - cc_api_version = 2, - create_grpc_library = True, - protodeps = [ - ":tpu_driver_proto", - "//xla:xla_data_proto", - "//xla:xla_proto", - "//xla/service:hlo_proto", - ], - use_grpc_namespace = True, - visibility = ["//visibility:public"], -) - -cc_library( - name = "tpu_driver", - srcs = [ - "tpu_driver.cc", - ], - hdrs = [ - "event_id.h", - "platform/external/compat.h", - "tpu_driver.h", - ], - visibility = ["//visibility:public"], - deps = [ - ":tpu_driver_proto_cc", - "//xla:status", - "//xla:statusor", - "//xla:util", - "//xla:xla_data_proto_cc", - "//xla:xla_proto_cc", - "//xla/service:hlo_proto_cc", - "@local_tsl//tsl/platform:logging", - ] + external_deps(), -) - -cc_library( - name = "grpc_tpu_driver", - srcs = [ - "grpc_tpu_driver.cc", - ], - hdrs = ["grpc_tpu_driver.h"], - visibility = ["//visibility:public"], - deps = [ - ":tpu_driver", - ":tpu_driver_proto_cc", - ":tpu_service_cc_grpc_proto", - ":tpu_service_proto_cc", - "//xla:status", - "//xla:util", - "//xla:xla_data_proto_cc", - "//xla/service:hlo_proto_cc", - "@local_tsl//tsl/platform:logging", - ] + tsl_grpc_cc_dependencies() + external_deps(), - alwayslink = 1, -) - -cc_library( - name = "recording_tpu_driver", - srcs = [ - "recording_tpu_driver.cc", - ], - visibility = ["//visibility:public"], - deps = [ - ":tpu_driver", - ":tpu_driver_proto_cc", - ":tpu_service_cc_grpc_proto", - ":tpu_service_proto_cc", - "//xla:status", - "//xla:util", - "//xla:xla_data_proto_cc", - "//xla:xla_proto_cc", - "//xla/service:hlo_proto_cc", - "@com_google_absl//absl/base", - "@local_tsl//tsl/platform:env", - "@local_tsl//tsl/platform:logging", - ] + external_deps(), - alwayslink = 1, -) - -cc_library( - name = "pod_tpu_driver", - srcs = ["pod_tpu_driver.cc"], - visibility = ["//visibility:public"], - deps = [ - ":grpc_tpu_driver", - ":tpu_driver", - ":tpu_driver_proto_cc", - "//xla/pjrt:semaphore", - "//xla/pjrt:worker_thread", - "@com_google_absl//absl/container:btree", - "@com_google_absl//absl/container:flat_hash_set", - "@local_tsl//tsl/platform:env", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:logging", - ] + tsl_grpc_cc_dependencies() + external_deps(), - alwayslink = 1, -) - -go_grpc_library( - name = "tpu_service_go_grpc", - srcs = [":tpu_service_proto"], - compatible_with = ["//buildenv/target:non_prod"], - deps = [":tpu_service_go_proto"], -) diff --git a/third_party/xla/xla/python/tpu_driver/README.md b/third_party/xla/xla/python/tpu_driver/README.md deleted file mode 100644 index 5b31df30ecc20b..00000000000000 --- a/third_party/xla/xla/python/tpu_driver/README.md +++ /dev/null @@ -1,22 +0,0 @@ -# TPU Driver API - -This repository contains the TPU driver API and network (gRPC) transport -implementation for high-performance access to TPU hardware. - -# Building - -Bazel is used to build the driver library and tests. Remote tests will require -access to a Cloud TPU. - -## Fetching Bazel - -Download the latest copy of Bazel from -https://github.com/bazelbuild/bazel/releases. - -## Building - -`bazel build ...` - -## Testing - -`bazel test ...` diff --git a/third_party/xla/xla/python/tpu_driver/client/BUILD b/third_party/xla/xla/python/tpu_driver/client/BUILD deleted file mode 100644 index b72ae894c12336..00000000000000 --- a/third_party/xla/xla/python/tpu_driver/client/BUILD +++ /dev/null @@ -1,22 +0,0 @@ -load("@local_tsl//tsl:tsl.default.bzl", "filegroup") -load("@local_tsl//tsl/platform:rules_cc.bzl", "cc_library") - -package( - default_visibility = ["//visibility:public"], - licenses = ["notice"], -) - -filegroup( - name = "header_and_client", - srcs = glob([ - "c_api*", - "libtpu*", - ]), - visibility = ["//visibility:public"], -) - -cc_library( - name = "libtpu", - hdrs = ["libtpu.h"], - visibility = ["//visibility:public"], -) diff --git a/third_party/xla/xla/python/tpu_driver/client/libtpu.h b/third_party/xla/xla/python/tpu_driver/client/libtpu.h deleted file mode 100644 index dc63a8015a0413..00000000000000 --- a/third_party/xla/xla/python/tpu_driver/client/libtpu.h +++ /dev/null @@ -1,309 +0,0 @@ -/* Copyright 2019 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_PYTHON_TPU_DRIVER_CLIENT_LIBTPU_H_ -#define XLA_PYTHON_TPU_DRIVER_CLIENT_LIBTPU_H_ - -#include -#include - -#define TPUDRIVER_CAPI_EXPORT __attribute__((visibility("default"))) - -#ifdef __cplusplus -extern "C" { -#endif - -// ------------------- TPU Driver Support ----------------------- - -struct TpuDriverFn; - -typedef struct TpuDriver TpuDriver; - -typedef struct TpuEvent TpuEvent; - -typedef struct TpuBufferHandleInternal TpuBufferHandleInternal; - -typedef struct TpuCompiledProgramHandleInternal - TpuCompiledProgramHandleInternal; - -typedef struct TpuLoadedProgramHandleInternal TpuLoadedProgramHandleInternal; - -typedef struct TpuBufferHandle { - TpuBufferHandleInternal* internal_handle; - TpuEvent* event; - int64_t size_in_bytes; -} TpuBufferHandle; - -typedef struct TpuCompiledProgramHandle { - TpuCompiledProgramHandleInternal* internal_handle; - TpuEvent* event; -} TpuCompiledProgramHandle; - -typedef struct TpuLoadedProgramHandle { - TpuLoadedProgramHandleInternal* internal_handle; - TpuEvent* event; -} TpuLoadedProgramHandle; - -// HloProto is a serialized xla::HloProto buffer. -typedef struct HloProto { - void* buffer; - int32_t size; -} HloProto; - -typedef struct DebugOptions { - void* buffer; - int32_t size; -} DebugOptions; - -// DeviceAssignment is a serialized xla::DeviceAssignmentProto buffer. -typedef struct DeviceAssignment { - void* bytes; - int32_t size; -} DeviceAssignment; - -typedef struct TpuStatus { - int32_t code; - char* msg; -} TpuStatus; - -typedef struct CompiledProgramShape { - struct TpuStatus* status; - void* bytes; - int32_t size; -} CompiledProgramShape; - -typedef struct TpuAllocationShape { - void* bytes; - int32_t size; -} TpuAllocationShape; - -typedef struct TpuSystemInfo { - void* bytes; - int32_t size; -} TpuSystemInfo; - -typedef void(PrototypeTpuDriver_Initialize)(struct TpuDriverFn* driver_fn, - bool initialize); -typedef struct TpuDriver*(PrototypeTpuDriver_Open)(const char* worker); -typedef void(PrototypeTpuDriver_Close)(struct TpuDriver* driver); -typedef struct TpuStatus*(PrototypeTpuDriver_Reset)(struct TpuDriver* driver); - -typedef struct TpuSystemInfo*(PrototypeTpuDriver_QuerySystemInfo)( - struct TpuDriver* driver); - -typedef void(PrototypeTpuDriver_FreeSystemInfo)(struct TpuSystemInfo* info); - -typedef int64_t(PrototypeTpuDriver_ComputeLinearizedBytesFromShape)( - struct TpuDriver* driver, const struct TpuAllocationShape shape); - -typedef struct TpuStatus*(PrototypeTpuDriver_LinearizeShape)( - struct TpuDriver* driver, void* dst, const void* src, - const struct TpuAllocationShape shape); - -typedef struct TpuStatus*(PrototypeTpuDriver_DelinearizeShape)( - struct TpuDriver* driver, void* dst, const void* src, - const struct TpuAllocationShape shape); - -typedef struct TpuCompiledProgramHandle*( - PrototypeTpuDriver_CompileProgram)(struct TpuDriver* driver, - const struct HloProto hlo_proto, - int32_t num_replicas, - const struct DebugOptions debug_options, - int32_t eventc, - struct TpuEvent** eventv); - -typedef struct TpuCompiledProgramHandle*( - PrototypeTpuDriver_CompileProgramFromText)(struct TpuDriver* driver, - const char* hlo_text, - int32_t num_replicas, - int32_t eventc, - struct TpuEvent** eventv); - -/* Note: We are not responsible for freeing the event within the - * TpuCompiledProgramHandle. You have to call FreeEvent separately to ensure - * that memory does not leak. - */ -typedef void(PrototypeTpuDriver_FreeCompiledProgramHandle)( - struct TpuCompiledProgramHandle* handle); - -typedef struct TpuLoadedProgramHandle*(PrototypeTpuDriver_LoadProgram)( - struct TpuDriver* driver, int32_t core_id, - const struct TpuCompiledProgramHandle* compiled_program_handle, - int32_t eventc, struct TpuEvent** eventv); - -/* Note: We are not responsible for freeing the event within the - * TpuLoadedProgramHandle. You have to call FreeEvent separately to ensure that - * memory does not leak. - */ -typedef struct TpuEvent*(PrototypeTpuDriver_UnloadProgram)( - struct TpuDriver* driver, - struct TpuLoadedProgramHandle* loaded_program_handle, int32_t eventc, - struct TpuEvent** eventv); - -typedef struct TpuEvent*(PrototypeTpuDriver_ExecuteProgram)( - struct TpuDriver* driver, struct TpuLoadedProgramHandle* handle, - int32_t inputc, struct TpuBufferHandle** input_buffer_handle, - int32_t outputc, struct TpuBufferHandle** output_buffer_handle, - struct DeviceAssignment device_assignment, int32_t eventc, - struct TpuEvent** eventv); - -typedef struct TpuBufferHandle*(PrototypeTpuDriver_AllocateTuple)( - struct TpuDriver* driver, int32_t core_id, int32_t memory_region, - int32_t bufferc, struct TpuBufferHandle** buffer_handle, int32_t eventc, - struct TpuEvent** eventv); - -typedef struct TpuBufferHandle*(PrototypeTpuDriver_Allocate)( - struct TpuDriver* driver, int32_t core_id, int32_t memory_region, - int64_t num_bytes, int32_t eventc, struct TpuEvent** eventv); - -typedef struct TpuBufferHandle*(PrototypeTpuDriver_AllocateShape)( - struct TpuDriver* driver, int32_t core_id, int32_t memory_region, - const struct TpuAllocationShape shape, int32_t eventc, - struct TpuEvent** eventv); - -/* Note: We are not responsible for freeing the event within the - * TpuBufferHandle. You have to call FreeEvent separately to ensure that memory - * does not leak. - */ -typedef struct TpuEvent*(PrototypeTpuDriver_Deallocate)( - struct TpuDriver* driver, struct TpuBufferHandle* buffer_handle, - int32_t eventc, struct TpuEvent** eventv); - -typedef struct TpuEvent*(PrototypeTpuDriver_TransferToDevice)( - struct TpuDriver* driver, const void* src, struct TpuBufferHandle* dst, - int32_t eventc, struct TpuEvent** eventv); - -typedef struct TpuEvent*(PrototypeTpuDriver_TransferFromDevice)( - struct TpuDriver* driver, struct TpuBufferHandle* src, void* dst, - int32_t eventc, struct TpuEvent** eventv); - -typedef struct TpuEvent*(PrototypeTpuDriver_TransferFromDeviceToDevice)( - struct TpuDriver* driver, struct TpuBufferHandle* src, - struct TpuBufferHandle* dst, int32_t eventc, struct TpuEvent** eventv); - -typedef struct CompiledProgramShape*( - PrototypeTpuDriver_GetCompiledProgramShape)( - struct TpuCompiledProgramHandle* handle); - -typedef void(PrototypeTpuDriver_FreeCompiledProgramShape)( - struct CompiledProgramShape* shape); - -typedef void(PrototypeTpuDriver_EventAddCallback)( - struct TpuEvent* event, - void (*callback_fn)(struct TpuStatus*, void* additional_info), - void* additional_info); - -typedef struct TpuStatus*(PrototypeTpuDriver_EventAwait)(struct TpuEvent* event, - int64_t timeout_in_us); - -typedef void(PrototypeTpuDriver_FreeEvent)(struct TpuEvent* event); - -typedef void(PrototypeTpuDriver_FreeStatus)(struct TpuStatus* status); - -typedef const char*(PrototypeTpuDriver_Version)(); - -TPUDRIVER_CAPI_EXPORT extern PrototypeTpuDriver_Initialize TpuDriver_Initialize; -TPUDRIVER_CAPI_EXPORT extern PrototypeTpuDriver_Open TpuDriver_Open; -TPUDRIVER_CAPI_EXPORT extern PrototypeTpuDriver_Close TpuDriver_Close; -TPUDRIVER_CAPI_EXPORT extern PrototypeTpuDriver_Reset TpuDriver_Reset; -TPUDRIVER_CAPI_EXPORT extern PrototypeTpuDriver_QuerySystemInfo - TpuDriver_QuerySystemInfo; -TPUDRIVER_CAPI_EXPORT extern PrototypeTpuDriver_FreeSystemInfo - TpuDriver_FreeSystemInfo; -TPUDRIVER_CAPI_EXPORT extern PrototypeTpuDriver_ComputeLinearizedBytesFromShape - TpuDriver_ComputeLinearizedBytesFromShape; -TPUDRIVER_CAPI_EXPORT extern PrototypeTpuDriver_LinearizeShape - TpuDriver_LinearizeShape; -TPUDRIVER_CAPI_EXPORT extern PrototypeTpuDriver_DelinearizeShape - TpuDriver_DelinearizeShape; -TPUDRIVER_CAPI_EXPORT extern PrototypeTpuDriver_CompileProgram - TpuDriver_CompileProgram; -TPUDRIVER_CAPI_EXPORT extern PrototypeTpuDriver_CompileProgramFromText - TpuDriver_CompileProgramFromText; -TPUDRIVER_CAPI_EXPORT extern PrototypeTpuDriver_FreeCompiledProgramHandle - TpuDriver_FreeCompiledProgramHandle; -TPUDRIVER_CAPI_EXPORT extern PrototypeTpuDriver_LoadProgram - TpuDriver_LoadProgram; -TPUDRIVER_CAPI_EXPORT extern PrototypeTpuDriver_UnloadProgram - TpuDriver_UnloadProgram; -TPUDRIVER_CAPI_EXPORT extern PrototypeTpuDriver_ExecuteProgram - TpuDriver_ExecuteProgram; -TPUDRIVER_CAPI_EXPORT extern PrototypeTpuDriver_AllocateTuple - TpuDriver_AllocateTuple; -TPUDRIVER_CAPI_EXPORT extern PrototypeTpuDriver_Allocate TpuDriver_Allocate; -TPUDRIVER_CAPI_EXPORT extern PrototypeTpuDriver_AllocateShape - TpuDriver_AllocateShape; -TPUDRIVER_CAPI_EXPORT extern PrototypeTpuDriver_Deallocate TpuDriver_Deallocate; -TPUDRIVER_CAPI_EXPORT extern PrototypeTpuDriver_TransferToDevice - TpuDriver_TransferToDevice; -TPUDRIVER_CAPI_EXPORT extern PrototypeTpuDriver_TransferFromDevice - TpuDriver_TransferFromDevice; -TPUDRIVER_CAPI_EXPORT extern PrototypeTpuDriver_TransferFromDeviceToDevice - TpuDriver_TransferFromDeviceToDevice; -TPUDRIVER_CAPI_EXPORT extern PrototypeTpuDriver_GetCompiledProgramShape - TpuDriver_GetCompiledProgramShape; -TPUDRIVER_CAPI_EXPORT extern PrototypeTpuDriver_FreeCompiledProgramShape - TpuDriver_FreeCompiledProgramShape; -TPUDRIVER_CAPI_EXPORT extern PrototypeTpuDriver_EventAddCallback - TpuDriver_EventAddCallback; -TPUDRIVER_CAPI_EXPORT extern PrototypeTpuDriver_EventAwait TpuDriver_EventAwait; -TPUDRIVER_CAPI_EXPORT extern PrototypeTpuDriver_FreeEvent TpuDriver_FreeEvent; -TPUDRIVER_CAPI_EXPORT extern PrototypeTpuDriver_FreeStatus TpuDriver_FreeStatus; -TPUDRIVER_CAPI_EXPORT extern PrototypeTpuDriver_Version TpuDriver_Version; - -#ifdef __cplusplus -} -#endif - -struct TpuDriverFn { - PrototypeTpuDriver_Open* TpuDriver_Open; // NOLINT - PrototypeTpuDriver_Close* TpuDriver_Close; // NOLINT - PrototypeTpuDriver_Reset* TpuDriver_Reset; // NOLINT - PrototypeTpuDriver_ComputeLinearizedBytesFromShape* - TpuDriver_ComputeLinearizedBytesFromShape; // NOLINT - PrototypeTpuDriver_QuerySystemInfo* TpuDriver_QuerySystemInfo; // NOLINT - PrototypeTpuDriver_FreeSystemInfo* TpuDriver_FreeSystemInfo; // NOLINT - PrototypeTpuDriver_LinearizeShape* TpuDriver_LinearizeShape; // NOLINT - PrototypeTpuDriver_DelinearizeShape* TpuDriver_DelinearizeShape; // NOLINT - PrototypeTpuDriver_CompileProgram* TpuDriver_CompileProgram; // NOLINT - PrototypeTpuDriver_CompileProgramFromText* - TpuDriver_CompileProgramFromText; // NOLINT - PrototypeTpuDriver_FreeCompiledProgramHandle* - TpuDriver_FreeCompiledProgramHandle; // NOLINT - PrototypeTpuDriver_LoadProgram* TpuDriver_LoadProgram; // NOLINT - PrototypeTpuDriver_UnloadProgram* TpuDriver_UnloadProgram; // NOLINT - PrototypeTpuDriver_ExecuteProgram* TpuDriver_ExecuteProgram; // NOLINT - PrototypeTpuDriver_AllocateTuple* TpuDriver_AllocateTuple; // NOLINT - PrototypeTpuDriver_Allocate* TpuDriver_Allocate; // NOLINT - PrototypeTpuDriver_AllocateShape* TpuDriver_AllocateShape; // NOLINT - PrototypeTpuDriver_Deallocate* TpuDriver_Deallocate; // NOLINT - PrototypeTpuDriver_TransferToDevice* TpuDriver_TransferToDevice; // NOLINT - PrototypeTpuDriver_TransferFromDevice* - TpuDriver_TransferFromDevice; // NOLINT - PrototypeTpuDriver_TransferFromDeviceToDevice* - TpuDriver_TransferFromDeviceToDevice; // NOLINT - PrototypeTpuDriver_GetCompiledProgramShape* - TpuDriver_GetCompiledProgramShape; // NOLINT - PrototypeTpuDriver_FreeCompiledProgramShape* - TpuDriver_FreeCompiledProgramShape; // NOLINT - PrototypeTpuDriver_EventAddCallback* TpuDriver_EventAddCallback; // NOLINT - PrototypeTpuDriver_EventAwait* TpuDriver_EventAwait; // NOLINT - PrototypeTpuDriver_FreeEvent* TpuDriver_FreeEvent; // NOLINT - PrototypeTpuDriver_FreeStatus* TpuDriver_FreeStatus; // NOLINT - - PrototypeTpuDriver_Version* TpuDriver_Version; // NOLINT -}; - -#endif // XLA_PYTHON_TPU_DRIVER_CLIENT_LIBTPU_H_ diff --git a/third_party/xla/xla/python/tpu_driver/client/libtpu_client.c b/third_party/xla/xla/python/tpu_driver/client/libtpu_client.c deleted file mode 100644 index 11282e8731ab9f..00000000000000 --- a/third_party/xla/xla/python/tpu_driver/client/libtpu_client.c +++ /dev/null @@ -1,167 +0,0 @@ -/* Copyright 2019 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// Before you start, make sure libtpu.so, libtpu.h and libtpu_client.c are in -// the same working directory. -// -// To compile: gcc -o libtpu_client libtpu_client.c -ldl -// To run: sudo ./libtpu_client - -#include -#include -#include - -#include "libtpu.h" - -void* LoadAndInitializeDriver(const char* shared_lib, - struct TpuDriverFn* driver_fn) { - void* handle; - handle = dlopen(shared_lib, RTLD_NOW); - if (!handle) { - fprintf(stderr, "Error: %s\n", dlerror()); - exit(EXIT_FAILURE); - } - - PrototypeTpuDriver_Initialize* initialize_fn; - *(void**)(&initialize_fn) = dlsym(handle, "TpuDriver_Initialize"); - initialize_fn(driver_fn, true); - - return handle; -} - -int main(int argc, char** argv) { - char* api_path = "libtpu.so"; - if (argc == 2) { - api_path = argv[1]; - } - - struct TpuDriverFn driver_fn; - void* handle = LoadAndInitializeDriver(api_path, &driver_fn); - - fprintf(stdout, "------ Going to Query Version ------\n"); - fprintf(stdout, "TPU Driver Version: %s\n", driver_fn.TpuDriver_Version()); - - fprintf(stdout, "------ Going to Open a TPU Driver ------\n"); - struct TpuDriver* driver = driver_fn.TpuDriver_Open("local://"); - - fprintf(stdout, "------ Going to Query for System Information ------\n"); - struct TpuSystemInfo* info = driver_fn.TpuDriver_QuerySystemInfo(driver); - driver_fn.TpuDriver_FreeSystemInfo(info); - - // An example of simple program to sum two parameters. - const char* hlo_module_text = R"(HloModule add_vec_module - ENTRY %add_vec (a: s32[256], b: s32[256]) -> s32[256] { - %a = s32[256] parameter(0) - %b = s32[256] parameter(1) - ROOT %sum = s32[256] add(%a, %b) - } - )"; - - fprintf(stdout, "------ Going to Compile a TPU program ------\n"); - struct TpuCompiledProgramHandle* cph = - driver_fn.TpuDriver_CompileProgramFromText(driver, hlo_module_text, - /*num_replicas=*/1, /*eventc=*/0, /*eventv*/NULL); - - TpuEvent* compile_events[] = {cph->event}; - fprintf(stdout, "------ Going to Load a TPU program ------\n"); - struct TpuLoadedProgramHandle* lph = - driver_fn.TpuDriver_LoadProgram(driver, /*core_id=*/0, cph, - /*eventc=*/1, /*eventv=*/compile_events); - - const int size = 1024; - - fprintf(stdout, "------ Going to Allocate a TPU Buffer ------\n"); - struct TpuBufferHandle* buf_a_handle = - driver_fn.TpuDriver_Allocate(driver, /*core-id=*/0, /*memory_region=*/1, - /*bytes=*/size, /*eventc=*/0, /*eventv=*/NULL); - fprintf(stdout, "------ Going to Allocate a TPU Buffer ------\n"); - struct TpuBufferHandle* buf_b_handle = - driver_fn.TpuDriver_Allocate(driver, /*core-id=*/0, /*memory_region=*/1, - /*bytes=*/size, /*eventc=*/0, /*eventv=*/NULL); - fprintf(stdout, "------ Going to Allocate a TPU Buffer ------\n"); - struct TpuBufferHandle* buf_sum_handle = - driver_fn.TpuDriver_Allocate(driver, /*core-id=*/0, /*memory_region=*/1, - /*bytes=*/size, /*eventc=*/0, /*eventv=*/NULL); - - char a_src[size], b_src[size], sum_src[size]; - for (int i = 0; i < size; ++i) { - a_src[i] = 1; - b_src[i] = 2; - sum_src[i] = 0; - } - - TpuEvent* allocate_buf_a_events[] = {buf_a_handle->event}; - fprintf(stdout, "------ Going to Transfer To Device ------\n"); - struct TpuEvent* transfer_ev1 = - driver_fn.TpuDriver_TransferToDevice(driver, a_src, buf_a_handle, - /*eventc=*/1, /*eventv=*/allocate_buf_a_events); - TpuEvent* allocate_buf_b_events[] = {buf_a_handle->event}; - fprintf(stdout, "------ Going to Transfer To Device ------\n"); - struct TpuEvent* transfer_ev2 = - driver_fn.TpuDriver_TransferToDevice(driver, b_src, buf_b_handle, - /*eventc=*/1, /*eventv=*/allocate_buf_b_events); - - fprintf(stdout, "------ Going to Execute a TPU program ------\n"); - DeviceAssignment device_assignment = {NULL, 0}; - TpuBufferHandle* input_buffer_handle[] = {buf_a_handle, buf_b_handle}; - TpuBufferHandle* output_buffer_handle[] = {buf_sum_handle}; - TpuEvent* transfer_events[] = {transfer_ev1, transfer_ev2}; - struct TpuEvent* execute_event = - driver_fn.TpuDriver_ExecuteProgram(driver, lph, - /*inputc=*/2, /*input_buffer_handle=*/input_buffer_handle, - /*outputc=*/1, /*output_buffer_handle=*/output_buffer_handle, - device_assignment, - /*eventc=*/2, /*eventv*/transfer_events); - - fprintf(stdout, "------ Going to Transfer From Device ------\n"); - TpuEvent* execute_events[] = {execute_event}; - struct TpuEvent* transfer_sum_event = - driver_fn.TpuDriver_TransferFromDevice(driver, buf_sum_handle, sum_src, - /*eventc=*/1, /*eventv=*/execute_events); - - TpuStatus* status = driver_fn.TpuDriver_EventAwait(transfer_sum_event, - 10000000); - if (status->code != 0) { - fprintf(stdout, "Transfer Event Await: Code: %d, Message: %s\n", - status->code, status->msg); - } - - fprintf(stdout, "------ Going to Unload a TPU program ------\n"); - struct TpuEvent* unload_program_event = driver_fn.TpuDriver_UnloadProgram( - driver, lph, /*eventc=*/1, /*eventv=*/execute_events); - - fprintf(stdout, "------ Going to Deallocate a TPU Buffer ------\n"); - struct TpuEvent* dealloc_ev1 = driver_fn.TpuDriver_Deallocate(driver, - buf_a_handle, /*eventc=*/0, /*eventv=*/NULL); - driver_fn.TpuDriver_FreeEvent(dealloc_ev1); - - fprintf(stdout, "------ Going to Deallocate a TPU Buffer ------\n"); - struct TpuEvent* dealloc_ev2 = driver_fn.TpuDriver_Deallocate(driver, - buf_b_handle, /*eventc=*/0, /*eventv=*/NULL); - driver_fn.TpuDriver_FreeEvent(dealloc_ev2); - - fprintf(stdout, "------ Going to Deallocate a TPU Buffer ------\n"); - struct TpuEvent* dealloc_ev3 = driver_fn.TpuDriver_Deallocate(driver, - buf_sum_handle, /*eventc=*/0, /*eventv=*/NULL); - driver_fn.TpuDriver_FreeEvent(dealloc_ev3); - - fprintf(stdout, "sum:\n"); - for (size_t i = 0; i < size; ++i) { - fprintf(stdout, "%d ", sum_src[i]); - } - - dlclose(handle); - exit(EXIT_SUCCESS); -} diff --git a/third_party/xla/xla/python/tpu_driver/event_id.h b/third_party/xla/xla/python/tpu_driver/event_id.h deleted file mode 100644 index ac86aac8ead540..00000000000000 --- a/third_party/xla/xla/python/tpu_driver/event_id.h +++ /dev/null @@ -1,63 +0,0 @@ -// Copyright 2019 The OpenXLA Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================== -#ifndef XLA_PYTHON_TPU_DRIVER_EVENT_ID_H_ -#define XLA_PYTHON_TPU_DRIVER_EVENT_ID_H_ - -#include -#include -#include -#include - -#include "absl/strings/str_cat.h" - -namespace tpu_driver { - -// For gRPC serialization, events are represented as a pair of -// {client, operation} ids. To simplify serialization, these are encoded as a -// single integer field. -// -// This class provides a typed interface for these values as well as support for -// hashing and ostreams (for logging). -struct EventId { - uint64_t client_id; - uint64_t operation_id; - - template - friend H AbslHashValue(H h, const EventId& c) { - return H::combine(std::move(h), c.client_id, c.operation_id); - } - - bool operator==(const EventId& r) const { - return r.client_id == client_id && r.operation_id == operation_id; - } - - friend std::ostream& operator<<(std::ostream& os, EventId r) { - return os << r.client_id << ":" << r.operation_id; - } - - std::string ToString() const { - return absl::StrCat(client_id, ":", operation_id); - } - - uint64_t AsInt() const { return client_id << 44 | operation_id; } - - static EventId FromInt(uint64_t value) { - return EventId{value >> 44, value & 0xfffffffffff}; - } -}; - -} // namespace tpu_driver - -#endif // XLA_PYTHON_TPU_DRIVER_EVENT_ID_H_ diff --git a/third_party/xla/xla/python/tpu_driver/grpc_tpu_driver.cc b/third_party/xla/xla/python/tpu_driver/grpc_tpu_driver.cc deleted file mode 100644 index 8b9c466d2c28a5..00000000000000 --- a/third_party/xla/xla/python/tpu_driver/grpc_tpu_driver.cc +++ /dev/null @@ -1,1108 +0,0 @@ -// Copyright 2019 The OpenXLA Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================== - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include "absl/base/thread_annotations.h" -#include "absl/strings/strip.h" -#include "absl/synchronization/mutex.h" -#include "absl/time/clock.h" -#include "absl/time/time.h" -#include "absl/types/span.h" -#include "grpcpp/grpcpp.h" -#include "xla/python/tpu_driver/event_id.h" -#include "xla/python/tpu_driver/platform/external/compat.h" -#include "xla/python/tpu_driver/tpu_driver.h" -#include "xla/python/tpu_driver/tpu_driver.pb.h" -#include "xla/python/tpu_driver/tpu_service.grpc.pb.h" -#include "xla/util.h" - -namespace tpu_driver { -namespace { - -using xla::OkStatus; -using xla::Status; - -const int64_t kMaxStreamWriteSize = 10 * 1000 * 1000; -const absl::Duration kWriteEpochDuration = absl::Microseconds(10); - -constexpr char kGrpcProtocol[] = "grpc://"; - -class GrpcTpuStream; -class GrpcTpuDriver; - -class GrpcEvent : public Event { - public: - explicit GrpcEvent(EventId id, GrpcTpuStream* stream) - : id_(id), stream_(stream) {} - ~GrpcEvent() override; - - xla::Status Await() override; - std::optional AwaitWithTimeout(absl::Duration duration) override; - void AddCallback(std::function callback) override; - - EventId id() const { return id_; } - GrpcTpuStream* stream() const { return stream_; } - - private: - const EventId id_; - GrpcTpuStream* stream_; -}; - -class ErrorEvent : public GrpcEvent { - public: - explicit ErrorEvent(Status status) : GrpcEvent(EventId{0, 0}, nullptr) { - status_ = status; - } - - xla::Status Await() override { return status_; } - std::optional AwaitWithTimeout( - absl::Duration duration) override { - return status_; - } - void AddCallback(std::function callback) override { - callback(status_); - } - - private: - Status status_; -}; - -class GrpcBufferHandle : public BufferHandle { - public: - explicit GrpcBufferHandle(EventId id, std::shared_ptr event, - int64_t bytes, - std::optional shape = std::nullopt) - : id_(id), - stream_(event->stream()), - event_(std::move(event)), - bytes_(bytes), - shape_(shape) {} - - std::shared_ptr OnReady() override { return event_; } - int64_t size_in_bytes() override { return bytes_; } - - EventId id() const { return id_; } - GrpcTpuStream* stream() const { return stream_; } - - std::optional shape() override { return shape_; } - - private: - const EventId id_; - GrpcTpuStream* stream_; - std::shared_ptr event_; - int64_t bytes_; - std::optional shape_; -}; - -class GrpcCompiledProgramHandle : public CompiledProgramHandle { - public: - explicit GrpcCompiledProgramHandle(EventId id, - std::shared_ptr event) - : id_(id), - stream_(event->stream()), - event_(std::move(event)), - metadata_(std::make_shared()) {} - - std::shared_ptr OnReady() override { return event_; } - - EventId id() const { return id_; } - GrpcTpuStream* stream() const { return stream_; } - - Status program_shape(xla::ProgramShapeProto* program_shape) override { - auto opt_status = OnReady()->AwaitWithTimeout(absl::Hours(1)); - if (!opt_status.has_value()) { - return xla::Internal("Compile failed to finish within 1 hour."); - } - - Status status = opt_status.value(); - if (!status.ok()) { - return status; - } - *program_shape = metadata_->program_shape(); - return OkStatus(); - } - - std::shared_ptr metadata() { return metadata_; } - - private: - const EventId id_; - GrpcTpuStream* stream_; - std::shared_ptr event_; - - // Using a shared pointer here because the program handle can go out of scope - // before we get a response back, but we want a valid location to write things - // into regardless. - std::shared_ptr metadata_; -}; - -class GrpcLoadedProgramHandle : public LoadedProgramHandle { - public: - explicit GrpcLoadedProgramHandle(EventId id, std::shared_ptr event) - : id_(id), stream_(event->stream()), event_(std::move(event)) {} - - std::shared_ptr OnReady() override { return event_; } - - EventId id() const { return id_; } - GrpcTpuStream* stream() const { return stream_; } - - private: - const EventId id_; - GrpcTpuStream* stream_; - std::shared_ptr event_; -}; - -class GrpcTpuStream { - public: - explicit GrpcTpuStream(int32_t id, GrpcTpuDriver* driver, - std::unique_ptr stub); - virtual ~GrpcTpuStream(); - - std::unique_ptr Allocate(int32_t core_id, MemoryRegion region, - int64_t num_bytes, - absl::Span wait_for); - std::unique_ptr Allocate(int32_t core_id, MemoryRegion region, - const xla::ShapeProto& shape, - absl::Span wait_for); - std::unique_ptr AllocateTuple( - int32_t core_id, MemoryRegion region, - absl::Span children, - absl::Span wait_for); - std::shared_ptr Deallocate(std::unique_ptr handle, - absl::Span wait_for); - - std::shared_ptr TransferToDevice(const void* src, BufferHandle* dst, - absl::Span wait_for); - std::shared_ptr TransferFromDevice(const BufferHandle* src, void* dst, - absl::Span wait_for); - - std::shared_ptr TransferFromDeviceToDevice( - const BufferHandle* src, BufferHandle* dst, - absl::Span wait_for); - - std::unique_ptr CompileProgram( - const xla::HloProto& source, int32_t num_replicas, - absl::Span wait_for, - const xla::DebugOptions& debug_options); - std::unique_ptr LoadProgram( - int32_t core_id, const CompiledProgramHandle* handle, - absl::Span wait_for); - std::shared_ptr UnloadProgram( - std::unique_ptr handle, - absl::Span wait_for); - std::shared_ptr ExecuteProgram( - LoadedProgramHandle* program, absl::Span inputs, - absl::Span outputs, - const xla::DeviceAssignmentProto& device_assignment, - absl::Span wait_for); - - private: - friend class GrpcEvent; - friend class GrpcTpuDriver; - - struct EventInfo { - bool all_deps_done = false; - bool done = false; // response received - bool deleted = false; // deleted by the user - Status status; - absl::InlinedVector, 1> callbacks; - // Most events should have <= 2 requirement events. - absl::InlinedVector deps; - }; - - struct TransferInfo { - explicit TransferInfo(void* dst, int64_t num_bytes) - : dst(dst), num_bytes(num_bytes) {} - - void* const dst; - const uint64_t num_bytes; - }; - - struct CompileMetadataInfo { - explicit CompileMetadataInfo( - std::shared_ptr metadata) { - compiled_metadata = metadata; - } - std::shared_ptr compiled_metadata; - }; - - // Every public method above should call this first. - void InitializeRequest(StreamRequest::Entry* req, - absl::Span wait_for) - ABSL_LOCKS_EXCLUDED(events_mutex_); - - // The first update to an event marks it done and calls registered callbacks. - // All subsequent updates must have the same OK-ness as the first update. - // Among non-OK updates, only the first error status is remembered. - void UpdateEventStatus(EventId id, Status status) - ABSL_EXCLUSIVE_LOCKS_REQUIRED(events_mutex_); - - // To ensure callbacks are still triggered, after this is called, we do not - // remove the event from the event mapping until a response is received from - // the server. - void DeleteEvent(EventId id) ABSL_LOCKS_EXCLUDED(events_mutex_); - - // Wait at most `duration` for event `id` to complete. Returns the event - // status or an empty optional if the event does not complete in time. - std::optional WaitForEvent(EventId id, absl::Duration duration) - ABSL_LOCKS_EXCLUDED(events_mutex_); - - void AddEventCallback(EventId id, std::function callback) - ABSL_LOCKS_EXCLUDED(events_mutex_); - - void AddWriteRequest(std::unique_ptr req) { - absl::MutexLock m(&request_lock_); - VLOG(2) << "Adding request: " << req->DebugString(); - requests_.push_back(std::move(req)); - } - - // Unique identifier for this stream. - int32_t id_; - // The parent driver that created this stream. - GrpcTpuDriver* driver_; - - std::unique_ptr stub_; - ::grpc::ClientContext ctx_; - std::unique_ptr< - ::grpc::ClientReaderWriterInterface> - stream_; - - absl::Mutex request_lock_; - std::deque> requests_ - ABSL_GUARDED_BY(request_lock_); - int64_t num_pending_requests_ ABSL_GUARDED_BY(request_lock_) = 0; - - bool shutting_down_ ABSL_GUARDED_BY(request_lock_) = false; - - void StreamWriterFn(); - Thread writer_thread_; - - void StreamReaderFn(); - Thread reader_thread_; - - // Map from operation ID to event information. - absl::Mutex events_mutex_; - absl::flat_hash_map events_ - ABSL_GUARDED_BY(events_mutex_); - - // Map from operation ID to transfer information. - // When a D2H transfer completes, received data is copied into the `dst` - // pointer in `TransferInfo`. - absl::Mutex transfers_mutex_; - absl::flat_hash_map transfers_ - ABSL_GUARDED_BY(transfers_mutex_); - - absl::Mutex compiles_mutex_; - absl::flat_hash_map compiles_ - ABSL_GUARDED_BY(compiles_mutex_); -}; - -class GrpcTpuDriver : public TpuDriver { - public: - explicit GrpcTpuDriver(const TpuDriverConfig& config, - std::shared_ptr<::grpc::ChannelCredentials> creds, - int32_t client_id) - : config_(config), creds_(creds), client_id_(client_id) { - SystemInfo system_info; - QuerySystemInfo(&system_info); - for (auto& chip_info : system_info.tpu_chip()) { - for (auto& core_info : chip_info.core()) { - int32_t core_id = core_info.id(); - // We have one stream per core, so use core ID as stream ID. - streams_[core_id] = AllocateStream(core_id); - } - } - CHECK_GT(streams_.size(), 0) << "Can't find any TPU chip in the system."; - - host_stream_ = AllocateStream(-1); - } - - ~GrpcTpuDriver() override { - if (closed_) { - return; - } - auto status = Close(); - if (!status.ok()) { - LOG(ERROR) << status; - } - } - - void QuerySystemInfo(SystemInfo* system_info) override; - Status Reset() override; - - std::unique_ptr Allocate( - int32_t core_id, MemoryRegion region, int64_t num_bytes, - absl::Span wait_for) override { - return streams_[core_id]->Allocate(core_id, region, num_bytes, wait_for); - } - std::unique_ptr Allocate( - int32_t core_id, MemoryRegion region, const xla::ShapeProto& shape, - absl::Span wait_for) override { - return streams_[core_id]->Allocate(core_id, region, shape, wait_for); - } - std::unique_ptr AllocateTuple( - int32_t core_id, MemoryRegion region, - absl::Span children, - absl::Span wait_for) override { - return streams_[core_id]->AllocateTuple(core_id, region, children, - wait_for); - } - std::shared_ptr Deallocate( - std::unique_ptr handle, - absl::Span wait_for) override { - auto* stream = static_cast(handle.get())->stream(); - return stream->Deallocate(std::move(handle), wait_for); - } - - std::shared_ptr TransferToDevice( - const void* src, BufferHandle* dst, - absl::Span wait_for) override { - auto* stream = static_cast(dst)->stream(); - return stream->TransferToDevice(src, dst, wait_for); - } - std::shared_ptr TransferFromDevice( - const BufferHandle* src, void* dst, - absl::Span wait_for) override { - auto* stream = static_cast(src)->stream(); - return stream->TransferFromDevice(src, dst, wait_for); - } - - std::shared_ptr TransferFromDeviceToDevice( - const BufferHandle* src, BufferHandle* dst, - absl::Span wait_for) override { - auto* stream = static_cast(src)->stream(); - return stream->TransferFromDeviceToDevice(src, dst, wait_for); - } - - std::unique_ptr CompileProgram( - const xla::HloProto& source, int32_t num_replicas, - absl::Span wait_for, - const xla::DebugOptions& debug_options) override { - // Always compile using the first/default core's stream. - return streams_[0]->CompileProgram(source, num_replicas, wait_for, - debug_options); - } - std::unique_ptr LoadProgram( - int32_t core_id, const CompiledProgramHandle* handle, - absl::Span wait_for) override { - return streams_[core_id]->LoadProgram(core_id, handle, wait_for); - } - std::shared_ptr UnloadProgram( - std::unique_ptr handle, - absl::Span wait_for) override { - auto* stream = - static_cast(handle.get())->stream(); - return stream->UnloadProgram(std::move(handle), wait_for); - } - std::shared_ptr ExecuteProgram( - LoadedProgramHandle* program, absl::Span inputs, - absl::Span outputs, - const xla::DeviceAssignmentProto& device_assignment, - absl::Span wait_for) override { - auto* stream = - static_cast(program)->stream(); - return stream->ExecuteProgram(program, inputs, outputs, device_assignment, - wait_for); - } - - EventId NewOperationId() { return EventId{client_id_, ++operation_id_}; } - - static std::unique_ptr CreateTpuDriverStub( - const TpuDriverConfig& config, - std::shared_ptr<::grpc::ChannelCredentials> creds); - - uint32_t client_id() const { return client_id_; } - - private: - Status Close(); - std::unique_ptr AllocateStream(int32_t core_id); - - const TpuDriverConfig config_; - std::shared_ptr<::grpc::ChannelCredentials> creds_; - const uint32_t client_id_; - // Map from stream IDs to streams. - absl::flat_hash_map> streams_; - std::unique_ptr host_stream_; - // Shared by all streams. - std::atomic operation_id_{0}; - std::atomic closed_{false}; -}; // namespace - -GrpcEvent::~GrpcEvent() { stream_->DeleteEvent(id_); } - -Status GrpcEvent::Await() { - auto opt_status = stream_->WaitForEvent(id_, absl::InfiniteDuration()); - return opt_status.value(); -} - -std::optional GrpcEvent::AwaitWithTimeout(absl::Duration duration) { - return stream_->WaitForEvent(id_, duration); -} - -void GrpcEvent::AddCallback(std::function callback) { - stream_->AddEventCallback(id_, std::move(callback)); -} - -GrpcTpuStream::GrpcTpuStream(int32_t id, GrpcTpuDriver* driver, - std::unique_ptr stub) - : id_(id), - driver_(driver), - stub_(std::move(stub)), - stream_(stub_->StreamExecute(&ctx_)), - writer_thread_(&GrpcTpuStream::StreamWriterFn, this), - reader_thread_(&GrpcTpuStream::StreamReaderFn, this) {} - -GrpcTpuStream::~GrpcTpuStream() { - { - absl::MutexLock lock(&request_lock_); - shutting_down_ = true; - } - - VLOG(1) << "Shutting down stream."; - { - // Mark all remaining events invalid. - absl::MutexLock lock(&events_mutex_); - for (const auto& e : events_) { - if (!e.second.done) { - LOG(ERROR) << "Resetting: " << e.first; - UpdateEventStatus(e.first, xla::Status(absl::StatusCode::kAborted, - "Driver was closed.")); - } - } - } - VLOG(1) << "Closing stream."; - stream_->WritesDone(); - stream_->Finish().IgnoreError(); - VLOG(1) << "Waiting for writer."; - writer_thread_.join(); - VLOG(1) << "Waiting for reader."; - reader_thread_.join(); -} - -void GrpcTpuStream::InitializeRequest(StreamRequest::Entry* req, - absl::Span wait_for) { - auto operation_id = driver_->NewOperationId(); - EventInfo event_info; - - req->set_operation_id(operation_id.AsInt()); - if (wait_for.empty()) { - event_info.all_deps_done = true; - } else { - event_info.deps.reserve(wait_for.size()); - for (auto* event : wait_for) { - auto grpc_event = static_cast(event); - req->add_wait_for_id(grpc_event->id().AsInt()); - event_info.deps.push_back(grpc_event->id()); - } - } - - absl::MutexLock lock(&events_mutex_); - events_[operation_id] = event_info; -} - -void GrpcTpuStream::UpdateEventStatus(EventId id, Status status) { - auto it = events_.find(id); - - // These should only happen when the server shuts down, and our local event - // cancellation interleaves with server responses. It should be safe to ignore - // the second updates in these situations. - if (it == events_.end()) { - VLOG(1) << "Received a status update: " << status - << ", but cannot find GrpcEvent " << id; - return; - } - if (it->second.done) { - // Done and deleted events must have already been removed. - CHECK(!it->second.deleted); - VLOG(1) << "Received a second status update: " << status.message() - << ", for GrpcEvent " << id - << " already done with status: " << it->second.status.message(); - return; - } - - // This is the first time this event finishes. Remember the results and call - // the callbacks. - VLOG(1) << "Response received for GrpcEvent " << id << ". " << status - << ". Firing " << it->second.callbacks.size() << " callbacks."; - it->second.done = true; - it->second.status = status; - for (const auto& callback : it->second.callbacks) { - callback(status); - } - - // Truly remove the event if it's both done and deleted. - if (it->second.deleted) { - events_.erase(it); - } -} - -void GrpcTpuStream::DeleteEvent(EventId id) { - absl::MutexLock lock(&events_mutex_); - auto it = events_.find(id); - CHECK(it != events_.end()); - CHECK(!it->second.deleted); - it->second.deleted = true; - // Truly remove the event if it's both done and deleted. - if (it->second.done) { - events_.erase(it); - } -} - -std::optional GrpcTpuStream::WaitForEvent(EventId id, - absl::Duration duration) { - events_mutex_.Lock(); - auto it = events_.find(id); - - if (it == events_.end()) { - // This event has already been marked as done and deleted. Assume success. - events_mutex_.Unlock(); - return OkStatus(); - } - - if (!it->second.all_deps_done) { - absl::InlinedVector deps = it->second.deps; - events_mutex_.Unlock(); - for (auto dep : deps) { - // If a requirement event timed out, no point in any further waiting. - if (!WaitForEvent(dep, duration)) { - return std::nullopt; - } - } - events_mutex_.Lock(); - } - - // Set the flag here, as we're guaranteed they have all completed at this - // point. This helps terminate recursion on a chain of completed events as - // soon as possible, at this event. - it = events_.find(id); - if (it != events_.end()) { - it->second.all_deps_done = true; - } - - auto done = [this, id]() { - events_mutex_.AssertHeld(); - return !events_.contains(id) || events_[id].done; - }; - if (events_mutex_.AwaitWithTimeout(absl::Condition(&done), duration)) { - auto status = events_.contains(id) ? events_[id].status : OkStatus(); - events_mutex_.Unlock(); - return status; - } - events_mutex_.Unlock(); - return std::nullopt; -} - -void GrpcTpuStream::AddEventCallback(EventId id, - std::function callback) { - absl::MutexLock lock(&events_mutex_); - auto it = events_.find(id); - if (it == events_.end()) { - callback(Status()); - return; - } - if (it->second.done) { - callback(it->second.status); - return; - } - it->second.callbacks.push_back(std::move(callback)); -} - -static bool ShouldBeginWriting(int64_t* pending_requests) { - return *pending_requests > 32; -} - -void GrpcTpuStream::StreamWriterFn() { - while (true) { - request_lock_.LockWhenWithTimeout( - absl::Condition(&ShouldBeginWriting, &num_pending_requests_), - kWriteEpochDuration); - if (shutting_down_) { - request_lock_.Unlock(); - return; - } - - if (requests_.empty()) { - request_lock_.Unlock(); - continue; - } - - std::vector reqs; - int64_t request_bytes = 0; - while (!requests_.empty()) { - StreamRequest::Entry* e = requests_.front().release(); - requests_.pop_front(); - const int64_t entry_bytes = e->ByteSizeLong(); - if (reqs.empty() || request_bytes + entry_bytes > kMaxStreamWriteSize) { - reqs.push_back(StreamRequest()); - request_bytes = 0; - } - VLOG(1) << "Sending request: " << EventId::FromInt(e->operation_id()); - VLOG(2) << "Sending request: " << e->DebugString(); - reqs.back().mutable_entry()->AddAllocated(e); - } - num_pending_requests_ = 0; - request_lock_.Unlock(); - - for (const auto& r : reqs) { - TraceMe activity("GrpcTpuStream::Send "); - ::grpc::WriteOptions opts; - opts.set_no_compression().clear_buffer_hint(); - stream_->Write(r, opts); - } - } -} - -void GrpcTpuStream::StreamReaderFn() { - StreamResponse resp; - while (stream_->Read(&resp)) { - VLOG(2) << "Received response: " << resp.DebugString(); - for (const StreamResponse::Entry& entry : resp.entry()) { - EventId event_id = EventId::FromInt(entry.operation_id()); - VLOG(1) << "Received response for: " << event_id; - - TraceMe activity("GrpcTpuStream::RequestComplete"); - if (entry.has_transfer_from()) { - TraceMe activity("GrpcTpuStream::TransferFromComplete"); - absl::MutexLock lock(&transfers_mutex_); - auto it = transfers_.find(event_id); - CHECK(it != transfers_.end()); - VLOG(1) << "Copying: " << it->second.num_bytes << " to position " - << it->second.dst; - if (entry.transfer_from().data().size() != it->second.num_bytes) { - absl::MutexLock lock(&events_mutex_); - UpdateEventStatus( - event_id, - Status( - absl::StatusCode::kDataLoss, - absl::StrCat("Expected ", it->second.num_bytes, " received ", - entry.transfer_from().data().size()))); - continue; - } - memcpy(it->second.dst, entry.transfer_from().data().data(), - it->second.num_bytes); - } - - if (entry.has_compile()) { - TraceMe activity("GrpcTpuStream::CompileComplete"); - absl::MutexLock lock(&compiles_mutex_); - auto it = compiles_.find(event_id); - CHECK(it != compiles_.end()); - *it->second.compiled_metadata = entry.compile().metadata(); - } - - absl::MutexLock lock(&events_mutex_); - if (entry.status().code() != tsl::error::Code::OK) { - UpdateEventStatus( - event_id, - Status(static_cast(entry.status().code()), - entry.status().message())); - } else { - UpdateEventStatus(event_id, OkStatus()); - } - } - } -} - -std::unique_ptr GrpcTpuStream::Allocate( - int32_t core_id, MemoryRegion region, int64_t num_bytes, - absl::Span wait_for) { - auto req = std::make_unique(); - InitializeRequest(req.get(), wait_for); - TraceMe activity("GrpcTpuStream::Allocate(num_bytes)"); - req->mutable_alloc()->set_core_id(core_id); - req->mutable_alloc()->set_region(region); - req->mutable_alloc()->set_num_bytes(num_bytes); - auto event = - std::make_shared(EventId::FromInt(req->operation_id()), this); - AddWriteRequest(std::move(req)); - return std::make_unique(event->id(), std::move(event), - num_bytes); -} - -std::unique_ptr GrpcTpuStream::Allocate( - int32_t core_id, MemoryRegion region, const xla::ShapeProto& shape, - absl::Span wait_for) { - auto req = std::make_unique(); - InitializeRequest(req.get(), wait_for); - TraceMe activity("GrpcTpuStream::Allocate(shape)"); - req->mutable_alloc()->set_core_id(core_id); - req->mutable_alloc()->set_region(region); - *req->mutable_alloc()->mutable_shape() = shape; - auto event = - std::make_shared(EventId::FromInt(req->operation_id()), this); - AddWriteRequest(std::move(req)); - return std::make_unique( - event->id(), std::move(event), ComputeBytesFromShape(shape), shape); -} - -std::unique_ptr GrpcTpuStream::AllocateTuple( - int32_t core_id, MemoryRegion region, - absl::Span children, - absl::Span wait_for) { - auto req = std::make_unique(); - InitializeRequest(req.get(), wait_for); - TraceMe activity("GrpcTpuStream::AllocateTuple"); - req->mutable_alloc_tuple()->set_core_id(core_id); - req->mutable_alloc_tuple()->set_region(region); - for (auto child : children) { - auto grpc_child = static_cast(child); - req->mutable_alloc_tuple()->add_children(grpc_child->id().AsInt()); - } - auto event = - std::make_shared(EventId::FromInt(req->operation_id()), this); - AddWriteRequest(std::move(req)); - return std::make_unique(event->id(), std::move(event), 0); -} - -std::shared_ptr GrpcTpuStream::Deallocate( - std::unique_ptr handle, absl::Span wait_for) { - auto req = std::make_unique(); - InitializeRequest(req.get(), wait_for); - TraceMe activity("GrpcTpuStream::Deallocate"); - auto grpc_handle = static_cast(handle.get()); - req->mutable_dealloc()->set_handle(grpc_handle->id().AsInt()); - auto event = - std::make_shared(EventId::FromInt(req->operation_id()), this); - AddWriteRequest(std::move(req)); - return event; -} - -std::shared_ptr GrpcTpuStream::TransferToDevice( - const void* src, BufferHandle* dst, absl::Span wait_for) { - auto req = std::make_unique(); - InitializeRequest(req.get(), wait_for); - TraceMe activity("GrpcTpuStream::TransferToDevice"); - req->mutable_transfer_to()->mutable_data()->assign( - static_cast(src), dst->size_in_bytes()); - req->mutable_transfer_to()->set_target_handle( - static_cast(dst)->id().AsInt()); - auto event = - std::make_shared(EventId::FromInt(req->operation_id()), this); - AddWriteRequest(std::move(req)); - return event; -} - -std::shared_ptr GrpcTpuStream::TransferFromDevice( - const BufferHandle* src, void* dst, absl::Span wait_for) { - auto req = std::make_unique(); - InitializeRequest(req.get(), wait_for); - TraceMe activity("GrpcTpuStream::TransferFromDevice"); - req->mutable_transfer_from()->set_source_handle( - static_cast(src)->id().AsInt()); - EventId event_id = EventId::FromInt(req->operation_id()); - { - absl::MutexLock lock(&transfers_mutex_); - TransferInfo info(dst, const_cast(src)->size_in_bytes()); - transfers_.insert(std::make_pair(event_id, info)); - } - auto event = std::make_shared(event_id, this); - AddWriteRequest(std::move(req)); - return event; -} - -std::shared_ptr GrpcTpuStream::TransferFromDeviceToDevice( - const BufferHandle* src, BufferHandle* dst, - absl::Span wait_for) { - auto req = std::make_unique(); - InitializeRequest(req.get(), wait_for); - TraceMe activity([&req] { - return absl::StrCat("GrpcTpuStream::TransferFromDeviceToDevice", - req->operation_id()); - }); - - req->mutable_transfer_from_to()->set_source_handle( - static_cast(src)->id().AsInt()); - req->mutable_transfer_from_to()->set_target_handle( - static_cast(dst)->id().AsInt()); - EventId event_id = EventId::FromInt(req->operation_id()); - auto event = std::make_shared(event_id, this); - AddWriteRequest(std::move(req)); - return event; -} - -std::unique_ptr GrpcTpuStream::CompileProgram( - const xla::HloProto& source, int32_t num_replicas, - absl::Span wait_for, const xla::DebugOptions& debug_options) { - auto req = std::make_unique(); - InitializeRequest(req.get(), wait_for); - TraceMe activity("GrpcTpuStream::CompileProgram"); - *req->mutable_compile()->mutable_hlo_program() = source; - req->mutable_compile()->set_num_replicas(num_replicas); - *req->mutable_compile()->mutable_debug_options() = debug_options; - EventId event_id = EventId::FromInt(req->operation_id()); - - auto event = - std::make_shared(EventId::FromInt(req->operation_id()), this); - - auto handle = std::make_unique(event->id(), - std::move(event)); - { - absl::MutexLock lock(&compiles_mutex_); - CompileMetadataInfo info(handle->metadata()); - compiles_.insert(std::make_pair(event_id, info)); - } - - AddWriteRequest(std::move(req)); - return std::move(handle); -} - -std::unique_ptr GrpcTpuStream::LoadProgram( - int32_t core_id, const CompiledProgramHandle* handle, - absl::Span wait_for) { - auto req = std::make_unique(); - InitializeRequest(req.get(), wait_for); - TraceMe activity("GrpcTpuStream::LoadProgram"); - req->mutable_load()->set_core_id(core_id); - auto grpc_handle = static_cast(handle); - if (grpc_handle->id().client_id != driver_->client_id()) { - auto event = std::make_shared( - xla::InvalidArgument("Invalid program handle (wrong client id). Did " - "you restart the server or use a stale handle?")); - return std::make_unique(event->id(), - std::move(event)); - } - req->mutable_load()->set_compiled_program_handle(grpc_handle->id().AsInt()); - auto event = - std::make_shared(EventId::FromInt(req->operation_id()), this); - AddWriteRequest(std::move(req)); - return std::make_unique(event->id(), - std::move(event)); -} - -std::shared_ptr GrpcTpuStream::UnloadProgram( - std::unique_ptr handle, - absl::Span wait_for) { - auto req = std::make_unique(); - InitializeRequest(req.get(), wait_for); - TraceMe activity("GrpcTpuStream::UnloadProgram"); - req->mutable_unload()->set_loaded_program_handle( - static_cast(handle.get())->id().AsInt()); - auto event = - std::make_shared(EventId::FromInt(req->operation_id()), this); - AddWriteRequest(std::move(req)); - return event; -} - -std::shared_ptr GrpcTpuStream::ExecuteProgram( - LoadedProgramHandle* program, absl::Span inputs, - absl::Span outputs, - const xla::DeviceAssignmentProto& device_assignment, - absl::Span wait_for) { - auto req = std::make_unique(); - InitializeRequest(req.get(), wait_for); - auto program_handle = static_cast(program); - if (program_handle->id().client_id != driver_->client_id()) { - return std::make_shared( - xla::InvalidArgument("Invalid program handle (wrong client id). Did " - "you restart the server or use a stale handle?")); - } - - req->mutable_execute()->set_loaded_program_handle( - program_handle->id().AsInt()); - - for (BufferHandle* input : inputs) { - auto* grpc_handle = static_cast(input); - if (grpc_handle->id().client_id != driver_->client_id()) { - return std::make_shared(xla::InvalidArgument( - "Invalid input buffer (wrong client id). Did you restart the server " - "or use a stale handle?")); - } - req->mutable_execute()->add_input_handle(grpc_handle->id().AsInt()); - } - - for (BufferHandle* output : outputs) { - auto* grpc_handle = static_cast(output); - if (grpc_handle->id().client_id != driver_->client_id()) { - return std::make_shared(xla::InvalidArgument( - "Invalid output buffer (wrong client id). Did you restart the server " - "or use a stale handle?")); - } - req->mutable_execute()->add_output_handle( - static_cast(output)->id().AsInt()); - } - // Only pass along device_assignment if it's not default constructed. - if (!(device_assignment.replica_count() == 0 && - device_assignment.computation_count() == 0)) { - *req->mutable_execute()->mutable_device_assignment() = device_assignment; - } - auto event = - std::make_shared(EventId::FromInt(req->operation_id()), this); - AddWriteRequest(std::move(req)); - return event; -} - -/*static*/ std::unique_ptr -GrpcTpuDriver::CreateTpuDriverStub( - const TpuDriverConfig& config, - std::shared_ptr<::grpc::ChannelCredentials> creds) { - ::grpc::ChannelArguments args; - args.SetMaxReceiveMessageSize(std::numeric_limits::max()); - args.SetMaxSendMessageSize(std::numeric_limits::max()); - - // Send at least 20 keep-alives before giving up. - int keepalive_timeout_ms = config.grpc().keepalive_timeout_secs() * 1000; - int keepalive_interval_ms = keepalive_timeout_ms / 20; - - grpc_arg client_arg_vals[] = { - {.type = GRPC_ARG_INTEGER, - .key = const_cast( - GRPC_ARG_HTTP2_MIN_RECV_PING_INTERVAL_WITHOUT_DATA_MS), - .value = {.integer = keepalive_interval_ms}}, - {.type = GRPC_ARG_INTEGER, - .key = const_cast(GRPC_ARG_HTTP2_MAX_PINGS_WITHOUT_DATA), - .value = {.integer = 0}}, // unlimited - {.type = GRPC_ARG_INTEGER, - .key = const_cast(GRPC_ARG_KEEPALIVE_TIME_MS), - .value = {.integer = keepalive_interval_ms}}, - {.type = GRPC_ARG_INTEGER, - .key = const_cast(GRPC_ARG_KEEPALIVE_TIMEOUT_MS), - .value = {.integer = keepalive_timeout_ms}}, - {.type = GRPC_ARG_INTEGER, - .key = const_cast(GRPC_ARG_KEEPALIVE_PERMIT_WITHOUT_CALLS), - .value = {.integer = 1}}, - {.type = GRPC_ARG_INTEGER, - .key = const_cast(GRPC_ARG_HTTP2_WRITE_BUFFER_SIZE), - .value = {.integer = 64 * 1000 * 1000}}}; - - grpc_channel_args client_args = {.num_args = 6, .args = client_arg_vals}; - args.SetChannelArgs(&client_args); - - // strips out 'grpc://' - auto worker_addr = absl::StripPrefix(config.worker(), kGrpcProtocol); - std::shared_ptr<::grpc::Channel> channel = - ::grpc::CreateCustomChannel(std::string(worker_addr), creds, args); - return grpc::CloudTpuDriver::NewStub(channel); -} - -std::unique_ptr GrpcTpuDriver::AllocateStream(int32_t id) { - auto stub = CreateTpuDriverStub(config_, creds_); - ::grpc::ClientContext ctx; - ctx.set_fail_fast(false); - ctx.set_deadline(std::chrono::system_clock::now() + std::chrono::seconds(10)); - return std::make_unique(id, this, std::move(stub)); -} - -void GrpcTpuDriver::QuerySystemInfo(SystemInfo* system_info) { - auto stub = CreateTpuDriverStub(config_, creds_); - ::grpc::ClientContext ctx; - ctx.set_fail_fast(false); - ctx.set_deadline(std::chrono::system_clock::now() + std::chrono::seconds(10)); - - QuerySystemInfoRequest req; - QuerySystemInfoResponse resp; - ::grpc::Status status = stub->QuerySystemInfo(&ctx, req, &resp); - if (!status.ok()) { - LOG(ERROR) << "QuerySystemInfo request failed: " << status.error_code() - << ": " << status.error_message() << ": " - << status.error_details(); - return; - } - *system_info = resp.system_info(); -} - -Status GrpcTpuDriver::Reset() { - auto stub = CreateTpuDriverStub(config_, creds_); - ::grpc::ClientContext ctx; - ctx.set_fail_fast(false); - ctx.set_deadline(std::chrono::system_clock::now() + std::chrono::seconds(10)); - ResetRequest req; - ResetResponse resp; - ::grpc::Status status = stub->Reset(&ctx, req, &resp); - if (!status.ok()) { - LOG(ERROR) << "Failed to reset the gRPC driver: " << status.error_code() - << ": " << status.error_message() << ": " - << status.error_details(); - return xla::Status(absl::StatusCode(status.error_code()), - absl::StrCat("Failed to reset TPU driver. Error was: ", - status.error_message(), - ". Details: ", status.error_details())); - } - streams_.clear(); - host_stream_.reset(); - return Close(); -} - -Status GrpcTpuDriver::Close() { - auto stub = CreateTpuDriverStub(config_, creds_); - ::grpc::ClientContext ctx; - ctx.set_fail_fast(false); - ctx.set_deadline(std::chrono::system_clock::now() + std::chrono::seconds(10)); - CloseRequest req; - req.set_client_id(client_id_); - CloseResponse resp; - ::grpc::Status status = stub->Close(&ctx, req, &resp); - if (!status.ok()) { - return xla::Status(absl::StatusCode(status.error_code()), - absl::StrCat("Failed to close TPU driver. Error was: ", - status.error_message(), - ". Details: ", status.error_details())); - } - closed_ = true; - return OkStatus(); -} -} // namespace - -xla::StatusOr> CreateGrpcTpuDriver( - const TpuDriverConfig& config, - std::shared_ptr<::grpc::ChannelCredentials> creds) { - auto stub = GrpcTpuDriver::CreateTpuDriverStub(config, creds); - ::grpc::ClientContext ctx; - ctx.set_fail_fast(false); - ctx.set_deadline( - std::chrono::system_clock::now() + - std::chrono::seconds(config.grpc().connection_timeout_secs())); - OpenRequest req; - OpenResponse resp; - ::grpc::Status status = stub->Open(&ctx, req, &resp); - if (!status.ok()) { - LOG(ERROR) << "Failed to open the gRPC driver: " << status.error_code() - << ": " << status.error_message() << ": " - << status.error_details(); - return xla::Status( - absl::StatusCode(status.error_code()), - absl::StrCat( - "Failed to connect to remote server at address: ", config.worker(), - ". Error from gRPC: ", status.error_message(), - ". Details: ", status.error_details())); - } - return std::unique_ptr( - new GrpcTpuDriver(config, creds, resp.client_id())); -} - -REGISTER_TPU_DRIVER( - "grpc://", - [](const TpuDriverConfig& config) - -> xla::StatusOr> { - if (absl::StartsWith(config.worker(), "grpc://localhost")) { - LOG(INFO) << "Using local credentials for localhost: connection."; - return CreateGrpcTpuDriver( - config, ::grpc::experimental::LocalCredentials(LOCAL_TCP)); - } else { - return CreateGrpcTpuDriver(config, - ::grpc::InsecureChannelCredentials()); - } - }); - -} // namespace tpu_driver diff --git a/third_party/xla/xla/python/tpu_driver/grpc_tpu_driver.h b/third_party/xla/xla/python/tpu_driver/grpc_tpu_driver.h deleted file mode 100644 index 11ed3431feee12..00000000000000 --- a/third_party/xla/xla/python/tpu_driver/grpc_tpu_driver.h +++ /dev/null @@ -1,33 +0,0 @@ -#ifndef XLA_PYTHON_TPU_DRIVER_GRPC_TPU_DRIVER_H_ -#define XLA_PYTHON_TPU_DRIVER_GRPC_TPU_DRIVER_H_ - -// Copyright 2019 The OpenXLA Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================== - -#include - -#include "grpcpp/grpcpp.h" -#include "xla/python/tpu_driver/tpu_driver.h" -#include "xla/python/tpu_driver/tpu_driver.pb.h" - -namespace tpu_driver { - -xla::StatusOr> CreateGrpcTpuDriver( - const TpuDriverConfig& config, - std::shared_ptr credentials); - -} // namespace tpu_driver - -#endif // XLA_PYTHON_TPU_DRIVER_GRPC_TPU_DRIVER_H_ diff --git a/third_party/xla/xla/python/tpu_driver/platform/external/compat.h b/third_party/xla/xla/python/tpu_driver/platform/external/compat.h deleted file mode 100644 index 07abc507aa5e5f..00000000000000 --- a/third_party/xla/xla/python/tpu_driver/platform/external/compat.h +++ /dev/null @@ -1,50 +0,0 @@ -// Copyright 2019 The OpenXLA Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================== - -#ifndef XLA_PYTHON_TPU_DRIVER_PLATFORM_EXTERNAL_COMPAT_H_ -#define XLA_PYTHON_TPU_DRIVER_PLATFORM_EXTERNAL_COMPAT_H_ - -#include // NOLINT - -#include "absl/strings/string_view.h" - -namespace tpu_driver { - -class Thread { - public: - template - explicit Thread(Function&& f, Args&&... args) - : thread_(std::forward(f), std::forward(args)...) {} - void join() { thread_.join(); } - - private: - std::thread thread_; -}; - -class TraceMe { - public: - explicit TraceMe(absl::string_view name, int level = 1) {} - explicit TraceMe(std::string&& name, int level = 1) = delete; - explicit TraceMe(const std::string& name, int level = 1) = delete; - explicit TraceMe(const char* raw, int level = 1) - : TraceMe(absl::string_view(raw), level) {} - template - explicit TraceMe(NameGeneratorT name_generator, int level = 1) {} - ~TraceMe() {} -}; - -} // namespace tpu_driver - -#endif // XLA_PYTHON_TPU_DRIVER_PLATFORM_EXTERNAL_COMPAT_H_ diff --git a/third_party/xla/xla/python/tpu_driver/pod_tpu_driver.cc b/third_party/xla/xla/python/tpu_driver/pod_tpu_driver.cc deleted file mode 100644 index 4b423d03f938c2..00000000000000 --- a/third_party/xla/xla/python/tpu_driver/pod_tpu_driver.cc +++ /dev/null @@ -1,991 +0,0 @@ -// Copyright 2020 The OpenXLA Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= - -#include -#include -#include -#include -#include -#include -#include -#include - -#include "absl/container/btree_map.h" -#include "absl/container/flat_hash_map.h" -#include "absl/container/flat_hash_set.h" -#include "absl/strings/str_split.h" -#include "absl/synchronization/mutex.h" -#include "xla/pjrt/semaphore.h" -#include "xla/pjrt/worker_thread.h" -#include "xla/python/tpu_driver/grpc_tpu_driver.h" -#include "xla/python/tpu_driver/tpu_driver.h" -#include "xla/python/tpu_driver/tpu_driver.pb.h" -#include "tsl/platform/env.h" -#include "tsl/platform/errors.h" - -namespace tpu_driver { -namespace { - -#define CHECK_EXISTS_OR_RETURN(container, target_op_id, operation_id) \ - { \ - auto p = CheckHandleExists(container, target_op_id, operation_id); \ - if (p != nullptr) return p; \ - } - -using xla::OkStatus; -using xla::Status; -using xla::WorkerThread; - -const char kPodTpuDriverPrefix[] = "grpc+pod://"; - -class PodTpuDriver; - -class PodEvent : public Event { - public: - explicit PodEvent(PodTpuDriver* driver, int64_t operation_id) - : driver_(driver), operation_id_(operation_id) {} - int64_t operation_id() const { return operation_id_; } - - xla::Status Await() override; - - std::optional AwaitWithTimeout(absl::Duration duration) override; - - void AddCallback(std::function callback) override; - - private: - PodTpuDriver* driver_; - const int64_t operation_id_; -}; - -class ErrorEvent : public PodEvent { - public: - explicit ErrorEvent(PodTpuDriver* driver, int64_t operation_id, Status status) - : PodEvent(driver, operation_id) { - status_ = status; - } - - xla::Status Await() override { return status_; } - std::optional AwaitWithTimeout( - absl::Duration duration) override { - return status_; - } - void AddCallback(std::function callback) override { - callback(status_); - } - - private: - Status status_; -}; - -class CombinedEvent : public PodEvent { - public: - explicit CombinedEvent(PodTpuDriver* driver, int64_t operation_id, - std::vector> events) - : PodEvent(driver, operation_id), events_(events) { - for (auto& event : events_) { - event->AddCallback([this](Status s) { IncrementAndCheckComplete(s); }); - } - } - - xla::Status Await() override { - for (auto& event : events_) { - TF_RETURN_IF_ERROR(event->Await()); - } - return OkStatus(); - } - - std::optional AwaitWithTimeout( - absl::Duration duration) override { - for (auto& event : events_) { - auto start_time = absl::Now(); - auto status = event->AwaitWithTimeout(duration); - duration -= absl::Now() - start_time; - if (status == std::nullopt) { - return std::nullopt; - } else { - TF_RETURN_IF_ERROR(status.value()); - } - } - return OkStatus(); - } - - void AddCallback(std::function callback) - ABSL_LOCKS_EXCLUDED(mu_) override { - bool all_events_completed = false; - { - absl::MutexLock l(&mu_); - all_events_completed = events_completed_ == events_.size(); - } - if (all_events_completed) { - callback(event_status_); - } else { - absl::MutexLock l(&mu_); - callbacks_.push_back(std::move(callback)); - } - } - - private: - void IncrementAndCheckComplete(Status s) ABSL_LOCKS_EXCLUDED(mu_) { - std::vector> callbacks; - { - absl::MutexLock l(&mu_); - - event_status_ = s; - events_completed_++; - if (events_completed_ == events_.size()) { - // Copy callbacks to a temporary to be invoked outside the mutex. - callbacks.assign(callbacks_.begin(), callbacks_.end()); - callbacks_.clear(); - } else { - return; - } - } - - for (const auto& callback : callbacks) { - callback(event_status_); - } - } - - absl::Mutex mu_; - std::vector> events_; - std::vector> callbacks_ ABSL_GUARDED_BY(mu_); - int64_t events_completed_ ABSL_GUARDED_BY(mu_) = 0; - Status event_status_; -}; - -class PodBufferHandle : public BufferHandle { - public: - explicit PodBufferHandle(PodTpuDriver* driver, int64_t operation_id, - int64_t size_in_bytes, - std::optional shape, - int64_t core_id) - : driver_(driver), - operation_id_(operation_id), - size_in_bytes_(size_in_bytes), - shape_(shape), - event_(std::make_shared(driver_, operation_id_)), - core_id_(core_id) {} - - std::shared_ptr OnReady() override { return event_; } - int64_t size_in_bytes() override { return size_in_bytes_; } - std::optional shape() override { return shape_; } - - int64_t operation_id() const { return operation_id_; } - int64_t core_id() const { return core_id_; } - - private: - PodTpuDriver* driver_; - const int64_t operation_id_; - const int64_t size_in_bytes_; - const std::optional shape_; - std::shared_ptr event_; - const int64_t core_id_; -}; - -class PodCompiledProgramHandle : public CompiledProgramHandle { - public: - explicit PodCompiledProgramHandle(PodTpuDriver* driver, int64_t operation_id) - : driver_(driver), - operation_id_(operation_id), - event_(std::make_shared(driver_, operation_id_)) {} - - std::shared_ptr OnReady() override { return event_; } - - xla::Status program_shape(xla::ProgramShapeProto* program_shape) override; - - int64_t operation_id() const { return operation_id_; } - - private: - PodTpuDriver* driver_; - const int64_t operation_id_; - std::shared_ptr event_; -}; - -class PodLoadedProgramHandle : public LoadedProgramHandle { - public: - explicit PodLoadedProgramHandle(PodTpuDriver* driver, int64_t operation_id, - int64_t core_id) - : driver_(driver), - operation_id_(operation_id), - core_id_(core_id), - event_(std::make_shared(driver_, operation_id_)) {} - - std::shared_ptr OnReady() override { return event_; } - - int64_t operation_id() const { return operation_id_; } - int64_t core_id() const { return core_id_; } - - private: - PodTpuDriver* driver_; - const int64_t operation_id_; - const int64_t core_id_; - std::shared_ptr event_; -}; - -struct EventInFlight { - EventInFlight() - : underlying_event(nullptr), - create_fn(nullptr), - incomplete_deps(), - callbacks() {} - - std::shared_ptr underlying_event; - std::function(void)> create_fn; - - absl::flat_hash_set incomplete_deps; - std::vector> callbacks; -}; - -class PodTpuDriver : public TpuDriver { - public: - explicit PodTpuDriver(const TpuDriverConfig& config, - std::shared_ptr<::grpc::ChannelCredentials> creds) - : config_(config), - creds_(creds), - event_thread_(tsl::Env::Default(), "grpc_pod_event_thread") { - std::vector workers = absl::StrSplit( - absl::StripPrefix(config.worker(), kPodTpuDriverPrefix), ','); - - int worker_count = 0; - - // Flag for environments where local core # == all cores in TPU system #, - // which means that we are connecting to separate TPU systems or we are in - // a test environment. - bool in_local_core_environment = false; - - for (const auto& worker : workers) { - TpuDriverConfig worker_config(config_); - *(worker_config.mutable_worker()) = absl::StrCat("grpc://", worker); - auto tpu_driver = CreateGrpcTpuDriver(worker_config, creds_).value(); - - SystemInfo driver_info; - tpu_driver->QuerySystemInfo(&driver_info); - - if (driver_info.core_count() == driver_info.local_core_size()) { - drivers_.insert({worker_count, std::move(tpu_driver)}); - in_local_core_environment = true; - } else { - drivers_.insert({driver_info.host_id(), std::move(tpu_driver)}); - } - - worker_count++; - } - - absl::flat_hash_set> processed_chips; - - for (int driver_num = 0; driver_num < workers.size(); ++driver_num) { - SystemInfo driver_info; - drivers_[driver_num]->QuerySystemInfo(&driver_info); - - for (const auto& tpu_chip : driver_info.tpu_chip()) { - std::tuple coord{tpu_chip.chip_coord().x(), - tpu_chip.chip_coord().y(), - tpu_chip.chip_coord().z()}; - // We only want to add chips that we have not seen before if we are in a - // TPU pod slice, or we are only seeing local cores (e.g. we are - // connected to individual TPUs or we are in a test environment). - if (!processed_chips.contains(coord) || - driver_info.core_count() == driver_info.local_core_size()) { - *(pod_info_.add_tpu_chip()) = tpu_chip; - processed_chips.insert(coord); - } - } - - *(pod_info_.mutable_cpu()) = driver_info.cpu(); - } - - // Process all the unique chips that we have seen. - int core_count = 0; - for (auto& tpu_chip : *pod_info_.mutable_tpu_chip()) { - for (auto& tpu_core : *tpu_chip.mutable_core()) { - int current_core = tpu_core.id(); - if (in_local_core_environment) { - current_core = core_count; - } - - core_to_driver_.insert( - {current_core, drivers_[tpu_chip.host_id()].get()}); - core_to_driver_id_.insert({current_core, tpu_chip.host_id()}); - core_to_driver_core_.insert({current_core, tpu_core.id()}); - - tpu_core.set_id(current_core); - tpu_core.set_core_on_host_index(current_core); - *(pod_info_.add_local_core()) = tpu_core; - - core_count++; - } - - // We are setting host_id to zero because we want this to look like one - // host with many cores from the perspective of tpu_client.cc. - tpu_chip.set_host_id(0); - } - - pod_info_.set_chip_count(pod_info_.tpu_chip_size()); - pod_info_.set_core_count(pod_info_.local_core_size()); - - // We want this to look like one host with many TPU chips/cores connected. - pod_info_.set_host_count(1); - pod_info_.set_host_id(0); - } - - ~PodTpuDriver() override { - // TODO(frankchn): Unload all handles, and wait for all events to finish. - } - - void QuerySystemInfo(SystemInfo* system_info) override { - *system_info = pod_info_; - } - - xla::Status Reset() override { - for (auto& driver : drivers_) { - TF_RETURN_IF_ERROR(driver.second->Reset()); - } - return OkStatus(); - } - - std::unique_ptr Allocate( - int32_t core_id, MemoryRegion region, int64_t num_bytes, - absl::Span wait_for) override { - int64_t operation_id = GetOperationId(); - auto deps = GetDependencyOperationIds(wait_for); - - ScheduleRequest( - operation_id, - [this, core_id, region, num_bytes, operation_id]() - ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_) { - underlying_buffers_.insert( - {operation_id, - core_to_driver_[core_id]->Allocate( - core_to_driver_core_[core_id], region, num_bytes, {})}); - return underlying_buffers_[operation_id]->OnReady(); - }, - deps); - - return std::make_unique(this, operation_id, num_bytes, - std::nullopt, core_id); - } - - std::unique_ptr Allocate( - int32_t core_id, MemoryRegion region, const xla::ShapeProto& shape, - absl::Span wait_for) override { - int64_t operation_id = GetOperationId(); - auto deps = GetDependencyOperationIds(wait_for); - - ScheduleRequest( - operation_id, - [this, core_id, region, shape, operation_id]() - ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_) { - underlying_buffers_.insert( - {operation_id, - core_to_driver_[core_id]->Allocate( - core_to_driver_core_[core_id], region, shape, {})}); - return underlying_buffers_[operation_id]->OnReady(); - }, - deps); - - return std::make_unique( - this, operation_id, ComputeBytesFromShape(shape), shape, core_id); - } - - std::unique_ptr AllocateTuple( - int32_t core_id, MemoryRegion region, - absl::Span children, - absl::Span wait_for) override { - int64_t operation_id = GetOperationId(); - auto deps = GetDependencyOperationIds(wait_for); - - std::vector children_ids; - const size_t children_ids_size = children.size(); - children_ids.reserve(children_ids_size); - for (size_t i = 0; i < children_ids_size; ++i) { - auto child_op_id = - static_cast(children[i])->operation_id(); - deps.insert(child_op_id); - children_ids.push_back(child_op_id); - } - - ScheduleRequest( - operation_id, - [this, core_id, region, children_ids, operation_id]() - ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_) -> std::shared_ptr { - std::vector child_buffers; - child_buffers.reserve(children_ids.size()); - for (size_t i = 0; i < children_ids.size(); ++i) { - CHECK_EXISTS_OR_RETURN(underlying_buffers_, children_ids[i], - operation_id); - child_buffers.push_back( - underlying_buffers_[children_ids[i]].get()); - } - - underlying_buffers_.insert( - {operation_id, core_to_driver_[core_id]->AllocateTuple( - core_to_driver_core_[core_id], region, - child_buffers, {})}); - return underlying_buffers_[operation_id]->OnReady(); - }, - deps); - - return std::make_unique(this, operation_id, 0, - std::nullopt, core_id); - } - - std::shared_ptr Deallocate( - std::unique_ptr handle, - absl::Span wait_for) override { - int64_t operation_id = GetOperationId(); - auto deps = GetDependencyOperationIds(wait_for); - deps.insert(static_cast(handle.get())->operation_id()); - - auto op_id = static_cast(handle.get())->operation_id(); - auto core_id = static_cast(handle.get())->core_id(); - - ScheduleRequest( - operation_id, - [this, operation_id, op_id, core_id]() - ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_) -> std::shared_ptr { - CHECK_EXISTS_OR_RETURN(underlying_buffers_, op_id, operation_id); - - auto buf_iter = underlying_buffers_.find(op_id); - auto underlying_hn = std::move(buf_iter->second); - underlying_buffers_.erase(buf_iter); - - return core_to_driver_[core_id]->Deallocate( - std::move(underlying_hn), {}); - }, - deps); - - return std::make_shared(this, operation_id); - } - - std::shared_ptr TransferToDevice( - const void* src, BufferHandle* dst, - absl::Span wait_for) override { - int64_t operation_id = GetOperationId(); - auto deps = GetDependencyOperationIds(wait_for); - deps.insert(static_cast(dst)->operation_id()); - - auto op_id = static_cast(dst)->operation_id(); - auto core_id = static_cast(dst)->core_id(); - - ScheduleRequest( - operation_id, - [this, src, operation_id, op_id, core_id]() - ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_) -> std::shared_ptr { - CHECK_EXISTS_OR_RETURN(underlying_buffers_, op_id, operation_id); - - auto buf_iter = underlying_buffers_.find(op_id); - return core_to_driver_[core_id]->TransferToDevice( - src, buf_iter->second.get(), {}); - }, - deps); - - return std::make_shared(this, operation_id); - } - - std::shared_ptr TransferFromDevice( - const BufferHandle* src, void* dst, - absl::Span wait_for) override { - int64_t operation_id = GetOperationId(); - auto deps = GetDependencyOperationIds(wait_for); - deps.insert(static_cast(src)->operation_id()); - - auto op_id = static_cast(src)->operation_id(); - auto core_id = static_cast(src)->core_id(); - - ScheduleRequest( - operation_id, - [this, dst, operation_id, op_id, core_id]() - ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_) -> std::shared_ptr { - CHECK_EXISTS_OR_RETURN(underlying_buffers_, op_id, operation_id); - auto buf_iter = underlying_buffers_.find(op_id); - return core_to_driver_[core_id]->TransferFromDevice( - buf_iter->second.get(), dst, {}); - }, - deps); - - return std::make_shared(this, operation_id); - } - - std::shared_ptr TransferFromDeviceToDevice( - const BufferHandle* src, BufferHandle* dst, - absl::Span wait_for) override { - auto src_core_id = static_cast(src)->core_id(); - auto dst_core_id = static_cast(dst)->core_id(); - - auto src_driver_id = core_to_driver_id_[src_core_id]; - auto dst_driver_id = core_to_driver_id_[dst_core_id]; - - if (src_driver_id == dst_driver_id) { - // They are in the same host, we can schedule it normally - int64_t operation_id = GetOperationId(); - auto deps = GetDependencyOperationIds(wait_for); - deps.insert(static_cast(src)->operation_id()); - deps.insert(static_cast(dst)->operation_id()); - - auto src_op_id = static_cast(src)->operation_id(); - auto dst_op_id = static_cast(dst)->operation_id(); - - ScheduleRequest( - operation_id, - [this, operation_id, src_op_id, dst_op_id, dst_core_id]() - ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_) -> std::shared_ptr { - CHECK_EXISTS_OR_RETURN(underlying_buffers_, src_op_id, - operation_id); - CHECK_EXISTS_OR_RETURN(underlying_buffers_, dst_op_id, - operation_id); - - auto src_iter = underlying_buffers_.find(src_op_id); - auto dst_iter = underlying_buffers_.find(dst_op_id); - return core_to_driver_[dst_core_id]->TransferFromDeviceToDevice( - src_iter->second.get(), dst_iter->second.get(), {}); - }, - deps); - return std::make_shared(this, operation_id); - } else { - // src and dst are on different hosts, we have to bounce through us. - auto dst_size = dst->size_in_bytes(); - char* host_buf = new char[dst_size]; - - auto src_event = TransferFromDevice(src, host_buf, wait_for); - auto dst_event = TransferToDevice(host_buf, dst, {src_event.get()}); - dst_event->AddCallback( - [src_event, host_buf](xla::Status status) { delete[] host_buf; }); - return dst_event; - } - } - - std::unique_ptr CompileProgram( - const xla::HloProto& source, int32_t num_replicas, - absl::Span wait_for, - const xla::DebugOptions& debug_options) override { - int64_t operation_id = GetOperationId(); - auto deps = GetDependencyOperationIds(wait_for); - - ScheduleRequest( - operation_id, - [this, operation_id, source, num_replicas, - debug_options]() ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_) { - auto cph_iterator = - underlying_cph_ - .insert( - {operation_id, - std::vector>()}) - .first; - - std::vector> collected_events; - for (int i = 0; i < drivers_.size(); ++i) { - auto current_cph = drivers_[i]->CompileProgram(source, num_replicas, - {}, debug_options); - cph_iterator->second.push_back(std::move(current_cph)); - collected_events.push_back(cph_iterator->second[i]->OnReady()); - } - return std::make_shared(this, operation_id, - collected_events); - }, - deps); - - return std::make_unique(this, operation_id); - } - - std::unique_ptr LoadProgram( - int32_t core_id, const CompiledProgramHandle* handle, - absl::Span wait_for) override { - int64_t operation_id = GetOperationId(); - auto deps = GetDependencyOperationIds(wait_for); - deps.insert( - static_cast(handle)->operation_id()); - auto cph_op_id = - static_cast(handle)->operation_id(); - - ScheduleRequest( - operation_id, - [this, operation_id, cph_op_id, core_id]() - ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_) -> std::shared_ptr { - CHECK_EXISTS_OR_RETURN(underlying_cph_, cph_op_id, operation_id); - auto cph_iter = underlying_cph_.find(cph_op_id); - - underlying_lph_.insert( - {operation_id, - core_to_driver_[core_id]->LoadProgram( - core_to_driver_core_[core_id], - cph_iter->second[core_to_driver_id_[core_id]].get(), - {})}); - - return underlying_lph_[operation_id]->OnReady(); - }, - deps); - - return std::make_unique(this, operation_id, - core_id); - } - - std::shared_ptr UnloadProgram( - std::unique_ptr handle, - absl::Span wait_for) override { - int64_t operation_id = GetOperationId(); - auto deps = GetDependencyOperationIds(wait_for); - deps.insert( - static_cast(handle.get())->operation_id()); - auto op_id = - static_cast(handle.get())->operation_id(); - auto core_id = - static_cast(handle.get())->core_id(); - - ScheduleRequest( - operation_id, - [this, operation_id, op_id, core_id]() - ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_) -> std::shared_ptr { - CHECK_EXISTS_OR_RETURN(underlying_lph_, op_id, operation_id); - auto lph_iter = underlying_lph_.find(op_id); - auto event = core_to_driver_[core_id]->UnloadProgram( - std::move(lph_iter->second), {}); - underlying_lph_.erase(lph_iter); - - return event; - }, - deps); - - return std::make_shared(this, operation_id); - } - - std::shared_ptr ExecuteProgram( - LoadedProgramHandle* program, absl::Span inputs, - absl::Span outputs, - const xla::DeviceAssignmentProto& device_assignment, - absl::Span wait_for) override { - int64_t operation_id = GetOperationId(); - - auto deps = GetDependencyOperationIds(wait_for); - deps.insert(static_cast(program)->operation_id()); - - auto op_id = static_cast(program)->operation_id(); - auto core_id = static_cast(program)->core_id(); - - std::vector input_op_ids; - std::vector output_op_ids; - input_op_ids.reserve(inputs.size()); - output_op_ids.reserve(outputs.size()); - - for (auto* input : inputs) { - auto input_dep = - static_cast(input)->operation_id(); - input_op_ids.push_back(input_dep); - deps.insert(input_dep); - } - for (auto* output : outputs) { - auto output_dep = - static_cast(output)->operation_id(); - output_op_ids.push_back(output_dep); - deps.insert(output_dep); - } - - ScheduleRequest( - operation_id, - [this, operation_id, core_id, op_id, input_op_ids, output_op_ids, - device_assignment]() - ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_) -> std::shared_ptr { - std::vector underlying_inputs; - std::vector underlying_outputs; - - underlying_inputs.reserve(input_op_ids.size()); - for (auto input_op_id : input_op_ids) { - CHECK_EXISTS_OR_RETURN(underlying_buffers_, input_op_id, - operation_id); - underlying_inputs.push_back( - underlying_buffers_[input_op_id].get()); - } - underlying_outputs.reserve(output_op_ids.size()); - for (auto output_op_id : output_op_ids) { - CHECK_EXISTS_OR_RETURN(underlying_buffers_, output_op_id, - operation_id); - underlying_outputs.push_back( - underlying_buffers_[output_op_id].get()); - } - - CHECK_EXISTS_OR_RETURN(underlying_lph_, op_id, operation_id); - LoadedProgramHandle* handle = underlying_lph_[op_id].get(); - return core_to_driver_[core_id]->ExecuteProgram( - handle, underlying_inputs, underlying_outputs, - device_assignment, {}); - }, - deps); - - return std::make_shared(this, operation_id); - } - - std::unique_ptr GetLinearizer() override { - return drivers_[0]->GetLinearizer(); - } - - // Helper methods for Event scheduling - - std::optional WaitForEvent(int64_t event_id, absl::Duration duration) - ABSL_LOCKS_EXCLUDED(mu_) { - std::shared_ptr underlying_event; - - { - absl::MutexLock l(&mu_); - auto event = events_.find(event_id); - - if (event == events_.end()) { - auto event_status = abnormal_event_status_.find(event_id); - if (event_status == abnormal_event_status_.end()) { - return OkStatus(); - } else { - return event_status->second; - } - } - - auto done = [this, event_id]() { - mu_.AssertHeld(); - // The event was either completed and erased from the map or we have - // an underlying event available to us. - return events_.count(event_id) == 0 || - (events_[event_id]->underlying_event != nullptr && - events_[event_id]->underlying_event.use_count() != 0); - }; - - auto status = mu_.AwaitWithTimeout(absl::Condition(&done), duration); - if (!status) { - return std::nullopt; - } - - if (events_.count(event_id) > 0) { - underlying_event = events_[event_id]->underlying_event; - } else { - underlying_event = nullptr; - } - } - - // Wait for the underlying event without holding on to the event_lock_, or - // else incoming events will not be processed. - if (underlying_event != nullptr) { - return underlying_event->AwaitWithTimeout(duration); - } else { - absl::MutexLock l(&mu_); - auto event_status = abnormal_event_status_.find(event_id); - if (event_status == abnormal_event_status_.end()) { - return OkStatus(); - } else { - return event_status->second; - } - } - } - - void AddCallbackForEvent(int64_t event_id, std::function fn) - ABSL_LOCKS_EXCLUDED(mu_) { - absl::MutexLock l(&mu_); - auto event = events_.find(event_id); - - if (event == events_.end()) { - auto event_status = abnormal_event_status_.find(event_id); - if (event_status == abnormal_event_status_.end()) { - fn(OkStatus()); - } else { - fn(event_status->second); - } - } else { - if (event->second->underlying_event != nullptr && - event->second->underlying_event.use_count() != 0) { - event->second->underlying_event->AddCallback(fn); - } else { - event->second->callbacks.push_back(std::move(fn)); - } - } - } - - xla::Status GetCompiledProgramShape(int64_t op_id, - xla::ProgramShapeProto* program_shape) - ABSL_LOCKS_EXCLUDED(mu_) { - absl::MutexLock l(&mu_); - - auto done = [this, op_id]() { - mu_.AssertHeld(); - return underlying_cph_.contains(op_id); - }; - mu_.Await(absl::Condition(&done)); - - return underlying_cph_[op_id][0]->program_shape(program_shape); - } - - private: - const TpuDriverConfig& config_; - std::shared_ptr<::grpc::ChannelCredentials> creds_; - - absl::flat_hash_map> drivers_; - absl::flat_hash_map core_to_driver_id_; - absl::flat_hash_map core_to_driver_; - absl::flat_hash_map core_to_driver_core_; - SystemInfo pod_info_; - - absl::Mutex mu_; - - absl::flat_hash_map> - underlying_buffers_ ABSL_GUARDED_BY(mu_); - absl::flat_hash_map>> - underlying_cph_ ABSL_GUARDED_BY(mu_); - absl::flat_hash_map> - underlying_lph_ ABSL_GUARDED_BY(mu_); - - absl::btree_map> events_ - ABSL_GUARDED_BY(mu_); - absl::flat_hash_map abnormal_event_status_ - ABSL_GUARDED_BY(mu_); - - std::atomic operation_id_counter_{0}; - - WorkerThread event_thread_; - - int64_t GetOperationId() { return operation_id_counter_++; } - - absl::flat_hash_set GetDependencyOperationIds( - absl::Span wait_for) { - absl::flat_hash_set deps; - for (auto* event : wait_for) { - deps.insert(static_cast(event)->operation_id()); - } - return deps; - } - - // EventCompleted is executed on the event_thread_ worker thread. We want - // to propagate the fact that the event is completed to any subsequent events - // that might depend on this event. - void EventCompleted(int64_t event_id, Status status) - ABSL_LOCKS_EXCLUDED(mu_) { - absl::MutexLock l(&mu_); - - absl::btree_map>::iterator - curr_event; - if (!status.ok()) abnormal_event_status_.insert({event_id, status}); - curr_event = events_.find(event_id); - - DCHECK(curr_event->second->callbacks.empty()); - DCHECK(curr_event->second->incomplete_deps.empty()); - - for (auto& event : events_) { - event.second->incomplete_deps.erase(event_id); - // The if statement conditions on both - // - all previous events have completed (incomplete_deps.empty()) - // - the op creating this event has not been called yet - // (event.second.create_fn != nullptr) - // We call the create_fn that creates the event and adds any relevant - // callbacks to the actual event, before setting create_fn to nullptr - // to indicate that it has already been called - if (event.second->incomplete_deps.empty() && - event.second->create_fn != nullptr) { - // We were the last unfilled dependency, all other dependencies are - // filled. We can now fire the create function. - event.second->underlying_event = event.second->create_fn(); - for (auto& fn : event.second->callbacks) { - event.second->underlying_event->AddCallback(std::move(fn)); - } - event.second->callbacks.clear(); - event.second->create_fn = nullptr; - } - } - - // We erase the current event to signal that it has finished. - events_.erase(curr_event); - } - - void ScheduleRequest(int64_t operation_id, - std::function(void)> fn, - const absl::flat_hash_set& deps) - ABSL_LOCKS_EXCLUDED(mu_) { - absl::MutexLock l(&mu_); - absl::btree_map>::iterator event; - absl::flat_hash_set incomplete_deps; - - event = - events_.insert({operation_id, std::make_unique()}).first; - for (const auto& dep : deps) { - if (events_.count(dep) > 0) incomplete_deps.insert(dep); - } - - if (incomplete_deps.empty()) { - // All dependencies have been fulfilled, we execute the request - // immediately and add a callback to inform our event fulfilled thread - // when it is done. - event->second->create_fn = nullptr; - event->second->underlying_event = fn(); - event->second->underlying_event->AddCallback( - [this, operation_id](Status status) { - event_thread_.Schedule([this, operation_id, status]() { - EventCompleted(operation_id, status); - }); - }); - } else { - // There are some dependencies that are not yet fulfilled. We attach - // the request to the event, and will execute it in the EventFulfilled - // worker thread when all its dependencies are fulfilled. - event->second->create_fn = std::move(fn); - event->second->incomplete_deps = std::move(incomplete_deps); - event->second->callbacks.push_back([this, operation_id](Status status) { - event_thread_.Schedule([this, operation_id, status]() { - EventCompleted(operation_id, status); - }); - }); - } - } - - template - std::shared_ptr CheckHandleExists( - absl::flat_hash_map& container, int64_t target_op_id, - int64_t operation_id) { - if (container.count(target_op_id) == 0) { - return std::make_shared( - this, operation_id, - tsl::errors::InvalidArgument("Handle ", target_op_id, - " does not exist.")); - } - return nullptr; - } -}; - -xla::Status PodEvent::Await() { - return driver_->WaitForEvent(operation_id_, absl::InfiniteDuration()).value(); -} - -std::optional PodEvent::AwaitWithTimeout(absl::Duration duration) { - return driver_->WaitForEvent(operation_id_, duration); -} - -void PodEvent::AddCallback(std::function callback) { - driver_->AddCallbackForEvent(operation_id_, std::move(callback)); -} - -xla::StatusOr> CreatePodTpuDriver( - const TpuDriverConfig& config, - std::shared_ptr<::grpc::ChannelCredentials> creds) { - return std::unique_ptr(new PodTpuDriver(config, creds)); -} - -xla::Status PodCompiledProgramHandle::program_shape( - xla::ProgramShapeProto* program_shape) { - return driver_->GetCompiledProgramShape(operation_id(), program_shape); -} - -} // namespace - -REGISTER_TPU_DRIVER(kPodTpuDriverPrefix, - [](const TpuDriverConfig& config) - -> xla::StatusOr> { - return CreatePodTpuDriver( - config, - ::grpc::InsecureChannelCredentials()); // NOLINT - }); - -} // namespace tpu_driver diff --git a/third_party/xla/xla/python/tpu_driver/recording_tpu_driver.cc b/third_party/xla/xla/python/tpu_driver/recording_tpu_driver.cc deleted file mode 100644 index 627ce469ffcd9e..00000000000000 --- a/third_party/xla/xla/python/tpu_driver/recording_tpu_driver.cc +++ /dev/null @@ -1,590 +0,0 @@ -// Copyright 2019 The OpenXLA Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= -#include -#include -#include -#include -#include -#include -#include - -#include "absl/base/internal/sysinfo.h" -#include "absl/strings/str_split.h" -#include "absl/strings/string_view.h" -#include "xla/python/tpu_driver/platform/external/compat.h" -#include "xla/python/tpu_driver/tpu_driver.h" -#include "xla/python/tpu_driver/tpu_driver.pb.h" -#include "xla/python/tpu_driver/tpu_service.grpc.pb.h" -#include "tsl/platform/file_system.h" -#include "tsl/platform/threadpool.h" - -/* - * The ReplayDriver wraps a concrete TpuDriver implementation and records the - * stream of operations to a log file. This log can be later replayed and - * analyzed for debugging. - */ - -namespace tpu_driver { -namespace { - -static std::atomic id_counter(0); - -using xla::Status; - -class RecordingTpuDriver; - -class RecordingEvent : public Event { - public: - explicit RecordingEvent(std::shared_ptr event) - : shared_event_(std::move(event)), id_(id_counter++) {} - - explicit RecordingEvent(std::shared_ptr event, int64_t id) - : shared_event_(event), id_(id) {} - - ~RecordingEvent() override = default; - - xla::Status Await() override { return shared_event_->Await(); } - - std::optional AwaitWithTimeout( - absl::Duration duration) override { - return shared_event_->AwaitWithTimeout(duration); - } - - void AddCallback(std::function callback) override { - return shared_event_->AddCallback(callback); - } - - private: - std::shared_ptr shared_event_; - - int64_t id_; - friend class RecordingTpuDriver; -}; - -class RecordingBufferHandle : public BufferHandle { - public: - explicit RecordingBufferHandle(std::unique_ptr handle) - : handle_(std::move(handle)), - id_(id_counter++), - event_(std::make_shared(handle_->OnReady(), id_)) {} - std::shared_ptr OnReady() override { return event_; } - int64_t size_in_bytes() override { return handle_->size_in_bytes(); } - std::optional shape() override { return handle_->shape(); } - - private: - std::unique_ptr handle_; - int64_t id_; - std::shared_ptr event_; - friend class RecordingTpuDriver; -}; - -class RecordingCompiledProgramHandle : public CompiledProgramHandle { - public: - explicit RecordingCompiledProgramHandle( - std::unique_ptr handle) - : handle_(std::move(handle)), - id_(id_counter++), - event_(std::make_shared(handle_->OnReady(), id_)) {} - std::shared_ptr OnReady() override { return event_; } - int64_t size_in_bytes() override { return handle_->size_in_bytes(); } - xla::Status program_shape(xla::ProgramShapeProto* program_shape) override { - return handle_->program_shape(program_shape); - } - - private: - std::unique_ptr handle_; - int64_t id_; - std::shared_ptr event_; - friend class RecordingTpuDriver; -}; - -class RecordingLoadedProgramHandle : public LoadedProgramHandle { - public: - explicit RecordingLoadedProgramHandle( - std::unique_ptr handle) - : handle_(std::move(handle)), - id_(id_counter++), - event_(std::make_shared(handle_->OnReady(), id_)) {} - std::shared_ptr OnReady() override { return event_; } - int64_t size_in_bytes() override { return handle_->size_in_bytes(); } - - private: - std::unique_ptr handle_; - int64_t id_; - std::shared_ptr event_; - friend class RecordingTpuDriver; -}; - -class RecordingTpuDriver : public TpuDriver { - public: - explicit RecordingTpuDriver(std::unique_ptr driver, - const std::string recording_path, - const bool flush) - : driver_(std::move(driver)), - recording_path_(recording_path), - flush_(flush) { - auto file_status = - tsl::Env::Default()->NewAppendableFile(recording_path_, &log_file_); - if (!file_status.ok()) { - LOG(FATAL) << "Unable to open " << recording_path_ - << " for appending. Error: " << file_status; - } - } - ~RecordingTpuDriver() override { - { - log_file_->Flush().IgnoreError(); - log_file_->Close().IgnoreError(); - log_file_ = nullptr; - } - } - - void QuerySystemInfo(SystemInfo* system_info) override { - // TODO(frankchn): Should we even save this event, since it is out-of-band. - driver_->QuerySystemInfo(system_info); - } - - Status Reset() override { return driver_->Reset(); } - - std::unique_ptr Allocate( - int32_t core_id, MemoryRegion region, int64_t num_bytes, - absl::Span wait_for) override { - auto unwrapped_wait_for = UnwrapWaitFor(wait_for); - - auto thread_id = GetCurrentThreadId(); - auto handle = - driver_->Allocate(core_id, region, num_bytes, unwrapped_wait_for); - auto recording_handle = - std::make_unique(std::move(handle)); - auto handle_id = recording_handle->id_; - - { - StreamRequest::Entry r; - r.mutable_alloc()->set_core_id(core_id); - r.mutable_alloc()->set_region(region); - r.mutable_alloc()->set_num_bytes(num_bytes); - - PopulateAndSaveEntry(&r, wait_for, handle_id, thread_id); - } - - return recording_handle; - } - - std::unique_ptr Allocate( - int32_t core_id, MemoryRegion region, const xla::ShapeProto& shape, - absl::Span wait_for) override { - auto unwrapped_wait_for = UnwrapWaitFor(wait_for); - - auto thread_id = GetCurrentThreadId(); - auto handle = driver_->Allocate(core_id, region, shape, unwrapped_wait_for); - auto recording_handle = - std::make_unique(std::move(handle)); - auto handle_id = recording_handle->id_; - - { - StreamRequest::Entry r; - r.mutable_alloc()->set_core_id(core_id); - r.mutable_alloc()->set_region(region); - *(r.mutable_alloc()->mutable_shape()) = shape; - - PopulateAndSaveEntry(&r, wait_for, handle_id, thread_id); - } - - return recording_handle; - } - - std::unique_ptr AllocateTuple( - int32_t core_id, MemoryRegion region, - absl::Span children, - absl::Span wait_for) override { - auto unwrapped_wait_for = UnwrapWaitFor(wait_for); - - std::vector unwrapped_children; - std::vector child_ids; - const auto children_size = children.size(); - unwrapped_children.reserve(children_size); - child_ids.reserve(children_size); - for (auto child : children) { - BufferHandle* unwrapped_child = - static_cast(child)->handle_.get(); - unwrapped_children.push_back(unwrapped_child); - child_ids.push_back( - static_cast(child)->id_); - } - - auto thread_id = GetCurrentThreadId(); - auto handle = driver_->AllocateTuple(core_id, region, unwrapped_children, - unwrapped_wait_for); - auto recording_handle = - std::make_unique(std::move(handle)); - auto handle_id = recording_handle->id_; - - { - StreamRequest::Entry r; - r.mutable_alloc_tuple()->set_core_id(core_id); - r.mutable_alloc_tuple()->set_region(region); - - for (auto child : child_ids) { - r.mutable_alloc_tuple()->add_children(child); - } - - PopulateAndSaveEntry(&r, wait_for, handle_id, thread_id); - } - - return recording_handle; - } - - std::shared_ptr Deallocate( - std::unique_ptr handle, - absl::Span wait_for) override { - auto unwrapped_wait_for = UnwrapWaitFor(wait_for); - - auto thread_id = GetCurrentThreadId(); - auto recording_handle = static_cast(handle.get()); - int64_t recording_handle_id = recording_handle->id_; - auto event = driver_->Deallocate(std::move(recording_handle->handle_), - unwrapped_wait_for); - auto recording_event = std::make_shared(std::move(event)); - int64_t event_id = recording_event->id_; - - { - StreamRequest::Entry r; - r.mutable_dealloc()->set_handle(recording_handle_id); - PopulateAndSaveEntry(&r, wait_for, event_id, thread_id); - } - - return recording_event; - } - - std::shared_ptr TransferToDevice( - const void* src, BufferHandle* dst, - absl::Span wait_for) override { - int64_t num_bytes = dst->size_in_bytes(); - auto unwrapped_wait_for = UnwrapWaitFor(wait_for); - - auto thread_id = GetCurrentThreadId(); - auto recording_handle = static_cast(dst); - int64_t recording_handle_id = recording_handle->id_; - auto recording_event = - std::make_shared(driver_->TransferToDevice( - src, static_cast(dst)->handle_.get(), - unwrapped_wait_for)); - int64_t event_id = recording_event->id_; - - { - StreamRequest::Entry r; - r.mutable_transfer_to()->set_target_handle(recording_handle_id); - if (num_bytes > 0) { - r.mutable_transfer_to()->mutable_data()->assign( - static_cast(src), num_bytes); - } else { - *r.mutable_transfer_to()->mutable_data() = ""; - } - PopulateAndSaveEntry(&r, wait_for, event_id, thread_id); - } - - return recording_event; - } - - std::shared_ptr TransferFromDevice( - const BufferHandle* src, void* dst, - absl::Span wait_for) override { - auto unwrapped_wait_for = UnwrapWaitFor(wait_for); - - auto thread_id = GetCurrentThreadId(); - auto src_handle_id = static_cast(src)->id_; - auto recording_event = - std::make_shared(driver_->TransferFromDevice( - static_cast(src)->handle_.get(), dst, - unwrapped_wait_for)); - auto event_id = recording_event->id_; - - { - StreamRequest::Entry r; - r.mutable_transfer_from()->set_source_handle(src_handle_id); - PopulateAndSaveEntry(&r, wait_for, event_id, thread_id); - } - - return recording_event; - } - - std::shared_ptr TransferFromDeviceToDevice( - const BufferHandle* src, BufferHandle* dst, - absl::Span wait_for) override { - auto unwrapped_wait_for = UnwrapWaitFor(wait_for); - - auto thread_id = GetCurrentThreadId(); - auto src_handle_id = static_cast(src)->id_; - auto dst_handle_id = static_cast(dst)->id_; - auto recording_event = - std::make_shared(driver_->TransferFromDeviceToDevice( - static_cast(src)->handle_.get(), - static_cast(dst)->handle_.get(), - unwrapped_wait_for)); - auto event_id = recording_event->id_; - - { - StreamRequest::Entry r; - r.mutable_transfer_from_to()->set_source_handle(src_handle_id); - r.mutable_transfer_from_to()->set_target_handle(dst_handle_id); - PopulateAndSaveEntry(&r, wait_for, event_id, thread_id); - } - - return recording_event; - } - - std::unique_ptr CompileProgram( - const xla::HloProto& source, int32_t num_replicas, - absl::Span wait_for, - const xla::DebugOptions& debug_options) override { - auto unwrapped_wait_for = UnwrapWaitFor(wait_for); - - auto thread_id = GetCurrentThreadId(); - auto recording_handle = std::make_unique( - driver_->CompileProgram(source, num_replicas, unwrapped_wait_for, - debug_options)); - auto handle_id = recording_handle->id_; - - { - StreamRequest::Entry r; - *r.mutable_compile()->mutable_hlo_program() = source; - r.mutable_compile()->set_num_replicas(num_replicas); - PopulateAndSaveEntry(&r, wait_for, handle_id, thread_id); - } - - return recording_handle; - } - - std::unique_ptr LoadProgram( - int32_t core_id, const CompiledProgramHandle* handle, - absl::Span wait_for) override { - auto unwrapped_wait_for = UnwrapWaitFor(wait_for); - - auto thread_id = GetCurrentThreadId(); - auto compiled_handle_id = - static_cast(handle)->id_; - auto recording_handle = - std::make_unique(driver_->LoadProgram( - core_id, - static_cast(handle) - ->handle_.get(), - unwrapped_wait_for)); - auto handle_id = recording_handle->id_; - { - StreamRequest::Entry r; - r.mutable_load()->set_core_id(core_id); - r.mutable_load()->set_compiled_program_handle(compiled_handle_id); - PopulateAndSaveEntry(&r, wait_for, handle_id, thread_id); - } - - return recording_handle; - } - - std::shared_ptr UnloadProgram( - std::unique_ptr handle, - absl::Span wait_for) override { - auto unwrapped_wait_for = UnwrapWaitFor(wait_for); - - auto thread_id = GetCurrentThreadId(); - auto loaded_handle_id = - static_cast(handle.get())->id_; - auto recording_event = - std::make_shared(driver_->UnloadProgram( - std::move(static_cast(handle.get()) - ->handle_), - unwrapped_wait_for)); - auto event_id = recording_event->id_; - - { - StreamRequest::Entry r; - r.mutable_unload()->set_loaded_program_handle(loaded_handle_id); - PopulateAndSaveEntry(&r, wait_for, event_id, thread_id); - } - - return recording_event; - } - - std::shared_ptr ExecuteProgram( - LoadedProgramHandle* program, absl::Span inputs, - absl::Span outputs, - const xla::DeviceAssignmentProto& device_assignment, - absl::Span wait_for) override { - auto unwrapped_wait_for = UnwrapWaitFor(wait_for); - - auto thread_id = GetCurrentThreadId(); - auto program_handle_id = - static_cast(program)->id_; - - std::vector unwrapped_inputs; - std::vector input_ids; - const auto inputs_size = inputs.size(); - unwrapped_inputs.reserve(inputs_size); - input_ids.reserve(inputs_size); - for (auto input : inputs) { - BufferHandle* unwrapped_input = - static_cast(input)->handle_.get(); - unwrapped_inputs.push_back(unwrapped_input); - input_ids.push_back( - static_cast(input)->id_); - } - - std::vector unwrapped_outputs; - std::vector output_ids; - const auto output_size = outputs.size(); - unwrapped_outputs.reserve(output_size); - output_ids.reserve(output_size); - for (auto output : outputs) { - BufferHandle* unwrapped_output = - static_cast(output)->handle_.get(); - unwrapped_outputs.push_back(unwrapped_output); - output_ids.push_back( - static_cast(output)->id_); - } - - auto recording_event = - std::make_shared(driver_->ExecuteProgram( - static_cast(program)->handle_.get(), - unwrapped_inputs, unwrapped_outputs, device_assignment, - unwrapped_wait_for)); - auto event_id = recording_event->id_; - - { - StreamRequest::Entry r; - r.mutable_execute()->set_loaded_program_handle(program_handle_id); - for (auto input_id : input_ids) { - r.mutable_execute()->add_input_handle(input_id); - } - for (auto output_id : output_ids) { - r.mutable_execute()->add_output_handle(output_id); - } - *r.mutable_execute()->mutable_device_assignment() = device_assignment; - - PopulateAndSaveEntry(&r, wait_for, event_id, thread_id); - } - - return recording_event; - } - - std::unique_ptr GetLinearizer() override { - return driver_->GetLinearizer(); - } - - private: - std::unique_ptr driver_; - const std::string recording_path_; - const bool flush_; - - std::unique_ptr log_file_; - - void PopulateAndSaveEntry(StreamRequest::Entry* r, - absl::Span wait_for, - int64_t handle_id, int64_t thread_id) { - for (auto event : wait_for) { - auto recording_event = static_cast(event); - r->add_wait_for_id(recording_event->id_); - } - r->set_operation_id(handle_id); - r->set_thread_id(thread_id); - - uint64_t data_size = r->ByteSizeLong(); - std::vector buffer; - buffer.resize(sizeof(data_size) + data_size); - memcpy(buffer.data(), &data_size, sizeof(data_size)); - r->SerializeToArray(buffer.data() + sizeof(data_size), data_size); - - { - if (log_file_ == nullptr) { - LOG(WARNING) << "The TPU driver has been shut down before all logging " - "has been written."; - return; - } - - absl::string_view buffer_sp(buffer.data(), buffer.size()); - auto data_status = log_file_->Append(buffer_sp); - if (!data_status.ok()) { - LOG(WARNING) << "Unable to write data to log file. File possibly " - "corrupt. Error: " - << data_status; - } - - if (flush_) { - auto flush_status = log_file_->Flush(); - if (!flush_status.ok()) { - LOG(WARNING) << "Unable to flush data to log file. File possibly " - "corrupt. Error: " - << flush_status; - } - - auto sync_status = log_file_->Sync(); - if (!sync_status.ok()) { - LOG(WARNING) << "Unable to sync log file. File possibly " - "corrupt. Error: " - << sync_status; - } - } - } - } - - std::vector UnwrapWaitFor(absl::Span wait_for) { - std::vector unwrapped_events; - for (auto event : wait_for) { - Event* unwrapped_event = - static_cast(event)->shared_event_.get(); - unwrapped_events.push_back(unwrapped_event); - } - return unwrapped_events; - } - - int64_t GetCurrentThreadId() { return absl::base_internal::GetTID(); } -}; - -xla::StatusOr> RegisterRecordingTpuDriver( - const TpuDriverConfig& config) { - std::vector configs = absl::StrSplit(config.worker(), '|'); - - std::string file; - std::string worker; - bool flush = false; - - for (const auto& config : configs) { - std::vector kv = - absl::StrSplit(config, absl::MaxSplits('=', 1)); - if (kv[0] == "file") { - file = kv[1]; - } - if (kv[0] == "worker") { - worker = kv[1]; - } - if (kv[0] == "flush") { - if (kv[1] == "true" || kv[1] == "1") { - flush = true; - } - } - } - - TpuDriverConfig worker_config; - worker_config.set_worker(worker); - - auto driver_status = TpuDriverRegistry::Open(worker_config); - if (!driver_status.ok()) return driver_status.status(); - return std::unique_ptr( - new RecordingTpuDriver(std::move(driver_status).value(), file, flush)); -} - -// To record a sequence of operations, set the worker configuration string to -// record://|file=|worker=grpc://1.2.3.4:8470 (for GRPC). -REGISTER_TPU_DRIVER("record://", RegisterRecordingTpuDriver); - -} // namespace -} // namespace tpu_driver diff --git a/third_party/xla/xla/python/tpu_driver/tpu_driver.cc b/third_party/xla/xla/python/tpu_driver/tpu_driver.cc deleted file mode 100644 index dc064809c9ac50..00000000000000 --- a/third_party/xla/xla/python/tpu_driver/tpu_driver.cc +++ /dev/null @@ -1,117 +0,0 @@ -// Copyright 2019 The OpenXLA Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= - -#include "xla/python/tpu_driver/tpu_driver.h" - -#include -#include -#include -#include - -#include "absl/strings/match.h" -#include "absl/synchronization/mutex.h" -#include "xla/util.h" - -namespace tpu_driver { - -namespace { - -typedef absl::flat_hash_map< - std::string, std::function>( - const TpuDriverConfig&)>> - DriverRegistryMap; - -DriverRegistryMap* GetDriverRegistryMap() { - static DriverRegistryMap* driver_registry = new DriverRegistryMap(); - return driver_registry; -} - -int64_t ByteSizeOfPrimitiveType(xla::PrimitiveType primitive_type) { - switch (primitive_type) { - case xla::PrimitiveType::PRED: - return sizeof(int8_t); - case xla::PrimitiveType::S8: - return sizeof(int8_t); - case xla::PrimitiveType::S16: - return sizeof(int16_t); - case xla::PrimitiveType::S32: - return sizeof(int32_t); - case xla::PrimitiveType::S64: - return sizeof(int64_t); - case xla::PrimitiveType::U8: - return sizeof(uint8_t); - case xla::PrimitiveType::U16: - return sizeof(uint16_t); - case xla::PrimitiveType::U32: - return sizeof(uint32_t); - case xla::PrimitiveType::U64: - return sizeof(uint64_t); - case xla::PrimitiveType::BF16: - return sizeof(float) / 2; - case xla::PrimitiveType::F16: - return sizeof(float) / 2; - case xla::PrimitiveType::F32: - return sizeof(float); - case xla::PrimitiveType::F64: - return sizeof(double); - case xla::PrimitiveType::C64: - return sizeof(std::complex); - case xla::PrimitiveType::C128: - return sizeof(std::complex); - case xla::PrimitiveType::TOKEN: - case xla::PrimitiveType::TUPLE: - case xla::PrimitiveType::OPAQUE_TYPE: - LOG(FATAL) << PrimitiveType_Name(primitive_type) - << " primitive type has no definitive size"; - default: - LOG(FATAL) << "Unhandled primitive type " << primitive_type; - } -} - -} // namespace - -/*static*/ int TpuDriverRegistry::RegisterDriver( - const std::string& prefix, - const std::function>( - const TpuDriverConfig&)>& creator) { - (*GetDriverRegistryMap())[prefix] = creator; - return GetDriverRegistryMap()->size(); -} - -/*static*/ xla::StatusOr> TpuDriverRegistry::Open( - const TpuDriverConfig& config) { - for (const auto& driver : *GetDriverRegistryMap()) { - if (absl::StartsWith(config.worker(), driver.first)) { - return driver.second(config); - } - } - return xla::NotFound("Unable to find driver in registry given worker: %s", - config.worker()); -} - -int64_t ComputeBytesFromShape(const xla::ShapeProto& shape) { - if (shape.tuple_shapes_size() > 0) { - LOG(FATAL) << "Tuples are not supported at the moment."; - } - - int64_t num_elems = 1; - for (auto dim : shape.dimensions()) { - num_elems *= dim; - } - - return ByteSizeOfPrimitiveType(shape.element_type()) * num_elems; -} - -} // namespace tpu_driver diff --git a/third_party/xla/xla/python/tpu_driver/tpu_driver.h b/third_party/xla/xla/python/tpu_driver/tpu_driver.h deleted file mode 100644 index 695a2c0a884b7a..00000000000000 --- a/third_party/xla/xla/python/tpu_driver/tpu_driver.h +++ /dev/null @@ -1,255 +0,0 @@ -// Copyright 2019 The OpenXLA Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================== - -#ifndef XLA_PYTHON_TPU_DRIVER_TPU_DRIVER_H_ -#define XLA_PYTHON_TPU_DRIVER_TPU_DRIVER_H_ - -#include -#include -#include -#include -#include -#include -#include - -#include "absl/container/flat_hash_map.h" -#include "absl/container/inlined_vector.h" -#include "absl/synchronization/mutex.h" -#include "absl/types/span.h" -#include "xla/python/tpu_driver/platform/external/compat.h" -#include "xla/python/tpu_driver/tpu_driver.pb.h" -#include "xla/service/hlo.pb.h" -#include "xla/status.h" -#include "xla/statusor.h" -#include "xla/xla.pb.h" -#include "xla/xla_data.pb.h" -#include "tsl/platform/logging.h" - -// This API is EXPERIMENTAL and under active development. It is subject to -// change without notice. - -namespace tpu_driver { - -int64_t ComputeBytesFromShape(const xla::ShapeProto& shape); - -// Represents the deferred completion of a scheduled operation. -// -// Events may be blocked on, or used as `wait_for` arguments to enforce -// inter-operation dependencies. -class Event { - public: - virtual ~Event() = default; - - // Blocks until the event completes and returns the result status. - virtual xla::Status Await() = 0; - // Returns an empty result if the wait times out. - virtual std::optional AwaitWithTimeout( - absl::Duration duration) = 0; - - // If the event is already done, the callback is called immediately. - virtual void AddCallback(std::function callback) = 0; -}; - -// Represents a device memory allocation. -class BufferHandle { - public: - virtual ~BufferHandle() = default; - - // This event completes after the device memory is actually allocated. - // - // Methods that take a buffer handle, such as ExecuteProgram and Transfer*, - // automatically add this event as a dependency. - virtual std::shared_ptr OnReady() = 0; - - virtual int64_t size_in_bytes() = 0; - virtual std::optional shape() = 0; -}; - -// Represents a compiled program on the host. -class CompiledProgramHandle { - public: - virtual ~CompiledProgramHandle() = default; - - // This Event completes after the program is actually compiled on the host. - // - // Methods that take a compiled program handle, including LoadProgram, - // automatically add this event as a dependency. - virtual std::shared_ptr OnReady() = 0; - - virtual int64_t size_in_bytes() { - LOG(FATAL) << "Unimplemented."; - return 0; - } - - // Returns the shape of the compiled program. Blocks until compile completes. - virtual xla::Status program_shape(xla::ProgramShapeProto* program_shape) = 0; -}; - -// Represents a program loaded on the device. -class LoadedProgramHandle { - public: - virtual ~LoadedProgramHandle() = default; - - // This Event completes after the program is actually loaded on the device. - // - // Methods that take a loaded program handle, including ExecuteProgram and - // UnloadProgram, automatically add this event as a dependency. - virtual std::shared_ptr OnReady() = 0; - - virtual int64_t size_in_bytes() { - LOG(FATAL) << "Unimplemented."; - return 0; - } -}; - -// A TpuLinearizer manages the linearization and delinearization of user buffers -// in the TPU driver. This interface is not yet implemented. -class TpuLinearizer { - public: - virtual ~TpuLinearizer() = default; - - int64_t ComputeBytesFromShape(const xla::ShapeProto& shape) { - return ::tpu_driver::ComputeBytesFromShape(shape); - } - virtual int64_t ComputeLinearizedBytesFromShape( - const xla::ShapeProto& shape) = 0; - - virtual xla::Status LinearizeShape(void* dst, const void* src, - const xla::ShapeProto& shape) = 0; - virtual xla::Status DelinearizeShape(void* dst, const void* src, - const xla::ShapeProto& shape) = 0; -}; - -// A TpuDriver manages a set of operations scheduled to run on a TPU system. -// -// By default, two independently scheduled operations may execute in any order. -// Ordering can be imposed in one of two ways: -// -// 1. Users can specify event dependencies via the `wait_for` argument. -// 2. Operations using buffer or program handles implicitly wait for the handles -// to become ready before executing. -// -// For returned handle objects, the user is responsible for calling the release -// methods (Deallocate, UnloadProgram, etc.) that consume the given unique_ptr -// arguments and free up device resources. For returned event objects, there is -// no release method; the user can let them go out of scope naturally. As soon -// as those methods accepting plain-pointer arguments return, the user can let -// the corresponding smart-pointer objects be released or go out of scope, -// regardless of whether the scheduled device operations have started execution. -class TpuDriver { - public: - virtual ~TpuDriver() = default; - - virtual void QuerySystemInfo(SystemInfo* system_info) = 0; - // Synchronous. Reset the state of the TPU driver. After Reset(), this TPU - // driver object is no longer usable. Users must destroy this object and - // create a new one. - // - // All running programs will be terminated and all allocations reset. All - // events and buffer handles created prior to Reset() will be invalid, and any - // use will result in undefined behavior. - virtual xla::Status Reset() = 0; - - virtual std::unique_ptr Allocate( - int32_t core_id, MemoryRegion region, int64_t num_bytes, - absl::Span wait_for) = 0; - virtual std::unique_ptr Allocate( - int32_t core_id, MemoryRegion region, const xla::ShapeProto& shape, - absl::Span wait_for) = 0; - - // Allocate a buffer representing a tuple of `children` buffers. - // - // The returned tuple buffer handle does not manage the memory of `children`: - // all `children` buffer handles must outlive the last usage of this tuple - // buffer handle. One way to guarantee that is to deallocate the tuple buffer - // handle before deallocating any buffer handle in `children`. - // - // All `children` buffers must exist in the same `core_id` and `region`. - // If `children` is empty, a zero-sized tuple will be allocated in `region`. - virtual std::unique_ptr AllocateTuple( - int32_t core_id, MemoryRegion region, - absl::Span children, - absl::Span wait_for) = 0; - virtual std::shared_ptr Deallocate( - std::unique_ptr handle, - absl::Span wait_for) = 0; - - /* For buffers declared with an xla::ShapeProto rather than a raw size, - * `src` must be laid out in consecutive row-major format for ingestion, and - * each element must take up the number of bytes specified by the type. - * - * For example, for a [3,3,3] tensor with a Float32 type, the memory layout - * would be as follows: - * - * [0,0,0], [0,0,1], [0,0,2], [0,1,0], [0,1,1], ..., [0,2,2], [1,0,0], ... - * [1,2,2], [2,0,0], ..., [2,2,2], - * - * and the entire buffer will be 108 bytes (27 elements x 4 bytes). - * - * See - * https://eli.thegreenplace.net/2015/memory-layout-of-multi-dimensional-arrays - * for a more detailed description. - * - * `TransferFromDevice` will write out the shape back in this order as well. - */ - virtual std::shared_ptr TransferToDevice( - const void* src, BufferHandle* dst, - absl::Span wait_for) = 0; - virtual std::shared_ptr TransferFromDevice( - const BufferHandle* src, void* dst, - absl::Span wait_for) = 0; - - virtual std::shared_ptr TransferFromDeviceToDevice( - const BufferHandle* src, BufferHandle* dst, - absl::Span wait_for) = 0; - - virtual std::unique_ptr CompileProgram( - const xla::HloProto& source, int32_t num_replicas, - absl::Span wait_for, - const xla::DebugOptions& debug_options) = 0; - virtual std::unique_ptr LoadProgram( - int32_t core_id, const CompiledProgramHandle* handle, - absl::Span wait_for) = 0; - virtual std::shared_ptr UnloadProgram( - std::unique_ptr handle, - absl::Span wait_for) = 0; - virtual std::shared_ptr ExecuteProgram( - LoadedProgramHandle* program, absl::Span inputs, - absl::Span outputs, - const xla::DeviceAssignmentProto& device_assignment, - absl::Span wait_for) = 0; - - virtual std::unique_ptr GetLinearizer() { return nullptr; } -}; - -class TpuDriverRegistry { - public: - static xla::StatusOr> Open( - const TpuDriverConfig& config); - static int RegisterDriver( - const std::string& prefix, - const std::function>( - const TpuDriverConfig&)>& creator); -}; - -#define REGISTER_TPU_DRIVER(prefix, fn) \ - REGISTER_TPU_DRIVER_HELPER(__COUNTER__, prefix, fn) -#define REGISTER_TPU_DRIVER_HELPER(ctr, prefix, fn) \ - static int register_tpu_driver_count_unused_##ctr = \ - ::tpu_driver::TpuDriverRegistry::RegisterDriver(prefix, fn); - -} // namespace tpu_driver - -#endif // XLA_PYTHON_TPU_DRIVER_TPU_DRIVER_H_ diff --git a/third_party/xla/xla/python/tpu_driver/tpu_driver.proto b/third_party/xla/xla/python/tpu_driver/tpu_driver.proto deleted file mode 100644 index 98caebd1d4cf18..00000000000000 --- a/third_party/xla/xla/python/tpu_driver/tpu_driver.proto +++ /dev/null @@ -1,74 +0,0 @@ -// Copyright 2019 The OpenXLA Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================== - -syntax = "proto2"; - -package tpu_driver; - -enum MemoryRegion { - HBM = 1; -} - -message ChipCoordinate { - required int32 x = 1; - required int32 y = 2; - required int32 z = 3; -} - -message TpuCoreInfo { - required int32 id = 1; - optional int32 core_on_chip_index = 2; - optional int32 core_on_host_index = 3; - optional int64 hbm_bytes_available = 100; - optional int64 hbm_bytes_allocatable = 101; -} - -message TpuChipInfo { - repeated TpuCoreInfo core = 1; - optional int32 host_id = 2; - optional ChipCoordinate chip_coord = 3; -} - -message CpuInfo { - required int32 num_cpu_cores = 1; - required float cpu_load_average_1min = 2; - required int64 ram_bytes_total = 100; - required int64 ram_bytes_available = 101; -} - -message SystemInfo { - repeated TpuChipInfo tpu_chip = 1; - required CpuInfo cpu = 2; - repeated TpuCoreInfo local_core = 3; - optional int32 host_id = 4; - optional int32 host_count = 5; - optional int32 chip_count = 6; - optional int32 core_count = 7; -} - -message TpuDriverConfig { - optional string worker = 1; - - message GrpcConfig { - // Time in seconds before the initial connection to the server will timeout. - optional int64 connection_timeout_secs = 1 [default = 30]; - - // Time in seconds the server may be unresponsive before terminating the - // connection. - optional int64 keepalive_timeout_secs = 2 [default = 30]; - } - - optional GrpcConfig grpc = 2; -} diff --git a/third_party/xla/xla/python/tpu_driver/tpu_service.proto b/third_party/xla/xla/python/tpu_driver/tpu_service.proto deleted file mode 100644 index 3ce49583f7ae8b..00000000000000 --- a/third_party/xla/xla/python/tpu_driver/tpu_service.proto +++ /dev/null @@ -1,187 +0,0 @@ -// Copyright 2019 The OpenXLA Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================== - -syntax = "proto2"; - -package tpu_driver; - -import "xla/python/tpu_driver/tpu_driver.proto"; -import "xla/service/hlo.proto"; -import "xla/xla.proto"; -import "xla/xla_data.proto"; - -option optimize_for = SPEED; - -message StatusMessage { - required int32 code = 1; - optional string message = 2; -} - -message AllocateRequest { - required int32 core_id = 1; - required MemoryRegion region = 2; - oneof size { - int64 num_bytes = 3; - xla.ShapeProto shape = 4; - } -} - -message AllocateTupleRequest { - required int32 core_id = 1; - required MemoryRegion region = 2; - repeated int64 children = 3; -} - -message DeallocateRequest { - required int64 handle = 1; -} - -message TransferToDeviceRequest { - required int64 target_handle = 1; - required bytes data = 2; -} - -message TransferFromDeviceRequest { - required int64 source_handle = 1; -} - -message TransferFromDeviceResponse { - required bytes data = 2; -} - -message TransferFromDeviceToDeviceRequest { - required int64 source_handle = 1; - required int64 target_handle = 2; -} - -message CompileRequest { - required xla.HloProto hlo_program = 1; - optional int64 num_replicas = 2; - optional xla.DebugOptions debug_options = 3; -} - -message CompiledProgramMetadata { - required xla.ProgramShapeProto program_shape = 1; -} - -message CompileResponse { - required CompiledProgramMetadata metadata = 1; -} - -message LoadProgramRequest { - required int32 core_id = 1; - required int64 compiled_program_handle = 2; -} - -message UnloadProgramRequest { - required int64 loaded_program_handle = 1; -} - -message ExecuteRequest { - required int64 loaded_program_handle = 1; - repeated int64 input_handle = 2; - repeated int64 output_handle = 3; - optional xla.DeviceAssignmentProto device_assignment = 4; -} - -message StreamRequest { - message Entry { - oneof request { - AllocateRequest alloc = 1; - AllocateTupleRequest alloc_tuple = 2; - DeallocateRequest dealloc = 3; - TransferToDeviceRequest transfer_to = 4; - TransferFromDeviceRequest transfer_from = 5; - TransferFromDeviceToDeviceRequest transfer_from_to = 10; - CompileRequest compile = 6; - LoadProgramRequest load = 7; - UnloadProgramRequest unload = 8; - ExecuteRequest execute = 9; - } - // If specified, a list of encoded EventId values. - repeated int64 wait_for_id = 20; - // A unique, encoded EventId value. - // For Allocate, Compile, and Load, this also defines the result handle. - required int64 operation_id = 21; - - // A unique identifier for the thread that issued this request. Currently - // for debugging purposes only. - optional int64 thread_id = 22; - } - - repeated Entry entry = 30; -} - -message StreamResponse { - message Entry { - oneof response { - TransferFromDeviceResponse transfer_from = 3; - CompileResponse compile = 4; - } - required StatusMessage status = 10; - // Echos the given encoded EventId value. - required int64 operation_id = 11; - } - - repeated Entry entry = 20; -} - -message OpenRequest { - // The version number for this client. Versions are bumped in case of - // backwards incompatible client-server protocol changes. Servers will reject - // clients with an unsupported version. - optional int32 client_version = 1 [default = 0]; -} - -message OpenResponse { - required uint32 client_id = 1; - - // Maximum time this client can be idle before it is GC'ed and all resources - // released. - optional int32 max_idle_time_seconds = 2 [default = 3600]; -} - -message CloseRequest { - required fixed32 client_id = 1; -} - -message CloseResponse {} - -message ResetRequest {} - -message ResetResponse {} - -message QuerySystemInfoRequest {} - -message QuerySystemInfoResponse { - required SystemInfo system_info = 1; -} - -service CloudTpuDriver { - // Open the driver. If the driver is already open, return an error. - rpc Open(OpenRequest) returns (OpenResponse); - - // Close the driver. Any outstanding requests will be terminated. - rpc Close(CloseRequest) returns (CloseResponse); - - // Reset the driver. All connected clients will be disconnected. - rpc Reset(ResetRequest) returns (ResetResponse); - - // Query the driver for current system performance information. - rpc QuerySystemInfo(QuerySystemInfoRequest) returns (QuerySystemInfoResponse); - - // Enqueue an operation to be executed when its dependencies are satisfied. - rpc StreamExecute(stream StreamRequest) returns (stream StreamResponse); -} diff --git a/third_party/xla/xla/python/traceback.cc b/third_party/xla/xla/python/traceback.cc index bed3c6a99d6c84..2990a24f2689ad 100644 --- a/third_party/xla/xla/python/traceback.cc +++ b/third_party/xla/xla/python/traceback.cc @@ -28,6 +28,13 @@ limitations under the License. #include "xla/pjrt/exceptions.h" #include "xla/python/python_ref_manager.h" #include "tsl/platform/logging.h" +#include "tsl/platform/platform.h" + +#ifdef PLATFORM_GOOGLE +#define Py_BUILD_CORE +#include "internal/pycore_frame.h" +#undef Py_BUILD_CORE +#endif // PLATFORM_GOOGLE namespace xla { @@ -54,7 +61,22 @@ Traceback::Traceback() { Py_INCREF(py_frame->f_code); frames_.emplace_back(py_frame->f_code, py_frame->f_lasti * kLastiWordBytes); } -#else // PY_VERSION_HEX < 0x030b0000 +#else // PY_VERSION_HEX < 0x030b0000 + +#ifdef PLATFORM_GOOGLE + // This code is equivalent to the version using public APIs, but it saves us + // an allocation of one object per stack frame. However, this is definitely + // violating the API contract of CPython, so we only use this where we can be + // confident we know exactly which CPython we are using (internal to Google). + // Feel free to turn this on if you like, but it might break at any time! + for (_PyInterpreterFrame* f = thread_state->cframe->current_frame; + f != nullptr; f = f->previous) { + if (_PyFrame_IsIncomplete(f)) continue; + Py_INCREF(f->f_code); + frames_.emplace_back(f->f_code, + _PyInterpreterFrame_LASTI(f) * sizeof(_Py_CODEUNIT)); + } +#else // PLATFORM_GOOGLE PyFrameObject* next; for (PyFrameObject* py_frame = PyThreadState_GetFrame(thread_state); py_frame != nullptr; py_frame = next) { @@ -62,6 +84,8 @@ Traceback::Traceback() { next = PyFrame_GetBack(py_frame); Py_XDECREF(py_frame); } +#endif // PLATFORM_GOOGLE + #endif // PY_VERSION_HEX < 0x030b0000 } diff --git a/third_party/xla/xla/python/transfer_guard_lib.cc b/third_party/xla/xla/python/transfer_guard_lib.cc index bebb2fe788e8b7..6f02c9210c7fc5 100644 --- a/third_party/xla/xla/python/transfer_guard_lib.cc +++ b/third_party/xla/xla/python/transfer_guard_lib.cc @@ -114,7 +114,7 @@ xla::Status ApplyTransferGuardToHostToDevice( return xla::InvalidArgument("Disallowed host-to-device transfer: %s", formatter()); } - return ::tsl::OkStatus(); + return absl::OkStatus(); } xla::Status ApplyTransferGuardToDeviceToDevice( @@ -129,7 +129,7 @@ xla::Status ApplyTransferGuardToDeviceToDevice( return xla::InvalidArgument("Disallowed device-to-device transfer: %s", formatter()); } - return ::tsl::OkStatus(); + return absl::OkStatus(); } xla::Status ApplyTransferGuardToDeviceToHost( @@ -144,7 +144,7 @@ xla::Status ApplyTransferGuardToDeviceToHost( return xla::InvalidArgument("Disallowed device-to-host transfer: %s", formatter()); } - return ::tsl::OkStatus(); + return absl::OkStatus(); } void BuildTransferGuardSubmodule(py::module& m) { diff --git a/third_party/xla/xla/python/types.cc b/third_party/xla/xla/python/types.cc index 96912da7db0139..3e90e1ee5d242c 100644 --- a/third_party/xla/xla/python/types.cc +++ b/third_party/xla/xla/python/types.cc @@ -71,7 +71,7 @@ const CustomDtypes& GetCustomDtypes() { } // namespace -xla::StatusOr DtypeToPrimitiveType(const py::dtype& np_type) { +absl::StatusOr DtypeToPrimitiveType(const py::dtype& np_type) { static auto& builtin_dtypes = *new absl::flat_hash_map, PrimitiveType>({ {{'?', 'b', 1}, PRED}, @@ -131,7 +131,7 @@ xla::StatusOr DtypeToPrimitiveType(const py::dtype& np_type) { np_type.char_(), np_type.kind(), np_type.itemsize()); } -xla::StatusOr PrimitiveTypeToDtype(PrimitiveType type) { +absl::StatusOr PrimitiveTypeToDtype(PrimitiveType type) { const CustomDtypes& custom_dtypes = GetCustomDtypes(); switch (type) { case PRED: @@ -184,7 +184,7 @@ xla::StatusOr PrimitiveTypeToDtype(PrimitiveType type) { } } -StatusOr IfrtDtypeToDtype(ifrt::DType dtype) { +absl::StatusOr IfrtDtypeToDtype(ifrt::DType dtype) { const CustomDtypes& custom_dtypes = GetCustomDtypes(); switch (dtype.kind()) { case ifrt::DType::kPred: @@ -318,7 +318,7 @@ const char* PEP3118FormatDescriptorForPrimitiveType(PrimitiveType type) { } } -StatusOr TypeDescriptorForPrimitiveType(PrimitiveType type) { +absl::StatusOr TypeDescriptorForPrimitiveType(PrimitiveType type) { #if __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__ #define ENDIAN_PREFIX "<" #else @@ -414,7 +414,8 @@ std::vector StridesForShape(PrimitiveType element_type, /*innermost_stride_size=*/1); } -StatusOr LiteralToPython(std::shared_ptr literal) { +absl::StatusOr LiteralToPython( + std::shared_ptr literal) { xla::Literal& m = *literal; if (m.shape().IsTuple()) { std::vector elems = m.DecomposeTuple(); @@ -440,7 +441,8 @@ StatusOr LiteralToPython(std::shared_ptr literal) { literal_object); } -StatusOr GetPythonBufferTree(const py::object& argument) { +absl::StatusOr GetPythonBufferTree( + const py::object& argument) { PythonBufferTree tree; if (py::isinstance(argument)) { py::tuple tuple = py::reinterpret_borrow(argument); diff --git a/third_party/xla/xla/python/types.h b/third_party/xla/xla/python/types.h index 8bdaa1dae877d5..9d6e5b094f7215 100644 --- a/third_party/xla/xla/python/types.h +++ b/third_party/xla/xla/python/types.h @@ -40,20 +40,22 @@ limitations under the License. namespace xla { // Converts a NumPy dtype to a PrimitiveType. -StatusOr DtypeToPrimitiveType(const pybind11::dtype& np_type); +absl::StatusOr DtypeToPrimitiveType( + const pybind11::dtype& np_type); // Converts a PrimitiveType to a Numpy dtype. -StatusOr PrimitiveTypeToDtype(PrimitiveType type); +absl::StatusOr PrimitiveTypeToDtype(PrimitiveType type); // Converts an IFRT dtype to a NumPy dtype. -StatusOr IfrtDtypeToDtype(ifrt::DType dtype); +absl::StatusOr IfrtDtypeToDtype(ifrt::DType dtype); // Returns a Python buffer protocol (PEP 3118) format descriptor string for // `type`. Return nullptr if there is no suitable choice of format string. const char* PEP3118FormatDescriptorForPrimitiveType(PrimitiveType type); // Returns a numpy-style typestr for `type`, as returned by np.dtype(...).str -StatusOr TypeDescriptorForPrimitiveType(PrimitiveType type); +absl::StatusOr TypeDescriptorForPrimitiveType( + PrimitiveType type); struct NumpyScalarTypes { pybind11::object np_bool; @@ -100,7 +102,8 @@ std::vector StridesForShape(PrimitiveType element_type, // buffers with the literals. Takes ownership of `literal` and keeps the // necessary pieces alive using Python reference counting. // Requires the GIL. -StatusOr LiteralToPython(std::shared_ptr literal); +absl::StatusOr LiteralToPython( + std::shared_ptr literal); // Converts a Python object into an XLA shape and a vector of leaf buffers. // The leaf buffers correspond to a depth-first, left-to-right traversal of @@ -113,7 +116,7 @@ struct PythonBufferTree { absl::InlinedVector leaves; Shape shape; }; -StatusOr GetPythonBufferTree( +absl::StatusOr GetPythonBufferTree( const pybind11::object& argument); // Converts a sequence of C++ ints to a Python tuple of ints. diff --git a/third_party/xla/xla/python/xla.cc b/third_party/xla/xla/python/xla.cc index 9c8e85a80d31c9..a35eee3b60436d 100644 --- a/third_party/xla/xla/python/xla.cc +++ b/third_party/xla/xla/python/xla.cc @@ -572,12 +572,26 @@ static void Init(py::module_& m) { xla::StatusOr pjrt_api = pjrt::PjrtApi(platform_name); return pjrt_api.ok(); }); - m.def("load_pjrt_plugin", - [](std::string platform_name, std::string library_path) -> py::capsule { + m.def( + "load_pjrt_plugin", + [](std::string platform_name, std::optional library_path, + std::optional c_api) -> py::capsule { + if (library_path.has_value()) { const PJRT_Api* api = xla::ValueOrThrow( - pjrt::LoadPjrtPlugin(platform_name, library_path)); + pjrt::LoadPjrtPlugin(platform_name, *library_path)); return py::capsule(absl::bit_cast(api), "pjrt_c_api"); - }); + } + if (absl::string_view(c_api->name()) != "pjrt_c_api") { + throw py::value_error( + "c_api argument to load_pjrt_plugin is not a pjrt_c_api " + "capsule."); + } + xla::ThrowIfError(pjrt::SetPjrtApi( + platform_name, static_cast(*c_api))); + return *c_api; + }, + py::arg("platform_name"), py::arg("library_path") = std::nullopt, + py::arg("c_api") = std::nullopt); m.def("pjrt_plugin_initialized", [](std::string platform_name) -> bool { return xla::ValueOrThrow(pjrt::IsPjrtPluginInitialized(platform_name)); }); @@ -810,7 +824,11 @@ static void Init(py::module_& m) { }, py::arg("dlpack"), py::arg("cpu_backend") = nullptr, py::arg("gpu_backend") = nullptr); - + m.def("cuda_array_interface_to_buffer", + [](const pybind11::dict& cai, std::shared_ptr cuda_client) { + return xla::ValueOrThrow( + CudaArrayInterfaceToBuffer(cai, std::move(cuda_client))); + }); BuildProfilerSubmodule(&m); BuildOpsSubmodule(&m); BuildOutfeedReceiverSubmodule(&m); @@ -910,6 +928,16 @@ static void Init(py::module_& m) { xla::ThrowIfError(client.KeyValueSet(key, value)); }, py::arg("key"), py::arg("value")) + // The key must be a string, but the value must a Python bytes object. + // Use `key_value_set_bytes()` and `blocking_key_value_get_bytes()`. + .def( + "key_value_set_bytes", + [](DistributedRuntimeClient& client, std::string key, + py::bytes value) { + py::gil_scoped_release gil_release; + xla::ThrowIfError(client.KeyValueSet(key, value)); + }, + py::arg("key"), py::arg("value")) // Assumes that all values in the directory are Python strings. .def( "key_value_dir_get", diff --git a/third_party/xla/xla/python/xla_client.py b/third_party/xla/xla/python/xla_client.py index c0e8e9c1f717a3..806f522733f12d 100644 --- a/third_party/xla/xla/python/xla_client.py +++ b/third_party/xla/xla/python/xla_client.py @@ -24,7 +24,7 @@ import logging import os import threading -from typing import Any, List, Mapping, Optional, Sequence, Tuple, Union +from typing import Any, List, Mapping, Optional, Protocol, Sequence, Tuple, Union import ml_dtypes import numpy as np @@ -48,7 +48,7 @@ # Just an internal arbitrary increasing number to help with backward-compatible # changes. In JAX, reference this via jax._src.lib.xla_extension_version. -_version = 234 +_version = 238 # Version number for MLIR:Python components. mlir_api_version = 55 @@ -145,7 +145,11 @@ def pjrt_plugin_loaded(plugin_name: str) -> bool: def load_pjrt_plugin_dynamically(plugin_name: str, library_path: str) -> Any: - return _xla.load_pjrt_plugin(plugin_name, library_path) + return _xla.load_pjrt_plugin(plugin_name, library_path, c_api=None) + + +def load_pjrt_plugin_with_c_api(plugin_name: str, c_api: Any) -> None: + return _xla.load_pjrt_plugin(plugin_name, None, c_api) def pjrt_plugin_initialized(plugin_name: str) -> bool: @@ -209,10 +213,13 @@ def generate_pjrt_gpu_plugin_options( options = {} if visible_devices != 'all': options['visible_devices'] = [int(x) for x in visible_devices.split(',')] - options['platform_name'] = 'cuda' + options['platform_name'] = 'cuda' allocator = os.getenv('XLA_PYTHON_CLIENT_ALLOCATOR', 'default').lower() memory_fraction = os.getenv('XLA_PYTHON_CLIENT_MEM_FRACTION', '') preallocate = os.getenv('XLA_PYTHON_CLIENT_PREALLOCATE', '') + collective_memory_size = os.getenv( + 'XLA_PYTHON_CLIENT_COLLECTIVE_MEM_SIZE_MB', '' + ) if allocator not in ('default', 'platform', 'bfc', 'cuda_async'): raise ValueError( 'XLA_PYTHON_CLIENT_ALLOCATOR env var must be "default", "platform", ' @@ -223,6 +230,8 @@ def generate_pjrt_gpu_plugin_options( options['memory_fraction'] = float(memory_fraction) if preallocate: options['preallocate'] = preallocate not in ('false', 'False', '0') + if collective_memory_size: + options['collective_memory_size'] = int(collective_memory_size) * (1 << 20) return options @@ -555,14 +564,22 @@ def LoadedExecutable_execute_with_token(self, arguments, device=None): LoadedExecutable.execute_with_token = LoadedExecutable_execute_with_token -_custom_callback_handler: dict[str, Any] = {} -# Key is xla_platform_name, value is (function_name, function) -_custom_callback: dict[str, list[Tuple[str, Any]]] = {} +class CustomCallHandler(Protocol): + + def __call__( + self, name: str, fn: Any, platform: str, /, api_version: int = ... + ) -> None: + ... + + +_custom_callback_handler: dict[str, CustomCallHandler] = {} +# Key is xla_platform_name, value is (function_name, function, api_version) +_custom_callback: dict[str, list[tuple[str, Any, int]]] = {} _custom_callback_lock = threading.Lock() def register_custom_call_target( - name: str, fn: Any, platform: str = 'cpu' + name: str, fn: Any, platform: str = 'cpu', api_version: int = 0 ) -> None: """Registers a custom call target. @@ -570,18 +587,26 @@ def register_custom_call_target( name: bytes containing the name of the function. fn: a PyCapsule object containing the function pointer. platform: the target platform. + api_version: the XLA FFI version to use. Supported versions are: 0 for the + untyped FFI and 1 for the typed FFI. """ # To support AMD GPUs, we need to have xla_platform_names["gpu"] == "ROCM" # Since that is hardcoded to CUDA, we are using the following as workaround. xla_platform_name = xla_platform_names.get(platform, platform) with _custom_callback_lock: if xla_platform_name in _custom_callback_handler: - _custom_callback_handler[xla_platform_name](name, fn, xla_platform_name) + _custom_callback_handler[xla_platform_name]( + name, fn, xla_platform_name, api_version + ) else: - _custom_callback.setdefault(xla_platform_name, []).append((name, fn)) + _custom_callback.setdefault(xla_platform_name, []).append( + (name, fn, api_version) + ) -def register_custom_call_handler(platform: str, handler: Any) -> None: +def register_custom_call_handler( + platform: str, handler: CustomCallHandler +) -> None: """Registers a custom handler and use it to register existing custom calls. If a custom call handler for the platform already exist, calling this method @@ -601,8 +626,8 @@ def register_custom_call_handler(platform: str, handler: Any) -> None: return _custom_callback_handler[xla_platform_name] = handler if xla_platform_name in _custom_callback: - for name, fn in _custom_callback[xla_platform_name]: - handler(name, fn, xla_platform_name) + for name, fn, api_version in _custom_callback[xla_platform_name]: + handler(name, fn, xla_platform_name, api_version) del _custom_callback[xla_platform_name] diff --git a/third_party/xla/xla/python/xla_client.pyi b/third_party/xla/xla/python/xla_client.pyi index f0625879d5dd87..758facc27f062f 100644 --- a/third_party/xla/xla/python/xla_client.pyi +++ b/third_party/xla/xla/python/xla_client.pyi @@ -122,6 +122,9 @@ def pjrt_plugin_loaded(plugin_name: str) -> bool: def load_pjrt_plugin_dynamically(plugin_name: str, library_path: str) -> Any: ... +def load_pjrt_plugin_with_c_api(plugin_name: str, c_api: Any) -> None: + ... + def pjrt_plugin_initialized(plugin_name: str) -> bool: ... @@ -235,7 +238,7 @@ def array_result_handler( ... def register_custom_call_target( - name: str, fn: Callable, platform: str = ... + name: str, fn: Callable, platform: str = ..., api_version: int = ... ) -> None: ... diff --git a/third_party/xla/xla/python/xla_client_test.py b/third_party/xla/xla/python/xla_client_test.py index 2926a8c7bf3927..995497d4cc4082 100644 --- a/third_party/xla/xla/python/xla_client_test.py +++ b/third_party/xla/xla/python/xla_client_test.py @@ -82,6 +82,8 @@ def jax_array_copy_to_host_async(self): # use widely for parameterizing tests. # pylint: disable=g-complex-comprehension +_CUSTOM_CALLS_REGISTERED = False + def TestFactory(xla_backend, cloud_tpu=False, @@ -109,6 +111,12 @@ def setUp(self): super(ComputationTest, self).setUp() self.backend = xla_backend() + global _CUSTOM_CALLS_REGISTERED + if self.backend.platform == "cpu" and not _CUSTOM_CALLS_REGISTERED: + for name, fn in custom_call_for_test.cpu_custom_call_targets.items(): + xla_client.register_custom_call_target(name, fn, platform="cpu") + _CUSTOM_CALLS_REGISTERED = True + def _NewComputation(self, name=None): if name is None: name = self.id() @@ -403,8 +411,6 @@ def testCustomCall(self): if self.backend.platform != "cpu": self.skipTest("Test requires cpu platform") c = self._NewComputation() - for name, fn in custom_call_for_test.cpu_custom_call_targets.items(): - xla_client.register_custom_call_target(name, fn, platform="cpu") ops.CustomCallWithLayout( c, b"test_subtract_f32", @@ -426,8 +432,6 @@ def testCustomCallWithUnifiedApi(self): if self.backend.platform != "cpu": self.skipTest("Test requires cpu platform") c = self._NewComputation() - for name, fn in custom_call_for_test.cpu_custom_call_targets.items(): - xla_client.register_custom_call_target(name, fn, platform="cpu") opaque_str = b"foo" ops.CustomCallWithLayout( diff --git a/third_party/xla/xla/python/xla_compiler.cc b/third_party/xla/xla/python/xla_compiler.cc index f7d0f03dd32ece..787c103030129d 100644 --- a/third_party/xla/xla/python/xla_compiler.cc +++ b/third_party/xla/xla/python/xla_compiler.cc @@ -23,7 +23,9 @@ limitations under the License. #include #include "absl/hash/hash.h" +#include "absl/status/status.h" #include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" #include "absl/strings/str_join.h" #include "absl/strings/string_view.h" #include "absl/synchronization/mutex.h" @@ -39,6 +41,9 @@ limitations under the License. #include "xla/client/xla_builder.h" #include "xla/client/xla_computation.h" #include "xla/debug_options_flags.h" +#include "xla/ffi/api/c_api.h" +#include "xla/ffi/ffi.h" +#include "xla/ffi/ffi_api.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_module_group.h" @@ -247,18 +252,32 @@ StatusOr IotaTileHelper( // 'fn_capsule' must be a void* pointer encapsulated in a PyCapsule object, // with name "xla._CUSTOM_CALL_TARGET". // 'platform' is an XLA platform name, e.g., "Host" or "CUDA". -Status PyRegisterCustomCallTarget(const std::string& fn_name, - py::capsule capsule, - const std::string& platform) { +absl::Status PyRegisterCustomCallTarget(const std::string& fn_name, + py::capsule capsule, + const std::string& platform, + int api_version) { static const char* const kName = "xla._CUSTOM_CALL_TARGET"; if (absl::string_view(capsule.name()) != kName) { return InvalidArgument( - "Argument to RegisterCustomCallTargetRegistry was not a " + "Argument to RegisterCustomCallTarget was not a " "xla._CUSTOM_CALL_TARGET capsule."); } - CustomCallTargetRegistry::Global()->Register( - fn_name, static_cast(capsule), platform); - return OkStatus(); + switch (api_version) { + case 0: + CustomCallTargetRegistry::Global()->Register( + fn_name, static_cast(capsule), platform); + return absl::OkStatus(); + case 1: + ffi::Ffi::RegisterStaticHandler( + xla::ffi::GetXlaFfiApi(), fn_name, platform, + reinterpret_cast(static_cast(capsule))); + return absl::OkStatus(); + default: + return absl::UnimplementedError(absl::StrFormat( + "API version %d is not supported by RegisterCustomCallTarget. " + "Supported versions are 0 and 1.", + api_version)); + } } template @@ -865,12 +884,15 @@ void BuildXlaCompilerSubmodule(py::module& m) { }); // Custom-call targets. - m.def("register_custom_call_target", - [](const std::string& fn_name, py::capsule capsule, - const std::string& platform) { - xla::ThrowIfError(PyRegisterCustomCallTarget( - fn_name, std::move(capsule), platform)); - }); + m.def( + "register_custom_call_target", + [](const std::string& fn_name, py::capsule capsule, + const std::string& platform, const int api_version) { + xla::ThrowIfError(PyRegisterCustomCallTarget( + fn_name, std::move(capsule), platform, api_version)); + }, + py::arg("fn_name"), py::arg("capsule"), py::arg("platform"), + py::arg("api_version") = 0); py::class_(m, "DebugOptions") .def("__repr__", &DebugOptions::DebugString) @@ -902,6 +924,16 @@ void BuildXlaCompilerSubmodule(py::module& m) { .def_property("xla_gpu_enable_fast_min_max", &DebugOptions::xla_gpu_enable_fast_min_max, &DebugOptions::set_xla_gpu_enable_fast_min_max) + .def_property("xla_gpu_dump_autotune_results_to", + &DebugOptions::xla_gpu_dump_autotune_results_to, + [](DebugOptions* self, std::string value) { + self->set_xla_gpu_dump_autotune_results_to(value); + }) + .def_property("xla_gpu_load_autotune_results_from", + &DebugOptions::xla_gpu_load_autotune_results_from, + [](DebugOptions* self, std::string value) { + self->set_xla_gpu_load_autotune_results_from(value); + }) .def_property("xla_gpu_cuda_data_dir", &DebugOptions::xla_gpu_cuda_data_dir, [](DebugOptions* self, std::string value) { @@ -1050,6 +1082,17 @@ void BuildXlaCompilerSubmodule(py::module& m) { "auto_spmd_partitioning_mesh_ids", &ExecutableBuildOptions::auto_spmd_partitioning_mesh_ids, &ExecutableBuildOptions::set_auto_spmd_partitioning_mesh_ids) + .def_property( + "allow_spmd_sharding_propagation_to_parameters", + [](const ExecutableBuildOptions& options) -> std::vector { + return std::vector( + options.allow_spmd_sharding_propagation_to_parameters().begin(), + options.allow_spmd_sharding_propagation_to_parameters().end()); + }, + [](ExecutableBuildOptions& options, std::vector values) { + absl::InlinedVector v(values.begin(), values.end()); + options.set_allow_spmd_sharding_propagation_to_parameters(v); + }) .def_property( "allow_spmd_sharding_propagation_to_output", [](const ExecutableBuildOptions& options) -> std::vector { diff --git a/third_party/xla/xla/python/xla_extension/__init__.pyi b/third_party/xla/xla/python/xla_extension/__init__.pyi index 3bad183072267f..a666143183a446 100644 --- a/third_party/xla/xla/python/xla_extension/__init__.pyi +++ b/third_party/xla/xla/python/xla_extension/__init__.pyi @@ -258,7 +258,7 @@ class CompileOptions: env_option_overrides: List[Tuple[str, str]] def register_custom_call_target( - fn_name: str, capsule: Any, platform: str + fn_name: str, capsule: Any, platform: str, api_version: int = ..., ) -> _Status: ... def register_custom_call_partitioner( name: str, @@ -308,6 +308,8 @@ class DebugOptions: xla_gpu_cuda_data_dir: str xla_detailed_logging: bool xla_enable_dumping: bool + xla_gpu_dump_autotune_results_to: str + xla_gpu_load_autotune_results_from: str class CompiledMemoryStats: generated_code_size_in_bytes: int @@ -562,7 +564,7 @@ def get_default_c_api_topology( options: Dict[str, Union[str, int, List[int], float]], ) -> DeviceTopology: ... def get_topology_for_devices(devices: List[Device]) -> DeviceTopology: ... -def load_pjrt_plugin(platform_name: str, library_path: str) -> _Status: ... +def load_pjrt_plugin(platform_name: str, library_path: Optional[str], c_api: Optional[Any]) -> _Status: ... def pjrt_plugin_loaded(plugin_name: str) -> bool: ... def pjrt_plugin_initialized(plugin_name: str) -> bool: ... def initialize_pjrt_plugin(platform_name: str) -> _Status: ... @@ -689,6 +691,16 @@ def dlpack_managed_tensor_to_buffer( tensor: Any, device: Device, stream: int | None ) -> ArrayImpl: ... +def cuda_array_interface_to_buffer( + cai: Dict[str, Union[ + str, int, None, + Tuple[int, ...], Tuple[int, bool], + List[Tuple[str, str]], + List[Tuple[str, str, Tuple[int, ...]]]] + ], + gpu_backend: Optional[Client] = ..., +) -> ArrayImpl: ... + # Legacy overload def dlpack_managed_tensor_to_buffer( tensor: Any, @@ -737,6 +749,7 @@ class DistributedRuntimeClient: def key_value_dir_get(self, key: str) -> _Status: ... def key_value_dir_get_bytes(self, key: str) -> _Status: ... def key_value_set(self, key: str, value: str) -> _Status: ... + def key_value_set_bytes(self, key: str, value: bytes) -> _Status: ... def key_value_delete(self, key: str) -> _Status: ... def wait_at_barrier(self, barrier_id: str, timeout_in_ms: int) -> _Status: ... diff --git a/third_party/xla/xla/python/xla_extension/pytree.pyi b/third_party/xla/xla/python/xla_extension/pytree.pyi index e021b6b972bef5..24421a857ca8dd 100644 --- a/third_party/xla/xla/python/xla_extension/pytree.pyi +++ b/third_party/xla/xla/python/xla_extension/pytree.pyi @@ -53,7 +53,6 @@ class PyTreeDef: def children(self) -> List[PyTreeDef]: ... @staticmethod def make_from_node_data_and_children( - self, registry: PyTreeRegistry, node_data: Optional[Tuple[Type, Any]], children: Iterable[PyTreeDef], @@ -71,7 +70,7 @@ class PyTreeDef: def serialize_using_proto(self) -> bytes: ... @staticmethod def deserialize_using_proto( - self, registry: PyTreeRegistry, data: bytes + registry: PyTreeRegistry, data: bytes ) -> PyTreeDef: ... diff --git a/third_party/xla/xla/python_api/BUILD b/third_party/xla/xla/python_api/BUILD index aaac07c9bcc78a..1cd1b6cdb834be 100644 --- a/third_party/xla/xla/python_api/BUILD +++ b/third_party/xla/xla/python_api/BUILD @@ -5,7 +5,6 @@ load("//xla:strict.default.bzl", "py_strict_library", "py_strict_test") load("//xla/tests:build_defs.bzl", "generate_backend_suites") package( - default_visibility = ["//visibility:public"], # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], licenses = ["notice"], ) diff --git a/third_party/xla/xla/runtime/BUILD b/third_party/xla/xla/runtime/BUILD index 324082c9f1cf69..b621f0354fcae2 100644 --- a/third_party/xla/xla/runtime/BUILD +++ b/third_party/xla/xla/runtime/BUILD @@ -1,10 +1,12 @@ load("//xla:xla.bzl", "xla_cc_test") +load("@local_tsl//tsl:tsl.bzl", "internal_visibility") load("@local_tsl//tsl:tsl.default.bzl", "filegroup", "get_compatible_with_portable") load("@local_tsl//tsl/platform:build_config.bzl", "if_llvm_system_z_available", "tf_platform_deps") load("@local_tsl//tsl/platform:rules_cc.bzl", "cc_library") package( - default_visibility = ["//visibility:public"], + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = internal_visibility(["//xla:internal"]), licenses = ["notice"], ) @@ -13,7 +15,6 @@ cc_library( srcs = ["arguments.cc"], hdrs = ["arguments.h"], compatible_with = get_compatible_with_portable(), - visibility = ["//visibility:public"], deps = [ ":async_runtime", ":types", @@ -40,7 +41,6 @@ cc_library( srcs = ["async_runtime.cc"], hdrs = ["async_runtime.h"], compatible_with = get_compatible_with_portable(), - visibility = ["//visibility:public"], deps = [ "@com_google_absl//absl/base:dynamic_annotations", "@local_tsl//tsl/concurrency:async_value", @@ -65,7 +65,6 @@ cc_library( name = "async_values_cache", hdrs = ["async_values_cache.h"], compatible_with = get_compatible_with_portable(), - visibility = ["//visibility:public"], deps = [ "@local_tsl//tsl/platform", ] + tf_platform_deps( @@ -79,7 +78,6 @@ cc_library( srcs = ["constraints.cc"], hdrs = ["constraints.h"], compatible_with = get_compatible_with_portable(), - visibility = ["//visibility:public"], deps = [ "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", @@ -91,7 +89,6 @@ cc_library( filegroup( name = "aot_ffi_execution_context_hdrs", srcs = ["aot_ffi_execution_context.h"], - visibility = ["//visibility:public"], ) cc_library( @@ -124,7 +121,6 @@ cc_library( srcs = ["custom_call.cc"], hdrs = ["custom_call.h"], compatible_with = get_compatible_with_portable(), - visibility = ["//visibility:public"], deps = [ ":async_runtime", ":diagnostics", @@ -173,7 +169,6 @@ cc_library( srcs = ["custom_call_registry.cc"], hdrs = ["custom_call_registry.h"], compatible_with = get_compatible_with_portable(), - visibility = ["//visibility:public"], deps = [ ":custom_call", "@llvm-project//llvm:Support", @@ -185,7 +180,6 @@ cc_library( srcs = ["diagnostics.cc"], hdrs = ["diagnostics.h"], compatible_with = get_compatible_with_portable(), - visibility = ["//visibility:public"], deps = [ ":logical_result", "@com_google_absl//absl/status", @@ -209,7 +203,6 @@ cc_library( name = "errors", hdrs = ["errors.h"], compatible_with = get_compatible_with_portable(), - visibility = ["//visibility:public"], deps = [ "@com_google_absl//absl/status", "@com_google_absl//absl/strings:str_format", @@ -221,7 +214,6 @@ cc_library( srcs = ["executable.cc"], hdrs = ["executable.h"], compatible_with = get_compatible_with_portable(), - visibility = ["//visibility:public"], deps = [ ":arguments", ":async_runtime", @@ -276,7 +268,6 @@ cc_library( srcs = ["execution_engine.cc"], hdrs = ["execution_engine.h"], compatible_with = get_compatible_with_portable(), - visibility = ["//visibility:public"], deps = [ ":errors", "@com_google_absl//absl/log:check", @@ -317,7 +308,6 @@ cc_library( srcs = ["jit_executable.cc"], hdrs = ["jit_executable.h"], compatible_with = get_compatible_with_portable(), - visibility = ["//visibility:public"], deps = [ ":async_values_cache", ":constraints", @@ -337,7 +327,6 @@ cc_library( name = "logical_result", hdrs = ["logical_result.h"], compatible_with = get_compatible_with_portable(), - visibility = ["//visibility:public"], deps = ["@llvm-project//mlir:Support"], ) @@ -345,7 +334,6 @@ cc_library( name = "map_by_type", hdrs = ["map_by_type.h"], compatible_with = get_compatible_with_portable(), - visibility = ["//visibility:public"], deps = [ ":type_id", "@llvm-project//llvm:Support", @@ -368,7 +356,6 @@ cc_library( srcs = ["memory_mapper.cc"], hdrs = ["memory_mapper.h"], compatible_with = get_compatible_with_portable(), - visibility = ["//visibility:public"], deps = [ "@llvm-project//llvm:ExecutionEngine", "@llvm-project//llvm:Support", @@ -383,7 +370,6 @@ cc_library( name = "memref_view", hdrs = ["memref_view.h"], compatible_with = get_compatible_with_portable(), - visibility = ["//visibility:public"], deps = [ "//xla:xla_data_proto_cc", "@com_google_absl//absl/types:span", @@ -394,7 +380,6 @@ cc_library( name = "module", hdrs = ["module.h"], compatible_with = get_compatible_with_portable(), - visibility = ["//visibility:public"], deps = [ ":custom_call_registry", "@com_google_absl//absl/status", @@ -407,7 +392,6 @@ cc_library( srcs = ["module_registry.cc"], hdrs = ["module_registry.h"], compatible_with = get_compatible_with_portable(), - visibility = ["//visibility:public"], deps = [ ":module", ], @@ -428,7 +412,6 @@ cc_library( name = "results", hdrs = ["results.h"], compatible_with = get_compatible_with_portable(), - visibility = ["//visibility:public"], deps = [ ":logical_result", ":types", @@ -452,14 +435,12 @@ cc_library( name = "runtime", hdrs = ["runtime.h"], compatible_with = get_compatible_with_portable(), - visibility = ["//visibility:public"], ) cc_library( name = "state", hdrs = ["state.h"], compatible_with = get_compatible_with_portable(), - visibility = ["//visibility:public"], deps = [ "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", @@ -483,7 +464,6 @@ cc_library( srcs = ["symbolic_shape.cc"], hdrs = ["symbolic_shape.h"], compatible_with = get_compatible_with_portable(), - visibility = ["//visibility:public"], deps = [ ":arguments", ":constraints", @@ -515,7 +495,6 @@ cc_library( srcs = ["types.cc"], hdrs = ["types.h"], compatible_with = get_compatible_with_portable(), - visibility = ["//visibility:public"], deps = [ "//xla:shape_util", "//xla:xla_data_proto_cc", @@ -530,7 +509,6 @@ cc_library( name = "tracing", hdrs = ["tracing.h"], compatible_with = get_compatible_with_portable(), - visibility = ["//visibility:public"], deps = [ ":custom_call", ":type_id", @@ -542,7 +520,6 @@ cc_library( srcs = ["type_id.cc"], hdrs = ["type_id.h"], compatible_with = get_compatible_with_portable(), - visibility = ["//visibility:public"], deps = [ "@com_google_absl//absl/container:flat_hash_map", "@llvm-project//mlir:Support", @@ -553,13 +530,11 @@ cc_library( name = "compiler", hdrs = ["compiler.h"], compatible_with = get_compatible_with_portable(), - visibility = ["//visibility:public"], ) cc_library( name = "cpu_event", hdrs = ["cpu_event.h"], - visibility = ["//visibility:public"], ) xla_cc_test( diff --git a/third_party/xla/xla/runtime/default/BUILD b/third_party/xla/xla/runtime/default/BUILD index 6607072a8d40f7..0cce6f64235abb 100644 --- a/third_party/xla/xla/runtime/default/BUILD +++ b/third_party/xla/xla/runtime/default/BUILD @@ -2,7 +2,8 @@ load("@local_tsl//tsl:tsl.default.bzl", "get_compatible_with_portable") load("@local_tsl//tsl/platform:rules_cc.bzl", "cc_library") package( - default_visibility = ["//visibility:public"], + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = ["//xla/runtime:__pkg__"], licenses = ["notice"], ) @@ -10,7 +11,6 @@ cc_library( name = "async_values_cache", hdrs = ["async_values_cache.h"], compatible_with = get_compatible_with_portable(), - visibility = ["//visibility:public"], deps = [ "@com_google_absl//absl/synchronization", "@llvm-project//llvm:Support", @@ -22,6 +22,5 @@ cc_library( name = "memory_mapper", hdrs = ["memory_mapper.h"], compatible_with = get_compatible_with_portable(), - visibility = ["//visibility:public"], deps = [], ) diff --git a/third_party/xla/xla/runtime/ffi/BUILD b/third_party/xla/xla/runtime/ffi/BUILD index c7b5092e88e026..f971eadcf16130 100644 --- a/third_party/xla/xla/runtime/ffi/BUILD +++ b/third_party/xla/xla/runtime/ffi/BUILD @@ -2,6 +2,7 @@ load("@local_tsl//tsl:tsl.default.bzl", "filegroup", "get_compatible_with_portab load("@local_tsl//tsl/platform:rules_cc.bzl", "cc_library") package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], default_visibility = ["//visibility:public"], ) @@ -12,21 +13,18 @@ filegroup( "ffi_api.h", "ffi_c_api.h", ], - visibility = ["//visibility:public"], ) cc_library( name = "ffi_abi", hdrs = ["ffi_abi.h"], compatible_with = get_compatible_with_portable(), - visibility = ["//visibility:public"], ) cc_library( name = "ffi_api", hdrs = ["ffi_api.h"], compatible_with = get_compatible_with_portable(), - visibility = ["//visibility:public"], deps = [ ":ffi_abi", ":ffi_c_api_hdrs", @@ -37,5 +35,4 @@ cc_library( name = "ffi_c_api_hdrs", hdrs = ["ffi_c_api.h"], compatible_with = get_compatible_with_portable(), - visibility = ["//visibility:public"], ) diff --git a/third_party/xla/xla/service/BUILD b/third_party/xla/xla/service/BUILD index 2ecaefe817d033..fed47da7fdbf99 100644 --- a/third_party/xla/xla/service/BUILD +++ b/third_party/xla/xla/service/BUILD @@ -20,7 +20,7 @@ load( "if_rocm", "if_rocm_is_configured", ) -load("@local_tsl//tsl:tsl.bzl", "if_google", "if_libtpu") +load("@local_tsl//tsl:tsl.bzl", "if_google", "if_libtpu", "internal_visibility") load("@local_tsl//tsl:tsl.default.bzl", "filegroup", "get_compatible_with_portable", "internal_hlo_deps") load( "@local_tsl//tsl/platform:build_config.bzl", @@ -34,7 +34,8 @@ load( ) package( - default_visibility = ["//visibility:public"], + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = internal_visibility([":friends"]), licenses = ["notice"], ) @@ -65,7 +66,6 @@ tf_proto_library( name = "hlo_profile_printer_data", srcs = ["hlo_profile_printer_data.proto"], cc_api_version = 2, - visibility = ["//visibility:public"], ) tf_proto_library( @@ -73,20 +73,17 @@ tf_proto_library( srcs = ["hlo_execution_profile_data.proto"], cc_api_version = 2, protodeps = [":hlo_profile_printer_data"], - visibility = ["//visibility:public"], ) tf_proto_library( name = "metrics_proto", srcs = ["metrics.proto"], cc_api_version = 2, - visibility = ["//visibility:public"], ) xla_py_proto_library( name = "metrics_pb2", api_version = 2, - visibility = ["//visibility:public"], deps = [":metrics_proto"], ) @@ -97,14 +94,12 @@ filegroup( "**/*.cc", "**/*.h", ]), - visibility = ["//visibility:public"], ) cc_library( name = "collective_opt_utils", srcs = ["collective_opt_utils.cc"], hdrs = ["collective_opt_utils.h"], - visibility = ["//visibility:public"], deps = [ "//xla/hlo/ir:hlo", "@com_google_absl//absl/algorithm:container", @@ -115,11 +110,11 @@ cc_library( name = "async_collective_creator", srcs = ["async_collective_creator.cc"], hdrs = ["async_collective_creator.h"], - visibility = ["//visibility:public"], deps = [ ":hlo_pass", ":shape_inference", "//xla:frontend_attributes", + "//xla:util", "//xla/hlo/ir:hlo", "@com_google_absl//absl/container:flat_hash_map", "@local_tsl//tsl/platform:errors", @@ -145,7 +140,6 @@ cc_library( name = "all_reduce_key", srcs = ["all_reduce_key.cc"], hdrs = ["all_reduce_key.h"], - visibility = ["//visibility:public"], deps = [ ":hlo_domain_map", "//xla/hlo/ir:hlo", @@ -156,7 +150,6 @@ cc_library( name = "all_reduce_promotion", srcs = ["all_reduce_promotion.cc"], hdrs = ["all_reduce_promotion.h"], - visibility = ["//visibility:public"], deps = [":change_op_data_type"], ) @@ -176,7 +169,6 @@ cc_library( name = "all_reduce_reassociate", srcs = ["all_reduce_reassociate.cc"], hdrs = ["all_reduce_reassociate.h"], - visibility = ["//visibility:public"], deps = [ ":all_reduce_key", ":collective_ops_utils", @@ -212,7 +204,6 @@ cc_library( name = "all_reduce_folder", srcs = ["all_reduce_folder.cc"], hdrs = ["all_reduce_folder.h"], - visibility = ["//visibility:public"], deps = [ ":all_reduce_key", ":hlo_pass", @@ -242,7 +233,6 @@ cc_library( name = "float_support", srcs = ["float_support.cc"], hdrs = ["float_support.h"], - visibility = ["//visibility:public"], deps = [ "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", @@ -253,7 +243,6 @@ cc_library( name = "broadcast_canonicalizer", srcs = ["broadcast_canonicalizer.cc"], hdrs = ["broadcast_canonicalizer.h"], - visibility = ["//visibility:public"], deps = [ ":hlo_creation_utils", ":hlo_pass", @@ -277,7 +266,6 @@ cc_library( name = "bfloat16_conversion_folding", srcs = ["bfloat16_conversion_folding.cc"], hdrs = ["bfloat16_conversion_folding.h"], - visibility = ["//visibility:public"], deps = [ ":float_support", ":hlo_dataflow_analysis", @@ -310,8 +298,8 @@ cc_library( name = "float_normalization", srcs = ["float_normalization.cc"], hdrs = ["float_normalization.h"], - visibility = ["//visibility:public"], deps = [ + ":call_graph", ":float_support", ":hlo_dce", ":hlo_pass", @@ -350,7 +338,6 @@ cc_library( name = "bfloat16_propagation", srcs = ["bfloat16_propagation.cc"], hdrs = ["bfloat16_propagation.h"], - visibility = ["//visibility:public"], deps = [ ":float_support", ":hlo_dataflow_analysis", @@ -391,7 +378,6 @@ cc_library( name = "collective_permute_decomposer", srcs = ["collective_permute_decomposer.cc"], hdrs = ["collective_permute_decomposer.h"], - visibility = ["//visibility:public"], deps = [ ":collective_ops_utils", ":hlo_pass", @@ -421,7 +407,6 @@ cc_library( name = "constant_value", srcs = ["constant_value.cc"], hdrs = ["constant_value.h"], - visibility = ["//visibility:public"], deps = [ "//xla:literal", "//xla:statusor", @@ -443,7 +428,6 @@ cc_library( name = "convert_async_collectives_to_sync", srcs = ["convert_async_collectives_to_sync.cc"], hdrs = ["convert_async_collectives_to_sync.h"], - visibility = ["//visibility:public"], deps = [ ":hlo_pass", "//xla:util", @@ -460,7 +444,6 @@ cc_library( name = "value_range", srcs = ["value_range.cc"], hdrs = ["value_range.h"], - visibility = ["//visibility:public"], deps = [ ":constant_value", "//xla/hlo/ir:hlo", @@ -498,7 +481,6 @@ cc_library( name = "collective_pipeliner", srcs = ["collective_pipeliner.cc"], hdrs = ["collective_pipeliner.h"], - visibility = ["//visibility:public"], deps = [ ":constant_value", ":hlo_dce", @@ -546,7 +528,6 @@ cc_library( name = "dump", srcs = ["dump.cc"], hdrs = ["dump.h"], - visibility = ["//visibility:public"], deps = [ ":hlo_graph_dumper", ":hlo_proto_util", @@ -575,7 +556,6 @@ cc_library( name = "shape_inference", srcs = ["shape_inference.cc"], hdrs = ["shape_inference.h"], - visibility = ["//visibility:public"], deps = [ "//xla:permutation_util", "//xla:shape_util", @@ -646,7 +626,6 @@ cc_library( hdrs = [ "sharding_propagation.h", ], - visibility = ["//visibility:public"], deps = [ ":call_graph", ":custom_call_sharding_helper", @@ -706,7 +685,6 @@ cc_library( hdrs = [ "sharding_remover.h", ], - visibility = ["//visibility:public"], deps = [ ":hlo_pass", "//xla:statusor", @@ -742,7 +720,6 @@ cc_library( hdrs = [ "dot_as_convolution_util.h", ], - visibility = ["//visibility:public"], deps = [ ":shape_inference", "//xla:status_macros", @@ -801,7 +778,6 @@ xla_cc_test( cc_library( name = "pattern_matcher", hdrs = ["pattern_matcher.h"], - visibility = ["//visibility:public"], deps = [ ":hlo_parser", "//xla:shape_util", @@ -835,7 +811,6 @@ cc_library( name = "pattern_matcher_gmock", testonly = 1, hdrs = ["pattern_matcher_gmock.h"], - visibility = ["//visibility:public"], deps = [ ":pattern_matcher", "//xla:test", @@ -873,7 +848,7 @@ xla_cc_test( xla_cc_test( name = "hlo_instruction_test", srcs = ["hlo_instruction_test.cc"], - tags = ["no_arm64"], + tags = ["no_aarch64"], deps = [ "//xla:literal", "//xla:protobuf_util", @@ -911,7 +886,6 @@ cc_library( name = "call_graph", srcs = ["call_graph.cc"], hdrs = ["call_graph.h"], - visibility = ["//visibility:public"], deps = [ "//xla:util", "//xla/hlo/ir:hlo", @@ -948,7 +922,6 @@ cc_library( name = "flatten_call_graph", srcs = ["flatten_call_graph.cc"], hdrs = ["flatten_call_graph.h"], - visibility = ["//visibility:public"], deps = [ ":call_graph", ":hlo_pass", @@ -963,7 +936,6 @@ cc_library( name = "call_inliner", srcs = ["call_inliner.cc"], hdrs = ["call_inliner.h"], - visibility = ["//visibility:public"], deps = [ ":call_graph", ":hlo_dce", @@ -1000,7 +972,6 @@ cc_library( name = "hlo_computation_deduplicator", srcs = ["hlo_computation_deduplicator.cc"], hdrs = ["hlo_computation_deduplicator.h"], - visibility = ["//visibility:public"], deps = [ ":hlo_pass", "//xla/hlo/ir:hlo", @@ -1052,7 +1023,6 @@ cc_library( name = "platform_util", srcs = ["platform_util.cc"], hdrs = ["platform_util.h"], - visibility = ["//visibility:public"], deps = [ ":compiler", "//xla:debug_options_flags", @@ -1061,6 +1031,7 @@ cc_library( "//xla:types", "//xla:util", "//xla/stream_executor", + "//xla/stream_executor:platform_manager", "//xla/stream_executor/cuda:cuda_platform_id", "//xla/stream_executor/host:host_platform_id", "//xla/stream_executor/rocm:rocm_platform_id", @@ -1074,7 +1045,6 @@ cc_library( name = "backend", srcs = ["backend.cc"], hdrs = ["backend.h"], - visibility = ["//visibility:public"], deps = [ ":compiler", ":computation_placer", @@ -1100,7 +1070,6 @@ cc_library( srcs = ["service.cc"], hdrs = ["service.h"], local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]), - visibility = ["//visibility:public"], deps = [ ":allocation_tracker", ":backend", @@ -1156,7 +1125,6 @@ cc_library( name = "local_service", srcs = ["local_service.cc"], hdrs = ["local_service.h"], - visibility = ["//visibility:public"], deps = [ ":backend", ":compiler", @@ -1194,7 +1162,6 @@ cc_library( name = "local_service_utils", srcs = ["local_service_utils.cc"], hdrs = ["local_service_utils.h"], - visibility = ["//visibility:public"], deps = [ ":backend", ":hlo_module_config", @@ -1218,7 +1185,6 @@ cc_library( name = "latency_hiding_scheduler", srcs = ["latency_hiding_scheduler.cc"], hdrs = ["latency_hiding_scheduler.h"], - visibility = ["//visibility:public"], deps = [ ":dump", ":hlo_alias_analysis", @@ -1259,7 +1225,6 @@ cc_library( name = "p2p_schedule_preparation", srcs = ["p2p_schedule_preparation.cc"], hdrs = ["p2p_schedule_preparation.h"], - visibility = ["//visibility:public"], deps = [ ":hlo_pass", "//xla:status", @@ -1299,7 +1264,6 @@ cc_library( name = "profile_guided_latency_estimator", srcs = ["profile_guided_latency_estimator.cc"], hdrs = ["profile_guided_latency_estimator.h"], - visibility = ["//visibility:public"], deps = [ ":latency_hiding_scheduler", "//xla/hlo/ir:hlo", @@ -1329,7 +1293,6 @@ cc_library( name = "compile_only_service", srcs = ["compile_only_service.cc"], hdrs = ["compile_only_service.h"], - visibility = ["//visibility:public"], deps = [ ":backend", ":compiler", @@ -1352,7 +1315,6 @@ cc_library( cc_library( name = "cpu_plugin", compatible_with = [], - visibility = ["//visibility:public"], deps = [ ":service", "//xla/service/cpu:cpu_compiler", @@ -1365,7 +1327,6 @@ cc_library( cc_library( name = "gpu_plugin_impl", compatible_with = get_compatible_with_portable(), - visibility = ["//visibility:public"], deps = [ "//xla/stream_executor", ] + if_gpu_is_configured([ @@ -1383,24 +1344,20 @@ cc_library( cc_library( name = "gpu_plugin_stub", - visibility = ["//visibility:public"], ) alias( name = "gpu_plugin_noncuda", actual = if_libtpu("gpu_plugin_stub", "gpu_plugin_impl"), - visibility = ["//visibility:public"], ) alias( name = "gpu_plugin", actual = if_cuda("gpu_plugin_impl", "gpu_plugin_noncuda"), - visibility = ["//visibility:public"], ) cc_library( name = "interpreter_plugin", - visibility = ["//visibility:public"], deps = [ ":service", "//xla/backends/interpreter:compiler", @@ -1456,7 +1413,6 @@ cc_library( "executable.h", "service_executable_run_options.h", ], - visibility = ["//visibility:public"], deps = [ ":computation_layout", ":dump", @@ -1497,7 +1453,6 @@ cc_library( name = "compiler", srcs = ["compiler.cc"], hdrs = ["compiler.h"], - visibility = ["//visibility:public"], deps = [ ":buffer_assignment", ":buffer_value", @@ -1552,7 +1507,6 @@ cc_library( name = "llvm_compiler", srcs = ["llvm_compiler.cc"], hdrs = ["llvm_compiler.h"], - visibility = ["//visibility:public"], deps = [ ":compiler", "@llvm-project//llvm:Core", @@ -1564,27 +1518,29 @@ cc_library( name = "transfer_manager", srcs = ["transfer_manager.cc"], hdrs = ["transfer_manager.h"], - visibility = ["//visibility:public"], deps = [ ":compiler", - ":executable", ":maybe_owning_device_memory", ":shaped_buffer", "//xla:literal", + "//xla:shape_tree", "//xla:shape_util", + "//xla:status", "//xla:status_macros", "//xla:statusor", - "//xla:types", "//xla:util", "//xla:xla_data_proto_cc", "//xla/stream_executor", "//xla/stream_executor:device_memory", + "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/cleanup", "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/strings", + "@com_google_absl//absl/synchronization", "@com_google_absl//absl/types:span", + "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:logging", "@local_tsl//tsl/platform:notification", + "@local_tsl//tsl/platform:statusor", ], ) @@ -1592,7 +1548,6 @@ cc_library( name = "allocation_tracker", srcs = ["allocation_tracker.cc"], hdrs = ["allocation_tracker.h"], - visibility = ["//visibility:public"], deps = [ ":backend", ":transfer_manager", @@ -1614,7 +1569,6 @@ cc_library( name = "execution_tracker", srcs = ["execution_tracker.cc"], hdrs = ["execution_tracker.h"], - visibility = ["//visibility:public"], deps = [ ":backend", ":stream_pool", @@ -1631,7 +1585,6 @@ cc_library( name = "channel_tracker", srcs = ["channel_tracker.cc"], hdrs = ["channel_tracker.h"], - visibility = ["//visibility:public"], deps = [ "//xla:statusor", "//xla:util", @@ -1643,7 +1596,6 @@ cc_library( name = "name_uniquer", srcs = ["name_uniquer.cc"], hdrs = ["name_uniquer.h"], - visibility = ["//visibility:public"], deps = [ "//xla:shape_util", "//xla:types", @@ -1672,7 +1624,6 @@ cc_library( hdrs = [ "buffer_assignment.h", ], - visibility = ["//visibility:public"], deps = [ ":buffer_assignment_proto_cc", ":buffer_value_containers", @@ -1743,7 +1694,6 @@ cc_library( name = "hlo_ordering", srcs = ["hlo_ordering.cc"], hdrs = ["hlo_ordering.h"], - visibility = ["//visibility:public"], deps = [ ":call_graph", ":hlo_dataflow_analysis", @@ -1802,7 +1752,6 @@ cc_library( name = "hlo_module_group_metadata", srcs = ["hlo_module_group_metadata.cc"], hdrs = ["hlo_module_group_metadata.h"], - visibility = ["//visibility:public"], deps = [ ":hlo_alias_analysis", ":tuple_points_to_analysis", @@ -1824,7 +1773,6 @@ cc_library( name = "hlo_module_util", srcs = ["hlo_module_util.cc"], hdrs = ["hlo_module_util.h"], - visibility = ["//visibility:public"], deps = [ ":compiler", ":hlo_module_config", @@ -1839,7 +1787,6 @@ cc_library( name = "hlo_module_group_util", srcs = ["hlo_module_group_util.cc"], hdrs = ["hlo_module_group_util.h"], - visibility = ["//visibility:public"], deps = [ ":hlo_module_group_metadata", "//xla:status", @@ -1906,7 +1853,6 @@ cc_library( name = "hlo_memory_scheduler", srcs = ["hlo_memory_scheduler.cc"], hdrs = ["hlo_memory_scheduler.h"], - visibility = ["//visibility:public"], deps = [ ":hlo_alias_analysis", ":hlo_pass", @@ -1950,7 +1896,6 @@ xla_cc_test( cc_library( name = "fusion_queue", hdrs = ["fusion_queue.h"], - visibility = ["//visibility:public"], deps = [ "//xla/hlo/ir:hlo", "@com_google_absl//absl/strings", @@ -1961,7 +1906,6 @@ cc_library( name = "instruction_fusion", srcs = ["instruction_fusion.cc"], hdrs = ["instruction_fusion.h"], - visibility = ["//visibility:public"], deps = [ ":fusion_queue", ":hlo_dataflow_analysis", @@ -1996,7 +1940,6 @@ cc_library( name = "multi_output_fusion", srcs = ["multi_output_fusion.cc"], hdrs = ["multi_output_fusion.h"], - visibility = ["//visibility:public"], deps = [ ":hlo_dataflow_analysis", ":hlo_dce", @@ -2019,7 +1962,6 @@ cc_library( hdrs = [ "hlo_creation_utils.h", ], - visibility = ["//visibility:public"], deps = [ ":hlo_module_config", ":shape_inference", @@ -2046,7 +1988,6 @@ cc_library( name = "fusion_node_indexing_evaluation", srcs = ["fusion_node_indexing_evaluation.cc"], hdrs = ["fusion_node_indexing_evaluation.h"], - visibility = ["//visibility:public"], deps = [ ":elemental_ir_emitter", "//xla:types", @@ -2094,7 +2035,6 @@ cc_library( name = "batchnorm_expander", srcs = ["batchnorm_expander.cc"], hdrs = ["batchnorm_expander.h"], - visibility = ["//visibility:public"], deps = [ ":hlo_pass", "//xla:literal", @@ -2119,7 +2059,6 @@ cc_library( name = "op_expander_pass", srcs = ["op_expander_pass.cc"], hdrs = ["op_expander_pass.h"], - visibility = ["//visibility:public"], deps = [ ":hlo_creation_utils", ":hlo_pass", @@ -2134,7 +2073,6 @@ cc_library( name = "gather_expander", srcs = ["gather_expander.cc"], hdrs = ["gather_expander.h"], - visibility = ["//visibility:public"], deps = [ ":hlo_creation_utils", ":op_expander_pass", @@ -2151,7 +2089,6 @@ cc_library( name = "optimization_barrier_expander", srcs = ["optimization_barrier_expander.cc"], hdrs = ["optimization_barrier_expander.h"], - visibility = ["//visibility:public"], deps = [ ":op_expander_pass", ], @@ -2161,14 +2098,18 @@ cc_library( name = "comparison_expander", srcs = ["comparison_expander.cc"], hdrs = ["comparison_expander.h"], - visibility = ["//visibility:public"], deps = [ - ":hlo_creation_utils", - ":hlo_pass", ":op_expander_pass", + "//xla:comparison_util", + "//xla:literal_util", + "//xla:shape_util", + "//xla:statusor", "//xla:util", - "//xla/client/lib:comparators", + "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", ], ) @@ -2176,7 +2117,6 @@ cc_library( name = "scatter_expander", srcs = ["scatter_expander.cc"], hdrs = ["scatter_expander.h"], - visibility = ["//visibility:public"], deps = [ ":call_inliner", ":hlo_creation_utils", @@ -2210,7 +2150,6 @@ cc_library( name = "triangular_solve_expander", srcs = ["triangular_solve_expander.cc"], hdrs = ["triangular_solve_expander.h"], - visibility = ["//visibility:public"], deps = [ ":op_expander_pass", "//xla:literal", @@ -2250,7 +2189,6 @@ cc_library( name = "cholesky_expander", srcs = ["cholesky_expander.cc"], hdrs = ["cholesky_expander.h"], - visibility = ["//visibility:public"], deps = [ ":op_expander_pass", "//xla:literal", @@ -2274,7 +2212,6 @@ cc_library( name = "qr_expander", srcs = ["qr_expander.cc"], hdrs = ["qr_expander.h"], - visibility = ["//visibility:public"], deps = [ ":op_expander_pass", "//xla:literal", @@ -2299,7 +2236,6 @@ cc_library( name = "real_imag_expander", srcs = ["real_imag_expander.cc"], hdrs = ["real_imag_expander.h"], - visibility = ["//visibility:public"], deps = [ ":op_expander_pass", "//xla:literal_util", @@ -2331,7 +2267,6 @@ cc_library( name = "eigh_expander", srcs = ["eigh_expander.cc"], hdrs = ["eigh_expander.h"], - visibility = ["//visibility:public"], deps = [ ":op_expander_pass", "//xla:literal_util", @@ -2356,7 +2291,6 @@ cc_library( name = "convolution_4d_expander", srcs = ["convolution_4d_expander.cc"], hdrs = ["convolution_4d_expander.h"], - visibility = ["//visibility:public"], deps = [ ":op_expander_pass", "//xla:shape_util", @@ -2386,7 +2320,6 @@ cc_library( name = "convolution_pred_expander", srcs = ["convolution_pred_expander.cc"], hdrs = ["convolution_pred_expander.h"], - visibility = ["//visibility:public"], deps = [ ":hlo_creation_utils", ":op_expander_pass", @@ -2440,7 +2373,6 @@ cc_library( name = "algebraic_simplifier", srcs = ["algebraic_simplifier.cc"], hdrs = ["algebraic_simplifier.h"], - visibility = ["//visibility:public"], deps = [ ":hlo_cost_analysis", ":hlo_creation_utils", @@ -2481,7 +2413,6 @@ cc_library( name = "tree_reduction_rewriter", srcs = ["tree_reduction_rewriter.cc"], hdrs = ["tree_reduction_rewriter.h"], - visibility = ["//visibility:public"], deps = [ ":hlo_pass", ":shape_inference", @@ -2540,7 +2471,6 @@ cc_library( name = "simplify_fp_conversions", srcs = ["simplify_fp_conversions.cc"], hdrs = ["simplify_fp_conversions.h"], - visibility = ["//visibility:public"], deps = [ ":hlo_pass", "//xla:comparison_util", @@ -2574,7 +2504,6 @@ cc_library( name = "logistic_expander", srcs = ["logistic_expander.cc"], hdrs = ["logistic_expander.h"], - visibility = ["//visibility:public"], deps = [ ":hlo_creation_utils", ":hlo_pass", @@ -2598,6 +2527,7 @@ xla_cc_test( name = "logistic_expander_test", srcs = ["logistic_expander_test.cc"], deps = [ + ":dynamic_padder", ":hlo_creation_utils", ":hlo_parser", ":hlo_pass", @@ -2608,6 +2538,7 @@ xla_cc_test( ":shape_inference", "//xla:literal", "//xla:shape_util", + "//xla:statusor", "//xla:test", "//xla:types", "//xla:window_util", @@ -2624,7 +2555,6 @@ cc_library( name = "collectives_schedule_linearizer", srcs = ["collectives_schedule_linearizer.cc"], hdrs = ["collectives_schedule_linearizer.h"], - visibility = ["//visibility:public"], deps = [ ":hlo_pass", "//xla:statusor", @@ -2656,7 +2586,6 @@ xla_cc_test( cc_library( name = "collective_combiner_utils", hdrs = ["collective_combiner_utils.h"], - visibility = ["//visibility:public"], deps = [ ":hlo_domain_map", "//xla:shape_util", @@ -2675,7 +2604,6 @@ cc_library( name = "collective_decomposer_utils", srcs = ["collective_decomposer_utils.cc"], hdrs = ["collective_decomposer_utils.h"], - visibility = ["//visibility:public"], deps = [ ":collective_ops_utils", ":hlo_module_config", @@ -2690,7 +2618,6 @@ cc_library( name = "all_gather_broadcast_reorder", srcs = ["all_gather_broadcast_reorder.cc"], hdrs = ["all_gather_broadcast_reorder.h"], - visibility = ["//visibility:public"], deps = [ ":hlo_pass", "//xla:shape_util", @@ -2707,7 +2634,6 @@ cc_library( name = "bitcast_dtypes_expander", srcs = ["bitcast_dtypes_expander.cc"], hdrs = ["bitcast_dtypes_expander.h"], - visibility = ["//visibility:public"], deps = [ ":hlo_pass", ":op_expander_pass", @@ -2757,7 +2683,6 @@ cc_library( name = "all_gather_combiner", srcs = ["all_gather_combiner.cc"], hdrs = ["all_gather_combiner.h"], - visibility = ["//visibility:public"], deps = [ ":collective_combiner_utils", ":hlo_domain_map", @@ -2794,7 +2719,6 @@ cc_library( name = "all_reduce_combiner", srcs = ["all_reduce_combiner.cc"], hdrs = ["all_reduce_combiner.h"], - visibility = ["//visibility:public"], deps = [ ":all_reduce_key", ":collective_combiner_utils", @@ -2835,7 +2759,6 @@ cc_library( name = "all_reduce_contiguous", srcs = ["all_reduce_contiguous.cc"], hdrs = ["all_reduce_contiguous.h"], - visibility = ["//visibility:public"], deps = [ ":hlo_pass", "//xla:shape_util", @@ -2863,7 +2786,6 @@ cc_library( name = "reduce_scatter_combiner", srcs = ["reduce_scatter_combiner.cc"], hdrs = ["reduce_scatter_combiner.h"], - visibility = ["//visibility:public"], deps = [ ":all_reduce_key", ":collective_combiner_utils", @@ -2903,7 +2825,6 @@ cc_library( name = "all_reduce_simplifier", srcs = ["all_reduce_simplifier.cc"], hdrs = ["all_reduce_simplifier.h"], - visibility = ["//visibility:public"], deps = [ ":hlo_pass", ":hlo_replication_analysis", @@ -2937,7 +2858,6 @@ cc_library( name = "reduce_scatter_decomposer", srcs = ["reduce_scatter_decomposer.cc"], hdrs = ["reduce_scatter_decomposer.h"], - visibility = ["//visibility:public"], deps = [ ":collective_decomposer_utils", ":collective_ops_utils", @@ -2970,7 +2890,6 @@ cc_library( name = "reduce_scatter_reassociate", srcs = ["reduce_scatter_reassociate.cc"], hdrs = ["reduce_scatter_reassociate.h"], - visibility = ["//visibility:public"], deps = [ ":all_reduce_key", ":collective_ops_utils", @@ -2999,7 +2918,6 @@ cc_library( name = "batch_dot_simplification", srcs = ["batch_dot_simplification.cc"], hdrs = ["batch_dot_simplification.h"], - visibility = ["//visibility:public"], deps = [ ":hlo_creation_utils", ":hlo_pass", @@ -3037,7 +2955,6 @@ cc_library( name = "conditional_simplifier", srcs = ["conditional_simplifier.cc"], hdrs = ["conditional_simplifier.h"], - visibility = ["//visibility:public"], deps = [ ":call_graph", ":call_inliner", @@ -3080,7 +2997,6 @@ cc_library( name = "conditional_code_motion", srcs = ["conditional_code_motion.cc"], hdrs = ["conditional_code_motion.h"], - visibility = ["//visibility:public"], deps = [ ":hlo_cse", ":hlo_dce", @@ -3131,7 +3047,6 @@ cc_library( name = "convolution_group_converter", srcs = ["convolution_group_converter.cc"], hdrs = ["convolution_group_converter.h"], - visibility = ["//visibility:public"], deps = [ ":hlo_creation_utils", ":hlo_pass", @@ -3169,7 +3084,6 @@ cc_library( name = "space_to_batch_converter", srcs = ["space_to_batch_converter.cc"], hdrs = ["space_to_batch_converter.h"], - visibility = ["//visibility:public"], deps = [ ":hlo_creation_utils", ":hlo_pass", @@ -3202,7 +3116,6 @@ cc_library( name = "sparse_util", srcs = ["sparse_util.cc"], hdrs = ["sparse_util.h"], - visibility = ["//visibility:public"], deps = [ "//xla:shape_util", "//xla/hlo/ir:hlo", @@ -3228,15 +3141,16 @@ cc_library( name = "while_loop_unroller", srcs = ["while_loop_unroller.cc"], hdrs = ["while_loop_unroller.h"], - visibility = ["//visibility:public"], deps = [ ":call_inliner", + ":collective_ops_utils", ":flatten_call_graph", ":hlo_cse", ":hlo_pass", ":tuple_simplifier", ":while_loop_analysis", ":while_loop_constant_sinking", + "//xla:comparison_util", "//xla:literal", "//xla:literal_util", "//xla:shape_util", @@ -3248,10 +3162,11 @@ cc_library( "//xla/hlo/utils:hlo_query", "@com_google_absl//absl/algorithm", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", - "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:statusor", @@ -3269,10 +3184,12 @@ xla_cc_test( "//xla/tests:literal_test_util", "//xla/tests:verified_hlo_module", "//xla/tests:xla_internal_test_main", + "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/log", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", "@com_google_googletest//:gtest", + "@local_tsl//tsl/platform:statusor", ], ) @@ -3280,7 +3197,6 @@ cc_library( name = "while_loop_analysis", srcs = ["while_loop_analysis.cc"], hdrs = ["while_loop_analysis.h"], - visibility = ["//visibility:public"], deps = [ ":pattern_matcher", "//xla:comparison_util", @@ -3319,7 +3235,6 @@ cc_library( name = "while_loop_simplifier", srcs = ["while_loop_simplifier.cc"], hdrs = ["while_loop_simplifier.h"], - visibility = ["//visibility:public"], deps = [ ":call_inliner", ":hlo_creation_utils", @@ -3368,7 +3283,6 @@ cc_library( name = "while_loop_trip_count_annotator", srcs = ["while_loop_trip_count_annotator.cc"], hdrs = ["while_loop_trip_count_annotator.h"], - visibility = ["//visibility:public"], deps = [ ":hlo_pass", ":while_loop_analysis", @@ -3397,7 +3311,6 @@ cc_library( name = "defuser", srcs = ["defuser.cc"], hdrs = ["defuser.h"], - visibility = ["//visibility:public"], deps = [ ":call_graph", ":hlo_pass", @@ -3444,7 +3357,6 @@ cc_library( name = "dot_decomposer", srcs = ["dot_decomposer.cc"], hdrs = ["dot_decomposer.h"], - visibility = ["//visibility:public"], deps = [ ":hlo_pass", ":sparse_util", @@ -3476,7 +3388,6 @@ cc_library( name = "dot_dimension_merger", srcs = ["dot_dimension_merger.cc"], hdrs = ["dot_dimension_merger.h"], - visibility = ["//visibility:public"], deps = [ ":hlo_creation_utils", ":hlo_pass", @@ -3503,7 +3414,6 @@ cc_library( name = "dot_merger", srcs = ["dot_merger.cc"], hdrs = ["dot_merger.h"], - visibility = ["//visibility:public"], deps = [ ":hlo_pass", ":shape_inference", @@ -3532,7 +3442,6 @@ cc_library( name = "convert_mover", srcs = ["convert_mover.cc"], hdrs = ["convert_mover.h"], - visibility = ["//visibility:public"], deps = [ ":hlo_creation_utils", ":hlo_pass", @@ -3557,7 +3466,6 @@ cc_library( name = "all_to_all_decomposer", srcs = ["all_to_all_decomposer.cc"], hdrs = ["all_to_all_decomposer.h"], - visibility = ["//visibility:public"], deps = [ ":op_expander_pass", "//xla:shape_util", @@ -3570,7 +3478,6 @@ cc_library( name = "all_gather_decomposer", srcs = ["all_gather_decomposer.cc"], hdrs = ["all_gather_decomposer.h"], - visibility = ["//visibility:public"], deps = [ ":collective_decomposer_utils", ":collective_ops_utils", @@ -3605,7 +3512,6 @@ cc_library( name = "tuple_simplifier", srcs = ["tuple_simplifier.cc"], hdrs = ["tuple_simplifier.h"], - visibility = ["//visibility:public"], deps = [ ":hlo_pass", "//xla/hlo/ir:hlo", @@ -3633,7 +3539,6 @@ cc_library( name = "reshape_mover", srcs = ["reshape_mover.cc"], hdrs = ["reshape_mover.h"], - visibility = ["//visibility:public"], deps = [ ":hlo_creation_utils", ":hlo_pass", @@ -3650,7 +3555,6 @@ cc_library( name = "reshape_decomposer", srcs = ["reshape_decomposer.cc"], hdrs = ["reshape_decomposer.h"], - visibility = ["//visibility:public"], deps = [ ":hlo_creation_utils", ":hlo_pass", @@ -3663,7 +3567,6 @@ cc_library( name = "reduce_decomposer", srcs = ["reduce_decomposer.cc"], hdrs = ["reduce_decomposer.h"], - visibility = ["//visibility:public"], deps = [ ":hlo_creation_utils", ":hlo_pass", @@ -3704,7 +3607,6 @@ cc_library( name = "dynamic_window_utils", srcs = ["dynamic_window_utils.cc"], hdrs = ["dynamic_window_utils.h"], - visibility = ["//visibility:public"], deps = [ ":shape_inference", "//xla:literal", @@ -3720,7 +3622,6 @@ cc_library( name = "dynamic_dimension_inference", srcs = ["dynamic_dimension_inference.cc"], hdrs = ["dynamic_dimension_inference.h"], - visibility = ["//visibility:public"], deps = [ ":call_inliner", ":dynamic_window_utils", @@ -3761,7 +3662,6 @@ cc_library( name = "dynamic_dimension_simplifier", srcs = ["dynamic_dimension_simplifier.cc"], hdrs = ["dynamic_dimension_simplifier.h"], - visibility = ["//visibility:public"], deps = [ ":hlo_pass", "//xla:status_macros", @@ -3799,7 +3699,6 @@ cc_library( name = "dynamic_padder", srcs = ["dynamic_padder.cc"], hdrs = ["dynamic_padder.h"], - visibility = ["//visibility:public"], deps = [ ":call_graph", ":dynamic_dimension_inference", @@ -3925,7 +3824,6 @@ cc_library( name = "computation_placer", srcs = ["computation_placer.cc"], hdrs = ["computation_placer.h"], - visibility = ["//visibility:public"], deps = [ ":global_device_id", "//xla:array2d", @@ -3955,7 +3853,6 @@ cc_library( cc_library( name = "computation_placer_hdr", hdrs = ["computation_placer.h"], - visibility = ["//visibility:public"], deps = [ ":global_device_id", "//xla:array2d", @@ -3990,7 +3887,6 @@ cc_library( name = "human_readable_profile_builder", srcs = ["human_readable_profile_builder.cc"], hdrs = ["human_readable_profile_builder.h"], - visibility = ["//visibility:public"], deps = [ "//xla:metric_table_report", "//xla:types", @@ -4006,7 +3902,6 @@ cc_library( name = "generic_transfer_manager", srcs = ["generic_transfer_manager.cc"], hdrs = ["generic_transfer_manager.h"], - visibility = ["//visibility:public"], deps = [ ":transfer_manager", "//xla:literal", @@ -4016,12 +3911,19 @@ cc_library( "//xla:types", "//xla:util", "//xla:xla_data_proto_cc", + "//xla/service:shaped_buffer", "//xla/stream_executor", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/container:node_hash_map", + "@com_google_absl//absl/functional:function_ref", "@com_google_absl//absl/status", "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/synchronization", "@com_google_absl//absl/types:span", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:logging", + "@local_tsl//tsl/platform:numbers", + "@local_tsl//tsl/platform:statusor", ], alwayslink = True, # Contains per-platform transfer manager registration ) @@ -4039,6 +3941,7 @@ xla_cc_test( "//xla:shape_util", "//xla:types", "//xla/stream_executor", + "//xla/stream_executor:platform_manager", "//xla/stream_executor/host:host_platform", "//xla/stream_executor/host:host_platform_id", "//xla/tests:literal_test_util", @@ -4055,7 +3958,6 @@ cc_library( name = "hlo_cost_analysis", srcs = ["hlo_cost_analysis.cc"], hdrs = ["hlo_cost_analysis.h"], - visibility = ["//visibility:public"], deps = [ "//xla:shape_util", "//xla:status", @@ -4103,7 +4005,6 @@ cc_library( name = "hlo_execution_profile", srcs = ["hlo_execution_profile.cc"], hdrs = ["hlo_execution_profile.h"], - visibility = ["//visibility:public"], deps = [ ":hlo_cost_analysis", ":hlo_execution_profile_data_cc", @@ -4194,7 +4095,6 @@ cc_library( name = "buffer_value", srcs = ["buffer_value.cc"], hdrs = ["buffer_value.h"], - visibility = ["//visibility:public"], deps = [ ":hlo_proto_cc", "//xla:shape_util", @@ -4209,7 +4109,6 @@ cc_library( cc_library( name = "buffer_value_containers", hdrs = ["buffer_value_containers.h"], - visibility = ["//visibility:public"], deps = [ ":buffer_value", ":logical_buffer", @@ -4222,7 +4121,6 @@ cc_library( name = "logical_buffer", srcs = ["logical_buffer.cc"], hdrs = ["logical_buffer.h"], - visibility = ["//visibility:public"], deps = [ ":buffer_value", ":hlo_proto_cc", @@ -4240,7 +4138,6 @@ cc_library( name = "hlo_value", srcs = ["hlo_value.cc"], hdrs = ["hlo_value.h"], - visibility = ["//visibility:public"], deps = [ ":buffer_value", "//xla:lazy", @@ -4267,7 +4164,6 @@ cc_library( name = "hlo_dataflow_analysis", srcs = ["hlo_dataflow_analysis.cc"], hdrs = ["hlo_dataflow_analysis.h"], - visibility = ["//visibility:public"], deps = [ ":call_graph", ":hlo_phi_graph", @@ -4285,6 +4181,7 @@ cc_library( "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/functional:function_ref", + "@com_google_absl//absl/log:check", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", @@ -4326,7 +4223,6 @@ cc_library( name = "hlo_phi_graph", srcs = ["hlo_phi_graph.cc"], hdrs = ["hlo_phi_graph.h"], - visibility = ["//visibility:public"], deps = [ ":hlo_value", "//xla/hlo/ir:hlo", @@ -4351,7 +4247,6 @@ cc_library( name = "hlo_value_semantics_analysis", srcs = ["hlo_value_semantics_analysis.cc"], hdrs = ["hlo_value_semantics_analysis.h"], - visibility = ["//visibility:public"], deps = [ ":hlo_value", "//xla:shape_tree", @@ -4394,7 +4289,6 @@ cc_library( name = "hlo_replication_analysis", srcs = ["hlo_replication_analysis.cc"], hdrs = ["hlo_replication_analysis.h"], - visibility = ["//visibility:public"], deps = [ "//xla:shape_util", "//xla:statusor", @@ -4426,7 +4320,6 @@ cc_library( name = "hlo_liveness_analysis", srcs = ["hlo_liveness_analysis.cc"], hdrs = ["hlo_liveness_analysis.h"], - visibility = ["//visibility:public"], deps = [ ":call_graph", ":hlo_value", @@ -4469,7 +4362,6 @@ cc_library( name = "hlo_buffer", srcs = ["hlo_buffer.cc"], hdrs = ["hlo_buffer.h"], - visibility = ["//visibility:public"], deps = [ ":hlo_value", "//xla:shape_tree", @@ -4489,7 +4381,6 @@ cc_library( name = "hlo_alias_analysis", srcs = ["hlo_alias_analysis.cc"], hdrs = ["hlo_alias_analysis.h"], - visibility = ["//visibility:public"], deps = [ ":hlo_buffer", ":hlo_dataflow_analysis", @@ -4541,7 +4432,6 @@ cc_library( name = "logical_buffer_analysis", srcs = ["logical_buffer_analysis.cc"], hdrs = ["logical_buffer_analysis.h"], - visibility = ["//visibility:public"], deps = [ ":logical_buffer", "//xla:shape_util", @@ -4557,7 +4447,6 @@ cc_library( name = "tuple_points_to_analysis", srcs = ["tuple_points_to_analysis.cc"], hdrs = ["tuple_points_to_analysis.h"], - visibility = ["//visibility:public"], deps = [ ":hlo_dataflow_analysis", ":logical_buffer", @@ -4608,7 +4497,6 @@ cc_library( name = "compilation_cache", srcs = ["compilation_cache.cc"], hdrs = ["compilation_cache.h"], - visibility = ["//visibility:public"], deps = [ ":executable", ":hlo_module_config", @@ -4629,7 +4517,6 @@ cc_library( hdrs = [ "layout_assignment.h", ], - visibility = ["//visibility:public"], deps = [ ":call_graph", ":computation_layout", @@ -4672,7 +4559,6 @@ cc_library( "compile_time_cap.h", "copy_insertion.h", ], - visibility = ["//visibility:public"], deps = [ ":dump", ":hlo_alias_analysis", @@ -4700,7 +4586,6 @@ cc_library( name = "loop_schedule_linearizer", srcs = ["loop_schedule_linearizer.cc"], hdrs = ["loop_schedule_linearizer.h"], - visibility = ["//visibility:public"], deps = [ ":hlo_alias_analysis", ":hlo_graph_dumper", @@ -4757,7 +4642,6 @@ cc_library( name = "memory_space_propagation", srcs = ["memory_space_propagation.cc"], hdrs = ["memory_space_propagation.h"], - visibility = ["//visibility:public"], deps = [ ":hlo_dataflow_analysis", ":hlo_pass", @@ -4782,7 +4666,6 @@ cc_library( name = "hlo_dce", srcs = ["hlo_dce.cc"], hdrs = ["hlo_dce.h"], - visibility = ["//visibility:public"], deps = [ ":hlo_pass", "//xla:status", @@ -4803,7 +4686,6 @@ cc_library( name = "hlo_module_dce", srcs = ["hlo_module_dce.cc"], hdrs = ["hlo_module_dce.h"], - visibility = ["//visibility:public"], deps = [ ":hlo_dce", ":hlo_liveness_analysis", @@ -4825,7 +4707,6 @@ cc_library( name = "hlo_verifier", srcs = ["hlo_verifier.cc"], hdrs = ["hlo_verifier.h"], - visibility = ["//visibility:public"], deps = [ ":collective_ops_utils", ":hlo_pass", @@ -4873,7 +4754,6 @@ cc_library( name = "cpu_gpu_shape_verifier", srcs = ["cpu_gpu_shape_verifier.cc"], hdrs = ["cpu_gpu_shape_verifier.h"], - visibility = ["//visibility:public"], deps = [ ":hlo_verifier", "//xla:shape_util", @@ -4903,7 +4783,6 @@ cc_library( name = "hlo_rematerialization", srcs = ["hlo_rematerialization.cc"], hdrs = ["hlo_rematerialization.h"], - visibility = ["//visibility:public"], deps = [ ":call_graph", ":hlo_cost_analysis", @@ -4936,7 +4815,6 @@ cc_library( name = "hlo_rematerialization_test_utils", testonly = 1, hdrs = ["hlo_rematerialization_test_utils.h"], - visibility = ["//visibility:public"], deps = [ "//xla:shape_util", "//xla:xla_data_proto_cc", @@ -5043,7 +4921,6 @@ cc_library( "hlo_pass_fix.h", "hlo_pass_interface.h", ], - visibility = ["//visibility:public"], deps = [ "//xla:status_macros", "//xla:statusor", @@ -5064,7 +4941,6 @@ cc_library( "hlo_pass_pipeline.h", ], local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]), - visibility = ["//visibility:public"], deps = [ ":compilation_stats", ":dump", @@ -5105,7 +4981,6 @@ cc_library( name = "hlo_cse", srcs = ["hlo_cse.cc"], hdrs = ["hlo_cse.h"], - visibility = ["//visibility:public"], deps = [ ":hlo_domain_map", ":hlo_pass", @@ -5142,7 +5017,6 @@ cc_library( name = "hlo_constant_folding", srcs = ["hlo_constant_folding.cc"], hdrs = ["hlo_constant_folding.h"], - visibility = ["//visibility:public"], deps = [ ":hlo_pass", ":slow_operation_alarm", @@ -5181,7 +5055,6 @@ cc_library( name = "hlo_domain_map", srcs = ["hlo_domain_map.cc"], hdrs = ["hlo_domain_map.h"], - visibility = ["//visibility:public"], deps = [ "//xla:statusor", "//xla:types", @@ -5197,7 +5070,6 @@ cc_library( name = "hlo_domain_verifier", srcs = ["hlo_domain_verifier.cc"], hdrs = ["hlo_domain_verifier.h"], - visibility = ["//visibility:public"], deps = [ ":hlo_domain_map", ":hlo_graph_dumper", @@ -5212,7 +5084,6 @@ cc_library( name = "hlo_domain_isolator", srcs = ["hlo_domain_isolator.cc"], hdrs = ["hlo_domain_isolator.h"], - visibility = ["//visibility:public"], deps = [ ":hlo_domain_remover", ":hlo_pass", @@ -5225,7 +5096,6 @@ cc_library( name = "hlo_domain_remover", srcs = ["hlo_domain_remover.cc"], hdrs = ["hlo_domain_remover.h"], - visibility = ["//visibility:public"], deps = [ ":hlo_domain_map", ":hlo_domain_verifier", @@ -5260,7 +5130,6 @@ cc_library( name = "hlo_element_type_converter", srcs = ["hlo_element_type_converter.cc"], hdrs = ["hlo_element_type_converter.h"], - visibility = ["//visibility:public"], deps = [ ":hlo_pass", "//xla:literal", @@ -5288,7 +5157,6 @@ cc_library( name = "conditional_canonicalizer", srcs = ["conditional_canonicalizer.cc"], hdrs = ["conditional_canonicalizer.h"], - visibility = ["//visibility:public"], deps = [ ":hlo_pass", "//xla:status_macros", @@ -5323,7 +5191,6 @@ cc_library( hdrs = [ "maybe_owning_device_memory.h", ], - visibility = ["//visibility:public"], deps = [ "//xla/stream_executor:device_memory_allocator", "@com_google_absl//absl/types:variant", @@ -5338,7 +5205,6 @@ cc_library( hdrs = [ "float8_fnuz_ir_emitter.h", ], - visibility = ["//visibility:public"], deps = [ "//xla:shape_util", "//xla:status_macros", @@ -5353,7 +5219,6 @@ cc_library( name = "elemental_ir_emitter", srcs = ["elemental_ir_emitter.cc"], hdrs = ["elemental_ir_emitter.h"], - visibility = ["//visibility:public"], deps = [ ":float8_fnuz_ir_emitter", "//xla:permutation_util", @@ -5370,6 +5235,7 @@ cc_library( "//xla/service/llvm_ir:llvm_loop", "//xla/service/llvm_ir:llvm_util", "//xla/service/llvm_ir:loop_emitter", + "//xla/service/llvm_ir:math_ops", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/strings", @@ -5409,7 +5275,6 @@ cc_library( name = "hlo_module_config", srcs = ["hlo_module_config.cc"], hdrs = ["hlo_module_config.h"], - visibility = ["//visibility:public"], deps = [ ":computation_layout", ":computation_placer", @@ -5445,7 +5310,6 @@ cc_library( name = "computation_layout", srcs = ["computation_layout.cc"], hdrs = ["computation_layout.h"], - visibility = ["//visibility:public"], deps = [ "//xla:printer", "//xla:shape_layout", @@ -5462,7 +5326,6 @@ cc_library( name = "hlo_graph_dumper", srcs = ["hlo_graph_dumper.cc"], hdrs = ["hlo_graph_dumper.h"], - visibility = ["//visibility:public"], deps = [ ":pattern_matcher", "//xla:literal", @@ -5513,7 +5376,6 @@ cc_library( name = "transpose_folding", srcs = ["transpose_folding.cc"], hdrs = ["transpose_folding.h"], - visibility = ["//visibility:public"], deps = [ ":hlo_pass", "//xla:shape_util", @@ -5557,7 +5419,6 @@ cc_library( name = "zero_sized_hlo_elimination", srcs = ["zero_sized_hlo_elimination.cc"], hdrs = ["zero_sized_hlo_elimination.h"], - visibility = ["//visibility:public"], deps = [ ":hlo_pass", "//xla:literal", @@ -5594,7 +5455,6 @@ cc_library( name = "stream_pool", srcs = ["stream_pool.cc"], hdrs = ["stream_pool.h"], - visibility = ["//visibility:public"], deps = [ "//xla/stream_executor", ], @@ -5607,6 +5467,7 @@ xla_cc_test( ":stream_pool", "//xla:test_helpers", "//xla/stream_executor", + "//xla/stream_executor:platform_manager", "//xla/stream_executor/host:host_platform", "//xla/tests:xla_internal_test_main", ], @@ -5616,7 +5477,6 @@ cc_library( name = "hlo_proto_util", srcs = ["hlo_proto_util.cc"], hdrs = ["hlo_proto_util.h"], - visibility = ["//visibility:public"], deps = [ ":buffer_assignment", ":hlo_proto_cc", @@ -5647,7 +5507,6 @@ cc_library( name = "hlo_runner_interface", srcs = ["hlo_runner_interface.cc"], hdrs = ["hlo_runner_interface.h"], - visibility = ["//visibility:public"], deps = [ ":computation_placer", ":executable", @@ -5666,7 +5525,6 @@ cc_library( name = "hlo_runner", srcs = ["hlo_runner.cc"], hdrs = ["hlo_runner.h"], - visibility = ["//visibility:public"], deps = [ ":backend", ":compiler", @@ -5696,7 +5554,6 @@ cc_library( name = "hlo_runner_pjrt", srcs = ["hlo_runner_pjrt.cc"], hdrs = ["hlo_runner_pjrt.h"], - visibility = ["//visibility:public"], deps = [ ":executable", ":hlo_module_util", @@ -5716,7 +5573,6 @@ cc_library( name = "hlo_profile_printer", srcs = ["hlo_profile_printer.cc"], hdrs = ["hlo_profile_printer.h"], - visibility = ["//visibility:public"], deps = [ ":hlo_profile_printer_data_cc", ":human_readable_profile_builder", @@ -5730,7 +5586,6 @@ cc_library( name = "sort_simplifier", srcs = ["sort_simplifier.cc"], hdrs = ["sort_simplifier.h"], - visibility = ["//visibility:public"], deps = [ ":hlo_pass", "//xla:statusor", @@ -5760,7 +5615,6 @@ cc_library( name = "stable_sort_expander", srcs = ["stable_sort_expander.cc"], hdrs = ["stable_sort_expander.h"], - visibility = ["//visibility:public"], deps = [ ":hlo_pass", ":op_expander_pass", @@ -5792,7 +5646,6 @@ cc_library( name = "tuple_util", srcs = ["tuple_util.cc"], hdrs = ["tuple_util.h"], - visibility = ["//visibility:public"], deps = [ ":hlo_value", "//xla:shape_tree", @@ -5829,7 +5682,6 @@ cc_library( name = "root_instruction_sinker", srcs = ["root_instruction_sinker.cc"], hdrs = ["root_instruction_sinker.h"], - visibility = ["//visibility:public"], deps = [ ":hlo_pass", ":tuple_util", @@ -5848,11 +5700,51 @@ xla_cc_test( ], ) +cc_library( + name = "host_memory_offload_annotations_hdr", + hdrs = ["host_memory_offload_annotations.h"], + deps = [ + "@com_google_absl//absl/strings:string_view", + ], +) + +cc_library( + name = "convert_memory_placement_to_internal_annotations", + srcs = ["convert_memory_placement_to_internal_annotations.cc"], + hdrs = ["convert_memory_placement_to_internal_annotations.h"], + deps = [ + ":host_memory_offload_annotations_hdr", + "//xla:side_effect_util", + "//xla:statusor", + "//xla:util", + "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/service:hlo_pass", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log", + "@com_google_absl//absl/strings:string_view", + "@local_tsl//tsl/platform:errors", + ], +) + +xla_cc_test( + name = "convert_memory_placement_to_internal_annotations_test", + srcs = ["convert_memory_placement_to_internal_annotations_test.cc"], + deps = [ + ":convert_memory_placement_to_internal_annotations", + "//xla:statusor", + "//xla:util", + "//xla:xla_data_proto_cc", + "//xla/tests:hlo_test_base", + "//xla/tests:xla_internal_test_main", + "@com_google_googletest//:gtest", + ], +) + cc_library( name = "host_memory_transfer_asyncifier", srcs = ["host_memory_transfer_asyncifier.cc"], hdrs = ["host_memory_transfer_asyncifier.h"], - visibility = ["//visibility:public"], deps = [ ":hlo_pass", "//xla:shape_util", @@ -5863,6 +5755,7 @@ cc_library( "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/log", "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:statusor", ], @@ -5888,16 +5781,68 @@ xla_cc_test( ], ) +cc_library( + name = "host_offload_legalize", + srcs = ["host_offload_legalize.cc"], + hdrs = ["host_offload_legalize.h"], + deps = [ + ":hlo_alias_analysis", + ":hlo_buffer", + ":hlo_pass", + ":hlo_value", + ":host_memory_offload_annotations_hdr", + ":host_offloader", + ":pattern_matcher", + "//xla:literal_util", + "//xla:shape_util", + "//xla:status", + "//xla:statusor", + "//xla:util", + "//xla/hlo/ir:hlo", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/strings", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:statusor", + ], +) + +xla_cc_test( + name = "host_offload_legalize_test", + srcs = ["host_offload_legalize_test.cc"], + deps = [ + ":host_memory_offload_annotations_hdr", + ":host_offload_legalize", + ":pattern_matcher", + ":pattern_matcher_gmock", + "//xla:shape_util", + "//xla:statusor", + "//xla:util", + "//xla/hlo/ir:hlo", + "//xla/tests:hlo_test_base", + "//xla/tests:xla_internal_test_main", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/status", + "@com_google_googletest//:gtest", + "@local_tsl//tsl/lib/core:status_test_util", + "@local_tsl//tsl/platform:statusor", + ], +) + cc_library( name = "host_offloader", srcs = ["host_offloader.cc"], hdrs = ["host_offloader.h"], - visibility = ["//visibility:public"], deps = [ ":hlo_alias_analysis", ":hlo_buffer", ":hlo_pass", ":hlo_value", + ":host_memory_offload_annotations_hdr", + ":pattern_matcher", "//xla:literal_util", "//xla:shape_util", "//xla:status", @@ -5918,7 +5863,10 @@ cc_library( xla_cc_test( name = "host_offloader_test", srcs = ["host_offloader_test.cc"], + shard_count = 12, deps = [ + ":host_memory_offload_annotations_hdr", + ":host_offload_legalize", ":host_offloader", ":pattern_matcher", ":pattern_matcher_gmock", @@ -5940,7 +5888,6 @@ cc_library( name = "while_util", srcs = ["while_util.cc"], hdrs = ["while_util.h"], - visibility = ["//visibility:public"], deps = [ ":call_inliner", ":hlo_creation_utils", @@ -5982,7 +5929,6 @@ cc_library( name = "while_loop_all_reduce_code_motion", srcs = ["while_loop_all_reduce_code_motion.cc"], hdrs = ["while_loop_all_reduce_code_motion.h"], - visibility = ["//visibility:public"], deps = [ ":call_graph", ":collective_ops_utils", @@ -6022,7 +5968,6 @@ cc_library( name = "while_loop_concat_code_motion", srcs = ["while_loop_concat_code_motion.cc"], hdrs = ["while_loop_concat_code_motion.h"], - visibility = ["//visibility:public"], deps = [ ":hlo_dce", ":hlo_pass", @@ -6069,7 +6014,6 @@ cc_library( "compile_time_cap.h", "while_loop_invariant_code_motion.h", ], - visibility = ["//visibility:public"], deps = [ ":hlo_dce", ":hlo_pass", @@ -6106,7 +6050,6 @@ cc_library( name = "while_loop_expensive_invariant_code_motion", srcs = ["while_loop_expensive_invariant_code_motion.cc"], hdrs = ["while_loop_expensive_invariant_code_motion.h"], - visibility = ["//visibility:public"], deps = [ ":hlo_pass", ":while_loop_analysis", @@ -6138,7 +6081,6 @@ cc_library( name = "fusion_constant_sinking", srcs = ["fusion_constant_sinking.cc"], hdrs = ["fusion_constant_sinking.h"], - visibility = ["//visibility:public"], deps = [ ":hlo_dce", ":hlo_pass", @@ -6173,7 +6115,6 @@ cc_library( name = "while_loop_constant_sinking", srcs = ["while_loop_constant_sinking.cc"], hdrs = ["while_loop_constant_sinking.h"], - visibility = ["//visibility:public"], deps = [ ":hlo_pass", ":while_util", @@ -6202,7 +6143,6 @@ cc_library( name = "while_loop_fusible_sinking", srcs = ["while_loop_fusible_sinking.cc"], hdrs = ["while_loop_fusible_sinking.h"], - visibility = ["//visibility:public"], deps = [ ":hlo_pass", ":while_util", @@ -6234,7 +6174,6 @@ cc_library( name = "despecializer", srcs = ["despecializer.cc"], hdrs = ["despecializer.h"], - visibility = ["//visibility:public"], deps = [ ":defuser", ":float_normalization", @@ -6252,7 +6191,6 @@ cc_library( name = "source_map_util", srcs = [], hdrs = ["source_map_util.h"], - visibility = ["//visibility:public"], deps = [ ":executable", "//xla:status", @@ -6264,7 +6202,6 @@ cc_library( name = "indexed_array_analysis", srcs = ["indexed_array_analysis.cc"], hdrs = ["indexed_array_analysis.h"], - visibility = ["//visibility:public"], deps = [ ":hlo_pass", "//xla:util", @@ -6294,7 +6231,6 @@ cc_library( name = "hlo_parser", srcs = ["hlo_parser.cc"], hdrs = ["hlo_parser.h"], - visibility = ["//visibility:public"], deps = [ ":computation_layout", ":hlo_lexer", @@ -6358,7 +6294,6 @@ cc_library( hdrs = [ "hlo_lexer.h", ], - visibility = ["//visibility:public"], deps = [ "//xla:shape_util", "//xla:statusor", @@ -6377,7 +6312,6 @@ cc_library( name = "map_inliner", srcs = ["map_inliner.cc"], hdrs = ["map_inliner.h"], - visibility = ["//visibility:public"], deps = [ ":hlo_pass", "//xla:status_macros", @@ -6395,7 +6329,6 @@ cc_library( name = "optimize_input_output_buffer_alias", srcs = ["optimize_input_output_buffer_alias.cc"], hdrs = ["optimize_input_output_buffer_alias.h"], - visibility = ["//visibility:public"], deps = [ ":hlo_pass", "//xla:shape_util", @@ -6433,7 +6366,6 @@ cc_library( name = "ar_crs_combiner", srcs = ["ar_crs_combiner.cc"], hdrs = ["ar_crs_combiner.h"], - visibility = ["//visibility:public"], deps = [ ":call_graph", ":hlo_pass", @@ -6457,7 +6389,6 @@ cc_library( name = "compilation_stats", srcs = ["compilation_stats.cc"], hdrs = ["compilation_stats.h"], - visibility = ["//visibility:public"], deps = [ "//xla:types", "@com_google_absl//absl/container:flat_hash_map", @@ -6471,7 +6402,6 @@ cc_library( name = "dynamic_index_splitter", srcs = ["dynamic_index_splitter.cc"], hdrs = ["dynamic_index_splitter.h"], - visibility = ["//visibility:public"], deps = [ ":hlo_pass", "//xla:shape_util", @@ -6542,7 +6472,6 @@ cc_library( name = "conditional_to_select", srcs = ["conditional_to_select.cc"], hdrs = ["conditional_to_select.h"], - visibility = ["//visibility:public"], deps = [ ":call_graph", ":call_inliner", @@ -6575,7 +6504,6 @@ cc_library( name = "slice_sinker", srcs = ["slice_sinker.cc"], hdrs = ["slice_sinker.h"], - visibility = ["//visibility:public"], deps = [ ":hlo_pass", "//xla:shape_util", @@ -6591,6 +6519,17 @@ cc_library( visibility = ["//visibility:public"], ) +cc_test( + name = "custom_call_target_registry_test", + srcs = ["custom_call_target_registry_test.cc"], + deps = [ + ":custom_call_status", + ":custom_call_target_registry", + "//xla:test", + "@local_tsl//tsl/platform:test_main", + ], +) + # Exposes the public interface only and hides internal details. Suitable for # linking into a static library or binary. cc_library( @@ -6613,13 +6552,13 @@ filegroup( "custom_call_status.h", "custom_call_status_internal.h", ], - visibility = ["//visibility:public"], + visibility = internal_visibility([":friends"]), ) filegroup( name = "custom_call_status_srcs", srcs = ["custom_call_status.cc"], - visibility = ["//visibility:public"], + visibility = internal_visibility([":friends"]), ) # Internal version that exposes internal details and private interfaces. For @@ -6630,7 +6569,10 @@ cc_library( "custom_call_status_internal.h", ], compatible_with = get_compatible_with_portable(), - visibility = ["//visibility:public"], + visibility = internal_visibility([ + ":__subpackages__", + "//tensorflow/compiler/tf2xla:__pkg__", + ]), deps = [ ":custom_call_status", "@com_google_absl//absl/strings", @@ -6661,7 +6603,6 @@ cc_library( testonly = True, srcs = ["custom_call_status_test_c_caller.c"], hdrs = ["custom_call_status_test_c_caller.h"], - visibility = ["//visibility:public"], deps = [":custom_call_status"], ) @@ -6688,7 +6629,6 @@ cc_library( name = "rng_expander", srcs = ["rng_expander.cc"], hdrs = ["rng_expander.h"], - visibility = ["//visibility:public"], deps = [ ":hlo_creation_utils", ":op_expander_pass", @@ -6703,7 +6643,6 @@ cc_library( name = "rng_bit_generator_expander", srcs = ["rng_bit_generator_expander.cc"], hdrs = ["rng_bit_generator_expander.h"], - visibility = ["//visibility:public"], deps = [ ":op_expander_pass", "//xla:shape_util", @@ -6721,7 +6660,6 @@ cc_library( name = "slow_operation_alarm", srcs = ["slow_operation_alarm.cc"], hdrs = ["slow_operation_alarm.h"], - visibility = ["//visibility:public"], deps = [ "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base", @@ -6738,7 +6676,6 @@ cc_library( name = "collective_ops_utils", srcs = ["collective_ops_utils.cc"], hdrs = ["collective_ops_utils.h"], - visibility = ["//visibility:public"], deps = [ ":computation_placer", ":global_device_id", @@ -6764,7 +6701,6 @@ cc_library( name = "collective_transformation_reorderer", srcs = ["collective_transformation_reorderer.cc"], hdrs = ["collective_transformation_reorderer.h"], - visibility = ["//visibility:public"], deps = [ ":hlo_dce", ":hlo_pass", @@ -6792,10 +6728,15 @@ xla_cc_test( ":collective_ops_utils", ":computation_placer", ":global_device_id", + ":hlo_parser", + "//xla:shape_util", "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", "//xla/tests:xla_internal_test_main", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/strings:string_view", "@local_tsl//tsl/lib/core:status_test_util", + "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/platform:test", ], ) @@ -6804,11 +6745,11 @@ cc_library( name = "topk_rewriter", srcs = ["topk_rewriter.cc"], hdrs = ["topk_rewriter.h"], - visibility = ["//visibility:public"], deps = [ ":hlo_pass", ":pattern_matcher", "//xla:shape_util", + "//xla:util", "//xla/client:xla_builder", "//xla/client/lib:comparators", "//xla/hlo/ir:hlo", @@ -6844,7 +6785,6 @@ cc_library( name = "operand_upcaster", srcs = ["operand_upcaster.cc"], hdrs = ["operand_upcaster.h"], - visibility = ["//visibility:public"], deps = [ ":hlo_creation_utils", ":op_expander_pass", @@ -6871,7 +6811,6 @@ cc_library( name = "result_caster", srcs = ["result_caster.cc"], hdrs = ["result_caster.h"], - visibility = ["//visibility:public"], deps = [ ":op_expander_pass", ":shape_inference", @@ -6896,7 +6835,6 @@ cc_library( name = "global_device_id", srcs = ["global_device_id.cc"], hdrs = ["global_device_id.h"], - visibility = ["//visibility:public"], deps = [ "//xla:types", "@com_google_absl//absl/strings", @@ -6909,7 +6847,6 @@ cc_library( name = "convert_operand_folding", srcs = ["convert_operand_folding.cc"], hdrs = ["convert_operand_folding.h"], - visibility = ["//visibility:public"], deps = [ ":op_expander_pass", "//xla:shape_util", @@ -6939,7 +6876,6 @@ cc_library( hdrs = [ "xla_debug_info_manager.h", ], - visibility = ["//visibility:public"], deps = [ ":hlo_proto_cc", ":hlo_proto_util", @@ -6991,7 +6927,6 @@ py_strict_test( cc_library( name = "mapped_ptr_container_sorter", hdrs = ["mapped_ptr_container_sorter.h"], - visibility = ["//visibility:public"], deps = [ "//xla:status", "//xla:statusor", @@ -7021,7 +6956,6 @@ xla_cc_test( cc_library( name = "lockable", hdrs = ["lockable.h"], - visibility = ["//visibility:public"], deps = [ "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/strings:str_format", @@ -7046,7 +6980,6 @@ cc_library( name = "rendezvous", srcs = ["rendezvous.cc"], hdrs = ["rendezvous.h"], - visibility = ["//visibility:public"], deps = [ "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:flat_hash_map", @@ -7068,6 +7001,7 @@ xla_cc_test( "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/time", "@com_google_absl//absl/types:span", "@local_tsl//tsl/platform:env", "@local_tsl//tsl/platform:test", @@ -7080,7 +7014,6 @@ cc_library( name = "compilation_environments", srcs = ["compilation_environments.cc"], hdrs = ["compilation_environments.h"], - visibility = ["//visibility:public"], deps = [ "//xla:statusor", "//xla:xla_proto_cc", @@ -7102,7 +7035,6 @@ cc_library( name = "custom_call_sharding_helper", srcs = ["custom_call_sharding_helper.cc"], hdrs = ["custom_call_sharding_helper.h"], - visibility = ["//visibility:public"], deps = ["//xla/hlo/ir:hlo"], ) @@ -7111,7 +7043,6 @@ tf_proto_library( testonly = 1, srcs = ["test_compilation_environment.proto"], cc_api_version = 2, - visibility = ["//visibility:public"], ) xla_cc_test( @@ -7134,7 +7065,6 @@ cc_library( name = "layout_normalization", srcs = ["layout_normalization.cc"], hdrs = ["layout_normalization.h"], - visibility = ["//visibility:public"], deps = [ ":hlo_creation_utils", ":hlo_pass", @@ -7155,7 +7085,6 @@ cc_library( name = "instruction_hoister", srcs = ["instruction_hoister.cc"], hdrs = ["instruction_hoister.h"], - visibility = ["//visibility:public"], deps = [ ":hlo_pass", "//xla/hlo/ir:hlo", @@ -7166,7 +7095,6 @@ cc_library( name = "scatter_simplifier", srcs = ["scatter_simplifier.cc"], hdrs = ["scatter_simplifier.h"], - visibility = ["//visibility:public"], deps = [ ":gather_scatter_utils", ":hlo_creation_utils", @@ -7198,7 +7126,6 @@ cc_library( name = "select_and_scatter_expander", srcs = ["select_and_scatter_expander.cc"], hdrs = ["select_and_scatter_expander.h"], - visibility = ["//visibility:public"], deps = [ ":call_inliner", ":op_expander_pass", @@ -7235,7 +7162,6 @@ cc_library( name = "change_op_data_type", srcs = ["change_op_data_type.cc"], hdrs = ["change_op_data_type.h"], - visibility = ["//visibility:public"], deps = [ ":hlo_creation_utils", ":hlo_pass", @@ -7246,7 +7172,6 @@ cc_library( name = "gather_scatter_utils", srcs = ["gather_scatter_utils.cc"], hdrs = ["gather_scatter_utils.h"], - visibility = ["//visibility:public"], deps = [ ":hlo_creation_utils", "//xla:permutation_util", @@ -7259,7 +7184,6 @@ cc_library( name = "gather_simplifier", srcs = ["gather_simplifier.cc"], hdrs = ["gather_simplifier.h"], - visibility = ["//visibility:public"], deps = [ ":gather_scatter_utils", ":hlo_creation_utils", @@ -7277,7 +7201,6 @@ cc_library( name = "stochastic_convert_decomposer", srcs = ["stochastic_convert_decomposer.cc"], hdrs = ["stochastic_convert_decomposer.h"], - visibility = ["//visibility:public"], deps = [ ":hlo_creation_utils", ":hlo_pass", @@ -7309,7 +7232,6 @@ xla_cc_test( cc_library( name = "metrics_hook_interface", hdrs = ["metrics_hook_interface.h"], - visibility = ["//visibility:public"], deps = [ ":metrics_proto_cc", "@com_google_absl//absl/strings", @@ -7321,7 +7243,6 @@ cc_library( name = "sub_byte_normalization", srcs = ["sub_byte_normalization.cc"], hdrs = ["sub_byte_normalization.h"], - visibility = ["//visibility:public"], deps = [ ":hlo_pass", "//xla:shape_layout", @@ -7340,7 +7261,6 @@ cc_library( testonly = True, srcs = ["sharding_format_picker.cc"], hdrs = ["sharding_format_picker.h"], - visibility = ["//visibility:public"], deps = [ ":hlo_pass", "//xla:statusor", @@ -7505,7 +7425,10 @@ xla_cc_test( name = "xla_aot_compile_cpu_test", srcs = ["xla_aot_compile_cpu_test.cc"], data = [":xla_aot_compile_test_cpu_executable"], - tags = ["no_oss"], + tags = [ + "no_oss", + "notap", + ], deps = [ ":cpu_plugin", ":platform_util", @@ -7587,13 +7510,11 @@ tf_proto_library( srcs = ["buffer_assignment.proto"], cc_api_version = 2, make_default_target_header_only = True, - visibility = ["//visibility:public"], ) cc_library( name = "export_hlo", hdrs = ["export_hlo.h"], - visibility = ["//visibility:public"], deps = [ "//xla/hlo/ir:hlo", "//xla/stream_executor:device_description_proto_cc", @@ -7605,7 +7526,6 @@ cc_library( name = "gpu_compilation_environment", srcs = ["gpu_compilation_environment.cc"], hdrs = ["gpu_compilation_environment.h"], - visibility = ["//visibility:public"], deps = [ ":compilation_environments", "//xla:parse_flags_from_env", @@ -7640,7 +7560,6 @@ xla_cc_test( cc_library( name = "symbol_repository", hdrs = ["symbol_repository.h"], - visibility = ["//visibility:public"], deps = [ ":compiler", "//xla:xla_proto_cc", @@ -7658,7 +7577,6 @@ cc_library( name = "time_utils", srcs = ["time_utils.cc"], hdrs = ["time_utils.h"], - visibility = ["//visibility:public"], deps = [], ) @@ -7673,7 +7591,4 @@ tf_proto_library( visibility = ["//visibility:public"], ) -exports_files( - ["xla_aot_compile_test_gpu_target_config.prototxt"], - visibility = ["//visibility:public"], -) +exports_files(["xla_aot_compile_test_gpu_target_config.prototxt"]) diff --git a/third_party/xla/xla/service/algebraic_simplifier.cc b/third_party/xla/xla/service/algebraic_simplifier.cc index 9c703b3046fd11..7659b5f9b4d513 100644 --- a/third_party/xla/xla/service/algebraic_simplifier.cc +++ b/third_party/xla/xla/service/algebraic_simplifier.cc @@ -1078,7 +1078,7 @@ Status AlgebraicSimplifierVisitor::HandleAdd(HloInstruction* add) { return OkStatus(); } -StatusOr AlgebraicSimplifierVisitor::TrySimplifyTautologicalCompare( +absl::StatusOr AlgebraicSimplifierVisitor::TrySimplifyTautologicalCompare( HloInstruction* conjunction) { HloInstruction *lhs, *rhs; if (!Match(conjunction, m::And(m::Op(&lhs), m::Op(&rhs)))) { @@ -1122,6 +1122,15 @@ StatusOr AlgebraicSimplifierVisitor::TrySimplifyTautologicalCompare( return false; } +Status AlgebraicSimplifierVisitor::HandleAllToAll(HloInstruction* all_to_all) { + if (all_to_all->shape().IsArray() && + Match(all_to_all->mutable_operand(0), + m::Broadcast(m::ConstantScalar()))) { + return ReplaceInstruction(all_to_all, all_to_all->mutable_operand(0)); + } + return OkStatus(); +} + Status AlgebraicSimplifierVisitor::HandleAnd(HloInstruction* logical_and) { HloInstruction *lhs, *rhs; CHECK(Match(logical_and, m::And(m::Op(&lhs), m::Op(&rhs)))); @@ -1798,7 +1807,7 @@ Status AlgebraicSimplifierVisitor::HandleConcatenate( return OkStatus(); } -StatusOr +absl::StatusOr AlgebraicSimplifierVisitor::TrySimplifyTautologicalBitcastConvert( HloInstruction* bitcast) { CHECK_EQ(bitcast->opcode(), HloOpcode::kBitcastConvert); @@ -2240,7 +2249,8 @@ Status AlgebraicSimplifierVisitor::HandleDivide(HloInstruction* divide) { return OkStatus(); } -StatusOr AlgebraicSimplifierVisitor::RemoveDegenerateDimensionFromDot( +absl::StatusOr +AlgebraicSimplifierVisitor::RemoveDegenerateDimensionFromDot( HloInstruction* dot) { const Shape& lhs_shape = dot->operand(0)->shape(); int64_t num_degenerate_lhs_dims = 0; @@ -2398,7 +2408,8 @@ Status AlgebraicSimplifierVisitor::SimplifyTransposeOfBroadcast( transpose->shape())); } -StatusOr AlgebraicSimplifierVisitor::RemoveTransposesFromDotOperands( +absl::StatusOr +AlgebraicSimplifierVisitor::RemoveTransposesFromDotOperands( HloInstruction* dot) { const int64_t rank = dot->shape().rank(); const auto& dnums = dot->dot_dimension_numbers(); @@ -2460,7 +2471,7 @@ StatusOr AlgebraicSimplifierVisitor::RemoveTransposesFromDotOperands( return true; } -StatusOr +absl::StatusOr AlgebraicSimplifierVisitor::NormalizeDotOperandToBatchMajorAndContractingMinor( HloInstruction* dot_operand, absl::Span batch_dimensions, absl::Span contracting_dimensions) { @@ -2493,7 +2504,7 @@ HloInstruction* AlgebraicSimplifierVisitor::AddReduce( shape, hlo, zero, dims, AddReduce_computation)); } -StatusOr AlgebraicSimplifierVisitor::OptimizeDotOfConcat( +absl::StatusOr AlgebraicSimplifierVisitor::OptimizeDotOfConcat( HloInstruction* dot) { const DotDimensionNumbers& dnums = dot->dot_dimension_numbers(); if (dnums.lhs_contracting_dimensions_size() != 1 || @@ -2519,7 +2530,8 @@ StatusOr AlgebraicSimplifierVisitor::OptimizeDotOfConcat( lhs_contracting_dim, /*swapped=*/true); } -StatusOr AlgebraicSimplifierVisitor::OptimizeDotOfConcatHelper( +absl::StatusOr +AlgebraicSimplifierVisitor::OptimizeDotOfConcatHelper( HloInstruction* dot, HloInstruction* lhs, int64_t lhs_contracting_dim, HloInstruction* rhs, int64_t rhs_contracting_dim, bool swapped) { bool can_optimize = lhs->opcode() == HloOpcode::kConcatenate && @@ -2636,7 +2648,7 @@ StatusOr AlgebraicSimplifierVisitor::OptimizeDotOfConcatHelper( return add_result; } -StatusOr AlgebraicSimplifierVisitor::OptimizeDotOfGather( +absl::StatusOr AlgebraicSimplifierVisitor::OptimizeDotOfGather( HloInstruction* dot) { const DotDimensionNumbers& dnums = dot->dot_dimension_numbers(); if (dnums.lhs_contracting_dimensions_size() != 1 || @@ -2776,7 +2788,7 @@ StatusOr AlgebraicSimplifierVisitor::OptimizeDotOfGather( // associative, so as long as we permute the elements of the contracting // dimensions on both sides of the dot in the same way, the result of the // dot is not affected. -StatusOr +absl::StatusOr AlgebraicSimplifierVisitor::OptimizeDotOfReorderContractingDims( HloInstruction* dot) { // This transformation assumes layout is not assigned yet. @@ -2986,7 +2998,7 @@ AlgebraicSimplifierVisitor::OptimizeDotOfReorderContractingDims( // If appropriate, reorder operation on dot operand to the mirror operation on // the other dot operand -StatusOr +absl::StatusOr AlgebraicSimplifierVisitor::AssociativeReorderDotOperator(HloInstruction* dot) { DotDimensionNumbers dnums = dot->dot_dimension_numbers(); HloInstruction* lhs = dot->mutable_operand(0); @@ -3845,7 +3857,7 @@ Status AlgebraicSimplifierVisitor::HandleGather(HloInstruction* gather) { } namespace { -StatusOr> MinMaxToClamp( +absl::StatusOr> MinMaxToClamp( HloInstruction* clamp_lower_bound_bcast, HloInstruction* to_clamp, HloInstruction* clamp_upper_bound_bcast, AlgebraicSimplifier* simplifier) { HloInstruction* clamp_lower_bound; @@ -4703,6 +4715,37 @@ Status AlgebraicSimplifierVisitor::HandleCompare(HloInstruction* compare) { } } } + + // Gt(Max(a,b), a) -> Gt(b,a) + // Gt(Max(a,b), b) -> Gt(a,b) + // Gt(a, Min(a,b)) -> Gt(a,b) + // Gt(b, Min(a,b)) -> Gt(b,a) + if (compare->comparison_direction() == ComparisonDirection::kGt) { + HloInstruction* a; + HloInstruction* b; + if (Match(lhs, m::Maximum(m::Op(&a), m::Op(&b)))) { + if (rhs == a) { // Gt(Max(a,b), a) -> Gt(b,a) + TF_RETURN_IF_ERROR(compare->ReplaceOperandWith(0, b)); + MarkAsChanged(); + return OkStatus(); + } else if (rhs == b) { // Gt(Max(a,b), b) -> Gt(a,b) + TF_RETURN_IF_ERROR(compare->ReplaceOperandWith(0, a)); + MarkAsChanged(); + return OkStatus(); + } + } else if (Match(rhs, m::Minimum(m::Op(&a), m::Op(&b)))) { + if (lhs == a) { // Gt(a, Min(a,b)) -> Gt(a,b) + TF_RETURN_IF_ERROR(compare->ReplaceOperandWith(1, b)); + MarkAsChanged(); + return OkStatus(); + } else if (lhs == b) { // Gt(b, Min(a,b)) -> Gt(b,a) + TF_RETURN_IF_ERROR(compare->ReplaceOperandWith(1, a)); + MarkAsChanged(); + return OkStatus(); + } + } + } + return OkStatus(); } @@ -5047,7 +5090,7 @@ Status AlgebraicSimplifierVisitor::HandlePower(HloInstruction* power) { return OkStatus(); } -StatusOr +absl::StatusOr AlgebraicSimplifierVisitor::TryToSinkBroadcastAfterOpWithUniqueNonScalarOperand( HloInstruction* broadcast) { TF_RET_CHECK(broadcast->opcode() == HloOpcode::kBroadcast); @@ -5673,7 +5716,7 @@ Status AlgebraicSimplifierVisitor::HandleReverse(HloInstruction* reverse) { return OkStatus(); } -StatusOr AlgebraicSimplifierVisitor::TrySimplifyScalarSlice( +absl::StatusOr AlgebraicSimplifierVisitor::TrySimplifyScalarSlice( HloInstruction* slice) { // Only try to do this for effective scalars. We could do the same for slicing // out larger pieces of padding (replacing with a broadcast of the padding @@ -5725,7 +5768,7 @@ StatusOr AlgebraicSimplifierVisitor::TrySimplifyScalarSlice( return false; } -StatusOr AlgebraicSimplifierVisitor::TryToReorderSliceAndReshape( +absl::StatusOr AlgebraicSimplifierVisitor::TryToReorderSliceAndReshape( HloInstruction* slice) { CHECK_EQ(slice->opcode(), HloOpcode::kSlice); if (!IsUnstridedSlice(slice)) { @@ -5781,7 +5824,7 @@ StatusOr AlgebraicSimplifierVisitor::TryToReorderSliceAndReshape( // Allowing a slice to move through a reverse with any necessary updates to the // slice config. -StatusOr AlgebraicSimplifierVisitor::TryToReorderSliceAndReverse( +absl::StatusOr AlgebraicSimplifierVisitor::TryToReorderSliceAndReverse( HloInstruction* slice) { VLOG(2) << "Entered TryToReorderSliceAndReverse for slice:" << slice->ToString(); @@ -5988,22 +6031,18 @@ Status AlgebraicSimplifierVisitor::HandleSlice(HloInstruction* slice) { // Here we build up the slice dimensions for lhs DimensionVector lhs_start_indices, lhs_limit_indices, lhs_strides; for (int64_t lhs_index = 0; lhs_index < lhs->shape().rank(); ++lhs_index) { - int64_t start = 0; - int64_t limit = lhs->shape().dimensions(lhs_index); - int64_t stride = 1; - if (map_lhs_dot[lhs_index] != -1) { - // If it is not a contracting dimension, we slice it according to the - // slicing of the corresponding dimension in dot - int64_t dot_index = map_lhs_dot[lhs_index]; - start = slice->slice_starts(dot_index); - limit = slice->slice_limits(dot_index); - stride = slice->slice_strides(dot_index); - } + int64_t size = lhs->shape().dimensions(lhs_index); + // If it is not a contracting dimension, we slice it according to the + // slicing of the corresponding dimension in dot + int64_t i = map_lhs_dot[lhs_index]; + int64_t start = i >= 0 ? slice->slice_starts(i) : 0; + int64_t limit = i >= 0 ? slice->slice_limits(i) : size; + int64_t stride = i >= 0 ? slice->slice_strides(i) : 1; lhs_start_indices.push_back(start); lhs_limit_indices.push_back(limit); lhs_strides.push_back(stride); // Record if any slicing occurs here - if (start != 0 || limit < lhs->shape().dimensions(lhs_index)) { + if (start != 0 || limit < size || stride != 1) { slice_lhs = true; } } @@ -6011,22 +6050,18 @@ Status AlgebraicSimplifierVisitor::HandleSlice(HloInstruction* slice) { // Here we do the same for rhs DimensionVector rhs_start_indices, rhs_limit_indices, rhs_strides; for (int64_t rhs_index = 0; rhs_index < rhs->shape().rank(); ++rhs_index) { - int64_t start = 0; - int64_t limit = rhs->shape().dimensions(rhs_index); - int64_t stride = 1; - if (map_rhs_dot[rhs_index] != -1) { - // If it is not a contracting dimension, we slice it according to the - // slicing of the corresponding dimension in dot - int64_t dot_index = map_rhs_dot[rhs_index]; - start = slice->slice_starts(dot_index); - limit = slice->slice_limits(dot_index); - stride = slice->slice_strides(dot_index); - } + int64_t size = rhs->shape().dimensions(rhs_index); + // If it is not a contracting dimension, we slice it according to the + // slicing of the corresponding dimension in dot + int64_t i = map_rhs_dot[rhs_index]; + int64_t start = i >= 0 ? slice->slice_starts(i) : 0; + int64_t limit = i >= 0 ? slice->slice_limits(i) : size; + int64_t stride = i >= 0 ? slice->slice_strides(i) : 1; rhs_start_indices.push_back(start); rhs_limit_indices.push_back(limit); rhs_strides.push_back(stride); // Record if any slicing occurs here - if (start != 0 || limit < rhs->shape().dimensions(rhs_index)) { + if (start != 0 || limit < size || stride != 1) { slice_rhs = true; } } @@ -7212,6 +7247,29 @@ Status AlgebraicSimplifierVisitor::HandleReduce(HloInstruction* hlo) { } } + // For Computation equal to Min, Max, And or Or, replace Reduce(Broadcast(x), + // a, Computation()) with Computation(x, a) when x is a scalar and the + // broadcast is reduced to a scalar. + if (HloInstruction * broadcast_arg; + Match(arg, m::Broadcast(m::Op(&broadcast_arg))) && + (Match(function->root_instruction(), + m::MaximumAnyOrder(m::Parameter(0), m::Parameter(1))) || + Match(function->root_instruction(), + m::MinimumAnyOrder(m::Parameter(0), m::Parameter(1))) || + Match(function->root_instruction(), + m::AndAnyOrder(m::Parameter(0), m::Parameter(1))) || + Match(function->root_instruction(), + m::OrAnyOrder(m::Parameter(0), m::Parameter(1))))) { + if (broadcast_arg->shape().rank() == 0 && + reduce->dimensions().size() == arg->shape().rank()) { + return ReplaceWithNewInstruction( + reduce, + HloInstruction::CreateBinary( + reduce_result_shape, function->root_instruction()->opcode(), + broadcast_arg, reduce->mutable_operand(1))); + } + } + return OkStatus(); } @@ -7746,7 +7804,7 @@ Status AlgebraicSimplifierVisitor::HandleTranspose(HloInstruction* transpose) { } // Convert transpose(dot(a,b)) to dot(b,a). - auto do_transpose_of_dot = [&]() -> StatusOr { + auto do_transpose_of_dot = [&]() -> absl::StatusOr { if (options_.supports_non_canonical_dots() || operand->opcode() != HloOpcode::kDot || operand->user_count() != 1) { return false; @@ -7803,7 +7861,7 @@ Status AlgebraicSimplifierVisitor::HandleTranspose(HloInstruction* transpose) { if (options_.supports_non_canonical_dots() && Match(operand, m::Dot(&dot, m::Op(&lhs), m::Op(&rhs))) && dot->user_count() == 1) { - TF_ASSIGN_OR_RETURN(bool did_transform, [&]() -> StatusOr { + TF_ASSIGN_OR_RETURN(bool did_transform, [&]() -> absl::StatusOr { const auto& dnums = dot->dot_dimension_numbers(); const int64_t num_batch_dims = dnums.lhs_batch_dimensions_size(); for (int64_t i = 0; i < num_batch_dims; ++i) { @@ -7884,7 +7942,7 @@ Status AlgebraicSimplifierVisitor::HandleTranspose(HloInstruction* transpose) { HloInstruction* reshape_operand = operand->mutable_operand(0); HloInstruction* outer_reshape = transpose->users()[0]; TF_ASSIGN_OR_RETURN( - bool did_transform, ([&]() -> StatusOr { + bool did_transform, ([&]() -> absl::StatusOr { if (operand->shape().dimensions_size() != reshape_operand->shape().dimensions_size() + 1) { return false; @@ -8028,7 +8086,7 @@ Status AlgebraicSimplifierVisitor::HandleTranspose(HloInstruction* transpose) { return OkStatus(); } -StatusOr AlgebraicSimplifierVisitor::FoldConvInputPad( +absl::StatusOr AlgebraicSimplifierVisitor::FoldConvInputPad( HloInstruction* convolution) { HloInstruction *lhs, *a, *b; if (Match(convolution, @@ -8089,7 +8147,7 @@ StatusOr AlgebraicSimplifierVisitor::FoldConvInputPad( return false; } -StatusOr AlgebraicSimplifierVisitor::FoldConvFilterPad( +absl::StatusOr AlgebraicSimplifierVisitor::FoldConvFilterPad( HloInstruction* convolution) { auto* lhs = convolution->mutable_operand(0); auto* rhs = convolution->mutable_operand(1); @@ -8155,7 +8213,7 @@ StatusOr AlgebraicSimplifierVisitor::FoldConvFilterPad( return true; } -StatusOr AlgebraicSimplifierVisitor::SwapConvOperands( +absl::StatusOr AlgebraicSimplifierVisitor::SwapConvOperands( HloInstruction* convolution) { if (!options_.enable_conv_operand_swap() || options_.is_layout_sensitive()) { return false; @@ -8292,7 +8350,7 @@ StatusOr AlgebraicSimplifierVisitor::SwapConvOperands( return true; } -StatusOr AlgebraicSimplifierVisitor::SimplifyConvToDot( +absl::StatusOr AlgebraicSimplifierVisitor::SimplifyConvToDot( HloInstruction* convolution) { auto* lhs = convolution->mutable_operand(0); auto* rhs = convolution->mutable_operand(1); @@ -8414,7 +8472,7 @@ StatusOr AlgebraicSimplifierVisitor::SimplifyConvToDot( return true; } -StatusOr AlgebraicSimplifierVisitor::SimplifyConvToMultiply( +absl::StatusOr AlgebraicSimplifierVisitor::SimplifyConvToMultiply( HloInstruction* convolution) { if (options_.is_layout_sensitive() || absl::c_linear_search(convolution->precision_config().operand_precision(), @@ -8602,7 +8660,7 @@ Status AlgebraicSimplifierVisitor::HandleMap(HloInstruction* map) { return ReplaceWithNewInstruction(map, std::move(clone)); } -StatusOr AlgebraicSimplifier::Run( +absl::StatusOr AlgebraicSimplifier::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { bool changed = false; diff --git a/third_party/xla/xla/service/algebraic_simplifier.h b/third_party/xla/xla/service/algebraic_simplifier.h index 893533c3761cc1..8acdeb10218181 100644 --- a/third_party/xla/xla/service/algebraic_simplifier.h +++ b/third_party/xla/xla/service/algebraic_simplifier.h @@ -71,6 +71,11 @@ class AlgebraicSimplifierOptions { return conv_is_lowerable_callback_(reverse_dims); } + void set_conv_is_lowerable_callback( + ConvIsLowerableCallback conv_is_lowerable_callback) { + conv_is_lowerable_callback_ = std::move(conv_is_lowerable_callback); + } + // If is_layout_sensitive is true, then the simplifier preserves layout during // transformation. Otherwise, layout is ignored. void set_is_layout_sensitive(bool is_layout_sensitive) { @@ -282,7 +287,7 @@ class AlgebraicSimplifier : public HloModulePass { // Run algebraic simplification on the given computation. Returns whether the // computation was changed. using HloPassInterface::Run; - StatusOr Run( + absl::StatusOr Run( HloModule* module, const absl::flat_hash_set& execution_threads) override; @@ -313,6 +318,8 @@ class AlgebraicSimplifierVisitor : public DfsHloRewriteVisitor { Status HandleAdd(HloInstruction* add) override; + Status HandleAllToAll(HloInstruction* all_to_all) override; + Status HandleAnd(HloInstruction* logical_and) override; Status HandleBitcast(HloInstruction* bitcast) override; @@ -443,7 +450,7 @@ class AlgebraicSimplifierVisitor : public DfsHloRewriteVisitor { private: // Removes degenerate dimension from dot. - StatusOr RemoveDegenerateDimensionFromDot(HloInstruction* dot); + absl::StatusOr RemoveDegenerateDimensionFromDot(HloInstruction* dot); // Moves the transpose to the broadcast if possible. Can also be called with a // bitcast transpose. @@ -466,7 +473,8 @@ class AlgebraicSimplifierVisitor : public DfsHloRewriteVisitor { // Transposes a dot operand such that the batch dimensions are the most major, // and the contracting dimensions are most minor. - StatusOr NormalizeDotOperandToBatchMajorAndContractingMinor( + absl::StatusOr + NormalizeDotOperandToBatchMajorAndContractingMinor( HloInstruction* dot_operand, absl::Span batch_dimensions, absl::Span contracting_dimensions); @@ -477,7 +485,7 @@ class AlgebraicSimplifierVisitor : public DfsHloRewriteVisitor { // // LHS [batch dims..., non-contracting dim, contracting dim] // RHS [batch dims..., contracting dim, non-contracting dim]. - StatusOr RemoveTransposesFromDotOperands(HloInstruction* dot); + absl::StatusOr RemoveTransposesFromDotOperands(HloInstruction* dot); // Helper method to perform and add reduction on a list of dimensions. HloInstruction* AddReduce(HloInstruction* hlo, absl::Span dims, @@ -521,20 +529,21 @@ class AlgebraicSimplifierVisitor : public DfsHloRewriteVisitor { // A Broadcast that feeds an element-wise operation with a unique non-scalar // operand can sink to after the operation. - StatusOr TryToSinkBroadcastAfterOpWithUniqueNonScalarOperand( + absl::StatusOr TryToSinkBroadcastAfterOpWithUniqueNonScalarOperand( HloInstruction* broadcast); - StatusOr OptimizeDotOfConcat(HloInstruction* dot); - StatusOr OptimizeDotOfConcatHelper( + absl::StatusOr OptimizeDotOfConcat(HloInstruction* dot); + absl::StatusOr OptimizeDotOfConcatHelper( HloInstruction* dot, HloInstruction* lhs, int64_t lhs_contracting_dim, HloInstruction* rhs, int64_t rhs_contracting_dim, bool swapped); - StatusOr OptimizeDotOfGather(HloInstruction* dot); + absl::StatusOr OptimizeDotOfGather(HloInstruction* dot); - StatusOr OptimizeDotOfReorderContractingDims( + absl::StatusOr OptimizeDotOfReorderContractingDims( HloInstruction* dot); - StatusOr AssociativeReorderDotOperator(HloInstruction* dot); + absl::StatusOr AssociativeReorderDotOperator( + HloInstruction* dot); HloComputation* GetOrCreateScalarAddComputation(PrimitiveType type) { HloComputation*& scalar_add_computation = scalar_add_computations_[type]; @@ -558,37 +567,39 @@ class AlgebraicSimplifierVisitor : public DfsHloRewriteVisitor { // Tries to fold a kPad in the input or filter into the convolution // instruction's window. - virtual StatusOr FoldConvInputPad(HloInstruction* convolution); - StatusOr FoldConvFilterPad(HloInstruction* convolution); + virtual absl::StatusOr FoldConvInputPad(HloInstruction* convolution); + absl::StatusOr FoldConvFilterPad(HloInstruction* convolution); // Tries to swap convolution operands if they would result in a more efficient // convolution. - StatusOr SwapConvOperands(HloInstruction* convolution); + absl::StatusOr SwapConvOperands(HloInstruction* convolution); // Tries to use a kDot in place of the given convolution. - StatusOr SimplifyConvToDot(HloInstruction* convolution); + absl::StatusOr SimplifyConvToDot(HloInstruction* convolution); // Tries to use a multiplication in place of the given convolution. - StatusOr SimplifyConvToMultiply(HloInstruction* convolution); + absl::StatusOr SimplifyConvToMultiply(HloInstruction* convolution); // Tries to simplify a slice where the result of the slice is a scalar. - StatusOr TrySimplifyScalarSlice(HloInstruction* slice); + absl::StatusOr TrySimplifyScalarSlice(HloInstruction* slice); // Tries to convert slice(reshape(X)) into reshape(slice(X)) - StatusOr TryToReorderSliceAndReshape(HloInstruction* slice); + absl::StatusOr TryToReorderSliceAndReshape(HloInstruction* slice); // Tries to convert slice(reverse(X)) into reverse(slice(X)) - StatusOr TryToReorderSliceAndReverse(HloInstruction* slice); + absl::StatusOr TryToReorderSliceAndReverse(HloInstruction* slice); // Tries to simplify `(and (< a N) (< a K))` in cases where `N <= K` into // `(< a N)`. This is crucial for being able to figure out the loop trip // count. // // Assumes that the input is conjunction. - StatusOr TrySimplifyTautologicalCompare(HloInstruction* conjunction); + absl::StatusOr TrySimplifyTautologicalCompare( + HloInstruction* conjunction); // Tries to simlplify (bitcast-convert (concat (bitcast-convert A) ...)) where // the types of inner and outer bitcast-convert cancel out. - StatusOr TrySimplifyTautologicalBitcastConvert(HloInstruction* bitcast); + absl::StatusOr TrySimplifyTautologicalBitcastConvert( + HloInstruction* bitcast); // Tries to remove surrounding converts around a binary op where the op has a // more precise type than its inputs and output. diff --git a/third_party/xla/xla/service/algebraic_simplifier_test.cc b/third_party/xla/xla/service/algebraic_simplifier_test.cc index 096b1fbb79ea0d..8df8939babf899 100644 --- a/third_party/xla/xla/service/algebraic_simplifier_test.cc +++ b/third_party/xla/xla/service/algebraic_simplifier_test.cc @@ -952,6 +952,54 @@ TEST_F(AlgebraicSimplifierTest, ReduceOfNegate) { GmockMatch(m::Negate(m::Reduce(m::Parameter(0), m::ConstantScalar(0))))); } +TEST_F(AlgebraicSimplifierTest, ReduceBroadcastOfScalar) { + // Test Reduce(Broadcast(x), a, Max) + const char* kModuleStrForMax = R"( + HloModule m + max_f32 { + p0 = f32[] parameter(0) + p1 = f32[] parameter(1) + ROOT r = f32[] maximum(p0, p1) + } + + ENTRY test { + p = f32[] parameter(0) + b = f32[1000,1000] broadcast(p), dimensions={} + ROOT reduce = f32[] reduce(b, f32[] constant(0)), dimensions={0,1}, to_apply=max_f32 + } + )"; + + TF_ASSERT_OK_AND_ASSIGN(auto m, + ParseAndReturnVerifiedModule(kModuleStrForMax)); + AlgebraicSimplifierOptions options = default_options_; + ASSERT_TRUE(AlgebraicSimplifier(options).Run(m.get()).value()); + EXPECT_THAT( + m->entry_computation()->root_instruction(), + GmockMatch(m::MaximumAnyOrder(m::Parameter(0), m::ConstantScalar(0)))); + + // Test Reduce(Broadcast(x), a, And) + const char* kModuleStrForAnd = R"( + HloModule m + and_u4 { + p0 = u4[] parameter(0) + p1 = u4[] parameter(1) + ROOT r = u4[] and(p0, p1) + } + + ENTRY test { + p = u4[] parameter(0) + b = u4[1000,1000] broadcast(p), dimensions={} + ROOT reduce = u4[] reduce(b, u4[] constant(0)), dimensions={0,1}, to_apply=and_u4 + } + )"; + + TF_ASSERT_OK_AND_ASSIGN(m, ParseAndReturnVerifiedModule(kModuleStrForAnd)); + ASSERT_TRUE(AlgebraicSimplifier(options).Run(m.get()).value()); + EXPECT_THAT( + m->entry_computation()->root_instruction(), + GmockMatch(m::AndAnyOrder(m::Parameter(0), m::ConstantScalar(0)))); +} + // Test that Const + A is canonicalized to A + Const. TEST_F(AlgebraicSimplifierTest, AddConstOnLHS) { auto m = CreateNewVerifiedModule(); @@ -6275,6 +6323,30 @@ TEST_F(AlgebraicSimplifierTest, SliceDotReorder) { GmockMatch(m::Dot(m::Slice(m::Parameter(0)), m::Parameter(1)))); } +TEST_F(AlgebraicSimplifierTest, SliceDotReorderWithStrides) { + const char* hlo_string = R"( + HloModule module + + ENTRY test { + a = f32[2048,2] parameter(0) + b = f32[2,2048] parameter(1) + dot = f32[2048,2048] dot(a,b), + lhs_contracting_dims={1}, + rhs_contracting_dims={0} + ROOT slice = f32[16,256] slice(dot), slice={[0:128:8],[0:2048:8]} + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + + AlgebraicSimplifierOptions options; + options.set_use_associative_reordering(true); + EXPECT_TRUE(AlgebraicSimplifier(options).Run(module.get()).value()); + ASSERT_THAT( + module->entry_computation()->root_instruction(), + GmockMatch(m::Dot(m::Slice(m::Parameter(0)), m::Slice(m::Parameter(1))))); +} + TEST_F(AlgebraicSimplifierTest, TransposeOfBatchDot) { const char* hlo_string = R"( HloModule module @@ -8096,6 +8168,78 @@ TEST_F(AlgebraicSimplifierTest, ASSERT_FALSE(AlgebraicSimplifier(default_options_).Run(m.get()).value()); } +TEST_F(AlgebraicSimplifierTest, CompareGtMaxA) { + // Gt(Max(a,b), a) -> Gt(b,a) + const char* kModuleStr = R"( + HloModule m + test { + a = f32[4] parameter(0) + b = f32[4] parameter(1) + m0 = f32[4] maximum(a, b) + ROOT compare = pred[4] compare(m0, a), direction=GT + })"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).value()); + EXPECT_THAT( + m->entry_computation()->root_instruction(), + GmockMatch(m::Compare(m::Parameter(1), m::Parameter(0)) + .WithComparisonDirection(ComparisonDirection::kGt))); +} + +TEST_F(AlgebraicSimplifierTest, CompareGtMaxB) { + // Gt(Max(a,b), b) -> Gt(a,b) + const char* kModuleStr = R"( + HloModule m + test { + a = f32[4] parameter(0) + b = f32[4] parameter(1) + m0 = f32[4] maximum(a, b) + ROOT compare = pred[4] compare(m0, b), direction=GT + })"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).value()); + EXPECT_THAT( + m->entry_computation()->root_instruction(), + GmockMatch(m::Compare(m::Parameter(0), m::Parameter(1)) + .WithComparisonDirection(ComparisonDirection::kGt))); +} + +TEST_F(AlgebraicSimplifierTest, CompareGtAMin) { + // Gt(a, Min(a,b)) -> Gt(a,b) + const char* kModuleStr = R"( + HloModule m + test { + a = f32[4] parameter(0) + b = f32[4] parameter(1) + m0 = f32[4] minimum(a, b) + ROOT compare = pred[4] compare(a, m0), direction=GT + })"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).value()); + EXPECT_THAT( + m->entry_computation()->root_instruction(), + GmockMatch(m::Compare(m::Parameter(0), m::Parameter(1)) + .WithComparisonDirection(ComparisonDirection::kGt))); +} + +TEST_F(AlgebraicSimplifierTest, CompareGtBMin) { + // Gt(b, Min(a,b)) -> Gt(b,a) + const char* kModuleStr = R"( + HloModule m + test { + a = f32[4] parameter(0) + b = f32[4] parameter(1) + m0 = f32[4] minimum(a, b) + ROOT compare = pred[4] compare(b, m0), direction=GT + })"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).value()); + EXPECT_THAT( + m->entry_computation()->root_instruction(), + GmockMatch(m::Compare(m::Parameter(1), m::Parameter(0)) + .WithComparisonDirection(ComparisonDirection::kGt))); +} + TEST_F(AlgebraicSimplifierTest, CompareIota) { const char* kModuleStr = R"( HloModule m diff --git a/third_party/xla/xla/service/all_gather_broadcast_reorder.cc b/third_party/xla/xla/service/all_gather_broadcast_reorder.cc index 97ba5a9ba3f84c..31a72c2d827506 100644 --- a/third_party/xla/xla/service/all_gather_broadcast_reorder.cc +++ b/third_party/xla/xla/service/all_gather_broadcast_reorder.cc @@ -27,7 +27,7 @@ limitations under the License. namespace xla { -StatusOr AllGatherBroadcastReorder::Run( +absl::StatusOr AllGatherBroadcastReorder::Run( HloModule *module, const absl::flat_hash_set &execution_threads) { if (hlo_query::ContainsLayoutConstrainedCollective(*module, diff --git a/third_party/xla/xla/service/all_gather_broadcast_reorder.h b/third_party/xla/xla/service/all_gather_broadcast_reorder.h index 018d48a04009a0..5746f2424fd95e 100644 --- a/third_party/xla/xla/service/all_gather_broadcast_reorder.h +++ b/third_party/xla/xla/service/all_gather_broadcast_reorder.h @@ -31,7 +31,7 @@ class AllGatherBroadcastReorder : public HloModulePass { absl::string_view name() const override { return "all-gather-bcast-reorder"; } using HloPassInterface::Run; - StatusOr Run( + absl::StatusOr Run( HloModule* module, const absl::flat_hash_set& execution_threads) override; }; diff --git a/third_party/xla/xla/service/all_gather_combiner.cc b/third_party/xla/xla/service/all_gather_combiner.cc index 29b37195e500be..ecb9fb42474d61 100644 --- a/third_party/xla/xla/service/all_gather_combiner.cc +++ b/third_party/xla/xla/service/all_gather_combiner.cc @@ -198,7 +198,7 @@ AllGatherCombiner::AllGatherCombiner(int64_t combine_threshold_in_bytes, combine_threshold_count_(combine_threshold_count), combine_by_dim_(combine_by_dim) {} -StatusOr AllGatherCombiner::Run( +absl::StatusOr AllGatherCombiner::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { VLOG(1) << "Running AllGatherCombiner with threshold of " diff --git a/third_party/xla/xla/service/all_gather_combiner.h b/third_party/xla/xla/service/all_gather_combiner.h index e09597e422e831..8e7a0e062799cf 100644 --- a/third_party/xla/xla/service/all_gather_combiner.h +++ b/third_party/xla/xla/service/all_gather_combiner.h @@ -36,7 +36,7 @@ class AllGatherCombiner : public HloModulePass { absl::string_view name() const override { return "all-gather-combiner"; } using HloPassInterface::Run; - StatusOr Run( + absl::StatusOr Run( HloModule* module, const absl::flat_hash_set& execution_threads) override; diff --git a/third_party/xla/xla/service/all_gather_decomposer.cc b/third_party/xla/xla/service/all_gather_decomposer.cc index 0b6b389ba99e35..24db8fd3599476 100644 --- a/third_party/xla/xla/service/all_gather_decomposer.cc +++ b/third_party/xla/xla/service/all_gather_decomposer.cc @@ -101,7 +101,7 @@ Status DecomposeAllGather(HloAllGatherInstruction* ag, HloComputation* comp) { return OkStatus(); } -StatusOr AllGatherDecomposer::Run( +absl::StatusOr AllGatherDecomposer::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { bool changed = false; diff --git a/third_party/xla/xla/service/all_gather_decomposer.h b/third_party/xla/xla/service/all_gather_decomposer.h index fb7eb2ed5b75f9..da56d0c4023037 100644 --- a/third_party/xla/xla/service/all_gather_decomposer.h +++ b/third_party/xla/xla/service/all_gather_decomposer.h @@ -37,7 +37,7 @@ class AllGatherDecomposer : public HloModulePass { // Run AllGatherDecomposer pass on computations in 'module'. // Returns whether the 'module' was changed. using HloPassInterface::Run; - StatusOr Run( + absl::StatusOr Run( HloModule* module, const absl::flat_hash_set& execution_threads) override; diff --git a/third_party/xla/xla/service/all_reduce_combiner.cc b/third_party/xla/xla/service/all_reduce_combiner.cc index e2b798b6d4d346..5d7b9b6ee4c5bb 100644 --- a/third_party/xla/xla/service/all_reduce_combiner.cc +++ b/third_party/xla/xla/service/all_reduce_combiner.cc @@ -108,7 +108,7 @@ AllReduceCombiner::AllReduceCombiner(int64_t combine_threshold_in_bytes, : combine_threshold_in_bytes_(combine_threshold_in_bytes), combine_threshold_count_(combine_threshold_count) {} -StatusOr AllReduceCombiner::Run( +absl::StatusOr AllReduceCombiner::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { VLOG(1) << "Running AllReduceCombiner with threshold of " diff --git a/third_party/xla/xla/service/all_reduce_combiner.h b/third_party/xla/xla/service/all_reduce_combiner.h index 279080f95d1596..4ef9e961258256 100644 --- a/third_party/xla/xla/service/all_reduce_combiner.h +++ b/third_party/xla/xla/service/all_reduce_combiner.h @@ -37,7 +37,7 @@ class AllReduceCombiner : public HloModulePass { absl::string_view name() const override { return "all-reduce-combiner"; } using HloPassInterface::Run; - StatusOr Run( + absl::StatusOr Run( HloModule* module, const absl::flat_hash_set& execution_threads) override; diff --git a/third_party/xla/xla/service/all_reduce_contiguous.cc b/third_party/xla/xla/service/all_reduce_contiguous.cc index f0805eb83e0c2a..7f07b2fa756df6 100644 --- a/third_party/xla/xla/service/all_reduce_contiguous.cc +++ b/third_party/xla/xla/service/all_reduce_contiguous.cc @@ -83,7 +83,7 @@ Status ReplaceWithContiguousAllReduce(HloAllReduceInstruction* all_reduce) { } } // namespace -StatusOr AllReduceContiguous::Run( +absl::StatusOr AllReduceContiguous::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { VLOG(1) << "Running AllReduceContiguous"; diff --git a/third_party/xla/xla/service/all_reduce_contiguous.h b/third_party/xla/xla/service/all_reduce_contiguous.h index f87aeb5cfd2e9b..d81582536fba40 100644 --- a/third_party/xla/xla/service/all_reduce_contiguous.h +++ b/third_party/xla/xla/service/all_reduce_contiguous.h @@ -29,7 +29,7 @@ class AllReduceContiguous : public HloModulePass { absl::string_view name() const override { return "all-reduce-contiguous"; } using HloPassInterface::Run; - StatusOr Run( + absl::StatusOr Run( HloModule* module, const absl::flat_hash_set& execution_threads) override; }; diff --git a/third_party/xla/xla/service/all_reduce_folder.cc b/third_party/xla/xla/service/all_reduce_folder.cc index 61b6fbb8039172..9d034dd45ef606 100644 --- a/third_party/xla/xla/service/all_reduce_folder.cc +++ b/third_party/xla/xla/service/all_reduce_folder.cc @@ -136,7 +136,7 @@ std::optional> FoldReplicaGroups( } // namespace -StatusOr AllReduceFolder::Run( +absl::StatusOr AllReduceFolder::Run( HloModule *module, const absl::flat_hash_set &execution_threads) { if (hlo_query::ContainsLayoutConstrainedAllReduce(*module)) { diff --git a/third_party/xla/xla/service/all_reduce_folder.h b/third_party/xla/xla/service/all_reduce_folder.h index 4a8f4cc677ed5a..e175a65677163b 100644 --- a/third_party/xla/xla/service/all_reduce_folder.h +++ b/third_party/xla/xla/service/all_reduce_folder.h @@ -38,7 +38,7 @@ class AllReduceFolder : public HloModulePass { absl::string_view name() const override { return "all-reduce-folder"; } using HloPassInterface::Run; - StatusOr Run( + absl::StatusOr Run( HloModule* module, const absl::flat_hash_set& execution_threads) override; }; diff --git a/third_party/xla/xla/service/all_reduce_folder_test.cc b/third_party/xla/xla/service/all_reduce_folder_test.cc index 990fc15174a825..57d2c7518838d2 100644 --- a/third_party/xla/xla/service/all_reduce_folder_test.cc +++ b/third_party/xla/xla/service/all_reduce_folder_test.cc @@ -30,15 +30,15 @@ using ::testing::HasSubstr; class AllReduceFolderTest : public HloTestBase { public: - StatusOr> RunPass(absl::string_view hlo_module, - bool expect_change) { + absl::StatusOr> RunPass( + absl::string_view hlo_module, bool expect_change) { TF_ASSIGN_OR_RETURN(auto module, ParseAndReturnVerifiedModule(hlo_module)); auto changed = AllReduceFolder().Run(module.get()); if (!changed.ok()) { return changed.status(); } EXPECT_EQ(changed.value(), expect_change); - return StatusOr>(std::move(module)); + return absl::StatusOr>(std::move(module)); } size_t AllReduceCount(std::unique_ptr &module) { diff --git a/third_party/xla/xla/service/all_reduce_promotion.cc b/third_party/xla/xla/service/all_reduce_promotion.cc index 39290a8fc9e0ba..b0328759c7d310 100644 --- a/third_party/xla/xla/service/all_reduce_promotion.cc +++ b/third_party/xla/xla/service/all_reduce_promotion.cc @@ -61,7 +61,7 @@ AllReducePromotion::AllReducePromotion( absl::Span const> from_to_types) : pass_(from_to_types, IsAllReduce, CloneAllReduce) {} -StatusOr AllReducePromotion::Run( +absl::StatusOr AllReducePromotion::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { return pass_.Run(module, execution_threads); diff --git a/third_party/xla/xla/service/all_reduce_promotion.h b/third_party/xla/xla/service/all_reduce_promotion.h index f2a7619156d9ad..a1ad33033187f1 100644 --- a/third_party/xla/xla/service/all_reduce_promotion.h +++ b/third_party/xla/xla/service/all_reduce_promotion.h @@ -28,7 +28,7 @@ class AllReducePromotion : public HloModulePass { absl::string_view name() const override { return "all-reduce-promotion"; } using HloPassInterface::Run; - StatusOr Run( + absl::StatusOr Run( HloModule* module, const absl::flat_hash_set& execution_threads) override; diff --git a/third_party/xla/xla/service/all_reduce_reassociate.cc b/third_party/xla/xla/service/all_reduce_reassociate.cc index 97918099a25be8..84d83b0b736c60 100644 --- a/third_party/xla/xla/service/all_reduce_reassociate.cc +++ b/third_party/xla/xla/service/all_reduce_reassociate.cc @@ -171,7 +171,7 @@ bool MatchOperandsToAllReduceWithOptionalConvert(HloInstruction* inst, } } // namespace -StatusOr AllReduceReassociate::Run( +absl::StatusOr AllReduceReassociate::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { if (hlo_query::ContainsLayoutConstrainedAllReduce(*module)) { diff --git a/third_party/xla/xla/service/all_reduce_reassociate.h b/third_party/xla/xla/service/all_reduce_reassociate.h index 5967e267cf056c..228d2f5cd15b5b 100644 --- a/third_party/xla/xla/service/all_reduce_reassociate.h +++ b/third_party/xla/xla/service/all_reduce_reassociate.h @@ -38,7 +38,7 @@ class AllReduceReassociate : public HloModulePass { absl::string_view name() const override { return "all-reduce-reassociate"; } using HloPassInterface::Run; - StatusOr Run( + absl::StatusOr Run( HloModule* module, const absl::flat_hash_set& execution_threads) override; diff --git a/third_party/xla/xla/service/all_reduce_reassociate_test.cc b/third_party/xla/xla/service/all_reduce_reassociate_test.cc index 5ee940e7a4ffc8..aa1f13eaf04a79 100644 --- a/third_party/xla/xla/service/all_reduce_reassociate_test.cc +++ b/third_party/xla/xla/service/all_reduce_reassociate_test.cc @@ -36,7 +36,7 @@ using ::testing::_; class AllReduceSimplifierTest : public HloTestBase { public: - StatusOr> RunPass( + absl::StatusOr> RunPass( absl::string_view hlo_module, bool expect_change, bool reassociate_converted_ar = false) { TF_ASSIGN_OR_RETURN(auto module, ParseAndReturnVerifiedModule(hlo_module)); @@ -46,7 +46,7 @@ class AllReduceSimplifierTest : public HloTestBase { return changed.status(); } EXPECT_EQ(changed.value(), expect_change); - return StatusOr>(std::move(module)); + return absl::StatusOr>(std::move(module)); } size_t AllReduceCount(std::unique_ptr& module) { diff --git a/third_party/xla/xla/service/all_reduce_simplifier.cc b/third_party/xla/xla/service/all_reduce_simplifier.cc index 9289fd394896d9..67aadf41b9e988 100644 --- a/third_party/xla/xla/service/all_reduce_simplifier.cc +++ b/third_party/xla/xla/service/all_reduce_simplifier.cc @@ -27,7 +27,7 @@ limitations under the License. namespace xla { -StatusOr AllReduceSimplifier::Run( +absl::StatusOr AllReduceSimplifier::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { TF_ASSIGN_OR_RETURN( diff --git a/third_party/xla/xla/service/all_reduce_simplifier.h b/third_party/xla/xla/service/all_reduce_simplifier.h index 41de78c01aa746..72bc60923dc3ba 100644 --- a/third_party/xla/xla/service/all_reduce_simplifier.h +++ b/third_party/xla/xla/service/all_reduce_simplifier.h @@ -36,7 +36,7 @@ class AllReduceSimplifier : public HloModulePass { // Run all-reduce simplification on the given computation. Returns whether the // computation was changed. using HloPassInterface::Run; - StatusOr Run( + absl::StatusOr Run( HloModule* module, const absl::flat_hash_set& execution_threads) override; diff --git a/third_party/xla/xla/service/all_to_all_decomposer.cc b/third_party/xla/xla/service/all_to_all_decomposer.cc index 2b82726c56fc70..241b242f693fc9 100644 --- a/third_party/xla/xla/service/all_to_all_decomposer.cc +++ b/third_party/xla/xla/service/all_to_all_decomposer.cc @@ -45,7 +45,7 @@ bool AllToAllDecomposer::InstructionMatchesPattern( } return all_to_all->shape().rank() < min_array_rank_; } -StatusOr AllToAllDecomposer::ExpandInstruction( +absl::StatusOr AllToAllDecomposer::ExpandInstruction( HloInstruction* instruction) { auto* all_to_all = Cast(instruction); int64_t split_dim = *all_to_all->split_dimension(); diff --git a/third_party/xla/xla/service/all_to_all_decomposer.h b/third_party/xla/xla/service/all_to_all_decomposer.h index d6cc09f93c93d5..3ef1891a412665 100644 --- a/third_party/xla/xla/service/all_to_all_decomposer.h +++ b/third_party/xla/xla/service/all_to_all_decomposer.h @@ -35,7 +35,7 @@ class AllToAllDecomposer : public OpExpanderPass { private: bool InstructionMatchesPattern(HloInstruction* instruction) override; - StatusOr ExpandInstruction( + absl::StatusOr ExpandInstruction( HloInstruction* instruction) override; bool decompose_to_tuple_; int64_t min_array_rank_; diff --git a/third_party/xla/xla/service/allocation_tracker.cc b/third_party/xla/xla/service/allocation_tracker.cc index 2849c241889019..a2cbad64b75964 100644 --- a/third_party/xla/xla/service/allocation_tracker.cc +++ b/third_party/xla/xla/service/allocation_tracker.cc @@ -31,7 +31,7 @@ limitations under the License. namespace xla { -StatusOr AllocationTracker::Register( +absl::StatusOr AllocationTracker::Register( ScopedShapedBuffer shaped_buffer, const std::string& tag) { absl::MutexLock lock(&mutex_); VLOG(2) << "Register"; @@ -40,7 +40,7 @@ StatusOr AllocationTracker::Register( return RegisterInternal(std::move(replicated_buffers), tag); } -StatusOr AllocationTracker::RegisterReplicatedBuffers( +absl::StatusOr AllocationTracker::RegisterReplicatedBuffers( std::vector replicated_buffers, const std::string& tag) { absl::MutexLock lock(&mutex_); @@ -57,7 +57,7 @@ static ShapedBuffer ReleaseIfScopedShapedBuffer(ScopedShapedBuffer b) { } template -StatusOr AllocationTracker::RegisterInternal( +absl::StatusOr AllocationTracker::RegisterInternal( std::vector replicated_buffers, const std::string& tag) { static_assert(std::is_same::value || std::is_same::value, @@ -126,8 +126,8 @@ Status AllocationTracker::Unregister(const GlobalDataHandle& data) { return OkStatus(); } -StatusOr> AllocationTracker::DeconstructTuple( - const GlobalDataHandle& data) { +absl::StatusOr> +AllocationTracker::DeconstructTuple(const GlobalDataHandle& data) { absl::MutexLock lock(&mutex_); TF_ASSIGN_OR_RETURN(std::vector replicated_buffers, @@ -164,13 +164,13 @@ StatusOr> AllocationTracker::DeconstructTuple( return std::move(element_handles); } -StatusOr> AllocationTracker::Resolve( +absl::StatusOr> AllocationTracker::Resolve( const GlobalDataHandle& data) const { absl::MutexLock lock(&mutex_); return AllocationTracker::ResolveInternal(data); } -StatusOr AllocationTracker::ResolveForReplica( +absl::StatusOr AllocationTracker::ResolveForReplica( const GlobalDataHandle& data, int replica_id) const { absl::MutexLock lock(&mutex_); TF_ASSIGN_OR_RETURN(std::vector replicated_buffers, @@ -184,8 +184,8 @@ StatusOr AllocationTracker::ResolveForReplica( return replicated_buffers[replica_id]; } -StatusOr> AllocationTracker::ResolveInternal( - const GlobalDataHandle& data) const { +absl::StatusOr> +AllocationTracker::ResolveInternal(const GlobalDataHandle& data) const { VLOG(2) << "resolve:" << data.handle(); auto it = handle_to_shaped_buffers_.find(data.handle()); if (it == handle_to_shaped_buffers_.end()) { diff --git a/third_party/xla/xla/service/allocation_tracker.h b/third_party/xla/xla/service/allocation_tracker.h index 7e3548b977f73f..c8359c0fc30971 100644 --- a/third_party/xla/xla/service/allocation_tracker.h +++ b/third_party/xla/xla/service/allocation_tracker.h @@ -42,12 +42,12 @@ class AllocationTracker { // Registers a shaped buffer of device memory, and returns a corresponding // handle that can be used for talking to XLA clients. The given shaped buffer // will be treated as the buffer corresponding to the only replica. - StatusOr Register(ScopedShapedBuffer shaped_buffer, - const std::string& tag); + absl::StatusOr Register(ScopedShapedBuffer shaped_buffer, + const std::string& tag); // Registers a vector of shaped buffers of device memory, one per replica, and // returns a corresponding handle that can be used for talking to XLA clients. - StatusOr RegisterReplicatedBuffers( + absl::StatusOr RegisterReplicatedBuffers( std::vector replicated_buffers, const std::string& tag); @@ -55,20 +55,20 @@ class AllocationTracker { Status Unregister(const GlobalDataHandle& data); // Returns a vector of global data handles that point to the tuple elements. - StatusOr> DeconstructTuple( + absl::StatusOr> DeconstructTuple( const GlobalDataHandle& Data); // Resolve a handle from an XLA client to a vector of shaped buffers, one per // replica, or provide an error status to say whether any of those buffers // were not found (or found, but found deallocated). - StatusOr> Resolve( + absl::StatusOr> Resolve( const GlobalDataHandle& data) const; // Resolves a handle from an XLA client and replica id to a shaped buffer, or // provide an error status to say whether it was not found (or found, but // found deallocated). - StatusOr ResolveForReplica(const GlobalDataHandle& data, - int replica_id) const; + absl::StatusOr ResolveForReplica( + const GlobalDataHandle& data, int replica_id) const; private: // Data structure encapsulating single memory allocation on the device. @@ -83,7 +83,7 @@ class AllocationTracker { // Internal helper which resolves the given GlobalDataHandle to a // list of ScopedShapedBuffers. - StatusOr> ResolveInternal( + absl::StatusOr> ResolveInternal( const GlobalDataHandle& data) const ABSL_EXCLUSIVE_LOCKS_REQUIRED(mutex_); // Internal helper which registers a vector of shaped buffers, one per @@ -91,7 +91,7 @@ class AllocationTracker { // it's ShapedBuffer, all of the given buffers must already be tracked by this // object -- presumably this is a call from DeconstructTuple. template - StatusOr RegisterInternal( + absl::StatusOr RegisterInternal( std::vector replicated_buffers, const std::string& tag) ABSL_EXCLUSIVE_LOCKS_REQUIRED(mutex_); diff --git a/third_party/xla/xla/service/ar_crs_combiner.cc b/third_party/xla/xla/service/ar_crs_combiner.cc index faa6969b5b851c..1ad71ca8d7e56c 100644 --- a/third_party/xla/xla/service/ar_crs_combiner.cc +++ b/third_party/xla/xla/service/ar_crs_combiner.cc @@ -42,8 +42,8 @@ namespace { // divide by the number of partitions. Depending on the topology and the // implementation of the all-reduce for the backend, this may give a better // performance. -StatusOr ReplaceReplicatedAllReduce(HloModule* module, - int64_t partition_count) { +absl::StatusOr ReplaceReplicatedAllReduce(HloModule* module, + int64_t partition_count) { TF_ASSIGN_OR_RETURN( auto replication_analysis, HloReplicationAnalysis::Run(module, /*cross_partition_spmd=*/true)); @@ -534,7 +534,7 @@ Status ArCrsCombiner::KeepProvablyEqualInstructionGroupsSPMD( return OkStatus(); } -StatusOr ArCrsCombiner::RewriteGraph() { +absl::StatusOr ArCrsCombiner::RewriteGraph() { if (all_reduce_map_.empty()) { return false; } @@ -600,7 +600,7 @@ StatusOr ArCrsCombiner::RewriteGraph() { return true; } -StatusOr ArCrsCombiner::Run( +absl::StatusOr ArCrsCombiner::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { call_graph_ = CallGraph::Build(module); diff --git a/third_party/xla/xla/service/ar_crs_combiner.h b/third_party/xla/xla/service/ar_crs_combiner.h index 53bb9b85eec574..7b537b7dd87429 100644 --- a/third_party/xla/xla/service/ar_crs_combiner.h +++ b/third_party/xla/xla/service/ar_crs_combiner.h @@ -77,7 +77,7 @@ class ArCrsCombiner : public HloModulePass { spmd_partition_(spmd_partition) {} absl::string_view name() const override { return "ar-crs-combiner"; } using HloPassInterface::Run; - StatusOr Run( + absl::StatusOr Run( HloModule* module, const absl::flat_hash_set& execution_threads) override; @@ -161,7 +161,7 @@ class ArCrsCombiner : public HloModulePass { // Performs the graph rewrite that eliminates the early AllReduce and turns // the later CRS into an AllReduce. - StatusOr RewriteGraph(); + absl::StatusOr RewriteGraph(); int num_spatial_partitions_; diff --git a/third_party/xla/xla/service/async_collective_creator.cc b/third_party/xla/xla/service/async_collective_creator.cc index 067d33c202f550..1bdfdd995a1f94 100644 --- a/third_party/xla/xla/service/async_collective_creator.cc +++ b/third_party/xla/xla/service/async_collective_creator.cc @@ -24,8 +24,10 @@ limitations under the License. #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" +#include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/ir/hlo_schedule.h" #include "xla/service/shape_inference.h" +#include "xla/util.h" #include "tsl/platform/errors.h" namespace xla { @@ -36,7 +38,8 @@ struct ReplacedAsync { HloInstruction* done; }; -StatusOr CreateAsyncAllReduce(HloInstruction* instruction) { +absl::StatusOr CreateAsyncAllReduce( + HloInstruction* instruction) { HloComputation* computation = instruction->parent(); auto* ar = Cast(instruction); HloInstruction* start = @@ -50,7 +53,8 @@ StatusOr CreateAsyncAllReduce(HloInstruction* instruction) { return ReplacedAsync{start, done}; } -StatusOr CreateAsyncAllGather(HloInstruction* instruction) { +absl::StatusOr CreateAsyncAllGather( + HloInstruction* instruction) { HloComputation* computation = instruction->parent(); auto* ag = Cast(instruction); std::vector operand_shapes; @@ -74,7 +78,7 @@ StatusOr CreateAsyncAllGather(HloInstruction* instruction) { return ReplacedAsync{start, done}; } -StatusOr CreateAsyncCollectivePermute( +absl::StatusOr CreateAsyncCollectivePermute( HloInstruction* instruction, absl::Span context_shapes) { HloComputation* computation = instruction->parent(); auto* cp = Cast(instruction); @@ -111,7 +115,7 @@ StatusOr CreateAsyncCollectivePermute( return ReplacedAsync{start, done}; } -StatusOr CreateAsyncStartDone( +absl::StatusOr CreateAsyncStartDone( HloInstruction* instruction, absl::Span context_shapes) { HloComputation* computation = instruction->parent(); TF_ASSIGN_OR_RETURN( @@ -125,93 +129,110 @@ StatusOr CreateAsyncStartDone( } // namespace -StatusOr AsyncCollectiveCreator::Run( +// Find all supported collective ops first as we can't modify the instructions +// while iterating through them. +std::vector AsyncCollectiveCreator::MatchCollectives( + HloComputation* computation) { + std::vector supported_collectives; + for (HloInstruction* instruction : computation->instructions()) { + const HloOpcode op = instruction->opcode(); + if ((op == HloOpcode::kAllReduce && + config_.convert_all_reduce(instruction)) || + (op == HloOpcode::kAllGather && + config_.convert_all_gather(instruction)) || + (op == HloOpcode::kCollectivePermute && + config_.convert_collective_permute(instruction)) || + (op == HloOpcode::kAllToAll && + config_.convert_all_to_all(instruction)) || + (op == HloOpcode::kReduceScatter && + config_.convert_reduce_scatter(instruction))) { + supported_collectives.push_back(instruction); + } + } + return supported_collectives; +} + +absl::StatusOr AsyncCollectiveCreator::ReplaceCollectives( + HloComputation* computation, + std::vector& supported_collectives) { + bool changed = false; + HloModule* module = computation->parent(); + absl::flat_hash_map replaced_pairs; + const bool should_update_schedule = + module->has_schedule() && + module->schedule().is_computation_scheduled(computation); + for (HloInstruction* instruction : supported_collectives) { + absl::StatusOr async_pair; + switch (instruction->opcode()) { + case HloOpcode::kAllReduce: + async_pair = CreateAsyncAllReduce(instruction); + break; + case HloOpcode::kAllGather: + async_pair = CreateAsyncAllGather(instruction); + break; + case HloOpcode::kCollectivePermute: + async_pair = CreateAsyncCollectivePermute( + instruction, config_.get_context_shapes(instruction)); + break; + case HloOpcode::kAllToAll: + case HloOpcode::kReduceScatter: + async_pair = CreateAsyncStartDone( + instruction, config_.get_context_shapes(instruction)); + break; + default: + return Internal("Unexpected opcode %s", + HloOpcodeString(instruction->opcode())); + } + TF_RETURN_IF_ERROR(async_pair.status()); + async_pair->start->set_metadata(instruction->metadata()); + async_pair->start->CopyBackendConfigFrom(instruction); + if (should_update_schedule) { + replaced_pairs[instruction] = *async_pair; + } + + // Update control dependencies if present. + TF_RETURN_IF_ERROR( + instruction->CopyAllControlDepsTo(async_pair->start, async_pair->done)); + TF_RETURN_IF_ERROR(instruction->DropAllControlDeps()); + + TF_RETURN_WITH_CONTEXT_IF_ERROR( + computation->ReplaceInstruction(instruction, async_pair->done), + "replacing ", instruction->ToShortString()); + changed = true; + } + if (should_update_schedule) { + std::vector new_sequence; + const HloInstructionSequence& sequence = + module->schedule().sequence(computation); + new_sequence.reserve(sequence.size() + replaced_pairs.size()); + for (HloInstruction* instr : sequence.instructions()) { + auto it = replaced_pairs.find(instr); + if (it != replaced_pairs.end()) { + new_sequence.push_back(it->second.start); + new_sequence.push_back(it->second.done); + continue; + } + new_sequence.push_back(instr); + } + module->schedule().set_sequence(computation, new_sequence); + } + return changed; +} + +absl::StatusOr AsyncCollectiveCreator::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { bool changed = false; for (HloComputation* computation : module->MakeNonfusionComputations(execution_threads)) { - // Find all supported collective ops first as we can't modify the - // instructions while iterating through them. - std::vector supported_collectives; - for (HloInstruction* instruction : computation->instructions()) { - const HloOpcode op = instruction->opcode(); - if ((op == HloOpcode::kAllReduce && - config_.convert_all_reduce(instruction)) || - (op == HloOpcode::kAllGather && - config_.convert_all_gather(instruction)) || - (op == HloOpcode::kCollectivePermute && - config_.convert_collective_permute(instruction)) || - (op == HloOpcode::kAllToAll && - config_.convert_all_to_all(instruction)) || - (op == HloOpcode::kReduceScatter && - config_.convert_reduce_scatter(instruction))) { - supported_collectives.push_back(instruction); - } - } + std::vector supported_collectives = + MatchCollectives(computation); if (supported_collectives.empty()) { continue; } - - absl::flat_hash_map replaced_pairs; - const bool should_update_schedule = - module->has_schedule() && - module->schedule().is_computation_scheduled(computation); - for (HloInstruction* instruction : supported_collectives) { - StatusOr async_pair; - switch (instruction->opcode()) { - case HloOpcode::kAllReduce: - async_pair = CreateAsyncAllReduce(instruction); - break; - case HloOpcode::kAllGather: - async_pair = CreateAsyncAllGather(instruction); - break; - case HloOpcode::kCollectivePermute: - async_pair = CreateAsyncCollectivePermute( - instruction, config_.get_context_shapes(instruction)); - break; - case HloOpcode::kAllToAll: - case HloOpcode::kReduceScatter: - async_pair = CreateAsyncStartDone( - instruction, config_.get_context_shapes(instruction)); - break; - default: - return Internal("Unexpected opcode %s", - HloOpcodeString(instruction->opcode())); - } - TF_RETURN_IF_ERROR(async_pair.status()); - async_pair->start->set_metadata(instruction->metadata()); - async_pair->start->CopyBackendConfigFrom(instruction); - if (should_update_schedule) { - replaced_pairs[instruction] = *async_pair; - } - - // Update control dependencies if present. - TF_RETURN_IF_ERROR(instruction->CopyAllControlDepsTo(async_pair->start, - async_pair->done)); - TF_RETURN_IF_ERROR(instruction->DropAllControlDeps()); - - TF_RETURN_WITH_CONTEXT_IF_ERROR( - computation->ReplaceInstruction(instruction, async_pair->done), - "replacing ", instruction->ToShortString()); - changed = true; - } - if (should_update_schedule) { - std::vector new_sequence; - const HloInstructionSequence& sequence = - module->schedule().sequence(computation); - new_sequence.reserve(sequence.size() + replaced_pairs.size()); - for (HloInstruction* instr : sequence.instructions()) { - auto it = replaced_pairs.find(instr); - if (it != replaced_pairs.end()) { - new_sequence.push_back(it->second.start); - new_sequence.push_back(it->second.done); - continue; - } - new_sequence.push_back(instr); - } - module->schedule().set_sequence(computation, new_sequence); - } + TF_ASSIGN_OR_RETURN(bool comp_changed, + ReplaceCollectives(computation, supported_collectives)); + changed |= comp_changed; } return changed; } diff --git a/third_party/xla/xla/service/async_collective_creator.h b/third_party/xla/xla/service/async_collective_creator.h index 43c8d1af8b0e93..72344181757a6c 100644 --- a/third_party/xla/xla/service/async_collective_creator.h +++ b/third_party/xla/xla/service/async_collective_creator.h @@ -47,9 +47,15 @@ class AsyncCollectiveCreator : public HloModulePass { absl::string_view name() const override { return "async-collective-creator"; } using HloPassInterface::Run; - StatusOr Run( - HloModule* module, - const absl::flat_hash_set& execution_threads) override; + absl::StatusOr Run( + HloModule *module, + const absl::flat_hash_set &execution_threads) override; + + std::vector MatchCollectives(HloComputation *computation); + absl::StatusOr ReplaceCollectives( + HloComputation *computation, + std::vector &supported_collectives); + const CollectiveCreatorConfig *config() const { return &config_; } private: CollectiveCreatorConfig config_; diff --git a/third_party/xla/xla/service/backend.cc b/third_party/xla/xla/service/backend.cc index 106533fdd3a559..459b0d37c0d521 100644 --- a/third_party/xla/xla/service/backend.cc +++ b/third_party/xla/xla/service/backend.cc @@ -78,7 +78,7 @@ struct Backend::IntraOpThreadPool { std::unique_ptr device; }; -/* static */ StatusOr> Backend::CreateBackend( +/* static */ absl::StatusOr> Backend::CreateBackend( const BackendOptions& options) { se::Platform* platform = options.platform(); TF_ASSIGN_OR_RETURN(auto compiler, Compiler::GetForPlatform(platform)); @@ -95,7 +95,7 @@ struct Backend::IntraOpThreadPool { return std::move(backend); } -/* static */ StatusOr> +/* static */ absl::StatusOr> Backend::CreateDefaultBackend() { TF_ASSIGN_OR_RETURN(se::Platform * platform, PlatformUtil::GetDefaultPlatform()); @@ -104,14 +104,14 @@ Backend::CreateDefaultBackend() { return CreateBackend(backend_options); } -StatusOr Backend::BorrowStream(int device_ordinal, - se::StreamPriority priority) { +absl::StatusOr Backend::BorrowStream( + int device_ordinal, se::StreamPriority priority) { TF_ASSIGN_OR_RETURN(auto executor, stream_executor(device_ordinal)); return BorrowStream(executor, priority); } -StatusOr Backend::BorrowStream(se::StreamExecutor* executor, - se::StreamPriority priority) { +absl::StatusOr Backend::BorrowStream( + se::StreamExecutor* executor, se::StreamPriority priority) { absl::MutexLock l(&mu_); if (!stream_pools_.contains(executor)) { stream_pools_.emplace(executor, std::make_unique()); @@ -119,7 +119,7 @@ StatusOr Backend::BorrowStream(se::StreamExecutor* executor, return stream_pools_.at(executor)->BorrowStream(executor, priority); } -StatusOr> Backend::BorrowStreams( +absl::StatusOr> Backend::BorrowStreams( int device_ordinal, int num_streams, se::StreamPriority priority) { absl::MutexLock l(&mu_); TF_ASSIGN_OR_RETURN(auto executor, stream_executor(device_ordinal)); @@ -181,7 +181,7 @@ tsl::thread::ThreadPool* Backend::eigen_intra_op_thread_pool() const { return intra_op_thread_pool_->pool.get(); } -StatusOr Backend::stream_executor( +absl::StatusOr Backend::stream_executor( int device_ordinal) const { if (device_ordinal < 0 || device_ordinal > stream_executors_.back()->device_ordinal()) { @@ -198,8 +198,8 @@ StatusOr Backend::stream_executor( device_name(device_ordinal)); } -StatusOr Backend::devices_equivalent(int device_ordinal_a, - int device_ordinal_b) { +absl::StatusOr Backend::devices_equivalent(int device_ordinal_a, + int device_ordinal_b) { // Use the name from device description to determine equivalence. This is a // bit crude but works for GPUs which is the important case where we compile // an executable for one GPU and want to know if it will run (well) on diff --git a/third_party/xla/xla/service/backend.h b/third_party/xla/xla/service/backend.h index cf397646195f25..fb8a324a4d320d 100644 --- a/third_party/xla/xla/service/backend.h +++ b/third_party/xla/xla/service/backend.h @@ -74,12 +74,12 @@ class BackendOptions { class Backend { public: // Creates a new backend. - static StatusOr> CreateBackend( + static absl::StatusOr> CreateBackend( const BackendOptions& options); // Creates a backend for the default platform. The default platform is defined // in PlatformUtil. - static StatusOr> CreateDefaultBackend(); + static absl::StatusOr> CreateDefaultBackend(); ~Backend(); @@ -109,7 +109,7 @@ class Backend { } // Returns the stream executor for the given device ordinal. - StatusOr stream_executor(int device_ordinal) const; + absl::StatusOr stream_executor(int device_ordinal) const; // Returns the stream executor for the default device ordinal. This stream // executor can only be used when the number of computations is 1 (replication @@ -122,13 +122,13 @@ class Backend { // Borrows a stream for use by the caller with a given priority, either by // grabbing it from an internal pool, or by constructing/initializating it, // and returns the result to the caller. - StatusOr BorrowStream( + absl::StatusOr BorrowStream( int device_ordinal, se::StreamPriority priority = se::StreamPriority::Default); - StatusOr BorrowStream( + absl::StatusOr BorrowStream( se::StreamExecutor* executor, se::StreamPriority priority = se::StreamPriority::Default); - StatusOr> BorrowStreams( + absl::StatusOr> BorrowStreams( int device_ordinal, int num_streams, se::StreamPriority priority = se::StreamPriority::Default); @@ -136,8 +136,8 @@ class Backend { // as `BorrowStreams` above does. // Purely for convenience, the caller could rather make this anonymous // function itself. - std::function>(int, int, - se::StreamPriority)> + std::function>( + int, int, se::StreamPriority)> StreamBorrowerWithPriority() { return [this](int device_ordinal, int num_streams, se::StreamPriority priority) { @@ -159,7 +159,8 @@ class Backend { // Returns true if the devices with the given ordinals are equivalent from // XLA's perspective. That is, an executable compiled for one device would // be equivalent to an executable compiled for the other. - StatusOr devices_equivalent(int device_ordinal_a, int device_ordinal_b); + absl::StatusOr devices_equivalent(int device_ordinal_a, + int device_ordinal_b); // For the host platform, returns the configured eigen threadpool device to be // used for scheduling work. For other platforms, returns NULL. diff --git a/third_party/xla/xla/service/batch_dot_simplification.cc b/third_party/xla/xla/service/batch_dot_simplification.cc index 78f1a27cce8f82..5e0484aaaa2a22 100644 --- a/third_party/xla/xla/service/batch_dot_simplification.cc +++ b/third_party/xla/xla/service/batch_dot_simplification.cc @@ -20,7 +20,7 @@ limitations under the License. #include "xla/service/hlo_creation_utils.h" namespace xla { -StatusOr +absl::StatusOr BatchDotSimplification::ElideDegenerateBatchDimensionFromBatchDot( HloInstruction* batch_dot) { // This pass assumes the lhs and rhs batch dimensions are equal and strictly @@ -108,7 +108,7 @@ absl::string_view BatchDotSimplification::name() const { return "batch-dot-simplification"; } -StatusOr BatchDotSimplification::Run( +absl::StatusOr BatchDotSimplification::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { bool changed = false; diff --git a/third_party/xla/xla/service/batch_dot_simplification.h b/third_party/xla/xla/service/batch_dot_simplification.h index 02baa9392f59db..0f5238386429dc 100644 --- a/third_party/xla/xla/service/batch_dot_simplification.h +++ b/third_party/xla/xla/service/batch_dot_simplification.h @@ -28,13 +28,13 @@ namespace xla { class BatchDotSimplification : public HloModulePass { public: using HloPassInterface::Run; - StatusOr Run( + absl::StatusOr Run( HloModule* module, const absl::flat_hash_set& execution_threads) override; absl::string_view name() const override; private: - StatusOr ElideDegenerateBatchDimensionFromBatchDot( + absl::StatusOr ElideDegenerateBatchDimensionFromBatchDot( HloInstruction* batch_dot); }; } // namespace xla diff --git a/third_party/xla/xla/service/batchnorm_expander.cc b/third_party/xla/xla/service/batchnorm_expander.cc index 9da5e568697a83..592cb4a210cfe0 100644 --- a/third_party/xla/xla/service/batchnorm_expander.cc +++ b/third_party/xla/xla/service/batchnorm_expander.cc @@ -580,7 +580,7 @@ Status BatchNormExpanderVisitor::HandleBatchNormGrad( return OkStatus(); } -StatusOr BatchNormExpander::Run( +absl::StatusOr BatchNormExpander::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { XLA_VLOG_LINES(2, "BatchNormExpander::Run(), before:\n" + module->ToString()); diff --git a/third_party/xla/xla/service/batchnorm_expander.h b/third_party/xla/xla/service/batchnorm_expander.h index 9c6d2e5e059448..ab2c13f56bc2c4 100644 --- a/third_party/xla/xla/service/batchnorm_expander.h +++ b/third_party/xla/xla/service/batchnorm_expander.h @@ -41,7 +41,7 @@ class BatchNormExpander : public HloModulePass { // Run operation expander on the given computation. Returns whether the // computation was changed. using HloPassInterface::Run; - StatusOr Run( + absl::StatusOr Run( HloModule* module, const absl::flat_hash_set& execution_threads) override; diff --git a/third_party/xla/xla/service/bfloat16_conversion_folding.cc b/third_party/xla/xla/service/bfloat16_conversion_folding.cc index 49e4909a96784a..c8e94d576c5c0b 100644 --- a/third_party/xla/xla/service/bfloat16_conversion_folding.cc +++ b/third_party/xla/xla/service/bfloat16_conversion_folding.cc @@ -252,7 +252,7 @@ Status BFloat16ConversionFoldingVisitor::HandleAllReduce(HloInstruction* crs) { return OkStatus(); } -StatusOr BFloat16ConversionFolding::Run( +absl::StatusOr BFloat16ConversionFolding::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { XLA_VLOG_LINES( diff --git a/third_party/xla/xla/service/bfloat16_conversion_folding.h b/third_party/xla/xla/service/bfloat16_conversion_folding.h index bec82dad3da4a7..707738dd8491cd 100644 --- a/third_party/xla/xla/service/bfloat16_conversion_folding.h +++ b/third_party/xla/xla/service/bfloat16_conversion_folding.h @@ -44,7 +44,7 @@ class BFloat16ConversionFolding : public HloModulePass { // Run BF16 conversion folding on the given computation. Returns whether the // computation was changed. using HloPassInterface::Run; - StatusOr Run( + absl::StatusOr Run( HloModule* module, const absl::flat_hash_set& execution_threads) override; diff --git a/third_party/xla/xla/service/bfloat16_conversion_folding_test.cc b/third_party/xla/xla/service/bfloat16_conversion_folding_test.cc index 9353b49ad8dcac..99cf031b565d5e 100644 --- a/third_party/xla/xla/service/bfloat16_conversion_folding_test.cc +++ b/third_party/xla/xla/service/bfloat16_conversion_folding_test.cc @@ -75,7 +75,7 @@ class BFloat16ConversionFoldingTest : public HloTestBase { bool FoldConversions(HloModule* module) { TestBFloat16Support bfloat16_support_; BFloat16ConversionFolding fold(&bfloat16_support_); - StatusOr result = fold.Run(module); + absl::StatusOr result = fold.Run(module); EXPECT_IS_OK(result.status()); return result.value(); } diff --git a/third_party/xla/xla/service/bfloat16_propagation.cc b/third_party/xla/xla/service/bfloat16_propagation.cc index 68d94ae29e94c6..38eb493081f302 100644 --- a/third_party/xla/xla/service/bfloat16_propagation.cc +++ b/third_party/xla/xla/service/bfloat16_propagation.cc @@ -831,7 +831,7 @@ Status BFloat16Propagation::SkipNoopConversions( // their users. During the backward pass, the potential changes are stored in // changes_to_bf16_ which are subject to further adjustments then applied to the // HLOs. -StatusOr BFloat16Propagation::Run( +absl::StatusOr BFloat16Propagation::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { consider_using_bfloat16_.clear(); diff --git a/third_party/xla/xla/service/bfloat16_propagation.h b/third_party/xla/xla/service/bfloat16_propagation.h index 43280766b41f3d..21625e7337573d 100644 --- a/third_party/xla/xla/service/bfloat16_propagation.h +++ b/third_party/xla/xla/service/bfloat16_propagation.h @@ -68,7 +68,7 @@ class BFloat16Propagation : public HloModulePass { // Runs the pass on the given module. Returns whether the module was changed // (precision reductions were added). using HloPassInterface::Run; - StatusOr Run( + absl::StatusOr Run( HloModule* module, const absl::flat_hash_set& execution_threads) override; diff --git a/third_party/xla/xla/service/bfloat16_propagation_test.cc b/third_party/xla/xla/service/bfloat16_propagation_test.cc index b169105f5dc1c7..c52c6a37d67c3d 100644 --- a/third_party/xla/xla/service/bfloat16_propagation_test.cc +++ b/third_party/xla/xla/service/bfloat16_propagation_test.cc @@ -67,7 +67,7 @@ class BFloat16PropagationTest : public HloTestBase { bool PropagatePrecision(HloModule* module) { TestBFloat16Support bfloat16_support; BFloat16Propagation propagation(&bfloat16_support); - StatusOr result = propagation.Run(module); + absl::StatusOr result = propagation.Run(module); EXPECT_IS_OK(result.status()); return result.value(); } diff --git a/third_party/xla/xla/service/bitcast_dtypes_expander.cc b/third_party/xla/xla/service/bitcast_dtypes_expander.cc index 6e48959b5d95b4..f4cc6809599cdd 100644 --- a/third_party/xla/xla/service/bitcast_dtypes_expander.cc +++ b/third_party/xla/xla/service/bitcast_dtypes_expander.cc @@ -35,7 +35,7 @@ limitations under the License. namespace xla { -StatusOr BitcastDtypesExpander::ExpandInstruction( +absl::StatusOr BitcastDtypesExpander::ExpandInstruction( HloInstruction* instruction) { HloInstruction* input = instruction->mutable_operand(0); const Shape& from_shape = input->shape(); diff --git a/third_party/xla/xla/service/bitcast_dtypes_expander.h b/third_party/xla/xla/service/bitcast_dtypes_expander.h index eb7412bef08b45..ce7663d47fc0e6 100644 --- a/third_party/xla/xla/service/bitcast_dtypes_expander.h +++ b/third_party/xla/xla/service/bitcast_dtypes_expander.h @@ -33,7 +33,7 @@ class BitcastDtypesExpander : public OpExpanderPass { protected: bool InstructionMatchesPattern(HloInstruction* instruction) override; - StatusOr ExpandInstruction( + absl::StatusOr ExpandInstruction( HloInstruction* instruction) override; private: diff --git a/third_party/xla/xla/service/broadcast_canonicalizer.cc b/third_party/xla/xla/service/broadcast_canonicalizer.cc index 02c1af6fd25fd7..e763b4d60d0e23 100644 --- a/third_party/xla/xla/service/broadcast_canonicalizer.cc +++ b/third_party/xla/xla/service/broadcast_canonicalizer.cc @@ -21,7 +21,7 @@ namespace xla { BroadcastCanonicalizer::BroadcastCanonicalizer() {} -StatusOr BroadcastCanonicalizer::Run( +absl::StatusOr BroadcastCanonicalizer::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { bool changed = false; diff --git a/third_party/xla/xla/service/broadcast_canonicalizer.h b/third_party/xla/xla/service/broadcast_canonicalizer.h index 30b182ce635cd1..0206d187942d86 100644 --- a/third_party/xla/xla/service/broadcast_canonicalizer.h +++ b/third_party/xla/xla/service/broadcast_canonicalizer.h @@ -30,7 +30,7 @@ class BroadcastCanonicalizer : public HloModulePass { absl::string_view name() const override { return "broadcast_canonicalizer"; } using HloPassInterface::Run; - StatusOr Run( + absl::StatusOr Run( HloModule* module, const absl::flat_hash_set& execution_threads) override; }; diff --git a/third_party/xla/xla/service/buffer_assignment.cc b/third_party/xla/xla/service/buffer_assignment.cc index 0543365a391454..c2cbdf5d001442 100644 --- a/third_party/xla/xla/service/buffer_assignment.cc +++ b/third_party/xla/xla/service/buffer_assignment.cc @@ -74,7 +74,7 @@ absl::flat_hash_map BuildIdToHloInstructionMap( return id_to_hlo_instruction; } -StatusOr> +absl::StatusOr> BuildIdToLogicalBufferMap( const BufferAssignmentProto& proto, const absl::flat_hash_map& @@ -470,7 +470,7 @@ bool BufferAssignment::HasTopLevelAllocation( return HasAllocationAt(instruction, /*index=*/{}); } -StatusOr BufferAssignment::GetUniqueSlice( +absl::StatusOr BufferAssignment::GetUniqueSlice( const HloInstruction* instruction, const ShapeIndex& index) const { VLOG(3) << "Trying to find unique slice for " << instruction->name() << " [" << index << "]"; @@ -502,7 +502,8 @@ StatusOr BufferAssignment::GetUniqueSlice( return result; } -StatusOr BufferAssignment::GetUniqueTopLevelSlice( +absl::StatusOr +BufferAssignment::GetUniqueTopLevelSlice( const HloInstruction* instruction) const { return GetUniqueSlice(instruction, /*index=*/{}); } @@ -548,7 +549,7 @@ bool BufferAssignment::HaveDisjointSlices(const HloInstruction* hlo_a, }); } -StatusOr +absl::StatusOr BufferAssignment::GetUniqueTopLevelOutputSlice() const { return GetUniqueTopLevelSlice( module_->entry_computation()->root_instruction()); @@ -995,7 +996,7 @@ BufferAssignmentProto BufferAssignment::ToProto() const { } /* static */ -StatusOr> BufferAssignment::FromProto( +absl::StatusOr> BufferAssignment::FromProto( const BufferAssignmentProto& proto, const HloModule* module, BufferValue::SizeFunction buffer_size, HloDataflowAnalysis::CanShareBuffer can_share_buffer) { @@ -1072,7 +1073,7 @@ StatusOr> BufferAssignment::FromProto( } /* static */ -StatusOr> BufferAssigner::Run( +absl::StatusOr> BufferAssigner::Run( const HloModule* module, std::unique_ptr hlo_ordering, BufferValue::SizeFunction buffer_size, LogicalBuffer::AlignmentFunction color_alignment, @@ -1983,7 +1984,8 @@ void BufferAssigner::AssignBuffersFromHeapSimulator( } } -StatusOr> BufferAssigner::CreateAssignment( +absl::StatusOr> +BufferAssigner::CreateAssignment( const HloModule* module, std::unique_ptr hlo_ordering, BufferValue::SizeFunction buffer_size, LogicalBuffer::AlignmentFunction color_alignment, diff --git a/third_party/xla/xla/service/buffer_assignment.h b/third_party/xla/xla/service/buffer_assignment.h index 0fd9632319cd2d..367151df0528e7 100644 --- a/third_party/xla/xla/service/buffer_assignment.h +++ b/third_party/xla/xla/service/buffer_assignment.h @@ -421,16 +421,16 @@ class BufferAssignment { // Convenience function which returns the unique slice containing the buffer // at the given index of the given instruction. If a slice is not assigned or // the slice cannot be determined at compile time then an error is returned. - StatusOr GetUniqueSlice( + absl::StatusOr GetUniqueSlice( const HloInstruction* instruction, const ShapeIndex& index) const; // Like GetUniqueSlice but fixes the index to the top-level of the shape // (index = {}). - StatusOr GetUniqueTopLevelSlice( + absl::StatusOr GetUniqueTopLevelSlice( const HloInstruction* instruction) const; // Like GetUniqueTopLevelSlice but returns the slice for the output of the // entry computation of the HLO module (ie, the result of the XLA // computation). - StatusOr GetUniqueTopLevelOutputSlice() const; + absl::StatusOr GetUniqueTopLevelOutputSlice() const; // Returns the set BufferValues which may be the source of the value at the // given index and instruction. @@ -480,7 +480,7 @@ class BufferAssignment { // Convert BufferAssignment to or from a proto. BufferAssignmentProto ToProto() const; - static StatusOr> FromProto( + static absl::StatusOr> FromProto( const BufferAssignmentProto& proto, const HloModule* module, BufferValue::SizeFunction buffer_size, HloDataflowAnalysis::CanShareBuffer can_share_buffer); @@ -637,7 +637,7 @@ class BufferAssigner { // LogicalBuffer. If preset_assignments is provided, those pre-set assignment // offsets will be used. The caller guarantees that those assignments are // valid and they do not overwrite each other. - static StatusOr> Run( + static absl::StatusOr> Run( const HloModule* module, std::unique_ptr hlo_ordering, BufferValue::SizeFunction buffer_size, LogicalBuffer::AlignmentFunction color_alignment, @@ -665,7 +665,7 @@ class BufferAssigner { virtual ~BufferAssigner() = default; // Create a buffer assignment. - StatusOr> CreateAssignment( + absl::StatusOr> CreateAssignment( const HloModule* module, std::unique_ptr hlo_ordering, BufferValue::SizeFunction buffer_size, LogicalBuffer::AlignmentFunction color_alignment, diff --git a/third_party/xla/xla/service/buffer_assignment_test.cc b/third_party/xla/xla/service/buffer_assignment_test.cc index ef9bd1de1b4e3f..5380368cad491e 100644 --- a/third_party/xla/xla/service/buffer_assignment_test.cc +++ b/third_party/xla/xla/service/buffer_assignment_test.cc @@ -103,7 +103,7 @@ class BufferAssignmentTest : public HloTestBase { .value(); } - StatusOr> ConvertToProtoAndBack( + absl::StatusOr> ConvertToProtoAndBack( const BufferAssignment* buffers, const HloModule* module) { // Dump proto for buffer assignments. auto proto = buffers->ToProto(); diff --git a/third_party/xla/xla/service/call_inliner.cc b/third_party/xla/xla/service/call_inliner.cc index 92ccce30e148a0..7b1a46f778e0d3 100644 --- a/third_party/xla/xla/service/call_inliner.cc +++ b/third_party/xla/xla/service/call_inliner.cc @@ -92,7 +92,7 @@ class SubcomputationInsertionVisitor : public DfsHloVisitorWithDefault { // Resolves the callee subcomputation_hlo to the new (inline) HLO in the // caller computation, or returns a NotFound error if that subcomputation HLO // has not been mapped. - StatusOr Resolve(HloInstruction* subcomputation_hlo) { + absl::StatusOr Resolve(HloInstruction* subcomputation_hlo) { auto it = subcomputation_hlo_to_new_hlo_.find(subcomputation_hlo); if (it == subcomputation_hlo_to_new_hlo_.end()) { return NotFound( @@ -123,8 +123,8 @@ class SubcomputationInsertionVisitor : public DfsHloVisitorWithDefault { } // namespace -/* static */ StatusOr CallInliner::Inline( - HloInstruction* call) { +/* static */ absl::StatusOr +CallInliner::Inline(HloInstruction* call) { TF_RET_CHECK(call->opcode() == HloOpcode::kCall) << "Instruction was not a call op: " << call->opcode(); const auto& callees = call->called_computations(); @@ -136,7 +136,7 @@ class SubcomputationInsertionVisitor : public DfsHloVisitorWithDefault { return visitor.ConsumeInstructionMap(); } -StatusOr CallInliner::Run( +absl::StatusOr CallInliner::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { std::unique_ptr call_graph = CallGraph::Build(module); diff --git a/third_party/xla/xla/service/call_inliner.h b/third_party/xla/xla/service/call_inliner.h index 7280e7dc0db82f..2ce5e7054a9235 100644 --- a/third_party/xla/xla/service/call_inliner.h +++ b/third_party/xla/xla/service/call_inliner.h @@ -31,7 +31,7 @@ class CallInliner : public HloModulePass { // Inlines one call instruction. Returns a mapping from the original // instructions to their inlined versions. - static StatusOr Inline(HloInstruction* call); + static absl::StatusOr Inline(HloInstruction* call); // If single_call_site is true, only functions with a single call site will be // inlined. @@ -44,7 +44,7 @@ class CallInliner : public HloModulePass { absl::string_view name() const override { return "CallInliner"; } using HloPassInterface::Run; - StatusOr Run( + absl::StatusOr Run( HloModule* module, const absl::flat_hash_set& execution_threads) override; diff --git a/third_party/xla/xla/service/change_op_data_type.cc b/third_party/xla/xla/service/change_op_data_type.cc index bfb6f56765dbc4..6765ebc7f3c62c 100644 --- a/third_party/xla/xla/service/change_op_data_type.cc +++ b/third_party/xla/xla/service/change_op_data_type.cc @@ -35,7 +35,7 @@ std::optional GetUniformOperandType( } } // namespace -StatusOr ChangeOpDataType::Run( +absl::StatusOr ChangeOpDataType::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { bool changed = false; diff --git a/third_party/xla/xla/service/change_op_data_type.h b/third_party/xla/xla/service/change_op_data_type.h index 36fbbeadfa72de..1f3ed75bd5e407 100644 --- a/third_party/xla/xla/service/change_op_data_type.h +++ b/third_party/xla/xla/service/change_op_data_type.h @@ -63,7 +63,7 @@ class ChangeOpDataType : public HloModulePass { } absl::string_view name() const override { return "change-op-data-type"; } - StatusOr Run( + absl::StatusOr Run( HloModule* module, const absl::flat_hash_set& execution_threads) override; diff --git a/third_party/xla/xla/service/channel_tracker.cc b/third_party/xla/xla/service/channel_tracker.cc index fde815dc50a857..8ad2445f082ef9 100644 --- a/third_party/xla/xla/service/channel_tracker.cc +++ b/third_party/xla/xla/service/channel_tracker.cc @@ -19,7 +19,7 @@ limitations under the License. namespace xla { -StatusOr ChannelTracker::NewChannel( +absl::StatusOr ChannelTracker::NewChannel( ChannelHandle::ChannelType type) { if (type != ChannelHandle::DEVICE_TO_DEVICE && type != ChannelHandle::HOST_TO_DEVICE && diff --git a/third_party/xla/xla/service/channel_tracker.h b/third_party/xla/xla/service/channel_tracker.h index e37c66a70b2a98..87b90a2c83c5f0 100644 --- a/third_party/xla/xla/service/channel_tracker.h +++ b/third_party/xla/xla/service/channel_tracker.h @@ -33,7 +33,7 @@ class ChannelTracker { // Creates a new Channel object and returns the corresponding // ChannelHandle for it. - StatusOr NewChannel(ChannelHandle::ChannelType type); + absl::StatusOr NewChannel(ChannelHandle::ChannelType type); private: // Guards the channel mapping. diff --git a/third_party/xla/xla/service/cholesky_expander.cc b/third_party/xla/xla/service/cholesky_expander.cc index e701ecebafdecd..95d46d0f323a52 100644 --- a/third_party/xla/xla/service/cholesky_expander.cc +++ b/third_party/xla/xla/service/cholesky_expander.cc @@ -52,7 +52,7 @@ namespace xla { // l = temp / l[..., j, j) * mask + l // return l // Returns a (result, error) pair. -StatusOr> CholeskyExpander::CholeskyUnblocked( +absl::StatusOr> CholeskyExpander::CholeskyUnblocked( XlaOp a, PrecisionConfig::Precision precision) { XlaBuilder* builder = a.builder(); TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a)); @@ -73,8 +73,9 @@ StatusOr> CholeskyExpander::CholeskyUnblocked( XlaOp l = ZerosLike(a); // Construct the for loop body to iterate over rows. - auto body_fn = [&](XlaOp i, absl::Span loop_vars, - XlaBuilder* body_builder) -> StatusOr> { + auto body_fn = + [&](XlaOp i, absl::Span loop_vars, + XlaBuilder* body_builder) -> absl::StatusOr> { std::vector row_shape_dims(major_dims.begin(), major_dims.end()); std::vector col_shape_dims(major_dims.begin(), major_dims.end()); auto body_a = loop_vars[0]; @@ -126,7 +127,7 @@ StatusOr> CholeskyExpander::CholeskyUnblocked( XlaOp CholeskyExpander::BuildCholesky(XlaOp a, int64_t block_size, PrecisionConfig::Precision precision) { XlaBuilder* builder = a.builder(); - return builder->ReportErrorOrReturn([&]() -> StatusOr { + return builder->ReportErrorOrReturn([&]() -> absl::StatusOr { TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a)); const int ndims = a_shape.rank(); if (ndims < 2) { @@ -217,7 +218,7 @@ bool CholeskyExpander::InstructionMatchesPattern(HloInstruction* instruction) { return instruction->opcode() == HloOpcode::kCholesky; } -StatusOr CholeskyExpander::ExpandInstruction( +absl::StatusOr CholeskyExpander::ExpandInstruction( HloInstruction* instruction) { const CholeskyOptions& options = instruction->cholesky_options(); const std::string name = absl::StrFormat( diff --git a/third_party/xla/xla/service/cholesky_expander.h b/third_party/xla/xla/service/cholesky_expander.h index 70ea5247809446..3178d36e949b19 100644 --- a/third_party/xla/xla/service/cholesky_expander.h +++ b/third_party/xla/xla/service/cholesky_expander.h @@ -29,10 +29,10 @@ class CholeskyExpander : public OpExpanderPass { protected: bool InstructionMatchesPattern(HloInstruction* instruction) override; - StatusOr ExpandInstruction( + absl::StatusOr ExpandInstruction( HloInstruction* instruction) override; - virtual StatusOr> CholeskyUnblocked( + virtual absl::StatusOr> CholeskyUnblocked( XlaOp a, PrecisionConfig::Precision precision); private: diff --git a/third_party/xla/xla/service/collective_combiner_utils.h b/third_party/xla/xla/service/collective_combiner_utils.h index b6a01f95d7f0e7..b54296d764a87a 100644 --- a/third_party/xla/xla/service/collective_combiner_utils.h +++ b/third_party/xla/xla/service/collective_combiner_utils.h @@ -43,7 +43,7 @@ namespace xla { // together. Instructions will be combined until the threshold for output byte // size or instruction count is reached. template -StatusOr CombineInstructionsByKey( +absl::StatusOr CombineInstructionsByKey( HloComputation* computation, absl::FunctionRef(const HloInstruction*)> key_fn, absl::FunctionRef)> combine_fn, diff --git a/third_party/xla/xla/service/collective_decomposer_utils.cc b/third_party/xla/xla/service/collective_decomposer_utils.cc index aef3231325e1ec..d86c6b5ae4e917 100644 --- a/third_party/xla/xla/service/collective_decomposer_utils.cc +++ b/third_party/xla/xla/service/collective_decomposer_utils.cc @@ -30,7 +30,7 @@ limitations under the License. namespace xla { // Create the start indices for decompositing the given collective. -StatusOr> +absl::StatusOr> CreateStartIndicesForCollectiveDecomposition( CollectiveOpGroupMode group_mode, absl::Span replica_groups, const Shape &shard_shape, diff --git a/third_party/xla/xla/service/collective_decomposer_utils.h b/third_party/xla/xla/service/collective_decomposer_utils.h index bcf4ebc71c6541..905ab12c240698 100644 --- a/third_party/xla/xla/service/collective_decomposer_utils.h +++ b/third_party/xla/xla/service/collective_decomposer_utils.h @@ -23,7 +23,7 @@ limitations under the License. namespace xla { -StatusOr> +absl::StatusOr> CreateStartIndicesForCollectiveDecomposition( CollectiveOpGroupMode group_mode, absl::Span replica_groups, const Shape &shard_shape, diff --git a/third_party/xla/xla/service/collective_ops_utils.cc b/third_party/xla/xla/service/collective_ops_utils.cc index a1dcd47878ebc9..01678635f2d9b4 100644 --- a/third_party/xla/xla/service/collective_ops_utils.cc +++ b/third_party/xla/xla/service/collective_ops_utils.cc @@ -89,7 +89,7 @@ std::optional GetReductionIdentity(ReductionKind kind, } } -StatusOr> GetParticipatingIDs( +absl::StatusOr> GetParticipatingIDs( CollectiveOpGroupMode group_mode, int current_id, std::optional total_participant_count, absl::Span groups) { @@ -131,7 +131,7 @@ StatusOr> GetParticipatingIDs( // Returns the group formation mode implied by (a) whether the operation has // channel_id and (b) if it has use_global_device_ids and if yes, its value. -StatusOr GetCollectiveOpGroupMode( +absl::StatusOr GetCollectiveOpGroupMode( bool has_channel_id, std::optional use_global_device_ids) { if (!has_channel_id) { if (!use_global_device_ids.has_value() || !*use_global_device_ids) { @@ -165,7 +165,7 @@ absl::string_view CollectiveOpGroupModeToString( } } -StatusOr>> +absl::StatusOr>> GetParticipatingDevicesGroups(const DeviceAssignment& device_assignment, absl::Span replica_groups, CollectiveOpGroupMode group_mode) { @@ -274,7 +274,7 @@ GetParticipatingDevicesGroups(const DeviceAssignment& device_assignment, } } -StatusOr> GetParticipatingFlattenedIdGroups( +absl::StatusOr> GetParticipatingFlattenedIdGroups( const DeviceAssignment& device_assignment, absl::Span replica_groups, CollectiveOpGroupMode group_mode) { @@ -304,7 +304,7 @@ StatusOr> GetParticipatingFlattenedIdGroups( return flattened_id_groups; } -StatusOr> GetParticipatingFlattenedIdGroups( +absl::StatusOr> GetParticipatingFlattenedIdGroups( absl::Span replica_groups, CollectiveOpGroupMode replica_group_mode, int replica_count, int partition_count) { @@ -375,7 +375,7 @@ StatusOr> GetParticipatingFlattenedIdGroups( return flattened_replica_groups; } -StatusOr> GetParticipatingDevices( +absl::StatusOr> GetParticipatingDevices( GlobalDeviceId device_id, const DeviceAssignment& device_assignment, absl::Span replica_groups, CollectiveOpGroupMode group_mode) { @@ -479,7 +479,7 @@ StatusOr> GetParticipatingDevices( } } -StatusOr> GetPariticipantCountsForReplicaGroups( +absl::StatusOr> GetPariticipantCountsForReplicaGroups( int64_t num_replicas, int64_t num_partitions, absl::Span replica_groups, CollectiveOpGroupMode group_mode) { @@ -597,6 +597,32 @@ bool IsCollective(const HloInstruction* instruction) { } } return false; + case HloOpcode::kAsyncStart: + case HloOpcode::kAsyncUpdate: + case HloOpcode::kAsyncDone: + return IsCollective(instruction->async_wrapped_instruction()); + default: + return false; + } +} + +bool IsCollectiveWithChannelId(const HloInstruction* instruction) { + switch (instruction->opcode()) { + case HloOpcode::kAllReduce: + case HloOpcode::kAllReduceStart: + case HloOpcode::kAllGather: + case HloOpcode::kAllGatherStart: + case HloOpcode::kAllToAll: + case HloOpcode::kCollectivePermute: + case HloOpcode::kCollectivePermuteStart: + return instruction->channel_id().has_value(); + case HloOpcode::kFusion: + for (const auto* inner_inst : instruction->fused_instructions()) { + if (IsCollectiveWithChannelId(inner_inst)) { + return true; + } + } + return false; default: return false; } diff --git a/third_party/xla/xla/service/collective_ops_utils.h b/third_party/xla/xla/service/collective_ops_utils.h index 0e9985fa9dd351..6c6114770cd403 100644 --- a/third_party/xla/xla/service/collective_ops_utils.h +++ b/third_party/xla/xla/service/collective_ops_utils.h @@ -97,7 +97,7 @@ enum class CollectiveOpGroupMode { // An empty `groups` indicates that all [0, total_participant_count) IDs // are participating. Note that for CollectiveOpGroupMode::kFlattenedID, // groups cannot be empty, so `total_participant_count` is an optional. -StatusOr> GetParticipatingIDs( +absl::StatusOr> GetParticipatingIDs( CollectiveOpGroupMode group_mode, int current_id, std::optional total_participant_count, absl::Span groups); @@ -107,7 +107,7 @@ absl::string_view CollectiveOpGroupModeToString( // Returns the group formation mode implied by (a) whether the operation has // channel_id and (b) if it has use_global_device_ids and if yes, its value. -StatusOr GetCollectiveOpGroupMode( +absl::StatusOr GetCollectiveOpGroupMode( bool has_channel_id, std::optional use_global_device_ids); // Figures out subgroups of participating devices from given replica_groups and @@ -123,32 +123,32 @@ StatusOr GetCollectiveOpGroupMode( // // This functions returns {{33, 34}, {44, 45, 55, 56}} // There are 2 subgroups of participating devices {33, 34}, {44, 45, 55, 56}. -StatusOr>> +absl::StatusOr>> GetParticipatingDevicesGroups(const DeviceAssignment& device_assignment, absl::Span replica_groups, CollectiveOpGroupMode group_mode); // Same as above, except that it returns the flattened id in the replica groups // instead of device id. -StatusOr> GetParticipatingFlattenedIdGroups( +absl::StatusOr> GetParticipatingFlattenedIdGroups( const DeviceAssignment& device_assignment, absl::Span replica_groups, CollectiveOpGroupMode group_mode); // Same as above, but take replica/partition count instead of device assignment. -StatusOr> GetParticipatingFlattenedIdGroups( +absl::StatusOr> GetParticipatingFlattenedIdGroups( absl::Span replica_groups, CollectiveOpGroupMode replica_group_mode, int replica_count, int partition_count); // Figures out which devices are participating in the collective subgroup. -StatusOr> GetParticipatingDevices( +absl::StatusOr> GetParticipatingDevices( GlobalDeviceId device_id, const DeviceAssignment& device_assignment, absl::Span replica_groups, CollectiveOpGroupMode group_mode); // Figures out how many ranks are participating in each collective subgroup. -StatusOr> GetPariticipantCountsForReplicaGroups( +absl::StatusOr> GetPariticipantCountsForReplicaGroups( int64_t num_replicas, int64_t num_partitions, absl::Span replica_groups, CollectiveOpGroupMode group_mode); @@ -172,6 +172,10 @@ inline constexpr absl::string_view kNopReturnTokenCustomCallTarget = // Returns true if instruction is a collective op or a collective fusion. bool IsCollective(const HloInstruction* instruction); +// Returns true if instruction is a collective op (or a collective fusion) with +// channel_id. +bool IsCollectiveWithChannelId(const HloInstruction* instruction); + // Returns true if instruction is a synchronous collective op. bool IsSyncCollective(const HloInstruction* instr); diff --git a/third_party/xla/xla/service/collective_ops_utils_test.cc b/third_party/xla/xla/service/collective_ops_utils_test.cc index bf6779415fc34f..9bb6d2239e0d29 100644 --- a/third_party/xla/xla/service/collective_ops_utils_test.cc +++ b/third_party/xla/xla/service/collective_ops_utils_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include "xla/service/collective_ops_utils.h" +#include #include #include #include @@ -22,10 +23,16 @@ limitations under the License. #include #include "absl/algorithm/container.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_instruction.h" #include "xla/service/computation_placer.h" #include "xla/service/global_device_id.h" +#include "xla/service/hlo_parser.h" +#include "xla/shape_util.h" #include "xla/xla_data.pb.h" #include "tsl/lib/core/status_test_util.h" +#include "tsl/platform/statusor.h" #include "tsl/platform/test.h" namespace xla { @@ -60,6 +67,52 @@ TEST(CollectiveOpsUtilsTest, GetParticipatingIDs_ReplicaGroups) { EXPECT_EQ(actual, expected); } +TEST(CollectiveOpsUtilsTest, CollectiveWithChannelId) { + absl::string_view hlo_string = R"( + HloModule module, is_scheduled=true + + ENTRY %cluster { + %param0 = f32[512]{0} parameter(0) + %copy0 = f32[512]{0} copy(param0) + %reshape0 = f32[1,1,512]{2,0,1} reshape(f32[512]{0} %copy0) + %all-gather = f32[1,4,512]{2,0,1} all-gather(f32[1,1,512]{2,0,1} %reshape0), channel_id=3621, replica_groups={{0,1,2,3}}, dimensions={1}, use_global_device_ids=true + %copy1 = f32[1,4,512]{2,0,1} copy(all-gather) + ROOT root = f32[1,4,512]{2,1,0} copy(%copy1) + })"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnUnverifiedModule(hlo_string)); + + HloInstruction *all_gather = + module->entry_computation()->GetInstructionWithName("all-gather"); + + EXPECT_TRUE(IsCollectiveWithChannelId(all_gather)); +} + +TEST(CollectiveOpsUtilsTest, CollectiveWithChannelId2) { + ReplicaGroup group; + for (int64_t i = 0; i < 8; i++) { + group.add_replica_ids(i); + } + + auto builder = HloComputation::Builder("CollectiveWithChannelId2"); + TF_ASSERT_OK_AND_ASSIGN( + HloInstruction * param_0, + builder.AddParameter(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(BF16, {1, 512, 4096}), "p0"))); + HloInstruction *instr = + builder.AddInstruction(HloInstruction::CreateAllGather( + ShapeUtil::MakeShape(BF16, {1, 4096, 4096}), {param_0}, 1, {group}, + true, 231, true)); + auto computation = builder.Build( + builder.AddInstruction(HloInstruction::CreateTuple({instr}))); + auto fusion = + HloInstruction::CreateFusion(ShapeUtil::MakeShape(BF16, {1, 4096, 4096}), + HloInstruction::FusionKind::kOutput, + {param_0}, computation.get(), "fusion"); + + EXPECT_TRUE(IsCollectiveWithChannelId(fusion.get())); +} + } // namespace // Tests for GetCollectOpGroupMode @@ -99,7 +152,7 @@ class GetCollectOpGroupModeTest : public testing::TestWithParam {}; TEST_P(GetCollectOpGroupModeTest, Test) { const TestCase &tc = GetParam(); - StatusOr actual = + absl::StatusOr actual = GetCollectiveOpGroupMode(tc.has_channel_id, tc.use_global_device_ids); if (tc.expected) { TF_ASSERT_OK(actual.status()); @@ -142,7 +195,7 @@ struct TestCase { // modes and their behavior. std::string TestCase::ToString() const { std::ostringstream s; - StatusOr group_mode = + absl::StatusOr group_mode = GetCollectiveOpGroupMode(has_channel_id, use_global_device_ids); if (group_mode.ok()) { s << CollectiveOpGroupModeToString(*group_mode); @@ -396,7 +449,7 @@ TEST_P(GetParticipatingDevicesTest, Test) { return group; }); - StatusOr group_mode = + absl::StatusOr group_mode = GetCollectiveOpGroupMode(tc.has_channel_id, tc.use_global_device_ids); if (!group_mode.ok()) { @@ -406,7 +459,7 @@ TEST_P(GetParticipatingDevicesTest, Test) { // Execute each sub-test. for (const TestCase::CurrentIdAndOutput &subtest : tc.subtests) { - StatusOr> actual = + absl::StatusOr> actual = GetParticipatingDevices(GlobalDeviceId(subtest.current_id), device_assignment, replica_groups, *group_mode); if (!actual.ok()) { @@ -420,9 +473,9 @@ TEST_P(GetParticipatingDevicesTest, Test) { EXPECT_EQ(*actual, expected); } - StatusOr>> actual_device_groups = - GetParticipatingDevicesGroups(device_assignment, replica_groups, - *group_mode); + absl::StatusOr>> + actual_device_groups = GetParticipatingDevicesGroups( + device_assignment, replica_groups, *group_mode); if (!actual_device_groups.ok()) { EXPECT_TRUE(tc.expected_failure); diff --git a/third_party/xla/xla/service/collective_permute_decomposer.cc b/third_party/xla/xla/service/collective_permute_decomposer.cc index 045013f8973ce1..fee09d023a80c6 100644 --- a/third_party/xla/xla/service/collective_permute_decomposer.cc +++ b/third_party/xla/xla/service/collective_permute_decomposer.cc @@ -157,7 +157,7 @@ Status DecomposeCollectivePermute( } } // namespace -StatusOr CollectivePermuteDecomposer::Run( +absl::StatusOr CollectivePermuteDecomposer::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { bool changed = false; diff --git a/third_party/xla/xla/service/collective_permute_decomposer.h b/third_party/xla/xla/service/collective_permute_decomposer.h index 3663f9b9f5d127..26537d5ae2dbfc 100644 --- a/third_party/xla/xla/service/collective_permute_decomposer.h +++ b/third_party/xla/xla/service/collective_permute_decomposer.h @@ -54,7 +54,7 @@ class CollectivePermuteDecomposer : public HloModulePass { using HloPassInterface::Run; // Runs CollectivePermuteDecomposer pass on computations in 'module'. // Returns whether the 'module' was changed. - StatusOr Run( + absl::StatusOr Run( HloModule* module, const absl::flat_hash_set& execution_threads) override; diff --git a/third_party/xla/xla/service/collective_pipeliner.cc b/third_party/xla/xla/service/collective_pipeliner.cc index 4ea84c293a889d..7aaf61427757fb 100644 --- a/third_party/xla/xla/service/collective_pipeliner.cc +++ b/third_party/xla/xla/service/collective_pipeliner.cc @@ -517,7 +517,7 @@ std::optional> CollectChainsToPushBackwards( HloInstruction* instr, int64_t loop_iter, const HloComputation* while_body, int64_t level_to_operate_on, const absl::flat_hash_set& loop_invariant_params) { - if (instr->user_count() != 1 || instr->HasControlDependencies()) { + if (instr->HasControlDependencies()) { return std::nullopt; } return CollectIndependentOperandChain(instr, loop_iter, @@ -607,11 +607,10 @@ void UpdateInstructionChannelId(HloInstruction* cloned_instr, // Clones a chain of instructions from a move_info for backward movement. template -StatusOr CloneBackwardChain(Comp& target_computation, - const WhileMoveInfo& move_info, - InstructionMap& clone_map, - int64_t loop_iter_idx, - int64_t& next_channel_id) { +absl::StatusOr CloneBackwardChain( + Comp& target_computation, const WhileMoveInfo& move_info, + InstructionMap& clone_map, int64_t loop_iter_idx, + int64_t& next_channel_id) { std::vector to_clone(move_info.formatting_ops.begin(), move_info.formatting_ops.end()); to_clone.push_back(move_info.collective_to_move); @@ -1361,7 +1360,7 @@ Status TransformLoopForward(const WhileLoopAnalysis& loop_analysis, [&next_channel_id, insert_non_alias_custom_call, level_to_operate_on]( HloInstruction* stacked_data, const InstructionMap& pipelined_values_map, - const WhileMoveInfo& move_info) -> StatusOr { + const WhileMoveInfo& move_info) -> absl::StatusOr { HloInstruction* processed = stacked_data->parent()->AddInstruction( move_info.collective_to_move->CloneWithNewOperands( move_info.collective_to_move->shape(), {stacked_data})); @@ -1388,11 +1387,11 @@ Status TransformLoopForward(const WhileLoopAnalysis& loop_analysis, return processed; }; auto extract_and_process_slice = - [&process_slice](HloInstruction* stacked_data, - HloInstruction* data_to_slice, - const WhileMoveInfo& move_info, - const InstructionMap& pipelined_values_map, - HloInstruction* dus_index) -> StatusOr { + [&process_slice]( + HloInstruction* stacked_data, HloInstruction* data_to_slice, + const WhileMoveInfo& move_info, + const InstructionMap& pipelined_values_map, + HloInstruction* dus_index) -> absl::StatusOr { HloComputation* computation = stacked_data->parent(); const Shape& slice_target_shape = move_info.collective_to_move->operand(0)->shape(); @@ -2329,7 +2328,7 @@ static Status TransformLoopBackward(const WhileLoopAnalysis& loop_analysis, return OkStatus(); } -StatusOr CollectivePipeliner::Run( +absl::StatusOr CollectivePipeliner::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { CHECK(config_.acceptable_formatting); diff --git a/third_party/xla/xla/service/collective_pipeliner.h b/third_party/xla/xla/service/collective_pipeliner.h index 46329630042c38..906f95d32041c4 100644 --- a/third_party/xla/xla/service/collective_pipeliner.h +++ b/third_party/xla/xla/service/collective_pipeliner.h @@ -113,7 +113,7 @@ class CollectivePipeliner : public HloModulePass { } using HloPassInterface::Run; - StatusOr Run( + absl::StatusOr Run( HloModule* module, const absl::flat_hash_set& execution_threads) override; diff --git a/third_party/xla/xla/service/collective_pipeliner_test.cc b/third_party/xla/xla/service/collective_pipeliner_test.cc index 4659cf11738f7c..68532638bdad8a 100644 --- a/third_party/xla/xla/service/collective_pipeliner_test.cc +++ b/third_party/xla/xla/service/collective_pipeliner_test.cc @@ -53,7 +53,7 @@ class CollectivePipelinerTest : public HloTestBase { HloModuleConfig config_; }; -StatusOr RunOptimizer( +absl::StatusOr RunOptimizer( HloModule* module, bool last_run, int64_t level_to_operate_on = 0, bool pipeline_use_tree = false, bool process_different_sized_ops = true, CollectivePipeliner::PipeliningDirection direction = diff --git a/third_party/xla/xla/service/collective_transformation_reorderer.cc b/third_party/xla/xla/service/collective_transformation_reorderer.cc index bfb1fcd9a9a015..a85e30e9d521a3 100644 --- a/third_party/xla/xla/service/collective_transformation_reorderer.cc +++ b/third_party/xla/xla/service/collective_transformation_reorderer.cc @@ -133,7 +133,8 @@ GetAllGatherTransformations(HloInstruction* all_gather) { } } // namespace -StatusOr CollectiveTransformationReorder::ReorderAllGatherTransformations( +absl::StatusOr +CollectiveTransformationReorder::ReorderAllGatherTransformations( HloModule* module, const absl::flat_hash_set& execution_threads) { // First, find all all-gathers and reshapes that are eligible for this @@ -211,7 +212,7 @@ StatusOr CollectiveTransformationReorder::ReorderAllGatherTransformations( return true; } -StatusOr CollectiveTransformationReorder::Run( +absl::StatusOr CollectiveTransformationReorder::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { return ReorderAllGatherTransformations(module, execution_threads); diff --git a/third_party/xla/xla/service/collective_transformation_reorderer.h b/third_party/xla/xla/service/collective_transformation_reorderer.h index 685323d085613f..f23c86bd3670ff 100644 --- a/third_party/xla/xla/service/collective_transformation_reorderer.h +++ b/third_party/xla/xla/service/collective_transformation_reorderer.h @@ -52,12 +52,12 @@ class CollectiveTransformationReorder : public HloModulePass { "collective-transformation-reorderer"; return kName; } - StatusOr Run( + absl::StatusOr Run( HloModule* module, const absl::flat_hash_set& execution_threads) override; private: - StatusOr ReorderAllGatherTransformations( + absl::StatusOr ReorderAllGatherTransformations( HloModule* module, const absl::flat_hash_set& execution_threads); }; diff --git a/third_party/xla/xla/service/collective_transformation_reorderer_test.cc b/third_party/xla/xla/service/collective_transformation_reorderer_test.cc index bb5a23f21fa531..e9a5440605c4a5 100644 --- a/third_party/xla/xla/service/collective_transformation_reorderer_test.cc +++ b/third_party/xla/xla/service/collective_transformation_reorderer_test.cc @@ -26,7 +26,7 @@ namespace op = xla::testing::opcode_matchers; class CollectiveTransformationReordererTest : public HloTestBase { public: - StatusOr RunCollectiveTransformationReorderer(HloModule* module) { + absl::StatusOr RunCollectiveTransformationReorderer(HloModule* module) { CollectiveTransformationReorder reorderer; return reorderer.Run(module, {}); } diff --git a/third_party/xla/xla/service/collectives_schedule_linearizer.cc b/third_party/xla/xla/service/collectives_schedule_linearizer.cc index 69e1791771638b..a367831a1d0fec 100644 --- a/third_party/xla/xla/service/collectives_schedule_linearizer.cc +++ b/third_party/xla/xla/service/collectives_schedule_linearizer.cc @@ -34,7 +34,7 @@ limitations under the License. namespace xla { // TODO(b/181653482): Fix for interprocedural collectives as well. -StatusOr CollectivesScheduleLinearizer::Run( +absl::StatusOr CollectivesScheduleLinearizer::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { if (is_enabled_ && !is_enabled_(module)) { diff --git a/third_party/xla/xla/service/collectives_schedule_linearizer.h b/third_party/xla/xla/service/collectives_schedule_linearizer.h index fedf34808c39f3..ad722dc3958873 100644 --- a/third_party/xla/xla/service/collectives_schedule_linearizer.h +++ b/third_party/xla/xla/service/collectives_schedule_linearizer.h @@ -39,7 +39,7 @@ class CollectivesScheduleLinearizer : public HloModulePass { } using HloPassInterface::Run; - StatusOr Run( + absl::StatusOr Run( HloModule* module, const absl::flat_hash_set& execution_threads) override; diff --git a/third_party/xla/xla/service/comparison_expander.cc b/third_party/xla/xla/service/comparison_expander.cc index ad399fafd2b700..4a7ff3d5a44628 100644 --- a/third_party/xla/xla/service/comparison_expander.cc +++ b/third_party/xla/xla/service/comparison_expander.cc @@ -15,44 +15,59 @@ limitations under the License. #include "xla/service/comparison_expander.h" -#include "xla/client/lib/comparators.h" -#include "xla/hlo/ir/dfs_hlo_visitor_with_default.h" +#include +#include + +#include "absl/algorithm/container.h" +#include "xla/comparison_util.h" +#include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_opcode.h" -#include "xla/service/hlo_creation_utils.h" +#include "xla/literal_util.h" +#include "xla/primitive_util.h" +#include "xla/shape_util.h" #include "xla/util.h" +#include "xla/xla_data.pb.h" namespace xla { HloInstruction* BitcastConvertFloatingPointToIntegral( - HloComputation* computation, HloInstruction* value, - const Shape& signed_shape, const Shape& unsigned_shape, - HloInstruction* zero, HloInstruction* max_value) { + HloComputation* computation, HloInstruction* value, HloInstruction* zero, + HloInstruction* min_value, HloInstruction* max_value) { // Switch from a floating point value to a integer value in such a way that // when using the integer value to compare, we get the same result for normal // values, and -Nan is treated as the smallest value, and Nan is treated as // the largest value. // If f is a float, and // x = bit_cast(f); - // y = x < 0 ? numeric_limits::max() - x : x; + // y = x < 0 ? numeric_limits::max() ^ x : x; // then y is ordered as an int32_t such that finite values have the obvious // order, -0 is ordered before 0, and -NaN and NaN appear at the beginning // and end of the ordering. - // Note that in order to avoid -x to overflow, we calculate - // numeric_limits::max() - x as unsigned, and then convert back to - // signed. + auto signed_shape = max_value->shape(); auto signed_value = computation->AddInstruction( HloInstruction::CreateBitcastConvert(signed_shape, value)); - auto unsigned_value = computation->AddInstruction( - HloInstruction::CreateBitcastConvert(unsigned_shape, value)); - auto flipped_value = computation->AddInstruction(HloInstruction::CreateBinary( - unsigned_shape, HloOpcode::kSubtract, max_value, unsigned_value)); - flipped_value = computation->AddInstruction( - HloInstruction::CreateBitcastConvert(signed_shape, flipped_value)); - auto compare_shape = signed_shape; - compare_shape.set_element_type(PRED); + auto compare_shape = ShapeUtil::ChangeElementType(signed_shape, PRED); + HloInstruction* flipped_value; + if (primitive_util::HasNegativeZero(value->shape().element_type())) { + flipped_value = computation->AddInstruction(HloInstruction::CreateBinary( + signed_shape, HloOpcode::kXor, max_value, signed_value)); + } else { + // There is no -0 so min_denorm() must take its place, this is the same as + // adding one to flipped_value. + flipped_value = computation->AddInstruction(HloInstruction::CreateBinary( + signed_shape, HloOpcode::kSubtract, min_value, signed_value)); + + // NaN is the smallest value as it is negative. + auto nan_bit_pattern = min_value; + auto is_nan = computation->AddInstruction(HloInstruction::CreateCompare( + compare_shape, signed_value, nan_bit_pattern, + ComparisonDirection::kEq)); + flipped_value = computation->AddInstruction(HloInstruction::CreateTernary( + signed_shape, HloOpcode::kSelect, is_nan, min_value, flipped_value)); + } auto is_negative = computation->AddInstruction(HloInstruction::CreateCompare( compare_shape, signed_value, zero, ComparisonDirection::kLt)); return computation->AddInstruction( @@ -63,9 +78,9 @@ HloInstruction* BitcastConvertFloatingPointToIntegral( bool ComparisonExpander::InstructionMatchesPattern( HloInstruction* instruction) { if (HloCompareInstruction* compare = - dynamic_cast(instruction)) { + DynCast(instruction)) { HloInstruction* lhs = instruction->operands()[0]; - if (compare->type() == Comparison::Type::kFloatTotalOrder && + if (compare->order() == Comparison::Order::kTotal && primitive_util::IsFloatingPointType(lhs->shape().element_type())) { return true; } @@ -73,60 +88,64 @@ bool ComparisonExpander::InstructionMatchesPattern( return false; } -StatusOr ComparisonExpander::ExpandInstruction( +absl::StatusOr ComparisonExpander::ExpandInstruction( HloInstruction* instruction) { - CHECK(instruction->opcode() == HloOpcode::kCompare); + CHECK_EQ(instruction->opcode(), HloOpcode::kCompare); HloCompareInstruction* compare = static_cast(instruction); - CHECK(compare->type() == Comparison::Type::kFloatTotalOrder); + CHECK(compare->order() == Comparison::Order::kTotal) + << ComparisonOrderToString(compare->order()); HloComputation* computation = instruction->parent(); HloInstruction* lhs = instruction->operands()[0]; HloInstruction* rhs = instruction->operands()[1]; - Shape compare_shape = lhs->shape(); - PrimitiveType compare_type = compare_shape.element_type(); + PrimitiveType compare_type = lhs->shape().element_type(); CHECK(primitive_util::IsFloatingPointType(compare_type)); - // Special-case handling for BF16. We currently do not support direct - // comparisons with BF16, so we convert to F32 and then use the F32 - // comparison logic. - if (compare_type == BF16) { - compare_type = F32; - compare_shape.set_element_type(compare_type); - lhs = computation->AddInstruction( - HloInstruction::CreateConvert(compare_shape, lhs)); - rhs = computation->AddInstruction( - HloInstruction::CreateConvert(compare_shape, rhs)); + if (auto do_upcast = absl::c_find_if( + expand_via_upcast_, + [compare_type](std::pair upcast) { + return upcast.first == compare_type; + }); + do_upcast != expand_via_upcast_.end()) { + CHECK(primitive_util::CastPreservesValues(do_upcast->first, + do_upcast->second)); + compare_type = do_upcast->second; + lhs = computation->AddInstruction(HloInstruction::CreateConvert( + ShapeUtil::ChangeElementType(lhs->shape(), compare_type), lhs)); + rhs = computation->AddInstruction(HloInstruction::CreateConvert( + ShapeUtil::ChangeElementType(rhs->shape(), compare_type), rhs)); } - int64_t bit_width = primitive_util::BitWidth(compare_type); + int64_t bit_width = primitive_util::BitWidth(lhs->shape().element_type()); PrimitiveType signed_type = primitive_util::SignedIntegralTypeForBitWidth(bit_width); - PrimitiveType unsigned_type = - primitive_util::UnsignedIntegralTypeForBitWidth(bit_width); - auto signed_shape = compare_shape; - signed_shape.set_element_type(signed_type); - auto unsigned_shape = compare_shape; - unsigned_shape.set_element_type(unsigned_type); + auto signed_shape = ShapeUtil::ChangeElementType(lhs->shape(), signed_type); + auto zero_value = computation->AddInstruction( HloInstruction::CreateConstant(LiteralUtil::Zero(signed_type))); - zero_value = computation->AddInstruction(HloInstruction::CreateBroadcast( - signed_shape, zero_value, zero_value->shape().dimensions())); - auto max_signed = computation->AddInstruction( + zero_value = computation->AddInstruction( + HloInstruction::CreateBroadcast(signed_shape, zero_value, {})); + + auto min_value = computation->AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::MinValue(signed_shape.element_type()))); + min_value = computation->AddInstruction( + HloInstruction::CreateBroadcast(signed_shape, min_value, {})); + + auto max_value = computation->AddInstruction( HloInstruction::CreateConstant(LiteralUtil::MaxValue(signed_type))); - auto max_shape = max_signed->shape(); - max_shape.set_element_type(unsigned_type); - auto max_unsigned = computation->AddInstruction( - HloInstruction::CreateConvert(max_shape, max_signed)); - auto max_value = computation->AddInstruction(HloInstruction::CreateBroadcast( - unsigned_shape, max_unsigned, max_shape.dimensions())); - lhs = BitcastConvertFloatingPointToIntegral( - computation, lhs, signed_shape, unsigned_shape, zero_value, max_value); - rhs = BitcastConvertFloatingPointToIntegral( - computation, rhs, signed_shape, unsigned_shape, zero_value, max_value); + max_value = computation->AddInstruction( + HloInstruction::CreateBroadcast(signed_shape, max_value, {})); + + lhs = BitcastConvertFloatingPointToIntegral(computation, lhs, zero_value, + min_value, max_value); + rhs = BitcastConvertFloatingPointToIntegral(computation, rhs, zero_value, + min_value, max_value); + auto new_compare = computation->AddInstruction(HloInstruction::CreateCompare( instruction->shape(), lhs, rhs, compare->direction(), Comparison::Type::kSigned)); + VLOG(2) << "New comparison instruction for total order:" - << new_compare->ToString() << "\n"; + << new_compare->ToString(); return new_compare; } diff --git a/third_party/xla/xla/service/comparison_expander.h b/third_party/xla/xla/service/comparison_expander.h index 17af0712b102c1..d95b6df78c123f 100644 --- a/third_party/xla/xla/service/comparison_expander.h +++ b/third_party/xla/xla/service/comparison_expander.h @@ -17,10 +17,15 @@ limitations under the License. #define XLA_SERVICE_COMPARISON_EXPANDER_H_ #include +#include -#include "xla/hlo/ir/hlo_module.h" -#include "xla/service/hlo_pass_interface.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/primitive_util.h" #include "xla/service/op_expander_pass.h" +#include "xla/statusor.h" +#include "xla/xla_data.pb.h" namespace xla { @@ -28,7 +33,11 @@ namespace xla { // order comparison of floating point numbers. class ComparisonExpander : public OpExpanderPass { public: - explicit ComparisonExpander() = default; + explicit ComparisonExpander( + absl::Span> + expand_via_upcast = {}) + : expand_via_upcast_(expand_via_upcast.begin(), expand_via_upcast.end()) { + } ~ComparisonExpander() override = default; absl::string_view name() const override { return "comparison-expander"; } @@ -38,8 +47,10 @@ class ComparisonExpander : public OpExpanderPass { // Returns a replacement for `instruction`, or nullptr if no replacement is // needed (e.g. only the to_apply subcomputation of the instruction was // modified). - StatusOr ExpandInstruction( + absl::StatusOr ExpandInstruction( HloInstruction* instruction) override; + + std::vector> expand_via_upcast_; }; } // namespace xla diff --git a/third_party/xla/xla/service/compilation_cache.cc b/third_party/xla/xla/service/compilation_cache.cc index df45e89d1a4199..9187924b6d6a95 100644 --- a/third_party/xla/xla/service/compilation_cache.cc +++ b/third_party/xla/xla/service/compilation_cache.cc @@ -51,7 +51,7 @@ ExecutionHandle CompilationCache::Insert( return handle; } -StatusOr> CompilationCache::LookUp( +absl::StatusOr> CompilationCache::LookUp( const ExecutionHandle& handle) const { absl::MutexLock lock(&mutex_); diff --git a/third_party/xla/xla/service/compilation_cache.h b/third_party/xla/xla/service/compilation_cache.h index 3c01943a7fa216..65384bf8340c43 100644 --- a/third_party/xla/xla/service/compilation_cache.h +++ b/third_party/xla/xla/service/compilation_cache.h @@ -39,7 +39,7 @@ class CompilationCache { // Lookup the Executable for the specified handle in the cache. Return a // shared_ptr to the Executable if it exists in the cache. - StatusOr> LookUp( + absl::StatusOr> LookUp( const ExecutionHandle& handle) const; protected: diff --git a/third_party/xla/xla/service/compilation_environments.cc b/third_party/xla/xla/service/compilation_environments.cc index c7defcf7c53d0e..0c2569b92dfcfb 100644 --- a/third_party/xla/xla/service/compilation_environments.cc +++ b/third_party/xla/xla/service/compilation_environments.cc @@ -124,7 +124,7 @@ CompilationEnvironments& CompilationEnvironments::operator=( return *this; } -StatusOr> +absl::StatusOr> CompilationEnvironments::CreateFromProto( const CompilationEnvironmentsProto& proto) { auto envs = std::make_unique(); diff --git a/third_party/xla/xla/service/compilation_environments.h b/third_party/xla/xla/service/compilation_environments.h index bb44ca643e37f1..3ffea24bb53b11 100644 --- a/third_party/xla/xla/service/compilation_environments.h +++ b/third_party/xla/xla/service/compilation_environments.h @@ -47,7 +47,7 @@ namespace xla { class CompilationEnvironments { public: using ProcessNewEnvFn = - std::function>( + std::function>( std::unique_ptr)>; CompilationEnvironments() = default; @@ -56,8 +56,8 @@ class CompilationEnvironments { ~CompilationEnvironments() = default; // Deserializes the given CompilationEnvironments proto. - static StatusOr> CreateFromProto( - const CompilationEnvironmentsProto& proto); + static absl::StatusOr> + CreateFromProto(const CompilationEnvironmentsProto& proto); // Whenever an environment is added to CompilationEnvironments, even when // GetEnv() adds a lazily initialized one, it is passed to the function diff --git a/third_party/xla/xla/service/compile_only_service.cc b/third_party/xla/xla/service/compile_only_service.cc index c6dd110ecbd62e..ab3acea65840d8 100644 --- a/third_party/xla/xla/service/compile_only_service.cc +++ b/third_party/xla/xla/service/compile_only_service.cc @@ -33,14 +33,14 @@ limitations under the License. namespace xla { -/* static */ StatusOr> +/* static */ absl::StatusOr> CompileOnlyService::NewService(se::Platform* platform) { ServiceOptions default_options; default_options.set_platform(platform); return NewService(default_options); } -/* static */ StatusOr> +/* static */ absl::StatusOr> CompileOnlyService::NewService(const ServiceOptions& options) { se::Platform* platform = options.platform(); if (platform == nullptr) { @@ -58,7 +58,7 @@ CompileOnlyService::CompileOnlyService(const ServiceOptions& options, Compiler* compiler) : Service(options, /*execute_backend=*/nullptr), compiler_(compiler) {} -StatusOr>> +absl::StatusOr>> CompileOnlyService::CompileAheadOfTime( absl::Span computations, const AotCompilationOptions& options, diff --git a/third_party/xla/xla/service/compile_only_service.h b/third_party/xla/xla/service/compile_only_service.h index acbb09694ba85d..09ca0534454b46 100644 --- a/third_party/xla/xla/service/compile_only_service.h +++ b/third_party/xla/xla/service/compile_only_service.h @@ -33,9 +33,9 @@ class CompileOnlyService : public Service { // Factory for creating a CompileOnlyService. The parameter platform is the // platform that the service should target. If platform is null then the // default platform is used. - static StatusOr> NewService( + static absl::StatusOr> NewService( se::Platform* platform); - static StatusOr> NewService( + static absl::StatusOr> NewService( const ServiceOptions& options); // A description of a xla computation to compile using CompileAheadOfTime. @@ -48,7 +48,7 @@ class CompileOnlyService : public Service { // Compiles a list of xla computations for ahead-of-time execution. This is // intended for use in static compilation. See // |CompileOnlyClient::CompileAheadOfTime| for additional details. - StatusOr>> + absl::StatusOr>> CompileAheadOfTime(absl::Span computations, const AotCompilationOptions& options, std::unique_ptr* metadata); diff --git a/third_party/xla/xla/service/compiler.cc b/third_party/xla/xla/service/compiler.cc index dabe4692c8bffd..b0feb9a62ae811 100644 --- a/third_party/xla/xla/service/compiler.cc +++ b/third_party/xla/xla/service/compiler.cc @@ -71,7 +71,7 @@ std::unique_ptr Compiler::ComputeDefaultBackendConfig( } // Define a default version where metadata is not used. -StatusOr>> +absl::StatusOr>> Compiler::CompileAheadOfTime( std::unique_ptr module_group, const AotCompilationOptions& options, @@ -108,7 +108,7 @@ Compiler::GetPlatformCompilers() { (*factories)[platform_id] = std::move(compiler_factory); } -/* static */ StatusOr Compiler::GetForPlatform( +/* static */ absl::StatusOr Compiler::GetForPlatform( const se::Platform* platform) { absl::MutexLock lock(&platform_compiler_mutex_); diff --git a/third_party/xla/xla/service/compiler.h b/third_party/xla/xla/service/compiler.h index e9fb13d319ac6b..2f3ef5973bbe9e 100644 --- a/third_party/xla/xla/service/compiler.h +++ b/third_party/xla/xla/service/compiler.h @@ -44,6 +44,10 @@ limitations under the License. #include "tsl/platform/protobuf.h" #include "tsl/platform/threadpool.h" +namespace mlir { +class DialectRegistry; +} // namespace mlir + namespace xla { // The following types are used for ahead of time compilation. @@ -63,15 +67,20 @@ class AotCompilationResult { virtual ~AotCompilationResult() = default; - virtual StatusOr SerializeAsString() const { + virtual absl::StatusOr SerializeAsString() const { return Unimplemented("SerializeAsString unimplemented."); } - virtual StatusOr> LoadExecutable( + virtual absl::StatusOr> LoadExecutable( Compiler* compiler, const se::StreamExecutor* executor) const { return Unimplemented("LoadExecutable unimplemented."); } + // Returns the optimized HLO module if one was computed and the implementation + // supports it. + virtual const HloModule* optimized_module() const = 0; + virtual std::unique_ptr consume_optimized_module() = 0; + protected: AotCompilationResult() = default; }; @@ -136,7 +145,7 @@ class Compiler { // An optional thread pool for parallel compilation. tsl::thread::ThreadPool* thread_pool = nullptr; - std::function, Shape>>( + std::function, Shape>>( const HloModule& module)> layout_canonicalization_callback = {}; @@ -145,6 +154,10 @@ class Compiler { // AOT device description. If provided, used instead of querying the device // on which compilation is performed. std::optional target_config; + + // Registry of MLIR dialects and plugins to be loaded during optimization. + // If non-null, it will be used to construct relevant MLIR contexts. + mlir::DialectRegistry* registry = nullptr; }; virtual ~Compiler() = default; @@ -154,10 +167,10 @@ class Compiler { // Runs Hlo passes to optimize the given Hlo module, returns the optimized // module. - virtual StatusOr> RunHloPasses( + virtual absl::StatusOr> RunHloPasses( std::unique_ptr module, se::StreamExecutor* executor, const CompileOptions& options) = 0; - StatusOr> RunHloPasses( + absl::StatusOr> RunHloPasses( std::unique_ptr module, se::StreamExecutor* executor, se::DeviceMemoryAllocator* device_allocator) { return RunHloPasses(std::move(module), executor, @@ -168,7 +181,7 @@ class Compiler { // assignments. // The returned 'BufferAssignment' retains a pointer to the 'HloModule', so // the module must live at least as long as the buffer assignments. - virtual StatusOr> AssignBuffers( + virtual absl::StatusOr> AssignBuffers( HloModule* module, const se::StreamExecutor* executor) { return Unimplemented("This compiler does not support this method"); } @@ -181,10 +194,10 @@ class Compiler { // // The compiler may optionally specialize to the individual device // (not just type of device) indicated by the executor. - virtual StatusOr> RunBackend( + virtual absl::StatusOr> RunBackend( std::unique_ptr module, se::StreamExecutor* executor, const CompileOptions& options) = 0; - StatusOr> RunBackend( + absl::StatusOr> RunBackend( std::unique_ptr module, se::StreamExecutor* executor, se::DeviceMemoryAllocator* device_allocator) { return RunBackend(std::move(module), executor, @@ -197,7 +210,8 @@ class Compiler { // Note: The default implementation of the API here does not utilize the given // buffer assignment. Different backends are a expected to override the // following method to achieve this functionality. - virtual StatusOr> RunBackendWithBufferAssignment( + virtual absl::StatusOr> + RunBackendWithBufferAssignment( std::unique_ptr module, const BufferAssignmentProto* /*buffer_assignment_proto*/, se::StreamExecutor* executor, const CompileOptions& options) { @@ -205,7 +219,7 @@ class Compiler { return RunBackend(std::move(module), executor, options); } - StatusOr> RunBackendWithBufferAssignment( + absl::StatusOr> RunBackendWithBufferAssignment( std::unique_ptr module, const BufferAssignmentProto* buffer_assignment_proto, se::StreamExecutor* executor, @@ -217,7 +231,7 @@ class Compiler { // Returns a (deserialized) AotCompilationResult from a serialized // AotCompilationResult. - virtual StatusOr> + virtual absl::StatusOr> LoadAotCompilationResult(const std::string& serialized_aot_result) { return Unimplemented("LoadAotCompilationResult unimplemented."); } @@ -228,11 +242,11 @@ class Compiler { // // TODO(b/68666782): Remove this method after adding support for multiple // modules to RunHloPasses and RunBackends. - virtual StatusOr>> Compile( + virtual absl::StatusOr>> Compile( std::unique_ptr module_group, std::vector> stream_exec, const CompileOptions& options) = 0; - StatusOr>> Compile( + absl::StatusOr>> Compile( std::unique_ptr module_group, std::vector> stream_exec, se::DeviceMemoryAllocator* device_allocator) { @@ -261,13 +275,13 @@ class Compiler { // Compiles the HLO module group for ahead-of-time execution. This is // intended for use in static compilation. - virtual StatusOr>> + virtual absl::StatusOr>> CompileAheadOfTime(std::unique_ptr module_group, const AotCompilationOptions& options) = 0; // Similar to CompileAheadOfTime above but AotCompilationMetadata // has an argument that can be populated during compilation. - virtual StatusOr>> + virtual absl::StatusOr>> CompileAheadOfTime(std::unique_ptr module_group, const AotCompilationOptions& options, std::unique_ptr* metadata); @@ -287,7 +301,7 @@ class Compiler { // Returns the compiler singleton pointer if it is available for the given // platform, or an error status if it is not. - static StatusOr GetForPlatform(const se::Platform* platform); + static absl::StatusOr GetForPlatform(const se::Platform* platform); // Returns a function that computes the size in bytes of the logical // buffer that contains a shape. @@ -307,7 +321,7 @@ class Compiler { } // Returns an AotCompilationResult of the executable for serialization. - virtual StatusOr> Export( + virtual absl::StatusOr> Export( Executable* executable) const { return Unimplemented("Export unimplemented"); } diff --git a/third_party/xla/xla/service/computation_layout.cc b/third_party/xla/xla/service/computation_layout.cc index 3e7852f35e353e..117b04f3aabf7b 100644 --- a/third_party/xla/xla/service/computation_layout.cc +++ b/third_party/xla/xla/service/computation_layout.cc @@ -59,8 +59,8 @@ bool ComputationLayout::AnyLayoutSet() const { result_layout_.LayoutIsSet(); } -StatusOr> ComputationLayout::FlattenedParameterLayouts() - const { +absl::StatusOr> +ComputationLayout::FlattenedParameterLayouts() const { std::vector result; for (int i = 0; i < parameter_count(); ++i) { TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus( @@ -88,7 +88,7 @@ StatusOr> ComputationLayout::FlattenedParameterLayouts() return result; } -StatusOr> ComputationLayout::FlattenedResultLayouts() +absl::StatusOr> ComputationLayout::FlattenedResultLayouts() const { std::vector result; TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus( diff --git a/third_party/xla/xla/service/computation_layout.h b/third_party/xla/xla/service/computation_layout.h index 27e7876d6c0c19..b6c947b2b7c9e9 100644 --- a/third_party/xla/xla/service/computation_layout.h +++ b/third_party/xla/xla/service/computation_layout.h @@ -86,12 +86,12 @@ class ComputationLayout { // Returns a list of each parameter's layout. If the parameters are tupled, // returns an untupled list. Must only be called if all parameters have // layouts set (check with LayoutIsSet()). - StatusOr> FlattenedParameterLayouts() const; + absl::StatusOr> FlattenedParameterLayouts() const; // Returns a list of each output's layout. If the result shape is a tuple, // returns an untupled list. Must only be called if all outputs have layouts // set (check with LayoutIsSet()). - StatusOr> FlattenedResultLayouts() const; + absl::StatusOr> FlattenedResultLayouts() const; // Prints a string representation of this object. void Print(Printer* printer) const; diff --git a/third_party/xla/xla/service/computation_placer.cc b/third_party/xla/xla/service/computation_placer.cc index d0274f5630d941..b896c7d10cc408 100644 --- a/third_party/xla/xla/service/computation_placer.cc +++ b/third_party/xla/xla/service/computation_placer.cc @@ -42,8 +42,8 @@ using absl::StrCat; namespace xla { -StatusOr DeviceAssignment::LogicalIdForDevice( - GlobalDeviceId device_id) const { +absl::StatusOr +DeviceAssignment::LogicalIdForDevice(GlobalDeviceId device_id) const { std::optional logical_id; for (int r = 0; r < replica_count(); ++r) { for (int c = 0; c < computation_count(); ++c) { @@ -65,7 +65,7 @@ StatusOr DeviceAssignment::LogicalIdForDevice( } } -StatusOr DeviceAssignment::ReplicaIdForDevice( +absl::StatusOr DeviceAssignment::ReplicaIdForDevice( GlobalDeviceId device_id) const { TF_ASSIGN_OR_RETURN(const LogicalID logical_id, LogicalIdForDevice(device_id)); @@ -98,7 +98,7 @@ Status DeviceAssignment::Serialize(DeviceAssignmentProto* proto) const { return OkStatus(); } -/* static */ StatusOr> +/* static */ absl::StatusOr> DeviceAssignment::Deserialize(const DeviceAssignmentProto& proto) { TF_RET_CHECK(proto.computation_devices_size() == proto.computation_count()); if (proto.replica_count() <= 0 || proto.computation_count() <= 0) { @@ -135,16 +135,16 @@ std::string DeviceAssignment::ToString() const { return output; } -StatusOr ComputationPlacer::DeviceId(int replica, int computation, - int replica_count, - int computation_count) { +absl::StatusOr ComputationPlacer::DeviceId(int replica, int computation, + int replica_count, + int computation_count) { TF_RET_CHECK(replica < replica_count); TF_RET_CHECK(computation < computation_count); return computation * replica_count + replica; } -StatusOr ComputationPlacer::AssignDevices( +absl::StatusOr ComputationPlacer::AssignDevices( int replica_count, int computation_count) { DeviceAssignment assignment(replica_count, computation_count); for (int replica = 0; replica < replica_count; ++replica) { @@ -165,7 +165,7 @@ StatusOr ComputationPlacer::AssignDevices( auto* computation_placers = GetPlatformComputationPlacers(); if (computation_placers->find(platform_id) != computation_placers->end()) { // TODO(b/282059652): Consider logging the platform name using - // MultiPlatformManager::PlatformWithId(). No doing that for now to avoid + // PlatformManager::PlatformWithId(). No doing that for now to avoid // introducing unwanted dependency. LOG(WARNING) << "computation placer already registered. Please check " "linkage and avoid linking the same target more than once."; @@ -173,8 +173,8 @@ StatusOr ComputationPlacer::AssignDevices( (*computation_placers)[platform_id].creation_function = creation_function; } -/* static */ StatusOr ComputationPlacer::GetForPlatform( - const se::Platform* platform) { +/* static */ absl::StatusOr +ComputationPlacer::GetForPlatform(const se::Platform* platform) { absl::MutexLock lock(&ComputationPlacer::platform_computation_placer_mutex_); auto* computation_placers = GetPlatformComputationPlacers(); diff --git a/third_party/xla/xla/service/computation_placer.h b/third_party/xla/xla/service/computation_placer.h index a603f6fcd7fa57..12facc79505621 100644 --- a/third_party/xla/xla/service/computation_placer.h +++ b/third_party/xla/xla/service/computation_placer.h @@ -55,9 +55,9 @@ class DeviceAssignment : public Array2D { }; // Finds the (replica ID, computation ID) pair for the given device. - StatusOr LogicalIdForDevice(GlobalDeviceId device_id) const; + absl::StatusOr LogicalIdForDevice(GlobalDeviceId device_id) const; // Finds the replica ID for the given device. - StatusOr ReplicaIdForDevice(GlobalDeviceId device_id) const; + absl::StatusOr ReplicaIdForDevice(GlobalDeviceId device_id) const; // Returns a map from device ID to logical ID. Querying this map is much more // efficient than `LogicalIdForDevice` if queried repeatedly. absl::flat_hash_map GetDeviceToLogicalIdMap() @@ -69,7 +69,7 @@ class DeviceAssignment : public Array2D { // Return a std::unique_ptr instead of a DeviceAssignment // directly because one of the supported TF platforms (mac) does not compile // due to a StatusOr of an incomplete type (DeviceAssignment). - static StatusOr> Deserialize( + static absl::StatusOr> Deserialize( const DeviceAssignmentProto& proto); std::string ToString() const; @@ -85,13 +85,14 @@ class ComputationPlacer { // Returns the device id assigned to the given replica and computation // instance for [replica_count x computation_count] setup. The returned device // id must match the assignment from PlaceReplicatedComputation(). - virtual StatusOr DeviceId(int replica, int computation, - int replica_count, int computation_count); + virtual absl::StatusOr DeviceId(int replica, int computation, + int replica_count, + int computation_count); // Returns the device ids assigned to a set of replicated computations, given // the number of replicas and the number of computations. - virtual StatusOr AssignDevices(int replica_count, - int computation_count); + virtual absl::StatusOr AssignDevices(int replica_count, + int computation_count); using ComputationPlacerCreationFunction = std::unique_ptr (*)(); @@ -103,7 +104,7 @@ class ComputationPlacer { // Returns the computation placer singleton pointer if it is available for the // given platform, or an error status if it is not. - static StatusOr GetForPlatform( + static absl::StatusOr GetForPlatform( const se::Platform* platform); private: diff --git a/third_party/xla/xla/service/conditional_canonicalizer.cc b/third_party/xla/xla/service/conditional_canonicalizer.cc index 587c1597e57d3b..22eb2cb6500fea 100644 --- a/third_party/xla/xla/service/conditional_canonicalizer.cc +++ b/third_party/xla/xla/service/conditional_canonicalizer.cc @@ -43,7 +43,7 @@ Status CanonicalizeNonTupleConditional(HloInstruction* conditional) { } } // namespace -StatusOr ConditionalCanonicalizer::Run( +absl::StatusOr ConditionalCanonicalizer::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { XLA_VLOG_LINES( diff --git a/third_party/xla/xla/service/conditional_canonicalizer.h b/third_party/xla/xla/service/conditional_canonicalizer.h index a8617513a6ecea..3446386f414f8f 100644 --- a/third_party/xla/xla/service/conditional_canonicalizer.h +++ b/third_party/xla/xla/service/conditional_canonicalizer.h @@ -32,7 +32,7 @@ class ConditionalCanonicalizer : public HloModulePass { } using HloPassInterface::Run; - StatusOr Run( + absl::StatusOr Run( HloModule* module, const absl::flat_hash_set& execution_threads) override; }; diff --git a/third_party/xla/xla/service/conditional_code_motion.cc b/third_party/xla/xla/service/conditional_code_motion.cc index 2c5af6c38c3bd0..1fed4183420b64 100644 --- a/third_party/xla/xla/service/conditional_code_motion.cc +++ b/third_party/xla/xla/service/conditional_code_motion.cc @@ -533,8 +533,8 @@ Status RestructureConditionalInstruction(HloComputation* computation, return OkStatus(); } -StatusOr ConvertSpecialMove(HloInstruction* conditional, - bool is_layout_sensitive) { +absl::StatusOr ConvertSpecialMove(HloInstruction* conditional, + bool is_layout_sensitive) { int branch_count = conditional->branch_count(); if (branch_count <= 0) { return false; @@ -673,7 +673,7 @@ StatusOr ConvertSpecialMove(HloInstruction* conditional, // are the shape of the operands are identical and their properties are // identical. Will start from the root instruction of each branch and get // the identical ops to hoist. -StatusOr ConditionalCodeMotion::MoveInstructionOut( +absl::StatusOr ConditionalCodeMotion::MoveInstructionOut( HloInstruction* conditional, std::vector& to_move_out, std::vector& new_boundaries) { if (to_move_out.empty()) { @@ -780,7 +780,7 @@ StatusOr ConditionalCodeMotion::MoveInstructionOut( } // Hoist conditional users from outside to inside the branches. -StatusOr ConditionalCodeMotion::MoveUserInstructionsIn( +absl::StatusOr ConditionalCodeMotion::MoveUserInstructionsIn( HloInstruction* conditional, std::vector& to_move_in) { if (to_move_in.empty()) { return false; @@ -1235,7 +1235,7 @@ class MoveOperandIntoBranch { }; // Hoist operands of a conditional from outside to inside the branches. -StatusOr ConditionalCodeMotion::MoveOperandInstructionsIn( +absl::StatusOr ConditionalCodeMotion::MoveOperandInstructionsIn( HloInstruction* conditional, std::vector& to_move_in) { // Mapping boundaries to be moved to their new representations. int64_t to_move_in_size = to_move_in.size(); @@ -1944,7 +1944,7 @@ ConditionalCodeMotion::Decision ConditionalCodeMotion::ConsiderCodeMotion( return Decision(Decision::Direction::kNoChange, 0); } -StatusOr ConditionalCodeMotion::Run( +absl::StatusOr ConditionalCodeMotion::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { VLOG(2) << "Begin a new pass of conditional code motion optimization.\n"; diff --git a/third_party/xla/xla/service/conditional_code_motion.h b/third_party/xla/xla/service/conditional_code_motion.h index 21fbe2b00b5e46..caf92342900c9a 100644 --- a/third_party/xla/xla/service/conditional_code_motion.h +++ b/third_party/xla/xla/service/conditional_code_motion.h @@ -181,7 +181,7 @@ class ConditionalCodeMotion : public HloModulePass { absl::string_view name() const override { return "conditional-code-motion"; } using HloPassInterface::Run; - StatusOr Run( + absl::StatusOr Run( HloModule* module, const absl::flat_hash_set& execution_threads) override; @@ -228,13 +228,13 @@ class ConditionalCodeMotion : public HloModulePass { // moved. int64_t memory_increase_allowance_ = 5000; int64_t memory_increase_ = 0; - StatusOr MoveInstructionOut(HloInstruction* conditional, - std::vector& to_move_out, - std::vector& new_boundaries); - StatusOr MoveUserInstructionsIn(HloInstruction* conditional, - std::vector& to_move_in); - StatusOr MoveOperandInstructionsIn(HloInstruction* conditional, - std::vector& to_move_in); + absl::StatusOr MoveInstructionOut( + HloInstruction* conditional, std::vector& to_move_out, + std::vector& new_boundaries); + absl::StatusOr MoveUserInstructionsIn( + HloInstruction* conditional, std::vector& to_move_in); + absl::StatusOr MoveOperandInstructionsIn( + HloInstruction* conditional, std::vector& to_move_in); void SetDefaultMoveConfig(); }; } // namespace conditional_opt diff --git a/third_party/xla/xla/service/conditional_simplifier.cc b/third_party/xla/xla/service/conditional_simplifier.cc index fdc8f423dc2109..d4d568dc6ed9ee 100644 --- a/third_party/xla/xla/service/conditional_simplifier.cc +++ b/third_party/xla/xla/service/conditional_simplifier.cc @@ -61,7 +61,7 @@ bool ComputationIsEmptyWithArrayRoot(const HloComputation* computation) { return empty_operations && contains_array; } -StatusOr TryRemoveUnusedConditionalOperands( +absl::StatusOr TryRemoveUnusedConditionalOperands( HloComputation* computation, const absl::flat_hash_set& calling_conditionals) { HloInstruction* param = computation->parameter_instruction(0); @@ -439,7 +439,7 @@ bool MergeDuplicateTupleElements(HloInstruction* conditional) { // inline that computation. // // Returns true if it made a change to the graph. -StatusOr ConditionalSimplifier::TryRemoveConditional( +absl::StatusOr ConditionalSimplifier::TryRemoveConditional( HloInstruction* conditional) { CHECK_EQ(conditional->opcode(), HloOpcode::kConditional); // Do not remove conditionals that contain side-effecting instructions or @@ -601,7 +601,7 @@ static bool InstructionCallsChannelInstructions( return false; } -StatusOr ConditionalSimplifier::Run( +absl::StatusOr ConditionalSimplifier::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { XLA_VLOG_LINES( diff --git a/third_party/xla/xla/service/conditional_simplifier.h b/third_party/xla/xla/service/conditional_simplifier.h index cde40ef66fc7f6..8eeab8279dd8f7 100644 --- a/third_party/xla/xla/service/conditional_simplifier.h +++ b/third_party/xla/xla/service/conditional_simplifier.h @@ -29,12 +29,12 @@ class ConditionalSimplifier : public HloModulePass { public: absl::string_view name() const override { return "simplify-conditional"; } using HloPassInterface::Run; - StatusOr Run( + absl::StatusOr Run( HloModule* module, const absl::flat_hash_set& execution_threads) override; private: - StatusOr TryRemoveConditional(HloInstruction* conditional); + absl::StatusOr TryRemoveConditional(HloInstruction* conditional); }; } // namespace xla diff --git a/third_party/xla/xla/service/conditional_to_select.cc b/third_party/xla/xla/service/conditional_to_select.cc index 81283604072cc2..3b2fd710038300 100644 --- a/third_party/xla/xla/service/conditional_to_select.cc +++ b/third_party/xla/xla/service/conditional_to_select.cc @@ -29,7 +29,7 @@ limitations under the License. namespace xla { -static StatusOr DoConditionalToSelect(HloInstruction* conditional) { +static absl::StatusOr DoConditionalToSelect(HloInstruction* conditional) { // Only allow conditional to select if the called computations // do not have side effects. if (conditional->true_computation()->HasSideEffect() || @@ -66,7 +66,7 @@ static StatusOr DoConditionalToSelect(HloInstruction* conditional) { return true; } -StatusOr ConditionalToSelect::Run( +absl::StatusOr ConditionalToSelect::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { std::unique_ptr call_graph = CallGraph::Build(module); diff --git a/third_party/xla/xla/service/conditional_to_select.h b/third_party/xla/xla/service/conditional_to_select.h index 9b939eafacdc35..cbc9cff571a907 100644 --- a/third_party/xla/xla/service/conditional_to_select.h +++ b/third_party/xla/xla/service/conditional_to_select.h @@ -31,7 +31,7 @@ class ConditionalToSelect : public HloModulePass { // Run conditional to select on the given computation. Returns whether the // computation was changed. using HloPassInterface::Run; - StatusOr Run( + absl::StatusOr Run( HloModule* module, const absl::flat_hash_set& execution_threads) override; }; diff --git a/third_party/xla/xla/service/constant_value.cc b/third_party/xla/xla/service/constant_value.cc index 2c4ff11d39af10..a5b6c9c30f2f0b 100644 --- a/third_party/xla/xla/service/constant_value.cc +++ b/third_party/xla/xla/service/constant_value.cc @@ -19,10 +19,11 @@ limitations under the License. namespace xla { -StatusOr ConstantValue::FromLiteral(const Literal& literal) { +absl::StatusOr ConstantValue::FromLiteral( + const Literal& literal) { CHECK_EQ(literal.shape().dimensions_size(), 0) << "Expected scalar literal"; return primitive_util::PrimitiveTypeSwitch>( - [&](auto primitive_type_constant) -> StatusOr { + [&](auto primitive_type_constant) -> absl::StatusOr { if constexpr (primitive_util::IsIntegralType(primitive_type_constant)) { return ConstantValue( static_cast( diff --git a/third_party/xla/xla/service/constant_value.h b/third_party/xla/xla/service/constant_value.h index 7b38ecfefc08f3..2a88afc3e1b21c 100644 --- a/third_party/xla/xla/service/constant_value.h +++ b/third_party/xla/xla/service/constant_value.h @@ -53,7 +53,7 @@ class ConstantValue { static ConstantValue GetUnsigned(uint64_t value, int32_t bitwidth) { return ConstantValue(value, bitwidth, /*is_signed=*/false); } - static StatusOr FromLiteral(const Literal& literal); + static absl::StatusOr FromLiteral(const Literal& literal); ConstantValue add(const ConstantValue& other) const { return ConstantValue(value_ + other.value_, bitwidth_, is_signed_); } diff --git a/third_party/xla/xla/service/convert_async_collectives_to_sync.cc b/third_party/xla/xla/service/convert_async_collectives_to_sync.cc index 7718c35de9a927..60ec4a8788f685 100644 --- a/third_party/xla/xla/service/convert_async_collectives_to_sync.cc +++ b/third_party/xla/xla/service/convert_async_collectives_to_sync.cc @@ -30,8 +30,8 @@ limitations under the License. namespace xla { -StatusOr CreateSyncVariant(HloInstruction* async_start, - HloInstruction* async_done) { +absl::StatusOr CreateSyncVariant(HloInstruction* async_start, + HloInstruction* async_done) { HloInstruction* sync_instruction = nullptr; HloComputation* computation = async_start->parent(); @@ -144,7 +144,7 @@ ConvertAsyncCollectivesToSync::ReplaceAsyncInstructionsWithSync( return OkStatus(); } -StatusOr ConvertAsyncCollectivesToSync::RunOnComputation( +absl::StatusOr ConvertAsyncCollectivesToSync::RunOnComputation( HloComputation* computation) { HloModule* module = computation->parent(); std::vector> async_pairs; @@ -193,7 +193,7 @@ StatusOr ConvertAsyncCollectivesToSync::RunOnComputation( return true; } -StatusOr ConvertAsyncCollectivesToSync::Run( +absl::StatusOr ConvertAsyncCollectivesToSync::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { if (!module->has_schedule()) { diff --git a/third_party/xla/xla/service/convert_async_collectives_to_sync.h b/third_party/xla/xla/service/convert_async_collectives_to_sync.h index 0fa3df7ac22c31..2b37c6ee7fa469 100644 --- a/third_party/xla/xla/service/convert_async_collectives_to_sync.h +++ b/third_party/xla/xla/service/convert_async_collectives_to_sync.h @@ -36,7 +36,7 @@ class ConvertAsyncCollectivesToSync : public HloModulePass { } using HloPassInterface::Run; - StatusOr Run( + absl::StatusOr Run( HloModule* module, const absl::flat_hash_set& execution_threads) override; @@ -58,7 +58,7 @@ class ConvertAsyncCollectivesToSync : public HloModulePass { "async_collective_name"; private: - StatusOr RunOnComputation(HloComputation* computation); + absl::StatusOr RunOnComputation(HloComputation* computation); HloPredicate is_nop_; }; } // namespace xla diff --git a/third_party/xla/xla/service/convert_memory_placement_to_internal_annotations.cc b/third_party/xla/xla/service/convert_memory_placement_to_internal_annotations.cc new file mode 100644 index 00000000000000..a0d7887532651a --- /dev/null +++ b/third_party/xla/xla/service/convert_memory_placement_to_internal_annotations.cc @@ -0,0 +1,102 @@ +/* Copyright 2024 The OpenXLA Authors. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ==============================================================================*/ + +#include "xla/service/convert_memory_placement_to_internal_annotations.h" + +#include "absl/container/flat_hash_set.h" +#include "absl/log/log.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/service/host_memory_offload_annotations.h" +#include "xla/side_effect_util.h" +#include "xla/statusor.h" +#include "xla/util.h" +#include "xla/xla_data.pb.h" +#include "tsl/platform/errors.h" + +namespace xla { + +absl::StatusOr ConvertMemoryPlacementToInternalAnnotations::Run( + HloModule* module, + const absl::flat_hash_set& execution_threads) { + bool changed = false; + for (HloComputation* c : module->MakeNonfusionComputations()) { + for (HloInstruction* instruction : c->MakeInstructionPostOrder()) { + if (instruction->IsCustomCall( + host_memory_offload_annotations::kDevicePlacement)) { + const auto& frontend_attributes = instruction->frontend_attributes(); + const auto it = frontend_attributes.map().find(kXlaBufferPlacementAttr); + if (it == frontend_attributes.map().end()) { + continue; + } + // XLA currently does not differentiate between pinned and unpinned host + // memory. + const bool is_to_host_case = + (it->second == + host_memory_offload_annotations::kMemoryTargetPinnedHost || + it->second == + host_memory_offload_annotations::kMemoryTargetUnpinnedHost); + const bool is_to_device_case = + (it->second == + host_memory_offload_annotations::kMemoryTargetDevice); + if (!is_to_host_case && !is_to_device_case) { + continue; + } + if (is_to_host_case) { + VLOG(1) << "Process forward case: " << instruction->ToString(); + if (instruction->users().size() != 1) { + VLOG(1) << "Skip because of too many users on instruction"; + continue; + } + if (instruction->operand_count() != 1) { + return Internal( + "Custom calls with target %s must have exactly one operand. %s " + "has %d.", + host_memory_offload_annotations::kDevicePlacement, + instruction->name(), instruction->operand_count()); + } + HloInstruction* input = instruction->mutable_operand(0); + TF_RETURN_IF_ERROR(instruction->ReplaceAllUsesWith( + c->AddInstruction(HloInstruction::CreateCustomCall( + input->shape(), {input}, + host_memory_offload_annotations:: + kMoveToHostCustomCallTarget)))); + TF_RETURN_IF_ERROR( + c->RemoveInstructionAndUnusedOperands(instruction)); + changed = true; + } else if (is_to_device_case) { + VLOG(1) << "Process backward case: " << instruction->ToString(); + HloInstruction* custom_call_operand = instruction->mutable_operand(0); + if (custom_call_operand->users().size() != 1) { + VLOG(1) << "Skip because operand is used by more than one user"; + continue; + } + HloInstruction* new_result = + c->AddInstruction(HloInstruction::CreateCustomCall( + custom_call_operand->shape(), {custom_call_operand}, + host_memory_offload_annotations:: + kMoveToDeviceCustomCallTarget)); + TF_RETURN_IF_ERROR(instruction->ReplaceAllUsesWith(new_result)); + TF_RETURN_IF_ERROR( + c->RemoveInstructionAndUnusedOperands(instruction)); + changed = true; + } + } + } + } + return changed; +} + +} // namespace xla diff --git a/third_party/xla/xla/service/convert_memory_placement_to_internal_annotations.h b/third_party/xla/xla/service/convert_memory_placement_to_internal_annotations.h new file mode 100644 index 00000000000000..87fff9d715ec86 --- /dev/null +++ b/third_party/xla/xla/service/convert_memory_placement_to_internal_annotations.h @@ -0,0 +1,44 @@ +/* Copyright 2024 The OpenXLA Authors. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ==============================================================================*/ + +#ifndef XLA_SERVICE_CONVERT_MEMORY_PLACEMENT_TO_INTERNAL_ANNOTATIONS_H_ +#define XLA_SERVICE_CONVERT_MEMORY_PLACEMENT_TO_INTERNAL_ANNOTATIONS_H_ + +#include +#include +#include + +#include "absl/container/flat_hash_set.h" +#include "absl/strings/string_view.h" +#include "xla/service/hlo_pass_interface.h" + +namespace xla { + +class ConvertMemoryPlacementToInternalAnnotations : public HloModulePass { + public: + ConvertMemoryPlacementToInternalAnnotations() = default; + + absl::string_view name() const override { + return "convert-memory-placement-to-internal-annotations"; + } + using HloPassInterface::Run; + StatusOr Run( + HloModule* module, + const absl::flat_hash_set& execution_threads) override; +}; + +} // namespace xla + +#endif // XLA_SERVICE_CONVERT_MEMORY_PLACEMENT_TO_INTERNAL_ANNOTATIONS_H_ diff --git a/third_party/xla/xla/service/convert_memory_placement_to_internal_annotations_test.cc b/third_party/xla/xla/service/convert_memory_placement_to_internal_annotations_test.cc new file mode 100644 index 00000000000000..7689a575c35996 --- /dev/null +++ b/third_party/xla/xla/service/convert_memory_placement_to_internal_annotations_test.cc @@ -0,0 +1,483 @@ +/* Copyright 2024 The OpenXLA Authors. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ==============================================================================*/ + +#include "xla/service/convert_memory_placement_to_internal_annotations.h" + +#include +#include +#include +#include +#include + +#include +#include +#include "xla/statusor.h" +#include "xla/tests/hlo_test_base.h" +#include "xla/util.h" +#include "xla/xla_data.pb.h" + +namespace xla { +namespace { + +class ConvertMemoryPlacementToInternalAnnotationsTest : public HloTestBase { + public: + ConvertMemoryPlacementToInternalAnnotationsTest() = default; +}; + +TEST_F(ConvertMemoryPlacementToInternalAnnotationsTest, ConvertPinnedHostTest) { + const char* hlo_string = R"( +HloModule jit_f, entry_computation_layout={(f32[16]{0})->f32[16]{0}} + +region_0.9 { + arg_tuple.10 = (s32[], f32[16]{0}, f32[16,16]{1,0}, f32[16,16]{1,0}) parameter(0) + get-tuple-element.11 = s32[] get-tuple-element(arg_tuple.10), index=0 + constant.15 = s32[] constant(1) + add.33 = s32[] add(get-tuple-element.11, constant.15) + get-tuple-element.12 = f32[16]{0} get-tuple-element(arg_tuple.10), index=1 + sine.18 = f32[16]{0} sine(get-tuple-element.12) + sine.19 = f32[16]{0} sine(sine.18) + sine.20 = f32[16]{0} sine(sine.19) + get-tuple-element.13 = f32[16,16]{1,0} get-tuple-element(arg_tuple.10), index=2 + custom-call.21 = f32[16]{0} custom-call(sine.19), custom_call_target="annotate_device_placement", frontend_attributes={_xla_buffer_placement="pinned_host"} + reshape.23 = f32[1,16]{1,0} reshape(custom-call.21) + constant.17 = s32[] constant(0) + compare.24 = pred[] compare(get-tuple-element.11, constant.17), direction=LT + constant.16 = s32[] constant(16) + add.25 = s32[] add(get-tuple-element.11, constant.16) + select.26 = s32[] select(compare.24, add.25, get-tuple-element.11) + dynamic-update-slice.27 = f32[16,16]{1,0} dynamic-update-slice(get-tuple-element.13, reshape.23, select.26, constant.17) + get-tuple-element.14 = f32[16,16]{1,0} get-tuple-element(arg_tuple.10), index=3 + custom-call.22 = f32[16]{0} custom-call(sine.20), custom_call_target="annotate_device_placement", frontend_attributes={_xla_buffer_placement="pinned_host"} + reshape.28 = f32[1,16]{1,0} reshape(custom-call.22) + compare.29 = pred[] compare(get-tuple-element.11, constant.17), direction=LT + add.30 = s32[] add(get-tuple-element.11, constant.16) + select.31 = s32[] select(compare.29, add.30, get-tuple-element.11) + dynamic-update-slice.32 = f32[16,16]{1,0} dynamic-update-slice(get-tuple-element.14, reshape.28, select.31, constant.17) + ROOT tuple.34 = (s32[], f32[16]{0}, f32[16,16]{1,0}, f32[16,16]{1,0}) tuple(add.33, sine.20, dynamic-update-slice.27, dynamic-update-slice.32) +} + +region_1.35 { + arg_tuple.36 = (s32[], f32[16]{0}, f32[16,16]{1,0}, f32[16,16]{1,0}) parameter(0) + get-tuple-element.38 = f32[16]{0} get-tuple-element(arg_tuple.36), index=1 + get-tuple-element.39 = f32[16,16]{1,0} get-tuple-element(arg_tuple.36), index=2 + get-tuple-element.40 = f32[16,16]{1,0} get-tuple-element(arg_tuple.36), index=3 + get-tuple-element.37 = s32[] get-tuple-element(arg_tuple.36), index=0 + constant.41 = s32[] constant(16) + ROOT compare.42 = pred[] compare(get-tuple-element.37, constant.41), direction=LT +} + +core_closed_call.43 { + constant.47 = s32[] constant(0) + Arg_0.44 = f32[16]{0} parameter(0) + constant.45 = f32[] constant(0) + broadcast.46 = f32[16,16]{1,0} broadcast(constant.45), dimensions={} + tuple.48 = (s32[], f32[16]{0}, f32[16,16]{1,0}, f32[16,16]{1,0}) tuple(constant.47, Arg_0.44, broadcast.46, broadcast.46) + while.49 = (s32[], f32[16]{0}, f32[16,16]{1,0}, f32[16,16]{1,0}) while(tuple.48), condition=region_1.35, body=region_0.9 + get-tuple-element.50 = s32[] get-tuple-element(while.49), index=0 + get-tuple-element.51 = f32[16]{0} get-tuple-element(while.49), index=1 + get-tuple-element.52 = f32[16,16]{1,0} get-tuple-element(while.49), index=2 + get-tuple-element.53 = f32[16,16]{1,0} get-tuple-element(while.49), index=3 + ROOT tuple.54 = (f32[16,16]{1,0}, f32[16,16]{1,0}) tuple(get-tuple-element.52, get-tuple-element.53) +} + +region_2.65 { + arg_tuple.66 = (s32[], f32[16]{0}, f32[16,16]{1,0}, f32[16,16]{1,0}, f32[16,16]{1,0}, /*index=5*/f32[16,16]{1,0}, f32[16,16]{1,0}) parameter(0) + get-tuple-element.67 = s32[] get-tuple-element(arg_tuple.66), index=0 + constant.74 = s32[] constant(1) + add.108 = s32[] add(get-tuple-element.67, constant.74) + get-tuple-element.73 = f32[16,16]{1,0} get-tuple-element(arg_tuple.66), index=6 + constant.76 = s32[] constant(0) + compare.82 = pred[] compare(get-tuple-element.67, constant.76), direction=LT + constant.75 = s32[] constant(16) + add.83 = s32[] add(get-tuple-element.67, constant.75) + select.84 = s32[] select(compare.82, add.83, get-tuple-element.67) + dynamic-slice.85 = f32[1,16]{1,0} dynamic-slice(get-tuple-element.73, select.84, constant.76), dynamic_slice_sizes={1,16} + reshape.86 = f32[16]{0} reshape(dynamic-slice.85) + custom-call.87 = f32[16]{0} custom-call(reshape.86), custom_call_target="annotate_device_placement", frontend_attributes={_xla_buffer_placement="device"} + get-tuple-element.69 = f32[16,16]{1,0} get-tuple-element(arg_tuple.66), index=2 + get-tuple-element.68 = f32[16]{0} get-tuple-element(arg_tuple.66), index=1 + cosine.88 = f32[16]{0} cosine(get-tuple-element.68) + reshape.93 = f32[1,16]{1,0} reshape(cosine.88) + compare.94 = pred[] compare(get-tuple-element.67, constant.76), direction=LT + add.95 = s32[] add(get-tuple-element.67, constant.75) + select.96 = s32[] select(compare.94, add.95, get-tuple-element.67) + dynamic-update-slice.97 = f32[16,16]{1,0} dynamic-update-slice(get-tuple-element.69, reshape.93, select.96, constant.76) + get-tuple-element.70 = f32[16,16]{1,0} get-tuple-element(arg_tuple.66), index=3 + sine.89 = f32[16]{0} sine(get-tuple-element.68) + cosine.90 = f32[16]{0} cosine(sine.89) + reshape.98 = f32[1,16]{1,0} reshape(cosine.90) + compare.99 = pred[] compare(get-tuple-element.67, constant.76), direction=LT + add.100 = s32[] add(get-tuple-element.67, constant.75) + select.101 = s32[] select(compare.99, add.100, get-tuple-element.67) + dynamic-update-slice.102 = f32[16,16]{1,0} dynamic-update-slice(get-tuple-element.70, reshape.98, select.101, constant.76) + get-tuple-element.71 = f32[16,16]{1,0} get-tuple-element(arg_tuple.66), index=4 + get-tuple-element.72 = f32[16,16]{1,0} get-tuple-element(arg_tuple.66), index=5 + compare.77 = pred[] compare(get-tuple-element.67, constant.76), direction=LT + add.78 = s32[] add(get-tuple-element.67, constant.75) + select.79 = s32[] select(compare.77, add.78, get-tuple-element.67) + dynamic-slice.80 = f32[1,16]{1,0} dynamic-slice(get-tuple-element.72, select.79, constant.76), dynamic_slice_sizes={1,16} + reshape.81 = f32[16]{0} reshape(dynamic-slice.80) + custom-call.91 = f32[16]{0} custom-call(reshape.81), custom_call_target="annotate_device_placement", frontend_attributes={_xla_buffer_placement="device"} + cosine.92 = f32[16]{0} cosine(custom-call.91) + reshape.103 = f32[1,16]{1,0} reshape(cosine.92) + compare.104 = pred[] compare(get-tuple-element.67, constant.76), direction=LT + add.105 = s32[] add(get-tuple-element.67, constant.75) + select.106 = s32[] select(compare.104, add.105, get-tuple-element.67) + dynamic-update-slice.107 = f32[16,16]{1,0} dynamic-update-slice(get-tuple-element.71, reshape.103, select.106, constant.76) + ROOT tuple.109 = (s32[], f32[16]{0}, f32[16,16]{1,0}, f32[16,16]{1,0}, f32[16,16]{1,0}, /*index=5*/f32[16,16]{1,0}, f32[16,16]{1,0}) tuple(add.108, custom-call.87, dynamic-update-slice.97, dynamic-update-slice.102, dynamic-update-slice.107, get-tuple-element.72, get-tuple-element.73) +} + +region_3.110 { + arg_tuple.111 = (s32[], f32[16]{0}, f32[16,16]{1,0}, f32[16,16]{1,0}, f32[16,16]{1,0}, /*index=5*/f32[16,16]{1,0}, f32[16,16]{1,0}) parameter(0) + get-tuple-element.113 = f32[16]{0} get-tuple-element(arg_tuple.111), index=1 + get-tuple-element.114 = f32[16,16]{1,0} get-tuple-element(arg_tuple.111), index=2 + get-tuple-element.115 = f32[16,16]{1,0} get-tuple-element(arg_tuple.111), index=3 + get-tuple-element.116 = f32[16,16]{1,0} get-tuple-element(arg_tuple.111), index=4 + get-tuple-element.117 = f32[16,16]{1,0} get-tuple-element(arg_tuple.111), index=5 + get-tuple-element.118 = f32[16,16]{1,0} get-tuple-element(arg_tuple.111), index=6 + get-tuple-element.112 = s32[] get-tuple-element(arg_tuple.111), index=0 + constant.119 = s32[] constant(16) + ROOT compare.120 = pred[] compare(get-tuple-element.112, constant.119), direction=LT +} + +region_4.130 { + arg_tuple.131 = (s32[], f32[16]{0}, f32[], f32[16,16]{1,0}, f32[16,16]{1,0}, /*index=5*/f32[16,16]{1,0}) parameter(0) + get-tuple-element.132 = s32[] get-tuple-element(arg_tuple.131), index=0 + constant.140 = s32[] constant(1) + add.164 = s32[] add(get-tuple-element.132, constant.140) + get-tuple-element.133 = f32[16]{0} get-tuple-element(arg_tuple.131), index=1 + get-tuple-element.134 = f32[] get-tuple-element(arg_tuple.131), index=2 + broadcast.159 = f32[16]{0} broadcast(get-tuple-element.134), dimensions={} + add.160 = f32[16]{0} add(get-tuple-element.133, broadcast.159) + get-tuple-element.137 = f32[16,16]{1,0} get-tuple-element(arg_tuple.131), index=5 + constant.141 = s32[] constant(16) + subtract.142 = s32[] subtract(constant.141, get-tuple-element.132) + subtract.143 = s32[] subtract(subtract.142, constant.140) + constant.139 = s32[] constant(0) + compare.154 = pred[] compare(subtract.143, constant.139), direction=LT + add.155 = s32[] add(subtract.143, constant.141) + select.156 = s32[] select(compare.154, add.155, subtract.143) + dynamic-slice.157 = f32[1,16]{1,0} dynamic-slice(get-tuple-element.137, select.156, constant.139), dynamic_slice_sizes={1,16} + reshape.158 = f32[16]{0} reshape(dynamic-slice.157) + multiply.161 = f32[16]{0} multiply(add.160, reshape.158) + get-tuple-element.136 = f32[16,16]{1,0} get-tuple-element(arg_tuple.131), index=4 + compare.149 = pred[] compare(subtract.143, constant.139), direction=LT + add.150 = s32[] add(subtract.143, constant.141) + select.151 = s32[] select(compare.149, add.150, subtract.143) + dynamic-slice.152 = f32[1,16]{1,0} dynamic-slice(get-tuple-element.136, select.151, constant.139), dynamic_slice_sizes={1,16} + reshape.153 = f32[16]{0} reshape(dynamic-slice.152) + multiply.162 = f32[16]{0} multiply(multiply.161, reshape.153) + get-tuple-element.135 = f32[16,16]{1,0} get-tuple-element(arg_tuple.131), index=3 + compare.144 = pred[] compare(subtract.143, constant.139), direction=LT + add.145 = s32[] add(subtract.143, constant.141) + select.146 = s32[] select(compare.144, add.145, subtract.143) + dynamic-slice.147 = f32[1,16]{1,0} dynamic-slice(get-tuple-element.135, select.146, constant.139), dynamic_slice_sizes={1,16} + reshape.148 = f32[16]{0} reshape(dynamic-slice.147) + multiply.163 = f32[16]{0} multiply(multiply.162, reshape.148) + constant.138 = f32[] constant(0) + ROOT tuple.165 = (s32[], f32[16]{0}, f32[], f32[16,16]{1,0}, f32[16,16]{1,0}, /*index=5*/f32[16,16]{1,0}) tuple(add.164, multiply.163, constant.138, get-tuple-element.135, get-tuple-element.136, get-tuple-element.137) +} + +region_5.166 { + arg_tuple.167 = (s32[], f32[16]{0}, f32[], f32[16,16]{1,0}, f32[16,16]{1,0}, /*index=5*/f32[16,16]{1,0}) parameter(0) + get-tuple-element.169 = f32[16]{0} get-tuple-element(arg_tuple.167), index=1 + get-tuple-element.170 = f32[] get-tuple-element(arg_tuple.167), index=2 + get-tuple-element.171 = f32[16,16]{1,0} get-tuple-element(arg_tuple.167), index=3 + get-tuple-element.172 = f32[16,16]{1,0} get-tuple-element(arg_tuple.167), index=4 + get-tuple-element.173 = f32[16,16]{1,0} get-tuple-element(arg_tuple.167), index=5 + get-tuple-element.168 = s32[] get-tuple-element(arg_tuple.167), index=0 + constant.174 = s32[] constant(16) + ROOT compare.175 = pred[] compare(get-tuple-element.168, constant.174), direction=LT +} + +ENTRY main.183 { + constant.6 = s32[] constant(0) + Arg_0.1 = f32[16]{0} parameter(0), sharding={devices=[2]<=[2]} + call.55 = (f32[16,16]{1,0}, f32[16,16]{1,0}) call(Arg_0.1), to_apply=core_closed_call.43 + get-tuple-element.56 = f32[16,16]{1,0} get-tuple-element(call.55), index=0 + get-tuple-element.57 = f32[16,16]{1,0} get-tuple-element(call.55), index=1 + constant.7 = f32[] constant(1) + tuple.58 = (f32[16,16]{1,0}, f32[16,16]{1,0}, f32[16]{0}, f32[]) tuple(get-tuple-element.56, get-tuple-element.57, Arg_0.1, constant.7) + opt-barrier.59 = (f32[16,16]{1,0}, f32[16,16]{1,0}, f32[16]{0}, f32[]) opt-barrier(tuple.58) + get-tuple-element.62 = f32[16]{0} get-tuple-element(opt-barrier.59), index=2 + constant.4 = f32[] constant(0) + broadcast.5 = f32[16,16]{1,0} broadcast(constant.4), dimensions={} + get-tuple-element.60 = f32[16,16]{1,0} get-tuple-element(opt-barrier.59), index=0 + get-tuple-element.61 = f32[16,16]{1,0} get-tuple-element(opt-barrier.59), index=1 + tuple.64 = (s32[], f32[16]{0}, f32[16,16]{1,0}, f32[16,16]{1,0}, f32[16,16]{1,0}, /*index=5*/f32[16,16]{1,0}, f32[16,16]{1,0}) tuple(constant.6, get-tuple-element.62, broadcast.5, broadcast.5, broadcast.5, get-tuple-element.60, get-tuple-element.61) + while.121 = (s32[], f32[16]{0}, f32[16,16]{1,0}, f32[16,16]{1,0}, f32[16,16]{1,0}, /*index=5*/f32[16,16]{1,0}, f32[16,16]{1,0}) while(tuple.64), condition=region_3.110, body=region_2.65 + get-tuple-element.122 = s32[] get-tuple-element(while.121), index=0 + get-tuple-element.123 = f32[16]{0} get-tuple-element(while.121), index=1 + get-tuple-element.127 = f32[16,16]{1,0} get-tuple-element(while.121), index=5 + get-tuple-element.128 = f32[16,16]{1,0} get-tuple-element(while.121), index=6 + constant.2 = f32[] constant(0) + broadcast.3 = f32[16]{0} broadcast(constant.2), dimensions={} + get-tuple-element.63 = f32[] get-tuple-element(opt-barrier.59), index=3 + get-tuple-element.124 = f32[16,16]{1,0} get-tuple-element(while.121), index=2 + get-tuple-element.125 = f32[16,16]{1,0} get-tuple-element(while.121), index=3 + get-tuple-element.126 = f32[16,16]{1,0} get-tuple-element(while.121), index=4 + tuple.129 = (s32[], f32[16]{0}, f32[], f32[16,16]{1,0}, f32[16,16]{1,0}, /*index=5*/f32[16,16]{1,0}) tuple(constant.6, broadcast.3, get-tuple-element.63, get-tuple-element.124, get-tuple-element.125, get-tuple-element.126) + while.176 = (s32[], f32[16]{0}, f32[], f32[16,16]{1,0}, f32[16,16]{1,0}, /*index=5*/f32[16,16]{1,0}) while(tuple.129), condition=region_5.166, body=region_4.130 + get-tuple-element.177 = s32[] get-tuple-element(while.176), index=0 + ROOT get-tuple-element.178 = f32[16]{0} get-tuple-element(while.176), index=1 + get-tuple-element.179 = f32[] get-tuple-element(while.176), index=2 + get-tuple-element.180 = f32[16,16]{1,0} get-tuple-element(while.176), index=3 + get-tuple-element.181 = f32[16,16]{1,0} get-tuple-element(while.176), index=4 + get-tuple-element.182 = f32[16,16]{1,0} get-tuple-element(while.176), index=5 +} +)"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + + bool changed = + ConvertMemoryPlacementToInternalAnnotations().Run(module.get()).value(); + EXPECT_TRUE(changed); + XLA_VLOG_LINES(1, module->ToString()); + int64_t custom_calls_count = 0; + for (auto* c : module->computations()) { + for (auto* instr : c->instructions()) { + if (instr->IsCustomCall("PipelineForward") || + instr->IsCustomCall("PipelineBackward")) { + ++custom_calls_count; + } + } + } + EXPECT_EQ(custom_calls_count, 4); +} + +TEST_F(ConvertMemoryPlacementToInternalAnnotationsTest, + ConvertUnpinnedHostTest) { + const char* hlo_string = R"( +HloModule jit_f, entry_computation_layout={(f32[16]{0})->f32[16]{0}} + +region_0.9 { + arg_tuple.10 = (s32[], f32[16]{0}, f32[16,16]{1,0}, f32[16,16]{1,0}) parameter(0) + get-tuple-element.11 = s32[] get-tuple-element(arg_tuple.10), index=0 + constant.15 = s32[] constant(1) + add.33 = s32[] add(get-tuple-element.11, constant.15) + get-tuple-element.12 = f32[16]{0} get-tuple-element(arg_tuple.10), index=1 + sine.18 = f32[16]{0} sine(get-tuple-element.12) + sine.19 = f32[16]{0} sine(sine.18) + sine.20 = f32[16]{0} sine(sine.19) + get-tuple-element.13 = f32[16,16]{1,0} get-tuple-element(arg_tuple.10), index=2 + custom-call.21 = f32[16]{0} custom-call(sine.19), custom_call_target="annotate_device_placement", frontend_attributes={_xla_buffer_placement="unpinned_host"} + reshape.23 = f32[1,16]{1,0} reshape(custom-call.21) + constant.17 = s32[] constant(0) + compare.24 = pred[] compare(get-tuple-element.11, constant.17), direction=LT + constant.16 = s32[] constant(16) + add.25 = s32[] add(get-tuple-element.11, constant.16) + select.26 = s32[] select(compare.24, add.25, get-tuple-element.11) + dynamic-update-slice.27 = f32[16,16]{1,0} dynamic-update-slice(get-tuple-element.13, reshape.23, select.26, constant.17) + get-tuple-element.14 = f32[16,16]{1,0} get-tuple-element(arg_tuple.10), index=3 + custom-call.22 = f32[16]{0} custom-call(sine.20), custom_call_target="annotate_device_placement", frontend_attributes={_xla_buffer_placement="unpinned_host"} + reshape.28 = f32[1,16]{1,0} reshape(custom-call.22) + compare.29 = pred[] compare(get-tuple-element.11, constant.17), direction=LT + add.30 = s32[] add(get-tuple-element.11, constant.16) + select.31 = s32[] select(compare.29, add.30, get-tuple-element.11) + dynamic-update-slice.32 = f32[16,16]{1,0} dynamic-update-slice(get-tuple-element.14, reshape.28, select.31, constant.17) + ROOT tuple.34 = (s32[], f32[16]{0}, f32[16,16]{1,0}, f32[16,16]{1,0}) tuple(add.33, sine.20, dynamic-update-slice.27, dynamic-update-slice.32) +} + +region_1.35 { + arg_tuple.36 = (s32[], f32[16]{0}, f32[16,16]{1,0}, f32[16,16]{1,0}) parameter(0) + get-tuple-element.38 = f32[16]{0} get-tuple-element(arg_tuple.36), index=1 + get-tuple-element.39 = f32[16,16]{1,0} get-tuple-element(arg_tuple.36), index=2 + get-tuple-element.40 = f32[16,16]{1,0} get-tuple-element(arg_tuple.36), index=3 + get-tuple-element.37 = s32[] get-tuple-element(arg_tuple.36), index=0 + constant.41 = s32[] constant(16) + ROOT compare.42 = pred[] compare(get-tuple-element.37, constant.41), direction=LT +} + +core_closed_call.43 { + constant.47 = s32[] constant(0) + Arg_0.44 = f32[16]{0} parameter(0) + constant.45 = f32[] constant(0) + broadcast.46 = f32[16,16]{1,0} broadcast(constant.45), dimensions={} + tuple.48 = (s32[], f32[16]{0}, f32[16,16]{1,0}, f32[16,16]{1,0}) tuple(constant.47, Arg_0.44, broadcast.46, broadcast.46) + while.49 = (s32[], f32[16]{0}, f32[16,16]{1,0}, f32[16,16]{1,0}) while(tuple.48), condition=region_1.35, body=region_0.9 + get-tuple-element.50 = s32[] get-tuple-element(while.49), index=0 + get-tuple-element.51 = f32[16]{0} get-tuple-element(while.49), index=1 + get-tuple-element.52 = f32[16,16]{1,0} get-tuple-element(while.49), index=2 + get-tuple-element.53 = f32[16,16]{1,0} get-tuple-element(while.49), index=3 + ROOT tuple.54 = (f32[16,16]{1,0}, f32[16,16]{1,0}) tuple(get-tuple-element.52, get-tuple-element.53) +} + +region_2.65 { + arg_tuple.66 = (s32[], f32[16]{0}, f32[16,16]{1,0}, f32[16,16]{1,0}, f32[16,16]{1,0}, /*index=5*/f32[16,16]{1,0}, f32[16,16]{1,0}) parameter(0) + get-tuple-element.67 = s32[] get-tuple-element(arg_tuple.66), index=0 + constant.74 = s32[] constant(1) + add.108 = s32[] add(get-tuple-element.67, constant.74) + get-tuple-element.73 = f32[16,16]{1,0} get-tuple-element(arg_tuple.66), index=6 + constant.76 = s32[] constant(0) + compare.82 = pred[] compare(get-tuple-element.67, constant.76), direction=LT + constant.75 = s32[] constant(16) + add.83 = s32[] add(get-tuple-element.67, constant.75) + select.84 = s32[] select(compare.82, add.83, get-tuple-element.67) + dynamic-slice.85 = f32[1,16]{1,0} dynamic-slice(get-tuple-element.73, select.84, constant.76), dynamic_slice_sizes={1,16} + reshape.86 = f32[16]{0} reshape(dynamic-slice.85) + custom-call.87 = f32[16]{0} custom-call(reshape.86), custom_call_target="annotate_device_placement", frontend_attributes={_xla_buffer_placement="device"} + get-tuple-element.69 = f32[16,16]{1,0} get-tuple-element(arg_tuple.66), index=2 + get-tuple-element.68 = f32[16]{0} get-tuple-element(arg_tuple.66), index=1 + cosine.88 = f32[16]{0} cosine(get-tuple-element.68) + reshape.93 = f32[1,16]{1,0} reshape(cosine.88) + compare.94 = pred[] compare(get-tuple-element.67, constant.76), direction=LT + add.95 = s32[] add(get-tuple-element.67, constant.75) + select.96 = s32[] select(compare.94, add.95, get-tuple-element.67) + dynamic-update-slice.97 = f32[16,16]{1,0} dynamic-update-slice(get-tuple-element.69, reshape.93, select.96, constant.76) + get-tuple-element.70 = f32[16,16]{1,0} get-tuple-element(arg_tuple.66), index=3 + sine.89 = f32[16]{0} sine(get-tuple-element.68) + cosine.90 = f32[16]{0} cosine(sine.89) + reshape.98 = f32[1,16]{1,0} reshape(cosine.90) + compare.99 = pred[] compare(get-tuple-element.67, constant.76), direction=LT + add.100 = s32[] add(get-tuple-element.67, constant.75) + select.101 = s32[] select(compare.99, add.100, get-tuple-element.67) + dynamic-update-slice.102 = f32[16,16]{1,0} dynamic-update-slice(get-tuple-element.70, reshape.98, select.101, constant.76) + get-tuple-element.71 = f32[16,16]{1,0} get-tuple-element(arg_tuple.66), index=4 + get-tuple-element.72 = f32[16,16]{1,0} get-tuple-element(arg_tuple.66), index=5 + compare.77 = pred[] compare(get-tuple-element.67, constant.76), direction=LT + add.78 = s32[] add(get-tuple-element.67, constant.75) + select.79 = s32[] select(compare.77, add.78, get-tuple-element.67) + dynamic-slice.80 = f32[1,16]{1,0} dynamic-slice(get-tuple-element.72, select.79, constant.76), dynamic_slice_sizes={1,16} + reshape.81 = f32[16]{0} reshape(dynamic-slice.80) + custom-call.91 = f32[16]{0} custom-call(reshape.81), custom_call_target="annotate_device_placement", frontend_attributes={_xla_buffer_placement="device"} + cosine.92 = f32[16]{0} cosine(custom-call.91) + reshape.103 = f32[1,16]{1,0} reshape(cosine.92) + compare.104 = pred[] compare(get-tuple-element.67, constant.76), direction=LT + add.105 = s32[] add(get-tuple-element.67, constant.75) + select.106 = s32[] select(compare.104, add.105, get-tuple-element.67) + dynamic-update-slice.107 = f32[16,16]{1,0} dynamic-update-slice(get-tuple-element.71, reshape.103, select.106, constant.76) + ROOT tuple.109 = (s32[], f32[16]{0}, f32[16,16]{1,0}, f32[16,16]{1,0}, f32[16,16]{1,0}, /*index=5*/f32[16,16]{1,0}, f32[16,16]{1,0}) tuple(add.108, custom-call.87, dynamic-update-slice.97, dynamic-update-slice.102, dynamic-update-slice.107, get-tuple-element.72, get-tuple-element.73) +} + +region_3.110 { + arg_tuple.111 = (s32[], f32[16]{0}, f32[16,16]{1,0}, f32[16,16]{1,0}, f32[16,16]{1,0}, /*index=5*/f32[16,16]{1,0}, f32[16,16]{1,0}) parameter(0) + get-tuple-element.113 = f32[16]{0} get-tuple-element(arg_tuple.111), index=1 + get-tuple-element.114 = f32[16,16]{1,0} get-tuple-element(arg_tuple.111), index=2 + get-tuple-element.115 = f32[16,16]{1,0} get-tuple-element(arg_tuple.111), index=3 + get-tuple-element.116 = f32[16,16]{1,0} get-tuple-element(arg_tuple.111), index=4 + get-tuple-element.117 = f32[16,16]{1,0} get-tuple-element(arg_tuple.111), index=5 + get-tuple-element.118 = f32[16,16]{1,0} get-tuple-element(arg_tuple.111), index=6 + get-tuple-element.112 = s32[] get-tuple-element(arg_tuple.111), index=0 + constant.119 = s32[] constant(16) + ROOT compare.120 = pred[] compare(get-tuple-element.112, constant.119), direction=LT +} + +region_4.130 { + arg_tuple.131 = (s32[], f32[16]{0}, f32[], f32[16,16]{1,0}, f32[16,16]{1,0}, /*index=5*/f32[16,16]{1,0}) parameter(0) + get-tuple-element.132 = s32[] get-tuple-element(arg_tuple.131), index=0 + constant.140 = s32[] constant(1) + add.164 = s32[] add(get-tuple-element.132, constant.140) + get-tuple-element.133 = f32[16]{0} get-tuple-element(arg_tuple.131), index=1 + get-tuple-element.134 = f32[] get-tuple-element(arg_tuple.131), index=2 + broadcast.159 = f32[16]{0} broadcast(get-tuple-element.134), dimensions={} + add.160 = f32[16]{0} add(get-tuple-element.133, broadcast.159) + get-tuple-element.137 = f32[16,16]{1,0} get-tuple-element(arg_tuple.131), index=5 + constant.141 = s32[] constant(16) + subtract.142 = s32[] subtract(constant.141, get-tuple-element.132) + subtract.143 = s32[] subtract(subtract.142, constant.140) + constant.139 = s32[] constant(0) + compare.154 = pred[] compare(subtract.143, constant.139), direction=LT + add.155 = s32[] add(subtract.143, constant.141) + select.156 = s32[] select(compare.154, add.155, subtract.143) + dynamic-slice.157 = f32[1,16]{1,0} dynamic-slice(get-tuple-element.137, select.156, constant.139), dynamic_slice_sizes={1,16} + reshape.158 = f32[16]{0} reshape(dynamic-slice.157) + multiply.161 = f32[16]{0} multiply(add.160, reshape.158) + get-tuple-element.136 = f32[16,16]{1,0} get-tuple-element(arg_tuple.131), index=4 + compare.149 = pred[] compare(subtract.143, constant.139), direction=LT + add.150 = s32[] add(subtract.143, constant.141) + select.151 = s32[] select(compare.149, add.150, subtract.143) + dynamic-slice.152 = f32[1,16]{1,0} dynamic-slice(get-tuple-element.136, select.151, constant.139), dynamic_slice_sizes={1,16} + reshape.153 = f32[16]{0} reshape(dynamic-slice.152) + multiply.162 = f32[16]{0} multiply(multiply.161, reshape.153) + get-tuple-element.135 = f32[16,16]{1,0} get-tuple-element(arg_tuple.131), index=3 + compare.144 = pred[] compare(subtract.143, constant.139), direction=LT + add.145 = s32[] add(subtract.143, constant.141) + select.146 = s32[] select(compare.144, add.145, subtract.143) + dynamic-slice.147 = f32[1,16]{1,0} dynamic-slice(get-tuple-element.135, select.146, constant.139), dynamic_slice_sizes={1,16} + reshape.148 = f32[16]{0} reshape(dynamic-slice.147) + multiply.163 = f32[16]{0} multiply(multiply.162, reshape.148) + constant.138 = f32[] constant(0) + ROOT tuple.165 = (s32[], f32[16]{0}, f32[], f32[16,16]{1,0}, f32[16,16]{1,0}, /*index=5*/f32[16,16]{1,0}) tuple(add.164, multiply.163, constant.138, get-tuple-element.135, get-tuple-element.136, get-tuple-element.137) +} + +region_5.166 { + arg_tuple.167 = (s32[], f32[16]{0}, f32[], f32[16,16]{1,0}, f32[16,16]{1,0}, /*index=5*/f32[16,16]{1,0}) parameter(0) + get-tuple-element.169 = f32[16]{0} get-tuple-element(arg_tuple.167), index=1 + get-tuple-element.170 = f32[] get-tuple-element(arg_tuple.167), index=2 + get-tuple-element.171 = f32[16,16]{1,0} get-tuple-element(arg_tuple.167), index=3 + get-tuple-element.172 = f32[16,16]{1,0} get-tuple-element(arg_tuple.167), index=4 + get-tuple-element.173 = f32[16,16]{1,0} get-tuple-element(arg_tuple.167), index=5 + get-tuple-element.168 = s32[] get-tuple-element(arg_tuple.167), index=0 + constant.174 = s32[] constant(16) + ROOT compare.175 = pred[] compare(get-tuple-element.168, constant.174), direction=LT +} + +ENTRY main.183 { + constant.6 = s32[] constant(0) + Arg_0.1 = f32[16]{0} parameter(0), sharding={devices=[2]<=[2]} + call.55 = (f32[16,16]{1,0}, f32[16,16]{1,0}) call(Arg_0.1), to_apply=core_closed_call.43 + get-tuple-element.56 = f32[16,16]{1,0} get-tuple-element(call.55), index=0 + get-tuple-element.57 = f32[16,16]{1,0} get-tuple-element(call.55), index=1 + constant.7 = f32[] constant(1) + tuple.58 = (f32[16,16]{1,0}, f32[16,16]{1,0}, f32[16]{0}, f32[]) tuple(get-tuple-element.56, get-tuple-element.57, Arg_0.1, constant.7) + opt-barrier.59 = (f32[16,16]{1,0}, f32[16,16]{1,0}, f32[16]{0}, f32[]) opt-barrier(tuple.58) + get-tuple-element.62 = f32[16]{0} get-tuple-element(opt-barrier.59), index=2 + constant.4 = f32[] constant(0) + broadcast.5 = f32[16,16]{1,0} broadcast(constant.4), dimensions={} + get-tuple-element.60 = f32[16,16]{1,0} get-tuple-element(opt-barrier.59), index=0 + get-tuple-element.61 = f32[16,16]{1,0} get-tuple-element(opt-barrier.59), index=1 + tuple.64 = (s32[], f32[16]{0}, f32[16,16]{1,0}, f32[16,16]{1,0}, f32[16,16]{1,0}, /*index=5*/f32[16,16]{1,0}, f32[16,16]{1,0}) tuple(constant.6, get-tuple-element.62, broadcast.5, broadcast.5, broadcast.5, get-tuple-element.60, get-tuple-element.61) + while.121 = (s32[], f32[16]{0}, f32[16,16]{1,0}, f32[16,16]{1,0}, f32[16,16]{1,0}, /*index=5*/f32[16,16]{1,0}, f32[16,16]{1,0}) while(tuple.64), condition=region_3.110, body=region_2.65 + get-tuple-element.122 = s32[] get-tuple-element(while.121), index=0 + get-tuple-element.123 = f32[16]{0} get-tuple-element(while.121), index=1 + get-tuple-element.127 = f32[16,16]{1,0} get-tuple-element(while.121), index=5 + get-tuple-element.128 = f32[16,16]{1,0} get-tuple-element(while.121), index=6 + constant.2 = f32[] constant(0) + broadcast.3 = f32[16]{0} broadcast(constant.2), dimensions={} + get-tuple-element.63 = f32[] get-tuple-element(opt-barrier.59), index=3 + get-tuple-element.124 = f32[16,16]{1,0} get-tuple-element(while.121), index=2 + get-tuple-element.125 = f32[16,16]{1,0} get-tuple-element(while.121), index=3 + get-tuple-element.126 = f32[16,16]{1,0} get-tuple-element(while.121), index=4 + tuple.129 = (s32[], f32[16]{0}, f32[], f32[16,16]{1,0}, f32[16,16]{1,0}, /*index=5*/f32[16,16]{1,0}) tuple(constant.6, broadcast.3, get-tuple-element.63, get-tuple-element.124, get-tuple-element.125, get-tuple-element.126) + while.176 = (s32[], f32[16]{0}, f32[], f32[16,16]{1,0}, f32[16,16]{1,0}, /*index=5*/f32[16,16]{1,0}) while(tuple.129), condition=region_5.166, body=region_4.130 + get-tuple-element.177 = s32[] get-tuple-element(while.176), index=0 + ROOT get-tuple-element.178 = f32[16]{0} get-tuple-element(while.176), index=1 + get-tuple-element.179 = f32[] get-tuple-element(while.176), index=2 + get-tuple-element.180 = f32[16,16]{1,0} get-tuple-element(while.176), index=3 + get-tuple-element.181 = f32[16,16]{1,0} get-tuple-element(while.176), index=4 + get-tuple-element.182 = f32[16,16]{1,0} get-tuple-element(while.176), index=5 +} +)"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + + bool changed = + ConvertMemoryPlacementToInternalAnnotations().Run(module.get()).value(); + EXPECT_TRUE(changed); + XLA_VLOG_LINES(1, module->ToString()); + int64_t custom_calls_count = 0; + for (auto* c : module->computations()) { + for (auto* instr : c->instructions()) { + if (instr->IsCustomCall("PipelineForward") || + instr->IsCustomCall("PipelineBackward")) { + ++custom_calls_count; + } + } + } + EXPECT_EQ(custom_calls_count, 4); +} + +} // namespace +} // namespace xla diff --git a/third_party/xla/xla/service/copy_insertion.cc b/third_party/xla/xla/service/copy_insertion.cc index 2c223ac41cbef7..df47694edb91fc 100644 --- a/third_party/xla/xla/service/copy_insertion.cc +++ b/third_party/xla/xla/service/copy_insertion.cc @@ -833,15 +833,14 @@ class ComputeRelativeLocation { // During live range analysis of results of `branch_0` this function will be // called when entry1 and entry2 are different outputs on `fusion` in // `branch_1`. `fusion` defines two buffers, but `value_definition` in - // LiveRangeRegions::InstructionInfo does not track output index. The + // LiveRangeRegions::InstructionInfo does not track the output index. The // analysis will say that they are not interfering and assign the same - // buffer to both. This will lead to incorrect numerical results during - // runtime. + // buffer to both. // // This check makes sure that outputs of multi-output instructions are // always interfering and can not be combined. It can be a false positive // when entry1 and entry2 correspond to the same output, but we prefer that - // over numerical issues. + // over correctness issues. // // A proper solution would be to track output index in // LiveRangeRegions::InstructionInfo. diff --git a/third_party/xla/xla/service/cpu/BUILD b/third_party/xla/xla/service/cpu/BUILD index b04077a80aa4f7..b5e578cf31382c 100644 --- a/third_party/xla/xla/service/cpu/BUILD +++ b/third_party/xla/xla/service/cpu/BUILD @@ -14,7 +14,7 @@ load( "acl_deps", "if_enable_acl", ) -load("@local_tsl//tsl:tsl.bzl", "tf_openmp_copts", "tsl_copts") +load("@local_tsl//tsl:tsl.bzl", "internal_visibility", "tf_openmp_copts", "tsl_copts") load("@local_tsl//tsl:tsl.default.bzl", "filegroup", "get_compatible_with_portable") load( "@local_tsl//tsl/mkl:build_defs.bzl", @@ -25,7 +25,8 @@ load("@local_tsl//tsl/platform:rules_cc.bzl", "cc_library") load(":build_defs.bzl", "runtime_copts") package( - default_visibility = ["//visibility:public"], + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = internal_visibility([":friends"]), licenses = ["notice"], ) @@ -43,7 +44,6 @@ filegroup( "**/*.cc", "**/*.h", ]), - visibility = ["//visibility:public"], ) bool_flag( @@ -56,14 +56,12 @@ config_setting( flag_values = { ":experimental_mlir_gpu": "True", }, - visibility = ["//visibility:public"], ) cc_library( name = "test_header_helper", testonly = True, hdrs = ["test_target_triple_helper.h"], - visibility = ["//visibility:public"], ) # When using mlir based HloLowering, the following utils will sometimes be needed to define used symbols. @@ -107,7 +105,7 @@ filegroup( "runtime_matmul_s32.cc", "runtime_fork_join.cc", ], - visibility = ["//visibility:public"], + visibility = internal_visibility([":friends"]), ) filegroup( @@ -134,14 +132,13 @@ filegroup( "runtime_lightweight_check.h", "runtime_matmul.h", ], - visibility = ["//visibility:public"], + visibility = internal_visibility([":friends"]), ) cc_library( name = "cpu_xfeed", srcs = ["cpu_xfeed.cc"], hdrs = ["cpu_xfeed.h"], - visibility = ["//visibility:public"], deps = [ ":cpu_runtime", "//xla:literal", @@ -166,7 +163,6 @@ cc_library( name = "cpu_transfer_manager", srcs = ["cpu_transfer_manager.cc"], hdrs = ["cpu_transfer_manager.h"], - visibility = ["//visibility:public"], deps = [ ":cpu_runtime", ":cpu_xfeed", @@ -182,6 +178,7 @@ cc_library( "//xla/service:generic_transfer_manager", "//xla/service:transfer_manager", "//xla/stream_executor", + "//xla/stream_executor:platform_manager", "//xla/stream_executor/host:host_platform_id", "@com_google_absl//absl/base", "@com_google_absl//absl/types:span", @@ -195,7 +192,6 @@ cc_library( name = "buffer_info_util", srcs = ["buffer_info_util.cc"], hdrs = ["buffer_info_util.h"], - visibility = ["//visibility:public"], deps = [ "//xla:cpu_function_runtime", "//xla/hlo/ir:hlo", @@ -209,7 +205,6 @@ cc_library( srcs = ["cpu_compiler.cc"], hdrs = ["cpu_compiler.h"], copts = tsl_copts(), - visibility = ["//visibility:public"], deps = [ ":buffer_info_util", ":compiler_functor", @@ -426,7 +421,6 @@ cc_library( name = "cpu_compiler", srcs = ["cpu_compiler_registerer.cc"], hdrs = ["cpu_compiler.h"], - visibility = ["//visibility:public"], deps = [ "cpu_compiler_pure", ":executable_proto_cc", @@ -436,6 +430,7 @@ cc_library( "//xla:status", "//xla:statusor", "//xla:util", + "//xla:xla_proto_cc", "//xla/hlo/ir:hlo", "//xla/hlo/ir:hlo_module_group", "//xla/service:buffer_assignment", @@ -459,21 +454,19 @@ tf_proto_library( protodeps = [ ":xla_framework_proto", "//xla/service:hlo_proto", + "//xla:xla_proto", ], - visibility = ["//visibility:public"], ) tf_proto_library( name = "xla_framework_proto", srcs = ["xla_framework.proto"], cc_api_version = 2, - visibility = ["//visibility:public"], ) cc_library( name = "xla_framework", hdrs = ["xla_framework.h"], - visibility = ["//visibility:public"], deps = [":xla_framework_proto_cc"], ) @@ -485,7 +478,6 @@ cc_library( ":experimental_mlir_gpu_enabled": ["EXPERIMENTAL_MLIR_GPU=1"], "//conditions:default": [], }), - visibility = ["//visibility:public"], deps = [ "//xla:status", "//xla/mlir/backends/cpu/transforms:passes", @@ -536,7 +528,6 @@ cc_library( ], hdrs = ["simple_orc_jit.h"], copts = if_enable_acl(["-DXLA_CPU_USE_ACL=1"]) + tsl_copts(), - visibility = ["//visibility:public"], deps = [ ":compiler_functor", ":cpu_runtime", @@ -583,7 +574,6 @@ cc_library( hdrs = ["runtime_lightweight_check.h"], compatible_with = get_compatible_with_portable(), copts = runtime_copts(), - visibility = ["//visibility:public"], ) cc_library( @@ -595,7 +585,6 @@ cc_library( "runtime_fp16.h", ], copts = runtime_copts(), - visibility = ["//visibility:public"], deps = ["@com_google_absl//absl/base:core_headers"], ) @@ -608,21 +597,18 @@ cc_library( "runtime_pow.h", ], copts = runtime_copts(), - visibility = ["//visibility:public"], deps = ["@com_google_absl//absl/base:core_headers"], ) cc_library( name = "buffer_desc", hdrs = ["buffer_desc.h"], - visibility = ["//visibility:public"], ) cc_library( name = "cpu_executable", srcs = ["cpu_executable.cc"], hdrs = ["cpu_executable.h"], - visibility = ["//visibility:public"], deps = [ ":buffer_desc", ":simple_orc_jit", @@ -675,7 +661,6 @@ cc_library( "ir_emitter.h", ], copts = tsl_copts(), - visibility = ["//visibility:public"], deps = [ ":backend_config_proto_cc", ":cpu_options", @@ -710,6 +695,7 @@ cc_library( "//xla/service/llvm_ir:llvm_type_conversion_util", "//xla/service/llvm_ir:llvm_util", "//xla/service/llvm_ir:loop_emitter", + "//xla/service/llvm_ir:math_ops", "//xla/service/llvm_ir:tuple_ops", "@com_google_absl//absl/cleanup", "@com_google_absl//absl/container:flat_hash_map", @@ -733,7 +719,6 @@ cc_library( "target_machine_features.cc", ], hdrs = ["target_machine_features.h"], - visibility = ["//visibility:public"], deps = [ "//xla:cpu_function_runtime", "//xla:shape_util", @@ -748,7 +733,6 @@ cc_library( name = "target_machine_features_fake", testonly = 1, hdrs = ["target_machine_features_fake.h"], - visibility = ["//visibility:public"], deps = [ ":target_machine_features", ], @@ -758,7 +742,6 @@ cc_library( name = "ir_function", srcs = ["ir_function.cc"], hdrs = ["ir_function.h"], - visibility = ["//visibility:public"], deps = [ ":cpu_runtime", ":ir_emission_utils", @@ -779,7 +762,6 @@ cc_library( name = "parallel_loop_emitter", srcs = ["parallel_loop_emitter.cc"], hdrs = ["parallel_loop_emitter.h"], - visibility = ["//visibility:public"], deps = [ ":ir_emission_utils", "//xla/service/llvm_ir:ir_array", @@ -795,7 +777,6 @@ cc_library( name = "tiled_dot_emitter", srcs = ["tiled_dot_emitter.cc"], hdrs = ["tiled_dot_emitter.h"], - visibility = ["//visibility:public"], deps = [ ":vector_support_library", "//xla:xla_data_proto_cc", @@ -813,7 +794,6 @@ cc_library( hdrs = [ "dot_op_emitter.h", ], - visibility = ["//visibility:public"], deps = [ ":backend_config_proto_cc", ":cpu_options", @@ -879,7 +859,6 @@ cc_library( name = "compiler_functor", srcs = ["compiler_functor.cc"], hdrs = ["compiler_functor.h"], - visibility = ["//visibility:public"], deps = [ ":cpu_runtime", ":llvm_ir_runtime", @@ -914,7 +893,6 @@ cc_library( "xfeed_manager.h", ], copts = runtime_copts(), - visibility = ["//visibility:public"], deps = [ ":collectives_interface", ":cpu_executable_run_options", @@ -955,7 +933,6 @@ cc_library( hdrs = [ "llvm_ir_runtime.h", ], - visibility = ["//visibility:public"], deps = [ ":vector_support_library", "//xla/service/llvm_ir:llvm_util", @@ -1183,7 +1160,7 @@ cc_library( compatible_with = get_compatible_with_portable(), copts = runtime_copts(), linkstatic = 1, - visibility = ["//visibility:public"], + visibility = ["//visibility:private"], deps = [ "@com_google_absl//absl/base:core_headers", "@eigen_archive//:eigen3", @@ -1284,7 +1261,7 @@ xla_cc_test( xla_cc_test( name = "cpu_instruction_fusion_test", srcs = ["cpu_instruction_fusion_test.cc"], - tags = ["no_arm64"], + tags = ["no_aarch64"], deps = [ ":cpu_instruction_fusion", "//xla:shape_util", @@ -1317,7 +1294,6 @@ cc_library( name = "cpu_instruction_fusion", srcs = ["cpu_instruction_fusion.cc"], hdrs = ["cpu_instruction_fusion.h"], - visibility = ["//visibility:public"], deps = [ "//xla/hlo/ir:hlo", "//xla/service:fusion_node_indexing_evaluation", @@ -1331,7 +1307,6 @@ cc_library( name = "ir_emission_utils", srcs = ["ir_emission_utils.cc"], hdrs = ["ir_emission_utils.h"], - visibility = ["//visibility:public"], deps = [ ":cpu_runtime", ":target_machine_features", @@ -1358,7 +1333,6 @@ cc_library( name = "cpu_layout_assignment", srcs = ["cpu_layout_assignment.cc"], hdrs = ["cpu_layout_assignment.h"], - visibility = ["//visibility:public"], deps = [ ":dot_op_emitter", ":ir_emission_utils", @@ -1404,7 +1378,6 @@ cc_library( name = "conv_canonicalization", srcs = ["conv_canonicalization.cc"], hdrs = ["conv_canonicalization.h"], - visibility = ["//visibility:public"], deps = [ ":cpu_runtime", ":ir_emission_utils", @@ -1438,7 +1411,6 @@ cc_library( name = "shape_partition", srcs = ["shape_partition.cc"], hdrs = ["shape_partition.h"], - visibility = ["//visibility:public"], deps = [ "//xla:shape_util", ], @@ -1460,7 +1432,6 @@ cc_library( name = "parallel_task_assignment", srcs = ["parallel_task_assignment.cc"], hdrs = ["parallel_task_assignment.h"], - visibility = ["//visibility:public"], deps = [ ":backend_config_proto_cc", ":ir_emission_utils", @@ -1501,7 +1472,6 @@ cc_library( name = "cpu_options", srcs = ["cpu_options.cc"], hdrs = ["cpu_options.h"], - visibility = ["//visibility:public"], deps = [ "//xla/service:hlo_module_config", "@com_google_absl//absl/strings", @@ -1512,7 +1482,6 @@ cc_library( name = "orc_jit_memory_mapper", srcs = ["orc_jit_memory_mapper.cc"], hdrs = ["orc_jit_memory_mapper.h"], - visibility = ["//visibility:public"], deps = [ "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/synchronization", @@ -1525,7 +1494,6 @@ cc_library( name = "vector_support_library", srcs = ["vector_support_library.cc"], hdrs = ["vector_support_library.h"], - visibility = ["//visibility:public"], deps = [ ":target_machine_features", "//xla:shape_util", @@ -1556,7 +1524,7 @@ xla_cc_test( name = "vectorized_reduce_with_no_vector_registers_test", size = "small", srcs = ["vectorized_reduce_with_no_vector_registers_test.cc"], - tags = ["no_arm64"], + tags = ["no_aarch64"], deps = [ ":cpu_compiler", ":cpu_transfer_manager", @@ -1574,7 +1542,6 @@ cc_library( name = "mlir_emitter", srcs = ["mlir_emitter.cc"], hdrs = ["mlir_emitter.h"], - visibility = ["//visibility:public"], deps = [ "//xla:shape_util", "//xla:status", @@ -1599,7 +1566,6 @@ tf_proto_library( name = "backend_config_proto", srcs = ["backend_config.proto"], cc_api_version = 2, - visibility = ["//visibility:public"], ) cc_library( @@ -1610,6 +1576,7 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":runtime_lightweight_check", + "//xla:literal", "//xla:shape_util", "//xla:status_macros", "//xla:statusor", @@ -1645,6 +1612,7 @@ cc_library( ":onednn_memory_util", ":runtime_lightweight_check", "//xla:executable_run_options", + "//xla:shape_util", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/base:dynamic_annotations", "@eigen_archive//:eigen3", @@ -1706,21 +1674,29 @@ cc_library( name = "onednn_matmul_rewriter", srcs = ["onednn_matmul_rewriter.cc"], hdrs = [ + "onednn_matmul.h", "onednn_matmul_rewriter.h", "onednn_util.h", + "@local_tsl//tsl/util:onednn_util_hdrs", ], copts = tsl_copts(), - visibility = ["//visibility:public"], deps = [ ":backend_config_proto_cc", + ":onednn_matmul", ":onednn_memory_util", + "//xla:executable_run_options", + "//xla:shape_util", "//xla:status_macros", "//xla:xla_data_proto_cc", + "//xla/hlo/evaluator:hlo_evaluator", "//xla/hlo/ir:hlo", "//xla/service:hlo_creation_utils", "//xla/service:hlo_pass", "//xla/service:pattern_matcher", "@com_google_absl//absl/algorithm:container", + "@eigen_archive//:eigen3", + "@local_tsl//tsl/platform:blocking_counter", + "@local_tsl//tsl/platform:env", "@local_tsl//tsl/platform:logging", "@local_tsl//tsl/platform:platform_port", ] + mkl_deps(), @@ -1734,7 +1710,6 @@ cc_library( "onednn_util.h", ], copts = tsl_copts(), - visibility = ["//visibility:public"], deps = [ ":backend_config_proto_cc", ":onednn_memory_util", @@ -1754,7 +1729,6 @@ cc_library( srcs = ["cpu_float_support.cc"], hdrs = ["cpu_float_support.h"], copts = tsl_copts(), - visibility = ["//visibility:public"], deps = [ ":onednn_matmul_rewriter", "//xla/service:float_support", @@ -1764,7 +1738,6 @@ cc_library( cc_library( name = "cpu_symbol_repository", hdrs = ["cpu_symbol_repository.h"], - visibility = ["//visibility:public"], deps = [ "//xla:xla_proto_cc", "//xla/service:symbol_repository", @@ -1774,7 +1747,6 @@ cc_library( cc_library( name = "collectives_interface", hdrs = ["collectives_interface.h"], - visibility = ["//visibility:public"], deps = [ "//xla:xla_data_proto_cc", "//xla/service:collective_ops_utils", @@ -1790,7 +1762,6 @@ cc_library( name = "in_process_collectives", srcs = ["in_process_collectives.cc"], hdrs = ["in_process_collectives.h"], - visibility = ["//visibility:public"], deps = [ ":collectives_interface", "//xla:refcounting_hash_map", @@ -1813,6 +1784,5 @@ cc_library( cc_library( name = "cpu_executable_run_options", hdrs = ["cpu_executable_run_options.h"], - visibility = ["//visibility:public"], deps = [":collectives_interface"], ) diff --git a/third_party/xla/xla/service/cpu/backend_config.proto b/third_party/xla/xla/service/cpu/backend_config.proto index bb4c9a390b4d3b..e32159c8044c5e 100644 --- a/third_party/xla/xla/service/cpu/backend_config.proto +++ b/third_party/xla/xla/service/cpu/backend_config.proto @@ -26,9 +26,13 @@ message OneDnnMatMulConfig { GELU_ERF = 4; GELU_TANH = 5; BINARY_ADD = 6; + LINEAR = 7; } repeated FusionKind fused_ops = 3; bool bias_broadcast = 4; + // To avoid protobuf failures for specific decimal values, + // the original float value alpha is type-casted to int32. + int32 alpha_typecast = 5; } message OneDnnLayerNormConfig { diff --git a/third_party/xla/xla/service/cpu/cpu_compiler.cc b/third_party/xla/xla/service/cpu/cpu_compiler.cc index 46c7044001ee4f..5ab9c75f00e1dd 100644 --- a/third_party/xla/xla/service/cpu/cpu_compiler.cc +++ b/third_party/xla/xla/service/cpu/cpu_compiler.cc @@ -354,20 +354,6 @@ se::Platform::Id CpuAotCompilationOptions::PlatformId() const { return se::host::kHostPlatformId; } -CpuXlaRuntimeAotCompilationResult::CpuXlaRuntimeAotCompilationResult( - HloModuleProto hlo, std::string_view obj_file, std::string_view mlir_module, - const XlaFrameworkMapping& xla_framework_mapping) { - XlaRuntimeExecutableProto xla_runtime_executable; - *xla_runtime_executable.mutable_hlo_module_proto() = hlo; - xla_runtime_executable.set_obj_file(std::string(obj_file)); - xla_runtime_executable.set_mlir_module(std::string(mlir_module)); - - *xla_runtime_cpu_executable_.mutable_xla_runtime_executable() = - xla_runtime_executable; - *xla_runtime_cpu_executable_.mutable_xla_framework_mapping() = - xla_framework_mapping.ToProto(); -} - namespace { namespace runtime = ::xla::runtime; @@ -394,7 +380,7 @@ class FlattenTuplesAndBufferizeTypeConverter : public mlir::TypeConverter { }; runtime::JitExecutable::Options GetXlaRuntimeJitExecutableOptions( - const HloModule& module) { + const HloModule& module, mlir::DialectRegistry* custom_registry) { runtime::CpuPipelineOptions copts; runtime::JitExecutable::Options opts; copts.xla_cpu_sparse_cuda_threads = @@ -403,10 +389,13 @@ runtime::JitExecutable::Options GetXlaRuntimeJitExecutableOptions( options::ExperimentalOverriddenPipeline(module.config()); opts.specialization = runtime::JitExecutable::Specialization::kDisabled; opts.compiler.register_dialects = - [](xla::runtime::DialectRegistry& dialects) { + [custom_registry](xla::runtime::DialectRegistry& dialects) { dialects->insert(); runtime::RegisterDefaultXlaCpuRuntimeDialects(dialects); RegisterHloXlaRuntimePipelineDialects(*dialects); + if (custom_registry) { + custom_registry->appendTo(*dialects); + } }; opts.compiler.symbols_binding = runtime::ToSymbolsBinding( [](runtime::DirectCustomCallRegistry& registry) { @@ -461,47 +450,24 @@ runtime::JitExecutable::Options GetXlaRuntimeJitExecutableOptions( } // namespace -StatusOr> -CpuXlaRuntimeAotCompilationResult::LoadExecutable( - Compiler* compiler, const se::StreamExecutor* executor) const { - XlaRuntimeExecutableProto xla_runtime_executable = - xla_runtime_cpu_executable_.xla_runtime_executable(); - TF_ASSIGN_OR_RETURN(HloModuleConfig hlo_module_config, - HloModule::CreateModuleConfigFromProto( - xla_runtime_executable.hlo_module_proto(), - GetDebugOptionsFromFlags())); - TF_ASSIGN_OR_RETURN( - std::unique_ptr hlo_module, - HloModule::CreateFromProto(xla_runtime_executable.hlo_module_proto(), - hlo_module_config)); - - XlaFrameworkMapping xla_framework_mapping; - xla_framework_mapping.FromProto( - xla_runtime_cpu_executable_.xla_framework_mapping()); - - TF_ASSIGN_OR_RETURN(std::unique_ptr buffer_assignment, - compiler->AssignBuffers(hlo_module.get(), executor)); - - // TODO(b/232263665): JitOptions should be used only for JIT case because it - // has details irrelevant to AOT. - runtime::JitExecutable::Options opts = - GetXlaRuntimeJitExecutableOptions(*hlo_module); - - return CpuExecutable::LoadFromObjFile( - std::move(hlo_module), xla_runtime_executable.obj_file(), - xla_runtime_executable.mlir_module(), std::move(buffer_assignment), - xla_framework_mapping, opts); -} - CpuAotCompilationResult::CpuAotCompilationResult( ObjectFileData object_file_data, std::vector buffer_infos, - int64_t result_buffer_index, + int64_t result_buffer_index, std::unique_ptr module, std::unique_ptr hlo_profile_printer_data) : object_file_data_(std::move(object_file_data)), buffer_infos_(std::move(buffer_infos)), result_buffer_index_(result_buffer_index), + module_(std::move(module)), hlo_profile_printer_data_(std::move(hlo_profile_printer_data)) {} +const HloModule* CpuAotCompilationResult::optimized_module() const { + return module_.get(); +} + +std::unique_ptr CpuAotCompilationResult::consume_optimized_module() { + return std::move(module_); +} + CpuCompiler::CpuCompiler(bool allow_sparse_shapes) : allow_sparse_shapes_(allow_sparse_shapes) { // Initialize LLVM the first time the CpuCompiler is initialized. @@ -678,9 +644,9 @@ Status CpuCompiler::RunHloPassesThroughLayoutAssn( } { - // Int4Packer must be run before the rest of the pipeline since it modifies - // the layout of the entry computation inputs/outputs, which is passed to - // LayoutAssignment. + // Int4Packer must be run before the rest of the pipeline since it + // modifies the layout of the entry computation inputs/outputs, which is + // passed to LayoutAssignment. HloPassPipeline int4_packer_pipeline("Int4Packer pipeline"); int4_packer_pipeline.AddPass( SubByteNormalization::SET_ELEMENT_SIZE); @@ -708,6 +674,12 @@ Status CpuCompiler::RunHloPassesThroughLayoutAssn( pipeline.AddPass(); pipeline.AddPass(); + // The TopkDecomposer generates a compare op with type=TOTALORDER and must + // run before the ComparisonExpander which rewrites such comparisons. + pipeline.AddPass([&](const HloInstruction* instr) { + return instr->opcode() == HloOpcode::kTopK; + }); + pipeline.AddPass(); pipeline.AddPass(); pipeline.AddPass(); @@ -854,9 +826,6 @@ Status CpuCompiler::RunHloPassesThroughLayoutAssn( pipeline.AddPass(); }(); pipeline.AddPass(); - pipeline.AddPass([&](const HloInstruction* instr) { - return instr->opcode() == HloOpcode::kTopK; - }); // XLA lowers topk to a libcall while the MLIR based pipeline does not yet // support libcalls. Disable this for now. @@ -903,7 +872,8 @@ Status CpuCompiler::RunHloPassesThroughLayoutAssn( Status CpuCompiler::RunHloPassesAfterLayoutAssn( HloModule* module, bool is_aot_compile, - LLVMTargetMachineFeatures* target_machine_features, bool is_mlir_compile) { + LLVMTargetMachineFeatures* target_machine_features, + const CompileOptions& compile_options, bool is_mlir_compile) { HloPassPipeline pipeline("HLO passes after layout assignment"); // CopyInsertion is still needed by BufferAssignment. MLIR passes will handle @@ -929,18 +899,22 @@ Status CpuCompiler::RunHloPassesAfterLayoutAssn( pipeline.AddPass(); + const int max_parallelism = + module->config().intra_op_parallelism_threads() > 0 + ? module->config().intra_op_parallelism_threads() + : tsl::port::NumSchedulableCPUs(); + #if defined(INTEL_MKL) && defined(ENABLE_ONEDNN_V3) // AOT compiled code runs in single thread. if (!is_aot_compile) { // Run SimplifyFPConversions pass to simplify the BF16 pattern and make it // easier to match. - pipeline.AddPass( - SimplifyFPConversions::Scope::kSimplifyAllConversions); - pipeline.AddPass(); + pipeline.AddPass(); + pipeline.AddPass(max_parallelism, + compile_options.thread_pool); // Run SimplifyFPConversions pass again to remove redundant Convert ops // that may exist as a result of running OneDnnMatMulRewriter pass. - pipeline.AddPass( - SimplifyFPConversions::Scope::kSimplifyAllConversions); + pipeline.AddPass(); } #endif // INTEL_MKL && ENABLE_ONEDNN_V3 @@ -971,10 +945,6 @@ Status CpuCompiler::RunHloPassesAfterLayoutAssn( }(); // Outline ops in the entry computation into calls to subcomputations. - const int max_parallelism = - module->config().intra_op_parallelism_threads() > 0 - ? module->config().intra_op_parallelism_threads() - : tsl::port::NumSchedulableCPUs(); if (!is_aot_compile) { // Run ParallelTaskAssigner to assign parallel tasks to HLOs in module. // Note this is not run for AOT because it would bring in thread pool @@ -999,13 +969,15 @@ Status CpuCompiler::RunHloPassesAfterLayoutAssn( Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile, llvm::TargetMachine* target_machine, + const CompileOptions& compile_options, bool is_mlir_compile) { LLVMTargetMachineFeatures target_machine_features(target_machine); TF_RETURN_IF_ERROR(RunHloPassesThroughLayoutAssn( module, is_aot_compile, &target_machine_features, is_mlir_compile)); return RunHloPassesAfterLayoutAssn(module, is_aot_compile, - &target_machine_features, is_mlir_compile); + &target_machine_features, compile_options, + is_mlir_compile); } namespace { @@ -1120,7 +1092,7 @@ Status CreateHloProfilingArtifacts( StatusOr> CpuCompiler::RunHloPasses( std::unique_ptr module, se::StreamExecutor* /*stream_exec*/, - const CompileOptions& /*options*/) { + const CompileOptions& options) { std::unique_ptr jit_target_machine = SimpleOrcJIT::InferTargetMachineForJIT( CompilerTargetOptions(module->config()), @@ -1128,6 +1100,7 @@ StatusOr> CpuCompiler::RunHloPasses( TF_RETURN_IF_ERROR(RunHloPasses( module.get(), /*is_aot_compile=*/false, jit_target_machine.get(), + /*compile_options=*/options, /*is_mlir_compile=*/ module->config().debug_options().xla_cpu_use_xla_runtime())); return std::move(module); @@ -1395,8 +1368,7 @@ CpuCompiler::CompileLegacyCpuExecutable(std::unique_ptr module) { post_optimization_ir_hook, CreateOrcJITPostCompilationHook(module.get(), &obj_files)); if (!jit) { - return Internal("Creating JIT failed: %s", - llvm::toString(jit.takeError())); + return Internal("Creating JIT failed: %s", llvm::toString(jit.takeError())); } llvm_module->setDataLayout((*jit)->data_layout()); llvm_module->setTargetTriple((*jit)->target_triple().getTriple()); @@ -1531,16 +1503,17 @@ namespace { StatusOr> GetXlaRuntimeCpuExecutable( const HloModule& hlo_module, mlir::ModuleOp mlir_module, absl::string_view entry_point, - const XlaFrameworkMapping& xla_framework_mapping) { + const XlaFrameworkMapping& xla_framework_mapping, + mlir::DialectRegistry* registry) { runtime::JitExecutable::Options opts = - GetXlaRuntimeJitExecutableOptions(hlo_module); + GetXlaRuntimeJitExecutableOptions(hlo_module, registry); std::string serialized_mlir = llvm_ir::DumpToString(mlir_module); absl::StatusOr jit_executable = runtime::JitExecutable::Instantiate(serialized_mlir, entry_point, opts); if (!jit_executable.ok()) { return Internal("Failed to compile XLA Runtime program: %s", - jit_executable.status().message()); + jit_executable.status().message()); } return std::make_unique( @@ -1551,7 +1524,7 @@ StatusOr> GetXlaRuntimeCpuExecutable( StatusOr> CpuCompiler::CompileXlaRuntimeCpuExecutable( - std::unique_ptr hlo_module) { + std::unique_ptr hlo_module, mlir::DialectRegistry* registry) { // Select an order for emitting the HLO instructions for each // computation. Using this sequence enables tighter buffer liveness analysis // and reduced memory usage (as compared to using DependencyHloOrdering). @@ -1587,6 +1560,9 @@ CpuCompiler::CompileXlaRuntimeCpuExecutable( } mlir::MLIRContext mlir_context; + if (registry) { + mlir_context.appendDialectRegistry(*registry); + } XlaFrameworkMapping xla_framework_mapping; TF_ASSIGN_OR_RETURN( auto mlir_module, @@ -1596,7 +1572,7 @@ CpuCompiler::CompileXlaRuntimeCpuExecutable( TF_ASSIGN_OR_RETURN( auto xla_runtime_executable, GetXlaRuntimeCpuExecutable(*hlo_module, *mlir_module, "main", - xla_framework_mapping)); + xla_framework_mapping, registry)); if (DumpingEnabledForHloModule(*hlo_module)) { TF_ASSIGN_OR_RETURN(std::string_view obj_file, @@ -1614,7 +1590,7 @@ CpuCompiler::CompileXlaRuntimeCpuExecutable( StatusOr> CpuCompiler::RunBackend( std::unique_ptr module, [[maybe_unused]] se::StreamExecutor* stream_exec, - [[maybe_unused]] const CompileOptions& options) { + const CompileOptions& options) { VLOG(1) << "Compiling: " << module->name(); XLA_SCOPED_LOGGING_TIMER( absl::StrFormat("Compiling [%s] for CPU using JIT", module->name())); @@ -1627,8 +1603,9 @@ StatusOr> CpuCompiler::RunBackend( std::unique_ptr cpu_executable; if (module->config().debug_options().xla_cpu_use_xla_runtime()) { - TF_ASSIGN_OR_RETURN(cpu_executable, - CompileXlaRuntimeCpuExecutable(std::move(module))); + TF_ASSIGN_OR_RETURN( + cpu_executable, + CompileXlaRuntimeCpuExecutable(std::move(module), options.registry)); } else { TF_ASSIGN_OR_RETURN(cpu_executable, CompileLegacyCpuExecutable(std::move(module))); @@ -1741,6 +1718,7 @@ CpuCompiler::CompileAheadOfTime(std::unique_ptr module_group, TF_RETURN_IF_ERROR( RunHloPasses(module, /*is_aot_compile=*/true, target_machine.get(), + /*dummy*/ CompileOptions{}, /*is_mlir_compile=*/options.use_mlir_hlo_lowering())); TF_ASSIGN_OR_RETURN(HloSchedule schedule, @@ -1904,7 +1882,8 @@ CpuCompiler::CompileAheadOfTime(std::unique_ptr module_group, results.emplace_back(std::make_unique( std::move(object_file_data), std::move(buffer_infos), - result_slice.index(), std::move(hlo_profile_printer_data))); + result_slice.index(), std::move(modules[i]), + std::move(hlo_profile_printer_data))); } VLOG(1) << "Compilation finished"; @@ -1939,10 +1918,13 @@ class CpuExecutableAotCompilationResult : public AotCompilationResult { const BufferAssignment* buffer_assignment, std::string_view function_name, std::string_view obj_file) { - *proto_.mutable_hlo_module() = hlo_module->ToProto(); + *proto_.mutable_hlo_module()->mutable_hlo_module() = hlo_module->ToProto(); *proto_.mutable_buffer_assignment() = buffer_assignment->ToProto(); proto_.set_entry_function_name(std::string(function_name)); proto_.set_obj_file(std::string(obj_file)); + *proto_.mutable_hlo_module()->mutable_config() = + *hlo_module->config().ToProto(); + module_ = hlo_module->Clone(); } StatusOr SerializeAsString() const override { @@ -1956,18 +1938,31 @@ class CpuExecutableAotCompilationResult : public AotCompilationResult { return Internal( "Failed to parse serialized CpuExecutableAotCompilationResult."); } + + TF_ASSIGN_OR_RETURN( + std::unique_ptr module, + HloModule::CreateFromProtoWithConfig(proto.hlo_module())); + return std::unique_ptr( - new CpuExecutableAotCompilationResult(proto)); + new CpuExecutableAotCompilationResult(proto, std::move(module))); } StatusOr> LoadExecutable( Compiler* compiler, const se::StreamExecutor* stream_exec) const override; + const HloModule* optimized_module() const override { return module_.get(); } + + std::unique_ptr consume_optimized_module() override { + return std::move(module_); + } + private: - explicit CpuExecutableAotCompilationResult(CompilationResultProto proto) - : proto_(std::move(proto)) {} + explicit CpuExecutableAotCompilationResult(CompilationResultProto proto, + std::unique_ptr module) + : proto_(std::move(proto)), module_(std::move(module)) {} CompilationResultProto proto_; + std::unique_ptr module_; }; } // namespace @@ -1976,12 +1971,9 @@ StatusOr> CpuExecutableAotCompilationResult::LoadExecutable( Compiler* compiler, const se::StreamExecutor* stream_exec) const { // Recreate HloModule from proto. - TF_ASSIGN_OR_RETURN(HloModuleConfig hlo_module_config, - HloModule::CreateModuleConfigFromProto( - proto_.hlo_module(), GetDebugOptionsFromFlags())); TF_ASSIGN_OR_RETURN( std::unique_ptr module, - HloModule::CreateFromProto(proto_.hlo_module(), hlo_module_config)); + HloModule::CreateFromProtoWithConfig(proto_.hlo_module())); // Recreate BufferAssignment from proto. TF_ASSIGN_OR_RETURN( @@ -2000,8 +1992,7 @@ CpuExecutableAotCompilationResult::LoadExecutable( /*pre_optimization_hook=*/nullptr, /*post_optimization_hook=*/nullptr, /*post_codegen_hook=*/nullptr); if (!jit) { - return Internal("Creating JIT failed: %s", - llvm::toString(jit.takeError())); + return Internal("Creating JIT failed: %s", llvm::toString(jit.takeError())); } // Create a named buffer from compiled object file. diff --git a/third_party/xla/xla/service/cpu/cpu_compiler.h b/third_party/xla/xla/service/cpu/cpu_compiler.h index 4f45dbce2a15f6..e28c9ed34cc8a2 100644 --- a/third_party/xla/xla/service/cpu/cpu_compiler.h +++ b/third_party/xla/xla/service/cpu/cpu_compiler.h @@ -41,6 +41,10 @@ limitations under the License. #include "xla/stream_executor/stream_executor.h" #include "xla/util.h" +namespace mlir { +class DialectRegistry; +} // namespace mlir + namespace xla { namespace cpu { @@ -96,44 +100,12 @@ class CpuAotCompilationOptions : public AotCompilationOptions { bool use_mlir_hlo_lowering_ = false; }; -class CpuXlaRuntimeAotCompilationResult : public AotCompilationResult { - public: - CpuXlaRuntimeAotCompilationResult( - HloModuleProto hlo, std::string_view obj_file, - std::string_view mlir_module, - const XlaFrameworkMapping& xla_framework_mapping); - - explicit CpuXlaRuntimeAotCompilationResult( - XlaRuntimeCpuExecutableProto executable) - : xla_runtime_cpu_executable_(executable) {} - - StatusOr SerializeAsString() const override { - return xla_runtime_cpu_executable_.SerializeAsString(); - } - - static StatusOr> - FromString(const std::string& serialized) { - XlaRuntimeCpuExecutableProto xla_runtime_cpu_executable; - if (!xla_runtime_cpu_executable.ParseFromString(serialized)) { - return Internal("Failed to parse serialized JitRtExecutableProto."); - } - return std::make_unique( - xla_runtime_cpu_executable); - } - - StatusOr> LoadExecutable( - Compiler* compiler, const se::StreamExecutor* executor) const override; - - private: - XlaRuntimeCpuExecutableProto xla_runtime_cpu_executable_; -}; - class CpuAotCompilationResult : public AotCompilationResult { public: CpuAotCompilationResult( ObjectFileData object_file_data, std::vector buffer_infos, - int64_t result_buffer_index, + int64_t result_buffer_index, std::unique_ptr module, std::unique_ptr hlo_profile_printer_data); ~CpuAotCompilationResult() override = default; @@ -147,6 +119,9 @@ class CpuAotCompilationResult : public AotCompilationResult { } int64_t result_buffer_index() const { return result_buffer_index_; } + const HloModule* optimized_module() const override; + std::unique_ptr consume_optimized_module() override; + private: // Contains the compiled computation: an object file. const ObjectFileData object_file_data_; @@ -160,6 +135,9 @@ class CpuAotCompilationResult : public AotCompilationResult { // parameter when calling the compiled computation. const int64_t result_buffer_index_; + // Contains the optimized HLO module. + std::unique_ptr module_; + // Contains an instance of HloProfilePrinterData if HLO profiling is enabled, // otherwise is nullptr. std::unique_ptr hlo_profile_printer_data_; @@ -208,8 +186,12 @@ class CpuCompiler : public LLVMCompiler { StatusOr> LoadAotCompilationResult( const std::string& serialized_aot_result) override; + // The optional `registry` supports MLIR dialects and plugins to be loaded + // during optimization. If non-null, it will be used to construct relevant + // MLIR contexts. StatusOr> CompileXlaRuntimeCpuExecutable( - std::unique_ptr module); + std::unique_ptr module, + mlir::DialectRegistry* registry = nullptr); private: // Initialize the LLVM target. @@ -219,6 +201,7 @@ class CpuCompiler : public LLVMCompiler { // correctness. Status RunHloPasses(HloModule* module, bool is_aot_compile, llvm::TargetMachine* target_machine, + const CompileOptions& compile_options, bool is_mlir_compile = false); // Runs HLO passes up to and including layout assignment. @@ -230,7 +213,8 @@ class CpuCompiler : public LLVMCompiler { // Runs HLO passes after layout assignment. Status RunHloPassesAfterLayoutAssn( HloModule* module, bool is_aot_compile, - LLVMTargetMachineFeatures* target_machine_features, bool is_mlir_compile); + LLVMTargetMachineFeatures* target_machine_features, + const CompileOptions& compile_options, bool is_mlir_compile); StatusOr> CompileLegacyCpuExecutable( std::unique_ptr module); diff --git a/third_party/xla/xla/service/cpu/cpu_runtime.cc b/third_party/xla/xla/service/cpu/cpu_runtime.cc index 773a7e9452d4ad..789a8d86fc01b8 100644 --- a/third_party/xla/xla/service/cpu/cpu_runtime.cc +++ b/third_party/xla/xla/service/cpu/cpu_runtime.cc @@ -164,6 +164,8 @@ extern const char* const kOneDnnSoftmaxSymbolName = "__xla_cpu_runtime_OneDnnSoftmax"; extern const char* const kOneDnnLayerNormSymbolName = "__xla_cpu_runtime_OneDnnLayerNorm"; +extern const char* const kOneDnnMatMulReorderSymbolName = + "__xla_cpu_runtime_OneDnnMatMulReorder"; namespace { diff --git a/third_party/xla/xla/service/cpu/cpu_runtime.h b/third_party/xla/xla/service/cpu/cpu_runtime.h index 894559147f8cba..e4fc06fc85bd5a 100644 --- a/third_party/xla/xla/service/cpu/cpu_runtime.h +++ b/third_party/xla/xla/service/cpu/cpu_runtime.h @@ -89,6 +89,7 @@ extern const char* const kReduceScatterSymbolName; extern const char* const kOneDnnMatMulSymbolName; extern const char* const kOneDnnSoftmaxSymbolName; extern const char* const kOneDnnLayerNormSymbolName; +extern const char* const kOneDnnMatMulReorderSymbolName; // All symbol names for XLA CPU runtime functions need to start with this // prefix. diff --git a/third_party/xla/xla/service/cpu/cpu_transfer_manager.cc b/third_party/xla/xla/service/cpu/cpu_transfer_manager.cc index 101b3dde21d743..27e62fc1af723d 100644 --- a/third_party/xla/xla/service/cpu/cpu_transfer_manager.cc +++ b/third_party/xla/xla/service/cpu/cpu_transfer_manager.cc @@ -30,6 +30,7 @@ limitations under the License. #include "xla/status_macros.h" #include "xla/statusor.h" #include "xla/stream_executor/host/host_platform_id.h" +#include "xla/stream_executor/platform_manager.h" #include "xla/stream_executor/stream_executor.h" #include "xla/types.h" #include "xla/util.h" @@ -61,7 +62,7 @@ Status CpuTransferManager::ReadDynamicShapes(se::Stream* stream, device_shape); } TF_ASSIGN_OR_RETURN(auto platform, - se::MultiPlatformManager::PlatformWithId(PlatformId())); + se::PlatformManager::PlatformWithId(PlatformId())); TF_ASSIGN_OR_RETURN(auto compiler, Compiler::GetForPlatform(platform)); return ReadDynamicShapesOnCpu(device_buffer, device_shape, compiler->ShapeSizeBytesFunction()); diff --git a/third_party/xla/xla/service/cpu/dot_op_emitter.cc b/third_party/xla/xla/service/cpu/dot_op_emitter.cc index e746726a3dd4fa..c6bad1e54ec045 100644 --- a/third_party/xla/xla/service/cpu/dot_op_emitter.cc +++ b/third_party/xla/xla/service/cpu/dot_op_emitter.cc @@ -326,7 +326,7 @@ Status DotOpEmitter::EmitLinalgMatmul() { /*outputs=*/mlir::ValueRange{a}, /*indexingMaps=*/ mlir::AffineMap::inferFromExprList( - {b_exprs, c_exprs, parallel_exprs}), + {b_exprs, c_exprs, parallel_exprs}, context), /*iteratorTypes=*/iteratorTypes, [](mlir::OpBuilder& b, mlir::Location loc, mlir::ValueRange args) { mlir::ArithBuilder ab(b, loc); diff --git a/third_party/xla/xla/service/cpu/elemental_ir_emitter.cc b/third_party/xla/xla/service/cpu/elemental_ir_emitter.cc index 265f6a962e2eb6..6d1980208fc1dc 100644 --- a/third_party/xla/xla/service/cpu/elemental_ir_emitter.cc +++ b/third_party/xla/xla/service/cpu/elemental_ir_emitter.cc @@ -24,6 +24,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/service/llvm_ir/llvm_util.h" +#include "xla/service/llvm_ir/math_ops.h" #include "xla/types.h" #include "xla/util.h" #include "xla/xla_data.pb.h" @@ -105,5 +106,32 @@ StatusOr CpuElementalIrEmitter::EmitTanh(PrimitiveType prim_type, return result; } +StatusOr CpuElementalIrEmitter::EmitErf(PrimitiveType prim_type, + llvm::Value* value) { + if (prim_type == F64) { + std::string function_name = "erf"; + // Create a function declaration. + llvm::Function* function = llvm::dyn_cast( + module() + ->getOrInsertFunction(function_name, value->getType(), + value->getType()) + .getCallee()); + function->setCallingConv(llvm::CallingConv::C); + function->setDoesNotThrow(); + function->setDoesNotAccessMemory(); + // Create an instruction to call the function. + llvm::Value* result = Call(function, value); + return result; + } + // Upcast F16 to F32 if necessary. + llvm::Type* type = prim_type == F16 ? b()->getFloatTy() : value->getType(); + if (type == b()->getFloatTy()) { + llvm::Value* x = FPCast(value, type); + auto* result = llvm_ir::EmitErfF32(b(), x); + return FPCast(result, value->getType()); + } + return Unimplemented("erf"); +} + } // namespace cpu } // namespace xla diff --git a/third_party/xla/xla/service/cpu/elemental_ir_emitter.h b/third_party/xla/xla/service/cpu/elemental_ir_emitter.h index 8a87f4ba69f38a..c4a3a0d291e648 100644 --- a/third_party/xla/xla/service/cpu/elemental_ir_emitter.h +++ b/third_party/xla/xla/service/cpu/elemental_ir_emitter.h @@ -41,6 +41,8 @@ class CpuElementalIrEmitter : public ElementalIrEmitter { absl::string_view name) override; StatusOr EmitTanh(PrimitiveType prim_type, llvm::Value* value) override; + StatusOr EmitErf(PrimitiveType prim_type, + llvm::Value* value) override; StatusOr> EmitThreadLocalCall( const HloComputation& callee, absl::Span parameters, diff --git a/third_party/xla/xla/service/cpu/executable.proto b/third_party/xla/xla/service/cpu/executable.proto index db4ad8a54d1765..2c48a51f4b0430 100644 --- a/third_party/xla/xla/service/cpu/executable.proto +++ b/third_party/xla/xla/service/cpu/executable.proto @@ -19,6 +19,7 @@ package xla.cpu; import "xla/service/cpu/xla_framework.proto"; import "xla/service/hlo.proto"; +import "xla/xla.proto"; message XlaRuntimeCpuExecutableProto { optional XlaRuntimeExecutableProto xla_runtime_executable = 1; @@ -26,7 +27,7 @@ message XlaRuntimeCpuExecutableProto { } message CompilationResultProto { - HloModuleProto hlo_module = 1; + HloModuleProtoWithConfig hlo_module = 1; BufferAssignmentProto buffer_assignment = 2; string entry_function_name = 3; bytes obj_file = 4; diff --git a/third_party/xla/xla/service/cpu/ir_emitter.cc b/third_party/xla/xla/service/cpu/ir_emitter.cc index e49e99f314e816..55cec03528fa34 100644 --- a/third_party/xla/xla/service/cpu/ir_emitter.cc +++ b/third_party/xla/xla/service/cpu/ir_emitter.cc @@ -2517,7 +2517,8 @@ Status IrEmitter::HandleTopK(HloInstruction* hlo) { } #if defined(INTEL_MKL) && defined(ENABLE_ONEDNN_V3) -Status IrEmitter::HandleOneDnnMatMul(HloInstruction* custom_call) { +Status IrEmitter::HandleOneDnnMatMulCalls(HloInstruction* custom_call, + std::string runtime_symbol_name) { // We would like to emit LLVM IR for the following function call // custom_call_target(void* result, void** args) // args can be thought of an array of pointers allocated on the stack, @@ -2592,7 +2593,7 @@ Status IrEmitter::HandleOneDnnMatMul(HloInstruction* custom_call) { llvm_ir::IrArray result_array = GetIrArrayFor(custom_call); auto result_stack_alloca = GetAllocaAndEmitMemrefInfo(b_, result_array); - EmitCallToFunc(runtime::kOneDnnMatMulSymbolName, + EmitCallToFunc(std::move(runtime_symbol_name), {result_stack_alloca.value, args_ptr}, b_.getVoidTy()); // Lifetime ends for all stack allocations. @@ -2705,7 +2706,6 @@ Status IrEmitter::HandleOneDnnSoftmax(HloInstruction* custom_call) { return OkStatus(); } - #endif // INTEL_MKL && ENABLE_ONEDNN_V3 Status IrEmitter::HandleCustomCall(HloInstruction* custom_call) { @@ -2720,7 +2720,8 @@ Status IrEmitter::HandleCustomCall(HloInstruction* custom_call) { } #if defined(INTEL_MKL) && defined(ENABLE_ONEDNN_V3) if (custom_call->custom_call_target() == "__onednn$matmul") { - return HandleOneDnnMatMul(custom_call); + return HandleOneDnnMatMulCalls(custom_call, + runtime::kOneDnnMatMulSymbolName); } if (custom_call->custom_call_target() == "__onednn$softmax") { return HandleOneDnnSoftmax(custom_call); @@ -2728,6 +2729,10 @@ Status IrEmitter::HandleCustomCall(HloInstruction* custom_call) { if (custom_call->custom_call_target() == "__onednn$layernorm") { return HandleOneDnnLayerNorm(custom_call); } + if (custom_call->custom_call_target() == "__onednn$matmul_reorder") { + return HandleOneDnnMatMulCalls(custom_call, + runtime::kOneDnnMatMulReorderSymbolName); + } #endif // INTEL_MKL && ENABLE_ONEDNN_V3 absl::Span operands(custom_call->operands()); llvm::AllocaInst* operands_alloca = diff --git a/third_party/xla/xla/service/cpu/ir_emitter.h b/third_party/xla/xla/service/cpu/ir_emitter.h index ace9b36b7590f4..12beb8547fb09c 100644 --- a/third_party/xla/xla/service/cpu/ir_emitter.h +++ b/third_party/xla/xla/service/cpu/ir_emitter.h @@ -195,7 +195,8 @@ class IrEmitter : public DfsHloVisitorWithDefault, Status HandleAllReduceSingleReplica(HloInstruction* crs); Status HandleAllReduceMultipleReplica(HloInstruction* crs); #if defined(INTEL_MKL) && defined(ENABLE_ONEDNN_V3) - Status HandleOneDnnMatMul(HloInstruction* hlo); + Status HandleOneDnnMatMulCalls(HloInstruction* hlo, + std::string runtime_symbol_name); Status HandleOneDnnSoftmax(HloInstruction* hlo); Status HandleOneDnnLayerNorm(HloInstruction* hlo); #endif // INTEL_MKL && ENABLE_ONEDNN_V3 diff --git a/third_party/xla/xla/service/cpu/onednn_matmul.cc b/third_party/xla/xla/service/cpu/onednn_matmul.cc index e0dfba57160de4..3686827198df7a 100644 --- a/third_party/xla/xla/service/cpu/onednn_matmul.cc +++ b/third_party/xla/xla/service/cpu/onednn_matmul.cc @@ -33,6 +33,8 @@ limitations under the License. #include "xla/service/cpu/backend_config.pb.h" #include "xla/service/cpu/onednn_memory_util.h" #include "xla/service/cpu/runtime_lightweight_check.h" +#include "xla/shape.h" +#include "xla/shape_util.h" #include "tsl/platform/logging.h" #include "tsl/util/onednn_threadpool.h" @@ -43,8 +45,128 @@ using dnnl::engine; using dnnl::matmul; using dnnl::memory; using dnnl::stream; + +dnnl::memory::desc Transpose(const dnnl::memory::desc& md) { + int64_t ndims = md.get_ndims(); + // Do not transpose 1D + if (ndims == 1) { + return md; + } + + std::vector permutation(ndims); + std::iota(permutation.begin(), permutation.end(), 0); + std::swap(permutation[ndims - 1], permutation[ndims - 2]); + return md.permute_axes(permutation); +} + +dnnl::memory::desc ShapeToMemDesc(const Shape& shape, bool transpose = false) { + auto dimensions = shape.dimensions(); + if (dimensions.size() == 0) { + return dnnl::memory::desc{}; + } + + auto dims = dnnl::memory::dims(dimensions.begin(), dimensions.end()); + + dnnl::memory::dims strides(dims.size()); + dnnl::memory::dim stride = 1; + for (auto i : shape.layout().minor_to_major()) { + strides.at(i) = stride; + stride *= dims.at(i); + } + + auto dt = ToOneDnnDataType(static_cast(shape.element_type())); + + return transpose ? Transpose(dnnl::memory::desc(dims, dt, strides)) + : dnnl::memory::desc(dims, dt, strides); +} + +dnnl::memory::desc OneDnnMatMulOptWeightsDesc( + const dnnl::engine& engine, const dnnl::memory::desc& input_md, + const dnnl::memory::desc& weights_md, const dnnl::memory::desc& bias_md, + const dnnl::memory::desc& output_md) { + auto weights_any_md = + memory::desc(weights_md.get_dims(), weights_md.get_data_type(), + dnnl::memory::format_tag::any); + + auto matmul_pd = matmul::primitive_desc(engine, input_md, weights_any_md, + bias_md, output_md); + + return matmul_pd.weights_desc(); +} + +dnnl::memory::desc OneDnnMatMulOptWeightsDesc( + const dnnl::engine& engine, const Shape& input_shape, + const Shape& weights_shape, const Shape& bias_shape, + const Shape& output_shape, const OneDnnMatMulConfig* matmul_config) { + auto input_md = ShapeToMemDesc(input_shape, matmul_config->transpose_a()); + auto weights_md = ShapeToMemDesc(weights_shape, matmul_config->transpose_b()); + auto bias_md = + absl::c_count(matmul_config->fused_ops(), OneDnnMatMulConfig::BIAS) > 0 + ? ShapeToMemDesc(bias_shape) + : dnnl::memory::desc{}; + auto output_md = ShapeToMemDesc(output_shape); + + // extend bias rank to match result rank + auto missed_rank = output_md.get_ndims() - bias_md.get_ndims(); + XLA_LIGHTWEIGHT_CHECK(missed_rank >= 0); + if (!bias_md.is_zero() && missed_rank > 0) { + auto bias_dims = bias_md.get_dims(); + bias_dims.insert(bias_dims.begin(), missed_rank, 1); + bias_md = bias_md.reshape(bias_dims); + } + + return OneDnnMatMulOptWeightsDesc(engine, input_md, weights_md, bias_md, + output_md); +} + +Shape MemDescToXlaShape(const dnnl::memory::desc& md) { + auto dtype = md.get_data_type(); + auto element_size = dnnl::memory::data_type_size(dtype); + int64_t bytes_num = md.get_size(); + XLA_LIGHTWEIGHT_CHECK(bytes_num % element_size == 0); + int64_t elements_num = static_cast(bytes_num / element_size); + return ShapeUtil::MakeShape(ToXlaPrimitiveType(dtype), {elements_num}); +} + +std::unique_ptr CreateOneDnnThreadPool( + const xla::ExecutableRunOptions* run_options) { +#ifndef ENABLE_ONEDNN_OPENMP + if (run_options != nullptr && + run_options->intra_op_thread_pool() != nullptr) { + return std::make_unique( + run_options->intra_op_thread_pool()->getPool(), false); + } else { + return nullptr; + } +#else + return nullptr; +#endif // ENABLE_ONEDNN_OPENMP +} + +dnnl::stream MakeOneDnnStream( + const dnnl::engine& cpu_engine, + dnnl::threadpool_interop::threadpool_iface* thread_pool) { + if (thread_pool != nullptr) { + return dnnl::threadpool_interop::make_stream(cpu_engine, thread_pool); + } else { + return dnnl::stream(cpu_engine); + } +} + } // namespace +Shape OneDnnMatMulOptWeightsShape(const Shape& input_shape, + const Shape& weights_shape, + const Shape& bias_shape, + const Shape& output_shape, + const OneDnnMatMulConfig* matmul_config) { + engine cpu_engine(engine::kind::cpu, 0); + auto optimized_weights_md = + OneDnnMatMulOptWeightsDesc(cpu_engine, input_shape, weights_shape, + bias_shape, output_shape, matmul_config); + return MemDescToXlaShape(optimized_weights_md); +} + ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_OneDnnMatMul( void* result, void** args) { // args[0]: ptr to nargs @@ -58,15 +180,10 @@ ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_OneDnnMatMul( static_cast(args[arg_indx++]); XLA_LIGHTWEIGHT_CHECK(run_options != nullptr); XLA_LIGHTWEIGHT_CHECK(run_options->intra_op_thread_pool() != nullptr); - tsl::OneDnnThreadPool thread_pool( - run_options->intra_op_thread_pool()->getPool(), false); + + auto thread_pool = CreateOneDnnThreadPool(run_options); engine cpu_engine(engine::kind::cpu, 0); -#ifndef ENABLE_ONEDNN_OPENMP - auto onednn_stream = - stream(dnnl::threadpool_interop::make_stream(cpu_engine, &thread_pool)); -#else - auto onednn_stream = stream(cpu_engine); -#endif // ENABLE_ONEDNN_OPENMP + auto onednn_stream = MakeOneDnnStream(cpu_engine, thread_pool.get()); std::string config_str(static_cast(args[arg_indx++])); OneDnnMatMulConfig matmul_config; @@ -82,26 +199,14 @@ ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_OneDnnMatMul( auto result_md = result_minfo.GetOneDnnMemDesc(); // Update dims and strides for transposed inputs. - bool transpose_a = matmul_config.transpose_a(); - if (transpose_a) { - int64_t ndims = lhs_md.get_ndims(); - std::vector permutation(ndims); - std::iota(permutation.begin(), permutation.end(), 0); - std::swap(permutation[ndims - 1], permutation[ndims - 2]); - lhs_md = lhs_md.permute_axes(permutation); + if (matmul_config.transpose_a()) { + lhs_md = Transpose(lhs_md); } - bool transpose_b = matmul_config.transpose_b(); - if (transpose_b) { - int64_t ndims = rhs_md.get_ndims(); - std::vector permutation(ndims); - std::iota(permutation.begin(), permutation.end(), 0); - std::swap(permutation[ndims - 1], permutation[ndims - 2]); - rhs_md = rhs_md.permute_axes(permutation); + + if (matmul_config.transpose_b()) { + rhs_md = Transpose(rhs_md); } - auto lhs_mem = memory(lhs_md, cpu_engine, lhs_minfo.Data()); - auto rhs_mem = memory(rhs_md, cpu_engine, rhs_minfo.Data()); auto bias_mem = memory(nullptr); - auto result_mem = memory(result_md, cpu_engine, result_minfo.Data()); std::vector> postop_args; // Currently, GELU/ReLU only fusion is supported. @@ -143,6 +248,13 @@ ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_OneDnnMatMul( postop_args.emplace_back( arg_idx, dnnl::memory(binary_md, cpu_engine, binary_minfo.Data())); } break; + case OneDnnMatMulConfig::LINEAR: { + float const_float; + *(reinterpret_cast(&const_float)) = + matmul_config.alpha_typecast(); + post_ops.append_eltwise(dnnl::algorithm::eltwise_linear, const_float, + 0.f); + } break; default: LOG(FATAL) << __FILE__ << ":" << __LINE__ << " Attempt to call OneDNN MatMul runtime library with " @@ -158,6 +270,22 @@ ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_OneDnnMatMul( attrs.set_post_ops(post_ops); } + bool weights_packed = rhs_md.get_ndims() == 1 && + rhs_md.get_dims().front() != lhs_md.get_dims().back(); + if (weights_packed) { + // expected 2D buffer with last dim of input and last dim of output + auto rhs_any_md = + memory::desc({lhs_md.get_dims().back(), result_md.get_dims().back()}, + rhs_md.get_data_type(), memory::format_tag::any); + + rhs_md = OneDnnMatMulOptWeightsDesc(cpu_engine, lhs_md, rhs_any_md, bias_md, + result_md); + } + + auto lhs_mem = memory(lhs_md, cpu_engine, lhs_minfo.Data()); + auto rhs_mem = memory(rhs_md, cpu_engine, rhs_minfo.Data()); + auto result_mem = memory(result_md, cpu_engine, result_minfo.Data()); + auto matmul_pd = matmul::primitive_desc(cpu_engine, lhs_md, rhs_md, bias_md, result_md, attrs); @@ -177,6 +305,78 @@ ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_OneDnnMatMul( matmul_prim.execute(onednn_stream, matmul_args); } +ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_OneDnnMatMulReorder( + void* result, void** args) { + // args[0]: ptr to nargs + // args[1]: ptr to ExecutableRunOptions + // args[2]: ptr to OneDnnMatMulConfig + // args[3...]: ptrs to operands + int arg_indx = 0; + const int64_t num_args = *(static_cast(args[arg_indx++])); + + const xla::ExecutableRunOptions* run_options = + static_cast(args[arg_indx++]); + + auto thread_pool = CreateOneDnnThreadPool(run_options); + engine cpu_engine(engine::kind::cpu, 0); + auto onednn_stream = MakeOneDnnStream(cpu_engine, thread_pool.get()); + + std::string config_str(static_cast(args[arg_indx++])); + OneDnnMatMulConfig matmul_config; + matmul_config.ParseFromString(config_str); + + MemrefInfo input_minfo(args[arg_indx++]); + MemrefInfo weight_minfo(args[arg_indx++]); + MemrefInfo output_minfo(args[arg_indx++]); + MemrefInfo result_minfo(result); + + auto input_md = input_minfo.GetOneDnnMemDesc(); + auto weight_md = weight_minfo.GetOneDnnMemDesc(); + auto output_md = output_minfo.GetOneDnnMemDesc(); + + auto bias_md = dnnl::memory::desc{}; + if (absl::c_count(matmul_config.fused_ops(), OneDnnMatMulConfig::BIAS) > 0) { + MemrefInfo bias_minfo(args[arg_indx++]); + bias_md = bias_minfo.GetOneDnnMemDesc(); + } + + XLA_LIGHTWEIGHT_CHECK(num_args >= arg_indx); + + // Update dims and strides for transposed inputs. + bool transpose_a = matmul_config.transpose_a(); + if (transpose_a) { + input_md = Transpose(input_md); + } + bool transpose_b = matmul_config.transpose_b(); + if (transpose_b) { + weight_md = Transpose(weight_md); + } + + // extend bias rank to match result rank + if (!bias_md.is_zero()) { + auto missed_rank = output_md.get_ndims() - bias_md.get_ndims(); + XLA_LIGHTWEIGHT_CHECK(missed_rank >= 0); + if (missed_rank > 0) { + auto bias_dims = bias_md.get_dims(); + bias_dims.insert(bias_dims.begin(), missed_rank, 1); + bias_md = bias_md.reshape(bias_dims); + } + } + + auto result_md = OneDnnMatMulOptWeightsDesc(cpu_engine, input_md, weight_md, + bias_md, output_md); + + XLA_LIGHTWEIGHT_CHECK(result_minfo.GetOneDnnMemDesc().get_size() == + result_md.get_size()); + + auto weight_mem = dnnl::memory{weight_md, cpu_engine, weight_minfo.Data()}; + auto result_mem = dnnl::memory{result_md, cpu_engine, result_minfo.Data()}; + + dnnl::reorder rdr{weight_mem, result_mem}; + rdr.execute(onednn_stream, weight_mem, result_mem); + onednn_stream.wait(); +} + } // namespace cpu } // namespace xla diff --git a/third_party/xla/xla/service/cpu/onednn_matmul.h b/third_party/xla/xla/service/cpu/onednn_matmul.h index 449971f61a30cb..6647eee2621a90 100644 --- a/third_party/xla/xla/service/cpu/onednn_matmul.h +++ b/third_party/xla/xla/service/cpu/onednn_matmul.h @@ -17,11 +17,21 @@ limitations under the License. #define XLA_SERVICE_CPU_ONEDNN_MATMUL_H_ #if defined(INTEL_MKL) && defined(ENABLE_ONEDNN_V3) +#include "xla/service/cpu/backend_config.pb.h" +#include "xla/shape.h" + namespace xla { namespace cpu { +Shape OneDnnMatMulOptWeightsShape(const Shape& input_shape, + const Shape& weights_shape, + const Shape& bias_shape, + const Shape& output_shape, + const OneDnnMatMulConfig* matmul_config); + extern "C" { extern void __xla_cpu_runtime_OneDnnMatMul(void* result, void** args); +extern void __xla_cpu_runtime_OneDnnMatMulReorder(void* result, void** args); } // extern "C" } // namespace cpu diff --git a/third_party/xla/xla/service/cpu/onednn_matmul_rewriter.cc b/third_party/xla/xla/service/cpu/onednn_matmul_rewriter.cc index 59cc4848fdb32c..840661310f8744 100644 --- a/third_party/xla/xla/service/cpu/onednn_matmul_rewriter.cc +++ b/third_party/xla/xla/service/cpu/onednn_matmul_rewriter.cc @@ -15,18 +15,24 @@ limitations under the License. #if defined(INTEL_MKL) && defined(ENABLE_ONEDNN_V3) +#define EIGEN_USE_THREADS + #include "xla/service/cpu/onednn_matmul_rewriter.h" +#include "xla/executable_run_options.h" +#include "xla/hlo/evaluator/hlo_evaluator.h" #include "xla/hlo/ir/dfs_hlo_visitor_with_default.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" #include "xla/service/cpu/backend_config.pb.h" +#include "xla/service/cpu/onednn_matmul.h" #include "xla/service/cpu/onednn_memory_util.h" #include "xla/service/cpu/onednn_util.h" #include "xla/service/pattern_matcher.h" #include "xla/status_macros.h" #include "tsl/platform/logging.h" // IWYU pragma: keep +#include "tsl/util/onednn_threadpool.h" namespace xla { namespace cpu { @@ -116,6 +122,26 @@ auto ConstScalarNear(double value) { }); } +bool IsScalar(const HloInstruction* instr) { + return ShapeUtil::IsEffectiveScalar(instr->shape()); +} + +std::optional GetConstantValueAsFloat32(const HloInstruction* inst) { + if (!IsScalar(inst)) { + return std::nullopt; + } + switch (inst->shape().element_type()) { + case F16: + return inst->literal().GetFirstElement(); + case BF16: + return inst->literal().GetFirstElement(); + case F32: + return inst->literal().GetFirstElement(); + default: + return std::nullopt; + } +} + inline auto BcastConstScalarNear(double value) { return m::Broadcast(ConstScalarNear(value)); } @@ -581,6 +607,36 @@ class OneDnnMatMulRewriteVisitor : public DfsHloRewriteVisitor { intermediate_instr); } } + + HloInstruction *dot, *constant; + auto pattern = m::Op(&instr) + .WithOpcode(HloOpcode::kMultiply) + .WithBinaryOperandsAnyOrder( + m::Op(&dot) + .WithOneUser() + .WithOpcode(HloOpcode::kCustomCall) + .WithCustomCallTarget({"__onednn$matmul"}), + m::Broadcast(m::Constant(&constant)).WithOneUser()); + + if (Match(instr, pattern)) { + std::vector new_operands; + auto constant_value = *GetConstantValueAsFloat32(constant); + + for (auto operand : dot->operands()) { + new_operands.push_back(operand); + } + auto matmul_call = Cast(instr->AddInstruction( + dot->CloneWithNewOperands(instr->shape(), new_operands))); + auto backend_config = matmul_call->backend_config(); + backend_config->mutable_onednn_matmul_config()->add_fused_ops( + OneDnnMatMulConfig::LINEAR); + // Casting to int32 because of issues in proto config for decimal types + // handling. + backend_config->mutable_onednn_matmul_config()->set_alpha_typecast( + *(reinterpret_cast(&constant_value))); + TF_RETURN_IF_ERROR(matmul_call->set_backend_config(*backend_config)); + TF_RETURN_IF_ERROR(ReplaceInstruction(instr, matmul_call)); + } return OkStatus(); } @@ -605,11 +661,150 @@ class OneDnnMatMulRewriteVisitor : public DfsHloRewriteVisitor { } }; +class OneDnnMatMulReorderVisitor : public DfsHloRewriteVisitor { + public: + OneDnnMatMulReorderVisitor(int intra_op_parallelism, + const tsl::thread::ThreadPool* compile_threadpool) + : intra_op_parallelism_(intra_op_parallelism > 0 + ? intra_op_parallelism + : tsl::port::MaxParallelism()), + evaluator_(/*max_loop_iterations=*/0) { + if (compile_threadpool) { + threadpool_device_.reset( + new Eigen::ThreadPoolDevice(compile_threadpool->AsEigenThreadPool(), + compile_threadpool->NumThreads())); + } else { + threadpool_handle_.reset(new tsl::thread::ThreadPool( + tsl::Env::Default(), "XLACpuCompile", tsl::port::MaxParallelism())); + threadpool_device_.reset( + new Eigen::ThreadPoolDevice(threadpool_handle_->AsEigenThreadPool(), + threadpool_handle_->NumThreads())); + } + + evaluator_.set_custom_call_handler( + [this](const HloInstruction* custom_call_instr, + absl::Span operands) -> StatusOr { + TF_ASSIGN_OR_RETURN( + auto backend_config, + custom_call_instr->backend_config()); + auto& matmul_config = backend_config.onednn_matmul_config(); + + auto output = Literal::CreateFromShape(custom_call_instr->shape()); + + int64_t nargs = operands.size() + 3; + std::vector args; + args.push_back(&nargs); + + ExecutableRunOptions run_options; + run_options.set_intra_op_thread_pool(threadpool_device_.get()); + args.push_back(&run_options); // No ExecutableRunOptions. + + // OneDnnMatMulConfig + std::string config; + matmul_config.SerializeToString(&config); + args.push_back(config.data()); + + std::vector minfo_ptrs(operands.size()); + std::transform(operands.begin(), operands.end(), minfo_ptrs.begin(), + CreateMemrefInfoFromLiteral); + for (auto& minfo_ptr : minfo_ptrs) { + args.push_back(static_cast(minfo_ptr.get())); + } + + auto result_ptr = CreateMemrefInfoFromLiteral(&output); + __xla_cpu_runtime_OneDnnMatMulReorder(result_ptr.get(), args.data()); + + return output; + }); + } + + Status HandleCustomCall(HloInstruction* custom_call) override { + HloInstruction* matmul; + if (Match(custom_call, OneDnnMatmulInstr(&matmul))) { + TF_ASSIGN_OR_RETURN(auto backend_config, + matmul->backend_config()); + auto& matmul_config = backend_config.onednn_matmul_config(); + + auto operands = custom_call->operands(); + auto input = operands[0]; + auto weight = operands[1]; // assuming weights is the second operand + + auto input_shape = input->shape(); + auto weight_shape = weight->shape(); + if (weight_shape.rank() != 2) { + // pre-pack only 2D weights + return DefaultAction(custom_call); + } + + auto bias_shape = + absl::c_count(matmul_config.fused_ops(), OneDnnMatMulConfig::BIAS) > 0 + ? operands.at(2)->shape() + : Shape(); + + auto output_shape = custom_call->shape(); + +#ifndef ENABLE_ONEDNN_OPENMP + // set oneDNN cuncurrency settings (which is thread-local) + tsl::OneDnnThreadPool::set_onednn_max_threads(intra_op_parallelism_); +#endif + auto new_weight_shape = OneDnnMatMulOptWeightsShape( + input_shape, weight_shape, bias_shape, output_shape, &matmul_config); + + auto cmpt = custom_call->parent(); + std::vector new_operands{ + cmpt->AddInstruction( + HloInstruction::CreateConstant(Literal(input_shape))), + weight, + cmpt->AddInstruction( + HloInstruction::CreateConstant(Literal(output_shape))), + }; + + if (ShapeUtil::IsInitialized(bias_shape)) { + new_operands.push_back(cmpt->AddInstruction( + HloInstruction::CreateConstant(Literal(bias_shape)))); + } + + HloInstruction* reorder_call = + custom_call->AddInstruction(HloInstruction::CreateCustomCall( + new_weight_shape, new_operands, "__onednn$matmul_reorder")); + + reorder_call->CopyBackendConfigFrom(custom_call); + + Literal result; + + if (evaluator_.TryEvaluate(reorder_call, &result, true)) { + HloInstruction* reordered_weight = custom_call->AddInstruction( + HloInstruction::CreateConstant(std::move(result))); + return custom_call->ReplaceOperandWithDifferentShape(1, + reordered_weight); + + } else { + return DefaultAction(custom_call); + } + } + return DefaultAction(custom_call); + } + + private: + int intra_op_parallelism_; + HloEvaluator evaluator_; + std::unique_ptr threadpool_handle_; + std::unique_ptr threadpool_device_; +}; + StatusOr OneDnnMatMulRewriter::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { OneDnnMatMulRewriteVisitor visitor; - return visitor.RunOnModule(module, execution_threads); + TF_ASSIGN_OR_RETURN(auto result, + visitor.RunOnModule(module, execution_threads)); + + OneDnnMatMulReorderVisitor reorder_visitor(intra_op_parallelism_, + compile_threadpool_); + TF_ASSIGN_OR_RETURN(auto result2, + reorder_visitor.RunOnModule(module, execution_threads)); + + return {result || result2}; } } // namespace cpu diff --git a/third_party/xla/xla/service/cpu/onednn_matmul_rewriter.h b/third_party/xla/xla/service/cpu/onednn_matmul_rewriter.h index 9a4e1aad8f1ec7..4c508e811aa473 100644 --- a/third_party/xla/xla/service/cpu/onednn_matmul_rewriter.h +++ b/third_party/xla/xla/service/cpu/onednn_matmul_rewriter.h @@ -20,9 +20,11 @@ limitations under the License. #include #include "absl/algorithm/container.h" +#include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/service/hlo_pass_interface.h" +#include "tsl/platform/threadpool.h" namespace xla { namespace cpu { @@ -31,6 +33,11 @@ namespace cpu { // calls. class OneDnnMatMulRewriter : public HloModulePass { public: + OneDnnMatMulRewriter(int intra_op_parallelism, + const tsl::thread::ThreadPool* compile_threadpool) + : intra_op_parallelism_(intra_op_parallelism), + compile_threadpool_(compile_threadpool) {} + absl::string_view name() const override { return "onednn-matmul-rewriter"; } using HloPassInterface::Run; @@ -39,6 +46,10 @@ class OneDnnMatMulRewriter : public HloModulePass { const absl::flat_hash_set& execution_threads) override; static bool ShouldRewrite(const HloInstruction* dot_instr); + + private: + int intra_op_parallelism_; + const tsl::thread::ThreadPool* compile_threadpool_; }; } // namespace cpu diff --git a/third_party/xla/xla/service/cpu/onednn_memory_util.cc b/third_party/xla/xla/service/cpu/onednn_memory_util.cc index 602c526ccfc2ce..372ce97c278932 100644 --- a/third_party/xla/xla/service/cpu/onednn_memory_util.cc +++ b/third_party/xla/xla/service/cpu/onednn_memory_util.cc @@ -50,6 +50,27 @@ struct MemrefInfoPOD { void* data; }; +MemrefInfoHandler CreateMemrefInfoFromLiteral(const Literal* literal) { + MemrefInfoHandler result(new MemrefInfoPOD); + + const auto& shape = literal->shape(); + result->dtype = shape.element_type(); + result->rank = shape.rank(); + auto dimensions = shape.dimensions(); + std::copy(dimensions.begin(), dimensions.end(), + absl::MakeSpan(result->dims).begin()); + + int64_t stride = 1; + for (int i : shape.layout().minor_to_major()) { + result->strides[i] = stride; + stride *= dimensions.at(i); + } + + result->data = const_cast(literal->untyped_data()); + + return result; +} + StackAlloca GetAllocaAndEmitMemrefInfo(llvm::IRBuilder<>& builder, const llvm_ir::IrArray& ir_array) { const Shape& shape = ir_array.GetShape(); diff --git a/third_party/xla/xla/service/cpu/onednn_memory_util.h b/third_party/xla/xla/service/cpu/onednn_memory_util.h index 1556da60bd4c46..fb5292843b5a61 100644 --- a/third_party/xla/xla/service/cpu/onednn_memory_util.h +++ b/third_party/xla/xla/service/cpu/onednn_memory_util.h @@ -17,11 +17,14 @@ limitations under the License. #define XLA_SERVICE_CPU_ONEDNN_MEMORY_UTIL_H_ #if defined(INTEL_MKL) && defined(ENABLE_ONEDNN_V3) +#include + #include "dnnl.hpp" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/LLVMContext.h" #include "llvm/IR/Value.h" +#include "xla/literal.h" #include "xla/service/llvm_ir/ir_array.h" #include "xla/xla_data.pb.h" @@ -40,6 +43,9 @@ struct StackAlloca { // Declare as opaque to put structure definition together with dependant code. struct MemrefInfoPOD; +using MemrefInfoHandler = std::shared_ptr; + +MemrefInfoHandler CreateMemrefInfoFromLiteral(const Literal* literal); StackAlloca GetAllocaAndEmitMemrefInfo(llvm::IRBuilder<>& builder, const llvm_ir::IrArray& ir_array); diff --git a/third_party/xla/xla/service/cpu/onednn_util.h b/third_party/xla/xla/service/cpu/onednn_util.h index ae975949fecdf2..0b8a7c65b0bf48 100644 --- a/third_party/xla/xla/service/cpu/onednn_util.h +++ b/third_party/xla/xla/service/cpu/onednn_util.h @@ -25,15 +25,21 @@ namespace cpu { inline bool IsSupportedType(xla::PrimitiveType dtype) { using tsl::port::CPUFeature; - static bool is_bf16_supported = TestCPUFeature(CPUFeature::AVX512_BF16) || - TestCPUFeature(CPUFeature::AMX_BF16); + // TODO(intel-tf): Enable more types. switch (dtype) { case F32: return true; case BF16: - return is_bf16_supported; + return TestCPUFeature(CPUFeature::AVX512F) || + TestCPUFeature(CPUFeature::AVX_NE_CONVERT) || + TestCPUFeature(CPUFeature::AMX_BF16); + case F16: + return TestCPUFeature(CPUFeature::AVX512BW) && + (TestCPUFeature(CPUFeature::AVX512_FP16) || + TestCPUFeature(CPUFeature::AMX_FP16) || + TestCPUFeature(CPUFeature::AVX_NE_CONVERT)); default: - break; + return false; } return false; } diff --git a/third_party/xla/xla/service/cpu/runtime/BUILD b/third_party/xla/xla/service/cpu/runtime/BUILD index 09d898ff206ae7..1169fdaa39d394 100644 --- a/third_party/xla/xla/service/cpu/runtime/BUILD +++ b/third_party/xla/xla/service/cpu/runtime/BUILD @@ -1,7 +1,8 @@ load("@local_tsl//tsl/platform:rules_cc.bzl", "cc_library") package( - default_visibility = ["//visibility:public"], + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = [":friends"], licenses = ["notice"], ) @@ -23,7 +24,6 @@ cc_library( name = "collectives", srcs = ["collectives.cc"], hdrs = ["collectives.h"], - visibility = ["//visibility:public"], deps = [ "//xla:executable_run_options", "//xla:shape_util", @@ -31,7 +31,12 @@ cc_library( "//xla/runtime:custom_call", "//xla/runtime:custom_call_registry", "//xla/runtime:executable", + "//xla/runtime:memref_view", "//xla/service/cpu:cpu_runtime", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@llvm-project//llvm:Support", "@llvm-project//mlir:Support", ], ) @@ -40,14 +45,15 @@ cc_library( name = "convolution", srcs = ["convolution.cc"], hdrs = ["convolution.h"], - visibility = ["//visibility:public"], deps = [ "//xla:executable_run_options", + "//xla:xla_data_proto_cc", "//xla/runtime:memref_view", "//xla/service/cpu:runtime_conv2d", "//xla/service/cpu:runtime_conv3d", "@com_google_absl//absl/status", "@com_google_absl//absl/types:span", + "@eigen_archive//:eigen3", ], ) @@ -58,9 +64,14 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":convolution", + "//xla:xla_data_proto_cc", "//xla/runtime:aot_ffi", "//xla/runtime:aot_ffi_execution_context", + "//xla/runtime:memref_view", "//xla/runtime/ffi:ffi_api", + "//xla/runtime/ffi:ffi_c_api_hdrs", + "@com_google_absl//absl/status", + "@com_google_absl//absl/types:span", ], ) @@ -68,13 +79,15 @@ cc_library( name = "convolution_call", srcs = ["convolution_call.cc"], hdrs = ["convolution_call.h"], - visibility = ["//visibility:public"], deps = [ ":convolution", "//xla:executable_run_options", "//xla/runtime:custom_call", "//xla/runtime:custom_call_registry", "//xla/runtime:executable", + "//xla/runtime:memref_view", + "@com_google_absl//absl/types:span", + "@llvm-project//mlir:Support", ], ) @@ -82,16 +95,20 @@ cc_library( name = "custom_call", srcs = ["custom_call.cc"], hdrs = ["custom_call.h"], - visibility = ["//visibility:public"], deps = [ "//xla:shape_util", "//xla:xla_proto_cc", "//xla/runtime:custom_call", "//xla/runtime:custom_call_registry", "//xla/runtime:executable", + "//xla/runtime:memref_view", "//xla/service:custom_call_status_internal", + "//xla/service:custom_call_status_public_headers", "//xla/service:custom_call_target_registry", "//xla/service:hlo_proto_cc", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@llvm-project//llvm:Support", "@llvm-project//mlir:Support", ], ) @@ -100,9 +117,9 @@ cc_library( name = "fft_call", srcs = ["fft_call.cc"], hdrs = ["fft_call.h"], - visibility = ["//visibility:public"], deps = [ "//xla:executable_run_options", + "//xla:xla_data_proto_cc", "//xla:xla_proto_cc", "//xla/runtime:custom_call", "//xla/runtime:custom_call_registry", @@ -113,6 +130,8 @@ cc_library( "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + "@llvm-project//mlir:Support", ], ) @@ -120,15 +139,18 @@ cc_library( name = "xfeed", srcs = ["xfeed.cc"], hdrs = ["xfeed.h"], - visibility = ["//visibility:public"], deps = [ "//xla:executable_run_options", "//xla:shape_util", + "//xla:xla_data_proto_cc", "//xla/runtime:custom_call", "//xla/runtime:custom_call_registry", "//xla/runtime:executable", + "//xla/runtime:memref_view", "//xla/service/cpu:cpu_runtime", - "@llvm-project//mlir:IR", + "@com_google_absl//absl/status", + "@com_google_absl//absl/types:span", + "@llvm-project//llvm:Support", "@llvm-project//mlir:Support", ], ) @@ -137,9 +159,9 @@ cc_library( name = "rng", srcs = ["rng.cc"], hdrs = ["rng.h"], - visibility = ["//visibility:public"], deps = [ "//xla:executable_run_options", + "//xla:xla_data_proto_cc", "//xla/runtime:memref_view", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", @@ -150,13 +172,14 @@ cc_library( name = "rng_call", srcs = ["rng_call.cc"], hdrs = ["rng_call.h"], - visibility = ["//visibility:public"], deps = [ ":rng", "//xla:executable_run_options", "//xla/runtime:custom_call", "//xla/runtime:custom_call_registry", "//xla/runtime:executable", + "//xla/runtime:memref_view", + "@llvm-project//mlir:Support", ], ) @@ -167,8 +190,12 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":rng", + "//xla:xla_data_proto_cc", "//xla/runtime:aot_ffi", "//xla/runtime:aot_ffi_execution_context", + "//xla/runtime:memref_view", "//xla/runtime/ffi:ffi_api", + "//xla/runtime/ffi:ffi_c_api_hdrs", + "@com_google_absl//absl/status", ], ) diff --git a/third_party/xla/xla/service/cpu/runtime/collectives.cc b/third_party/xla/xla/service/cpu/runtime/collectives.cc index 5c84302321bfe4..6034cc600245b2 100644 --- a/third_party/xla/xla/service/cpu/runtime/collectives.cc +++ b/third_party/xla/xla/service/cpu/runtime/collectives.cc @@ -26,11 +26,19 @@ #include #include +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "xla/executable_run_options.h" #include "xla/runtime/custom_call.h" #include "xla/runtime/custom_call_registry.h" #include "xla/runtime/executable.h" +#include "xla/runtime/memref_view.h" #include "xla/service/cpu/cpu_runtime.h" #include "xla/shape.h" #include "xla/shape_util.h" diff --git a/third_party/xla/xla/service/cpu/runtime/convolution.cc b/third_party/xla/xla/service/cpu/runtime/convolution.cc index 6235cfa3b19e06..bc2c7ef29b2535 100644 --- a/third_party/xla/xla/service/cpu/runtime/convolution.cc +++ b/third_party/xla/xla/service/cpu/runtime/convolution.cc @@ -19,9 +19,13 @@ #include #include "absl/status/status.h" +#include "absl/types/span.h" +#include "Eigen/Core" // from @eigen_archive #include "xla/executable_run_options.h" +#include "xla/runtime/memref_view.h" #include "xla/service/cpu/runtime_conv2d.h" #include "xla/service/cpu/runtime_conv3d.h" +#include "xla/xla_data.pb.h" namespace xla { namespace cpu { diff --git a/third_party/xla/xla/service/cpu/runtime/convolution.h b/third_party/xla/xla/service/cpu/runtime/convolution.h index 03da05d7e6f3ad..fe4433774a7040 100644 --- a/third_party/xla/xla/service/cpu/runtime/convolution.h +++ b/third_party/xla/xla/service/cpu/runtime/convolution.h @@ -16,6 +16,7 @@ #include +#include "absl/status/status.h" #include "absl/types/span.h" #include "xla/executable_run_options.h" #include "xla/runtime/memref_view.h" diff --git a/third_party/xla/xla/service/cpu/runtime/convolution_call.cc b/third_party/xla/xla/service/cpu/runtime/convolution_call.cc index bf8bb3a597a2e5..793f6285da40c1 100644 --- a/third_party/xla/xla/service/cpu/runtime/convolution_call.cc +++ b/third_party/xla/xla/service/cpu/runtime/convolution_call.cc @@ -24,9 +24,13 @@ #include #include +#include "absl/types/span.h" +#include "mlir/Support/LogicalResult.h" // from @llvm-project #include "xla/executable_run_options.h" #include "xla/runtime/custom_call.h" +#include "xla/runtime/custom_call_registry.h" #include "xla/runtime/executable.h" +#include "xla/runtime/memref_view.h" #include "xla/service/cpu/runtime/convolution.h" namespace xla { diff --git a/third_party/xla/xla/service/cpu/runtime/convolution_ffi.cc b/third_party/xla/xla/service/cpu/runtime/convolution_ffi.cc index 7feafcfd1bc84c..9673938a05eea9 100644 --- a/third_party/xla/xla/service/cpu/runtime/convolution_ffi.cc +++ b/third_party/xla/xla/service/cpu/runtime/convolution_ffi.cc @@ -14,10 +14,15 @@ #include "xla/service/cpu/runtime/convolution_ffi.h" +#include "absl/status/status.h" +#include "absl/types/span.h" #include "xla/runtime/aot_ffi.h" #include "xla/runtime/aot_ffi_execution_context.h" #include "xla/runtime/ffi/ffi_api.h" +#include "xla/runtime/ffi/ffi_c_api.h" +#include "xla/runtime/memref_view.h" #include "xla/service/cpu/runtime/convolution.h" +#include "xla/xla_data.pb.h" namespace xla { struct ExecutableRunOptions; diff --git a/third_party/xla/xla/service/cpu/runtime/custom_call.cc b/third_party/xla/xla/service/cpu/runtime/custom_call.cc index 2b1d846edfdf01..6b45f3d1a36718 100644 --- a/third_party/xla/xla/service/cpu/runtime/custom_call.cc +++ b/third_party/xla/xla/service/cpu/runtime/custom_call.cc @@ -25,11 +25,17 @@ #include #include +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringRef.h" #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "xla/primitive_util.h" #include "xla/runtime/custom_call.h" #include "xla/runtime/custom_call_registry.h" #include "xla/runtime/executable.h" +#include "xla/runtime/memref_view.h" +#include "xla/service/custom_call_status.h" #include "xla/service/custom_call_status_internal.h" #include "xla/service/custom_call_target_registry.h" #include "xla/service/hlo.pb.h" diff --git a/third_party/xla/xla/service/cpu/runtime/fft_call.cc b/third_party/xla/xla/service/cpu/runtime/fft_call.cc index 0d0e07005e4a2f..c62b57422a1543 100644 --- a/third_party/xla/xla/service/cpu/runtime/fft_call.cc +++ b/third_party/xla/xla/service/cpu/runtime/fft_call.cc @@ -27,6 +27,8 @@ #include "absl/container/inlined_vector.h" #include "absl/status/status.h" #include "absl/strings/str_cat.h" +#include "absl/types/span.h" +#include "mlir/Support/LogicalResult.h" // from @llvm-project #include "xla/executable_run_options.h" #include "xla/runtime/custom_call.h" #include "xla/runtime/custom_call_registry.h" @@ -35,6 +37,7 @@ #include "xla/service/cpu/runtime_fft.h" #include "xla/service/hlo.pb.h" #include "xla/xla.pb.h" +#include "xla/xla_data.pb.h" namespace xla { namespace cpu { diff --git a/third_party/xla/xla/service/cpu/runtime/rng.cc b/third_party/xla/xla/service/cpu/runtime/rng.cc index f8daf99fd965b7..7f2edd42b56b26 100644 --- a/third_party/xla/xla/service/cpu/runtime/rng.cc +++ b/third_party/xla/xla/service/cpu/runtime/rng.cc @@ -20,6 +20,8 @@ #include "absl/status/status.h" #include "absl/strings/str_cat.h" #include "xla/executable_run_options.h" +#include "xla/runtime/memref_view.h" +#include "xla/xla_data.pb.h" namespace xla { namespace cpu { diff --git a/third_party/xla/xla/service/cpu/runtime/rng_call.cc b/third_party/xla/xla/service/cpu/runtime/rng_call.cc index cd895a8113127f..6bcbe0fe0bf7e4 100644 --- a/third_party/xla/xla/service/cpu/runtime/rng_call.cc +++ b/third_party/xla/xla/service/cpu/runtime/rng_call.cc @@ -17,10 +17,12 @@ #include #include +#include "mlir/Support/LogicalResult.h" // from @llvm-project #include "xla/executable_run_options.h" #include "xla/runtime/custom_call.h" #include "xla/runtime/custom_call_registry.h" #include "xla/runtime/executable.h" +#include "xla/runtime/memref_view.h" #include "xla/service/cpu/runtime/rng.h" namespace xla { diff --git a/third_party/xla/xla/service/cpu/runtime/rng_ffi.cc b/third_party/xla/xla/service/cpu/runtime/rng_ffi.cc index 886be013f818e4..8efd9aabfade06 100644 --- a/third_party/xla/xla/service/cpu/runtime/rng_ffi.cc +++ b/third_party/xla/xla/service/cpu/runtime/rng_ffi.cc @@ -14,10 +14,14 @@ #include "xla/service/cpu/runtime/rng_ffi.h" +#include "absl/status/status.h" #include "xla/runtime/aot_ffi.h" #include "xla/runtime/aot_ffi_execution_context.h" #include "xla/runtime/ffi/ffi_api.h" +#include "xla/runtime/ffi/ffi_c_api.h" +#include "xla/runtime/memref_view.h" #include "xla/service/cpu/runtime/rng.h" +#include "xla/xla_data.pb.h" namespace xla { struct ExecutableRunOptions; diff --git a/third_party/xla/xla/service/cpu/runtime/xfeed.cc b/third_party/xla/xla/service/cpu/runtime/xfeed.cc index e33d08be8425fa..38bb2eb34644e3 100644 --- a/third_party/xla/xla/service/cpu/runtime/xfeed.cc +++ b/third_party/xla/xla/service/cpu/runtime/xfeed.cc @@ -26,15 +26,20 @@ #include #include -#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "absl/status/status.h" +#include "absl/types/span.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "xla/executable_run_options.h" #include "xla/primitive_util.h" #include "xla/runtime/custom_call.h" #include "xla/runtime/custom_call_registry.h" #include "xla/runtime/executable.h" +#include "xla/runtime/memref_view.h" #include "xla/service/cpu/cpu_runtime.h" #include "xla/shape_util.h" +#include "xla/xla_data.pb.h" namespace xla { namespace cpu { diff --git a/third_party/xla/xla/service/cpu/simple_orc_jit.cc b/third_party/xla/xla/service/cpu/simple_orc_jit.cc index bc19cc53a35f5a..3bf08e4f300e6f 100644 --- a/third_party/xla/xla/service/cpu/simple_orc_jit.cc +++ b/third_party/xla/xla/service/cpu/simple_orc_jit.cc @@ -539,6 +539,7 @@ bool RegisterKnownJITSymbols() { REGISTER_CPU_RUNTIME_SYMBOL(OneDnnMatMul); REGISTER_CPU_RUNTIME_SYMBOL(OneDnnSoftmax); REGISTER_CPU_RUNTIME_SYMBOL(OneDnnLayerNorm); + REGISTER_CPU_RUNTIME_SYMBOL(OneDnnMatMulReorder); #endif // INTEL_MKL && ENABLE_ONEDNN_V3 registry->Register("__gnu_f2h_ieee", reinterpret_cast(__gnu_f2h_ieee), diff --git a/third_party/xla/xla/service/cpu/tests/BUILD b/third_party/xla/xla/service/cpu/tests/BUILD index 8da93b855babf5..c0044b17d31f58 100644 --- a/third_party/xla/xla/service/cpu/tests/BUILD +++ b/third_party/xla/xla/service/cpu/tests/BUILD @@ -6,7 +6,8 @@ load("@local_tsl//tsl:tsl.bzl", "tsl_copts") load("@local_tsl//tsl:tsl.default.bzl", "filegroup") package( - default_visibility = ["//visibility:public"], + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = [":friends"], licenses = ["notice"], ) @@ -24,14 +25,12 @@ filegroup( "**/*.cc", "**/*.h", ]), - visibility = ["//visibility:public"], ) cc_library( name = "cpu_codegen_test", testonly = True, hdrs = ["cpu_codegen_test.h"], - visibility = ["//visibility:public"], deps = [ "//xla/service:cpu_plugin", "//xla/tests:llvm_irgen_test_base", @@ -51,8 +50,8 @@ xla_cc_test( "//xla/service:platform_util", "//xla/service/cpu:cpu_compiler", "//xla/stream_executor", - "//xla/stream_executor:multi_platform_manager", "//xla/stream_executor:platform", + "//xla/stream_executor:platform_manager", "//xla/tests:hlo_test_base", "@com_google_absl//absl/strings", "@com_google_googletest//:gtest", diff --git a/third_party/xla/xla/service/cpu/tests/cpu_aot_export_test.cc b/third_party/xla/xla/service/cpu/tests/cpu_aot_export_test.cc index 3fa10e9394e35e..39528634f7cd82 100644 --- a/third_party/xla/xla/service/cpu/tests/cpu_aot_export_test.cc +++ b/third_party/xla/xla/service/cpu/tests/cpu_aot_export_test.cc @@ -26,8 +26,8 @@ limitations under the License. #include "xla/service/compiler.h" #include "xla/service/executable.h" #include "xla/service/platform_util.h" -#include "xla/stream_executor/multi_platform_manager.h" #include "xla/stream_executor/platform.h" +#include "xla/stream_executor/platform_manager.h" #include "xla/stream_executor/stream_executor.h" #include "xla/tests/hlo_test_base.h" #include "tsl/platform/statusor.h" @@ -53,7 +53,7 @@ TEST_F(CpuAotCompilationTest, ExportAndLoadExecutable) { auto name = absl::AsciiStrToUpper( PlatformUtil::CanonicalPlatformName("host").value()); TF_ASSERT_OK_AND_ASSIGN(se::Platform * platform, - se::MultiPlatformManager::PlatformWithName(name)); + se::PlatformManager::PlatformWithName(name)); TF_ASSERT_OK_AND_ASSIGN(se::StreamExecutor * stream_exec, platform->ExecutorForDevice(0)); diff --git a/third_party/xla/xla/service/custom_call_target_registry.cc b/third_party/xla/xla/service/custom_call_target_registry.cc index 8ceb64e87f4636..1c5f731fc2e79b 100644 --- a/third_party/xla/xla/service/custom_call_target_registry.cc +++ b/third_party/xla/xla/service/custom_call_target_registry.cc @@ -15,6 +15,12 @@ limitations under the License. #include "xla/service/custom_call_target_registry.h" +#include +#include +#include // NOLINT +#include +#include + namespace xla { CustomCallTargetRegistry* CustomCallTargetRegistry::Global() { @@ -26,7 +32,17 @@ void CustomCallTargetRegistry::Register(const std::string& symbol, void* address, const std::string& platform) { std::lock_guard lock(mu_); - registered_symbols_[std::make_pair(symbol, platform)] = address; + const auto [it, inserted] = + registered_symbols_.insert({{symbol, platform}, address}); + if (!inserted && it->second != address) { + std::cerr << "Duplicate custom call registration detected for symbol \"" + << symbol << "\" with different addresses " << address + << "(current) and " << it->second << " (previous) on platform " + << platform + << "Rejecting the registration to avoid confusion about which " + "symbol would actually get used at runtime.\n"; + std::exit(1); + } } void* CustomCallTargetRegistry::Lookup(const std::string& symbol, diff --git a/third_party/xla/xla/service/custom_call_target_registry.h b/third_party/xla/xla/service/custom_call_target_registry.h index 8488b5f0c99229..d2b87892778d14 100644 --- a/third_party/xla/xla/service/custom_call_target_registry.h +++ b/third_party/xla/xla/service/custom_call_target_registry.h @@ -35,7 +35,14 @@ namespace xla { // The XLA:CPU ahead-of-time (AOT) compiler links using a standard offline // linker; so when compiling in CPU AOT mode, you *also* need to make sure the // name of the callee (presumably implemented in C++) matches up with the -// symbolic name used in the CustomCall. +// symbolic name used in the CustomCall. Be careful with the name of the symbol +// you register with the macros: C++ namespaces are not included, including +// anonymous namespaces,so if two libraries attempt to register functions with +// the same name in separate namespaces the registrations will collide. Either +// call the registration macro from the global namespace so that you have to +// refer to the function in a fully-qualified manner (which also requires you to +// emit HLO-based calls to it by the fully-qualified name *and* complicates +// future refactoring!) or use C-style namespacing directly in the symbol name. // // We maintain the registry in both the JIT and the AOT cases for simplicity, // but we only use it when running in JIT mode. diff --git a/third_party/xla/xla/service/custom_call_target_registry_test.cc b/third_party/xla/xla/service/custom_call_target_registry_test.cc new file mode 100644 index 00000000000000..1c7a2b4ad0383e --- /dev/null +++ b/third_party/xla/xla/service/custom_call_target_registry_test.cc @@ -0,0 +1,58 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/custom_call_target_registry.h" + +#include "xla/service/custom_call_status.h" +#include "xla/test.h" + +namespace xla { +namespace { + +void custom_call(void*, const void**, XlaCustomCallStatus*) {} +void custom_call2(void*, const void**, XlaCustomCallStatus*) {} + +TEST(CustomCallRegistryTest, Registers) { + CustomCallTargetRegistry registry; + EXPECT_EQ(registry.Lookup("custom_call", "Host"), nullptr); + registry.Register("custom_call", reinterpret_cast(custom_call), + "Host"); + EXPECT_EQ(custom_call, registry.Lookup("custom_call", "Host")); + // A registration with a different name is fine. + registry.Register("custom_call2", reinterpret_cast(&custom_call), + "Host"); + + EXPECT_EQ(registry.Lookup("custom_call", "CUDA"), nullptr); + // A registration on a different platform is fine. + registry.Register("custom_call", reinterpret_cast(custom_call), + "CUDA"); + EXPECT_EQ(custom_call, registry.Lookup("custom_call", "CUDA")); + + // A second registration of the same function is fine. + registry.Register("custom_call", reinterpret_cast(custom_call), + "Host"); +} + +TEST(CustomCallRegistryDeathTest, RejectsDuplicateRegistrations) { + CustomCallTargetRegistry registry; + registry.Register("custom_call", reinterpret_cast(custom_call), + "Host"); + EXPECT_DEATH(registry.Register("custom_call", + reinterpret_cast(custom_call2), "Host"), + "Duplicate custom call"); +} + +} // namespace +} // namespace xla diff --git a/third_party/xla/xla/service/elemental_ir_emitter.cc b/third_party/xla/xla/service/elemental_ir_emitter.cc index 10f07f92f806ad..b066c66454c98c 100644 --- a/third_party/xla/xla/service/elemental_ir_emitter.cc +++ b/third_party/xla/xla/service/elemental_ir_emitter.cc @@ -31,6 +31,7 @@ limitations under the License. #include "llvm/IR/Constants.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/Intrinsics.h" +#include "llvm/IR/Value.h" #include "llvm/Support/MathExtras.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "xla/hlo/ir/hlo_casting_utils.h" @@ -43,6 +44,7 @@ limitations under the License. #include "xla/service/llvm_ir/ir_array.h" #include "xla/service/llvm_ir/llvm_loop.h" #include "xla/service/llvm_ir/llvm_util.h" +#include "xla/service/llvm_ir/math_ops.h" #include "xla/shape_util.h" #include "xla/status_macros.h" #include "xla/statusor.h" @@ -902,6 +904,8 @@ StatusOr ElementalIrEmitter::EmitFloatUnaryOp( primitive_util::BitWidth(from_type), primitive_util::BitWidth(to_type)); } + case HloOpcode::kErf: + return EmitErf(op->shape().element_type(), operand_value); case HloOpcode::kExp: return EmitExp(op->shape().element_type(), operand_value, ""); case HloOpcode::kExpm1: @@ -2041,6 +2045,11 @@ StatusOr ElementalIrEmitter::EmitTanh(PrimitiveType prim_type, return Unimplemented("tanh"); } +StatusOr ElementalIrEmitter::EmitErf(PrimitiveType prim_type, + llvm::Value* value) { + return Unimplemented("erf"); +} + StatusOr ElementalIrEmitter::EmitTan(PrimitiveType prim_type, llvm::Value* value) { auto sin_x = llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::sin, {value}, @@ -2971,6 +2980,7 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator( case HloOpcode::kConvert: case HloOpcode::kBitcastConvert: case HloOpcode::kCos: + case HloOpcode::kErf: case HloOpcode::kExp: case HloOpcode::kExpm1: case HloOpcode::kFloor: diff --git a/third_party/xla/xla/service/elemental_ir_emitter.h b/third_party/xla/xla/service/elemental_ir_emitter.h index caef7ca795d3d8..705378237ed7b9 100644 --- a/third_party/xla/xla/service/elemental_ir_emitter.h +++ b/third_party/xla/xla/service/elemental_ir_emitter.h @@ -175,6 +175,9 @@ class ElementalIrEmitter : public IrBuilderMixin { llvm::Value* lhs, llvm::Value* rhs, absl::string_view name); + virtual StatusOr EmitErf(PrimitiveType prim_type, + llvm::Value* value); + virtual StatusOr EmitTanh(PrimitiveType prim_type, llvm::Value* value); @@ -233,9 +236,6 @@ class ElementalIrEmitter : public IrBuilderMixin { llvm::Value* accumulator, xla::PrimitiveType primitive_type); - // Identifier of the thread unique among all threads on the device - virtual llvm::Value* EmitThreadId() { return b_->getIntN(128, 0); } - StatusOr EmitElementalSelect( const HloInstruction* hlo, const HloToElementGeneratorMap& operand_to_generator, diff --git a/third_party/xla/xla/service/elemental_ir_emitter_test.cc b/third_party/xla/xla/service/elemental_ir_emitter_test.cc index 58fe971eab6800..90554af8f99023 100644 --- a/third_party/xla/xla/service/elemental_ir_emitter_test.cc +++ b/third_party/xla/xla/service/elemental_ir_emitter_test.cc @@ -698,8 +698,10 @@ ENTRY e { EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); } +// TODO(b/324385428): Failing on GPU at head due to an LLVM integrate. Re-enable +// once this has been fixed. XLA_TEST_F(ElementalIrEmitterExecutionTestWithoutFastMinMax, - MinimumHandlesNaNsOnTheRight) { + DISABLED_MinimumHandlesNaNsOnTheRight) { constexpr absl::string_view kHloText = R"( HloModule t diff --git a/third_party/xla/xla/service/float_normalization.cc b/third_party/xla/xla/service/float_normalization.cc index 4b574a9e284ff2..3541297c4cd628 100644 --- a/third_party/xla/xla/service/float_normalization.cc +++ b/third_party/xla/xla/service/float_normalization.cc @@ -22,6 +22,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" +#include "xla/service/call_graph.h" #include "xla/service/hlo_dce.h" #include "xla/service/tuple_simplifier.h" #include "xla/shape_util.h" @@ -290,6 +291,14 @@ Status FloatNormalizationVisitor::ConvertCalledComputations( return OkStatus(); } +// Returns true if the called computations of the instruction should not +// be touched by float normalization. In particular, we must not introduce +// float conversions into collective reductions. +bool ShouldAvoidNormalizingComputationsForInstruction(HloInstruction* hlo) { + return hlo->opcode() == HloOpcode::kAllReduce || + hlo->opcode() == HloOpcode::kReduceScatter; +} + Status FloatNormalizationVisitor::HandleMultipleOutputs(HloInstruction* hlo) { std::vector operand_types(hlo->operand_count()); std::vector output_types(hlo->operand_count()); @@ -355,7 +364,7 @@ Status FloatNormalizationVisitor::HandleMultipleOutputs(HloInstruction* hlo) { std::vector low_precision_called_comps; for (auto* comp : hlo->called_computations()) { - if (comp->IsCollectiveCalledComputation()) { + if (ShouldAvoidNormalizingComputationsForInstruction(hlo)) { continue; } bool comp_has_low_precision = false; @@ -434,7 +443,7 @@ Status FloatNormalizationVisitor::HandleInstruction(HloInstruction* hlo) { std::vector low_precision_called_comps; for (auto* comp : hlo->called_computations()) { - if (comp->IsCollectiveCalledComputation()) { + if (ShouldAvoidNormalizingComputationsForInstruction(hlo)) { continue; } bool comp_has_low_precision = false; @@ -564,6 +573,56 @@ Status FloatNormalizationVisitor::Preprocess(HloInstruction* hlo) { return OkStatus(); } +// We must avoid normalizing computations that have non-normalizing users +// (e.g., all-reduce's computations should not be normalized). If a +// computation is shared between normalizing and non-normalizing users, we will +// clone the computation for the non-normalizing users so that it can be left +// unchanged. This function clones the shared computations and returns the set +// of non-normalizing computations that must be skipped by the visitor. +absl::flat_hash_set +CloneComputationsForNonNormalizingInstructions( + HloModule* module, + const absl::flat_hash_set& execution_threads) { + std::unique_ptr call_graph = + CallGraph::Build(module, execution_threads); + + absl::flat_hash_set computations_to_skip; + for (const CallGraphNode& node : call_graph->nodes()) { + bool has_normalizing_users = false; + bool has_users_to_skip_normalization = false; + for (const CallSite& site : node.caller_callsites()) { + if (ShouldAvoidNormalizingComputationsForInstruction( + site.instruction())) { + has_users_to_skip_normalization = true; + } else { + has_normalizing_users = true; + } + } + // If the computation is only used by normalizing users or only by + // non-normalizing users, then we do not clone. + if (!has_users_to_skip_normalization) { + continue; + } + if (!has_normalizing_users) { + computations_to_skip.insert(node.computation()); + continue; + } + // Otherwise, we create a clone and replace the normalizing instructions' + // computations with the clone. + HloComputation* clone = module->DeepCloneComputation(node.computation()); + for (const CallSite& site : node.caller_callsites()) { + if (ShouldAvoidNormalizingComputationsForInstruction( + site.instruction())) { + site.instruction()->ReplaceCalledComputations( + [&](HloComputation* called) { + return called == node.computation() ? clone : called; + }); + } + } + computations_to_skip.insert(clone); + } + return computations_to_skip; +} } // namespace StatusOr FloatNormalization::Run( @@ -573,13 +632,14 @@ StatusOr FloatNormalization::Run( primitive_util::LowercasePrimitiveTypeName( float_support_->LowPrecisionType()) + ", before:\n" + module->ToString()); + auto computations_to_visit = + module->MakeComputationPostOrder(execution_threads); + auto computations_to_skip = + CloneComputationsForNonNormalizingInstructions(module, execution_threads); + FloatNormalizationVisitor visitor(float_support_, this); - for (auto* comp : module->MakeComputationPostOrder(execution_threads)) { - if (comp->IsCollectiveCalledComputation()) { - XLA_VLOG_LINES(2, "Skip processing collective called computation: " + - comp->ToString()); - continue; - } + for (auto* comp : computations_to_visit) { + if (computations_to_skip.contains(comp)) continue; TF_RETURN_IF_ERROR(comp->Accept(&visitor)); } XLA_VLOG_LINES(2, "FloatNormalization::Run() for " + diff --git a/third_party/xla/xla/service/float_normalization_test.cc b/third_party/xla/xla/service/float_normalization_test.cc index 679dc2934423a8..b9e91833acc6d8 100644 --- a/third_party/xla/xla/service/float_normalization_test.cc +++ b/third_party/xla/xla/service/float_normalization_test.cc @@ -586,6 +586,67 @@ TEST_F(FloatNormalizationNoComputeSupportTest, EXPECT_EQ(ShapeUtil::GetSubshape(crs->shape(), {1}).element_type(), BF16); } +TEST_F(FloatNormalizationNoComputeSupportTest, + NormalizationClonesSharedApplyAllReduceAndReduce) { + auto module = CreateNewVerifiedModule(); + HloComputation::Builder sum_builder("sum"); + auto x = sum_builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/0, ShapeUtil::MakeShape(BF16, {}), "x")); + auto y = sum_builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/1, ShapeUtil::MakeShape(BF16, {}), "y")); + sum_builder.AddInstruction(HloInstruction::CreateBinary( + ShapeUtil::MakeShape(BF16, {}), HloOpcode::kAdd, x, y)); + HloComputation* reduction = + module->AddEmbeddedComputation(sum_builder.Build()); + + auto builder = HloComputation::Builder(TestName()); + + Shape bf16_shape_a = ShapeUtil::MakeShape(BF16, {2, 4}); + HloInstruction* a = builder.AddInstruction( + HloInstruction::CreateParameter(0, bf16_shape_a, "a")); + + Shape bf16_shape_b = ShapeUtil::MakeShape(BF16, {2, 4, 2}); + HloInstruction* b = builder.AddInstruction( + HloInstruction::CreateParameter(1, bf16_shape_b, "b")); + + Shape bf16_scalar_shape = ShapeUtil::MakeShape(BF16, {}); + HloInstruction* init = builder.AddInstruction( + HloInstruction::CreateParameter(2, bf16_scalar_shape, "init")); + + HloInstruction* all_reduce = builder.AddInstruction( + HloInstruction::CreateAllReduce(bf16_shape_a, {a}, reduction, + /*replica_groups=*/{}, + /*constrain_layout=*/false, + /*channel_id=*/std::nullopt, + /*use_global_device_ids=*/false)); + + HloInstruction* reduce = builder.AddInstruction( + HloInstruction::CreateReduce(bf16_shape_a, b, init, {2}, reduction)); + builder.AddInstruction(HloInstruction::CreateBinary( + bf16_shape_a, HloOpcode::kAdd, all_reduce, reduce)); + + auto computation = module->AddEntryComputation(builder.Build()); + // Verify that the shared computation was cloned, the all-reduce instruction + // got the unchanged bf16 add, while the reduction was promoted to f32 + // together with its called computation. + EXPECT_TRUE(Normalize(module.get())); + EXPECT_EQ(computation->root_instruction()->shape().element_type(), BF16); + EXPECT_EQ(all_reduce->operand(0)->shape().element_type(), BF16); + EXPECT_EQ(all_reduce->to_apply()->root_instruction()->opcode(), + HloOpcode::kAdd); + EXPECT_EQ(all_reduce->to_apply()->root_instruction()->shape().element_type(), + BF16); + EXPECT_EQ(reduce->called_computations().size(), 1); + EXPECT_EQ(reduce->called_computations()[0] + ->root_instruction() + ->shape() + .element_type(), + F32); + EXPECT_EQ(reduce->called_computations()[0]->root_instruction()->opcode(), + HloOpcode::kConvert); + EXPECT_EQ(reduce->shape().element_type(), F32); +} + TEST_F(FloatNormalizationNoComputeSupportTest, NoNormalizationForToApplyAllReduce) { auto module = CreateNewVerifiedModule(); diff --git a/third_party/xla/xla/service/generic_transfer_manager.cc b/third_party/xla/xla/service/generic_transfer_manager.cc index 41aca77513974c..2d3675ae59e6cd 100644 --- a/third_party/xla/xla/service/generic_transfer_manager.cc +++ b/third_party/xla/xla/service/generic_transfer_manager.cc @@ -15,108 +15,35 @@ limitations under the License. #include "xla/service/generic_transfer_manager.h" +#include #include +#include +#include #include -#include #include #include #include "absl/status/status.h" #include "absl/strings/str_format.h" #include "absl/types/span.h" -#include "xla/layout_util.h" #include "xla/literal.h" #include "xla/primitive_util.h" +#include "xla/service/shaped_buffer.h" #include "xla/service/transfer_manager.h" #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/status.h" #include "xla/status_macros.h" +#include "xla/stream_executor/device_memory.h" +#include "xla/stream_executor/event.h" +#include "xla/stream_executor/platform.h" #include "xla/stream_executor/stream_executor.h" -#include "xla/types.h" #include "xla/util.h" #include "tsl/platform/errors.h" #include "tsl/platform/logging.h" namespace xla { -namespace { - -// Transfer a memory block of the given size from the device source into the -// 'destination' buffer. -// -// size is the size to transfer to destination in bytes. -Status TransferBufferFromDevice(se::Stream* stream, - const se::DeviceMemoryBase& source, - int64_t size, void* destination) { - if (source.size() < size) { - return absl::FailedPreconditionError(absl::StrFormat( - "Source allocation on device not large enough for data transfer: " - "%d < %d", - source.size(), size)); - } - stream->ThenMemcpy(destination, source, size); - return OkStatus(); -} - -// Transfer a memory block of the given size from 'source' buffer to the given -// destination of the device. -// -// size is the size to transfer from source in bytes. -Status TransferBufferToDevice(se::Stream* stream, int64_t size, - const void* source, - se::DeviceMemoryBase* destination) { - if (destination->size() < size) { - return absl::FailedPreconditionError(absl::StrFormat( - "Destination allocation on device not large enough for data transfer: " - "%d < %d", - destination->size(), size)); - } - stream->ThenMemcpy(destination, source, size); - return OkStatus(); -} - -// Transfers a buffer of packed int4 values from the device to the host, then -// unpacks them on the host. 'source' is a buffer with (num_elements+1)/2 bytes -// where each byte stores two int4 values. 'destination' is a buffer with -// num_elements bytes, where a single int4 value will be written to each byte -// in the lower 4 bits. -Status TransferInt4ArrayFromDevice(se::Stream* stream, - const se::DeviceMemoryBase& source, - int64_t num_elements, void* destination) { - int64_t packed_size = (num_elements + 1) / 2; - auto packed_dst_data = std::make_unique>(packed_size); - TF_RETURN_IF_ERROR(TransferBufferFromDevice(stream, source, packed_size, - packed_dst_data->data())); - stream->ThenDoHostCallback([destination, num_elements, - moved_dst_data = std::move(packed_dst_data)]() { - UnpackInt4(*moved_dst_data, - absl::MakeSpan(static_cast(destination), num_elements)); - }); - return OkStatus(); -} - -// Packs an array of int4 values then transfers the packed buffer from the host -// to the device. 'source' is a buffer with num_elements bytes, where the lower -// 4 bits of each byte stores an int4 value. 'destination' is a buffer with -// (num_elements+1)/2 bytes, where two int4 values will be written into each -// byte. -Status TransferInt4ArrayToDevice(se::Stream* stream, int64_t num_elements, - const void* source, - se::DeviceMemoryBase* destination) { - auto packed_src_data = std::make_unique>( - CeilOfRatio(num_elements, int64_t{2})); - PackInt4(absl::MakeSpan(static_cast(source), num_elements), - absl::MakeSpan(*packed_src_data)); - TF_RETURN_IF_ERROR(TransferBufferToDevice( - stream, packed_src_data->size(), packed_src_data->data(), destination)); - // Ensure the buffer is transferred before we destroy it - stream->ThenDoHostCallback([keep_alive = std::move(packed_src_data)] {}); - return OkStatus(); -} - -} // namespace - GenericTransferManager::GenericTransferManager(se::Platform::Id platform_id, size_t pointer_size) : platform_id_(platform_id), pointer_size_(pointer_size) {} @@ -286,6 +213,60 @@ Status GenericTransferManager::ResetDevices( "Device reset is not yet supported on this platform (b/30481585)"); } +Status GenericTransferManager::TransferBufferFromDevice( + se::Stream* stream, const se::DeviceMemoryBase& source, int64_t size, + void* destination) { + if (source.size() < size) { + return absl::FailedPreconditionError(absl::StrFormat( + "Source allocation on device not large enough for data transfer: " + "%d < %d", + source.size(), size)); + } + stream->ThenMemcpy(destination, source, size); + return OkStatus(); +} + +Status GenericTransferManager::TransferBufferToDevice( + se::Stream* stream, int64_t size, const void* source, + se::DeviceMemoryBase* destination) { + if (destination->size() < size) { + return absl::FailedPreconditionError(absl::StrFormat( + "Destination allocation on device not large enough for data transfer: " + "%d < %d", + destination->size(), size)); + } + stream->ThenMemcpy(destination, source, size); + return OkStatus(); +} + +Status GenericTransferManager::TransferInt4ArrayFromDevice( + se::Stream* stream, const se::DeviceMemoryBase& source, + int64_t num_elements, void* destination) { + int64_t packed_size = (num_elements + 1) / 2; + auto packed_dst_data = std::make_unique>(packed_size); + TF_RETURN_IF_ERROR(TransferBufferFromDevice(stream, source, packed_size, + packed_dst_data->data())); + stream->ThenDoHostCallback([destination, num_elements, + packed_dst_data = std::move(packed_dst_data)]() { + UnpackInt4(*packed_dst_data, + absl::MakeSpan(static_cast(destination), num_elements)); + }); + return OkStatus(); +} + +Status GenericTransferManager::TransferInt4ArrayToDevice( + se::Stream* stream, int64_t num_elements, const void* source, + se::DeviceMemoryBase* destination) { + auto packed_src_data = std::make_unique>( + CeilOfRatio(num_elements, int64_t{2})); + PackInt4(absl::MakeSpan(static_cast(source), num_elements), + absl::MakeSpan(*packed_src_data)); + TF_RETURN_IF_ERROR(TransferBufferToDevice( + stream, packed_src_data->size(), packed_src_data->data(), destination)); + stream->ThenDoHostCallback([keep_alive = std::move(packed_src_data)] {}); + return OkStatus(); +} + int64_t GenericTransferManager::GetByteSizeRequirement( const Shape& shape) const { if (shape.is_static() || shape.IsTuple()) { @@ -304,4 +285,5 @@ Shape GenericTransferManager::HostShapeToDeviceShape( } return device_shape; } + } // namespace xla diff --git a/third_party/xla/xla/service/generic_transfer_manager.h b/third_party/xla/xla/service/generic_transfer_manager.h index 1b66b1ad9a817b..bc376a6932c017 100644 --- a/third_party/xla/xla/service/generic_transfer_manager.h +++ b/third_party/xla/xla/service/generic_transfer_manager.h @@ -16,10 +16,24 @@ limitations under the License. #ifndef XLA_SERVICE_GENERIC_TRANSFER_MANAGER_H_ #define XLA_SERVICE_GENERIC_TRANSFER_MANAGER_H_ -#include - +#include +#include +#include +#include + +#include "absl/base/thread_annotations.h" +#include "absl/container/node_hash_map.h" +#include "absl/synchronization/mutex.h" +#include "absl/types/span.h" +#include "xla/literal.h" +#include "xla/service/shaped_buffer.h" #include "xla/service/transfer_manager.h" #include "xla/shape.h" +#include "xla/status.h" +#include "xla/stream_executor/device_memory.h" +#include "xla/stream_executor/event.h" +#include "xla/stream_executor/memory_allocation.h" +#include "xla/stream_executor/platform.h" #include "xla/stream_executor/stream_executor.h" #include "xla/xla_data.pb.h" @@ -54,6 +68,7 @@ class GenericTransferManager : public TransferManager { Status TransferLiteralToInfeed(se::StreamExecutor* executor, const LiteralSlice& literal) override; + Status TransferLiteralFromOutfeed(se::StreamExecutor* executor, MutableBorrowingLiteral literal) override; @@ -74,6 +89,42 @@ class GenericTransferManager : public TransferManager { // can only hold one value, but subclasses can override this. virtual bool PackSubbyteTypes() const { return false; } + // Transfer a memory block of the given size from the device source into the + // 'destination' buffer. + // + // size is the size to transfer to destination in bytes. + virtual Status TransferBufferFromDevice(se::Stream* stream, + const se::DeviceMemoryBase& source, + int64_t size, void* destination); + + // Transfer a memory block of the given size from 'source' buffer to the given + // destination of the device. + // + // size is the size to transfer from source in bytes. + virtual Status TransferBufferToDevice(se::Stream* stream, int64_t size, + const void* source, + se::DeviceMemoryBase* destination); + + // Transfers a buffer of packed int4 values from the device to the host, then + // unpacks them on the host. 'source' is a buffer with (num_elements+1)/2 + // bytes where each byte stores two int4 values. 'destination' is a buffer + // with num_elements bytes, where a single int4 value will be written to each + // byte in the lower 4 bits. + virtual Status TransferInt4ArrayFromDevice(se::Stream* stream, + const se::DeviceMemoryBase& source, + int64_t num_elements, + void* destination); + + // Packs an array of int4 values then transfers the packed buffer from the + // host to the device. 'source' is a buffer with num_elements bytes, where the + // lower 4 bits of each byte stores an int4 value. 'destination' is a buffer + // with (num_elements+1)/2 bytes, where two int4 values will be written into + // each byte. + virtual Status TransferInt4ArrayToDevice(se::Stream* stream, + int64_t num_elements, + const void* source, + se::DeviceMemoryBase* destination); + // The platform this transfer manager targets. const se::Platform::Id platform_id_; diff --git a/third_party/xla/xla/service/generic_transfer_manager_test.cc b/third_party/xla/xla/service/generic_transfer_manager_test.cc index 3e6cae2fb0bf5c..5ae01d81c97b1a 100644 --- a/third_party/xla/xla/service/generic_transfer_manager_test.cc +++ b/third_party/xla/xla/service/generic_transfer_manager_test.cc @@ -31,6 +31,7 @@ limitations under the License. #include "xla/shape_tree.h" #include "xla/shape_util.h" #include "xla/stream_executor/host/host_platform_id.h" +#include "xla/stream_executor/platform_manager.h" #include "xla/stream_executor/stream_executor.h" #include "xla/tests/literal_test_util.h" #include "xla/types.h" @@ -58,7 +59,7 @@ class GenericTransferManagerTest : public ::testing::Test { void SetUp() override { TF_ASSERT_OK_AND_ASSIGN( se::Platform * platform, - se::MultiPlatformManager::PlatformWithId(se::host::kHostPlatformId)); + se::PlatformManager::PlatformWithId(se::host::kHostPlatformId)); TF_ASSERT_OK_AND_ASSIGN(stream_executor_, platform->ExecutorForDevice(0)); stream_.emplace(stream_executor_); stream_->Init(); diff --git a/third_party/xla/xla/service/gpu/BUILD b/third_party/xla/xla/service/gpu/BUILD index 3dbb5f6d60deef..ffbff19cdfd567 100644 --- a/third_party/xla/xla/service/gpu/BUILD +++ b/third_party/xla/xla/service/gpu/BUILD @@ -9,6 +9,7 @@ load( "//xla/service/gpu:build_defs.bzl", "build_cub_sort_kernels", "get_cub_sort_kernel_types", + "gpu_kernel_library", ) load( "//xla/stream_executor:build_defs.bzl", @@ -20,7 +21,14 @@ load( "if_rocm_is_configured", "rocm_copts", ) -load("@local_tsl//tsl:tsl.bzl", "if_google", "if_nccl", "tsl_copts", "tsl_gpu_library") +load( + "@local_tsl//tsl:tsl.bzl", + "if_google", + "if_nccl", + "internal_visibility", + "tsl_copts", + "tsl_gpu_library", +) load("@local_tsl//tsl:tsl.default.bzl", "filegroup", "get_compatible_with_portable") load( "@local_tsl//tsl/platform:build_config.bzl", @@ -38,7 +46,8 @@ load( ) package( - default_visibility = ["//visibility:public"], + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = internal_visibility([":friends"]), licenses = ["notice"], ) @@ -56,7 +65,6 @@ filegroup( "**/*.cc", "**/*.h", ]), - visibility = ["//visibility:public"], ) tf_proto_library( @@ -69,7 +77,6 @@ tf_proto_library( "//xla:autotuning_proto", "@local_tsl//tsl/protobuf:dnn_proto", ], - visibility = ["//visibility:public"], ) xla_cc_test( @@ -113,13 +120,11 @@ cc_library( cc_library( name = "gpu_constants", hdrs = ["gpu_constants.h"], - visibility = ["//visibility:public"], ) cc_library( name = "gpu_memory_space_assignment", hdrs = ["gpu_memory_space_assignment.h"], - visibility = ["//visibility:public"], deps = [ "//xla:status", "//xla/hlo/ir:hlo", @@ -139,7 +144,6 @@ cc_library( "launch_dimensions.h", ], compatible_with = get_compatible_with_portable(), - visibility = ["//visibility:public"], deps = [ "//xla:shape_util", "//xla:statusor", @@ -167,22 +171,14 @@ xla_cc_test( "//xla/ffi", "//xla/ffi:ffi_api", "//xla/hlo/ir:hlo", - "//xla/runtime:custom_call", - "//xla/runtime:custom_call_registry", - "//xla/runtime:executable", - "//xla/runtime:memref_view", - "//xla/runtime:module", - "//xla/runtime:module_registry", - "//xla/runtime/ffi:ffi_api", "//xla/service:custom_call_status", "//xla/service:custom_call_target_registry", "//xla/service:executable", "//xla/service:gpu_plugin", - "//xla/service/gpu/runtime:custom_call_registry", - "//xla/service/gpu/runtime:support", "//xla/stream_executor/gpu:gpu_types_header", "//xla/tests:client_library_test_base", "//xla/tests:xla_internal_test_main", # fixdeps: keep + "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", @@ -215,7 +211,6 @@ cc_library( name = "hlo_to_ir_bindings", srcs = ["hlo_to_ir_bindings.cc"], hdrs = ["hlo_to_ir_bindings.h"], - visibility = ["//visibility:public"], deps = [ ":buffer_allocations", ":ir_emission_utils", @@ -239,7 +234,6 @@ cc_library( srcs = ["target_util.cc"], hdrs = ["target_util.h"], compatible_with = get_compatible_with_portable(), - visibility = ["//visibility:public"], deps = [ "//xla:shape_util", "//xla:status", @@ -272,7 +266,6 @@ cc_library( srcs = ["gpu_device_info_for_tests.cc"], hdrs = ["gpu_device_info_for_tests.h"], compatible_with = get_compatible_with_portable(), - visibility = ["//visibility:public"], deps = [ "//xla/stream_executor:device_description", ], @@ -282,7 +275,6 @@ cc_library( name = "ir_emitter_context", srcs = ["ir_emitter_context.cc"], hdrs = ["ir_emitter_context.h"], - visibility = ["//visibility:public"], deps = [ ":gpu_constants", ":gpu_executable", @@ -293,6 +285,7 @@ cc_library( "//xla/stream_executor:device_description", "@com_google_absl//absl/algorithm:container", "@llvm-project//llvm:Support", + "@llvm-project//llvm:TargetParser", "@llvm-project//llvm:ir_headers", "@llvm-project//mlir:IR", ], @@ -307,7 +300,6 @@ cc_library( ]) + if_rocm_hipblaslt([ "TF_HIPBLASLT=1", ]), - visibility = ["//visibility:public"], deps = [ ":backend_configs_cc", ":cublas_cudnn", @@ -333,6 +325,8 @@ cc_library( ":reduction_utils", ":target_util", ":thunk", + ":triton_call", + ":triton_fusion_analysis", "//xla:autotuning_proto_cc", "//xla:literal", "//xla:permutation_util", @@ -361,24 +355,28 @@ cc_library( "//xla/service/gpu/fusions:tiling_util", "//xla/service/gpu/kernels:custom_kernel", "//xla/service/gpu/kernels:topk_custom_kernel", - "//xla/service/gpu/runtime3:command_buffer_cmd", - "//xla/service/gpu/runtime3:command_buffer_cmd_emitter", - "//xla/service/gpu/runtime3:command_buffer_thunk", - "//xla/service/gpu/runtime3:conditional_thunk", - "//xla/service/gpu/runtime3:convolution_thunk", - "//xla/service/gpu/runtime3:copy_thunk", - "//xla/service/gpu/runtime3:custom_call_thunk", - "//xla/service/gpu/runtime3:fft_thunk", - "//xla/service/gpu/runtime3:fused_mha_thunk", - "//xla/service/gpu/runtime3:gemm_thunk", - "//xla/service/gpu/runtime3:infeed_thunk", - "//xla/service/gpu/runtime3:kernel_thunk", - "//xla/service/gpu/runtime3:norm_thunk", - "//xla/service/gpu/runtime3:outfeed_thunk", - "//xla/service/gpu/runtime3:replica_id_thunk", - "//xla/service/gpu/runtime3:send_recv_thunk", - "//xla/service/gpu/runtime3:sequential_thunk", - "//xla/service/gpu/runtime3:while_thunk", + "//xla/service/gpu/runtime:command_buffer_cmd", + "//xla/service/gpu/runtime:command_buffer_cmd_emitter", + "//xla/service/gpu/runtime:command_buffer_thunk", + "//xla/service/gpu/runtime:conditional_thunk", + "//xla/service/gpu/runtime:convolution_thunk", + "//xla/service/gpu/runtime:copy_thunk", + "//xla/service/gpu/runtime:custom_call_thunk", + "//xla/service/gpu/runtime:fft_thunk", + "//xla/service/gpu/runtime:fused_mha_thunk", + "//xla/service/gpu/runtime:gemm_thunk", + "//xla/service/gpu/runtime:infeed_thunk", + "//xla/service/gpu/runtime:kernel_thunk", + "//xla/service/gpu/runtime:nccl_all_gather_thunk", + "//xla/service/gpu/runtime:nccl_all_reduce_thunk", + "//xla/service/gpu/runtime:nccl_all_to_all_thunk", + "//xla/service/gpu/runtime:norm_thunk", + "//xla/service/gpu/runtime:outfeed_thunk", + "//xla/service/gpu/runtime:replica_id_thunk", + "//xla/service/gpu/runtime:send_recv_thunk", + "//xla/service/gpu/runtime:sequential_thunk", + "//xla/service/gpu/runtime:wait_for_streams_thunk", + "//xla/service/gpu/runtime:while_thunk", "//xla/service/llvm_ir:buffer_assignment_util", "//xla/service/llvm_ir:dynamic_update_slice_util", "//xla/service/llvm_ir:fused_ir_emitter", @@ -387,12 +385,12 @@ cc_library( "//xla/service/llvm_ir:llvm_util", "//xla/service/llvm_ir:sort_util", "//xla/stream_executor:device_description", + "//xla/stream_executor:launch_dim", "//xla/stream_executor/gpu:gpu_blas_lt", "//xla/translate/hlo_to_mhlo:hlo_utils", "//xla/translate/mhlo_to_hlo:attribute_exporter", "//xla/translate/mhlo_to_hlo:location_exporter", "//xla/translate/mhlo_to_hlo:mlir_hlo_to_hlo", - "//xla/translate/mhlo_to_lhlo_with_xla", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", @@ -416,8 +414,11 @@ cc_library( "@llvm-project//mlir:IR", "@llvm-project//mlir:LLVMDialect", "@llvm-project//mlir:LLVMToLLVMIRTranslation", + "@llvm-project//mlir:MathDialect", "@llvm-project//mlir:MemRefDialect", + "@llvm-project//mlir:MemRefTransforms", "@llvm-project//mlir:NVVMToLLVMIRTranslation", + "@llvm-project//mlir:Parser", "@llvm-project//mlir:ROCDLToLLVMIRTranslation", "@llvm-project//mlir:Support", "@llvm-project//mlir:ToLLVMIRTranslation", @@ -426,12 +427,15 @@ cc_library( "@local_tsl//tsl/platform:status", "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/protobuf:dnn_proto_cc", + "@triton//:TritonDialects", ] + if_gpu_is_configured([ ":ir_emitter_triton", - "//xla/service/gpu/runtime3:cholesky_thunk", - "//xla/service/gpu/runtime3:cub_sort_thunk", - "//xla/service/gpu/runtime3:gpublas_lt_matmul_thunk", - "//xla/service/gpu/runtime3:triangular_solve_thunk", + "//xla/service/gpu/runtime:cholesky_thunk", + "//xla/service/gpu/runtime:cub_sort_thunk", + "//xla/service/gpu/runtime:gpublas_lt_matmul_thunk", + "//xla/service/gpu/runtime:triangular_solve_thunk", + ]) + if_rocm_is_configured([ + "@local_config_rocm//rocm:rocm_headers", ]), ) @@ -448,7 +452,6 @@ cc_library( "ir_emitter_nested.h", ], copts = if_cuda_is_configured(["-DGOOGLE_CUDA=1"]), - visibility = ["//visibility:public"], deps = [ ":backend_configs_cc", ":hlo_to_ir_bindings", @@ -479,6 +482,7 @@ cc_library( "@com_google_absl//absl/types:span", "@llvm-project//llvm:Core", "@llvm-project//llvm:Support", + "@llvm-project//llvm:TargetParser", "@local_tsl//tsl/platform:errors", ], ) @@ -489,7 +493,6 @@ cc_library( "ir_emitter_triton.cc", ]), hdrs = if_gpu_is_configured(["ir_emitter_triton.h"]), - visibility = ["//visibility:public"], deps = [ ":hlo_traversal", ":ir_emission_utils", @@ -513,9 +516,11 @@ cc_library( "//xla/mlir_hlo", "//xla/mlir_hlo:map_mhlo_to_scalar_op", "//xla/service:dump", + "//xla/service:hlo_module_config", "//xla/service/gpu/llvm_gpu_backend", "//xla/service/llvm_ir:llvm_util", "//xla/stream_executor:device_description", + "//xla/stream_executor:launch_dim", "//xla/translate/hlo_to_mhlo:hlo_module_importer", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", @@ -557,11 +562,10 @@ cc_library( "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/platform:tensor_float_32_utils", "@triton//:TritonDialects", - "@triton//:TritonTmaMetadata", "@triton//:TritonTransforms", ] + if_cuda_is_configured([ - "@triton//:NVGPUToLLVM", - "@triton//:TritonGPUToLLVM", + "@triton//third_party/nvidia:NVGPUToLLVM", + "@triton//third_party/nvidia:TritonNVIDIAGPUToLLVM", "@triton//:TritonGPUTransforms", "@triton//:TritonNvidiaGPUTransforms", "@triton//:TritonLLVMIR", @@ -577,18 +581,18 @@ xla_test( "gpu_v100", ], shard_count = 20, - tags = [ - "noasan", # TODO(b/319626241): Reenable once the signed-integer-overflow is fixed - "nomac", - ], + tags = ["nomac"], deps = [ ":backend_configs_cc", ":gpu_device_info_for_tests", ":ir_emission_utils", ":ir_emitter_triton", ":matmul_utils", + ":triton_fusion_analysis", "//xla:autotuning_proto_cc", "//xla:error_spec", + "//xla:literal", + "//xla:literal_util", "//xla:status_macros", "//xla:statusor", "//xla:xla_proto_cc", @@ -601,6 +605,7 @@ xla_test( "//xla/tests:filecheck", "//xla/tests:verified_hlo_module", "//xla/tests:xla_internal_test_main", # fixdeps: keep + "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/strings", "@com_google_googletest//:gtest", "@llvm-project//llvm:Support", @@ -672,7 +677,6 @@ cc_library( name = "triton_autotuner", srcs = if_cuda_is_configured(["triton_autotuner.cc"]), hdrs = if_cuda_is_configured(["triton_autotuner.h"]), - visibility = ["//visibility:public"], deps = if_cuda_is_configured([ ":autotuner_compile_util", ":autotuner_util", @@ -700,6 +704,7 @@ cc_library( "//xla:autotuning_proto_cc", "//xla:shape_util", "//xla:status_macros", + "//xla/tools:hlo_decomposer_lib", "//xla:status", "//xla:statusor", "//xla:util", @@ -761,6 +766,7 @@ xla_test( "//xla/tests:test_utils", "//xla/tests:verified_hlo_module", "//xla/tests:xla_internal_test_main", # fixdeps: keep + "//xla/tools:hlo_decomposer_lib", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", "@com_google_absl//absl/strings", @@ -775,12 +781,26 @@ xla_test( ], ) +cc_library( + name = "triton_call", + srcs = if_gpu_is_configured(["triton_call.cc"]), + hdrs = ["triton_call.h"], + local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]) + if_rocm_is_configured([ + "TENSORFLOW_USE_ROCM=1", + ]), + deps = [ + "@llvm-project//mlir:AsmParser", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Parser", + "@llvm-project//mlir:Support", + ], +) + cc_library( name = "parallel_loop_emitter", srcs = ["parallel_loop_emitter.cc"], hdrs = ["parallel_loop_emitter.h"], compatible_with = get_compatible_with_portable(), - visibility = ["//visibility:public"], deps = [ ":launch_dimensions", ":target_util", @@ -802,7 +822,6 @@ cc_library( name = "buffer_allocations", srcs = ["buffer_allocations.cc"], hdrs = ["buffer_allocations.h"], - visibility = ["//visibility:public"], deps = [ "//xla:status", "//xla:statusor", @@ -820,18 +839,19 @@ cc_library( name = "thunk", srcs = ["thunk.cc"], hdrs = ["thunk.h"], - visibility = ["//visibility:public"], deps = [ + ":backend_configs_cc", ":buffer_allocations", ":gpu_executable_run_options", + ":nccl_api", ":nccl_clique", + ":nccl_clique_key", "//xla:executable_run_options", "//xla:status", "//xla/hlo/ir:hlo", "//xla/service:buffer_assignment", "//xla/service:executable", "//xla/service:global_device_id", - "//xla/service/gpu:nccl_clique_key", "//xla/stream_executor", "//xla/translate/mhlo_to_hlo:location_exporter", "@com_google_absl//absl/algorithm:container", @@ -844,6 +864,7 @@ cc_library( "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/types:span", "@llvm-project//mlir:IR", + "@local_tsl//tsl/lib/gtl:int_type", "@local_tsl//tsl/platform:statusor", ], ) @@ -851,9 +872,6 @@ cc_library( cc_library( name = "nccl_collective_thunks", srcs = [ - "nccl_all_gather_thunk.cc", - "nccl_all_reduce_thunk.cc", - "nccl_all_to_all_thunk.cc", "nccl_collective_permute_thunk.cc", "nccl_collective_thunk.cc", "nccl_p2p_thunk_common.cc", @@ -861,19 +879,19 @@ cc_library( "nccl_send_thunk.cc", ], hdrs = [ - "nccl_all_gather_thunk.h", - "nccl_all_reduce_thunk.h", - "nccl_all_to_all_thunk.h", "nccl_collective_permute_thunk.h", "nccl_collective_thunk.h", "nccl_p2p_thunk_common.h", "nccl_recv_thunk.h", "nccl_send_thunk.h", ], - visibility = ["//visibility:public"], + local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]) + if_rocm_is_configured([ + "TENSORFLOW=1", + ]), deps = [ ":backend_configs_cc", ":buffer_allocations", + ":gpu_executable_run_options", ":ir_emission_utils", ":nccl_api", ":nccl_clique", @@ -893,11 +911,13 @@ cc_library( "//xla/service:computation_placer", "//xla/service:global_device_id", "//xla/service:hlo_parser", - "//xla/service/gpu:gpu_executable_run_options", + "//xla/service:rendezvous", "//xla/service/llvm_ir:llvm_util", "//xla/stream_executor", "//xla/stream_executor/gpu:gpu_activation_header", + "//xla/stream_executor/gpu:gpu_driver_header", "//xla/stream_executor/gpu:gpu_stream", + "//xla/stream_executor/gpu:gpu_types_header", "//xla/translate/hlo_to_mhlo:hlo_utils", "//xla/translate/mhlo_to_hlo:attribute_exporter", "//xla/translate/mhlo_to_hlo:type_to_shape", @@ -914,6 +934,7 @@ cc_library( "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/time", "@com_google_absl//absl/types:span", "@llvm-project//mlir:IR", "@local_tsl//tsl/platform:errors", @@ -934,42 +955,46 @@ cc_library( # have `if_nccl` and `if_gpu_configured` that do not compose. NCCL header included directly in # :nccl_api target and all other targets should use this header to launch collective operations. # This allows to minimize the spreading of #ifdef all over the XLA code base. - alias( name = "nccl_api", actual = if_nccl(":_nccl_api_impl", ":_nccl_api_stub"), - visibility = ["//visibility:public"], ) cc_library( name = "_nccl_api_impl", - srcs = if_cuda_is_configured( + srcs = if_gpu_is_configured( ["nccl_api.cc"], ["nccl_api_stub.cc"], ), hdrs = ["nccl_api.h"], compatible_with = get_compatible_with_portable(), - defines = if_cuda_is_configured(["XLA_ENABLE_XCCL"]), # TODO(ezhulenev): Remove! - visibility = ["//visibility:public"], deps = [ ":nccl_clique_key", "//xla:shape_util", "//xla:xla_data_proto_cc", "//xla/service:collective_ops_utils", "//xla/stream_executor", + "//xla/stream_executor/gpu:gpu_activation", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:btree", "@com_google_absl//absl/hash", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/types:span", "@local_tsl//tsl/concurrency:ref_count", + "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:logging", "@local_tsl//tsl/platform:statusor", ] + if_cuda_is_configured([ "@local_config_nccl//:nccl", "//xla/stream_executor/cuda:cuda_driver", "//xla/stream_executor/cuda:cuda_executor", + ]) + if_rocm_is_configured([ + "@local_config_rocm//rocm:rccl", + "//xla/stream_executor/rocm:rocm_driver", + "//xla/stream_executor/rocm:rocm_executor", ]) + if_gpu_is_configured([ "//xla/stream_executor/gpu:gpu_stream", ]), @@ -980,7 +1005,6 @@ cc_library( srcs = ["nccl_api_stub.cc"], hdrs = ["nccl_api.h"], compatible_with = get_compatible_with_portable(), - visibility = ["//visibility:public"], deps = [ ":nccl_clique_key", "//xla:shape_util", @@ -989,6 +1013,7 @@ cc_library( "//xla/stream_executor", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/types:span", "@local_tsl//tsl/concurrency:ref_count", "@local_tsl//tsl/platform:logging", ], @@ -999,22 +1024,32 @@ cc_library( srcs = ["nccl_clique_key.cc"], hdrs = ["nccl_clique_key.h"], compatible_with = get_compatible_with_portable(), - visibility = ["//visibility:public"], deps = [ "//xla/service:global_device_id", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:span", ], ) +xla_cc_test( + name = "nccl_clique_key_test", + srcs = ["nccl_clique_key_test.cc"], + deps = [ + ":nccl_clique_key", + "//xla/service:global_device_id", + "@com_google_absl//absl/container:btree", + "@local_tsl//tsl/platform:test", + "@local_tsl//tsl/platform:test_main", + ], +) + cc_library( name = "nccl_clique", srcs = ["nccl_clique.cc"], hdrs = ["nccl_clique.h"], - visibility = ["//visibility:public"], deps = [ ":nccl_api", ":nccl_clique_key", @@ -1024,9 +1059,10 @@ cc_library( "//xla/service:global_device_id", "//xla/service:lockable", "//xla/service:rendezvous", + "//xla/stream_executor", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base:core_headers", - "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:btree", "@com_google_absl//absl/container:node_hash_map", "@com_google_absl//absl/functional:function_ref", "@com_google_absl//absl/hash", @@ -1037,9 +1073,9 @@ cc_library( "@com_google_absl//absl/synchronization", "@com_google_absl//absl/time", "@com_google_absl//absl/types:span", - "@local_tsl//tsl/lib/gtl:int_type", "@local_tsl//tsl/platform:env", "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:hash", "@local_tsl//tsl/platform:logging", "@local_tsl//tsl/platform:statusor", ], @@ -1049,7 +1085,6 @@ cuda_library( name = "sleep_kernel", srcs = if_cuda_is_configured(["sleep_kernel.cu.cc"]), hdrs = if_cuda_is_configured(["sleep_kernel.h"]), - visibility = ["//visibility:public"], deps = ["@local_config_cuda//cuda:cuda_headers"], ) @@ -1057,9 +1092,11 @@ cc_library( name = "mock_nccl_xml_google", srcs = ["mock_nccl_xml.cc"], hdrs = ["mock_nccl_xml.h"], - local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]), + local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]) + if_rocm_is_configured([ + "TENSORFLOW_USE_ROCM=1", + ]), tags = ["manual"], - visibility = ["//visibility:public"], + visibility = ["//visibility:private"], deps = [ "//xla:status", "@com_google_absl//absl/container:flat_hash_map", @@ -1073,6 +1110,8 @@ cc_library( "@local_tsl//tsl/platform:regexp", ] + if_cuda_is_configured([ "@local_config_nccl//:nccl", + ]) + if_rocm_is_configured([ + "@local_config_rocm//rocm:rccl", ]), ) @@ -1080,7 +1119,9 @@ xla_cc_test( name = "mock_nccl_xml_test", size = "small", srcs = if_google(["mock_nccl_xml_test.cc"]), - local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]), + local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]) + if_rocm_is_configured([ + "TENSORFLOW_USE_ROCM=1", + ]), tags = tf_cuda_tests_tags(), deps = [ "//xla:status", @@ -1090,6 +1131,8 @@ xla_cc_test( ":mock_nccl_xml_google", ]) + if_cuda_is_configured([ "@local_config_nccl//:nccl", + ]) + if_rocm_is_configured([ + "@local_config_rocm//rocm:rccl", ]), ) @@ -1097,7 +1140,6 @@ xla_cc_test( cc_library( name = "empty", compatible_with = get_compatible_with_portable(), - visibility = ["//visibility:public"], ) alias( @@ -1106,7 +1148,6 @@ alias( if_google(":_mock_nccl_utils_google", ":_mock_nccl_utils_default"), ":empty", ), - visibility = ["//visibility:public"], ) # Do not build mock_nccl_utils.cc in OSS. It uses the nccl internal cost model to estimate the @@ -1121,9 +1162,8 @@ cc_library( ]), # Override tsl_gpu_library()'s internal default value of ["//buildenv/target:gce"]. compatible_with = [], - defines = if_cuda_is_configured(["XLA_ENABLE_XCCL"]), # TODO(ezhulenev): Remove! tags = ["manual"], - visibility = ["//visibility:public"], + visibility = ["//visibility:private"], deps = if_cuda_is_configured([ ":gpu_executable_run_options", ":mock_nccl_xml_google", @@ -1157,11 +1197,13 @@ cc_library( "//xla/service:collective_ops_utils", "//xla/service:global_device_id", "//xla/service:rendezvous", + "//xla/service:lockable", "//xla/stream_executor", "//xla/stream_executor/gpu:gpu_activation", "//xla/stream_executor/gpu:gpu_stream", "//xla/stream_executor/gpu:gpu_types_header", "@local_tsl//tsl/platform:env", + "@local_tsl//tsl/lib/gtl:int_type", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:statusor", ]), @@ -1174,7 +1216,7 @@ cc_library( # Override tsl_gpu_library()'s internal default value of ["//buildenv/target:gce"]. compatible_with = [], tags = ["manual"], - visibility = ["//visibility:public"], + visibility = ["//visibility:private"], deps = if_gpu_is_configured([ ":gpu_executable_run_options", ":nccl_collective_thunks", @@ -1192,6 +1234,8 @@ cc_library( "//xla/stream_executor", ]) + if_cuda_is_configured([ "@local_config_nccl//:nccl", + ]) + if_rocm_is_configured([ + "@local_config_rocm//rocm:rccl", ]), ) @@ -1201,28 +1245,6 @@ bool_flag( build_setting_default = if_google(True, False), ) -cc_library( - name = "non_atomically_upgradeable_rw_lock", - srcs = [], - hdrs = [ - "non_atomically_upgradeable_rw_lock.h", - ], - visibility = ["//visibility:public"], - deps = [ - "@com_google_absl//absl/synchronization", - ], -) - -xla_cc_test( - name = "non_atomically_upgradeable_rw_lock_test", - srcs = ["non_atomically_upgradeable_rw_lock_test.cc"], - deps = [ - ":non_atomically_upgradeable_rw_lock", - "@com_google_googletest//:gtest_main", - "@local_tsl//tsl/platform:test", - ], -) - cc_library( name = "gpu_executable", srcs = [ @@ -1234,36 +1256,33 @@ cc_library( local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]) + if_rocm_is_configured([ "TENSORFLOW_USE_ROCM=1", ]), - visibility = ["//visibility:public"], deps = [ ":buffer_allocations", ":gpu_constants", + ":gpu_executable_run_options", ":ir_emission_utils", ":nccl_clique", - ":non_atomically_upgradeable_rw_lock", + ":nccl_clique_key", ":stream_executor_util", ":thunk", + "//xla:executable_run_options", "//xla:shape_tree", "//xla:shape_util", "//xla:status", "//xla:status_macros", "//xla:util", "//xla/hlo/ir:hlo", - "//xla/mlir/runtime/ir:rt", - "//xla/mlir/runtime/transforms:compilation_pipeline_gpu", - "//xla/mlir/runtime/transforms:type_converter", - "//xla/runtime:executable", "//xla/service:buffer_assignment", "//xla/service:executable", "//xla/service:hlo_execution_profile", + "//xla/service:hlo_module_config", "//xla/service:hlo_parser", "//xla/service:rendezvous", "//xla/service:shaped_buffer", "//xla/service:stream_pool", "//xla/service:xla_debug_info_manager", - "//xla/service/gpu:nccl_clique_key", - "//xla/service/gpu/runtime:executable", - "//xla/service/gpu/runtime:tracing", + "//xla/service/gpu:backend_configs_cc", + "//xla/service/gpu/runtime:annotation", "//xla/stream_executor", "//xla/stream_executor:device_description", "//xla/stream_executor:device_memory", @@ -1275,11 +1294,15 @@ cc_library( "//xla/stream_executor/gpu:gpu_timer_header", "//xla/stream_executor/rocm:rocm_platform_id", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/cleanup", "@com_google_absl//absl/container:btree", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/synchronization", @@ -1316,7 +1339,6 @@ cc_library( srcs = ["ir_emission_utils.cc"], hdrs = ["ir_emission_utils.h"], compatible_with = get_compatible_with_portable(), - visibility = ["//visibility:public"], deps = [ ":hlo_traversal", ":target_util", @@ -1391,7 +1413,6 @@ cc_library( srcs = ["reduction_utils.cc"], hdrs = ["reduction_utils.h"], local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]), - visibility = ["//visibility:public"], deps = [ ":ir_emission_utils", "//xla:shape_util", @@ -1426,7 +1447,6 @@ cc_library( srcs = ["cublas_cudnn.cc"], hdrs = ["cublas_cudnn.h"], compatible_with = get_compatible_with_portable(), - visibility = ["//visibility:public"], deps = [ "//xla:status", "//xla:statusor", @@ -1438,35 +1458,21 @@ cc_library( ], ) -cuda_library( - name = "gpu_prim_cuda", - hdrs = ["gpu_prim_cuda.h"], - local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]), - visibility = ["//visibility:public"], +gpu_kernel_library( + name = "gpu_prim", + hdrs = ["gpu_prim.h"], deps = [ "@eigen_archive//:eigen3", "@local_config_cuda//cuda:cuda_headers", "@local_tsl//tsl/platform:bfloat16", - ] + if_cuda_is_configured(xla_cub_deps()), + ] + if_cuda_is_configured(xla_cub_deps()) + if_rocm_is_configured([ + "@local_config_rocm//rocm:rocprim", + ]), ) cc_library( name = "variant_visitor", hdrs = ["variant_visitor.h"], - visibility = ["//visibility:public"], -) - -cc_library( - name = "gpu_prim_rocm", - hdrs = ["gpu_prim_rocm.h"], - local_defines = if_rocm_is_configured(["TENSORFLOW_USE_ROCM=1"]), - visibility = ["//visibility:public"], - deps = [ - "@eigen_archive//:eigen3", - "@local_tsl//tsl/platform:bfloat16", - ] + if_rocm_is_configured([ - "@local_config_rocm//rocm:rocprim", - ]), ) build_cub_sort_kernels( @@ -1476,14 +1482,9 @@ build_cub_sort_kernels( local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]) + if_rocm_is_configured([ "TENSORFLOW_USE_ROCM=1", ]), - tags = [ - "ignore_for_dep=third_party/tensorflow/compiler/xla/service/gpu/gpu_prim_rocm.h", - ], types = get_cub_sort_kernel_types(), - deps = if_cuda_is_configured([ - ":gpu_prim_cuda", - ]) + if_rocm_is_configured([ - ":gpu_prim_rocm", + deps = if_gpu_is_configured([ + ":gpu_prim", ]), ) @@ -1494,7 +1495,6 @@ cc_library( local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]) + if_rocm_is_configured([ "TENSORFLOW_USE_ROCM=1", ]), - visibility = ["//visibility:public"], deps = [ ":backend_configs_cc", ":cublas_cudnn", @@ -1529,7 +1529,6 @@ cc_library( name = "triton_support", srcs = ["triton_support.cc"], hdrs = ["triton_support.h"], - visibility = ["//visibility:public"], deps = [ ":variant_visitor", "//xla:xla_data_proto_cc", @@ -1543,7 +1542,6 @@ cc_library( name = "triton_tiling_propagation", srcs = ["triton_tiling_propagation.cc"], hdrs = ["triton_tiling_propagation.h"], - visibility = ["//visibility:public"], deps = [ ":triton_support", "//xla:permutation_util", @@ -1566,7 +1564,6 @@ cc_library( name = "triton_fusion_analysis", srcs = ["triton_fusion_analysis.cc"], hdrs = ["triton_fusion_analysis.h"], - visibility = ["//visibility:public"], deps = [ ":matmul_utils", ":triton_tiling_propagation", @@ -1610,7 +1607,6 @@ cc_library( name = "gemm_rewriter_triton", srcs = ["gemm_rewriter_triton.cc"], hdrs = ["gemm_rewriter_triton.h"], - visibility = ["//visibility:public"], deps = [ ":backend_configs_cc", ":cublas_padding_requirements", @@ -1669,7 +1665,6 @@ cc_library( name = "split_k_gemm_rewriter", srcs = ["split_k_gemm_rewriter.cc"], hdrs = ["split_k_gemm_rewriter.h"], - visibility = ["//visibility:public"], deps = [ ":ir_emission_utils", ":matmul_utils", @@ -1731,7 +1726,6 @@ cc_library( name = "fusion_merger_triton", srcs = ["fusion_merger_triton.cc"], hdrs = ["fusion_merger_triton.h"], - visibility = ["//visibility:public"], deps = [ ":backend_configs_cc", ":gpu_fusible", @@ -1777,7 +1771,6 @@ cc_library( name = "softmax_rewriter_triton", srcs = ["softmax_rewriter_triton.cc"], hdrs = ["softmax_rewriter_triton.h"], - visibility = ["//visibility:public"], deps = [ ":backend_configs_cc", ":ir_emission_utils", @@ -1785,12 +1778,12 @@ cc_library( "//xla:shape_util", "//xla:status", "//xla:status_macros", - "//xla:statusor", "//xla:util", "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", "//xla/hlo/utils:hlo_query", "//xla/service:hlo_pass", + "//xla/service:instruction_fusion", "//xla/stream_executor:device_description", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", @@ -1812,7 +1805,6 @@ cc_library( local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]) + if_rocm_is_configured([ "TENSORFLOW_USE_ROCM=1", ]), - visibility = ["//visibility:public"], deps = if_gpu_is_configured([ ":backend_configs_cc", ":buffer_comparator", @@ -1848,7 +1840,6 @@ cc_library( name = "autotuner_util", srcs = if_gpu_is_configured(["autotuner_util.cc"]), hdrs = if_gpu_is_configured(["autotuner_util.h"]), - visibility = ["//visibility:public"], deps = if_gpu_is_configured([ ":gpu_asm_opts_util", ":stream_executor_util", @@ -1888,7 +1879,6 @@ cc_library( name = "autotuner_compile_util", srcs = if_gpu_is_configured(["autotuner_compile_util.cc"]), hdrs = if_gpu_is_configured(["autotuner_compile_util.h"]), - visibility = ["//visibility:public"], deps = if_gpu_is_configured([ ":autotuner_util", ":gpu_executable_run_options", @@ -1955,7 +1945,6 @@ cc_library( local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]) + if_rocm_is_configured([ "TENSORFLOW_USE_ROCM=1", ]), - visibility = ["//visibility:public"], deps = [ ":backend_configs_cc", ":ir_emission_utils", @@ -1969,11 +1958,11 @@ cc_library( "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", "//xla/mlir_hlo", - "//xla/mlir_hlo:lhlo_gpu", "//xla/stream_executor", "//xla/stream_executor/gpu:gpu_blas_lt", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", "@local_tsl//tsl/platform:status", @@ -2011,7 +2000,6 @@ cc_library( name = "dot_dimension_sorter", srcs = ["dot_dimension_sorter.cc"], hdrs = ["dot_dimension_sorter.h"], - visibility = ["//visibility:public"], deps = [ "//xla:permutation_util", "//xla:shape_util", @@ -2049,7 +2037,6 @@ cc_library( name = "gpu_async_collective_annotator", srcs = ["gpu_async_collective_annotator.cc"], hdrs = ["gpu_async_collective_annotator.h"], - visibility = ["//visibility:public"], deps = [ ":backend_configs_cc", "//xla/hlo/utils:hlo_query", @@ -2076,7 +2063,6 @@ cc_library( name = "gpu_convert_async_collectives_to_sync", srcs = ["gpu_convert_async_collectives_to_sync.cc"], hdrs = ["gpu_convert_async_collectives_to_sync.h"], - visibility = ["//visibility:public"], deps = [ ":backend_configs_cc", "//xla/service:convert_async_collectives_to_sync", @@ -2106,7 +2092,6 @@ cc_library( local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]) + if_rocm_is_configured([ "TENSORFLOW_USE_ROCM=1", ]), - visibility = ["//visibility:public"], deps = if_gpu_is_configured([ ":autotuner_util", ":backend_configs_cc", @@ -2191,7 +2176,6 @@ cc_library( name = "gpu_conv_runner", srcs = ["gpu_conv_runner.cc"], hdrs = ["gpu_conv_runner.h"], - visibility = ["//visibility:public"], deps = [ ":backend_configs_cc", ":cublas_cudnn", @@ -2216,7 +2200,6 @@ cc_library( name = "gpu_norm_runner", srcs = ["gpu_norm_runner.cc"], hdrs = ["gpu_norm_runner.h"], - visibility = ["//visibility:public"], deps = [ ":backend_configs_cc", ":cublas_cudnn", @@ -2243,7 +2226,6 @@ cc_library( name = "gpu_fused_mha_runner", srcs = ["gpu_fused_mha_runner.cc"], hdrs = ["gpu_fused_mha_runner.h"], - visibility = ["//visibility:public"], deps = [ ":backend_configs_cc", ":cublas_cudnn", @@ -2259,6 +2241,7 @@ cc_library( "//xla/stream_executor", "//xla/stream_executor:dnn", "//xla/stream_executor:lazy_op_runner", + "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", ], @@ -2268,7 +2251,6 @@ cc_library( name = "gpu_conv_rewriter", srcs = ["gpu_conv_rewriter.cc"], hdrs = ["gpu_conv_rewriter.h"], - visibility = ["//visibility:public"], deps = [ ":backend_configs_cc", ":cublas_cudnn", @@ -2287,7 +2269,6 @@ cc_library( name = "gpu_sort_rewriter", srcs = if_gpu_is_configured(["gpu_sort_rewriter.cc"]), hdrs = if_gpu_is_configured(["gpu_sort_rewriter.h"]), - visibility = ["//visibility:public"], deps = [ ":cublas_cudnn", "//xla:comparison_util", @@ -2297,7 +2278,7 @@ cc_library( "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", "//xla/service:hlo_pass", - "//xla/service/gpu/runtime3:cub_sort_thunk", + "//xla/service/gpu/runtime:cub_sort_thunk", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", @@ -2311,7 +2292,6 @@ cc_library( name = "move_copy_to_users", srcs = ["move_copy_to_users.cc"], hdrs = ["move_copy_to_users.h"], - visibility = ["//visibility:public"], deps = [ "//xla:shape_util", "//xla:status", @@ -2386,7 +2366,6 @@ cc_library( local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]) + if_rocm_is_configured([ "TENSORFLOW_USE_ROCM=1", ]), - visibility = ["//visibility:public"], deps = [ "//xla:comparison_util", "//xla:status", @@ -2416,7 +2395,6 @@ cc_library( name = "cusolver_rewriter", srcs = if_gpu_is_configured(["cusolver_rewriter.cc"]), hdrs = if_gpu_is_configured(["cusolver_rewriter.h"]), - visibility = ["//visibility:public"], deps = if_gpu_is_configured([ ":cusolver_context", ":ir_emission_utils", @@ -2444,7 +2422,6 @@ cc_library( name = "instruction_fusion", srcs = ["instruction_fusion.cc"], hdrs = ["instruction_fusion.h"], - visibility = ["//visibility:public"], deps = [ ":gpu_fusible", "//xla:shape_util", @@ -2467,7 +2444,7 @@ xla_cc_test( name = "instruction_fusion_test", srcs = ["instruction_fusion_test.cc"], tags = [ - "no_arm64", + "no_aarch64", "nomsan", ], deps = [ @@ -2488,43 +2465,93 @@ tf_proto_library( name = "fusion_process_dump_proto", srcs = ["fusion_process_dump.proto"], cc_api_version = 2, - protodeps = [], - visibility = ["//visibility:public"], + protodeps = [ + "//xla/stream_executor:device_description_proto", + ], +) + +cc_library( + name = "fusion_process_dump", + srcs = ["fusion_process_dump.cc"], + hdrs = ["fusion_process_dump.h"], + deps = [ + ":fusion_process_dump_proto_cc", + "//xla:util", + "//xla/hlo/ir:hlo", + "//xla/service:hlo_graph_dumper", + "//xla/stream_executor:stream_executor_headers", + "//xla/tools:hlo_module_loader", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@local_tsl//tsl/platform:env", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:logging", + "@local_tsl//tsl/platform:path", + "@local_tsl//tsl/platform:protobuf", + "@local_tsl//tsl/platform:status", + "@local_tsl//tsl/platform:statusor", + ], +) + +xla_cc_test( + name = "fusion_process_dump_test", + srcs = ["fusion_process_dump_test.cc"], + deps = [ + ":fusion_process_dump", + ":fusion_process_dump_proto_cc", + ":gpu_device_info_for_tests", + "//xla:test", + "//xla/hlo/ir:hlo", + "//xla/service:hlo_parser", + "//xla/service:pattern_matcher", + "//xla/service:pattern_matcher_gmock", + "//xla/tests:hlo_test_base", + "//xla/tests:xla_internal_test_main", + "@com_google_absl//absl/strings:string_view", + "@com_google_googletest//:gtest", + "@local_tsl//tsl/platform:statusor", + ], ) cc_library( name = "priority_fusion", srcs = ["priority_fusion.cc"], hdrs = ["priority_fusion.h"], - visibility = ["//visibility:public"], deps = [ ":fusion_process_dump_proto_cc", ":gpu_fusible", ":hlo_fusion_analysis", ":hlo_traversal", + "//xla:debug_options_flags", "//xla:shape_util", "//xla:statusor", "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", "//xla/service:dump", - "//xla/service:fusion_node_indexing_evaluation", "//xla/service:fusion_queue", "//xla/service:hlo_cost_analysis", + "//xla/service:hlo_graph_dumper", "//xla/service:hlo_pass", "//xla/service:instruction_fusion", "//xla/service/gpu/model:fusion_analysis_cache", "//xla/service/gpu/model:gpu_hlo_cost_analysis", "//xla/service/gpu/model:gpu_performance_model", + "//xla/service/gpu/model:gpu_performance_model_base", "//xla/stream_executor:device_description", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/log:check", "@com_google_absl//absl/meta:type_traits", "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/synchronization", "@com_google_absl//absl/time", + "@llvm-project//llvm:Support", "@local_tsl//tsl/platform:blocking_counter", "@local_tsl//tsl/platform:env", + "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:logging", "@local_tsl//tsl/platform:status", ], @@ -2547,6 +2574,7 @@ xla_cc_test( "//xla/tests:hlo_test_base", "//xla/tests:verified_hlo_module", "//xla/tests:xla_internal_test_main", + "@com_google_absl//absl/strings:string_view", "@com_google_googletest//:gtest", "@local_tsl//tsl/platform:status_matchers", ], @@ -2556,7 +2584,6 @@ cc_library( name = "multi_output_fusion", srcs = ["multi_output_fusion.cc"], hdrs = ["multi_output_fusion.h"], - visibility = ["//visibility:public"], deps = [ ":gpu_fusible", "//xla:debug_options_flags", @@ -2602,6 +2629,34 @@ xla_cc_test( ], ) +cc_library( + name = "rename_fusions", + srcs = ["rename_fusions.cc"], + hdrs = ["rename_fusions.h"], + deps = [ + ":hlo_traversal", + ":ir_emission_utils", + "//xla/hlo/ir:hlo", + "//xla/service:hlo_pass", + "@com_google_absl//absl/container:btree", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + ], +) + +xla_cc_test( + name = "rename_fusions_test", + srcs = ["rename_fusions_test.cc"], + deps = [ + ":rename_fusions", + "//xla/tests:hlo_test_base", + "//xla/tests:xla_internal_test_main", + "@com_google_absl//absl/strings:string_view", + "@com_google_googletest//:gtest", + ], +) + xla_cc_test( name = "softmax_rewriter_triton_test", srcs = ["softmax_rewriter_triton_test.cc"], @@ -2610,13 +2665,13 @@ xla_cc_test( "//xla:shape_util", "//xla:statusor", "//xla/hlo/ir:hlo", + "//xla/service:instruction_fusion", "//xla/service:pattern_matcher", "//xla/service:pattern_matcher_gmock", "//xla/stream_executor:device_description", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", # build_cleaner: keep "@com_google_absl//absl/base:core_headers", - "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", "@com_google_absl//absl/strings", "@com_google_googletest//:gtest", @@ -2630,7 +2685,6 @@ cc_library( name = "gpu_sanitize_constant_names", srcs = ["gpu_sanitize_constant_names.cc"], hdrs = ["gpu_sanitize_constant_names.h"], - visibility = ["//visibility:public"], deps = [ "//xla/hlo/ir:hlo", "//xla/service:hlo_pass", @@ -2660,7 +2714,6 @@ cc_library( name = "fusion_merger", srcs = ["fusion_merger.cc"], hdrs = ["fusion_merger.h"], - visibility = ["//visibility:public"], deps = [ ":gpu_fusible", "//xla:shape_util", @@ -2701,7 +2754,6 @@ cc_library( name = "gpu_conv_padding_legalization", srcs = ["gpu_conv_padding_legalization.cc"], hdrs = ["gpu_conv_padding_legalization.h"], - visibility = ["//visibility:public"], deps = [ ":cublas_cudnn", "//xla:literal", @@ -2738,7 +2790,6 @@ cc_library( name = "cudnn_support_utils", srcs = ["cudnn_support_utils.cc"], hdrs = ["cudnn_support_utils.h"], - visibility = ["//visibility:public"], deps = [ ":cublas_cudnn", "//xla:shape_util", @@ -2781,7 +2832,6 @@ cc_library( name = "cudnn_pad_for_convolutions", srcs = ["cudnn_pad_for_convolutions.cc"], hdrs = ["cudnn_pad_for_convolutions.h"], - visibility = ["//visibility:public"], deps = [ ":cublas_cudnn", ":cudnn_support_utils", @@ -2824,7 +2874,6 @@ cc_library( name = "cudnn_vectorize_convolutions", srcs = ["cudnn_vectorize_convolutions.cc"], hdrs = ["cudnn_vectorize_convolutions.h"], - visibility = ["//visibility:public"], deps = [ ":backend_configs_cc", ":cublas_cudnn", @@ -2877,7 +2926,6 @@ cc_library( name = "cudnn_simplify_padding", srcs = ["cudnn_simplify_padding.cc"], hdrs = ["cudnn_simplify_padding.h"], - visibility = ["//visibility:public"], deps = [ ":backend_configs_cc", ":cublas_cudnn", @@ -2934,7 +2982,6 @@ cc_library( name = "cublas_pad_for_gemms", srcs = ["cublas_pad_for_gemms.cc"], hdrs = ["cublas_pad_for_gemms.h"], - visibility = ["//visibility:public"], deps = [ ":gemm_rewriter_triton", ":ir_emission_utils", @@ -2957,7 +3004,6 @@ cc_library( name = "cublas_padding_requirements", srcs = ["cublas_padding_requirements.cc"], hdrs = ["cublas_padding_requirements.h"], - visibility = ["//visibility:public"], deps = [ ":variant_visitor", "//xla:shape_util", @@ -2992,20 +3038,17 @@ tf_proto_library( "//xla/service:hlo_proto", "//xla:xla_proto", ], - visibility = ["//visibility:public"], ) cc_library( name = "target_constants", hdrs = ["target_constants.h"], - visibility = ["//visibility:public"], ) cc_library( name = "gpu_transfer_manager", srcs = ["gpu_transfer_manager.cc"], hdrs = ["gpu_transfer_manager.h"], - visibility = ["//visibility:public"], deps = [ ":io_feed_manager", ":target_constants", @@ -3020,16 +3063,26 @@ cc_library( "//xla:xla_data_proto_cc", "//xla/service:compiler", "//xla/service:generic_transfer_manager", + "//xla/service:shaped_buffer", "//xla/service:transfer_manager", "//xla/stream_executor", + "//xla/stream_executor:memory_allocation", "//xla/stream_executor/cuda:cuda_platform_id", "//xla/stream_executor/host:host_platform_id", "//xla/stream_executor/rocm:rocm_platform_id", + "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/cleanup", + "@com_google_absl//absl/container:node_hash_map", + "@com_google_absl//absl/functional:function_ref", "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/synchronization", "@llvm-project//llvm:Core", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:logging", + "@local_tsl//tsl/platform:numbers", + "@local_tsl//tsl/platform:statusor", ], alwayslink = True, # Contains per-platform transfer manager registration ) @@ -3038,7 +3091,6 @@ cc_library( name = "gpu_reduce_scatter_creator", srcs = ["gpu_reduce_scatter_creator.cc"], hdrs = ["gpu_reduce_scatter_creator.h"], - visibility = ["//visibility:public"], deps = [ "//xla/hlo/ir:hlo", "//xla/hlo/utils:hlo_query", @@ -3052,7 +3104,6 @@ cc_library( name = "gpu_all_gather_optimizer", srcs = ["gpu_all_gather_optimizer.cc"], hdrs = ["gpu_all_gather_optimizer.h"], - visibility = ["//visibility:public"], deps = [ "//xla:shape_util", "//xla:statusor", @@ -3070,7 +3121,6 @@ cc_library( name = "gpu_float_support", srcs = ["gpu_float_support.cc"], hdrs = ["gpu_float_support.h"], - visibility = ["//visibility:public"], deps = [ "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", @@ -3087,7 +3137,6 @@ cc_library( hdrs = [ "compile_module_to_llvm_ir.h", ], - visibility = ["//visibility:public"], deps = [ ":executable_proto_cc", ":gpu_constants", @@ -3105,7 +3154,6 @@ cc_library( "//xla:util", "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", - "//xla/mlir/backends/gpu/transforms:passes", "//xla/mlir/runtime/transforms:compilation_pipeline_gpu", "//xla/mlir_hlo:transforms_gpu_passes", "//xla/service:buffer_assignment", @@ -3115,17 +3163,15 @@ cc_library( "//xla/service:hlo_ordering", "//xla/service:hlo_proto_cc", "//xla/service:logical_buffer", - "//xla/service/gpu/runtime:executable", - "//xla/service/gpu/runtime3:conditional_thunk", - "//xla/service/gpu/runtime3:sequential_thunk", - "//xla/service/gpu/runtime3:while_thunk", + "//xla/service/gpu/runtime:conditional_thunk", + "//xla/service/gpu/runtime:sequential_thunk", + "//xla/service/gpu/runtime:while_thunk", "//xla/service/llvm_ir:llvm_util", "//xla/stream_executor", "//xla/stream_executor:device_description", "//xla/stream_executor/rocm:rocm_platform_id", "//xla/translate/hlo_to_mhlo:hlo_utils", "//xla/translate/mhlo_to_hlo:location_exporter", - "//xla/translate/mhlo_to_lhlo_with_xla", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/status", @@ -3152,7 +3198,6 @@ cc_library( name = "command_buffer_scheduling", srcs = ["command_buffer_scheduling.cc"], hdrs = ["command_buffer_scheduling.h"], - visibility = ["//visibility:public"], deps = [ ":backend_configs_cc", ":cublas_cudnn", @@ -3164,6 +3209,8 @@ cc_library( "//xla:util", "//xla/hlo/ir:hlo", "//xla/service:hlo_pass", + "//xla/service/gpu:hlo_fusion_analysis", + "//xla/service/gpu:hlo_traversal", "//xla/stream_executor:device_description", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", @@ -3185,6 +3232,7 @@ xla_cc_test( ":command_buffer_scheduling", "//xla/hlo/ir:hlo", "//xla/service:hlo_parser", + "//xla/stream_executor:device_description", "//xla/tests:filecheck", "//xla/tests:hlo_test_base", "//xla/tests:verified_hlo_module", @@ -3199,7 +3247,6 @@ cc_library( name = "custom_kernel_fusion_rewriter", srcs = ["custom_kernel_fusion_rewriter.cc"], hdrs = ["custom_kernel_fusion_rewriter.h"], - visibility = ["//visibility:public"], deps = [ "//xla:shape_util", "//xla:statusor", @@ -3239,15 +3286,18 @@ cc_library( name = "address_computation_fusion_rewriter", srcs = ["address_computation_fusion_rewriter.cc"], hdrs = ["address_computation_fusion_rewriter.h"], - visibility = ["//visibility:public"], deps = [ ":backend_configs_cc", ":cublas_cudnn", ":hlo_traversal", ":ir_emission_utils", - "//xla:status_macros", + "//xla:shape_util", "//xla:statusor", + "//xla:util", + "//xla/ffi:ffi_api", + "//xla/ffi/api:c_api", "//xla/hlo/ir:hlo", + "//xla/service:custom_call_target_registry", "//xla/service:hlo_pass", "//xla/service/gpu/kernels:custom_fusion_library", "@com_google_absl//absl/algorithm:container", @@ -3264,13 +3314,27 @@ cc_library( xla_cc_test( name = "address_computation_fusion_rewriter_test", - srcs = ["address_computation_fusion_rewriter_test.cc"], + srcs = if_cuda_is_configured(["address_computation_fusion_rewriter_test.cc"]), deps = [ ":address_computation_fusion_rewriter", ":gpu_device_info_for_tests", + "//xla:shape_util", + "//xla/client:xla_builder", + "//xla/client/lib:constants", + "//xla/ffi", + "//xla/ffi:ffi_api", "//xla/hlo/ir:hlo", - "//xla/stream_executor:device_description", + "//xla/service:buffer_value", + "//xla/service:custom_call_target_registry", + "//xla/service:executable", + "//xla/service:hlo_memory_scheduler", + "//xla/service:hlo_module_config", + "//xla/stream_executor/gpu:gpu_types_header", "//xla/tests:hlo_test_base", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/status", + "@local_tsl//tsl/platform:status", + "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/platform:test", "@local_tsl//tsl/platform:test_main", ], @@ -3280,7 +3344,6 @@ cc_library( name = "fusion_pipeline", srcs = ["fusion_pipeline.cc"], hdrs = ["fusion_pipeline.h"], - visibility = ["//visibility:public"], deps = [ ":fusion_merger", ":horizontal_input_fusion", @@ -3288,6 +3351,7 @@ cc_library( ":instruction_fusion", ":multi_output_fusion", ":priority_fusion", + ":rename_fusions", ":variadic_op_splitter", "//xla:xla_proto_cc", "//xla/service:cpu_gpu_shape_verifier", @@ -3308,7 +3372,6 @@ cc_library( name = "prepare_hlo_for_ir_emitting_pipeline", srcs = ["prepare_hlo_for_ir_emitting_pipeline.cc"], hdrs = ["prepare_hlo_for_ir_emitting_pipeline.h"], - visibility = ["//visibility:public"], deps = [ ":alias_passthrough_params", ":copy_fusion", @@ -3338,8 +3401,8 @@ cc_library( local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]) + if_rocm_is_configured([ "TENSORFLOW_USE_ROCM=1", ]), - visibility = ["//visibility:public"], deps = if_gpu_is_configured([ + ":address_computation_fusion_rewriter", ":alias_passthrough_params", ":all_reduce_blueconnect", ":autotuner_util", @@ -3417,9 +3480,6 @@ cc_library( "//xla:xla_proto_cc", "//xla/hlo/ir:hlo", "//xla/hlo/transforms:hlo_constant_splitter", - "//xla/mlir/backends/gpu/transforms:passes", - "//xla/mlir/runtime/transforms:compilation_pipeline_gpu", - "//xla/runtime:jit_executable", "//xla/service:algebraic_simplifier", "//xla/service:all_gather_broadcast_reorder", "//xla/service:all_gather_combiner", @@ -3518,7 +3578,6 @@ cc_library( "//xla/stream_executor/cuda:cuda_platform_id", "//xla/translate/hlo_to_mhlo:hlo_utils", "//xla/translate/mhlo_to_hlo:location_exporter", - "//xla/translate/mhlo_to_lhlo_with_xla", "@local_tsl//tsl/platform:blocking_counter", "@local_tsl//tsl/platform:casts", "@local_tsl//tsl/platform:env", @@ -3540,15 +3599,12 @@ cc_library( "//xla:shape_util", "//xla/hlo/ir:hlo_module_group", "//xla/mlir/runtime/transforms:compilation_pipeline_options", - "//xla/runtime:compiler", - "//xla/runtime:executable", "//xla/service:buffer_value", "//xla/service:dynamic_dimension_inference", "//xla/service:hlo_cost_analysis", "//xla/service:hlo_ordering", "//xla/service:layout_assignment", "//xla/service:logical_buffer", - "//xla/service/gpu/runtime:executable", "//xla/stream_executor/rocm:rocm_platform_id", "@local_tsl//tsl/platform:numbers", ]) + xla_export_hlo_deps() + [ @@ -3558,22 +3614,24 @@ cc_library( ":ir_emitter_context", ":ir_emitter_unnested", ":prepare_hlo_for_ir_emitting_pipeline", + ":rename_fusions", ":thunk", + "//xla/stream_executor:platform_manager", "@llvm-project//mlir:FuncDialect", "@local_tsl//tsl/lib/monitoring:counter", ], ) -xla_cc_test( +xla_test( name = "gpu_compiler_test", srcs = ["gpu_compiler_test.cc"], - tags = tf_cuda_tests_tags(), + backends = ["gpu"], deps = [ ":horizontal_loop_fusion", ":metrics", "//xla:autotune_results_proto_cc", "//xla/service:buffer_assignment", - "//xla/service:gpu_plugin", + "//xla/service:hlo_module_config", "//xla/service:pattern_matcher", "//xla/service:pattern_matcher_gmock", "//xla/service:xla_debug_info_manager", @@ -3609,7 +3667,6 @@ cc_library( srcs = if_cuda_is_configured([ "nvptx_compiler_registration.cc", ]), - visibility = ["//visibility:public"], deps = if_cuda_is_configured([ "//xla/stream_executor/cuda:cuda_platform_id", ":nvptx_compiler_impl", @@ -3632,7 +3689,6 @@ cc_library( ], "//conditions:default": [], }), - visibility = ["//visibility:public"], deps = if_cuda_is_configured([ ":autotuner_util", ":buffer_sharing", @@ -3783,8 +3839,8 @@ xla_cc_test( "//xla/service:gpu_plugin", "//xla/service:platform_util", "//xla/stream_executor", - "//xla/stream_executor:multi_platform_manager", "//xla/stream_executor:platform", + "//xla/stream_executor:platform_manager", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", # build_cleaner: keep "@com_google_absl//absl/strings", @@ -3799,7 +3855,6 @@ cc_library( "amdgpu_compiler_registration.cc", ], tags = ["manual"], - visibility = ["//visibility:public"], deps = [ ":amdgpu_compiler_impl", "//xla/stream_executor/rocm:rocm_platform_id", @@ -3816,7 +3871,6 @@ cc_library( "amdgpu_compiler.h", ], tags = ["manual"], - visibility = ["//visibility:public"], deps = [ ":conv_algorithm_picker", ":cublas_pad_for_gemms", @@ -3857,7 +3911,6 @@ cc_library( name = "all_reduce_blueconnect", srcs = ["all_reduce_blueconnect.cc"], hdrs = ["all_reduce_blueconnect.h"], - visibility = ["//visibility:public"], deps = [ "//xla:shape_util", "//xla:status_macros", @@ -3900,7 +3953,6 @@ xla_cc_test( cc_library( name = "xfeed_queue", hdrs = ["xfeed_queue.h"], - visibility = ["//visibility:public"], deps = [ "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/synchronization", @@ -3920,7 +3972,6 @@ cc_library( "outfeed_manager.h", ], local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]), - visibility = ["//visibility:public"], deps = [ ":xfeed_queue", "//xla:literal", @@ -3942,7 +3993,6 @@ cc_library( name = "gpu_layout_assignment", srcs = ["gpu_layout_assignment.cc"], hdrs = ["gpu_layout_assignment.h"], - visibility = ["//visibility:public"], deps = [ ":backend_configs_cc", ":cublas_cudnn", @@ -3999,7 +4049,6 @@ cc_library( name = "gpu_schedule_postprocessing", srcs = ["gpu_schedule_postprocessing.cc"], hdrs = ["gpu_schedule_postprocessing.h"], - visibility = ["//visibility:public"], deps = [ ":backend_configs_cc", "//xla:statusor", @@ -4037,7 +4086,6 @@ cc_library( name = "gpu_hlo_schedule", srcs = ["gpu_hlo_schedule.cc"], hdrs = ["gpu_hlo_schedule.h"], - visibility = ["//visibility:public"], deps = [ ":backend_configs_cc", ":cublas_cudnn", @@ -4126,7 +4174,6 @@ cc_library( srcs = ["stream_executor_util.cc"], hdrs = ["stream_executor_util.h"], copts = tsl_copts(), - visibility = ["//visibility:public"], deps = [ ":cublas_cudnn", ":launch_dimensions", @@ -4139,6 +4186,7 @@ cc_library( "//xla/hlo/ir:hlo", "//xla/service:hlo_module_config", "//xla/stream_executor", + "//xla/stream_executor:launch_dim", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", @@ -4146,6 +4194,7 @@ cc_library( "@com_google_absl//absl/time", "@com_google_absl//absl/types:span", "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/util:env_var", "@local_tsl//tsl/util/proto:proto_utils", ], @@ -4171,7 +4220,6 @@ cc_library( hdrs = ["gpu_asm_opts_util.h"], compatible_with = get_compatible_with_portable(), copts = tsl_copts(), - visibility = ["//visibility:public"], deps = [ "//xla:xla_proto_cc", "//xla/stream_executor/gpu:gpu_asm_opts", @@ -4184,7 +4232,6 @@ cc_library( srcs = ["hlo_fusion_analysis.cc"], hdrs = ["hlo_fusion_analysis.h"], local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]), - visibility = ["//visibility:public"], deps = [ ":backend_configs_cc", ":hlo_traversal", @@ -4227,7 +4274,6 @@ cc_library( local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]) + if_rocm_is_configured([ "TENSORFLOW_USE_ROCM=1", ]), - visibility = ["//visibility:public"], deps = if_gpu_is_configured([ ":buffer_comparator_kernel", ":gpu_asm_opts_util", @@ -4256,7 +4302,6 @@ cuda_library( local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]) + if_rocm_is_configured([ "TENSORFLOW_USE_ROCM=1", ]), - visibility = ["//visibility:public"], deps = if_cuda_is_configured([ "@local_config_cuda//cuda:cuda_headers", ]) + if_rocm_is_configured([ @@ -4277,7 +4322,7 @@ xla_cc_test( "//xla:types", "//xla/stream_executor", "//xla/stream_executor:device_memory_allocator", - "//xla/stream_executor:multi_platform_manager", + "//xla/stream_executor:platform_manager", "@local_tsl//tsl/platform:ml_dtypes", "@local_tsl//tsl/platform:status", "@local_tsl//tsl/platform:test", @@ -4292,7 +4337,6 @@ cc_library( name = "buffer_sharing", srcs = ["buffer_sharing.cc"], hdrs = ["buffer_sharing.h"], - visibility = ["//visibility:public"], deps = [ ":backend_configs_cc", ":cublas_cudnn", @@ -4312,7 +4356,6 @@ cc_library( name = "gpu_fusible", srcs = ["gpu_fusible.cc"], hdrs = ["gpu_fusible.h"], - visibility = ["//visibility:public"], deps = [ ":backend_configs_cc", ":hlo_traversal", @@ -4348,7 +4391,6 @@ cc_library( srcs = ["cudnn_fused_conv_rewriter.cc"], hdrs = ["cudnn_fused_conv_rewriter.h"], local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]), - visibility = ["//visibility:public"], deps = [ ":backend_configs_cc", ":cublas_cudnn", @@ -4431,7 +4473,6 @@ cc_library( srcs = ["cudnn_norm_rewriter.cc"], hdrs = ["cudnn_norm_rewriter.h"], local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]), - visibility = ["//visibility:public"], deps = [ ":backend_configs_cc", ":cublas_cudnn", @@ -4485,7 +4526,6 @@ cc_library( srcs = ["cudnn_fused_mha_rewriter.cc"], hdrs = ["cudnn_fused_mha_rewriter.h"], local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]), - visibility = ["//visibility:public"], deps = [ ":backend_configs_cc", ":cublas_cudnn", @@ -4520,7 +4560,6 @@ cc_library( name = "cudnn_fused_mha_transpose_fusion", srcs = ["cudnn_fused_mha_transpose_fusion.cc"], hdrs = ["cudnn_fused_mha_transpose_fusion.h"], - visibility = ["//visibility:public"], deps = [ ":backend_configs_cc", ":cublas_cudnn", @@ -4615,7 +4654,6 @@ cc_library( name = "variadic_op_splitter", srcs = ["variadic_op_splitter.cc"], hdrs = ["variadic_op_splitter.h"], - visibility = ["//visibility:public"], deps = [ "//xla:statusor", "//xla:util", @@ -4632,7 +4670,6 @@ cc_library( name = "gpu_scatter_expander", srcs = ["gpu_scatter_expander.cc"], hdrs = ["gpu_scatter_expander.h"], - visibility = ["//visibility:public"], deps = [ "//xla:statusor", "//xla/hlo/ir:hlo", @@ -4671,20 +4708,23 @@ tf_proto_library( "//xla/service:hlo_proto", "//xla:autotuning_proto", ], - visibility = ["//visibility:public"], ) cc_library( name = "hlo_algorithm_denylist", srcs = ["hlo_algorithm_denylist.cc"], hdrs = ["hlo_algorithm_denylist.h"], - visibility = ["//visibility:public"], deps = [ ":gpu_autotuning_proto_cc", "//xla:autotuning_proto_cc", "//xla:debug_options_flags", "//xla/stream_executor", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/types:span", + "@local_tsl//tsl/platform:env", + "@local_tsl//tsl/platform:protobuf", + "@local_tsl//tsl/platform:status", ], ) @@ -4710,7 +4750,6 @@ cc_library( name = "alias_passthrough_params", srcs = ["alias_passthrough_params.cc"], hdrs = ["alias_passthrough_params.h"], - visibility = ["//visibility:public"], deps = [ "//xla:shape_util", "//xla:statusor", @@ -4742,7 +4781,6 @@ cc_library( name = "horizontal_loop_fusion", srcs = ["horizontal_loop_fusion.cc"], hdrs = ["horizontal_loop_fusion.h"], - visibility = ["//visibility:public"], deps = [ ":gpu_fusible", "//xla:shape_util", @@ -4789,7 +4827,6 @@ cc_library( name = "horizontal_input_fusion", srcs = ["horizontal_input_fusion.cc"], hdrs = ["horizontal_input_fusion.h"], - visibility = ["//visibility:public"], deps = [ ":gpu_fusible", "//xla/hlo/ir:hlo", @@ -4822,7 +4859,6 @@ cc_library( name = "reduction_degenerate_dim_remover", srcs = ["reduction_degenerate_dim_remover.cc"], hdrs = ["reduction_degenerate_dim_remover.h"], - visibility = ["//visibility:public"], deps = [ ":ir_emission_utils", "//xla:shape_util", @@ -4842,7 +4878,6 @@ cc_library( name = "reduction_dimension_grouper", srcs = ["reduction_dimension_grouper.cc"], hdrs = ["reduction_dimension_grouper.h"], - visibility = ["//visibility:public"], deps = [ "//xla:shape_util", "//xla:statusor", @@ -4858,7 +4893,6 @@ cc_library( name = "reduction_splitter", srcs = ["reduction_splitter.cc"], hdrs = ["reduction_splitter.h"], - visibility = ["//visibility:public"], deps = [ ":reduction_utils", "//xla:shape_util", @@ -4888,7 +4922,6 @@ cc_library( name = "reduction_layout_normalizer", srcs = ["reduction_layout_normalizer.cc"], hdrs = ["reduction_layout_normalizer.h"], - visibility = ["//visibility:public"], deps = [ ":ir_emission_utils", "//xla:shape_util", @@ -4909,7 +4942,6 @@ cc_library( name = "tree_reduction_rewriter", srcs = ["tree_reduction_rewriter.cc"], hdrs = ["tree_reduction_rewriter.h"], - visibility = ["//visibility:public"], deps = [ ":reduction_utils", "//xla:shape_util", @@ -4931,7 +4963,6 @@ cc_library( name = "gemm_broadcast_folding_rewriter", srcs = ["gemm_broadcast_folding_rewriter.cc"], hdrs = ["gemm_broadcast_folding_rewriter.h"], - visibility = ["//visibility:public"], deps = [ ":backend_configs_cc", ":cublas_cudnn", @@ -4951,7 +4982,6 @@ cc_library( name = "metrics", srcs = ["metrics.cc"], hdrs = ["metrics.h"], - visibility = ["//visibility:public"], deps = [ "@local_tsl//tsl/lib/monitoring:counter", "@local_tsl//tsl/lib/monitoring:gauge", @@ -4963,7 +4993,6 @@ cc_library( name = "dot_operand_converter", srcs = ["dot_operand_converter.cc"], hdrs = ["dot_operand_converter.h"], - visibility = ["//visibility:public"], deps = [ "//xla:shape_util", "//xla/hlo/ir:hlo", @@ -5000,7 +5029,6 @@ cc_library( name = "make_batch_pointers", srcs = if_gpu_is_configured(["make_batch_pointers.cc"]), hdrs = if_gpu_is_configured(["make_batch_pointers.h"]), - visibility = ["//visibility:public"], deps = [ "//xla:status", "//xla:statusor", @@ -5022,7 +5050,6 @@ cc_library( cuda_library( name = "make_batch_pointers_kernel", srcs = if_cuda_is_configured(["make_batch_pointers.cu.cc"]), - visibility = ["//visibility:public"], deps = [ "@local_config_cuda//cuda:cuda_headers", # build_cleaner: keep ], @@ -5032,7 +5059,6 @@ cc_library( name = "triangular_solve_rewriter", srcs = ["triangular_solve_rewriter.cc"], hdrs = ["triangular_solve_rewriter.h"], - visibility = ["//visibility:public"], deps = [ ":cublas_cudnn", "//xla:statusor", @@ -5047,7 +5073,6 @@ tsl_gpu_library( name = "runtime_intrinsics", srcs = ["runtime_intrinsics.cc"], hdrs = ["runtime_intrinsics.h"], - visibility = ["//visibility:public"], deps = [ "//xla:shape_util", "//xla:status", @@ -5088,7 +5113,6 @@ cc_library( name = "hlo_fusion_stats", srcs = ["hlo_fusion_stats.cc"], hdrs = ["hlo_fusion_stats.h"], - visibility = ["//visibility:public"], deps = [ "//xla:status", "//xla:statusor", @@ -5120,7 +5144,6 @@ cc_library( name = "scatter_slice_simplifier", srcs = ["scatter_slice_simplifier.cc"], hdrs = ["scatter_slice_simplifier.h"], - visibility = ["//visibility:public"], deps = [ "//xla:shape_util", "//xla/hlo/ir:hlo", @@ -5148,7 +5171,6 @@ cc_library( name = "conv_layout_normalization", srcs = ["conv_layout_normalization.cc"], hdrs = ["conv_layout_normalization.h"], - visibility = ["//visibility:public"], deps = [ ":cublas_cudnn", "//xla:shape_util", @@ -5169,7 +5191,6 @@ cc_library( local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]) + if_rocm_is_configured([ "TENSORFLOW_USE_ROCM=1", ]), - visibility = ["//visibility:public"], deps = [ "//xla:executable_run_options", "//xla:shape_util", @@ -5195,7 +5216,6 @@ cc_library( name = "topk_splitter", srcs = ["topk_splitter.cc"], hdrs = ["topk_splitter.h"], - visibility = ["//visibility:public"], deps = [ "//xla:literal_util", "//xla:shape_util", @@ -5236,11 +5256,39 @@ xla_cc_test( ], ) +xla_cc_test( + name = "topk_test", + srcs = ["topk_test.cc"], + local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]) + if_rocm_is_configured([ + "TENSORFLOW_USE_ROCM=1", + ]), + tags = tf_gpu_tests_tags(), + deps = [ + "//xla:error_spec", + "//xla:shape_util", + "//xla:status", + "//xla:statusor", + "//xla:types", + "//xla/hlo/ir:hlo", + "//xla/service:gpu_plugin", + "//xla/service:hlo_pass", + "//xla/service:platform_util", + "//xla/service:topk_rewriter", + "//xla/service/gpu:topk_specializer", + "//xla/tests:hlo_test_base", + "//xla/tests:verified_hlo_module", + "//xla/tests:xla_internal_test_main", # fixdeps: keep + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/strings", + "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/platform:test_main", + ], +) + cc_library( name = "copy_fusion", srcs = ["copy_fusion.cc"], hdrs = ["copy_fusion.h"], - visibility = ["//visibility:public"], deps = [ ":gpu_fusible", ":ir_emission_utils", @@ -5260,12 +5308,12 @@ cc_library( name = "kernel_reuse_cache", srcs = ["kernel_reuse_cache.cc"], hdrs = ["kernel_reuse_cache.h"], - visibility = ["//visibility:public"], deps = [ ":kernel_arguments", ":launch_dimensions", "//xla:util", "//xla/hlo/ir:hlo", + "//xla/stream_executor:launch_dim", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", @@ -5278,7 +5326,6 @@ cc_library( name = "kernel_arguments", srcs = ["kernel_arguments.cc"], hdrs = ["kernel_arguments.h"], - visibility = ["//visibility:public"], deps = [ ":gpu_constants", ":ir_emission_utils", @@ -5301,7 +5348,6 @@ cc_library( srcs = ["hlo_traversal.cc"], hdrs = ["hlo_traversal.h"], compatible_with = get_compatible_with_portable(), - visibility = ["//visibility:public"], deps = [ "//xla:shape_util", "//xla/hlo/ir:hlo", @@ -5309,6 +5355,7 @@ cc_library( "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/log:check", + "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/types:span", ], @@ -5329,7 +5376,6 @@ cc_library( name = "fusion_wrapper", srcs = ["fusion_wrapper.cc"], hdrs = ["fusion_wrapper.h"], - visibility = ["//visibility:public"], deps = [ ":gpu_fusible", "//xla:status_macros", @@ -5392,7 +5438,6 @@ cc_library( name = "loop_double_buffer_transformer", srcs = ["loop_double_buffer_transformer.cc"], hdrs = ["loop_double_buffer_transformer.h"], - visibility = ["//visibility:public"], deps = [ "//xla:status", "//xla:statusor", @@ -5460,7 +5505,6 @@ xla_cc_test( cc_library( name = "gpu_symbol_repository", hdrs = ["gpu_symbol_repository.h"], - visibility = ["//visibility:public"], deps = [ "//xla:xla_proto_cc", "//xla/service:symbol_repository", diff --git a/third_party/xla/xla/service/gpu/address_computation_fusion_rewriter.cc b/third_party/xla/xla/service/gpu/address_computation_fusion_rewriter.cc index 9d3ed700ea5e4c..6aec517ffe1a70 100644 --- a/third_party/xla/xla/service/gpu/address_computation_fusion_rewriter.cc +++ b/third_party/xla/xla/service/gpu/address_computation_fusion_rewriter.cc @@ -14,11 +14,13 @@ limitations under the License. ==============================================================================*/ #include "xla/service/gpu/address_computation_fusion_rewriter.h" -#include +#include #include #include #include +#include #include +#include #include "absl/algorithm/container.h" #include "absl/container/flat_hash_map.h" @@ -28,13 +30,20 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" +#include "xla/ffi/api/c_api.h" +#include "xla/ffi/ffi_api.h" +#include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/ir/hlo_schedule.h" +#include "xla/service/custom_call_target_registry.h" #include "xla/service/gpu/backend_configs.pb.h" #include "xla/service/gpu/cublas_cudnn.h" #include "xla/service/gpu/hlo_traversal.h" #include "xla/service/gpu/ir_emission_utils.h" -#include "xla/status_macros.h" +#include "xla/shape.h" +#include "xla/util.h" #include "tsl/platform/errors.h" #include "tsl/platform/statusor.h" @@ -48,16 +57,61 @@ bool IsNoOp(const HloInstruction* hlo) { HloOpcode::kGetTupleElement>(hlo); } +bool IsCustomCall(const HloInstruction* hlo, absl::string_view platform_name) { + auto* custom_call = DynCast(hlo); + if (custom_call == nullptr) return false; + + // TODO(vuson): properly handle token by following + // `LhloDialectEmitter::EmitCustomCallOp`'s `CreateOperands` logic for + // `LhloDialectEmitter::EmitFusionOp`'s `RewriteFusionOperand` + if (custom_call->shape().IsTuple() && + absl::c_any_of( + custom_call->shape().tuple_shapes(), + [&](const Shape& sub_shape) { return sub_shape.IsToken(); })) + return false; + + const std::string call_target_name = custom_call->custom_call_target(); + + bool is_ffi_custom_call = + custom_call->api_version() == CustomCallApiVersion::API_VERSION_TYPED_FFI; + + void* call_target = CustomCallTargetRegistry::Global()->Lookup( + call_target_name, std::string(platform_name)); + + absl::StatusOr handler = + ffi::FindHandler(call_target_name, platform_name); + + // At least one implementation should be available at run time. + bool found_custom_call = !is_ffi_custom_call && call_target != nullptr; + bool found_ffi_handler = is_ffi_custom_call && handler.ok(); + + return found_custom_call || found_ffi_handler; +} + absl::InlinedVector GetSlicedOperandChains( const HloInstruction* instr) { absl::InlinedVector sliced_operand_chains = { const_cast(instr)}; auto fusion = HloFusionAdaptor::ForComputation(instr->parent()); + absl::flat_hash_set processed_sliced_chain_set; + + const auto& aliasing_pairs = + Cast(instr)->output_to_operand_aliasing(); + absl::flat_hash_set aliased_operands; + for (const auto& pair : aliasing_pairs) { + aliased_operands.insert(pair.second.first); + } + for (auto* operand : instr->operands()) { + // output_to_operand_aliasing means the operand is to be materialized, which + // is against the whole idea of address computation fusion. Skip this + // operand. + if (aliased_operands.contains(instr->operand_index(operand))) continue; absl::InlinedVector maybe_sliced_operand_chain; auto maybe_slice_adaptor = HloFindIf({HloInstructionAdaptor(*operand)}, *fusion, [&](auto node) { const HloInstruction* cur = &node.instruction(); + if (processed_sliced_chain_set.contains(cur)) return true; maybe_sliced_operand_chain.push_back( const_cast(cur)); // TODO(vuson): lift the first restriction by considering fusing other @@ -70,10 +124,13 @@ absl::InlinedVector GetSlicedOperandChains( }); if (maybe_slice_adaptor == std::nullopt) continue; const auto& maybe_slice_instr = maybe_slice_adaptor->instruction(); - if (IsContiguousSlice(maybe_slice_instr)) { + if (IsContiguousSlice(maybe_slice_instr) || + processed_sliced_chain_set.contains(&maybe_slice_instr)) { sliced_operand_chains.insert(sliced_operand_chains.end(), maybe_sliced_operand_chain.begin(), maybe_sliced_operand_chain.end()); + processed_sliced_chain_set.insert(maybe_sliced_operand_chain.begin(), + maybe_sliced_operand_chain.end()); } } return sliced_operand_chains; @@ -98,6 +155,48 @@ absl::InlinedVector GetPatternCaptures( return captures; } +absl::InlinedVector GetSortedMatched( + absl::Span matched) { + absl::InlinedVector sorted_matched; + absl::flat_hash_set instructions_set(matched.begin(), + matched.end()); + absl::flat_hash_set processed_set; + // Topologically sort `matched` + for (auto it = matched.rbegin(); it != matched.rend(); ++it) { + if (processed_set.contains(*it)) continue; + for (auto* operand : (*it)->operands()) { + if (!instructions_set.contains(operand)) { + continue; + } + if (!processed_set.contains(operand)) { + sorted_matched.emplace_back(operand); + processed_set.insert(operand); + } + } + sorted_matched.emplace_back(*it); + processed_set.insert(*it); + } + + return sorted_matched; +} + +void CreateRootTuple(HloInstruction* root, HloComputation::Builder& builder) { + std::vector elements; + elements.reserve(root->shape().tuple_shapes_size()); + for (size_t i = 0; i < root->shape().tuple_shapes_size(); ++i) { + if (root->shape().tuple_shapes(i).IsTuple()) { + HloInstruction* gte = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(root, i)); + CreateRootTuple(gte, builder); + elements.push_back(builder.last_added_instruction()); + } else { + elements.push_back(builder.AddInstruction( + HloInstruction::CreateGetTupleElement(root, i))); + } + } + builder.AddInstruction(HloInstruction::CreateTuple(elements)); +} + absl::StatusOr CreateFusionBody( HloModule* module, absl::Span matched, absl::Span captures) { @@ -131,16 +230,10 @@ absl::StatusOr CreateFusionBody( } HloInstruction* root = builder.last_added_instruction(); - - // If the custom call requires a workspace we wrap the produced values with a - // root tuple of "real" result and a workspace. - if (root->shape().IsTuple()) { - TF_RET_CHECK(root->shape().tuple_shapes_size() == 2); - HloInstruction* result = - builder.AddInstruction(HloInstruction::CreateGetTupleElement(root, 0)); - HloInstruction* workspace = - builder.AddInstruction(HloInstruction::CreateGetTupleElement(root, 1)); - builder.AddInstruction(HloInstruction::CreateTuple({result, workspace})); + // Create a root tuple if the root is a tuple to make sure there's a buffer + // assigned for each of the elements. Make sure the tuple is not nil first. + if (root->shape().IsTuple() && root->shape().tuple_shapes_size() > 0) { + CreateRootTuple(root, builder); } return module->AddComputationAndUnifyNamesAndIds(builder.Build(), false); @@ -157,6 +250,9 @@ absl::StatusOr CreateFusionInstruction( captures, body)); module->SetAndUniquifyInstrName(fusion, "address_computation"); + // We don't need to set/update output_to_operand_aliasing for the new fusion + // instruction because all buffers are already assigned at this point. + // Set backends config to a matched custom fusion config. GpuBackendConfig gpu_config; FusionBackendConfig& backend_config = @@ -175,7 +271,7 @@ absl::StatusOr CreateFusionInstruction( absl::StatusOr AddressComputationFusionRewriter::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { - auto instructions = module->entry_computation()->MakeInstructionPostOrder(); + if (!module->has_schedule()) return Internal("module is not scheduled"); bool changed = false; absl::flat_hash_map> @@ -185,7 +281,7 @@ absl::StatusOr AddressComputationFusionRewriter::Run( for (HloComputation* computation : module->computations()) { if (computation->IsFusionComputation()) continue; for (HloInstruction* instr : computation->instructions()) { - if (IsLegacyCublasMatmul(*instr)) { + if (IsLegacyCublasMatmul(*instr) || IsCustomCall(instr, platform_name_)) { auto sliced_operand_chains = GetSlicedOperandChains(instr); if (!(sliced_operand_chains.size() == 1 && sliced_operand_chains.front() == instr)) { @@ -195,19 +291,36 @@ absl::StatusOr AddressComputationFusionRewriter::Run( } } + HloSchedule& schedule = module->schedule(); for (auto& kv : matches) { auto captures = GetPatternCaptures(kv.second); - std::reverse(kv.second.begin(), kv.second.end()); + auto sorted = GetSortedMatched(kv.second); + TF_ASSIGN_OR_RETURN(HloComputation * fusion_body, - CreateFusionBody(module, kv.second, captures)); + CreateFusionBody(module, sorted, captures)); TF_ASSIGN_OR_RETURN( HloInstruction * fusion, CreateFusionInstruction(module, kv.first, captures, fusion_body)); + + // As we are running after scheduling we have to keep it valid. HloComputation* parent = kv.first->parent(); + + // Update schedule to replace the custom call instruction with the fusion + // instruction. + // Removal of the rest of the instructions in the sequence is handled by + // schedule update below. + HloInstructionSequence& sequence = schedule.GetOrCreateSequence(parent); + sequence.replace_instruction(kv.first, fusion); + + // TODO(vuson): handle control dependencies TF_RETURN_IF_ERROR(parent->ReplaceInstruction(kv.first, fusion)); changed = true; } + if (changed) { + TF_RETURN_IF_ERROR(module->schedule().Update()); + } + return changed; } diff --git a/third_party/xla/xla/service/gpu/address_computation_fusion_rewriter.h b/third_party/xla/xla/service/gpu/address_computation_fusion_rewriter.h index 6731fad070e75e..d2fc6fac228f30 100644 --- a/third_party/xla/xla/service/gpu/address_computation_fusion_rewriter.h +++ b/third_party/xla/xla/service/gpu/address_computation_fusion_rewriter.h @@ -15,11 +15,14 @@ limitations under the License. #ifndef XLA_SERVICE_GPU_ADDRESS_COMPUTATION_FUSION_REWRITER_H_ #define XLA_SERVICE_GPU_ADDRESS_COMPUTATION_FUSION_REWRITER_H_ +#include +#include + #include "absl/container/flat_hash_set.h" +#include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/service/hlo_pass_interface.h" -#include "xla/statusor.h" namespace xla { namespace gpu { @@ -70,9 +73,15 @@ class AddressComputationFusionRewriter : public HloModulePass { return "address-computation-fusion-rewriter"; } + explicit AddressComputationFusionRewriter(std::string platform_name) + : platform_name_(std::move(platform_name)) {} + absl::StatusOr Run( HloModule* module, const absl::flat_hash_set& execution_threads) override; + + private: + std::string platform_name_; }; } // namespace gpu diff --git a/third_party/xla/xla/service/gpu/address_computation_fusion_rewriter_test.cc b/third_party/xla/xla/service/gpu/address_computation_fusion_rewriter_test.cc index 22d7e8684c4cf7..8a420c647ab5ee 100644 --- a/third_party/xla/xla/service/gpu/address_computation_fusion_rewriter_test.cc +++ b/third_party/xla/xla/service/gpu/address_computation_fusion_rewriter_test.cc @@ -15,23 +15,46 @@ limitations under the License. #include "xla/service/gpu/address_computation_fusion_rewriter.h" +#include +#include +#include #include - +#include + +#include "absl/algorithm/container.h" +#include "absl/status/status.h" +#include "xla/client/lib/constants.h" +#include "xla/client/xla_builder.h" +#include "xla/ffi/ffi.h" +#include "xla/ffi/ffi_api.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/ir/hlo_schedule.h" +#include "xla/service/buffer_value.h" +#include "xla/service/custom_call_target_registry.h" #include "xla/service/gpu/gpu_device_info_for_tests.h" +#include "xla/service/hlo_memory_scheduler.h" +#include "xla/service/hlo_module_config.h" +#include "xla/service/service_executable_run_options.h" +#include "xla/shape.h" +#include "xla/shape_util.h" +#include "xla/stream_executor/gpu/gpu_types.h" #include "xla/tests/hlo_test_base.h" +#include "tsl/platform/status.h" +#include "tsl/platform/statusor.h" #include "tsl/platform/test.h" +#define PLATFORM "GPU" namespace xla::gpu { class AddressComputationFusionRewriterTest : public HloTestBase {}; TEST_F(AddressComputationFusionRewriterTest, SimpleGemm) { const char* hlo = R"( - HloModule test + HloModule test, is_scheduled=true ENTRY %main.9 { - %p0 = f16[2,8,8]{2,1,0} parameter(0), sharding={replicated} - %p1 = f16[2,8,8]{2,1,0} parameter(1), sharding={replicated} + %p0 = f16[2,8,8]{2,1,0} parameter(0) + %p1 = f16[2,8,8]{2,1,0} parameter(1) %slice.13 = f16[1,8,8]{2,1,0} slice(%p0), slice={[1:2], [0:8], [0:8]} %bitcast.41 = f16[8,8]{1,0} bitcast(%slice.13) %slice.14 = f16[1,8,8]{2,1,0} slice(%p1), slice={[1:2], [0:8], [0:8]} @@ -82,16 +105,20 @@ TEST_F(AddressComputationFusionRewriterTest, SimpleGemm) { )"; auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo(); - RunAndFilecheckHloRewrite(hlo, AddressComputationFusionRewriter(), expected); + RunAndFilecheckHloRewrite(hlo, AddressComputationFusionRewriter(PLATFORM), + expected, [](HloModule* module) { + EXPECT_TRUE(module->has_schedule()); + TF_CHECK_OK(module->schedule().Verify()); + }); } TEST_F(AddressComputationFusionRewriterTest, SimpleGemmWithWorkspace) { const char* hlo = R"( - HloModule test + HloModule test, is_scheduled=true ENTRY %main.9 { - %p0 = f16[2,8,8]{2,1,0} parameter(0), sharding={replicated} - %p1 = f16[2,8,8]{2,1,0} parameter(1), sharding={replicated} + %p0 = f16[2,8,8]{2,1,0} parameter(0) + %p1 = f16[2,8,8]{2,1,0} parameter(1) %slice.13 = f16[1,8,8]{2,1,0} slice(%p0), slice={[1:2], [0:8], [0:8]} %bitcast.41 = f16[8,8]{1,0} bitcast(%slice.13) %slice.14 = f16[1,8,8]{2,1,0} slice(%p1), slice={[1:2], [0:8], [0:8]} @@ -146,16 +173,20 @@ TEST_F(AddressComputationFusionRewriterTest, SimpleGemmWithWorkspace) { )"; auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo(); - RunAndFilecheckHloRewrite(hlo, AddressComputationFusionRewriter(), expected); + RunAndFilecheckHloRewrite(hlo, AddressComputationFusionRewriter(PLATFORM), + expected, [](HloModule* module) { + EXPECT_TRUE(module->has_schedule()); + TF_CHECK_OK(module->schedule().Verify()); + }); } TEST_F(AddressComputationFusionRewriterTest, SimpleGemmNotRoot) { const char* hlo = R"( - HloModule test + HloModule test, is_scheduled=true ENTRY %main.9 { - %p0 = f16[2,8,8]{2,1,0} parameter(0), sharding={replicated} - %p1 = f16[2,8,8]{2,1,0} parameter(1), sharding={replicated} + %p0 = f16[2,8,8]{2,1,0} parameter(0) + %p1 = f16[2,8,8]{2,1,0} parameter(1) %slice.13 = f16[1,8,8]{2,1,0} slice(%p0), slice={[1:2], [0:8], [0:8]} %bitcast.41 = f16[8,8]{1,0} bitcast(%slice.13) %slice.14 = f16[1,8,8]{2,1,0} slice(%p1), slice={[1:2], [0:8], [0:8]} @@ -208,17 +239,21 @@ TEST_F(AddressComputationFusionRewriterTest, SimpleGemmNotRoot) { )"; auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo(); - RunAndFilecheckHloRewrite(hlo, AddressComputationFusionRewriter(), expected); + RunAndFilecheckHloRewrite(hlo, AddressComputationFusionRewriter(PLATFORM), + expected, [](HloModule* module) { + EXPECT_TRUE(module->has_schedule()); + TF_CHECK_OK(module->schedule().Verify()); + }); } TEST_F(AddressComputationFusionRewriterTest, SimpleGemmOperandHasMultipleUsers) { const char* hlo = R"( - HloModule test + HloModule test, is_scheduled=true ENTRY %main.9 { - %p0 = f16[2,8,8]{2,1,0} parameter(0), sharding={replicated} - %p1 = f16[2,8,8]{2,1,0} parameter(1), sharding={replicated} + %p0 = f16[2,8,8]{2,1,0} parameter(0) + %p1 = f16[2,8,8]{2,1,0} parameter(1) %slice.13 = f16[1,8,8]{2,1,0} slice(%p0), slice={[1:2], [0:8], [0:8]} %bitcast.41 = f16[8,8]{1,0} bitcast(%slice.13) %slice.14 = f16[1,8,8]{2,1,0} slice(%p1), slice={[1:2], [0:8], [0:8]} @@ -273,17 +308,21 @@ TEST_F(AddressComputationFusionRewriterTest, )"; auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo(); - RunAndFilecheckHloRewrite(hlo, AddressComputationFusionRewriter(), expected); + RunAndFilecheckHloRewrite(hlo, AddressComputationFusionRewriter(PLATFORM), + expected, [](HloModule* module) { + EXPECT_TRUE(module->has_schedule()); + TF_CHECK_OK(module->schedule().Verify()); + }); } TEST_F(AddressComputationFusionRewriterTest, SimpleGemmOperandsHaveMultipleUsers) { const char* hlo = R"( - HloModule test + HloModule test, is_scheduled=true ENTRY %main.9 { - %p0 = f16[2,8,8]{2,1,0} parameter(0), sharding={replicated} - %p1 = f16[2,8,8]{2,1,0} parameter(1), sharding={replicated} + %p0 = f16[2,8,8]{2,1,0} parameter(0) + %p1 = f16[2,8,8]{2,1,0} parameter(1) %slice.13 = f16[1,8,8]{2,1,0} slice(%p0), slice={[1:2], [0:8], [0:8]} %bitcast.41 = f16[8,8]{1,0} bitcast(%slice.13) %slice.14 = f16[1,8,8]{2,1,0} slice(%p1), slice={[1:2], [0:8], [0:8]} @@ -332,17 +371,17 @@ TEST_F(AddressComputationFusionRewriterTest, )"; auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo(); - RunAndFilecheckHloRewrite(hlo, AddressComputationFusionRewriter(), + RunAndFilecheckHloRewrite(hlo, AddressComputationFusionRewriter(PLATFORM), std::nullopt); } TEST_F(AddressComputationFusionRewriterTest, SimpleGemmSlicingNotParameter) { const char* hlo = R"( - HloModule test + HloModule test, is_scheduled=true ENTRY %main.9 { - %p0 = f16[4,8,8]{2,1,0} parameter(0), sharding={replicated} - %p1 = f16[2,8,8]{2,1,0} parameter(1), sharding={replicated} + %p0 = f16[4,8,8]{2,1,0} parameter(0) + %p1 = f16[2,8,8]{2,1,0} parameter(1) %slice.12 = f16[2,8,8]{2,1,0} slice(%p0), slice={[0:2], [0:8], [0:8]} %slice.13 = f16[1,8,8]{2,1,0} slice(%slice.12), slice={[1:2], [0:8], [0:8]} %bitcast.41 = f16[8,8]{1,0} bitcast(%slice.13) @@ -399,16 +438,20 @@ TEST_F(AddressComputationFusionRewriterTest, SimpleGemmSlicingNotParameter) { )"; auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo(); - RunAndFilecheckHloRewrite(hlo, AddressComputationFusionRewriter(), expected); + RunAndFilecheckHloRewrite(hlo, AddressComputationFusionRewriter(PLATFORM), + expected, [](HloModule* module) { + EXPECT_TRUE(module->has_schedule()); + TF_CHECK_OK(module->schedule().Verify()); + }); } TEST_F(AddressComputationFusionRewriterTest, SimpleGemmNotContiguousSlice) { const char* hlo = R"( - HloModule test + HloModule test, is_scheduled=true ENTRY %main.9 { - %p0 = f16[2,8,8]{2,1,0} parameter(0), sharding={replicated} - %p1 = f16[2,8,8]{2,1,0} parameter(1), sharding={replicated} + %p0 = f16[2,8,8]{2,1,0} parameter(0) + %p1 = f16[2,8,8]{2,1,0} parameter(1) %slice.13 = f16[1,4,6]{2,1,0} slice(%p0), slice={[1:2], [0:4], [0:6]} %bitcast.41 = f16[4,6]{1,0} bitcast(%slice.13) %slice.14 = f16[1,6,4]{2,1,0} slice(%p1), slice={[1:2], [0:6], [0:4]} @@ -437,17 +480,17 @@ TEST_F(AddressComputationFusionRewriterTest, SimpleGemmNotContiguousSlice) { )"; auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo(); - RunAndFilecheckHloRewrite(hlo, AddressComputationFusionRewriter(), + RunAndFilecheckHloRewrite(hlo, AddressComputationFusionRewriter(PLATFORM), std::nullopt); } TEST_F(AddressComputationFusionRewriterTest, SimpleGemmNonNoOpInSliceChain) { const char* hlo = R"( - HloModule test + HloModule test, is_scheduled=true ENTRY %main.9 { - %p0 = f16[2,8,8]{2,1,0} parameter(0), sharding={replicated} - %p1 = f16[2,8,8]{2,1,0} parameter(1), sharding={replicated} + %p0 = f16[2,8,8]{2,1,0} parameter(0) + %p1 = f16[2,8,8]{2,1,0} parameter(1) %slice.13 = f16[1,8,8]{2,1,0} slice(%p0), slice={[0:1], [0:8], [0:8]} %slice.14 = f16[1,8,8]{2,1,0} slice(%p0), slice={[1:2], [0:8], [0:8]} %add.0 = f16[1,8,8]{2,1,0} add(%slice.13, %slice.14) @@ -480,8 +523,481 @@ TEST_F(AddressComputationFusionRewriterTest, SimpleGemmNonNoOpInSliceChain) { )"; auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo(); - RunAndFilecheckHloRewrite(hlo, AddressComputationFusionRewriter(), + RunAndFilecheckHloRewrite(hlo, AddressComputationFusionRewriter(PLATFORM), std::nullopt); } +TEST_F(AddressComputationFusionRewriterTest, SimpleGemmDuplicateOperand) { + const char* hlo = R"( + HloModule test, is_scheduled=true + + ENTRY %main { + %p0 = (f32[100,100]{1,0}, f32[100,100]{1,0}) parameter(0) + %get-tuple-element.240 = f32[100,100]{1,0} get-tuple-element(%p0), index=0 + %get-tuple-element.241 = f32[100,100]{1,0} get-tuple-element(%p0), index=1 + %concatenate.10 = f32[200,100]{1,0} concatenate(%get-tuple-element.240, %get-tuple-element.241), dimensions={0} + %custom-call.16 = (f32[200,100]{1,0}, s8[120000]{0}) custom-call(%concatenate.10, %get-tuple-element.240), + custom_call_target="__cublas$gemm", + backend_config={ + "gemm_backend_config":{ + "alpha_real":1, + "beta":0, + "dot_dimension_numbers":{ + "lhs_contracting_dimensions":["1"], + "rhs_contracting_dimensions":["0"], + "lhs_batch_dimensions":[], + "rhs_batch_dimensions":[] + }, + "alpha_imag":0, + "precision_config":{"operand_precision":["HIGHEST","HIGHEST"]}, + "epilogue":"DEFAULT", + "lhs_stride":"20000", + "rhs_stride":"10000", + "grad_x":false, + "grad_y":false + } + } + %get-tuple-element.97 = f32[200,100]{1,0} get-tuple-element(%custom-call.16), index=0 + %slice.26 = f32[100,100]{1,0} slice(%get-tuple-element.97), slice={[0:100], [0:100]} + ROOT %custom-call.17 = (f32[100,100]{1,0}, s8[80000]{0}) custom-call(%slice.26, %slice.26), + custom_call_target="__cublas$gemm", + backend_config={ + "gemm_backend_config":{ + "alpha_real":1, + "beta":0, + "dot_dimension_numbers":{ + "lhs_contracting_dimensions":["1"], + "rhs_contracting_dimensions":["0"], + "lhs_batch_dimensions":[], + "rhs_batch_dimensions":[] + }, + "alpha_imag":0, + "precision_config":{"operand_precision":["HIGHEST","HIGHEST"]}, + "epilogue":"DEFAULT", + "lhs_stride":"10000", + "rhs_stride":"10000", + "grad_x":false, + "grad_y":false + } + } + })"; + + const char* expected = R"( + ; CHECK: %address-computation {{.*}} { + ; CHECK: [[P0:%[^ ]+]] = f32[200,100]{1,0} parameter(0) + ; CHECK: [[S0:%[^ ]+]] = f32[100,100]{1,0} slice([[P0]]), slice={[0:100], [0:100]} + ; CHECK-NOT: slice + ; CHECK: [[CC:%[^ ]+]] = (f32[100,100]{1,0}, s8[80000]{0}) custom-call([[S0]], [[S0]]), + ; CHECK: custom_call_target="__cublas$gemm" + ; CHECK: } + + ; CHECK: ENTRY %main{{.*}} { + ; CHECK: ROOT [[FUSION:%[^ ]+]] = (f32[100,100]{1,0}, s8[80000]{0}) fusion + ; CHECK: kind=kCustom, calls=%address-computation, + ; CHECK: backend_config={ + ; CHECK: "kind":"__custom_fusion", + ; CHECK: "custom_fusion_config":{"name":"address_computation"} + ; CHECK: } + ; CHECK: } + )"; + + auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo(); + RunAndFilecheckHloRewrite(hlo, AddressComputationFusionRewriter(PLATFORM), + expected, [](HloModule* module) { + EXPECT_TRUE(module->has_schedule()); + TF_CHECK_OK(module->schedule().Verify()); + }); +} + +TEST_F(AddressComputationFusionRewriterTest, SimpleGemmReverseOperandOrder) { + const char* hlo = R"( + HloModule test, is_scheduled=true + + ENTRY %main.9 { + %p0 = f16[2,8,8]{2,1,0} parameter(1) + %slice.13 = f16[1,8,8]{2,1,0} slice(%p0), slice={[0:1], [0:8], [0:8]} + %bitcast.41 = f16[8,8]{1,0} bitcast(%slice.13) + %p1 = f16[2,8,8]{2,1,0} parameter(0) + %slice.14 = f16[1,8,8]{2,1,0} slice(%p1), slice={[1:2], [0:8], [0:8]} + %bitcast.42 = f16[8,8]{1,0} bitcast(%slice.14) + + ROOT %custom-call.1 = f16[8,8]{1,0} custom-call(%bitcast.41, %bitcast.42), + custom_call_target="__cublas$gemm", + backend_config={"gemm_backend_config":{ + "alpha_real":1, + "beta":0, + "dot_dimension_numbers":{ + "lhs_contracting_dimensions":["1"], + "rhs_contracting_dimensions":["0"], + "lhs_batch_dimensions":[], + "rhs_batch_dimensions":[] + }, + "alpha_imag":0, + "precision_config":{"operand_precision":["DEFAULT","DEFAULT"]}, + "epilogue":"DEFAULT", + "lhs_stride":"64", + "rhs_stride":"64", + "grad_x":false, + "grad_y":false + }} + } + )"; + + const char* expected = R"( + ; CHECK: %address-computation {{.*}} { + ; CHECK-DAG: [[P0:%[^ ]+]] = f16[2,8,8]{2,1,0} parameter(0) + ; CHECK-DAG: [[P1:%[^ ]+]] = f16[2,8,8]{2,1,0} parameter(1) + ; CHECK-DAG: [[S0:%[^ ]+]] = f16[1,8,8]{2,1,0} slice([[P0]]), slice={[0:1], [0:8], [0:8]} + ; CHECK-DAG: [[B0:%[^ ]+]] = f16[8,8]{1,0} bitcast([[S0]]) + ; CHECK-DAG: [[S1:%[^ ]+]] = f16[1,8,8]{2,1,0} slice([[P1]]), slice={[1:2], [0:8], [0:8]} + ; CHECK-DAG: [[B1:%[^ ]+]] = f16[8,8]{1,0} bitcast([[S1]]) + ; CHECK: ROOT [[CC:%[^ ]+]] = f16[8,8]{1,0} custom-call([[B0]], [[B1]]), + ; CHECK: custom_call_target="__cublas$gemm" + ; CHECK: } + + ; CHECK: ENTRY %main{{.*}} { + ; CHECK-DAG: [[A0:%[^ ]+]] = f16[2,8,8]{2,1,0} parameter(1) + ; CHECK-DAG: [[A1:%[^ ]+]] = f16[2,8,8]{2,1,0} parameter(0) + ; CHECK: ROOT [[FUSION:%[^ ]+]] = f16[8,8]{1,0} fusion([[A0]], [[A1]]) + ; CHECK: kind=kCustom, calls=%address-computation, + ; CHECK: backend_config={ + ; CHECK: "kind":"__custom_fusion", + ; CHECK: "custom_fusion_config":{"name":"address_computation"} + ; CHECK: } + ; CHECK: } + )"; + + auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo(); + RunAndFilecheckHloRewrite(hlo, AddressComputationFusionRewriter(PLATFORM), + expected, [](HloModule* module) { + EXPECT_TRUE(module->has_schedule()); + TF_CHECK_OK(module->schedule().Verify()); + }); +} + +TEST_F(AddressComputationFusionRewriterTest, SimpleGemmReverseOperandOrder2) { + const char* hlo = R"( + HloModule test, is_scheduled=true + + ENTRY %main.9 { + %p0 = f16[2,8,8]{2,1,0} parameter(0) + %slice.13 = f16[1,8,8]{2,1,0} slice(%p0), slice={[0:1], [0:8], [0:8]} + %bitcast.41 = f16[8,8]{1,0} bitcast(%slice.13) + %p1 = f16[2,8,8]{2,1,0} parameter(1) + %slice.14 = f16[1,8,8]{2,1,0} slice(%p1), slice={[1:2], [0:8], [0:8]} + %bitcast.42 = f16[8,8]{1,0} bitcast(%slice.14) + + ROOT %custom-call.1 = f16[8,8]{1,0} custom-call(%bitcast.42, %bitcast.41), + custom_call_target="__cublas$gemm", + backend_config={"gemm_backend_config":{ + "alpha_real":1, + "beta":0, + "dot_dimension_numbers":{ + "lhs_contracting_dimensions":["1"], + "rhs_contracting_dimensions":["0"], + "lhs_batch_dimensions":[], + "rhs_batch_dimensions":[] + }, + "alpha_imag":0, + "precision_config":{"operand_precision":["DEFAULT","DEFAULT"]}, + "epilogue":"DEFAULT", + "lhs_stride":"64", + "rhs_stride":"64", + "grad_x":false, + "grad_y":false + }} + } + )"; + + const char* expected = R"( + ; CHECK: %address-computation {{.*}} { + ; CHECK-DAG: [[P0:%[^ ]+]] = f16[2,8,8]{2,1,0} parameter(0) + ; CHECK-DAG: [[P1:%[^ ]+]] = f16[2,8,8]{2,1,0} parameter(1) + ; CHECK-DAG: [[S0:%[^ ]+]] = f16[1,8,8]{2,1,0} slice([[P0]]), slice={[1:2], [0:8], [0:8]} + ; CHECK-DAG: [[B0:%[^ ]+]] = f16[8,8]{1,0} bitcast([[S0]]) + ; CHECK-DAG: [[S1:%[^ ]+]] = f16[1,8,8]{2,1,0} slice([[P1]]), slice={[0:1], [0:8], [0:8]} + ; CHECK-DAG: [[B1:%[^ ]+]] = f16[8,8]{1,0} bitcast([[S1]]) + ; CHECK: ROOT [[CC:%[^ ]+]] = f16[8,8]{1,0} custom-call([[B0]], [[B1]]), + ; CHECK: custom_call_target="__cublas$gemm" + ; CHECK: } + + ; CHECK: ENTRY %main{{.*}} { + ; CHECK-DAG: [[A0:%[^ ]+]] = f16[2,8,8]{2,1,0} parameter(1) + ; CHECK-DAG: [[A1:%[^ ]+]] = f16[2,8,8]{2,1,0} parameter(0) + ; CHECK: ROOT [[FUSION:%[^ ]+]] = f16[8,8]{1,0} fusion([[A0]], [[A1]]) + ; CHECK: kind=kCustom, calls=%address-computation, + ; CHECK: backend_config={ + ; CHECK: "kind":"__custom_fusion", + ; CHECK: "custom_fusion_config":{"name":"address_computation"} + ; CHECK: } + ; CHECK: } + )"; + + auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo(); + RunAndFilecheckHloRewrite(hlo, AddressComputationFusionRewriter(PLATFORM), + expected, [](HloModule* module) { + EXPECT_TRUE(module->has_schedule()); + TF_CHECK_OK(module->schedule().Verify()); + }); +} + +TEST_F(AddressComputationFusionRewriterTest, SimpleGemmOperandAliasingOutput) { + const char* hlo = R"( + HloModule test, is_scheduled=true + + ENTRY %main.9 { + %p0 = (f32[100,100]{1,0}, f32[100,100]{1,0}) parameter(0) + %get-tuple-element.287 = f32[100,100]{1,0} get-tuple-element(%p0), index=0 + %get-tuple-element.288 = f32[100,100]{1,0} get-tuple-element(%p0), index=1 + %concatenate.12 = f32[200,100]{1,0} concatenate(%get-tuple-element.287, %get-tuple-element.288), dimensions={0} + %slice.30 = f32[100,100]{1,0} slice(%concatenate.12), slice={[20:120], [0:100]} + %slice.34 = f32[100,100]{1,0} slice(%concatenate.12), slice={[99:199], [0:100]} + ROOT %cublas-gemm.15 = (f32[100,100]{1,0}, s8[120000]{0}) custom-call(%get-tuple-element.287, %slice.30, %slice.34), + custom_call_target="__cublas$gemm", + output_to_operand_aliasing={{0}: (2, {})}, + backend_config={"gemm_backend_config":{ + "alpha_real":1, + "beta":1, + "dot_dimension_numbers":{ + "lhs_contracting_dimensions":["1"], + "rhs_contracting_dimensions":["0"], + "lhs_batch_dimensions":[], + "rhs_batch_dimensions":[] + }, + "alpha_imag":0, + "precision_config":{"operand_precision":["HIGHEST","HIGHEST"]}, + "epilogue":"DEFAULT", + "lhs_stride":"10000", + "rhs_stride":"10000", + "grad_x":false, + "grad_y":false + }} + } + )"; + + const char* expected = R"( + ; CHECK: %address-computation {{.*}} { + ; CHECK-DAG: [[P0:%[^ ]+]] = f32[100,100]{1,0} parameter(0) + ; CHECK-DAG: [[P1:%[^ ]+]] = f32[100,100]{1,0} parameter(1) + ; CHECK-DAG: [[P2:%[^ ]+]] = f32[200,100]{1,0} parameter(2) + ; CHECK-DAG: [[S1:%[^ ]+]] = f32[100,100]{1,0} slice([[P2]]), slice={[20:120], [0:100]} + ; CHECK: [[CC:%[^ ]+]] = (f32[100,100]{1,0}, s8[120000]{0}) custom-call([[P0]], [[S1]], [[P1]]), + ; CHECK: custom_call_target="__cublas$gemm" + ; CHECK: } + + ; CHECK: ENTRY %main{{.*}} { + ; CHECK: [[P:%[^ ]+]] = (f32[100,100]{1,0}, f32[100,100]{1,0}) parameter(0) + ; CHECK: [[GTE0:%[^ ]+]] = f32[100,100]{1,0} get-tuple-element([[P]]), index=0 + ; CHECK: [[GTE1:%[^ ]+]] = f32[100,100]{1,0} get-tuple-element([[P]]), index=1 + ; CHECK: [[CONCAT:%[^ ]+]] = f32[200,100]{1,0} concatenate([[GTE0]], [[GTE1]]), dimensions={0} + ; CHECK: [[S:%[^ ]+]] = f32[100,100]{1,0} slice([[CONCAT]]), slice={[99:199], [0:100]} + ; CHECK: ROOT [[FUSION:%[^ ]+]] = (f32[100,100]{1,0}, s8[120000]{0}) fusion([[GTE0]], [[S]], [[CONCAT]]) + ; CHECK: kind=kCustom, calls=%address-computation, + ; CHECK: backend_config={ + ; CHECK: "kind":"__custom_fusion", + ; CHECK: "custom_fusion_config":{"name":"address_computation"} + ; CHECK: } + ; CHECK: } + )"; + + auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo(); + RunAndFilecheckHloRewrite(hlo, AddressComputationFusionRewriter(PLATFORM), + expected, [](HloModule* module) { + EXPECT_TRUE(module->has_schedule()); + TF_CHECK_OK(module->schedule().Verify()); + }); +} + +TEST_F(AddressComputationFusionRewriterTest, SimpleGemmOperandsFromSameSlice) { + const char* hlo = R"( + HloModule test, is_scheduled=true + + ENTRY %main.9 { + %p0 = f16[2,8,8]{2,1,0} parameter(0) + %slice.13 = f16[1,8,8]{2,1,0} slice(%p0), slice={[0:1], [0:8], [0:8]} + %bitcast.41 = f16[8,8]{1,0} bitcast(%slice.13) + %bitcast.42 = f16[8,8]{0,1} bitcast(%slice.13) + + ROOT %custom-call.1 = f16[8,8]{1,0} custom-call(%bitcast.41, %bitcast.42), + custom_call_target="__cublas$gemm", + backend_config={"gemm_backend_config":{ + "alpha_real":1, + "beta":0, + "dot_dimension_numbers":{ + "lhs_contracting_dimensions":["1"], + "rhs_contracting_dimensions":["0"], + "lhs_batch_dimensions":[], + "rhs_batch_dimensions":[] + }, + "alpha_imag":0, + "precision_config":{"operand_precision":["DEFAULT","DEFAULT"]}, + "epilogue":"DEFAULT", + "lhs_stride":"64", + "rhs_stride":"64", + "grad_x":false, + "grad_y":false + }} + } + )"; + + const char* expected = R"( + ; CHECK: %address-computation {{.*}} { + ; CHECK-DAG: [[P0:%[^ ]+]] = f16[2,8,8]{2,1,0} parameter(0) + ; CHECK-DAG: [[S0:%[^ ]+]] = f16[1,8,8]{2,1,0} slice([[P0]]), slice={[0:1], [0:8], [0:8]} + ; CHECK-DAG: [[B0:%[^ ]+]] = f16[8,8]{1,0} bitcast([[S0]]) + ; CHECK-DAG: [[B1:%[^ ]+]] = f16[8,8]{0,1} bitcast([[S0]]) + ; CHECK: ROOT [[CC:%[^ ]+]] = f16[8,8]{1,0} custom-call([[B0]], [[B1]]), + ; CHECK: custom_call_target="__cublas$gemm" + ; CHECK: } + + ; CHECK: ENTRY %main{{.*}} { + ; CHECK-DAG: [[A0:%[^ ]+]] = f16[2,8,8]{2,1,0} parameter(0) + ; CHECK: ROOT [[FUSION:%[^ ]+]] = f16[8,8]{1,0} fusion([[A0]]) + ; CHECK: kind=kCustom, calls=%address-computation, + ; CHECK: backend_config={ + ; CHECK: "kind":"__custom_fusion", + ; CHECK: "custom_fusion_config":{"name":"address_computation"} + ; CHECK: } + ; CHECK: } + )"; + + auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo(); + RunAndFilecheckHloRewrite(hlo, AddressComputationFusionRewriter(PLATFORM), + expected, [](HloModule* module) { + EXPECT_TRUE(module->has_schedule()); + TF_CHECK_OK(module->schedule().Verify()); + }); +} + +static absl::Status Memcpy(const ServiceExecutableRunOptions* run_options, + ffi::BufferBase src, ffi::BufferBase dst) { + return run_options->stream()->MemcpyD2D( + &dst.data, src.data, + absl::c_accumulate(src.dimensions, 1.0, std::multiplies()) * + sizeof(float)); +} + +XLA_FFI_DEFINE_HANDLER(kMemcpy, Memcpy, + ffi::Ffi::Bind() + .Ctx() + .Arg() // src + .Arg() // dst +); +XLA_FFI_REGISTER_HANDLER(ffi::GetXlaFfiApi(), "__xla_test$$memcpy", PLATFORM, + kMemcpy); + +TEST_F(AddressComputationFusionRewriterTest, SimpleCustomCall) { + XlaBuilder b(TestName()); + CustomCall(&b, "__xla_test$$memcpy", + /*operands=*/ + {Slice(Broadcast(ConstantR0WithType(&b, F32, 42.0), {256}), {0}, + {128}, {1})}, + ShapeUtil::MakeShape(F32, {128}), /*opaque=*/"", + /*has_side_effect=*/false, + /*output_operand_aliasing=*/{}, /*literal=*/nullptr, + /*schedule=*/CustomCallSchedule::SCHEDULE_NONE, + /*api_version=*/CustomCallApiVersion::API_VERSION_TYPED_FFI); + TF_ASSERT_OK_AND_ASSIGN(auto computation, b.Build()); + xla::HloModuleConfig hlo_config( + xla::ProgramShape(computation.proto().host_program_shape()), + /*ignore_layouts=*/false); + DebugOptions debug_options = GetDebugOptionsForTest(); + debug_options.set_xla_gpu_enable_address_computation_fusion(false); + hlo_config.set_debug_options(debug_options); + TF_ASSERT_OK_AND_ASSIGN(auto hlo, xla::HloModule::CreateFromProto( + computation.proto(), hlo_config)); + TF_ASSERT_OK_AND_ASSIGN( + HloSchedule schedule, + ScheduleModule(hlo.get(), [](const BufferValue& buffer) { + return ShapeUtil::ByteSizeOf(buffer.shape(), /*pointer_size=*/8); + })); + TF_CHECK_OK(hlo->set_schedule(std::move(schedule))); + + const char* expected = R"( + ; CHECK: %address-computation {{.*}} { + ; CHECK: [[P0:%[^ ]+]] = f32[256]{0} parameter(0) + ; CHECK: [[S0:%[^ ]+]] = f32[128]{0} slice([[P0]]), slice={[0:128]} + ; CHECK: ROOT [[CC:%[^ ]+]] = f32[128]{0} custom-call([[S0]]), + ; CHECK: custom_call_target="__xla_test$$memcpy", + ; CHECK: api_version=API_VERSION_TYPED_FFI + ; CHECK: } + + ; CHECK: ENTRY %{{.*}} { + ; CHECK: [[C0:%[^ ]+]] = f32[] constant(42) + ; CHECK: [[BC:%[^ ]+]] = f32[256]{0} broadcast([[C0]]) + ; CHECK: ROOT [[FUSION:%[^ ]+]] = f32[128]{0} fusion([[BC]]) + ; CHECK: kind=kCustom, calls=%address-computation, + ; CHECK: backend_config={ + ; CHECK: "kind":"__custom_fusion", + ; CHECK: "custom_fusion_config":{"name":"address_computation"} + ; CHECK: } + ; CHECK: } + )"; + + auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo(); + RunAndFilecheckHloRewrite(hlo->ToString(), + AddressComputationFusionRewriter(PLATFORM), + expected, [](HloModule* module) { + EXPECT_TRUE(module->has_schedule()); + TF_CHECK_OK(module->schedule().Verify()); + }); +} + +void Callback_Void(se::gpu::GpuStreamHandle stream, void** buffers, + const char* /*opaque*/, size_t /*opaque_len*/) {} + +XLA_REGISTER_CUSTOM_CALL_TARGET(Callback_Void, PLATFORM); + +TEST_F(AddressComputationFusionRewriterTest, SimpleCustomCallLegacy) { + XlaBuilder b(TestName()); + CustomCall(&b, "Callback_Void", + /*operands=*/ + {Slice(Broadcast(ConstantR0WithType(&b, F32, 42.0), {256}), {0}, + {128}, {1})}, + ShapeUtil::MakeShape(F32, {128}), /*opaque=*/""); + TF_ASSERT_OK_AND_ASSIGN(auto computation, b.Build()); + xla::HloModuleConfig hlo_config( + xla::ProgramShape(computation.proto().host_program_shape()), + /*ignore_layouts=*/false); + DebugOptions debug_options = GetDebugOptionsForTest(); + debug_options.set_xla_gpu_enable_address_computation_fusion(false); + hlo_config.set_debug_options(debug_options); + TF_ASSERT_OK_AND_ASSIGN(auto hlo, xla::HloModule::CreateFromProto( + computation.proto(), hlo_config)); + TF_ASSERT_OK_AND_ASSIGN( + HloSchedule schedule, + ScheduleModule(hlo.get(), [](const BufferValue& buffer) { + return ShapeUtil::ByteSizeOf(buffer.shape(), /*pointer_size=*/8); + })); + TF_CHECK_OK(hlo->set_schedule(std::move(schedule))); + + const char* expected = R"( + ; CHECK: %address-computation {{.*}} { + ; CHECK: [[P0:%[^ ]+]] = f32[256]{0} parameter(0) + ; CHECK: [[S0:%[^ ]+]] = f32[128]{0} slice([[P0]]), slice={[0:128]} + ; CHECK: ROOT [[CC:%[^ ]+]] = f32[128]{0} custom-call([[S0]]), + ; CHECK: custom_call_target="Callback_Void" + ; CHECK: } + + ; CHECK: ENTRY %{{.*}} { + ; CHECK: [[C0:%[^ ]+]] = f32[] constant(42) + ; CHECK: [[BC:%[^ ]+]] = f32[256]{0} broadcast([[C0]]) + ; CHECK: ROOT [[FUSION:%[^ ]+]] = f32[128]{0} fusion([[BC]]) + ; CHECK: kind=kCustom, calls=%address-computation, + ; CHECK: backend_config={ + ; CHECK: "kind":"__custom_fusion", + ; CHECK: "custom_fusion_config":{"name":"address_computation"} + ; CHECK: } + ; CHECK: } + )"; + + auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo(); + RunAndFilecheckHloRewrite(hlo->ToString(), + AddressComputationFusionRewriter(PLATFORM), + expected, [](HloModule* module) { + EXPECT_TRUE(module->has_schedule()); + TF_CHECK_OK(module->schedule().Verify()); + }); +} + } // namespace xla::gpu diff --git a/third_party/xla/xla/service/gpu/amdgpu_compiler.cc b/third_party/xla/xla/service/gpu/amdgpu_compiler.cc index 30a1fd334aa059..0a985871914cd1 100644 --- a/third_party/xla/xla/service/gpu/amdgpu_compiler.cc +++ b/third_party/xla/xla/service/gpu/amdgpu_compiler.cc @@ -76,7 +76,8 @@ absl::Status AMDGPUCompiler::OptimizeHloConvolutionCanonicalization( // tf2xla bridge, DepthwiseConvolutionConverter and GpuConvRewriter // introduces reshapes and transposes that can be eliminated using // AlgebraicSimplifier We run algsimp to a fixed point. - AlgebraicSimplifierOptions options; + AlgebraicSimplifierOptions options = + GetAlgebraicSimplifierOptions(hlo_module->config()); options.set_enable_conv_operand_swap(false); options.set_enable_unconditional_reduce_of_concat_replacement(false); pipeline.AddPass>(options); diff --git a/third_party/xla/xla/service/gpu/autotuner_util.cc b/third_party/xla/xla/service/gpu/autotuner_util.cc index 18b56b6a38a9f1..ebf9e999836e0b 100644 --- a/third_party/xla/xla/service/gpu/autotuner_util.cc +++ b/third_party/xla/xla/service/gpu/autotuner_util.cc @@ -65,17 +65,8 @@ static absl::Mutex autotune_cache_mu(absl::kConstInit); static auto& autotune_cache ABSL_GUARDED_BY(autotune_cache_mu) = *new AutotuneCacheMap(); -/*static*/ absl::Status AutotunerUtil::SerializeAutotuneResults( - AutotuneResults* results) { - absl::MutexLock lock(&autotune_cache_mu); - for (const auto& [k, result] : autotune_cache) { - auto& entry = *results->add_results(); - entry.set_device(std::string(k.GetModelStr())); - entry.set_hlo(std::string(k.GetHlo())); - *entry.mutable_result() = result; - } - - // Sort the results so that they're deterministic. +// Sort the results so that they're deterministic. +static void SortAutotuneResults(AutotuneResults* results) { std::sort(results->mutable_results()->pointer_begin(), results->mutable_results()->pointer_end(), [](const auto* a, const auto* b) { @@ -84,6 +75,40 @@ static auto& autotune_cache ABSL_GUARDED_BY(autotune_cache_mu) = std::make_pair(absl::string_view(b->device()), absl::string_view(b->hlo())); }); +} + +// Serialize `results` to string as a proto. +static absl::StatusOr AutotuneResultsToString( + const AutotuneResults& results, bool as_textproto) { + if (as_textproto) { + std::string textproto; + if (tsl::protobuf::TextFormat::PrintToString(results, &textproto)) { + return textproto; + } else { + return Internal("Failed to serialize autotune results."); + } + } + return results.SerializeAsString(); +} + +// Serialize a single entry to `results`. +static void SerializeAutotuneEntry(AutotuneResults* results, + const AutotuneCacheKey& k, + const AutotuneResult* res) { + auto& entry = *results->add_results(); + entry.set_device(std::string(k.GetModelStr())); + entry.set_hlo(std::string(k.GetHlo())); + *entry.mutable_result() = *res; +} + +/*static*/ absl::Status AutotunerUtil::SerializeAutotuneResults( + AutotuneResults* results) { + absl::MutexLock lock(&autotune_cache_mu); + for (const auto& [k, result] : autotune_cache) { + SerializeAutotuneEntry(results, k, &result); + } + + SortAutotuneResults(results); return absl::OkStatus(); } @@ -183,13 +208,14 @@ namespace { // Bump this version whenever you change the structure of the results. // LINT.IfChange(version) -constexpr int kVersion = 2; +constexpr int kVersion = 3; // LINT.ThenChange() bool IsTextProtoPath(absl::string_view file_path) { return absl::EndsWith(file_path, ".txt") || absl::EndsWith(file_path, ".textproto") || - absl::EndsWith(file_path, ".prototxt"); + absl::EndsWith(file_path, ".prototxt") || + absl::EndsWith(file_path, ".pbtxt"); } } // anonymous namespace @@ -221,15 +247,7 @@ bool IsTextProtoPath(absl::string_view file_path) { AutotuneResults results; results.set_version(kVersion); TF_RETURN_IF_ERROR(SerializeAutotuneResults(&results)); - if (as_textproto) { - std::string textproto; - if (tsl::protobuf::TextFormat::PrintToString(results, &textproto)) { - return textproto; - } else { - return Internal("Failed to serialize autotune results."); - } - } - return results.SerializeAsString(); + return AutotuneResultsToString(results, as_textproto); } /*static*/ absl::Status AutotunerUtil::SerializeAutotuneResultsToFile( @@ -275,42 +293,6 @@ bool IsTextProtoPath(absl::string_view file_path) { return absl::OkStatus(); } -/*static*/ std::unique_ptr -AutotunerUtil::ExtractInstructionIntoNewModule(const HloInstruction& hlo) { - auto new_hlo_module = std::make_unique( - "extracted", HloModuleConfig{}, - std::make_unique(hlo.GetModule()->comp_envs())); - int parameter_number = 0; - HloComputation::Builder builder("entry_computation"); - HloCloneContext clone_context(new_hlo_module.get()); - std::vector new_operands; - for (const HloInstruction* operand : hlo.operands()) { - std::unique_ptr new_parameter = - HloInstruction::CreateParameter(parameter_number, operand->shape(), - operand->name()); - ++parameter_number; - new_operands.push_back(builder.AddInstruction(std::move(new_parameter))); - } - std::unique_ptr new_instruction = - hlo.CloneWithNewOperands(hlo.shape(), new_operands, &clone_context); - builder.AddInstruction(std::move(new_instruction)); - new_hlo_module->AddEntryComputationWithLayouts(builder.Build()); - return new_hlo_module; -} - -/*static*/ std::unique_ptr -AutotunerUtil::ExtractComputationIntoNewModule( - const HloComputation& computation) { - auto new_hlo_module = - std::make_unique("extracted", HloModuleConfig{}, - std::make_unique( - computation.parent()->comp_envs())); - HloCloneContext clone_context(new_hlo_module.get()); - new_hlo_module->AddEntryComputationWithLayouts( - computation.CloneInContext(clone_context)); - return new_hlo_module; -} - /*static*/ absl::StatusOr AutotunerUtil::CreateRedzoneAllocator(const AutotuneConfig& config, const DebugOptions& opts, @@ -327,5 +309,24 @@ AutotunerUtil::CreateRedzoneAllocator(const AutotuneConfig& config, : 0); } +/*static*/ absl::StatusOr +AutotunerUtil::SerializeAutotuneResultsForModule( + const HloModule& module, const AutotuneConfig& autotune_config, + bool as_textproto) { + AutotuneResults results; + results.set_version(kVersion); + + for (const HloInstruction* instr : + module.entry_computation()->instructions()) { + AutotuneCacheKey k(autotune_config.GetModelStr(), *instr); + if (const AutotuneResult* res = TryFindInCache(k)) { + SerializeAutotuneEntry(&results, k, res); + } + } + + SortAutotuneResults(&results); + return AutotuneResultsToString(results, as_textproto); +} + } // namespace gpu } // namespace xla diff --git a/third_party/xla/xla/service/gpu/autotuner_util.h b/third_party/xla/xla/service/gpu/autotuner_util.h index 54f0d9d99a50cc..d78891def8f214 100644 --- a/third_party/xla/xla/service/gpu/autotuner_util.h +++ b/third_party/xla/xla/service/gpu/autotuner_util.h @@ -232,6 +232,16 @@ struct AutotunerUtil { static absl::StatusOr SerializeAutotuneResults( bool as_textproto = false); + // As above, but only performs serialization for instructions found in the + // module. + // + // Only serializes autotuning results for instructions found in the module: + // while this is more expensive than serializing all cache, this avoids + // quadratic blow-up when serializing cache for a large number of modules. + static absl::StatusOr SerializeAutotuneResultsForModule( + const HloModule& module, const AutotuneConfig& autotune_config, + bool as_textproto = false); + static absl::Status SerializeAutotuneResults(AutotuneResults* results); static absl::Status LoadAutotuneResults(absl::string_view data, bool as_textproto = false); @@ -253,16 +263,6 @@ struct AutotunerUtil { static absl::Status LoadAutotuneResultsFromFile(absl::string_view file_path); static void ClearAutotuneResults(); - - // Extracts an HLO instruction into a new HLO module replacing its operands - // with parameter instructions. - static std::unique_ptr ExtractInstructionIntoNewModule( - const HloInstruction& hlo); - - // Extracts an HLO computation into a new HLO module, using its clone as the - // root computation. - static std::unique_ptr ExtractComputationIntoNewModule( - const HloComputation& computation); }; } // namespace gpu diff --git a/third_party/xla/xla/service/gpu/backend_configs.proto b/third_party/xla/xla/service/gpu/backend_configs.proto index 65cc7221fd0f6e..5f837d239d25ed 100644 --- a/third_party/xla/xla/service/gpu/backend_configs.proto +++ b/third_party/xla/xla/service/gpu/backend_configs.proto @@ -125,8 +125,22 @@ message CollectiveBackendConfig { bool no_parallel_custom_call = 2; } +// Backend config for cost model estimates. message ReificationCost { - double end_to_end_cycles = 1; // Total execution time of the reified op. + // Total execution time of the reified op. + double end_to_end_cycles = 1; + + // Estimated overall kernel execution in microseconds. + // + // GPU Cost Model estimates compute and memory access time separately. Exec + // time is a combined metric of the two. + double exec_time_us = 2; + + // Estimate for compute time in microseconds. + double compute_time_us = 3; + + // Estimate for memory access (read+write) time in microseconds. + double memory_access_time_us = 4; } // Backend config for a custom fusion (pre-compiled device kernel implementing a @@ -165,6 +179,14 @@ message CudnnNormBackendConfig { // Opaque algorithm number. stream_executor.dnn.AlgorithmProto algorithm = 2; + + // Norm type. + enum Kind { + LAYER_FWD_INFER = 0; + LAYER_FWD_TRAIN = 1; + LAYER_BWD = 2; + } + Kind kind = 3; } // Backend config for a fused Multi-Headed Attention (fMHA) that runs through diff --git a/third_party/xla/xla/service/gpu/buffer_comparator.cc b/third_party/xla/xla/service/gpu/buffer_comparator.cc index 35c763a38c2f8e..d73c692d3c8f3a 100644 --- a/third_party/xla/xla/service/gpu/buffer_comparator.cc +++ b/third_party/xla/xla/service/gpu/buffer_comparator.cc @@ -64,7 +64,7 @@ static absl::StatusOr DeviceCompare(se::Stream* stream, se::ScopedDeviceMemory out_param = executor->AllocateOwnedScalar(); - stream->ThenMemZero(out_param.ptr(), sizeof(uint64_t)); + TF_RETURN_IF_ERROR(stream->MemZero(out_param.ptr(), sizeof(uint64_t))); if (current.size() != expected.size()) { return Internal("Mismatched buffer size: %d bytes vs. %d bytes", current.size(), expected.size()); @@ -75,11 +75,12 @@ static absl::StatusOr DeviceCompare(se::Stream* stream, uint64_t buffer_size = current_typed.ElementCount(); TF_ASSIGN_OR_RETURN( - std::unique_ptr> comparison_kernel, - (executor->CreateTypedKernel, - se::DeviceMemory, float, uint64_t, - se::DeviceMemory>(kernel_name, - kernel_symbol))); + ComparisonKernelT comparison_kernel, + (se::TypedKernel, se::DeviceMemory, + float, uint64_t, + se::DeviceMemory>::Create(executor, + kernel_name, + kernel_symbol))); const se::DeviceDescription& gpu_device_info = executor->GetDeviceDescription(); @@ -88,13 +89,13 @@ static absl::StatusOr DeviceCompare(se::Stream* stream, CalculateLaunchDimensions(buffer_shape, gpu_device_info); TF_RETURN_IF_ERROR(stream->ThenLaunch( - dim.thread_counts_per_block(), dim.block_counts(), *comparison_kernel, + dim.thread_counts_per_block(), dim.block_counts(), comparison_kernel, current_typed, expected_typed, static_cast(kTolerance), buffer_size, out_param.cref())); uint64_t result = -1; CHECK_EQ(out_param->size(), sizeof(result)); - stream->ThenMemcpy(&result, *out_param, sizeof(result)); + TF_RETURN_IF_ERROR(stream->Memcpy(&result, *out_param, sizeof(result))); TF_RETURN_IF_ERROR(stream->BlockHostUntilDone()); return result == 0; } @@ -109,8 +110,10 @@ absl::StatusOr HostCompare(se::Stream* stream, se::DeviceMemoryBase expected) { int64_t n = current.size() / sizeof(ElementType); std::vector host_current(n), host_expected(n); - stream->ThenMemcpy(host_current.data(), current, current.size()); - stream->ThenMemcpy(host_expected.data(), expected, expected.size()); + TF_RETURN_IF_ERROR( + stream->Memcpy(host_current.data(), current, current.size())); + TF_RETURN_IF_ERROR( + stream->Memcpy(host_expected.data(), expected, expected.size())); TF_RETURN_IF_ERROR(stream->BlockHostUntilDone()); const auto canonicalize = [](ComparisonType a) -> ComparisonType { diff --git a/third_party/xla/xla/service/gpu/buffer_comparator_test.cc b/third_party/xla/xla/service/gpu/buffer_comparator_test.cc index 2f590dde2a6db6..291f7206f362b2 100644 --- a/third_party/xla/xla/service/gpu/buffer_comparator_test.cc +++ b/third_party/xla/xla/service/gpu/buffer_comparator_test.cc @@ -26,7 +26,7 @@ limitations under the License. #include "xla/shape_util.h" #include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/device_memory_allocator.h" -#include "xla/stream_executor/multi_platform_manager.h" +#include "xla/stream_executor/platform_manager.h" #include "xla/stream_executor/stream.h" #include "xla/types.h" #include "tsl/platform/ml_dtypes.h" @@ -41,9 +41,9 @@ class BufferComparatorTest : public testing::Test { protected: BufferComparatorTest() #if GOOGLE_CUDA - : platform_(se::MultiPlatformManager::PlatformWithName("CUDA").value()), + : platform_(se::PlatformManager::PlatformWithName("CUDA").value()), #elif TENSORFLOW_USE_ROCM - : platform_(se::MultiPlatformManager::PlatformWithName("ROCM").value()), + : platform_(se::PlatformManager::PlatformWithName("ROCM").value()), #endif stream_exec_(platform_->ExecutorForDevice(0).value()) { } @@ -53,17 +53,17 @@ class BufferComparatorTest : public testing::Test { bool CompareEqualBuffers(const std::vector& current, const std::vector& expected) { se::Stream stream(stream_exec_); - stream.Init(); + TF_CHECK_OK(stream.Initialize()); se::ScopedDeviceMemory current_buffer = stream_exec_->AllocateOwnedArray(current.size()); se::ScopedDeviceMemory expected_buffer = stream_exec_->AllocateOwnedArray(expected.size()); - stream.ThenMemcpy(current_buffer.ptr(), current.data(), - current_buffer->size()); - stream.ThenMemcpy(expected_buffer.ptr(), expected.data(), - expected_buffer->size()); + TF_CHECK_OK(stream.Memcpy(current_buffer.ptr(), current.data(), + current_buffer->size())); + TF_CHECK_OK(stream.Memcpy(expected_buffer.ptr(), expected.data(), + expected_buffer->size())); TF_CHECK_OK(stream.BlockHostUntilDone()); BufferComparator comparator( @@ -346,7 +346,7 @@ TEST_F(BufferComparatorTest, BF16) { int64_t rng_state = 0; se::Stream stream(stream_exec_); - stream.Init(); + TF_CHECK_OK(stream.Initialize()); se::ScopedDeviceMemory lhs = stream_exec_->AllocateOwnedArray(element_count); diff --git a/third_party/xla/xla/service/gpu/command_buffer_scheduling.cc b/third_party/xla/xla/service/gpu/command_buffer_scheduling.cc index 03a79c358cb840..2daec3fb2e01ad 100644 --- a/third_party/xla/xla/service/gpu/command_buffer_scheduling.cc +++ b/third_party/xla/xla/service/gpu/command_buffer_scheduling.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include #include +#include #include #include #include @@ -39,6 +40,8 @@ limitations under the License. #include "xla/hlo/ir/hlo_schedule.h" #include "xla/service/gpu/backend_configs.pb.h" #include "xla/service/gpu/cublas_cudnn.h" +#include "xla/service/gpu/hlo_fusion_analysis.h" +#include "xla/service/gpu/hlo_traversal.h" #include "xla/service/gpu/variant_visitor.h" #include "xla/shape.h" #include "xla/shape_util.h" @@ -98,7 +101,7 @@ static bool IsCommand(const HloInstruction*, const CommandBufferConfig&); template <> bool IsCommand(const HloInstruction* hlo, const CommandBufferConfig& config) { - return config.contains(DebugOptions::CONDITIONALS) && + return config.enabled_commands.contains(DebugOptions::CONDITIONALS) && IsCommand(hlo->while_body(), config) && IsCommand(hlo->while_condition(), config); } @@ -108,7 +111,7 @@ bool IsCommand(const HloInstruction* hlo, template <> bool IsCommand(const HloInstruction* hlo, const CommandBufferConfig& config) { - return config.contains(DebugOptions::CONDITIONALS) && + return config.enabled_commands.contains(DebugOptions::CONDITIONALS) && absl::c_all_of(hlo->branch_computations(), [&](const HloComputation* comp) { return IsCommand(comp, config); @@ -117,31 +120,45 @@ bool IsCommand(const HloInstruction* hlo, static bool IsCommand(const HloCustomCallInstruction* hlo, const CommandBufferConfig& config) { - if (config.contains(DebugOptions::CUBLAS) && IsLegacyCublasMatmul(*hlo)) { + if (config.enabled_commands.contains(DebugOptions::CUBLAS) && + IsLegacyCublasMatmul(*hlo)) { return true; } - if (config.contains(DebugOptions::CUSTOM_CALL)) { - if (hlo->custom_call_target() == "triton_kernel_call" || - hlo->custom_call_target() == "cu_threefry2x32") { - return true; - } - } + if (config.enabled_commands.contains(DebugOptions::CUSTOM_CALL) && + hlo->custom_call_target() == "triton_kernel_call") + return true; return false; } static bool IsCommand(const HloInstruction* hlo, const CommandBufferConfig& config) { - if (auto* fusion = DynCast(hlo)) - return config.contains(DebugOptions::FUSION); + if (auto* fusion = DynCast(hlo)) { + auto gpu_config = fusion->backend_config(); + const FusionBackendConfig& backend_config = + gpu_config->fusion_backend_config(); + const auto& custom_config = backend_config.custom_fusion_config(); + if (custom_config.name() == "address_computation") { + auto fusion_analysis = + HloFusionAnalysis::Create(fusion, &config.device_description); + const HloFusionAdaptor& adaptor = fusion_analysis.fusion(); + auto custom_call_adaptor = HloFindIf( + adaptor.GetRoots(), adaptor, + [](auto node) { return node.opcode() == HloOpcode::kCustomCall; }); + const auto* custom_call = static_cast( + &custom_call_adaptor->instruction()); + return IsCommand(custom_call, config); + } + return config.enabled_commands.contains(DebugOptions::FUSION); + } if (auto* sort = DynCast(hlo)) - return config.contains(DebugOptions::FUSION); + return config.enabled_commands.contains(DebugOptions::FUSION); if (hlo->opcode() == HloOpcode::kPartitionId || hlo->opcode() == HloOpcode::kReplicaId) { - return config.contains(DebugOptions::FUSION); + return config.enabled_commands.contains(DebugOptions::FUSION); } if (auto* custom_call = DynCast(hlo)) @@ -171,12 +188,12 @@ static bool IsAsyncStartCommand(const HloInstruction* hlo, const CommandBufferConfig& config) { if (hlo->opcode() == HloOpcode::kAllReduceStart || hlo->opcode() == HloOpcode::kAllGatherStart) { - return config.contains(DebugOptions::COLLECTIVES); + return config.enabled_commands.contains(DebugOptions::COLLECTIVES); } if (hlo->opcode() == HloOpcode::kAsyncStart) { if (hlo->async_wrapped_opcode() == HloOpcode::kReduceScatter) { - return config.contains(DebugOptions::COLLECTIVES); + return config.enabled_commands.contains(DebugOptions::COLLECTIVES); } } @@ -187,12 +204,12 @@ static bool IsAsyncDoneCommand(const HloInstruction* hlo, const CommandBufferConfig& config) { if (hlo->opcode() == HloOpcode::kAllReduceDone || hlo->opcode() == HloOpcode::kAllGatherDone) { - return config.contains(DebugOptions::COLLECTIVES); + return config.enabled_commands.contains(DebugOptions::COLLECTIVES); } if (hlo->opcode() == HloOpcode::kAsyncDone) { if (hlo->async_wrapped_opcode() == HloOpcode::kReduceScatter) { - return config.contains(DebugOptions::COLLECTIVES); + return config.enabled_commands.contains(DebugOptions::COLLECTIVES); } } @@ -571,9 +588,9 @@ absl::StatusOr CommandBufferScheduling::RewriteCommandBuffer( //===----------------------------------------------------------------------===// CommandBufferScheduling::CommandBufferScheduling( - const se::GpuComputeCapability& gpu_compute_comp, + const se::DeviceDescription& device_description, int32_t gpu_toolkit_version, int32_t gpu_driver_version) - : gpu_compute_comp_(gpu_compute_comp), + : device_description_(device_description), gpu_toolkit_version_(gpu_toolkit_version), gpu_driver_version_(gpu_driver_version) {} @@ -589,10 +606,11 @@ absl::StatusOr CommandBufferScheduling::Run( const DebugOptions& debug_options = module->config().debug_options(); - CommandBufferConfig config; + absl::flat_hash_set commands; for (auto cmd_type : debug_options.xla_gpu_enable_command_buffer()) { - config.insert(static_cast(cmd_type)); + commands.insert(static_cast(cmd_type)); } + CommandBufferConfig config{std::move(commands), device_description_}; // Erase command buffer cmd types that are not supported by the gpu runtime. static constexpr auto kRequireConditionals = {DebugOptions::CONDITIONALS}; @@ -601,7 +619,7 @@ absl::StatusOr CommandBufferScheduling::Run( auto erase = [&](absl::Span cmds) { for (auto cmd : cmds) { - if (config.erase(cmd)) { + if (config.enabled_commands.erase(cmd)) { VLOG(1) << "Removed command buffer support for " << DebugOptions::CommandBufferCmdType_Name(cmd) << " as it's not supported with gpu toolkit version " @@ -627,7 +645,8 @@ absl::StatusOr CommandBufferScheduling::Run( return true; // check for ROCM support }; - if (std::visit(VariantVisitor{check_cuda, check_rocm}, gpu_compute_comp_)) { + if (std::visit(VariantVisitor{check_cuda, check_rocm}, + device_description_.gpu_compute_capability())) { erase(kRequireTracing); // cuStreamBeginCaptureToGraph erase(kRequireConditionals); // on-device control flow } diff --git a/third_party/xla/xla/service/gpu/command_buffer_scheduling.h b/third_party/xla/xla/service/gpu/command_buffer_scheduling.h index ff114ee1b3d72a..79855f307d6003 100644 --- a/third_party/xla/xla/service/gpu/command_buffer_scheduling.h +++ b/third_party/xla/xla/service/gpu/command_buffer_scheduling.h @@ -24,12 +24,11 @@ limitations under the License. #include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" -#include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_schedule.h" #include "xla/service/hlo_pass_interface.h" #include "xla/status.h" -#include "xla/statusor.h" +#include "xla/stream_executor/device_description.h" namespace xla::gpu { @@ -70,12 +69,14 @@ namespace xla::gpu { // custom call to a first class operation later. class CommandBufferScheduling : public HloModulePass { public: - // DebugOptions control which commands are enabled. Long term we want to - // remove that flag and enable all supported commands by default. - using CommandBufferConfig = - absl::flat_hash_set; + struct CommandBufferConfig { + // DebugOptions control which commands are enabled. Long term we want to + // remove that flag and enable all supported commands by default. + absl::flat_hash_set enabled_commands; + const se::DeviceDescription& device_description; + }; - CommandBufferScheduling(const se::GpuComputeCapability& gpu_compute_comp, + CommandBufferScheduling(const se::DeviceDescription& device_description, int32_t gpu_toolkit_version, int32_t gpu_driver_version); @@ -127,7 +128,7 @@ class CommandBufferScheduling : public HloModulePass { CommandBuffer command_buffer); private: - se::GpuComputeCapability gpu_compute_comp_; + se::DeviceDescription device_description_; // For NVIDIA gpus XLA can be compiled with a CUDA version that is larger than // the version supported by the driver, e.g. we can compile for CUDA 12.3 but // have 12.1 driver installed. When deciding what command buffer features we diff --git a/third_party/xla/xla/service/gpu/command_buffer_scheduling_test.cc b/third_party/xla/xla/service/gpu/command_buffer_scheduling_test.cc index 3463ac8001a248..9d75a60a70dda2 100644 --- a/third_party/xla/xla/service/gpu/command_buffer_scheduling_test.cc +++ b/third_party/xla/xla/service/gpu/command_buffer_scheduling_test.cc @@ -24,6 +24,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/ir/hlo_schedule.h" #include "xla/service/hlo_parser.h" +#include "xla/stream_executor/device_description.h" #include "xla/tests/filecheck.h" #include "xla/tests/hlo_test_base.h" #include "xla/tests/verified_hlo_module.h" @@ -39,11 +40,8 @@ class CommandBufferSchedulingTest : public HloTestBase { // Use CUDA 12.3 version for testing as it has all the features we rely on. static constexpr int32_t kCudaVersion = 12030; - const auto& gpu_comp() { - return backend() - .default_stream_executor() - ->GetDeviceDescription() - .gpu_compute_capability(); + const se::DeviceDescription& device_desc() { + return backend().default_stream_executor()->GetDeviceDescription(); } DebugOptions GetDebugOptionsForTest() override { @@ -101,7 +99,7 @@ TEST_F(CommandBufferSchedulingTest, SingleCommandBuffer) { // CHECK: })"; RunAndFilecheckHloRewrite( - hlo, CommandBufferScheduling(gpu_comp(), kCudaVersion, kCudaVersion), + hlo, CommandBufferScheduling(device_desc(), kCudaVersion, kCudaVersion), expected, [](HloModule* module) { EXPECT_TRUE(module->has_schedule()); TF_CHECK_OK(module->schedule().Verify()); @@ -179,7 +177,7 @@ TEST_F(CommandBufferSchedulingTest, MultipleCommandBuffers) { // CHECK: })"; RunAndFilecheckHloRewrite( - hlo, CommandBufferScheduling(gpu_comp(), kCudaVersion, kCudaVersion), + hlo, CommandBufferScheduling(device_desc(), kCudaVersion, kCudaVersion), expected, [](HloModule* module) { EXPECT_TRUE(module->has_schedule()); TF_CHECK_OK(module->schedule().Verify()); @@ -218,7 +216,7 @@ TEST_F(CommandBufferSchedulingTest, AllReduceStartFollowedByDone) { CHECK: })"; RunAndFilecheckHloRewrite( - hlo, CommandBufferScheduling(gpu_comp(), kCudaVersion, kCudaVersion), + hlo, CommandBufferScheduling(device_desc(), kCudaVersion, kCudaVersion), expected, [](HloModule* module) { EXPECT_TRUE(module->has_schedule()); TF_CHECK_OK(module->schedule().Verify()); @@ -253,7 +251,7 @@ TEST_F(CommandBufferSchedulingTest, AllGatherStartFollowedByDone) { CHECK: })"; RunAndFilecheckHloRewrite( - hlo, CommandBufferScheduling(gpu_comp(), kCudaVersion, kCudaVersion), + hlo, CommandBufferScheduling(device_desc(), kCudaVersion, kCudaVersion), expected, [](HloModule* module) { EXPECT_TRUE(module->has_schedule()); TF_CHECK_OK(module->schedule().Verify()); @@ -294,7 +292,7 @@ TEST_F(CommandBufferSchedulingTest, ReduceScatterStartFollowedByDone) { CHECK: })"; RunAndFilecheckHloRewrite( - hlo, CommandBufferScheduling(gpu_comp(), kCudaVersion, kCudaVersion), + hlo, CommandBufferScheduling(device_desc(), kCudaVersion, kCudaVersion), expected, [](HloModule* module) { EXPECT_TRUE(module->has_schedule()); TF_CHECK_OK(module->schedule().Verify()); @@ -351,8 +349,8 @@ TEST_F(CommandBufferSchedulingTest, CollectCommandBufferSequence) { } EXPECT_EQ(seq.size(), 10); - CommandBufferScheduling::CommandBufferConfig config; - config.insert(DebugOptions::FUSION); + CommandBufferScheduling::CommandBufferConfig config{{DebugOptions::FUSION}, + device_desc()}; std::vector command_buffer_sequences = CommandBufferScheduling::CollectCommandBufferSequences(seq, config); @@ -541,7 +539,7 @@ TEST_F(CommandBufferSchedulingTest, ForwardControlDependencies) { CHECK: })"; RunAndFilecheckHloRewrite( - hlo, CommandBufferScheduling(gpu_comp(), kCudaVersion, kCudaVersion), + hlo, CommandBufferScheduling(device_desc(), kCudaVersion, kCudaVersion), expected, [](HloModule* module) { EXPECT_TRUE(module->has_schedule()); TF_CHECK_OK(module->schedule().Verify()); @@ -581,7 +579,7 @@ TEST_F(CommandBufferSchedulingTest, ForwardControlDependenciesToParams) { CHECK: })"; RunAndFilecheckHloRewrite( - hlo, CommandBufferScheduling(gpu_comp(), kCudaVersion, kCudaVersion), + hlo, CommandBufferScheduling(device_desc(), kCudaVersion, kCudaVersion), expected, [](HloModule* module) { EXPECT_TRUE(module->has_schedule()); TF_CHECK_OK(module->schedule().Verify()); @@ -660,7 +658,7 @@ TEST_F(CommandBufferSchedulingTest, WhileNotCommand) { CHECK: })"; RunAndFilecheckHloRewrite( - hlo, CommandBufferScheduling(gpu_comp(), kCudaVersion, kCudaVersion), + hlo, CommandBufferScheduling(device_desc(), kCudaVersion, kCudaVersion), expected, [](HloModule* module) { EXPECT_TRUE(module->has_schedule()); TF_CHECK_OK(module->schedule().Verify()); @@ -722,7 +720,7 @@ TEST_F(CommandBufferSchedulingTest, While) { CHECK: })"; RunAndFilecheckHloRewrite( - hlo, CommandBufferScheduling(gpu_comp(), kCudaVersion, kCudaVersion), + hlo, CommandBufferScheduling(device_desc(), kCudaVersion, kCudaVersion), expected, [](HloModule* module) { EXPECT_TRUE(module->has_schedule()); TF_CHECK_OK(module->schedule().Verify()); @@ -798,7 +796,7 @@ TEST_F(CommandBufferSchedulingTest, Conditional) { CHECK: })"; RunAndFilecheckHloRewrite( - hlo, CommandBufferScheduling(gpu_comp(), kCudaVersion, kCudaVersion), + hlo, CommandBufferScheduling(device_desc(), kCudaVersion, kCudaVersion), expected, [](HloModule* module) { EXPECT_TRUE(module->has_schedule()); TF_CHECK_OK(module->schedule().Verify()); diff --git a/third_party/xla/xla/service/gpu/compile_module_to_llvm_ir.cc b/third_party/xla/xla/service/gpu/compile_module_to_llvm_ir.cc index 80f74bb9fce7fb..e3ba8a1688ec77 100644 --- a/third_party/xla/xla/service/gpu/compile_module_to_llvm_ir.cc +++ b/third_party/xla/xla/service/gpu/compile_module_to_llvm_ir.cc @@ -57,7 +57,6 @@ limitations under the License. #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" -#include "xla/mlir/backends/gpu/transforms/passes.h" #include "xla/mlir/runtime/transforms/compilation_pipeline_gpu.h" #include "xla/mlir_hlo/transforms/gpu_passes.h" #include "xla/service/buffer_assignment.h" @@ -69,10 +68,9 @@ limitations under the License. #include "xla/service/gpu/ir_emitter_context.h" #include "xla/service/gpu/ir_emitter_unnested.h" #include "xla/service/gpu/metrics.h" -#include "xla/service/gpu/runtime/executable.h" -#include "xla/service/gpu/runtime3/conditional_thunk.h" -#include "xla/service/gpu/runtime3/sequential_thunk.h" -#include "xla/service/gpu/runtime3/while_thunk.h" +#include "xla/service/gpu/runtime/conditional_thunk.h" +#include "xla/service/gpu/runtime/sequential_thunk.h" +#include "xla/service/gpu/runtime/while_thunk.h" #include "xla/service/gpu/thunk.h" #include "xla/service/hlo_dataflow_analysis.h" #include "xla/service/hlo_ordering.h" @@ -88,7 +86,6 @@ limitations under the License. #include "xla/stream_executor/stream_executor.h" #include "xla/translate/hlo_to_mhlo/hlo_utils.h" #include "xla/translate/mhlo_to_hlo/location_exporter.h" -#include "xla/translate/mhlo_to_lhlo_with_xla/mhlo_to_lhlo_with_xla.h" #include "xla/util.h" #include "xla/xla_data.pb.h" #include "tsl/platform/casts.h" @@ -147,53 +144,6 @@ class DumpAfterPassIfEnabled : public mlir::PassInstrumentation { int pass_counter_ = 0; }; -// Lowers MLIR module to the XLA Gpu runtime custom calls. -static absl::Status LowerToXlaGpuRuntime( - mlir::ModuleOp module, llvm::StringRef entry_function_name, - llvm::ArrayRef buffer_sizes, ThunkSequence* thunk_sequence, - const HloModule* hlo_module, se::GpuComputeCapability compute_capability) { - if (!module) { - return Internal("No MLIR module to lower."); - } - - const DebugOptions& debug_options = hlo_module->config().debug_options(); - bool should_verify = debug_options.xla_gpu_llvm_verification_level() >= 1; -#ifndef NDEBUG - should_verify = true; -#endif - - mlir::PassManager pm(module->getName(), mlir::PassManager::Nesting::Implicit); - pm.enableVerifier(should_verify); - if (hlo_module != nullptr && DumpingEnabledForHloModule(*hlo_module)) { - pm.addInstrumentation( - std::make_unique(hlo_module, &module)); - } - - absl::flat_hash_set command_types; - for (int command_type_num : debug_options.xla_gpu_enable_command_buffer()) { - if (!DebugOptions::CommandBufferCmdType_IsValid(command_type_num)) { - return Internal("Invalid command buffer command type"); - } - DebugOptions::CommandBufferCmdType command_type = - static_cast(command_type_num); - command_types.insert(command_type); - } - - GpuPipelineOpts opts; - opts.command_types = command_types; - opts.min_graph_size = debug_options.xla_gpu_graph_min_graph_size(); - opts.enable_concurrent_region = - debug_options.xla_gpu_graph_enable_concurrent_region(); - opts.compute_capability = compute_capability; - populateXlaGpuRuntimePasses(pm, thunk_sequence, opts); - - if (pm.run(module).failed()) { - return Internal("Failed to lower LMHLO to Gpu runtime custom calls."); - } - - return absl::OkStatus(); -} - } // namespace void ForAllThunks(const std::function& fn, @@ -228,36 +178,6 @@ static void ForwardCollectiveAttrs(mlir::ModuleOp module, func->setAttr("num_partitions", b.getI64IntegerAttr(config.num_partitions())); } -absl::StatusOr LowerToJitRt( - mlir::ModuleOp mlir_module, llvm::StringRef entry_function_name, - llvm::ArrayRef buffer_sizes, - std::unique_ptr thunk_sequence, const HloModule* hlo_module, - se::GpuComputeCapability compute_capability) { - const auto& module_config = hlo_module->config(); - // Forward collective (NCCL) attributes for use by the lowering pipeline. - ForwardCollectiveAttrs(mlir_module, entry_function_name, module_config); - - // Lower LMHLO operations to the XLA:GPU runtime custom calls. - TF_RETURN_IF_ERROR(LowerToXlaGpuRuntime( - mlir_module, {entry_function_name.data(), entry_function_name.size()}, - buffer_sizes, thunk_sequence.get(), hlo_module, compute_capability)); - - // TODO(b/232033540): Pass MLIR module directly to Gpu runtime executable - // without forcing serialization. - std::string module_str = llvm_ir::DumpToString(mlir_module); - - if (hlo_module != nullptr) { - DumpToFileInDirOrStdout(*hlo_module, "gpu_rt_host", "mlir", module_str); - } - - // Collect allocation indices for handling graph capture functions. - auto allocation_indices = GetAllocationIndices(mlir_module); - - return std::make_unique( - entry_function_name.str(), std::move(module_str), buffer_sizes.vec(), - std::move(allocation_indices), module_config.debug_options()); -} - // Analyze the function signature to reconstruct a vector of BufferAllocation // objects, as well as other output information. // @@ -349,70 +269,27 @@ absl::StatusOr CompileModuleToLlvmIr( << ": " << hlo_module->GetFingerprint128(); uint64_t start_usecs = tsl::Env::Default()->NowMicros(); + mlir::DialectRegistry registry; IrEmitterUnnested::GetDependentDialects(registry); - // Disable MLIR multi-threading to prevent creating too many threads when // compiling XLA executables concurrently (e.g. during auto-tuning). auto mlir_context = std::make_unique( registry, mlir::MLIRContext::Threading::DISABLED); - mlir_context->getDiagEngine().registerHandler(DiagnosticHandler); - mlir::OwningOpRef mlir_module = llvm_ir::CreateMlirModuleOp( - mlir::Builder(mlir_context.get()).getUnknownLoc(), hlo_module->name()); - absl::flat_hash_map - operation_map; - - // Store the allocations in the order of the LMHLO buffer arguments. - std::vector ordered_allocations; - TF_RETURN_IF_ERROR(HloToLhloModule(*results.buffer_assignment, *hlo_module, - *mlir_module, &ordered_allocations, - &operation_map)); - - results.module_name = - mlir::mhlo::GetDebugNameFromLocation(mlir_module->getLoc()); - - if (DumpingEnabledForHloModule(*hlo_module)) { - DumpToFileInDirOrStdout(*hlo_module, "lmhlo", mlir_module.get()); - } - - auto entry_function = mlir::cast( - mlir_module->lookupSymbol(hlo_module->entry_computation()->name())); - - bool emit_from_hlo = !IsXlaRuntimeExecutableEnabled(hlo_module->config()); - - std::vector mlir_allocations; - absl::flat_hash_map mlir_output_info; - Shape mlir_output_shape; - TF_RETURN_IF_ERROR(GetMlirAllocationInfo(entry_function, &mlir_allocations, - &mlir_output_info, - &mlir_output_shape)); + results.module_name = hlo_module->name(); IrEmitterContext ir_emitter_context( hlo_module, results.buffer_assignment.get(), platform_name, gpu_device_info, mlir_context.get(), results.llvm_module.get(), - emit_from_hlo, /*emit_kernels=*/true); + /*emit_kernels=*/true); std::vector allocations; - if (emit_from_hlo) { - results.output_shape = hlo_module->result_shape(); - TF_ASSIGN_OR_RETURN(results.output_info, - GetOutputInfo(*hlo_module, *results.buffer_assignment)); - TF_RET_CHECK(mlir_allocations.size() == ordered_allocations.size()); - ir_emitter_context.set_allocations(ordered_allocations); - results.use_original_allocations = true; - } else { - results.allocations = std::move(mlir_allocations); - results.output_shape = mlir_output_shape; - results.output_info = mlir_output_info; - allocations.reserve(results.allocations.size()); - for (auto& allocation : results.allocations) { - allocations.push_back(&allocation); - } - ir_emitter_context.set_allocations(allocations); - results.use_original_allocations = false; - } + results.output_shape = hlo_module->result_shape(); + TF_ASSIGN_OR_RETURN(results.output_info, + GetOutputInfo(*hlo_module, *results.buffer_assignment)); + results.use_original_allocations = true; auto ir_emitter = IrEmitterUnnested::Create(&ir_emitter_context); @@ -421,7 +298,7 @@ absl::StatusOr CompileModuleToLlvmIr( "GpuCompiler::RunBackend - IR emission for ", hlo_module->name())); TF_RETURN_IF_ERROR( - ir_emitter->EmitLmhloRegion(&entry_function.getBody(), operation_map)); + ir_emitter->EmitHloComputation(hlo_module->entry_computation())); bool supports_runtime_managed_constants = // TODO(b/218907125): Implement this feature for ROCm as well. @@ -442,27 +319,11 @@ absl::StatusOr CompileModuleToLlvmIr( RecordHloToLlvmDuration(end_usecs - start_usecs); } - // TODO(ezhulenev): Remove the FP8 check once https://reviews.llvm.org/D140088 - // is submitted. Currently we can't emit LLVM IR with fp8 types. - if (IsXlaRuntimeExecutableEnabled(hlo_module->config()) && - !HasFp8(*hlo_module)) { - // Sizes of all buffers required for running XLA module. - std::vector buffer_sizes; - llvm::transform( - results.allocations, std::back_inserter(buffer_sizes), - [](const BufferAllocation& allocation) { return allocation.size(); }); - - TF_ASSIGN_OR_RETURN( - results.executable, - LowerToJitRt(*mlir_module, entry_function.getName(), buffer_sizes, - ir_emitter->ConsumeThunkSequence(), hlo_module, - gpu_device_info.gpu_compute_capability())); - } else { - auto thunk_sequence = ir_emitter->ConsumeThunkSequence(); - ForAllThunks([](Thunk* thunk) { thunk->ClearCompileTimeInfo(); }, - thunk_sequence.get()); - results.executable = std::move(thunk_sequence); - } + auto thunk_sequence = ir_emitter->ConsumeThunkSequence(); + ForAllThunks([](Thunk* thunk) { thunk->ClearCompileTimeInfo(); }, + thunk_sequence.get()); + results.executable = std::move(thunk_sequence); + return results; } diff --git a/third_party/xla/xla/service/gpu/compile_module_to_llvm_ir.h b/third_party/xla/xla/service/gpu/compile_module_to_llvm_ir.h index f39cba107e68de..d7f8f395849b0c 100644 --- a/third_party/xla/xla/service/gpu/compile_module_to_llvm_ir.h +++ b/third_party/xla/xla/service/gpu/compile_module_to_llvm_ir.h @@ -42,9 +42,7 @@ struct CompileModuleResults { std::unique_ptr llvm_module; std::unique_ptr buffer_assignment; std::vector allocations; - std::variant - executable; + GpuExecutable::OwnedThunkSequence executable; std::vector constants; absl::flat_hash_map output_info; Shape output_shape; diff --git a/third_party/xla/xla/service/gpu/conv_algorithm_picker.cc b/third_party/xla/xla/service/gpu/conv_algorithm_picker.cc index 6c8cdc8a179b8e..b1bdfa9ded9460 100644 --- a/third_party/xla/xla/service/gpu/conv_algorithm_picker.cc +++ b/third_party/xla/xla/service/gpu/conv_algorithm_picker.cc @@ -416,7 +416,7 @@ absl::StatusOr GpuConvAlgorithmPicker::PickBestAlgorithmNoCache( se::DeviceMemoryAllocator* allocator = config_.GetAllocator(); TF_ASSIGN_OR_RETURN(se::Stream* const stream, config_.GetStream()); - StatusOr result_or(Internal("Unknown platform.")); + absl::StatusOr result_or(Internal("Unknown platform.")); // Check StreamExecutor on which platform it is. ROCm and Cuda implementation // have diverged. Specifically, we need to make sure redzone allocator related // utilities are not used in ROCm routine @@ -766,8 +766,9 @@ absl::StatusOr GpuConvAlgorithmPicker::AutotuneOneConvRunner( reference_result_buffers[i], runtime_arguments.input_output_allocator->AllocateBytes( result_buffers[i].size())); - stream->ThenMemcpy(&reference_result_buffers[i], result_buffers[i], - result_buffers[i].size()); + TF_RETURN_IF_ERROR(stream->Memcpy(&reference_result_buffers[i], + result_buffers[i], + result_buffers[i].size())); } (*reference_result) = {alg, reference_result_buffers}; } @@ -965,7 +966,7 @@ GpuConvAlgorithmPicker::PickBestAlgorithmNoCacheRocm( // before autotuning. It's conceivable that using uninitialized memory as // the inputs might affect performance if e.g. the inputs contain // denormals, and this is easy enough. - stream->ThenMemZero(&buffer, buffer.size()); + return stream->MemZero(&buffer, buffer.size()); }; // Allocate space for the input, filter, and output of the convolution. We @@ -975,7 +976,7 @@ GpuConvAlgorithmPicker::PickBestAlgorithmNoCacheRocm( TF_ASSIGN_OR_RETURN(auto buffer, input_output_allocator.AllocateBytes( ShapeUtil::ByteSizeOf(operand->shape()))); - initialize_buffer(buffer); + TF_RETURN_IF_ERROR(initialize_buffer(buffer)); operand_buffers.push_back(buffer); } @@ -987,14 +988,14 @@ GpuConvAlgorithmPicker::PickBestAlgorithmNoCacheRocm( result_buffers[i], input_output_allocator.AllocateBytes( ShapeUtil::ByteSizeOf(instr->shape().tuple_shapes(i)))); - initialize_buffer(result_buffers[i]); + TF_RETURN_IF_ERROR(initialize_buffer(result_buffers[i])); } } else { TF_ASSIGN_OR_RETURN( result_buffers[0], input_output_allocator.AllocateBytes( ShapeUtil::ByteSizeOf(instr->shape().tuple_shapes(0)))); - initialize_buffer(result_buffers[0]); + TF_RETURN_IF_ERROR(initialize_buffer(result_buffers[0])); } ScratchAllocator scratch_allocator(device_ordinal, allocator); diff --git a/third_party/xla/xla/service/gpu/conv_layout_normalization_test.cc b/third_party/xla/xla/service/gpu/conv_layout_normalization_test.cc index eec7e5c52aaba9..61386297ca7c9e 100644 --- a/third_party/xla/xla/service/gpu/conv_layout_normalization_test.cc +++ b/third_party/xla/xla/service/gpu/conv_layout_normalization_test.cc @@ -89,7 +89,8 @@ ENTRY %TestComputation { )"); } -TEST_F(ConvolutionLayoutNormalizationTest, FusedConv3D) { +// TODO(rocm): No Conv3D +TEST_F(ConvolutionLayoutNormalizationTest, DISABLED_ON_GPU_ROCM(FusedConv3D)) { const char* hlo = R"( HloModule TestModule diff --git a/third_party/xla/xla/service/gpu/cub_sort_kernel.cu.cc b/third_party/xla/xla/service/gpu/cub_sort_kernel.cu.cc index 4f694c735b5509..2efd6d09bffa9b 100644 --- a/third_party/xla/xla/service/gpu/cub_sort_kernel.cu.cc +++ b/third_party/xla/xla/service/gpu/cub_sort_kernel.cu.cc @@ -18,11 +18,7 @@ limitations under the License. #include #include -#if GOOGLE_CUDA -#include "xla/service/gpu/gpu_prim_cuda.h" -#elif TENSORFLOW_USE_ROCM -#include "xla/service/gpu/gpu_prim_rocm.h" -#endif // TENSORFLOW_USE_ROCM +#include "xla/service/gpu/gpu_prim.h" namespace xla { namespace gpu { diff --git a/third_party/xla/xla/service/gpu/cublas_cudnn.cc b/third_party/xla/xla/service/gpu/cublas_cudnn.cc index fed9fe599b63b1..dc4a39bff19f76 100644 --- a/third_party/xla/xla/service/gpu/cublas_cudnn.cc +++ b/third_party/xla/xla/service/gpu/cublas_cudnn.cc @@ -75,43 +75,43 @@ const absl::string_view kCudnnConvReorderFilterAndBiasCallTarget = const absl::string_view kCudnnNormCallTarget = "__cudnn$norm"; // fMHA forward call targets. -const absl::string_view kCudnnfMHABmmBmmCallTarget = "__cudnn$fhmaBmmBmm"; -const absl::string_view kCudnnfMHASoftmaxCallTarget = "__cudnn$fhmaSoftmax"; +const absl::string_view kCudnnfMHABmmBmmCallTarget = "__cudnn$fmhaBmmBmm"; +const absl::string_view kCudnnfMHASoftmaxCallTarget = "__cudnn$fmhaSoftmax"; const absl::string_view kCudnnfMHAScaleBiasMaskSoftmaxCallTarget = - "__cudnn$fhmaScaleBiasMaskSoftmax"; + "__cudnn$fmhaScaleBiasMaskSoftmax"; const absl::string_view kCudnnfMHAScaleBiasMaskSoftmaxDropoutCallTarget = - "__cudnn$fhmaScaleBiasMaskSoftmaxDropout"; + "__cudnn$fmhaScaleBiasMaskSoftmaxDropout"; const absl::string_view kCudnnfMHAScaleBiasSoftmaxDropoutCallTarget = - "__cudnn$fhmaScaleBiasSoftmaxDropout"; + "__cudnn$fmhaScaleBiasSoftmaxDropout"; const absl::string_view kCudnnfMHAScaleBiasSoftmaxCallTarget = - "__cudnn$fhmaScaleBiasSoftmax"; + "__cudnn$fmhaScaleBiasSoftmax"; const absl::string_view kCudnnfMHAScaleMaskSoftmaxCallTarget = - "__cudnn$fhmaScaleMaskSoftmax"; + "__cudnn$fmhaScaleMaskSoftmax"; const absl::string_view kCudnnfMHAScaleMaskSoftmaxDropoutCallTarget = - "__cudnn$fhmaScaleMaskSoftmaxDropout"; + "__cudnn$fmhaScaleMaskSoftmaxDropout"; const absl::string_view kCudnnfMHASoftmaxDropoutCallTarget = - "__cudnn$fhmaSoftmaxDropout"; + "__cudnn$fmhaSoftmaxDropout"; // fMHA backward call targets. const absl::string_view kCudnnfMHABmmBmmBackwardCallTarget = - "__cudnn$fhmaBmmBmmBackward"; + "__cudnn$fmhaBmmBmmBackward"; const absl::string_view kCudnnfMHASoftmaxBackwardCallTarget = - "__cudnn$fhmaSoftmaxBackward"; + "__cudnn$fmhaSoftmaxBackward"; const absl::string_view kCudnnfMHAScaleBiasMaskSoftmaxBackwardCallTarget = - "__cudnn$fhmaScaleBiasMaskSoftmaxBackward"; + "__cudnn$fmhaScaleBiasMaskSoftmaxBackward"; const absl::string_view kCudnnfMHAScaleBiasMaskSoftmaxDropoutBackwardCallTarget = - "__cudnn$fhmaScaleBiasMaskSoftmaxDropoutBackward"; + "__cudnn$fmhaScaleBiasMaskSoftmaxDropoutBackward"; const absl::string_view kCudnnfMHAScaleBiasSoftmaxDropoutBackwardCallTarget = - "__cudnn$fhmaScaleBiasSoftmaxDropoutBackward"; + "__cudnn$fmhaScaleBiasSoftmaxDropoutBackward"; const absl::string_view kCudnnfMHAScaleBiasSoftmaxBackwardCallTarget = - "__cudnn$fhmaScaleBiasSoftmaxBackward"; + "__cudnn$fmhaScaleBiasSoftmaxBackward"; const absl::string_view kCudnnfMHAScaleMaskSoftmaxBackwardCallTarget = - "__cudnn$fhmaScaleMaskSoftmaxBackward"; + "__cudnn$fmhaScaleMaskSoftmaxBackward"; const absl::string_view kCudnnfMHAScaleMaskSoftmaxDropoutBackwardCallTarget = - "__cudnn$fhmaScaleMaskSoftmaxDropoutBackward"; + "__cudnn$fmhaScaleMaskSoftmaxDropoutBackward"; const absl::string_view kCudnnfMHASoftmaxDropoutBackwardCallTarget = - "__cudnn$fhmaSoftmaxDropoutBackward"; + "__cudnn$fmhaSoftmaxDropoutBackward"; const absl::string_view kCubDeviceRadixSortTarget = "__cub$DeviceRadixSort"; diff --git a/third_party/xla/xla/service/gpu/cublas_cudnn.h b/third_party/xla/xla/service/gpu/cublas_cudnn.h index cd9e78e805f688..e3937b8bf322fe 100644 --- a/third_party/xla/xla/service/gpu/cublas_cudnn.h +++ b/third_party/xla/xla/service/gpu/cublas_cudnn.h @@ -48,6 +48,12 @@ enum class CudnnConvKind { // => output }; +enum class CudnnNormKind { + kLayerForwardInfer, + kLayerForwardTrain, + kLayerBackward, +}; + enum class CudnnfMHAKind { kBmmBmm, kScaleBiasMaskSoftmax, diff --git a/third_party/xla/xla/service/gpu/cudnn_fused_mha_rewriter.cc b/third_party/xla/xla/service/gpu/cudnn_fused_mha_rewriter.cc index 9694bdc6f8a076..04fe02e376ede0 100644 --- a/third_party/xla/xla/service/gpu/cudnn_fused_mha_rewriter.cc +++ b/third_party/xla/xla/service/gpu/cudnn_fused_mha_rewriter.cc @@ -309,17 +309,15 @@ bool IsComputeCapabilityAndCudnnSupported( stream_executor::CudaComputeCapability cc, stream_executor::dnn::VersionInfo cudnn_version, stream_executor::dnn::VersionInfo supported_cudnn_version) { - if (!((cc.IsAtLeast(se::CudaComputeCapability::AMPERE) && cc.minor == 0) && - (cudnn_version >= supported_cudnn_version))) { - VLOG(2) << absl::StrFormat( - "CudnnFusedMHARewriter did not run. Unsupported compute " - "capability(==8.0) or cudnn version(>=%d.%d.%d)", - supported_cudnn_version.major_version(), - supported_cudnn_version.minor_version(), - supported_cudnn_version.patch()); - return false; + if (cc.IsAtLeastAmpere() && cudnn_version >= supported_cudnn_version) { + return true; } - return true; + VLOG(2) << absl::StrFormat( + "CudnnFusedMHARewriter did not run. Unsupported compute " + "capability(%s; should be >= 8.0) or cudnn version(%s; should be >= %s)", + cc.ToString(), cudnn_version.ToString(), + supported_cudnn_version.ToString()); + return false; } bool IsSupportedPrimitiveType(const HloInstruction* bmm) { @@ -440,7 +438,7 @@ absl::StatusOr IsSupportedBMM2(const HloInstruction* bmm_2, return true; } -StatusOr IsFlashAttention( +absl::StatusOr IsFlashAttention( HloInstruction* bmm_1, bool is_causal_mask, absl::string_view custom_call_name, stream_executor::CudaComputeCapability cc, @@ -473,17 +471,11 @@ StatusOr IsFlashAttention( TF_RET_CHECK(seq_q.size() == 1); TF_RET_CHECK(seq_k.size() == 1); TF_RET_CHECK(hidden_dim.size() == 1); - auto is_fixed_topology = - (custom_call_name == kCudnnfMHAScaleBiasSoftmaxDropoutCallTarget || - custom_call_name == kCudnnfMHAScaleBiasSoftmaxCallTarget || - custom_call_name == kCudnnfMHAScaleBiasMaskSoftmaxDropoutCallTarget || - custom_call_name == kCudnnfMHAScaleBiasMaskSoftmaxCallTarget); auto is_seqlen_supported = seq_q[0] > 512 && seq_k[0] > 512 && seq_q[0] % 64 == 0 && seq_k[0] % 64 == 0; auto is_hidden_dim_supported = hidden_dim[0] == 64 || hidden_dim[0] == 128; - auto is_flash_attention = - is_seqlen_supported && is_hidden_dim_supported && is_fixed_topology; + auto is_flash_attention = is_seqlen_supported && is_hidden_dim_supported; auto is_cross_attention = seq_q[0] != seq_k[0]; // flash attention requires cuDNN 8.9.3 to run non-fused QKV @@ -671,10 +663,12 @@ MatchFwdResult MatchBmm1UnfusedBiasSoftmaxBmm2(MatchFwdResult previous_result, OptionalConvert(first_bmm_pattern.WithOneUse()), OptionalConvert( m::Broadcast(m::Constant(&scale).WithPredicate(IsScalar)))); - if (Match(softmax_input, - OptionalConvert(OptionalBitcast(first_bmm_pattern)))) { + OptionalConvert(OptionalBitcast(m::AnyOf( + first_bmm_pattern, unfused_scaled_bmm_subpattern))))) { + // bmm1 - (scale) - softmax match_result.matched_bmm_1 = bmm_1; + match_result.matched_scale = scale; match_result.matched_custom_call_name = has_dropout ? kCudnnfMHASoftmaxDropoutCallTarget : kCudnnfMHASoftmaxCallTarget; @@ -685,6 +679,7 @@ MatchFwdResult MatchBmm1UnfusedBiasSoftmaxBmm2(MatchFwdResult previous_result, unfused_scaled_bmm_subpattern.WithOneUse(), first_bmm_pattern.WithOneUse()))), m::Op(&bias))))) { + // bmm1 - (scale) - bias - softmax match_result.matched_bmm_1 = bmm_1; match_result.matched_scale = scale; match_result.matched_bias = bias; @@ -1227,11 +1222,8 @@ absl::StatusOr IsMHABlockSupported( return false; } - if (is_training && - (custom_call_name != kCudnnfMHAScaleBiasSoftmaxDropoutCallTarget && - custom_call_name != kCudnnfMHAScaleBiasSoftmaxCallTarget && - custom_call_name != kCudnnfMHAScaleBiasMaskSoftmaxDropoutCallTarget && - custom_call_name != kCudnnfMHAScaleBiasMaskSoftmaxCallTarget)) { + // cuDNN FMHA requires softmax for backward + if (is_training && custom_call_name == kCudnnfMHABmmBmmCallTarget) { VLOG(3) << "Unsupported fused MHA training pattern.\n"; return false; } diff --git a/third_party/xla/xla/service/gpu/cudnn_fused_mha_rewriter_test.cc b/third_party/xla/xla/service/gpu/cudnn_fused_mha_rewriter_test.cc index a5920afc015661..54376f827b091b 100644 --- a/third_party/xla/xla/service/gpu/cudnn_fused_mha_rewriter_test.cc +++ b/third_party/xla/xla/service/gpu/cudnn_fused_mha_rewriter_test.cc @@ -4012,6 +4012,220 @@ ENTRY main.92 { EXPECT_EQ(config.is_causal_mask(), false); } +TEST_F(CudnnFusedMhaRewriterTestHloTest, + FlashAttentionBF16TrainingBmm1SoftmaxBmm2Pattern) { + const char* module_str = R"( +HloModule jit__unnamed_wrapped_function_, entry_computation_layout={(bf16[2,6,2048,128]{3,2,1,0},bf16[2,6,128,2048]{3,2,1,0},bf16[2,6,2048,128]{3,2,1,0},bf16[2,6,2048,128]{3,2,1,0})->(bf16[2,6,2048,128]{3,2,1,0}, bf16[2,6,2048,128]{3,2,1,0}, bf16[2,6,128,2048]{3,2,1,0}, bf16[2,6,2048,128]{3,2,1,0})}, allow_spmd_sharding_propagation_to_output={true,true,true,true} +region_0.32 { + Arg_0.33 = bf16[] parameter(0) + Arg_1.34 = bf16[] parameter(1) + ROOT maximum = bf16[] maximum(Arg_0.33, Arg_1.34) +} +region_1.44 { + Arg_0.45 = f32[] parameter(0) + Arg_1.46 = f32[] parameter(1) + ROOT add = f32[] add(Arg_0.45, Arg_1.46) +} +region_2.66 { + Arg_0.67 = bf16[] parameter(0) + Arg_1.68 = bf16[] parameter(1) + ROOT add.1 = bf16[] add(Arg_0.67, Arg_1.68) +} +ENTRY main.92 { + Arg_0.1 = bf16[2,6,2048,128]{3,2,1,0} parameter(0), sharding={replicated} + Arg_1.2 = bf16[2,6,128,2048]{3,2,1,0} parameter(1), sharding={replicated} + dot.14 = bf16[2,6,2048,2048]{3,2,1,0} dot(Arg_0.1, Arg_1.2), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2} + constant.17 = bf16[] constant(2) + broadcast.29 = bf16[2,6,2048,2048]{3,2,1,0} broadcast(constant.17), dimensions={} + multiply.2 = bf16[2,6,2048,2048]{3,2,1,0} multiply(dot.14, broadcast.29) + constant.10 = bf16[] constant(-inf) + constant.16 = bf16[] constant(0) + reduce.36 = bf16[2,6,2048]{2,1,0} reduce(multiply.2, constant.10), dimensions={3}, to_apply=region_0.32 + broadcast.21 = bf16[2,6,2048,2048]{3,2,1,0} broadcast(reduce.36), dimensions={0,1,2} + subtract.1 = bf16[2,6,2048,2048]{3,2,1,0} subtract(multiply.2, broadcast.21) + exponential.1 = bf16[2,6,2048,2048]{3,2,1,0} exponential(subtract.1) + convert.5 = f32[2,6,2048,2048]{3,2,1,0} convert(exponential.1) + constant.14 = f32[] constant(0) + reduce.48 = f32[2,6,2048]{2,1,0} reduce(convert.5, constant.14), dimensions={3}, to_apply=region_1.44 + convert.9 = bf16[2,6,2048]{2,1,0} convert(reduce.48) + broadcast.32 = bf16[2,6,2048,2048]{3,2,1,0} broadcast(convert.9), dimensions={0,1,2} + divide.5 = bf16[2,6,2048,2048]{3,2,1,0} divide(exponential.1, broadcast.32) + Arg_2.3 = bf16[2,6,2048,128]{3,2,1,0} parameter(2), sharding={replicated} + dot.57 = bf16[2,6,2048,128]{3,2,1,0} dot(divide.5, Arg_2.3), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2} + Arg_3.4 = bf16[2,6,2048,128]{3,2,1,0} parameter(3), sharding={replicated} + dot.60 = bf16[2,6,2048,2048]{3,2,1,0} dot(Arg_3.4, Arg_2.3), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={3} + divide.4 = bf16[2,6,2048,2048]{3,2,1,0} divide(dot.60, broadcast.32) + constant.15 = bf16[] constant(1) + broadcast.25 = bf16[2,6,2048]{2,1,0} broadcast(constant.15), dimensions={} + multiply.3 = bf16[2,6,2048]{2,1,0} multiply(convert.9, convert.9) + divide.3 = bf16[2,6,2048]{2,1,0} divide(broadcast.25, multiply.3) + broadcast.26 = bf16[2,6,2048,2048]{3,2,1,0} broadcast(divide.3), dimensions={0,1,2} + multiply.4 = bf16[2,6,2048,2048]{3,2,1,0} multiply(dot.60, broadcast.26) + multiply.5 = bf16[2,6,2048,2048]{3,2,1,0} multiply(multiply.4, exponential.1) + reduce.70 = bf16[2,6,2048]{2,1,0} reduce(multiply.5, constant.16), dimensions={3}, to_apply=region_2.66 + negate.2 = bf16[2,6,2048]{2,1,0} negate(reduce.70) + broadcast.31 = bf16[2,6,2048,2048]{3,2,1,0} broadcast(negate.2), dimensions={0,1,2} + add.5 = bf16[2,6,2048,2048]{3,2,1,0} add(divide.4, broadcast.31) + multiply.8 = bf16[2,6,2048,2048]{3,2,1,0} multiply(add.5, exponential.1) + multiply.9 = bf16[2,6,2048,2048]{3,2,1,0} multiply(multiply.8, broadcast.29) + dot.90 = bf16[2,6,2048,128]{3,2,1,0} dot(multiply.9, Arg_1.2), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={3} + dot = bf16[2,6,128,2048]{3,2,1,0} dot(Arg_0.1, multiply.9), lhs_batch_dims={0,1}, lhs_contracting_dims={2}, rhs_batch_dims={0,1}, rhs_contracting_dims={2} + dot.1 = bf16[2,6,2048,128]{3,2,1,0} dot(divide.5, Arg_3.4), lhs_batch_dims={0,1}, lhs_contracting_dims={2}, rhs_batch_dims={0,1}, rhs_contracting_dims={2} + ROOT tuple.91 = (bf16[2,6,2048,128]{3,2,1,0}, bf16[2,6,2048,128]{3,2,1,0}, bf16[2,6,128,2048]{3,2,1,0}, bf16[2,6,2048,128]{3,2,1,0}) tuple(dot.57, dot.90, dot, dot.1) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str)); + CudnnFusedMHARewriter fusedMhaRewriter{ + GetCudaComputeCapability(), GetCudnnVersionWithFlashAttentionSupport()}; + TF_ASSERT_OK(RunHloPass(&fusedMhaRewriter, m.get()).status()); + HloDCE dce; + TF_ASSERT_OK(RunHloPass(&dce, m.get()).status()); + + ComputationLayout computation_layout( + m->entry_computation()->ComputeProgramShape()); + + const HloInstruction* fmha; + + SCOPED_TRACE(m->ToString()); + EXPECT_THAT( + m->entry_computation()->root_instruction(), + GmockMatch(m::Tuple( + m::GetTupleElement( + m::CustomCall(&fmha, {kCudnnfMHASoftmaxCallTarget}), 0) + .WithShape(BF16, {2, 6, 2048, 128}), + m::GetTupleElement( + m::CustomCall(&fmha, {kCudnnfMHASoftmaxBackwardCallTarget}), 0) + .WithShape(BF16, {2, 6, 2048, 128}), + m::Transpose( + m::GetTupleElement( + m::CustomCall({kCudnnfMHASoftmaxBackwardCallTarget}), 1)) + .WithShape(BF16, {2, 6, 128, 2048}), + m::GetTupleElement( + m::CustomCall({kCudnnfMHASoftmaxBackwardCallTarget}), 2) + .WithShape(BF16, {2, 6, 2048, 128})))); + TF_ASSERT_OK_AND_ASSIGN(auto gpu_config, + fmha->backend_config()); + const CudnnfMHABackendConfig& config = gpu_config.cudnn_fmha_backend_config(); + EXPECT_EQ(fmha->operands().size(), 6); + EXPECT_NEAR(config.dropout_rate(), 0, 1e-2); + EXPECT_FLOAT_EQ(config.fmha_scale(), 2); + EXPECT_EQ(config.is_flash_attention(), true); + EXPECT_EQ(config.is_causal_mask(), false); +} + +TEST_F(CudnnFusedMhaRewriterTestHloTest, + FlashAttentionBF16TrainingBmm1ScaleMaskSoftmaxBmm2Pattern) { + const char* module_str = R"( +HloModule jit__unnamed_wrapped_function_, entry_computation_layout={(bf16[2,6,2048,64]{3,2,1,0},bf16[2,6,64,2048]{3,2,1,0},bf16[2,6,2048,64]{3,2,1,0},bf16[2,6,2048,64]{3,2,1,0})->(bf16[2,6,2048,64]{3,2,1,0}, bf16[2,6,2048,64]{3,2,1,0}, bf16[2,6,64,2048]{3,2,1,0}, bf16[2,6,2048,64]{3,2,1,0})}, allow_spmd_sharding_propagation_to_output={true,true,true,true} + +region_0.21 { + Arg_0.22 = bf16[] parameter(0) + Arg_1.23 = bf16[] parameter(1) + ROOT maximum = bf16[] maximum(Arg_0.22, Arg_1.23) +} + +region_1.33 { + Arg_0.34 = f32[] parameter(0) + Arg_1.35 = f32[] parameter(1) + ROOT add = f32[] add(Arg_0.34, Arg_1.35) +} + +region_2.55 { + Arg_0.56 = bf16[] parameter(0) + Arg_1.57 = bf16[] parameter(1) + ROOT add.1 = bf16[] add(Arg_0.56, Arg_1.57) +} + +ENTRY main.82 { + constant.18 = pred[2,6,2048,2048]{3,2,1,0} constant({...}) + Arg_0.1 = bf16[2,6,2048,64]{3,2,1,0} parameter(0), sharding={replicated} + Arg_1.2 = bf16[2,6,64,2048]{3,2,1,0} parameter(1), sharding={replicated} + dot.17 = bf16[2,6,2048,2048]{3,2,1,0} dot(Arg_0.1, Arg_1.2), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2} + constant.22 = bf16[] constant(2) + broadcast.24 = bf16[2,6,2048,2048]{3,2,1,0} broadcast(constant.22), dimensions={} + multiply.2 = bf16[2,6,2048,2048]{3,2,1,0} multiply(dot.17, broadcast.24) + constant.19 = bf16[] constant(1) + constant.21 = bf16[] constant(0) + broadcast.23 = bf16[2,6,2048,2048]{3,2,1,0} broadcast(constant.21), dimensions={} + select.1 = bf16[2,6,2048,2048]{3,2,1,0} select(constant.18, multiply.2, broadcast.23) + constant.15 = bf16[] constant(-inf) + reduce.25 = bf16[2,6,2048]{2,1,0} reduce(select.1, constant.15), dimensions={3}, to_apply=region_0.21 + broadcast.17 = bf16[2,6,2048,2048]{3,2,1,0} broadcast(reduce.25), dimensions={0,1,2} + subtract.1 = bf16[2,6,2048,2048]{3,2,1,0} subtract(select.1, broadcast.17) + exponential.1 = bf16[2,6,2048,2048]{3,2,1,0} exponential(subtract.1) + convert.5 = f32[2,6,2048,2048]{3,2,1,0} convert(exponential.1) + constant.17 = f32[] constant(0) + reduce.37 = f32[2,6,2048]{2,1,0} reduce(convert.5, constant.17), dimensions={3}, to_apply=region_1.33 + convert.9 = bf16[2,6,2048]{2,1,0} convert(reduce.37) + broadcast.26 = bf16[2,6,2048,2048]{3,2,1,0} broadcast(convert.9), dimensions={0,1,2} + divide.5 = bf16[2,6,2048,2048]{3,2,1,0} divide(exponential.1, broadcast.26) + Arg_2.3 = bf16[2,6,2048,64]{3,2,1,0} parameter(2), sharding={replicated} + dot.46 = bf16[2,6,2048,64]{3,2,1,0} dot(divide.5, Arg_2.3), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2} + Arg_3.4 = bf16[2,6,2048,64]{3,2,1,0} parameter(3), sharding={replicated} + dot.49 = bf16[2,6,2048,2048]{3,2,1,0} dot(Arg_3.4, Arg_2.3), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={3} + divide.4 = bf16[2,6,2048,2048]{3,2,1,0} divide(dot.49, broadcast.26) + broadcast.20 = bf16[2,6,2048]{2,1,0} broadcast(constant.19), dimensions={} + multiply.3 = bf16[2,6,2048]{2,1,0} multiply(convert.9, convert.9) + divide.3 = bf16[2,6,2048]{2,1,0} divide(broadcast.20, multiply.3) + broadcast.21 = bf16[2,6,2048,2048]{3,2,1,0} broadcast(divide.3), dimensions={0,1,2} + multiply.4 = bf16[2,6,2048,2048]{3,2,1,0} multiply(dot.49, broadcast.21) + multiply.5 = bf16[2,6,2048,2048]{3,2,1,0} multiply(multiply.4, exponential.1) + reduce.59 = bf16[2,6,2048]{2,1,0} reduce(multiply.5, constant.21), dimensions={3}, to_apply=region_2.55 + negate.2 = bf16[2,6,2048]{2,1,0} negate(reduce.59) + broadcast.25 = bf16[2,6,2048,2048]{3,2,1,0} broadcast(negate.2), dimensions={0,1,2} + add.5 = bf16[2,6,2048,2048]{3,2,1,0} add(divide.4, broadcast.25) + multiply.8 = bf16[2,6,2048,2048]{3,2,1,0} multiply(add.5, exponential.1) + select.3 = bf16[2,6,2048,2048]{3,2,1,0} select(constant.18, multiply.8, broadcast.23) + multiply.9 = bf16[2,6,2048,2048]{3,2,1,0} multiply(select.3, broadcast.24) + dot.80 = bf16[2,6,2048,64]{3,2,1,0} dot(multiply.9, Arg_1.2), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={3} + dot = bf16[2,6,64,2048]{3,2,1,0} dot(Arg_0.1, multiply.9), lhs_batch_dims={0,1}, lhs_contracting_dims={2}, rhs_batch_dims={0,1}, rhs_contracting_dims={2} + dot.1 = bf16[2,6,2048,64]{3,2,1,0} dot(divide.5, Arg_3.4), lhs_batch_dims={0,1}, lhs_contracting_dims={2}, rhs_batch_dims={0,1}, rhs_contracting_dims={2} + ROOT tuple.81 = (bf16[2,6,2048,64]{3,2,1,0}, bf16[2,6,2048,64]{3,2,1,0}, bf16[2,6,64,2048]{3,2,1,0}, bf16[2,6,2048,64]{3,2,1,0}) tuple(dot.46, dot.80, dot, dot.1) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str)); + CudnnFusedMHARewriter fusedMhaRewriter{ + GetCudaComputeCapability(), GetCudnnVersionWithFlashAttentionSupport()}; + TF_ASSERT_OK(RunHloPass(&fusedMhaRewriter, m.get()).status()); + HloDCE dce; + TF_ASSERT_OK(RunHloPass(&dce, m.get()).status()); + + ComputationLayout computation_layout( + m->entry_computation()->ComputeProgramShape()); + + const HloInstruction* fmha; + + SCOPED_TRACE(m->ToString()); + EXPECT_THAT( + m->entry_computation()->root_instruction(), + GmockMatch(m::Tuple( + m::GetTupleElement( + m::CustomCall(&fmha, {kCudnnfMHAScaleMaskSoftmaxCallTarget}), 0) + .WithShape(BF16, {2, 6, 2048, 64}), + m::GetTupleElement( + m::CustomCall(&fmha, + {kCudnnfMHAScaleMaskSoftmaxBackwardCallTarget}), + 0) + .WithShape(BF16, {2, 6, 2048, 64}), + m::Transpose( + m::GetTupleElement( + m::CustomCall({kCudnnfMHAScaleMaskSoftmaxBackwardCallTarget}), + 1)) + .WithShape(BF16, {2, 6, 64, 2048}), + m::GetTupleElement( + m::CustomCall({kCudnnfMHAScaleMaskSoftmaxBackwardCallTarget}), 2) + .WithShape(BF16, {2, 6, 2048, 64})))); + TF_ASSERT_OK_AND_ASSIGN(auto gpu_config, + fmha->backend_config()); + const CudnnfMHABackendConfig& config = gpu_config.cudnn_fmha_backend_config(); + EXPECT_EQ(fmha->operands().size(), 7); + EXPECT_NEAR(config.dropout_rate(), 0, 1e-2); + EXPECT_EQ(config.is_flash_attention(), true); + EXPECT_EQ(config.is_causal_mask(), false); +} + TEST_F(CudnnFusedMhaRewriterTestHloTest, FlashAttentionF16Bmm1BiasSoftmaxBmm2PatternCrossAttention) { const char* module_str = R"( diff --git a/third_party/xla/xla/service/gpu/cudnn_norm_rewriter.cc b/third_party/xla/xla/service/gpu/cudnn_norm_rewriter.cc index 6fa4cc9d18c7db..17164251cf3edc 100644 --- a/third_party/xla/xla/service/gpu/cudnn_norm_rewriter.cc +++ b/third_party/xla/xla/service/gpu/cudnn_norm_rewriter.cc @@ -57,6 +57,93 @@ namespace { namespace m = match; +// Traverses the graph upward starting at instr and returns the +// first instruction that is not a convert, bitcast or reshape. +const HloInstruction* SkipUnaryOps(const HloInstruction* instr) { + while (instr->opcode() == HloOpcode::kConvert || + instr->opcode() == HloOpcode::kBitcast || + instr->opcode() == HloOpcode::kReshape) { + instr = instr->operand(0); + } + return instr; +} + +// Recursively traverses the graph downward starting at instr and stores in +// instrs the users that are not a convert, bitcast or reshape. +void SkipUnaryOpsTopDownRecursive(HloInstruction* instr, + std::vector& instrs) { + if (instr->opcode() == HloOpcode::kConvert || + instr->opcode() == HloOpcode::kBitcast || + instr->opcode() == HloOpcode::kReshape) { + for (HloInstruction* user : instr->users()) { + SkipUnaryOpsTopDownRecursive(user, instrs); + } + } else { + instrs.emplace_back(instr); + } +} + +// Holds auxiliary information about individual layer norm patterns rewritten +// into a cuDNN Custom Call. +struct NormMetadata { + // Transposes applied to the input and output of the forward layer norm to + // order the normalization and non-normalization dimensions as required by + // cuDNN. Nullptr if no transposes were inserted. + HloInstruction *x_transpose, *y_transpose; + // The reduction and non-reduction dimensions of the input into the forward + // layer norm before the potential application of transposes. + std::vector norm_dims, non_norm_dims; +}; + +// Map from the instruction pointer of a layer norm Custom Call to its metadata. +using NormMetadataMap = absl::flat_hash_map; + +// Captures multiple HloInstruction pointers and verifies that their target +// is identical. +// +// Example: +// Pattern cos(x) / sin(x) with cos and sin intended to operate on the same +// HloInstruction: +// UniqueHloInstruction x; +// bool m = Match( +// instr, m::Divide(m::Cos(m::Op().WithPredicate(x.capture_and_verify)), +// m::Sin(m::Op().WithPredicate(x.capture_and_verify)))); +// m is true and x.Instr() returns an HloInstruction pointer to the operand of +// cosine and sine iff HloInstruction *instr points to a division of a cosine by +// a sine that operate on the same instruction. +class UniqueHloInstruction { + public: + UniqueHloInstruction() : is_set_(false), instr_(nullptr) {} + HloInstruction* Instr() const { return instr_; } + void SetInstr(HloInstruction* instr) { + is_set_ = true; + instr_ = instr; + } + + // Stores instr when invoked the first time. Otherwise, compares instr to the + // stored value and sets the stored value to nullptr if the comparison fails. + bool CaptureOrVerify(HloInstruction* instr) { + if (is_set_ && instr != instr_) { + instr_ = nullptr; + } + if (!is_set_) { + is_set_ = true; + instr_ = instr; + } + return instr_; + } + + // Lambda for capturing or verifying an instruction using WithPredicate. + const std::function capture_or_verify = + [this](const HloInstruction* instr) -> bool { + return CaptureOrVerify(const_cast(instr)); + }; + + private: + bool is_set_; + HloInstruction* instr_; +}; + // Returns an architecture-specific constant for the calculation of an upper // bound for the size of the scratch space for layer norm kernels. absl::StatusOr CConstant( @@ -78,11 +165,22 @@ bool CompatibleElementType(const HloInstruction* instr) { } // Returns whether the HLO Computation applied by instr calculates the sum of -// the elements. -bool AppliesAddReduce(const HloInstruction* instr) { +// the elements. When provided, compares reduce_dims to the dimensions of the +// reduction. +bool AppliesAddReduce(const HloInstruction* instr, + absl::Span reduce_dims = {}) { if (instr->opcode() != HloOpcode::kReduce) { return false; } + if (ShapeUtil::HasDegenerateDimensions(instr->operand(0)->shape())) { + VLOG(1) << "Reduction input must not have degenerate dimensions."; + return false; + } + // Verify the dimensions of the reduction. + if (!reduce_dims.empty() && instr->dimensions() != reduce_dims) { + return false; + } + HloComputation* reduce_comp = instr->to_apply(); HloInstruction* reduce_comp_root = reduce_comp->root_instruction(); return instr->operand_count() == 2 && @@ -97,23 +195,13 @@ bool AppliesAddReduce(const HloInstruction* instr) { // Returns whether instr multiplies the result of a reduction by one over the // number of reduced elements. bool CalculatesExpectation(const HloInstruction* instr) { - auto skip_convert_and_reshape = - [](const HloInstruction* instr) -> const HloInstruction* { - while (instr->opcode() == HloOpcode::kConvert || - instr->opcode() == HloOpcode::kReshape) { - instr = instr->operand(0); - } - return instr; - }; - - instr = skip_convert_and_reshape(instr); + instr = SkipUnaryOps(instr); if (instr->opcode() != HloOpcode::kMultiply) { return false; } bool bcast_operand = instr->operand(0)->opcode() != HloOpcode::kBroadcast; const HloInstruction *broadcast = instr->operand(bcast_operand), - *reduce = instr->operand(!bcast_operand); - reduce = skip_convert_and_reshape(reduce); + *reduce = SkipUnaryOps(instr->operand(!bcast_operand)); if (reduce->opcode() != HloOpcode::kReduce || broadcast->opcode() != HloOpcode::kBroadcast || broadcast->operand(0)->opcode() != HloOpcode::kConstant) { @@ -134,6 +222,167 @@ bool CalculatesExpectation(const HloInstruction* instr) { ((actual_r_nelems + r_nelems) * numerical_epsilon); } +// Returns whether target can be reached from instr by recursively traversing +// the graph across converts, bitcasts and reshapes. +bool FindTargetRecursive( + const HloInstruction* instr, const HloInstruction* target, + absl::flat_hash_set& visited_instrs, + const HloInstruction* transpose) { + visited_instrs.emplace(instr); + const absl::flat_hash_set supported_ops = { + HloOpcode::kConvert, HloOpcode::kBitcast, HloOpcode::kReshape}; + if (instr == target) { + return true; + } + // Look for target among the users of instr. + for (HloInstruction* user : instr->users()) { + if ((supported_ops.contains(user->opcode()) || user == transpose) && + !visited_instrs.contains(user)) { + return FindTargetRecursive(user, target, visited_instrs, transpose); + } + } + // Ascend the graph if target is not found and instr is a convert, bitcast + // or reshape. + if (supported_ops.contains(instr->opcode())) { + return FindTargetRecursive(instr->operand(0), target, visited_instrs, + transpose); + } + return false; +} + +bool FindTarget(const HloInstruction* custom_call, const HloInstruction* instr, + const HloInstruction* target, + const NormMetadataMap& norm_metadata) { + absl::flat_hash_set visited_instrs; + auto custom_call_metadata = norm_metadata.find(custom_call); + if (custom_call_metadata == norm_metadata.end()) { + return false; + } + return FindTargetRecursive(instr, target, visited_instrs, + custom_call_metadata->second.x_transpose); +} + +// Maps the dimension numbers in dimensions from shape original_shape to shape +// reshaped_shape, assuming that the shapes are related through a strict +// reshape. Returns an empty vector if a dimension mapping is not found. +std::vector MapDimensions(const Shape& original_shape, + const Shape& reshaped_shape, + const absl::Span dimensions) { + // The original and reshaped shape must not have degenerate dimensions. + if (ShapeUtil::HasDegenerateDimensions(original_shape) || + ShapeUtil::HasDegenerateDimensions(reshaped_shape)) { + return {}; + } + + auto dimension_product = + [](const Shape& shape, + absl::Span product_dimensions) -> int64_t { + int64_t product = 1; + for (int64_t product_dimension : product_dimensions) { + product *= shape.dimensions(product_dimension); + } + return product; + }; + // Construct the dimension mapping. + absl::flat_hash_map> dimensions_map; + std::vector original_dimensions, reshaped_dimensions; + for (int64_t original_dimension = 0, reshaped_dimension = 0; + original_dimension < original_shape.rank(); ++original_dimension) { + original_dimensions.emplace_back(original_dimension); + while (dimension_product(reshaped_shape, reshaped_dimensions) < + dimension_product(original_shape, original_dimensions) && + reshaped_dimension < reshaped_shape.rank()) { + reshaped_dimensions.emplace_back(reshaped_dimension++); + } + + // Many-to-many dimension mappings are not supported. + if (original_dimensions.size() > 1 && reshaped_dimensions.size() > 1) { + return {}; + } + + if (dimension_product(original_shape, original_dimensions) == + dimension_product(reshaped_shape, reshaped_dimensions)) { + std::vector original_dimensions_in_dimensions; + std::set_intersection( + original_dimensions.begin(), original_dimensions.end(), + dimensions.begin(), dimensions.end(), + std::back_inserter(original_dimensions_in_dimensions)); + // The unique mapping of dimensions requires either all or none of the + // entries of original_dimensions to be an element of dimensions. + if (original_dimensions_in_dimensions.size() != 0 && + original_dimensions_in_dimensions.size() != + original_dimensions.size()) { + return {}; + } + for (int64_t dimension : original_dimensions) { + dimensions_map.insert({dimension, reshaped_dimensions}); + } + original_dimensions.clear(); + reshaped_dimensions.clear(); + } + } + + // Map the dimensions numbers to the reshaped shape. + std::vector mapped_dimensions; + for (int64_t dimension : dimensions) { + auto mapped_dimension = dimensions_map.find(dimension); + if (mapped_dimension == dimensions_map.end()) { + return {}; + } + mapped_dimensions.insert(mapped_dimensions.end(), + mapped_dimension->second.begin(), + mapped_dimension->second.end()); + } + + // Eliminate duplicates in the mapped dimension numbers. + mapped_dimensions.erase( + std::unique(mapped_dimensions.begin(), mapped_dimensions.end()), + mapped_dimensions.end()); + return mapped_dimensions; +} + +// Recursively traverses the graph across converts, bitcasts and reshapes, +// starting from instr, and returns the first addition-reduction identified. +// Returns nullptr if no addition-reduction is found. +HloInstruction* FindAddReduceRecursive( + HloInstruction* instr, const Shape& orig_instr_shape, + const absl::Span reduce_dims, + absl::flat_hash_set& visited_instrs) { + visited_instrs.emplace(instr); + const absl::flat_hash_set supported_ops = { + HloOpcode::kConvert, HloOpcode::kBitcast, HloOpcode::kReshape}; + // Look for a reduction among the users of instr. + for (HloInstruction* user : instr->users()) { + if (user->opcode() == HloOpcode::kReduce) { + std::vector mapped_reduce_dims = + MapDimensions(orig_instr_shape, instr->shape(), reduce_dims); + if (!mapped_reduce_dims.empty() && + AppliesAddReduce(user, mapped_reduce_dims)) { + return user; + } + } + if (supported_ops.contains(user->opcode()) && + !visited_instrs.contains(user)) { + return FindAddReduceRecursive(user, orig_instr_shape, reduce_dims, + visited_instrs); + } + } + // Ascend the graph if the addition-reduction is not found and instr is a + // convert, bitcast or reshape. + if (supported_ops.contains(instr->opcode())) { + return FindAddReduceRecursive(instr->mutable_operand(0), orig_instr_shape, + reduce_dims, visited_instrs); + } + return nullptr; +} + +HloInstruction* FindAddReduce(HloInstruction* instr, + const absl::Span reduce_dims) { + absl::flat_hash_set visited_instrs; + return FindAddReduceRecursive(instr, instr->shape(), reduce_dims, + visited_instrs); +} + // Type conversion from and to any of BF16, FP16 and FP32. template auto SupportedConvert(Pattern pattern) { @@ -144,69 +393,94 @@ auto SupportedConvert(Pattern pattern) { return m::Convert(pattern).WithPredicate(supported_convert); } -// Reshape adding or removing degenerate dimensions. +// Bitcast or reshape adding or removing degenerate dimensions. template -auto SupportedReshape(Pattern pattern) { - auto supported_reshape = [](const HloInstruction* instr) -> bool { +auto SupportedBitcastOrReshape(Pattern pattern) { + auto supported_bitcast_or_reshape = [](const HloInstruction* instr) -> bool { return ShapeUtil::Equal( ShapeUtil::DropDegenerateDimensions(instr->shape()), ShapeUtil::DropDegenerateDimensions(instr->operand(0)->shape())); }; - return m::Reshape(pattern).WithPredicate(supported_reshape); + return m::AnyOf( + m::Bitcast(pattern).WithPredicate(supported_bitcast_or_reshape), + m::Reshape(pattern).WithPredicate(supported_bitcast_or_reshape)); } -// Matches pattern, SupportedConvert(pattern), SupportedReshape(pattern), -// SupportedConvert(SupportedReshape(pattern)) and -// SupportedReshape(SupportedConvert(pattern)). +// Matches pattern, SupportedConvert(pattern), +// SupportedBitcastOrReshape(pattern), +// SupportedConvert(SupportedBitcastOrReshape(pattern)) and +// SupportedBitcastOrReshape(SupportedConvert(pattern)). template -auto OptionalConvertAndOrReshape(Pattern pattern) { +auto OptionalSupportedTransform(Pattern pattern) { auto shared_subpattern = m::SharedSubpattern(pattern); return m::AnyOf( - SupportedConvert(SupportedReshape(shared_subpattern)), - SupportedReshape(SupportedConvert(shared_subpattern)), - SupportedConvert(shared_subpattern), SupportedReshape(shared_subpattern), - shared_subpattern); + SupportedConvert(SupportedBitcastOrReshape(shared_subpattern)), + SupportedBitcastOrReshape(SupportedConvert(shared_subpattern)), + SupportedConvert(shared_subpattern), + SupportedBitcastOrReshape(shared_subpattern), shared_subpattern); +} + +// Bitcast or reshape with optional supported type conversion and/or addition or +// removal of degenerate dimensions. +template +auto BitcastOrReshape(Pattern pattern) { + return OptionalSupportedTransform( + m::AnyOf(m::Bitcast(pattern), m::Reshape(pattern))); +} + +// Transpose with optional supported type conversion and/or addition or removal +// of degenerate dimensions. +template +auto Transpose(Pattern pattern) { + return OptionalSupportedTransform(m::Transpose(pattern)); } -// Rsqrt with optional convert and/or reshape. +// Rsqrt with optional supported type conversion and/or addition or removal of +// degenerate dimensions. template auto Rsqrt(HloInstruction** rsqrt, Pattern pattern) { - return OptionalConvertAndOrReshape(m::Rsqrt(rsqrt, pattern)); + return OptionalSupportedTransform(m::Rsqrt(rsqrt, pattern)); } -// AddAnyOrder with optional convert and/or reshape. +// AddAnyOrder with optional supported type conversion and/or addition or +// removal of degenerate dimensions. template auto AddAnyOrder(Pattern0 pattern0, Pattern1 pattern1) { - return OptionalConvertAndOrReshape(m::AddAnyOrder(pattern0, pattern1)); + return OptionalSupportedTransform(m::AddAnyOrder(pattern0, pattern1)); } -// Subtract with optional convert and/or reshape. +// Subtract with optional supported type conversion and/or addition or removal +// of degenerate dimensions. template auto Subtract(Pattern0 pattern0, Pattern1 pattern1) { - return OptionalConvertAndOrReshape(m::Subtract(pattern0, pattern1)); + return OptionalSupportedTransform(m::Subtract(pattern0, pattern1)); } -// Capturing subtract with optional convert and/or reshape. +// Capturing subtract with optional supported type conversion and/or addition or +// removal of degenerate dimensions. template auto Subtract(HloInstruction** subtract, Pattern0 pattern0, Pattern1 pattern1) { - return OptionalConvertAndOrReshape(m::Subtract(subtract, pattern0, pattern1)); + return OptionalSupportedTransform(m::Subtract(subtract, pattern0, pattern1)); } -// Multiply with optional convert and/or reshape. +// Multiply with optional supported type conversion and/or addition or removal +// of degenerate dimensions. template auto MultiplyAnyOrder(Pattern0 pattern0, Pattern1 pattern1) { - return OptionalConvertAndOrReshape(m::MultiplyAnyOrder(pattern0, pattern1)); + return OptionalSupportedTransform(m::MultiplyAnyOrder(pattern0, pattern1)); } -// Capturing multiply with optional convert and/or reshape. +// Capturing multiply with optional supported type conversion and/or addition or +// removal of degenerate dimensions. template auto MultiplyAnyOrder(HloInstruction** multiply, Pattern0 pattern0, Pattern1 pattern1) { - return OptionalConvertAndOrReshape( + return OptionalSupportedTransform( m::MultiplyAnyOrder(multiply, pattern0, pattern1)); } -// Multiplication of pattern by itself with optional convert and/or reshape. +// Multiplication of pattern by itself with optional supported type conversion +// and/or addition or removal of degenerate dimensions. template auto Square(Pattern pattern) { return MultiplyAnyOrder(pattern, pattern) @@ -215,28 +489,49 @@ auto Square(Pattern pattern) { }); } -// Addition-reduction of pattern with optional convert and/or reshape and -// constant 0 scalar. +// Multiplication of the square of pattern by pattern with optional supported +// type conversion and/or addition or removal of degenerate dimensions. The root +// instruction of pattern cannot be a multiplication. +template +auto Cube(Pattern pattern) { + auto unique_cube = [](const HloInstruction* instr) -> bool { + bool square_operand = instr->operand(0)->opcode() != HloOpcode::kMultiply; + return instr->operand(!square_operand)->opcode() != HloOpcode::kMultiply && + instr->operand(square_operand)->operand(0) == + instr->operand(!square_operand); + }; + return MultiplyAnyOrder(Square(pattern), pattern).WithPredicate(unique_cube); +} + +// Addition-reduction of pattern with optional supported type conversion and/or +// addition or removal of degenerate dimensions. template auto AddReduce(Pattern pattern) { - return OptionalConvertAndOrReshape( + return OptionalSupportedTransform( m::Reduce(pattern, m::Op()) .WithPredicate([](const HloInstruction* instr) { return AppliesAddReduce(instr); })); } -// Capturing addition-reduction of pattern with optional convert and/or reshape -// and constant 0 scalar. +// Capturing addition-reduction of pattern with optional supported type +// conversion and/or addition or removal of degenerate dimensions. template auto AddReduce(HloInstruction** reduction, Pattern pattern) { - return OptionalConvertAndOrReshape( + return OptionalSupportedTransform( m::Reduce(reduction, pattern, m::Op()) .WithPredicate([](const HloInstruction* instr) { return AppliesAddReduce(instr); })); } +// Negated addition-reduction. +template +auto NegateAddReduce(HloInstruction** reduction, Pattern pattern) { + return m::AnyOf(AddReduce(reduction, m::Negate(pattern)), + m::Negate(AddReduce(reduction, pattern))); +} + // Expected value, or mean, with optional broadcast. template auto Expectation(Pattern pattern) { @@ -251,65 +546,56 @@ auto Expectation(Pattern pattern) { // Expected value, or mean, with optional broadcast. template -auto Expectation(HloInstruction** expectation, Pattern pattern) { +auto Expectation(UniqueHloInstruction* expectation, Pattern pattern) { auto shared_subpattern = - MultiplyAnyOrder(expectation, m::Broadcast(m::ConstantScalar()), - AddReduce(pattern)) + MultiplyAnyOrder(m::Broadcast(m::ConstantScalar()), AddReduce(pattern)) .WithPredicate([](const HloInstruction* instr) { return CalculatesExpectation(instr); - }); + }) + .WithPredicate(expectation->capture_or_verify); return m::AnyOf(m::Broadcast(shared_subpattern), shared_subpattern); } // Expected value, or mean, with optional broadcast. template -auto Expectation(HloInstruction** expectation, HloInstruction** reduce, +auto Expectation(UniqueHloInstruction* expectation, HloInstruction** reduce, Pattern pattern) { - auto shared_subpattern = - MultiplyAnyOrder(expectation, m::Broadcast(m::ConstantScalar()), - AddReduce(reduce, pattern)) - .WithPredicate([](const HloInstruction* instr) { - return CalculatesExpectation(instr); - }); + auto shared_subpattern = MultiplyAnyOrder(m::Broadcast(m::ConstantScalar()), + AddReduce(reduce, pattern)) + .WithPredicate([](const HloInstruction* instr) { + return CalculatesExpectation(instr); + }) + .WithPredicate(expectation->capture_or_verify); return m::AnyOf(m::Broadcast(shared_subpattern), shared_subpattern); } // Variance, expressed as expectation(X^2) - expectation(X)^2 or -// expectation((X - expectation(X))^2). The simultaneous capture of input0 and -// input1 allows the caller to verify that they are identical. -auto Variance(HloInstruction** expectation, HloInstruction** input0, - HloInstruction** input1) { - return m::AnyOf( - Subtract(Expectation(Square(m::Op(input0))), - Square(Expectation(expectation, m::Op(input1)))), - Expectation(Square( - Subtract(m::Op(input0), Expectation(expectation, m::Op(input1)))))); -} - -// Variance, expressed as expectation(X^2) - expectation(X)^2 or -// expectation((X - expectation(X))^2). The simultaneous capture of input0 and -// input1 allows the caller to verify that they are identical. -auto Variance(HloInstruction** variance, HloInstruction** expectation, - HloInstruction** input0, HloInstruction** input1) { +// expectation((X - expectation(X))^2). +auto Variance(UniqueHloInstruction* variance, UniqueHloInstruction* expectation, + UniqueHloInstruction* x) { return m::AnyOf( - Subtract(variance, Expectation(Square(m::Op(input0))), - Square(Expectation(expectation, m::Op(input1)))), - Expectation(variance, - Square(Subtract(m::Op(input0), - Expectation(expectation, m::Op(input1)))))); + Subtract(Expectation(Square(m::Op().WithPredicate(x->capture_or_verify))), + Square(Expectation(expectation, + m::Op().WithPredicate(x->capture_or_verify)))) + .WithPredicate(variance->capture_or_verify), + Expectation( + Square(Subtract(m::Op().WithPredicate(x->capture_or_verify), + Expectation(expectation, m::Op().WithPredicate( + x->capture_or_verify))))) + .WithPredicate(variance->capture_or_verify)); } // Reciprocal of the square root of variance + epsilon with optional broadcast. -// The simultaneous capture of input0 and input1 allows the caller to verify -// that they are identical. -auto NormFactor(HloInstruction** norm_factor, HloInstruction** input0, - HloInstruction** input1, HloInstruction** variance, - HloInstruction** expectation, HloInstruction** epsilon) { +auto NormFactor(HloInstruction** norm_factor, UniqueHloInstruction* x, + UniqueHloInstruction* variance, + UniqueHloInstruction* expectation, + UniqueHloInstruction* epsilon) { auto shared_subpattern = m::SharedSubpattern(Rsqrt( - norm_factor, AddAnyOrder(Variance(variance, expectation, input0, input1), - m::Broadcast(m::ConstantScalar(epsilon))))); + norm_factor, AddAnyOrder(Variance(variance, expectation, x), + m::Broadcast(m::ConstantScalar().WithPredicate( + epsilon->capture_or_verify))))); return m::AnyOf(m::Broadcast(shared_subpattern), shared_subpattern); } @@ -323,6 +609,22 @@ auto MultiplyMultiplyAnyOrder(P0 p0, P1 p1, P2 p2) { MultiplyAnyOrder(p2, MultiplyAnyOrder(p0, p1))); } +// Any order of p0 + p1 + p2. +template +auto AddAddAnyOrder(P0 p0, P1 p1, P2 p2) { + return m::AnyOf(AddAnyOrder(p0, AddAnyOrder(p1, p2)), + AddAnyOrder(p1, AddAnyOrder(p0, p2)), + AddAnyOrder(p2, AddAnyOrder(p0, p1))); +} + +// Any order of p0 * (p1 + p2). +template +auto MultiplyAddAnyOrder(P0 p0, P1 p1, P2 p2) { + return m::AnyOf( + MultiplyAnyOrder(p0, AddAnyOrder(p1, p2)), + AddAnyOrder(MultiplyAnyOrder(p0, p1), MultiplyAnyOrder(p0, p2))); +} + // Any order of p0 - p1 + p2. template auto SubtractAddAnyOrder(P0 p0, P1 p1, P2 p2) { @@ -340,6 +642,185 @@ auto SubtractMultiplyAddAnyOrder(P0 p0, P1 p1, P2 p2, P3 p3, P4 p4) { AddAnyOrder(MultiplyMultiplyAnyOrder(Subtract(p0, p1), p2, p3), p4)); } +// Expectation fused into a layer norm Custom Call. +auto FusedExpectation(UniqueHloInstruction* custom_call) { + auto shared_subpattern = m::SharedSubpattern( + m::GetTupleElement(m::CustomCall({kCudnnNormCallTarget}) + .WithPredicate(custom_call->capture_or_verify), + 1)); + return m::AnyOf(shared_subpattern, + BitcastOrReshape(shared_subpattern)); +} + +// Expectation fused into a layer norm Custom Call. +auto FusedExpectation(UniqueHloInstruction* fused_expectation, + UniqueHloInstruction* custom_call) { + auto shared_subpattern = m::SharedSubpattern( + m::GetTupleElement(m::CustomCall({kCudnnNormCallTarget}) + .WithPredicate(custom_call->capture_or_verify), + 1) + .WithPredicate(fused_expectation->capture_or_verify)); + return m::AnyOf(shared_subpattern, + BitcastOrReshape(shared_subpattern)); +} + +// Norm factor fused into a layer norm Custom Call. +auto FusedNormFactor(UniqueHloInstruction* custom_call) { + auto shared_subpattern = m::SharedSubpattern( + m::GetTupleElement(m::CustomCall({kCudnnNormCallTarget}) + .WithPredicate(custom_call->capture_or_verify), + 2)); + return m::AnyOf(shared_subpattern, + BitcastOrReshape(shared_subpattern)); +} + +// Norm factor fused into a layer norm Custom Call. +auto FusedNormFactor(UniqueHloInstruction* fused_norm_factor, + UniqueHloInstruction* custom_call) { + auto shared_subpattern = m::SharedSubpattern( + m::GetTupleElement(m::CustomCall({kCudnnNormCallTarget}) + .WithPredicate(custom_call->capture_or_verify), + 2) + .WithPredicate(fused_norm_factor->capture_or_verify)); + return m::AnyOf(shared_subpattern, + BitcastOrReshape(shared_subpattern)); +} + +// Derivative of the norm factor w.r.t. variance + epsilon, +// d(norm_factor)/d(variance + epsilon) +// = d((variance + epsilon)^-1/2)/d(variance + epsilon) +// = -1/2 * norm_factor^3. +// Forwards custom_call to FusedNormFactor for verification. +auto DNormFactor(UniqueHloInstruction* custom_call) { + return MultiplyAnyOrder(m::Broadcast(m::ConstantScalar(-0.5)), + Cube(FusedNormFactor(custom_call))); +} + +// Zero-centered input of the layer norm, X - expectation(X). Verifies that +// custom_call is a forward layer norm fusing X. Forwards custom_call to +// FusedExpectation for verification. +auto XCenter(UniqueHloInstruction* x, UniqueHloInstruction* custom_call, + const NormMetadataMap& norm_metadata) { + auto capture_or_verify_x = + [x, custom_call, &norm_metadata](const HloInstruction* instr) -> bool { + return x->CaptureOrVerify( + FindTarget(custom_call->Instr(), instr->operand(0), + custom_call->Instr()->operand(0), norm_metadata) + ? custom_call->Instr()->mutable_operand(0) + : nullptr); + }; + return Subtract(m::Op(), m::Broadcast(FusedExpectation(custom_call))) + .WithPredicate(capture_or_verify_x); +} + +// Zero-centered input of the layer norm, X - expectation(X). Captures X in x if +// custom_call is a forward layer norm fusing X. Forwards custom_call to +// FusedExpectation for comparison. +auto XCenter(UniqueHloInstruction* x_center, UniqueHloInstruction* x, + UniqueHloInstruction* fused_expectation, + UniqueHloInstruction* custom_call, + const NormMetadataMap& norm_metadata) { + auto capture_or_verify_x = [x, x_center, custom_call, &norm_metadata]( + const HloInstruction* instr) -> bool { + return x->CaptureOrVerify( + FindTarget(custom_call->Instr(), instr->operand(0), + custom_call->Instr()->operand(0), norm_metadata) + ? custom_call->Instr()->mutable_operand(0) + : nullptr); + }; + return Subtract(m::Op(), m::Broadcast(FusedExpectation(fused_expectation, + custom_call))) + .WithPredicate(x_center->capture_or_verify) + .WithPredicate(capture_or_verify_x); +} + +// Addition-reduction of the product of XCenter, the broadcasted scale and DY, +// XCenter * scale * DY. Captures the scale in scale if custom_call is a forward +// layer norm fusing the scale. Forwards custom_call to XCenter for comparison. +auto F0(UniqueHloInstruction* custom_call, UniqueHloInstruction* scale, + UniqueHloInstruction* dy, UniqueHloInstruction* x, + HloInstruction** reduce, const NormMetadataMap& norm_metadata) { + auto capture_or_verify_scale = [scale, custom_call, &norm_metadata]( + const HloInstruction* instr) -> bool { + return scale->CaptureOrVerify(FindTarget(custom_call->Instr(), instr, + custom_call->Instr()->operand(1), + norm_metadata) + ? custom_call->Instr()->mutable_operand(1) + : nullptr); + }; + return AddReduce( + reduce, MultiplyMultiplyAnyOrder( + XCenter(x, custom_call, norm_metadata), + m::Broadcast(m::Op().WithPredicate(capture_or_verify_scale)), + m::Op().WithPredicate(dy->capture_or_verify))); +} + +// Product of XCenter and the scaled and broadcasted product of F0 and +// d(norm_factor)/d(variance + epsilon), XCenter * F0 * DNormFactor * 2 / +// nelems. Forwards custom_call to XCenter, F0 and DNormFactor for capture or +// verification. +auto F1(UniqueHloInstruction* x, UniqueHloInstruction* x_center, + UniqueHloInstruction* fused_expectation, + UniqueHloInstruction* custom_call, UniqueHloInstruction* scale, + UniqueHloInstruction* dy, HloInstruction** reduce, + const NormMetadataMap& norm_metadata) { + auto broadcasts_two_over_nelems = [](const HloInstruction* instr) -> bool { + const HloInstruction* multiply = SkipUnaryOps(instr->operand(0)); + bool bcast_operand = + multiply->operand(0)->opcode() != HloOpcode::kBroadcast; + + // The captured scalar must be two over the number of elements in the + // broadcasted dimensions. + float actual_two_over_nelems = multiply->operand(bcast_operand) + ->operand(0) + ->literal() + .GetAsDouble({}) + .value(); + int64_t nelems = 1; + for (int i = 0; i < instr->shape().dimensions_size(); ++i) { + if (!c_linear_search(instr->dimensions(), i)) { + nelems *= instr->shape().dimensions()[i]; + } + } + // The absolute of the difference between the actual scaling factor and the + // reference value must not exceed a prescribed threshold. + float two_over_nelems = 2. / static_cast(nelems); + float numerical_epsilon = std::numeric_limits::epsilon(); + return abs(actual_two_over_nelems - two_over_nelems) < + ((actual_two_over_nelems + two_over_nelems) * numerical_epsilon); + }; + + return MultiplyAnyOrder( + XCenter(x_center, x, fused_expectation, custom_call, norm_metadata), + m::Broadcast( + MultiplyAnyOrder(m::Broadcast(m::ConstantScalar()), + MultiplyAnyOrder(DNormFactor(custom_call), + F0(custom_call, scale, dy, x, + reduce, norm_metadata)))) + .WithPredicate(broadcasts_two_over_nelems)); +} + +// Product of the norm factor, scale and DY, NormFactor * scale * DY. Captures +// the scale in scale if custom_call is a forward layer norm fusing the scale. +// Forwards custom_call to FusedNormFactor for comparison. +auto F2(UniqueHloInstruction* fused_norm_factor, UniqueHloInstruction* scale, + UniqueHloInstruction* dy, UniqueHloInstruction* custom_call, + const NormMetadataMap& norm_metadata) { + auto capture_or_verify_scale = [scale, custom_call, &norm_metadata]( + const HloInstruction* instr) -> bool { + return scale->CaptureOrVerify( + FindTarget(custom_call->Instr(), instr->operand(0), + custom_call->Instr()->operand(1), norm_metadata) + ? custom_call->Instr()->mutable_operand(1) + : nullptr); + }; + return MultiplyAnyOrder( + m::Broadcast( + BitcastOrReshape(FusedNormFactor(fused_norm_factor, custom_call))), + MultiplyAnyOrder(m::Broadcast().WithPredicate(capture_or_verify_scale), + m::Op().WithPredicate(dy->capture_or_verify))); +} + class CudnnNormRewriterVisitor : public DfsHloRewriteVisitor { public: explicit CudnnNormRewriterVisitor( @@ -347,7 +828,9 @@ class CudnnNormRewriterVisitor : public DfsHloRewriteVisitor { : cuda_compute_capability_(cuda_compute_capability) {} absl::Status HandleAdd(HloInstruction* instr) override { - return MatchLayerNorm(instr); + TF_RETURN_IF_ERROR(MatchLayerNorm(instr)); + TF_RETURN_IF_ERROR(MatchLayerNormGradient(instr)); + return absl::OkStatus(); } absl::Status HandleSubtract(HloInstruction* instr) override { @@ -355,19 +838,21 @@ class CudnnNormRewriterVisitor : public DfsHloRewriteVisitor { } // Matches and rewrites layer norm patterns, - // (X - expectation(X))/(variance(X) + epsilon)^1/2 * scale + bias, + // Y = (X - expectation(X))/sqrt(variance(X) + epsilon) * scale + bias, // into Custom Calls to cuDNN. absl::Status MatchLayerNorm(HloInstruction* instr) { - HloInstruction *input, *input0, *input1, *input2, *scale, *bias, *epsilon, - *expectation, *expectation0, *reduce, *norm_factor, *variance, - *broadcast_scale, *broadcast_bias; - if (Match(instr, SubtractMultiplyAddAnyOrder( - m::Op(&input), - Expectation(&expectation, &reduce, m::Op(&input0)), - NormFactor(&norm_factor, &input1, &input2, &variance, - &expectation0, &epsilon), - m::Broadcast(&broadcast_scale, m::Op(&scale)), - m::Broadcast(&broadcast_bias, m::Op(&bias))))) { + UniqueHloInstruction x, expectation, variance, epsilon; + HloInstruction *scale, *bias, *reduce, *norm_factor, *broadcast_scale, + *broadcast_bias; + if (Match( + instr, + SubtractMultiplyAddAnyOrder( + m::Op().WithPredicate(x.capture_or_verify), + Expectation(&expectation, &reduce, + m::Op().WithPredicate(x.capture_or_verify)), + NormFactor(&norm_factor, &x, &variance, &expectation, &epsilon), + m::Broadcast(&broadcast_scale, m::Op(&scale)), + m::Broadcast(&broadcast_bias, m::Op(&bias))))) { #if CUDNN_VERSION < 8905 // Layer norm kernels are available with cuDNN 8.9.5 and above. VLOG(1) << "Layer norm Custom Calls require cuDNN 8.9.5."; @@ -391,25 +876,20 @@ class CudnnNormRewriterVisitor : public DfsHloRewriteVisitor { } // Verify the uniqueness of the inputs. - auto is_input = [input](HloInstruction* inputx) -> bool { - return inputx->unique_id() == input->unique_id() || - (inputx->opcode() == HloOpcode::kConvert && - inputx->operand(0)->unique_id() == input->unique_id()); - }; - if (!is_input(input0) || !is_input(input1) || !is_input(input2) || - expectation->unique_id() != expectation0->unique_id()) { + if (!x.Instr() || !expectation.Instr() || !variance.Instr() || + !epsilon.Instr()) { VLOG(1) << "Layer norm operands not unique."; return absl::OkStatus(); } // Skip initial convert, if present. - if (input->opcode() == HloOpcode::kConvert) { - input = input->mutable_operand(0); + if (x.Instr()->opcode() == HloOpcode::kConvert) { + x.SetInstr(x.Instr()->mutable_operand(0)); } // Verify the input and output layouts. // TODO(philipphack): Consider supporting more general cases. - if (!LayoutUtil::IsMonotonicWithDim0Major(input->shape().layout()) || + if (!LayoutUtil::IsMonotonicWithDim0Major(x.Instr()->shape().layout()) || !LayoutUtil::IsMonotonicWithDim0Major(scale->shape().layout()) || !LayoutUtil::IsMonotonicWithDim0Major(bias->shape().layout()) || !LayoutUtil::IsMonotonicWithDim0Major(instr->shape().layout())) { @@ -419,7 +899,7 @@ class CudnnNormRewriterVisitor : public DfsHloRewriteVisitor { // Verify the element types. The types and shapes of the scale and bias // must match. - if (!CompatibleElementType(input) || !CompatibleElementType(instr) || + if (!CompatibleElementType(x.Instr()) || !CompatibleElementType(instr) || !CompatibleElementType(scale) || !CompatibleElementType(bias) || !ShapeUtil::Equal(scale->shape(), bias->shape())) { VLOG(1) << "Layer norm input types or shapes not supported."; @@ -435,7 +915,7 @@ class CudnnNormRewriterVisitor : public DfsHloRewriteVisitor { return absl::OkStatus(); } for (int i = 0; i < norm_dims.size(); ++i) { - if (input->shape().dimensions(norm_dims[i]) != + if (x.Instr()->shape().dimensions(norm_dims[i]) != scale->shape().dimensions(i)) { VLOG(1) << "Layer norm input dimensions not supported."; return absl::OkStatus(); @@ -456,54 +936,55 @@ class CudnnNormRewriterVisitor : public DfsHloRewriteVisitor { // If necessary, transpose the input so that the dimensions not being // normalized are the leading dimensions. std::vector non_norm_dims; - for (int64_t input_dim = 0; input_dim < input->shape().rank(); - ++input_dim) { - if (std::find(norm_dims.begin(), norm_dims.end(), input_dim) == + for (int64_t x_dim = 0; x_dim < x.Instr()->shape().rank(); ++x_dim) { + if (std::find(norm_dims.begin(), norm_dims.end(), x_dim) == norm_dims.end()) { - non_norm_dims.emplace_back(input_dim); + non_norm_dims.emplace_back(x_dim); } } - std::vector transpose_order = non_norm_dims; - transpose_order.insert(transpose_order.end(), norm_dims.begin(), - norm_dims.end()); + std::vector x_transpose_order = non_norm_dims; + x_transpose_order.insert(x_transpose_order.end(), norm_dims.begin(), + norm_dims.end()); bool apply_transpose = false; - for (int i = 0; i < transpose_order.size(); ++i) { - if (transpose_order[i] != i) { + for (int i = 0; i < x_transpose_order.size(); ++i) { + if (x_transpose_order[i] != i) { apply_transpose = true; break; } } - std::optional transpose; - std::vector inverse_transpose_order(transpose_order.size()); + std::optional x_transpose; + // The transpose applied to the output is the inverse of the transpose + // applied to the input. + std::vector y_transpose_order(x_transpose_order.size()); if (apply_transpose) { - for (int k = 0; k < transpose_order.size(); ++k) { - inverse_transpose_order[transpose_order[k]] = k; + for (int k = 0; k < x_transpose_order.size(); ++k) { + y_transpose_order[x_transpose_order[k]] = k; } - TF_ASSIGN_OR_RETURN(transpose, - MakeTransposeHlo(input, transpose_order)); + TF_ASSIGN_OR_RETURN(x_transpose, + MakeTransposeHlo(x.Instr(), x_transpose_order)); } // Combine the dimensions not normalized into the first dimension of the // input as required by cuDNN. std::vector reshaped_dims = {1}; for (auto non_norm_dim : non_norm_dims) { - reshaped_dims[0] *= input->shape().dimensions(non_norm_dim); + reshaped_dims[0] *= x.Instr()->shape().dimensions(non_norm_dim); } for (auto norm_dim : norm_dims) { - reshaped_dims.emplace_back(input->shape().dimensions(norm_dim)); + reshaped_dims.emplace_back(x.Instr()->shape().dimensions(norm_dim)); } // cuDNN requires tensors to have at least four dimensions. while (reshaped_dims.size() < 4) { reshaped_dims.emplace_back(1); } - Shape reshaped_shape = - ShapeUtil::MakeShape(input->shape().element_type(), reshaped_dims); + Shape reshaped_shape = ShapeUtil::MakeShape( + x.Instr()->shape().element_type(), reshaped_dims); TF_ASSIGN_OR_RETURN( - HloInstruction * reshape, - MakeReshapeHlo(reshaped_shape, transpose.value_or(input))); + HloInstruction * x_reshape, + MakeReshapeHlo(reshaped_shape, x_transpose.value_or(x.Instr()))); // Reshape the scale and bias. std::vector reshaped_scale_dims(reshaped_dims.begin() + 1, @@ -514,14 +995,16 @@ class CudnnNormRewriterVisitor : public DfsHloRewriteVisitor { } Shape scale_bias_shape = ShapeUtil::MakeShape( scale->shape().element_type(), reshaped_scale_dims); - TF_ASSIGN_OR_RETURN(HloInstruction * reshaped_scale, + TF_ASSIGN_OR_RETURN(HloInstruction * scale_reshape, MakeReshapeHlo(scale_bias_shape, scale)); - TF_ASSIGN_OR_RETURN(HloInstruction * reshaped_bias, + TF_ASSIGN_OR_RETURN(HloInstruction * bias_reshape, MakeReshapeHlo(scale_bias_shape, bias)); - GpuBackendConfig gpu_config; + GpuBackendConfig gpu_backend_config; CudnnNormBackendConfig& backend_config = - *gpu_config.mutable_cudnn_norm_backend_config(); - backend_config.set_epsilon(epsilon->literal().GetAsDouble({}).value()); + *gpu_backend_config.mutable_cudnn_norm_backend_config(); + backend_config.set_epsilon( + epsilon.Instr()->literal().GetAsDouble({}).value()); + backend_config.set_kind(CudnnNormBackendConfig::LAYER_FWD_INFER); auto* algorithm = backend_config.mutable_algorithm(); algorithm->set_algo_id(0); algorithm->set_math_type(se::dnn::AlgorithmProto::TENSOR_OP_MATH); @@ -538,28 +1021,33 @@ class CudnnNormRewriterVisitor : public DfsHloRewriteVisitor { // The output of the Custom Call is a tuple, the second element of which // describes the scratch space. Shape custom_call_shape = ShapeUtil::MakeTupleShape( - {reshape->shape(), ShapeUtil::MakeShape(U8, {workspace_size})}); + {x_reshape->shape(), ShapeUtil::MakeShape(U8, {workspace_size})}); HloInstruction* custom_call = instr->AddInstruction(HloInstruction::CreateCustomCall( - custom_call_shape, {reshape, reshaped_scale, reshaped_bias}, + custom_call_shape, {x_reshape, scale_reshape, bias_reshape}, kCudnnNormCallTarget)); - TF_RETURN_IF_ERROR(custom_call->set_backend_config(gpu_config)); + TF_RETURN_IF_ERROR(custom_call->set_backend_config(gpu_backend_config)); TF_ASSIGN_OR_RETURN(HloInstruction * gte, MakeGetTupleElementHlo(custom_call, 0)); TF_ASSIGN_OR_RETURN( - HloInstruction * inverse_reshape, - MakeReshapeHlo(transpose.value_or(instr)->shape(), gte)); + HloInstruction * y_reshape, + MakeReshapeHlo(x_transpose.value_or(instr)->shape(), gte)); - if (!apply_transpose) { - TF_RETURN_IF_ERROR(ReplaceInstruction(instr, inverse_reshape)); - } else { - TF_ASSIGN_OR_RETURN( - HloInstruction * inverse_transpose, - MakeTransposeHlo(inverse_reshape, inverse_transpose_order)); - TF_RETURN_IF_ERROR(ReplaceInstruction(instr, inverse_transpose)); + std::optional y_transpose; + if (apply_transpose) { + TF_ASSIGN_OR_RETURN(y_transpose, + MakeTransposeHlo(y_reshape, y_transpose_order)); } + TF_RETURN_IF_ERROR( + ReplaceInstruction(instr, y_transpose.value_or(y_reshape))); + + // Store metadata for potential use in the backward graph. + norm_metadata_.insert( + {custom_call, NormMetadata({x_transpose.value_or(nullptr), + y_transpose.value_or(nullptr), norm_dims, + non_norm_dims})}); VLOG(1) << "Layer norm rewritten into Custom Call."; @@ -583,28 +1071,36 @@ class CudnnNormRewriterVisitor : public DfsHloRewriteVisitor { // into the layer norm Custom Call. absl::Status MatchNormFactor(HloInstruction* instr, HloInstruction* custom_call, - HloInstruction* variance, - HloInstruction* expectation, - HloInstruction* epsilon) { - HloInstruction *variance0, *epsilon0, *gte = custom_call->users()[0]; + UniqueHloInstruction& variance, + UniqueHloInstruction& expectation, + UniqueHloInstruction& epsilon) { + HloInstruction* gte = custom_call->users()[0]; if (Match(instr, - m::Divide(m::Op(), AddAnyOrder(m::Op(&variance0), - m::Broadcast(m::ConstantScalar( - &epsilon0)))))) { + m::Divide( + m::Op(), + AddAnyOrder(m::Op().WithPredicate(variance.capture_or_verify), + m::Broadcast(m::ConstantScalar().WithPredicate( + epsilon.capture_or_verify)))))) { // Verify the uniqueness of the operands. - if (variance->unique_id() != variance0->unique_id() || - epsilon->unique_id() != epsilon0->unique_id()) { + if (!variance.Instr() || !epsilon.Instr()) { VLOG(1) << "Layer norm operands not unique."; return absl::OkStatus(); } // Verify the element types. if (!CompatibleElementType(instr) || - !CompatibleElementType(expectation)) { + !CompatibleElementType(expectation.Instr())) { VLOG(1) << "Layer norm input types not compatible."; return absl::OkStatus(); } + // Retrieve metadata of the forward layer norm. + auto norm_metadata = norm_metadata_.extract(custom_call); + if (!norm_metadata) { + VLOG(1) << "Unable to retrieve norm metadata of forward Custom Call."; + return absl::OkStatus(); + } + // The shape of the expectation and norm factor return values of the // Custom Call is [nelems, 1, 1, 1], where nelems is the // number of elements in the expectation and norm factor shapes. @@ -613,7 +1109,8 @@ class CudnnNormRewriterVisitor : public DfsHloRewriteVisitor { {ShapeUtil::ElementsIn(shape), 1, 1, 1}); }; - Shape expectation_shape = make_compatible_shape(expectation->shape()); + Shape expectation_shape = + make_compatible_shape(expectation.Instr()->shape()); Shape norm_factor_shape = make_compatible_shape(instr->shape()); // The augmented Custom Call additionally returns the expectation and the @@ -627,17 +1124,21 @@ class CudnnNormRewriterVisitor : public DfsHloRewriteVisitor { HloInstruction* new_custom_call = instr->AddInstruction( custom_call->CloneWithNewShape(custom_call_shape)); + TF_ASSIGN_OR_RETURN( + GpuBackendConfig gpu_backend_config, + custom_call->backend_config()); + CudnnNormBackendConfig& backend_config = + *gpu_backend_config.mutable_cudnn_norm_backend_config(); + backend_config.set_kind(CudnnNormBackendConfig::LAYER_FWD_TRAIN); + // Update the workspace size. TF_ASSIGN_OR_RETURN(const int64_t c_constant, CConstant(cuda_compute_capability_)); const int64_t workspace_size = (2 * c_constant * (4 + 256)) + 32; - TF_ASSIGN_OR_RETURN(GpuBackendConfig gpu_config, - custom_call->backend_config()); - CudnnNormBackendConfig& backend_config = - *gpu_config.mutable_cudnn_norm_backend_config(); backend_config.mutable_algorithm()->mutable_workspace_size()->set_value( workspace_size); - TF_RETURN_IF_ERROR(custom_call->set_backend_config(gpu_config)); + TF_RETURN_IF_ERROR( + new_custom_call->set_backend_config(gpu_backend_config)); auto replace_with_new_cc = [new_custom_call, this]( HloInstruction* old_instr, @@ -674,17 +1175,286 @@ class CudnnNormRewriterVisitor : public DfsHloRewriteVisitor { // Replace the result of the original Custom Call as well as the // expectation and the norm factor with the augmented Custom Call. TF_RETURN_IF_ERROR(replace_with_new_cc(gte, 0)); - TF_RETURN_IF_ERROR(replace_with_new_cc(expectation, 1)); + TF_RETURN_IF_ERROR(replace_with_new_cc(expectation.Instr(), 1)); TF_RETURN_IF_ERROR(replace_with_new_cc(instr, 2)); + // Update the Custom Call associated with the metadata of the forward + // norm. + norm_metadata.key() = new_custom_call; + norm_metadata_.insert(std::move(norm_metadata)); + VLOG(1) << "Expectation and norm factor fused into layer norm Custom Call."; } + + return absl::OkStatus(); + } + + // Matches and rewrites the backward graph of layer norm patterns into Custom + // Calls to cuDNN when the associated forward graph has been rewritten into a + // cuDNN Custom Call. The gradients are + // DX = F1 + F2 - AddReduce(F1 + F2) / nelems, + // Dscale = AddReduce(DY * XCenter * NormFactor), + // Dbias = AddReduce(DY), + // with + // F0 = XCenter * scale * DY, + // F1 = XCenter * F0 * DNormFactor * 2 / nelems, + // F2 = NormFactor * scale * DY, + // XCenter = X - expectation(X), + // NormFactor = (variance(X) + epsilon)^-1/2 and + // DNormFactor = -1/2 * NormFactor^3. + absl::Status MatchLayerNormGradient(HloInstruction* instr) { + UniqueHloInstruction fwd_custom_call, x, x_center, scale, dy, + fused_expectation, fused_norm_factor; + HloInstruction *broadcast, *scalar, *dscale, *dbias, *reduce0, *reduce1, + *reduce2, *reduce3; + if (Match(instr, + AddAddAnyOrder( + m::Broadcast( + &broadcast, + MultiplyAddAnyOrder( + m::Broadcast(m::ConstantScalar(&scalar)), + NegateAddReduce(&reduce0, + F1(&x, &x_center, &fused_expectation, + &fwd_custom_call, &scale, &dy, + &reduce2, norm_metadata_)), + NegateAddReduce( + &reduce1, F2(&fused_norm_factor, &scale, &dy, + &fwd_custom_call, norm_metadata_)))), + F2(&fused_norm_factor, &scale, &dy, &fwd_custom_call, + norm_metadata_), + F1(&x, &x_center, &fused_expectation, &fwd_custom_call, + &scale, &dy, &reduce3, norm_metadata_)))) { + // Skip initial convert, if present. + if (instr->user_count() == 1 && + instr->users()[0]->opcode() == HloOpcode::kConvert && + CompatibleElementType(instr->users()[0])) { + instr = instr->users()[0]; + } + + // Verify the uniqueness of the captured Custom Call and inputs. + if (!fwd_custom_call.Instr() || !x.Instr() || !dy.Instr() || + !x_center.Instr() || !scale.Instr() || !fused_expectation.Instr() || + !fused_norm_factor.Instr()) { + VLOG(1) << "Layer norm gradient inputs not unique."; + return absl::OkStatus(); + } + + // Retrieve metadata of the forward layer norm. + auto norm_metadata = norm_metadata_.find(fwd_custom_call.Instr()); + if (norm_metadata == norm_metadata_.end()) { + VLOG(1) << "Unable to retrieve norm metadata of forward Custom Call."; + return absl::OkStatus(); + } + + // Verify the dimensions of reductions in the backward graph. + if (reduce0->dimensions() != norm_metadata->second.norm_dims || + reduce1->dimensions() != norm_metadata->second.norm_dims || + reduce2->dimensions() != norm_metadata->second.norm_dims || + reduce3->dimensions() != norm_metadata->second.norm_dims) { + VLOG(1) << "Unexpected reductions dimensions in layer norm gradient."; + return absl::OkStatus(); + } + + // The captured scalar must be one over the number of elements in the + // broadcasted dimensions. + float actual_r_nelems = scalar->literal().GetAsDouble({}).value(); + int64_t nelems = 1; + for (int i = 0; i < broadcast->shape().dimensions_size(); ++i) { + if (!c_linear_search(broadcast->dimensions(), i)) { + nelems *= broadcast->shape().dimensions()[i]; + } + } + // The absolute of the difference between the actual scaling factor and + // the reference value must not exceed a prescribed threshold. + float r_nelems = 1. / static_cast(nelems); + float numerical_epsilon = std::numeric_limits::epsilon(); + if (!(abs(actual_r_nelems - r_nelems) < + ((actual_r_nelems + r_nelems) * numerical_epsilon))) { + VLOG(1) + << "Layer norm backward broadcast operand outside expected range."; + return absl::OkStatus(); + } + + // Identify Dscale = AddReduce(DY * XCenter * norm factor) with factor0 + // and factor1 intended to be XCenter and DY or DY and XCenter. + auto find_dscale = + [&fused_norm_factor, &norm_metadata]( + const UniqueHloInstruction& factor0, + const UniqueHloInstruction& factor1) -> HloInstruction* { + for (HloInstruction* factor0_user : factor0.Instr()->users()) { + std::vector users; + SkipUnaryOpsTopDownRecursive(factor0_user, users); + // One of the users of factor0 must be a chained multiplication by the + // fused norm factor and factor1. + for (HloInstruction* user : users) { + if (Match(user, + MultiplyAnyOrder( + m::Op(), MultiplyAnyOrder( + m::Broadcast(BitcastOrReshape(m::Op().Is( + fused_norm_factor.Instr()))), + m::Op().Is(factor1.Instr()))))) { + // Dscale is an addition-reduction of the product. + for (HloInstruction* multiply_user : user->users()) { + if (AppliesAddReduce(multiply_user, + norm_metadata->second.non_norm_dims)) { + return multiply_user; + } + } + } + } + } + return nullptr; + }; + if (!(dscale = find_dscale(x_center, dy)) && + !(dscale = find_dscale(dy, x_center))) { + VLOG(1) << "Unable to identify Dscale in graph."; + return absl::OkStatus(); + } + + // Find Dbias, i.e. an addition-reduction of DY, starting from DY. + // Rewriting proceeds without fusing Dbias if unsuccessful. + dbias = FindAddReduce(dy.Instr(), norm_metadata->second.non_norm_dims); + + // Verify the input and output layouts. + // TODO(philipphack): Consider supporting more general cases. + if (!LayoutUtil::IsMonotonicWithDim0Major(dy.Instr()->shape().layout()) || + !LayoutUtil::IsMonotonicWithDim0Major(instr->shape().layout()) || + !LayoutUtil::IsMonotonicWithDim0Major(dscale->shape().layout()) || + (dbias && + !LayoutUtil::IsMonotonicWithDim0Major(dbias->shape().layout()))) { + VLOG(1) << "Layer norm input and/or output layouts nor supported."; + return absl::OkStatus(); + } + + // The types of X and DX must match. + if (x.Instr()->shape().element_type() != instr->shape().element_type()) { + VLOG(1) << "The types of X and DX must match."; + return absl::OkStatus(); + } + + // The types and shapes of scale, Dscale and Dbias (if present) must + // match. + if (!ShapeUtil::Equal( + ShapeUtil::DropDegenerateDimensions(scale.Instr()->shape()), + ShapeUtil::DropDegenerateDimensions(dscale->shape())) || + (dbias && + !ShapeUtil::Equal( + ShapeUtil::DropDegenerateDimensions(scale.Instr()->shape()), + ShapeUtil::DropDegenerateDimensions(dbias->shape())))) { + VLOG(1) << "Backward layer norm types not supported."; + return absl::OkStatus(); + } + + // Verify the element types. + if (!CompatibleElementType(dy.Instr())) { + VLOG(1) << "Backward layer norm types not supported."; + return absl::OkStatus(); + } + + // cuDNN requires the byte size of the element type of X to be at least + // that of DY and scale. + if (ShapeUtil::ByteSizeOfPrimitiveType( + x.Instr()->shape().element_type()) < + ShapeUtil::ByteSizeOfPrimitiveType( + dy.Instr()->shape().element_type()) || + ShapeUtil::ByteSizeOfPrimitiveType( + x.Instr()->shape().element_type()) < + ShapeUtil::ByteSizeOfPrimitiveType( + scale.Instr()->shape().element_type())) { + VLOG(1) << "Backward layer norm types not supported."; + return absl::OkStatus(); + } + + // Transpose DY applying the stored transpose order of X from the forward + // graph. + HloInstruction* transposed_dy = dy.Instr(); + if (norm_metadata->second.x_transpose) { + TF_ASSIGN_OR_RETURN( + transposed_dy, + MakeTransposeHlo(dy.Instr(), + norm_metadata->second.x_transpose->dimensions())); + } + TF_ASSIGN_OR_RETURN(HloInstruction * reshaped_dy, + MakeReshapeHlo(x.Instr()->shape(), transposed_dy)); + + Shape dx_shape = ShapeUtil::MakeShape(instr->shape().element_type(), + x.Instr()->shape().dimensions()); + + Shape dscale_dbias_shape = ShapeUtil::MakeShape( + dscale->shape().element_type(), scale.Instr()->shape().dimensions()); + + GpuBackendConfig gpu_backend_config; + CudnnNormBackendConfig& backend_config = + *gpu_backend_config.mutable_cudnn_norm_backend_config(); + backend_config.set_kind(CudnnNormBackendConfig::LAYER_BWD); + auto* algorithm = backend_config.mutable_algorithm(); + algorithm->set_algo_id(0); + algorithm->set_math_type(se::dnn::AlgorithmProto::TENSOR_OP_MATH); + algorithm->set_is_cudnn_frontend(true); + + // Set the workspace size to its upper bound. + // TODO(philipphack): Consider autotuning the norm kernels. + TF_ASSIGN_OR_RETURN(const int64_t c_constant, + CConstant(cuda_compute_capability_)); + const int64_t workspace_size = + (2 * c_constant * (4 + 256)) + + (2 * x.Instr()->shape().dimensions(0) * 4) + 64; + algorithm->mutable_workspace_size()->set_value(workspace_size); + + // The output of the Custom Call is a tuple. The output shape of Dscale + // and Dbias is that of scale. + Shape custom_call_shape = ShapeUtil::MakeTupleShape( + {dx_shape, dscale_dbias_shape, dscale_dbias_shape, + ShapeUtil::MakeShape(U8, {workspace_size})}); + + HloInstruction* custom_call = + instr->AddInstruction(HloInstruction::CreateCustomCall( + custom_call_shape, + {x.Instr(), scale.Instr(), reshaped_dy, fused_expectation.Instr(), + fused_norm_factor.Instr()}, + kCudnnNormCallTarget)); + TF_RETURN_IF_ERROR(custom_call->set_backend_config(gpu_backend_config)); + + auto replace_with_cc = [custom_call, norm_metadata, transposed_dy, this]( + HloInstruction* old_instr, + int tuple_index) -> absl::Status { + TF_ASSIGN_OR_RETURN(HloInstruction * gte, + MakeGetTupleElementHlo(custom_call, tuple_index)); + HloInstruction* new_instr; + // Transpose DX applying the stored transpose order of Y from the + // forward graph. + if (tuple_index == 0 && norm_metadata->second.y_transpose) { + TF_ASSIGN_OR_RETURN(new_instr, + MakeReshapeHlo(transposed_dy->shape(), gte)); + TF_ASSIGN_OR_RETURN( + new_instr, + MakeTransposeHlo( + new_instr, norm_metadata->second.y_transpose->dimensions())); + } else { + TF_ASSIGN_OR_RETURN(new_instr, + MakeReshapeHlo(old_instr->shape(), gte)); + } + TF_RETURN_IF_ERROR(ReplaceInstruction(old_instr, new_instr)); + return absl::OkStatus(); + }; + + TF_RETURN_IF_ERROR(replace_with_cc(instr, 0)); + TF_RETURN_IF_ERROR(replace_with_cc(dscale, 1)); + if (dbias) { + TF_RETURN_IF_ERROR(replace_with_cc(dbias, 2)); + } + VLOG(1) << "Gradients w.r.t. x" + << (dbias ? ", scale and bias" : " and scale") + << " rewritten into layer norm backward Custom Call."; + } + return absl::OkStatus(); } private: se::CudaComputeCapability cuda_compute_capability_; + NormMetadataMap norm_metadata_; }; absl::StatusOr RunOnComputation( diff --git a/third_party/xla/xla/service/gpu/cudnn_norm_rewriter.h b/third_party/xla/xla/service/gpu/cudnn_norm_rewriter.h index 8a3981bf3c70ac..62b242513485b1 100644 --- a/third_party/xla/xla/service/gpu/cudnn_norm_rewriter.h +++ b/third_party/xla/xla/service/gpu/cudnn_norm_rewriter.h @@ -24,7 +24,7 @@ namespace xla { namespace gpu { // Rewrites norm patterns into Custom Calls to the cuDNN library. Currently, the -// forward pass of layer norm patterns is implemented. +// forward and backward passes of layer norm patterns are implemented. class CudnnNormRewriter : public HloModulePass { public: explicit CudnnNormRewriter(se::CudaComputeCapability cuda_compute_capability); diff --git a/third_party/xla/xla/service/gpu/cudnn_norm_rewriter_test.cc b/third_party/xla/xla/service/gpu/cudnn_norm_rewriter_test.cc index 4a46efffe09c9d..0df94397c096ec 100644 --- a/third_party/xla/xla/service/gpu/cudnn_norm_rewriter_test.cc +++ b/third_party/xla/xla/service/gpu/cudnn_norm_rewriter_test.cc @@ -283,9 +283,9 @@ TEST_F(CudnnNormRewriterTest, LayerNorm4D12) { ENTRY test { input = f32[2,4,6,8] parameter(0) - multiply3 = f32[2,4,6,8] multiply(input, input) + input_square = f32[2,4,6,8] multiply(input, input) c0 = f32[] constant(0) - input_square_sum = f32[2,8] reduce(multiply3, c0), dimensions={1,2}, to_apply=apply + input_square_sum = f32[2,8] reduce(input_square, c0), dimensions={1,2}, to_apply=apply r_nelems = f32[] constant(0.041667) r_nelems_bcast = f32[2,8] broadcast(r_nelems), dimensions={} input_square_mean = f32[2,8] multiply(input_square_sum, r_nelems_bcast) @@ -410,9 +410,9 @@ TEST_F(CudnnNormRewriterTest, LayerNormTrain2D1) { ENTRY test { input = f32[2,4] parameter(0) - multiply3 = f32[2,4] multiply(input, input) + input_square = f32[2,4] multiply(input, input) c0 = f32[] constant(0) - input_square_sum = f32[2] reduce(multiply3, c0), dimensions={1}, to_apply=apply + input_square_sum = f32[2] reduce(input_square, c0), dimensions={1}, to_apply=apply r_nelems = f32[] constant(0.25) r_nelems_bcast = f32[2] broadcast(r_nelems), dimensions={} input_square_mean = f32[2] multiply(input_square_sum,r_nelems_bcast) @@ -487,9 +487,9 @@ TEST_F(CudnnNormRewriterTest, LayerNormTrain4D3) { ENTRY test { input = f32[2,4,6,8] parameter(0) - multiply3 = f32[2,4,6,8] multiply(input, input) + input_square = f32[2,4,6,8] multiply(input, input) c0 = f32[] constant(0) - input_square_sum = f32[2,4,6] reduce(multiply3, c0), dimensions={3}, to_apply=apply + input_square_sum = f32[2,4,6] reduce(input_square, c0), dimensions={3}, to_apply=apply r_nelems = f32[] constant(0.125) r_nelems_bcast = f32[2,4,6] broadcast(r_nelems), dimensions={} input_square_mean = f32[2,4,6] multiply(input_square_sum, r_nelems_bcast) @@ -564,9 +564,9 @@ TEST_F(CudnnNormRewriterTest, LayerNormTrain4D12) { ENTRY test { input = f32[2,4,6,8] parameter(0) - multiply3 = f32[2,4,6,8] multiply(input, input) + input_square = f32[2,4,6,8] multiply(input, input) c0 = f32[] constant(0) - input_square_sum = f32[2,8] reduce(multiply3, c0), dimensions={1,2}, to_apply=apply + input_square_sum = f32[2,8] reduce(input_square, c0), dimensions={1,2}, to_apply=apply r_nelems = f32[] constant(0.041667) r_nelems_bcast = f32[2,8] broadcast(r_nelems), dimensions={} input_square_mean = f32[2,8] multiply(input_square_sum, r_nelems_bcast) @@ -620,6 +620,710 @@ TEST_F(CudnnNormRewriterTest, LayerNormTrain4D12) { TestNorm(hlo_text, optimized_hlo); } +TEST_F(CudnnNormRewriterTest, LayerNormTrainBackward2D1) { +#if (CUDA_VERSION < 12000 || CUDNN_VERSION < 8905) + GTEST_SKIP() << "Layer norm kernels require CUDA 12 and cuDNN 8.9.5."; +#endif + if (!(GetCudaComputeCapability().major == + se::CudaComputeCapability::AMPERE) && + !(GetCudaComputeCapability().major == + se::CudaComputeCapability::HOPPER)) { + GTEST_SKIP() + << "Layer norm kernels require Ampere or Hopper architectures."; + } + const char* hlo_text = R"( + HloModule test + + apply { + a = f32[] parameter(0) + b = f32[] parameter(1) + ROOT c = f32[] add(a,b) + } + + ENTRY test { + input = f32[2,4] parameter(0) + input_square = f32[2,4] multiply(input, input) + c0 = f32[] constant(0) + input_square_sum = f32[2] reduce(input_square, c0), dimensions={1}, to_apply=apply + reduce = f32[2] reduce(input, c0), dimensions={1}, to_apply=apply + r_nelems = f32[] constant(0.25) + r_nelems_bcast = f32[2] broadcast(r_nelems), dimensions={} + input_square_mean = f32[2] multiply(input_square_sum,r_nelems_bcast) + input_mean = f32[2] multiply(reduce, r_nelems_bcast) + input_mean_square = f32[2] multiply(input_mean,input_mean) + variance = f32[2] subtract(input_square_mean,input_mean_square) + epsilon = f32[] constant(0.001) + epsilon_bcast = f32[2] broadcast(epsilon), dimensions={} + variance_plus_epsilon = f32[2] add(variance, epsilon_bcast) + norm_factor = f32[2] rsqrt(variance_plus_epsilon) + norm_factor_bcast = f32[2,4] broadcast(norm_factor), dimensions={0} + input_mean_bcast = f32[2,4] broadcast(input_mean), dimensions={0} + input_center = f32[2,4] subtract(input, input_mean_bcast) + norm = f32[2,4] multiply(input_center, norm_factor_bcast) + scale = f32[4] parameter(1) + scale_bcast = f32[2,4] broadcast(scale), dimensions={1} + norm_scale = f32[2,4] multiply(norm, scale_bcast) + bias = f32[4] parameter(2) + bias_bcast = f32[2,4] broadcast(bias), dimensions={1} + norm_scale_bias = f32[2,4] add(norm_scale, bias_bcast) + doutput = f32[2,4] parameter(3) + dbias = f32[4] reduce(doutput, c0), dimensions={0}, to_apply=apply + norm_doutput = f32[2,4] multiply(norm, doutput) + dscale = f32[4] reduce(norm_doutput, c0), dimensions={0}, to_apply=apply + scale_doutput = f32[2,4] multiply(scale_bcast, doutput) + input_center_scale_doutput = f32[2,4] multiply(input_center, scale_doutput) + f0 = f32[2] reduce(input_center_scale_doutput, c0), dimensions={1}, to_apply=apply + norm_factor_cube = f32[2] divide(norm_factor, variance_plus_epsilon) + c1 = f32[] constant(-0.5) + c1_bcast = f32[2] broadcast(c1), dimensions={} + dnorm_factor = f32[2] multiply(norm_factor_cube, c1_bcast) + f0_dnorm_factor = f32[2] multiply(f0, dnorm_factor) + c2 = f32[] constant(0.5) + c2_bcast = f32[2] broadcast(c2), dimensions={} + f0_dnorm_factor_scaled = f32[2] multiply(f0_dnorm_factor, c2_bcast) + f0_dnorm_factor_scaled_bcast = f32[2,4] broadcast(f0_dnorm_factor_scaled), dimensions={0} + f1 = f32[2,4] multiply(input_center, f0_dnorm_factor_scaled_bcast) + minus_f1 = f32[2,4] negate(f1) + minus_f1_sum = f32[2] reduce(minus_f1, c0), dimensions={1}, to_apply=apply + f2 = f32[2,4] multiply(norm_factor_bcast, scale_doutput) + minus_f2 = f32[2,4] negate(f2) + minus_f2_sum = f32[2] reduce(minus_f2, c0), dimensions={1}, to_apply=apply + minus_f1_f2_sum = f32[2] add(minus_f1_sum, minus_f2_sum) + minus_f1_f2_sum_scaled = f32[2] multiply(minus_f1_f2_sum, r_nelems_bcast) + minus_f1_f2_sum_scaled_bcast = f32[2,4] broadcast(minus_f1_f2_sum_scaled), dimensions={0} + f1_f2 = f32[2,4] add(f1, f2) + dinput = f32[2,4] add(f1_f2, minus_f1_f2_sum_scaled_bcast) + ROOT out = (f32[2,4], f32[2,4], f32[4], f32[4]) tuple(norm_scale_bias, dinput, dscale, dbias) + })"; + + const char* optimized_hlo = R"( + +; CHECK-LABEL: ENTRY %test (input: f32[2,4], scale: f32[4], bias: f32[4], doutput: f32[2,4]) -> (f32[2,4], f32[2,4], f32[4], f32[4]) { +; CHECK-NEXT: [[P0:%[^ ]+]] = f32[2,4]{1,0} parameter(0) +; CHECK-NEXT: [[P0_BITCAST:%[^ ]+]] = f32[2,4,1,1]{3,2,1,0} bitcast([[P0]]) +; CHECK-NEXT: [[P1:%[^ ]+]] = f32[4]{0} parameter(1) +; CHECK-NEXT: [[P1_BITCAST:%[^ ]+]] = f32[4,1,1,1]{3,2,1,0} bitcast([[P1]]) +; CHECK-NEXT: [[P2:%[^ ]+]] = f32[4]{0} parameter(2) +; CHECK-NEXT: [[P2_BITCAST:%[^ ]+]] = f32[4,1,1,1]{3,2,1,0} bitcast([[P2]]) +; CHECK-NEXT: [[CC0:%[^ ]+]] = (f32[2,4,1,1]{3,2,1,0}, f32[2,1,1,1]{3,2,1,0}, f32[2,1,1,1]{3,2,1,0}, u8[{{.*}}]{0}) custom-call([[P0_BITCAST]], [[P1_BITCAST]], [[P2_BITCAST]]), +; CHECK: custom_call_target="__cudnn$norm", +; CHECK: backend_config={ +; CHECK-DAG: "epsilon":0.001 +; CHECK-DAG: "kind":"LAYER_FWD_TRAIN" +; CHECK: } +; CHECK-DAG: [[GTE0:%[^ ]+]] = f32[2,4,1,1]{3,2,1,0} get-tuple-element([[CC0]]), index=0 +; CHECK-DAG: [[GTE0_BITCAST:%[^ ]+]] = f32[2,4]{1,0} bitcast([[GTE0]]) +; CHECK-DAG: [[P3:%[^ ]+]] = f32[2,4]{1,0} parameter(3) +; CHECK-DAG: [[P3_BITCAST:%[^ ]+]] = f32[2,4,1,1]{3,2,1,0} bitcast([[P3]]) +; CHECK-DAG: [[GTE1:%[^ ]+]] = f32[2,1,1,1]{3,2,1,0} get-tuple-element([[CC0]]), index=1 +; CHECK-DAG: [[GTE2:%[^ ]+]] = f32[2,1,1,1]{3,2,1,0} get-tuple-element([[CC0]]), index=2 +; CHECK-NEXT: [[CC1:%[^ ]+]] = (f32[2,4,1,1]{3,2,1,0}, f32[4,1,1,1]{3,2,1,0}, f32[4,1,1,1]{3,2,1,0}, u8[{{.*}}]{0}) custom-call([[P0_BITCAST]], [[P1_BITCAST]], [[P3_BITCAST]], [[GTE1]], [[GTE2]]), +; CHECK: custom_call_target="__cudnn$norm", +; CHECK: backend_config={ +; CHECK-DAG: "epsilon":0 +; CHECK-DAG: "kind":"LAYER_BWD" +; CHECK: } +; CHECK-DAG: [[GTE3:%[^ ]+]] = f32[2,4,1,1]{3,2,1,0} get-tuple-element([[CC1]]), index=0 +; CHECK-DAG: [[GTE3_BITCAST:%[^ ]+]] = f32[2,4]{1,0} bitcast([[GTE3]]) +; CHECK-DAG: [[GTE4:%[^ ]+]] = f32[4,1,1,1]{3,2,1,0} get-tuple-element([[CC1]]), index=1 +; CHECK-DAG: [[GTE4_BITCAST:%[^ ]+]] = f32[4]{0} bitcast([[GTE4]]) +; CHECK-DAG: [[GTE5:%[^ ]+]] = f32[4,1,1,1]{3,2,1,0} get-tuple-element([[CC1]]), index=2 +; CHECK-DAG: [[GTE5_BITCAST:%[^ ]+]] = f32[4]{0} bitcast([[GTE5]]) +; CHECK-DAG: ROOT [[OUT:%[^ ]+]] = (f32[2,4]{1,0}, f32[2,4]{1,0}, f32[4]{0}, f32[4]{0}) tuple([[GTE0_BITCAST]], [[GTE3_BITCAST]], [[GTE4_BITCAST]], [[GTE5_BITCAST]]) + )"; + + TestNorm(hlo_text, optimized_hlo); +} + +TEST_F(CudnnNormRewriterTest, LayerNormTrainBackward4D3) { +#if (CUDA_VERSION < 12000 || CUDNN_VERSION < 8905) + GTEST_SKIP() << "Layer norm kernels require CUDA 12 and cuDNN 8.9.5."; +#endif + if (!(GetCudaComputeCapability().major == + se::CudaComputeCapability::AMPERE) && + !(GetCudaComputeCapability().major == + se::CudaComputeCapability::HOPPER)) { + GTEST_SKIP() + << "Layer norm kernels require Ampere or Hopper architectures."; + } + const char* hlo_text = R"( + HloModule test + + apply { + a = f32[] parameter(0) + b = f32[] parameter(1) + ROOT c = f32[] add(a,b) + } + + ENTRY test { + input = f32[2,4,6,8] parameter(0) + input_square = f32[2,4,6,8] multiply(input, input) + c0 = f32[] constant(0) + input_square_sum = f32[2,4,6] reduce(input_square, c0), dimensions={3}, to_apply=apply + reduce = f32[2,4,6] reduce(input, c0), dimensions={3}, to_apply=apply + r_nelems = f32[] constant(0.125) + r_nelems_bcast = f32[2,4,6] broadcast(r_nelems), dimensions={} + input_square_mean = f32[2,4,6] multiply(input_square_sum,r_nelems_bcast) + input_mean = f32[2,4,6] multiply(reduce, r_nelems_bcast) + input_mean_square = f32[2,4,6] multiply(input_mean,input_mean) + variance = f32[2,4,6] subtract(input_square_mean,input_mean_square) + epsilon = f32[] constant(0.001) + epsilon_bcast = f32[2,4,6] broadcast(epsilon), dimensions={} + variance_plus_epsilon = f32[2,4,6] add(variance, epsilon_bcast) + norm_factor = f32[2,4,6] rsqrt(variance_plus_epsilon) + norm_factor_bcast = f32[2,4,6,8] broadcast(norm_factor), dimensions={0,1,2} + input_mean_bcast = f32[2,4,6,8] broadcast(input_mean), dimensions={0,1,2} + input_center = f32[2,4,6,8] subtract(input, input_mean_bcast) + norm = f32[2,4,6,8] multiply(input_center, norm_factor_bcast) + scale = f32[8] parameter(1) + scale_bcast = f32[2,4,6,8] broadcast(scale), dimensions={3} + norm_scale = f32[2,4,6,8] multiply(norm, scale_bcast) + bias = f32[8] parameter(2) + bias_bcast = f32[2,4,6,8] broadcast(bias), dimensions={3} + norm_scale_bias = f32[2,4,6,8] add(norm_scale, bias_bcast) + doutput = f32[2,4,6,8] parameter(3) + dbias = f32[8] reduce(doutput, c0), dimensions={0,1,2}, to_apply=apply + norm_doutput = f32[2,4,6,8] multiply(norm, doutput) + dscale = f32[8] reduce(norm_doutput, c0), dimensions={0,1,2}, to_apply=apply + scale_doutput = f32[2,4,6,8] multiply(scale_bcast, doutput) + input_center_scale_doutput = f32[2,4,6,8] multiply(input_center, scale_doutput) + f0 = f32[2,4,6] reduce(input_center_scale_doutput, c0), dimensions={3}, to_apply=apply + norm_factor_cube = f32[2,4,6] divide(norm_factor, variance_plus_epsilon) + c1 = f32[] constant(-0.5) + c1_bcast = f32[2,4,6] broadcast(c1), dimensions={} + dnorm_factor = f32[2,4,6] multiply(norm_factor_cube, c1_bcast) + f0_dnorm_factor = f32[2,4,6] multiply(f0, dnorm_factor) + c2 = f32[] constant(0.25) + c2_bcast = f32[2,4,6] broadcast(c2), dimensions={} + f0_dnorm_factor_scaled = f32[2,4,6] multiply(f0_dnorm_factor, c2_bcast) + f0_dnorm_factor_scaled_bcast = f32[2,4,6,8] broadcast(f0_dnorm_factor_scaled), dimensions={0,1,2} + f1 = f32[2,4,6,8] multiply(input_center, f0_dnorm_factor_scaled_bcast) + minus_f1 = f32[2,4,6,8] negate(f1) + minus_f1_sum = f32[2,4,6] reduce(minus_f1, c0), dimensions={3}, to_apply=apply + f2 = f32[2,4,6,8] multiply(norm_factor_bcast, scale_doutput) + minus_f2 = f32[2,4,6,8] negate(f2) + minus_f2_sum = f32[2,4,6] reduce(minus_f2, c0), dimensions={3}, to_apply=apply + minus_f1_f2_sum = f32[2,4,6] add(minus_f1_sum, minus_f2_sum) + minus_f1_f2_sum_scaled = f32[2,4,6] multiply(minus_f1_f2_sum, r_nelems_bcast) + minus_f1_f2_sum_scaled_bcast = f32[2,4,6,8] broadcast(minus_f1_f2_sum_scaled), dimensions={0,1,2} + f1_f2 = f32[2,4,6,8] add(f1, f2) + dinput = f32[2,4,6,8] add(f1_f2, minus_f1_f2_sum_scaled_bcast) + ROOT out = (f32[2,4,6,8], f32[2,4,6,8], f32[8], f32[8]) tuple(norm_scale_bias, dinput, dscale, dbias) + })"; + + const char* optimized_hlo = R"( + +; CHECK-LABEL: ENTRY %test (input: f32[2,4,6,8], scale: f32[8], bias: f32[8], doutput: f32[2,4,6,8]) -> (f32[2,4,6,8], f32[2,4,6,8], f32[8], f32[8]) { +; CHECK-NEXT: [[P0:%[^ ]+]] = f32[2,4,6,8]{3,2,1,0} parameter(0) +; CHECK-NEXT: [[P0_BITCAST:%[^ ]+]] = f32[48,8,1,1]{3,2,1,0} bitcast([[P0]]) +; CHECK-NEXT: [[P1:%[^ ]+]] = f32[8]{0} parameter(1) +; CHECK-NEXT: [[P1_BITCAST:%[^ ]+]] = f32[8,1,1,1]{3,2,1,0} bitcast([[P1]]) +; CHECK-NEXT: [[P2:%[^ ]+]] = f32[8]{0} parameter(2) +; CHECK-NEXT: [[P2_BITCAST:%[^ ]+]] = f32[8,1,1,1]{3,2,1,0} bitcast([[P2]]) +; CHECK-NEXT: [[CC0:%[^ ]+]] = (f32[48,8,1,1]{3,2,1,0}, f32[48,1,1,1]{3,2,1,0}, f32[48,1,1,1]{3,2,1,0}, u8[{{.*}}]{0}) custom-call([[P0_BITCAST]], [[P1_BITCAST]], [[P2_BITCAST]]), +; CHECK: custom_call_target="__cudnn$norm", +; CHECK: backend_config={ +; CHECK-DAG: "epsilon":0.001 +; CHECK-DAG: "kind":"LAYER_FWD_TRAIN" +; CHECK: } +; CHECK-DAG: [[GTE0:%[^ ]+]] = f32[48,8,1,1]{3,2,1,0} get-tuple-element([[CC0]]), index=0 +; CHECK-DAG: [[GTE0_BITCAST:%[^ ]+]] = f32[2,4,6,8]{3,2,1,0} bitcast([[GTE0]]) +; CHECK-DAG: [[P3:%[^ ]+]] = f32[2,4,6,8]{3,2,1,0} parameter(3) +; CHECK-DAG: [[P3_BITCAST:%[^ ]+]] = f32[48,8,1,1]{3,2,1,0} bitcast([[P3]]) +; CHECK-DAG: [[GTE1:%[^ ]+]] = f32[48,1,1,1]{3,2,1,0} get-tuple-element([[CC0]]), index=1 +; CHECK-DAG: [[GTE2:%[^ ]+]] = f32[48,1,1,1]{3,2,1,0} get-tuple-element([[CC0]]), index=2 +; CHECK-NEXT: [[CC1:%[^ ]+]] = (f32[48,8,1,1]{3,2,1,0}, f32[8,1,1,1]{3,2,1,0}, f32[8,1,1,1]{3,2,1,0}, u8[{{.*}}]{0}) custom-call([[P0_BITCAST]], [[P1_BITCAST]], [[P3_BITCAST]], [[GTE1]], [[GTE2]]), +; CHECK: custom_call_target="__cudnn$norm", +; CHECK: backend_config={ +; CHECK-DAG: "epsilon":0 +; CHECK-DAG: "kind":"LAYER_BWD" +; CHECK: } +; CHECK-DAG: [[GTE3:%[^ ]+]] = f32[48,8,1,1]{3,2,1,0} get-tuple-element([[CC1]]), index=0 +; CHECK-DAG: [[GTE3_BITCAST:%[^ ]+]] = f32[2,4,6,8]{3,2,1,0} bitcast([[GTE3]]) +; CHECK-DAG: [[GTE4:%[^ ]+]] = f32[8,1,1,1]{3,2,1,0} get-tuple-element([[CC1]]), index=1 +; CHECK-DAG: [[GTE4_BITCAST:%[^ ]+]] = f32[8]{0} bitcast([[GTE4]]) +; CHECK-DAG: [[GTE5:%[^ ]+]] = f32[8,1,1,1]{3,2,1,0} get-tuple-element([[CC1]]), index=2 +; CHECK-DAG: [[GTE5_BITCAST:%[^ ]+]] = f32[8]{0} bitcast([[GTE5]]) +; CHECK-DAG: ROOT [[OUT:%[^ ]+]] = (f32[2,4,6,8]{3,2,1,0}, f32[2,4,6,8]{3,2,1,0}, f32[8]{0}, f32[8]{0}) tuple([[GTE0_BITCAST]], [[GTE3_BITCAST]], [[GTE4_BITCAST]], [[GTE5_BITCAST]]) + )"; + + TestNorm(hlo_text, optimized_hlo); +} + +TEST_F(CudnnNormRewriterTest, LayerNormTrainBackward4D2) { +#if (CUDA_VERSION < 12000 || CUDNN_VERSION < 8905) + GTEST_SKIP() << "Layer norm kernels require CUDA 12 and cuDNN 8.9.5."; +#endif + if (!(GetCudaComputeCapability().major == + se::CudaComputeCapability::AMPERE) && + !(GetCudaComputeCapability().major == + se::CudaComputeCapability::HOPPER)) { + GTEST_SKIP() + << "Layer norm kernels require Ampere or Hopper architectures."; + } + const char* hlo_text = R"( + HloModule test + + apply { + a = f32[] parameter(0) + b = f32[] parameter(1) + ROOT c = f32[] add(a,b) + } + + ENTRY test { + input = f32[2,4,6,8] parameter(0) + input_square = f32[2,4,6,8] multiply(input, input) + c0 = f32[] constant(0) + input_square_sum = f32[2,4,8] reduce(input_square, c0), dimensions={2}, to_apply=apply + reduce = f32[2,4,8] reduce(input, c0), dimensions={2}, to_apply=apply + r_nelems = f32[] constant(0.166667) + r_nelems_bcast = f32[2,4,8] broadcast(r_nelems), dimensions={} + input_square_mean = f32[2,4,8] multiply(input_square_sum,r_nelems_bcast) + input_mean = f32[2,4,8] multiply(reduce, r_nelems_bcast) + input_mean_square = f32[2,4,8] multiply(input_mean,input_mean) + variance = f32[2,4,8] subtract(input_square_mean,input_mean_square) + epsilon = f32[] constant(0.001) + epsilon_bcast = f32[2,4,8] broadcast(epsilon), dimensions={} + variance_plus_epsilon = f32[2,4,8] add(variance, epsilon_bcast) + norm_factor = f32[2,4,8] rsqrt(variance_plus_epsilon) + norm_factor_bcast = f32[2,4,6,8] broadcast(norm_factor), dimensions={0,1,3} + input_mean_bcast = f32[2,4,6,8] broadcast(input_mean), dimensions={0,1,3} + input_center = f32[2,4,6,8] subtract(input, input_mean_bcast) + norm = f32[2,4,6,8] multiply(input_center, norm_factor_bcast) + scale = f32[6] parameter(1) + scale_bcast = f32[2,4,6,8] broadcast(scale), dimensions={2} + norm_scale = f32[2,4,6,8] multiply(norm, scale_bcast) + bias = f32[6] parameter(2) + bias_bcast = f32[2,4,6,8] broadcast(bias), dimensions={2} + norm_scale_bias = f32[2,4,6,8] add(norm_scale, bias_bcast) + doutput = f32[2,4,6,8] parameter(3) + dbias = f32[6] reduce(doutput, c0), dimensions={0,1,3}, to_apply=apply + norm_doutput = f32[2,4,6,8] multiply(norm, doutput) + dscale = f32[6] reduce(norm_doutput, c0), dimensions={0,1,3}, to_apply=apply + scale_doutput = f32[2,4,6,8] multiply(scale_bcast, doutput) + input_center_scale_doutput = f32[2,4,6,8] multiply(input_center, scale_doutput) + f0 = f32[2,4,8] reduce(input_center_scale_doutput, c0), dimensions={2}, to_apply=apply + norm_factor_cube = f32[2,4,8] divide(norm_factor, variance_plus_epsilon) + c1 = f32[] constant(-0.5) + c1_bcast = f32[2,4,8] broadcast(c1), dimensions={} + dnorm_factor = f32[2,4,8] multiply(norm_factor_cube, c1_bcast) + f0_dnorm_factor = f32[2,4,8] multiply(f0, dnorm_factor) + c2 = f32[] constant(0.333333) + c2_bcast = f32[2,4,8] broadcast(c2), dimensions={} + f0_dnorm_factor_scaled = f32[2,4,8] multiply(f0_dnorm_factor, c2_bcast) + f0_dnorm_factor_scaled_bcast = f32[2,4,6,8] broadcast(f0_dnorm_factor_scaled), dimensions={0,1,3} + f1 = f32[2,4,6,8] multiply(input_center, f0_dnorm_factor_scaled_bcast) + minus_f1 = f32[2,4,6,8] negate(f1) + minus_f1_sum = f32[2,4,8] reduce(minus_f1, c0), dimensions={2}, to_apply=apply + f2 = f32[2,4,6,8] multiply(norm_factor_bcast, scale_doutput) + minus_f2 = f32[2,4,6,8] negate(f2) + minus_f2_sum = f32[2,4,8] reduce(minus_f2, c0), dimensions={2}, to_apply=apply + minus_f1_f2_sum = f32[2,4,8] add(minus_f1_sum, minus_f2_sum) + minus_f1_f2_sum_scaled = f32[2,4,8] multiply(minus_f1_f2_sum, r_nelems_bcast) + minus_f1_f2_sum_scaled_bcast = f32[2,4,6,8] broadcast(minus_f1_f2_sum_scaled), dimensions={0,1,3} + f1_f2 = f32[2,4,6,8] add(f1, f2) + dinput = f32[2,4,6,8] add(f1_f2, minus_f1_f2_sum_scaled_bcast) + ROOT out = (f32[2,4,6,8], f32[2,4,6,8], f32[6], f32[6]) tuple(norm_scale_bias, dinput, dscale, dbias) + })"; + + const char* optimized_hlo = R"( + +; CHECK-LABEL: ENTRY %test (input: f32[2,4,6,8], scale: f32[6], bias: f32[6], doutput: f32[2,4,6,8]) -> (f32[2,4,6,8], f32[2,4,6,8], f32[6], f32[6]) { +; CHECK-NEXT: [[P0:%[^ ]+]] = f32[2,4,6,8]{3,2,1,0} parameter(0) +; CHECK-NEXT: [[TRANSPOSE0:%[^ ]+]] = f32[2,4,8,6]{3,2,1,0} transpose([[P0]]), dimensions={0,1,3,2} +; CHECK-NEXT: [[P0_BITCAST:%[^ ]+]] = f32[64,6,1,1]{3,2,1,0} bitcast([[TRANSPOSE0]]) +; CHECK-NEXT: [[P1:%[^ ]+]] = f32[6]{0} parameter(1) +; CHECK-NEXT: [[P1_BITCAST:%[^ ]+]] = f32[6,1,1,1]{3,2,1,0} bitcast([[P1]]) +; CHECK-NEXT: [[P2:%[^ ]+]] = f32[6]{0} parameter(2) +; CHECK-NEXT: [[P2_BITCAST:%[^ ]+]] = f32[6,1,1,1]{3,2,1,0} bitcast([[P2]]) +; CHECK-NEXT: [[CC0:%[^ ]+]] = (f32[64,6,1,1]{3,2,1,0}, f32[64,1,1,1]{3,2,1,0}, f32[64,1,1,1]{3,2,1,0}, u8[{{.*}}]{0}) custom-call([[P0_BITCAST]], [[P1_BITCAST]], [[P2_BITCAST]]), +; CHECK: custom_call_target="__cudnn$norm", +; CHECK: backend_config={ +; CHECK-DAG: "epsilon":0.001 +; CHECK-DAG: "kind":"LAYER_FWD_TRAIN" +; CHECK: } +; CHECK-DAG: [[GTE0:%[^ ]+]] = f32[64,6,1,1]{3,2,1,0} get-tuple-element([[CC0]]), index=0 +; CHECK-DAG: [[P3:%[^ ]+]] = f32[2,4,6,8]{3,2,1,0} parameter(3) +; CHECK-NEXT: [[TRANSPOSE1:%[^ ]+]] = f32[2,4,8,6]{3,2,1,0} transpose([[P3]]), dimensions={0,1,3,2} +; CHECK-DAG: [[P3_BITCAST:%[^ ]+]] = f32[64,6,1,1]{3,2,1,0} bitcast([[TRANSPOSE1]]) +; CHECK-DAG: [[GTE1:%[^ ]+]] = f32[64,1,1,1]{3,2,1,0} get-tuple-element([[CC0]]), index=1 +; CHECK-DAG: [[GTE2:%[^ ]+]] = f32[64,1,1,1]{3,2,1,0} get-tuple-element([[CC0]]), index=2 +; CHECK-NEXT: [[CC1:%[^ ]+]] = (f32[64,6,1,1]{3,2,1,0}, f32[6,1,1,1]{3,2,1,0}, f32[6,1,1,1]{3,2,1,0}, u8[{{.*}}]{0}) custom-call([[P0_BITCAST]], [[P1_BITCAST]], [[P3_BITCAST]], [[GTE1]], [[GTE2]]), +; CHECK: custom_call_target="__cudnn$norm", +; CHECK: backend_config={ +; CHECK-DAG: "epsilon":0 +; CHECK-DAG: "kind":"LAYER_BWD" +; CHECK: } +; CHECK-DAG: [[GTE3:%[^ ]+]] = f32[64,6,1,1]{3,2,1,0} get-tuple-element([[CC1]]), index=0 +; CHECK-DAG: [[FUSION:%[^ ]+]] = (f32[2,4,6,8]{3,2,1,0}, f32[2,4,6,8]{3,2,1,0}) fusion([[GTE0]], [[GTE3]]), kind=kLoop, calls=[[FUSED_COMPUTATION:%[^ ]+]] +; CHECK-DAG: [[GTEF0:%[^ ]+]] = f32[2,4,6,8]{3,2,1,0} get-tuple-element([[FUSION]]), index=0 +; CHECK-DAG: [[GTEF1:%[^ ]+]] = f32[2,4,6,8]{3,2,1,0} get-tuple-element([[FUSION]]), index=1 +; CHECK-DAG: [[GTE4:%[^ ]+]] = f32[6,1,1,1]{3,2,1,0} get-tuple-element([[CC1]]), index=1 +; CHECK-DAG: [[GTE4_BITCAST:%[^ ]+]] = f32[6]{0} bitcast([[GTE4]]) +; CHECK-DAG: [[GTE5:%[^ ]+]] = f32[6,1,1,1]{3,2,1,0} get-tuple-element([[CC1]]), index=2 +; CHECK-DAG: [[GTE5_BITCAST:%[^ ]+]] = f32[6]{0} bitcast([[GTE5]]) +; CHECK-DAG: ROOT [[OUT:%[^ ]+]] = (f32[2,4,6,8]{3,2,1,0}, f32[2,4,6,8]{3,2,1,0}, f32[6]{0}, f32[6]{0}) tuple([[GTEF0]], [[GTEF1]], [[GTE4_BITCAST]], [[GTE5_BITCAST]]) + )"; + + TestNorm(hlo_text, optimized_hlo); +} + +TEST_F(CudnnNormRewriterTest, LayerNormTrainBackward4D12) { +#if (CUDA_VERSION < 12000 || CUDNN_VERSION < 8905) + GTEST_SKIP() << "Layer norm kernels require CUDA 12 and cuDNN 8.9.5."; +#endif + if (!(GetCudaComputeCapability().major == + se::CudaComputeCapability::AMPERE) && + !(GetCudaComputeCapability().major == + se::CudaComputeCapability::HOPPER)) { + GTEST_SKIP() + << "Layer norm kernels require Ampere or Hopper architectures."; + } + const char* hlo_text = R"( + HloModule test + + apply { + a = f32[] parameter(0) + b = f32[] parameter(1) + ROOT c = f32[] add(a,b) + } + + ENTRY test { + input = f32[2,4,6,8] parameter(0) + input_square = f32[2,4,6,8] multiply(input, input) + c0 = f32[] constant(0) + input_square_sum = f32[2,8] reduce(input_square, c0), dimensions={1,2}, to_apply=apply + reduce = f32[2,8] reduce(input, c0), dimensions={1,2}, to_apply=apply + r_nelems = f32[] constant(0.041667) + r_nelems_bcast = f32[2,8] broadcast(r_nelems), dimensions={} + input_square_mean = f32[2,8] multiply(input_square_sum,r_nelems_bcast) + input_mean = f32[2,8] multiply(reduce, r_nelems_bcast) + input_mean_square = f32[2,8] multiply(input_mean,input_mean) + variance = f32[2,8] subtract(input_square_mean,input_mean_square) + epsilon = f32[] constant(0.001) + epsilon_bcast = f32[2,8] broadcast(epsilon), dimensions={} + variance_plus_epsilon = f32[2,8] add(variance, epsilon_bcast) + norm_factor = f32[2,8] rsqrt(variance_plus_epsilon) + norm_factor_bcast = f32[2,4,6,8] broadcast(norm_factor), dimensions={0,3} + input_mean_bcast = f32[2,4,6,8] broadcast(input_mean), dimensions={0,3} + input_center = f32[2,4,6,8] subtract(input, input_mean_bcast) + norm = f32[2,4,6,8] multiply(input_center, norm_factor_bcast) + scale = f32[4,6] parameter(1) + scale_bcast = f32[2,4,6,8] broadcast(scale), dimensions={1,2} + norm_scale = f32[2,4,6,8] multiply(norm, scale_bcast) + bias = f32[4,6] parameter(2) + bias_bcast = f32[2,4,6,8] broadcast(bias), dimensions={1,2} + norm_scale_bias = f32[2,4,6,8] add(norm_scale, bias_bcast) + doutput = f32[2,4,6,8] parameter(3) + dbias = f32[4,6] reduce(doutput, c0), dimensions={0,3}, to_apply=apply + norm_doutput = f32[2,4,6,8] multiply(norm, doutput) + dscale = f32[4,6] reduce(norm_doutput, c0), dimensions={0,3}, to_apply=apply + scale_doutput = f32[2,4,6,8] multiply(scale_bcast, doutput) + input_center_scale_doutput = f32[2,4,6,8] multiply(input_center, scale_doutput) + f0 = f32[2,8] reduce(input_center_scale_doutput, c0), dimensions={1,2}, to_apply=apply + norm_factor_cube = f32[2,8] divide(norm_factor, variance_plus_epsilon) + c1 = f32[] constant(-0.5) + c1_bcast = f32[2,8] broadcast(c1), dimensions={} + dnorm_factor = f32[2,8] multiply(norm_factor_cube, c1_bcast) + f0_dnorm_factor = f32[2,8] multiply(f0, dnorm_factor) + c2 = f32[] constant(0.083333) + c2_bcast = f32[2,8] broadcast(c2), dimensions={} + f0_dnorm_factor_scaled = f32[2,8] multiply(f0_dnorm_factor, c2_bcast) + f0_dnorm_factor_scaled_bcast = f32[2,4,6,8] broadcast(f0_dnorm_factor_scaled), dimensions={0,3} + f1 = f32[2,4,6,8] multiply(input_center, f0_dnorm_factor_scaled_bcast) + minus_f1 = f32[2,4,6,8] negate(f1) + minus_f1_sum = f32[2,8] reduce(minus_f1, c0), dimensions={1,2}, to_apply=apply + f2 = f32[2,4,6,8] multiply(norm_factor_bcast, scale_doutput) + minus_f2 = f32[2,4,6,8] negate(f2) + minus_f2_sum = f32[2,8] reduce(minus_f2, c0), dimensions={1,2}, to_apply=apply + minus_f1_f2_sum = f32[2,8] add(minus_f1_sum, minus_f2_sum) + minus_f1_f2_sum_scaled = f32[2,8] multiply(minus_f1_f2_sum, r_nelems_bcast) + minus_f1_f2_sum_scaled_bcast = f32[2,4,6,8] broadcast(minus_f1_f2_sum_scaled), dimensions={0,3} + f1_f2 = f32[2,4,6,8] add(f1, f2) + dinput = f32[2,4,6,8] add(f1_f2, minus_f1_f2_sum_scaled_bcast) + ROOT out = (f32[2,4,6,8], f32[2,4,6,8], f32[4,6], f32[4,6]) tuple(norm_scale_bias, dinput, dscale, dbias) + })"; + + const char* optimized_hlo = R"( + +; CHECK-LABEL: ENTRY %test (input: f32[2,4,6,8], scale: f32[4,6], bias: f32[4,6], doutput: f32[2,4,6,8]) -> (f32[2,4,6,8], f32[2,4,6,8], f32[4,6], f32[4,6]) { +; CHECK-NEXT: [[P0:%[^ ]+]] = f32[2,4,6,8]{3,2,1,0} parameter(0) +; CHECK-NEXT: [[TRANSPOSE0:%[^ ]+]] = f32[2,8,4,6]{3,2,1,0} transpose([[P0]]), dimensions={0,3,1,2} +; CHECK-NEXT: [[P0_BITCAST:%[^ ]+]] = f32[16,4,6,1]{3,2,1,0} bitcast([[TRANSPOSE0]]) +; CHECK-NEXT: [[P1:%[^ ]+]] = f32[4,6]{1,0} parameter(1) +; CHECK-NEXT: [[P1_BITCAST:%[^ ]+]] = f32[4,6,1,1]{3,2,1,0} bitcast([[P1]]) +; CHECK-NEXT: [[P2:%[^ ]+]] = f32[4,6]{1,0} parameter(2) +; CHECK-NEXT: [[P2_BITCAST:%[^ ]+]] = f32[4,6,1,1]{3,2,1,0} bitcast([[P2]]) +; CHECK-NEXT: [[CC0:%[^ ]+]] = (f32[16,4,6,1]{3,2,1,0}, f32[16,1,1,1]{3,2,1,0}, f32[16,1,1,1]{3,2,1,0}, u8[{{.*}}]{0}) custom-call([[P0_BITCAST]], [[P1_BITCAST]], [[P2_BITCAST]]), +; CHECK: custom_call_target="__cudnn$norm", +; CHECK: backend_config={ +; CHECK-DAG: "epsilon":0.001 +; CHECK-DAG: "kind":"LAYER_FWD_TRAIN" +; CHECK: } +; CHECK-DAG: [[GTE0:%[^ ]+]] = f32[16,4,6,1]{3,2,1,0} get-tuple-element([[CC0]]), index=0 +; CHECK-DAG: [[P3:%[^ ]+]] = f32[2,4,6,8]{3,2,1,0} parameter(3) +; CHECK-NEXT: [[TRANSPOSE1:%[^ ]+]] = f32[2,8,4,6]{3,2,1,0} transpose([[P3]]), dimensions={0,3,1,2} +; CHECK-DAG: [[P3_BITCAST:%[^ ]+]] = f32[16,4,6,1]{3,2,1,0} bitcast([[TRANSPOSE1]]) +; CHECK-DAG: [[GTE1:%[^ ]+]] = f32[16,1,1,1]{3,2,1,0} get-tuple-element([[CC0]]), index=1 +; CHECK-DAG: [[GTE2:%[^ ]+]] = f32[16,1,1,1]{3,2,1,0} get-tuple-element([[CC0]]), index=2 +; CHECK-NEXT: [[CC1:%[^ ]+]] = (f32[16,4,6,1]{3,2,1,0}, f32[4,6,1,1]{3,2,1,0}, f32[4,6,1,1]{3,2,1,0}, u8[{{.*}}]{0}) custom-call([[P0_BITCAST]], [[P1_BITCAST]], [[P3_BITCAST]], [[GTE1]], [[GTE2]]), +; CHECK: custom_call_target="__cudnn$norm", +; CHECK: backend_config={ +; CHECK-DAG: "epsilon":0 +; CHECK-DAG: "kind":"LAYER_BWD" +; CHECK: } +; CHECK-DAG: [[GTE3:%[^ ]+]] = f32[16,4,6,1]{3,2,1,0} get-tuple-element([[CC1]]), index=0 +; CHECK-DAG: [[FUSION:%[^ ]+]] = (f32[2,4,6,8]{3,2,1,0}, f32[2,4,6,8]{3,2,1,0}) fusion([[GTE0]], [[GTE3]]), kind=kLoop, calls=[[FUSED_COMPUTATION:%[^ ]+]] +; CHECK-DAG: [[GTEF0:%[^ ]+]] = f32[2,4,6,8]{3,2,1,0} get-tuple-element([[FUSION]]), index=0 +; CHECK-DAG: [[GTEF1:%[^ ]+]] = f32[2,4,6,8]{3,2,1,0} get-tuple-element([[FUSION]]), index=1 +; CHECK-DAG: [[GTE4:%[^ ]+]] = f32[4,6,1,1]{3,2,1,0} get-tuple-element([[CC1]]), index=1 +; CHECK-DAG: [[GTE4_BITCAST:%[^ ]+]] = f32[4,6]{1,0} bitcast([[GTE4]]) +; CHECK-DAG: [[GTE5:%[^ ]+]] = f32[4,6,1,1]{3,2,1,0} get-tuple-element([[CC1]]), index=2 +; CHECK-DAG: [[GTE5_BITCAST:%[^ ]+]] = f32[4,6]{1,0} bitcast([[GTE5]]) +; CHECK-DAG: ROOT [[OUT:%[^ ]+]] = (f32[2,4,6,8]{3,2,1,0}, f32[2,4,6,8]{3,2,1,0}, f32[4,6]{1,0}, f32[4,6]{1,0}) tuple([[GTEF0]], [[GTEF1]], [[GTE4_BITCAST]], [[GTE5_BITCAST]]) + )"; + + TestNorm(hlo_text, optimized_hlo); +} + +TEST_F(CudnnNormRewriterTest, LayerNormTrainBackward4D1DoutputReshapeSplit) { +#if (CUDA_VERSION < 12000 || CUDNN_VERSION < 8905) + GTEST_SKIP() << "Layer norm kernels require CUDA 12 and cuDNN 8.9.5."; +#endif + if (!(GetCudaComputeCapability().major == + se::CudaComputeCapability::AMPERE) && + !(GetCudaComputeCapability().major == + se::CudaComputeCapability::HOPPER)) { + GTEST_SKIP() + << "Layer norm kernels require Ampere or Hopper architectures."; + } + const char* hlo_text = R"( + HloModule test + + apply { + a = f32[] parameter(0) + b = f32[] parameter(1) + ROOT c = f32[] add(a,b) + } + + ENTRY test { + input = f32[2,4,6,8] parameter(0) + input_square = f32[2,4,6,8] multiply(input, input) + c0 = f32[] constant(0) + input_square_sum = f32[2,6,8] reduce(input_square, c0), dimensions={1}, to_apply=apply + reduce = f32[2,6,8] reduce(input, c0), dimensions={1}, to_apply=apply + r_nelems = f32[] constant(0.25) + r_nelems_bcast = f32[2,6,8] broadcast(r_nelems), dimensions={} + input_square_mean = f32[2,6,8] multiply(input_square_sum,r_nelems_bcast) + input_mean = f32[2,6,8] multiply(reduce, r_nelems_bcast) + input_mean_square = f32[2,6,8] multiply(input_mean,input_mean) + variance = f32[2,6,8] subtract(input_square_mean,input_mean_square) + epsilon = f32[] constant(0.001) + epsilon_bcast = f32[2,6,8] broadcast(epsilon), dimensions={} + variance_plus_epsilon = f32[2,6,8] add(variance, epsilon_bcast) + norm_factor = f32[2,6,8] rsqrt(variance_plus_epsilon) + norm_factor_bcast = f32[2,4,6,8] broadcast(norm_factor), dimensions={0,2,3} + input_mean_bcast = f32[2,4,6,8] broadcast(input_mean), dimensions={0,2,3} + input_center = f32[2,4,6,8] subtract(input, input_mean_bcast) + norm = f32[2,4,6,8] multiply(input_center, norm_factor_bcast) + scale = f32[4] parameter(1) + scale_bcast = f32[2,4,6,8] broadcast(scale), dimensions={1} + norm_scale = f32[2,4,6,8] multiply(norm, scale_bcast) + bias = f32[4] parameter(2) + bias_bcast = f32[2,4,6,8] broadcast(bias), dimensions={1} + norm_scale_bias = f32[2,4,6,8] add(norm_scale, bias_bcast) + doutput = f32[2,4,48] parameter(3) + dbias = f32[4] reduce(doutput, c0), dimensions={0,2}, to_apply=apply + doutput_bitcast = f32[2,4,6,8] reshape(doutput) + norm_doutput = f32[2,4,6,8] multiply(norm, doutput_bitcast) + dscale = f32[4] reduce(norm_doutput, c0), dimensions={0,2,3}, to_apply=apply + scale_doutput = f32[2,4,6,8] multiply(scale_bcast, doutput_bitcast) + input_center_scale_doutput = f32[2,4,6,8] multiply(input_center, scale_doutput) + f0 = f32[2,6,8] reduce(input_center_scale_doutput, c0), dimensions={1}, to_apply=apply + norm_factor_cube = f32[2,6,8] divide(norm_factor, variance_plus_epsilon) + c1 = f32[] constant(-0.5) + c1_bcast = f32[2,6,8] broadcast(c1), dimensions={} + dnorm_factor = f32[2,6,8] multiply(norm_factor_cube, c1_bcast) + f0_dnorm_factor = f32[2,6,8] multiply(f0, dnorm_factor) + c2 = f32[] constant(0.5) + c2_bcast = f32[2,6,8] broadcast(c2), dimensions={} + f0_dnorm_factor_scaled = f32[2,6,8] multiply(f0_dnorm_factor, c2_bcast) + f0_dnorm_factor_scaled_bcast = f32[2,4,6,8] broadcast(f0_dnorm_factor_scaled), dimensions={0,2,3} + f1 = f32[2,4,6,8] multiply(input_center, f0_dnorm_factor_scaled_bcast) + minus_f1 = f32[2,4,6,8] negate(f1) + minus_f1_sum = f32[2,6,8] reduce(minus_f1, c0), dimensions={1}, to_apply=apply + f2 = f32[2,4,6,8] multiply(norm_factor_bcast, scale_doutput) + minus_f2 = f32[2,4,6,8] negate(f2) + minus_f2_sum = f32[2,6,8] reduce(minus_f2, c0), dimensions={1}, to_apply=apply + minus_f1_f2_sum = f32[2,6,8] add(minus_f1_sum, minus_f2_sum) + minus_f1_f2_sum_scaled = f32[2,6,8] multiply(minus_f1_f2_sum, r_nelems_bcast) + minus_f1_f2_sum_scaled_bcast = f32[2,4,6,8] broadcast(minus_f1_f2_sum_scaled), dimensions={0,2,3} + f1_f2 = f32[2,4,6,8] add(f1, f2) + dinput = f32[2,4,6,8] add(f1_f2, minus_f1_f2_sum_scaled_bcast) + ROOT out = (f32[2,4,6,8], f32[2,4,6,8], f32[4], f32[4]) tuple(norm_scale_bias, dinput, dscale, dbias) + })"; + + const char* optimized_hlo = R"( + +; CHECK-LABEL: ENTRY %test (input: f32[2,4,6,8], scale: f32[4], bias: f32[4], doutput: f32[2,4,48]) -> (f32[2,4,6,8], f32[2,4,6,8], f32[4], f32[4]) { +; CHECK-NEXT: [[P0:%[^ ]+]] = f32[2,4,6,8]{3,2,1,0} parameter(0) +; CHECK-NEXT: [[TRANSPOSE0:%[^ ]+]] = f32[2,6,8,4]{3,2,1,0} transpose([[P0]]), dimensions={0,2,3,1} +; CHECK-NEXT: [[P0_BITCAST:%[^ ]+]] = f32[96,4,1,1]{3,2,1,0} bitcast([[TRANSPOSE0]]) +; CHECK-NEXT: [[P1:%[^ ]+]] = f32[4]{0} parameter(1) +; CHECK-NEXT: [[P1_BITCAST:%[^ ]+]] = f32[4,1,1,1]{3,2,1,0} bitcast([[P1]]) +; CHECK-NEXT: [[P2:%[^ ]+]] = f32[4]{0} parameter(2) +; CHECK-NEXT: [[P2_BITCAST:%[^ ]+]] = f32[4,1,1,1]{3,2,1,0} bitcast([[P2]]) +; CHECK-NEXT: [[CC0:%[^ ]+]] = (f32[96,4,1,1]{3,2,1,0}, f32[96,1,1,1]{3,2,1,0}, f32[96,1,1,1]{3,2,1,0}, u8[{{.*}}]{0}) custom-call([[P0_BITCAST]], [[P1_BITCAST]], [[P2_BITCAST]]), +; CHECK: custom_call_target="__cudnn$norm", +; CHECK: backend_config={ +; CHECK-DAG: "epsilon":0.001 +; CHECK-DAG: "kind":"LAYER_FWD_TRAIN" +; CHECK: } +; CHECK-DAG: [[GTE0:%[^ ]+]] = f32[96,4,1,1]{3,2,1,0} get-tuple-element([[CC0]]), index=0 +; CHECK-DAG: [[P3:%[^ ]+]] = f32[2,4,48]{2,1,0} parameter(3) +; CHECK-DAG: [[FUSION0:%[^ ]+]] = f32[2,6,8,4]{3,2,1,0} fusion([[P3]]), kind=kLoop, calls=[[FUSED_COMPUTATION0:%[^ ]+]] +; CHECK-DAG: [[FUSION0_BITCAST:%[^ ]+]] = f32[96,4,1,1]{3,2,1,0} bitcast([[FUSION0]]) +; CHECK-DAG: [[GTE1:%[^ ]+]] = f32[96,1,1,1]{3,2,1,0} get-tuple-element([[CC0]]), index=1 +; CHECK-DAG: [[GTE2:%[^ ]+]] = f32[96,1,1,1]{3,2,1,0} get-tuple-element([[CC0]]), index=2 +; CHECK-NEXT: [[CC1:%[^ ]+]] = (f32[96,4,1,1]{3,2,1,0}, f32[4,1,1,1]{3,2,1,0}, f32[4,1,1,1]{3,2,1,0}, u8[{{.*}}]{0}) custom-call([[P0_BITCAST]], [[P1_BITCAST]], [[FUSION0_BITCAST]], [[GTE1]], [[GTE2]]), +; CHECK: custom_call_target="__cudnn$norm", +; CHECK: backend_config={ +; CHECK-DAG: "epsilon":0 +; CHECK-DAG: "kind":"LAYER_BWD" +; CHECK: } +; CHECK-DAG: [[GTE3:%[^ ]+]] = f32[96,4,1,1]{3,2,1,0} get-tuple-element([[CC1]]), index=0 +; CHECK-DAG: [[FUSION1:%[^ ]+]] = (f32[2,4,6,8]{3,2,1,0}, f32[2,4,6,8]{3,2,1,0}) fusion([[GTE0]], [[GTE3]]), kind=kLoop, calls=[[FUSED_COMPUTATION1:%[^ ]+]] +; CHECK-DAG: [[GTEF1:%[^ ]+]] = f32[2,4,6,8]{3,2,1,0} get-tuple-element([[FUSION1]]), index=0 +; CHECK-DAG: [[GTEF2:%[^ ]+]] = f32[2,4,6,8]{3,2,1,0} get-tuple-element([[FUSION1]]), index=1 +; CHECK-DAG: [[GTE4:%[^ ]+]] = f32[4,1,1,1]{3,2,1,0} get-tuple-element([[CC1]]), index=1 +; CHECK-DAG: [[GTE4_BITCAST:%[^ ]+]] = f32[4]{0} bitcast([[GTE4]]) +; CHECK-DAG: [[GTE5:%[^ ]+]] = f32[4,1,1,1]{3,2,1,0} get-tuple-element([[CC1]]), index=2 +; CHECK-DAG: [[GTE5_BITCAST:%[^ ]+]] = f32[4]{0} bitcast([[GTE5]]) +; CHECK-DAG: ROOT [[OUT:%[^ ]+]] = (f32[2,4,6,8]{3,2,1,0}, f32[2,4,6,8]{3,2,1,0}, f32[4]{0}, f32[4]{0}) tuple([[GTEF1]], [[GTEF2]], [[GTE4_BITCAST]], [[GTE5_BITCAST]]) + )"; + + TestNorm(hlo_text, optimized_hlo); +} + +TEST_F(CudnnNormRewriterTest, LayerNormTrainBackward4D1DoutputReshapeCombine) { +#if (CUDA_VERSION < 12000 || CUDNN_VERSION < 8905) + GTEST_SKIP() << "Layer norm kernels require CUDA 12 and cuDNN 8.9.5."; +#endif + if (!(GetCudaComputeCapability().major == + se::CudaComputeCapability::AMPERE) && + !(GetCudaComputeCapability().major == + se::CudaComputeCapability::HOPPER)) { + GTEST_SKIP() + << "Layer norm kernels require Ampere or Hopper architectures."; + } + const char* hlo_text = R"( + HloModule test + + apply { + a = f32[] parameter(0) + b = f32[] parameter(1) + ROOT c = f32[] add(a,b) + } + + ENTRY test { + input = f32[2,4,6,8] parameter(0) + input_square = f32[2,4,6,8] multiply(input, input) + c0 = f32[] constant(0) + input_square_sum = f32[2,6,8] reduce(input_square, c0), dimensions={1}, to_apply=apply + reduce = f32[2,6,8] reduce(input, c0), dimensions={1}, to_apply=apply + r_nelems = f32[] constant(0.25) + r_nelems_bcast = f32[2,6,8] broadcast(r_nelems), dimensions={} + input_square_mean = f32[2,6,8] multiply(input_square_sum,r_nelems_bcast) + input_mean = f32[2,6,8] multiply(reduce, r_nelems_bcast) + input_mean_square = f32[2,6,8] multiply(input_mean,input_mean) + variance = f32[2,6,8] subtract(input_square_mean,input_mean_square) + epsilon = f32[] constant(0.001) + epsilon_bcast = f32[2,6,8] broadcast(epsilon), dimensions={} + variance_plus_epsilon = f32[2,6,8] add(variance, epsilon_bcast) + norm_factor = f32[2,6,8] rsqrt(variance_plus_epsilon) + norm_factor_bcast = f32[2,4,6,8] broadcast(norm_factor), dimensions={0,2,3} + input_mean_bcast = f32[2,4,6,8] broadcast(input_mean), dimensions={0,2,3} + input_center = f32[2,4,6,8] subtract(input, input_mean_bcast) + norm = f32[2,4,6,8] multiply(input_center, norm_factor_bcast) + scale = f32[4] parameter(1) + scale_bcast = f32[2,4,6,8] broadcast(scale), dimensions={1} + norm_scale = f32[2,4,6,8] multiply(norm, scale_bcast) + bias = f32[4] parameter(2) + bias_bcast = f32[2,4,6,8] broadcast(bias), dimensions={1} + norm_scale_bias = f32[2,4,6,8] add(norm_scale, bias_bcast) + doutput = f32[2,4,6,2,2,2] parameter(3) + dbias = f32[4] reduce(doutput, c0), dimensions={0,2,3,4,5}, to_apply=apply + doutput_bitcast = f32[2,4,6,8] reshape(doutput) + norm_doutput = f32[2,4,6,8] multiply(norm, doutput_bitcast) + dscale = f32[4] reduce(norm_doutput, c0), dimensions={0,2,3}, to_apply=apply + scale_doutput = f32[2,4,6,8] multiply(scale_bcast, doutput_bitcast) + input_center_scale_doutput = f32[2,4,6,8] multiply(input_center, scale_doutput) + f0 = f32[2,6,8] reduce(input_center_scale_doutput, c0), dimensions={1}, to_apply=apply + norm_factor_cube = f32[2,6,8] divide(norm_factor, variance_plus_epsilon) + c1 = f32[] constant(-0.5) + c1_bcast = f32[2,6,8] broadcast(c1), dimensions={} + dnorm_factor = f32[2,6,8] multiply(norm_factor_cube, c1_bcast) + f0_dnorm_factor = f32[2,6,8] multiply(f0, dnorm_factor) + c2 = f32[] constant(0.5) + c2_bcast = f32[2,6,8] broadcast(c2), dimensions={} + f0_dnorm_factor_scaled = f32[2,6,8] multiply(f0_dnorm_factor, c2_bcast) + f0_dnorm_factor_scaled_bcast = f32[2,4,6,8] broadcast(f0_dnorm_factor_scaled), dimensions={0,2,3} + f1 = f32[2,4,6,8] multiply(input_center, f0_dnorm_factor_scaled_bcast) + minus_f1 = f32[2,4,6,8] negate(f1) + minus_f1_sum = f32[2,6,8] reduce(minus_f1, c0), dimensions={1}, to_apply=apply + f2 = f32[2,4,6,8] multiply(norm_factor_bcast, scale_doutput) + minus_f2 = f32[2,4,6,8] negate(f2) + minus_f2_sum = f32[2,6,8] reduce(minus_f2, c0), dimensions={1}, to_apply=apply + minus_f1_f2_sum = f32[2,6,8] add(minus_f1_sum, minus_f2_sum) + minus_f1_f2_sum_scaled = f32[2,6,8] multiply(minus_f1_f2_sum, r_nelems_bcast) + minus_f1_f2_sum_scaled_bcast = f32[2,4,6,8] broadcast(minus_f1_f2_sum_scaled), dimensions={0,2,3} + f1_f2 = f32[2,4,6,8] add(f1, f2) + dinput = f32[2,4,6,8] add(f1_f2, minus_f1_f2_sum_scaled_bcast) + ROOT out = (f32[2,4,6,8], f32[2,4,6,8], f32[4], f32[4]) tuple(norm_scale_bias, dinput, dscale, dbias) + })"; + + const char* optimized_hlo = R"( + +; CHECK-LABEL: ENTRY %test (input: f32[2,4,6,8], scale: f32[4], bias: f32[4], doutput: f32[2,4,6,2,2,2]) -> (f32[2,4,6,8], f32[2,4,6,8], f32[4], f32[4]) { +; CHECK-NEXT: [[P0:%[^ ]+]] = f32[2,4,6,8]{3,2,1,0} parameter(0) +; CHECK-NEXT: [[TRANSPOSE0:%[^ ]+]] = f32[2,6,8,4]{3,2,1,0} transpose([[P0]]), dimensions={0,2,3,1} +; CHECK-NEXT: [[P0_BITCAST:%[^ ]+]] = f32[96,4,1,1]{3,2,1,0} bitcast([[TRANSPOSE0]]) +; CHECK-NEXT: [[P1:%[^ ]+]] = f32[4]{0} parameter(1) +; CHECK-NEXT: [[P1_BITCAST:%[^ ]+]] = f32[4,1,1,1]{3,2,1,0} bitcast([[P1]]) +; CHECK-NEXT: [[P2:%[^ ]+]] = f32[4]{0} parameter(2) +; CHECK-NEXT: [[P2_BITCAST:%[^ ]+]] = f32[4,1,1,1]{3,2,1,0} bitcast([[P2]]) +; CHECK-NEXT: [[CC0:%[^ ]+]] = (f32[96,4,1,1]{3,2,1,0}, f32[96,1,1,1]{3,2,1,0}, f32[96,1,1,1]{3,2,1,0}, u8[{{.*}}]{0}) custom-call([[P0_BITCAST]], [[P1_BITCAST]], [[P2_BITCAST]]), +; CHECK: custom_call_target="__cudnn$norm", +; CHECK: backend_config={ +; CHECK-DAG: "epsilon":0.001 +; CHECK-DAG: "kind":"LAYER_FWD_TRAIN" +; CHECK: } +; CHECK-DAG: [[GTE0:%[^ ]+]] = f32[96,4,1,1]{3,2,1,0} get-tuple-element([[CC0]]), index=0 +; CHECK-DAG: [[P3:%[^ ]+]] = f32[2,4,6,2,2,2]{5,4,3,2,1,0} parameter(3) +; CHECK-DAG: [[FUSION0:%[^ ]+]] = f32[2,6,8,4]{3,2,1,0} fusion([[P3]]), kind=kLoop, calls=[[FUSED_COMPUTATION0:%[^ ]+]] +; CHECK-DAG: [[FUSION0_BITCAST:%[^ ]+]] = f32[96,4,1,1]{3,2,1,0} bitcast([[FUSION0]]) +; CHECK-DAG: [[GTE1:%[^ ]+]] = f32[96,1,1,1]{3,2,1,0} get-tuple-element([[CC0]]), index=1 +; CHECK-DAG: [[GTE2:%[^ ]+]] = f32[96,1,1,1]{3,2,1,0} get-tuple-element([[CC0]]), index=2 +; CHECK-NEXT: [[CC1:%[^ ]+]] = (f32[96,4,1,1]{3,2,1,0}, f32[4,1,1,1]{3,2,1,0}, f32[4,1,1,1]{3,2,1,0}, u8[{{.*}}]{0}) custom-call([[P0_BITCAST]], [[P1_BITCAST]], [[FUSION0_BITCAST]], [[GTE1]], [[GTE2]]), +; CHECK: custom_call_target="__cudnn$norm", +; CHECK: backend_config={ +; CHECK-DAG: "epsilon":0 +; CHECK-DAG: "kind":"LAYER_BWD" +; CHECK: } +; CHECK-DAG: [[GTE3:%[^ ]+]] = f32[96,4,1,1]{3,2,1,0} get-tuple-element([[CC1]]), index=0 +; CHECK-DAG: [[FUSION1:%[^ ]+]] = (f32[2,4,6,8]{3,2,1,0}, f32[2,4,6,8]{3,2,1,0}) fusion([[GTE0]], [[GTE3]]), kind=kLoop, calls=[[FUSED_COMPUTATION1:%[^ ]+]] +; CHECK-DAG: [[GTEF1:%[^ ]+]] = f32[2,4,6,8]{3,2,1,0} get-tuple-element([[FUSION1]]), index=0 +; CHECK-DAG: [[GTEF2:%[^ ]+]] = f32[2,4,6,8]{3,2,1,0} get-tuple-element([[FUSION1]]), index=1 +; CHECK-DAG: [[GTE4:%[^ ]+]] = f32[4,1,1,1]{3,2,1,0} get-tuple-element([[CC1]]), index=1 +; CHECK-DAG: [[GTE4_BITCAST:%[^ ]+]] = f32[4]{0} bitcast([[GTE4]]) +; CHECK-DAG: [[GTE5:%[^ ]+]] = f32[4,1,1,1]{3,2,1,0} get-tuple-element([[CC1]]), index=2 +; CHECK-DAG: [[GTE5_BITCAST:%[^ ]+]] = f32[4]{0} bitcast([[GTE5]]) +; CHECK-DAG: ROOT [[OUT:%[^ ]+]] = (f32[2,4,6,8]{3,2,1,0}, f32[2,4,6,8]{3,2,1,0}, f32[4]{0}, f32[4]{0}) tuple([[GTEF1]], [[GTEF2]], [[GTE4_BITCAST]], [[GTE5_BITCAST]]) + )"; + + TestNorm(hlo_text, optimized_hlo); +} + } // namespace } // namespace gpu } // namespace xla diff --git a/third_party/xla/xla/service/gpu/cusolver_context.cc b/third_party/xla/xla/service/gpu/cusolver_context.cc index d63f7a7cc18a52..4343fec8b19870 100644 --- a/third_party/xla/xla/service/gpu/cusolver_context.cc +++ b/third_party/xla/xla/service/gpu/cusolver_context.cc @@ -17,6 +17,7 @@ limitations under the License. #include #include +#include #include #include "absl/status/status.h" @@ -25,6 +26,7 @@ limitations under the License. #include "third_party/gpus/cuda/include/cusolverDn.h" #include "third_party/gpus/cuda/include/cusolver_common.h" #include "third_party/gpus/cuda/include/driver_types.h" +#include "third_party/gpus/cuda/include/library_types.h" #endif #include "xla/primitive_util.h" #include "xla/status.h" @@ -285,6 +287,8 @@ absl::Status ConvertStatus(rocblas_status status) { GPU_SOLVER_CAT(GPU_SOLVER_PREFIX, Cpotrf_bufferSize) #define GpuSolverZpotrf_bufferSize \ GPU_SOLVER_CAT(GPU_SOLVER_PREFIX, Zpotrf_bufferSize) +#define GpuSolverDnXpotrf_bufferSize \ + GPU_SOLVER_CAT(GPU_SOLVER_PREFIX, Xpotrf_bufferSize) #if TENSORFLOW_USE_CUSOLVER_OR_HIPSOLVER #define GpuSolverSpotrf GPU_SOLVER_CAT(GPU_SOLVER_PREFIX, Spotrf) #define GpuSolverDpotrf GPU_SOLVER_CAT(GPU_SOLVER_PREFIX, Dpotrf) @@ -294,6 +298,7 @@ absl::Status ConvertStatus(rocblas_status status) { #define GpuSolverDpotrfBatched GPU_SOLVER_CAT(GPU_SOLVER_PREFIX, DpotrfBatched) #define GpuSolverCpotrfBatched GPU_SOLVER_CAT(GPU_SOLVER_PREFIX, CpotrfBatched) #define GpuSolverZpotrfBatched GPU_SOLVER_CAT(GPU_SOLVER_PREFIX, ZpotrfBatched) +#define GpuSolverXpotrf GPU_SOLVER_CAT(GPU_SOLVER_PREFIX, Xpotrf) #else // TENSORFLOW_USE_ROCSOLVER #define GpuSolverSpotrf GPU_SOLVER_CAT(GPU_SOLVER_PREFIX, spotrf) #define GpuSolverDpotrf GPU_SOLVER_CAT(GPU_SOLVER_PREFIX, dpotrf) @@ -338,35 +343,36 @@ absl::StatusOr GpuSolverContext::PotrfBufferSize( int batch_size) { #if TENSORFLOW_USE_CUSOLVER_OR_HIPSOLVER int size = -1; + size_t d_lwork = 0; /* size of workspace */ + size_t h_lwork = 0; /* size of workspace */ + + cudaDataType_t cuda_data_type; switch (type) { case F32: { - TF_RETURN_IF_ERROR(ConvertStatus( - GpuSolverSpotrf_bufferSize(handle_.get(), GpuBlasUpperLower(uplo), n, - /*A=*/nullptr, lda, &size))); + cuda_data_type = CUDA_R_32F; break; } case F64: { - TF_RETURN_IF_ERROR(ConvertStatus( - GpuSolverDpotrf_bufferSize(handle_.get(), GpuBlasUpperLower(uplo), n, - /*A=*/nullptr, lda, &size))); + cuda_data_type = CUDA_R_64F; break; } case C64: { - TF_RETURN_IF_ERROR(ConvertStatus( - GpuSolverCpotrf_bufferSize(handle_.get(), GpuBlasUpperLower(uplo), n, - /*A=*/nullptr, lda, &size))); + cuda_data_type = CUDA_C_32F; break; } case C128: { - TF_RETURN_IF_ERROR(ConvertStatus( - GpuSolverZpotrf_bufferSize(handle_.get(), GpuBlasUpperLower(uplo), n, - /*A=*/nullptr, lda, &size))); + cuda_data_type = CUDA_C_64F; break; } default: return InvalidArgument("Invalid type for cholesky decomposition: %s", PrimitiveType_Name(type)); } + TF_RETURN_IF_ERROR(ConvertStatus(GpuSolverDnXpotrf_bufferSize( + handle_.get(), nullptr, GpuBlasUpperLower(uplo), n, cuda_data_type, + nullptr, lda, cuda_data_type, &d_lwork, &h_lwork))); + size = static_cast(d_lwork); + // CUDA's potrfBatched needs space for the `as` array, which contains // batch_size pointers. Divide by sizeof(type) because this function returns // not bytes but a number of elements of `type`. @@ -428,5 +434,49 @@ absl::Status GpuSolverContext::PotrfBatched( ToDevicePointer(lapack_info), batch_size)); } +absl::Status GpuSolverContext::Potrf(se::blas::UpperLower uplo, int n, + se::DeviceMemory a, int lda, + se::DeviceMemory lapack_info, + se::DeviceMemory workspace) { + absl::Status status = ConvertStatus(GpuSolverXpotrf( + handle_.get(), nullptr, GpuBlasUpperLower(uplo), n, CUDA_R_64F, + ToDevicePointer(a), lda, CUDA_R_64F, ToDevicePointer(workspace), + workspace.ElementCount(), nullptr, 0, ToDevicePointer(lapack_info))); + return status; +} + +absl::Status GpuSolverContext::Potrf(se::blas::UpperLower uplo, int n, + se::DeviceMemory a, int lda, + se::DeviceMemory lapack_info, + se::DeviceMemory workspace) { + absl::Status status = ConvertStatus(GpuSolverXpotrf( + handle_.get(), nullptr, GpuBlasUpperLower(uplo), n, CUDA_R_32F, + ToDevicePointer(a), lda, CUDA_R_32F, ToDevicePointer(workspace), + workspace.ElementCount(), nullptr, 0, ToDevicePointer(lapack_info))); + return status; +} + +absl::Status GpuSolverContext::Potrf( + se::blas::UpperLower uplo, int n, se::DeviceMemory> a, + int lda, se::DeviceMemory lapack_info, + se::DeviceMemory> workspace) { + absl::Status status = ConvertStatus(GpuSolverXpotrf( + handle_.get(), nullptr, GpuBlasUpperLower(uplo), n, CUDA_C_32F, + ToDevicePointer(a), lda, CUDA_C_32F, ToDevicePointer(workspace), + workspace.ElementCount(), nullptr, 0, ToDevicePointer(lapack_info))); + return status; +} + +absl::Status GpuSolverContext::Potrf( + se::blas::UpperLower uplo, int n, se::DeviceMemory> a, + int lda, se::DeviceMemory lapack_info, + se::DeviceMemory> workspace) { + absl::Status status = ConvertStatus(GpuSolverXpotrf( + handle_.get(), nullptr, GpuBlasUpperLower(uplo), n, CUDA_C_64F, + ToDevicePointer(a), lda, CUDA_C_64F, ToDevicePointer(workspace), + workspace.ElementCount(), nullptr, 0, ToDevicePointer(lapack_info))); + return status; +} + } // namespace gpu } // namespace xla diff --git a/third_party/xla/xla/service/gpu/cusolver_context.h b/third_party/xla/xla/service/gpu/cusolver_context.h index d17a570ef438c0..d72228ec21b04d 100644 --- a/third_party/xla/xla/service/gpu/cusolver_context.h +++ b/third_party/xla/xla/service/gpu/cusolver_context.h @@ -19,6 +19,8 @@ limitations under the License. #include #include +#include "absl/status/status.h" + #define TENSORFLOW_USE_HIPSOLVER \ (TENSORFLOW_USE_ROCM && (TF_ROCM_VERSION >= 40500)) #define TENSORFLOW_USE_ROCSOLVER \ @@ -76,6 +78,23 @@ class GpuSolverContext { se::DeviceMemory*> as, int lda, se::DeviceMemory lapack_info, int batch_size); + absl::Status Potrf(se::blas::UpperLower uplo, int n, + se::DeviceMemory a, int lda, + se::DeviceMemory lapack_info, + se::DeviceMemory workspace); + absl::Status Potrf(se::blas::UpperLower uplo, int n, + se::DeviceMemory a, int lda, + se::DeviceMemory lapack_info, + se::DeviceMemory workspace); + absl::Status Potrf(se::blas::UpperLower uplo, int n, + se::DeviceMemory> a, int lda, + se::DeviceMemory lapack_info, + se::DeviceMemory> workspace); + absl::Status Potrf(se::blas::UpperLower uplo, int n, + se::DeviceMemory> a, int lda, + se::DeviceMemory lapack_info, + se::DeviceMemory> workspace); + // Returns the max size of the `workspace` required by Potrf and PotrfBatched, // in number of elements of `type`. // diff --git a/third_party/xla/xla/service/gpu/custom_call_test.cc b/third_party/xla/xla/service/gpu/custom_call_test.cc index cf8adf9d4d0188..b06314b31e5947 100644 --- a/third_party/xla/xla/service/gpu/custom_call_test.cc +++ b/third_party/xla/xla/service/gpu/custom_call_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include #include +#include #include #include #include @@ -22,6 +23,7 @@ limitations under the License. #include #include +#include "absl/algorithm/container.h" #include "absl/strings/str_format.h" #include "xla/shape.h" #include "tsl/platform/statusor.h" @@ -45,14 +47,8 @@ limitations under the License. #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instructions.h" -#include "xla/runtime/custom_call.h" -#include "xla/runtime/custom_call_registry.h" -#include "xla/runtime/executable.h" // IWYU pragma: keep -#include "xla/runtime/memref_view.h" #include "xla/service/custom_call_status.h" #include "xla/service/custom_call_target_registry.h" -#include "xla/service/gpu/runtime/custom_call_registry.h" -#include "xla/service/gpu/runtime/support.h" #include "xla/service/service_executable_run_options.h" #include "xla/shape_util.h" #include "xla/status.h" @@ -354,100 +350,21 @@ TEST_F(CustomCallTest, WithStatusFailed) { // XLA runtime custom calls provides type-safe custom call API //===----------------------------------------------------------------------===// -// WARNING: We currently rely on a magic custom call prefix `__gpu$` to detect -// "internal" custom calls that linked statically into the binary. Without this -// prefix custom calls expected to be registered as XLA:FFI custom calls, and -// this is not yet fully supported. -// -// TODO(ezhulenev): Unify runtime custom calls and XLA:FFI. - -// (1) Declare custom call implementations as static functions. - -static absl::Status AlwaysFailImpl(runtime::MemrefView arg, int32_t value) { - return absl::InternalError(absl::StrCat("Uh oh, wrong value: ", value)); -} - -static absl::Status MemcpyImpl(const ServiceExecutableRunOptions* run_options, - runtime::MemrefView src, - runtime::MemrefView dst) { - auto src_mem = gpu::GetDeviceAddress(src); - auto dst_mem = gpu::GetDeviceAddress(dst); - run_options->stream()->ThenMemcpyD2D(&dst_mem, src_mem, src_mem.size()); - return absl::OkStatus(); -} - -// (2) Declare custom call binding signature. At compile time we check that -// declared signature matches function handlers, and at run time we check that -// passed arguments match the signature (number of arguments and their types). - -// TODO(ezhulenev): Remove these custom calls once we switch to thunks runtime. - -XLA_RUNTIME_DEFINE_CUSTOM_CALL( - AlwaysFail, AlwaysFailImpl, runtime::CustomCall::RuntimeChecks::kDefault, - runtime::CustomCall::Bind("__gpu$xla.gpu.ext.always_fail") - .Arg() // arg - .Attr("value") // value -); - -XLA_RUNTIME_DEFINE_CUSTOM_CALL( - Memcpy, MemcpyImpl, runtime::CustomCall::RuntimeChecks::kDefault, - runtime::CustomCall::Bind("__gpu$xla.gpu.ext.memcpy") - .UserData() - .Arg() // src - .Arg() // dst -); - -// (3) Declare FFI handlers as adaptors for legacy XLA runtime custom calls. -// -// TODO(ezhulenev): This is a long term replacement for "legacy" custom calls -// (custom calls with void** arguments) and a type safe xla runtime custom -// calls (see above). XLA FFI unifies internal custom calls (static linking) -// with external custom calls (dynamically loaded libraries). Make this the only -// example, once it's fully supported. - -namespace impl { static absl::Status AlwaysFail(ffi::BufferBase arg, int32_t value) { - return AlwaysFailImpl(arg, value); -} - -static absl::Status Memcpy(const ServiceExecutableRunOptions* run_options, - ffi::BufferBase src, ffi::BufferBase dst) { - return MemcpyImpl(run_options, src, dst); + return absl::InternalError(absl::StrCat("Uh oh, wrong value: ", value)); } -} // namespace impl -XLA_FFI_DEFINE_HANDLER(kAlwaysFail, impl::AlwaysFail, +XLA_FFI_DEFINE_HANDLER(kAlwaysFail, AlwaysFail, ffi::Ffi::Bind() .Arg() // arg .Attr("value") // value ); - -XLA_FFI_DEFINE_HANDLER(kMemcpy, impl::Memcpy, - ffi::Ffi::Bind() - .Ctx() - .Arg() // src - .Arg() // dst -); - -// (4) Register custom calls handlers with XLA runtime. - -static void RegisterCustomCalls(runtime::DirectCustomCallRegistry& registry) { - registry.Register("__gpu$xla.gpu.ext.always_fail", AlwaysFail); - registry.Register("__gpu$xla.gpu.ext.memcpy", Memcpy); -} - -XLA_GPU_REGISTER_RUNTIME_CUSTOM_CALL(RegisterCustomCalls); - -// (5) Register XLA FFI handlers with XLA runtime. - -XLA_FFI_REGISTER_HANDLER(ffi::GetXlaFfiApi(), "__gpu$xla.gpu.ext.always_fail", +XLA_FFI_REGISTER_HANDLER(ffi::GetXlaFfiApi(), "__xla_test$$always_fail", PLATFORM, kAlwaysFail); -XLA_FFI_REGISTER_HANDLER(ffi::GetXlaFfiApi(), "__gpu$xla.gpu.ext.memcpy", - PLATFORM, kMemcpy); TEST_F(CustomCallTest, RuntimeCustomCallAlwaysFail) { XlaBuilder b(TestName()); - CustomCall(&b, "__gpu$xla.gpu.ext.always_fail", /*operands=*/{}, + CustomCall(&b, "__xla_test$$always_fail", /*operands=*/{}, ShapeUtil::MakeShape(F32, {}), /*opaque=*/"{value = 42 : i32}", /*has_side_effect=*/false, /*output_operand_aliasing=*/{}, /*literal=*/nullptr, @@ -458,9 +375,26 @@ TEST_F(CustomCallTest, RuntimeCustomCallAlwaysFail) { EXPECT_THAT(status.message(), ::testing::HasSubstr("Uh oh, wrong value: 42")); } +static absl::Status Memcpy(const ServiceExecutableRunOptions* run_options, + ffi::BufferBase src, ffi::BufferBase dst) { + return run_options->stream()->MemcpyD2D( + &dst.data, src.data, + absl::c_accumulate(src.dimensions, 1.0, std::multiplies()) * + sizeof(float)); +} + +XLA_FFI_DEFINE_HANDLER(kMemcpy, Memcpy, + ffi::Ffi::Bind() + .Ctx() + .Arg() // src + .Arg() // dst +); +XLA_FFI_REGISTER_HANDLER(ffi::GetXlaFfiApi(), "__xla_test$$memcpy", PLATFORM, + kMemcpy); + TEST_F(CustomCallTest, ExportedFfiMemcpy) { XlaBuilder b(TestName()); - CustomCall(&b, "__gpu$xla.gpu.ext.memcpy", + CustomCall(&b, "__xla_test$$memcpy", /*operands=*/{Broadcast(ConstantR0WithType(&b, F32, 42.0), {128})}, ShapeUtil::MakeShape(F32, {128}), /*opaque=*/"", /*has_side_effect=*/false, @@ -471,7 +405,6 @@ TEST_F(CustomCallTest, ExportedFfiMemcpy) { EXPECT_THAT(result.data(), ::testing::Each(42)); } -// Test passing arbitrary pointers as i64 attributes. static absl::Status HandleUserPointer(ffi::BufferBase, const std::string* str) { return absl::InternalError(*str); } @@ -501,6 +434,187 @@ TEST_F(CustomCallTest, PassUserPointerWithAttrs) { EXPECT_THAT(status.message(), ::testing::HasSubstr("User-defined message")); } +bool is_ffi_invoked = false; +static absl::Status IsInvoked(ffi::BufferBase) { + is_ffi_invoked = true; + return absl::OkStatus(); +} + +XLA_FFI_DEFINE_HANDLER( + kIsInvoked, IsInvoked, + ffi::Ffi::Bind().Arg()); // Buffer for result (unused). + +XLA_FFI_REGISTER_HANDLER(ffi::GetXlaFfiApi(), "__xla_test$$isinvoked", PLATFORM, + kIsInvoked); + +TEST_F(CustomCallTest, ExportedFfiIsInvoked) { + XlaBuilder b(TestName()); + CustomCall(&b, "__xla_test$$isinvoked", /*operands=*/{}, + ShapeUtil::MakeShape(F32, {}), /*opaque=*/"", + /*has_side_effect=*/false, + /*output_operand_aliasing=*/{}, /*literal=*/nullptr, + /*schedule=*/CustomCallSchedule::SCHEDULE_NONE, + /*api_version=*/CustomCallApiVersion::API_VERSION_TYPED_FFI); + TF_ASSERT_OK_AND_ASSIGN(auto result, ExecuteAndTransfer(&b, {})); + EXPECT_TRUE(is_ffi_invoked); +} + +TEST_F(CustomCallTest, ExportedFfiUnknownTarget) { + XlaBuilder b(TestName()); + CustomCall(&b, "__xla_test$$unknown_target", /*operands=*/{}, + ShapeUtil::MakeShape(F32, {}), /*opaque=*/"", + /*has_side_effect=*/false, + /*output_operand_aliasing=*/{}, /*literal=*/nullptr, + /*schedule=*/CustomCallSchedule::SCHEDULE_NONE, + /*api_version=*/CustomCallApiVersion::API_VERSION_TYPED_FFI); + auto status = Execute(&b, {}).status(); + EXPECT_EQ(status.code(), absl::StatusCode::kUnimplemented); + EXPECT_THAT(status.message(), + ::testing::HasSubstr("No registered implementation")); +} + +// Memcpy and SubBuffers tests are already ported in +// fusions/address_computation_fusion_test.cc + +// Reusing kExpectedOpaque from the original test. +static absl::Status Opaque(ffi::BufferBase, const std::string* str) { + std::string opaque(*str); + if (opaque != kExpectedOpaque) + return absl::InternalError(absl::StrFormat( + "Opaque string does not match. Expected `%s` but got `%s`", + kExpectedOpaque, opaque)); + return absl::OkStatus(); +} + +XLA_FFI_DEFINE_HANDLER(kOpaque, Opaque, + ffi::Ffi::Bind() + .Arg() // Dummy result buffer. + .Attr>("opaque")); + +XLA_FFI_REGISTER_HANDLER(ffi::GetXlaFfiApi(), "__xla_test$$opaque", PLATFORM, + kOpaque); + +TEST_F(CustomCallTest, ExportedFfiOpaque) { + XlaBuilder b(TestName()); + const std::string opaque = absl::StrFormat( + "{opaque = %d : i64}", reinterpret_cast(&kExpectedOpaque)); + CustomCall(&b, "__xla_test$$opaque", /*operands=*/{}, + ShapeUtil::MakeShape(F32, {}), + /*opaque=*/opaque, + /*has_side_effect=*/false, + /*output_operand_aliasing=*/{}, /*literal=*/nullptr, + /*schedule=*/CustomCallSchedule::SCHEDULE_NONE, + /*api_version=*/CustomCallApiVersion::API_VERSION_TYPED_FFI); + TF_ASSERT_OK(Execute(&b, {}).status()); +} + +static absl::Status TokensChecker(std::vector inputs, + const std::string* opaque) { + // TODO(penporn): Actually check the inputs when FFI handlers support tokens. + return absl::OkStatus(); +} + +static absl::Status Tokens1Input(ffi::BufferBase input1, ffi::BufferBase, + const std::string* opaque) { + return TokensChecker({input1}, opaque); +} + +static absl::Status Tokens2Inputs(ffi::BufferBase input1, + ffi::BufferBase input2, ffi::BufferBase, + const std::string* opaque) { + return TokensChecker({input1, input2}, opaque); +} + +static absl::Status Tokens3Inputs(ffi::BufferBase input1, + ffi::BufferBase input2, + ffi::BufferBase input3, ffi::BufferBase, + const std::string* opaque) { + return TokensChecker({input1, input2, input3}, opaque); +} + +XLA_FFI_DEFINE_HANDLER(kTokens1Input, Tokens1Input, + ffi::Ffi::Bind() + .Arg() // 1 input buffer. + .Arg() // Output buffer. + .Attr>("opaque")); + +XLA_FFI_REGISTER_HANDLER(ffi::GetXlaFfiApi(), "__xla_test$$tokens_1input", + PLATFORM, kTokens1Input); + +XLA_FFI_DEFINE_HANDLER(kTokens2Inputs, Tokens2Inputs, + ffi::Ffi::Bind() + .Arg() // 1st input buffer. + .Arg() // 2nd input buffer. + .Arg() // Output buffer. + .Attr>("opaque")); + +XLA_FFI_REGISTER_HANDLER(ffi::GetXlaFfiApi(), "__xla_test$$tokens_2inputs", + PLATFORM, kTokens2Inputs); + +XLA_FFI_DEFINE_HANDLER(kTokens3Inputs, Tokens3Inputs, + ffi::Ffi::Bind() + .Arg() // 1st input buffer. + .Arg() // 2nd input buffer. + .Arg() // 3rd input buffer. + .Arg() // Output buffer. + .Attr>("opaque")); + +XLA_FFI_REGISTER_HANDLER(ffi::GetXlaFfiApi(), "__xla_test$$tokens_3inputs", + PLATFORM, kTokens3Inputs); + +TEST_P(CustomCallTokensTest, ExportedFfiTokensTest) { + const TokenTestCase& tc = GetParam(); + XlaBuilder b(TestName()); + std::istringstream input(tc.input); + std::istringstream output(tc.output); + std::vector call_inputs = BuildInputs(b, input); + std::vector call_output = BuildOutputType(output); + ASSERT_GE(call_inputs.size(), 1); + ASSERT_LE(call_inputs.size(), 3); + ASSERT_EQ(call_output.size(), 1); + + const std::string custom_call_name = + absl::StrFormat("__xla_test$$tokens_%dinput%s", call_inputs.size(), + call_inputs.size() == 1 ? "" : "s"); + const std::string opaque = absl::StrFormat( + "{opaque = %d : i64}", reinterpret_cast(&tc.opaque)); + CustomCall(&b, custom_call_name, /*operands=*/call_inputs, + call_output.front(), + /*opaque=*/opaque, + /*has_side_effect=*/false, + /*output_operand_aliasing=*/{}, /*literal=*/nullptr, + /*schedule=*/CustomCallSchedule::SCHEDULE_NONE, + /*api_version=*/CustomCallApiVersion::API_VERSION_TYPED_FFI); + + // TODO(penporn): Expect an OK status when FFI handlers support tokens. + auto status = Execute(&b, {}).status(); + EXPECT_EQ(status.code(), absl::StatusCode::kInternal); + EXPECT_THAT(status.message(), + ::testing::HasSubstr("FFI handlers do not support tokens")); +} + +INSTANTIATE_TEST_SUITE_P(CustomCallTokensTest, CustomCallTokensTest, + ::testing::ValuesIn(GetTokenTestCases())); + +static absl::Status AlwaysSucceed(ffi::BufferBase) { return absl::OkStatus(); } + +XLA_FFI_DEFINE_HANDLER(kAlwaysSucceed, AlwaysSucceed, + ffi::Ffi::Bind().Arg()); + +XLA_FFI_REGISTER_HANDLER(ffi::GetXlaFfiApi(), "__xla_test$$always_succeed", + PLATFORM, kAlwaysSucceed); + +TEST_F(CustomCallTest, ExportedFfiWithStatusSucceeded) { + XlaBuilder b(TestName()); + CustomCall(&b, "__xla_test$$always_succeed", /*operands=*/{}, + ShapeUtil::MakeShape(F32, {}), /*opaque=*/"", + /*has_side_effect=*/false, + /*output_operand_aliasing=*/{}, /*literal=*/nullptr, + /*schedule=*/CustomCallSchedule::SCHEDULE_NONE, + /*api_version=*/CustomCallApiVersion::API_VERSION_TYPED_FFI); + TF_ASSERT_OK(Execute(&b, {}).status()); +} + //===----------------------------------------------------------------------===// // XLA:FFI handler with attached HloComputation //===----------------------------------------------------------------------===// @@ -517,7 +631,7 @@ static absl::Status MemcpyWithCalledComputation( if (!DynCast(called_computation->root_instruction())) return absl::InternalError("ROOT must be a paremeter"); - return MemcpyImpl(run_options, src, dst); + return Memcpy(run_options, src, dst); } XLA_FFI_DEFINE_HANDLER(kMemcpyWithCalledComputation, @@ -529,13 +643,10 @@ XLA_FFI_DEFINE_HANDLER(kMemcpyWithCalledComputation, .Ctx()); XLA_FFI_REGISTER_HANDLER(ffi::GetXlaFfiApi(), - "__gpu$xla.gpu.ext.memcpy_with_called_compuation", - PLATFORM, kMemcpyWithCalledComputation); + "xla.gpu.ext.memcpy_with_called_computation", PLATFORM, + kMemcpyWithCalledComputation); TEST_F(CustomCallTest, WithCalledComputation) { - // FFI handlers with called computations supported only with Thunks runtime. - mutable_debug_options()->set_xla_gpu_enable_xla_runtime_executable(false); - auto shape = ShapeUtil::MakeShape(F32, {128}); // Build a called computation which is just a copy instruction. @@ -546,7 +657,7 @@ TEST_F(CustomCallTest, WithCalledComputation) { XlaBuilder b(TestName()); CustomCallWithComputation( - &b, "__gpu$xla.gpu.ext.memcpy_with_called_compuation", + &b, "xla.gpu.ext.memcpy_with_called_computation", /*operands=*/{Broadcast(ConstantR0WithType(&b, F32, 42.0), {128})}, copy_computation, shape, /*opaque=*/"", /*has_side_effect=*/false, diff --git a/third_party/xla/xla/service/gpu/elemental_ir_emitter.cc b/third_party/xla/xla/service/gpu/elemental_ir_emitter.cc index c1519e6c2d60fd..51296c7146d8f9 100644 --- a/third_party/xla/xla/service/gpu/elemental_ir_emitter.cc +++ b/third_party/xla/xla/service/gpu/elemental_ir_emitter.cc @@ -49,18 +49,6 @@ namespace gpu { using absl::StrAppend; -namespace { -// Returns whether operand is a floating-point literal with the given value. -bool IsFPLiteralWithValue(const HloInstruction* operand, float value) { - if (operand->opcode() == HloOpcode::kConstant && - operand->literal().IsAllFloat(value)) { - return true; - } - return operand->opcode() == HloOpcode::kBroadcast && - IsFPLiteralWithValue(operand->operand(0), value); -} -} // namespace - GpuElementalIrEmitter::GpuElementalIrEmitter( IrEmitterContext& ir_emitter_context, llvm::IRBuilder<>* b) : ElementalIrEmitter(ir_emitter_context.llvm_module(), b), @@ -109,29 +97,6 @@ absl::StatusOr GpuElementalIrEmitter::EmitDeviceMathCall( return result; } -absl::StatusOr GpuElementalIrEmitter::EmitLlvmIntrinsicMathCall( - const std::string& callee_name, absl::Span operands, - absl::Span input_types, PrimitiveType output_type) { - // llvm intrinsics differentiate between half/float/double functions via - // the suffixes ".f16", ".f32" and ".f64". - std::string munged_callee = callee_name; - switch (output_type) { - case F16: - StrAppend(&munged_callee, ".f16"); - break; - case F32: - StrAppend(&munged_callee, ".f32"); - break; - case F64: - StrAppend(&munged_callee, ".f64"); - break; - default: - return Unimplemented("Bad type for llvm intrinsic math call: %s", - PrimitiveType_Name(output_type)); - } - return EmitMathCall(munged_callee, operands, input_types, output_type); -} - absl::StatusOr GpuElementalIrEmitter::EmitMathCall( const std::string& callee_name, absl::Span operands, absl::Span input_types, PrimitiveType output_type, @@ -330,6 +295,22 @@ absl::StatusOr GpuElementalIrEmitter::EmitTanh( value->getType(), "tanh"); } +absl::StatusOr GpuElementalIrEmitter::EmitErf( + PrimitiveType prim_type, llvm::Value* value) { + if (prim_type == F64) { + return EmitDeviceMathCall(TargetDeviceFunctionID::kErf, {value}, + {prim_type}, prim_type); + } + // Upcast F16 to F32 if necessary. + llvm::Type* type = prim_type == F16 ? b()->getFloatTy() : value->getType(); + if (type == b()->getFloatTy()) { + llvm::Value* x = FPCast(value, type); + auto* result = llvm_ir::EmitErfF32(b(), x); + return FPCast(result, value->getType()); + } + return Unimplemented("erf"); +} + absl::StatusOr GpuElementalIrEmitter::EmitComplexAbs( PrimitiveType prim_type, llvm::Value* value) { return EmitDeviceMathCall(TargetDeviceFunctionID::kHypot, @@ -343,19 +324,6 @@ absl::StatusOr GpuElementalIrEmitter::EmitCbrt( prim_type); } -llvm::Value* GpuElementalIrEmitter::EmitThreadId() { - llvm::Value* block_id = IntCast( - EmitCallToTargetIntrinsic(TargetIntrinsicID::kBlockIdx, {}, {}, b()), - b()->getIntNTy(128), /*isSigned=*/true, "block.id"); - llvm::Value* thread_id_in_block = IntCast( - EmitCallToTargetIntrinsic(TargetIntrinsicID::kThreadIdx, {}, {}, b()), - b()->getIntNTy(128), /*isSigned=*/true, "thread.id"); - llvm::Value* threads_per_block = IntCast( - EmitCallToTargetIntrinsic(TargetIntrinsicID::kBlockDimx, {}, {}, b()), - b()->getIntNTy(128), /*isSigned=*/true, "threads_per_block"); - return NSWAdd(NSWMul(block_id, threads_per_block), thread_id_in_block); -} - absl::StatusOr GpuElementalIrEmitter::EmitF32ToBF16( llvm::Value* f32_value) { // sm_80 and up has an instruction to convert f32 into bf16. diff --git a/third_party/xla/xla/service/gpu/elemental_ir_emitter.h b/third_party/xla/xla/service/gpu/elemental_ir_emitter.h index 037005bd63b3d9..770cee379e4ff6 100644 --- a/third_party/xla/xla/service/gpu/elemental_ir_emitter.h +++ b/third_party/xla/xla/service/gpu/elemental_ir_emitter.h @@ -85,6 +85,9 @@ class GpuElementalIrEmitter : public ElementalIrEmitter { absl::StatusOr EmitTanh(PrimitiveType prim_type, llvm::Value* value) override; + absl::StatusOr EmitErf(PrimitiveType prim_type, + llvm::Value* value) override; + absl::StatusOr EmitComplexAbs(PrimitiveType prim_type, llvm::Value* value) override; @@ -95,8 +98,6 @@ class GpuElementalIrEmitter : public ElementalIrEmitter { const HloComputation& callee, absl::Span parameters, absl::string_view, bool /*is_reducer*/) override; - llvm::Value* EmitThreadId() override; - absl::StatusOr EmitF32ToBF16(llvm::Value* f32_value) override; bool fast_min_max() override { @@ -109,13 +110,6 @@ class GpuElementalIrEmitter : public ElementalIrEmitter { llvm::Value* lhs_value, llvm::Value* rhs_value); - // Emits IR to call an LLVM intrinsic of type [T] -> T. Adjusts - // callee_name according to T. Returns the IR value that represents the - // return value of the function. - absl::StatusOr EmitLlvmIntrinsicMathCall( - const std::string& callee_name, absl::Span operands, - absl::Span input_types, PrimitiveType output_type); - // Emits IR to call a device function of type [T] -> T. Adjusts // callee_name according to T. Returns the IR value that represents the // return value of the function. diff --git a/third_party/xla/xla/service/gpu/fusion_process_dump.cc b/third_party/xla/xla/service/gpu/fusion_process_dump.cc new file mode 100644 index 00000000000000..9863a3a7b63ef8 --- /dev/null +++ b/third_party/xla/xla/service/gpu/fusion_process_dump.cc @@ -0,0 +1,219 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/fusion_process_dump.h" + +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/container/inlined_vector.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/service/gpu/fusion_process_dump.pb.h" +#include "xla/stream_executor/device_description.h" +#include "xla/tools/hlo_module_loader.h" +#include "xla/util.h" +#include "tsl/platform/env.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/logging.h" +#include "tsl/platform/path.h" +#include "tsl/platform/protobuf.h" // IWYU pragma: keep +#include "tsl/platform/status.h" +#include "tsl/platform/statusor.h" + +namespace xla { +namespace gpu { + +namespace { + +HloInstruction* AddFusionInstruction(HloInstruction* producer, + HloInstruction* consumer, + HloComputation* computation, + std::string_view fusion_name) { + if (consumer->opcode() == HloOpcode::kFusion) { + return consumer; + } + + // This is not true for all fusions, but the fusion kind isn't used in the + // cost model and fusion pipeline, so it doesn't matter here. Set kLoop for + // everything. + auto kind = HloInstruction::FusionKind::kLoop; + + auto fusion_instruction = computation->AddInstruction( + HloInstruction::CreateFusion(consumer->shape(), kind, consumer), + /*new_name=*/fusion_name); + TF_CHECK_OK(computation->ReplaceInstruction(consumer, fusion_instruction)); + + return fusion_instruction; +} + +HloInstruction* Fuse(HloInstruction* producer, HloInstruction* consumer, + HloComputation* computation, + std::string_view fusion_name) { + HloInstruction* fusion_instruction = + AddFusionInstruction(producer, consumer, computation, fusion_name); + if (producer->opcode() == HloOpcode::kFusion) { + fusion_instruction->MergeFusionInstruction(producer); + } else { + fusion_instruction->FuseInstruction(producer); + } + + if (producer->user_count() == 0) { + TF_CHECK_OK(computation->RemoveInstruction(producer)); + } + + return fusion_instruction; +} + +absl::string_view GetProducerName(const FusionStep& step) { + if (step.has_fusion()) { + return step.fusion().producer_name(); + } + + if (step.has_update_priority()) { + return step.update_priority().producer_name(); + } + + if (step.has_producer_ineligible()) { + return step.producer_ineligible().producer_name(); + } + + LOG(FATAL) << "Producer name not found in the current step."; +} + +} // namespace + +absl::StatusOr FusionProcessDump::LoadFromFile( + const std::string& path) { + std::string format = std::string(tsl::io::Extension(path)); + std::string data; + TF_RETURN_IF_ERROR(tsl::ReadFileToString(tsl::Env::Default(), path, &data)); + return FusionProcessDump::LoadFromData(data, format); +} + +absl::StatusOr FusionProcessDump::LoadFromData( + const std::string& data, absl::string_view format) { + FusionProcessDumpProto fusion_process_dump_proto; + if (format == "txt" || format == "pbtxt") { + if (!tsl::protobuf::TextFormat::ParseFromString( + data, &fusion_process_dump_proto)) { + return InvalidArgument("Failed to parse input as HLO protobuf text"); + } + } else if (format == "pb") { + if (!fusion_process_dump_proto.ParseFromString(data)) { + return InvalidArgument("Failed to parse input as HLO protobuf binary"); + } + } else { + return InvalidArgument( + "Invalid format from file extension: '%s'. Expected: txt, pb, or pbtxt", + format); + } + + return FusionProcessDump::LoadFromProto(fusion_process_dump_proto); +} + +absl::StatusOr FusionProcessDump::LoadFromProto( + const FusionProcessDumpProto& fusion_process_dump_proto) { + TF_ASSIGN_OR_RETURN( + auto module, + LoadModuleFromData(fusion_process_dump_proto.hlo_module_before_fusion(), + /*format=*/"txt")); + + se::DeviceDescription gpu_device_info( + fusion_process_dump_proto.gpu_device_info()); + + absl::flat_hash_map + instruction_name_to_computation_map; + for (HloComputation* computation : module->MakeNonfusionComputations()) { + for (HloInstruction* instr : computation->instructions()) { + instruction_name_to_computation_map[instr->name()] = computation; + } + } + + return FusionProcessDump(std::move(fusion_process_dump_proto), + std::move(module), std::move(gpu_device_info), + std::move(instruction_name_to_computation_map)); +} + +HloComputation* FusionProcessDump::GetCurrentComputation() { + return instruction_name_to_computation_map_.at( + GetProducerName(CurrentStep())); +} + +HloInstruction* FusionProcessDump::GetInstructionWithName( + absl::string_view name) { + return instruction_name_to_computation_map_[name]->GetInstructionWithName( + name); +} + +HloInstruction* FusionProcessDump::GetProducer() { + return GetInstructionWithName(GetProducerName(CurrentStep())); +} + +absl::InlinedVector FusionProcessDump::GetConsumers() { + auto& step = CurrentStep(); + + if (step.has_fusion()) { + return {GetInstructionWithName(step.fusion().consumer_name())}; + } + + if (step.has_update_priority()) { + absl::InlinedVector consumers; + for (const auto& consumer_name : step.update_priority().consumer_names()) { + consumers.push_back(GetInstructionWithName(consumer_name)); + } + return consumers; + } + + return {}; +} + +const FusionStep& FusionProcessDump::CurrentStep() { + CHECK(HasNext()); + return fusion_process_dump_proto_.fusion_steps(current_step_idx_); +} + +bool FusionProcessDump::HasNext() { + return current_step_idx_ < fusion_process_dump_proto_.fusion_steps_size(); +} + +void FusionProcessDump::Advance() { + auto step = CurrentStep(); + if (step.has_fusion()) { + const auto& fusion_step = step.fusion(); + + auto* computation = GetCurrentComputation(); + + HloInstruction* producer = + computation->GetInstructionWithName(fusion_step.producer_name()); + HloInstruction* consumer = + computation->GetInstructionWithName(fusion_step.consumer_name()); + + HloInstruction* fusion = + Fuse(producer, consumer, computation, fusion_step.fusion_name()); + + instruction_name_to_computation_map_[fusion->name()] = computation; + last_fusion_ = fusion; + } + ++current_step_idx_; +} + +} // namespace gpu +} // namespace xla diff --git a/third_party/xla/xla/service/gpu/fusion_process_dump.h b/third_party/xla/xla/service/gpu/fusion_process_dump.h new file mode 100644 index 00000000000000..782702d116d19d --- /dev/null +++ b/third_party/xla/xla/service/gpu/fusion_process_dump.h @@ -0,0 +1,119 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_GPU_FUSION_PROCESS_DUMP_H_ +#define XLA_SERVICE_GPU_FUSION_PROCESS_DUMP_H_ + +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/container/inlined_vector.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/service/gpu/fusion_process_dump.pb.h" +#include "xla/stream_executor/device_description.h" +#include "xla/stream_executor/stream_executor.h" + +namespace xla { +namespace gpu { + +// Helper class to work with fusion process dump. +class FusionProcessDump { + public: + static absl::StatusOr LoadFromFile( + const std::string& path); + static absl::StatusOr LoadFromData( + const std::string& data, absl::string_view format); + static absl::StatusOr LoadFromProto( + const FusionProcessDumpProto& fusion_process_dump_proto); + + const FusionProcessDumpProto& proto() { return fusion_process_dump_proto_; } + + HloModule* module() { return hlo_module_.get(); } + + const se::DeviceDescription& device_info() { return device_info_; } + + int64_t current_step_idx() { return current_step_idx_; } + + // Returns computation that contains producer (and other instructions) of the + // current step. + HloComputation* GetCurrentComputation(); + + // Returns the instruction with `name`. + HloInstruction* GetInstructionWithName(absl::string_view name); + + // Returns producer of the current step. Should not be null, since all step + // types have a producer. + HloInstruction* GetProducer(); + + // Returns a list of consumers of the current step. The list contains one + // instruction is the current step is fusion. The list is empty if the current + // step is `producer_ineligible`. + absl::InlinedVector GetConsumers(); + + // Returns result instruction of the last fusion step. Returns nullptr before + // the first fusion. + HloInstruction* GetLastFusion() { return last_fusion_; } + + // Returns current step. If current step is `fusion`, the `module` is in the + // state *before* the fusion. Next call to `FusionProcessDump::Advance` will + // actualy perform the fusion. + const FusionStep& CurrentStep(); + + // Returns true if there are fusion steps. + bool HasNext(); + + // Advances to the next fusion step. If current step is `fusion`, modifies the + // `module` accordingly. + void Advance(); + + private: + FusionProcessDump(FusionProcessDumpProto fusion_process_dump_proto, + std::unique_ptr hlo_module, + se::DeviceDescription device_info, + absl::flat_hash_map + instruction_name_to_computation_map) + : fusion_process_dump_proto_(std::move(fusion_process_dump_proto)), + hlo_module_(std::move(hlo_module)), + device_info_(std::move(device_info)), + instruction_name_to_computation_map_( + std::move(instruction_name_to_computation_map)) {} + + FusionProcessDumpProto fusion_process_dump_proto_; + std::unique_ptr hlo_module_; + se::DeviceDescription device_info_; + + // A map from instructions to computations. HLO module doesn't have a + // convenient way to get an instruction by name. This map saves the need to + // iterator over all computations in the module. + absl::flat_hash_map + instruction_name_to_computation_map_; + + // Index of the current step. + int64_t current_step_idx_ = 0; + + // Tracks result of the last fusion step. + HloInstruction* last_fusion_ = nullptr; +}; + +} // namespace gpu +} // namespace xla + +#endif // XLA_SERVICE_GPU_FUSION_PROCESS_DUMP_H_ diff --git a/third_party/xla/xla/service/gpu/fusion_process_dump.proto b/third_party/xla/xla/service/gpu/fusion_process_dump.proto index 0fc379441358b1..0c52edb46c09eb 100644 --- a/third_party/xla/xla/service/gpu/fusion_process_dump.proto +++ b/third_party/xla/xla/service/gpu/fusion_process_dump.proto @@ -2,6 +2,8 @@ syntax = "proto3"; package xla.gpu; +import "xla/stream_executor/device_description.proto"; + message FusionStep { message Fusion { // Name of the resulting fusion. Can be the same as producer or consumer. @@ -46,4 +48,13 @@ message FusionStep { message FusionProcessDumpProto { repeated FusionStep fusion_steps = 1; + + stream_executor.GpuDeviceInfoProto gpu_device_info = 2; + + // HLO module before fusion in short parsable string format. The string + // represantation is compacter than HloModuleProto in this case, especially + // when the fusion process dump is stored as text proto. + // + // TODO: Consider using base64 or gzip to decrease the size of the string. + string hlo_module_before_fusion = 3; } diff --git a/third_party/xla/xla/service/gpu/fusion_process_dump_test.cc b/third_party/xla/xla/service/gpu/fusion_process_dump_test.cc new file mode 100644 index 00000000000000..37eb3bee29c836 --- /dev/null +++ b/third_party/xla/xla/service/gpu/fusion_process_dump_test.cc @@ -0,0 +1,94 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/fusion_process_dump.h" + +#include + +#include +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/service/gpu/fusion_process_dump.pb.h" +#include "xla/service/gpu/gpu_device_info_for_tests.h" +#include "xla/service/hlo_parser.h" +#include "xla/service/pattern_matcher.h" +#include "xla/service/pattern_matcher_gmock.h" +#include "xla/test.h" +#include "xla/tests/hlo_test_base.h" +#include "tsl/platform/statusor.h" + +namespace m = ::xla::match; + +namespace xla { +namespace gpu { +namespace { + +using FusionProcessDumpTest = HloTestBase; + +void AddFusion(FusionProcessDumpProto& dump_proto, + const std::string& fusion_name, const std::string& producer_name, + const std::string& consumer_name) { + auto step = dump_proto.add_fusion_steps(); + auto fusion_step = step->mutable_fusion(); + fusion_step->set_fusion_name(fusion_name); + fusion_step->set_producer_name(producer_name); + fusion_step->set_consumer_name(consumer_name); +} + +TEST_F(FusionProcessDumpTest, MultipleFusionSteps) { + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(R"( + HloModule test_module + + ENTRY main { + p0 = f32[] parameter(0) + p1 = f32[] parameter(1) + add = f32[] add(p0, p1) + subtract = f32[] subtract(p0, p1) + abs = f32[] abs(subtract) + ROOT multiply = f32[] multiply(add, abs) + })")); + + FusionProcessDumpProto dump_proto; + *dump_proto.mutable_gpu_device_info() = + TestGpuDeviceInfo::RTXA6000DeviceInfo().ToGpuProto(); + dump_proto.set_hlo_module_before_fusion( + module->ToString(HloPrintOptions::ShortParsable())); + + AddFusion(dump_proto, "fusion.1", "subtract", "abs"); + AddFusion(dump_proto, "fusion.2", "fusion.1", "multiply"); + AddFusion(dump_proto, "fusion.2", "add", "fusion.2"); + + TF_ASSERT_OK_AND_ASSIGN(auto fusion_process_dump, + FusionProcessDump::LoadFromProto(dump_proto)); + + fusion_process_dump.Advance(); + fusion_process_dump.Advance(); + fusion_process_dump.Advance(); + + EXPECT_FALSE(fusion_process_dump.HasNext()); + + auto root = + fusion_process_dump.module()->entry_computation()->root_instruction(); + EXPECT_EQ(root->name(), "fusion.2"); + ASSERT_THAT(root, GmockMatch(m::Fusion(m::Parameter(), m::Parameter()))); + EXPECT_THAT(root->fused_expression_root(), + GmockMatch(m::Multiply( + m::Add(m::Parameter(), m::Parameter()), + m::Abs(m::Subtract(m::Parameter(), m::Parameter()))))); +} + +} // namespace +} // namespace gpu +} // namespace xla diff --git a/third_party/xla/xla/service/gpu/fusion_wrapper.cc b/third_party/xla/xla/service/gpu/fusion_wrapper.cc index 7b6495e77f1bd7..3dcb448412fb25 100644 --- a/third_party/xla/xla/service/gpu/fusion_wrapper.cc +++ b/third_party/xla/xla/service/gpu/fusion_wrapper.cc @@ -18,6 +18,7 @@ limitations under the License. #include "absl/container/flat_hash_set.h" #include "absl/status/status.h" +#include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" @@ -67,6 +68,7 @@ absl::StatusOr FusionWrapper::Run( case HloOpcode::kDot: case HloOpcode::kDynamicSlice: case HloOpcode::kDynamicUpdateSlice: + case HloOpcode::kErf: case HloOpcode::kExp: case HloOpcode::kExpm1: case HloOpcode::kFloor: @@ -116,8 +118,13 @@ absl::StatusOr FusionWrapper::Run( computation->AddInstruction(HloInstruction::CreateFusion( instruction->shape(), ChooseFusionKind(*instruction), instruction)); - instruction->GetModule()->SetAndUniquifyInstrName( - fusion_instruction, absl::StrCat("wrapped_", instruction->name())); + const absl::string_view wrapped_opcode = + HloOpcodeString(instruction->opcode()); + module->SetAndUniquifyInstrName( + fusion_instruction, absl::StrCat("wrapped_", wrapped_opcode)); + module->SetAndUniquifyComputationName( + fusion_instruction->fused_instructions_computation(), + absl::StrCat("wrapped_", wrapped_opcode, "_computation")); if (module->has_schedule()) { module->schedule().replace_instruction(computation, instruction, fusion_instruction); diff --git a/third_party/xla/xla/service/gpu/fusion_wrapper_test.cc b/third_party/xla/xla/service/gpu/fusion_wrapper_test.cc index c7860cea8b1229..ad77ad99efb304 100644 --- a/third_party/xla/xla/service/gpu/fusion_wrapper_test.cc +++ b/third_party/xla/xla/service/gpu/fusion_wrapper_test.cc @@ -33,7 +33,7 @@ TEST_F(FusionWrapperTest, SimpleOp) { ROOT result = f16[60, 41] concatenate(p0, p1), dimensions={0} })", FusionWrapper(), R"( -// CHECK: %fused_computation (param_0: f16[30,41], param_1: f16[30,41]) -> f16[60,41] { +// CHECK: %wrapped_concatenate_computation (param_0: f16[30,41], param_1: f16[30,41]) -> f16[60,41] { // CHECK: %param_0 = f16[30,41]{1,0} parameter(0) // CHECK: %param_1 = f16[30,41]{1,0} parameter(1) // CHECK: ROOT %result.1 = f16[60,41]{1,0} concatenate(%param_0, %param_1), dimensions={0} @@ -42,7 +42,7 @@ TEST_F(FusionWrapperTest, SimpleOp) { // CHECK: ENTRY %TestComputation (p0: f16[30,41], p1: f16[30,41]) -> f16[60,41] { // CHECK: %p0 = f16[30,41]{1,0} parameter(0) // CHECK: %p1 = f16[30,41]{1,0} parameter(1) -// CHECK: ROOT %wrapped_result = f16[60,41]{1,0} fusion(%p0, %p1), kind=kLoop, calls=%fused_computation +// CHECK: ROOT %wrapped_concatenate = f16[60,41]{1,0} fusion(%p0, %p1), kind=kLoop, calls=%wrapped_concatenate_computation // CHECK: })"); } @@ -67,7 +67,7 @@ TEST_F(FusionWrapperTest, Scatter) { to_apply=update_s32 })", FusionWrapper(), R"( -// CHECK: fused_computation +// CHECK: wrapped_scatter_computation // CHECK: %[[param_0:.*]] = s32[] parameter(0) // CHECK: %[[param_1:.*]] = s32[0]{0} parameter(1) // CHECK: %[[param_2:.*]] = s32[] parameter(2) @@ -77,7 +77,7 @@ TEST_F(FusionWrapperTest, Scatter) { // CHECK: %[[p0:.*]] = s32[] parameter(0) // CHECK: %[[p1:.*]] = s32[0]{0} parameter(1) // CHECK: %[[p2:.*]] = s32[] parameter(2) -// CHECK: ROOT %{{.*}} = s32[] fusion(%[[p0]], %[[p1]], %[[p2]]), kind=kInput, calls=%fused_computation +// CHECK: ROOT %{{.*}} = s32[] fusion(%[[p0]], %[[p1]], %[[p2]]), kind=kInput, calls=%wrapped_scatter_computation // CHECK: })"); } @@ -123,28 +123,28 @@ TEST_F(FusionWrapperTest, While) { ROOT %while.19 = (f32[5]{0}) while((f32[5]{0}) %tuple), condition=%cond, body=%body })", FusionWrapper(), R"( -// CHECK: %fused_computation.1 {{.*}} { +// CHECK: %wrapped_broadcast_computation {{.*}} { // CHECK: %param_0.1 = f32[] parameter(0) // CHECK: ROOT %broadcast.0 = f32[5]{0} broadcast(%param_0.1), dimensions={} // CHECK: } // CHECK: %body {{.*}} { // CHECK: %parameter.5 = (f32[5]{0}) parameter(0) // CHECK: %constant_8 = f32[] constant(0) -// CHECK: %wrapped_broadcast.9 = f32[5]{0} fusion(%constant_8), kind=kLoop, calls=%fused_computation.1 -// CHECK: ROOT %tuple.2 = (f32[5]{0}) tuple(%wrapped_broadcast.9) +// CHECK: %wrapped_broadcast = f32[5]{0} fusion(%constant_8), kind=kLoop, calls=%wrapped_broadcast_computation +// CHECK: ROOT %tuple.2 = (f32[5]{0}) tuple(%wrapped_broadcast) // CHECK: } // CHECK: %cond {{.*}} { // CHECK: %parameter.12 = (f32[5]{0}) parameter(0) // CHECK: ROOT %constant_1 = pred[] constant(false) // CHECK: } -// CHECK: %fused_computation {{.*}} { +// CHECK: %wrapped_copy_computation {{.*}} { // CHECK: %param_0 = f32[5]{0} parameter(0) // CHECK: ROOT %copy.0 = f32[5]{0} copy(%param_0) // CHECK: } // CHECK: ENTRY %main {{.*}} { // CHECK: %parameter.1 = f32[5]{0} parameter(0) -// CHECK: %wrapped_copy.3 = f32[5]{0} fusion(%parameter.1), kind=kLoop, calls=%fused_computation -// CHECK: %tuple = (f32[5]{0}) tuple(%wrapped_copy.3) +// CHECK: %wrapped_copy = f32[5]{0} fusion(%parameter.1), kind=kLoop, calls=%wrapped_copy_computation +// CHECK: %tuple = (f32[5]{0}) tuple(%wrapped_copy) // CHECK: ROOT %while.19 = (f32[5]{0}) while(%tuple), condition=%cond, body=%body // CHECK: })"); } diff --git a/third_party/xla/xla/service/gpu/fusions/BUILD b/third_party/xla/xla/service/gpu/fusions/BUILD index 754305da368ba5..3d602cfb04c0fe 100644 --- a/third_party/xla/xla/service/gpu/fusions/BUILD +++ b/third_party/xla/xla/service/gpu/fusions/BUILD @@ -1,9 +1,9 @@ load("//xla/tests:build_defs.bzl", "xla_test") load("//xla:xla.bzl", "xla_cc_test") +load("@local_config_rocm//rocm:build_defs.bzl", "if_rocm_is_configured") load("@local_tsl//tsl/platform/default:cuda_build_defs.bzl", "if_cuda_is_configured") package( - default_visibility = ["//visibility:public"], # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], licenses = ["notice"], ) @@ -12,7 +12,6 @@ cc_library( name = "in_place_dynamic_update_slice", srcs = ["in_place_dynamic_update_slice.cc"], hdrs = ["in_place_dynamic_update_slice.h"], - visibility = ["//visibility:public"], deps = [ ":fusion_emitter", "//xla:status", @@ -36,7 +35,6 @@ cc_library( name = "copy", srcs = ["copy.cc"], hdrs = ["copy.h"], - visibility = ["//visibility:public"], deps = [ ":fusion_emitter", "//xla:statusor", @@ -45,7 +43,7 @@ cc_library( "//xla/service:buffer_assignment", "//xla/service/gpu:ir_emitter_context", "//xla/service/gpu:thunk", - "//xla/service/gpu/runtime3:copy_thunk", + "//xla/service/gpu/runtime:copy_thunk", "@llvm-project//mlir:IR", ], ) @@ -54,15 +52,20 @@ cc_library( name = "custom", srcs = ["custom.cc"], hdrs = ["custom.h"], - visibility = ["//visibility:public"], + local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]), deps = [ ":fusion_emitter", "//xla:shape_util", - "//xla:status_macros", + "//xla:status", "//xla:statusor", + "//xla:util", + "//xla/ffi:ffi_api", + "//xla/ffi/api:c_api", "//xla/hlo/ir:hlo", "//xla/mlir_hlo:lhlo", "//xla/service:buffer_assignment", + "//xla/service:custom_call_status", + "//xla/service:custom_call_target_registry", "//xla/service/gpu:backend_configs_cc", "//xla/service/gpu:cublas_cudnn", "//xla/service/gpu:hlo_fusion_analysis", @@ -74,43 +77,63 @@ cc_library( "//xla/service/gpu:thunk", "//xla/service/gpu/kernels:custom_kernel", "//xla/service/gpu/kernels:custom_kernel_fusion", - "//xla/service/gpu/runtime3:gemm_thunk", - "//xla/service/gpu/runtime3:kernel_thunk", + "//xla/service/gpu/runtime:custom_call_thunk", + "//xla/service/gpu/runtime:gemm_thunk", + "//xla/service/gpu/runtime:kernel_thunk", "@com_google_absl//absl/log", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@llvm-project//llvm:Support", + "@llvm-project//mlir:AsmParser", "@llvm-project//mlir:IR", + "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:statusor", ], ) xla_test( name = "address_computation_fusion_test", - srcs = ["address_computation_fusion_test.cc"], + srcs = if_cuda_is_configured(["address_computation_fusion_test.cc"]), backends = ["gpu"], + local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]), deps = [ - "//xla:array3d", - "//xla:array4d", "//xla:error_spec", - "//xla:literal_util", - "//xla:types", + "//xla:shape_util", + "//xla/client:xla_builder", + "//xla/client:xla_computation", + "//xla/client/lib:constants", + "//xla/ffi", + "//xla/ffi:ffi_api", + "//xla/hlo/ir:hlo", + "//xla/service:custom_call_target_registry", + "//xla/service:executable", + "//xla/service:hlo_module_config", + "//xla/stream_executor:device_description", + "//xla/stream_executor/gpu:gpu_types_header", "//xla/tests:hlo_test_base", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/status", + "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/platform:test", "@local_tsl//tsl/platform:test_main", - ], + ] + if_cuda_is_configured([ + "@local_config_cuda//cuda:cuda_headers", + ]) + if_rocm_is_configured([ + "@local_config_rocm//rocm:rocm_headers", + ]), ) cc_library( name = "fusion_emitter", srcs = ["fusion_emitter.cc"], hdrs = ["fusion_emitter.h"], - visibility = ["//visibility:public"], + visibility = ["//xla/service/gpu:__subpackages__"], deps = [ "//xla:shape_util", "//xla:status", "//xla:status_macros", "//xla:statusor", + "//xla:util", "//xla/hlo/ir:hlo", "//xla/mlir_hlo:lhlo", "//xla/service/gpu:ir_emitter_context", @@ -121,7 +144,7 @@ cc_library( "//xla/service/gpu:thunk", "//xla/service/gpu/model:indexing_analysis", "//xla/service/gpu/model:indexing_map", - "//xla/service/gpu/runtime3:kernel_thunk", + "//xla/service/gpu/runtime:kernel_thunk", "//xla/service/llvm_ir:ir_array", "//xla/service/llvm_ir:llvm_util", "//xla/stream_executor:device_description", @@ -130,6 +153,7 @@ cc_library( "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", "@llvm-project//llvm:Support", + "@llvm-project//llvm:TargetParser", "@llvm-project//llvm:ir_headers", "@llvm-project//mlir:IR", "@local_tsl//tsl/platform:errors", @@ -141,7 +165,7 @@ cc_library( name = "fusions", srcs = ["fusions.cc"], hdrs = ["fusions.h"], - visibility = ["//visibility:public"], + visibility = ["//xla/service/gpu:__subpackages__"], deps = [ ":concatenate", ":copy", @@ -150,6 +174,7 @@ cc_library( ":in_place_dynamic_update_slice", ":input_slices", ":loop", + ":loop_mlir", ":reduction", ":scatter", ":transpose", @@ -162,6 +187,7 @@ cc_library( "//xla/service:buffer_assignment", "//xla/service/gpu:hlo_fusion_analysis", "//xla/service/gpu:ir_emission_utils", + "//xla/service/gpu/fusions/mlir:elemental_hlo_to_mlir", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/log:check", "@com_google_absl//absl/types:span", @@ -175,7 +201,6 @@ cc_library( name = "loop", srcs = ["loop.cc"], hdrs = ["loop.h"], - visibility = ["//visibility:public"], deps = [ ":fusion_emitter", "//xla:shape_util", @@ -190,6 +215,7 @@ cc_library( "//xla/service/gpu:launch_dimensions", "//xla/service/gpu:parallel_loop_emitter", "//xla/service/gpu/model:indexing_analysis", + "//xla/service/gpu/model:indexing_map", "//xla/service/llvm_ir:fused_ir_emitter", "//xla/service/llvm_ir:ir_array", "@com_google_absl//absl/numeric:bits", @@ -198,6 +224,64 @@ cc_library( ], ) +cc_library( + name = "loop_mlir", + srcs = ["loop_mlir.cc"], + hdrs = ["loop_mlir.h"], + deps = [ + ":loop", + "//xla:shape_util", + "//xla:status", + "//xla:status_macros", + "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/service/gpu:hlo_fusion_analysis", + "//xla/service/gpu:launch_dimensions", + "//xla/service/gpu/fusions/mlir:computation_partitioner", + "//xla/service/gpu/fusions/mlir:elemental_hlo_to_mlir", + "//xla/service/gpu/fusions/mlir:mlir_fusion_emitter", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/log", + "@com_google_absl//absl/status", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:TensorDialect", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:statusor", + ], +) + +xla_cc_test( + name = "loop_mlir_test", + srcs = ["loop_mlir_test.cc"], + deps = [ + ":loop_mlir", + "//xla/hlo/ir:hlo", + "//xla/mlir_hlo", + "//xla/service/gpu:gpu_device_info_for_tests", + "//xla/service/gpu:hlo_fusion_analysis", + "//xla/stream_executor:device_description", + "//xla/tests:filecheck", + "//xla/tests:hlo_test_base", + "//xla/tests:xla_internal_test_main", + "@com_google_googletest//:gtest", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:AffineDialect", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:FuncExtensions", + "@llvm-project//mlir:GPUDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:MathDialect", + "@llvm-project//mlir:MemRefTransforms", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:SCFDialect", + "@llvm-project//mlir:TensorDialect", + "@local_tsl//tsl/platform:statusor", + ], +) + xla_cc_test( name = "loop_test", srcs = ["loop_test.cc"], @@ -223,7 +307,6 @@ cc_library( name = "scatter", srcs = ["scatter.cc"], hdrs = ["scatter.h"], - visibility = ["//visibility:public"], deps = [ ":fusion_emitter", "//xla:shape_util", @@ -268,10 +351,9 @@ cc_library( name = "tiling_util", srcs = ["tiling_util.cc"], hdrs = ["tiling_util.h"], - visibility = ["//visibility:public"], + visibility = ["//xla/service/gpu:__subpackages__"], deps = [ "//xla:shape_util", - "//xla:statusor", "//xla:util", "//xla/service/gpu:ir_emission_utils", "//xla/service/gpu:target_util", @@ -280,6 +362,7 @@ cc_library( "//xla/service/llvm_ir:llvm_loop", "//xla/service/llvm_ir:llvm_util", "@com_google_absl//absl/log:check", + "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", "@llvm-project//llvm:Support", "@llvm-project//llvm:ir_headers", @@ -292,7 +375,6 @@ cc_library( srcs = ["triton.cc"], hdrs = ["triton.h"], local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]), - visibility = ["//visibility:public"], deps = [ ":fusion_emitter", "//xla:statusor", @@ -308,7 +390,7 @@ cc_library( "//xla/service/gpu:launch_dimensions", "//xla/service/gpu:matmul_utils", "//xla/service/gpu:triton_fusion_analysis", - "//xla/service/gpu/runtime3:kernel_thunk", + "//xla/service/gpu/runtime:kernel_thunk", "//xla/service/llvm_ir:llvm_util", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", @@ -341,20 +423,18 @@ cc_library( name = "thunk_util", srcs = ["thunk_util.cc"], hdrs = ["thunk_util.h"], - visibility = ["//visibility:public"], + visibility = ["//xla/service/gpu:__subpackages__"], deps = [ "//xla:literal", "//xla:shape_util", - "//xla:statusor", "//xla/hlo/ir:hlo", "//xla/service:buffer_assignment", - "//xla/service/gpu:gpu_executable", "//xla/service/gpu:ir_emitter_context", "//xla/service/gpu:thunk", - "//xla/service/gpu/runtime3:memset_thunk", + "//xla/service/gpu/runtime:memset_thunk", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/types:span", - "@llvm-project//mlir:IR", ], ) @@ -362,7 +442,6 @@ cc_library( name = "reduction", srcs = ["reduction.cc"], hdrs = ["reduction.h"], - visibility = ["//visibility:public"], deps = [ ":fusion_emitter", ":thunk_util", @@ -389,7 +468,9 @@ cc_library( "//xla/service/gpu:reduction_utils", "//xla/service/gpu:target_util", "//xla/service/gpu:thunk", - "//xla/service/gpu/runtime3:kernel_thunk", + "//xla/service/gpu/model:indexing_analysis", + "//xla/service/gpu/model:indexing_map", + "//xla/service/gpu/runtime:kernel_thunk", "//xla/service/llvm_ir:fused_ir_emitter", "//xla/service/llvm_ir:ir_array", "//xla/service/llvm_ir:kernel_support_library", @@ -417,11 +498,30 @@ cc_library( ], ) +xla_cc_test( + name = "reduction_test", + srcs = ["reduction_test.cc"], + deps = [ + ":fusions", + ":reduction", + "//xla:status_macros", + "//xla:statusor", + "//xla/service/gpu:gpu_device_info_for_tests", + "//xla/service/gpu:hlo_fusion_analysis", + "//xla/service/gpu/model:indexing_test_utils", + "//xla/stream_executor:device_description", + "//xla/tests:hlo_test_base", + "//xla/tests:xla_internal_test_main", + "@com_google_googletest//:gtest", + "@llvm-project//mlir:IR", + "@local_tsl//tsl/platform:statusor", + ], +) + cc_library( name = "concatenate", srcs = ["concatenate.cc"], hdrs = ["concatenate.h"], - visibility = ["//visibility:public"], deps = [ ":fusion_emitter", "//xla:shape_util", @@ -450,7 +550,6 @@ cc_library( name = "transpose", srcs = ["transpose.cc"], hdrs = ["transpose.h"], - visibility = ["//visibility:public"], deps = [ ":fusion_emitter", ":tiling_util", @@ -464,17 +563,40 @@ cc_library( "//xla/service/gpu:ir_emitter_context", "//xla/service/gpu:launch_dimensions", "//xla/service/gpu:target_util", + "//xla/service/gpu/model:indexing_analysis", + "//xla/service/gpu/model:indexing_map", "//xla/service/llvm_ir:fused_ir_emitter", "//xla/service/llvm_ir:ir_array", "//xla/service/llvm_ir:llvm_util", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/types:span", "@llvm-project//llvm:Support", "@llvm-project//llvm:ir_headers", + "@llvm-project//mlir:IR", + ], +) + +xla_cc_test( + name = "transpose_test", + srcs = ["transpose_test.cc"], + deps = [ + ":fusions", + ":transpose", + "//xla:status_macros", + "//xla:statusor", + "//xla/service/gpu:gpu_device_info_for_tests", + "//xla/service/gpu:hlo_fusion_analysis", + "//xla/service/gpu/model:indexing_test_utils", + "//xla/stream_executor:device_description", + "//xla/tests:hlo_test_base", + "//xla/tests:xla_internal_test_main", + "@com_google_googletest//:gtest", + "@llvm-project//mlir:IR", + "@local_tsl//tsl/platform:statusor", ], ) @@ -482,7 +604,6 @@ cc_library( name = "input_slices", srcs = ["input_slices.cc"], hdrs = ["input_slices.h"], - visibility = ["//visibility:public"], deps = [ ":fusion_emitter", "//xla:shape_util", diff --git a/third_party/xla/xla/service/gpu/fusions/address_computation_fusion_test.cc b/third_party/xla/xla/service/gpu/fusions/address_computation_fusion_test.cc index 14279c810a1190..1ec79b4622227a 100644 --- a/third_party/xla/xla/service/gpu/fusions/address_computation_fusion_test.cc +++ b/third_party/xla/xla/service/gpu/fusions/address_computation_fusion_test.cc @@ -12,19 +12,78 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/array3d.h" -#include "xla/array4d.h" + +#include +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/status/status.h" +#include "xla/client/lib/constants.h" +#include "xla/client/xla_builder.h" #include "xla/error_spec.h" -#include "xla/literal_util.h" +#include "xla/ffi/ffi.h" +#include "xla/ffi/ffi_api.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/service/custom_call_target_registry.h" +#include "xla/service/hlo_module_config.h" +#include "xla/service/service_executable_run_options.h" +#include "xla/shape.h" +#include "xla/shape_util.h" +#include "xla/stream_executor/gpu/gpu_types.h" #include "xla/tests/hlo_test_base.h" -#include "xla/types.h" +#include "tsl/platform/statusor.h" #include "tsl/platform/test.h" +#define PLATFORM "CUDA" +#if GOOGLE_CUDA +#include "third_party/gpus/cuda/include/cuda.h" // IWYU pragma: keep +#include "third_party/gpus/cuda/include/cuda_runtime_api.h" +#include "third_party/gpus/cuda/include/driver_types.h" +#elif TENSORFLOW_USE_ROCM +#include "rocm/include/hip/hip_runtime.h" +#define PLATFORM "ROCM" +#endif + +#if GOOGLE_CUDA +#define gpuSuccess cudaSuccess +#define gpuMemcpyAsync cudaMemcpyAsync +#define gpuMemcpyDeviceToDevice cudaMemcpyDeviceToDevice +#define gpuMemcpy cudaMemcpy +#define gpuMemcpyDeviceToHost cudaMemcpyDeviceToHost +#define gpuMemcpyHostToDevice cudaMemcpyHostToDevice +#elif TENSORFLOW_USE_ROCM +#define gpuSuccess hipSuccess +#define gpuMemcpyAsync hipMemcpyAsync +#define gpuMemcpyDeviceToDevice hipMemcpyDeviceToDevice +#define gpuMemcpy hipMemcpy +#define gpuMemcpyDeviceToHost hipMemcpyDeviceToHost +#define gpuMemcpyHostToDevice hipMemcpyHostToDevice +#endif + namespace xla { namespace gpu { namespace { -class AddressComputationFusionTest : public HloTestBase {}; +class AddressComputationFusionTest : public HloTestBase { + public: + HloModuleConfig GetRefModuleConfig() { + DebugOptions debug_options = GetDebugOptionsForTest(); + debug_options.set_xla_gpu_enable_address_computation_fusion(false); + HloModuleConfig config; + config.set_debug_options(debug_options); + return config; + } + + HloModuleConfig GetOptModuleConfig() { + DebugOptions debug_options = GetDebugOptionsForTest(); + debug_options.set_xla_gpu_enable_address_computation_fusion(true); + HloModuleConfig config; + config.set_debug_options(debug_options); + return config; + } +}; TEST_F(AddressComputationFusionTest, CublasGemmSimple) { ErrorSpec error_spec{/*aabs=*/1e-3, /*arel=*/1e-3}; @@ -100,15 +159,8 @@ TEST_F(AddressComputationFusionTest, CublasGemmSimple) { backend_config={"fusion_backend_config":{"kind":"__custom_fusion","custom_fusion_config":{"name":"address_computation"}}} })"; - Array3D arr0(2, 8, 8); // bf16[2,8,8] - Array3D arr1(2, 8, 8); // bf16[2,8,8] - arr0.FillIota(static_cast(1.0)); - arr1.FillRandom(bfloat16(0.01f), 0.02); - - auto a0 = LiteralUtil::CreateFromArray(arr0); - auto a1 = LiteralUtil::CreateFromArray(arr1); - - EXPECT_TRUE(RunAndCompareTwoModules(hlo_ref, hlo_opt, {&a0, &a1}, error_spec, + EXPECT_TRUE(RunAndCompareTwoModules(hlo_ref, hlo_opt, GetRefModuleConfig(), + GetOptModuleConfig(), error_spec, /*run_hlo_passes=*/false)); } @@ -189,15 +241,8 @@ TEST_F(AddressComputationFusionTest, CublasGemmWithWorkspace) { backend_config={"fusion_backend_config":{"kind":"__custom_fusion","custom_fusion_config":{"name":"address_computation"}}} })"; - Array3D arr0(2, 8, 8); // bf16[2,8,8] - Array3D arr1(2, 8, 8); // bf16[2,8,8] - arr0.FillRandom(bfloat16(0.01f), 0.02); - arr1.FillIota(static_cast(10.0)); - - auto a0 = LiteralUtil::CreateFromArray(arr0); - auto a1 = LiteralUtil::CreateFromArray(arr1); - - EXPECT_TRUE(RunAndCompareTwoModules(hlo_ref, hlo_opt, {&a0, &a1}, error_spec, + EXPECT_TRUE(RunAndCompareTwoModules(hlo_ref, hlo_opt, GetRefModuleConfig(), + GetOptModuleConfig(), error_spec, /*run_hlo_passes=*/false)); } @@ -275,15 +320,8 @@ TEST_F(AddressComputationFusionTest, ContiguousSlice) { backend_config={"fusion_backend_config":{"kind":"__custom_fusion","custom_fusion_config":{"name":"address_computation"}}} })"; - Array3D arr0(2, 8, 8); // bf16[2,8,8] - Array4D arr1(8, 8, 10, 8); // bf16[8,8,10,8] - arr0.FillIota(static_cast(1.0)); - arr1.FillRandom(bfloat16(0.01f), 0.02); - - auto a0 = LiteralUtil::CreateFromArray(arr0); - auto a1 = LiteralUtil::CreateFromArray(arr1); - - EXPECT_TRUE(RunAndCompareTwoModules(hlo_ref, hlo_opt, {&a0, &a1}, error_spec, + EXPECT_TRUE(RunAndCompareTwoModules(hlo_ref, hlo_opt, GetRefModuleConfig(), + GetOptModuleConfig(), error_spec, /*run_hlo_passes=*/false)); } @@ -361,15 +399,720 @@ TEST_F(AddressComputationFusionTest, ContiguousSliceNonDefaultLayout) { backend_config={"fusion_backend_config":{"kind":"__custom_fusion","custom_fusion_config":{"name":"address_computation"}}} })"; - Array3D arr0(2, 8, 8); // bf16[2,8,8] - Array4D arr1(8, 8, 10, 8); // bf16[8,8,10,8] - arr0.FillIota(static_cast(1.0)); - arr1.FillRandom(bfloat16(0.01f), 0.02); + EXPECT_TRUE(RunAndCompareTwoModules(hlo_ref, hlo_opt, GetRefModuleConfig(), + GetOptModuleConfig(), error_spec, + /*run_hlo_passes=*/false)); +} + +TEST_F(AddressComputationFusionTest, OperandIsSlicedGetTupleElement) { + ErrorSpec error_spec{/*aabs=*/1e-3, /*arel=*/1e-3}; + + const char* hlo_ref = R"( + HloModule jit_slice + + ENTRY %main { + %p0 = (f32[100,100]{1,0}, f32[100,100]{1,0}) parameter(0) + %get-tuple-element.240 = f32[100,100]{1,0} get-tuple-element(%p0), index=0 + %get-tuple-element.241 = f32[100,100]{1,0} get-tuple-element(%p0), index=1 + %concatenate.10 = f32[200,100]{1,0} concatenate(%get-tuple-element.240, %get-tuple-element.241), dimensions={0} + %custom-call.16 = (f32[200,100]{1,0}, s8[120000]{0}) custom-call(%concatenate.10, %get-tuple-element.240), + custom_call_target="__cublas$gemm", + backend_config={ + "gemm_backend_config":{ + "alpha_real":1, + "beta":0, + "dot_dimension_numbers":{ + "lhs_contracting_dimensions":["1"], + "rhs_contracting_dimensions":["0"], + "lhs_batch_dimensions":[], + "rhs_batch_dimensions":[] + }, + "alpha_imag":0, + "precision_config":{"operand_precision":["HIGHEST","HIGHEST"]}, + "epilogue":"DEFAULT", + "lhs_stride":"20000", + "rhs_stride":"10000", + "grad_x":false, + "grad_y":false + } + } + %get-tuple-element.97 = f32[200,100]{1,0} get-tuple-element(%custom-call.16), index=0 + %slice.26 = f32[100,100]{1,0} slice(%get-tuple-element.97), slice={[0:100], [0:100]} + ROOT %custom-call.17 = (f32[100,100]{1,0}, s8[80000]{0}) custom-call(%slice.26, %get-tuple-element.240), + custom_call_target="__cublas$gemm", + backend_config={ + "gemm_backend_config":{ + "alpha_real":1, + "beta":0, + "dot_dimension_numbers":{ + "lhs_contracting_dimensions":["1"], + "rhs_contracting_dimensions":["0"], + "lhs_batch_dimensions":[], + "rhs_batch_dimensions":[] + }, + "alpha_imag":0, + "precision_config":{"operand_precision":["HIGHEST","HIGHEST"]}, + "epilogue":"DEFAULT", + "lhs_stride":"10000", + "rhs_stride":"10000", + "grad_x":false, + "grad_y":false + } + } + })"; + + const char* hlo_opt = R"( + HloModule jit_slice + + %address-computation { + %p0.3 = f32[200,100]{1,0} parameter(0) + %p1.3 = f32[100,100]{1,0} parameter(1) + %slice.56 = f32[100,100]{1,0} slice(%p0.3), slice={[0:100], [0:100]} + %cublas-gemm.23 = (f32[100,100]{1,0}, s8[80000]{0}) custom-call(%slice.56, %p1.3), + custom_call_target="__cublas$gemm", + backend_config={ + "gemm_backend_config":{ + "alpha_real":1, + "beta":0, + "dot_dimension_numbers":{ + "lhs_contracting_dimensions":["1"], + "rhs_contracting_dimensions":["0"], + "lhs_batch_dimensions":[], + "rhs_batch_dimensions":[] + }, + "alpha_imag":0, + "precision_config":{"operand_precision":["HIGHEST","HIGHEST"]}, + "epilogue":"DEFAULT", + "lhs_stride":"10000", + "rhs_stride":"10000", + "grad_x":false, + "grad_y":false + } + } + %get-tuple-element.221 = f32[100,100]{1,0} get-tuple-element(%cublas-gemm.23), index=0 + %get-tuple-element.222 = s8[80000]{0} get-tuple-element(%cublas-gemm.23), index=1 + ROOT %tuple.58 = (f32[100,100]{1,0}, s8[80000]{0}) tuple(%get-tuple-element.221, %get-tuple-element.222) + } + + ENTRY %main { + %p0 = (f32[100,100]{1,0}, f32[100,100]{1,0}) parameter(0) + %get-tuple-element.240 = f32[100,100]{1,0} get-tuple-element(%p0), index=0 + %get-tuple-element.241 = f32[100,100]{1,0} get-tuple-element(%p0), index=1 + %concatenate.10 = f32[200,100]{1,0} concatenate(%get-tuple-element.240, %get-tuple-element.241), dimensions={0} + %custom-call.16 = (f32[200,100]{1,0}, s8[120000]{0}) custom-call(%concatenate.10, %get-tuple-element.240), + custom_call_target="__cublas$gemm", + backend_config={ + "gemm_backend_config":{ + "alpha_real":1, + "beta":0, + "dot_dimension_numbers":{ + "lhs_contracting_dimensions":["1"], + "rhs_contracting_dimensions":["0"], + "lhs_batch_dimensions":[], + "rhs_batch_dimensions":[] + }, + "alpha_imag":0, + "precision_config":{"operand_precision":["HIGHEST","HIGHEST"]}, + "epilogue":"DEFAULT", + "lhs_stride":"20000", + "rhs_stride":"10000", + "grad_x":false, + "grad_y":false + } + } + %get-tuple-element.97 = f32[200,100]{1,0} get-tuple-element(%custom-call.16), index=0 + ROOT %address_computation.6 = (f32[100,100]{1,0}, s8[80000]{0}) fusion(%get-tuple-element.97, %get-tuple-element.240), + kind=kCustom, + calls=%address-computation, + backend_config={ + "fusion_backend_config":{ + "kind":"__custom_fusion","custom_fusion_config":{"name":"address_computation"} + } + } + })"; + + EXPECT_TRUE(RunAndCompareTwoModules(hlo_ref, hlo_opt, GetRefModuleConfig(), + GetOptModuleConfig(), error_spec, + /*run_hlo_passes=*/false)); +} + +TEST_F(AddressComputationFusionTest, ReversedOperandOrder) { + ErrorSpec error_spec{/*aabs=*/1e-3, /*arel=*/1e-3}; + + const char* hlo_ref = R"( + HloModule jit_slice + + ENTRY %main.9 { + %p0 = f16[2,8,8]{2,1,0} parameter(0) + %slice.13 = f16[1,8,8]{2,1,0} slice(%p0), slice={[0:1], [0:8], [0:8]} + %bitcast.41 = f16[8,8]{1,0} bitcast(%slice.13) + %p1 = f16[2,8,8]{2,1,0} parameter(1) + %slice.14 = f16[1,8,8]{2,1,0} slice(%p1), slice={[1:2], [0:8], [0:8]} + %bitcast.42 = f16[8,8]{1,0} bitcast(%slice.14) + + ROOT %custom-call.1 = f16[8,8]{1,0} custom-call(%bitcast.42, %bitcast.41), + custom_call_target="__cublas$gemm", + backend_config={"gemm_backend_config":{ + "alpha_real":1, + "beta":0, + "dot_dimension_numbers":{ + "lhs_contracting_dimensions":["1"], + "rhs_contracting_dimensions":["0"], + "lhs_batch_dimensions":[], + "rhs_batch_dimensions":[] + }, + "alpha_imag":0, + "precision_config":{"operand_precision":["DEFAULT","DEFAULT"]}, + "epilogue":"DEFAULT", + "lhs_stride":"64", + "rhs_stride":"64", + "grad_x":false, + "grad_y":false + }} + })"; + + const char* hlo_opt = R"( + HloModule jit_slice + + %address-computation { + %p0.1 = f16[2,8,8]{2,1,0} parameter(0) + %slice.1 = f16[1,8,8]{2,1,0} slice(%p0.1), slice={[1:2], [0:8], [0:8]} + %bitcast.1 = f16[8,8]{1,0} bitcast(%slice.1) + %p1.1 = f16[2,8,8]{2,1,0} parameter(1) + %slice.0 = f16[1,8,8]{2,1,0} slice(%p1.1), slice={[0:1], [0:8], [0:8]} + %bitcast.0 = f16[8,8]{1,0} bitcast(%slice.0) + ROOT %custom-call.0 = f16[8,8]{1,0} custom-call(%bitcast.1, %bitcast.0), + custom_call_target="__cublas$gemm", + backend_config={ + "gemm_backend_config":{ + "alpha_real":1, + "beta":0, + "dot_dimension_numbers":{ + "lhs_contracting_dimensions":["1"], + "rhs_contracting_dimensions":["0"], + "lhs_batch_dimensions":[], + "rhs_batch_dimensions":[] + }, + "alpha_imag":0, + "precision_config":{"operand_precision":["DEFAULT","DEFAULT"]}, + "epilogue":"DEFAULT", + "lhs_stride":"64", + "rhs_stride":"64", + "grad_x":false, + "grad_y":false + } + } + } + + ENTRY %main { + %p0 = f16[2,8,8]{2,1,0} parameter(0) + %p1 = f16[2,8,8]{2,1,0} parameter(1) + ROOT %address_computation.6 = f16[8,8]{1,0} fusion(%p1, %p0), + kind=kCustom, + calls=%address-computation, + backend_config={ + "fusion_backend_config":{ + "kind":"__custom_fusion","custom_fusion_config":{"name":"address_computation"} + } + } + })"; + + EXPECT_TRUE(RunAndCompareTwoModules(hlo_ref, hlo_opt, GetRefModuleConfig(), + GetOptModuleConfig(), error_spec, + /*run_hlo_passes=*/false)); +} + +TEST_F(AddressComputationFusionTest, SingleOperandComputation) { + ErrorSpec error_spec{/*aabs=*/1e-3, /*arel=*/1e-3}; + + const char* hlo_ref = R"( + HloModule jit_slice + + ENTRY %main { + %p0 = (f32[100,100]{1,0}, f32[100,100]{1,0}) parameter(0) + %get-tuple-element.240 = f32[100,100]{1,0} get-tuple-element(%p0), index=0 + %get-tuple-element.241 = f32[100,100]{1,0} get-tuple-element(%p0), index=1 + %concatenate.10 = f32[200,100]{1,0} concatenate(%get-tuple-element.240, %get-tuple-element.241), dimensions={0} + %custom-call.16 = (f32[200,100]{1,0}, s8[120000]{0}) custom-call(%concatenate.10, %get-tuple-element.240), + custom_call_target="__cublas$gemm", + backend_config={ + "gemm_backend_config":{ + "alpha_real":1, + "beta":0, + "dot_dimension_numbers":{ + "lhs_contracting_dimensions":["1"], + "rhs_contracting_dimensions":["0"], + "lhs_batch_dimensions":[], + "rhs_batch_dimensions":[] + }, + "alpha_imag":0, + "precision_config":{"operand_precision":["HIGHEST","HIGHEST"]}, + "epilogue":"DEFAULT", + "lhs_stride":"20000", + "rhs_stride":"10000", + "grad_x":false, + "grad_y":false + } + } + %get-tuple-element.97 = f32[200,100]{1,0} get-tuple-element(%custom-call.16), index=0 + %slice.26 = f32[100,100]{1,0} slice(%get-tuple-element.97), slice={[0:100], [0:100]} + ROOT %custom-call.17 = (f32[100,100]{1,0}, s8[80000]{0}) custom-call(%slice.26, %slice.26), + custom_call_target="__cublas$gemm", + backend_config={ + "gemm_backend_config":{ + "alpha_real":1, + "beta":0, + "dot_dimension_numbers":{ + "lhs_contracting_dimensions":["1"], + "rhs_contracting_dimensions":["0"], + "lhs_batch_dimensions":[], + "rhs_batch_dimensions":[] + }, + "alpha_imag":0, + "precision_config":{"operand_precision":["HIGHEST","HIGHEST"]}, + "epilogue":"DEFAULT", + "lhs_stride":"10000", + "rhs_stride":"10000", + "grad_x":false, + "grad_y":false + } + } + })"; + + const char* hlo_opt = R"( + HloModule jit_slice + + %address-computation { + %p0.3 = f32[200,100]{1,0} parameter(0) + %slice.56 = f32[100,100]{1,0} slice(%p0.3), slice={[0:100], [0:100]} + %cublas-gemm.23 = (f32[100,100]{1,0}, s8[80000]{0}) custom-call(%slice.56, %slice.56), + custom_call_target="__cublas$gemm", + backend_config={ + "gemm_backend_config":{ + "alpha_real":1, + "beta":0, + "dot_dimension_numbers":{ + "lhs_contracting_dimensions":["1"], + "rhs_contracting_dimensions":["0"], + "lhs_batch_dimensions":[], + "rhs_batch_dimensions":[] + }, + "alpha_imag":0, + "precision_config":{"operand_precision":["HIGHEST","HIGHEST"]}, + "epilogue":"DEFAULT", + "lhs_stride":"10000", + "rhs_stride":"10000", + "grad_x":false, + "grad_y":false + } + } + %get-tuple-element.221 = f32[100,100]{1,0} get-tuple-element(%cublas-gemm.23), index=0 + %get-tuple-element.222 = s8[80000]{0} get-tuple-element(%cublas-gemm.23), index=1 + ROOT %tuple.58 = (f32[100,100]{1,0}, s8[80000]{0}) tuple(%get-tuple-element.221, %get-tuple-element.222) + } + + ENTRY %main { + %p0 = (f32[100,100]{1,0}, f32[100,100]{1,0}) parameter(0) + %get-tuple-element.240 = f32[100,100]{1,0} get-tuple-element(%p0), index=0 + %get-tuple-element.241 = f32[100,100]{1,0} get-tuple-element(%p0), index=1 + %concatenate.10 = f32[200,100]{1,0} concatenate(%get-tuple-element.240, %get-tuple-element.241), dimensions={0} + %custom-call.16 = (f32[200,100]{1,0}, s8[120000]{0}) custom-call(%concatenate.10, %get-tuple-element.240), + custom_call_target="__cublas$gemm", + backend_config={ + "gemm_backend_config":{ + "alpha_real":1, + "beta":0, + "dot_dimension_numbers":{ + "lhs_contracting_dimensions":["1"], + "rhs_contracting_dimensions":["0"], + "lhs_batch_dimensions":[], + "rhs_batch_dimensions":[] + }, + "alpha_imag":0, + "precision_config":{"operand_precision":["HIGHEST","HIGHEST"]}, + "epilogue":"DEFAULT", + "lhs_stride":"20000", + "rhs_stride":"10000", + "grad_x":false, + "grad_y":false + } + } + %get-tuple-element.97 = f32[200,100]{1,0} get-tuple-element(%custom-call.16), index=0 + ROOT %address_computation.6 = (f32[100,100]{1,0}, s8[80000]{0}) fusion(%get-tuple-element.97), + kind=kCustom, + calls=%address-computation, + backend_config={ + "fusion_backend_config":{ + "kind":"__custom_fusion","custom_fusion_config":{"name":"address_computation"} + } + } + })"; + + EXPECT_TRUE(RunAndCompareTwoModules(hlo_ref, hlo_opt, GetRefModuleConfig(), + GetOptModuleConfig(), error_spec, + /*run_hlo_passes=*/false)); +} + +TEST_F(AddressComputationFusionTest, SlicedOperandAliasingOutput) { + ErrorSpec error_spec{/*aabs=*/1e-3, /*arel=*/1e-3}; + + const char* hlo_ref = R"( + HloModule jit_slice + + ENTRY %main.9 { + %p0 = (f32[100,100]{1,0}, f32[100,100]{1,0}) parameter(0) + %get-tuple-element.287 = f32[100,100]{1,0} get-tuple-element(%p0), index=0 + %get-tuple-element.288 = f32[100,100]{1,0} get-tuple-element(%p0), index=1 + %concatenate.12 = f32[200,100]{1,0} concatenate(%get-tuple-element.287, %get-tuple-element.288), dimensions={0} + %slice.30 = f32[100,100]{1,0} slice(%concatenate.12), slice={[20:120], [0:100]} + %slice.34 = f32[100,100]{1,0} slice(%concatenate.12), slice={[99:199], [0:100]} + ROOT %cublas-gemm.15 = (f32[100,100]{1,0}, s8[120000]{0}) custom-call(%get-tuple-element.287, %slice.30, %slice.34), + custom_call_target="__cublas$gemm", + output_to_operand_aliasing={{0}: (2, {})}, + backend_config={"gemm_backend_config":{ + "alpha_real":1, + "beta":1, + "dot_dimension_numbers":{ + "lhs_contracting_dimensions":["1"], + "rhs_contracting_dimensions":["0"], + "lhs_batch_dimensions":[], + "rhs_batch_dimensions":[] + }, + "alpha_imag":0, + "precision_config":{"operand_precision":["HIGHEST","HIGHEST"]}, + "epilogue":"DEFAULT", + "lhs_stride":"10000", + "rhs_stride":"10000", + "grad_x":false, + "grad_y":false + }} + })"; + + const char* hlo_opt = R"( + HloModule jit_slice + + %address-computation { + %p0.1 = f32[100,100]{1,0} parameter(0) + %p2 = f32[200,100]{1,0} parameter(2) + %slice.0 = f32[100,100]{1,0} slice(f32[200,100]{1,0} %p2), slice={[20:120], [0:100]} + %p1 = f32[100,100]{1,0} parameter(1) + %cublas-gemm.0 = (f32[100,100]{1,0}, s8[120000]{0}) custom-call(%p0.1, %slice.0, %p1), + custom_call_target="__cublas$gemm", + backend_config={ + "gemm_backend_config":{ + "alpha_real":1, + "beta":1, + "dot_dimension_numbers":{ + "lhs_contracting_dimensions":["1"], + "rhs_contracting_dimensions":["0"], + "lhs_batch_dimensions":[], + "rhs_batch_dimensions":[] + }, + "alpha_imag":0, + "precision_config":{"operand_precision":["HIGHEST","HIGHEST"]}, + "epilogue":"DEFAULT", + "lhs_stride":"10000", + "rhs_stride":"10000", + "grad_x":false, + "grad_y":false + } + } + %get-tuple-element = f32[100,100]{1,0} get-tuple-element(%cublas-gemm.0), index=0 + %get-tuple-element.1 = s8[120000]{0} get-tuple-element(%cublas-gemm.0), index=1 + ROOT %tuple = (f32[100,100]{1,0}, s8[120000]{0}) tuple(%get-tuple-element, %get-tuple-element.1) + } + + ENTRY %main { + %p0 = (f32[100,100]{1,0}, f32[100,100]{1,0}) parameter(0) + %get-tuple-element.287 = f32[100,100]{1,0} get-tuple-element(%p0), index=0 + %get-tuple-element.288 = f32[100,100]{1,0} get-tuple-element(%p0), index=1 + %concatenate.12 = f32[200,100]{1,0} concatenate(%get-tuple-element.287, %get-tuple-element.288), dimensions={0} + %slice.34 = f32[100,100]{1,0} slice(%concatenate.12), slice={[99:199], [0:100]} + ROOT %address_computation.6 = (f32[100,100]{1,0}, s8[120000]{0}) fusion(%get-tuple-element.287, %slice.34, %concatenate.12), + kind=kCustom, + calls=%address-computation, + output_to_operand_aliasing={{0}: (1, {})}, + backend_config={ + "fusion_backend_config":{ + "kind":"__custom_fusion","custom_fusion_config":{"name":"address_computation"} + } + } + })"; + + EXPECT_TRUE(RunAndCompareTwoModules(hlo_ref, hlo_opt, GetRefModuleConfig(), + GetOptModuleConfig(), error_spec, + /*run_hlo_passes=*/false)); +} + +static absl::Status Memcpy(const ServiceExecutableRunOptions* run_options, + ffi::BufferBase src, ffi::BufferBase dst) { + return run_options->stream()->MemcpyD2D( + &dst.data, src.data, + absl::c_accumulate(src.dimensions, 1.0, std::multiplies()) * + sizeof(float)); +} + +XLA_FFI_DEFINE_HANDLER(kMemcpy, Memcpy, + ffi::Ffi::Bind() + .Ctx() + .Arg() // src + .Arg() // dst +); +XLA_FFI_REGISTER_HANDLER(ffi::GetXlaFfiApi(), "__xla_test$$memcpy", PLATFORM, + kMemcpy); + +TEST_F(AddressComputationFusionTest, CustomCallSimple) { + XlaBuilder b(TestName()); + CustomCall(&b, "__xla_test$$memcpy", + /*operands=*/ + {Slice(Broadcast(ConstantR0WithType(&b, F32, 42.0), {256}), {0}, + {128}, {1})}, + ShapeUtil::MakeShape(F32, {128}), /*opaque=*/"", + /*has_side_effect=*/false, + /*output_operand_aliasing=*/{}, /*literal=*/nullptr, + /*schedule=*/CustomCallSchedule::SCHEDULE_NONE, + /*api_version=*/CustomCallApiVersion::API_VERSION_TYPED_FFI); + ErrorSpec error_spec{/*aabs=*/1e-3, /*arel=*/1e-3}; + + TF_ASSERT_OK_AND_ASSIGN(auto computation, b.Build()); + xla::HloModuleConfig hlo_config( + xla::ProgramShape(computation.proto().host_program_shape()), + /*ignore_layouts=*/false); + DebugOptions debug_options = GetDebugOptionsForTest(); + debug_options.set_xla_gpu_enable_address_computation_fusion(false); + hlo_config.set_debug_options(debug_options); + TF_ASSERT_OK_AND_ASSIGN(auto hlo_ref, xla::HloModule::CreateFromProto( + computation.proto(), hlo_config)); + + debug_options.set_xla_gpu_enable_address_computation_fusion(true); + hlo_config.set_debug_options(debug_options); + TF_ASSERT_OK_AND_ASSIGN(auto hlo_opt, xla::HloModule::CreateFromProto( + computation.proto(), hlo_config)); + + EXPECT_TRUE(RunAndCompareTwoModules(std::move(hlo_ref), std::move(hlo_opt), + error_spec, + /*run_hlo_passes=*/false)); +} + +static absl::Status SubBuffers(const ServiceExecutableRunOptions* run_options, + ffi::BufferBase src0, ffi::BufferBase src1, + ffi::BufferBase src2, ffi::BufferBase src3, + ffi::BufferBase src4, ffi::BufferBase dst0, + ffi::BufferBase dst1, ffi::BufferBase dst2, + ffi::BufferBase dst3, ffi::BufferBase dst4) { + // src0: param 0 at tuple index {0}, shape f32[128] + // src1: param 0 at tuple index {1}, shape f32[256] + // src2: param 1 at tuple index {0}, shape f32[1024] + // src3: param 1 at tuple index {1}, shape f32[8] + // src4: param 2, shape f32[4,8] + // + // dst0: result at tuple index {0}, shape f32[8] + // dst1: result at tuple index {1, 0}, shape f32[128] + // dst2: result at tuple index {1, 1}, shape f32[256] + // dst3: result at tuple index {2}, shape f32[1024] + // dst4: result at tuple index {3}, shape f32[4,8] + + TF_RETURN_IF_ERROR(run_options->stream()->MemcpyD2D(&dst0.data, src3.data, + 8 * sizeof(float))); + TF_RETURN_IF_ERROR(run_options->stream()->MemcpyD2D(&dst1.data, src0.data, + 128 * sizeof(float))); + TF_RETURN_IF_ERROR(run_options->stream()->MemcpyD2D(&dst2.data, src1.data, + 256 * sizeof(float))); + TF_RETURN_IF_ERROR(run_options->stream()->MemcpyD2D(&dst3.data, src2.data, + 1024 * sizeof(float))); + TF_RETURN_IF_ERROR(run_options->stream()->MemcpyD2D(&dst4.data, src4.data, + 4 * 8 * sizeof(float))); + return absl::OkStatus(); +} + +XLA_FFI_DEFINE_HANDLER(kSubBuffers, SubBuffers, + ffi::Ffi::Bind() + .Ctx() + .Arg() // src0 + .Arg() // src1 + .Arg() // src2 + .Arg() // src3 + .Arg() // src4 + .Arg() // dst0 + .Arg() // dst1 + .Arg() // dst2 + .Arg() // dst3 + .Arg() // dst4 +); +XLA_FFI_REGISTER_HANDLER(ffi::GetXlaFfiApi(), "__xla_test$$subbuffers", + PLATFORM, kSubBuffers); + +TEST_F(AddressComputationFusionTest, CustomCallWithTuple) { + XlaBuilder b(TestName()); + CustomCall(&b, "__xla_test$$subbuffers", /*operands=*/ + { + Tuple(&b, + { + Broadcast(ConstantR0WithType(&b, F32, 1), {128}), + Broadcast(ConstantR0WithType(&b, F32, 2), {256}), + }), + Tuple(&b, + { + Broadcast(ConstantR0WithType(&b, F32, 3), {1024}), + Broadcast(ConstantR0WithType(&b, F32, 4), {8}), + }), + Slice(Broadcast(ConstantR0WithType(&b, F32, 5), {8, 8}), + {0, 0}, {4, 8}, {1, 1}), + }, + ShapeUtil::MakeTupleShape({ + ShapeUtil::MakeShape(F32, {8}), + ShapeUtil::MakeTupleShape({ + ShapeUtil::MakeShape(F32, {128}), + ShapeUtil::MakeShape(F32, {256}), + }), + ShapeUtil::MakeShape(F32, {1024}), + ShapeUtil::MakeShape(F32, {4, 8}), + }), + /*opaque=*/"", + /*has_side_effect=*/false, + /*output_operand_aliasing=*/{}, /*literal=*/nullptr, + /*schedule=*/CustomCallSchedule::SCHEDULE_NONE, + /*api_version=*/CustomCallApiVersion::API_VERSION_TYPED_FFI); + ErrorSpec error_spec{/*aabs=*/1e-3, /*arel=*/1e-3}; + + TF_ASSERT_OK_AND_ASSIGN(auto computation, b.Build()); + xla::HloModuleConfig hlo_config( + xla::ProgramShape(computation.proto().host_program_shape()), + /*ignore_layouts=*/false); + DebugOptions debug_options = GetDebugOptionsForTest(); + debug_options.set_xla_gpu_enable_address_computation_fusion(false); + hlo_config.set_debug_options(debug_options); + TF_ASSERT_OK_AND_ASSIGN(auto hlo_ref, xla::HloModule::CreateFromProto( + computation.proto(), hlo_config)); + + debug_options.set_xla_gpu_enable_address_computation_fusion(true); + hlo_config.set_debug_options(debug_options); + TF_ASSERT_OK_AND_ASSIGN(auto hlo_opt, xla::HloModule::CreateFromProto( + computation.proto(), hlo_config)); + + EXPECT_TRUE(RunAndCompareTwoModules(std::move(hlo_ref), std::move(hlo_opt), + error_spec, + /*run_hlo_passes=*/false)); +} + +static absl::Status NoOp(const ServiceExecutableRunOptions* run_options, + ffi::BufferBase operand) { + return absl::OkStatus(); +} + +XLA_FFI_DEFINE_HANDLER(kNoOp, NoOp, + ffi::Ffi::Bind() + .Ctx() + .Arg() // operand +); +XLA_FFI_REGISTER_HANDLER(ffi::GetXlaFfiApi(), "__xla_test$$noop", PLATFORM, + kNoOp); + +TEST_F(AddressComputationFusionTest, NilTuple) { + XlaBuilder b(TestName()); + CustomCall(&b, "__xla_test$$noop", + /*operands=*/ + {Slice(Broadcast(ConstantR0WithType(&b, F32, 42.0), {256}), {0}, + {128}, {1})}, + ShapeUtil::MakeNil(), + /*opaque=*/"", + /*has_side_effect=*/false, + /*output_operand_aliasing=*/{}, /*literal=*/nullptr, + /*schedule=*/CustomCallSchedule::SCHEDULE_NONE, + /*api_version=*/CustomCallApiVersion::API_VERSION_TYPED_FFI); + ErrorSpec error_spec{/*aabs=*/1e-3, /*arel=*/1e-3}; + + TF_ASSERT_OK_AND_ASSIGN(auto computation, b.Build()); + xla::HloModuleConfig hlo_config( + xla::ProgramShape(computation.proto().host_program_shape()), + /*ignore_layouts=*/false); + DebugOptions debug_options = GetDebugOptionsForTest(); + debug_options.set_xla_gpu_enable_address_computation_fusion(false); + hlo_config.set_debug_options(debug_options); + TF_ASSERT_OK_AND_ASSIGN(auto hlo_ref, xla::HloModule::CreateFromProto( + computation.proto(), hlo_config)); + + debug_options.set_xla_gpu_enable_address_computation_fusion(true); + hlo_config.set_debug_options(debug_options); + TF_ASSERT_OK_AND_ASSIGN(auto hlo_opt, xla::HloModule::CreateFromProto( + computation.proto(), hlo_config)); + + EXPECT_TRUE(RunAndCompareTwoModules(std::move(hlo_ref), std::move(hlo_opt), + error_spec, + /*run_hlo_passes=*/false)); +} + +void Callback_Memcpy(se::gpu::GpuStreamHandle stream, void** buffers, + const char* /*opaque*/, size_t /*opaque_len*/) { + void* src = buffers[0]; + void* dst = buffers[1]; + auto err = gpuMemcpyAsync(dst, src, /*count=*/sizeof(float) * 128, + gpuMemcpyDeviceToDevice, stream); + ASSERT_EQ(err, gpuSuccess); +} + +XLA_REGISTER_CUSTOM_CALL_TARGET(Callback_Memcpy, PLATFORM); + +TEST_F(AddressComputationFusionTest, CustomCallLegacyAPI) { + XlaBuilder b(TestName()); + CustomCall(&b, "Callback_Memcpy", + /*operands=*/ + {Slice(Broadcast(ConstantR0WithType(&b, F32, 42.0), {256}), {0}, + {128}, {1})}, + ShapeUtil::MakeShape(F32, {128}), /*opaque=*/""); + ErrorSpec error_spec{/*aabs=*/1e-3, /*arel=*/1e-3}; + + TF_ASSERT_OK_AND_ASSIGN(auto computation, b.Build()); + xla::HloModuleConfig hlo_config( + xla::ProgramShape(computation.proto().host_program_shape()), + /*ignore_layouts=*/false); + DebugOptions debug_options = GetDebugOptionsForTest(); + debug_options.set_xla_gpu_enable_address_computation_fusion(false); + hlo_config.set_debug_options(debug_options); + TF_ASSERT_OK_AND_ASSIGN(auto hlo_ref, xla::HloModule::CreateFromProto( + computation.proto(), hlo_config)); + + debug_options.set_xla_gpu_enable_address_computation_fusion(true); + hlo_config.set_debug_options(debug_options); + TF_ASSERT_OK_AND_ASSIGN(auto hlo_opt, xla::HloModule::CreateFromProto( + computation.proto(), hlo_config)); + + EXPECT_TRUE(RunAndCompareTwoModules(std::move(hlo_ref), std::move(hlo_opt), + error_spec, + /*run_hlo_passes=*/false)); +} + +void Callback_Void(se::gpu::GpuStreamHandle /*stream*/, void** /*buffers*/, + const char* /*opaque*/, size_t /*opaque_len*/) {} + +XLA_REGISTER_CUSTOM_CALL_TARGET(Callback_Void, PLATFORM); + +TEST_F(AddressComputationFusionTest, NilTupleLegacyAPI) { + XlaBuilder b(TestName()); + CustomCall(&b, "Callback_Void", /*operands=*/ + {Slice(Broadcast(ConstantR0WithType(&b, F32, 42.0), {256}), {0}, + {128}, {1})}, + ShapeUtil::MakeNil(), + /*opaque=*/""); + ErrorSpec error_spec{/*aabs=*/1e-3, /*arel=*/1e-3}; + + TF_ASSERT_OK_AND_ASSIGN(auto computation, b.Build()); + xla::HloModuleConfig hlo_config( + xla::ProgramShape(computation.proto().host_program_shape()), + /*ignore_layouts=*/false); + DebugOptions debug_options = GetDebugOptionsForTest(); + debug_options.set_xla_gpu_enable_address_computation_fusion(false); + hlo_config.set_debug_options(debug_options); + TF_ASSERT_OK_AND_ASSIGN(auto hlo_ref, xla::HloModule::CreateFromProto( + computation.proto(), hlo_config)); - auto a0 = LiteralUtil::CreateFromArray(arr0); - auto a1 = LiteralUtil::CreateFromArray(arr1); + debug_options.set_xla_gpu_enable_address_computation_fusion(true); + hlo_config.set_debug_options(debug_options); + TF_ASSERT_OK_AND_ASSIGN(auto hlo_opt, xla::HloModule::CreateFromProto( + computation.proto(), hlo_config)); - EXPECT_TRUE(RunAndCompareTwoModules(hlo_ref, hlo_opt, {&a0, &a1}, error_spec, + EXPECT_TRUE(RunAndCompareTwoModules(std::move(hlo_ref), std::move(hlo_opt), + error_spec, /*run_hlo_passes=*/false)); } diff --git a/third_party/xla/xla/service/gpu/fusions/concatenate.h b/third_party/xla/xla/service/gpu/fusions/concatenate.h index 5bd77cdcf66129..3fdd3b2878fb2e 100644 --- a/third_party/xla/xla/service/gpu/fusions/concatenate.h +++ b/third_party/xla/xla/service/gpu/fusions/concatenate.h @@ -38,6 +38,13 @@ class ConcatenateFusion : public KernelFusionEmitterBase { std::optional ComputeThreadIdToOutputIndexing( int64_t output_id, mlir::MLIRContext* ctx) const override; + std::optional ComputeThreadIdToInputIndexing( + int64_t root_index, int64_t hero_operand_index, + mlir::MLIRContext* ctx) const override { + // TODO(b/319081342): Implement this. + return std::nullopt; + } + protected: absl::Status EmitKernel(IrEmitterContext& ir_emitter_context, const HloFusionInstruction& fusion, diff --git a/third_party/xla/xla/service/gpu/fusions/copy.cc b/third_party/xla/xla/service/gpu/fusions/copy.cc index eeb934c82085f4..eaa7a515cf8e8f 100644 --- a/third_party/xla/xla/service/gpu/fusions/copy.cc +++ b/third_party/xla/xla/service/gpu/fusions/copy.cc @@ -17,10 +17,9 @@ limitations under the License. #include #include "xla/hlo/ir/hlo_instructions.h" -#include "xla/mlir_hlo/lhlo/IR/lhlo_ops.h" #include "xla/service/gpu/fusions/fusion_emitter.h" #include "xla/service/gpu/ir_emitter_context.h" -#include "xla/service/gpu/runtime3/copy_thunk.h" +#include "xla/service/gpu/runtime/copy_thunk.h" #include "xla/service/gpu/thunk.h" #include "xla/statusor.h" @@ -28,23 +27,16 @@ namespace xla { namespace gpu { absl::StatusOr MemcpyFusion::Emit( - IrEmitterContext& ir_emitter_context, mlir::lmhlo::FusionOp fusion_op, + IrEmitterContext& ir_emitter_context, const HloFusionInstruction& fusion) const { FusionEmissionResult result; for (int i = 0; i < src_buffers_.size(); ++i) { if (src_buffers_[i] != dst_buffers_[i]) { result.thunks.emplace_back(std::make_unique( - ir_emitter_context.emit_ir_from_hlo() - ? Thunk::ThunkInfo::WithProfileAnnotation(&fusion) - : Thunk::ThunkInfo::WithProfileAnnotation(fusion_op), + Thunk::ThunkInfo::WithProfileAnnotation(&fusion), /*source_buffer=*/src_buffers_[i], /*destination_buffer=*/dst_buffers_[i], - /*mem_size=*/src_buffers_[i].size(), - /*source_value=*/ir_emitter_context.emit_ir_from_hlo() ? nullptr - : srcs_[i], - /*destination_value=*/ir_emitter_context.emit_ir_from_hlo() - ? nullptr - : dsts_[i])); + /*mem_size=*/src_buffers_[i].size())); } } return result; diff --git a/third_party/xla/xla/service/gpu/fusions/copy.h b/third_party/xla/xla/service/gpu/fusions/copy.h index d51c133bf20764..574f1eb454271a 100644 --- a/third_party/xla/xla/service/gpu/fusions/copy.h +++ b/third_party/xla/xla/service/gpu/fusions/copy.h @@ -38,7 +38,7 @@ class MemcpyFusion : public FusionInterface { dsts_(std::move(dsts)) {} absl::StatusOr Emit( - IrEmitterContext& ir_emitter_context, mlir::lmhlo::FusionOp fusion_op, + IrEmitterContext& ir_emitter_context, const HloFusionInstruction& fusion) const final; private: diff --git a/third_party/xla/xla/service/gpu/fusions/custom.cc b/third_party/xla/xla/service/gpu/fusions/custom.cc index 437ea7d5f4fa2d..fa910bc2589cf6 100644 --- a/third_party/xla/xla/service/gpu/fusions/custom.cc +++ b/third_party/xla/xla/service/gpu/fusions/custom.cc @@ -18,6 +18,8 @@ limitations under the License. #include #include #include +#include +#include #include #include #include @@ -26,13 +28,20 @@ limitations under the License. #include "absl/status/status.h" #include "absl/strings/str_cat.h" #include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/TypeSwitch.h" +#include "mlir/AsmParser/AsmParser.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project #include "mlir/IR/Operation.h" // from @llvm-project +#include "xla/ffi/api/c_api.h" +#include "xla/ffi/ffi_api.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_opcode.h" -#include "xla/mlir_hlo/lhlo/IR/lhlo_ops.h" #include "xla/service/buffer_assignment.h" +#include "xla/service/custom_call_status.h" +#include "xla/service/custom_call_target_registry.h" #include "xla/service/gpu/backend_configs.pb.h" #include "xla/service/gpu/cublas_cudnn.h" #include "xla/service/gpu/fusions/fusion_emitter.h" @@ -44,13 +53,15 @@ limitations under the License. #include "xla/service/gpu/kernels/custom_kernel.h" #include "xla/service/gpu/kernels/custom_kernel_fusion.h" #include "xla/service/gpu/matmul_utils.h" -#include "xla/service/gpu/runtime3/gemm_thunk.h" -#include "xla/service/gpu/runtime3/kernel_thunk.h" +#include "xla/service/gpu/runtime/custom_call_thunk.h" +#include "xla/service/gpu/runtime/gemm_thunk.h" +#include "xla/service/gpu/runtime/kernel_thunk.h" #include "xla/service/gpu/thunk.h" #include "xla/shape.h" #include "xla/shape_util.h" -#include "xla/status_macros.h" -#include "xla/statusor.h" +#include "xla/status.h" +#include "xla/util.h" +#include "tsl/platform/errors.h" #include "tsl/platform/statusor.h" namespace xla { @@ -59,47 +70,104 @@ namespace { absl::StatusOr> BuildCustomKernelThunkForFusion( IrEmitterContext& ir_emitter_context, const HloFusionInstruction& fusion, - mlir::lmhlo::FusionOp fusion_op, CustomKernel custom_kernel) { - TF_ASSIGN_OR_RETURN(auto kernel_arguments, - ir_emitter_context.emit_ir_from_hlo() - ? KernelArguments::Create( - ir_emitter_context.buffer_assignment(), &fusion) - : KernelArguments::Create( - ir_emitter_context.allocations(), fusion_op)); - - std::variant instr; - if (ir_emitter_context.emit_ir_from_hlo()) { - instr = &fusion; - } else { - instr = fusion_op; - } + CustomKernel custom_kernel) { + TF_ASSIGN_OR_RETURN( + auto kernel_arguments, + KernelArguments::Create(ir_emitter_context.buffer_assignment(), &fusion)); return std::make_unique( - instr, std::move(custom_kernel), std::move(kernel_arguments.args())); + &fusion, std::move(custom_kernel), std::move(kernel_arguments.args())); +} + +// TODO(vuson): this is duplicated from ir_emitter_unnested.cc +// Converts MLIR dictionary attribute attached to a custom call operation to a +// custom call thunk attributes that are forwarded to the FFI handler. +static absl::StatusOr BuildAttributesMap( + mlir::DictionaryAttr dict) { + CustomCallThunk::AttributesMap attributes; + for (auto& kv : dict) { + std::string_view name = kv.getName().strref(); + + auto integer = [&](mlir::IntegerAttr integer) { + switch (integer.getType().getIntOrFloatBitWidth()) { + case 32: + attributes[name] = static_cast(integer.getInt()); + return absl::OkStatus(); + case 64: + attributes[name] = static_cast(integer.getInt()); + return absl::OkStatus(); + default: + return absl::InvalidArgumentError(absl::StrCat( + "Unsupported integer attribute bit width for attribute: ", name)); + } + }; + + auto fp = [&](mlir::FloatAttr fp) { + switch (fp.getType().getIntOrFloatBitWidth()) { + case 32: + attributes[name] = static_cast(fp.getValue().convertToFloat()); + return absl::OkStatus(); + default: + return absl::InvalidArgumentError(absl::StrCat( + "Unsupported float attribute bit width for attribute: ", name)); + } + }; + + auto str = [&](mlir::StringAttr str) { + attributes[name] = str.getValue().str(); + return absl::OkStatus(); + }; + + TF_RETURN_IF_ERROR( + llvm::TypeSwitch(kv.getValue()) + .Case(integer) + .Case(fp) + .Case(str) + .Default([&](mlir::Attribute) { + return absl::InvalidArgumentError(absl::StrCat( + "Unsupported attribute type for attribute: ", name)); + })); + } + return attributes; } absl::StatusOr GetSliceWithUpdatedOffsetAndSize( const BufferAssignment& buffer_assignment, const HloFusionAdaptor& fusion, - const HloInstruction* bufferized_instr, const HloInstruction& start) { - TF_ASSIGN_OR_RETURN( - BufferAllocation::Slice orig_slice, - GetAllocationSlice(buffer_assignment, bufferized_instr, {})); + const HloInstruction& fusion_instr, const HloInstruction& start, + const ShapeIndex& index) { + if (const auto* param = DynCast(&start)) { + return GetAllocationSlice(buffer_assignment, + fusion_instr.operand(param->parameter_number()), + index); + } - auto maybe_slice_adaptor = + auto slice_adaptor = HloFindIf({HloInstructionAdaptor(start)}, fusion, [](auto node) { return node.opcode() == HloOpcode::kSlice; }); - if (maybe_slice_adaptor == std::nullopt) return orig_slice; + if (!slice_adaptor.has_value()) { + return absl::InternalError( + "AddressComputationFusion expects at least one sliced operand"); + } - const auto& slice_instr = *static_cast( - &maybe_slice_adaptor->instruction()); + const auto& slice_instr = + *static_cast(&slice_adaptor->instruction()); - TF_RET_CHECK(IsContiguousSlice(slice_instr)) - << "AddressComputationFusion only handles contiguous slices currently"; + if (!IsContiguousSlice(slice_instr)) { + return absl::InternalError( + "AddressComputationFusion only handles contiguous slices currently"); + } const Shape& src_shape = slice_instr.operand(0)->shape(); const Shape& dst_shape = slice_instr.shape(); int64_t size = ShapeUtil::ByteSizeOf(dst_shape); + const auto* param = Cast(slice_instr.operand(0)); + TF_ASSIGN_OR_RETURN( + BufferAllocation::Slice orig_slice, + GetAllocationSlice(buffer_assignment, + fusion_instr.operand(param->parameter_number()), + index)); + // Given this slice // f16[1,4,8]{2,1,0} slice(f16[2,8,8]{2,1,0}), // slice={[1:2], [4:8], [0:8]} @@ -107,7 +175,7 @@ absl::StatusOr GetSliceWithUpdatedOffsetAndSize( // The offset of the slice should be: // slice_starts(0) * 8 * 8 * sizeof(f16) + // slice_starts(1) * 8 * sizeof(f16) - int64_t offset = 0; + int64_t offset = orig_slice.offset(); for (auto [start, stride] : llvm::zip(slice_instr.slice_starts(), *ShapeUtil::ByteStrides(src_shape))) { offset += start * stride; @@ -116,10 +184,217 @@ absl::StatusOr GetSliceWithUpdatedOffsetAndSize( return BufferAllocation::Slice(orig_slice.allocation(), offset, size); } +absl::StatusOr EmitGemm( + IrEmitterContext& ir_emitter_context, const HloFusionAdaptor& adaptor, + const HloFusionInstruction& fusion, + const HloCustomCallInstruction& custom_call) { + const BufferAssignment& buffer_assignment = + ir_emitter_context.buffer_assignment(); + + TF_ASSIGN_OR_RETURN( + BufferAllocation::Slice lhs_slice, + GetSliceWithUpdatedOffsetAndSize(buffer_assignment, adaptor, fusion, + *custom_call.operand(0), /*index=*/{})); + + TF_ASSIGN_OR_RETURN( + BufferAllocation::Slice rhs_slice, + GetSliceWithUpdatedOffsetAndSize(buffer_assignment, adaptor, fusion, + *custom_call.operand(1), /*index=*/{})); + + BufferAllocation::Slice output; + std::optional workspace; + + // Result of a legacy cuBLAS custom call can be a tuple if we explicitly + // allocate workspace buffer in HLO. If result is an array, it means that + // workspace is not available, and cuBLAS will allocate its own workspace. + if (custom_call.shape().IsArray()) { + TF_ASSIGN_OR_RETURN(output, + GetAllocationSlice(buffer_assignment, &fusion, {})); + } else { + TF_ASSIGN_OR_RETURN(output, + GetAllocationSlice(buffer_assignment, &fusion, {0})); + TF_ASSIGN_OR_RETURN(workspace, + GetAllocationSlice(buffer_assignment, &fusion, {1})); + } + + bool deterministic_ops = + ir_emitter_context.debug_options().xla_gpu_deterministic_ops(); + + TF_ASSIGN_OR_RETURN( + GemmConfig config, + GemmConfig::For(static_cast(&custom_call))); + auto thunk = std::make_unique( + Thunk::ThunkInfo::WithProfileAnnotation(&custom_call), std::move(config), + lhs_slice, rhs_slice, output, workspace, deterministic_ops); + + FusionEmissionResult result; + result.thunks.push_back(std::move(thunk)); + return result; +} + +absl::StatusOr EmitCustomCall( + IrEmitterContext& ir_emitter_context, const HloFusionAdaptor& adaptor, + const HloFusionInstruction& fusion, + const HloCustomCallInstruction& custom_call) { + const BufferAssignment& buffer_assignment = + ir_emitter_context.buffer_assignment(); + + const std::string call_target_name = custom_call.custom_call_target(); + + // Typed FFI custom calls is a replacement for legacy custom calls with + // a rich type safe API. It's under construction and not fully supported. + bool is_ffi_custom_call = + custom_call.api_version() == CustomCallApiVersion::API_VERSION_TYPED_FFI; + + void* call_target = CustomCallTargetRegistry::Global()->Lookup( + call_target_name, std::string(ir_emitter_context.platform_name())); + + absl::StatusOr handler = + ffi::FindHandler(call_target_name, ir_emitter_context.platform_name()); + + // At least one implementation should be available at run time. + bool found_custom_call = !is_ffi_custom_call && call_target != nullptr; + bool found_ffi_handler = is_ffi_custom_call && handler.ok(); + + if (!found_custom_call && !found_ffi_handler) { + return absl::InternalError( + "AddressComputationFusion expects custom calls that are emittable as " + "thunks"); + } + + using Slices = std::vector>; + + Slices operands; + // TODO(vuson): add test with custom call with tuple-typed operands + for (auto* operand : custom_call.operands()) { + TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus( + operand->shape(), [&](const Shape& subshape, const ShapeIndex& index) { + if (subshape.IsToken()) { + operands.push_back(std::nullopt); + return absl::OkStatus(); + } + if (!subshape.IsArray()) { + return absl::OkStatus(); + } + TF_ASSIGN_OR_RETURN(auto slice, GetSliceWithUpdatedOffsetAndSize( + buffer_assignment, adaptor, + fusion, *operand, index)); + operands.push_back(CustomCallThunk::Slice{slice, subshape}); + return absl::OkStatus(); + })); + } + + Slices results; + TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus( + fusion.shape(), [&](const Shape& subshape, const ShapeIndex& index) { + if (subshape.IsToken()) { + results.push_back(std::nullopt); + return absl::OkStatus(); + } + if (!subshape.IsArray()) { + return absl::OkStatus(); + } + TF_ASSIGN_OR_RETURN( + auto slice, GetAllocationSlice(buffer_assignment, &fusion, index)); + results.push_back(CustomCallThunk::Slice{slice, subshape}); + return absl::OkStatus(); + })); + + // For legacy custom calls we convert all API versions into the latest + // status-returning one and pass backend config as an opaque string. + CustomCallThunk::CustomCallTarget custom_call_target; + std::string opaque; + + // For XLA FFI handlers we decode opaque backend config into attributes map + // at IR emission time, so that we do not need to parse MLIR at run time. For + // FFI handlers backend config must be a compatible MLIR dictionary. + CustomCallThunk::AttributesMap attributes; + + // For information about this calling convention, see + // xla/g3doc/custom_call.md. + switch (custom_call.api_version()) { + case CustomCallApiVersion::API_VERSION_ORIGINAL: + using original_call_type = + void (*)(CustomCallThunk::Stream /*stream*/, void** /*buffers*/, + const char* /*opaque*/, size_t /*opaque_len*/); + custom_call_target = [call_target](CustomCallThunk::Stream stream, + void** buffers, const char* opaque, + size_t opaque_len, + XlaCustomCallStatus*) { + auto typed_call_target = + reinterpret_cast(call_target); + typed_call_target(stream, buffers, opaque, opaque_len); + }; + break; + case CustomCallApiVersion::API_VERSION_STATUS_RETURNING: + case CustomCallApiVersion::API_VERSION_STATUS_RETURNING_UNIFIED: + using status_returning_call_type = + void (*)(CustomCallThunk::Stream /*stream*/, void** /*buffers*/, + const char* /*opaque*/, size_t /*opaque_len*/, + XlaCustomCallStatus* /*status*/); + custom_call_target = + reinterpret_cast(call_target); + break; + case CustomCallApiVersion::API_VERSION_TYPED_FFI: + // We already checked `handler` above. + break; + default: + return Internal("Unknown custom-call API version enum value: %d", + custom_call.api_version()); + } + + auto& backend_config_str = custom_call.raw_backend_config_string(); + switch (custom_call.api_version()) { + case CustomCallApiVersion::API_VERSION_ORIGINAL: + case CustomCallApiVersion::API_VERSION_STATUS_RETURNING: + case CustomCallApiVersion::API_VERSION_STATUS_RETURNING_UNIFIED: + if (!backend_config_str.empty()) { + opaque = backend_config_str; + } + break; + + case CustomCallApiVersion::API_VERSION_TYPED_FFI: + if (!backend_config_str.empty()) { + mlir::Attribute attr = mlir::parseAttribute( + backend_config_str, ir_emitter_context.mlir_context()); + if (auto dict = attr.dyn_cast_or_null()) { + TF_ASSIGN_OR_RETURN(attributes, BuildAttributesMap(dict)); + break; + } + return absl::InternalError( + "Unsupported backend config. Expected a string parsable into " + "dictionary attribute"); + } + break; + + default: + return Internal("Unknown custom-call API version enum value: %d", + custom_call.api_version()); + } + + auto ffi_thunk = [&] { + auto& called_computations = custom_call.called_computations(); + return std::make_unique( + Thunk::ThunkInfo::WithProfileAnnotation(&custom_call), *handler, + std::move(operands), std::move(results), std::move(attributes), + called_computations.empty() ? nullptr : called_computations[0]); + }; + + auto legacy_thunk = [&] { + return std::make_unique( + Thunk::ThunkInfo::WithProfileAnnotation(&custom_call), + std::move(custom_call_target), std::move(operands), std::move(results), + std::move(opaque)); + }; + FusionEmissionResult result; + result.thunks.push_back(found_ffi_handler ? ffi_thunk() : legacy_thunk()); + return result; +} + } // namespace absl::StatusOr CustomFusion::Emit( - IrEmitterContext& ir_emitter_context, mlir::lmhlo::FusionOp fusion_op, + IrEmitterContext& ir_emitter_context, const HloFusionInstruction& fusion) const { TF_ASSIGN_OR_RETURN(auto gpu_config, fusion.backend_config()); @@ -159,9 +434,9 @@ absl::StatusOr CustomFusion::Emit( return absl::InternalError("Expected exactly one custom kernel"); } - TF_ASSIGN_OR_RETURN(auto thunk, BuildCustomKernelThunkForFusion( - ir_emitter_context, fusion, fusion_op, - std::move(kernels[0]))); + TF_ASSIGN_OR_RETURN( + auto thunk, BuildCustomKernelThunkForFusion(ir_emitter_context, fusion, + std::move(kernels[0]))); FusionEmissionResult result; result.thunks.push_back(std::move(thunk)); @@ -169,66 +444,25 @@ absl::StatusOr CustomFusion::Emit( } absl::StatusOr AddressComputationFusion::Emit( - IrEmitterContext& ir_emitter_context, mlir::lmhlo::FusionOp fusion_op, + IrEmitterContext& ir_emitter_context, const HloFusionInstruction& fusion) const { - const BufferAssignment& buffer_assignment = - ir_emitter_context.buffer_assignment(); - const HloFusionAdaptor& adaptor = analysis_.fusion(); auto maybe_custom_call_adaptor = HloFindIf( adaptor.GetRoots(), adaptor, [](auto node) { return node.opcode() == HloOpcode::kCustomCall; }); - TF_RET_CHECK(maybe_custom_call_adaptor != std::nullopt) - << "AddressComputationFusion requires a CustomCall hero"; + if (maybe_custom_call_adaptor == std::nullopt) { + return absl::InternalError( + "AddressComputationFusion requires a CustomCall hero"); + } const auto& custom_call = *static_cast( &maybe_custom_call_adaptor->instruction()); + // TODO(vuson): these Emit* are mostly duplicated from ir_emitter_unnested if (IsLegacyCublasMatmul(custom_call)) { - TF_ASSIGN_OR_RETURN(BufferAllocation::Slice lhs_slice, - GetSliceWithUpdatedOffsetAndSize( - buffer_assignment, adaptor, fusion.operand(0), - *custom_call.operand(0))); - - TF_ASSIGN_OR_RETURN(BufferAllocation::Slice rhs_slice, - GetSliceWithUpdatedOffsetAndSize( - buffer_assignment, adaptor, fusion.operand(1), - *custom_call.operand(1))); - - BufferAllocation::Slice output; - std::optional workspace; - - // Result of a legacy cuBLAS custom call can be a tuple if we explicitly - // allocate workspace buffer in HLO. If result is an array, it means that - // workspace is not available, and cuBLAS will allocate its own workspace. - if (custom_call.shape().IsArray()) { - TF_ASSIGN_OR_RETURN(output, - GetAllocationSlice(buffer_assignment, &fusion, {})); - } else { - TF_ASSIGN_OR_RETURN(output, - GetAllocationSlice(buffer_assignment, &fusion, {0})); - TF_ASSIGN_OR_RETURN(workspace, - GetAllocationSlice(buffer_assignment, &fusion, {1})); - } - - bool deterministic_ops = - ir_emitter_context.debug_options().xla_gpu_deterministic_ops(); - - TF_ASSIGN_OR_RETURN( - GemmConfig config, - GemmConfig::For(static_cast(&custom_call))); - auto thunk = std::make_unique( - Thunk::ThunkInfo::WithProfileAnnotation(&custom_call), - std::move(config), lhs_slice, rhs_slice, output, workspace, - deterministic_ops); - - FusionEmissionResult result; - result.thunks.push_back(std::move(thunk)); - return result; + return EmitGemm(ir_emitter_context, adaptor, fusion, custom_call); } - return absl::UnimplementedError( - absl::StrCat("No emission for AddressComputationFusion of custom call ", - custom_call.custom_call_target())); + return EmitCustomCall(ir_emitter_context, adaptor, fusion, custom_call); } } // namespace gpu diff --git a/third_party/xla/xla/service/gpu/fusions/custom.h b/third_party/xla/xla/service/gpu/fusions/custom.h index f227e270d82d0e..e5f763027f754b 100644 --- a/third_party/xla/xla/service/gpu/fusions/custom.h +++ b/third_party/xla/xla/service/gpu/fusions/custom.h @@ -16,7 +16,6 @@ limitations under the License. #define XLA_SERVICE_GPU_FUSIONS_CUSTOM_H_ #include "xla/hlo/ir/hlo_instructions.h" -#include "xla/mlir_hlo/lhlo/IR/lhlo_ops.h" #include "xla/service/gpu/fusions/fusion_emitter.h" #include "xla/service/gpu/hlo_fusion_analysis.h" #include "xla/service/gpu/ir_emitter_context.h" @@ -31,7 +30,7 @@ namespace gpu { class CustomFusion : public FusionInterface { public: absl::StatusOr Emit( - IrEmitterContext& ir_emitter_context, mlir::lmhlo::FusionOp fusion_op, + IrEmitterContext& ir_emitter_context, const HloFusionInstruction& fusion) const final; }; @@ -57,7 +56,7 @@ class AddressComputationFusion : public FusionInterface { : analysis_(analysis) {} absl::StatusOr Emit( - IrEmitterContext& ir_emitter_context, mlir::lmhlo::FusionOp fusion_op, + IrEmitterContext& ir_emitter_context, const HloFusionInstruction& fusion) const final; private: diff --git a/third_party/xla/xla/service/gpu/fusions/fusion_emitter.cc b/third_party/xla/xla/service/gpu/fusions/fusion_emitter.cc index a6818e7107b3a3..34eb8e0a1fdad9 100644 --- a/third_party/xla/xla/service/gpu/fusions/fusion_emitter.cc +++ b/third_party/xla/xla/service/gpu/fusions/fusion_emitter.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include #include +#include #include #include #include @@ -39,6 +40,7 @@ limitations under the License. #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/Metadata.h" +#include "llvm/TargetParser/Triple.h" #include "mlir/IR/AffineExpr.h" // from @llvm-project #include "mlir/IR/AffineMap.h" // from @llvm-project #include "xla/hlo/ir/hlo_instructions.h" @@ -48,17 +50,18 @@ limitations under the License. #include "xla/service/gpu/kernel_arguments.h" #include "xla/service/gpu/kernel_reuse_cache.h" #include "xla/service/gpu/launch_dimensions.h" -#include "xla/service/gpu/model/indexing_analysis.h" #include "xla/service/gpu/model/indexing_map.h" -#include "xla/service/gpu/runtime3/kernel_thunk.h" +#include "xla/service/gpu/runtime/kernel_thunk.h" #include "xla/service/gpu/target_util.h" #include "xla/service/llvm_ir/ir_array.h" #include "xla/service/llvm_ir/llvm_util.h" #include "xla/shape.h" +#include "xla/shape_util.h" #include "xla/status.h" #include "xla/status_macros.h" #include "xla/statusor.h" #include "xla/stream_executor/device_description.h" +#include "xla/util.h" #include "tsl/platform/errors.h" #include "tsl/platform/statusor.h" @@ -82,6 +85,8 @@ void AnnotateWithInt32Value(std::string name, int64_t value, llvm::IntegerType::get(llvm_context, /*NumBits=*/32), value))})); } +} // namespace + // Annotates the launch dimensions of the corresponding IR kernel in // `llvm_module`. absl::Status AnnotateKernelLaunchDimensions( @@ -108,13 +113,12 @@ absl::Status AnnotateKernelLaunchDimensions( AnnotateWithInt32Value("reqntidz", launch_dims.thread_counts_per_block().z, kernel_name, llvm_module); } - + // Maybe we want to set "reqnctapercluster" here, but not sure if needed or if + // LLVM supports that yet. Let's do that later when needed. return absl::OkStatus(); } -} // namespace - -mlir::AffineMap KernelFusionInterface::GetDefaultThreadIdToOutputIndexingMap( +IndexingMap KernelFusionInterface::GetDefaultThreadIdToOutputIndexingMap( const LaunchDimensions& launch_dims, int unroll_factor, const Shape& output_shape, mlir::MLIRContext* ctx) { std::vector output_dims(output_shape.rank()); @@ -140,7 +144,8 @@ mlir::AffineMap KernelFusionInterface::GetDefaultThreadIdToOutputIndexingMap( // This means that this code supports some launch grids that the parallel // loop emitter doesn't support. This is safe, since the latter CHECK fails // if its assumptions are not fulfilled. - mlir::AffineExpr linear_index = mlir::getAffineConstantExpr(0, ctx); + mlir::AffineExpr c0 = mlir::getAffineConstantExpr(0, ctx); + mlir::AffineExpr linear_index = c0; uint64_t stride = 1; for (int i = 0; i < 3; ++i) { auto coord = mlir::getAffineDimExpr(kIndexingMapThreadIdxDims[i], ctx) + @@ -150,11 +155,12 @@ mlir::AffineMap KernelFusionInterface::GetDefaultThreadIdToOutputIndexingMap( linear_index = linear_index + linear_component; stride *= total_sizes[i]; } + mlir::AffineExpr chunk_id = mlir::getAffineSymbolExpr(0, ctx); + mlir::AffineExpr unroll_elem_id = mlir::getAffineSymbolExpr(1, ctx); - if (unroll_factor > 1) { - linear_index = - linear_index * unroll_factor + mlir::getAffineSymbolExpr(0, ctx); - } + linear_index = linear_index * unroll_factor + + chunk_id * unroll_factor * launch_dims.launch_bound() + + unroll_elem_id; // See IndexUtil::LinearIndexToMultidimensionalIndex. uint64_t divisor = 1; @@ -165,13 +171,6 @@ mlir::AffineMap KernelFusionInterface::GetDefaultThreadIdToOutputIndexingMap( divisor *= output_shape.dimensions(dimension); } - return mlir::AffineMap::get(/*dimCount=*/6, - /*symbolCount=*/unroll_factor > 1 ? 1 : 0, - output_dims, ctx); -} - -Domain KernelFusionInterface::GetThreadIdDomain( - const LaunchDimensions& launch_dims, int unroll_factor) { std::vector dimension_ranges = { {0, static_cast(launch_dims.thread_counts_per_block().x) - 1}, {0, static_cast(launch_dims.thread_counts_per_block().y) - 1}, @@ -181,10 +180,26 @@ Domain KernelFusionInterface::GetThreadIdDomain( {0, static_cast(launch_dims.block_counts().z) - 1}, }; std::vector symbol_ranges; - if (unroll_factor > 1) { - symbol_ranges.push_back({0, unroll_factor - 1}); + int64_t num_elements = ShapeUtil::ElementsIn(output_shape); + symbol_ranges.push_back( + {0, CeilOfRatio(num_elements, + static_cast(launch_dims.launch_bound()) * + unroll_factor) - + 1}); + symbol_ranges.push_back({0, unroll_factor - 1}); + IndexingMap indexing_map( + mlir::AffineMap::get(/*dimCount=*/6, + /*symbolCount=*/2, output_dims, ctx), + dimension_ranges, symbol_ranges); + // Remove the unroll_elem_id symbol if unrolling divides num_elements. + if (num_elements % unroll_factor == 0) { + indexing_map.AddConstraint(linear_index.replace({{unroll_elem_id, c0}}), + Range{0, num_elements - unroll_factor}); + } else { + indexing_map.AddConstraint(linear_index, Range{0, num_elements - 1}); } - return Domain(dimension_ranges, symbol_ranges); + indexing_map.Simplify(); + return indexing_map; } absl::StatusOr, @@ -219,9 +234,11 @@ BuildKernelPrototype(IrEmitterContext& ir_emitter_context, // Create the kernel and add it to the module. auto* llvm_module = ir_emitter_context.llvm_module(); llvm::LLVMContext& context = llvm_module->getContext(); + // Explicitly set global addrspace for SPIR backend. + int addrspace = llvm::Triple(llvm_module->getTargetTriple()).isSPIR() ? 1 : 0; llvm::FunctionType* kernel_type = llvm::FunctionType::get( /*Result=*/llvm::Type::getVoidTy(context), - std::vector(kNumLlvmArgs, builder->getPtrTy()), + std::vector(kNumLlvmArgs, builder->getPtrTy(addrspace)), /*isVarArg=*/false); llvm::Function* kernel = llvm::Function::Create(kernel_type, llvm::GlobalValue::ExternalLinkage, @@ -285,22 +302,19 @@ BuildKernelPrototype(IrEmitterContext& ir_emitter_context, } absl::StatusOr KernelFusionEmitterBase::Emit( - IrEmitterContext& ir_emitter_context, mlir::lmhlo::FusionOp fusion_op, + IrEmitterContext& ir_emitter_context, const HloFusionInstruction& fusion) const { llvm::IRBuilder<> builder(ir_emitter_context.llvm_module()->getContext()); std::string suggested_kernel_name = std::string(fusion.name()); - TF_ASSIGN_OR_RETURN(KernelArguments kernel_arguments, - ir_emitter_context.emit_ir_from_hlo() - ? KernelArguments::Create( - ir_emitter_context.buffer_assignment(), &fusion) - : KernelArguments::Create( - ir_emitter_context.allocations(), fusion_op)); + TF_ASSIGN_OR_RETURN( + KernelArguments kernel_arguments, + KernelArguments::Create(ir_emitter_context.buffer_assignment(), &fusion)); auto* fused_computation = fusion.fused_instructions_computation(); TF_ASSIGN_OR_RETURN(auto result, - EmitInitializers(ir_emitter_context, fusion_op, fusion)); + EmitInitializers(ir_emitter_context, fusion)); auto launch_dims = launch_dimensions(); std::vector inputs, outputs; auto [status_or_entry, cached] = @@ -325,6 +339,7 @@ absl::StatusOr KernelFusionEmitterBase::Emit( // TODO(jreiffers): Return shmem_bytes from EmitKernel when // converting the Triton emitters to this infrastructure. return KernelReuseCache::Entry{kernel->getName().str(), launch_dims, + /*cluster_dim=*/std::nullopt, /*shmem_bytes=*/0}; }); TF_ASSIGN_OR_RETURN(const KernelReuseCache::Entry* entry, status_or_entry); @@ -334,15 +349,9 @@ absl::StatusOr KernelFusionEmitterBase::Emit( << entry->kernel_name; } - if (ir_emitter_context.emit_ir_from_hlo()) { - result.thunks.emplace_back(std::make_unique( - &fusion, entry->kernel_name, kernel_arguments.args(), launch_dims, - entry->shmem_bytes)); - } else { - result.thunks.emplace_back(std::make_unique( - fusion_op, entry->kernel_name, kernel_arguments.args(), launch_dims, - entry->shmem_bytes)); - } + result.thunks.emplace_back(std::make_unique( + &fusion, entry->kernel_name, kernel_arguments.args(), launch_dims, + entry->cluster_dim, entry->shmem_bytes)); return result; } diff --git a/third_party/xla/xla/service/gpu/fusions/fusion_emitter.h b/third_party/xla/xla/service/gpu/fusions/fusion_emitter.h index 83bc88554df639..dbc8e8718debe0 100644 --- a/third_party/xla/xla/service/gpu/fusions/fusion_emitter.h +++ b/third_party/xla/xla/service/gpu/fusions/fusion_emitter.h @@ -30,11 +30,11 @@ limitations under the License. #include "mlir/IR/AffineMap.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project #include "xla/hlo/ir/hlo_instructions.h" -#include "xla/mlir_hlo/lhlo/IR/lhlo_ops.h" #include "xla/service/gpu/ir_emitter_context.h" #include "xla/service/gpu/kernel_arguments.h" #include "xla/service/gpu/launch_dimensions.h" #include "xla/service/gpu/model/indexing_analysis.h" +#include "xla/service/gpu/model/indexing_map.h" #include "xla/service/gpu/thunk.h" #include "xla/service/llvm_ir/ir_array.h" #include "xla/shape.h" @@ -53,7 +53,7 @@ class FusionInterface { virtual ~FusionInterface() = default; virtual absl::StatusOr Emit( - IrEmitterContext& ir_emitter_context, mlir::lmhlo::FusionOp fusion_op, + IrEmitterContext& ir_emitter_context, const HloFusionInstruction& fusion) const = 0; }; @@ -65,7 +65,7 @@ class KernelFusionInterface : public FusionInterface { // Returns the fusion's launch dimensions. virtual LaunchDimensions launch_dimensions() const = 0; - // Computes an indexing map from thread to output element(s). + // Computes an indexing map from thread to output element(s) of the **hero**. // // The dimensions in the resulting map are // d0, d1, d2: threadIdx.{x,y,z} @@ -77,7 +77,14 @@ class KernelFusionInterface : public FusionInterface { // unsupported (scatter, in-place DUS). Implementations will return nullopt. // Note: Work in progress, not implemented for all emitters. virtual std::optional ComputeThreadIdToOutputIndexing( - int64_t output_id, mlir::MLIRContext* ctx) const = 0; + int64_t root_index, mlir::MLIRContext* ctx) const = 0; + + // Computes an indexing map from thread to input element(s) of the root's + // **hero**. Note that in many cases this is not computable from the output + // indexing. The indexing may only be known for some operands of the hero. + virtual std::optional ComputeThreadIdToInputIndexing( + int64_t root_index, int64_t hero_operand_index, + mlir::MLIRContext* ctx) const = 0; static constexpr std::array kIndexingMapThreadIdxDims = {0, 1, 2}; static constexpr std::array kIndexingMapBlockIdxDims = {3, 4, 5}; @@ -85,14 +92,11 @@ class KernelFusionInterface : public FusionInterface { protected: // Returns the default mapping for the given launch dimensions: linearizes // the thread index and then reshapes it into the output layout. - static mlir::AffineMap GetDefaultThreadIdToOutputIndexingMap( - const LaunchDimensions& launch_dims, int unroll_factor, - const Shape& output_shape, mlir::MLIRContext* ctx); - // Populates the ranges for d0, d1, d2, d3, d4, d5 from the thread counts and // block sizes in the given launch dimensions. - static Domain GetThreadIdDomain(const LaunchDimensions& launch_dims, - int unroll_factor); + static IndexingMap GetDefaultThreadIdToOutputIndexingMap( + const LaunchDimensions& launch_dims, int unroll_factor, + const Shape& output_shape, mlir::MLIRContext* ctx); }; // Base class for fusions that are implemented using a single kernel, which is @@ -100,13 +104,13 @@ class KernelFusionInterface : public FusionInterface { class KernelFusionEmitterBase : public KernelFusionInterface { public: absl::StatusOr Emit( - IrEmitterContext& ir_emitter_context, mlir::lmhlo::FusionOp fusion_op, + IrEmitterContext& ir_emitter_context, const HloFusionInstruction& fusion) const final; protected: // Creates initializer thunks that need to run before the main kernel. virtual absl::StatusOr EmitInitializers( - IrEmitterContext& ir_emitter_context, mlir::lmhlo::FusionOp fusion_op, + IrEmitterContext& ir_emitter_context, const HloFusionInstruction& fusion) const { // No initializers by default. return FusionEmissionResult{}; @@ -130,6 +134,11 @@ BuildKernelPrototype(IrEmitterContext& ir_emitter_context, const LaunchDimensions& launch_dimensions, llvm::IRBuilder<>* builder); +absl::Status AnnotateKernelLaunchDimensions( + const se::DeviceDescription& device_info, + const LaunchDimensions& launch_dims, const std::string& kernel_name, + llvm::Module* llvm_module); + } // namespace gpu } // namespace xla diff --git a/third_party/xla/xla/service/gpu/fusions/fusions.cc b/third_party/xla/xla/service/gpu/fusions/fusions.cc index ef78788d28c74f..d5617503c1102d 100644 --- a/third_party/xla/xla/service/gpu/fusions/fusions.cc +++ b/third_party/xla/xla/service/gpu/fusions/fusions.cc @@ -36,6 +36,8 @@ limitations under the License. #include "xla/service/gpu/fusions/in_place_dynamic_update_slice.h" #include "xla/service/gpu/fusions/input_slices.h" #include "xla/service/gpu/fusions/loop.h" +#include "xla/service/gpu/fusions/loop_mlir.h" +#include "xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir.h" #include "xla/service/gpu/fusions/reduction.h" #include "xla/service/gpu/fusions/scatter.h" #include "xla/service/gpu/fusions/transpose.h" @@ -186,6 +188,18 @@ absl::StatusOr> GetFusionEmitter( if (auto copy_fusion = fusion_info.GetCopyFusion()) { return *std::move(copy_fusion); } + + if (analysis.fusion_roots() + .front() + ->GetModule() + ->config() + .debug_options() + .xla_gpu_enable_mlir_emitters() && + mlir_converter::IsHloConversionSupported( + analysis.fusion(), + fusion_info.analysis().device_info().gpu_compute_capability())) { + return std::make_unique(analysis); + } return std::make_unique(analysis); } case HloFusionAnalysis::EmitterFusionKind::kReduction: diff --git a/third_party/xla/xla/service/gpu/fusions/in_place_dynamic_update_slice.h b/third_party/xla/xla/service/gpu/fusions/in_place_dynamic_update_slice.h index 836750a1cb0816..12be8043b05ec1 100644 --- a/third_party/xla/xla/service/gpu/fusions/in_place_dynamic_update_slice.h +++ b/third_party/xla/xla/service/gpu/fusions/in_place_dynamic_update_slice.h @@ -67,12 +67,19 @@ class InPlaceDynamicUpdateSliceFusion : public KernelFusionEmitterBase { LaunchDimensions launch_dimensions() const override; std::optional ComputeThreadIdToOutputIndexing( - int64_t output_id, mlir::MLIRContext* ctx) const override { + int64_t root_index, mlir::MLIRContext* ctx) const override { // The mapping cannot be statically computed in general, since the offsets // are unknown. return std::nullopt; } + std::optional ComputeThreadIdToInputIndexing( + int64_t root_index, int64_t hero_operand_index, + mlir::MLIRContext* ctx) const override { + // TODO(b/319081342): Implement this. + return std::nullopt; + } + protected: absl::Status EmitKernel(IrEmitterContext& ir_emitter_context, const HloFusionInstruction& fusion, diff --git a/third_party/xla/xla/service/gpu/fusions/input_slices.cc b/third_party/xla/xla/service/gpu/fusions/input_slices.cc index 9eac6b96ff08fb..85f661a8f125f5 100644 --- a/third_party/xla/xla/service/gpu/fusions/input_slices.cc +++ b/third_party/xla/xla/service/gpu/fusions/input_slices.cc @@ -34,7 +34,6 @@ limitations under the License. #include "xla/service/gpu/ir_emission_utils.h" #include "xla/service/gpu/ir_emitter_context.h" #include "xla/service/gpu/launch_dimensions.h" -#include "xla/service/gpu/model/indexing_analysis.h" #include "xla/service/gpu/parallel_loop_emitter.h" #include "xla/service/llvm_ir/fused_ir_emitter.h" #include "xla/service/llvm_ir/ir_array.h" @@ -191,11 +190,8 @@ std::optional InputSlicesFusion::ComputeThreadIdToOutputIndexing( // The implementation requires the shapes and layouts to be the same, but we // still use the requested output's shape for clarity. const auto& shape = analysis_.fusion_roots()[output_id]->shape(); - IndexingMap result{GetDefaultThreadIdToOutputIndexingMap( - launch_dims, unroll_factor_, shape, ctx), - GetThreadIdDomain(launch_dims, unroll_factor_)}; - result.Simplify(); - return result; + return GetDefaultThreadIdToOutputIndexingMap(launch_dims, unroll_factor_, + shape, ctx); } absl::Status InputSlicesFusion::EmitKernel( diff --git a/third_party/xla/xla/service/gpu/fusions/input_slices.h b/third_party/xla/xla/service/gpu/fusions/input_slices.h index 0cd3aa7601d685..90f4f4e4a24d03 100644 --- a/third_party/xla/xla/service/gpu/fusions/input_slices.h +++ b/third_party/xla/xla/service/gpu/fusions/input_slices.h @@ -50,6 +50,13 @@ class InputSlicesFusion : public KernelFusionEmitterBase { std::optional ComputeThreadIdToOutputIndexing( int64_t output_id, mlir::MLIRContext* ctx) const override; + std::optional ComputeThreadIdToInputIndexing( + int64_t root_index, int64_t hero_operand_index, + mlir::MLIRContext* ctx) const override { + // TODO(b/319081342): Implement this. + return std::nullopt; + } + protected: absl::Status EmitKernel(IrEmitterContext& ir_emitter_context, const HloFusionInstruction& fusion, diff --git a/third_party/xla/xla/service/gpu/fusions/input_slices_test.cc b/third_party/xla/xla/service/gpu/fusions/input_slices_test.cc index 52de75d7971482..094bbfac7a27a9 100644 --- a/third_party/xla/xla/service/gpu/fusions/input_slices_test.cc +++ b/third_party/xla/xla/service/gpu/fusions/input_slices_test.cc @@ -32,20 +32,13 @@ namespace xla { namespace gpu { namespace { -using ::testing::ElementsAre; -using ::testing::HasSubstr; -using ::testing::IsEmpty; - class InputSlicesTest : public HloTestBase { public: void SetUp() override { HloTestBase::SetUp(); - printer_.SetDimensionName(0, "th_x"); - printer_.SetDimensionName(1, "th_y"); - printer_.SetDimensionName(2, "th_z"); - printer_.SetDimensionName(3, "bl_x"); - printer_.SetDimensionName(4, "bl_y"); - printer_.SetDimensionName(5, "bl_z"); + printer_ = + AffineMapPrinter({"th_x", "th_y", "th_z", "bl_x", "bl_y", "bl_z"}, + {"chunk_id", "unroll_id"}); } protected: @@ -84,17 +77,23 @@ TEST_F(InputSlicesTest, ThreadIndexing) { auto thread_id_to_output_indexing = fusion->ComputeThreadIdToOutputIndexing(0, &mlir_context_); - EXPECT_THAT(printer_.ToString(thread_id_to_output_indexing->affine_map), - HasSubstr("(th_x, th_y, th_z, bl_x, bl_y, bl_z) -> " - "(0, " - "((th_x + bl_x * 128) floordiv 3) mod 2, " - "(th_x + bl_x * 128) mod 3, " - "((th_x + bl_x * 128) floordiv 6) mod 5)")); - EXPECT_THAT(thread_id_to_output_indexing->domain, - MatchDomain(ElementsAre(MatchRange(0, 127), MatchRange(0, 0), - MatchRange(0, 0), MatchRange(0, 1), - MatchRange(0, 0), MatchRange(0, 0)), - IsEmpty())); + EXPECT_THAT(thread_id_to_output_indexing->ToString(printer_), + MatchIndexingString(R"( + (th_x, th_y, th_z, bl_x, bl_y, bl_z)[chunk_id, unroll_id] -> (0, + ((th_x + bl_x * 128) floordiv 3) mod 2, + (th_x + bl_x * 128) mod 3, + ((bl_x * 64 + th_x floordiv 2) floordiv 3) mod 5) + domain: + th_x in [0, 127] + th_y in [0, 0] + th_z in [0, 0] + bl_x in [0, 1] + bl_y in [0, 0] + bl_z in [0, 0] + chunk_id in [0, 0] + unroll_id in [0, 0] + th_x + bl_x * 128 in [0, 29] + )")); } } // namespace diff --git a/third_party/xla/xla/service/gpu/fusions/loop.cc b/third_party/xla/xla/service/gpu/fusions/loop.cc index e635570d056221..e7a13200fe391f 100644 --- a/third_party/xla/xla/service/gpu/fusions/loop.cc +++ b/third_party/xla/xla/service/gpu/fusions/loop.cc @@ -139,6 +139,8 @@ std::pair RowVectorizationEnabled( num_big_inputs); } +} // namespace + LaunchDimensionsConfig ComputeLoopFusionConfig( const HloFusionAnalysis& analysis) { int unroll_factor = 1; @@ -209,20 +211,36 @@ LaunchDimensionsConfig ComputeLoopFusionConfig( return launch_config; } -} // namespace - LoopFusion::LoopFusion(const HloFusionAnalysis& analysis) : analysis_(analysis), config_(ComputeLoopFusionConfig(analysis)) {} std::optional LoopFusion::ComputeThreadIdToOutputIndexing( - int64_t output_id, mlir::MLIRContext* ctx) const { + int64_t root_index, mlir::MLIRContext* ctx) const { auto launch_dims = launch_dimensions(); - const auto& shape = analysis_.fusion_roots()[output_id]->shape(); - IndexingMap result{GetDefaultThreadIdToOutputIndexingMap( - launch_dims, config_.unroll_factor, shape, ctx), - GetThreadIdDomain(launch_dims, config_.unroll_factor)}; - result.Simplify(); - return result; + return GetDefaultThreadIdToOutputIndexingMap( + launch_dims, config_.unroll_factor, GetElementShape(analysis_), ctx); +} + +std::optional LoopFusion::ComputeThreadIdToInputIndexing( + int64_t root_index, int64_t hero_operand_index, + mlir::MLIRContext* ctx) const { + std::optional thread_id_to_output_indexing = + ComputeThreadIdToOutputIndexing(root_index, ctx); + if (!thread_id_to_output_indexing.has_value()) { + return std::nullopt; + } + const HloInstruction* fusion_root = analysis_.fusion_roots()[root_index]; + auto output_to_input_indexing = + ComputeOutputToInputIndexing(fusion_root, /*output_id=*/0, ctx); + IndexingMapSet output_to_input_indexing_set = + output_to_input_indexing.indexing_maps[hero_operand_index]; + // Since we are computing the indexing for a non-fusion op, there is only one + // indexing map per operand. + CHECK_EQ(output_to_input_indexing_set.size(), 1); + IndexingMap thread_id_to_input_indexing_map = ComposeIndexingMaps( + *thread_id_to_output_indexing, *output_to_input_indexing_set.begin()); + thread_id_to_input_indexing_map.Simplify(); + return thread_id_to_input_indexing_map; } absl::Status LoopFusion::EmitKernel(IrEmitterContext& ir_emitter_context, diff --git a/third_party/xla/xla/service/gpu/fusions/loop.h b/third_party/xla/xla/service/gpu/fusions/loop.h index 38112a6a209cbe..e466abe66a843f 100644 --- a/third_party/xla/xla/service/gpu/fusions/loop.h +++ b/third_party/xla/xla/service/gpu/fusions/loop.h @@ -26,7 +26,7 @@ limitations under the License. #include "xla/service/gpu/hlo_fusion_analysis.h" #include "xla/service/gpu/ir_emitter_context.h" #include "xla/service/gpu/launch_dimensions.h" -#include "xla/service/gpu/model/indexing_analysis.h" +#include "xla/service/gpu/model/indexing_map.h" #include "xla/service/llvm_ir/ir_array.h" #include "xla/status.h" @@ -40,7 +40,11 @@ class LoopFusion : public KernelFusionEmitterBase { LaunchDimensions launch_dimensions() const override; std::optional ComputeThreadIdToOutputIndexing( - int64_t output_id, mlir::MLIRContext* ctx) const override; + int64_t root_index, mlir::MLIRContext* ctx) const override; + + std::optional ComputeThreadIdToInputIndexing( + int64_t root_index, int64_t hero_operand_index, + mlir::MLIRContext* ctx) const override; protected: absl::Status EmitKernel(IrEmitterContext& ir_emitter_context, @@ -55,6 +59,9 @@ class LoopFusion : public KernelFusionEmitterBase { LaunchDimensionsConfig config_; }; +LaunchDimensionsConfig ComputeLoopFusionConfig( + const HloFusionAnalysis& analysis); + } // namespace gpu } // namespace xla diff --git a/third_party/xla/xla/service/gpu/fusions/loop_mlir.cc b/third_party/xla/xla/service/gpu/fusions/loop_mlir.cc new file mode 100644 index 00000000000000..3f01ed334f17f6 --- /dev/null +++ b/third_party/xla/xla/service/gpu/fusions/loop_mlir.cc @@ -0,0 +1,170 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "xla/service/gpu/fusions/loop_mlir.h" + +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project +#include "mlir/IR/AffineExpr.h" // from @llvm-project +#include "mlir/IR/AffineMap.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/ImplicitLocOpBuilder.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/IR/ValueRange.h" // from @llvm-project +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_instructions.h" +#include "xla/service/gpu/fusions/mlir/computation_partitioner.h" +#include "xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir.h" +#include "xla/service/gpu/launch_dimensions.h" +#include "xla/shape.h" +#include "xla/status_macros.h" +#include "xla/xla_data.pb.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/statusor.h" + +namespace xla { +namespace gpu { +namespace { + +const Shape& GetFusionResultShape(const HloFusionAnalysis& analysis) { + const Shape* shape = &analysis.fusion_roots().front()->shape(); + while (shape->IsTuple()) { + shape = &shape->tuple_shapes(0); + } + return *shape; +} + +} // namespace + +std::optional MlirLoopFusion::ComputeThreadIdToOutputIndexing( + int64_t root_index, mlir::MLIRContext* ctx) const { + auto launch_dims = launch_dimensions(); + return GetDefaultThreadIdToOutputIndexingMap( + launch_dims, config_.unroll_factor, GetFusionResultShape(analysis_), ctx); +} + +std::optional MlirLoopFusion::ComputeThreadIdToInputIndexing( + int64_t root_index, int64_t hero_operand_index, + mlir::MLIRContext* ctx) const { + std::optional thread_id_to_output_indexing = + ComputeThreadIdToOutputIndexing(root_index, ctx); + if (!thread_id_to_output_indexing.has_value()) { + return std::nullopt; + } + const HloInstruction* fusion_root = analysis_.fusion_roots()[root_index]; + auto output_to_input_indexing = + ComputeOutputToInputIndexing(fusion_root, /*output_id=*/0, ctx); + IndexingMapSet output_to_input_indexing_set = + output_to_input_indexing.indexing_maps[hero_operand_index]; + // Since we are computing the indexing for a non-fusion op, there is only one + // indexing map per operand. + CHECK_EQ(output_to_input_indexing_set.size(), 1); + IndexingMap thread_id_to_input_indexing_map = ComposeIndexingMaps( + *thread_id_to_output_indexing, *output_to_input_indexing_set.begin()); + thread_id_to_input_indexing_map.Simplify(); + return thread_id_to_input_indexing_map; +} + +LaunchDimensions MlirLoopFusion::launch_dimensions() const { + return CalculateLaunchDimensions(GetFusionResultShape(analysis_), + analysis_.device_info(), config_); +} + +absl::Status MlirLoopFusion::EmitMlir( + mlir::ModuleOp module, mlir::func::FuncOp entry_function, + const HloFusionInstruction& fusion) const { + mlir_converter::PartitionedComputations computations( + fusion.fused_instructions_computation()); + + const auto& root_computation = computations.FindPartitionedComputation( + fusion.fused_instructions_computation()); + const auto& root_graph = root_computation.GetRootSubgraph(); + + auto subgraph_to_mlir_fn = computations.DeclareFunctions(module); + subgraph_to_mlir_fn.extract(&root_graph).mapped().erase(); + + auto call_target_lookup = [&](const HloInstruction* instr) { + return subgraph_to_mlir_fn[&computations + .FindPartitionedComputation(instr->parent()) + .FindSubgraph(instr)]; + }; + + for (const auto& comp : computations.partitioned_computations()) { + for (const auto& subgraph : comp.subgraphs()) { + if (&subgraph == &root_graph) { + // We inline the root subgraph. + continue; + } + TF_RETURN_IF_ERROR(mlir_converter::SubgraphToMlirFunction( + comp, subgraph, subgraph_to_mlir_fn[&subgraph], call_target_lookup)); + } + } + + mlir::ImplicitLocOpBuilder builder(entry_function.getLoc(), entry_function); + builder.setInsertionPointToStart(entry_function.addEntryBlock()); + + // We enforce that all the root shapes have identical dimensions in + // IsHloOpSupported. + auto indexing = ComputeThreadIdToOutputIndexing(0, module.getContext()); + TF_RET_CHECK(indexing) << "Indexing is never nullopt"; + + int num_inputs = fusion.fused_instructions_computation()->num_parameters(); + llvm::SmallVector input_tensors( + entry_function.getArguments().take_front(num_inputs)); + auto output_tensor_args = + entry_function.getArguments().drop_front(num_inputs); + + TF_ASSIGN_OR_RETURN( + auto result_tensors, + EmitLoopNest( + builder, output_tensor_args, *indexing, + [&](mlir::ValueRange output_tensors, mlir::ValueRange output_indices) + -> absl::StatusOr> { + llvm::SmallVector args(input_tensors); + absl::c_copy(output_indices, std::back_inserter(args)); + TF_ASSIGN_OR_RETURN( + auto result_scalars, + mlir_converter::SubgraphToMlir( + root_computation, root_graph, call_target_lookup, + input_tensors, output_indices, builder)); + + llvm::SmallVector result_tensors; + result_tensors.reserve(output_tensor_args.size()); + for (auto [tensor, value] : + llvm::zip(output_tensors, result_scalars)) { + result_tensors.push_back(builder + .create( + value, tensor, output_indices) + .getResult()); + } + return result_tensors; + })); + + builder.create(result_tensors); + + return absl::OkStatus(); +} + +} // namespace gpu +} // namespace xla diff --git a/third_party/xla/xla/service/gpu/fusions/loop_mlir.h b/third_party/xla/xla/service/gpu/fusions/loop_mlir.h new file mode 100644 index 00000000000000..dec08459a6a1f6 --- /dev/null +++ b/third_party/xla/xla/service/gpu/fusions/loop_mlir.h @@ -0,0 +1,60 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef XLA_SERVICE_GPU_FUSIONS_LOOP_MLIR_H_ +#define XLA_SERVICE_GPU_FUSIONS_LOOP_MLIR_H_ + +#include +#include + +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/Interfaces/DataLayoutInterfaces.h" // from @llvm-project +#include "xla/hlo/ir/hlo_instructions.h" +#include "xla/service/gpu/fusions/loop.h" +#include "xla/service/gpu/fusions/mlir/mlir_fusion_emitter.h" +#include "xla/service/gpu/hlo_fusion_analysis.h" +#include "xla/service/gpu/launch_dimensions.h" +#include "xla/status.h" + +namespace xla { +namespace gpu { + +// Generic loop fusion. Lowers to LLVM via MLIR. +class MlirLoopFusion : public MlirFusionEmitterBase { + public: + explicit MlirLoopFusion(const HloFusionAnalysis& analysis) + : analysis_(analysis), config_(ComputeLoopFusionConfig(analysis)) {} + LaunchDimensions launch_dimensions() const override; + + std::optional ComputeThreadIdToOutputIndexing( + int64_t root_index, mlir::MLIRContext* ctx) const override; + + std::optional ComputeThreadIdToInputIndexing( + int64_t root_index, int64_t hero_operand_index, + mlir::MLIRContext* ctx) const override; + + protected: + absl::Status EmitMlir(mlir::ModuleOp module, + mlir::func::FuncOp entry_function, + const HloFusionInstruction& fusion) const override; + + private: + const HloFusionAnalysis& analysis_; + LaunchDimensionsConfig config_; +}; + +} // namespace gpu +} // namespace xla + +#endif // XLA_SERVICE_GPU_FUSIONS_LOOP_MLIR_H_ diff --git a/third_party/xla/xla/service/gpu/fusions/loop_mlir_test.cc b/third_party/xla/xla/service/gpu/fusions/loop_mlir_test.cc new file mode 100644 index 00000000000000..fd6b212b1f620c --- /dev/null +++ b/third_party/xla/xla/service/gpu/fusions/loop_mlir_test.cc @@ -0,0 +1,256 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "xla/service/gpu/fusions/loop_mlir.h" + +#include + +#include +#include +#include "llvm/Support/raw_ostream.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" // from @llvm-project +#include "mlir/Dialect/Func/Extensions/InlinerExtension.h" // from @llvm-project +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/Dialect/GPU/IR/GPUDialect.h" // from @llvm-project +#include "mlir/Dialect/Math/IR/Math.h" // from @llvm-project +#include "mlir/Dialect/MemRef/Transforms/Passes.h" // from @llvm-project +#include "mlir/Dialect/SCF/IR/SCF.h" // from @llvm-project +#include "mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project +#include "mlir/IR/DialectRegistry.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/Pass/PassManager.h" // from @llvm-project +#include "xla/hlo/ir/hlo_casting_utils.h" +#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" +#include "xla/service/gpu/gpu_device_info_for_tests.h" +#include "xla/service/gpu/hlo_fusion_analysis.h" +#include "xla/stream_executor/device_description.h" +#include "xla/tests/filecheck.h" +#include "xla/tests/hlo_test_base.h" +#include "tsl/platform/statusor.h" + +namespace xla { +namespace gpu { +namespace { + +class MlirLoopFusionTest : public HloTestBase { + public: + MlirLoopFusionTest() { + context_.loadDialect(); + mlir::DialectRegistry registry; + mlir::func::registerInlinerExtension(registry); + context_.appendDialectRegistry(registry); + } + + stream_executor::DeviceDescription device_info_ = + TestGpuDeviceInfo::RTXA6000DeviceInfo(); + mlir::MLIRContext context_; +}; + +TEST_F(MlirLoopFusionTest, NoCodeDuplication) { + // This test HLO is copied from + // xla/service/fusion_node_indexing_evaluation_test.cc. + auto module = ParseAndReturnVerifiedModule(R"( +HloModule test_module +%fused_computation (param: f32[6]) -> f32[2] { + %param = f32[6]{0} parameter(0) + %slice0.1 = f32[5]{0} slice(f32[6]{0} %param), slice={[0:5]} + %slice0.2 = f32[5]{0} slice(f32[6]{0} %param), slice={[1:6]} + %add0 = f32[5]{0} add(f32[5]{0} %slice0.1, f32[5]{0} %slice0.2) + %slice1.1 = f32[4]{0} slice(f32[5]{0} %add0), slice={[0:4]} + %slice1.2 = f32[4]{0} slice(f32[5]{0} %add0), slice={[1:5]} + %add1 = f32[4]{0} add(f32[4]{0} %slice1.1, f32[4]{0} %slice1.2) + %slice2.1 = f32[3]{0} slice(f32[4]{0} %add1), slice={[0:3]} + %slice2.2 = f32[3]{0} slice(f32[4]{0} %add1), slice={[1:4]} + %add2 = f32[3]{0} add(f32[3]{0} %slice2.1, f32[3]{0} %slice2.2) + %slice3.1 = f32[2]{0} slice(f32[3]{0} %add2), slice={[0:2]} + %slice3.2 = f32[2]{0} slice(f32[3]{0} %add2), slice={[1:3]} + ROOT %add3 = f32[2]{0} add(f32[2]{0} %slice3.1, f32[2]{0} %slice3.2) +} + +ENTRY entry_computation { + p0 = f32[] parameter(0) + add = f32[] add(p0, p0) + broadcast = f32[6]{0} broadcast(add), dimensions={} + ROOT %fusion = f32[2]{0} fusion(broadcast), kind=kLoop, calls=%fused_computation +})") + .value(); + + auto* root = module->entry_computation()->root_instruction(); + auto analysis = AnalyzeFusion(*root, device_info_); + + MlirLoopFusion fusion(analysis); + TF_ASSERT_OK_AND_ASSIGN( + auto mlir_module, + fusion.CreateMLIRModule(context_, *Cast(root), + "fused_computation", nullptr)); + + std::string out; + llvm::raw_string_ostream os(out); + mlir_module->print(os); + ASSERT_TRUE(RunFileCheck(out, R"( +// CHECK-COUNT-4: arith.add +// CHECK-NOT: arith.add +)") + .value()); +} + +TEST_F(MlirLoopFusionTest, TwoUsersConsistentIndexing) { + auto module = ParseAndReturnVerifiedModule(R"( +HloModule test_module +%fused_computation (param: f32[6]) -> f32[2] { + %p0 = f32[2]{0} parameter(0) + %p1 = f32[2]{0} parameter(1) + %add = f32[2] add(%p0, %p1) + %sub = f32[2] subtract(%p0, %p1) + %mul = f32[2] multiply(%add, %sub) + %div = f32[2] divide(%add, %sub) + ROOT %atan2 = f32[2] atan2(%mul, %div) +} + +ENTRY entry_computation { + p0 = f32[2] parameter(0) + p1 = f32[2] parameter(1) + ROOT %fusion = f32[2] fusion(p0, p1), kind=kLoop, calls=%fused_computation +})") + .value(); + + auto* root = module->entry_computation()->root_instruction(); + auto analysis = AnalyzeFusion(*root, device_info_); + + MlirLoopFusion fusion(analysis); + TF_ASSERT_OK_AND_ASSIGN( + auto mlir_module, + fusion.CreateMLIRModule(context_, *Cast(root), + "fused_computation", nullptr)); + + std::string out; + llvm::raw_string_ostream os(out); + mlir_module->print(os); + ASSERT_TRUE(RunFileCheck(out, R"( + // CHECK: func.func @fused_computation + // CHECK-NEXT: gpu.thread_id + // CHECK-NEXT: tensor.extract + // CHECK-NEXT: tensor.extract + // CHECK-NEXT: addf + // CHECK-NEXT: subf + // CHECK-NEXT: mulf + // CHECK-NEXT: divf + // CHECK-NEXT: atan2 + // CHECK-NEXT: tensor.insert + )") + .value()); +} + +TEST_F(MlirLoopFusionTest, IotaCopyBitcastBroadcastReshapeReverseTranspose) { + auto module = ParseAndReturnVerifiedModule(R"( +HloModule test_module +%fused_computation { + %iota = f32[10,20,30] iota(), iota_dimension=2 + %copy = f32[10,20,30] copy(%iota) + %bitcast = s32[10,20,30] bitcast(%copy) + %broadcast = s32[2,10,3,20,5,30,7] broadcast(%bitcast), dimensions={1,3,5} + %reshape = s32[20,60,150,7] reshape(%broadcast) + %reverse = s32[20,60,150,7] reverse(%reshape), dimensions={2,3} + ROOT %transpose = s32[60,20,7,150] transpose(%reverse), dimensions={1,0,3,2} +} + +ENTRY entry_computation { + ROOT %fusion = s32[60,20,7,150] fusion(), kind=kLoop, calls=%fused_computation +})") + .value(); + + auto* root = module->entry_computation()->root_instruction(); + auto analysis = AnalyzeFusion(*root, device_info_); + + MlirLoopFusion fusion(analysis); + TF_ASSERT_OK_AND_ASSIGN( + auto mlir_module, + fusion.CreateMLIRModule(context_, *Cast(root), + "fused_computation", nullptr)); + + std::string out; + llvm::raw_string_ostream os(out); + mlir_module->print(os); + + ASSERT_TRUE(RunFileCheck(out, R"( + // CHECK-COUNT-1: func.func + // CHECK-NOT: func.func + )") + .value()); +} + +TEST_F(MlirLoopFusionTest, VariadicReduce) { + auto module = ParseAndReturnVerifiedModule(R"( + HloModule Test, is_scheduled=true + +Add { + scalar_lhs.0 = f32[] parameter(0) + scalar_rhs.0 = f32[] parameter(1) + scalar_lhs.1 = f32[] parameter(2) + scalar_rhs.1 = f32[] parameter(3) + add.0 = f32[] add(scalar_lhs.0, scalar_lhs.1) + add.1 = f32[] add(scalar_rhs.0, scalar_rhs.1) + ROOT t = (f32[], f32[]) tuple(add.0, add.1) +} + +fused_computation { + param_0 = f32[5,200,300]{2,1,0} parameter(0) + param_1 = f32[5,200,300]{2,1,0} parameter(1) + param_2 = f32[] parameter(2) + ROOT d.1 = (f32[200]{0}, f32[200]{0}) reduce(f32[5,200,300]{2,1,0} param_0, f32[5,200,300]{2,1,0} %param_1, f32[] param_2, f32[] param_2), dimensions={0,2}, to_apply=Add +} + +ENTRY main { + a = f32[5, 200, 300]{2,1,0} parameter(0) + b = f32[5, 200, 300]{2,1,0} parameter(1) + c = f32[] constant(0) + ROOT fusion = (f32[200]{0}, f32[200]{0}) fusion(f32[5,200,300]{2,1,0} a, f32[5,200,300]{2,1,0} b, f32[] c), kind=kLoop, calls=fused_computation +} + )") + .value(); + + auto* root = module->entry_computation()->root_instruction(); + auto analysis = AnalyzeFusion(*root, device_info_); + + MlirLoopFusion fusion(analysis); + TF_ASSERT_OK_AND_ASSIGN( + auto mlir_module, + fusion.CreateMLIRModule(context_, *Cast(root), + "fused_computation", nullptr)); + + std::string out; + llvm::raw_string_ostream os(out); + mlir_module->print(os); + + ASSERT_TRUE(RunFileCheck(out, R"( + // CHECK: #[[MAP:.*]] = affine_map<()[s0, s1] -> ((s0 + s1 * 128) mod 200)> + // CHECK: func @fused_computation( + // CHECK: %[[TID_X:.*]] = gpu.thread_id x + // CHECK: %[[BID_X:.*]] = gpu.block_id x + // CHECK: %[[IDX:.*]] = affine.apply #[[MAP]]()[%[[TID_X]], %[[BID_X]]] + // CHECK: %[[RET:.*]]:2 = func.call @Add_t + // CHECK: yield %[[RET]]#0, %[[RET]]#1 + // CHECK: %[[INSERTED_1:.*]] = tensor.insert %{{.*}}#0 into %{{.*}}[%[[IDX]]] + // CHECK: %[[INSERTED_2:.*]] = tensor.insert %{{.*}}#1 into %{{.*}}[%[[IDX]]] + // CHECK: yield %[[INSERTED_1]], %[[INSERTED_2]] +)") + .value()); +} + +} // namespace +} // namespace gpu +} // namespace xla diff --git a/third_party/xla/xla/service/gpu/fusions/loop_test.cc b/third_party/xla/xla/service/gpu/fusions/loop_test.cc index 1dc59082b8d251..e497b52ebaf843 100644 --- a/third_party/xla/xla/service/gpu/fusions/loop_test.cc +++ b/third_party/xla/xla/service/gpu/fusions/loop_test.cc @@ -35,20 +35,14 @@ namespace xla { namespace gpu { namespace { -using ::testing::ElementsAre; -using ::testing::HasSubstr; -using ::testing::IsEmpty; - class LoopTest : public HloTestBase { public: void SetUp() override { HloTestBase::SetUp(); - printer_.SetDimensionName(0, "th_x"); - printer_.SetDimensionName(1, "th_y"); - printer_.SetDimensionName(2, "th_z"); - printer_.SetDimensionName(3, "bl_x"); - printer_.SetDimensionName(4, "bl_y"); - printer_.SetDimensionName(5, "bl_z"); + + printer_ = + AffineMapPrinter({"th_x", "th_y", "th_z", "bl_x", "bl_y", "bl_z"}, + {"chunk_id", "unroll_id"}); } protected: @@ -58,15 +52,15 @@ class LoopTest : public HloTestBase { mlir::MLIRContext mlir_context_; }; -absl::StatusOr> GetLoopFusion( +absl::StatusOr> GetFusion( const HloFusionAnalysis& analysis) { TF_ASSIGN_OR_RETURN( auto emitter, GetFusionEmitter(PreBufferAssignmentFusionInfo{analysis})); - auto fusion = dynamic_cast(emitter.get()); + auto fusion = dynamic_cast(emitter.get()); TF_RET_CHECK(fusion != nullptr); emitter.release(); - return std::unique_ptr{fusion}; + return std::unique_ptr{fusion}; } TEST_F(LoopTest, ThreadIndexingUnrolled) { @@ -87,19 +81,30 @@ TEST_F(LoopTest, ThreadIndexingUnrolled) { auto* root = module->entry_computation()->root_instruction(); auto analysis = AnalyzeFusion(*root, device_info_); - TF_ASSERT_OK_AND_ASSIGN(auto loop_fusion, GetLoopFusion(analysis)); + TF_ASSERT_OK_AND_ASSIGN(auto loop_fusion, GetFusion(analysis)); auto thread_id_to_output_indexing = - loop_fusion->ComputeThreadIdToOutputIndexing(0, &mlir_context_); - EXPECT_THAT(printer_.ToString(thread_id_to_output_indexing->affine_map), - HasSubstr("(th_x, th_y, th_z, bl_x, bl_y, bl_z)[s0] -> (" - "(th_x * 4 + bl_x * 512 + s0) floordiv 60000, " - "((th_x * 4 + bl_x * 512 + s0) floordiv 300) mod 200, " - "(th_x * 4 + bl_x * 512 + s0) mod 300)")); - EXPECT_THAT(thread_id_to_output_indexing->domain, - MatchDomain(ElementsAre(MatchRange(0, 127), MatchRange(0, 0), - MatchRange(0, 0), MatchRange(0, 1007), - MatchRange(0, 0), MatchRange(0, 0)), - ElementsAre(MatchRange(0, 3)))); + loop_fusion->ComputeThreadIdToOutputIndexing(/*root_index=*/0, + &mlir_context_); + + EXPECT_THAT(thread_id_to_output_indexing->ToString(printer_), + MatchIndexingString(R"( + (th_x, th_y, th_z, bl_x, bl_y, bl_z)[chunk_id, unroll_id] -> ( + (((bl_x * 16 + th_x floordiv 8) floordiv 3 + chunk_id * 5376) floordiv 625) mod 100, + (((th_x + bl_x * 128) floordiv 3 + chunk_id * 43008) floordiv 25) mod 200, + th_x * 4 + bl_x * 512 + chunk_id * 516096 + unroll_id - + (((th_x + bl_x * 128) floordiv 3 + chunk_id * 43008) floordiv 25) * 300 + ) + domain: + th_x in [0, 127] + th_y in [0, 0] + th_z in [0, 0] + bl_x in [0, 1007] + bl_y in [0, 0] + bl_z in [0, 0] + chunk_id in [0, 11] + unroll_id in [0, 3] + (th_x + bl_x * 128) * 4 + chunk_id * 516096 in [0, 5999996] +)")); } TEST_F(LoopTest, ThreadIndexingNotUnrolled) { @@ -120,16 +125,98 @@ TEST_F(LoopTest, ThreadIndexingNotUnrolled) { auto* root = module->entry_computation()->root_instruction(); auto analysis = AnalyzeFusion(*root, device_info_); - TF_ASSERT_OK_AND_ASSIGN(auto loop_fusion, GetLoopFusion(analysis)); + TF_ASSERT_OK_AND_ASSIGN(auto loop_fusion, GetFusion(analysis)); + auto thread_id_to_output_indexing = + loop_fusion->ComputeThreadIdToOutputIndexing(/*root_index=*/0, + &mlir_context_); + EXPECT_THAT(thread_id_to_output_indexing->ToString(printer_), + MatchIndexingString(R"( + (th_x, th_y, th_z, bl_x, bl_y, bl_z)[chunk_id, unroll_id] -> (th_x) + domain: + th_x in [0, 19] + th_y in [0, 0] + th_z in [0, 0] + bl_x in [0, 0] + bl_y in [0, 0] + bl_z in [0, 0] + chunk_id in [0, 0] + unroll_id in [0, 0] + )")); + auto thread_id_to_input_indexing = + loop_fusion->ComputeThreadIdToInputIndexing( + /*root_index=*/0, /*hero_operand_index=*/0, &mlir_context_); + EXPECT_THAT(thread_id_to_input_indexing->ToString(printer_), + MatchIndexingString(R"( + (th_x, th_y, th_z, bl_x, bl_y, bl_z)[chunk_id, unroll_id] -> (th_x) + domain: + th_x in [0, 19] + th_y in [0, 0] + th_z in [0, 0] + bl_x in [0, 0] + bl_y in [0, 0] + bl_z in [0, 0] + chunk_id in [0, 0] + unroll_id in [0, 0] + )")); +} + +TEST_F(LoopTest, Broadcast) { + auto module = ParseAndReturnVerifiedModule(R"( + HloModule module + + bcast { + %input = f32[20] parameter(0) + ROOT bcast = f32[10, 20, 30] broadcast(%input), dimensions={1} + } + + ENTRY entry { + %input = f32[20] parameter(0) + ROOT %fusion = f32[10, 20, 30] fusion(%input), kind=kLoop, calls=bcast + })") + .value(); + + auto* root = module->entry_computation()->root_instruction(); + auto analysis = AnalyzeFusion(*root, device_info_); + + TF_ASSERT_OK_AND_ASSIGN(auto loop_fusion, GetFusion(analysis)); auto thread_id_to_output_indexing = - loop_fusion->ComputeThreadIdToOutputIndexing(0, &mlir_context_); - EXPECT_THAT(printer_.ToString(thread_id_to_output_indexing->affine_map), - HasSubstr("(th_x, th_y, th_z, bl_x, bl_y, bl_z) -> (th_x)")); - EXPECT_THAT(thread_id_to_output_indexing->domain, - MatchDomain(ElementsAre(MatchRange(0, 19), MatchRange(0, 0), - MatchRange(0, 0), MatchRange(0, 0), - MatchRange(0, 0), MatchRange(0, 0)), - IsEmpty())); + loop_fusion->ComputeThreadIdToOutputIndexing(/*root_index=*/0, + &mlir_context_); + EXPECT_THAT(thread_id_to_output_indexing->ToString(printer_), + MatchIndexingString(R"( + (th_x, th_y, th_z, bl_x, bl_y, bl_z)[chunk_id, unroll_id] -> ( + ((bl_x * 16 + th_x floordiv 8) floordiv 75) mod 10, + ((bl_x * 64 + th_x floordiv 2) floordiv 15) mod 20, + (th_x + bl_x * 128) mod 30) + domain: + th_x in [0, 127] + th_y in [0, 0] + th_z in [0, 0] + bl_x in [0, 46] + bl_y in [0, 0] + bl_z in [0, 0] + chunk_id in [0, 0] + unroll_id in [0, 0] + th_x + bl_x * 128 in [0, 5999] + )")); + auto thread_id_to_input_indexing = + loop_fusion->ComputeThreadIdToInputIndexing( + /*root_index=*/0, /*hero_operand_index=*/0, &mlir_context_); + EXPECT_THAT(thread_id_to_input_indexing->ToString(printer_), + MatchIndexingString(R"( + (th_x, th_y, th_z, bl_x, bl_y, bl_z)[chunk_id, unroll_id] -> ( + ((bl_x * 64 + th_x floordiv 2) floordiv 15) mod 20) + domain: + th_x in [0, 127] + th_y in [0, 0] + th_z in [0, 0] + bl_x in [0, 46] + bl_y in [0, 0] + bl_z in [0, 0] + chunk_id in [0, 0] + unroll_id in [0, 0] + th_x + bl_x * 128 in [0, 5999] + )")); } } // namespace diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/BUILD b/third_party/xla/xla/service/gpu/fusions/mlir/BUILD new file mode 100644 index 00000000000000..570e26f2f3549a --- /dev/null +++ b/third_party/xla/xla/service/gpu/fusions/mlir/BUILD @@ -0,0 +1,301 @@ +load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library") +load("//xla:xla.bzl", "xla_cc_test") + +package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = [":friends"], + licenses = ["notice"], +) + +package_group( + name = "friends", + includes = [ + "//xla:friends", + ], +) + +cc_library( + name = "computation_partitioner", + srcs = ["computation_partitioner.cc"], + hdrs = ["computation_partitioner.h"], + deps = [ + ":type_util", + "//xla:shape_util", + "//xla:union_find", + "//xla/hlo/ir:hlo", + "//xla/service/llvm_ir:llvm_util", + "//xla/translate/hlo_to_mhlo:hlo_utils", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:node_hash_map", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:LLVMDialect", + "@llvm-project//mlir:TensorDialect", + ], +) + +xla_cc_test( + name = "computation_partitioner_test", + srcs = ["computation_partitioner_test.cc"], + deps = [ + ":computation_partitioner", + "//xla/tests:hlo_test_base", + "//xla/tests:xla_internal_test_main", + "@com_google_googletest//:gtest", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + ], +) + +cc_library( + name = "elemental_hlo_to_mlir", + srcs = ["elemental_hlo_to_mlir.cc"], + hdrs = ["elemental_hlo_to_mlir.h"], + deps = [ + ":computation_partitioner", + "//xla:comparison_util", + "//xla:shape_util", + "//xla:status_macros", + "//xla/hlo/ir:hlo", + "//xla/mlir_hlo", + "//xla/mlir_hlo:map_mhlo_to_scalar_op", + "//xla/service/gpu:hlo_traversal", + "//xla/service/gpu/model:indexing_analysis", + "//xla/service/gpu/model:indexing_map", + "//xla/stream_executor:device_description", + "//xla/translate/hlo_to_mhlo:hlo_utils", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/container:node_hash_map", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:AffineDialect", + "@llvm-project//mlir:AffineUtils", + "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:LLVMDialect", + "@llvm-project//mlir:SCFDialect", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TensorDialect", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:statusor", + ], +) + +xla_cc_test( + name = "elemental_hlo_to_mlir_test", + srcs = ["elemental_hlo_to_mlir_test.cc"], + deps = [ + ":computation_partitioner", + ":elemental_hlo_to_mlir", + "//xla:status_macros", + "//xla/hlo/ir:hlo", + "//xla/mlir_hlo", + "//xla/service:hlo_parser", + "//xla/service/llvm_ir:llvm_util", + "//xla/tests:filecheck", + "//xla/tests:hlo_test_base", + "//xla/tests:xla_internal_test_main", + "@com_google_absl//absl/status", + "@com_google_googletest//:gtest", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:AffineDialect", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:MathDialect", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:SCFDialect", + "@llvm-project//mlir:TensorDialect", + "@llvm-project//mlir:Transforms", + "@local_tsl//tsl/lib/core:status_test_util", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:statusor", + ], +) + +cc_library( + name = "mlir_fusion_emitter", + srcs = ["mlir_fusion_emitter.cc"], + hdrs = ["mlir_fusion_emitter.h"], + deps = [ + ":elemental_hlo_to_mlir", + ":passes", + ":type_util", + "//xla:shape_util", + "//xla:status_macros", + "//xla/hlo/ir:hlo", + "//xla/mlir_hlo", + "//xla/service:buffer_assignment", + "//xla/service/gpu:ir_emitter_context", + "//xla/service/gpu:kernel_arguments", + "//xla/service/gpu:target_util", + "//xla/service/gpu/fusions:fusion_emitter", + "//xla/service/gpu/model:indexing_map", + "//xla/service/gpu/runtime:kernel_thunk", + "//xla/service/llvm_ir:llvm_util", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@llvm-project//llvm:Core", + "@llvm-project//llvm:Linker", + "@llvm-project//llvm:Support", + "@llvm-project//llvm:ir_headers", + "@llvm-project//mlir:AffineDialect", + "@llvm-project//mlir:AffineToStandard", + "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:ArithTransforms", + "@llvm-project//mlir:BufferizationDialect", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:FuncExtensions", + "@llvm-project//mlir:GPUDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:MathDialect", + "@llvm-project//mlir:MemRefTransforms", + "@llvm-project//mlir:NVVMDialect", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:SCFDialect", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TensorDialect", + "@llvm-project//mlir:ToLLVMIRTranslation", + "@llvm-project//mlir:Transforms", + "@local_tsl//tsl/platform:errors", + ], +) + +xla_cc_test( + name = "mlir_fusion_emitter_test", + srcs = ["mlir_fusion_emitter_test.cc"], + deps = [ + ":mlir_fusion_emitter", + "//xla/hlo/ir:hlo", + "//xla/mlir_hlo", + "//xla/service/gpu:gpu_device_info_for_tests", + "//xla/service/gpu:launch_dimensions", + "//xla/service/gpu/model:indexing_map", + "//xla/tests:filecheck", + "//xla/tests:hlo_test_base", + "//xla/tests:xla_internal_test_main", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings:string_view", + "@com_google_googletest//:gtest", + "@llvm-project//llvm:Support", + "@llvm-project//llvm:ir_headers", + "@llvm-project//mlir:AffineDialect", + "@llvm-project//mlir:BufferizationDialect", + "@llvm-project//mlir:BuiltinToLLVMIRTranslation", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:GPUDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:LLVMToLLVMIRTranslation", + "@llvm-project//mlir:MathDialect", + "@llvm-project//mlir:NVVMDialect", + "@llvm-project//mlir:NVVMToLLVMIRTranslation", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:SCFDialect", + "@llvm-project//mlir:TensorDialect", + "@local_tsl//tsl/platform:statusor", + ], +) + +gentbl_cc_library( + name = "passes_inc_gen", + tbl_outs = [ + ( + [ + "-gen-pass-decls", + "-name=GpuFusionTransforms", + ], + "passes.h.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "passes.td", + visibility = ["//visibility:private"], + deps = ["@llvm-project//mlir:PassBaseTdFiles"], +) + +cc_library( + name = "passes", + srcs = [ + "expand_float_conversions.cc", + "lower_tensors.cc", + "lower_to_llvm.cc", + "merge_pointers_to_same_slice.cc", + "propagate_slice_indices.cc", + "simplify_affine.cc", + ], + hdrs = ["passes.h"], + deps = [ + ":elemental_hlo_to_mlir", + ":passes_inc_gen", + "//xla:shape_util", + "//xla:xla_data_proto_cc", + "//xla/mlir_hlo", + "//xla/mlir_hlo:map_mhlo_to_scalar_op", + "//xla/service/gpu/model:indexing_map", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/container:flat_hash_map", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:AffineDialect", + "@llvm-project//mlir:AffineToStandard", + "@llvm-project//mlir:AffineUtils", + "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:ArithToLLVM", + "@llvm-project//mlir:ArithTransforms", + "@llvm-project//mlir:ComplexDialect", + "@llvm-project//mlir:ComplexToLLVM", + "@llvm-project//mlir:ControlFlowToLLVM", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:FuncToLLVM", + "@llvm-project//mlir:GPUDialect", + "@llvm-project//mlir:GPUToNVVMTransforms", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:LLVMCommonConversion", + "@llvm-project//mlir:LLVMDialect", + "@llvm-project//mlir:MathDialect", + "@llvm-project//mlir:MathToLLVM", + "@llvm-project//mlir:NVVMDialect", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:SCFDialect", + "@llvm-project//mlir:SCFToControlFlow", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TensorDialect", + "@llvm-project//mlir:TransformUtils", + "@llvm-project//mlir:VectorTransforms", + ], +) + +cc_library( + name = "type_util", + srcs = ["type_util.cc"], + hdrs = ["type_util.h"], + deps = [ + "//xla:shape_util", + "//xla/translate/hlo_to_mhlo:hlo_utils", + "@com_google_absl//absl/log:check", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:IR", + ], +) + +xla_cc_test( + name = "type_util_test", + srcs = ["type_util_test.cc"], + deps = [ + ":type_util", + "//xla:shape_util", + "//xla/tests:xla_internal_test_main", + "@com_google_googletest//:gtest", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:IR", + ], +) diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/README.md b/third_party/xla/xla/service/gpu/fusions/mlir/README.md new file mode 100644 index 00000000000000..d692bd279bce98 --- /dev/null +++ b/third_party/xla/xla/service/gpu/fusions/mlir/README.md @@ -0,0 +1,157 @@ +# XLA MLIR fusion emitters + +This is a prototype of a new loop emitter. The main goals are: + +- Fixing exponential code size issues with the current emitter. We should be + able to generate reasonable code for any fusion (note that execution time may + still be bad, but that's a problem for priority fusion). +- Fixing compile time (as a result of the above). +- Make the code easier to understand thanks to gradual lowering. +- Eventually extend the concepts here to the other emitters (transpose, reduce + in particular) + +## High-level overview + +The code consists of the following big building blocks: + +- Computation partitioning - splitting an HLO computation into functions +- Elemental emission of XLA instructions +- Based on the above two: emission of functions +- The actual emitter +- Lowerings to LLVM + +## Partitioning + +See `computation_partitioner.h`. + +Non-elementwise HLO instructions cannot always be emitted together. Consider the +following HLO graph: + +``` + param + | + log + | \ + | transpose + | / + add +``` + +If we emit this in a single function, the `log` will be accessed at two +different indices for each element of the `add`. The old emitters solve this +problem by generating the `log` twice. For this particular graph, this is not +a problem, but when there are multiple splits, the code size grows +exponentially. + +Here, we solve this problem by partitioning the graph into pieces that can be +safely emitted as one function. The criteria are: + +- Instructions that have only one user are safe to emit together with their + user. +- Instructions that have multiple users are safe to emit together with their + users if they are accessed through the same indices by all users. + +In the example above, the `add` and `tranpose` access different indices of the +`log`, so it is not safe to emit it together with them. + +The graph is therefore partitioned into three functions (each containing just +one instruction). + +## Elemental emission + +See `elemental_hlo_to_mlir.h`. + +Elemental emission is based on `mlir_hlo` and reuses it for all element-wise +instructions. For the most part, this is straightforward, but there are some +interesting things going on here. + +### Indexing transformations + +Some instructions (`transpose`, `broadcast`, `reshape`, `slice`, `reverse` and +a few more) are purely transformations on indices: to produce an element of the +result, we need to produce some other element of the input. For this, we can +reuse XLA's `indexing_analysis`, which has functions to produce the output to +input mapping for an instruction. + +For example, for a `transpose` from `[20,40]` to `[40,20]`, it will produce the +following indexing map (one affine expression per input dimension; d0 and d1 are +the output dimensions): + +``` + (d0, d1) -> d1 + (d0, d1) -> d0 +``` + +So for these pure index transformation instructions, we can simply get the map, +apply it to the output indices, and produce the input at the resulting index. + +Similarly, the `pad` op uses indexing maps and constraints for most of the +implementation. `pad` is also an indexing transformation with some added checks +to see if we return an element of the input or the padding value. + +### Tuples + +We do not support internal `tuple`s. We also do not support nested tuple +outputs. All XLA graphs that use these features can be converted to graphs that +do not. + +### Gather + +We only support canonical gathers as produced by [`gather_simplifier`]( +https://github.com/openxla/xla/blob/main/xla/service/gather_simplifier.h). + +## Emission of functions + +For a subgraph of a computation with parameters `%p0` to `%p_n`, and subgraph +roots with rank `r` and element types (`e0` to `e_m`), we use the following MLIR +function signature: + +`````` +(%p0: tensor<...>, %p1: tensor<...>, ..., %pn: tensor<...>, + %i0: index, %i1: index, ..., %i_r-1: index) -> (e0, ..., e_m) +`````` + +That is, we have one tensor input per computation parameter, one index input per +dimension of the output, and one result per output. + +To emit a function, we simply use the elemental emitter above, and recursively +emit its operands until we reach the edge of the subgraph. Then, we: + +- emit a `tensor.extract` for parameters +- or emit a `func.call` for other subgraphs + +## Putting it together: the loop emitter + +The loop emitter first partitions its fusion computation and emits code for each +subgraph. Then, it has to generate an entry function. The entry function is +different from the functions above, since it has no indices as inputs (just the +thread and block IDs) and actually needs to write the output somewhere. For the +loop emitter, this is fairly straightforward, but the transpose and reduction +emitters have non-trivial write logic. + +The signature of the entry computation is: + +``` +(%p0: tensor<...>, ..., %pn: tensor<...>, + %r0: tensor<...>, ..., %rn: tensor<...>) -> (tensor<...>, ..., tensor<...>) +``` + +Where like before, the `%pn`s are the parameters of the computation, and the +`%rn`s are the results of the computation. The entry computation takes the +results as tensors, `tensor.insert`s updates into them, and then returns them. +No other uses of the output tensors are allowed. + +## Lowerings to LLVM + +We mostly use the standard LLVM lowerings, but there are a few special passes. +We cannot use the `memref` lowerings for tensors, since we don't bufferize the +IR and our ABI is not compatible with the `memref` ABI. Instead, we have a +custom lowering directly from tensors to `LLVM`. + +- The lowering of tensors is done in `lower_tensors.cc`. `tensor.extract` is + lowered to `llvm.load`, `tensor.insert` to `llvm.store`, in the obvious way. +- `propagate_slice_indices` and `merge_pointers_to_same_slice` together + implement a detail of buffer assignment and XLA's ABI: if two tensors share + the same buffer slice, they are only passed once. These passes deduplicate the + function arguments. + diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/computation_partitioner.cc b/third_party/xla/xla/service/gpu/fusions/mlir/computation_partitioner.cc new file mode 100644 index 00000000000000..dab586f685dff4 --- /dev/null +++ b/third_party/xla/xla/service/gpu/fusions/mlir/computation_partitioner.cc @@ -0,0 +1,271 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "xla/service/gpu/fusions/mlir/computation_partitioner.h" + +#include +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/container/flat_hash_map.h" +#include "absl/container/node_hash_map.h" +#include "absl/log/check.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/SmallVector.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" // from @llvm-project +#include "mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/ImplicitLocOpBuilder.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Interfaces/DataLayoutInterfaces.h" // from @llvm-project +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/service/gpu/fusions/mlir/type_util.h" +#include "xla/service/llvm_ir/llvm_util.h" +#include "xla/shape.h" +#include "xla/translate/hlo_to_mhlo/hlo_utils.h" +#include "xla/union_find.h" + +namespace xla { +namespace gpu { +namespace mlir_converter { +namespace { + +absl::flat_hash_map PartitionGraphByIndexing( + const HloComputation& computation) { + constexpr int kRootIndexing = 0; + int next_indexing = 1; + absl::flat_hash_map indexing; + + std::function indexing_for_instr; + indexing_for_instr = [&](const HloInstruction* instr) -> int { + auto it = indexing.find(instr); + if (it != indexing.end()) return it->second; + + if (instr->opcode() != HloOpcode::kTuple && + !HloInstruction::IsOpElementwise(instr->opcode())) { + return indexing[instr] = next_indexing++; + } + + if (instr->user_count() == 0) { + return indexing[instr] = kRootIndexing; + } + + // If all users have the same indexing, we can reuse it. + std::optional instr_indexing = std::nullopt; + for (auto* user : instr->users()) { + auto user_indexing = indexing_for_instr(user); + if (user->opcode() == HloOpcode::kConcatenate || + (instr_indexing && user_indexing != *instr_indexing)) { + instr_indexing = std::nullopt; + break; + } + instr_indexing = user_indexing; + } + return indexing[instr] = instr_indexing ? *instr_indexing : next_indexing++; + }; + + for (auto* instr : computation.instructions()) { + indexing_for_instr(instr); + } + + return indexing; +} + +} // namespace + +PartitionedComputation::PartitionedComputation( + const HloComputation* computation) + : computation_(computation) { + // For each instruction, figure out what function it goes in. Parameters don't + // count. + absl::node_hash_map> + disjoint_sets; + auto indexing = PartitionGraphByIndexing(*computation); + for (auto* instruction : computation->instructions()) { + if (instruction->opcode() == HloOpcode::kParameter) continue; + disjoint_sets[instruction].Get() = instruction; + } + for (auto* instruction : computation->instructions()) { + if (instruction->opcode() == HloOpcode::kParameter) continue; + bool can_merge = + instruction->user_count() == 1 || + (instruction->user_count() > 1 && + absl::c_all_of(instruction->users(), [&](const HloInstruction* user) { + return indexing.at(user) == indexing.at(instruction); + })); + auto is_bad_gather = [&](const HloInstruction* user) { + // Don't merge into a gather that would evaluate the index more than once. + return user->opcode() == HloOpcode::kGather && + user->operand_index(instruction) == 1 && + instruction->shape().dimensions(1) > 1; + }; + auto is_concat = [&](const HloInstruction* user) { + // Concat codegen doesn't work if any of a concat's transitive inputs is + // reused. Instead of checking, we just cut the function at the concat, + // which has the benefit of leading to slightly easier to read IR. + return user->opcode() == HloOpcode::kConcatenate; + }; + can_merge &= !absl::c_any_of(instruction->users(), is_bad_gather); + can_merge &= !absl::c_any_of(instruction->users(), is_concat); + if (can_merge) { + auto& set = disjoint_sets[instruction]; + for (auto* user : instruction->users()) { + set.Merge(&disjoint_sets[user]); + } + } + } + + ConstHloInstructionMap> functions; + for (auto* instruction : computation->MakeInstructionPostOrder()) { + if (instruction->opcode() == HloOpcode::kParameter) continue; + functions[disjoint_sets[instruction].Get()].push_back(instruction); + } + + subgraphs_.reserve(functions.size()); + for (auto& [cluster_id, instructions] : functions) { + std::vector roots; + for (auto* instruction : instructions) { + if (instruction->user_count() == 0 || + absl::c_any_of(instruction->users(), + [cluster_id = cluster_id, &disjoint_sets](auto* user) { + return disjoint_sets[user].Get() != cluster_id; + })) { + roots.push_back(instruction); + } + } + CHECK(!roots.empty()) << "No roots found"; + std::string name = llvm_ir::SanitizeFunctionName(absl::StrCat( + roots.front()->parent()->name(), "_", + absl::StrJoin(roots, "_", [](std::string* out, const auto* root) { + absl::StrAppend(out, root->name()); + }))); + subgraphs_.push_back( + Subgraph{.name = std::move(name), + .instructions_post_order = std::move(instructions), + .roots = std::move(roots)}); + } + + for (const auto& subgraph : subgraphs_) { + for (const auto* instruction : subgraph.instructions_post_order) { + instructions_to_subgraphs_[instruction] = &subgraph; + } + } +} + +PartitionedComputations::PartitionedComputations(const HloComputation* fusion) { + // Collect all transitively called computations (including the fusion itself). + absl::flat_hash_set seen; + std::vector computations; + std::function visit; + visit = [&](const HloComputation* computation) { + if (!seen.insert(computation).second) return; + computations.push_back(computation); + for (auto* instr : computation->instructions()) { + absl::c_for_each(instr->called_computations(), visit); + } + }; + visit(fusion); + + partitioned_computations_.reserve(computations.size()); + for (auto* computation : computations) { + computation_to_partitioning_[computation] = + &partitioned_computations_.emplace_back( + PartitionedComputation{computation}); + } +} + +absl::flat_hash_map +PartitionedComputations::DeclareFunctions(mlir::ModuleOp module) const { + absl::flat_hash_map + mapping; + mlir::ImplicitLocOpBuilder builder(module.getLoc(), module->getContext()); + builder.setInsertionPointToEnd(module.getBody()); + for (const auto& computation : partitioned_computations_) { + for (const auto& subgraph : computation.subgraphs()) { + auto func_op = CreateSubgraphMlirFunction(subgraph, builder); + func_op->setAttr("llvm.linkage", mlir::LLVM::LinkageAttr::get( + module->getContext(), + mlir::LLVM::Linkage::Internal)); + mapping[&subgraph] = func_op; + } + } + return mapping; +} + +mlir::func::FuncOp CreateSubgraphMlirFunction( + const PartitionedComputation::Subgraph& subgraph, + mlir::ImplicitLocOpBuilder& b) { + auto* computation = subgraph.roots.front()->parent(); + llvm::SmallVector parameter_types; + llvm::SmallVector result_types; + + auto element_type = [&](const auto& shape) { + return *ConvertPrimitiveTypeToMLIRType(shape.element_type(), b); + }; + + const xla::Shape* one_root_shape = nullptr; + for (auto* root : subgraph.roots) { + if (root->shape().IsTuple()) { + for (auto& shape : root->shape().tuple_shapes()) { + one_root_shape = &shape; + result_types.push_back(element_type(shape)); + } + } else { + one_root_shape = &root->shape(); + result_types.push_back(element_type(root->shape())); + } + } + + llvm::SmallVector arg_attrs; + // We support the entry computation here for convenience of testing. The entry + // computation is never code generated here. + if (computation->IsFusionComputation() || computation->IsEntryComputation()) { + for (auto* param : computation->parameter_instructions()) { + parameter_types.push_back(TensorShapeToMlirType(param->shape(), b)); + arg_attrs.emplace_back(); + } + for (int dim = 0; dim < one_root_shape->rank(); ++dim) { + parameter_types.push_back(b.getIndexType()); + arg_attrs.emplace_back(mlir::DictionaryAttr::get( + b.getContext(), + {b.getNamedAttr( + "xla.range", + b.getIndexArrayAttr({0, one_root_shape->dimensions(dim) - 1}))})); + } + } else { + for (auto* param : computation->parameter_instructions()) { + parameter_types.push_back(element_type(param->shape())); + } + } + auto ty = b.getFunctionType(parameter_types, result_types); + return b.create( + subgraph.name, ty, + /*attrs=*/llvm::ArrayRef{}, arg_attrs); +} + +} // namespace mlir_converter +} // namespace gpu +} // namespace xla diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/computation_partitioner.h b/third_party/xla/xla/service/gpu/fusions/mlir/computation_partitioner.h new file mode 100644 index 00000000000000..605593801315f4 --- /dev/null +++ b/third_party/xla/xla/service/gpu/fusions/mlir/computation_partitioner.h @@ -0,0 +1,136 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef XLA_SERVICE_GPU_FUSIONS_MLIR_COMPUTATION_PARTITIONER_H_ +#define XLA_SERVICE_GPU_FUSIONS_MLIR_COMPUTATION_PARTITIONER_H_ + +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/types/span.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/ImplicitLocOpBuilder.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/IR/ValueRange.h" // from @llvm-project +#include "mlir/Interfaces/DataLayoutInterfaces.h" // from @llvm-project +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_instruction.h" + +namespace xla { +namespace gpu { +namespace mlir_converter { + +// Partitions an HLO computation into subgraphs so that all users of a node have +// consistent indexing, i. e. when we compute a node `a` with users `b` and `c`, +// all three nodes will have the same indexing - neither of `b` or `c` will be a +// transpose, reshape, reduce, etc. +// +// Consider the following example, where we assume all nodes affect indexing: +// +// a b Here we create four subgraphs: `a,d,c,e`, `b`, `f` and `g`. If +// \ /| `f` and `g` didn't change the indexing, they would be included +// d c | in the `a,d,c,e` subgraph, so we'd have `b` and the rest. +// \ | | +// e | Note that if some users have the same indexing as a node (e.g. +// / \| `e` and `g` in the graph to the left), we still have to create +// f g separate subgraphs for `f` and `g`. +// +// The purpose of this partitioning is to allow us to generate code without ever +// having to duplicate instructions: users with incompatible indexing will be in +// different subgraphs, each of which will emit a call to the producer graph. +// +// Note that this partitioning will sometimes create silly subgraphs that should +// (and will) be inlined, e. g. containing only a constant or only a broadcast. +class PartitionedComputation { + public: + explicit PartitionedComputation(const HloComputation* computation); + + struct Subgraph { + // A unique name of the subgraph. Used for function names. + std::string name; + + // The instructions that make up this subgraph. + std::vector instructions_post_order; + + // The roots. These are guaranteed not to have users inside the subgraph. + std::vector roots; + }; + + absl::Span subgraphs() const { return subgraphs_; } + + const HloComputation& computation() const { return *computation_; } + + const Subgraph& GetRootSubgraph() const { + return FindSubgraph(computation_->root_instruction()); + } + + // Returns the subgraph containing the given instruction. + const Subgraph& FindSubgraph(const HloInstruction* instr) const { + return *instructions_to_subgraphs_.at(instr); + } + + private: + const HloComputation* computation_; + std::vector subgraphs_; + absl::flat_hash_map + instructions_to_subgraphs_; +}; + +// A collection of PartitionedComputations, starting at a fusion computation and +// including all transitively called computations. +class PartitionedComputations { + public: + explicit PartitionedComputations(const HloComputation* fusion); + + const PartitionedComputation& FindPartitionedComputation( + const HloComputation* computation) const { + return *computation_to_partitioning_.at(computation); + } + + absl::Span partitioned_computations() const { + return partitioned_computations_; + } + + // Declares func.func ops for each subgraph in each computation and returns a + // mapping from subgraph to declared function. + absl::flat_hash_map + DeclareFunctions(mlir::ModuleOp module) const; + + private: + std::vector partitioned_computations_; + absl::flat_hash_map + computation_to_partitioning_; +}; + +// Returns an MLIR function declaration for the given subgraph. For subgraphs of +// fusions, the signature is: +// (ptr, ptr, ..., index, index, ...) -> element type(s) +// For subgraphs of called computations, the signature is: +// (elemen type, ...) -> element type(s) +// +// Subgraphs of fusions will also have range (xla.range = [lower_bound, +// upper_bound], both bounds are inclusive) annotations on their index +// arguments. +mlir::func::FuncOp CreateSubgraphMlirFunction( + const PartitionedComputation::Subgraph& subgraph, + mlir::ImplicitLocOpBuilder& b); + +} // namespace mlir_converter +} // namespace gpu +} // namespace xla + +#endif // XLA_SERVICE_GPU_FUSIONS_MLIR_COMPUTATION_PARTITIONER_H_ diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/computation_partitioner_test.cc b/third_party/xla/xla/service/gpu/fusions/mlir/computation_partitioner_test.cc new file mode 100644 index 00000000000000..996a5c1196c457 --- /dev/null +++ b/third_party/xla/xla/service/gpu/fusions/mlir/computation_partitioner_test.cc @@ -0,0 +1,194 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "xla/service/gpu/fusions/mlir/computation_partitioner.h" + +#include + +#include +#include +#include "llvm/Support/raw_ostream.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/ImplicitLocOpBuilder.h" // from @llvm-project +#include "mlir/IR/Location.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "xla/tests/hlo_test_base.h" + +namespace xla { +namespace gpu { +namespace mlir_converter { +namespace { + +using ::testing::ElementsAre; +using ::testing::SizeIs; + +using ComputationPartitionerTest = HloTestBase; + +TEST_F(ComputationPartitionerTest, PartitionDiamonds) { + auto module = ParseAndReturnVerifiedModule(R"( + HloModule test_module + fused_computation { + %param = f32[6] parameter(0) + %slice0.1 = f32[5] slice(f32[6]{0} %param), slice={[0:5]} + %slice0.2 = f32[5] slice(f32[6]{0} %param), slice={[1:6]} + %add0 = f32[5] add(f32[5]{0} %slice0.1, f32[5]{0} %slice0.2) + %slice1.1 = f32[4] slice(f32[5]{0} %add0), slice={[0:4]} + %slice1.2 = f32[4] slice(f32[5]{0} %add0), slice={[1:5]} + %add1 = f32[4] add(f32[4]{0} %slice1.1, f32[4]{0} %slice1.2) + %slice2.1 = f32[3] slice(f32[4]{0} %add1), slice={[0:3]} + %slice2.2 = f32[3] slice(f32[4]{0} %add1), slice={[1:4]} + %add2 = f32[3] add(f32[3]{0} %slice2.1, f32[3]{0} %slice2.2) + %slice3.1 = f32[2] slice(f32[3]{0} %add2), slice={[0:2]} + %slice3.2 = f32[2] slice(f32[3]{0} %add2), slice={[1:3]} + ROOT %add3 = f32[2] add(f32[2]{0} %slice3.1, f32[2]{0} %slice3.2) + })") + .value(); + + auto* fusion = module->GetComputationWithName("fused_computation"); + ASSERT_NE(fusion, nullptr); + PartitionedComputation computation(fusion); + auto slice01 = fusion->GetInstructionWithName("slice0.1"); + auto slice02 = fusion->GetInstructionWithName("slice0.2"); + auto add0 = fusion->GetInstructionWithName("add0"); + auto slice11 = fusion->GetInstructionWithName("slice1.1"); + auto slice12 = fusion->GetInstructionWithName("slice1.2"); + auto add1 = fusion->GetInstructionWithName("add1"); + auto slice21 = fusion->GetInstructionWithName("slice2.1"); + auto slice22 = fusion->GetInstructionWithName("slice2.2"); + auto add2 = fusion->GetInstructionWithName("add2"); + auto slice31 = fusion->GetInstructionWithName("slice3.1"); + auto slice32 = fusion->GetInstructionWithName("slice3.2"); + auto add3 = fusion->GetInstructionWithName("add3"); + + const auto& graphs = computation.subgraphs(); + ASSERT_THAT(graphs, SizeIs(4)); + EXPECT_THAT(graphs[0].instructions_post_order, + ElementsAre(slice01, slice02, add0)); + EXPECT_THAT(graphs[1].instructions_post_order, + ElementsAre(slice11, slice12, add1)); + EXPECT_THAT(graphs[2].instructions_post_order, + ElementsAre(slice21, slice22, add2)); + EXPECT_THAT(graphs[3].instructions_post_order, + ElementsAre(slice31, slice32, add3)); + + EXPECT_THAT(graphs[0].roots, ElementsAre(add0)); + EXPECT_THAT(graphs[1].roots, ElementsAre(add1)); + EXPECT_THAT(graphs[2].roots, ElementsAre(add2)); + EXPECT_THAT(graphs[3].roots, ElementsAre(add3)); + + EXPECT_EQ(&computation.GetRootSubgraph(), &graphs[3]); + EXPECT_EQ(&computation.FindSubgraph(slice21), &graphs[2]); +} + +TEST_F(ComputationPartitionerTest, TupleRoot) { + auto module = ParseAndReturnVerifiedModule(R"( + HloModule test_module + fused_computation { + %p0 = f32[6] parameter(0) + %p1 = f32[6] parameter(1) + %add = f32[6] add(p0, p1) + %sub = f32[6] subtract(p0, p1) + ROOT %root = (f32[6], f32[6]) tuple(%add, %sub) + })") + .value(); + + auto* fusion = module->GetComputationWithName("fused_computation"); + ASSERT_NE(fusion, nullptr); + PartitionedComputation computation(fusion); + + ASSERT_THAT(computation.subgraphs(), SizeIs(1)); + EXPECT_THAT(computation.GetRootSubgraph().roots, SizeIs(1)); + EXPECT_THAT(computation.GetRootSubgraph().instructions_post_order, SizeIs(3)); +} + +TEST_F(ComputationPartitionerTest, PartiallyMergable) { + auto module = ParseAndReturnVerifiedModule(R"( + HloModule test_module + fused_computation { + %p0 = f32[10,10] parameter(0) + %p1 = f32[10,10] parameter(1) + %add = f32[10,10] add(%p0, %p1) + %transpose = f32[10,10] transpose(%add), dimensions={1,0} + ROOT %sub = f32[10,10] subtract(%add, %transpose) + })") + .value(); + + auto* fusion = module->GetComputationWithName("fused_computation"); + ASSERT_NE(fusion, nullptr); + PartitionedComputation computation(fusion); + + auto transpose = fusion->GetInstructionWithName("transpose"); + auto sub = fusion->GetInstructionWithName("sub"); + + ASSERT_THAT(computation.subgraphs(), SizeIs(2)); + EXPECT_THAT(computation.GetRootSubgraph().instructions_post_order, + ElementsAre(transpose, sub)); +} + +TEST_F(ComputationPartitionerTest, SubgraphSignatures) { + auto module = ParseAndReturnVerifiedModule(R"( + HloModule test_module + + add { + %p0 = f32[] parameter(0) + %p1 = f32[] parameter(1) + ROOT %add = f32[] add(%p0, %p1) + } + + fusion { + %p0 = f32[10,10]{0,1} parameter(0) + %p1 = f32[10,10]{1,0} parameter(1) + %c0 = f32[] constant(2) + %bc = f32[10,10]{0,1} bitcast(%p1) + %add = f32[10,10] add(%p0, %bc) + ROOT %reduce = f32[10] reduce(%add, %c0), dimensions={1}, to_apply=add + } + + ENTRY main { + %p0 = f32[10,10] parameter(0) + %p1 = f32[10,10] parameter(1) + ROOT %fusion = f32[10] fusion(%p0, %p1), kind=kLoop, calls=fusion + })") + .value(); + auto print = [](mlir::func::FuncOp func) { + // Set visibility to private so the function verifies. + func.setSymVisibility("private"); + std::string out; + llvm::raw_string_ostream os(out); + os << func; + func.erase(); + return out; + }; + + mlir::MLIRContext context; + context.loadDialect(); + mlir::ImplicitLocOpBuilder builder(mlir::UnknownLoc::get(&context), &context); + + PartitionedComputation fusion(module->GetComputationWithName("fusion")); + EXPECT_EQ( + print(CreateSubgraphMlirFunction(fusion.GetRootSubgraph(), builder)), + "func.func private @fusion_reduce(tensor<10x10xf32, dense<[0, 1]> : " + "tensor<2xi64>>, tensor<10x10xf32>, index {xla.range = [0 : index, 9 : " + "index]}) -> f32"); + + PartitionedComputation add(module->GetComputationWithName("add")); + EXPECT_EQ(print(CreateSubgraphMlirFunction(add.GetRootSubgraph(), builder)), + "func.func private @add_add(f32, f32) -> f32"); +} + +} // namespace +} // namespace mlir_converter +} // namespace gpu +} // namespace xla diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir.cc b/third_party/xla/xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir.cc new file mode 100644 index 00000000000000..c321718f6b7e53 --- /dev/null +++ b/third_party/xla/xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir.cc @@ -0,0 +1,952 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir.h" + +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/container/flat_hash_set.h" +#include "absl/container/node_hash_map.h" +#include "absl/log/check.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "llvm/ADT/APFloat.h" +#include "llvm/ADT/APInt.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Support/MathExtras.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" // from @llvm-project +#include "mlir/Dialect/Affine/LoopUtils.h" // from @llvm-project +#include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" // from @llvm-project +#include "mlir/Dialect/SCF/IR/SCF.h" // from @llvm-project +#include "mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project +#include "mlir/IR/AffineExpr.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributeInterfaces.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/ImplicitLocOpBuilder.h" // from @llvm-project +#include "mlir/IR/TypeRange.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/IR/ValueRange.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "xla/comparison_util.h" +#include "xla/hlo/ir/hlo_casting_utils.h" +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_instructions.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" +#include "xla/mlir_hlo/mhlo/transforms/map_mhlo_to_scalar_op.h" +#include "xla/primitive_util.h" +#include "xla/service/gpu/hlo_traversal.h" +#include "xla/service/gpu/model/indexing_analysis.h" +#include "xla/status_macros.h" +#include "xla/stream_executor/device_description.h" +#include "xla/translate/hlo_to_mhlo/hlo_utils.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/statusor.h" + +namespace xla { +namespace gpu { +namespace mlir_converter { +namespace { + +using mlir::Value; +using mlir::ValueRange; +using mlir::arith::AndIOp; +using mlir::arith::CmpFOp; +using mlir::arith::CmpFPredicate; +using mlir::arith::CmpIOp; +using mlir::arith::CmpIPredicate; +using mlir::arith::ConstantOp; +using mlir::arith::SelectOp; +using mlir::scf::ForOp; +using mlir::scf::IfOp; +using mlir::scf::YieldOp; + +namespace arith = ::mlir::arith; +namespace mhlo = ::mlir::mhlo; + +// HLO opcodes that we never support. +static auto& kUnsupportedOps = + *new absl::flat_hash_set{HloOpcode::kAddDependency, + HloOpcode::kAfterAll, + HloOpcode::kAllGather, + HloOpcode::kAllGatherDone, + HloOpcode::kAllGatherStart, + HloOpcode::kAllReduce, + HloOpcode::kAllReduceDone, + HloOpcode::kAllReduceStart, + HloOpcode::kAllToAll, + HloOpcode::kAsyncDone, + HloOpcode::kAsyncStart, + HloOpcode::kAsyncUpdate, + HloOpcode::kBatchNormGrad, + HloOpcode::kBatchNormInference, + HloOpcode::kBatchNormTraining, + HloOpcode::kCholesky, + HloOpcode::kCollectivePermute, + HloOpcode::kCollectivePermuteDone, + HloOpcode::kCollectivePermuteStart, + HloOpcode::kCopyDone, + HloOpcode::kCopyStart, + HloOpcode::kCustomCall, + HloOpcode::kDomain, + HloOpcode::kDynamicReshape, + HloOpcode::kDynamicSlice, + HloOpcode::kFft, + HloOpcode::kFusion, + HloOpcode::kGetDimensionSize, + HloOpcode::kOptimizationBarrier, + HloOpcode::kInfeed, + HloOpcode::kOutfeed, + HloOpcode::kParameter, + HloOpcode::kPartitionId, + HloOpcode::kRecv, + HloOpcode::kRecvDone, + HloOpcode::kReduceScatter, + HloOpcode::kReplicaId, + HloOpcode::kRng, + HloOpcode::kRngBitGenerator, + HloOpcode::kRngGetAndUpdateState, + HloOpcode::kScatter, + HloOpcode::kSelectAndScatter, + HloOpcode::kSend, + HloOpcode::kSendDone, + HloOpcode::kSetDimensionSize, + HloOpcode::kSort, + HloOpcode::kTopK, + HloOpcode::kTriangularSolve, + HloOpcode::kWhile, + HloOpcode::kConditional, + HloOpcode::kStochasticConvert, + HloOpcode::kCall}; + +static auto& kUnimplementedOps = *new absl::flat_hash_set{ + HloOpcode::kConvolution, HloOpcode::kDot, HloOpcode::kDynamicUpdateSlice, + HloOpcode::kMap, HloOpcode::kReduceWindow, + // Custom approximations in XLA: + HloOpcode::kErf, HloOpcode::kTanh, + // Incorrect NaN handling: + HloOpcode::kMaximum, HloOpcode::kMinimum, HloOpcode::kClamp}; + +bool IsUnsupportedConstant(const HloInstruction* instr) { + return instr->opcode() == HloOpcode::kConstant && instr->shape().rank() != 0; +} + +bool IsUnsupportedTuple(const HloInstruction* instr) { + if (instr->opcode() != HloOpcode::kTuple) { + return false; + } + + if (instr->user_count() > 0) { + // Internal tuples are unsupported. + return true; + } + + // Nested tuples and tokens are unsupported. + if (absl::c_any_of(instr->operands(), + [&](auto* op) { return !op->shape().IsArray(); })) { + return true; + } + + // All tuple elements must have the same dimensions (element types may + // differ). + auto first_shape = instr->shape().tuple_shapes(0); + for (int i = 1; i < instr->operand_count(); ++i) { + if (instr->shape().tuple_shapes(i).dimensions() != + first_shape.dimensions()) { + return true; + } + } + return false; +} + +bool IsUnsupportedGather(const HloInstruction* instr) { + // We assume gather simplifier ran, so we don't need to support all gather + // forms. + if (instr->opcode() != HloOpcode::kGather) return false; + + auto* gather = Cast(instr); + const auto& dims = gather->gather_dimension_numbers(); + if (dims.index_vector_dim() != 1 || !dims.collapsed_slice_dims().empty() || + gather->operand(1)->shape().rank() != 2) { + return true; + } + + for (auto [index, val] : llvm::enumerate(dims.start_index_map())) { + if (index != val) return true; + } + for (auto [index, val] : llvm::enumerate(dims.offset_dims())) { + if (index + 1 != val) return true; + } + return false; +} + +absl::StatusOr GetSingleOperandValue( + const OperandProvider& operand_provider, const HloInstruction* instr, + int operand_index, ValueRange indices) { + TF_ASSIGN_OR_RETURN(auto operand, + operand_provider(instr, operand_index, indices)); + TF_RET_CHECK(operand.size() == 1) << "Expected operand to be a single value."; + return operand.front(); +} + +absl::StatusOr> EmitReduce( + const HloInstruction* instr, ValueRange indices, + const OperandProvider& operand_provider, + const CallTargetProvider& call_target_provider, + mlir::ImplicitLocOpBuilder& b) { + llvm::SmallVector reduction_indices(indices); + llvm::SmallVector accumulators; + for (int i = instr->operand_count() / 2; i < instr->operand_count(); ++i) { + TF_ASSIGN_OR_RETURN(accumulators.emplace_back(), + GetSingleOperandValue(operand_provider, instr, i, {})); + } + auto dims = llvm::to_vector(instr->dimensions()); + absl::c_sort(dims); + ForOp outermost_loop = nullptr; + for (int dim : dims) { + auto bound = instr->operands()[0]->shape().dimensions(dim); + auto loop = + b.create(b.create(b.getIndexAttr(0)), + b.create(b.getIndexAttr(bound)), + b.create(b.getIndexAttr(1)), accumulators); + if (outermost_loop == nullptr) { + outermost_loop = loop; + } else { + b.create(loop.getResults()); + } + b.setInsertionPointToStart(loop.getBody()); + reduction_indices.insert(reduction_indices.begin() + dim, + loop.getInductionVar()); + accumulators = {loop.getRegionIterArgs().begin(), + loop.getRegionIterArgs().end()}; + } + llvm::SmallVector args; + for (int i = 0; i < instr->operand_count() / 2; ++i) { + args.push_back(accumulators[i]); + TF_ASSIGN_OR_RETURN( + args.emplace_back(), + GetSingleOperandValue(operand_provider, instr, i, reduction_indices)); + } + auto reducer = call_target_provider( + instr->called_computations().front()->root_instruction()); + b.create(b.create(reducer, args).getResults()); + + b.setInsertionPointAfter(outermost_loop); + return outermost_loop.getResults(); +} + +absl::StatusOr> EmitConcat( + const HloInstruction* instr, ValueRange indices, + const OperandProvider& operand_provider, mlir::ImplicitLocOpBuilder& b) { + int concat_dim = + Cast(instr)->concatenate_dimension(); + auto ty = *ConvertPrimitiveTypeToMLIRType(instr->shape().element_type(), b); + int64_t offset = 0; + IfOp outermost_if = nullptr; + llvm::SmallVector operand_indices = indices; + for (auto [index, operand] : llvm::enumerate(instr->operands())) { + int64_t limit = offset + operand->shape().dimensions(concat_dim); + auto in_bounds = + b.create(CmpIPredicate::ult, indices[concat_dim], + b.create(b.getIndexAttr(limit))); + + auto generate_operand = [&, index = index]() { + operand_indices[concat_dim] = b.create( + indices[concat_dim], b.create(b.getIndexAttr(offset))); + TF_ASSIGN_OR_RETURN(auto operand, + operand_provider(instr, index, operand_indices)); + b.create(operand); + return absl::OkStatus(); + }; + + if (index < instr->operand_count() - 1) { + auto if_op = b.create(mlir::TypeRange{ty}, in_bounds, true, true); + if (outermost_if == nullptr) { + outermost_if = if_op; + } else { + b.create(if_op.getResults()); + } + + b.setInsertionPointToStart(if_op.getBody(0)); + TF_RETURN_IF_ERROR(generate_operand()); + b.setInsertionPointToStart(if_op.getBody(1)); + } else { + TF_RETURN_IF_ERROR(generate_operand()); + } + offset = limit; + } + + b.setInsertionPointAfter(outermost_if); + return outermost_if.getResults(); +} + +absl::StatusOr> EmitGather( + const HloInstruction* instr, ValueRange indices, + const OperandProvider& operand_provider, mlir::ImplicitLocOpBuilder& b) { + auto row = indices[0]; + auto zero = b.create(b.getIndexAttr(0)); + // Gather allows the index vector to contain fewer elements than the rank + // of the input. In that case, the remaining indices are 0. + llvm::SmallVector operand_indices(instr->operand(0)->shape().rank(), + zero); + + // Produce start indices. + int num_indices = instr->operand(1)->shape().dimensions(1); + for (int i = 0; i < num_indices; ++i) { + auto i_val = i == 0 ? zero : b.create(b.getIndexAttr(i)); + int64_t slice_size = instr->gather_slice_sizes()[i]; + int64_t input_size = instr->operand(0)->shape().dimensions()[i]; + if (slice_size == input_size) { + // We're reading the full dimension, so clamping would always result in a + // zero index. + operand_indices[i] = zero; + } else { + // Read and clamp index. + TF_ASSIGN_OR_RETURN(auto input_index, + operand_provider(instr, 1, {row, i_val})); + TF_RET_CHECK(input_index.size() == 1) + << "Expected operand to be a single value."; + mlir::Value index = + b.create(b.getIndexType(), input_index.front()); + auto max_minus_size = + b.create(b.getIndexAttr(input_size - slice_size)); + index = b.create(index, max_minus_size); + index = b.create(index, zero); + operand_indices[i] = index; + } + } + + // Add offsets. + for (int i = 0; i < operand_indices.size(); ++i) { + operand_indices[i] = + b.createOrFold(operand_indices[i], indices[i + 1]); + } + + return operand_provider(instr, 0, operand_indices); +} + +Value CheckConstraint(mlir::Value constrained_value, Range range, + mlir::ImplicitLocOpBuilder& b) { + auto lb = b.create(b.getIndexAttr(range.lower_bound)); + if (range.IsPoint()) { + return b.create(CmpIPredicate::eq, constrained_value, lb); + } + auto ub = b.create(b.getIndexAttr(range.upper_bound)); + return b.create( + b.create(CmpIPredicate::sge, constrained_value, lb), + b.create(CmpIPredicate::sle, constrained_value, ub)); +} + +// For a given instruction, deduces the indices of each parameter that are +// needed for a given output index. +llvm::SmallVector> GetInputIndices( + const HloInstructionIndexing& indexing, ValueRange output_indices, + mlir::ImplicitLocOpBuilder& b) { + llvm::SmallVector> indices; + for (auto& maps : indexing.indexing_maps) { + CHECK_EQ(maps.size(), 1); + auto map = maps.begin()->GetAffineMap(); + CHECK(!maps.begin()->IsUndefined()); + indices.emplace_back() = ApplyAffineMap(map, output_indices, {}, b); + } + return indices; +} + +absl::StatusOr> EmitPad( + const HloInstruction* instr, ValueRange indices, + const OperandProvider& operand_provider, mlir::ImplicitLocOpBuilder& b) { + auto indexing = ComputeOutputToInputIndexing(instr, 0, b.getContext()); + const auto& indexing_map = *indexing.indexing_maps[0].begin(); + mlir::Value is_in_bounds = CheckConstraints(indexing_map, indices, {}, b); + b.create(b.getIntegerAttr(b.getI1Type(), 1)); + for (auto&& [index, range] : + llvm::enumerate(indexing_map.GetDimensionRanges())) { + // If the range is the full output dimension, it's always in bounds. Sadly, + // this doesn't get optimized automatically. + if (range.lower_bound == 0 && + range.upper_bound == instr->shape().dimensions(index) - 1) { + continue; + } + is_in_bounds = b.create(is_in_bounds, + CheckConstraint(indices[index], range, b)); + } + + auto ty = *ConvertPrimitiveTypeToMLIRType(instr->shape().element_type(), b); + auto if_op = b.create(mlir::TypeRange{ty}, is_in_bounds, true, true); + b.setInsertionPointToStart(if_op.getBody(0)); + TF_ASSIGN_OR_RETURN(auto input_value, + GetSingleOperandValue( + operand_provider, instr, 0, + GetInputIndices(indexing, indices, + b)[0 /* indexing for operand 0 */])); + b.create(input_value); + + b.setInsertionPointToStart(if_op.getBody(1)); + TF_ASSIGN_OR_RETURN(auto padding_value, + GetSingleOperandValue(operand_provider, instr, 1, {})); + b.create(padding_value); + + b.setInsertionPointAfter(if_op); + return if_op.getResults(); +} + +template +llvm::SmallVector MapHloOp(llvm::ArrayRef result_types, + llvm::ArrayRef args, + mlir::ImplicitLocOpBuilder& b, + ExtraArgs&&... extra_args) { + return {mhlo::MhloOpToStdScalarOp::mapOpOfType( + b.getLoc(), result_types, llvm::to_vector(mlir::TypeRange(args)), + typename MhloOp::Adaptor(args, std::forward(extra_args)...), + &b)}; +} + +template +llvm::SmallVector MapElementwiseOp( + llvm::ArrayRef args, mlir::ImplicitLocOpBuilder& b) { + // We use the last argument's type because of select. + return MapHloOp({args.back().getType()}, args, b); +} + +} // namespace + +Value ApplyAffineExpr(mlir::AffineExpr expr, mlir::ValueRange dims, + mlir::ValueRange symbols, mlir::ImplicitLocOpBuilder& b) { + // For unknown (but undoubtedly good) reasons, affine.apply removes unused + // trailing dimensions, but only in the expression. + while (dims.size() > 0 && !expr.isFunctionOfDim(dims.size() - 1)) { + dims = dims.drop_back(); + } + while (symbols.size() > 0 && !expr.isFunctionOfSymbol(symbols.size() - 1)) { + symbols = symbols.drop_back(); + } + llvm::SmallVector args(dims); + absl::c_copy(symbols, std::back_inserter(args)); + return b.createOrFold(expr, args); +} + +llvm::SmallVector ApplyAffineMap(mlir::AffineMap map, + mlir::ValueRange dims, + mlir::ValueRange symbols, + mlir::ImplicitLocOpBuilder& b) { + llvm::SmallVector result; + result.reserve(map.getNumResults()); + for (auto expr : map.getResults()) { + result.push_back(ApplyAffineExpr(expr, dims, symbols, b)); + } + return result; +} + +Value CheckConstraints(const IndexingMap& map, ValueRange dims, + ValueRange symbols, mlir::ImplicitLocOpBuilder& b) { + mlir::Value ret = b.create(b.getIntegerAttr(b.getI1Type(), 1)); + for (auto&& [expression, range] : map.GetConstraints()) { + ret = b.create( + ret, CheckConstraint(ApplyAffineExpr(expression, dims, symbols, b), + range, b)); + } + return ret; +} + +absl::StatusOr> HloToMlir( + const HloInstruction* instr, ValueRange indices, + const OperandProvider& operand_provider, + const CallTargetProvider& call_target_provider, + mlir::ImplicitLocOpBuilder& builder) { + CHECK(!kUnsupportedOps.contains(instr->opcode())) << instr->ToShortString(); + CHECK(!kUnimplementedOps.contains(instr->opcode())) << instr->ToShortString(); + + auto element_type = instr->shape().element_type(); + // Handle ops that aren't elementwise and aren't just indexing + // transformations. + switch (instr->opcode()) { + case HloOpcode::kConcatenate: + return EmitConcat(instr, indices, operand_provider, builder); + case HloOpcode::kConstant: + if (instr->shape().rank() == 0) { + auto val = mlir::cast( + CreateDenseElementsAttrFromLiteral(instr->literal(), builder) + ->getValues()[0]); + return {{builder.create(val).getResult()}}; + } + return absl::UnimplementedError( + absl::StrCat("Unimplemented: ", instr->ToShortString())); + case HloOpcode::kGather: + return EmitGather(instr, indices, operand_provider, builder); + case HloOpcode::kIota: { + auto element_mlir_type = + *ConvertPrimitiveTypeToMLIRType(element_type, builder); + auto index = indices[Cast(instr)->iota_dimension()]; + if (element_mlir_type.getIntOrFloatBitWidth() == 32) { + index = + builder.create(builder.getI32Type(), index); + } else { + index = + builder.create(builder.getI64Type(), index); + } + return MapHloOp({element_mlir_type}, {index}, builder); + } + case HloOpcode::kPad: + return EmitPad(instr, indices, operand_provider, builder); + case HloOpcode::kReduce: + return EmitReduce(instr, indices, operand_provider, call_target_provider, + builder); + case HloOpcode::kTuple: { + CHECK(!IsUnsupportedTuple(instr)); + llvm::SmallVector operands; + for (int i = 0; i < instr->operand_count(); ++i) { + TF_ASSIGN_OR_RETURN( + operands.emplace_back(), + GetSingleOperandValue(operand_provider, instr, i, indices)); + } + return operands; + } + case HloOpcode::kGetTupleElement: { + // We have to generate the entire tuple, but since we don't support + // internal tuple operations (only root tuples), this will always be + // cached and computed together anyway (e.g. it'll be a variadic reduce). + TF_ASSIGN_OR_RETURN(auto tuple, operand_provider(instr, 0, indices)); + return {{tuple[instr->tuple_index()]}}; + } + default: + break; + } + + auto input_indices = GetInputIndices( + ComputeOutputToInputIndexing(instr, 0, builder.getContext()), indices, + builder); + llvm::SmallVector operands; + for (auto&& [operand_number, operand_indices] : + llvm::enumerate(input_indices)) { + TF_ASSIGN_OR_RETURN(operands.emplace_back(), + GetSingleOperandValue(operand_provider, instr, + operand_number, operand_indices)); + // Nulls can be pretty hard to debug, so guard against them here. The MHLO + // conversion functions like to return nullptr for errors. + TF_RET_CHECK(operands.back() != nullptr) + << "null operand at index " << operand_number << " for " + << instr->ToShortString(); + } + CHECK_NE(operands.size(), 0); + + auto element_mlir_type = + *ConvertPrimitiveTypeToMLIRType(element_type, builder); + switch (instr->opcode()) { + case HloOpcode::kAbs: + if (primitive_util::IsComplexType(element_type)) { + return {MapHloOp( + {*ConvertPrimitiveTypeToMLIRType( + primitive_util::ComplexComponentType(element_type), builder)}, + operands, builder)}; + } else { + return MapElementwiseOp(operands, builder); + } + case HloOpcode::kAdd: + if (element_type == PRED) { + return MapElementwiseOp(operands, builder); + } else { + return MapElementwiseOp(operands, builder); + } + case HloOpcode::kAnd: + return MapElementwiseOp(operands, builder); + case HloOpcode::kAtan2: + return MapElementwiseOp(operands, builder); + case HloOpcode::kCbrt: + return MapElementwiseOp(operands, builder); + case HloOpcode::kCeil: + return MapElementwiseOp(operands, builder); + case HloOpcode::kClamp: + return MapElementwiseOp(operands, builder); + case HloOpcode::kClz: + return MapElementwiseOp(operands, builder); + case HloOpcode::kCompare: { + auto* context = builder.getContext(); + auto dir = builder.getDictionaryAttr(builder.getNamedAttr( + "comparison_direction", + mhlo::ComparisonDirectionAttr::get( + context, + mhlo::symbolizeComparisonDirection( + ComparisonDirectionToString(instr->comparison_direction())) + .value()))); + auto result_types = llvm::to_vector(mlir::TypeRange{builder.getI1Type()}); + auto arg_types = llvm::to_vector(mlir::TypeRange(operands)); + return {{mhlo::MhloOpToStdScalarOp::mapOpOfType( + builder.getLoc(), result_types, arg_types, + mhlo::CompareOp::Adaptor(operands, dir), &builder)}}; + } + case HloOpcode::kComplex: + return MapHloOp({element_mlir_type}, operands, builder); + case HloOpcode::kCos: + return MapElementwiseOp(operands, builder); + case HloOpcode::kDivide: + return MapElementwiseOp(operands, builder); + case HloOpcode::kErf: + return MapElementwiseOp(operands, builder); + case HloOpcode::kExp: + return MapElementwiseOp(operands, builder); + case HloOpcode::kExpm1: + return MapElementwiseOp(operands, builder); + case HloOpcode::kFloor: + return MapElementwiseOp(operands, builder); + case HloOpcode::kIsFinite: + return MapHloOp({builder.getI1Type()}, operands, + builder); + case HloOpcode::kImag: + return MapHloOp({element_mlir_type}, operands, builder); + case HloOpcode::kLog: + return MapElementwiseOp(operands, builder); + case HloOpcode::kLog1p: + return MapElementwiseOp(operands, builder); + case HloOpcode::kLogistic: + return MapElementwiseOp(operands, builder); + case HloOpcode::kMaximum: + return MapElementwiseOp(operands, builder); + case HloOpcode::kMinimum: + return MapElementwiseOp(operands, builder); + case HloOpcode::kMultiply: + return MapElementwiseOp(operands, builder); + case HloOpcode::kNegate: + return MapElementwiseOp(operands, builder); + case HloOpcode::kNot: + return MapElementwiseOp(operands, builder); + case HloOpcode::kOr: + return MapElementwiseOp(operands, builder); + case HloOpcode::kPopulationCount: + return MapHloOp({element_mlir_type}, operands, + builder); + case HloOpcode::kPower: + return MapElementwiseOp(operands, builder); + case HloOpcode::kReal: + return MapHloOp({element_mlir_type}, operands, builder); + case HloOpcode::kReducePrecision: { + mlir::NamedAttribute exponent_bits( + builder.getStringAttr("exponent_bits"), + builder.getI32IntegerAttr(instr->exponent_bits())); + mlir::NamedAttribute mantissa_bits( + builder.getStringAttr("mantissa_bits"), + builder.getI32IntegerAttr(instr->mantissa_bits())); + return MapHloOp( + {operands.front().getType()}, operands, builder, + mlir::DictionaryAttr::get(builder.getContext(), + {exponent_bits, mantissa_bits})); + } + case HloOpcode::kRemainder: + return MapElementwiseOp(operands, builder); + case HloOpcode::kRoundNearestAfz: + return MapElementwiseOp(operands, builder); + case HloOpcode::kRoundNearestEven: + return MapElementwiseOp(operands, builder); + case HloOpcode::kRsqrt: + return MapElementwiseOp(operands, builder); + case HloOpcode::kSelect: + return MapElementwiseOp(operands, builder); + case HloOpcode::kShiftLeft: + return MapElementwiseOp(operands, builder); + case HloOpcode::kShiftRightArithmetic: + return MapElementwiseOp(operands, builder); + case HloOpcode::kShiftRightLogical: + return MapElementwiseOp(operands, builder); + case HloOpcode::kSign: + return MapElementwiseOp(operands, builder); + case HloOpcode::kSin: + return MapElementwiseOp(operands, builder); + case HloOpcode::kSqrt: + return MapElementwiseOp(operands, builder); + case HloOpcode::kSubtract: + return MapElementwiseOp(operands, builder); + case HloOpcode::kTan: + return MapElementwiseOp(operands, builder); + case HloOpcode::kTanh: + return MapElementwiseOp(operands, builder); + case HloOpcode::kXor: + return MapElementwiseOp(operands, builder); + case HloOpcode::kBitcastConvert: + return MapHloOp({element_mlir_type}, operands, + builder); + case HloOpcode::kConvert: { + if (operands[0].getType().isa() && + element_type == PRED) { + return { + builder + .create(CmpFPredicate::UNE, operands[0], + builder.create(builder.getFloatAttr( + operands[0].getType(), 0.0))) + ->getResults()}; + } + + auto out = + MapHloOp({element_mlir_type}, operands, builder) + .front(); + // Convert from float to int is saturating, but MHLO's conversion logic + // does not implement this. + // TODO(jreiffers): Is this a bug or a feature? + if (auto int_ty = out.getType().dyn_cast()) { + auto in = operands[0]; + if (auto float_ty = in.getType().dyn_cast()) { + auto cst_int = [&](int64_t x) { + return builder.create(x, int_ty); + }; + auto cst_float = [&](int64_t x) { + return builder.create( + builder.getFloatAttr(float_ty, x)); + }; + int64_t min = llvm::minIntN(int_ty.getWidth()); + int64_t max = llvm::maxIntN(int_ty.getWidth()); + // x <= static_cast(INT_MIN) ? INT_MIN : ... + out = builder.create( + builder.create(CmpFPredicate::OLE, in, cst_float(min)), + cst_int(min), out); + // x >= static_cast(INT_MAX) ? INT_MAX : ... + out = builder.create( + builder.create(CmpFPredicate::OGE, in, cst_float(max)), + cst_int(max), out); + // isnan(x) ? 0 : ... + out = builder.create( + builder.create(CmpFPredicate::UNO, in, in), cst_int(0), + out); + } + } + return {{out}}; + } + case HloOpcode::kBitcast: + if (instr->operands()[0]->shape().element_type() == element_type) { + return operands; + } + return MapHloOp({element_mlir_type}, operands, + builder); + case HloOpcode::kCopy: + case HloOpcode::kSlice: + case HloOpcode::kBroadcast: + case HloOpcode::kReshape: + case HloOpcode::kReverse: + case HloOpcode::kTranspose: + return operands; + default: + break; + } + + return absl::UnimplementedError(absl::StrCat("Unsupported: ", instr->name())); +} + +bool IsHloOpSupported(const HloInstruction* instr, + se::CudaComputeCapability compute_capability) { + auto is_unsupported_type = [](const HloInstruction* instr) { + auto e = instr->shape().element_type(); + // TODO(jreiffers): Convert to signless. + // TODO(jreiffers): Support complex. + // TODO(jreiffers): Support fp8, fp16, bf16. + // TODO(jreiffers): Support int4. + return (primitive_util::IsIntegralType(e) && + primitive_util::BitWidth(e) > 1 && + primitive_util::BitWidth(e) < 8) || + primitive_util::IsUnsignedIntegralType(e) || + primitive_util::IsComplexType(e) || + (primitive_util::IsFloatingPointType(e) && + primitive_util::BitWidth(e) < 32); + }; + if (is_unsupported_type(instr) || + absl::c_any_of(instr->operands(), is_unsupported_type)) { + return false; + } + + return !(kUnsupportedOps.contains(instr->opcode()) || + kUnimplementedOps.contains(instr->opcode()) || + IsUnsupportedConstant(instr) || IsUnsupportedTuple(instr) || + IsUnsupportedGather(instr)); +} + +bool IsHloConversionSupported(const HloComputation* computation, + se::GpuComputeCapability compute_capability) { + if (!std::holds_alternative(compute_capability)) { + // ROCM is not tested. + return false; + } + auto cuda_compute_capability = + std::get(compute_capability); + + return absl::c_all_of( + computation->instructions(), + [=](const HloInstruction* instr) { + return absl::c_all_of(instr->called_computations(), + [&](const HloComputation* called) { + return IsHloConversionSupported( + called, compute_capability); + }) && + IsHloOpSupported(instr, cuda_compute_capability); + }) && + (computation->IsFusionComputation() || + (absl::c_all_of( + computation->parameter_instructions(), [](auto* param) { + return param->shape().IsArray() && param->shape().rank() == 0; + }))); +} + +bool IsHloConversionSupported(const HloFusionAdaptor& fusion, + se::GpuComputeCapability compute_capability) { + if (!std::holds_alternative(compute_capability)) { + // ROCM is not tested. + return false; + } + auto cuda_compute_capability = + std::get(compute_capability); + + if (fusion.GetRoots().size() > 1) { + auto first_shape = fusion.GetRoots()[0].instruction().shape(); + for (int i = 1; i < fusion.GetRoots().size(); ++i) { + if (fusion.GetRoots()[i].instruction().shape().dimensions() != + first_shape.dimensions()) { + return false; + } + } + } + + return !HloFindIf( + fusion.GetRoots(), fusion, [=](HloInstructionAdaptor instr) { + return !absl::c_all_of(instr.instruction().called_computations(), + [&](const HloComputation* called) { + return IsHloConversionSupported( + called, compute_capability); + }) || + !IsHloOpSupported(&instr.instruction(), cuda_compute_capability); + }); +} + +absl::Status SubgraphToMlirFunction( + const PartitionedComputation& computation, + const PartitionedComputation::Subgraph& subgraph, mlir::func::FuncOp& func, + const CallTargetProvider& call_target_provider) { + TF_RET_CHECK(func != nullptr); + mlir::ImplicitLocOpBuilder builder(func.getLoc(), func->getContext()); + builder.setInsertionPointToStart(func.addEntryBlock()); + auto indices = func.getArguments().drop_front( + computation.computation().num_parameters()); + auto parameters = func.getArguments().take_front( + computation.computation().num_parameters()); + TF_ASSIGN_OR_RETURN( + auto results, SubgraphToMlir(computation, subgraph, call_target_provider, + parameters, indices, builder)); + builder.create(results); + return absl::OkStatus(); +} + +absl::StatusOr> SubgraphToMlir( + const PartitionedComputation& computation, + const PartitionedComputation::Subgraph& subgraph, + const CallTargetProvider& call_target_provider, mlir::ValueRange parameters, + mlir::ValueRange indices, mlir::ImplicitLocOpBuilder& builder) { + llvm::SmallVector results; + absl::node_hash_map>, + llvm::SmallVector> + cached_instructions; + + std::function>( + const HloInstruction* instr, mlir::ValueRange indices)> + emit_instr; + absl::flat_hash_map> + calls; + + auto provide_operand = [&](const HloInstruction* instr, int index, + mlir::ValueRange indices) + -> absl::StatusOr> { + auto* operand = instr->operand(index); + if (operand->opcode() == HloOpcode::kParameter) { + mlir::Value value = parameters[operand->parameter_number()]; + if (value.getType().isa()) { + value = builder.create(value, indices); + } else { + CHECK_EQ(indices.size(), 0); + } + return {{value}}; + } + + const auto& target_subgraph = computation.FindSubgraph(operand); + if (&target_subgraph == &subgraph) { + return emit_instr(operand, indices); + } + + auto callee = call_target_provider(operand); + llvm::SmallVector operands(parameters); + absl::c_copy(indices, std::back_inserter(operands)); + + // Check if we already have a call to this function in scope and reuse it + // if so. func.call is not pure, even if the call target is pure, so CSE + // won't clean this up. + auto& existing_calls = calls[callee]; + for (auto call : existing_calls) { + if (call.getOperands() == operands && + call->getParentRegion()->findAncestorBlockInRegion( + *builder.getInsertionBlock())) { + return call.getResults(); + } + } + return existing_calls + .emplace_back(builder.create( + call_target_provider(operand), operands)) + .getResults(); + }; + + emit_instr = [&](const HloInstruction* instr, mlir::ValueRange indices) + -> absl::StatusOr> { + // TODO(jreiffers): Check dominance, e.g.: + // + // padding_value = log(param) + // pad = pad(bar, padding_value) + // broadcast = broadcast(padding_value) + // pad + broadcasub + // + // If padding_value was first emitted in the context of pad, it'll be + // inside an scf.if. For now this doesn't matter, because the indexing + // is considered to be different, but once the partitioner is smarter, + // it will matter. + // + // Also, this caching should be combined with parameter caching. + std::vector indices_ptrs; + indices_ptrs.reserve(indices.size()); + for (auto index : indices) { + indices_ptrs.push_back(index.getAsOpaquePointer()); + } + auto& entry = cached_instructions[std::make_pair(instr, indices_ptrs)]; + if (!entry.empty()) { + return entry; + } + + TF_ASSIGN_OR_RETURN(entry, HloToMlir(instr, indices, provide_operand, + call_target_provider, builder)); + TF_RET_CHECK(!absl::c_any_of( + entry, [](const auto& entry) { return entry == nullptr; })) + << "null result for " << instr->ToShortString(); + return entry; + }; + + for (const auto* root : subgraph.roots) { + TF_ASSIGN_OR_RETURN(auto root_results, emit_instr(root, indices)); + results.append(root_results.begin(), root_results.end()); + } + return results; +} + +} // namespace mlir_converter +} // namespace gpu +} // namespace xla diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir.h b/third_party/xla/xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir.h new file mode 100644 index 00000000000000..10c1cc23372f73 --- /dev/null +++ b/third_party/xla/xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir.h @@ -0,0 +1,103 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef XLA_SERVICE_GPU_FUSIONS_MLIR_ELEMENTAL_HLO_TO_MLIR_H_ +#define XLA_SERVICE_GPU_FUSIONS_MLIR_ELEMENTAL_HLO_TO_MLIR_H_ + +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "llvm/ADT/SmallVector.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/AffineExpr.h" // from @llvm-project +#include "mlir/IR/ImplicitLocOpBuilder.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/IR/ValueRange.h" // from @llvm-project +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/service/gpu/fusions/mlir/computation_partitioner.h" +#include "xla/service/gpu/hlo_traversal.h" +#include "xla/service/gpu/model/indexing_map.h" +#include "xla/stream_executor/device_description.h" + +namespace xla { +namespace gpu { +namespace mlir_converter { + +using OperandProvider = + std::function>( + const HloInstruction* instr, int index, mlir::ValueRange indices)>; + +// Given a root of a subgraph, returns the corresponding function. +using CallTargetProvider = + std::function; + +// Emits MLIR to generate the given element of the HLO instruction. Required +// operands are accessed through the `operand_provider` function. +// CHECK fails if IsHloConversionSupported returns false. +absl::StatusOr> HloToMlir( + const HloInstruction* instr, mlir::ValueRange indices, + const OperandProvider& operand_provider, + const CallTargetProvider& call_target_provider, + mlir::ImplicitLocOpBuilder& builder); + +// Checks whether the given HLO instruction can be converted to MLIR. +bool IsHloOpSupported(const HloInstruction* instr, + se::CudaComputeCapability compute_capability); + +// Checks whether the given HLO computation is supported by the MLIR converter: +// - all instructions in it are supported +// - the signature is supported: if the computation is not a fusion computation, +// all arguments have rank 0. +bool IsHloConversionSupported(const HloComputation* computation, + se::GpuComputeCapability compute_capability); +bool IsHloConversionSupported(const HloFusionAdaptor& fusion, + se::GpuComputeCapability compute_capability); + +// Converts a function (subgraph) to an MLIR function producing one element of +// the result. The function must have the correct interface. +absl::Status SubgraphToMlirFunction( + const PartitionedComputation& computation, + const PartitionedComputation::Subgraph& subgraph, mlir::func::FuncOp& func, + const CallTargetProvider& call_target_provider); + +// Converts a function (subgraph) to MLIR that is emitted inline. +absl::StatusOr> SubgraphToMlir( + const PartitionedComputation& computation, + const PartitionedComputation::Subgraph& subgraph, + const CallTargetProvider& call_target_provider, mlir::ValueRange parameters, + mlir::ValueRange indices, mlir::ImplicitLocOpBuilder& builder); + +// Creates an affine.apply op for the given expression and values. +mlir::Value ApplyAffineExpr(mlir::AffineExpr expr, mlir::ValueRange dims, + mlir::ValueRange symbols, + mlir::ImplicitLocOpBuilder& b); + +// Creates affine.apply ops for each result of the given map. +llvm::SmallVector ApplyAffineMap(mlir::AffineMap map, + mlir::ValueRange dims, + mlir::ValueRange symbols, + mlir::ImplicitLocOpBuilder& b); + +// Checks all the **constraints** in the map (not the **ranges**). +mlir::Value CheckConstraints(const IndexingMap& map, mlir::ValueRange dims, + mlir::ValueRange symbols, + mlir::ImplicitLocOpBuilder& b); + +} // namespace mlir_converter +} // namespace gpu +} // namespace xla + +#endif // XLA_SERVICE_GPU_FUSIONS_MLIR_ELEMENTAL_HLO_TO_MLIR_H_ diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir_test.cc b/third_party/xla/xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir_test.cc new file mode 100644 index 00000000000000..82b07f03f1275e --- /dev/null +++ b/third_party/xla/xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir_test.cc @@ -0,0 +1,312 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir.h" + +#include + +#include +#include +#include "absl/status/status.h" +#include "llvm/Support/raw_ostream.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" // from @llvm-project +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/Dialect/Math/IR/Math.h" // from @llvm-project +#include "mlir/Dialect/SCF/IR/SCF.h" // from @llvm-project +#include "mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project +#include "mlir/IR/ImplicitLocOpBuilder.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/Pass/PassManager.h" // from @llvm-project +#include "mlir/Transforms/Passes.h" // from @llvm-project +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" +#include "xla/service/gpu/fusions/mlir/computation_partitioner.h" +#include "xla/service/hlo_parser.h" +#include "xla/service/llvm_ir/llvm_util.h" +#include "xla/status_macros.h" +#include "xla/tests/filecheck.h" +#include "xla/tests/hlo_test_base.h" +#include "tsl/lib/core/status_test_util.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/statusor.h" + +namespace xla { +namespace gpu { +namespace mlir_converter { +namespace { + +class ElementalHloToMlirTest : public HloTestBase { + public: + ElementalHloToMlirTest() { + context_.loadDialect(); + } + + // Converts the root subgraph of the entry function of the given hlo module to + // MLIR. + absl::Status Run(const std::string& hlo, const std::string& filecheck_str) { + auto hlo_module = ParseAndReturnVerifiedModule(hlo).value(); + + mlir::ImplicitLocOpBuilder builder(mlir::UnknownLoc::get(&context_), + &context_); + auto module = llvm_ir::CreateMlirModuleOp(builder.getLoc()); + builder.setInsertionPointToStart(module->getBody()); + auto* entry_computation = hlo_module->entry_computation(); + mlir::func::FuncOp entry_func; + for (auto* computation : hlo_module->computations()) { + PartitionedComputation pc(computation); + TF_RET_CHECK(pc.subgraphs().size() == 1); + auto func = CreateSubgraphMlirFunction(pc.GetRootSubgraph(), builder); + func.setSymName(computation->name()); + if (computation == entry_computation) { + entry_func = func; + } else { + func.setSymVisibility("private"); + } + } + PartitionedComputation entry_pc(entry_computation); + TF_RETURN_IF_ERROR(SubgraphToMlirFunction( + entry_pc, entry_pc.GetRootSubgraph(), entry_func, + [&](const HloInstruction* instr) { + return module->lookupSymbol( + instr->parent()->name()); + })); + + // Canonicalize and CSE for better readability of check tests. + mlir::PassManager pm(&context_); + pm.addPass(mlir::createCanonicalizerPass()); + pm.addPass(mlir::createCSEPass()); + TF_RET_CHECK(pm.run(module.get()).succeeded()); + + std::string out; + llvm::raw_string_ostream stream(out); + stream << entry_func; + + TF_ASSIGN_OR_RETURN(auto filecheck_result, + RunFileCheck(out, filecheck_str)); + TF_RET_CHECK(filecheck_result); + return absl::OkStatus(); + } + + mlir::MLIRContext context_; +}; + +TEST_F(ElementalHloToMlirTest, Reduce) { + TF_EXPECT_OK(Run(R"( + add { + p0 = f32[] parameter(0) + p1 = f32[] parameter(1) + ROOT sum = f32[] add(p0, p1) + } + + ENTRY main { + p0 = f32[10,20,30,40] parameter(0) + p1 = f32[] parameter(1) + ROOT r = f32[10,30] reduce(p0, p1), dimensions={1,3}, + to_apply=add + })", + R"( + // CHECK: func.func @main( + // CHECK-SAME: %[[ARG0:.*]]: tensor<10x20x30x40xf32>, + // CHECK-SAME: %[[ARG1:.*]]: tensor, + // CHECK-SAME: %[[X:.*]]: index {{.*}}, %[[Y:.*]]: index {{.*}} -> f32 { + // CHECK-DAG: %[[C0:.*]] = arith.constant 0 + // CHECK-DAG: %[[C1:.*]] = arith.constant 1 + // CHECK-DAG: %[[C20:.*]] = arith.constant 20 + // CHECK-DAG: %[[C40:.*]] = arith.constant 40 + // CHECK: %[[INIT:.*]] = tensor.extract %[[ARG1]][] + // CHECK: %[[RET:.*]] = scf.for %[[I:.*]] = %[[C0]] to %[[C20]] + // CHECK-SAME: step %[[C1]] iter_args(%[[ACC:.*]] = %[[INIT]]) + // CHECK: %[[RET_INNER:.*]] = scf.for %[[J:.*]] = %[[C0]] to %[[C40]] + // CHECK-SAME: iter_args(%[[ACC_INNER:.*]] = %[[ACC]]) + // CHECK: %[[VAL:.*]] = tensor.extract %[[ARG0]] + // CHECK-SAME: [%[[X]], %[[I]], %[[Y]], %[[J]]] + // CHECK: %[[UPD:.*]] = func.call @add(%[[ACC_INNER]], %[[VAL]]) + // CHECK: scf.yield %[[UPD]] + // CHECK: } + // CHECK: scf.yield %[[RET_INNER]] + // CHECK: } + // CHECK: return %[[RET]] + )")); +} + +TEST_F(ElementalHloToMlirTest, Concatenate) { + TF_EXPECT_OK(Run(R"( + ENTRY main { + p0 = f32[10,20,30] parameter(0) + p1 = f32[10,15,30] parameter(1) + p2 = f32[10,3,30] parameter(2) + ROOT r = f32[10,38,30] concatenate(p0, p1, p2), dimensions={1} + })", + R"( + // CHECK: func.func @main( + // CHECK-SAME: %[[ARG0:.*]]: tensor<10x20x30xf32>, + // CHECK-SAME: %[[ARG1:.*]]: tensor<10x15x30xf32>, + // CHECK-SAME: %[[ARG2:.*]]: tensor<10x3x30xf32>, + // CHECK-SAME: %[[X:.*]]: index {{{.*}}}, %[[Y:.*]]: index {{{.*}}}, + // CHECK-SAME: %[[Z:.*]]: index {{{.*}}} + // CHECK-DAG: %[[C35:.*]] = arith.constant 35 + // CHECK-DAG: %[[C20:.*]] = arith.constant 20 + // CHECK: %[[IN_BOUNDS:.*]] = arith.cmpi ult, %[[Y]], %[[C20]] + // CHECK: %[[CONCAT:.*]] = scf.if %[[IN_BOUNDS]] + // CHECK: %[[P0_VAL:.*]] = tensor.extract %[[ARG0]] + // CHECK-SAME: [%[[X]], %[[Y]], %[[Z]]] + // CHECK: scf.yield %[[P0_VAL]] + // CHECK: } else { + // CHECK: %[[IN_BOUNDS:.*]] = arith.cmpi ult, %[[Y]], %[[C35]] + // CHECK: %[[CONCAT2:.*]] = scf.if %[[IN_BOUNDS]] + // CHECK: %[[OFFSET:.*]] = arith.subi %[[Y]], %[[C20]] + // CHECK: %[[P1_VAL:.*]] = tensor.extract %[[ARG1]] + // CHECK-SAME: [%[[X]], %[[OFFSET]], %[[Z]]] + // CHECK: scf.yield %[[P1_VAL]] + // CHECK: } else { + // CHECK: %[[OFFSET:.*]] = arith.subi %[[Y]], %[[C35]] + // CHECK: %[[P2_VAL:.*]] = tensor.extract %[[ARG2]] + // CHECK-SAME: [%[[X]], %[[OFFSET]], %[[Z]]] + // CHECK: scf.yield %[[P2_VAL]] + // CHECK: } + // CHECK: scf.yield %[[CONCAT2]] + // CHECK: } + // CHECK: return %[[CONCAT]] + )")); +} + +TEST_F(ElementalHloToMlirTest, Gather) { + TF_EXPECT_OK(Run(R"( + ENTRY main { + operand = f32[33,34] parameter(0) + indices = s32[1806,1] parameter(1) + ROOT r = f32[1806,7,8] gather(operand, indices), offset_dims={1,2}, + collapsed_slice_dims={}, start_index_map={0}, + index_vector_dim=1, slice_sizes={7,8} + })", + R"( + // CHECK: func.func @main( + // CHECK-SAME: %[[ARG0:.*]]: tensor<33x34xf32>, + // CHECK-SAME: %[[ARG1:.*]]: tensor<1806x1xi32>, + // CHECK-SAME: %[[X:.*]]: index {{{.*}}}, %[[Y:.*]]: index {{{.*}}}, + // CHECK-SAME: %[[Z:.*]]: index {{{.*}}} + // CHECK-DAG: %[[C0:.*]] = arith.constant 0 + // CHECK-DAG: %[[C26:.*]] = arith.constant 26 + // CHECK: %[[IDX_I32:.*]] = tensor.extract %[[ARG1]][%[[X]], %[[C0]]] + // CHECK: %[[IDX:.*]] = arith.index_cast %[[IDX_I32]] : i32 to index + // CHECK: %[[CLAMP_HIGH:.*]] = arith.minsi %[[IDX]], %[[C26]] + // CHECK: %[[CLAMPED:.*]] = arith.maxsi %[[CLAMP_HIGH]], %[[C0]] + // CHECK: %[[X_IN:.*]] = arith.addi %[[CLAMPED]], %[[Y]] + // CHECK: %[[RET:.*]] = tensor.extract %[[ARG0]][%[[X_IN]], %[[Z]]] + // CHECK: return %[[RET]] + )")); +} + +TEST_F(ElementalHloToMlirTest, Pad) { + TF_EXPECT_OK(Run(R"( + ENTRY main { + p0 = f32[4, 4] parameter(0) + p1 = f32[] parameter(1) + ROOT pad = f32[12, 16] pad(p0, p1), padding=1_4_1x4_8_0 + })", + R"( + // CHECK: func.func @main( + // CHECK-SAME: %[[ARG0:.*]]: tensor<4x4xf32>, + // CHECK-SAME: %[[ARG1:.*]]: tensor, + // CHECK-SAME: %[[X:.*]]: index {{{.*}}}, %[[Y:.*]]: index {{{.*}}} + // CHECK-DAG: %[[C0:.*]] = arith.constant 0 + // CHECK-DAG: %[[C1:.*]] = arith.constant 1 + // CHECK-DAG: %[[C4:.*]] = arith.constant 4 + // CHECK-DAG: %[[C7:.*]] = arith.constant 7 + // CHECK: %[[CONSTRAINT_VAL:.*]] = affine.apply + // CHECK-SAME: <()[s0] -> (s0 - ((s0 - 1) floordiv 2) * 2 - 1)> + // CHECK-SAME: ()[%[[X]]] + // CHECK: %[[CONSTRAINT:.*]] = arith.cmpi eq, %[[CONSTRAINT_VAL]], %[[C0]] + // CHECK: %[[X_L:.*]] = arith.cmpi sge, %[[X]], %[[C1]] + // CHECK: %[[X_H:.*]] = arith.cmpi sle, %[[X]], %[[C7]] + // CHECK: %[[X_BOUNDS:.*]] = arith.andi %[[X_L]], %[[X_H]] + // CHECK: %[[X_AND_CONSTRAINT:.*]] = arith.andi %[[CONSTRAINT]], %[[X_BOUNDS]] + // CHECK: %[[Y_L:.*]] = arith.cmpi sge, %[[Y]], %[[C4]] + // CHECK: %[[Y_H:.*]] = arith.cmpi sle, %[[Y]], %[[C7]] + // CHECK: %[[Y_BOUNDS:.*]] = arith.andi %[[Y_L]], %[[Y_H]] + // CHECK: %[[FROM_INPUT:.*]] = arith.andi %[[X_AND_CONSTRAINT]], %[[Y_BOUNDS]] + // CHECK: %[[RET:.*]] = scf.if %[[FROM_INPUT]] + // CHECK: %[[X_IN:.*]] = affine.apply + // CHECK-SAME: <()[s0] -> ((s0 - 1) floordiv 2)>()[%[[X]]] + // CHECK: %[[Y_IN:.*]] = affine.apply + // CHECK-SAME: <()[s0] -> (s0 - 4)>()[%[[Y]]] + // CHECK: %[[VAL:.*]] = tensor.extract %[[ARG0]][%[[X_IN]], %[[Y_IN]]] + // CHECK: scf.yield %[[VAL]] + // CHECK: } else { + // CHECK: %[[PAD_VAL:.*]] = tensor.extract %[[ARG1]][] + // CHECK: scf.yield %[[PAD_VAL]] + // CHECK: } + // CHECK: return %[[RET]] + )")); +} + +TEST_F(ElementalHloToMlirTest, Transpose) { + TF_EXPECT_OK(Run(R"( + ENTRY main { + p0 = f32[4,5,6] parameter(0) + ROOT transpose = f32[6,5,4] transpose(p0), dimensions={2,1,0} + })", + R"( + // CHECK: func.func @main( + // CHECK-SAME: %[[ARG0:.*]]: tensor<4x5x6xf32>, + // CHECK-SAME: %[[X:.*]]: index {{{.*}}}, %[[Y:.*]]: index {{{.*}}}, + // CHECK-SAME: %[[Z:.*]]: index {{{.*}}} + // CHECK: %[[RET:.*]] = tensor.extract %[[ARG0]] + // CHECK-SAME: [%[[Z]], %[[Y]], %[[X]]] + // CHECK: return %[[RET]] + )")); +} + +TEST_F(ElementalHloToMlirTest, Broadcast) { + TF_EXPECT_OK(Run(R"( + ENTRY main { + p0 = f32[4,5] parameter(0) + ROOT broadcast = f32[6,4,5] broadcast(p0), dimensions={1,2} + })", + R"( + // CHECK: func.func @main( + // CHECK-SAME: %[[ARG0:.*]]: tensor<4x5xf32>, + // CHECK-SAME: %[[X:.*]]: index {{{.*}}}, %[[Y:.*]]: index {{{.*}}}, + // CHECK-SAME: %[[Z:.*]]: index {{{.*}}} + // CHECK: %[[RET:.*]] = tensor.extract %[[ARG0]] + // CHECK-SAME: [%[[Y]], %[[Z]]] + // CHECK: return %[[RET]] + )")); +} + +TEST_F(ElementalHloToMlirTest, Add) { + TF_EXPECT_OK(Run(R"( + ENTRY main { + p0 = f32[4] parameter(0) + p1 = f32[4] parameter(1) + ROOT add = f32[4] add(p0, p1) + })", + R"( + // CHECK: func.func @main( + // CHECK-SAME: %[[ARG0:.*]]: tensor<4xf32>, %[[ARG1:.*]]: tensor<4xf32>, + // CHECK-SAME: %[[X:.*]]: index {{.*}} + // CHECK: %[[A:.*]] = tensor.extract %[[ARG0]][%[[X]]] + // CHECK: %[[B:.*]] = tensor.extract %[[ARG1]][%[[X]]] + // CHECK: %[[RET:.*]] = arith.addf %[[A]], %[[B]] + // CHECK: return %[[RET]] + )")); +} + +} // namespace +} // namespace mlir_converter +} // namespace gpu +} // namespace xla diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/expand_float_conversions.cc b/third_party/xla/xla/service/gpu/fusions/mlir/expand_float_conversions.cc new file mode 100644 index 00000000000000..9094f722bc1813 --- /dev/null +++ b/third_party/xla/xla/service/gpu/fusions/mlir/expand_float_conversions.cc @@ -0,0 +1,209 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include + +#include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project +#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" +#include "xla/mlir_hlo/mhlo/transforms/map_mhlo_to_scalar_op.h" +#include "xla/primitive_util.h" +#include "xla/service/gpu/fusions/mlir/passes.h" +#include "xla/xla_data.pb.h" + +namespace xla { +namespace gpu { + +#define GEN_PASS_DEF_EXPANDFLOATCONVERSIONSPASS +#include "xla/service/gpu/fusions/mlir/passes.h.inc" + +namespace { + +class ExpandFloatConversionsPass + : public impl::ExpandFloatConversionsPassBase { + public: + using ExpandFloatConversionsPassBase::ExpandFloatConversionsPassBase; + void runOnOperation() override; +}; + +template +struct RewriteIntToBF16 : public mlir::OpRewritePattern { + using mlir::OpRewritePattern::OpRewritePattern; + + mlir::LogicalResult matchAndRewrite( + Op op, mlir::PatternRewriter& rewriter) const override { + if (op.getResult().getType() != rewriter.getBF16Type()) { + return rewriter.notifyMatchFailure(op, "not a bf16 itofp"); + } + auto f32 = + rewriter.create(op.getLoc(), rewriter.getF32Type(), op.getIn()); + rewriter.replaceOpWithNewOp( + op, op.getResult().getType(), f32); + return mlir::success(); + } +}; + +template +struct RewriteBF16ToInt : public mlir::OpRewritePattern { + using mlir::OpRewritePattern::OpRewritePattern; + + mlir::LogicalResult matchAndRewrite( + Op op, mlir::PatternRewriter& rewriter) const override { + if (op.getIn().getType() != rewriter.getBF16Type()) { + return rewriter.notifyMatchFailure(op, "not a bf16 fptoi"); + } + auto f32 = rewriter.create( + op.getLoc(), rewriter.getF32Type(), op.getIn()); + rewriter.replaceOpWithNewOp(op, op.getResult().getType(), f32); + return mlir::success(); + } +}; + +struct RewriteExtBF16ToF32 + : public mlir::OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + mlir::LogicalResult matchAndRewrite( + mlir::arith::ExtFOp op, mlir::PatternRewriter& rewriter) const override { + if (op.getIn().getType() != rewriter.getBF16Type() || + op.getResult().getType() != rewriter.getF32Type()) { + return rewriter.notifyMatchFailure(op, "not a bf16 -> f32 extf"); + } + auto bitcast = rewriter.create( + op.getLoc(), rewriter.getI16Type(), op.getIn()); + auto exti = rewriter.create( + op.getLoc(), rewriter.getI32Type(), bitcast); + auto shl = rewriter.create( + op.getLoc(), exti, + rewriter.create(op.getLoc(), 16, + rewriter.getI32Type())); + rewriter.replaceOpWithNewOp( + op, op.getResult().getType(), shl); + return mlir::success(); + } +}; + +struct RewriteExtBF16ToF64 + : public mlir::OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + mlir::LogicalResult matchAndRewrite( + mlir::arith::ExtFOp op, mlir::PatternRewriter& rewriter) const override { + if (op.getIn().getType() != rewriter.getBF16Type() || + op.getResult().getType() != rewriter.getF64Type()) { + return rewriter.notifyMatchFailure(op, "not a bf16 -> f64 extf"); + } + auto f32 = rewriter.create( + op.getLoc(), rewriter.getF32Type(), op.getIn()); + rewriter.replaceOpWithNewOp(op, op.getOut().getType(), + f32); + return mlir::success(); + } +}; + +struct RewriteTruncF32ToBF16 + : public mlir::OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + mlir::LogicalResult matchAndRewrite( + mlir::arith::TruncFOp op, + mlir::PatternRewriter& rewriter) const override { + if (op.getIn().getType() != rewriter.getF32Type() || + op.getResult().getType() != rewriter.getBF16Type()) { + return rewriter.notifyMatchFailure(op, "not an f32 -> bf16 truncf"); + } + + // The default lowering for f32 -> f16 doesn't round correctly. + mlir::NamedAttribute exponent_bits( + rewriter.getStringAttr("exponent_bits"), + rewriter.getI32IntegerAttr(primitive_util::ExponentWidth(BF16))); + mlir::NamedAttribute mantissa_bits( + rewriter.getStringAttr("mantissa_bits"), + rewriter.getI32IntegerAttr(primitive_util::SignificandWidth(BF16) - 1)); + + auto reduced = mlir::mhlo::MhloOpToStdScalarOp::mapOpOfType< + mlir::mhlo::ReducePrecisionOp>( + op.getLoc(), rewriter.getF32Type(), {rewriter.getF32Type()}, + mlir::mhlo::ReducePrecisionOpAdaptor( + {op.getIn()}, + mlir::DictionaryAttr::get(rewriter.getContext(), + {exponent_bits, mantissa_bits})), + &rewriter); + auto bitcast = rewriter.create( + op.getLoc(), rewriter.getI32Type(), reduced); + auto shr = rewriter.create( + op.getLoc(), bitcast, + rewriter.create(op.getLoc(), 16, + rewriter.getI32Type())); + auto trunc = rewriter.create( + op.getLoc(), rewriter.getI16Type(), shr); + rewriter.replaceOpWithNewOp( + op, op.getOut().getType(), trunc); + return mlir::success(); + } +}; + +struct RewriteTruncF64ToBF16 + : public mlir::OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + mlir::LogicalResult matchAndRewrite( + mlir::arith::TruncFOp op, + mlir::PatternRewriter& rewriter) const override { + if (op.getIn().getType() != rewriter.getF64Type() || + op.getResult().getType() != rewriter.getBF16Type()) { + return rewriter.notifyMatchFailure(op, "not an f64 -> bf16 truncf"); + } + auto f32 = rewriter.create( + op.getLoc(), rewriter.getF32Type(), op.getIn()); + rewriter.replaceOpWithNewOp( + op, op.getOut().getType(), f32); + return mlir::success(); + } +}; + +void ExpandFloatConversionsPass::runOnOperation() { + mlir::RewritePatternSet patterns(&getContext()); + patterns.add(&getContext()); + patterns.add(&getContext()); + patterns.add>(&getContext()); + patterns.add>(&getContext()); + patterns.add>(&getContext()); + patterns.add>(&getContext()); + if (include_bf16_) { + patterns.add(&getContext()); + } + patterns.add(&getContext()); + if (mlir::failed(mlir::applyPatternsAndFoldGreedily(getOperation(), + std::move(patterns)))) { + signalPassFailure(); + } +} + +} // namespace + +std::unique_ptr CreateExpandFloatConversionsPass(bool enable_bf16) { + return createExpandFloatConversionsPass( + ExpandFloatConversionsPassOptions{enable_bf16}); +} + +} // namespace gpu +} // namespace xla diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/lower_tensors.cc b/third_party/xla/xla/service/gpu/fusions/mlir/lower_tensors.cc new file mode 100644 index 00000000000000..b6ba47f423323b --- /dev/null +++ b/third_party/xla/xla/service/gpu/fusions/mlir/lower_tensors.cc @@ -0,0 +1,262 @@ + +/* Copyright 2024 The OpenXLA Authors. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include + +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" // from @llvm-project +#include "mlir/Dialect/Affine/IR/AffineOps.h" // from @llvm-project +#include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" // from @llvm-project +#include "mlir/Dialect/SCF/IR/SCF.h" // from @llvm-project +#include "mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project +#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" // from @llvm-project +#include "mlir/IR/AffineExpr.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/IR/TypeRange.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/IR/ValueRange.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project +#include "xla/layout_util.h" +#include "xla/shape_util.h" + +namespace xla { +namespace gpu { + +#define GEN_PASS_DEF_LOWERTENSORSPASS +#include "xla/service/gpu/fusions/mlir/passes.h.inc" + +namespace { + +using mlir::failure; +using mlir::success; + +struct RewriteFunctionSignatures : mlir::OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + mlir::LogicalResult matchAndRewrite( + mlir::func::FuncOp op, mlir::PatternRewriter& rewriter) const override { + auto is_tensor = [](mlir::Type ty) { + return ty.isa(); + }; + if (!llvm::any_of(op.getFunctionType().getInputs(), is_tensor)) { + return rewriter.notifyMatchFailure(op, + "the function has no input tensors"); + } + + bool some_tensor_result = + llvm::any_of(op.getFunctionType().getResults(), is_tensor); + bool all_tensor_results = + llvm::all_of(op.getFunctionType().getResults(), is_tensor); + if (some_tensor_result && !all_tensor_results) { + op->emitOpError("function has a mix of tensor and non-tensor results"); + return failure(); + } + + mlir::TypeRange new_results = op.getFunctionType().getResults(); + if (some_tensor_result) { + new_results = {}; + auto terminator = op.getFunctionBody().front().getTerminator(); + rewriter.setInsertionPoint(terminator); + rewriter.replaceOpWithNewOp(terminator); + } + + llvm::SmallVector new_operands( + op.getFunctionType().getInputs()); + for (auto&& [index, operand] : llvm::enumerate(new_operands)) { + if (is_tensor(operand)) { + rewriter.setInsertionPointToStart(&op.getBody().front()); + auto cast = rewriter.create( + op.getLoc(), operand, op.getArgument(index)); + op.getArgument(index).replaceAllUsesExcept(cast.getResult(0), cast); + operand = mlir::LLVM::LLVMPointerType::get(op.getContext()); + } + } + + op.setFunctionType(rewriter.getFunctionType(new_operands, new_results)); + auto& entry = op->getRegion(0).front(); + for (auto [arg, arg_type] : llvm::zip(entry.getArguments(), new_operands)) { + arg.setType(arg_type); + } + + return success(); + } +}; + +mlir::Value CreateGep(mlir::Operation* op, + mlir::TypedValue tensor, + mlir::ValueRange indices, + mlir::PatternRewriter& rewriter) { + auto ptr = mlir::LLVM::LLVMPointerType::get(rewriter.getContext()); + auto byte_shape = ShapeUtil::MakeShape(U8, tensor.getType().getShape()); + if (auto encoding = tensor.getType().getEncoding()) { + *byte_shape.mutable_layout() = LayoutUtil::MakeLayout(llvm::to_vector( + encoding.cast().getValues())); + } + auto linearize_map = mlir::getAffineConstantExpr(0, rewriter.getContext()); + for (auto [dim, stride] : + llvm::enumerate(*ShapeUtil::ByteStrides(byte_shape))) { + linearize_map = linearize_map + + mlir::getAffineDimExpr(dim, rewriter.getContext()) * stride; + } + + rewriter.setInsertionPoint(op); + mlir::Value index = rewriter.create( + tensor.getLoc(), linearize_map, indices); + auto index_ty = + ShapeUtil::ElementsIn(byte_shape) < std::numeric_limits::max() + ? rewriter.getI32Type() + : rewriter.getI64Type(); + index = rewriter.create(tensor.getLoc(), index_ty, + index); + + auto tensor_ptr = rewriter + .create( + tensor.getLoc(), ptr, tensor) + .getResult(0); + auto gep = rewriter.create( + tensor.getLoc(), ptr, tensor.getType().getElementType(), tensor_ptr, + index); + gep.setInbounds(true); + return gep; +} + +struct RewriteTensorExtract : mlir::OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + mlir::LogicalResult matchAndRewrite( + mlir::tensor::ExtractOp op, + mlir::PatternRewriter& rewriter) const override { + auto gep = CreateGep(op, op.getTensor(), op.getIndices(), rewriter); + rewriter.replaceOpWithNewOp(op, op.getType(), gep); + return success(); + } +}; + +struct RewriteTensorInsert : mlir::OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + mlir::LogicalResult matchAndRewrite( + mlir::tensor::InsertOp op, + mlir::PatternRewriter& rewriter) const override { + mlir::Value dest = op.getDest(); + while (dest.getDefiningOp()) { + int result_number = dest.cast().getResultNumber(); + if (auto insert = dest.getDefiningOp()) { + dest = insert.getDest(); + } else if (auto scf_if = dest.getDefiningOp()) { + // Pick one of the branches, they're required to yield the same buffers. + dest = scf_if.getThenRegion().front().getTerminator()->getOperand( + result_number); + } else if (auto scf_for = dest.getDefiningOp()) { + dest = scf_for.getInitArgs()[result_number]; + } + } + + auto gep = + CreateGep(op, dest.cast>(), + op.getIndices(), rewriter); + rewriter.create(gep.getLoc(), op.getScalar(), gep); + + op.replaceAllUsesWith(op.getDest()); + op.erase(); + return success(); + } +}; + +struct RewriteCall : mlir::OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + mlir::LogicalResult matchAndRewrite( + mlir::func::CallOp op, mlir::PatternRewriter& rewriter) const override { + if (!llvm::any_of(op->getOperandTypes(), [](mlir::Type ty) { + return ty.isa(); + })) { + return rewriter.notifyMatchFailure(op, "the call has no input tensors"); + } + + for (const auto&& [index, arg] : llvm::enumerate(op.getOperands())) { + if (arg.getType().isa()) { + op.setOperand( + index, + rewriter + .create( + op.getLoc(), + mlir::LLVM::LLVMPointerType::get(op.getContext()), arg) + .getResult(0)); + } + } + return success(); + } +}; + +class LowerTensorsPass : public impl::LowerTensorsPassBase { + public: + void runOnOperation() override; +}; + +void LowerTensorsPass::runOnOperation() { + mlir::RewritePatternSet tensor_patterns(&getContext()); + tensor_patterns.add(&getContext()); + if (mlir::failed(mlir::applyPatternsAndFoldGreedily( + getOperation(), std::move(tensor_patterns)))) { + signalPassFailure(); + } + + mlir::RewritePatternSet function_patterns(&getContext()); + function_patterns.add(&getContext()); + mlir::scf::ForOp::getCanonicalizationPatterns(function_patterns, + &getContext()); + mlir::scf::IfOp::getCanonicalizationPatterns(function_patterns, + &getContext()); + if (mlir::failed(mlir::applyPatternsAndFoldGreedily( + getOperation(), std::move(function_patterns)))) { + signalPassFailure(); + } + + getOperation()->walk([this](mlir::LLVM::LoadOp load) { + mlir::Value addr = load.getAddr(); + if (auto gep = load.getAddr().getDefiningOp()) { + addr = gep.getBase(); + } + if (auto base = mlir::dyn_cast(addr)) { + if (auto func = mlir::dyn_cast( + base.getOwner()->getParentOp())) { + if (func.getArgAttr(base.getArgNumber(), "xla.invariant")) { + load.setInvariant(true); + } + return; + } + } + load.emitOpError("load op address is not (a GEP of) a function argument"); + signalPassFailure(); + }); +} + +} // namespace + +std::unique_ptr<::mlir::Pass> CreateLowerTensorsPass() { + return std::make_unique(); +} + +} // namespace gpu +} // namespace xla diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/lower_to_llvm.cc b/third_party/xla/xla/service/gpu/fusions/mlir/lower_to_llvm.cc new file mode 100644 index 00000000000000..a1d7fe28549329 --- /dev/null +++ b/third_party/xla/xla/service/gpu/fusions/mlir/lower_to_llvm.cc @@ -0,0 +1,92 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include "mlir/Conversion/AffineToStandard/AffineToStandard.h" // from @llvm-project +#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h" // from @llvm-project +#include "mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h" // from @llvm-project +#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h" // from @llvm-project +#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h" // from @llvm-project +#include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h" // from @llvm-project +#include "mlir/Conversion/LLVMCommon/ConversionTarget.h" // from @llvm-project +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" // from @llvm-project +#include "mlir/Conversion/MathToLLVM/MathToLLVM.h" // from @llvm-project +#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h" // from @llvm-project +#include "mlir/Dialect/Arith/Transforms/Passes.h" // from @llvm-project +#include "mlir/Dialect/Complex/IR/Complex.h" // from @llvm-project +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" // from @llvm-project // IWYU pragma: keep +#include "mlir/Dialect/Math/IR/Math.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/Interfaces/DataLayoutInterfaces.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/Transforms/DialectConversion.h" // from @llvm-project + +namespace xla { +namespace gpu { + +#define GEN_PASS_DEF_LOWERTOLLVMPASS +#include "xla/service/gpu/fusions/mlir/passes.h.inc" + +namespace { + +class LowerToLLVMPass : public impl::LowerToLLVMPassBase { + public: + using LowerToLLVMPassBase::LowerToLLVMPassBase; + + void runOnOperation() override { + // Populate type conversions. + mlir::LLVMTypeConverter type_converter(getOperation().getContext()); + mlir::LLVMConversionTarget target(*getOperation().getContext()); + + // Populate patterns. + mlir::RewritePatternSet patterns(&getContext()); + mlir::populateAffineToStdConversionPatterns(patterns); + mlir::populateSCFToControlFlowConversionPatterns(patterns); + mlir::arith::populateArithExpandOpsPatterns(patterns); + mlir::arith::populateArithToLLVMConversionPatterns(type_converter, + patterns); + mlir::populateGpuToNVVMConversionPatterns(type_converter, patterns); + mlir::populateFuncToLLVMConversionPatterns(type_converter, patterns); + mlir::cf::populateControlFlowToLLVMConversionPatterns(type_converter, + patterns); + mlir::populateComplexToLLVMConversionPatterns(type_converter, patterns); + mlir::populateMathToLLVMConversionPatterns(type_converter, patterns); + + // Setup target. + mlir::configureGpuToNVVMConversionLegality(target); + target.addIllegalDialect(); + target.addLegalOp(); + + if (failed( + applyFullConversion(getOperation(), target, std::move(patterns)))) { + signalPassFailure(); + } + } +}; + +} // namespace + +std::unique_ptr CreateLowerToLLVMPass() { + return std::make_unique(); +} + +} // namespace gpu +} // namespace xla diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/merge_pointers_to_same_slice.cc b/third_party/xla/xla/service/gpu/fusions/mlir/merge_pointers_to_same_slice.cc new file mode 100644 index 00000000000000..ce9b73648ef554 --- /dev/null +++ b/third_party/xla/xla/service/gpu/fusions/mlir/merge_pointers_to_same_slice.cc @@ -0,0 +1,117 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "llvm/ADT/BitVector.h" +#include "llvm/ADT/STLExtras.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project + +namespace xla { +namespace gpu { + +#define GEN_PASS_DEF_MERGEPOINTERSTOSAMESLICEPASS +#include "xla/service/gpu/fusions/mlir/passes.h.inc" + +namespace { + +class MergePointersToSameSlicePass + : public impl::MergePointersToSameSlicePassBase< + MergePointersToSameSlicePass> { + public: + void runOnOperation() override; +}; + +struct PackedArgs { + llvm::BitVector args_to_erase; + // replacement_args[i] == i iff !args_to_erase[i]. + llvm::SmallVector replacement_args; + + PackedArgs() = default; + explicit PackedArgs(mlir::func::FuncOp func) { + absl::flat_hash_map> slice_to_operand; + args_to_erase.resize(func.getNumArguments()); + replacement_args.reserve(func.getNumArguments()); + for (int i = 0; i < func.getNumArguments(); ++i) { + replacement_args.push_back(i); + } + + for (auto [idx, operand] : llvm::enumerate(func.getArguments())) { + auto slice_index = func.getArgAttr(idx, "xla.slice_index"); + if (!slice_index) { + continue; + } + + auto& target_index = slice_to_operand[static_cast( + slice_index.cast().getInt())]; + if (target_index) { + replacement_args[idx] = *target_index; + args_to_erase[idx] = true; + } else { + target_index = idx; + } + } + } + + void Pack(mlir::func::FuncOp op) { + for (auto [idx, arg] : llvm::enumerate(op.getArguments())) { + if (replacement_args[idx] != idx) { + arg.replaceAllUsesWith(op.getArgument(replacement_args[idx])); + } + } + op.eraseArguments(args_to_erase); + for (int i = 0; i < op.getNumArguments(); ++i) { + if (op.getArgAttr(i, "xla.slice_index")) { + op.removeArgAttr(i, "xla.slice_index"); + op.setArgAttr(i, mlir::LLVM::LLVMDialect::getNoAliasAttrName(), + mlir::UnitAttr::get(op->getContext())); + } + } + } + + void Pack(mlir::func::CallOp op) { op->eraseOperands(args_to_erase); } +}; + +void MergePointersToSameSlicePass::runOnOperation() { + mlir::func::FuncOp entry; + + absl::flat_hash_map args_to_pack; + getOperation()->walk([&](mlir::func::FuncOp func) { + args_to_pack[func.getName()] = PackedArgs(func); + }); + getOperation()->walk([&](mlir::func::CallOp call) { + args_to_pack[call.getCallee()].Pack(call); + }); + getOperation()->walk([&](mlir::func::FuncOp func) { + args_to_pack[func.getName()].Pack(func); + }); +} + +} // namespace + +std::unique_ptr> +CreateMergePointersToSameSlicePass() { + return std::make_unique(); +} + +} // namespace gpu +} // namespace xla diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/mlir_fusion_emitter.cc b/third_party/xla/xla/service/gpu/fusions/mlir/mlir_fusion_emitter.cc new file mode 100644 index 00000000000000..d42821697f3501 --- /dev/null +++ b/third_party/xla/xla/service/gpu/fusions/mlir/mlir_fusion_emitter.cc @@ -0,0 +1,388 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "xla/service/gpu/fusions/mlir/mlir_fusion_emitter.h" + +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/IntrinsicsNVPTX.h" +#include "llvm/Linker/Linker.h" +#include "llvm/Support/Casting.h" +#include "mlir/Conversion/AffineToStandard/AffineToStandard.h" // from @llvm-project +#include "mlir/Dialect/Affine/IR/AffineOps.h" // from @llvm-project +#include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project +#include "mlir/Dialect/Arith/Transforms/Passes.h" // from @llvm-project +#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" // from @llvm-project +#include "mlir/Dialect/Func/Extensions/InlinerExtension.h" // from @llvm-project +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/Dialect/GPU/IR/GPUDialect.h" // from @llvm-project +#include "mlir/Dialect/LLVMIR/NVVMDialect.h" // from @llvm-project +#include "mlir/Dialect/Math/IR/Math.h" // from @llvm-project +#include "mlir/Dialect/MemRef/Transforms/Passes.h" // from @llvm-project +#include "mlir/Dialect/SCF/IR/SCF.h" // from @llvm-project +#include "mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/DialectRegistry.h" // from @llvm-project +#include "mlir/IR/ImplicitLocOpBuilder.h" // from @llvm-project +#include "mlir/IR/Location.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/OwningOpRef.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/Interfaces/DataLayoutInterfaces.h" // from @llvm-project +#include "mlir/Pass/PassManager.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Target/LLVMIR/Export.h" // from @llvm-project +#include "mlir/Transforms/Passes.h" // from @llvm-project +#include "xla/hlo/ir/hlo_instructions.h" +#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" +#include "xla/service/buffer_assignment.h" +#include "xla/service/gpu/fusions/fusion_emitter.h" +#include "xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir.h" +#include "xla/service/gpu/fusions/mlir/passes.h" +#include "xla/service/gpu/fusions/mlir/type_util.h" +#include "xla/service/gpu/ir_emitter_context.h" +#include "xla/service/gpu/kernel_arguments.h" +#include "xla/service/gpu/model/indexing_map.h" +#include "xla/service/gpu/runtime/kernel_thunk.h" +#include "xla/service/gpu/target_util.h" +#include "xla/service/llvm_ir/llvm_util.h" +#include "xla/shape_util.h" +#include "xla/status_macros.h" +#include "tsl/platform/errors.h" + +namespace xla { +namespace gpu { +namespace { + +void AddRanges(llvm::Function* func, const LaunchDimensions& launch_dims, + llvm::Module* module) { + for (auto& block : *func) { + for (auto& instr : block) { + if (auto* call = llvm::dyn_cast(&instr)) { + if (auto* callee = call->getCalledFunction()) { + switch (callee->getIntrinsicID()) { + case llvm::Intrinsic::nvvm_read_ptx_sreg_tid_x: + llvm_ir::AddRangeMetadata( + 0, launch_dims.thread_counts_per_block().x, call, module); + break; + case llvm::Intrinsic::nvvm_read_ptx_sreg_tid_y: + llvm_ir::AddRangeMetadata( + 0, launch_dims.thread_counts_per_block().y, call, module); + break; + case llvm::Intrinsic::nvvm_read_ptx_sreg_tid_z: + llvm_ir::AddRangeMetadata( + 0, launch_dims.thread_counts_per_block().z, call, module); + break; + case llvm::Intrinsic::nvvm_read_ptx_sreg_ctaid_x: + llvm_ir::AddRangeMetadata(0, launch_dims.block_counts().x, call, + module); + break; + case llvm::Intrinsic::nvvm_read_ptx_sreg_ctaid_y: + llvm_ir::AddRangeMetadata(0, launch_dims.block_counts().y, call, + module); + break; + case llvm::Intrinsic::nvvm_read_ptx_sreg_ctaid_z: + llvm_ir::AddRangeMetadata(0, launch_dims.block_counts().z, call, + module); + break; + } + } + } + } + } +} + +} // namespace + +mlir::Value MlirFusionEmitterBase::EmitBlockId( + mlir::ImplicitLocOpBuilder& builder, int dim) const { + const auto& counts = launch_dimensions().block_counts(); + int64_t count = dim == 0 ? counts.x : dim == 1 ? counts.y : counts.z; + auto block_id = builder.create( + static_cast(dim)); + block_id->setAttr("xla.range", builder.getIndexArrayAttr({0, count - 1})); + return block_id; +} + +mlir::Value MlirFusionEmitterBase::EmitThreadId( + mlir::ImplicitLocOpBuilder& builder, int dim) const { + const auto& counts = launch_dimensions().thread_counts_per_block(); + int64_t count = dim == 0 ? counts.x : dim == 1 ? counts.y : counts.z; + auto thread_id = builder.create( + static_cast(dim)); + thread_id->setAttr("xla.range", builder.getIndexArrayAttr({0, count - 1})); + return thread_id; +} + +absl::StatusOr MlirFusionEmitterBase::Emit( + IrEmitterContext& ir_emitter_context, + const HloFusionInstruction& fusion) const { + TF_ASSIGN_OR_RETURN( + auto args, + KernelArguments::Create(ir_emitter_context.buffer_assignment(), &fusion)); + auto launch_dims = launch_dimensions(); + auto [status_or_entry, cached] = + ir_emitter_context.kernel_cache().GetWithStatus( + fusion.fused_instructions_computation(), args.args(), + /*discriminator=*/"", + [&]() -> absl::StatusOr { + std::string kernel_name = + ir_emitter_context.name_uniquer()->GetUniqueName( + llvm_ir::SanitizeFunctionName(std::string(fusion.name()))); + if (ir_emitter_context.emit_kernels()) { + TF_ASSIGN_OR_RETURN( + auto module, + CreateLLVMModule( + *ir_emitter_context.mlir_context(), + ir_emitter_context.llvm_module()->getContext(), + ir_emitter_context.gpu_device_info(), fusion, kernel_name, + &ir_emitter_context.buffer_assignment())); + auto* kernel_func = module->getFunction(kernel_name); + AddRanges(kernel_func, launch_dims, module.get()); + + auto* target = ir_emitter_context.llvm_module(); + module->setDataLayout(target->getDataLayout()); + module->setTargetTriple(target->getTargetTriple()); + + llvm::IRBuilder<> builder(module->getContext()); + AnnotateFunctionAsGpuKernel(module.get(), kernel_func, &builder); + TF_RETURN_IF_ERROR(AnnotateKernelLaunchDimensions( + ir_emitter_context.gpu_device_info(), launch_dims, + kernel_name, module.get())); + + // Use override flag because libdevice functions can be present in + // both. + CHECK(!llvm::Linker::linkModules( + *target, std::move(module), + llvm::Linker::Flags::OverrideFromSrc)); + } else { + VLOG(3) << "Skipped kernel compilation."; + } + + return KernelReuseCache::Entry{kernel_name, launch_dims, + std::nullopt, + /*shmem_bytes=*/0}; + }); + TF_ASSIGN_OR_RETURN(const KernelReuseCache::Entry* entry, status_or_entry); + + if (cached) { + VLOG(3) << "Reuse: " << fusion.name() << " -> " << entry->kernel_name; + } + + FusionEmissionResult result; + result.thunks.emplace_back(std::make_unique( + &fusion, entry->kernel_name, args.args(), launch_dims, entry->cluster_dim, + entry->shmem_bytes)); + return result; +} + +absl::StatusOr> +MlirFusionEmitterBase::CreateLLVMModule( + mlir::MLIRContext& mlir_context, llvm::LLVMContext& llvm_context, + const se::DeviceDescription& device, const HloFusionInstruction& fusion, + const std::string& entry_function_name, + const BufferAssignment* buffer_assignment) const { + TF_RET_CHECK(device.cuda_compute_capability().major >= 1) + << "Unsupported device type: " << device.name(); + TF_ASSIGN_OR_RETURN( + auto module, CreateMLIRModule(mlir_context, fusion, entry_function_name, + buffer_assignment)); + + mlir::PassManager pm(&mlir_context); + // TODO(jreiffers): Proper inlining and CSE of function calls. + pm.addPass(mlir::createCanonicalizerPass()); + pm.addPass(mlir::createCSEPass()); + pm.addPass(CreatePropagateSliceIndicesPass()); + pm.addPass(CreateLowerTensorsPass()); + pm.addPass(CreateMergePointersToSameSlicePass()); + + // LowerTensors creates new affine.apply ops. Fold and CSE them so + // simplify-affine has maximally folded expressions to work with. + pm.addPass(mlir::createCanonicalizerPass()); + pm.addPass(mlir::createCSEPass()); + pm.addPass(CreateSimplifyAffinePass()); + + // simplify-affine lowers most affine.apply ops, but if it can't prove a + // division or modulo is unsigned, affine.apply ops will remain. + pm.addPass(mlir::createLowerAffinePass()); + + pm.addPass(mlir::createLoopInvariantCodeMotionPass()); + pm.addPass(mlir::createSymbolDCEPass()); + pm.addPass(mlir::createCSEPass()); + pm.addPass(CreateLowerTensorsPass()); + pm.addPass(CreateExpandFloatConversionsPass( + !device.cuda_compute_capability().IsAtLeastAmpere())); + pm.addPass(CreateLowerToLLVMPass()); + TF_RET_CHECK(pm.run(module.get()).succeeded()); + + auto llvm_module = mlir::translateModuleToLLVMIR(module.get(), llvm_context); + TF_RET_CHECK(llvm_module != nullptr) + << "Failed to translate module to LLVM IR."; + + return llvm_module; +} + +absl::StatusOr> +MlirFusionEmitterBase::CreateMLIRModule( + mlir::MLIRContext& context, const HloFusionInstruction& fusion, + const std::string& entry_function_name, + const BufferAssignment* buffer_assignment) const { + context.loadDialect(); + mlir::DialectRegistry registry; + mlir::func::registerInlinerExtension(registry); + context.appendDialectRegistry(registry); + + mlir::OpBuilder builder(&context); + auto loc = mlir::NameLoc::get(builder.getStringAttr(fusion.name())); + mlir::OwningOpRef module = llvm_ir::CreateMlirModuleOp(loc); + + // Create the entry function. + llvm::SmallVector param_types; + std::optional args; + if (buffer_assignment != nullptr) { + TF_ASSIGN_OR_RETURN(args, + KernelArguments::Create(*buffer_assignment, &fusion)); + } + // Annotate tensors with the buffer indices. This way, the buffer propagation + // pass can clean them up later. + int next_slice_index = 0; + absl::flat_hash_map> + slice_indices; + auto get_arg_attrs = [&](int index) -> absl::StatusOr { + if (!args) { + return builder.getDictionaryAttr({builder.getNamedAttr( + "xla.slice_index", builder.getIndexAttr(next_slice_index++))}); + } + + const auto& arg = args->args()[index]; + llvm::SmallVector attrs; + attrs.push_back(builder.getNamedAttr( + "xla.slice_index", builder.getIndexAttr(arg.llvm_arg_index()))); + attrs.push_back( + builder.getNamedAttr(mlir::LLVM::LLVMDialect::getAlignAttrName(), + builder.getIndexAttr(arg.alignment()))); + attrs.push_back(builder.getNamedAttr( + mlir::LLVM::LLVMDialect::getDereferenceableAttrName(), + builder.getIndexAttr(arg.slice().size()))); + if (!arg.written()) { + attrs.push_back( + builder.getNamedAttr("xla.invariant", builder.getUnitAttr())); + } + return builder.getDictionaryAttr(attrs); + }; + + llvm::SmallVector arg_attrs; + int arg_index = 0; + for (auto* param : fusion.operands()) { + param_types.push_back( + mlir_converter::TensorShapeToMlirType(param->shape(), builder)); + TF_ASSIGN_OR_RETURN(arg_attrs.emplace_back(), get_arg_attrs(arg_index++)); + } + + auto result_types = mlir_converter::ShapeToMlirTypes(fusion.shape(), builder); + param_types.append(result_types.begin(), result_types.end()); + TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus( + fusion.shape(), [&](const auto& shape, const ShapeIndex& index) { + if (shape.IsArray()) { + TF_ASSIGN_OR_RETURN(arg_attrs.emplace_back(), + get_arg_attrs(arg_index++)); + } + return absl::OkStatus(); + })); + + builder.setInsertionPointToStart(module->getBody()); + auto entry_func = builder.create( + loc, entry_function_name, + mlir::FunctionType::get(&context, param_types, result_types), + /*sym_visibility=*/mlir::StringAttr{}, + mlir::ArrayAttr::get(&context, arg_attrs), + /*res_attrs=*/mlir::ArrayAttr{}); + entry_func->setAttr("xla.entry", mlir::UnitAttr::get(&context)); + + TF_RETURN_IF_ERROR(EmitMlir(module.get(), entry_func, fusion)); + + // Run a minimal simplification pipeline. + mlir::PassManager pm(&context); + pm.addPass(mlir::createCanonicalizerPass()); + pm.addPass(mlir::createCSEPass()); + TF_RET_CHECK(pm.run(module.get()).succeeded()); + return module; +} + +absl::StatusOr> +MlirFusionEmitterBase::EmitLoopNest( + mlir::ImplicitLocOpBuilder& b, mlir::ValueRange output_tensors, + const IndexingMap& thread_to_output_map, + const std::function>( + mlir::ValueRange outputs_tensors, mlir::ValueRange output_indices)>& + create_body) const { + llvm::SmallVector map_dims{ + EmitThreadId(b, 0), EmitThreadId(b, 1), EmitThreadId(b, 2), + EmitBlockId(b, 0), EmitBlockId(b, 1), EmitBlockId(b, 2)}; + llvm::SmallVector map_symbols; + + auto cst = [&](int64_t v) { + return b.create(b.getIndexAttr(v)); + }; + + std::function>( + int, mlir::ValueRange)> + make_loops; + make_loops = [&](int i, mlir::ValueRange current_outputs) + -> absl::StatusOr> { + if (i < thread_to_output_map.GetAffineMap().getNumSymbols()) { + auto range = thread_to_output_map.GetSymbolRange(i); + auto for_op = b.create(cst(range.lower_bound), + cst(range.upper_bound + 1), + cst(1), current_outputs); + map_symbols.push_back(for_op.getInductionVar()); + b.setInsertionPointToStart(for_op.getBody()); + TF_ASSIGN_OR_RETURN(auto results, + make_loops(i + 1, for_op.getRegionIterArgs())); + b.create(results); + b.setInsertionPointAfter(for_op); + return for_op.getResults(); + } + auto is_in_bounds = mlir_converter::CheckConstraints( + thread_to_output_map, map_dims, map_symbols, b); + auto if_op = b.create(mlir::TypeRange{current_outputs}, + is_in_bounds, true, true); + b.setInsertionPointToStart(if_op.getBody(0)); + auto output_indices = mlir_converter::ApplyAffineMap( + thread_to_output_map.GetAffineMap(), map_dims, map_symbols, b); + TF_ASSIGN_OR_RETURN(auto results, + create_body(current_outputs, output_indices)); + b.create(results); + b.setInsertionPointToStart(if_op.getBody(1)); + b.create(current_outputs); + b.setInsertionPointAfter(if_op); + return if_op.getResults(); + }; + + return make_loops(0, output_tensors); +} + +} // namespace gpu +} // namespace xla diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/mlir_fusion_emitter.h b/third_party/xla/xla/service/gpu/fusions/mlir/mlir_fusion_emitter.h new file mode 100644 index 00000000000000..b458fa7a6037c4 --- /dev/null +++ b/third_party/xla/xla/service/gpu/fusions/mlir/mlir_fusion_emitter.h @@ -0,0 +1,82 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef XLA_SERVICE_GPU_FUSIONS_MLIR_MLIR_FUSION_EMITTER_H_ +#define XLA_SERVICE_GPU_FUSIONS_MLIR_MLIR_FUSION_EMITTER_H_ + +#include + +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/Module.h" +#include "mlir/IR/AffineMap.h" // from @llvm-project +#include "mlir/IR/ImplicitLocOpBuilder.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "xla/hlo/ir/hlo_instructions.h" +#include "xla/service/gpu/fusions/fusion_emitter.h" +#include "xla/service/gpu/ir_emitter_context.h" +#include "xla/service/gpu/model/indexing_map.h" + +namespace xla { +namespace gpu { + +class MlirFusionEmitterBase : public KernelFusionInterface { + public: + absl::StatusOr Emit( + IrEmitterContext& ir_emitter_context, + const HloFusionInstruction& fusion) const final; + + // Visible for testing. `buffer_assignment` is optional for testing (assigns + // a different buffer to each tensor). + absl::StatusOr> CreateLLVMModule( + mlir::MLIRContext& mlir_context, llvm::LLVMContext& llvm_context, + const se::DeviceDescription& device, const HloFusionInstruction& fusion, + const std::string& entry_function_name, + const BufferAssignment* buffer_assignment) const; + + // Visible for testing. `buffer_assignment` is optional for testing (assigns + // a different buffer to each tensor). + absl::StatusOr> CreateMLIRModule( + mlir::MLIRContext& context, const HloFusionInstruction& fusion, + const std::string& entry_function_name, + const BufferAssignment* buffer_assignment) const; + + protected: + // Emits MLIR for the given fusion. The entry function has one tensor argument + // per fusion parameter and output and one tensor result per fusion output. + // The fuson outputs may only be used with `tensor.insert` ops.a + virtual absl::Status EmitMlir(mlir::ModuleOp module, + mlir::func::FuncOp entry_function, + const HloFusionInstruction& fusion) const = 0; + + // Emit a loop nest for the symbols in the output map. The output map should + // have the dimensions specified in KernelFusionInterface. Loops are nested + // with the symbol 0 as the outermost loop. `output_indices` are the final + // output indices, not just the indices of the symbols. The return value of + // the function is the updated output tensors. + absl::StatusOr> EmitLoopNest( + mlir::ImplicitLocOpBuilder& b, mlir::ValueRange output_tensors, + const IndexingMap& thread_to_output_map, + const std::function>( + mlir::ValueRange output_tensors, mlir::ValueRange output_indices)>& + create_body) const; + + mlir::Value EmitBlockId(mlir::ImplicitLocOpBuilder& builder, int dim) const; + mlir::Value EmitThreadId(mlir::ImplicitLocOpBuilder& builder, int dim) const; +}; + +} // namespace gpu +} // namespace xla + +#endif // XLA_SERVICE_GPU_FUSIONS_MLIR_MLIR_FUSION_EMITTER_H_ diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/mlir_fusion_emitter_test.cc b/third_party/xla/xla/service/gpu/fusions/mlir/mlir_fusion_emitter_test.cc new file mode 100644 index 00000000000000..e9ca8a7343bb8d --- /dev/null +++ b/third_party/xla/xla/service/gpu/fusions/mlir/mlir_fusion_emitter_test.cc @@ -0,0 +1,180 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "xla/service/gpu/fusions/mlir/mlir_fusion_emitter.h" + +#include +#include +#include + +#include +#include +#include "absl/status/status.h" +#include "absl/strings/string_view.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/Support/raw_ostream.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" // from @llvm-project +#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" // from @llvm-project +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/Dialect/GPU/IR/GPUDialect.h" // from @llvm-project +#include "mlir/Dialect/LLVMIR/NVVMDialect.h" // from @llvm-project +#include "mlir/Dialect/Math/IR/Math.h" // from @llvm-project +#include "mlir/Dialect/SCF/IR/SCF.h" // from @llvm-project +#include "mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project +#include "mlir/IR/ImplicitLocOpBuilder.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/ValueRange.h" // from @llvm-project +#include "mlir/Interfaces/DataLayoutInterfaces.h" // from @llvm-project +#include "mlir/Pass/PassManager.h" // from @llvm-project +#include "mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h" // from @llvm-project +#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h" // from @llvm-project +#include "mlir/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.h" // from @llvm-project +#include "xla/hlo/ir/hlo_casting_utils.h" +#include "xla/hlo/ir/hlo_instructions.h" +#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" +#include "xla/service/gpu/gpu_device_info_for_tests.h" +#include "xla/service/gpu/launch_dimensions.h" +#include "xla/service/gpu/model/indexing_map.h" +#include "xla/tests/filecheck.h" +#include "xla/tests/hlo_test_base.h" +#include "tsl/platform/statusor.h" + +namespace xla { +namespace gpu { +namespace { + +class DummyCopyFusionEmitter : public MlirFusionEmitterBase { + public: + LaunchDimensions launch_dimensions() const final { return {1, 100}; } + + std::optional ComputeThreadIdToOutputIndexing( + int64_t, mlir::MLIRContext*) const final { + return std::nullopt; + } + + std::optional ComputeThreadIdToInputIndexing( + int64_t, int64_t, mlir::MLIRContext*) const final { + return std::nullopt; + } + + protected: + absl::Status EmitMlir(mlir::ModuleOp module, + mlir::func::FuncOp entry_function, + const HloFusionInstruction& fusion) const final { + mlir::ImplicitLocOpBuilder b(module->getLoc(), entry_function); + b.setInsertionPointToStart(entry_function.addEntryBlock()); + auto thread_id = EmitThreadId(b, 0); + auto value = b.create( + entry_function.getArgument(0), mlir::ValueRange{thread_id}); + auto result = b.create( + value, entry_function.getArgument(1), mlir::ValueRange{thread_id}); + b.create(result->getResults()); + return absl::OkStatus(); + } +}; + +class MlirFusionEmitterTest : public HloTestBase { + protected: + MlirFusionEmitterTest() { + context_.loadDialect(); + mlir::DialectRegistry registry; + mlir::registerBuiltinDialectTranslation(registry); + mlir::registerLLVMDialectTranslation(registry); + mlir::registerNVVMDialectTranslation(registry); + context_.appendDialectRegistry(registry); + } + + mlir::MLIRContext context_; + stream_executor::DeviceDescription device_info_ = + TestGpuDeviceInfo::RTXA6000DeviceInfo(); +}; + +constexpr absl::string_view kModule = R"( + fused_computation { + ROOT %p0 = f32[100] parameter(0) + } + + ENTRY main { + %p0 = f32[100] parameter(0) + ROOT fusion = f32[100] fusion(%p0), kind=kLoop, calls=fused_computation + })"; + +TEST_F(MlirFusionEmitterTest, CreateMlirModule) { + auto module = ParseAndReturnVerifiedModule(kModule).value(); + DummyCopyFusionEmitter emitter; + TF_ASSERT_OK_AND_ASSIGN( + auto mlir_module, + emitter.CreateMLIRModule( + context_, + *Cast( + module->entry_computation()->root_instruction()), + "fusion", + /*buffer_assignment=*/nullptr)); + + std::string out; + llvm::raw_string_ostream stream(out); + stream << *mlir_module; + + TF_ASSERT_OK_AND_ASSIGN(auto filecheck_result, RunFileCheck(out, R"( + // CHECK: func.func @fusion( + // CHECK-SAME: %[[IN:.*]]: tensor<100xf32> {xla.slice_index = 0 + // CHECK-SAME: %[[OUT:.*]]: tensor<100xf32> {xla.slice_index = 1 + // CHECK: %[[TID:.*]] = gpu.thread_id x + // CHECK: %[[VAL:.*]] = tensor.extract %[[IN]][%[[TID]]] + // CHECK: %[[RET:.*]] = tensor.insert %[[VAL]] + // CHECK-SAME: into %[[OUT]][%[[TID]]] + // CHECK: return %[[RET]] + )")); + EXPECT_TRUE(filecheck_result); +} + +TEST_F(MlirFusionEmitterTest, CreateLLVMModule) { + llvm::LLVMContext llvm_context; + + auto module = ParseAndReturnVerifiedModule(kModule).value(); + DummyCopyFusionEmitter emitter; + TF_ASSERT_OK_AND_ASSIGN( + auto llvm_module, + emitter.CreateLLVMModule( + context_, llvm_context, device_info_, + *Cast( + module->entry_computation()->root_instruction()), + "fusion", + /*buffer_assignment=*/nullptr)); + + std::string out; + llvm::raw_string_ostream stream(out); + stream << *llvm_module; + + TF_ASSERT_OK_AND_ASSIGN(auto filecheck_result, RunFileCheck(out, R"( + // CHECK: define void @fusion(ptr noalias %[[IN:.*]], ptr noalias %[[OUT:.*]]) + // CHECK: %[[TID:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x() + // CHECK: %[[EXT:.*]] = sext i32 %[[TID]] to i64 + // CHECK: %[[TRUNC:.*]] = trunc i64 %[[EXT]] to i32 + // CHECK: %[[IN_PTR:.*]] = getelementptr inbounds float, ptr %[[IN]], i32 %[[TRUNC]] + // CHECK: %[[VAL:.*]] = load float, ptr %[[IN_PTR]], align 4 + // CHECK: %[[OUT_PTR:.*]] = getelementptr inbounds float, ptr %[[OUT]], i32 %[[TRUNC]] + // CHECK: store float %[[VAL]], ptr %[[OUT_PTR]], align 4 + // CHECK: ret void + )")); + EXPECT_TRUE(filecheck_result); +} + +} // namespace +} // namespace gpu +} // namespace xla diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/passes.h b/third_party/xla/xla/service/gpu/fusions/mlir/passes.h new file mode 100644 index 00000000000000..12f582a9b013d1 --- /dev/null +++ b/third_party/xla/xla/service/gpu/fusions/mlir/passes.h @@ -0,0 +1,44 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef XLA_SERVICE_GPU_FUSIONS_MLIR_PASSES_H_ +#define XLA_SERVICE_GPU_FUSIONS_MLIR_PASSES_H_ + +#include + +#include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project // IWYU pragma: keep +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" // from @llvm-project // IWYU pragma: keep +#include "mlir/Dialect/SCF/IR/SCF.h" // from @llvm-project // IWYU pragma: keep +#include "mlir/Pass/Pass.h" // from @llvm-project + +namespace xla { +namespace gpu { + +#define GEN_PASS_DECL +#include "xla/service/gpu/fusions/mlir/passes.h.inc" + +std::unique_ptr CreateLowerTensorsPass(); +std::unique_ptr CreateLowerToLLVMPass(); +std::unique_ptr CreatePropagateSliceIndicesPass(); +std::unique_ptr CreateMergePointersToSameSlicePass(); +std::unique_ptr CreateSimplifyAffinePass(); +std::unique_ptr CreateExpandFloatConversionsPass(bool enable_bf16); + +#define GEN_PASS_REGISTRATION +#include "xla/service/gpu/fusions/mlir/passes.h.inc" + +} // namespace gpu +} // namespace xla + +#endif // XLA_SERVICE_GPU_FUSIONS_MLIR_PASSES_H_ diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/passes.td b/third_party/xla/xla/service/gpu/fusions/mlir/passes.td new file mode 100644 index 00000000000000..39501b940626ce --- /dev/null +++ b/third_party/xla/xla/service/gpu/fusions/mlir/passes.td @@ -0,0 +1,124 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_GPU_FUSIONS_MLIR_PASSES_TD_ +#define XLA_SERVICE_GPU_FUSIONS_MLIR_PASSES_TD_ + +include "mlir/Pass/PassBase.td" + +def PropagateSliceIndicesPass : + Pass<"xla-gpu-propagate-slice-indices", "mlir::ModuleOp"> { + let summary = "Propagates slice indices from the entry function to all callees."; + + let description = [{ + Propagates xla.slice_index attributes from the function with the xla.entry + attribute to all other functions. + }]; + + let dependentDialects = [ + "mlir::func::FuncDialect" + ]; + + let constructor = "CreatePropagateSliceIndicesPass()"; +} + +def LowerTensorsPass : + Pass<"xla-gpu-lower-tensors", "mlir::ModuleOp"> { + let summary = "Lowers tensors to llvm pointers and loads/stores."; + + let description = [{ + Lowers tensors to LLVM. We cannot use the memref lowerings because they + are not compatible with XLA's ABI. + }]; + + let dependentDialects = [ + "mlir::func::FuncDialect", "mlir::LLVM::LLVMDialect", + "mlir::tensor::TensorDialect", + ]; + + let constructor = "CreateLowerTensorsPass()"; +} + +def MergePointersToSameSlicePass : + Pass<"xla-gpu-merge-pointers", "mlir::ModuleOp"> { + let summary = "Merges pointers that share slices."; + + let description = [{ + When a function has multiple pointer arguments with the same slice index, + merges them. + }]; + + let dependentDialects = [ + "mlir::func::FuncDialect" + ]; + + let constructor = "CreateMergePointersToSameSlicePass()"; +} + +def SimplifyAffinePass : Pass<"xla-gpu-simplify-affine", "mlir::ModuleOp"> { + let summary = "Simplifies affine.apply using XLA's range-aware simplifier."; + + let description = [{ + The standard affine canonicalizer cannot simplify all expressions, since + it is unaware of range information. This pass uses `xla.range` attributes + on arguments and ops for simplification. It also lowers floordiv and mod + to simpler expressions than lower-affine. This pass only works for + expressions for which we can prove the LHS of mod and div is nonnegative. + }]; + + let dependentDialects = [ + "mlir::affine::AffineDialect", "mlir::func::FuncDialect", + "mlir::scf::SCFDialect", + ]; + + let constructor = "CreateSimplifyAffinePass()"; +} + +def ExpandFloatConversionsPass : Pass<"xla-gpu-expand-conversions", "mlir::ModuleOp"> { + let summary = "Expands float conversions that are not natively supported."; + + let description = [{ + Not all float conversions are natively supported, so the ones that aren't + need to be emulated with bitwise operations. + + Currently, this pass only implements bf16 conversions. + }]; + + let dependentDialects = [ + "mlir::arith::ArithDialect", "mlir::mhlo::MhloDialect" + ]; + + let options = [ + Option<"include_bf16_", "include-bf16", "bool", /*default=*/"false", + "Enable the BF16 <-> F32 expansion patterns.">, + ]; +} + +def LowerToLLVMPass : + Pass<"xla-gpu-lower-to-llvm", "mlir::ModuleOp"> { + let summary = "Lowers to LLVM."; + + let description = [{ + Lowers the rest to LLVM + }]; + + let dependentDialects = [ + "mlir::func::FuncDialect", "mlir::LLVM::LLVMDialect" + ]; + + let constructor = "CreateLowerToLLVMPass()"; +} + +#endif // XLA_SERVICE_GPU_FUSIONS_MLIR_PASSES_TD_ diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/propagate_slice_indices.cc b/third_party/xla/xla/service/gpu/fusions/mlir/propagate_slice_indices.cc new file mode 100644 index 00000000000000..bc4e74d914a99c --- /dev/null +++ b/third_party/xla/xla/service/gpu/fusions/mlir/propagate_slice_indices.cc @@ -0,0 +1,80 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include + +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "xla/service/gpu/fusions/mlir/passes.h" + +namespace xla { +namespace gpu { + +#define GEN_PASS_DEF_PROPAGATESLICEINDICESPASS +#include "xla/service/gpu/fusions/mlir/passes.h.inc" + +namespace { + +class PropagateSliceIndicesPass + : public impl::PropagateSliceIndicesPassBase { + public: + void runOnOperation() override; +}; + +void PropagateSliceIndicesPass::runOnOperation() { + mlir::func::FuncOp entry; + for (auto func : getOperation().getOps()) { + if (func->getAttr("xla.entry")) { + entry = func; + break; + } + } + + if (!entry) { + getOperation()->emitOpError("No entry function found."); + signalPassFailure(); + return; + } + + for (auto func : getOperation().getOps()) { + if (func.getNumArguments() == 0 || func == entry) { + continue; + } + + for (int i = 0; i < func.getNumArguments(); ++i) { + if (mlir::isa(func.getArgument(i).getType())) { + if (auto index = entry.getArgAttr(i, "xla.slice_index")) { + func.setArgAttr(i, "xla.slice_index", index); + } + if (auto invariant = entry.getArgAttr(i, "xla.invariant")) { + func.setArgAttr(i, "xla.invariant", invariant); + } + } else { + break; + } + } + } +} + +} // namespace + +std::unique_ptr CreatePropagateSliceIndicesPass() { + return std::make_unique(); +} + +} // namespace gpu +} // namespace xla diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/simplify_affine.cc b/third_party/xla/xla/service/gpu/fusions/mlir/simplify_affine.cc new file mode 100644 index 00000000000000..c1698ec79e8590 --- /dev/null +++ b/third_party/xla/xla/service/gpu/fusions/mlir/simplify_affine.cc @@ -0,0 +1,222 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include +#include +#include +#include + +#include "absl/base/optimization.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Support/Casting.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" // from @llvm-project +#include "mlir/Dialect/Affine/LoopUtils.h" // from @llvm-project +#include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/Dialect/SCF/IR/SCF.h" // from @llvm-project +#include "mlir/IR/AffineExpr.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/ImplicitLocOpBuilder.h" // from @llvm-project +#include "mlir/IR/Matchers.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project +#include "xla/service/gpu/fusions/mlir/passes.h" +#include "xla/service/gpu/model/indexing_map.h" + +namespace xla { +namespace gpu { + +#define GEN_PASS_DEF_SIMPLIFYAFFINEPASS +#include "xla/service/gpu/fusions/mlir/passes.h.inc" + +namespace { + +class SimplifyAffinePass + : public impl::SimplifyAffinePassBase { + public: + void runOnOperation() override; +}; + +std::optional GetRange(mlir::Value value) { + auto attr_to_range = [](mlir::Attribute attr) -> std::optional { + if (!attr) { + return std::nullopt; + } + auto values = llvm::to_vector( + attr.cast().getAsValueRange()); + return {{values[0].getSExtValue(), values[1].getSExtValue()}}; + }; + + if (value.getDefiningOp()) { + return attr_to_range(value.getDefiningOp()->getAttr("xla.range")); + } + + auto bbarg = value.dyn_cast(); + if (!bbarg) { + return std::nullopt; + } + + auto parent = bbarg.getParentBlock()->getParentOp(); + if (auto func_op = mlir::dyn_cast(parent)) { + return attr_to_range(func_op.getArgAttr(bbarg.getArgNumber(), "xla.range")); + } + + if (auto for_op = mlir::dyn_cast(parent)) { + llvm::APInt lb, ub; + if (mlir::matchPattern(for_op.getLowerBound(), mlir::m_ConstantInt(&lb)) && + mlir::matchPattern(for_op.getUpperBound(), mlir::m_ConstantInt(&ub))) { + return {{lb.getSExtValue(), ub.getSExtValue() - 1}}; + } + } + + return std::nullopt; +} + +struct RewriteAffineApply + : mlir::OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + mlir::LogicalResult matchAndRewrite( + mlir::affine::AffineApplyOp op, + mlir::PatternRewriter& rewriter) const override { + auto affine_map = op.getAffineMap(); + std::vector dim_ranges(affine_map.getNumDims()); + std::vector symbol_ranges(affine_map.getNumSymbols()); + + for (int i = 0; i < affine_map.getNumInputs(); ++i) { + if (auto range = GetRange(op->getOperand(i))) { + if (i >= dim_ranges.size()) { + symbol_ranges[i - dim_ranges.size()] = *range; + } else { + dim_ranges[i] = *range; + } + } else { + return rewriter.notifyMatchFailure(op, "failed to deduce range"); + } + } + + IndexingMap map(op.getAffineMap(), dim_ranges, symbol_ranges); + map.Simplify(); + auto expr = map.GetAffineMap().getResult(0); + + RangeEvaluator range_evaluator(dim_ranges, symbol_ranges, op->getContext()); + std::function can_be_lowered; + bool fits_32_bits = true; + can_be_lowered = [&](mlir::AffineExpr expr) { + auto range = range_evaluator.ComputeExpressionRange(expr); + fits_32_bits &= range.upper_bound < std::numeric_limits::max(); + + auto bin_op = llvm::dyn_cast(expr); + if (!bin_op) { + return true; + } + + // Mod and div can be lowered if their LHS is >= 0 and their RHS is a + // constant. + if (expr.getKind() == mlir::AffineExprKind::Mod || + expr.getKind() == mlir::AffineExprKind::FloorDiv) { + if (!range_evaluator.IsAlwaysPositiveOrZero(bin_op.getLHS()) || + !range_evaluator.ComputeExpressionRange(bin_op.getRHS()) + .IsPoint()) { + return false; + } + } + if (expr.getKind() == mlir::AffineExprKind::CeilDiv) { + return false; + } + + return can_be_lowered(bin_op.getLHS()) && can_be_lowered(bin_op.getRHS()); + }; + + if (!can_be_lowered(expr)) { + return rewriter.notifyMatchFailure(op, + "unable to lower the affine apply"); + } + + std::function lower; + + mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); + auto int_ty = fits_32_bits ? b.getI32Type() : b.getI64Type(); + b.setInsertionPoint(op); + lower = [&](mlir::AffineExpr expr) -> mlir::Value { + if (auto bin_op = mlir::dyn_cast(expr)) { + auto lhs = lower(bin_op.getLHS()); + auto rhs = lower(bin_op.getRHS()); + switch (expr.getKind()) { + case mlir::AffineExprKind::Add: + return b.create(lhs, rhs); + case mlir::AffineExprKind::Mul: + return b.create(lhs, rhs); + case mlir::AffineExprKind::Mod: + return b.create(lhs, rhs); + case mlir::AffineExprKind::FloorDiv: + return b.create(lhs, rhs); + default: + ABSL_UNREACHABLE(); + } + } + + switch (expr.getKind()) { + case mlir::AffineExprKind::Constant: + return b.create( + mlir::cast(expr).getValue(), int_ty); + case mlir::AffineExprKind::DimId: + return b.create( + int_ty, op.getDimOperands()[mlir::cast(expr) + .getPosition()]); + case mlir::AffineExprKind::SymbolId: + return b.create( + int_ty, + op.getSymbolOperands()[mlir::cast(expr) + .getPosition()]); + default: + ABSL_UNREACHABLE(); + } + }; + + auto result = lower(map.GetAffineMap().getResult(0)); + rewriter.replaceOp( + op, b.create(b.getIndexType(), result)); + return mlir::success(); + } +}; + +void SimplifyAffinePass::runOnOperation() { + mlir::RewritePatternSet patterns(&getContext()); + patterns.add(&getContext()); + mlir::GreedyRewriteConfig config; + // There's no point simplifying more than once. + config.strictMode = mlir::GreedyRewriteStrictness::ExistingOps; + if (mlir::failed(mlir::applyPatternsAndFoldGreedily( + getOperation(), std::move(patterns), config))) { + signalPassFailure(); + } +} + +} // namespace + +std::unique_ptr CreateSimplifyAffinePass() { + return std::make_unique(); +} + +} // namespace gpu +} // namespace xla diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/tests/BUILD b/third_party/xla/xla/service/gpu/fusions/mlir/tests/BUILD new file mode 100644 index 00000000000000..b3048193cacc66 --- /dev/null +++ b/third_party/xla/xla/service/gpu/fusions/mlir/tests/BUILD @@ -0,0 +1,38 @@ +load("//xla:lit.bzl", "lit_test_suite") +load("//xla:xla.bzl", "xla_cc_binary") + +package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + licenses = ["notice"], +) + +xla_cc_binary( + name = "mlir_fusions_opt", + srcs = ["mlir_fusions_opt.cc"], + deps = [ + "//xla/mlir_hlo", + "//xla/service/gpu/fusions/mlir:passes", + "@llvm-project//mlir:AffineDialect", + "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:FuncExtensions", + "@llvm-project//mlir:GPUDialect", + "@llvm-project//mlir:LLVMDialect", + "@llvm-project//mlir:MathDialect", + "@llvm-project//mlir:MlirOptLib", + "@llvm-project//mlir:SCFDialect", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TensorDialect", + "@llvm-project//mlir:Transforms", + ], +) + +lit_test_suite( + name = "tests", + srcs = glob(["*.mlir"]), + cfg = "//xla:lit.cfg.py", + tools = [ + ":mlir_fusions_opt", + "@llvm-project//llvm:FileCheck", + ], +) diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/tests/expand_float_conversions.mlir b/third_party/xla/xla/service/gpu/fusions/mlir/tests/expand_float_conversions.mlir new file mode 100644 index 00000000000000..19d739cfe1b247 --- /dev/null +++ b/third_party/xla/xla/service/gpu/fusions/mlir/tests/expand_float_conversions.mlir @@ -0,0 +1,53 @@ +// RUN: mlir_fusions_opt %s -split-input-file -xla-gpu-expand-conversions="include-bf16=true" -canonicalize | FileCheck %s -check-prefixes=CHECK,CHECK-BF16 -dump-input=always +// RUN: mlir_fusions_opt %s -split-input-file -xla-gpu-expand-conversions="include-bf16=false" -canonicalize | FileCheck %s -check-prefixes=CHECK,CHECK-NO-BF16 -dump-input=always + +module { + func.func @f64_to_bf16(%arg0: f64) -> bf16 { + %ret = arith.truncf %arg0 : f64 to bf16 + return %ret : bf16 + } +} + +// CHECK-LABEL: f64_to_bf16 +// CHECK-SAME: (%[[ARG:.*]]: f64) +// CHECK-BF16: arith.truncf %[[ARG]] : f64 to f32 +// CHECK-BF16-NOT: arith.truncf + +// CHECK-NO-BF16: %[[F32:.*]] = arith.truncf %[[ARG]] : f64 to f32 +// CHECK-NO-BF16: arith.truncf %[[F32]] : f32 to bf16 + + +module { + func.func @bf16_to_f64(%arg0: bf16) -> f64 { + %ret = arith.extf %arg0 : bf16 to f64 + return %ret : f64 + } +} + +// CHECK-LABEL: bf16_to_f64 +// CHECK: bitcast {{.*}} : i32 to f32 +// CHECK: arith.extf {{.*}} : f32 to f64 + +// ----- + +module { + func.func @bf16_to_int(%arg0: bf16) -> i32 { + %ret = arith.fptosi %arg0 : bf16 to i32 + return %ret : i32 + } +} + +// CHECK-LABEL: bf16_to_int +// CHECK: arith.fptosi {{.*}} : f32 to i32 + +// ----- + +module { + func.func @int_to_bf16(%arg0: i16) -> bf16 { + %ret = arith.sitofp %arg0 : i16 to bf16 + return %ret : bf16 + } +} + +// CHECK-LABEL: int_to_bf16 +// CHECK: arith.sitofp {{.*}} : i16 to f32 \ No newline at end of file diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/tests/lower_tensors.mlir b/third_party/xla/xla/service/gpu/fusions/mlir/tests/lower_tensors.mlir new file mode 100644 index 00000000000000..d3437a0bec0290 --- /dev/null +++ b/third_party/xla/xla/service/gpu/fusions/mlir/tests/lower_tensors.mlir @@ -0,0 +1,134 @@ +// RUN: mlir_fusions_opt %s -split-input-file -xla-gpu-lower-tensors | FileCheck %s + +module { + func.func @add(%arg0: f32, %arg1: f32) -> f32 { + %sum = arith.addf %arg0, %arg1 : f32 + func.return %sum : f32 + } + + func.func @tensorarg(%arg0: tensor<43xf32> {xla.invariant, xla.slice_index = 0}, %arg1: index) -> f32 { + %v1 = arith.constant 2.0 : f32 + %v2 = tensor.extract %arg0[%arg1] : tensor<43xf32> + %sum = func.call @add(%v1, %v2) : (f32, f32) -> f32 + func.return %sum : f32 + } + + func.func @tensorcall(%arg0: tensor<43xf32> {xla.slice_index = 0}, %arg1: index) -> f32 { + %call = func.call @tensorarg(%arg0, %arg1) : (tensor<43xf32>, index) -> f32 + func.return %call : f32 + } + + func.func @stores(%arg0: tensor<17xf32> {xla.slice_index = 0}, %arg1: tensor<43xf32> {xla.slice_index = 1}) -> tensor<43xf32> { + %c17 = arith.constant 17 : index + %c23 = arith.constant 23 : index + %cst = arith.constant 3.0 : f32 + %out = tensor.insert %cst into %arg1[%c17] : tensor<43xf32> + %out2 = tensor.insert %cst into %out[%c23] : tensor<43xf32> + func.return %out2 : tensor<43xf32> + } +} + +// CHECK: func.func @add(%{{.*}}: f32, %{{.*}}: f32) -> f32 { +// CHECK-NEXT: arith.addf +// CHECK-NEXT: return + +// CHECK: func.func @tensorarg(%[[ARG0:.*]]: !llvm.ptr +// CHECK-SAME: {xla.invariant, xla.slice_index = 0 : i64}, %[[ARG1:.*]]: index) -> f32 { +// CHECK-DAG: %[[C2:.*]] = arith.constant 2.000000e+00 +// CHECK-DAG: %[[IDX:.*]] = arith.index_castui %[[ARG1]] : index to i32 +// CHECK-DAG: %[[PTR:.*]] = llvm.getelementptr inbounds %[[ARG0]][%[[IDX]]] +// CHECK-DAG: %[[V2:.*]] = llvm.load %[[PTR]] invariant +// CHECK: %[[RET:.*]] = call @add(%[[C2]], %[[V2]]) +// CHECK: return %[[RET]] + +// CHECK: func.func @tensorcall(%[[ARG0:.*]]: !llvm.ptr +// CHECK-SAME: {xla.slice_index = 0 : i64}, %[[ARG1:.*]]: index) +// CHECK: %[[RET:.*]] = call @tensorarg(%[[ARG0]], %[[ARG1]]) +// CHECK: return %[[RET]] + +// CHECK: func.func @stores( +// CHECK-SAME: %[[ARG0:.*]]: !llvm.ptr {xla.slice_index = 0 : i64}, +// CHECK-SAME: %[[ARG1:.*]]: !llvm.ptr {xla.slice_index = 1 : i64}) +// CHECK-NEXT: %[[CST:.*]] = arith.constant 3.000000e+00 : f32 +// CHECK-NEXT: %[[PTR1:.*]] = llvm.getelementptr inbounds %[[ARG1]][17] +// CHECK-NEXT: llvm.store %[[CST]], %[[PTR1]] +// CHECK-NEXT: %[[PTR2:.*]] = llvm.getelementptr inbounds %[[ARG1]][23] +// CHECK-NEXT: llvm.store %[[CST]], %[[PTR2]] +// CHECK-NEXT: return + +// ----- + +module { + func.func @layout( + %arg0: tensor<2x3xf32, dense<[0, 1]> : tensor<2xi64>>, + %arg1: index, %arg2: index) -> f32 { + %v = tensor.extract %arg0[%arg1, %arg2] + : tensor<2x3xf32, dense<[0, 1]> : tensor<2xi64>> + func.return %v : f32 + } +} + +// CHECK: #[[MAP:.*]] = affine_map<(d0, d1) -> (d0 + d1 * 2)> +// CHECK: @layout(%[[ARG0:.*]]: !llvm.ptr, +// CHECK-SAME: %[[X:.*]]: index, %[[Y:.*]]: index +// CHECK: %[[IDX:.*]] = affine.apply #[[MAP]](%[[X]], %[[Y]]) +// CHECK: %[[IDX_CAST:.*]] = arith.index_castui %[[IDX]] : index to i32 +// CHECK: %[[PTR:.*]] = llvm.getelementptr inbounds %[[ARG0]][%[[IDX_CAST]]] +// CHECK: llvm.load %[[PTR]] + +// ----- + +module { + func.func @store_control_flow( + %arg0: tensor<2xf32>, + %arg1: index + ) -> tensor<2xf32> { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %cst = arith.constant 0.0 : f32 + %cst2 = arith.constant 1.0 : f32 + + %for = scf.for %i = %c0 to %c2 step %c1 iter_args(%arg2 = %arg0) -> tensor<2xf32> { + %new_out = tensor.insert %cst into %arg2[%i] : tensor<2xf32> + scf.yield %new_out : tensor<2xf32> + } + + %inbounds = arith.cmpi sle, %arg1, %c1 : index + %result = scf.if %inbounds -> tensor<2xf32> { + %if = tensor.insert %cst2 into %for[%arg1] : tensor<2xf32> + scf.yield %if : tensor<2xf32> + } else { + scf.yield %for : tensor<2xf32> + } + func.return %result : tensor<2xf32> + } +} + +// CHECK: @store_control_flow(%[[ARG0:.*]]: !llvm.ptr, %[[X:.*]]: index) { +// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index +// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index +// CHECK: scf.for %[[I:.*]] = %[[C0]] to %[[C2]] step %[[C1]] { +// CHECK: %[[CAST:.*]] = arith.index_castui %[[I]] : index to i32 +// CHECK: %[[PTR:.*]] = llvm.getelementptr inbounds %[[ARG0]][%[[CAST]]] +// CHECK: llvm.store {{.*}}, %[[PTR]] +// CHECK: %[[INBOUNDS:.*]] = arith.cmpi +// CHECK: scf.if %[[INBOUNDS]] { +// CHECK: llvm.store +// CHECK-NEXT: } +// CHECK-NEXT: return + +// ----- + +module { + func.func @large_tensor( + %arg0: tensor<1024x1024x1024x6xf32>, + %arg1: index) -> f32 { + %v = tensor.extract %arg0[%arg1, %arg1, %arg1, %arg1] : tensor<1024x1024x1024x6xf32> + func.return %v : f32 + } +} + +// CHECK: @large_tensor +// CHECK: arith.index_castui {{.*}} : index to i64 \ No newline at end of file diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/tests/merge_pointers_to_same_slice.mlir b/third_party/xla/xla/service/gpu/fusions/mlir/tests/merge_pointers_to_same_slice.mlir new file mode 100644 index 00000000000000..89c1d0c320cc85 --- /dev/null +++ b/third_party/xla/xla/service/gpu/fusions/mlir/tests/merge_pointers_to_same_slice.mlir @@ -0,0 +1,40 @@ +// RUN: mlir_fusions_opt %s -split-input-file -xla-gpu-lower-tensors -xla-gpu-merge-pointers | FileCheck %s + +module { + func.func @tensorargs(%arg0: tensor<43xf32> {xla.slice_index = 0}, + %arg1: tensor<43xf32> {xla.slice_index = 1, xla.invariant}, + %arg2: tensor<43xf32> {xla.slice_index = 0}, + %arg3: index) -> f32 { + %v0 = tensor.extract %arg0[%arg3] : tensor<43xf32> + %v1 = tensor.extract %arg1[%arg3] : tensor<43xf32> + %v2 = tensor.extract %arg2[%arg3] : tensor<43xf32> + %sum = arith.addf %v0, %v1 : f32 + %sum2 = arith.addf %sum, %v2 : f32 + func.return %sum2 : f32 + } + + func.func @tensorcall(%arg0: tensor<43xf32> {xla.slice_index = 0}, + %arg1: tensor<43xf32> {xla.slice_index = 1, xla.invariant}, + %arg2: tensor<43xf32> {xla.slice_index = 0}, + %arg3: index) -> f32 { + %call = func.call @tensorargs(%arg0, %arg1, %arg2, %arg3) : + (tensor<43xf32>, tensor<43xf32>, tensor<43xf32>, index) -> f32 + func.return %call : f32 + } +} + +// CHECK: func.func @tensorargs( +// CHECK-SAME: %[[ARG0:.*]]: !llvm.ptr {llvm.noalias}, +// CHECK-SAME: %[[ARG1:.*]]: !llvm.ptr {llvm.noalias, xla.invariant}, +// CHECK-SAME: %[[ARG2:.*]]: index) -> f32 { +// CHECK: %[[GEP0:.*]] = llvm.getelementptr inbounds %[[ARG0]] +// CHECK: llvm.load %[[GEP0]] : !llvm.ptr +// CHECK: %[[GEP1:.*]] = llvm.getelementptr inbounds %[[ARG1]] +// CHECK: llvm.load %[[GEP1]] invariant : !llvm.ptr +// CHECK: %[[GEP2:.*]] = llvm.getelementptr inbounds %[[ARG0]] + +// CHECK: func.func @tensorcall +// CHECK-SAME: %[[ARG0:.*]]: !llvm.ptr {llvm.noalias}, +// CHECK-SAME: %[[ARG1:.*]]: !llvm.ptr {llvm.noalias, xla.invariant}, +// CHECK-SAME: %[[ARG2:.*]]: index) -> f32 { +// CHECK: call @tensorargs(%[[ARG0]], %[[ARG1]], %[[ARG2]]) diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/tests/mlir_fusions_opt.cc b/third_party/xla/xla/service/gpu/fusions/mlir/tests/mlir_fusions_opt.cc new file mode 100644 index 00000000000000..4f0f123e2f9f55 --- /dev/null +++ b/third_party/xla/xla/service/gpu/fusions/mlir/tests/mlir_fusions_opt.cc @@ -0,0 +1,44 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mlir/Dialect/Affine/IR/AffineOps.h" // from @llvm-project +#include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project +#include "mlir/Dialect/Func/Extensions/AllExtensions.h" // from @llvm-project +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/Dialect/GPU/IR/GPUDialect.h" // from @llvm-project +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" // from @llvm-project +#include "mlir/Dialect/Math/IR/Math.h" // from @llvm-project +#include "mlir/Dialect/SCF/IR/SCF.h" // from @llvm-project +#include "mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/Tools/mlir-opt/MlirOptMain.h" // from @llvm-project +#include "mlir/Transforms/Passes.h" // from @llvm-project +#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" +#include "xla/service/gpu/fusions/mlir/passes.h" + +int main(int argc, char **argv) { + mlir::DialectRegistry registry; + registry.insert(); + mlir::func::registerAllExtensions(registry); + mlir::registerCanonicalizerPass(); + xla::gpu::registerGpuFusionTransformsPasses(); + + return mlir::failed( + MlirOptMain(argc, argv, "XLA MLIR Fusion Pass Driver\n", registry)); +} diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/tests/propagate_slice_indices.mlir b/third_party/xla/xla/service/gpu/fusions/mlir/tests/propagate_slice_indices.mlir new file mode 100644 index 00000000000000..fa8d1623d7b5ff --- /dev/null +++ b/third_party/xla/xla/service/gpu/fusions/mlir/tests/propagate_slice_indices.mlir @@ -0,0 +1,36 @@ +// RUN: mlir_fusions_opt %s -split-input-file -xla-gpu-propagate-slice-indices | FileCheck %s + +module { + func.func @add(%arg0: f32, %arg1: f32) -> f32 { + %sum = arith.addf %arg0, %arg1 : f32 + func.return %sum : f32 + } + + func.func @tensorarg(%arg0: tensor<43xf32>, %arg1: index) -> f32 { + %v1 = arith.constant 2.0 : f32 + %v2 = tensor.extract %arg0[%arg1] : tensor<43xf32> + %sum = func.call @add(%v1, %v2) : (f32, f32) -> f32 + func.return %sum : f32 + } + + func.func @tensorcall(%arg0: tensor<43xf32>, %arg1: index) -> f32 { + %call = func.call @tensorarg(%arg0, %arg1) : (tensor<43xf32>, index) -> f32 + func.return %call : f32 + } + + func.func @stores(%arg0: tensor<17xf32> {xla.invariant, xla.slice_index = 0}, + %arg1: tensor<43xf32> {xla.slice_index = 1}) -> tensor<43xf32> + attributes { xla.entry } { + %c17 = arith.constant 17 : index + %c23 = arith.constant 23 : index + %cst = arith.constant 3.0 : f32 + %out = tensor.insert %cst into %arg1[%c17] : tensor<43xf32> + %out2 = tensor.insert %cst into %out[%c23] : tensor<43xf32> + func.return %out2 : tensor<43xf32> + } +} + +// CHECK-DAG: @add(%{{.*}}: f32, %{{.*}}: f32) +// CHECK-DAG: @tensorarg(%{{.*}}: tensor<43xf32> {xla.invariant, xla.slice_index = 0 : i64}, %{{.*}}: index) +// CHECK-DAG: @tensorcall(%{{.*}}: tensor<43xf32> {xla.invariant, xla.slice_index = 0 : i64}, %{{.*}}: index) +// CHECK-DAG: @stores(%{{.*}}: tensor<17xf32> {xla.invariant, xla.slice_index = 0 : i64}, %{{.*}}: tensor<43xf32> {xla.slice_index = 1 : i64}) diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/tests/simplify_affine.mlir b/third_party/xla/xla/service/gpu/fusions/mlir/tests/simplify_affine.mlir new file mode 100644 index 00000000000000..11bbd9fd567132 --- /dev/null +++ b/third_party/xla/xla/service/gpu/fusions/mlir/tests/simplify_affine.mlir @@ -0,0 +1,82 @@ +// RUN: mlir_fusions_opt %s -split-input-file -xla-gpu-simplify-affine | FileCheck %s + +module { + func.func @op_and_for_ranges(%arg0: !llvm.ptr, %arg1: !llvm.ptr, %arg2: !llvm.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + %0 = gpu.thread_id x {xla.range = [0 : index, 127 : index]} + %1 = gpu.block_id x {xla.range = [0 : index, 3071 : index]} + scf.for %arg3 = %c0 to %c4 step %c1 { + %2 = affine.apply affine_map<()[s0, s1, s2] -> (s0 * 512 + s1 * 4 + s2 - ((s1 * 4 + s2) floordiv 256) * 256 + (s1 floordiv 64) * 256 - ((s0 * 2 + s1 floordiv 64) floordiv 3) * 768 + ((s0 * 128 + s1) floordiv 192) * 768 - (((s0 * 128 + s1) floordiv 192) floordiv 1024) * 786432 + (s0 floordiv 1536) * 786432)>()[%1, %0, %arg3] + %3 = arith.index_castui %2 : index to i64 + %4 = llvm.getelementptr %arg0[%3] : (!llvm.ptr, i64) -> !llvm.ptr, f32 + %5 = llvm.load %4 invariant : !llvm.ptr -> f32 + %8 = llvm.getelementptr %arg1[%3] : (!llvm.ptr, i64) -> !llvm.ptr, f32 + %9 = llvm.load %8 invariant : !llvm.ptr -> f32 + %10 = arith.cmpf oge, %5, %9 : f32 + %11 = llvm.getelementptr %arg2[%3] : (!llvm.ptr, i64) -> !llvm.ptr, i1 + llvm.store %10, %11 : i1, !llvm.ptr + } + return + } +} + +// CHECK: @op_and_for_ranges +// CHECK-DAG: %[[C512:.*]] = arith.constant 512 +// CHECK-DAG: %[[C4:.*]] = arith.constant 4 +// CHECK-DAG: %[[TID_X:.*]] = gpu.thread_id x +// CHECK-DAG: %[[BID_X:.*]] = gpu.block_id x +// CHECK: scf.for %[[I:.*]] = +// CHECK: %[[BID_32:.*]] = arith.index_castui %[[BID_X]] : index to i32 +// CHECK: %[[BLOCK_OFFSET:.*]] = arith.muli %[[BID_32]], %[[C512]] +// CHECK: %[[TID_32:.*]] = arith.index_castui %[[TID_X]] : index to i32 +// CHECK: %[[THREAD_OFFSET:.*]] = arith.muli %[[TID_32]], %[[C4]] +// CHECK: %[[OFFSET:.*]] = arith.addi %[[BLOCK_OFFSET]], %[[THREAD_OFFSET]] +// CHECK: %[[I_32:.*]] = arith.index_castui %[[I]] : index to i32 +// CHECK: %[[UNROLL_OFFSET:.*]] = arith.addi %[[OFFSET]], %[[I_32]] +// CHECK: %[[UNROLL_INDEX:.*]] = arith.index_castui %[[UNROLL_OFFSET]] : i32 to index +// CHECK: arith.index_castui %[[UNROLL_INDEX]] : index to i64 + +// ----- + +module { + func.func @arg_ranges(%arg0: index {xla.range = [0 : index, 42 : index]}, %arg1: index {xla.range = [0 : index, 1000 : index]}) -> index { + %0 = affine.apply affine_map<()[s0, s1] -> (s0 floordiv 100 + s1 floordiv 100)>()[%arg0, %arg1] + return %0 : index + } +} + +// CHECK: @arg_ranges +// CHECK-NEXT: %[[C100:.*]] = arith.constant 100 +// CHECK-NEXT: %[[ARG0_32:.*]] = arith.index_castui {{.*}} : index to i32 +// CHECK-NEXT: %[[RET_32:.*]] = arith.divui %[[ARG0_32]], %[[C100]] +// CHECK-NEXT: %[[RET:.*]] = arith.index_castui %[[RET_32]] : i32 to index +// CHECK-NEXT: return %[[RET]] + + +// ----- + +module { + func.func @needs_i64(%arg0: index {xla.range = [0 : index, 1000000000000 : index]}, %arg1: index {xla.range = [0 : index, 10 : index]}) -> index { + %0 = affine.apply affine_map<()[s0, s1] -> (s0 + s1)>()[%arg0, %arg1] + return %0 : index + } +} + +// CHECK: @needs_i64 +// CHECK: arith.index_castui {{.*}} : index to i64 +// CHECK: arith.index_castui {{.*}} : index to i64 +// CHECK: arith.index_castui {{.*}} : i64 to index + +// ----- + +module { + func.func @cant_lower(%arg0: index {xla.range = [-10 : index, 42 : index]}, %arg1: index {xla.range = [0 : index, 1000 : index]}) -> index { + %0 = affine.apply affine_map<()[s0, s1] -> (s0 floordiv 100 + s1 floordiv 100)>()[%arg0, %arg1] + return %0 : index + } +} + +// CHECK: @cant_lower +// CHECK: affine.apply diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/type_util.cc b/third_party/xla/xla/service/gpu/fusions/mlir/type_util.cc new file mode 100644 index 00000000000000..9de53000fd8394 --- /dev/null +++ b/third_party/xla/xla/service/gpu/fusions/mlir/type_util.cc @@ -0,0 +1,63 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "xla/service/gpu/fusions/mlir/type_util.h" + +#include "absl/log/check.h" +#include "llvm/ADT/SmallVector.h" +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project +#include "xla/layout_util.h" +#include "xla/shape.h" +#include "xla/translate/hlo_to_mhlo/hlo_utils.h" + +namespace xla { +namespace gpu { +namespace mlir_converter { + +mlir::Type TensorShapeToMlirType(const Shape& shape, mlir::OpBuilder& b) { + CHECK(shape.IsArray()); + + // Default layouts create a lot of clutter in the IR, so only add an + // encoding when needed. + mlir::Attribute layout = {}; + if (!LayoutUtil::IsMonotonicWithDim0Major(shape.layout())) { + layout = CreateDenseIntElementsAttrFromVector( + llvm::to_vector(shape.layout().minor_to_major()), b); + } + return mlir::RankedTensorType::get( + llvm::to_vector(shape.dimensions()), + *ConvertPrimitiveTypeToMLIRType(shape.element_type(), b), layout); +} + +llvm::SmallVector ShapeToMlirTypes(const Shape& shape, + mlir::OpBuilder& b) { + llvm::SmallVector types; + types.reserve(shape.IsTuple() ? shape.tuple_shapes_size() : 1); + if (shape.IsTuple()) { + types.reserve(shape.tuple_shapes_size()); + for (auto& tuple_shape : shape.tuple_shapes()) { + types.push_back(TensorShapeToMlirType(tuple_shape, b)); + } + } else { + types.push_back(TensorShapeToMlirType(shape, b)); + } + return types; +} + +} // namespace mlir_converter +} // namespace gpu +} // namespace xla diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/type_util.h b/third_party/xla/xla/service/gpu/fusions/mlir/type_util.h new file mode 100644 index 00000000000000..d9d394ffa6c9b8 --- /dev/null +++ b/third_party/xla/xla/service/gpu/fusions/mlir/type_util.h @@ -0,0 +1,41 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef XLA_SERVICE_GPU_FUSIONS_MLIR_TYPE_UTIL_H_ +#define XLA_SERVICE_GPU_FUSIONS_MLIR_TYPE_UTIL_H_ + +#include "llvm/ADT/SmallVector.h" +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project +#include "xla/shape.h" + +namespace xla { +namespace gpu { +namespace mlir_converter { + +// Converts an XLA tensor to an MLIR ranked tensor. The layout is stored in the +// encoding attribute, if it is not the default layout. `shape` must be an +// array. +mlir::Type TensorShapeToMlirType(const Shape& shape, mlir::OpBuilder& b); + +// If `shape` is a tuple, returns the converted tuple shapes. Otherwise returns +// just the converted shape. Nested tuples are not supported. +llvm::SmallVector ShapeToMlirTypes(const Shape& shape, + mlir::OpBuilder& b); + +} // namespace mlir_converter +} // namespace gpu +} // namespace xla + +#endif // XLA_SERVICE_GPU_FUSIONS_MLIR_TYPE_UTIL_H_ diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/type_util_test.cc b/third_party/xla/xla/service/gpu/fusions/mlir/type_util_test.cc new file mode 100644 index 00000000000000..95dcdbf161725e --- /dev/null +++ b/third_party/xla/xla/service/gpu/fusions/mlir/type_util_test.cc @@ -0,0 +1,88 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "xla/service/gpu/fusions/mlir/type_util.h" + +#include + +#include +#include +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/SmallVectorExtras.h" +#include "llvm/Support/raw_ostream.h" +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project +#include "xla/shape_util.h" + +namespace xla { +namespace gpu { +namespace mlir_converter { +namespace { + +using ::testing::ElementsAre; + +std::string TypeToString(mlir::Type type) { + std::string out; + llvm::raw_string_ostream stream(out); + stream << type; + return out; +} + +llvm::SmallVector TypesToString( + const llvm::SmallVector& types) { + return llvm::map_to_vector(types, TypeToString); +} + +TEST(TensorShapeTest, ConvertsShape) { + mlir::MLIRContext ctx; + mlir::OpBuilder b(&ctx); + EXPECT_EQ(TypeToString( + TensorShapeToMlirType(ShapeUtil::MakeShape(S32, {4, 5, 6}), b)), + "tensor<4x5x6xi32>"); +} + +TEST(TensorShapeTest, ConvertsLayout) { + mlir::MLIRContext ctx; + mlir::OpBuilder b(&ctx); + EXPECT_EQ( + TypeToString(TensorShapeToMlirType( + ShapeUtil::MakeShapeWithDenseLayout(S32, {4, 5, 6}, {0, 2, 1}), b)), + "tensor<4x5x6xi32, dense<[0, 2, 1]> : tensor<3xi64>>"); +} + +TEST(ShapeTest, ConvertsArray) { + mlir::MLIRContext ctx; + mlir::OpBuilder b(&ctx); + EXPECT_THAT( + TypesToString(ShapeToMlirTypes(ShapeUtil::MakeShape(S32, {4, 5, 6}), b)), + ElementsAre("tensor<4x5x6xi32>")); +} + +TEST(ShapeTest, ConvertsTuple) { + mlir::MLIRContext ctx; + mlir::OpBuilder b(&ctx); + + EXPECT_THAT( + TypesToString(ShapeToMlirTypes( + ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(S32, {4, 5, 6}), + ShapeUtil::MakeShape(F32, {})}), + b)), + ElementsAre("tensor<4x5x6xi32>", "tensor")); +} + +} // namespace +} // namespace mlir_converter +} // namespace gpu +} // namespace xla diff --git a/third_party/xla/xla/service/gpu/fusions/reduction.cc b/third_party/xla/xla/service/gpu/fusions/reduction.cc index 2be5a39fecef36..d0700aa0c49dfd 100644 --- a/third_party/xla/xla/service/gpu/fusions/reduction.cc +++ b/third_party/xla/xla/service/gpu/fusions/reduction.cc @@ -15,7 +15,6 @@ limitations under the License. #include "xla/service/gpu/fusions/reduction.h" #include -#include #include #include #include @@ -34,9 +33,9 @@ limitations under the License. #include "absl/container/node_hash_map.h" #include "absl/log/check.h" #include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" -#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/Twine.h" #include "llvm/IR/Constants.h" #include "llvm/IR/DerivedTypes.h" @@ -47,6 +46,7 @@ limitations under the License. #include "llvm/IR/Value.h" #include "llvm/Support/AtomicOrdering.h" #include "llvm/Support/Casting.h" +#include "mlir/IR/AffineExpr.h" // from @llvm-project #include "mlir/IR/Value.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project #include "xla/hlo/ir/hlo_casting_utils.h" @@ -54,6 +54,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/utils/hlo_query.h" +#include "xla/layout_util.h" #include "xla/mlir_hlo/lhlo/IR/lhlo_ops.h" #include "xla/service/buffer_assignment.h" #include "xla/service/gpu/elemental_ir_emitter.h" @@ -69,9 +70,11 @@ limitations under the License. #include "xla/service/gpu/kernel_arguments.h" #include "xla/service/gpu/kernel_reuse_cache.h" #include "xla/service/gpu/launch_dimensions.h" +#include "xla/service/gpu/model/indexing_analysis.h" +#include "xla/service/gpu/model/indexing_map.h" #include "xla/service/gpu/parallel_loop_emitter.h" #include "xla/service/gpu/reduction_utils.h" -#include "xla/service/gpu/runtime3/kernel_thunk.h" +#include "xla/service/gpu/runtime/kernel_thunk.h" #include "xla/service/gpu/target_util.h" #include "xla/service/gpu/thunk.h" #include "xla/service/llvm_ir/fused_ir_emitter.h" @@ -96,6 +99,15 @@ namespace xla { namespace gpu { namespace { +// These are the indices that GetReductionKindAndContiguousComponents uses. +constexpr int kRowMajorReducedDimension = 0; +constexpr int kRowKeptDimension = 1; +constexpr int kRowMinorReducedDimension = 2; + +constexpr int kColMajorKeptDimension = 0; +constexpr int kColReducedDimension = 1; +constexpr int kColMinorKeptDimension = 2; + using TypedPointer = std::pair; // Fusion root -> array of indexes, one per reduction output. @@ -111,13 +123,16 @@ int GetNumOutputs(const Shape& shape) { return 1; } +const Shape& OutputShape(const Shape& output_shape, int output_index) { + CHECK(output_index == 0 || output_shape.IsTuple()); + return output_shape.IsTuple() ? output_shape.tuple_shapes(output_index) + : output_shape; +} + llvm::Type* GetIndexType(const HloFusionInstruction& fusion, - const TilingScheme& tiling_scheme, - llvm::IRBuilder<>* builder) { - return GetIndexTypeForKernel(&fusion, - tiling_scheme.GetNumThreadsPerBlockPhysical() * - tiling_scheme.GetNumBlocksPhysical(), - builder); + const Tiling& tiling, llvm::IRBuilder<>* builder) { + return GetIndexTypeForKernel( + &fusion, tiling.GetNumThreadsPerBlock() * tiling.GetNumBlocks(), builder); } // For a row reduction, returns the number of rows we can process in parallel @@ -130,29 +145,15 @@ int RowReductionGetRowsPerWarp(int reduced_dimension_size) { return WarpSize() / reduced_dimension_size; } -int64_t NearestPowerOfTwo(int64_t v) { - if (v < 0) { - return 0; - } - int64_t upper = absl::bit_ceil(v); - int64_t lower = upper >> 1; - return upper - v < v - lower ? upper : lower; -} +} // namespace -// Divides `num_reduces` reduces into groups. Different groups will be executed -// in parallel. Generally speaking, we'd like to run the reduce instructions -// in parallel without incurring too much recomputation overhead. The current -// heuristic is to place reduce instructions who share nothing or only -// (broadcasted) scalars/constants into different groups; otherwise, they are -// placed in the same group. Non-reduce instructions always go with the reduce -// instructions into the same group so long as they share any predecessors. -std::vector> GroupDisjointReductions( +ReductionFusion::IndexGroups ReductionFusion::GroupDisjointReductions( const HloFusionAnalysis& analysis) { const int num_fusion_outputs = analysis.fusion_roots().size(); CHECK_NE(0, num_fusion_outputs); if (num_fusion_outputs == 1) { - return {{analysis.fusion_roots()[0]}}; + return {{{analysis.fusion_roots()[0]}}, {0}, {true}}; } absl::node_hash_map> GroupDisjointReductions( absl::flat_hash_set> reachable_outputs; absl::flat_hash_set roots_with_reduction; - auto roots = analysis.fusion().GetRoots(); + const auto& roots = analysis.fusion().GetRoots(); + ReductionFusion::IndexGroups result; + result.group_id_per_root.reserve(roots.size()); + result.is_reduction_root.reserve(roots.size()); for (auto [root, hero] : llvm::zip(roots, analysis.fusion_heroes())) { disjoint_sets[root].Get() = root; reachable_outputs[root].insert(root); - if (IsRealReductionHero(root.instruction(), *hero)) { + result.is_reduction_root.push_back( + IsRealReductionHero(root.instruction(), *hero)); + if (result.is_reduction_root.back()) { roots_with_reduction.insert(root); } else if (first_non_reduction_root) { disjoint_sets[*first_non_reduction_root].Merge(&disjoint_sets[root]); @@ -230,71 +236,63 @@ std::vector> GroupDisjointReductions( } // Place output instructions in the same set into the same group. - ConstHloInstructionMap> groups; + ConstHloInstructionMap> group_map; for (auto root : roots) { - groups[&disjoint_sets[root].Get().instruction()].push_back( + group_map[&disjoint_sets[root].Get().instruction()].push_back( &root.instruction()); } - std::vector> ret; - ret.reserve(groups.size()); - absl::c_for_each( - groups, [&](auto& iter) { ret.emplace_back(std::move(iter.second)); }); - return ret; -} + absl::flat_hash_map set_ids; + for (auto&& [id, disjoint_set] : llvm::enumerate(disjoint_sets)) { + set_ids[&disjoint_set.second.Get().instruction()] = id; + } -// Experimentally determined values to achieve optimal number of -// bytes-in-flight. With a bound of #warps/SM which can be concurrently -// scheduled, for small reduced values it can be hard to achieve optimal -// number of bytes-in-flight. In order to address it, we increase the # of -// threads/block (physically, while keeping logical mapping the same), which -// allows larger # of bytes-in-flight. -int CalculateVirtualThreadScalingFactorForReduction( - const HloFusionAnalysis& analysis, - const ReductionDimensions& reduction_dimensions) { - int64_t dimx = reduction_dimensions.dimensions[TilingScheme::DimX]; - if (reduction_dimensions.is_row_reduction && dimx <= 128) { - int rows_per_warp = RowReductionGetRowsPerWarp(dimx); - const auto* cuda_cc = std::get_if( - &analysis.device_info().gpu_compute_capability()); - if (cuda_cc != nullptr && - cuda_cc->IsAtLeast(se::CudaComputeCapability::AMPERE)) { - return rows_per_warp * 3; - } - return rows_per_warp * 5; + for (auto root : roots) { + result.group_id_per_root.push_back( + set_ids[&disjoint_sets[root].Get().instruction()]); } - return 1; + + result.grouped_roots.reserve(group_map.size()); + absl::c_for_each(group_map, [&](auto& it) { + result.grouped_roots.emplace_back(std::move(it.second)); + }); + return result; } -bool CanVectorizeReduction(const HloFusionAnalysis& analysis, - const ReductionDimensions& reduction_dimensions, - int num_threads_x, Vector3 reduction_tiling) { +namespace { + +int GetVectorSize(const HloFusionAnalysis& analysis, + const ReductionDimensions& reduction_dimensions, + int num_threads, Vector3 reduction_tiling) { if (!reduction_dimensions.is_row_reduction) { - return false; + return 1; } - if (reduction_dimensions.dimensions[TilingScheme::DimX] % 2 != 0 || + if (reduction_dimensions.dimensions[kRowMinorReducedDimension] % 2 != 0 || MayPreventVectorization(analysis.fusion())) { - return false; + return 1; } // Enabling vectorization if number of threads is <= warpsize leads to half or // more of the threads not doing any work. - if (num_threads_x <= WarpSize()) { - return false; + if (num_threads <= WarpSize()) { + return 1; } const auto* cuda_cc = std::get_if( &analysis.device_info().gpu_compute_capability()); - if (cuda_cc == nullptr) return false; - if (cuda_cc->IsAtLeast(se::CudaComputeCapability::VOLTA)) return true; + if (cuda_cc == nullptr) return 1; + if (cuda_cc->IsAtLeast(se::CudaComputeCapability::VOLTA)) return 2; if (cuda_cc->IsAtLeast(se::CudaComputeCapability::PASCAL_)) { return analysis.input_output_info().smallest_input_dtype_bits <= 32 && - reduction_dimensions.dimensions[TilingScheme::DimX] % - (reduction_tiling[2] * num_threads_x) == - 0; + reduction_dimensions.dimensions[kRowMinorReducedDimension] % + (reduction_tiling[kRowMinorReducedDimension] * + num_threads) == + 0 + ? 2 + : 1; } - return false; + return 1; } llvm::Value* CastSharedToGlobal(llvm::IRBuilder<>* builder, llvm::Value* input, @@ -321,11 +319,20 @@ class ReductionFusion::ReductionEmitter { reduction_codegen_info_(reduction_codegen_info), ir_emitter_context_(ir_emitter_context), fusion_(fusion), - index_ty_(GetIndexType(fusion, reduction_codegen_info.GetTilingScheme(), - elemental_emitter_.builder())) {} + index_ty_(GetIndexType(fusion, reduction_codegen_info.GetTiling(), + elemental_emitter_.builder())) { + for (auto hero : analysis.fusion_heroes()) { + if (hero->opcode() == HloOpcode::kReduce) { + for (int i = 0; i < hero->operand_count() / 2; ++i) { + CHECK(LayoutUtil::IsMonotonicWithDim0Major( + hero->operand(i)->shape().layout())) + << "reduction-layout-normalizer must run before code generation"; + } + } + } + } - absl::StatusOr EmitInitializers( - mlir::lmhlo::FusionOp fusion_op); + absl::StatusOr EmitInitializers(); absl::Status EmitKernel(const LaunchDimensions& launch_dims, std::vector inputs, std::vector outputs); @@ -334,7 +341,6 @@ class ReductionFusion::ReductionEmitter { friend class ReductionGroupEmitter; absl::StatusOr> BuildKernelThunkForFusion( - mlir::lmhlo::FusionOp fusion_op, const LaunchDimensions& launch_dimensions, absl::string_view discriminator, std::function, @@ -342,8 +348,8 @@ class ReductionFusion::ReductionEmitter { kernel_builder_fn); absl::StatusOr> BuildFusedInitializerThunk( - mlir::lmhlo::FusionOp fusion_op, const HloInstruction* fusion_root, - mlir::Value dest, BufferAllocation::Slice dest_slice, int output_index); + const HloInstruction* fusion_root, BufferAllocation::Slice dest_slice, + int output_index); absl::Status EmitIRForReduction( absl::Span instr_index_group, @@ -354,7 +360,7 @@ class ReductionFusion::ReductionEmitter { void EmitSyncThreads(); int ReducedDimensionSize() const { - return reduction_codegen_info_.GetTilingScheme().GetShape()[2]; + return reduction_codegen_info_.GetTiling().GetShape()[2]; } llvm::IRBuilder<>* builder_; @@ -458,9 +464,7 @@ ReductionFusion::ReductionGroupEmitter::ReductionGroupEmitter( for (const HloReduceInstruction* reduce_hlo : reduce_instr_index_group) { for (int op_result_idx = 0; op_result_idx < GetNumOutputs(reduce_hlo->shape()); op_result_idx++) { - Shape result_shape = reduce_hlo->shape().IsTuple() - ? reduce_hlo->shape().tuple_shapes(op_result_idx) - : reduce_hlo->shape(); + Shape result_shape = OutputShape(reduce_hlo->shape(), op_result_idx); llvm::Type* element_type = llvm_ir::PrimitiveTypeToIrType( result_shape.element_type(), builder->GetInsertBlock()->getModule()); @@ -480,7 +484,7 @@ ReductionFusion::ReductionGroupEmitter::ReductionGroupEmitter( .value(); builder->CreateStore(init_ir_value, result_address); - const TilingScheme& tiling_scheme = reduction_info.GetTilingScheme(); + const Tiling& tiling = reduction_info.GetTiling(); auto shared_cache = [&]() -> std::optional { auto* module = reduction_emitter.ir_emitter_context_.llvm_module(); if (reduction_info.IsRowReduction()) { @@ -489,23 +493,21 @@ ReductionFusion::ReductionGroupEmitter::ReductionGroupEmitter( reduction_emitter_.ReducedDimensionSize()) > 1) { return std::nullopt; } - CHECK_EQ(tiling_scheme.GetNumThreadsPerBlock() % WarpSize(), 0); - int num_warps = tiling_scheme.GetNumThreadsPerBlock() / WarpSize(); + // Allocate one shared memory element per warp. + auto block_size = tiling.GetThreadsPerBlock(); + CHECK_EQ(block_size[kRowMinorReducedDimension] % WarpSize(), 0); return llvm_ir::AllocateSharedMemoryTile( module, element_type, - {tiling_scheme.GetThreadIdScalingFactor(), num_warps}, + {block_size[kRowKeptDimension], + block_size[kRowMinorReducedDimension] / WarpSize()}, "shared_cache"); } - const auto& num_threads = tiling_scheme.GetThreadsPerBlock(); - // num_threads_x == num_threads_y. The "+1" is used to avoid bank - // conflicts. - CHECK_EQ(num_threads[TilingScheme::DimX], - num_threads[TilingScheme::DimY]); - return llvm_ir::AllocateSharedMemoryTile( - module, element_type, - {num_threads[TilingScheme::DimX], - num_threads[TilingScheme::DimX] + 1}, - "shared_cache"); + const auto& num_threads = tiling.GetThreadsPerBlock(); + int n = num_threads[kColReducedDimension]; + CHECK_EQ(n, num_threads[kColMinorKeptDimension]); + // The "+1" is used to avoid bank conflicts. + return llvm_ir::AllocateSharedMemoryTile(module, element_type, + {n, n + 1}, "shared_cache"); }(); llvm_ir::ElementGenerator input_gen = @@ -549,14 +551,12 @@ void ReductionFusion::ReductionEmitter::EmitSyncThreads() { // std::vector outputs) { ... }; // TF_ASSIGN_OR_RETURN( // auto thunk, -// BuildKernelThunkForFusion(..., fusion_op, launch_dimensions, builder_fn, -// ...)); +// BuildKernelThunkForFusion(..., launch_dimensions, builder_fn)); // AddThunkToThunkSequence(std::move(thunk)) // ``` absl::StatusOr> ReductionFusion::ReductionEmitter::BuildKernelThunkForFusion( - mlir::lmhlo::FusionOp fusion_op, const LaunchDimensions& launch_dimensions, - absl::string_view discriminator, + const LaunchDimensions& launch_dimensions, absl::string_view discriminator, std::function, std::vector)> kernel_builder_fn) { @@ -564,13 +564,9 @@ ReductionFusion::ReductionEmitter::BuildKernelThunkForFusion( fusion_.fused_instructions_computation(); std::string suggested_kernel_name = std::string(fusion_.name()); - TF_ASSIGN_OR_RETURN( - auto kernel_arguments, - ir_emitter_context_.emit_ir_from_hlo() - ? KernelArguments::Create(ir_emitter_context_.buffer_assignment(), - &fusion_) - : KernelArguments::Create(ir_emitter_context_.allocations(), - fusion_op)); + TF_ASSIGN_OR_RETURN(auto kernel_arguments, + KernelArguments::Create( + ir_emitter_context_.buffer_assignment(), &fusion_)); auto [status_or_entry, cached] = ir_emitter_context_.kernel_cache().GetWithStatus( @@ -586,7 +582,10 @@ ReductionFusion::ReductionEmitter::BuildKernelThunkForFusion( fusion_.operand_count(), launch_dimensions, builder_)); TF_RETURN_IF_ERROR(kernel_builder_fn(input_arrays, output_arrays)); - return {{kernel->getName().str(), launch_dimensions}}; + // Shared memory is allocated statically. + return {{kernel->getName().str(), launch_dimensions, + /*cluster_dim=*/std::nullopt, + /*shmem_bytes=*/0}}; }); TF_ASSIGN_OR_RETURN(const KernelReuseCache::Entry* entry, status_or_entry); if (cached) { @@ -594,18 +593,9 @@ ReductionFusion::ReductionEmitter::BuildKernelThunkForFusion( << entry->kernel_name; } - if (ir_emitter_context_.emit_ir_from_hlo()) { - return std::make_unique( - &fusion_, entry->kernel_name, kernel_arguments.args(), - launch_dimensions, - // Shared memory is allocated statically. - /*shmem_bytes=*/0); - } - return std::make_unique( - fusion_op, entry->kernel_name, kernel_arguments.args(), launch_dimensions, - // Shared memory is allocated statically. - /*shmem_bytes=*/0); + &fusion_, entry->kernel_name, kernel_arguments.args(), launch_dimensions, + entry->cluster_dim, entry->shmem_bytes); } absl::Status ReductionFusion::ReductionGroupEmitter::EmitExtraOutputsForReduce( @@ -646,8 +636,8 @@ absl::Status ReductionFusion::ReductionGroupEmitter::EmitExtraOutputsForReduce( absl::StatusOr> ReductionFusion::ReductionEmitter::BuildFusedInitializerThunk( - mlir::lmhlo::FusionOp fusion_op, const HloInstruction* fusion_root, - mlir::Value dest, BufferAllocation::Slice dest_slice, int output_index) { + const HloInstruction* fusion_root, BufferAllocation::Slice dest_slice, + int output_index) { const HloReduceInstruction* reduce = DynCast(fusion_root); TF_RET_CHECK(reduce); @@ -655,8 +645,8 @@ ReductionFusion::ReductionEmitter::BuildFusedInitializerThunk( const HloInstruction* init_value = reduce->init_values()[0]; TF_ASSIGN_OR_RETURN( std::optional> constant_init_thunk, - BuildConstantInitializerThunk(ir_emitter_context_, fusion_op, fusion_root, - init_value, dest, dest_slice)); + BuildConstantInitializerThunk(ir_emitter_context_, fusion_root, + init_value, dest_slice)); if (constant_init_thunk) { return *std::move(constant_init_thunk); } @@ -694,7 +684,7 @@ ReductionFusion::ReductionEmitter::BuildFusedInitializerThunk( return absl::OkStatus(); }; - return BuildKernelThunkForFusion(fusion_op, launch_dimensions, + return BuildKernelThunkForFusion(launch_dimensions, /*discriminator=*/ absl::StrCat("init_", output_index), builder_fn); @@ -767,63 +757,36 @@ ReductionFusion::ReductionGroupEmitter::GetOutputIndexForReduction( const HloReduceInstruction* reduction, const HloInstruction* root, int output_idx) const { auto* builder = reduction_emitter_.builder_; - const auto& reduction_info = reduction_emitter_.reduction_codegen_info_; - const TilingScheme& tiling_scheme = reduction_info.GetTilingScheme(); - const TilingThreadIdInfo& thread_id_info = tiling_kernel_info.thread_id_info; - - llvm_ir::IrArray::Index index = [&] { - auto offsets = thread_id_info.start_offsets; - if (!reduction_info.IsRowReduction()) { - std::swap(offsets[TilingScheme::DimX], offsets[TilingScheme::DimY]); - } - return tiling_kernel_info.tile_origin.AddOffset(offsets, builder); - }(); - - const Shape& operand_shape = reduction->inputs()[output_idx]->shape(); - Shape reduction_kept_element_shape = - ShapeUtil::DeleteDimensions(reduction->dimensions(), operand_shape); - - // Given the llvm_ir::IrArray index of a reduction input, returns the linear - // address of the reduction output as if the reduction were going to keep - // the input shape with the dimensions being reduced moved. - llvm::Value* untransposed_output_linear_address = [&] { + auto* index_ty = reduction_emitter_.index_ty_; + + // 1d or 2d output index (for row/column reduction). + auto projected_index = [&]() -> llvm_ir::IrArray::Index { + const auto& reduction_info = reduction_emitter_.reduction_codegen_info_; + const auto& offset = tiling_kernel_info.tile_origin; + const auto& shape = reduction_info.GetTiling().GetXlaShape(); + const auto& thread_ids = tiling_kernel_info.thread_id_info.thread_ids; if (reduction_info.IsRowReduction()) { - // For row-reduction, y-coordinate determines which row we write into. - return index[TilingScheme::DimY]; + constexpr int kDim = kRowKeptDimension; + return {{builder->CreateAdd(offset[kDim], thread_ids[kDim])}, + {shape.dimensions(kDim)}, + index_ty}; } - // For column reduction, we get the transposed address. - absl::Span dims_in_elem = tiling_scheme.GetShape(); - llvm::Value* x_dim_size = - index.GetConstantWithIndexType(dims_in_elem[TilingScheme::DimX]); - llvm::Value* x_block_offset = - builder->CreateMul(index[TilingScheme::DimZ], x_dim_size); - return builder->CreateAdd(x_block_offset, index[TilingScheme::DimX]); + auto* major_idx = offset[kColMajorKeptDimension]; + auto* minor_idx = builder->CreateAdd(offset[kColMinorKeptDimension], + thread_ids[kColReducedDimension]); + return {{major_idx, minor_idx}, + ShapeUtil::DeleteDimension(kColReducedDimension, shape), + index_ty}; }(); - // A reduction is allowed to transpose its output. For example, suppose - // we are reducing the second dimension of f32[10,20,30]{3,2,1}. We are - // allowed to produce as output either f32[10,30]{1,0} (no transpose) or - // f32[10,30]{0,1} (transposing the two output dims). - // - // At this point in the function we have a "partial sum" of input elements - // (stored in partial_result_addresses), and we need to accumulate it into - // the correct output element. - llvm_ir::IrArray::Index element_index( - /*linear=*/untransposed_output_linear_address, - reduction_kept_element_shape, builder); - const Shape& output_shape = !reduction->shape().IsTuple() - ? reduction->shape() - : reduction->shape().tuple_shapes(output_idx); - llvm_ir::IrArray::Index output_index(element_index.multidim(), output_shape, - element_index.GetType()); - // We need to check for root == reduction separately, because for variadic - // reduce the root shape would be a tuple, while 'output_shape' is the - // subshape. - return (root == reduction || - ShapeUtil::EqualIgnoringElementType(output_shape, root->shape())) - ? output_index - : output_index.SourceIndexOfBitcast(output_shape, root->shape(), - builder); + auto physical_shape = ShapeUtil::DeleteDimensions( + reduction->dimensions(), reduction->operand(output_idx)->shape()); + auto physical_index = + projected_index.SourceIndexOfBitcast(physical_shape, builder); + return llvm_ir::IrArray::Index(physical_index.multidim(), + OutputShape(reduction->shape(), output_idx), + index_ty) + .SourceIndexOfBitcast(OutputShape(root->shape(), output_idx), builder); } void ReductionFusion::ReductionGroupEmitter::WriteReductionOutput( @@ -865,15 +828,14 @@ void ReductionFusion::ReductionGroupEmitter::WriteReductionOutput( } } -// `current_output`: the value the tile has calculated. -// `output_address`: address where the output value has to be written. void ReductionFusion::ReductionGroupEmitter::EmitReductionOutputForRowReduction( const TilingKernelInfo& tiling_kernel_info, const HloReduceInstruction* reduction, const std::vector& roots) const { const HloComputation* reducer = reduction->to_apply(); const auto& thread_id_info = tiling_kernel_info.thread_id_info; - auto* thread_id_x = thread_id_info.thread_ids[TilingScheme::DimX]; + const auto& thread_ids = thread_id_info.thread_ids; + auto* thread_id_x = thread_ids[kRowMinorReducedDimension]; auto constant = [&](uint64_t c) -> llvm::Constant* { return llvm::ConstantInt::get(reduction_emitter_.index_ty_, c); }; @@ -893,12 +855,12 @@ void ReductionFusion::ReductionGroupEmitter::EmitReductionOutputForRowReduction( } const auto& reduction_info = reduction_emitter_.reduction_codegen_info_; - const TilingScheme& tiling_scheme = reduction_info.GetTilingScheme(); + const Tiling& tiling = reduction_info.GetTiling(); int num_rows_per_warp = RowReductionGetRowsPerWarp(reduction_emitter_.ReducedDimensionSize()); - EmitFullWarpShuffleDownLoopForReduce( - reducer, absl::MakeSpan(current_outputs), - tiling_scheme.GetNumThreadsPerBlockPhysical(), num_rows_per_warp); + EmitFullWarpShuffleDownLoopForReduce(reducer, absl::MakeSpan(current_outputs), + tiling.GetNumThreadsPerBlock(), + num_rows_per_warp); KernelSupportLibrary ksl(builder); llvm::Value* warp_id = builder->CreateUDiv(thread_id_x, constant(WarpSize())); @@ -910,67 +872,79 @@ void ReductionFusion::ReductionGroupEmitter::EmitReductionOutputForRowReduction( }); }; - if (num_rows_per_warp > 1) { - llvm::Value* is_writing_thread = is_zero(builder->CreateAnd( - thread_id_x, constant(reduction_emitter_.ReducedDimensionSize() - 1))); - emit_write_output(is_writing_thread, current_outputs); - return; - } + // The major kept dimension and vector dimension are not tiled, so they're + // always in bounds. + llvm::Value* is_in_bounds_y = builder->CreateICmpULT( + thread_ids[kRowKeptDimension], + tiling_kernel_info.output_tile_bounds[kRowKeptDimension]); - ksl.If("intra_warp_reduce_write", is_zero(thread_id_info.lane_id), [&] { - for (int oidx = 0; oidx < num_outputs; oidx++) { - auto& state = GetCalculationStateFor(reduction, oidx); - state.shared_cache->Store( - builder->CreateLoad(current_outputs[oidx].second, - current_outputs[oidx].first), - {thread_id_info.scaling_index, warp_id}, builder); - } - }); - - // TODO(cheshire): Don't we want to sync it once for everything in the - // output? Not once per each? - reduction_emitter_.EmitSyncThreads(); - ksl.If("inter_warp_reduce", is_zero(warp_id), [&] { - absl::InlinedVector selected_values; - for (int oidx = 0; oidx < num_outputs; oidx++) { - auto& state = GetCalculationStateFor(reduction, oidx); - llvm::Value* block_accum_addr = state.shared_cache->Address( - {thread_id_info.scaling_index, thread_id_info.lane_id}, builder); - - llvm::Type* element_type = - state.partial_result_address->getAllocatedType(); - - // Ensure initial value address is in generic, not scratch. - llvm::Value* initial_value_addr = - CastSharedToGlobal(builder, - llvm_ir::EmitAllocaAtFunctionEntry( - element_type, "initial_value_addr", builder), - element_type, /*name=*/""); - builder->CreateStore(state.initial_value, initial_value_addr); - - llvm::Value* warp_exists = builder->CreateICmpULT( + ksl.If("thread_in_bounds", is_in_bounds_y, [&] { + if (num_rows_per_warp > 1) { + llvm::Value* is_writing_thread = is_zero(builder->CreateAnd( thread_id_x, - constant(tiling_scheme.GetThreadsPerBlock()[TilingScheme::DimX] / - WarpSize())); + constant(reduction_emitter_.ReducedDimensionSize() - 1))); + emit_write_output(is_writing_thread, current_outputs); + return; + } - llvm::Value* selected_value = builder->CreateSelect( - warp_exists, block_accum_addr, initial_value_addr); + ksl.If("intra_warp_reduce_write", is_zero(thread_id_info.lane_id), [&] { + for (int oidx = 0; oidx < num_outputs; oidx++) { + auto& state = GetCalculationStateFor(reduction, oidx); + state.shared_cache->Store( + builder->CreateLoad(current_outputs[oidx].second, + current_outputs[oidx].first), + {thread_id_info.thread_ids[kRowKeptDimension], warp_id}, builder); + } + }); - selected_values.push_back({selected_value, element_type}); - } + // TODO(cheshire): Don't we want to sync it once for everything in the + // output? Not once per each? + reduction_emitter_.EmitSyncThreads(); + ksl.If("inter_warp_reduce", is_zero(warp_id), [&] { + absl::InlinedVector selected_values; + for (int oidx = 0; oidx < num_outputs; oidx++) { + auto& state = GetCalculationStateFor(reduction, oidx); + llvm::Value* block_accum_addr = state.shared_cache->Address( + {thread_id_info.thread_ids[kRowKeptDimension], + thread_id_info.lane_id}, + builder); + + llvm::Type* element_type = + state.partial_result_address->getAllocatedType(); + + // Ensure initial value address is in generic, not scratch. + llvm::Value* initial_value_addr = + CastSharedToGlobal(builder, + llvm_ir::EmitAllocaAtFunctionEntry( + element_type, "initial_value_addr", builder), + element_type, /*name=*/""); + builder->CreateStore(state.initial_value, initial_value_addr); + + llvm::Value* warp_exists = builder->CreateICmpULT( + thread_id_x, + constant(tiling.GetThreadsPerBlock()[kRowMinorReducedDimension] / + WarpSize())); + + llvm::Value* selected_value = builder->CreateSelect( + warp_exists, block_accum_addr, initial_value_addr); + + selected_values.push_back({selected_value, element_type}); + } - // If only one warp is present in the block, then we don't need inter-warp - // reduction. - // TODO(b/241414088) If only warp is present, then inter-warp - // communication using shared memory and synchronization using barrier is - // also unnecessary and should be removed. - if (tiling_scheme.GetNumThreadsPerBlock() > WarpSize()) { - EmitFullWarpShuffleDownLoopForReduce( - reducer, absl::MakeSpan(selected_values), - tiling_scheme.GetNumThreadsPerBlock(), /*num_results_per_warp=*/1); - } + // If only one warp produces the output element, we don't need to emit + // an inter warp reduce. In our tiling, DimX is the minor reduced + // dimension. The major reduced dimension is always emitted as a loop. + // TODO(b/241414088) If only warp is present, then inter-warp + // communication using shared memory and synchronization using barrier is + // also unnecessary and should be removed. + if (tiling.GetThreadsPerBlock()[kRowMinorReducedDimension] > WarpSize()) { + EmitFullWarpShuffleDownLoopForReduce( + reducer, absl::MakeSpan(selected_values), + tiling.GetNumThreadsPerBlock(), /*num_results_per_warp=*/1); + } - emit_write_output(is_zero(thread_id_x), selected_values); + emit_write_output(is_zero(thread_id_x), selected_values); + }); }); } @@ -993,20 +967,20 @@ void ReductionFusion::ReductionGroupEmitter:: return builder->CreateICmpEQ(value, constant(0)); }; const auto& reduction_info = reduction_emitter_.reduction_codegen_info_; - const TilingScheme& tiling_scheme = reduction_info.GetTilingScheme(); + const Tiling& tiling = reduction_info.GetTiling(); int num_outputs = reducer->num_parameters() / 2; + auto* kept_index = thread_ids[kColMinorKeptDimension]; + auto* reduced_index = thread_ids[kColReducedDimension]; + // Store the transpose in shared memory. for (int output_idx = 0; output_idx < num_outputs; output_idx++) { const auto& state = GetCalculationStateFor(reduction, output_idx); - const auto& shared_cache = state.shared_cache; auto* current_output_value = builder->CreateLoad(state.partial_result_address->getAllocatedType(), state.partial_result_address); - shared_cache->Store( - current_output_value, - {thread_ids[TilingScheme::DimX], thread_ids[TilingScheme::DimY]}, - builder); + state.shared_cache->Store(current_output_value, {kept_index, reduced_index}, + builder); } reduction_emitter_.EmitSyncThreads(); @@ -1015,27 +989,26 @@ void ReductionFusion::ReductionGroupEmitter:: absl::InlinedVector shmem_transposed_addrs; for (int output_idx = 0; output_idx < num_outputs; output_idx++) { const auto& state = GetCalculationStateFor(reduction, output_idx); - auto* shmem_transposed_addr = state.shared_cache->Address( - {thread_ids[TilingScheme::DimY], thread_ids[TilingScheme::DimX]}, - builder); + auto* shmem_transposed_addr = + state.shared_cache->Address({reduced_index, kept_index}, builder); shmem_transposed_addrs.push_back( {shmem_transposed_addr, state.shared_cache->GetElementType()}); } EmitFullWarpShuffleDownLoopForReduce(reducer, absl::MakeSpan(shmem_transposed_addrs), - tiling_scheme.GetNumThreadsPerBlock(), + tiling.GetNumThreadsPerBlock(), /*num_results_per_warp=*/1); // Some warps in the block are completely outside of the bound of the // tensor, so they should not write any output at all. llvm::Value* has_output = builder->CreateAnd( builder->CreateICmpULT( - thread_ids[TilingScheme::DimY], - tiling_kernel_info.output_tile_bounds[TilingScheme::DimX]), + reduced_index, + tiling_kernel_info.output_tile_bounds[kColMinorKeptDimension]), builder->CreateICmpULT( - thread_ids[TilingScheme::DimX], - tiling_kernel_info.output_tile_bounds[TilingScheme::DimY])); + kept_index, + tiling_kernel_info.output_tile_bounds[kColReducedDimension])); ksl.If("reduction_write_output", builder->CreateAnd(has_output, is_zero(thread_id_info.lane_id)), [&] { @@ -1059,10 +1032,8 @@ void ReductionFusion::ReductionGroupEmitter::GenerateElementForReducer( const auto& state = GetCalculationStateFor(reduction, red_idx); llvm::AllocaInst* input_address = state.input_address; - llvm_ir::IrArray::Index input_index = GetUnnormalizedIndex( - index, reduction->operand(0)->shape(), builder, - reduction_emitter_.reduction_codegen_info_.GetTilingScheme() - .GetShape()); + auto input_index = + index.SourceIndexOfBitcast(reduction->operand(0)->shape(), builder); llvm::Value* const input_ir_value = *state.input_gen(input_index); builder->CreateStore(input_ir_value, input_address); reduction_accumulators.push_back(state.partial_result_address); @@ -1122,36 +1093,36 @@ absl::Status ReductionFusion::ReductionEmitter::EmitIRForReduction( } CHECK(!heroes.empty()) << " expect at least one reduce instructions."; - const TilingScheme& tiling_scheme = reduction_codegen_info_.GetTilingScheme(); - CHECK_EQ(tiling_scheme.GetNumThreadsPerBlockPhysical() % WarpSize(), 0); + const Tiling& tiling = reduction_codegen_info_.GetTiling(); + CHECK_EQ(tiling.GetNumThreadsPerBlock() % WarpSize(), 0); ReductionGroupEmitter group_emitter(*this, heroes, result_ir_arrays, fused_emitter); TF_ASSIGN_OR_RETURN( TilingKernelInfo tiling_kernel_info, EmitTilingKernel( - builder_, tiling_scheme, index_ty_, + builder_, tiling, index_ty_, [&](const TilingThreadIdInfo& thread_id_info, const llvm_ir::IrArray::Index& tile_index, - std::array tile_dimensions) { - auto emit_element = [&](std::array index_in_tile) { - auto index = tile_index.AddOffset(index_in_tile, builder_); - - // Emit code to generate the input and perform the reduction - // computation for each reduction instruction. - for (const HloReduceInstruction* reduce : heroes) { - group_emitter.GenerateElementForReducer(reduce, index); - } - - // Emit code to generate the output for the non-reduction - // instructions in the fusion, if any. - TF_CHECK_OK(group_emitter.EmitExtraOutputsForReduce( - ShapeUtil::MakeShape( - F32, - reduction_codegen_info_.GetTilingScheme().GetShape()), - index, extra_output_gens)); - }; - EmitTile(builder_, reduction_codegen_info_.GetTilingScheme(), + absl::Span tile_dimensions) { + auto emit_element = + [&](absl::Span index_in_tile) { + auto index = tile_index.AddOffset(index_in_tile, builder_); + + // Emit code to generate the input and perform the reduction + // computation for each reduction instruction. + for (const HloReduceInstruction* reduce : heroes) { + group_emitter.GenerateElementForReducer(reduce, index); + } + + // Emit code to generate the output for the non-reduction + // instructions in the fusion, if any. + TF_CHECK_OK(group_emitter.EmitExtraOutputsForReduce( + ShapeUtil::MakeShape( + F32, reduction_codegen_info_.GetTiling().GetShape()), + index, extra_output_gens)); + }; + EmitTile(builder_, reduction_codegen_info_.GetTiling(), thread_id_info, tile_dimensions, emit_element); })); @@ -1170,8 +1141,7 @@ absl::Status ReductionFusion::ReductionEmitter::EmitIRForReduction( } absl::StatusOr -ReductionFusion::ReductionEmitter::EmitInitializers( - mlir::lmhlo::FusionOp fusion_op) { +ReductionFusion::ReductionEmitter::EmitInitializers() { FusionEmissionResult result; if (reduction_codegen_info_.IsRaceFree()) { return result; @@ -1192,44 +1162,29 @@ ReductionFusion::ReductionEmitter::EmitInitializers( // Therefore we can get the ordered slices by calling ForEachSubshape on the // result shape. std::vector slices; - if (ir_emitter_context_.emit_ir_from_hlo()) { - TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus( - fusion_.shape(), [&](const Shape& subshape, ShapeIndex index) { - if (!ShapeUtil::IsLeafIndex(fusion_.shape(), index)) { - return absl::OkStatus(); - } - - TF_ASSIGN_OR_RETURN( - BufferAllocation::Slice slice, - ir_emitter_context_.buffer_assignment().GetUniqueSlice(&fusion_, - index)); - slices.push_back(slice); + TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus( + fusion_.shape(), [&](const Shape& subshape, ShapeIndex index) { + if (!ShapeUtil::IsLeafIndex(fusion_.shape(), index)) { return absl::OkStatus(); - })); - } + } + + TF_ASSIGN_OR_RETURN( + BufferAllocation::Slice slice, + ir_emitter_context_.buffer_assignment().GetUniqueSlice(&fusion_, + index)); + slices.push_back(slice); + return absl::OkStatus(); + })); absl::Span fusion_roots = analysis_.fusion_roots(); for (int i = 0; i < fusion_roots.size(); ++i) { const HloInstruction* fusion_root = fusion_roots[i]; - mlir::Value dest = ir_emitter_context_.emit_ir_from_hlo() - ? nullptr - : fusion_op.getOutputBuffers()[i]; - - BufferAllocation::Slice dest_slice; - if (ir_emitter_context_.emit_ir_from_hlo()) { - dest_slice = slices[i]; - } else { - TF_ASSIGN_OR_RETURN( - dest_slice, - GetAllocationSlice(dest, ir_emitter_context_.allocations())); - } - if (IsReductionFromOrToContiguousDimensions(*fusion_root)) { - TF_ASSIGN_OR_RETURN(result.thunks.emplace_back(), - BuildFusedInitializerThunk(fusion_op, fusion_root, - dest, dest_slice, i)); + TF_ASSIGN_OR_RETURN( + result.thunks.emplace_back(), + BuildFusedInitializerThunk(fusion_root, slices[i], i)); } } return result; @@ -1268,20 +1223,21 @@ absl::Status ReductionFusion::ReductionEmitter::EmitKernel( // block_id_y instead of block_id_x simplifies the index calculation // for reduction code generation as the block_id_y is orthogonal to // the indices used within the reductions. - const std::vector>& instr_index_groups = - reduction_codegen_info_.GetIndexGroups(); + const auto& instr_index_groups = + reduction_codegen_info_.GetIndexGroups().grouped_roots; Shape reduce_operand_shape = reduction_codegen_info_.GetReduceOperandShape(); - llvm::Value* raw_block_id_y = gpu::EmitCallToTargetIntrinsic( + llvm::Value* block_id_y = gpu::EmitCallToTargetIntrinsic( gpu::TargetIntrinsicID::kBlockIdy, {}, {}, builder_); llvm_ir::AddRangeMetadata(0, instr_index_groups.size(), - llvm::cast(raw_block_id_y)); - raw_block_id_y = builder_->CreateZExtOrTrunc( - raw_block_id_y, builder_->getInt32Ty(), "raw_block_id_y"); + llvm::cast(block_id_y), + builder_->GetInsertBlock()->getModule()); + block_id_y = builder_->CreateZExtOrTrunc(block_id_y, builder_->getInt32Ty()); + block_id_y->setName("block.id.y"); for (int i = 0; i < instr_index_groups.size(); ++i) { TF_RETURN_IF_ERROR(ksl.IfWithStatus( absl::StrCat("reduce-group-", i), - builder_->CreateICmpEQ(raw_block_id_y, builder_->getInt32(i)), [&] { + builder_->CreateICmpEQ(block_id_y, builder_->getInt32(i)), [&] { return EmitIRForReduction(instr_index_groups[i], fused_emitter, result_ir_arrays, reduce_operand_shape); })); @@ -1295,12 +1251,12 @@ ReductionFusion::ReductionFusion(const HloFusionAnalysis& analysis) reduction_codegen_info_(ComputeReductionCodegenInfo(analysis)) {} absl::StatusOr ReductionFusion::EmitInitializers( - IrEmitterContext& ir_emitter_context, mlir::lmhlo::FusionOp fusion_op, + IrEmitterContext& ir_emitter_context, const HloFusionInstruction& fusion) const { llvm::IRBuilder<> builder(ir_emitter_context.llvm_module()->getContext()); return ReductionEmitter(analysis_, reduction_codegen_info_, ir_emitter_context, fusion, &builder) - .EmitInitializers(fusion_op); + .EmitInitializers(); } absl::Status ReductionFusion::EmitKernel(IrEmitterContext& ir_emitter_context, @@ -1315,11 +1271,12 @@ absl::Status ReductionFusion::EmitKernel(IrEmitterContext& ir_emitter_context, } LaunchDimensions ReductionFusion::launch_dimensions() const { - const TilingScheme& tiling_scheme = reduction_codegen_info_.GetTilingScheme(); - size_t blocks_y = reduction_codegen_info_.GetIndexGroups().size(); - return {se::BlockDim(/*x=*/tiling_scheme.GetNumBlocksPhysical(), + const Tiling& tiling = reduction_codegen_info_.GetTiling(); + size_t blocks_y = + reduction_codegen_info_.GetIndexGroups().grouped_roots.size(); + return {se::BlockDim(/*x=*/tiling.GetNumBlocks(), /*y=*/static_cast(blocks_y), /*z=*/1), - se::ThreadDim(/*x=*/tiling_scheme.GetNumThreadsPerBlockPhysical(), + se::ThreadDim(/*x=*/tiling.GetNumThreadsPerBlock(), /*y=*/1, /*z=*/1)}; } @@ -1331,54 +1288,189 @@ ReductionFusion::ComputeReductionCodegenInfo( Shape input_shape = hero_reduction->operand(0)->shape(); ReductionDimensions reduction_dimensions = GetReductionKindAndContiguousComponents(*hero_reduction); + auto shape = reduction_dimensions.dimensions; VLOG(10) << "is_row_reduction " << reduction_dimensions.is_row_reduction - << " " << reduction_dimensions.dimensions[0] << " " - << reduction_dimensions.dimensions[1] << " " - << reduction_dimensions.dimensions[2]; + << " " << shape[0] << " " << shape[1] << " " << shape[2]; Vector3 reduction_tiling = GetReductionTiling(reduction_dimensions); - int64_t fan_out = analysis.fusion_roots().size(); int64_t num_threads_y = reduction_dimensions.is_row_reduction ? 1 : WarpSize(); + int64_t rows_per_warp = + reduction_dimensions.is_row_reduction + ? RowReductionGetRowsPerWarp(shape[kRowMinorReducedDimension]) + : 1; int64_t num_threads_x = [&] { if (reduction_dimensions.is_row_reduction) { - if (RowReductionGetRowsPerWarp(reduction_dimensions.dimensions[2]) > 1) { - return reduction_dimensions.dimensions[2]; + if (rows_per_warp > 1) { + return shape[kRowMinorReducedDimension]; } - // Use 512 as default block size (threads per block) for row reductions. - // For multi-output fusions, reduce the block size further to decrease - // register pressure when multiple outputs are computed by each thread. - int64_t max_block_size = std::max( - MinThreadsXRowReduction(hero_reduction->GetModule()->config()), - static_cast(512LL / NearestPowerOfTwo(fan_out))); - return std::min(max_block_size, - RoundUpTo(CeilOfRatio(reduction_dimensions.dimensions[2], - reduction_tiling[2]), - WarpSize())); + int64_t max_block_size = + MinThreadsXRowReduction(hero_reduction->GetModule()->config()); + return std::min( + max_block_size, + RoundUpTo(CeilOfRatio(shape[kRowMinorReducedDimension], + reduction_tiling[kRowMinorReducedDimension]), + WarpSize())); } return WarpSize(); }(); - int vector_size = CanVectorizeReduction(analysis, reduction_dimensions, - num_threads_x, reduction_tiling) - ? 2 - : 1; + // If we're limited by the size of the x dimension, add additional parallelism + // in the y dimension. The code generator doesn't currently support + // parallelizing the z dimension (major reduced dimensions). The general + // recommendation is to use between 128 and 512 threads, so we just go for + // 256. See https://forums.developer.nvidia.com/t/55529 + constexpr int64_t kThreadsPerBlockTarget = 256; + if (reduction_dimensions.is_row_reduction && + num_threads_x * 2 <= kThreadsPerBlockTarget) { + int64_t kept_size = reduction_dimensions.dimensions[kRowKeptDimension]; + // Increase the size of the y dimension as long as there's remaining + // parallelism. + if (kept_size * num_threads_x <= kThreadsPerBlockTarget) { + num_threads_y = kept_size; + // num_threads_x is a power of two, but it may be less than 32. If dim_y + // is also small, we may have to increase the bound so the total number of + // threads is a multiple of 32. + while ((num_threads_x * num_threads_y) % 32) ++num_threads_y; + } else { + num_threads_y = kThreadsPerBlockTarget / num_threads_x; + } + } - Vector3 num_threads = {1, num_threads_y, num_threads_x}; - int virtual_thread_scaling_factor = - CalculateVirtualThreadScalingFactorForReduction(analysis, - reduction_dimensions); - VLOG(2) << "Using virtual thread scaling: " << virtual_thread_scaling_factor; + int vector_size = GetVectorSize(analysis, reduction_dimensions, num_threads_x, + reduction_tiling); + + absl::InlinedVector num_threads{1, num_threads_y, num_threads_x}; + absl::InlinedVector tiled_shape{shape[0], shape[1], + shape[2] / vector_size}; + absl::InlinedVector tile_per_thread{ + reduction_tiling[0], reduction_tiling[1], + reduction_tiling[2] / vector_size}; + if (rows_per_warp > 1) { + // If we produce more than one element per thread, that means the reduced + // dimension is small and it can't be tiled - we already have more threads + // in a warp than the size of the reduced dimension. The code generator + // doesn't currently support tiling the kept dimension, because it just + // uses the thread ID as the coordinate. + tile_per_thread[2] = 1; + } + if (vector_size != 1) { + num_threads.push_back(1); // The vector dimension is a loop. + tiled_shape.push_back(vector_size); + tile_per_thread.push_back(vector_size); + } - TilingScheme tiling_scheme(reduction_dimensions.dimensions, reduction_tiling, - num_threads, vector_size, - virtual_thread_scaling_factor); + Tiling tiling(tiled_shape, tile_per_thread, num_threads, + /*loops_to_unroll=*/{false, false, true, false}); bool reduction_is_race_free = ReductionIsRaceFree( hero_reduction->GetModule()->config(), reduction_dimensions); return ReductionCodegenInfo( - tiling_scheme, reduction_dimensions.is_row_reduction, - reduction_is_race_free, GroupDisjointReductions(analysis), - hero_reduction); + tiling, reduction_dimensions.is_row_reduction, reduction_is_race_free, + GroupDisjointReductions(analysis), hero_reduction); +} + +std::optional ReductionFusion::ComputeThreadIdToOutputIndexing( + int64_t root_index, mlir::MLIRContext* ctx) const { + const auto& groups = reduction_codegen_info_.GetIndexGroups(); + if (!groups.is_reduction_root[root_index]) { + // Non-transpose roots are elementwise by definition. + return ComputeThreadIdToInputIndexing(root_index, 0, ctx); + } + auto* root = analysis_.fusion_roots()[root_index]; + auto* hero = analysis_.fusion_heroes()[root_index]; + + const auto& tiling = reduction_codegen_info_.GetTiling(); + auto block_offsets = GetBlockOffsetsForTiling(tiling, ctx); + auto thread_ids = DelinearizeInBoundsIndex(mlir::getAffineDimExpr(0, ctx), + tiling.GetThreadsPerBlock(), + tiling.GetThreadStrides()); + + auto physical_shape = ShapeUtil::DeleteDimensions(hero->dimensions(), + hero->operand(0)->shape()); + + std::vector dimension_ranges{ + {0, tiling.GetNumThreadsPerBlock() - 1}, {}, {}, + {0, tiling.GetNumBlocks() - 1}, {}, {}, + }; + auto physical_index = [&]() { + if (reduction_codegen_info_.IsRowReduction()) { + IndexingMap linear_index( + mlir::AffineMap::get(6, 0, + block_offsets.getResult(kRowKeptDimension) + + thread_ids[kRowKeptDimension], + ctx), + dimension_ranges, {}); + int rows_per_warp = RowReductionGetRowsPerWarp( + tiling.GetShape()[kRowMinorReducedDimension]); + if (rows_per_warp > 1) { + linear_index.AddConstraint(thread_ids[kRowMinorReducedDimension] % + (WarpSize() / rows_per_warp), + {0, 0}); + } else { + linear_index.AddConstraint(thread_ids[kRowMinorReducedDimension], + {0, 0}); + } + return ComposeIndexingMaps( + linear_index, + GetBitcastMap(ShapeUtil::MakeShape( + PRED, {tiling.GetShape()[kRowKeptDimension]}), + physical_shape, ctx)); + } + + IndexingMap projected_index( + mlir::AffineMap::get(6, 0, + {block_offsets.getResult(kColMajorKeptDimension), + block_offsets.getResult(kColMinorKeptDimension) + + thread_ids[kColReducedDimension]}, + ctx), + dimension_ranges, {}); + + // TODO(b/319081342): Add constraints for the writing threads + // (`has_output`). + projected_index.AddConstraint( + mlir::getAffineDimExpr(kIndexingMapThreadIdxDims[0], ctx) % WarpSize(), + {0, 0}); + + return ComposeIndexingMaps( + projected_index, + GetBitcastMap(ShapeUtil::DeleteDimension(kColReducedDimension, + tiling.GetXlaShape()), + physical_shape, ctx)); + }(); + + auto map = ComposeIndexingMaps( + physical_index, GetBitcastMap(OutputShape(hero->shape(), 0), + OutputShape(root->shape(), 0), ctx)); + + int group_index = groups.group_id_per_root[root_index]; + map.AddConstraint(mlir::getAffineDimExpr(kIndexingMapBlockIdxDims[1], ctx), + {group_index, group_index}); + return map; +} + +std::optional ReductionFusion::ComputeThreadIdToInputIndexing( + int64_t root_index, int64_t hero_operand_index, + mlir::MLIRContext* ctx) const { + const auto& groups = reduction_codegen_info_.GetIndexGroups(); + + auto* hero = analysis_.fusion_heroes()[root_index]; + if (groups.is_reduction_root[root_index] && + hero_operand_index >= hero->operand_count() / 2) { + // We don't have indexing for the init values. + return std::nullopt; + } + + const auto& tiling = reduction_codegen_info_.GetTiling(); + auto map = ComposeIndexingMaps( + GetIndexingMapForTiling(tiling, ctx), + GetBitcastMap(tiling.GetXlaShape(), + hero->operand(hero_operand_index)->shape(), ctx)); + // Only threads with the right y block index actually do anything for this + // root. + int group_index = groups.group_id_per_root[root_index]; + map.AddConstraint(mlir::getAffineDimExpr(kIndexingMapBlockIdxDims[1], ctx), + {group_index, group_index}); + return map; } } // namespace gpu diff --git a/third_party/xla/xla/service/gpu/fusions/reduction.h b/third_party/xla/xla/service/gpu/fusions/reduction.h index ddfd2b0cda9c4b..0a30343f649463 100644 --- a/third_party/xla/xla/service/gpu/fusions/reduction.h +++ b/third_party/xla/xla/service/gpu/fusions/reduction.h @@ -109,14 +109,15 @@ class ReductionFusion : public KernelFusionEmitterBase { LaunchDimensions launch_dimensions() const override; std::optional ComputeThreadIdToOutputIndexing( - int64_t output_id, mlir::MLIRContext* ctx) const override { - // TODO(b/319081342): Implement this. - return std::nullopt; - } + int64_t root_index, mlir::MLIRContext* ctx) const override; + + std::optional ComputeThreadIdToInputIndexing( + int64_t root_index, int64_t hero_operand_index, + mlir::MLIRContext* ctx) const override; protected: absl::StatusOr EmitInitializers( - IrEmitterContext& ir_emitter_context, mlir::lmhlo::FusionOp fusion_op, + IrEmitterContext& ir_emitter_context, const HloFusionInstruction& fusion) const override; absl::Status EmitKernel(IrEmitterContext& ir_emitter_context, @@ -130,20 +131,30 @@ class ReductionFusion : public KernelFusionEmitterBase { class ReductionEmitter; class ReductionGroupEmitter; + struct IndexGroups { + std::vector> grouped_roots; + + // For each root of the fusion, returns the index of the group it was placed + // in. + std::vector group_id_per_root; + + // For each root of the fusion, returns whether it is a reduction root, or + // an additional output. + std::vector is_reduction_root; + }; + class ReductionCodegenInfo { public: - using IndexGroups = std::vector>; - - ReductionCodegenInfo(TilingScheme mapping_scheme, bool is_row_reduction, + ReductionCodegenInfo(Tiling tiling, bool is_row_reduction, bool is_race_free, IndexGroups index_groups, const HloInstruction* first_reduce) - : tiling_scheme_(mapping_scheme), + : tiling_(tiling), is_row_reduction_(is_row_reduction), is_race_free_(is_race_free), index_groups_(std::move(index_groups)), first_reduce_(first_reduce) {} - const TilingScheme& GetTilingScheme() const { return tiling_scheme_; } + const Tiling& GetTiling() const { return tiling_; } const IndexGroups& GetIndexGroups() const { return index_groups_; } Shape GetReduceOperandShape() const { return first_reduce_->operand(0)->shape(); @@ -153,13 +164,21 @@ class ReductionFusion : public KernelFusionEmitterBase { bool IsRaceFree() const { return is_race_free_; } private: - TilingScheme tiling_scheme_; + Tiling tiling_; bool is_row_reduction_; bool is_race_free_; IndexGroups index_groups_; const HloInstruction* first_reduce_; }; + // Groups the roots of the fusion. Different groups will be executed in + // parallel. We run reduce instructions in parallel if we can without too + // much recomputation overhead. The current heuristic is to place reduce + // instructions that share nothing or only (broadcasted) scalars/constants + // into different groups; otherwise, they are placed in the same group. Non- + // reduce instructions are always grouped with reduces with which they share + // any predecessors. + static IndexGroups GroupDisjointReductions(const HloFusionAnalysis& analysis); static ReductionCodegenInfo ComputeReductionCodegenInfo( const HloFusionAnalysis& analysis); diff --git a/third_party/xla/xla/service/gpu/fusions/reduction_test.cc b/third_party/xla/xla/service/gpu/fusions/reduction_test.cc new file mode 100644 index 00000000000000..7eb1be8247d9e8 --- /dev/null +++ b/third_party/xla/xla/service/gpu/fusions/reduction_test.cc @@ -0,0 +1,409 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "xla/service/gpu/fusions/reduction.h" + +#include +#include + +#include +#include +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "xla/service/gpu/fusions/fusions.h" +#include "xla/service/gpu/gpu_device_info_for_tests.h" +#include "xla/service/gpu/hlo_fusion_analysis.h" +#include "xla/service/gpu/model/indexing_test_utils.h" +#include "xla/status_macros.h" +#include "xla/statusor.h" +#include "xla/stream_executor/device_description.h" +#include "xla/tests/hlo_test_base.h" +#include "tsl/platform/statusor.h" + +namespace xla { +namespace gpu { +namespace { + +class ReductionTest : public HloTestBase { + protected: + stream_executor::DeviceDescription device_info_ = + TestGpuDeviceInfo::RTXA6000DeviceInfo(); +}; + +StatusOr> GetReductionFusion( + const HloFusionAnalysis& analysis) { + TF_ASSIGN_OR_RETURN( + auto emitter, GetFusionEmitter(PreBufferAssignmentFusionInfo{analysis})); + auto fusion = dynamic_cast(emitter.get()); + TF_RET_CHECK(fusion != nullptr); + + emitter.release(); + return std::unique_ptr{fusion}; +} + +TEST_F(ReductionTest, ThreadIndexingRowReduction) { + auto module = ParseAndReturnVerifiedModule(R"( + HloModule module + + add { + p0 = f32[] parameter(0) + p1 = f32[] parameter(1) + ROOT add = f32[] add(p0, p1) + } + + fusion { + %input = f32[100,64,512] parameter(0) + %c0 = f32[] constant(0) + ROOT reduce = f32[100,64] reduce(%input, %c0), dimensions={2}, to_apply=add + } + + ENTRY entry { + %input = f32[100,64,512] parameter(0) + ROOT %fusion = f32[100,64] fusion(%input), kind=kInput, calls=fusion + })") + .value(); + + auto* root = module->entry_computation()->root_instruction(); + auto analysis = AnalyzeFusion(*root, device_info_); + + TF_ASSERT_OK_AND_ASSIGN(auto fusion, GetReductionFusion(analysis)); + mlir::MLIRContext mlir_context; + + EXPECT_THAT( + fusion->ComputeThreadIdToInputIndexing(0, 0, &mlir_context)->ToString(), + MatchIndexingString(R"( + (d0, d1, d2, d3, d4, d5)[s0, s1, s2] -> ( + (d3 * 8 + d0 floordiv 32) floordiv 64, + (d3 * 8 + d0 floordiv 32) mod 64, + d0 mod 32 + s2 * 32 + ) + domain: + d0 in [0, 255] + d1 in [0, 0] + d2 in [0, 0] + d3 in [0, 799] + d4 in [0, 0] + d5 in [0, 0] + s0 in [0, 0] + s1 in [0, 0] + s2 in [0, 15] + 0 in [0, 0] + d0 mod 32 + s2 * 32 in [0, 511] + d3 * 8 + d0 floordiv 32 in [0, 6399] + )")); + EXPECT_THAT( + fusion->ComputeThreadIdToOutputIndexing(0, &mlir_context)->ToString(), + MatchIndexingString(R"( + (d0, d1, d2, d3, d4, d5) -> ( + (d3 * 8 + d0 floordiv 32) floordiv 64, + (d3 * 8 + d0 floordiv 32) mod 64 + ) + domain: + d0 in [0, 255] + d1 in [0, 0] + d2 in [0, 0] + d3 in [0, 799] + d4 in [0, 0] + d5 in [0, 0] + (d3 * 8 + d0 floordiv 32) mod 64 in [0, 63] + d0 mod 32 in [0, 0] + d3 * 8 + d0 floordiv 32 in [0, 6399] + )")); +} + +TEST_F(ReductionTest, ThreadIndexingMultiRowReduction) { + auto module = ParseAndReturnVerifiedModule(R"( + HloModule module + + add { + p0 = f32[] parameter(0) + p1 = f32[] parameter(1) + ROOT add = f32[] add(p0, p1) + } + + fusion { + %input = f32[100,64,4] parameter(0) + %c0 = f32[] constant(0) + ROOT reduce = f32[100,64] reduce(%input, %c0), dimensions={2}, to_apply=add + } + + ENTRY entry { + %input = f32[100,64,4] parameter(0) + ROOT %fusion = f32[100,64] fusion(%input), kind=kInput, calls=fusion + })") + .value(); + + auto* root = module->entry_computation()->root_instruction(); + auto analysis = AnalyzeFusion(*root, device_info_); + + TF_ASSERT_OK_AND_ASSIGN(auto fusion, GetReductionFusion(analysis)); + mlir::MLIRContext mlir_context; + + EXPECT_THAT( + fusion->ComputeThreadIdToInputIndexing(0, 0, &mlir_context)->ToString(), + MatchIndexingString(R"( + (d0, d1, d2, d3, d4, d5)[s0, s1, s2] -> ( + d3 + (d0 floordiv 4) floordiv 64, + (d0 floordiv 4) mod 64, + d0 mod 4 + ) + domain: + d0 in [0, 255] + d1 in [0, 0] + d2 in [0, 0] + d3 in [0, 99] + d4 in [0, 0] + d5 in [0, 0] + s0 in [0, 0] + s1 in [0, 0] + s2 in [0, 0] + 0 in [0, 0] + d0 mod 4 in [0, 3] + d3 * 64 + d0 floordiv 4 in [0, 6399] + )")); + EXPECT_THAT( + fusion->ComputeThreadIdToOutputIndexing(0, &mlir_context)->ToString(), + MatchIndexingString(R"( + (d0, d1, d2, d3, d4, d5) -> ( + d3 + (d0 floordiv 4) floordiv 64, + (d0 floordiv 4) mod 64 + ) + domain: + d0 in [0, 255] + d1 in [0, 0] + d2 in [0, 0] + d3 in [0, 99] + d4 in [0, 0] + d5 in [0, 0] + (d0 floordiv 4) mod 64 in [0, 63] + d0 mod 4 in [0, 0] + d3 * 64 + d0 floordiv 4 in [0, 6399] + d3 + (d0 floordiv 4) floordiv 64 in [0, 99] + )")); +} + +TEST_F(ReductionTest, ThreadIndexingColumnReduction) { + auto module = ParseAndReturnVerifiedModule(R"( + HloModule module + + add { + p0 = f32[] parameter(0) + p1 = f32[] parameter(1) + ROOT add = f32[] add(p0, p1) + } + + fusion { + %input = f32[100,64,32] parameter(0) + %c0 = f32[] constant(0) + ROOT reduce = f32[100,32] reduce(%input, %c0), dimensions={1}, to_apply=add + } + + ENTRY entry { + %input = f32[100,64,32] parameter(0) + ROOT %fusion = f32[100,32] fusion(%input), kind=kInput, calls=fusion + })") + .value(); + + auto* root = module->entry_computation()->root_instruction(); + auto analysis = AnalyzeFusion(*root, device_info_); + + TF_ASSERT_OK_AND_ASSIGN(auto fusion, GetReductionFusion(analysis)); + mlir::MLIRContext mlir_context; + + EXPECT_THAT( + fusion->ComputeThreadIdToInputIndexing(0, 0, &mlir_context)->ToString(), + MatchIndexingString(R"( + (d0, d1, d2, d3, d4, d5)[s0, s1, s2] -> ( + d3, + d0 floordiv 32 + s1 * 32, + d0 mod 32 + ) + domain: + d0 in [0, 1023] d1 in [0, 0] d2 in [0, 0] + d3 in [0, 99] d4 in [0, 0] d5 in [0, 0] + s0 in [0, 0] s1 in [0, 127] s2 in [0, 0] + d0 floordiv 32 + s1 * 32 in [0, 63] + d0 mod 32 in [0, 31] + )")); + EXPECT_THAT( + fusion->ComputeThreadIdToOutputIndexing(0, &mlir_context)->ToString(), + MatchIndexingString(R"( + (d0, d1, d2, d3, d4, d5) -> ( + d3, + d0 floordiv 32 + ) + domain: + d0 in [0, 1023] d1 in [0, 0] d2 in [0, 0] + d3 in [0, 99] d4 in [0, 0] d5 in [0, 0] + d0 mod 32 in [0, 0] + )")); +} + +TEST_F(ReductionTest, ThreadIndexingOutputLayout) { + auto module = ParseAndReturnVerifiedModule(R"( + HloModule module + + add { + p0 = f32[] parameter(0) + p1 = f32[] parameter(1) + ROOT add = f32[] add(p0, p1) + } + + fusion { + %input = f32[100,64,512] parameter(0) + %c0 = f32[] constant(0) + ROOT reduce = f32[100,64]{0,1} reduce(%input, %c0), dimensions={2}, to_apply=add + } + + ENTRY entry { + %input = f32[100,64,512] parameter(0) + ROOT %fusion = f32[100,64]{0,1} fusion(%input), kind=kInput, calls=fusion + })") + .value(); + + auto* root = module->entry_computation()->root_instruction(); + auto analysis = AnalyzeFusion(*root, device_info_); + + TF_ASSERT_OK_AND_ASSIGN(auto fusion, GetReductionFusion(analysis)); + mlir::MLIRContext mlir_context; + + EXPECT_THAT( + fusion->ComputeThreadIdToOutputIndexing(0, &mlir_context)->ToString(), + MatchIndexingString(R"( + (d0, d1, d2, d3, d4, d5) -> ( + (d3 * 8 + d0 floordiv 32) floordiv 64, + (d3 * 8 + d0 floordiv 32) mod 64 + ) + domain: + d0 in [0, 255] + d1 in [0, 0] + d2 in [0, 0] + d3 in [0, 799] + d4 in [0, 0] + d5 in [0, 0] + (d3 * 8 + d0 floordiv 32) mod 64 in [0, 63] + d0 mod 32 in [0, 0] + d3 * 8 + d0 floordiv 32 in [0, 6399] + )")); +} + +TEST_F(ReductionTest, ThreadIndexingSideOutput) { + auto module = ParseAndReturnVerifiedModule(R"( + HloModule module + + add { + p0 = f32[] parameter(0) + p1 = f32[] parameter(1) + ROOT add = f32[] add(p0, p1) + } + + fusion { + %input = f32[100,64,512] parameter(0) + %c0 = f32[] constant(0) + %log = f32[100,64,512] log(%input) + %reduce = f32[100,64] reduce(%input, %c0), dimensions={2}, to_apply=add + ROOT tuple = (f32[100,64], f32[100,64,512]) tuple(%reduce, %log) + } + + ENTRY entry { + %input = f32[100,64,512] parameter(0) + ROOT %fusion = (f32[100,64], f32[100,64,512]) fusion(%input), kind=kInput, calls=fusion + })") + .value(); + + auto* root = module->entry_computation()->root_instruction(); + auto analysis = AnalyzeFusion(*root, device_info_); + + TF_ASSERT_OK_AND_ASSIGN(auto fusion, GetReductionFusion(analysis)); + mlir::MLIRContext mlir_context; + + constexpr char kExpectedIndexing[] = R"( + (d0, d1, d2, d3, d4, d5)[s0, s1, s2] -> ( + (d3 * 8 + d0 floordiv 32) floordiv 64, + (d3 * 8 + d0 floordiv 32) mod 64, + d0 mod 32 + s2 * 32 + ) + domain: + d0 in [0, 255] + d1 in [0, 0] + d2 in [0, 0] + d3 in [0, 799] + d4 in [1, 0] + d5 in [0, 0] + s0 in [0, 0] + s1 in [0, 0] + s2 in [0, 15] + 0 in [0, 0] + d0 mod 32 + s2 * 32 in [0, 511] + d3 * 8 + d0 floordiv 32 in [0, 6399] + )"; + EXPECT_THAT( + fusion->ComputeThreadIdToInputIndexing(1, 0, &mlir_context)->ToString(), + MatchIndexingString(kExpectedIndexing)); + EXPECT_THAT( + fusion->ComputeThreadIdToOutputIndexing(1, &mlir_context)->ToString(), + MatchIndexingString(kExpectedIndexing)); +} + +TEST_F(ReductionTest, bla) { + auto module = ParseAndReturnVerifiedModule(R"( + HloModule module + add { + p0 = f32[] parameter(0) + p1 = f32[] parameter(1) + ROOT add = f32[] add(p0, p1) + } + fusion { + %input = f32[1024, 8192] parameter(0) + %c0 = f32[] constant(0) + ROOT reduce = f32[1024]{0} reduce(f32[1024, 8192] %input, f32[] %c0), + dimensions={1}, to_apply=add + } + ENTRY entry { + %input = f32[1024, 8192] parameter(0) + ROOT %fusion = f32[1024] fusion(%input), kind=kInput, calls=fusion + })") + .value(); + + auto* root = module->entry_computation()->root_instruction(); + auto analysis = AnalyzeFusion(*root, device_info_); + + TF_ASSERT_OK_AND_ASSIGN(auto fusion, GetReductionFusion(analysis)); + mlir::MLIRContext mlir_context; + + EXPECT_THAT( + fusion->ComputeThreadIdToInputIndexing(0, 0, &mlir_context)->ToString(), + MatchIndexingString(R"( + (d0, d1, d2, d3, d4, d5)[s0, s1, s2, s3] -> ( + d3, + (d0 + s2 * 512) * 2 + s3 + ) + domain: + d0 in [0, 511] + d1 in [0, 0] + d2 in [0, 0] + d3 in [0, 1023] + d4 in [0, 0] + d5 in [0, 0] + s0 in [0, 0] + s1 in [0, 0] + s2 in [0, 7] + s3 in [0, 1] + 0 in [0, 0] + d0 + s2 * 512 in [0, 4095] + )")); +} + +} // namespace +} // namespace gpu +} // namespace xla diff --git a/third_party/xla/xla/service/gpu/fusions/scatter.h b/third_party/xla/xla/service/gpu/fusions/scatter.h index b90b7833acc5da..6982bbc8e6bd2c 100644 --- a/third_party/xla/xla/service/gpu/fusions/scatter.h +++ b/third_party/xla/xla/service/gpu/fusions/scatter.h @@ -44,12 +44,19 @@ class ScatterFusion : public KernelFusionEmitterBase { LaunchDimensions launch_dimensions() const override; std::optional ComputeThreadIdToOutputIndexing( - int64_t output_id, mlir::MLIRContext* ctx) const override { + int64_t root_index, mlir::MLIRContext* ctx) const override { // The kernel iterates over updates, whose correspondence to output // elements cannot be computed statically. return std::nullopt; } + std::optional ComputeThreadIdToInputIndexing( + int64_t root_index, int64_t hero_operand_index, + mlir::MLIRContext* ctx) const override { + // TODO(b/319081342): Implement this. + return std::nullopt; + } + protected: absl::Status EmitKernel(IrEmitterContext& ir_emitter_context, const HloFusionInstruction& fusion, diff --git a/third_party/xla/xla/service/gpu/fusions/thunk_util.cc b/third_party/xla/xla/service/gpu/fusions/thunk_util.cc index 8ebdae0d850f88..36b5cb300453bd 100644 --- a/third_party/xla/xla/service/gpu/fusions/thunk_util.cc +++ b/third_party/xla/xla/service/gpu/fusions/thunk_util.cc @@ -20,30 +20,26 @@ limitations under the License. #include #include "absl/algorithm/container.h" +#include "absl/status/statusor.h" #include "absl/types/span.h" -#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project -#include "mlir/IR/Operation.h" // from @llvm-project -#include "mlir/IR/Value.h" // from @llvm-project #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" #include "xla/literal.h" #include "xla/service/buffer_assignment.h" #include "xla/service/gpu/ir_emitter_context.h" -#include "xla/service/gpu/runtime3/memset_thunk.h" +#include "xla/service/gpu/runtime/memset_thunk.h" #include "xla/service/gpu/thunk.h" #include "xla/shape.h" #include "xla/shape_util.h" -#include "xla/statusor.h" namespace xla { namespace gpu { absl::StatusOr>> BuildConstantInitializerThunk(IrEmitterContext& ir_emitter_context, - mlir::Operation* op, const HloInstruction* instr, + const HloInstruction* instr, const HloInstruction* init_value, - mlir::Value dest, BufferAllocation::Slice dest_slice) { if (const HloConstantInstruction* constant = DynCast(init_value)) { @@ -56,11 +52,9 @@ BuildConstantInitializerThunk(IrEmitterContext& ir_emitter_context, const Shape dest_shape = instr->shape(); Thunk::ThunkInfo thunk_info = - ir_emitter_context.emit_ir_from_hlo() - ? Thunk::ThunkInfo::WithProfileAnnotation(instr) - : Thunk::ThunkInfo::WithProfileAnnotation(op); + Thunk::ThunkInfo::WithProfileAnnotation(instr); if (absl::c_all_of(literal_bytes, [](uint8_t byte) { return byte == 0; })) { - return {{std::make_unique(thunk_info, dest_slice, dest)}}; + return {{std::make_unique(thunk_info, dest_slice)}}; } // If the literal is 8 or 16 bits wide, we can emit a 32-bit memset by @@ -77,7 +71,7 @@ BuildConstantInitializerThunk(IrEmitterContext& ir_emitter_context, } uint32_t pattern32 = uint32_t{pattern16} | (uint32_t{pattern16} << 16); return {{std::make_unique(thunk_info, pattern32, - dest_slice, dest)}}; + dest_slice)}}; } // If the literal is an even multiple of 32 bits wide, we can emit a 32-bit @@ -88,7 +82,7 @@ BuildConstantInitializerThunk(IrEmitterContext& ir_emitter_context, uint32_t word; memcpy(&word, literal_bytes.data(), sizeof(word)); return {{std::make_unique(thunk_info, word, - dest_slice, dest)}}; + dest_slice)}}; } } return std::nullopt; diff --git a/third_party/xla/xla/service/gpu/fusions/thunk_util.h b/third_party/xla/xla/service/gpu/fusions/thunk_util.h index 6848f78ec5b8ae..32ba8e6a267a8c 100644 --- a/third_party/xla/xla/service/gpu/fusions/thunk_util.h +++ b/third_party/xla/xla/service/gpu/fusions/thunk_util.h @@ -18,12 +18,11 @@ limitations under the License. #include #include -#include "mlir/IR/Value.h" // from @llvm-project +#include "absl/status/statusor.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/service/buffer_assignment.h" #include "xla/service/gpu/ir_emitter_context.h" #include "xla/service/gpu/thunk.h" -#include "xla/statusor.h" namespace xla { namespace gpu { @@ -32,9 +31,8 @@ namespace gpu { // empty optional if the value is not a constant. absl::StatusOr>> BuildConstantInitializerThunk(IrEmitterContext& ir_emitter_context, - mlir::Operation* op, const HloInstruction* instr, + const HloInstruction* instr, const HloInstruction* init_value, - mlir::Value dest, BufferAllocation::Slice dest_slice); } // namespace gpu diff --git a/third_party/xla/xla/service/gpu/fusions/tiling_util.cc b/third_party/xla/xla/service/gpu/fusions/tiling_util.cc index 2e24d6c0eae166..24456209e521fb 100644 --- a/third_party/xla/xla/service/gpu/fusions/tiling_util.cc +++ b/third_party/xla/xla/service/gpu/fusions/tiling_util.cc @@ -15,14 +15,15 @@ limitations under the License. #include "xla/service/gpu/fusions/tiling_util.h" -#include #include #include #include #include #include "absl/log/check.h" +#include "absl/strings/str_cat.h" #include "absl/types/span.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/IR/Constants.h" #include "llvm/IR/DerivedTypes.h" #include "llvm/IR/IRBuilder.h" @@ -35,7 +36,6 @@ limitations under the License. #include "xla/service/llvm_ir/kernel_support_library.h" #include "xla/service/llvm_ir/llvm_loop.h" #include "xla/service/llvm_ir/llvm_util.h" -#include "xla/shape.h" #include "xla/shape_util.h" #include "xla/util.h" #include "tsl/platform/statusor.h" @@ -44,9 +44,8 @@ namespace xla { namespace gpu { namespace { -void EmitTileRec(const TilingThreadIdInfo& thread_id_info, - const TilingScheme& tiling_scheme, int dim, - std::array tile_idx, +void EmitTileRec(const TilingThreadIdInfo& thread_id_info, const Tiling& tiling, + int dim, absl::InlinedVector tile_idx, absl::Span tile_dimensions, llvm::IRBuilder<>* b, const TileElementGenerator& emit_elem) { llvm::Type* index_ty = thread_id_info.thread_id->getType(); @@ -55,61 +54,57 @@ void EmitTileRec(const TilingThreadIdInfo& thread_id_info, }; auto recurse = [&] { - EmitTileRec(thread_id_info, tiling_scheme, dim + 1, tile_idx, - tile_dimensions, b, emit_elem); + if (dim == tile_idx.size() - 1) { + emit_elem(tile_idx); + } else { + EmitTileRec(thread_id_info, tiling, dim + 1, tile_idx, tile_dimensions, b, + emit_elem); + } }; - KernelSupportLibrary ksl(b, dim == TilingScheme::DimX - ? llvm_ir::UnrollMode::kFullyUnroll - : llvm_ir::UnrollMode::kDefaultUnroll); + bool unroll = tiling.GetLoopsToUnroll()[dim]; + KernelSupportLibrary ksl(b, unroll ? llvm_ir::UnrollMode::kFullyUnroll + : llvm_ir::UnrollMode::kDefaultUnroll); - // TODO(jreiffers): Remove the dim==Z check, this is only here for historical - // reasons. - if (dim == TilingScheme::DimZ && tiling_scheme.GetBlockTileSize()[dim] == 1) { + if (tiling.GetBlockTileSize()[dim] == 1) { tile_idx[dim] = constant(0); recurse(); - } else if (dim == TilingScheme::DimX) { - int64_t vector_size = tiling_scheme.GetVectorSize(); - int64_t stride = tiling_scheme.GetThreadsPerBlock()[TilingScheme::DimX]; - int64_t last_dim_size = tiling_scheme.GetThreadTileSize()[2] / vector_size; + } else if (unroll) { + // TODO(jreiffers): Check if this unrolling does anything useful. + int64_t stride = tiling.GetThreadsPerBlock()[dim]; + int64_t dim_size = tiling.GetThreadTileSize()[dim]; auto make_loop = [&](bool emit_bounds_checks) { auto body = [&, emit_bounds_checks](llvm::Value* i) { - for (int64_t v = 0; v < vector_size; ++v) { - tile_idx[dim] = b->CreateAdd( - b->CreateAdd(b->CreateMul(i, constant(stride * vector_size)), - constant(v)), - thread_id_info.start_offsets[dim]); - if (emit_bounds_checks) { - auto* in_bounds = - b->CreateICmpULT(tile_idx[dim], tile_dimensions[dim]); - ksl.If("x_in_tile", in_bounds, [&] { emit_elem(tile_idx); }); - } else { - emit_elem(tile_idx); - } + tile_idx[dim] = b->CreateAdd(i, thread_id_info.thread_ids[dim]); + if (emit_bounds_checks) { + auto* in_bounds = + b->CreateICmpULT(tile_idx[dim], tile_dimensions[dim]); + ksl.If("x_in_tile", in_bounds, recurse); + } else { + recurse(); } }; return [&, body] { - ksl.For(absl::StrCat("loop", dim), constant(0), constant(last_dim_size), - constant(1), body); + ksl.For(absl::StrCat("loop", dim), constant(0), + constant(dim_size * stride), constant(stride), body); }; }; - if (stride > 1 && last_dim_size > 1) { + if (stride > 1 && dim_size > 1) { // Most tiles will be full, so we emit a single bounds check for those. - auto* is_full_tile = - b->CreateICmpEQ(constant(tiling_scheme.GetBlockTileSize()[dim]), - tile_dimensions[dim]); + auto* is_full_tile = b->CreateICmpEQ( + constant(tiling.GetBlockTileSize()[dim]), tile_dimensions[dim]); ksl.If("is_full_tile", is_full_tile, make_loop(false), make_loop(true)); } else { - // TODO(jreiffers): If last_dim_size is 1, we don't need the bounds check - // and actually we don't need any loop. That's a special case of the TODO - // above. make_loop(true)(); } } else { - ksl.For(absl::StrCat("loop", dim), thread_id_info.start_offsets[dim], - tile_dimensions[dim], thread_id_info.strides[dim], - [&](llvm::Value* i) { + // All dimensions are strided (thread 0 processes elements 0, num_threads, + // num_threads+2, ...; thread 1 processes elements 1, num_threads + 1 and so + // on). + ksl.For(absl::StrCat("loop", dim), /*start=*/thread_id_info.thread_ids[dim], + /*end=*/tile_dimensions[dim], + /*step=*/tiling.GetThreadsPerBlock()[dim], [&](llvm::Value* i) { tile_idx[dim] = i; recurse(); }); @@ -118,11 +113,12 @@ void EmitTileRec(const TilingThreadIdInfo& thread_id_info, } // namespace -void EmitTile(llvm::IRBuilder<>* builder, const TilingScheme& tiling_scheme, +void EmitTile(llvm::IRBuilder<>* builder, const Tiling& tiling, const TilingThreadIdInfo& thread_id_info, absl::Span tile_dimensions, const TileElementGenerator& emit_elem_function) { - EmitTileRec(thread_id_info, tiling_scheme, 0, {}, tile_dimensions, builder, + absl::InlinedVector tile_idx(tiling.GetShape().size()); + EmitTileRec(thread_id_info, tiling, 0, tile_idx, tile_dimensions, builder, emit_elem_function); } @@ -135,10 +131,12 @@ llvm::Value* EmitBlockId(llvm::IRBuilder<>* builder, int32_t num_blocks, EmitCallToTargetIntrinsic(TargetIntrinsicID::kBlockIdx, {}, {}, builder); if (num_blocks != 0) { llvm_ir::AddRangeMetadata(0, num_blocks, - llvm::cast(block_id)); + llvm::cast(block_id), + builder->GetInsertBlock()->getModule()); } - return builder->CreateIntCast(block_id, index_ty, /*isSigned=*/true, - "block.id.x"); + auto ret = builder->CreateIntCast(block_id, index_ty, /*isSigned=*/true); + ret->setName("block.id.x"); + return ret; } // Emits current thread id with the given type. @@ -148,168 +146,111 @@ llvm::Value* EmitThreadId(llvm::IRBuilder<>* builder, int64_t threads_per_block, llvm::Type* index_ty) { // Calculate (y, x) coordinates respectively in the 2D view of thread block, // defined by (num_thread_y, num_thread_x) from thread_id. - llvm::CallInst* thread_id_raw = + llvm::CallInst* thread_id = EmitCallToTargetIntrinsic(TargetIntrinsicID::kThreadIdx, {}, {}, builder); - llvm_ir::AddRangeMetadata(0, threads_per_block, thread_id_raw); - return builder->CreateIntCast(thread_id_raw, index_ty, - /*isSigned=*/true, "thread.id.x"); + llvm_ir::AddRangeMetadata(0, threads_per_block, thread_id, + builder->GetInsertBlock()->getModule()); + auto ret = builder->CreateIntCast(thread_id, index_ty, /*isSigned=*/true); + ret->setName("thread.id.x"); + return ret; } -// Emits the LLVM values for thread_id, thread_id.x, thread_id.y and lane -// id. -// -// Returns a struct containing these values. -// -// In the presence of thread scaling in tiling scheme may return early if the -// combination of thread_id/block_id does not correspond to a real block. -// Assumes the current function returns void. -absl::StatusOr EmitThreadIdInfo( - llvm::IRBuilder<>* builder, const TilingScheme& tiling_scheme, - llvm::Type* index_ty) { +// Emits the LLVM values for thread_id, block_id, coordinates of the current +// tile and strides of the loops to iterate over the current tile. +absl::StatusOr EmitThreadIdInfo(llvm::IRBuilder<>* builder, + const Tiling& tiling, + llvm::Type* index_ty) { auto constant = [&](uint64_t c) -> llvm::Constant* { return llvm::ConstantInt::get(index_ty, c); }; - llvm::Value* thread_id_physical = EmitThreadId( - builder, tiling_scheme.GetNumThreadsPerBlockPhysical(), index_ty); - int64_t num_blocks = tiling_scheme.GetNumBlocksPhysical(); + int64_t num_blocks = tiling.GetNumBlocks(); if (num_blocks > (int64_t)std::numeric_limits::max()) { return FailedPrecondition( "Number of physical blocks (%d) does not fit in an i32 in tiling " "scheme: %s", - num_blocks, tiling_scheme.ToString()); + num_blocks, tiling.ToString()); } - llvm::Value* block_id_physical = EmitBlockId(builder, num_blocks, index_ty); - // More than one thread in the z axis is currently not supported by the - // index computation. Since the indexing is a bit complicated (with respect to - // strides and starts and "virtual scaling"), there's no obvious way to extend - // it right now. - CHECK_EQ(tiling_scheme.GetThreadsPerBlock()[TilingScheme::DimZ], 1); + TilingThreadIdInfo info; + info.thread_id = + EmitThreadId(builder, tiling.GetNumThreadsPerBlock(), index_ty); + info.block_id = EmitBlockId(builder, num_blocks, index_ty); - llvm::Value* thread_id_logical = builder->CreateURem( - thread_id_physical, constant(tiling_scheme.GetNumThreadsPerBlock())); - llvm::Value* scaling = builder->CreateUDiv( - thread_id_physical, constant(tiling_scheme.GetNumThreadsPerBlock())); - llvm::Value* block_id_logical = builder->CreateAdd( - builder->CreateMul(block_id_physical, - constant(tiling_scheme.GetThreadIdScalingFactor())), - scaling); - - llvm::Value* num_threads_x_v = - constant(tiling_scheme.GetThreadsPerBlock()[TilingScheme::DimX]); - - llvm::Value* block_exists = builder->CreateICmpULT( - block_id_logical, constant(tiling_scheme.GetNumBlocks())); - llvm_ir::EmitEarlyReturn(block_exists, builder); - - std::array thread_ids{ - constant(0), // See above, there must be 1 thread in the z axis. - builder->CreateUDiv(thread_id_logical, num_threads_x_v, "thread_id.y"), - builder->CreateURem(thread_id_logical, num_threads_x_v, "thread_id.x")}; - std::array start_offsets{ - constant(0), thread_ids[TilingScheme::DimY], - builder->CreateMul(thread_ids[TilingScheme::DimX], - constant(tiling_scheme.GetVectorSize()))}; - std::array strides{ - constant(1), - constant(tiling_scheme.GetThreadsPerBlock()[TilingScheme::DimY]), - constant(1) // Not really, see EmitTileRec. - }; + for (auto [dim, stride] : llvm::enumerate(tiling.GetThreadStrides())) { + int64_t size = tiling.GetThreadsPerBlock()[dim]; + if (size == 1) { + info.thread_ids.emplace_back(constant(0)); + } else { + auto& dim_id = info.thread_ids.emplace_back(info.thread_id); + if (stride > 1) { + dim_id = builder->CreateUDiv(dim_id, constant(stride)); + } + if (dim) { + dim_id = builder->CreateURem(dim_id, constant(size)); + } + dim_id->setName(absl::StrCat("thread.id.", dim)); + } + } - auto* lane_id = - builder->CreateURem(thread_id_logical, constant(WarpSize()), "lane_id"); - return TilingThreadIdInfo{ - thread_id_logical, thread_ids, start_offsets, strides, - lane_id, block_id_logical, scaling}; + info.lane_id = + builder->CreateURem(info.thread_id, constant(WarpSize()), "lane_id"); + return info; } } // namespace absl::StatusOr EmitTilingKernel( - llvm::IRBuilder<>* builder, const TilingScheme& tiling_scheme, - llvm::Type* index_ty, const TileGenerator& tile_generator) { - absl::Span dims_in_elems = tiling_scheme.GetShape(); - Vector3 dims_in_blocks = tiling_scheme.GetBlockCounts(); + llvm::IRBuilder<>* builder, const Tiling& tiling, llvm::Type* index_ty, + const TileGenerator& tile_generator) { + absl::Span dims_in_elems = tiling.GetShape(); + const auto& block_counts = tiling.GetBlockCounts(); auto constant = [&](uint64_t c) -> llvm::Constant* { return llvm::ConstantInt::get(index_ty, c); }; TF_ASSIGN_OR_RETURN(TilingThreadIdInfo thread_id_info, - EmitThreadIdInfo(builder, tiling_scheme, index_ty)); + EmitThreadIdInfo(builder, tiling, index_ty)); KernelSupportLibrary ksl(builder, llvm_ir::UnrollMode::kDefaultUnroll); const llvm_ir::IrArray::Index block_coords( thread_id_info.block_id, - ShapeUtil::MakeShape(PRED /*arbitrary*/, dims_in_blocks), builder); + ShapeUtil::MakeShape(PRED /*arbitrary*/, block_counts), builder); - std::array tile_dimensions; - for (int i = 0; i < 3; ++i) { - int64_t block_tile_size = tiling_scheme.GetBlockTileSize()[i]; + absl::InlinedVector tile_dimensions; + for (int i = 0; i < block_counts.size(); ++i) { + int64_t block_tile_size = tiling.GetBlockTileSize()[i]; if (dims_in_elems[i] % block_tile_size == 0) { // The block tile size evenly divides the tiled shape -> no need to emit // the bounds check. - tile_dimensions[i] = constant(block_tile_size); + tile_dimensions.push_back(constant(block_tile_size)); } else { // Only the last tile in each dimension may not have full size. - llvm::Value* is_last = builder->CreateICmpEQ( - block_coords[i], constant(dims_in_blocks[i] - 1)); + llvm::Value* is_last = + builder->CreateICmpEQ(block_coords[i], constant(block_counts[i] - 1)); int64_t partial_row = - dims_in_elems[i] - (dims_in_blocks[i] - 1) * block_tile_size; - tile_dimensions[i] = - builder->CreateSelect(is_last, constant(partial_row), - constant(block_tile_size), "tile_bound"); + dims_in_elems[i] - (block_counts[i] - 1) * block_tile_size; + tile_dimensions.push_back(builder->CreateSelect( + is_last, constant(partial_row), constant(block_tile_size), + absl::StrCat("tile_bound.", i))); } } - llvm_ir::IrArray::Index tile_origin = [&] { + llvm_ir::IrArray::Index tile_offset = [&] { std::vector elem_multi_index = block_coords.multidim(); llvm::Type* index_ty = block_coords.GetType(); - for (int i = 0; i < TilingScheme::DimTot; ++i) { + for (int i = 0; i < block_counts.size(); ++i) { elem_multi_index[i] = builder->CreateMul( block_coords[i], - llvm::ConstantInt::get(index_ty, tiling_scheme.GetBlockTileSize()[i]), - "tile_origin." + std::to_string(i)); + llvm::ConstantInt::get(index_ty, tiling.GetBlockTileSize()[i]), + absl::StrCat("tile_origin.", i)); } - return llvm_ir::IrArray::Index(elem_multi_index, tiling_scheme.GetShape(), + return llvm_ir::IrArray::Index(elem_multi_index, tiling.GetShape(), index_ty); }(); - tile_generator(thread_id_info, tile_origin, tile_dimensions); - return {{tile_dimensions, tile_origin, thread_id_info}}; -} - -llvm_ir::IrArray::Index GetUnnormalizedIndex( - const llvm_ir::IrArray::Index& normalized_shape_index, - const Shape& unnormalized_shape, llvm::IRBuilder<>* builder, - absl::Span dims_in_elems) { - CHECK_EQ(normalized_shape_index.size(), 3); - // If the normalization only add a new dimensions of size 1, - // generate simpler indexing. LLVM doesn't always simplify the more - // complicated indexing and this prevents it from vectorizing some - // cases. We do this only for major_to_minor memory layout. - if (unnormalized_shape.rank() == 2 && unnormalized_shape.has_layout() && - unnormalized_shape.dimensions()[0] == normalized_shape_index.dims()[1] && - unnormalized_shape.dimensions()[1] == normalized_shape_index.dims()[2] && - unnormalized_shape.layout().minor_to_major(1) == 0) { - CHECK_EQ(normalized_shape_index.dims()[0], 1); - auto multidim = normalized_shape_index.multidim(); - return llvm_ir::IrArray::Index({multidim[1], multidim[2]}, - unnormalized_shape, - normalized_shape_index.GetType()); - } - if (unnormalized_shape.rank() == 2 && unnormalized_shape.has_layout() && - unnormalized_shape.dimensions()[0] == normalized_shape_index.dims()[2] && - unnormalized_shape.dimensions()[1] == normalized_shape_index.dims()[1] && - unnormalized_shape.layout().minor_to_major(1) == 1) { - CHECK_EQ(normalized_shape_index.dims()[0], 1); - auto multidim = normalized_shape_index.multidim(); - return llvm_ir::IrArray::Index({multidim[2], multidim[1]}, - unnormalized_shape, - normalized_shape_index.GetType()); - } - return normalized_shape_index.SourceIndexOfBitcast( - ShapeUtil::MakeShape(F32, dims_in_elems), unnormalized_shape, builder); + tile_generator(thread_id_info, tile_offset, tile_dimensions); + return {{tile_dimensions, tile_offset, thread_id_info}}; } } // namespace gpu diff --git a/third_party/xla/xla/service/gpu/fusions/tiling_util.h b/third_party/xla/xla/service/gpu/fusions/tiling_util.h index afc80c28a7160a..f06ae8ccab4280 100644 --- a/third_party/xla/xla/service/gpu/fusions/tiling_util.h +++ b/third_party/xla/xla/service/gpu/fusions/tiling_util.h @@ -22,6 +22,7 @@ limitations under the License. #include "absl/log/check.h" #include "absl/types/span.h" #include "xla/service/llvm_ir/ir_array.h" +#include "xla/shape_util.h" #include "xla/util.h" namespace xla { @@ -29,135 +30,113 @@ namespace gpu { // Describes tiling used by the kernel. // -// Used by reduction and transpose emitters. Both algorithms operate over -// "logical" 3D views over input arrays, hence tiling and number of threads -// information has only 3 dimensions. -// -// In the presence of virtual threadIdx/blockIdx scaling, all accessors are -// "logical", unless otherwise specified. -class TilingScheme { +// Used by reduction and transpose emitters. +class Tiling { public: - enum { DimZ = 0, DimY, DimX, DimTot }; - - TilingScheme(Vector3 dims_in_elems, Vector3 tile_sizes, Vector3 num_threads, - int vector_size, int scaling_factor) - : dims_in_elems_(dims_in_elems), + Tiling(absl::InlinedVector shape, + absl::InlinedVector tile_sizes, + absl::InlinedVector num_threads, + // By default, don't unroll anything. + absl::InlinedVector loops_to_unroll = {}) + : shape_(shape), tile_sizes_per_thread_(tile_sizes), - tile_sizes_per_block_{num_threads[0] * tile_sizes[0], - num_threads[1] * tile_sizes[1], - num_threads[2] * tile_sizes[2]}, + tile_sizes_per_block_(shape.size()), num_threads_(num_threads), - vector_size_(vector_size), - thread_id_virtual_scaling_(scaling_factor) { - CHECK_EQ(tile_sizes[2] % vector_size_, 0) - << "tile sizes = " << absl::StrJoin(tile_sizes, ", ") - << "; vector size = " << vector_size_; + num_blocks_(shape.size()), + loops_to_unroll_(loops_to_unroll) { + for (int64_t i = 0; i < shape.size(); ++i) { + tile_sizes_per_block_[i] = tile_sizes[i] * num_threads[i]; + CHECK_NE(tile_sizes_per_block_[i], 0); + num_blocks_[i] = CeilOfRatio(shape[i], tile_sizes_per_block_[i]); + CHECK_NE(num_blocks_[i], 0); + } + if (loops_to_unroll_.empty()) loops_to_unroll_.resize(shape.size()); } std::string ToString() const { return absl::StrJoin( - {absl::StrFormat("dims_in_elems = {%s}", - absl::StrJoin(dims_in_elems_, ", ")), + {absl::StrFormat("shape = {%s}", absl::StrJoin(shape_, ", ")), absl::StrFormat("tile_sizes = {%s}", absl::StrJoin(tile_sizes_per_thread_, ", ")), absl::StrFormat("num_threads = {%s}", - absl::StrJoin(num_threads_, ", ")), - absl::StrFormat("vector_size = %d", vector_size_), - absl::StrFormat("thread_id_virtual_scaling = %d", - thread_id_virtual_scaling_)}, + absl::StrJoin(num_threads_, ", "))}, ", "); } - // Number of elements in each dimension (Z/Y/X respectively). - const Vector3& GetShape() const { return dims_in_elems_; } + // Number of elements in each dimension. + const absl::InlinedVector& GetShape() const { return shape_; } + xla::Shape GetXlaShape(PrimitiveType element_type = F32) const { + return ShapeUtil::MakeShape(element_type, shape_); + } - Vector3 GetBlockCounts() const { - return {GetBlockCount(0), GetBlockCount(1), GetBlockCount(2)}; + const absl::InlinedVector& GetBlockCounts() const { + return num_blocks_; } // Tile size for each thread. // // Equals to the number of iterations in the loop each tile will make. - const Vector3& GetThreadTileSize() const { return tile_sizes_per_thread_; } + const absl::InlinedVector& GetThreadTileSize() const { + return tile_sizes_per_thread_; + } // Tile size for an entire thread block. - const Vector3& GetBlockTileSize() const { return tile_sizes_per_block_; } - - // Number of logical threads per block. - const Vector3& GetThreadsPerBlock() const { return num_threads_; } - int64_t GetNumThreadsPerBlock() const { - return num_threads_[0] * num_threads_[1] * num_threads_[2]; + const absl::InlinedVector& GetBlockTileSize() const { + return tile_sizes_per_block_; } - // Number of logical blocks. - int64_t GetNumBlocks() const { - auto counts = GetBlockCounts(); - return counts[0] * counts[1] * counts[2]; + const absl::InlinedVector& GetThreadsPerBlock() const { + return num_threads_; } - // Number of physical blocks launched (with scaling applied). - int64_t GetNumBlocksPhysical() const { - return CeilOfRatio(GetNumBlocks(), thread_id_virtual_scaling_); + // Returns the strides of the thread index dimensions wrt. the linear thread + // id. + absl::InlinedVector GetThreadStrides() const { + return *ShapeUtil::ByteStrides(ShapeUtil::MakeShape(U8, num_threads_)); } - // Number of physical threads per block launched (with scaling applied). - int64_t GetNumThreadsPerBlockPhysical() const { - return num_threads_[0] * num_threads_[1] * num_threads_[2] * - thread_id_virtual_scaling_; + // Returns the strides of the block index dimensions wrt. the linear block id. + absl::InlinedVector GetBlockStrides() const { + return *ShapeUtil::ByteStrides(ShapeUtil::MakeShape(U8, num_blocks_)); } - int GetVectorSize() const { return vector_size_; } + int64_t GetNumThreadsPerBlock() const { return Product(num_threads_); } - // Scaling factor for transforming physical threadId to logical. - int GetThreadIdScalingFactor() const { return thread_id_virtual_scaling_; } + int64_t GetNumBlocks() const { return Product(num_blocks_); } - private: - // Number of blocks required to "cover" the given dimension. - int64_t GetBlockCount(int d) const { - return CeilOfRatio(dims_in_elems_[d], tile_sizes_per_block_[d]); + const absl::InlinedVector& GetLoopsToUnroll() const { + return loops_to_unroll_; } + private: // The number of elements in each dimension. - Vector3 dims_in_elems_; + absl::InlinedVector shape_; // The number of elements for each dimension of a tile. - Vector3 tile_sizes_per_thread_; - Vector3 tile_sizes_per_block_; + absl::InlinedVector tile_sizes_per_thread_; + absl::InlinedVector tile_sizes_per_block_; - // Number of threads implicitly assigned to each dimension. - Vector3 num_threads_; + absl::InlinedVector num_threads_; + absl::InlinedVector num_blocks_; - // Vector size for dimension X. - int vector_size_; - - // Scaling apply to transform physical threadIdx into logical. - int64_t thread_id_virtual_scaling_ = 1; + absl::InlinedVector loops_to_unroll_; }; -// Contains threading information. Note that for performance we might apply -// thread id "scaling" where the physical thread id (to achieve good SM -// occupancy) will differ from logical thread id. This struct contains -// logical thread ids, along with meta-information about the scaling applied. struct TilingThreadIdInfo { llvm::Value* thread_id; - std::array thread_ids; - std::array start_offsets; - std::array strides; + absl::InlinedVector thread_ids; // Lane id: `thread_id % WarpSize` llvm::Value* lane_id; // Block id. llvm::Value* block_id; - - // The virtual scaling index: [0; thread_id_virtual_scaling). - llvm::Value* scaling_index; }; struct TilingKernelInfo { // Tiling bounds. - std::array output_tile_bounds; + absl::InlinedVector output_tile_bounds; // Starting tile, as calculated from block id only. llvm_ir::IrArray::Index tile_origin; @@ -173,39 +152,18 @@ struct TilingKernelInfo { using TileGenerator = std::function tile_dimensions)>; + absl::Span tile_dimensions)>; // A function object to generate code to process one element in a tile. // -// index_in_tile: the current [z, y, x] coordinate. +// index_in_tile: the current coordinates within the tile. To get the global +// coordinates, use `tile_start_index.AddOffset(index_in_tile, ...)`. using TileElementGenerator = - std::function index_in_tile)>; + std::function index_in_tile)>; -// Emits code to iterate through a 2-dimensional tile with a given tile -// dimensions and given strides, and call the callback at each iteration., -// -// thread_id_y` and `thread_id_x` are the intra-tile coordinates for -// the first element to process, and `index` is the index for the origin of -// the tile. Emits bounds check to ensure that each processed element -// is within the boundary defined by `tile_dimensions`. -// -// Rough pseudocode: -// -// Given: tile_dimensions, x_offset, y_offset -// -// for (y = 0; y < tile_dimensions[0]; y += num_threads_y) { -// for (x = 0; x < tile_dimensions[1]; x++) { -// -// y_pos = y_offset + y -// x_pos = x_offset + x * stride -// -// if (x_loc < tile_width) { -// emit_elem_function(y_offset + y, x_loc); -// } -// } -// } -// -void EmitTile(llvm::IRBuilder<>* builder, const TilingScheme& tiling_scheme, +// Emits code to iterate through a tile with given tile dimensions and generate +// elements using the callback. +void EmitTile(llvm::IRBuilder<>* builder, const Tiling& tiling, const TilingThreadIdInfo& thread_id_info, absl::Span tile_dimensions, const TileElementGenerator& emit_elem_function); @@ -213,13 +171,8 @@ void EmitTile(llvm::IRBuilder<>* builder, const TilingScheme& tiling_scheme, // Emits a kernel for the hlo instruction using the given kernel mapping // scheme. absl::StatusOr EmitTilingKernel( - llvm::IRBuilder<>* builder, const TilingScheme& tiling_scheme, - llvm::Type* index_ty, const TileGenerator& tile_element_generator); - -llvm_ir::IrArray::Index GetUnnormalizedIndex( - const llvm_ir::IrArray::Index& normalized_shape_index, - const Shape& unnormalized_shape, llvm::IRBuilder<>* builder, - absl::Span dims_in_elems); + llvm::IRBuilder<>* builder, const Tiling& tiling, llvm::Type* index_ty, + const TileGenerator& tile_element_generator); } // namespace gpu } // namespace xla diff --git a/third_party/xla/xla/service/gpu/fusions/transpose.cc b/third_party/xla/xla/service/gpu/fusions/transpose.cc index 3b0a912cd4d19b..80c0b93915cebc 100644 --- a/third_party/xla/xla/service/gpu/fusions/transpose.cc +++ b/third_party/xla/xla/service/gpu/fusions/transpose.cc @@ -22,18 +22,18 @@ limitations under the License. #include #include "absl/container/flat_hash_map.h" +#include "absl/container/inlined_vector.h" #include "absl/log/check.h" #include "absl/status/status.h" #include "absl/strings/str_cat.h" -#include "absl/strings/string_view.h" #include "absl/types/span.h" #include "llvm/ADT/STLExtras.h" #include "llvm/IR/DerivedTypes.h" -#include "llvm/IR/GlobalVariable.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Type.h" #include "llvm/IR/Value.h" #include "llvm/Support/AtomicOrdering.h" +#include "mlir/IR/AffineMap.h" // from @llvm-project #include "xla/hlo/ir/hlo_instructions.h" #include "xla/permutation_util.h" #include "xla/service/gpu/elemental_ir_emitter.h" @@ -42,6 +42,8 @@ limitations under the License. #include "xla/service/gpu/ir_emission_utils.h" #include "xla/service/gpu/ir_emitter_context.h" #include "xla/service/gpu/launch_dimensions.h" +#include "xla/service/gpu/model/indexing_analysis.h" +#include "xla/service/gpu/model/indexing_map.h" #include "xla/service/gpu/target_util.h" #include "xla/service/llvm_ir/fused_ir_emitter.h" #include "xla/service/llvm_ir/ir_array.h" @@ -53,8 +55,7 @@ namespace xla { namespace gpu { namespace { -TilingScheme ComputeTransposeTilingScheme( - const TransposeDescription& tiled_transpose) { +Tiling ComputeTransposeTiling(const TransposeDescription& tiled_transpose) { constexpr int kNumRows = 4; static_assert(WarpSize() % kNumRows == 0); @@ -66,52 +67,17 @@ TilingScheme ComputeTransposeTilingScheme( // always use the permutation, even when we want the inverse. CHECK((permutation == Vector3{0, 2, 1}) || (permutation == Vector3{2, 1, 0})); - Vector3 input_dims{transposed_dims[permutation[0]], - transposed_dims[permutation[1]], - transposed_dims[permutation[2]]}; - // The tiling corresponds to the two minor dimensions before and after the - // transpose. The remaining dimension is the batch dimension. - // The order is {batch, minor post-transpose, minor pre-transpose}. - // - // Examples for transposed_dims {200, 300, 700}: - // order {0, 2, 1} {2, 1, 0} - // input_dims {200, 700, 300} {700, 300, 200} - // tiled_shape {200, 700, 300} {300, 700, 200} - // tile -> input {0, 1, 2} {1, 0, 2} - Vector3 tiled_shape{input_dims[1 - permutation[2]], transposed_dims[2], - input_dims[2]}; - - Vector3 tile_sizes{1, WarpSize() / kNumRows, 1}; - Vector3 num_threads{1, kNumRows, WarpSize()}; - - return TilingScheme( - /*dims_in_elems=*/tiled_shape, - /*tile_sizes=*/tile_sizes, - /*num_threads=*/num_threads, - /*vector_size=*/1, - /*scaling_factor=*/1); -} + absl::InlinedVector input_dims{transposed_dims[permutation[0]], + transposed_dims[permutation[1]], + transposed_dims[permutation[2]]}; -Vector3 TileToInoutPermutation(Vector3 permutation) { - // See ComputeTransposeTilingScheme. - // Note: this is also the tile to output permutation because we swap the - // last two components. - return permutation[2] == 1 ? Vector3{0, 1, 2} : Vector3{1, 0, 2}; -} + // We tile along the minor dimensions pre- and post-transpose. + absl::InlinedVector tile_sizes{1, 1, 1}; + tile_sizes[permutation[2]] = WarpSize() / kNumRows; + absl::InlinedVector num_threads{1, 1, WarpSize()}; + num_threads[permutation[2]] = kNumRows; -llvm::GlobalVariable* AllocateShared( - llvm::IRBuilder<>* builder, const TilingScheme& tiling_scheme, - llvm::Type* element_type, - absl::Span dimensions_major_to_minor, - absl::string_view buffer_name) { - CHECK(!dimensions_major_to_minor.empty()); - llvm::Type* ty = element_type; - for (auto dim : llvm::reverse(dimensions_major_to_minor)) { - ty = llvm::ArrayType::get(ty, dim); - } - ty = llvm::ArrayType::get(ty, tiling_scheme.GetThreadIdScalingFactor()); - return llvm_ir::AllocateSharedMemoryTile( - builder->GetInsertBlock()->getModule(), ty, buffer_name); + return Tiling(input_dims, tile_sizes, num_threads); } void MaybeEmitFenceForAMDGPU(llvm::IRBuilder<>* builder, @@ -142,7 +108,14 @@ llvm_ir::IrArray::Index PermuteIndex(const llvm_ir::IrArray::Index& index, TransposeFusion::TransposeFusion(const HloFusionAnalysis& analysis) : analysis_(analysis), - tiling_scheme_(ComputeTransposeTilingScheme(analysis.tiled_transpose())) { + tiling_(ComputeTransposeTiling(analysis.tiled_transpose())) { + for (auto [root, hero] : + llvm::zip(analysis_.fusion_roots(), analysis_.fusion_heroes())) { + if (auto transpose = GetDescriptionForTiledTransposeEmitter(*root, *hero)) { + permutation_ = transpose->permutation; + break; + } + } } absl::Status TransposeFusion::EmitKernel(IrEmitterContext& ir_emitter_context, @@ -174,7 +147,7 @@ absl::Status TransposeFusion::EmitKernel(IrEmitterContext& ir_emitter_context, std::vector> extra_outputs; for (const auto& [output_idx, root] : llvm::enumerate(hlo_roots)) { - const auto& hero = FindNonTrivialHero(*root); + const auto& hero = *analysis_.fusion_heroes()[output_idx]; auto transpose_descr = GetDescriptionForTiledTransposeEmitter(*root, hero); if (transpose_descr.has_value()) { auto iterator_inserted = transposes_to_roots.insert(std::make_pair( @@ -194,7 +167,7 @@ absl::Status TransposeFusion::EmitKernel(IrEmitterContext& ir_emitter_context, Vector3 permutation; for (const auto& [tile_idx, tr] : llvm::enumerate(transposes)) { permutation = tr.permutation; - auto tile_size = tiling_scheme_.GetBlockTileSize(); + auto tile_size = tiling_.GetBlockTileSize(); ++tile_size.back(); // Prevent bank conflicts. auto* module = ir_emitter_context.llvm_module(); tiles[tr.instr] = llvm_ir::AllocateSharedMemoryTile( @@ -204,62 +177,56 @@ absl::Status TransposeFusion::EmitKernel(IrEmitterContext& ir_emitter_context, tile_size, absl::StrCat("tr_tile_", tile_idx)); } - auto tile_to_inout = TileToInoutPermutation(permutation); - auto input_shape = Permute(tiling_scheme_.GetShape(), tile_to_inout); auto tile_generator = [&](const TilingThreadIdInfo& thread_id_info, const llvm_ir::IrArray::Index& tile_start_index, - std::array tile_dimensions) { + absl::Span tile_dimensions) { // Copy input parameter values to shared memory buffers: // tile[thread_id_y, thread_id_x] = input[index] - EmitTile( - builder, tiling_scheme_, thread_id_info, tile_dimensions, - [&](std::array index_in_tile) { - auto index = - PermuteIndex(tile_start_index.AddOffset(index_in_tile, builder), - tile_to_inout); - for (const auto& tr : transposes) { - auto input_gen = *fused_emitter.GetGenerator(*tr.instr->operand(0)); - auto input_index = GetUnnormalizedIndex( - index, tr.instr->operand(0)->shape(), builder, input_shape); - llvm::Value* value = *input_gen(input_index); - tiles[tr.instr].Store(value, index_in_tile, builder); - } - - // Compute all extra output values before writing them. This - // avoids overwriting aliased input/output values before all reads - // occurred. - std::vector> - scheduled_writes; - for (const auto& [output_idx, root] : extra_outputs) { - llvm_ir::IrArray::Index extra_output_index = GetUnnormalizedIndex( - index, root->shape(), builder, input_shape); - auto output_gen = *fused_emitter.GetGenerator(*root); - llvm::Value* output_value = *output_gen(extra_output_index); - scheduled_writes.emplace_back(outputs[output_idx], - extra_output_index, output_value); - } - - for (const auto& [output, idx, value] : scheduled_writes) { - output.EmitWriteArrayElement(idx, value, builder); - } - }); + EmitTile(builder, tiling_, thread_id_info, tile_dimensions, + [&](absl::Span index_in_tile) { + auto index = tile_start_index.AddOffset(index_in_tile, builder); + for (const auto& tr : transposes) { + auto input_gen = + *fused_emitter.GetGenerator(*tr.instr->operand(0)); + auto input_index = index.SourceIndexOfBitcast( + tr.instr->operand(0)->shape(), builder); + llvm::Value* value = *input_gen(input_index); + tiles[tr.instr].Store(value, index_in_tile, builder); + } + + // Compute all extra output values before writing them. This + // avoids overwriting aliased input/output values before all + // reads occurred. + std::vector> + scheduled_writes; + for (const auto& [output_idx, root] : extra_outputs) { + auto extra_output_index = + index.SourceIndexOfBitcast(root->shape(), builder); + auto output_gen = *fused_emitter.GetGenerator(*root); + llvm::Value* output_value = *output_gen(extra_output_index); + scheduled_writes.emplace_back( + outputs[output_idx], extra_output_index, output_value); + } + + for (const auto& [output, idx, value] : scheduled_writes) { + output.EmitWriteArrayElement(idx, value, builder); + } + }); EmitSyncThreads(builder, ir_emitter_context); - auto output_tile_index = PermuteIndex(tile_start_index, {0, 2, 1}); - auto transposed_tile_dimensions = Permute(tile_dimensions, {0, 2, 1}); + auto output_tile_index = PermuteIndex(tile_start_index, permutation); + auto transposed_tile_dimensions = Permute(tile_dimensions, permutation); EmitTile( - builder, tiling_scheme_, thread_id_info, transposed_tile_dimensions, + builder, tiling_, thread_id_info, transposed_tile_dimensions, /*emit_elem_function=*/ - [&](std::array index_in_tile) { - auto index = - PermuteIndex(output_tile_index.AddOffset(index_in_tile, builder), - tile_to_inout); + [&](absl::Span index_in_tile) { + auto index = output_tile_index.AddOffset(index_in_tile, builder); for (const auto& tr : transposes) { llvm::Value* loaded = tiles[tr.instr].Load( - Permute(index_in_tile, {0, 2, 1}), builder); + Permute(index_in_tile, permutation), builder); FusedIrEmitter fused_emitter(elemental_emitter); fused_emitter.BindGenerator( @@ -291,9 +258,8 @@ absl::Status TransposeFusion::EmitKernel(IrEmitterContext& ir_emitter_context, // Both for emission and writing it should be // index-as-transformed by the computation. - llvm_ir::IrArray::Index untiled_index = - GetUnnormalizedIndex(index, root->shape(), builder, - Permute(input_shape, permutation)); + auto untiled_index = + index.SourceIndexOfBitcast(root->shape(), builder); TF_ASSIGN_OR_RETURN(llvm::Value * generated, gen(untiled_index)); scheduled_writes.emplace_back(outputs[output_idx], untiled_index, generated); @@ -308,13 +274,45 @@ absl::Status TransposeFusion::EmitKernel(IrEmitterContext& ir_emitter_context, llvm::Type* index_type = GetIndexTypeForKernel(&fusion, launch_dims.launch_bound(), builder); - return EmitTilingKernel(builder, tiling_scheme_, index_type, tile_generator) + return EmitTilingKernel(builder, tiling_, index_type, tile_generator) .status(); } LaunchDimensions TransposeFusion::launch_dimensions() const { - return LaunchDimensions(tiling_scheme_.GetNumBlocksPhysical(), - tiling_scheme_.GetNumThreadsPerBlockPhysical()); + return LaunchDimensions(tiling_.GetNumBlocks(), + tiling_.GetNumThreadsPerBlock()); +} + +std::optional TransposeFusion::ComputeThreadIdToOutputIndexing( + int64_t root_index, mlir::MLIRContext* ctx) const { + const auto& hero = *analysis_.fusion_heroes()[root_index]; + const auto& root = *analysis_.fusion_roots()[root_index]; + if (!GetDescriptionForTiledTransposeEmitter(root, hero)) { + // Non-transpose roots are elementwise by definition. + return ComputeThreadIdToInputIndexing(root_index, 0, ctx); + } + + // The block offsets are permuted, but the thread offsets remain the same. + auto block_offset = GetBlockOffsetsForTiling(tiling_, ctx) + .getSubMap(std::vector{permutation_.begin(), + permutation_.end()}); + auto thread_offset = GetThreadOffsetsForTiling(tiling_, ctx); + auto permuted_tiled_shape = + ShapeUtil::MakeShape(U8, Permute(tiling_.GetShape(), permutation_)); + + return ComposeIndexingMaps( + GetIndexingMapForTiling(block_offset, thread_offset, tiling_), + GetBitcastMap(permuted_tiled_shape, hero.shape(), ctx)); +} + +std::optional TransposeFusion::ComputeThreadIdToInputIndexing( + int64_t root_index, int64_t hero_operand_index, + mlir::MLIRContext* ctx) const { + const auto& hero = *analysis_.fusion_heroes()[root_index]; + + return ComposeIndexingMaps( + GetIndexingMapForTiling(tiling_, ctx), + GetBitcastMap(tiling_.GetXlaShape(), hero.operand(0)->shape(), ctx)); } } // namespace gpu diff --git a/third_party/xla/xla/service/gpu/fusions/transpose.h b/third_party/xla/xla/service/gpu/fusions/transpose.h index d973cc0f48af7d..899b1cb94390ae 100644 --- a/third_party/xla/xla/service/gpu/fusions/transpose.h +++ b/third_party/xla/xla/service/gpu/fusions/transpose.h @@ -15,17 +15,22 @@ limitations under the License. #ifndef XLA_SERVICE_GPU_FUSIONS_TRANSPOSE_H_ #define XLA_SERVICE_GPU_FUSIONS_TRANSPOSE_H_ +#include #include #include #include "absl/status/status.h" #include "llvm/IR/IRBuilder.h" +#include "mlir/IR/MLIRContext.h" // from @llvm-project #include "xla/hlo/ir/hlo_instructions.h" #include "xla/service/gpu/fusions/fusion_emitter.h" #include "xla/service/gpu/fusions/tiling_util.h" #include "xla/service/gpu/hlo_fusion_analysis.h" #include "xla/service/gpu/ir_emitter_context.h" #include "xla/service/gpu/launch_dimensions.h" +#include "xla/service/gpu/model/indexing_map.h" +#include "xla/service/llvm_ir/ir_array.h" +#include "xla/util.h" namespace xla { namespace gpu { @@ -59,10 +64,11 @@ class TransposeFusion : public KernelFusionEmitterBase { LaunchDimensions launch_dimensions() const override; std::optional ComputeThreadIdToOutputIndexing( - int64_t output_id, mlir::MLIRContext* ctx) const override { - // TODO(b/319081342): Implement this. - return std::nullopt; - } + int64_t root_index, mlir::MLIRContext* ctx) const override; + + std::optional ComputeThreadIdToInputIndexing( + int64_t root_index, int64_t hero_operand_index, + mlir::MLIRContext* ctx) const override; protected: absl::Status EmitKernel(IrEmitterContext& ir_emitter_context, @@ -74,7 +80,8 @@ class TransposeFusion : public KernelFusionEmitterBase { private: const HloFusionAnalysis& analysis_; - TilingScheme tiling_scheme_; + Tiling tiling_; + Vector3 permutation_; }; } // namespace gpu diff --git a/third_party/xla/xla/service/gpu/fusions/transpose_test.cc b/third_party/xla/xla/service/gpu/fusions/transpose_test.cc new file mode 100644 index 00000000000000..55f6f420ebf0be --- /dev/null +++ b/third_party/xla/xla/service/gpu/fusions/transpose_test.cc @@ -0,0 +1,197 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "xla/service/gpu/fusions/transpose.h" + +#include +#include + +#include +#include +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "xla/service/gpu/fusions/fusions.h" +#include "xla/service/gpu/gpu_device_info_for_tests.h" +#include "xla/service/gpu/hlo_fusion_analysis.h" +#include "xla/service/gpu/model/indexing_test_utils.h" +#include "xla/status_macros.h" +#include "xla/statusor.h" +#include "xla/stream_executor/device_description.h" +#include "xla/tests/hlo_test_base.h" +#include "tsl/platform/statusor.h" + +namespace xla { +namespace gpu { +namespace { + +using ::testing::HasSubstr; + +class TransposeTest : public HloTestBase { + protected: + stream_executor::DeviceDescription device_info_ = + TestGpuDeviceInfo::RTXA6000DeviceInfo(); +}; + +StatusOr> GetTransposeFusion( + const HloFusionAnalysis& analysis) { + TF_ASSIGN_OR_RETURN( + auto emitter, GetFusionEmitter(PreBufferAssignmentFusionInfo{analysis})); + auto fusion = dynamic_cast(emitter.get()); + TF_RET_CHECK(fusion != nullptr); + + emitter.release(); + return std::unique_ptr{fusion}; +} + +TEST_F(TransposeTest, ThreadIndexing021) { + auto module = ParseAndReturnVerifiedModule(R"( + HloModule module + + fusion { + %input = f32[100,32,64] parameter(0) + ROOT transpose = f32[100,64,32] transpose(%input), dimensions={0,2,1} + } + + ENTRY entry { + %input = f32[100,32,64] parameter(0) + ROOT %fusion = f32[100,64,32] fusion(%input), kind=kInput, calls=fusion + })") + .value(); + + auto* root = module->entry_computation()->root_instruction(); + auto analysis = AnalyzeFusion(*root, device_info_); + + TF_ASSERT_OK_AND_ASSIGN(auto fusion, GetTransposeFusion(analysis)); + mlir::MLIRContext mlir_context; + + EXPECT_THAT( + fusion->ComputeThreadIdToInputIndexing(0, 0, &mlir_context)->ToString(), + MatchIndexingString(R"( + (d0, d1, d2, d3, d4, d5)[s0, s1, s2] -> ( + d3 floordiv 2, + d0 floordiv 32 + s1 * 4, + (d3 mod 2) * 32 + d0 mod 32 + ) + domain: + d0 in [0, 127] + d1 in [0, 0] + d2 in [0, 0] + d3 in [0, 199] + d4 in [0, 0] + d5 in [0, 0] + + s0 in [0, 0] + s1 in [0, 7] + s2 in [0, 0] + + (d3 mod 2) * 32 + d0 mod 32 in [0, 63] + d0 floordiv 32 + s1 * 4 in [0, 31] + )")); + EXPECT_THAT( + fusion->ComputeThreadIdToOutputIndexing(0, &mlir_context)->ToString(), + MatchIndexingString(R"( + (d0, d1, d2, d3, d4, d5)[s0, s1, s2] -> ( + d3 floordiv 2, + (d3 mod 2) * 32 + d0 floordiv 32 + s1 * 4, + d0 mod 32 + ) + domain: + d0 in [0, 127] + d1 in [0, 0] + d2 in [0, 0] + d3 in [0, 199] + d4 in [0, 0] + d5 in [0, 0] + + s0 in [0, 0] + s1 in [0, 7] + s2 in [0, 0] + + (d3 mod 2) * 32 + d0 floordiv 32 + s1 * 4 in [0, 63] + d0 mod 32 in [0, 31] + )")); +} + +TEST_F(TransposeTest, ThreadIndexing201) { + auto module = ParseAndReturnVerifiedModule(R"( + HloModule module + + fusion { + %input = f32[100,64,32] parameter(0) + ROOT transpose = f32[32,100,64] transpose(%input), dimensions={2,0,1} + } + + ENTRY entry { + %input = f32[100,64,32] parameter(0) + ROOT %fusion = f32[32,100,64] fusion(%input), kind=kInput, calls=fusion + })") + .value(); + + auto* root = module->entry_computation()->root_instruction(); + auto analysis = AnalyzeFusion(*root, device_info_); + + TF_ASSERT_OK_AND_ASSIGN(auto fusion, GetTransposeFusion(analysis)); + mlir::MLIRContext mlir_context; + EXPECT_THAT( + fusion->ComputeThreadIdToInputIndexing(0, 0, &mlir_context)->ToString(), + MatchIndexingString(R"( + (d0, d1, d2, d3, d4, d5)[s0, s1, s2] -> ( + (d3 * 32 + d0 floordiv 32 + s1 * 4) floordiv 64, + (d3 * 32 + d0 floordiv 32 + s1 * 4) mod 64, + d0 mod 32 + ) + domain: + d0 in [0, 127] + d1 in [0, 0] + d2 in [0, 0] + d3 in [0, 199] + d4 in [0, 0] + d5 in [0, 0] + + s0 in [0, 0] + s1 in [0, 7] + s2 in [0, 0] + + 0 in [0, 0] + d0 mod 32 in [0, 31] + d3 * 32 + d0 floordiv 32 + s1 * 4 in [0, 6399] + )")); + EXPECT_THAT( + fusion->ComputeThreadIdToOutputIndexing(0, &mlir_context)->ToString(), + MatchIndexingString(R"( + (d0, d1, d2, d3, d4, d5)[s0, s1, s2] -> ( + d0 floordiv 32 + s1 * 4, + (d3 * 32 + d0 mod 32) floordiv 64, + (d3 * 32 + d0 mod 32) mod 64 + ) + domain: + d0 in [0, 127] + d1 in [0, 0] + d2 in [0, 0] + d3 in [0, 199] + d4 in [0, 0] + d5 in [0, 0] + + s0 in [0, 0] + s1 in [0, 7] + s2 in [0, 0] + + 0 in [0, 0] + d0 floordiv 32 + s1 * 4 in [0, 31] + d3 * 32 + d0 mod 32 in [0, 6399] + )")); +} + +} // namespace +} // namespace gpu +} // namespace xla diff --git a/third_party/xla/xla/service/gpu/fusions/triton.cc b/third_party/xla/xla/service/gpu/fusions/triton.cc index b4eee54199c0f3..c79b03c2656cc3 100644 --- a/third_party/xla/xla/service/gpu/fusions/triton.cc +++ b/third_party/xla/xla/service/gpu/fusions/triton.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "xla/service/gpu/fusions/triton.h" +#include #include #include #include @@ -35,7 +36,7 @@ limitations under the License. #include "xla/service/gpu/kernel_reuse_cache.h" #include "xla/service/gpu/launch_dimensions.h" #include "xla/service/gpu/matmul_utils.h" -#include "xla/service/gpu/runtime3/kernel_thunk.h" +#include "xla/service/gpu/runtime/kernel_thunk.h" #include "xla/service/gpu/triton_fusion_analysis.h" #include "xla/service/llvm_ir/llvm_util.h" #include "xla/statusor.h" @@ -95,26 +96,15 @@ LaunchDimensions CalculateSoftMaxLaunchDimensions( } // namespace absl::StatusOr TritonFusion::Emit( - IrEmitterContext& ir_emitter_context, mlir::lmhlo::FusionOp fusion_op, + IrEmitterContext& ir_emitter_context, const HloFusionInstruction& fusion) const { llvm::IRBuilder builder(ir_emitter_context.llvm_module()->getContext()); #if GOOGLE_CUDA - if (!ir_emitter_context.emit_ir_from_hlo()) { - CHECK_NE(fusion_op, nullptr); - } - if (ir_emitter_context.emit_ir_from_hlo()) { - VLOG(3) << fusion.ToString(); - } else { - VLOG(3) << llvm_ir::DumpToString(fusion_op); - } + VLOG(3) << fusion.ToString(); std::string suggested_kernel_name = std::string(fusion.name()); - TF_ASSIGN_OR_RETURN(auto kernel_arguments, - ir_emitter_context.emit_ir_from_hlo() - ? KernelArguments::Create( - ir_emitter_context.buffer_assignment(), &fusion) - : KernelArguments::Create( - ir_emitter_context.allocations(), - mlir::cast(fusion_op))); + TF_ASSIGN_OR_RETURN( + auto kernel_arguments, + KernelArguments::Create(ir_emitter_context.buffer_assignment(), &fusion)); const HloComputation* hlo_computation = fusion.fused_instructions_computation(); @@ -135,12 +125,13 @@ absl::StatusOr TritonFusion::Emit( if (fusion_kind == kTritonSoftmaxFusionKind) { launch_dimensions = *this->launch_dimensions(); - auto& triton_config = *backend_config.mutable_triton_gemm_config(); - triton_config.set_num_stages(1); + // This is a hack, we use TritonGemmConfig for Softmax too, but we ignore + // most parameters. + TritonGemmConfig config; + config.num_stages = 1; // Thread count per block is always a multiple of WarpSize. - triton_config.set_num_warps(launch_dimensions.num_threads_per_block() / - WarpSize()); - TritonGemmConfig config = TritonGemmConfig::FromProto(triton_config); + config.num_warps = launch_dimensions.num_threads_per_block() / WarpSize(); + config.num_ctas = 1; TF_ASSIGN_OR_RETURN(auto analysis, TritonFusionAnalysis::Execute(*hlo_computation)); @@ -155,13 +146,8 @@ absl::StatusOr TritonFusion::Emit( } else { // Must be a MatMul CHECK_EQ(fusion_kind, kTritonGemmFusionKind); if (!backend_config.has_triton_gemm_config()) { - if (ir_emitter_context.emit_ir_from_hlo()) { - LOG(WARNING) << "Using fallback triton GEMM config for op " - << fusion.name(); - } else { - LOG(WARNING) << "Using fallback triton GEMM config for op " - << GetIrNameFromLoc(fusion_op->getLoc()); - } + LOG(WARNING) << "Using fallback triton GEMM config for op " + << fusion.name(); auto& triton_config = *backend_config.mutable_triton_gemm_config(); triton_config.set_block_m(64); triton_config.set_block_k(64); @@ -170,8 +156,9 @@ absl::StatusOr TritonFusion::Emit( triton_config.set_num_stages(1); triton_config.set_num_warps(2); } - TritonGemmConfig config = - TritonGemmConfig::FromProto(backend_config.triton_gemm_config()); + TF_ASSIGN_OR_RETURN( + TritonGemmConfig config, + TritonGemmConfig::FromProto(backend_config.triton_gemm_config())); TF_ASSIGN_OR_RETURN(auto analysis, TritonFusionAnalysis::Execute( *hlo_computation, config.split_k)); @@ -209,6 +196,7 @@ absl::StatusOr TritonFusion::Emit( impl_fn->eraseFromParent(); return {{kernel->getName().str(), launch_dimensions, + triton_wrapper_result.cluster_dim, triton_wrapper_result.shmem_bytes}}; }; @@ -218,17 +206,10 @@ absl::StatusOr TritonFusion::Emit( /*discriminator=*/"", generate); TF_ASSIGN_OR_RETURN(const KernelReuseCache::Entry* entry, status_or_entry); - std::variant fusion_op_or_hlo; - if (ir_emitter_context.emit_ir_from_hlo()) { - fusion_op_or_hlo = &fusion; - } else { - fusion_op_or_hlo = fusion_op; - } - FusionEmissionResult result; result.thunks.emplace_back(std::make_unique( - fusion_op_or_hlo, entry->kernel_name, kernel_arguments.args(), - entry->launch_dimensions, entry->shmem_bytes)); + &fusion, entry->kernel_name, kernel_arguments.args(), + entry->launch_dimensions, entry->cluster_dim, entry->shmem_bytes)); return result; #else diff --git a/third_party/xla/xla/service/gpu/fusions/triton.h b/third_party/xla/xla/service/gpu/fusions/triton.h index f774f16ec84454..938c23ccf8f210 100644 --- a/third_party/xla/xla/service/gpu/fusions/triton.h +++ b/third_party/xla/xla/service/gpu/fusions/triton.h @@ -33,7 +33,7 @@ class TritonFusion : public FusionInterface { : analysis_(analysis) {} absl::StatusOr Emit( - IrEmitterContext& ir_emitter_context, mlir::lmhlo::FusionOp fusion_op, + IrEmitterContext& ir_emitter_context, const HloFusionInstruction& fusion) const final; // Returns the launch dimensions for softmax fusions. Not supported for diff --git a/third_party/xla/xla/service/gpu/gemm_algorithm_picker.cc b/third_party/xla/xla/service/gpu/gemm_algorithm_picker.cc index 2cbeb77ff43832..d3a10d3ef69665 100644 --- a/third_party/xla/xla/service/gpu/gemm_algorithm_picker.cc +++ b/third_party/xla/xla/service/gpu/gemm_algorithm_picker.cc @@ -90,8 +90,7 @@ absl::StatusOr GetBestAlgorithm( output_buffer); } - TF_ASSIGN_OR_RETURN(se::blas::ProfileResult profile_result, - run_benchmark(algorithm)); + TF_ASSIGN_OR_RETURN(auto profile_result, run_benchmark(algorithm)); results.emplace_back(); AutotuneResult& result = results.back(); @@ -125,8 +124,8 @@ absl::StatusOr GetBestAlgorithm( } if (!reference_algorithm) { - stream->ThenMemcpy(&reference_buffer, output_buffer, - output_buffer.size()); + TF_RETURN_IF_ERROR(stream->Memcpy(&reference_buffer, output_buffer, + output_buffer.size())); reference_algorithm = profile_result.algorithm(); } else { // Perform the comparison. @@ -225,12 +224,12 @@ absl::StatusOr DoGemmAutotuneNoCache( TF_ASSIGN_OR_RETURN(se::Stream* const stream, autotune_config.GetStream()); GpuBackendConfig gpu_config = gemm->backend_config().value(); - const GemmBackendConfig& gemm_config = gpu_config.gemm_backend_config(); + const GemmBackendConfig& backend_config = gpu_config.gemm_backend_config(); const DebugOptions& debug_options = gemm->GetModule()->config().debug_options(); const bool deterministic_ops = debug_options.xla_gpu_deterministic_ops(); - TF_ASSIGN_OR_RETURN(GemmConfig config, GemmConfig::For(gemm)); + TF_ASSIGN_OR_RETURN(GemmConfig gemm_config, GemmConfig::For(gemm)); // Don't run autotuning concurrently on the same GPU. absl::MutexLock gpu_lock(&GetGpuMutex(stream->parent())); @@ -276,18 +275,18 @@ absl::StatusOr DoGemmAutotuneNoCache( HloModuleConfig& hlo_module_config = gemm->GetModule()->mutable_config(); AutotuneResult best_algorithm; if (IsCublasLtMatmul(*gemm)) { - bool has_matrix_bias = config.beta != 0.; + bool has_matrix_bias = gemm_config.beta != 0.; TF_ASSIGN_OR_RETURN( bool has_vector_bias, - gpublas_lt::EpilogueAddsVectorBias(gemm_config.epilogue())); + gpublas_lt::EpilogueAddsVectorBias(backend_config.epilogue())); TF_ASSIGN_OR_RETURN( bool has_aux_output, - gpublas_lt::EpilogueHasAuxiliaryOutput(gemm_config.epilogue())); + gpublas_lt::EpilogueHasAuxiliaryOutput(backend_config.epilogue())); TF_ASSIGN_OR_RETURN(auto epilogue, - AsBlasLtEpilogue(gemm_config.epilogue())); + AsBlasLtEpilogue(backend_config.epilogue())); se::DeviceMemoryBase bias_buffer; if (has_vector_bias) { @@ -309,7 +308,7 @@ absl::StatusOr DoGemmAutotuneNoCache( } TF_ASSIGN_OR_RETURN(auto plan, - BlasLt::GetMatmulPlan(stream, config, epilogue)); + BlasLt::GetMatmulPlan(stream, gemm_config, epilogue)); TF_ASSIGN_OR_RETURN(auto algorithms, plan->GetAlgorithms()); @@ -318,7 +317,7 @@ absl::StatusOr DoGemmAutotuneNoCache( GetBestAlgorithm( stream, buffer_allocator, gemm->ToString(), autotune_config, lhs_buffer, rhs_buffer, output_buffer, algorithms, output_shape, - hlo_module_config, gemm_config.beta(), + hlo_module_config, backend_config.beta(), [&](const BlasLt::MatmulAlgorithm& algorithm) -> absl::StatusOr { se::OwningScratchAllocator<> scratch_allocator( @@ -333,11 +332,21 @@ absl::StatusOr DoGemmAutotuneNoCache( })); } else { std::vector algorithms; - TF_RET_CHECK(stream->parent()->GetBlasGemmAlgorithms(stream, &algorithms)); + TF_ASSIGN_OR_RETURN(GemmConfig::DescriptorsTuple desc, + gemm_config.GetMatrixDescriptors(lhs_buffer, rhs_buffer, + output_buffer)); -#if TENSORFLOW_USE_ROCM // Blas gemm algorithms are not yet supported + auto blas = stream->parent()->AsBlas(); + if (blas == nullptr) { + return absl::InternalError("No BLAS support for stream"); + } + blas->GetBlasGemmAlgorithms(stream, desc.lhs, desc.rhs, &desc.output, + &gemm_config.alpha, &gemm_config.beta, + &algorithms); + +#if TENSORFLOW_USE_ROCM // Blas gemm algorithms can be empty for ROCM if (algorithms.empty()) { // nothing to autotune - VLOG(1) << "Skipping autotuning for ROCm.."; + LOG(WARNING) << "No solutions found: skipping autotuning for ROCM.."; best_algorithm.mutable_gemm()->set_algorithm(se::blas::kDefaultAlgorithm); return best_algorithm; } @@ -348,7 +357,7 @@ absl::StatusOr DoGemmAutotuneNoCache( GetBestBlasAlgorithm( stream, buffer_allocator, gemm->ToString(), autotune_config, lhs_buffer, rhs_buffer, output_buffer, algorithms, output_shape, - hlo_module_config, gemm_config.beta(), + hlo_module_config, backend_config.beta(), [&](const se::blas::AlgorithmType& algorithm) -> absl::StatusOr { se::blas::ProfileResult profile_result; @@ -359,7 +368,7 @@ absl::StatusOr DoGemmAutotuneNoCache( // should always return true, and the actual // success-ness is returned in // ProfileResult::is_valid. - TF_RETURN_IF_ERROR(RunGemm(config, lhs_buffer, rhs_buffer, + TF_RETURN_IF_ERROR(RunGemm(gemm_config, lhs_buffer, rhs_buffer, output_buffer, workspace_buffer, deterministic_ops, stream, algorithm, &profile_result)); @@ -383,11 +392,11 @@ absl::StatusOr RunOnInstruction(HloInstruction* gemm, GpuBackendConfig gpu_config = gemm->backend_config().value(); - GemmBackendConfig gemm_config = gpu_config.gemm_backend_config(); + GemmBackendConfig backend_config = gpu_config.gemm_backend_config(); // Degenerate gemms replaced with memzero operation, no need to auto tune it. - if (gemm_config.alpha_real() == 0.0 && gemm_config.alpha_imag() == 0.0 && - gemm_config.beta() == 0.0) { + if (backend_config.alpha_real() == 0.0 && + backend_config.alpha_imag() == 0.0 && backend_config.beta() == 0.0) { VLOG(3) << "Skip degenerate gemm instruction auto tuning"; return false; } @@ -399,7 +408,7 @@ absl::StatusOr RunOnInstruction(HloInstruction* gemm, return DoGemmAutotuneNoCache(gemm, key, config); })); - GemmBackendConfig updated_config = gemm_config; + GemmBackendConfig updated_config = backend_config; bool update_algorithm = std::visit( VariantVisitor{[](const se::CudaComputeCapability& cc) { @@ -423,7 +432,8 @@ absl::StatusOr RunOnInstruction(HloInstruction* gemm, } *gpu_config.mutable_gemm_backend_config() = updated_config; TF_RETURN_IF_ERROR(gemm->set_backend_config(gpu_config)); - return updated_config.SerializeAsString() != gemm_config.SerializeAsString(); + return updated_config.SerializeAsString() != + backend_config.SerializeAsString(); } absl::StatusOr RunOnComputation(HloComputation* computation, diff --git a/third_party/xla/xla/service/gpu/gemm_rewriter_triton.cc b/third_party/xla/xla/service/gpu/gemm_rewriter_triton.cc index 95ae7ec9e17883..fbd6ecd4c67c3b 100644 --- a/third_party/xla/xla/service/gpu/gemm_rewriter_triton.cc +++ b/third_party/xla/xla/service/gpu/gemm_rewriter_triton.cc @@ -700,6 +700,9 @@ class GemmRewriterTritonVisitor : public DfsHloRewriteVisitor { dot->parent()->AddInstruction(HloInstruction::CreateFusion( computation->root_instruction()->shape(), HloInstruction::FusionKind::kCustom, fusion_inputs, computation)); + // Copy the metadata of the `dot` to the newly created `fusion` op. This + // is convenient for handling metadata in split-k rewriting subsequently. + dot_fusion->set_metadata(dot->metadata()); dot_fusion->GetModule()->SetAndUniquifyInstrName(dot_fusion, fusion_name); TF_ASSIGN_OR_RETURN(auto gpu_config, diff --git a/third_party/xla/xla/service/gpu/gemm_rewriter_triton_test.cc b/third_party/xla/xla/service/gpu/gemm_rewriter_triton_test.cc index 603076e51bedd0..e604de3327fc95 100644 --- a/third_party/xla/xla/service/gpu/gemm_rewriter_triton_test.cc +++ b/third_party/xla/xla/service/gpu/gemm_rewriter_triton_test.cc @@ -1080,6 +1080,23 @@ e { m::Parameter(), m::Parameter())))); } +TEST_F(GemmRewriterTritonTest, CopiesDotMetadataToFusionOp) { + auto module = ParseAndReturnVerifiedModule(R"( +HloModule m + +ENTRY e { + p0 = f16[2,18] parameter(0) + p1 = f16[256,2] parameter(1) + ROOT d = f16[18,256] dot(p0, p1), + lhs_contracting_dims={0}, rhs_contracting_dims={1}, metadata={op_name="foo"} +})") + .value(); + EXPECT_TRUE(GemmRewriterTriton(gpu_version_).Run(module.get()).value()); + EXPECT_EQ( + module->entry_computation()->root_instruction()->metadata().op_name(), + "foo"); +} + } // namespace } // namespace gpu } // namespace xla diff --git a/third_party/xla/xla/service/gpu/gpu_aot_compilation_test.cc b/third_party/xla/xla/service/gpu/gpu_aot_compilation_test.cc index f61001dcd2361f..aad47e75728e27 100644 --- a/third_party/xla/xla/service/gpu/gpu_aot_compilation_test.cc +++ b/third_party/xla/xla/service/gpu/gpu_aot_compilation_test.cc @@ -26,8 +26,8 @@ limitations under the License. #include "xla/service/compiler.h" #include "xla/service/executable.h" #include "xla/service/platform_util.h" -#include "xla/stream_executor/multi_platform_manager.h" #include "xla/stream_executor/platform.h" +#include "xla/stream_executor/platform_manager.h" #include "xla/stream_executor/stream_executor.h" #include "xla/tests/hlo_test_base.h" #include "tsl/platform/statusor.h" @@ -53,7 +53,7 @@ ENTRY main { auto name = absl::AsciiStrToUpper(PlatformUtil::CanonicalPlatformName("gpu").value()); TF_ASSERT_OK_AND_ASSIGN(se::Platform * platform, - se::MultiPlatformManager::PlatformWithName(name)); + se::PlatformManager::PlatformWithName(name)); TF_ASSERT_OK_AND_ASSIGN(se::StreamExecutor * stream_exec, platform->ExecutorForDevice(0)); @@ -94,7 +94,7 @@ ENTRY main { auto name = absl::AsciiStrToUpper(PlatformUtil::CanonicalPlatformName("gpu").value()); TF_ASSERT_OK_AND_ASSIGN(se::Platform * platform, - se::MultiPlatformManager::PlatformWithName(name)); + se::PlatformManager::PlatformWithName(name)); TF_ASSERT_OK_AND_ASSIGN(se::StreamExecutor * stream_exec, platform->ExecutorForDevice(0)); diff --git a/third_party/xla/xla/service/gpu/gpu_compiler.cc b/third_party/xla/xla/service/gpu/gpu_compiler.cc index 34d2b4c062a0bc..964b9318ce70f8 100644 --- a/third_party/xla/xla/service/gpu/gpu_compiler.cc +++ b/third_party/xla/xla/service/gpu/gpu_compiler.cc @@ -71,12 +71,6 @@ limitations under the License. #include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/ir/hlo_schedule.h" #include "xla/hlo/transforms/hlo_constant_splitter.h" -#include "xla/mlir/backends/gpu/transforms/passes.h" -#include "xla/mlir/runtime/transforms/compilation_pipeline_gpu.h" -#include "xla/mlir/runtime/transforms/compilation_pipeline_options.h" -#include "xla/runtime/compiler.h" -#include "xla/runtime/executable.h" -#include "xla/runtime/jit_executable.h" #include "xla/service/algebraic_simplifier.h" #include "xla/service/all_gather_broadcast_reorder.h" #include "xla/service/all_gather_combiner.h" @@ -119,6 +113,7 @@ limitations under the License. #include "xla/service/float_support.h" #include "xla/service/gather_expander.h" #include "xla/service/gather_simplifier.h" +#include "xla/service/gpu/address_computation_fusion_rewriter.h" #include "xla/service/gpu/alias_passthrough_params.h" #include "xla/service/gpu/all_reduce_blueconnect.h" #include "xla/service/gpu/autotuner_util.h" @@ -164,7 +159,7 @@ limitations under the License. #include "xla/service/gpu/reduction_layout_normalizer.h" #include "xla/service/gpu/reduction_splitter.h" #include "xla/service/gpu/reduction_utils.h" -#include "xla/service/gpu/runtime/executable.h" +#include "xla/service/gpu/rename_fusions.h" #include "xla/service/gpu/runtime_intrinsics.h" #include "xla/service/gpu/scatter_slice_simplifier.h" #include "xla/service/gpu/softmax_rewriter_triton.h" @@ -234,8 +229,8 @@ limitations under the License. #include "xla/stream_executor/device_description.pb.h" #include "xla/stream_executor/dnn.h" #include "xla/stream_executor/gpu/gpu_driver.h" +#include "xla/stream_executor/platform_manager.h" #include "xla/stream_executor/stream_executor.h" -#include "xla/translate/mhlo_to_lhlo_with_xla/mhlo_to_lhlo_with_xla.h" #include "xla/util.h" #include "xla/xla.pb.h" #include "xla/xla_data.pb.h" @@ -302,6 +297,11 @@ class MaybeOwningThreadPool { int default_parallelism) { CHECK_GE(parallelism, 0); CHECK_GE(default_parallelism, 1); + // CurrentThreadId() returns -1 if the current thread does not belong to the + // thread pool. If the current thread belongs to the thread pool, we should + // not be using it, because it can potentially cause deadlocks. + CHECK(default_thread_pool == nullptr || + default_thread_pool->CurrentThreadId() == -1); auto create_thread_pool = [&](int num_threads) { CHECK_GE(num_threads, 1); @@ -357,10 +357,6 @@ MaybeOwningThreadPool::operator bool() const { return get() != nullptr; } bool MaybeOwningThreadPool::operator!() const { return get() == nullptr; } -bool ConvIsLowerable(HloInstruction* conv) { - return GpuConvRewriter::ConvIsLowerable(conv); -} - absl::StatusOr GetAutotuneConfig( se::StreamExecutor* stream_exec, const DebugOptions& debug_options, const GpuCompiler::CompileOptions& options, @@ -379,63 +375,8 @@ se::GpuComputeCapability GetGpuVersion(const se::StreamExecutor* stream_exec) { return stream_exec->GetDeviceDescription().gpu_compute_capability(); } -// TODO(b/232263665): It should be shared between GPU and CPU. -class GpuAotCompilationResult : public AotCompilationResult { - public: - GpuAotCompilationResult( - HloModuleProto hlo, std::string_view obj_file, - std::string_view mlir_module, std::string_view gpu_asm_text, - absl::Span gpu_binary, - absl::Span constants = {}) { - XlaRuntimeExecutableProto xla_runtime_executable; - *xla_runtime_executable.mutable_hlo_module_proto() = hlo; - xla_runtime_executable.set_obj_file(std::string(obj_file)); - xla_runtime_executable.set_mlir_module(std::string(mlir_module)); - *xla_runtime_gpu_executable_.mutable_xla_runtime_executable() = - xla_runtime_executable; - - xla_runtime_gpu_executable_.set_gpu_asm_text(std::string(gpu_asm_text)); - xla_runtime_gpu_executable_.set_gpu_binary(gpu_binary.data(), - gpu_binary.size()); - - for (const GpuExecutable::ConstantInfo& cst : constants) { - auto* cst_proto = xla_runtime_gpu_executable_.add_constants(); - cst_proto->set_symbol_name(cst.symbol_name); - cst_proto->set_allocation_index(cst.allocation_index); - cst_proto->set_content(cst.content.span().data(), - cst.content.span().size()); - } - } - - explicit GpuAotCompilationResult(XlaRuntimeGpuExecutableProto executable) - : xla_runtime_gpu_executable_(executable) {} - - absl::StatusOr SerializeAsString() const override { - return xla_runtime_gpu_executable_.SerializeAsString(); - } - - static absl::StatusOr> FromString( - const std::string& serialized) { - XlaRuntimeGpuExecutableProto xla_runtime_gpu_executable; - if (!xla_runtime_gpu_executable.ParseFromString(serialized)) { - return Internal("Failed to parse serialized JitRtExecutableProto."); - } - return std::make_unique( - xla_runtime_gpu_executable); - } - - absl::StatusOr> LoadExecutable( - Compiler* compiler, const se::StreamExecutor* executor) const override; - - private: - XlaRuntimeGpuExecutableProto xla_runtime_gpu_executable_; -}; - class GpuThunkAotCompilationResult : public AotCompilationResult { public: - explicit GpuThunkAotCompilationResult(CompilationResultProto proto) - : proto_(std::move(proto)) {} - static absl::StatusOr> FromModule(const HloModule* hlo_module, const BufferAssignment* buffer_assignment, @@ -446,7 +387,9 @@ class GpuThunkAotCompilationResult : public AotCompilationResult { *proto.mutable_buffer_assignment() = buffer_assignment->ToProto(); proto.set_asm_text(std::string(asm_text)); proto.set_binary(binary.data(), binary.size()); - return std::make_unique(std::move(proto)); + return std::unique_ptr( + new GpuThunkAotCompilationResult(hlo_module->Clone(), + std::move(proto))); } static absl::StatusOr> @@ -457,7 +400,11 @@ class GpuThunkAotCompilationResult : public AotCompilationResult { "Failed to parse serialized GpuThunkAotCompilationResult."); } - return std::make_unique(std::move(proto)); + TF_ASSIGN_OR_RETURN( + std::unique_ptr module, + HloModule::CreateFromProtoWithConfig(proto.hlo_module_with_config())); + return std::unique_ptr( + new GpuThunkAotCompilationResult(std::move(module), std::move(proto))); } absl::StatusOr SerializeAsString() const override { @@ -467,43 +414,22 @@ class GpuThunkAotCompilationResult : public AotCompilationResult { absl::StatusOr> LoadExecutable( Compiler* compiler, const se::StreamExecutor* stream_exec) const override; + const HloModule* optimized_module() const override { return module_.get(); } + std::unique_ptr consume_optimized_module() override { + return std::move(module_); + } + private: + GpuThunkAotCompilationResult(std::unique_ptr module, + CompilationResultProto proto) + : module_(std::move(module)), proto_(std::move(proto)) {} + + std::unique_ptr module_; CompilationResultProto proto_; }; } // end anonymous namespace -absl::StatusOr> -GpuAotCompilationResult::LoadExecutable( - Compiler* compiler, const se::StreamExecutor* executor) const { - XlaRuntimeExecutableProto xla_runtime_executable = - xla_runtime_gpu_executable_.xla_runtime_executable(); - TF_ASSIGN_OR_RETURN(HloModuleConfig hlo_module_config, - HloModule::CreateModuleConfigFromProto( - xla_runtime_executable.hlo_module_proto(), - GetDebugOptionsFromFlags())); - TF_ASSIGN_OR_RETURN( - std::unique_ptr hlo_module, - HloModule::CreateFromProto(xla_runtime_executable.hlo_module_proto(), - hlo_module_config)); - std::vector constants; - for (auto& cst : xla_runtime_gpu_executable_.constants()) { - GpuExecutable::ConstantInfo constant = { - cst.symbol_name(), - DenseDataIntermediate::Own( - std::vector{cst.content().begin(), cst.content().end()}), - cst.allocation_index()}; - constants.push_back(std::move(constant)); - } - - return GpuExecutable::LoadFromObjFile( - std::move(hlo_module), xla_runtime_executable.obj_file(), - xla_runtime_executable.mlir_module(), GetDebugOptionsFromFlags(), - xla_runtime_gpu_executable_.gpu_asm_text(), - xla_runtime_gpu_executable_.gpu_binary(), std::move(constants), - GetGpuVersion(executor)); -} - absl::StatusOr> GpuThunkAotCompilationResult::LoadExecutable( Compiler* compiler, const se::StreamExecutor* stream_exec) const { @@ -524,7 +450,7 @@ GpuThunkAotCompilationResult::LoadExecutable( // Build the executable, which should be a thunk sequence. TF_ASSIGN_OR_RETURN( se::Platform * platform, - se::MultiPlatformManager::PlatformWithId(compiler->PlatformId())); + se::PlatformManager::PlatformWithId(compiler->PlatformId())); std::string platform_name = platform->Name(); se::DeviceDescription gpu_device_info = stream_exec->GetDeviceDescription(); mlir::DialectRegistry registry; @@ -541,25 +467,10 @@ GpuThunkAotCompilationResult::LoadExecutable( IrEmitterContext ir_emitter_context(hlo_module.get(), buffer_assignment.get(), platform_name, gpu_device_info, mlir_context.get(), llvm_module.get(), - /*emit_ir_from_hlo=*/true, /*emit_kernels=*/false); - mlir::OwningOpRef mlir_module = llvm_ir::CreateMlirModuleOp( - mlir::Builder(mlir_context.get()).getUnknownLoc(), hlo_module->name()); - std::vector ordered_allocations; - absl::flat_hash_map - operation_map; - TF_RETURN_IF_ERROR(HloToLhloModule(*buffer_assignment, *hlo_module, - *mlir_module, &ordered_allocations, - &operation_map)); - ir_emitter_context.set_allocations(ordered_allocations); auto ir_emitter = IrEmitterUnnested::Create(&ir_emitter_context); - auto entry_function = mlir::cast( - mlir_module->lookupSymbol(hlo_module->entry_computation()->name())); - // TODO(anlunx): EmitLmhloRegion emits fusion kernels. We need to make sure - // ptx and cubin already contain emission results and disable kernel emission - // here. TF_RETURN_IF_ERROR( - ir_emitter->EmitLmhloRegion(&entry_function.getBody(), operation_map)); + ir_emitter->EmitHloComputation(hlo_module->entry_computation())); std::unique_ptr thunk_sequence = ir_emitter->ConsumeThunkSequence(); ForAllThunks([](Thunk* thunk) { thunk->ClearCompileTimeInfo(); }, @@ -621,15 +532,6 @@ void AddHloVerifier(HloPassPipeline* pipeline, HloVerifierOpts&& opts = {}, "hlo verifier"); } } - -void SetInstructionMetadata(HloModule* module) { - for (HloComputation* computation : module->computations()) { - for (HloInstruction* instruction : computation->instructions()) { - instruction->set_creation_pass_id(-1); - instruction->set_logical_creation_pass_id(-1); - } - } -} } // namespace // Runs optimization passes on the given HLO module. @@ -658,8 +560,14 @@ absl::Status GpuCompiler::OptimizeHloModule( /*default_thread_pool=*/options.thread_pool, /*default_parallelism=*/tsl::port::MaxParallelism()); - AlgebraicSimplifierOptions layout_insensitive_algsimp_opts({}, - ConvIsLowerable); + AlgebraicSimplifierOptions layout_insensitive_algsimp_opts = + GetAlgebraicSimplifierOptions(hlo_module->config()); + layout_insensitive_algsimp_opts.set_conv_is_lowerable_callback( + GpuConvRewriter::ConvIsLowerable); + layout_insensitive_algsimp_opts.set_enable_dot_strength_reduction( + hlo_module->config() + .debug_options() + .xla_gpu_enable_dot_strength_reduction()); // GPU only supports canonical convolutions. layout_insensitive_algsimp_opts.set_supports_non_canonical_dots(false); @@ -682,8 +590,6 @@ absl::Status GpuCompiler::OptimizeHloModule( layout_insensitive_algsimp_opts .set_enable_unconditional_reduce_of_concat_replacement(false); - SetInstructionMetadata(hlo_module); - HloPassPipeline pre_spmd_pipeline("pre-spmd-partitioner"); // Run some IR cleanup passes before running the SPMD partitioning // passes. @@ -691,6 +597,8 @@ absl::Status GpuCompiler::OptimizeHloModule( pre_spmd_pipeline.AddPass(); pre_spmd_pipeline.AddPass(); + // The TopkDecomposer generates a compare op with type=TOTALORDER and must + // run before the ComparisonExpander which rewrites such comparisons. pre_spmd_pipeline.AddPass([&](const HloInstruction* instr) { return instr->opcode() == HloOpcode::kTopK; }); @@ -824,7 +732,7 @@ absl::Status GpuCompiler::OptimizeHloModule( pipeline.AddPass(RandomAlgorithm::RNG_PHILOX); // Comparison total order expander - pipeline.AddPass(); + pipeline.AddPass(std::array{std::make_pair(BF16, F32)}); // Remove zero-sized HLO from the input so that other passes don't have to // handle it. @@ -1137,6 +1045,7 @@ absl::Status GpuCompiler::OptimizeHloModule( { HloPassPipeline pipeline("post-fusion optimization"); + pipeline.AddPass(); pipeline.AddPass( hlo_module->config() .debug_options() @@ -1283,6 +1192,14 @@ absl::Status GpuCompiler::OptimizeHloModule( return absl::OkStatus(); } +AlgebraicSimplifierOptions GpuCompiler::GetAlgebraicSimplifierOptions( + const HloModuleConfig& config) { + AlgebraicSimplifierOptions opts; + opts.set_enable_dot_strength_reduction( + config.debug_options().xla_gpu_enable_dot_strength_reduction()); + return opts; +} + // Modifies the given HLO module so that it will be accepted by IrEmitter. // Unlike optimization passes, the passes are necessary for correctness. absl::Status GpuCompiler::PrepareHloModuleForIrEmitting(HloModule* hlo_module) { @@ -1300,7 +1217,8 @@ absl::Status GpuCompiler::OptimizeHloPostLayoutAssignment( const se::GpuComputeCapability gpu_version = gpu_target_config.device_description.gpu_compute_capability(); const AlgebraicSimplifierOptions simplifier_options = [&] { - AlgebraicSimplifierOptions opts; + AlgebraicSimplifierOptions opts = + GetAlgebraicSimplifierOptions(hlo_module->config()); opts.set_supports_non_canonical_dots(false); opts.set_is_layout_sensitive(true); opts.set_enable_conv_operand_swap(false); @@ -1329,9 +1247,9 @@ absl::Status GpuCompiler::OptimizeHloPostLayoutAssignment( sub_pipeline.AddPass(&f8e5m2fnuz_support); sub_pipeline.AddPass(&f8e4m3fnuz_support); // Remove `f32 -> bf16 -> f32` casts inserted by bf16 normalization. - if (debug_options.xla_gpu_simplify_all_fp_conversions()) { - sub_pipeline.AddPass( - SimplifyFPConversions::Scope::kSimplifyAllConversions); + if (debug_options.xla_allow_excess_precision() && + debug_options.xla_gpu_simplify_all_fp_conversions()) { + sub_pipeline.AddPass(); } }; @@ -1438,6 +1356,10 @@ absl::Status GpuCompiler::OptimizeHloPostLayoutAssignment( pipeline.AddPass(gpu_version); // Rewrite GEMMs with broadcasted inputs as strided GEMMs. pipeline.AddPass(); + if (debug_options.xla_gpu_normalize_layouts()) { + pipeline.AddPass(&NormalizeLayoutForGpuCustomCalls); + pipeline.AddPass>(simplifier_options); + } TF_RETURN_IF_ERROR(AddConvAndGemmAutotuningPasses( &pipeline, hlo_module, autotune_config, thread_pool)); @@ -1453,14 +1375,13 @@ absl::Status GpuCompiler::OptimizeHloPostLayoutAssignment( // duplicate or NOPs, so remove them with algebraic simplification and CSE. pipeline.AddPass>(simplifier_options); - if (debug_options.xla_gpu_simplify_all_fp_conversions()) { + if (debug_options.xla_allow_excess_precision() && + debug_options.xla_gpu_simplify_all_fp_conversions()) { // This pass cleans up chains of compiler-generated converts // (i.e. f32 -> bf16 -> f32) that have been produced by the algebraic // simplifier by rearranging ops (i.e. by pushing broadcasts towards the // root). - pipeline.AddPass( - SimplifyFPConversions::Scope:: - kOnlySimplifyCompilerGeneratedConversions); + pipeline.AddPass(); } // Since this CSE runs after collective schedule linearizer which inserts @@ -1469,8 +1390,32 @@ absl::Status GpuCompiler::OptimizeHloPostLayoutAssignment( pipeline.AddPass(/*is_layout_sensitive=*/true, /*only_fusion_computations=*/false, /*ignore_control_dependencies=*/true); + +#ifdef NDEBUG + // Verify the module in non-debug builds. For debug builds, the verifier + // already runs after every pass. + pipeline.AddPass( + std::make_unique( + HloVerifierOpts{} + .MakeLayoutSensitive() + .WithInstructionCanChangeLayout( + LayoutAssignment::InstructionCanChangeLayout) + .VerifyBroadcastDimensionsOrder() + .VerifyReshapeIsBitcast()), + "end-of-post-layout_assignment"); +#endif // NDEBUG + TF_RETURN_IF_ERROR(pipeline.Run(hlo_module).status()); + if (DumpingEnabledForHloModule(*hlo_module)) { + TF_ASSIGN_OR_RETURN( + std::string autotune_results, + AutotunerUtil::SerializeAutotuneResultsForModule( + *hlo_module, autotune_config, /*as_textproto=*/true)); + DumpToFileInDirOrStdout(*hlo_module, "", "autotune_results.pbtxt", + autotune_results); + } + return absl::OkStatus(); } @@ -1880,7 +1825,7 @@ GpuCompiler::CompileToBackendResult( module, schedule_metadata.scheduler_mem_limit, gpu_device_info)); TF_ASSIGN_OR_RETURN(se::Platform * platform, - se::MultiPlatformManager::PlatformWithId(PlatformId())); + se::PlatformManager::PlatformWithId(PlatformId())); // Compile the module TF_ASSIGN_OR_RETURN( @@ -1902,14 +1847,9 @@ GpuCompiler::CompileToBackendResult( module->config(), compile_module_results.llvm_module.get(), gpu_device_info.gpu_compute_capability(), executor, options, module)); RecordXlaDeviceBinarySize(backend_result.binary.size()); - if (DumpingEnabledForHloModule(*module) && - std::holds_alternative( - compile_module_results.executable)) { - const ThunkSequence& thunk_sequence = - *std::get( - compile_module_results.executable); + if (DumpingEnabledForHloModule(*module)) { DumpToFileInDirOrStdout(*module, "", "thunk_sequence.txt", - thunk_sequence.ToString()); + compile_module_results.executable->ToString()); } return CompileResultWithMetadata{std::move(backend_result), @@ -1973,11 +1913,9 @@ absl::StatusOr> GpuCompiler::RunBackend( CompileToBackendResult(module.get(), &llvm_context, stream_exec, options, gpu_device_info)); - if (auto thunk_sequence = std::get_if( - &res.compile_module_results.executable); - DumpingEnabledForHloModule(*module) && thunk_sequence) { + if (DumpingEnabledForHloModule(*module)) { DumpToFileInDirOrStdout(*module, "", "thunk_sequence.txt", - (*thunk_sequence)->ToString()); + res.compile_module_results.executable->ToString()); } // The module is being moved into the GpuExecutable below and we need to @@ -2062,73 +2000,14 @@ GpuCompiler::CompileAheadOfTime(std::unique_ptr module_group, CompileToBackendResult(module.get(), &llvm_context, options.executor(), {options.device_allocator()}, gpu_device_info)); - if (!IsXlaRuntimeExecutableEnabled(module->config())) { - // Create GpuThunkAotCompilationResult if thunk runtime is enabled. - TF_ASSIGN_OR_RETURN( - results.emplace_back(), - GpuThunkAotCompilationResult::FromModule( - module.get(), res.compile_module_results.buffer_assignment.get(), - res.backend_result.asm_text, res.backend_result.binary)); - continue; - } - - const auto* program = std::get_if( - &res.compile_module_results.executable); - if (!program) { - return Internal("Gpu runtime program was not provided"); - } - - // TODO(ezhulenev): Unify AOT compilation with GpuRuntimeExecutable::Create - // (see `gpu/runtime/executable.h`). - - // Options for the default XLA runtime compilation pipeline. - runtime::CompilationPipelineOptions copts; - - // Populate mapping from XLA (SE) enums/structs type id to symbol names. - copts.populate_type_id_names = RegisterXlaGpuTypeIdNames; - - // For passing LMHLO attributes as XLA (SE) enums/structs to custom calls. - copts.populate_attr_encodings = RegisterXlaGpuAttrEncoding; - - // Options for constructing XLA runtime JitExecutable. - runtime::JitExecutable::Options opts; - opts.specialization = runtime::JitExecutable::Specialization::kDisabled; - opts.compiler.register_dialects = - runtime::RegisterDefaultXlaGpuRuntimeDialects; - - // Register XLA Gpu runtime custom calls with the linker. - opts.compiler.symbols_binding = runtime::ToSymbolsBinding( - RegisterXlaGpuRuntimeCustomCalls, RegisterXlaGpuTypeIdNames); - - opts.compiler.create_compilation_pipeline = - [copts](xla::runtime::PassManager& passes) { - runtime::CreateDefaultXlaGpuRuntimeCompilationPipeline(passes, copts); - return absl::OkStatus(); - }; - - // Instantiate new JitExecutable from the MLIR source. - auto jit_executable = runtime::JitExecutable::Instantiate( - (*program)->module, (*program)->entry_point, opts); - if (!jit_executable.ok()) - return Internal("Failed to compile XLA program: %s", - jit_executable.status().message()); - - // For static shapes we can always serialize only the default executable. - runtime::Executable& executable = jit_executable->DefaultExecutable().get(); - - // Check if XLA runtime executable saved the compilation result. - std::unique_ptr obj_file = executable.obj_file(); - if (!obj_file) - return Internal("XLA runtime executable didn't save the obj file"); - - std::string data(obj_file->getBuffer().data(), - obj_file->getBuffer().size()); - - results.emplace_back(std::make_unique( - module->ToProto(), data, (*program)->module, - res.backend_result.asm_text, res.backend_result.binary, - res.compile_module_results.constants)); + // Create GpuThunkAotCompilationResult if thunk runtime is enabled. + TF_ASSIGN_OR_RETURN( + results.emplace_back(), + GpuThunkAotCompilationResult::FromModule( + module.get(), res.compile_module_results.buffer_assignment.get(), + res.backend_result.asm_text, res.backend_result.binary)); } + return std::move(results); } @@ -2144,11 +2023,6 @@ absl::StatusOr> GpuCompiler::Export( auto* gpu_executable = tensorflow::down_cast(executable); if (!gpu_executable) return Internal("GpuExecutable is null"); - if (gpu_executable->IsXlaRuntimeEnabled()) { - return absl::InternalError( - "Exporting executables when XLA runtime is enabled is not supported"); - } - return GpuThunkAotCompilationResult::FromModule( &gpu_executable->module(), gpu_executable->buffer_assignment(), gpu_executable->text(), gpu_executable->binary()); @@ -2195,6 +2069,16 @@ absl::Status GpuCompiler::RunPostSchedulingPipelines( } } + if (module->config() + .debug_options() + .xla_gpu_enable_address_computation_fusion()) { + HloPassPipeline pipeline("address-computation"); + TF_ASSIGN_OR_RETURN(se::Platform * platform, + se::PlatformManager::PlatformWithId(PlatformId())); + pipeline.AddPass(platform->Name()); + TF_RETURN_IF_ERROR(pipeline.Run(module).status()); + } + { HloPassPipeline pipeline("fusion-wrapper"); pipeline.AddPass(); @@ -2215,8 +2099,9 @@ absl::Status GpuCompiler::RunPostSchedulingPipelines( constexpr int toolkit_version = TF_ROCM_VERSION; #endif pipeline.AddPass( - gpu_device_info.gpu_compute_capability(), toolkit_version, + gpu_device_info, toolkit_version, driver_version.value_or(toolkit_version)); + pipeline.AddPass(); TF_RETURN_IF_ERROR(pipeline.Run(module).status()); } diff --git a/third_party/xla/xla/service/gpu/gpu_compiler.h b/third_party/xla/xla/service/gpu/gpu_compiler.h index b4c1ef377654aa..95a7faae24f4f8 100644 --- a/third_party/xla/xla/service/gpu/gpu_compiler.h +++ b/third_party/xla/xla/service/gpu/gpu_compiler.h @@ -28,6 +28,7 @@ limitations under the License. #include "xla/autotune_results.pb.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_module_group.h" +#include "xla/service/algebraic_simplifier.h" #include "xla/service/buffer_assignment.h" #include "xla/service/compiler.h" #include "xla/service/executable.h" @@ -165,6 +166,9 @@ class GpuCompiler : public LLVMCompiler { return absl::OkStatus(); } + AlgebraicSimplifierOptions GetAlgebraicSimplifierOptions( + const HloModuleConfig& config); + private: struct CompileResultWithMetadata { BackendCompileResult backend_result; diff --git a/third_party/xla/xla/service/gpu/gpu_compiler_test.cc b/third_party/xla/xla/service/gpu/gpu_compiler_test.cc index c8117397fec9ae..91622d7959aa39 100644 --- a/third_party/xla/xla/service/gpu/gpu_compiler_test.cc +++ b/third_party/xla/xla/service/gpu/gpu_compiler_test.cc @@ -26,6 +26,7 @@ limitations under the License. #include "xla/service/buffer_assignment.h" #include "xla/service/gpu/horizontal_loop_fusion.h" #include "xla/service/gpu/metrics.h" +#include "xla/service/hlo_module_config.h" #include "xla/service/pattern_matcher.h" #include "xla/service/pattern_matcher_gmock.h" #include "xla/service/xla_debug_info_manager.h" @@ -309,6 +310,60 @@ ENTRY main { HloOpcode::kAllGatherDone); } +TEST_F(GpuCompilerTest, + GemmRewriterTritonIsNoOpWhenTritonAutotunerFallsBackToCublas) { + const absl::string_view hlo_string = R"( +HloModule test + +ENTRY main { + param_0 = bf16[3,32,1024,4,1024]{4,3,2,1,0} parameter(0) + param_1 = bf16[4,3,32,1024]{3,2,1,0} parameter(1) + param_2 = s32[] parameter(2) + constant_0 = s32[] constant(0) + dynamic-slice_0 = bf16[1,3,32,1024]{3,2,1,0} dynamic-slice(param_1, param_2, constant_0, constant_0, constant_0), dynamic_slice_sizes={1,3,32,1024} + reshape_0 = bf16[3,32,1024]{2,1,0} reshape(dynamic-slice_0) + broadcast_0 = bf16[3,32,1024,4,1024]{2,1,4,3,0} broadcast(reshape_0), dimensions={0,1,2} + add_0 = bf16[3,32,1024,4,1024]{4,3,2,1,0} add(param_0, broadcast_0) + transpose_0 = bf16[3,4,1024,32,1024]{2,1,4,3,0} transpose(add_0), dimensions={0,3,4,1,2} + slice_0 = bf16[1,4,1024,32,1024]{4,3,2,1,0} slice(transpose_0), slice={[0:1], [0:4], [0:1024], [0:32], [0:1024]} + reshape_1 = bf16[4,1024,32,1024]{3,2,1,0} reshape(slice_0) + copy_0 = bf16[4,1024,32,1024]{3,2,1,0} copy(reshape_1) + constant_1 = bf16[] constant(0.08838) + broadcast_1 = bf16[4,1024,32,1024]{3,2,1,0} broadcast(constant_1), dimensions={} + multiply_0 = bf16[4,1024,32,1024]{3,2,1,0} multiply(copy_0, broadcast_1) + slice_1 = bf16[1,4,1024,32,1024]{4,3,2,1,0} slice(transpose_0), slice={[1:2], [0:4], [0:1024], [0:32], [0:1024]} + reshape_2 = bf16[4,1024,32,1024]{3,2,1,0} reshape(slice_1) + copy_1 = bf16[4,1024,32,1024]{3,2,1,0} copy(reshape_2) + ROOT dot_0 = bf16[4,32,1024,1024]{3,2,1,0} dot(multiply_0, copy_1), lhs_batch_dims={0,2}, lhs_contracting_dims={3}, rhs_batch_dims={0,2}, rhs_contracting_dims={3} +} +)"; + + HloModuleConfig config; + DebugOptions debug_options = GetDebugOptionsForTest(); + config.set_debug_options(GetDebugOptionsForTest()); + config.set_replica_count(1); + config.set_num_partitions(1); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_string, config)); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr triton_enabled_module, + GetOptimizedModule(std::move(module))); + debug_options.set_xla_gpu_enable_triton_gemm(false); + config.set_debug_options(debug_options); + TF_ASSERT_OK_AND_ASSIGN(module, + ParseAndReturnVerifiedModule(hlo_string, config)); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr triton_disabled_module, + GetOptimizedModule(std::move(module))); + // Make sure autotuner falls back to cuBLAS when enabling triton gemm + const HloInstruction* root = + triton_enabled_module->entry_computation()->root_instruction(); + const HloInstruction* custom_op = root->operand(0)->operand(0); + EXPECT_TRUE(custom_op->IsCustomCall("__cublas$gemm")); + // Make sure that the module has the same number of computations with/without + // enabling triton gemm + EXPECT_EQ(triton_enabled_module->computation_count(), + triton_disabled_module->computation_count()); +} + } // namespace } // namespace gpu } // namespace xla diff --git a/third_party/xla/xla/service/gpu/gpu_conv_rewriter.cc b/third_party/xla/xla/service/gpu/gpu_conv_rewriter.cc index 697b69306f127f..ca1388e132caae 100644 --- a/third_party/xla/xla/service/gpu/gpu_conv_rewriter.cc +++ b/third_party/xla/xla/service/gpu/gpu_conv_rewriter.cc @@ -43,6 +43,30 @@ namespace { using ConvolutionMatch = std::optional< std::tuple>; +// Determine whether conv2d is equal to conv1d. +bool MaybeConv1dToConv2d(HloInstruction* conv) { + if (conv->window().dimensions().size() != 2) { + return false; + } + if (conv->operand(1)->opcode() != HloOpcode::kReshape) { + return false; + } + auto filter = conv->operand(1); + std::optional reshape_degenerate = + filter->ReshapeMerelyInsertsOrDeletes1SizedDimensions(); + if (reshape_degenerate.has_value() && + reshape_degenerate->deleted_dimensions.empty() && + reshape_degenerate->inserted_dimensions.size() == 1) { + auto dnums = conv->convolution_dimension_numbers(); + for (auto dim : dnums.kernel_spatial_dimensions()) { + if (dim == reshape_degenerate->inserted_dimensions[0]) { + return true; + } + } + } + return false; +} + bool CanImplementAsGpuForwardConv(HloInstruction* conv) { const ConvolutionDimensionNumbers& dnums = conv->convolution_dimension_numbers(); @@ -145,14 +169,18 @@ ConvolutionMatch MatchBackwardFilter(HloInstruction* conv) { // convolutions have very small kernel dimensions, while in the backward pass // "kernel dimensions" are large. If kernel dimensions are smaller than the // output dimensions, return foward conv; otherwise proceed with backward - // filter conv. - bool exists_small_kernel_dimension = false; + // filter conv. But for conv1d, it is not same. Due to conv1d always reshape + // 1D-filter to 2D-filter, even backward or forward will exist one small + // kernel dimension. We should handle this special case. + int small_kernel_dimension_num = 0; for (int i = 0; i < kernel_spatial_dims.size(); ++i) { - exists_small_kernel_dimension |= - (conv->operand(1)->shape().dimensions(kernel_spatial_dims[i]) <= - conv->shape().dimensions(output_spatial_dims[i])); + if (conv->operand(1)->shape().dimensions(kernel_spatial_dims[i]) <= + conv->shape().dimensions(output_spatial_dims[i])) { + small_kernel_dimension_num += 1; + } } - if ((kernel_spatial_dims.empty() || exists_small_kernel_dimension) && + if ((kernel_spatial_dims.empty() || small_kernel_dimension_num > 1 || + (!MaybeConv1dToConv2d(conv) && small_kernel_dimension_num == 1)) && !window_util::HasWindowDilation(conv->window())) { VLOG(1) << conv->ToString() << " is a regular forward convolution. No need " @@ -313,10 +341,18 @@ ConvolutionMatch MatchBackwardInput(HloInstruction* conv) { reverse_filter->opcode() == HloOpcode::kReverse && absl::c_is_permutation(dnums.kernel_spatial_dimensions(), reverse_filter->dimensions()); + // For conv1d which reshape to conv2d, filter reverse pattern is + // reshape(reverse(filter)). It seems we can reuse conv2d backward input + // pattern matcher, but after algsimp pass, this pattern will change to + // reverse(reshape(filter)) and fail to match. So matching conv1d backward + // input need different processing logic. + bool is_reversed_conv1d_filter = + MaybeConv1dToConv2d(conv) && + reverse_filter->operand(0)->opcode() == HloOpcode::kReverse; bool is_1x1_filter = absl::c_all_of(conv->window().dimensions(), [](const WindowDimension& d) { return d.size() == 1; }); - if (!is_reversed_filter && + if (!is_reversed_filter && !is_reversed_conv1d_filter && !(window_util::HasBaseDilation(conv->window()) && (reverse_filter->IsConstant() || is_1x1_filter))) { VLOG(1) << "Can't match to backwards convolution. Either filter is not " @@ -488,6 +524,10 @@ ConvolutionMatch MatchBackwardInput(HloInstruction* conv) { // One reverse is subsumed by the cuDNN call. if (rhs->opcode() == HloOpcode::kReverse) { rhs = rhs->mutable_operand(0); + } else if (is_reversed_conv1d_filter) { + auto src = rhs->mutable_operand(0)->mutable_operand(0); + rhs = conv->parent()->AddInstruction( + HloInstruction::CreateReshape(rhs->shape(), src)); } if (conv->feature_group_count() == 1) { return std::make_tuple(new_window, dnums, rhs); diff --git a/third_party/xla/xla/service/gpu/gpu_conv_rewriter_test.cc b/third_party/xla/xla/service/gpu/gpu_conv_rewriter_test.cc index 0044aced926bfe..f161a3012d0b63 100644 --- a/third_party/xla/xla/service/gpu/gpu_conv_rewriter_test.cc +++ b/third_party/xla/xla/service/gpu/gpu_conv_rewriter_test.cc @@ -686,6 +686,52 @@ TEST_F(GpuConvRewriterTest, TestBackwardFilterPatternNoMatch) { 0))); } +TEST_F(GpuConvRewriterTest, TestConv1dBackwardFilterPatternMatch) { + // There exist one kernel dimension equal to output dimension, regard + // it as backward filter if conv is 1d. + const std::string module_str = absl::StrFormat(R"( + HloModule Test + + ENTRY Test { + input = f32[8,256,128] parameter(0) + filter = f32[8,254,128] parameter(1) + reshape.1 = f32[8,1,256,128] reshape(input) + reshape.2 = f32[8,1,254,128] reshape(filter) + ROOT conv = f32[1,3,128,128] convolution(reshape.1, reshape.2), window={size=1x254}, dim_labels=f01b_i01o->01bf + })"); + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str)); + + EXPECT_TRUE(RunPass(m.get())); + EXPECT_THAT(m->entry_computation()->root_instruction(), + GmockMatch(m::GetTupleElement( + m::CustomCall({kCudnnConvBackwardFilterCallTarget}, + m::Reshape(), m::Reshape()), + 0))); +} + +TEST_F(GpuConvRewriterTest, TestConv1dBackwardInputPatternMatch) { + // For conv1d backward input, filter may reverse first and then reshape. + const std::string module_str = absl::StrFormat(R"( + HloModule Test + + ENTRY Test { + input = f32[8,254,128] parameter(0) + filter = f32[3,128,128] parameter(1) + reverse = f32[3,128,128] reverse(filter), dimensions={0} + reshape.1 = f32[8,1,254,128] reshape(input) + reshape.2 = f32[1,3,128,128] reshape(reverse) + ROOT conv = f32[8,1,256,128] convolution(reshape.1, reshape.2), window={size=1x3 pad=0_0x2_2}, dim_labels=b01f_01oi->b01f + })"); + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str)); + + EXPECT_TRUE(RunPass(m.get())); + EXPECT_THAT(m->entry_computation()->root_instruction(), + GmockMatch(m::GetTupleElement( + m::CustomCall({kCudnnConvBackwardInputCallTarget}, + m::Reshape(), m::Reshape()), + 0))); +} + } // anonymous namespace } // namespace gpu } // namespace xla diff --git a/third_party/xla/xla/service/gpu/gpu_executable.cc b/third_party/xla/xla/service/gpu/gpu_executable.cc index c786983385be86..d3d8c639e4a722 100644 --- a/third_party/xla/xla/service/gpu/gpu_executable.cc +++ b/third_party/xla/xla/service/gpu/gpu_executable.cc @@ -16,20 +16,19 @@ limitations under the License. #include "xla/service/gpu/gpu_executable.h" #include -#include #include +#include #include #include #include #include -#include #include #include #include -#include "absl/algorithm/container.h" #include "absl/container/btree_map.h" #include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "absl/container/inlined_vector.h" #include "absl/log/check.h" #include "absl/status/status.h" @@ -38,21 +37,17 @@ limitations under the License. #include "absl/synchronization/mutex.h" #include "absl/time/time.h" #include "absl/types/span.h" -#include "mlir/Parser/Parser.h" // from @llvm-project +#include "xla/executable_run_options.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/map_util.h" -#include "xla/mlir/runtime/ir/rt_ops.h" -#include "xla/mlir/runtime/transforms/compilation_pipeline_gpu.h" -#include "xla/mlir/runtime/transforms/type_converter.h" -#include "xla/runtime/executable.h" #include "xla/service/buffer_assignment.h" +#include "xla/service/gpu/backend_configs.pb.h" #include "xla/service/gpu/buffer_allocations.h" #include "xla/service/gpu/gpu_constants.h" +#include "xla/service/gpu/gpu_executable_run_options.h" #include "xla/service/gpu/nccl_clique.h" #include "xla/service/gpu/nccl_clique_key.h" -#include "xla/service/gpu/non_atomically_upgradeable_rw_lock.h" -#include "xla/service/gpu/runtime/executable.h" -#include "xla/service/gpu/runtime/tracing.h" +#include "xla/service/gpu/runtime/annotation.h" #include "xla/service/gpu/stream_executor_util.h" #include "xla/service/gpu/thunk.h" #include "xla/service/hlo_parser.h" @@ -98,25 +93,22 @@ class GpuTimer {}; namespace xla { namespace gpu { +using ::tsl::profiler::ScopedAnnotation; + bool IsXlaRuntimeExecutableEnabled(const HloModuleConfig& config) { bool enabled = config.debug_options().xla_gpu_enable_xla_runtime_executable(); if (enabled) { - LOG(WARNING) - << "XLA:GPU uses deprecated xla runtime by setting " - "--xla_gpu_enable_xla_runtime_executable flag to true. This flag " + LOG(ERROR) + << "XLA:GPU tried to use deprecated xla runtime by setting " + "--xla_gpu_enable_xla_runtime_executable flag to `true` but the " + "flag value was ignored as XLA:GPU uses default runtime. This flag " "together with the deprecated code will be removed soon. Please " - "check that your workloads can run with default XLA runtime and if " - "not report bugs to XLA team ASAP."; + "report bugs to XLA team if this breaks your workloads."; } - return enabled; + return false; } -namespace { - -using ::tsl::profiler::ScopedAnnotation; -using ::tsl::profiler::ScopedAnnotationAlways; - -bool NeedsAsyncCommsStream(Thunk& thunk) { +static bool NeedsAsyncCommsStream(Thunk& thunk) { switch (thunk.kind()) { case Thunk::Kind::kNcclAllReduceStart: case Thunk::Kind::kNcclAllReduceDone: @@ -126,27 +118,34 @@ bool NeedsAsyncCommsStream(Thunk& thunk) { } } -} // namespace +// Traverses operations in HLO module and collects execution stream ids +// requested by HLO operations. At run time thunks may use additional streams to +// launch compute operations in addition to a main one. +// +// TODO(ezhulenev): Execution stream requirements should be queried from thunks +// directly and not from HLO module that might be missing. +static absl::flat_hash_set GetExecutionStreamIds( + const HloModule& module) { + absl::flat_hash_set stream_ids; + for (const HloComputation* comp : module.computations()) { + for (const HloInstruction* hlo : comp->instructions()) { + if (hlo->has_backend_config() && + hlo->backend_config().ok()) { + int64_t op_queue_id = hlo->backend_config() + .value() + .operation_queue_id(); + if (op_queue_id > 0) { + stream_ids.insert(ExecutionStreamId(op_queue_id)); + } + } + } + } + return stream_ids; +} absl::StatusOr> GpuExecutable::Create( Params params) { - auto executable = std::move(params.executable); - std::unique_ptr result(new GpuExecutable(std::move(params))); - - if (std::holds_alternative(executable)) { - result->thunks_ = std::move(std::get(executable)); - return result; - } - - if (std::holds_alternative(executable)) { - auto& program = std::get(executable); - TF_ASSIGN_OR_RETURN( - result->gpu_runtime_executable_, - GpuRuntimeExecutable::Create(result->module_name_, std::move(program))); - return result; - } - - return Internal("No XLA gpu executable was provided"); + return std::unique_ptr(new GpuExecutable(std::move(params))); } // Implementation note: HLO profiling is always enabled for GPU executables, @@ -156,6 +155,10 @@ GpuExecutable::GpuExecutable(GpuExecutable::Params params) text_(std::move(params.asm_text)), binary_(std::move(params.binary)), gpu_version_(params.gpu_version), + thunks_(std::move(params.executable)), + execution_stream_ids_(has_module() + ? GetExecutionStreamIds(module()) + : absl::flat_hash_set()), module_name_(params.module_name), output_shape_(params.output_shape), allocations_(std::move(params.mlir_allocations)), @@ -172,9 +175,6 @@ GpuExecutable::GpuExecutable(GpuExecutable::Params params) *(uint64_t*)(&binary_[binary_.size() - 16]) = tsl::EnvTime::NowNanos(); *(uint64_t*)(&binary_[binary_.size() - 8]) = tsl::random::New64(); #endif - if (has_module()) { - annotation_info_.emplace(module()); - } if (has_module() && enable_debug_info_manager_) { XlaDebugInfoManager::Get()->RegisterModule(shared_module(), buffer_assignment_->ToProto()); @@ -248,7 +248,10 @@ class ResourceRequests : public Thunk::ResourceRequests { VLOG(2) << "Acquire " << cliques_.size() << " collective cliques for global device id " << params.global_device_id.value() - << "; run_id=" << params.run_id.ToInt(); + << "; run_id=" << params.run_id.ToInt() + << "; max number of channels for collectives " + << params.collective_max_nchannels + << "; max number of channels for p2p " << params.p2p_max_nchannels; tsl::profiler::TraceMe trace([&] { return tsl::profiler::TraceMeEncode("AcquireCollectiveCliques", @@ -257,7 +260,7 @@ class ResourceRequests : public Thunk::ResourceRequests { auto start_micros = tsl::Env::Default()->NowMicros(); - Thunk::CollectiveCliques::CliquesMap cliques_map; + NcclClique::AcquiredCliquesMap cliques_map; for (const auto& [clique_key, num_local_participants] : cliques_) { std::optional rank = clique_key.rank(params.global_device_id); @@ -273,10 +276,15 @@ class ResourceRequests : public Thunk::ResourceRequests { const NcclCliqueIdCallback* clique_id_callback, GetNcclCliqueIdCallback(params.nccl_clique_id_callback, is_local)); - TF_ASSIGN_OR_RETURN(std::shared_ptr clique, - AcquireNcclClique(params.run_id, OpId(0), clique_key, - *clique_id_callback, *rank, - num_local_participants, false)); + int64_t max_channels = + clique_key.stream_kind() == AsyncStreamKind::kCollective + ? params.collective_max_nchannels + : params.p2p_max_nchannels; + TF_ASSIGN_OR_RETURN( + std::shared_ptr clique, + AcquireNcclClique(params.executor, params.run_id, clique_key, + *clique_id_callback, *rank, num_local_participants, + cliques_map, max_channels)); cliques_map[clique_key] = std::move(clique); } @@ -293,8 +301,9 @@ class ResourceRequests : public Thunk::ResourceRequests { private: // Keep all clique requests in an ordered container so that we acquire cliques - // in the same order for all participants and do not create a deadlock. - absl::btree_map cliques_; + // in the same order for all participants and do not create a deadlock. We use + // greater ordering to acquire largest cliques first. + absl::btree_map> cliques_; }; absl::Status MaybeSyncAndProfile( @@ -302,19 +311,19 @@ absl::Status MaybeSyncAndProfile( std::optional execution_timer, se::Stream* stream_to_sync); -absl::Status MaybeRendezvousAfterInitialization( +absl::Status RendezvousAfterInitialization( + const ServiceExecutableRunOptions* run_options); + +absl::Status ExecuteThunks( + const std::string& module_name, ModuleIdentifier module_id, + const ThunkSequence& thunk_sequence, + Thunk::ExecutableSource executable_source, const ServiceExecutableRunOptions* run_options, - std::atomic* thunks_initialized); - -absl::Status ExecuteThunks(const std::string& module_name, - ModuleIdentifier module_id, - const ThunkSequence& thunk_sequence, - Thunk::ExecutableSource executable_source, - const ServiceExecutableRunOptions* run_options, - const BufferAllocations& buffer_allocations, - bool block_host_until_done, - bool use_highest_priority_for_async_stream, - std::atomic* thunks_initialized) { + const BufferAllocations& buffer_allocations, bool block_host_until_done, + bool use_highest_priority_for_async_stream, + const absl::flat_hash_set& execution_stream_ids, + int64_t collective_max_nchannels, int64_t p2p_max_nchannels, + const ModuleAnnotations& module_annotations) { se::Stream* main_stream = run_options->stream(); se::StreamExecutor* executor = main_stream->parent(); stream_executor::StreamPriority stream_priority = @@ -343,6 +352,22 @@ absl::Status ExecuteThunks(const std::string& module_name, command_buffer_trace_stream = borrowed_command_buffer_trace_stream->get(); } + // Borrow stream for additional compute streams + Thunk::ExecutionStreamIdMap additional_execution_streams; + std::vector additional_streams; + if (!execution_stream_ids.empty()) { + TF_ASSIGN_OR_RETURN(additional_streams, run_options->BorrowStreams( + executor->device_ordinal(), + execution_stream_ids.size())); + int64_t i = 0; + for (ExecutionStreamId stream_id : execution_stream_ids) { + additional_execution_streams[stream_id] = additional_streams.at(i).get(); + i++; + } + VLOG(2) << "Using " << additional_execution_streams.size() + << " additional compute streams."; + } + tsl::profiler::TraceMe hlo_module_activity( [&] { return absl::StrCat(module_name, ":XLA GPU module"); }, tsl::profiler::TraceMeLevel::kInfo); @@ -359,10 +384,10 @@ absl::Status ExecuteThunks(const std::string& module_name, #endif // Parameters for executing collective operations. - TF_ASSIGN_OR_RETURN( - Thunk::CollectiveExecuteParams collective_params, - Thunk::CollectiveExecuteParams::Create( - *run_options, main_stream->parent()->device_ordinal())); + TF_ASSIGN_OR_RETURN(Thunk::CollectiveExecuteParams collective_params, + Thunk::CollectiveExecuteParams::Create( + *run_options, main_stream->parent()->device_ordinal(), + collective_max_nchannels, p2p_max_nchannels)); ResourceRequests resource_requests; @@ -396,21 +421,21 @@ absl::Status ExecuteThunks(const std::string& module_name, // only in presence of collective cliques which means that we have collective // operations in the XLA operations that tend to cause deadlocks. if (!collective_cliques.empty()) { - TF_RETURN_IF_ERROR( - MaybeRendezvousAfterInitialization(run_options, thunks_initialized)); + TF_RETURN_IF_ERROR(RendezvousAfterInitialization(run_options)); } // Prepare parameters for thunks execution. Thunk::ExecuteParams execute_params = Thunk::ExecuteParams::Create( *run_options, buffer_allocations, main_stream, command_buffer_trace_stream, async_comms_streams, &collective_params, - &collective_cliques); + &collective_cliques, additional_execution_streams); for (const std::unique_ptr& thunk : thunk_sequence) { // Annotate execution of this op if tracing was enabled when we started // running this module. If tracing is enabled *while* we're running the // module, we won't get any data, but that's probably an OK trade-off. - ScopedAnnotation annotation([&] { return thunk->profile_annotation(); }); + auto scoped_annotation = + GetKernelAnnotation(&module_annotations, thunk->profile_annotation()); VLOG(3) << "Executing the thunk for " << thunk->profile_annotation(); if (NeedsAsyncCommsStream(*thunk)) { for (se::Stream* async_stream : async_comms_streams) { @@ -425,9 +450,25 @@ absl::Status ExecuteThunks(const std::string& module_name, block_host_until_done ? main_stream : nullptr); } -absl::Status MaybeRendezvousAfterInitialization( - const ServiceExecutableRunOptions* run_options, - std::atomic* thunks_initialized) { +namespace { +// Wrap RunId into a unique struct to guarantee we do not accidentally try to +// run multiple unrelated rendezvous for a same key. +struct InitializationKey { + RunId run_id; + + template + friend H AbslHashValue(H h, const InitializationKey& key) { + return H::combine(std::move(h), key.run_id); + } +}; + +bool operator==(const InitializationKey& a, const InitializationKey& b) { + return a.run_id == b.run_id; +} +} // namespace + +absl::Status RendezvousAfterInitialization( + const ServiceExecutableRunOptions* run_options) { // Thunk initialization can allocate new control data structures on device // that can lead to deadlocks if other replicas are executing concurrently // (i.e. this happens if we try to instantiate CUDA graph when other replica @@ -441,24 +482,6 @@ absl::Status MaybeRendezvousAfterInitialization( // are running in a single Gpu config and don't need a rendezvous. if (!gpu_opts || !device_assn) return absl::OkStatus(); - // If `thunks_initialized` value is `-1` it means that all thunks are - // initialized and we can go ahead and execute all of them. All other values - // signal how many threads are executing rendezvous (they can be from - // different run ids). - if (thunks_initialized->load() < 0) return absl::OkStatus(); - - // We rely on CAS operations to make sure that all participants of - // potentially multiple concurrent XLA executions join the rendezvous or - // none of them join, because otherwise we will get a dead lock. - int64_t participant_id = thunks_initialized->load(); - while (participant_id >= 0 && !thunks_initialized->compare_exchange_weak( - participant_id, participant_id + 1)) { - } - - // If we exited a CAS loop with participant id less than 0 it means that some - // other thread completed initialization rendezvous. - if (participant_id < 0) return absl::OkStatus(); - // Assume that all participants execute locally first, if we have a local // device id to global device id map we will use it to get the real number of // participating local devices. @@ -484,7 +507,14 @@ absl::Status MaybeRendezvousAfterInitialization( << num_local_participants << " local participants" << "; device_ordinal=" << run_options->device_ordinal(); - auto rendezvous_key = run_options->run_options().run_id(); + tsl::profiler::TraceMe trace([&] { + return tsl::profiler::TraceMeEncode( + "RendezvousAfterInitialization", + {{"run_id", run_options->run_options().run_id().ToInt()}, + {"num_local_participants", num_local_participants}}); + }); + + auto rendezvous_key = InitializationKey{run_options->run_options().run_id()}; auto rendezvous_name = absl::StrFormat( "thunk initialization completion for device ordinal %d; run_id=%d", run_options->device_ordinal(), @@ -493,23 +523,6 @@ absl::Status MaybeRendezvousAfterInitialization( RendezvousSingle(rendezvous_name, rendezvous_key, num_local_participants, absl::Seconds(10), absl::Seconds(30)); - // Reload participant_id and use CAS to decide if we are the one who - // should mark initialization completed. - participant_id = thunks_initialized->load(); - - // Check that no one completed initialization process without us, and the - // number of participants inside the critical section is greater than 0 (we - // are here, so it can't be 0). - CHECK_GT(participant_id, 0); // NOLINT - - // If we are the last one, we try to mark executable initialization as - // completed by writing `-1` into the flag. - while (!thunks_initialized->compare_exchange_weak( - participant_id, participant_id == 1 ? -1 : participant_id - 1)) { - // Check precondition for participant id after CAS failure reloaded it. - CHECK_GT(participant_id, 0); // NOLINT - } - return absl::OkStatus(); } @@ -597,8 +610,8 @@ GpuExecutable::ResolveConstantGlobals(se::Stream* stream) { if (!info.content.span().empty()) { // This means the constant did not have an initializer in the PTX and // therefore must be initialized by XLA here. - stream->ThenMemcpy(&global, info.content.span().data(), - info.content.span().size()); + TF_RETURN_IF_ERROR(stream->Memcpy(&global, info.content.span().data(), + info.content.span().size())); submitted_mem_copies = true; } } else { @@ -752,37 +765,6 @@ absl::StatusOr GpuExecutable::ExecuteAsyncOnStream( return out.ConsumeResult(); } -static absl::Status ExecuteXlaRuntime( - const std::string& module_name, ModuleIdentifier module_id, - GpuRuntimeExecutable& gpu_runtime_executable, - const ServiceExecutableRunOptions* run_options, const std::string& asm_text, - const std::vector& binary, - const BufferAllocations& buffer_allocations, - const BufferAllocation* temp_buffer, bool block_host_until_done, - NonAtomicallyUpgradeableRWLock& gpu_lock) { - tsl::profiler::TraceMe hlo_module_activity( - [&] { return absl::StrCat(module_name, ":XLA GPU module"); }, - tsl::profiler::TraceMeLevel::kInfo); - - std::optional execution_timer; -#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM - if (ExecutionProfile* profile = - run_options->run_options().execution_profile(); - profile) { - TF_ASSIGN_OR_RETURN( - execution_timer, - se::gpu::GpuTimer::Create(se::gpu::AsGpuStream(run_options->stream()))); - } -#endif - auto executed = gpu_runtime_executable.Execute( - run_options, asm_text, binary, buffer_allocations, gpu_lock, temp_buffer); - if (!executed.ok()) return executed; - - return MaybeSyncAndProfile( - run_options, std::move(execution_timer), - block_host_until_done ? run_options->stream() : nullptr); -} - absl::StatusOr GpuExecutable::ExecuteAsyncOnStreamImpl( const ServiceExecutableRunOptions* run_options, VariantArguments arguments) { @@ -806,12 +788,13 @@ absl::StatusOr GpuExecutable::ExecuteAsyncOnStreamImpl( // that may be running during JIT compilation while allowing multiple XLA // computations to use the same GPU simultaneously. We do not add locking for // "recursive" invocations, which are done when holding a lock already. - NonAtomicallyUpgradeableRWLock gpu_lock(&GetGpuMutex(executor)); - std::optional exclusive_gpu_lock; - const gpu::GpuExecutableRunOptions* gpu_opts = - run_options->run_options().gpu_executable_run_options(); - if (gpu_opts && gpu_opts->requires_exclusive_lock_on_gpu()) { - exclusive_gpu_lock.emplace(&gpu_lock); + std::variant gpu_lock( + std::in_place_index_t<0>{}, &GetGpuMutex(executor)); + + // Maybe update to a writer lock to get exlcusive acess to underlying GPU. + if (auto* gpu_opts = run_options->run_options().gpu_executable_run_options(); + gpu_opts && gpu_opts->requires_exclusive_lock_on_gpu()) { + gpu_lock.emplace<1>(&GetGpuMutex(executor)); } const GpuExecutable::BufferAllocToDeviceMemoryMap* globals; @@ -926,8 +909,8 @@ absl::StatusOr GpuExecutable::ExecuteAsyncOnStreamImpl( buffer_allocations.GetMutableDeviceAddress( output_info.allocation_index); CHECK_EQ(aliased_buffer.size(), result_buffer.size()); - run_options->stream()->ThenMemcpyD2D(&result_buffer, aliased_buffer, - aliased_buffer.size()); + TF_RETURN_IF_ERROR(run_options->stream()->MemcpyD2D( + &result_buffer, aliased_buffer, aliased_buffer.size())); aliased_buffer = result_buffer; } } @@ -947,8 +930,8 @@ absl::StatusOr GpuExecutable::ExecuteAsyncOnStreamImpl( buffers_in_result.insert(result_buffer); } - TF_RETURN_IF_ERROR(ExecuteThunksOrXlaRuntime( - run_options, buffer_allocations, block_host_until_done, gpu_lock)); + TF_RETURN_IF_ERROR(ExecuteThunksOrXlaRuntime(run_options, buffer_allocations, + block_host_until_done)); TF_RETURN_IF_ERROR( buffer_allocations.TearDown(buffers_in_result, GetAllocations())); @@ -960,48 +943,28 @@ absl::StatusOr GpuExecutable::ExecuteAsyncOnStreamImpl( return std::move(result); } -namespace { -struct ModuleAnnotationManager { - ModuleAnnotationManager(const std::optional& annotations) { - if (annotations.has_value()) { - m_old_annotations = SetCurrentModuleAnnotations(&(*annotations)); - } - } - ~ModuleAnnotationManager() { - if (m_old_annotations.has_value()) { - SetCurrentModuleAnnotations(*m_old_annotations); - } - } - - private: - std::optional m_old_annotations; -}; -} // namespace - absl::Status GpuExecutable::ExecuteThunksOrXlaRuntime( const ServiceExecutableRunOptions* run_options, - const BufferAllocations& buffer_allocations, bool block_host_until_done, - NonAtomicallyUpgradeableRWLock& gpu_lock) { + const BufferAllocations& buffer_allocations, bool block_host_until_done) { TF_RETURN_IF_ERROR( CheckCompatibilityWithServiceExecutableRunOptions(run_options)); - // There isn't always an HLO module. - ModuleIdentifier unique_id = -1; - if (has_module()) { - unique_id = module().unique_id(); - } + ScopedAnnotation annotation([&] { return module_annotations_.top_level; }); + ScopedModuleAnnotations module_annotations(&module_annotations_); - ScopedAnnotationAlways annotation([&]() -> ModuleAnnotation { - if (annotation_info_) { - return annotation_info_->top_level; - } else { - return {module_name_, unique_id}; - } - }); - ModuleAnnotationManager set_current_kernel_annotations{annotation_info_}; + ModuleIdentifier unique_id = has_module() ? module().unique_id() : -1; if (thunks_) { Thunk::ExecutableSource executable_source = {text_, binary_}; + int64_t collective_max_nchannels = + has_module() ? module_config() + .debug_options() + .xla_gpu_nccl_collective_max_nchannels() + : 0; + int64_t p2p_max_nchannels = + has_module() + ? module_config().debug_options().xla_gpu_nccl_p2p_max_nchannels() + : 0; return ExecuteThunks( module_name_, unique_id, *thunks_, executable_source, run_options, @@ -1011,23 +974,8 @@ absl::Status GpuExecutable::ExecuteThunksOrXlaRuntime( .debug_options() .xla_gpu_enable_highest_priority_async_stream() : false, - &thunks_initialized_); - } - - // Match IrEmitter's temp buffer allocation for kernel launches. See - // IrEmitterUnnested::BuildKernelThunkImpl(). - const BufferAllocation* temp_buffer = nullptr; - for (const BufferAllocation& alloc : GetAllocations()) { - if (alloc.IsPreallocatedTempBuffer()) { - // Retrieve the first seen temp buffer. - if (temp_buffer == nullptr) temp_buffer = &alloc; - } - } - - if (gpu_runtime_executable_) { - return ExecuteXlaRuntime(module_name_, unique_id, *gpu_runtime_executable_, - run_options, text_, binary_, buffer_allocations, - temp_buffer, block_host_until_done, gpu_lock); + execution_stream_ids_, collective_max_nchannels, p2p_max_nchannels, + module_annotations_); } return FailedPrecondition("Expected XLA gpu executable is not supplied."); @@ -1165,211 +1113,5 @@ GetOutputInfo(const HloModule& hlo_module, const BufferAssignment& assignment) { return output; } -GpuExecutable::GpuExecutable( - std::shared_ptr hlo_module, std::string asm_text, - std::vector binary, std::vector constants, - se::GpuComputeCapability gpu_version, absl::string_view module_name, - Shape xla_output_shape, std::vector allocations, - absl::flat_hash_map output_info, - std::unique_ptr gpu_runtime_executable) - : Executable(std::move(hlo_module)), - text_(std::move(asm_text)), - binary_(std::move(binary)), - gpu_version_(gpu_version), - gpu_runtime_executable_(std::move(gpu_runtime_executable)), - module_name_(module_name), - output_shape_(xla_output_shape), - allocations_(std::move(allocations)), - constants_(std::move(constants)), - output_info_(std::move(output_info)), - enable_debug_info_manager_(true) { - if (has_module()) { - annotation_info_.emplace(module()); - XlaDebugInfoManager::Get()->RegisterModule(shared_module(), - BufferAssignmentProto()); - } -} - -// Returns a list of functions exported from the `module` that should be loaded -// from the object file. Entrypoint functions always loaded with ordinal 0. -static absl::StatusOr> -GetFunctionsToLoad(mlir::ModuleOp module, std::string_view entry) { - std::vector functions; - - // Use canonical type converter because we currently do not support any - // user-defined types in XLA:GPU executables. - runtime::TypeConverter type_converter; - - // Converts function type and adds load function metadata. In XLA:GPU exported - // function runtime signature is the same as regular signature with an extra - // execution context argument at index 0. - auto convert = [&](mlir::func::FuncOp func) -> absl::Status { - auto signature = type_converter.Convert(func.getFunctionType()); - if (!signature.ok()) - return Internal("Failed to convert entry function type: %s", - signature.status().message()); - - // TODO(ezhulenev): Copy `signature` once FunctionType is copyable. - auto rt_signature = type_converter.Convert(func.getFunctionType()); - rt_signature->insert_operand( - 0, std::make_unique()); - - functions.push_back({func.getName().str(), std::move(*signature), - std::move(*rt_signature)}); - - return absl::OkStatus(); - }; - - mlir::SymbolTable sym_table(module); - - // Load entrypoint function first at ordinal 0. - TF_CHECK_OK(convert(module.lookupSymbol(entry))); - - // Load all functions explicitly exported from the module (in XLA:GPU it's - // always CUDA graph capture functions). We explicitly sort them by ordinal, - // to make sure they are loaded in correct order. - auto export_ops = llvm::to_vector(module.getOps()); - llvm::sort(export_ops, [](runtime::ExportOp a, runtime::ExportOp b) { - return a.getOrdinal()->getSExtValue() < b.getOrdinal()->getSExtValue(); - }); - for (runtime::ExportOp exported : export_ops) { - TF_CHECK_OK(convert( - sym_table.lookup(exported.getFunctionRef()))); - } - - return functions; -} - -// Get arguments buffer sizes from the entry function signature. -static absl::StatusOr> GetBufferSizes( - runtime::FunctionType& f) { - std::vector buffer_sizes; - for (unsigned i = 0; i < f.num_operands(); ++i) { - auto* memref = llvm::dyn_cast(f.operand(i)); - - // Entry function argument must be a statically shaped 1d I8 memref. - if (memref == nullptr || memref->element_type() != PrimitiveType::S8 || - memref->rank() != 1 || runtime::MemrefType::IsDynamic(memref->size(0))) - return Internal("Illegal buffer argument type: %s", - f.operand(0)->ToString()); - - buffer_sizes.push_back(memref->size(0)); - } - return buffer_sizes; -} - -// TODO(ezhulenev): This is a copy of `GetAllocationIndices` from -// `mlir/backends/gpu/transforms/passes.h`. We can't depend on that file because -// of a dependency cycle, and this is a short term work around the cuda graph -// capture bug. This code should not survive beyond Q1 2024. -static std::vector> GetAllocationIndices( - mlir::ModuleOp module) { - std::vector> res; - - mlir::SymbolTable sym_table(module); - for (auto op : module.getOps()) { - unsigned ordinal = *op.ordinal(); - if (ordinal >= res.size()) res.resize(ordinal + 1); - - auto func = sym_table.lookup(op.getFunctionRef()); - res[ordinal].resize(func.getNumArguments(), -1); - - for (unsigned i = 0; i < func.getNumArguments(); ++i) { - auto idx = - func.getArgAttrOfType(i, "rt.allocation_index"); - if (idx) res[ordinal][i] = idx.getInt(); - } - } - - return res; -} - -absl::StatusOr> GpuExecutable::LoadFromObjFile( - std::shared_ptr hlo_module, absl::string_view obj_file, - absl::string_view mlir_module, DebugOptions debug_options, - absl::string_view asm_text, absl::string_view binary, - std::vector constants, se::GpuComputeCapability gpu_version) { - VLOG(1) << "Load serialized Gpu executable from object file: module=" - << hlo_module->name(); - - std::string_view entry = hlo_module->entry_computation()->name(); - - // Load MLIR module behind the compiled object file to recover XLA allocations - // and output info details. Also recover buffer sizes from the entrypoint - // function signature. - mlir::MLIRContext context; - runtime::AppendXlaGpuDialectRegistry(context); - - auto module = mlir::parseSourceString(mlir_module, &context); - if (!module) return Internal("Failed to parse AOT compiled module"); - - // Get the list of functions to be loaded from the object file. - TF_ASSIGN_OR_RETURN(std::vector functions, - GetFunctionsToLoad(*module, entry)); - VLOG(2) << "Found " << functions.size() << " functions to load"; - - // Get the buffer sizes from the entry function signature. - TF_ASSIGN_OR_RETURN(std::vector buffer_sizes, - GetBufferSizes(functions[0].signature)); - - // Get allocation indices from graph capture functions. - auto allocation_indices = GetAllocationIndices(*module); - - // Get the XLA module entrypoint function. - auto func = mlir::cast(module->lookupSymbol(entry)); - - // Infer XLA allocations and output info from the MLIR module. - std::vector allocations; - absl::flat_hash_map output_info; - Shape result_xla_shape; - TF_RETURN_IF_ERROR(SetUpMlirAllocation(func, buffer_sizes, &allocations, - &output_info, &result_xla_shape)); - - // Create a named buffer from compiled object file. - llvm::StringRef data(obj_file.data(), obj_file.size()); - auto buffer = llvm::MemoryBuffer::getMemBuffer(data, hlo_module->name()); - - auto symbol_map = runtime::ToSymbolsBinding(RegisterXlaGpuRuntimeCustomCalls, - RegisterXlaGpuTypeIdNames); - - // Load XLA Runtime executable from an object file, and link it with Gpu - // runtime intrinsics implementing Gpu custom calls. - auto executable = runtime::Executable::LoadFromObjFile( - hlo_module->name(), std::move(buffer), std::move(functions), symbol_map); - - if (!executable.ok()) - return Internal("Failed to load XLA Runtime executable: %s", - executable.status().message()); - - // Move runtime::Executable ownership to the GpuRuntimeExecutable. - TF_ASSIGN_OR_RETURN(auto gpu_runtime_executable, - GpuRuntimeExecutable::Create( - hlo_module->name(), std::move(buffer_sizes), - std::move(allocation_indices), std::move(*executable), - std::move(debug_options))); - - // Construct GpuExecutable for the loaded XLA Runtime executable. - std::string name = hlo_module->name(); - std::string asm_text_string = std::string(asm_text); - std::vector binary_vector(binary.begin(), binary.end()); - return std::unique_ptr(new GpuExecutable( - std::move(hlo_module), std::move(asm_text_string), - std::move(binary_vector), std::move(constants), gpu_version, name, - result_xla_shape, std::move(allocations), std::move(output_info), - std::move(gpu_runtime_executable))); -} - -absl::StatusOr GpuExecutable::GetObjFile() const { - if (!gpu_runtime_executable_) - return Internal("gpu_runtime_executable is null"); - return gpu_runtime_executable_->GetObjFile(); -} - -absl::StatusOr GpuExecutable::GetMlirModule() const { - if (!gpu_runtime_executable_) - return Internal("gpu_runtime_executable is null"); - return gpu_runtime_executable_->GetMlirModule(); -} - } // namespace gpu } // namespace xla diff --git a/third_party/xla/xla/service/gpu/gpu_executable.h b/third_party/xla/xla/service/gpu/gpu_executable.h index b3acb1eb5abe8b..fb79d6c1330af4 100644 --- a/third_party/xla/xla/service/gpu/gpu_executable.h +++ b/third_party/xla/xla/service/gpu/gpu_executable.h @@ -16,35 +16,37 @@ limitations under the License. #ifndef XLA_SERVICE_GPU_GPU_EXECUTABLE_H_ #define XLA_SERVICE_GPU_GPU_EXECUTABLE_H_ -#include #include -#include -#include #include #include #include -#include -#include #include #include +#include "absl/base/thread_annotations.h" #include "absl/container/flat_hash_map.h" -#include "absl/strings/string_view.h" +#include "absl/container/flat_hash_set.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/synchronization/mutex.h" #include "absl/types/span.h" #include "absl/types/variant.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "xla/hlo/ir/hlo_input_output_alias_config.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/service/buffer_assignment.h" #include "xla/service/executable.h" #include "xla/service/gpu/buffer_allocations.h" #include "xla/service/gpu/ir_emission_utils.h" -#include "xla/service/gpu/non_atomically_upgradeable_rw_lock.h" #include "xla/service/gpu/runtime/annotation.h" -#include "xla/service/gpu/runtime/executable.h" #include "xla/service/gpu/thunk.h" #include "xla/service/hlo_execution_profile.h" +#include "xla/service/hlo_module_config.h" +#include "xla/service/rendezvous.h" +#include "xla/service/service_executable_run_options.h" #include "xla/service/shaped_buffer.h" #include "xla/shape.h" +#include "xla/shape_util.h" #include "xla/stream_executor/device_description.h" #include "xla/stream_executor/device_memory_allocator.h" #include "xla/stream_executor/stream_executor.h" @@ -63,7 +65,6 @@ bool IsXlaRuntimeExecutableEnabled(const HloModuleConfig& config); class GpuExecutable : public Executable { public: using OwnedThunkSequence = std::unique_ptr; - using OwnedGpuRuntimeProgram = std::unique_ptr; struct ConstantInfo { std::string symbol_name; @@ -87,10 +88,7 @@ class GpuExecutable : public Executable { std::string asm_text; std::vector binary; se::GpuComputeCapability gpu_version; - // The GpuExecutable will either execute Thunks, XLA runtime executable - // (native function) or experimental XLA runtime executable (IREE VM - // function) depending on which is supplied. - std::variant executable; + OwnedThunkSequence executable; std::vector constants; absl::flat_hash_map output_info; std::string module_name; @@ -104,36 +102,12 @@ class GpuExecutable : public Executable { // Analyze the entry function to construct buffer allocation and other output // information. - // - // TODO(ezhulenev): Once Xla runtime enabled by default, hide this method as - // an implementation detail of GpuExecutable. static absl::Status SetUpMlirAllocation( mlir::func::FuncOp func, llvm::ArrayRef buffer_sizes, std::vector* allocations, absl::flat_hash_map* output_info, Shape* output_shape); - // Returns an Executable that is loaded from an object file (XLA program - // compiled to a native function using the XLA Runtime stack). - static absl::StatusOr> LoadFromObjFile( - std::shared_ptr hlo_module, absl::string_view obj_file, - absl::string_view mlir_module, DebugOptions debug_options, - absl::string_view asm_text, absl::string_view binary, - std::vector constants, - se::GpuComputeCapability gpu_version); - - // Constructor to use when loading a GpuExecutable from an object file (native - // function compiled for XLA Runtime). Omits setting class members that aren't - // used in XLA Runtime execution mode. - GpuExecutable(std::shared_ptr hlo_module, std::string asm_text, - std::vector binary, - std::vector constants, - se::GpuComputeCapability gpu_version, - absl::string_view module_name, Shape xla_output_shape, - std::vector allocations, - absl::flat_hash_map output_info, - std::unique_ptr runtime_executable); - static absl::StatusOr> Create(Params params); ~GpuExecutable() override; @@ -197,15 +171,8 @@ class GpuExecutable : public Executable { : buffer_assignment_->Allocations(); } - bool IsXlaRuntimeEnabled() const { - return gpu_runtime_executable_ != nullptr; - } - const std::vector& constants() const { return constants_; } - absl::StatusOr GetObjFile() const; - absl::StatusOr GetMlirModule() const; - const BufferAssignment* buffer_assignment() const { return buffer_assignment_.get(); } @@ -221,8 +188,7 @@ class GpuExecutable : public Executable { // GPU execution completes. absl::Status ExecuteThunksOrXlaRuntime( const ServiceExecutableRunOptions* run_options, - const BufferAllocations& buffer_allocations, bool block_host_until_done, - NonAtomicallyUpgradeableRWLock& gpu_lock); + const BufferAllocations& buffer_allocations, bool block_host_until_done); using BufferAllocToDeviceMemoryMap = absl::flat_hash_map; @@ -272,11 +238,8 @@ class GpuExecutable : public Executable { // compute_capability_. // // May be empty, in which case we leave compilation up to the GPU driver. -#ifdef TENSORFLOW_USE_ROCM std::vector binary_; -#else - const std::vector binary_; -#endif + // The GPU version for compute compatibility check. se::GpuComputeCapability gpu_version_; @@ -284,18 +247,8 @@ class GpuExecutable : public Executable { // IrEmitter (null if XLA:GPU runtime is enabled). OwnedThunkSequence thunks_; - // A flag that signals if `thunks_` initialization is completed. Thunks - // initialization might allocate new control data structures on device, which - // can lead to deadlocks if executed concurrently with other replicas. - // - // We use atomic CAS operations to decide if a thread running XLA executable - // should join the rendezvous, see implementation for details. - std::atomic thunks_initialized_{0}; - - // Gpu runtime executable that encapsulates all the state for running Gpu - // runtime custom calls implementing gpu abstraction layer (available only if - // Xla runtime is enabled). - std::unique_ptr gpu_runtime_executable_; + // Additional execution streams requested by `thunks_`. + absl::flat_hash_set execution_stream_ids_; std::string module_name_; @@ -313,7 +266,12 @@ class GpuExecutable : public Executable { // This object is also used for dumping debug info. std::unique_ptr buffer_assignment_; - std::optional annotation_info_; + ModuleAnnotations module_annotations_ = [this] { + if (has_module()) { + return ModuleAnnotations(module()); + } + return ModuleAnnotations(module_name_); + }(); int64_t debug_buffer_assignment_show_max_; diff --git a/third_party/xla/xla/service/gpu/gpu_fused_mha_runner.h b/third_party/xla/xla/service/gpu/gpu_fused_mha_runner.h index ce2f8c8a50e884..c9d19ab0e1ebd7 100644 --- a/third_party/xla/xla/service/gpu/gpu_fused_mha_runner.h +++ b/third_party/xla/xla/service/gpu/gpu_fused_mha_runner.h @@ -23,10 +23,12 @@ limitations under the License. #include #include +#include "absl/container/inlined_vector.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" #include "xla/service/gpu/backend_configs.pb.h" #include "xla/service/gpu/cublas_cudnn.h" +#include "xla/shape.h" #include "xla/status.h" #include "xla/statusor.h" #include "xla/stream_executor/dnn.h" @@ -53,7 +55,7 @@ struct GpufMHADescriptor { Shape rhs_bmm2_shape; Shape intermediate_lhs_bmm2_shape; // This will contain both output shape and activation shape - std::vector output_shapes; + absl::InlinedVector output_shapes; DotDimensionNumbers bmm1_dnums; DotDimensionNumbers bmm2_dnums; diff --git a/third_party/xla/xla/service/gpu/gpu_fusible.cc b/third_party/xla/xla/service/gpu/gpu_fusible.cc index 2b880f0ab50320..c7baf4db8eb4f8 100644 --- a/third_party/xla/xla/service/gpu/gpu_fusible.cc +++ b/third_party/xla/xla/service/gpu/gpu_fusible.cc @@ -602,25 +602,6 @@ int64_t SharedMemoryUsage(const HloInstruction& instr, FusionInfoCache* cache) { return it->second; } -int64_t ReductionProjectedShmemUsageBytes( - const ReductionDimensions& reduction_dimensions, - const std::vector>& instr_index_groups) { - int64_t out = 0; - // Different groups are computed in parallel on different blocks, so they are - // not sharing the shmem budget. The overall usage is given by the largest - // one. - for (const auto& group : instr_index_groups) { - int64_t sum = 0; - for (const HloInstruction* root : group) { - if (IsReductionFromOrToContiguousDimensions(*root)) { - sum += SharedMemoryUsage(*root); - } - } - out = std::max(out, sum); - } - return out; -} - // Codegen'ing unnested reductions requires a lot of registers, so a MOF // combining many of those runs a high risk of spilling. constexpr int64_t kMaxUnnestedReductionOutputsPerFusion = 8; diff --git a/third_party/xla/xla/service/gpu/gpu_fusible.h b/third_party/xla/xla/service/gpu/gpu_fusible.h index e562759e3d9ae8..2ca3c1afa00a4d 100644 --- a/third_party/xla/xla/service/gpu/gpu_fusible.h +++ b/third_party/xla/xla/service/gpu/gpu_fusible.h @@ -64,11 +64,6 @@ struct FusionInfoCache { int64_t SharedMemoryUsage(const HloInstruction& instr, FusionInfoCache* cache = nullptr); -// Returns projected shared memory usage of a reduction fusion. -int64_t ReductionProjectedShmemUsageBytes( - const ReductionDimensions& reduction_dimensions, - const std::vector>& instr_index_groups); - inline constexpr int64_t MaxOperandsAndOutputsPerFusion() { return 64; } // Whether the op transposes the physical data layout. Fusing such ops may lead diff --git a/third_party/xla/xla/service/gpu/gpu_hlo_schedule.cc b/third_party/xla/xla/service/gpu/gpu_hlo_schedule.cc index 295505b8432339..0aa23b43637f3e 100644 --- a/third_party/xla/xla/service/gpu/gpu_hlo_schedule.cc +++ b/third_party/xla/xla/service/gpu/gpu_hlo_schedule.cc @@ -657,7 +657,7 @@ static int64_t GetSchedulerMemoryLimit( const HloModule* module, const se::DeviceDescription& gpu_device_info, int pointer_size); -StatusOr ScheduleGpuModule( +absl::StatusOr ScheduleGpuModule( HloModule* module, int64_t pointer_size, const se::DeviceDescription& gpu_device_info) { int64_t memory_limit = @@ -679,10 +679,9 @@ StatusOr ScheduleGpuModule( // instruction name with ids. std::string fingerprint = module->GetFingerprint128( HloPrintOptions::Canonical().set_print_backend_config(true)); - HloInstruction* root = module->entry_computation()->root_instruction(); FrontendAttributes attributes; (*attributes.mutable_map())[std::string(kFingerprintBeforeLHS)] = fingerprint; - root->add_frontend_attributes(attributes); + module->add_frontend_attributes(attributes); VLOG(1) << "Fingerprint before LHS for module " << module->name() << "(" << module->unique_id() << ") = " << fingerprint; diff --git a/third_party/xla/xla/service/gpu/gpu_hlo_schedule.h b/third_party/xla/xla/service/gpu/gpu_hlo_schedule.h index 21cae7bd8c11af..142f6655061499 100644 --- a/third_party/xla/xla/service/gpu/gpu_hlo_schedule.h +++ b/third_party/xla/xla/service/gpu/gpu_hlo_schedule.h @@ -29,7 +29,7 @@ struct ScheduleMetadata { }; // Determines the schedule of HLO instructions for a module run on the GPU. -StatusOr ScheduleGpuModule( +absl::StatusOr ScheduleGpuModule( HloModule* module, int64_t pointer_size, const se::DeviceDescription& gpu_device_info); diff --git a/third_party/xla/xla/service/gpu/gpu_hlo_schedule_test.cc b/third_party/xla/xla/service/gpu/gpu_hlo_schedule_test.cc index fc9ceacf3c7b60..a9007a3cba8d5e 100644 --- a/third_party/xla/xla/service/gpu/gpu_hlo_schedule_test.cc +++ b/third_party/xla/xla/service/gpu/gpu_hlo_schedule_test.cc @@ -87,9 +87,7 @@ class GpuHloScheduleTest : public HloTestBase { static bool HasValidFingerprint(HloModule* module) { // Verify that the fingerprint of HLO prior to LHS is present. - const HloInstruction* root = - module->entry_computation()->root_instruction(); - const FrontendAttributes& attrs = root->frontend_attributes(); + const FrontendAttributes& attrs = module->frontend_attributes(); auto it = attrs.map().find(kFingerprintBeforeLHS); // The fingerprint is 128 bits stored as a hex string (128/4 hex digits). diff --git a/third_party/xla/xla/service/gpu/gpu_norm_runner.cc b/third_party/xla/xla/service/gpu/gpu_norm_runner.cc index 65ce6b82e8c5de..676f00ea2229d9 100644 --- a/third_party/xla/xla/service/gpu/gpu_norm_runner.cc +++ b/third_party/xla/xla/service/gpu/gpu_norm_runner.cc @@ -27,38 +27,60 @@ namespace xla { namespace gpu { absl::Status RunGpuNorm(const gpu::GpuNormConfig& config, - const se::DeviceMemoryBase& input_buffer, + const se::DeviceMemoryBase& x_buffer, const se::DeviceMemoryBase& scale_buffer, - const se::DeviceMemoryBase& bias_buffer, - const se::DeviceMemoryBase& output_buffer, + const se::DeviceMemoryBase& y_or_dx_buffer, + std::optional bias_buffer, + std::optional dy_buffer, std::optional expectation_buffer, std::optional norm_factor_buffer, + std::optional dscale_buffer, + std::optional dbias_buffer, const se::DeviceMemoryBase& scratch_memory, se::Stream* stream, RunNormOptions options) { se::dnn::LazyOpRunner* lazy_runner = options.norm_runner->AsNormRunner(); std::optional> local_runner; - se::dnn::NormOp::Config ln_config{config.epsilon, - config.input_descriptor, + TF_ASSIGN_OR_RETURN(se::dnn::NormKind kind, + GetDNNNormKindFromCudnnNormKind(config.kind)); + + se::dnn::NormOp::Config ln_config{kind, + config.epsilon, + config.x_descriptor, config.scale_descriptor, + config.y_or_dx_descriptor, config.bias_descriptor, - config.output_descriptor, + config.dy_descriptor, config.expectation_descriptor, - config.norm_factor_descriptor}; + config.norm_factor_descriptor, + config.dscale_descriptor, + config.dbias_descriptor}; TF_ASSIGN_OR_RETURN(auto* runner, lazy_runner->GetOrCreateRunner(ln_config, stream)); std::vector operands; - operands.emplace_back(input_buffer); + operands.emplace_back(x_buffer); operands.emplace_back(scale_buffer); - operands.emplace_back(bias_buffer); - operands.emplace_back(output_buffer); - if (expectation_buffer) { + operands.emplace_back(y_or_dx_buffer); + + // The remaining operands are composed of inputs followed by outputs of the + // library call. The expectation and norm factor are outputs of the forward + // training layer norm, and inputs of the backward layer norm. + if (config.kind == CudnnNormKind::kLayerForwardInfer || + config.kind == CudnnNormKind::kLayerForwardTrain) { + operands.emplace_back(bias_buffer.value()); + } + if (config.kind == CudnnNormKind::kLayerForwardTrain) { operands.emplace_back(expectation_buffer.value()); + operands.emplace_back(norm_factor_buffer.value()); } - if (norm_factor_buffer) { + if (config.kind == CudnnNormKind::kLayerBackward) { + operands.emplace_back(dy_buffer.value()); + operands.emplace_back(expectation_buffer.value()); operands.emplace_back(norm_factor_buffer.value()); + operands.emplace_back(dscale_buffer.value()); + operands.emplace_back(dbias_buffer.value()); } return (*runner)(stream, options.profile_result, scratch_memory, operands); diff --git a/third_party/xla/xla/service/gpu/gpu_norm_runner.h b/third_party/xla/xla/service/gpu/gpu_norm_runner.h index 6452cae6c1844f..442fa2b1b67f7a 100644 --- a/third_party/xla/xla/service/gpu/gpu_norm_runner.h +++ b/third_party/xla/xla/service/gpu/gpu_norm_runner.h @@ -36,26 +36,45 @@ limitations under the License. namespace xla { namespace gpu { +inline absl::StatusOr AsCudnnNormKind( + xla::gpu::CudnnNormBackendConfig_Kind kind) { + switch (kind) { + case xla::gpu::CudnnNormBackendConfig::LAYER_FWD_INFER: + return xla::gpu::CudnnNormKind::kLayerForwardInfer; + case xla::gpu::CudnnNormBackendConfig::LAYER_FWD_TRAIN: + return xla::gpu::CudnnNormKind::kLayerForwardTrain; + case xla::gpu::CudnnNormBackendConfig::LAYER_BWD: + return xla::gpu::CudnnNormKind::kLayerBackward; + default: + return xla::Internal("Unknown norm kind."); + } +} + // Intermediate structure used as input to construct GpuNormConfig. struct GpuNormDescriptor { CudnnNormBackendConfig backend_config; - Shape input_shape; + Shape x_shape; Shape scale_shape; - Shape bias_shape; - Shape output_shape; + std::optional bias_shape; + Shape y_or_dx_shape; std::optional expectation_shape; std::optional norm_factor_shape; + std::optional dy_shape; + std::optional dscale_shape; + std::optional dbias_shape; size_t scratch_size; }; // Structure to describe static properties of a fused norm op. struct GpuNormConfig { static absl::StatusOr For(const GpuNormDescriptor& desc) { - std::vector output_types; + std::vector y_or_dx_types; GpuNormConfig config; config.epsilon = desc.backend_config.epsilon(); config.algorithm = se::dnn::AlgorithmDesc(desc.backend_config.algorithm()); + TF_ASSIGN_OR_RETURN(config.kind, + AsCudnnNormKind(desc.backend_config.kind())); auto tensor_descriptor_from_shape = [](Shape shape) -> absl::StatusOr { @@ -66,35 +85,49 @@ struct GpuNormConfig { shape.layout().minor_to_major()); }; - TF_ASSIGN_OR_RETURN(config.input_descriptor, - tensor_descriptor_from_shape(desc.input_shape)); + TF_ASSIGN_OR_RETURN(config.x_descriptor, + tensor_descriptor_from_shape(desc.x_shape)); TF_ASSIGN_OR_RETURN(config.scale_descriptor, tensor_descriptor_from_shape(desc.scale_shape)); - TF_ASSIGN_OR_RETURN(config.bias_descriptor, - tensor_descriptor_from_shape(desc.bias_shape)); - TF_ASSIGN_OR_RETURN(config.output_descriptor, - tensor_descriptor_from_shape(desc.output_shape)); + TF_ASSIGN_OR_RETURN(config.y_or_dx_descriptor, + tensor_descriptor_from_shape(desc.y_or_dx_shape)); + if (desc.bias_shape) { + TF_ASSIGN_OR_RETURN(config.bias_descriptor, tensor_descriptor_from_shape( + desc.bias_shape.value())); + } if (desc.expectation_shape) { TF_ASSIGN_OR_RETURN( config.expectation_descriptor, tensor_descriptor_from_shape(desc.expectation_shape.value())); - } - if (desc.norm_factor_shape) { TF_ASSIGN_OR_RETURN( config.norm_factor_descriptor, tensor_descriptor_from_shape(desc.norm_factor_shape.value())); } + if (desc.dscale_shape) { + TF_ASSIGN_OR_RETURN(config.dy_descriptor, + tensor_descriptor_from_shape(desc.dy_shape.value())); + TF_ASSIGN_OR_RETURN( + config.dscale_descriptor, + tensor_descriptor_from_shape(desc.dscale_shape.value())); + TF_ASSIGN_OR_RETURN( + config.dbias_descriptor, + tensor_descriptor_from_shape(desc.dbias_shape.value())); + } return config; } double epsilon; + CudnnNormKind kind; se::dnn::AlgorithmDesc algorithm; - se::dnn::TensorDescriptor input_descriptor; + se::dnn::TensorDescriptor x_descriptor; se::dnn::TensorDescriptor scale_descriptor; - se::dnn::TensorDescriptor bias_descriptor; - se::dnn::TensorDescriptor output_descriptor; + std::optional bias_descriptor; + se::dnn::TensorDescriptor y_or_dx_descriptor; std::optional expectation_descriptor; std::optional norm_factor_descriptor; + std::optional dy_descriptor; + std::optional dscale_descriptor; + std::optional dbias_descriptor; }; class NormRunner { @@ -127,13 +160,16 @@ struct RunNormOptions { NormRunner* norm_runner; }; -absl::Status RunGpuNorm(const gpu::GpuNormConfig& config, - const se::DeviceMemoryBase& input_buffer, +absl::Status RunGpuNorm(const GpuNormConfig& conv_config, + const se::DeviceMemoryBase& x_buffer, const se::DeviceMemoryBase& scale_buffer, - const se::DeviceMemoryBase& bias_buffer, - const se::DeviceMemoryBase& output_buffer, - std::optional expectation_buffer, + const se::DeviceMemoryBase& y_or_dx_buffer, + std::optional bias_buffer, + std::optional dy_buffer, + std::optional exepctation_buffer, std::optional norm_factor_buffer, + std::optional dscale_buffer, + std::optional dbias_buffer, const se::DeviceMemoryBase& scratch_memory, se::Stream* stream, RunNormOptions options = {}); diff --git a/third_party/xla/xla/service/gpu/gpu_prim_cuda.h b/third_party/xla/xla/service/gpu/gpu_prim.h similarity index 72% rename from third_party/xla/xla/service/gpu/gpu_prim_cuda.h rename to third_party/xla/xla/service/gpu/gpu_prim.h index aefa395d08ed0e..5e7daa3d86e02d 100644 --- a/third_party/xla/xla/service/gpu/gpu_prim_cuda.h +++ b/third_party/xla/xla/service/gpu/gpu_prim.h @@ -11,8 +11,8 @@ BASIS, WITHOUT OF ANY KIND WARRANTIES OR CONDITIONS, either express or implied. For the specific language governing permissions and limitations under the license, the license you must see. ==============================================================================*/ -#ifndef XLA_SERVICE_GPU_GPU_PRIM_CUDA_H_ -#define XLA_SERVICE_GPU_GPU_PRIM_CUDA_H_ +#ifndef XLA_SERVICE_GPU_GPU_PRIM_H_ +#define XLA_SERVICE_GPU_GPU_PRIM_H_ #include "tsl/platform/bfloat16.h" @@ -77,6 +77,42 @@ struct NumericTraits /*_NULL_TYPE=*/false, /*_UnsignedBits=*/uint16_t, /*T=*/tsl::bfloat16> {}; } // namespace cub -#endif // GOOGLE_CUDA +#elif TENSORFLOW_USE_ROCM -#endif // XLA_SERVICE_GPU_GPU_PRIM_CUDA_H_ +#include "rocm/include/hipcub/hipcub.hpp" +#include "rocm/rocm_config.h" +namespace gpuprim = ::hipcub; + +// Required for sorting Eigen::half and bfloat16. +namespace rocprim { +namespace detail { + +#if (TF_ROCM_VERSION >= 50200) +template <> +struct float_bit_mask { + static constexpr uint16_t sign_bit = 0x8000; + static constexpr uint16_t exponent = 0x7C00; + static constexpr uint16_t mantissa = 0x03FF; + using bit_type = uint16_t; +}; + +template <> +struct float_bit_mask { + static constexpr uint16_t sign_bit = 0x8000; + static constexpr uint16_t exponent = 0x7F80; + static constexpr uint16_t mantissa = 0x007F; + using bit_type = uint16_t; +}; +#endif // TF_ROCM_VERSION >= 50200 +template <> +struct radix_key_codec_base + : radix_key_codec_floating {}; +template <> +struct radix_key_codec_base + : radix_key_codec_floating {}; +}; // namespace detail +}; // namespace rocprim + +#endif // TENSORFLOW_USE_ROCM + +#endif // XLA_SERVICE_GPU_GPU_PRIM_H_ diff --git a/third_party/xla/xla/service/gpu/gpu_prim_rocm.h b/third_party/xla/xla/service/gpu/gpu_prim_rocm.h deleted file mode 100644 index 647f67d686bf97..00000000000000 --- a/third_party/xla/xla/service/gpu/gpu_prim_rocm.h +++ /dev/null @@ -1,56 +0,0 @@ -/* Copyright 2023 The OpenXLA Authors. -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -To in writing unless required by applicable law or agreed, -distributed on an, software distributed under the license is "AS IS" -BASIS, WITHOUT OF ANY KIND WARRANTIES OR CONDITIONS, either express -or implied. For the specific language governing permissions and -limitations under the license, the license you must see. -==============================================================================*/ -#ifndef XLA_SERVICE_GPU_GPU_PRIM_ROCM_H_ -#define XLA_SERVICE_GPU_GPU_PRIM_ROCM_H_ - -#include "tsl/platform/bfloat16.h" - -#if TENSORFLOW_USE_ROCM - -#include "rocm/include/hipcub/hipcub.hpp" -#include "rocm/rocm_config.h" -namespace gpuprim = ::hipcub; - -// Required for sorting Eigen::half and bfloat16. -namespace rocprim { -namespace detail { - -#if (TF_ROCM_VERSION >= 50200) -template <> -struct float_bit_mask { - static constexpr uint16_t sign_bit = 0x8000; - static constexpr uint16_t exponent = 0x7C00; - static constexpr uint16_t mantissa = 0x03FF; - using bit_type = uint16_t; -}; - -template <> -struct float_bit_mask { - static constexpr uint16_t sign_bit = 0x8000; - static constexpr uint16_t exponent = 0x7F80; - static constexpr uint16_t mantissa = 0x007F; - using bit_type = uint16_t; -}; -#endif // TF_ROCM_VERSION >= 50200 -template <> -struct radix_key_codec_base - : radix_key_codec_floating {}; -template <> -struct radix_key_codec_base - : radix_key_codec_floating {}; -}; // namespace detail -}; // namespace rocprim - -#endif // TENSORFLOW_USE_ROCM -#endif // XLA_SERVICE_GPU_GPU_PRIM_ROCM_H_ diff --git a/third_party/xla/xla/service/gpu/gpu_sort_rewriter.cc b/third_party/xla/xla/service/gpu/gpu_sort_rewriter.cc index 0cce7eebab7737..04d874979b9d43 100644 --- a/third_party/xla/xla/service/gpu/gpu_sort_rewriter.cc +++ b/third_party/xla/xla/service/gpu/gpu_sort_rewriter.cc @@ -32,7 +32,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/service/gpu/cublas_cudnn.h" -#include "xla/service/gpu/runtime3/cub_sort_thunk.h" +#include "xla/service/gpu/runtime/cub_sort_thunk.h" #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/statusor.h" diff --git a/third_party/xla/xla/service/gpu/gpu_transfer_manager.cc b/third_party/xla/xla/service/gpu/gpu_transfer_manager.cc index 4ef4ece8dc957a..b890c8b63ca156 100644 --- a/third_party/xla/xla/service/gpu/gpu_transfer_manager.cc +++ b/third_party/xla/xla/service/gpu/gpu_transfer_manager.cc @@ -15,32 +15,43 @@ limitations under the License. #include "xla/service/gpu/gpu_transfer_manager.h" +#include +#include +#include +#include #include -#include #include #include #include "absl/cleanup/cleanup.h" +#include "absl/functional/function_ref.h" #include "absl/status/status.h" +#include "absl/strings/str_format.h" +#include "absl/synchronization/mutex.h" #include "llvm/IR/DataLayout.h" #include "xla/literal.h" -#include "xla/literal_util.h" #include "xla/service/compiler.h" +#include "xla/service/generic_transfer_manager.h" +#include "xla/service/gpu/infeed_manager.h" #include "xla/service/gpu/outfeed_manager.h" #include "xla/service/gpu/target_constants.h" +#include "xla/service/shaped_buffer.h" +#include "xla/service/transfer_manager.h" #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/status_macros.h" -#include "xla/statusor.h" #include "xla/stream_executor/cuda/cuda_platform_id.h" -#include "xla/stream_executor/host/host_platform_id.h" -#include "xla/stream_executor/multi_platform_manager.h" +#include "xla/stream_executor/device_memory.h" +#include "xla/stream_executor/event.h" +#include "xla/stream_executor/memory_allocation.h" +#include "xla/stream_executor/platform.h" #include "xla/stream_executor/rocm/rocm_platform_id.h" #include "xla/stream_executor/stream_executor.h" -#include "xla/types.h" #include "xla/util.h" #include "tsl/platform/errors.h" #include "tsl/platform/logging.h" +#include "tsl/platform/numbers.h" +#include "tsl/platform/statusor.h" namespace xla { namespace gpu { @@ -52,12 +63,6 @@ GpuTransferManager::GpuTransferManager(se::Platform::Id id, unsigned pointer_size) : GenericTransferManager(id, pointer_size) {} -GpuTransferManager::~GpuTransferManager() { - if (pinned_chunk_se_) { - pinned_chunk_se_->HostMemoryDeallocate(pinned_chunk_); - } -} - absl::Status GpuTransferManager::TransferLiteralToInfeed( se::StreamExecutor* executor, const LiteralSlice& literal) { return gpu::GetOrCreateInfeedManager(executor)->TransferLiteralToInfeed( @@ -70,21 +75,25 @@ absl::Status GpuTransferManager::TransferLiteralFromOutfeed( executor, literal); } -void GpuTransferManager::EnsurePinnedBuffersAllocated( +absl::Status GpuTransferManager::EnsurePinnedBuffersAllocated( se::StreamExecutor* executor) { if (pinned_chunk_ != nullptr) { - return; + return absl::OkStatus(); } + TF_ASSIGN_OR_RETURN(pinned_chunk_, + executor->HostMemoryAllocate(kPinnedChunkBytes)); pinned_chunk_se_ = executor; - pinned_chunk_ = - reinterpret_cast(executor->HostMemoryAllocate(kPinnedChunkBytes)); + static_assert(kPinnedChunkBytes % kPinnedBufferBytes == 0, "assumption of loop below"); - for (char* buf = pinned_chunk_; buf < pinned_chunk_ + kPinnedChunkBytes; + char* base = reinterpret_cast(pinned_chunk_->opaque()); + for (char* buf = base; buf < base + kPinnedChunkBytes; buf += kPinnedBufferBytes) { pinned_buffers_.push_back(buf); } + + return absl::OkStatus(); } absl::Status GpuTransferManager::ReadDynamicShapes( @@ -147,7 +156,7 @@ absl::Status GpuTransferManager::ReadDynamicShapes( { absl::MutexLock lock(&mu_); - EnsurePinnedBuffersAllocated(stream->parent()); + TF_RETURN_IF_ERROR(EnsurePinnedBuffersAllocated(stream->parent())); for (const auto& src_dst : copies) { se::DeviceMemoryBase src = src_dst.first; @@ -172,7 +181,7 @@ absl::Status GpuTransferManager::ReadDynamicShapes( for (int i = 0; i < copies.size(); i++) { se::DeviceMemoryBase src = copies[i].first; void* dst = h2d_memcpy_dsts[i]; - stream->ThenMemcpy(dst, src, src.size()); + TF_RETURN_IF_ERROR(stream->Memcpy(dst, src, src.size())); } // Wait for all the async copies to complete, then write into device_shape. @@ -191,6 +200,133 @@ absl::Status GpuTransferManager::ReadDynamicShapes( return absl::OkStatus(); } +// Chunks `size` into chunks of `chunk_size` and calls `callback` for each. +static absl::Status ForEachChunk( + size_t size, size_t chunk_size, + absl::FunctionRef + callback) { + int64_t num_chunks = CeilOfRatio(size, chunk_size); + + for (int64_t chunk_index = 0; chunk_index < num_chunks; ++chunk_index) { + TF_RETURN_IF_ERROR(callback( + /*chunk_offset=*/chunk_index * chunk_size, + /*chunk_size=*/std::min(chunk_size, size - chunk_index * chunk_size))); + } + return absl::OkStatus(); +} + +absl::Status GpuTransferManager::TransferBufferFromDevice( + se::Stream* stream, const se::DeviceMemoryBase& source, int64_t size, + void* destination) { + if (source.size() < size) { + return absl::FailedPreconditionError(absl::StrFormat( + "Source allocation on device not large enough for data transfer: " + "%d < %d", + source.size(), size)); + } + + VLOG(5) << "Transfer buffer from device: size=" + << tsl::strings::HumanReadableNumBytes(size); + + TF_ASSIGN_OR_RETURN(auto staging_buffer, + GetOrCreateStagingBuffer(stream->parent())); + + absl::MutexLock lock(&staging_buffer->mutex); + void* staging = staging_buffer->allocation->opaque(); + + // Transfer chunk of data from device to destination via staging buffer. + auto transfer_chunk = [&](size_t chunk_offset, + size_t chunk_size) -> absl::Status { + VLOG(5) << "Transfer buffer chunk from device: offset=" << chunk_offset + << " size=" << tsl::strings::HumanReadableNumBytes(chunk_size); + + se::DeviceMemoryBase chunk = source.GetByteSlice(chunk_offset, chunk_size); + TF_RETURN_IF_ERROR(stream->Memcpy(staging, chunk, chunk_size)); + + void* dst = reinterpret_cast(destination) + chunk_offset; + return stream->DoHostCallback( + [=] { std::memcpy(dst, staging, chunk_size); }); + }; + + TF_RETURN_IF_ERROR(stream->WaitFor(staging_buffer->transfer_completed.get())); + TF_RETURN_IF_ERROR(ForEachChunk(size, kStagingBufferSize, transfer_chunk)); + TF_RETURN_IF_ERROR( + stream->RecordEvent(staging_buffer->transfer_completed.get())); + + return absl::OkStatus(); +} + +absl::Status GpuTransferManager::TransferBufferToDevice( + se::Stream* stream, int64_t size, const void* source, + se::DeviceMemoryBase* destination) { + if (destination->size() < size) { + return absl::FailedPreconditionError(absl::StrFormat( + "Destination allocation on device not large enough for data transfer: " + "%d < %d", + destination->size(), size)); + } + + VLOG(5) << "Transfer buffer to device: size=" + << tsl::strings::HumanReadableNumBytes(size); + + TF_ASSIGN_OR_RETURN(auto staging_buffer, + GetOrCreateStagingBuffer(stream->parent())); + + absl::MutexLock lock(&staging_buffer->mutex); + void* staging = staging_buffer->allocation->opaque(); + + // Transfer chunk of data from device to destination. + auto transfer_chunk = [&](size_t chunk_offset, size_t chunk_size) { + VLOG(5) << "Transfer buffer chunk to device: offset=" << chunk_offset + << " size=" << tsl::strings::HumanReadableNumBytes(chunk_size); + + const void* src = reinterpret_cast(source) + chunk_offset; + TF_RETURN_IF_ERROR( + stream->DoHostCallback([=] { std::memcpy(staging, src, chunk_size); })); + + auto chunk = destination->GetByteSlice(chunk_offset, chunk_size); + return stream->Memcpy(&chunk, staging, chunk_size); + }; + + TF_RETURN_IF_ERROR(stream->WaitFor(staging_buffer->transfer_completed.get())); + TF_RETURN_IF_ERROR(ForEachChunk(size, kStagingBufferSize, transfer_chunk)); + TF_RETURN_IF_ERROR( + stream->RecordEvent(staging_buffer->transfer_completed.get())); + + return absl::OkStatus(); +} + +GpuTransferManager::StagingBuffer::StagingBuffer( + std::unique_ptr allocation, + std::unique_ptr transfer_completed) + : allocation(std::move(allocation)), + transfer_completed(std::move(transfer_completed)) {} + +absl::StatusOr +GpuTransferManager::GetOrCreateStagingBuffer(se::StreamExecutor* executor) { + absl::MutexLock lock(&mutex_); + if (auto it = staging_buffers_.find(executor); it != staging_buffers_.end()) { + return &it->second; + } + + VLOG(3) << absl::StreamFormat( + "Allocate staging buffer of %s for executor %p (device_ordinal=%d)", + tsl::strings::HumanReadableNumBytes(kStagingBufferSize), executor, + executor->device_ordinal()); + + TF_ASSIGN_OR_RETURN(auto staging_buffer, + executor->HostMemoryAllocate(kStagingBufferSize)); + + auto transfer_completed = std::make_unique(executor); + if (!transfer_completed->Init()) { + return absl::InternalError("Failed to initialize transfer completed event"); + } + + auto emplaced = staging_buffers_.try_emplace( + executor, std::move(staging_buffer), std::move(transfer_completed)); + return &emplaced.first->second; +} + } // namespace gpu } // namespace xla @@ -215,4 +351,5 @@ static bool InitModule() { stream_executor::rocm::kROCmPlatformId, &CreateAMDGPUTransferManager); return true; } + static bool module_initialized = InitModule(); diff --git a/third_party/xla/xla/service/gpu/gpu_transfer_manager.h b/third_party/xla/xla/service/gpu/gpu_transfer_manager.h index 67a94f09388e97..3ec0e6da5f816d 100644 --- a/third_party/xla/xla/service/gpu/gpu_transfer_manager.h +++ b/third_party/xla/xla/service/gpu/gpu_transfer_manager.h @@ -16,13 +16,23 @@ limitations under the License. #ifndef XLA_SERVICE_GPU_GPU_TRANSFER_MANAGER_H_ #define XLA_SERVICE_GPU_GPU_TRANSFER_MANAGER_H_ +#include +#include #include +#include "absl/base/thread_annotations.h" +#include "absl/container/node_hash_map.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/synchronization/mutex.h" +#include "xla/literal.h" #include "xla/service/generic_transfer_manager.h" -#include "xla/service/gpu/infeed_manager.h" -#include "xla/service/transfer_manager.h" -#include "xla/shape_tree.h" -#include "xla/statusor.h" +#include "xla/service/shaped_buffer.h" +#include "xla/shape.h" +#include "xla/stream_executor/device_memory.h" +#include "xla/stream_executor/event.h" +#include "xla/stream_executor/memory_allocation.h" +#include "xla/stream_executor/platform.h" #include "xla/stream_executor/stream_executor.h" #include "xla/xla_data.pb.h" @@ -34,7 +44,6 @@ namespace gpu { class GpuTransferManager : public GenericTransferManager { public: GpuTransferManager(se::Platform::Id id, unsigned pointer_size); - ~GpuTransferManager() override; absl::Status TransferLiteralToInfeed(se::StreamExecutor* executor, const LiteralSlice& literal) override; @@ -45,11 +54,43 @@ class GpuTransferManager : public GenericTransferManager { Shape* device_shape) override; private: + // We use a fixed-size staging buffers and split transfer into multiple + // operations if literal does not fit into it. + static constexpr int64_t kStagingBufferSize = 128 * 1024 * 1024; + + // We use host memory allocation (pinned host memory) as a staging buffer for + // transfering literals to and from device. We keep a separate staging + // allocation per device so we don't need to do cross-device synchronization. + // All transfers to and from a device are ordered via stream dependencies. + struct StagingBuffer { + StagingBuffer(std::unique_ptr allocation, + std::unique_ptr transfer_completed); + + absl::Mutex mutex; + std::unique_ptr allocation ABSL_GUARDED_BY(mutex); + std::unique_ptr transfer_completed ABSL_GUARDED_BY(mutex); + }; + GpuTransferManager(const GpuTransferManager&) = delete; GpuTransferManager& operator=(const GpuTransferManager&) = delete; bool PackSubbyteTypes() const override { return true; } + // Returns or creates the staging buffer for the given executor. + absl::StatusOr GetOrCreateStagingBuffer( + se::StreamExecutor* executor); + + absl::Status TransferBufferFromDevice(se::Stream* stream, + const se::DeviceMemoryBase& source, + int64_t size, + void* destination) override; + + absl::Status TransferBufferToDevice( + se::Stream* stream, int64_t size, const void* source, + se::DeviceMemoryBase* destination) override; + + // TODO(ezhulenev): Unify this with staged buffers for transfering literals. + // This class keeps a pool of pinned memory // (StreamExecutor::HostMemoryAllocate()) that serves ReadDynamicShapes(). // This is a bit of a hack: Callers like TensorFlow already have a full pinned @@ -82,7 +123,7 @@ class GpuTransferManager : public GenericTransferManager { // // Lazy initialization works around this, because at that point we have a // stream, and therefore we have an already-initialized StreamExecutor. - void EnsurePinnedBuffersAllocated(se::StreamExecutor* executor) + absl::Status EnsurePinnedBuffersAllocated(se::StreamExecutor* executor) ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_); static constexpr int64_t kPinnedChunkBytes = 128 * 1024; @@ -96,11 +137,16 @@ class GpuTransferManager : public GenericTransferManager { // Chunk of pinned memory of size kPinnedChunkBytes. The pointers in // pinned_buffers_ point into this chunk. Lazily initialized. - char* pinned_chunk_ ABSL_GUARDED_BY(mu_) = nullptr; + std::unique_ptr pinned_chunk_ ABSL_GUARDED_BY(mu_); // Host buffers for reading dynamic shapes. Each buffer has size // kPinnedBufferBytes. Lazily initialized. std::vector pinned_buffers_ ABSL_GUARDED_BY(mu_); + + // Staging buffers allocated for transfers to and from device. + absl::Mutex mutex_; + absl::node_hash_map staging_buffers_ + ABSL_GUARDED_BY(mutex_); }; } // namespace gpu diff --git a/third_party/xla/xla/service/gpu/hlo_algorithm_denylist.cc b/third_party/xla/xla/service/gpu/hlo_algorithm_denylist.cc index 029ee5806edca1..a5619da7de8bf6 100644 --- a/third_party/xla/xla/service/gpu/hlo_algorithm_denylist.cc +++ b/third_party/xla/xla/service/gpu/hlo_algorithm_denylist.cc @@ -15,11 +15,20 @@ limitations under the License. #include "xla/service/gpu/hlo_algorithm_denylist.h" +#include #include +#include +#include #include "absl/container/flat_hash_map.h" +#include "absl/log/check.h" +#include "absl/types/span.h" #include "xla/debug_options_flags.h" #include "xla/service/gpu/gpu_autotuning.pb.h" +#include "xla/stream_executor/dnn.h" +#include "tsl/platform/env.h" +#include "tsl/platform/protobuf.h" +#include "tsl/platform/status.h" namespace xla { namespace gpu { diff --git a/third_party/xla/xla/service/gpu/hlo_fusion_analysis.cc b/third_party/xla/xla/service/gpu/hlo_fusion_analysis.cc index 445740e9a05648..3ef668189998d5 100644 --- a/third_party/xla/xla/service/gpu/hlo_fusion_analysis.cc +++ b/third_party/xla/xla/service/gpu/hlo_fusion_analysis.cc @@ -165,29 +165,6 @@ HloFusionAnalysis HloFusionAnalysis::Create( tiled_transpose_hero, std::move(input_output_info)); } -// static -absl::string_view HloFusionAnalysis::GetEmitterFusionKindString( - EmitterFusionKind kind) { - switch (kind) { - case EmitterFusionKind::kLoop: - return "loop"; - case EmitterFusionKind::kCustomFusion: - return "custom"; - case EmitterFusionKind::kTriton: - return "triton"; - case EmitterFusionKind::kReduction: - return "reduction"; - case EmitterFusionKind::kTranspose: - return "transpose"; - case EmitterFusionKind::kConcatenate: - return "concatenate"; - case EmitterFusionKind::kInputSlices: - return "input_slices"; - case EmitterFusionKind::kScatter: - return "scatter"; - } -} - // static HloFusionAnalysis HloFusionAnalysis::Create( const HloFusionInstruction* fusion, @@ -262,10 +239,10 @@ HloFusionAnalysis::EmitterFusionKind HloFusionAnalysis::GetEmitterFusionKind() continue; } if (!IsRealReductionHero(*root, *hero)) { - // Needs to have a compatible shape to the reduce operand. - if (!ShapeUtil::IsReshapeOrTransposeBitcast( - root->shape(), hero_operand_shape, - /*ignore_element_type=*/true)) { + // Needs to have a compatible shape to the reduce operand (compatible + // meaning same number of elements). + if (ShapeUtil::ElementsIn(root->shape()) != + ShapeUtil::ElementsIn(hero_operand_shape)) { valid_shapes = false; break; } diff --git a/third_party/xla/xla/service/gpu/hlo_fusion_analysis.h b/third_party/xla/xla/service/gpu/hlo_fusion_analysis.h index af5e57e2289632..011810a72b8532 100644 --- a/third_party/xla/xla/service/gpu/hlo_fusion_analysis.h +++ b/third_party/xla/xla/service/gpu/hlo_fusion_analysis.h @@ -91,8 +91,6 @@ class HloFusionAnalysis { return input_output_info_; } - static absl::string_view GetEmitterFusionKindString(EmitterFusionKind kind); - private: HloFusionAnalysis(FusionBackendConfig fusion_backend_config, std::vector fusion_roots, diff --git a/third_party/xla/xla/service/gpu/hlo_fusion_stats.cc b/third_party/xla/xla/service/gpu/hlo_fusion_stats.cc index 0fdcfefeb15937..11ec2530836807 100644 --- a/third_party/xla/xla/service/gpu/hlo_fusion_stats.cc +++ b/third_party/xla/xla/service/gpu/hlo_fusion_stats.cc @@ -48,6 +48,7 @@ class OpcodeCollector : public ConstDfsHloVisitorWithDefault { case HloOpcode::kCbrt: case HloOpcode::kCeil: case HloOpcode::kCos: + case HloOpcode::kErf: case HloOpcode::kExp: case HloOpcode::kExpm1: case HloOpcode::kFloor: diff --git a/third_party/xla/xla/service/gpu/hlo_to_ir_bindings.cc b/third_party/xla/xla/service/gpu/hlo_to_ir_bindings.cc index ea7a32ccb01b1a..33f9d9fedb4998 100644 --- a/third_party/xla/xla/service/gpu/hlo_to_ir_bindings.cc +++ b/third_party/xla/xla/service/gpu/hlo_to_ir_bindings.cc @@ -97,20 +97,6 @@ void HloToIrBindings::EmitBasePointersForHlos( } } -llvm::Value* HloToIrBindings::EmitGetTupleElement(const HloInstruction* gte, - llvm::Value* base_ptr) { - // TODO(b/26344050): tighten the alignment based on the real element type. - if (gte->operand(0)->opcode() != HloOpcode::kGetTupleElement) { - return llvm_ir::EmitGetTupleElement( - gte->shape(), gte->tuple_index(), /*alignment=*/1, base_ptr, - llvm_ir::ShapeToIrType(gte->operand(0)->shape(), module_), b_); - } - return llvm_ir::EmitGetTupleElement( - gte->shape(), gte->tuple_index(), /*alignment=*/1, - EmitGetTupleElement(gte->operand(0), base_ptr), - llvm_ir::ShapeToIrType(gte->operand(0)->shape(), module_), b_); -} - // Returns true if `value` has a name that should not be changed. static bool HasMeaningfulName(llvm::Value* value) { if (auto* global = llvm::dyn_cast(value)) { @@ -149,26 +135,9 @@ llvm_ir::IrArray HloToIrBindings::GetIrArray(const HloInstruction& hlo, return ir_array; } -void HloToIrBindings::UnbindAllLocalIrValues() { - std::vector hlos_to_unbind; - for (auto& key_value : base_ptrs_) { - if (!llvm::isa( - (key_value.second.element({}))->stripPointerCasts())) { - hlos_to_unbind.push_back(key_value.first); - } - } - for (const HloInstruction* hlo_to_unbind : hlos_to_unbind) { - VLOG(2) << "Unbinding " << hlo_to_unbind->ToString(); - base_ptrs_.erase(hlo_to_unbind); - } -} - std::string HloToIrBindings::ToString() const { std::string s = StrCat("** HloToIrBindings **\n"); StrAppend(&s, " is_nested_=", is_nested_, "\n"); - StrAppend(&s, - " temp_buffer_base_=", llvm_ir::DumpToString(temp_buffer_base_), - "\n"); if (base_ptrs_.empty()) { return s; diff --git a/third_party/xla/xla/service/gpu/hlo_to_ir_bindings.h b/third_party/xla/xla/service/gpu/hlo_to_ir_bindings.h index 5750c17dc69993..56b347d33ad760 100644 --- a/third_party/xla/xla/service/gpu/hlo_to_ir_bindings.h +++ b/third_party/xla/xla/service/gpu/hlo_to_ir_bindings.h @@ -43,21 +43,11 @@ class HloToIrBindings { void BindHloToIrValue(const HloInstruction& hlo, llvm::Value* ir_value, ShapeIndexView shape_index = {}); - // Unbinds all IR values that's defined in an LLVM function, e.g., function - // arguments and stack variables. Global variables will be kept in bindings_. - // - // This method is called after emitting code for each top-level HLO. The local - // IR values are out of scope at that point and should not be used. - void UnbindAllLocalIrValues(); - // Returns whether `hlo` is bound to an LLVM IR value. bool BoundToIrValue(const HloInstruction& hlo) const { return base_ptrs_.contains(&hlo); } - llvm::Value* GetTempBufferBase() const { return temp_buffer_base_; } - void SetTempBufferBase(llvm::Value* v) { temp_buffer_base_ = v; } - // A helper method that returns the base pointer of the IrArray containing the // output of "inst".at the given ShapeIndex. llvm::Value* GetBasePointer(const HloInstruction& hlo, @@ -81,10 +71,6 @@ class HloToIrBindings { std::string ToString() const; private: - // Emits IR to resolve (possibly) recursive GetTupleElement instructions. - llvm::Value* EmitGetTupleElement(const HloInstruction* gte, - llvm::Value* base_ptr); - const bool is_nested_; llvm::IRBuilder<>* b_; @@ -96,9 +82,6 @@ class HloToIrBindings { // in the ShapeTree. absl::flat_hash_map> base_ptrs_; - - // The address of the memory block that contains all temporary buffers. - llvm::Value* temp_buffer_base_ = nullptr; }; } // namespace gpu diff --git a/third_party/xla/xla/service/gpu/hlo_traversal.cc b/third_party/xla/xla/service/gpu/hlo_traversal.cc index 43f670751139f1..a70821b84c150b 100644 --- a/third_party/xla/xla/service/gpu/hlo_traversal.cc +++ b/third_party/xla/xla/service/gpu/hlo_traversal.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include #include +#include #include "absl/container/flat_hash_set.h" #include "absl/container/inlined_vector.h" @@ -77,6 +78,8 @@ class SingleInstructionFusion : public HloFusionAdaptor { return {instruction_}; } + std::string ToString() const override { return instruction_.ToString(); } + private: HloInstructionAdaptor instruction_; }; @@ -155,6 +158,8 @@ class HloComputationFusion : public HloFusionAdaptor { return result; } + std::string ToString() const override { return computation_->ToString(); } + private: const HloComputation* computation_; absl::InlinedVector roots_; diff --git a/third_party/xla/xla/service/gpu/hlo_traversal.h b/third_party/xla/xla/service/gpu/hlo_traversal.h index f4173905565635..fa5bc0f81817fb 100644 --- a/third_party/xla/xla/service/gpu/hlo_traversal.h +++ b/third_party/xla/xla/service/gpu/hlo_traversal.h @@ -24,6 +24,7 @@ limitations under the License. #include "absl/algorithm/container.h" #include "absl/container/inlined_vector.h" +#include "absl/strings/str_join.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "xla/hlo/ir/hlo_instruction.h" @@ -74,6 +75,7 @@ class HloFusionAdaptor { virtual absl::InlinedVector GetRoots() const = 0; virtual absl::InlinedVector MakeInstructionPostOrder() const = 0; + virtual std::string ToString() const = 0; static std::unique_ptr ForInstruction( const HloInstruction* instruction); @@ -114,6 +116,15 @@ class ProducerConsumerFusion : public HloFusionAdaptor { return producer_post_order; } + std::string ToString() const override { + // TODO: Add a parameter to indent output on nested adaptor for better + // visual representation. Nested producer-consumers fusion are not used in + // practice yet. + return absl::StrJoin({std::string("producer-consumer fusion:"), + producer_->ToString(), consumer_->ToString()}, + "\n"); + } + private: std::unique_ptr producer_; std::unique_ptr consumer_; diff --git a/third_party/xla/xla/service/gpu/infeed_manager.cc b/third_party/xla/xla/service/gpu/infeed_manager.cc index 4c24a73827030b..9f7dcd6904d22f 100644 --- a/third_party/xla/xla/service/gpu/infeed_manager.cc +++ b/third_party/xla/xla/service/gpu/infeed_manager.cc @@ -36,7 +36,7 @@ constexpr int kMaxInfeedsInFlight = 8; InfeedManager::InfeedManager(se::StreamExecutor* executor) : BlockingXfeedQueue(/*max_pending_xfeeds=*/kMaxInfeedsInFlight), stream_(std::make_unique(executor)) { - stream_->Init(); + stream_->Initialize().IgnoreError(); } static absl::StatusOr> CopyBufferToDevice( @@ -53,7 +53,7 @@ static absl::StatusOr> CopyBufferToDevice( se::StreamExecutor* executor = stream->parent(); se::ScopedDeviceMemory buffer( executor, executor->AllocateArray(size)); - stream->ThenMemcpy(buffer.ptr(), source, size); + TF_RETURN_IF_ERROR(stream->Memcpy(buffer.ptr(), source, size)); return std::move(buffer); } diff --git a/third_party/xla/xla/service/gpu/ir_emission_utils.cc b/third_party/xla/xla/service/gpu/ir_emission_utils.cc index edea96c1667cc8..ca299d71d1ae2b 100644 --- a/third_party/xla/xla/service/gpu/ir_emission_utils.cc +++ b/third_party/xla/xla/service/gpu/ir_emission_utils.cc @@ -172,31 +172,6 @@ bool IsCustomCallToTopK(const HloInstruction& hlo) { hlo.custom_call_target() == kTopKCustomCallTarget; } -bool IsInputFusibleSlices(mlir::Operation* unnested_hlo, - bool verify_no_strides) { - auto fusion = mlir::dyn_cast(unnested_hlo); - if (!fusion) { - return false; - } - - auto is_non_strided = [](mlir::DenseIntElementsAttr strides) -> bool { - return absl::c_all_of( - strides, [](const llvm::APInt& stride) { return stride == 1; }); - }; - - for (mlir::Value value : fusion.getFusionResults()) { - auto slice = - mlir::dyn_cast_or_null(value.getDefiningOp()); - if (!slice) { - return false; - } - if (verify_no_strides && !is_non_strided(slice.getStrides())) { - return false; - } - } - return true; -} - bool IsSliceWithUnitStrides(const HloInstruction* instr) { auto slice = DynCast(instr); return slice && absl::c_all_of(slice->slice_strides(), @@ -221,62 +196,6 @@ bool IsContiguousSlice(const HloInstruction& instr) { return true; } -// This emits a device-side call to -// "i32 vprintf(i8* fmt, arguments_type* arguments)" in the driver; see -// http://docs.nvidia.com/cuda/ptx-writers-guide-to-interoperability/index.html#system-calls -llvm::Value* EmitPrintf(absl::string_view fmt, - absl::Span arguments, - llvm::IRBuilder<>* builder) { - std::vector argument_types; - - // Variadic arguments implicit promotion [1] converts float to double, - // and bool/char/short are converted to int. - // [1] https://en.cppreference.com/w/cpp/language/variadic_arguments - auto requires_int32_promotion = [](llvm::Type* type) { - return type->isIntegerTy(/*BitWidth=*/1) || - type->isIntegerTy(/*BitWidth=*/8) || - type->isIntegerTy(/*BitWidth=*/16); - }; - auto requires_double_promotion = [](llvm::Type* type) { - return type->isFloatingPointTy(); - }; - - for (auto argument : arguments) { - llvm::Type* type = argument->getType(); - if (requires_double_promotion(type)) { - argument_types.push_back(builder->getDoubleTy()); - } else if (requires_int32_promotion(type)) { - argument_types.push_back(builder->getInt32Ty()); - } else { - argument_types.push_back(type); - } - } - auto* arguments_type = llvm::StructType::create(argument_types); - llvm::Value* arguments_ptr = builder->CreateAlloca(arguments_type); - for (size_t i = 0; i < arguments.size(); ++i) { - llvm::Value* value = arguments[i]; - llvm::Type* type = value->getType(); - if (requires_double_promotion(type)) { - value = builder->CreateFPCast(value, builder->getDoubleTy()); - } else if (requires_int32_promotion(type)) { - value = builder->CreateIntCast(value, builder->getInt32Ty(), - /*isSigned=*/true); - } - builder->CreateStore( - value, - builder->CreateGEP(arguments_type, arguments_ptr, - {builder->getInt64(0), builder->getInt32(i)})); - } - llvm::Type* ptr_ty = builder->getPtrTy(); - return builder->CreateCall( - builder->GetInsertBlock()->getParent()->getParent()->getOrInsertFunction( - "vprintf", - llvm::FunctionType::get(builder->getInt32Ty(), {ptr_ty, ptr_ty}, - /*isVarArg=*/false)), - {builder->CreateGlobalStringPtr(llvm_ir::AsStringRef(fmt)), - builder->CreatePointerCast(arguments_ptr, ptr_ty)}); -} - // Helper function to emit call to AMDGPU shfl_down function. llvm::Value* EmitAMDGPUShflDown(llvm::Value* value, llvm::Value* offset, llvm::IRBuilder<>* b) { @@ -1043,10 +962,6 @@ bool IsIntermediate(const HloInstruction* instr, int allowed_operand_count, } } -static bool IsParameter(const HloInstruction& instr) { - return instr.opcode() == HloOpcode::kParameter; -} - static std::optional FindNonTrivialHero( HloInstructionAdaptor root, const HloFusionAdaptor& fusion, const std::function& predicate) { diff --git a/third_party/xla/xla/service/gpu/ir_emission_utils.h b/third_party/xla/xla/service/gpu/ir_emission_utils.h index 2cbe73fee1605d..5611a6e3ee7aa4 100644 --- a/third_party/xla/xla/service/gpu/ir_emission_utils.h +++ b/third_party/xla/xla/service/gpu/ir_emission_utils.h @@ -90,12 +90,6 @@ bool IsCustomCallToTopK(const HloInstruction& hlo); // is a success/failure code per batch element. extern const char* const kCusolverCholeskyCallTarget; -// Returns whether unnested_hlo is an input fusion whose root is either a slice -// or a tuple of slices. If verify_no_strides is true, returns false unless all -// ROOT slices have no strides. -bool IsInputFusibleSlices(mlir::Operation* unnested_hlo, - bool verify_no_strides); - // Returns true if `instr` is a non-strided slice. bool IsSliceWithUnitStrides(const HloInstruction* instr); @@ -103,11 +97,6 @@ bool IsSliceWithUnitStrides(const HloInstruction* instr); // slice. bool IsContiguousSlice(const HloInstruction& instr); -// Emits call to "vprintf" with given format and arguments. -llvm::Value* EmitPrintf(absl::string_view fmt, - absl::Span arguments, - llvm::IRBuilder<>* builder); - // Emits code to shuffle data between threads of a warp. This has the same // semantics as the PTX "shfl.sync.down" instruction but works for values that // aren't 32 bits in size. The last operand of the emitted "shfl" is @@ -131,11 +120,6 @@ llvm::SmallVector GetHloOutputs(mlir::Operation* op); bool WritesMlirBuffer(mlir::Operation* op, mlir::Value operand); -template -std::vector ToStdVector(const llvm::SmallVectorImpl& v) { - return std::vector(v.begin(), v.end()); -} - absl::StatusOr GetAllocationSlice( mlir::Value v, absl::Span allocations, std::string* constant_name = nullptr); @@ -201,11 +185,6 @@ struct TransposeDescription { Vector3 permutation) : instr(instr), dimensions(dimensions), permutation(permutation) {} - std::string ToString() const { - return absl::StrCat("dimensions=", VectorString(dimensions), - ", permutation=", VectorString(permutation)); - } - // Transpose instruction input shape. const Shape& input_shape() const { return instr->operand(0)->shape(); } diff --git a/third_party/xla/xla/service/gpu/ir_emitter.h b/third_party/xla/xla/service/gpu/ir_emitter.h index 4dcec4c2844bce..70cd40fd307ee3 100644 --- a/third_party/xla/xla/service/gpu/ir_emitter.h +++ b/third_party/xla/xla/service/gpu/ir_emitter.h @@ -141,13 +141,6 @@ class IrEmitter : public DfsHloVisitorWithDefault, const char* sync_scope_id); private: - // A helper method for HandleSort(). It adds the inner comparison loop where - // we compare elements pointed to by 'keys_index' and 'compare_keys_index'. - void EmitCompareLoop(int64_t dimension_to_sort, - const llvm_ir::IrArray::Index& keys_index, - const llvm_ir::IrArray::Index& compare_keys_index, - const llvm_ir::IrArray& keys_array); - // A convenience method to determine whether or not IR is emitted for AMDGPU. bool IsEmittingForAMDGPU() const; }; diff --git a/third_party/xla/xla/service/gpu/ir_emitter_context.cc b/third_party/xla/xla/service/gpu/ir_emitter_context.cc index f09520eaef2cf4..81eaef0ebe229d 100644 --- a/third_party/xla/xla/service/gpu/ir_emitter_context.cc +++ b/third_party/xla/xla/service/gpu/ir_emitter_context.cc @@ -22,6 +22,7 @@ limitations under the License. #include "absl/algorithm/container.h" #include "llvm/ADT/ArrayRef.h" +#include "llvm/TargetParser/Triple.h" #include "xla/service/gpu/gpu_constants.h" #include "xla/service/gpu/ir_emission_utils.h" @@ -66,6 +67,9 @@ void IrEmitterContext::emit_constant(int64_t num_elements, content.span().size())); }(); + // Explicitly set global addrspace for SPIR backend. + int addrspace = + llvm::Triple(llvm_module_->getTargetTriple()).isSPIR() ? 1 : 0; // These globals will be looked up by name by GpuExecutable so we need to // give them an external linkage. Not all of their uses are visible in // the LLVM IR so we can't give then a linkage that merely preserves their @@ -79,7 +83,7 @@ void IrEmitterContext::emit_constant(int64_t num_elements, llvm::GlobalValue::ExternalLinkage, /*Initializer=*/initializer, symbol_name, /*TLMode=*/llvm::GlobalValue::NotThreadLocal, - /*AddressSpace=*/0, + /*AddressSpace=*/addrspace, /*isExternallyInitialized=*/false); global_for_const->setAlignment(llvm::Align(kConstantBufferAlignBytes)); llvm_module_->insertGlobalVariable(global_for_const); diff --git a/third_party/xla/xla/service/gpu/ir_emitter_context.h b/third_party/xla/xla/service/gpu/ir_emitter_context.h index 4f76c4f367cad4..330406f31604da 100644 --- a/third_party/xla/xla/service/gpu/ir_emitter_context.h +++ b/third_party/xla/xla/service/gpu/ir_emitter_context.h @@ -43,14 +43,13 @@ class IrEmitterContext { std::string platform_name, const se::DeviceDescription& gpu_device_info, mlir::MLIRContext* mlir_context, llvm::Module* llvm_module, - bool emit_ir_from_hlo, bool emit_kernels) + bool emit_kernels) : hlo_module_(hlo_module), buffer_assignment_(buffer_assignment), platform_name_(std::move(platform_name)), gpu_device_info_(gpu_device_info), mlir_context_(mlir_context), llvm_module_(llvm_module), - emit_ir_from_hlo_(emit_ir_from_hlo), emit_kernels_(emit_kernels) {} // Disallow copy and assign. IrEmitterContext(const IrEmitterContext&) = delete; @@ -104,7 +103,6 @@ class IrEmitterContext { KernelReuseCache& kernel_cache() { return kernel_cache_; } - bool emit_ir_from_hlo() const { return emit_ir_from_hlo_; } bool emit_kernels() const { return emit_kernels_; } private: @@ -122,7 +120,6 @@ class IrEmitterContext { llvm::Module* llvm_module_; NameUniquer name_uniquer_; std::vector constants_; - const bool emit_ir_from_hlo_; KernelReuseCache kernel_cache_; // We should not emit kernels when loading thunks from a compilation result. diff --git a/third_party/xla/xla/service/gpu/ir_emitter_nested.cc b/third_party/xla/xla/service/gpu/ir_emitter_nested.cc index 3a54197f27ec1f..c888aff308782c 100644 --- a/third_party/xla/xla/service/gpu/ir_emitter_nested.cc +++ b/third_party/xla/xla/service/gpu/ir_emitter_nested.cc @@ -25,6 +25,7 @@ limitations under the License. #include "llvm/IR/Function.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Instructions.h" +#include "llvm/TargetParser/Triple.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" diff --git a/third_party/xla/xla/service/gpu/ir_emitter_triton.cc b/third_party/xla/xla/service/gpu/ir_emitter_triton.cc index d535a52d42ff32..80a92e0c727c63 100644 --- a/third_party/xla/xla/service/gpu/ir_emitter_triton.cc +++ b/third_party/xla/xla/service/gpu/ir_emitter_triton.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include #include +#include #include #include #include @@ -27,6 +28,8 @@ limitations under the License. #include #include +#include "nvidia/include/NVGPUToLLVM/NVGPUToLLVMPass.h" +#include "nvidia/include/TritonNVIDIAGPUToLLVM/Passes.h" #include "absl/algorithm/container.h" #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" @@ -103,12 +106,14 @@ limitations under the License. #include "xla/service/gpu/target_util.h" #include "xla/service/gpu/triton_fusion_analysis.h" #include "xla/service/gpu/triton_tiling_propagation.h" +#include "xla/service/hlo_module_config.h" #include "xla/service/llvm_ir/llvm_util.h" #include "xla/shape_util.h" #include "xla/status.h" #include "xla/status_macros.h" #include "xla/statusor.h" #include "xla/stream_executor/device_description.h" +#include "xla/stream_executor/launch_dim.h" #include "xla/translate/hlo_to_mhlo/hlo_function_importer.h" #include "xla/util.h" #include "xla/xla_data.pb.h" @@ -118,15 +123,12 @@ limitations under the License. #include "tsl/platform/status.h" #include "tsl/platform/statusor.h" #include "tsl/platform/tensor_float_32_utils.h" -#include "triton/Conversion/NVGPUToLLVM/NVGPUToLLVMPass.h" -#include "triton/Conversion/TritonGPUToLLVM/Passes.h" #include "triton/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.h" #include "triton/Dialect/Triton/IR/Dialect.h" #include "triton/Dialect/Triton/IR/Types.h" #include "triton/Dialect/Triton/Transforms/Passes.h" #include "triton/Dialect/TritonGPU/Transforms/Passes.h" #include "triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h" -#include "triton/Target/PTX/TmaMetadata.h" namespace xla { namespace gpu { @@ -140,6 +142,7 @@ namespace mt = ::mlir::triton; using ::llvm::SmallVector; using mlir::ArrayRef; using mlir::ImplicitLocOpBuilder; +using ::mlir::ShapedType; using ::mlir::Type; using ::mlir::Value; using mlir::ValueRange; @@ -219,7 +222,7 @@ ma::ConstantOp CreateConst(ImplicitLocOpBuilder& b, Type type, T value, } Value ZerosLike(ImplicitLocOpBuilder& b, Value x) { - if (auto src_shaped_ty = x.getType().dyn_cast()) { + if (auto src_shaped_ty = x.getType().dyn_cast()) { Type src_ty = src_shaped_ty.getElementType(); return CreateConst(b, src_ty, 0, src_shaped_ty.getShape()); } @@ -227,7 +230,7 @@ Value ZerosLike(ImplicitLocOpBuilder& b, Value x) { } Value OnesLike(ImplicitLocOpBuilder& b, Value x) { - if (auto src_shaped_ty = x.getType().dyn_cast()) { + if (auto src_shaped_ty = x.getType().dyn_cast()) { Type src_ty = src_shaped_ty.getElementType(); return CreateConst(b, src_ty, 1, src_shaped_ty.getShape()); } @@ -240,7 +243,7 @@ Value Cast(ImplicitLocOpBuilder& b, Value value, Type dst_element_ty) { Type src_element_ty = src_ty; Type fp32_ty = b.getF32Type(); Type dst_ty = dst_element_ty; - if (auto src_shaped_ty = src_ty.dyn_cast()) { + if (auto src_shaped_ty = src_ty.dyn_cast()) { src_element_ty = src_shaped_ty.getElementType(); dst_ty = src_shaped_ty.clone(src_shaped_ty.getShape(), dst_element_ty); fp32_ty = src_shaped_ty.clone(src_shaped_ty.getShape(), b.getF32Type()); @@ -731,20 +734,26 @@ absl::StatusOr EmitScope( return values[instructions.back()]; } -absl::Status CreateTritonPipeline(mlir::OpPassManager& pm, - const se::CudaComputeCapability& cc, - const TritonGemmConfig& config) { +// Create Triton pipeline. +// +// `out_cluster_info` must be kept alive at least until pm.run() is called. +// It should be read after that. We have to pass the cluster dims to +// LaunchDimensions. Triton currently uses this as an out-parameter to return +// the cluster dims determined based on `config.num_ctas` and a heuristic. There +// are some signs that show that this was intended to be used as an in-out +// parameter which would give a hint to Triton which cluster dims we prefer to +// use, but that's not the case currently. +absl::Status CreateTritonPipeline( + mlir::OpPassManager& pm, const se::CudaComputeCapability& cc, + const TritonGemmConfig& config, + mlir::triton::nvidia_gpu::ClusterInfo& out_cluster_info) { const int ccAsInt = cc.major * 10 + cc.minor; const int threadsPerWarp = 32; - mlir::triton::nvidia_gpu::ClusterInfo clusterInfo; - clusterInfo.clusterDimX = config.cluster_dims.x; - clusterInfo.clusterDimY = config.cluster_dims.y; - clusterInfo.clusterDimZ = config.cluster_dims.z; // Based on make_ttir() in // @triton//:third_party/nvidia/backend/compiler.py - pm.addPass(mt::createRewriteTensorPointerPass(ccAsInt)); pm.addPass(mlir::createInlinerPass()); + pm.addPass(mt::createRewriteTensorPointerPass()); pm.addPass(mt::createCombineOpsPass()); pm.addPass(mlir::createCanonicalizerPass()); pm.addPass(mt::createReorderBroadcastPass()); @@ -757,9 +766,7 @@ absl::Status CreateTritonPipeline(mlir::OpPassManager& pm, pm.addPass(mt::createConvertTritonToTritonGPUPass( config.num_warps, threadsPerWarp, config.num_ctas, ccAsInt)); pm.addPass(mt::gpu::createCoalescePass()); - pm.addPass(mlir::createTritonNvidiaGPUPlanCTAPass(&clusterInfo)); - pm.addPass(mlir::createTritonGPURewriteTensorPointerPass(ccAsInt)); - pm.addPass(mlir::createTritonNvidiaGPUPlanCTAPass(&clusterInfo)); + pm.addPass(mlir::createTritonNvidiaGPUPlanCTAPass(&out_cluster_info)); pm.addPass(mt::gpu::createRemoveLayoutConversionsPass()); pm.addPass(mt::gpu::createOptimizeThreadLocalityPass()); pm.addPass(mt::gpu::createAccelerateMatmulPass(ccAsInt)); @@ -767,65 +774,34 @@ absl::Status CreateTritonPipeline(mlir::OpPassManager& pm, pm.addPass(mt::gpu::createOptimizeDotOperandsPass()); pm.addPass(mlir::createCSEPass()); - if (cc.IsAtLeastHopper() && config.enable_warp_specialization) { - // Triton currently doesn't support warp specialization for num_warps != 4. - // TODO from Triton to add support here: - // https://github.com/openai/triton/blob/1bc9c0ea67e4cbec2c77d4acde3173aa7d51c8f9/python/triton/compiler/backends/cuda.py#L119 - if (config.num_warps != 4) { - return absl::UnimplementedError( - "Triton currently doesn't support warp specialization for " - "num_warps != 4."); - } - // Ideally, we should run - // 'mlir::createTritonNvidiaGPUWSFeasibilityCheckingPass(ccAsInt)' at this - // point on the IR to check if warp specialization is feasible. Instead, we - // are relying on failures as indication of infeasibility during - // auto-tuning. - pm.addPass(mlir::createTritonNvidiaGPUWSDecomposingPass(ccAsInt)); - pm.addPass(mlir::createTritonNvidiaGPUWSPipelinePass( - config.num_stages, config.num_warps, ccAsInt)); - pm.addPass(mlir::createTritonNvidiaGPUWSMutexPass(ccAsInt)); - pm.addPass(mlir::createTritonNvidiaGPUWSMaterializationPass(ccAsInt)); - pm.addPass(mlir::createLoopInvariantCodeMotionPass()); - pm.addPass(mlir::createCSEPass()); - } else { + if (cc.IsAtLeastAmpere()) { pm.addPass(mt::gpu::createPipelinePass(config.num_stages, config.num_warps, config.num_ctas, ccAsInt)); } - - pm.addPass(mlir::createTritonNvidiaGPUMaterializeLoadStorePass( - config.num_warps, ccAsInt)); - if (ccAsInt <= 80) { - pm.addPass(mlir::triton::gpu::createPrefetchPass()); + if (!cc.IsAtLeastHopper()) { + pm.addPass(mt::gpu::createPrefetchPass()); } + pm.addPass(mt::gpu::createOptimizeDotOperandsPass()); pm.addPass(mt::gpu::createRemoveLayoutConversionsPass()); pm.addPass(mt::gpu::createReduceDataDuplicationPass()); - pm.addPass(mlir::createTritonNvidiaGPUWSFixupMissingAttrs()); pm.addPass(mt::gpu::createReorderInstructionsPass()); pm.addPass(mlir::createCSEPass()); pm.addPass(mlir::createSymbolDCEPass()); if (cc.IsAtLeastHopper()) { pm.addPass(mlir::createTritonNvidiaGPUFenceInsertionPass(ccAsInt)); } - pm.addPass(mlir::createTritonNvidiaGPUWSFixupMissingAttrs()); pm.addPass(mlir::createCanonicalizerPass()); // Based on make_llir() in // @triton//:third_party/nvidia/backend/compiler.py - pm.addPass(mlir::createTritonNvidiaGPUAddDescriptorArgs()); pm.addPass(mlir::triton::gpu::createDecomposeUnsupportedConversionsPass()); pm.addPass(mlir::createConvertSCFToCFPass()); pm.addPass(mlir::createConvertIndexToLLVMPass()); - // // TODO(b/316566238): Use TMA info collected here in XLA runtime. - mlir::triton::gpu::TMAMetadataTy tma_infos; - pm.addPass(mt::createConvertTritonGPUToLLVMPass(ccAsInt, - /*target=*/mlir::triton::NVVM, - &tma_infos)); - if (cc.IsAtLeastHopper() && config.enable_warp_specialization) { - pm.addPass(mlir::createLoopInvariantCodeMotionPass()); - pm.addPass(mlir::createCSEPass()); - } + pm.addPass(mlir::triton::gpu::createAllocateSharedMemoryPass()); + pm.addPass( + mt::createConvertTritonGPUToLLVMPass(ccAsInt, + /*target=*/mlir::triton::NVVM)); pm.addPass(mt::createConvertNVGPUToLLVMPass()); pm.addPass(mlir::createArithToLLVMConversionPass()); pm.addPass(mlir::createCanonicalizerPass()); @@ -1548,6 +1524,94 @@ ConstHloInstructionSet ScopeInputs(const TritonFusionAnalysis& analysis, return result; } +// Truncates |input| of F32 type to the number representable in Bf16 toward +// zero. +// It is used for Emit6xBfloat16MatMul. +Value TruncateToBF16TowardsZero(ImplicitLocOpBuilder& b, Value input) { + ShapedType input_type = input.getType().dyn_cast(); + Type input_type_as_i32 = input_type.clone(b.getI32Type()); + Value input_as_i32 = b.create(input_type_as_i32, input); + Value mask = CreateConst(b, b.getI32Type(), 0xFFFF0000u, + input_type.getShape()); + Value high_bits = b.create(input_type_as_i32, input_as_i32, mask); + + return b.create(input_type, high_bits); +} + +// Finds the middle 8 bits of |input|'s mantissa. +// It is used for Emit6xBfloat16MatMul. +Value SoftMiddleEight(ImplicitLocOpBuilder& b, Value input) { + Value high = TruncateToBF16TowardsZero(b, input); + return b.create(input, high); +} + +// Finds the low 8 bits of |input|'s mantissa. +// It is used for Emit6xBfloat16MatMul. +Value SoftLowEight(ImplicitLocOpBuilder& b, Value input) { + // Find the middle bits of the middle bits, and these are the low eight + // bits. + return SoftMiddleEight(b, SoftMiddleEight(b, input)); +} + +// Rounds |input| to BF16 type. +// It is used for Emit6xBfloat16MatMul. +Value RoundToBF16(ImplicitLocOpBuilder& b, Value input) { + return Cast(b, input, b.getBF16Type()); +} + +// Leverages BF16 datatype for F32 matmul computation. It follows the guidance +// from https://arxiv.org/pdf/1904.06376.pdf. +absl::StatusOr Emit6xBfloat16MatMul(ImplicitLocOpBuilder& b, Value lhs, + Value rhs, Value acc) { + Type f32 = b.getF32Type(); + TF_RET_CHECK(lhs.getType().cast().getElementType() == f32); + TF_RET_CHECK(rhs.getType().cast().getElementType() == f32); + TF_RET_CHECK(acc.getType().cast().getElementType() == f32); + + Value lhs_high = RoundToBF16(b, TruncateToBF16TowardsZero(b, lhs)); + Value lhs_middle = + RoundToBF16(b, TruncateToBF16TowardsZero(b, SoftMiddleEight(b, lhs))); + Value lhs_low = + RoundToBF16(b, TruncateToBF16TowardsZero(b, SoftLowEight(b, lhs))); + + Value rhs_high = RoundToBF16(b, TruncateToBF16TowardsZero(b, rhs)); + Value rhs_middle = + RoundToBF16(b, TruncateToBF16TowardsZero(b, SoftMiddleEight(b, rhs))); + Value rhs_low = + RoundToBF16(b, TruncateToBF16TowardsZero(b, SoftLowEight(b, rhs))); + + auto bf16_dot = [&](Value lhs_bf16, Value rhs_bf16, + Value accumulator) -> Value { + return b.create(lhs_bf16, rhs_bf16, accumulator, + /*allowTF32=*/false, + /*maxNumImpreciseAcc=*/0); + }; + + Value local_acc = ZerosLike(b, acc); + Value result = bf16_dot(lhs_middle, rhs_middle, local_acc); + result = bf16_dot(lhs_low, rhs_high, result); + result = bf16_dot(lhs_high, rhs_low, result); + result = bf16_dot(lhs_middle, rhs_high, result); + result = bf16_dot(lhs_high, rhs_middle, result); + // If lhs is 1.0, we will have lhs_high = 1.0 and lhs_low = 0.0. + // If rhs is +infinity, we will have: + // +infinity * 1.0 = +infinity + // +infinity * 0.0 = NaN + // We would get the wrong result if we sum these partial products. Instead, we + // must override any accumulated result if the last partial product is + // non-finite. See b/115844437. + Value positive_inf = CreateConst( + b, b.getF32Type(), std::numeric_limits::infinity(), + result.getType().cast().getShape()); + Value abs_result = b.create(result); + Value is_finite = + b.create(ma::CmpFPredicate::OGT, positive_inf, abs_result); + result = b.create(is_finite, result, ZerosLike(b, result)); + result = bf16_dot(lhs_high, rhs_high, result); + result = b.create(acc, result); + return result; +} + // Variable naming: lhs [m, k] x rhs [k, n] -> out [m, n]. absl::Status EmitMatMul(mlir::OpBuilder builder, absl::string_view libdevice_path, @@ -1736,15 +1800,30 @@ absl::Status EmitMatMul(mlir::OpBuilder builder, [](const int precision) { return precision != PrecisionConfig::DEFAULT; }); - - // Execute matrix multiplication of input tiles and pass the accumulator. - // TODO(manany): Should be looked into once we enable Hopper workloads. - // maxNumImpreciseAcc flag was introduced for Hopper to accumulate in a - // lower precision than the output type. The change was introduced here: - // https://github.com/openai/triton/commit/31b0c521427109a8eda609b58d756c380b21599a - Value accumulator_next = b.create(dot_input_lhs, dot_input_rhs, - iter_args.back(), allow_tf32, - /*maxNumImpreciseAcc=*/0); + const HloModule* hlo_module = computation->parent(); + Type f32 = b.getF32Type(); + // BF16 datatype is not supported before Ampere. + const bool use_bf16_6x = + device_info.cuda_compute_capability().IsAtLeastAmpere() && + hlo_module->config().debug_options().xla_gpu_enable_bf16_6way_gemm() && + dot_input_lhs.getType().cast().getElementType() == f32 && + dot_input_rhs.getType().cast().getElementType() == f32; + Value accumulator_next; + if (use_bf16_6x) { + absl::StatusOr accumulator_next_or = Emit6xBfloat16MatMul( + b, dot_input_lhs, dot_input_rhs, iter_args.back()); + TF_CHECK_OK(accumulator_next_or.status()); + accumulator_next = accumulator_next_or.value(); + } else { + // Execute matrix multiplication of input tiles and pass the accumulator. + // TODO(manany): Should be looked into once we enable Hopper workloads. + // maxNumImpreciseAcc flag was introduced for Hopper to accumulate in a + // lower precision than the output type. The change was introduced here: + // https://github.com/openai/triton/commit/31b0c521427109a8eda609b58d756c380b21599a + accumulator_next = b.create(dot_input_lhs, dot_input_rhs, + iter_args.back(), allow_tf32, + /*maxNumImpreciseAcc=*/0); + } iter_args_next.push_back(accumulator_next); b.create(iter_args_next); @@ -2000,11 +2079,9 @@ absl::StatusOr> TranslateLLVMToLLVMIR( namespace { -std::string GetLibdevicePath(const HloComputation* hlo_computation) { - return nvptx::LibDevicePath(hlo_computation->parent() - ->config() - .debug_options() - .xla_gpu_cuda_data_dir()); +std::string GetLibdevicePath(const HloModuleConfig& hlo_config) { + return nvptx::LibDevicePath( + hlo_config.debug_options().xla_gpu_cuda_data_dir()); } } // namespace @@ -2044,9 +2121,9 @@ absl::StatusOr> CreateTritonModule( fn.addEntryBlock(); b.setInsertionPointToStart(&fn.front()); - TF_RETURN_IF_ERROR(ir_emitter(b, GetLibdevicePath(hlo_computation), - device_info, analysis, hlo_computation, fn, - config)); + TF_RETURN_IF_ERROR( + ir_emitter(b, GetLibdevicePath(hlo_computation->parent()->config()), + device_info, analysis, hlo_computation, fn, config)); b.create(loc); @@ -2120,12 +2197,21 @@ absl::StatusOr TritonWrapper( VLOG(2) << config.ToString(); // Compile Triton kernel to LLVM. - std::optional log_stream; const HloModule* hlo_module = hlo_computation->parent(); + return CompileTritonToLLVM(hlo_module->config(), hlo_module->name(), cc, + device_info, config, triton_module.get(), + llvm_module, mlir_context); +} +// TODO(b/325220878): Replace TritonGemmConfig with a more generic abstraction. +absl::StatusOr CompileTritonToLLVM( + const HloModuleConfig& hlo_config, absl::string_view hlo_module_name, + const se::CudaComputeCapability& cc, + const se::DeviceDescription& device_info, const TritonGemmConfig& config, + mlir::ModuleOp triton_module, llvm::Module* llvm_module, + mlir::MLIRContext& mlir_context) { bool should_verify = - (hlo_module->config().debug_options().xla_gpu_llvm_verification_level() >= - 1); + (hlo_config.debug_options().xla_gpu_llvm_verification_level() >= 1); #ifndef NDEBUG should_verify = true; #endif @@ -2133,13 +2219,14 @@ absl::StatusOr TritonWrapper( mlir::PassManager pm(&mlir_context); pm.enableVerifier(should_verify); - if (hlo_module->config().debug_options().xla_gpu_dump_llvmir()) { + std::optional log_stream; + if (hlo_config.debug_options().xla_gpu_dump_llvmir()) { const std::string basename = - absl::StrCat(absl::string_view(tsl::io::Basename(hlo_module->name())), + absl::StrCat(absl::string_view(tsl::io::Basename(hlo_module_name)), ".triton-passes.log"); std::string outputs_dir; if (!tsl::io::GetTestUndeclaredOutputsDir(&outputs_dir)) { - outputs_dir = hlo_module->config().debug_options().xla_dump_to(); + outputs_dir = hlo_config.debug_options().xla_dump_to(); } if (!outputs_dir.empty()) { std::string path = tsl::io::JoinPath(outputs_dir, basename); @@ -2163,7 +2250,8 @@ absl::StatusOr TritonWrapper( } } - if (!CreateTritonPipeline(pm, cc, config).ok()) { + mlir::triton::nvidia_gpu::ClusterInfo cluster_info; + if (!CreateTritonPipeline(pm, cc, config, /*out*/ cluster_info).ok()) { return Internal("Failed to create Triton pipeline."); } if (log_stream.has_value()) { @@ -2176,7 +2264,7 @@ absl::StatusOr TritonWrapper( // llvm::Linker::linkModules() segfaults if we don't strip locations. pm.addPass(mlir::createStripDebugInfoPass()); - bool succeeded = mlir::succeeded(pm.run(*triton_module)); + bool succeeded = mlir::succeeded(pm.run(triton_module)); if (log_stream.has_value()) { log_stream->flush(); @@ -2187,8 +2275,7 @@ absl::StatusOr TritonWrapper( } const int shared_mem_bytes = - (*triton_module) - ->getAttrOfType("triton_gpu.shared") + triton_module->getAttrOfType("triton_gpu.shared") .getInt(); VLOG(2) << "Shared memory usage: " << shared_mem_bytes << " B"; if (shared_mem_bytes > device_info.shared_memory_per_block_optin()) { @@ -2199,8 +2286,8 @@ absl::StatusOr TritonWrapper( TF_ASSIGN_OR_RETURN( std::unique_ptr ll_triton_module, - TranslateLLVMToLLVMIR(&llvm_module->getContext(), *triton_module, - GetLibdevicePath(hlo_computation))); + TranslateLLVMToLLVMIR(&llvm_module->getContext(), triton_module, + GetLibdevicePath(hlo_config))); VLogModule(5, *ll_triton_module); if (should_verify) { VerifyModule(*ll_triton_module); @@ -2219,7 +2306,24 @@ absl::StatusOr TritonWrapper( VerifyModule(*llvm_module); } - return {{shared_mem_bytes}}; + // `cluster_info` must be read after pm.run(). + std::optional cluster_dim; + if (config.num_ctas > 1) { + VLOG(3) << "num_ctas: " << config.num_ctas + << ", cluster_info: " << cluster_info.clusterDimX << "," + << cluster_info.clusterDimY << "," << cluster_info.clusterDimZ; + if (cluster_info.clusterDimX > 1 || cluster_info.clusterDimY > 1 || + cluster_info.clusterDimZ > 1) { + cluster_dim = + se::ClusterDim(cluster_info.clusterDimX, cluster_info.clusterDimY, + cluster_info.clusterDimZ); + } + } else { + TF_RET_CHECK(cluster_info.clusterDimX == 1 && + cluster_info.clusterDimY == 1 && + cluster_info.clusterDimZ == 1); + } + return {{shared_mem_bytes, cluster_dim}}; } } // namespace gpu diff --git a/third_party/xla/xla/service/gpu/ir_emitter_triton.h b/third_party/xla/xla/service/gpu/ir_emitter_triton.h index d733d192a36346..749c6aab3e58b9 100644 --- a/third_party/xla/xla/service/gpu/ir_emitter_triton.h +++ b/third_party/xla/xla/service/gpu/ir_emitter_triton.h @@ -16,7 +16,9 @@ limitations under the License. #ifndef XLA_SERVICE_GPU_IR_EMITTER_TRITON_H_ #define XLA_SERVICE_GPU_IR_EMITTER_TRITON_H_ +#include #include +#include #include "absl/strings/string_view.h" #include "llvm/IR/Module.h" @@ -32,13 +34,15 @@ limitations under the License. #include "xla/service/gpu/triton_fusion_analysis.h" #include "xla/statusor.h" #include "xla/stream_executor/device_description.h" +#include "xla/stream_executor/launch_dim.h" #include "triton/Dialect/Triton/IR/Dialect.h" namespace xla { namespace gpu { struct TritonWrapperResult { - int64_t shmem_bytes; + int64_t shmem_bytes = 0; + std::optional cluster_dim; }; // Compute the launch dimensions for the given Triton MatMul. @@ -89,6 +93,14 @@ absl::StatusOr> CreateTritonModule( const se::DeviceDescription& device_info, const TritonGemmConfig& config, TritonIrEmitter ir_emitter, mlir::MLIRContext& mlir_context); +// Compiles a given Triton module to LLVM IR. +absl::StatusOr CompileTritonToLLVM( + const HloModuleConfig& hlo_config, absl::string_view hlo_module_name, + const se::CudaComputeCapability& cc, + const se::DeviceDescription& device_info, const TritonGemmConfig& config, + mlir::ModuleOp triton_module, llvm::Module* llvm_module, + mlir::MLIRContext& mlir_context); + } // namespace gpu } // namespace xla diff --git a/third_party/xla/xla/service/gpu/ir_emitter_triton_large_test.cc b/third_party/xla/xla/service/gpu/ir_emitter_triton_large_test.cc index d00dcda9ba59d9..eaf73063fde421 100644 --- a/third_party/xla/xla/service/gpu/ir_emitter_triton_large_test.cc +++ b/third_party/xla/xla/service/gpu/ir_emitter_triton_large_test.cc @@ -62,7 +62,7 @@ ENTRY e { p0 = f16[65536,32800] parameter(0) p1 = f16[32800,32] parameter(1) ROOT _ = f16[65536,32] fusion(p0, p1), kind=kCustom, calls=triton_dot, - backend_config="{\"fusion_backend_config\": {kind: \"__triton_gemm\", triton_gemm_config: {\"block_m\":\"32\",\"block_n\":\"32\",\"block_k\":\"32\",\"split_k\":\"1\",\"num_stages\":\"1\",\"num_warps\":\"1\"}}}" + backend_config="{\"fusion_backend_config\": {kind: \"__triton_gemm\", triton_gemm_config: {\"block_m\":\"32\",\"block_n\":\"32\",\"block_k\":\"32\",\"split_k\":\"1\",\"num_stages\":\"1\",\"num_warps\":\"1\",\"num_ctas\":\"1\"}}}" } )"; diff --git a/third_party/xla/xla/service/gpu/ir_emitter_triton_parametrized_test.cc b/third_party/xla/xla/service/gpu/ir_emitter_triton_parametrized_test.cc index 14291bc976d088..a9027ce8aac59b 100644 --- a/third_party/xla/xla/service/gpu/ir_emitter_triton_parametrized_test.cc +++ b/third_party/xla/xla/service/gpu/ir_emitter_triton_parametrized_test.cc @@ -199,9 +199,14 @@ ENTRY e { ROOT triton_gemm__ = f32[15,68]{1,0} fusion(p0, p1), kind=kCustom, calls=triton_gemm___computation, backend_config={"fusion_backend_config":{"kind":"__triton_gemm", - "triton_gemm_config":{"block_m":"32","block_n":"32", - "block_k":"32","split_k":"1", - "num_stages":"1","num_warps":"4"}}} + "triton_gemm_config": + {"block_m":"32", + "block_n":"32", + "block_k":"32", + "split_k":"1", + "num_stages":"1", + "num_warps":"4", + "num_ctas":"1"}}} })"; const std::string hlo_test = absl::Substitute( kHloTestTemplate, primitive_util::LowercasePrimitiveTypeName(data_type), @@ -310,9 +315,14 @@ ENTRY e { ROOT triton_gemm__ = f32[92,63]{1,0} fusion(p0, p1, p2), kind=kCustom, calls=triton_gemm___computation, backend_config={"fusion_backend_config":{"kind":"__triton_gemm", - "triton_gemm_config":{"block_m":"64","block_n":"32", - "block_k":"64","split_k":"1", - "num_stages":"2","num_warps":"2"}}} + "triton_gemm_config": + {"block_m":"64", + "block_n":"32", + "block_k":"64", + "split_k":"1", + "num_stages":"2", + "num_warps":"2", + "num_ctas":"1"}}} })"; const std::string hlo_test = absl::Substitute( kHloTestTemplate, primitive_util::LowercasePrimitiveTypeName(data_type), @@ -437,9 +447,14 @@ ENTRY e { ROOT triton_gemm__ = f32[92,63]{1,0} fusion(p0, p1, p2), kind=kCustom, calls=triton_gemm___computation, backend_config={"fusion_backend_config":{"kind":"__triton_gemm", - "triton_gemm_config":{"block_m":"16","block_n":"64", - "block_k":"16","split_k":"1", - "num_stages":"3","num_warps":"2"}}} + "triton_gemm_config": + {"block_m":"16", + "block_n":"64", + "block_k":"16", + "split_k":"1", + "num_stages":"3", + "num_warps":"2", + "num_ctas":"1"}}} })"; const std::string hlo_test = absl::Substitute( kHloTestTemplate, primitive_util::LowercasePrimitiveTypeName(data_type), @@ -542,9 +557,14 @@ ENTRY e { ROOT triton_gemm__ = $1[92,63]{1,0} fusion(p0, p1, p2, p3), kind=kCustom, calls=triton_gemm___computation, backend_config={"fusion_backend_config":{"kind":"__triton_gemm", - "triton_gemm_config":{"block_m":"16","block_n":"64", - "block_k":"16","split_k":"1", - "num_stages":"3","num_warps":"2"}}} + "triton_gemm_config": + {"block_m":"16", + "block_n":"64", + "block_k":"16", + "split_k":"1", + "num_stages":"3", + "num_warps":"2", + "num_ctas":"1"}}} })"; const std::string hlo_test = absl::Substitute( kHloTestTemplate, primitive_util::LowercasePrimitiveTypeName(data_type1), @@ -635,9 +655,14 @@ ENTRY e { ROOT triton_gemm__ = f32[92,63]{1,0} fusion(p0, p1), kind=kCustom, calls=triton_gemm___computation, backend_config={"fusion_backend_config":{"kind":"__triton_gemm", - "triton_gemm_config":{"block_m":"16","block_n":"64", - "block_k":"16","split_k":"1", - "num_stages":"3","num_warps":"2"}}} + "triton_gemm_config": + {"block_m":"16", + "block_n":"64", + "block_k":"16", + "split_k":"1", + "num_stages":"3", + "num_warps":"2", + "num_ctas":"1"}}} })"; const std::string hlo_test = absl::Substitute( kHloTestTemplate, primitive_util::LowercasePrimitiveTypeName(data_type)); diff --git a/third_party/xla/xla/service/gpu/ir_emitter_triton_test.cc b/third_party/xla/xla/service/gpu/ir_emitter_triton_test.cc index 37188c9e8e566f..a16f10fe6d9226 100644 --- a/third_party/xla/xla/service/gpu/ir_emitter_triton_test.cc +++ b/third_party/xla/xla/service/gpu/ir_emitter_triton_test.cc @@ -15,6 +15,8 @@ limitations under the License. #include "xla/service/gpu/ir_emitter_triton.h" +#include +#include #include #include #include @@ -22,6 +24,7 @@ limitations under the License. #include #include +#include "absl/algorithm/container.h" #include "absl/strings/string_view.h" #include "llvm/IR/LLVMContext.h" #include "llvm/Support/raw_ostream.h" @@ -32,11 +35,14 @@ limitations under the License. #include "xla/error_spec.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" +#include "xla/literal.h" +#include "xla/literal_util.h" #include "xla/service/gpu/backend_configs.pb.h" #include "xla/service/gpu/gpu_device_info_for_tests.h" #include "xla/service/gpu/ir_emission_utils.h" #include "xla/service/gpu/matmul_utils.h" #include "xla/service/gpu/tests/gpu_codegen_test.h" +#include "xla/service/gpu/triton_fusion_analysis.h" #include "xla/service/pattern_matcher.h" #include "xla/service/pattern_matcher_gmock.h" #include "xla/status_macros.h" @@ -77,7 +83,20 @@ class TritonGemmTest : public TritonTest { public: DebugOptions GetDebugOptionsForTest() override { DebugOptions debug_options = TritonTest::GetDebugOptionsForTest(); + // Do not fall back to cuBLAS, we are testing Triton. debug_options.set_xla_gpu_cublas_fallback(false); + // Do not autotune split-k by default, since this prevents deterministically + // matching the optimized HLO. + debug_options.set_xla_gpu_enable_split_k_autotuning(false); + return debug_options; + } +}; + +class TritonGemmTestWithSplitK : public TritonGemmTest { + public: + DebugOptions GetDebugOptionsForTest() override { + DebugOptions debug_options = TritonGemmTest::GetDebugOptionsForTest(); + debug_options.set_xla_gpu_enable_split_k_autotuning(true); return debug_options; } }; @@ -108,6 +127,7 @@ absl::StatusOr TritonFilecheckTest::CreateTritonIrAndFileCheck( auto* computation = verified_module->GetComputationWithName(triton_fusion_name); + TF_RET_CHECK(computation != nullptr); TF_ASSIGN_OR_RETURN(auto analysis, TritonFusionAnalysis::Execute(*computation)); @@ -143,7 +163,10 @@ ENTRY e { p0 = s8[80,115]{1,0} parameter(0) ROOT triton_gemm_r = f32[80,137]{1,0} fusion(p0, p1), kind=kCustom, calls=triton_gemm_r, - backend_config={"fusion_backend_config": {kind: "__triton_gemm", triton_gemm_config: {"block_m":16,"block_n":64,"block_k":32,"split_k":1,"num_stages":1,"num_warps":2}}} + backend_config={"fusion_backend_config": {kind: "__triton_gemm", + triton_gemm_config: {"block_m":16,"block_n":64,"block_k":32, + "split_k":1,"num_stages":1,"num_warps":2, + "num_ctas":1}}} })"; TritonGemmConfig config(16, 64, 32, 1, 1, 1); ASSERT_THAT(CreateTritonIrAndFileCheck(kHloText, config, EmitMatMul, @@ -188,15 +211,15 @@ CHECK: %[[RHS_ITER_PTR_NEXT:.*]] = tt.advance %[[RHS_ITER_PTR]], [%[[TILE CHECK: %[[CONVERTED:.*]] = arith.sitofp %[[LHS_TILE]] : tensor<16x32xi8> to tensor<16x32xf32> CHECK: %[[TILE_K_LIMIT:.*]] = arith.subi %[[SIZE_K]], %[[BLOCK_K]] : i32 CHECK: %[[K_TILE_IOTA:.*]] = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32> -CHECK: %[[K_OFFSETS_1K:.*]] = tt.expand_dims %[[K_TILE_IOTA]] {axis = 0 : i32} : (tensor<32xi32>) -> tensor<1x32xi32> -CHECK: %[[TILE_K_LIMIT_1K:.*]] = tt.splat %[[TILE_K_LIMIT]] : (i32) -> tensor<1x32xi32> +CHECK: %[[K_OFFSETS_1K:.*]] = tt.expand_dims %[[K_TILE_IOTA]] {axis = 0 : i32} : tensor<32xi32> -> tensor<1x32xi32> +CHECK: %[[TILE_K_LIMIT_1K:.*]] = tt.splat %[[TILE_K_LIMIT]] : i32 -> tensor<1x32xi32> CHECK: %[[LHS_INBOUNDS_1K:.*]] = arith.cmpi slt, %[[K_OFFSETS_1K]], %[[TILE_K_LIMIT_1K]] : tensor<1x32xi32> -CHECK: %[[LHS_INBOUNDS_MK:.*]] = tt.broadcast %[[LHS_INBOUNDS_1K]] : (tensor<1x32xi1>) -> tensor<16x32xi1> +CHECK: %[[LHS_INBOUNDS_MK:.*]] = tt.broadcast %[[LHS_INBOUNDS_1K]] : tensor<1x32xi1> -> tensor<16x32xi1> CHECK: %[[LHS_MASKED:.*]] = arith.select %[[LHS_INBOUNDS_MK]], %[[CONVERTED]], %[[ZERO_MK]] -CHECK: %[[K_OFFSETS_K1:.*]] = tt.expand_dims %[[K_TILE_IOTA]] {axis = 1 : i32} : (tensor<32xi32>) -> tensor<32x1xi32> -CHECK: %[[TILE_K_LIMIT_K1:.*]] = tt.splat %[[TILE_K_LIMIT]] : (i32) -> tensor<32x1xi32> +CHECK: %[[K_OFFSETS_K1:.*]] = tt.expand_dims %[[K_TILE_IOTA]] {axis = 1 : i32} : tensor<32xi32> -> tensor<32x1xi32> +CHECK: %[[TILE_K_LIMIT_K1:.*]] = tt.splat %[[TILE_K_LIMIT]] : i32 -> tensor<32x1xi32> CHECK: %[[RHS_INBOUNDS_K1:.*]] = arith.cmpi slt, %[[K_OFFSETS_K1]], %[[TILE_K_LIMIT_K1]] : tensor<32x1xi32> -CHECK: %[[RHS_INBOUNDS_KN:.*]] = tt.broadcast %[[RHS_INBOUNDS_K1]] : (tensor<32x1xi1>) -> tensor<32x64xi1> +CHECK: %[[RHS_INBOUNDS_KN:.*]] = tt.broadcast %[[RHS_INBOUNDS_K1]] : tensor<32x1xi1> -> tensor<32x64xi1> CHECK: %[[RHS_MASKED:.*]] = arith.select %[[RHS_INBOUNDS_KN]], %[[RHS_TILE]], %[[ZERO_KN]] : tensor<32x64xi1>, tensor<32x64xf32> CHECK: %[[ACC_NEXT:.*]] = tt.dot %[[LHS_MASKED]], %[[RHS_MASKED]], %[[ACC]] CHECK: scf.yield %[[LHS_ITER_PTR_NEXT]], %[[RHS_ITER_PTR_NEXT]], %[[ACC_NEXT]] : !tt.ptr, 1>, !tt.ptr, 1>, tensor<16x64xf32> @@ -675,7 +698,8 @@ ENTRY e { backend_config={"fusion_backend_config": {kind: "__triton_gemm", triton_gemm_config: { "block_m":16,"block_n":16,"block_k":16, - "split_k":1,"num_stages":1,"num_warps":1 + "split_k":1,"num_stages":1,"num_warps":1, + "num_ctas":1 } } } @@ -712,7 +736,10 @@ ENTRY e { parameter_2 = f32[2,10,256]{2,1,0} parameter(2) ROOT dot = f32[2,3,384]{2,1,0} fusion(parameter_0, parameter_1, parameter_2), kind=kCustom, calls=triton_gemm, - backend_config={"fusion_backend_config": {kind: "__triton_gemm", triton_gemm_config: {"block_m":16,"block_n":64,"block_k":32,"split_k":1,"num_stages":1,"num_warps":2}}} + backend_config={"fusion_backend_config": {kind: "__triton_gemm", + triton_gemm_config: {"block_m":16,"block_n":64,"block_k":32, + "split_k":1,"num_stages":1,"num_warps":2, + "num_ctas":1}}} })"; TritonGemmConfig config(16, 64, 32, 1, 1, 2); @@ -753,7 +780,9 @@ ENTRY e { ROOT triton_gemm_r = f32[80,16]{1,0} fusion(p0, p1), kind=kCustom, calls=triton_gemm_r, backend_config={"fusion_backend_config": {kind: "__triton_gemm", triton_gemm_config: - {"block_m":32,"block_n":32,"block_k":32,"split_k":1,"num_stages":1,"num_warps":2}}} + {"block_m":32,"block_n":32,"block_k":32, + "split_k":1,"num_stages":1,"num_warps":2, + "num_ctas":1}}} })"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr verified_module, ParseAndReturnVerifiedModule(kHloText)); @@ -809,7 +838,9 @@ ENTRY e { ROOT triton_gemm_r = f32[80,16]{1,0} fusion(p0, p1), kind=kCustom, calls=triton_gemm_r, backend_config={"fusion_backend_config": {kind: "__triton_gemm", triton_gemm_config: - {"block_m":32,"block_n":32,"block_k":32,"split_k":1,"num_stages":1,"num_warps":2}}} + {"block_m":32,"block_n":32,"block_k":32, + "split_k":1,"num_stages":1,"num_warps":2, + "num_ctas":1}}} })"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr verified_module, ParseAndReturnVerifiedModule(kHloText)); @@ -883,7 +914,8 @@ ENTRY entry { EXPECT_GT(result.shmem_bytes, dev_info.shared_memory_per_block()); } -TEST_F(TritonGemmTest, WorksWhenKIsDivisibleByBlockKButNotByBlockKTimesSplitK) { +TEST_F(TritonGemmTestWithSplitK, + WorksWhenKIsDivisibleByBlockKButNotByBlockKTimesSplitK) { // The condition mentioned in the test name is fulfilled by // GemmKey(16, 64, 256, 8, 1, 4), which was part of the default configs for // Ampere at the time of the addition of this test case. @@ -971,10 +1003,11 @@ TEST_F(TritonGemmTest, SplitLhsNoncontractingTransposeRhs) { HloModule t ENTRY e { - p0 = s8[3,122,96,12]{3,2,1,0} parameter(0) + p0 = pred[3,122,96,12]{3,2,1,0} parameter(0) cp0 = f16[3,122,96,12]{3,2,1,0} convert(p0) - p1 = f16[1,5,122]{2,1,0} parameter(1) - ROOT _ = f16[3,96,12,1,5]{4,3,2,1,0} dot(cp0, p1), + p1 = pred[1,5,122]{2,1,0} parameter(1) + cp1 = f16[1,5,122]{2,1,0} convert(p1) + ROOT _ = f16[3,96,12,1,5]{4,3,2,1,0} dot(cp0, cp1), lhs_contracting_dims={1}, rhs_contracting_dims={2} })"; @@ -987,7 +1020,7 @@ ENTRY e { ; CHECK-SAME: "block_m": )"); - EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{/*aabs=*/1e-2, /*arel=*/1e-2})); + EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{/*aabs=*/0, /*arel=*/0})); } TEST_F(TritonGemmTest, SplitLhsNoncontracting) { @@ -1403,7 +1436,8 @@ ENTRY e { backend_config={"fusion_backend_config": {"kind":"__triton_gemm", "triton_gemm_config":{"block_m":"16","block_n":"64", "block_k":"16","split_k":"1", - "num_stages":"3","num_warps":"2"}}} + "num_stages":"3","num_warps":"2", + "num_ctas":"1"}}} })") .status()); } @@ -1608,7 +1642,7 @@ ENTRY e { })"; MatchOptimizedHlo(kHloText, R"( -; CHECK: fused_computation +; CHECK: fused_subtract ; CHECK: negate ; CHECK: negate ; CHECK: ROOT @@ -1832,7 +1866,9 @@ ENTRY e { p1 = f16[92,75] parameter(1) ROOT _ = f16[92,67] fusion(p0, p1), kind=kCustom, calls=triton_dot, backend_config={"fusion_backend_config": {kind: "__triton_gemm", triton_gemm_config: - {"block_m":32,"block_n":64,"block_k":32,"split_k":1,"num_stages":1,"num_warps":1}}} + {"block_m":32,"block_n":64,"block_k":32, + "split_k":1,"num_stages":1,"num_warps":1, + "num_ctas":1}}} })"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, @@ -2171,7 +2207,8 @@ ENTRY e { EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); } -TEST_F(TritonGemmTest, SplitKDoesNotBreakSlicedFragmentedContractingDimension) { +TEST_F(TritonGemmTestWithSplitK, + SplitKDoesNotBreakSlicedFragmentedContractingDimension) { const std::string kHloText = R"( ENTRY e { p0 = f16[16,8,128]{2,1,0} parameter(0) @@ -2426,7 +2463,11 @@ ENTRY e { p0 = s8[101,202]{1,0} parameter(0) p1 = f32[202,303]{1,0} parameter(1) ROOT _ = f32[101,303] fusion(p0, p1), kind=kCustom, calls=triton_dot, - backend_config={"fusion_backend_config": {kind: "__triton_gemm", triton_gemm_config: {"block_m":16,"block_n":64,"block_k":32,"split_k":1,"num_stages":3,"num_warps":8}}} + backend_config={"fusion_backend_config": {kind: "__triton_gemm", + triton_gemm_config: + {"block_m":16,"block_n":64,"block_k":32, + "split_k":1,"num_stages":3,"num_warps":8, + "num_ctas":1}}} })"; const char* hlo_text_triton = R"( @@ -2444,7 +2485,10 @@ ENTRY e { p0 = s8[101,202]{1,0} parameter(0) p1 = f32[202,303]{1,0} parameter(1) ROOT _ = f32[101,303] fusion(p0, p1), kind=kCustom, calls=triton_dot, - backend_config={"fusion_backend_config": {kind: "__triton_gemm", triton_gemm_config: {"block_m":32,"block_n":128,"block_k":32,"split_k":1,"num_stages":2,"num_warps":4}}} + backend_config={"fusion_backend_config": {kind: "__triton_gemm", + triton_gemm_config: {"block_m":32,"block_n":128,"block_k":32, + "split_k":1,"num_stages":2,"num_warps":4, + "num_ctas":1}}} })"; EXPECT_TRUE(RunAndCompareTwoModules(hlo_text_ref, hlo_text_triton, @@ -2480,7 +2524,10 @@ ENTRY e { p0 = f16[5,7]{1,0} parameter(0) p1 = f16[7,33]{1,0} parameter(1) ROOT _ = f16[5,33] fusion(p0, p1), kind=kCustom, calls=triton_dot, - backend_config={"fusion_backend_config": {kind: "__triton_gemm", triton_gemm_config: {"block_m":32,"block_n":32,"block_k":32,"split_k":1,"num_stages":1,"num_warps":1}}} + backend_config={"fusion_backend_config": {kind: "__triton_gemm", + triton_gemm_config: {"block_m":32,"block_n":32,"block_k":32, + "split_k":1,"num_stages":1,"num_warps":1, + "num_ctas":1}}} } )"; @@ -2517,7 +2564,10 @@ ENTRY e { p0 = f32[5,7]{1,0} parameter(0) p1 = f32[7,33]{1,0} parameter(1) ROOT _ = f32[5,33] fusion(p0, p1), kind=kCustom, calls=triton_dot, - backend_config={"fusion_backend_config": {kind: "__triton_gemm", triton_gemm_config: {"block_m":32,"block_n":32,"block_k":32,"split_k":1,"num_stages":1,"num_warps":1}}} + backend_config={"fusion_backend_config": {kind: "__triton_gemm", + triton_gemm_config: {"block_m":32,"block_n":32,"block_k":32, + "split_k":1,"num_stages":1,"num_warps":1, + "num_ctas":1}}} } )"; @@ -2559,7 +2609,10 @@ ENTRY e { arg0 = bf16[512,16]{1,0} parameter(0) arg1 = bf16[512,256]{1,0} parameter(1) ROOT _ = bf16[16,256]{1,0} fusion(arg0, arg1), kind=kCustom, calls=triton_dot, - backend_config={"fusion_backend_config": {kind: "__triton_gemm", triton_gemm_config: {"block_m":128,"block_n":32,"block_k":16,"split_k":1,"num_stages":2,"num_warps":4}}} + backend_config={"fusion_backend_config": {kind: "__triton_gemm", + triton_gemm_config: {"block_m":128,"block_n":32,"block_k":16, + "split_k":1,"num_stages":2,"num_warps":4, + "num_ctas":1}}} } )"; @@ -2595,7 +2648,10 @@ ENTRY e { p0 = s8[332,441]{1,0} parameter(0) p1 = f16[441,39]{1,0} parameter(1) ROOT _ = f16[332,39]{1,0} fusion(p0, p1), kind=kCustom, calls=triton_dot, - backend_config={"fusion_backend_config": {kind: "__triton_gemm", triton_gemm_config: {"block_m":128,"block_n":128,"block_k":128,"split_k":1,"num_stages":2,"num_warps":32}}} + backend_config={"fusion_backend_config": {kind: "__triton_gemm", + triton_gemm_config: {"block_m":128,"block_n":128,"block_k":128, + "split_k":1,"num_stages":2,"num_warps":32, + "num_ctas":1}}} })"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr hlo_module, @@ -2613,12 +2669,14 @@ ENTRY e { ->root_instruction() ->backend_config()); const FusionBackendConfig& config = gpu_config.fusion_backend_config(); + TF_ASSERT_OK_AND_ASSIGN( + TritonGemmConfig triton_gemm_config, + TritonGemmConfig::FromProto(config.triton_gemm_config())); TF_ASSERT_OK_AND_ASSIGN( const auto result, TritonWrapper(*TritonFusionAnalysis::Execute(*triton_dot_computation), "test_fn", triton_dot_computation, kTritonGemmFusionKind, - GetCudaComputeCapability(), dev_info, - TritonGemmConfig::FromProto(config.triton_gemm_config()), + GetCudaComputeCapability(), dev_info, triton_gemm_config, &llvm_module, &EmitMatMul, mlir_context)); // The config is chosen so that the used memory size is slightly above the // 48 kB boundary of standard / optin shared memory so that any GPU that @@ -2642,7 +2700,10 @@ ENTRY e { p0 = s8[332,441]{1,0} parameter(0) p1 = f16[441,39]{1,0} parameter(1) ROOT _ = f16[332,39]{1,0} fusion(p0, p1), kind=kCustom, calls=triton_dot, - backend_config={"fusion_backend_config": {kind: "__triton_gemm", triton_gemm_config: {"block_m":32,"block_n":32,"block_k":32,"split_k":1,"num_stages":1,"num_warps":4}}} + backend_config={"fusion_backend_config": {kind: "__triton_gemm", + triton_gemm_config: {"block_m":32,"block_n":32,"block_k":32, + "split_k":1,"num_stages":1,"num_warps":4, + "num_ctas":1}}} })"; EXPECT_TRUE(RunAndCompareTwoModules(kHloTextLowShmem, kHloTextOptinShmem, @@ -2678,7 +2739,10 @@ ENTRY e { arg0 = f16[128,32]{1,0} parameter(0) arg1 = f16[64,32]{1,0} parameter(1) ROOT _ = f16[128,64]{1,0} fusion(arg0, arg1), kind=kCustom, calls=triton_dot, - backend_config={"fusion_backend_config": {kind: "__triton_gemm", triton_gemm_config: {"block_m":128,"block_n":32,"block_k":64,"split_k":1,"num_stages":2,"num_warps":4}}} + backend_config={"fusion_backend_config": {kind: "__triton_gemm", + triton_gemm_config: {"block_m":128,"block_n":32,"block_k":64, + "split_k":1,"num_stages":2,"num_warps":4, + "num_ctas":1}}} } )"; @@ -2715,7 +2779,10 @@ ENTRY e { arg0 = f32[64,128]{1,0} parameter(0) arg1 = f32[1024,64]{1,0} parameter(1) ROOT _ = f32[128,1024]{1,0} fusion(arg0, arg1), kind=kCustom, calls=triton_dot, - backend_config={"fusion_backend_config": {kind: "__triton_gemm", triton_gemm_config: {"block_m":32,"block_n":32,"block_k":64,"split_k":1,"num_stages":2,"num_warps":4}}} + backend_config={"fusion_backend_config": {kind: "__triton_gemm", + triton_gemm_config: {"block_m":32,"block_n":32,"block_k":64, + "split_k":1,"num_stages":2,"num_warps":4, + "num_ctas":1}}} } )"; @@ -2763,7 +2830,10 @@ ENTRY e { p0 = s8[144,256]{1,0} parameter(0) p1 = bf16[256,122]{1,0} parameter(1) ROOT _ = bf16[144,122]{1,0} fusion(p0, p1), kind=kCustom, calls=triton_dot, - backend_config={"fusion_backend_config": {kind: "__triton_gemm", triton_gemm_config: {"block_m":64,"block_n":64,"block_k":64,"split_k":1,"num_stages":1,"num_warps":2}}} + backend_config={"fusion_backend_config": {kind: "__triton_gemm", + triton_gemm_config: {"block_m":64,"block_n":64,"block_k":64, + "split_k":1,"num_stages":1,"num_warps":2, + "num_ctas":1}}} } )"; @@ -2794,7 +2864,10 @@ ENTRY e { bitcast.4 = s8[480,120]{1,0} bitcast(p0) ROOT triton_gemm_r = bf16[480,16]{1,0} fusion(bitcast.4, p1), kind=kCustom, calls=triton_gemm_r, - backend_config={"fusion_backend_config": {kind: "__triton_gemm", triton_gemm_config: {"block_m":64,"block_n":32,"block_k":64,"split_k":1,"num_stages":4,"num_warps":4}}} + backend_config={"fusion_backend_config": {kind: "__triton_gemm", + triton_gemm_config: {"block_m":64,"block_n":32,"block_k":64, + "split_k":1,"num_stages":4,"num_warps":4, + "num_ctas":1}}} })"; const std::string hlo_text_splitk = R"( @@ -2833,7 +2906,10 @@ ENTRY e { bitcast.4 = s8[480,120]{1,0} bitcast(p0) triton_gemm_r = bf16[4,480,16]{2,1,0} fusion(bitcast.4, p1), kind=kCustom, calls=triton_gemm_r, - backend_config={"fusion_backend_config": {kind: "__triton_gemm", triton_gemm_config: {"block_m":32,"block_n":32,"block_k":128,"split_k":4,"num_stages":1,"num_warps":4}}} + backend_config={"fusion_backend_config": {kind: "__triton_gemm", + triton_gemm_config: {"block_m":32,"block_n":32,"block_k":128, + "split_k":4,"num_stages":1,"num_warps":4, + "num_ctas":1}}} ROOT fusion.1 = bf16[480,16]{1,0} fusion(triton_gemm_r), kind=kLoop, calls=fused_computation })"; @@ -2865,7 +2941,10 @@ ENTRY e { tmp_0 = bf16[1,1,800,5,128]{4,3,2,1,0} parameter(1) ROOT triton_gemm_dot.24 = f32[5,128,700]{2,1,0} fusion(tmp_3, tmp_0), kind=kCustom, calls=triton_gemm_dot.24, - backend_config={"fusion_backend_config": {kind: "__triton_gemm", triton_gemm_config: {"block_m":64,"block_n":32,"block_k":64,"split_k":1,"num_stages":2,"num_warps":8}}} + backend_config={"fusion_backend_config": {kind: "__triton_gemm", + triton_gemm_config: {"block_m":64,"block_n":32,"block_k":64, + "split_k":1,"num_stages":2,"num_warps":8, + "num_ctas":1}}} })"; const std::string kHloTextSplitK = R"( @@ -2893,7 +2972,10 @@ ENTRY e { tmp_0 = bf16[1,1,800,5,128]{4,3,2,1,0} parameter(1) triton_gemm_dot.24 = f32[8,5,128,700]{3,2,1,0} fusion(tmp_3, tmp_0), kind=kCustom, calls=triton_gemm_dot, - backend_config={"fusion_backend_config": {kind: "__triton_gemm", triton_gemm_config: {"block_m":64,"block_n":32,"block_k":64,"split_k":8,"num_stages":1,"num_warps":4}}} + backend_config={"fusion_backend_config": {kind: "__triton_gemm", + triton_gemm_config: {"block_m":64,"block_n":32,"block_k":64, + "split_k":8,"num_stages":1,"num_warps":4, + "num_ctas":1}}} constant = f32[] constant(0) ROOT reduce = f32[5,128,700]{2,1,0} reduce(triton_gemm_dot.24, constant), dimensions={0}, to_apply=add })"; @@ -2926,7 +3008,10 @@ ENTRY entry { parameter_1.1 = bf16[16,4,128]{2,1,0} parameter(1) ROOT triton_gemm_dot.5316 = bf16[16,96]{1,0} fusion(bitcast.6, parameter_1.1), kind=kCustom, calls=triton_gemm_dot.5316, - backend_config={"fusion_backend_config": {kind: "__triton_gemm", triton_gemm_config: {"block_m":32,"block_n":32,"block_k":256,"split_k":1,"num_stages":1,"num_warps":4}}} + backend_config={"fusion_backend_config": {kind: "__triton_gemm", + triton_gemm_config: {"block_m":32,"block_n":32,"block_k":256, + "split_k":1,"num_stages":1,"num_warps":4, + "num_ctas":1}}} })"; const std::string kHloTextSplitK = R"( @@ -2966,7 +3051,10 @@ ENTRY entry { parameter_1.1 = bf16[16,4,128]{2,1,0} parameter(1) triton_gemm_dot.5316 = bf16[16,16,96]{2,1,0} fusion(bitcast.6, parameter_1.1), kind=kCustom, calls=triton_gemm_dot.5316, - backend_config={"fusion_backend_config": {kind: "__triton_gemm", triton_gemm_config: {"block_m":64,"block_n":32,"block_k":32,"split_k":16,"num_stages":1,"num_warps":4}}} + backend_config={"fusion_backend_config": {kind: "__triton_gemm", + triton_gemm_config: {"block_m":64,"block_n":32,"block_k":32, + "split_k":16,"num_stages":1,"num_warps":4, + "num_ctas":1}}} ROOT fusion.1 = bf16[16,96]{1,0} fusion(triton_gemm_dot.5316), kind=kLoop, calls=fused_computation })"; @@ -2998,7 +3086,11 @@ triton_gemm_dot.clone { ENTRY entry_computation { p0 = s8[3,129,5,32]{3,2,1,0} parameter(0) p1 = f16[16,129]{1,0} parameter(1) - ROOT fusion = f16[480,16]{1,0} fusion(p0, p1), kind=kCustom, calls=triton_gemm_dot.clone, backend_config={"fusion_backend_config": {"kind":"__triton_gemm","triton_gemm_config":{"block_m":"32","block_n":"32","block_k":"256","split_k":"1","num_stages":"1","num_warps":"4"}}} + ROOT fusion = f16[480,16]{1,0} fusion(p0, p1), kind=kCustom, calls=triton_gemm_dot.clone, + backend_config={"fusion_backend_config": {"kind":"__triton_gemm", + "triton_gemm_config":{"block_m":"32","block_n":"32","block_k":"256", + "split_k":"1","num_stages":"1","num_warps":"4", + "num_ctas":"1"}}} } )"; @@ -3039,7 +3131,11 @@ fused_computation { ENTRY entry_computation { p0 = s8[3,129,5,32]{3,2,1,0} parameter(0) p1 = f16[16,129]{1,0} parameter(1) - fusion = f32[2,480,16]{2,1,0} fusion(p0, p1), kind=kCustom, calls=triton_gemm_dot.clone, backend_config={"fusion_backend_config": {"kind":"__triton_gemm","triton_gemm_config":{"block_m":"128","block_n":"128","block_k":"64","split_k":"2","num_stages":"1","num_warps":"8"}}} + fusion = f32[2,480,16]{2,1,0} fusion(p0, p1), kind=kCustom, calls=triton_gemm_dot.clone, + backend_config={"fusion_backend_config": {"kind":"__triton_gemm", + "triton_gemm_config":{"block_m":"128","block_n":"128","block_k":"64", + "split_k":"2","num_stages":"1","num_warps":"8", + "num_ctas":"1"}}} ROOT fusion.1 = f16[480,16]{1,0} fusion(fusion), kind=kLoop, calls=fused_computation } )"; @@ -3067,7 +3163,10 @@ ENTRY entry_computation { p1 = f16[1,1023,128]{2,1,0} parameter(1) ROOT triton_gemm_dot.7103 = f16[1,8,4,128]{3,2,1,0} fusion(p0, p1), kind=kCustom, calls=triton_gemm_dot.7103_computation.clone, - backend_config={"fusion_backend_config": {"kind":"__triton_gemm","triton_gemm_config":{"block_m":"128","block_n":"128","block_k":"32","split_k":"1","num_stages":"4","num_warps":"4"}}} + backend_config={"fusion_backend_config": {"kind":"__triton_gemm", + "triton_gemm_config":{"block_m":"128","block_n":"128","block_k":"32", + "split_k":"1","num_stages":"4","num_warps":"4", + "num_ctas":"1"}}} } )"; @@ -3115,7 +3214,12 @@ fused_computation.1 { ENTRY entry_computation { p0 = f16[1,8,4,1023]{3,2,1,0} parameter(0) p1 = f16[1,1023,128]{2,1,0} parameter(1) - triton_gemm_dot.7103 = f16[8,1,8,4,128]{4,3,2,1,0} fusion(p0, p1), kind=kCustom, calls=triton_gemm_dot.7103_computation.clone, backend_config={"fusion_backend_config": {"kind":"__triton_gemm","triton_gemm_config":{"block_m":"16","block_n":"128","block_k":"32","split_k":"8","num_stages":"1","num_warps":"4"}}} + triton_gemm_dot.7103 = f16[8,1,8,4,128]{4,3,2,1,0} fusion(p0, p1), kind=kCustom, + calls=triton_gemm_dot.7103_computation.clone, + backend_config={"fusion_backend_config": {"kind":"__triton_gemm", + "triton_gemm_config":{"block_m":"16","block_n":"128","block_k":"32", + "split_k":"8","num_stages":"1","num_warps":"4", + "num_ctas":"1"}}} ROOT fusion.1 = f16[1,8,4,128]{3,2,1,0} fusion(triton_gemm_dot.7103), kind=kLoop, calls=fused_computation.1 } )"; @@ -3141,7 +3245,12 @@ triton_gemm_dot.7103_computation.clone { ENTRY entry_computation { p0 = f16[1,8,4,1019]{3,2,1,0} parameter(0) p1 = f16[1,1019,128]{2,1,0} parameter(1) - ROOT triton_gemm_dot.7103 = f16[1,8,4,128]{3,2,1,0} fusion(p0, p1), kind=kCustom, calls=triton_gemm_dot.7103_computation.clone, backend_config={"fusion_backend_config": {"kind":"__triton_gemm","triton_gemm_config":{"block_m":"32","block_n":"32","block_k":"256","split_k":"1","num_stages":"1","num_warps":"4"}}} + ROOT triton_gemm_dot.7103 = f16[1,8,4,128]{3,2,1,0} fusion(p0, p1), kind=kCustom, + calls=triton_gemm_dot.7103_computation.clone, + backend_config={"fusion_backend_config": {"kind":"__triton_gemm", + "triton_gemm_config":{"block_m":"32","block_n":"32","block_k":"256", + "split_k":"1","num_stages":"1","num_warps":"4", + "num_ctas":"1"}}} } )"; @@ -3189,7 +3298,12 @@ fused_computation.1 { ENTRY entry_computation { p0 = f16[1,8,4,1019]{3,2,1,0} parameter(0) p1 = f16[1,1019,128]{2,1,0} parameter(1) - triton_gemm_dot.7103 = f16[16,1,8,4,128]{4,3,2,1,0} fusion(p0, p1), kind=kCustom, calls=triton_gemm_dot.7103_computation.clone, backend_config={"fusion_backend_config": {"kind":"__triton_gemm","triton_gemm_config":{"block_m":"64","block_n":"32","block_k":"32","split_k":"16","num_stages":"1","num_warps":"4"}}} + triton_gemm_dot.7103 = f16[16,1,8,4,128]{4,3,2,1,0} fusion(p0, p1), kind=kCustom, + calls=triton_gemm_dot.7103_computation.clone, + backend_config={"fusion_backend_config": {"kind":"__triton_gemm", + "triton_gemm_config":{"block_m":"64","block_n":"32","block_k":"32", + "split_k":"16","num_stages":"1","num_warps":"4", + "num_ctas":"1"}}} ROOT fusion.1 = f16[1,8,4,128]{3,2,1,0} fusion(triton_gemm_dot.7103), kind=kLoop, calls=fused_computation.1 } )"; @@ -3217,7 +3331,10 @@ ENTRY e { p1 = f32[32,50,104]{2,1,0} parameter(1) ROOT triton_gemm_dot.6 = f32[32,50,26]{2,0,1} fusion(p0, p1), kind=kCustom, calls=triton_gemm_dot.6, - backend_config={"fusion_backend_config": {kind: "__triton_gemm", triton_gemm_config: {"block_m":64,"block_n":16,"block_k":32,"split_k":1,"num_stages":1,"num_warps":4}}} + backend_config={"fusion_backend_config": {kind: "__triton_gemm", + triton_gemm_config: {"block_m":64,"block_n":16,"block_k":32, + "split_k":1,"num_stages":1,"num_warps":4, + "num_ctas":1}}} })"; const std::string kHloTextRef = R"( @@ -3243,7 +3360,10 @@ ENTRY e { %parameter_1 = f32[32,50,104]{2,1,0} parameter(1) %triton_gemm_dot.127 = f32[32,50,26]{2,1,0} fusion(%parameter_0, %parameter_1), kind=kCustom, calls=%triton_gemm_dot.127, - backend_config={"fusion_backend_config": {kind: "__triton_gemm", triton_gemm_config: {"block_m":32,"block_n":128,"block_k":64,"split_k":1,"num_stages":2,"num_warps":4}}} + backend_config={"fusion_backend_config": {kind: "__triton_gemm", + triton_gemm_config: {"block_m":32,"block_n":128,"block_k":64, + "split_k":1,"num_stages":2,"num_warps":4, + "num_ctas":1}}} ROOT %fusion.1 = f32[32,50,26]{2,0,1} fusion(%triton_gemm_dot.127), kind=kLoop, calls=%fused_computation })"; @@ -3272,7 +3392,8 @@ ENTRY e { backend_config={"fusion_backend_config": {"kind":"__triton_gemm", "triton_gemm_config":{"block_m":"16","block_n":"64", "block_k":"16","split_k":"1", - "num_stages":"3","num_warps":"2"}}} + "num_stages":"3","num_warps":"2", + "num_ctas":"1"}}} })"; const std::string kHloTextRef = R"( @@ -3312,7 +3433,8 @@ ENTRY e { backend_config={"fusion_backend_config": {"kind":"__triton_gemm", "triton_gemm_config":{"block_m":"16","block_n":"64", "block_k":"16","split_k":"1", - "num_stages":"3","num_warps":"2"}}} + "num_stages":"3","num_warps":"2", + "num_ctas":"1"}}} })"; const std::string kHloTextRef = R"( @@ -3381,7 +3503,8 @@ ENTRY e { backend_config={"fusion_backend_config": {"kind":"__triton_gemm", "triton_gemm_config":{"block_m":"64","block_n":"64", "block_k":"64","split_k":"1", - "num_stages":"1","num_warps":"4"}}} + "num_stages":"1","num_warps":"4", + "num_ctas":"1"}}} })"; const std::string kHloTextRef = R"( @@ -3467,7 +3590,8 @@ ENTRY e { backend_config={"fusion_backend_config": {"kind":"__triton_gemm", "triton_gemm_config":{"block_m":"32","block_n":"16", "block_k":"32","split_k":"1", - "num_stages":"1","num_warps":"4"}}} + "num_stages":"1","num_warps":"4", + "num_ctas":"1"}}} })"; const std::string kHloTextRef = R"( @@ -3525,7 +3649,9 @@ ENTRY e { ROOT r = f16[9,32]{1,0} fusion(p0, p1, p2), kind=kCustom, calls=triton_dot, backend_config={"fusion_backend_config": {kind: "__triton_gemm", - triton_gemm_config: {"block_m":32,"block_n":32,"block_k":32,"split_k":1,"num_stages":1,"num_warps":2}}} + triton_gemm_config: {"block_m":32,"block_n":32,"block_k":32, + "split_k":1,"num_stages":1,"num_warps":2, + "num_ctas":"1"}}} })"; const std::string kHloTextRef = R"( @@ -3661,10 +3787,328 @@ ENTRY e { ->fused_instructions_computation() ->root_instruction(), GmockMatch(m::Dot(m::Op().WithShape(BF16, {16, 32}, {1, 0}), - m::Op().WithShape(BF16, {40, 32}, {1, 0})) + m::Op().WithShape(BF16, {32, 40}, {1, 0})) .WithShape(BF16, {16, 40}, {1, 0}))); } +class Triton6xBF16GemmTest : public TritonFilecheckTest { + public: + DebugOptions GetDebugOptionsForTest() override { + DebugOptions debug_options = GpuCodegenTest::GetDebugOptionsForTest(); + // Enable triton fusion for all supported gemms. + debug_options.set_xla_gpu_triton_gemm_any(true); + // Do not fall back to cuBLAS, we are testing Triton. + debug_options.set_xla_gpu_cublas_fallback(false); + // Do not autotune split-k by default, since this prevents deterministically + // matching the optimized HLO. + debug_options.set_xla_gpu_enable_split_k_autotuning(false); + // Enable bf16_6way gemm to compute F32 matmul. + debug_options.set_xla_gpu_enable_bf16_6way_gemm(true); + return debug_options; + } +}; + +TEST_F(Triton6xBF16GemmTest, Emit6xBF16GemmWhenBothInputsAreF32) { + if (!GetCudaComputeCapability().IsAtLeastAmpere()) { + GTEST_SKIP() << "No BF16 before Ampere."; + } + const char* kHloText = R"( +HloModule t + +triton_dot { + p0 = f32[5,7] parameter(0) + p1 = f32[7,33] parameter(1) + ROOT dot = f32[5,33] dot(p0, p1), + lhs_contracting_dims={1}, rhs_contracting_dims={0} +} + +ENTRY e { + p0 = f32[5,7]{1,0} parameter(0) + p1 = f32[7,33]{1,0} parameter(1) + ROOT _ = f32[5,33] fusion(p0, p1), kind=kCustom, calls=triton_dot, + backend_config={"fusion_backend_config": {kind: "__triton_gemm", + triton_gemm_config: + {"block_m":32,"block_n":32,"block_k":32,"split_k":1,"num_stages":1,"num_warps":1,"num_ctas":1}}} +} +)"; + TritonGemmConfig config(32, 32, 32, 1, 1, 1); + ASSERT_THAT( + CreateTritonIrAndFileCheck(kHloText, config, EmitMatMul, "triton_dot", R"( +CHECK: tt.func @triton_fn(%[[LHS:.*]]: !tt.ptr {tt.divisibility = 16 : i32}, %[[RHS:.*]]: !tt.ptr {tt.divisibility = 16 : i32}, %[[OUT:.*]]: !tt.ptr {tt.divisibility = 16 : i32}) { +CHECK: %[[INFINITY:.*]] = arith.constant dense<0x7F800000> : tensor<32x32xf32> +CHECK: %[[C_MASK:.*]] = arith.constant dense<-65536> : tensor<32x32xi32> +CHECK: %[[C0:.*]] = arith.constant dense<0.000000e+00> : tensor<32x32xf32> +CHECK: %[[CAST_I32:.*]] = tt.bitcast %{{.*}} : tensor<32x32xf32> -> tensor<32x32xi32> +CHECK: %[[EXTRACT_HI:.*]] = arith.andi %[[CAST_I32]], %[[C_MASK]] : tensor<32x32xi32> +CHECK: %[[CAST_HI:.*]] = tt.bitcast %[[EXTRACT_HI]] : tensor<32x32xi32> -> tensor<32x32xf32> +CHECK: %[[TRUNC_TO_BF16:.*]] = arith.truncf %[[CAST_HI]] : tensor<32x32xf32> to tensor<32x32xbf16> +CHECK-COUNT-5: %{{.*}} = tt.dot %{{.*}}, %{{.*}}, %{{.*}} {allowTF32 = false, maxNumImpreciseAcc = 0 : i32} : tensor<32x32xbf16> * tensor<32x32xbf16> -> tensor<32x32xf32> +CHECK: %[[ABS:.*]] = math.absf +CHECK: %[[CMP:.*]] = arith.cmpf ogt, %[[INFINITY]], %[[ABS]] : tensor<32x32xf32> +CHECK: %[[SELECT:.*]] = arith.select %[[CMP]], %{{.*}}, %[[C0]] : tensor<32x32xi1>, tensor<32x32xf32> +CHECK: %[[DOT_LAST:.*]] = tt.dot %{{.*}}, %{{.*}}, %[[SELECT]] {allowTF32 = false, maxNumImpreciseAcc = 0 : i32} : tensor<32x32xbf16> * tensor<32x32xbf16> -> tensor<32x32xf32> +CHECK: %[[ACC:.*]] = arith.addf %[[DOT_LAST]], %[[C0]] : tensor<32x32xf32> + )"), + tsl::testing::IsOkAndHolds(true)); + + EXPECT_TRUE(RunAndCompareNoHloPasses(kHloText, ErrorSpec{/*aabs=*/1e-6, + /*arel=*/1e-6})); +} + +TEST_F(Triton6xBF16GemmTest, Triton6xBF16GemmWorksForLongContractingDimension) { + if (!GetCudaComputeCapability().IsAtLeastAmpere()) { + GTEST_SKIP() << "No BF16 before Ampere."; + } + const char* kHloText = R"( +HloModule t + +triton_dot { + p0 = f32[5,2048] parameter(0) + p1 = f32[2048,33] parameter(1) + ROOT dot = f32[5,33] dot(p0, p1), + lhs_contracting_dims={1}, rhs_contracting_dims={0} +} + +ENTRY e { + p0 = f32[5,2048]{1,0} parameter(0) + p1 = f32[2048,33]{1,0} parameter(1) + ROOT _ = f32[5,33] fusion(p0, p1), kind=kCustom, calls=triton_dot, + backend_config={"fusion_backend_config": {kind: "__triton_gemm", + triton_gemm_config: + {"block_m":64,"block_n":32,"block_k":32,"split_k":1,"num_stages":1,"num_warps":4, "num_ctas":1}}} +} +)"; + TritonGemmConfig config(64, 32, 32, 1, 1, 4); + ASSERT_THAT( + CreateTritonIrAndFileCheck(kHloText, config, EmitMatMul, "triton_dot", R"( +CHECK: tt.func @triton_fn(%[[LHS:.*]]: !tt.ptr {tt.divisibility = 16 : i32}, %[[RHS:.*]]: !tt.ptr {tt.divisibility = 16 : i32}, %[[OUT:.*]]: !tt.ptr {tt.divisibility = 16 : i32}) { +CHECK-COUNT-6: %{{.*}} = tt.dot %{{.*}}, %{{.*}}, %{{.*}} {allowTF32 = false, maxNumImpreciseAcc = 0 : i32} : tensor<64x32xbf16> * tensor<32x32xbf16> -> tensor<64x32xf32> + )"), + tsl::testing::IsOkAndHolds(true)); + EXPECT_TRUE(RunAndCompareNoHloPasses(kHloText, ErrorSpec{/*aabs=*/1e-5, + /*arel=*/1e-5})); +} + +TEST_F(Triton6xBF16GemmTest, Triton6xBF16GemmCanHandleInfinity) { + if (!GetCudaComputeCapability().IsAtLeastAmpere()) { + GTEST_SKIP() << "No BF16 before Ampere."; + } + const char* kHloText = R"( +HloModule t + +triton_dot { + p0 = f32[2,2] parameter(0) + p1 = f32[2,2] parameter(1) + ROOT dot = f32[2,2] dot(p0, p1), + lhs_contracting_dims={1}, rhs_contracting_dims={0} +} + +ENTRY e { + p0 = f32[2,2]{1, 0} parameter(0) + p1 = f32[2,2]{1, 0} parameter(1) + ROOT _ = f32[2,2] fusion(p0, p1), kind=kCustom, calls=triton_dot, + backend_config={"fusion_backend_config": {kind: "__triton_gemm", + triton_gemm_config: + {"block_m":32,"block_n":32,"block_k":32,"split_k":1,"num_stages":1,"num_warps":1, "num_ctas":1}}} +} +)"; + TritonGemmConfig config(32, 32, 32, 1, 1, 1); + ASSERT_THAT( + CreateTritonIrAndFileCheck(kHloText, config, EmitMatMul, "triton_dot", R"( +CHECK: tt.func @triton_fn(%[[LHS:.*]]: !tt.ptr {tt.divisibility = 16 : i32}, %[[RHS:.*]]: !tt.ptr {tt.divisibility = 16 : i32}, %[[OUT:.*]]: !tt.ptr {tt.divisibility = 16 : i32}) { +CHECK-COUNT-6: %{{.*}} = tt.dot %{{.*}}, %{{.*}}, %{{.*}} {allowTF32 = false, maxNumImpreciseAcc = 0 : i32} : tensor<32x32xbf16> * tensor<32x32xbf16> -> tensor<32x32xf32> + )"), + tsl::testing::IsOkAndHolds(true)); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + GetOptimizedModule(kHloText)); + std::vector arguments(2); + arguments[0] = + LiteralUtil::CreateR2({{+std::numeric_limits::infinity(), + +std::numeric_limits::infinity()}, + {+std::numeric_limits::infinity(), + +std::numeric_limits::infinity()}}); + arguments[1] = LiteralUtil::CreateR2({{1.0f, 1.0f}, {1.0f, 1.0f}}); + std::vector argument_ptrs; + absl::c_transform( + arguments, std::back_inserter(argument_ptrs), + [](const Literal& literal) { return const_cast(&literal); }); + + EXPECT_TRUE(RunAndCompareNoHloPasses(std::move(module), argument_ptrs, + ErrorSpec{/*aabs=*/0, /*arel=*/0})); +} + +TEST_F(Triton6xBF16GemmTest, Triton6xBF16GemmCanHandleNaN) { + if (!GetCudaComputeCapability().IsAtLeastAmpere()) { + GTEST_SKIP() << "No BF16 before Ampere."; + } + const char* kHloText = R"( +HloModule t + +triton_dot { + p0 = f32[2,2] parameter(0) + p1 = f32[2,2] parameter(1) + ROOT dot = f32[2,2] dot(p0, p1), + lhs_contracting_dims={1}, rhs_contracting_dims={0} +} + +ENTRY e { + p0 = f32[2,2]{1, 0} parameter(0) + p1 = f32[2,2]{1, 0} parameter(1) + ROOT _ = f32[2,2] fusion(p0, p1), kind=kCustom, calls=triton_dot, + backend_config={"fusion_backend_config": {kind: "__triton_gemm", + triton_gemm_config: + {"block_m":32,"block_n":32,"block_k":32,"split_k":1,"num_stages":1,"num_warps":1, "num_ctas":1}}} +} +)"; + TritonGemmConfig config(32, 32, 32, 1, 1, 1); + ASSERT_THAT( + CreateTritonIrAndFileCheck(kHloText, config, EmitMatMul, "triton_dot", R"( +CHECK: tt.func @triton_fn(%[[LHS:.*]]: !tt.ptr {tt.divisibility = 16 : i32}, %[[RHS:.*]]: !tt.ptr {tt.divisibility = 16 : i32}, %[[OUT:.*]]: !tt.ptr {tt.divisibility = 16 : i32}) { +CHECK-COUNT-6: %{{.*}} = tt.dot %{{.*}}, %{{.*}}, %{{.*}} {allowTF32 = false, maxNumImpreciseAcc = 0 : i32} : tensor<32x32xbf16> * tensor<32x32xbf16> -> tensor<32x32xf32> + )"), + tsl::testing::IsOkAndHolds(true)); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + GetOptimizedModule(kHloText)); + std::vector arguments(2); + arguments[0] = + LiteralUtil::CreateR2({{std::numeric_limits::quiet_NaN(), + std::numeric_limits::quiet_NaN()}, + {std::numeric_limits::quiet_NaN(), + std::numeric_limits::quiet_NaN()}}); + arguments[1] = LiteralUtil::CreateR2( + {{1.0f, +std::numeric_limits::infinity()}, + {1.0f, +std::numeric_limits::infinity()}}); + std::vector argument_ptrs; + absl::c_transform( + arguments, std::back_inserter(argument_ptrs), + [](const Literal& literal) { return const_cast(&literal); }); + + EXPECT_TRUE(RunAndCompareNoHloPasses(std::move(module), argument_ptrs, + ErrorSpec{/*aabs=*/0, /*arel=*/0})); +} + +// Test case shows that why we truncate the middle term instead of rounding. +// If we round the middle term, the splitted terms may disagree in sign. This +// could result in wrong results for extreme values. +// For example, consider: +// x = -3.40282347e+38 +// If we round the middle term, its decomposition would be: +// x_hi: -3.38953139e+38 +// x_mid: -1.3240357e+36 +// x_lo: 5.17201445e+33 +// The result of x*x would be NaN instead of positive infinity. +TEST_F(Triton6xBF16GemmTest, Triton6xBF16GemmWorksForExtremeInputs) { + if (!GetCudaComputeCapability().IsAtLeastAmpere()) { + GTEST_SKIP() << "No BF16 before Ampere."; + } + const char* kHloText = R"( +HloModule t + +triton_dot { + p0 = f32[2,2] parameter(0) + p1 = f32[2,2] parameter(1) + ROOT dot = f32[2,2] dot(p0, p1), + lhs_contracting_dims={1}, rhs_contracting_dims={0} +} + +ENTRY e { + p0 = f32[2,2]{1, 0} parameter(0) + p1 = f32[2,2]{1, 0} parameter(1) + ROOT _ = f32[2,2] fusion(p0, p1), kind=kCustom, calls=triton_dot, + backend_config={"fusion_backend_config": {kind: "__triton_gemm", + triton_gemm_config: + {"block_m":32,"block_n":32,"block_k":32,"split_k":1,"num_stages":1,"num_warps":1, "num_ctas":1}}} +} +)"; + TritonGemmConfig config(32, 32, 32, 1, 1, 1); + ASSERT_THAT( + CreateTritonIrAndFileCheck(kHloText, config, EmitMatMul, "triton_dot", R"( +CHECK: tt.func @triton_fn(%[[LHS:.*]]: !tt.ptr {tt.divisibility = 16 : i32}, %[[RHS:.*]]: !tt.ptr {tt.divisibility = 16 : i32}, %[[OUT:.*]]: !tt.ptr {tt.divisibility = 16 : i32}) { +CHECK-COUNT-6: %{{.*}} = tt.dot %{{.*}}, %{{.*}}, %{{.*}} {allowTF32 = false, maxNumImpreciseAcc = 0 : i32} : tensor<32x32xbf16> * tensor<32x32xbf16> -> tensor<32x32xf32> + )"), + tsl::testing::IsOkAndHolds(true)); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + GetOptimizedModule(kHloText)); + std::vector arguments(2); + arguments[0] = LiteralUtil::CreateR2( + {{0x1.0103p72f, 1.0f}, {-0x1.0103p72f, 1.0f}}); + arguments[1] = LiteralUtil::CreateR2( + {{0x1.0103p72f, 1.0f}, {-0x1.0103p72f, 1.0f}}); + std::vector argument_ptrs; + absl::c_transform( + arguments, std::back_inserter(argument_ptrs), + [](const Literal& literal) { return const_cast(&literal); }); + + EXPECT_TRUE( + RunAndCompareNoHloPasses(std::move(module), argument_ptrs, + ErrorSpec{/*aabs=*/1e-6, /*arel=*/1e-6})); +} + +TEST_F(Triton6xBF16GemmTest, ShouldNotEmit6xBF16GemmForPreAmpere) { + if (GetCudaComputeCapability().IsAtLeastAmpere()) { + GTEST_SKIP() << "6xBF16Gemm should be emitted post-Ampere."; + } + const char* kHloText = R"( +HloModule t + +triton_dot { + p0 = f32[5,7] parameter(0) + p1 = f32[7,33] parameter(1) + ROOT dot = f32[5,33] dot(p0, p1), + lhs_contracting_dims={1}, rhs_contracting_dims={0} +} + +ENTRY e { + p0 = f32[5,7]{1,0} parameter(0) + p1 = f32[7,33]{1,0} parameter(1) + ROOT _ = f32[5,33] fusion(p0, p1), kind=kCustom, calls=triton_dot, + backend_config={"fusion_backend_config": {kind: "__triton_gemm", + triton_gemm_config: + {"block_m":32,"block_n":32,"block_k":32,"split_k":1,"num_stages":1,"num_warps":1,"num_ctas":1}}} +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr verified_module, + ParseAndReturnVerifiedModule(kHloText)); + + CompileAndOptionallyVerifyPtx(std::move(verified_module), + R"( +CHECK-NOT: mma +CHECK: selp.f32 +CHECK: st.shared.f32 +CHECK: ld.shared.v4.f32 +CHECK: fma.rn.f32 +CHECK: st.shared.f32 +)"); +} + +TEST_F(Triton6xBF16GemmTest, Emit6xBF16GemmEndToEnd) { + const char* kHloText = R"( +HloModule t + +ENTRY e { + p0 = f32[5,32] parameter(0) + p1 = f32[32,7] parameter(1) + ROOT dot = f32[5,7] dot(p0, p1), + lhs_contracting_dims={1}, rhs_contracting_dims={0} +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr verified_module, + ParseAndReturnVerifiedModule(kHloText)); + if (GetCudaComputeCapability().IsAtLeastAmpere()) { + CompileAndOptionallyVerifyPtx(std::move(verified_module), + R"( +CHECK: mma.sync.aligned.{{.*}}.row.col.f32.bf16.bf16.f32 +CHECK-NOT: mma.sync.aligned.{{.*}}.row.col.f32.tf32.tf32.f32 +)"); + } else { + CompileAndOptionallyVerifyPtx(std::move(verified_module), + R"( +CHECK-NOT: mma +)"); + } + EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-6, + /*arel=*/1e-6})); +} } // namespace } // namespace gpu } // namespace xla diff --git a/third_party/xla/xla/service/gpu/ir_emitter_unnested.cc b/third_party/xla/xla/service/gpu/ir_emitter_unnested.cc index 390ea8a60df831..ec27edd5fc1406 100644 --- a/third_party/xla/xla/service/gpu/ir_emitter_unnested.cc +++ b/third_party/xla/xla/service/gpu/ir_emitter_unnested.cc @@ -20,7 +20,6 @@ limitations under the License. #include #include #include -#include #include #include #include @@ -33,6 +32,7 @@ limitations under the License. #include "absl/algorithm/container.h" #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" +#include "absl/container/inlined_vector.h" #include "absl/log/check.h" #include "absl/log/log.h" #include "absl/status/status.h" @@ -47,6 +47,7 @@ limitations under the License. #include "llvm/ADT/StringRef.h" #include "llvm/ADT/TypeSwitch.h" #include "llvm/IR/Argument.h" +#include "llvm/IR/Attributes.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/DerivedTypes.h" #include "llvm/IR/Function.h" @@ -62,14 +63,18 @@ limitations under the License. #include "mlir/Dialect/GPU/IR/GPUDialect.h" // from @llvm-project #include "mlir/Dialect/LLVMIR/LLVMDialect.h" // from @llvm-project #include "mlir/Dialect/MemRef/IR/MemRef.h" // from @llvm-project +#include "mlir/Dialect/MemRef/Transforms/Passes.h" // from @llvm-project #include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/DialectRegistry.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project #include "mlir/IR/Operation.h" // from @llvm-project #include "mlir/IR/Value.h" // from @llvm-project #include "mlir/IR/ValueRange.h" // from @llvm-project +#include "mlir/Parser/Parser.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h" // from @llvm-project #include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h" // from @llvm-project @@ -111,38 +116,41 @@ limitations under the License. #include "xla/service/gpu/ir_emitter_context.h" #include "xla/service/gpu/ir_emitter_nested.h" #include "xla/service/gpu/kernel_arguments.h" +#include "xla/service/gpu/kernel_reuse_cache.h" #include "xla/service/gpu/kernels/custom_kernel.h" #include "xla/service/gpu/kernels/topk_custom_kernel.h" #include "xla/service/gpu/launch_dimensions.h" #include "xla/service/gpu/matmul_utils.h" -#include "xla/service/gpu/nccl_all_gather_thunk.h" -#include "xla/service/gpu/nccl_all_reduce_thunk.h" -#include "xla/service/gpu/nccl_all_to_all_thunk.h" #include "xla/service/gpu/nccl_api.h" #include "xla/service/gpu/nccl_collective_permute_thunk.h" #include "xla/service/gpu/nccl_collective_thunk.h" #include "xla/service/gpu/nccl_recv_thunk.h" #include "xla/service/gpu/nccl_send_thunk.h" #include "xla/service/gpu/parallel_loop_emitter.h" -#include "xla/service/gpu/runtime3/command_buffer_cmd.h" -#include "xla/service/gpu/runtime3/command_buffer_cmd_emitter.h" -#include "xla/service/gpu/runtime3/command_buffer_thunk.h" -#include "xla/service/gpu/runtime3/conditional_thunk.h" -#include "xla/service/gpu/runtime3/convolution_thunk.h" -#include "xla/service/gpu/runtime3/copy_thunk.h" -#include "xla/service/gpu/runtime3/custom_call_thunk.h" -#include "xla/service/gpu/runtime3/fft_thunk.h" -#include "xla/service/gpu/runtime3/fused_mha_thunk.h" -#include "xla/service/gpu/runtime3/gemm_thunk.h" -#include "xla/service/gpu/runtime3/infeed_thunk.h" -#include "xla/service/gpu/runtime3/kernel_thunk.h" -#include "xla/service/gpu/runtime3/norm_thunk.h" -#include "xla/service/gpu/runtime3/outfeed_thunk.h" -#include "xla/service/gpu/runtime3/replica_id_thunk.h" -#include "xla/service/gpu/runtime3/send_recv_thunk.h" -#include "xla/service/gpu/runtime3/sequential_thunk.h" -#include "xla/service/gpu/runtime3/while_thunk.h" +#include "xla/service/gpu/runtime/command_buffer_cmd.h" +#include "xla/service/gpu/runtime/command_buffer_cmd_emitter.h" +#include "xla/service/gpu/runtime/command_buffer_thunk.h" +#include "xla/service/gpu/runtime/conditional_thunk.h" +#include "xla/service/gpu/runtime/convolution_thunk.h" +#include "xla/service/gpu/runtime/copy_thunk.h" +#include "xla/service/gpu/runtime/custom_call_thunk.h" +#include "xla/service/gpu/runtime/fft_thunk.h" +#include "xla/service/gpu/runtime/fused_mha_thunk.h" +#include "xla/service/gpu/runtime/gemm_thunk.h" +#include "xla/service/gpu/runtime/infeed_thunk.h" +#include "xla/service/gpu/runtime/kernel_thunk.h" +#include "xla/service/gpu/runtime/nccl_all_gather_thunk.h" +#include "xla/service/gpu/runtime/nccl_all_reduce_thunk.h" +#include "xla/service/gpu/runtime/nccl_all_to_all_thunk.h" +#include "xla/service/gpu/runtime/norm_thunk.h" +#include "xla/service/gpu/runtime/outfeed_thunk.h" +#include "xla/service/gpu/runtime/replica_id_thunk.h" +#include "xla/service/gpu/runtime/send_recv_thunk.h" +#include "xla/service/gpu/runtime/sequential_thunk.h" +#include "xla/service/gpu/runtime/wait_for_streams_thunk.h" +#include "xla/service/gpu/runtime/while_thunk.h" #include "xla/service/gpu/thunk.h" +#include "xla/service/gpu/triton_call.h" #include "xla/service/llvm_ir/buffer_assignment_util.h" #include "xla/service/llvm_ir/fused_ir_emitter.h" #include "xla/service/llvm_ir/ir_array.h" @@ -157,11 +165,11 @@ limitations under the License. #include "xla/statusor.h" #include "xla/stream_executor/device_description.h" #include "xla/stream_executor/gpu/gpu_blas_lt.h" +#include "xla/stream_executor/launch_dim.h" #include "xla/translate/hlo_to_mhlo/hlo_utils.h" #include "xla/translate/mhlo_to_hlo/attribute_exporter.h" #include "xla/translate/mhlo_to_hlo/location_exporter.h" #include "xla/translate/mhlo_to_hlo/mlir_hlo_to_hlo.h" -#include "xla/translate/mhlo_to_lhlo_with_xla/mhlo_to_lhlo_with_xla.h" #include "xla/util.h" #include "xla/xla_data.pb.h" #include "tsl/platform/errors.h" @@ -169,16 +177,17 @@ limitations under the License. #include "tsl/platform/status.h" #include "tsl/platform/statusor.h" #include "tsl/protobuf/dnn.pb.h" +#include "triton/Dialect/Triton/IR/Dialect.h" #if GOOGLE_CUDA || TF_HIPBLASLT -#include "xla/service/gpu/runtime3/gpublas_lt_matmul_thunk.h" +#include "xla/service/gpu/runtime/gpublas_lt_matmul_thunk.h" #endif // GOOGLE_CUDA || TF_HIPBLASLT #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM #include "xla/service/gpu/ir_emitter_triton.h" -#include "xla/service/gpu/runtime3/cholesky_thunk.h" -#include "xla/service/gpu/runtime3/cub_sort_thunk.h" -#include "xla/service/gpu/runtime3/triangular_solve_thunk.h" +#include "xla/service/gpu/runtime/cholesky_thunk.h" +#include "xla/service/gpu/runtime/cub_sort_thunk.h" +#include "xla/service/gpu/runtime/triangular_solve_thunk.h" #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM namespace xla { @@ -264,6 +273,13 @@ absl::StatusOr AsCudnnBackwardfMHAKind( case mlir::lmhlo_gpu::FusedMhaBackwardDagSignature::BackwardSoftmaxDropout: return xla::gpu::CudnnfMHAKind::kBackwardSoftmaxDropout; break; + case mlir::lmhlo_gpu::FusedMhaBackwardDagSignature:: + BackwardScaleMaskSoftmax: + return xla::gpu::CudnnfMHAKind::kBackwardScaleMaskSoftmax; + case mlir::lmhlo_gpu::FusedMhaBackwardDagSignature:: + BackwardScaleMaskSoftmaxDropout: + return xla::gpu::CudnnfMHAKind::kBackwardScaleMaskSoftmaxDropout; + break; default: return xla::Internal("Unsupported fused_mha_backward_dag_signature"); } @@ -282,54 +298,6 @@ std::unique_ptr IrEmitterUnnested::Create( new IrEmitterUnnested(ir_emitter_context)); } -absl::StatusOr IrEmitterUnnested::GetAllocationSlice( - mlir::Value v) { - return xla::gpu::GetAllocationSlice(v, ir_emitter_context_->allocations(), - nullptr); -} - -absl::StatusOr> -IrEmitterUnnested::GetAllocationSlices(mlir::OperandRange operands) { - std::vector slices; - slices.reserve(operands.size()); - for (mlir::Value operand : operands) { - TF_ASSIGN_OR_RETURN(auto slice, GetAllocationSlice(operand)); - slices.push_back(slice); - } - return slices; -} - -absl::Status IrEmitterUnnested::EmitUnreachable(mlir::Operation* op, - std::string error_message) { - AddThunkToThunkSequence(std::unique_ptr( - new UnreachableThunk(op, std::move(error_message)))); - return absl::OkStatus(); -} - -absl::Status IrEmitterUnnested::EmitConstant(mlir::Operation* op, - const Literal& literal) { - auto get_global = mlir::cast(op); - auto module = get_global->getParentOfType(); - auto global = mlir::cast( - module.lookupSymbol(get_global.getName())); - TF_ASSIGN_OR_RETURN(DenseDataIntermediate content, - LiteralToXlaFormat(literal)); - - int element_bytes = primitive_util::ByteWidth(literal.shape().element_type()); - TF_RET_CHECK(content.span().size() % element_bytes == 0); - // Treat int4 constant as int8 constant with half the number of elements. - int num_elements = content.span().size() / element_bytes; - - int64_t arg_index = - global->getAttrOfType("lmhlo.alloc").getInt(); - int allocation_index = ir_emitter_context_->allocations()[arg_index]->index(); - - ir_emitter_context_->emit_constant(num_elements, element_bytes, - global.getSymName(), allocation_index, - std::move(content), &b_); - return absl::OkStatus(); -} - absl::Status IrEmitterUnnested::EmitConstant( const HloConstantInstruction* instr) { TF_ASSIGN_OR_RETURN(DenseDataIntermediate content, @@ -371,37 +339,6 @@ static ConditionalThunkConfig GetConditionalThunkConfig( return config; } -absl::Status IrEmitterUnnested::EmitConditional( - mlir::Operation* op, - const absl::flat_hash_map& - hlo_for_lmhlo) { - if (ir_emitter_context_->emit_ir_from_hlo()) - return EmitConditional(hlo_for_lmhlo.at(op)); - - auto conditional = mlir::cast(op); - - std::vector branch_thunks; - - int branch_count = conditional.getBranches().size(); - branch_thunks.reserve(branch_count); - - for (int j = 0; j < branch_count; ++j) { - mlir::Region* branch_computation = &conditional.getBranches()[j]; - auto ir_emitter = IrEmitterUnnested::Create(ir_emitter_context_); - TF_RETURN_IF_ERROR( - ir_emitter->EmitLmhloRegion(branch_computation, hlo_for_lmhlo)); - branch_thunks.push_back(std::move(*ir_emitter->ConsumeThunkSequence())); - } - - ConditionalThunkConfig config = - GetConditionalThunkConfig(conditional, std::move(branch_thunks)); - - TF_ASSIGN_OR_RETURN(auto slice, GetAllocationSlice(conditional.getIndex())); - AddThunkToThunkSequence(std::unique_ptr(new ConditionalThunk( - Thunk::ThunkInfo::WithProfileAnnotation(op), std::move(config), slice))); - return absl::OkStatus(); -} - static ConditionalThunkConfig GetConditionalThunkConfig( const HloInstruction* instr, std::vector branch_thunk_sequences) { @@ -495,21 +432,20 @@ void IrEmitterUnnested::CreateStore(llvm::Value* data, llvm::Value* address, // Input = {dynamic array(with dynamic dimension meta data at the end)} // Output = {static array, dynamic_dim0, dynamic_dim1} -absl::Status IrEmitterUnnested::EmitPadToStatic(mlir::Operation* op) { - // TODO(jurahul): Create an op to represent PadToStatic. - auto pad_to_static = mlir::cast(op); +absl::Status IrEmitterUnnested::EmitPadToStatic( + const HloCustomCallInstruction* instr) { int unroll_factor = 1; - std::string ir_name = GetIrNameFromLoc(pad_to_static.getLoc()); + std::string ir_name = std::string(instr->name()); - const Shape& input_shape = GetShape(pad_to_static.getArgs().front()); + const Shape& input_shape = instr->operand(0)->shape(); LaunchDimensions launch_dimensions = CalculateLaunchDimensions( input_shape, ir_emitter_context_->gpu_device_info(), {unroll_factor}); std::vector input_arrays; std::vector output_arrays; - TF_ASSIGN_OR_RETURN( - std::tie(input_arrays, output_arrays), - BuildKernelThunkForNonFusionOp(pad_to_static, launch_dimensions)); + TF_ASSIGN_OR_RETURN(std::tie(input_arrays, output_arrays), + BuildKernelThunkForNonFusionOp(instr, instr->operands(), + launch_dimensions)); CHECK_EQ(output_arrays.size(), 0); const llvm_ir::IrArray source_array = input_arrays[0]; @@ -517,8 +453,8 @@ absl::Status IrEmitterUnnested::EmitPadToStatic(mlir::Operation* op) { auto output_dim_arrays = absl::Span(input_arrays).subspan(2); - llvm::Type* index_ty = GetIndexTypeForKernel( - pad_to_static, launch_dimensions.launch_bound(), &b_); + llvm::Type* index_ty = + GetIndexTypeForKernel(instr, launch_dimensions.launch_bound(), &b_); // pseudo code for PadToStatic on a 2d array // int* source_array = input[0]; @@ -536,10 +472,13 @@ absl::Status IrEmitterUnnested::EmitPadToStatic(mlir::Operation* op) { // int* dyn_dim1_size = source_array + meta_data_offset + sizeof(int); std::vector dynamic_dims; int alignment = raw_data_size % sizeof(int32_t); - for (int64_t i = 1; i < pad_to_static.getOutput().size(); ++i) { + std::vector output_shapes = + ShapeUtil::GetLeafShapes(instr->shape()); + + for (int64_t i = 1; i < output_shapes.size(); ++i) { // Dynamic size of each dimension is attached at the end of the source // array(operand(0)). We need to extract these value. - const Shape& dim_shape = GetShape(pad_to_static.getOutput()[i]); + const Shape& dim_shape = output_shapes[i].shape; TF_RET_CHECK(Shape::Equal()(dim_shape, ShapeUtil::MakeScalarShape(S32))); const int64_t dim_index = i - 1; @@ -559,7 +498,7 @@ absl::Status IrEmitterUnnested::EmitPadToStatic(mlir::Operation* op) { // *output[2] = *dyn_dim1_size; // } KernelSupportLibrary{&b_}.If("is_thread_0", IsBlock0Thread0(&b_), [&] { - for (int64_t i = 1; i < pad_to_static.getOutput().size(); ++i) { + for (int64_t i = 1; i < output_shapes.size(); ++i) { const int64_t dim_index = i - 1; llvm::Value* dest_dim_size_address = output_dim_arrays[dim_index].GetBasePointer(); @@ -609,7 +548,7 @@ absl::Status IrEmitterUnnested::EmitPadToStatic(mlir::Operation* op) { return absl::OkStatus(); }; - const Shape& data_shape = GetShape(pad_to_static.getOutput().front()); + const Shape& data_shape = instr->shape().tuple_shapes(0); TF_RETURN_IF_ERROR(ParallelLoopEmitter(body_generator, data_shape, launch_dimensions, &b_, {unroll_factor}) @@ -619,25 +558,25 @@ absl::Status IrEmitterUnnested::EmitPadToStatic(mlir::Operation* op) { // Input = {dynamic array(with dynamic dimension meta data at the end)} // Output = {static array, dynamic_dim0, dynamic_dim1} -absl::Status IrEmitterUnnested::EmitSliceToDynamic(mlir::Operation* op) { +absl::Status IrEmitterUnnested::EmitSliceToDynamic( + const HloCustomCallInstruction* instr) { // TODO(jurahul): Create an op to represent SliceToDynamic. - auto slice_to_dynamic = mlir::cast(op); int unroll_factor = 1; - std::string ir_name = GetIrNameFromLoc(slice_to_dynamic.getLoc()); + std::string ir_name = std::string(instr->name()); - const Shape& input_shape = GetShape(slice_to_dynamic.getArgs().front()); + const Shape& input_shape = instr->operand(0)->shape(); LaunchDimensions launch_dimensions = CalculateLaunchDimensions( input_shape, ir_emitter_context_->gpu_device_info(), {unroll_factor}); - llvm::Type* index_ty = GetIndexTypeForKernel( - slice_to_dynamic, launch_dimensions.launch_bound(), &b_); + llvm::Type* index_ty = + GetIndexTypeForKernel(instr, launch_dimensions.launch_bound(), &b_); std::vector input_arrays, output_arrays; - TF_ASSIGN_OR_RETURN( - std::tie(input_arrays, output_arrays), - BuildKernelThunkForNonFusionOp(slice_to_dynamic, launch_dimensions)); + TF_ASSIGN_OR_RETURN(std::tie(input_arrays, output_arrays), + BuildKernelThunkForNonFusionOp(instr, instr->operands(), + launch_dimensions)); - TF_RET_CHECK(slice_to_dynamic.getOutput().size() == 1); - const Shape& data_shape = GetShape(slice_to_dynamic.getOutput().front()); + const Shape& data_shape = ShapeUtil::MakeStaticShape(instr->shape()); + TF_RET_CHECK(data_shape.IsArray()); // TODO(jurahul): data_shape here is the static shape of the output (which has // a dynamic shape in XLA). Currently, we are mapping that to a static shaped @@ -659,7 +598,7 @@ absl::Status IrEmitterUnnested::EmitSliceToDynamic(mlir::Operation* op) { // Load dynamic dimensions from memory. std::vector dynamic_dims; int alignment = raw_data_size % sizeof(int32_t); - for (int64_t i = 1; i < slice_to_dynamic.getArgs().size(); ++i) { + for (int64_t i = 1; i < instr->operand_count(); ++i) { llvm::Value* source_buffer = input_arrays[i].GetBasePointer(); llvm::Type* source_buffer_pointee_type = input_arrays[i].GetBasePointeeType(); @@ -676,7 +615,7 @@ absl::Status IrEmitterUnnested::EmitSliceToDynamic(mlir::Operation* op) { // *dyn_dim1_size = *output[2]; // } KernelSupportLibrary{&b_}.If("is_thread_0", IsBlock0Thread0(&b_), [&] { - for (int64_t i = 1; i < slice_to_dynamic.getArgs().size(); ++i) { + for (int64_t i = 1; i < instr->operand_count(); ++i) { const int64_t dim_index = i - 1; llvm::Value* metadata = b_.CreateConstInBoundsGEP1_32( b_.getInt8Ty(), dest_buffer, @@ -809,171 +748,6 @@ absl::Status IrEmitterUnnested::EmitConvolutionThunk( return OkStatus(); } -absl::Status IrEmitterUnnested::EmitConvolutionThunk(mlir::Operation* op) { - using mlir::dyn_cast; - using mlir::lmhlo_gpu::Activation; - using mlir::lmhlo_gpu::ConvBackwardFilterOp; - using mlir::lmhlo_gpu::ConvBackwardInputOp; - using mlir::lmhlo_gpu::ConvForwardFusedOp; - using mlir::lmhlo_gpu::ConvForwardFusedSideInputOp; - using mlir::lmhlo_gpu::ConvForwardGraphOp; - using mlir::lmhlo_gpu::ConvForwardOp; - - std::vector operand_slices, result_slices; - int32_t n_aux_outputs = 0; - if (auto conv = dyn_cast(op)) { - n_aux_outputs = conv.getNAuxOutputs(); - } - int64_t num_operands = op->getNumOperands(); - operand_slices.reserve(num_operands - n_aux_outputs - 2); - - // The operands describe inputs, the main result of the convolution, the - // scratch workspace and n_aux_outputs return values of ops fused into the - // convolution. - for (mlir::Value operand : op->getOperands().drop_back(2 + n_aux_outputs)) { - TF_ASSIGN_OR_RETURN(auto slice, GetAllocationSlice(operand)); - operand_slices.push_back(slice); - } - - result_slices.reserve(1 + n_aux_outputs); - for (mlir::Value result : op->getOperands() - .drop_front(num_operands - n_aux_outputs - 2) - .drop_back(1)) { - TF_ASSIGN_OR_RETURN(auto slice, GetAllocationSlice(result)); - result_slices.push_back(slice); - } - mlir::Value scratch_result = op->getOperand(num_operands - 1); - TF_ASSIGN_OR_RETURN(auto scratch_slice, GetAllocationSlice(scratch_result)); - - auto apply_layout = [](const Shape& shape, - mlir::ArrayRef minor_to_major) { - return ShapeUtil::MakeShapeWithDenseLayout( - shape.element_type(), shape.dimensions(), minor_to_major); - }; - - GpuConvDescriptor descriptor; - - auto fill_conv_descriptor = [&](auto op) { - descriptor.operand0_shape = - apply_layout(GetShape(op->getOperand(0)), - op.getBackendConfig().getOperand_0Layout()); - descriptor.operand1_shape = - apply_layout(GetShape(op->getOperand(1)), - op.getBackendConfig().getOperand_1Layout()); - descriptor.result_shape = - apply_layout(GetShape(op->getOperand(num_operands - n_aux_outputs - 2)), - op.getBackendConfig().getResultLayout()); - descriptor.dnums = ConvertConvDimensionNumbers(op.getDimensionNumbers()); - descriptor.scratch_size = scratch_slice.size(); - mlir::DenseIntElementsAttr window_strides = op.getWindowStrides().value(); - mlir::DenseIntElementsAttr padding = op.getPadding().value(); - mlir::DenseIntElementsAttr lhs_dilation = op.getLhsDilation().value(); - mlir::DenseIntElementsAttr rhs_dilation = op.getRhsDilation().value(); - mlir::DenseElementsAttr window_reversal = op.getWindowReversal().value(); - for (auto index : llvm::seq(0, window_strides.getNumElements())) { - WindowDimension* dim = descriptor.window.add_dimensions(); - // Window size for a convolution is the same as the kernel size. - // Kernel size of the convolution is operand1_shape. We need to look at - // the convolution dimension numbers kernel spatial dimensions to get - // the window size. - int kernel_dim = descriptor.dnums.kernel_spatial_dimensions(index); - dim->set_size(descriptor.operand0_shape.dimensions(kernel_dim)); - dim->set_stride(window_strides.getValues()[index]); - dim->set_padding_low(padding.getValues()[index]); - dim->set_padding_high(padding.getValues()[index]); - dim->set_base_dilation(lhs_dilation.getValues()[index]); - dim->set_window_dilation(rhs_dilation.getValues()[index]); - dim->set_window_reversal(window_reversal.getValues()[index]); - } - descriptor.feature_group_count = op.getFeatureGroupCount(); - { - auto* algorithm = descriptor.backend_config.mutable_algorithm(); - algorithm->set_algo_id(op.getBackendConfig().getAlgorithm()); - algorithm->set_math_type(op.getBackendConfig().getTensorOpsEnabled() - ? se::dnn::AlgorithmProto::TENSOR_OP_MATH - : se::dnn::AlgorithmProto::DEFAULT_MATH); - for (int i = 0; i < op.getBackendConfig().getKnobIds().size(); ++i) { - // N.B. tuning_knobs is a map rather than a repeated field, so this - // doesn't require reserving space up front. - (*algorithm - ->mutable_tuning_knobs())[op.getBackendConfig().getKnobIds()[i]] = - op.getBackendConfig().getKnobValues()[i]; - } - algorithm->set_is_cudnn_frontend( - op.getBackendConfig().getIsCudnnFrontend()); - auto workspace_size = op.getBackendConfig().getWorkspaceSize(); - if (workspace_size >= 0) { - algorithm->mutable_workspace_size()->set_value(workspace_size); - } - } - descriptor.backend_config.set_conv_result_scale( - op.getResultScale().convertToDouble()); - descriptor.backend_config.set_reordered_int8_nchw_vect( - op.getBackendConfig().getIsCudnnReorderedInt8()); - }; - - auto set_activation_mode = [&](auto op) -> absl::Status { - TF_ASSIGN_OR_RETURN(stream_executor::dnn::ActivationMode activation_mode, - ConvertConvActivationMode(op.getActivationMode())); - descriptor.backend_config.set_activation_mode(activation_mode); - return absl::OkStatus(); - }; - - if (auto conv = dyn_cast(op)) { - descriptor.kind = CudnnConvKind::kForward; - fill_conv_descriptor(conv); - } else if (auto conv = dyn_cast(op)) { - descriptor.kind = CudnnConvKind::kBackwardInput; - fill_conv_descriptor(conv); - } else if (auto conv = dyn_cast(op)) { - descriptor.kind = CudnnConvKind::kBackwardFilter; - fill_conv_descriptor(conv); - } else if (auto conv = dyn_cast(op)) { - descriptor.kind = CudnnConvKind::kForwardGraph; - fill_conv_descriptor(conv); - descriptor.backend_config.set_serialized_graph( - conv.getSerializedGraph().data()); - } else if (auto conv = dyn_cast(op)) { - descriptor.kind = CudnnConvKind::kForwardActivation; - fill_conv_descriptor(conv); - TF_RETURN_IF_ERROR(set_activation_mode(conv)); - descriptor.backend_config.set_leakyrelu_alpha( - conv.getLeakyreluAlpha().convertToDouble()); - } else if (auto conv = dyn_cast(op)) { - descriptor.kind = CudnnConvKind::kForwardActivation; - fill_conv_descriptor(conv); - TF_RETURN_IF_ERROR(set_activation_mode(conv)); - descriptor.backend_config.set_side_input_scale( - conv.getSideInputScale().convertToDouble()); - } else { - return Internal("EmitConvolutionThunk: Unexpected operation"); - } - TF_ASSIGN_OR_RETURN(GpuConvConfig config, GetGpuConvConfig(descriptor, "")); - AddThunkToThunkSequence(std::make_unique( - Thunk::ThunkInfo::WithProfileAnnotation(op), std::move(config), - std::move(operand_slices), std::move(result_slices), scratch_slice)); - return absl::OkStatus(); -} - -absl::Status IrEmitterUnnested::EmitGemmThunk(mlir::Operation* op) { - auto gemm = mlir::dyn_cast(op); - TF_RET_CHECK(gemm != nullptr); - - TF_ASSIGN_OR_RETURN(auto a, GetAllocationSlice(gemm.getA())); - TF_ASSIGN_OR_RETURN(auto b, GetAllocationSlice(gemm.getB())); - TF_ASSIGN_OR_RETURN(auto c, GetAllocationSlice(gemm.getC())); - bool deterministic_ops = - ir_emitter_context_->debug_options().xla_gpu_deterministic_ops(); - - TF_ASSIGN_OR_RETURN(GemmConfig config, GemmConfig::For(gemm)); - auto thunk = std::make_unique( - Thunk::ThunkInfo::WithProfileAnnotation(op), std::move(config), a, b, c, - std::nullopt, deterministic_ops); - - AddThunkToThunkSequence(std::move(thunk)); - return absl::OkStatus(); -} - absl::Status IrEmitterUnnested::EmitGemmThunk( const HloCustomCallInstruction* instr) { TF_ASSIGN_OR_RETURN(BufferAllocation::Slice a, @@ -1074,36 +848,6 @@ absl::Status IrEmitterUnnested::EmitCublasLtMatmulThunk( return absl::OkStatus(); } -absl::Status IrEmitterUnnested::EmitCublasLtMatmulThunk(mlir::Operation* op) { - auto matmul = mlir::dyn_cast(op); - TF_RET_CHECK(matmul != nullptr); - - TF_ASSIGN_OR_RETURN(auto a, GetAllocationSlice(matmul.getA())); - TF_ASSIGN_OR_RETURN(auto b, GetAllocationSlice(matmul.getB())); - TF_ASSIGN_OR_RETURN(auto c, GetAllocationSlice(matmul.getC())); - TF_ASSIGN_OR_RETURN(auto d, GetAllocationSlice(matmul.getD())); - - BufferAllocation::Slice bias, a_scale, b_scale, c_scale, d_scale, d_amax; - if (matmul.getBias() != nullptr) { - TF_ASSIGN_OR_RETURN(bias, GetAllocationSlice(matmul.getBias())); - } - - BufferAllocation::Slice aux; - if (matmul.getAux() != nullptr) { - TF_ASSIGN_OR_RETURN(aux, GetAllocationSlice(matmul.getAux())); - } - - TF_ASSIGN_OR_RETURN(GemmConfig gemm_config, GemmConfig::For(matmul)); - TF_ASSIGN_OR_RETURN(auto epilogue, - gpublas_lt::AsBlasLtEpilogue(matmul.getEpilogue())); - auto thunk = std::make_unique( - Thunk::ThunkInfo::WithProfileAnnotation(op), std::move(gemm_config), - epilogue, matmul.getAlgorithm(), a, b, c, d, bias, aux, a_scale, b_scale, - c_scale, d_scale, d_amax); - - AddThunkToThunkSequence(std::move(thunk)); - return absl::OkStatus(); -} #endif // GOOGLE_CUDA || TF_HIPBLASLT #if GOOGLE_CUDA @@ -1183,490 +927,374 @@ absl::Status IrEmitterUnnested::EmitCublasLtMatmulThunkF8( return absl::OkStatus(); } -absl::Status IrEmitterUnnested::EmitCublasLtMatmulThunkF8(mlir::Operation* op) { - auto matmul = mlir::dyn_cast(op); - TF_RET_CHECK(matmul != nullptr); - - TF_ASSIGN_OR_RETURN(BufferAllocation::Slice a, - GetAllocationSlice(matmul.getA())); - TF_ASSIGN_OR_RETURN(BufferAllocation::Slice b, - GetAllocationSlice(matmul.getB())); - TF_ASSIGN_OR_RETURN(BufferAllocation::Slice c, - GetAllocationSlice(matmul.getC())); - TF_ASSIGN_OR_RETURN(BufferAllocation::Slice d, - GetAllocationSlice(matmul.getD())); - TF_ASSIGN_OR_RETURN(BufferAllocation::Slice a_scale, - GetAllocationSlice(matmul.getAScale())); - TF_ASSIGN_OR_RETURN(BufferAllocation::Slice b_scale, - GetAllocationSlice(matmul.getBScale())); - TF_ASSIGN_OR_RETURN(BufferAllocation::Slice c_scale, - GetAllocationSlice(matmul.getCScale())); - TF_ASSIGN_OR_RETURN(BufferAllocation::Slice d_scale, - GetAllocationSlice(matmul.getDScale())); - BufferAllocation::Slice d_amax, bias; - if (matmul.getDAmax() != nullptr) { - TF_ASSIGN_OR_RETURN(d_amax, GetAllocationSlice(matmul.getDAmax())); - } - if (matmul.getBias() != nullptr) { - TF_ASSIGN_OR_RETURN(bias, GetAllocationSlice(matmul.getBias())); - } - - BufferAllocation::Slice aux; // Not used. - - TF_ASSIGN_OR_RETURN(GemmConfig gemm_config, GemmConfig::For(matmul)); - TF_ASSIGN_OR_RETURN(auto epilogue, - gpublas_lt::AsBlasLtEpilogue(matmul.getEpilogue())); - auto thunk = std::make_unique( - Thunk::ThunkInfo::WithProfileAnnotation(op), std::move(gemm_config), - epilogue, matmul.getAlgorithm(), a, b, c, d, bias, aux, a_scale, b_scale, - c_scale, d_scale, d_amax); - - AddThunkToThunkSequence(std::move(thunk)); - return absl::OkStatus(); -} - absl::Status IrEmitterUnnested::EmitConvolutionReorderThunk( - mlir::Operation* op) { - using mlir::dyn_cast; - using mlir::lmhlo_gpu::CudnnConvReorderFilterAndBiasOp; - using mlir::lmhlo_gpu::CudnnConvReorderFilterOp; - - std::vector operand_slices; - std::vector result_slices; - std::vector filter_dims; - - auto set_filter_data = [&](auto op) -> absl::Status { - TF_ASSIGN_OR_RETURN(BufferAllocation::Slice filter_input, - GetAllocationSlice(op.getFilterInput())); - operand_slices.push_back(filter_input); - - TF_ASSIGN_OR_RETURN(BufferAllocation::Slice filter_output, - GetAllocationSlice(op.getFilterOutput())); - result_slices.push_back(filter_output); - - auto filter_dims_values = op.getFilterDims().template getValues(); - filter_dims.assign(filter_dims_values.begin(), filter_dims_values.end()); - return absl::OkStatus(); - }; - - if (auto reorder = dyn_cast(op)) { - TF_RETURN_IF_ERROR(set_filter_data(reorder)); - + const HloCustomCallInstruction* instr) { + bool has_bias = instr->operand_count() > 1; + Shape shape = has_bias ? instr->shape().tuple_shapes(0) : instr->shape(); + if (shape.rank() != 5 || shape.dimensions(4) != 32) { + return Internal("Unexpected shape for convolution reorder: %s", + instr->ToString()); + } + absl::InlinedVector filter_dims = { + shape.dimensions(0), shape.dimensions(1) * 32, shape.dimensions(2), + shape.dimensions(3)}; + + absl::InlinedVector operand_slices; + TF_ASSIGN_OR_RETURN(BufferAllocation::Slice filter_input, + GetAllocationSliceForHlo(instr->operand(0))); + operand_slices.push_back(filter_input); + if (has_bias) { TF_ASSIGN_OR_RETURN(BufferAllocation::Slice bias_input, - GetAllocationSlice(reorder.getBiasInput())); + GetAllocationSliceForHlo(instr->operand(1))); operand_slices.push_back(bias_input); + } + absl::InlinedVector result_slices; + if (has_bias) { + TF_ASSIGN_OR_RETURN(BufferAllocation::Slice filter_output, + GetAllocationSliceForHlo(instr, {0})); + result_slices.push_back(filter_output); TF_ASSIGN_OR_RETURN(BufferAllocation::Slice bias_output, - GetAllocationSlice(reorder.getBiasOutput())); + GetAllocationSliceForHlo(instr, {1})); result_slices.push_back(bias_output); - } else if (auto reorder = dyn_cast(op)) { - TF_RETURN_IF_ERROR(set_filter_data(reorder)); } else { - return Internal("Unexpected operation"); + TF_ASSIGN_OR_RETURN(BufferAllocation::Slice filter_output, + GetAllocationSliceForHlo(instr)); + result_slices.push_back(filter_output); } auto thunk = std::make_unique( - Thunk::ThunkInfo::WithProfileAnnotation(op), absl::MakeSpan(filter_dims), - std::move(operand_slices), std::move(result_slices)); - + Thunk::ThunkInfo::WithProfileAnnotation(instr), + absl::MakeSpan(filter_dims), operand_slices, result_slices); AddThunkToThunkSequence(std::move(thunk)); return absl::OkStatus(); } -absl::Status IrEmitterUnnested::EmitNormThunk(mlir::Operation* op) { - auto norm = mlir::dyn_cast(op); - TF_RET_CHECK(norm != nullptr); +absl::Status IrEmitterUnnested::EmitNormThunk( + const HloCustomCallInstruction* instr) { + TF_ASSIGN_OR_RETURN(auto const gpu_backend_config, + instr->backend_config()); + const xla::gpu::CudnnNormBackendConfig& backend_config = + gpu_backend_config.cudnn_norm_backend_config(); - TF_ASSIGN_OR_RETURN(BufferAllocation::Slice input_slice, - GetAllocationSlice(norm.getInput())); + TF_ASSIGN_OR_RETURN(BufferAllocation::Slice x_slice, + GetAllocationSliceForHlo(instr->operand(0))); TF_ASSIGN_OR_RETURN(BufferAllocation::Slice scale_slice, - GetAllocationSlice(norm.getScale())); - TF_ASSIGN_OR_RETURN(BufferAllocation::Slice bias_slice, - GetAllocationSlice(norm.getBias())); - TF_ASSIGN_OR_RETURN(BufferAllocation::Slice output_slice, - GetAllocationSlice(norm.getOutput())); + GetAllocationSliceForHlo(instr->operand(1))); + TF_ASSIGN_OR_RETURN(BufferAllocation::Slice y_or_dx_slice, + GetAllocationSliceForHlo(instr, {0})); - int64_t num_operands = op->getNumOperands(); - std::optional expectation_slice, norm_factor_slice; - if (num_operands == 7) { + std::optional bias_slice, expectation_slice, + norm_factor_slice, dy_slice, dscale_slice, dbias_slice; + + if (backend_config.kind() == + xla::gpu::CudnnNormBackendConfig::LAYER_FWD_INFER || + backend_config.kind() == + xla::gpu::CudnnNormBackendConfig::LAYER_FWD_TRAIN) { + TF_ASSIGN_OR_RETURN(bias_slice, + GetAllocationSliceForHlo(instr->operand(2))); + } + if (backend_config.kind() == + xla::gpu::CudnnNormBackendConfig::LAYER_FWD_TRAIN) { TF_ASSIGN_OR_RETURN(expectation_slice, - GetAllocationSlice(norm.getExpectation())); + GetAllocationSliceForHlo(instr, {1})); TF_ASSIGN_OR_RETURN(norm_factor_slice, - GetAllocationSlice(norm.getNormFactor())); + GetAllocationSliceForHlo(instr, {2})); + } + if (backend_config.kind() == xla::gpu::CudnnNormBackendConfig::LAYER_BWD) { + TF_ASSIGN_OR_RETURN(dy_slice, GetAllocationSliceForHlo(instr->operand(2))); + TF_ASSIGN_OR_RETURN(expectation_slice, + GetAllocationSliceForHlo(instr->operand(3))); + TF_ASSIGN_OR_RETURN(norm_factor_slice, + GetAllocationSliceForHlo(instr->operand(4))); + TF_ASSIGN_OR_RETURN(dscale_slice, GetAllocationSliceForHlo(instr, {1})); + TF_ASSIGN_OR_RETURN(dbias_slice, GetAllocationSliceForHlo(instr, {2})); } - TF_ASSIGN_OR_RETURN(BufferAllocation::Slice scratch_slice, - GetAllocationSlice(norm.getScratch())); + GetAllocationSliceForHlo( + instr, {instr->shape().tuple_shapes_size() - 1})); GpuNormDescriptor descriptor; - auto* algorithm = descriptor.backend_config.mutable_algorithm(); - algorithm->set_algo_id(norm.getAlgorithmConfig().getAlgorithm()); - algorithm->set_is_cudnn_frontend(true); - auto workspace_size = norm.getAlgorithmConfig().getWorkspaceSize(); - algorithm->mutable_workspace_size()->set_value(workspace_size); - - descriptor.input_shape = GetShape(norm->getOperand(0)); - descriptor.scale_shape = GetShape(norm->getOperand(1)); - descriptor.bias_shape = GetShape(norm->getOperand(2)); - descriptor.output_shape = GetShape(norm->getOperand(3)); - if (num_operands == 7) { - descriptor.expectation_shape = GetShape(norm->getOperand(4)); - descriptor.norm_factor_shape = GetShape(norm->getOperand(5)); + descriptor.backend_config = backend_config; + + descriptor.x_shape = instr->operand(0)->shape(); + descriptor.scale_shape = instr->operand(1)->shape(); + descriptor.y_or_dx_shape = ShapeUtil::GetSubshape(instr->shape(), {0}); + if (backend_config.kind() == + xla::gpu::CudnnNormBackendConfig::LAYER_FWD_INFER || + backend_config.kind() == + xla::gpu::CudnnNormBackendConfig::LAYER_FWD_TRAIN) { + descriptor.bias_shape = instr->operand(2)->shape(); + } + if (backend_config.kind() == + xla::gpu::CudnnNormBackendConfig::LAYER_FWD_TRAIN) { + descriptor.expectation_shape = ShapeUtil::GetSubshape(instr->shape(), {1}); + descriptor.norm_factor_shape = ShapeUtil::GetSubshape(instr->shape(), {2}); + } + if (backend_config.kind() == xla::gpu::CudnnNormBackendConfig::LAYER_BWD) { + descriptor.dy_shape = instr->operand(2)->shape(); + descriptor.expectation_shape = instr->operand(3)->shape(); + descriptor.norm_factor_shape = instr->operand(4)->shape(); + descriptor.dscale_shape = ShapeUtil::GetSubshape(instr->shape(), {1}); + descriptor.dbias_shape = ShapeUtil::GetSubshape(instr->shape(), {2}); } - descriptor.backend_config.set_epsilon(norm.getEpsilon().convertToDouble()); TF_ASSIGN_OR_RETURN(GpuNormConfig config, GpuNormConfig::For(descriptor)); auto thunk = std::make_unique( - Thunk::ThunkInfo::WithProfileAnnotation(op), std::move(config), - input_slice, scale_slice, bias_slice, output_slice, expectation_slice, - norm_factor_slice, scratch_slice); - + Thunk::ThunkInfo::WithProfileAnnotation(instr), std::move(config), + x_slice, scale_slice, y_or_dx_slice, bias_slice, expectation_slice, + norm_factor_slice, dy_slice, dscale_slice, dbias_slice, scratch_slice); AddThunkToThunkSequence(std::move(thunk)); - return absl::OkStatus(); } -absl::Status IrEmitterUnnested::EmitFusedMHAThunk(mlir::Operation* op) { - using mlir::dyn_cast; - using mlir::lmhlo_gpu::fusedMHAOp; - GpufMHADescriptor descriptor; - BufferAllocation::Slice lhs_bmm1_slice, rhs_bmm1_slice, rhs_bmm2_slice, - output_slice, scratch_slice, activation_slice, mask_slice, bias_slice; - - auto populate_common = [&](auto fmha) -> absl::Status { - descriptor.backend_config.set_fmha_scale( - fmha.getFmhaScale().convertToDouble()); - - if (fmha.getDropoutRate()) { - descriptor.backend_config.set_dropout_rate( - (*fmha.getDropoutRate()).convertToDouble()); - } - - if (fmha.getSeed()) { - descriptor.backend_config.set_seed((*fmha.getSeed())); - } - - auto* algorithm = descriptor.backend_config.mutable_algorithm(); - algorithm->set_algo_id(fmha.getAlgorithmConfig().getAlgorithm()); - for (int i = 0; i < fmha.getAlgorithmConfig().getKnobIds().size(); ++i) { - // N.B. tuning_knobs is a map rather than a repeated field, so this - // doesn't require reserving space up front. - (*algorithm->mutable_tuning_knobs())[fmha.getAlgorithmConfig() - .getKnobIds()[i]] = - fmha.getAlgorithmConfig().getKnobValues()[i]; - } - algorithm->set_is_cudnn_frontend(true); - auto workspace_size = fmha.getAlgorithmConfig().getWorkspaceSize(); - if (workspace_size >= 0) { - algorithm->mutable_workspace_size()->set_value(workspace_size); - } - - descriptor.bmm1_dnums = - ConvertDotDimensionNumbers(fmha.getBmm1DotDimensionNumbers()); - descriptor.bmm2_dnums = - ConvertDotDimensionNumbers(fmha.getBmm2DotDimensionNumbers()); - - descriptor.lhs_bmm1_shape = ShapeUtil::MakeShapeWithDenseLayout( - GetShape(fmha.getLhsBmm1()).element_type(), - GetShape(fmha.getLhsBmm1()).dimensions(), - GetShape(fmha.getLhsBmm1()).layout().minor_to_major()); - TF_ASSIGN_OR_RETURN(lhs_bmm1_slice, GetAllocationSlice(fmha.getLhsBmm1())); - - descriptor.rhs_bmm1_shape = ShapeUtil::MakeShapeWithDenseLayout( - GetShape(fmha.getRhsBmm1()).element_type(), - GetShape(fmha.getRhsBmm1()).dimensions(), - GetShape(fmha.getRhsBmm1()).layout().minor_to_major()); - TF_ASSIGN_OR_RETURN(rhs_bmm1_slice, GetAllocationSlice(fmha.getRhsBmm1())); - - descriptor.rhs_bmm2_shape = ShapeUtil::MakeShapeWithDenseLayout( - GetShape(fmha.getRhsBmm2()).element_type(), - GetShape(fmha.getRhsBmm2()).dimensions(), - GetShape(fmha.getRhsBmm2()).layout().minor_to_major()); - TF_ASSIGN_OR_RETURN(rhs_bmm2_slice, GetAllocationSlice(fmha.getRhsBmm2())); - - descriptor.output_shapes.push_back(ShapeUtil::MakeShapeWithDenseLayout( - GetShape(fmha.getOutput()).element_type(), - GetShape(fmha.getOutput()).dimensions(), - GetShape(fmha.getOutput()).layout().minor_to_major())); - TF_ASSIGN_OR_RETURN(output_slice, GetAllocationSlice(fmha.getOutput())); - - TF_ASSIGN_OR_RETURN(scratch_slice, GetAllocationSlice(fmha.getScratch())); - - TF_ASSIGN_OR_RETURN(auto intermediate_tensor_dims_array, - ConvertMlirArrayAttrToInt64Array( - fmha.getIntermediateTensorDimensions())); - if (fmha.getActivation() != nullptr) { - descriptor.output_shapes.push_back(ShapeUtil::MakeShapeWithDenseLayout( - GetShape(fmha.getActivation()).element_type(), - GetShape(fmha.getActivation()).dimensions(), - GetShape(fmha.getActivation()).layout().minor_to_major())); - TF_ASSIGN_OR_RETURN(activation_slice, - GetAllocationSlice(fmha.getActivation())); - } - - if (fmha.getBias() != nullptr) { - descriptor.bias_shape = ShapeUtil::MakeShapeWithDenseLayout( - GetShape(fmha.getBias()).element_type(), - GetShape(fmha.getBias()).dimensions(), - GetShape(fmha.getBias()).layout().minor_to_major()); - - TF_ASSIGN_OR_RETURN(bias_slice, GetAllocationSlice(fmha.getBias())); - } - - if (fmha.getMask() != nullptr) { - descriptor.mask_shape = ShapeUtil::MakeShapeWithDenseLayout( - GetShape(fmha.getMask()).element_type(), - GetShape(fmha.getMask()).dimensions(), - GetShape(fmha.getMask()).layout().minor_to_major()); - - TF_ASSIGN_OR_RETURN(mask_slice, GetAllocationSlice(fmha.getMask())); +absl::Status IrEmitterUnnested::EmitFusedMHAThunk( + const HloCustomCallInstruction* instr) { + const HloInstruction* lhs_bmm1 = instr->operand(0); + const HloInstruction* rhs_bmm1 = instr->operand(1); + const HloInstruction* rhs_bmm2 = instr->operand(2); + + TF_ASSIGN_OR_RETURN(BufferAllocation::Slice lhs_bmm1_slice, + GetAllocationSliceForHlo(lhs_bmm1)); + TF_ASSIGN_OR_RETURN(BufferAllocation::Slice rhs_bmm1_slice, + GetAllocationSliceForHlo(rhs_bmm1)); + TF_ASSIGN_OR_RETURN(BufferAllocation::Slice rhs_bmm2_slice, + GetAllocationSliceForHlo(rhs_bmm2)); + TF_ASSIGN_OR_RETURN(BufferAllocation::Slice output_slice, + GetAllocationSliceForHlo(instr, {0})); + TF_ASSIGN_OR_RETURN(BufferAllocation::Slice scratch_slice, + GetAllocationSliceForHlo(instr, {1})); + BufferAllocation::Slice activation_slice; + bool has_activation = xla::ShapeUtil::TupleElementCount(instr->shape()) == 3; + if (has_activation) { + TF_ASSIGN_OR_RETURN(activation_slice, GetAllocationSliceForHlo(instr, {2})); + } + + TF_ASSIGN_OR_RETURN(const xla::gpu::CudnnfMHAKind kind, + xla::gpu::GetCudnnfMHAKind(instr)); + BufferAllocation::Slice mask_slice, bias_slice; + std::optional mask_shape, bias_shape; + { + bool has_mask = kind == CudnnfMHAKind::kScaleMaskSoftmax || + kind == CudnnfMHAKind::kScaleMaskSoftmaxDropout || + kind == CudnnfMHAKind::kScaleBiasMaskSoftmax || + kind == CudnnfMHAKind::kScaleBiasMaskSoftmaxDropout; + bool has_bias = kind == CudnnfMHAKind::kScaleBiasMaskSoftmax || + kind == CudnnfMHAKind::kScaleBiasSoftmaxDropout || + kind == CudnnfMHAKind::kScaleBiasSoftmax || + kind == CudnnfMHAKind::kScaleBiasSoftmaxDropout; + + if (has_mask) { + const HloInstruction* mask = instr->operand(3); + TF_ASSIGN_OR_RETURN(mask_slice, GetAllocationSliceForHlo(mask)); + mask_shape = mask->shape(); + if (has_bias) { + const HloInstruction* bias = instr->operand(4); + TF_ASSIGN_OR_RETURN(bias_slice, GetAllocationSliceForHlo(bias)); + bias_shape = bias->shape(); + } + } else if (has_bias) { + const HloInstruction* bias = instr->operand(3); + TF_ASSIGN_OR_RETURN(bias_slice, GetAllocationSliceForHlo(bias)); + bias_shape = bias->shape(); } - TF_ASSIGN_OR_RETURN( - auto intermediate_tensor_layout_array, - ConvertMlirArrayAttrToInt64Array(fmha.getIntermediateTensorLayout())); - - descriptor.intermediate_lhs_bmm2_shape = - ShapeUtil::MakeShapeWithDenseLayout( - GetShape(fmha.getOutput()).element_type(), - intermediate_tensor_dims_array, intermediate_tensor_layout_array); - - // set if flash attention here - descriptor.is_flash_attention = fmha.getIsFlashAttention(); - // set if causal mask here - descriptor.is_causal_mask = fmha.getIsCausalMask(); - return absl::OkStatus(); - }; - - if (auto fmha_op = dyn_cast(op)) { - TF_RET_CHECK(fmha_op != nullptr); - TF_ASSIGN_OR_RETURN(CudnnfMHAKind kind, - AsCudnnfMHAKind(fmha_op.getFusedMhaDag())); - descriptor.kind = kind; - TF_RETURN_IF_ERROR(populate_common(fmha_op)); - } else { - return Internal("Unexpected operation"); } - TF_ASSIGN_OR_RETURN(GpufMHAConfig config, GpufMHAConfig::For(descriptor)); + + TF_ASSIGN_OR_RETURN(const auto gpu_config, + instr->backend_config()); + const xla::gpu::CudnnfMHABackendConfig& config = + gpu_config.cudnn_fmha_backend_config(); + Shape intermediate_tensor_shape(config.intermediate_tensor_shape()); + absl::InlinedVector output_shapes = { + ShapeUtil::GetSubshape(instr->shape(), {0})}; + if (has_activation) { + output_shapes.push_back(ShapeUtil::GetSubshape(instr->shape(), {2})); + } + + GpufMHADescriptor descriptor = {kind, + config, + config.is_flash_attention(), + config.is_causal_mask(), + lhs_bmm1->shape(), + rhs_bmm1->shape(), + rhs_bmm2->shape(), + intermediate_tensor_shape, + output_shapes, + config.bmm1_dot_dimension_numbers(), + config.bmm2_dot_dimension_numbers(), + mask_shape, + bias_shape}; + + TF_ASSIGN_OR_RETURN(GpufMHAConfig fmha_config, + GpufMHAConfig::For(descriptor)); AddThunkToThunkSequence(std::make_unique( - Thunk::ThunkInfo::WithProfileAnnotation(op), std::move(config), + Thunk::ThunkInfo::WithProfileAnnotation(instr), std::move(fmha_config), lhs_bmm1_slice, rhs_bmm1_slice, rhs_bmm2_slice, output_slice, scratch_slice, mask_slice, bias_slice, activation_slice)); return absl::OkStatus(); } -absl::Status IrEmitterUnnested::EmitFusedMHABackwardThunk(mlir::Operation* op) { - using mlir::dyn_cast; - using mlir::lmhlo_gpu::fusedMHABackwardOp; - - GpufMHABackwardDescriptor descriptor; - BufferAllocation::Slice bmm1_grad_gemm1_rhs_slice, bmm1_grad_gemm2_rhs_slice, - bmm2_grad_gemm1_lhs_slice, bmm2_grad_gemm2_rhs_slice, d_output_slice, - scratch_slice, mask_slice, fwd_output_slice, bias_slice; - BufferAllocation::Slice d_bmm1_lhs_slice, d_bmm1_rhs_slice, d_bmm2_rhs_slice, - d_s_slice, softmax_sum_slice, d_Q_accum_slice, d_bias_slice; - - auto populate_common = [&](auto fmha) -> absl::Status { - descriptor.backend_config.set_fmha_scale( - fmha.getFmhaScale().convertToDouble()); - - if (fmha.getDropoutRate()) { - descriptor.backend_config.set_dropout_rate( - (*fmha.getDropoutRate()).convertToDouble()); - } - - if (fmha.getSeed()) { - descriptor.backend_config.set_seed((*fmha.getSeed())); - } - - auto* algorithm = descriptor.backend_config.mutable_algorithm(); - algorithm->set_algo_id(fmha.getAlgorithmConfig().getAlgorithm()); - for (int i = 0; i < fmha.getAlgorithmConfig().getKnobIds().size(); ++i) { - // N.B. tuning_knobs is a map rather than a repeated field, so this - // doesn't require reserving space up front. - (*algorithm->mutable_tuning_knobs())[fmha.getAlgorithmConfig() - .getKnobIds()[i]] = - fmha.getAlgorithmConfig().getKnobValues()[i]; - } - algorithm->set_is_cudnn_frontend(true); - auto workspace_size = fmha.getAlgorithmConfig().getWorkspaceSize(); - if (workspace_size >= 0) { - algorithm->mutable_workspace_size()->set_value(workspace_size); - } - - // set if flash attention here - descriptor.is_flash_attention = fmha.getIsFlashAttention(); - // set if causal mask here - descriptor.is_causal_mask = fmha.getIsCausalMask(); - descriptor.bmm1_grad_gemm1_dnums = - ConvertDotDimensionNumbers(fmha.getBmm1GradGemm1DotDimensionNumbers()); - descriptor.bmm1_grad_gemm2_dnums = - ConvertDotDimensionNumbers(fmha.getBmm1GradGemm2DotDimensionNumbers()); - descriptor.bmm2_grad_gemm1_dnums = - ConvertDotDimensionNumbers(fmha.getBmm2GradGemm1DotDimensionNumbers()); - descriptor.bmm2_grad_gemm2_dnums = - ConvertDotDimensionNumbers(fmha.getBmm2GradGemm2DotDimensionNumbers()); - - descriptor.bmm1_grad_gemm1_rhs_shape = ShapeUtil::MakeShapeWithDenseLayout( - GetShape(fmha.getBmm1GradGemm1Rhs()).element_type(), - GetShape(fmha.getBmm1GradGemm1Rhs()).dimensions(), - GetShape(fmha.getBmm1GradGemm1Rhs()).layout().minor_to_major()); - TF_ASSIGN_OR_RETURN(bmm1_grad_gemm1_rhs_slice, - GetAllocationSlice(fmha.getBmm1GradGemm1Rhs())); - - descriptor.bmm1_grad_gemm2_rhs_shape = ShapeUtil::MakeShapeWithDenseLayout( - GetShape(fmha.getBmm1GradGemm2Rhs()).element_type(), - GetShape(fmha.getBmm1GradGemm2Rhs()).dimensions(), - GetShape(fmha.getBmm1GradGemm2Rhs()).layout().minor_to_major()); - TF_ASSIGN_OR_RETURN(bmm1_grad_gemm2_rhs_slice, - GetAllocationSlice(fmha.getBmm1GradGemm2Rhs())); - - // fwd activation - // fmha.getBmm2GradGemm1Lhs() could be bmm2_grad_gemm1_lhs for regular - // attention or softmax stats for flash attention here we set the shape to - // be bmm2_grad_gemm1_lhs even it is flash attention - if (descriptor.is_flash_attention) { - // flash attention TODO: make sure the layout is correct for - // bmm2_grad_gemm1_lhs - TF_ASSIGN_OR_RETURN(auto intermediate_tensor_dims_array, - ConvertMlirArrayAttrToInt64Array( - fmha.getIntermediateTensorDimensions())); - TF_ASSIGN_OR_RETURN( - auto intermediate_tensor_layout_array, - ConvertMlirArrayAttrToInt64Array(fmha.getIntermediateTensorLayout())); - - descriptor.bmm2_grad_gemm1_lhs_shape = - ShapeUtil::MakeShapeWithDenseLayout( - GetShape(fmha.getDOutput()).element_type(), - intermediate_tensor_dims_array, intermediate_tensor_layout_array); - } else { - descriptor.bmm2_grad_gemm1_lhs_shape = - ShapeUtil::MakeShapeWithDenseLayout( - GetShape(fmha.getBmm2GradGemm1Lhs()).element_type(), - GetShape(fmha.getBmm2GradGemm1Lhs()).dimensions(), - GetShape(fmha.getBmm2GradGemm1Lhs()).layout().minor_to_major()); - } - TF_ASSIGN_OR_RETURN(bmm2_grad_gemm1_lhs_slice, - GetAllocationSlice(fmha.getBmm2GradGemm1Lhs())); - - descriptor.bmm2_grad_gemm2_rhs_shape = ShapeUtil::MakeShapeWithDenseLayout( - GetShape(fmha.getBmm2GradGemm2Rhs()).element_type(), - GetShape(fmha.getBmm2GradGemm2Rhs()).dimensions(), - GetShape(fmha.getBmm2GradGemm2Rhs()).layout().minor_to_major()); - TF_ASSIGN_OR_RETURN(bmm2_grad_gemm2_rhs_slice, - GetAllocationSlice(fmha.getBmm2GradGemm2Rhs())); - - descriptor.d_output_shape = ShapeUtil::MakeShapeWithDenseLayout( - GetShape(fmha.getDOutput()).element_type(), - GetShape(fmha.getDOutput()).dimensions(), - GetShape(fmha.getDOutput()).layout().minor_to_major()); - TF_ASSIGN_OR_RETURN(d_output_slice, GetAllocationSlice(fmha.getDOutput())); - descriptor.d_bmm1_lhs_shape = ShapeUtil::MakeShapeWithDenseLayout( - GetShape(fmha.getDBmm1Lhs()).element_type(), - GetShape(fmha.getDBmm1Lhs()).dimensions(), - GetShape(fmha.getDBmm1Lhs()).layout().minor_to_major()); - TF_ASSIGN_OR_RETURN(d_bmm1_lhs_slice, - GetAllocationSlice(fmha.getDBmm1Lhs())); - - descriptor.d_bmm1_rhs_shape = ShapeUtil::MakeShapeWithDenseLayout( - GetShape(fmha.getDBmm1Rhs()).element_type(), - GetShape(fmha.getDBmm1Rhs()).dimensions(), - GetShape(fmha.getDBmm1Rhs()).layout().minor_to_major()); - TF_ASSIGN_OR_RETURN(d_bmm1_rhs_slice, - GetAllocationSlice(fmha.getDBmm1Rhs())); - - descriptor.d_bmm2_rhs_shape = ShapeUtil::MakeShapeWithDenseLayout( - GetShape(fmha.getDBmm2Rhs()).element_type(), - GetShape(fmha.getDBmm2Rhs()).dimensions(), - GetShape(fmha.getDBmm2Rhs()).layout().minor_to_major()); - TF_ASSIGN_OR_RETURN(d_bmm2_rhs_slice, - GetAllocationSlice(fmha.getDBmm2Rhs())); - - TF_ASSIGN_OR_RETURN(scratch_slice, GetAllocationSlice(fmha.getScratch())); - - if (fmha.getD_S() != nullptr) { - descriptor.d_s_shape = ShapeUtil::MakeShapeWithDenseLayout( - GetShape(fmha.getD_S()).element_type(), - GetShape(fmha.getD_S()).dimensions(), - GetShape(fmha.getD_S()).layout().minor_to_major()); - TF_ASSIGN_OR_RETURN(d_s_slice, GetAllocationSlice(fmha.getD_S())); - } - - if (fmha.getDBias() != nullptr) { - descriptor.d_bias_shape = ShapeUtil::MakeShapeWithDenseLayout( - GetShape(fmha.getDBias()).element_type(), - GetShape(fmha.getDBias()).dimensions(), - GetShape(fmha.getDBias()).layout().minor_to_major()); - TF_ASSIGN_OR_RETURN(d_bias_slice, GetAllocationSlice(fmha.getDBias())); - } - - if (fmha.getMask() != nullptr) { - // has mask input - TF_RET_CHECK( - descriptor.kind != xla::gpu::CudnnfMHAKind::kBackwardBmmBmm && - descriptor.kind != xla::gpu::CudnnfMHAKind::kBackwardSoftmaxDropout && - descriptor.kind != xla::gpu::CudnnfMHAKind::kBackwardSoftmax); - - descriptor.mask_shape = ShapeUtil::MakeShapeWithDenseLayout( - GetShape(fmha.getMask()).element_type(), - GetShape(fmha.getMask()).dimensions(), - GetShape(fmha.getMask()).layout().minor_to_major()); - - TF_ASSIGN_OR_RETURN(mask_slice, GetAllocationSlice(fmha.getMask())); - } - // add flash attention backward related slice here - if (fmha.getBias() != nullptr) { - descriptor.bias_shape = ShapeUtil::MakeShapeWithDenseLayout( - GetShape(fmha.getBias()).element_type(), - GetShape(fmha.getBias()).dimensions(), - GetShape(fmha.getBias()).layout().minor_to_major()); - TF_ASSIGN_OR_RETURN(bias_slice, GetAllocationSlice(fmha.getBias())); - } - - if (fmha.getSoftmaxSum() != nullptr) { - TF_ASSIGN_OR_RETURN(softmax_sum_slice, - GetAllocationSlice(fmha.getSoftmaxSum())); - } - - if (fmha.getD_QAccum() != nullptr) { - TF_ASSIGN_OR_RETURN(d_Q_accum_slice, - GetAllocationSlice(fmha.getD_QAccum())); - } - - if (fmha.getFwdOutput() != nullptr) { - descriptor.fwd_output_shape = ShapeUtil::MakeShapeWithDenseLayout( - GetShape(fmha.getFwdOutput()).element_type(), - GetShape(fmha.getFwdOutput()).dimensions(), - GetShape(fmha.getFwdOutput()).layout().minor_to_major()); - TF_ASSIGN_OR_RETURN(fwd_output_slice, - GetAllocationSlice(fmha.getFwdOutput())); - } - return absl::OkStatus(); - }; - - if (auto fmha_backward_op = dyn_cast(op)) { - TF_RET_CHECK(fmha_backward_op != nullptr); - TF_ASSIGN_OR_RETURN( - CudnnfMHAKind kind, - AsCudnnBackwardfMHAKind(fmha_backward_op.getFusedMhaDag())); - descriptor.kind = kind; - TF_RETURN_IF_ERROR(populate_common(fmha_backward_op)); +absl::Status IrEmitterUnnested::EmitFusedMHABackwardThunk( + const HloCustomCallInstruction* instr) { + TF_ASSIGN_OR_RETURN(const auto gpu_config, + instr->backend_config()); + const xla::gpu::CudnnfMHABackendConfig& config = + gpu_config.cudnn_fmha_backend_config(); + bool is_flash_attention = config.is_flash_attention(); + + int input_index = 0; + TF_ASSIGN_OR_RETURN(BufferAllocation::Slice bmm1_grad_gemm1_rhs_slice, + GetAllocationSliceForHlo(instr->operand(input_index))); + Shape bmm1_grad_gemm1_rhs_shape = instr->operand(input_index++)->shape(); + + TF_ASSIGN_OR_RETURN(BufferAllocation::Slice bmm1_grad_gemm2_rhs_slice, + GetAllocationSliceForHlo(instr->operand(input_index))); + Shape bmm1_grad_gemm2_rhs_shape = instr->operand(input_index++)->shape(); + + TF_ASSIGN_OR_RETURN(BufferAllocation::Slice bmm2_grad_gemm2_rhs_slice, + GetAllocationSliceForHlo(instr->operand(input_index))); + Shape bmm2_grad_gemm2_rhs_shape = instr->operand(input_index++)->shape(); + + TF_ASSIGN_OR_RETURN(BufferAllocation::Slice bmm2_grad_gemm1_lhs_slice, + GetAllocationSliceForHlo(instr->operand(input_index))); + Shape bmm2_grad_gemm1_lhs_shape; + + // fmha.getBmm2GradGemm1Lhs() could be bmm2_grad_gemm1_lhs for regular + // attention or softmax stats for flash attention here we set the shape to + // be bmm2_grad_gemm1_lhs even it is flash attention + if (is_flash_attention) { + // flash attention TODO: make sure the layout is correct for + // bmm2_grad_gemm1_lhs + Shape intermediate_tensor_shape(config.intermediate_tensor_shape()); + bmm2_grad_gemm1_lhs_shape = intermediate_tensor_shape; + input_index++; } else { - return Internal("Unexpected operation"); + bmm2_grad_gemm1_lhs_shape = instr->operand(input_index++)->shape(); + } + + TF_ASSIGN_OR_RETURN(BufferAllocation::Slice d_output_slice, + GetAllocationSliceForHlo(instr->operand(input_index))); + Shape d_output_shape = instr->operand(input_index++)->shape(); + + TF_ASSIGN_OR_RETURN(const CudnnfMHAKind kind, GetCudnnfMHAKind(instr)); + bool has_mask = kind == CudnnfMHAKind::kBackwardScaleMaskSoftmax || + kind == CudnnfMHAKind::kBackwardScaleBiasMaskSoftmax || + kind == CudnnfMHAKind::kBackwardScaleMaskSoftmaxDropout || + kind == CudnnfMHAKind::kBackwardScaleBiasMaskSoftmaxDropout; + BufferAllocation::Slice mask_slice; + std::optional mask_shape; + if (has_mask) { + TF_ASSIGN_OR_RETURN(mask_slice, + GetAllocationSliceForHlo(instr->operand(input_index))); + mask_shape = instr->operand(input_index++)->shape(); + } + + bool has_bias = is_flash_attention && + (kind == CudnnfMHAKind::kBackwardScaleBiasSoftmax || + kind == CudnnfMHAKind::kBackwardScaleBiasSoftmaxDropout || + kind == CudnnfMHAKind::kBackwardScaleBiasMaskSoftmax || + kind == CudnnfMHAKind::kBackwardScaleBiasMaskSoftmaxDropout); + BufferAllocation::Slice bias_slice; + std::optional bias_shape; + if (has_bias) { + TF_ASSIGN_OR_RETURN(bias_slice, + GetAllocationSliceForHlo(instr->operand(input_index))); + bias_shape = instr->operand(input_index++)->shape(); + } + + BufferAllocation::Slice fwd_output_slice; + std::optional fwd_output_shape; + if (is_flash_attention) { + TF_ASSIGN_OR_RETURN(fwd_output_slice, + GetAllocationSliceForHlo(instr->operand(input_index))); + fwd_output_shape = instr->operand(input_index++)->shape(); + } + + TF_RET_CHECK(input_index == instr->operand_count()); + + int output_index = 0; + TF_ASSIGN_OR_RETURN(BufferAllocation::Slice d_bmm1_lhs_slice, + GetAllocationSliceForHlo(instr, {output_index})); + Shape d_bmm1_lhs_shape = + ShapeUtil::GetSubshape(instr->shape(), {output_index++}); + + TF_ASSIGN_OR_RETURN(BufferAllocation::Slice d_bmm1_rhs_slice, + GetAllocationSliceForHlo(instr, {output_index})); + Shape d_bmm1_rhs_shape = + ShapeUtil::GetSubshape(instr->shape(), {output_index++}); + + TF_ASSIGN_OR_RETURN(BufferAllocation::Slice d_bmm2_rhs_slice, + GetAllocationSliceForHlo(instr, {output_index})); + Shape d_bmm2_rhs_shape = + ShapeUtil::GetSubshape(instr->shape(), {output_index++}); + + BufferAllocation::Slice d_s_slice, softmax_sum_slice, d_Q_accum_slice; + std::optional d_s_shape; + if (!is_flash_attention) { + TF_ASSIGN_OR_RETURN(d_s_slice, + GetAllocationSliceForHlo(instr, {output_index})); + d_s_shape = ShapeUtil::GetSubshape(instr->shape(), {output_index++}); + } else { + TF_ASSIGN_OR_RETURN(softmax_sum_slice, + GetAllocationSliceForHlo(instr, {output_index++})); + TF_ASSIGN_OR_RETURN(d_Q_accum_slice, + GetAllocationSliceForHlo(instr, {output_index++})); } - TF_ASSIGN_OR_RETURN(GpufMHABackwardConfig config, + + TF_ASSIGN_OR_RETURN(BufferAllocation::Slice scratch_slice, + GetAllocationSliceForHlo(instr, {output_index++})); + + bool has_dbias = + instr->shape().tuple_shapes().size() == 6 && !is_flash_attention; + BufferAllocation::Slice d_bias_slice; + std::optional d_bias_shape; + if (has_dbias) { + TF_ASSIGN_OR_RETURN(d_bias_slice, + GetAllocationSliceForHlo(instr, {output_index})); + d_bias_shape = ShapeUtil::GetSubshape(instr->shape(), {output_index++}); + } + + TF_RET_CHECK(output_index == instr->shape().tuple_shapes().size()); + + GpufMHABackwardDescriptor descriptor = { + kind, + config, + is_flash_attention, + config.is_causal_mask(), + bmm1_grad_gemm1_rhs_shape, + bmm1_grad_gemm2_rhs_shape, + bmm2_grad_gemm1_lhs_shape, + bmm2_grad_gemm2_rhs_shape, + d_output_shape, + d_bmm1_lhs_shape, + d_bmm1_rhs_shape, + d_bmm2_rhs_shape, + config.bmm1_grad_gemm1_dot_dimension_numbers(), + config.bmm1_grad_gemm2_dot_dimension_numbers(), + config.bmm2_grad_gemm1_dot_dimension_numbers(), + config.bmm2_grad_gemm2_dot_dimension_numbers(), + d_s_shape, + fwd_output_shape, + mask_shape, + d_bias_shape, + bias_shape}; + + TF_ASSIGN_OR_RETURN(GpufMHABackwardConfig fmha_backward_config, GpufMHABackwardConfig::For(descriptor)); AddThunkToThunkSequence(std::make_unique( - Thunk::ThunkInfo::WithProfileAnnotation(op), std::move(config), - bmm1_grad_gemm1_rhs_slice, bmm1_grad_gemm2_rhs_slice, - bmm2_grad_gemm1_lhs_slice, bmm2_grad_gemm2_rhs_slice, d_output_slice, - scratch_slice, d_bmm1_lhs_slice, d_bmm1_rhs_slice, d_bmm2_rhs_slice, - d_s_slice, softmax_sum_slice, d_Q_accum_slice, mask_slice, d_bias_slice, + Thunk::ThunkInfo::WithProfileAnnotation(instr), + std::move(fmha_backward_config), bmm1_grad_gemm1_rhs_slice, + bmm1_grad_gemm2_rhs_slice, bmm2_grad_gemm1_lhs_slice, + bmm2_grad_gemm2_rhs_slice, d_output_slice, scratch_slice, + d_bmm1_lhs_slice, d_bmm1_rhs_slice, d_bmm2_rhs_slice, d_s_slice, + softmax_sum_slice, d_Q_accum_slice, mask_slice, d_bias_slice, fwd_output_slice, bias_slice)); return absl::OkStatus(); } + #endif // GOOGLE_CUDA absl::StatusOr @@ -1678,81 +1306,44 @@ IrEmitterUnnested::GetAllocationSliceForHlo(const HloInstruction* instr, #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM -absl::Status IrEmitterUnnested::EmitCubDeviceRadixSort(mlir::Operation* op) { - auto radix_sort_op = mlir::cast(op); - if (radix_sort_op.getInputs().size() != 1 && - radix_sort_op.getInputs().size() != 2) { +absl::Status IrEmitterUnnested::EmitCubDeviceRadixSort( + const HloCustomCallInstruction* instr) { + if (instr->operand_count() != 1 && instr->operand_count() != 2) { return Internal("Invalid number of operands for radix sort"); } - TF_ASSIGN_OR_RETURN(std::vector operands, - GetAllocationSlices(radix_sort_op.getInputs())); - TF_ASSIGN_OR_RETURN(std::vector results, - GetAllocationSlices(radix_sort_op.getOutput())); - TF_ASSIGN_OR_RETURN(BufferAllocation::Slice scratch, - GetAllocationSlice(radix_sort_op.getScratch())); - - auto thunk = std::make_unique( - Thunk::ThunkInfo::WithProfileAnnotation(op), - GetShape(op->getOperand(0)).element_type(), - radix_sort_op.getInputs().size() == 2 - ? std::optional(GetShape(op->getOperand(1)).element_type()) - : std::nullopt, - operands, results, scratch, radix_sort_op.getDescending()); - - AddThunkToThunkSequence(std::move(thunk)); - return absl::OkStatus(); -} - -absl::Status IrEmitterUnnested::EmitCholeskyThunk(mlir::Operation* op) { - auto cholesky_op = mlir::cast(op); - - const Shape shape = GetShape(cholesky_op.getInput()); - int ndim = shape.dimensions_size(); - CHECK_GE(ndim, 2); - int64_t n = shape.dimensions(ndim - 1); - - const auto& dims = shape.dimensions(); - int64_t batch_size = - std::accumulate(dims.begin(), dims.end() - 2, int64_t{1}, - [](int64_t a, int64_t b) { return a * b; }); - - TF_ASSIGN_OR_RETURN(auto operand_buffer, - GetAllocationSlice(cholesky_op.getInput())); - TF_ASSIGN_OR_RETURN(auto a_buffer, - GetAllocationSlice(cholesky_op.getOutput())); - TF_ASSIGN_OR_RETURN(auto workspace_buffer, - GetAllocationSlice(cholesky_op.getScratch())); - TF_ASSIGN_OR_RETURN(auto info_buffer, - GetAllocationSlice(cholesky_op.getInfo())); - - ThunkSequence thunks; - - if (operand_buffer != a_buffer) { - thunks.push_back(std::make_unique( - Thunk::ThunkInfo::WithProfileAnnotation(op), - /*source_buffer=*/operand_buffer, - /*destination_buffer=*/a_buffer, - /*mem_size=*/ShapeUtil::ByteSizeOf(shape), - /*source_value=*/cholesky_op.getInput(), - /*destination_value=*/cholesky_op.getOutput())); + absl::InlinedVector operands; + for (int i = 0; i < instr->operand_count(); ++i) { + TF_ASSIGN_OR_RETURN(BufferAllocation::Slice operand, + GetAllocationSliceForHlo(instr->operand(i), {})); + operands.push_back(operand); } - CholeskyOptions options; - options.set_lower(cholesky_op.getIsLower()); - thunks.push_back(std::make_unique( - Thunk::ThunkInfo::WithProfileAnnotation(op), options, - PtxOptsFromDebugOptions(ir_emitter_context_->debug_options()), a_buffer, - workspace_buffer, info_buffer, shape.element_type(), batch_size, n)); + absl::InlinedVector results; + TF_ASSIGN_OR_RETURN(BufferAllocation::Slice result, + GetAllocationSliceForHlo(instr, {0})); + results.push_back(result); - // Elide the sequential thunk if there's no copy. - if (thunks.size() == 1) { - AddThunkToThunkSequence(std::move(thunks[0])); + BufferAllocation::Slice scratch; + if (instr->operand_count() == 1) { + TF_ASSIGN_OR_RETURN(scratch, GetAllocationSliceForHlo(instr, {1})); } else { - AddThunkToThunkSequence(std::make_unique( - Thunk::ThunkInfo::WithProfileAnnotation(op), std::move(thunks))); + TF_ASSIGN_OR_RETURN(BufferAllocation::Slice result, + GetAllocationSliceForHlo(instr, {1})); + results.push_back(result); + TF_ASSIGN_OR_RETURN(scratch, GetAllocationSliceForHlo(instr, {2})); } + TF_ASSIGN_OR_RETURN(xla::SortOptions options, + instr->backend_config()); + auto thunk = std::make_unique( + Thunk::ThunkInfo::WithProfileAnnotation(instr), + instr->operand(0)->shape().element_type(), + instr->operand_count() == 2 + ? std::optional(instr->operand(1)->shape().element_type()) + : std::nullopt, + operands, results, scratch, options.descending()); + AddThunkToThunkSequence(std::move(thunk)); return absl::OkStatus(); } @@ -1785,9 +1376,7 @@ absl::Status IrEmitterUnnested::EmitCholeskyThunk(const HloInstruction* instr) { Thunk::ThunkInfo::WithProfileAnnotation(instr), /*source_buffer=*/operand_buffer, /*destination_buffer=*/a_buffer, - /*mem_size=*/ShapeUtil::ByteSizeOf(shape), - /*source_value=*/nullptr, - /*destination_value=*/nullptr)); + /*mem_size=*/ShapeUtil::ByteSizeOf(shape))); } thunks.push_back(std::make_unique( @@ -1859,190 +1448,13 @@ static absl::StatusOr BuildAttributesMap( } absl::Status IrEmitterUnnested::EmitCustomCallThunk( - mlir::Operation* op, const HloCustomCallInstruction* instr) { - if (ir_emitter_context_->emit_ir_from_hlo()) - return EmitCustomCallThunk(instr); - auto custom_call = mlir::cast(op); - const std::string call_target_name = custom_call.getCallTargetName().str(); + const HloCustomCallInstruction* instr) { + const std::string call_target_name = instr->custom_call_target(); // Typed FFI custom calls is a replacement for legacy custom calls with // a rich type safe API. It's under construction and not fully supported. bool is_ffi_custom_call = - custom_call.getApiVersion() == - mlir::mhlo::CustomCallApiVersion::API_VERSION_TYPED_FFI; - - void* call_target = CustomCallTargetRegistry::Global()->Lookup( - call_target_name, std::string(platform_name())); - - absl::StatusOr handler = - ffi::FindHandler(call_target_name, platform_name()); - - // At least one implementation should be available at run time. - bool found_custom_call = !is_ffi_custom_call && call_target != nullptr; - bool found_ffi_handler = is_ffi_custom_call && handler.ok(); - - if (!found_custom_call && !found_ffi_handler) { - auto& debug_options = ir_emitter_context_->debug_options(); - - // If true, then all custom calls that are not found in custom call or FFI - // registries will become no-op (we don't emit any thunks for them). - if (debug_options.xla_gpu_mock_custom_calls()) { - return absl::OkStatus(); - } - - // TODO(ezhulenev): Custom calls registered with an XLA runtime are not part - // of a legacy registry, or an FFI registry. For now we simply ignore them. - if (debug_options.xla_gpu_enable_xla_runtime_executable()) { - return absl::OkStatus(); - } - - return absl::UnimplementedError( - absl::StrCat("No registered implementation for custom call to ", - call_target_name, " for platform ", platform_name())); - } - - using Slices = std::vector>; - - // Initialize slices and shapes from the value range. - auto init_from_values = [&](mlir::ValueRange values, Slices* slices) { - for (mlir::Value value : values) { - TF_ASSIGN_OR_RETURN(auto slice, GetAllocationSlice(value)); - slices->push_back(CustomCallThunk::Slice{slice, GetShape(value)}); - } - return absl::OkStatus(); - }; - - // Initialize slices and shapes from the value range with token holes. - auto init_from_mapped_values = [&](mlir::ValueRange values, - absl::Span target_mapping, - int64_t target_size, Slices* slices) { - slices->resize(target_size); - for (auto [index, value] : llvm::zip(target_mapping, values)) { - TF_ASSIGN_OR_RETURN(auto slice, GetAllocationSlice(value)); - (*slices)[index] = CustomCallThunk::Slice{slice, GetShape(value)}; - } - return absl::OkStatus(); - }; - - Slices operands, results; - - // If we have a target mapping, than the number of operands and results of a - // custom call handler can be more than a number of operands and results in - // the IR. These holes are coming from the HLO token operands and results. - if (auto target_mapping = custom_call.getTargetArgMapping()) { - auto arg_mapping = target_mapping->getArgsToTargetArgs(); - auto res_mapping = target_mapping->getResultsToTargetResults(); - - TF_RETURN_IF_ERROR( - init_from_mapped_values(custom_call.getArgs(), arg_mapping, - target_mapping->getNumArgs(), &operands)); - TF_RETURN_IF_ERROR( - init_from_mapped_values(custom_call.getOutput(), res_mapping, - target_mapping->getNumResults(), &results)); - - } else { - TF_RETURN_IF_ERROR(init_from_values(custom_call.getArgs(), &operands)); - TF_RETURN_IF_ERROR(init_from_values(custom_call.getOutput(), &results)); - } - - // For legacy custom calls we convert all API versions into the the latest - // status-returning one and pass backend config as an opaque string. - CustomCallThunk::CustomCallTarget custom_call_target; - std::string opaque; - - // For XLA FFI handlers we decode opaque backend config into attributes map - // at IR emission time, so that we do not need to parse MLIR at run time. For - // FFI handlers backend config must be a compatible MLIR dictionary. - CustomCallThunk::AttributesMap attributes; - - // For information about this calling convention, see - // xla/g3doc/custom_call.md. - switch (custom_call.getApiVersion()) { - case mlir::mhlo::CustomCallApiVersion::API_VERSION_ORIGINAL: - using original_call_type = - void (*)(CustomCallThunk::Stream /*stream*/, void** /*buffers*/, - const char* /*opaque*/, size_t /*opaque_len*/); - custom_call_target = [call_target](CustomCallThunk::Stream stream, - void** buffers, const char* opaque, - size_t opaque_len, - XlaCustomCallStatus*) { - auto typed_call_target = - reinterpret_cast(call_target); - typed_call_target(stream, buffers, opaque, opaque_len); - }; - break; - case mlir::mhlo::CustomCallApiVersion::API_VERSION_STATUS_RETURNING: - case mlir::mhlo::CustomCallApiVersion::API_VERSION_STATUS_RETURNING_UNIFIED: - using status_returning_call_type = - void (*)(CustomCallThunk::Stream /*stream*/, void** /*buffers*/, - const char* /*opaque*/, size_t /*opaque_len*/, - XlaCustomCallStatus* /*status*/); - custom_call_target = - reinterpret_cast(call_target); - break; - case mlir::mhlo::CustomCallApiVersion::API_VERSION_TYPED_FFI: - // We already checked `handler` above. - break; - default: - return Internal("Unknown custom-call API version enum value: %d", - custom_call.getApiVersion()); - } - - auto backend_config = - custom_call.getBackendConfig().value_or(mlir::Attribute()); - - switch (custom_call.getApiVersion()) { - case mlir::mhlo::CustomCallApiVersion::API_VERSION_ORIGINAL: - case mlir::mhlo::CustomCallApiVersion::API_VERSION_STATUS_RETURNING: - case mlir::mhlo::CustomCallApiVersion::API_VERSION_STATUS_RETURNING_UNIFIED: - if (auto str = backend_config.dyn_cast_or_null()) { - opaque = str.str(); - break; - } - return absl::InternalError( - "Unsupported backend config. Expected a string attribute"); - - case mlir::mhlo::CustomCallApiVersion::API_VERSION_TYPED_FFI: - if (auto dict = backend_config.dyn_cast_or_null()) { - TF_ASSIGN_OR_RETURN(attributes, BuildAttributesMap(dict)); - break; - } - return absl::InternalError( - "Unsupported backend config. Expected a dictionary attribute"); - - default: - return Internal("Unknown custom-call API version enum value: %d", - custom_call.getApiVersion()); - } - - auto ffi_thunk = [&] { - auto& called_computations = instr->called_computations(); - return std::make_unique( - Thunk::ThunkInfo::WithProfileAnnotation(op), *handler, - std::move(operands), std::move(results), std::move(attributes), - called_computations.empty() ? nullptr : called_computations[0]); - }; - - auto legacy_thunk = [&] { - return std::make_unique( - Thunk::ThunkInfo::WithProfileAnnotation(op), - std::move(custom_call_target), std::move(operands), std::move(results), - std::move(opaque)); - }; - - AddThunkToThunkSequence(found_ffi_handler ? ffi_thunk() : legacy_thunk()); - - return absl::OkStatus(); -} - -absl::Status IrEmitterUnnested::EmitCustomCallThunk( - const HloCustomCallInstruction* instr) { - const std::string call_target_name = instr->custom_call_target(); - - // Typed FFI custom calls is a replacement for legacy custom calls with - // a rich type safe API. It's under construction and not fully supported. - bool is_ffi_custom_call = - instr->api_version() == CustomCallApiVersion::API_VERSION_TYPED_FFI; + instr->api_version() == CustomCallApiVersion::API_VERSION_TYPED_FFI; void* call_target = CustomCallTargetRegistry::Global()->Lookup( call_target_name, std::string(platform_name())); @@ -2201,33 +1613,6 @@ absl::Status IrEmitterUnnested::EmitCustomCallThunk( return absl::OkStatus(); } -absl::Status IrEmitterUnnested::EmitFftThunk(mlir::Operation* op) { - auto fft_op = mlir::cast(op); - const Shape operand_shape = GetShape(fft_op.getOperand()); - const Shape output_shape = GetShape(fft_op.getOutput()); - TF_RET_CHECK(LayoutUtil::IsMonotonicWithDim0Major(operand_shape.layout())); - TF_RET_CHECK(LayoutUtil::IsMonotonicWithDim0Major(output_shape.layout())); - - TF_ASSIGN_OR_RETURN(BufferAllocation::Slice arg_slice, - GetAllocationSlice(fft_op.getOperand())); - TF_ASSIGN_OR_RETURN(BufferAllocation::Slice dest_slice, - GetAllocationSlice(fft_op.getOutput())); - TF_ASSIGN_OR_RETURN( - xla::FftType fft_type, - ConvertFftType(mlir::mhlo::stringifyFftType(fft_op.getFftType()))); - auto fft_length_values = fft_op.getFftLength().getValues(); - std::vector fft_length(fft_length_values.begin(), - fft_length_values.end()); - - AddThunkToThunkSequence(std::make_unique( - Thunk::ThunkInfo::WithProfileAnnotation(op), fft_type, fft_length, - /*input_buffer=*/arg_slice, - /*output_buffer=*/dest_slice, - /*input_shape=*/operand_shape, - /*output_shape=*/output_shape)); - return absl::OkStatus(); -} - absl::Status IrEmitterUnnested::EmitFftThunk(const HloFftInstruction* instr) { TF_ASSIGN_OR_RETURN(BufferAllocation::Slice arg_slice, GetAllocationSliceForHlo(instr->operand(0))); @@ -2244,97 +1629,6 @@ absl::Status IrEmitterUnnested::EmitFftThunk(const HloFftInstruction* instr) { } #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM -absl::Status IrEmitterUnnested::EmitTriangularSolveCustomCall( - mlir::Operation* op) { - auto custom_call = mlir::cast(op); - - auto operands = op->getOperands(); - TF_RET_CHECK(operands.size() == 4); - - // We expect Fortran layout for everything other than the temp buffer (the - // last operand). Fortran layout is not XLA default layout with elements 0 - // and 1 swapped. For example instead of default layout {3,2,1,0} we'd have - // Fortran layout {2,3,1,0}. - TF_RET_CHECK(absl::c_all_of(operands.drop_back(1), [&](mlir::Value v) { - const Shape& shape = GetShape(v); - const Layout& layout = shape.layout(); - int n = layout.minor_to_major_size(); - if (n < 2) { - return false; - } - // Unfortunately the HLO -> LMHLO -> HLO conversion loses layout information - // if the shape has any dimensions of size 1: In that case, the new HLO - // (which we see here) will have an arbitrary value for the location of the - // size-1 dimension. Just skip this assertion if the shape has any - // degenerate dimensions. - if (absl::c_any_of(shape.dimensions(), - [](int64_t dim) { return dim == 1; })) { - return true; - } - return layout.minor_to_major(0) == n - 2 && - layout.minor_to_major(1) == n - 1 && - std::is_sorted(layout.minor_to_major().begin() + 2, - layout.minor_to_major().end(), - std::greater()); - })); - - TF_ASSIGN_OR_RETURN(BufferAllocation::Slice a_slice, - GetAllocationSlice(operands[0])); - TF_ASSIGN_OR_RETURN(BufferAllocation::Slice b_slice, - GetAllocationSlice(operands[1])); - TF_ASSIGN_OR_RETURN(BufferAllocation::Slice result_slice, - GetAllocationSlice(operands[2])); - TF_ASSIGN_OR_RETURN(BufferAllocation::Slice temp_slice, - GetAllocationSlice(operands[3])); - - const Shape b_shape = GetShape(operands[1]); - const PrimitiveType elem_ty = b_shape.element_type(); - - TriangularSolveOptions backend_config; - if (auto str = custom_call.getBackendConfig() - .value_or(mlir::Attribute()) - .dyn_cast_or_null()) - TF_RETURN_IF_ERROR( - tsl::HumanReadableJsonToProto(str.str(), &backend_config)); - - ThunkSequence thunks; - - // Triangular solve is in-place on 'b', so copy 'b' to the output if they - // aren't the same buffer. - if (b_slice != result_slice) { - thunks.push_back(std::make_unique( - Thunk::ThunkInfo(op), - /*source_buffer=*/b_slice, - /*destination_buffer=*/result_slice, - /*mem_size=*/ShapeUtil::ByteSizeOf(b_shape), - /*source_value=*/operands[1], - /*destination_value=*/operands[2])); - } - - int64_t m = b_shape.dimensions(b_shape.rank() - 2); - int64_t n = b_shape.dimensions(b_shape.rank() - 1); - int64_t batch_size = std::accumulate( - b_shape.dimensions().begin(), b_shape.dimensions().end() - 2, int64_t{1}, - [](int64_t a, int64_t b) { return a * b; }); - int64_t elem_size = ShapeUtil::ByteSizeOfPrimitiveType(elem_ty); - int64_t a_batch_stride = - backend_config.left_side() ? m * m * elem_size : n * n * elem_size; - int64_t b_batch_stride = m * n * elem_size; - thunks.push_back(std::make_unique( - Thunk::ThunkInfo::WithProfileAnnotation(op), backend_config, - PtxOptsFromDebugOptions(ir_emitter_context_->debug_options()), - /*a_buffer=*/a_slice, /*b_buffer=*/result_slice, temp_slice, elem_ty, - batch_size, m, n, a_batch_stride, b_batch_stride)); - - // Elide the sequential thunk if there's no copy. - if (thunks.size() == 1) { - AddThunkToThunkSequence(std::move(thunks[0])); - } else { - AddThunkToThunkSequence(std::make_unique( - Thunk::ThunkInfo::WithProfileAnnotation(op), std::move(thunks))); - } - return absl::OkStatus(); -} absl::Status IrEmitterUnnested::EmitTriangularSolveCustomCall( const HloInstruction* instr) { @@ -2384,9 +1678,7 @@ absl::Status IrEmitterUnnested::EmitTriangularSolveCustomCall( Thunk::ThunkInfo::WithProfileAnnotation(instr), /*source_buffer=*/b_slice, /*destination_buffer=*/result_slice, - /*mem_size=*/ShapeUtil::ByteSizeOf(b_shape), - /*source_value=*/nullptr, - /*destination_value=*/nullptr)); + /*mem_size=*/ShapeUtil::ByteSizeOf(b_shape))); } int64_t m = b_shape.dimensions(b_shape.rank() - 2); @@ -2408,8 +1700,11 @@ absl::Status IrEmitterUnnested::EmitTriangularSolveCustomCall( if (thunks.size() == 1) { AddThunkToThunkSequence(std::move(thunks[0])); } else { - AddThunkToThunkSequence(std::make_unique( - Thunk::ThunkInfo::WithProfileAnnotation(instr), std::move(thunks))); + auto thunk_info = Thunk::ThunkInfo::WithProfileAnnotation(instr); + // Don't repeat the annotation from inside thunks + thunk_info.profile_annotation = {}; + AddThunkToThunkSequence( + std::make_unique(thunk_info, std::move(thunks))); } return absl::OkStatus(); } @@ -2461,6 +1756,98 @@ absl::Status IrEmitterUnnested::EmitTopKCustomCall( return absl::OkStatus(); } +absl::Status IrEmitterUnnested::EmitTritonCustomCall( + const HloCustomCallInstruction* instr) { +#if !GOOGLE_CUDA + return absl::UnimplementedError("Triton support requires CUDA"); +#else + auto generate = [this, &instr]() -> absl::StatusOr { + mlir::MLIRContext& mlir_context = *ir_emitter_context_->mlir_context(); + mlir_context.loadDialect(); + auto call = + TritonCall::Parse(instr->raw_backend_config_string(), &mlir_context); + auto kernel_name = + ir_emitter_context_->name_uniquer()->GetUniqueName(call.name); + VLOG(3) << "Generating: " << kernel_name; + + auto triton_module = + mlir::parseSourceString(call.ir, &mlir_context); + auto triton_fn = + triton_module->lookupSymbol(call.name); + triton_fn.setName(kernel_name); + + HloModule* hlo_module = instr->GetModule(); + auto gemm_config = TritonGemmConfig( + /*block_m=*/-1, /*block_n=*/-1, /*block_k=*/-1, /*split_k=*/-1, + call.num_stages, call.num_warps); + TF_ASSIGN_OR_RETURN( + auto result, + CompileTritonToLLVM(hlo_module->config(), hlo_module->name(), + ir_emitter_context_->cuda_compute_capability(), + ir_emitter_context_->gpu_device_info(), gemm_config, + triton_module.get(), + ir_emitter_context_->llvm_module(), mlir_context)); + + llvm::Function* impl_fn = + ir_emitter_context_->llvm_module()->getFunction(kernel_name); + TF_RET_CHECK(impl_fn); + impl_fn->setName(ir_emitter_context_->name_uniquer()->GetUniqueName( + kernel_name + "_impl")); + + TF_ASSIGN_OR_RETURN( + auto kernel_arguments, + KernelArguments::Create(ir_emitter_context_->buffer_assignment(), instr, + instr->operands())); + auto launch_dimensions = + LaunchDimensions(se::BlockDim(call.grid_x, call.grid_y, call.grid_z), + se::ThreadDim(call.num_warps * 32)); + + llvm::IRBuilder builder(ir_emitter_context_->llvm_module()->getContext()); + + llvm::Function* kernel; + std::vector inputs; + std::vector outputs; + TF_ASSIGN_OR_RETURN( + std::tie(kernel, inputs, outputs), + BuildKernelPrototype(*ir_emitter_context_, kernel_name, + kernel_arguments.args(), impl_fn->arg_size(), + launch_dimensions, &builder)); + + // Move function body into kernel prototype. + llvm::Function* prototype_func = builder.GetInsertBlock()->getParent(); + prototype_func->splice(prototype_func->begin(), impl_fn); + for (const auto& [kernel_arg, arg, input] : + llvm::zip(kernel_arguments.args(), impl_fn->args(), inputs)) { + // Remove the alignment and aliasing attributes to avoid recompiling the + // kernel for each alignment/aliasing combination. + arg.removeAttr(llvm::Attribute::Alignment); + arg.removeAttr(llvm::Attribute::NoAlias); + + arg.replaceAllUsesWith(input.GetBasePointer()); + } + impl_fn->eraseFromParent(); + + return {{kernel->getName().str(), launch_dimensions, result.cluster_dim, + result.shmem_bytes}}; + }; + + auto [status_or_entry, was_cached] = + ir_emitter_context_->kernel_cache().GetWithStatus( + instr->raw_backend_config_string(), generate); + TF_ASSIGN_OR_RETURN(const KernelReuseCache::Entry* entry, status_or_entry); + + TF_ASSIGN_OR_RETURN( + auto kernel_arguments, + KernelArguments::Create(ir_emitter_context_->buffer_assignment(), instr, + instr->operands())); + + AddThunkToThunkSequence(std::make_unique( + instr, entry->kernel_name, kernel_arguments.args(), + entry->launch_dimensions, entry->cluster_dim, entry->shmem_bytes)); + return absl::OkStatus(); +#endif // GOOGLE_CUDA +} + // Convert the following form of fusion region: // fusion() { // %0 = tensor_load %external_memref0 @@ -2526,28 +1913,7 @@ absl::Status IrEmitterUnnested::EmitFusion(const HloFusionInstruction* instr, std::unique_ptr emitter, GetFusionEmitter(HloFusionInfo( fusion_analysis, instr, &ir_emitter_context_->buffer_assignment()))); - return AddThunksToThunkSequence( - emitter->Emit(*ir_emitter_context_, nullptr, *instr)); -} - -absl::Status IrEmitterUnnested::EmitFusion( - mlir::Operation* op, - const absl::flat_hash_map& - hlo_for_lmhlo) { - auto fusion_op = mlir::cast(op); - auto* fusion = Cast(hlo_for_lmhlo.at(fusion_op)); - - // Create HloFusionAnalysis instance. - const se::DeviceDescription& device_info = - ir_emitter_context_->gpu_device_info(); - auto fusion_analysis = HloFusionAnalysis::Create(fusion, &device_info); - - TF_ASSIGN_OR_RETURN( - std::unique_ptr emitter, - GetFusionEmitter(LmhloFusionInfo(fusion_analysis, fusion_op, - ir_emitter_context_->allocations()))); - return AddThunksToThunkSequence( - emitter->Emit(*ir_emitter_context_, fusion_op, *fusion)); + return AddThunksToThunkSequence(emitter->Emit(*ir_emitter_context_, *instr)); } absl::Status IrEmitterUnnested::AssertNonDeterminismIsOkay( @@ -2563,46 +1929,36 @@ absl::Status IrEmitterUnnested::AssertNonDeterminismIsOkay( } absl::Status IrEmitterUnnested::EmitSelectAndScatter( - mlir::Operation* op, - const absl::flat_hash_map& - hlo_for_lmhlo) { - auto select_and_scatter_op = mlir::cast(op); - auto* select_and_scatter = - Cast(hlo_for_lmhlo.at(op)); - - const Shape source_shape = GetShape(select_and_scatter_op.getSource()); - const Shape operand_shape = GetShape(select_and_scatter_op.getOperand()); + const HloSelectAndScatterInstruction* instr) { + const HloInstruction* operand = instr->operand(0); + const HloInstruction* source = instr->operand(1); + const Shape source_shape = source->shape(); + const Shape operand_shape = operand->shape(); const int64_t rank = operand_shape.rank(); + Window window = instr->window(); + CHECK_EQ(rank, source_shape.rank()); - if (select_and_scatter_op.getWindowDimensions()) { - CHECK_EQ(rank, select_and_scatter_op.getWindowDimensions()->size()); - } + CHECK_EQ(rank, window.dimensions_size()); - TF_RETURN_IF_ERROR(AssertNonDeterminismIsOkay( - mlir::mhlo::GetDebugNameFromLocation(select_and_scatter_op.getLoc()))); + std::string name = llvm_ir::IrName(instr); - std::string name = GetIrNameFromLoc(select_and_scatter_op.getLoc()); + TF_RETURN_IF_ERROR(AssertNonDeterminismIsOkay(name)); - const HloInstruction* init_value = select_and_scatter->operand(2); + const HloInstruction* init_value = instr->operand(2); // IrEmitterUnnested implements kSelectAndScatter as a SequentialThunk // consisting of two thunks, an initializer KernelThunk that initializes // the output and another KernelThunk that accumulates the scattered // elements. - TF_RETURN_IF_ERROR(BuildInitializerThunk(op, select_and_scatter, init_value, - select_and_scatter_op.getInitValue(), - select_and_scatter_op.getOut())); + TF_RETURN_IF_ERROR(BuildInitializerThunk(instr, init_value)); LaunchDimensions launch_dimensions = CalculateLaunchDimensions( source_shape, ir_emitter_context_->gpu_device_info()); // Init value is not needed in IR emission. - TF_ASSIGN_OR_RETURN(auto ir_arrays, BuildKernelThunkForNonFusionOp( - select_and_scatter_op, - {select_and_scatter_op.getOperand(), - select_and_scatter_op.getSource(), - select_and_scatter_op.getOut()}, - launch_dimensions)); + TF_ASSIGN_OR_RETURN(auto ir_arrays, + BuildKernelThunkForNonFusionOp(instr, {operand, source}, + launch_dimensions)); auto& [inputs, outputs] = ir_arrays; CHECK_EQ(inputs.size(), 3); @@ -2611,8 +1967,8 @@ absl::Status IrEmitterUnnested::EmitSelectAndScatter( const llvm_ir::IrArray& source_array = inputs[1]; const llvm_ir::IrArray& out_array = inputs[2]; - llvm::Type* index_type = GetIndexTypeForKernel( - select_and_scatter_op, launch_dimensions.launch_bound(), &b_); + llvm::Type* index_type = + GetIndexTypeForKernel(instr, launch_dimensions.launch_bound(), &b_); auto index_typed_constant = [&](uint64_t c) -> llvm::Constant* { return llvm::ConstantInt::get(index_type, c); }; @@ -2661,11 +2017,10 @@ absl::Status IrEmitterUnnested::EmitSelectAndScatter( index_type); DimensionVector window_size; - mlir::DenseIntElementsAttr window_dimensions = - select_and_scatter_op.getWindowDimensions().value(); - for (const auto& dim : window_dimensions) { - window_size.push_back(dim.getSExtValue()); - CHECK_GT(dim.getSExtValue(), 0); + for (const WindowDimension& dim : window.dimensions()) { + auto size = static_cast(dim.size()); + window_size.push_back(size); + CHECK_GT(size, 0); } const llvm_ir::IrArray::Index window_index = window_loops.AddLoopsForShape( @@ -2680,14 +2035,9 @@ absl::Status IrEmitterUnnested::EmitSelectAndScatter( std::vector operand_multi_index(source_index.size()); llvm::Value* in_bounds_condition = b_.getInt1(true); - auto strides = *select_and_scatter_op.getWindowStrides(); - auto paddings = *select_and_scatter_op.getPadding(); - - for (const auto& stride_and_padding : - llvm::enumerate(llvm::zip(strides, paddings))) { - const int i = stride_and_padding.index(); - int64_t stride = std::get<0>(stride_and_padding.value()).getSExtValue(); - int64_t padding = std::get<1>(stride_and_padding.value()).getSExtValue(); + for (const auto [i, value] : llvm::enumerate(window.dimensions())) { + auto stride = static_cast(value.stride()); + auto padding = static_cast(value.padding_low()); llvm::Value* strided_index = NSWMul(source_index[i], index_typed_constant(stride)); @@ -2739,7 +2089,7 @@ absl::Status IrEmitterUnnested::EmitSelectAndScatter( llvm_ir::PrimitiveTypeToIrType(PRED, module_), "select_return_buffer", &b_); - const HloComputation* select_computation = select_and_scatter->select(); + const HloComputation* select_computation = instr->select(); TF_RETURN_IF_ERROR(CallNestedComputation( &b_, *ir_emitter_context_, *select_computation, {selected_value_address, operand_address}, select_return_buffer)); @@ -2784,7 +2134,7 @@ absl::Status IrEmitterUnnested::EmitSelectAndScatter( Load(selected_index_address->getAllocatedType(), selected_index_address_slot)); } - const Shape output_shape = GetShape(select_and_scatter_op.getOut()); + const Shape output_shape = instr->shape(); llvm::Value* source_value_address = source_array.EmitArrayElementAddress(source_index, &b_); llvm_ir::IrArray::Index selected_index(selected_multi_index, output_shape, @@ -2792,7 +2142,7 @@ absl::Status IrEmitterUnnested::EmitSelectAndScatter( llvm::Value* output_value_address = out_array.EmitArrayElementAddress(selected_index, &b_); - const HloComputation* scatter_computation = select_and_scatter->scatter(); + const HloComputation* scatter_computation = instr->scatter(); return EmitAtomicOperationForNestedComputation( &b_, *ir_emitter_context_, *scatter_computation, output_value_address, source_value_address, source_array.GetElementLlvmType()); @@ -2803,294 +2153,22 @@ absl::Status IrEmitterUnnested::EmitSelectAndScatter( .EmitLoop(name, index_type); } -absl::Status IrEmitterUnnested::EmitSelectAndScatter( - const HloSelectAndScatterInstruction* instr) { - const HloInstruction* operand = instr->operand(0); - const HloInstruction* source = instr->operand(1); - const Shape source_shape = source->shape(); - const Shape operand_shape = operand->shape(); - const int64_t rank = operand_shape.rank(); +absl::Status IrEmitterUnnested::EmitWhile(const HloInstruction* instr) { + TF_ASSIGN_OR_RETURN(auto config, + instr->backend_config()); - Window window = instr->window(); + std::optional trip_count = std::nullopt; + if (config.has_known_trip_count()) trip_count = config.known_trip_count().n(); - CHECK_EQ(rank, source_shape.rank()); - CHECK_EQ(rank, window.dimensions_size()); + TF_ASSIGN_OR_RETURN( + auto thunk, + BuildWhileThunk(instr, Thunk::ThunkInfo::WithProfileAnnotation(instr), + trip_count)); - std::string name = llvm_ir::IrName(instr); - - TF_RETURN_IF_ERROR(AssertNonDeterminismIsOkay(name)); - - const HloInstruction* init_value = instr->operand(2); - // IrEmitterUnnested implements kSelectAndScatter as a SequentialThunk - // consisting of two thunks, an initializer KernelThunk that initializes - // the output and another KernelThunk that accumulates the scattered - // elements. - TF_RETURN_IF_ERROR( - BuildInitializerThunk(nullptr, instr, init_value, nullptr, nullptr)); - - LaunchDimensions launch_dimensions = CalculateLaunchDimensions( - source_shape, ir_emitter_context_->gpu_device_info()); - - // Init value is not needed in IR emission. - TF_ASSIGN_OR_RETURN(auto ir_arrays, - BuildKernelThunkForNonFusionOp(instr, {operand, source}, - launch_dimensions)); - - auto& [inputs, outputs] = ir_arrays; - CHECK_EQ(inputs.size(), 3); - CHECK_EQ(outputs.size(), 0); - const llvm_ir::IrArray& operand_array = inputs[0]; - const llvm_ir::IrArray& source_array = inputs[1]; - const llvm_ir::IrArray& out_array = inputs[2]; - - llvm::Type* index_type = - GetIndexTypeForKernel(instr, launch_dimensions.launch_bound(), &b_); - auto index_typed_constant = [&](uint64_t c) -> llvm::Constant* { - return llvm::ConstantInt::get(index_type, c); - }; - - // kSelectAndScatter is implemented as two kernel launches: the first launch - // initializes the output array to the given initial value, - // and the second accumulates the "source" matrix to the - // selected elements in the output array. The first launch is already - // implemented by the initializer thunk generated earlier, so this function - // only needs to take care of the select-and-scatter part. - // - // Pseudo code for select-and-scatter: - // - // for (coordinates S in the source): # This loop is parallel. - // initialized_flag = false - // for (coordinates W in the window): - // I = S * stride + W - pad_low - // if I within bounds of operand: - // if !(initialized_flag and select(selected_value, operand(I))): - // selected_value = operand(I) - // selected_index = I - // initialized_flag = true - // if initialized_flag: - // output(selected_index) = scatter(output(selected_index), source(S)) - auto loop_body_emitter = - [&](const llvm_ir::IrArray::Index& source_index) -> absl::Status { - // Allocate space to keep the currently selected value, its index, and a - // boolean flag if the value is initialized. The initialized_flag is set - // false. - llvm::Value* selected_value_address = llvm_ir::EmitAllocaAtFunctionEntry( - llvm_ir::PrimitiveTypeToIrType(operand_shape.element_type(), module_), - "selected_value_address", &b_); - - llvm::AllocaInst* selected_index_address = - llvm_ir::EmitAllocaAtFunctionEntryWithCount( - index_type, index_typed_constant(rank), "selected_index_address", - &b_); - - llvm::AllocaInst* initialized_flag_address = - llvm_ir::EmitAllocaAtFunctionEntry(b_.getInt1Ty(), - "initialized_flag_address", &b_); - Store(b_.getInt1(false), initialized_flag_address); - - // Create the inner loop to iterate over the window. - llvm_ir::ForLoopNest window_loops(absl::StrCat(name, "inner"), &b_, - index_type); - - DimensionVector window_size; - for (const WindowDimension& dim : window.dimensions()) { - auto size = static_cast(dim.size()); - window_size.push_back(size); - CHECK_GT(size, 0); - } - - const llvm_ir::IrArray::Index window_index = window_loops.AddLoopsForShape( - ShapeUtil::MakeShape(operand_shape.element_type(), window_size), - "window"); - llvm_ir::SetToFirstInsertPoint(window_loops.GetInnerLoopBodyBasicBlock(), - &b_); - - // Compute the operand index to visit and evaluate the condition whether the - // operand index is within the bounds. The unsigned comparison includes - // checking whether the operand index >= 0. - std::vector operand_multi_index(source_index.size()); - llvm::Value* in_bounds_condition = b_.getInt1(true); - - for (const auto [i, value] : llvm::enumerate(window.dimensions())) { - auto stride = static_cast(value.stride()); - auto padding = static_cast(value.padding_low()); - - llvm::Value* strided_index = - NSWMul(source_index[i], index_typed_constant(stride)); - operand_multi_index[i] = NSWSub(NSWAdd(strided_index, window_index[i]), - index_typed_constant(padding)); - llvm::Value* index_condition = ICmpULT( - operand_multi_index[i], - index_typed_constant(ShapeUtil::GetDimension(operand_shape, i))); - in_bounds_condition = And(in_bounds_condition, index_condition); - } - - // Only need to do something if the operand index is within the bounds. - // First check if the initialized_flag is set. - llvm_ir::LlvmIfData if_in_bounds = - llvm_ir::EmitIfThenElse(in_bounds_condition, "in-bounds", &b_); - llvm_ir::SetToFirstInsertPoint(if_in_bounds.true_block, &b_); - llvm_ir::LlvmIfData if_initialized = llvm_ir::EmitIfThenElse( - Load(initialized_flag_address->getAllocatedType(), - initialized_flag_address), - "initialized", &b_); - - // If the initialized_flag is false, initialize the selected value and index - // with the currently visiting operand. - llvm_ir::SetToFirstInsertPoint(if_initialized.false_block, &b_); - const auto save_operand_index = - [&](const llvm_ir::IrArray::Index& operand_index) { - for (int64_t i = 0; i < rank; ++i) { - llvm::Value* selected_index_address_slot = - InBoundsGEP(selected_index_address->getAllocatedType(), - selected_index_address, {b_.getInt32(i)}); - Store(operand_index[i], selected_index_address_slot); - } - }; - llvm_ir::IrArray::Index operand_index(operand_multi_index, operand_shape, - index_type); - llvm::Value* operand_data = - operand_array.EmitReadArrayElement(operand_index, &b_); - Store(operand_data, selected_value_address); - save_operand_index(operand_index); - Store(b_.getInt1(true), initialized_flag_address); - - // If the initialized_flag is true, call the `select` function to - // potentially update the selected value and index with the currently - // visiting operand. - llvm_ir::SetToFirstInsertPoint(if_initialized.true_block, &b_); - llvm::Value* operand_address = - operand_array.EmitArrayElementAddress(operand_index, &b_); - llvm::AllocaInst* select_return_buffer = llvm_ir::EmitAllocaAtFunctionEntry( - llvm_ir::PrimitiveTypeToIrType(PRED, module_), "select_return_buffer", - &b_); - - const HloComputation* select_computation = instr->select(); - TF_RETURN_IF_ERROR(CallNestedComputation( - &b_, *ir_emitter_context_, *select_computation, - {selected_value_address, operand_address}, select_return_buffer)); - llvm::Value* result = - Load(select_return_buffer->getAllocatedType(), select_return_buffer); - - // If the 'select' function returns false, update the selected value and the - // index to the currently visiting operand. - llvm::Value* cond = - ICmpNE(result, - llvm::ConstantInt::get( - llvm_ir::PrimitiveTypeToIrType(PRED, module_), 0), - "boolean_predicate"); - llvm_ir::LlvmIfData if_select_lhs = - llvm_ir::EmitIfThenElse(cond, "if-select-lhs", &b_); - llvm_ir::SetToFirstInsertPoint(if_select_lhs.false_block, &b_); - Store(Load(operand_array.GetElementLlvmType(), operand_address), - selected_value_address); - save_operand_index(operand_index); - - // If the initialized_flag is true, write to the selected index of the - // output; otherwise the window is outside the source (in the padding) and - // should be ignored. - llvm_ir::SetToFirstInsertPoint(window_loops.GetOuterLoopExitBasicBlock(), - &b_); - llvm_ir::LlvmIfData if_should_store = llvm_ir::EmitIfThenElse( - Load(initialized_flag_address->getAllocatedType(), - initialized_flag_address), - "should-store", &b_, /*emit_else=*/false); - llvm_ir::SetToFirstInsertPoint(if_should_store.true_block, &b_); - - // After iterating over the window elements, scatter the source element to - // the selected index of the output. The value we store at the output - // location is computed by calling the `scatter` function with the source - // value and the current output value. - std::vector selected_multi_index; - for (int64_t i = 0; i < rank; ++i) { - llvm::Value* selected_index_address_slot = - InBoundsGEP(selected_index_address->getAllocatedType(), - selected_index_address, {b_.getInt32(i)}); - selected_multi_index.push_back( - Load(selected_index_address->getAllocatedType(), - selected_index_address_slot)); - } - const Shape output_shape = instr->shape(); - llvm::Value* source_value_address = - source_array.EmitArrayElementAddress(source_index, &b_); - llvm_ir::IrArray::Index selected_index(selected_multi_index, output_shape, - operand_index.GetType()); - llvm::Value* output_value_address = - out_array.EmitArrayElementAddress(selected_index, &b_); - - const HloComputation* scatter_computation = instr->scatter(); - return EmitAtomicOperationForNestedComputation( - &b_, *ir_emitter_context_, *scatter_computation, output_value_address, - source_value_address, source_array.GetElementLlvmType()); - }; - - return ParallelLoopEmitter(loop_body_emitter, source_shape, launch_dimensions, - &b_) - .EmitLoop(name, index_type); -} - -absl::Status IrEmitterUnnested::EmitWhile( - mlir::Operation* op, - const absl::flat_hash_map& - hlo_for_lmhlo) { - auto while_op = mlir::cast(op); - - auto cond_result = GetHloOutputs(while_op); - TF_RET_CHECK(cond_result.size() == 1); - TF_RET_CHECK(cond_result[0] - .getType() - .cast() - .getElementType() - .isInteger(/*width=*/1)) - << "While condition computation must return bool"; - - TF_ASSIGN_OR_RETURN( - auto thunk, - BuildWhileThunk(while_op, Thunk::ThunkInfo::WithProfileAnnotation(op), - hlo_for_lmhlo, while_op.getTripCount())); AddThunkToThunkSequence(std::move(thunk)); return absl::OkStatus(); } -absl::Status IrEmitterUnnested::EmitWhile(const HloInstruction* instr) { - TF_ASSIGN_OR_RETURN(auto config, - instr->backend_config()); - - std::optional trip_count = std::nullopt; - if (config.has_known_trip_count()) trip_count = config.known_trip_count().n(); - - TF_ASSIGN_OR_RETURN( - auto thunk, - BuildWhileThunk(instr, Thunk::ThunkInfo::WithProfileAnnotation(instr), - trip_count)); - - AddThunkToThunkSequence(std::move(thunk)); - return absl::OkStatus(); -} - -absl::Status IrEmitterUnnested::EmitRngGetAndUpdateState(mlir::Operation* op) { - auto rng_op = mlir::dyn_cast(op); - - // Emit a kernel to increment the global state for Philox RNG algorithm. - TF_ASSIGN_OR_RETURN(auto ir_arrays, - BuildKernelThunkForNonFusionOp( - rng_op /*, rng_op.getState(),*/, LaunchDimensions())); - auto& [inputs, outputs] = ir_arrays; - - llvm::Value* old_state = - llvm_ir::RngGetAndUpdateState(rng_op.getDelta(), module_, &b_); - - const Shape shape = GetShape(rng_op.getState()); - - llvm::Value* output_address = inputs[0].EmitArrayElementAddress( - llvm_ir::IrArray::Index( - /*linear=*/b_.getInt64(0), shape, &b_), - &b_, "rng_state_address"); - Store(old_state, output_address); - - return absl::OkStatus(); -} - absl::Status IrEmitterUnnested::EmitRngGetAndUpdateState( const HloRngGetAndUpdateStateInstruction* instr) { // Emit a kernel to increment the global state for Philox RNG algorithm. @@ -3107,13 +2185,7 @@ absl::Status IrEmitterUnnested::EmitRngGetAndUpdateState( return absl::OkStatus(); } -absl::Status IrEmitterUnnested::EmitSort(mlir::Operation* op, - const HloSortInstruction* sort) { - auto sort_op = mlir::dyn_cast_or_null(op); - if (!ir_emitter_context_->emit_ir_from_hlo() && !sort_op) { - return absl::InternalError("MLIR operations must be not null"); - } - +absl::Status IrEmitterUnnested::EmitSort(const HloSortInstruction* sort) { std::string op_name(sort->name()); const Shape& keys_shape = sort->operand(0)->shape(); int64_t dimension_to_sort = sort->sort_dimension(); @@ -3132,29 +2204,20 @@ absl::Status IrEmitterUnnested::EmitSort(mlir::Operation* op, // If possible, we share buffers. If that is not possible, we need to // copy the values, because the emitter does the sorting in-place. - if (ir_emitter_context_->emit_ir_from_hlo()) { - TF_ASSIGN_OR_RETURN(destination_buffer, - GetAllocationSliceForHlo(sort, shape_index)); - TF_ASSIGN_OR_RETURN(source_address, - GetAllocationSliceForHlo(sort->operand(i), {})); - } else { - TF_ASSIGN_OR_RETURN(destination_buffer, - GetAllocationSlice(sort_op.getOutput()[i])); - TF_ASSIGN_OR_RETURN(source_address, - GetAllocationSlice(sort_op.getOperands()[i])); - } + TF_ASSIGN_OR_RETURN(destination_buffer, + GetAllocationSliceForHlo(sort, shape_index)); + TF_ASSIGN_OR_RETURN(source_address, + GetAllocationSliceForHlo(sort->operand(i), {})); if (destination_buffer != source_address) { // TODO(b/26783907): Figure out why we never seem to share buffers for // key/value sort. VLOG(2) << op_name << " requires initial D2D copy for operand " << i; AddThunkToThunkSequence(std::make_unique( - Thunk::ThunkInfo(op), + Thunk::ThunkInfo::WithProfileAnnotation(sort), /*source_buffer=*/source_address, /*destination_buffer=*/destination_buffer, - /*mem_size=*/ShapeUtil::ByteSizeOf(sort->operand(i)->shape()), - /*source_value=*/sort_op ? sort_op.getOperands()[i] : nullptr, - /*destination_value=*/sort_op ? sort_op.getOutput()[i] : nullptr)); + /*mem_size=*/ShapeUtil::ByteSizeOf(sort->operand(i)->shape()))); } } @@ -3256,12 +2319,8 @@ absl::Status IrEmitterUnnested::EmitSort(mlir::Operation* op, LaunchDimensions launch_dimensions = xor_masks.size() > 1 ? tiled_launch_dimensions : standard_launch_dimensions; - TF_ASSIGN_OR_RETURN( - auto ir_arrays, - ir_emitter_context_->emit_ir_from_hlo() - ? BuildKernelThunkForNonFusionOp(sort, {}, launch_dimensions) - : BuildKernelThunkForNonFusionOp(sort_op, sort_op.getOutput(), - launch_dimensions)); + TF_ASSIGN_OR_RETURN(auto ir_arrays, BuildKernelThunkForNonFusionOp( + sort, {}, launch_dimensions)); auto& [inputs, outputs] = ir_arrays; auto* comparator = sort->called_computations().front(); @@ -3297,219 +2356,61 @@ absl::Status IrEmitterUnnested::EmitSort(mlir::Operation* op, } } if (!xor_masks.empty()) { - TF_RETURN_IF_ERROR(emit_kernel(xor_masks)); - } - return absl::OkStatus(); -} - -absl::Status IrEmitterUnnested::EmitSort(const HloSortInstruction* sort) { - CHECK(ir_emitter_context_->emit_ir_from_hlo()); // NOLINT - return EmitSort(nullptr, sort); -} - -template -absl::Status IrEmitterUnnested::EmitReplicaOrPartitionId(mlir::Operation* op) { - auto casted = mlir::cast(op); - TF_ASSIGN_OR_RETURN(BufferAllocation::Slice result_slice, - GetAllocationSlice(casted.getOperand())); - auto thunk = std::make_unique( - Thunk::ThunkInfo::WithProfileAnnotation(op), result_slice); - AddThunkToThunkSequence(std::move(thunk)); - return absl::OkStatus(); -} - -template -absl::Status IrEmitterUnnested::EmitReplicaOrPartitionId( - const HloInstruction* instr) { - TF_ASSIGN_OR_RETURN(BufferAllocation::Slice result_slice, - GetAllocationSliceForHlo(instr, {})); - auto thunk = std::make_unique( - Thunk::ThunkInfo::WithProfileAnnotation(instr), result_slice); - AddThunkToThunkSequence(std::move(thunk)); - return absl::OkStatus(); -} - -Status IrEmitterUnnested::EmitCollectivePermute(mlir::Operation* op) { - auto collective_permute_op = - mlir::cast(op); - - TF_ASSIGN_OR_RETURN(BufferAllocation::Slice source_slice, - GetAllocationSlice(collective_permute_op.getOperand())); - TF_ASSIGN_OR_RETURN(BufferAllocation::Slice result_slice, - GetAllocationSlice(collective_permute_op.getOutput())); - - const Shape shape = GetShape(collective_permute_op.getOperand()); - const auto& hlo_config = ir_emitter_context_->hlo_module().config(); - const int64_t replica_count = hlo_config.replica_count(); - const int64_t partition_count = hlo_config.num_partitions(); - - if (NcclCollectivePermuteStartThunk::IsDegenerate( - collective_permute_op, replica_count, partition_count)) { - // For a degenerate collective permute, just generate a copy thunk. - AddThunkToThunkSequence(std::make_unique( - Thunk::ThunkInfo::WithProfileAnnotation(op), - /*source_buffer=*/source_slice, - /*destination_buffer=*/result_slice, - /*mem_size=*/ShapeUtil::ByteSizeOf(shape), - /*source_value=*/collective_permute_op.getOperand(), - /*destination_value=*/collective_permute_op.getOutput())); - - // Signal that start thunk not created with nullptr. - collectives_async_events_.try_emplace(op, nullptr); - } else { - const NcclCollectiveThunk::Buffer buffer = { - /*element_count=*/ShapeUtil::ElementsIn(shape), - /*source_buffer=*/source_slice, - /*destination_buffer=*/result_slice}; - auto thunk = std::make_unique( - Thunk::ThunkInfo::WithProfileAnnotation(op), NcclApi::Default(), - collective_permute_op, replica_count, partition_count, buffer); - collectives_async_events_.try_emplace(op, thunk->async_events()); - AddThunkToThunkSequence(std::move(thunk)); - } - return absl::OkStatus(); -} - -Status IrEmitterUnnested::EmitCollectivePermute( - const HloCollectivePermuteInstruction* instr) { - TF_RET_CHECK(instr->operand_count() == 1); - auto* operand = instr->operand(0); - TF_ASSIGN_OR_RETURN(BufferAllocation::Slice source_slice, - GetAllocationSliceForHlo(operand)); - // First output is aliased. - TF_RET_CHECK( - instr->shape().IsTuple() && instr->shape().tuple_shapes_size() == 2 && - instr->shape().tuple_shapes(0) == instr->shape().tuple_shapes(1)); - TF_ASSIGN_OR_RETURN(BufferAllocation::Slice result_slice, - GetAllocationSliceForHlo(instr, {1})); - - const Shape shape = operand->shape(); - const auto& hlo_config = ir_emitter_context_->hlo_module().config(); - const int64_t replica_count = hlo_config.replica_count(); - const int64_t partition_count = hlo_config.num_partitions(); - - if (NcclCollectivePermuteStartThunk::IsDegenerate(instr, replica_count, - partition_count)) { - // For a degenerate collective permute, just generate a copy thunk. - AddThunkToThunkSequence(std::make_unique( - Thunk::ThunkInfo::WithProfileAnnotation(instr), - /*source_buffer=*/source_slice, - /*destination_buffer=*/result_slice, - /*mem_size=*/ShapeUtil::ByteSizeOf(shape), - /*source_value=*/nullptr, - /*destination_value=*/nullptr)); - // Signal that start thunk not created with nullptr. - collectives_async_events_.try_emplace(instr, nullptr); - - } else { - const NcclCollectiveThunk::Buffer buffer = { - /*element_count=*/ShapeUtil::ElementsIn(shape), - /*source_buffer=*/source_slice, - /*destination_buffer=*/result_slice}; - auto thunk = std::make_unique( - Thunk::ThunkInfo::WithProfileAnnotation(instr), NcclApi::Default(), - instr, replica_count, partition_count, buffer); - collectives_async_events_.try_emplace(instr, thunk->async_events()); - AddThunkToThunkSequence(std::move(thunk)); - } - return absl::OkStatus(); -} - -template -absl::Status IrEmitterUnnested::EmitNcclThunk(mlir::Operation* untyped_op) { - OpT op = mlir::cast(untyped_op); - const auto& hlo_config = ir_emitter_context_->hlo_module().config(); - int64_t replica_count = hlo_config.replica_count(); - int64_t partition_count = hlo_config.num_partitions(); - VLOG(2) << NcclThunkType::GetHloOpName() - << "; replica count: " << replica_count - << "; partition count: " << partition_count - << "; operand count: " << op.getOperands().size(); - - // A given collective op can be degenerate if across all groups formed - // by it are singleton. In such a case, we don't need to do any communication - // and we can just copy the input to the output. - bool is_degenerate = - GetNcclCollectiveConfigForMlir(op, op.getUseGlobalDeviceIds()) - .IsDegenerate(replica_count, partition_count); - absl::Status implementable_status = - NcclThunkType::CheckImplementable(op, replica_count, partition_count); - bool should_use_nccl_thunk = !is_degenerate && implementable_status.ok(); - - // Stash relevant information in NcclCollectiveThunk::Buffer even if we may - // not generate an NcclCollectiveThunk. - std::vector buffers; - buffers.reserve(op.getInputs().size()); - for (auto it : llvm::zip(op.getInputs(), op.getOutputs())) { - mlir::Value operand = std::get<0>(it); - mlir::Value result = std::get<1>(it); - const Shape shape = GetShape(operand); - TF_ASSIGN_OR_RETURN(auto source_slice, GetAllocationSlice(operand)); - TF_ASSIGN_OR_RETURN(auto dest_slice, GetAllocationSlice(result)); - buffers.push_back(NcclCollectiveThunk::Buffer{ - /*element_count=*/ShapeUtil::ElementsIn(shape), - /*source_buffer=*/source_slice, - /*destination_buffer=*/dest_slice, - /*source_memory_space=*/0, // always 0 for LMHLO - /*destination_memory_space=*/0, // always 0 for LMHLO - /*source_value=*/operand, - /*destination_value=*/result}); - } - - if (should_use_nccl_thunk) { - auto thunk = std::make_unique( - Thunk::ThunkInfo::WithProfileAnnotation(op), NcclApi::Default(), op, - /*buffers=*/std::move(buffers)); - collectives_async_events_.try_emplace(untyped_op, thunk->async_events()); - AddThunkToThunkSequence(std::move(thunk)); - return absl::OkStatus(); - } - - if (!is_degenerate) { - return implementable_status; - } - - // Signal that start thunk not created with nullptr. - collectives_async_events_.try_emplace(untyped_op, nullptr); - - VLOG(1) << "Collective call is degenerate, not doing NCCL call"; - - // Degenerate collectives are simply identity function. Buffer - // assignment expects a copy, so that's what we do. - ThunkSequence thunks; - for (int64_t i = 0; i < buffers.size(); i++) { - const Shape shape = GetShape(op.getOperands()[i]); - thunks.push_back(std::make_unique( - buffers.size() == 1 ? Thunk::ThunkInfo::WithProfileAnnotation(op) - : Thunk::ThunkInfo(op), - /*source_buffer=*/buffers[i].source_buffer, - /*destination_buffer=*/buffers[i].destination_buffer, - /*mem_size=*/ShapeUtil::ByteSizeOf(shape), - /*source_value=*/buffers[i].source_value, - /*destination_value=*/buffers[i].destination_value)); - } - if (thunks.size() == 1) { - AddThunkToThunkSequence(std::move(thunks[0])); - } else { - AddThunkToThunkSequence(std::make_unique( - Thunk::ThunkInfo::WithProfileAnnotation(op), std::move(thunks))); + TF_RETURN_IF_ERROR(emit_kernel(xor_masks)); } return absl::OkStatus(); } -absl::Status IrEmitterUnnested::EmitNcclAsyncDone(Thunk::Kind kind, - mlir::Operation* op, - mlir::Value token) { - auto start_op = token.getDefiningOp(); - auto async_events = collectives_async_events_.extract(start_op); - TF_RET_CHECK(async_events) << "couldn't find async events for start op"; +template +absl::Status IrEmitterUnnested::EmitReplicaOrPartitionId( + const HloInstruction* instr) { + TF_ASSIGN_OR_RETURN(BufferAllocation::Slice result_slice, + GetAllocationSliceForHlo(instr, {})); + auto thunk = std::make_unique( + Thunk::ThunkInfo::WithProfileAnnotation(instr), result_slice); + AddThunkToThunkSequence(std::move(thunk)); + return absl::OkStatus(); +} - // Can be null if no start thunk was created (e.g. if the start op is - // degenerate), in which case there's nothing to do here. - if (async_events.mapped()) { - AddThunkToThunkSequence(std::make_unique( - kind, Thunk::ThunkInfo::WithProfileAnnotation(op), - std::move(async_events.mapped()))); +Status IrEmitterUnnested::EmitCollectivePermute( + const HloCollectivePermuteInstruction* instr) { + TF_RET_CHECK(instr->operand_count() == 1); + auto* operand = instr->operand(0); + TF_ASSIGN_OR_RETURN(BufferAllocation::Slice source_slice, + GetAllocationSliceForHlo(operand)); + // First output is aliased. + TF_RET_CHECK( + instr->shape().IsTuple() && instr->shape().tuple_shapes_size() == 2 && + instr->shape().tuple_shapes(0) == instr->shape().tuple_shapes(1)); + TF_ASSIGN_OR_RETURN(BufferAllocation::Slice result_slice, + GetAllocationSliceForHlo(instr, {1})); + + const Shape shape = operand->shape(); + const auto& hlo_config = ir_emitter_context_->hlo_module().config(); + const int64_t replica_count = hlo_config.replica_count(); + const int64_t partition_count = hlo_config.num_partitions(); + + if (NcclCollectivePermuteStartThunk::IsDegenerate(instr, replica_count, + partition_count)) { + // For a degenerate collective permute, just generate a copy thunk. + AddThunkToThunkSequence(std::make_unique( + Thunk::ThunkInfo::WithProfileAnnotation(instr), + /*source_buffer=*/source_slice, + /*destination_buffer=*/result_slice, + /*mem_size=*/ShapeUtil::ByteSizeOf(shape))); + // Signal that start thunk not created with nullptr. + collectives_async_events_.try_emplace(instr, nullptr); + + } else { + const NcclCollectiveThunk::Buffer buffer = { + /*element_count=*/ShapeUtil::ElementsIn(shape), + /*source_buffer=*/source_slice, + /*destination_buffer=*/result_slice}; + auto thunk = std::make_unique( + Thunk::ThunkInfo::WithProfileAnnotation(instr), NcclApi::Default(), + instr, replica_count, partition_count, buffer); + collectives_async_events_.try_emplace(instr, thunk->async_events()); + AddThunkToThunkSequence(std::move(thunk)); } return absl::OkStatus(); } @@ -3611,9 +2512,7 @@ absl::Status IrEmitterUnnested::EmitNcclThunk( Thunk::ThunkInfo::WithProfileAnnotation(inst), /*source_buffer=*/buffers[i].source_buffer, /*destination_buffer=*/buffers[i].destination_buffer, - /*mem_size=*/ShapeUtil::ByteSizeOf(shape), - /*source_value=*/buffers[i].source_value, - /*destination_value=*/buffers[i].destination_value)); + /*mem_size=*/ShapeUtil::ByteSizeOf(shape))); } if (thunks.size() == 1) { AddThunkToThunkSequence(std::move(thunks[0])); @@ -3641,25 +2540,33 @@ absl::Status IrEmitterUnnested::EmitNcclAsyncDone(Thunk::Kind kind, return absl::OkStatus(); } -absl::StatusOr> IrEmitterUnnested::GetShapedSlices( - mlir::Operation::operand_range operands) { - std::vector shaped_slices; - shaped_slices.reserve(operands.size()); - for (mlir::Value opnd : operands) { - TF_ASSIGN_OR_RETURN(auto slice, GetAllocationSlice(opnd)); - shaped_slices.push_back(ShapedSlice{slice, GetShape(opnd)}); +absl::Status IrEmitterUnnested::EmitWaitForStreamsThunk( + const HloInstruction* inst, GpuBackendConfig& gpu_config, + bool is_async_done) { + std::vector wait_on_streams; + ExecutionStreamId source_stream_id = Thunk::GetMainComputeStreamId(); + // If it's for an async done, then we need to sychronize on the execution + // stream of the instruction from main compute stream + if (is_async_done) { + wait_on_streams.push_back( + ExecutionStreamId(gpu_config.operation_queue_id())); + } else if (gpu_config.wait_on_operation_queues().size() == 0) { + // If wait on queue is empty, we just synchronize on the main compute + // stream from the execution stream. + wait_on_streams.push_back(Thunk::GetMainComputeStreamId()); + source_stream_id = gpu_config.operation_queue_id(); + } else { + // Else, we synchronize on all specified + // streams from the execution stream. + for (int64_t stream_id : gpu_config.wait_on_operation_queues()) { + wait_on_streams.push_back(ExecutionStreamId(stream_id)); + } + source_stream_id = gpu_config.operation_queue_id(); } - return shaped_slices; -} - -absl::Status IrEmitterUnnested::EmitInfeed(mlir::Operation* op) { - mlir::Operation::operand_range operands = - mlir::cast(op).getOutputs(); - TF_ASSIGN_OR_RETURN(auto shaped_slices, GetShapedSlices(operands)); - auto thunk = std::make_unique( - Thunk::ThunkInfo::WithProfileAnnotation(op), std::move(shaped_slices)); - AddThunkToThunkSequence(std::move(thunk)); + AddThunkToThunkSequence(std::make_unique( + Thunk::ThunkInfo::WithProfileAnnotation(inst), source_stream_id, + wait_on_streams)); return absl::OkStatus(); } @@ -3687,17 +2594,6 @@ absl::Status IrEmitterUnnested::EmitInfeed(const HloInfeedInstruction* instr) { return absl::OkStatus(); } -absl::Status IrEmitterUnnested::EmitOutfeed(mlir::Operation* op) { - mlir::Operation::operand_range operands = - mlir::cast(op).getInputs(); - TF_ASSIGN_OR_RETURN(auto shaped_slices, GetShapedSlices(operands)); - auto thunk = std::make_unique( - Thunk::ThunkInfo::WithProfileAnnotation(op), std::move(shaped_slices)); - AddThunkToThunkSequence(std::move(thunk)); - - return absl::OkStatus(); -} - absl::Status IrEmitterUnnested::EmitOutfeed( const HloOutfeedInstruction* instr) { // HLO outfeed instruction has 2 operands, the source and a token, and a @@ -3724,39 +2620,6 @@ absl::Status IrEmitterUnnested::EmitOutfeed( return absl::OkStatus(); } -absl::StatusOr< - std::pair, std::vector>> -IrEmitterUnnested::BuildKernelThunkForNonFusionOp( - mlir::Operation* op, mlir::ValueRange needed_operands, - const LaunchDimensions& launch_dimensions) { - TF_RET_CHECK(!mlir::isa(op)) - << "Please use BuildKernelThunkForFusion!"; - - std::string suggested_kernel_name = GetIrNameFromLoc(op->getLoc()); - - TF_ASSIGN_OR_RETURN( - auto kernel_arguments, - KernelArguments::Create(ir_emitter_context_->allocations(), op, - needed_operands)); - - VLOG(3) << "Generating (without reuse check): " << suggested_kernel_name; - - llvm::Function* kernel; - std::vector inputs; - std::vector outputs; - TF_ASSIGN_OR_RETURN( - std::tie(kernel, inputs, outputs), - BuildKernelPrototype(*ir_emitter_context_, suggested_kernel_name, - kernel_arguments.args(), needed_operands.size(), - launch_dimensions, &b_)); - - AddThunkToThunkSequence(std::make_unique( - op, kernel->getName().str(), kernel_arguments.args(), launch_dimensions, - /*shmem_bytes=*/0)); - - return {{inputs, outputs}}; -} - absl::StatusOr /*inputs*/, std::vector /*outputs*/>> IrEmitterUnnested::BuildKernelThunkForNonFusionOp( @@ -3783,37 +2646,25 @@ IrEmitterUnnested::BuildKernelThunkForNonFusionOp( AddThunkToThunkSequence(std::make_unique( hlo, kernel->getName().str(), kernel_arguments.args(), launch_dimensions, + /*cluster_dim=*/std::nullopt, /*shmem_bytes=*/0)); return {{inputs, outputs}}; } -absl::StatusOr< - std::pair, std::vector>> -IrEmitterUnnested::BuildKernelThunkForNonFusionOp( - mlir::Operation* op, const LaunchDimensions& launch_dimensions) { - return BuildKernelThunkForNonFusionOp(op, op->getOperands(), - launch_dimensions); -} - absl::Status IrEmitterUnnested::BuildInitializerThunk( - mlir::Operation* op, const HloInstruction* instr, - const HloInstruction* init_value, mlir::Value init_value_mlir, - mlir::Value dest) { + const HloInstruction* instr, const HloInstruction* init_value) { // initial value must be a scalar memref. TF_RET_CHECK(init_value->shape().rank() == 0); - auto maybe_dest_slice = ir_emitter_context_->emit_ir_from_hlo() - ? GetAllocationSliceForHlo(instr, {}) - : GetAllocationSlice(dest); + auto maybe_dest_slice = GetAllocationSliceForHlo(instr, {}); if (!maybe_dest_slice.ok()) return maybe_dest_slice.status(); BufferAllocation::Slice dest_slice = *maybe_dest_slice; - TF_ASSIGN_OR_RETURN( - std::optional> constant_init_thunk, - BuildConstantInitializerThunk(*ir_emitter_context_, op, instr, init_value, - dest, dest_slice)); + TF_ASSIGN_OR_RETURN(std::optional> constant_init_thunk, + BuildConstantInitializerThunk(*ir_emitter_context_, instr, + init_value, dest_slice)); if (constant_init_thunk) { AddThunkToThunkSequence(*std::move(constant_init_thunk)); return absl::OkStatus(); @@ -3821,24 +2672,17 @@ absl::Status IrEmitterUnnested::BuildInitializerThunk( // Otherwise fall back to our slow initializer code. The thunk in this case // will just need the IR arrays for the initial value and the destination. - const Shape dest_shape = - ir_emitter_context_->emit_ir_from_hlo() ? instr->shape() : GetShape(dest); + const Shape& dest_shape = instr->shape(); LaunchDimensions launch_dimensions = CalculateLaunchDimensions( dest_shape, ir_emitter_context_->gpu_device_info()); TF_ASSIGN_OR_RETURN( auto ir_arrays, - ir_emitter_context_->emit_ir_from_hlo() - ? BuildKernelThunkForNonFusionOp(instr, {init_value}, - launch_dimensions) - : BuildKernelThunkForNonFusionOp(op, {init_value_mlir, dest}, - launch_dimensions)); + BuildKernelThunkForNonFusionOp(instr, {init_value}, launch_dimensions)); auto& [inputs, outputs] = ir_arrays; auto init_array = inputs[0]; - std::string name = ir_emitter_context_->emit_ir_from_hlo() - ? llvm_ir::IrName(instr, "init") - : GetIrNameFromLoc(op->getLoc()); + std::string name = llvm_ir::IrName(instr, "init"); TF_RETURN_IF_ERROR(ParallelLoopEmitter( [=](const llvm_ir::IrArray::Index& index) { return init_array.EmitReadArrayElement(index, &b_); @@ -3848,37 +2692,6 @@ absl::Status IrEmitterUnnested::BuildInitializerThunk( return absl::OkStatus(); } -absl::StatusOr> IrEmitterUnnested::BuildWhileThunk( - mlir::lmhlo::WhileOp while_op, const Thunk::ThunkInfo& thunk_info, - const absl::flat_hash_map& - hlo_for_lmhlo, - std::optional trip_count) { - // Generate thunk sequence for while 'condition'. - mlir::Region* condition = &while_op.getCond(); - auto ir_emitter_condition = IrEmitterUnnested::Create(ir_emitter_context_); - - TF_RETURN_IF_ERROR( - ir_emitter_condition->EmitLmhloRegion(condition, hlo_for_lmhlo)); - - // Generate thunk sequence for while 'body'. - mlir::Region* body = &while_op.getBody(); - auto ir_emitter_body = IrEmitterUnnested::Create(ir_emitter_context_); - - TF_RETURN_IF_ERROR(ir_emitter_body->EmitLmhloRegion(body, hlo_for_lmhlo)); - - // Extract the condition value from the last op (excluding the terminator op) - // in the condition region. - auto cond_result = GetHloOutputs(while_op); - TF_RET_CHECK(cond_result.size() == 1); - TF_ASSIGN_OR_RETURN(auto cond_result_slice, - GetAllocationSlice(cond_result[0])); - - return std::unique_ptr( - new WhileThunk(thunk_info, cond_result_slice, - ir_emitter_condition->ConsumeThunkSequence(), - ir_emitter_body->ConsumeThunkSequence(), trip_count)); -} - absl::StatusOr> IrEmitterUnnested::BuildWhileThunk( const HloInstruction* instr, const Thunk::ThunkInfo& thunk_info, std::optional trip_count) { @@ -4018,354 +2831,6 @@ absl::Status IrEmitterUnnested::EmitRecvDoneThunk( return absl::OkStatus(); } -absl::Status IrEmitterUnnested::EmitOp( - mlir::Operation* op, - const absl::flat_hash_map& - hlo_for_lmhlo) { - if (mlir::isa(op)) { - return absl::OkStatus(); - } - - if (mlir::isa(op)) { - const HloConstantInstruction* hlo_const_instr = - DynCast(hlo_for_lmhlo.at(op)); - TF_RET_CHECK(hlo_const_instr); - return EmitConstant(op, hlo_const_instr->literal()); - } - - bool is_gpu_runtime = ir_emitter_context_->debug_options() - .xla_gpu_enable_xla_runtime_executable(); - - if (auto call = mlir::dyn_cast(op)) { - if (call.getCallTargetName() == "PadToStatic") { - return EmitPadToStatic(op); - } - if (call.getCallTargetName() == "SliceToDynamic") { - return EmitSliceToDynamic(op); - } - const llvm::StringRef call_target = call.getCallTargetName(); -#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM - if (absl::string_view(call_target.data(), call_target.size()) == - kTriangularSolveCallTarget) { - return EmitTriangularSolveCustomCall(op); - } -#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM - - if (!is_gpu_runtime && call.getCallTargetName() == "__gpu$TopK") { - return EmitTopKCustomCall( - Cast(hlo_for_lmhlo.at(op))); - } - - return EmitCustomCallThunk( - op, Cast(hlo_for_lmhlo.at(op))); - } - - if (mlir::isa(op)) { - if (ir_emitter_context_->emit_ir_from_hlo()) { - const HloCustomCallInstruction* instr = - Cast(hlo_for_lmhlo.at(op)); - return EmitGemmThunk(instr); - } - return EmitGemmThunk(op); - } - -#if GOOGLE_CUDA || TF_HIPBLASLT - if (mlir::isa(op)) { - if (ir_emitter_context_->emit_ir_from_hlo()) { - const auto* instr = Cast(hlo_for_lmhlo.at(op)); - return EmitCublasLtMatmulThunk(instr); - } - return EmitCublasLtMatmulThunk(op); - } -#endif // GOOGLE_CUDA || TF_HIPBLASLT -#if GOOGLE_CUDA - if (mlir::isa(op)) { - if (ir_emitter_context_->emit_ir_from_hlo()) { - const auto* instr = Cast(hlo_for_lmhlo.at(op)); - return EmitCublasLtMatmulThunkF8(instr); - } - return EmitCublasLtMatmulThunkF8(op); - } - if (mlir::isa(op)) { - return EmitConvolutionReorderThunk(op); - } - if (mlir::isa(op)) { - return EmitNormThunk(op); - } - if (mlir::isa(op)) { - return EmitFusedMHAThunk(op); - } - if (mlir::isa(op)) { - return EmitFusedMHABackwardThunk(op); - } -#endif // GOOGLE_CUDA - - if (mlir::isa(op)) { - if (ir_emitter_context_->emit_ir_from_hlo()) { - return EmitConvolutionThunk( - Cast(hlo_for_lmhlo.at(op))); - } - return EmitConvolutionThunk(op); - } - -#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM - if (mlir::isa(op)) { - return EmitCubDeviceRadixSort(op); - } - if (mlir::isa(op)) { - if (ir_emitter_context_->emit_ir_from_hlo()) { - return EmitCholeskyThunk(hlo_for_lmhlo.at(op)); - } else { - return EmitCholeskyThunk(op); - } - } -#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM - - if (mlir::isa(op)) { - if (ir_emitter_context_->emit_ir_from_hlo()) { - return EmitFftThunk(Cast(hlo_for_lmhlo.at(op))); - } - return EmitFftThunk(op); - } - - if (mlir::isa(op)) { - return Internal( - "TriangularSolve is implemented as a custom-call; we do not expect to " - "lower a true HLO TriangularSolve op."); - } - - if (mlir::isa(op)) { - if (ir_emitter_context_->emit_ir_from_hlo()) { - const HloFusionInstruction* instr = - Cast(hlo_for_lmhlo.at(op)); - const se::DeviceDescription& device_info = - ir_emitter_context_->gpu_device_info(); - auto fusion_analysis = HloFusionAnalysis::Create(instr, &device_info); - return EmitFusion(instr, fusion_analysis); - } - - return EmitFusion(op, hlo_for_lmhlo); - } - - if (mlir::isa(op)) { - if (ir_emitter_context_->emit_ir_from_hlo()) { - return EmitSelectAndScatter( - Cast(hlo_for_lmhlo.at(op))); - } - return EmitSelectAndScatter(op, hlo_for_lmhlo); - } - - if (mlir::isa(op)) { - if (ir_emitter_context_->emit_ir_from_hlo()) { - return EmitRngGetAndUpdateState( - Cast(hlo_for_lmhlo.at(op))); - } - return EmitRngGetAndUpdateState(op); - } - - if (mlir::isa(op)) { - return EmitSort(op, Cast(hlo_for_lmhlo.at(op))); - } - - if (mlir::isa(op)) { - if (ir_emitter_context_->emit_ir_from_hlo()) { - return EmitReplicaOrPartitionId(hlo_for_lmhlo.at(op)); - } - return EmitReplicaOrPartitionId( - op); - } - - if (mlir::isa(op)) { - if (ir_emitter_context_->emit_ir_from_hlo()) { - return EmitReplicaOrPartitionId(hlo_for_lmhlo.at(op)); - } - return EmitReplicaOrPartitionId(op); - } - - if (mlir::isa(op)) { - if (ir_emitter_context_->emit_ir_from_hlo()) { - return EmitCollectivePermute( - Cast(hlo_for_lmhlo.at(op))); - } - return EmitCollectivePermute(op); - } - - if (mlir::isa(op)) { - if (ir_emitter_context_->emit_ir_from_hlo()) { - return EmitNcclAsyncDone(Thunk::kNcclCollectivePermuteDone, - hlo_for_lmhlo.at(op)); - } - return EmitNcclAsyncDone( - Thunk::kNcclCollectivePermuteDone, op, - mlir::cast(op).getToken()); - } - - if (mlir::isa(op)) { - if (ir_emitter_context_->emit_ir_from_hlo()) { - auto* all_gather = Cast(hlo_for_lmhlo.at(op)); - return EmitNcclThunk( - Thunk::kNcclAllGatherStart, all_gather, all_gather, - all_gather->use_global_device_ids()); - } - return EmitNcclThunk(op); - } - - if (mlir::isa(op)) { - if (ir_emitter_context_->emit_ir_from_hlo()) { - return EmitNcclAsyncDone(Thunk::kNcclAllGatherDone, hlo_for_lmhlo.at(op)); - } - return EmitNcclAsyncDone( - Thunk::kNcclAllGatherDone, op, - mlir::cast(op).getToken()); - } - - if (mlir::isa(op)) { - if (ir_emitter_context_->emit_ir_from_hlo()) { - auto* all_reduce = Cast(hlo_for_lmhlo.at(op)); - return EmitNcclThunk( - Thunk::kNcclAllReduceStart, all_reduce, all_reduce, - all_reduce->use_global_device_ids()); - } - return EmitNcclThunk(op); - } - - if (mlir::isa(op)) { - if (ir_emitter_context_->emit_ir_from_hlo()) { - return EmitNcclAsyncDone(Thunk::kNcclAllReduceDone, hlo_for_lmhlo.at(op)); - } - return EmitNcclAsyncDone( - Thunk::kNcclAllReduceDone, op, - mlir::cast(op).getToken()); - } - - if (mlir::isa(op)) { - if (ir_emitter_context_->emit_ir_from_hlo()) { - auto* async_start = hlo_for_lmhlo.at(op); - auto* reduce_scatter = Cast( - async_start->async_wrapped_instruction()); - return EmitNcclThunk( - Thunk::kNcclReduceScatterStart, async_start, reduce_scatter, - reduce_scatter->use_global_device_ids()); - } - return EmitNcclThunk(op); - } - - if (mlir::isa(op)) { - if (ir_emitter_context_->emit_ir_from_hlo()) { - return EmitNcclAsyncDone(Thunk::kNcclReduceScatterDone, - hlo_for_lmhlo.at(op)); - } - return EmitNcclAsyncDone( - Thunk::kNcclReduceScatterDone, op, - mlir::cast(op).getToken()); - } - - if (mlir::isa(op)) { - return EmitNcclThunk(op); - } - - if (mlir::isa(op)) { - return EmitNcclAsyncDone( - Thunk::kNcclAllToAllDone, op, - mlir::cast(op).getToken()); - } - - if (mlir::isa(op)) { - if (ir_emitter_context_->emit_ir_from_hlo()) { - return EmitInfeed(Cast(hlo_for_lmhlo.at(op))); - } - return EmitInfeed(op); - } - - if (mlir::isa(op)) { - if (ir_emitter_context_->emit_ir_from_hlo()) { - return EmitOutfeed(Cast(hlo_for_lmhlo.at(op))); - } - return EmitOutfeed(op); - } - - if (mlir::isa(op)) { - return EmitConditional(op, hlo_for_lmhlo); - } - - if (mlir::isa(op)) { - // TODO(ezhulenev): While loops may contain instructions that do not support - // emitting from HLO, so we can't yet enable while thunk emission here. - static constexpr bool kWhileThunkNotSupported = true; - if (ir_emitter_context_->emit_ir_from_hlo() && !kWhileThunkNotSupported) { - return EmitWhile(hlo_for_lmhlo.at(op)); - } - return EmitWhile(op, hlo_for_lmhlo); - } - - // Remaining arith.constant ops are the gpu.launch_func dimensions as a result - // of inlining the fusion region after lowering. They can safely be skipped - // because constants have no side effects. - if (mlir::isa(op)) { - return absl::OkStatus(); - } - - if (mlir::isa(op)) { - return EmitCommandBufferThunk(hlo_for_lmhlo.at(op)); - } - - // In GPU runtime point-to-point communications implemented as runtime custom - // calls, and we do not need real thunks to construct them, so we can emit - // stubs that always fail. This is deprecated and will be removed in Q1 2024. - if (is_gpu_runtime && - mlir::isa(op)) { - return EmitUnreachable(op, - "Point-to-point communication operations are not " - "implemented as thunks"); - } - - if (mlir::isa(op)) { - return EmitSendThunk(Cast(hlo_for_lmhlo.at(op))); - } - - if (mlir::isa(op)) { - return EmitSendDoneThunk( - Cast(hlo_for_lmhlo.at(op))); - } - - if (mlir::isa(op)) { - return EmitRecvThunk(Cast(hlo_for_lmhlo.at(op))); - } - - if (mlir::isa(op)) { - return EmitRecvDoneThunk( - Cast(hlo_for_lmhlo.at(op))); - } - - return Internal("Unrecognized op: %s", llvm_ir::DumpToString(op)); -} - -absl::Status IrEmitterUnnested::EmitLmhloRegion( - mlir::Region* region, - const absl::flat_hash_map& - hlo_for_lmhlo) { - for (mlir::Operation& op : llvm::make_early_inc_range(region->front())) { - TF_RETURN_IF_ERROR(EmitOp(&op, hlo_for_lmhlo)); - } - return absl::OkStatus(); -} - absl::Status IrEmitterUnnested::EmitHloInstruction( const HloInstruction* instr) { // TODO(anlunx): Support other instruction opcodes. @@ -4395,9 +2860,23 @@ absl::Status IrEmitterUnnested::EmitHloInstruction( return EmitNcclAsyncDone(Thunk::kNcclReduceScatterDone, instr); case HloOpcode::kAllToAll: return EmitNcclAsyncDone(Thunk::kNcclAllToAllDone, instr); - default: + default: { + if (wrapped->has_backend_config()) { + TF_ASSIGN_OR_RETURN( + xla::gpu::GpuBackendConfig gpu_config, + wrapped->backend_config()); + if (gpu_config.operation_queue_id() != 0) { + // If there an async-done instruction that wraps an instruction + // that runs on a non-default stream, then we will + // just emit syncOnStreamThunk(). + return EmitWaitForStreamsThunk(instr, gpu_config, + /*is_async_done=*/true); + } + } + return Internal("Unsupported async done wrapped instruction: %s", HloOpcodeString(wrapped->opcode())); + } } } case HloOpcode::kAsyncStart: { @@ -4415,9 +2894,26 @@ absl::Status IrEmitterUnnested::EmitHloInstruction( return EmitNcclThunk( Thunk::kNcclAllToAll, instr, all_to_all, std::nullopt); } - default: + default: { + if (wrapped->has_backend_config()) { + TF_ASSIGN_OR_RETURN( + xla::gpu::GpuBackendConfig gpu_config, + wrapped->backend_config()); + if (gpu_config.operation_queue_id() != 0) { + // If there an async instruction that wraps an instruction + // that runs on a non-default stream, then we will + // emit syncOnStreamThunk(source=execution_stream, + // wait_on=main_compute_stream) + // then the thunk of wrapped instruction. + TF_RETURN_IF_ERROR( + EmitWaitForStreamsThunk(instr, gpu_config, + /*is_async_done=*/false)); + return EmitHloInstruction(wrapped); + } + } return Internal("Unsupported async start wrapped instruction: %s", HloOpcodeString(wrapped->opcode())); + } } } @@ -4448,6 +2944,18 @@ absl::Status IrEmitterUnnested::EmitHloInstruction( if (IsCublasLtMatmulF8(*instr)) { return EmitCublasLtMatmulThunkF8(custom_call); } + if (IsCudnnConvolutionReorder(*instr)) { + return EmitConvolutionReorderThunk(custom_call); + } + if (IsCustomCallToDnnNorm(*instr)) { + return EmitNormThunk(custom_call); + } + if (IsFwdCustomCallTofMHA(*instr)) { + return EmitFusedMHAThunk(custom_call); + } + if (IsBwdCustomCallTofMHA(*instr)) { + return EmitFusedMHABackwardThunk(custom_call); + } #endif // GOOGLE_CUDA if (IsCustomCallToTopK(*instr)) { return EmitTopKCustomCall(custom_call); @@ -4462,7 +2970,19 @@ absl::Status IrEmitterUnnested::EmitHloInstruction( if (IsTriangularSolve(*instr)) { return EmitTriangularSolveCustomCall(instr); } + if (IsCubDeviceRadixSort(*instr)) { + return EmitCubDeviceRadixSort(custom_call); + } #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM + if (custom_call->custom_call_target() == "PadToStatic") { + return EmitPadToStatic(custom_call); + } + if (instr->custom_call_target() == "SliceToDynamic") { + return EmitSliceToDynamic(custom_call); + } + if (instr->custom_call_target() == "__gpu$xla.gpu.triton") { + return EmitTritonCustomCall(custom_call); + } return EmitCustomCallThunk(custom_call); } case HloOpcode::kFusion: { @@ -4504,7 +3024,8 @@ absl::Status IrEmitterUnnested::EmitHloInstruction( case HloOpcode::kWhile: return EmitWhile(instr); - // HLO module is already ordered, so kAfterAll is a noop. + // HLO module is already scheduled, so instructions for ordering are noops. + case HloOpcode::kAddDependency: case HloOpcode::kAfterAll: // We don't need to emit thunks for these operations because their semantics // are encoded by buffers. diff --git a/third_party/xla/xla/service/gpu/ir_emitter_unnested.h b/third_party/xla/xla/service/gpu/ir_emitter_unnested.h index 18fccb3d292c26..1dd9afbcaae7a4 100644 --- a/third_party/xla/xla/service/gpu/ir_emitter_unnested.h +++ b/third_party/xla/xla/service/gpu/ir_emitter_unnested.h @@ -40,7 +40,7 @@ limitations under the License. #include "xla/service/gpu/hlo_fusion_analysis.h" #include "xla/service/gpu/ir_emitter.h" #include "xla/service/gpu/nccl_collective_thunk.h" -#include "xla/service/gpu/runtime3/send_recv_thunk.h" +#include "xla/service/gpu/runtime/send_recv_thunk.h" #include "xla/service/gpu/thunk.h" #include "xla/service/llvm_ir/ir_array.h" #include "xla/service/llvm_ir/llvm_util.h" @@ -48,6 +48,11 @@ limitations under the License. #include "xla/status.h" #include "xla/statusor.h" +#if TENSORFLOW_USE_ROCM +// for TF_HIPBLASLT +#include "rocm/rocm_config.h" +#endif + namespace xla { namespace gpu { @@ -106,22 +111,13 @@ class IrEmitterUnnested : public IrEmitter { return std::make_unique(std::move(thunk_sequence_)); } - // Emits code for the given LMHLO region. + // Emits code for the given HLO computation. // // Also populates related information to 'ir_emitter_context_' for // large-constant initializations. Large constants don't get initializers in // the generated code and so must be initialized by XLA. The value of these // constants will be stored in 'content'. Constants with initializers in the // generated code will have empty 'content'. - absl::Status EmitLmhloRegion( - mlir::Region* region, - const absl::flat_hash_map& - hlo_for_lmhlo); - - // Emits code for the given HLO computation. Right now it is only used to emit - // thunks for constructing command buffer. The plan is to replace - // EmitLmhloRegion by this function altogether, after we support emitting - // all instructions from HLO. absl::Status EmitHloComputation(const HloComputation* computation); static void GetDependentDialects(mlir::DialectRegistry& registry); @@ -129,79 +125,49 @@ class IrEmitterUnnested : public IrEmitter { private: explicit IrEmitterUnnested(IrEmitterContext* ir_emitter_context); - absl::Status EmitUnreachable(mlir::Operation* op, std::string error_message); - absl::Status EmitCommandBufferThunk(const HloInstruction* instr); // IrEmitterUnnested handles the following instructions differently from // IrEmitter. It also mixes in some special handling for custom kernels // via the ThunkEmitter. - absl::Status EmitConstant(mlir::Operation* op, const Literal& literal); absl::Status EmitConstant(const HloConstantInstruction* instr); - absl::Status EmitConditional( - mlir::Operation* op, - const absl::flat_hash_map& - hlo_for_lmhlo); absl::Status EmitConditional(const HloInstruction* instr); - absl::Status EmitConvolutionThunk(mlir::Operation* op); absl::Status EmitConvolutionThunk(const HloCustomCallInstruction* instr); - absl::Status EmitGemmThunk(mlir::Operation* op); absl::Status EmitGemmThunk(const HloCustomCallInstruction* instr); #if GOOGLE_CUDA || TF_HIPBLASLT - absl::Status EmitCublasLtMatmulThunk(mlir::Operation* op); absl::Status EmitCublasLtMatmulThunk(const HloCustomCallInstruction* instr); #endif // GOOGLE_CUDA || TF_HIPBLASLT #if GOOGLE_CUDA - absl::Status EmitCublasLtMatmulThunkF8(mlir::Operation* op); absl::Status EmitCublasLtMatmulThunkF8(const HloCustomCallInstruction* instr); - absl::Status EmitConvolutionReorderThunk(mlir::Operation* op); - absl::Status EmitNormThunk(mlir::Operation* op); - absl::Status EmitFusedMHAThunk(mlir::Operation* op); - absl::Status EmitFusedMHABackwardThunk(mlir::Operation* op); + absl::Status EmitConvolutionReorderThunk( + const HloCustomCallInstruction* instr); + absl::Status EmitNormThunk(const HloCustomCallInstruction* instr); + absl::Status EmitFusedMHAThunk(const HloCustomCallInstruction* instr); + absl::Status EmitFusedMHABackwardThunk(const HloCustomCallInstruction* instr); #endif // GOOGLE_CUDA #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM - absl::Status EmitCubDeviceRadixSort(mlir::Operation* op); - absl::Status EmitCholeskyThunk(mlir::Operation* op); + absl::Status EmitCubDeviceRadixSort(const HloCustomCallInstruction* instr); absl::Status EmitCholeskyThunk(const HloInstruction* instr); #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM - absl::Status EmitCustomCallThunk(mlir::Operation* op, - const HloCustomCallInstruction* instr); absl::Status EmitCustomCallThunk(const HloCustomCallInstruction* instr); - absl::Status EmitFftThunk(mlir::Operation* op); absl::Status EmitFftThunk(const HloFftInstruction* instr); - absl::Status EmitFusion( - mlir::Operation* op, - const absl::flat_hash_map& - hlo_for_lmhlo); absl::Status EmitFusion(const HloFusionInstruction* instr, HloFusionAnalysis& fusion_analysis); - absl::Status EmitSelectAndScatter( - mlir::Operation* op, - const absl::flat_hash_map& - hlo_for_lmhlo); absl::Status EmitSelectAndScatter( const HloSelectAndScatterInstruction* instr); - absl::Status EmitWhile( - mlir::Operation* op, - const absl::flat_hash_map& - hlo_for_lmhlo); absl::Status EmitWhile(const HloInstruction* instr); - absl::Status EmitInfeed(mlir::Operation* op); absl::Status EmitInfeed(const HloInfeedInstruction* instr); - absl::Status EmitOutfeed(mlir::Operation* op); absl::Status EmitOutfeed(const HloOutfeedInstruction* instr); - absl::Status EmitRngGetAndUpdateState(mlir::Operation* op); absl::Status EmitRngGetAndUpdateState( const HloRngGetAndUpdateStateInstruction* instr); - absl::Status EmitSort(mlir::Operation* op, const HloSortInstruction* sort); absl::Status EmitSort(const HloSortInstruction* sort); #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM - absl::Status EmitTriangularSolveCustomCall(mlir::Operation* op); absl::Status EmitTriangularSolveCustomCall(const HloInstruction* instr); #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM absl::Status EmitTopKCustomCall(const HloCustomCallInstruction* instr); + absl::Status EmitTritonCustomCall(const HloCustomCallInstruction* instr); absl::Status EmitSendThunk(const HloSendInstruction* instr); absl::Status EmitSendDoneThunk(const HloSendDoneInstruction* instr); @@ -209,12 +175,6 @@ class IrEmitterUnnested : public IrEmitter { absl::Status EmitRecvThunk(const HloRecvInstruction* instr); absl::Status EmitRecvDoneThunk(const HloRecvDoneInstruction* instr); - template - absl::Status EmitNcclThunk(mlir::Operation* op); - - absl::Status EmitNcclAsyncDone(Thunk::Kind kind, mlir::Operation* op, - mlir::Value token); - template absl::Status EmitNcclThunk(Thunk::Kind kind, const HloInstruction* async_start, @@ -223,24 +183,17 @@ class IrEmitterUnnested : public IrEmitter { absl::Status EmitNcclAsyncDone(Thunk::Kind kind, const HloInstruction* instr); - template - absl::Status EmitReplicaOrPartitionId(mlir::Operation* op); + absl::Status EmitWaitForStreamsThunk(const HloInstruction* inst, + GpuBackendConfig& gpu_config, + bool is_async_done); template absl::Status EmitReplicaOrPartitionId(const HloInstruction* instr); - absl::Status EmitCollectivePermute(mlir::Operation* op); absl::Status EmitCollectivePermute( const HloCollectivePermuteInstruction* instr); - absl::Status EmitOp( - mlir::Operation* op, - const absl::flat_hash_map& - hlo_for_lmhlo); - absl::Status EmitHloInstruction(const HloInstruction* instr); - static Thunk::ThunkInfo GetThunkInfo(mlir::Operation* op); - absl::Status EmitTargetElementLoop( const HloInstruction& hlo, const llvm_ir::ElementGenerator& body_emitter) override; @@ -327,7 +280,7 @@ class IrEmitterUnnested : public IrEmitter { // return; // } // ``` - absl::Status EmitPadToStatic(mlir::Operation* op); + absl::Status EmitPadToStatic(const HloCustomCallInstruction* instr); // Input = {dynamic array(with dynamic dimension meta data at the end)} // Output = {static array, dynamic_dim0, dynamic_dim1} @@ -373,44 +326,13 @@ class IrEmitterUnnested : public IrEmitter { // return; // } // ``` - absl::Status EmitSliceToDynamic(mlir::Operation* op); - - absl::StatusOr GetAllocationSlice(mlir::Value v); - absl::StatusOr> GetAllocationSlices( - mlir::OperandRange operands); + absl::Status EmitSliceToDynamic(const HloCustomCallInstruction* instr); int64_t ByteSizeOf(const Shape& shape) const { return llvm_ir::ByteSizeOf( shape, ir_emitter_context_->llvm_module()->getDataLayout()); } - // Emits kernel thunk for a custom fusion implemented with hand written custom - // device kernels. - absl::StatusOr EmitCustomFusion( - const HloFusionInstruction* fusion, mlir::lmhlo::FusionOp fusion_op, - const CustomFusionConfig& config); - - // Builds a kernel thunk for a non-fusion operation, without reuse. - // - // All input and output tensors of `op` are passed to the kernel. - // - // TODO(tdanyluk): Consider also reusing non-fusion kernels. - absl::StatusOr /*inputs*/, - std::vector /*outputs*/>> - BuildKernelThunkForNonFusionOp(mlir::Operation* op, - const LaunchDimensions& launch_dimensions); - - // Builds a kernel thunk for a non-fusion operation, without reuse. - // - // Only the tensors specified in `needed_operands` are passed to the kernel. - // - // TODO(tdanyluk): Consider also reusing non-fusion kernels. - absl::StatusOr /*inputs*/, - std::vector /*outputs*/>> - BuildKernelThunkForNonFusionOp(mlir::Operation* op, - mlir::ValueRange needed_operands, - const LaunchDimensions& launch_dimensions); - absl::StatusOr /*inputs*/, std::vector /*outputs*/>> BuildKernelThunkForNonFusionOp( @@ -418,20 +340,11 @@ class IrEmitterUnnested : public IrEmitter { absl::Span needed_operands, const LaunchDimensions& launch_dimensions); - absl::Status BuildInitializerThunk(mlir::Operation* op, - const HloInstruction* instr, - const HloInstruction* init_value, - mlir::Value init_value_mlir, - mlir::Value dest); + absl::Status BuildInitializerThunk(const HloInstruction* instr, + const HloInstruction* init_value); // Returns a WhileThunk that invokes thunk sequences for 'condition' and // 'body' sub-computations of while instruction. - absl::StatusOr> BuildWhileThunk( - mlir::lmhlo::WhileOp while_op, const Thunk::ThunkInfo& thunk_info, - const absl::flat_hash_map& - hlo_for_lmhlo, - std::optional trip_count); - absl::StatusOr> BuildWhileThunk( const HloInstruction* instr, const Thunk::ThunkInfo& thunk_info, std::optional trip_count); @@ -460,11 +373,6 @@ class IrEmitterUnnested : public IrEmitter { // Container for async send/recv events shared by send/recv thunks. std::shared_ptr send_recv_events_; - // Begin optional members for XLA HLO -> LMHLO: - absl::flat_hash_map> - scratch_nested_computations_; - // End optional members for XLA HLO -> LMHLO. - // Returns the ShapedSlices for the given operands. absl::StatusOr> GetShapedSlices( mlir::Operation::operand_range operands); diff --git a/third_party/xla/xla/service/gpu/kernel_arguments.cc b/third_party/xla/xla/service/gpu/kernel_arguments.cc index 6c43ca7c9882f1..0ed64ec0ec4ed9 100644 --- a/third_party/xla/xla/service/gpu/kernel_arguments.cc +++ b/third_party/xla/xla/service/gpu/kernel_arguments.cc @@ -103,6 +103,7 @@ std::vector KernelArguments::ProcessArguments( absl::flat_hash_map> first_indices_for_slices; + int next_llvm_arg_index = 0; for (int i = 0; i < static_cast(kernel_arguments.size()); ++i) { KernelArgument& kernel_argument = kernel_arguments[i]; @@ -113,9 +114,11 @@ std::vector KernelArguments::ProcessArguments( kernel_argument.alignment_ = same.alignment_; kernel_argument.aliased_ = same.aliased_; kernel_argument.written_ = same.written_; + kernel_argument.llvm_arg_index_ = same.llvm_arg_index_; continue; } else { first_index = i; + kernel_argument.llvm_arg_index_ = next_llvm_arg_index++; } const BufferAllocation* alloc = kernel_argument.slice().allocation(); diff --git a/third_party/xla/xla/service/gpu/kernel_arguments.h b/third_party/xla/xla/service/gpu/kernel_arguments.h index cf73904787e740..1512b0d8c5c144 100644 --- a/third_party/xla/xla/service/gpu/kernel_arguments.h +++ b/third_party/xla/xla/service/gpu/kernel_arguments.h @@ -47,6 +47,7 @@ class KernelArgument { return first_with_same_slice_; } bool aliased() const { return aliased_; } + int llvm_arg_index() const { return llvm_arg_index_; } private: KernelArgument(mlir::Value value, Shape shape, BufferAllocation::Slice slice, @@ -59,6 +60,7 @@ class KernelArgument { bool aliased_ = true; int64_t alignment_ = 1; bool written_ = true; + int llvm_arg_index_; // Holds the index of the first argument which has the same slice as this, // if this is not the first such argument. std::optional first_with_same_slice_; diff --git a/third_party/xla/xla/service/gpu/kernel_reuse_cache.cc b/third_party/xla/xla/service/gpu/kernel_reuse_cache.cc index 86831a99c1adba..2d4a9d3356b672 100644 --- a/third_party/xla/xla/service/gpu/kernel_reuse_cache.cc +++ b/third_party/xla/xla/service/gpu/kernel_reuse_cache.cc @@ -94,7 +94,13 @@ KernelReuseCache::GetWithStatus( fused_computation, kernel_arguments, discriminator); VLOG(4) << "Fingerprint: "; XLA_VLOG_LINES(4, fingerprint); + return GetWithStatus(std::move(fingerprint), generator); +} +std::pair, bool> +KernelReuseCache::GetWithStatus( + std::string fingerprint, + const std::function()>& generator) { auto it = cache_.find(fingerprint); if (it != cache_.end()) { return {&it->second, /*was_cached=*/true}; @@ -102,7 +108,8 @@ KernelReuseCache::GetWithStatus( absl::StatusOr entry = generator(); if (entry.ok()) { - it = cache_.insert({fingerprint, std::move(entry.value())}).first; + it = + cache_.insert({std::move(fingerprint), std::move(entry.value())}).first; return {&it->second, /*was_cached=*/false}; } diff --git a/third_party/xla/xla/service/gpu/kernel_reuse_cache.h b/third_party/xla/xla/service/gpu/kernel_reuse_cache.h index 3af903e219f2e9..ea55d97dc43989 100644 --- a/third_party/xla/xla/service/gpu/kernel_reuse_cache.h +++ b/third_party/xla/xla/service/gpu/kernel_reuse_cache.h @@ -17,6 +17,7 @@ limitations under the License. #include #include +#include #include #include @@ -27,6 +28,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_computation.h" #include "xla/service/gpu/kernel_arguments.h" #include "xla/service/gpu/launch_dimensions.h" +#include "xla/stream_executor/launch_dim.h" namespace xla { namespace gpu { @@ -38,7 +40,8 @@ class KernelReuseCache { struct Entry { std::string kernel_name; LaunchDimensions launch_dimensions; - int64_t shmem_bytes; + std::optional cluster_dim; + int64_t shmem_bytes = 0; }; // Retrieves the cache entry for the given computation, or generates it using @@ -54,6 +57,17 @@ class KernelReuseCache { absl::string_view discriminator, const std::function()>& generator); + // Retrieves the cache entry for the given fingerprint, or generates it using + // the given generator function and stores it in the cache. + // + // The returned pointer is never nullptr. + // + // A non-OK status is returned if the entry is not found and the generator + // failed. + std::pair, bool /*was_cached*/> GetWithStatus( + std::string fingerprint, + const std::function()>& generator); + private: absl::flat_hash_map cache_; }; diff --git a/third_party/xla/xla/service/gpu/kernels/BUILD b/third_party/xla/xla/service/gpu/kernels/BUILD index 12ab0a94af4b0c..cf62411f1451fd 100644 --- a/third_party/xla/xla/service/gpu/kernels/BUILD +++ b/third_party/xla/xla/service/gpu/kernels/BUILD @@ -1,11 +1,15 @@ load("//xla/tests:build_defs.bzl", "xla_test") load("@local_config_cuda//cuda:build_defs.bzl", "cuda_library") +load("//xla:xla.bzl", "xla_cc_test") load("//xla/stream_executor:build_defs.bzl", "if_gpu_is_configured") +load("//xla/service/gpu:build_defs.bzl", "gpu_kernel_library") load("@local_config_rocm//rocm:build_defs.bzl", "if_rocm_is_configured") +load("@local_tsl//tsl/platform:build_config_root.bzl", "tf_gpu_tests_tags") load("@local_tsl//tsl/platform/default:cuda_build_defs.bzl", "if_cuda_is_configured") package( - default_visibility = ["//visibility:public"], + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = ["//visibility:private"], licenses = ["notice"], ) @@ -18,7 +22,7 @@ cc_library( name = "custom_kernel_fusion", srcs = ["custom_kernel_fusion.cc"], hdrs = ["custom_kernel_fusion.h"], - visibility = ["//visibility:public"], + visibility = [":friends"], deps = [ ":custom_kernel", "//xla:status", @@ -38,7 +42,7 @@ cc_library( name = "custom_kernel_fusion_pattern", srcs = ["custom_kernel_fusion_pattern.cc"], hdrs = ["custom_kernel_fusion_pattern.h"], - visibility = ["//visibility:public"], + visibility = [":friends"], deps = [ "//xla:statusor", "//xla/hlo/ir:hlo", @@ -57,7 +61,7 @@ cc_library( name = "custom_kernel", srcs = ["custom_kernel.cc"], hdrs = ["custom_kernel.h"], - visibility = ["//visibility:public"], + visibility = [":friends"], deps = [ "//xla/stream_executor", "@com_google_absl//absl/strings:str_format", @@ -68,7 +72,7 @@ cc_library( # a single dependency. cc_library( name = "custom_fusion_library", - visibility = ["//visibility:public"], + visibility = [":friends"], deps = [":cutlass_gemm_fusion"], ) @@ -76,7 +80,6 @@ cc_library( name = "cutlass_gemm_fusion", srcs = ["cutlass_gemm_fusion.cc"], hdrs = ["cutlass_gemm_fusion.h"], - visibility = ["//visibility:public"], deps = [ ":custom_kernel", ":custom_kernel_fusion", @@ -120,6 +123,76 @@ xla_test( ], ) +cc_library( + name = "topk_kernel", + srcs = if_gpu_is_configured(["topk_kernel.cc"]), + hdrs = if_gpu_is_configured(["topk_kernel.h"]), + compatible_with = [], + local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]) + if_rocm_is_configured([ + "TENSORFLOW_USE_ROCM=1", + ]), + deps = [ + "//xla:shape_util", + "//xla:types", + "//xla:util", + "//xla:xla_data_proto_cc", + "//xla/stream_executor", # build_cleaner: keep + "//xla/stream_executor:platform", + "//xla/stream_executor/gpu:gpu_stream_header", + "//xla/stream_executor/gpu:gpu_types_header", + "@com_google_absl//absl/numeric:bits", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:logging", + "@local_tsl//tsl/platform:statusor", + ] + if_gpu_is_configured([ + ":topk_kernel_gpu", + ]), +) + +gpu_kernel_library( + name = "topk_kernel_gpu", + srcs = if_gpu_is_configured([ + "topk_kernel_bfloat16.cu.cc", + "topk_kernel_float.cu.cc", + "topk_kernel.cu.h", + ]), + hdrs = if_gpu_is_configured(["topk_kernel_common.h"]), + compatible_with = [], + deps = [ + "//xla:types", + "//xla/stream_executor/gpu:gpu_types_header", + "@local_tsl//tsl/lib/math:math_util", + ], +) + +xla_cc_test( + name = "topk_kernel_test", + srcs = if_gpu_is_configured(["topk_kernel_test.cc"]), + tags = tf_gpu_tests_tags(), + deps = [ + ":topk_kernel", + "//xla:types", + "//xla:xla_data_proto_cc", + "//xla/stream_executor", # build_cleaner: keep + "//xla/stream_executor:platform_manager", + "//xla/stream_executor/gpu:gpu_init", + "//xla/stream_executor/gpu:gpu_stream_header", + "//xla/stream_executor/gpu:gpu_timer_header", + "//xla/stream_executor/gpu:gpu_types_header", + "//xla/stream_executor/host:host_platform", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/random", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/time", + "@local_tsl//tsl/platform:test", + "@local_tsl//tsl/platform:test_benchmark", + "@local_tsl//tsl/platform:test_main", + ], +) + cc_library( name = "topk_custom_kernel", srcs = ["topk_custom_kernel.cc"], @@ -127,23 +200,20 @@ cc_library( local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]) + if_rocm_is_configured([ "TENSORFLOW_USE_ROCM=1", ]), - visibility = ["//visibility:public"], + visibility = [":friends"], deps = [ ":custom_kernel", "//xla:statusor", + "//xla:types", "//xla:xla_data_proto_cc", "//xla/stream_executor", "@com_google_absl//absl/numeric:bits", "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", - "@eigen_archive//:eigen3", "@local_tsl//tsl/platform:statusor", ] + if_gpu_is_configured([ - "//xla/service/gpu/runtime:gpu_kernel_helper", - ]) + if_cuda_is_configured([ - "//xla/service/gpu/runtime:topk_kernel_cuda", - ]) + if_rocm_is_configured([ - "//xla/service/gpu/runtime:topk_kernel_rocm", + ":topk_kernel_gpu", ]), ) @@ -153,18 +223,19 @@ xla_test( backends = ["gpu"], deps = [ ":topk_custom_kernel", + "//xla:types", "//xla:xla_data_proto_cc", "//xla/service:platform_util", "//xla/stream_executor", - "//xla/stream_executor:multi_platform_manager", "//xla/stream_executor:platform", + "//xla/stream_executor:platform_manager", "//xla/stream_executor/cuda:cuda_platform", "@com_google_absl//absl/random", "@com_google_absl//absl/strings", "@com_google_googletest//:gtest", - "@eigen_archive//:eigen3", "@local_tsl//tsl/lib/core:status_test_util", "@local_tsl//tsl/platform:path", + "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/platform:test", "@local_tsl//tsl/platform:test_main", ], @@ -181,7 +252,6 @@ cc_library( ["cutlass_gemm_custom_kernel_stub.cc"], ), hdrs = ["cutlass_gemm_custom_kernel.h"], - visibility = ["//visibility:public"], deps = [ ":custom_kernel", ":cutlass_gemm", @@ -204,8 +274,8 @@ xla_test( ":cutlass_gemm_custom_kernel", "//xla:xla_data_proto_cc", "//xla/stream_executor", - "//xla/stream_executor:multi_platform_manager", "//xla/stream_executor:platform", + "//xla/stream_executor:platform_manager", "//xla/stream_executor/cuda:cuda_platform", "@local_tsl//tsl/lib/core:status_test_util", "@local_tsl//tsl/platform:path", @@ -224,10 +294,11 @@ cc_binary( "//xla:xla_data_proto_cc", "//xla/service:gpu_plugin", "//xla/stream_executor", - "//xla/stream_executor:multi_platform_manager", "//xla/stream_executor:platform", + "//xla/stream_executor:platform_manager", "//xla/stream_executor/cuda:cuda_platform", "@local_tsl//tsl/platform:status", + "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/platform:test", "@local_tsl//tsl/platform:test_benchmark", "@local_tsl//tsl/platform:test_main", @@ -242,7 +313,6 @@ cc_library( name = "cutlass_gemm", srcs = ["cutlass_gemm.cc"], hdrs = ["cutlass_gemm.h"], - visibility = ["//visibility:public"], deps = ["@local_tsl//tsl/platform:logging"], ) @@ -250,7 +320,6 @@ cuda_library( name = "cutlass_gemm_adaptor", hdrs = if_cuda_is_configured(["cutlass_gemm_adaptor.cu.h"]), copts = ["-Wno-unknown-attributes"], # __grid_constant__ is not supported by clang - visibility = ["//visibility:public"], deps = if_cuda_is_configured([ ":cutlass_gemm", "@cutlass_archive//:cutlass", @@ -261,7 +330,6 @@ cuda_library( name = "cutlass_gemm_epilogue", # TODO(ezhulenev): Update to regular hdrs after fixing CUTLASS headers. textual_hdrs = if_cuda_is_configured(["cutlass_gemm_epilogue.cu.h"]), - visibility = ["//visibility:public"], deps = if_cuda_is_configured(["@cutlass_archive//:cutlass"]), ) @@ -274,7 +342,6 @@ cuda_library( cc_library( name = "cutlass_gemm_kernels", - visibility = ["//visibility:public"], deps = [ ":cutlass_gemm_kernel_bf16xbf16_to_bf16", ":cutlass_gemm_kernel_bf16xbf16_to_bf16_sm80", @@ -293,7 +360,6 @@ cuda_library( name = "cutlass_gemm_kernel_bf16xbf16_to_bf16", srcs = if_cuda_is_configured(["cutlass_gemm_kernel_bf16xbf16_to_bf16.cu.cc"]), copts = ["-Wno-unknown-attributes -mllvm -unroll-threshold=100000"], - visibility = ["//visibility:public"], deps = if_cuda_is_configured([ ":cutlass_gemm_adaptor", "@cutlass_archive//:cutlass", @@ -305,7 +371,6 @@ cuda_library( name = "cutlass_gemm_kernel_bf16xbf16_to_bf16_sm80", srcs = if_cuda_is_configured(["cutlass_gemm_kernel_bf16xbf16_to_bf16_sm80.cu.cc"]), copts = ["-Wno-unknown-attributes -mllvm -unroll-threshold=100000"], - visibility = ["//visibility:public"], deps = if_cuda_is_configured([ ":cutlass_gemm_adaptor", "@cutlass_archive//:cutlass", @@ -317,7 +382,6 @@ cuda_library( name = "cutlass_gemm_kernel_bf16xbf16_to_bf16_sm90", srcs = if_cuda_is_configured(["cutlass_gemm_kernel_bf16xbf16_to_bf16_sm90.cu.cc"]), copts = ["-Wno-ctad-maybe-unsupported -Wno-unknown-attributes -mllvm -unroll-threshold=100000"], - visibility = ["//visibility:public"], deps = if_cuda_is_configured([ ":cutlass_gemm_adaptor", ":cutlass_gemm_epilogue", @@ -330,7 +394,6 @@ cuda_library( name = "cutlass_gemm_kernel_f32xf32_to_f32", srcs = if_cuda_is_configured(["cutlass_gemm_kernel_f32xf32_to_f32.cu.cc"]), copts = ["-Wno-unknown-attributes"], - visibility = ["//visibility:public"], deps = if_cuda_is_configured([ ":cutlass_gemm_adaptor", "@cutlass_archive//:cutlass", diff --git a/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_custom_kernel_benchmarks.cc b/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_custom_kernel_benchmarks.cc index 4b2c8c5af5b624..77c37ae30ed6b3 100644 --- a/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_custom_kernel_benchmarks.cc +++ b/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_custom_kernel_benchmarks.cc @@ -20,12 +20,13 @@ limitations under the License. #include "xla/service/gpu/kernels/cutlass_gemm_custom_kernel.h" #include "xla/stream_executor/device_description.h" #include "xla/stream_executor/kernel.h" -#include "xla/stream_executor/multi_platform_manager.h" #include "xla/stream_executor/platform.h" +#include "xla/stream_executor/platform_manager.h" #include "xla/stream_executor/stream.h" #include "xla/stream_executor/stream_executor.h" #include "xla/xla_data.pb.h" #include "tsl/platform/status.h" +#include "tsl/platform/statusor.h" #include "tsl/platform/test.h" #include "tsl/platform/test_benchmark.h" @@ -39,16 +40,14 @@ static uint32_t BitPattern(float value) { static void BM_RowMajorGemm(benchmark::State& state) { se::Platform* platform = - se::MultiPlatformManager::PlatformWithName("CUDA").value(); + se::PlatformManager::PlatformWithName("CUDA").value(); se::StreamExecutor* executor = platform->ExecutorForDevice(0).value(); const se::DeviceDescription& device = executor->GetDeviceDescription(); se::Stream stream(executor); - stream.Init(); + TF_CHECK_OK(stream.Initialize()); ASSERT_TRUE(stream.ok()); - se::Kernel gemm(executor); - // GEMM: 8192x4096 * 4096x16384 -> 8192x16384 int32_t m = 8192; int32_t n = 16384; @@ -57,16 +56,18 @@ static void BM_RowMajorGemm(benchmark::State& state) { auto custom_kernel = GetCutlassGemmKernel("cutlass_gemm", PrimitiveType::BF16, m, n, k, /*indices=*/{0, 1, 2}, /*slices=*/{}, device); - TF_CHECK_OK(executor->GetKernel(custom_kernel->kernel_spec(), &gemm)); + + TF_ASSERT_OK_AND_ASSIGN( + auto gemm, se::Kernel::Create(executor, custom_kernel->kernel_spec())); // Prepare arguments: a=1.1, b=1.2, c=0.0 se::DeviceMemory a = executor->AllocateArray(m * k, 0); se::DeviceMemory b = executor->AllocateArray(k * n, 0); se::DeviceMemory c = executor->AllocateArray(m * n, 0); - stream.ThenMemset32(&a, BitPattern(1.1f), a.size()); - stream.ThenMemset32(&b, BitPattern(1.2f), b.size()); - stream.ThenMemZero(&c, c.size()); + TF_CHECK_OK(stream.Memset32(&a, BitPattern(1.1f), a.size())); + TF_CHECK_OK(stream.Memset32(&b, BitPattern(1.2f), b.size())); + TF_CHECK_OK(stream.MemZero(&c, c.size())); se::KernelArgsDeviceMemoryArray args( std::vector({a, b, c}), @@ -74,7 +75,7 @@ static void BM_RowMajorGemm(benchmark::State& state) { for (auto s : state) { TF_CHECK_OK(executor->Launch(&stream, custom_kernel->thread_dims(), - custom_kernel->block_dims(), gemm, args)); + custom_kernel->block_dims(), *gemm, args)); TF_CHECK_OK(stream.BlockHostUntilDone()); } } diff --git a/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_custom_kernel_test.cc b/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_custom_kernel_test.cc index b258192cf3ada3..f4e67adb7f64a7 100644 --- a/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_custom_kernel_test.cc +++ b/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_custom_kernel_test.cc @@ -21,8 +21,8 @@ limitations under the License. #include #include "xla/stream_executor/kernel.h" -#include "xla/stream_executor/multi_platform_manager.h" #include "xla/stream_executor/platform.h" +#include "xla/stream_executor/platform_manager.h" #include "xla/stream_executor/stream.h" #include "xla/stream_executor/stream_executor.h" #include "xla/xla_data.pb.h" @@ -35,20 +35,20 @@ namespace xla::gpu::kernel::gemm_universal { TEST(CutlassGemmKernelTest, SimpleGemm) { se::Platform* platform = - se::MultiPlatformManager::PlatformWithName("CUDA").value(); + se::PlatformManager::PlatformWithName("CUDA").value(); se::StreamExecutor* executor = platform->ExecutorForDevice(0).value(); se::Stream stream(executor); - stream.Init(); + TF_ASSERT_OK(stream.Initialize()); ASSERT_TRUE(stream.ok()); - se::Kernel gemm(executor); - // Load [4, 4] x [4, 4] gemm kernel written in CUDA C++ with CUTLASS. auto custom_kernel = GetCutlassGemmKernel( "cutlass_gemm", PrimitiveType::F32, 4, 4, 4, /*indices=*/{0, 1, 2}, /*slices=*/{}, executor->GetDeviceDescription()); - TF_ASSERT_OK(executor->GetKernel(custom_kernel->kernel_spec(), &gemm)); + + TF_ASSERT_OK_AND_ASSIGN( + auto gemm, se::Kernel::Create(executor, custom_kernel->kernel_spec())); int64_t length = 4 * 4; int64_t byte_length = sizeof(float) * length; @@ -62,20 +62,20 @@ TEST(CutlassGemmKernelTest, SimpleGemm) { uint32_t pattern; std::memcpy(&pattern, &value, sizeof(pattern)); - stream.ThenMemset32(&a, pattern, byte_length); - stream.ThenMemset32(&b, pattern, byte_length); - stream.ThenMemZero(&c, byte_length); + TF_ASSERT_OK(stream.Memset32(&a, pattern, byte_length)); + TF_ASSERT_OK(stream.Memset32(&b, pattern, byte_length)); + TF_ASSERT_OK(stream.MemZero(&c, byte_length)); // Launch gemm kernel with device memory arguments. se::KernelArgsDeviceMemoryArray arr( std::vector({a, b, c}), custom_kernel->shared_memory_bytes()); TF_ASSERT_OK(executor->Launch(&stream, custom_kernel->thread_dims(), - custom_kernel->block_dims(), gemm, arr)); + custom_kernel->block_dims(), *gemm, arr)); // Copy `c` data back to host. std::vector dst(length, -1.0f); - stream.ThenMemcpy(dst.data(), c, byte_length); + TF_ASSERT_OK(stream.Memcpy(dst.data(), c, byte_length)); std::vector expected(length, 16.0); ASSERT_EQ(dst, expected); @@ -87,20 +87,20 @@ TEST(CutlassGemmKernelTest, LoadFromSharedLibrary) { "cutlass_gemm_kernel_f32xf32_to_f32.so"); se::Platform* platform = - se::MultiPlatformManager::PlatformWithName("CUDA").value(); + se::PlatformManager::PlatformWithName("CUDA").value(); se::StreamExecutor* executor = platform->ExecutorForDevice(0).value(); se::Stream stream(executor); - stream.Init(); + TF_ASSERT_OK(stream.Initialize()); ASSERT_TRUE(stream.ok()); - se::Kernel gemm(executor); - // Load [4, 4] x [4, 4] gemm kernel written in CUDA C++ with CUTLASS. auto custom_kernel = LoadCutlassGemmKernel( "cutlass_gemm", kernel_lib_path, PrimitiveType::F32, 4, 4, 4, /*indices=*/{0, 1, 2}, /*slices=*/{}, executor->GetDeviceDescription()); - TF_ASSERT_OK(executor->GetKernel(custom_kernel->kernel_spec(), &gemm)); + + TF_ASSERT_OK_AND_ASSIGN( + auto gemm, se::Kernel::Create(executor, custom_kernel->kernel_spec())); int64_t length = 4 * 4; int64_t byte_length = sizeof(float) * length; @@ -113,20 +113,20 @@ TEST(CutlassGemmKernelTest, LoadFromSharedLibrary) { uint32_t pattern; std::memcpy(&pattern, &value, sizeof(pattern)); - stream.ThenMemset32(&a, pattern, byte_length); - stream.ThenMemset32(&b, pattern, byte_length); - stream.ThenMemZero(&c, byte_length); + TF_ASSERT_OK(stream.Memset32(&a, pattern, byte_length)); + TF_ASSERT_OK(stream.Memset32(&b, pattern, byte_length)); + TF_ASSERT_OK(stream.MemZero(&c, byte_length)); // Launch gemm kernel with device memory arguments. se::KernelArgsDeviceMemoryArray arr( std::vector({a, b, c}), custom_kernel->shared_memory_bytes()); TF_ASSERT_OK(executor->Launch(&stream, custom_kernel->thread_dims(), - custom_kernel->block_dims(), gemm, arr)); + custom_kernel->block_dims(), *gemm, arr)); // Copy `c` data back to host. std::vector dst(length, -1.0f); - stream.ThenMemcpy(dst.data(), c, byte_length); + TF_ASSERT_OK(stream.Memcpy(dst.data(), c, byte_length)); std::vector expected(length, 16.0); ASSERT_EQ(dst, expected); diff --git a/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_fusion.cc b/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_fusion.cc index a2081e85e68455..70fa91a8d14154 100644 --- a/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_fusion.cc +++ b/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_fusion.cc @@ -82,7 +82,13 @@ struct GemmWithDynamicSlice { explicit GemmWithDynamicSlice(HloDynamicUpdateSliceInstruction* update_slice) : update_slice(update_slice) {} - std::vector Instrs() { return {dot, bitcast, update_slice}; } + std::vector Instrs() { + // Bitcast could be optional + if (bitcast == nullptr) { + return {dot, update_slice}; + } + return {dot, bitcast, update_slice}; + } HloInstruction* dot = nullptr; HloInstruction* bitcast = nullptr; // result bitcast @@ -152,14 +158,20 @@ static absl::StatusOr MatchGemmWithUpcast( return absl::InternalError("unsupported gemm with upcasing"); } +template +auto OptionalBitcast(HloInstruction** optional_bitcast, Pattern pattern) { + return m::AnyOf(m::Bitcast(optional_bitcast, pattern), + std::move(pattern)); +} + // Returns matched GEMM with result used to update a slice. static absl::StatusOr MatchGemmWithDynamicUpdateSlice( HloDynamicUpdateSliceInstruction* update_slice) { GemmWithDynamicSlice match(update_slice); - if (!Match( - const_cast(update_slice->operand(1)), - m::Bitcast(&match.bitcast, m::Dot(&match.dot, m::Op(), m::Op())))) { + if (!Match(const_cast(update_slice->operand(1)), + OptionalBitcast(&match.bitcast, + m::Dot(&match.dot, m::Op(), m::Op())))) { return absl::InternalError("failed to match update slice instr"); } @@ -204,9 +216,12 @@ CutlassGemmWithDynamicUpdateSlicePattern::TryMatch( match.AddReplacement(matched->dot, [=](HloFusionInstruction* fusion) { HloComputation* parent = fusion->parent(); auto* dus = Cast(matched->update_slice); + bool has_bitcast = matched->bitcast != nullptr; + const Shape dus_shape = + has_bitcast ? matched->bitcast->shape() : matched->dot->shape(); auto* slice = parent->AddInstruction(HloInstruction::CreateDynamicSlice( - matched->bitcast->shape(), fusion, dus->index_operands(), - matched->bitcast->shape().dimensions())); + dus_shape, fusion, dus->index_operands(), dus_shape.dimensions())); + return parent->AddInstruction( HloInstruction::CreateBitcast(matched->dot->shape(), slice)); }); @@ -337,7 +352,6 @@ class CutlassGemmWithDynamicUpdateSliceFusion : public CustomKernelFusion { // Mapping to a buffer that holds output slice offset. auto* offset = Cast(matched.update_slice->operand(2)); - kernel::gemm_universal::DynamicSliceIndices slices; slices.out = offset->parameter_number(); diff --git a/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_fusion_test.cc b/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_fusion_test.cc index b7495aeb4ca826..0ffe3fa8fe4f3c 100644 --- a/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_fusion_test.cc +++ b/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_fusion_test.cc @@ -225,6 +225,52 @@ TEST_F(CutlassFusionTest, RowMajorGemmWithDynamicUpdateSliceMultipleUses) { RunAndFilecheckHloRewrite(hlo, std::move(pass), expected); } +TEST_F(CutlassFusionTest, RowMajorGemmWithDynamicUpdateSliceWithoutBitcast) { + const char* hlo = R"( + HloModule test + + ENTRY %main (p0: f32[4,2], p1: f32[2,2], i: s32[]) -> f32[4,2] { + %p0 = f32[4,2]{1,0} parameter(0) + %p1 = f32[2,2]{1,0} parameter(1) + %i = s32[] parameter(2) + + %dot = f32[2,2]{1,0} dot(%p1, %p1), + lhs_contracting_dims={1}, + rhs_contracting_dims={0} + + ROOT %r = f32[4,2]{1,0} dynamic-update-slice(%p0, %dot, %i, %i) + } + )"; + + const char* expected = R"( + ; CHECK: %cutlass_gemm_with_dynamic_update_slice {{.*}} { + ; CHECK-DAG: [[P1:%[^ ]+]] = f32[4,2]{1,0} parameter + ; CHECK-DAG: [[P0:%[^ ]+]] = f32[2,2]{1,0} parameter + ; CHECK-DAG: [[DOT:%[^ ]+]] = f32[2,2]{1,0} dot([[P0]], [[P0]]) + ; CHECK-DAG: [[P2:%[^ ]+]] = s32[] parameter + ; CHECK: ROOT [[DUS:%[^ ]+]] = f32[4,2]{1,0} dynamic-update-slice([[P1]], [[DOT]], [[P2]], [[P2]]) + ; CHECK: } + + ; CHECK: ENTRY %main {{.*}} { + ; CHECK: ROOT [[FUSION:%[^ ]+]] = f32[4,2]{1,0} fusion + ; CHECK: kind=kCustom, calls=%cutlass_gemm_with_dynamic_update_slice, + ; CHECK: backend_config={ + ; CHECK: "kind":"__custom_fusion", + ; CHECK: "custom_fusion_config":{ + ; CHECK: "name":"cutlass_gemm_with_dynamic_update_slice" + ; CHECK: } + ; CHECK: } + ; CHECK: } + )"; + + CustomKernelFusionPatternRegistry patterns; + patterns.Emplace(); + + auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo(); + CustomKernelFusionRewriter pass(&device, &patterns); + RunAndFilecheckHloRewrite(hlo, std::move(pass), expected); +} + //===----------------------------------------------------------------------===// // Run And Compare Tests //===----------------------------------------------------------------------===// @@ -373,4 +419,72 @@ TEST_F(CutlassFusionTest, RowMajorGemmWithDynamicUpdateSliceKernel) { /*run_hlo_passes=*/false)); } +TEST_F(CutlassFusionTest, + RowMajorGemmWithDynamicUpdateSliceKernelWithoutBitcast) { + ErrorSpec error_spec{/*aabs=*/1e-3, /*arel=*/1e-3}; + + const char* hlo_text_cublas = R"( + HloModule cublas + + ENTRY e { + p0 = bf16[16,8]{1,0} parameter(0) + p1 = bf16[8,8]{1,0} parameter(1) + p2 = s32[] parameter(2) + p3 = s32[] parameter(3) + + gemm.tuple = (bf16[8,8]{1,0}, s8[0]{0}) custom-call(p1, p1), + custom_call_target="__cublas$gemm", + backend_config={"gemm_backend_config":{"alpha_real":1,"beta":0,"dot_dimension_numbers":{"lhs_contracting_dimensions":[1],"rhs_contracting_dimensions":[0],"lhs_batch_dimensions":[],"rhs_batch_dimensions":[]},"alpha_imag":0,"precision_config":{"operand_precision":["DEFAULT","DEFAULT"]},"epilogue":"DEFAULT"}} + gemm = bf16[8,8]{1,0} get-tuple-element(gemm.tuple), index=0 + + ROOT r = bf16[16,8]{1,0} dynamic-update-slice(p0, gemm, p2, p3) + } + )"; + + const char* hlo_text_custom_fusion = R"( + HloModule cutlass + + cutlass_gemm { + p0.1 = bf16[8,8]{1,0} parameter(0) + p1.1 = bf16[16,8]{1,0} parameter(1) + p2 = s32[] parameter(2) + p3 = s32[] parameter(3) + dot.1 = bf16[8,8]{1,0} dot(p0.1, p0.1), lhs_contracting_dims={1}, rhs_contracting_dims={0} + r.1 = bf16[16,8]{1,0} dynamic-update-slice(p1.1, dot.1, p2, p3) + workspace = u8[1024]{0} custom-call(), + custom_call_target="__custom_kernel_fusion$workspace", + api_version=API_VERSION_TYPED_FFI + ROOT tuple = (bf16[16,8]{1,0}, u8[1024]{0}) tuple(r.1, workspace) + } + + ENTRY e { + p0 = bf16[16,8]{1,0} parameter(0) + p1 = bf16[8,8]{1,0} parameter(1) + p2 = s32[] parameter(2) + p3 = s32[] parameter(3) + r.0 = (bf16[16,8]{1,0}, u8[1024]{0}) fusion(p1, p0, p2, p3), kind=kCustom, + calls=%cutlass_gemm, + backend_config={"fusion_backend_config":{"kind":"__custom_fusion","custom_fusion_config":{"name":"cutlass_gemm_with_dynamic_update_slice"}}} + ROOT %get-tuple-element = bf16[16,8]{1,0} get-tuple-element(r.0), index=0 + })"; + + Array2D p0_arr(16, 8); // bf16[16,8] + Array2D p1_arr(8, 8); // bf16[8,8] + p1_arr.Each([](int64_t i, int64_t j, bfloat16* out) { + *out = bfloat16{1.0f * i * j}; + }); + + Array p2_arr({}, 0); + Array p3_arr({}, 1); + + auto p0 = LiteralUtil::CreateFromArray(p0_arr); + auto p1 = LiteralUtil::CreateFromArray(p1_arr); + auto p2 = LiteralUtil::CreateFromArray(p2_arr); + auto p3 = LiteralUtil::CreateFromArray(p3_arr); + + EXPECT_TRUE(RunAndCompareTwoModules(hlo_text_cublas, hlo_text_custom_fusion, + {&p0, &p1, &p2, &p3}, error_spec, + /*run_hlo_passes=*/false)); +} + } // namespace xla::gpu diff --git a/third_party/xla/xla/service/gpu/kernels/topk_custom_kernel.cc b/third_party/xla/xla/service/gpu/kernels/topk_custom_kernel.cc index 1fc8d420c512bf..31364a9a0078a8 100644 --- a/third_party/xla/xla/service/gpu/kernels/topk_custom_kernel.cc +++ b/third_party/xla/xla/service/gpu/kernels/topk_custom_kernel.cc @@ -22,23 +22,23 @@ limitations under the License. #include #include -#if defined(GOOGLE_CUDA) || defined(TENSORFLOW_USE_ROCM) #include "absl/numeric/bits.h" #include "absl/status/status.h" +#include "absl/status/statusor.h" #include "absl/strings/str_cat.h" -#include "Eigen/Core" // from @eigen_archive -#include "xla/service/gpu/runtime/gpu_kernel_helper.h" -#include "xla/service/gpu/runtime/topk_kernel_common.h" -#include "xla/statusor.h" +#include "xla/service/gpu/kernels/custom_kernel.h" #include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/kernel.h" #include "xla/stream_executor/kernel_spec.h" -#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM - -#include "xla/service/gpu/kernels/custom_kernel.h" +#include "xla/stream_executor/launch_dim.h" +#include "xla/types.h" #include "xla/xla_data.pb.h" #include "tsl/platform/statusor.h" +#if defined(GOOGLE_CUDA) || defined(TENSORFLOW_USE_ROCM) +#include "xla/service/gpu/kernels/topk_kernel_common.h" +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM + namespace xla::gpu::kernel::topk { #if defined(GOOGLE_CUDA) || defined(TENSORFLOW_USE_ROCM) @@ -97,9 +97,8 @@ template absl::StatusOr GetTypedTopK(std::string name, size_t num_elements, size_t k, size_t batch_size) { constexpr size_t kMaxKVSize = sizeof(uint64_t); - constexpr size_t kWavefrontSize = WAVEFRONT_SIZE; // Allocate shmem assuming we have a full reduction. - int shmem_size = absl::bit_ceil(k) * kMaxKVSize * kWavefrontSize; + int shmem_size = absl::bit_ceil(k) * kMaxKVSize * GetTopKWaveFrontSize(); int num_threads = EstimateOptimalNumThreads(num_elements, k, batch_size); if (num_threads == 0) { return absl::FailedPreconditionError( @@ -128,8 +127,8 @@ absl::StatusOr GetTopKKernel(std::string name, case PrimitiveType::F32: return GetTypedTopK(std::move(name), num_elements, k, batch_size); case PrimitiveType::BF16: - return GetTypedTopK(std::move(name), num_elements, k, - batch_size); + return GetTypedTopK(std::move(name), num_elements, k, + batch_size); default: return absl::InvalidArgumentError( absl::StrCat("Unsupported GpuTopK data type: ", dtype)); diff --git a/third_party/xla/xla/service/gpu/kernels/topk_custom_kernel_test.cc b/third_party/xla/xla/service/gpu/kernels/topk_custom_kernel_test.cc index 4c93885429bc9b..8236632dee50bf 100644 --- a/third_party/xla/xla/service/gpu/kernels/topk_custom_kernel_test.cc +++ b/third_party/xla/xla/service/gpu/kernels/topk_custom_kernel_test.cc @@ -26,15 +26,16 @@ limitations under the License. #include "absl/random/random.h" #include "absl/strings/ascii.h" #include "absl/strings/substitute.h" -#include "Eigen/Core" // from @eigen_archive #include "xla/service/platform_util.h" #include "xla/stream_executor/kernel.h" -#include "xla/stream_executor/multi_platform_manager.h" #include "xla/stream_executor/platform.h" +#include "xla/stream_executor/platform_manager.h" #include "xla/stream_executor/stream.h" #include "xla/stream_executor/stream_executor.h" +#include "xla/types.h" #include "xla/xla_data.pb.h" #include "tsl/lib/core/status_test_util.h" +#include "tsl/platform/statusor.h" #include "tsl/platform/test.h" namespace xla::gpu::kernel::topk { @@ -67,7 +68,7 @@ std::vector RandomVecNegative(int num_elements) { PrimitiveType Get(float) { return PrimitiveType::F32; } -PrimitiveType Get(Eigen::bfloat16) { return PrimitiveType::BF16; } +PrimitiveType Get(bfloat16) { return PrimitiveType::BF16; } // Params: // - n_kb: number of elements in kilobytes. @@ -84,12 +85,11 @@ TEST_P(TopKKernelTest, TopKFloat) { auto name = absl::AsciiStrToUpper(PlatformUtil::CanonicalPlatformName("gpu").value()); - se::Platform* platform = - se::MultiPlatformManager::PlatformWithName(name).value(); + se::Platform* platform = se::PlatformManager::PlatformWithName(name).value(); se::StreamExecutor* executor = platform->ExecutorForDevice(0).value(); se::Stream stream(executor); - stream.Init(); + TF_ASSERT_OK(stream.Initialize()); ASSERT_TRUE(stream.ok()); const auto [n_kb, k, batch_size, offset] = GetParam(); @@ -103,14 +103,17 @@ TEST_P(TopKKernelTest, TopKFloat) { executor->AllocateArray(k * batch_size, 0); auto source = RandomVec(n * batch_size); - stream.ThenMemcpy(&input_buffer, source.data(), n * batch_size * sizeof(T)); - stream.ThenMemZero(&output_values, k * batch_size * sizeof(T)); - stream.ThenMemZero(&output_indices, k * batch_size * sizeof(uint32_t)); + TF_ASSERT_OK( + stream.Memcpy(&input_buffer, source.data(), n * batch_size * sizeof(T))); + TF_ASSERT_OK(stream.MemZero(&output_values, k * batch_size * sizeof(T))); + TF_ASSERT_OK( + stream.MemZero(&output_indices, k * batch_size * sizeof(uint32_t))); - se::Kernel kernel(executor); auto custom_kernel = GetTopKKernel("topk", PrimitiveType::F32, n, k, batch_size); - TF_ASSERT_OK(executor->GetKernel(custom_kernel->kernel_spec(), &kernel)); + + TF_ASSERT_OK_AND_ASSIGN( + auto kernel, se::Kernel::Create(executor, custom_kernel->kernel_spec())); // Launch topk kernel with device memory arguments. se::KernelArgsDeviceMemoryArray arr( @@ -118,13 +121,13 @@ TEST_P(TopKKernelTest, TopKFloat) { {input_buffer, output_values, output_indices}), custom_kernel->shared_memory_bytes()); TF_ASSERT_OK(executor->Launch(&stream, custom_kernel->thread_dims(), - custom_kernel->block_dims(), kernel, arr)); + custom_kernel->block_dims(), *kernel, arr)); std::vector got(k); ASSERT_TRUE(stream.BlockHostUntilDone().ok()); for (int i = 0; i < batch_size; i++) { - stream.ThenMemcpy(got.data(), output_values.GetSlice(k * i, k), - k * sizeof(T)); + TF_ASSERT_OK(stream.Memcpy(got.data(), output_values.GetSlice(k * i, k), + k * sizeof(T))); std::vector slice(source.data() + n * i, source.data() + n * (i + 1)); std::sort(slice.begin(), slice.end(), std::greater()); slice.resize(k); @@ -138,12 +141,11 @@ TEST_P(TopKKernelTest, TopKPackedNegative) { auto name = absl::AsciiStrToUpper(PlatformUtil::CanonicalPlatformName("gpu").value()); - se::Platform* platform = - se::MultiPlatformManager::PlatformWithName(name).value(); + se::Platform* platform = se::PlatformManager::PlatformWithName(name).value(); se::StreamExecutor* executor = platform->ExecutorForDevice(0).value(); se::Stream stream(executor); - stream.Init(); + TF_ASSERT_OK(stream.Initialize()); ASSERT_TRUE(stream.ok()); const auto [n_kb, k, batch_size, offset] = GetParam(); @@ -157,14 +159,17 @@ TEST_P(TopKKernelTest, TopKPackedNegative) { executor->AllocateArray(k * batch_size, 0); auto source = RandomVecNegative(n * batch_size); - stream.ThenMemcpy(&input_buffer, source.data(), n * batch_size * sizeof(T)); - stream.ThenMemZero(&output_values, k * batch_size * sizeof(T)); - stream.ThenMemZero(&output_indices, k * batch_size * sizeof(uint32_t)); + TF_ASSERT_OK( + stream.Memcpy(&input_buffer, source.data(), n * batch_size * sizeof(T))); + TF_ASSERT_OK(stream.MemZero(&output_values, k * batch_size * sizeof(T))); + TF_ASSERT_OK( + stream.MemZero(&output_indices, k * batch_size * sizeof(uint32_t))); - se::Kernel kernel(executor); auto custom_kernel = GetTopKKernel("topk", PrimitiveType::F32, n, k, batch_size); - TF_ASSERT_OK(executor->GetKernel(custom_kernel->kernel_spec(), &kernel)); + + TF_ASSERT_OK_AND_ASSIGN( + auto kernel, se::Kernel::Create(executor, custom_kernel->kernel_spec())); // Launch topk kernel with device memory arguments. se::KernelArgsDeviceMemoryArray arr( @@ -172,13 +177,13 @@ TEST_P(TopKKernelTest, TopKPackedNegative) { {input_buffer, output_values, output_indices}), custom_kernel->shared_memory_bytes()); TF_ASSERT_OK(executor->Launch(&stream, custom_kernel->thread_dims(), - custom_kernel->block_dims(), kernel, arr)); + custom_kernel->block_dims(), *kernel, arr)); std::vector got(k); ASSERT_TRUE(stream.BlockHostUntilDone().ok()); for (int i = 0; i < batch_size; i++) { - stream.ThenMemcpy(got.data(), output_values.GetSlice(k * i, k), - k * sizeof(T)); + TF_ASSERT_OK(stream.Memcpy(got.data(), output_values.GetSlice(k * i, k), + k * sizeof(T))); std::vector slice(source.data() + n * i, source.data() + n * (i + 1)); std::sort(slice.begin(), slice.end(), std::greater()); slice.resize(k); diff --git a/third_party/xla/xla/service/gpu/runtime/topk_kernel.cc b/third_party/xla/xla/service/gpu/kernels/topk_kernel.cc similarity index 83% rename from third_party/xla/xla/service/gpu/runtime/topk_kernel.cc rename to third_party/xla/xla/service/gpu/kernels/topk_kernel.cc index 91fde6b8c8d27a..1261b5ea23402b 100644 --- a/third_party/xla/xla/service/gpu/runtime/topk_kernel.cc +++ b/third_party/xla/xla/service/gpu/kernels/topk_kernel.cc @@ -14,28 +14,30 @@ limitations under the License. ==============================================================================*/ // This file contains bespoke and optimized implementation for TopK shapes. When -// adding support for new shapes/dtypes, you also need to modify the rewritter +// adding support for new shapes/dtypes, you also need to modify the rewriter // on topk_specializer.cc for these changes to be picked up. -#include "xla/service/gpu/runtime/topk_kernel.h" +#include "xla/service/gpu/kernels/topk_kernel.h" #include +#include #include #include "absl/numeric/bits.h" #include "absl/status/status.h" -#include "Eigen/Core" // from @eigen_archive +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" #include "xla/primitive_util.h" -#include "xla/service/gpu/runtime/gpu_kernel_helper.h" -#include "xla/service/gpu/runtime/topk_kernel_common.h" -#include "xla/stream_executor/gpu/gpu_stream.h" +#include "xla/service/gpu/kernels/topk_kernel_common.h" +#include "xla/stream_executor/launch_dim.h" #include "xla/stream_executor/stream.h" +#include "xla/types.h" #include "xla/xla_data.pb.h" #include "tsl/platform/errors.h" +#include "tsl/platform/logging.h" #include "tsl/platform/statusor.h" namespace xla::gpu { - namespace { size_t NumThreads(size_t n, size_t k, size_t batch_size) { @@ -66,7 +68,7 @@ absl::Status TypedTopK(se::Stream* stream, se::DeviceMemoryBase data, size_t batch_size) { constexpr size_t max_kv_size = sizeof(uint64_t); // Allocate shmem assuming we have a full reduction. - int shmem_size = absl::bit_ceil(k) * max_kv_size * WAVEFRONT_SIZE; + int shmem_size = absl::bit_ceil(k) * max_kv_size * GetTopKWaveFrontSize(); int num_threads = NumThreads(num_elements, k, batch_size); if (num_threads == 0) { return absl::FailedPreconditionError( @@ -81,14 +83,13 @@ absl::Status TypedTopK(se::Stream* stream, se::DeviceMemoryBase data, TF_ASSIGN_OR_RETURN(void* kernel_symbol, GetKernel(num_elements, k)); TF_ASSIGN_OR_RETURN( auto kernel, - (executor - ->CreateTypedKernel, size_t, se::DeviceMemory, - se::DeviceMemory, size_t>( - "topk", kernel_symbol))); + (se::TypedKernel, size_t, se::DeviceMemory, + se::DeviceMemory, + size_t>::Create(executor, "topk", kernel_symbol))); TF_RETURN_IF_ERROR(stream->ThenLaunch( se::ThreadDim(num_threads, 1, 1), se::BlockDim(batch_size, 1, 1), - shmem_size, *kernel, data_typed, num_elements, top_elements_typed, + shmem_size, kernel, data_typed, num_elements, top_elements_typed, top_indices_typed, k)); return absl::OkStatus(); @@ -108,8 +109,8 @@ absl::Status RunTopk(se::Stream* stream, PrimitiveType dtype, return TypedTopK(stream, data, num_elements, top_elements, top_indices, k, batch_size); case PrimitiveType::BF16: - return TypedTopK( - stream, data, num_elements, top_elements, top_indices, k, batch_size); + return TypedTopK(stream, data, num_elements, top_elements, + top_indices, k, batch_size); default: return absl::UnimplementedError("GpuTopK not implemented for this dtype"); } diff --git a/third_party/xla/xla/service/gpu/runtime/topk_kernel.cu.h b/third_party/xla/xla/service/gpu/kernels/topk_kernel.cu.h similarity index 81% rename from third_party/xla/xla/service/gpu/runtime/topk_kernel.cu.h rename to third_party/xla/xla/service/gpu/kernels/topk_kernel.cu.h index f981682c70d242..44b11394ece3e6 100644 --- a/third_party/xla/xla/service/gpu/runtime/topk_kernel.cu.h +++ b/third_party/xla/xla/service/gpu/kernels/topk_kernel.cu.h @@ -13,22 +13,77 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_GPU_RUNTIME_TOPK_KERNEL_CU_H_ -#define XLA_SERVICE_GPU_RUNTIME_TOPK_KERNEL_CU_H_ +#ifndef XLA_SERVICE_GPU_KERNELS_TOPK_KERNEL_CU_H_ +#define XLA_SERVICE_GPU_KERNELS_TOPK_KERNEL_CU_H_ // This file contains bespoke and optimized implementation for TopK shapes. When -// adding support for new shapes/dtypes, you also need to modify the rewritter +// adding support for new shapes/dtypes, you also need to modify the rewriter // on topk_specializer.cc for these changes to be picked up. #include #include #include -#include "xla/service/gpu/runtime/gpu_kernel_helper.h" -#include "xla/service/gpu/runtime/topk_kernel_common.h" +#include "xla/service/gpu/kernels/topk_kernel_common.h" +#include "xla/stream_executor/gpu/gpu_types.h" +#include "tsl/lib/math/math_util.h" + +#if GOOGLE_CUDA + +#define WAVEFRONT_SIZE 32 +#define FORCEINLINE __forceinline__ + +#elif TENSORFLOW_USE_ROCM // GOOGLE_CUDA + +#ifdef __AMDGCN_WAVEFRONT_SIZE +#define WAVEFRONT_SIZE __AMDGCN_WAVEFRONT_SIZE +#else +#define WAVEFRONT_SIZE 64 +#endif +#define FORCEINLINE __forceinline__ + +#endif // TENSORFLOW_USE_ROCM namespace xla::gpu { +enum class ShflType { kSync, kUp, kDown, kXor }; + +template +__device__ FORCEINLINE NT GpuShuffle(NT val, uint32_t idx, + uint32_t allmsk = 0xffffffffu) { + constexpr uint32_t SZ = + tsl::MathUtil::CeilOfRatio(sizeof(NT), sizeof(uint32_t)); + union S { + NT v; + uint32_t d[SZ]; + }; + S in{val}, res{}; + +#pragma unroll + for (uint32_t i = 0; i < SZ; i++) { +#if GOOGLE_CUDA + if constexpr (Type == ShflType::kSync) + res.d[i] = __shfl_sync(allmsk, in.d[i], idx); + else if constexpr (Type == ShflType::kUp) + res.d[i] = __shfl_up_sync(allmsk, in.d[i], idx); + else if constexpr (Type == ShflType::kDown) + res.d[i] = __shfl_down_sync(allmsk, in.d[i], idx); + else if constexpr (Type == ShflType::kXor) + res.d[i] = __shfl_xor_sync(allmsk, in.d[i], idx); +#elif TENSORFLOW_USE_ROCM // ROcm does not support sync shuffle intrinsics + if constexpr (Type == ShflType::kSync) + res.d[i] = __shfl(in.d[i], idx); + else if constexpr (Type == ShflType::kUp) + res.d[i] = __shfl_up(in.d[i], idx); + else if constexpr (Type == ShflType::kDown) + res.d[i] = __shfl_down(in.d[i], idx); + else if constexpr (Type == ShflType::kXor) + res.d[i] = __shfl_xor(in.d[i], idx); +#endif + } + return res.v; +} + // Default implementation for KV holder. Useful for testing while adding support // for a new type, but generally bitpacking those values is more efficient. See // implementations below. @@ -191,7 +246,7 @@ struct TopK { for (int offset = num_lanes / 2; offset > 0; offset /= 2) { #pragma unroll for (int i = 0; i < K; i++) { - KVT kv = GpuShuffle(tmp[i], offset); + KVT kv = GpuShuffle(tmp[i], offset); if (lane_id >= offset) continue; Push(tmp, kv); } @@ -245,12 +300,17 @@ __launch_bounds__(kTopKMaxThreadsPerBlock, 1) __global__ template void* GetTopKKernelForK(int n) { // TODO(doak): Switch to uint32_t if we don't have an efficient - // implemementation for uint16_t. + // implementation for uint16_t. return n < std::numeric_limits::max() ? reinterpret_cast(&Run) : reinterpret_cast(&Run); } +template +int32_t GetTopKWaveFrontSize() { + return WAVEFRONT_SIZE; +} + } // namespace xla::gpu -#endif // XLA_SERVICE_GPU_RUNTIME_TOPK_KERNEL_CU_H_ +#endif // XLA_SERVICE_GPU_KERNELS_TOPK_KERNEL_CU_H_ diff --git a/third_party/xla/xla/service/gpu/runtime/topk_kernel.h b/third_party/xla/xla/service/gpu/kernels/topk_kernel.h similarity index 83% rename from third_party/xla/xla/service/gpu/runtime/topk_kernel.h rename to third_party/xla/xla/service/gpu/kernels/topk_kernel.h index d1f8b7ec803ee3..8e15483f8c0667 100644 --- a/third_party/xla/xla/service/gpu/runtime/topk_kernel.h +++ b/third_party/xla/xla/service/gpu/kernels/topk_kernel.h @@ -13,18 +13,15 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_GPU_RUNTIME_TOPK_KERNEL_H_ -#define XLA_SERVICE_GPU_RUNTIME_TOPK_KERNEL_H_ +#ifndef XLA_SERVICE_GPU_KERNELS_TOPK_KERNEL_H_ +#define XLA_SERVICE_GPU_KERNELS_TOPK_KERNEL_H_ #include -#include #include "absl/status/status.h" -#include "xla/stream_executor/gpu/gpu_types.h" -#include "xla/stream_executor/platform.h" #include "xla/stream_executor/stream.h" #include "xla/stream_executor/stream_executor.h" -#include "xla/types.h" +#include "xla/types.h" // IWYU pragma: keep #include "xla/xla_data.pb.h" namespace xla::gpu { @@ -43,4 +40,4 @@ absl::Status RunTopk(se::Stream* stream, PrimitiveType dtype, } // namespace xla::gpu -#endif // XLA_SERVICE_GPU_RUNTIME_TOPK_KERNEL_H_ +#endif // XLA_SERVICE_GPU_KERNELS_TOPK_KERNEL_H_ diff --git a/third_party/xla/xla/service/gpu/runtime/topk.h b/third_party/xla/xla/service/gpu/kernels/topk_kernel_bfloat16.cu.cc similarity index 63% rename from third_party/xla/xla/service/gpu/runtime/topk.h rename to third_party/xla/xla/service/gpu/kernels/topk_kernel_bfloat16.cu.cc index 3b6e056aa43522..c0e47295a18d07 100644 --- a/third_party/xla/xla/service/gpu/runtime/topk.h +++ b/third_party/xla/xla/service/gpu/kernels/topk_kernel_bfloat16.cu.cc @@ -13,16 +13,17 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_GPU_RUNTIME_TOPK_H_ -#define XLA_SERVICE_GPU_RUNTIME_TOPK_H_ - -#include "xla/runtime/custom_call_registry.h" +#include "xla/service/gpu/kernels/topk_kernel.cu.h" +#include "xla/types.h" namespace xla::gpu { -// Registers XLA Gpu runtime TopK custom calls. -void RegisterTopkCustomCall(runtime::DirectCustomCallRegistry& registry); +template void* GetTopKKernelForK(int n); +template void* GetTopKKernelForK(int n); +template void* GetTopKKernelForK(int n); +template void* GetTopKKernelForK(int n); +template void* GetTopKKernelForK(int n); -} // namespace xla::gpu +template int32_t GetTopKWaveFrontSize(); -#endif // XLA_SERVICE_GPU_RUNTIME_TOPK_H_ +} // namespace xla::gpu diff --git a/third_party/xla/xla/service/gpu/runtime/topk_kernel_common.h b/third_party/xla/xla/service/gpu/kernels/topk_kernel_common.h similarity index 83% rename from third_party/xla/xla/service/gpu/runtime/topk_kernel_common.h rename to third_party/xla/xla/service/gpu/kernels/topk_kernel_common.h index d7ba89fce1bb84..5ddd9ed513d9aa 100644 --- a/third_party/xla/xla/service/gpu/runtime/topk_kernel_common.h +++ b/third_party/xla/xla/service/gpu/kernels/topk_kernel_common.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_GPU_RUNTIME_TOPK_KERNEL_COMMON_H_ -#define XLA_SERVICE_GPU_RUNTIME_TOPK_KERNEL_COMMON_H_ +#ifndef XLA_SERVICE_GPU_KERNELS_TOPK_KERNEL_COMMON_H_ +#define XLA_SERVICE_GPU_KERNELS_TOPK_KERNEL_COMMON_H_ #include @@ -31,6 +31,9 @@ static constexpr size_t kTopKMaxThreadsPerBlock = 1024; template void* GetTopKKernelForK(int n); +template +int32_t GetTopKWaveFrontSize(); + } // namespace xla::gpu -#endif // XLA_SERVICE_GPU_RUNTIME_TOPK_KERNEL_COMMON_H_ +#endif // XLA_SERVICE_GPU_KERNELS_TOPK_KERNEL_COMMON_H_ diff --git a/third_party/xla/xla/service/gpu/runtime/topk_kernel_float.cu.cc b/third_party/xla/xla/service/gpu/kernels/topk_kernel_float.cu.cc similarity index 90% rename from third_party/xla/xla/service/gpu/runtime/topk_kernel_float.cu.cc rename to third_party/xla/xla/service/gpu/kernels/topk_kernel_float.cu.cc index 33ffd648c6934d..b7b7823a4dff2e 100644 --- a/third_party/xla/xla/service/gpu/runtime/topk_kernel_float.cu.cc +++ b/third_party/xla/xla/service/gpu/kernels/topk_kernel_float.cu.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/runtime/topk_kernel.cu.h" +#include "xla/service/gpu/kernels/topk_kernel.cu.h" namespace xla::gpu { @@ -23,4 +23,6 @@ template void* GetTopKKernelForK(int n); template void* GetTopKKernelForK(int n); template void* GetTopKKernelForK(int n); +template int32_t GetTopKWaveFrontSize(); + } // namespace xla::gpu diff --git a/third_party/xla/xla/service/gpu/runtime/topk_kernel_test.cc b/third_party/xla/xla/service/gpu/kernels/topk_kernel_test.cc similarity index 89% rename from third_party/xla/xla/service/gpu/runtime/topk_kernel_test.cc rename to third_party/xla/xla/service/gpu/kernels/topk_kernel_test.cc index 62a3f09765b3e2..0fe83b55b2599e 100644 --- a/third_party/xla/xla/service/gpu/runtime/topk_kernel_test.cc +++ b/third_party/xla/xla/service/gpu/kernels/topk_kernel_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/runtime/topk_kernel.h" +#include "xla/service/gpu/kernels/topk_kernel.h" #include #include @@ -27,15 +27,14 @@ limitations under the License. #include "absl/random/random.h" #include "absl/strings/substitute.h" #include "absl/time/time.h" -#include "Eigen/Core" // from @eigen_archive -#include "xla/service/gpu/runtime/gpu_kernel_helper.h" #include "xla/stream_executor/gpu/gpu_init.h" #include "xla/stream_executor/gpu/gpu_stream.h" #include "xla/stream_executor/gpu/gpu_timer.h" #include "xla/stream_executor/gpu/gpu_types.h" -#include "xla/stream_executor/multi_platform_manager.h" #include "xla/stream_executor/platform.h" +#include "xla/stream_executor/platform_manager.h" #include "xla/stream_executor/stream.h" +#include "xla/types.h" #include "xla/xla_data.pb.h" #include "tsl/platform/test.h" #include "tsl/platform/test_benchmark.h" @@ -71,11 +70,11 @@ std::vector RandomVecNegative(int num_elements) { } PrimitiveType Get(float) { return PrimitiveType::F32; } -PrimitiveType Get(Eigen::bfloat16) { return PrimitiveType::BF16; } +PrimitiveType Get(bfloat16) { return PrimitiveType::BF16; } se::StreamExecutor* GetGpuExecutor() { auto* platform = - se::MultiPlatformManager::PlatformWithName(se::GpuPlatformName()).value(); + se::PlatformManager::PlatformWithName(se::GpuPlatformName()).value(); return platform->ExecutorForDevice(0).value(); } @@ -94,7 +93,7 @@ TEST_P(TopkTest, TopKFloat) { auto* executor = GetGpuExecutor(); se::Stream stream(executor); - stream.Init(); + CHECK_OK(stream.Initialize()); ASSERT_TRUE(stream.ok()); const auto [n_kb, k, batch_size, offset] = GetParam(); @@ -108,8 +107,8 @@ TEST_P(TopkTest, TopKFloat) { output_indices.is_null())); auto source = RandomVec(n * batch_size); - stream.ThenMemcpy(input_buffer.ptr(), source.data(), - n * batch_size * sizeof(T)); + CHECK_OK(stream.Memcpy(input_buffer.ptr(), source.data(), + n * batch_size * sizeof(T))); ASSERT_TRUE(RunTopk(&stream, Get(T()), *input_buffer, n, *output_values, *output_indices, k, batch_size) @@ -117,8 +116,8 @@ TEST_P(TopkTest, TopKFloat) { std::vector got(k); ASSERT_TRUE(stream.BlockHostUntilDone().ok()); for (int i = 0; i < batch_size; i++) { - stream.ThenMemcpy(got.data(), output_values->GetSlice(k * i, k), - k * sizeof(T)); + CHECK_OK(stream.Memcpy(got.data(), output_values->GetSlice(k * i, k), + k * sizeof(T))); std::vector slice(source.data() + n * i, source.data() + n * (i + 1)); std::sort(slice.begin(), slice.end(), std::greater()); slice.resize(k); @@ -132,7 +131,7 @@ TEST_P(TopkTest, TopKPackedNegative) { auto* executor = GetGpuExecutor(); se::Stream stream(executor); - stream.Init(); + CHECK_OK(stream.Initialize()); ASSERT_TRUE(stream.ok()); const auto [n_kb, k, batch_size, offset] = GetParam(); @@ -146,8 +145,8 @@ TEST_P(TopkTest, TopKPackedNegative) { output_indices.is_null())); auto source = RandomVecNegative(n * batch_size); - stream.ThenMemcpy(input_buffer.ptr(), source.data(), - n * batch_size * sizeof(T)); + CHECK_OK(stream.Memcpy(input_buffer.ptr(), source.data(), + n * batch_size * sizeof(T))); ASSERT_TRUE(RunTopk(&stream, Get(T()), *input_buffer, n, *output_values, *output_indices, k, batch_size) @@ -155,8 +154,8 @@ TEST_P(TopkTest, TopKPackedNegative) { std::vector got(k); ASSERT_TRUE(stream.BlockHostUntilDone().ok()); for (int i = 0; i < batch_size; i++) { - stream.ThenMemcpy(got.data(), output_values->GetSlice(k * i, k), - k * sizeof(T)); + CHECK_OK(stream.Memcpy(got.data(), output_values->GetSlice(k * i, k), + k * sizeof(T))); std::vector slice(source.data() + n * i, source.data() + n * (i + 1)); std::sort(slice.begin(), slice.end(), std::greater()); slice.resize(k); @@ -191,7 +190,7 @@ void BM_SmallTopk(benchmark::State& state) { auto* executor = GetGpuExecutor(); se::Stream stream(executor); - stream.Init(); + CHECK_OK(stream.Initialize()); ASSERT_TRUE(stream.ok()); auto input_buffer = executor->AllocateOwnedArray(n * batch_size), @@ -209,7 +208,7 @@ void BM_SmallTopk(benchmark::State& state) { // time to generate random data) for (size_t i = 0; i < batch_size; i++) { auto slice = input_buffer->GetSlice(i * n, n); - stream.ThenMemcpy(&slice, source.data(), n * sizeof(T)); + CHECK_OK(stream.Memcpy(&slice, source.data(), n * sizeof(T))); } for (auto _ : state) { diff --git a/third_party/xla/xla/service/gpu/llvm_gpu_backend/BUILD b/third_party/xla/xla/service/gpu/llvm_gpu_backend/BUILD index be1d96e899b54c..64cdec48e19ca0 100644 --- a/third_party/xla/xla/service/gpu/llvm_gpu_backend/BUILD +++ b/third_party/xla/xla/service/gpu/llvm_gpu_backend/BUILD @@ -1,3 +1,4 @@ +load("@local_tsl//tsl:tsl.bzl", "internal_visibility") load("//xla:xla.bzl", "xla_cc_test") load( "@local_config_rocm//rocm:build_defs.bzl", @@ -5,7 +6,8 @@ load( ) package( - default_visibility = ["//visibility:public"], + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = internal_visibility([":friends"]), licenses = ["notice"], ) @@ -26,7 +28,6 @@ cc_library( "gpu_backend_lib.h", "utils.h", ], - visibility = ["//visibility:public"], deps = [ "//xla:status_macros", "//xla:statusor", diff --git a/third_party/xla/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc b/third_party/xla/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc index bbd1352fd066ab..37373041e8a967 100644 --- a/third_party/xla/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc +++ b/third_party/xla/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc @@ -292,14 +292,10 @@ absl::Status NVPTXTargetModuleLinker(llvm::Module* module, std::unique_ptr NVPTXGetTargetMachine( llvm::Triple target_triple, se::CudaComputeCapability compute_capability, const DebugOptions& debug_options) { - // TODO(b/266678775): Make it always PTX 7.1 as soon as TF driver requirements - // are updated. - const std::string ptx_ver = - debug_options.xla_gpu_enable_triton_gemm() ? "+ptx71" : "+ptx60"; // Figure out the exact name of the processor as known to the NVPTX backend // from the gpu_architecture flag. return GetTargetMachine(target_triple, GetSmName(compute_capability), - debug_options, ptx_ver); + debug_options, /*feature_str=*/"+ptx74"); } using TargetModuleLinker = diff --git a/third_party/xla/xla/service/gpu/make_batch_pointers.cc b/third_party/xla/xla/service/gpu/make_batch_pointers.cc index 9cc84bf6be2ca3..027b128f413dfe 100644 --- a/third_party/xla/xla/service/gpu/make_batch_pointers.cc +++ b/third_party/xla/xla/service/gpu/make_batch_pointers.cc @@ -60,13 +60,15 @@ absl::Status MakeBatchPointers(se::Stream* stream, #else TF_ASSIGN_OR_RETURN( - auto kernel, (executor->CreateTypedKernel( - "make_batch_pointers", make_batch_pointers::kernel()))); + auto kernel, + (se::TypedKernel< + se::DeviceMemoryBase, size_t, size_t, + se::DeviceMemoryBase>::Create(executor, "make_batch_pointers", + make_batch_pointers::kernel()))); TF_RETURN_IF_ERROR( stream->ThenLaunch(se::ThreadDim(kThreads, 1, 1), - se::BlockDim(CeilOfRatio(n, kThreads), 1, 1), *kernel, + se::BlockDim(CeilOfRatio(n, kThreads), 1, 1), kernel, base_ptr, stride_bytes, n, ptrs_out)); #endif return absl::OkStatus(); diff --git a/third_party/xla/xla/service/gpu/matmul_utils.cc b/third_party/xla/xla/service/gpu/matmul_utils.cc index 424d9732a4ff80..962588333c1c21 100644 --- a/third_party/xla/xla/service/gpu/matmul_utils.cc +++ b/third_party/xla/xla/service/gpu/matmul_utils.cc @@ -26,10 +26,11 @@ limitations under the License. #include "absl/algorithm/container.h" #include "absl/log/check.h" +#include "absl/status/status.h" #include "absl/strings/str_cat.h" #include "absl/types/span.h" +#include "xla/autotuning.pb.h" #include "xla/hlo/ir/hlo_instruction.h" -#include "xla/mlir_hlo/lhlo_gpu/IR/lhlo_gpu_ops.h" #include "xla/primitive_util.h" #include "xla/service/gpu/backend_configs.pb.h" #include "xla/shape.h" @@ -456,42 +457,6 @@ absl::StatusOr CanFoldTransposeOperandIntoDot(const HloInstruction& dot, config.beta(), algorithm, precision, grad_x, grad_y); } -/*static*/ absl::StatusOr GemmConfig::For( - mlir::lmhlo_gpu::GEMMOp op) { - mlir::mhlo::DotDimensionNumbersAttr dot_dims = op.getDotDimensionNumbers(); - - std::optional algorithm; - if (op.getAlgorithm()) algorithm = *op.getAlgorithm(); - - bool grad_x = false; - bool grad_y = false; - auto attr_grad_x = op.getGradX(); - if (attr_grad_x) grad_x = attr_grad_x.value(); - auto attr_grad_y = op.getGradY(); - if (attr_grad_y) grad_y = attr_grad_y.value(); - - int64_t compute_precision = 0; // Default - if (op.getPrecisionConfig().has_value()) { - auto precision_config = op.getPrecisionConfig(); - for (auto attr : precision_config.value()) { - int64_t value = static_cast( - attr.template cast().getValue()); - if (value > compute_precision) { - compute_precision = value; - } - } - } - - return GemmConfig::For( - GetShape(op.getA()), dot_dims.getLhsBatchingDimensions(), - dot_dims.getLhsContractingDimensions(), GetShape(op.getB()), - dot_dims.getRhsBatchingDimensions(), - dot_dims.getRhsContractingDimensions(), GetShape(op.getC()), - op.getAlphaReal().convertToDouble(), op.getAlphaImag().convertToDouble(), - op.getBeta().convertToDouble(), algorithm, compute_precision, grad_x, - grad_y); -} - absl::StatusOr GemmConfig::GetMatrixDescriptors( se::DeviceMemoryBase lhs_buf, se::DeviceMemoryBase rhs_buf, se::DeviceMemoryBase out_buf) const { @@ -507,7 +472,7 @@ absl::StatusOr GemmConfig::GetMatrixDescriptors( ? se::blas::Transpose::kNoTranspose : se::blas::Transpose::kTranspose)}; }; - // make a local copy to prevent modification of layouts, + // TODO: make a local copy to prevent modification of layouts, // but maybe we can modify them once instead during creation ? se::gpu::MatrixLayout lhs = lhs_layout, rhs = rhs_layout, out = output_layout; @@ -557,21 +522,25 @@ absl::Status DoGemmWithAlgorithm(const se::gpu::MatrixDescriptor& lhs, se::DeviceMemory output_data(output.data); // Set a workspace for all Blas operations launched below. - se::blas::BlasSupport::ScopedWorkspace scoped_workspace( - stream->parent()->AsBlas(), &workspace); + auto* blas = stream->parent()->AsBlas(); + if (blas == nullptr) { + return absl::InternalError("No Blas support for stream"); + } + + se::blas::BlasSupport::ScopedWorkspace scoped_workspace(blas, &workspace); if (output.batch_size != 1) { - return stream->ThenBlasGemmStridedBatchedWithAlgorithm( - lhs.transpose, rhs.transpose, output.m, output.n, output.k, alpha, - lhs.cast(), lhs.leading_dim_stride, lhs.batch_stride, + return blas->BlasGemmStridedBatchedWithAlgorithm( + stream, lhs.transpose, rhs.transpose, output.m, output.n, output.k, + alpha, lhs.cast(), lhs.leading_dim_stride, lhs.batch_stride, rhs.cast(), rhs.leading_dim_stride, rhs.batch_stride, beta, &output_data, output.leading_dim_stride, output.batch_stride, output.batch_size, computation_type, algorithm, numeric_options, profile_result, context); } else { - return stream->ThenBlasGemmWithAlgorithm( - lhs.transpose, rhs.transpose, output.m, output.n, output.k, alpha, - lhs.cast(), lhs.leading_dim_stride, rhs.cast(), + return blas->BlasGemmWithAlgorithm( + stream, lhs.transpose, rhs.transpose, output.m, output.n, output.k, + alpha, lhs.cast(), lhs.leading_dim_stride, rhs.cast(), rhs.leading_dim_stride, beta, &output_data, output.leading_dim_stride, computation_type, algorithm, numeric_options, profile_result, context); } @@ -590,34 +559,34 @@ absl::Status DoGemm(const se::gpu::MatrixDescriptor& lhs, se::blas::CallContext context) { CHECK(output.transpose == se::blas::Transpose::kNoTranspose); se::DeviceMemory output_data(output.data); + auto* blas = stream->parent()->AsBlas(); + if (blas == nullptr) { + return absl::InternalError("No Blas support for stream"); + } // Set a workspace for all Blas operations launched below. - se::blas::BlasSupport::ScopedWorkspace scoped_workspace( - stream->parent()->AsBlas(), &workspace); + se::blas::BlasSupport::ScopedWorkspace scoped_workspace(blas, &workspace); -// TODO: enable DoGemmWithAlgorithm for ROCm ! -#if GOOGLE_CUDA if (algorithm) { return DoGemmWithAlgorithm( lhs, rhs, output, workspace, alpha, beta, stream, *algorithm, compute_precision, numeric_options, profile_result, context); } -#endif if (output.batch_size != 1) { - return stream->ThenBlasGemmStridedBatched( - lhs.transpose, rhs.transpose, output.m, output.n, output.k, alpha, - lhs.cast(), lhs.leading_dim_stride, lhs.batch_stride, + return blas->BlasGemmStridedBatched( + stream, lhs.transpose, rhs.transpose, output.m, output.n, output.k, + alpha, lhs.cast(), lhs.leading_dim_stride, lhs.batch_stride, rhs.cast(), rhs.leading_dim_stride, rhs.batch_stride, beta, &output_data, output.leading_dim_stride, output.batch_stride, output.batch_size, numeric_options, context); } - return stream->ThenBlasGemm( - lhs.transpose, rhs.transpose, output.m, output.n, output.k, alpha, - lhs.cast(), lhs.leading_dim_stride, rhs.cast(), - rhs.leading_dim_stride, beta, &output_data, output.leading_dim_stride, - numeric_options, context); + return blas->BlasGemm(stream, lhs.transpose, rhs.transpose, output.m, + output.n, output.k, alpha, lhs.cast(), + lhs.leading_dim_stride, rhs.cast(), + rhs.leading_dim_stride, beta, &output_data, + output.leading_dim_stride, numeric_options, context); } } // namespace @@ -661,8 +630,7 @@ absl::Status RunGemm(const GemmConfig& config, se::DeviceMemoryBase lhs_buffer, // graphs, so we are making sure we do not trigger it). if (config.alpha.real() == 0.0 && config.alpha.imag() == 0.0 && config.beta == 0.0) { - stream->ThenMemZero(&output_buffer, output_buffer.size()); - return absl::OkStatus(); + return stream->MemZero(&output_buffer, output_buffer.size()); } #define TYPED_GEMM(SCALENTYPE, ATYPE, BTYPE, CTYPE) \ @@ -758,30 +726,6 @@ absl::StatusOr EpilogueHasAuxiliaryOutput( } } -absl::StatusOr AsBlasLtEpilogue( - mlir::lmhlo_gpu::CublasLtMatmulEpilogue epilogue) { - using mlir::lmhlo_gpu::CublasLtMatmulEpilogue; - switch (epilogue) { - case CublasLtMatmulEpilogue::Default: - return se::gpu::BlasLt::Epilogue::kDefault; - case CublasLtMatmulEpilogue::Relu: - return se::gpu::BlasLt::Epilogue::kReLU; - case CublasLtMatmulEpilogue::Gelu: - return se::gpu::BlasLt::Epilogue::kGELU; - case CublasLtMatmulEpilogue::GeluAux: - return se::gpu::BlasLt::Epilogue::kGELUWithAux; - case CublasLtMatmulEpilogue::Bias: - return se::gpu::BlasLt::Epilogue::kBias; - case CublasLtMatmulEpilogue::BiasRelu: - return se::gpu::BlasLt::Epilogue::kBiasThenReLU; - case CublasLtMatmulEpilogue::BiasGelu: - return se::gpu::BlasLt::Epilogue::kBiasThenGELU; - case CublasLtMatmulEpilogue::BiasGeluAux: - return se::gpu::BlasLt::Epilogue::kBiasThenGELUWithAux; - } - return Internal("unexpected epilogue value"); -} - absl::StatusOr AsBlasLtEpilogue( GemmBackendConfig_Epilogue epilogue) { switch (epilogue) { @@ -808,16 +752,20 @@ absl::StatusOr AsBlasLtEpilogue( } // namespace gpublas_lt -/*static*/ TritonGemmConfig TritonGemmConfig::FromProto( +/*static*/ absl::StatusOr TritonGemmConfig::FromProto( const AutotuneResult::TritonGemmKey& proto) { - TritonGemmConfig config; - config.block_m = proto.block_m(); - config.block_n = proto.block_n(); - config.block_k = proto.block_k(); - config.split_k = proto.split_k(); - config.num_stages = proto.num_stages(); - config.num_warps = proto.num_warps(); - return config; + // Sanity check to avoid loading incomplete data. + TF_RET_CHECK(proto.block_m() > 0); + TF_RET_CHECK(proto.block_n() > 0); + TF_RET_CHECK(proto.block_k() > 0); + TF_RET_CHECK(proto.split_k() > 0); + TF_RET_CHECK(proto.num_stages() > 0); + TF_RET_CHECK(proto.num_warps() > 0); + TF_RET_CHECK(proto.num_ctas() > 0); + + return TritonGemmConfig(proto.block_m(), proto.block_n(), proto.block_k(), + proto.split_k(), proto.num_stages(), + proto.num_warps(), proto.num_ctas()); } AutotuneResult::TritonGemmKey TritonGemmConfig::ToProto() const { @@ -828,6 +776,7 @@ AutotuneResult::TritonGemmKey TritonGemmConfig::ToProto() const { key.set_split_k(split_k); key.set_num_stages(num_stages); key.set_num_warps(num_warps); + key.set_num_ctas(num_ctas); return key; } @@ -835,7 +784,7 @@ std::string TritonGemmConfig::ToString() const { return absl::StrCat("{block_m:", block_m, ",block_n:", block_n, ",block_k:", block_k, ",split_k:", split_k, ",num_stages:", num_stages, ",num_warps:", num_warps, - "}"); + ",num_ctas:", num_ctas, "}"); } } // namespace gpu diff --git a/third_party/xla/xla/service/gpu/matmul_utils.h b/third_party/xla/xla/service/gpu/matmul_utils.h index 2efbdd386af1e5..8141706fd1381b 100644 --- a/third_party/xla/xla/service/gpu/matmul_utils.h +++ b/third_party/xla/xla/service/gpu/matmul_utils.h @@ -26,7 +26,6 @@ limitations under the License. #include "absl/types/span.h" #include "xla/autotuning.pb.h" #include "xla/hlo/ir/hlo_instruction.h" -#include "xla/mlir_hlo/lhlo_gpu/IR/lhlo_gpu_ops.h" #include "xla/service/gpu/backend_configs.pb.h" #include "xla/service/gpu/ir_emission_utils.h" #include "xla/shape.h" @@ -97,7 +96,6 @@ struct GemmConfig : public se::gpu::GemmConfig { static constexpr int64_t kDefaultWorkspace = 4 * 1024 * 1024; // 4 MiB static absl::StatusOr For(const HloInstruction* gemm); - static absl::StatusOr For(mlir::lmhlo_gpu::GEMMOp op); static absl::StatusOr For( const Shape& lhs_shape, absl::Span lhs_batch_dims, @@ -119,43 +117,6 @@ struct GemmConfig : public se::gpu::GemmConfig { double alpha_imag, double beta, std::optional algorithm, int64_t compute_precision, bool grad_x, bool grad_y); - template ::value || - std::is_same::value>> - static absl::StatusOr For(CublasLtMatmulMaybeF8Op op) { - mlir::mhlo::DotDimensionNumbersAttr dot_dims = op.getDotDimensionNumbers(); - - int64_t compute_precision = 0; // Default - if (op.getPrecisionConfig().has_value()) { - auto precision_config = op.getPrecisionConfig(); - for (auto attr : precision_config.value()) { - int64_t value = static_cast( - attr.template cast().getValue()); - if (value > compute_precision) { - compute_precision = value; - } - } - } - - Shape bias_shape; - if (op.getBias() != nullptr) { - bias_shape = GetShape(op.getBias()); - } - return GemmConfig::For( - GetShape(op.getA()), dot_dims.getLhsBatchingDimensions(), - dot_dims.getLhsContractingDimensions(), GetShape(op.getB()), - dot_dims.getRhsBatchingDimensions(), - dot_dims.getRhsContractingDimensions(), GetShape(op.getC()), - op.getBias() == nullptr ? nullptr : &bias_shape, GetShape(op.getD()), - op.getAlphaReal().convertToDouble(), - op.getAlphaImag().convertToDouble(), op.getBeta().convertToDouble(), - op.getAlgorithm(), compute_precision, /*grad_x=*/false, - /*grad_y=*/false); - } - struct DescriptorsTuple { se::gpu::MatrixDescriptor lhs; se::gpu::MatrixDescriptor rhs; @@ -186,8 +147,6 @@ absl::StatusOr EpilogueAddsVectorBias( absl::StatusOr EpilogueHasAuxiliaryOutput( GemmBackendConfig_Epilogue epilogue); -absl::StatusOr AsBlasLtEpilogue( - mlir::lmhlo_gpu::CublasLtMatmulEpilogue epilogue); absl::StatusOr AsBlasLtEpilogue( GemmBackendConfig_Epilogue epilogue); @@ -196,47 +155,43 @@ absl::StatusOr AsBlasLtEpilogue( // We should use this in code instead of AutotuneResult::TritonGemmKey. // This has some advantages, for example it can be used in hashmaps. struct TritonGemmConfig { - struct ClusterDims { - constexpr ClusterDims() = default; - constexpr ClusterDims(int x, int y, int z) : x(x), y(y), z(z) {} - int x = 1; - int y = 1; - int z = 1; - }; - constexpr TritonGemmConfig() = default; constexpr TritonGemmConfig(int block_m, int block_n, int block_k, int split_k, - int num_stages, int num_warps, int num_ctas = 1, - ClusterDims cluster_dims = ClusterDims(1, 1, 1), - bool enable_warp_specialization = false) + int num_stages, int num_warps, int num_ctas = 1) : block_m(block_m), block_n(block_n), block_k(block_k), split_k(split_k), num_stages(num_stages), num_warps(num_warps), - num_ctas(num_ctas), - cluster_dims(cluster_dims), - enable_warp_specialization(enable_warp_specialization) {} + num_ctas(num_ctas) {} int block_m = 0; int block_n = 0; int block_k = 0; int split_k = 0; int num_stages = 0; int num_warps = 0; - int num_ctas = 1; - ClusterDims cluster_dims; - bool enable_warp_specialization = false; + // Number of blocks in a block cluster. + int num_ctas = 0; + + // When adding new members, please update all methods, such as ToTuple, + // FromProto, ToProto, ToString, etc. Updating ToTuple is not enough. + // Please also add new members to AutotuneResult::TritonGemmKey in + // autotuning.proto. Also kVersion has to be incremented in autotuner_util.cc + // and all the autotuning results stored in tests, repos, etc. will have to + // be updated. private: auto ToTuple() const { return std::make_tuple(block_m, block_n, block_k, split_k, num_stages, - num_warps, num_ctas, cluster_dims.x, cluster_dims.y, - cluster_dims.z, enable_warp_specialization); + num_warps, num_ctas); } public: - static TritonGemmConfig FromProto(const AutotuneResult::TritonGemmKey& proto); + // Creates a TritonGemmConfig from the supplied proto, doing a simple sanity + // check. + static absl::StatusOr FromProto( + const AutotuneResult::TritonGemmKey& proto); AutotuneResult::TritonGemmKey ToProto() const; std::string ToString() const; diff --git a/third_party/xla/xla/service/gpu/mock_nccl_utils.cc b/third_party/xla/xla/service/gpu/mock_nccl_utils.cc index 91f3a0ebba5842..2f7e63e0f8cb42 100644 --- a/third_party/xla/xla/service/gpu/mock_nccl_utils.cc +++ b/third_party/xla/xla/service/gpu/mock_nccl_utils.cc @@ -590,7 +590,7 @@ absl::Status RunMockCollectivePermute( // buffer. VLOG(3) << absl::StreamFormat( "%s : mock collective-Permute: Issuing MemZero", device_string); - stream.ThenMemZero(&dest_addr, dest_addr.size()); + return stream.MemZero(&dest_addr, dest_addr.size()); } return absl::OkStatus(); } diff --git a/third_party/xla/xla/service/gpu/mock_nccl_utils.h b/third_party/xla/xla/service/gpu/mock_nccl_utils.h index 19526f7e51206c..70f07145d30e1a 100644 --- a/third_party/xla/xla/service/gpu/mock_nccl_utils.h +++ b/third_party/xla/xla/service/gpu/mock_nccl_utils.h @@ -18,9 +18,11 @@ limitations under the License. #include #include +#include #include #include "absl/status/status.h" +#include "absl/strings/str_format.h" #include "absl/strings/string_view.h" #include "xla/executable_run_options.h" #include "xla/service/collective_ops_utils.h" @@ -31,11 +33,25 @@ limitations under the License. #include "xla/service/gpu/nccl_collective_thunk.h" #include "xla/service/gpu/nccl_p2p_thunk_common.h" #include "xla/service/gpu/thunk.h" +#include "xla/service/lockable.h" #include "xla/stream_executor/stream.h" +#include "tsl/lib/gtl/int_type.h" namespace xla { namespace gpu { +TSL_LIB_GTL_DEFINE_INT_TYPE(OpId, int64_t); + +struct NcclCommName { + static std::string ToString(NcclApi::NcclCommHandle comm) { + return absl::StrFormat("lockable comm %p", comm); + } +}; + +struct NcclComm : public Lockable { + explicit NcclComm(NcclApi::NcclCommHandle comm) : Lockable(comm) {} +}; + // Create the mock nccl communicator assuming all hosts have the same hardwares. absl::StatusOr LockMockNcclComm( const Thunk::CollectiveExecuteParams& params, diff --git a/third_party/xla/xla/service/gpu/model/BUILD b/third_party/xla/xla/service/gpu/model/BUILD index cbea57ea7c36d3..e7cc425a0cf0a7 100644 --- a/third_party/xla/xla/service/gpu/model/BUILD +++ b/third_party/xla/xla/service/gpu/model/BUILD @@ -2,13 +2,15 @@ load("//xla/tests:build_defs.bzl", "xla_test") load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda") load("//xla:xla.bzl", "xla_cc_test", "xla_nvml_deps") +load("@local_tsl//tsl:tsl.bzl", "internal_visibility") load("@local_tsl//tsl:tsl.default.bzl", "get_compatible_with_portable") load("@local_tsl//tsl/platform:build_config.bzl", "tf_proto_library") load("@local_tsl//tsl/platform:build_config_root.bzl", "tf_cuda_tests_tags") load("@local_tsl//tsl/platform/default:cuda_build_defs.bzl", "if_cuda_is_configured") package( - default_visibility = ["//visibility:public"], + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = internal_visibility([":friends"]), licenses = ["notice"], ) @@ -23,8 +25,8 @@ cc_library( name = "analytical_latency_estimator", srcs = ["analytical_latency_estimator.cc"], hdrs = ["analytical_latency_estimator.h"], - visibility = ["//visibility:public"], deps = [ + ":gpu_collective_performance_model", ":gpu_hlo_cost_analysis", ":gpu_performance_model", "//xla:xla_proto_cc", @@ -71,7 +73,6 @@ cc_library( name = "fusion_analysis_cache", srcs = ["fusion_analysis_cache.cc"], hdrs = ["fusion_analysis_cache.h"], - visibility = ["//visibility:public"], deps = [ "//xla/hlo/ir:hlo", "//xla/service/gpu:hlo_fusion_analysis", @@ -102,7 +103,6 @@ cc_library( name = "gpu_cost_model_stats_collection", srcs = ["gpu_cost_model_stats_collection.cc"], hdrs = ["gpu_cost_model_stats_collection.h"], - visibility = ["//visibility:public"], deps = [ ":gpu_hlo_cost_analysis", ":gpu_performance_model", @@ -141,12 +141,8 @@ xla_cc_test( cc_library( name = "gpu_hlo_cost_analysis", srcs = ["gpu_hlo_cost_analysis.cc"], - hdrs = [ - "gpu_hlo_cost_analysis.h", - "hlo_op_profiles_data.h", - ], + hdrs = ["gpu_hlo_cost_analysis.h"], compatible_with = get_compatible_with_portable(), - visibility = ["//visibility:public"], deps = [ ":hlo_op_profile_proto_cc", ":hlo_op_profiles", @@ -187,19 +183,121 @@ xla_cc_test( ], ) +cc_library( + name = "gpu_performance_model_base", + srcs = ["gpu_performance_model_base.cc"], + hdrs = ["gpu_performance_model_base.h"], + deps = [ + ":fusion_analysis_cache", + ":gpu_hlo_cost_analysis", + "//xla:shape_util", + "//xla:util", + "//xla/hlo/ir:hlo", + "//xla/service/gpu:backend_configs_cc", + "//xla/service/gpu:hlo_fusion_analysis", + "//xla/service/gpu:hlo_traversal", + "//xla/service/gpu:launch_dimensions", + "//xla/service/gpu/fusions", + "//xla/service/gpu/fusions:fusion_emitter", + "//xla/stream_executor:device_description", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/time", + ], +) + +xla_cc_test( + name = "gpu_performance_model_base_test", + srcs = ["gpu_performance_model_base_test.cc"], + deps = [ + ":gpu_hlo_cost_analysis", + ":gpu_performance_model_base", + "//xla:shape_util", + "//xla:test_helpers", + "//xla/hlo/ir:hlo", + "//xla/service/gpu:backend_configs_cc", + "//xla/service/gpu:gpu_device_info_for_tests", + "//xla/stream_executor:device_description", + "//xla/tests:hlo_test_base", + "//xla/tests:xla_internal_test_main", + "@com_google_absl//absl/strings", + "@com_google_googletest//:gtest", + "@local_tsl//tsl/platform:statusor", + ], +) + cc_library( name = "gpu_performance_model", srcs = ["gpu_performance_model.cc"], hdrs = ["gpu_performance_model.h"], + deps = [ + ":coalescing_analysis", + ":gpu_hlo_cost_analysis", + ":gpu_performance_model_base", + "//xla:shape_util", + "//xla:util", + "//xla/hlo/ir:hlo", + "//xla/service/gpu:backend_configs_cc", + "//xla/service/gpu:hlo_fusion_analysis", + "//xla/service/gpu:launch_dimensions", + "//xla/stream_executor:device_description", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/time", + "@com_google_absl//absl/types:span", + "@llvm-project//llvm:Support", + "@local_tsl//tsl/platform:status", + ], +) + +xla_cc_test( + name = "gpu_performance_model_test", + srcs = ["gpu_performance_model_test.cc"], + deps = [ + ":gpu_hlo_cost_analysis", + ":gpu_indexing_performance_model", + ":gpu_performance_model", + ":gpu_performance_model_base", + "//xla:shape_util", + "//xla:test_helpers", + "//xla/hlo/ir:hlo", + "//xla/service:hlo_module_config", + "//xla/service/gpu:backend_configs_cc", + "//xla/service/gpu:gpu_device_info_for_tests", + "//xla/stream_executor:device_description", + "//xla/tests:hlo_test_base", + "//xla/tests:xla_internal_test_main", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/time", + "@com_google_googletest//:gtest", + "@llvm-project//mlir:IR", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:statusor", + ], +) + +cc_library( + name = "gpu_collective_performance_model", + srcs = ["gpu_collective_performance_model.cc"], + hdrs = ["gpu_collective_performance_model.h"], local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]), - visibility = ["//visibility:public"], deps = [ ":coalescing_analysis", ":fusion_analysis_cache", ":gpu_hlo_cost_analysis", + ":gpu_performance_model_base", + ":hlo_op_profiles", + ":indexing_analysis", + ":indexing_map", "//xla:shape_util", + "//xla:statusor", "//xla:util", "//xla/hlo/ir:hlo", + "//xla/service:hlo_cost_analysis", "//xla/service:hlo_dataflow_analysis", "//xla/service/gpu:backend_configs_cc", "//xla/service/gpu:gpu_fusible", @@ -207,6 +305,7 @@ cc_library( "//xla/service/gpu:hlo_traversal", "//xla/service/gpu:launch_dimensions", "//xla/service/gpu/fusions", + "//xla/service/gpu/fusions:fusion_emitter", "//xla/stream_executor:device_description", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", @@ -215,21 +314,63 @@ cc_library( "@com_google_absl//absl/strings", "@com_google_absl//absl/synchronization", "@com_google_absl//absl/time", + "@com_google_absl//absl/types:span", "@llvm-project//llvm:Support", + "@llvm-project//mlir:IR", "@local_tsl//tsl/platform:status", ] + if_cuda_is_configured(xla_nvml_deps()), ) xla_cc_test( - name = "gpu_performance_model_test", - srcs = ["gpu_performance_model_test.cc"], + name = "gpu_collective_performance_model_test", + srcs = ["gpu_collective_performance_model_test.cc"], deps = [ + ":gpu_collective_performance_model", + "//xla/service/gpu:backend_configs_cc", + "//xla/tests:hlo_test_base", + "//xla/tests:xla_internal_test_main", + "@com_google_googletest//:gtest", + ], +) + +cc_library( + name = "gpu_indexing_performance_model", + srcs = ["gpu_indexing_performance_model.cc"], + hdrs = ["gpu_indexing_performance_model.h"], + deps = [ + ":coalescing_analysis", ":gpu_hlo_cost_analysis", - ":gpu_performance_model", + ":gpu_performance_model_base", + ":hlo_op_profiles", + ":indexing_analysis", + ":indexing_map", + "//xla:shape_util", + "//xla:util", + "//xla/hlo/ir:hlo", + "//xla/service:hlo_cost_analysis", + "//xla/service/gpu:backend_configs_cc", + "//xla/service/gpu:hlo_fusion_analysis", + "//xla/service/gpu:hlo_traversal", + "//xla/service/gpu:launch_dimensions", + "//xla/stream_executor:device_description", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/time", + "@com_google_absl//absl/types:span", + "@llvm-project//mlir:IR", + "@local_tsl//tsl/platform:status", + ], +) + +xla_cc_test( + name = "gpu_indexing_performance_model_test", + srcs = ["gpu_indexing_performance_model_test.cc"], + deps = [ + ":gpu_hlo_cost_analysis", + ":gpu_indexing_performance_model", + ":gpu_performance_model_base", "//xla:shape_util", - "//xla:test_helpers", "//xla/hlo/ir:hlo", - "//xla/service:hlo_module_config", "//xla/service/gpu:backend_configs_cc", "//xla/service/gpu:gpu_device_info_for_tests", "//xla/stream_executor:device_description", @@ -238,7 +379,7 @@ xla_cc_test( "@com_google_absl//absl/strings", "@com_google_absl//absl/time", "@com_google_googletest//:gtest", - "@local_tsl//tsl/platform:errors", + "@llvm-project//mlir:IR", "@local_tsl//tsl/platform:statusor", ], ) @@ -247,9 +388,9 @@ cc_library( name = "affine_map_printer", srcs = ["affine_map_printer.cc"], hdrs = ["affine_map_printer.h"], - visibility = ["//visibility:public"], deps = [ "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", "@llvm-project//mlir:Support", @@ -273,7 +414,6 @@ cc_library( name = "indexing_map", srcs = ["indexing_map.cc"], hdrs = ["indexing_map.h"], - visibility = ["//visibility:public"], deps = [ ":affine_map_printer", "@com_google_absl//absl/types:span", @@ -304,7 +444,6 @@ cc_library( testonly = True, srcs = ["indexing_test_utils.cc"], hdrs = ["indexing_test_utils.h"], - visibility = ["//visibility:public"], deps = [ ":affine_map_printer", ":indexing_analysis", @@ -326,7 +465,6 @@ cc_library( name = "indexing_analysis", srcs = ["indexing_analysis.cc"], hdrs = ["indexing_analysis.h"], - visibility = ["//visibility:public"], deps = [ ":affine_map_printer", ":indexing_map", @@ -337,9 +475,12 @@ cc_library( "//xla/hlo/ir:hlo", "//xla/service/gpu:hlo_traversal", "//xla/service/gpu:matmul_utils", + "//xla/service/gpu/fusions:tiling_util", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/log:check", "@com_google_absl//absl/types:span", "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", @@ -352,9 +493,7 @@ xla_cc_test( srcs = ["indexing_analysis_test.cc"], deps = [ ":indexing_analysis", - ":indexing_map", ":indexing_test_utils", - "//xla:shape_util", "//xla/hlo/ir:hlo", "//xla/service/gpu:hlo_traversal", "//xla/tests:hlo_test_base", @@ -370,11 +509,11 @@ cc_library( name = "tile_analysis", srcs = ["tile_analysis.cc"], hdrs = ["tile_analysis.h"], - visibility = ["//visibility:public"], deps = [ ":affine_map_printer", ":indexing_map", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/strings", "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", "@llvm-project//mlir:Support", @@ -405,7 +544,6 @@ cc_library( name = "coalescing_analysis", srcs = ["coalescing_analysis.cc"], hdrs = ["coalescing_analysis.h"], - visibility = ["//visibility:public"], deps = [ ":indexing_analysis", ":indexing_map", @@ -413,8 +551,15 @@ cc_library( "//xla/hlo/ir:hlo", "//xla/service/gpu:gpu_fusible", "//xla/service/gpu:hlo_fusion_analysis", + "//xla/service/gpu:hlo_traversal", + "//xla/service/gpu/fusions:fusion_emitter", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/types:span", + "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", ], ) @@ -423,16 +568,18 @@ xla_cc_test( srcs = ["coalescing_analysis_test.cc"], deps = [ ":coalescing_analysis", - ":indexing_analysis", + "//xla:shape_util", "//xla/hlo/ir:hlo", + "//xla/service/gpu:gpu_device_info_for_tests", + "//xla/service/gpu:hlo_fusion_analysis", "//xla/service/gpu:hlo_traversal", + "//xla/service/gpu/fusions", + "//xla/service/gpu/fusions:fusion_emitter", + "//xla/stream_executor:device_description", "//xla/tests:hlo_test_base", - "@com_google_absl//absl/log:check", "@com_google_absl//absl/strings:string_view", "@com_google_googletest//:gtest_main", - "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", - "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/platform:test", ], ) @@ -445,15 +592,16 @@ tf_proto_library( protodeps = [ "//xla/service:hlo_proto", ], - visibility = ["//visibility:public"], ) cc_library( name = "hlo_op_profiles", srcs = ["hlo_op_profiles.cc"], - hdrs = ["hlo_op_profiles.h"], + hdrs = [ + "hlo_op_profiles.h", + "hlo_op_profiles_data.h", + ], compatible_with = get_compatible_with_portable(), - visibility = ["//visibility:public"], deps = [ ":hlo_op_profile_proto_cc", "//xla:types", @@ -488,7 +636,6 @@ cc_library( srcs = ["hlo_op_profiler.cc"], hdrs = ["hlo_op_profiler.h"], local_defines = if_cuda(["GOOGLE_CUDA"]), - visibility = ["//visibility:public"], deps = [ ":hlo_op_profile_proto_cc", "//xla:debug_options_flags", diff --git a/third_party/xla/xla/service/gpu/model/affine_map_printer.cc b/third_party/xla/xla/service/gpu/model/affine_map_printer.cc index 4687b742b80a63..972da89717e32a 100644 --- a/third_party/xla/xla/service/gpu/model/affine_map_printer.cc +++ b/third_party/xla/xla/service/gpu/model/affine_map_printer.cc @@ -18,8 +18,10 @@ limitations under the License. #include #include #include +#include #include "absl/strings/str_cat.h" +#include "absl/types/span.h" #include "llvm/ADT/STLExtras.h" #include "llvm/Support/raw_ostream.h" #include "mlir/IR/AffineExpr.h" // from @llvm-project @@ -40,6 +42,19 @@ using mlir::AffineSymbolExpr; } // namespace +AffineMapPrinter::AffineMapPrinter( + absl::Span dim_names, + absl::Span symbol_names) { + dim_id_to_name_.reserve(dim_names.size()); + for (const auto& [index, name] : llvm::enumerate(dim_names)) { + dim_id_to_name_[index] = name; + } + symbol_id_to_name_.reserve(symbol_names.size()); + for (const auto& [index, name] : llvm::enumerate(symbol_names)) { + symbol_id_to_name_[index] = name; + } +} + void AffineMapPrinter::Print(std::ostream& out, AffineMap affine_map) const { out << ToString(affine_map); } diff --git a/third_party/xla/xla/service/gpu/model/affine_map_printer.h b/third_party/xla/xla/service/gpu/model/affine_map_printer.h index 9b8b4e2e09d70b..67360a552fe2b9 100644 --- a/third_party/xla/xla/service/gpu/model/affine_map_printer.h +++ b/third_party/xla/xla/service/gpu/model/affine_map_printer.h @@ -19,7 +19,9 @@ limitations under the License. #include #include #include +#include +#include "absl/types/span.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/StringRef.h" #include "mlir/IR/AffineExpr.h" // from @llvm-project @@ -32,6 +34,12 @@ namespace gpu { // symbol and dimension names. class AffineMapPrinter { public: + AffineMapPrinter() = default; + AffineMapPrinter(AffineMapPrinter&& other) = default; + AffineMapPrinter& operator=(AffineMapPrinter&& other) = default; + AffineMapPrinter(absl::Span dim_names, + absl::Span symbol_names); + void SetSymbolName(int64_t symbol_id, llvm::StringRef name); void SetDimensionName(int64_t dim_id, llvm::StringRef name); diff --git a/third_party/xla/xla/service/gpu/model/analytical_latency_estimator.cc b/third_party/xla/xla/service/gpu/model/analytical_latency_estimator.cc index bbf5aa77cec668..9865c369dfd702 100644 --- a/third_party/xla/xla/service/gpu/model/analytical_latency_estimator.cc +++ b/third_party/xla/xla/service/gpu/model/analytical_latency_estimator.cc @@ -23,6 +23,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/utils/hlo_query.h" +#include "xla/service/gpu/model/gpu_collective_performance_model.h" #include "xla/service/gpu/model/gpu_hlo_cost_analysis.h" #include "xla/service/gpu/model/gpu_performance_model.h" #include "xla/service/hlo_cost_analysis.h" diff --git a/third_party/xla/xla/service/gpu/model/coalescing_analysis.cc b/third_party/xla/xla/service/gpu/model/coalescing_analysis.cc index 4073a83c7bf337..ef30f6d65722c8 100644 --- a/third_party/xla/xla/service/gpu/model/coalescing_analysis.cc +++ b/third_party/xla/xla/service/gpu/model/coalescing_analysis.cc @@ -15,27 +15,42 @@ limitations under the License. #include "xla/service/gpu/model/coalescing_analysis.h" +#include +#include +#include #include +#include #include "absl/container/flat_hash_map.h" +#include "absl/container/inlined_vector.h" +#include "absl/log/check.h" +#include "absl/types/span.h" +#include "llvm/ADT/STLExtras.h" +#include "mlir/IR/AffineExpr.h" // from @llvm-project #include "mlir/IR/AffineMap.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" +#include "xla/layout.h" +#include "xla/service/gpu/fusions/fusion_emitter.h" #include "xla/service/gpu/gpu_fusible.h" #include "xla/service/gpu/hlo_fusion_analysis.h" +#include "xla/service/gpu/hlo_traversal.h" #include "xla/service/gpu/model/indexing_analysis.h" #include "xla/service/gpu/model/indexing_map.h" #include "xla/shape.h" +#include "xla/shape_util.h" namespace xla { namespace gpu { -bool IsReadCoalescedHeuristic(const HloFusionAnalysis& fusion_analysis, +// Returns true if all input reads are coalesced. If consumer is not nullptr, +// producer and consumer are considered as one fusion, otherwise it's only the +// producer. +bool IsReadCoalescedHeuristic(HloFusionAnalysis::EmitterFusionKind fusion_kind, const HloInstruction* producer, const HloInstruction* consumer) { - auto fusion_kind = fusion_analysis.GetEmitterFusionKind(); - // Transposing minor dimension breaks coalescing. if (fusion_kind != HloFusionAnalysis::EmitterFusionKind::kTranspose) { auto is_broadcast = [&](const HloInstruction* instr) { @@ -48,7 +63,6 @@ bool IsReadCoalescedHeuristic(const HloFusionAnalysis& fusion_analysis, instr = instr->operand(0); } }; - auto is_bad_transpose = [&](const HloInstruction* instr) { if (instr->opcode() == HloOpcode::kFusion) { for (auto* instr : instr->fused_instructions()) { @@ -62,50 +76,330 @@ bool IsReadCoalescedHeuristic(const HloFusionAnalysis& fusion_analysis, } return TransposesMinorDimension(instr); }; - if (is_bad_transpose(producer)) return false; if (consumer && is_bad_transpose(consumer)) return false; } - // Fusing two row reductions breaks coalescing. if (fusion_kind == HloFusionAnalysis::EmitterFusionKind::kReduction && IsInputFusibleReduction(*producer) && consumer && IsInputFusibleReduction(*consumer)) { return false; } - return true; } -bool IsReadCoalesced(const HloInstruction* operand, const HloInstruction* instr, - const absl::flat_hash_map& indexing_maps, - mlir::MLIRContext* mlir_context) { - bool is_coalesced = true; - const Shape& output_shape = instr->shape(); - const Shape& operand_shape = operand->shape(); - auto output_physical_to_logical_map = - GetIndexingMapFromPhysicalLayoutToLogical(output_shape, mlir_context); - auto input_logical_to_physical_map = - GetIndexingMapFromLogicalToPhysicalLayout(operand_shape, mlir_context); - for (const auto& indexing_map : indexing_maps.at(operand)) { - if (!indexing_map.has_value()) return false; - - auto normalized_indexing_map = indexing_map; - if (output_physical_to_logical_map.has_value()) { - normalized_indexing_map = ComposeIndexingMaps( - normalized_indexing_map, output_physical_to_logical_map); +namespace { + +using mlir::AffineExpr; +using mlir::AffineMap; +using mlir::getAffineConstantExpr; +using mlir::MLIRContext; + +// Performs backtracking to find all feasible dimensions, symbols that satisfy +// the constraints and then evaluates the affine map at those. +// For example, for the following indexing map: +// (d0)[s0] -> (d0 + s0) +// domain: +// d0 in [0, 3] +// s0 in [0, 1, 2] +// s0 mod 2 in [0, 0] +// The function will compute the following indices [0, 2, 1, 3, 2, 4, 3, 5]. +void FindAllIndices(const IndexingMap& thread_id_to_physical_index, + MLIRContext* mlir_context, int dim_id, int symbol_id, + std::vector* dimensions, + std::vector* symbols, + std::vector* indices) { + if (dim_id < thread_id_to_physical_index.GetDimensionCount()) { + Range dim_range = thread_id_to_physical_index.GetDimensionRange(dim_id); + for (int64_t dim_value = dim_range.lower_bound; + dim_value <= dim_range.upper_bound; ++dim_value) { + dimensions->push_back(getAffineConstantExpr(dim_value, mlir_context)); + FindAllIndices(thread_id_to_physical_index, mlir_context, dim_id + 1, + symbol_id, dimensions, symbols, indices); + dimensions->pop_back(); + } + return; + } + if (symbol_id < thread_id_to_physical_index.GetSymbolCount()) { + Range symbol_range = thread_id_to_physical_index.GetSymbolRange(symbol_id); + for (int64_t symbol_value = symbol_range.lower_bound; + symbol_value <= symbol_range.upper_bound; ++symbol_value) { + symbols->push_back(getAffineConstantExpr(symbol_value, mlir_context)); + FindAllIndices(thread_id_to_physical_index, mlir_context, dim_id, + symbol_id + 1, dimensions, symbols, indices); + symbols->pop_back(); + } + return; + } + if (!thread_id_to_physical_index.ConstraintsSatisfied(*dimensions, + *symbols)) { + return; + } + indices->push_back( + thread_id_to_physical_index.Evaluate(*dimensions, *symbols).front()); +} + +// Computes contiguous intervals of accessed elements. +// For example, for an indexing map +// (thread_x) -> (thread_x * 4 + s0 + (thread_x floordiv 16) * 1984) +// d0 in [0, 31] +// s0 in [0, 3] +// The intervals are [0, 63] and [2047, 2111]. +// TODO(b/325613460): Make it faster than O(number of elements in the domain). +std::vector FindContiguousIntervals( + const IndexingMap& thread_id_to_physical_index) { + CHECK(thread_id_to_physical_index.GetAffineMap().getNumResults() == 1) + << "Expects an affine map that maps to 1D."; + MLIRContext* mlir_context = thread_id_to_physical_index.GetMLIRContext(); + + // Find all linear indices, sort and deduplicate them. + std::vector dimensions, symbols; + std::vector linear_indices; + FindAllIndices(thread_id_to_physical_index, mlir_context, + /*dim_id=*/0, + /*symbol_id=*/0, &dimensions, &symbols, &linear_indices); + std::sort(linear_indices.begin(), linear_indices.end()); + linear_indices.erase( + std::unique(linear_indices.begin(), linear_indices.end()), + linear_indices.end()); + + // Scan over the sorted unique indices and combine them in intervals. + std::vector intervals; + for (int i = 0, start, end; i < linear_indices.size(); ++i) { + start = linear_indices[i++]; + end = start; + while (i < linear_indices.size() && linear_indices[i] == end + 1) { + ++end; + ++i; + } + intervals.push_back(Range{start, end}); + } + return intervals; +} + +int64_t CeilDiv(int64_t a, int64_t b) { return a / b + (a % b != 0); } + +// Approximately estimate the number of memory transactions needed to load all +// elements in every range and compare it with the "ideal" number of memory +// transactions, i.e. total number of elements in all ranges / WarpSize(). +// Note, that later we would need to take the element type into account. +bool EstimateCoalescingViaMemoryTransactionsCount( + absl::Span intervals, PrimitiveType element_type) { + constexpr int64_t kBytesPerMemoryTransaction = 128; + int64_t type_size = ShapeUtil::ByteSizeOfPrimitiveType(element_type); + int memory_transactions = 0; + int total_num_elements = 0; + for (const auto& range : intervals) { + int64_t num_elements = range.upper_bound - range.lower_bound + 1; + memory_transactions += + CeilDiv(num_elements * type_size, kBytesPerMemoryTransaction); + total_num_elements += num_elements; + } + if (memory_transactions == 0) { + return true; + } + int memory_transactions_lower_bound = + CeilDiv(total_num_elements * type_size, kBytesPerMemoryTransaction); + // The magic value chosen by an uneducated guess. + constexpr float kIsCoalescedThreshold = 0.9; + return memory_transactions_lower_bound > + memory_transactions * kIsCoalescedThreshold; +} + +bool IsCoalesced(const IndexingMap& thread_id_to_input_indexing_map, + PrimitiveType element_type) { + // Undefined indexing maps, i.e. those for which we don't know the indexing + // are assumed to be uncoalesced. + if (thread_id_to_input_indexing_map.IsUndefined()) { + return false; + } + // 0d constants are coalesced. + if (thread_id_to_input_indexing_map.GetAffineMap().getNumResults() == 0) { + return true; + } + MLIRContext* mlir_context = thread_id_to_input_indexing_map.GetMLIRContext(); + AffineExpr thread_x_dim = mlir::getAffineDimExpr( + KernelFusionInterface::kIndexingMapThreadIdxDims[0], mlir_context); + AffineExpr c0 = mlir::getAffineConstantExpr(0, mlir_context); + IndexingMap thread_x_first_32_elements{ + AffineMap::get(1, 0, {thread_x_dim, c0, c0, c0, c0, c0}, mlir_context), + {Range{0, 31}}, + {}}; + IndexingMap thread_x_to_linearized_input = + thread_x_first_32_elements * thread_id_to_input_indexing_map; + thread_x_to_linearized_input.Simplify(); + thread_x_to_linearized_input.RemoveUnusedSymbols(); + return EstimateCoalescingViaMemoryTransactionsCount( + FindContiguousIntervals(thread_x_to_linearized_input), element_type); +} + +// Returns a linearized shape, i.e. tensor. +Shape GetLinearizedShape(const Shape& shape) { + if (shape.rank() == 0) { + return shape; + } + std::vector dims{ShapeUtil::ElementsIn(shape)}; + auto result = Shape(shape.element_type(), dims, + absl::InlinedVector(dims.size(), false), {}); + *result.mutable_layout() = xla::Layout({0}); + return result; +} + +// Returns thread ID to linearized physical layout indexing map for each operand +// of the fusion. +std::optional GetThreadIdToInputMemoryLayoutsMaps( + const HloFusionAdaptor& fusion_adaptor, + absl::Span operands, + const HloFusionAnalysis& fusion_analysis, + KernelFusionInterface* fusion_interface, mlir::MLIRContext* mlir_context) { + GroupedByOpIndexingMap result; + for (const auto& [root_index, hero] : + llvm::enumerate(fusion_analysis.fusion_heroes())) { + for (const auto& [hero_operand_index, hero_operand] : + llvm::enumerate(hero->operands())) { + if (hero_operand->shape().rank() == 0) { + continue; + } + // Compute thread ID -> hero operand indexing map. + std::optional thread_id_to_hero_operand_map = + fusion_interface->ComputeThreadIdToInputIndexing( + root_index, hero_operand_index, mlir_context); + if (!thread_id_to_hero_operand_map.has_value()) { + return std::nullopt; + } + // Compute indexing from output to inputs for logical layout. + HloInstructionAdaptor hero_operand_adaptor(*hero_operand); + GroupedByOpIndexingMap instr_indexing_keyed_by_operands = + ComputeGroupedOutputToInputIndexing( + fusion_adaptor, hero_operand_adaptor, mlir_context); + // For every operand compute thread ID -> physical layout of operand + // indexing map. + for (const HloInstruction* operand : operands) { + auto operand_indexing_maps_it = + instr_indexing_keyed_by_operands.find(operand); + if (operand_indexing_maps_it == + instr_indexing_keyed_by_operands.end()) { + continue; + } + const Shape& operand_shape = operand->shape(); + + IndexingMap operand_logical_to_physical_map = + GetIndexingMapFromLogicalToPhysicalLayout(operand_shape, + mlir_context); + IndexingMap operand_physical_to_linearized_shape = GetBitcastMap( + ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( + operand_shape), + GetLinearizedShape(operand_shape), mlir_context); + IndexingMap operand_logical_to_linearized_physical_shape = + operand_logical_to_physical_map * + operand_physical_to_linearized_shape; + operand_logical_to_linearized_physical_shape.Simplify(); + + for (const IndexingMap& operand_indexing_map : + operand_indexing_maps_it->second) { + // If one of the indexing maps for the operand is undefined, we remove + // all indexing maps for it and store only the undefined one. + if (operand_indexing_map.IsUndefined()) { + result[operand] = {operand_indexing_map}; + break; + } + IndexingMap logical_output_to_linearized_physical_input_map = + operand_indexing_map * + operand_logical_to_linearized_physical_shape; + IndexingMap thread_id_to_linearized_physical_input_map = + *thread_id_to_hero_operand_map * + logical_output_to_linearized_physical_input_map; + thread_id_to_linearized_physical_input_map.Simplify(); + result[operand].insert(thread_id_to_linearized_physical_input_map); + } + } + } + } + return result; +} + +} // namespace + +CoalescingAnalysis::CoalescingAnalysis( + const HloInstruction* instr, + absl::Span operands, + const HloFusionAnalysis& fusion_analysis, + KernelFusionInterface* fusion_interface, mlir::MLIRContext* mlir_context, + bool use_heuristic) { + auto fusion_adaptor = HloFusionAdaptor::ForInstruction(instr); + if (!use_heuristic && ComputeCoalescingForAllOperands( + *fusion_adaptor, operands, fusion_analysis, + fusion_interface, mlir_context)) { + return; + } + // If ComputeCoalescingForAllOperands fails, fallback to using the heuristic. + is_coalesced_computed_by_heuristic_ = + IsReadCoalescedHeuristic(fusion_analysis.GetEmitterFusionKind(), instr); +} + +CoalescingAnalysis::CoalescingAnalysis( + const HloInstruction* producer, const HloInstruction* consumer, + absl::Span operands, + const HloFusionAnalysis& fusion_analysis, + KernelFusionInterface* fusion_interface, mlir::MLIRContext* mlir_context, + bool use_heuristic) { + ProducerConsumerFusion fusion_adaptor(producer, consumer); + if (!use_heuristic && + ComputeCoalescingForAllOperands(fusion_adaptor, operands, fusion_analysis, + fusion_interface, mlir_context)) { + return; + } + // If ComputeCoalescingForAllOperands fails, fallback to using the heuristic. + is_coalesced_computed_by_heuristic_ = IsReadCoalescedHeuristic( + fusion_analysis.GetEmitterFusionKind(), producer, consumer); +} + +bool CoalescingAnalysis::ComputeCoalescingForAllOperands( + const HloFusionAdaptor& fusion_adaptor, + absl::Span operands, + const HloFusionAnalysis& fusion_analysis, + KernelFusionInterface* fusion_interface, mlir::MLIRContext* mlir_context) { + std::optional thread_id_to_input_memory_layouts = + GetThreadIdToInputMemoryLayoutsMaps(fusion_adaptor, operands, + fusion_analysis, fusion_interface, + mlir_context); + if (!thread_id_to_input_memory_layouts.has_value()) { + return false; + } + for (const HloInstruction* operand : operands) { + if (operand->shape().rank() == 0) { + coalescing_per_operand_.insert({operand, true}); + continue; + } + auto operand_indexing_maps = + thread_id_to_input_memory_layouts->find(operand); + // If there is no indexing map for the operand, it means that it is not used + // in the fusion cluster. + if (operand_indexing_maps == thread_id_to_input_memory_layouts->end()) { + coalescing_per_operand_.insert({operand, true}); + continue; } - if (input_logical_to_physical_map.has_value()) { - normalized_indexing_map = ComposeIndexingMaps( - input_logical_to_physical_map, normalized_indexing_map); + for (const IndexingMap& operand_indexing_map : + operand_indexing_maps->second) { + bool is_coalesced = + IsCoalesced(operand_indexing_map, operand->shape().element_type()); + auto [it, inserted] = + coalescing_per_operand_.insert({operand, is_coalesced}); + if (!inserted) { + it->second &= is_coalesced; + } + if (!is_coalesced) break; } - // First version is naive, we just check that the affine maps of input and - // output have the same minor dimension. - is_coalesced &= - normalized_indexing_map->affine_map.isMinorIdentityWithBroadcasting(); } - return is_coalesced; + return true; +} + +bool CoalescingAnalysis::IsReadCoalesced(const HloInstruction* operand) const { + auto it = coalescing_per_operand_.find(operand); + if (it == coalescing_per_operand_.end()) { + return is_coalesced_computed_by_heuristic_; + } + return it->second; } } // namespace gpu diff --git a/third_party/xla/xla/service/gpu/model/coalescing_analysis.h b/third_party/xla/xla/service/gpu/model/coalescing_analysis.h index 3ea6e543c179b0..300036aa453bae 100644 --- a/third_party/xla/xla/service/gpu/model/coalescing_analysis.h +++ b/third_party/xla/xla/service/gpu/model/coalescing_analysis.h @@ -16,30 +16,61 @@ limitations under the License. #ifndef XLA_SERVICE_GPU_MODEL_COALESCING_ANALYSIS_H_ #define XLA_SERVICE_GPU_MODEL_COALESCING_ANALYSIS_H_ -#include - #include "absl/container/flat_hash_map.h" +#include "absl/types/span.h" #include "mlir/IR/MLIRContext.h" // from @llvm-project #include "xla/hlo/ir/hlo_instruction.h" +#include "xla/service/gpu/fusions/fusion_emitter.h" #include "xla/service/gpu/hlo_fusion_analysis.h" -#include "xla/service/gpu/model/indexing_analysis.h" +#include "xla/service/gpu/hlo_traversal.h" namespace xla { namespace gpu { +// Computes read coalescing for operands of an instruction or a +// producer-consumer fusion. +// Note, that later, after we migrate away from using the heuristic, we might +// want to use HloFusionAdaptor instead of having two different constructors. +class CoalescingAnalysis { + public: + // Computes read coalescing for operands of `instr`. + CoalescingAnalysis(const HloInstruction* instr, + absl::Span operands, + const HloFusionAnalysis& fusion_analysis, + KernelFusionInterface* fusion_interface = nullptr, + mlir::MLIRContext* mlir_context = nullptr, + bool use_heuristic = true); + + // Computes read coalescing for operands of fused `producer` and `consumer`. + CoalescingAnalysis(const HloInstruction* producer, + const HloInstruction* consumer, + absl::Span operands, + const HloFusionAnalysis& fusion_analysis, + KernelFusionInterface* fusion_interface = nullptr, + mlir::MLIRContext* mlir_context = nullptr, + bool use_heuristic = true); + + // Returns true if the operand is read coalesced. + bool IsReadCoalesced(const HloInstruction* operand) const; + + private: + bool ComputeCoalescingForAllOperands( + const HloFusionAdaptor& fusion_adaptor, + absl::Span operands, + const HloFusionAnalysis& fusion_analysis, + KernelFusionInterface* fusion_interface, mlir::MLIRContext* mlir_context); + + absl::flat_hash_map coalescing_per_operand_; + bool is_coalesced_computed_by_heuristic_ = false; +}; + // Returns true if all input reads are coalesced. If consumer is not nullptr, // producer and consumer are considered as one fusion, otherwise it's only the // producer. -bool IsReadCoalescedHeuristic(const HloFusionAnalysis& fusion_analysis, +bool IsReadCoalescedHeuristic(HloFusionAnalysis::EmitterFusionKind fusion_kind, const HloInstruction* producer, const HloInstruction* consumer = nullptr); -// Returns true, if operand's read is coalesced. -bool IsReadCoalesced(const HloInstruction* operand, const HloInstruction* instr, - const absl::flat_hash_map& indexing_maps, - mlir::MLIRContext* mlir_context); - } // namespace gpu } // namespace xla diff --git a/third_party/xla/xla/service/gpu/model/coalescing_analysis_test.cc b/third_party/xla/xla/service/gpu/model/coalescing_analysis_test.cc index f01372b32d4759..5b250330bb8120 100644 --- a/third_party/xla/xla/service/gpu/model/coalescing_analysis_test.cc +++ b/third_party/xla/xla/service/gpu/model/coalescing_analysis_test.cc @@ -15,107 +15,361 @@ limitations under the License. #include "xla/service/gpu/model/coalescing_analysis.h" +#include +#include #include #include -#include "absl/log/check.h" #include "absl/strings/string_view.h" -#include "llvm/ADT/STLExtras.h" #include "mlir/IR/MLIRContext.h" // from @llvm-project #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" +#include "xla/service/gpu/fusions/fusion_emitter.h" +#include "xla/service/gpu/fusions/fusions.h" +#include "xla/service/gpu/gpu_device_info_for_tests.h" +#include "xla/service/gpu/hlo_fusion_analysis.h" #include "xla/service/gpu/hlo_traversal.h" -#include "xla/service/gpu/model/indexing_analysis.h" +#include "xla/shape.h" +#include "xla/shape_util.h" +#include "xla/stream_executor/device_description.h" #include "xla/tests/hlo_test_base.h" -#include "tsl/platform/statusor.h" #include "tsl/platform/test.h" namespace xla { namespace gpu { namespace { -using ::testing::ElementsAreArray; +using ::testing::ElementsAre; class CoalescingTest : public HloTestBase { public: - void GetRoot(absl::string_view hlo_string, - absl::Span expected_results) { - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnVerifiedModule(hlo_string)); + std::vector IsReadCoalescedPerOperand(absl::string_view hlo_string) { + auto module = ParseAndReturnVerifiedModule(hlo_string).value(); HloInstruction* root = module->entry_computation()->root_instruction(); + return IsReadCoalescedPerOperand(root); + } - for (auto* operand : root->operands()) { - CHECK(operand->opcode() == HloOpcode::kParameter || - operand->opcode() == HloOpcode::kConstant) - << "If there are multiple instructions, they need to be wrapped in a " - "fusion."; - } + std::vector IsReadCoalescedPerOperand(const HloInstruction* root) { auto fusion_adaptor = HloFusionAdaptor::ForInstruction(root); + auto analysis = AnalyzeFusion(*root, device_info_); + auto emitter = GetFusionEmitter(PreBufferAssignmentFusionInfo{analysis}); + auto fusion = dynamic_cast(emitter.value().get()); + EXPECT_TRUE(emitter.ok()); - auto output_to_input_indexing = - ComputeOutputToInputIndexing(root, /*output_id=*/0, &mlir_context_); - auto grouped_indexing_maps = - GroupIndexingMapsByProducers(output_to_input_indexing, root); + CoalescingAnalysis coalescing_analysis(root, root->operands(), analysis, + fusion, &mlir_context_, + /*use_heuristic=*/false); - std::vector actual_results; - actual_results.reserve(expected_results.size()); - for (auto [operand_id, is_coalesced] : llvm::enumerate(expected_results)) { - auto* operand = root->operand(operand_id); - actual_results.push_back(IsReadCoalesced( - operand, root, grouped_indexing_maps, &mlir_context_)); + std::vector results; + for (const HloInstruction* operand : root->operands()) { + results.push_back(coalescing_analysis.IsReadCoalesced(operand)); } - EXPECT_THAT(actual_results, ElementsAreArray(expected_results)); + return results; } + protected: + stream_executor::DeviceDescription device_info_ = + TestGpuDeviceInfo::RTXA6000DeviceInfo(); mlir::MLIRContext mlir_context_; }; TEST_F(CoalescingTest, IdentityLayout) { - GetRoot(R"( + absl::string_view ir = R"( HloModule m + fusion { + p0 = f32[100, 200] parameter(0) + p1 = f32[100, 200] parameter(1) + ROOT adthread_x = f32[100, 200] add(p0, p1) + } ENTRY e { - p0 = f32[10, 20] parameter(0) - p1 = f32[10, 20] parameter(1) - ROOT add0 = f32[10, 20] add(p0, p1) + p0 = f32[100, 200] parameter(0) + p1 = f32[100, 200] parameter(1) + ROOT fusion = f32[100, 200] fusion(p0, p1), kind=kInput, calls=fusion } - )", - {true, true}); + )"; + // thread_x to linearized input mapping for thread_x in [0, 31]: + // Operand 1: (thread_x) -> (thread_x) + // Operand 2: (thread_x) -> (thread_x) + EXPECT_THAT(IsReadCoalescedPerOperand(ir), ElementsAre(true, true)); } TEST_F(CoalescingTest, RhsTransposedLayout) { - GetRoot(R"( + absl::string_view ir = R"( HloModule m + fusion { + p0 = f32[100, 200]{1, 0} parameter(0) + p1 = f32[100, 200]{0, 1} parameter(1) + ROOT exp = f32[100, 200]{1, 0} add(p0, p1) + } ENTRY e { - p0 = f32[10, 20]{1, 0} parameter(0) - p1 = f32[10, 20]{0, 1} parameter(1) - ROOT exp = f32[10, 20]{1, 0} add(p0, p1) + p0 = f32[100, 200]{1, 0} parameter(0) + p1 = f32[100, 200]{0, 1} parameter(1) + ROOT fusion = f32[100, 200]{1, 0} fusion(p0, p1), kind=kInput, calls=fusion } - )", - {true, false}); + )"; + // thread_x to linearized input mapping for thread_x in [0, 31]: + // Operand 1: (thread_x) -> (thread_x) + // Operand 2: (thread_x) -> (thread_x * 100) + EXPECT_THAT(IsReadCoalescedPerOperand(ir), ElementsAre(true, false)); } TEST_F(CoalescingTest, OutputTransposedLayout) { - GetRoot(R"( + absl::string_view ir = R"( HloModule m + fusion { + p0 = f32[100, 200]{1, 0} parameter(0) + p1 = f32[100, 200]{1, 0} parameter(1) + ROOT exp = f32[100, 200]{0, 1} add(p0, p1) + } ENTRY e { - p0 = f32[10, 20]{1, 0} parameter(0) - p1 = f32[10, 20]{1, 0} parameter(1) - ROOT exp = f32[10, 20]{0, 1} add(p0, p1) + p0 = f32[100, 200]{1, 0} parameter(0) + p1 = f32[100, 200]{1, 0} parameter(1) + ROOT fusion = f32[100, 200]{0, 1} fusion(p0, p1), kind=kInput, calls=fusion } - )", - {false, false}); + )"; + // thread_x to linearized input mapping for thread_x in [0, 31]: + // Operand 1: (thread_x) -> (thread_x * 200) + // Operand 2: (thread_x) -> (thread_x * 200) + EXPECT_THAT(IsReadCoalescedPerOperand(ir), ElementsAre(false, false)); } TEST_F(CoalescingTest, OutputAndLhsTransposedLayout) { - GetRoot(R"( + absl::string_view ir = R"( HloModule m + fusion { + p0 = f32[100, 200]{1, 0} parameter(0) + p1 = f32[100, 200]{0, 1} parameter(1) + ROOT exp = f32[100, 200]{1, 0} add(p0, p1) + } ENTRY e { - p0 = f32[10, 20]{1, 0} parameter(0) - p1 = f32[10, 20]{0, 1} parameter(1) - ROOT exp = f32[10, 20]{1, 0} add(p0, p1) + p0 = f32[100, 200]{1, 0} parameter(0) + p1 = f32[100, 200]{0, 1} parameter(1) + ROOT fusion = f32[100, 200]{1, 0} fusion(p0, p1), kind=kInput, calls=fusion + } + )"; + // thread_x to linearized input mapping for thread_x in [0, 31]: + // Operand 1: (thread_x) -> (thread_x) + // Operand 2: (thread_x) -> (thread_x * 100) + EXPECT_THAT(IsReadCoalescedPerOperand(ir), ElementsAre(true, false)); +} + +TEST_F(CoalescingTest, Transpose) { + absl::string_view ir = R"( + HloModule module + + fusion { + %input = f32[100, 64, 32] parameter(0) + ROOT transpose = f32[32, 100, 64] transpose(%input), dimensions={2, 0, 1} + } + + ENTRY entry { + %input = f32[100, 64, 32] parameter(0) + ROOT %fusion = f32[32, 100, 64] fusion(%input), kind=kLoop, calls=fusion + })"; + // thread_x to linearized input mapping for thread_x in [0, 31]: + // Operand 1: (thread_x)[s0] -> (thread_x + s0 * 128) for s0 in [0, 7] + EXPECT_THAT(IsReadCoalescedPerOperand(ir), ElementsAre(true)); +} + +TEST_F(CoalescingTest, TransposeOnlyOuterDims) { + absl::string_view ir = R"( + HloModule module + + fusion { + %input = f32[100, 32, 64] parameter(0) + ROOT transpose = f32[32, 100, 64] transpose(%input), dimensions={1, 0, 2} + } + + ENTRY entry { + %input = f32[100, 32, 64] parameter(0) + ROOT %fusion = f32[32, 100, 64] fusion(%input), kind=kLoop, calls=fusion + })"; + // thread_x to linearized input mapping for thread_x in [0, 31]: + // Operand 1: + // (thread_x) -> (thread_x * 4 + s0 + (thread_x floordiv 16) * 1984) + // for s0 in [0, 3] + EXPECT_THAT(IsReadCoalescedPerOperand(ir), ElementsAre(true)); +} + +TEST_F(CoalescingTest, PadOp) { + absl::string_view ir = R"( + HloModule module + fusion { + p0 = f32[997, 436] parameter(0) + p1 = f32[] parameter(1) + ROOT pad = f32[1024, 512] pad(p0, p1), padding=10_17x24_52 + } + ENTRY entry { + p0 = f32[997, 436] parameter(0) + p1 = f32[] parameter(1) + ROOT %fusion = f32[1024, 512] fusion(p0, p1), kind=kLoop, calls=fusion + })"; + // thread_x to linearized input mapping for thread_x in [0, 31]: + // Operand 1: (thread_x)[s0] -> (thread_x * 4 + s0 - 4384) + // for s0 in [0, 3] and thread_x * 4 + s0 in [24, 459] + // Operand 2: (thread_x) -> () + EXPECT_THAT(IsReadCoalescedPerOperand(ir), ElementsAre(true, true)); +} + +TEST_F(CoalescingTest, RowReduction) { + absl::string_view ir = R"( + HloModule module + add { + p0 = f32[] parameter(0) + p1 = f32[] parameter(1) + ROOT add = f32[] add(p0, p1) + } + fusion { + %input = f32[100,64,512] parameter(0) + %c0 = f32[] constant(0) + ROOT reduce = f32[100,64] reduce(%input, %c0), dimensions={2}, to_apply=add } - )", - {true, false}); + ENTRY entry { + %input = f32[100,64,512] parameter(0) + ROOT %fusion = f32[100,64] fusion(%input), kind=kInput, calls=fusion + })"; + // thread_x to linearized input mapping for thread_x in [0, 31]: + // Operand 1: (thread_x)[s0] -> (thread_x + s0 * 32) for s0 in [0, 15] + EXPECT_THAT(IsReadCoalescedPerOperand(ir), ElementsAre(true)); +} + +TEST_F(CoalescingTest, MultiRowReduction) { + absl::string_view ir = R"( + HloModule module + add { + p0 = f32[] parameter(0) + p1 = f32[] parameter(1) + ROOT add = f32[] add(p0, p1) + } + fusion { + %input = f32[100,64,4] parameter(0) + %c0 = f32[] constant(0) + ROOT reduce = f32[100,64] reduce(%input, %c0), dimensions={2}, to_apply=add + } + ENTRY entry { + %input = f32[100,64,4] parameter(0) + ROOT %fusion = f32[100,64] fusion(%input), kind=kInput, calls=fusion + })"; + // thread_x to linearized input mapping for thread_x in [0, 31]: + // Operand 1: (thread_x) -> (thread_x) + EXPECT_THAT(IsReadCoalescedPerOperand(ir), ElementsAre(true)); +} + +TEST_F(CoalescingTest, ColumnReduction) { + absl::string_view ir = R"( + HloModule module + add { + p0 = f32[] parameter(0) + p1 = f32[] parameter(1) + ROOT add = f32[] add(p0, p1) + } + fusion { + %input = f32[100,64,32] parameter(0) + %c0 = f32[] constant(0) + ROOT reduce = f32[100,32] reduce(%input, %c0), + dimensions={1}, to_apply=add + } + ENTRY entry { + %input = f32[100,64,32] parameter(0) + ROOT %fusion = f32[100,32] fusion(%input), kind=kInput, calls=fusion + })"; + // thread_x to linearized input mapping for thread_x in [0, 31]: + // Operand 1: (thread_x)[s0] -> (thread_x + s0 * 1024) for s0 in [0, 1] + EXPECT_THAT(IsReadCoalescedPerOperand(ir), ElementsAre(true)); +} + +TEST_F(CoalescingTest, VariadicReduceViaLoopEmitter) { + absl::string_view ir = R"( + HloModule module + max { + p0 = s32[] parameter(0) + p1 = s32[] parameter(1) + p2 = s32[] parameter(2) + p3 = s32[] parameter(3) + max01 = s32[] maximum(p0, p1) + max23 = s32[] maximum(p2, p3) + ROOT max = (s32[], s32[]) tuple(max01, max23) + } + fusion { + p0 = s32 [5696,10,4] parameter(0) + p1 = s32 [5696,10,4] parameter(1) + p2 = s32[] parameter(2) + p3 = s32[] parameter(3) + ROOT reduce = (s32[5696,4], s32[5696,4]) reduce(s32[5696,10,4] p0, + s32[5696,10,4] p1, s32[] p2, s32[] p3), dimensions={1}, to_apply=max + } + ENTRY entry { + p0 = s32 [5696,10,4] parameter(0) + p1 = s32 [5696,10,4] parameter(1) + p2 = s32[] parameter(2) + p3 = s32[] parameter(3) + ROOT f = (s32[5696,4], s32[5696,4]) fusion(p0, p1, p2, p3), + kind=kInput, calls=fusion + })"; + EXPECT_THAT(IsReadCoalescedPerOperand(ir), + ElementsAre(true, true, true, true)); +} + +TEST_F(CoalescingTest, VariadicReduceViaReductionEmitter) { + absl::string_view ir = R"( + HloModule module + max { + p0 = s32[] parameter(0) + p1 = s32[] parameter(1) + p2 = s32[] parameter(2) + p3 = s32[] parameter(3) + max01 = s32[] maximum(p0, p1) + max23 = s32[] maximum(p2, p3) + ROOT max = (s32[], s32[]) tuple(max01, max23) + } + fusion { + p0 = s32[32,40] parameter(0) + p1 = s32[32,40] parameter(1) + p2 = s32[] parameter(2) + p3 = s32[] parameter(3) + ROOT reduce = (s32[32], s32[32]) + reduce(s32[32,40] p0, s32[32,40] p1, s32[] p2, s32[] p3), + dimensions={1}, to_apply=max + } + ENTRY entry { + p0 = s32[32,40] parameter(0) + p1 = s32[32,40] parameter(1) + p2 = s32[] parameter(2) + p3 = s32[] parameter(3) + ROOT f = (s32[32], s32[32]) fusion(p0, p1, p2, p3), + kind=kInput, calls=fusion + })"; + EXPECT_THAT(IsReadCoalescedPerOperand(ir), + ElementsAre(true, true, true, true)); +} + +TEST_F(CoalescingTest, UnusedParameter) { + Shape shape = ShapeUtil::MakeShape(F32, {100000}); + + auto module = std::make_unique("m", HloModuleConfig{}); + HloComputation::Builder b("b"); + auto p0 = b.AddInstruction(HloInstruction::CreateParameter(0, shape, "p0")); + auto p1 = b.AddInstruction(HloInstruction::CreateParameter(1, shape, "p1")); + + HloComputation::Builder sub_builder("subcomp"); + HloInstruction* p0f = sub_builder.AddInstruction( + HloInstruction::CreateParameter(0, shape, "p0f")); + // p1f is not used. + HloInstruction* p1f = sub_builder.AddInstruction( + HloInstruction::CreateParameter(1, shape, "p1f")); + ASSERT_NE(p1f, nullptr); + sub_builder.AddInstruction( + HloInstruction::CreateUnary(shape, HloOpcode::kNegate, p0f)); + + HloComputation* subcomp = module->AddEmbeddedComputation(sub_builder.Build()); + auto fusion = HloInstruction::CreateFusion( + shape, HloInstruction::FusionKind::kLoop, {p0, p1}, subcomp); + b.AddInstruction(std::move(fusion)); + module->AddEntryComputation(b.Build()); + + EXPECT_THAT(IsReadCoalescedPerOperand( + module->entry_computation()->root_instruction()), + ElementsAre(true, true)); } } // namespace diff --git a/third_party/xla/xla/service/gpu/model/gpu_collective_performance_model.cc b/third_party/xla/xla/service/gpu/model/gpu_collective_performance_model.cc new file mode 100644 index 00000000000000..188426ece72a77 --- /dev/null +++ b/third_party/xla/xla/service/gpu/model/gpu_collective_performance_model.cc @@ -0,0 +1,304 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/model/gpu_collective_performance_model.h" + +#include +#include +#include +#include + +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/strings/numbers.h" +#include "absl/time/time.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/service/gpu/backend_configs.pb.h" +#include "xla/service/gpu/model/gpu_hlo_cost_analysis.h" +#include "xla/service/hlo_dataflow_analysis.h" +#include "xla/stream_executor/device_description.h" +#include "xla/util.h" + +#if GOOGLE_CUDA +#include "third_party/gpus/cuda/nvml/include/nvml.h" +#endif // GOOGLE_CUDA +namespace xla { +namespace gpu { + +namespace { + +int64_t GetNcclMaxNumChannels( + GpuPerformanceWithCollectiveModel::CollectiveAlgo algorithm) { + int64_t max_nchannels = 0; + switch (algorithm) { + // Tree and Ring algos share the same max channel number. + case GpuPerformanceWithCollectiveModel::RING: + case GpuPerformanceWithCollectiveModel::TREE: + max_nchannels = GpuPerformanceWithCollectiveModel::kMaxNumChannelsRing; + break; + } + const char* env = std::getenv("NCCL_MAX_NCHANNELS"); + if (env != nullptr) { + int64_t max_nchannels_from_env; + if (absl::SimpleAtoi(env, &max_nchannels_from_env)) { + max_nchannels = std::min(max_nchannels_from_env, max_nchannels); + } + } + return max_nchannels; +} + +int64_t GetMinNumberOfChannels( + GpuPerformanceWithCollectiveModel::CollectiveAlgo algorithm) { + int64_t min_nchannels = 0; + switch (algorithm) { + // Tree and Ring algos share the same min channel number. + case GpuPerformanceWithCollectiveModel::RING: + case GpuPerformanceWithCollectiveModel::TREE: + min_nchannels = 1; + break; + } + const char* env = std::getenv("NCCL_MIN_NCHANNELS"); + if (env != nullptr) { + int64_t min_nchannels_from_env; + if (absl::SimpleAtoi(env, &min_nchannels_from_env)) { + min_nchannels = std::min(min_nchannels_from_env, min_nchannels); + } + } + return min_nchannels; +} + +int GetNumThreads(int warp_size, int min_num_threads, int max_num_threads, + int default_num_threads) { + int threads_from_env = default_num_threads; + const char* env = std::getenv("NCCL_NTHREADS"); + if (env != nullptr) { + CHECK(absl::SimpleAtoi(env, &threads_from_env)); + } + int num_threads = threads_from_env; + if (num_threads > 0) { + if (num_threads % warp_size != 0) { + num_threads = max_num_threads; + } else if (num_threads > max_num_threads) { + num_threads = max_num_threads; + } else if (num_threads < min_num_threads) { + num_threads = min_num_threads; + } + } else { + num_threads = default_num_threads; + } + return num_threads; +} + +float GetMaxSysBwFromGpu(const se::CudaComputeCapability cc, + const double* bandwidths_table) { + switch (cc.major) { + case se::CudaComputeCapability::VOLTA: + return bandwidths_table[0]; + case se::CudaComputeCapability::AMPERE: + return bandwidths_table[1]; + case se::CudaComputeCapability::HOPPER: + return bandwidths_table[2]; + } + return -1; +} + +} // namespace + +// Returns NVLink bw in GB/s +/*static*/ +float GpuPerformanceWithCollectiveModel::GetNvlinkBw( + se::CudaComputeCapability compute_capability) { + return compute_capability.IsAtLeast(se::CudaComputeCapability::HOPPER) + ? kSm90NvlinkBandwidth + : compute_capability.IsAtLeast(se::CudaComputeCapability::AMPERE) + ? kSm80NvlinkBandwidth + : compute_capability.IsAtLeast(se::CudaComputeCapability::VOLTA) + ? kSm70NvlinkBandwidth + : compute_capability.IsAtLeast(se::CudaComputeCapability::PASCAL_) + ? kSm60NvlinkBandwidth + : kSm80NvlinkBandwidth; +} + +/*static*/ bool GpuPerformanceWithCollectiveModel::InitNvml() { +#if GOOGLE_CUDA + void* libhandle = dlopen("libnvidia-ml.so.1", RTLD_NOW); + CHECK(libhandle != nullptr) << "Failed to open libnvidia-ml.so.1"; + + struct SymbolEntry { + void** functor; + char const* name; + }; + + std::vector symbols = { + {(void**)&xla_nvmlInit, "nvmlInit_v2"}, + {(void**)&xla_nvmlShutdown, "nvmlShutdown"}, + {(void**)&xla_nvmlDeviceGetHandleByIndex, "nvmlDeviceGetHandleByIndex"}, + {(void**)&xla_nvmlDeviceGetNvLinkCapability, + "nvmlDeviceGetNvLinkCapability"}, + }; + for (SymbolEntry se : symbols) { + *se.functor = dlsym(libhandle, se.name); + } + nvmlReturn_t init_result = xla_nvmlInit(); + return init_result == NVML_SUCCESS; +#else + return false; +#endif // GOOGLE_CUDA +} + +/*static*/ bool GpuPerformanceWithCollectiveModel::ShutdownNvml() { +#if GOOGLE_CUDA + nvmlReturn_t shutdown_result = xla_nvmlShutdown(); + return shutdown_result == NVML_SUCCESS; +#else + return false; +#endif // GOOGLE_CUDA +} + +/*static*/ uint32_t +GpuPerformanceWithCollectiveModel::CheckIfNvlinkSupportsP2P() { +#if GOOGLE_CUDA + // We will use nvml library to detect nvlink capability + // to see if it supports p2p communication. + // We first load libnvidia-ml.so and assign symbols to function pointers + // to avoid linking errors. + // Then gpu 0 will be used to query for nvlink capability, note that + // we only look at link 0 of gpu 0 since all other links are assumed + // to have the same capability. + CHECK(InitNvml()) << "NVML init failed."; + nvmlDevice_t nvml_device; + nvmlReturn_t get_device_result = + xla_nvmlDeviceGetHandleByIndex(0, &nvml_device); + CHECK(get_device_result == NVML_SUCCESS); + + uint32_t supported_p2p = 0; + + nvmlReturn_t nvlink_cap_result = xla_nvmlDeviceGetNvLinkCapability( + nvml_device, /*nvlink link number*/ 0, NVML_NVLINK_CAP_P2P_SUPPORTED, + &supported_p2p); + CHECK(nvlink_cap_result == NVML_SUCCESS); + CHECK(ShutdownNvml()) << "NVML shutdown failed."; + return supported_p2p; +#else + return 0; +#endif // GOOGLE_CUDA +} + +/*static*/ absl::Duration +GpuPerformanceWithCollectiveModel::ComputeAllreduceTime( + const HloInstruction& instr, const GpuHloCostAnalysis* cost_analysis, + const se::DeviceDescription& gpu_device_info) { + // We use nccl group call to launch multiple allreduces so launch overhead + // only occurs once. + absl::Duration total_time = kNcclKernelLaunchOverhead; + stream_executor::CudaComputeCapability compute_cap = + gpu_device_info.cuda_compute_capability(); + + int64_t size_of_speed_array = kIntraNodeSpeeds.size(); + int64_t size_of_sm90_speed_array = kIntraNodeSpeedsSm90.size(); + + int num_speeds = compute_cap.major >= se::CudaComputeCapability::HOPPER + ? size_of_sm90_speed_array + : size_of_speed_array; + const double* speeds = compute_cap.major >= se::CudaComputeCapability::HOPPER + ? kIntraNodeSpeedsSm90.data() + : kIntraNodeSpeeds.data(); + + int speed_index = 0; + float max_sys_bw = + GetMaxSysBwFromGpu(compute_cap, kLowLatencyMaxBandwidths.data()); + + CHECK_GT(max_sys_bw, 0); + + while ((speed_index < num_speeds - 1) && speeds[speed_index] > max_sys_bw) { + speed_index++; + } + float bw_intra_node = speeds[speed_index]; + int64_t num_devices = cost_analysis->NumOfDevices(instr); + + int64_t min_nchannels = + std::max(num_devices, GetMinNumberOfChannels(CollectiveAlgo::RING)); + int64_t num_channels = + std::max(min_nchannels, GetNcclMaxNumChannels(CollectiveAlgo::RING)); + int default_threads = + (bw_intra_node * num_channels <= kPciBandwidth) ? 256 : kLL128NumThreads; + + int warp_size = gpu_device_info.threads_per_warp(); + int num_threads = GetNumThreads(warp_size, kLL128NumThreads / 4, + kLL128NumThreads, default_threads); + + // Since channels are pipelined together, compute time will only occur as in a + // single channel. + absl::Duration compute_time_per_channel = + ComputeTime(gpu_device_info, + cost_analysis->flop_count(instr) / num_channels, num_threads); + total_time += compute_time_per_channel; + + uint32_t supported_p2p = CheckIfNvlinkSupportsP2P(); + + if (supported_p2p == 0) { + VLOG(8) << "Nvlink doesn't support p2p communication. Model will " + "continue using default system bandwidth."; + } else { + VLOG(8) << "Nvlink supports p2p communication, setting intra node " + "bandwidth to nvlink bw."; + bw_intra_node = GetNvlinkBw(compute_cap); + } + + double bus_bandwidth = bw_intra_node * num_channels; + + // Get per channel LL128 ring bandwidth + double per_channel_ring_ll128_Bw = + GetMaxSysBwFromGpu(compute_cap, kPerChannelMaxRingLL128Bandwidths.data()); + + bus_bandwidth = std::min(bus_bandwidth * kRingAlgorithmDiscountFactor, + num_channels * per_channel_ring_ll128_Bw); + double actual_bandwidth = bus_bandwidth * cost_analysis->ScalingRatio(instr); + + absl::Duration communication_time = absl::Microseconds( + cost_analysis->bytes_accessed(instr) / (1e6 * actual_bandwidth)); + total_time += communication_time; + return total_time; +} + +/*static*/ absl::Duration +GpuPerformanceWithCollectiveModel::ComputeCollectiveTime( + const HloInstruction& instr, const GpuHloCostAnalysis* cost_analysis, + const se::DeviceDescription& gpu_device_info) { + if (cost_analysis->NumOfDevices(instr) == 1) { + VLOG(8) << "Returning only kernel launch overhead for a single partition."; + return kNcclKernelLaunchOverhead; + } + + if (HloDataflowAnalysis::IsAsynchronousOperationDone(instr.opcode())) { + VLOG(8) << "Returning 0 cost for async done op " << instr.name(); + return absl::ZeroDuration(); + } + switch (instr.opcode()) { + case HloOpcode::kAllReduce: + case HloOpcode::kAllReduceStart: + return ComputeAllreduceTime(instr, cost_analysis, gpu_device_info); + default: { + LOG(WARNING) + << "Runtime estimate for " << instr.name() + << " not implemented. Returning only the kernel launch time."; + return kNcclKernelLaunchOverhead; + } + } +} + +} // namespace gpu +} // namespace xla diff --git a/third_party/xla/xla/service/gpu/model/gpu_collective_performance_model.h b/third_party/xla/xla/service/gpu/model/gpu_collective_performance_model.h new file mode 100644 index 00000000000000..c11a78c684e80d --- /dev/null +++ b/third_party/xla/xla/service/gpu/model/gpu_collective_performance_model.h @@ -0,0 +1,128 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_GPU_MODEL_GPU_COLLECTIVE_PERFORMANCE_MODEL_H_ +#define XLA_SERVICE_GPU_MODEL_GPU_COLLECTIVE_PERFORMANCE_MODEL_H_ + +#include +#include + +#include "absl/time/time.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/service/gpu/model/gpu_hlo_cost_analysis.h" +#include "xla/service/gpu/model/gpu_performance_model_base.h" +#include "xla/stream_executor/device_description.h" + +#if GOOGLE_CUDA +#include + +#include "third_party/gpus/cuda/nvml/include/nvml.h" +// Below is a list of function pointers to be used +// for querying device properties through nvml library. +#define NVML_FUNCTOR(name, rettype, args) \ + inline rettype(*xla_##name) args = nullptr; + +NVML_FUNCTOR(nvmlInit, nvmlReturn_t, ()) +NVML_FUNCTOR(nvmlShutdown, nvmlReturn_t, ()) +NVML_FUNCTOR(nvmlDeviceGetHandleByIndex, nvmlReturn_t, + (unsigned int index, nvmlDevice_t* device)) +NVML_FUNCTOR(nvmlDeviceGetNvLinkCapability, nvmlReturn_t, + (nvmlDevice_t device, unsigned int link, + nvmlNvLinkCapability_t capability, unsigned int* capResult)) + +#endif + +namespace xla { +namespace gpu { + +class GpuPerformanceWithCollectiveModel : public GpuPerformanceModelBase { + public: + // Different algorithms that can be used to perform the collective. + enum CollectiveAlgo { + RING = 0, + TREE, + }; + + // Table for max system bandwidths GB/s for using NCCL's low latency + // algorithm. This is used for intra-node estimate. + static constexpr std::array kLowLatencyMaxBandwidths = { + 39.0 /* Volta*/, 87.7 /* Ampere*/, 87.7 /* Hopper*/ + }; + + // Max bandwidth in GB/s for ring low latency 128 algorithm per channel on a + // single-node + static constexpr std::array kPerChannelMaxRingLL128Bandwidths = { + 20.0 /* Volta */, + 20.0 /* Ampere */, + 36.7 /* Hopper */, + }; + + // Nvlink unidirectional bandwidth for different compute cap. Note this is per + // lane bandwidth. + static constexpr double kSm60NvlinkBandwidth = 18.0; + static constexpr double kSm70NvlinkBandwidth = 20.0; + static constexpr double kSm80NvlinkBandwidth = 20.0; + static constexpr double kSm90NvlinkBandwidth = 20.0; + + // PCIE bandwidth for PCI Gen3 x16 + static constexpr double kPciBandwidth = 12.0; + + // Discount factor for ring algorithm + static constexpr double kRingAlgorithmDiscountFactor = 0.92; + + // Different tiers for intra-node bandwidth. + static constexpr std::array kIntraNodeSpeeds = { + 40.0, 30.0, 20.0, 18.0, 15.0, 12.0, 10.0, 9.0, 7.0, 6.0, 5.0, 4.0, 3.0}; + // SM90 has different bandwidths. + static constexpr std::array kIntraNodeSpeedsSm90 = { + 60.0, 40.0, 30.0, 24.0, 20.0, 15.0, 12.0, 6.0, 3.0}; + + // Maximum number of channels allowed by NCCL + static constexpr int64_t kMaxNumChannelsRing = 16; + + // ll128 is by default enabled for Volta, Ampere and Hopper, ll128 by default + // launches 640 threads. + static constexpr int64_t kLL128NumThreads = 640; + + static constexpr absl::Duration kNcclKernelLaunchOverhead = + absl::Microseconds(5); + + static absl::Duration ComputeCollectiveTime( + const HloInstruction& instr, const GpuHloCostAnalysis* cost_analysis, + const se::DeviceDescription& gpu_device_info); + + // Returns NVLink bw in GB/s + static float GetNvlinkBw(se::CudaComputeCapability compute_capability); + + // Initialize nvml library. + static bool InitNvml(); + + // Shut down nvml library. + static bool ShutdownNvml(); + + // This checks if the nvlink supports direct P2P communication, + // If not, we will use PCIE bandwidth to estimate latency. + static uint32_t CheckIfNvlinkSupportsP2P(); + + private: + static absl::Duration ComputeAllreduceTime( + const HloInstruction& instr, const GpuHloCostAnalysis* cost_analysis, + const se::DeviceDescription& gpu_device_info); +}; + +} // namespace gpu +} // namespace xla + +#endif // XLA_SERVICE_GPU_MODEL_GPU_COLLECTIVE_PERFORMANCE_MODEL_H_ diff --git a/third_party/xla/xla/service/gpu/model/gpu_collective_performance_model_test.cc b/third_party/xla/xla/service/gpu/model/gpu_collective_performance_model_test.cc new file mode 100644 index 00000000000000..3136a13422a6f3 --- /dev/null +++ b/third_party/xla/xla/service/gpu/model/gpu_collective_performance_model_test.cc @@ -0,0 +1,45 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/model/gpu_collective_performance_model.h" + +#include +#include "xla/service/gpu/backend_configs.pb.h" +#include "xla/tests/hlo_test_base.h" + +namespace xla { +namespace gpu { +namespace { + +using GpuPerformanceWithCollectiveModelTest = HloTestBase; + +TEST_F(GpuPerformanceWithCollectiveModelTest, TestNvmlLibraryLoading) { +#if GOOGLE_CUDA + EXPECT_TRUE(GpuPerformanceWithCollectiveModel::InitNvml()); + // After successful init, we try to use one of the + // nvml functions to see if the result is good. + nvmlDevice_t nvml_device; + nvmlReturn_t get_device_result = + xla_nvmlDeviceGetHandleByIndex(0, &nvml_device); + EXPECT_TRUE(get_device_result == NVML_SUCCESS); + + EXPECT_TRUE(GpuPerformanceWithCollectiveModel::InitNvml()); + +#endif // GOOGLE_CUDA +} + +} // namespace +} // namespace gpu +} // namespace xla diff --git a/third_party/xla/xla/service/gpu/model/gpu_hlo_cost_analysis.cc b/third_party/xla/xla/service/gpu/model/gpu_hlo_cost_analysis.cc index 2b4dc7e6456264..c1ec02bc5267cf 100644 --- a/third_party/xla/xla/service/gpu/model/gpu_hlo_cost_analysis.cc +++ b/third_party/xla/xla/service/gpu/model/gpu_hlo_cost_analysis.cc @@ -41,7 +41,6 @@ limitations under the License. #include "xla/service/gpu/cublas_cudnn.h" #include "xla/service/gpu/model/hlo_op_profile.pb.h" #include "xla/service/gpu/model/hlo_op_profiles.h" -#include "xla/service/gpu/model/hlo_op_profiles_data.h" #include "xla/service/hlo_cost_analysis.h" #include "xla/service/hlo_module_config.h" #include "xla/shape.h" @@ -95,7 +94,7 @@ int64_t GpuHloCostAnalysis::FusionParameterReadBytes( if (!options_.count_multiple_input_accesses) { utilization = fmin(utilization, 1.0); } - return GetShapeSize(hlo->shape()) * utilization; + return std::llround(GetShapeSize(hlo->shape()) * utilization); } absl::Status GpuHloCostAnalysis::FusionCalculateUtilizations( @@ -327,11 +326,7 @@ int64_t GpuHloCostAnalysis::GetConvolutionFlops( int64_t FlopsPerElement(const se::DeviceDescription* device_info, const PrimitiveType type, const HloOpcode opcode) { - static const auto* hlo_op_profiles = - HloOpProfiles::Load(kDeviceHloOpProfiles, - /*default_profile_name=*/"sm_86") - .release(); - auto device_profile = hlo_op_profiles->GetProfile(device_info); + auto device_profile = HloOpProfiles::Singleton().GetProfile(device_info); // Elementwise instructions typically take at least a few clock cycles. constexpr int64_t kDefaultFlopsPerElement = 3; return FindOrDefault(device_profile, std::make_pair(opcode, type), @@ -436,6 +431,51 @@ absl::Status GpuHloCostAnalysis::HandleConcatenate(const HloInstruction* hlo) { return absl::OkStatus(); } +absl::Status GpuHloCostAnalysis::HandleReduce(const HloInstruction* hlo) { + // HloCostAnalysis::HandleReduce computes FLOPs for the computation correctly, + // but `bytes_accessed` estimates are different for GPU. + TF_RETURN_IF_ERROR(HloCostAnalysis::HandleReduce(hlo)); + + const HloReduceInstruction* reduce = DynCast(hlo); + auto output_shape = reduce->shape().IsArray() + ? reduce->shape() + : reduce->shape().tuple_shapes(0); + + int64_t output_bytes_accessed = 0; + ShapeUtil::ForEachLeafShape( + reduce->shape(), [&](const Shape& sub_shape, const ShapeIndex& index) { + output_bytes_accessed += GetShapeSize(sub_shape); + }); + + current_properties_.set_output_bytes_accessed(output_bytes_accessed); + + int64_t bytes_accessed = output_bytes_accessed; + for (int64_t input_operand_id = 0; input_operand_id < reduce->input_count(); + ++input_operand_id) { + bytes_accessed += + current_properties_.operand_bytes_accessed(input_operand_id); + } + + int64_t output_shape_size = ShapeUtil::ElementsIn(output_shape); + for (int64_t init_operand_id = reduce->input_count(); + init_operand_id < reduce->operand_count(); ++init_operand_id) { + auto init_operand = reduce->operand(init_operand_id); + + int64_t operand_bytes_accessed = + output_shape_size * GetShapeSize(init_operand->shape()); + current_properties_.set_operand_bytes_accessed(init_operand_id, + operand_bytes_accessed); + current_properties_.set_operand_utilization(init_operand_id, + output_shape_size); + + bytes_accessed += operand_bytes_accessed; + } + + current_properties_[kBytesAccessedKey] = bytes_accessed; + + return absl::OkStatus(); +} + absl::Status GpuHloCostAnalysis::HandleElementwiseOp( const HloInstruction* hlo) { current_properties_[kFlopsKey] = GetFlopsForElementwiseOp(device_info_, hlo); diff --git a/third_party/xla/xla/service/gpu/model/gpu_hlo_cost_analysis.h b/third_party/xla/xla/service/gpu/model/gpu_hlo_cost_analysis.h index ab0b51d3023112..ee7939c5a23fea 100644 --- a/third_party/xla/xla/service/gpu/model/gpu_hlo_cost_analysis.h +++ b/third_party/xla/xla/service/gpu/model/gpu_hlo_cost_analysis.h @@ -54,6 +54,7 @@ class GpuHloCostAnalysis : public HloCostAnalysis { absl::Status HandleConcatenate(const HloInstruction* hlo) override; absl::Status HandleAllReduce(const HloInstruction* allreduce) override; + absl::Status HandleReduce(const HloInstruction* hlo) override; // Estimate the total size of IR accounting for both duplication // of producer code by consumer and the total number of basic blocks. diff --git a/third_party/xla/xla/service/gpu/model/gpu_hlo_cost_analysis_test.cc b/third_party/xla/xla/service/gpu/model/gpu_hlo_cost_analysis_test.cc index 161d69beebb872..edec3c996832dd 100644 --- a/third_party/xla/xla/service/gpu/model/gpu_hlo_cost_analysis_test.cc +++ b/third_party/xla/xla/service/gpu/model/gpu_hlo_cost_analysis_test.cc @@ -537,5 +537,82 @@ TEST_F(GpuHloCostAnalysisTest, CommonElementwiseUseParameterAndRoot) { 0.f); } +TEST_F(GpuHloCostAnalysisTest, Reduce) { + absl::string_view hlo_string = R"( +HloModule m + +add { + param_0 = f32[] parameter(0) + param_1 = f32[] parameter(1) + ROOT add.0 = f32[] add(param_0, param_1) +} + +ENTRY entry_computation { + param_0.3 = f32[32,40]{1,0} parameter(0) + constant = f32[] constant(0) + ROOT reduce = f32[32]{0} reduce(param_0.3, constant), dimensions={1}, to_apply=add +} +)"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + ASSERT_IS_OK(module->entry_computation()->Accept(&analysis_)); + const HloInstruction* reduce = + module->entry_computation()->root_instruction(); + + int64_t input_bytes_accessed = 4 * 32 * 40; + int64_t init_bytes_accessed = 4 * 32; + int64_t output_bytes_accessed = 4 * 32; + + EXPECT_EQ(analysis_.operand_bytes_accessed(*reduce, 0), input_bytes_accessed); + EXPECT_EQ(analysis_.operand_bytes_accessed(*reduce, 1), init_bytes_accessed); + EXPECT_EQ(analysis_.output_bytes_accessed(*reduce), output_bytes_accessed); + EXPECT_EQ(analysis_.bytes_accessed(*reduce), + input_bytes_accessed + init_bytes_accessed + output_bytes_accessed); + EXPECT_EQ(analysis_.flop_count(*reduce), 32 * 39 * 3); +} + +TEST_F(GpuHloCostAnalysisTest, VariadicReduce) { + absl::string_view hlo_string = R"( +HloModule m + +add { + param_0 = f32[] parameter(0) + param_1 = f32[] parameter(1) + param_2 = f32[] parameter(2) + param_3 = f32[] parameter(3) + add.0 = f32[] add(param_0, param_2) + add.1 = f32[] add(param_1, param_3) + ROOT t = (f32[], f32[]) tuple(add.0, add.1) +} + +ENTRY entry_computation { + param_0.3 = f32[32,40]{1,0} parameter(0) + param_1.3 = f32[32,40]{1,0} parameter(1) + param_2.2 = f32[] parameter(2) + constant = f32[] constant(0) + ROOT reduce = (f32[32]{0}, f32[32]{0}) reduce(param_0.3, param_1.3, param_2.2, constant), dimensions={1}, to_apply=add +} +)"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + ASSERT_IS_OK(module->entry_computation()->Accept(&analysis_)); + const HloInstruction* reduce = + module->entry_computation()->root_instruction(); + + int64_t input_bytes_accessed = 4 * 32 * 40; + int64_t init_bytes_accessed = 4 * 32; + int64_t output_bytes_accessed = 2 * 4 * 32; + + EXPECT_EQ(analysis_.operand_bytes_accessed(*reduce, 0), input_bytes_accessed); + EXPECT_EQ(analysis_.operand_bytes_accessed(*reduce, 1), input_bytes_accessed); + EXPECT_EQ(analysis_.operand_bytes_accessed(*reduce, 2), init_bytes_accessed); + EXPECT_EQ(analysis_.operand_bytes_accessed(*reduce, 3), init_bytes_accessed); + EXPECT_EQ(analysis_.output_bytes_accessed(*reduce), output_bytes_accessed); + EXPECT_EQ(analysis_.bytes_accessed(*reduce), 2 * input_bytes_accessed + + 2 * init_bytes_accessed + + output_bytes_accessed); + EXPECT_EQ(analysis_.flop_count(*reduce), 32 * 39 * 6); +} + } // namespace gpu } // namespace xla diff --git a/third_party/xla/xla/service/gpu/model/gpu_indexing_performance_model.cc b/third_party/xla/xla/service/gpu/model/gpu_indexing_performance_model.cc new file mode 100644 index 00000000000000..814ac0e8e57f44 --- /dev/null +++ b/third_party/xla/xla/service/gpu/model/gpu_indexing_performance_model.cc @@ -0,0 +1,234 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/model/gpu_indexing_performance_model.h" + +#include +#include +#include + +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/time/time.h" +#include "absl/types/span.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/service/gpu/backend_configs.pb.h" +#include "xla/service/gpu/hlo_fusion_analysis.h" +#include "xla/service/gpu/hlo_traversal.h" +#include "xla/service/gpu/launch_dimensions.h" +#include "xla/service/gpu/model/coalescing_analysis.h" +#include "xla/service/gpu/model/gpu_hlo_cost_analysis.h" +#include "xla/service/gpu/model/gpu_performance_model_base.h" +#include "xla/service/gpu/model/indexing_analysis.h" +#include "xla/service/gpu/model/indexing_map.h" +#include "xla/shape.h" +#include "xla/shape_util.h" +#include "xla/util.h" +#include "tsl/platform/status.h" + +namespace xla { +namespace gpu { + +int64_t GpuPerformanceModelWithIndexingAnalysis::FlopsPerElement( + const HloInstruction* instr) const { + // TODO(shyshkov): Replace dependency on GpuHloCostAnalysis with independent + // flops calculation. + GpuHloCostAnalysis::Options cost_analysis_options{ + shape_size_, + /*per_second_rates=*/{}, + /*count_multiple_input_accesses=*/true}; + GpuHloCostAnalysis cost_analysis(cost_analysis_options, device_info_); + TF_CHECK_OK( + cost_analysis.RevisitInstruction(const_cast(instr))); + + int64_t num_elements = [&] { + if (instr->opcode() == HloOpcode::kReduce && instr->shape().IsTuple()) { + return ShapeUtil::ElementsInRecursive(instr->shape().tuple_shapes(0)); + } + return ShapeUtil::ElementsInRecursive(instr->shape()); + }(); + + return cost_analysis.flop_count(*instr) / num_elements; +} + +int64_t GpuPerformanceModelWithIndexingAnalysis::GetShapeSizeRecursive( + const Shape& shape) const { + CHECK(shape.IsArray() || shape.IsTuple()); + if (shape.IsArray()) { + return shape_size_(shape); + } + + int64_t total_size = 0; + for (const auto& element_shape : shape.tuple_shapes()) { + total_size += GetShapeSizeRecursive(element_shape); + } + return total_size; +} + +int64_t GetIterationSpaceSize(const IndexingMap& indexing_map, + const HloInstruction* instr) { + if (indexing_map.IsUndefined()) { + return ShapeUtil::ElementsInRecursive(instr->shape()); + } + + if (indexing_map.IsKnownEmpty()) { + return 0; + } + + auto get_ranges_iteration_space_size = [](const std::vector& ranges) { + int64_t num_iters = 1; + for (const Range& range : ranges) { + num_iters *= range.upper_bound - range.lower_bound + 1; + } + return num_iters; + }; + + return get_ranges_iteration_space_size(indexing_map.GetSymbolRanges()) * + get_ranges_iteration_space_size(indexing_map.GetDimensionRanges()); +} + +EstimateRunTimeData +GpuPerformanceModelWithIndexingAnalysis::EstimateRunTimeForFusion( + const HloFusionAnalysis& fusion_analysis, bool is_coalesced) { + auto& fusion_adaptor = fusion_analysis.fusion(); + VLOG(5) << "EstimateRunTimeForFusion: " << fusion_adaptor.ToString(); + + auto roots = fusion_adaptor.GetRoots(); + CHECK_EQ(roots.size(), 1) + << "Indexing cost model doesn't support multi-output fusions."; + auto root_shape = roots.front().shape(); + + LaunchDimensions launch_dimensions = + EstimateFusionLaunchDimensions(ShapeUtil::ElementsInRecursive(root_shape), + fusion_analysis, *device_info_); + + int64_t num_threads = launch_dimensions.launch_bound(); + int64_t num_blocks = launch_dimensions.num_blocks(); + + // Compute indexing from root to each instruction in the fusion and fusion + // operands. For each instruction, tells which elements of the instructions + // result will be used to compute one result element of the fusion. + auto grouped_fusion_indexing = ComputeGroupedOutputToInputIndexing( + fusion_adaptor, roots[0], mlir_context_); + + int64_t flops = 0; + int64_t bytes_read = 0; + absl::Duration read_time = absl::ZeroDuration(); + + for (const auto& [instr, indexing_maps] : grouped_fusion_indexing) { + VLOG(10) << "instr: " << instr->name(); + HloInstructionAdaptor instr_adaptor(*instr); + + // Instructions inside the fusion are computation and account for FLOPs + // count. Instructions outside the fusion are operands of the fusion and + // account for memory read time. + bool is_operand = !fusion_adaptor.ContainsInstruction(instr_adaptor); + + auto element_type = instr->shape().element_type(); + int64_t n_bytes_total = 0; + for (const auto& indexing_map : indexing_maps) { + VLOG(10) << indexing_map.ToString(); + + int64_t num_iters = GetIterationSpaceSize(indexing_map, instr); + + if (is_operand) { + int64_t type_size = ShapeUtil::ByteSizeOfPrimitiveType(element_type); + n_bytes_total += type_size * num_iters; + } else { + int64_t flops_per_element = FlopsPerElement(instr); + flops += flops_per_element * num_iters; + } + } + + if (is_operand) { + int64_t operand_size = shape_size_(instr->shape()); + int64_t n_bytes_net = std::min(operand_size, n_bytes_total); + bytes_read += n_bytes_total; + + VLogOperandRead(instr, n_bytes_total, n_bytes_net, is_coalesced); + + read_time += + ReadTimeWithDRAMHeuristic(*device_info_, num_blocks, n_bytes_net, + n_bytes_total, element_type, is_coalesced); + } + } + + int64_t bytes_written = GetShapeSizeRecursive(root_shape); + + absl::Duration compute_time = ComputeTime(*device_info_, flops, num_threads); + absl::Duration write_time = WriteTime(*device_info_, bytes_written); + absl::Duration memory_access_time = read_time + write_time; + absl::Duration exec_time = CombineComputeAndMemoryAccessTime( + compute_time, memory_access_time, + GpuPerformanceModelOptions::PriorityFusion()); + + VLogResult(flops, bytes_read, bytes_written, num_threads, compute_time, + read_time, write_time, exec_time); + + return EstimateRunTimeData{flops, bytes_written, num_threads, read_time, + write_time, compute_time, exec_time}; +} + +EstimateRunTimeData +GpuPerformanceModelWithIndexingAnalysis::EstimateRunTimeForInstruction( + const HloInstruction* producer) { + // Stand-alone bitcast is always no-op during runtime. + if (producer->opcode() == HloOpcode::kBitcast) { + return {0, 0, 0, absl::ZeroDuration(), absl::ZeroDuration()}; + } + + auto fusion_analysis = AnalyzeFusion(*producer, *device_info_); + + bool is_coalesced = IsReadCoalescedHeuristic( + fusion_analysis.GetEmitterFusionKind(), producer); + return EstimateRunTimeForFusion(fusion_analysis, is_coalesced); +} + +EstimateRunTimeData +GpuPerformanceModelWithIndexingAnalysis::EstimateRunTimeForProducerConsumer( + const HloInstruction* producer, const HloInstruction* consumer) { + auto fusion_analysis = + AnalyzeProducerConsumerFusion(*producer, *consumer, *device_info_); + + bool is_coalesced = IsReadCoalescedHeuristic( + fusion_analysis.GetEmitterFusionKind(), producer, consumer); + return EstimateRunTimeForFusion(fusion_analysis, is_coalesced); +} + +/*static*/ +GpuPerformanceModelWithIndexingAnalysis::RunTimes +GpuPerformanceModelWithIndexingAnalysis::EstimateRunTimes( + const HloInstruction* producer, + absl::Span fused_consumers) { + auto producer_runtime = EstimateRunTimeForInstruction(producer); + + absl::Duration time_unfused = + kKernelLaunchOverhead * (fused_consumers.size() + 1) + + producer_runtime.exec_time; + + absl::Duration time_fused = kKernelLaunchOverhead * fused_consumers.size(); + + for (const auto& consumer : fused_consumers) { + time_unfused += EstimateRunTimeForInstruction(consumer).exec_time; + time_fused += + EstimateRunTimeForProducerConsumer(producer, consumer).exec_time; + } + + return {time_unfused, time_fused}; +} + +} // namespace gpu +} // namespace xla diff --git a/third_party/xla/xla/service/gpu/model/gpu_indexing_performance_model.h b/third_party/xla/xla/service/gpu/model/gpu_indexing_performance_model.h new file mode 100644 index 00000000000000..4328d3588009e6 --- /dev/null +++ b/third_party/xla/xla/service/gpu/model/gpu_indexing_performance_model.h @@ -0,0 +1,75 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_GPU_MODEL_GPU_INDEXING_PERFORMANCE_MODEL_H_ +#define XLA_SERVICE_GPU_MODEL_GPU_INDEXING_PERFORMANCE_MODEL_H_ + +#include + +#include "absl/types/span.h" +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/service/gpu/hlo_fusion_analysis.h" +#include "xla/service/gpu/model/gpu_performance_model_base.h" +#include "xla/service/gpu/model/hlo_op_profiles.h" +#include "xla/service/hlo_cost_analysis.h" +#include "xla/stream_executor/device_description.h" + +namespace xla { +namespace gpu { + +// Implementation of Cost Model that uses indexing analysis to estimate amount +// of compute and memory access time. +class GpuPerformanceModelWithIndexingAnalysis : public GpuPerformanceModelBase { + public: + explicit GpuPerformanceModelWithIndexingAnalysis( + const se::DeviceDescription* device_info, + HloCostAnalysis::ShapeSizeFunction shape_size, + mlir::MLIRContext* mlir_context) + : hlo_op_profile_(&HloOpProfiles::Singleton().GetProfile(device_info)), + device_info_(device_info), + shape_size_(shape_size), + mlir_context_(mlir_context) {} + + EstimateRunTimeData EstimateRunTimeForFusion( + const HloFusionAnalysis& fusion_analysis, bool is_coalesced = true); + + EstimateRunTimeData EstimateRunTimeForInstruction( + const HloInstruction* producer); + + EstimateRunTimeData EstimateRunTimeForProducerConsumer( + const HloInstruction* producer, const HloInstruction* consumer); + + RunTimes EstimateRunTimes( + const HloInstruction* producer, + absl::Span fused_consumers = {}); + + private: + // Returns an estimate how many FLOPs will be used to produce one element of + // the output. + int64_t FlopsPerElement(const HloInstruction* instr) const; + + int64_t GetShapeSizeRecursive(const Shape& shape) const; + + const HloOpProfiles::HloOpProfile* hlo_op_profile_; + const se::DeviceDescription* device_info_; + HloCostAnalysis::ShapeSizeFunction shape_size_; + mlir::MLIRContext* mlir_context_; +}; + +} // namespace gpu +} // namespace xla + +#endif // XLA_SERVICE_GPU_MODEL_GPU_INDEXING_PERFORMANCE_MODEL_H_ diff --git a/third_party/xla/xla/service/gpu/model/gpu_indexing_performance_model_test.cc b/third_party/xla/xla/service/gpu/model/gpu_indexing_performance_model_test.cc new file mode 100644 index 00000000000000..5e52685762e524 --- /dev/null +++ b/third_party/xla/xla/service/gpu/model/gpu_indexing_performance_model_test.cc @@ -0,0 +1,172 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/model/gpu_indexing_performance_model.h" + +#include +#include + +#include +#include "absl/strings/string_view.h" +#include "absl/time/time.h" +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/service/gpu/backend_configs.pb.h" +#include "xla/service/gpu/gpu_device_info_for_tests.h" +#include "xla/service/gpu/model/gpu_hlo_cost_analysis.h" +#include "xla/service/gpu/model/gpu_performance_model_base.h" +#include "xla/shape.h" +#include "xla/shape_util.h" +#include "xla/stream_executor/device_description.h" +#include "xla/tests/hlo_test_base.h" +#include "tsl/platform/statusor.h" + +namespace xla { +namespace gpu { +namespace { + +class GpuIndexingPerformanceModelTest : public HloTestBase { + GpuHloCostAnalysis::ShapeSizeFunction ShapeSizeBytesFunction() const { + return [&](const Shape& shape) { + constexpr int64_t kPointerSize = 8; + return ShapeUtil::ByteSizeOf(shape, kPointerSize); + }; + } + + public: + mlir::MLIRContext mlir_context_; + // The reference times in the test cases below are measured + // on A6000 by profiling the execution of the HLOs. + se::DeviceDescription device_info_{TestGpuDeviceInfo::RTXA6000DeviceInfo()}; + GpuPerformanceModelWithIndexingAnalysis indexing_cost_model_{ + &device_info_, ShapeSizeBytesFunction(), &mlir_context_}; + + GpuIndexingPerformanceModelTest() : HloTestBase() {} +}; + +TEST_F(GpuIndexingPerformanceModelTest, BroadcastElementwise) { + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule( + R"( +HloModule extracted + +ENTRY entry_computation { + param_0 = f32[32]{0} parameter(0) + broadcast = f32[32,1,768]{2,1,0} broadcast(param_0), dimensions={0} + param_1 = f32[32,1,768]{2,1,0} parameter(1) + ROOT multiply = f32[32,1,768]{2,1,0} multiply(broadcast, param_1) +} +)")); + + auto producer = + module->entry_computation()->GetInstructionWithName("broadcast"); + auto consumer = + module->entry_computation()->GetInstructionWithName("multiply"); + + auto runtime_data = indexing_cost_model_.EstimateRunTimeForProducerConsumer( + producer, consumer); + EXPECT_EQ(runtime_data.flops, 73728); + EXPECT_EQ(runtime_data.bytes_written, 98304); + EXPECT_NEAR(absl::ToInt64Nanoseconds(runtime_data.write_time), 128, 2); + EXPECT_NEAR(absl::ToInt64Nanoseconds(runtime_data.exec_time), 267, 2); +} + +TEST_F(GpuIndexingPerformanceModelTest, Bitcast) { + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule( + R"( +HloModule m + +ENTRY entry_computation { + param_0 = bf16[4,8,65,128]{3,2,1,0} parameter(0) + ROOT bitcast = bf16[8,4,65,128]{3,2,0,1} bitcast(param_0) +} +)")); + + auto instruction = + module->entry_computation()->GetInstructionWithName("bitcast"); + + auto runtime_data = + indexing_cost_model_.EstimateRunTimeForInstruction(instruction); + EXPECT_EQ(runtime_data.flops, 0); + EXPECT_EQ(runtime_data.bytes_written, 0); + EXPECT_EQ(runtime_data.write_time, absl::ZeroDuration()); + EXPECT_EQ(runtime_data.exec_time, absl::ZeroDuration()); +} + +TEST_F(GpuIndexingPerformanceModelTest, Reduce) { + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule( + R"( +HloModule m + +add { + param_0 = f32[] parameter(0) + param_1 = f32[] parameter(1) + ROOT add.0 = f32[] add(param_0, param_1) +} + +ENTRY entry_computation { + param_0.3 = f32[32,40]{1,0} parameter(0) + constant = f32[] constant(0) + ROOT reduce = f32[32]{0} reduce(param_0.3, constant), dimensions={1}, to_apply=add +} +)")); + + auto instruction = module->entry_computation()->root_instruction(); + + auto runtime_data = + indexing_cost_model_.EstimateRunTimeForInstruction(instruction); + EXPECT_EQ(runtime_data.flops, 3744); + EXPECT_EQ(runtime_data.bytes_written, 128); + EXPECT_NEAR(absl::ToDoubleNanoseconds(runtime_data.write_time), 0, 1); + EXPECT_NEAR(absl::ToDoubleNanoseconds(runtime_data.exec_time), 29, 1); +} + +TEST_F(GpuIndexingPerformanceModelTest, VariadicReduce) { + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule( + R"( +HloModule m + +add { + param_0 = f32[] parameter(0) + param_1 = f32[] parameter(1) + param_2 = f32[] parameter(2) + param_3 = f32[] parameter(3) + add.0 = f32[] add(param_0, param_2) + add.1 = f32[] add(param_1, param_3) + ROOT t = (f32[], f32[]) tuple(add.0, add.1) +} + +ENTRY entry_computation { + param_0.3 = f32[32,40]{1,0} parameter(0) + param_1.3 = f32[32,40]{1,0} parameter(1) + param_2.2 = f32[] parameter(2) + constant = f32[] constant(0) + ROOT reduce = (f32[32]{0}, f32[32]{0}) reduce(param_0.3, param_1.3, param_2.2, constant), dimensions={1}, to_apply=add +} +)")); + + auto instruction = module->entry_computation()->root_instruction(); + + auto runtime_data = + indexing_cost_model_.EstimateRunTimeForInstruction(instruction); + EXPECT_EQ(runtime_data.flops, 7488); + EXPECT_EQ(runtime_data.bytes_written, 256); + EXPECT_NEAR(absl::ToDoubleNanoseconds(runtime_data.write_time), 0, 1); + EXPECT_NEAR(absl::ToDoubleNanoseconds(runtime_data.exec_time), 58, 1); +} + +} // namespace +} // namespace gpu +} // namespace xla diff --git a/third_party/xla/xla/service/gpu/model/gpu_performance_model.cc b/third_party/xla/xla/service/gpu/model/gpu_performance_model.cc index 88fa14b037fab2..5144094c377b4d 100644 --- a/third_party/xla/xla/service/gpu/model/gpu_performance_model.cc +++ b/third_party/xla/xla/service/gpu/model/gpu_performance_model.cc @@ -18,319 +18,53 @@ limitations under the License. #include #include #include -#include #include #include -#include "absl/container/flat_hash_set.h" #include "absl/log/check.h" #include "absl/log/log.h" -#include "absl/strings/numbers.h" -#include "absl/synchronization/mutex.h" #include "absl/time/time.h" +#include "absl/types/span.h" #include "llvm/ADT/STLExtras.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/service/gpu/backend_configs.pb.h" -#include "xla/service/gpu/fusions/fusions.h" #include "xla/service/gpu/hlo_fusion_analysis.h" -#include "xla/service/gpu/hlo_traversal.h" #include "xla/service/gpu/launch_dimensions.h" #include "xla/service/gpu/model/coalescing_analysis.h" #include "xla/service/gpu/model/gpu_hlo_cost_analysis.h" -#include "xla/service/hlo_dataflow_analysis.h" +#include "xla/service/gpu/model/gpu_performance_model_base.h" #include "xla/shape_util.h" #include "xla/stream_executor/device_description.h" #include "xla/util.h" #include "tsl/platform/status.h" -#if GOOGLE_CUDA -#include "third_party/gpus/cuda/nvml/include/nvml.h" -#endif // GOOGLE_CUDA namespace xla { namespace gpu { - namespace { -// Estimated values in the absence of easy ways to query them. -static constexpr absl::Duration kKernelLaunchOverhead = absl::Microseconds(1); -static constexpr absl::Duration kNcclKernelLaunchOverhead = - absl::Microseconds(5); -static constexpr float kL2CacheSpeedup = 2.5; -static constexpr float kL1CacheSpeedup = 8; -// A very conservative estimate. L1 size varies because it can be dynamically -// configured as shared memory; there is no easy way to query its actual size; -// also we do not count what occupies cache, but rather claim that what is -// much smaller than the cache size will likely stay in it. -// For reference, it can be up to 256 kB per SM on RTX A6000. -static constexpr float kL1CacheSizePerSM = 2 * 1024; - -absl::Duration CombineComputeAndMemoryAccessTime( - absl::Duration compute_time, absl::Duration memory_access_time, - const GpuPerformanceModelOptions& config) { - return compute_time + memory_access_time - - std::min(compute_time, memory_access_time) * - config.memory_compute_parallelism; -} - -// Returns whether a fusion uses the parameter at the given index elementwise -// from its root. -bool FusionUsesParameterElementwiseFromRoot( - const HloInstruction* fusion, int parameter_index, - const GpuHloCostAnalysis* cost_analysis) { - return cost_analysis->CommonElementwiseUtilization( - fusion->fused_parameter(parameter_index), - fusion->fused_expression_root()) == 1.f; -} - -int GetCoalescingWasteFactor(PrimitiveType element_type) { - int64_t element_size_bytes = - element_type == PrimitiveType::TUPLE || - element_type == PrimitiveType::TOKEN - ? 4 /* Dummy value. TODO(jreiffers): Model this case. */ - : ShapeUtil::ByteSizeOfPrimitiveType(element_type); - // Cache line is 128B that is split into 4 sectors of 32B. Default transaction - // size from DRAM -> L2 = 64 Bytes = 2 sectors, since V100, but it can be also - // configured. - // https://developer.download.nvidia.com/video/gputechconf/gtc/2020/presentations/s21819-optimizing-applications-for-nvidia-ampere-gpu-architecture.pdf - // (page 10). - constexpr int kDRAMToL2TransactionSizeBytes = 64; - // Assume we use one element from the cache line and waste the remaining - // bandwidth. For example, if we're reading f32s, we use 1/16nd of the cache - // line. - return kDRAMToL2TransactionSizeBytes / element_size_bytes; -} - -// Limit the bandwidth for low occupancy cases. Each SM can issue at most -// one 32B memory transaction per clock. H100 needs at least 56.8 active SMs -// (1830 MHz) to saturate the memory bandwidth (3.35 TB/s). -float AdjustBandwidth(const se::DeviceDescription& gpu_device_info, - float bandwidth, int64_t num_blocks) { - float per_block_bandwidth = gpu_device_info.clock_rate_ghz() * 1.0e9f * 32; - float max_bandwidth = num_blocks * per_block_bandwidth; - - return std::min(bandwidth, max_bandwidth); -} - -// Estimate read time of n_bytes_total bytes from global memory on a -// given GPU. -// -// Assumes that the first n_bytes_net are always read from DRAM, but next reads -// can be cached. Applies waste factor if read from DRAM is uncoalesced. -absl::Duration ReadTimeWithDRAMHeuristic( - const se::DeviceDescription& gpu_device_info, int64_t num_blocks, - int64_t n_bytes_net, int64_t n_bytes_total, PrimitiveType element_type, - bool coalesced) { - int waste_factor = coalesced ? 1 : GetCoalescingWasteFactor(element_type); - - // The first read of the input buffer always happens from DRAM. If reads are - // no coaleced, bandwidth is reduced by the waste factor. - float dram_bandwidth = gpu_device_info.memory_bandwidth() / waste_factor; - - // Two things can happed on re-reading the buffer: - // - If the buffer fits into cache, the L1/L2 cache speedup is applied. - // - If the buffer doesn't fit, it will be read from DRAM and the same - // coalessing waste factor is applied. - float rest_bandwidth = gpu_device_info.memory_bandwidth(); - if (n_bytes_net < gpu_device_info.l2_cache_size()) { - rest_bandwidth *= kL2CacheSpeedup; - if (n_bytes_net < kL1CacheSizePerSM * gpu_device_info.core_count()) { - rest_bandwidth *= kL1CacheSpeedup; - } - } else { - rest_bandwidth /= waste_factor; - } - - dram_bandwidth = AdjustBandwidth(gpu_device_info, dram_bandwidth, num_blocks); - rest_bandwidth = AdjustBandwidth(gpu_device_info, rest_bandwidth, num_blocks); - - // n_bytes_net > n_bytes_total can happen when we compute read time of - // shared operand. This is a flaw in the interface that should be fixed. - int64_t n_bytes_read_dram = std::min(n_bytes_net, n_bytes_total); - - // Number of bytes that we be re-read, potentially from cache. - int64_t n_bytes_read_cache = n_bytes_total - n_bytes_read_dram; - - return absl::Seconds(n_bytes_read_dram / dram_bandwidth) + - absl::Seconds(n_bytes_read_cache / rest_bandwidth); -} - -// Estimate read time of n_bytes_total bytes from global memory on a -// given GPU. Account for L1 / L2 cache speedup if the input's nominal size -// n_bytes_net is small. -absl::Duration ReadTime(const se::DeviceDescription& gpu_device_info, - int64_t num_blocks, int64_t n_bytes_net, - int64_t n_bytes_total) { - float bandwidth = gpu_device_info.memory_bandwidth(); - if (n_bytes_net < gpu_device_info.l2_cache_size()) { - bandwidth *= kL2CacheSpeedup; - if (n_bytes_net < kL1CacheSizePerSM * gpu_device_info.core_count()) { - bandwidth *= kL1CacheSpeedup; - } - } - - bandwidth = AdjustBandwidth(gpu_device_info, bandwidth, num_blocks); - return absl::Seconds(n_bytes_total / bandwidth); -} - -int64_t GetNcclMaxNumChannels( - GpuPerformanceWithCollectiveModel::CollectiveAlgo algorithm) { - int64_t max_nchannels = 0; - switch (algorithm) { - // Tree and Ring algos share the same max channel number. - case GpuPerformanceWithCollectiveModel::RING: - case GpuPerformanceWithCollectiveModel::TREE: - max_nchannels = GpuPerformanceWithCollectiveModel::kMaxNumChannelsRing; - break; - } - const char* env = std::getenv("NCCL_MAX_NCHANNELS"); - if (env != nullptr) { - int64_t max_nchannels_from_env; - if (absl::SimpleAtoi(env, &max_nchannels_from_env)) { - max_nchannels = std::min(max_nchannels_from_env, max_nchannels); - } - } - return max_nchannels; -} - -int64_t GetMinNumberOfChannels( - GpuPerformanceWithCollectiveModel::CollectiveAlgo algorithm) { - int64_t min_nchannels = 0; - switch (algorithm) { - // Tree and Ring algos share the same min channel number. - case GpuPerformanceWithCollectiveModel::RING: - case GpuPerformanceWithCollectiveModel::TREE: - min_nchannels = 1; - break; - } - const char* env = std::getenv("NCCL_MIN_NCHANNELS"); - if (env != nullptr) { - int64_t min_nchannels_from_env; - if (absl::SimpleAtoi(env, &min_nchannels_from_env)) { - min_nchannels = std::min(min_nchannels_from_env, min_nchannels); - } - } - return min_nchannels; -} - -int GetNumThreads(int warp_size, int min_num_threads, int max_num_threads, - int default_num_threads) { - int threads_from_env = default_num_threads; - const char* env = std::getenv("NCCL_NTHREADS"); - if (env != nullptr) { - CHECK(absl::SimpleAtoi(env, &threads_from_env)); - } - int num_threads = threads_from_env; - if (num_threads > 0) { - if (num_threads % warp_size != 0) { - num_threads = max_num_threads; - } else if (num_threads > max_num_threads) { - num_threads = max_num_threads; - } else if (num_threads < min_num_threads) { - num_threads = min_num_threads; - } - } else { - num_threads = default_num_threads; - } - return num_threads; -} - -float GetMaxSysBwFromGpu(const se::CudaComputeCapability cc, - const double* bandwidths_table) { - switch (cc.major) { - case se::CudaComputeCapability::VOLTA: - return bandwidths_table[0]; - case se::CudaComputeCapability::AMPERE: - return bandwidths_table[1]; - case se::CudaComputeCapability::HOPPER: - return bandwidths_table[2]; +std::vector GetUniqueFusionOperands( + const HloInstruction* producer, const HloInstruction* consumer) { + std::vector fusion_operands; + for (const HloInstruction* operand : producer->operands()) { + fusion_operands.push_back(operand); } - return -1; -} - -// Uses HloFusionAnalysis for computing the actual number of threads and blocks -// that the IR emitter will use. -LaunchDimensions EstimateFusionLaunchDimensions( - int64_t estimated_num_threads, const HloFusionAnalysis& fusion_analysis, - const se::DeviceDescription& device_info) { - auto emitter = - GetFusionEmitter(PreBufferAssignmentFusionInfo{fusion_analysis}); - if (emitter.ok()) { - if (const auto* kernel_emitter = - dynamic_cast(emitter->get())) { - return kernel_emitter->launch_dimensions(); + for (const HloInstruction* operand : consumer->operands()) { + if (operand != producer) { + fusion_operands.push_back(operand); } } - int64_t block_size = 128; // Result for default LaunchDimensionsConfig. - int64_t num_blocks = CeilOfRatio(estimated_num_threads, block_size); - return LaunchDimensions(num_blocks, block_size); + std::sort(fusion_operands.begin(), fusion_operands.end()); + fusion_operands.erase( + std::unique(fusion_operands.begin(), fusion_operands.end()), + fusion_operands.end()); + return fusion_operands; } } // namespace -std::optional GpuPerformanceModelCache::Get( - const HloInstruction& instruction) { - absl::MutexLock lock(&mutex_); - - auto it = instruction_runtime_data_.find(HloInstructionAdaptor(instruction)); - if (it != instruction_runtime_data_.end()) { - return it->second; - } - return std::nullopt; -} - -std::optional GpuPerformanceModelCache::Get( - const HloInstruction& producer, const HloInstruction& consumer) { - absl::MutexLock lock(&mutex_); - - auto it = fusion_runtime_data_.find(HloInstructionAdaptor(producer)); - if (it != fusion_runtime_data_.end()) { - auto jt = it->second.find(HloInstructionAdaptor(consumer)); - if (jt != it->second.end()) { - return jt->second; - } - } - return std::nullopt; -} - -void GpuPerformanceModelCache::Set(const HloInstruction& instruction, - const EstimateRunTimeData& runtime_data) { - absl::MutexLock lock(&mutex_); - - instruction_runtime_data_[HloInstructionAdaptor(instruction)] = runtime_data; -} - -void GpuPerformanceModelCache::Set(const HloInstruction& producer, - const HloInstruction& consumer, - absl::Duration runtime) { - absl::MutexLock lock(&mutex_); - fusion_runtime_data_[HloInstructionAdaptor(producer)] - [HloInstructionAdaptor(consumer)] = runtime; -} - -void GpuPerformanceModelCache::Invalidate(const HloInstruction& instruction) { - absl::MutexLock lock(&mutex_); - HloInstructionAdaptor adaptor(instruction); - - // Remove runtime data for the instruction. - instruction_runtime_data_.erase(adaptor); - - // Remove cache for all producer-consumer pairs where the instruction is - // producer. - fusion_runtime_data_.erase(adaptor); - - // Iterate through operands to find all producer-consumer pairs where - // instruction is consumer and remove them from cache. - for (auto* operand : instruction.operands()) { - auto it = fusion_runtime_data_.find(HloInstructionAdaptor(*operand)); - if (it != fusion_runtime_data_.end()) { - it->second.erase(adaptor); - } - } -} - /*static*/ EstimateRunTimeData GpuPerformanceModel::EstimateRunTimeForInstruction( const HloInstruction* instr, const GpuHloCostAnalysis* cost_analysis, @@ -340,7 +74,6 @@ GpuPerformanceModel::EstimateRunTimeForInstruction( int64_t flops = cost_analysis->flop_count(*instr); int64_t bytes_written = cost_analysis->output_bytes_accessed(*instr); - int64_t bytes_read = cost_analysis->bytes_accessed(*instr) - bytes_written; // Use the analysis cache if present. // TODO(jreiffers): Remove this once all callers use a cache. @@ -359,47 +92,36 @@ GpuPerformanceModel::EstimateRunTimeForInstruction( absl::Duration compute_time = ComputeTime(*device_info, flops, num_threads); - // TODO(jreiffers): We should be checking each operand. - bool coalesced = IsReadCoalescedHeuristic(fusion_analysis, instr, - /*consumer=*/nullptr); + CoalescingAnalysis coalescing_analysis(instr, instr->operands(), + fusion_analysis); absl::Duration read_time; - for (int i = 0; i < instr->operand_count(); ++i) { - auto element_type = instr->operand(i)->shape().element_type(); - // Information about data read taking into account utilization. - // If `operand_utilization` is 0, `operand_bytes_accessed` should be also 0. - int64_t n_bytes_total = cost_analysis->operand_bytes_accessed(*instr, i); - float operand_utilization = cost_analysis->operand_utilization(*instr, i); - - // An estimate how much data would need to fit into L1/L2 cache to speed up - // the operand access. - // If `operand_utilization` < 1, only a part of the full operand size should - // be read. Otherwise, `n_bytes_total / operand_utilization` is the - // size of the operand without reuse. - int64_t n_bytes_net = - std::llround(n_bytes_total / std::max(operand_utilization, 1.0f)); + int64_t bytes_read = 0; + for (const auto [operand_id, operand] : llvm::enumerate(instr->operands())) { + int64_t operand_size = cost_analysis->GetShapeSize(operand->shape()); + int64_t n_bytes_total = + GetOperandBytesAccessed(cost_analysis, instr, operand); + int64_t n_bytes_net = std::min(operand_size, n_bytes_total); + bytes_read += n_bytes_total; - read_time += - ReadTimeWithDRAMHeuristic(*device_info, num_blocks, n_bytes_net, - n_bytes_total, element_type, coalesced); + bool coalesced = coalescing_analysis.IsReadCoalesced(operand); + + VLogOperandRead(operand, n_bytes_total, n_bytes_net, coalesced); + + read_time += ReadTimeWithDRAMHeuristic( + *device_info, num_blocks, n_bytes_net, n_bytes_total, + operand->shape().element_type(), coalesced); } - absl::Duration write_time = - absl::Seconds(1.0f * bytes_written / device_info->memory_bandwidth()); + absl::Duration write_time = WriteTime(*device_info, bytes_written); absl::Duration exec_time = CombineComputeAndMemoryAccessTime( compute_time, read_time + write_time, config); - if (VLOG_IS_ON(8)) { - LOG(INFO) << "FLOPs: " << flops; - LOG(INFO) << "Bytes read: " << bytes_read; - LOG(INFO) << "Bytes written: " << bytes_written; - LOG(INFO) << "Num threads: " << num_threads; - LOG(INFO) << "Compute time: " << compute_time; - LOG(INFO) << "Input read time: " << read_time; - LOG(INFO) << "Output write time: " << write_time; - } + VLogResult(flops, bytes_read, bytes_written, num_threads, compute_time, + read_time, write_time, exec_time); - return {flops, bytes_written, num_threads, write_time, exec_time}; + return {flops, bytes_written, num_threads, read_time, + write_time, compute_time, exec_time}; } /*static*/ EstimateRunTimeData @@ -422,145 +144,12 @@ GpuPerformanceModel::EstimateRunTimeForInstructionCached( return runtime_data; } -// Returns utilization of operand by instruction. Returns 0, if the operand is -// not used by the instruction. -float GetOperandUtilization(const GpuHloCostAnalysis* cost_analysis, - const HloInstruction* instr, - const HloInstruction* operand) { - if (!instr->IsUserOf(operand)) { - return 0.f; - } - - return cost_analysis->operand_utilization(*instr, - instr->operand_index(operand)); -} - -// Returns utilization `overlap` between a common operand of producer and -// consumer on merge. `utilization > 0` means that the operand will be accessed -// more efficiently after fusion. -// -// Currently covers two cases: -// 1) Producer has to use the common operand elementwise from its root if it is -// a fusion or just be an elementwise instruction. -// 2) Consumer has to have common elementwise roots for the producer and the -// common operand if it is a fusion or just be an elementwise instruction. -float GetCommonUtilization(const GpuHloCostAnalysis* cost_analysis, - const HloInstruction* producer, - int64_t producer_idx_of_operand, - const HloInstruction* consumer) { - const auto* operand = producer->operand(producer_idx_of_operand); - - if (!consumer || !consumer->IsUserOf(operand)) { - return 0.f; - } - - if (producer->IsElementwise() || - (producer->opcode() == HloOpcode::kFusion && - FusionUsesParameterElementwiseFromRoot(producer, producer_idx_of_operand, - cost_analysis))) { - if (consumer->opcode() == HloOpcode::kFusion) { - int64_t consumer_idx_of_common_operand = consumer->operand_index(operand); - int64_t consumer_idx_of_producer = consumer->operand_index(producer); - return cost_analysis->CommonElementwiseUtilization( - consumer->fused_parameter(consumer_idx_of_common_operand), - consumer->fused_parameter(consumer_idx_of_producer)); - } else { - if (consumer->IsElementwise()) { - return 1.f; - } - } - } - return 0.f; -} - -// Returns utilization of operand after producer and consumer are fused -// together. `GetCommonUtilization` works only for a limited set of elementwise -// cases. -// TODO(shyshkov): Combine logic from GpuHloCostAnalysis with boundary function -// to properly calculate utilization. -float GetSharedUtilization(const GpuHloCostAnalysis* cost_analysis, - const HloInstruction* producer, - const HloInstruction* consumer, - const HloInstruction* operand) { - float producer_utilization_by_consumer = - GetOperandUtilization(cost_analysis, consumer, producer); - - float operand_utilization_by_producer = - GetOperandUtilization(cost_analysis, producer, operand); - - float operand_utilization_by_consumer = - GetOperandUtilization(cost_analysis, consumer, operand); - - float common_utilization = - producer->IsUserOf(operand) - ? GetCommonUtilization(cost_analysis, producer, - producer->operand_index(operand), consumer) - : 0.f; - - return producer_utilization_by_consumer * operand_utilization_by_producer + - operand_utilization_by_consumer - common_utilization; -} - -// Tells input access time of the producer alone if fused_consumer -// is not specified. Otherwise estimates the access time to producer's -// inputs as if it is fused into the consumer. -/*static*/ absl::Duration GpuPerformanceModel::ProducerInputAccessTime( - const GpuHloCostAnalysis* cost_analysis, - const se::DeviceDescription& gpu_device_info, int64_t num_blocks, - const HloInstruction* producer, const HloFusionAnalysis& fusion_analysis, - const GpuPerformanceModelOptions& config, - const HloInstruction* fused_consumer) { - absl::Duration ret = absl::ZeroDuration(); - float producer_output_utilization = - fused_consumer - ? GetOperandUtilization(cost_analysis, fused_consumer, producer) - : 1.f; - - for (int i = 0; i < producer->operand_count(); ++i) { - // Information about data read taking into account utilization. - // If `operand_utilization` is 0, `operand_bytes_accessed` should be also 0. - int64_t operand_bytes_accessed = - cost_analysis->operand_bytes_accessed(*producer, i); - float operand_utilization = - cost_analysis->operand_utilization(*producer, i); - - // An estimate how much data would need to fit into L1/L2 cache to speed up - // the operand access. - // If `operand_utilization` < 1, only a part of the full operand size should - // be read. Otherwise, `operand_bytes_accessed / operand_utilization` is the - // size of the operand without reuse. - int64_t n_bytes_net = std::llround(operand_bytes_accessed / - std::max(operand_utilization, 1.0f)); - - // Look if common operand of producer and consumer will be accessed more - // efficiently on merge. - float common_utilization = GetCommonUtilization( - cost_analysis, producer, /*producer_idx_of_operand=*/i, fused_consumer); - - CHECK_LE(common_utilization, producer_output_utilization); - float n_bytes_total = operand_bytes_accessed * - (producer_output_utilization - common_utilization); - ret += ReadTime(gpu_device_info, num_blocks, n_bytes_net, n_bytes_total); - } - return ret; -} - -absl::Duration GpuPerformanceModel::ComputeTime( - const se::DeviceDescription& gpu_device_info, int64_t flops, - int64_t num_threads) { - int64_t fpu_count = - gpu_device_info.core_count() * gpu_device_info.fpus_per_core(); - int64_t n_threads_active = std::min(num_threads, fpu_count); - int64_t flop_per_ns_per_fpu = gpu_device_info.clock_rate_ghz() * /*fma:*/ 2; - int64_t flop_per_ns_effective = flop_per_ns_per_fpu * n_threads_active; - return absl::Nanoseconds(1.0f * flops / flop_per_ns_effective); -} - +/*static*/ absl::Duration GpuPerformanceModel::EstimateUnfusedExecTime( const HloInstruction* producer, const EstimateRunTimeData& producer_runtime, const GpuHloCostAnalysis* cost_analysis, const GpuPerformanceModelOptions& config, - const std::vector& fused_consumers) { + absl::Span fused_consumers) { const se::DeviceDescription* device_info = cost_analysis->device_info_; absl::Duration time_unfused = @@ -630,51 +219,43 @@ absl::Duration GpuPerformanceModel::EstimateUnfusedExecTime( producer_runtime.num_threads * utilization_by_this_consumer, fusion_analysis, *device_info); - int64_t fused_flops = producer_runtime.flops * utilization_by_this_consumer + - consumer_runtime.flops; + int64_t flops = producer_runtime.flops * utilization_by_this_consumer + + consumer_runtime.flops; int64_t num_threads = launch_dimensions.launch_bound(); - absl::Duration compute_time = - ComputeTime(*device_info, fused_flops, num_threads); + absl::Duration compute_time = ComputeTime(*device_info, flops, num_threads); - absl::flat_hash_set fusion_operands; - for (auto* operand : producer->operands()) { - fusion_operands.insert(operand); - } - for (auto* operand : consumer->operands()) { - if (operand != producer) { - fusion_operands.insert(operand); - } - } + std::vector fusion_operands = + GetUniqueFusionOperands(producer, consumer); + CoalescingAnalysis coalescing_analysis(producer, consumer, fusion_operands, + fusion_analysis); absl::Duration read_time; + int64_t bytes_read = 0; for (const auto* operand : fusion_operands) { - float operand_utilization = - GetSharedUtilization(cost_analysis, producer, consumer, operand); - int64_t operand_size = cost_analysis->GetShapeSize(operand->shape()); - int64_t n_bytes_total = std::llround(operand_size * operand_utilization); + int64_t n_bytes_total = GetSharedOperandBytesAccessed( + cost_analysis, producer, consumer, operand); int64_t n_bytes_net = std::min(operand_size, n_bytes_total); + bytes_read += n_bytes_total; - bool coalesced = - IsReadCoalescedHeuristic(fusion_analysis, producer, consumer); + bool coalesced = coalescing_analysis.IsReadCoalesced(operand); + + VLogOperandRead(operand, n_bytes_total, n_bytes_net, coalesced); read_time += ReadTimeWithDRAMHeuristic( *device_info, launch_dimensions.num_blocks(), n_bytes_net, n_bytes_total, operand->shape().element_type(), coalesced); } - if (VLOG_IS_ON(8)) { - LOG(INFO) << "Fused FLOPs: " << fused_flops; - LOG(INFO) << "Num threads: " << num_threads; - LOG(INFO) << "Compute time: " << compute_time; - LOG(INFO) << "Input read time: " << read_time; - LOG(INFO) << "Output write time: " << consumer_runtime.write_time; - } - - return CombineComputeAndMemoryAccessTime( + auto exec_time = CombineComputeAndMemoryAccessTime( compute_time, read_time + consumer_runtime.write_time, config); + + VLogResult(flops, bytes_read, consumer_runtime.bytes_written, num_threads, + compute_time, read_time, consumer_runtime.write_time, exec_time); + + return exec_time; } /*static*/ @@ -702,11 +283,13 @@ absl::Duration GpuPerformanceModel::EstimateRunTimeForFusionCached( return fusion_runtime; } +/*static*/ absl::Duration GpuPerformanceModel::EstimateFusedExecTime( const HloInstruction* producer, const EstimateRunTimeData& producer_runtime, const GpuHloCostAnalysis* cost_analysis, const GpuPerformanceModelOptions& config, - const std::vector& fused_consumers, bool multi_output) { + absl::Span fused_consumers, + bool multi_output) { const se::DeviceDescription* device_info = cost_analysis->device_info_; absl::Duration exec_time_fused = @@ -760,11 +343,13 @@ absl::Duration GpuPerformanceModel::EstimateFusedExecTime( return exec_time_fused; } +/*static*/ GpuPerformanceModel::RunTimes GpuPerformanceModel::EstimateRunTimesForPriorityFusion( const HloInstruction* producer, const GpuHloCostAnalysis* cost_analysis, const GpuPerformanceModelOptions& config, - std::vector fused_consumers, bool multi_output) { + absl::Span fused_consumers, + bool multi_output) { EstimateRunTimeData producer_runtime = EstimateRunTimeForInstructionCached(producer, cost_analysis, config); @@ -802,10 +387,12 @@ GpuPerformanceModel::EstimateRunTimesForPriorityFusion( return {time_unfused, time_fused}; } +/*static*/ GpuPerformanceModel::RunTimes GpuPerformanceModel::EstimateRunTimes( const HloInstruction* producer, const GpuHloCostAnalysis* cost_analysis, const GpuPerformanceModelOptions& config, - std::vector fused_consumers, bool multi_output) { + absl::Span fused_consumers, + bool multi_output) { VLOG(8) << "Producer: " << producer->name(); if (producer->opcode() == HloOpcode::kFusion) { VLOG(10) << producer->fused_instructions_computation()->ToString(); @@ -821,19 +408,8 @@ GpuPerformanceModel::RunTimes GpuPerformanceModel::EstimateRunTimes( EstimateFusedExecTime(producer, producer_runtime, cost_analysis, config, fused_consumers, multi_output); - int64_t fused_consumer_count = fused_consumers.size(); - float total_producer_utilization = 0; - - for (const HloInstruction* fused_consumer : fused_consumers) { - float utilization_by_this_consumer = cost_analysis->operand_utilization( - *fused_consumer, fused_consumer->operand_index(producer)); - total_producer_utilization += utilization_by_this_consumer; - } - if (VLOG_IS_ON(8)) { - LOG(INFO) << "Consumer count: " << fused_consumer_count; - LOG(INFO) << "Utilization of producer output: " - << total_producer_utilization; + LOG(INFO) << "Consumer count: " << fused_consumers.size(); LOG(INFO) << "Unfused time: " << time_unfused; LOG(INFO) << "Fused time: " << time_fused; } @@ -841,6 +417,7 @@ GpuPerformanceModel::RunTimes GpuPerformanceModel::EstimateRunTimes( return {time_unfused, time_fused}; } +/*static*/ void GpuPerformanceModel::RecordEstimatedRunTime( HloInstruction* instruction, const GpuHloCostAnalysis* cost_analysis, const GpuPerformanceModelOptions& config) { @@ -854,196 +431,19 @@ void GpuPerformanceModel::RecordEstimatedRunTime( auto gpu_config = instruction->backend_config(); TF_CHECK_OK(gpu_config.status()) << instruction->ToString(); - FusionBackendConfig& backend_config = - *gpu_config->mutable_fusion_backend_config(); - backend_config.mutable_reification_cost()->set_end_to_end_cycles(cycles); + auto reification_cost = + gpu_config->mutable_fusion_backend_config()->mutable_reification_cost(); + reification_cost->set_end_to_end_cycles(cycles); + reification_cost->set_compute_time_us( + absl::ToDoubleMicroseconds(data.compute_time)); + reification_cost->set_memory_access_time_us( + absl::ToDoubleMicroseconds(data.read_time + data.write_time)); + reification_cost->set_exec_time_us( + absl::ToDoubleMicroseconds(data.exec_time)); TF_CHECK_OK(instruction->set_backend_config(*gpu_config)); VLOG(8) << "RecordEstimatedRunTime: " << instruction->ToString(); } -// Returns NVLink bw in GB/s -/*static*/ -float GpuPerformanceWithCollectiveModel::GetNvlinkBw( - se::CudaComputeCapability compute_capability) { - return compute_capability.IsAtLeast(se::CudaComputeCapability::HOPPER) - ? kSm90NvlinkBandwidth - : compute_capability.IsAtLeast(se::CudaComputeCapability::AMPERE) - ? kSm80NvlinkBandwidth - : compute_capability.IsAtLeast(se::CudaComputeCapability::VOLTA) - ? kSm70NvlinkBandwidth - : compute_capability.IsAtLeast(se::CudaComputeCapability::PASCAL_) - ? kSm60NvlinkBandwidth - : kSm80NvlinkBandwidth; -} - -/*static*/ bool GpuPerformanceWithCollectiveModel::InitNvml() { -#if GOOGLE_CUDA - void* libhandle = dlopen("libnvidia-ml.so.1", RTLD_NOW); - CHECK(libhandle != nullptr) << "Failed to open libnvidia-ml.so.1"; - - struct SymbolEntry { - void** functor; - char const* name; - }; - - std::vector symbols = { - {(void**)&xla_nvmlInit, "nvmlInit_v2"}, - {(void**)&xla_nvmlShutdown, "nvmlShutdown"}, - {(void**)&xla_nvmlDeviceGetHandleByIndex, "nvmlDeviceGetHandleByIndex"}, - {(void**)&xla_nvmlDeviceGetNvLinkCapability, - "nvmlDeviceGetNvLinkCapability"}, - }; - for (SymbolEntry se : symbols) { - *se.functor = dlsym(libhandle, se.name); - } - nvmlReturn_t init_result = xla_nvmlInit(); - return init_result == NVML_SUCCESS; -#else - return false; -#endif // GOOGLE_CUDA -} - -/*static*/ bool GpuPerformanceWithCollectiveModel::ShutdownNvml() { -#if GOOGLE_CUDA - nvmlReturn_t shutdown_result = xla_nvmlShutdown(); - return shutdown_result == NVML_SUCCESS; -#else - return false; -#endif // GOOGLE_CUDA -} - -/*static*/ uint32_t -GpuPerformanceWithCollectiveModel::CheckIfNvlinkSupportsP2P() { -#if GOOGLE_CUDA - // We will use nvml library to detect nvlink capability - // to see if it supports p2p communication. - // We first load libnvidia-ml.so and assign symbols to function pointers - // to avoid linking errors. - // Then gpu 0 will be used to query for nvlink capability, note that - // we only look at link 0 of gpu 0 since all other links are assumed - // to have the same capability. - CHECK(InitNvml()) << "NVML init failed."; - nvmlDevice_t nvml_device; - nvmlReturn_t get_device_result = - xla_nvmlDeviceGetHandleByIndex(0, &nvml_device); - CHECK(get_device_result == NVML_SUCCESS); - - uint32_t supported_p2p = 0; - - nvmlReturn_t nvlink_cap_result = xla_nvmlDeviceGetNvLinkCapability( - nvml_device, /*nvlink link number*/ 0, NVML_NVLINK_CAP_P2P_SUPPORTED, - &supported_p2p); - CHECK(nvlink_cap_result == NVML_SUCCESS); - CHECK(ShutdownNvml()) << "NVML shutdown failed."; - return supported_p2p; -#else - return 0; -#endif // GOOGLE_CUDA -} - -/*static*/ absl::Duration -GpuPerformanceWithCollectiveModel::ComputeAllreduceTime( - const HloInstruction& instr, const GpuHloCostAnalysis* cost_analysis, - const se::DeviceDescription& gpu_device_info) { - // We use nccl group call to launch multiple allreduces so launch overhead - // only occurs once. - absl::Duration total_time = kNcclKernelLaunchOverhead; - stream_executor::CudaComputeCapability compute_cap = - gpu_device_info.cuda_compute_capability(); - - int64_t size_of_speed_array = kIntraNodeSpeeds.size(); - int64_t size_of_sm90_speed_array = kIntraNodeSpeedsSm90.size(); - - int num_speeds = compute_cap.major >= se::CudaComputeCapability::HOPPER - ? size_of_sm90_speed_array - : size_of_speed_array; - const double* speeds = compute_cap.major >= se::CudaComputeCapability::HOPPER - ? kIntraNodeSpeedsSm90.data() - : kIntraNodeSpeeds.data(); - - int speed_index = 0; - float max_sys_bw = - GetMaxSysBwFromGpu(compute_cap, kLowLatencyMaxBandwidths.data()); - - CHECK_GT(max_sys_bw, 0); - - while ((speed_index < num_speeds - 1) && speeds[speed_index] > max_sys_bw) { - speed_index++; - } - float bw_intra_node = speeds[speed_index]; - int64_t num_devices = cost_analysis->NumOfDevices(instr); - - int64_t min_nchannels = - std::max(num_devices, GetMinNumberOfChannels(CollectiveAlgo::RING)); - int64_t num_channels = - std::max(min_nchannels, GetNcclMaxNumChannels(CollectiveAlgo::RING)); - int default_threads = - (bw_intra_node * num_channels <= kPciBandwidth) ? 256 : kLL128NumThreads; - - int warp_size = gpu_device_info.threads_per_warp(); - int num_threads = GetNumThreads(warp_size, kLL128NumThreads / 4, - kLL128NumThreads, default_threads); - - // Since channels are pipelined together, compute time will only occur as in a - // single channel. - absl::Duration compute_time_per_channel = - ComputeTime(gpu_device_info, - cost_analysis->flop_count(instr) / num_channels, num_threads); - total_time += compute_time_per_channel; - - uint32_t supported_p2p = CheckIfNvlinkSupportsP2P(); - - if (supported_p2p == 0) { - VLOG(8) << "Nvlink doesn't support p2p communication. Model will " - "continue using default system bandwidth."; - } else { - VLOG(8) << "Nvlink supports p2p communication, setting intra node " - "bandwidth to nvlink bw."; - bw_intra_node = GetNvlinkBw(compute_cap); - } - - double bus_bandwidth = bw_intra_node * num_channels; - - // Get per channel LL128 ring bandwidth - double per_channel_ring_ll128_Bw = - GetMaxSysBwFromGpu(compute_cap, kPerChannelMaxRingLL128Bandwidths.data()); - - bus_bandwidth = std::min(bus_bandwidth * kRingAlgorithmDiscountFactor, - num_channels * per_channel_ring_ll128_Bw); - double actual_bandwidth = bus_bandwidth * cost_analysis->ScalingRatio(instr); - - absl::Duration communication_time = absl::Microseconds( - cost_analysis->bytes_accessed(instr) / (1e6 * actual_bandwidth)); - total_time += communication_time; - return total_time; -} - -/*static*/ absl::Duration -GpuPerformanceWithCollectiveModel::ComputeCollectiveTime( - const HloInstruction& instr, const GpuHloCostAnalysis* cost_analysis, - const se::DeviceDescription& gpu_device_info) { - if (cost_analysis->NumOfDevices(instr) == 1) { - VLOG(8) << "Returning only kernel launch overhead for a single partition."; - return kNcclKernelLaunchOverhead; - } - - if (HloDataflowAnalysis::IsAsynchronousOperationDone(instr.opcode())) { - VLOG(8) << "Returning 0 cost for async done op " << instr.name(); - return absl::ZeroDuration(); - } - switch (instr.opcode()) { - case HloOpcode::kAllReduce: - case HloOpcode::kAllReduceStart: - return ComputeAllreduceTime(instr, cost_analysis, gpu_device_info); - default: { - LOG(WARNING) - << "Runtime estimate for " << instr.name() - << " not implemented. Returning only the kernel launch time."; - return kNcclKernelLaunchOverhead; - } - } -} - } // namespace gpu } // namespace xla diff --git a/third_party/xla/xla/service/gpu/model/gpu_performance_model.h b/third_party/xla/xla/service/gpu/model/gpu_performance_model.h index cdd4b5feae4a19..d23c74d96563c2 100644 --- a/third_party/xla/xla/service/gpu/model/gpu_performance_model.h +++ b/third_party/xla/xla/service/gpu/model/gpu_performance_model.h @@ -16,125 +16,18 @@ limitations under the License. #ifndef XLA_SERVICE_GPU_MODEL_GPU_PERFORMANCE_MODEL_H_ #define XLA_SERVICE_GPU_MODEL_GPU_PERFORMANCE_MODEL_H_ -#include -#include -#include "absl/container/flat_hash_map.h" #include "absl/time/time.h" +#include "absl/types/span.h" #include "xla/hlo/ir/hlo_instruction.h" -#include "xla/service/gpu/hlo_fusion_analysis.h" -#include "xla/service/gpu/hlo_traversal.h" -#include "xla/service/gpu/model/fusion_analysis_cache.h" #include "xla/service/gpu/model/gpu_hlo_cost_analysis.h" -#include "xla/stream_executor/device_description.h" - -#if GOOGLE_CUDA -#include - -#include "third_party/gpus/cuda/nvml/include/nvml.h" -// Below is a list of function pointers to be used -// for querying device properties through nvml library. -#define NVML_FUNCTOR(name, rettype, args) rettype(*xla_##name) args = nullptr; - -NVML_FUNCTOR(nvmlInit, nvmlReturn_t, ()) -NVML_FUNCTOR(nvmlShutdown, nvmlReturn_t, ()) -NVML_FUNCTOR(nvmlDeviceGetHandleByIndex, nvmlReturn_t, - (unsigned int index, nvmlDevice_t* device)) -NVML_FUNCTOR(nvmlDeviceGetNvLinkCapability, nvmlReturn_t, - (nvmlDevice_t device, unsigned int link, - nvmlNvLinkCapability_t capability, unsigned int* capResult)) - -#endif +#include "xla/service/gpu/model/gpu_performance_model_base.h" namespace xla { namespace gpu { -struct EstimateRunTimeData { - int64_t flops; - int64_t bytes_written; - int64_t num_threads; - absl::Duration write_time; - absl::Duration exec_time; -}; - -class GpuPerformanceModelCache { - public: - // Returns cached runtime data for the instruction or producer-consumer pair. - // Returns nullopt if there is no data in cache. - std::optional Get(const HloInstruction& instruction); - std::optional Get(const HloInstruction& producer, - const HloInstruction& consumer); - - // Sets cache value for the instruction or producer-consumer pair. - void Set(const HloInstruction& instruction, - const EstimateRunTimeData& runtime_data); - void Set(const HloInstruction& producer, const HloInstruction& consumer, - absl::Duration runtime); - - // Removes all cache entries for this instruction. The cache contains entries - // for individual instructions in instruction_runtime_data_ and for - // producer-consumer pairs in fusion_runtime_data_. - void Invalidate(const HloInstruction& instruction); - - private: - absl::Mutex mutex_; - - // Stores unfused runtime data for individual instructions. - absl::flat_hash_map - instruction_runtime_data_; - - // Stores fused runtime data for producer-consumer pairs. - absl::flat_hash_map< - HloInstructionAdaptor, - absl::flat_hash_map> - fusion_runtime_data_; -}; - -struct GpuPerformanceModelOptions { - // Factor for how much parallelism between compute and memory accesses should - // be assumed. If 1.0, assume perfect parallelism (the run time is the maximum - // of both times). If 0.0, assume no parallelism (the run time is the sum of - // both times). - double memory_compute_parallelism = 1.0; - - // If present, use this to retrieve fusion analyses. - HloFusionAnalysisCache* fusion_analysis_cache = nullptr; - - GpuPerformanceModelCache* gpu_performance_model_cache = nullptr; - - static GpuPerformanceModelOptions Default() { - return GpuPerformanceModelOptions(); - } - - static GpuPerformanceModelOptions PriorityFusion( - HloFusionAnalysisCache* fusion_analysis_cache = nullptr, - GpuPerformanceModelCache* gpu_performance_model_cache = nullptr) { - GpuPerformanceModelOptions config; - config.fusion_analysis_cache = fusion_analysis_cache; - config.gpu_performance_model_cache = gpu_performance_model_cache; - // This constant was chosen empirically in early 2024, based on runtime - // performance on a set of benchmarks internal to Google. Intuitively, we - // expect it to be close to 1, but not quite 1 (i.e., sometimes, compute - // or memory accesses will be stalled waiting for the other, but usually - // they won't). - config.memory_compute_parallelism = 0.95; - return config; - } - - static GpuPerformanceModelOptions ForModule(const HloModule* module) { - return module->config().debug_options().xla_gpu_enable_priority_fusion() - ? PriorityFusion() // Only cache within priority fusion. - : Default(); - } -}; - -class GpuPerformanceModel { +class GpuPerformanceModel : public GpuPerformanceModelBase { public: - struct RunTimes { - absl::Duration time_unfused; - absl::Duration time_fused; - }; - static EstimateRunTimeData EstimateRunTimeForInstruction( const HloInstruction* instr, const GpuHloCostAnalysis* cost_analysis, const GpuPerformanceModelOptions& config); @@ -163,113 +56,32 @@ class GpuPerformanceModel { const EstimateRunTimeData& producer_runtime, const GpuHloCostAnalysis* cost_analysis, const GpuPerformanceModelOptions& config, - const std::vector& fused_consumers); + absl::Span fused_consumers); static absl::Duration EstimateFusedExecTime( const HloInstruction* producer, const EstimateRunTimeData& producer_runtime, const GpuHloCostAnalysis* cost_analysis, const GpuPerformanceModelOptions& config, - const std::vector& fused_consumers, bool multi_output); + absl::Span fused_consumers, + bool multi_output); static RunTimes EstimateRunTimes( const HloInstruction* producer, const GpuHloCostAnalysis* cost_analysis, const GpuPerformanceModelOptions& config, - std::vector fused_consumers = {}, + absl::Span fused_consumers = {}, bool multi_output = false); static RunTimes EstimateRunTimesForPriorityFusion( const HloInstruction* producer, const GpuHloCostAnalysis* cost_analysis, const GpuPerformanceModelOptions& config, - std::vector fused_consumers = {}, + absl::Span fused_consumers = {}, bool multi_output = false); // Writes estimated execution time to FusionBackendConfig.reification_cost. static void RecordEstimatedRunTime(HloInstruction* instruction, const GpuHloCostAnalysis* cost_analysis, const GpuPerformanceModelOptions& config); - static absl::Duration ComputeTime( - const se::DeviceDescription& gpu_device_info, int64_t flops, - int64_t num_threads); - - static absl::Duration ProducerInputAccessTime( - const GpuHloCostAnalysis* cost_analysis, - const se::DeviceDescription& gpu_device_info, int64_t num_blocks, - const HloInstruction* producer, const HloFusionAnalysis& fusion_analysis, - const GpuPerformanceModelOptions& config, - const HloInstruction* fused_consumer = nullptr); -}; - -class GpuPerformanceWithCollectiveModel : public GpuPerformanceModel { - public: - // Different algorithms that can be used to perform the collective. - enum CollectiveAlgo { - RING = 0, - TREE, - }; - - // Table for max system bandwidths GB/s for using NCCL's low latency - // algorithm. This is used for intra-node estimate. - static constexpr std::array kLowLatencyMaxBandwidths = { - 39.0 /* Volta*/, 87.7 /* Ampere*/, 87.7 /* Hopper*/ - }; - - // Max bandwidth in GB/s for ring low latency 128 algorithm per channel on a - // single-node - static constexpr std::array kPerChannelMaxRingLL128Bandwidths = { - 20.0 /* Volta */, - 20.0 /* Ampere */, - 36.7 /* Hopper */, - }; - - // Nvlink unidirectional bandwidth for different compute cap. Note this is per - // lane bandwidth. - static constexpr double kSm60NvlinkBandwidth = 18.0; - static constexpr double kSm70NvlinkBandwidth = 20.0; - static constexpr double kSm80NvlinkBandwidth = 20.0; - static constexpr double kSm90NvlinkBandwidth = 20.0; - - // PCIE bandwidth for PCI Gen3 x16 - static constexpr double kPciBandwidth = 12.0; - - // Discount factor for ring algorithm - static constexpr double kRingAlgorithmDiscountFactor = 0.92; - - // Different tiers for intra-node bandwidth. - static constexpr std::array kIntraNodeSpeeds = { - 40.0, 30.0, 20.0, 18.0, 15.0, 12.0, 10.0, 9.0, 7.0, 6.0, 5.0, 4.0, 3.0}; - // SM90 has different bandwidths. - static constexpr std::array kIntraNodeSpeedsSm90 = { - 60.0, 40.0, 30.0, 24.0, 20.0, 15.0, 12.0, 6.0, 3.0}; - - // Maximum number of channels allowed by NCCL - static constexpr int64_t kMaxNumChannelsRing = 16; - - // ll128 is by default enabled for Volta, Ampere and Hopper, ll128 by default - // launches 640 threads. - static constexpr int64_t kLL128NumThreads = 640; - - static absl::Duration ComputeCollectiveTime( - const HloInstruction& instr, const GpuHloCostAnalysis* cost_analysis, - const se::DeviceDescription& gpu_device_info); - - // Returns NVLink bw in GB/s - static float GetNvlinkBw(se::CudaComputeCapability compute_capability); - - // Initialize nvml library. - static bool InitNvml(); - - // Shut down nvml library. - static bool ShutdownNvml(); - - // This checks if the nvlink supports direct P2P communication, - // If not, we will use PCIE bandwidth to estimate latency. - static uint32_t CheckIfNvlinkSupportsP2P(); - - private: - static absl::Duration ComputeAllreduceTime( - const HloInstruction& instr, const GpuHloCostAnalysis* cost_analysis, - const se::DeviceDescription& gpu_device_info); }; } // namespace gpu diff --git a/third_party/xla/xla/service/gpu/model/gpu_performance_model_base.cc b/third_party/xla/xla/service/gpu/model/gpu_performance_model_base.cc new file mode 100644 index 00000000000000..40bb1ff69b1b7f --- /dev/null +++ b/third_party/xla/xla/service/gpu/model/gpu_performance_model_base.cc @@ -0,0 +1,404 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/model/gpu_performance_model_base.h" + +#include +#include +#include +#include + +#include "absl/container/flat_hash_set.h" +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/synchronization/mutex.h" +#include "absl/time/time.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/service/gpu/backend_configs.pb.h" +#include "xla/service/gpu/fusions/fusion_emitter.h" +#include "xla/service/gpu/fusions/fusions.h" +#include "xla/service/gpu/hlo_fusion_analysis.h" +#include "xla/service/gpu/hlo_traversal.h" +#include "xla/service/gpu/launch_dimensions.h" +#include "xla/service/gpu/model/gpu_hlo_cost_analysis.h" +#include "xla/shape_util.h" +#include "xla/stream_executor/device_description.h" +#include "xla/util.h" + +namespace xla { +namespace gpu { + +namespace { + +// Returns whether a fusion uses the parameter at the given index elementwise +// from its root. +bool FusionUsesParameterElementwiseFromRoot( + const HloInstruction* fusion, int parameter_index, + const GpuHloCostAnalysis* cost_analysis) { + return cost_analysis->CommonElementwiseUtilization( + fusion->fused_parameter(parameter_index), + fusion->fused_expression_root()) == 1.f; +} + +int GetCoalescingWasteFactor(PrimitiveType element_type) { + int64_t element_size_bytes = + element_type == PrimitiveType::TUPLE || + element_type == PrimitiveType::TOKEN + ? 4 /* Dummy value. TODO(jreiffers): Model this case. */ + : ShapeUtil::ByteSizeOfPrimitiveType(element_type); + // Cache line is 128B that is split into 4 sectors of 32B. Default transaction + // size from DRAM -> L2 = 64 Bytes = 2 sectors, since V100, but it can be also + // configured. + // https://developer.download.nvidia.com/video/gputechconf/gtc/2020/presentations/s21819-optimizing-applications-for-nvidia-ampere-gpu-architecture.pdf + // (page 10). + constexpr int kDRAMToL2TransactionSizeBytes = 64; + // Assume we use one element from the cache line and waste the remaining + // bandwidth. For example, if we're reading f32s, we use 1/16nd of the cache + // line. + return kDRAMToL2TransactionSizeBytes / element_size_bytes; +} + +// Limit the bandwidth for low occupancy cases. Each SM can issue at most +// one 32B memory transaction per clock. H100 needs at least 56.8 active SMs +// (1830 MHz) to saturate the memory bandwidth (3.35 TB/s). +float AdjustBandwidth(const se::DeviceDescription& gpu_device_info, + float bandwidth, int64_t num_blocks) { + float per_block_bandwidth = gpu_device_info.clock_rate_ghz() * 1.0e9f * 32; + float max_bandwidth = num_blocks * per_block_bandwidth; + + return std::min(bandwidth, max_bandwidth); +} + +} // namespace + +std::optional GpuPerformanceModelCache::Get( + const HloInstruction& instruction) { + absl::MutexLock lock(&mutex_); + + auto it = instruction_runtime_data_.find(HloInstructionAdaptor(instruction)); + if (it != instruction_runtime_data_.end()) { + return it->second; + } + return std::nullopt; +} + +std::optional GpuPerformanceModelCache::Get( + const HloInstruction& producer, const HloInstruction& consumer) { + absl::MutexLock lock(&mutex_); + + auto it = fusion_runtime_data_.find(HloInstructionAdaptor(producer)); + if (it != fusion_runtime_data_.end()) { + auto jt = it->second.find(HloInstructionAdaptor(consumer)); + if (jt != it->second.end()) { + return jt->second; + } + } + return std::nullopt; +} + +void GpuPerformanceModelCache::Set(const HloInstruction& instruction, + const EstimateRunTimeData& runtime_data) { + absl::MutexLock lock(&mutex_); + + instruction_runtime_data_[HloInstructionAdaptor(instruction)] = runtime_data; +} + +void GpuPerformanceModelCache::Set(const HloInstruction& producer, + const HloInstruction& consumer, + absl::Duration runtime) { + absl::MutexLock lock(&mutex_); + fusion_runtime_data_[HloInstructionAdaptor(producer)] + [HloInstructionAdaptor(consumer)] = runtime; +} + +void GpuPerformanceModelCache::Invalidate(const HloInstruction& instruction) { + absl::MutexLock lock(&mutex_); + HloInstructionAdaptor adaptor(instruction); + + // Remove runtime data for the instruction. + instruction_runtime_data_.erase(adaptor); + + // Remove cache for all producer-consumer pairs where the instruction is + // producer. + fusion_runtime_data_.erase(adaptor); + + // Iterate through operands to find all producer-consumer pairs where + // instruction is consumer and remove them from cache. + for (auto* operand : instruction.operands()) { + auto it = fusion_runtime_data_.find(HloInstructionAdaptor(*operand)); + if (it != fusion_runtime_data_.end()) { + it->second.erase(adaptor); + } + } +} + +/*static*/ +LaunchDimensions GpuPerformanceModelBase::EstimateFusionLaunchDimensions( + int64_t estimated_num_threads, const HloFusionAnalysis& fusion_analysis, + const se::DeviceDescription& device_info) { + auto emitter = + GetFusionEmitter(PreBufferAssignmentFusionInfo{fusion_analysis}); + if (emitter.ok()) { + if (const auto* kernel_emitter = + dynamic_cast(emitter->get())) { + return kernel_emitter->launch_dimensions(); + } + } + int64_t block_size = 128; // Result for default LaunchDimensionsConfig. + int64_t num_blocks = CeilOfRatio(estimated_num_threads, block_size); + return LaunchDimensions(num_blocks, block_size); +} + +/*static*/ +int64_t GpuPerformanceModelBase::GetOperandBytesAccessed( + const GpuHloCostAnalysis* cost_analysis, const HloInstruction* instr, + const HloInstruction* operand) { + // When called for a consumer-producer fusion, the operand can be from a + // different instruction. GpuHloCostAnalysis can't fail gravefully in this + // case, so we need an explicit check. + if (!instr->IsUserOf(operand)) { + return 0; + } + + return cost_analysis->operand_bytes_accessed(*instr, + instr->operand_index(operand)); +} + +/*static*/ +float GpuPerformanceModelBase::GetOperandUtilization( + const GpuHloCostAnalysis* cost_analysis, const HloInstruction* instr, + const HloInstruction* operand) { + // When called for a consumer-producer fusion, the operand can be from a + // different instruction. GpuHloCostAnalysis can't fail gravefully in this + // case, so we need an explicit check. + if (!instr->IsUserOf(operand)) { + return 0.f; + } + + return cost_analysis->operand_utilization(*instr, + instr->operand_index(operand)); +} + +/*static*/ +float GpuPerformanceModelBase::GetCommonUtilization( + const GpuHloCostAnalysis* cost_analysis, const HloInstruction* producer, + int64_t producer_idx_of_operand, const HloInstruction* consumer) { + const auto* operand = producer->operand(producer_idx_of_operand); + + if (!consumer || !consumer->IsUserOf(operand)) { + return 0.f; + } + + if (producer->IsElementwise() || + (producer->opcode() == HloOpcode::kFusion && + FusionUsesParameterElementwiseFromRoot(producer, producer_idx_of_operand, + cost_analysis))) { + if (consumer->opcode() == HloOpcode::kFusion) { + int64_t consumer_idx_of_common_operand = consumer->operand_index(operand); + int64_t consumer_idx_of_producer = consumer->operand_index(producer); + return cost_analysis->CommonElementwiseUtilization( + consumer->fused_parameter(consumer_idx_of_common_operand), + consumer->fused_parameter(consumer_idx_of_producer)); + } else { + if (consumer->IsElementwise()) { + return 1.f; + } + } + } + return 0.f; +} + +/*static*/ +int64_t GpuPerformanceModelBase::GetSharedOperandBytesAccessed( + const GpuHloCostAnalysis* cost_analysis, const HloInstruction* producer, + const HloInstruction* consumer, const HloInstruction* operand) { + float producer_utilization_by_consumer = + GetOperandUtilization(cost_analysis, consumer, producer); + + int64_t bytes_accessed_by_producer = + GetOperandBytesAccessed(cost_analysis, producer, operand); + + int64_t bytes_accessed_by_consumer = + GetOperandBytesAccessed(cost_analysis, consumer, operand); + + float common_utilization = + producer->IsUserOf(operand) + ? GetCommonUtilization(cost_analysis, producer, + producer->operand_index(operand), consumer) + : 0.f; + + int64_t operand_size = cost_analysis->GetShapeSize(operand->shape()); + int64_t common_bytes_accessed = + std::llround(operand_size * common_utilization); + + return std::llround(bytes_accessed_by_producer * + producer_utilization_by_consumer) + + bytes_accessed_by_consumer - common_bytes_accessed; +} + +/*static*/ +absl::Duration GpuPerformanceModelBase::ReadTime( + const se::DeviceDescription& gpu_device_info, int64_t num_blocks, + int64_t n_bytes_net, int64_t n_bytes_total) { + float bandwidth = gpu_device_info.memory_bandwidth(); + if (n_bytes_net < gpu_device_info.l2_cache_size()) { + bandwidth *= kL2CacheSpeedup; + if (n_bytes_net < kL1CacheSizePerSM * gpu_device_info.core_count()) { + bandwidth *= kL1CacheSpeedup; + } + } + + bandwidth = AdjustBandwidth(gpu_device_info, bandwidth, num_blocks); + return absl::Seconds(n_bytes_total / bandwidth); +} + +/*static*/ +absl::Duration GpuPerformanceModelBase::ReadTimeWithDRAMHeuristic( + const se::DeviceDescription& gpu_device_info, int64_t num_blocks, + int64_t n_bytes_net, int64_t n_bytes_total, PrimitiveType element_type, + bool coalesced) { + int waste_factor = coalesced ? 1 : GetCoalescingWasteFactor(element_type); + + // The first read of the input buffer always happens from DRAM. If reads are + // no coaleced, bandwidth is reduced by the waste factor. + float dram_bandwidth = gpu_device_info.memory_bandwidth() / waste_factor; + + // Two things can happed on re-reading the buffer: + // - If the buffer fits into cache, the L1/L2 cache speedup is applied. + // - If the buffer doesn't fit, it will be read from DRAM and the same + // coalessing waste factor is applied. + float rest_bandwidth = gpu_device_info.memory_bandwidth(); + if (n_bytes_net < gpu_device_info.l2_cache_size()) { + rest_bandwidth *= kL2CacheSpeedup; + if (n_bytes_net < kL1CacheSizePerSM * gpu_device_info.core_count()) { + rest_bandwidth *= kL1CacheSpeedup; + } + } else { + rest_bandwidth /= waste_factor; + } + + dram_bandwidth = AdjustBandwidth(gpu_device_info, dram_bandwidth, num_blocks); + rest_bandwidth = AdjustBandwidth(gpu_device_info, rest_bandwidth, num_blocks); + + // n_bytes_net > n_bytes_total can happen when we compute read time of + // shared operand. This is a flaw in the interface that should be fixed. + int64_t n_bytes_read_dram = std::min(n_bytes_net, n_bytes_total); + + // Number of bytes that we be re-read, potentially from cache. + int64_t n_bytes_read_cache = n_bytes_total - n_bytes_read_dram; + + return absl::Seconds(n_bytes_read_dram / dram_bandwidth) + + absl::Seconds(n_bytes_read_cache / rest_bandwidth); +} + +/*static*/ absl::Duration GpuPerformanceModelBase::ProducerInputAccessTime( + const GpuHloCostAnalysis* cost_analysis, + const se::DeviceDescription& gpu_device_info, int64_t num_blocks, + const HloInstruction* producer, const HloFusionAnalysis& fusion_analysis, + const GpuPerformanceModelOptions& config, + const HloInstruction* fused_consumer) { + absl::Duration ret = absl::ZeroDuration(); + float producer_output_utilization = + fused_consumer + ? GetOperandUtilization(cost_analysis, fused_consumer, producer) + : 1.f; + + for (int i = 0; i < producer->operand_count(); ++i) { + // Information about data read taking into account utilization. + // If `operand_utilization` is 0, `operand_bytes_accessed` should be also 0. + int64_t operand_bytes_accessed = + cost_analysis->operand_bytes_accessed(*producer, i); + float operand_utilization = + cost_analysis->operand_utilization(*producer, i); + + // An estimate how much data would need to fit into L1/L2 cache to speed up + // the operand access. + // If `operand_utilization` < 1, only a part of the full operand size should + // be read. Otherwise, `operand_bytes_accessed / operand_utilization` is the + // size of the operand without reuse. + int64_t n_bytes_net = std::llround(operand_bytes_accessed / + std::max(operand_utilization, 1.0f)); + + // Look if common operand of producer and consumer will be accessed more + // efficiently on merge. + float common_utilization = GetCommonUtilization( + cost_analysis, producer, /*producer_idx_of_operand=*/i, fused_consumer); + + CHECK_LE(common_utilization, producer_output_utilization); + float n_bytes_total = operand_bytes_accessed * + (producer_output_utilization - common_utilization); + ret += ReadTime(gpu_device_info, num_blocks, n_bytes_net, n_bytes_total); + } + return ret; +} + +/*static*/ +absl::Duration GpuPerformanceModelBase::WriteTime( + const se::DeviceDescription& gpu_device_info, int64_t bytes_written) { + return absl::Seconds(1.0f * bytes_written / + gpu_device_info.memory_bandwidth()); +} + +/*static*/ +absl::Duration GpuPerformanceModelBase::ComputeTime( + const se::DeviceDescription& gpu_device_info, int64_t flops, + int64_t num_threads) { + int64_t fpu_count = + gpu_device_info.core_count() * gpu_device_info.fpus_per_core(); + int64_t n_threads_active = std::min(num_threads, fpu_count); + int64_t flop_per_ns_per_fpu = gpu_device_info.clock_rate_ghz() * /*fma:*/ 2; + int64_t flop_per_ns_effective = flop_per_ns_per_fpu * n_threads_active; + return absl::Nanoseconds(1.0f * flops / flop_per_ns_effective); +} + +/*static*/ +absl::Duration GpuPerformanceModelBase::CombineComputeAndMemoryAccessTime( + absl::Duration compute_time, absl::Duration memory_access_time, + const GpuPerformanceModelOptions& config) { + return compute_time + memory_access_time - + std::min(compute_time, memory_access_time) * + config.memory_compute_parallelism; +} + +/*static*/ +void GpuPerformanceModelBase::VLogOperandRead(const HloInstruction* operand, + int64_t n_bytes_total, + int64_t n_bytes_net, + bool coalesced) { + VLOG(8) << "operand " << operand->name() + << ", n_bytes_total: " << n_bytes_total + << ", n_bytes_net: " << n_bytes_net << ", coalesced: " << coalesced; +} + +/*static*/ +void GpuPerformanceModelBase::VLogResult( + int64_t flops, int64_t bytes_read, int64_t bytes_written, + int64_t num_threads, absl::Duration compute_time, absl::Duration read_time, + absl::Duration write_time, absl::Duration exec_time) { + if (VLOG_IS_ON(8)) { + LOG(INFO) << "FLOPs: " << flops; + LOG(INFO) << "Bytes read: " << bytes_read; + LOG(INFO) << "Bytes written: " << bytes_written; + LOG(INFO) << "Num threads: " << num_threads; + LOG(INFO) << "Compute time: " << compute_time; + LOG(INFO) << "Input read time: " << read_time; + LOG(INFO) << "Output write time: " << write_time; + LOG(INFO) << "Exec time: " << exec_time; + } +} + +} // namespace gpu +} // namespace xla diff --git a/third_party/xla/xla/service/gpu/model/gpu_performance_model_base.h b/third_party/xla/xla/service/gpu/model/gpu_performance_model_base.h new file mode 100644 index 00000000000000..7d08a0c68a0bb1 --- /dev/null +++ b/third_party/xla/xla/service/gpu/model/gpu_performance_model_base.h @@ -0,0 +1,229 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_GPU_MODEL_GPU_PERFORMANCE_MODEL_BASE_H_ +#define XLA_SERVICE_GPU_MODEL_GPU_PERFORMANCE_MODEL_BASE_H_ + +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/synchronization/mutex.h" +#include "absl/time/time.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/service/gpu/hlo_fusion_analysis.h" +#include "xla/service/gpu/hlo_traversal.h" +#include "xla/service/gpu/launch_dimensions.h" +#include "xla/service/gpu/model/fusion_analysis_cache.h" +#include "xla/service/gpu/model/gpu_hlo_cost_analysis.h" +#include "xla/stream_executor/device_description.h" + +namespace xla { +namespace gpu { + +struct EstimateRunTimeData { + int64_t flops; + int64_t bytes_written; + int64_t num_threads; + absl::Duration read_time; + absl::Duration write_time; + absl::Duration compute_time; + absl::Duration exec_time; +}; + +class GpuPerformanceModelCache { + public: + // Returns cached runtime data for the instruction or producer-consumer pair. + // Returns nullopt if there is no data in cache. + std::optional Get(const HloInstruction& instruction); + std::optional Get(const HloInstruction& producer, + const HloInstruction& consumer); + + // Sets cache value for the instruction or producer-consumer pair. + void Set(const HloInstruction& instruction, + const EstimateRunTimeData& runtime_data); + void Set(const HloInstruction& producer, const HloInstruction& consumer, + absl::Duration runtime); + + // Removes all cache entries for this instruction. The cache contains entries + // for individual instructions in instruction_runtime_data_ and for + // producer-consumer pairs in fusion_runtime_data_. + void Invalidate(const HloInstruction& instruction); + + private: + absl::Mutex mutex_; + + // Stores unfused runtime data for individual instructions. + absl::flat_hash_map + instruction_runtime_data_; + + // Stores fused runtime data for producer-consumer pairs. + absl::flat_hash_map< + HloInstructionAdaptor, + absl::flat_hash_map> + fusion_runtime_data_; +}; + +struct GpuPerformanceModelOptions { + // Factor for how much parallelism between compute and memory accesses should + // be assumed. If 1.0, assume perfect parallelism (the run time is the maximum + // of both times). If 0.0, assume no parallelism (the run time is the sum of + // both times). + double memory_compute_parallelism = 1.0; + + // If present, use this to retrieve fusion analyses. + HloFusionAnalysisCache* fusion_analysis_cache = nullptr; + + GpuPerformanceModelCache* gpu_performance_model_cache = nullptr; + + static GpuPerformanceModelOptions Default() { + return GpuPerformanceModelOptions(); + } + + static GpuPerformanceModelOptions PriorityFusion( + HloFusionAnalysisCache* fusion_analysis_cache = nullptr, + GpuPerformanceModelCache* gpu_performance_model_cache = nullptr) { + GpuPerformanceModelOptions config; + config.fusion_analysis_cache = fusion_analysis_cache; + config.gpu_performance_model_cache = gpu_performance_model_cache; + // This constant was chosen empirically in early 2024, based on runtime + // performance on a set of benchmarks internal to Google. Intuitively, we + // expect it to be close to 1, but not quite 1 (i.e., sometimes, compute + // or memory accesses will be stalled waiting for the other, but usually + // they won't). + config.memory_compute_parallelism = 0.95; + return config; + } + + static GpuPerformanceModelOptions ForModule(const HloModule* module) { + return module->config().debug_options().xla_gpu_enable_priority_fusion() + ? PriorityFusion() // Only cache within priority fusion. + : Default(); + } +}; + +class GpuPerformanceModelBase { + public: + struct RunTimes { + absl::Duration time_unfused; + absl::Duration time_fused; + }; + + // Estimated values in the absence of easy ways to query them. + static constexpr absl::Duration kKernelLaunchOverhead = absl::Microseconds(1); + static constexpr absl::Duration kNcclKernelLaunchOverhead = + absl::Microseconds(5); + static constexpr float kL2CacheSpeedup = 2.5; + static constexpr float kL1CacheSpeedup = 8; + // A very conservative estimate. L1 size varies because it can be dynamically + // configured as shared memory; there is no easy way to query its actual size; + // also we do not count what occupies cache, but rather claim that what is + // much smaller than the cache size will likely stay in it. + // For reference, it can be up to 256 kB per SM on RTX A6000. + static constexpr float kL1CacheSizePerSM = 2 * 1024; + + // Uses HloFusionAnalysis for computing the actual number of threads and + // blocks that the IR emitter will use. + static LaunchDimensions EstimateFusionLaunchDimensions( + int64_t estimated_num_threads, const HloFusionAnalysis& fusion_analysis, + const se::DeviceDescription& device_info); + + // Returns bytes accessed of operand output by instruction. Returns 0, if the + // operand is not used by the instruction. + static int64_t GetOperandBytesAccessed( + const GpuHloCostAnalysis* cost_analysis, const HloInstruction* instr, + const HloInstruction* operand); + + // Returns utilization of operand by instruction. Returns 0, if the operand is + // not used by the instruction. + static float GetOperandUtilization(const GpuHloCostAnalysis* cost_analysis, + const HloInstruction* instr, + const HloInstruction* operand); + + // Returns utilization `overlap` between a common operand of producer and + // consumer on merge. `utilization > 0` means that the operand will be + // accessed more efficiently after fusion. + // + // Currently covers two cases: + // 1) Producer has to use the common operand elementwise from its root if it + // is a fusion or just be an elementwise instruction. + // 2) Consumer has to have common elementwise roots for the producer and the + // common operand if it is a fusion or just be an elementwise instruction. + static float GetCommonUtilization(const GpuHloCostAnalysis* cost_analysis, + const HloInstruction* producer, + int64_t producer_idx_of_operand, + const HloInstruction* consumer); + + // Returns bytes accessed of operand after producer and consumer are fused + // together. `GetCommonUtilization` works only for a limited set of + // elementwise cases. + static int64_t GetSharedOperandBytesAccessed( + const GpuHloCostAnalysis* cost_analysis, const HloInstruction* producer, + const HloInstruction* consumer, const HloInstruction* operand); + + // Estimate read time of n_bytes_total bytes from global memory on a + // given GPU. Account for L1 / L2 cache speedup if the input's nominal size + // n_bytes_net is small. + static absl::Duration ReadTime(const se::DeviceDescription& gpu_device_info, + int64_t num_blocks, int64_t n_bytes_net, + int64_t n_bytes_total); + + // Estimate read time of n_bytes_total bytes from global memory on a + // given GPU. + // + // Assumes that the first n_bytes_net are always read from DRAM, but next + // reads can be cached. Applies waste factor if read from DRAM is uncoalesced. + static absl::Duration ReadTimeWithDRAMHeuristic( + const se::DeviceDescription& gpu_device_info, int64_t num_blocks, + int64_t n_bytes_net, int64_t n_bytes_total, PrimitiveType element_type, + bool coalesced); + + // Tells input access time of the producer alone if fused_consumer + // is not specified. Otherwise estimates the access time to producer's + // inputs as if it is fused into the consumer. + static absl::Duration ProducerInputAccessTime( + const GpuHloCostAnalysis* cost_analysis, + const se::DeviceDescription& gpu_device_info, int64_t num_blocks, + const HloInstruction* producer, const HloFusionAnalysis& fusion_analysis, + const GpuPerformanceModelOptions& config, + const HloInstruction* fused_consumer = nullptr); + + static absl::Duration WriteTime(const se::DeviceDescription& gpu_device_info, + int64_t bytes_written); + + static absl::Duration ComputeTime( + const se::DeviceDescription& gpu_device_info, int64_t flops, + int64_t num_threads); + + static absl::Duration CombineComputeAndMemoryAccessTime( + absl::Duration compute_time, absl::Duration memory_access_time, + const GpuPerformanceModelOptions& config); + + // Logs estimates for the operand read if VLOG is enabled. + static void VLogOperandRead(const HloInstruction* operand, + int64_t n_bytes_total, int64_t n_bytes_net, + bool coalesced); + + // Logs estimate results of the performance model if VLOG is enabled. + static void VLogResult(int64_t flops, int64_t bytes_read, + int64_t bytes_written, int64_t num_threads, + absl::Duration compute_time, absl::Duration read_time, + absl::Duration write_time, absl::Duration exec_time); +}; + +} // namespace gpu +} // namespace xla + +#endif // XLA_SERVICE_GPU_MODEL_GPU_PERFORMANCE_MODEL_BASE_H_ diff --git a/third_party/xla/xla/service/gpu/model/gpu_performance_model_base_test.cc b/third_party/xla/xla/service/gpu/model/gpu_performance_model_base_test.cc new file mode 100644 index 00000000000000..d15c0d4339bfc2 --- /dev/null +++ b/third_party/xla/xla/service/gpu/model/gpu_performance_model_base_test.cc @@ -0,0 +1,196 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/model/gpu_performance_model_base.h" + +#include + +#include +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/service/gpu/backend_configs.pb.h" +#include "xla/service/gpu/gpu_device_info_for_tests.h" +#include "xla/service/gpu/model/gpu_hlo_cost_analysis.h" +#include "xla/shape.h" +#include "xla/shape_util.h" +#include "xla/stream_executor/device_description.h" +#include "xla/test_helpers.h" +#include "xla/tests/hlo_test_base.h" +#include "tsl/platform/statusor.h" + +namespace xla { +namespace gpu { +namespace { + +class GpuPerformanceModelBaseTest : public HloTestBase { + public: + GpuHloCostAnalysis::ShapeSizeFunction ShapeSizeBytesFunction() const { + return [&](const Shape& shape) { + constexpr int64_t kPointerSize = 8; + return ShapeUtil::ByteSizeOf(shape, kPointerSize); + }; + } + + GpuHloCostAnalysis::Options options_{ShapeSizeBytesFunction(), + /*per_second_rates=*/{}, + /*count_multiple_input_accesses=*/true}; + // The reference times in the test cases below are measured + // on A6000 by profiling the execution of the HLOs. + se::DeviceDescription device_info_{TestGpuDeviceInfo::RTXA6000DeviceInfo()}; + GpuHloCostAnalysis analysis_{options_, &device_info_}; + + GpuPerformanceModelBaseTest() : HloTestBase() {} +}; + +TEST_F(GpuPerformanceModelBaseTest, SharedOperandBytesAccessed_InPlaceDUS) { + absl::string_view hlo_string = R"( +HloModule m + +ENTRY entry_computation { + param_0 = f32[8,16] parameter(0) + param_1 = f32[4,4] parameter(1) + c_0 = s32[] constant(0) + log = f32[4,4] log(param_1) + ROOT dynamic-update-slice = f32[8,16] dynamic-update-slice(param_0, log, c_0, c_0) +} +)"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + auto computation = module->entry_computation(); + ASSERT_IS_OK(computation->Accept(&analysis_)); + + auto dus_consumer = computation->root_instruction(); + auto log_producer = dus_consumer->mutable_operand(1); + + auto get_shared_operand_bytes_accessed = [&](const HloInstruction* operand) { + return GpuPerformanceModelBase::GetSharedOperandBytesAccessed( + &analysis_, log_producer, dus_consumer, operand); + }; + + EXPECT_EQ(get_shared_operand_bytes_accessed(dus_consumer->operand(0)), 0); + EXPECT_EQ(get_shared_operand_bytes_accessed(log_producer->operand(0)), 64); +} + +TEST_F(GpuPerformanceModelBaseTest, SharedOperandBytesAccessed_DUS) { + absl::string_view hlo_string = R"( +HloModule m + +ENTRY entry_computation { + param_0 = f32[8,16] parameter(0) + param_1 = f32[4,4] parameter(1) + c_0 = s32[] constant(0) + log = f32[8,16] log(param_0) + ROOT dynamic-update-slice = f32[8,16] dynamic-update-slice(log, param_1, c_0, c_0) +} +)"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + auto computation = module->entry_computation(); + ASSERT_IS_OK(computation->Accept(&analysis_)); + + auto dus_consumer = computation->root_instruction(); + auto log_producer = dus_consumer->mutable_operand(0); + + auto get_shared_operand_bytes_accessed = [&](const HloInstruction* operand) { + return GpuPerformanceModelBase::GetSharedOperandBytesAccessed( + &analysis_, log_producer, dus_consumer, operand); + }; + + EXPECT_EQ(get_shared_operand_bytes_accessed(dus_consumer->operand(1)), 64); + EXPECT_EQ(get_shared_operand_bytes_accessed(log_producer->operand(0)), 448); +} + +// This test documents current behaviour. See comments below how the correct +// result should look like. +TEST_F(GpuPerformanceModelBaseTest, + ReduceBroadcastedDim_IncorrectBytesAccessed) { + absl::string_view hlo_string = R"( +HloModule m + +add { + p0 = f32[] parameter(0) + p1 = f32[] parameter(1) + ROOT add = f32[] add(p0, p1) +} + +f1 { + p0 = f32[128] parameter(0) + c0 = f32[] constant(0) + broadcast = f32[128,256] broadcast(p0), dimensions={0} + ROOT reduce = f32[128] reduce(broadcast, c0), dimensions={1}, to_apply=add +} + +ENTRY entry_computation { + param_0 = f32[128] parameter(0) + param_1 = f32[4,4] parameter(1) + ROOT fusion = f32[128] fusion(param_0), kind=kLoop, calls=f1 +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + auto computation = module->entry_computation(); + ASSERT_IS_OK(computation->Accept(&analysis_)); + + auto root = computation->root_instruction(); + + // Cost Model estimates that input element we be re-read in reduce. Each + // element of reduce output needs only one input element. Bytes accessed + // should be 4*128=512. + EXPECT_EQ(GpuPerformanceModelBase::GetOperandBytesAccessed(&analysis_, root, + root->operand(0)), + /*4*128*256=*/131072); +} + +// This test documents current behaviour. See comments below how the correct +// result should look like. +TEST_F(GpuPerformanceModelBaseTest, ElementwiseBitcast_IncorrectBytesAccessed) { + absl::string_view hlo_string = R"( +HloModule m + +f1 { + p0 = f32[128] parameter(0) + bitcast.1 = f32[8,16] bitcast(p0) + log = f32[128] log(p0) + bitcast.2 = f32[8,16] bitcast(log) + ROOT add = f32[8,16] add(bitcast.1, bitcast.2) +} + +ENTRY entry_computation { + param_0 = f32[128] parameter(0) + ROOT fusion = f32[8,16] fusion(param_0), kind=kLoop, calls=f1 +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + auto computation = module->entry_computation(); + ASSERT_IS_OK(computation->Accept(&analysis_)); + + auto root = computation->root_instruction(); + + // Bitcast breaks the chain of elementwise utilization even if the bitcast + // doesn't change physical layout. Each element of `param_0` should be read + // only once, but Cost Model estimates that it will be accessed twice. Bytes + // accessed should be 4*128=512. + EXPECT_EQ(GpuPerformanceModelBase::GetOperandBytesAccessed(&analysis_, root, + root->operand(0)), + /*2*4*128=*/1024); +} + +} // namespace +} // namespace gpu +} // namespace xla diff --git a/third_party/xla/xla/service/gpu/model/gpu_performance_model_test.cc b/third_party/xla/xla/service/gpu/model/gpu_performance_model_test.cc index 6fa02ad2ad7f60..82984b1193bef8 100644 --- a/third_party/xla/xla/service/gpu/model/gpu_performance_model_test.cc +++ b/third_party/xla/xla/service/gpu/model/gpu_performance_model_test.cc @@ -21,8 +21,10 @@ limitations under the License. #include #include +#include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "absl/time/time.h" +#include "mlir/IR/MLIRContext.h" // from @llvm-project #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" @@ -30,6 +32,8 @@ limitations under the License. #include "xla/service/gpu/backend_configs.pb.h" #include "xla/service/gpu/gpu_device_info_for_tests.h" #include "xla/service/gpu/model/gpu_hlo_cost_analysis.h" +#include "xla/service/gpu/model/gpu_indexing_performance_model.h" +#include "xla/service/gpu/model/gpu_performance_model_base.h" #include "xla/service/hlo_module_config.h" #include "xla/shape.h" #include "xla/shape_util.h" @@ -68,13 +72,18 @@ class GpuPerformanceModelTest : public HloTestBase { fused_consumers); } + mlir::MLIRContext mlir_context_; GpuHloCostAnalysis::Options options_{ShapeSizeBytesFunction(), /*per_second_rates=*/{}, /*count_multiple_input_accesses=*/true}; // The reference times in the test cases below are measured // on A6000 by profiling the execution of the HLOs. - se::DeviceDescription dev_info_{TestGpuDeviceInfo::RTXA6000DeviceInfo()}; - GpuHloCostAnalysis analysis_{options_, &dev_info_}; + se::DeviceDescription device_info_{TestGpuDeviceInfo::RTXA6000DeviceInfo()}; + GpuHloCostAnalysis analysis_{options_, &device_info_}; + + GpuPerformanceModelWithIndexingAnalysis indexing_cost_model_{ + &device_info_, ShapeSizeBytesFunction(), &mlir_context_}; + GpuPerformanceModelTest() : HloTestBase() {} }; @@ -103,6 +112,9 @@ ENTRY e { auto prio_t = EstimateRunTimesForPriorityFusion(root); // Dominated by the DRAM bandwidth. EXPECT_NEAR(absl::ToInt64Microseconds(prio_t.time_unfused), 53, 10); + + auto indexing_t = indexing_cost_model_.EstimateRunTimes(root); + EXPECT_NEAR(absl::ToInt64Microseconds(indexing_t.time_unfused), 53, 10); } TEST_F(GpuPerformanceModelTest, SmallReadWrite) { @@ -132,11 +144,14 @@ ENTRY e { GpuPerformanceModel::RecordEstimatedRunTime( root, &analysis_, GpuPerformanceModelOptions::Default()); - double recorded_cycles = root->backend_config() - ->fusion_backend_config() - .reification_cost() - .end_to_end_cycles(); - EXPECT_NEAR(recorded_cycles, 257.7, 0.1); + auto reification_cost = root->backend_config() + ->fusion_backend_config() + .reification_cost(); + EXPECT_NEAR(reification_cost.end_to_end_cycles(), 257.7, 0.1); + EXPECT_NEAR(reification_cost.exec_time_us(), 0, 1); + + auto indexing_t = indexing_cost_model_.EstimateRunTimes(root); + EXPECT_NEAR(absl::ToInt64Microseconds(indexing_t.time_unfused), 1, 1); } TEST_F(GpuPerformanceModelTest, LargeReadWrite) { @@ -166,11 +181,13 @@ ENTRY e { GpuPerformanceModel::RecordEstimatedRunTime( root, &analysis_, GpuPerformanceModelOptions::Default()); - double recorded_cycles = root->backend_config() - ->fusion_backend_config() - .reification_cost() - .end_to_end_cycles(); - EXPECT_NEAR(recorded_cycles, 220284, 100); + auto reification_cost = root->backend_config() + ->fusion_backend_config() + .reification_cost(); + EXPECT_NEAR(reification_cost.end_to_end_cycles(), 220284, 100); + EXPECT_NEAR(reification_cost.exec_time_us(), 156, 10); + EXPECT_NEAR(reification_cost.compute_time_us(), 1, 1); + EXPECT_NEAR(reification_cost.memory_access_time_us(), 156, 10); } TEST_F(GpuPerformanceModelTest, L1CacheEffect) { @@ -261,23 +278,6 @@ TEST_F(GpuPerformanceModelTest, UnusedParameter) { EXPECT_NEAR(absl::ToInt64Microseconds(t.time_unfused), 1, 1); } -using GpuPerformanceWithCollectiveModelTest = GpuPerformanceModelTest; - -TEST_F(GpuPerformanceWithCollectiveModelTest, TestNvmlLibraryLoading) { -#if GOOGLE_CUDA - EXPECT_TRUE(GpuPerformanceWithCollectiveModel::InitNvml()); - // After successful init, we try to use one of the - // nvml functions to see if the result is good. - nvmlDevice_t nvml_device; - nvmlReturn_t get_device_result = - xla_nvmlDeviceGetHandleByIndex(0, &nvml_device); - EXPECT_TRUE(get_device_result == NVML_SUCCESS); - - EXPECT_TRUE(GpuPerformanceWithCollectiveModel::InitNvml()); - -#endif // GOOGLE_CUDA -} - TEST_F(GpuPerformanceModelTest, ComputeBoundReducesWithSameLaunchDimensions) { // We compare two compute-bound reduces that do ~the same amount of compute // and have the same launch dimensions. The result should be approximately @@ -331,7 +331,7 @@ ENTRY fusion { auto run = [&](absl::string_view hlo_text) -> absl::StatusOr { TF_ASSIGN_OR_RETURN(auto module, ParseAndReturnVerifiedModule(hlo_text)); - GpuHloCostAnalysis analysis(options_, &dev_info_); + GpuHloCostAnalysis analysis(options_, &device_info_); TF_RETURN_IF_ERROR(module->entry_computation()->Accept(&analysis)); auto* producer = diff --git a/third_party/xla/xla/service/gpu/model/hlo_op_profiler_run.cc b/third_party/xla/xla/service/gpu/model/hlo_op_profiler_run.cc index 7e45f20dac33bf..38479bbc982a38 100644 --- a/third_party/xla/xla/service/gpu/model/hlo_op_profiler_run.cc +++ b/third_party/xla/xla/service/gpu/model/hlo_op_profiler_run.cc @@ -89,6 +89,7 @@ int RunProfiler(int argc, char** argv) { // Unary HloOpcode::kCbrt, HloOpcode::kCos, + HloOpcode::kErf, HloOpcode::kExp, HloOpcode::kExpm1, HloOpcode::kLog, diff --git a/third_party/xla/xla/service/gpu/model/hlo_op_profiler_test.cc b/third_party/xla/xla/service/gpu/model/hlo_op_profiler_test.cc index e2d7ccbcb93446..6a8ed6538e8edb 100644 --- a/third_party/xla/xla/service/gpu/model/hlo_op_profiler_test.cc +++ b/third_party/xla/xla/service/gpu/model/hlo_op_profiler_test.cc @@ -39,7 +39,7 @@ TEST_F(HloOpProfilerTest, BasicMeasurementsAreCorrect) { EXPECT_GT(profiler.MeasureClockCyclesPerOp(HloOpcode::kDivide, F64) .value() .clock_cycles(), - 500); + 400); // c128 sqrt is slow. EXPECT_GT(profiler.MeasureClockCyclesPerOp(HloOpcode::kSqrt, C128) .value() diff --git a/third_party/xla/xla/service/gpu/model/hlo_op_profiles.cc b/third_party/xla/xla/service/gpu/model/hlo_op_profiles.cc index fcceaa5a4057c7..e8a46b077478cc 100644 --- a/third_party/xla/xla/service/gpu/model/hlo_op_profiles.cc +++ b/third_party/xla/xla/service/gpu/model/hlo_op_profiles.cc @@ -24,6 +24,7 @@ limitations under the License. #include "absl/memory/memory.h" #include "absl/strings/str_cat.h" #include "xla/hlo/ir/hlo_opcode.h" +#include "xla/service/gpu/model/hlo_op_profiles_data.h" #include "xla/stream_executor/device_description.h" #include "tsl/platform/logging.h" #include "tsl/platform/protobuf.h" @@ -31,6 +32,14 @@ limitations under the License. namespace xla { namespace gpu { +/*static*/ const HloOpProfiles& HloOpProfiles::Singleton() { + static const auto* hlo_op_profiles = + HloOpProfiles::Load(kDeviceHloOpProfiles, + /*default_profile_name=*/"sm_86") + .release(); + return *hlo_op_profiles; +} + /*static*/ std::string HloOpProfiles::GetProfileName( const se::DeviceDescription* device_info) { if (device_info != nullptr) { diff --git a/third_party/xla/xla/service/gpu/model/hlo_op_profiles.h b/third_party/xla/xla/service/gpu/model/hlo_op_profiles.h index 91866f0609fe1b..000ffc601b4dec 100644 --- a/third_party/xla/xla/service/gpu/model/hlo_op_profiles.h +++ b/third_party/xla/xla/service/gpu/model/hlo_op_profiles.h @@ -40,6 +40,9 @@ class HloOpProfiles { absl::flat_hash_map; + // Returns singleton with profiler data. + static const HloOpProfiles& Singleton(); + // Returns profile name for the gived device. // For CUDA, the format is "sm_XX". static std::string GetProfileName(const se::DeviceDescription* device_info); diff --git a/third_party/xla/xla/service/gpu/model/indexing_analysis.cc b/third_party/xla/xla/service/gpu/model/indexing_analysis.cc index b16e6727bf7bc4..88c4c4fadb95dc 100644 --- a/third_party/xla/xla/service/gpu/model/indexing_analysis.cc +++ b/third_party/xla/xla/service/gpu/model/indexing_analysis.cc @@ -24,13 +24,17 @@ limitations under the License. #include #include #include +#include #include #include #include "absl/algorithm/container.h" #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" +#include "absl/container/inlined_vector.h" +#include "absl/log/check.h" #include "absl/types/span.h" +#include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" #include "mlir/IR/AffineExpr.h" // from @llvm-project @@ -42,8 +46,8 @@ limitations under the License. #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/layout.h" -#include "xla/layout_util.h" #include "xla/permutation_util.h" +#include "xla/service/gpu/fusions/tiling_util.h" #include "xla/service/gpu/hlo_traversal.h" #include "xla/service/gpu/matmul_utils.h" #include "xla/service/gpu/model/affine_map_printer.h" @@ -62,21 +66,26 @@ using mlir::AffineExpr; using mlir::AffineMap; using mlir::getAffineConstantExpr; using mlir::getAffineDimExpr; +using mlir::getAffineSymbolExpr; using mlir::MLIRContext; HloInstructionIndexing CreateUnknownIndexing(int64_t count = 1) { HloInstructionIndexing indexing; - indexing.indexing_maps = - std::vector>>( - count, {std::nullopt}); + indexing.indexing_maps = std::vector>( + count, {IndexingMap::GetUndefined()}); return indexing; } IndexingMap CreateIdentityMap(const Shape& shape, MLIRContext* ctx) { + if (shape.IsTuple()) { + // Should happen only for variadic reduce. In that case all tuple shapes are + // equal. + return CreateIdentityMap(shape.tuple_shapes(0), ctx); + } + auto dims = shape.dimensions(); - IndexingMap identity_map{ - .affine_map = AffineMap::getMultiDimIdentityMap(dims.size(), ctx), - .domain = Domain::FromUpperBounds(dims, {})}; + IndexingMap identity_map = IndexingMap::FromTensorSizes( + AffineMap::getMultiDimIdentityMap(dims.size(), ctx), dims, {}); return identity_map; } @@ -108,10 +117,10 @@ HloInstructionIndexing ComputeOutputToInputBroadcastOpIndexing( for (int64_t bcast_dim : bcast->dimensions()) { exprs.push_back(getAffineDimExpr(bcast_dim, mlir_context)); } - IndexingMap indexing_map{ - .affine_map = AffineMap::get(output_dims.size(), /*symbolCount=*/0, exprs, - mlir_context), - .domain = Domain::FromUpperBounds(output_dims, {})}; + IndexingMap indexing_map = IndexingMap::FromTensorSizes( + AffineMap::get(output_dims.size(), /*symbolCount=*/0, exprs, + mlir_context), + output_dims, {}); return HloInstructionIndexing::FromIndexingMaps({indexing_map}); } @@ -138,15 +147,23 @@ HloInstructionIndexing ComputeInputToOutputBroadcastOpIndexing( exprs.push_back(getAffineDimExpr( std::distance(bcast_dims.begin(), bcast_dim), mlir_context)); } - IndexingMap indexing_map{ - .affine_map = AffineMap::get(input_shape.rank(), added_dims_sizes.size(), - exprs, mlir_context), - .domain = - Domain::FromUpperBounds(input_shape.dimensions(), added_dims_sizes)}; + IndexingMap indexing_map = IndexingMap::FromTensorSizes( + AffineMap::get(input_shape.rank(), added_dims_sizes.size(), exprs, + mlir_context), + input_shape.dimensions(), added_dims_sizes); return HloInstructionIndexing::FromIndexingMaps({indexing_map}); } +std::vector RangesFromUpperBounds(absl::Span bounds) { + std::vector dim_ranges; + dim_ranges.reserve(bounds.size()); + for (int64_t dim : bounds) { + dim_ranges.push_back(Range{0, dim - 1}); + } + return dim_ranges; +} + HloInstructionIndexing ComputeOutputToInputConcatenateOpIndexing( const HloConcatenateInstruction* concat, MLIRContext* mlir_context) { const auto& operand_0_dims = concat->operand(0)->shape().dimensions(); @@ -155,11 +172,7 @@ HloInstructionIndexing ComputeOutputToInputConcatenateOpIndexing( // be adjusted for a particular operand_id. mlir::MutableAffineMap affine_map = AffineMap::getMultiDimIdentityMap(operand_0_dims.size(), mlir_context); - std::vector dim_ranges; - dim_ranges.reserve(operand_0_dims.size()); - for (int64_t dim : operand_0_dims) { - dim_ranges.push_back(Range{0, dim - 1}); - } + std::vector dim_ranges = RangesFromUpperBounds(operand_0_dims); HloInstructionIndexing concat_indexing; concat_indexing.indexing_maps.resize(concat->operand_count()); @@ -170,8 +183,9 @@ HloInstructionIndexing ComputeOutputToInputConcatenateOpIndexing( affine_map.setResult(concat_dim, concat_dim_expr - offset); int64_t operand_concat_dim = operand->shape().dimensions()[concat_dim]; dim_ranges[concat_dim] = Range{offset, offset + operand_concat_dim - 1}; - concat_indexing.indexing_maps[operand_id].insert(IndexingMap{ - affine_map.getAffineMap(), Domain(dim_ranges, /*symbol_ranges=*/{})}); + concat_indexing.indexing_maps[operand_id].insert( + IndexingMap(affine_map.getAffineMap(), dim_ranges, + /*symbol_ranges=*/{})); offset += operand_concat_dim; } return concat_indexing; @@ -192,8 +206,8 @@ HloInstructionIndexing ComputeInputToOutputConcatenateOpIndexing( AffineMap::getMultiDimIdentityMap(operand_dims.size(), mlir_context); affine_map.setResult(concat_dim, getAffineDimExpr(concat_dim, mlir_context) + offset); - IndexingMap indexing_map{.affine_map = affine_map.getAffineMap(), - .domain = Domain::FromUpperBounds(operand_dims, {})}; + IndexingMap indexing_map = + IndexingMap::FromTensorSizes(affine_map.getAffineMap(), operand_dims, {}); return HloInstructionIndexing::FromIndexingMaps({indexing_map}); } @@ -204,23 +218,14 @@ HloInstructionIndexing ComputeOutputToInputFusionOpIndexing( MLIRContext* mlir_context) { auto fusion_adaptor = HloFusionAdaptor::ForInstruction(fusion); auto grouped_indexing_maps = ComputeGroupedOutputToInputIndexing( - *fusion_adaptor, output_id, mlir_context); - - HloInstructionIndexing fusion_indexing; - - if (!grouped_indexing_maps.has_value()) { - fusion_indexing.indexing_maps = - std::vector>>( - fusion->operand_count(), {std::nullopt}); - return fusion_indexing; - } + *fusion_adaptor, fusion_adaptor->GetRoots()[output_id], mlir_context); // After the traversal, `grouped_indexing_maps` is keyed by // HloParameterInstructions. Convert them back to the operand id and return. + HloInstructionIndexing fusion_indexing; fusion_indexing.indexing_maps.resize(fusion->operand_count()); for (auto [operand_id, operand] : llvm::enumerate(fusion->operands())) { - fusion_indexing.indexing_maps[operand_id] = - grouped_indexing_maps.value()[operand]; + fusion_indexing.indexing_maps[operand_id] = grouped_indexing_maps[operand]; } return fusion_indexing; } @@ -290,21 +295,75 @@ HloInstructionIndexing ComputeOutputToInputDotOpIndexing( input_dim_sizes.push_back(lhs_shape.dimensions(lhs_contracting_dim)); } - IndexingMap lhs_indexing_map{ - .affine_map = AffineMap::get(dot->shape().rank(), input_dim_sizes.size(), - lhs_exprs, mlir_context), - .domain = - Domain::FromUpperBounds(dot->shape().dimensions(), input_dim_sizes)}; + IndexingMap lhs_indexing_map = IndexingMap::FromTensorSizes( + AffineMap::get(dot->shape().rank(), input_dim_sizes.size(), lhs_exprs, + mlir_context), + dot->shape().dimensions(), input_dim_sizes); - IndexingMap rhs_indexing_map{ - .affine_map = AffineMap::get(dot->shape().rank(), input_dim_sizes.size(), - rhs_exprs, mlir_context), - .domain = - Domain::FromUpperBounds(dot->shape().dimensions(), input_dim_sizes)}; + IndexingMap rhs_indexing_map = IndexingMap::FromTensorSizes( + AffineMap::get(dot->shape().rank(), input_dim_sizes.size(), rhs_exprs, + mlir_context), + dot->shape().dimensions(), input_dim_sizes); return HloInstructionIndexing::FromIndexingMaps( {lhs_indexing_map, rhs_indexing_map}); } +IndexingMap ComputeOutputToInputPadOpIndexingImpl( + absl::Span output_dims, + absl::Span padding_low, + absl::Span padding_high, + absl::Span padding_interior, MLIRContext* mlir_context) { + int64_t output_rank = output_dims.size(); + + std::vector exprs; + std::vector> constraints; + std::vector dimension_ranges; + exprs.reserve(output_rank); + constraints.reserve(output_rank); + int64_t output_dim_id = 0; + for (const auto [output_dim, pad_low, pad_high, pad_interior] : + llvm::zip(output_dims, padding_low, padding_high, padding_interior)) { + AffineExpr dim_expr = getAffineDimExpr(output_dim_id, mlir_context); + dimension_ranges.push_back( + Range{std::max(int64_t{0}, pad_low), + std::min(output_dim - 1, output_dim - 1 - pad_high)}); + if (pad_interior == 0) { + exprs.push_back(dim_expr - pad_low); + } else { + exprs.push_back((dim_expr - pad_low).floorDiv(pad_interior + 1)); + constraints.push_back( + {(dim_expr - pad_low) % (pad_interior + 1), Range{0, 0}}); + } + ++output_dim_id; + } + return IndexingMap{ + AffineMap::get(output_rank, /*symbolCount=*/0, exprs, mlir_context), + dimension_ranges, /*symbol_ranges = */ {}, absl::MakeSpan(constraints)}; +} + +HloInstructionIndexing ComputeOutputToInputPadOpIndexing( + const HloPadInstruction* pad, MLIRContext* mlir_context) { + const Shape& output_shape = pad->shape(); + int64_t rank = output_shape.rank(); + SmallVector padding_low, padding_high, padding_interior; + padding_low.reserve(rank); + padding_high.reserve(rank); + padding_interior.reserve(rank); + for (const auto& dim_config : pad->padding_config().dimensions()) { + padding_low.push_back(dim_config.edge_padding_low()); + padding_high.push_back(dim_config.edge_padding_high()); + padding_interior.push_back(dim_config.interior_padding()); + } + IndexingMap input_indexing_map = ComputeOutputToInputPadOpIndexingImpl( + output_shape.dimensions(), padding_low, padding_high, padding_interior, + mlir_context); + IndexingMap padding_value_indexing_map = IndexingMap::FromTensorSizes( + AffineMap::get(output_shape.rank(), /*symbolCount=*/0, {}, mlir_context), + output_shape.dimensions(), /*symbol_upper_bounds=*/{}); + return HloInstructionIndexing::FromIndexingMaps( + {input_indexing_map, padding_value_indexing_map}); +} + HloInstructionIndexing ComputeOutputToInputReduceOpIndexing( const HloReduceInstruction* reduce, int output_id, MLIRContext* mlir_context) { @@ -328,15 +387,13 @@ HloInstructionIndexing ComputeOutputToInputReduceOpIndexing( } exprs.push_back(getAffineDimExpr(output_dim_id++, mlir_context)); } - IndexingMap inputs_indexing_map{ - .affine_map = AffineMap::get(output_shape.rank(), reduce_dims_ids.size(), - exprs, mlir_context), - .domain = Domain::FromUpperBounds(output_shape.dimensions(), - parallel_dims_sizes)}; - IndexingMap inits_indexing_map{ - .affine_map = AffineMap::get(output_shape.rank(), /*symbolCount=*/0, {}, - mlir_context), - .domain = Domain::FromUpperBounds(output_shape.dimensions(), {})}; + IndexingMap inputs_indexing_map = IndexingMap::FromTensorSizes( + AffineMap::get(output_shape.rank(), reduce_dims_ids.size(), exprs, + mlir_context), + output_shape.dimensions(), parallel_dims_sizes); + IndexingMap inits_indexing_map = IndexingMap::FromTensorSizes( + AffineMap::get(output_shape.rank(), /*symbolCount=*/0, {}, mlir_context), + output_shape.dimensions(), {}); HloInstructionIndexing instr_indexing; instr_indexing.indexing_maps.resize(reduce->operand_count()); @@ -368,17 +425,15 @@ HloInstructionIndexing ComputeInputToOutputReduceOpIndexing( continue; } inputs_exprs.push_back(getAffineDimExpr(input_dim_id, mlir_context)); - inits_exprs.push_back( - mlir::getAffineSymbolExpr(output_dim_id++, mlir_context)); - } - IndexingMap inputs_indexing_map{ - .affine_map = AffineMap::get(input_shape.rank(), /*symbolCount=*/0, - inputs_exprs, mlir_context), - .domain = Domain::FromUpperBounds(input_shape.dimensions(), {})}; - IndexingMap inits_indexing_map{ - .affine_map = AffineMap::get(0, /*symbolCount=*/output_rank, inits_exprs, - mlir_context), - .domain = Domain::FromUpperBounds({}, output_shape.dimensions())}; + inits_exprs.push_back(getAffineSymbolExpr(output_dim_id++, mlir_context)); + } + IndexingMap inputs_indexing_map = IndexingMap::FromTensorSizes( + AffineMap::get(input_shape.rank(), /*symbolCount=*/0, inputs_exprs, + mlir_context), + input_shape.dimensions(), {}); + IndexingMap inits_indexing_map = IndexingMap::FromTensorSizes( + AffineMap::get(0, /*symbolCount=*/output_rank, inits_exprs, mlir_context), + {}, output_shape.dimensions()); HloInstructionIndexing instr_indexing; instr_indexing.indexing_maps.resize(reduce->operand_count()); @@ -391,6 +446,80 @@ HloInstructionIndexing ComputeInputToOutputReduceOpIndexing( return instr_indexing; } +// Indexing for reduce-window with dilations and non-trivial padding can be +// represented as a composition of pad op and reduce-window that never goes out +// of bounds. +HloInstructionIndexing ComputeOutputToInputReduceWindowOpIndexing( + const HloReduceWindowInstruction* reduce_window, int output_id, + MLIRContext* mlir_context) { + const Shape& input_shape = reduce_window->operand(0)->shape(); + const Shape& output_shape = GetOutputShape(reduce_window, 0); + int64_t rank = input_shape.rank(); + + // Compute shape of the padded input and the indexing map of pad op required + // to pad the input. + SmallVector padding_low, padding_high, padding_interior, + padded_input_dimensions; + padding_low.reserve(rank); + padding_high.reserve(rank); + padding_interior.reserve(rank); + padded_input_dimensions.reserve(rank); + SmallVector exprs; + std::vector dim_ranges, symbol_ranges; + exprs.reserve(rank); + dim_ranges.reserve(rank); + symbol_ranges.reserve(rank); + for (const auto& [dim_id, window_config] : + llvm::enumerate(reduce_window->window().dimensions())) { + padding_low.push_back(window_config.padding_low()); + padding_high.push_back(window_config.padding_high()); + // For some reason interior_padding in HLO pad is offset from base_dilations + // in HLO reduce-window by 1. + padding_interior.push_back(window_config.base_dilation() - 1); + padded_input_dimensions.push_back(input_shape.dimensions(dim_id) + + window_config.padding_low() + + window_config.padding_high() + + (input_shape.dimensions(dim_id) - 1) * + (window_config.base_dilation() - 1)); + AffineExpr dim_expr = getAffineDimExpr(dim_id, mlir_context); + AffineExpr symbol_expr = getAffineSymbolExpr(dim_id, mlir_context); + + exprs.push_back(symbol_expr + window_config.stride() * dim_expr); + dim_ranges.push_back(Range{0, output_shape.dimensions(dim_id) - 1}); + symbol_ranges.push_back(Range{0, window_config.size() - 1}); + } + // Indexing map for pad op that pads the input. + IndexingMap padded_input_indexing = ComputeOutputToInputPadOpIndexingImpl( + padded_input_dimensions, padding_low, padding_high, padding_interior, + mlir_context); + // Indexing map for reduce-window, that does not do any padding. + IndexingMap reduce_window_indexing_no_padding( + AffineMap::get(rank, rank, exprs, mlir_context), dim_ranges, + symbol_ranges); + + // Composed indexing. + IndexingMap inputs_indexing = ComposeIndexingMaps( + reduce_window_indexing_no_padding, padded_input_indexing); + inputs_indexing.Simplify(); + inputs_indexing.RemoveUnusedSymbols(); + + // Indexing map for the init value. + IndexingMap inits_indexing_map = IndexingMap::FromTensorSizes( + AffineMap::get(output_shape.rank(), /*symbolCount=*/0, {}, mlir_context), + output_shape.dimensions(), /*symbol_upper_bounds=*/{}); + + HloInstructionIndexing instr_indexing; + instr_indexing.indexing_maps.resize(reduce_window->operand_count()); + for (int64_t id = 0; id < reduce_window->input_count(); ++id) { + instr_indexing.indexing_maps[id].insert(inputs_indexing); + } + for (int64_t id = reduce_window->input_count(); + id < reduce_window->operand_count(); ++id) { + instr_indexing.indexing_maps[id].insert(inits_indexing_map); + } + return instr_indexing; +} + // Computes strides for a shape. std::vector ComputeStrides(absl::Span dims) { int rank = static_cast(dims.size()); @@ -542,10 +671,9 @@ HloInstructionIndexing ComputeOutputToInputReshapeOpIndexing( auto input_dims = reshape->operand(0)->shape().dimensions(); auto output_dims = reshape->shape().dimensions(); - IndexingMap reshape_indexing_map{ - .affine_map = - ComputeReshapeIndexingMap(input_dims, output_dims, mlir_context), - .domain = Domain::FromUpperBounds(output_dims, {})}; + IndexingMap reshape_indexing_map = IndexingMap::FromTensorSizes( + ComputeReshapeIndexingMap(input_dims, output_dims, mlir_context), + output_dims, {}); reshape_indexing_map.Simplify(); return HloInstructionIndexing::FromIndexingMaps({reshape_indexing_map}); } @@ -554,10 +682,9 @@ HloInstructionIndexing ComputeInputToOutputReshapeOpIndexing( auto input_dims = reshape->operand(0)->shape().dimensions(); auto output_dims = reshape->shape().dimensions(); - IndexingMap reshape_indexing_map{ - .affine_map = - ComputeReshapeIndexingMap(output_dims, input_dims, mlir_context), - .domain = Domain::FromUpperBounds(input_dims, {})}; + IndexingMap reshape_indexing_map = IndexingMap::FromTensorSizes( + ComputeReshapeIndexingMap(output_dims, input_dims, mlir_context), + input_dims, {}); reshape_indexing_map.Simplify(); return HloInstructionIndexing::FromIndexingMaps({reshape_indexing_map}); } @@ -579,10 +706,10 @@ HloInstructionIndexing ComputeReverseOpIndexing( exprs.push_back(-dim_expr + output_dim - 1); } - IndexingMap indexing_map{ - .affine_map = AffineMap::get(output_dims.size(), /*symbolCount=*/0, exprs, - mlir_context), - .domain = Domain::FromUpperBounds(output_dims, {})}; + IndexingMap indexing_map = IndexingMap::FromTensorSizes( + AffineMap::get(output_dims.size(), /*symbolCount=*/0, exprs, + mlir_context), + output_dims, {}); return HloInstructionIndexing::FromIndexingMaps({indexing_map}); } @@ -598,10 +725,9 @@ HloInstructionIndexing ComputeOutputToInputSliceOpIndexing( exprs.push_back(dim_expr * slice->slice_strides()[dim] + slice->slice_starts()[dim]); } - IndexingMap indexing_map{ - .affine_map = - AffineMap::get(output_rank, /*symbolCount=*/0, exprs, mlir_context), - .domain = Domain::FromUpperBounds(slice->shape().dimensions(), {})}; + IndexingMap indexing_map = IndexingMap::FromTensorSizes( + AffineMap::get(output_rank, /*symbolCount=*/0, exprs, mlir_context), + slice->shape().dimensions(), {}); return HloInstructionIndexing::FromIndexingMaps({indexing_map}); } @@ -616,24 +742,22 @@ HloInstructionIndexing ComputeOutputToInputTransposeOpIndexing( const HloTransposeInstruction* transpose, MLIRContext* mlir_context) { AffineMap inverse_permutation = ComputeTransposeIndexingMap( InversePermutation(transpose->dimensions()), mlir_context); - return HloInstructionIndexing::FromIndexingMaps({IndexingMap{ - .affine_map = inverse_permutation, - .domain = Domain::FromUpperBounds(transpose->shape().dimensions(), {})}}); + return HloInstructionIndexing::FromIndexingMaps({IndexingMap::FromTensorSizes( + inverse_permutation, transpose->shape().dimensions(), {})}); } HloInstructionIndexing ComputeInputToOutputTransposeOpIndexing( const HloTransposeInstruction* transpose, MLIRContext* mlir_context) { AffineMap forward_permutation = ComputeTransposeIndexingMap(transpose->dimensions(), mlir_context); - return HloInstructionIndexing::FromIndexingMaps( - {IndexingMap{.affine_map = forward_permutation, - .domain = Domain::FromUpperBounds( - transpose->operand(0)->shape().dimensions(), {})}}); + return HloInstructionIndexing::FromIndexingMaps({IndexingMap::FromTensorSizes( + forward_permutation, transpose->operand(0)->shape().dimensions(), {})}); } -std::optional ComputeOutputToInputBitcastOpIndexingImpl( - const Shape& input_shape, const Shape& output_shape, - MLIRContext* mlir_context) { +} // namespace + +IndexingMap GetBitcastMap(const Shape& input_shape, const Shape& output_shape, + MLIRContext* ctx) { ShapeUtil::BitcastDecomposition decomposed_bitcast = ShapeUtil::DecomposeBitcast(input_shape, output_shape); @@ -644,57 +768,49 @@ std::optional ComputeOutputToInputBitcastOpIndexingImpl( CHECK(permutation.has_value()) << "Failed to deduce permutation for a bitcast."; - return ComputeTransposeIndexingMap(InversePermutation(permutation.value()), - mlir_context); + return IndexingMap::FromTensorSizes( + ComputeTransposeIndexingMap(permutation.value(), ctx), + input_shape.dimensions(), {}); } if (std::holds_alternative( decomposed_bitcast)) { - return ComputeReshapeIndexingMap(input_shape.dimensions(), - output_shape.dimensions(), mlir_context); + // Note: ComputeReshapeIndexingMap assumes it's computing an output->input + // indexing, so input and output are reversed. + return IndexingMap::FromTensorSizes( + ComputeReshapeIndexingMap(output_shape.dimensions(), + input_shape.dimensions(), ctx), + input_shape.dimensions(), {}); } // `trt` stands for transpose-reshape-transpose decomposition of bitcast. auto trt = std::get(decomposed_bitcast); - AffineMap transpose_map_1 = ComputeTransposeIndexingMap( - InversePermutation(trt.transpose1_dims), mlir_context); - AffineMap reshape_map = - ComputeReshapeIndexingMap(trt.transpose1_shape.dimensions(), - trt.reshape_shape.dimensions(), mlir_context); - AffineMap transpose_map_2 = ComputeTransposeIndexingMap( - InversePermutation(trt.transpose2_dims), mlir_context); - return transpose_map_1.compose(reshape_map).compose(transpose_map_2); + auto transpose_map_1 = ComputeTransposeIndexingMap(trt.transpose1_dims, ctx); + auto reshape_map = ComputeReshapeIndexingMap( + trt.reshape_shape.dimensions(), trt.transpose1_shape.dimensions(), ctx); + auto transpose_map_2 = ComputeTransposeIndexingMap(trt.transpose2_dims, ctx); + auto bitcast_map = + transpose_map_2.compose(reshape_map).compose(transpose_map_1); + return IndexingMap::FromTensorSizes(bitcast_map, input_shape.dimensions(), + {}); } +namespace { + HloInstructionIndexing ComputeOutputToInputBitcastOpIndexing( const HloInstruction* bitcast, MLIRContext* mlir_context) { - const Shape& input_shape = bitcast->operand(0)->shape(); - const Shape& output_shape = bitcast->shape(); - auto bitcast_affine_map = ComputeOutputToInputBitcastOpIndexingImpl( - input_shape, output_shape, mlir_context); - if (!bitcast_affine_map.has_value()) return CreateUnknownIndexing(); - - IndexingMap bitcast_indexing_map{ - .affine_map = bitcast_affine_map.value(), - .domain = Domain::FromUpperBounds(output_shape.dimensions(), {})}; - bitcast_indexing_map.Simplify(); + auto bitcast_map = GetBitcastMap(bitcast->shape(), + bitcast->operand(0)->shape(), mlir_context); + bitcast_map.Simplify(); - return HloInstructionIndexing::FromIndexingMaps({bitcast_indexing_map}); + return HloInstructionIndexing::FromIndexingMaps({bitcast_map}); } HloInstructionIndexing ComputeInputToOutputBitcastOpIndexing( const HloInstruction* bitcast, MLIRContext* mlir_context) { - const Shape& input_shape = bitcast->operand(0)->shape(); - const Shape& output_shape = bitcast->shape(); - - auto bitcast_affine_map = ComputeOutputToInputBitcastOpIndexingImpl( - output_shape, input_shape, mlir_context); - if (!bitcast_affine_map.has_value()) return CreateUnknownIndexing(); + auto bitcast_map = GetBitcastMap(bitcast->operand(0)->shape(), + bitcast->shape(), mlir_context); + bitcast_map.Simplify(); - IndexingMap bitcast_indexing_map{ - .affine_map = bitcast_affine_map.value(), - .domain = Domain::FromUpperBounds(input_shape.dimensions(), {})}; - bitcast_indexing_map.Simplify(); - - return HloInstructionIndexing::FromIndexingMaps({bitcast_indexing_map}); + return HloInstructionIndexing::FromIndexingMaps({bitcast_map}); } // Converts a layout to a dimensions transposition necessary to get to that @@ -706,46 +822,122 @@ std::vector ToTransposeDimensions(const Layout& l) { return out; } +AffineMap GetTilingAffineMap(llvm::ArrayRef exprs, + const Tiling& tiling) { + return AffineMap::get( + /*dimCount=*/6, /*symbolCount=*/tiling.GetShape().size(), exprs, + exprs[0].getContext()); +} + } // namespace -// Creates an indexing map from the physical layout of the tensor to its logical -// layout. If it is an identity, return std::nullopt. -std::optional GetIndexingMapFromPhysicalLayoutToLogical( - const Shape& shape, MLIRContext* ctx) { - if (shape.rank() == 0 || - LayoutUtil::IsMonotonicWithDim0Major(shape.layout())) { - return std::nullopt; +llvm::SmallVector DelinearizeInBoundsIndex( + AffineExpr linear, absl::Span sizes, + absl::Span strides) { + llvm::SmallVector result; + result.reserve(sizes.size()); + for (auto [size, stride] : llvm::zip(sizes, strides)) { + result.push_back(linear.floorDiv(stride) % size); } - return IndexingMap{ + for (int dim = 0; dim < sizes.size(); ++dim) { + if (sizes[dim] > 1) { + // We assume the linear index is in bounds, so no mod for the first major + // non-degenerate dimension. Degenerate dimensions are already rewritten + // to 0 by operator%. + result[dim] = linear.floorDiv(strides[dim]); + break; + } + } + return result; +} + +IndexingMap GetIndexingMapFromPhysicalLayoutToLogical(const Shape& shape, + MLIRContext* ctx) { + if (shape.rank() == 0) { + return IndexingMap(AffineMap::get(ctx), {}, {}); + } + return IndexingMap::FromTensorSizes( ComputeTransposeIndexingMap( InversePermutation(ToTransposeDimensions(shape.layout())), ctx), - Domain::FromUpperBounds( - ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(shape) - .dimensions(), - {})}; + ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(shape) + .dimensions(), + {}); } -// Creates an indexing map from the logical layout of the tensor to its physical -// layout. If it is an identity, return std::nullopt. -std::optional GetIndexingMapFromLogicalToPhysicalLayout( - const Shape& shape, MLIRContext* ctx) { - if (shape.rank() == 0 || - LayoutUtil::IsMonotonicWithDim0Major(shape.layout())) { - return std::nullopt; +IndexingMap GetIndexingMapFromLogicalToPhysicalLayout(const Shape& shape, + MLIRContext* ctx) { + if (shape.rank() == 0) { + return IndexingMap(AffineMap::get(ctx), {}, {}); } - return IndexingMap{ + return IndexingMap::FromTensorSizes( ComputeTransposeIndexingMap(ToTransposeDimensions(shape.layout()), ctx), - Domain::FromUpperBounds(shape.dimensions(), {})}; + shape.dimensions(), {}); +} + +AffineMap GetBlockOffsetsForTiling(const Tiling& tiling, + mlir::MLIRContext* ctx) { + auto offsets = DelinearizeInBoundsIndex(getAffineDimExpr(3, ctx), + tiling.GetBlockCounts(), + tiling.GetBlockStrides()); + for (auto&& [offset, tile_size] : + llvm::zip(offsets, tiling.GetBlockTileSize())) { + offset = offset * tile_size; + } + return GetTilingAffineMap(offsets, tiling); +} + +AffineMap GetThreadOffsetsForTiling(const Tiling& tiling, + mlir::MLIRContext* ctx) { + auto offsets = DelinearizeInBoundsIndex(getAffineDimExpr(0, ctx), + tiling.GetThreadsPerBlock(), + tiling.GetThreadStrides()); + for (int dim = 0; dim < tiling.GetShape().size(); ++dim) { + if (tiling.GetThreadTileSize()[dim] > 1) { + offsets[dim] = offsets[dim] + getAffineSymbolExpr(dim, ctx) * + tiling.GetThreadsPerBlock()[dim]; + } + } + return GetTilingAffineMap(offsets, tiling); +} + +IndexingMap GetIndexingMapForTiling(const Tiling& tiling, + mlir::MLIRContext* ctx) { + return GetIndexingMapForTiling(GetBlockOffsetsForTiling(tiling, ctx), + GetThreadOffsetsForTiling(tiling, ctx), + tiling); +} + +IndexingMap GetIndexingMapForTiling(AffineMap block_offsets, + AffineMap thread_offsets, + const Tiling& tiling) { + llvm::SmallVector offsets; + offsets.reserve(block_offsets.getNumResults()); + for (auto [block, thread] : + llvm::zip(block_offsets.getResults(), thread_offsets.getResults())) { + offsets.push_back(block + thread); + } + + // TODO(jreiffers): Use general constraints for symbols: in the last blocks + // in each each dimension, the bounds can be different if we don't have a + // perfect tiling. + std::vector dimension_ranges{ + {0, tiling.GetNumThreadsPerBlock() - 1}, {}, {}, + {0, tiling.GetNumBlocks() - 1}, {}, {}, + }; + return {GetTilingAffineMap(offsets, tiling), dimension_ranges, + RangesFromUpperBounds(tiling.GetThreadTileSize())}; } bool HloInstructionIndexing::Simplify() { bool any_simplified = false; for (auto& operand_indexing : indexing_maps) { - std::vector> to_remove, to_add; - for (std::optional map : operand_indexing) { + std::vector to_remove, to_add; + for (IndexingMap map : operand_indexing) { to_remove.push_back(map); - if (!map.has_value() || map->Simplify()) { + if (map.IsUndefined()) { to_add.push_back(map); + } else if (map.Simplify()) { + map.RemoveUnusedSymbols(); } else { to_remove.pop_back(); } @@ -785,11 +977,11 @@ void HloInstructionIndexing::Print(std::ostream& out, llvm::enumerate(indexing_maps)) { out << "operand id = " << operand_id << ' '; for (const auto& indexing_map : indexing_maps) { - if (!indexing_map.has_value()) { + if (indexing_map.IsUndefined()) { out << "unknown indexing"; continue; } - indexing_map->Print(out, printer); + indexing_map.Print(out, printer); } } } @@ -818,31 +1010,51 @@ GroupedByOpIndexingMap GroupIndexingMapsByProducers( return result; } -std::optional ComputeGroupedOutputToInputIndexing( - const HloFusionAdaptor& fusion_adaptor, int output_id, MLIRContext* ctx) { - auto root = fusion_adaptor.GetRoots()[output_id]; - - auto initial_map = CreateIdentityMap(root.instruction().shape(), ctx); +GroupedByOpIndexingMap ComputeGroupedOutputToInputIndexing( + const HloFusionAdaptor& fusion_adaptor, HloInstructionAdaptor target_instr, + MLIRContext* ctx) { + auto initial_map = CreateIdentityMap(target_instr.instruction().shape(), ctx); GroupedByOpIndexingMap grouped_indexing_maps; - grouped_indexing_maps[&root.instruction()].insert(initial_map); + // If target_instr is a parameter of a fusion, then we create an identity map + // for the fusion operand. + if (fusion_adaptor.ContainsInstruction(target_instr)) { + if (auto parameter_instr = + DynCast(&target_instr.instruction())) { + const HloInstruction* user = parameter_instr->users().front(); + auto fusion_operand = HloInstructionAdaptor(*user).GetOperand( + parameter_instr->parameter_number()); + grouped_indexing_maps[&fusion_operand.instruction()] = {initial_map}; + return grouped_indexing_maps; + } + } + grouped_indexing_maps[&target_instr.instruction()].insert(initial_map); auto post_order = fusion_adaptor.MakeInstructionPostOrder(); // Iterator in reversed post-order (use-before-def). - for (auto it = post_order.rbegin(); it != post_order.rend(); ++it) { + auto it = std::find(post_order.rbegin(), post_order.rend(), target_instr); + for (; it != post_order.rend(); ++it) { auto producer_indexing = ComputeOutputToInputIndexing(&it->instruction(), /*output_id=*/0, ctx); - auto consumer_indexing_maps = grouped_indexing_maps[&it->instruction()]; + auto consumer_indexing_maps = + grouped_indexing_maps.find(&it->instruction()); + if (consumer_indexing_maps == grouped_indexing_maps.end()) { + continue; + } + // Indexing maps have to be copied because of rehashing. Consider using a + // different container to get better performance. + IndexingMapSet consumer_indexing_maps_copy = consumer_indexing_maps->second; for (const auto& [producer_operand_id, producer_operand_indexing] : llvm::enumerate(producer_indexing.indexing_maps)) { auto producer_operand_adaptor = it->GetOperand(producer_operand_id); - for (const std::optional& producer_map : - producer_operand_indexing) { - for (const std::optional& consumer_map : - consumer_indexing_maps) { + for (const IndexingMap& producer_map : producer_operand_indexing) { + for (const IndexingMap& consumer_map : consumer_indexing_maps_copy) { + auto composed_map = ComposeIndexingMaps(consumer_map, producer_map); + composed_map.Simplify(); + composed_map.RemoveUnusedSymbols(); grouped_indexing_maps[&producer_operand_adaptor.instruction()].insert( - ComposeIndexingMaps(producer_map, consumer_map)); + composed_map); } } } @@ -850,6 +1062,29 @@ std::optional ComputeGroupedOutputToInputIndexing( return grouped_indexing_maps; } +bool FuseProducerConsumerOutputToInputIndexing( + const HloInstruction* producer_instr, + absl::flat_hash_map* + consumer_indexing, + MLIRContext* mlir_context) { + auto producer_indexing = ComputeOutputToInputIndexing( + producer_instr, /*output_id=*/0, mlir_context); + auto consumer_indexing_maps = (*consumer_indexing)[producer_instr]; + for (const auto& [producer_operand_id, producer_operand_indexing] : + llvm::enumerate(producer_indexing.indexing_maps)) { + const HloInstruction* producer_operand_instr = + producer_instr->operand(producer_operand_id); + for (const IndexingMap& producer_map : producer_operand_indexing) { + for (const IndexingMap& consumer_map : consumer_indexing_maps) { + (*consumer_indexing)[producer_operand_instr].insert( + ComposeIndexingMaps(producer_map, consumer_map)); + } + } + } + consumer_indexing->erase(producer_instr); + return true; +} + HloInstructionIndexing ComputeOutputToInputIndexing(const HloInstruction* instr, int output_id, MLIRContext* ctx) { @@ -877,9 +1112,16 @@ HloInstructionIndexing ComputeOutputToInputIndexing(const HloInstruction* instr, if (auto iota = DynCast(instr)) { return HloInstructionIndexing{}; } + if (auto pad = DynCast(instr)) { + return ComputeOutputToInputPadOpIndexing(pad, ctx); + } if (auto reduce = DynCast(instr)) { return ComputeOutputToInputReduceOpIndexing(reduce, output_id, ctx); } + if (auto reduce_window = DynCast(instr)) { + return ComputeOutputToInputReduceWindowOpIndexing(reduce_window, output_id, + ctx); + } if (auto reshape = DynCast(instr)) { return ComputeOutputToInputReshapeOpIndexing(reshape, ctx); } diff --git a/third_party/xla/xla/service/gpu/model/indexing_analysis.h b/third_party/xla/xla/service/gpu/model/indexing_analysis.h index c3c9b805142249..6df85b418ec7dd 100644 --- a/third_party/xla/xla/service/gpu/model/indexing_analysis.h +++ b/third_party/xla/xla/service/gpu/model/indexing_analysis.h @@ -18,7 +18,6 @@ limitations under the License. #include #include -#include #include #include #include @@ -27,9 +26,11 @@ limitations under the License. #include "absl/container/flat_hash_set.h" #include "absl/types/span.h" #include "llvm/ADT/Hashing.h" +#include "mlir/IR/AffineExpr.h" // from @llvm-project #include "mlir/IR/AffineMap.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project #include "xla/hlo/ir/hlo_instruction.h" +#include "xla/service/gpu/fusions/tiling_util.h" #include "xla/service/gpu/hlo_traversal.h" #include "xla/service/gpu/model/affine_map_printer.h" #include "xla/service/gpu/model/indexing_map.h" @@ -38,6 +39,8 @@ limitations under the License. namespace xla { namespace gpu { +using IndexingMapSet = absl::flat_hash_set; + // Contains indexing maps for all N-dimensional tensor input operands that // correspond to a particular output. struct HloInstructionIndexing { @@ -55,7 +58,7 @@ struct HloInstructionIndexing { absl::Span indexing_maps); // Maps input operand index to the indexing map for one particular output. - std::vector>> indexing_maps; + std::vector indexing_maps; }; std::ostream& operator<<(std::ostream& out, const HloInstructionIndexing& instr_indexing); @@ -74,13 +77,13 @@ HloInstructionIndexing ComputeInputToOutputIndexing(const HloInstruction* instr, int input_id, mlir::MLIRContext* ctx); -using IndexingMapSet = absl::flat_hash_set>; using GroupedByOpIndexingMap = absl::flat_hash_map; -// Computes indexing for every instruction within a fusion cluster. -std::optional ComputeGroupedOutputToInputIndexing( - const HloFusionAdaptor& fusion_adaptor, int output_id, +// Computes output-to-input indexing for every instruction within a fusion +// cluster starting with `target_instr` and going from def to use. +GroupedByOpIndexingMap ComputeGroupedOutputToInputIndexing( + const HloFusionAdaptor& fusion_adaptor, HloInstructionAdaptor target_instr, mlir::MLIRContext* ctx); // Groups indexing maps by instructions. @@ -88,19 +91,55 @@ absl::flat_hash_map GroupIndexingMapsByProducers(const HloInstructionIndexing& indexing, const HloInstruction* instr); +// Computes producer indexing maps and fuse/compose them with the consumer +// indexing maps. +bool FuseProducerConsumerOutputToInputIndexing( + const HloInstruction* producer_instr, + absl::flat_hash_map* + consumer_indexing, + mlir::MLIRContext* mlir_context); + +// Creates an indexing map for bitcasting from `input_shape` to `output_shape`. +// Equivalent to linearizing the input_shape index and then delinearizing it +// to output_shape. +IndexingMap GetBitcastMap(const Shape& input_shape, const Shape& output_shape, + mlir::MLIRContext* ctx); + // Creates an indexing map from the physical layout of the tensor to its logical -// layout. If it is an identity, return std::nullopt. -std::optional GetIndexingMapFromPhysicalLayoutToLogical( - const Shape& shape, mlir::MLIRContext* ctx); +// layout. +IndexingMap GetIndexingMapFromPhysicalLayoutToLogical(const Shape& shape, + mlir::MLIRContext* ctx); // Creates an indexing map from the logical layout of the tensor to its physical -// layout. If it is an identity, return std::nullopt. -std::optional GetIndexingMapFromLogicalToPhysicalLayout( - const Shape& shape, mlir::MLIRContext* ctx); +// layout. +IndexingMap GetIndexingMapFromLogicalToPhysicalLayout(const Shape& shape, + mlir::MLIRContext* ctx); + +// Creates an indexing map from thread and block IDs to elements of the tiled +// shape. Uses the same convention as KernelFusionInterface: dimensions 0 to 2 +// are thread indices (currently only 0 is used), dimensions 3 to 5 are block +// indices (currently only 3 is used). +mlir::AffineMap GetBlockOffsetsForTiling(const Tiling& tiling, + mlir::MLIRContext* ctx); +mlir::AffineMap GetThreadOffsetsForTiling(const Tiling& tiling, + mlir::MLIRContext* ctx); + +// Convenience functions for the two functions above +// (`GetBlockOffsestsForTiling` + `GetThreadOffsetsForTiling`). Also sets up +// the ranges of dimensions and symbols. +IndexingMap GetIndexingMapForTiling(const Tiling& tiling, + mlir::MLIRContext* ctx); +IndexingMap GetIndexingMapForTiling(mlir::AffineMap block_offsets, + mlir::AffineMap thread_offsets, + const Tiling& tiling); // Returns the shape of the output of the instruction. const Shape& GetOutputShape(const HloInstruction* instr, int64_t output_id); +llvm::SmallVector DelinearizeInBoundsIndex( + mlir::AffineExpr linear, absl::Span sizes, + absl::Span strides); + } // namespace gpu } // namespace xla diff --git a/third_party/xla/xla/service/gpu/model/indexing_analysis_test.cc b/third_party/xla/xla/service/gpu/model/indexing_analysis_test.cc index 0b7746c450b613..f21def004466fe 100644 --- a/third_party/xla/xla/service/gpu/model/indexing_analysis_test.cc +++ b/third_party/xla/xla/service/gpu/model/indexing_analysis_test.cc @@ -15,8 +15,6 @@ limitations under the License. #include "xla/service/gpu/model/indexing_analysis.h" -#include - #include #include #include "absl/strings/string_view.h" @@ -35,7 +33,6 @@ using ::testing::ElementsAre; using ::testing::Eq; using ::testing::ExplainMatchResult; using ::testing::IsEmpty; -using ::testing::Optional; using ::testing::Pair; using ::testing::UnorderedElementsAre; @@ -87,13 +84,13 @@ TEST_F(IndexingAnalysisTest, FuseProducerConsumerOutputToInputIndexing) { EXPECT_THAT( grouped_by_key, - UnorderedElementsAre(Pair(parameter, ElementsAre(MatchIndexingString(R"( + UnorderedElementsAre(Pair(parameter, ElementsAre(MatchIndexingMap(R"( (d0, d1) -> (d0, d1) domain: d0 in [0, 999] d1 in [0, 999] )"))), - Pair(transpose, ElementsAre(MatchIndexingString(R"( + Pair(transpose, ElementsAre(MatchIndexingMap(R"( (d0, d1) -> (d0, d1) domain: d0 in [0, 999] @@ -119,33 +116,180 @@ TEST_F(IndexingAnalysisTest, ComputeGroupedOutputToInputIndexing) { auto fusion_adaptor = ProducerConsumerFusion(transpose, root); auto grouped_indexing = ComputeGroupedOutputToInputIndexing( - fusion_adaptor, /*output_id=*/0, &mlir_context_); + fusion_adaptor, fusion_adaptor.GetRoots()[0], &mlir_context_); EXPECT_THAT(grouped_indexing, - Optional(UnorderedElementsAre( - Pair(root, ElementsAre(MatchIndexingString(R"( + UnorderedElementsAre( + Pair(root, ElementsAre(MatchIndexingMap(R"( (d0, d1) -> (d0, d1) domain: d0 in [0, 999] d1 in [0, 999] )"))), - Pair(transpose, ElementsAre(MatchIndexingString(R"( + Pair(transpose, ElementsAre(MatchIndexingMap(R"( (d0, d1) -> (d0, d1) domain: d0 in [0, 999] d1 in [0, 999] )"))), - Pair(parameter, UnorderedElementsAre(MatchIndexingString(R"( + Pair(parameter, UnorderedElementsAre(MatchIndexingMap(R"( (d0, d1) -> (d0, d1) domain: d0 in [0, 999] d1 in [0, 999] )"), - MatchIndexingString(R"( + MatchIndexingMap(R"( (d0, d1) -> (d1, d0) domain: d0 in [0, 999] d1 in [0, 999] - )")))))); + )"))))); +} + +TEST_F(IndexingAnalysisTest, + ComputeGroupedOutputToInputIndexing_VariadicReduce) { + auto module = ParseAndReturnVerifiedModule(R"( +HloModule m + +add { + param_0 = f32[] parameter(0) + param_1 = f32[] parameter(1) + param_2 = f32[] parameter(2) + param_3 = f32[] parameter(3) + add.0 = f32[] add(param_0, param_2) + add.1 = f32[] add(param_1, param_3) + ROOT t = (f32[], f32[]) tuple(add.0, add.1) +} + +ENTRY entry_computation { + param_0.3 = f32[32,40]{1,0} parameter(0) + param_1.3 = f32[32,40]{1,0} parameter(1) + param_2.2 = f32[] parameter(2) + constant = f32[] constant(0) + ROOT reduce = (f32[32]{0}, f32[32]{0}) reduce(param_0.3, param_1.3, param_2.2, constant), dimensions={1}, to_apply=add +} + )"); + EXPECT_TRUE(module.ok()); + const HloInstruction* root = + (*module)->entry_computation()->root_instruction(); + + auto fusion_adaptor = HloFusionAdaptor::ForInstruction(root); + + auto grouped_indexing = ComputeGroupedOutputToInputIndexing( + *fusion_adaptor, fusion_adaptor->GetRoots()[0], &mlir_context_); + + EXPECT_THAT(grouped_indexing, + UnorderedElementsAre( + Pair(root, ElementsAre(MatchIndexingMap(R"( + (d0) -> (d0) + domain: + d0 in [0, 31] + )"))), + Pair(root->operand(0), ElementsAre(MatchIndexingMap(R"( + (d0)[s0] -> (d0, s0) + domain: + d0 in [0, 31] + s0 in [0, 39] + )"))), + Pair(root->operand(1), ElementsAre(MatchIndexingMap(R"( + (d0)[s0] -> (d0, s0) + domain: + d0 in [0, 31] + s0 in [0, 39] + )"))), + Pair(root->operand(2), ElementsAre(MatchIndexingMap(R"( + (d0) -> () + domain: + d0 in [0, 31] + )"))), + Pair(root->operand(3), ElementsAre(MatchIndexingMap(R"( + (d0) -> () + domain: + d0 in [0, 31] + )"))))); +} + +TEST_F(IndexingAnalysisTest, ComputeGroupedOutputToInputIndexing_SingleOp) { + auto module = ParseAndReturnVerifiedModule(R"( + HloModule m + ENTRY e { + p0 = f32[1000, 1000] parameter(0) + p1 = f32[1000, 1000] parameter(1) + exp0 = f32[1000, 1000] exponential(p1) + ROOT a0 = f32[1000, 1000] add(p0, exp0) + } + )"); + EXPECT_TRUE(module.ok()); + HloComputation* entry_computation = (*module)->entry_computation(); + const HloInstruction* exponential = + entry_computation->GetInstructionWithName("exp0"); + const HloInstruction* parameter = + entry_computation->GetInstructionWithName("p1"); + + auto fusion_adaptor = HloFusionAdaptor::ForInstruction(exponential); + HloInstructionAdaptor parameter_adaptor(*parameter); + auto grouped_indexing = ComputeGroupedOutputToInputIndexing( + *fusion_adaptor, parameter_adaptor, &mlir_context_); + EXPECT_THAT(grouped_indexing, UnorderedElementsAre(Pair( + parameter, ElementsAre(MatchIndexingMap(R"( + (d0, d1) -> (d0, d1) + domain: + d0 in [0, 999] + d1 in [0, 999] + )"))))); +} + +TEST_F(IndexingAnalysisTest, + ComputeGroupedOutputToInputIndexing_StartNotAtRoot) { + auto module = ParseAndReturnVerifiedModule(R"( + HloModule m + max { + p0 = f32[] parameter(0) + p1 = f32[] parameter(1) + ROOT max = f32[] maximum(p0, p1) + } + f { + p0 = f32[15, 20] parameter(0) + p0_init = f32[] parameter(1) + p0_bcast = f32[15, 32, 20, 64] broadcast(p0), dimensions={0, 2} + + ROOT reduce_2 = f32[15, 64] reduce(p0_bcast, p0_init), + dimensions={1, 2}, to_apply=max + } + ENTRY e { + p0 = f32[15, 20] parameter(0) + p0_init = f32[] constant(-inf) + ROOT fusion = f32[15, 64] fusion(p0, p0_init), kind=kLoop, calls=f + } + )"); + EXPECT_TRUE(module.ok()); + + auto fusion_adaptor = HloFusionAdaptor::ForInstruction( + (*module)->entry_computation()->root_instruction()); + auto root = fusion_adaptor->GetRoots()[0]; + auto bcast = root.GetOperand(0); + auto parameter_0 = bcast.GetOperand(0); + + auto grouped_indexing = ComputeGroupedOutputToInputIndexing( + *fusion_adaptor, bcast, &mlir_context_); + EXPECT_THAT( + grouped_indexing, + UnorderedElementsAre( + Pair(&bcast.instruction(), ElementsAre(MatchIndexingMap(R"( + (d0, d1, d2, d3) -> (d0, d1, d2, d3) + domain: + d0 in [0, 14] + d1 in [0, 31] + d2 in [0, 19] + d3 in [0, 63] + )"))), + Pair(¶meter_0.instruction(), ElementsAre(MatchIndexingMap(R"( + (d0, d1, d2, d3) -> (d0, d2) + domain: + d0 in [0, 14] + d1 in [0, 31] + d2 in [0, 19] + d3 in [0, 63] + )"))))); } TEST_F(IndexingAnalysisTest, PhysicalLayoutTestOutputPermutation) { @@ -159,7 +303,7 @@ TEST_F(IndexingAnalysisTest, PhysicalLayoutTestOutputPermutation) { auto input_indexing = GetOutputToInputIndexingForEntryComputation( ir, /*output_id=*/0, /*use_physical_layout=*/true); EXPECT_THAT(input_indexing.indexing_maps, - ElementsAre(ElementsAre(MatchIndexingString(R"( + ElementsAre(ElementsAre(MatchIndexingMap(R"( (d0, d1, d2) -> (d1, d2, d0) domain: d0 in [0, 29] @@ -170,7 +314,7 @@ TEST_F(IndexingAnalysisTest, PhysicalLayoutTestOutputPermutation) { auto output_indexing = GetInputToOutputIndexingForEntryComputation( ir, /*input_id=*/0, /*use_physical_layout=*/true); EXPECT_THAT(output_indexing.indexing_maps, - ElementsAre(ElementsAre(MatchIndexingString(R"( + ElementsAre(ElementsAre(MatchIndexingMap(R"( (d0, d1, d2) -> (d2, d0, d1) domain: d0 in [0, 9] @@ -179,6 +323,37 @@ TEST_F(IndexingAnalysisTest, PhysicalLayoutTestOutputPermutation) { )")))); } +TEST_F(IndexingAnalysisTest, CopyNothing) { + auto ir = R"( + HloModule m + ENTRY e { + p0 = f32[0, 0]{0,1} parameter(0) + ROOT copy0 = f32[0, 0]{1,0} copy(p0) + } + )"; + auto input_indexing = + GetOutputToInputIndexingForEntryComputation(ir, /*output_id=*/0); + input_indexing.Simplify(); + EXPECT_THAT(input_indexing.indexing_maps, + ElementsAre(ElementsAre(MatchIndexingMap(R"( + (d0, d1) -> (d0, d1) + domain: + d0 in [0, -1] + d1 in [0, -1] + )")))); + + auto output_indexing = + GetInputToOutputIndexingForEntryComputation(ir, /*input_id=*/0); + output_indexing.Simplify(); + EXPECT_THAT(output_indexing.indexing_maps, + ElementsAre(ElementsAre(MatchIndexingMap(R"( + (d0, d1) -> (d0, d1) + domain: + d0 in [0, -1] + d1 in [0, -1] + )")))); +} + TEST_F(IndexingAnalysisTest, PhysicalLayoutTestInputPermutation) { auto ir = R"( HloModule m @@ -190,7 +365,7 @@ TEST_F(IndexingAnalysisTest, PhysicalLayoutTestInputPermutation) { auto input_indexing = GetOutputToInputIndexingForEntryComputation( ir, /*output_id=*/0, /*use_physical_layout=*/true); EXPECT_THAT(input_indexing.indexing_maps, - ElementsAre(ElementsAre(MatchIndexingString(R"( + ElementsAre(ElementsAre(MatchIndexingMap(R"( (d0, d1, d2) -> (d2, d0, d1) domain: d0 in [0, 9] @@ -201,7 +376,7 @@ TEST_F(IndexingAnalysisTest, PhysicalLayoutTestInputPermutation) { auto output_indexing = GetInputToOutputIndexingForEntryComputation( ir, /*input_id=*/0, /*use_physical_layout=*/true); EXPECT_THAT(output_indexing.indexing_maps, - ElementsAre(ElementsAre(MatchIndexingString(R"( + ElementsAre(ElementsAre(MatchIndexingMap(R"( (d0, d1, d2) -> (d1, d2, d0) domain: d0 in [0, 29] @@ -221,7 +396,7 @@ TEST_F(IndexingAnalysisTest, PhysicalLayoutTestInputAndOutputPermutation) { auto input_indexing = GetOutputToInputIndexingForEntryComputation( ir, /*output_id=*/0, /*use_physical_layout=*/true); EXPECT_THAT(input_indexing.indexing_maps, - ElementsAre(ElementsAre(MatchIndexingString(R"( + ElementsAre(ElementsAre(MatchIndexingMap(R"( (d0, d1, d2) -> (d0, d1, d2) domain: d0 in [0, 29] @@ -232,7 +407,7 @@ TEST_F(IndexingAnalysisTest, PhysicalLayoutTestInputAndOutputPermutation) { auto output_indexing = GetInputToOutputIndexingForEntryComputation( ir, /*input_id=*/0, /*use_physical_layout=*/true); EXPECT_THAT(output_indexing.indexing_maps, - ElementsAre(ElementsAre(MatchIndexingString(R"( + ElementsAre(ElementsAre(MatchIndexingMap(R"( (d0, d1, d2) -> (d0, d1, d2) domain: d0 in [0, 29] @@ -252,13 +427,13 @@ TEST_F(IndexingAnalysisTest, ElementwiseOp) { )"; auto input_indexing = GetOutputToInputIndexingForEntryComputation(ir); EXPECT_THAT(input_indexing.indexing_maps, - ElementsAre(ElementsAre(MatchIndexingString(R"( + ElementsAre(ElementsAre(MatchIndexingMap(R"( (d0, d1) -> (d0, d1) domain: d0 in [0, 9] d1 in [0, 19] )")), - ElementsAre(MatchIndexingString(R"( + ElementsAre(MatchIndexingMap(R"( (d0, d1) -> (d0, d1) domain: d0 in [0, 9] @@ -268,7 +443,7 @@ TEST_F(IndexingAnalysisTest, ElementwiseOp) { auto output_indexing_0 = GetInputToOutputIndexingForEntryComputation(ir, /*input_id=*/0); EXPECT_THAT(output_indexing_0.indexing_maps, - ElementsAre(ElementsAre(MatchIndexingString(R"( + ElementsAre(ElementsAre(MatchIndexingMap(R"( (d0, d1) -> (d0, d1) domain: d0 in [0, 9] @@ -278,7 +453,7 @@ TEST_F(IndexingAnalysisTest, ElementwiseOp) { auto output_indexing_1 = GetInputToOutputIndexingForEntryComputation(ir, /*input_id=*/1); EXPECT_THAT(output_indexing_1.indexing_maps, - ElementsAre(ElementsAre(MatchIndexingString(R"( + ElementsAre(ElementsAre(MatchIndexingMap(R"( (d0, d1) -> (d0, d1) domain: d0 in [0, 9] @@ -295,7 +470,7 @@ TEST_F(IndexingAnalysisTest, BitcastIsReshape) { } )"); EXPECT_THAT(input_indexing.indexing_maps, - ElementsAre(ElementsAre(MatchIndexingString(R"( + ElementsAre(ElementsAre(MatchIndexingMap(R"( (d0, d1, d2) -> (d0, d1 * 4 + d2) domain: d0 in [0, 3] @@ -313,7 +488,7 @@ TEST_F(IndexingAnalysisTest, BitcastIsTranspose) { } )"); EXPECT_THAT(input_indexing.indexing_maps, - ElementsAre(ElementsAre(MatchIndexingString(R"( + ElementsAre(ElementsAre(MatchIndexingMap(R"( (d0, d1, d2, d3) -> (d0, d3, d1, d2) domain: d0 in [0, 2] @@ -333,7 +508,7 @@ TEST_F(IndexingAnalysisTest, BitcastIsTransposeReshapeTranspose) { )"; auto input_indexing = GetOutputToInputIndexingForEntryComputation(ir); EXPECT_THAT(input_indexing.indexing_maps, - ElementsAre(ElementsAre(MatchIndexingString(R"( + ElementsAre(ElementsAre(MatchIndexingMap(R"( (d0, d1) -> (d1, d0 floordiv 3, d0 mod 3) domain: d0 in [0, 50] @@ -341,7 +516,7 @@ TEST_F(IndexingAnalysisTest, BitcastIsTransposeReshapeTranspose) { )")))); auto output_indexing = GetInputToOutputIndexingForEntryComputation(ir); EXPECT_THAT(output_indexing.indexing_maps, - ElementsAre(ElementsAre(MatchIndexingString(R"( + ElementsAre(ElementsAre(MatchIndexingMap(R"( (d0, d1, d2) -> (d1 * 3 + d2, d0) domain: d0 in [0, 15] @@ -360,7 +535,7 @@ TEST_F(IndexingAnalysisTest, BroadcastOp) { )"; auto input_indexing = GetOutputToInputIndexingForEntryComputation(ir); EXPECT_THAT(input_indexing.indexing_maps, - ElementsAre(ElementsAre(MatchIndexingString(R"( + ElementsAre(ElementsAre(MatchIndexingMap(R"( (d0, d1, d2) -> (d1) domain: d0 in [0, 9] @@ -370,7 +545,7 @@ TEST_F(IndexingAnalysisTest, BroadcastOp) { auto output_indexing = GetInputToOutputIndexingForEntryComputation(ir); EXPECT_THAT(output_indexing.indexing_maps, - ElementsAre(ElementsAre(MatchIndexingString(R"( + ElementsAre(ElementsAre(MatchIndexingMap(R"( (d0)[s0, s1] -> (s0, d0, s1) domain: d0 in [0, 19] @@ -403,21 +578,21 @@ TEST_F(IndexingAnalysisTest, ConcatenateOp) { )"; auto input_indexing = GetOutputToInputIndexingForEntryComputation(ir); EXPECT_THAT(input_indexing.indexing_maps, - ElementsAre(ElementsAre(MatchIndexingString(R"( + ElementsAre(ElementsAre(MatchIndexingMap(R"( (d0, d1, d2) -> (d0, d1, d2) domain: d0 in [0, 1] d1 in [0, 4] d2 in [0, 6] )")), - ElementsAre(MatchIndexingString(R"( + ElementsAre(MatchIndexingMap(R"( (d0, d1, d2) -> (d0, d1 - 5, d2) domain: d0 in [0, 1] d1 in [5, 15] d2 in [0, 6] )")), - ElementsAre(MatchIndexingString(R"( + ElementsAre(MatchIndexingMap(R"( (d0, d1, d2) -> (d0, d1 - 16, d2) domain: d0 in [0, 1] @@ -428,7 +603,7 @@ TEST_F(IndexingAnalysisTest, ConcatenateOp) { auto output_indexing_0 = GetInputToOutputIndexingForEntryComputation(ir, /*input_id=*/0); EXPECT_THAT(output_indexing_0.indexing_maps, - ElementsAre(ElementsAre(MatchIndexingString(R"( + ElementsAre(ElementsAre(MatchIndexingMap(R"( (d0, d1, d2) -> (d0, d1, d2) domain: d0 in [0, 1] @@ -439,7 +614,7 @@ TEST_F(IndexingAnalysisTest, ConcatenateOp) { auto output_indexing_1 = GetInputToOutputIndexingForEntryComputation(ir, /*input_id=*/1); EXPECT_THAT(output_indexing_1.indexing_maps, - ElementsAre(ElementsAre(MatchIndexingString(R"( + ElementsAre(ElementsAre(MatchIndexingMap(R"( (d0, d1, d2) -> (d0, d1 + 5, d2) domain: d0 in [0, 1] @@ -450,7 +625,7 @@ TEST_F(IndexingAnalysisTest, ConcatenateOp) { auto output_indexing_2 = GetInputToOutputIndexingForEntryComputation(ir, /*input_id=*/2); EXPECT_THAT(output_indexing_2.indexing_maps, - ElementsAre(ElementsAre(MatchIndexingString(R"( + ElementsAre(ElementsAre(MatchIndexingMap(R"( (d0, d1, d2) -> (d0, d1 + 16, d2) domain: d0 in [0, 1] @@ -474,12 +649,12 @@ TEST_F(IndexingAnalysisTest, FusionOpWithSingleBinaryOp) { } )"); EXPECT_THAT(input_indexing.indexing_maps, - ElementsAre(ElementsAre(MatchIndexingString(R"( + ElementsAre(ElementsAre(MatchIndexingMap(R"( (d0) -> (d0) domain: d0 in [0, 99] )")), - ElementsAre(MatchIndexingString(R"( + ElementsAre(MatchIndexingMap(R"( (d0) -> (d0) domain: d0 in [0, 99] @@ -547,7 +722,7 @@ TEST_F(IndexingAnalysisTest, FusionOpWithDot) { )"); EXPECT_THAT(input_indexing.indexing_maps, - ElementsAre(ElementsAre(MatchIndexingString(R"( + ElementsAre(ElementsAre(MatchIndexingMap(R"( (d0, d1, d2, d3, d4, d5)[s0] -> (d2, d0 * 768 + s0, d4, d5) domain: d0 in [0, 15] @@ -558,7 +733,7 @@ TEST_F(IndexingAnalysisTest, FusionOpWithDot) { d5 in [0, 127] s0 in [0, 767] )")), - ElementsAre(MatchIndexingString(R"( + ElementsAre(MatchIndexingMap(R"( (d0, d1, d2, d3, d4, d5)[s0] -> (d0 * 768 + s0) domain: d0 in [0, 15] @@ -569,7 +744,7 @@ TEST_F(IndexingAnalysisTest, FusionOpWithDot) { d5 in [0, 127] s0 in [0, 767] )")), - ElementsAre(MatchIndexingString(R"( + ElementsAre(MatchIndexingMap(R"( (d0, d1, d2, d3, d4, d5) -> (d1) domain: d0 in [0, 15] @@ -579,7 +754,7 @@ TEST_F(IndexingAnalysisTest, FusionOpWithDot) { d4 in [0, 5] d5 in [0, 127] )")), - ElementsAre(MatchIndexingString(R"( + ElementsAre(MatchIndexingMap(R"( (d0, d1, d2, d3, d4, d5)[s0] -> (d1, d0 * 768 + s0) domain: d0 in [0, 15] @@ -590,7 +765,7 @@ TEST_F(IndexingAnalysisTest, FusionOpWithDot) { d5 in [0, 127] s0 in [0, 767] )")), - ElementsAre(MatchIndexingString(R"( + ElementsAre(MatchIndexingMap(R"( (d0, d1, d2, d3, d4, d5)[s0] -> (d1, d0 * 768 + s0) domain: d0 in [0, 15] @@ -601,7 +776,7 @@ TEST_F(IndexingAnalysisTest, FusionOpWithDot) { d5 in [0, 127] s0 in [0, 767] )")), - ElementsAre(MatchIndexingString(R"( + ElementsAre(MatchIndexingMap(R"( (d0, d1, d2, d3, d4, d5) -> (d2, d4, d5) domain: d0 in [0, 15] @@ -660,7 +835,7 @@ TEST_F(IndexingAnalysisTest, FusionOpWithSoftmax) { } )"); EXPECT_THAT(input_indexing.indexing_maps, - ElementsAre(UnorderedElementsAre(MatchIndexingString(R"( + ElementsAre(UnorderedElementsAre(MatchIndexingMap(R"( (d0, d1, d2)[s0] -> (d0, d1, s0) domain: d0 in [0, 1] @@ -668,7 +843,7 @@ TEST_F(IndexingAnalysisTest, FusionOpWithSoftmax) { d2 in [0, 124] s0 in [0, 124] )"), - MatchIndexingString(R"( + MatchIndexingMap(R"( (d0, d1, d2) -> (d0, d1, d2) domain: d0 in [0, 1] @@ -691,13 +866,13 @@ TEST_F(IndexingAnalysisTest, FusionOpTensorPlusTransposedTensor) { } )"); EXPECT_THAT(input_indexing.indexing_maps, - ElementsAre(UnorderedElementsAre(MatchIndexingString(R"( + ElementsAre(UnorderedElementsAre(MatchIndexingMap(R"( (d0, d1) -> (d0, d1) domain: d0 in [0, 999] d1 in [0, 999] )"), - MatchIndexingString(R"( + MatchIndexingMap(R"( (d0, d1) -> (d1, d0) domain: d0 in [0, 999] @@ -728,32 +903,32 @@ TEST_F(IndexingAnalysisTest, FusionExponentialDuplication) { calls=fused_computation })"); EXPECT_THAT(input_indexing.indexing_maps, - ElementsAre(UnorderedElementsAre(MatchIndexingString(R"( + ElementsAre(UnorderedElementsAre(MatchIndexingMap(R"( (d0) -> (d0 + 1) domain: d0 in [0, 1] )"), - MatchIndexingString(R"( + MatchIndexingMap(R"( (d0) -> (d0) domain: d0 in [0, 1] )"), - MatchIndexingString(R"( + MatchIndexingMap(R"( (d0) -> (d0 + 2) domain: d0 in [0, 1] )")), - UnorderedElementsAre(MatchIndexingString(R"( + UnorderedElementsAre(MatchIndexingMap(R"( (d0) -> (d0 + 2) domain: d0 in [0, 1] )"), - MatchIndexingString(R"( + MatchIndexingMap(R"( (d0) -> (d0 + 1) domain: d0 in [0, 1] )"), - MatchIndexingString(R"( + MatchIndexingMap(R"( (d0) -> (d0) domain: d0 in [0, 1] @@ -783,7 +958,7 @@ TEST_F(IndexingAnalysisTest, FusionOpWithReduceOfReduce) { } )"); EXPECT_THAT(input_indexing.indexing_maps, - ElementsAre(ElementsAre(MatchIndexingString(R"( + ElementsAre(ElementsAre(MatchIndexingMap(R"( (d0)[s0, s1, s2] -> (s0, s2, d0, s1) domain: d0 in [0, 9] @@ -791,7 +966,7 @@ TEST_F(IndexingAnalysisTest, FusionOpWithReduceOfReduce) { s1 in [0, 49] s2 in [0, 19] )")), - ElementsAre(MatchIndexingString(R"( + ElementsAre(MatchIndexingMap(R"( (d0) -> () domain: d0 in [0, 9] @@ -821,14 +996,14 @@ TEST_F(IndexingAnalysisTest, FusionOpWithReduceOfBroadcast) { } )"); EXPECT_THAT(input_indexing.indexing_maps, - ElementsAre(ElementsAre(MatchIndexingString(R"( + ElementsAre(ElementsAre(MatchIndexingMap(R"( (d0, d1)[s0] -> (d0, s0) domain: d0 in [0, 14] d1 in [0, 63] s0 in [0, 19] )")), - ElementsAre(MatchIndexingString(R"( + ElementsAre(MatchIndexingMap(R"( (d0, d1) -> () domain: d0 in [0, 14] @@ -862,7 +1037,7 @@ TEST_F(IndexingAnalysisTest, FusionOpWithTransposeOfTranspose) { } )"); EXPECT_THAT(input_indexing.indexing_maps, - ElementsAre(ElementsAre(MatchIndexingString(R"( + ElementsAre(ElementsAre(MatchIndexingMap(R"( (d0, d1, d2) -> (d2, d0, d1) domain: d0 in [0, 9] @@ -894,14 +1069,14 @@ TEST_F(IndexingAnalysisTest, FusionOpWithReducedSlice) { } )"); EXPECT_THAT(input_indexing.indexing_maps, - ElementsAre(ElementsAre(MatchIndexingString(R"( + ElementsAre(ElementsAre(MatchIndexingMap(R"( (d0)[s0, s1] -> (s0 + 5, d0 * 2, s1 * 3 + 50) domain: d0 in [0, 31] s0 in [0, 15] s1 in [0, 127] )")), - ElementsAre(MatchIndexingString(R"( + ElementsAre(MatchIndexingMap(R"( (d0) -> () domain: d0 in [0, 31] @@ -922,7 +1097,7 @@ TEST_F(IndexingAnalysisTest, FusionOpWithReshape_CollapseOfExpand) { } )"); EXPECT_THAT(input_indexing.indexing_maps, - ElementsAre(ElementsAre(MatchIndexingString(R"( + ElementsAre(ElementsAre(MatchIndexingMap(R"( (d0) -> (d0) domain: d0 in [0, 127] @@ -943,7 +1118,7 @@ TEST_F(IndexingAnalysisTest, FusionOpWithReshape_ExpandOfCollapse) { } )"); EXPECT_THAT(input_indexing.indexing_maps, - ElementsAre(ElementsAre(MatchIndexingString(R"( + ElementsAre(ElementsAre(MatchIndexingMap(R"( (d0, d1) -> (d0, d1) domain: d0 in [0, 7] @@ -964,13 +1139,15 @@ TEST_F(IndexingAnalysisTest, FusionOpWithReshape_ChainedGenericReshapes) { ROOT fusion = f32[10, 10, 10] fusion(p0), kind=kLoop, calls=f } )"); + // TODO(jreiffers): Remove the redundant constraint. EXPECT_THAT(input_indexing.indexing_maps, - ElementsAre(ElementsAre(MatchIndexingString(R"( + ElementsAre(ElementsAre(MatchIndexingMap(R"( (d0, d1, d2) -> (d0, d1, d2) domain: d0 in [0, 9] d1 in [0, 9] d2 in [0, 9] + d1 * 10 + d2 - (d1 floordiv 2) * 20 in [0, 19] )")))); } @@ -990,7 +1167,7 @@ TEST_F(IndexingAnalysisTest, FusionOpWithSliceOfSlice) { } )"); EXPECT_THAT(input_indexing.indexing_maps, - ElementsAre(ElementsAre(MatchIndexingString(R"( + ElementsAre(ElementsAre(MatchIndexingMap(R"( (d0, d1, d2) -> (d0 * 2 + 8, d1 * 6 + 8, d2 * 12 + 65) @@ -1021,21 +1198,21 @@ TEST_F(IndexingAnalysisTest, FusionOpSliceOfAllConcatenateOpInputs) { } )"); EXPECT_THAT(input_indexing.indexing_maps, - ElementsAre(ElementsAre(MatchIndexingString(R"( + ElementsAre(ElementsAre(MatchIndexingMap(R"( (d0, d1, d2) -> (d0, d1 * 3, d2) domain: d0 in [0, 1] d1 in [0, 1] d2 in [0, 6] )")), - ElementsAre(MatchIndexingString(R"( + ElementsAre(MatchIndexingMap(R"( (d0, d1, d2) -> (d0, d1 * 3 - 5, d2) domain: d0 in [0, 1] d1 in [2, 5] d2 in [0, 6] )")), - ElementsAre(MatchIndexingString(R"( + ElementsAre(MatchIndexingMap(R"( (d0, d1, d2) -> (d0, d1 * 3 - 16, d2) domain: d0 in [0, 1] @@ -1064,21 +1241,21 @@ TEST_F(IndexingAnalysisTest, FusionOpSliceOfOneOfConcatenateOpInputs) { } )"); EXPECT_THAT(input_indexing.indexing_maps, - ElementsAre(ElementsAre(MatchIndexingString(R"( + ElementsAre(ElementsAre(MatchIndexingMap(R"( (d0, d1, d2) -> (d0, d1 * 2, d2) domain: d0 in [0, 1] d1 in [0, 2] d2 in [0, 6] )")), - ElementsAre(MatchIndexingString(R"( + ElementsAre(MatchIndexingMap(R"( (d0, d1, d2) -> (d0, d1 * 2 - 5, d2) domain: d0 in [0, 1] d1 in [3, 2] d2 in [0, 6] )")), - ElementsAre(MatchIndexingString(R"( + ElementsAre(MatchIndexingMap(R"( (d0, d1, d2) -> (d0, d1 * 2 - 16, d2) domain: d0 in [0, 1] @@ -1103,14 +1280,14 @@ TEST_F(IndexingAnalysisTest, FusionOpReshapeOfConcat) { } )"); EXPECT_THAT(input_indexing.indexing_maps, - ElementsAre(ElementsAre(MatchIndexingString(R"( + ElementsAre(ElementsAre(MatchIndexingMap(R"( (d0, d1) -> (d0 * 8 + d1) domain: d0 in [0, 3] d1 in [0, 7] d0 * 8 + d1 in [0, 1] )")), - ElementsAre(MatchIndexingString(R"( + ElementsAre(MatchIndexingMap(R"( (d0, d1) -> (d0 * 8 + d1 - 2) domain: d0 in [0, 3] @@ -1139,7 +1316,7 @@ TEST_F(IndexingAnalysisTest, ReshapeOpCollapseShape) { } )"); EXPECT_THAT(input_indexing.indexing_maps, - ElementsAre(ElementsAre(MatchIndexingString(R"( + ElementsAre(ElementsAre(MatchIndexingMap(R"( (d0) -> (d0 floordiv 8, d0 mod 8) domain: d0 in [0, 31] @@ -1155,7 +1332,7 @@ TEST_F(IndexingAnalysisTest, ReshapeOpExpandShape) { } )"); EXPECT_THAT(input_indexing.indexing_maps, - ElementsAre(ElementsAre(MatchIndexingString(R"( + ElementsAre(ElementsAre(MatchIndexingMap(R"( (d0, d1) -> (d0 * 8 + d1) domain: d0 in [0, 3] @@ -1173,7 +1350,7 @@ TEST_F(IndexingAnalysisTest, ReshapeOpExpandAndCollapseShape) { )"; auto input_indexing = GetOutputToInputIndexingForEntryComputation(ir); EXPECT_THAT(input_indexing.indexing_maps, - ElementsAre(ElementsAre(MatchIndexingString(R"( + ElementsAre(ElementsAre(MatchIndexingMap(R"( (d0, d1, d2) -> (d0 floordiv 8, d0 mod 8, d1 * 4 + d2) domain: d0 in [0, 31] @@ -1183,7 +1360,7 @@ TEST_F(IndexingAnalysisTest, ReshapeOpExpandAndCollapseShape) { auto output_indexing = GetInputToOutputIndexingForEntryComputation(ir); EXPECT_THAT(output_indexing.indexing_maps, - ElementsAre(ElementsAre(MatchIndexingString(R"( + ElementsAre(ElementsAre(MatchIndexingMap(R"( (d0, d1, d2) -> (d0 * 8 + d1, d2 floordiv 4, d2 mod 4) domain: d0 in [0, 3] @@ -1201,7 +1378,7 @@ TEST_F(IndexingAnalysisTest, ReshapeOpExpandSubshapeOnly) { } )"); EXPECT_THAT(input_indexing.indexing_maps, - ElementsAre(ElementsAre(MatchIndexingString(R"( + ElementsAre(ElementsAre(MatchIndexingMap(R"( (d0, d1, d2) -> (d0 * 4 + d1, d2) domain: d0 in [0, 3] @@ -1219,9 +1396,9 @@ TEST_F(IndexingAnalysisTest, ReshapeOpGenericReshape2DTO3D) { } )"); EXPECT_THAT(input_indexing.indexing_maps, - ElementsAre(ElementsAre(MatchIndexingString(R"( - (d0, d1, d2) -> (d0 * 2 + (d1 * 4 + d2) floordiv 8, - (d1 * 4 + d2) mod 8) + ElementsAre(ElementsAre(MatchIndexingMap(R"( + (d0, d1, d2) -> (d0 * 2 + d1 floordiv 2, + d1 * 4 + d2 - (d1 floordiv 2) * 8) domain: d0 in [0, 1] d1 in [0, 3] @@ -1238,16 +1415,97 @@ TEST_F(IndexingAnalysisTest, ReshapeOpGenericReshape3DTO2D) { } )"); EXPECT_THAT(input_indexing.indexing_maps, - ElementsAre(ElementsAre(MatchIndexingString(R"( - (d0, d1) -> ((d0 * 8 + d1) floordiv 16, - ((d0 * 8 + d1) mod 16) floordiv 4, - d1 mod 4) + ElementsAre(ElementsAre(MatchIndexingMap(R"( + (d0, d1) -> (d0 floordiv 2, + d0 * 2 - (d0 floordiv 2) * 4 + + d1 floordiv 4, + d1 mod 4) domain: d0 in [0, 3] d1 in [0, 7] )")))); } +TEST_F(IndexingAnalysisTest, PadOp) { + auto input_indexing = GetOutputToInputIndexingForEntryComputation(R"( + HloModule m + ENTRY e { + p0 = f32[4, 4] parameter(0) + p1 = f32[] parameter(1) + ROOT pad = f32[12, 16] pad(p0, p1), padding=1_4_1x4_8_0 + } + )"); + EXPECT_THAT(input_indexing.indexing_maps, + ElementsAre(ElementsAre(MatchIndexingMap(R"( + (d0, d1) -> ( + (d0 - 1) floordiv 2, + d1 - 4 + ) + domain: + d0 in [1, 7] + d1 in [4, 7] + (d0 - 1) mod 2 in [0, 0] + )")), + ElementsAre(MatchIndexingMap(R"( + (d0, d1) -> () + domain: + d0 in [0, 11] + d1 in [0, 15] + )")))); +} + +TEST_F(IndexingAnalysisTest, PadOpNoInterior) { + auto input_indexing = GetOutputToInputIndexingForEntryComputation(R"( + HloModule m + ENTRY e { + p0 = f32[2,8] parameter(0) + p1 = f32[] parameter(1) + ROOT pad = f32[10,8] pad(p0, p1), padding=1_7x0_0 + } + )"); + EXPECT_THAT(input_indexing.indexing_maps, + ElementsAre(ElementsAre(MatchIndexingMap(R"( + (d0, d1) -> (d0 - 1, d1) + domain: + d0 in [1, 2] + d1 in [0, 7] + )")), + ElementsAre(MatchIndexingMap(R"( + (d0, d1) -> () + domain: + d0 in [0, 9] + d1 in [0, 7] + )")))); +} + +TEST_F(IndexingAnalysisTest, PadOpNegativePadding) { + // The interior padding is applied first (even with negative padding), so we + // get a size of 5 (7 + 6 - 8). + // in: 0 1 2 3 4 5 6 + // padded: 0 p 1 p 2 p 3 p 4 p 5 p 6 + // sliced: p 2 p 3 p + auto input_indexing = GetOutputToInputIndexingForEntryComputation(R"( + HloModule m + ENTRY e { + p0 = f32[7] parameter(0) + p1 = f32[] parameter(1) + ROOT pad = f32[5] pad(p0, p1), padding=-3_-5_1 + } + )"); + EXPECT_THAT(input_indexing.indexing_maps, + ElementsAre(ElementsAre(MatchIndexingMap(R"( + (d0) -> ((d0 + 3) floordiv 2) + domain: + d0 in [0, 4] + (d0 + 3) mod 2 in [0, 0] + )")), + ElementsAre(MatchIndexingMap(R"( + (d0) -> () + domain: + d0 in [0, 4] + )")))); +} + TEST_F(IndexingAnalysisTest, ReduceOp) { auto ir = R"( HloModule m @@ -1265,7 +1523,7 @@ TEST_F(IndexingAnalysisTest, ReduceOp) { )"; auto input_indexing = GetOutputToInputIndexingForEntryComputation(ir); EXPECT_THAT(input_indexing.indexing_maps, - ElementsAre(ElementsAre(MatchIndexingString(R"( + ElementsAre(ElementsAre(MatchIndexingMap(R"( (d0, d1)[s0, s1] -> (d0, s0, d1, s1) domain: d0 in [0, 149] @@ -1273,7 +1531,7 @@ TEST_F(IndexingAnalysisTest, ReduceOp) { s0 in [0, 19] s1 in [0, 49] )")), - ElementsAre(MatchIndexingString(R"( + ElementsAre(MatchIndexingMap(R"( (d0, d1) -> () domain: d0 in [0, 149] @@ -1282,7 +1540,7 @@ TEST_F(IndexingAnalysisTest, ReduceOp) { auto output_indexing = GetInputToOutputIndexingForEntryComputation(ir); EXPECT_THAT(output_indexing.indexing_maps, - ElementsAre(ElementsAre(MatchIndexingString(R"( + ElementsAre(ElementsAre(MatchIndexingMap(R"( (d0, d1, d2, d3) -> (d0, d2) domain: d0 in [0, 149] @@ -1290,7 +1548,7 @@ TEST_F(IndexingAnalysisTest, ReduceOp) { d2 in [0, 9] d3 in [0, 49] )")), - ElementsAre(MatchIndexingString(R"( + ElementsAre(MatchIndexingMap(R"( ()[s0, s1] -> (s0, s1) domain: s0 in [0, 149] @@ -1324,26 +1582,26 @@ TEST_F(IndexingAnalysisTest, VariadicReduceOp) { auto output_indexing_0 = GetOutputToInputIndexingForEntryComputation(ir, /*output_id=*/0); EXPECT_THAT(output_indexing_0.indexing_maps, - ElementsAre(ElementsAre(MatchIndexingString(R"( + ElementsAre(ElementsAre(MatchIndexingMap(R"( (d0)[s0] -> (s0, d0) domain: d0 in [0, 9] s0 in [0, 255] )")), - ElementsAre(MatchIndexingString(R"( + ElementsAre(MatchIndexingMap(R"( (d0)[s0] -> (s0, d0) domain: d0 in [0, 9] s0 in [0, 255] )")), - ElementsAre(MatchIndexingString(R"( + ElementsAre(MatchIndexingMap(R"( (d0) -> () domain: d0 in [0, 9] )")), - ElementsAre(MatchIndexingString(R"( + ElementsAre(MatchIndexingMap(R"( (d0) -> () domain: d0 in [0, 9] @@ -1352,24 +1610,24 @@ TEST_F(IndexingAnalysisTest, VariadicReduceOp) { auto output_indexing_1 = GetOutputToInputIndexingForEntryComputation(ir, /*output_id=*/1); EXPECT_THAT(output_indexing_1.indexing_maps, - ElementsAre(ElementsAre(MatchIndexingString(R"( + ElementsAre(ElementsAre(MatchIndexingMap(R"( (d0)[s0] -> (s0, d0) domain: d0 in [0, 9] s0 in [0, 255] )")), - ElementsAre(MatchIndexingString(R"( + ElementsAre(MatchIndexingMap(R"( (d0)[s0] -> (s0, d0) domain: d0 in [0, 9] s0 in [0, 255] )")), - ElementsAre(MatchIndexingString(R"( + ElementsAre(MatchIndexingMap(R"( (d0) -> () domain: d0 in [0, 9] )")), - ElementsAre(MatchIndexingString(R"( + ElementsAre(MatchIndexingMap(R"( (d0) -> () domain: d0 in [0, 9] @@ -1379,24 +1637,24 @@ TEST_F(IndexingAnalysisTest, VariadicReduceOp) { GetInputToOutputIndexingForEntryComputation(ir, /*input_id=*/0); EXPECT_THAT(input_indexing_0.indexing_maps, - ElementsAre(ElementsAre(MatchIndexingString(R"( + ElementsAre(ElementsAre(MatchIndexingMap(R"( (d0, d1) -> (d1) domain: d0 in [0, 255] d1 in [0, 9] )")), - ElementsAre(MatchIndexingString(R"( + ElementsAre(MatchIndexingMap(R"( (d0, d1) -> (d1) domain: d0 in [0, 255] d1 in [0, 9] )")), - ElementsAre(MatchIndexingString(R"( + ElementsAre(MatchIndexingMap(R"( ()[s0] -> (s0) domain: s0 in [0, 9] )")), - ElementsAre(MatchIndexingString(R"( + ElementsAre(MatchIndexingMap(R"( ()[s0] -> (s0) domain: s0 in [0, 9] @@ -1405,30 +1663,216 @@ TEST_F(IndexingAnalysisTest, VariadicReduceOp) { auto input_indexing_1 = GetInputToOutputIndexingForEntryComputation(ir, /*input_id=*/1); EXPECT_THAT(input_indexing_1.indexing_maps, - ElementsAre(ElementsAre(MatchIndexingString(R"( + ElementsAre(ElementsAre(MatchIndexingMap(R"( (d0, d1) -> (d1) domain: d0 in [0, 255] d1 in [0, 9] )")), - ElementsAre(MatchIndexingString(R"( + ElementsAre(MatchIndexingMap(R"( (d0, d1) -> (d1) domain: d0 in [0, 255] d1 in [0, 9] )")), - ElementsAre(MatchIndexingString(R"( + ElementsAre(MatchIndexingMap(R"( ()[s0] -> (s0) domain: s0 in [0, 9] )")), - ElementsAre(MatchIndexingString(R"( + ElementsAre(MatchIndexingMap(R"( ()[s0] -> (s0) domain: s0 in [0, 9] )")))); } +TEST_F(IndexingAnalysisTest, ReduceWindowOp_NoPadding) { + auto ir = R"( + HloModule m + max { + p0 = f32[] parameter(0) + p1 = f32[] parameter(1) + ROOT max = f32[] maximum(p0, p1) + } + ENTRY e { + c_inf = f32[] constant(-inf) + p0 = f32[1024, 514]parameter(0) + ROOT reduce-window = f32[1024, 3] reduce-window(p0, c_inf), + window={size=1x512 pad=0_0x0_0}, to_apply=max + } + )"; + auto input_indexing = GetOutputToInputIndexingForEntryComputation(ir); + EXPECT_THAT(input_indexing.indexing_maps, + ElementsAre(ElementsAre(MatchIndexingMap(R"( + (d0, d1)[s0] -> (d0, d1 + s0) + domain: + d0 in [0, 1023] + d1 in [0, 2] + s0 in [0, 511] + )")), + ElementsAre(MatchIndexingMap(R"( + (d0, d1) -> () + domain: + d0 in [0, 1023] + d1 in [0, 2] + )")))); +} + +TEST_F(IndexingAnalysisTest, ReduceWindowOp_PaddingAndWindowStride) { + auto ir = R"( + HloModule m + max { + p0 = f32[] parameter(0) + p1 = f32[] parameter(1) + ROOT max = f32[] maximum(p0, p1) + } + ENTRY e { + c_inf = f32[] constant(-inf) + p0 = f32[13, 17] parameter(0) + ROOT reduce-window = f32[7, 17] reduce-window(p0, c_inf), + window={size=3x2 stride=2x1 pad=1_1x0_1}, to_apply=max + } + )"; + auto input_indexing = GetOutputToInputIndexingForEntryComputation(ir); + EXPECT_THAT(input_indexing.indexing_maps, + ElementsAre(ElementsAre(MatchIndexingMap(R"( + (d0, d1)[s0, s1] -> (d0 * 2 + s0 - 1, d1 + s1) + domain: + d0 in [0, 6] + d1 in [0, 16] + s0 in [0, 2] + s1 in [0, 1] + d0 * 2 + s0 in [1, 13] + d1 + s1 in [0, 16] + )")), + ElementsAre(MatchIndexingMap(R"( + (d0, d1) -> () + domain: + d0 in [0, 6] + d1 in [0, 16] + )")))); +} + +TEST_F(IndexingAnalysisTest, ReduceWindowOp_Dilation) { + auto ir = R"( + HloModule m + max { + p0 = f32[] parameter(0) + p1 = f32[] parameter(1) + ROOT max = f32[] maximum(p0, p1) + } + ENTRY e { + c_inf = f32[] constant(-inf) + p0 = f32[2, 3] parameter(0) + ROOT reduce-window = f32[3, 5] reduce-window(p0, c_inf), + window={size=1x1 pad=0_0x0_0 lhs_dilate=2x2}, to_apply=max + } + )"; + auto input_indexing = GetOutputToInputIndexingForEntryComputation(ir); + EXPECT_THAT(input_indexing.indexing_maps, + ElementsAre(ElementsAre(MatchIndexingMap(R"( + (d0, d1) -> (d0 floordiv 2, d1 floordiv 2) + domain: + d0 in [0, 2] + d1 in [0, 4] + d0 mod 2 in [0, 0] + d1 mod 2 in [0, 0] + )")), + ElementsAre(MatchIndexingMap(R"( + (d0, d1) -> () + domain: + d0 in [0, 2] + d1 in [0, 4] + )")))); +} + +TEST_F(IndexingAnalysisTest, ReduceWindowOp_Variadic) { + auto ir = R"( + HloModule m + combiner { + a0 = f32[] parameter(0) + a1 = s32[] parameter(1) + b0 = f32[] parameter(2) + b1 = s32[] parameter(3) + add0 = f32[] add(a0, b0) + add1 = s32[] add(a1, b1) + ROOT sum2 = (f32[], s32[]) tuple(add0, add1) + } + ENTRY e { + c_f32 = f32[] constant(-inf) + c_s32 = s32[] constant(10) + p0 = f32[2, 3] parameter(0) + p1 = s32[2, 3] parameter(1) + ROOT reduce-window = (f32[1, 2], s32[1, 2]) + reduce-window(p0, p1, c_f32, c_s32), + window={size=2x2 pad=0_0x0_0}, to_apply=combiner + } + )"; + auto input_indexing_0 = + GetOutputToInputIndexingForEntryComputation(ir, /*output_id=*/0); + EXPECT_THAT(input_indexing_0.indexing_maps, + ElementsAre(ElementsAre(MatchIndexingMap(R"( + (d0, d1)[s0, s1] -> (s0, d1 + s1) + domain: + d0 in [0, 0] + d1 in [0, 1] + s0 in [0, 1] + s1 in [0, 1] + )")), + ElementsAre(MatchIndexingMap(R"( + (d0, d1)[s0, s1] -> (s0, d1 + s1) + domain: + d0 in [0, 0] + d1 in [0, 1] + s0 in [0, 1] + s1 in [0, 1] + )")), + ElementsAre(MatchIndexingMap(R"( + (d0, d1) -> () + domain: + d0 in [0, 0] + d1 in [0, 1] + )")), + ElementsAre(MatchIndexingMap(R"( + (d0, d1) -> () + domain: + d0 in [0, 0] + d1 in [0, 1] + )")))); + auto input_indexing_1 = + GetOutputToInputIndexingForEntryComputation(ir, /*output_id=*/1); + EXPECT_THAT(input_indexing_1.indexing_maps, + ElementsAre(ElementsAre(MatchIndexingMap(R"( + (d0, d1)[s0, s1] -> (s0, d1 + s1) + domain: + d0 in [0, 0] + d1 in [0, 1] + s0 in [0, 1] + s1 in [0, 1] + )")), + ElementsAre(MatchIndexingMap(R"( + (d0, d1)[s0, s1] -> (s0, d1 + s1) + domain: + d0 in [0, 0] + d1 in [0, 1] + s0 in [0, 1] + s1 in [0, 1] + )")), + ElementsAre(MatchIndexingMap(R"( + (d0, d1) -> () + domain: + d0 in [0, 0] + d1 in [0, 1] + )")), + ElementsAre(MatchIndexingMap(R"( + (d0, d1) -> () + domain: + d0 in [0, 0] + d1 in [0, 1] + )")))); +} + TEST_F(IndexingAnalysisTest, ReverseOp) { auto ir = R"( HloModule m @@ -1439,7 +1883,7 @@ TEST_F(IndexingAnalysisTest, ReverseOp) { )"; auto input_indexing = GetOutputToInputIndexingForEntryComputation(ir); EXPECT_THAT(input_indexing.indexing_maps, - ElementsAre(ElementsAre(MatchIndexingString(R"( + ElementsAre(ElementsAre(MatchIndexingMap(R"( (d0, d1, d2, d3) -> (d0, -d1 + 16, -d2 + 8, d3) domain: d0 in [0, 0] @@ -1450,7 +1894,7 @@ TEST_F(IndexingAnalysisTest, ReverseOp) { auto output_indexing = GetInputToOutputIndexingForEntryComputation(ir); EXPECT_THAT(output_indexing.indexing_maps, - ElementsAre(ElementsAre(MatchIndexingString(R"( + ElementsAre(ElementsAre(MatchIndexingMap(R"( (d0, d1, d2, d3) -> (d0, -d1 + 16, -d2 + 8, d3) domain: d0 in [0, 0] @@ -1477,7 +1921,7 @@ TEST_F(IndexingAnalysisTest, ReverseReshape) { } )"); EXPECT_THAT(input_indexing.indexing_maps, - ElementsAre(ElementsAre(MatchIndexingString(R"( + ElementsAre(ElementsAre(MatchIndexingMap(R"( (d0, d1) -> (d0, d1) domain: d0 in [0, 9] @@ -1495,7 +1939,7 @@ TEST_F(IndexingAnalysisTest, SliceOp) { } )"); EXPECT_THAT(input_indexing.indexing_maps, - ElementsAre(ElementsAre(MatchIndexingString(R"( + ElementsAre(ElementsAre(MatchIndexingMap(R"( (d0, d1, d2) -> (d0 + 5, d1 * 7 + 3, d2 * 2) domain: d0 in [0, 4] @@ -1515,7 +1959,7 @@ TEST_F(IndexingAnalysisTest, TransposeOp) { )"; auto input_indexing = GetOutputToInputIndexingForEntryComputation(ir); EXPECT_THAT(input_indexing.indexing_maps, - ElementsAre(ElementsAre(MatchIndexingString(R"( + ElementsAre(ElementsAre(MatchIndexingMap(R"( (d0, d1, d2, d3) -> (d0, d3, d1, d2) domain: d0 in [0, 2] @@ -1525,7 +1969,7 @@ TEST_F(IndexingAnalysisTest, TransposeOp) { )")))); auto output_indexing = GetInputToOutputIndexingForEntryComputation(ir); EXPECT_THAT(output_indexing.indexing_maps, - ElementsAre(ElementsAre(MatchIndexingString(R"( + ElementsAre(ElementsAre(MatchIndexingMap(R"( (d0, d1, d2, d3) -> (d0, d2, d3, d1) domain: d0 in [0, 2] @@ -1544,7 +1988,7 @@ TEST_F(IndexingAnalysisTest, TransposeOp4D) { } )"); EXPECT_THAT(input_indexing.indexing_maps, - ElementsAre(ElementsAre(MatchIndexingString(R"( + ElementsAre(ElementsAre(MatchIndexingMap(R"( (d0, d1, d2, d3) -> (d0, d3, d1, d2) domain: d0 in [0, 2] @@ -1566,7 +2010,7 @@ TEST_F(IndexingAnalysisTest, DotOp) { } )"); EXPECT_THAT(input_indexing.indexing_maps, - ElementsAre(ElementsAre(MatchIndexingString(R"( + ElementsAre(ElementsAre(MatchIndexingMap(R"( (d0, d1, d2, d3, d4, d5)[s0, s1] -> (d2, d1, s1, d3, s0, d0) domain: d0 in [0, 9] @@ -1578,7 +2022,7 @@ TEST_F(IndexingAnalysisTest, DotOp) { s0 in [0, 17] s1 in [0, 16] )")), - ElementsAre(MatchIndexingString(R"( + ElementsAre(MatchIndexingMap(R"( (d0, d1, d2, d3, d4, d5)[s0, s1] -> (s1, d0, d4, s0, d5, d1) domain: d0 in [0, 9] @@ -1604,21 +2048,22 @@ TEST_F(IndexingAnalysisTest, UnsupportedOps) { } )"; auto input_indexing = GetOutputToInputIndexingForEntryComputation(ir); - EXPECT_THAT(input_indexing.indexing_maps, - ElementsAre(ElementsAre(std::nullopt), ElementsAre(std::nullopt), - ElementsAre(std::nullopt))); + EXPECT_THAT( + input_indexing.indexing_maps, + ElementsAre(ElementsAre(UndefinedMap()), ElementsAre(UndefinedMap()), + ElementsAre(UndefinedMap()))); auto output_indexing_0 = GetInputToOutputIndexingForEntryComputation(ir, 0); EXPECT_THAT(output_indexing_0.indexing_maps, - ElementsAre(ElementsAre(std::nullopt))); + ElementsAre(ElementsAre(UndefinedMap()))); auto output_indexing_1 = GetInputToOutputIndexingForEntryComputation(ir, 1); EXPECT_THAT(output_indexing_1.indexing_maps, - ElementsAre(ElementsAre(std::nullopt))); + ElementsAre(ElementsAre(UndefinedMap()))); auto output_indexing_2 = GetInputToOutputIndexingForEntryComputation(ir, 2); EXPECT_THAT(output_indexing_2.indexing_maps, - ElementsAre(ElementsAre(std::nullopt))); + ElementsAre(ElementsAre(UndefinedMap()))); } TEST_F(IndexingAnalysisTest, FusionWithUnsupportedOp) { @@ -1641,14 +2086,38 @@ TEST_F(IndexingAnalysisTest, FusionWithUnsupportedOp) { } )"); EXPECT_THAT(input_indexing.indexing_maps, - ElementsAre(UnorderedElementsAre(MatchIndexingString(R"( + ElementsAre(UnorderedElementsAre(MatchIndexingMap(R"( (d0, d1) -> (d0 * 4, d1) domain: d0 in [0, 4] d1 in [0, 4] )"), - std::nullopt), - ElementsAre(std::nullopt))); + UndefinedMap()), + ElementsAre(UndefinedMap()))); +} + +TEST_F(IndexingAnalysisTest, TilingIndexing) { + Tiling tiling{/*shape=*/{1024, 256, 16}, + /*tile_sizes=*/{8, 1, 4}, + /*num_threads=*/{1, 4, 4}}; + EXPECT_THAT(GetIndexingMapForTiling(tiling, &mlir_context_).ToString(), + MatchIndexingString(R"( + (d0, d1, d2, d3, d4, d5)[s0, s1, s2] -> ( + (d3 floordiv 64) * 8 + s0, + (d3 mod 64) * 4 + d0 floordiv 4, + d0 mod 4 + s2 * 4 + ) + domain: + d0 in [0, 15] + d1 in [0, 0] + d2 in [0, 0] + d3 in [0, 8191] + d4 in [0, 0] + d5 in [0, 0] + s0 in [0, 7] + s1 in [0, 0] + s2 in [0, 3] + )")); } } // namespace diff --git a/third_party/xla/xla/service/gpu/model/indexing_map.cc b/third_party/xla/xla/service/gpu/model/indexing_map.cc index 126b092d46009f..41265518f32ef7 100644 --- a/third_party/xla/xla/service/gpu/model/indexing_map.cc +++ b/third_party/xla/xla/service/gpu/model/indexing_map.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include #include +#include #include #include #include @@ -27,8 +28,11 @@ limitations under the License. #include "absl/types/span.h" #include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallBitVector.h" +#include "llvm/ADT/SmallVector.h" #include "mlir/IR/AffineExpr.h" // from @llvm-project #include "mlir/IR/AffineMap.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project #include "xla/service/gpu/model/affine_map_printer.h" #include "tsl/platform/logging.h" // IWYU pragma: keep @@ -37,6 +41,9 @@ namespace xla { namespace gpu { namespace { +using llvm::ArrayRef; +using llvm::SmallBitVector; +using llvm::SmallVector; using mlir::AffineBinaryOpExpr; using mlir::AffineConstantExpr; using mlir::AffineDimExpr; @@ -46,6 +53,7 @@ using mlir::AffineMap; using mlir::AffineSymbolExpr; using mlir::getAffineBinaryOpExpr; using mlir::getAffineConstantExpr; +using mlir::MLIRContext; int64_t FloorDiv(int64_t dividend, int64_t divisor) { return dividend / divisor - @@ -57,44 +65,396 @@ int64_t CeilDiv(int64_t dividend, int64_t divisor) { (((dividend >= 0) == (divisor >= 0) && dividend % divisor) ? 1 : 0); } +class AffineExprSimplifier { + public: + explicit AffineExprSimplifier(RangeEvaluator* range_evaluator) + : range_evaluator_(range_evaluator) {} + + // Simplifies the map as much as possible. + mlir::AffineMap Simplify(mlir::AffineMap affine_map); + + mlir::AffineExpr Simplify(mlir::AffineExpr expr); + + private: + std::optional GetConstantRhsMultiplier(mlir::AffineExpr expr); + + // Simplifier for mod. + // - Rewrites (a * 100 + ...) % 100 to (...) % 100 + // - Rewrites a % b to a if a is known to be less than b. + mlir::AffineExpr RewriteMod(mlir::AffineBinaryOpExpr mod); + + // Simplifier for floordiv. + // - Rewrites (a * 100 + ...) / 100 to a + (...) / 100 + // - Rewrites a / 100 to 0 when a is known to be less than 100. + mlir::AffineExpr RewriteFloorDiv(mlir::AffineBinaryOpExpr div); + + mlir::AffineExpr RewriteSum( + mlir::AffineExpr expr, + const std::function& map); + + mlir::AffineExpr RewriteSumIf( + mlir::AffineExpr expr, const std::function& pred); + + // Attempts to simplify the expression, but doesn't attempt to simplify the + // result further. + mlir::AffineExpr SimplifyOnce(mlir::AffineExpr expr); + + RangeEvaluator* range_evaluator_; +}; + +AffineExpr AffineExprSimplifier::RewriteMod(AffineBinaryOpExpr mod) { + auto lhs_simplified = SimplifyOnce(mod.getLHS()); + + auto lhs = range_evaluator_->ComputeExpressionRange(lhs_simplified); + auto rhs = range_evaluator_->ComputeExpressionRange(mod.getRHS()); + + // a % b where b is always larger than a? + if (0 <= lhs.lower_bound && lhs.upper_bound < rhs.lower_bound) { + return lhs_simplified; + } + + // The logic below assumes we have a constant RHS. + if (!rhs.IsPoint()) { + return mod; + } + int64_t m = rhs.lower_bound; + + Range no_multiplier_range{0, 0}; + int64_t multiplier_gcd = -1; + + auto new_lhs = RewriteSumIf(lhs_simplified, [&](AffineExpr expr) { + if (auto multiplier = GetConstantRhsMultiplier(expr)) { + if (*multiplier % m == 0) { + return false; + } + + if (multiplier_gcd == -1) { + multiplier_gcd = *multiplier; + } else { + multiplier_gcd = std::gcd(multiplier_gcd, *multiplier); + } + return true; + } + auto range = range_evaluator_->ComputeExpressionRange(expr); + no_multiplier_range.lower_bound += range.lower_bound; + no_multiplier_range.upper_bound += range.upper_bound; + return true; + }); + + mlir::AffineExpr extracted = getAffineConstantExpr(0, mod.getContext()); + if (m % multiplier_gcd == 0 && no_multiplier_range.lower_bound >= 0 && + no_multiplier_range.upper_bound < multiplier_gcd) { + // Remove everything that doesn't have a multiplier. + new_lhs = RewriteSumIf(new_lhs, [&](AffineExpr expr) { + if (GetConstantRhsMultiplier(expr)) { + return true; + } + extracted = extracted + expr; + return false; + }); + } + if (!new_lhs) new_lhs = getAffineConstantExpr(0, mod.getContext()); + return new_lhs % mod.getRHS() + extracted; +} + +AffineExpr AffineExprSimplifier::RewriteFloorDiv(AffineBinaryOpExpr div) { + auto mlir_context = range_evaluator_->GetMLIRContext(); + auto lhs_simplified = SimplifyOnce(div.getLHS()); + auto lhs = range_evaluator_->ComputeExpressionRange(lhs_simplified); + auto rhs = range_evaluator_->ComputeExpressionRange(div.getRHS()); + + if (0 <= lhs.lower_bound && lhs.upper_bound < rhs.lower_bound) { + return getAffineConstantExpr(0, mlir_context); + } + + // The logic below assumes we have a constant RHS. + if (!rhs.IsPoint()) { + return div; + } + int64_t d = rhs.lower_bound; + + // If the dividend's range has a single element, return its value. + int64_t a = FloorDiv(lhs.lower_bound, d); + int64_t b = FloorDiv(lhs.upper_bound, d); + if (a == b) { + return getAffineConstantExpr(a, mlir_context); + } + + // Rewrite `(a / b) / c` to `a / (b * c)` if `a >= 0` and `b` and `c` are + // constants. + if (lhs_simplified.getKind() == AffineExprKind::FloorDiv) { + auto lhs_div = mlir::cast(lhs_simplified); + auto lhs_lhs = range_evaluator_->ComputeExpressionRange(lhs_div.getLHS()); + if (lhs_lhs.lower_bound >= 0) { + auto lhs_rhs = range_evaluator_->ComputeExpressionRange(lhs_div.getRHS()); + if (lhs_rhs.IsPoint()) { + return lhs_div.getLHS().floorDiv(lhs_rhs.lower_bound * d); + } + } + } + + Range no_multiplier_range{0, 0}; + int64_t multiplier_gcd = -1; + // The maximum GCD of any remaining multiplier inside the div and the divisor. + int64_t max_remaining_multiplier_gcd = -1; + AffineExpr extracted = getAffineConstantExpr(0, mlir_context); + auto new_dividend = RewriteSumIf(lhs_simplified, [&](AffineExpr expr) { + if (auto multiplier = GetConstantRhsMultiplier(expr)) { + // (x * 7 + ...) / 3 -> can't extract. We could extract x * 2 and keep + // one x, but we currently have no reason to do that. + if (*multiplier % d != 0) { + if (multiplier_gcd == -1) { + multiplier_gcd = *multiplier; + } else { + multiplier_gcd = std::gcd(multiplier_gcd, *multiplier); + } + max_remaining_multiplier_gcd = + std::max(max_remaining_multiplier_gcd, std::gcd(*multiplier, d)); + return true; + } + int64_t factor = *multiplier / d; + extracted = + extracted + mlir::cast(expr).getLHS() * factor; + // Remove from dividend. + return false; + } + auto range = range_evaluator_->ComputeExpressionRange(expr); + no_multiplier_range.lower_bound += range.lower_bound; + no_multiplier_range.upper_bound += range.upper_bound; + // Not a constant multiplier, keep in dividend. + return true; + }); + + // If we removed everything, skip the div. + if (!new_dividend) { + return extracted; + } + + if ((d % multiplier_gcd) == 0) { + if (no_multiplier_range.lower_bound >= 0 && + no_multiplier_range.upper_bound < multiplier_gcd) { + // Remove everything that doesn't have a multiplier. + new_dividend = RewriteSumIf(new_dividend, [&](AffineExpr expr) { + auto mult = GetConstantRhsMultiplier(expr); + return mult.has_value(); + }); + } + } + + // If we have a gcd > 1, we can split the div into two: + // (x * 128 + y) // 192 -> (x * 2 + y // 64) // 3 + // This rule primarily exists because MLIR's upstream simplifier tends to + // generate expressions like this from %: + // + // s0 * 512 + // - ((s0 * 2 + s1 floordiv 64) floordiv 3) * 768 + // + ((s0 * 128 + s1) floordiv 192) * 768 + // + // This rule lets us eliminate the subtraction and the addition. + // TODO(pifon): Remove this once the remaining simplification is fixed. + if (max_remaining_multiplier_gcd > 1) { + AffineExpr partially_extracted = getAffineConstantExpr(0, mlir_context); + new_dividend = RewriteSumIf(new_dividend, [&](AffineExpr expr) { + if (auto multiplier = GetConstantRhsMultiplier(expr); + multiplier && ((*multiplier % max_remaining_multiplier_gcd) == 0)) { + auto expr_lhs = mlir::cast(expr).getLHS(); + partially_extracted = + partially_extracted + + expr_lhs * (*multiplier / max_remaining_multiplier_gcd); + // Remove from dividend. + return false; + } + return true; + }); + if (!new_dividend) { + new_dividend = getAffineConstantExpr(0, mlir_context); + } + return extracted + (partially_extracted + + new_dividend.floorDiv(max_remaining_multiplier_gcd)) + .floorDiv(d / max_remaining_multiplier_gcd); + } + + // If we removed nothing, return the original division. + if (extracted == getAffineConstantExpr(0, mlir_context) && + new_dividend == div.getLHS()) { + return div; + } + + return extracted + new_dividend.floorDiv(div.getRHS()); +} + +std::optional AffineExprSimplifier::GetConstantRhsMultiplier( + AffineExpr expr) { + if (expr.getKind() != AffineExprKind::Mul) { + return std::nullopt; + } + auto bound = range_evaluator_->ComputeExpressionRange( + mlir::cast(expr).getRHS()); + if (!bound.IsPoint()) { + return std::nullopt; + } + return bound.lower_bound; +} + +AffineExpr AffineExprSimplifier::RewriteSum( + AffineExpr expr, const std::function& map) { + if (expr.getKind() == AffineExprKind::Add) { + auto add = mlir::dyn_cast(expr); + return RewriteSum(add.getLHS(), map) + RewriteSum(add.getRHS(), map); + } + return map(expr); +} + +AffineExpr AffineExprSimplifier::RewriteSumIf( + AffineExpr expr, const std::function& pred) { + if (expr.getKind() == AffineExprKind::Add) { + auto add = mlir::dyn_cast(expr); + auto lhs = RewriteSumIf(add.getLHS(), pred); + auto rhs = RewriteSumIf(add.getRHS(), pred); + if (lhs == add.getLHS() && rhs == add.getRHS()) { + return add; + } + if (lhs && rhs) { + return lhs + rhs; + } + return lhs ? lhs : (rhs ? rhs : nullptr); + } + return pred(expr) ? expr : nullptr; +} + +AffineExpr AffineExprSimplifier::SimplifyOnce(AffineExpr expr) { + switch (expr.getKind()) { + case AffineExprKind::Mul: + case AffineExprKind::Add: { + auto binop = mlir::cast(expr); + auto lhs = SimplifyOnce(binop.getLHS()); + auto rhs = SimplifyOnce(binop.getRHS()); + if (lhs == binop.getLHS() && rhs == binop.getRHS()) { + return expr; + } + return getAffineBinaryOpExpr(expr.getKind(), lhs, rhs); + } + case AffineExprKind::Mod: + return RewriteMod(mlir::cast(expr)); + case AffineExprKind::FloorDiv: + return RewriteFloorDiv(mlir::cast(expr)); + case AffineExprKind::DimId: + case AffineExprKind::SymbolId: { + auto bounds = range_evaluator_->ComputeExpressionRange(expr); + if (bounds.IsPoint()) { + return getAffineConstantExpr(bounds.lower_bound, + range_evaluator_->GetMLIRContext()); + } + return expr; + } + + default: + return expr; + } +} + +AffineExpr AffineExprSimplifier::Simplify(AffineExpr expr) { + while (true) { + auto simplified = SimplifyOnce(expr); + if (simplified == expr) { + return expr; + } + expr = simplified; + } +} + +AffineMap AffineExprSimplifier::Simplify(AffineMap affine_map) { + affine_map = mlir::simplifyAffineMap(affine_map); + SmallVector results; + results.reserve(affine_map.getNumResults()); + bool nothing_changed = true; + for (AffineExpr expr : affine_map.getResults()) { + AffineExpr simplified = Simplify(expr); + nothing_changed &= simplified == expr; + results.push_back(simplified); + } + if (nothing_changed) { + return affine_map; + } + return Simplify(AffineMap::get(affine_map.getNumDims(), + affine_map.getNumSymbols(), results, + affine_map.getContext())); +} + // Computes intersection of two ranges. Range Intersect(const Range& lhs, const Range& rhs) { return Range{std::max(lhs.lower_bound, rhs.lower_bound), std::min(lhs.upper_bound, rhs.upper_bound)}; } -// Attempts to parse an expression dim_or_symbol * factor + shift. -bool ParseLinearFunction(AffineExpr expr, AffineExpr* symbol_or_dim, - int64_t* factor, int64_t* shift) { - AffineExpr residual = expr; - *shift = 0; - *factor = 1; - if (auto binop = mlir::dyn_cast(residual)) { - if (binop.getKind() == AffineExprKind::Add) { - auto constant = mlir::dyn_cast(binop.getRHS()); - if (!constant) { - return false; - } - *shift = constant.getValue(); - residual = binop.getLHS(); +// Simplifies a constraint range, i.e. a constraint d0 + x in [lb, ub] will +// become d0 in [lb - x, ub - x]. Also supports *, floorDiv. +bool SimplifyConstraintRangeOnce(AffineExpr* expr, Range* range) { + switch (expr->getKind()) { + case AffineExprKind::DimId: + case AffineExprKind::SymbolId: + // do the trick with constant + case AffineExprKind::Constant: { + return false; } - } - if (auto binop = mlir::dyn_cast(residual)) { - if (binop.getKind() == AffineExprKind::Mul) { - auto constant = mlir::dyn_cast(binop.getRHS()); + default: { + auto binary_op = mlir::cast(*expr); + CHECK(binary_op); + auto lhs = binary_op.getLHS(); + auto rhs = binary_op.getRHS(); + auto constant = mlir::dyn_cast(rhs); if (!constant) { return false; } - *factor = constant.getValue(); - residual = binop.getLHS(); + switch (expr->getKind()) { + case AffineExprKind::Add: { + int64_t shift = constant.getValue(); + range->lower_bound -= shift; + range->upper_bound -= shift; + *expr = lhs; + return true; + } + case AffineExprKind::Mul: { + int64_t factor = constant.getValue(); + if (factor < 0) { + factor *= -1; + range->lower_bound *= -1; + range->upper_bound *= -1; + std::swap(range->lower_bound, range->upper_bound); + } + range->lower_bound = CeilDiv(range->lower_bound, factor); + range->upper_bound = FloorDiv(range->upper_bound, factor); + *expr = lhs; + return true; + } + case AffineExprKind::FloorDiv: { + int64_t divisor = constant.getValue(); + if (divisor < 0) { + divisor *= -1; + range->lower_bound *= -1; + range->upper_bound *= -1; + std::swap(range->lower_bound, range->upper_bound); + } + range->lower_bound *= divisor; + range->upper_bound = (range->upper_bound + 1) * divisor - 1; + *expr = lhs; + return true; + } + default: { + return false; + } + } } } - if (residual.getKind() == AffineExprKind::DimId || - residual.getKind() == AffineExprKind::SymbolId) { - *symbol_or_dim = residual; - return true; +} + +// Repeatedly simplifies the range of the constraint. +bool SimplifyConstraintRange(AffineExpr* expr, Range* range) { + bool is_simplified = false; + while (SimplifyConstraintRangeOnce(expr, range)) { + is_simplified = true; } - return false; + return is_simplified; } } // namespace @@ -119,23 +479,24 @@ bool operator==(const Range& lhs, const Range& rhs) { lhs.upper_bound == rhs.upper_bound; } -Domain Domain::FromUpperBounds(absl::Span dim_upper_bounds, - absl::Span symbol_upper_bounds) { - Domain domain; - domain.dim_ranges_.reserve(dim_upper_bounds.size()); +IndexingMap IndexingMap::FromTensorSizes( + AffineMap affine_map, absl::Span dim_upper_bounds, + absl::Span symbol_upper_bounds) { + IndexingMap indexing_map; + indexing_map.affine_map_ = affine_map; + indexing_map.dim_ranges_.reserve(dim_upper_bounds.size()); for (int64_t ub : dim_upper_bounds) { - CHECK_GT(ub, 0); - domain.dim_ranges_.push_back(Range{0, ub - 1}); + indexing_map.dim_ranges_.push_back(Range{0, ub - 1}); } - domain.symbol_ranges_.reserve(symbol_upper_bounds.size()); + indexing_map.symbol_ranges_.reserve(symbol_upper_bounds.size()); for (int64_t ub : symbol_upper_bounds) { CHECK_GT(ub, 0); - domain.symbol_ranges_.push_back(Range{0, ub - 1}); + indexing_map.symbol_ranges_.push_back(Range{0, ub - 1}); } - return domain; + return indexing_map; } -void Domain::AddConstraint(mlir::AffineExpr expr, const Range& range) { +void IndexingMap::AddConstraint(mlir::AffineExpr expr, Range range) { if (auto dim_expr = mlir::dyn_cast(expr)) { Range& current_range = dim_ranges_[dim_expr.getPosition()]; current_range = Intersect(current_range, range); @@ -149,67 +510,68 @@ void Domain::AddConstraint(mlir::AffineExpr expr, const Range& range) { // TODO(b/322131639): Add a proper Constraints simplifier that will apply // simplification rules until it converges. For example, it should have a rule // for `symbol_or_dim floorDiv divisor`. - - // Try to parse a linear function of type symbol_or_dim * factor + shift. - AffineExpr symbol_or_dim; - int64_t factor, shift; - if (ParseLinearFunction(expr, &symbol_or_dim, &factor, &shift)) { - Range new_range = factor > 0 - ? Range{CeilDiv(range.lower_bound - shift, factor), - FloorDiv(range.upper_bound - shift, factor)} - : Range{CeilDiv(range.upper_bound - shift, factor), - FloorDiv(range.lower_bound - shift, factor)}; - AddConstraint(symbol_or_dim, new_range); + if (SimplifyConstraintRange(&expr, &range)) { + AddConstraint(expr, range); return; } - auto [it, inserted] = expr_ranges_.insert({expr, range}); + auto [it, inserted] = constraints_.insert({expr, range}); if (!inserted) { it->second = Intersect(it->second, range); } } -bool Domain::IsKnownEmpty() const { +bool IndexingMap::ConstraintsSatisfied( + ArrayRef dim_const_exprs, + ArrayRef symbol_const_exprs) const { + CHECK(dim_const_exprs.size() == GetDimensionCount()); + CHECK(symbol_const_exprs.size() == GetSymbolCount()); + if (IsKnownEmpty()) { + return false; + } + for (auto& [expr, range] : constraints_) { + int64_t expr_value = + mlir::cast( + expr.replaceDimsAndSymbols(dim_const_exprs, symbol_const_exprs)) + .getValue(); + if (expr_value < range.lower_bound || expr_value > range.upper_bound) { + return false; + } + } + return true; +} + +SmallVector IndexingMap::Evaluate( + ArrayRef dim_const_exprs, + ArrayRef symbol_const_exprs) const { + CHECK(dim_const_exprs.size() == GetDimensionCount()); + CHECK(symbol_const_exprs.size() == GetSymbolCount()); + AffineMap eval = affine_map_.replaceDimsAndSymbols( + dim_const_exprs, symbol_const_exprs, dim_const_exprs.size(), + symbol_const_exprs.size()); + return eval.getConstantResults(); +} + +bool IndexingMap::IsKnownEmpty() const { auto is_infeasible = [](const Range& range) { return range.lower_bound > range.upper_bound; }; return llvm::any_of(dim_ranges_, is_infeasible) || llvm::any_of(symbol_ranges_, is_infeasible) || - llvm::any_of(expr_ranges_, + llvm::any_of(constraints_, [&](const std::pair& item) { return is_infeasible(item.second); }); } -std::string Domain::ToString(const AffineMapPrinter& printer) const { - std::stringstream ss; - Print(ss, printer); - return ss.str(); -} - -void Domain::Print(std::ostream& out, const AffineMapPrinter& printer) const { - for (const auto& [index, range] : llvm::enumerate(dim_ranges_)) { - out << printer.GetDimensionName(static_cast(index)) << " in "; - range.Print(out); - out << '\n'; - } - for (const auto& [index, range] : llvm::enumerate(symbol_ranges_)) { - out << printer.GetSymbolName(static_cast(index)) << " in "; - range.Print(out); - out << '\n'; - } - std::vector expr_range_strings; - expr_range_strings.reserve(expr_ranges_.size()); - for (const auto& [expr, range] : expr_ranges_) { - std::stringstream ss; - printer.Print(ss, expr); - ss << " in "; - range.Print(ss); - ss << '\n'; - expr_range_strings.push_back(ss.str()); +RangeEvaluator::RangeEvaluator(absl::Span dim_ranges, + absl::Span symbol_ranges, + MLIRContext* mlir_context) + : mlir_context_(mlir_context) { + for (const auto& [index, range] : llvm::enumerate(dim_ranges)) { + expression_ranges_cache_[getAffineDimExpr(index, mlir_context_)] = range; } - std::sort(expr_range_strings.begin(), expr_range_strings.end()); - for (const auto& expr_range_string : expr_range_strings) { - out << expr_range_string; + for (const auto& [index, range] : llvm::enumerate(symbol_ranges)) { + expression_ranges_cache_[getAffineSymbolExpr(index, mlir_context_)] = range; } } @@ -228,12 +590,10 @@ Range RangeEvaluator::ComputeExpressionRange(AffineExpr expr) { return Range{value, value}; } case AffineExprKind::DimId: { - return domain_->GetDimensionRange( - mlir::cast(expr).getPosition()); + return expression_ranges_cache_[expr]; } case AffineExprKind::SymbolId: { - return domain_->GetSymbolRange( - mlir::cast(expr).getPosition()); + return expression_ranges_cache_[expr]; } default: auto bound = expression_ranges_cache_.find(expr); @@ -277,17 +637,6 @@ Range RangeEvaluator::ComputeExpressionRange(AffineExpr expr) { } } -std::ostream& operator<<(std::ostream& out, const Domain& domain) { - AffineMapPrinter printer; - domain.Print(out, printer); - return out; -} - -bool operator==(const Domain& lhs, const Domain& rhs) { - return lhs.GetDimensionRanges() == rhs.GetDimensionRanges() && - lhs.GetSymbolRanges() == rhs.GetSymbolRanges(); -} - std::string IndexingMap::ToString(const AffineMapPrinter& printer) const { std::stringstream ss; Print(ss, printer); @@ -296,10 +645,31 @@ std::string IndexingMap::ToString(const AffineMapPrinter& printer) const { void IndexingMap::Print(std::ostream& out, const AffineMapPrinter& printer) const { - printer.Print(out, affine_map); + printer.Print(out, affine_map_); out << "\ndomain:\n"; - domain.Print(out, printer); - out << "\n"; + for (const auto& [index, range] : llvm::enumerate(dim_ranges_)) { + out << printer.GetDimensionName(static_cast(index)) << " in "; + range.Print(out); + out << '\n'; + } + for (const auto& [index, range] : llvm::enumerate(symbol_ranges_)) { + out << printer.GetSymbolName(static_cast(index)) << " in "; + range.Print(out); + out << '\n'; + } + std::vector expr_range_strings; + expr_range_strings.reserve(constraints_.size()); + for (const auto& [expr, range] : constraints_) { + std::stringstream ss; + printer.Print(ss, expr); + ss << " in "; + range.Print(ss); + expr_range_strings.push_back(ss.str()); + } + std::sort(expr_range_strings.begin(), expr_range_strings.end()); + for (const auto& expr_range_string : expr_range_strings) { + out << expr_range_string << '\n'; + } } std::ostream& operator<<(std::ostream& out, const IndexingMap& indexing_map) { @@ -309,274 +679,275 @@ std::ostream& operator<<(std::ostream& out, const IndexingMap& indexing_map) { } bool operator==(const IndexingMap& lhs, const IndexingMap& rhs) { - return lhs.affine_map == rhs.affine_map && lhs.domain == rhs.domain; + return lhs.GetAffineMap() == rhs.GetAffineMap() && + lhs.GetDimensionRanges() == rhs.GetDimensionRanges() && + lhs.GetSymbolRanges() == rhs.GetSymbolRanges(); +} + +IndexingMap operator*(const IndexingMap& lhs, const IndexingMap& rhs) { + return ComposeIndexingMaps(lhs, rhs); } +// Simplification of IndexingMap has two main parts. +// At first we optimized constraints to make the domain as small and simple as +// possible. And only then we simplify the affine_map, because its +// simplification relies on lower/upper bounds of dimensions and symbols. + +// Constraint simplification is performed in two stages repeated until +// convergence. +// 1. Simplify affine expressions in all constraints. +// 2. Simplify constraint ranges for all constraints. +// We don't optimize every constraint separately to avoid re-initialization of +// RangeEvaluator for every constraint. Note that we start with "expr" +// simplification, because the ranges of constraints were already optimized once +// when IndexingMap was constructed. bool IndexingMap::Simplify() { - RangeEvaluator range_evaluator(&domain); + if (IsUndefined()) return false; + + // Simplify constraints to shrink the lower/upper bounds of dims and symbols. + bool constraints_were_simplified = false; + while (true) { + if (!SimplifyConstraintExprs()) break; + constraints_were_simplified = true; + if (!SimplifyConstraintRanges()) break; + } + // Simplify affine_map using the optimized ranges. + // Potentially, we can be smarter about recreating the range_evaluator. + RangeEvaluator range_evaluator(dim_ranges_, symbol_ranges_, GetMLIRContext()); AffineMap simplified_affine_map = - IndexingMapSimplifier(&range_evaluator, affine_map.getContext()) - .Simplify(affine_map); - if (simplified_affine_map == affine_map) { - return false; + AffineExprSimplifier(&range_evaluator).Simplify(affine_map_); + bool affine_map_was_simplified = simplified_affine_map != affine_map_; + if (affine_map_was_simplified) { + affine_map_ = simplified_affine_map; } - affine_map = simplified_affine_map; - return true; + return affine_map_was_simplified || constraints_were_simplified; } -std::optional ComposeIndexingMaps( - const std::optional& producer_map, - const std::optional& consumer_map) { - if (!producer_map.has_value() || !consumer_map.has_value()) { - return std::nullopt; - } - // AffineMap::compose(some_affine_map) actually computes some_affine_map ∘ - // this. - AffineMap composed_map = mlir::simplifyAffineMap( - producer_map->affine_map.compose(consumer_map->affine_map)); - - // After the composition some of the symbols might become unused, e.g. when a - // dimension was added by broadcasting as then reduced. We should remove these - // dimensions from the composed affine map and also from the resulting - // `domain.symbol_ranges_`. - // - // For example, if there is a reduction(broadcast): - // - // param = f32[15] parameter(0) - // bcast = f32[15, 20] broadcast(p0), dimensions={0} - // reduce = f32[15, 20] reduce(bcast, init) dimensions={1} - // - // then `reduce` has (d0)[s0] -> (d0, s0) with s0 in [0, 20). - // and `bcast` has (d0, d1) -> (d0) indexing map. - // - // The composition of there two maps yields (d0)[s0] -> (d0), - // although `s0` is not used in the mapping. In order to remove such symbols, - // we get the indices of unused symbols and remove them from the composed - // affine map and the `domain.symbol_ranges_`. - auto unused_symbols_bit_vector = - mlir::getUnusedSymbolsBitVector({composed_map}); - composed_map = mlir::compressSymbols(composed_map, unused_symbols_bit_vector); - - // The symbols in the composed map, i.e. combined - // producer_map.compose(consumer_map) are packed as [symbols(producer_map) | - // symbols(consumer_map)]. In that order we are adding the symbol ranges while - // skipping the symbols that are unused. - std::vector combined_symbol_ranges; - combined_symbol_ranges.reserve(producer_map->domain.GetSymbolCount() + - consumer_map->domain.GetSymbolCount()); - int64_t symbol_id = 0; - for (const Range& symbol_range : - llvm::concat(producer_map->domain.GetSymbolRanges(), - consumer_map->domain.GetSymbolRanges())) { - if (unused_symbols_bit_vector[symbol_id++]) { +bool IndexingMap::SimplifyConstraintExprs() { + // Simplify affine expression in the constraints_. + RangeEvaluator range_evaluator(dim_ranges_, symbol_ranges_, GetMLIRContext()); + AffineExprSimplifier simplifier(&range_evaluator); + std::vector to_remove; + std::vector> to_add; + for (const auto& [expr, range] : constraints_) { + AffineExpr simplified = simplifier.Simplify(expr); + + // Skip constraints that are always satisfied. + Range evaluated_range = range_evaluator.ComputeExpressionRange(simplified); + if (evaluated_range.upper_bound <= range.upper_bound && + evaluated_range.lower_bound >= range.lower_bound) { + to_remove.push_back(expr); continue; } - combined_symbol_ranges.push_back(symbol_range); + if (simplified == expr) continue; + to_add.push_back({simplified, range}); + to_remove.push_back(expr); } - - IndexingMap composed_indexing_map{ - std::move(composed_map), Domain{consumer_map->domain.GetDimensionRanges(), - combined_symbol_ranges}}; - composed_indexing_map.Simplify(); - - RangeEvaluator consumer_range_evaluator(&consumer_map->domain); - // Add constraints for consumer's codomain w.r.t. producer's domain. - for (auto [index, expr] : - llvm::enumerate(consumer_map->affine_map.getResults())) { - Range consumer_result_range = - consumer_range_evaluator.ComputeExpressionRange(expr); - Range producer_dim_range = - producer_map->domain.GetDimensionRange(static_cast(index)); - // If the constraint is always satisfied, we skip it. - if (consumer_result_range.upper_bound <= producer_dim_range.upper_bound && - consumer_result_range.lower_bound >= producer_dim_range.lower_bound) { - continue; - } - composed_indexing_map.domain.AddConstraint(expr, producer_dim_range); + for (const auto& expr : to_remove) { + constraints_.erase(expr); } - return composed_indexing_map; + for (const auto& [expr, range] : to_add) { + AddConstraint(expr, range); + } + return !to_add.empty(); } -AffineExpr IndexingMapSimplifier::RewriteMod(AffineBinaryOpExpr mod) { - auto lhs_simplified = SimplifyOnce(mod.getLHS()); - - auto lhs = range_evaluator_->ComputeExpressionRange(lhs_simplified); - auto rhs = range_evaluator_->ComputeExpressionRange(mod.getRHS()); - - // a % b where b is always larger than a? - if (0 <= lhs.lower_bound && lhs.upper_bound < rhs.upper_bound) { - return lhs_simplified; +bool IndexingMap::SimplifyConstraintRanges() { + std::vector to_remove; + std::vector> to_add; + for (const auto& [expr, range] : constraints_) { + AffineExpr simplified_expr = expr; + Range simplified_range = range; + if (SimplifyConstraintRange(&simplified_expr, &simplified_range)) { + to_add.push_back({simplified_expr, simplified_range}); + to_remove.push_back(expr); + } } - - // The logic below assumes we have a constant RHS. - if (!rhs.IsPoint()) { - return mod; + for (const auto& expr : to_remove) { + constraints_.erase(expr); } - int64_t m = rhs.lower_bound; + for (const auto& [expr, range] : to_add) { + AddConstraint(expr, range); + } + return !to_add.empty(); +} - auto new_lhs = RewriteSumIf(lhs_simplified, [&](AffineExpr expr) { - if (expr.getKind() != AffineExprKind::Mul) { - return true; - } +namespace { - auto mul_rhs = range_evaluator_->ComputeExpressionRange( - mlir::cast(expr).getRHS()); - bool remove = mul_rhs.IsPoint() && (mul_rhs.lower_bound % m) == 0; - return !remove; // We keep it if we don't remove it! - }); +struct UsedParameters { + llvm::DenseSet dimension_ids; + llvm::DenseSet symbol_ids; +}; - // If we weren't able to remove or simplify anything, return the original - // expression. - if (new_lhs == mod.getLHS()) { - return mod; +void GetUsedParametersImpl(const AffineExpr& expr, + UsedParameters& used_parameters) { + if (auto dim_expr = mlir::dyn_cast(expr)) { + used_parameters.dimension_ids.insert(dim_expr.getPosition()); + return; + } + if (auto symbol_expr = mlir::dyn_cast(expr)) { + used_parameters.symbol_ids.insert(symbol_expr.getPosition()); + return; } - // If we removed everything, return 0. - if (!new_lhs) { - return getAffineConstantExpr(0, mlir_context_); + if (auto binary_expr = mlir::dyn_cast(expr)) { + GetUsedParametersImpl(binary_expr.getLHS(), used_parameters); + GetUsedParametersImpl(binary_expr.getRHS(), used_parameters); } - // Otherwise, return new_sum % m. - return new_lhs % mod.getRHS(); } -AffineExpr IndexingMapSimplifier::RewriteFloorDiv(AffineBinaryOpExpr div) { - auto lhs_simplified = SimplifyOnce(div.getLHS()); - auto lhs = range_evaluator_->ComputeExpressionRange(lhs_simplified); - auto rhs = range_evaluator_->ComputeExpressionRange(div.getRHS()); +// Returns IDs of dimensions and symbols that participate in AffineExpr. +UsedParameters GetUsedParameters(const mlir::AffineExpr& expr) { + UsedParameters used_parameters; + GetUsedParametersImpl(expr, used_parameters); + return used_parameters; +} - if (0 <= lhs.lower_bound && lhs.upper_bound < rhs.lower_bound) { - return getAffineConstantExpr(0, mlir_context_); +bool IsFunctionOfUnusedDimsAndSymbolsOnly( + const UsedParameters& used_parameters, + const SmallBitVector& unused_dims_bit_vector, + const SmallBitVector& unused_symbols_bit_vector) { + for (int64_t dim_id : used_parameters.dimension_ids) { + if (!unused_dims_bit_vector[dim_id]) return false; } - - // The logic below assumes we have a constant RHS. - if (!rhs.IsPoint()) { - return div; + for (int64_t symbol_id : used_parameters.symbol_ids) { + if (!unused_symbols_bit_vector[symbol_id]) return false; } - int64_t d = rhs.lower_bound; + return true; +} - // If the dividend's range has a single element, return its value. - int64_t a = FloorDiv(lhs.lower_bound, d); - int64_t b = FloorDiv(lhs.upper_bound, d); - if (a == b) { - return getAffineConstantExpr(a, mlir_context_); - } +} // namespace - AffineExpr extracted = getAffineConstantExpr(0, mlir_context_); - auto new_dividend = RewriteSumIf(lhs_simplified, [&](AffineExpr expr) { - if (auto multiplier = GetConstantRhsMultiplier(expr)) { - // (x * 7 + ...) / 3 -> can't extract. We could extract x * 2 and keep - // one x, but we currently have no reason to do that. - if (*multiplier % d != 0) { - return true; - } - int64_t factor = *multiplier / d; - extracted = - extracted + mlir::cast(expr).getLHS() * factor; - // Remove from dividend. - return false; +void IndexingMap::RemoveUnusedSymbols() { + if (IsUndefined()) return; + + // Remove unused symbols from the affine_map. + unsigned num_symbols_before = affine_map_.getNumSymbols(); + SmallBitVector unused_symbols_bit_vector = + mlir::getUnusedSymbolsBitVector({affine_map_}); + SmallBitVector unused_dims_bit_vector = + mlir::getUnusedDimsBitVector({affine_map_}); + + // Check if the symbols that are unused in `affine_map` are also unused in + // expressions. + std::vector> candidates_to_remove; + for (const auto& [expr, range] : constraints_) { + UsedParameters used_parameters = GetUsedParameters(expr); + // If the expression uses only symbols and dims that are "unused" in + // `affine_map`, then we can remove it. + if (IsFunctionOfUnusedDimsAndSymbolsOnly(used_parameters, + unused_dims_bit_vector, + unused_symbols_bit_vector)) { + candidates_to_remove.push_back({expr, used_parameters}); + continue; + } + // Otherwise, we need to mark all symbols of these expr as "used". + for (int64_t symbol_id : used_parameters.symbol_ids) { + unused_symbols_bit_vector[symbol_id] = false; } - - // Not a constant multiplier, keep in dividend. - return true; - }); - - // If we removed everything, skip the div. - if (!new_dividend) { - return extracted; } - // If we removed nothing, return the original division. - if (extracted == getAffineConstantExpr(0, mlir_context_) && - new_dividend == div.getLHS()) { - return div; + for (const auto& [expr, used_parameters] : candidates_to_remove) { + if (IsFunctionOfUnusedDimsAndSymbolsOnly(used_parameters, + unused_dims_bit_vector, + unused_symbols_bit_vector)) { + constraints_.erase(expr); + } } - return extracted + new_dividend.floorDiv(div.getRHS()); -} + // Compress `affine_map` using the updated `unused_symbols_bit_vector`. + affine_map_ = mlir::compressSymbols(affine_map_, unused_symbols_bit_vector); -std::optional IndexingMapSimplifier::GetConstantRhsMultiplier( - AffineExpr expr) { - if (expr.getKind() != AffineExprKind::Mul) { - return std::nullopt; - } - auto bound = range_evaluator_->ComputeExpressionRange( - mlir::cast(expr).getRHS()); - if (!bound.IsPoint()) { - return std::nullopt; - } - return bound.lower_bound; -} + // Remap symbols in the constraint expressions accordingly. + unsigned num_symbols_after = affine_map_.getNumSymbols(); + if (num_symbols_after == num_symbols_before) return; -AffineExpr IndexingMapSimplifier::RewriteSumIf( - AffineExpr expr, const std::function& pred) { - if (expr.getKind() == AffineExprKind::Add) { - auto add = mlir::dyn_cast(expr); - auto lhs = RewriteSumIf(add.getLHS(), pred); - auto rhs = RewriteSumIf(add.getRHS(), pred); - if (lhs == add.getLHS() && rhs == add.getRHS()) { - return add; + std::vector compressed_symbol_ranges_; + MLIRContext* mlir_context = GetMLIRContext(); + int64_t used_symbols_count = 0; + std::vector symbol_replacements( + num_symbols_before, getAffineConstantExpr(0, mlir_context)); + for (int i = 0; i < unused_symbols_bit_vector.size(); ++i) { + if (!unused_symbols_bit_vector[i]) { + compressed_symbol_ranges_.push_back(symbol_ranges_[i]); + symbol_replacements[i] = + getAffineSymbolExpr(used_symbols_count++, mlir_context); } - if (lhs && rhs) { - return lhs + rhs; - } - return lhs ? lhs : (rhs ? rhs : nullptr); } - return pred(expr) ? expr : nullptr; + symbol_ranges_ = std::move(compressed_symbol_ranges_); + std::vector to_remove; + std::vector> to_add; + for (const auto& [expr, range] : constraints_) { + auto updated_expr = expr.replaceSymbols(symbol_replacements); + if (updated_expr == expr) continue; + to_add.push_back({updated_expr, range}); + to_remove.push_back(expr); + } + for (const auto& expr : to_remove) { + constraints_.erase(expr); + } + for (const auto& [expr, range] : to_add) { + AddConstraint(expr, range); + } } -AffineExpr IndexingMapSimplifier::SimplifyOnce(AffineExpr expr) { - switch (expr.getKind()) { - case AffineExprKind::Mul: - case AffineExprKind::Add: { - auto binop = mlir::cast(expr); - auto lhs = SimplifyOnce(binop.getLHS()); - auto rhs = SimplifyOnce(binop.getRHS()); - if (lhs == binop.getLHS() && rhs == binop.getRHS()) { - return expr; - } - return getAffineBinaryOpExpr(expr.getKind(), lhs, rhs); - } - case AffineExprKind::Mod: - return RewriteMod(mlir::cast(expr)); - case AffineExprKind::FloorDiv: - return RewriteFloorDiv(mlir::cast(expr)); - case AffineExprKind::DimId: - case AffineExprKind::SymbolId: { - auto bounds = range_evaluator_->ComputeExpressionRange(expr); - if (bounds.IsPoint()) { - return getAffineConstantExpr(bounds.lower_bound, mlir_context_); - } - return expr; - } - - default: - return expr; +IndexingMap ComposeIndexingMaps(const IndexingMap& first, + const IndexingMap& second) { + if (second.IsUndefined() || first.IsUndefined()) { + return IndexingMap::GetUndefined(); } -} + AffineMap producer_affine_map = second.GetAffineMap(); + // map1.compose(map2) computes map2 ∘ map1 for some reason. + AffineMap composed_map = producer_affine_map.compose(first.GetAffineMap()); -AffineExpr IndexingMapSimplifier::Simplify(AffineExpr expr) { - while (true) { - auto simplified = SimplifyOnce(expr); - if (simplified == expr) { - return expr; - } - expr = simplified; + // The symbols in the composed map, i.e. combined + // producer_map.compose(consumer_map) are packed as [symbols(producer_map) | + // symbols(consumer_map)]. + std::vector combined_symbol_ranges; + combined_symbol_ranges.reserve(second.GetSymbolCount() + + first.GetSymbolCount()); + for (const Range& symbol_range : llvm::concat( + second.GetSymbolRanges(), first.GetSymbolRanges())) { + combined_symbol_ranges.push_back(symbol_range); } -} -AffineMap IndexingMapSimplifier::Simplify(AffineMap affine_map) { - mlir::SmallVector results; - results.reserve(affine_map.getNumResults()); - bool nothing_changed = true; - for (AffineExpr expr : affine_map.getResults()) { - AffineExpr simplified = Simplify(expr); - nothing_changed &= simplified == expr; - results.push_back(simplified); + IndexingMap composed_indexing_map(composed_map, first.GetDimensionRanges(), + std::move(combined_symbol_ranges)); + // Add constraints that are already present in the producer_map. We have to + // compute consumer_map(producer_constraints). To keep all symbols and + // dimension IDs the same as in the `composed_indexing_map.affine_map`, we + // create an AffineMap + // (dims of producer_affine_map)[symbols_of_producer_affine_map] = + // (constraint_1, ..., constraint_N) and then compose. + std::vector constraints; + std::vector constraints_ranges; + for (const auto& [expr, range] : second.GetConstraints()) { + constraints.push_back(expr); + constraints_ranges.push_back(range); + } + auto constraints_map = AffineMap::get( + producer_affine_map.getNumDims(), producer_affine_map.getNumSymbols(), + constraints, producer_affine_map.getContext()); + auto remapped_constraints = constraints_map.compose(first.GetAffineMap()); + for (const auto& [expr, range] : + llvm::zip(remapped_constraints.getResults(), constraints_ranges)) { + composed_indexing_map.AddConstraint(expr, range); + } + // Remap symbol ids and add constraints that are already present in the + // consumer_map. + for (const auto& [expr, range] : first.GetConstraints()) { + composed_indexing_map.AddConstraint( + expr.shiftSymbols(first.GetSymbolCount(), second.GetSymbolCount()), + range); } - if (nothing_changed) { - return affine_map; + // Add constraints for consumer's codomain w.r.t. producer's domain. + for (auto [index, expr] : + llvm::enumerate(first.GetAffineMap().getResults())) { + Range producer_dim_range = + second.GetDimensionRange(static_cast(index)); + composed_indexing_map.AddConstraint( + expr.shiftSymbols(first.GetSymbolCount(), second.GetSymbolCount()), + producer_dim_range); } - return mlir::simplifyAffineMap( - AffineMap::get(affine_map.getNumDims(), affine_map.getNumSymbols(), - results, affine_map.getContext())); + return composed_indexing_map; } } // namespace gpu diff --git a/third_party/xla/xla/service/gpu/model/indexing_map.h b/third_party/xla/xla/service/gpu/model/indexing_map.h index 72789963a46126..e631bc7a45bd92 100644 --- a/third_party/xla/xla/service/gpu/model/indexing_map.h +++ b/third_party/xla/xla/service/gpu/model/indexing_map.h @@ -18,7 +18,6 @@ limitations under the License. #include #include -#include #include #include #include @@ -28,8 +27,10 @@ limitations under the License. #include "absl/types/span.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/Hashing.h" +#include "llvm/ADT/SmallVector.h" #include "mlir/IR/AffineExpr.h" // from @llvm-project #include "mlir/IR/AffineMap.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project #include "xla/service/gpu/model/affine_map_printer.h" namespace xla { @@ -53,67 +54,13 @@ H AbslHashValue(H h, const Range& range) { return H::combine(std::move(h), range.lower_bound, range.upper_bound); } -// Domain contains ranges for symbols and dimensions of an affine map. -class Domain { - public: - Domain() = default; - - Domain(absl::Span dim_ranges, - absl::Span symbol_ranges) - : dim_ranges_(dim_ranges.begin(), dim_ranges.end()), - symbol_ranges_(symbol_ranges.begin(), symbol_ranges.end()) {} - - static Domain FromUpperBounds(absl::Span dim_upper_bounds, - absl::Span symbol_upper_bounds); - - // Getters for dimension ranges. - Range GetDimensionRange(int64_t id) const { return dim_ranges_[id]; } - absl::Span GetDimensionRanges() const { return dim_ranges_; } - int64_t GetDimensionCount() const { return dim_ranges_.size(); } - - // Getters for symbol ranges. - Range GetSymbolRange(int64_t id) const { return symbol_ranges_[id]; } - absl::Span GetSymbolRanges() const { return symbol_ranges_; } - int64_t GetSymbolCount() const { return symbol_ranges_.size(); } - - // Getters for affine expression constraints. - const llvm::DenseMap& GetExprRanges() const { - return expr_ranges_; - } - int64_t GetExprCount() const { return expr_ranges_.size(); } - - // Allows to add bounds for the affine expression `expr`. If there are - // bounds for the `expr`, then computes intersection of the current and new - // ranges. - void AddConstraint(mlir::AffineExpr expr, const Range& range); - - // Returns true if the domain is empty. Right now it scans through all - // constraints to find the one where lower_bound > upper_bound. If it returns - // true, that does not mean that the domain is not effectively empty. - // For example, if there are two constraints 0 <= d0 mod 7 <= 0 and - // 0 <= d0 mod 11 <= 0 for a dimension 0<= d0 <= 50 then there is no d0 that - // satisfies both constraints. - bool IsKnownEmpty() const; - - std::string ToString( - const AffineMapPrinter& printer = AffineMapPrinter()) const; - - void Print(std::ostream& out, const AffineMapPrinter& printer) const; - - private: - std::vector dim_ranges_; - std::vector symbol_ranges_; - // Inequality constraints for affine expressions. They restrict the feasible - // set for the domain of the indexing map. It contains affine expressions - // other than AffineDimExpr and AffineSymbolExpr. - llvm::DenseMap expr_ranges_; -}; - // Evaluates lower and upper bounds for expressions given the domain. // Not thread safe. class RangeEvaluator { public: - explicit RangeEvaluator(const Domain* domain) : domain_(domain) {} + RangeEvaluator(absl::Span dim_ranges, + absl::Span symbol_ranges, + mlir::MLIRContext* mlir_context); // Checks whether an `AffineExpr` always describes a non-negative value. bool IsAlwaysPositiveOrZero(mlir::AffineExpr expr); @@ -124,19 +71,14 @@ class RangeEvaluator { // Computes the range of expression using its subexpression ranges. Range ComputeExpressionRange(mlir::AffineExpr expr); + // Return MLIR context. + mlir::MLIRContext* GetMLIRContext() const { return mlir_context_; } + private: - const Domain* const domain_; + mlir::MLIRContext* mlir_context_; llvm::DenseMap expression_ranges_cache_; }; -std::ostream& operator<<(std::ostream& out, const Domain& domain); -bool operator==(const Domain& lhs, const Domain& rhs); - -template -H AbslHashValue(H h, const Domain& domain) { - return H::combine(std::move(h), domain.GetDimensionRanges(), - domain.GetSymbolRanges()); -} // Contains an affine map with N dimension expressions and M symbols: // (d0, ..., d_{N - 1})[s_0, ..., s_{M - 1}] -> f(d_i, s_j) @@ -163,7 +105,25 @@ H AbslHashValue(H h, const Domain& domain) { // ``` // can be written as `(d0, d1, d2, d3) -> (d0, -d1 + 16, -d2 + 8, d3)` with // d0 in [0, 1), d1 in [0, 16], d2 in [0, 8] and d3 in [0, 8]. -struct IndexingMap { +class IndexingMap { + public: + IndexingMap(mlir::AffineMap affine_map, std::vector dim_ranges, + std::vector symbol_ranges, + absl::Span> constraints = {}) + : affine_map_(affine_map), + dim_ranges_(std::move(dim_ranges)), + symbol_ranges_(std::move(symbol_ranges)) { + for (const auto& [expr, range] : constraints) { + AddConstraint(expr, range); + } + } + + static IndexingMap GetUndefined() { return IndexingMap(); } + + static IndexingMap FromTensorSizes( + mlir::AffineMap affine_map, absl::Span dim_upper_bounds, + absl::Span symbol_upper_bounds); + std::string ToString( const AffineMapPrinter& printer = AffineMapPrinter()) const; @@ -172,63 +132,93 @@ struct IndexingMap { // Returns true if the map was simplified. bool Simplify(); - mlir::AffineMap affine_map; - Domain domain; -}; -std::ostream& operator<<(std::ostream& out, const IndexingMap& indexing_map); -bool operator==(const IndexingMap& lhs, const IndexingMap& rhs); + // Return MLIRContext. + mlir::MLIRContext* GetMLIRContext() const { return affine_map_.getContext(); } -// Composes affine maps, i.e. consumer_map ∘ producer_map. -// Right now the ranges of the composed indexing map are correct only when there -// is no composition with concat. -// TODO(b/319410501): Generalize domain modelling. -std::optional ComposeIndexingMaps( - const std::optional& producer_map, - const std::optional& consumer_map); + // Returns the affine map. + mlir::AffineMap GetAffineMap() const { return affine_map_; } -template -H AbslHashValue(H h, const IndexingMap& indexing_map) { - llvm::hash_code affine_map_hash = llvm::hash_combine(indexing_map.affine_map); - return H::combine(std::move(h), static_cast(affine_map_hash), - indexing_map.domain); -} + // Getters for dimension ranges. + Range GetDimensionRange(int64_t id) const { return dim_ranges_[id]; } + const std::vector& GetDimensionRanges() const { return dim_ranges_; } + int64_t GetDimensionCount() const { return dim_ranges_.size(); } -class IndexingMapSimplifier { - public: - IndexingMapSimplifier(RangeEvaluator* range_evaluator, - mlir::MLIRContext* mlir_context) - : range_evaluator_(range_evaluator), mlir_context_(mlir_context) {} + // Getters for symbol ranges. + Range GetSymbolRange(int64_t id) const { return symbol_ranges_[id]; } + const std::vector& GetSymbolRanges() const { return symbol_ranges_; } + int64_t GetSymbolCount() const { return symbol_ranges_.size(); } - // Simplifies the map as much as possible. - mlir::AffineMap Simplify(mlir::AffineMap affine_map); + // Getters for affine expression constraints. + const llvm::DenseMap& GetConstraints() const { + return constraints_; + } + int64_t GetConstraintsCount() const { return constraints_.size(); } - // Simplifies the expression as much as possible. - mlir::AffineExpr Simplify(mlir::AffineExpr expr); + // Allows to add bounds for the affine expression `expr`. If there are + // bounds for the `expr`, then computes intersection of the current and new + // ranges. + void AddConstraint(mlir::AffineExpr expr, Range range); - private: - std::optional GetConstantRhsMultiplier(mlir::AffineExpr expr); + // Evaluates the constraints at a given point and returns `true` if all + // constraints are satisfied. + bool ConstraintsSatisfied( + llvm::ArrayRef dim_const_exprs, + llvm::ArrayRef symbol_const_exprs) const; - // Simplifier for mod. - // - Rewrites (a * 100 + ...) % 100 to (...) % 100 - // - Rewrites a % b to a if a is known to be less than b. - mlir::AffineExpr RewriteMod(mlir::AffineBinaryOpExpr mod); + // Evaluates indexing map results at a given point. + llvm::SmallVector Evaluate( + llvm::ArrayRef dim_const_exprs, + llvm::ArrayRef symbol_const_exprs) const; - // Simplifier for floordiv. - // - Rewrites (a * 100 + ...) / 100 to a + (...) / 100 - // - Rewrites a / 100 to 0 when a is known to be less than 100. - mlir::AffineExpr RewriteFloorDiv(mlir::AffineBinaryOpExpr div); + // Returns true if the domain is empty. Right now it scans through all + // constraints to find the one where lower_bound > upper_bound. If it returns + // true, that does not mean that the domain is not effectively empty. + // For example, if there are two constraints 0 <= d0 mod 7 <= 0 and + // 0 <= d0 mod 11 <= 0 for a dimension 0<= d0 <= 50 then there is no d0 that + // satisfies both constraints. + bool IsKnownEmpty() const; - mlir::AffineExpr RewriteSumIf( - mlir::AffineExpr expr, const std::function& pred); + bool IsUndefined() const { return affine_map_ == mlir::AffineMap(); } - // Attempts to simplify the expression, but doesn't attempt to simplify the - // result further. - mlir::AffineExpr SimplifyOnce(mlir::AffineExpr expr); + // Removes unused symbols from the `affine_map_` and constraints. + void RemoveUnusedSymbols(); - RangeEvaluator* range_evaluator_; - mlir::MLIRContext* mlir_context_; + private: + IndexingMap() = default; + + // Performs AffineExpr simplification for all constraints. + // Returns true if simplification was performed. + bool SimplifyConstraintExprs(); + + // Performs range simplification for all constraints. + // Returns true if simplification was performed. + bool SimplifyConstraintRanges(); + + mlir::AffineMap affine_map_; + std::vector dim_ranges_; + std::vector symbol_ranges_; + // Inequality constraints for affine expressions. They restrict the feasible + // set for the domain of the indexing map. It contains affine expressions + // other than AffineDimExpr and AffineSymbolExpr. + llvm::DenseMap constraints_; }; +std::ostream& operator<<(std::ostream& out, const IndexingMap& indexing_map); +bool operator==(const IndexingMap& lhs, const IndexingMap& rhs); +IndexingMap operator*(const IndexingMap& lhs, const IndexingMap& rhs); +// Composes affine maps, i.e. first ∘ second. +IndexingMap ComposeIndexingMaps(const IndexingMap& first, + const IndexingMap& second); + +template +H AbslHashValue(H h, const IndexingMap& indexing_map) { + llvm::hash_code affine_map_hash = + llvm::hash_combine(indexing_map.GetAffineMap()); + return H::combine(std::move(h), static_cast(affine_map_hash), + indexing_map.GetDimensionRanges(), + indexing_map.GetSymbolRanges(), + indexing_map.GetConstraintsCount()); +} } // namespace gpu } // namespace xla diff --git a/third_party/xla/xla/service/gpu/model/indexing_map_test.cc b/third_party/xla/xla/service/gpu/model/indexing_map_test.cc index 1905c95332ce7b..d9eef56cd0749b 100644 --- a/third_party/xla/xla/service/gpu/model/indexing_map_test.cc +++ b/third_party/xla/xla/service/gpu/model/indexing_map_test.cc @@ -30,9 +30,6 @@ namespace gpu { namespace { using ::testing::ElementsAre; -using ::testing::HasSubstr; -using ::testing::IsEmpty; -using ::testing::UnorderedElementsAre; class IndexingMapTest : public HloTestBase { public: @@ -40,117 +37,411 @@ class IndexingMapTest : public HloTestBase { AffineMapPrinter printer_; }; -TEST_F(IndexingMapTest, ComposeWithPermutation) { - IndexingMap producer{ +TEST_F(IndexingMapTest, Evaluation) { + IndexingMap indexing_map = IndexingMap::FromTensorSizes( ParseAffineMap("(d0, d1)[s0, s1] -> (d1, d0, s1, s0)", &mlir_context_), - Domain::FromUpperBounds({4, 4}, {2, 2})}; + {4, 4}, {2, 2}); - IndexingMap consumer{ParseAffineMap("(d0)[s0] -> (d0, s0)", &mlir_context_), - Domain::FromUpperBounds({4}, {4})}; + auto results = indexing_map.Evaluate( + mlir::getAffineConstantExprs({1, 2}, &mlir_context_), + mlir::getAffineConstantExprs({3, 4}, &mlir_context_)); + EXPECT_THAT(results, ElementsAre(2, 1, 4, 3)); - auto composed = ComposeIndexingMaps(producer, consumer); - EXPECT_THAT(composed, - MatchIndexingMap( - "(d0)[s0, s1, s2] -> (s2, d0, s1, s0)", - MatchDomain(ElementsAre(MatchRange(0, 3)), - ElementsAre(MatchRange(0, 1), MatchRange(0, 1), - MatchRange(0, 3))))); + auto feasible = indexing_map.ConstraintsSatisfied( + mlir::getAffineConstantExprs({1, 2}, &mlir_context_), + mlir::getAffineConstantExprs({3, 4}, &mlir_context_)); + EXPECT_TRUE(feasible); + + indexing_map.AddConstraint(ParseAffineExpr("s0 mod 4", &mlir_context_), + Range{0, 0}); + + auto infeasible = indexing_map.ConstraintsSatisfied( + mlir::getAffineConstantExprs({1, 2}, &mlir_context_), + mlir::getAffineConstantExprs({5, 4}, &mlir_context_)); + EXPECT_FALSE(infeasible); } -TEST_F(IndexingMapTest, ComposeWithRestrictedRange) { - IndexingMap producer{ +TEST_F(IndexingMapTest, Composition_Permutation) { + IndexingMap producer = IndexingMap::FromTensorSizes( ParseAffineMap("(d0, d1)[s0, s1] -> (d1, d0, s1, s0)", &mlir_context_), - Domain::FromUpperBounds({5, 6}, {7, 2})}; + {4, 4}, {2, 2}); + + IndexingMap consumer = IndexingMap::FromTensorSizes( + ParseAffineMap("(d0)[s0] -> (d0, s0)", &mlir_context_), {4}, {4}); + + auto composed = ComposeIndexingMaps(consumer, producer); + EXPECT_THAT(composed, MatchIndexingMap(R"( + (d0)[s0, s1, s2] -> (s2, d0, s1, s0) + domain: + d0 in [0, 3] + s0 in [0, 1] + s1 in [0, 1] + s2 in [0, 3] + )")); +} - IndexingMap consumer{ParseAffineMap("(d0)[s0] -> (d0, s0)", &mlir_context_), - Domain::FromUpperBounds({10}, {8})}; +TEST_F(IndexingMapTest, Composition_RestrictedRange) { + IndexingMap producer = IndexingMap::FromTensorSizes( + ParseAffineMap("(d0, d1)[s0, s1] -> (d1, d0, s1, s0)", &mlir_context_), + {5, 6}, {7, 2}); + + IndexingMap consumer = IndexingMap::FromTensorSizes( + ParseAffineMap("(d0)[s0] -> (d0, s0)", &mlir_context_), {10}, {8}); + + auto composed = ComposeIndexingMaps(consumer, producer); + EXPECT_THAT(composed, MatchIndexingMap(R"( + (d0)[s0, s1, s2] -> (s2, d0, s1, s0) + domain: + d0 in [0, 4] + s0 in [0, 6] + s1 in [0, 1] + s2 in [0, 5] + )")); +} - auto composed = ComposeIndexingMaps(producer, consumer); - EXPECT_THAT(composed, - MatchIndexingMap( - "(d0)[s0, s1, s2] -> (s2, d0, s1, s0)", - MatchDomain(ElementsAre(MatchRange(0, 4)), - ElementsAre(MatchRange(0, 5), MatchRange(0, 1), - MatchRange(0, 7))))); +TEST_F(IndexingMapTest, Composition_ProducerAndConsumerHaveConstraints) { + IndexingMap producer = IndexingMap::FromTensorSizes( + ParseAffineMap("(d0, d1)[s0, s1] -> (d1, d0, s1, s0)", &mlir_context_), + {50, 60}, {70, 20}); + producer.AddConstraint(ParseAffineExpr("d0 mod 8", &mlir_context_), + Range{0, 0}); + producer.AddConstraint(ParseAffineExpr("s0 mod 3", &mlir_context_), + Range{1, 1}); + + IndexingMap consumer = IndexingMap::FromTensorSizes( + ParseAffineMap("(d0)[s0] -> (d0, s0)", &mlir_context_), {10}, {8}); + consumer.AddConstraint(ParseAffineExpr("d0 + s0", &mlir_context_), + Range{0, 20}); + consumer.AddConstraint(ParseAffineExpr("s0 mod 4", &mlir_context_), + Range{0, 0}); + + auto composed = ComposeIndexingMaps(consumer, producer); + EXPECT_THAT(composed, MatchIndexingMap(R"( + (d0)[s0, s1, s2] -> (s2, d0, s1, s0) + domain: + d0 in [0, 9] + s0 in [0, 69] + s1 in [0, 19] + s2 in [0, 7] + d0 + s2 in [0, 20] + d0 mod 8 in [0, 0] + s0 mod 3 in [1, 1] + s2 mod 4 in [0, 0] + )")); + composed.Simplify(); + EXPECT_THAT(composed, MatchIndexingMap(R"( + (d0)[s0, s1, s2] -> (s2, d0, s1, s0) + domain: + d0 in [0, 9] + s0 in [0, 69] + s1 in [0, 19] + s2 in [0, 7] + d0 mod 8 in [0, 0] + s0 mod 3 in [1, 1] + s2 mod 4 in [0, 0] + )")); } -TEST_F(IndexingMapTest, ComposeWithAddedConstraint) { - IndexingMap producer{ParseAffineMap("(d0) -> (d0)", &mlir_context_), - Domain::FromUpperBounds({2}, {})}; +TEST_F(IndexingMapTest, RemoveUnusedSymbols_ConstraintUsesSymbol) { + IndexingMap indexing_map = IndexingMap::FromTensorSizes( + ParseAffineMap("(d0, d1)[s0, s1] -> (d1, d0, s1)", &mlir_context_), + {50, 60}, {70, 20}); + // This constraint cannot be removed, because it contains a "used symbol". + indexing_map.AddConstraint(ParseAffineExpr("s0 + s1", &mlir_context_), + Range{1, 100}); + indexing_map.AddConstraint(ParseAffineExpr("s0 mod 3", &mlir_context_), + Range{0, 0}); + indexing_map.RemoveUnusedSymbols(); + EXPECT_THAT(indexing_map, MatchIndexingMap(R"( + (d0, d1)[s0, s1] -> (d1, d0, s1) + domain: + d0 in [0, 49] + d1 in [0, 59] + s0 in [0, 69] + s1 in [0, 19] + s0 + s1 in [1, 100] + s0 mod 3 in [0, 0] + )")); +} - IndexingMap consumer{ParseAffineMap("(d0) -> (d0 mod 8)", &mlir_context_), - Domain::FromUpperBounds({100}, {})}; +TEST_F(IndexingMapTest, RemoveUnusedSymbols_ConstraintUsesOnlyUnusedSymbols) { + IndexingMap indexing_map = IndexingMap::FromTensorSizes( + ParseAffineMap("(d0, d1)[s0, s1] -> (d1, d0, s1)", &mlir_context_), + {50, 60}, {70, 20}); + // This constraint can be removed, because it contains only the unused symbol. + indexing_map.AddConstraint(ParseAffineExpr("s0 mod 3", &mlir_context_), + Range{0, 0}); + indexing_map.RemoveUnusedSymbols(); + EXPECT_THAT(indexing_map, MatchIndexingMap(R"( + (d0, d1)[s0] -> (d1, d0, s0) + domain: + d0 in [0, 49] + d1 in [0, 59] + s0 in [0, 19] + )")); +} - auto composed = ComposeIndexingMaps(producer, consumer); - EXPECT_THAT(composed, - MatchIndexingMap("(d0) -> (d0 mod 8)", - MatchDomainWithGenericConstraints( - ElementsAre(MatchRange(0, 99)), IsEmpty(), - UnorderedElementsAre(MatchExprRange( - "d0 mod 8", MatchRange(0, 1)))))); +TEST_F(IndexingMapTest, RemoveUnusedSymbols_ConstraintsWithManySymbols) { + IndexingMap indexing_map = IndexingMap::FromTensorSizes( + ParseAffineMap("(d0)[s0, s1, s2, s3, s4] -> (d0 * 4 + s1 + s3 - 42)", + &mlir_context_), + {32}, {1, 2, 3, 4, 5}); + indexing_map.AddConstraint( + ParseAffineExpr("d0 * 4 + s1 + s3", &mlir_context_), Range{24, 459}); + indexing_map.RemoveUnusedSymbols(); + // Symbols s0, s2, s4 will be removed and s1 and s3 will become s0 and s1. + EXPECT_THAT(indexing_map, MatchIndexingMap(R"( + (d0)[s0, s1] -> (d0 * 4 + s0 + s1 - 42) + domain: + d0 in [0, 31] + s0 in [0, 1] + s1 in [0, 3] + d0 * 4 + s0 + s1 in [24, 459] + )")); } -TEST_F(IndexingMapTest, SimplifyConstantDims) { - IndexingMap indexing_map{ParseAffineMap("(d0) -> (d0)", &mlir_context_), - Domain{{Range{5, 5}}, {}}}; +TEST_F(IndexingMapTest, ConstraintRangeSimplification_Sum) { + IndexingMap indexing_map = IndexingMap::FromTensorSizes( + ParseAffineMap("(d0) -> (d0)", &mlir_context_), {100}, {}); + + indexing_map.AddConstraint(ParseAffineExpr("(d0 mod 8) + 5", &mlir_context_), + Range{50, 54}); + + EXPECT_THAT(indexing_map.ToString(), MatchIndexingString(R"( + (d0) -> (d0) + domain: + d0 in [0, 99] + d0 mod 8 in [45, 49] + )")); +} + +TEST_F(IndexingMapTest, + ConstraintRangeSimplification_FloorDivPositiveDivisorPositiveBounds) { + IndexingMap indexing_map = IndexingMap::FromTensorSizes( + ParseAffineMap("(d0) -> (d0)", &mlir_context_), {100}, {}); + + indexing_map.AddConstraint(ParseAffineExpr("d0 floordiv 8", &mlir_context_), + Range{5, 11}); + EXPECT_THAT(indexing_map.ToString(), MatchIndexingString(R"( + (d0) -> (d0) + domain: + d0 in [40, 95] + )")); +} + +TEST_F(IndexingMapTest, + ConstraintRangeSimplification_FloorDivPositiveDivisorNegativeBounds) { + IndexingMap indexing_map = + IndexingMap(ParseAffineMap("(d0)[s0] -> (d0)", &mlir_context_), + {Range{0, 99}}, {Range{-99, 99}}); + + indexing_map.AddConstraint(ParseAffineExpr("s0 floordiv 3", &mlir_context_), + Range{-11, -5}); + EXPECT_THAT(indexing_map.ToString(), MatchIndexingString(R"( + (d0)[s0] -> (d0) + domain: + d0 in [0, 99] + s0 in [-33, -13] + )")); +} + +TEST_F(IndexingMapTest, + ConstraintRangeSimplification_FloorDivNegativeDivisorNegativeBounds) { + IndexingMap indexing_map = + IndexingMap(ParseAffineMap("(d0)[s0] -> (d0)", &mlir_context_), + {Range{0, 99}}, {Range{-99, 99}}); + + indexing_map.AddConstraint(ParseAffineExpr("s0 floordiv -3", &mlir_context_), + Range{-11, -5}); + EXPECT_THAT(indexing_map.ToString(), MatchIndexingString(R"( + (d0)[s0] -> (d0) + domain: + d0 in [0, 99] + s0 in [15, 35] + )")); +} + +TEST_F(IndexingMapTest, + ConstraintRangeSimplification_MulPositiveMultiplierPositiveBounds) { + IndexingMap indexing_map = IndexingMap::FromTensorSizes( + ParseAffineMap("(d0) -> (d0)", &mlir_context_), {100}, {}); + + indexing_map.AddConstraint(ParseAffineExpr("d0 * 8", &mlir_context_), + Range{14, 33}); + EXPECT_THAT(indexing_map.ToString(), MatchIndexingString(R"( + (d0) -> (d0) + domain: + d0 in [2, 4] + )")); +} + +TEST_F(IndexingMapTest, + ConstraintRangeSimplification_MulPositiveMultiplierNegativeBounds) { + IndexingMap indexing_map = + IndexingMap(ParseAffineMap("(d0)[s0] -> (d0)", &mlir_context_), + {Range{0, 99}}, {Range{-99, 99}}); + + indexing_map.AddConstraint(ParseAffineExpr("s0 * 3", &mlir_context_), + Range{-11, -5}); + EXPECT_THAT(indexing_map.ToString(), MatchIndexingString(R"( + (d0)[s0] -> (d0) + domain: + d0 in [0, 99] + s0 in [-3, -2] + )")); +} + +TEST_F(IndexingMapTest, + ConstraintRangeSimplification_MulNegativeMultiplierNegativeBounds) { + IndexingMap indexing_map = + IndexingMap(ParseAffineMap("(d0)[s0] -> (d0)", &mlir_context_), + {Range{0, 99}}, {Range{-99, 99}}); + + indexing_map.AddConstraint(ParseAffineExpr("s0 * -3", &mlir_context_), + Range{-11, -5}); + EXPECT_THAT(indexing_map.ToString(), MatchIndexingString(R"( + (d0)[s0] -> (d0) + domain: + d0 in [0, 99] + s0 in [2, 3] + )")); +} + +TEST_F(IndexingMapTest, AffineMapSimplification_ConstantDims) { + IndexingMap indexing_map = IndexingMap( + ParseAffineMap("(d0) -> (d0)", &mlir_context_), {Range{5, 5}}, {}); indexing_map.Simplify(); - EXPECT_THAT(printer_.ToString(indexing_map.affine_map), - HasSubstr("(d0) -> (5)")); + EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( + (d0) -> (5) + domain: + d0 in [5, 5] + )")); } -TEST_F(IndexingMapTest, SimplifyDivsAndModsIfSmallerThanDivisor) { +TEST_F(IndexingMapTest, + AffineMapSimplification_DivsAndModsIfSmallerThanDivisor) { auto serialized_map = "(d0, d1) -> (d0 + d1 floordiv 16, d1 mod 16)"; - IndexingMap indexing_map{ParseAffineMap(serialized_map, &mlir_context_), - Domain::FromUpperBounds({8, 16}, {})}; + IndexingMap indexing_map = IndexingMap::FromTensorSizes( + ParseAffineMap(serialized_map, &mlir_context_), {8, 16}, {}); indexing_map.Simplify(); - - EXPECT_THAT(printer_.ToString(indexing_map.affine_map), - HasSubstr("(d0, d1) -> (d0, d1)")); + EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( + (d0, d1) -> (d0, d1) + domain: + d0 in [0, 7] + d1 in [0, 15] + )")); } -TEST_F(IndexingMapTest, SimplifyDivsAndModsWithMultipliers) { +TEST_F(IndexingMapTest, AffineMapSimplification_DivsAndModsWithMultipliers) { auto serialized_map = "(d0, d1, d2) -> ((d0 * 100 + d1 * 10 + d2) floordiv 100, " "((d0 * 100 + d1 * 10 + d2) mod 100) floordiv 10, " "d2 mod 10)"; - IndexingMap indexing_map{ParseAffineMap(serialized_map, &mlir_context_), - Domain::FromUpperBounds({9, 9, 9}, {})}; + IndexingMap indexing_map = IndexingMap::FromTensorSizes( + ParseAffineMap(serialized_map, &mlir_context_), {9, 9, 9}, {}); indexing_map.Simplify(); - EXPECT_THAT(printer_.ToString(indexing_map.affine_map), - HasSubstr("(d0, d1, d2) -> (d0, d1, d2)")); + EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( + (d0, d1, d2) -> (d0, d1, d2) + domain: + d0 in [0, 8] + d1 in [0, 8] + d2 in [0, 8] + )")); } -TEST_F(IndexingMapTest, SimplifyDivsAndModsWithDivisibleMultipliers) { +TEST_F(IndexingMapTest, + AffineMapSimplification_DivsAndModsWithDivisibleMultipliers) { auto serialized_map = "(d0, d1, d2) -> ((d0 * 16 + d1 * 4 + d2) floordiv 8, " - "(d0 * 16 + d1 * 4 + d2) mod 8)"; + " (d0 * 16 + d1 * 4 + d2) mod 8)"; - IndexingMap indexing_map{ParseAffineMap(serialized_map, &mlir_context_), - Domain::FromUpperBounds({10, 10, 10}, {})}; + IndexingMap indexing_map = IndexingMap::FromTensorSizes( + ParseAffineMap(serialized_map, &mlir_context_), {10, 10, 10}, {}); indexing_map.Simplify(); - - EXPECT_THAT(printer_.ToString(indexing_map.affine_map), - HasSubstr("(d0, d1, d2) -> (d0 * 2 + (d1 * 4 + d2) floordiv 8, " - "(d1 * 4 + d2) mod 8)")); + EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( + (d0, d1, d2) -> (d0 * 2 + (d1 + d2 floordiv 4) floordiv 2, (d1 * 4 + d2) mod 8) + domain: + d0 in [0, 9] + d1 in [0, 9] + d2 in [0, 9] + )")); } -TEST_F(IndexingMapTest, SimplifyDivsAndModsWithReverse) { +TEST_F(IndexingMapTest, AffineMapSimplification_DivsAndModsWithReverse) { auto serialized_map = "(d0, d1) -> (-((d0 * -11 - d1 + 109) floordiv 11) + 9, " "d0 * 11 + d1 + ((d0 * -11 - d1 + 109) floordiv 11) * 11 - 99)"; - IndexingMap indexing_map{ParseAffineMap(serialized_map, &mlir_context_), - Domain::FromUpperBounds({8, 9}, {})}; + IndexingMap indexing_map = IndexingMap::FromTensorSizes( + ParseAffineMap(serialized_map, &mlir_context_), {8, 9}, {}); + indexing_map.Simplify(); + EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( + (d0, d1) -> (d0, d1) + domain: + d0 in [0, 7] + d1 in [0, 8] + )")); +} + +TEST_F(IndexingMapTest, AffineMapSimplification_DivsInSequence) { + auto serialized_map = + "()[s0] -> (s0 - ((s0 floordiv 2) floordiv 7) * 14 + (s0 floordiv 14) * " + "14)"; + IndexingMap indexing_map = IndexingMap::FromTensorSizes( + ParseAffineMap(serialized_map, &mlir_context_), {}, {1234}); + indexing_map.Simplify(); + EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( + ()[s0] -> (s0) + domain: + s0 in [0, 1233] + )")); +} + +TEST_F(IndexingMapTest, AffineMapSimplification_DivGcdGreater1) { + auto serialized_map = + "()[s0, s1, s2] -> (s0 * 512 + s1 * 4 + s2 - ((s0 * 2 + s1 floordiv 64) " + "floordiv 3) * 768 + ((s0 * 128 + s1) floordiv 192) * 768)"; + IndexingMap indexing_map = IndexingMap::FromTensorSizes( + ParseAffineMap(serialized_map, &mlir_context_), {}, {1234, 128, 4}); indexing_map.Simplify(); + EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( + ()[s0, s1, s2] -> (s0 * 512 + s1 * 4 + s2) + domain: + s0 in [0, 1233] + s1 in [0, 127] + s2 in [0, 3] + )")); +} - EXPECT_THAT(printer_.ToString(indexing_map.affine_map), - HasSubstr("(d0, d1) -> (d0, d1)")); +TEST_F(IndexingMapTest, AffineMapSimplification_ExtractFromMod) { + auto serialized_map = + "()[s0, s1, s2, s3] -> ((s0 * 458752 + s1 + s2 * 4 + s3 * 512) mod " + "20000)"; + IndexingMap indexing_map = IndexingMap::FromTensorSizes( + ParseAffineMap(serialized_map, &mlir_context_), {}, {872, 4, 128, 896}); + indexing_map.Simplify(); + // TODO(jreiffers): Get rid of the division here. The important thing is that + // s1 was extracted from the mod and is not in the subtracted value, but we'd + // prefer the result to be: + // (s0 * 458752 + s2 * 4 + s3 * 512) mod 20000 + s1 + EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( + ()[s0, s1, s2, s3] -> ( + s0 * 458752 + s1 + s2 * 4 + s3 * 512 - + ((s0 * 14336 + s3 * 16 + s2 floordiv 8) floordiv 625) * 20000 + ) + domain: + s0 in [0, 871] + s1 in [0, 3] + s2 in [0, 127] + s3 in [0, 895] + )")); } TEST_F(IndexingMapTest, RangeEvaluatorTest) { - Domain domain({Range{0, 9}, Range{-10, -1}, Range{-1, 2}, Range{0, 0}}, {}); - RangeEvaluator range_evaluator(&domain); + RangeEvaluator range_evaluator( + {Range{0, 9}, Range{-10, -1}, Range{-1, 2}, Range{0, 0}}, {}, + &mlir_context_); mlir::AffineExpr d0, d1, d2, d3; bindDims(&mlir_context_, d0, d1, d2, d3); @@ -171,24 +462,9 @@ TEST_F(IndexingMapTest, RangeEvaluatorTest) { EXPECT_TRUE(range_evaluator.IsAlwaysNegativeOrZero(d3)); } -// TODO(b/313840171): Simplify `(d1 * 4 + d2) floordiv 8` to `d1 floordiv 2`. - -// TODO(b/313840171): Simplify `(d0 * 8 + d1) floordiv 16` to `d0 floordiv 2`. - // TODO(b/313840171): Simplify `((d0 * 8 + d1) mod 16) floordiv 4` to // `((d0 * 8 + d1) floordiv 4) mod 4` to `(d0 * 2 + d1 floordiv 4) mod 4`. -TEST_F(IndexingMapTest, AffineMapPrinterTest) { - auto map = - ParseAffineMap("(d0, d1)[s0, s1] -> (d0 + d1 floordiv 8, s0 + s1 mod 16)", - &mlir_context_); - printer_.SetDimensionName(0, "offset"); - printer_.SetSymbolName(1, "linear_index"); - EXPECT_THAT(printer_.ToString(map), - HasSubstr("(offset, d1)[s0, linear_index] -> " - "(offset + d1 floordiv 8, s0 + linear_index mod 16)")); -} - } // namespace } // namespace gpu } // namespace xla diff --git a/third_party/xla/xla/service/gpu/model/indexing_test_utils.cc b/third_party/xla/xla/service/gpu/model/indexing_test_utils.cc index f8ced4105c22a3..5285ea4e3e98ad 100644 --- a/third_party/xla/xla/service/gpu/model/indexing_test_utils.cc +++ b/third_party/xla/xla/service/gpu/model/indexing_test_utils.cc @@ -26,6 +26,7 @@ limitations under the License. #include "absl/strings/string_view.h" #include "llvm/ADT/STLExtras.h" #include "mlir/AsmParser/AsmParser.h" // from @llvm-project +#include "mlir/IR/AffineExpr.h" // from @llvm-project #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project #include "xla/hlo/ir/hlo_instruction.h" @@ -37,11 +38,12 @@ limitations under the License. namespace xla { namespace gpu { +using ::mlir::AffineExpr; using ::mlir::AffineMap; -using ::mlir::AffineMapAttr; +using ::mlir::MLIRContext; HloInstructionIndexing ComputeOutputToInputIndexingForEntryComputation( - HloTestBase* test_base, mlir::MLIRContext* mlir_context, + HloTestBase* test_base, MLIRContext* mlir_context, absl::string_view hlo_string, int output_id, bool use_physical_layout) { auto module = test_base->ParseAndReturnVerifiedModule(hlo_string); EXPECT_TRUE(module.ok()); @@ -61,26 +63,24 @@ HloInstructionIndexing ComputeOutputToInputIndexingForEntryComputation( if (!use_physical_layout) return indexing; - std::optional output_permutation = - GetIndexingMapFromPhysicalLayoutToLogical(GetOutputShape(root, output_id), - mlir_context); + IndexingMap output_permutation = GetIndexingMapFromPhysicalLayoutToLogical( + GetOutputShape(root, output_id), mlir_context); for (const auto& [operand_id, indexing_maps] : llvm::enumerate(indexing.indexing_maps)) { - std::optional operand_permutation = - GetIndexingMapFromLogicalToPhysicalLayout( - root->operand(operand_id)->shape(), mlir_context); + IndexingMap operand_permutation = GetIndexingMapFromLogicalToPhysicalLayout( + root->operand(operand_id)->shape(), mlir_context); - absl::flat_hash_set> operand_indexing_maps; - for (const std::optional& indexing_map : indexing_maps) { + absl::flat_hash_set operand_indexing_maps; + for (const IndexingMap& indexing_map : indexing_maps) { auto normalized_indexing_map = indexing_map; - if (output_permutation.has_value()) { + if (!output_permutation.GetAffineMap().isIdentity()) { normalized_indexing_map = - ComposeIndexingMaps(normalized_indexing_map, output_permutation); + ComposeIndexingMaps(output_permutation, normalized_indexing_map); } - if (operand_permutation.has_value()) { + if (!operand_permutation.GetAffineMap().isIdentity()) { normalized_indexing_map = - ComposeIndexingMaps(operand_permutation, normalized_indexing_map); + ComposeIndexingMaps(normalized_indexing_map, operand_permutation); } operand_indexing_maps.insert(normalized_indexing_map); } @@ -90,7 +90,7 @@ HloInstructionIndexing ComputeOutputToInputIndexingForEntryComputation( } HloInstructionIndexing ComputeInputToOutputIndexingForEntryComputation( - HloTestBase* test_base, mlir::MLIRContext* mlir_context, + HloTestBase* test_base, MLIRContext* mlir_context, absl::string_view hlo_string, int input_id, bool use_physical_layout) { auto module = test_base->ParseAndReturnVerifiedModule(hlo_string); EXPECT_TRUE(module.ok()); @@ -110,26 +110,24 @@ HloInstructionIndexing ComputeInputToOutputIndexingForEntryComputation( if (!use_physical_layout) return indexing; - std::optional input_permutation = - GetIndexingMapFromPhysicalLayoutToLogical( - root->operand(input_id)->shape(), mlir_context); + IndexingMap input_permutation = GetIndexingMapFromPhysicalLayoutToLogical( + root->operand(input_id)->shape(), mlir_context); for (const auto& [output_id, indexing_maps] : llvm::enumerate(indexing.indexing_maps)) { - std::optional operand_permutation = - GetIndexingMapFromLogicalToPhysicalLayout( - GetOutputShape(root, output_id), mlir_context); + IndexingMap operand_permutation = GetIndexingMapFromLogicalToPhysicalLayout( + GetOutputShape(root, output_id), mlir_context); - absl::flat_hash_set> operand_indexing_maps; - for (const std::optional& indexing_map : indexing_maps) { + absl::flat_hash_set operand_indexing_maps; + for (const IndexingMap& indexing_map : indexing_maps) { auto normalized_indexing_map = indexing_map; - if (input_permutation.has_value()) { + if (!input_permutation.GetAffineMap().isIdentity()) { normalized_indexing_map = - ComposeIndexingMaps(normalized_indexing_map, input_permutation); + ComposeIndexingMaps(input_permutation, normalized_indexing_map); } - if (operand_permutation.has_value()) { + if (!operand_permutation.GetAffineMap().isIdentity()) { normalized_indexing_map = - ComposeIndexingMaps(operand_permutation, normalized_indexing_map); + ComposeIndexingMaps(normalized_indexing_map, operand_permutation); } operand_indexing_maps.insert(normalized_indexing_map); } @@ -139,14 +137,28 @@ HloInstructionIndexing ComputeInputToOutputIndexingForEntryComputation( } AffineMap ParseAffineMap(absl::string_view serialized_affine_map, - mlir::MLIRContext* context) { + MLIRContext* context) { std::string full_affine_map_string = absl::StrCat("affine_map<", serialized_affine_map, ">"); return mlir::parseAttribute(full_affine_map_string, context) - .cast() + .cast() .getValue(); } +// Since MLIR does not have AffineExprAttr, we construct an AffineMap and then +// retrieve its first result. +AffineExpr ParseAffineExpr(absl::string_view serialized_affine_expr, + MLIRContext* context) { + std::string full_affine_map_string = absl::StrCat( + "affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8, d9)" + "[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (", + serialized_affine_expr, ")>"); + return mlir::parseAttribute(full_affine_map_string, context) + .cast() + .getValue() + .getResult(0); +} + bool ApproximateMatch(std::string_view lhs, std::string_view rhs) { size_t lhs_length = lhs.size(); size_t rhs_length = rhs.size(); diff --git a/third_party/xla/xla/service/gpu/model/indexing_test_utils.h b/third_party/xla/xla/service/gpu/model/indexing_test_utils.h index 7962171d351383..f2e4ac2056c0a9 100644 --- a/third_party/xla/xla/service/gpu/model/indexing_test_utils.h +++ b/third_party/xla/xla/service/gpu/model/indexing_test_utils.h @@ -29,60 +29,21 @@ limitations under the License. namespace xla { namespace gpu { -MATCHER_P2(MatchRange, lower_bound, upper_bound, - absl::StrCat(negation ? "equals " : "doesn't equal ", "range [", - lower_bound, ", ", upper_bound, "]")) { - return ExplainMatchResult(::testing::FieldsAre(lower_bound, upper_bound), arg, - result_listener); -} - -MATCHER_P2(MatchDomain, dim_ranges, symbol_ranges, "") { - return ExplainMatchResult(0, arg.GetExprCount(), result_listener) && - ExplainMatchResult(dim_ranges, arg.GetDimensionRanges(), - result_listener) && - ExplainMatchResult(symbol_ranges, arg.GetSymbolRanges(), - result_listener); -} - -MATCHER(IsEmptyDomain, "") { - return ExplainMatchResult(true, arg.IsKnownEmpty(), result_listener); -} - -MATCHER_P3(MatchDomainWithGenericConstraints, dim_ranges, symbol_ranges, - expr_ranges, "") { - return ExplainMatchResult(dim_ranges, arg.GetDimensionRanges(), - result_listener) && - ExplainMatchResult(symbol_ranges, arg.GetSymbolRanges(), - result_listener) && - ExplainMatchResult(expr_ranges, arg.GetExprRanges(), result_listener); -} +// Matches two strings ignoring whitespaces. +bool ApproximateMatch(std::string_view lhs, std::string_view rhs); -MATCHER_P2(MatchExprRange, affine_map_string, range, "") { - return ExplainMatchResult(::testing::HasSubstr(affine_map_string), - AffineMapPrinter().ToString(arg.first), - result_listener) && - ExplainMatchResult(range, arg.second, result_listener); -} +MATCHER(UndefinedMap, "") { return arg.IsUndefined(); } -MATCHER_P2(MatchIndexingMap, affine_map_string, domain, "") { - if (!arg.has_value()) { +MATCHER_P(MatchIndexingMap, indexing_string, "") { + if (arg.IsUndefined()) { return false; } - return ExplainMatchResult(::testing::HasSubstr(affine_map_string), - AffineMapPrinter().ToString(arg->affine_map), - result_listener) && - ExplainMatchResult(domain, arg->domain, result_listener); + return ExplainMatchResult( + true, ApproximateMatch(indexing_string, arg.ToString()), result_listener); } -// Matches two strings ignoring whitespaces. -bool ApproximateMatch(std::string_view lhs, std::string_view rhs); - MATCHER_P(MatchIndexingString, indexing_string, "") { - if (!arg.has_value()) { - return false; - } - return ExplainMatchResult(true, - ApproximateMatch(indexing_string, arg->ToString()), + return ExplainMatchResult(true, ApproximateMatch(indexing_string, arg), result_listener); } @@ -99,6 +60,9 @@ HloInstructionIndexing ComputeInputToOutputIndexingForEntryComputation( mlir::AffineMap ParseAffineMap(absl::string_view serialized_affine_map, mlir::MLIRContext* context); +mlir::AffineExpr ParseAffineExpr(absl::string_view serialized_affine_expr, + mlir::MLIRContext* context); + } // namespace gpu } // namespace xla diff --git a/third_party/xla/xla/service/gpu/model/tile_analysis.cc b/third_party/xla/xla/service/gpu/model/tile_analysis.cc index 8badaba5b09972..24456afc75bf97 100644 --- a/third_party/xla/xla/service/gpu/model/tile_analysis.cc +++ b/third_party/xla/xla/service/gpu/model/tile_analysis.cc @@ -24,6 +24,7 @@ limitations under the License. #include #include "absl/algorithm/container.h" +#include "absl/strings/str_cat.h" #include "llvm/ADT/STLExtras.h" #include "llvm/Support/Casting.h" #include "mlir/IR/AffineExpr.h" // from @llvm-project @@ -37,6 +38,7 @@ namespace xla { namespace gpu { namespace { +using absl::StrCat; using mlir::AffineDimExpr; using mlir::AffineExpr; using mlir::AffineMap; @@ -172,7 +174,7 @@ AffineMap SubstituteAllIndicesAndKnownSymbolsWithSameValue( // symbols, since they will have been replaced by constants. std::optional RawSymbolicTileFromIndexingMap( const IndexingMap& indexing_map) { - AffineMap affine_map = indexing_map.affine_map; + AffineMap affine_map = indexing_map.GetAffineMap(); if (!AffineMapDescribesTile(affine_map)) { return std::nullopt; } @@ -235,7 +237,7 @@ std::optional RawSymbolicTileFromIndexingMap( if (symbol_expr && symbol_expr.getPosition() < num_known_symbols) { CHECK(!size_expr); const Range& symbol_range = - indexing_map.domain.GetSymbolRange(symbol_expr.getPosition()); + indexing_map.GetSymbolRange(symbol_expr.getPosition()); size_expr = getAffineConstantExpr( symbol_range.upper_bound - symbol_range.lower_bound + 1, mlir_context); @@ -257,7 +259,9 @@ std::optional RawSymbolicTileFromIndexingMap( offset_expressions.reserve(num_results); std::vector stride_expressions; stride_expressions.reserve(num_results); - RangeEvaluator range_evaluator(&indexing_map.domain); + RangeEvaluator range_evaluator(indexing_map.GetDimensionRanges(), + indexing_map.GetSymbolRanges(), + indexing_map.GetMLIRContext()); for (auto [offset_expr, stride_expr, size_expr] : llvm::zip(unnormalized_offset_expressions, signed_stride_expressions, size_expressions)) { @@ -292,8 +296,8 @@ std::optional RawSymbolicTileFromIndexingMap( /*static*/ std::optional SymbolicTile::FromIndexingMap( const IndexingMap& indexing_map) { - MLIRContext* mlir_context = indexing_map.affine_map.getContext(); - int64_t num_input_dims = indexing_map.domain.GetDimensionCount(); + MLIRContext* mlir_context = indexing_map.GetAffineMap().getContext(); + int64_t num_input_dims = indexing_map.GetDimensionCount(); std::vector exprs; exprs.reserve(num_input_dims); @@ -301,11 +305,11 @@ std::optional RawSymbolicTileFromIndexingMap( tile_dimension_ranges.reserve(num_input_dims); std::vector tile_symbol_ranges; tile_symbol_ranges.reserve(kNumTileParametersPerInputDim * num_input_dims + - indexing_map.affine_map.getNumSymbols()); + indexing_map.GetAffineMap().getNumSymbols()); // The symbols declared in 'indexing_map.affine_map' will precede those // defined in the producer map we construct here. - absl::c_copy(indexing_map.domain.GetSymbolRanges(), + absl::c_copy(indexing_map.GetSymbolRanges(), std::back_inserter(tile_symbol_ranges)); // For each input dims we add kNumTileParametersPerInputDim = 3 symbols, as @@ -320,7 +324,7 @@ std::optional RawSymbolicTileFromIndexingMap( exprs.push_back(offset + stride * index); - Range range = indexing_map.domain.GetDimensionRange(dim); + Range range = indexing_map.GetDimensionRange(dim); tile_dimension_ranges.push_back(range); for (int64_t symbol_index = 0; symbol_index < kNumTileParametersPerInputDim; @@ -333,9 +337,10 @@ std::optional RawSymbolicTileFromIndexingMap( num_input_dims, kNumTileParametersPerInputDim * num_input_dims, exprs, mlir_context); - IndexingMap composed_indexing_map{ - .affine_map = indexing_map.affine_map.compose(producer_map), - .domain = Domain(tile_dimension_ranges, tile_symbol_ranges)}; + IndexingMap composed_indexing_map( + indexing_map.GetAffineMap().compose(producer_map), tile_dimension_ranges, + tile_symbol_ranges); + composed_indexing_map.Simplify(); std::optional maybe_raw_symbolic_tile = @@ -371,6 +376,19 @@ void SymbolicTile::Print(std::ostream& out, std::ostream& operator<<(std::ostream& out, const SymbolicTile& symbolic_tile) { AffineMapPrinter printer; + + // This utilizes the assumption that symbols are structured as triplets, i.e. + // [offset0, size0, stride0, ... offset{N-1}, size{N-1}, stride{N-1}] + // where N is the tensor rank. + for (int64_t triplet_start = 0; + triplet_start < symbolic_tile.offset_map().getNumSymbols(); + triplet_start += kNumTileParametersPerInputDim) { + int64_t triplet_idx = triplet_start / kNumTileParametersPerInputDim; + printer.SetSymbolName(triplet_start, StrCat("offset", triplet_idx)); + printer.SetSymbolName(triplet_start + 1, StrCat("size", triplet_idx)); + printer.SetSymbolName(triplet_start + 2, StrCat("stride", triplet_idx)); + } + symbolic_tile.Print(out, printer); return out; } diff --git a/third_party/xla/xla/service/gpu/model/tile_analysis.h b/third_party/xla/xla/service/gpu/model/tile_analysis.h index 3dcb5b2b06231e..d9d66925484512 100644 --- a/third_party/xla/xla/service/gpu/model/tile_analysis.h +++ b/third_party/xla/xla/service/gpu/model/tile_analysis.h @@ -73,6 +73,8 @@ class SymbolicTile { : offset_map_(offset_map), size_map_(size_map), stride_map_(stride_map) {} }; +// Prints symbolic_tile with triplet labels for each symbol. +// i.e. a symbol si which corresponds to an offset will be labeled offseti. std::ostream& operator<<(std::ostream& out, const SymbolicTile& symbolic_tile); } // namespace gpu diff --git a/third_party/xla/xla/service/gpu/model/tile_analysis_test.cc b/third_party/xla/xla/service/gpu/model/tile_analysis_test.cc index 1ab76a8c87a19f..e138d0ff08299e 100644 --- a/third_party/xla/xla/service/gpu/model/tile_analysis_test.cc +++ b/third_party/xla/xla/service/gpu/model/tile_analysis_test.cc @@ -16,6 +16,8 @@ limitations under the License. #include "xla/service/gpu/model/tile_analysis.h" #include +#include +#include #include #include @@ -33,6 +35,7 @@ namespace gpu { namespace { using ::testing::ExplainMatchResult; +using ::testing::HasSubstr; using ::testing::Optional; using ::testing::StrEq; @@ -65,9 +68,7 @@ class SymbolicTileTest : public HloTestBase { mlir::MLIRContext mlir_context_; }; - -TEST_F(SymbolicTileTest, - CanPropagateTileFromDotOutputToInputsWithoutSpecializedTileSizes) { +TEST_F(SymbolicTileTest, CanPropagateTileFromDotOutputToInputs) { auto input_indexing = GetOutputToInputIndexingForEntryComputation(R"( HloModule m ENTRY e { @@ -80,14 +81,14 @@ TEST_F(SymbolicTileTest, )"); EXPECT_THAT( - SymbolicTile::FromIndexingMap(**input_indexing.indexing_maps[0].begin()), + SymbolicTile::FromIndexingMap(*input_indexing.indexing_maps[0].begin()), Optional(MatchSymbolicTile( "()[s0, s1, s2, s3, s4, s5, s6, s7, s8] -> (s0, s3, 0)", "()[s0, s1, s2, s3, s4, s5, s6, s7, s8] -> (s1, s4, 19)", "()[s0, s1, s2, s3, s4, s5, s6, s7, s8] -> (s2, s5, 1)"))); EXPECT_THAT( - SymbolicTile::FromIndexingMap(**input_indexing.indexing_maps[1].begin()), + SymbolicTile::FromIndexingMap(*input_indexing.indexing_maps[1].begin()), Optional(MatchSymbolicTile( "()[s0, s1, s2, s3, s4, s5, s6, s7, s8] -> (s0, 0, s6)", "()[s0, s1, s2, s3, s4, s5, s6, s7, s8] -> (s1, 19, s7)", @@ -104,7 +105,7 @@ TEST_F(SymbolicTileTest, CanPropagateTileThroughTrivialReshape) { )"); EXPECT_THAT( - SymbolicTile::FromIndexingMap(**input_indexing.indexing_maps[0].begin()), + SymbolicTile::FromIndexingMap(*input_indexing.indexing_maps[0].begin()), Optional(MatchSymbolicTile( "()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11] " "-> (s3, s6, s9)", @@ -114,8 +115,7 @@ TEST_F(SymbolicTileTest, CanPropagateTileThroughTrivialReshape) { "-> (s5, s8, s11)"))); } -TEST_F(SymbolicTileTest, - FailsToPropagateTileThroughReshapeWithoutSpecializedTileSizes) { +TEST_F(SymbolicTileTest, FailsToPropagateTileThroughReshape) { auto input_indexing = GetOutputToInputIndexingForEntryComputation(R"( HloModule m ENTRY e { @@ -125,12 +125,11 @@ TEST_F(SymbolicTileTest, )"); EXPECT_EQ( - SymbolicTile::FromIndexingMap(**input_indexing.indexing_maps[0].begin()), + SymbolicTile::FromIndexingMap(*input_indexing.indexing_maps[0].begin()), std::nullopt); } -TEST_F(SymbolicTileTest, - CanPropagateTileThroughElementwiseOpWithoutSpecializedTileSizes) { +TEST_F(SymbolicTileTest, CanPropagateTileThroughElementwiseOp) { auto input_indexing = GetOutputToInputIndexingForEntryComputation(R"( HloModule m ENTRY e { @@ -141,14 +140,13 @@ TEST_F(SymbolicTileTest, )"); EXPECT_THAT( - SymbolicTile::FromIndexingMap(**input_indexing.indexing_maps[0].begin()), + SymbolicTile::FromIndexingMap(*input_indexing.indexing_maps[0].begin()), Optional(MatchSymbolicTile("()[s0, s1, s2] -> (s0)", "()[s0, s1, s2] -> (s1)", "()[s0, s1, s2] -> (s2)"))); } -TEST_F(SymbolicTileTest, - CanPropagateTileFromBroadcastOutputToInputWithoutSpecializedTileSizes) { +TEST_F(SymbolicTileTest, CanPropagateTileFromBroadcastOutputToInput) { auto input_indexing = GetOutputToInputIndexingForEntryComputation(R"( HloModule m ENTRY e { @@ -158,14 +156,13 @@ TEST_F(SymbolicTileTest, )"); EXPECT_THAT( - SymbolicTile::FromIndexingMap(**input_indexing.indexing_maps[0].begin()), + SymbolicTile::FromIndexingMap(*input_indexing.indexing_maps[0].begin()), Optional(MatchSymbolicTile("()[s0, s1, s2, s3, s4, s5] -> (s3)", "()[s0, s1, s2, s3, s4, s5] -> (s4)", "()[s0, s1, s2, s3, s4, s5] -> (s5)"))); } -TEST_F(SymbolicTileTest, - CanPropagateTileFromReduceOutputToInputWithoutSpecializedTileSizes) { +TEST_F(SymbolicTileTest, CanPropagateTileFromReduceOutputToInput) { auto input_indexing = GetOutputToInputIndexingForEntryComputation(R"( HloModule m max { @@ -182,14 +179,13 @@ TEST_F(SymbolicTileTest, )"); EXPECT_THAT( - SymbolicTile::FromIndexingMap(**input_indexing.indexing_maps[0].begin()), + SymbolicTile::FromIndexingMap(*input_indexing.indexing_maps[0].begin()), Optional(MatchSymbolicTile("()[s0, s1, s2] -> (0, s0)", "()[s0, s1, s2] -> (125, s1)", "()[s0, s1, s2] -> (1, s2)"))); } -TEST_F(SymbolicTileTest, - CanPropagateTileThroughReverseWithoutSpecializedTileSizes) { +TEST_F(SymbolicTileTest, CanPropagateTileThroughReverse) { auto input_indexing = GetOutputToInputIndexingForEntryComputation(R"( HloModule m ENTRY e { @@ -199,14 +195,13 @@ TEST_F(SymbolicTileTest, )"); EXPECT_THAT( - SymbolicTile::FromIndexingMap(**input_indexing.indexing_maps[0].begin()), + SymbolicTile::FromIndexingMap(*input_indexing.indexing_maps[0].begin()), Optional(MatchSymbolicTile("()[s0, s1, s2] -> (-s0 - s2 * s1 + 178)", "()[s0, s1, s2] -> (s1)", "()[s0, s1, s2] -> (s2)"))); } -TEST_F(SymbolicTileTest, - CanPropagateTileFromSliceOutputToInputWithoutSpecializedTileSizes) { +TEST_F(SymbolicTileTest, CanPropagateTileFromSliceOutputToInput) { auto input_indexing = GetOutputToInputIndexingForEntryComputation(R"( HloModule m ENTRY e { @@ -216,15 +211,14 @@ TEST_F(SymbolicTileTest, )"); EXPECT_THAT( - SymbolicTile::FromIndexingMap(**input_indexing.indexing_maps[0].begin()), + SymbolicTile::FromIndexingMap(*input_indexing.indexing_maps[0].begin()), Optional(MatchSymbolicTile( "()[s0, s1, s2, s3, s4, s5] -> (s0 * 2 + 40, s3 * 4 + 20)", "()[s0, s1, s2, s3, s4, s5] -> (s1, s4)", "()[s0, s1, s2, s3, s4, s5] -> (s2 * 2, s5 * 4)"))); } -TEST_F(SymbolicTileTest, - CanPropagateTileThroughTransposeWithoutSpecializedTileSizes) { +TEST_F(SymbolicTileTest, CanPropagateTileThroughTranspose) { auto input_indexing = GetOutputToInputIndexingForEntryComputation(R"( HloModule m ENTRY e { @@ -234,12 +228,107 @@ TEST_F(SymbolicTileTest, )"); EXPECT_THAT( - SymbolicTile::FromIndexingMap(**input_indexing.indexing_maps[0].begin()), + SymbolicTile::FromIndexingMap(*input_indexing.indexing_maps[0].begin()), Optional(MatchSymbolicTile("()[s0, s1, s2, s3, s4, s5] -> (s3, s0)", "()[s0, s1, s2, s3, s4, s5] -> (s4, s1)", "()[s0, s1, s2, s3, s4, s5] -> (s5, s2)"))); } +TEST_F(SymbolicTileTest, CanPropagateTileThroughConcatenate) { + // TODO(325488844): Add additional concat test cases with constraints. + auto input_indexing = GetOutputToInputIndexingForEntryComputation(R"( + HloModule m + ENTRY e { + p0 = f32[2,5,7] parameter(0) + p1 = f32[2,11,7] parameter(1) + p2 = f32[2,17,7] parameter(2) + ROOT concat = f32[2,33,7] concatenate(p0, p1, p2), dimensions={1} + } + )"); + + EXPECT_THAT( + SymbolicTile::FromIndexingMap(*input_indexing.indexing_maps[0].begin()), + Optional(MatchSymbolicTile( + "()[s0, s1, s2, s3, s4, s5, s6, s7, s8] -> (s0, s3, s6)", + "()[s0, s1, s2, s3, s4, s5, s6, s7, s8] -> (s1, s4, s7)", + "()[s0, s1, s2, s3, s4, s5, s6, s7, s8] -> (s2, s5, s8)"))); + EXPECT_THAT( + SymbolicTile::FromIndexingMap(*input_indexing.indexing_maps[1].begin()), + Optional(MatchSymbolicTile( + "()[s0, s1, s2, s3, s4, s5, s6, s7, s8] -> (s0, s3 - 5, s6)", + "()[s0, s1, s2, s3, s4, s5, s6, s7, s8] -> (s1, s4, s7)", + "()[s0, s1, s2, s3, s4, s5, s6, s7, s8] -> (s2, s5, s8)"))); + EXPECT_THAT( + SymbolicTile::FromIndexingMap(*input_indexing.indexing_maps[2].begin()), + Optional(MatchSymbolicTile( + "()[s0, s1, s2, s3, s4, s5, s6, s7, s8] -> (s0, s3 - 16, s6)", + "()[s0, s1, s2, s3, s4, s5, s6, s7, s8] -> (s1, s4, s7)", + "()[s0, s1, s2, s3, s4, s5, s6, s7, s8] -> (s2, s5, s8)"))); +} + +TEST_F(SymbolicTileTest, CanPropagateTileThroughPadOpWithoutInteriorPadding) { + // TODO(325488844): Add pad tests with defined constraints on tile input. + auto input_indexing = GetOutputToInputIndexingForEntryComputation(R"( + HloModule m + ENTRY e { + p0 = f32[4, 4] parameter(0) + p1 = f32[] parameter(1) + ROOT pad = f32[8,8] pad(p0, p1), padding=2_2_0x1_3_0 + } + )"); + + EXPECT_THAT( + SymbolicTile::FromIndexingMap(*input_indexing.indexing_maps[0].begin()), + Optional( + MatchSymbolicTile("()[s0, s1, s2, s3, s4, s5] -> (s0 - 2, s3 - 1)", + "()[s0, s1, s2, s3, s4, s5] -> (s1, s4)", + "()[s0, s1, s2, s3, s4, s5] -> (s2, s5)"))); +} + +TEST_F(SymbolicTileTest, CanPrintSymbolicTileWithNamedTriplets) { + auto input_indexing = GetOutputToInputIndexingForEntryComputation(R"( + HloModule m + ENTRY e { + p0 = f32[17, 19] parameter(0) + p1 = f32[19, 23] parameter(1) + ROOT dot = f32[17, 23] dot(p0, p1), + lhs_contracting_dims={1}, rhs_contracting_dims={0} + } + )"); + + std::string s; + std::stringstream ss(s); + + SymbolicTile first_operand_tile = + SymbolicTile::FromIndexingMap(*input_indexing.indexing_maps[0].begin()) + .value(); + SymbolicTile second_operand_tile = + SymbolicTile::FromIndexingMap(*input_indexing.indexing_maps[1].begin()) + .value(); + + ss << first_operand_tile; + EXPECT_THAT( + ss.str(), + AllOf(HasSubstr("()[offset0, size0, stride0, offset1, size1, stride1] " + "-> (offset0, 0)"), + HasSubstr("()[offset0, size0, stride0, offset1, size1, stride1] " + "-> (size0, 19)"), + HasSubstr("()[offset0, size0, stride0, offset1, size1, stride1] " + "-> (stride0, 1)"))); + + // Clear the stream and load the second map. + ss.str(""); + ss << second_operand_tile; + EXPECT_THAT( + ss.str(), + AllOf(HasSubstr("()[offset0, size0, stride0, offset1, size1, stride1] " + "-> (0, offset1)"), + HasSubstr("()[offset0, size0, stride0, offset1, size1, stride1] " + "-> (19, size1)"), + HasSubstr("()[offset0, size0, stride0, offset1, size1, stride1] " + "-> (1, stride1)"))); +} + } // namespace } // namespace gpu } // namespace xla diff --git a/third_party/xla/xla/service/gpu/multi_output_fusion.cc b/third_party/xla/xla/service/gpu/multi_output_fusion.cc index 20cb217fe00ec7..3271d209a5a08a 100644 --- a/third_party/xla/xla/service/gpu/multi_output_fusion.cc +++ b/third_party/xla/xla/service/gpu/multi_output_fusion.cc @@ -208,10 +208,7 @@ FusionDecision ProducerCandidateIsFusible( GpuPerformanceModel::RunTimes t = GpuPerformanceModel::EstimateRunTimes( &producer, cost_analysis, GpuPerformanceModelOptions::Default(), - - // `EstimateRunTimes`'s interface violates const correctness, so we - // need the const cast here. - {const_cast(&consumer)}, + /*fused_consumers=*/{&consumer}, /*multi_output=*/true); if (t.time_fused > t.time_unfused) { return "will execute slower if fused"; diff --git a/third_party/xla/xla/service/gpu/nccl_api.cc b/third_party/xla/xla/service/gpu/nccl_api.cc index 990420ad54dfa2..e09cac7c879194 100644 --- a/third_party/xla/xla/service/gpu/nccl_api.cc +++ b/third_party/xla/xla/service/gpu/nccl_api.cc @@ -17,8 +17,10 @@ limitations under the License. #include #include +#include #include #include +#include #include "absl/algorithm/container.h" #include "absl/hash/hash.h" @@ -26,18 +28,32 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" -#include "third_party/nccl/nccl.h" +#include "absl/strings/str_join.h" +#include "absl/types/span.h" #include "xla/primitive_util.h" #include "xla/service/collective_ops_utils.h" #include "xla/service/gpu/nccl_clique_key.h" #include "xla/stream_executor/device_memory.h" +#include "xla/stream_executor/gpu/gpu_activation.h" #include "xla/stream_executor/gpu/gpu_stream.h" #include "xla/stream_executor/stream.h" #include "xla/xla_data.pb.h" #include "tsl/concurrency/ref_count.h" +#include "tsl/platform/errors.h" #include "tsl/platform/logging.h" #include "tsl/platform/statusor.h" +#if TENSORFLOW_USE_ROCM +#include "rocm/rocm_config.h" +#if (TF_ROCM_VERSION >= 50200) +#include "rocm/include/rccl/rccl.h" +#else +#include "rocm/include/rccl.h" +#endif // TF_ROCM_VERSION >= 50200 +#else +#include "third_party/nccl/nccl.h" +#endif // TENSORFLOW_USE_ROCM + namespace xla::gpu { //==-----------------------------------------------------------------------===// @@ -237,7 +253,7 @@ PersistentPlanAllocator::AllocateAndInitialize(void* src, size_t size) { VLOG(5) << "Allocate and initialize NCCL persistent plan; mem=" << owned_mem->opaque() << "; size=" << size; se::DeviceMemoryBase mem = owned_mem.Release(); - stream_->ThenMemcpy(&mem, src, size); + TF_RETURN_IF_ERROR(stream_->Memcpy(&mem, src, size)); return mem; } @@ -277,9 +293,13 @@ class DefaultNcclApi final : public NcclApi { public: absl::StatusOr GetUniqueId() final; - absl::StatusOr CommInitRank(int32_t nranks, - const NcclCliqueId& clique_id, - int32_t rank) final; + absl::StatusOr> CommInitRanks( + int32_t nranks, const NcclCliqueId& clique_id, + absl::Span ranks, const Config& config) final; + + absl::StatusOr> CommSplit( + absl::Span comms, int32_t color, + absl::Span keys, std::optional config) final; absl::Status CommAbort(NcclCommHandle comm) final; absl::Status CommFinalize(NcclCommHandle comm) final; @@ -344,20 +364,97 @@ absl::StatusOr DefaultNcclApi::GetUniqueId() { return NcclCliqueId(id.internal); } -absl::StatusOr DefaultNcclApi::CommInitRank( - int32_t nranks, const NcclCliqueId& clique_id, int32_t rank) { - VLOG(1) << "Initialize NCCL communicator for rank #" << rank << " of " - << nranks << "; hash(id)=" << absl::HashOf(clique_id.data()); +absl::StatusOr> +DefaultNcclApi::CommInitRanks(int32_t nranks, const NcclCliqueId& clique_id, + absl::Span ranks, + const Config& config) { + VLOG(1) << "Initialize NCCL communicator for " << ranks.size() + << " devices; hash(id)=" << absl::HashOf(clique_id); - if (rank < 0 || rank >= nranks) - return absl::InvalidArgumentError(absl::StrFormat( - "Invalid rank %d, it must be in [0, %d) range", rank, nranks)); + ncclConfig_t comm_config = NCCL_CONFIG_INITIALIZER; + comm_config.splitShare = config.split_share; + if (config.max_nchannels > 0) { + comm_config.maxCTAs = config.max_nchannels; + VLOG(1) << "Maximum number of channels for hash(id)=" + << absl::HashOf(clique_id) << " is set to: " << comm_config.maxCTAs; + } - ncclComm_t comm = nullptr; - XLA_NCCL_RETURN_IF_ERROR( - ncclCommInitRank(&comm, nranks, AsNcclUniqueId(clique_id), rank)); + std::vector comms; + comms.reserve(ranks.size()); - return OwnedNcclComm(Cast(comm), NcclCommDeleter{this}); + TF_RETURN_IF_ERROR(GroupStart()); + for (size_t i = 0; i < ranks.size(); ++i) { + VLOG(1) << "Initialize NCCL communicator for rank #" << ranks[i].rank + << " of " << nranks << "; hash(id)=" << absl::HashOf(clique_id); + + se::gpu::ScopedActivateExecutorContext activate_context(ranks[i].device); + + ncclComm_t comm_handle = nullptr; + XLA_NCCL_RETURN_IF_ERROR( + ncclCommInitRankConfig(&comm_handle, nranks, AsNcclUniqueId(clique_id), + ranks[i].rank, &comm_config)); + + comms.emplace_back(Cast(comm_handle), NcclCommDeleter{this}); + } + TF_RETURN_IF_ERROR(GroupEnd()); + + return comms; +} + +absl::StatusOr> DefaultNcclApi::CommSplit( + absl::Span comms, int32_t color, + absl::Span keys, std::optional config) { + VLOG(1) << absl::StreamFormat( + "Split %d NCCL communicators using color %d and keys: [%s]", comms.size(), + color, absl::StrJoin(keys, ",")); + +#if !defined(TENSORFLOW_USE_ROCM) || TF_ROCM_VERSION >= 60000 + if (keys.size() != comms.size()) { + return absl::InvalidArgumentError( + absl::StrFormat("Comms and keys must have the same size, but %d != %d", + comms.size(), keys.size())); + } + + ncclConfig_t comm_config = NCCL_CONFIG_INITIALIZER; + if (config.has_value()) { + comm_config.splitShare = config.value().split_share; + // If max_nchannels is set, then we don't want to + // inherit from parent comm. + if (config.value().max_nchannels > 0) { + comm_config.maxCTAs = config.value().max_nchannels; + VLOG(1) << "CommSplit maximum number of channels " + << " is set to: " << comm_config.maxCTAs; + } + } + + // In contrast to grouped initialization communicator splitting initializes + // communicators only after a successful call to `GroupEnd`, so we keep a + // vector of handles and after successful splitting convert to RAII wrappers. + std::vector split_comms_handles; + split_comms_handles.resize(comms.size(), nullptr); + + ncclConfig_t* comm_config_ptr = config.has_value() ? &comm_config : nullptr; + TF_RETURN_IF_ERROR(GroupStart()); + for (size_t i = 0; i < comms.size(); ++i) { + VLOG(1) << "Split NCCL communicator " << comms[i] << " with color " << color + << " and key " << keys[i]; + XLA_NCCL_RETURN_IF_ERROR(ncclCommSplit(Cast(comms[i]), color, keys[i], + &split_comms_handles[i], + /*config=*/comm_config_ptr)); + } + TF_RETURN_IF_ERROR(GroupEnd()); + + std::vector split_comms; + for (size_t i = 0; i < split_comms_handles.size(); ++i) { + split_comms.emplace_back(Cast(split_comms_handles[i]), + NcclCommDeleter{this}); + } + return split_comms; +#else + return absl::UnimplementedError( + absl::StrFormat("%s:%d: NCCL operation ncclCommSplit not implemented", + __FILE__, __LINE__)); +#endif // !defined(TENSORFLOW_USE_ROCM) || TF_ROCM_VERSION >= 60000 } absl::Status DefaultNcclApi::CommAbort(NcclCommHandle comm) { @@ -510,9 +607,10 @@ DefaultNcclApi::RegisterBuffer(NcclCommHandle comm, "Register buffer for NCCL communicator; buffer=%p; size=%d; comm=%p", buffer.opaque(), buffer.size(), comm); void* handle = nullptr; +#if (NCCL_VERSION_CODE >= 21901) XLA_NCCL_RETURN_IF_ERROR( ncclCommRegister(Cast(comm), buffer.opaque(), buffer.size(), &handle)); - +#endif // NCCL_VERSION_CODE >= 21901 return reinterpret_cast(handle); } @@ -522,8 +620,9 @@ DefaultNcclApi::DeregisterBuffer(NcclCommHandle comm, VLOG(3) << absl::StreamFormat( "Deregister buffer for NCCL communicator; handle=%p; comm=%p", handle, comm); +#if (NCCL_VERSION_CODE >= 21901) return XLA_NCCL_STATUS( ncclCommDeregister(Cast(comm), reinterpret_cast(handle))); +#endif // NCCL_VERSION_CODE >= 21901 } - } // namespace xla::gpu diff --git a/third_party/xla/xla/service/gpu/nccl_api.h b/third_party/xla/xla/service/gpu/nccl_api.h index 86f19a0d6c72f0..6294b0bfdaf568 100644 --- a/third_party/xla/xla/service/gpu/nccl_api.h +++ b/third_party/xla/xla/service/gpu/nccl_api.h @@ -19,9 +19,11 @@ limitations under the License. #include #include #include +#include #include "absl/status/status.h" #include "absl/status/statusor.h" +#include "absl/types/span.h" #include "xla/service/collective_ops_utils.h" #include "xla/service/gpu/nccl_clique_key.h" #include "xla/shape_util.h" @@ -45,6 +47,14 @@ class NcclApi { public: virtual ~NcclApi() = default; + // Communicator configuration. + // + // https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/api/types.html#ncclconfig + struct Config { + bool split_share = false; + int64_t max_nchannels = 0; + }; + // Returns a default NcclApi for a current process. Can be a real one based on // NCCL or a stub if XLA compiled without NCCL or CUDA support. static NcclApi* Default(); @@ -113,6 +123,14 @@ class NcclApi { tsl::RCReference allocator_; }; + struct DeviceRank { + DeviceRank(se::StreamExecutor* device, int32_t rank) + : device(device), rank(rank) {} + + se::StreamExecutor* device; + int32_t rank; + }; + // Returns a slice of device memory `buff` containing `count` values of data // type `dtype` starting from `offset`. static se::DeviceMemoryBase Slice(se::DeviceMemoryBase buff, @@ -127,11 +145,25 @@ class NcclApi { // https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/api/comms.html#ncclgetuniqueid virtual absl::StatusOr GetUniqueId() = 0; - // Creates a new communicator. + // Creates new communicators for given devices. + // + // This API doesn't have a corresponding API in NCCL and implemented as + // multiple calls to ncclCommInitRank within a single group. // // https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/api/comms.html#ncclcomminitrank - virtual absl::StatusOr CommInitRank( - int32_t nranks, const NcclCliqueId& clique_id, int32_t rank) = 0; + virtual absl::StatusOr> CommInitRanks( + int32_t nranks, const NcclCliqueId& clique_id, + absl::Span ranks, const Config& config) = 0; + + // Creates new communicators by splitting `comms`. + // + // This API doesn't have a corresponding API in NCCL and implemented as + // multiple calls to ncclCommSplit within a single group. + // + // https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/api/comms.html#ncclcommsplit + virtual absl::StatusOr> CommSplit( + absl::Span comms, int32_t color, + absl::Span keys, std::optional config) = 0; // Abort any uncompleted operations and destroys the communicator. Frees // resources that are allocated to a communicator object comm. diff --git a/third_party/xla/xla/service/gpu/nccl_api_stub.cc b/third_party/xla/xla/service/gpu/nccl_api_stub.cc index 036ee6ee135e9b..d4140e752e8276 100644 --- a/third_party/xla/xla/service/gpu/nccl_api_stub.cc +++ b/third_party/xla/xla/service/gpu/nccl_api_stub.cc @@ -15,9 +15,12 @@ limitations under the License. #include #include +#include +#include #include "absl/status/status.h" #include "absl/status/statusor.h" +#include "absl/types/span.h" #include "xla/service/collective_ops_utils.h" #include "xla/service/gpu/nccl_api.h" #include "xla/service/gpu/nccl_clique_key.h" @@ -83,8 +86,15 @@ class NcclApiStub final : public NcclApi { return UnimplementedError(); } - absl::StatusOr CommInitRank(int32_t, const NcclCliqueId&, - int32_t) final { + absl::StatusOr> CommInitRanks( + int32_t, const NcclCliqueId&, absl::Span, + const Config&) final { + return UnimplementedError(); + } + + absl::StatusOr> CommSplit( + absl::Span, int32_t, absl::Span, + std::optional) final { return UnimplementedError(); } diff --git a/third_party/xla/xla/service/gpu/nccl_clique.cc b/third_party/xla/xla/service/gpu/nccl_clique.cc index 5ea4d2d5564982..84156ac58fcf0c 100644 --- a/third_party/xla/xla/service/gpu/nccl_clique.cc +++ b/third_party/xla/xla/service/gpu/nccl_clique.cc @@ -24,8 +24,9 @@ limitations under the License. #include #include +#include "absl/algorithm/container.h" #include "absl/base/thread_annotations.h" -#include "absl/container/flat_hash_map.h" +#include "absl/container/btree_map.h" #include "absl/container/node_hash_map.h" #include "absl/functional/function_ref.h" #include "absl/hash/hash.h" @@ -33,7 +34,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" -#include "absl/synchronization/barrier.h" +#include "absl/strings/str_join.h" #include "absl/synchronization/mutex.h" #include "absl/time/clock.h" #include "absl/time/time.h" @@ -46,8 +47,10 @@ limitations under the License. #include "xla/service/lockable.h" #include "xla/service/rendezvous.h" #include "xla/status_macros.h" +#include "xla/stream_executor/stream_executor.h" #include "tsl/platform/env.h" #include "tsl/platform/errors.h" +#include "tsl/platform/hash.h" #include "tsl/platform/logging.h" #include "tsl/platform/statusor.h" @@ -100,34 +103,41 @@ static absl::Duration TerminateTimeout() { //===----------------------------------------------------------------------===// NcclCliqueCommunicators::NcclCliqueCommunicators( - NcclCliqueKey clique_key, NcclCliqueId clique_id, - absl::node_hash_map communicators) + NcclCliqueKey clique_key, std::optional clique_id, + absl::btree_map communicators) : clique_key_(std::move(clique_key)), clique_id_(std::move(clique_id)), communicators_(std::move(communicators)) {} -std::optional NcclCliqueCommunicators::comm(int32_t rank) { +std::optional NcclCliqueCommunicators::comm( + int32_t rank) { if (auto it = communicators_.find(rank); it != communicators_.end()) { - return &it->second; + return it->second.get(); } return std::nullopt; } +bool NcclCliqueCommunicators::IsLocal() const { + return communicators_.size() == clique_key_.devices().size(); +} + void NcclCliqueCommunicators::ForEachComm( - absl::FunctionRef fn) { + absl::FunctionRef fn) { for (auto& [rank, comm] : communicators_) { - fn(rank, comm); + fn(rank, comm.get()); } } std::string NcclCliqueCommunicators::DebugString() const { - std::string out = absl::StrFormat( - "clique_key: %s; hash(id): %d; size: %d; communicators: ", - clique_key_.ToString(), absl::HashOf(clique_id_), communicators_.size()); + std::string out = + absl::StrFormat("clique_key: %s; hash(id): %d; size: %d; communicators: ", + clique_key_.ToString(), + clique_id_.has_value() ? absl::HashOf(*clique_id_) : 0, + communicators_.size()); int32_t cnt = 0; for (const auto& [rank, comm] : communicators_) { if (cnt++) absl::StrAppend(&out, ", "); - absl::StrAppendFormat(&out, "[rank=%d, comm=%p]", rank, comm.value()); + absl::StrAppendFormat(&out, "[rank=%d, comm=%p]", rank, comm.get()); } return out; } @@ -150,26 +160,6 @@ static NcclCliques& GetNcclCliques() { return *cliques; } -// Acquires a NCCL clique for a given key. Should be used with extra care if -// executed outside of a rendezvous callback as it's unsafe to launch unrelated -// collective operations using the same clique out of order. -// -// If NCCL clique for a given key is not initialized returns an empty lock. -// Caller must always check if lock is valid before trying to use it. -static absl::StatusOr AcquireNcclClique( - const NcclCliqueKey& clique_key, RunId run_id, - int32_t num_local_participants) { - NcclCliques& cliques = GetNcclCliques(); - - absl::MutexLock lock(&cliques.mu); - if (auto it = cliques.map.find(clique_key); it != cliques.map.end()) { - return it->second.Acquire(); - } - - // Return empty lock if we do not have a clique for `clique_key`. - return NcclClique::Lock(); -} - //===----------------------------------------------------------------------===// // NcclClique Heart Beat Monitor //===----------------------------------------------------------------------===// @@ -177,17 +167,14 @@ static absl::StatusOr AcquireNcclClique( // Runs an async error check for a `comm` and aborts it if it is in the // error state. It will free resources that are allocated to a communicator // and abort any uncompleted operations before destroying the communicator. -static absl::Status CheckComm(NcclComm& lockable_comm) { - if (NcclComm::Lock comm = lockable_comm.TryAcquire()) { - absl::Status async_err = NcclApi::Default()->CommGetAsyncError(*comm); - if (!async_err.ok()) { - LOG(ERROR) << "Aborting communicator: " << comm - << " due to async NCCL error: " << async_err; - TF_RETURN_IF_ERROR(NcclApi::Default()->CommAbort(*comm)); - } - return async_err; +static absl::Status CheckComm(NcclApi::NcclCommHandle comm) { + absl::Status async_err = NcclApi::Default()->CommGetAsyncError(comm); + if (!async_err.ok()) { + LOG(ERROR) << "Aborting communicator: " << comm + << " due to async NCCL error: " << async_err; + TF_RETURN_IF_ERROR(NcclApi::Default()->CommAbort(comm)); } - return absl::OkStatus(); + return async_err; } // Runs async check on all communicators in a clique. @@ -195,8 +182,9 @@ static void CheckClique(const NcclCliqueKey& clique_key, NcclClique& lockable_clique) { if (NcclClique::Lock clique = lockable_clique.TryAcquire()) { VLOG(5) << "Checking NCCL clique " << clique_key.ToString() - << " for async errors; num_communicators=" << clique->size(); - clique->ForEachComm([](int32_t rank, NcclComm& comm) { + << " for async errors; num_communicators=" + << clique->num_communicators(); + clique->ForEachComm([](int32_t rank, NcclApi::NcclCommHandle comm) { if (auto status = CheckComm(comm); !status.ok()) LOG(ERROR) << status; }); } else { @@ -233,221 +221,297 @@ static void StartNcclCliqueHeartBeatMonitor() { // NcclClique initialization must be executed together by all participants, and // we rely on rendezvous to guarantee that all ranks are ready to initialize -// NCCL communicators. - -namespace { -// Local (in-process) NCCL clique initialization state. Once initialization is -// complete NCCL clique added to the NcclCliques container (see above). -struct InitializationState { - using Ranks = absl::Span; - InitializationState(NcclCliqueId clique_id, Ranks ranks); - - NcclCliqueId clique_id; - absl::node_hash_map> comms; - - // Signals when all participants updated entries in `comms`. - std::unique_ptr ready; -}; - -} // namespace - -InitializationState::InitializationState(NcclCliqueId clique_id, Ranks ranks) - : clique_id(clique_id), ready(new absl::Barrier(ranks.size())) { - // Initialize `comms` for all ranks so that each participating thread can - // write into it without synchronization. - for (const int32_t* rank : ranks) { - comms[*rank] = absl::InternalError("uninitialized NCCL communicator"); - } +// NCCL communicators. In general collective operations are expected to be +// executed concurrently by all participating ranks, and when some ranks do not +// join the operation it leads to deadlocks. We use a combination of rendezvous +// and locking to guarantee that all collective operations in XLA have a well +// defined order and do not deadlock inside underlying collective communication +// library. + +static auto DeviceRanksToString(absl::Span ranks) { + return absl::StrJoin(ranks, ",", [](std::string* str, auto& rank) { + str->append(std::to_string(rank.rank)); + }); } -// Creates a new NCCL communicator for a given `rank` and joins a rendezvous to -// initialize a clique for a `clique_key`. Returns a lock that gives exclusive -// access to a NCCL clique. +// Joins a NcclClique initialization rendezvous for a `clique_key` and returns +// a lock that gives an access to initialized clique (access is shared between +// all participating ranks that own a shared pointer). static absl::StatusOr> InitializeNcclClique( - RunId run_id, NcclCliqueKey clique_key, + se::StreamExecutor* device, RunId run_id, NcclCliqueKey clique_key, const NcclCliqueIdCallback& clique_id_callback, - int32_t num_local_participants, int32_t rank) { + int32_t num_local_participants, int32_t rank, NcclApi::Config& config) { int nranks = clique_key.devices().size(); VLOG(3) << "Initialize NCCL clique " << clique_key.ToString() << " rank #" - << rank << " of " << nranks - << "; num_local_participants=" << num_local_participants; + << rank << "; num_local_participants=" << num_local_participants; // Start NCCL clique heart beat monitor when create a first clique. StartNcclCliqueHeartBeatMonitor(); - // Creates initialization state for participating ranks. - auto create_initialization_state = [&](absl::Span ranks) - -> absl::StatusOr { + // Initializes a NcclClique for given device ranks and returns a lock that + // gives access to clique communicators. + auto initialize = [&](absl::Span args) + -> absl::StatusOr { TF_ASSIGN_OR_RETURN(auto clique_id, clique_id_callback(clique_key)); - VLOG(3) << "Created unique clique id (hash): " << absl::HashOf(clique_id); - return InitializationState(clique_id, ranks); + + std::vector ranks; + ranks.reserve(args.size()); + for (auto* arg : args) ranks.emplace_back(*arg); + + // Sort device ranks, mainly to get more readable logs below, NCCL does + // not care in what order ranks are initialized. + absl::c_sort(ranks, [](auto& a, auto& b) { return a.rank < b.rank; }); + + VLOG(3) << absl::StreamFormat( + "Create NCCL communicators for clique %s; ranks=[%s]; hash(id)=%d", + clique_key.ToString(), DeviceRanksToString(ranks), + absl::HashOf(clique_id)); + + TF_ASSIGN_OR_RETURN( + std::vector created_comms, + NcclApi::Default()->CommInitRanks(nranks, clique_id, ranks, config)); + + absl::btree_map comms; + for (size_t i = 0; i < ranks.size(); ++i) { + comms[ranks[i].rank] = std::move(created_comms[i]); + } + + VLOG(3) << absl::StreamFormat( + "Created NCCL communicators for clique %s; ranks=[%s]; hash(id)=%d", + clique_key.ToString(), DeviceRanksToString(ranks), + absl::HashOf(clique_id)); + + NcclCliques& cliques = GetNcclCliques(); + absl::MutexLock lock(&cliques.mu); + + // Create a new clique with given clique key and communicators. + auto emplaced = cliques.map.try_emplace(clique_key, clique_key, clique_id, + std::move(comms)); + + // We can have a race to create a clique for a given key, the winner + // inserts it into a map and the looser destroys all communicators. + if (!emplaced.second) { + VLOG(3) << "Clique already exists: " + << emplaced.first->second.DebugString(); + } else { + VLOG(3) << "Created new clique: " << emplaced.first->second.DebugString(); + } + + return emplaced.first->second.Acquire(); }; // We include `run_id` to a rendezvous key to make sure that multiple // concurrent initializations will not join the same rendezvous. The winner // will update cliques state, and others will destroy unused communicators. auto rendezvous_key = std::make_tuple(run_id, clique_key); - auto initialization_rendezvous_name = absl::StrFormat( - "create clique initialization state for rank %d; clique=%s; run_id=%d", - rank, clique_key.ToString(), run_id.ToInt()); - - // Do a round of rendezvous to wait for all participants to join NCCL clique - // initialization process. - TF_ASSIGN_OR_RETURN(std::shared_ptr state, - RendezvousSingle>( - initialization_rendezvous_name, rendezvous_key, rank, - num_local_participants, create_initialization_state, - WarnStuckTimeout(), TerminateTimeout())); - - VLOG(3) << "Create NCCL communicator for clique " << clique_key.ToString() - << " rank #" << rank << " of " << nranks - << "; num_local_participants=" << num_local_participants; + auto initialization_rendezvous_name = + absl::StrFormat("initialize clique for rank %d; clique=%s; run_id=%d", + rank, clique_key.ToString(), run_id.ToInt()); - absl::StatusOr comm = - NcclApi::Default()->CommInitRank(nranks, state->clique_id, rank); + NcclApi::DeviceRank device_rank = {device, rank}; - if (comm.ok()) { - state->comms[rank] = std::move(*comm); - } else { - state->comms[rank] = comm.status(); - } + return RendezvousSingle>( + initialization_rendezvous_name, rendezvous_key, device_rank, + num_local_participants, initialize, WarnStuckTimeout(), + TerminateTimeout()); +} - // Wait for all participants to complete communicator initialization. - bool completed_initialization = state->ready->Block(); +// Computes a unique NCCL communicator split color from a clique key. We use a +// deterministic hash function to guarantee that all participating processes get +// the same color value for a clique. +static int32_t GetCommSplitColor(const NcclCliqueKey& clique_key) { + std::vector global_device_ids; + global_device_ids.reserve(clique_key.devices().size()); - // Check that all ranks successfully initialize communicators. - for (const auto& [rank, comm] : state->comms) { - TF_RETURN_IF_ERROR(comm.status()); + for (GlobalDeviceId id : clique_key.devices()) { + global_device_ids.push_back(id.value()); } - // If we are the leader who completed the clique initialization we should - // update the local (in-process) cliques state. - if (completed_initialization) { - NcclCliques& cliques = GetNcclCliques(); + return abs(static_cast( + tsl::Hash32(reinterpret_cast(global_device_ids.data()), + sizeof(int64_t) * global_device_ids.size(), 0))); +} - // Create NCCL communicators from handles. - absl::node_hash_map communicators; - for (auto& [rank, comm] : state->comms) { - if (*comm == nullptr) { - return absl::InternalError(absl::StrFormat( - "uninitialized NCCL communicator for rank %d", rank)); +// Joins a NcclClique initialization rendezvous for a `clique_key` and returns +// a lock that gives an access to clique created by splitting already acquired +// `parent_clique` clique (access is shared between all participating ranks that +// own a shared pointer). +static absl::StatusOr> InitializeNcclClique( + se::StreamExecutor* device, RunId run_id, NcclCliqueKey clique_key, + std::shared_ptr parent_clique, + int32_t num_local_participants, int32_t rank, NcclApi::Config& config) { + // Find our rank in the parent clique. + const NcclCliqueKey& parent_clique_key = (*parent_clique)->clique_key(); + int32_t parent_rank = *parent_clique_key.rank(clique_key.devices()[rank]); + + VLOG(3) << "Initialize NCCL clique " << clique_key.ToString() << " rank #" + << rank << " by splitting rank #" << parent_rank + << " in parent clique " << parent_clique_key.ToString() + << "; num_local_participants=" << num_local_participants; + + using RankPair = std::pair; + RankPair rank_pair = {parent_rank, rank}; + + // Current approach for communicator splitting works because of XLAs SPMD + // programming model where all collective operations have replica groups that + // include all ranks. This property guarantees that we'll split each + // communicator exactly once with a unique color computed from rank mapping + // and each communicator in the parent clique will become a member of exactly + // one new clique. Clique splitting happens concurrently for multiple + // non-overlapping clique and this guarantees forward progress even with + // implicit synchronization inside NCCL. + + // Initializes a NcclClique for given device ranks and returns a lock that + // gives access to clique communicators. + auto split = [&](absl::Span rank_pairs) + -> absl::StatusOr { + // Collect mapping from ranks in parent clique to ranks in a new clique. + absl::btree_map rank_mapping; + for (auto* rank_pair : rank_pairs) { + rank_mapping[rank_pair->first] = rank_pair->second; + } + + auto rank_mapping_formatter = [](std::string* str, auto mapping) { + absl::StrAppend(str, mapping.first, "->", mapping.second); + }; + + // Collect parent communicators we'll be splitting from and keys for + // creating new communicators. + std::vector parent_comms; + std::vector keys; + + for (auto& [parent_rank, split_rank] : rank_mapping) { + auto parent_comm = (*parent_clique)->comm(parent_rank); + if (!parent_comm.has_value()) { + return absl::InvalidArgumentError(absl::StrFormat( + "Parent clique %s does not have a communicator for rank %d", + parent_clique_key.ToString(), parent_rank)); } - communicators.try_emplace(rank, comm->release()); + + parent_comms.push_back(*parent_comm); + keys.push_back(split_rank); } - VLOG(3) << "Completed NCCL clique initialization for a clique " - << clique_key.ToString(); + // Get a globally consistent color value for newly created clique. + int32_t color = GetCommSplitColor(clique_key); - // Create a new clique with given clique id and communicators. + VLOG(3) << absl::StreamFormat( + "Create NCCL communicators for clique %s; parent=%s; color=%d; " + "rank_mapping=[%s]", + clique_key.ToString(), parent_clique_key.ToString(), color, + absl::StrJoin(rank_mapping, ",", rank_mapping_formatter)); + + TF_ASSIGN_OR_RETURN( + auto splitted_comms, + NcclApi::Default()->CommSplit(parent_comms, color, keys, config)); + + absl::btree_map comms; + for (size_t i = 0; i < splitted_comms.size(); ++i) { + comms[i] = std::move(splitted_comms[i]); + } + + VLOG(3) << absl::StreamFormat( + "Created NCCL communicators for clique %s; parent=%s; color=%d; " + "rank_mapping=[%s]", + clique_key.ToString(), parent_clique_key.ToString(), color, + absl::StrJoin(rank_mapping, ",", rank_mapping_formatter)); + + NcclCliques& cliques = GetNcclCliques(); absl::MutexLock lock(&cliques.mu); - auto emplaced = cliques.map.try_emplace( - clique_key, clique_key, state->clique_id, std::move(communicators)); - // We can have a race to create a clique for a given key, the winner inserts - // it into a map and the looser destroys all communicators. + // Create a new clique with given clique key and communicators. + auto emplaced = cliques.map.try_emplace(clique_key, clique_key, + std::nullopt, std::move(comms)); + + // We can have a race to create a clique for a given key, the winner + // inserts it into a map and the looser destroys all communicators. if (!emplaced.second) { VLOG(3) << "Clique already exists: " << emplaced.first->second.DebugString(); } else { VLOG(3) << "Created new clique: " << emplaced.first->second.DebugString(); } - } - // Do one more round of rendezvous to guarantee that all ranks that - // participated in clique initialization will share an exclusive access to all - // communicators in a NCCL clique. - auto initialized_rendezvous_name = absl::StrFormat( - "acquire initialized clique for rank %d; clique=%s; run_id=%d", rank, - clique_key.ToString(), run_id.ToInt()); + return emplaced.first->second.Acquire(); + }; + + // We include `run_id` to a rendezvous key to make sure that multiple + // concurrent initializations will not join the same rendezvous. The winner + // will update cliques state, and others will destroy unused communicators. + auto rendezvous_key = std::make_tuple(run_id, clique_key, parent_clique_key); + auto initialization_rendezvous_name = absl::StrFormat( + "initialize clique for rank %d; clique=%s; run_id=%d; parent=%s", rank, + clique_key.ToString(), run_id.ToInt(), parent_clique_key.ToString()); return RendezvousSingle>( - initialized_rendezvous_name, rendezvous_key, num_local_participants, - [&] { - return AcquireNcclClique(clique_key, run_id, num_local_participants); - }, - WarnStuckTimeout(), TerminateTimeout()); + initialization_rendezvous_name, rendezvous_key, rank_pair, + num_local_participants, split, WarnStuckTimeout(), TerminateTimeout()); } //===----------------------------------------------------------------------===// +using AcquiredCliquesMap = NcclClique::AcquiredCliquesMap; + absl::StatusOr> AcquireNcclClique( - RunId run_id, OpId op_id, NcclCliqueKey clique_key, + se::StreamExecutor* device, RunId run_id, NcclCliqueKey clique_key, const NcclCliqueIdCallback& clique_id_callback, int32_t rank, - size_t num_local_participants, bool may_skip_rendezvous) { + size_t num_local_participants, const AcquiredCliquesMap& acquired_cliques, + int64_t max_nchannels) { VLOG(2) << "Acquire NCCL clique " << clique_key.ToString() << "; run" - << run_id.ToString() << "; op" << op_id.value() << "; rank " << rank + << run_id.ToString() << "; rank " << rank << "; num_local_participants=" << num_local_participants - << "; may_skip_rendezvous=" << may_skip_rendezvous; - - // If we prefer to skip rendezvous check if NcclClique is already available - // for a given key. - // TODO(ezhulenev): Remove this code path as it leads to deadlocks. - if (may_skip_rendezvous) { - TF_ASSIGN_OR_RETURN( - NcclClique::Lock clique, - AcquireNcclClique(clique_key, run_id, num_local_participants)); - - // If lock is not null return it to the caller. - if (clique) return std::make_shared(std::move(clique)); - - } else { - // Get the clique lock via the rendezvous process. - auto rendezvous_key = std::make_tuple(run_id, clique_key); - auto rendezvous_name = - absl::StrFormat("acquire clique for rank %d; clique=%s; run_id=%d", - rank, clique_key.ToString(), run_id.ToInt()); - - TF_ASSIGN_OR_RETURN( - std::shared_ptr clique, - RendezvousSingle>( - rendezvous_name, rendezvous_key, num_local_participants, - [&] { - return AcquireNcclClique(clique_key, run_id, - num_local_participants); - }, - WarnStuckTimeout(), TerminateTimeout())); - - // If lock is not null return it to the caller. - if (*clique) return clique; - } - - // If NCCL clique is not found try to initialize a new one for a given key. - return InitializeNcclClique(run_id, clique_key, clique_id_callback, - num_local_participants, rank); -} + << "; acquired_cliques=" << acquired_cliques.size(); -absl::StatusOr AcquireNcclComm( - RunId run_id, OpId op_id, std::vector participants, - size_t num_local_participants, - const NcclCliqueIdCallback& clique_id_callback, int32_t rank, - int64_t stream_id, bool enable_clique_optimization) { - // Ensure that this group of threads have exclusive access to the clique to - // prevent threads from different groups locking communicators in the clique. - // The enable_clique_optimization value is only used for asynchronous - // collective stream currently. For synchronous collectives, we should always - // enable the optimization. For P2P stream, we currently have to always enable - // the optimization, because we initially implement this optimization to - // workaround an NCCL bug related to P2P operations. - NcclCliqueKey clique_key(std::move(participants), stream_id); + // Get the clique lock via the rendezvous to guarantee that all clique + // members participate in XLA run. + auto rendezvous_key = std::make_tuple(run_id, clique_key); + auto rendezvous_name = + absl::StrFormat("acquire clique for rank %d; clique=%s; run_id=%d", rank, + clique_key.ToString(), run_id.ToInt()); TF_ASSIGN_OR_RETURN( std::shared_ptr clique, - AcquireNcclClique( - run_id, op_id, clique_key, clique_id_callback, rank, - num_local_participants, - enable_clique_optimization || - stream_id != GetStreamId(/*is_async=*/true, - AsyncStreamKind::kCollective))); - - // Check that clique has a communicator for our rank. - auto communicator = (*clique)->comm(rank); - if (!communicator.has_value()) { - return absl::InternalError(absl::StrCat("Communicator for rank ", rank, - " not found in a NCCL clique ", - clique_key.ToString())); + RendezvousSingle>( + rendezvous_name, rendezvous_key, num_local_participants, + [&] { + NcclCliques& cliques = GetNcclCliques(); + absl::MutexLock lock(&cliques.mu); + // Returns empty lock if we do not have a clique for `clique_key`. + auto it = cliques.map.find(clique_key); + return it == cliques.map.end() ? NcclClique::Lock() + : it->second.Acquire(); + }, + WarnStuckTimeout(), TerminateTimeout())); + + // If lock is not null return it to the caller. + if (*clique) return clique; + + // Maybe find if we acquired a clique with communicators that we can split. + static const int64_t enable_nccl_comm_splitting = + xla::GetDebugOptionsFromFlags().xla_gpu_enable_nccl_comm_splitting(); + + // We enable resource sharing between parent and split communicators by + // default because that's the only reason why we use comm splitting. + NcclApi::Config config; + config.split_share = true; + config.max_nchannels = max_nchannels; + + if (enable_nccl_comm_splitting) { + for (auto& [acquired_clique_key, acquired_clique] : acquired_cliques) { + // We don't support splitting non-local cliques as it requires careful + // synchronization between multiple processes. + if (!(*acquired_clique)->IsLocal()) continue; + + if (clique_key.IsSubsetOf(acquired_clique_key)) { + return InitializeNcclClique(device, run_id, clique_key, acquired_clique, + num_local_participants, rank, config); + } + } } - return (*communicator)->Acquire(); + // If we can't split any of the acquired cliques, create a new one. + return InitializeNcclClique(device, run_id, clique_key, clique_id_callback, + num_local_participants, rank, config); } } // namespace xla::gpu diff --git a/third_party/xla/xla/service/gpu/nccl_clique.h b/third_party/xla/xla/service/gpu/nccl_clique.h index 274c56ba7baccb..3275ba5b625b38 100644 --- a/third_party/xla/xla/service/gpu/nccl_clique.h +++ b/third_party/xla/xla/service/gpu/nccl_clique.h @@ -18,22 +18,21 @@ limitations under the License. #include #include +#include #include #include #include #include -#include -#include "absl/container/node_hash_map.h" +#include "absl/container/btree_map.h" #include "absl/functional/function_ref.h" #include "absl/status/statusor.h" #include "absl/strings/str_format.h" #include "xla/executable_run_options.h" -#include "xla/service/global_device_id.h" #include "xla/service/gpu/nccl_api.h" #include "xla/service/gpu/nccl_clique_key.h" #include "xla/service/lockable.h" -#include "tsl/lib/gtl/int_type.h" +#include "xla/stream_executor/stream_executor.h" namespace xla::gpu { @@ -55,9 +54,6 @@ namespace xla::gpu { // // https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/communicators.html#using-multiple-nccl-communicators-concurrently -// Forward declare. -class NcclCliqueCommunicators; - //===----------------------------------------------------------------------===// // NcclUniqueId //===----------------------------------------------------------------------===// @@ -71,35 +67,6 @@ absl::StatusOr GetNcclCliqueIdCallback( const NcclCliqueIdCallback* clique_id_callback, // may be null bool is_local); -//===----------------------------------------------------------------------===// -// NcclComm -//===----------------------------------------------------------------------===// - -// TODO(b/319655685): Lockable NcclComm should be deleted and NcclClique should -// become the owner of all communicators making up a clique and responsible for -// synchronizing access to communicators. - -TSL_LIB_GTL_DEFINE_INT_TYPE(OpId, int64_t); - -struct NcclCommName { - static std::string ToString(NcclApi::NcclCommHandle comm) { - return absl::StrFormat("lockable comm %p", comm); - } -}; - -struct NcclComm : public Lockable { - friend class NcclCliqueCommunicators; - - explicit NcclComm(NcclApi::NcclCommHandle comm) : Lockable(comm) {} -}; - -// Acquires an exclusive access to NCCL communicator owned by a NCCL clique. -absl::StatusOr AcquireNcclComm( - RunId run_id, OpId op_id, std::vector participants, - size_t num_local_participants, - const NcclCliqueIdCallback& clique_id_callback, int32_t rank, - int64_t stream_id, bool enable_clique_optimization); - //===----------------------------------------------------------------------===// // NcclClique //===----------------------------------------------------------------------===// @@ -110,27 +77,33 @@ absl::StatusOr AcquireNcclComm( // operations that does not lead to deadlocks. class NcclCliqueCommunicators { public: - NcclCliqueCommunicators(NcclCliqueKey clique_key, NcclCliqueId, - absl::node_hash_map communicators); + NcclCliqueCommunicators( + NcclCliqueKey clique_key, std::optional clique_id, + absl::btree_map communicators); // Returns a NCCL communicator for a given rank if it's in a clique. - std::optional comm(int32_t rank); + std::optional comm(int32_t rank); + + // Return true if clique is local: all communicators belong to current + // process. Non-local cliques spans multiple processes (typically hosts). + bool IsLocal() const; // Calls `fn` for each communicator in the clique. - void ForEachComm(absl::FunctionRef fn); + void ForEachComm( + absl::FunctionRef fn); const NcclCliqueKey& clique_key() const { return clique_key_; } - const NcclCliqueId& clique_id() const { return clique_id_; } - size_t size() const { return communicators_.size(); } + const std::optional& clique_id() const { return clique_id_; } + size_t num_communicators() const { return communicators_.size(); } std::string DebugString() const; private: NcclCliqueKey clique_key_; - NcclCliqueId clique_id_; + std::optional clique_id_; // TODO(ezhulenev): Switch this map to GlobalDeviceId key. - absl::node_hash_map communicators_; + absl::btree_map communicators_; }; struct NcclCliqueName { @@ -140,10 +113,15 @@ struct NcclCliqueName { }; struct NcclClique : public Lockable { - NcclClique(NcclCliqueKey clique_key, NcclCliqueId clique_id, - absl::node_hash_map communicators) - : Lockable(NcclCliqueCommunicators{std::move(clique_key), clique_id, - std::move(communicators)}) {} + // We keep acquired cliques in a sorted container to guarantee that all + // participants iterate over cliques in the same order. + using AcquiredCliquesMap = + absl::btree_map, + std::greater>; + + NcclClique(NcclCliqueKey clique_key, std::optional clique_id, + absl::btree_map communicators) + : Lockable(std::move(clique_key), clique_id, std::move(communicators)) {} std::string DebugString() const; }; @@ -151,10 +129,16 @@ struct NcclClique : public Lockable { // Acquires an shared access to a NCCL clique (NcclClique::Lock collectively // owned by `num_local_participants` threads). XLA uses this lock to serialize // execution of all collective operations sharing a `clique_id`. +// +// If clique for a given key does not exist it will be initialized from newly +// created communicators or maybe created by splitting of the already acquired +// cliques. absl::StatusOr> AcquireNcclClique( - RunId run_id, OpId op_id, NcclCliqueKey clique_key, + se::StreamExecutor* device, RunId run_id, NcclCliqueKey clique_key, const NcclCliqueIdCallback& clique_id_callback, int32_t rank, - size_t num_local_participants, bool may_skip_rendezvous); + size_t num_local_participants, + const NcclClique::AcquiredCliquesMap& acquired_cliques, + int64_t max_nchannels = 0); } // namespace xla::gpu diff --git a/third_party/xla/xla/service/gpu/nccl_clique_key.cc b/third_party/xla/xla/service/gpu/nccl_clique_key.cc index 377c589679e29c..d8348d9d7a0c7a 100644 --- a/third_party/xla/xla/service/gpu/nccl_clique_key.cc +++ b/third_party/xla/xla/service/gpu/nccl_clique_key.cc @@ -26,7 +26,7 @@ limitations under the License. #include "absl/algorithm/container.h" #include "absl/status/status.h" #include "absl/status/statusor.h" -#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" #include "absl/types/span.h" #include "xla/service/global_device_id.h" @@ -37,8 +37,10 @@ namespace xla::gpu { //===----------------------------------------------------------------------===// NcclCliqueKey::NcclCliqueKey(std::vector devices, - int64_t stream_id) - : devices_(std::move(devices)), stream_id_(stream_id) {} + int64_t stream_id, AsyncStreamKind stream_kind) + : devices_(std::move(devices)), + stream_id_(stream_id), + stream_kind_(stream_kind) {} absl::Span NcclCliqueKey::devices() const { return devices_; @@ -51,9 +53,16 @@ std::optional NcclCliqueKey::rank(GlobalDeviceId id) const { return std::nullopt; } +bool NcclCliqueKey::IsSubsetOf(const NcclCliqueKey& other) const { + return stream_id_ == other.stream_id_ && + absl::c_all_of(devices_, [&](GlobalDeviceId id) { + return absl::c_linear_search(other.devices_, id); + }); +} + std::string NcclCliqueKey::ToString() const { - return absl::StrCat("devices=", GlobalDeviceIdsToString(devices_), - "; stream=", stream_id_); + return absl::StrFormat("devices=[%s]; stream=%d", + GlobalDeviceIdsToString(devices_), stream_id_); } bool operator==(const NcclCliqueKey& a, const NcclCliqueKey& b) { @@ -61,10 +70,25 @@ bool operator==(const NcclCliqueKey& a, const NcclCliqueKey& b) { } bool operator<(const NcclCliqueKey& a, const NcclCliqueKey& b) { - if (a.stream_id_ < b.stream_id_) return true; - if (b.stream_id_ < a.stream_id_) return false; + if (a.devices_.size() < b.devices_.size()) return true; + if (b.devices_.size() < a.devices_.size()) return false; + + if (a.devices_ < b.devices_) return true; + if (b.devices_ < a.devices_) return false; + + return a.stream_id_ < b.stream_id_; +} + +bool operator>(const NcclCliqueKey& a, const NcclCliqueKey& b) { + if (a.devices_.size() > b.devices_.size()) return true; + if (b.devices_.size() > a.devices_.size()) return false; + + if (a.devices_ > b.devices_) return true; + if (b.devices_ > a.devices_) return false; - return a.devices_ < b.devices_; + // We still use `<` to order by stream id as we want to acquire sync cliques + // before async ones. + return a.stream_id_ < b.stream_id_; } //===----------------------------------------------------------------------===// @@ -80,8 +104,8 @@ NcclCliqueId::NcclCliqueId(char bytes[kSize]) { absl::StatusOr NcclCliqueId::FromString(std::string_view str) { if (str.size() != kSize) { return absl::InvalidArgumentError( - absl::StrCat("Invalid NCCL clique id size: ", str.size(), ", expected ", - kSize, " bytes")); + absl::StrFormat("Invalid NCCL clique id size: %d , expected %d bytes", + str.size(), kSize)); } char bytes[kSize]; std::copy(str.data(), str.data() + kSize, bytes); diff --git a/third_party/xla/xla/service/gpu/nccl_clique_key.h b/third_party/xla/xla/service/gpu/nccl_clique_key.h index ef4df883307372..46479d6027dd87 100644 --- a/third_party/xla/xla/service/gpu/nccl_clique_key.h +++ b/third_party/xla/xla/service/gpu/nccl_clique_key.h @@ -73,14 +73,24 @@ inline uint64_t GetStreamId( // executable. class NcclCliqueKey { public: - explicit NcclCliqueKey(std::vector devices, - int64_t stream_id = 0); + explicit NcclCliqueKey( + std::vector devices, int64_t stream_id = 0, + AsyncStreamKind stream_kind = AsyncStreamKind::kCollective); absl::Span devices() const; // Returns the rank of the global device in the clique. std::optional rank(GlobalDeviceId id) const; + // Returns true if this clique is a subset of `other`: both cliques have the + // same `stream_id` and all clique devices are part of `other` clique. + bool IsSubsetOf(const NcclCliqueKey& other) const; + + // Returns the stream kind for this clique key, + // stream kind will be used to specify what configuration + // to pass for each type of operation. + AsyncStreamKind stream_kind() const { return stream_kind_; } + std::string ToString() const; template @@ -88,10 +98,12 @@ class NcclCliqueKey { friend bool operator==(const NcclCliqueKey& a, const NcclCliqueKey& b); friend bool operator<(const NcclCliqueKey& a, const NcclCliqueKey& b); + friend bool operator>(const NcclCliqueKey& a, const NcclCliqueKey& b); private: const std::vector devices_; const int64_t stream_id_; + AsyncStreamKind stream_kind_; }; template diff --git a/third_party/xla/xla/service/gpu/nccl_clique_key_test.cc b/third_party/xla/xla/service/gpu/nccl_clique_key_test.cc new file mode 100644 index 00000000000000..bfee6e7174dc4f --- /dev/null +++ b/third_party/xla/xla/service/gpu/nccl_clique_key_test.cc @@ -0,0 +1,72 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/nccl_clique_key.h" + +#include +#include + +#include "absl/container/btree_map.h" +#include "xla/service/global_device_id.h" +#include "tsl/platform/test.h" + +namespace xla::gpu { + +TEST(NcclCliqueKeyTest, IsSubsetOf) { + GlobalDeviceId id0 = GlobalDeviceId(0); + GlobalDeviceId id1 = GlobalDeviceId(1); + GlobalDeviceId id2 = GlobalDeviceId(2); + GlobalDeviceId id3 = GlobalDeviceId(3); + + NcclCliqueKey key0({id0, id1}, 0); + NcclCliqueKey key1({id0, id1, id2, id3}, 0); + NcclCliqueKey key2({id0, id1, id2, id3}, 1); + NcclCliqueKey key3({id1, id2, id3}, 0); + + EXPECT_TRUE(key0.IsSubsetOf(key1)); + EXPECT_FALSE(key0.IsSubsetOf(key2)); + EXPECT_FALSE(key0.IsSubsetOf(key3)); +} + +TEST(NcclCliqueKeyTest, LargerCliqueGoFirst) { + GlobalDeviceId id0 = GlobalDeviceId(0); + GlobalDeviceId id1 = GlobalDeviceId(1); + GlobalDeviceId id2 = GlobalDeviceId(2); + GlobalDeviceId id3 = GlobalDeviceId(3); + + NcclCliqueKey key0({id0, id1}, 0); + NcclCliqueKey key1({id1, id2, id3}, 0); + + EXPECT_LT(key0, key1); + EXPECT_GT(key1, key0); +} + +TEST(NcclCliqueKeyTest, BtreeIterationOrder) { + GlobalDeviceId id0 = GlobalDeviceId(0); + GlobalDeviceId id1 = GlobalDeviceId(1); + GlobalDeviceId id2 = GlobalDeviceId(2); + GlobalDeviceId id3 = GlobalDeviceId(3); + + NcclCliqueKey key0({id0, id2}, 0); + NcclCliqueKey key1({id0, id1, id2, id3}, 0); + + absl::btree_map> map; + map[key0] = 0; + map[key1] = 1; + + EXPECT_EQ(map.begin()->first, key1); +} + +} // namespace xla::gpu diff --git a/third_party/xla/xla/service/gpu/nccl_collective_permute_thunk.cc b/third_party/xla/xla/service/gpu/nccl_collective_permute_thunk.cc index a46eb0279b5807..8ccef7e1586be0 100644 --- a/third_party/xla/xla/service/gpu/nccl_collective_permute_thunk.cc +++ b/third_party/xla/xla/service/gpu/nccl_collective_permute_thunk.cc @@ -302,7 +302,7 @@ absl::Status RunCollectivePermute( // buffer. VLOG(3) << absl::StreamFormat("%s : collective-Permute: Issuing MemZero", device_string); - stream.ThenMemZero(&dest_addr, dest_addr.size()); + TF_RETURN_IF_ERROR(stream.MemZero(&dest_addr, dest_addr.size())); } return absl::OkStatus(); } diff --git a/third_party/xla/xla/service/gpu/nccl_collective_thunk.cc b/third_party/xla/xla/service/gpu/nccl_collective_thunk.cc index 54615fcfe6d9df..2b6ce21f1404c7 100644 --- a/third_party/xla/xla/service/gpu/nccl_collective_thunk.cc +++ b/third_party/xla/xla/service/gpu/nccl_collective_thunk.cc @@ -20,6 +20,7 @@ limitations under the License. #include #include #include +#include #include #include @@ -32,6 +33,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/strings/str_format.h" #include "absl/synchronization/mutex.h" +#include "absl/time/time.h" #include "mlir/IR/Value.h" // from @llvm-project #include "xla/hlo/ir/hlo_instructions.h" #include "xla/layout_util.h" @@ -45,9 +47,9 @@ limitations under the License. #include "xla/service/gpu/nccl_clique.h" #include "xla/service/gpu/nccl_clique_key.h" #include "xla/service/gpu/thunk.h" +#include "xla/service/rendezvous.h" #include "xla/shape.h" #include "xla/status.h" -#include "xla/status_macros.h" #include "xla/statusor.h" #include "xla/stream_executor/event.h" #include "xla/stream_executor/gpu/gpu_activation.h" @@ -56,6 +58,11 @@ limitations under the License. #include "tsl/platform/errors.h" #include "tsl/platform/statusor.h" +#if GOOGLE_CUDA +#include "xla/stream_executor/gpu/gpu_driver.h" +#include "xla/stream_executor/gpu/gpu_types.h" +#endif // GOOGLE_CUDA + namespace xla { namespace gpu { namespace { @@ -210,11 +217,11 @@ NcclCollectiveThunk::NcclCollectiveThunk(Kind kind, ThunkInfo thunk_info, nccl_api_(nccl_api), async_events_(is_sync ? nullptr : new AsyncEvents()) {} -absl::StatusOr GetNcclComm( +static absl::StatusOr GetNcclCliqueKey( const Thunk::CollectiveExecuteParams& params, - const Thunk::CollectiveCliques& collective_cliques, const std::vector& replica_groups, - CollectiveOpGroupMode group_mode, int64_t stream_id) { + CollectiveOpGroupMode group_mode, int64_t stream_id, + AsyncStreamKind stream_kind) { GlobalDeviceId global_device_id = params.global_device_id; TF_ASSIGN_OR_RETURN( @@ -229,59 +236,21 @@ absl::StatusOr GetNcclComm( "environment configuration."); } - NcclCliqueKey clique_key(std::move(participants), stream_id); - std::optional rank = clique_key.rank(global_device_id); - - return collective_cliques.GetComm(std::move(clique_key), *rank); + return NcclCliqueKey(std::move(participants), stream_id, stream_kind); } -// TODO(ezhulenev): This is a deprecated code path and should be removed after -// all users in legacy XLA runtime are removed. -absl::StatusOr LockNcclComm( +absl::StatusOr GetNcclComm( const Thunk::CollectiveExecuteParams& params, + const Thunk::CollectiveCliques& collective_cliques, const std::vector& replica_groups, - CollectiveOpGroupMode group_mode, int64_t op_id, int64_t stream_id, - bool enable_clique_optimization) { - GlobalDeviceId global_device_id = params.global_device_id; - - TF_ASSIGN_OR_RETURN( - std::vector participants, - GetParticipatingDevices(global_device_id, *params.device_assn, - replica_groups, group_mode)); - - if (IsGlobalNcclConfig() && - (participants.size() != params.device_assn->replica_count())) { - return InvalidArgument( - "Partial replica groups are not allowed when using NCCL_COMM_ID " - "environment configuration."); - } - - auto it = absl::c_find(participants, global_device_id); - TF_RET_CHECK(it != participants.end()); - int rank = it - participants.begin(); - - std::vector local_devices; - if (params.global_device_id_map) { - local_devices.reserve(params.global_device_id_map->size()); - for (const auto& entry : *params.global_device_id_map) { - local_devices.push_back(entry.second); - } - } - size_t num_local_participants = GetNumLocalParticipants( - participants, params.global_device_id_map ? &local_devices : nullptr); - - bool is_local = participants.size() == num_local_participants; - TF_ASSIGN_OR_RETURN( - const NcclCliqueIdCallback* clique_id_callback, - GetNcclCliqueIdCallback(params.nccl_clique_id_callback, is_local)); + CollectiveOpGroupMode group_mode, int64_t stream_id, + AsyncStreamKind stream_kind) { + TF_ASSIGN_OR_RETURN(NcclCliqueKey clique_key, + GetNcclCliqueKey(params, replica_groups, group_mode, + stream_id, stream_kind)); -#ifdef GOOGLE_CUDA - se::gpu::ScopedActivateExecutorContext scoped_context(params.stream_executor); -#endif // GOOGLE_CUDA - - return AcquireNcclComm(params.run_id, OpId(op_id), std::move(participants), - num_local_participants, *clique_id_callback, rank, - stream_id, enable_clique_optimization); + std::optional rank = clique_key.rank(params.global_device_id); + return collective_cliques.GetComm(std::move(clique_key), *rank); } absl::StatusOr> ConvertToDeviceBuffers( @@ -311,39 +280,59 @@ absl::StatusOr> ConvertToDeviceBuffers( return device_buffers; } -Status MaybeRegisterBuffers(NcclApi* nccl_api, int device_ordinal, - const std::vector& buffers, - NcclApi::NcclCommHandle comm) { +Status RegisterBufferOnce(NcclApi* nccl_api, int device_ordinal, + NcclApi::NcclCommHandle comm, + se::DeviceMemoryBase buffer) { // Keep track of which communicators we have registered for already. - // Each device has one NCCL buffer which only needs to be registered once per - // each comm. + // Each ncclMemAlloc'd buffer needs to be registered once per comm. struct RegisteredBuffers { absl::Mutex mu; - absl::flat_hash_map> - per_device_comms ABSL_GUARDED_BY(mu); + // Device ordinal, communicator, and base pointer address. + absl::flat_hash_set> records + ABSL_GUARDED_BY(mu); // Buffers could be deregistered with ncclCommDeregister. std::vector handles ABSL_GUARDED_BY(mu); }; static auto& all_registered = *new RegisteredBuffers; + // Since each XLA buffer is a slice into a larger BFCAllocator chunk, first + // get the base address of buffer. We will use the base address to keep track + // of which chunks we have registered. + void* base_ptr; + size_t base_size; +#ifdef GOOGLE_CUDA + TF_RETURN_IF_ERROR(se::gpu::GpuDriver::GetPointerAddressRange( + reinterpret_cast(buffer.opaque()), + reinterpret_cast(&base_ptr), &base_size)); +#else // GOOGLE_CUDA + base_ptr = nullptr; + base_size = 0; +#endif // GOOGLE_CUDA + absl::MutexLock lock(&all_registered.mu); + if (!all_registered.records.contains({device_ordinal, comm, base_ptr})) { + // ncclCommRegister will internally get and use the base address/size of the + // address we provide. + TF_ASSIGN_OR_RETURN(NcclApi::NcclRegisteredBufferHandle handle, + nccl_api->RegisterBuffer(comm, buffer)); + all_registered.handles.push_back(handle); + all_registered.records.insert({device_ordinal, comm, base_ptr}); + } + return OkStatus(); +} + +Status MaybeRegisterBuffers(NcclApi* nccl_api, int device_ordinal, + const std::vector& buffers, + NcclApi::NcclCommHandle comm) { for (int i = 0; i < buffers.size(); ++i) { - if (!all_registered.per_device_comms[device_ordinal].contains(comm)) { - if (buffers[i].source_memory_space == kCollectiveMemorySpaceColor) { - TF_ASSIGN_OR_RETURN( - NcclApi::NcclRegisteredBufferHandle handle, - nccl_api->RegisterBuffer(comm, buffers[i].source_buffer)); - all_registered.handles.push_back(handle); - all_registered.per_device_comms[device_ordinal].insert(comm); - } - if (buffers[i].destination_memory_space == kCollectiveMemorySpaceColor) { - TF_ASSIGN_OR_RETURN( - NcclApi::NcclRegisteredBufferHandle handle, - nccl_api->RegisterBuffer(comm, buffers[i].destination_buffer)); - all_registered.handles.push_back(handle); - all_registered.per_device_comms[device_ordinal].insert(comm); - } + if (buffers[i].source_memory_space == kCollectiveMemorySpaceColor) { + TF_RETURN_IF_ERROR(RegisterBufferOnce(nccl_api, device_ordinal, comm, + buffers[i].source_buffer)); + } + if (buffers[i].destination_memory_space == kCollectiveMemorySpaceColor) { + TF_RETURN_IF_ERROR(RegisterBufferOnce(nccl_api, device_ordinal, comm, + buffers[i].destination_buffer)); } } return OkStatus(); @@ -398,9 +387,9 @@ absl::Status NcclCollectiveThunk::Prepare(const PrepareParams& params, size_t num_local_participants = GetNumLocalParticipants( participants, collectives->global_device_id_map ? &local_devices : nullptr); - + AsyncStreamKind stream_kind = GetAsyncStreamKind(); return resource_requests.AddClique( - NcclCliqueKey(std::move(participants), GetStreamId()), + NcclCliqueKey(std::move(participants), GetStreamId(), stream_kind), num_local_participants); } @@ -411,48 +400,88 @@ absl::Status NcclCollectiveThunk::Initialize(const InitializeParams& params) { return absl::OkStatus(); } +namespace { +// Wrap NcclCliqueKey into a unique struct to guarantee we do not accidentally +// try to run multiple unrelated rendezvous for a same key. +struct FirstCallRendezvousKey { + NcclCliqueKey clique_key; + + template + friend H AbslHashValue(H h, const FirstCallRendezvousKey& key) { + return H::combine(std::move(h), key.clique_key); + } +}; + +bool operator==(const FirstCallRendezvousKey& a, + const FirstCallRendezvousKey& b) { + return a.clique_key == b.clique_key; +} +} // namespace + Status NcclCollectiveThunk::ExecuteOnStream(const ExecuteParams& params) { VLOG(1) << absl::StreamFormat("Starting %s %s.", IsAsync() ? "async" : "sync", Thunk::KindToString(kind())); const int64_t stream_id = GetStreamId(); + AsyncStreamKind stream_kind = GetAsyncStreamKind(); TF_ASSIGN_OR_RETURN( - NcclComm::Lock comm, + NcclApi::NcclCommHandle comm, GetNcclComm(*params.collective_params, *params.collective_cliques, - config().replica_groups, config().group_mode, stream_id)); + config().replica_groups, config().group_mode, stream_id, + stream_kind)); se::StreamExecutor* executor = params.stream->parent(); - int64_t async_stream_idx = static_cast(GetAsyncStreamKind()); + int64_t async_stream_idx = static_cast(stream_kind); if (IsAsync()) { // Launch collective operation on an async stream. se::Stream& async_stream = *params.async_comms_streams[async_stream_idx]; // Wait for main compute stream to make sure all buffers are ready. - async_stream.ThenWaitFor(params.stream); + TF_RETURN_IF_ERROR(async_stream.WaitFor(params.stream)); - TF_RETURN_IF_ERROR(RunNcclCollective(params, async_stream, *comm)); + TF_RETURN_IF_ERROR(RunNcclCollective(params, async_stream, comm)); // Record collective operation completion. TF_ASSIGN_OR_RETURN(se::Event * event, async_events_->GetEvent(executor)); - async_stream.ThenRecordEvent(event); + TF_RETURN_IF_ERROR(async_stream.RecordEvent(event)); } else { // Launch collective operation on a main stream. - TF_RETURN_IF_ERROR(RunNcclCollective(params, *params.stream, *comm)); + TF_RETURN_IF_ERROR(RunNcclCollective(params, *params.stream, comm)); } - // Block host on the first call to ensure that all devices have allocated the - // required buffers for their communicators before allowing any device to - // continue enqueuing operations. Otherwise, the allocations can cause - // deadlock in the CUDA driver (b/215649390). - // - // TODO(ezhulenev): This can be removed with shared cliques acquisition. - if (first_call_to_execute_) { - se::Stream* stream = IsAsync() - ? params.async_comms_streams[async_stream_idx] - : params.stream; - TF_RETURN_IF_ERROR(stream->BlockHostUntilDone()); - first_call_to_execute_ = false; + // After a first execution of this instance of collective operation do a + // rendezvous with other participants to make sure that all of them allocated + // required state (internal to NCCL) and ready to continue. Going too far + // ahead on one rank leads to deadlocks in NCCL. + if (NeedFirstCallRendzevous() && !first_call_rendezvous_flag_.IsCompleted()) { + TF_ASSIGN_OR_RETURN( + NcclCliqueKey clique_key, + GetNcclCliqueKey(*params.collective_params, config().replica_groups, + config().group_mode, stream_id, stream_kind)); + + TF_ASSIGN_OR_RETURN( + size_t num_local_participants, + params.collective_cliques->num_communicators(clique_key)); + + auto global_device_id = params.collective_params->global_device_id; + VLOG(1) << "Do a rendezvous after a first call to " + << Thunk::KindToString(kind()) + << "; run_id=" << params.collective_params->run_id.ToInt() + << "; op_id=" << config().op_id + << "; num_local_participants=" << num_local_participants + << "; rank=" << clique_key.rank(global_device_id).value_or(-1) + << "; clique_key=" << clique_key.ToString(); + + auto rendezvous_key = FirstCallRendezvousKey{std::move(clique_key)}; + auto rendezvous_name = absl::StrFormat( + "first call to collective operation %d; run_id=%d", config().op_id, + params.collective_params->run_id.ToInt()); + + RendezvousSingle(first_call_rendezvous_flag_, rendezvous_name, + rendezvous_key, num_local_participants, + /*warn_stuck_timeout=*/absl::Seconds(10), + /*terminate_timeout=*/absl::Seconds(30)); } return absl::OkStatus(); @@ -479,8 +508,7 @@ absl::Status NcclCollectiveDoneThunk::ExecuteOnStream( const ExecuteParams& params) { se::StreamExecutor* executor = params.stream->parent(); TF_ASSIGN_OR_RETURN(se::Event * event, async_events_->GetEvent(executor)); - params.stream->ThenWaitFor(event); - return absl::OkStatus(); + return params.stream->WaitFor(event); } absl::Status IsValidOperand(mlir::Value operand, Thunk::Kind reduction_op) { diff --git a/third_party/xla/xla/service/gpu/nccl_collective_thunk.h b/third_party/xla/xla/service/gpu/nccl_collective_thunk.h index 9e9adadc3b1b84..c687dc78de97ce 100644 --- a/third_party/xla/xla/service/gpu/nccl_collective_thunk.h +++ b/third_party/xla/xla/service/gpu/nccl_collective_thunk.h @@ -36,13 +36,13 @@ limitations under the License. #include "xla/service/collective_ops_utils.h" #include "xla/service/global_device_id.h" #include "xla/service/gpu/buffer_allocations.h" -#include "xla/service/gpu/gpu_executable_run_options.h" #include "xla/service/gpu/ir_emission_utils.h" #include "xla/service/gpu/nccl_api.h" #include "xla/service/gpu/nccl_clique.h" #include "xla/service/gpu/nccl_clique_key.h" #include "xla/service/gpu/thunk.h" #include "xla/service/llvm_ir/llvm_util.h" +#include "xla/service/rendezvous.h" #include "xla/shape.h" #include "xla/status.h" #include "xla/stream_executor/device_memory.h" @@ -169,16 +169,34 @@ class NcclCollectiveThunk : public Thunk { return AsyncStreamKind::kCollective; } + // A collective thunk is normally an independent operation in a sense that + // different instances of the same collective thunk communicate each other. + // The only exception are SendThunk and RecvThunk. Assume two devices are + // executing a program contains the following instructions, the Recv from + // device 1 will release the Send from device 0. Adding first call + // rendezvous on the SendThunk would cause a runtime deadlock. + // Send(src_target={0,1}) + // Recv(src_target={0,1}) + virtual bool NeedFirstCallRendzevous() const { return true; } + private: bool IsAsync() const { return async_events_ != nullptr; } int64_t GetStreamId() const { return xla::gpu::GetStreamId(IsAsync(), GetAsyncStreamKind()); } - bool first_call_to_execute_ = true; - NcclApi* nccl_api_; std::shared_ptr async_events_; + + // After a first call to this particular instance of a NCCL collective thunk + // we do a round of rendezvous to make sure that all participants successfully + // allocated on-device state required for executing collective operation. This + // is required to avoid deadlocks when one device goes too far ahead and + // causes a deadlock in CUDA driver (root cause is mysterious). + // + // TODO(ezhulenev): Try to move this flag to NCCL clique as we need to make + // sure that all NCCL resources are allocated just once. + RendezvousSingleFlag first_call_rendezvous_flag_; }; //===----------------------------------------------------------------------===// @@ -239,26 +257,18 @@ size_t GetNumLocalParticipants( const std::vector& participants, const std::vector* local_devices); // may be null -absl::StatusOr GetNcclComm( +absl::StatusOr GetNcclComm( const Thunk::CollectiveExecuteParams& params, const Thunk::CollectiveCliques& collective_cliques, const std::vector& replica_groups, - CollectiveOpGroupMode group_mode, int64_t stream_id); - -// TODO(ezhulenev): This is a deprecated code path and should be removed after -// all users in legacy XLA runtime are removed. -absl::StatusOr LockNcclComm( - const Thunk::CollectiveExecuteParams& params, - const std::vector& replica_groups, - CollectiveOpGroupMode group_mode, int64_t op_id, int64_t stream_id, - bool enable_clique_optimization); + CollectiveOpGroupMode group_mode, int64_t stream_id, + AsyncStreamKind stream_kind); struct DeviceBufferPair { PrimitiveType element_type; int64_t element_count; se::DeviceMemoryBase source_buffer; se::DeviceMemoryBase destination_buffer; - // TODO(b/320767790): Remove once memory space added to DeviceMemoryBase. int64_t source_memory_space; int64_t destination_memory_space; }; diff --git a/third_party/xla/xla/service/gpu/nccl_recv_thunk.cc b/third_party/xla/xla/service/gpu/nccl_recv_thunk.cc index f8218c4c6188c3..346b876a3052b3 100644 --- a/third_party/xla/xla/service/gpu/nccl_recv_thunk.cc +++ b/third_party/xla/xla/service/gpu/nccl_recv_thunk.cc @@ -139,7 +139,7 @@ absl::Status RunRecv(NcclApi* nccl_api, // the destination buffer. VLOG(3) << absl::StreamFormat("%s : collective-Permute: Issuing MemZero", device_string); - stream.ThenMemZero(&dest_addr, dest_addr.size()); + TF_RETURN_IF_ERROR(stream.MemZero(&dest_addr, dest_addr.size())); } return absl::OkStatus(); } diff --git a/third_party/xla/xla/service/gpu/nccl_recv_thunk.h b/third_party/xla/xla/service/gpu/nccl_recv_thunk.h index b9fb27ef0ea079..9adb6118076619 100644 --- a/third_party/xla/xla/service/gpu/nccl_recv_thunk.h +++ b/third_party/xla/xla/service/gpu/nccl_recv_thunk.h @@ -55,6 +55,7 @@ class NcclRecvThunk : public NcclCollectiveThunk { AsyncStreamKind GetAsyncStreamKind() const override { return AsyncStreamKind::kP2P; } + bool NeedFirstCallRendzevous() const override { return false; } private: const NcclP2PConfig config_; diff --git a/third_party/xla/xla/service/gpu/nccl_send_thunk.h b/third_party/xla/xla/service/gpu/nccl_send_thunk.h index 89122079414b89..192679bfcd1109 100644 --- a/third_party/xla/xla/service/gpu/nccl_send_thunk.h +++ b/third_party/xla/xla/service/gpu/nccl_send_thunk.h @@ -53,6 +53,7 @@ class NcclSendThunk : public NcclCollectiveThunk { AsyncStreamKind GetAsyncStreamKind() const override { return AsyncStreamKind::kP2P; } + bool NeedFirstCallRendzevous() const override { return false; } private: const NcclP2PConfig config_; diff --git a/third_party/xla/xla/service/gpu/non_atomically_upgradeable_rw_lock.h b/third_party/xla/xla/service/gpu/non_atomically_upgradeable_rw_lock.h deleted file mode 100644 index 7a6ad29f2087f7..00000000000000 --- a/third_party/xla/xla/service/gpu/non_atomically_upgradeable_rw_lock.h +++ /dev/null @@ -1,95 +0,0 @@ -/* Copyright 2023 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_SERVICE_GPU_NON_ATOMICALLY_UPGRADEABLE_RW_LOCK_H_ -#define XLA_SERVICE_GPU_NON_ATOMICALLY_UPGRADEABLE_RW_LOCK_H_ - -#include -#include - -#include "absl/synchronization/mutex.h" - -namespace xla { -namespace gpu { - -// Augments absl::ReaderMutexLock with a poor man's upgrade/downgrade pair using -// RAII. Instead of a true upgrade (or downgrade), we simply drop the read -// (write) lock and then reacquire it as a write (read) lock. -class ABSL_SCOPED_LOCKABLE NonAtomicallyUpgradeableRWLock { - public: - explicit NonAtomicallyUpgradeableRWLock(absl::Mutex* mu) - ABSL_SHARED_LOCK_FUNCTION(mu) - : mu_(mu), is_reader_(true) { - mu_->ReaderLock(); - } - - NonAtomicallyUpgradeableRWLock(const NonAtomicallyUpgradeableRWLock&) = - delete; - NonAtomicallyUpgradeableRWLock(NonAtomicallyUpgradeableRWLock&&) = delete; - NonAtomicallyUpgradeableRWLock& operator=( - const NonAtomicallyUpgradeableRWLock&) = delete; - NonAtomicallyUpgradeableRWLock& operator=(NonAtomicallyUpgradeableRWLock&&) = - delete; - - ~NonAtomicallyUpgradeableRWLock() ABSL_UNLOCK_FUNCTION() { - if (is_reader_) { - mu_->ReaderUnlock(); - } else { - mu_->WriterUnlock(); - } - } - - // Upgrade and downgrade the reader lock via RAII. - class ABSL_SCOPED_LOCKABLE WriterLock { - public: - explicit WriterLock(NonAtomicallyUpgradeableRWLock* parent) - ABSL_EXCLUSIVE_LOCK_FUNCTION(parent->mu_) - : parent_(parent) { - assert(parent_->is_reader_); - parent_->mu_->ReaderUnlock(); - parent_->mu_->WriterLock(); - parent_->is_reader_ = false; - } - - WriterLock(const WriterLock&) = delete; - WriterLock(WriterLock&&) = delete; - WriterLock& operator=(const WriterLock&) = delete; - WriterLock& operator=(WriterLock&&) = delete; - - ~WriterLock() ABSL_UNLOCK_FUNCTION() { - parent_->mu_->WriterUnlock(); - parent_->mu_->ReaderLock(); - parent_->is_reader_ = true; - } - - private: - NonAtomicallyUpgradeableRWLock* parent_; - }; - - // Update the reader lock to a writer lock. The function is invalid if the - // lock is already upgraded. - WriterLock UpgradeToWriterMutexLock() ABSL_NO_THREAD_SAFETY_ANALYSIS { - return WriterLock(this); - } - - private: - absl::Mutex* const mu_; - bool is_reader_; -}; - -} // namespace gpu -} // namespace xla - -#endif // XLA_SERVICE_GPU_NON_ATOMICALLY_UPGRADEABLE_RW_LOCK_H_ diff --git a/third_party/xla/xla/service/gpu/non_atomically_upgradeable_rw_lock_test.cc b/third_party/xla/xla/service/gpu/non_atomically_upgradeable_rw_lock_test.cc deleted file mode 100644 index d163cb77ac5cc8..00000000000000 --- a/third_party/xla/xla/service/gpu/non_atomically_upgradeable_rw_lock_test.cc +++ /dev/null @@ -1,45 +0,0 @@ -/* Copyright 2023 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "xla/service/gpu/non_atomically_upgradeable_rw_lock.h" - -#include -#include "tsl/platform/test.h" - -namespace xla { -namespace gpu { -namespace { - -TEST(NonAtomicallyUpgradeableRWLock, UpgradeReaderMutexLock) { - absl::Mutex mu; - { - NonAtomicallyUpgradeableRWLock reader_lock(&mu); - mu.AssertReaderHeld(); - - { - NonAtomicallyUpgradeableRWLock::WriterLock writer_lock = - reader_lock.UpgradeToWriterMutexLock(); - mu.AssertHeld(); - } - - // The lock downgrades after the WriterLock goes out of scope. - mu.AssertReaderHeld(); - } - mu.AssertNotHeld(); -} - -} // namespace -} // namespace gpu -} // namespace xla diff --git a/third_party/xla/xla/service/gpu/nvptx_compiler.cc b/third_party/xla/xla/service/gpu/nvptx_compiler.cc index 5f9848b3622ba6..9b4764afa8e6eb 100644 --- a/third_party/xla/xla/service/gpu/nvptx_compiler.cc +++ b/third_party/xla/xla/service/gpu/nvptx_compiler.cc @@ -157,7 +157,7 @@ absl::Status NVPTXCompiler::OptimizeHloConvolutionCanonicalization( /*layout_sensitive=*/false, /*allow_mixed_precision=*/false); - // Convert upsupported bf16 convolutions to f32. + // Convert unsupported bf16 convolutions to f32. ConvBfloat16Support conv_bf16_support(dnn_version, cuda_compute_capability); pipeline.AddPass(&conv_bf16_support); @@ -174,7 +174,8 @@ absl::Status NVPTXCompiler::OptimizeHloConvolutionCanonicalization( pipeline.AddPass(); pipeline.AddPass(); - AlgebraicSimplifierOptions algsimp_options; + AlgebraicSimplifierOptions algsimp_options = + GetAlgebraicSimplifierOptions(hlo_module->config()); algsimp_options.set_enable_conv_operand_swap(false); algsimp_options.set_enable_unconditional_reduce_of_concat_replacement(false); pipeline.AddPass>(algsimp_options); @@ -234,7 +235,8 @@ absl::Status NVPTXCompiler::OptimizeHloPostLayoutAssignment( const DebugOptions& debug_options = hlo_module->config().debug_options(); // The LayoutAssignment pass may leave behind kCopy instructions which are // duplicate or NOPs, so remove them with algebraic simplification and CSE. - AlgebraicSimplifierOptions alg_sim_options; + AlgebraicSimplifierOptions alg_sim_options = + GetAlgebraicSimplifierOptions(hlo_module->config()); alg_sim_options.set_supports_non_canonical_dots(false); alg_sim_options.set_is_layout_sensitive(true); alg_sim_options.set_enable_conv_operand_swap(false); @@ -514,8 +516,7 @@ NVPTXCompiler::CompileTargetBinary(const HloModuleConfig& module_config, (debug_module != nullptr ? debug_module->name() : "(unknown)"), relocatable, options); - if (maybe_cubin.status().code() == absl::StatusCode::kCancelled || - maybe_cubin.status().code() == absl::StatusCode::kResourceExhausted) { + if (!maybe_cubin.ok()) { return maybe_cubin.status(); } return BackendCompileResult{std::move(ptx), std::move(maybe_cubin.value())}; @@ -613,13 +614,10 @@ static absl::StatusOr> AssembleOptionsAndCompile( } if (maybe_cubin.status().code() != absl::StatusCode::kUnimplemented) { - // If unimplemented is returned, we fallback to the driver. - LOG(FATAL) << "ptxas returned an error during compilation of ptx " - "to sass: '" - << maybe_cubin.status() << "' " - << "If the error message indicates that a file could " - "not be written, please verify that sufficient " - "filesystem space is provided."; + return AppendStatus( + maybe_cubin.status(), + "If the error message indicates that a file could not be written, " + "please verify that sufficient filesystem space is provided."); } return maybe_cubin; @@ -638,42 +636,45 @@ NVPTXCompiler::CompileGpuAsmOrGetCachedResult( !options.is_autotuning_compilation); tsl::profiler::TraceMe activity("PTX->CUBIN", tsl::profiler::TraceMeLevel::kInfo); - auto [iter, inserted] = [&] { + CompilationCacheValue* cache_value = nullptr; + bool inserted = [&] { + auto flags = CompilationCacheFlags{ + hlo_module_config.debug_options() + .xla_gpu_filter_kernels_spilling_registers_on_autotuning()}; absl::MutexLock lock(&mutex_); - return compilation_cache_.emplace( + auto [iter, inserted] = compilation_cache_.emplace( std::piecewise_construct, - std::forward_as_tuple(ptx, cc.major, cc.minor, relocatable), + std::forward_as_tuple(ptx, cc.major, cc.minor, relocatable, flags), std::forward_as_tuple()); + // Do not move this assignment outside of the critical section. There is + // a TOCTOU if `compilation_cache_` is rehashed before the iterator is used. + cache_value = &iter->second; + return inserted; }(); - // Pointers into compilation_cache_ where the ptx and (optional) cuBIN are - // stored. - CompilationCacheValue& cache_value = iter->second; - // Compile the ptx if it wasn't in the cache before we called this function. // Other threads asking for the same compilation key will block on // cache_value->mutex_ until compilation is done. - absl::MutexLock lock(&cache_value.mutex); + absl::MutexLock lock(&cache_value->mutex); if (inserted) { - CHECK(!cache_value.compilation_done); - absl::Cleanup mark_compilation_as_done = [&cache_value] { + CHECK(!cache_value->compilation_done); + absl::Cleanup mark_compilation_as_done = [cache_value] { // Note that we will set this to true also in the error case, so that we // don't retry this compilation. - cache_value.compilation_done = true; - cache_value.compilation_done_cv.SignalAll(); + cache_value->compilation_done = true; + cache_value->compilation_done_cv.SignalAll(); }; - TF_ASSIGN_OR_RETURN(cache_value.cubin_data, - AssembleOptionsAndCompile(ptx, cc, hlo_module_config, - options, relocatable)); - return cache_value.cubin_data; + cache_value->maybe_cubin = AssembleOptionsAndCompile( + ptx, cc, hlo_module_config, options, relocatable); + return cache_value->maybe_cubin; } - while (!cache_value.compilation_done) { - cache_value.compilation_done_cv.Wait(&cache_value.mutex); + while (!cache_value->compilation_done) { + cache_value->compilation_done_cv.Wait(&cache_value->mutex); } - return cache_value.cubin_data; + return cache_value->maybe_cubin; } static std::optional> GetNvLinkVersion( @@ -694,26 +695,23 @@ static std::optional> GetNvLinkVersion( } // Make sure nvlink exists and is executable. - const std::string bin_path = + absl::StatusOr bin_path = se::FindCudaExecutable("nvlink", preferred_cuda_dir); - auto version = se::GetToolVersion(bin_path); + + if (!bin_path.ok()) { + return std::nullopt; + } + + auto version = se::GetToolVersion(bin_path.value()); if (!version.ok()) { return std::nullopt; } return *version; } -absl::StatusOr NVPTXCompiler::ChooseLinkingMethod( - const std::string& preferred_cuda_dir) { - { - absl::MutexLock lock(&mutex_); - auto it = linking_methods_.find(preferred_cuda_dir); - if (it != linking_methods_.end()) { - return it->second; - } - } - - LinkingMethod linking_method = LinkingMethod::kNone; +absl::StatusOr ChooseLinkingMethodImpl( + const DebugOptions& debug_options, const std::string& preferred_cuda_dir) { + using LinkingMethod = NVPTXCompiler::LinkingMethod; TF_ASSIGN_OR_RETURN(auto ptxas_version_tuple, se::GetAsmCompilerVersion(preferred_cuda_dir)); @@ -724,33 +722,57 @@ absl::StatusOr NVPTXCompiler::ChooseLinkingMethod( return absl::InternalError("XLA requires ptxas version 11.8 or higher"); } - static const std::optional> nvlink_version = + std::optional> nvlink_version = GetNvLinkVersion(preferred_cuda_dir); if (nvlink_version && *nvlink_version >= ptxas_version_tuple) { - linking_method = LinkingMethod::kNvLink; - } else { - int ptxas_version = std::get<0>(ptxas_version_tuple) * 1000 + - std::get<1>(ptxas_version_tuple) * 10; - TF_ASSIGN_OR_RETURN(int driver_version, - se::gpu::GpuDriver::GetDriverVersion()); + return LinkingMethod::kNvLink; + } + + int ptxas_version = std::get<0>(ptxas_version_tuple) * 1000 + + std::get<1>(ptxas_version_tuple) * 10; + TF_ASSIGN_OR_RETURN(int driver_version, + se::gpu::GpuDriver::GetDriverVersion()); + + if (driver_version >= ptxas_version) { + return LinkingMethod::kDriver; + } + + LOG_FIRST_N(WARNING, 1) + << "The NVIDIA driver's CUDA version is " + << absl::StrFormat("%d.%d", driver_version / 1000, + (driver_version % 1000) / 10) + << " which is older than the ptxas CUDA version " + << absl::StrFormat("(%d.%d.%d)", std::get<0>(ptxas_version_tuple), + std::get<1>(ptxas_version_tuple), + std::get<2>(ptxas_version_tuple)) + << ". Because the driver is older than the ptxas version, XLA is " + "disabling parallel compilation, which may slow down " + "compilation. " + "You should update your NVIDIA driver or use the " + "NVIDIA-provided " + "CUDA forward compatibility packages."; + + return LinkingMethod::kNone; +} - if (driver_version >= ptxas_version) { - linking_method = LinkingMethod::kDriver; - } else { - LOG_FIRST_N(WARNING, 1) - << "The NVIDIA driver's CUDA version is " - << absl::StrFormat("%d.%d", driver_version / 1000, - (driver_version % 1000) / 10) - << " which is older than the ptxas CUDA version " - << absl::StrFormat("(%d.%d.%d)", std::get<0>(ptxas_version_tuple), - std::get<1>(ptxas_version_tuple), - std::get<2>(ptxas_version_tuple)) - << ". Because the driver is older than the ptxas version, XLA is " - "disabling parallel compilation, which may slow down compilation. " - "You should update your NVIDIA driver or use the NVIDIA-provided " - "CUDA forward compatibility packages."; +absl::StatusOr NVPTXCompiler::ChooseLinkingMethod( + const DebugOptions& debug_options) { + se::GpuAsmOpts ptxas_config = PtxOptsFromDebugOptions(debug_options); + std::string& preferred_cuda_dir = ptxas_config.preferred_cuda_dir; + + { + absl::MutexLock lock(&mutex_); + auto it = linking_methods_.find(preferred_cuda_dir); + if (it != linking_methods_.end()) { + return it->second; } } + + // This wrapper only handles caching. The actual choice happens in this call: + TF_ASSIGN_OR_RETURN( + LinkingMethod linking_method, + ChooseLinkingMethodImpl(debug_options, preferred_cuda_dir)); + { absl::MutexLock lock(&mutex_); linking_methods_[preferred_cuda_dir] = linking_method; @@ -762,10 +784,8 @@ absl::StatusOr NVPTXCompiler::CanUseLinkModules( const HloModuleConfig& hlo_module_config) { // TODO(phawkins): rather than comparing version numbers, it might be more // robust if we simply tried to link something the first time we compile. - auto ptxas_config = - PtxOptsFromDebugOptions(hlo_module_config.debug_options()); TF_ASSIGN_OR_RETURN(LinkingMethod linking_method, - ChooseLinkingMethod(ptxas_config.preferred_cuda_dir)); + ChooseLinkingMethod(hlo_module_config.debug_options())); return linking_method != LinkingMethod::kNone; } @@ -783,7 +803,7 @@ absl::StatusOr> NVPTXCompiler::LinkModules( stream_exec->platform_specific_handle().context); TF_ASSIGN_OR_RETURN(LinkingMethod linking_method, - ChooseLinkingMethod(ptxas_config.preferred_cuda_dir)); + ChooseLinkingMethod(debug_options)); if (linking_method == LinkingMethod::kNvLink) { return LinkUsingNvlink(debug_options.xla_gpu_cuda_data_dir(), context, images); diff --git a/third_party/xla/xla/service/gpu/nvptx_compiler.h b/third_party/xla/xla/service/gpu/nvptx_compiler.h index fa5afaec7d2fc5..62cdef69bcba24 100644 --- a/third_party/xla/xla/service/gpu/nvptx_compiler.h +++ b/third_party/xla/xla/service/gpu/nvptx_compiler.h @@ -72,6 +72,12 @@ class NVPTXCompiler : public GpuCompiler { se::GpuComputeCapability gpu_version, bool relocatable, const HloModule* debug_module, const CompileOptions& options) override; + enum class LinkingMethod { + kNone, + kNvLink, + kDriver, + }; + private: absl::StatusOr CanUseLinkModules( const HloModuleConfig& module_config) override; @@ -83,16 +89,11 @@ class NVPTXCompiler : public GpuCompiler { absl::Mutex mutex_; - enum class LinkingMethod { - kNone, - kNvLink, - kDriver, - }; absl::flat_hash_map linking_methods_ ABSL_GUARDED_BY(mutex_); absl::StatusOr ChooseLinkingMethod( - const std::string& preferred_cuda_dir); + const DebugOptions& debug_options); // Tries to compile the given ptx string to cubin. Returns a vector with the // compiled cubin if compilation succeeded. @@ -101,6 +102,22 @@ class NVPTXCompiler : public GpuCompiler { const HloModuleConfig& hlo_module_config, absl::string_view module_name, bool relocatable, const CompileOptions& options); + struct CompilationCacheFlags { + template + friend H AbslHashValue(H h, const CompilationCacheFlags& flags) { + return H::combine(std::move(h), + flags.filter_kernels_spilling_registers_on_autotuning); + } + + friend bool operator==(const CompilationCacheFlags& a, + const CompilationCacheFlags& b) { + return a.filter_kernels_spilling_registers_on_autotuning == + b.filter_kernels_spilling_registers_on_autotuning; + } + + bool filter_kernels_spilling_registers_on_autotuning; + }; + // The compilation_cache_ map is a cache from {ptx string, cc_major, cc_minor} // -> cubin so we don't recompile the same ptx twice. This is important for // some interactive workflows. (We also cache at the HLO level, but sometimes @@ -115,29 +132,36 @@ class NVPTXCompiler : public GpuCompiler { // and leave compilation up to the driver. struct CompilationCacheKey { CompilationCacheKey(std::string ptx, int cc_major, int cc_minor, - bool relocatable) + bool relocatable, CompilationCacheFlags flags) : ptx(std::move(ptx)), cc_major(cc_major), cc_minor(cc_minor), - relocatable(relocatable) {} + relocatable(relocatable), + flags(std::move(flags)) {} + template friend H AbslHashValue(H h, const CompilationCacheKey& key) { return H::combine(std::move(h), key.ptx, key.cc_major, key.cc_minor, - key.relocatable); + key.relocatable, key.flags); } + friend bool operator==(const CompilationCacheKey& a, const CompilationCacheKey& b) { return a.cc_major == b.cc_major && a.cc_minor == b.cc_minor && - a.ptx == b.ptx && a.relocatable == b.relocatable; + a.ptx == b.ptx && a.relocatable == b.relocatable && + a.flags == b.flags; } + std::string ptx; int cc_major; int cc_minor; bool relocatable; + CompilationCacheFlags flags; }; + struct CompilationCacheValue { bool compilation_done = false; - std::vector cubin_data; + absl::StatusOr> maybe_cubin; // mutex and condition variable to serialize compilation completing. absl::Mutex mutex; absl::CondVar compilation_done_cv; diff --git a/third_party/xla/xla/service/gpu/parallel_loop_emitter.cc b/third_party/xla/xla/service/gpu/parallel_loop_emitter.cc index 326cdce5d0d719..bb5dbf77df1e2e 100644 --- a/third_party/xla/xla/service/gpu/parallel_loop_emitter.cc +++ b/third_party/xla/xla/service/gpu/parallel_loop_emitter.cc @@ -66,7 +66,8 @@ ParallelLoopEmitter::EmitLinearBaseAndThreadIdx(llvm::Type* index_type, llvm::Value* block_id = EmitCallToTargetIntrinsic(TargetIntrinsicID::kBlockIdx, {}, {}, b_); llvm_ir::AddRangeMetadata(0, launch_dimensions_.block_counts().x, - static_cast(block_id)); + static_cast(block_id), + b_->GetInsertBlock()->getModule()); block_id = b_->CreateZExtOrTrunc(block_id, index_type, "block_id"); // Per the PTX documentation: @@ -74,7 +75,8 @@ ParallelLoopEmitter::EmitLinearBaseAndThreadIdx(llvm::Type* index_type, llvm::Value* thread_id_x = EmitCallToTargetIntrinsic(TargetIntrinsicID::kThreadIdx, {}, {}, b_); llvm_ir::AddRangeMetadata(0, launch_dimensions_.thread_counts_per_block().x, - static_cast(thread_id_x)); + static_cast(thread_id_x), + b_->GetInsertBlock()->getModule()); thread_id_x = b_->CreateZExtOrTrunc(thread_id_x, index_type, "thread_id_x"); llvm::Value* linear_index_base = @@ -88,7 +90,8 @@ ParallelLoopEmitter::EmitLinearBaseAndThreadIdx(llvm::Type* index_type, llvm::Value* thread_id_y = EmitCallToTargetIntrinsic(TargetIntrinsicID::kThreadIdy, {}, {}, b_); llvm_ir::AddRangeMetadata(0, launch_dimensions_.thread_counts_per_block().y, - static_cast(thread_id_y)); + static_cast(thread_id_y), + b_->GetInsertBlock()->getModule()); thread_id_y = b_->CreateZExtOrTrunc(thread_id_y, index_type, "thread_id_y"); linear_index_base = b_->CreateAdd( linear_index_base, diff --git a/third_party/xla/xla/service/gpu/priority_fusion.cc b/third_party/xla/xla/service/gpu/priority_fusion.cc index 8831771dac6656..56837f6befbd76 100644 --- a/third_party/xla/xla/service/gpu/priority_fusion.cc +++ b/third_party/xla/xla/service/gpu/priority_fusion.cc @@ -22,7 +22,7 @@ limitations under the License. #include #include #include -#include +#include #include #include @@ -31,14 +31,16 @@ limitations under the License. #include "absl/log/check.h" #include "absl/meta/type_traits.h" #include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" #include "absl/strings/string_view.h" #include "absl/synchronization/mutex.h" #include "absl/time/time.h" +#include "llvm/ADT/STLExtras.h" +#include "xla/debug_options_flags.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/service/dump.h" -#include "xla/service/fusion_node_indexing_evaluation.h" #include "xla/service/fusion_queue.h" #include "xla/service/gpu/fusion_process_dump.pb.h" #include "xla/service/gpu/gpu_fusible.h" @@ -47,13 +49,18 @@ limitations under the License. #include "xla/service/gpu/model/fusion_analysis_cache.h" #include "xla/service/gpu/model/gpu_hlo_cost_analysis.h" #include "xla/service/gpu/model/gpu_performance_model.h" +#include "xla/service/gpu/model/gpu_performance_model_base.h" +#include "xla/service/hlo_graph_dumper.h" #include "xla/service/instruction_fusion.h" #include "xla/shape.h" +#include "xla/shape_util.h" #include "xla/stream_executor/device_description.h" #include "xla/xla_data.pb.h" #include "tsl/platform/blocking_counter.h" +#include "tsl/platform/errors.h" #include "tsl/platform/logging.h" #include "tsl/platform/status.h" +#include "tsl/platform/threadpool.h" namespace xla { namespace gpu { @@ -110,7 +117,7 @@ bool IsFusible(const HloInstruction& instr) { // performance when fusing it to all of its fusible users. We greedily pick the // max-benefit producer to fuse, and update the estimated benefits of the fused // nodes and their operands. -class GpuPriorityFusionQueue : public FusionQueue { +class GpuPriorityFusionQueue { using Priority = int64_t; using CanFuseCallback = std::function; @@ -131,6 +138,11 @@ class GpuPriorityFusionQueue : public FusionQueue { VLOG(2) << "Running full HLO cost analysis for " << computation_->name(); TF_CHECK_OK(computation_->Accept(&cost_analysis_)); + dump_fusion_visualization_ = computation->parent() + ->config() + .debug_options() + .xla_dump_fusion_visualization(); + // Initializes the priority queue. std::vector instructions; for (auto* instruction : computation->MakeInstructionPostOrder()) { @@ -174,18 +186,14 @@ class GpuPriorityFusionQueue : public FusionQueue { return priorities; } - std::pair> - DequeueNextInstructionAndOperandsToFuseInOrder() override { - // When current_consumers_ is empty, we need to dequeue a new producer. - // Update the priorities that changed during the last fusion. - if (current_consumers_.empty()) { - UpdatePriorities(); - } + // Gets the next pair of (producer, consumers) from the queue for fusion. + // Returns true if there is the next producer to fuse, otherwise false. Stores + // the producer and consumers in `current_producer_` and `current_consumers_`. + bool DequeueNextProducer() { + current_producer_ = nullptr; + current_consumers_.clear(); - while (current_consumers_.empty()) { - if (producer_priority_queue_.empty()) { - return {}; - } + while (!producer_priority_queue_.empty() && current_consumers_.empty()) { auto next_it = std::prev(producer_priority_queue_.end()); auto priority = next_it->first.first; @@ -198,6 +206,7 @@ class GpuPriorityFusionQueue : public FusionQueue { if (priority < 0) { continue; } + current_consumers_ = current_producer_->users(); if (current_producer_->opcode() == HloOpcode::kBitcast) { @@ -209,14 +218,7 @@ class GpuPriorityFusionQueue : public FusionQueue { } } - auto next_consumer = current_consumers_.back(); - int64_t producer_operand_index = - next_consumer->operand_index(current_producer_); - current_consumers_.pop_back(); - VLOG(5) << "next: " << next_consumer->name() << "(" << next_consumer - << ") + " << current_producer_->name() << "(" << current_producer_ - << ")"; - return {next_consumer, {producer_operand_index}}; + return !current_consumers_.empty(); } // Update priorities of all affected ops. @@ -256,7 +258,15 @@ class GpuPriorityFusionQueue : public FusionQueue { // Prepares producer and consumer instruction to be fused. Invalidates caches // and writes logs. - void PreFusion(HloInstruction* producer, HloInstruction* consumer) override { + void PreFusion(HloInstruction* producer, HloInstruction* consumer) { + if (dump_fusion_visualization_) { + RegisterFusionState( + *computation_, + absl::StrCat("About to fuse |", producer->name(), "| into |", + consumer->name(), "| inside PriorityFusion"), + *consumer, producer); + } + InvalidateCaches(producer); InvalidateCaches(consumer); } @@ -276,23 +286,12 @@ class GpuPriorityFusionQueue : public FusionQueue { gpu_performance_model_cache_.Invalidate(*instruction); fusion_analysis_cache_.Invalidate(*instruction); - - for (auto* user : instruction->users()) { - fusion_node_evaluations_.erase(user); - } - fusion_node_evaluations_.erase(instruction); } // Updates data for the new fusion instruction and its users and operands. void OnFusingInstruction(HloInstruction* fusion, HloInstruction* original_producer, - HloInstruction* original_consumer) override { - absl::string_view emitter_fusion_kind = - HloFusionAnalysis::GetEmitterFusionKindString( - fusion_analysis_cache_.Get(*fusion).GetEmitterFusionKind()); - fusion->SetAndSanitizeName(absl::StrCat(emitter_fusion_kind, "_fusion")); - fusion->UniquifyName(&fusion->GetModule()->instruction_name_uniquer()); - + HloInstruction* original_consumer) { if (fusion_process_dump_) { auto* fusion_step = fusion_process_dump_->add_fusion_steps()->mutable_fusion(); @@ -303,6 +302,14 @@ class GpuPriorityFusionQueue : public FusionQueue { fusion_step->set_consumer_name(std::string(original_consumer->name())); } + if (dump_fusion_visualization_) { + RegisterFusionState( + *computation_, + absl::StrCat("Fused |", original_producer->name(), "| into |", + fusion->name(), "| inside PriorityFusion"), + *fusion); + } + // The original consumer was replaced with the fusion, but it's pointer can // still be referenced somewhere, for example, in to_update_priority_. // Priority recomputation is called before DCE. Remove all references to @@ -322,8 +329,6 @@ class GpuPriorityFusionQueue : public FusionQueue { // Collect the instructions whose priorities need to be updated. for (HloInstruction* operand : fusion->operands()) { if (operand == original_producer || - original_producer->opcode() == HloOpcode::kBroadcast || - operand->opcode() == HloOpcode::kBroadcast || operand->opcode() == HloOpcode::kConstant || operand->opcode() == HloOpcode::kGetTupleElement) { continue; @@ -340,7 +345,7 @@ class GpuPriorityFusionQueue : public FusionQueue { } // Removes data for the instruction. - void RemoveInstruction(HloInstruction* instruction) override { + void RemoveInstruction(HloInstruction* instruction) { to_update_priority_.erase(instruction); fusion_analysis_cache_.Invalidate(*instruction); @@ -352,7 +357,11 @@ class GpuPriorityFusionQueue : public FusionQueue { reverse_map_.erase(reverse_it); } - const std::vector* FusionConfiguration() override { return nullptr; } + HloInstruction* current_producer() { return current_producer_; } + + const std::vector& current_consumers() { + return current_consumers_; + } private: // Returns the priority of the producer based on its current operands and @@ -370,7 +379,7 @@ class GpuPriorityFusionQueue : public FusionQueue { } // Don't fuse if we can't fuse in all users. - if (auto fusion_decision = CanFuseWithAllUsers(producer); + if (auto fusion_decision = CanFuseWithAllNonBitcastUsers(producer); !fusion_decision) { if (fusion_process_dump_) { absl::MutexLock lock(&fusion_process_dump_mutex_); @@ -413,6 +422,10 @@ class GpuPriorityFusionQueue : public FusionQueue { return "the consumer is not fusible"; } + if (consumer->opcode() == HloOpcode::kBitcast) { + return "not fusing into a single bitcast as consumer"; + } + // Scatter is special as it has no elemental version but is still input // fusible. Block attempts to create scatter fusions we can't codegen. if (auto can_fuse = CanEmitInputFusedScatter(*producer, *consumer); @@ -426,7 +439,9 @@ class GpuPriorityFusionQueue : public FusionQueue { auto contains_significant_reduce = [&](const HloInstruction* instr) { auto fusion = HloFusionAdaptor::ForInstruction(instr); return HloAnyOf(fusion->GetRoots(), *fusion, [](auto node) { - if (node.opcode() != HloOpcode::kReduce) return false; + if (!(node.opcode() == HloOpcode::kReduce && node.shape().IsArray())) { + return false; + } int64_t reduction_size = ShapeUtil::ElementsIn(node.instruction().operand(0)->shape()) / @@ -445,14 +460,13 @@ class GpuPriorityFusionQueue : public FusionQueue { // switch it to the loop emitter. This often occurs during epilog fusion for // reductions, which suffer from limited emitter support. // TODO(b/312686229): Cost model should handle this. - const auto& analysis_fused = - fusion_analysis_cache_.Get(*producer, *consumer); - if (producer->IsInputFusion() && - analysis_fused.GetEmitterFusionKind() == - HloFusionAnalysis::EmitterFusionKind::kLoop) { - const auto& analysis = fusion_analysis_cache_.Get(*producer); - if (analysis.GetEmitterFusionKind() == - HloFusionAnalysis::EmitterFusionKind::kReduction) { + const auto& analysis = fusion_analysis_cache_.Get(*producer); + if (analysis.GetEmitterFusionKind() == + HloFusionAnalysis::EmitterFusionKind::kReduction) { + const auto& analysis_fused = + fusion_analysis_cache_.Get(*producer, *consumer); + if (analysis_fused.GetEmitterFusionKind() == + HloFusionAnalysis::EmitterFusionKind::kLoop) { return "fusion into output of a reduce fusion would create a loop " "fusion"; } @@ -471,18 +485,8 @@ class GpuPriorityFusionQueue : public FusionQueue { // have exponential time/memory requirements for emitting certain fusion // kernels, in which case we don't want to fuse. // TODO(b/119692968): Remove this once we have fixed our fusion emitter. - if (consumer->opcode() == HloOpcode::kFusion) { - absl::MutexLock lock(&fusion_node_evaluations_mutex_); - if (fusion_node_evaluations_.find(consumer) == - fusion_node_evaluations_.end()) { - // We have no cached results for this fusion node yet. Compute it now. - fusion_node_evaluations_.emplace( - consumer, FusionNodeIndexingEvaluation(consumer)); - } - if (fusion_node_evaluations_.at(consumer).CodeDuplicationTooHigh( - producer)) { - return "the fusion would result in an overly large code duplication"; - } + if (cost_analysis_.ProducerConsumerMergedTooLarge(*producer, *consumer)) { + return "the fusion would result in an overly large code duplication"; } // Don't fuse across a root instruction. There are situation when a root @@ -525,13 +529,18 @@ class GpuPriorityFusionQueue : public FusionQueue { return fusion_decision; } - FusionDecision CanFuseWithAllUsers(HloInstruction* producer) { + FusionDecision CanFuseWithAllNonBitcastUsers(HloInstruction* producer) { if (producer->users().empty()) { return "No users to fuse"; } FusionDecision result; + bool has_non_bitcast_user = false; for (const auto& user : producer->users()) { + if (user->opcode() == HloOpcode::kBitcast) { + continue; + } + has_non_bitcast_user = true; if (auto fusion_decision = CanFuseCached(producer, user); !fusion_decision) { VLOG(10) << "Cannot fuse " << producer->name() << " with " @@ -539,6 +548,9 @@ class GpuPriorityFusionQueue : public FusionQueue { return fusion_decision; } } + if (!has_non_bitcast_user) { + return "not fusing because there are only bitcast users"; + } return {}; } @@ -589,11 +601,7 @@ class GpuPriorityFusionQueue : public FusionQueue { GpuPerformanceModelCache gpu_performance_model_cache_; - // Keep track of the number of times each instruction inside a fusion node is - // indexed with different index vectors. - absl::Mutex fusion_node_evaluations_mutex_; - absl::flat_hash_map - fusion_node_evaluations_; + bool dump_fusion_visualization_; }; } // namespace @@ -632,6 +640,14 @@ bool IsSmallConstant(const HloInstruction* instr) { ShapeUtil::ElementsIn(instr->shape()) <= 1; } +bool GpuPriorityFusion::ConsumeFuel(HloInstruction* producer, + HloInstruction* consumer) { + return xla::ConsumeFuel(name(), /*ran_out_of_fuel_msg=*/[&] { + return absl::StrFormat("Not fusing producer %s with consumer %s", + producer->name(), consumer->name()); + }); +}; + absl::StatusOr GpuPriorityFusion::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { @@ -639,6 +655,8 @@ absl::StatusOr GpuPriorityFusion::Run( DumpingEnabledForHloPass(name(), module->config().debug_options()); if (dump_enabled) { fusion_process_dump_ = std::make_unique(); + *fusion_process_dump_->mutable_gpu_device_info() = + device_info_.ToGpuProto(); } // Appends ".0" suffix to all instructions. @@ -658,33 +676,72 @@ absl::StatusOr GpuPriorityFusion::Run( } } - auto result = InstructionFusion::Run(module, execution_threads); - - // Fuse all constants. - if (result.ok()) { - // Note: `GetFusionComputations` doesn't return the fusion computations, but - // the computations to be fused. - for (auto* computation : GetFusionComputations(module, execution_threads)) { - std::vector constants; - for (auto* instruction : computation->instructions()) { - // Small constants should be fused, because they can be folded and - // codegened efficiently. - // Fusing large constants doesn't give much benefits, because they're - // treated like parameters and read from global memory anyway. Fusion - // and duplication of large constants can, however, cause problems if we - // want to dump hlo and parse back, because in that case duplicated - // constants will be filled with different data. - if (IsSmallConstant(instruction)) { - constants.push_back(instruction); + if (dump_enabled) { + fusion_process_dump_->set_hlo_module_before_fusion( + module->ToString(HloPrintOptions::ShortParsable())); + } + + int changed = false; + // Note: `GetFusionComputations` doesn't return the fusion computations, but + // the computations to be fused. + for (auto* computation : GetFusionComputations(module, execution_threads)) { + CHECK(!computation->IsFusionComputation()); + + auto fusion_queue = std::make_unique( + computation, cost_analysis_options_, &device_info_, + fusion_process_dump_.get(), thread_pool_, fusion_analysis_cache_); + + while (fusion_queue->DequeueNextProducer()) { + auto producer = fusion_queue->current_producer(); + + for (auto* consumer : fusion_queue->current_consumers()) { + // Don't fuse into single bitcasts. We ignore them in the check + // CanFuseWithAllNonBitcastUsers(), so we need to check it here. + if (consumer->opcode() == HloOpcode::kBitcast) { + continue; } + if (!ConsumeFuel(producer, consumer)) continue; + + VLOG(5) << "next: " << consumer->name() << "(" << consumer << ") + " + << producer->name() << "(" << producer << ")"; + + fusion_queue->PreFusion(producer, consumer); + auto fusion_instruction = Fuse(producer, consumer, computation); + fusion_queue->OnFusingInstruction(fusion_instruction, producer, + consumer); + + changed = true; + } + + if (producer->user_count() == 0) { + fusion_queue->RemoveInstruction(producer); + // Remove from computation. + TF_RETURN_IF_ERROR(computation->RemoveInstruction(producer)); + } + + fusion_queue->UpdatePriorities(); + } + + // Fuse all constants. + std::vector constants; + for (auto* instruction : computation->instructions()) { + // Small constants should be fused, because they can be folded and + // codegened efficiently. + // Fusing large constants doesn't give much benefits, because they're + // treated like parameters and read from global memory anyway. Fusion + // and duplication of large constants can, however, cause problems if we + // want to dump hlo and parse back, because in that case duplicated + // constants will be filled with different data. + if (IsSmallConstant(instruction)) { + constants.push_back(instruction); } - for (auto* constant : constants) { - auto users = constant->users(); - for (auto* user : users) { - if (IsFusible(*user) && CanEmitInputFusedScatter(*constant, *user)) { - InstructionFusion::Fuse(constant, user, computation); - result = true; - } + } + for (auto* constant : constants) { + auto users = constant->users(); + for (auto* user : users) { + if (IsFusible(*user) && CanEmitInputFusedScatter(*constant, *user)) { + Fuse(constant, user, computation); + changed = true; } } } @@ -701,7 +758,7 @@ absl::StatusOr GpuPriorityFusion::Run( "priority_fusion_dump"); } - return result; + return changed; } FusionDecision GpuPriorityFusion::ShouldFuse(HloInstruction* consumer, @@ -747,9 +804,7 @@ HloInstruction* GpuPriorityFusion::FuseInstruction( std::unique_ptr GpuPriorityFusion::GetFusionQueue( HloComputation* computation) { - return std::unique_ptr(new GpuPriorityFusionQueue( - computation, cost_analysis_options_, &device_info_, - fusion_process_dump_.get(), thread_pool_, fusion_analysis_cache_)); + return nullptr; } } // namespace gpu diff --git a/third_party/xla/xla/service/gpu/priority_fusion.h b/third_party/xla/xla/service/gpu/priority_fusion.h index 7610705cc3ee38..4725e726f2011f 100644 --- a/third_party/xla/xla/service/gpu/priority_fusion.h +++ b/third_party/xla/xla/service/gpu/priority_fusion.h @@ -26,7 +26,6 @@ limitations under the License. #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" -#include "xla/service/fusion_node_indexing_evaluation.h" #include "xla/service/fusion_queue.h" #include "xla/service/gpu/fusion_process_dump.pb.h" #include "xla/service/gpu/model/fusion_analysis_cache.h" @@ -74,6 +73,10 @@ class GpuPriorityFusion : public InstructionFusion { HloInstruction* FuseInstruction(HloInstruction* fusion_instruction, HloInstruction* producer) override; + // Consumes a unit of compiler fuel and returns true if we should + // continue with the transformation. + bool ConsumeFuel(HloInstruction* producer, HloInstruction* consumer); + tsl::thread::ThreadPool* thread_pool_; se::DeviceDescription device_info_; diff --git a/third_party/xla/xla/service/gpu/priority_fusion_test.cc b/third_party/xla/xla/service/gpu/priority_fusion_test.cc index ff4d6956f90201..8ff3f675e79e3f 100644 --- a/third_party/xla/xla/service/gpu/priority_fusion_test.cc +++ b/third_party/xla/xla/service/gpu/priority_fusion_test.cc @@ -19,9 +19,11 @@ limitations under the License. #include #include +#include #include #include +#include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" @@ -40,7 +42,6 @@ limitations under the License. namespace m = ::xla::match; -using ::testing::ElementsAre; using ::testing::UnorderedElementsAre; using ::tsl::testing::IsOk; using ::tsl::testing::IsOkAndHolds; @@ -140,6 +141,24 @@ CHECK-NEXT: ROOT {{.*}} tuple(%[[FUSION_0]], %[[FUSION_1]]) )"); } +TEST_F(PriorityFusionTest, FuseBroadcastIntoBitcastConsumers) { + absl::string_view kHlo = R"( + HloModule test_module + + ENTRY main { + param_0 = f32[96]{0} parameter(0) + broadcast = f32[8,96,128,7]{3,2,1,0} broadcast(param_0), dimensions={1} + bitcast.6079.2 = f32[8,24,4,128,7]{4,3,2,1,0} bitcast(broadcast) + ROOT transpose.1990.2 = f32[8,24,128,7,4]{4,3,2,1,0} transpose(bitcast.6079.2), dimensions={0,1,3,4,2} + } + )"; + RunAndFilecheckHloRewrite(kHlo, std::move(priority_fusion_), R"( +CHECK: ENTRY +CHECK-NEXT: %[[PARAM:.*]] = f32[96]{0} parameter(0) +CHECK-NEXT: ROOT %{{.*}} fusion(%[[PARAM]]) + )"); +} + TEST_F(PriorityFusionTest, FuseWideningConvertIntoConsumers) { absl::string_view kHlo = R"( HloModule test_module @@ -157,8 +176,9 @@ TEST_F(PriorityFusionTest, FuseWideningConvertIntoConsumers) { CHECK: ENTRY CHECK-NEXT: %[[PARAM:.*]] = f16[512]{0} parameter(0) CHECK-NEXT: %[[FUSION_F32:.*]] = f32[512]{0} fusion(%[[PARAM]]) -CHECK-NEXT: %[[FUSION_S32:.*]] = s32[512]{0} fusion(%[[PARAM]]) -CHECK-NEXT: ROOT %{{.*}} = (f32[512]{0}, s32[512]{0}) tuple(%[[FUSION_F32]], %[[FUSION_S32]]) +CHECK-NEXT: %[[CONVERT_FUSION:.*]] = f32[512]{0} fusion(%[[PARAM]]) +CHECK-NEXT: %[[BITCAST:.*]] = s32[512]{0} bitcast(%[[CONVERT_FUSION]]) +CHECK-NEXT: ROOT %{{.*}} = (f32[512]{0}, s32[512]{0}) tuple(%[[FUSION_F32]], %[[BITCAST]]) )"); } @@ -201,58 +221,9 @@ CHECK-COUNT-3: fusion )"); } -TEST_F(PriorityFusionTest, FusionInstructionNames) { - absl::string_view kHlo = R"( - HloModule test_module - - square { - p = f32[16384] parameter(0) - ROOT m = f32[16384] multiply(p, p) - } - - exp { - p = f32[16384] parameter(0) - ROOT e = f32[16384] exponential(p) - } - - log { - p = f32[16384] parameter(0) - ROOT l = f32[16384] log(p) - } - - add { - p0 = f32[] parameter(0) - p1 = f32[] parameter(1) - ROOT add = f32[] add(p0, p1) - } - - ENTRY main { - p0 = bf16[1024,8192] parameter(0) - p1 = f32[8192] parameter(1) - p2 = f32[16384] parameter(2) - convert = f32[1024,8192] convert(p0) - broadcast = f32[1024,8192] broadcast(p1), dimensions={1} - c0 = f32[] constant(0) - multiply = f32[1024,8192] multiply(broadcast, convert) - reduce = f32[1024] reduce(multiply, c0), dimensions={1}, to_apply=add - convert.1 = bf16[1024] convert(reduce) - s = f32[16384] fusion(p2), kind=kLoop, calls=square - e = f32[16384] fusion(s), kind=kLoop, calls=exp - l = f32[16384] fusion(s), kind=kInput, calls=log - ROOT result = (bf16[1024]{0}, f32[16384]{0}, f32[16384]{0}) tuple(convert.1, l, e) - })"; - - RunAndFilecheckHloRewrite(kHlo, std::move(priority_fusion_), R"( -CHECK: ENTRY %main -CHECK: %reduction_fusion{{.*}} fusion -CHECK: %loop_fusion{{.*}} calls=%log -CHECK: %loop_fusion{{.*}} calls=%exp -CHECK: ROOT %result - )"); -} - TEST_F(PriorityFusionTest, ReductionEpilogueFusionRegressionTest) { - // Regression test for epilogue fusion of convert+bitcast into a reduction. + // Regression test for epilogue fusion of convert into a reduction, even if + // the convert has a bitcast as consumer. absl::string_view kHlo = R"( HloModule test_module @@ -301,10 +272,37 @@ TEST_F(PriorityFusionTest, ReductionEpilogueFusionRegressionTest) { RunAndFilecheckHloRewrite(kHlo, std::move(priority_fusion_), R"( CHECK: ENTRY -CHECK: ROOT {{.*}} fusion( +CHECK: ROOT {{.*}} bitcast({{.*}}fusion{{.*}}) )"); } +TEST_F(PriorityFusionTest, DoNotChangeReductionFusionToLoopFusion) { + // Regression test for epilogue fusion of slice into a reduction. The fusion + // kind for the reduction fusion is intentionally chosen to be set to kLoop, + // as we cannot rely on reductions always having fusion kind kInput. + auto module = *ParseAndReturnVerifiedModule(R"( + HloModule test_module + + add { + rhs.407 = f32[] parameter(1) + lhs.407 = f32[] parameter(0) + ROOT add.24451 = f32[] add(lhs.407, rhs.407) + } + + fused_computation { + p0 = f32[16,64]{1,0} parameter(0) + zero = f32[] constant(0.0) + ROOT reduce = f32[16]{0} reduce(p0, zero), dimensions={1}, to_apply=add + } + + ENTRY main { + param0 = f32[16,64]{1,0} parameter(0) + fusion = f32[16]{0} fusion(param0), kind=kLoop, calls=fused_computation + ROOT slice = f32[8]{0} slice(fusion), slice={[0:8]} + })"); + EXPECT_THAT(priority_fusion_.Run(module.get()), IsOkAndHolds(false)); +} + TEST_F(PriorityFusionTest, DoNotFuseTransposeIntoReduce) { absl::string_view kHlo = R"( HloModule test_module @@ -805,5 +803,56 @@ TEST_F(PriorityFusionTest, FuseOnlySmallConstant) { m::Add(m::Parameter(), m::Broadcast(m::Constant()))))); } +TEST_F(PriorityFusionTest, DoNotFuseProducerConsumerMergedTooLarge) { + auto module = *ParseAndReturnVerifiedModule(R"( + HloModule module + + fused_computation.1 { + iota.9.7 = s32[3,1,1]{2,1,0} iota(), iota_dimension=0 + param_3.29 = s32[] parameter(2) + pad.2.7 = s32[3,1,2]{2,1,0} pad(iota.9.7, param_3.29), padding=0_0x0_0x0_1 + param_2.39 = s32[] parameter(1) + broadcast.76.1 = s32[3,1,2]{2,1,0} broadcast(param_2.39), dimensions={} + compare.9.1 = pred[3,1,2]{2,1,0} compare(pad.2.7, broadcast.76.1), direction=GE + param_1.73 = s32[2]{0} parameter(0) + broadcast.78.1 = s32[3,2]{1,0} broadcast(param_1.73), dimensions={1} + bitcast.1 = s32[3,2]{1,0} bitcast(pad.2.7) + compare.10.1 = pred[3,2]{1,0} compare(bitcast.1, broadcast.78.1), direction=LE + bitcast.2 = pred[3,1,2]{2,1,0} bitcast(compare.10.1) + ROOT and.3.1 = pred[3,1,2]{2,1,0} and(compare.9.1, bitcast.2) + } + + and { + x = pred[] parameter(0) + y = pred[] parameter(1) + ROOT and = pred[] and(x, y) + } + + fused_computation.2 { + param0 = pred[3,1,2]{2,1,0} parameter(0) + slice = pred[1,1,2]{2,1,0} slice(param0), slice={[0:1], [0:1], [0:2]} + bitcast = pred[2]{0} bitcast(slice) + init = pred[] constant(true) + reduce = pred[2]{0} reduce(param0, init), dimensions={0,1}, to_apply=and + and = pred[2]{0} and(bitcast, reduce) + pad = pred[3]{0} pad(and, init), padding=0_1 + broadcast = pred[3,2]{1,0} broadcast(pad), dimensions={0} + bitcast2 = pred[6]{0} bitcast(broadcast) + broadcast2 = pred[2,3]{1,0} broadcast(pad), dimensions={1} + bitcast3 = pred[6]{0} bitcast(broadcast2) + ROOT and2 = pred[6]{0} and(bitcast2, bitcast3) + } + + ENTRY main { + p0 = s32[2]{0} parameter(0) + p1 = s32[] parameter(1) + p2 = s32[] parameter(2) + fusion1 = pred[3,1,2]{2,1,0} fusion(p0, p1, p2), kind=kLoop, calls=fused_computation.1 + ROOT fusion2 = pred[6]{0} fusion(fusion1), kind=kInput, calls=fused_computation.2 + } + )"); + EXPECT_THAT(priority_fusion_.Run(module.get()), IsOkAndHolds(false)); +} + } // namespace gpu } // namespace xla diff --git a/third_party/xla/xla/service/gpu/rename_fusions.cc b/third_party/xla/xla/service/gpu/rename_fusions.cc new file mode 100644 index 00000000000000..1a6731cdd49190 --- /dev/null +++ b/third_party/xla/xla/service/gpu/rename_fusions.cc @@ -0,0 +1,93 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/rename_fusions.h" + +#include +#include + +#include "absl/container/btree_set.h" +#include "absl/container/flat_hash_set.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "absl/strings/str_replace.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/service/gpu/hlo_traversal.h" +#include "xla/service/gpu/ir_emission_utils.h" + +namespace xla { +namespace gpu { +namespace { + +constexpr absl::string_view FusionKindToString( + HloInstruction::FusionKind kind) { + switch (kind) { + case HloInstruction::FusionKind::kCustom: + return "custom"; + case HloInstruction::FusionKind::kLoop: + return "loop"; + case HloInstruction::FusionKind::kInput: + return "input"; + case HloInstruction::FusionKind::kOutput: + return "output"; + } +} + +std::string MakeFusionHeroNames(const HloInstruction* instruction) { + std::unique_ptr fusion_adaptor = + HloFusionAdaptor::ForInstruction(instruction); + absl::btree_set heroes; + + for (auto root : fusion_adaptor->GetRoots()) { + heroes.insert(HloOpcodeString( + FindNonTrivialHero(root.instruction(), *fusion_adaptor).opcode())); + } + return absl::StrReplaceAll(absl::StrJoin(heroes, "_"), {{"-", "_"}}); +} + +void RenameFusion(HloModule* module, HloInstruction* instruction) { + std::string hero_names = MakeFusionHeroNames(instruction); + module->SetAndUniquifyInstrName( + instruction, absl::StrCat(FusionKindToString(instruction->fusion_kind()), + "_", hero_names, "_fusion")); + module->SetAndUniquifyComputationName( + instruction->fused_instructions_computation(), + absl::StrCat("fused_", hero_names)); +} + +} // namespace + +absl::StatusOr RenameFusions::Run( + HloModule* module, + const absl::flat_hash_set& execution_threads) { + for (HloComputation* computation : module->MakeNonfusionComputations()) { + for (HloInstruction* instruction : computation->instructions()) { + if (instruction->opcode() != HloOpcode::kFusion || + instruction->fusion_kind() == HloInstruction::FusionKind::kCustom) { + continue; + } + RenameFusion(module, instruction); + } + } + return true; +} + +} // namespace gpu +} // namespace xla diff --git a/third_party/xla/xla/service/gpu/rename_fusions.h b/third_party/xla/xla/service/gpu/rename_fusions.h new file mode 100644 index 00000000000000..247750f378c86a --- /dev/null +++ b/third_party/xla/xla/service/gpu/rename_fusions.h @@ -0,0 +1,46 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_GPU_RENAME_FUSIONS_H_ +#define XLA_SERVICE_GPU_RENAME_FUSIONS_H_ + +#include "absl/container/flat_hash_set.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/service/hlo_pass_interface.h" + +namespace xla { +namespace gpu { + +// An HLO pass that gives fusions and fused computations descriptive names. +// +// The name is based on hero instructions and the fusion kind, i.e. +// Fusions get name "__fusion", +// and fused computations get name "fused_". +// In the case of multiple roots, the hero instructions in the name are +// underscore-separated and alphabetically sorted. + +class RenameFusions : public HloModulePass { + absl::string_view name() const override { return "rename_fusions"; } + absl::StatusOr Run( + HloModule* module, + const absl::flat_hash_set& execution_threads) override; +}; + +} // namespace gpu +} // namespace xla + +#endif // XLA_SERVICE_GPU_RENAME_FUSIONS_H_ diff --git a/third_party/xla/xla/service/gpu/rename_fusions_test.cc b/third_party/xla/xla/service/gpu/rename_fusions_test.cc new file mode 100644 index 00000000000000..60c97cf2ff9438 --- /dev/null +++ b/third_party/xla/xla/service/gpu/rename_fusions_test.cc @@ -0,0 +1,83 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/rename_fusions.h" + +#include + +#include +#include "absl/strings/string_view.h" +#include "xla/tests/hlo_test_base.h" + +namespace xla { +namespace gpu { + +class RenameFusionsTest : public HloTestBase { + protected: + RenameFusions rename_fusions_; +}; + +TEST_F(RenameFusionsTest, FusionInstructionNames) { + absl::string_view kHlo = R"( + HloModule test_module + + square { + p = f32[16384] parameter(0) + ROOT m = f32[16384] multiply(p, p) + } + + exp { + p = f32[16384] parameter(0) + ROOT e = f32[16384] exponential(p) + } + + log { + p = f32[16384] parameter(0) + ROOT l = f32[16384] log(p) + } + + add { + p0 = f32[] parameter(0) + p1 = f32[] parameter(1) + ROOT add = f32[] add(p0, p1) + } + + ENTRY main { + p0 = bf16[1024,8192] parameter(0) + p1 = f32[8192] parameter(1) + p2 = f32[16384] parameter(2) + convert = f32[1024,8192] convert(p0) + broadcast = f32[1024,8192] broadcast(p1), dimensions={1} + c0 = f32[] constant(0) + multiply = f32[1024,8192] multiply(broadcast, convert) + reduce = f32[1024] reduce(multiply, c0), dimensions={1}, to_apply=add + convert.1 = bf16[1024] convert(reduce) + s = f32[16384] fusion(p2), kind=kLoop, calls=square + e = f32[16384] fusion(s), kind=kLoop, calls=exp + l = f32[16384] fusion(s), kind=kInput, calls=log + ROOT result = (bf16[1024]{0}, f32[16384]{0}, f32[16384]{0}) tuple(convert.1, l, e) + })"; + + RunAndFilecheckHloRewrite(kHlo, std::move(rename_fusions_), R"( +CHECK: ENTRY %main +CHECK: %loop_multiply_fusion{{.*}} calls=%fused_multiply +CHECK: %input_log_fusion{{.*}} calls=%fused_log +CHECK: %loop_exponential_fusion{{.*}} calls=%fused_exponential +CHECK: ROOT %result + )"); +} + +} // namespace gpu +} // namespace xla diff --git a/third_party/xla/xla/service/gpu/runtime/BUILD b/third_party/xla/xla/service/gpu/runtime/BUILD index a2f4915ce9e6ab..360bc385f90374 100644 --- a/third_party/xla/xla/service/gpu/runtime/BUILD +++ b/third_party/xla/xla/service/gpu/runtime/BUILD @@ -1,774 +1,742 @@ -load("@local_config_cuda//cuda:build_defs.bzl", "cuda_library") -load("//xla:xla.bzl", "xla_cc_test") -load("//xla/service/gpu:build_defs.bzl", "gpu_kernel_library") +load("//xla/tests:build_defs.bzl", "xla_test") +load("//xla/service/gpu:build_defs.bzl", "get_cub_sort_kernel_types") load("//xla/stream_executor:build_defs.bzl", "if_gpu_is_configured") -load( - "@local_config_rocm//rocm:build_defs.bzl", - "if_rocm_is_configured", - "rocm_library", -) -load( - "@local_tsl//tsl/platform:build_config_root.bzl", - "tf_gpu_tests_tags", -) +load("@local_config_rocm//rocm:build_defs.bzl", "if_rocm_is_configured") load("@local_tsl//tsl/platform/default:cuda_build_defs.bzl", "if_cuda_is_configured") package( - default_visibility = ["//visibility:public"], + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = [":friends"], licenses = ["notice"], ) package_group( name = "friends", - includes = [ - "//xla:friends", - ], + includes = ["//xla:friends"], ) -gpu_kernel_library( - name = "gpu_kernel_helper", - hdrs = if_gpu_is_configured(["gpu_kernel_helper.h"]), - local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]) + if_rocm_is_configured([ - "TENSORFLOW_USE_ROCM=1", - ]), +#===-------------------------------------------------------------------------------------------===// +# Runtime tracing libraries +#===-------------------------------------------------------------------------------------------===// + +cc_library( + name = "annotation", + srcs = ["annotation.cc"], + hdrs = ["annotation.h"], + local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]), deps = [ - "//xla/stream_executor/platform", - "@local_tsl//tsl/lib/math:math_util", - ] + if_cuda_is_configured([ - "@local_config_cuda//cuda:cuda_headers", - ]) + if_rocm_is_configured([ - "@local_config_rocm//rocm:rocm_headers", - ]), + "//xla:printer", + "//xla:status", + "//xla/hlo/ir:hlo", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/log", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/profiler/lib:nvtx_utils", + "@local_tsl//tsl/profiler/lib:scoped_annotation", + ], ) +#===-------------------------------------------------------------------------------------------===// +# Command Buffer Integration +#===-------------------------------------------------------------------------------------------===// + cc_library( - name = "cholesky", - srcs = ["cholesky.cc"], - hdrs = ["cholesky.h"], - local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]) + if_rocm_is_configured([ - "TENSORFLOW_USE_ROCM=1", - ]), - visibility = ["//visibility:public"], + name = "command_buffer_allocations", + srcs = ["command_buffer_allocations.cc"], + hdrs = ["command_buffer_allocations.h"], deps = [ - ":support", - "//xla:xla_proto_cc", - "//xla/runtime:custom_call", - "//xla/runtime:custom_call_registry", - "//xla/runtime:executable", - "//xla/service:executable", - "//xla/service/gpu:gpu_asm_opts_util", - "//xla/service/gpu/runtime3:cholesky_thunk", + "//xla:status", + "//xla:statusor", + "//xla/service:buffer_assignment", + "//xla/service/gpu:buffer_allocations", + "//xla/stream_executor", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/log", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", ], ) cc_library( - name = "collectives", - srcs = ["collectives.cc"], - hdrs = ["collectives.h"], - local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]) + if_rocm_is_configured([ - "TENSORFLOW_USE_ROCM=1", + name = "command_buffer_cmd", + srcs = ["command_buffer_cmd.cc"], + hdrs = ["command_buffer_cmd.h"], + local_defines = if_cuda_is_configured([ + "GOOGLE_CUDA=1", ]), - visibility = ["//visibility:public"], deps = [ - ":support", - "//xla/runtime:custom_call", - "//xla/runtime:custom_call_registry", - "//xla/runtime:executable", + ":annotation", + ":custom_call_thunk", + ":nccl_all_gather_thunk", + ":nccl_all_reduce_thunk", + "//xla:shape_util", + "//xla:status", + "//xla:types", + "//xla:util", + "//xla/service:buffer_assignment", "//xla/service:collective_ops_utils", - "//xla/service:computation_placer_hdr", - "//xla/service:executable", + "//xla/service:computation_placer", + "//xla/service:custom_call_status_internal", + "//xla/service:custom_call_status_public_headers", "//xla/service:global_device_id", + "//xla/service/gpu:buffer_allocations", "//xla/service/gpu:gpu_executable_run_options", + "//xla/service/gpu:launch_dimensions", + "//xla/service/gpu:matmul_utils", "//xla/service/gpu:nccl_api", + "//xla/service/gpu:nccl_clique", + "//xla/service/gpu:nccl_clique_key", "//xla/service/gpu:nccl_collective_thunks", + "//xla/service/gpu:stream_executor_util", "//xla/service/gpu:thunk", + "//xla/service/gpu/kernels:custom_kernel", "//xla/stream_executor", + "//xla/stream_executor/gpu:gpu_stream_header", + "//xla/stream_executor/gpu:gpu_types_header", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/functional:function_ref", "@com_google_absl//absl/log", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/synchronization", "@com_google_absl//absl/types:span", - ] + if_gpu_is_configured([ - "//xla/service/gpu:mock_nccl_utils", - ]), + "@local_tsl//tsl/concurrency:ref_count", + "@local_tsl//tsl/platform:env", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:logging", + "@local_tsl//tsl/platform:statusor", + ], ) cc_library( - name = "conv", - srcs = ["conv.cc"], - hdrs = ["conv.h"], - local_defines = if_cuda_is_configured([ - "GOOGLE_CUDA=1", - ]), - visibility = ["//visibility:public"], + name = "command_buffer_cmd_emitter", + srcs = ["command_buffer_cmd_emitter.cc"], + hdrs = ["command_buffer_cmd_emitter.h"], deps = [ - ":support", + ":command_buffer_cmd", + ":conditional_thunk", + ":copy_thunk", + ":custom_call_thunk", + ":gemm_thunk", + ":kernel_thunk", + ":memset_thunk", + ":nccl_all_gather_thunk", + ":nccl_all_reduce_thunk", + ":replica_id_thunk", + ":sequential_thunk", + ":wait_for_streams_thunk", + ":while_thunk", "//xla:status", - "//xla:xla_proto_cc", - "//xla/mlir/runtime/transforms:custom_call_encoding", - "//xla/runtime:custom_call", - "//xla/runtime:custom_call_registry", - "//xla/runtime:executable", - "//xla/service:executable", - "//xla/service/gpu:autotuner_util", - "//xla/service/gpu:gpu_asm_opts_util", - "//xla/service/gpu:gpu_conv_runner", - "//xla/service/gpu:non_atomically_upgradeable_rw_lock", - "//xla/stream_executor:device_memory", - "//xla/stream_executor:device_memory_allocator", - "//xla/translate/mhlo_to_hlo:attribute_exporter", - "@com_google_absl//absl/container:node_hash_map", - "@com_google_absl//absl/functional:function_ref", - "@com_google_absl//absl/synchronization", - "@llvm-project//llvm:Support", - ] + if_cuda_is_configured([ - "//xla/service/gpu:conv_algorithm_picker", - ]), + "//xla:statusor", + "//xla:util", + "//xla/service/gpu:nccl_collective_thunks", + "//xla/service/gpu:thunk", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:statusor", + ], ) -cc_library( - name = "conv_reorder", - srcs = ["conv_reorder.cc"], - hdrs = ["conv_reorder.h"], - visibility = ["//visibility:public"], +xla_test( + name = "command_buffer_cmd_test", + srcs = if_gpu_is_configured(["command_buffer_cmd_test.cc"]), + backends = ["gpu"], + local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]) + if_rocm_is_configured(["TENSORFLOW_USE_ROCM=1"]), deps = [ - ":support", - "//xla:xla_proto_cc", - "//xla/runtime:custom_call", - "//xla/runtime:custom_call_registry", - "//xla/runtime:executable", + ":command_buffer_cmd", + "//xla:status", + "//xla:types", + "//xla/service:buffer_assignment", "//xla/service:executable", + "//xla/service:platform_util", + "//xla/service/gpu:buffer_allocations", + "//xla/service/gpu:launch_dimensions", + "//xla/service/gpu:thunk", + "//xla/stream_executor", + "//xla/stream_executor:platform", + "//xla/stream_executor:platform_manager", + "//xla/stream_executor/gpu:gpu_test_kernels", + "@com_google_absl//absl/functional:function_ref", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@local_tsl//tsl/lib/core:status_test_util", + "@local_tsl//tsl/platform:status", + "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/platform:test", + "@local_tsl//tsl/platform:test_benchmark", + "@local_tsl//tsl/platform:test_main", ], ) +#===-------------------------------------------------------------------------------------------===// +# XLA Thunks Runtime +#===-------------------------------------------------------------------------------------------===// + cc_library( - name = "norm", - srcs = ["norm.cc"], - hdrs = ["norm.h"], - visibility = ["//visibility:public"], - deps = [ - ":support", - "//xla:status", - "//xla:xla_proto_cc", - "//xla/mlir/runtime/transforms:custom_call_encoding", - "//xla/runtime:custom_call", - "//xla/runtime:custom_call_registry", - "//xla/runtime:executable", - "//xla/service:executable", - "//xla/service/gpu:gpu_asm_opts_util", - "//xla/service/gpu:gpu_norm_runner", + name = "cholesky_thunk", + srcs = if_gpu_is_configured(["cholesky_thunk.cc"]), + hdrs = if_gpu_is_configured(["cholesky_thunk.h"]), + deps = if_gpu_is_configured([ + "//xla/service/gpu:buffer_allocations", + "//xla/service/gpu:cusolver_context", + "//xla/service/gpu:make_batch_pointers", + "//xla/service/gpu:thunk", + "//xla:types", + "//xla:util", + "//xla:xla_data_proto_cc", + "//xla/service:buffer_assignment", + "//xla/hlo/ir:hlo", + "@local_tsl//tsl/platform:logging", + "//xla/stream_executor", "//xla/stream_executor:device_memory", - "//xla/stream_executor:device_memory_allocator", - "//xla/translate/mhlo_to_hlo:attribute_exporter", - "@com_google_absl//absl/container:node_hash_map", - "@com_google_absl//absl/synchronization", - "@llvm-project//llvm:Support", + "//xla/stream_executor/gpu:gpu_asm_opts", + ]) + [ + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings:str_format", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:status", ], ) cc_library( - name = "fused_attention", - srcs = ["fused_attention.cc"], - hdrs = ["fused_attention.h"], - visibility = ["//visibility:public"], + name = "command_buffer_thunk", + srcs = ["command_buffer_thunk.cc"], + hdrs = ["command_buffer_thunk.h"], + local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]), deps = [ - ":support", + ":annotation", + ":command_buffer_allocations", + ":command_buffer_cmd", "//xla:status", - "//xla:xla_proto_cc", - "//xla/mlir/runtime/transforms:custom_call_encoding", - "//xla/runtime:custom_call", - "//xla/runtime:custom_call_registry", - "//xla/runtime:executable", - "//xla/service:executable", - "//xla/service/gpu:gpu_asm_opts_util", - "//xla/service/gpu:gpu_fused_mha_runner", - "//xla/stream_executor:device_memory", - "//xla/stream_executor:device_memory_allocator", - "//xla/translate/mhlo_to_hlo:attribute_exporter", - "@com_google_absl//absl/container:node_hash_map", + "//xla:statusor", + "//xla/service:buffer_assignment", # build_cleaner: keep + "//xla/service/gpu:buffer_allocations", # build_cleaner: keep + "//xla/service/gpu:thunk", + "//xla/stream_executor", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/synchronization", - "@llvm-project//llvm:Support", + "@local_tsl//tsl/platform:env", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:logging", + "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/profiler/lib:profiler_lock", + "@local_tsl//tsl/profiler/lib:scoped_annotation", + "@local_tsl//tsl/profiler/lib:traceme", + "@local_tsl//tsl/profiler/lib:traceme_encode", ], ) -cc_library( - name = "cub_sort", - srcs = ["cub_sort.cc"], - hdrs = ["cub_sort.h"], - local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]) + if_rocm_is_configured([ - "TENSORFLOW_USE_ROCM=1", - ]), - visibility = ["//visibility:public"], +xla_test( + name = "command_buffer_thunk_test", + srcs = if_gpu_is_configured(["command_buffer_thunk_test.cc"]), + backend_tags = { + "gpu_a100": ["config-cuda-only"], + "gpu_v100": ["config-cuda-only"], + }, + backends = [ + "gpu_a100", + "gpu_v100", + ], + local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]) + if_rocm_is_configured(["TENSORFLOW_USE_ROCM=1"]), deps = [ - ":support", - "//xla/runtime:custom_call", - "//xla/runtime:custom_call_registry", - "//xla/runtime:executable", - "//xla/runtime:memref_view", + ":command_buffer_allocations", + ":command_buffer_cmd", + ":command_buffer_thunk", + "//xla:shape_util", + "//xla:types", + "//xla/service:buffer_assignment", "//xla/service:executable", - "//xla/stream_executor:device_memory", - "@com_google_absl//absl/status", - ] + if_gpu_is_configured([ - "//xla/service/gpu/runtime3:cub_sort_thunk", + "//xla/service:platform_util", + "//xla/service/gpu:buffer_allocations", + "//xla/service/gpu:launch_dimensions", + "//xla/service/gpu:matmul_utils", + "//xla/service/gpu:thunk", + "//xla/stream_executor", + "//xla/stream_executor:platform", + "//xla/stream_executor:platform_manager", + "//xla/stream_executor/gpu:gpu_test_kernels", + "//xla/stream_executor/gpu:gpu_types_header", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@local_tsl//tsl/lib/core:status_test_util", + "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/platform:test", + "@local_tsl//tsl/platform:test_main", + ] + if_cuda_is_configured([ + "@local_config_cuda//cuda:cuda_headers", ]), ) cc_library( - name = "custom_call", - srcs = ["custom_call.cc"], - hdrs = ["custom_call.h"], - copts = [ - "-fexceptions", - "-fno-strict-aliasing", - ], - features = ["-use_header_modules"], - local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]) + if_rocm_is_configured([ - "TENSORFLOW_USE_ROCM=1", - ]), - visibility = ["//visibility:public"], + name = "conditional_thunk", + srcs = ["conditional_thunk.cc"], + hdrs = ["conditional_thunk.h"], deps = [ - ":support", - ":triangular_solve", - "//xla/runtime:custom_call_registry", - "//xla/runtime:executable", - "//xla/service:custom_call_status_internal", - "//xla/service:custom_call_target_registry", - "//xla/service:executable", - "//xla/service/gpu:cublas_cudnn", - "//xla/stream_executor/gpu:gpu_stream_header", + ":sequential_thunk", + "//xla:status", + "//xla:status_macros", + "//xla:util", + "//xla/hlo/ir:hlo", + "//xla/service:buffer_assignment", + "//xla/service/gpu:buffer_allocations", + "//xla/service/gpu:thunk", + "//xla/service/gpu:variant_visitor", + "//xla/stream_executor", + "//xla/stream_executor:memory_allocation", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/types:span", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:statusor", ], ) cc_library( - name = "custom_call_registry", - srcs = ["custom_call_registry.cc"], - hdrs = ["custom_call_registry.h"], - visibility = ["//visibility:public"], - deps = ["//xla/runtime:custom_call_registry"], -) - -cc_library( - name = "executable", - srcs = ["executable.cc"], - hdrs = ["executable.h"], - local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]), - visibility = ["//visibility:public"], + name = "convolution_thunk", + srcs = ["convolution_thunk.cc"], + hdrs = ["convolution_thunk.h"], deps = [ - ":cholesky", - ":collectives", - ":concurrent_region", - ":conv", - ":conv_reorder", - ":cub_sort", - ":custom_call", - ":custom_call_registry", - ":fft", - ":fused_attention", - ":gemm", - ":gpublas_lt_matmul", - ":graph_launch", - ":io_feed", - ":kernel_launch", - ":memcpy", - ":memset", - ":norm", - ":send_recv", - ":stream_synchronization", - ":support", - ":topk", - ":tracing", - "//xla:statusor", - "//xla:xla_proto_cc", - "//xla/mlir/runtime/transforms:compilation_pipeline_gpu", - "//xla/runtime:executable", - "//xla/runtime:jit_executable", - "//xla/runtime:module_registry", - "//xla/service:executable", - "//xla/service:stream_pool", - "//xla/service/gpu:buffer_allocations", - "//xla/service/gpu:non_atomically_upgradeable_rw_lock", + "//xla:util", + "//xla/service:buffer_assignment", + "//xla/service/gpu:gpu_conv_runner", "//xla/service/gpu:thunk", "//xla/stream_executor", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/status", - "@local_tsl//tsl/protobuf:dnn_proto_cc", ], ) cc_library( - name = "fft", - srcs = ["fft.cc"], - hdrs = ["fft.h"], - visibility = ["//visibility:public"], + name = "copy_thunk", + srcs = ["copy_thunk.cc"], + hdrs = ["copy_thunk.h"], deps = [ - ":support", - "//xla/mlir/runtime/transforms:custom_call_encoding", - "//xla/runtime:custom_call", - "//xla/runtime:custom_call_registry", - "//xla/runtime:executable", - "//xla/runtime:state", - "//xla/service/gpu/runtime3:fft_thunk", - "//xla/stream_executor:fft", + "//xla:status", + "//xla/service:buffer_assignment", + "//xla/service/gpu:thunk", + "//xla/stream_executor", + "@llvm-project//mlir:IR", ], ) cc_library( - name = "topk_kernel", - srcs = if_gpu_is_configured(["topk_kernel.cc"]), - hdrs = if_gpu_is_configured(["topk_kernel.h"]), - compatible_with = [], + name = "cub_sort_thunk", + srcs = if_gpu_is_configured(["cub_sort_thunk.cc"]), + hdrs = if_gpu_is_configured(["cub_sort_thunk.h"]), local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]) + if_rocm_is_configured([ "TENSORFLOW_USE_ROCM=1", ]), - visibility = ["//visibility:public"], - deps = [ - ":gpu_kernel_helper", - ":support", + deps = if_gpu_is_configured([ + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "//xla/service:buffer_assignment", + "//xla/service/gpu:buffer_allocations", + "//xla/service/gpu:thunk", + "//xla/stream_executor:device_memory", "//xla:shape_util", - "//xla:types", "//xla:util", "//xla:xla_data_proto_cc", - "//xla/runtime:memref_view", - "//xla/stream_executor", # build_cleaner: keep - "//xla/stream_executor:platform", - "//xla/stream_executor/gpu:gpu_stream_header", - "//xla/stream_executor/gpu:gpu_types_header", - "@com_google_absl//absl/numeric:bits", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@eigen_archive//:eigen3", "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:statusor", - ] + if_cuda_is_configured([ - ":topk_kernel_cuda", - ]) + if_rocm_is_configured([ - ":topk_kernel_rocm", - ]), -) - -cuda_library( - name = "topk_kernel_cuda", - srcs = if_cuda_is_configured( - [ - "topk_kernel_bfloat16.cu.cc", - "topk_kernel_float.cu.cc", - "topk_kernel.cu.h", - ], - ), - hdrs = if_cuda_is_configured(["topk_kernel_common.h"]), - compatible_with = [], - local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]), - visibility = ["//visibility:public"], - deps = [ - ":gpu_kernel_helper", - "@eigen_archive//:eigen3", - ], + ] + ["//xla/service/gpu:cub_sort_kernel_" + suffix for suffix in get_cub_sort_kernel_types()]), ) -rocm_library( - name = "topk_kernel_rocm", - srcs = if_rocm_is_configured( - [ - "topk_kernel_bfloat16.cu.cc", - "topk_kernel_float.cu.cc", - "topk_kernel.cu.h", - ], - ), - hdrs = if_rocm_is_configured(["topk_kernel_common.h"]), - compatible_with = [], - local_defines = if_rocm_is_configured(["TENSORFLOW_USE_ROCM=1"]), - deps = [ - ":gpu_kernel_helper", - "@eigen_archive//:eigen3", - ], -) - -xla_cc_test( - name = "topk_kernel_test", - srcs = if_gpu_is_configured(["topk_kernel_test.cc"]), - tags = tf_gpu_tests_tags(), +cc_library( + name = "custom_call_thunk", + srcs = ["custom_call_thunk.cc"], + hdrs = ["custom_call_thunk.h"], + local_defines = if_cuda_is_configured([ + "GOOGLE_CUDA=1", + ]), deps = [ - ":gpu_kernel_helper", - ":topk_kernel", - "//xla:xla_data_proto_cc", - "//xla/stream_executor", # build_cleaner: keep - "//xla/stream_executor:multi_platform_manager", - "//xla/stream_executor/gpu:gpu_init", + "//xla:executable_run_options", + "//xla:shape_util", + "//xla:status", + "//xla:util", + "//xla/ffi:call_frame", + "//xla/ffi:ffi_api", + "//xla/ffi/api:c_api", + "//xla/hlo/ir:hlo", + "//xla/service:buffer_assignment", + "//xla/service:custom_call_status", + "//xla/service:custom_call_status_internal", + "//xla/service:executable", + "//xla/service/gpu:thunk", + "//xla/stream_executor:device_memory", "//xla/stream_executor/gpu:gpu_stream_header", - "//xla/stream_executor/gpu:gpu_timer_header", "//xla/stream_executor/gpu:gpu_types_header", - "//xla/stream_executor/host:host_platform", - "@com_google_absl//absl/log:check", - "@com_google_absl//absl/random", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/time", - "@eigen_archive//:eigen3", - "@local_tsl//tsl/platform:test", - "@local_tsl//tsl/platform:test_benchmark", - "@local_tsl//tsl/platform:test_main", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/strings:string_view", ], ) -xla_cc_test( - name = "topk_test", - srcs = ["topk_test.cc"], - local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]) + if_rocm_is_configured([ - "TENSORFLOW_USE_ROCM=1", - ]), - tags = tf_gpu_tests_tags(), +cc_library( + name = "fft_thunk", + srcs = ["fft_thunk.cc"], + hdrs = ["fft_thunk.h"], deps = [ - ":topk", - "//xla:error_spec", - "//xla:shape_util", - "//xla:status", - "//xla:statusor", "//xla:types", + "//xla:util", + "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", - "//xla/service:gpu_plugin", - "//xla/service:hlo_pass", - "//xla/service:platform_util", - "//xla/service:topk_rewriter", - "//xla/service/gpu:topk_specializer", - "//xla/tests:hlo_test_base", - "//xla/tests:verified_hlo_module", - "//xla/tests:xla_internal_test_main", # fixdeps: keep - "@com_google_absl//absl/container:flat_hash_set", + "//xla/service:buffer_assignment", + "//xla/service/gpu:buffer_allocations", + "//xla/service/gpu:thunk", + "//xla/stream_executor", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", - "@local_tsl//tsl/platform:statusor", - "@local_tsl//tsl/platform:test_main", + "@com_google_absl//absl/strings:str_format", + "@local_tsl//tsl/platform:logging", + "@local_tsl//tsl/platform:status", ], ) cc_library( - name = "topk", - srcs = if_gpu_is_configured(["topk.cc"]), - hdrs = ["topk.h"], - local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]) + if_rocm_is_configured([ - "TENSORFLOW_USE_ROCM=1", - ]), - visibility = ["//visibility:public"], - deps = if_gpu_is_configured([":topk_kernel"]) + [ - ":support", - "//xla:executable_run_options", - "//xla:shape_util", - "//xla:status", - "//xla:statusor", - "//xla:types", + name = "fused_mha_thunk", + srcs = ["fused_mha_thunk.cc"], + hdrs = ["fused_mha_thunk.h"], + deps = [ + "//xla:util", "//xla:xla_data_proto_cc", - "//xla:xla_proto_cc", - "//xla/hlo/ir:hlo", - "//xla/mlir/runtime/transforms:custom_call_encoding", - "//xla/runtime:custom_call", - "//xla/runtime:custom_call_registry", - "//xla/runtime:executable", - "//xla/runtime:state", - "//xla/service:executable", - "//xla/service:hlo_pass", - "//xla/service:tuple_util", - "//xla/stream_executor/gpu:gpu_stream_header", - "//xla/stream_executor/gpu:gpu_types_header", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/log:check", - "@com_google_absl//absl/status", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/types:span", - "@local_tsl//tsl/platform:statusor", + "//xla/service:buffer_assignment", + "//xla/service/gpu:gpu_fused_mha_runner", + "//xla/service/gpu:thunk", + "//xla/stream_executor", + "@com_google_absl//absl/container:flat_hash_map", ], ) cc_library( - name = "gemm", - srcs = ["gemm.cc"], - hdrs = ["gemm.h"], - local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]) + if_rocm_is_configured([ - "TENSORFLOW_USE_ROCM=1", - ]), - visibility = ["//visibility:public"], + name = "gemm_thunk", + srcs = ["gemm_thunk.cc"], + hdrs = ["gemm_thunk.h"], deps = [ - ":support", "//xla:status", - "//xla:xla_proto_cc", - "//xla/runtime:custom_call", - "//xla/runtime:custom_call_registry", - "//xla/runtime:executable", - "//xla/runtime:state", - "//xla/service:executable", - "//xla/service:hlo_module_config", - "//xla/service/gpu:gpu_asm_opts_util", + "//xla/service:buffer_assignment", "//xla/service/gpu:matmul_utils", - "//xla/service/gpu:non_atomically_upgradeable_rw_lock", - "//xla/stream_executor:blas", + "//xla/service/gpu:thunk", "//xla/stream_executor:device_memory", - "@com_google_absl//absl/container:node_hash_map", "@com_google_absl//absl/status", - "@local_tsl//tsl/platform:errors", - ] + if_gpu_is_configured([ - "//xla/service/gpu:gemm_algorithm_picker", - "//xla/stream_executor/gpu:redzone_allocator", - ]), + "@local_tsl//tsl/platform:logging", + ], ) cc_library( - name = "graph_launch", - srcs = ["graph_launch.cc"], - hdrs = ["graph_launch.h"], + name = "gpublas_lt_matmul_thunk", + srcs = if_gpu_is_configured(["gpublas_lt_matmul_thunk.cc"]), + hdrs = if_gpu_is_configured(["gpublas_lt_matmul_thunk.h"]), local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]) + if_rocm_is_configured([ "TENSORFLOW_USE_ROCM=1", ]), - visibility = ["//visibility:public"], + deps = if_gpu_is_configured([ + "//xla/service:buffer_assignment", + "//xla/service/gpu:matmul_utils", + "//xla/service/gpu:thunk", + "//xla:status", + "//xla/stream_executor:device_memory", + "//xla/stream_executor", + "@local_tsl//tsl/platform:logging", + ]), +) + +cc_library( + name = "infeed_thunk", + srcs = ["infeed_thunk.cc"], + hdrs = ["infeed_thunk.h"], deps = [ - ":concurrent_region", - ":conv", - ":gemm", - ":kernel_launch", - ":support", - "//xla:statusor", - "//xla/runtime:custom_call", - "//xla/runtime:custom_call_registry", - "//xla/runtime:executable", - "//xla/service:executable", + "//xla:shape_util", + "//xla:status_macros", + "//xla:util", "//xla/service/gpu:buffer_allocations", - "//xla/service/gpu:non_atomically_upgradeable_rw_lock", + "//xla/service/gpu:io_feed_manager", + "//xla/service/gpu:thunk", "//xla/stream_executor", - "//xla/stream_executor/gpu:gpu_graph", - "@com_google_absl//absl/container:node_hash_map", - "@com_google_absl//absl/log", - "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", + ], +) + +cc_library( + name = "kernel_thunk", + srcs = ["kernel_thunk.cc"], + hdrs = ["kernel_thunk.h"], + deps = [ + "//xla:status", + "//xla:types", + "//xla/hlo/ir:hlo", + "//xla/service:buffer_assignment", + "//xla/service/gpu:kernel_arguments", + "//xla/service/gpu:launch_dimensions", + "//xla/service/gpu:stream_executor_util", + "//xla/service/gpu:thunk", + "//xla/service/gpu/kernels:custom_kernel", + "//xla/stream_executor", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/synchronization", "@com_google_absl//absl/types:span", - "@local_tsl//tsl/profiler/lib:profiler_lock", - "@local_tsl//tsl/profiler/lib:traceme", - "@local_tsl//tsl/profiler/lib:traceme_encode", + "@llvm-project//mlir:IR", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:logging", + "@local_tsl//tsl/platform:statusor", ], ) cc_library( - name = "concurrent_region", - srcs = ["concurrent_region.cc"], - hdrs = ["concurrent_region.h"], - visibility = ["//visibility:public"], + name = "memset_thunk", + srcs = ["memset_thunk.cc"], + hdrs = ["memset_thunk.h"], deps = [ - ":support", - "//xla/runtime:custom_call", - "//xla/runtime:custom_call_registry", - "//xla/runtime:executable", - "//xla/service:executable", - "//xla/service:stream_pool", + "//xla:status", + "//xla/service:buffer_assignment", + "//xla/service/gpu:thunk", "//xla/stream_executor", - "@local_tsl//tsl/platform:statusor", + "@com_google_absl//absl/status", ], ) cc_library( - name = "stream_synchronization", - srcs = ["stream_synchronization.cc"], - hdrs = ["stream_synchronization.h"], - visibility = ["//visibility:public"], + name = "nccl_all_gather_thunk", + srcs = ["nccl_all_gather_thunk.cc"], + hdrs = ["nccl_all_gather_thunk.h"], deps = [ - ":concurrent_region", - ":support", - "//xla/runtime:custom_call_registry", - "//xla/runtime:executable", + "//xla:shape_util", + "//xla/hlo/ir:hlo", + "//xla/service:collective_ops_utils", + "//xla/service/gpu:backend_configs_cc", + "//xla/service/gpu:nccl_api", + "//xla/service/gpu:nccl_collective_thunks", + "//xla/service/gpu:thunk", + "//xla/stream_executor", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/types:span", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:logging", + "@local_tsl//tsl/platform:statusor", ], ) cc_library( - name = "io_feed", - srcs = ["io_feed.cc"], - hdrs = ["io_feed.h"], - visibility = ["//visibility:public"], + name = "nccl_all_reduce_thunk", + srcs = ["nccl_all_reduce_thunk.cc"], + hdrs = ["nccl_all_reduce_thunk.h"], deps = [ - ":support", - "//xla/runtime:custom_call", - "//xla/runtime:custom_call_registry", - "//xla/runtime:executable", - "//xla/service:executable", - "//xla/service/gpu:io_feed_manager", + "//xla:status_macros", + "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/mlir_hlo:lhlo_gpu", + "//xla/service:collective_ops_utils", + "//xla/service/gpu:backend_configs_cc", + "//xla/service/gpu:nccl_api", + "//xla/service/gpu:nccl_collective_thunks", + "//xla/service/gpu:thunk", + "//xla/stream_executor", + "//xla/translate/hlo_to_mhlo:hlo_utils", + "//xla/translate/mhlo_to_hlo:type_to_shape", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/types:span", + "@llvm-project//mlir:IR", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:logging", + "@local_tsl//tsl/platform:statusor", ], ) cc_library( - name = "kernel_launch", - srcs = ["kernel_launch.cc"], - hdrs = ["kernel_launch.h"], - local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]) + if_rocm_is_configured([ - "TENSORFLOW_USE_ROCM=1", - ]), - visibility = ["//visibility:public"], + name = "nccl_all_to_all_thunk", + srcs = ["nccl_all_to_all_thunk.cc"], + hdrs = ["nccl_all_to_all_thunk.h"], deps = [ - ":concurrent_region", - ":support", - "//xla:statusor", + "//xla:shape_util", + "//xla:status_macros", "//xla/hlo/ir:hlo", - "//xla/runtime:custom_call", - "//xla/runtime:custom_call_registry", - "//xla/runtime:executable", - "//xla/runtime:memref_view", - "//xla/runtime:state", - "//xla/service:executable", - "//xla/service:hlo_proto_cc", - "//xla/service/gpu:launch_dimensions", - "//xla/service/gpu:stream_executor_util", - "//xla/service/gpu/kernels:custom_kernel", - "//xla/service/gpu/kernels:custom_kernel_fusion", + "//xla/service:collective_ops_utils", + "//xla/service/gpu:ir_emission_utils", + "//xla/service/gpu:nccl_api", + "//xla/service/gpu:nccl_collective_thunks", + "//xla/service/gpu:thunk", "//xla/stream_executor", - "//xla/stream_executor/gpu:gpu_graph", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/container:inlined_vector", - "@com_google_absl//absl/container:node_hash_map", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/synchronization", + "@llvm-project//mlir:IR", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:logging", + "@local_tsl//tsl/platform:statusor", ], ) cc_library( - name = "gpublas_lt_matmul", - srcs = ["gpublas_lt_matmul.cc"], - hdrs = ["gpublas_lt_matmul.h"], - local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]) + if_rocm_is_configured([ - "TENSORFLOW_USE_ROCM=1", - ]), - visibility = ["//visibility:public"], + name = "norm_thunk", + srcs = ["norm_thunk.cc"], + hdrs = ["norm_thunk.h"], deps = [ - ":support", - "//xla:xla_proto_cc", - "//xla/mlir/runtime/transforms:custom_call_encoding", - "//xla/runtime:custom_call", - "//xla/runtime:custom_call_registry", - "//xla/runtime:executable", - "//xla/runtime:logical_result", - "//xla/runtime:state", - "//xla/service:executable", - "//xla/service/gpu:matmul_utils", + "//xla:util", + "//xla:xla_data_proto_cc", + "//xla/service:buffer_assignment", + "//xla/service/gpu:gpu_norm_runner", + "//xla/service/gpu:thunk", "//xla/stream_executor", - "@local_tsl//tsl/platform:status", - ] + if_rocm_is_configured([ - "@local_config_rocm//rocm:rocm_headers", - ]), + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/status", + ], ) cc_library( - name = "memcpy", - srcs = ["memcpy.cc"], - hdrs = ["memcpy.h"], - visibility = ["//visibility:public"], + name = "outfeed_thunk", + srcs = ["outfeed_thunk.cc"], + hdrs = ["outfeed_thunk.h"], deps = [ - ":concurrent_region", - ":support", - "//xla/runtime:custom_call", - "//xla/runtime:custom_call_registry", - "//xla/runtime:executable", - "//xla/service:executable", + "//xla:util", + "//xla/service/gpu:io_feed_manager", + "//xla/service/gpu:thunk", + "//xla/stream_executor", + "@com_google_absl//absl/status", ], ) cc_library( - name = "memset", - srcs = ["memset.cc"], - hdrs = ["memset.h"], - visibility = ["//visibility:public"], + name = "replica_id_thunk", + srcs = ["replica_id_thunk.cc"], + hdrs = ["replica_id_thunk.h"], deps = [ - ":support", - "//xla/runtime:custom_call", - "//xla/runtime:custom_call_registry", - "//xla/runtime:executable", - "//xla/service:executable", - "@com_google_absl//absl/base", + "//xla/service:buffer_assignment", + "//xla/service:global_device_id", + "//xla/service/gpu:thunk", + "@com_google_absl//absl/status", + "@local_tsl//tsl/platform:statusor", ], ) cc_library( - name = "support", - srcs = ["support.cc"], - hdrs = ["support.h"], + name = "sequential_thunk", + srcs = ["sequential_thunk.cc"], + hdrs = ["sequential_thunk.h"], local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]) + if_rocm_is_configured([ "TENSORFLOW_USE_ROCM=1", ]), - visibility = ["//visibility:public"], deps = [ - "//xla:shape_util", - "//xla/mlir/runtime/transforms:custom_call_encoding", - "//xla/runtime:custom_call", - "//xla/service/gpu:matmul_utils", - "//xla/stream_executor:blas", - "//xla/stream_executor:device_memory", + ":annotation", + "//xla:status", + "//xla/hlo/ir:hlo", + "//xla/service/gpu:buffer_allocations", + "//xla/service/gpu:thunk", + "@com_google_absl//absl/status", "@com_google_absl//absl/strings", - "@llvm-project//llvm:Support", - "@local_tsl//tsl/profiler/lib:scoped_annotation_stack", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/profiler/lib:scoped_annotation", ], ) cc_library( - name = "send_recv", - srcs = ["send_recv.cc"], - hdrs = ["send_recv.h"], - visibility = ["//visibility:public"], + name = "send_recv_thunk", + srcs = ["send_recv_thunk.cc"], + hdrs = ["send_recv_thunk.h"], deps = [ - ":support", - "//xla/mlir/runtime/transforms:custom_call_encoding", - "//xla/mlir_hlo", - "//xla/runtime:custom_call", - "//xla/runtime:custom_call_registry", - "//xla/runtime:executable", - "//xla/service:executable", + "//xla:shape_util", + "//xla:status", + "//xla:statusor", + "//xla:xla_data_proto_cc", + "//xla/service:buffer_assignment", + "//xla/service:global_device_id", + "//xla/service/gpu:thunk", "//xla/stream_executor", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/log", "@com_google_absl//absl/status", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/synchronization", "@local_tsl//tsl/concurrency:async_value", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/profiler/lib:traceme", - "@local_tsl//tsl/profiler/lib:traceme_encode", ], ) cc_library( - name = "tracing", - srcs = [ - "annotation.cc", - "tracing.cc", - ], - hdrs = [ - "annotation.h", - "tracing.h", + name = "triangular_solve_thunk", + srcs = if_gpu_is_configured(["triangular_solve_thunk.cc"]), + hdrs = if_gpu_is_configured(["triangular_solve_thunk.h"]), + deps = if_gpu_is_configured([ + "@com_google_absl//absl/strings:str_format", + "//xla/service/gpu:buffer_allocations", + "//xla/service/gpu:make_batch_pointers", + "//xla/service/gpu:thunk", + "//xla:types", + "//xla:util", + "//xla:xla_data_proto_cc", + "//xla/service:buffer_assignment", + "//xla/hlo/ir:hlo", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:logging", + "//xla/stream_executor", + "//xla/stream_executor:device_memory", + "//xla/stream_executor/gpu:gpu_asm_opts", + ]) + [ + "@com_google_absl//absl/status", + "@local_tsl//tsl/platform:status", ], - local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]), - visibility = ["//visibility:public"], +) + +cc_library( + name = "while_thunk", + srcs = ["while_thunk.cc"], + hdrs = ["while_thunk.h"], deps = [ - ":support", + ":sequential_thunk", + "//xla:status", "//xla/hlo/ir:hlo", - "//xla/runtime:custom_call_registry", - "//xla/runtime:executable", - "//xla/runtime:tracing", - "//xla/runtime:type_id", + "//xla/service:buffer_assignment", + "//xla/service/gpu:buffer_allocations", + "//xla/service/gpu:thunk", + "//xla/stream_executor", + "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", + "@com_google_absl//absl/status", "@com_google_absl//absl/strings:str_format", - "@local_tsl//tsl/profiler/lib:nvtx_utils", - "@local_tsl//tsl/profiler/lib:scoped_annotation_stack", + "@com_google_absl//absl/synchronization", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:logging", + "@local_tsl//tsl/platform:statusor", ], ) cc_library( - name = "triangular_solve", - srcs = ["triangular_solve.cc"], - hdrs = ["triangular_solve.h"], - local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]) + if_rocm_is_configured(["TENSORFLOW_USE_ROCM"]), - visibility = ["//visibility:public"], + name = "wait_for_streams_thunk", + srcs = ["wait_for_streams_thunk.cc"], + hdrs = ["wait_for_streams_thunk.h"], deps = [ - ":support", - "//xla:xla_proto_cc", - "//xla/runtime:custom_call", - "//xla/service:executable", - "//xla/service/gpu:gpu_asm_opts_util", - "//xla/service/gpu/runtime3:triangular_solve_thunk", - "@local_tsl//tsl/platform:human_readable_json", + "//xla/service:global_device_id", + "//xla/service/gpu:thunk", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:statusor", ], ) diff --git a/third_party/xla/xla/service/gpu/runtime/annotation.cc b/third_party/xla/xla/service/gpu/runtime/annotation.cc index c95d19f154a6fd..39af816c1b29fe 100644 --- a/third_party/xla/xla/service/gpu/runtime/annotation.cc +++ b/third_party/xla/xla/service/gpu/runtime/annotation.cc @@ -15,38 +15,237 @@ limitations under the License. #include "xla/service/gpu/runtime/annotation.h" +#include #include +#include +#include +#include +#include +#include +#include +#include +#include #include +#include +#include +#include +#include "absl/status/status.h" +#include "absl/strings/str_format.h" #include "absl/strings/str_split.h" #include "xla/hlo/ir/dfs_hlo_visitor_with_default.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/printer.h" +#include "xla/status.h" +#include "tsl/platform/errors.h" +#include "tsl/profiler/lib/nvtx_utils.h" +#include "tsl/profiler/lib/scoped_annotation.h" namespace xla::gpu { +using ::tsl::profiler::ScopedAnnotation; namespace { -nvtxStringHandle_t RegisterString(const char* str) { + +nvtxStringHandle_t RegisterString(const std::string& str) { #if GOOGLE_CUDA - auto domain = tsl::profiler::nvtx::GetNVTXDomain(); + auto domain = tsl::profiler::GetNVTXDomain(); if (!domain) { - // NVTX not enabled, so don't bother registering strings with it - return {}; + return {}; // NVTX not enabled, so don't registering strings. } - std::string buffer{}; constexpr auto max_length = 65330; - if (auto const length = std::strlen(str); length >= max_length) { - // nvbugs 4340868 - std::string_view suffix{"\n[truncated]\n"}; - buffer.reserve(max_length); - buffer.assign(str, str + length - suffix.size()); - buffer.append(suffix); - str = buffer.c_str(); - } - return nvtxDomainRegisterStringA(*domain, str); + if (str.size() <= max_length) { + return nvtxDomainRegisterStringA(*domain, str.c_str()); + } + // nvbugs 4340868 + std::string_view suffix{"\n[truncated]\n"}; + std::string buffer(str.data(), max_length - suffix.size()); + buffer.append(suffix); + return nvtxDomainRegisterStringA(*domain, buffer.c_str()); #else return {}; #endif } +// Nsight Systems supports some basic HTML markup in annotation strings. This +// escaping stops things like from disappearing. +std::ostream& PrintEscaped(std::ostream& os, std::string_view str) { + for (char c : str) { + switch (c) { + case '<': + os << "<"; + break; + case '>': + os << ">"; + break; + default: + os << c; + } + } + return os; +} + +// Print options for profiler annotations. +HloPrintOptions PrintOptions() { + auto opts = HloPrintOptions::ShortParsable(); + opts.set_print_large_constants(false); + opts.set_print_control_dependencies(true); + opts.set_print_operand_index_annotation_interval(5); + opts.set_print_backend_config(true); + opts.set_print_metadata(true); + opts.set_print_name_after_closing_brace(true); + return opts; +} + +// Sortable struct representing a frame in the Python stacktrace attached to a +// given instruction. +struct StackFrame { + std::string_view file_name, function_name, op_name; + int line, column; + + private: + auto tied() const { + return std::tie(file_name, line, column, function_name, op_name); + } + friend bool operator==(StackFrame const& lhs, StackFrame const& rhs) { + return lhs.tied() == rhs.tied(); + } + friend bool operator<(StackFrame const& lhs, StackFrame const& rhs) { + return lhs.tied() < rhs.tied(); + } +}; + +// Walk through the HLO graph from an instruction and collect the source +// file/line information we see along the way. This allows us to generate an +// annotation for each kernel that shows the (merged) Python stacktraces of the +// operations that were traced and compiled int this kernel. For example: +// +// - /opt/jax/examples/mnist_vae.py:143[] +// -- /opt/jax/examples/mnist_vae.py:127[run_epoch] +// --- /opt/jax/examples/mnist_vae.py:125[body_fun] +// ---- /opt/jax/examples/mnist_vae.py:124[] +// ----- /opt/jax/examples/mnist_vae.py:122[body_fun] transpose[permutation=(1, +// 0)] +// --- /opt/jax/examples/mnist_vae.py:126[body_fun] add +// --- /opt/jax/examples/mnist_vae.py:126[body_fun] mul +// --- /opt/jax/examples/mnist_vae.py:126[body_fun] sub +// +// shows four merged stacktraces (3 of depth 3, 1 of depth 5). +class SourceLocationVisitor : public ConstDfsHloVisitorWithDefault { + public: + explicit SourceLocationVisitor( + std::string_view op_name_prefix_to_remove__ = {}) + : op_name_prefix_to_remove_{op_name_prefix_to_remove__} {} + + std::string AsString(int32_t common_prefix) const { + // Format the call stacks we've collected; if call stack collection was not + // enabled then each "stack" just has depth 1 and no column/function name + // information. Skip the first `common_prefix` elements of each stack trace + if (common_prefix < 0) { + return "[invalid common_prefix]"; + } + std::ostringstream oss{}; + oss << '\n'; + std::vector current_state{}; + for (auto const& call_stack : location_set_) { + for (auto depth = 0; depth < call_stack.size() - common_prefix; ++depth) { + auto const& frame = call_stack[common_prefix + depth]; + if (depth < current_state.size() && current_state[depth] == frame) { + continue; + } + current_state.resize(depth + 1); + current_state[depth] = frame; + FormatFrame(oss, frame, depth); + } + } + return std::move(oss).str(); + } + + Status DefaultAction(HloInstruction const* inst) final { + OpMetadata const& meta = inst->metadata(); + // The full op_name is split across three places: the module-level + // annotation shows the prefix that is common to the whole module, the + // kernel-level annotation removes that prefix and shows whatever middle + // sections of the name are common to all operations in the kernel, and the + // individual call stack frames in the kernel-level annotation show the + // final parts of the op_name that have not already been shown. + std::string_view op_name = meta.op_name(); + if (!op_name.empty()) { + op_name = op_name.substr(op_name_prefix_to_remove_.size()); + } + if (!op_name.empty() && op_name.front() == '/') { + op_name = op_name.substr(1); + } + if (int frame_id = meta.stack_frame_id(); frame_id != 0) { + std::vector call_stack{}; + HloModule const* const hlo_module = inst->parent()->parent(); + while (frame_id != 0) { + HloModule::StackFrame frame = hlo_module->get_stack_frame(frame_id); + if (frame.empty()) { + break; + } + frame_id = frame.parent_frame_id; + call_stack.emplace_back(StackFrame{frame.file_name, frame.function_name, + op_name, frame.line, frame.column}); + // only attach the op_name to the most-nested frame + op_name = {}; + } + // re-order to be [caller, callee, ...] + std::reverse(call_stack.begin(), call_stack.end()); + location_set_.emplace(call_stack); + } else if (!meta.source_file().empty() && meta.source_line() != 0) { + location_set_.emplace(1, StackFrame{meta.source_file(), + {/* function_name */}, + op_name, + meta.source_line()}); + } + return OkStatus(); + } + + std::pair LongestSourceLocationPrefix() const { + // Find the longest common prefix along the members of location_set_ and + // return a formatted version of that prefix, along with its length. As + // location_set_ is sorted, that just means looking for the longest common + // prefix of the first and last elements. + if (location_set_.size() < 2) { + // Only extract a prefix if there are enough stack traces. + return {}; + } + const auto& first_loc = *location_set_.begin(); + const auto common_end = std::mismatch(first_loc.begin(), first_loc.end(), + location_set_.rbegin()->begin(), + location_set_.rbegin()->end()) + .first; + std::ostringstream oss{}; + oss << '\n'; + std::for_each(first_loc.begin(), common_end, + [&oss](const StackFrame& frame) { FormatFrame(oss, frame); }); + const int32_t prefix_frames = std::distance(first_loc.begin(), common_end); + return {RegisterString(std::move(oss).str()), prefix_frames}; + } + + private: + static void FormatFrame(std::ostringstream& oss, const StackFrame& frame, + int depth = -1) { + if (depth >= 0) { + oss << std::string(depth + 1, '-') << ' '; + } + PrintEscaped(oss, frame.file_name) << ':' << frame.line; + if (frame.column) { + oss << ':' << frame.column; + } + if (!frame.function_name.empty()) { + PrintEscaped(oss << '[', frame.function_name) << ']'; + } + if (!frame.op_name.empty()) { + PrintEscaped(oss << ' ', frame.op_name); + } + oss << '\n'; + } + std::string_view op_name_prefix_to_remove_{}; + std::set> location_set_{}; +}; + template absl::Status VisitInstAndCalledButNotOperands(Visitor& visitor, const HloInstruction& inst) { @@ -90,16 +289,17 @@ class OpNamePrefixVisitor : public ConstDfsHloVisitorWithDefault { absl::Status DefaultAction(const HloInstruction* inst) final { auto const& op_name = inst->metadata().op_name(); if (!op_name.empty()) { - prefix = prefix ? LongestPrefix(*prefix, op_name) : op_name; + prefix_ = prefix_ ? LongestPrefix(*prefix_, op_name) : op_name; } return absl::OkStatus(); } + std::string_view longest_op_name_prefix() const { - return prefix.value_or(std::string_view{}); + return prefix_.value_or(""); } private: - std::optional prefix{}; + std::optional prefix_; }; std::string_view GetLongestOpNamePrefix(const HloModule& mod) { @@ -132,30 +332,120 @@ std::string MakeTitle(const HloModule& mod, std::string_view longest_prefix) { return absl::StrFormat("XlaModule:#prefix=%s,hlo_module=%s,program_id=%d#", longest_prefix, mod.name(), mod.unique_id()); } + +std::string FormatSourceLocations(HloInstruction const& inst, + int32_t common_frames) { + // Inside the source location/backtrace report the op_name too, but remove the + // kernel-wide prefix for brevity + SourceLocationVisitor visitor{GetLongestOpNamePrefix(inst)}; + // Visit the given instruction, and the things it calls, but not its operands + // -- we don't want to collect the source code locations that produced the + // inputs to this kernel, just those corresponding to the kernel itself. + if (!VisitInstAndCalledButNotOperands(visitor, inst).ok()) { + return "[error]"; + } + return visitor.AsString(common_frames); +} + +// Get the string representation of this instruction as an std::string. +std::string InstructionAsString(HloInstruction const& inst) { + StringPrinter printer; + inst.Print(&printer, PrintOptions()); + return std::move(printer).ToString(); +} + +// Get the string representation of the HLO code called by this instruction, +// but not the instruction itself. The typical example is a fusion instruction, +// where InstructionAsString(fusion_inst) would be something like +// fusion.N = ... fusion(...), calls=fused_computation.N ... +// and CalledInstructionsAsString(fusion_inst) would be something like +// fused_computation.N { ... } +std::string CalledInstructionsAsString(HloInstruction const& inst) { + StringPrinter printer; + auto const opts = PrintOptions(); + for (HloComputation const* called : inst.called_computations()) { + called->Print(&printer, opts); + } + return std::move(printer).ToString(); +} + +// Get a string representing the longest common prefix of source locations in +// this module, and the number of frames that that represents. +std::pair GetLongestSourceLocationPrefix( + const HloModule& mod) { + // In the presence of (at least) debug callbacks, calling Accept on the root + // instruction of the module may not reach all instructions in the module. + SourceLocationVisitor visitor{}; + for (const HloComputation* computation : mod.computations()) { + for (const HloInstruction* inst : computation->instructions()) { + if (!visitor.DefaultAction(inst).ok()) { + return {}; + } + } + } + return visitor.LongestSourceLocationPrefix(); +} } // namespace -ModuleAnnotation::ModuleAnnotation(std::string module_name_, int module_id_) - : longest_prefix{}, - title_str{ - module_id_ >= 0 - ? absl::StrFormat("XlaModule:#hlo_module=%s,program_id=%d", - module_name_, module_id_) - : absl::StrFormat("XlaModule:#hlo_module=%s", module_name_)}, - title{RegisterString(title_str.c_str())} {} +ModuleAnnotation::ModuleAnnotation(std::string_view module_name_) + : title_str_(absl::StrFormat("XlaModule:#hlo_module=%s#", module_name_)), + title_(RegisterString(title_str_)), + module_name_(RegisterString(std::string{module_name_})) {} ModuleAnnotation::ModuleAnnotation(const HloModule& mod) - : longest_prefix{GetLongestOpNamePrefix(mod)}, - title_str{MakeTitle(mod, longest_prefix)}, - title{RegisterString(title_str.c_str())} {} - -std::string_view ModuleAnnotation::longest_op_name_prefix() const { - return longest_prefix; + : longest_prefix_(GetLongestOpNamePrefix(mod)), + title_str_(MakeTitle(mod, longest_prefix_)), + title_(RegisterString(title_str_)), + module_name_(RegisterString(mod.name())), + module_id_(mod.unique_id()) { + std::tie(common_src_locations_, common_stack_frames_) = + GetLongestSourceLocationPrefix(mod); } -std::string_view ModuleAnnotation::Title() const { return title_str; } +#if GOOGLE_CUDA +namespace { +auto schema_entry(uint64_t type, const char* name, uint64_t offset) { + nvtxPayloadSchemaEntry_t r{}; + r.type = type; + r.name = name; + r.offset = offset; + return r; +} +} // namespace +#endif -nvtxStringHandle_t ModuleAnnotation::NvtxRegisteredTitle() const { - return title; +uint64_t ModuleAnnotation::NvtxSchemaId() { + static std::uint64_t schema_id = []() -> std::uint64_t { +#if GOOGLE_CUDA + auto domain_opt = tsl::profiler::GetNVTXDomain(); + if (!domain_opt.has_value()) { + return 0; + } + const nvtxPayloadSchemaEntry_t schema[] = { + schema_entry(NVTX_PAYLOAD_ENTRY_TYPE_NVTX_REGISTERED_STRING_HANDLE, + "Name", offsetof(ModuleAnnotation, module_name_)), + schema_entry(NVTX_PAYLOAD_ENTRY_TYPE_INT32, "Unique ID", + offsetof(ModuleAnnotation, module_id_)), + schema_entry(NVTX_PAYLOAD_ENTRY_TYPE_NVTX_REGISTERED_STRING_HANDLE, + "Common source locations", + offsetof(ModuleAnnotation, common_src_locations_))}; + const nvtxPayloadSchemaAttr_t schemaAttr = { + /* .fieldMask = */ NVTX_PAYLOAD_SCHEMA_ATTR_NAME | + NVTX_PAYLOAD_SCHEMA_ATTR_TYPE | NVTX_PAYLOAD_SCHEMA_ATTR_ENTRIES | + NVTX_PAYLOAD_SCHEMA_ATTR_NUM_ENTRIES | + NVTX_PAYLOAD_SCHEMA_ATTR_STATIC_SIZE, + /* .name = */ "XlaModule", + /* .type = */ NVTX_PAYLOAD_SCHEMA_TYPE_STATIC, + /* .flags = */ NVTX_PAYLOAD_SCHEMA_FLAG_NONE, + /* .entries = */ schema, + /* .numEntries = */ sizeof(schema) / sizeof(schema[0]), + /* .payloadStaticSize = */ sizeof(ModuleAnnotation)}; + return nvtxPayloadSchemaRegister(*domain_opt, &schemaAttr); +#else + return 0; +#endif + }(); + return schema_id; } namespace { @@ -186,42 +476,105 @@ std::string MakeKernelName(std::string_view prefix, KernelAnnotation::KernelAnnotation(const ModuleAnnotation& module_annotation, const HloInstruction& inst) - : title_str{MakeKernelName(module_annotation.longest_op_name_prefix(), - inst)}, - title{RegisterString(title_str.c_str())} {} + : title_str( + MakeKernelName(module_annotation.longest_op_name_prefix(), inst)), + title(RegisterString(title_str)), + hlo_dump(RegisterString(InstructionAsString(inst))), + src_locations(RegisterString(FormatSourceLocations( + inst, module_annotation.common_stack_frames()))), + called_hlo_dump(RegisterString("\n" + CalledInstructionsAsString(inst))) { +} -std::string_view KernelAnnotation::Title() const { return title_str; } +ModuleAnnotations::ModuleAnnotations(std::string_view module_name) + : top_level(module_name) {} -nvtxStringHandle_t KernelAnnotation::NvtxRegisteredTitle() const { - return title; +uint64_t KernelAnnotation::NvtxSchemaId() { + static std::uint64_t schema_id = []() -> std::uint64_t { +#if GOOGLE_CUDA + auto domain_opt = tsl::profiler::GetNVTXDomain(); + if (!domain_opt.has_value()) { + return 0; + } + const nvtxPayloadSchemaEntry_t schema[] = { + schema_entry(NVTX_PAYLOAD_ENTRY_TYPE_NVTX_REGISTERED_STRING_HANDLE, + "Source locations", + offsetof(KernelAnnotation, src_locations)), + schema_entry(NVTX_PAYLOAD_ENTRY_TYPE_NVTX_REGISTERED_STRING_HANDLE, + "HLO", offsetof(KernelAnnotation, hlo_dump)), + schema_entry(NVTX_PAYLOAD_ENTRY_TYPE_NVTX_REGISTERED_STRING_HANDLE, + "Called HLO", + offsetof(KernelAnnotation, called_hlo_dump))}; + const nvtxPayloadSchemaAttr_t schemaAttr = { + /* .fieldMask = */ NVTX_PAYLOAD_SCHEMA_ATTR_NAME | + NVTX_PAYLOAD_SCHEMA_ATTR_TYPE | NVTX_PAYLOAD_SCHEMA_ATTR_ENTRIES | + NVTX_PAYLOAD_SCHEMA_ATTR_NUM_ENTRIES | + NVTX_PAYLOAD_SCHEMA_ATTR_STATIC_SIZE, + /* .name = */ "XlaKernel", + /* .type = */ NVTX_PAYLOAD_SCHEMA_TYPE_STATIC, + /* .flags = */ NVTX_PAYLOAD_SCHEMA_FLAG_NONE, + /* .entries = */ schema, + /* .numEntries = */ sizeof(schema) / sizeof(schema[0]), + /* .payloadStaticSize = */ sizeof(KernelAnnotation)}; + return nvtxPayloadSchemaRegister(*domain_opt, &schemaAttr); +#else + return 0; +#endif + }(); + return schema_id; } ModuleAnnotations::ModuleAnnotations(const HloModule& mod) : top_level{mod} { // loop through `mod` and populate `kernels` (string -> KernelAnnotation map) // with the information we want to attach to individual kernels. - for (const HloComputation* computation : - mod.computations()) { // top-level blocks in the module - for (const HloInstruction* inst : - computation->instructions()) { // statements within block - // working assumption: only custom calls and fusions end up with NVTX - // ranges named after them. bad assumption [at least partially]: cuda - // graph launches are not handled correctly - switch (inst->opcode()) { - case HloOpcode::kCustomCall: - case HloOpcode::kFusion: { - // e.g. inst.name is "fusion.6", inst.opcode is "kFusion" and called - // is ["fused_computation.5"], in which case the content of - // "fused_computation.5" ends up under an NVTX range called - // "fusion.6". We want to construct a useful annotation for that NVTX - // range based on the content of `inst`, including `called` etc. - // FIXME: using try_emplace here was sensitive to - // https://github.com/abseil/abseil-cpp/issues/388. - kernels.insert({inst->name(), {top_level, *inst}}); - } break; - default: - break; - } + for (const HloComputation* computation : mod.computations()) { + for (const HloInstruction* inst : computation->instructions()) { + // e.g. inst.name is "fusion.6", inst.opcode is "kFusion" and called + // is ["fused_computation.5"], in which case the content of + // "fused_computation.5" ends up under an NVTX range called + // "fusion.6". We want to construct a useful annotation for that NVTX + // range based on the content of `inst`, including `called` etc. + // FIXME: using try_emplace here was sensitive to + // https://github.com/abseil/abseil-cpp/issues/388. + kernels.insert({inst->name(), {top_level, *inst}}); + } + } +} + +//===----------------------------------------------------------------------===// +// Scoped RAII helper to set and restore thread local module annotations +//===----------------------------------------------------------------------===// + +namespace { +thread_local const ModuleAnnotations* current_annotations = nullptr; +} // namespace + +ScopedModuleAnnotations::ScopedModuleAnnotations( + const ModuleAnnotations* annotations) + : restore_(std::exchange(current_annotations, annotations)) {} + +ScopedModuleAnnotations::~ScopedModuleAnnotations() { + std::exchange(current_annotations, restore_); +} + +const ModuleAnnotations* GetCurrentModuleAnnotations() { + return current_annotations; +} + +std::optional GetKernelAnnotation( + const ModuleAnnotations* annotations, std::string_view profile_annotation) { + if (profile_annotation.empty()) { + return {}; + } + if (annotations) { + // Have a set of pre-prepared thunk/kernel annotations to use + const auto iter = annotations->kernels.find(profile_annotation); + if (iter != annotations->kernels.end()) { + // Have a pre-prepared annotation, use it + return std::optional{[&] { return iter->second; }}; } } + return std::optional{ + [&] { return absl::StrFormat("Thunk:#hlo_op=%s#", profile_annotation); }}; } + } // namespace xla::gpu diff --git a/third_party/xla/xla/service/gpu/runtime/annotation.h b/third_party/xla/xla/service/gpu/runtime/annotation.h index 9721ad80fd3049..eef1f8d91e5318 100644 --- a/third_party/xla/xla/service/gpu/runtime/annotation.h +++ b/third_party/xla/xla/service/gpu/runtime/annotation.h @@ -16,44 +16,95 @@ limitations under the License. #ifndef XLA_SERVICE_GPU_RUNTIME_ANNOTATION_H_ #define XLA_SERVICE_GPU_RUNTIME_ANNOTATION_H_ +#include +#include +#include + #include "absl/container/flat_hash_map.h" +#include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" #include "tsl/profiler/lib/nvtx_utils.h" +#include "tsl/profiler/lib/scoped_annotation.h" namespace xla::gpu { + // Prepared information for the top level NVTX/profiler range covering an // HloModule -struct ModuleAnnotation { - ModuleAnnotation(std::string module_name, int module_id); - ModuleAnnotation(const HloModule& mod); - std::string_view longest_op_name_prefix() const; - nvtxStringHandle_t NvtxRegisteredTitle() const; - std::string_view Title() const; +class ModuleAnnotation { + public: + explicit ModuleAnnotation(std::string_view module_name); + explicit ModuleAnnotation(const HloModule& mod); + + std::string_view longest_op_name_prefix() const { return longest_prefix_; } + explicit operator std::string_view() const { return title_str_; } + nvtxStringHandle_t title() const { return title_; } + static uint64_t NvtxSchemaId(); + int32_t common_stack_frames() const { return common_stack_frames_; } private: - std::string longest_prefix; - std::string title_str; - nvtxStringHandle_t title{}; + friend void RangePush(nvtxDomainHandle_t domain, + const ModuleAnnotation& annotation) { + tsl::profiler::RangePush(domain, annotation.title(), annotation); + } + + std::string longest_prefix_; + std::string title_str_; + nvtxStringHandle_t title_; + nvtxStringHandle_t module_name_; + nvtxStringHandle_t common_src_locations_{}; + int32_t module_id_{-1}; + int32_t common_stack_frames_{}; }; // Prepared information for a kernel/thunk/fusion/... within an HloModule struct KernelAnnotation { - KernelAnnotation(const ModuleAnnotation& module_annotaion, + KernelAnnotation(const ModuleAnnotation& module_annotation, const HloInstruction& inst); - nvtxStringHandle_t NvtxRegisteredTitle() const; - std::string_view Title() const; + + explicit operator std::string_view() const { return title_str; } + static uint64_t NvtxSchemaId(); private: + friend void RangePush(nvtxDomainHandle_t domain, + const KernelAnnotation& annotation) { + tsl::profiler::RangePush(domain, annotation.title, annotation); + } + std::string title_str; - nvtxStringHandle_t title{}; + nvtxStringHandle_t title; + nvtxStringHandle_t hlo_dump; + nvtxStringHandle_t src_locations; + nvtxStringHandle_t called_hlo_dump; }; + // Parsed/prepared information for an HloModule that gets propagated to NVTX // ranges/profilers/... at execution time. struct ModuleAnnotations { - ModuleAnnotations(const HloModule&); + explicit ModuleAnnotations(std::string_view module_name); + explicit ModuleAnnotations(const HloModule&); + ModuleAnnotation top_level; - absl::flat_hash_map kernels{}; + absl::flat_hash_map kernels; }; + +//===----------------------------------------------------------------------===// +// Scoped RAII helper to set and restore thread local module annotations +//===----------------------------------------------------------------------===// + +class ScopedModuleAnnotations { + public: + explicit ScopedModuleAnnotations(const ModuleAnnotations* annotations); + ~ScopedModuleAnnotations(); + + private: + const ModuleAnnotations* restore_; +}; + +const ModuleAnnotations* GetCurrentModuleAnnotations(); + +std::optional GetKernelAnnotation( + const ModuleAnnotations* annotations, std::string_view profile_annotation); + } // namespace xla::gpu #endif // XLA_SERVICE_GPU_RUNTIME_ANNOTATION_H_ diff --git a/third_party/xla/xla/service/gpu/runtime/cholesky.cc b/third_party/xla/xla/service/gpu/runtime/cholesky.cc deleted file mode 100644 index 4291d156b0ef19..00000000000000 --- a/third_party/xla/xla/service/gpu/runtime/cholesky.cc +++ /dev/null @@ -1,84 +0,0 @@ -/* Copyright 2022 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "xla/service/gpu/runtime/cholesky.h" - -#include "xla/runtime/custom_call.h" -#include "xla/runtime/executable.h" -#include "xla/service/gpu/gpu_asm_opts_util.h" -#include "xla/service/gpu/runtime/support.h" -#include "xla/service/service_executable_run_options.h" -#include "xla/xla.pb.h" - -#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM -#include "xla/service/gpu/runtime3/cholesky_thunk.h" -#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM - -namespace xla { -namespace gpu { - -using ::xla::runtime::CustomCall; -using ::xla::runtime::MemrefView; -using ::xla::runtime::StridedMemrefView; - -static absl::Status CholeskyImpl(const ServiceExecutableRunOptions* run_options, - const DebugOptions* debug_options, - StridedMemrefView operand, StridedMemrefView a, - MemrefView workspace, MemrefView info, - int64_t batch_size, bool is_lower, int64_t n) { -#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM - se::DeviceMemoryBase operand_buffer = GetDeviceAddress(operand); - se::DeviceMemoryBase a_buffer = GetDeviceAddress(a); - se::DeviceMemoryBase workspace_buffer = GetDeviceAddress(workspace); - se::DeviceMemoryBase info_buffer = GetDeviceAddress(info); - - VLOG(3) << "Running Cholesky"; - se::Stream* stream = run_options->stream(); - - // Copy operand to the a buffer if they are different. - if (a.data != operand.data) - stream->ThenMemcpy(&a_buffer, operand_buffer, operand_buffer.size()); - - using UpperLower = se::blas::UpperLower; - UpperLower uplo = is_lower ? UpperLower::kLower : UpperLower::kUpper; - - CholeskyParams params{n, batch_size, uplo, - a_buffer, workspace_buffer, info_buffer}; - return RunCholesky(xla::gpu::PtxOptsFromDebugOptions(*debug_options), - operand.dtype, ¶ms, stream); -#else // GOOGLE_CUDA || TENSORFLOW_USE_ROCM - return absl::InternalError("Cholesky is not supported without GPU"); -#endif -} - -XLA_RUNTIME_DEFINE_CUSTOM_CALL( - Cholesky, FunctionWrapper(), checks, - CustomCall::Bind("xla.gpu.cholesky") - .UserData() - .UserData() - .Arg() // operand - .Arg() // a - .Arg() // workspace - .Arg() // info - .Attr("batch_size") - .Attr("is_lower") - .Attr("n")); - -void RegisterCholeskyCustomCalls(runtime::DirectCustomCallRegistry& registry) { - registry.Register("xla.gpu.cholesky", Cholesky); -} - -} // namespace gpu -} // namespace xla diff --git a/third_party/xla/xla/service/gpu/runtime/cholesky.h b/third_party/xla/xla/service/gpu/runtime/cholesky.h deleted file mode 100644 index 0a8639093e2c59..00000000000000 --- a/third_party/xla/xla/service/gpu/runtime/cholesky.h +++ /dev/null @@ -1,30 +0,0 @@ -/* Copyright 2022 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_SERVICE_GPU_RUNTIME_CHOLESKY_H_ -#define XLA_SERVICE_GPU_RUNTIME_CHOLESKY_H_ - -#include "xla/runtime/custom_call_registry.h" - -namespace xla { -namespace gpu { - -// Registers XLA Gpu runtime cholesky custom calls. -void RegisterCholeskyCustomCalls(runtime::DirectCustomCallRegistry& registry); - -} // namespace gpu -} // namespace xla - -#endif // XLA_SERVICE_GPU_RUNTIME_CHOLESKY_H_ diff --git a/third_party/xla/xla/service/gpu/runtime3/cholesky_thunk.cc b/third_party/xla/xla/service/gpu/runtime/cholesky_thunk.cc similarity index 61% rename from third_party/xla/xla/service/gpu/runtime3/cholesky_thunk.cc rename to third_party/xla/xla/service/gpu/runtime/cholesky_thunk.cc index d6db6468211924..b91be4449d3b8b 100644 --- a/third_party/xla/xla/service/gpu/runtime3/cholesky_thunk.cc +++ b/third_party/xla/xla/service/gpu/runtime/cholesky_thunk.cc @@ -13,20 +13,25 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/runtime3/cholesky_thunk.h" +#include "xla/service/gpu/runtime/cholesky_thunk.h" #include +#include #include #include #include +#include "absl/status/status.h" +#include "absl/strings/str_format.h" #include "xla/service/gpu/cusolver_context.h" #include "xla/service/gpu/make_batch_pointers.h" #include "xla/stream_executor/blas.h" #include "xla/stream_executor/device_memory.h" +#include "xla/stream_executor/gpu/gpu_asm_opts.h" #include "xla/stream_executor/stream_executor.h" #include "xla/util.h" #include "xla/xla_data.pb.h" +#include "tsl/platform/errors.h" #include "tsl/platform/logging.h" namespace xla { @@ -63,6 +68,26 @@ absl::Status DoPotrfBatched(const se::GpuAsmOpts& asm_opts, params->batch_size); } +template +absl::Status DoPotrfUnbatched(const se::GpuAsmOpts& asm_opts, + CholeskyParams* params, se::Stream* stream, + GpuSolverContext& context) { + T* a_base = static_cast(params->a_buffer.opaque()); + int* info_base = static_cast(params->info_buffer.opaque()); + + int64_t stride = params->n * params->n; + for (int64_t i = 0; i < params->batch_size; ++i) { + se::DeviceMemory a_data( + se::DeviceMemoryBase(&a_base[i * stride], sizeof(T) * stride)); + se::DeviceMemory info_data( + se::DeviceMemoryBase(&info_base[i], sizeof(int))); + se::DeviceMemory workspace_data(params->workspace_buffer); + TF_RETURN_IF_ERROR(context.Potrf(params->uplo, params->n, a_data, params->n, + info_data, workspace_data)); + } + return absl::OkStatus(); +} + } // namespace CholeskyThunk::CholeskyThunk(ThunkInfo thunk_info, @@ -109,23 +134,43 @@ absl::Status RunCholesky(const se::GpuAsmOpts& asm_opts, PrimitiveType type, TF_RETURN_IF_ERROR(context.status()); TF_RETURN_IF_ERROR(context->SetStream(stream)); - switch (type) { - case F32: - return DoPotrfBatched(asm_opts, cholesky_params, stream, *context); - case F64: - return DoPotrfBatched(asm_opts, cholesky_params, stream, - *context); - case C64: - return DoPotrfBatched>(asm_opts, cholesky_params, - stream, *context); - case C128: - return DoPotrfBatched>(asm_opts, cholesky_params, - stream, *context); - default: - return InvalidArgument("Invalid type for cholesky %s", - PrimitiveType_Name(type)); + if (cholesky_params->batch_size > 1) { + switch (type) { + case F32: + return DoPotrfBatched(asm_opts, cholesky_params, stream, + *context); + case F64: + return DoPotrfBatched(asm_opts, cholesky_params, stream, + *context); + case C64: + return DoPotrfBatched>(asm_opts, cholesky_params, + stream, *context); + case C128: + return DoPotrfBatched>(asm_opts, cholesky_params, + stream, *context); + default: + return InvalidArgument("Invalid type for cholesky %s", + PrimitiveType_Name(type)); + } + } else { + switch (type) { + case F32: + return DoPotrfUnbatched(asm_opts, cholesky_params, stream, + *context); + case F64: + return DoPotrfUnbatched(asm_opts, cholesky_params, stream, + *context); + case C64: + return DoPotrfUnbatched>(asm_opts, cholesky_params, + stream, *context); + case C128: + return DoPotrfUnbatched>(asm_opts, cholesky_params, + stream, *context); + default: + return InvalidArgument("Invalid type for cholesky %s", + PrimitiveType_Name(type)); + } } } - } // namespace gpu } // namespace xla diff --git a/third_party/xla/xla/service/gpu/runtime3/cholesky_thunk.h b/third_party/xla/xla/service/gpu/runtime/cholesky_thunk.h similarity index 94% rename from third_party/xla/xla/service/gpu/runtime3/cholesky_thunk.h rename to third_party/xla/xla/service/gpu/runtime/cholesky_thunk.h index 26054a194e4f62..e56c01a7fa2646 100644 --- a/third_party/xla/xla/service/gpu/runtime3/cholesky_thunk.h +++ b/third_party/xla/xla/service/gpu/runtime/cholesky_thunk.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_GPU_RUNTIME3_CHOLESKY_THUNK_H_ -#define XLA_SERVICE_GPU_RUNTIME3_CHOLESKY_THUNK_H_ +#ifndef XLA_SERVICE_GPU_RUNTIME_CHOLESKY_THUNK_H_ +#define XLA_SERVICE_GPU_RUNTIME_CHOLESKY_THUNK_H_ #include @@ -80,4 +80,4 @@ absl::Status RunCholesky(const se::GpuAsmOpts& asm_opts, PrimitiveType type, } // namespace gpu } // namespace xla -#endif // XLA_SERVICE_GPU_RUNTIME3_CHOLESKY_THUNK_H_ +#endif // XLA_SERVICE_GPU_RUNTIME_CHOLESKY_THUNK_H_ diff --git a/third_party/xla/xla/service/gpu/runtime/collectives.cc b/third_party/xla/xla/service/gpu/runtime/collectives.cc deleted file mode 100644 index 9dcfe6a42e06ea..00000000000000 --- a/third_party/xla/xla/service/gpu/runtime/collectives.cc +++ /dev/null @@ -1,1028 +0,0 @@ -/* Copyright 2022 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "xla/service/gpu/runtime/collectives.h" - -#include -#include -#include -#include -#include -#include - -#include "absl/log/log.h" -#include "absl/status/status.h" -#include "absl/strings/str_format.h" -#include "absl/types/span.h" -#include "xla/runtime/custom_call.h" -#include "xla/runtime/executable.h" -#include "xla/service/collective_ops_utils.h" -#include "xla/service/computation_placer.h" -#include "xla/service/global_device_id.h" -#include "xla/service/gpu/gpu_executable_run_options.h" -#include "xla/service/gpu/nccl_all_gather_thunk.h" -#include "xla/service/gpu/nccl_all_reduce_thunk.h" -#include "xla/service/gpu/nccl_all_to_all_thunk.h" -#include "xla/service/gpu/nccl_api.h" -#include "xla/service/gpu/nccl_collective_permute_thunk.h" -#include "xla/service/gpu/nccl_collective_thunk.h" -#include "xla/service/gpu/nccl_recv_thunk.h" -#include "xla/service/gpu/nccl_send_thunk.h" -#include "xla/service/gpu/runtime/support.h" -#include "xla/service/gpu/thunk.h" -#include "xla/service/service_executable_run_options.h" - -#if XLA_ENABLE_XCCL -#include "xla/service/gpu/mock_nccl_utils.h" -#endif // XLA_ENABLE_XCCL - -namespace xla { -namespace gpu { - -using xla::runtime::CustomCall; -using xla::runtime::FlatMemrefView; -using xla::runtime::StridedMemrefView; - -namespace { - -absl::Status RunRepeated(int32_t count, absl::FunctionRef to_run) { - if (count != 0) { - VLOG(3) << "Running each collective " << count << " times\n"; - } - for (int32_t i = 0; i < count; ++i) { - TF_RETURN_IF_ERROR(to_run()); - } - return absl::OkStatus(); -} - -// Helper function to run a collective either synchronously on main stream or -// asynchronously on the async stream. -absl::Status RunSyncOrAsync( - const ServiceExecutableRunOptions* run_options, - CollectivesSupport* collectives, AsyncCollectivesSupport* async_collectives, - int32_t uid, bool is_async, - absl::FunctionRef to_run, - AsyncStreamKind stream_kind = AsyncStreamKind::kCollective) { - se::Stream* main_stream = run_options->stream(); - se::Stream* async_stream = - is_async ? async_collectives->async_comm_stream(stream_kind) : nullptr; - if (is_async) { - // Wait until compute inputs are ready. - async_stream->ThenWaitFor(main_stream); - } - - // Launch the collective on either the main or async stream. - se::Stream* stream = is_async ? async_stream : main_stream; - TF_RETURN_IF_ERROR(to_run(stream)); - - if (is_async) { - TF_RETURN_IF_ERROR(async_collectives->RecordEvent(uid, stream_kind)); - } - int32_t device_ordinal = main_stream->parent()->device_ordinal(); - return collectives->MaybeBlockAfterFirstRun(uid, device_ordinal, main_stream); -} - -#if XLA_ENABLE_XCCL -bool ShouldEnableCliqueOptimization( - const Thunk::CollectiveExecuteParams& params, - const DebugOptions* debug_options, bool no_parallel_custom_call) { - // Enable clique optimization for single-host application, which is indicated - // by the absence of nccl_clique_id_callback. For multiple-host, only enable - // when a debug flag is set for now, due to some divergent compilation issues. - return no_parallel_custom_call && - (!params.nccl_clique_id_callback || - debug_options->xla_gpu_enable_nccl_clique_optimization()); -} - -absl::StatusOr GetNcclComm( - const Thunk::CollectiveExecuteParams& params, int64_t group_mode, - int64_t op_id, absl::Span replica_group_offsets, - absl::Span replica_group_values, int64_t stream_id, - bool enable_clique_optimization) { - // TODO(b/233930690): Pass the attribute below as a nested array. - // Pass an array of arrays using two vectors; one specifying all the values - // and another specifying the (ending) offsets of each array in the other - // vector. Example: [ [10, 20, 30, 40], [50, 60], [70, 80, 90] ] turns into - // offsets=[4, 6, 9] values=[10, 20, 30, 40, 50, 60, 70, 80, 90]. - std::vector replica_groups; - int i = 0; - for (int64_t replica_group_end : replica_group_offsets) { - ReplicaGroup replica_group; - while (i < replica_group_end) - replica_group.add_replica_ids(replica_group_values[i++]); - replica_groups.push_back(replica_group); - } - - return LockNcclComm(params, replica_groups, - static_cast(group_mode), op_id, - stream_id, enable_clique_optimization); -} - -absl::StatusOr GetMockNcclComm( - const Thunk::CollectiveExecuteParams& params, int64_t group_mode, - int64_t op_id, absl::Span replica_group_offsets, - absl::Span replica_group_values, int64_t stream_id, - bool enable_clique_optimization, - GpuExecutableRunOptions::MockNcclTopoModel topo_model) { - // TODO(b/233930690): Pass the attribute below as a nested array. - // Pass an array of arrays using two vectors; one specifying all the values - // and another specifying the (ending) offsets of each array in the other - // vector. Example: [ [10, 20, 30, 40], [50, 60], [70, 80, 90] ] turns into - // offsets=[4, 6, 9] values=[10, 20, 30, 40, 50, 60, 70, 80, 90]. - std::vector replica_groups; - int i = 0; - for (int64_t replica_group_end : replica_group_offsets) { - ReplicaGroup replica_group; - while (i < replica_group_end) - replica_group.add_replica_ids(replica_group_values[i++]); - replica_groups.push_back(replica_group); - } - - return LockMockNcclComm(params, replica_groups, - static_cast(group_mode), op_id, - stream_id, enable_clique_optimization, topo_model); -} -#endif // XLA_ENABLE_XCCL - -absl::StatusOr> GetDeviceBufferPairs( - CustomCall::RemainingArgs& args) { - // Add MemRef arguments as buffer arguments. - TF_RET_CHECK(args.size() % 2 == 0); - const int buffer_pairs = args.size() / 2; - std::vector device_buffers; - device_buffers.reserve(buffer_pairs); - for (int i = 0; i < buffer_pairs; ++i) { - auto source = args.get(i); - auto destination = args.get(i + buffer_pairs); - if (failed(source) || failed(destination)) { - return InvalidArgument("Unsupported device buffer pair type"); - } - - int64_t element_count = 1; - for (int64_t size : source->sizes) element_count *= size; - device_buffers.emplace_back(DeviceBufferPair{ - source->dtype, element_count, GetDeviceAddress(*source), - GetDeviceAddress(*destination)}); - } - return device_buffers; -} - -// Expects a single argument, and returns a device buffer pair with that -// argument replicated in both source and destination buffer. -absl::StatusOr> GetSingleArgAsDeviceBufferPair( - CustomCall::RemainingArgs& args) { - TF_RET_CHECK(args.size() == 1); - auto buffer = args.get(0); - if (failed(buffer)) { - return InvalidArgument("Unsupported device buffer type"); - } - int64_t element_count = 1; - for (int64_t size : buffer->sizes) element_count *= size; - return std::vector{ - DeviceBufferPair{buffer->dtype, element_count, GetDeviceAddress(*buffer), - GetDeviceAddress(*buffer)}}; -} - -absl::Status AsyncDoneImpl(const ServiceExecutableRunOptions* run_options, - AsyncCollectivesSupport* async_collectives, - int32_t uid, std::string_view done_type) { -#if XLA_ENABLE_XCCL - VLOG(3) << "Running " << done_type; - se::Stream* stream = run_options->stream(); - - TF_ASSIGN_OR_RETURN(se::Event event, async_collectives->PopEvent(uid)); - stream->ThenWaitFor(&event); - - return absl::OkStatus(); -#else // XLA_ENABLE_XCCL - return absl::InternalError("NCCL disabled"); -#endif // XLA_ENABLE_XCCL -} - -#if XLA_ENABLE_XCCL -absl::Status MockNcclImplCommon( - const ServiceExecutableRunOptions* run_options, - const DebugOptions* debug_options, se::Stream* stream, - CustomCall::RemainingArgs args, int64_t group_mode, int64_t op_id, - absl::Span replica_group_offsets, - absl::Span replica_group_values, bool is_async, - Thunk::Kind reduce_op, - GpuExecutableRunOptions::MockNcclTopoModel topo_model) { - TF_ASSIGN_OR_RETURN(Thunk::CollectiveExecuteParams params, - Thunk::CollectiveExecuteParams::Create( - *run_options, stream->parent()->device_ordinal())); - - auto comm = - GetMockNcclComm(params, group_mode, op_id, replica_group_offsets, - replica_group_values, GetStreamId(is_async), - debug_options->xla_gpu_enable_nccl_clique_optimization(), - topo_model); //); - if (absl::IsCancelled(comm.status())) return absl::OkStatus(); - if (!comm.ok()) return comm.status(); - - TF_ASSIGN_OR_RETURN(auto device_buffers, GetDeviceBufferPairs(args)); - - return RunMockNcclCollectives(NcclApi::Default(), device_buffers, *stream, - **comm, reduce_op); -} -#endif // XLA_ENABLE_XCCL - -//===----------------------------------------------------------------------===// -// CollectivePermute. -//===----------------------------------------------------------------------===// - -#if XLA_ENABLE_XCCL -using NcclP2PRunner = absl::FunctionRef; - -using DeviceBuffersGetter = - absl::FunctionRef>( - CustomCall::RemainingArgs& args)>; - -absl::Status MockNcclP2PImplCommon( - const ServiceExecutableRunOptions* run_options, - const DebugOptions* debug_options, se::Stream* stream, - CustomCall::RemainingArgs args, int64_t group_mode, int64_t op_id, - absl::Span replica_group_offsets, - absl::Span replica_group_values, - absl::Span source_peers, - absl::Span target_peers, NcclP2PRunner runner, - DeviceBuffersGetter device_buffers_getter, uint64_t stream_id, - GpuExecutableRunOptions::MockNcclTopoModel topo_model) { - TF_ASSIGN_OR_RETURN(Thunk::CollectiveExecuteParams params, - Thunk::CollectiveExecuteParams::Create( - *run_options, stream->parent()->device_ordinal())); - - const std::string device_string = - NcclCollectiveThunk::GetDeviceString(params); - - auto comm = GetMockNcclComm( - params, group_mode, op_id, replica_group_offsets, replica_group_values, - stream_id, debug_options->xla_gpu_enable_nccl_clique_optimization(), - topo_model); - if (absl::IsCancelled(comm.status())) return absl::OkStatus(); - if (!comm.ok()) return comm.status(); - - auto device_buffers = device_buffers_getter(args); - if (!device_buffers.ok()) return device_buffers.status(); - if (device_buffers->size() != 1) { - return absl::InternalError(absl::StrFormat( - "Expected device buffer size: 1, got %d", device_buffers->size())); - } - - GlobalDeviceId global_device_id = params.global_device_id; - - TF_ASSIGN_OR_RETURN(DeviceAssignment::LogicalID current_logical_id, - params.device_assn->LogicalIdForDevice(global_device_id)); - - const int64_t current_id = static_cast(group_mode) == - CollectiveOpGroupMode::kCrossReplica - ? current_logical_id.replica_id - : current_logical_id.computation_id; - - NcclP2PConfig::IdToSourceTargetMap id_to_source_target; - for (int i = 0; i < source_peers.size(); ++i) { - id_to_source_target[target_peers[i]].source = source_peers[i]; - id_to_source_target[source_peers[i]].target = target_peers[i]; - } - const NcclP2PConfig::SourceTargetMapEntry source_target = - NcclP2PConfig::GetSourceTarget(id_to_source_target, current_id); - - return runner(NcclApi::Default(), source_target, (*device_buffers)[0], - *stream, **comm, device_string, current_id); -} - -absl::Status P2PImplCommon(const ServiceExecutableRunOptions* run_options, - const DebugOptions* debug_options, - se::Stream* stream, CustomCall::RemainingArgs args, - int64_t group_mode, int64_t op_id, - bool no_parallel_custom_call, - absl::Span replica_group_offsets, - absl::Span replica_group_values, - absl::Span source_peers, - absl::Span target_peers, - NcclP2PRunner runner, - DeviceBuffersGetter device_buffers_getter, - uint64_t stream_id) { - (void)no_parallel_custom_call; - TF_ASSIGN_OR_RETURN(Thunk::CollectiveExecuteParams params, - Thunk::CollectiveExecuteParams::Create( - *run_options, stream->parent()->device_ordinal())); - bool enable_clique_opt = ShouldEnableCliqueOptimization( - params, debug_options, no_parallel_custom_call); - - const std::string device_string = - NcclCollectiveThunk::GetDeviceString(params); - auto comm = GetNcclComm(params, group_mode, op_id, replica_group_offsets, - replica_group_values, stream_id, enable_clique_opt); - if (!comm.ok()) return comm.status(); - - auto device_buffers = device_buffers_getter(args); - if (!device_buffers.ok()) return device_buffers.status(); - if (device_buffers->size() != 1) { - return absl::InternalError(absl::StrFormat( - "Expected device buffer size: 1, got %d", device_buffers->size())); - } - - GlobalDeviceId global_device_id = params.global_device_id; - - TF_ASSIGN_OR_RETURN(DeviceAssignment::LogicalID current_logical_id, - params.device_assn->LogicalIdForDevice(global_device_id)); - - const int64_t current_id = static_cast(group_mode) == - CollectiveOpGroupMode::kCrossReplica - ? current_logical_id.replica_id - : current_logical_id.computation_id; - - NcclP2PConfig::IdToSourceTargetMap id_to_source_target; - for (int i = 0; i < source_peers.size(); ++i) { - id_to_source_target[target_peers[i]].source = source_peers[i]; - id_to_source_target[source_peers[i]].target = target_peers[i]; - } - const NcclP2PConfig::SourceTargetMapEntry source_target = - NcclP2PConfig::GetSourceTarget(id_to_source_target, current_id); - - return RunRepeated(debug_options->xla_gpu_collective_inflation_factor(), - [&]() -> absl::Status { - return runner(NcclApi::Default(), source_target, - (*device_buffers)[0], *stream, **comm, - device_string, current_id); - }); -} -#endif // XLA_ENABLE_XCCL - -absl::Status CollectivePermuteImpl( - const ServiceExecutableRunOptions* run_options, - const DebugOptions* debug_options, CollectivesSupport* collectives, - AsyncCollectivesSupport* async_collectives, CustomCall::RemainingArgs args, - int32_t uid, int64_t group_mode, int64_t op_id, bool is_async, - bool no_parallel_custom_call, - absl::Span replica_group_offsets, - absl::Span replica_group_values, - absl::Span source_peers, - absl::Span target_peers) { -#if XLA_ENABLE_XCCL - VLOG(3) << "Running CollectivePermute " << (is_async ? "(Async) " : "(Sync) ") - << no_parallel_custom_call; - return RunSyncOrAsync( - run_options, collectives, async_collectives, uid, is_async, - [&](se::Stream* stream) { - const gpu::GpuExecutableRunOptions* gpu_opts = - run_options->run_options().gpu_executable_run_options(); - if (gpu_opts && gpu_opts->enable_mock_nccl_collectives()) { - return MockNcclP2PImplCommon( - run_options, debug_options, stream, args, group_mode, op_id, - replica_group_offsets, replica_group_values, source_peers, - target_peers, RunMockCollectivePermute, GetDeviceBufferPairs, - GetStreamId(is_async), gpu_opts->mock_nccl_topo_model()); - } - return P2PImplCommon(run_options, debug_options, stream, args, - group_mode, op_id, no_parallel_custom_call, - replica_group_offsets, replica_group_values, - source_peers, target_peers, RunCollectivePermute, - GetDeviceBufferPairs, GetStreamId(is_async)); - }); -#else // XLA_ENABLE_XCCL - return absl::InternalError("NCCL disabled"); -#endif // XLA_ENABLE_XCCL -} - -XLA_RUNTIME_DEFINE_CUSTOM_CALL( - CollectivePermute, FunctionWrapper(), checks, - CustomCall::Bind("xla.gpu.collective_permute") - .UserData() - .UserData() - .UserData() - .UserData() - .RemainingArgs() // args - .Attr("uid") - .Attr("group_mode") // CollectiveOpGroupMode - .Attr("op_id") - .Attr("is_async") - .Attr("no_parallel_custom_call") - .Attr>("replica_group_offsets") - .Attr>("replica_group_values") - .Attr>("source_peers") - .Attr>("target_peers")); - -//===----------------------------------------------------------------------===// -// Send. -//===----------------------------------------------------------------------===// - -static absl::Status P2PSendImpl(const ServiceExecutableRunOptions* run_options, - const DebugOptions* debug_options, - CollectivesSupport* collectives, - AsyncCollectivesSupport* async_collectives, - CustomCall::RemainingArgs args, int32_t uid, - int64_t group_mode, int64_t op_id, - bool is_async, bool no_parallel_custom_call, - absl::Span replica_group_offsets, - absl::Span replica_group_values, - absl::Span source_peers, - absl::Span target_peers) { -#if XLA_ENABLE_XCCL - VLOG(3) << "Running Send"; - TF_RET_CHECK(is_async); - // The scheduler guarantee no_parallel_custom_call for P2P chain, which is not - // reflected in the default value for the attribute. - return RunSyncOrAsync( - run_options, collectives, async_collectives, uid, is_async, - [&](se::Stream* stream) { - return P2PImplCommon( - run_options, debug_options, stream, args, group_mode, op_id, - /*no_parallel_custom_call=*/true, replica_group_offsets, - replica_group_values, source_peers, target_peers, RunSend, - GetSingleArgAsDeviceBufferPair, - GetStreamId(is_async, AsyncStreamKind::kP2P)); - }, - AsyncStreamKind::kP2P); -#else // XLA_ENABLE_XCCL - return absl::InternalError("NCCL disabled"); -#endif // XLA_ENABLE_XCCL -} - -XLA_RUNTIME_DEFINE_CUSTOM_CALL( - P2PSend, FunctionWrapper(), checks, - CustomCall::Bind("xla.gpu.send") - .UserData() - .UserData() - .UserData() - .UserData() - .RemainingArgs() // args - .Attr("uid") - .Attr("group_mode") // CollectiveOpGroupMode - .Attr("op_id") - .Attr("is_async") - .Attr("no_parallel_custom_call") - .Attr>("replica_group_offsets") - .Attr>("replica_group_values") - .Attr>("source_peers") - .Attr>("target_peers")); - -//===----------------------------------------------------------------------===// -// Recv. -//===----------------------------------------------------------------------===// - -static absl::Status P2PRecvImpl(const ServiceExecutableRunOptions* run_options, - const DebugOptions* debug_options, - CollectivesSupport* collectives, - AsyncCollectivesSupport* async_collectives, - CustomCall::RemainingArgs args, int32_t uid, - int64_t group_mode, int64_t op_id, - bool is_async, bool no_parallel_custom_call, - absl::Span replica_group_offsets, - absl::Span replica_group_values, - absl::Span source_peers, - absl::Span target_peers) { -#if XLA_ENABLE_XCCL - VLOG(3) << "Running Recv"; - TF_RET_CHECK(is_async); - // The scheduler guarantee no_parallel_custom_call for P2P chain, which is not - // reflected in the default value for the attribute. - return RunSyncOrAsync( - run_options, collectives, async_collectives, uid, is_async, - [&](se::Stream* stream) { - return P2PImplCommon( - run_options, debug_options, stream, args, group_mode, op_id, - /*no_parallel_custom_call=*/true, replica_group_offsets, - replica_group_values, source_peers, target_peers, RunRecv, - GetSingleArgAsDeviceBufferPair, - GetStreamId(is_async, AsyncStreamKind::kP2P)); - }, - AsyncStreamKind::kP2P); -#else // XLA_ENABLE_XCCL - return absl::InternalError("NCCL disabled"); -#endif // XLA_ENABLE_XCCL -} - -XLA_RUNTIME_DEFINE_CUSTOM_CALL( - P2PRecv, FunctionWrapper(), checks, - CustomCall::Bind("xla.gpu.recv") - .UserData() - .UserData() - .UserData() - .UserData() - .RemainingArgs() // args - .Attr("uid") - .Attr("group_mode") // CollectiveOpGroupMode - .Attr("op_id") - .Attr("is_async") - .Attr("no_parallel_custom_call") - .Attr>("replica_group_offsets") - .Attr>("replica_group_values") - .Attr>("source_peers") - .Attr>("target_peers")); - -//===----------------------------------------------------------------------===// -// AllGather. -//===----------------------------------------------------------------------===// - -#if XLA_ENABLE_XCCL -absl::Status AllGatherImplCommon( - const ServiceExecutableRunOptions* run_options, - const DebugOptions* debug_options, se::Stream* stream, - CustomCall::RemainingArgs args, int64_t group_mode, int64_t op_id, - absl::Span replica_group_offsets, - absl::Span replica_group_values, bool is_async, - bool no_parallel_custom_call) { - TF_ASSIGN_OR_RETURN(Thunk::CollectiveExecuteParams params, - Thunk::CollectiveExecuteParams::Create( - *run_options, stream->parent()->device_ordinal())); - bool enable_clique_opt = ShouldEnableCliqueOptimization( - params, debug_options, no_parallel_custom_call); - TF_ASSIGN_OR_RETURN( - auto comm, GetNcclComm(params, group_mode, op_id, replica_group_offsets, - replica_group_values, GetStreamId(is_async), - enable_clique_opt)); - - TF_ASSIGN_OR_RETURN(auto device_buffers, GetDeviceBufferPairs(args)); - - return RunRepeated( - debug_options->xla_gpu_collective_inflation_factor(), [&]() { - return RunAllGather(NcclApi::Default(), device_buffers, *stream, *comm); - }); -} -#endif // XLA_ENABLE_XCCL - -absl::Status AllGatherImpl(const ServiceExecutableRunOptions* run_options, - const DebugOptions* debug_options, - CollectivesSupport* collectives, - AsyncCollectivesSupport* async_collectives, - CustomCall::RemainingArgs args, int32_t uid, - int64_t group_mode, int64_t op_id, bool is_async, - bool no_parallel_custom_call, - absl::Span replica_group_offsets, - absl::Span replica_group_values) { -#if XLA_ENABLE_XCCL - VLOG(3) << "Running AllGather " << (is_async ? "(Async) " : "(Sync) ") - << no_parallel_custom_call; - return RunSyncOrAsync( - run_options, collectives, async_collectives, uid, is_async, - [&](se::Stream* stream) { - const gpu::GpuExecutableRunOptions* gpu_opts = - run_options->run_options().gpu_executable_run_options(); - if (gpu_opts && gpu_opts->enable_mock_nccl_collectives()) { - return MockNcclImplCommon( - run_options, debug_options, stream, args, group_mode, op_id, - replica_group_offsets, replica_group_values, is_async, - Thunk::kNcclAllGather, gpu_opts->mock_nccl_topo_model()); - } - return AllGatherImplCommon(run_options, debug_options, stream, args, - group_mode, op_id, replica_group_offsets, - replica_group_values, is_async, - no_parallel_custom_call); - }); -#else // XLA_ENABLE_XCCL - return absl::InternalError("NCCL diasbled"); -#endif // XLA_ENABLE_XCCL -} - -XLA_RUNTIME_DEFINE_CUSTOM_CALL( - AllGather, FunctionWrapper(), checks, - CustomCall::Bind("xla.gpu.all_gather") - .UserData() - .UserData() - .UserData() - .UserData() - .RemainingArgs() // args - .Attr("uid") - .Attr("group_mode") // CollectiveOpGroupMode - .Attr("op_id") - .Attr("is_async") - .Attr("no_parallel_custom_call") - .Attr>("replica_group_offsets") - .Attr>("replica_group_values")); - -//===----------------------------------------------------------------------===// -// AllReduce. -//===----------------------------------------------------------------------===// - -#if XLA_ENABLE_XCCL -absl::Status AllReduceImplCommon( - const ServiceExecutableRunOptions* run_options, - const DebugOptions* debug_options, se::Stream* stream, - CustomCall::RemainingArgs args, int64_t group_mode, int64_t op_id, - int64_t reduction_kind, absl::Span replica_group_offsets, - absl::Span replica_group_values, bool is_async, - bool no_parallel_custom_call) { - TF_ASSIGN_OR_RETURN(Thunk::CollectiveExecuteParams params, - Thunk::CollectiveExecuteParams::Create( - *run_options, stream->parent()->device_ordinal())); - bool enable_clique_opt = ShouldEnableCliqueOptimization( - params, debug_options, no_parallel_custom_call); - - TF_ASSIGN_OR_RETURN( - auto comm, GetNcclComm(params, group_mode, op_id, replica_group_offsets, - replica_group_values, GetStreamId(is_async), - enable_clique_opt)); - - TF_ASSIGN_OR_RETURN(auto device_buffers, GetDeviceBufferPairs(args)); - - return RunRepeated( - debug_options->xla_gpu_collective_inflation_factor(), [&]() { - return RunAllReduce(NcclApi::Default(), - static_cast(reduction_kind), - device_buffers, *stream, *comm); - }); -} -#endif // XLA_ENABLE_XCCL - -absl::Status AllReduceImpl(const ServiceExecutableRunOptions* run_options, - const DebugOptions* debug_options, - CollectivesSupport* collectives, - AsyncCollectivesSupport* async_collectives, - CustomCall::RemainingArgs args, int32_t uid, - int64_t group_mode, int64_t op_id, bool is_async, - bool no_parallel_custom_call, int64_t reduction_kind, - absl::Span replica_group_offsets, - absl::Span replica_group_values) { -#if XLA_ENABLE_XCCL - VLOG(3) << "Running AllReduce " << (is_async ? "(Async) " : "(Sync) ") - << no_parallel_custom_call; - return RunSyncOrAsync( - run_options, collectives, async_collectives, uid, is_async, - [&](se::Stream* stream) { - const gpu::GpuExecutableRunOptions* gpu_opts = - run_options->run_options().gpu_executable_run_options(); - if (gpu_opts && gpu_opts->enable_mock_nccl_collectives()) { - return MockNcclImplCommon( - run_options, debug_options, stream, args, group_mode, op_id, - replica_group_offsets, replica_group_values, is_async, - Thunk::kNcclAllReduce, gpu_opts->mock_nccl_topo_model()); - } - return AllReduceImplCommon(run_options, debug_options, stream, args, - group_mode, op_id, reduction_kind, - replica_group_offsets, replica_group_values, - is_async, no_parallel_custom_call); - }); -#else // XLA_ENABLE_XCCL - // NCCL disabled. - return absl::InternalError("NCCL disabled"); -#endif // XLA_ENABLE_XCCL -} - -XLA_RUNTIME_DEFINE_CUSTOM_CALL( - AllReduce, FunctionWrapper(), checks, - CustomCall::Bind("xla.gpu.all_reduce") - .UserData() - .UserData() - .UserData() - .UserData() - .RemainingArgs() // args - .Attr("uid") - .Attr("group_mode") // CollectiveOpGroupMode - .Attr("op_id") - .Attr("is_async") - .Attr("no_parallel_custom_call") - .Attr("reduction_kind") // ReductionKind - .Attr>("replica_group_offsets") - .Attr>("replica_group_values")); - -//===----------------------------------------------------------------------===// -// AllToAll. -//===----------------------------------------------------------------------===// - -#if XLA_ENABLE_XCCL -absl::Status MockAllToAllImplCommon( - const ServiceExecutableRunOptions* run_options, - const DebugOptions* debug_options, se::Stream* stream, - CustomCall::RemainingArgs args, int64_t group_mode, - bool has_split_dimension, int64_t op_id, - absl::Span replica_group_offsets, - absl::Span replica_group_values, bool is_async, - GpuExecutableRunOptions::MockNcclTopoModel topo_model) { - TF_ASSIGN_OR_RETURN(Thunk::CollectiveExecuteParams params, - Thunk::CollectiveExecuteParams::Create( - *run_options, stream->parent()->device_ordinal())); - - auto comm = GetMockNcclComm( - params, group_mode, op_id, replica_group_offsets, replica_group_values, - GetStreamId(is_async), - debug_options->xla_gpu_enable_nccl_clique_optimization(), topo_model); - // Skip mock nccl calls for gpus with non-zero ranks. Only run the nccl mock - // calls for the gpu with rank 0. - // TODO: Remove the check, once the pjrt client supports running benchmark - // with single gpu. - if (absl::IsCancelled(comm.status())) return absl::OkStatus(); - if (!comm.ok()) return comm.status(); - - TF_ASSIGN_OR_RETURN(auto device_buffers, GetDeviceBufferPairs(args)); - - return RunMockNcclAllToAll(NcclApi::Default(), has_split_dimension, - device_buffers, *stream, **comm); -} - -absl::Status AllToAllImplCommon(const ServiceExecutableRunOptions* run_options, - const DebugOptions* debug_options, - se::Stream* stream, - CustomCall::RemainingArgs args, - int64_t group_mode, bool has_split_dimension, - int64_t op_id, - absl::Span replica_group_offsets, - absl::Span replica_group_values, - bool is_async, bool no_parallel_custom_call) { - TF_ASSIGN_OR_RETURN(Thunk::CollectiveExecuteParams params, - Thunk::CollectiveExecuteParams::Create( - *run_options, stream->parent()->device_ordinal())); - bool enable_clique_opt = ShouldEnableCliqueOptimization( - params, debug_options, no_parallel_custom_call); - - TF_ASSIGN_OR_RETURN( - auto comm, GetNcclComm(params, group_mode, op_id, replica_group_offsets, - replica_group_values, GetStreamId(is_async), - enable_clique_opt)); - - TF_ASSIGN_OR_RETURN(auto device_buffers, GetDeviceBufferPairs(args)); - - return RunRepeated( - debug_options->xla_gpu_collective_inflation_factor(), [&]() { - return RunAllToAll(NcclApi::Default(), has_split_dimension, - device_buffers, *stream, *comm); - }); -} -#endif // XLA_ENABLE_XCCL - -absl::Status AllToAllImpl(const ServiceExecutableRunOptions* run_options, - const DebugOptions* debug_options, - CollectivesSupport* collectives, - AsyncCollectivesSupport* async_collectives, - CustomCall::RemainingArgs args, int32_t uid, - int64_t group_mode, bool has_split_dimension, - int64_t op_id, bool is_async, - bool no_parallel_custom_call, - absl::Span replica_group_offsets, - absl::Span replica_group_values) { -#if XLA_ENABLE_XCCL - VLOG(3) << "Running AllToAll " << (is_async ? "(Async) " : "(Sync) ") - << no_parallel_custom_call; - return RunSyncOrAsync( - run_options, collectives, async_collectives, uid, is_async, - [&](se::Stream* stream) { - const gpu::GpuExecutableRunOptions* gpu_opts = - run_options->run_options().gpu_executable_run_options(); - if (gpu_opts && gpu_opts->enable_mock_nccl_collectives()) { - return MockAllToAllImplCommon( - run_options, debug_options, stream, args, group_mode, - has_split_dimension, op_id, replica_group_offsets, - replica_group_values, is_async, gpu_opts->mock_nccl_topo_model()); - } - return AllToAllImplCommon(run_options, debug_options, stream, args, - group_mode, has_split_dimension, op_id, - replica_group_offsets, replica_group_values, - is_async, no_parallel_custom_call); - }); -#else // XLA_ENABLE_XCCL - return absl::InternalError("NCCL disabled"); -#endif // XLA_ENABLE_XCCL -} - -XLA_RUNTIME_DEFINE_CUSTOM_CALL( - AllToAll, FunctionWrapper(), checks, - CustomCall::Bind("xla.gpu.all_to_all") - .UserData() - .UserData() - .UserData() - .UserData() - .RemainingArgs() // args - .Attr("uid") - .Attr("group_mode") // CollectiveOpGroupMode - .Attr("has_split_dimension") - .Attr("op_id") - .Attr("is_async") - .Attr("no_parallel_custom_call") - .Attr>("replica_group_offsets") - .Attr>("replica_group_values")); - -//===----------------------------------------------------------------------===// -// ReduceScatter. -//===----------------------------------------------------------------------===// - -#if XLA_ENABLE_XCCL -absl::Status ReduceScatterImplCommon( - const ServiceExecutableRunOptions* run_options, - const DebugOptions* debug_options, se::Stream* stream, - CustomCall::RemainingArgs args, int64_t group_mode, int64_t op_id, - int64_t reduction_kind, absl::Span replica_group_offsets, - absl::Span replica_group_values, bool is_async, - bool no_parallel_custom_call) { - TF_ASSIGN_OR_RETURN(Thunk::CollectiveExecuteParams params, - Thunk::CollectiveExecuteParams::Create( - *run_options, stream->parent()->device_ordinal())); - bool enable_clique_opt = ShouldEnableCliqueOptimization( - params, debug_options, no_parallel_custom_call); - - TF_ASSIGN_OR_RETURN( - auto comm, GetNcclComm(params, group_mode, op_id, replica_group_offsets, - replica_group_values, GetStreamId(is_async), - enable_clique_opt)); - - TF_ASSIGN_OR_RETURN(auto device_buffers, GetDeviceBufferPairs(args)); - - return RunRepeated( - debug_options->xla_gpu_collective_inflation_factor(), [&]() { - return RunReduceScatter(NcclApi::Default(), - static_cast(reduction_kind), - device_buffers, *stream, *comm); - }); -} -#endif // XLA_ENABLE_XCCL - -absl::Status ReduceScatterImpl(const ServiceExecutableRunOptions* run_options, - const DebugOptions* debug_options, - CollectivesSupport* collectives, - AsyncCollectivesSupport* async_collectives, - CustomCall::RemainingArgs args, int32_t uid, - int64_t group_mode, int64_t op_id, bool is_async, - bool no_parallel_custom_call, - int64_t reduction_kind, - absl::Span replica_group_offsets, - absl::Span replica_group_values) { -#if XLA_ENABLE_XCCL - VLOG(3) << "Running ReduceScatter " << (is_async ? "(Async) " : "(Sync) ") - << no_parallel_custom_call; - return RunSyncOrAsync( - run_options, collectives, async_collectives, uid, is_async, - [&](se::Stream* stream) { - const gpu::GpuExecutableRunOptions* gpu_opts = - run_options->run_options().gpu_executable_run_options(); - if (gpu_opts && gpu_opts->enable_mock_nccl_collectives()) { - return MockNcclImplCommon( - run_options, debug_options, stream, args, group_mode, op_id, - replica_group_offsets, replica_group_values, is_async, - Thunk::kNcclReduceScatter, gpu_opts->mock_nccl_topo_model()); - } - return ReduceScatterImplCommon( - run_options, debug_options, stream, args, group_mode, op_id, - reduction_kind, replica_group_offsets, replica_group_values, - is_async, no_parallel_custom_call); - }); -#else // XLA_ENABLE_XCCL - return absl::InternalError("NCCL disabled"); -#endif // XLA_ENABLE_XCCL -} - -XLA_RUNTIME_DEFINE_CUSTOM_CALL( - ReduceScatter, FunctionWrapper(), checks, - CustomCall::Bind("xla.gpu.reduce_scatter") - .UserData() - .UserData() - .UserData() - .UserData() - .RemainingArgs() // args - .Attr("uid") - .Attr("group_mode") // CollectiveOpGroupMode - .Attr("op_id") - .Attr("is_async") - .Attr("no_parallel_custom_call") - .Attr("reduction_kind") // ReductionKind - .Attr>("replica_group_offsets") - .Attr>("replica_group_values")); - -//===----------------------------------------------------------------------===// -// AsyncDone. -//===----------------------------------------------------------------------===// - -XLA_RUNTIME_DEFINE_CUSTOM_CALL( - AsyncDone, FunctionWrapper(), checks, - CustomCall::Bind("xla.gpu.async_collective_done") - .UserData() - .UserData() - .Attr("uid") - .Attr("done_type")); - -//===----------------------------------------------------------------------===// -// ReplicaId. -//===----------------------------------------------------------------------===// - -absl::Status ReplicaPartitionIdImpl( - const ServiceExecutableRunOptions* run_options, FlatMemrefView result, - bool is_replica_id) { - VLOG(3) << "Running " << (is_replica_id ? "ReplicaId" : "PartitionId"); - se::Stream* stream = run_options->stream(); - TF_ASSIGN_OR_RETURN(Thunk::CollectiveExecuteParams params, - Thunk::CollectiveExecuteParams::Create( - *run_options, stream->parent()->device_ordinal())); - - GlobalDeviceId global_device_id = params.global_device_id; - - TF_ASSIGN_OR_RETURN(DeviceAssignment::LogicalID logical_id, - params.device_assn->LogicalIdForDevice(global_device_id)); - - se::DeviceMemoryBase result_data = GetDeviceAddress(result); - const uint32_t id = - is_replica_id ? logical_id.replica_id : logical_id.computation_id; - stream->ThenMemset32(&result_data, id, /*size=*/4); - return absl::OkStatus(); -} - -absl::Status ReplicaIdImpl(const ServiceExecutableRunOptions* run_options, - FlatMemrefView result) { - return ReplicaPartitionIdImpl(run_options, result, /*is_replica_id=*/true); -} - -XLA_RUNTIME_DEFINE_CUSTOM_CALL( - ReplicaId, FunctionWrapper(), checks, - CustomCall::Bind("xla.gpu.replica_id") - .UserData() - .Arg()); - -//===----------------------------------------------------------------------===// -// PartitionId. -//===----------------------------------------------------------------------===// - -absl::Status PartitionIdImpl(const ServiceExecutableRunOptions* run_options, - FlatMemrefView result) { - return ReplicaPartitionIdImpl(run_options, result, /*is_replica_id=*/false); -} - -XLA_RUNTIME_DEFINE_CUSTOM_CALL( - PartitionId, FunctionWrapper(), checks, - CustomCall::Bind("xla.gpu.partition_id") - .UserData() - .Arg()); - -//===----------------------------------------------------------------------===// - -int64_t Key(int32_t uid, int32_t device_ordinal) { - return static_cast(uid) << 32 | device_ordinal; -} - -} // namespace - -//===----------------------------------------------------------------------===// -// Collectives support library. -//===----------------------------------------------------------------------===// - -absl::Status CollectivesSupport::MaybeBlockAfterFirstRun(int32_t uid, - int32_t device_ordinal, - se::Stream* stream) { - bool block = [&] { - absl::MutexLock lock(&mutex_); - return executed_.insert(Key(uid, device_ordinal)).second; - }(); - return block ? stream->BlockHostUntilDone() : absl::OkStatus(); -} - -AsyncCollectivesSupport::AsyncCollectivesSupport( - absl::Span async_streams) - : async_comm_streams_(async_streams.begin(), async_streams.end()) {} - -absl::Status AsyncCollectivesSupport::RecordEvent( - int32_t uid, gpu::AsyncStreamKind async_stream_kind) { - // Create an event on the async stream for the completion of the collective. - se::Event done_event(async_comm_stream(async_stream_kind)->parent()); - if (!done_event.Init()) return absl::InternalError("Failed to create event"); - async_comm_stream(async_stream_kind)->ThenRecordEvent(&done_event); - - absl::MutexLock lock(&mutex_); - auto [_, was_inserted] = done_events_.insert({uid, std::move(done_event)}); - if (!was_inserted) { - return absl::InternalError(absl::StrFormat( - "Async done event has not been consumed (uid=%d, device_ordinal=%d)", - uid, async_comm_stream(async_stream_kind)->parent()->device_ordinal())); - } - return absl::OkStatus(); -} - -absl::StatusOr AsyncCollectivesSupport::PopEvent(int32_t uid) { - absl::MutexLock lock(&mutex_); - auto done_event = done_events_.extract(uid); - if (!done_event) { - return absl::InternalError( - absl::StrFormat("Async done event was not found (uid=%d)", uid)); - } - return std::move(done_event.mapped()); -} - -void RegisterCollectiveCustomCalls( - runtime::DirectCustomCallRegistry& registry) { - registry.Register("xla.gpu.collective_permute", CollectivePermute); - registry.Register("xla.gpu.send", P2PSend); - registry.Register("xla.gpu.recv", P2PRecv); - registry.Register("xla.gpu.all_gather", AllGather); - registry.Register("xla.gpu.all_reduce", AllReduce); - registry.Register("xla.gpu.all_to_all", AllToAll); - registry.Register("xla.gpu.reduce_scatter", ReduceScatter); - - registry.Register("xla.gpu.collective_done", AsyncDone); - - registry.Register("xla.gpu.partition_id", PartitionId); - registry.Register("xla.gpu.replica_id", ReplicaId); -} - -} // namespace gpu -} // namespace xla diff --git a/third_party/xla/xla/service/gpu/runtime/collectives.h b/third_party/xla/xla/service/gpu/runtime/collectives.h deleted file mode 100644 index d6bda3d16392ab..00000000000000 --- a/third_party/xla/xla/service/gpu/runtime/collectives.h +++ /dev/null @@ -1,79 +0,0 @@ -/* Copyright 2022 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_SERVICE_GPU_RUNTIME_COLLECTIVES_H_ -#define XLA_SERVICE_GPU_RUNTIME_COLLECTIVES_H_ - -#include - -#include "absl/container/flat_hash_map.h" -#include "absl/container/flat_hash_set.h" -#include "xla/runtime/custom_call_registry.h" -#include "xla/service/gpu/nccl_collective_thunk.h" -#include "xla/stream_executor/event.h" - -namespace xla { -namespace gpu { - -// Support for running async collective operations communicating via events. -// Registers XLA Gpu runtime collective custom calls. -void RegisterCollectiveCustomCalls(runtime::DirectCustomCallRegistry& registry); - -class CollectivesSupport { - public: - // Maybe block host after the first call to the collective operation with the - // given uid, to ensure that all devices have allocated the required buffers - // for their communicators before allowing any device to continue enqueuing - // operations. Otherwise, the allocations can cause deadlock in the CUDA - // driver. - // - // This basically ports workaround from cr/435058849 to Xla runtime (see - // details in the b/215649390). - absl::Status MaybeBlockAfterFirstRun(int32_t uid, int32_t device_ordinal, - se::Stream* stream); - - private: - absl::Mutex mutex_; - - // Store if a particular collective operation was executed at least once. We - // rely on unique `uid` assigned to each collective operation by the lowering - // pass. - absl::flat_hash_set executed_ ABSL_GUARDED_BY(mutex_); -}; - -// Support for running async collective operations communicating via events. -class AsyncCollectivesSupport { - public: - explicit AsyncCollectivesSupport(absl::Span async_streams); - - absl::Status RecordEvent(int32_t uid, AsyncStreamKind async_stream_kind); - absl::StatusOr PopEvent(int32_t uid); - - se::Stream* async_comm_stream(AsyncStreamKind async_stream_kind) const { - return async_comm_streams_[static_cast(async_stream_kind)]; - } - - private: - absl::Mutex mutex_; - absl::InlinedVector async_comm_streams_; - - // Store done events for the Done ops to wait upon. - absl::flat_hash_map done_events_ ABSL_GUARDED_BY(mutex_); -}; - -} // namespace gpu -} // namespace xla - -#endif // XLA_SERVICE_GPU_RUNTIME_COLLECTIVES_H_ diff --git a/third_party/xla/xla/service/gpu/runtime3/command_buffer_allocations.cc b/third_party/xla/xla/service/gpu/runtime/command_buffer_allocations.cc similarity index 97% rename from third_party/xla/xla/service/gpu/runtime3/command_buffer_allocations.cc rename to third_party/xla/xla/service/gpu/runtime/command_buffer_allocations.cc index 83ee253dc4d1ad..a8ac270ff8d2e2 100644 --- a/third_party/xla/xla/service/gpu/runtime3/command_buffer_allocations.cc +++ b/third_party/xla/xla/service/gpu/runtime/command_buffer_allocations.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/runtime3/command_buffer_allocations.h" +#include "xla/service/gpu/runtime/command_buffer_allocations.h" #include diff --git a/third_party/xla/xla/service/gpu/runtime3/command_buffer_allocations.h b/third_party/xla/xla/service/gpu/runtime/command_buffer_allocations.h similarity index 90% rename from third_party/xla/xla/service/gpu/runtime3/command_buffer_allocations.h rename to third_party/xla/xla/service/gpu/runtime/command_buffer_allocations.h index 310a8c0bfc0325..3440852b581d92 100644 --- a/third_party/xla/xla/service/gpu/runtime3/command_buffer_allocations.h +++ b/third_party/xla/xla/service/gpu/runtime/command_buffer_allocations.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_GPU_RUNTIME3_COMMAND_BUFFER_ALLOCATIONS_H_ -#define XLA_SERVICE_GPU_RUNTIME3_COMMAND_BUFFER_ALLOCATIONS_H_ +#ifndef XLA_SERVICE_GPU_RUNTIME_COMMAND_BUFFER_ALLOCATIONS_H_ +#define XLA_SERVICE_GPU_RUNTIME_COMMAND_BUFFER_ALLOCATIONS_H_ #include "absl/container/flat_hash_map.h" #include "xla/service/buffer_assignment.h" @@ -48,4 +48,4 @@ class CommandBufferAllocations : public BufferAllocations::ExternalAllocations { } // namespace xla::gpu -#endif // XLA_SERVICE_GPU_RUNTIME3_COMMAND_BUFFER_ALLOCATIONS_H_ +#endif // XLA_SERVICE_GPU_RUNTIME_COMMAND_BUFFER_ALLOCATIONS_H_ diff --git a/third_party/xla/xla/service/gpu/runtime3/command_buffer_cmd.cc b/third_party/xla/xla/service/gpu/runtime/command_buffer_cmd.cc similarity index 72% rename from third_party/xla/xla/service/gpu/runtime3/command_buffer_cmd.cc rename to third_party/xla/xla/service/gpu/runtime/command_buffer_cmd.cc index a27076dbbd763b..86bffba4968525 100644 --- a/third_party/xla/xla/service/gpu/runtime3/command_buffer_cmd.cc +++ b/third_party/xla/xla/service/gpu/runtime/command_buffer_cmd.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/runtime3/command_buffer_cmd.h" +#include "xla/service/gpu/runtime/command_buffer_cmd.h" #include #include @@ -27,8 +27,10 @@ limitations under the License. #include #include "absl/algorithm/container.h" +#include "absl/base/optimization.h" #include "absl/container/flat_hash_set.h" #include "absl/container/inlined_vector.h" +#include "absl/functional/function_ref.h" #include "absl/log/log.h" #include "absl/status/status.h" #include "absl/strings/str_cat.h" @@ -42,12 +44,13 @@ limitations under the License. #include "xla/service/gpu/kernels/custom_kernel.h" #include "xla/service/gpu/launch_dimensions.h" #include "xla/service/gpu/matmul_utils.h" -#include "xla/service/gpu/nccl_all_gather_thunk.h" -#include "xla/service/gpu/nccl_all_reduce_thunk.h" #include "xla/service/gpu/nccl_api.h" #include "xla/service/gpu/nccl_clique.h" #include "xla/service/gpu/nccl_clique_key.h" #include "xla/service/gpu/nccl_collective_thunk.h" +#include "xla/service/gpu/runtime/annotation.h" +#include "xla/service/gpu/runtime/nccl_all_gather_thunk.h" +#include "xla/service/gpu/runtime/nccl_all_reduce_thunk.h" #include "xla/service/gpu/stream_executor_util.h" #include "xla/service/gpu/thunk.h" #include "xla/stream_executor/command_buffer.h" @@ -56,10 +59,10 @@ limitations under the License. #include "xla/stream_executor/launch_dim.h" #include "xla/stream_executor/stream_executor.h" #include "xla/types.h" // IWYU pragma: keep -#include "xla/util.h" #include "tsl/concurrency/ref_count.h" #include "tsl/platform/env.h" #include "tsl/platform/errors.h" +#include "tsl/platform/logging.h" #include "tsl/platform/statusor.h" #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM @@ -88,10 +91,10 @@ static std::string_view ReductionKindString(ReductionKind kind) { // Creates condition command buffer builder from a cmd sequence. static se::CommandBuffer::Builder ConditionBuilder( - CommandBufferCmdSequence* commands, - const CommandBufferCmd::RecordParams* params) { + CommandBufferCmdSequence* commands, const Thunk::ExecuteParams* params, + CommandBufferCmd::StateManager* state) { return [=](se::CommandBuffer* command_buffer) { - return commands->Record(*params, command_buffer, + return commands->Record(*params, *state, command_buffer, CommandBufferCmdSequence::RecordMode::kConditional); }; } @@ -99,14 +102,35 @@ static se::CommandBuffer::Builder ConditionBuilder( // Creates condition command buffer builders from a span of cmd sequences. static std::vector ConditionBuilders( absl::Span commands, - const CommandBufferCmd::RecordParams* params) { + const Thunk::ExecuteParams* params, CommandBufferCmd::StateManager* state) { std::vector builders; for (CommandBufferCmdSequence& cmd : commands) { - builders.push_back(ConditionBuilder(&cmd, params)); + builders.push_back(ConditionBuilder(&cmd, params, state)); } return builders; } +//===----------------------------------------------------------------------===// +// CommandBufferCmd +//===----------------------------------------------------------------------===// + +CommandBufferCmd::State* CommandBufferCmd::StateManager::GetOrNull( + const CommandBufferCmd* cmd) { + if (auto it = state_.find(cmd); it != state_.end()) { + return it->second.get(); + } + return nullptr; +} + +CommandBufferCmd::State* CommandBufferCmd::StateManager::GetOrCreate( + const CommandBufferCmd* cmd, + absl::FunctionRef()> create) { + if (auto it = state_.find(cmd); it != state_.end()) { + return it->second.get(); + } + return state_.try_emplace(cmd, create()).first->second.get(); +} + //===----------------------------------------------------------------------===// // CommandBufferCmdSequence //===----------------------------------------------------------------------===// @@ -137,7 +161,7 @@ void CommandBufferCmdSequence::Append(std::unique_ptr cmd) { if (requires_barrier) ClearTrackedBuffers(); - commands_.emplace_back(std::move(cmd), requires_barrier); + commands_.push_back({std::move(cmd), requires_barrier}); TrackBuffers(buffers); } @@ -151,9 +175,10 @@ absl::Status CommandBufferCmdSequence::Prepare( } absl::Status CommandBufferCmdSequence::Initialize( - se::StreamExecutor* executor, Thunk::Thunk::ExecutableSource source) { + const Thunk::InitializeParams& params, + CommandBufferCmd::StateManager& state) { for (auto& command : commands_) { - TF_RETURN_IF_ERROR(command.cmd->Initialize(executor, source)); + TF_RETURN_IF_ERROR(command.cmd->Initialize(params, state)); } return absl::OkStatus(); } @@ -207,7 +232,7 @@ static std::string_view RecordModeString( } absl::Status CommandBufferCmdSequence::Record( - const CommandBufferCmd::RecordParams& params, + const Thunk::ExecuteParams& params, CommandBufferCmd::StateManager& state, se::CommandBuffer* command_buffer, RecordMode mode) { VLOG(3) << "Record " << commands_.size() << " commands into command buffer" << "; mode=" << RecordModeString(mode); @@ -222,15 +247,17 @@ absl::Status CommandBufferCmdSequence::Record( // Track the number of commands recorded between barriers. int64_t num_recorded_commands = 0; + const ModuleAnnotations* annotations = GetCurrentModuleAnnotations(); for (auto& command : commands_) { if (command.requires_barrier) { VLOG(3) << "Add command buffer barrier after " << num_recorded_commands << " recorded commands"; - TF_RETURN_IF_ERROR(command_buffer->Barrier(params.executor)); + TF_RETURN_IF_ERROR(command_buffer->Barrier(params.stream->parent())); num_recorded_commands = 0; } - - TF_RETURN_IF_ERROR(command.cmd->Record(params, command_buffer)); + auto annotation = + GetKernelAnnotation(annotations, command.cmd->profile_annotation()); + TF_RETURN_IF_ERROR(command.cmd->Record(params, state, command_buffer)); ++num_recorded_commands; } @@ -262,6 +289,91 @@ std::vector CommandBufferCmdSequence::barriers() const { [](auto& command) { return command.requires_barrier; }); return barriers; } + +//===----------------------------------------------------------------------===// +// TracedCommandBuffer +//===----------------------------------------------------------------------===// + +TracedCommandBuffer::TracedCommandBuffer( + CommandBufferCmd::BufferUsageVector buffers, int64_t capacity) + : capacity_(capacity), entries_(capacity) { + CHECK_GT(capacity, 0) << "capacity must be larger than 0"; // NOLINT + // Collect unique buffer allocation indices in a set first and convert to + // vector as flat hash set iteration has measurable overheads. + absl::flat_hash_set allocs_indices; + for (auto& buffer : buffers) allocs_indices.insert(buffer.slice.index()); + allocs_indices_.assign(allocs_indices.begin(), allocs_indices.end()); +} + +absl::StatusOr TracedCommandBuffer::GetOrTraceCommandBuffer( + const BufferAllocations* buffer_allocation, se::StreamExecutor* executor, + se::Stream* stream, absl::FunctionRef trace) { + // Collect memory addresses for relevant allocations. + absl::InlinedVector allocs; + allocs.reserve(allocs_indices_.size()); + for (auto& index : allocs_indices_) { + allocs.emplace_back(buffer_allocation->GetDeviceAddress(index)); + } + + // Moves entry at `i` position to front and moves entries in `[0, i)` range + // one element to the right. Returns reference to the first entry. + auto shift_right = [&](size_t i) -> Entry& { + if (i == 0) return entries_[0]; + + Entry entry = std::move(entries_[i]); + do { + entries_[i] = std::move(entries_[i - 1]); + } while (--i > 0); + + return entries_[0] = std::move(entry); + }; + + for (size_t i = 0; i < capacity_; ++i) { + // Found entry for a given allocations, move it to front and return a + // pointer to cached command buffer. + if (ABSL_PREDICT_TRUE(absl::c_equal(entries_[i].recorded_allocs, allocs) && + entries_[i].command_buffer)) { + return shift_right(i).command_buffer.get(); + } + + // Create a new entry by calling a user-provided tracing function, move it + // to front and return a pointer to cached command buffer. + if (entries_[i].command_buffer == nullptr) { + TF_ASSIGN_OR_RETURN(entries_[i].command_buffer, + se::CommandBuffer::Trace(executor, stream, trace)); + entries_[i].recorded_allocs.assign(allocs.begin(), allocs.end()); + return shift_right(i).command_buffer.get(); + } + } + + // Create a new entry by calling a user-provided tracing function, replace the + // last entry with it, move it to front and return a pointer to cached command + // buffer. + TF_ASSIGN_OR_RETURN(entries_[capacity_ - 1].command_buffer, + se::CommandBuffer::Trace(executor, stream, trace)); + entries_[capacity_ - 1].recorded_allocs.assign(allocs.begin(), allocs.end()); + return shift_right(capacity_ - 1).command_buffer.get(); +} + +//===----------------------------------------------------------------------===// +// TracedCommandBufferCmd +//===----------------------------------------------------------------------===// + +absl::Status TracedCommandBufferCmd::AddTracedCommandBuffer( + const Thunk::ExecuteParams& params, StateManager& state, + se::CommandBuffer* command_buffer, + absl::FunctionRef trace) { + auto traced_cmd = state.GetOrCreate( + this, [&] { return std::make_unique(buffers()); }); + + TF_ASSIGN_OR_RETURN(auto nested_cmd, + traced_cmd->GetOrTraceCommandBuffer( + params.buffer_allocations, params.stream->parent(), + params.command_buffer_trace_stream, trace)); + + return command_buffer->AddNestedCommandBuffer(*nested_cmd); +} + //===----------------------------------------------------------------------===// // ComputationId //===----------------------------------------------------------------------===// @@ -280,7 +392,7 @@ std::vector CommandBufferCmdSequence::barriers() const { // // Easiest way to get PTX from C++ is to use https://godbolt.org. inline constexpr std::string_view kMemset32Kernel = R"( -.version 8.0 +.version 4.0 .target sm_50 .address_size 64 @@ -328,24 +440,25 @@ CommandBufferCmd::BufferUsageVector ComputationIdCmd::buffers() { return {{dest_, MemoryAccess::kWrite}}; } -absl::Status ComputationIdCmd::Initialize(se::StreamExecutor* executor, - Thunk::ExecutableSource source) { +absl::Status ComputationIdCmd::Initialize(const Thunk::InitializeParams& params, + StateManager& state) { { absl::MutexLock lock(&mutex_); - if (memset_kernels_.contains(executor)) return absl::OkStatus(); + if (memset_kernels_.contains(params.executor)) return absl::OkStatus(); } - TF_ASSIGN_OR_RETURN( - std::unique_ptr kernel, - CreateKernel("memset32", 3, kMemset32Kernel, /*cubin_data=*/{}, executor, - /*shared_mem_bytes=*/0)); + TF_ASSIGN_OR_RETURN(std::unique_ptr kernel, + CreateKernel("memset32", 3, kMemset32Kernel, + /*cubin_data=*/{}, params.executor, + /*shared_mem_bytes=*/0)); absl::MutexLock lock(&mutex_); - memset_kernels_.emplace(executor, std::move(kernel)); + memset_kernels_.emplace(params.executor, std::move(kernel)); return absl::OkStatus(); } -absl::Status ComputationIdCmd::Record(const RecordParams& params, +absl::Status ComputationIdCmd::Record(const Thunk::ExecuteParams& params, + StateManager& state, se::CommandBuffer* command_buffer) { se::DeviceMemoryBase dst = params.buffer_allocations->GetDeviceAddress(dest_); @@ -364,7 +477,7 @@ absl::Status ComputationIdCmd::Record(const RecordParams& params, se::Kernel* memset_kernel = [&] { absl::MutexLock lock(&mutex_); - return memset_kernels_[params.executor].get(); + return memset_kernels_[params.stream->parent()].get(); }(); if (memset_kernel == nullptr) { @@ -372,13 +485,9 @@ absl::Status ComputationIdCmd::Record(const RecordParams& params, "Memset kernel not loaded on a command buffer executor"); } - auto* memset32 = static_cast< - se::TypedKernel>*>( - memset_kernel); - - return command_buffer->Launch(*memset32, se::ThreadDim(1), se::BlockDim(1), - /*n=*/int64_t{1}, value, - se::DeviceMemory(dst)); + auto args = se::PackKernelArgs(/*shmem_bytes=*/0, int64_t{1}, value, dst); + return command_buffer->Launch(se::ThreadDim(1), se::BlockDim(1), + *memset_kernel, *args); } //===----------------------------------------------------------------------===// @@ -395,30 +504,32 @@ LaunchCmd::LaunchCmd(std::string kernel_name, dims_(dims), shmem_bytes_(shmem_bytes) {} -absl::Status LaunchCmd::Initialize(se::StreamExecutor* executor, - Thunk::ExecutableSource source) { +absl::Status LaunchCmd::Initialize(const Thunk::InitializeParams& params, + StateManager& state) { { absl::MutexLock lock(&mutex_); - if (kernels_.contains(executor)) return absl::OkStatus(); + if (kernels_.contains(params.executor)) return absl::OkStatus(); } - TF_ASSIGN_OR_RETURN(std::unique_ptr kernel, - CreateKernel(kernel_name_, args_.size(), source.text, - source.binary, executor, shmem_bytes_)); + TF_ASSIGN_OR_RETURN( + std::unique_ptr kernel, + CreateKernel(kernel_name_, args_.size(), params.src.text, + params.src.binary, params.executor, shmem_bytes_)); absl::MutexLock lock(&mutex_); - kernels_.emplace(executor, std::move(kernel)); + kernels_.emplace(params.executor, std::move(kernel)); return absl::OkStatus(); } -absl::Status LaunchCmd::Record(const RecordParams& params, +absl::Status LaunchCmd::Record(const Thunk::ExecuteParams& params, + StateManager& state, se::CommandBuffer* command_buffer) { VLOG(5) << "LaunchCmd: kernel=" << kernel_name_ << ", shmem_bytes=" << shmem_bytes_; se::Kernel* kernel = [&] { absl::MutexLock lock(&mutex_); - return kernels_[params.executor].get(); + return kernels_[params.stream->parent()].get(); }(); if (kernel == nullptr) { @@ -459,29 +570,30 @@ CustomKernelLaunchCmd::CustomKernelLaunchCmd( args_access_(args_access.begin(), args_access.end()), custom_kernel_(std::move(custom_kernel)) {} -absl::Status CustomKernelLaunchCmd::Initialize(se::StreamExecutor* executor, - Thunk::ExecutableSource source) { +absl::Status CustomKernelLaunchCmd::Initialize( + const Thunk::InitializeParams& params, StateManager& state) { { absl::MutexLock lock(&mutex_); - if (kernels_.contains(executor)) return absl::OkStatus(); + if (kernels_.contains(params.executor)) return absl::OkStatus(); } - auto kernel = std::make_unique(executor); - TF_RETURN_IF_ERROR( - executor->GetKernel(custom_kernel_.kernel_spec(), kernel.get())); + TF_ASSIGN_OR_RETURN( + std::unique_ptr kernel, + se::Kernel::Create(params.executor, custom_kernel_.kernel_spec())); absl::MutexLock lock(&mutex_); - kernels_.emplace(executor, std::move(kernel)); + kernels_.emplace(params.executor, std::move(kernel)); return absl::OkStatus(); } -absl::Status CustomKernelLaunchCmd::Record(const RecordParams& params, +absl::Status CustomKernelLaunchCmd::Record(const Thunk::ExecuteParams& params, + StateManager& state, se::CommandBuffer* command_buffer) { VLOG(5) << "CustomKernelLaunchCmd: custom_kernel=" << custom_kernel_.name(); se::Kernel* kernel = [&] { absl::MutexLock lock(&mutex_); - return kernels_[params.executor].get(); + return kernels_[params.stream->parent()].get(); }(); if (kernel == nullptr) { @@ -523,7 +635,8 @@ MemcpyDeviceToDeviceCmd::MemcpyDeviceToDeviceCmd(BufferAllocation::Slice dst, : dst_(dst), src_(src), num_bytes_(num_bytes) {} absl::Status MemcpyDeviceToDeviceCmd::Record( - const RecordParams& params, se::CommandBuffer* command_buffer) { + const Thunk::ExecuteParams& params, StateManager& state, + se::CommandBuffer* command_buffer) { se::DeviceMemoryBase dst = params.buffer_allocations->GetDeviceAddress(dst_); se::DeviceMemoryBase src = params.buffer_allocations->GetDeviceAddress(src_); @@ -549,7 +662,8 @@ CommandBufferCmd::BufferUsageVector MemcpyDeviceToDeviceCmd::buffers() { MemzeroCmd::MemzeroCmd(BufferAllocation::Slice dst) : dst_(dst) {} -absl::Status MemzeroCmd::Record(const RecordParams& params, +absl::Status MemzeroCmd::Record(const Thunk::ExecuteParams& params, + StateManager& state, se::CommandBuffer* command_buffer) { se::DeviceMemoryBase dst = params.buffer_allocations->GetDeviceAddress(dst_); @@ -575,7 +689,8 @@ CommandBufferCmd::BufferUsageVector MemzeroCmd::buffers() { Memset32Cmd::Memset32Cmd(BufferAllocation::Slice dst, uint32_t bit_pattern) : dst_(dst), bit_pattern_(bit_pattern) {} -absl::Status Memset32Cmd::Record(const RecordParams& params, +absl::Status Memset32Cmd::Record(const Thunk::ExecuteParams& params, + StateManager& state, se::CommandBuffer* command_buffer) { se::DeviceMemoryBase dst = params.buffer_allocations->GetDeviceAddress(dst_); @@ -603,18 +718,20 @@ IfCmd::IfCmd(BufferAllocation::Slice pred, CommandBufferCmdSequence then_commands) : pred_(pred), then_commands_(std::move(then_commands)) {} -absl::Status IfCmd::Initialize(se::StreamExecutor* executor, - Thunk::ExecutableSource source) { - return then_commands_.Initialize(executor, source); +absl::Status IfCmd::Initialize(const Thunk::InitializeParams& params, + StateManager& state) { + return then_commands_.Initialize(params, state); } -absl::Status IfCmd::Record(const RecordParams& params, +absl::Status IfCmd::Record(const Thunk::ExecuteParams& params, + StateManager& state, se::CommandBuffer* command_buffer) { se::DeviceMemoryBase pred = params.buffer_allocations->GetDeviceAddress(pred_); - return command_buffer->If(params.executor, se::DeviceMemory(pred), - ConditionBuilder(&then_commands_, ¶ms)); + return command_buffer->If(params.stream->parent(), + se::DeviceMemory(pred), + ConditionBuilder(&then_commands_, ¶ms, &state)); } CommandBufferCmd::BufferUsageVector IfCmd::buffers() { @@ -636,21 +753,23 @@ IfElseCmd::IfElseCmd(BufferAllocation::Slice pred, then_commands_(std::move(then_commands)), else_commands_(std::move(else_commands)) {} -absl::Status IfElseCmd::Initialize(se::StreamExecutor* executor, - Thunk::ExecutableSource source) { - TF_RETURN_IF_ERROR(then_commands_.Initialize(executor, source)); - TF_RETURN_IF_ERROR(else_commands_.Initialize(executor, source)); +absl::Status IfElseCmd::Initialize(const Thunk::InitializeParams& params, + StateManager& state) { + TF_RETURN_IF_ERROR(then_commands_.Initialize(params, state)); + TF_RETURN_IF_ERROR(else_commands_.Initialize(params, state)); return absl::OkStatus(); } -absl::Status IfElseCmd::Record(const RecordParams& params, +absl::Status IfElseCmd::Record(const Thunk::ExecuteParams& params, + StateManager& state, se::CommandBuffer* command_buffer) { se::DeviceMemoryBase pred = params.buffer_allocations->GetDeviceAddress(pred_); - return command_buffer->IfElse(params.executor, se::DeviceMemory(pred), - ConditionBuilder(&then_commands_, ¶ms), - ConditionBuilder(&else_commands_, ¶ms)); + return command_buffer->IfElse( + params.stream->parent(), se::DeviceMemory(pred), + ConditionBuilder(&then_commands_, ¶ms, &state), + ConditionBuilder(&else_commands_, ¶ms, &state)); } CommandBufferCmd::BufferUsageVector IfElseCmd::buffers() { @@ -671,22 +790,23 @@ CaseCmd::CaseCmd(BufferAllocation::Slice index, std::vector branches_commands) : index_(index), branches_commands_(std::move(branches_commands)) {} -absl::Status CaseCmd::Initialize(se::StreamExecutor* executor, - Thunk::ExecutableSource source) { +absl::Status CaseCmd::Initialize(const Thunk::InitializeParams& params, + StateManager& state) { for (auto& branch : branches_commands_) { - TF_RETURN_IF_ERROR(branch.Initialize(executor, source)); + TF_RETURN_IF_ERROR(branch.Initialize(params, state)); } return absl::OkStatus(); } -absl::Status CaseCmd::Record(const RecordParams& params, +absl::Status CaseCmd::Record(const Thunk::ExecuteParams& params, + StateManager& state, se::CommandBuffer* command_buffer) { se::DeviceMemoryBase index = params.buffer_allocations->GetDeviceAddress(index_); return command_buffer->Case( - params.executor, se::DeviceMemory(index), - ConditionBuilders(absl::MakeSpan(branches_commands_), ¶ms)); + params.stream->parent(), se::DeviceMemory(index), + ConditionBuilders(absl::MakeSpan(branches_commands_), ¶ms, &state)); } CommandBufferCmd::BufferUsageVector CaseCmd::buffers() { @@ -708,12 +828,13 @@ ForCmd::ForCmd(int32_t num_iterations, BufferAllocation::Slice loop_counter, loop_counter_(loop_counter), body_commands_(std::move(body_commands)) {} -absl::Status ForCmd::Initialize(se::StreamExecutor* executor, - Thunk::ExecutableSource source) { - return body_commands_.Initialize(executor, source); +absl::Status ForCmd::Initialize(const Thunk::InitializeParams& params, + StateManager& state) { + return body_commands_.Initialize(params, state); } -absl::Status ForCmd::Record(const RecordParams& params, +absl::Status ForCmd::Record(const Thunk::ExecuteParams& params, + StateManager& state, se::CommandBuffer* command_buffer) { se::DeviceMemoryBase loop_counter = params.buffer_allocations->GetDeviceAddress(loop_counter_); @@ -723,9 +844,10 @@ absl::Status ForCmd::Record(const RecordParams& params, VLOG(5) << " loop_counter: " << loop_counter_ << " (" << loop_counter.opaque() << ")"; - return command_buffer->For(params.executor, num_iterations_, - se::DeviceMemory(loop_counter), - ConditionBuilder(&body_commands_, ¶ms)); + return command_buffer->For( + params.stream->parent(), num_iterations_, + se::DeviceMemory(loop_counter), + ConditionBuilder(&body_commands_, ¶ms, &state)); } CommandBufferCmd::BufferUsageVector ForCmd::buffers() { @@ -747,13 +869,14 @@ WhileCmd::WhileCmd(BufferAllocation::Slice pred, cond_commands_(std::move(cond_commands)), body_commands_(std::move(body_commands)) {} -absl::Status WhileCmd::Initialize(se::StreamExecutor* executor, - Thunk::ExecutableSource source) { - TF_RETURN_IF_ERROR(cond_commands_.Initialize(executor, source)); - return body_commands_.Initialize(executor, source); +absl::Status WhileCmd::Initialize(const Thunk::InitializeParams& params, + StateManager& state) { + TF_RETURN_IF_ERROR(cond_commands_.Initialize(params, state)); + return body_commands_.Initialize(params, state); } -absl::Status WhileCmd::Record(const RecordParams& params, +absl::Status WhileCmd::Record(const Thunk::ExecuteParams& params, + StateManager& state, se::CommandBuffer* command_buffer) { se::DeviceMemoryBase pred = params.buffer_allocations->GetDeviceAddress(pred_); @@ -762,9 +885,10 @@ absl::Status WhileCmd::Record(const RecordParams& params, << " body_commands=" << body_commands_.size(); VLOG(5) << " pred: " << pred_ << " (" << pred.opaque() << ")"; - return command_buffer->While(params.executor, se::DeviceMemory(pred), - ConditionBuilder(&cond_commands_, ¶ms), - ConditionBuilder(&body_commands_, ¶ms)); + return command_buffer->While( + params.stream->parent(), se::DeviceMemory(pred), + ConditionBuilder(&cond_commands_, ¶ms, &state), + ConditionBuilder(&body_commands_, ¶ms, &state)); } CommandBufferCmd::BufferUsageVector WhileCmd::buffers() { @@ -784,7 +908,8 @@ CommandBufferCmd::BufferUsageVector WhileCmd::buffers() { AllocateCmd::AllocateCmd(BufferAllocation allocation) : allocation_(allocation) {} -absl::Status AllocateCmd::Record(const RecordParams& params, +absl::Status AllocateCmd::Record(const Thunk::ExecuteParams& params, + StateManager& state, se::CommandBuffer* command_buffer) { // Memory allocation address is returned on graph creation, and there is no // update operation @@ -804,7 +929,8 @@ CommandBufferCmd::BufferUsageVector AllocateCmd::buffers() { return {}; } FreeCmd::FreeCmd(BufferAllocation allocation) : allocation_(allocation) {} -absl::Status FreeCmd::Record(const RecordParams& params, +absl::Status FreeCmd::Record(const Thunk::ExecuteParams& params, + StateManager& state, se::CommandBuffer* command_buffer) { VLOG(2) << "FreeCmd: index=" << allocation_.index(); @@ -836,15 +962,16 @@ GemmCmd::GemmCmd(GemmConfig config, const BufferAllocation::Slice& lhs_buffer, workspace_(workspace), deterministic_(deterministic) {} -absl::Status GemmCmd::Initialize(se::StreamExecutor* executor, - Thunk::ExecutableSource source) { - if (!executor->AsBlas()) { +absl::Status GemmCmd::Initialize(const Thunk::InitializeParams& params, + StateManager& state) { + if (!params.stream->parent()->AsBlas()) { return absl::InternalError("Failed to initialize BLAS support for GemmCmd"); } return absl::OkStatus(); } -absl::Status GemmCmd::Record(const RecordParams& params, +absl::Status GemmCmd::Record(const Thunk::ExecuteParams& params, + StateManager& state, se::CommandBuffer* command_buffer) { se::DeviceMemoryBase lhs = params.buffer_allocations->GetDeviceAddress(lhs_buffer_); @@ -861,15 +988,11 @@ absl::Status GemmCmd::Record(const RecordParams& params, VLOG(5) << " Out: " << output_buffer_ << " (" << out.opaque() << ")"; VLOG(5) << " Workspace: " << workspace_ << " (" << workspace.opaque() << ")"; - TF_ASSIGN_OR_RETURN( - auto nested_buffer, - se::CommandBuffer::Trace( - params.executor, params.trace_stream, [&](se::Stream* stream) { - return RunGemm(config_, lhs, rhs, out, workspace, deterministic_, - stream); - })); - - return command_buffer->AddNestedCommandBuffer(nested_buffer); + return AddTracedCommandBuffer( + params, state, command_buffer, [&](se::Stream* stream) { + return RunGemm(config_, lhs, rhs, out, workspace, deterministic_, + stream); + }); } CommandBufferCmd::BufferUsageVector GemmCmd::buffers() { @@ -883,8 +1006,9 @@ CommandBufferCmd::BufferUsageVector GemmCmd::buffers() { // CustomCallCmd //===----------------------------------------------------------------------===// -Status CustomCallCmd::Record(const RecordParams& params, - se::CommandBuffer* command_buffer) { +absl::Status CustomCallCmd::Record(const Thunk::ExecuteParams& params, + StateManager& state, + se::CommandBuffer* command_buffer) { std::vector buffers; buffers.reserve(operands_.size() + results_.size()); for (auto& slices : {operands_, results_}) { @@ -904,31 +1028,30 @@ Status CustomCallCmd::Record(const RecordParams& params, } } - if (VLOG_IS_ON(5)) { - VLOG(5) << "CustomCallCmd: "; - for (int i = 0; i < operands_.size(); ++i) { - if (operands_[i].has_value()) { - VLOG(5) << " Operand " << i << ": " << operands_[i]->slice << " (" - << buffers[i] << ")"; - } else { - VLOG(5) << " Operand " << i << ": null"; - } + VLOG(5) << "CustomCallCmd: "; + for (int i = 0; i < operands_.size(); ++i) { + if (operands_[i].has_value()) { + VLOG(5) << " Operand " << i << ": " << operands_[i]->slice << " (" + << buffers[i] << ")"; + } else { + VLOG(5) << " Operand " << i << ": null"; } - for (int i = 0; i < results_.size(); ++i) { - if (results_[i].has_value()) { - VLOG(5) << " Result " << i << ": " << results_[i]->slice << " (" - << buffers[operands_.size() + i] << ")"; - } else { - VLOG(5) << " Result " << i << ": null"; - } + } + for (int i = 0; i < results_.size(); ++i) { + if (results_[i].has_value()) { + VLOG(5) << " Result " << i << ": " << results_[i]->slice << " (" + << buffers[operands_.size() + i] << ")"; + } else { + VLOG(5) << " Result " << i << ": null"; } } #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM TF_ASSIGN_OR_RETURN( - auto nested_buffer, + auto nested_cmd, se::CommandBuffer::Trace( - params.executor, params.trace_stream, [&](se::Stream* stream) { + params.stream->parent(), params.command_buffer_trace_stream, + [&](se::Stream* stream) { se::gpu::GpuStreamHandle gpu_stream = se::gpu::AsGpuStreamValue(stream); XlaCustomCallStatus custom_call_status; @@ -941,7 +1064,7 @@ Status CustomCallCmd::Record(const RecordParams& params, } return absl::OkStatus(); })); - return command_buffer->AddNestedCommandBuffer(nested_buffer); + return command_buffer->AddNestedCommandBuffer(*nested_cmd); #else // GOOGLE_CUDA || TENSORFLOW_USE_ROCM return Unavailable( "Custom calls on GPU are not supported in this configuration. Please " @@ -991,7 +1114,8 @@ absl::Status CollectiveCmd::Prepare( collectives->global_device_id_map ? &local_devices : nullptr); return resource_requests.AddClique( - NcclCliqueKey(std::move(participants), /*stream_id=*/0), + NcclCliqueKey(std::move(participants), /*stream_id=*/0, + GetAsyncStreamKind()), num_local_participants); } @@ -1007,7 +1131,8 @@ AllReduceCmd::AllReduceCmd( reduction_kind_(reduction_kind), buffers_(buffers.begin(), buffers.end()) {} -absl::Status AllReduceCmd::Record(const RecordParams& params, +absl::Status AllReduceCmd::Record(const Thunk::ExecuteParams& params, + StateManager& state, se::CommandBuffer* command_buffer) { TF_ASSIGN_OR_RETURN( std::vector device_buffers, @@ -1031,26 +1156,22 @@ absl::Status AllReduceCmd::Record(const RecordParams& params, // Today when recording collective operations into command buffers we always // use a sync mode and a stream id `0`. TF_ASSIGN_OR_RETURN( - NcclComm::Lock comm, + NcclApi::NcclCommHandle comm, GetNcclComm(*params.collective_params, *params.collective_cliques, config().replica_groups, config().group_mode, - /*stream_id=*/0)); + /*stream_id=*/0, GetAsyncStreamKind())); // Use custom allocator for persistent execution plans. NcclApi::ScopedPersistentPlanAllocator scoped_allocator( - *comm, tsl::MakeRef( - params.buffer_allocations->device_ordinal(), - params.buffer_allocations->memory_allocator(), params.stream)); - - TF_ASSIGN_OR_RETURN( - auto nested_buffer, - se::CommandBuffer::Trace( - params.executor, params.trace_stream, [&](se::Stream* stream) { - return RunAllReduce(nccl_api(), reduction_kind_, device_buffers, - *stream, *comm); - })); - - return command_buffer->AddNestedCommandBuffer(nested_buffer); + comm, tsl::MakeRef( + params.buffer_allocations->device_ordinal(), + params.buffer_allocations->memory_allocator(), params.stream)); + + return AddTracedCommandBuffer( + params, state, command_buffer, [&](se::Stream* stream) { + return RunAllReduce(nccl_api(), reduction_kind_, device_buffers, + *stream, comm); + }); } CommandBufferCmd::BufferUsageVector AllReduceCmd::buffers() { @@ -1074,7 +1195,8 @@ ReduceScatterCmd::ReduceScatterCmd( reduction_kind_(reduction_kind), buffers_(buffers.begin(), buffers.end()) {} -absl::Status ReduceScatterCmd::Record(const RecordParams& params, +absl::Status ReduceScatterCmd::Record(const Thunk::ExecuteParams& params, + StateManager& state, se::CommandBuffer* command_buffer) { TF_ASSIGN_OR_RETURN( std::vector device_buffers, @@ -1099,26 +1221,22 @@ absl::Status ReduceScatterCmd::Record(const RecordParams& params, // Today when recording collective operations into command buffers we always // use a sync mode and a stream id `0`. TF_ASSIGN_OR_RETURN( - NcclComm::Lock comm, + NcclApi::NcclCommHandle comm, GetNcclComm(*params.collective_params, *params.collective_cliques, config().replica_groups, config().group_mode, - /*stream_id=*/0)); + /*stream_id=*/0, GetAsyncStreamKind())); // Use custom allocator for persistent execution plans. NcclApi::ScopedPersistentPlanAllocator scoped_allocator( - *comm, tsl::MakeRef( - params.buffer_allocations->device_ordinal(), - params.buffer_allocations->memory_allocator(), params.stream)); - - TF_ASSIGN_OR_RETURN( - auto nested_buffer, - se::CommandBuffer::Trace( - params.executor, params.trace_stream, [&](se::Stream* stream) { - return RunReduceScatter(nccl_api(), reduction_kind_, device_buffers, - *stream, *comm); - })); - - return command_buffer->AddNestedCommandBuffer(nested_buffer); + comm, tsl::MakeRef( + params.buffer_allocations->device_ordinal(), + params.buffer_allocations->memory_allocator(), params.stream)); + + return AddTracedCommandBuffer( + params, state, command_buffer, [&](se::Stream* stream) { + return RunReduceScatter(nccl_api(), reduction_kind_, device_buffers, + *stream, comm); + }); } CommandBufferCmd::BufferUsageVector ReduceScatterCmd::buffers() { @@ -1140,7 +1258,8 @@ AllGatherCmd::AllGatherCmd( : CollectiveCmd(nccl_api, std::move(config)), buffers_(buffers.begin(), buffers.end()) {} -absl::Status AllGatherCmd::Record(const RecordParams& params, +absl::Status AllGatherCmd::Record(const Thunk::ExecuteParams& params, + StateManager& state, se::CommandBuffer* command_buffer) { TF_ASSIGN_OR_RETURN( std::vector device_buffers, @@ -1164,25 +1283,21 @@ absl::Status AllGatherCmd::Record(const RecordParams& params, // Today when recording collective operations into command buffers we always // use a sync mode and a stream id `0`. TF_ASSIGN_OR_RETURN( - NcclComm::Lock comm, + NcclApi::NcclCommHandle comm, GetNcclComm(*params.collective_params, *params.collective_cliques, config().replica_groups, config().group_mode, - /*stream_id=*/0)); + /*stream_id=*/0, GetAsyncStreamKind())); // Use custom allocator for persistent execution plans. NcclApi::ScopedPersistentPlanAllocator scoped_allocator( - *comm, tsl::MakeRef( - params.buffer_allocations->device_ordinal(), - params.buffer_allocations->memory_allocator(), params.stream)); - - TF_ASSIGN_OR_RETURN( - auto nested_buffer, - se::CommandBuffer::Trace( - params.executor, params.trace_stream, [&](se::Stream* stream) { - return RunAllGather(nccl_api(), device_buffers, *stream, *comm); - })); - - return command_buffer->AddNestedCommandBuffer(nested_buffer); + comm, tsl::MakeRef( + params.buffer_allocations->device_ordinal(), + params.buffer_allocations->memory_allocator(), params.stream)); + + return AddTracedCommandBuffer( + params, state, command_buffer, [&](se::Stream* stream) { + return RunAllGather(nccl_api(), device_buffers, *stream, comm); + }); } CommandBufferCmd::BufferUsageVector AllGatherCmd::buffers() { diff --git a/third_party/xla/xla/service/gpu/runtime3/command_buffer_cmd.h b/third_party/xla/xla/service/gpu/runtime/command_buffer_cmd.h similarity index 70% rename from third_party/xla/xla/service/gpu/runtime3/command_buffer_cmd.h rename to third_party/xla/xla/service/gpu/runtime/command_buffer_cmd.h index 02504f33ad2b5a..e2d3f891017b5b 100644 --- a/third_party/xla/xla/service/gpu/runtime3/command_buffer_cmd.h +++ b/third_party/xla/xla/service/gpu/runtime/command_buffer_cmd.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_GPU_RUNTIME3_COMMAND_BUFFER_CMD_H_ -#define XLA_SERVICE_GPU_RUNTIME3_COMMAND_BUFFER_CMD_H_ +#ifndef XLA_SERVICE_GPU_RUNTIME_COMMAND_BUFFER_CMD_H_ +#define XLA_SERVICE_GPU_RUNTIME_COMMAND_BUFFER_CMD_H_ #include #include @@ -28,6 +28,7 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" #include "absl/container/inlined_vector.h" +#include "absl/functional/function_ref.h" #include "absl/status/status.h" #include "absl/strings/string_view.h" #include "absl/synchronization/mutex.h" @@ -40,12 +41,13 @@ limitations under the License. #include "xla/service/gpu/matmul_utils.h" #include "xla/service/gpu/nccl_api.h" #include "xla/service/gpu/nccl_collective_thunk.h" -#include "xla/service/gpu/runtime3/custom_call_thunk.h" +#include "xla/service/gpu/runtime/custom_call_thunk.h" #include "xla/service/gpu/thunk.h" #include "xla/status.h" #include "xla/stream_executor/command_buffer.h" #include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/kernel.h" +#include "xla/stream_executor/stream.h" #include "xla/stream_executor/stream_executor.h" namespace xla::gpu { @@ -57,11 +59,16 @@ namespace xla::gpu { // CommandBufferCmd is an abstract command that creates or updates command // buffer by recording commands into it. // -// Command initialization and recording must be thread safe as commands can be -// recorded concurrently for multiple command buffers on different stream -// executors. +// Commands have the same execution stages as thunks as they are executed by a +// command buffer thunk: Prepare, Initialize and Record (Execute). See Thunk +// documentation for details. +// +// Commands must be thread safe as they can be recorded into multiple command +// buffers concurrently on different stream executors. class CommandBufferCmd { public: + virtual ~CommandBufferCmd() = default; + enum class MemoryAccess { kRead, kWrite }; // BufferUsage tracks memory access type for a buffer slice, so that we can @@ -85,6 +92,59 @@ class CommandBufferCmd { using BufferUsageVector = absl::InlinedVector; + // A base class for externally managed command state. + // + // Commands can be executed concurrently for many stream executors (underlying + // devices) and command buffers. Managing per-executor state can become + // expensive as it requires synchronization. Furthermore the number of command + // buffers command is recorded into is unbounded as they come and go (command + // buffers evicted and reconstructed) which makes it hard to manage the + // lifetime of resources attached to command buffers. + // + // Externally managed state (owned and synchronized by CommandBufferThunk) + // allows commands to attach a piece of information to command buffer in a + // safe and performant way. + class State { + public: + virtual ~State() = default; + }; + + // An external manager for a state attached to commands. + class StateManager { + public: + virtual ~StateManager() = default; + + template + ConcreteState* GetOrNull(const CommandBufferCmd* cmd) { + static_assert(std::is_base_of_v); + return static_cast(GetOrNull(cmd)); + } + + template + ConcreteState* GetOrCreate( + const CommandBufferCmd* cmd, + absl::FunctionRef()> create) { + static_assert(std::is_base_of_v); + return static_cast(GetOrCreate( + cmd, [&]() -> std::unique_ptr { return create(); })); + } + + template + ConcreteState* GetOrCreate(const CommandBufferCmd* cmd) { + static_assert(std::is_base_of_v); + return static_cast( + GetOrCreate(cmd, [] { return std::make_unique(); })); + } + + private: + State* GetOrNull(const CommandBufferCmd* cmd); + + State* GetOrCreate(const CommandBufferCmd* cmd, + absl::FunctionRef()> create); + + absl::flat_hash_map> state_; + }; + // See Thunk documentation for XLA execution stages (prepare, initialize, // execute). Commands mirror thunks as they are executed as CommandBufferThunk // that is plugged into the Thunk execution cycle. @@ -96,36 +156,17 @@ class CommandBufferCmd { return absl::OkStatus(); } - // Run time parameters required for recording commands into the command - // buffer. For example when we emit command buffer cmd sequence from an HLO - // module, we only know the buffer slices required for HLO operations, but the - // concrete device pointers become available only at run time. - // - // For allocations that performed through command buffer Allocate command, the - // target addresses are tracked by command buffer runtime. To record command - // that consumes buffers allocated inside command buffer, user should specify - // the target address as se::DeviceMemoryBase{nullptr, size}. - // - // TODO(ezhulenev): Use Thunk ExecuteParams for recording commands. - struct RecordParams { - se::StreamExecutor* executor = nullptr; - se::Stream* stream = nullptr; - se::Stream* trace_stream = nullptr; - const BufferAllocations* buffer_allocations = nullptr; - const Thunk::CollectiveExecuteParams* collective_params = nullptr; - const Thunk::CollectiveCliques* collective_cliques = nullptr; - }; - - // Prepares a command for recording on a given executor. We split it into a + // Initialize a command for recording on a given executor. We split it into a // separate function to allow expensive initialization (e.g. device kernel // loading) to happen before a command buffer thunk execution. - virtual absl::Status Initialize(se::StreamExecutor* executor, - Thunk::ExecutableSource source) { + virtual absl::Status Initialize(const Thunk::InitializeParams& params, + StateManager& state) { return absl::OkStatus(); } // Records command into the command buffer. - virtual absl::Status Record(const RecordParams& params, + virtual absl::Status Record(const Thunk::ExecuteParams& params, + StateManager& state, se::CommandBuffer* command_buffer) = 0; // Returns all buffers used by the cmd. These will be used to track cmd @@ -135,7 +176,13 @@ class CommandBufferCmd { // Returns true if command implemented as a nested command buffer. virtual bool IsNestedCommandBuffer() const { return false; } - virtual ~CommandBufferCmd() = default; + std::string_view profile_annotation() const { return profile_annotation_; } + void set_profile_annotation(std::string_view profile_annotation) { + profile_annotation_ = profile_annotation; + } + + private: + std::string profile_annotation_; }; //===----------------------------------------------------------------------===// @@ -176,11 +223,12 @@ class CommandBufferCmdSequence { Thunk::ResourceRequests& resource_requests); // Initializes all commands added to a sequence. - absl::Status Initialize(se::StreamExecutor* executor, - Thunk::ExecutableSource source); + absl::Status Initialize(const Thunk::InitializeParams& params, + CommandBufferCmd::StateManager& state); // Records all commands added to a sequence into the given command buffer. - absl::Status Record(const CommandBufferCmd::RecordParams& params, + absl::Status Record(const Thunk::ExecuteParams& params, + CommandBufferCmd::StateManager& state, se::CommandBuffer* command_buffer, RecordMode mode = RecordMode::kExclusive); @@ -198,10 +246,7 @@ class CommandBufferCmdSequence { size_t size() const { return commands_.size(); } private: - struct Command { - Command(std::unique_ptr cmd, bool requires_barrier) - : cmd(std::move(cmd)), requires_barrier(requires_barrier) {} - + struct CommandInfo { std::unique_ptr cmd; bool requires_barrier; }; @@ -213,7 +258,7 @@ class CommandBufferCmdSequence { void ClearTrackedBuffers(); bool force_barriers_; - std::vector commands_; + std::vector commands_; // Buffers referenced by commands in this sequence. absl::flat_hash_set buffers_; @@ -228,6 +273,53 @@ class CommandBufferCmdSequence { absl::flat_hash_set write_set_; }; +//===----------------------------------------------------------------------===// +// TracedCommandBuffer +//===----------------------------------------------------------------------===// + +// A cache for traced command buffers that will re-trace on change in buffer +// allocations that are relevant for `buffers` passed to constructor. We use a +// very simple most-recently-used cache of traced command buffers as in practice +// subsequent calls to XLA executable tend to reuse the same allocations. +class TracedCommandBuffer : public CommandBufferCmd::State { + public: + explicit TracedCommandBuffer(CommandBufferCmd::BufferUsageVector buffers, + int64_t capacity = 16); + + // Returns cached command buffer traced using the same buffer addresses or + // traces and caches a new command buffer using user provided callback. + absl::StatusOr GetOrTraceCommandBuffer( + const BufferAllocations* buffer_allocation, se::StreamExecutor* executor, + se::Stream* stream, absl::FunctionRef trace); + + private: + std::vector allocs_indices_; + + struct Entry { + std::vector recorded_allocs; + std::unique_ptr command_buffer; + }; + + int64_t capacity_; + std::vector entries_; +}; + +//===----------------------------------------------------------------------===// +// TracedCommandBufferCmd +//===----------------------------------------------------------------------===// + +// A base class for commands implemented as tracing of stream activities. +class TracedCommandBufferCmd : public CommandBufferCmd { + protected: + // Creates a command buffer by calling a user-provided `trace` function and + // adds it as a nested command to `command_buffer`. Traced command buffers + // cached and reused in an instance of `TracedCommandBuffer` kept in `state`. + absl::Status AddTracedCommandBuffer( + const Thunk::ExecuteParams& params, StateManager& state, + se::CommandBuffer* command_buffer, + absl::FunctionRef trace); +}; + //===----------------------------------------------------------------------===// // ComputationIdCmd (ReplicaId and PartitionId) //===----------------------------------------------------------------------===// @@ -238,10 +330,10 @@ class ComputationIdCmd : public CommandBufferCmd { ComputationIdCmd(BufferAllocation::Slice dest, Kind kind); - absl::Status Initialize(se::StreamExecutor* executor, - Thunk::ExecutableSource source) override; + absl::Status Initialize(const Thunk::InitializeParams& params, + StateManager& state) override; - absl::Status Record(const RecordParams& params, + absl::Status Record(const Thunk::ExecuteParams& params, StateManager& state, se::CommandBuffer* command_buffer) override; BufferUsageVector buffers() override; @@ -274,10 +366,10 @@ class LaunchCmd : public CommandBufferCmd { absl::Span args_access, LaunchDimensions dims, int64_t shmem_bytes); - absl::Status Initialize(se::StreamExecutor* executor, - Thunk::ExecutableSource source) override; + absl::Status Initialize(const Thunk::InitializeParams& params, + StateManager& state) override; - absl::Status Record(const RecordParams& params, + absl::Status Record(const Thunk::ExecuteParams& params, StateManager& state, se::CommandBuffer* command_buffer) override; BufferUsageVector buffers() override; @@ -306,10 +398,10 @@ class CustomKernelLaunchCmd : public CommandBufferCmd { absl::Span args_access, CustomKernel custom_kernel); - absl::Status Initialize(se::StreamExecutor* executor, - Thunk::ExecutableSource source) override; + absl::Status Initialize(const Thunk::InitializeParams& params, + StateManager& state) override; - absl::Status Record(const RecordParams& params, + absl::Status Record(const Thunk::ExecuteParams& params, StateManager& state, se::CommandBuffer* command_buffer) override; BufferUsageVector buffers() override; @@ -335,7 +427,7 @@ class MemcpyDeviceToDeviceCmd : public CommandBufferCmd { MemcpyDeviceToDeviceCmd(BufferAllocation::Slice dst, BufferAllocation::Slice src, int64_t num_bytes); - absl::Status Record(const RecordParams& params, + absl::Status Record(const Thunk::ExecuteParams& params, StateManager& state, se::CommandBuffer* command_buffer) override; BufferUsageVector buffers() override; @@ -354,7 +446,7 @@ class MemzeroCmd : public CommandBufferCmd { public: explicit MemzeroCmd(BufferAllocation::Slice dst); - absl::Status Record(const RecordParams& params, + absl::Status Record(const Thunk::ExecuteParams& params, StateManager& state, se::CommandBuffer* command_buffer) override; BufferUsageVector buffers() override; @@ -371,7 +463,7 @@ class Memset32Cmd : public CommandBufferCmd { public: explicit Memset32Cmd(BufferAllocation::Slice dst, uint32_t bit_pattern); - absl::Status Record(const RecordParams& params, + absl::Status Record(const Thunk::ExecuteParams& params, StateManager& state, se::CommandBuffer* command_buffer) override; BufferUsageVector buffers() override; @@ -389,10 +481,10 @@ class IfCmd : public CommandBufferCmd { public: IfCmd(BufferAllocation::Slice pred, CommandBufferCmdSequence then_commands); - absl::Status Initialize(se::StreamExecutor* executor, - Thunk::ExecutableSource source) override; + absl::Status Initialize(const Thunk::InitializeParams& params, + StateManager& state) override; - absl::Status Record(const RecordParams& params, + absl::Status Record(const Thunk::ExecuteParams& params, StateManager& state, se::CommandBuffer* command_buffer) override; BufferUsageVector buffers() override; @@ -412,10 +504,10 @@ class IfElseCmd : public CommandBufferCmd { CommandBufferCmdSequence then_commands, CommandBufferCmdSequence else_commands); - absl::Status Initialize(se::StreamExecutor* executor, - Thunk::ExecutableSource source) override; + absl::Status Initialize(const Thunk::InitializeParams& params, + StateManager& state) override; - absl::Status Record(const RecordParams& params, + absl::Status Record(const Thunk::ExecuteParams& params, StateManager& state, se::CommandBuffer* command_buffer) override; BufferUsageVector buffers() override; @@ -435,10 +527,10 @@ class CaseCmd : public CommandBufferCmd { CaseCmd(BufferAllocation::Slice index, std::vector branches_commands); - absl::Status Initialize(se::StreamExecutor* executor, - Thunk::ExecutableSource source) override; + absl::Status Initialize(const Thunk::InitializeParams& params, + StateManager& state) override; - absl::Status Record(const RecordParams& params, + absl::Status Record(const Thunk::ExecuteParams& params, StateManager& state, se::CommandBuffer* command_buffer) override; BufferUsageVector buffers() override; @@ -457,10 +549,10 @@ class ForCmd : public CommandBufferCmd { ForCmd(int32_t num_iterations, BufferAllocation::Slice loop_counter, CommandBufferCmdSequence body_commands); - absl::Status Initialize(se::StreamExecutor* executor, - Thunk::ExecutableSource source) override; + absl::Status Initialize(const Thunk::InitializeParams& params, + StateManager& state) override; - absl::Status Record(const RecordParams& params, + absl::Status Record(const Thunk::ExecuteParams& params, StateManager& state, se::CommandBuffer* command_buffer) override; BufferUsageVector buffers() override; @@ -480,10 +572,10 @@ class WhileCmd : public CommandBufferCmd { WhileCmd(BufferAllocation::Slice pred, CommandBufferCmdSequence cond_commands, CommandBufferCmdSequence body_commands); - absl::Status Initialize(se::StreamExecutor* executor, - Thunk::ExecutableSource source) override; + absl::Status Initialize(const Thunk::InitializeParams& params, + StateManager& state) override; - absl::Status Record(const RecordParams& params, + absl::Status Record(const Thunk::ExecuteParams& params, StateManager& state, se::CommandBuffer* command_buffer) override; BufferUsageVector buffers() override; @@ -504,7 +596,7 @@ class AllocateCmd : public CommandBufferCmd { // After calling this function, the allocated memory is tracked in // CommandBuffer object. - absl::Status Record(const RecordParams& params, + absl::Status Record(const Thunk::ExecuteParams& params, StateManager& state, se::CommandBuffer* command_buffer) override; BufferUsageVector buffers() override; @@ -523,7 +615,7 @@ class FreeCmd : public CommandBufferCmd { // After calling this function, the allocated memory address for dst // BufferAllocation is freed, no update is required. - absl::Status Record(const RecordParams& params, + absl::Status Record(const Thunk::ExecuteParams& params, StateManager& state, se::CommandBuffer* command_buffer) override; BufferUsageVector buffers() override; @@ -536,17 +628,17 @@ class FreeCmd : public CommandBufferCmd { // GemmCmd //===----------------------------------------------------------------------===// -class GemmCmd : public CommandBufferCmd { +class GemmCmd : public TracedCommandBufferCmd { public: GemmCmd(GemmConfig config, const BufferAllocation::Slice& lhs_buffer, const BufferAllocation::Slice& rhs_buffer, const BufferAllocation::Slice& output_buffer, const BufferAllocation::Slice& workspace, bool deterministic); - absl::Status Initialize(se::StreamExecutor* executor, - Thunk::ExecutableSource source) override; + absl::Status Initialize(const Thunk::InitializeParams& params, + StateManager& state) override; - absl::Status Record(const RecordParams& params, + absl::Status Record(const Thunk::ExecuteParams& params, StateManager& state, se::CommandBuffer* command_buffer) override; BufferUsageVector buffers() override; @@ -585,7 +677,7 @@ class CustomCallCmd : public CommandBufferCmd { results_(std::move(results)), opaque_(opaque){}; - absl::Status Record(const RecordParams& params, + absl::Status Record(const Thunk::ExecuteParams& params, StateManager& state, se::CommandBuffer* command_buffer) override; BufferUsageVector buffers() override; @@ -602,7 +694,7 @@ class CustomCallCmd : public CommandBufferCmd { // CollectiveCmd //===----------------------------------------------------------------------===// -class CollectiveCmd : public CommandBufferCmd { +class CollectiveCmd : public TracedCommandBufferCmd { public: CollectiveCmd(NcclApi* nccl_api, NcclCollectiveConfig config); @@ -611,6 +703,8 @@ class CollectiveCmd : public CommandBufferCmd { bool IsNestedCommandBuffer() const final { return true; } + virtual AsyncStreamKind GetAsyncStreamKind() = 0; + protected: NcclApi* nccl_api() const { return nccl_api_; } const NcclCollectiveConfig& config() const { return config_; } @@ -630,11 +724,15 @@ class AllReduceCmd : public CollectiveCmd { ReductionKind reduction_kind, absl::Span buffers); - absl::Status Record(const RecordParams& params, + absl::Status Record(const Thunk::ExecuteParams& params, StateManager& state, se::CommandBuffer* command_buffer) override; BufferUsageVector buffers() override; + AsyncStreamKind GetAsyncStreamKind() override { + return AsyncStreamKind::kCollective; + }; + private: ReductionKind reduction_kind_; std::vector buffers_; @@ -650,11 +748,15 @@ class ReduceScatterCmd : public CollectiveCmd { ReductionKind reduction_kind, absl::Span buffers); - absl::Status Record(const RecordParams& params, + absl::Status Record(const Thunk::ExecuteParams& params, StateManager& state, se::CommandBuffer* command_buffer) override; BufferUsageVector buffers() override; + AsyncStreamKind GetAsyncStreamKind() override { + return AsyncStreamKind::kCollective; + }; + private: ReductionKind reduction_kind_; std::vector buffers_; @@ -669,15 +771,19 @@ class AllGatherCmd : public CollectiveCmd { AllGatherCmd(NcclApi* nccl_api, NcclCollectiveConfig config, absl::Span buffers); - absl::Status Record(const RecordParams& params, + absl::Status Record(const Thunk::ExecuteParams& params, StateManager& state, se::CommandBuffer* command_buffer) override; BufferUsageVector buffers() override; + AsyncStreamKind GetAsyncStreamKind() override { + return AsyncStreamKind::kCollective; + }; + private: std::vector buffers_; }; } // namespace xla::gpu -#endif // XLA_SERVICE_GPU_RUNTIME3_COMMAND_BUFFER_CMD_H_ +#endif // XLA_SERVICE_GPU_RUNTIME_COMMAND_BUFFER_CMD_H_ diff --git a/third_party/xla/xla/service/gpu/runtime3/command_buffer_cmd_emitter.cc b/third_party/xla/xla/service/gpu/runtime/command_buffer_cmd_emitter.cc similarity index 88% rename from third_party/xla/xla/service/gpu/runtime3/command_buffer_cmd_emitter.cc rename to third_party/xla/xla/service/gpu/runtime/command_buffer_cmd_emitter.cc index 5556b186c0524d..daebb2a044711f 100644 --- a/third_party/xla/xla/service/gpu/runtime3/command_buffer_cmd_emitter.cc +++ b/third_party/xla/xla/service/gpu/runtime/command_buffer_cmd_emitter.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/runtime3/command_buffer_cmd_emitter.h" +#include "xla/service/gpu/runtime/command_buffer_cmd_emitter.h" #include #include @@ -23,18 +23,19 @@ limitations under the License. #include "absl/container/inlined_vector.h" #include "absl/status/status.h" #include "absl/status/statusor.h" -#include "xla/service/gpu/nccl_all_gather_thunk.h" -#include "xla/service/gpu/nccl_all_reduce_thunk.h" -#include "xla/service/gpu/runtime3/command_buffer_cmd.h" -#include "xla/service/gpu/runtime3/conditional_thunk.h" -#include "xla/service/gpu/runtime3/copy_thunk.h" -#include "xla/service/gpu/runtime3/custom_call_thunk.h" -#include "xla/service/gpu/runtime3/gemm_thunk.h" -#include "xla/service/gpu/runtime3/kernel_thunk.h" -#include "xla/service/gpu/runtime3/memset_thunk.h" -#include "xla/service/gpu/runtime3/replica_id_thunk.h" -#include "xla/service/gpu/runtime3/sequential_thunk.h" -#include "xla/service/gpu/runtime3/while_thunk.h" +#include "xla/service/gpu/runtime/command_buffer_cmd.h" +#include "xla/service/gpu/runtime/conditional_thunk.h" +#include "xla/service/gpu/runtime/copy_thunk.h" +#include "xla/service/gpu/runtime/custom_call_thunk.h" +#include "xla/service/gpu/runtime/gemm_thunk.h" +#include "xla/service/gpu/runtime/kernel_thunk.h" +#include "xla/service/gpu/runtime/memset_thunk.h" +#include "xla/service/gpu/runtime/nccl_all_gather_thunk.h" +#include "xla/service/gpu/runtime/nccl_all_reduce_thunk.h" +#include "xla/service/gpu/runtime/replica_id_thunk.h" +#include "xla/service/gpu/runtime/sequential_thunk.h" +#include "xla/service/gpu/runtime/wait_for_streams_thunk.h" +#include "xla/service/gpu/runtime/while_thunk.h" #include "xla/service/gpu/thunk.h" #include "xla/util.h" #include "tsl/platform/errors.h" @@ -162,16 +163,25 @@ static absl::StatusOr Convert(const CustomCallThunk& thunk) { } //===----------------------------------------------------------------------===// +static absl::StatusOr CopyMetadata(absl::StatusOr cmd, + const Thunk& thunk) { + if (cmd.ok()) { + (*cmd)->set_profile_annotation(thunk.profile_annotation()); + return cmd; + } + return cmd; +} template static absl::StatusOr Convert(const Thunk& thunk) { - return Convert(static_cast(thunk)); + return CopyMetadata(Convert(static_cast(thunk)), thunk); } template static absl::StatusOr Convert(const Thunk& thunk, bool force_barriers) { - return Convert(static_cast(thunk), force_barriers); + return CopyMetadata( + Convert(static_cast(thunk), force_barriers), thunk); } static absl::Status AppendCommands(CommandBufferCmdSequence& cmd_sequence, @@ -226,6 +236,7 @@ static absl::Status AppendCommands(CommandBufferCmdSequence& cmd_sequence, case Thunk::Kind::kNcclAllGatherDone: case Thunk::Kind::kNcclAllReduceDone: case Thunk::Kind::kNcclReduceScatterDone: + case Thunk::Kind::kWaitForStreams: return absl::OkStatus(); default: diff --git a/third_party/xla/xla/service/gpu/runtime3/command_buffer_cmd_emitter.h b/third_party/xla/xla/service/gpu/runtime/command_buffer_cmd_emitter.h similarity index 82% rename from third_party/xla/xla/service/gpu/runtime3/command_buffer_cmd_emitter.h rename to third_party/xla/xla/service/gpu/runtime/command_buffer_cmd_emitter.h index 237f0ee7c41fa6..9b77ab7630a99a 100644 --- a/third_party/xla/xla/service/gpu/runtime3/command_buffer_cmd_emitter.h +++ b/third_party/xla/xla/service/gpu/runtime/command_buffer_cmd_emitter.h @@ -13,10 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_GPU_RUNTIME3_COMMAND_BUFFER_CMD_EMITTER_H_ -#define XLA_SERVICE_GPU_RUNTIME3_COMMAND_BUFFER_CMD_EMITTER_H_ +#ifndef XLA_SERVICE_GPU_RUNTIME_COMMAND_BUFFER_CMD_EMITTER_H_ +#define XLA_SERVICE_GPU_RUNTIME_COMMAND_BUFFER_CMD_EMITTER_H_ -#include "xla/service/gpu/runtime3/command_buffer_cmd.h" +#include "xla/service/gpu/runtime/command_buffer_cmd.h" #include "xla/service/gpu/thunk.h" #include "xla/statusor.h" @@ -31,4 +31,4 @@ absl::StatusOr ConvertToCommands( } // namespace xla::gpu -#endif // XLA_SERVICE_GPU_RUNTIME3_COMMAND_BUFFER_CMD_EMITTER_H_ +#endif // XLA_SERVICE_GPU_RUNTIME_COMMAND_BUFFER_CMD_EMITTER_H_ diff --git a/third_party/xla/xla/service/gpu/runtime3/command_buffer_cmd_test.cc b/third_party/xla/xla/service/gpu/runtime/command_buffer_cmd_test.cc similarity index 53% rename from third_party/xla/xla/service/gpu/runtime3/command_buffer_cmd_test.cc rename to third_party/xla/xla/service/gpu/runtime/command_buffer_cmd_test.cc index 7cf458001b0a5a..036244a4706fbe 100644 --- a/third_party/xla/xla/service/gpu/runtime3/command_buffer_cmd_test.cc +++ b/third_party/xla/xla/service/gpu/runtime/command_buffer_cmd_test.cc @@ -13,26 +13,34 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/runtime3/command_buffer_cmd.h" +#include "xla/service/gpu/runtime/command_buffer_cmd.h" +#include #include #include +#include "absl/functional/function_ref.h" +#include "absl/status/status.h" #include "absl/strings/ascii.h" #include "xla/service/buffer_assignment.h" #include "xla/service/gpu/buffer_allocations.h" #include "xla/service/gpu/launch_dimensions.h" #include "xla/service/gpu/thunk.h" #include "xla/service/platform_util.h" +#include "xla/service/service_executable_run_options.h" #include "xla/status.h" #include "xla/stream_executor/command_buffer.h" +#include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/gpu/gpu_test_kernels.h" -#include "xla/stream_executor/multi_platform_manager.h" #include "xla/stream_executor/platform.h" +#include "xla/stream_executor/platform_manager.h" #include "xla/stream_executor/stream_executor.h" #include "xla/types.h" // IWYU pragma: keep #include "tsl/lib/core/status_test_util.h" +#include "tsl/platform/status.h" +#include "tsl/platform/statusor.h" #include "tsl/platform/test.h" +#include "tsl/platform/test_benchmark.h" namespace xla::gpu { @@ -43,7 +51,7 @@ using MemoryAccess = CommandBufferCmd::MemoryAccess; static se::StreamExecutor* GpuExecutor() { auto name = absl::AsciiStrToUpper(PlatformUtil::CanonicalPlatformName("gpu").value()); - auto* platform = se::MultiPlatformManager::PlatformWithName(name).value(); + auto* platform = se::PlatformManager::PlatformWithName(name).value(); return platform->ExecutorForDevice(0).value(); } @@ -54,7 +62,8 @@ struct TestOnlyCommandBufferCmd : public CommandBufferCmd { explicit TestOnlyCommandBufferCmd(BufferUsageVector buffer_usage) : buffer_usage(buffer_usage) {} - absl::Status Record(const RecordParams&, se::CommandBuffer*) override { + absl::Status Record(const Thunk::ExecuteParams&, StateManager&, + se::CommandBuffer*) override { return absl::OkStatus(); } @@ -147,9 +156,7 @@ TEST(CommandBufferCmdTest, MemcpyCmd) { se::StreamExecutor* executor = GpuExecutor(); se::Stream stream(executor); - stream.Init(); - ASSERT_TRUE(stream.ok()); - + CHECK_OK(stream.Initialize()); int64_t length = 4; int64_t byte_length = sizeof(int32_t) * length; @@ -157,8 +164,8 @@ TEST(CommandBufferCmdTest, MemcpyCmd) { se::DeviceMemory a = executor->AllocateArray(length, 0); se::DeviceMemory b = executor->AllocateArray(length, 0); - stream.ThenMemset32(&a, 42, byte_length); - stream.ThenMemZero(&b, byte_length); + TF_ASSERT_OK(stream.Memset32(&a, 42, byte_length)); + TF_ASSERT_OK(stream.MemZero(&b, byte_length)); // Prepare buffer allocations for recording command buffer. BufferAllocation alloc_a(/*index=*/0, byte_length, /*color=*/0); @@ -171,18 +178,23 @@ TEST(CommandBufferCmdTest, MemcpyCmd) { CommandBufferCmdSequence commands; commands.Emplace(slice_b, slice_a, byte_length); + ServiceExecutableRunOptions run_options; BufferAllocations allocations({a, b}, 0, executor->GetAllocator()); + Thunk::ExecuteParams params = Thunk::ExecuteParams::Create( + run_options, allocations, &stream, &stream, {}, nullptr, nullptr); + + CommandBufferCmd::StateManager state; + auto command_buffer = se::CommandBuffer::Create(executor).value(); - TF_ASSERT_OK(commands.Record({executor, &stream, &stream, &allocations}, - &command_buffer)); + TF_ASSERT_OK(commands.Record(params, state, command_buffer.get())); // Execute command buffer and verify that it copied the memory. - TF_ASSERT_OK(executor->Submit(&stream, command_buffer)); + TF_ASSERT_OK(executor->Submit(&stream, *command_buffer)); // Copy `b` data back to host. std::vector dst(4, 0); - stream.ThenMemcpy(dst.data(), b, byte_length); + TF_ASSERT_OK(stream.Memcpy(dst.data(), b, byte_length)); ASSERT_EQ(dst, std::vector(4, 42)); } @@ -191,9 +203,7 @@ TEST(CommandBufferCmdTest, LaunchCmd) { se::StreamExecutor* executor = GpuExecutor(); se::Stream stream(executor); - stream.Init(); - ASSERT_TRUE(stream.ok()); - + CHECK_OK(stream.Initialize()); int64_t length = 4; int64_t byte_length = sizeof(int32_t) * length; @@ -201,8 +211,8 @@ TEST(CommandBufferCmdTest, LaunchCmd) { se::DeviceMemory a = executor->AllocateArray(length, 0); se::DeviceMemory b = executor->AllocateArray(length, 0); - stream.ThenMemset32(&a, 42, byte_length); - stream.ThenMemZero(&b, byte_length); + TF_ASSERT_OK(stream.Memset32(&a, 42, byte_length)); + TF_ASSERT_OK(stream.MemZero(&b, byte_length)); // Prepare buffer allocations for recording command buffer. BufferAllocation alloc_a(/*index=*/0, byte_length, /*color=*/0); @@ -230,22 +240,171 @@ TEST(CommandBufferCmdTest, LaunchCmd) { /*binary=*/se::gpu::internal::kAddI32KernelModule #endif }; - TF_ASSERT_OK(commands.Initialize(executor, source)); + CommandBufferCmd::StateManager state; + TF_ASSERT_OK(commands.Initialize({executor, source}, state)); + + ServiceExecutableRunOptions run_options; BufferAllocations allocations({a, b}, 0, executor->GetAllocator()); + Thunk::ExecuteParams params = Thunk::ExecuteParams::Create( + run_options, allocations, &stream, &stream, {}, nullptr, nullptr); + auto command_buffer = se::CommandBuffer::Create(executor).value(); - TF_ASSERT_OK(commands.Record({executor, &stream, &stream, &allocations}, - &command_buffer)); + TF_ASSERT_OK(commands.Record(params, state, command_buffer.get())); // Execute command buffer and verify that it copied the memory. - TF_ASSERT_OK(executor->Submit(&stream, command_buffer)); + TF_ASSERT_OK(executor->Submit(&stream, *command_buffer)); // Copy `b` data back to host. std::vector dst(4, 0); - stream.ThenMemcpy(dst.data(), b, byte_length); + TF_ASSERT_OK(stream.Memcpy(dst.data(), b, byte_length)); ASSERT_EQ(dst, std::vector(4, 42 + 42)); } +TEST(CommandBufferCmdStateManageTest, GetOrCreateState) { + struct TestState : public CommandBufferCmd::State { + int32_t value = 0; + }; + + // We need a fake command buffer pointer to use as a key. + CommandBufferCmd* cmd = reinterpret_cast(0x1234567); + + CommandBufferCmd::StateManager state_manager; + + auto* state0 = state_manager.GetOrNull(cmd); + ASSERT_EQ(state0, nullptr); + + auto* state1 = state_manager.GetOrCreate(cmd); + ASSERT_EQ(state1->value, 0); + state1->value += 42; + + auto* state2 = state_manager.GetOrCreate(cmd); + ASSERT_EQ(state2->value, 42); + ASSERT_EQ(state1, state2); +} + +TEST(TracedCommandBuffer, GetOrUpdateCommandBuffer) { + se::StreamExecutor* executor = GpuExecutor(); + + se::Stream stream(executor); + CHECK_OK(stream.Initialize()); + BufferAllocation alloc0(/*index=*/0, /*size=*/1024, /*color=*/0); + BufferAllocation alloc1(/*index=*/1, /*size=*/1024, /*color=*/0); + + CommandBufferCmd::BufferUsageVector buffers = { + {BufferAllocation::Slice(&alloc0, 0, 1024), MemoryAccess::kRead}, + {BufferAllocation::Slice(&alloc1, 0, 1024), MemoryAccess::kWrite}}; + + TracedCommandBuffer traced_cmd_buffer(buffers, /*capacity=*/2); + + se::DeviceMemoryBase mem0(reinterpret_cast(0x01234567)); + se::DeviceMemoryBase mem1(reinterpret_cast(0x12345670)); + + BufferAllocations allocations({mem0, mem1}, 0, executor->GetAllocator()); + + // No-op trace callback to count how many times it was called. + int64_t num_calls = 0; + auto trace = [&](se::Stream*) { + num_calls++; + return absl::OkStatus(); + }; + + TF_ASSERT_OK_AND_ASSIGN(auto* command_buffer0, + traced_cmd_buffer.GetOrTraceCommandBuffer( + &allocations, executor, &stream, trace)); + + TF_ASSERT_OK_AND_ASSIGN(auto* command_buffer1, + traced_cmd_buffer.GetOrTraceCommandBuffer( + &allocations, executor, &stream, trace)); + + // Check that command buffer was reused as buffer allocations didn't change. + ASSERT_EQ(command_buffer0, command_buffer1); + EXPECT_EQ(num_calls, 1); + + // Check that when memory address changes we re-trace the command buffer. + se::DeviceMemoryBase mem2(reinterpret_cast(0x23456701)); + allocations = BufferAllocations({mem0, mem2}, 0, executor->GetAllocator()); + + TF_ASSERT_OK_AND_ASSIGN(auto* command_buffer2, + traced_cmd_buffer.GetOrTraceCommandBuffer( + &allocations, executor, &stream, trace)); + + ASSERT_NE(command_buffer0, command_buffer2); + EXPECT_EQ(num_calls, 2); + + // Check that we keep first command buffer in cache. + allocations = BufferAllocations({mem0, mem1}, 0, executor->GetAllocator()); + + TF_ASSERT_OK_AND_ASSIGN(auto* command_buffer3, + traced_cmd_buffer.GetOrTraceCommandBuffer( + &allocations, executor, &stream, trace)); + ASSERT_EQ(command_buffer0, command_buffer3); + EXPECT_EQ(num_calls, 2); + + // Check that we trace a new graph when buffer allocation pattern is new. + allocations = BufferAllocations({mem0, mem0}, 0, executor->GetAllocator()); + + TF_ASSERT_OK_AND_ASSIGN(auto* command_buffer4, + traced_cmd_buffer.GetOrTraceCommandBuffer( + &allocations, executor, &stream, trace)); + ASSERT_NE(command_buffer4, command_buffer3); + ASSERT_NE(command_buffer4, command_buffer2); + EXPECT_EQ(num_calls, 3); + + // Check that we still keep the previous graph in cache. + allocations = BufferAllocations({mem0, mem1}, 0, executor->GetAllocator()); + + TF_ASSERT_OK_AND_ASSIGN(auto* command_buffer5, + traced_cmd_buffer.GetOrTraceCommandBuffer( + &allocations, executor, &stream, trace)); + ASSERT_EQ(command_buffer0, command_buffer5); + EXPECT_EQ(num_calls, 3); +} + +//===----------------------------------------------------------------------===// +// Performance benchmarks below +//===----------------------------------------------------------------------===// + +static void BM_GetOrTraceCommandBuffer(benchmark::State& state) { + se::StreamExecutor* executor = GpuExecutor(); + + se::Stream stream(executor); + TF_ASSERT_OK(stream.Initialize()); + CHECK(stream.ok()); + + BufferAllocation alloc0(/*index=*/0, /*size=*/1024, /*color=*/0); + BufferAllocation alloc1(/*index=*/1, /*size=*/1024, /*color=*/0); + + CommandBufferCmd::BufferUsageVector buffers = { + {BufferAllocation::Slice(&alloc0, 0, 1024), MemoryAccess::kRead}, + {BufferAllocation::Slice(&alloc1, 0, 1024), MemoryAccess::kWrite}}; + + se::DeviceMemoryBase mem0(reinterpret_cast(0x01234567)); + se::DeviceMemoryBase mem1(reinterpret_cast(0x12345670)); + + std::array allocations = { + BufferAllocations({mem0, mem1}, 0, executor->GetAllocator()), + BufferAllocations({mem1, mem0}, 0, executor->GetAllocator()), + BufferAllocations({mem0, mem0}, 0, executor->GetAllocator()), + BufferAllocations({mem1, mem1}, 0, executor->GetAllocator()), + }; + + int32_t index = 0; + TracedCommandBuffer traced_cmd_buffer(buffers); + + auto trace = [](se::Stream*) { return absl::OkStatus(); }; + absl::FunctionRef trace_ref(trace); + + for (auto s : state) { + TF_CHECK_OK(traced_cmd_buffer + .GetOrTraceCommandBuffer(&allocations[index++ % 4], + executor, &stream, trace_ref) + .status()); + } +} + +BENCHMARK(BM_GetOrTraceCommandBuffer); + } // namespace xla::gpu diff --git a/third_party/xla/xla/service/gpu/runtime3/command_buffer_thunk.cc b/third_party/xla/xla/service/gpu/runtime/command_buffer_thunk.cc similarity index 87% rename from third_party/xla/xla/service/gpu/runtime3/command_buffer_thunk.cc rename to third_party/xla/xla/service/gpu/runtime/command_buffer_thunk.cc index f5be72ee218248..43cbf440480010 100644 --- a/third_party/xla/xla/service/gpu/runtime3/command_buffer_thunk.cc +++ b/third_party/xla/xla/service/gpu/runtime/command_buffer_thunk.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/runtime3/command_buffer_thunk.h" +#include "xla/service/gpu/runtime/command_buffer_thunk.h" #include #include @@ -27,7 +27,8 @@ limitations under the License. #include "absl/synchronization/mutex.h" #include "xla/service/buffer_assignment.h" #include "xla/service/gpu/buffer_allocations.h" -#include "xla/service/gpu/runtime3/command_buffer_cmd.h" +#include "xla/service/gpu/runtime/annotation.h" +#include "xla/service/gpu/runtime/command_buffer_cmd.h" #include "xla/service/gpu/thunk.h" #include "xla/stream_executor/command_buffer.h" #include "xla/stream_executor/device_memory.h" @@ -37,13 +38,11 @@ limitations under the License. #include "tsl/platform/logging.h" #include "tsl/platform/statusor.h" #include "tsl/profiler/lib/profiler_lock.h" -#include "tsl/profiler/lib/scoped_annotation.h" #include "tsl/profiler/lib/traceme.h" #include "tsl/profiler/lib/traceme_encode.h" namespace xla::gpu { -using tsl::profiler::ScopedAnnotation; using tsl::profiler::TraceMe; using tsl::profiler::TraceMeEncode; @@ -52,7 +51,7 @@ using tsl::profiler::TraceMeEncode; //===----------------------------------------------------------------------===// CommandBufferThunk::ExecutorCommandBuffer::ExecutorCommandBuffer( - se::CommandBuffer command_buffer) + std::unique_ptr command_buffer) : command_buffer(std::move(command_buffer)) {} CommandBufferThunk::CommandBufferThunk(CommandBufferCmdSequence commands, @@ -80,7 +79,7 @@ CommandBufferThunk::CommandBufferThunk(CommandBufferCmdSequence commands, bool CommandBufferThunk::ExecutorCommandBuffer::ShouldUpdateCommandBuffer( const CommandBufferCmdSequence& commands, - const CommandBufferCmd::RecordParams& params) { + const Thunk::ExecuteParams& params) { bool should_update = false; const BufferAllocations* allocs = params.buffer_allocations; @@ -127,7 +126,12 @@ absl::Status CommandBufferThunk::Initialize(const InitializeParams& params) { // are no-op (e.g. memcpy of size 0) and we have no emitted thunks for them. if (commands_.empty()) return absl::OkStatus(); - TF_RETURN_IF_ERROR(commands_.Initialize(params.executor, params.src)); + TF_ASSIGN_OR_RETURN(std::shared_ptr cmd_buffer, + GetOrCreateCommandBuffer(params.executor)); + absl::MutexLock lock(&cmd_buffer->mutex); + + // Initialize commands. + TF_RETURN_IF_ERROR(commands_.Initialize(params, cmd_buffer->state)); // Always initialize thunks if they are present so we are ready to fall back // on them if we detect profiling activity. @@ -137,26 +141,24 @@ absl::Status CommandBufferThunk::Initialize(const InitializeParams& params) { } } - TF_ASSIGN_OR_RETURN(std::shared_ptr cmd_buffer, - GetOrCreateCommandBuffer(params.executor)); - - absl::MutexLock lock(&cmd_buffer->mutex); - - CommandBufferCmd::RecordParams record_params = { - params.executor, - params.stream, - params.command_buffer_trace_stream, - const_cast(params.buffer_allocations), - params.collective_params, - params.collective_cliques}; + // Construct ExecuteParams with empty fields for everything that is not needed + // for recording commands. + Thunk::ExecuteParams execute_params( + params.buffer_allocations, params.stream, + params.command_buffer_trace_stream, {}, params.collective_params, + params.collective_cliques, /*device_to_host_stream=*/nullptr, + /*host_to_device_stream=*/nullptr, + /*send_device_memory_function=*/nullptr, + /*recv_device_memory_function=*/nullptr); // If command buffer is in `kCreate` state it means that command buffer // sequence was never recorded into it. We initialize all command buffers // before execution, because command buffers when instantiated will allocate // memory on device and this might lead to deadlocks when we have concurrent // NCCL operations in flight. - if (cmd_buffer->command_buffer.state() == se::CommandBuffer::State::kCreate && - cmd_buffer->ShouldUpdateCommandBuffer(commands_, record_params)) { + if (cmd_buffer->command_buffer->state() == + se::CommandBuffer::State::kCreate && + cmd_buffer->ShouldUpdateCommandBuffer(commands_, execute_params)) { VLOG(3) << "Initialize command buffer on device #" << params.executor->device_ordinal() << " by recoding command buffer cmd sequence" @@ -170,8 +172,8 @@ absl::Status CommandBufferThunk::Initialize(const InitializeParams& params) { uint64_t start_micros = tsl::Env::Default()->NowMicros(); - TF_RETURN_IF_ERROR( - commands_.Record(record_params, &cmd_buffer->command_buffer)); + TF_RETURN_IF_ERROR(commands_.Record(execute_params, cmd_buffer->state, + cmd_buffer->command_buffer.get())); uint64_t end_micros = tsl::Env::Default()->NowMicros(); VLOG(3) << "Initialized command buffer on device #" @@ -194,8 +196,10 @@ absl::Status CommandBufferThunk::ExecuteOnStream(const ExecuteParams& params) { if (tsl::profiler::ProfilerLock::HasActiveSession() && thunks_.has_value()) { VLOG(1) << "Execute command buffer thunk as a regular thunk sequence " "because we detected active profiling session"; + const ModuleAnnotations* annotations = GetCurrentModuleAnnotations(); for (auto& thunk : *thunks_) { - ScopedAnnotation annotation([&] { return thunk->profile_annotation(); }); + auto scoped_annotation = + GetKernelAnnotation(annotations, thunk->profile_annotation()); TF_RETURN_IF_ERROR(thunk->ExecuteOnStream(params)); } return absl::OkStatus(); @@ -207,15 +211,7 @@ absl::Status CommandBufferThunk::ExecuteOnStream(const ExecuteParams& params) { absl::MutexLock lock(&cmd_buffer->mutex); - CommandBufferCmd::RecordParams record_params = { - executor, - params.stream, - params.command_buffer_trace_stream, - const_cast(params.buffer_allocations), - params.collective_params, - params.collective_cliques}; - - if (cmd_buffer->ShouldUpdateCommandBuffer(commands_, record_params)) { + if (cmd_buffer->ShouldUpdateCommandBuffer(commands_, params)) { VLOG(3) << "Update command buffer on device #" << executor->device_ordinal() << " by recoding command buffer cmd sequence" << " after " << cmd_buffer->num_executions << " executions since last update" @@ -231,8 +227,8 @@ absl::Status CommandBufferThunk::ExecuteOnStream(const ExecuteParams& params) { uint64_t start_micros = tsl::Env::Default()->NowMicros(); - TF_RETURN_IF_ERROR( - commands_.Record(record_params, &cmd_buffer->command_buffer)); + TF_RETURN_IF_ERROR(commands_.Record(params, cmd_buffer->state, + cmd_buffer->command_buffer.get())); uint64_t end_micros = tsl::Env::Default()->NowMicros(); VLOG(3) << "Updated command buffer in " << (end_micros - start_micros) @@ -253,7 +249,7 @@ absl::Status CommandBufferThunk::ExecuteOnStream(const ExecuteParams& params) { {"num_executions", cmd_buffer->num_executions}}); }); - return executor->Submit(params.stream, cmd_buffer->command_buffer); + return executor->Submit(params.stream, *cmd_buffer->command_buffer); } absl::StatusOr> diff --git a/third_party/xla/xla/service/gpu/runtime3/command_buffer_thunk.h b/third_party/xla/xla/service/gpu/runtime/command_buffer_thunk.h similarity index 88% rename from third_party/xla/xla/service/gpu/runtime3/command_buffer_thunk.h rename to third_party/xla/xla/service/gpu/runtime/command_buffer_thunk.h index 09598485885a8c..d81003c899661f 100644 --- a/third_party/xla/xla/service/gpu/runtime3/command_buffer_thunk.h +++ b/third_party/xla/xla/service/gpu/runtime/command_buffer_thunk.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_GPU_RUNTIME3_COMMAND_BUFFER_THUNK_H_ -#define XLA_SERVICE_GPU_RUNTIME3_COMMAND_BUFFER_THUNK_H_ +#ifndef XLA_SERVICE_GPU_RUNTIME_COMMAND_BUFFER_THUNK_H_ +#define XLA_SERVICE_GPU_RUNTIME_COMMAND_BUFFER_THUNK_H_ #include #include @@ -26,8 +26,8 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/synchronization/mutex.h" -#include "xla/service/gpu/runtime3/command_buffer_allocations.h" -#include "xla/service/gpu/runtime3/command_buffer_cmd.h" +#include "xla/service/gpu/runtime/command_buffer_allocations.h" +#include "xla/service/gpu/runtime/command_buffer_cmd.h" #include "xla/service/gpu/thunk.h" #include "xla/stream_executor/command_buffer.h" #include "xla/stream_executor/device_memory.h" @@ -55,18 +55,23 @@ class CommandBufferThunk : public Thunk { // Command buffer instantiated on a `se::StreamExecutor` instance, and // auxiliary state required for efficient command buffer updates. struct ExecutorCommandBuffer { - explicit ExecutorCommandBuffer(se::CommandBuffer command_buffer); + explicit ExecutorCommandBuffer( + std::unique_ptr command_buffer); // Returns true if `commands` cmd sequence has to be recorded into // `command_buffer` to update it (see `recorded_allocs` below). bool ShouldUpdateCommandBuffer(const CommandBufferCmdSequence& commands, - const CommandBufferCmd::RecordParams& params) + const Thunk::ExecuteParams& params) ABSL_EXCLUSIVE_LOCKS_REQUIRED(mutex); // se::CommandBuffer is not thread safe, and we guard it with a mutex to // guarantee that we do not mutate it concurrently. absl::Mutex mutex; - se::CommandBuffer command_buffer ABSL_GUARDED_BY(mutex); + std::unique_ptr command_buffer ABSL_GUARDED_BY(mutex); + + // A manager for an external state attached by commands in a command + // sequence to a command buffer. + CommandBufferCmd::StateManager state ABSL_GUARDED_BY(mutex); // TODO(ezhulenev): We need to move command buffer allocations all the way // up to the GpuExecutable as we can have Allocate and Free commands in @@ -136,4 +141,4 @@ class CommandBufferThunk : public Thunk { } // namespace xla::gpu -#endif // XLA_SERVICE_GPU_RUNTIME3_COMMAND_BUFFER_THUNK_H_ +#endif // XLA_SERVICE_GPU_RUNTIME_COMMAND_BUFFER_THUNK_H_ diff --git a/third_party/xla/xla/service/gpu/runtime3/command_buffer_thunk_test.cc b/third_party/xla/xla/service/gpu/runtime/command_buffer_thunk_test.cc similarity index 87% rename from third_party/xla/xla/service/gpu/runtime3/command_buffer_thunk_test.cc rename to third_party/xla/xla/service/gpu/runtime/command_buffer_thunk_test.cc index 45f132b029142b..9d367a7404b096 100644 --- a/third_party/xla/xla/service/gpu/runtime3/command_buffer_thunk_test.cc +++ b/third_party/xla/xla/service/gpu/runtime/command_buffer_thunk_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/runtime3/command_buffer_thunk.h" +#include "xla/service/gpu/runtime/command_buffer_thunk.h" #include #include @@ -28,8 +28,8 @@ limitations under the License. #include "xla/service/gpu/buffer_allocations.h" #include "xla/service/gpu/launch_dimensions.h" #include "xla/service/gpu/matmul_utils.h" -#include "xla/service/gpu/runtime3/command_buffer_allocations.h" -#include "xla/service/gpu/runtime3/command_buffer_cmd.h" +#include "xla/service/gpu/runtime/command_buffer_allocations.h" +#include "xla/service/gpu/runtime/command_buffer_cmd.h" #include "xla/service/gpu/thunk.h" #include "xla/service/platform_util.h" #include "xla/service/service_executable_run_options.h" @@ -39,28 +39,30 @@ limitations under the License. #include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/gpu/gpu_test_kernels.h" #include "xla/stream_executor/gpu/gpu_types.h" // IWYU pragma: keep -#include "xla/stream_executor/multi_platform_manager.h" #include "xla/stream_executor/platform.h" +#include "xla/stream_executor/platform_manager.h" #include "xla/stream_executor/stream_executor.h" #include "xla/types.h" // IWYU pragma: keep #include "tsl/lib/core/status_test_util.h" #include "tsl/platform/test.h" +#ifdef GOOGLE_CUDA +#include "third_party/gpus/cuda/include/cuda.h" +#endif + namespace xla::gpu { using MemoryAccess = CommandBufferCmd::MemoryAccess; using KernelArgsPacking = se::MultiKernelLoaderSpec::KernelArgsPacking; -namespace { - static se::StreamExecutor* GpuExecutor() { auto name = absl::AsciiStrToUpper(PlatformUtil::CanonicalPlatformName("gpu").value()); - auto* platform = se::MultiPlatformManager::PlatformWithName(name).value(); + auto* platform = se::PlatformManager::PlatformWithName(name).value(); return platform->ExecutorForDevice(0).value(); } -Thunk::ExecutableSource ExecutableSource() { +static Thunk::ExecutableSource ExecutableSource() { Thunk::ExecutableSource source = { #if defined(GOOGLE_CUDA) /*text=*/se::gpu::internal::kAddI32Kernel, @@ -73,7 +75,7 @@ Thunk::ExecutableSource ExecutableSource() { return source; } -KernelArgsPacking CreateDefaultArgsPacking() { +static KernelArgsPacking CreateDefaultArgsPacking() { using Packed = absl::StatusOr>; return [=](const se::Kernel& kernel, const se::KernelArgs& args) -> Packed { @@ -84,14 +86,22 @@ KernelArgsPacking CreateDefaultArgsPacking() { }; } -} // namespace +// Some of the tests rely on CUDA 12.3+ features. +static bool IsAtLeastCuda12300() { +#if defined(TENSORFLOW_USE_ROCM) + return false; +#endif +#if CUDA_VERSION >= 12030 + return true; +#endif + return false; +} TEST(CommandBufferThunkTest, MemcpyCmd) { se::StreamExecutor* executor = GpuExecutor(); se::Stream stream(executor); - stream.Init(); - ASSERT_TRUE(stream.ok()); + TF_ASSERT_OK(stream.Initialize()); int64_t length = 4; int64_t byte_length = sizeof(int32_t) * length; @@ -100,8 +110,8 @@ TEST(CommandBufferThunkTest, MemcpyCmd) { se::DeviceMemory a = executor->AllocateArray(length, 0); se::DeviceMemory b = executor->AllocateArray(length, 0); - stream.ThenMemset32(&a, 42, byte_length); - stream.ThenMemZero(&b, byte_length); + TF_ASSERT_OK(stream.Memset32(&a, 42, byte_length)); + TF_ASSERT_OK(stream.MemZero(&b, byte_length)); // Prepare buffer allocations for recording command buffer. BufferAllocation alloc_a(/*index=*/0, byte_length, /*color=*/0); @@ -129,12 +139,12 @@ TEST(CommandBufferThunkTest, MemcpyCmd) { // Copy `b` data back to host. std::vector dst(4, 0); - stream.ThenMemcpy(dst.data(), b, byte_length); + TF_ASSERT_OK(stream.Memcpy(dst.data(), b, byte_length)); ASSERT_EQ(dst, std::vector(4, 42)); // Try to update the command buffer with the same buffers. - stream.ThenMemZero(&b, byte_length); + TF_ASSERT_OK(stream.MemZero(&b, byte_length)); // Thunk execution should automatically update underlying command buffer. TF_ASSERT_OK(thunk.ExecuteOnStream(params)); @@ -142,7 +152,7 @@ TEST(CommandBufferThunkTest, MemcpyCmd) { // Copy `b` data back to host. std::fill(dst.begin(), dst.end(), 0); - stream.ThenMemcpy(dst.data(), b, byte_length); + TF_ASSERT_OK(stream.Memcpy(dst.data(), b, byte_length)); ASSERT_EQ(dst, std::vector(4, 42)); } @@ -151,15 +161,14 @@ TEST(CommandBufferThunkTest, MemzeroCmd) { se::StreamExecutor* executor = GpuExecutor(); se::Stream stream(executor); - stream.Init(); - ASSERT_TRUE(stream.ok()); + TF_ASSERT_OK(stream.Initialize()); int64_t length = 4; int64_t byte_length = sizeof(int32_t) * length; // Prepare arguments: a=42 se::DeviceMemory a = executor->AllocateArray(length, 0); - stream.ThenMemset32(&a, 42, byte_length); + TF_ASSERT_OK(stream.Memset32(&a, 42, byte_length)); // Prepare buffer allocations for recording command buffer. BufferAllocation alloc_a(/*index=*/0, byte_length, /*color=*/0); @@ -184,7 +193,7 @@ TEST(CommandBufferThunkTest, MemzeroCmd) { // Copy `a` data back to host. std::vector dst(4, 0); - stream.ThenMemcpy(dst.data(), a, byte_length); + TF_ASSERT_OK(stream.Memcpy(dst.data(), a, byte_length)); ASSERT_EQ(dst, std::vector(4, 0)); } @@ -193,8 +202,7 @@ TEST(CommandBufferThunkTest, Memset32Cmd) { se::StreamExecutor* executor = GpuExecutor(); se::Stream stream(executor); - stream.Init(); - ASSERT_TRUE(stream.ok()); + TF_ASSERT_OK(stream.Initialize()); int64_t length = 4; int64_t byte_length = sizeof(int32_t) * length; @@ -202,7 +210,7 @@ TEST(CommandBufferThunkTest, Memset32Cmd) { // Prepare arguments: a=42 se::DeviceMemory a = executor->AllocateArray(length, 0); - stream.ThenMemset32(&a, 42, byte_length); + TF_ASSERT_OK(stream.Memset32(&a, 42, byte_length)); // Prepare buffer allocations for recording command buffer. BufferAllocation alloc_a(/*index=*/0, byte_length, /*color=*/0); @@ -227,7 +235,7 @@ TEST(CommandBufferThunkTest, Memset32Cmd) { // Copy `a` data back to host. std::vector dst(4, 0); - stream.ThenMemcpy(dst.data(), a, byte_length); + TF_ASSERT_OK(stream.Memcpy(dst.data(), a, byte_length)); ASSERT_EQ(dst, std::vector(4, 84)); } @@ -244,8 +252,7 @@ TEST(CommandBufferThunkTest, MemallocFreeCmdSameThunk) { se::StreamExecutor* executor = GpuExecutor(); se::Stream stream(executor); - stream.Init(); - ASSERT_TRUE(stream.ok()); + TF_ASSERT_OK(stream.Initialize()); // Prepare arguments: int64_t length = 4; @@ -270,7 +277,7 @@ TEST(CommandBufferThunkTest, MemallocFreeCmdSameThunk) { // Prepare arguments: a=42, b=0 se::DeviceMemory a = executor->AllocateArray(length, 0); - stream.ThenMemset32(&a, 42, byte_length); + TF_ASSERT_OK(stream.Memset32(&a, 42, byte_length)); se::DeviceMemory b(se::DeviceMemoryBase( reinterpret_cast(BufferAllocations::kExternalAllocationMarker), @@ -293,8 +300,8 @@ TEST(CommandBufferThunkTest, MemallocFreeCmdSameThunk) { // Copy `b` data back to host. std::vector dst(4, 0); - stream.ThenMemcpy(dst.data(), allocations.GetMutableDeviceAddress(2), - byte_length); + TF_ASSERT_OK(stream.Memcpy(dst.data(), allocations.GetMutableDeviceAddress(2), + byte_length)); ASSERT_EQ(dst, std::vector(4, 42)); } @@ -310,8 +317,7 @@ TEST(CommandBufferThunkTest, MemallocFreeCmdAcrossThunk) { se::StreamExecutor* executor = GpuExecutor(); se::Stream stream(executor); - stream.Init(); - ASSERT_TRUE(stream.ok()); + TF_ASSERT_OK(stream.Initialize()); // Prepare arguments: int64_t length = 4; @@ -335,7 +341,7 @@ TEST(CommandBufferThunkTest, MemallocFreeCmdAcrossThunk) { // Prepare arguments: a=42, b=0 se::DeviceMemory a = executor->AllocateArray(length, 0); - stream.ThenMemset32(&a, 42, byte_length); + TF_ASSERT_OK(stream.Memset32(&a, 42, byte_length)); se::DeviceMemory b(se::DeviceMemoryBase( reinterpret_cast(BufferAllocations::kExternalAllocationMarker), byte_length)); @@ -367,8 +373,8 @@ TEST(CommandBufferThunkTest, MemallocFreeCmdAcrossThunk) { // Copy `c` data back to host. std::vector dst(4, 0); - stream.ThenMemcpy(dst.data(), allocations.GetMutableDeviceAddress(2), - byte_length); + TF_ASSERT_OK(stream.Memcpy(dst.data(), allocations.GetMutableDeviceAddress(2), + byte_length)); ASSERT_EQ(dst, std::vector(4, 42)); } @@ -377,8 +383,7 @@ TEST(CommandBufferThunkTest, LaunchCmd) { se::StreamExecutor* executor = GpuExecutor(); se::Stream stream(executor); - stream.Init(); - ASSERT_TRUE(stream.ok()); + TF_ASSERT_OK(stream.Initialize()); int64_t length = 4; int64_t byte_length = sizeof(int32_t) * length; @@ -387,8 +392,8 @@ TEST(CommandBufferThunkTest, LaunchCmd) { se::DeviceMemory a = executor->AllocateArray(length, 0); se::DeviceMemory b = executor->AllocateArray(length, 0); - stream.ThenMemset32(&a, 42, byte_length); - stream.ThenMemZero(&b, byte_length); + TF_ASSERT_OK(stream.Memset32(&a, 42, byte_length)); + TF_ASSERT_OK(stream.MemZero(&b, byte_length)); // Prepare buffer allocations for recording command buffer. BufferAllocation alloc_a(/*index=*/0, byte_length, /*color=*/0); @@ -424,13 +429,13 @@ TEST(CommandBufferThunkTest, LaunchCmd) { // Copy `b` data back to host. std::vector dst(4, 0); - stream.ThenMemcpy(dst.data(), b, byte_length); + TF_ASSERT_OK(stream.Memcpy(dst.data(), b, byte_length)); ASSERT_EQ(dst, std::vector(4, 42 + 42)); // Prepare buffer allocation for updating command buffer: c=0 se::DeviceMemory c = executor->AllocateArray(length, 0); - stream.ThenMemZero(&c, byte_length); + TF_ASSERT_OK(stream.MemZero(&c, byte_length)); // Update buffer allocation #1 to buffer `c`. allocations = BufferAllocations({a, c}, 0, executor->GetAllocator()); @@ -441,12 +446,12 @@ TEST(CommandBufferThunkTest, LaunchCmd) { // Copy `c` data back to host. std::fill(dst.begin(), dst.end(), 0); - stream.ThenMemcpy(dst.data(), c, byte_length); + TF_ASSERT_OK(stream.Memcpy(dst.data(), c, byte_length)); ASSERT_EQ(dst, std::vector(4, 42 + 42)); // Try to update the command buffer with the same buffers. - stream.ThenMemZero(&c, byte_length); + TF_ASSERT_OK(stream.MemZero(&c, byte_length)); // Thunk execution should automatically update underlying command buffer. TF_ASSERT_OK(thunk.ExecuteOnStream(params)); @@ -454,7 +459,7 @@ TEST(CommandBufferThunkTest, LaunchCmd) { // Copy `c` data back to host. std::fill(dst.begin(), dst.end(), 0); - stream.ThenMemcpy(dst.data(), c, byte_length); + TF_ASSERT_OK(stream.Memcpy(dst.data(), c, byte_length)); ASSERT_EQ(dst, std::vector(4, 42 + 42)); } @@ -463,8 +468,7 @@ TEST(CommandBufferThunkTest, CustomAddKernelLaunchCmd) { se::StreamExecutor* executor = GpuExecutor(); se::Stream stream(executor); - stream.Init(); - ASSERT_TRUE(stream.ok()); + TF_ASSERT_OK(stream.Initialize()); auto packing = CreateDefaultArgsPacking(); @@ -482,8 +486,8 @@ TEST(CommandBufferThunkTest, CustomAddKernelLaunchCmd) { se::DeviceMemory a = executor->AllocateArray(length, 0); se::DeviceMemory b = executor->AllocateArray(length, 0); - stream.ThenMemset32(&a, 42, byte_length); - stream.ThenMemZero(&b, byte_length); + TF_ASSERT_OK(stream.Memset32(&a, 42, byte_length)); + TF_ASSERT_OK(stream.MemZero(&b, byte_length)); // Prepare buffer allocations for recording command buffer. BufferAllocation alloc_a(/*index=*/0, byte_length, /*color=*/0); @@ -519,13 +523,13 @@ TEST(CommandBufferThunkTest, CustomAddKernelLaunchCmd) { // Copy `b` data back to host. std::vector dst(4, 0); - stream.ThenMemcpy(dst.data(), b, byte_length); + TF_ASSERT_OK(stream.Memcpy(dst.data(), b, byte_length)); ASSERT_EQ(dst, std::vector(4, 42 + 42)); // Prepare buffer allocation for updating command buffer: c=0 se::DeviceMemory c = executor->AllocateArray(length, 0); - stream.ThenMemZero(&c, byte_length); + TF_ASSERT_OK(stream.MemZero(&c, byte_length)); // Update buffer allocation #1 to buffer `c`. allocations = BufferAllocations({a, c}, 0, executor->GetAllocator()); @@ -536,12 +540,12 @@ TEST(CommandBufferThunkTest, CustomAddKernelLaunchCmd) { // Copy `c` data back to host. std::fill(dst.begin(), dst.end(), 0); - stream.ThenMemcpy(dst.data(), c, byte_length); + TF_ASSERT_OK(stream.Memcpy(dst.data(), c, byte_length)); ASSERT_EQ(dst, std::vector(4, 42 + 42)); // Try to update the command buffer with the same buffers. - stream.ThenMemZero(&c, byte_length); + TF_ASSERT_OK(stream.MemZero(&c, byte_length)); // Thunk execution should automatically update underlying command buffer. TF_ASSERT_OK(thunk.ExecuteOnStream(params)); @@ -549,20 +553,20 @@ TEST(CommandBufferThunkTest, CustomAddKernelLaunchCmd) { // Copy `c` data back to host. std::fill(dst.begin(), dst.end(), 0); - stream.ThenMemcpy(dst.data(), c, byte_length); + TF_ASSERT_OK(stream.Memcpy(dst.data(), c, byte_length)); ASSERT_EQ(dst, std::vector(4, 42 + 42)); } TEST(CommandBufferThunkTest, GemmCmd) { -#if !defined(TENSORFLOW_USE_ROCM) && CUDA_VERSION < 12030 - GTEST_SKIP() << "Command buffer tracing is not supported"; -#endif + if (!IsAtLeastCuda12300()) { + GTEST_SKIP() << "CUDA graph conditionals are not supported"; + } + se::StreamExecutor* executor = GpuExecutor(); se::Stream stream(executor); - stream.Init(); - ASSERT_TRUE(stream.ok()); + TF_ASSERT_OK(stream.Initialize()); int64_t lhs_length = sizeof(float) * 2 * 4; int64_t rhs_length = sizeof(float) * 4 * 3; @@ -577,18 +581,18 @@ TEST(CommandBufferThunkTest, GemmCmd) { // 1.0, 1.0, 1.0] se::DeviceMemory lhs = executor->AllocateArray(2 * 4); std::vector lhs_arr{1, 2, 3, 4, 5, 6, 7, 8}; - stream.ThenMemcpy(&lhs, lhs_arr.data(), lhs_length); + TF_ASSERT_OK(stream.Memcpy(&lhs, lhs_arr.data(), lhs_length)); se::DeviceMemory rhs = executor->AllocateArray(4 * 3); std::vector rhs_arr(12, 1); - stream.ThenMemcpy(&rhs, rhs_arr.data(), rhs_length); + TF_ASSERT_OK(stream.Memcpy(&rhs, rhs_arr.data(), rhs_length)); se::DeviceMemory out = executor->AllocateArray(2 * 3); - stream.ThenMemZero(&out, out_length); + TF_ASSERT_OK(stream.MemZero(&out, out_length)); se::DeviceMemory workspace = executor->AllocateArray(1024 * 1024); - stream.ThenMemZero(&workspace, 1024 * 1024); + TF_ASSERT_OK(stream.MemZero(&workspace, 1024 * 1024)); // Prepare buffer allocations for recording command buffer. BufferAllocation alloc_lhs(/*index=*/0, lhs_length, /*color=*/0); @@ -633,13 +637,13 @@ TEST(CommandBufferThunkTest, GemmCmd) { // Copy `out` data back to host. std::vector dst(6, 0); - stream.ThenMemcpy(dst.data(), out, out_length); + TF_ASSERT_OK(stream.Memcpy(dst.data(), out, out_length)); ASSERT_EQ(dst, std::vector({10, 10, 10, 26, 26, 26})); // Prepare buffer allocation for updating command buffer. se::DeviceMemory updated_out = executor->AllocateArray(2 * 3); - stream.ThenMemZero(&updated_out, out_length); + TF_ASSERT_OK(stream.MemZero(&updated_out, out_length)); // Update buffer allocation to updated `out` buffer. allocations = BufferAllocations({lhs, rhs, updated_out, workspace}, 0, @@ -651,12 +655,12 @@ TEST(CommandBufferThunkTest, GemmCmd) { // Copy `updated_out` data back to host. std::fill(dst.begin(), dst.end(), 0); - stream.ThenMemcpy(dst.data(), updated_out, out_length); + TF_ASSERT_OK(stream.Memcpy(dst.data(), updated_out, out_length)); ASSERT_EQ(dst, std::vector({10, 10, 10, 26, 26, 26})); // Try to update the command buffer with the same buffers. - stream.ThenMemZero(&updated_out, out_length); + TF_ASSERT_OK(stream.MemZero(&updated_out, out_length)); // Thunk execution should automatically update underlying command buffer. TF_ASSERT_OK(thunk.ExecuteOnStream(params)); @@ -664,7 +668,7 @@ TEST(CommandBufferThunkTest, GemmCmd) { // Copy `updated_out` data back to host. std::fill(dst.begin(), dst.end(), 0); - stream.ThenMemcpy(dst.data(), updated_out, out_length); + TF_ASSERT_OK(stream.Memcpy(dst.data(), updated_out, out_length)); ASSERT_EQ(dst, std::vector({10, 10, 10, 26, 26, 26})); } @@ -673,8 +677,7 @@ TEST(CommandBufferThunkTest, MultipleLaunchCmd) { se::StreamExecutor* executor = GpuExecutor(); se::Stream stream(executor); - stream.Init(); - ASSERT_TRUE(stream.ok()); + TF_ASSERT_OK(stream.Initialize()); int64_t length = 4; int64_t byte_length = sizeof(int32_t) * length; @@ -685,10 +688,10 @@ TEST(CommandBufferThunkTest, MultipleLaunchCmd) { se::DeviceMemory c = executor->AllocateArray(length, 0); se::DeviceMemory d = executor->AllocateArray(length, 0); - stream.ThenMemset32(&a, 42, byte_length); - stream.ThenMemZero(&b, byte_length); - stream.ThenMemset32(&c, 21, byte_length); - stream.ThenMemZero(&d, byte_length); + TF_ASSERT_OK(stream.Memset32(&a, 42, byte_length)); + TF_ASSERT_OK(stream.MemZero(&b, byte_length)); + TF_ASSERT_OK(stream.Memset32(&c, 21, byte_length)); + TF_ASSERT_OK(stream.MemZero(&d, byte_length)); // Prepare buffer allocations for recording command buffer. BufferAllocation alloc_a(/*index=*/0, byte_length, /*color=*/0); @@ -732,12 +735,12 @@ TEST(CommandBufferThunkTest, MultipleLaunchCmd) { // Copy `b` data back to host. std::vector dst(4, 0); - stream.ThenMemcpy(dst.data(), b, byte_length); + TF_ASSERT_OK(stream.Memcpy(dst.data(), b, byte_length)); ASSERT_EQ(dst, std::vector(4, 42 + 42)); // Copy `d` data back to host. std::fill(dst.begin(), dst.end(), 0); - stream.ThenMemcpy(dst.data(), d, byte_length); + TF_ASSERT_OK(stream.Memcpy(dst.data(), d, byte_length)); ASSERT_EQ(dst, std::vector(4, 21 + 21)); BufferAllocation alloc_e(/*index=*/3, byte_length, /*color=*/0); @@ -745,7 +748,7 @@ TEST(CommandBufferThunkTest, MultipleLaunchCmd) { // Prepare buffer allocation for updating command buffer: e=0 se::DeviceMemory e = executor->AllocateArray(length, 0); - stream.ThenMemZero(&e, byte_length); + TF_ASSERT_OK(stream.MemZero(&e, byte_length)); // Update buffer allocation #1 to buffer `c`. allocations = BufferAllocations({a, b, c, e}, 0, executor->GetAllocator()); @@ -756,16 +759,16 @@ TEST(CommandBufferThunkTest, MultipleLaunchCmd) { // Copy `b` data back to host. std::fill(dst.begin(), dst.end(), 0); - stream.ThenMemcpy(dst.data(), b, byte_length); + TF_ASSERT_OK(stream.Memcpy(dst.data(), b, byte_length)); ASSERT_EQ(dst, std::vector(4, 42 + 42)); // Copy `e` data back to host. std::fill(dst.begin(), dst.end(), 0); - stream.ThenMemcpy(dst.data(), e, byte_length); + TF_ASSERT_OK(stream.Memcpy(dst.data(), e, byte_length)); ASSERT_EQ(dst, std::vector(4, 21 + 21)); // Try to update the command buffer with the same buffers. - stream.ThenMemZero(&e, byte_length); + TF_ASSERT_OK(stream.MemZero(&e, byte_length)); // Thunk execution should automatically update underlying command buffer. TF_ASSERT_OK(thunk.ExecuteOnStream(params)); @@ -773,24 +776,24 @@ TEST(CommandBufferThunkTest, MultipleLaunchCmd) { // Copy `b` data back to host. std::fill(dst.begin(), dst.end(), 0); - stream.ThenMemcpy(dst.data(), b, byte_length); + TF_ASSERT_OK(stream.Memcpy(dst.data(), b, byte_length)); ASSERT_EQ(dst, std::vector(4, 42 + 42)); // Copy `e` data back to host. std::fill(dst.begin(), dst.end(), 0); - stream.ThenMemcpy(dst.data(), e, byte_length); + TF_ASSERT_OK(stream.Memcpy(dst.data(), e, byte_length)); ASSERT_EQ(dst, std::vector(4, 21 + 21)); } TEST(CommandBufferThunkTest, IfCmd) { - se::StreamExecutor* executor = GpuExecutor(); - if (!se::CommandBuffer::SupportsConditionalCommands(executor->platform())) { + if (!IsAtLeastCuda12300()) { GTEST_SKIP() << "CUDA graph conditionals are not supported"; } + se::StreamExecutor* executor = GpuExecutor(); + se::Stream stream(executor); - stream.Init(); - ASSERT_TRUE(stream.ok()); + TF_ASSERT_OK(stream.Initialize()); int64_t length = 4; int64_t byte_length = sizeof(int32_t) * length; @@ -801,9 +804,9 @@ TEST(CommandBufferThunkTest, IfCmd) { se::DeviceMemory b = executor->AllocateArray(length, 0); constexpr bool kTrue = true; - stream.ThenMemcpy(&pred, &kTrue, 1); - stream.ThenMemset32(&a, 42, byte_length); - stream.ThenMemZero(&b, byte_length); + TF_ASSERT_OK(stream.Memcpy(&pred, &kTrue, 1)); + TF_ASSERT_OK(stream.Memset32(&a, 42, byte_length)); + TF_ASSERT_OK(stream.MemZero(&b, byte_length)); // Prepare buffer allocations for recording command buffer. BufferAllocation alloc_p(/*index=*/0, 1, /*color=*/0); @@ -846,13 +849,13 @@ TEST(CommandBufferThunkTest, IfCmd) { // Copy `b` data back to host. std::vector dst(4, 0); - stream.ThenMemcpy(dst.data(), b, byte_length); + TF_ASSERT_OK(stream.Memcpy(dst.data(), b, byte_length)); ASSERT_EQ(dst, std::vector(4, 42 + 42)); // Prepare buffer allocation for updating command buffer: c=0 se::DeviceMemory c = executor->AllocateArray(length, 0); - stream.ThenMemZero(&c, byte_length); + TF_ASSERT_OK(stream.MemZero(&c, byte_length)); // Update buffer allocation #2 to buffer `c`. allocations = BufferAllocations({pred, a, c}, 0, executor->GetAllocator()); @@ -863,20 +866,20 @@ TEST(CommandBufferThunkTest, IfCmd) { // Copy `c` data back to host. std::fill(dst.begin(), dst.end(), 0); - stream.ThenMemcpy(dst.data(), c, byte_length); + TF_ASSERT_OK(stream.Memcpy(dst.data(), c, byte_length)); ASSERT_EQ(dst, std::vector(4, 42 + 42)); } TEST(CommandBufferThunkTest, IfElseCmd) { - se::StreamExecutor* executor = GpuExecutor(); - if (!se::CommandBuffer::SupportsConditionalCommands(executor->platform())) { + if (!IsAtLeastCuda12300()) { GTEST_SKIP() << "CUDA graph conditionals are not supported"; } + se::StreamExecutor* executor = GpuExecutor(); + se::Stream stream(executor); - stream.Init(); - ASSERT_TRUE(stream.ok()); + TF_ASSERT_OK(stream.Initialize()); int64_t length = 4; int64_t byte_length = sizeof(int32_t) * length; @@ -887,9 +890,9 @@ TEST(CommandBufferThunkTest, IfElseCmd) { se::DeviceMemory b = executor->AllocateArray(length, 0); constexpr bool kTrue = true; - stream.ThenMemcpy(&pred, &kTrue, 1); - stream.ThenMemset32(&a, 42, byte_length); - stream.ThenMemZero(&b, byte_length); + TF_ASSERT_OK(stream.Memcpy(&pred, &kTrue, 1)); + TF_ASSERT_OK(stream.Memset32(&a, 42, byte_length)); + TF_ASSERT_OK(stream.MemZero(&b, byte_length)); // Prepare buffer allocations for recording command buffer. BufferAllocation alloc_p(/*index=*/0, 1, /*color=*/0); @@ -944,30 +947,30 @@ TEST(CommandBufferThunkTest, IfElseCmd) { // Copy `b` data back to host. std::vector dst(4, 0); - stream.ThenMemcpy(dst.data(), b, byte_length); + TF_ASSERT_OK(stream.Memcpy(dst.data(), b, byte_length)); ASSERT_EQ(dst, std::vector(4, 42 + 42)); // Change branch to `else` and check that it updated the `b` buffer. constexpr bool kFalse = false; - stream.ThenMemcpy(&pred, &kFalse, 1); + TF_ASSERT_OK(stream.Memcpy(&pred, &kFalse, 1)); TF_ASSERT_OK(thunk.ExecuteOnStream(params)); TF_ASSERT_OK(stream.BlockHostUntilDone()); - stream.ThenMemcpy(dst.data(), b, byte_length); + TF_ASSERT_OK(stream.Memcpy(dst.data(), b, byte_length)); ASSERT_EQ(dst, std::vector(4, 2 * (42 + 42))); } TEST(CommandBufferThunkTest, CaseCmd) { - se::StreamExecutor* executor = GpuExecutor(); - if (!se::CommandBuffer::SupportsConditionalCommands(executor->platform())) { + if (!IsAtLeastCuda12300()) { GTEST_SKIP() << "CUDA graph conditionals are not supported"; } + se::StreamExecutor* executor = GpuExecutor(); + se::Stream stream(executor); - stream.Init(); - ASSERT_TRUE(stream.ok()); + TF_ASSERT_OK(stream.Initialize()); int64_t length = 4; int64_t byte_length = sizeof(int32_t) * length; @@ -977,9 +980,9 @@ TEST(CommandBufferThunkTest, CaseCmd) { se::DeviceMemory a = executor->AllocateArray(length, 0); se::DeviceMemory b = executor->AllocateArray(length, 0); - stream.ThenMemset32(&index, 0, sizeof(int32_t)); - stream.ThenMemset32(&a, 42, byte_length); - stream.ThenMemZero(&b, byte_length); + TF_ASSERT_OK(stream.Memset32(&index, 0, sizeof(int32_t))); + TF_ASSERT_OK(stream.Memset32(&a, 42, byte_length)); + TF_ASSERT_OK(stream.MemZero(&b, byte_length)); // Prepare buffer allocations for recording command buffer. BufferAllocation alloc_i(/*index=*/0, 1, /*color=*/0); @@ -1032,29 +1035,29 @@ TEST(CommandBufferThunkTest, CaseCmd) { // Copy `b` data back to host. std::vector dst(4, 0); - stream.ThenMemcpy(dst.data(), b, byte_length); + TF_ASSERT_OK(stream.Memcpy(dst.data(), b, byte_length)); ASSERT_EQ(dst, std::vector(4, 42 + 42)); // Change `index` to `1` and check that it updated the `b` buffer. - stream.ThenMemset32(&index, 1, sizeof(int32_t)); + TF_ASSERT_OK(stream.Memset32(&index, 1, sizeof(int32_t))); TF_ASSERT_OK(thunk.ExecuteOnStream(params)); TF_ASSERT_OK(stream.BlockHostUntilDone()); - stream.ThenMemcpy(dst.data(), b, byte_length); + TF_ASSERT_OK(stream.Memcpy(dst.data(), b, byte_length)); ASSERT_EQ(dst, std::vector(4, 2 * (42 + 42))); } TEST(CommandBufferThunkTest, ForCmd) { - se::StreamExecutor* executor = GpuExecutor(); - if (!se::CommandBuffer::SupportsConditionalCommands(executor->platform())) { + if (!IsAtLeastCuda12300()) { GTEST_SKIP() << "CUDA graph conditionals are not supported"; } + se::StreamExecutor* executor = GpuExecutor(); + se::Stream stream(executor); - stream.Init(); - ASSERT_TRUE(stream.ok()); + TF_ASSERT_OK(stream.Initialize()); int64_t length = 4; int64_t byte_length = sizeof(int32_t) * length; @@ -1064,9 +1067,9 @@ TEST(CommandBufferThunkTest, ForCmd) { se::DeviceMemory a = executor->AllocateArray(length, 0); se::DeviceMemory b = executor->AllocateArray(length, 0); - stream.ThenMemset32(&loop_cnt, 0, sizeof(int32_t)); - stream.ThenMemset32(&a, 1, byte_length); - stream.ThenMemZero(&b, byte_length); + TF_ASSERT_OK(stream.Memset32(&loop_cnt, 0, sizeof(int32_t))); + TF_ASSERT_OK(stream.Memset32(&a, 1, byte_length)); + TF_ASSERT_OK(stream.MemZero(&b, byte_length)); // Prepare buffer allocations for recording command buffer. BufferAllocation alloc_cnt(/*index=*/0, 1, /*color=*/0); @@ -1110,7 +1113,7 @@ TEST(CommandBufferThunkTest, ForCmd) { // Copy `b` data back to host. std::vector dst(4, 0); - stream.ThenMemcpy(dst.data(), b, byte_length); + TF_ASSERT_OK(stream.Memcpy(dst.data(), b, byte_length)); ASSERT_EQ(dst, std::vector(4, 10)); } diff --git a/third_party/xla/xla/service/gpu/runtime/concurrent_region.cc b/third_party/xla/xla/service/gpu/runtime/concurrent_region.cc deleted file mode 100644 index e1f2d15567b25c..00000000000000 --- a/third_party/xla/xla/service/gpu/runtime/concurrent_region.cc +++ /dev/null @@ -1,167 +0,0 @@ -/* Copyright 2023 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "xla/service/gpu/runtime/concurrent_region.h" - -#include -#include - -#include "xla/runtime/custom_call.h" -#include "xla/runtime/executable.h" -#include "xla/service/gpu/runtime/support.h" -#include "xla/service/stream_pool.h" -#include "xla/stream_executor/stream.h" -#include "tsl/platform/statusor.h" - -namespace xla { -namespace gpu { - -//===----------------------------------------------------------------------===// -// Definitions for ConcurrentRegionStatus. -//===----------------------------------------------------------------------===// - -ConcurrentRegionStatus::ConcurrentRegionStatus( - const ServiceExecutableRunOptions* run_options, int num_borrowed_streams) - : num_borrowed_streams_(num_borrowed_streams), - run_options_(run_options), - stream_index_(0), - capture_stream_(nullptr) {} - -ConcurrentRegionStatus::~ConcurrentRegionStatus() { - DCHECK(!IsInConcurrentRegion()); -} - -// Assign a stream in a round-robin fashion. Either the capture stream or one of -// the borrowed streams is returned. -se::Stream* ConcurrentRegionStatus::GetNextStream() { - DCHECK(IsInConcurrentRegion()); - if (borrowed_streams_.empty()) { - return nullptr; - } - - int index = stream_index_ % (borrowed_streams_.size() + 1); - stream_index_++; - - if (index == 0) { - return capture_stream_; - } - - return borrowed_streams_[index - 1].get(); -} - -absl::StatusOr ConcurrentRegionStatus::GetStream(int index) { - DCHECK(IsInConcurrentRegion()); - - if (index < 0 || index >= region_size_) { - return absl::OutOfRangeError("Invalid stream index"); - } - - if (index == 0) { - return capture_stream_; - } - - return borrowed_streams_[index - 1].get(); -} - -absl::Status ConcurrentRegionStatus::StartConcurrentRegion( - se::Stream* capture_stream, int64_t size) { - if (disabled_) { - return absl::OkStatus(); - } - - DCHECK(!IsInConcurrentRegion()); - se::StreamExecutor* executor = run_options_->stream()->parent(); - - // Stream borrowing should only happen in the first call to this function. - if (borrowed_streams_.empty()) { - TF_ASSIGN_OR_RETURN(std::vector borrowed_streams, - run_options_->BorrowStreams(executor->device_ordinal(), - num_borrowed_streams_)); - for (StreamPool::Ptr& stream : borrowed_streams) { - borrowed_streams_.push_back(std::move(stream)); - } - } - - // Switch borrowed streams into capture mode. We only synchronize enough - // streams to run the kernels. - for (int i = 0; i < std::min(size - 1, num_borrowed_streams_); ++i) { - borrowed_streams_[i]->ThenWaitFor(capture_stream); - } - - region_size_ = size; - capture_stream_ = capture_stream; - return absl::OkStatus(); -} - -void ConcurrentRegionStatus::EndConcurrentRegion() { - if (disabled_) { - return; - } - - DCHECK(IsInConcurrentRegion()); - - // Synchronize main capture stream with all borrowed streams in capture mode. - for (int i = 0; i < std::min(region_size_ - 1, num_borrowed_streams_); - ++i) { - capture_stream_->ThenWaitFor(borrowed_streams_[i].get()); - } - - stream_index_ = 0; - capture_stream_ = nullptr; -} - -bool ConcurrentRegionStatus::IsInConcurrentRegion() { - return capture_stream_ != nullptr; -} - -//===----------------------------------------------------------------------===// -// Define custom calls that mark the concurrent region in CUDA graphs. -//===----------------------------------------------------------------------===// - -using xla::runtime::CustomCall; - -static absl::Status RegionBegin(const ServiceExecutableRunOptions* run_options, - ConcurrentRegionStatus* region_status, - int64_t size) { - se::Stream* capture_stream = run_options->stream(); - return region_status->StartConcurrentRegion(capture_stream, size); -} - -static absl::Status RegionEnd(ConcurrentRegionStatus* region_status) { - region_status->EndConcurrentRegion(); - return absl::OkStatus(); -} - -//===----------------------------------------------------------------------===// - -XLA_RUNTIME_DEFINE_CUSTOM_CALL( - Begin, FunctionWrapper(), checks, - CustomCall::Bind("xla.gpu.concurrent_region.begin") - .UserData() - .UserData() - .Attr("size")); - -XLA_RUNTIME_DEFINE_CUSTOM_CALL(End, FunctionWrapper(), checks, - CustomCall::Bind("xla.gpu.concurrent_region.end") - .UserData()); - -void RegisterConcurrentRegionCustomCalls( - runtime::DirectCustomCallRegistry& registry) { - registry.Register("xla.gpu.concurrent_region.begin", Begin); - registry.Register("xla.gpu.concurrent_region.end", End); -} - -} // namespace gpu -} // namespace xla diff --git a/third_party/xla/xla/service/gpu/runtime/concurrent_region.h b/third_party/xla/xla/service/gpu/runtime/concurrent_region.h deleted file mode 100644 index 2591d6d31a66ab..00000000000000 --- a/third_party/xla/xla/service/gpu/runtime/concurrent_region.h +++ /dev/null @@ -1,75 +0,0 @@ -/* Copyright 2023 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_SERVICE_GPU_RUNTIME_CONCURRENT_REGION_H_ -#define XLA_SERVICE_GPU_RUNTIME_CONCURRENT_REGION_H_ - -#include - -#include "xla/runtime/custom_call_registry.h" -#include "xla/service/service_executable_run_options.h" - -namespace xla { -namespace gpu { - -// Registers XLA Gpu runtime kernel launch custom calls. -void RegisterConcurrentRegionCustomCalls( - runtime::DirectCustomCallRegistry& registry); - -// The state to keep track of the information regarding concurrent regions -// between custom calls. -class ConcurrentRegionStatus { - public: - explicit ConcurrentRegionStatus( - const ServiceExecutableRunOptions* run_options, - int num_borrowed_streams = 10); - - ~ConcurrentRegionStatus(); - - absl::Status StartConcurrentRegion(se::Stream* capture_stream, int64_t size); - void EndConcurrentRegion(); - - // Temporarily disable concurrent execution when we run GPU graphs op-by-op. - // If disabled_ is set to true, StartConcurrentRegion will become an no-op and - // IsInConcurrentRegion always returns false. - void DisableConcurrentRegion() { disabled_ = true; } - void EnableConcurrentRegion() { disabled_ = false; } - - // Get a stream on which the concurrent-executable kernel runs. It returns a - // different stream each time to avoid building dependencies in the CUDA - // graph. - se::Stream* GetNextStream(); - - absl::StatusOr GetStream(int index); - - bool IsInConcurrentRegion(); - - private: - const int num_borrowed_streams_; - std::vector borrowed_streams_; - const ServiceExecutableRunOptions* run_options_; - - bool disabled_ = false; - int32_t stream_index_; - - // It is set to nullptr if not in a concurrent region. - se::Stream* capture_stream_; - int region_size_; -}; - -} // namespace gpu -} // namespace xla - -#endif // XLA_SERVICE_GPU_RUNTIME_CONCURRENT_REGION_H_ diff --git a/third_party/xla/xla/service/gpu/runtime3/conditional_thunk.cc b/third_party/xla/xla/service/gpu/runtime/conditional_thunk.cc similarity index 56% rename from third_party/xla/xla/service/gpu/runtime3/conditional_thunk.cc rename to third_party/xla/xla/service/gpu/runtime/conditional_thunk.cc index 65df0cafe861e4..fb5a588fc21be0 100644 --- a/third_party/xla/xla/service/gpu/runtime3/conditional_thunk.cc +++ b/third_party/xla/xla/service/gpu/runtime/conditional_thunk.cc @@ -13,15 +13,25 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/runtime3/conditional_thunk.h" +#include "xla/service/gpu/runtime/conditional_thunk.h" +#include #include +#include +#include #include "absl/status/status.h" -#include "xla/hlo/ir/hlo_instruction.h" +#include "absl/synchronization/mutex.h" +#include "xla/service/buffer_assignment.h" +#include "xla/service/gpu/thunk.h" +#include "xla/service/gpu/variant_visitor.h" #include "xla/status.h" +#include "xla/status_macros.h" +#include "xla/stream_executor/device_memory.h" +#include "xla/stream_executor/memory_allocation.h" #include "xla/util.h" #include "tsl/platform/errors.h" +#include "tsl/platform/statusor.h" namespace xla { namespace gpu { @@ -55,6 +65,16 @@ absl::Status ConditionalThunk::Initialize(const InitializeParams& params) { for (auto& branch_thunk : config_.branch_thunks) { TF_RETURN_IF_ERROR(branch_thunk->Initialize(params)); } + + absl::MutexLock lock(&mutex_); + if (auto it = predicates_.find(params.executor); it == predicates_.end()) { + TF_ASSIGN_OR_RETURN( + std::unique_ptr allocation, + params.executor->HostMemoryAllocate( + config_.branch_index_is_bool ? sizeof(bool) : sizeof(int32_t))); + predicates_.emplace(params.executor, std::move(allocation)); + } + return absl::OkStatus(); } @@ -62,28 +82,39 @@ absl::Status ConditionalThunk::ExecuteOnStream(const ExecuteParams& params) { auto& stream = *params.stream; // Copy the predicate value from device. - int32_t branch_index = -1; - bool pred = false; + auto branch_index_or_pred = [&]() -> std::variant { + absl::MutexLock lock(&mutex_); + se::StreamExecutor* executor = stream.parent(); + if (config_.branch_index_is_bool) { + return reinterpret_cast(predicates_.at(executor)->opaque()); + } else { + return reinterpret_cast(predicates_.at(executor)->opaque()); + } + }(); + se::DeviceMemoryBase branch_index_address = params.buffer_allocations->GetDeviceAddress(branch_index_buffer_index_); if (config_.branch_index_is_bool) { - stream.ThenMemcpy(&pred, branch_index_address, sizeof(bool)); + TF_RETURN_IF_ERROR(stream.Memcpy(std::get(branch_index_or_pred), + branch_index_address, sizeof(bool))); } else { - stream.ThenMemcpy(&branch_index, branch_index_address, sizeof(int32_t)); + TF_RETURN_IF_ERROR(stream.Memcpy(std::get(branch_index_or_pred), + branch_index_address, sizeof(int32_t))); } - absl::Status block_status = stream.BlockHostUntilDone(); - if (!block_status.ok()) { + if (absl::Status blocked = stream.BlockHostUntilDone(); !blocked.ok()) { return Internal("Failed to retrieve branch_index value on stream %p: %s.", - &stream, block_status.message()); + &stream, blocked.message()); } - if (config_.branch_index_is_bool) { - branch_index = pred ? 0 : 1; - } else { - // Handle default scenario for branch_index not in [0, num_branches). - if (branch_index < 0 || branch_index >= config_.branch_count) { - branch_index = config_.branch_count - 1; - } + + int32_t branch_index = std::visit( + VariantVisitor{[](int32_t* branch_index) { return *branch_index; }, + [](bool* pred) { return *pred ? 0 : 1; }}, + branch_index_or_pred); + + // Handle default scenario for branch_index not in [0, num_branches). + if (branch_index < 0 || branch_index >= config_.branch_count) { + branch_index = config_.branch_count - 1; } // Execute the branch computation corresponding to the value of branch_index. diff --git a/third_party/xla/xla/service/gpu/runtime3/conditional_thunk.h b/third_party/xla/xla/service/gpu/runtime/conditional_thunk.h similarity index 79% rename from third_party/xla/xla/service/gpu/runtime3/conditional_thunk.h rename to third_party/xla/xla/service/gpu/runtime/conditional_thunk.h index aca9cac32a64e4..a7a7d3be0e0e37 100644 --- a/third_party/xla/xla/service/gpu/runtime3/conditional_thunk.h +++ b/third_party/xla/xla/service/gpu/runtime/conditional_thunk.h @@ -13,20 +13,23 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_GPU_RUNTIME3_CONDITIONAL_THUNK_H_ -#define XLA_SERVICE_GPU_RUNTIME3_CONDITIONAL_THUNK_H_ +#ifndef XLA_SERVICE_GPU_RUNTIME_CONDITIONAL_THUNK_H_ +#define XLA_SERVICE_GPU_RUNTIME_CONDITIONAL_THUNK_H_ #include #include #include +#include "absl/base/thread_annotations.h" +#include "absl/container/flat_hash_map.h" #include "absl/status/status.h" +#include "absl/synchronization/mutex.h" #include "absl/types/span.h" -#include "xla/hlo/ir/hlo_instruction.h" -#include "xla/service/gpu/buffer_allocations.h" -#include "xla/service/gpu/runtime3/sequential_thunk.h" +#include "xla/service/buffer_assignment.h" +#include "xla/service/gpu/runtime/sequential_thunk.h" #include "xla/service/gpu/thunk.h" #include "xla/status.h" +#include "xla/stream_executor/memory_allocation.h" #include "xla/stream_executor/stream_executor.h" namespace xla { @@ -72,9 +75,15 @@ class ConditionalThunk : public Thunk { private: const ConditionalThunkConfig config_; const BufferAllocation::Slice branch_index_buffer_index_; + + // Pinned host memory for transferring predicate value from device to host. + absl::Mutex mutex_; + absl::flat_hash_map> + predicates_ ABSL_GUARDED_BY(mutex_); }; } // namespace gpu } // namespace xla -#endif // XLA_SERVICE_GPU_RUNTIME3_CONDITIONAL_THUNK_H_ +#endif // XLA_SERVICE_GPU_RUNTIME_CONDITIONAL_THUNK_H_ diff --git a/third_party/xla/xla/service/gpu/runtime/conv.cc b/third_party/xla/xla/service/gpu/runtime/conv.cc deleted file mode 100644 index 166123e81c7bb2..00000000000000 --- a/third_party/xla/xla/service/gpu/runtime/conv.cc +++ /dev/null @@ -1,689 +0,0 @@ -/* Copyright 2022 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "xla/service/gpu/runtime/conv.h" - -#include -#include -#include -#include -#include - -#include "llvm/ADT/Sequence.h" -#include "xla/mlir/runtime/transforms/custom_call_encoding.h" -#include "xla/runtime/custom_call.h" -#include "xla/runtime/executable.h" -#include "xla/service/gpu/gpu_asm_opts_util.h" -#include "xla/service/gpu/gpu_conv_runner.h" -#include "xla/service/gpu/non_atomically_upgradeable_rw_lock.h" -#include "xla/service/gpu/runtime/support.h" -#include "xla/service/service_executable_run_options.h" -#include "xla/status.h" -#include "xla/stream_executor/device_memory.h" -#include "xla/stream_executor/device_memory_allocator.h" -#include "xla/translate/mhlo_to_hlo/attribute_exporter.h" -#include "xla/xla.pb.h" - -#if GOOGLE_CUDA -#include "xla/service/gpu/autotuner_util.h" -#include "xla/service/gpu/conv_algorithm_picker.h" -#endif - -namespace xla { - -using xla::runtime::AggregateAttrDef; -using xla::runtime::AggregateAttrEncoding; -using xla::runtime::CustomCall; -using xla::runtime::EnumAttrEncoding; -using xla::runtime::FlatMemrefView; -using xla::runtime::State; -using xla::runtime::StridedMemrefView; -using xla::runtime::Tagged; - -namespace lmhlo_gpu = ::mlir::lmhlo_gpu; -namespace mhlo = ::mlir::mhlo; - -//===----------------------------------------------------------------------===// -// Structs for encoding convolution attributes defined in MHLO dialect. -//===----------------------------------------------------------------------===// - -namespace gpu { - -struct ConvDimensionNumbers { - int64_t input_batch_dim; - int64_t input_feature_dim; - absl::Span input_spatial_dims; - - int64_t kernel_in_feature_dim; - int64_t kernel_out_feature_dim; - absl::Span kernel_spatial_dims; - - int64_t output_batch_dim; - int64_t output_feature_dim; - absl::Span output_spatial_dims; -}; - -struct ConvBackendConfig { - int64_t algorithm; - bool tensor_ops_enabled; - bool is_cudnn_frontend; - bool is_cudnn_reordered_int8; - absl::Span knob_ids; - absl::Span knob_values; - absl::Span operand_0_layout; - absl::Span operand_1_layout; - absl::Span result_layout; - int64_t workspace_size; -}; - -} // namespace gpu - -//===----------------------------------------------------------------------===// -// Register convolution attributes decoding with the Xla runtime. -//===----------------------------------------------------------------------===// - -namespace runtime { - -XLA_RUNTIME_REGISTER_ENUM_ATTR_DECODING(se::dnn::ActivationMode); - -XLA_RUNTIME_REGISTER_AGGREGATE_ATTR_DECODING( - xla::gpu::ConvDimensionNumbers, - // --- input dimensions - AggregateMember("input_batch_dim"), - AggregateMember("input_feature_dim"), - AggregateMember>("input_spatial_dims"), - // --- kernel dimensions - AggregateMember("kernel_in_feature_dim"), - AggregateMember("kernel_out_feature_dim"), - AggregateMember>("kernel_spatial_dims"), - // --- output dimensions - AggregateMember("output_batch_dim"), - AggregateMember("output_feature_dim"), - AggregateMember>("output_spatial_dims")); - -XLA_RUNTIME_REGISTER_AGGREGATE_ATTR_DECODING( - xla::gpu::ConvBackendConfig, // - AggregateMember("algorithm"), - AggregateMember("tensor_ops_enabled"), - AggregateMember("is_cudnn_frontend"), - AggregateMember("is_cudnn_reordered_int8"), - AggregateMember>("knob_ids"), - AggregateMember>("knob_values"), - AggregateMember>("operand_0_layout"), - AggregateMember>("operand_1_layout"), - AggregateMember>("result_layout"), - AggregateMember("workspace_size")); - -} // namespace runtime - -//===----------------------------------------------------------------------===// -// Type names for encoded attributes. -//===----------------------------------------------------------------------===// - -namespace gpu { - -void RegisterConvTypeIdNames(runtime::TypeIDNameRegistry& registry) { - registry.Register>("__type_id_conv_dim_numbers"); - registry.Register>("__type_id_conv_backend_config"); -} - -//===----------------------------------------------------------------------===// -// Encoding from MHLO attributes to Xla runtime aggregate attributes. -//===----------------------------------------------------------------------===// - -// TODO(ezhulenev): We have to support enum encoding that can fail instead of -// always getting the value from returned StatusOr. -static auto EncodeConvActivation(lmhlo_gpu::Activation activation) { - return ConvertConvActivationMode(activation).value(); -} - -void PopulateConvAttrEncoding(runtime::CustomCallAttrEncodingSet& encoding) { - { // --- Encode `lmhlo_gpu::ActivationAttr`. - encoding - .Add>(EncodeConvActivation); - } - - { // --- Encode `mhlo::ConvDimensionNumbersAttr`. - using Attr = mhlo::ConvDimensionNumbersAttr; - encoding.Add>( - encoding, - AggregateAttrDef() - .Add("input_batch_dim", &Attr::getInputBatchDimension) - .Add("input_feature_dim", &Attr::getInputFeatureDimension) - .Add("input_spatial_dims", &Attr::getInputSpatialDimensions) - .Add("kernel_in_feature_dim", &Attr::getKernelInputFeatureDimension) - .Add("kernel_out_feature_dim", - &Attr::getKernelOutputFeatureDimension) - .Add("kernel_spatial_dims", &Attr::getKernelSpatialDimensions) - .Add("output_batch_dim", &Attr::getOutputBatchDimension) - .Add("output_feature_dim", &Attr::getOutputFeatureDimension) - .Add("output_spatial_dims", &Attr::getOutputSpatialDimensions)); - } - - { // --- Encode `lmhlo_gpu::ConvolutionBackendConfigAttr`. - using Attr = lmhlo_gpu::ConvolutionBackendConfigAttr; - encoding.Add>( - encoding, - AggregateAttrDef() - .Add("algorithm", &Attr::getAlgorithm) - .Add("tensor_ops_enabled", &Attr::getTensorOpsEnabled) - .Add("is_cudnn_frontend", &Attr::getIsCudnnFrontend) - .Add("is_cudnn_reordered_int8", &Attr::getIsCudnnReorderedInt8) - .Add("knob_ids", &Attr::getKnobIds) - .Add("knob_values", &Attr::getKnobValues) - .Add("operand_0_layout", &Attr::getOperand_0Layout) - .Add("operand_1_layout", &Attr::getOperand_1Layout) - .Add("result_layout", &Attr::getResultLayout) - .Add("workspace_size", &Attr::getWorkspaceSize)); - } -} - -//===----------------------------------------------------------------------===// -// Convolution runners caching. -//===----------------------------------------------------------------------===// - -StreamExecutorConvRunners* ConvRunners::operator()( - se::StreamExecutor* executor) { - absl::MutexLock lock(&mutex_); - return &runners_[executor]; -} - -//===----------------------------------------------------------------------===// -// Convolution custom call implementation. -//===----------------------------------------------------------------------===// - -namespace { - -struct Window { - absl::Span window_strides; - absl::Span padding; - absl::Span lhs_dilation; - absl::Span rhs_dilation; - absl::Span window_reversal; -}; - -struct ConvAttrs { - int64_t feature_group_count; - double result_scale; -}; - -struct FusedConvAttrs { - se::dnn::ActivationMode activation_mode; -}; - -struct SideInputAttrs { - double side_input_scale; -}; - -struct LeakyReluAlphaAttrs { - double leaky_relu_alpha; -}; - -} // namespace - -static GpuConvDescriptor GetConvDescriptor( - CudnnConvKind kind, - // Arguments - StridedMemrefView operand0, StridedMemrefView operand1, - StridedMemrefView output, FlatMemrefView scratch, - // Attributes - ConvDimensionNumbers dims, Window w, ConvBackendConfig b, ConvAttrs attrs, - // Conv-specific arguments and attributes - std::optional fused = std::nullopt, - std::optional side_input = std::nullopt, - std::optional leakyrelu_alpha = std::nullopt) { - // Build a convolution descriptor from the attributes. - GpuConvDescriptor descriptor; - descriptor.kind = kind; - - // Apply backend config layout to the shape. - auto apply_layout = [](StridedMemrefView& memref, - absl::Span minor_to_major) { - Shape shape = ToShape(memref); - return ShapeUtil::MakeShapeWithDenseLayout( - shape.element_type(), shape.dimensions(), minor_to_major); - }; - - descriptor.operand0_shape = apply_layout(operand0, b.operand_0_layout); - descriptor.operand1_shape = apply_layout(operand1, b.operand_1_layout); - descriptor.result_shape = apply_layout(output, b.result_layout); - - // Set up convolution dimensions numbers. - ConvolutionDimensionNumbers dns; - dns.set_input_batch_dimension(dims.input_batch_dim); - dns.set_input_feature_dimension(dims.input_feature_dim); - dns.set_kernel_input_feature_dimension(dims.kernel_in_feature_dim); - dns.set_kernel_output_feature_dimension(dims.kernel_out_feature_dim); - dns.set_output_batch_dimension(dims.output_batch_dim); - dns.set_output_feature_dimension(dims.output_feature_dim); - for (int64_t d : dims.input_spatial_dims) dns.add_input_spatial_dimensions(d); - for (int64_t d : dims.kernel_spatial_dims) - dns.add_kernel_spatial_dimensions(d); - for (int64_t d : dims.output_spatial_dims) - dns.add_output_spatial_dimensions(d); - descriptor.dnums = std::move(dns); - - // Put together convolution window config. - for (auto index : llvm::seq(0, w.window_strides.size())) { - WindowDimension* dim = descriptor.window.add_dimensions(); - // Window size for a convolution is the same as the kernel size. - // Kernel size of the convolution is operand1_shape. We need to look at - // the convolution dimension numbers kernel spatial dimensions to get - // the window size. - int kernel_dim = descriptor.dnums.kernel_spatial_dimensions(index); - dim->set_size(descriptor.operand0_shape.dimensions(kernel_dim)); - dim->set_stride(w.window_strides[index]); - dim->set_padding_low(w.padding[index]); - dim->set_padding_high(w.padding[index]); - dim->set_base_dilation(w.lhs_dilation[index]); - dim->set_window_dilation(w.rhs_dilation[index]); - dim->set_window_reversal(w.window_reversal[index]); - } - - descriptor.scratch_size = scratch.size_in_bytes; - descriptor.feature_group_count = attrs.feature_group_count; - descriptor.backend_config.set_conv_result_scale(attrs.result_scale); - descriptor.backend_config.set_reordered_int8_nchw_vect( - b.is_cudnn_reordered_int8); - - // Set up convolution algorigthm. - auto* algo = descriptor.backend_config.mutable_algorithm(); - algo->set_algo_id(b.algorithm); - algo->set_math_type(b.tensor_ops_enabled - ? se::dnn::AlgorithmProto::TENSOR_OP_MATH - : se::dnn::AlgorithmProto::DEFAULT_MATH); - algo->set_is_cudnn_frontend(b.is_cudnn_frontend); - - if (b.workspace_size >= 0) - algo->mutable_workspace_size()->set_value(b.workspace_size); - - for (unsigned i = 0; i < b.knob_ids.size(); ++i) { - algo->mutable_tuning_knobs()->insert({b.knob_ids[i], b.knob_values[i]}); - } - - // Set attributes specific for fused convolutions. - if (fused.has_value()) - descriptor.backend_config.set_activation_mode(fused->activation_mode); - - // Set attributes specific for fused convolutions with leaky_relu_alpha. - if (leakyrelu_alpha.has_value()) - descriptor.backend_config.set_leakyrelu_alpha( - leakyrelu_alpha->leaky_relu_alpha); - - // Set attributes specific for convolutions with side input. - if (side_input.has_value()) - descriptor.backend_config.set_side_input_scale( - side_input->side_input_scale); - - return descriptor; -} - -template -static absl::Status DoConv( - const ServiceExecutableRunOptions* run_options, - const DebugOptions* debug_options, NonAtomicallyUpgradeableRWLock* gpu_lock, - State runner, - // Arguments - StridedMemrefView operand0, StridedMemrefView operand1, - std::optional bias, - std::optional side_input, - absl::Span outputs, FlatMemrefView scratch, - int64_t uid, - // Convolution config - ConvDimensionNumbers conv_dims, - // Window config - absl::Span window_strides, absl::Span padding, - absl::Span lhs_dilation, - absl::Span rhs_dilation, - absl::Span window_reversal, - // Backend config attributes - ConvBackendConfig backend_config, - // Remaining attributes - int64_t feature_group_count, double result_scale, - // Optional attributes for fused convolutions. - std::optional activation_mode = std::nullopt, - std::optional side_input_scale = std::nullopt, - std::optional leakyrelu_alpha = std::nullopt, - // Optional extra arguments for graph convolutions. - absl::Span extra_operands = {}, - std::optional serialized_graph = std::nullopt) { - // Build config for optional attributes. - std::optional fused_attrs = std::nullopt; - if (activation_mode.has_value()) fused_attrs = {*activation_mode}; - - std::optional side_input_attrs = std::nullopt; - if (side_input_scale.has_value()) side_input_attrs = {*side_input_scale}; - - std::optional leakyrelu_alpha_attrs = std::nullopt; - if (leakyrelu_alpha.has_value()) leakyrelu_alpha_attrs = {*leakyrelu_alpha}; - - bool runtime_autotuning = false; - if (backend_config.algorithm == -1) { - // Set the algorithm back to the default algorithm to avoid error from - // cuDNN. - backend_config.algorithm = 0; - runtime_autotuning = true; - } - - // Get or create the convolution runner state. - TF_ASSIGN_OR_RETURN( - ConvRunner * conv, - runner.GetOrCreate([&]() -> absl::StatusOr { - GpuConvDescriptor descriptor = GetConvDescriptor( - kind, operand0, operand1, outputs[0], scratch, conv_dims, - {window_strides, padding, lhs_dilation, rhs_dilation, - window_reversal}, - backend_config, {feature_group_count, result_scale}, fused_attrs, - side_input_attrs, leakyrelu_alpha_attrs); - if (serialized_graph.has_value()) { - descriptor.backend_config.set_serialized_graph( - std::string(serialized_graph.value())); - } - TF_ASSIGN_OR_RETURN(GpuConvConfig conv_config, - GetGpuConvConfig(descriptor, "")); - - return ConvRunner(std::move(conv_config)); - })); - - // Prepare buffer arguments. - std::vector buffers = {GetDeviceAddress(operand0), - GetDeviceAddress(operand1)}; - if (bias.has_value()) buffers.push_back(GetDeviceAddress(*bias)); - if (side_input.has_value()) buffers.push_back(GetDeviceAddress(*side_input)); - for (const StridedMemrefView& operand : extra_operands) { - buffers.push_back(GetDeviceAddress(operand)); - } - - std::vector result_buffers; - for (const StridedMemrefView& output : outputs) { - result_buffers.push_back(GetDeviceAddress(output)); - } - se::DeviceMemoryBase scratch_buffer = GetDeviceAddress(scratch); - - int64_t scratch_buffer_size = scratch_buffer.size(); - - // Do runtime conv autotuning. - if (runtime_autotuning) { -#if GOOGLE_CUDA - // Don't run autotuning concurrently on the same GPU. - NonAtomicallyUpgradeableRWLock::WriterLock writer_lock = - gpu_lock->UpgradeToWriterMutexLock(); - - auto stream_exec = run_options->stream()->parent(); - auto allocator = run_options->allocator(); - AutotuneConfig config(DeviceConfig{stream_exec, allocator}, *debug_options); - GpuConvAlgorithmPicker conv_algorithm_picker(config); - - GpuConvConfig gpu_conv_config = conv->config; - TF_ASSIGN_OR_RETURN( - AutotuneResult best_algo, - conv_algorithm_picker.PickBestAlgorithmWithAllocatedBuffer( - config, gpu_conv_config, run_options, *debug_options, buffers, - result_buffers)); - - // Set algorithm in the convolution runner state. - se::dnn::AlgorithmDesc algo_desc(best_algo.conv().algorithm(), - best_algo.conv().tensor_ops_enabled()); - conv->config.algorithm = algo_desc; - - // Set scratch buffer size according to the selected algorithm. - scratch_buffer_size = best_algo.scratch_bytes(); -#else - return absl::InternalError( - "Failed to run runtime autotuner because CUDA is not enabled"); -#endif - } - - RunConvOptions opts; - opts.runner_cache = &conv->runner; - - if (scratch_buffer_size > scratch_buffer.size()) { - // Need to reallocate scratch buffer. - se::DeviceMemoryAllocator* allocator = run_options->allocator(); - TF_ASSIGN_OR_RETURN(se::OwningDeviceMemory allocated_buffer, - allocator->Allocate(run_options->device_ordinal(), - scratch_buffer_size)); - se::DeviceMemoryBase new_scratch_buffer(allocated_buffer.ptr(), - scratch_buffer_size); - - // Run the convolution using the new scratch buffer. - TF_RETURN_IF_ERROR(RunGpuConv(conv->config, buffers, result_buffers, - new_scratch_buffer, run_options->stream(), - opts)); - if (!run_options->stream()->ok()) { - return absl::InternalError("run_options stream not ok"); - } - return absl::OkStatus(); - } - - // Run the convolution. - TF_RETURN_IF_ERROR(RunGpuConv(conv->config, buffers, result_buffers, - scratch_buffer, run_options->stream(), opts)); - if (!run_options->stream()->ok()) { - return absl::InternalError("run_options stream not ok"); - } - - return absl::OkStatus(); -} - -template -static absl::Status ConvImpl( - const ServiceExecutableRunOptions* run_options, - const DebugOptions* debug_options, NonAtomicallyUpgradeableRWLock* gpu_lock, - State runner, - // Arguments - StridedMemrefView operand0, StridedMemrefView operand1, - std::optional bias, - std::optional side_input, StridedMemrefView output, - FlatMemrefView scratch, int64_t uid, - // Convolution config - ConvDimensionNumbers conv_dims, - // Window config - absl::Span window_strides, absl::Span padding, - absl::Span lhs_dilation, - absl::Span rhs_dilation, - absl::Span window_reversal, - // Backend config attributes - ConvBackendConfig backend_config, - // Remaining attributes - int64_t feature_group_count, double result_scale, - // Optional attributes for fused convolutions. - std::optional activation_mode = std::nullopt, - std::optional side_input_scale = std::nullopt, - std::optional leakyrelu_alpha = std::nullopt) { - return DoConv(run_options, debug_options, gpu_lock, runner, operand0, - operand1, bias, side_input, {output}, scratch, uid, - conv_dims, window_strides, padding, lhs_dilation, - rhs_dilation, window_reversal, backend_config, - feature_group_count, result_scale, activation_mode, - side_input_scale, leakyrelu_alpha); -} - -template -static absl::Status ConvGraphImpl( - const ServiceExecutableRunOptions* run_options, - const DebugOptions* debug_options, NonAtomicallyUpgradeableRWLock* gpu_lock, - State runner, - // Arguments - StridedMemrefView operand0, StridedMemrefView operand1, - CustomCall::RemainingArgs args, int64_t uid, - // Convolution config - ConvDimensionNumbers conv_dims, - // Window config - absl::Span window_strides, absl::Span padding, - absl::Span lhs_dilation, - absl::Span rhs_dilation, - absl::Span window_reversal, - // Backend config attributes - ConvBackendConfig backend_config, - // Remaining attributes - int64_t feature_group_count, double result_scale, int32_t n_aux_outputs, - std::string_view serialized_graph) { - // Let N be the size of 'args'. The first (N - n_aux_outputs - 2) elements of - // 'args' are extra operands, which are operands other than the input and - // filter. The next (n_aux_outputs + 1) elements are the outputs -- the first - // being the main convolution output and the others being the "auxiliary" - // outputs (e.g. amax). The last element of 'args' is the scratch space. - std::vector extra_operands; - for (int i = 0; i < args.size() - n_aux_outputs - 2; i++) { - auto arg = args.get(i); - if (failed(arg)) { - return absl::InternalError( - "Failed to get operand buffer for convolution graph"); - } - extra_operands.push_back(arg.value()); - } - - std::vector outputs; - for (int i = args.size() - n_aux_outputs - 2; i < args.size() - 1; i++) { - auto arg = args.get(i); - if (failed(arg)) { - return absl::InternalError( - "Failed to get output buffer for convolution graph"); - } - outputs.push_back(arg.value()); - } - - auto scratch = args.get(args.size() - 1); - if (failed(scratch)) { - return absl::InternalError( - "Failed to get scratch buffer for convolution graph"); - } - - return DoConv(run_options, debug_options, gpu_lock, runner, operand0, - operand1, /*bias=*/{}, - /*side_input=*/{}, outputs, scratch.value(), uid, - conv_dims, window_strides, padding, lhs_dilation, - rhs_dilation, window_reversal, backend_config, - feature_group_count, result_scale, /*activation_mode=*/{}, - /*side_input_scale=*/{}, /*leakyrelu_alpha=*/{}, - extra_operands, serialized_graph); -} - -//===----------------------------------------------------------------------===// -// Convolution custom calls bindings and registration. -//===----------------------------------------------------------------------===// - -using Kind = CudnnConvKind; - -template -static auto BindConvAttributes(runtime::CustomCallBinding binding) { - return std::move(binding) - // Unique convolution id for caching state. - .template Attr("uid") - // Convolution dimensions numbers - .template Attr("conv_dims") - // Window config - .template Attr>("window_strides") - .template Attr>("padding") - .template Attr>("lhs_dilation") - .template Attr>("rhs_dilation") - .template Attr>("window_reversal") - // Backend config attributes - .template Attr("backend_config") - // Remaining attributes. - .template Attr("feature_group_count") - .template Attr("result_scale"); -} - -XLA_RUNTIME_DEFINE_CUSTOM_CALL_TEMPLATE( - Kind kind, Conv, FunctionWrapper>(), checks, - BindConvAttributes( - CustomCall::Bind("xla.gpu.conv") - .UserData() - .UserData() - .UserData() - .State("uid") // runner - .Arg() // operand0 - .Arg() // operand1 - .Value(std::optional()) // bias - .Value(std::optional()) // side_input - .Arg() // output - .Arg() // scratch - ) - .Value(std::optional()) // activation_mode - .Value(std::optional()) // side_input_scale - .Value(std::optional()) // leaky_relu_alpha -); - -XLA_RUNTIME_DEFINE_CUSTOM_CALL( - ConvFused, FunctionWrapper>(), checks, - BindConvAttributes( - CustomCall::Bind("xla.gpu.conv.fused") - .UserData() - .UserData() - .UserData() - .State("uid") // runner - .Arg() // operand0 - .Arg() // operand1 - .Arg() // bias - .Value(std::optional()) // side_input - .Arg() // output - .Arg() // scratch - ) - .Attr("activation_mode") - .Value(std::optional()) // side_input_scale - .Attr("leakyrelu_alpha") // leaky_relu_alpha -); - -XLA_RUNTIME_DEFINE_CUSTOM_CALL( - ConvFusedSideInput, FunctionWrapper>(), - checks, - BindConvAttributes(CustomCall::Bind("xla.gpu.conv.fused.side_input") - .UserData() - .UserData() - .UserData() - .State("uid") // runner - .Arg() // operand0 - .Arg() // operand1 - .Arg() // bias - .Arg() // side_input - .Arg() // output - .Arg() // scratch - ) - .Attr("activation_mode") - .Attr("side_input_scale") - .Value(std::optional())); // leaky_relu_alpha - -XLA_RUNTIME_DEFINE_CUSTOM_CALL( - ConvForwardGraph, FunctionWrapper>(), - checks, - BindConvAttributes(CustomCall::Bind("xla.gpu.conv.forward.graph") - .UserData() - .UserData() - .UserData() - .State("uid") // runner - .Arg() // operand0 - .Arg() // operand1 - .RemainingArgs() // binary_operands - ) - .Attr("n_aux_outputs") - .Attr("serialized_graph")); - -//===----------------------------------------------------------------------===// - -void RegisterConvCustomCalls(runtime::DirectCustomCallRegistry& registry) { - auto conv = [](std::string name) { return "xla.gpu.conv." + name; }; - registry.Register(conv("forward"), Conv); - registry.Register(conv("backward.input"), Conv); - registry.Register(conv("backward.filter"), Conv); - registry.Register(conv("forward.fused"), ConvFused); - registry.Register(conv("forward.fused.side_input"), ConvFusedSideInput); - registry.Register(conv("forward.graph"), ConvForwardGraph); -} - -} // namespace gpu -} // namespace xla diff --git a/third_party/xla/xla/service/gpu/runtime/conv.h b/third_party/xla/xla/service/gpu/runtime/conv.h deleted file mode 100644 index 91591b8ffab123..00000000000000 --- a/third_party/xla/xla/service/gpu/runtime/conv.h +++ /dev/null @@ -1,67 +0,0 @@ -/* Copyright 2022 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_SERVICE_GPU_RUNTIME_CONV_H_ -#define XLA_SERVICE_GPU_RUNTIME_CONV_H_ - -#include -#include - -#include "absl/container/node_hash_map.h" -#include "absl/synchronization/mutex.h" -#include "xla/mlir/runtime/transforms/custom_call_encoding.h" -#include "xla/runtime/custom_call_registry.h" -#include "xla/service/gpu/gpu_conv_runner.h" - -namespace xla { -namespace gpu { - -// Registers XLA Gpu runtime Conv custom calls. -void RegisterConvCustomCalls(runtime::DirectCustomCallRegistry& registry); - -// Register type names for convoluttion attributes defined by MHLO dialect. -void RegisterConvTypeIdNames(runtime::TypeIDNameRegistry& registry); - -// Add attributes encoding for convoluttion attributes defined by MHLO dialect. -void PopulateConvAttrEncoding(runtime::CustomCallAttrEncodingSet& encoding); - -//===----------------------------------------------------------------------===// -// Cache conv runners between invocations of convolution custom calls. -//===----------------------------------------------------------------------===// - -struct ConvRunner { - explicit ConvRunner(GpuConvConfig config) - : config(std::move(config)), runner(this->config) {} - GpuConvConfig config; - GenericConvRunner runner; -}; - -class StreamExecutorConvRunners : public runtime::StateVector {}; - -// Xla executable keeps a mapping from stream executors to convolution runners. -class ConvRunners { - public: - StreamExecutorConvRunners* operator()(se::StreamExecutor* executor); - - private: - mutable absl::Mutex mutex_; - absl::node_hash_map runners_ - ABSL_GUARDED_BY(mutex_); -}; - -} // namespace gpu -} // namespace xla - -#endif // XLA_SERVICE_GPU_RUNTIME_CONV_H_ diff --git a/third_party/xla/xla/service/gpu/runtime/conv_reorder.cc b/third_party/xla/xla/service/gpu/runtime/conv_reorder.cc deleted file mode 100644 index 2264f6b921b796..00000000000000 --- a/third_party/xla/xla/service/gpu/runtime/conv_reorder.cc +++ /dev/null @@ -1,105 +0,0 @@ -/* Copyright 2022 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "xla/service/gpu/runtime/conv_reorder.h" - -#include -#include - -#include "xla/runtime/custom_call.h" -#include "xla/runtime/executable.h" -#include "xla/service/gpu/runtime/support.h" -#include "xla/service/service_executable_run_options.h" -#include "xla/xla.pb.h" - -namespace xla { -namespace gpu { -namespace { - -using ::xla::runtime::CustomCall; -using ::xla::runtime::FlatMemrefView; -using ::xla::runtime::StridedMemrefView; - -se::dnn::FilterDescriptor GetFilterDescriptor( - absl::Span filter_dims) { - se::dnn::FilterDescriptor filter_desc(2); - filter_desc.set_layout(se::dnn::FilterLayout::kOutputInputYX32); - filter_desc.set_output_feature_map_count(filter_dims[0]); - filter_desc.set_input_feature_map_count(filter_dims[1]); - filter_desc.set_input_filter_height(filter_dims[2]); - filter_desc.set_input_filter_width(filter_dims[3]); - return filter_desc; -} - -absl::Status ConvReorderFilterImpl( - const ServiceExecutableRunOptions* run_options, - StridedMemrefView input_view, StridedMemrefView output_view, - absl::Span filter_dims) { - auto input = se::DeviceMemory(GetDeviceAddress(input_view)); - auto output = se::DeviceMemory(GetDeviceAddress(output_view)); - - return run_options->stream()->CudnnReorderConvolutionFilterAndBias( - GetFilterDescriptor(filter_dims), input, &output, std::nullopt, - std::nullopt); -} - -absl::Status ConvReorderFilterAndBiasImpl( - const ServiceExecutableRunOptions* run_options, - StridedMemrefView filter_input_view, FlatMemrefView bias_input_view, - StridedMemrefView filter_output_view, FlatMemrefView bias_output_view, - absl::Span filter_dims) { - auto filter_input = - se::DeviceMemory(GetDeviceAddress(filter_input_view)); - auto filter_output = - se::DeviceMemory(GetDeviceAddress(filter_output_view)); - auto bias_input = se::DeviceMemory(GetDeviceAddress(bias_input_view)); - auto bias_output = - se::DeviceMemory(GetDeviceAddress(bias_output_view)); - - return run_options->stream()->CudnnReorderConvolutionFilterAndBias( - GetFilterDescriptor(filter_dims), filter_input, &filter_output, - std::make_optional(bias_input), std::make_optional(bias_output)); -} - -} // namespace - -XLA_RUNTIME_DEFINE_CUSTOM_CALL( - ConvReorderFilter, FunctionWrapper(), checks, - CustomCall::Bind("xla.gpu.conv.reorder.filter") - .UserData() - .Arg() // filter_input - .Arg() // filter_output - .Attr>("filter_dims")); - -XLA_RUNTIME_DEFINE_CUSTOM_CALL( - ConvReorderFilterAndBias, FunctionWrapper(), - checks, - CustomCall::Bind("xla.gpu.conv.reorder.filter_and_bias") - .UserData() - .Arg() // filter_input - .Arg() // bias_input - .Arg() // filter_output - .Arg() // bias_output - .Attr>("filter_dims")); - -void RegisterConvReorderCustomCalls( - runtime::DirectCustomCallRegistry& registry) { - registry.Register("xla.gpu.conv.reorder.filter", ConvReorderFilter); - registry.Register("xla.gpu.conv.reorder.filter_and_bias", - ConvReorderFilterAndBias); -} - -} // namespace gpu -} // namespace xla diff --git a/third_party/xla/xla/service/gpu/runtime3/convolution_thunk.cc b/third_party/xla/xla/service/gpu/runtime/convolution_thunk.cc similarity index 87% rename from third_party/xla/xla/service/gpu/runtime3/convolution_thunk.cc rename to third_party/xla/xla/service/gpu/runtime/convolution_thunk.cc index fdbbaf12f91f2d..1d33eca1ac8603 100644 --- a/third_party/xla/xla/service/gpu/runtime3/convolution_thunk.cc +++ b/third_party/xla/xla/service/gpu/runtime/convolution_thunk.cc @@ -13,12 +13,14 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/runtime3/convolution_thunk.h" +#include "xla/service/gpu/runtime/convolution_thunk.h" #include #include +#include "absl/container/inlined_vector.h" #include "absl/status/status.h" +#include "xla/service/buffer_assignment.h" #include "xla/service/gpu/gpu_conv_runner.h" #include "xla/stream_executor/stream_executor.h" #include "xla/util.h" @@ -83,12 +85,12 @@ absl::Status ConvolutionThunk::ExecuteOnStream(const ExecuteParams& params) { ConvolutionReorderThunk::ConvolutionReorderThunk( ThunkInfo thunk_info, absl::Span filter_nchw, - std::vector operand_slices, - std::vector result_slices) + absl::InlinedVector operand_slices, + absl::InlinedVector result_slices) : Thunk(Kind::kConvolutionReorder, thunk_info), filter_descriptor_(CreateFilterDescriptor(filter_nchw)), - operand_buffers_(std::move(operand_slices)), - result_buffers_(std::move(result_slices)) {} + operand_buffers_(operand_slices), + result_buffers_(result_slices) {} absl::Status ConvolutionReorderThunk::ExecuteOnStream( const ExecuteParams& params) { @@ -110,14 +112,13 @@ absl::Status ConvolutionReorderThunk::ExecuteOnStream( buffer_allocations.GetDeviceAddress(result_buffers_[1]))) : std::nullopt; - TF_RETURN_IF_ERROR(params.stream->CudnnReorderConvolutionFilterAndBias( - filter_descriptor_, filter_input, &filter_output, std::move(bias_input), - std::move(bias_output))); - - if (!params.stream->ok()) { - return Internal("ConvolutionReorderThunk::ExecuteOnStream failed."); + auto dnn = params.stream->parent()->AsDnn(); + if (dnn == nullptr) { + return absl::InternalError("No DNN for stream."); } - return absl::OkStatus(); + return dnn->CudnnReorderConvolutionFilterAndBias( + params.stream, filter_descriptor_, filter_input, &filter_output, + std::move(bias_input), std::move(bias_output)); } se::dnn::FilterDescriptor ConvolutionReorderThunk::CreateFilterDescriptor( diff --git a/third_party/xla/xla/service/gpu/runtime3/convolution_thunk.h b/third_party/xla/xla/service/gpu/runtime/convolution_thunk.h similarity index 82% rename from third_party/xla/xla/service/gpu/runtime3/convolution_thunk.h rename to third_party/xla/xla/service/gpu/runtime/convolution_thunk.h index 5dfaeda1624d42..02aecd464fce4a 100644 --- a/third_party/xla/xla/service/gpu/runtime3/convolution_thunk.h +++ b/third_party/xla/xla/service/gpu/runtime/convolution_thunk.h @@ -13,12 +13,13 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_GPU_RUNTIME3_CONVOLUTION_THUNK_H_ -#define XLA_SERVICE_GPU_RUNTIME3_CONVOLUTION_THUNK_H_ +#ifndef XLA_SERVICE_GPU_RUNTIME_CONVOLUTION_THUNK_H_ +#define XLA_SERVICE_GPU_RUNTIME_CONVOLUTION_THUNK_H_ #include #include "absl/container/flat_hash_map.h" +#include "absl/container/inlined_vector.h" #include "xla/service/buffer_assignment.h" #include "xla/service/gpu/gpu_conv_runner.h" #include "xla/service/gpu/thunk.h" @@ -63,9 +64,10 @@ class ConvolutionThunk : public Thunk { // Launches the kernel that reorders input data for int8x32 convolutions. class ConvolutionReorderThunk : public Thunk { public: - ConvolutionReorderThunk(ThunkInfo thunk_info, absl::Span filter_nchw, - std::vector operand_slices, - std::vector result_slices); + ConvolutionReorderThunk( + ThunkInfo thunk_info, absl::Span filter_nchw, + absl::InlinedVector operand_slices, + absl::InlinedVector result_slices); ConvolutionReorderThunk(const ConvolutionReorderThunk&) = delete; ConvolutionReorderThunk& operator=(const ConvolutionReorderThunk&) = delete; @@ -77,11 +79,11 @@ class ConvolutionReorderThunk : public Thunk { absl::Span filter_nchw); const se::dnn::FilterDescriptor filter_descriptor_; - std::vector operand_buffers_; - std::vector result_buffers_; + absl::InlinedVector operand_buffers_; + absl::InlinedVector result_buffers_; }; } // namespace gpu } // namespace xla -#endif // XLA_SERVICE_GPU_RUNTIME3_CONVOLUTION_THUNK_H_ +#endif // XLA_SERVICE_GPU_RUNTIME_CONVOLUTION_THUNK_H_ diff --git a/third_party/xla/xla/service/gpu/runtime3/copy_thunk.cc b/third_party/xla/xla/service/gpu/runtime/copy_thunk.cc similarity index 82% rename from third_party/xla/xla/service/gpu/runtime3/copy_thunk.cc rename to third_party/xla/xla/service/gpu/runtime/copy_thunk.cc index fed80deaf8a52d..0b372da9d66450 100644 --- a/third_party/xla/xla/service/gpu/runtime3/copy_thunk.cc +++ b/third_party/xla/xla/service/gpu/runtime/copy_thunk.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/runtime3/copy_thunk.h" +#include "xla/service/gpu/runtime/copy_thunk.h" #include @@ -28,14 +28,11 @@ namespace gpu { DeviceToDeviceCopyThunk::DeviceToDeviceCopyThunk( ThunkInfo thunk_info, const BufferAllocation::Slice& source_buffer, - const BufferAllocation::Slice& destination_buffer, uint64_t mem_size, - mlir::Value source_value, mlir::Value destination_value) + const BufferAllocation::Slice& destination_buffer, uint64_t mem_size) : Thunk(Kind::kCopy, thunk_info), source_buffer_(source_buffer), destination_buffer_(destination_buffer), - mem_size_(mem_size), - source_value_(source_value), - destination_value_(destination_value) {} + mem_size_(mem_size) {} absl::Status DeviceToDeviceCopyThunk::ExecuteOnStream( const ExecuteParams& params) { @@ -43,8 +40,7 @@ absl::Status DeviceToDeviceCopyThunk::ExecuteOnStream( params.buffer_allocations->GetDeviceAddress(destination_buffer_); se::DeviceMemoryBase source_data = params.buffer_allocations->GetDeviceAddress(source_buffer_); - params.stream->ThenMemcpy(&destination_data, source_data, mem_size_); - return absl::OkStatus(); + return params.stream->Memcpy(&destination_data, source_data, mem_size_); } } // namespace gpu } // namespace xla diff --git a/third_party/xla/xla/service/gpu/runtime3/copy_thunk.h b/third_party/xla/xla/service/gpu/runtime/copy_thunk.h similarity index 78% rename from third_party/xla/xla/service/gpu/runtime3/copy_thunk.h rename to third_party/xla/xla/service/gpu/runtime/copy_thunk.h index 6db53c6dbc9e9c..b2347657342a7f 100644 --- a/third_party/xla/xla/service/gpu/runtime3/copy_thunk.h +++ b/third_party/xla/xla/service/gpu/runtime/copy_thunk.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_GPU_RUNTIME3_COPY_THUNK_H_ -#define XLA_SERVICE_GPU_RUNTIME3_COPY_THUNK_H_ +#ifndef XLA_SERVICE_GPU_RUNTIME_COPY_THUNK_H_ +#define XLA_SERVICE_GPU_RUNTIME_COPY_THUNK_H_ #include "xla/service/buffer_assignment.h" #include "xla/service/gpu/thunk.h" @@ -31,8 +31,7 @@ class DeviceToDeviceCopyThunk : public Thunk { DeviceToDeviceCopyThunk(ThunkInfo thunk_info, const BufferAllocation::Slice& source_buffer, const BufferAllocation::Slice& destination_buffer, - uint64_t mem_size, mlir::Value source_value, - mlir::Value destination_value); + uint64_t mem_size); DeviceToDeviceCopyThunk(const DeviceToDeviceCopyThunk&) = delete; DeviceToDeviceCopyThunk& operator=(const DeviceToDeviceCopyThunk&) = delete; @@ -41,8 +40,6 @@ class DeviceToDeviceCopyThunk : public Thunk { void ClearCompileTimeInfo() override { Thunk::ClearCompileTimeInfo(); - source_value_ = nullptr; - destination_value_ = nullptr; } const BufferAllocation::Slice& source() const { return source_buffer_; } @@ -50,18 +47,14 @@ class DeviceToDeviceCopyThunk : public Thunk { return destination_buffer_; } uint64_t size_bytes() const { return mem_size_; } - mlir::Value source_value() const { return source_value_; } - mlir::Value destination_value() const { return destination_value_; } private: const BufferAllocation::Slice source_buffer_; const BufferAllocation::Slice destination_buffer_; const uint64_t mem_size_; - mlir::Value source_value_; - mlir::Value destination_value_; }; } // namespace gpu } // namespace xla -#endif // XLA_SERVICE_GPU_RUNTIME3_COPY_THUNK_H_ +#endif // XLA_SERVICE_GPU_RUNTIME_COPY_THUNK_H_ diff --git a/third_party/xla/xla/service/gpu/runtime/cub_sort.cc b/third_party/xla/xla/service/gpu/runtime/cub_sort.cc deleted file mode 100644 index e779a35c5cb856..00000000000000 --- a/third_party/xla/xla/service/gpu/runtime/cub_sort.cc +++ /dev/null @@ -1,100 +0,0 @@ -/* Copyright 2023 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "xla/service/gpu/runtime/cub_sort.h" - -#include - -#include "absl/status/status.h" -#include "xla/runtime/custom_call.h" -#include "xla/runtime/custom_call_registry.h" -#include "xla/runtime/executable.h" // IWYU pragma: keep -#include "xla/runtime/memref_view.h" -#include "xla/service/gpu/runtime/support.h" -#include "xla/service/service_executable_run_options.h" -#include "xla/stream_executor/device_memory.h" - -#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM -#include "xla/service/gpu/runtime3/cub_sort_thunk.h" -#endif - -namespace xla { -namespace gpu { -namespace { - -using ::stream_executor::DeviceMemoryBase; -using ::xla::runtime::CustomCall; -using ::xla::runtime::FlatMemrefView; - -absl::Status CubDeviceRadixSortKeysImpl( - const ServiceExecutableRunOptions* run_options, FlatMemrefView input_view, - FlatMemrefView output_view, FlatMemrefView scratch_view, bool descending) { -#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM - return RunCubSort(input_view.dtype, std::nullopt, - GetDeviceAddress(input_view), DeviceMemoryBase(), - GetDeviceAddress(output_view), DeviceMemoryBase(), - GetDeviceAddress(scratch_view), descending); -#else - return absl::UnimplementedError("CUB is not available"); -#endif -} - -absl::Status CubDeviceRadixSortPairsImpl( - const ServiceExecutableRunOptions* run_options, - FlatMemrefView input_keys_view, FlatMemrefView input_values_view, - FlatMemrefView output_keys_view, FlatMemrefView output_values_view, - FlatMemrefView scratch_view, bool descending) { -#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM - return RunCubSort( - input_keys_view.dtype, input_values_view.dtype, - GetDeviceAddress(input_keys_view), GetDeviceAddress(input_values_view), - GetDeviceAddress(output_keys_view), GetDeviceAddress(output_values_view), - GetDeviceAddress(scratch_view), descending); -#else - return absl::UnimplementedError("CUB is not available"); -#endif -} - -} // namespace - -XLA_RUNTIME_DEFINE_CUSTOM_CALL( - CubDeviceRadixSortKeys, FunctionWrapper(), - checks, - CustomCall::Bind("xla.gpu.radix_sort_keys") - .UserData() - .Arg() // input - .Arg() // output - .Arg() // scratch - .Attr("descending")); - -XLA_RUNTIME_DEFINE_CUSTOM_CALL( - CubDeviceRadixSortPairs, FunctionWrapper(), - checks, - CustomCall::Bind("xla.gpu.radix_sort_pairs") - .UserData() - .Arg() // input_keys - .Arg() // input_values - .Arg() // output_keys - .Arg() // output_values - .Arg() // scratch - .Attr("descending")); - -void RegisterCubSortCustomCalls(runtime::DirectCustomCallRegistry& registry) { - registry.Register("xla.gpu.radix_sort_keys", CubDeviceRadixSortKeys); - registry.Register("xla.gpu.radix_sort_pairs", CubDeviceRadixSortPairs); -} - -} // namespace gpu -} // namespace xla diff --git a/third_party/xla/xla/service/gpu/runtime/cub_sort.h b/third_party/xla/xla/service/gpu/runtime/cub_sort.h deleted file mode 100644 index 113e9bc4726648..00000000000000 --- a/third_party/xla/xla/service/gpu/runtime/cub_sort.h +++ /dev/null @@ -1,30 +0,0 @@ -/* Copyright 2023 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_SERVICE_GPU_RUNTIME_CUB_SORT_H_ -#define XLA_SERVICE_GPU_RUNTIME_CUB_SORT_H_ - -#include "xla/runtime/custom_call_registry.h" - -namespace xla { -namespace gpu { - -// Registers XLA Gpu runtime CUB sort custom calls. -void RegisterCubSortCustomCalls(runtime::DirectCustomCallRegistry& registry); - -} // namespace gpu -} // namespace xla - -#endif // XLA_SERVICE_GPU_RUNTIME_CUB_SORT_H_ diff --git a/third_party/xla/xla/service/gpu/runtime3/cub_sort_thunk.cc b/third_party/xla/xla/service/gpu/runtime/cub_sort_thunk.cc similarity index 96% rename from third_party/xla/xla/service/gpu/runtime3/cub_sort_thunk.cc rename to third_party/xla/xla/service/gpu/runtime/cub_sort_thunk.cc index 0263c81d65f89e..1e4fe32cd97db2 100644 --- a/third_party/xla/xla/service/gpu/runtime3/cub_sort_thunk.cc +++ b/third_party/xla/xla/service/gpu/runtime/cub_sort_thunk.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/runtime3/cub_sort_thunk.h" +#include "xla/service/gpu/runtime/cub_sort_thunk.h" #include #include @@ -23,6 +23,7 @@ limitations under the License. #include #include +#include "absl/container/inlined_vector.h" #include "absl/log/check.h" #include "absl/status/status.h" #include "absl/strings/str_cat.h" @@ -246,11 +247,12 @@ CubSortRunnerInterface::Create(PrimitiveType type, : CreateCubSortRunner(type); } -CubSortThunk::CubSortThunk(ThunkInfo thunk_info, PrimitiveType type, - std::optional value_type, - std::vector operands, - std::vector results, - BufferAllocation::Slice scratch, bool descending) +CubSortThunk::CubSortThunk( + ThunkInfo thunk_info, PrimitiveType type, + std::optional value_type, + absl::InlinedVector operands, + absl::InlinedVector results, + BufferAllocation::Slice scratch, bool descending) : Thunk(Thunk::kCubSort, thunk_info), runner_(CubSortRunnerInterface::Create(type, value_type).value()), operands_(std::move(operands)), diff --git a/third_party/xla/xla/service/gpu/runtime3/cub_sort_thunk.h b/third_party/xla/xla/service/gpu/runtime/cub_sort_thunk.h similarity index 86% rename from third_party/xla/xla/service/gpu/runtime3/cub_sort_thunk.h rename to third_party/xla/xla/service/gpu/runtime/cub_sort_thunk.h index ce0b4e05129d82..bd5440d5fca30e 100644 --- a/third_party/xla/xla/service/gpu/runtime3/cub_sort_thunk.h +++ b/third_party/xla/xla/service/gpu/runtime/cub_sort_thunk.h @@ -13,14 +13,15 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_GPU_RUNTIME3_CUB_SORT_THUNK_H_ -#define XLA_SERVICE_GPU_RUNTIME3_CUB_SORT_THUNK_H_ +#ifndef XLA_SERVICE_GPU_RUNTIME_CUB_SORT_THUNK_H_ +#define XLA_SERVICE_GPU_RUNTIME_CUB_SORT_THUNK_H_ #include #include #include #include +#include "absl/container/inlined_vector.h" #include "xla/service/buffer_assignment.h" #include "xla/service/gpu/thunk.h" #include "xla/stream_executor/device_memory.h" @@ -49,8 +50,8 @@ class CubSortThunk : public Thunk { public: CubSortThunk(ThunkInfo thunk_info, PrimitiveType type, std::optional value_type, - std::vector operands, - std::vector results, + absl::InlinedVector operands, + absl::InlinedVector results, BufferAllocation::Slice scratch, bool descending); absl::Status ExecuteOnStream(const ExecuteParams& params) override { @@ -64,8 +65,8 @@ class CubSortThunk : public Thunk { private: std::unique_ptr runner_; - std::vector operands_; - std::vector results_; + absl::InlinedVector operands_; + absl::InlinedVector results_; BufferAllocation::Slice scratch_; bool descending_; }; @@ -81,4 +82,4 @@ absl::Status RunCubSort(PrimitiveType type, } // namespace gpu } // namespace xla -#endif // XLA_SERVICE_GPU_RUNTIME3_CUB_SORT_THUNK_H_ +#endif // XLA_SERVICE_GPU_RUNTIME_CUB_SORT_THUNK_H_ diff --git a/third_party/xla/xla/service/gpu/runtime/custom_call.cc b/third_party/xla/xla/service/gpu/runtime/custom_call.cc deleted file mode 100644 index 837ad2484f0af8..00000000000000 --- a/third_party/xla/xla/service/gpu/runtime/custom_call.cc +++ /dev/null @@ -1,177 +0,0 @@ -/* Copyright 2022 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "xla/service/gpu/runtime/custom_call.h" - -#include -#include -#include - -#include "xla/runtime/executable.h" -#include "xla/service/custom_call_status_internal.h" -#include "xla/service/custom_call_target_registry.h" -#include "xla/service/gpu/cublas_cudnn.h" -#include "xla/service/gpu/runtime/support.h" -#include "xla/service/gpu/runtime/triangular_solve.h" -#include "xla/service/service_executable_run_options.h" - -#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM -#include "xla/stream_executor/gpu/gpu_stream.h" -#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM - -namespace xla { -namespace gpu { - -// Custom calls with API version API_VERSION_TYPED_FFI lowered directly to an -// Xla runtime custom calls. Older API versions handled by adapting Xla runtime -// calling convention to the calling convention expected by the registered -// handler. -// -// Once all Xla backends will use Xla runtime we will deprecate older API -// version, and migrate all users to API_VERSION_TYPED_FFI. -#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM - -using xla::runtime::CustomCall; -using xla::runtime::FlatMemrefView; -using xla::runtime::StridedMemrefView; - -static absl::Status XlaCustomCallImpl( - const ServiceExecutableRunOptions* run_options, - const DebugOptions* debug_options, CustomCall::RemainingArgs args, - std::string_view call_target_name, int32_t api_version, - std::string_view backend_config) { - // Pattern match custom call to a few special cases, otherwise find the custom - // call handler regustered with the runtime. - if (call_target_name == kTriangularSolveCallTarget) - return TriangularSolve::run(run_options, debug_options, args, - backend_config); - - // Find the Xla custom call handler. - auto& platform_name = run_options->stream()->parent()->platform()->Name(); - void* call_target = CustomCallTargetRegistry::Global()->Lookup( - std::string(call_target_name), platform_name); - if (!call_target) { - return absl::InvalidArgumentError(absl::StrCat( - "Cannot find the Xla custom call handler ", call_target_name)); - } - - // Prepare pointers to buffers to pass to the Xla custom call handler. - llvm::SmallVector buffers; - for (unsigned i = 0; i < args.size(); ++i) { - if (auto memref = args.get(i); succeeded(memref)) { - buffers.push_back(memref->data); - continue; - } - - if (auto strided = args.get(i); succeeded(strided)) { - buffers.push_back(strided->data); - continue; - } - - // TODO(ezhulenev): Add dialect and type to model Xla custom call holes, - // today we rely on the fact that custom calls do not support scalar - // arguments and we can disambiguate holes from real arguments. - if (auto hole = args.get(i); succeeded(hole)) { - buffers.push_back(nullptr); - continue; - } - - return absl::InvalidArgumentError( - "Failed to get arguments as (strided) memref view"); - } - - // Call custom call handler using the calling convention it requires. - using ApiVersion = CustomCallApiVersion; - - // Original custom call API version that doesn't support returning status. - if (api_version == ApiVersion::API_VERSION_ORIGINAL) { - using XlaCustomCallType = - void (*)(se::gpu::GpuStreamHandle, void**, const char*, size_t); - auto xla_call_target = reinterpret_cast(call_target); - - // As this is calling an external library, we should catch the - // error as there isn't another working correctly path to return - // an error to XLA. - try { - xla_call_target(se::gpu::AsGpuStreamValue(run_options->stream()), - buffers.data(), backend_config.data(), - backend_config.size()); - } catch (std::exception& e) { - return absl::UnknownError( - absl::StrCat(call_target_name, - " XLA extension have thrown an exception: ", e.what())); - } catch (...) { - return absl::UnknownError(absl::StrCat( - call_target_name, " XLA extension have thrown an exception.")); - } - - return absl::OkStatus(); - } - - // Xla Custom call API returning status. - if (api_version == ApiVersion::API_VERSION_STATUS_RETURNING || - api_version == ApiVersion::API_VERSION_STATUS_RETURNING_UNIFIED) { - using XlaCustomCallType = - void (*)(se::gpu::GpuStreamHandle, void**, const char*, size_t, - XlaCustomCallStatus*); - auto xla_call_target = reinterpret_cast(call_target); - - XlaCustomCallStatus custom_call_status; - // As this is calling an external library, we should catch the - // error as there isn't another working correctly path to return - // an error to XLA. - try { - xla_call_target(se::gpu::AsGpuStreamValue(run_options->stream()), - buffers.data(), backend_config.data(), - backend_config.size(), &custom_call_status); - } catch (std::exception& e) { - return absl::UnknownError( - absl::StrCat(call_target_name, - " XLA extension have thrown an exception: ", e.what())); - } catch (...) { - return absl::UnknownError(absl::StrCat( - call_target_name, " XLA extension have thrown an exception.")); - } - - if (auto message = CustomCallStatusGetMessage(&custom_call_status)) { - return absl::InternalError(message.value()); - } else { - return absl::OkStatus(); - } - } - - return absl::InvalidArgumentError( - absl::StrFormat("Unsupported custom call API version: %d", api_version)); -} - -XLA_RUNTIME_DEFINE_CUSTOM_CALL( - XlaCustomCall, FunctionWrapper(), checks, - runtime::CustomCall::Bind("xla.gpu.memcpy") - .UserData() - .UserData() - .Arg() // args - .Attr("call_target_name") - .Attr("api_version") - .Attr("backend_config")); - -void RegisterXlaClassicCustomCalls( - runtime::DirectCustomCallRegistry& registry) { - registry.Register("xla.gpu.custom_call", XlaCustomCall); -} - -#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM - -} // namespace gpu -} // namespace xla diff --git a/third_party/xla/xla/service/gpu/runtime/custom_call_registry.cc b/third_party/xla/xla/service/gpu/runtime/custom_call_registry.cc deleted file mode 100644 index 636c858a7dd069..00000000000000 --- a/third_party/xla/xla/service/gpu/runtime/custom_call_registry.cc +++ /dev/null @@ -1,46 +0,0 @@ -/* Copyright 2023 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "xla/service/gpu/runtime/custom_call_registry.h" - -#include -#include - -#include "xla/runtime/custom_call_registry.h" - -namespace xla::gpu { - -using DirectCustomCallRegistration = - std::function; - -static std::vector* -DirectCustomCallRegistrations() { - static auto* storage = new std::vector(); - return storage; -} - -void AddDirectCustomCallRegistration( - DirectCustomCallRegistration registration) { - DirectCustomCallRegistrations()->push_back(registration); -} - -// Registers all direct custom calls with the given registry. -void RegisterDirectCustomCalls(runtime::DirectCustomCallRegistry& registry) { - for (auto& registration : *DirectCustomCallRegistrations()) { - registration(registry); - } -} - -} // namespace xla::gpu diff --git a/third_party/xla/xla/service/gpu/runtime/custom_call_registry.h b/third_party/xla/xla/service/gpu/runtime/custom_call_registry.h deleted file mode 100644 index 20c761de1b690d..00000000000000 --- a/third_party/xla/xla/service/gpu/runtime/custom_call_registry.h +++ /dev/null @@ -1,62 +0,0 @@ -/* Copyright 2023 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_SERVICE_GPU_RUNTIME_CUSTOM_CALL_REGISTRY_H_ -#define XLA_SERVICE_GPU_RUNTIME_CUSTOM_CALL_REGISTRY_H_ - -#include - -#include "xla/runtime/custom_call_registry.h" - -namespace xla::gpu { - -// This is a static custom call registry for XLA:GPU executables. XLA runtime -// custom calls must not be confused with a "classic" custom calls, they are -// an internal implementation of XLA runtime (and XLA:GPU by extension), and -// do not provide stable ABI across dynamically loaded libraries. XLA runtime -// custom calls must be statically linked. -// -// XLA:FFI is the planned mechanism for registering "custom calls" via a stable -// C ABI for internal and external uses, however it's under construction. -// -// See more XLA runtime and XLA FFI plans here: -// https://docs.google.com/document/d/1XHzJyfq-ZFn9WHoKe4o_urnwS991dFHgWoNRboBK_3I/edit#bookmark=id.696pyshem503 -// -// XLA:FFI will become an official "external custom call" mechanism for XLA:GPU -// and XLA:CPU some time in 2024. - -// Adds a direct custom call registration function to a static registry. -void AddDirectCustomCallRegistration( - std::function registration); - -// Registers all direct custom calls with the given registry. -void RegisterDirectCustomCalls(runtime::DirectCustomCallRegistry& registry); - -//===----------------------------------------------------------------------===// -// Helper macro to define a static module registration. -//===----------------------------------------------------------------------===// - -#define XLA_GPU_REGISTER_RUNTIME_CUSTOM_CALL(FUNC) \ - XLA_GPU_REGISTER_RUNTIME_CUSTOM_CALL_IMPL(FUNC, __COUNTER__) - -#define XLA_GPU_REGISTER_RUNTIME_CUSTOM_CALL_IMPL(FUNC, N) \ - static bool xla_gpu_runtime_custom_call_##N##_registered_ = []() { \ - ::xla::gpu::AddDirectCustomCallRegistration(FUNC); \ - return true; \ - }() - -} // namespace xla::gpu - -#endif // XLA_SERVICE_GPU_RUNTIME_CUSTOM_CALL_REGISTRY_H_ diff --git a/third_party/xla/xla/service/gpu/runtime3/custom_call_thunk.cc b/third_party/xla/xla/service/gpu/runtime/custom_call_thunk.cc similarity index 99% rename from third_party/xla/xla/service/gpu/runtime3/custom_call_thunk.cc rename to third_party/xla/xla/service/gpu/runtime/custom_call_thunk.cc index 6329f6764e90c4..28a7dcebfc1dfa 100644 --- a/third_party/xla/xla/service/gpu/runtime3/custom_call_thunk.cc +++ b/third_party/xla/xla/service/gpu/runtime/custom_call_thunk.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/runtime3/custom_call_thunk.h" +#include "xla/service/gpu/runtime/custom_call_thunk.h" #include #include diff --git a/third_party/xla/xla/service/gpu/runtime3/custom_call_thunk.h b/third_party/xla/xla/service/gpu/runtime/custom_call_thunk.h similarity index 96% rename from third_party/xla/xla/service/gpu/runtime3/custom_call_thunk.h rename to third_party/xla/xla/service/gpu/runtime/custom_call_thunk.h index 55134f59232f89..5fa1dce32842fe 100644 --- a/third_party/xla/xla/service/gpu/runtime3/custom_call_thunk.h +++ b/third_party/xla/xla/service/gpu/runtime/custom_call_thunk.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_GPU_RUNTIME3_CUSTOM_CALL_THUNK_H_ -#define XLA_SERVICE_GPU_RUNTIME3_CUSTOM_CALL_THUNK_H_ +#ifndef XLA_SERVICE_GPU_RUNTIME_CUSTOM_CALL_THUNK_H_ +#define XLA_SERVICE_GPU_RUNTIME_CUSTOM_CALL_THUNK_H_ #include #include @@ -126,4 +126,4 @@ class CustomCallThunk : public Thunk { } // namespace gpu } // namespace xla -#endif // XLA_SERVICE_GPU_RUNTIME3_CUSTOM_CALL_THUNK_H_ +#endif // XLA_SERVICE_GPU_RUNTIME_CUSTOM_CALL_THUNK_H_ diff --git a/third_party/xla/xla/service/gpu/runtime/executable.cc b/third_party/xla/xla/service/gpu/runtime/executable.cc deleted file mode 100644 index 814bafa8df5013..00000000000000 --- a/third_party/xla/xla/service/gpu/runtime/executable.cc +++ /dev/null @@ -1,522 +0,0 @@ -/* Copyright 2022 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "xla/service/gpu/runtime/executable.h" - -#include -#include -#include -#include -#include -#include -#include - -#include "absl/container/inlined_vector.h" -#include "absl/status/status.h" -#include "xla/mlir/runtime/transforms/compilation_pipeline_gpu.h" -#include "xla/runtime/executable.h" -#include "xla/runtime/jit_executable.h" -#include "xla/service/gpu/non_atomically_upgradeable_rw_lock.h" -#include "xla/service/gpu/runtime/cholesky.h" -#include "xla/service/gpu/runtime/concurrent_region.h" -#include "xla/service/gpu/runtime/conv.h" -#include "xla/service/gpu/runtime/conv_reorder.h" -#include "xla/service/gpu/runtime/cub_sort.h" -#include "xla/service/gpu/runtime/custom_call.h" -#include "xla/service/gpu/runtime/custom_call_registry.h" -#include "xla/service/gpu/runtime/fft.h" -#include "xla/service/gpu/runtime/fused_attention.h" -#include "xla/service/gpu/runtime/gemm.h" -#include "xla/service/gpu/runtime/gpublas_lt_matmul.h" -#include "xla/service/gpu/runtime/graph_launch.h" -#include "xla/service/gpu/runtime/io_feed.h" -#include "xla/service/gpu/runtime/memcpy.h" -#include "xla/service/gpu/runtime/memset.h" -#include "xla/service/gpu/runtime/norm.h" -#include "xla/service/gpu/runtime/send_recv.h" -#include "xla/service/gpu/runtime/stream_synchronization.h" -#include "xla/service/gpu/runtime/support.h" -#include "xla/service/gpu/runtime/topk.h" -#include "xla/service/gpu/runtime/tracing.h" -#include "xla/service/gpu/thunk.h" -#include "xla/service/service_executable_run_options.h" -#include "xla/service/stream_pool.h" -#include "xla/statusor.h" -#include "xla/stream_executor/stream.h" -#include "tsl/protobuf/dnn.pb.h" - -namespace xla { -namespace gpu { - -using ::xla::runtime::CustomCallAttrEncodingSet; -using ::xla::runtime::DirectCustomCallRegistry; -using ::xla::runtime::Executable; -using ::xla::runtime::JitExecutable; -using ::xla::runtime::Tagged; -using ::xla::runtime::TypeIDNameRegistry; - -using ::xla::runtime::CustomCall; -using ::xla::runtime::DiagnosticEngine; -using ::xla::runtime::ExportModules; - -void RegisterXlaGpuRuntimeCustomCalls(DirectCustomCallRegistry& registry) { - // Register custom calls from a static XLA:GPU registry. - RegisterDirectCustomCalls(registry); - - // Register builtin XLA:GPU custom calls (aka GPU runtime). - RegisterKernelLaunchCustomCalls(registry); - RegisterTracingCustomCalls(registry); - RegisterFftCustomCalls(registry); - RegisterCholeskyCustomCalls(registry); - RegisterCollectiveCustomCalls(registry); - RegisterGemmCustomCalls(registry); - RegisterConvCustomCalls(registry); - RegisterConvReorderCustomCalls(registry); - RegisterMemcpyCustomCalls(registry); - RegisterIoFeedCustomCalls(registry); - RegisterMemsetCustomCalls(registry); - RegisterSendRecvCustomCalls(registry); - -#if GOOGLE_CUDA || TF_HIPBLASLT - RegisterMatmulCustomCalls(registry); -#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM -#if GOOGLE_CUDA - RegisterNormCustomCalls(registry); - RegisterFusedAttentionCustomCalls(registry); - RegisterFusedAttentionBackwardCustomCalls(registry); -#endif // GOOGLE_CUDA -#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM - // Graph launch kernels depend on Cuda Graph API. - RegisterGraphLaunchCustomCalls(registry); - RegisterConcurrentRegionCustomCalls(registry); - RegisterStreamSynchronizationCustomCalls(registry); - RegisterCubSortCustomCalls(registry); - RegisterXlaClassicCustomCalls(registry); - RegisterTopkCustomCall(registry); -#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM -} - -void RegisterXlaGpuTypeIdNames(TypeIDNameRegistry& registry) { - registry.Register>( - "__type_id_se_dnn_activation"); - registry.Register>( - "__type_id_dot_dimension_numbers"); - registry.Register>("__type_id_se_fft_type"); - - RegisterTracingTypeIdNames(registry); - RegisterConvTypeIdNames(registry); - RegisterSendRecvTypeIdNames(registry); - -#if GOOGLE_CUDA || TF_HIPBLASLT - registry.Register>( - "__type_id_se_gpublas_lt_epilogue"); - RegisterFusedAttentionTypeIdNames(registry); - RegisterNormTypeIdNames(registry); -#endif // GOOGLE_CUDA || TF_HIPBLASLT -} - -void RegisterXlaGpuAttrEncoding(CustomCallAttrEncodingSet& encoding) { - PopulateConvAttrEncoding(encoding); - PopulateFftAttrEncoding(encoding); - PopulateDotDimsAttrEncoding(encoding); - PopulateSendRecvAttrEncoding(encoding); - -#if GOOGLE_CUDA || TF_HIPBLASLT - PopulateCublasLtMatmulAttrEncoding(encoding); - PopulateFusedAttentionAlgorithmConfigAttrEncoding(encoding); - PopulateFusedAttentionForwardDAGSignatureAttrEncoding(encoding); - PopulateFusedAttentionBackwardDAGSignatureAttrEncoding(encoding); - PopulateNormAlgorithmConfigAttrEncoding(encoding); -#endif // GOOGLE_CUDA || TF_HIPBLASLT -} - -//===----------------------------------------------------------------------===// - -// Executable can have only one "main" function and only graph capture function. -static int64_t GetNumGraphs(const runtime::Executable& executable) { - return executable.num_functions() - 1; -} - -GpuRuntimeExecutable::GpuRuntimeExecutable( - std::string module_name, std::vector buffer_sizes, - std::vector> allocation_indices, - std::unique_ptr jit_executable, DebugOptions debug_options, - ModulesState modules_state) - : module_name_(std::move(module_name)), - buffer_sizes_(std::move(buffer_sizes)), - allocation_indices_(std::move(allocation_indices)), - executable_(std::move(jit_executable)), - debug_options_(std::move(debug_options)), -#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM - graph_instances_(module_name_, GetNumGraphs(executable())), -#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM - modules_state_(std::move(modules_state)) { - ExportModules(dynamic_custom_calls_); // export runtime modules -} - -GpuRuntimeExecutable::GpuRuntimeExecutable( - std::string module_name, std::vector buffer_sizes, - std::vector> allocation_indices, - std::unique_ptr aot_executable, DebugOptions debug_options, - ModulesState modules_state) - : module_name_(std::move(module_name)), - buffer_sizes_(std::move(buffer_sizes)), - allocation_indices_(std::move(allocation_indices)), - executable_(std::move(aot_executable)), - debug_options_(std::move(debug_options)), -#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM - graph_instances_(module_name_, GetNumGraphs(executable())), -#endif // GOOGL_CUDA || TENSORFLOW_USE_ROCM - modules_state_(std::move(modules_state)) { - ExportModules(dynamic_custom_calls_); // export runtime modules -} - -//===----------------------------------------------------------------------===// -// Compile Xla program lowered to runtime dialects to Gpu runtime executable. -//===----------------------------------------------------------------------===// - -/*static*/ absl::StatusOr> -GpuRuntimeExecutable::Create(std::string module_name, - std::unique_ptr program) { - // Options for the default XLA Runtime compilation pipeline. - runtime::CompilationPipelineOptions copts; - - // Populate mapping from XLA (SE) enums/structs type id to symbol names. - copts.populate_type_id_names = RegisterXlaGpuTypeIdNames; - - // For passing LMHLO attributes as XLA (SE) enums/structs to custom calls. - copts.populate_attr_encodings = RegisterXlaGpuAttrEncoding; - - // Options for constructing XLA runtime JitExecutable. - JitExecutable::Options opts; - opts.specialization = JitExecutable::Specialization::kDisabled; - opts.compiler.verification_level = - program->debug_options.xla_gpu_llvm_verification_level(); - opts.compiler.register_dialects = - runtime::RegisterDefaultXlaGpuRuntimeDialects; - - // Register XLA Gpu runtime custom calls with the linker. - opts.compiler.symbols_binding = runtime::ToSymbolsBinding( - RegisterXlaGpuRuntimeCustomCalls, RegisterXlaGpuTypeIdNames); - - // We just use the default compilation pipeline provided by the XLA runtime. - // Alternatively instead of having a separate Xla Runtime program (LMHLO - // lowered to canonical dialects), we can assemble a pipeline that will - // compile starting from the LMHLO dialect. However this intermediate step - // helps with debugging, by materializing IR with XLA runtime custom calls. - opts.compiler.create_compilation_pipeline = - [copts](xla::runtime::PassManager& passes) { - runtime::CreateDefaultXlaGpuRuntimeCompilationPipeline(passes, copts); - return absl::OkStatus(); - }; - - // Do not run expensive optimization passes because we do not expect any - // non-trivial host code in XLA:GPU host executables. - opts.compiler.jit_code_opt_level = llvm::CodeGenOptLevel::None; - - // Instantiate new JitExecutable from the MLIR source. - auto jit_executable = - JitExecutable::Instantiate(program->module, program->entry_point, opts); - if (!jit_executable.ok()) - return Internal("Failed to compile XLA Runtime program: %s", - jit_executable.status().message()); - - // Instantiate state for all registered runtime modules. - auto modules_state = ModulesState::Instantiate(); - if (!modules_state.ok()) - return Internal("Failed to instantiate modules state: %s", - modules_state.status().message()); - - return std::unique_ptr(new GpuRuntimeExecutable( - std::move(module_name), std::move(program->buffer_sizes), - std::move(program->allocation_indices), - std::make_unique(std::move(*jit_executable)), - std::move(program->debug_options), std::move(*modules_state))); -} - -//===----------------------------------------------------------------------===// -// Constructs Gpu runtime executable from AOT compiled runtime artifact. -//===----------------------------------------------------------------------===// - -/*static*/ absl::StatusOr> -GpuRuntimeExecutable::Create( - std::string module_name, std::vector buffer_sizes, - std::vector> allocation_indices, Executable executable, - DebugOptions debug_options) { - // Instantiate state for all registered runtime modules. - auto modules_state = ModulesState::Instantiate(); - if (!modules_state.ok()) - return Internal("Failed to instantiate modules state: %s", - modules_state.status().message()); - - return std::unique_ptr(new GpuRuntimeExecutable( - std::move(module_name), std::move(buffer_sizes), - std::move(allocation_indices), - std::make_unique(std::move(executable)), - std::move(debug_options), std::move(*modules_state))); -} - -//===----------------------------------------------------------------------===// -// Executes with the given buffer arguments. -//===----------------------------------------------------------------------===// - -static runtime::AsyncTaskRunner* NoAsyncTaskRunner() { - return reinterpret_cast(0XDEADBEEF); -} - -// TODO(ezhulenev): We rely on implementation details of passing memrefs to the -// compiled kernel. We should have a nicer API to do this, without creating a -// vector of temporary MemrefDesc for passing operands. -static void InitializeCallFrame(runtime::Executable::CallFrame& call_frame, - const BufferAllocations& buffer_allocations, - absl::Span buffer_sizes, - llvm::SmallVectorImpl& ptrs) { - size_t num_allocations = buffer_allocations.size(); - assert(ptrs.empty() && "pointers storage must be empty"); - ptrs.resize_for_overwrite(num_allocations); - - // Each buffer allocation passed as 1d memref to the compiled function: - // {basePtr, dataPtr, offset, [sizes, ...], [strides, ...]} - size_t num_args_ptrs = 1 + num_allocations * 5; - call_frame.args.resize_for_overwrite(num_args_ptrs); - - // Pass pointers to these constants as a memref offset and stride. - static int64_t zero = 0; - static int64_t one = 1; - void* offset = &zero; - void* stride = &one; - - // Add a placeholder for the kernel context as the first argument. - call_frame.args[0] = nullptr; - - // Initialize arguments for the buffer operands. - for (unsigned i = 0; i < num_allocations; ++i) { - void* data = &(ptrs[i] = buffer_allocations.GetDeviceAddress(i).opaque()); - void* size = const_cast(&buffer_sizes[i]); - unsigned idx = 1 + i * 5; - call_frame.args[idx + 0] = data; - call_frame.args[idx + 1] = data; - call_frame.args[idx + 2] = offset; - call_frame.args[idx + 3] = size; - call_frame.args[idx + 4] = stride; - } -} - -absl::Status GpuRuntimeExecutable::Execute( - const ServiceExecutableRunOptions* run_options, const std::string& asm_text, - const std::vector& binary, - const BufferAllocations& buffer_allocations, - NonAtomicallyUpgradeableRWLock& gpu_lock, - const BufferAllocation* temp_alloc) { - // We pass a pointer to the executable through UserData, so that we can - // get access to other exported functions from custom call handlers. - runtime::Executable& executable = this->executable(); - - // Pack buffer allocations as executable arguments. It is guaranteed that - // the compiled function will make a copy of all arguments and will write all - // results after the call to `Execute` completes, so it is safe to keep them - // on the stack. - runtime::Executable::CallFrame call_frame; - - llvm::SmallVector ptrs; // storage for device address pointers - InitializeCallFrame(call_frame, buffer_allocations, buffer_sizes_, ptrs); - - // Check that initialized call frame is compatible with the executable - // entry point signature, otherwise compiled executable can read memory out of - // arguments bounds and crash with a segfault. - const runtime::FunctionType& signature = executable.signature(); - if (signature.num_operands() != buffer_allocations.size()) - return Internal("Expected %d arguments but got %d buffer allocations", - signature.num_operands(), buffer_allocations.size()); - - for (unsigned i = 0; i < executable.signature().num_operands(); ++i) { - auto* memref = llvm::dyn_cast(signature.operand(i)); - if (!memref) return InvalidArgument("Expected memref as %d-th argument", i); - - if (memref->rank() != 1 || memref->sizes()[0] != buffer_sizes_[i]) - return InvalidArgument("Expected a buffer of size %d but got %d", - memref->sizes()[0], buffer_sizes_[i]); - } - - // XLA Runtime executables do not return any values. - runtime::NoResultConverter converter; - - // Get the async communications stream for async collectives. - se::StreamExecutor* executor = run_options->stream()->parent(); - se::StreamPriority stream_priority = se::StreamPriority::Default; - if (debug_options_.xla_gpu_enable_highest_priority_async_stream()) { - stream_priority = se::StreamPriority::Highest; - } - - // Create the needed streams to support NcclCollectiveThunk. - // - // Calling BorrowStream multiple times doesn't work as intended, see - // b/293945751. - absl::InlinedVector async_comm_streams( - kAsyncStreamTotal, nullptr); - absl::StatusOr> streams = - run_options->BorrowStreams(executor->device_ordinal(), kAsyncStreamTotal, - stream_priority); - if (streams.ok()) { - for (int64_t i = 0; i < kAsyncStreamTotal; ++i) { - async_comm_streams[i] = streams->at(i).get(); - } - } - - // Async Collectives support and Send/Recv events instantiated for each Gpu - // executable run, so that concurrent executions can run independently using a - // separate set of events for communication. - AsyncCollectivesSupport async_collectives(async_comm_streams); - SendRecvEvents send_recv_events; - - // Always pass in the temp buffer, even if it is null, to accommodate the - // 0-sized buffer corner case. - se::DeviceMemoryBase temp_buffer; - if (temp_alloc) - temp_buffer = buffer_allocations.GetDeviceAddress(temp_alloc->index()); - - // State cached separately for each stream executor. - StreamExecutorKernels::Snapshot kernels = gpu_kernels_(executor)->snapshot(); - StreamExecutorConvRunners::Snapshot conv_runners = - conv_runners_(executor)->snapshot(); - -#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM - std::shared_ptr executor_graphs = - graph_instances_(executor); - - StreamExecutorGraphInstances::Snapshot graph_instances = - executor_graphs->snapshot(); - CapturedFunctionExecutionCount::Snapshot execution_count = - captured_function_counts_(executor)->snapshot(); -#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM - - // Kernels in concurrent regions should be launched on borrowed stream, so - // that the cuda graph won't record dependencies between kernels. - // This state stores if the kernel being run is in a concurrent region and - // the borrowed streams for executing kernels in concurrent regions. - ConcurrentRegionStatus concurrent_region_status(run_options); - - // State cached globally for gpu executable. - GemmConfigs::Snapshot gemm_configs = gemm_configs_.snapshot(); - FftPlans::Snapshot fft_plans = fft_plans_.snapshot(); - -#if GOOGLE_CUDA || TF_HIPBLASLT - MatmulPlans::Snapshot matmul_plans = gpublas_lt_matmul_plans_.snapshot(); -#endif - -#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM - StreamExecutorNormRunners::Snapshot norm_runners = - norm_runners_(executor)->snapshot(); - StreamExecutorFusedAttentionRunners::Snapshot fused_attention_runners = - fused_attention_runners_(executor)->snapshot(); - StreamExecutorFusedAttentionBackwardRunners::Snapshot - fused_attention_backward_runners = - fused_attention_backward_runners_(executor)->snapshot(); -#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM - - // Pass auxiliary data to the custom call handlers. - runtime::CustomCall::UserData user_data( - run_options, &executable, &debug_options_, &temp_buffer, &asm_text, - &binary, &kernels, &gemm_configs, &conv_runners, &collectives_, - &fft_plans, &send_recv_events, &gpu_lock, -#if GOOGLE_CUDA || TF_HIPBLASLT - &matmul_plans, -#endif -#if GOOGLE_CUDA - // Auxiliary data that is available only if compiled with CUDA support - // only. - &norm_runners, &fused_attention_runners, - &fused_attention_backward_runners, -#endif // GOOGLE_CUDA -#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM - &graph_instances, &execution_count, -#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM - &concurrent_region_status, - // Null pointer will be interpreted as an absence of async collectives - // support and custom calls will safely return an error. - async_collectives.async_comm_stream(AsyncStreamKind::kCollective) - ? &async_collectives - : nullptr); - - // Initialize state required for running functions from registered modules. - auto state_ref = modules_state_.InitializeUserData(user_data); - if (!state_ref.ok()) - return Internal("Failed to initialize runtime modules state: %s", - state_ref.status().message()); - -#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM - // Instantiate all CUDA graphs before executing the main function. - if (debug_options_.xla_gpu_graph_num_runs_to_instantiate() < 0 && - !graph_instances_.InstantiatedAllGraphs(run_options, executable)) { - if (auto instantiated = graph_instances_.InstantiateAllGraphs( - run_options, executable, user_data, buffer_allocations, - buffer_sizes_, allocation_indices_, - debug_options_.xla_gpu_graph_eviction_timeout_seconds()); - !instantiated.ok()) { - return Internal("Failed to instantiate GPU graphs: %s", - instantiated.message()); - } - } -#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM - - // Collect all emitted diagnostic messages. - std::string diagnostic; - runtime::DiagnosticEngine diagnostic_engine; - AppendDiagnosticToString(diagnostic_engine, &diagnostic, true); - - // Prepare options for executing XLA Runtime program. - runtime::Executable::ExecuteOpts opts; - opts.async_task_runner = NoAsyncTaskRunner(); - opts.custom_call_data = &user_data; - opts.diagnostic_engine = &diagnostic_engine; - opts.custom_call_registry = &dynamic_custom_calls_; - - // Execute with the prepared call frame. - executable.Execute(call_frame, opts); - - if (auto st = executable.ReturnResults(converter, &call_frame); !st.ok()) { - return Internal("Failed to execute XLA Runtime executable: %s%s%s.", - st.message(), diagnostic.empty() ? "" : ": ", - diagnostic); - } - - return absl::OkStatus(); -} - -//===----------------------------------------------------------------------===// - -const Executable& GpuRuntimeExecutable::executable() const { - if (auto* jit = std::get_if>(&executable_)) { - return *(*jit)->DefaultExecutable(); - } - return *std::get>(executable_); -} - -absl::StatusOr GpuRuntimeExecutable::GetObjFile() const { - if (auto obj_file = executable().obj_file()) - return std::string_view(obj_file->getBuffer()); - - return Internal("gpu runtime executable didn't save the obj file"); -} - -absl::StatusOr GpuRuntimeExecutable::GetMlirModule() const { - const auto* jit = std::get_if>(&executable_); - if (!jit) return Internal("MLIR module is not available"); - - return (*jit)->mlir_module(); -} - -} // namespace gpu -} // namespace xla diff --git a/third_party/xla/xla/service/gpu/runtime/executable.h b/third_party/xla/xla/service/gpu/runtime/executable.h deleted file mode 100644 index 57c54186697524..00000000000000 --- a/third_party/xla/xla/service/gpu/runtime/executable.h +++ /dev/null @@ -1,210 +0,0 @@ -/* Copyright 2022 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_SERVICE_GPU_RUNTIME_EXECUTABLE_H_ -#define XLA_SERVICE_GPU_RUNTIME_EXECUTABLE_H_ - -#include -#include -#include -#include -#include -#include -#include - -#include "xla/runtime/executable.h" -#include "xla/runtime/jit_executable.h" -#include "xla/runtime/module_registry.h" -#include "xla/service/gpu/buffer_allocations.h" -#include "xla/service/gpu/non_atomically_upgradeable_rw_lock.h" -#include "xla/service/gpu/runtime/collectives.h" -#include "xla/service/gpu/runtime/conv.h" -#include "xla/service/gpu/runtime/fft.h" -#include "xla/service/gpu/runtime/fused_attention.h" -#include "xla/service/gpu/runtime/gemm.h" -#include "xla/service/gpu/runtime/gpublas_lt_matmul.h" -#include "xla/service/gpu/runtime/graph_launch.h" -#include "xla/service/gpu/runtime/kernel_launch.h" -#include "xla/service/gpu/runtime/norm.h" -#include "xla/service/service_executable_run_options.h" -#include "xla/xla.pb.h" - -namespace xla { -namespace gpu { - -// Register custom calls implementing Xla Gpu runtime. -void RegisterXlaGpuRuntimeCustomCalls( - runtime::DirectCustomCallRegistry& registry); - -// Register mapping from XLA (SE) enums/structs type ids to symbol names. -void RegisterXlaGpuTypeIdNames(runtime::TypeIDNameRegistry& registry); - -// Register encoding for (L)MHLO attributes required by the runtime functions. -void RegisterXlaGpuAttrEncoding(runtime::CustomCallAttrEncodingSet& encoding); - -// Xla Gpu program lowered to the Xla runtime dialects. Gpu runtime executable -// jit-compiles this program to an executable artifact (via lowering to LLVM). -// -// We have this program as an intermediate step between lowering from HLO to -// runtime executable to be able to introspect the compilation process. Once we -// have this program, the Xla gpu compiler job is done, and lowering to LLVM is -// the responsibility of backend-agnostic Xla runtime passes. This is the last -// stage when IR is still at a fairly high level of abstraction and has a lot of -// Gpu specific details in it. -struct GpuRuntimeProgram { - GpuRuntimeProgram(std::string entry_point, std::string module, - std::vector buffer_sizes, - std::vector> allocation_indices, - DebugOptions debug_options) - : entry_point(std::move(entry_point)), - module(std::move(module)), - buffer_sizes(std::move(buffer_sizes)), - allocation_indices(std::move(allocation_indices)), - debug_options(std::move(debug_options)) {} - - std::string entry_point; - std::string module; - std::vector buffer_sizes; - std::vector> allocation_indices; - DebugOptions debug_options; -}; - -// Gpu runtime executable encapsulates the Xla runtime executable compiled from -// an Xla program and owns all the state required for running it (e.g. it owns -// various caches required for performance). -// -// TODO(ezhulenev): Once thunks are removed from Xla, it might make sense to -// merge this executable into GpuExecutable. Today we keep it separate to manage -// the complexity of mixing two execution modes in the same file. GpuExecutable -// provides an API at XLA level of abstraction (streams and buffers), and this -// executable provides a lower level API exposing some of the implementation -// details. -class GpuRuntimeExecutable { - using ModulesState = ::xla::runtime::ModulesState; - - public: - // Creates GpuRuntimeExecutable from the Xla Gpu Program. - static absl::StatusOr> Create( - std::string module_name, std::unique_ptr program); - - // Creates GpuRuntimeExecutable from the AOT compiled binary. - static absl::StatusOr> Create( - std::string module_name, std::vector buffer_sizes, - std::vector> allocation_indices, - runtime::Executable executable, DebugOptions debug_options); - - // Executes entry function with the given buffer arguments. - absl::Status Execute(const ServiceExecutableRunOptions* run_options, - const std::string& asm_text, - const std::vector& binary, - const BufferAllocations& buffer_allocations, - NonAtomicallyUpgradeableRWLock& gpu_lock, - const BufferAllocation* temp_alloc = nullptr); - - // Returns object file behind the runtime executable. This object file can - // be exported and loaded later to instantiate another executable. - absl::StatusOr GetObjFile() const; - - // Returns MLIR module behind this executable if it is available. - absl::StatusOr GetMlirModule() const; - - std::string_view module_name() const { return module_name_; } - - private: - GpuRuntimeExecutable(std::string module_name, - std::vector buffer_sizes, - std::vector> allocation_indices, - std::unique_ptr jit_executable, - DebugOptions debug_options, ModulesState modules_state); - - GpuRuntimeExecutable(std::string module_name, - std::vector buffer_sizes, - std::vector> allocation_indices, - std::unique_ptr aot_executable, - DebugOptions debug_options, ModulesState modules_state); - - std::string module_name_; - - // Depending on the state of `executable_` returns a reference to active - // Xla runtime executable. - runtime::Executable& executable() { - return const_cast( - const_cast(this)->executable()); - } - const runtime::Executable& executable() const; - - std::vector buffer_sizes_; - - // `rt.allocation_index` attributes for all exported functions. Indexed by - // function ordinal. - std::vector> allocation_indices_; - - // In JIT compilation mode `JitExecutable` is used. In AOT compilation mode - // `Executable` is used. - std::variant, - std::unique_ptr> - executable_; - - const DebugOptions debug_options_; - - // Keep gpu kernels loaded by this executable. - GpuExecutableKernels gpu_kernels_; - - // Keep gemm configs for all gemm operation in the program. - GemmConfigs gemm_configs_; - - // Keep a cache for conv configs for all conv operations in the program. - ConvRunners conv_runners_; - - // Keep a cache for fused norm configs for all fused norm operations in the - // program. - NormRunnerStates norm_runners_; - - // Keep a cache for fused_dot_attention configs for all fused_dot_attention - // operations in the program. - FusedAttentionRunners fused_attention_runners_; - - // Keep a cache for fused_dot_attention configs for all fused_dot_attention - // backward - // operations in the program. - FusedAttentionBackwardRunners fused_attention_backward_runners_; - - // Support for running collective operations. - CollectivesSupport collectives_; - - // Keep a cache of fft plans for all FFT operations in the program. - FftPlans fft_plans_; - -#if GOOGLE_CUDA || TF_HIPBLASLT // Keep matmul execution plans. - MatmulPlans gpublas_lt_matmul_plans_; -#endif - -#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM - // Keep captured and instantiated GPU graphs instances. - GraphInstances graph_instances_; - CapturedFunctionExecutionCounts captured_function_counts_; -#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM - - // Keep an executable state for all registered runtime modules. - ModulesState modules_state_; - - // Dynamic custom calls exported from XLA runtime modules (and FFI modules). - runtime::DynamicCustomCallRegistry dynamic_custom_calls_; -}; - -} // namespace gpu -} // namespace xla - -#endif // XLA_SERVICE_GPU_RUNTIME_EXECUTABLE_H_ diff --git a/third_party/xla/xla/service/gpu/runtime/fft.cc b/third_party/xla/xla/service/gpu/runtime/fft.cc deleted file mode 100644 index 11c37aeb363d3c..00000000000000 --- a/third_party/xla/xla/service/gpu/runtime/fft.cc +++ /dev/null @@ -1,133 +0,0 @@ -/* Copyright 2022 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "xla/service/gpu/runtime/fft.h" - -#include - -#include "xla/mlir/runtime/transforms/custom_call_encoding.h" -#include "xla/runtime/custom_call.h" -#include "xla/runtime/executable.h" -#include "xla/runtime/state.h" -#include "xla/service/gpu/runtime/support.h" -#include "xla/service/gpu/runtime3/fft_thunk.h" -#include "xla/stream_executor/fft.h" - -namespace xla { - -using xla::runtime::CustomCall; -using xla::runtime::State; -using xla::runtime::StridedMemrefView; - -//===----------------------------------------------------------------------===// -// Register FFT attributes decoding with the Xla runtime. -//===----------------------------------------------------------------------===// - -namespace runtime { - -XLA_RUNTIME_REGISTER_ENUM_ATTR_DECODING(se::fft::Type); - -} // namespace runtime - -//===----------------------------------------------------------------------===// -// Encoding from MHLO attributes to Xla runtime aggregate attributes. -//===----------------------------------------------------------------------===// - -namespace gpu { - -namespace mhlo = ::mlir::mhlo; - -static se::fft::Type ConvertFftType(mhlo::FftType type) { - switch (type) { - case mhlo::FftType::FFT: - return se::fft::Type::kC2CForward; - case mhlo::FftType::IFFT: - return se::fft::Type::kC2CInverse; - case mhlo::FftType::RFFT: - return se::fft::Type::kR2C; - case mhlo::FftType::IRFFT: - return se::fft::Type::kC2R; - default: - return se::fft::Type::kInvalid; - } -} - -void PopulateFftAttrEncoding(runtime::CustomCallAttrEncodingSet& encoding) { - encoding.Add>(ConvertFftType); -} - -//===----------------------------------------------------------------------===// -// FFT custom call implementation. -//===----------------------------------------------------------------------===// - -static absl::Status FftImpl(const ServiceExecutableRunOptions* run_options, - State> state, - StridedMemrefView input, StridedMemrefView output, - absl::Span fft_length, - se::fft::Type fft_type) { - se::Stream* stream = run_options->stream(); - se::StreamExecutor* executor = stream->parent(); - - if (input.dtype == PrimitiveType::F64 || input.dtype == PrimitiveType::C128) { - // Adjust FFT type to reflect double precision. - switch (fft_type) { - case se::fft::Type::kC2CForward: - fft_type = se::fft::Type::kZ2ZForward; - break; - case se::fft::Type::kC2CInverse: - fft_type = se::fft::Type::kZ2ZInverse; - break; - case se::fft::Type::kR2C: - fft_type = se::fft::Type::kD2Z; - break; - case se::fft::Type::kC2R: - fft_type = se::fft::Type::kZ2D; - break; - default: - return absl::InvalidArgumentError("Unsupported FFT type"); - } - } - - TF_ASSIGN_OR_RETURN( - std::unique_ptr * fft_plan_cache, - state.GetOrCreate([]() -> absl::StatusOr> { - return std::make_unique(); - })); - - return RunFft(GetDeviceAddress(input), ToShape(input), - GetDeviceAddress(output), ToShape(output), fft_type, fft_length, - executor->device_ordinal(), fft_plan_cache->get(), stream, - run_options->allocator()); -} - -XLA_RUNTIME_DEFINE_CUSTOM_CALL( - Fft, FunctionWrapper(), checks, - CustomCall::Bind("xla.gpu.fft") - .UserData() - .State>("uid") - .Arg() // input - .Arg() // output - .Attr>("fft_length") - .Attr("fft_type")); - -//===----------------------------------------------------------------------===// - -void RegisterFftCustomCalls(runtime::DirectCustomCallRegistry& registry) { - registry.Register("xla.gpu.fft", Fft); -} - -} // namespace gpu -} // namespace xla diff --git a/third_party/xla/xla/service/gpu/runtime/fft.h b/third_party/xla/xla/service/gpu/runtime/fft.h deleted file mode 100644 index e57572c043607d..00000000000000 --- a/third_party/xla/xla/service/gpu/runtime/fft.h +++ /dev/null @@ -1,40 +0,0 @@ -/* Copyright 2022 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_SERVICE_GPU_RUNTIME_FFT_H_ -#define XLA_SERVICE_GPU_RUNTIME_FFT_H_ - -#include - -#include "xla/mlir/runtime/transforms/custom_call_encoding.h" -#include "xla/runtime/custom_call_registry.h" -#include "xla/service/gpu/runtime3/fft_thunk.h" - -namespace xla { -namespace gpu { - -// Registers XLA Gpu runtime fft custom calls. -void RegisterFftCustomCalls(runtime::DirectCustomCallRegistry& registry); - -// Adds attributes encoding set for fft custom calls -void PopulateFftAttrEncoding(runtime::CustomCallAttrEncodingSet& encoding); - -// Keep FftPlanCache for all FFT instances in the executable. -class FftPlans : public runtime::StateVector> {}; - -} // namespace gpu -} // namespace xla - -#endif // XLA_SERVICE_GPU_RUNTIME_FFT_H_ diff --git a/third_party/xla/xla/service/gpu/runtime3/fft_thunk.cc b/third_party/xla/xla/service/gpu/runtime/fft_thunk.cc similarity index 76% rename from third_party/xla/xla/service/gpu/runtime3/fft_thunk.cc rename to third_party/xla/xla/service/gpu/runtime/fft_thunk.cc index 74195445436ede..728c36752aeed5 100644 --- a/third_party/xla/xla/service/gpu/runtime3/fft_thunk.cc +++ b/third_party/xla/xla/service/gpu/runtime/fft_thunk.cc @@ -13,11 +13,12 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/runtime3/fft_thunk.h" +#include "xla/service/gpu/runtime/fft_thunk.h" #include #include "absl/status/status.h" +#include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" #include "xla/stream_executor/scratch_allocator.h" @@ -66,6 +67,22 @@ std::string FftTypeToString(se::fft::Type type) { } } +absl::StatusOr GetBlas( + se::Stream* stream) { + auto blas = stream->parent()->AsBlas(); + if (blas == nullptr) { + return absl::InternalError("Unable to get Blas support"); + } + return blas; +} + +absl::StatusOr GetFft(se::Stream* stream) { + auto fft = stream->parent()->AsFft(); + if (fft == nullptr) { + return absl::InternalError("Unable to get fft support"); + } + return fft; +} } // namespace FftThunk::FftThunk(ThunkInfo thunk_info, FftType fft_type, @@ -113,6 +130,7 @@ absl::Status RunFft(se::DeviceMemoryBase input, const Shape& input_shape, // protect each plan with a mutex. absl::MutexLock lock(&fft_plan_ptr->mu); std::unique_ptr& fft_plan = fft_plan_ptr->plan; + TF_ASSIGN_OR_RETURN(auto fft, GetFft(stream)); if (fft_plan == nullptr) { const int64_t fft_rank = fft_len.size(); CHECK_LE(fft_rank, 3); @@ -138,7 +156,7 @@ absl::Status RunFft(se::DeviceMemoryBase input, const Shape& input_shape, } constexpr bool kInPlaceFft = false; - fft_plan = stream->parent()->AsFft()->CreateBatchedPlanWithScratchAllocator( + fft_plan = fft->CreateBatchedPlanWithScratchAllocator( stream, fft_rank, fft_length, input_embed, input_stride, input_distance, output_embed, output_stride, output_distance, fft_type, kInPlaceFft, batch_size, &scratch_allocator); @@ -146,8 +164,8 @@ absl::Status RunFft(se::DeviceMemoryBase input, const Shape& input_shape, << "Failed to create cuFFT batched plan with scratch allocator"; fft_plan_ptr->scale_factor = 1.0f / output_distance; } else { - stream->parent()->AsFft()->UpdatePlanWithScratchAllocator( - stream, fft_plan.get(), &scratch_allocator); + fft->UpdatePlanWithScratchAllocator(stream, fft_plan.get(), + &scratch_allocator); } float scale_factor = fft_plan_ptr->scale_factor; @@ -157,81 +175,72 @@ absl::Status RunFft(se::DeviceMemoryBase input, const Shape& input_shape, case se::fft::Type::kC2CForward: { se::DeviceMemory input_data(input); se::DeviceMemory output_data(output); - launch_ok = - stream->ThenFft(fft_plan.get(), input_data, &output_data).ok(); + launch_ok = fft->DoFft(stream, fft_plan.get(), input_data, &output_data); break; } case se::fft::Type::kZ2ZForward: { se::DeviceMemory input_data(input); se::DeviceMemory output_data(output); - launch_ok = - stream->ThenFft(fft_plan.get(), input_data, &output_data).ok(); + launch_ok = fft->DoFft(stream, fft_plan.get(), input_data, &output_data); break; } case se::fft::Type::kC2CInverse: { se::DeviceMemory input_data(input); se::DeviceMemory output_data(output); - launch_ok = - stream->ThenFft(fft_plan.get(), input_data, &output_data).ok(); + launch_ok = fft->DoFft(stream, fft_plan.get(), input_data, &output_data); if (launch_ok) { - launch_ok = stream - ->ThenBlasScal(ShapeUtil::ElementsIn(output_shape), - complex64(scale_factor), &output_data, 1) - .ok(); + TF_ASSIGN_OR_RETURN(auto blas, GetBlas(stream)); + launch_ok = + blas->DoBlasScal(stream, ShapeUtil::ElementsIn(output_shape), + complex64(scale_factor), &output_data, 1); } break; } case se::fft::Type::kZ2ZInverse: { se::DeviceMemory input_data(input); se::DeviceMemory output_data(output); - launch_ok = - stream->ThenFft(fft_plan.get(), input_data, &output_data).ok(); + launch_ok = fft->DoFft(stream, fft_plan.get(), input_data, &output_data); if (launch_ok) { + TF_ASSIGN_OR_RETURN(auto blas, GetBlas(stream)); launch_ok = - stream - ->ThenBlasScal(ShapeUtil::ElementsIn(output_shape), - complex128(scale_factor), &output_data, 1) - .ok(); + blas->DoBlasScal(stream, ShapeUtil::ElementsIn(output_shape), + complex128(scale_factor), &output_data, 1); } break; } case se::fft::Type::kR2C: { se::DeviceMemory input_data(input); se::DeviceMemory output_data(output); - launch_ok = - stream->ThenFft(fft_plan.get(), input_data, &output_data).ok(); + launch_ok = fft->DoFft(stream, fft_plan.get(), input_data, &output_data); break; } case se::fft::Type::kD2Z: { se::DeviceMemory input_data(input); se::DeviceMemory output_data(output); - launch_ok = - stream->ThenFft(fft_plan.get(), input_data, &output_data).ok(); + launch_ok = fft->DoFft(stream, fft_plan.get(), input_data, &output_data); break; } case se::fft::Type::kC2R: { se::DeviceMemory input_data(input); se::DeviceMemory output_data(output); - launch_ok = - stream->ThenFft(fft_plan.get(), input_data, &output_data).ok(); + launch_ok = fft->DoFft(stream, fft_plan.get(), input_data, &output_data); if (launch_ok) { - launch_ok = stream - ->ThenBlasScal(ShapeUtil::ElementsIn(output_shape), - scale_factor, &output_data, 1) - .ok(); + TF_ASSIGN_OR_RETURN(auto blas, GetBlas(stream)); + launch_ok = + blas->DoBlasScal(stream, ShapeUtil::ElementsIn(output_shape), + scale_factor, &output_data, 1); } break; } case se::fft::Type::kZ2D: { se::DeviceMemory input_data(input); se::DeviceMemory output_data(output); - launch_ok = - stream->ThenFft(fft_plan.get(), input_data, &output_data).ok(); + launch_ok = fft->DoFft(stream, fft_plan.get(), input_data, &output_data); if (launch_ok) { - launch_ok = stream - ->ThenBlasScal(ShapeUtil::ElementsIn(output_shape), - scale_factor, &output_data, 1) - .ok(); + TF_ASSIGN_OR_RETURN(auto blas, GetBlas(stream)); + launch_ok = + blas->DoBlasScal(stream, ShapeUtil::ElementsIn(output_shape), + scale_factor, &output_data, 1); } break; } diff --git a/third_party/xla/xla/service/gpu/runtime3/fft_thunk.h b/third_party/xla/xla/service/gpu/runtime/fft_thunk.h similarity index 95% rename from third_party/xla/xla/service/gpu/runtime3/fft_thunk.h rename to third_party/xla/xla/service/gpu/runtime/fft_thunk.h index dae9aa4e15affe..bf6e214f039fe0 100644 --- a/third_party/xla/xla/service/gpu/runtime3/fft_thunk.h +++ b/third_party/xla/xla/service/gpu/runtime/fft_thunk.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_GPU_RUNTIME3_FFT_THUNK_H_ -#define XLA_SERVICE_GPU_RUNTIME3_FFT_THUNK_H_ +#ifndef XLA_SERVICE_GPU_RUNTIME_FFT_THUNK_H_ +#define XLA_SERVICE_GPU_RUNTIME_FFT_THUNK_H_ #include @@ -98,4 +98,4 @@ absl::Status RunFft(se::DeviceMemoryBase input, const Shape& input_shape, } // namespace gpu } // namespace xla -#endif // XLA_SERVICE_GPU_RUNTIME3_FFT_THUNK_H_ +#endif // XLA_SERVICE_GPU_RUNTIME_FFT_THUNK_H_ diff --git a/third_party/xla/xla/service/gpu/runtime/fused_attention.cc b/third_party/xla/xla/service/gpu/runtime/fused_attention.cc deleted file mode 100644 index 5846cbf2d4222f..00000000000000 --- a/third_party/xla/xla/service/gpu/runtime/fused_attention.cc +++ /dev/null @@ -1,1349 +0,0 @@ -/* Copyright 2023 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License.1 -==============================================================================*/ - -#include "xla/service/gpu/runtime/fused_attention.h" - -#include -#include -#include -#include -#include - -#include "llvm/ADT/Sequence.h" -#include "xla/mlir/runtime/transforms/custom_call_encoding.h" -#include "xla/runtime/custom_call.h" -#include "xla/runtime/executable.h" -#include "xla/service/gpu/gpu_asm_opts_util.h" -#include "xla/service/gpu/gpu_fused_mha_runner.h" -#include "xla/service/gpu/runtime/support.h" -#include "xla/service/service_executable_run_options.h" -#include "xla/status.h" -#include "xla/stream_executor/device_memory.h" -#include "xla/stream_executor/device_memory_allocator.h" -#include "xla/translate/mhlo_to_hlo/attribute_exporter.h" -#include "xla/xla.pb.h" - -namespace xla { - -using xla::runtime::CustomCall; -using xla::runtime::EnumAttrEncoding; -using xla::runtime::FlatMemrefView; -using xla::runtime::State; -using xla::runtime::StridedMemrefView; -using xla::runtime::Tagged; - -namespace lmhlo_gpu = ::mlir::lmhlo_gpu; -namespace gpu { -//===----------------------------------------------------------------------===// -// Structs for encoding fused attention attributes defined in LMHLO dialect. -//===----------------------------------------------------------------------===// -struct AlgorithmConfig { - int64_t algorithm; - absl::Span knob_ids; - absl::Span knob_values; - int64_t workspace_size; -}; - -} // namespace gpu - -//===----------------------------------------------------------------------===// -// Register fused attention attributes decoding with the Xla runtime. -//===----------------------------------------------------------------------===// -namespace runtime { -XLA_RUNTIME_REGISTER_ENUM_ATTR_DECODING(xla::gpu::CudnnfMHAKind); - -XLA_RUNTIME_REGISTER_AGGREGATE_ATTR_DECODING( - xla::gpu::AlgorithmConfig, // - AggregateMember("algorithm"), - AggregateMember>("knob_ids"), - AggregateMember>("knob_values"), - AggregateMember("workspace_size")); - -} // namespace runtime - -//===----------------------------------------------------------------------===// -// Type names for encoded attributes. -//===----------------------------------------------------------------------===// - -namespace gpu { - -// Register type names for fused attention attributes defined by LMHLO dialect. -void RegisterFusedAttentionTypeIdNames(runtime::TypeIDNameRegistry& registry) { - registry.Register>("__type_id_algorithm_config"); - registry.Register>( - "__type_id_xla_gpu_cudnn_fmha_kind"); -} - -static auto EncodeFusedAttentionDAGSignature( - lmhlo_gpu::FusedMhaDagSignature signature) { - switch (signature) { - case mlir::lmhlo_gpu::FusedMhaDagSignature::Default: - return xla::gpu::CudnnfMHAKind::kBmmBmm; - case mlir::lmhlo_gpu::FusedMhaDagSignature::ScaleBiasMaskSoftmax: - return xla::gpu::CudnnfMHAKind::kScaleBiasMaskSoftmax; - case mlir::lmhlo_gpu::FusedMhaDagSignature::ScaleBiasMaskSoftmaxDropout: - return xla::gpu::CudnnfMHAKind::kScaleBiasMaskSoftmaxDropout; - case mlir::lmhlo_gpu::FusedMhaDagSignature::ScaleMaskSoftmax: - return xla::gpu::CudnnfMHAKind::kScaleMaskSoftmax; - case mlir::lmhlo_gpu::FusedMhaDagSignature::ScaleMaskSoftmaxDropout: - return xla::gpu::CudnnfMHAKind::kScaleMaskSoftmaxDropout; - case mlir::lmhlo_gpu::FusedMhaDagSignature::SoftmaxDropout: - return xla::gpu::CudnnfMHAKind::kSoftmaxDropout; - case mlir::lmhlo_gpu::FusedMhaDagSignature::Softmax: - return xla::gpu::CudnnfMHAKind::kSoftmax; - case mlir::lmhlo_gpu::FusedMhaDagSignature::ScaleBiasSoftmax: - return xla::gpu::CudnnfMHAKind::kScaleBiasSoftmax; - case mlir::lmhlo_gpu::FusedMhaDagSignature::ScaleBiasSoftmaxDropout: - return xla::gpu::CudnnfMHAKind::kScaleBiasSoftmaxDropout; - } -} - -static auto EncodeFusedAttentionBackwardDAGSignature( - lmhlo_gpu::FusedMhaBackwardDagSignature signature) { - switch (signature) { - // backward - case mlir::lmhlo_gpu::FusedMhaBackwardDagSignature::BackwardSoftmax: - return xla::gpu::CudnnfMHAKind::kBackwardSoftmax; - case mlir::lmhlo_gpu::FusedMhaBackwardDagSignature::BackwardSoftmaxDropout: - return xla::gpu::CudnnfMHAKind::kBackwardSoftmaxDropout; - case mlir::lmhlo_gpu::FusedMhaBackwardDagSignature:: - BackwardScaleBiasSoftmax: - return xla::gpu::CudnnfMHAKind::kBackwardScaleBiasSoftmax; - case mlir::lmhlo_gpu::FusedMhaBackwardDagSignature:: - BackwardScaleBiasSoftmaxDropout: - return xla::gpu::CudnnfMHAKind::kBackwardScaleBiasSoftmaxDropout; - case mlir::lmhlo_gpu::FusedMhaBackwardDagSignature:: - BackwardScaleBiasMaskSoftmax: - return xla::gpu::CudnnfMHAKind::kBackwardScaleBiasMaskSoftmax; - case mlir::lmhlo_gpu::FusedMhaBackwardDagSignature:: - BackwardScaleBiasMaskSoftmaxDropout: - return xla::gpu::CudnnfMHAKind::kBackwardScaleBiasMaskSoftmaxDropout; - case mlir::lmhlo_gpu::FusedMhaBackwardDagSignature:: - BackwardScaleMaskSoftmax: - return xla::gpu::CudnnfMHAKind::kBackwardScaleMaskSoftmax; - case mlir::lmhlo_gpu::FusedMhaBackwardDagSignature:: - BackwardScaleMaskSoftmaxDropout: - return xla::gpu::CudnnfMHAKind::kBackwardScaleMaskSoftmaxDropout; - } -} - -void PopulateFusedAttentionForwardDAGSignatureAttrEncoding( - runtime::CustomCallAttrEncodingSet& encoding) { - { // --- Encode `lmhlo_gpu::FusedMhaDagSignatureAttr`. - encoding.Add>( - EncodeFusedAttentionDAGSignature); - } -} - -void PopulateFusedAttentionBackwardDAGSignatureAttrEncoding( - runtime::CustomCallAttrEncodingSet& encoding) { - { // --- Encode `lmhlo_gpu::FusedMhaBackwardDagSignatureAttr`. - encoding.Add>( - EncodeFusedAttentionBackwardDAGSignature); - } -} - -void PopulateFusedAttentionAlgorithmConfigAttrEncoding( - runtime::CustomCallAttrEncodingSet& encoding) { - { // --- Encode `lmhlo_gpu::FusedMHAAlgorithmConfigAttr`. - using Attr = mlir::lmhlo_gpu::FusedMHAAlgorithmConfigAttr; - encoding.Add>( - encoding, xla::runtime::AggregateAttrDef() - .Add("algorithm", &Attr::getAlgorithm) - .Add("knob_ids", &Attr::getKnobIds) - .Add("knob_values", &Attr::getKnobValues) - .Add("workspace_size", &Attr::getWorkspaceSize)); - } -} - -//===----------------------------------------------------------------------===// -// Fused Dot Attention runners caching. -//===----------------------------------------------------------------------===// - -StreamExecutorFusedAttentionRunners* FusedAttentionRunners::operator()( - se::StreamExecutor* executor) { - absl::MutexLock lock(&mutex_); - return &runners_[executor]; -} - -StreamExecutorFusedAttentionBackwardRunners* -FusedAttentionBackwardRunners::operator()(se::StreamExecutor* executor) { - absl::MutexLock lock(&mutex_); - return &runners_[executor]; -} - -namespace { -struct DropoutAttrs { - double dropout_rate; - int64_t seed; -}; -} // namespace - -static GpufMHADescriptor GetGpufMHADescriptor( - CudnnfMHAKind kind, StridedMemrefView lhs_bmm1, StridedMemrefView rhs_bmm1, - StridedMemrefView rhs_bmm2, std::optional mask, - std::optional bias, StridedMemrefView output, - std::optional activation, double fmha_scale, - absl::Span intermediate_tensor_dimensions, - absl::Span intermediate_tensor_layout, AlgorithmConfig algo, - DotDimensionNumbers bmm1_dot_dimension_numbers, - DotDimensionNumbers bmm2_dot_dimension_numbers, bool is_flash_attention, - bool is_causal_mask, std::optional dropout = std::nullopt) { - GpufMHADescriptor descriptor; - descriptor.backend_config.set_fmha_scale(fmha_scale); - - auto* algorithm = descriptor.backend_config.mutable_algorithm(); - algorithm->set_algo_id(algo.algorithm); - for (unsigned i = 0; i < algo.knob_ids.size(); ++i) { - algorithm->mutable_tuning_knobs()->insert( - {algo.knob_ids[i], algo.knob_values[i]}); - } - algorithm->set_is_cudnn_frontend(true); - if (algo.workspace_size >= 0) { - algorithm->mutable_workspace_size()->set_value(algo.workspace_size); - } - descriptor.bmm1_dnums = - ConvertDotDimensionNumbers(bmm1_dot_dimension_numbers.lhs_batch, - bmm1_dot_dimension_numbers.lhs_contract, - bmm1_dot_dimension_numbers.rhs_batch, - bmm1_dot_dimension_numbers.rhs_contract); - descriptor.bmm2_dnums = - ConvertDotDimensionNumbers(bmm2_dot_dimension_numbers.lhs_batch, - bmm2_dot_dimension_numbers.lhs_contract, - bmm2_dot_dimension_numbers.rhs_batch, - bmm2_dot_dimension_numbers.rhs_contract); - // Apply backend config layout to the shape. - auto apply_shape = [](StridedMemrefView& memref) { - Shape shape = ToShape(memref); - return ShapeUtil::MakeShapeWithDenseLayout(shape.element_type(), - shape.dimensions(), - shape.layout().minor_to_major()); - }; - descriptor.lhs_bmm1_shape = apply_shape(lhs_bmm1); - descriptor.rhs_bmm1_shape = apply_shape(rhs_bmm1); - descriptor.rhs_bmm2_shape = apply_shape(rhs_bmm2); - descriptor.output_shapes.push_back(apply_shape(output)); - if (activation.has_value()) { - descriptor.output_shapes.push_back(apply_shape(*activation)); - } - if (bias.has_value()) { - descriptor.bias_shape = apply_shape(*bias); - } - if (mask.has_value()) { - descriptor.mask_shape = apply_shape(*mask); - } - - Shape out_shape = ToShape(output); - descriptor.intermediate_lhs_bmm2_shape = ShapeUtil::MakeShapeWithDenseLayout( - out_shape.element_type(), intermediate_tensor_dimensions, - intermediate_tensor_layout); - - if (dropout.has_value()) { - descriptor.backend_config.set_dropout_rate(dropout->dropout_rate); - descriptor.backend_config.set_seed(dropout->seed); - } - - descriptor.kind = kind; - descriptor.is_flash_attention = is_flash_attention; - descriptor.is_causal_mask = is_causal_mask; - return descriptor; -} - -static GpufMHABackwardDescriptor GetGpufMHABackwardDescriptor( - CudnnfMHAKind kind, StridedMemrefView bmm1_grad_gemm1_rhs, - StridedMemrefView bmm1_grad_gemm2_rhs, - StridedMemrefView bmm2_grad_gemm2_rhs, - StridedMemrefView bmm2_grad_gemm1_lhs, StridedMemrefView d_output, - std::optional mask, - std::optional d_bias, StridedMemrefView d_bmm1_lhs, - StridedMemrefView d_bmm1_rhs, StridedMemrefView d_bmm2_rhs, - std::optional d_S, - std::optional softmax_sum, - std::optional d_Q_accum, - std::optional fwd_output, - std::optional bias, double fmha_scale, - AlgorithmConfig algo, - DotDimensionNumbers bmm1_grad_gemm1_dot_dimension_numbers, - DotDimensionNumbers bmm1_grad_gemm2_dot_dimension_numbers, - DotDimensionNumbers bmm2_grad_gemm1_dot_dimension_numbers, - DotDimensionNumbers bmm2_grad_gemm2_dot_dimension_numbers, - absl::Span intermediate_tensor_dimensions, - absl::Span intermediate_tensor_layout, - bool is_flash_attention, bool is_causal_mask, - std::optional dropout_attrs = std::nullopt) { - GpufMHABackwardDescriptor descriptor; - descriptor.backend_config.set_fmha_scale(fmha_scale); - - auto* algorithm = descriptor.backend_config.mutable_algorithm(); - algorithm->set_algo_id(algo.algorithm); - for (unsigned i = 0; i < algo.knob_ids.size(); ++i) { - algorithm->mutable_tuning_knobs()->insert( - {algo.knob_ids[i], algo.knob_values[i]}); - } - algorithm->set_is_cudnn_frontend(true); - if (algo.workspace_size >= 0) { - algorithm->mutable_workspace_size()->set_value(algo.workspace_size); - } - - descriptor.bmm1_grad_gemm1_dnums = ConvertDotDimensionNumbers( - bmm1_grad_gemm1_dot_dimension_numbers.lhs_batch, - bmm1_grad_gemm1_dot_dimension_numbers.lhs_contract, - bmm1_grad_gemm1_dot_dimension_numbers.rhs_batch, - bmm1_grad_gemm1_dot_dimension_numbers.rhs_contract); - descriptor.bmm1_grad_gemm2_dnums = ConvertDotDimensionNumbers( - bmm1_grad_gemm2_dot_dimension_numbers.lhs_batch, - bmm1_grad_gemm2_dot_dimension_numbers.lhs_contract, - bmm1_grad_gemm2_dot_dimension_numbers.rhs_batch, - bmm1_grad_gemm2_dot_dimension_numbers.rhs_contract); - descriptor.bmm2_grad_gemm1_dnums = ConvertDotDimensionNumbers( - bmm2_grad_gemm1_dot_dimension_numbers.lhs_batch, - bmm2_grad_gemm1_dot_dimension_numbers.lhs_contract, - bmm2_grad_gemm1_dot_dimension_numbers.rhs_batch, - bmm2_grad_gemm1_dot_dimension_numbers.rhs_contract); - descriptor.bmm2_grad_gemm2_dnums = ConvertDotDimensionNumbers( - bmm2_grad_gemm2_dot_dimension_numbers.lhs_batch, - bmm2_grad_gemm2_dot_dimension_numbers.lhs_contract, - bmm2_grad_gemm2_dot_dimension_numbers.rhs_batch, - bmm2_grad_gemm2_dot_dimension_numbers.rhs_contract); - - // Apply backend config layout to the shape. - auto apply_shape = [](StridedMemrefView& memref) { - Shape shape = ToShape(memref); - return ShapeUtil::MakeShapeWithDenseLayout(shape.element_type(), - shape.dimensions(), - shape.layout().minor_to_major()); - }; - descriptor.bmm1_grad_gemm1_rhs_shape = apply_shape(bmm1_grad_gemm1_rhs); - descriptor.bmm1_grad_gemm2_rhs_shape = apply_shape(bmm1_grad_gemm2_rhs); - descriptor.bmm2_grad_gemm2_rhs_shape = apply_shape(bmm2_grad_gemm2_rhs); - if (is_flash_attention) { - // if it is flash attention then bmm2_grad_gemm1_lhs will be softmax_stats - // instead of P we need to use real P layout - descriptor.bmm2_grad_gemm1_lhs_shape = ShapeUtil::MakeShapeWithDenseLayout( - descriptor.bmm2_grad_gemm2_rhs_shape.element_type(), - intermediate_tensor_dimensions, intermediate_tensor_layout); - } else { - descriptor.bmm2_grad_gemm1_lhs_shape = apply_shape(bmm2_grad_gemm1_lhs); - } - - descriptor.d_output_shape = apply_shape(d_output); - descriptor.d_bmm1_lhs_shape = apply_shape(d_bmm1_lhs); - descriptor.d_bmm1_rhs_shape = apply_shape(d_bmm1_rhs); - descriptor.d_bmm2_rhs_shape = apply_shape(d_bmm2_rhs); - - if (mask.has_value()) { - descriptor.mask_shape = apply_shape(*mask); - } - if (d_bias.has_value()) { - descriptor.d_bias_shape = apply_shape(*d_bias); - } - if (fwd_output.has_value()) { - descriptor.fwd_output_shape = apply_shape(*fwd_output); - } - if (bias.has_value()) { - descriptor.bias_shape = apply_shape(*bias); - } - if (dropout_attrs.has_value()) { - descriptor.backend_config.set_dropout_rate(dropout_attrs->dropout_rate); - descriptor.backend_config.set_seed(dropout_attrs->seed); - } - - descriptor.kind = kind; - descriptor.is_flash_attention = is_flash_attention; - descriptor.is_causal_mask = is_causal_mask; - return descriptor; -} - -static absl::Status FusedAttentionForwardImpl( - const ServiceExecutableRunOptions* run_options, - const DebugOptions* debug_options, State runner, - StridedMemrefView lhs_bmm1, StridedMemrefView rhs_bmm1, - StridedMemrefView rhs_bmm2, std::optional mask, - std::optional bias, StridedMemrefView output, - FlatMemrefView scratch, std::optional activation, - int64_t uid, double fmha_scale, bool is_flash_attention, - bool is_causal_mask, - absl::Span intermediate_tensor_dimensions, - absl::Span intermediate_tensor_layout, - DotDimensionNumbers bmm1_dot_dimension_numbers, - DotDimensionNumbers bmm2_dot_dimension_numbers, - xla::gpu::CudnnfMHAKind kind, AlgorithmConfig algorithm_config, - std::optional dropout_rate = std::nullopt, - std::optional seed = std::nullopt) { - std::optional dropout_attrs = std::nullopt; - if (dropout_rate.has_value() && seed.has_value()) { - dropout_attrs = {*dropout_rate, *seed}; - } - // Get or create the fused attention runner state. - absl::StatusOr fda = - runner.GetOrCreate([&]() -> absl::StatusOr { - GpufMHADescriptor descriptor = GetGpufMHADescriptor( - kind, lhs_bmm1, rhs_bmm1, rhs_bmm2, mask, bias, output, activation, - fmha_scale, intermediate_tensor_dimensions, - intermediate_tensor_layout, algorithm_config, - bmm1_dot_dimension_numbers, bmm2_dot_dimension_numbers, - is_flash_attention, is_causal_mask, dropout_attrs); - - absl::StatusOr config = GpufMHAConfig::For(descriptor); - if (!config.ok()) return tsl::ToAbslStatus(config.status()); - - return FusedAttentionRunner(*std::move(config)); - }); - if (!fda.ok()) return fda.status(); - - se::DeviceMemoryBase lhs_bmm1_buffer = GetDeviceAddress(lhs_bmm1); - se::DeviceMemoryBase rhs_bmm1_buffer = GetDeviceAddress(rhs_bmm1); - se::DeviceMemoryBase rhs_bmm2_buffer = GetDeviceAddress(rhs_bmm2); - se::DeviceMemoryBase output_buffer = GetDeviceAddress(output); - se::DeviceMemoryBase scratch_buffer = GetDeviceAddress(scratch); - - se::DeviceMemoryBase mask_buffer; - if (mask.has_value()) { - mask_buffer = GetDeviceAddress(*mask); - } - se::DeviceMemoryBase bias_buffer; - if (bias.has_value()) { - bias_buffer = GetDeviceAddress(*bias); - } - se::DeviceMemoryBase activation_buffer; - if (activation.has_value()) { - activation_buffer = GetDeviceAddress(*activation); - } - - RunFusedMHAOptions opts; - opts.runner_cache = &(*fda)->runner; - - // Run the fused dot attention. - auto st = - RunGpuFMHA((*fda)->config, lhs_bmm1_buffer, rhs_bmm1_buffer, - rhs_bmm2_buffer, output_buffer, scratch_buffer, mask_buffer, - bias_buffer, activation_buffer, run_options->stream(), opts); - if (!st.ok() || !run_options->stream()->ok()) { - return tsl::ToAbslStatus(st); - } - return absl::OkStatus(); -} - -static absl::Status FusedAttentionBackwardImpl( - const ServiceExecutableRunOptions* run_options, - const DebugOptions* debug_options, - State runner, - StridedMemrefView bmm1_grad_gemm1_rhs, - StridedMemrefView bmm1_grad_gemm2_rhs, - StridedMemrefView bmm2_grad_gemm2_rhs, - StridedMemrefView bmm2_grad_gemm1_lhs, StridedMemrefView d_output, - std::optional mask, - std::optional bias, - std::optional fwd_output, StridedMemrefView d_bmm1_lhs, - StridedMemrefView d_bmm1_rhs, StridedMemrefView d_bmm2_rhs, - std::optional d_S, - std::optional softmax_sum, - std::optional d_Q_accum, FlatMemrefView scratch, - std::optional d_bias, int64_t uid, double fmha_scale, - bool is_flash_attention, bool is_causal_mask, - absl::Span intermediate_tensor_dimensions, - absl::Span intermediate_tensor_layout, - DotDimensionNumbers bmm1_grad_gemm1_dot_dimension_numbers, - DotDimensionNumbers bmm1_grad_gemm2_dot_dimension_numbers, - DotDimensionNumbers bmm2_grad_gemm1_dot_dimension_numbers, - DotDimensionNumbers bmm2_grad_gemm2_dot_dimension_numbers, - xla::gpu::CudnnfMHAKind kind, AlgorithmConfig algorithm_config, - std::optional dropout_rate = std::nullopt, - std::optional seed = std::nullopt) { - std::optional dropout_attrs = std::nullopt; - if (dropout_rate.has_value() && seed.has_value()) { - dropout_attrs = {*dropout_rate, *seed}; - } - - // Get or create the fused attention runner state. - absl::StatusOr fda = - runner.GetOrCreate([&]() -> absl::StatusOr { - GpufMHABackwardDescriptor descriptor = GetGpufMHABackwardDescriptor( - kind, bmm1_grad_gemm1_rhs, bmm1_grad_gemm2_rhs, bmm2_grad_gemm2_rhs, - bmm2_grad_gemm1_lhs, d_output, mask, d_bias, d_bmm1_lhs, d_bmm1_rhs, - d_bmm2_rhs, d_S, softmax_sum, d_Q_accum, fwd_output, bias, - fmha_scale, algorithm_config, bmm1_grad_gemm1_dot_dimension_numbers, - bmm1_grad_gemm2_dot_dimension_numbers, - bmm2_grad_gemm1_dot_dimension_numbers, - bmm2_grad_gemm2_dot_dimension_numbers, - intermediate_tensor_dimensions, intermediate_tensor_layout, - is_flash_attention, is_causal_mask, dropout_attrs); - absl::StatusOr config = - GpufMHABackwardConfig::For(descriptor); - if (!config.ok()) return tsl::ToAbslStatus(config.status()); - - return FusedAttentionBackwardRunner(*std::move(config)); - }); - if (!fda.ok()) return fda.status(); - - se::DeviceMemoryBase bmm1_grad_gemm1_rhs_buffer = - GetDeviceAddress(bmm1_grad_gemm1_rhs); - se::DeviceMemoryBase bmm1_grad_gemm2_rhs_buffer = - GetDeviceAddress(bmm1_grad_gemm2_rhs); - se::DeviceMemoryBase bmm2_grad_gemm2_rhs_buffer = - GetDeviceAddress(bmm2_grad_gemm2_rhs); - se::DeviceMemoryBase bmm2_grad_gemm1_lhs_buffer = - GetDeviceAddress(bmm2_grad_gemm1_lhs); - - se::DeviceMemoryBase d_output_buffer = GetDeviceAddress(d_output); - se::DeviceMemoryBase d_bmm1_lhs_buffer = GetDeviceAddress(d_bmm1_lhs); - se::DeviceMemoryBase d_bmm1_rhs_buffer = GetDeviceAddress(d_bmm1_rhs); - se::DeviceMemoryBase d_bmm2_rhs_buffer = GetDeviceAddress(d_bmm2_rhs); - se::DeviceMemoryBase scratch_buffer = GetDeviceAddress(scratch); - - se::DeviceMemoryBase d_S_buffer; - if (d_S.has_value()) { - d_S_buffer = GetDeviceAddress(*d_S); - } - - se::DeviceMemoryBase mask_buffer; - if (mask.has_value()) { - mask_buffer = GetDeviceAddress(*mask); - } - - se::DeviceMemoryBase d_bias_buffer; - if (d_bias.has_value()) { - d_bias_buffer = GetDeviceAddress(*d_bias); - } - - se::DeviceMemoryBase softmax_sum_buffer; - if (softmax_sum.has_value()) { - softmax_sum_buffer = GetDeviceAddress(*softmax_sum); - } - - se::DeviceMemoryBase d_Q_accum_buffer; - if (d_Q_accum.has_value()) { - d_Q_accum_buffer = GetDeviceAddress(*d_Q_accum); - } - - se::DeviceMemoryBase fwd_output_buffer; - if (fwd_output.has_value()) { - fwd_output_buffer = GetDeviceAddress(*fwd_output); - } - - se::DeviceMemoryBase bias_buffer; - if (bias.has_value()) { - bias_buffer = GetDeviceAddress(*bias); - } - - RunFusedMHABackwardOptions opts; - opts.runner_cache = &(*fda)->runner; - - // Run the fused attention backward. - auto st = RunGpuFMHABackward( - (*fda)->config, bmm1_grad_gemm1_rhs_buffer, bmm1_grad_gemm2_rhs_buffer, - bmm2_grad_gemm1_lhs_buffer, bmm2_grad_gemm2_rhs_buffer, d_output_buffer, - scratch_buffer, d_bmm1_lhs_buffer, d_bmm1_rhs_buffer, d_bmm2_rhs_buffer, - d_S_buffer, softmax_sum_buffer, d_Q_accum_buffer, mask_buffer, - d_bias_buffer, fwd_output_buffer, bias_buffer, run_options->stream(), - opts); - if (!st.ok() || !run_options->stream()->ok()) { - return tsl::ToAbslStatus(st); - } - return absl::OkStatus(); -} - -//===----------------------------------------------------------------------===// -// Fused Attention custom calls bindings and registration. -//===----------------------------------------------------------------------===// - -template -auto BindFusedAttentionAttributes(runtime::CustomCallBinding binding) { - return std::move(binding) - .template Attr("uid") - .template Attr("fmha_scale") - .template Attr("is_flash_attention") - .template Attr("is_causal_mask") - .template Attr>( - "intermediate_tensor_dimensions") - .template Attr>("intermediate_tensor_layout") - .template Attr("bmm1_dot_dimension_numbers") - .template Attr("bmm2_dot_dimension_numbers") - .template Attr("fused_mha_dag") - .template Attr("algorithm_config"); -} - -auto FusedAttentionCall(const char* name) { - return CustomCall::Bind(name) - .UserData() - .UserData() - .State("uid") - .Arg() // lhs_bmm1 - .Arg() // rhs_bmm1 - .Arg(); // rhs_bmm2 -} - -XLA_RUNTIME_DEFINE_CUSTOM_CALL( - FusedAttentionBmmBmmInference, FunctionWrapper(), - checks, - BindFusedAttentionAttributes( - FusedAttentionCall("xla.gpu.fused.attention.bmm.bmm.inference") - .Value(std::optional()) // mask - .Value(std::optional()) // bias - .Arg() // output - .Arg() // scratch - .Value(std::optional()) // activation - ) - .Value(std::optional()) // dropout_rate - .Value(std::optional()) // seed -); - -XLA_RUNTIME_DEFINE_CUSTOM_CALL( - FusedAttentionBmmBmmForward, FunctionWrapper(), - checks, - BindFusedAttentionAttributes( - FusedAttentionCall("xla.gpu.fused.attention.bmm.bmm.forward") - .Value(std::optional()) // mask - .Value(std::optional()) // bias - .Arg() // output - .Arg() // scratch - .Arg() // activation - ) - .Value(std::optional()) // dropout_rate - .Value(std::optional()) // seed -); - -XLA_RUNTIME_DEFINE_CUSTOM_CALL( - FusedAttentionSoftmaxInference, - FunctionWrapper(), checks, - BindFusedAttentionAttributes( - FusedAttentionCall("xla.gpu.fused.attention.softmax.inference") - .Value(std::optional()) // mask - .Value(std::optional()) // bias - .Arg() // output - .Arg() // scratch - .Value(std::optional()) // activation - ) - .Value(std::optional()) // dropout_rate - .Value(std::optional()) // seed -); - -XLA_RUNTIME_DEFINE_CUSTOM_CALL( - FusedAttentionSoftmaxForward, FunctionWrapper(), - checks, - BindFusedAttentionAttributes( - FusedAttentionCall("xla.gpu.fused.attention.softmax.forward") - .Value(std::optional()) // mask - .Value(std::optional()) // bias - .Arg() // output - .Arg() // scratch - .Arg() // activation - ) - .Value(std::optional()) // dropout_rate - .Value(std::optional()) // seed -); - -XLA_RUNTIME_DEFINE_CUSTOM_CALL( - FusedAttentionSoftmaxDropoutInference, - FunctionWrapper(), checks, - BindFusedAttentionAttributes( - FusedAttentionCall("xla.gpu.fused.attention.softmax.dropout.inference") - .Value(std::optional()) // mask - .Value(std::optional()) // bias - .Arg() // output - .Arg() // scratch - .Value(std::optional()) // activation - ) - .Attr("dropout_rate") // dropout_rate - .Attr("seed") // seed -); - -XLA_RUNTIME_DEFINE_CUSTOM_CALL( - FusedAttentionSoftmaxDropoutForward, - FunctionWrapper(), checks, - BindFusedAttentionAttributes( - FusedAttentionCall("xla.gpu.fused.attention.softmax.dropout.forward") - .Value(std::optional()) // mask - .Value(std::optional()) // bias - .Arg() // output - .Arg() // scratch - .Arg() // activation - ) - .Attr("dropout_rate") // dropout_rate - .Attr("seed") // seed -); - -XLA_RUNTIME_DEFINE_CUSTOM_CALL( - FusedAttentionScaleBiasSoftmaxInference, - FunctionWrapper(), checks, - BindFusedAttentionAttributes( - FusedAttentionCall( - "xla.gpu.fused.attention.scale.bias.softmax.inference") - .Value(std::optional()) // mask - .Arg() // bias - .Arg() // output - .Arg() // scratch - .Value(std::optional()) // activation - ) - .Value(std::optional()) // dropout_rate - .Value(std::optional()) // seed -); - -XLA_RUNTIME_DEFINE_CUSTOM_CALL( - FusedAttentionScaleBiasSoftmaxForward, - FunctionWrapper(), checks, - BindFusedAttentionAttributes( - FusedAttentionCall("xla.gpu.fused.attention.scale.bias.softmax.forward") - .Value(std::optional()) // mask - .Arg() // bias - .Arg() // output - .Arg() // scratch - .Arg() // activation - ) - .Value(std::optional()) // dropout_rate - .Value(std::optional()) // seed -); - -XLA_RUNTIME_DEFINE_CUSTOM_CALL( - FusedAttentionScaleBiasSoftmaxDropoutInference, - FunctionWrapper(), checks, - BindFusedAttentionAttributes( - FusedAttentionCall( - "xla.gpu.fused.attention.scale.bias.softmax.dropout.inference") - .Value(std::optional()) // mask - .Arg() // bias - .Arg() // output - .Arg() // scratch - .Value(std::optional()) // activation - ) - .Attr("dropout_rate") // dropout_rate - .Attr("seed") // seed -); - -XLA_RUNTIME_DEFINE_CUSTOM_CALL( - FusedAttentionScaleBiasSoftmaxDropoutForward, - FunctionWrapper(), checks, - BindFusedAttentionAttributes( - FusedAttentionCall( - "xla.gpu.fused.attention.scale.bias.softmax.dropout.forward") - .Value(std::optional()) // mask - .Arg() // bias - .Arg() // output - .Arg() // scratch - .Arg() // activation - ) - .Attr("dropout_rate") // dropout_rate - .Attr("seed") // seed -); - -XLA_RUNTIME_DEFINE_CUSTOM_CALL( - FusedAttentionScaleMaskSoftmaxInference, - FunctionWrapper(), checks, - BindFusedAttentionAttributes( - FusedAttentionCall( - "xla.gpu.fused.attention.scale.mask.softmax.inference") - .Arg() // mask - .Value(std::optional()) // bias - .Arg() // output - .Arg() // scratch - .Value(std::optional()) // activation - ) - .Value(std::optional()) // dropout_rate - .Value(std::optional()) // seed -); - -XLA_RUNTIME_DEFINE_CUSTOM_CALL( - FusedAttentionScaleMaskSoftmaxForward, - FunctionWrapper(), checks, - BindFusedAttentionAttributes( - FusedAttentionCall("xla.gpu.fused.attention.scale.mask.softmax.forward") - .Arg() // mask - .Value(std::optional()) // bias - .Arg() // output - .Arg() // scratch - .Arg() // activation - ) - .Value(std::optional()) // dropout_rate - .Value(std::optional()) // seed -); - -XLA_RUNTIME_DEFINE_CUSTOM_CALL( - FusedAttentionScaleMaskSoftmaxDropoutInference, - FunctionWrapper(), checks, - BindFusedAttentionAttributes( - FusedAttentionCall( - "xla.gpu.fused.attention.scale.mask.softmax.dropout.inference") - .Arg() // mask - .Value(std::optional()) // bias - .Arg() // output - .Arg() // scratch - .Value(std::optional()) // activation - ) - .Attr("dropout_rate") // dropout_rate - .Attr("seed") // seed -); - -XLA_RUNTIME_DEFINE_CUSTOM_CALL( - FusedAttentionScaleMaskSoftmaxDropoutForward, - FunctionWrapper(), checks, - BindFusedAttentionAttributes( - FusedAttentionCall( - "xla.gpu.fused.attention.scale.mask.softmax.dropout.forward") - .Arg() // mask - .Value(std::optional()) // bias - .Arg() // output - .Arg() // scratch - .Arg() // activation - ) - .Attr("dropout_rate") // dropout_rate - .Attr("seed") // seed -); - -XLA_RUNTIME_DEFINE_CUSTOM_CALL( - FusedAttentionScaleBiasMaskSoftmaxInference, - FunctionWrapper(), checks, - BindFusedAttentionAttributes( - FusedAttentionCall( - "xla.gpu.fused.attention.scale.bias.mask.softmax.inference") - .Arg() // mask - .Arg() // bias - .Arg() // output - .Arg() // scratch - .Value(std::optional()) // activation - ) - .Value(std::optional()) // dropout_rate - .Value(std::optional()) // seed -); - -XLA_RUNTIME_DEFINE_CUSTOM_CALL( - FusedAttentionScaleBiasMaskSoftmaxForward, - FunctionWrapper(), checks, - BindFusedAttentionAttributes( - FusedAttentionCall( - "xla.gpu.fused.attention.scale.bias.mask.softmax.forward") - .Arg() // mask - .Arg() // bias - .Arg() // output - .Arg() // scratch - .Arg() // activation - ) - .Value(std::optional()) // dropout_rate - .Value(std::optional()) // seed -); - -XLA_RUNTIME_DEFINE_CUSTOM_CALL( - FusedAttentionScaleBiasMaskSoftmaxDropoutInference, - FunctionWrapper(), checks, - BindFusedAttentionAttributes( - FusedAttentionCall( - "xla.gpu.fused.attention.scale.bias.mask.softmax.dropout.inference") - .Arg() // mask - .Arg() // bias - .Arg() // output - .Arg() // scratch - .Value(std::optional()) // activation - ) - .Attr("dropout_rate") // dropout_rate - .Attr("seed") // seed -); - -XLA_RUNTIME_DEFINE_CUSTOM_CALL( - FusedAttentionScaleBiasMaskSoftmaxDropoutForward, - FunctionWrapper(), checks, - BindFusedAttentionAttributes( - FusedAttentionCall( - "xla.gpu.fused.attention.scale.bias.mask.softmax.dropout.forward") - .Arg() // mask - .Arg() // bias - .Arg() // output - .Arg() // scratch - .Arg() // activation - ) - .Attr("dropout_rate") // dropout_rate - .Attr("seed") // seed -); - -template -auto BindFusedAttentionBackwardAttributes( - runtime::CustomCallBinding binding) { - return std::move(binding) - .template Attr("uid") - .template Attr("fmha_scale") - .template Attr("is_flash_attention") - .template Attr("is_causal_mask") - .template Attr>( - "intermediate_tensor_dimensions") - .template Attr>("intermediate_tensor_layout") - .template Attr( - "bmm1_grad_gemm1_dot_dimension_numbers") - .template Attr( - "bmm1_grad_gemm2_dot_dimension_numbers") - .template Attr( - "bmm2_grad_gemm1_dot_dimension_numbers") - .template Attr( - "bmm2_grad_gemm2_dot_dimension_numbers") - .template Attr("fused_mha_dag") - .template Attr("algorithm_config"); -} - -auto FusedAttentionBackwardCall(const char* name) { - return CustomCall::Bind(name) - .UserData() - .UserData() - .State("uid") - .Arg() // bmm1_grad_gemm1_rhs - .Arg() // bmm1_grad_gemm2_rhs - .Arg() // bmm2_grad_gemm2_rhs - .Arg() // bmm2_grad_gemm1_lhs - .Arg(); // d_output -} - -XLA_RUNTIME_DEFINE_CUSTOM_CALL( - FusedAttentionScaleBiasSoftmaxBackward, - FunctionWrapper(), checks, - BindFusedAttentionBackwardAttributes( - FusedAttentionBackwardCall( - "xla.gpu.fused.attention.backward.scale.dbias.softmax") - .Value(std::optional()) // mask - .Value(std::optional()) // bias - .Value(std::optional()) // fwd_output - .Arg() // d_bmm1_lhs - .Arg() // d_bmm1_rhs - .Arg() // d_bmm2_rhs - .Arg() // d_S - .Value(std::optional()) // softmax_sum - .Value(std::optional()) // d_Q_accum - .Arg() // scratch - .Arg() // d_bias - ) - .Value(std::optional()) // dropout_rate - .Value(std::optional()) // seed -); - -XLA_RUNTIME_DEFINE_CUSTOM_CALL( - FusedAttentionScaleSoftmaxBackward, - FunctionWrapper(), checks, - BindFusedAttentionBackwardAttributes( - FusedAttentionBackwardCall( - "xla.gpu.fused.attention.backward.scale.softmax") - .Value(std::optional()) // mask - .Value(std::optional()) // bias - .Value(std::optional()) // fwd_output - .Arg() // d_bmm1_lhs - .Arg() // d_bmm1_rhs - .Arg() // d_bmm2_rhs - .Arg() // d_S - .Value(std::optional()) // softmax_sum - .Value(std::optional()) // d_Q_accum - .Arg() // scratch - .Value(std::optional()) // d_bias - ) - .Value(std::optional()) // dropout_rate - .Value(std::optional()) // seed -); - -XLA_RUNTIME_DEFINE_CUSTOM_CALL( - FusedAttentionScaleBiasSoftmaxDropoutBackward, - FunctionWrapper(), checks, - BindFusedAttentionBackwardAttributes( - FusedAttentionBackwardCall( - "xla.gpu.fused.attention.backward.scale.dbias.softmax.dropout") - .Value(std::optional()) // mask - .Value(std::optional()) // bias - .Value(std::optional()) // fwd_output - .Arg() // d_bmm1_lhs - .Arg() // d_bmm1_rhs - .Arg() // d_bmm2_rhs - .Arg() // d_S - .Value(std::optional()) // softmax_sum - .Value(std::optional()) // d_Q_accum - .Arg() // scratch - .Arg() // d_bias - ) - .Attr("dropout_rate") // dropout_rate - .Attr("seed") // seed -); - -XLA_RUNTIME_DEFINE_CUSTOM_CALL( - FusedAttentionScaleSoftmaxDropoutBackward, - FunctionWrapper(), checks, - BindFusedAttentionBackwardAttributes( - FusedAttentionBackwardCall( - "xla.gpu.fused.attention.backward.scale.softmax.dropout") - .Value(std::optional()) // mask - .Value(std::optional()) // bias - .Value(std::optional()) // fwd_output - .Arg() // d_bmm1_lhs - .Arg() // d_bmm1_rhs - .Arg() // d_bmm2_rhs - .Arg() // d_S - .Value(std::optional()) // softmax_sum - .Value(std::optional()) // d_Q_accum - .Arg() // scratch - .Value(std::optional()) // d_bias - ) - .Attr("dropout_rate") // dropout_rate - .Attr("seed") // seed -); - -XLA_RUNTIME_DEFINE_CUSTOM_CALL( - FusedAttentionScaleBiasMaskSoftmaxBackward, - FunctionWrapper(), checks, - BindFusedAttentionBackwardAttributes( - FusedAttentionBackwardCall( - "xla.gpu.fused.attention.backward.scale.dbias.mask.softmax") - .Arg() // mask - .Value(std::optional()) // bias - .Value(std::optional()) // fwd_output - .Arg() // d_bmm1_lhs - .Arg() // d_bmm1_rhs - .Arg() // d_bmm2_rhs - .Arg() // d_S - .Value(std::optional()) // softmax_sum - .Value(std::optional()) // d_Q_accum - .Arg() // scratch - .Arg() // d_bias - ) - .Value(std::optional()) // dropout_rate - .Value(std::optional()) // seed -); - -XLA_RUNTIME_DEFINE_CUSTOM_CALL( - FusedAttentionScaleMaskSoftmaxBackward, - FunctionWrapper(), checks, - BindFusedAttentionBackwardAttributes( - FusedAttentionBackwardCall( - "xla.gpu.fused.attention.backward.scale.mask.softmax") - .Arg() // mask - .Value(std::optional()) // bias - .Value(std::optional()) // fwd_output - .Arg() // d_bmm1_lhs - .Arg() // d_bmm1_rhs - .Arg() // d_bmm2_rhs - .Arg() // d_S - .Value(std::optional()) // softmax_sum - .Value(std::optional()) // d_Q_accum - .Arg() // scratch - .Value(std::optional()) // d_bias - ) - .Value(std::optional()) // dropout_rate - .Value(std::optional()) // seed -); - -XLA_RUNTIME_DEFINE_CUSTOM_CALL( - FusedAttentionScaleBiasMaskSoftmaxDropoutBackward, - FunctionWrapper(), checks, - BindFusedAttentionBackwardAttributes( - FusedAttentionBackwardCall( - "xla.gpu.fused.attention.backward.scale.dbias.mask.softmax.dropout") - .Arg() // mask - .Value(std::optional()) // bias - .Value(std::optional()) // fwd_output - .Arg() // d_bmm1_lhs - .Arg() // d_bmm1_rhs - .Arg() // d_bmm2_rhs - .Arg() // d_S - .Value(std::optional()) // softmax_sum - .Value(std::optional()) // d_Q_accum - .Arg() // scratch - .Arg() // d_bias - ) - .Attr("dropout_rate") // dropout_rate - .Attr("seed") // seed -); - -XLA_RUNTIME_DEFINE_CUSTOM_CALL( - FusedAttentionScaleMaskSoftmaxDropoutBackward, - FunctionWrapper(), checks, - BindFusedAttentionBackwardAttributes( - FusedAttentionBackwardCall( - "xla.gpu.fused.attention.backward.scale.mask.softmax.dropout") - .Arg() // mask - .Value(std::optional()) // bias - .Value(std::optional()) // fwd_output - .Arg() // d_bmm1_lhs - .Arg() // d_bmm1_rhs - .Arg() // d_bmm2_rhs - .Arg() // d_S - .Value(std::optional()) // softmax_sum - .Value(std::optional()) // d_Q_accum - .Arg() // scratch - .Value(std::optional()) // d_bias - ) - .Attr("dropout_rate") // dropout_rate - .Attr("seed") // seed -); - -// flash attention backward custom call -XLA_RUNTIME_DEFINE_CUSTOM_CALL( - FlashAttentionScaleBiasSoftmaxBackward, - FunctionWrapper(), checks, - BindFusedAttentionBackwardAttributes( - FusedAttentionBackwardCall( - "xla.gpu.flash.attention.backward.scale.bias.softmax") - .Value(std::optional()) // mask - .Arg() // bias - .Arg() // fwd_output - .Arg() // d_bmm1_lhs - .Arg() // d_bmm1_rhs - .Arg() // d_bmm2_rhs - .Value(std::optional()) // d_S - .Arg() // softmax_sum - .Arg() // d_Q_accum - .Arg() // scratch - .Value(std::optional()) // d_bias - ) - .Value(std::optional()) // dropout_rate - .Value(std::optional()) // seed -); - -XLA_RUNTIME_DEFINE_CUSTOM_CALL( - FlashAttentionScaleSoftmaxBackward, - FunctionWrapper(), checks, - BindFusedAttentionBackwardAttributes( - FusedAttentionBackwardCall( - "xla.gpu.flash.attention.backward.scale.softmax") - .Value(std::optional()) // mask - .Value(std::optional()) // bias - .Arg() // fwd_output - .Arg() // d_bmm1_lhs - .Arg() // d_bmm1_rhs - .Arg() // d_bmm2_rhs - .Value(std::optional()) // d_S - .Arg() // softmax_sum - .Arg() // d_Q_accum - .Arg() // scratch - .Value(std::optional()) // d_bias - ) - .Value(std::optional()) // dropout_rate - .Value(std::optional()) // seed -); - -XLA_RUNTIME_DEFINE_CUSTOM_CALL( - FlashAttentionScaleBiasSoftmaxDropoutBackward, - FunctionWrapper(), checks, - BindFusedAttentionBackwardAttributes( - FusedAttentionBackwardCall( - "xla.gpu.flash.attention.backward.scale.bias.softmax.dropout") - .Value(std::optional()) // mask - .Arg() // bias - .Arg() // fwd_output - .Arg() // d_bmm1_lhs - .Arg() // d_bmm1_rhs - .Arg() // d_bmm2_rhs - .Value(std::optional()) // d_S - .Arg() // softmax_sum - .Arg() // d_Q_accum - .Arg() // scratch - .Value(std::optional()) // d_bias - ) - .Attr("dropout_rate") // dropout_rate - .Attr("seed") // seed -); - -XLA_RUNTIME_DEFINE_CUSTOM_CALL( - FlashAttentionScaleSoftmaxDropoutBackward, - FunctionWrapper(), checks, - BindFusedAttentionBackwardAttributes( - FusedAttentionBackwardCall( - "xla.gpu.flash.attention.backward.scale.softmax.dropout") - .Value(std::optional()) // mask - .Value(std::optional()) // bias - .Arg() // fwd_output - .Arg() // d_bmm1_lhs - .Arg() // d_bmm1_rhs - .Arg() // d_bmm2_rhs - .Value(std::optional()) // d_S - .Arg() // softmax_sum - .Arg() // d_Q_accum - .Arg() // scratch - .Value(std::optional()) // d_bias - ) - .Attr("dropout_rate") // dropout_rate - .Attr("seed") // seed -); - -XLA_RUNTIME_DEFINE_CUSTOM_CALL( - FlashAttentionScaleBiasMaskSoftmaxBackward, - FunctionWrapper(), checks, - BindFusedAttentionBackwardAttributes( - FusedAttentionBackwardCall( - "xla.gpu.flash.attention.backward.scale.bias.mask.softmax") - .Arg() // mask - .Arg() // bias - .Arg() // fwd_output - .Arg() // d_bmm1_lhs - .Arg() // d_bmm1_rhs - .Arg() // d_bmm2_rhs - .Value(std::optional()) // d_S - .Arg() // softmax_sum - .Arg() // d_Q_accum - .Arg() // scratch - .Value(std::optional()) // d_bias - ) - .Value(std::optional()) // dropout_rate - .Value(std::optional()) // seed -); - -XLA_RUNTIME_DEFINE_CUSTOM_CALL( - FlashAttentionScaleMaskSoftmaxBackward, - FunctionWrapper(), checks, - BindFusedAttentionBackwardAttributes( - FusedAttentionBackwardCall( - "xla.gpu.flash.attention.backward.scale.mask.softmax") - .Arg() // mask - .Value(std::optional()) // bias - .Arg() // fwd_output - .Arg() // d_bmm1_lhs - .Arg() // d_bmm1_rhs - .Arg() // d_bmm2_rhs - .Value(std::optional()) // d_S - .Arg() // softmax_sum - .Arg() // d_Q_accum - .Arg() // scratch - .Value(std::optional()) // d_bias - ) - .Value(std::optional()) // dropout_rate - .Value(std::optional()) // seed -); - -XLA_RUNTIME_DEFINE_CUSTOM_CALL( - FlashAttentionScaleBiasMaskSoftmaxDropoutBackward, - FunctionWrapper(), checks, - BindFusedAttentionBackwardAttributes( - FusedAttentionBackwardCall( - "xla.gpu.flash.attention.backward.scale.bias.mask.softmax.dropout") - .Arg() // mask - .Arg() // bias - .Arg() // fwd_output - .Arg() // d_bmm1_lhs - .Arg() // d_bmm1_rhs - .Arg() // d_bmm2_rhs - .Value(std::optional()) // d_S - .Arg() // softmax_sum - .Arg() // d_Q_accum - .Arg() // scratch - .Value(std::optional()) // d_bias - ) - .Attr("dropout_rate") // dropout_rate - .Attr("seed") // seed -); - -XLA_RUNTIME_DEFINE_CUSTOM_CALL( - FlashAttentionScaleMaskSoftmaxDropoutBackward, - FunctionWrapper(), checks, - BindFusedAttentionBackwardAttributes( - FusedAttentionBackwardCall( - "xla.gpu.flash.attention.backward.scale.mask.softmax.dropout") - .Arg() // mask - .Value(std::optional()) // bias - .Arg() // fwd_output - .Arg() // d_bmm1_lhs - .Arg() // d_bmm1_rhs - .Arg() // d_bmm2_rhs - .Value(std::optional()) // d_S - .Arg() // softmax_sum - .Arg() // d_Q_accum - .Arg() // scratch - .Value(std::optional()) // d_bias - ) - .Attr("dropout_rate") // dropout_rate - .Attr("seed") // seed -); -//===----------------------------------------------------------------------===// -// cuBLASLt custom calls bindings and registration. -//===----------------------------------------------------------------------===// -void RegisterFusedAttentionCustomCalls( - runtime::DirectCustomCallRegistry& registry) { - auto fused_attention = [](std::string name) { - return "xla.gpu.fused.attention." + name; - }; - registry.Register(fused_attention("bmm.bmm.inference"), - FusedAttentionBmmBmmInference); - registry.Register(fused_attention("bmm.bmm.forward"), - FusedAttentionBmmBmmForward); - registry.Register(fused_attention("softmax.inference"), - FusedAttentionSoftmaxInference); - registry.Register(fused_attention("softmax.forward"), - FusedAttentionSoftmaxForward); - registry.Register(fused_attention("softmax.dropout.inference"), - FusedAttentionSoftmaxDropoutInference); - registry.Register(fused_attention("softmax.dropout.forward"), - FusedAttentionSoftmaxDropoutForward); - registry.Register(fused_attention("scale.bias.softmax.inference"), - FusedAttentionScaleBiasSoftmaxInference); - registry.Register(fused_attention("scale.bias.softmax.forward"), - FusedAttentionScaleBiasSoftmaxForward); - registry.Register(fused_attention("scale.bias.softmax.dropout.inference"), - FusedAttentionScaleBiasSoftmaxDropoutInference); - registry.Register(fused_attention("scale.bias.softmax.dropout.forward"), - FusedAttentionScaleBiasSoftmaxDropoutForward); - registry.Register(fused_attention("scale.mask.softmax.inference"), - FusedAttentionScaleMaskSoftmaxInference); - registry.Register(fused_attention("scale.mask.softmax.forward"), - FusedAttentionScaleMaskSoftmaxForward); - registry.Register(fused_attention("scale.mask.softmax.dropout.inference"), - FusedAttentionScaleMaskSoftmaxDropoutInference); - registry.Register(fused_attention("scale.mask.softmax.dropout.forward"), - FusedAttentionScaleMaskSoftmaxDropoutForward); - registry.Register(fused_attention("scale.bias.mask.softmax.inference"), - FusedAttentionScaleBiasMaskSoftmaxInference); - registry.Register(fused_attention("scale.bias.mask.softmax.forward"), - FusedAttentionScaleBiasMaskSoftmaxForward); - registry.Register( - fused_attention("scale.bias.mask.softmax.dropout.inference"), - FusedAttentionScaleBiasMaskSoftmaxDropoutInference); - registry.Register(fused_attention("scale.bias.mask.softmax.dropout.forward"), - FusedAttentionScaleBiasMaskSoftmaxDropoutForward); -} - -void RegisterFusedAttentionBackwardCustomCalls( - runtime::DirectCustomCallRegistry& registry) { - auto fused_attention = [](std::string name) { - return "xla.gpu.fused.attention.backward." + name; - }; - registry.Register(fused_attention("scale.dbias.softmax"), - FusedAttentionScaleBiasSoftmaxBackward); - registry.Register(fused_attention("scale.softmax"), - FusedAttentionScaleSoftmaxBackward); - registry.Register(fused_attention("scale.dbias.softmax.dropout"), - FusedAttentionScaleBiasSoftmaxDropoutBackward); - registry.Register(fused_attention("scale.softmax.dropout"), - FusedAttentionScaleSoftmaxDropoutBackward); - registry.Register(fused_attention("scale.dbias.mask.softmax"), - FusedAttentionScaleBiasMaskSoftmaxBackward); - registry.Register(fused_attention("scale.mask.softmax"), - FusedAttentionScaleMaskSoftmaxBackward); - registry.Register(fused_attention("scale.dbias.mask.softmax.dropout"), - FusedAttentionScaleBiasMaskSoftmaxDropoutBackward); - registry.Register(fused_attention("scale.mask.softmax.dropout"), - FusedAttentionScaleMaskSoftmaxDropoutBackward); - // flash attention bwd - auto flash_attention = [](std::string name) { - return "xla.gpu.flash.attention.backward." + name; - }; - registry.Register(flash_attention("scale.bias.softmax"), - FlashAttentionScaleBiasSoftmaxBackward); - registry.Register(flash_attention("scale.softmax"), - FlashAttentionScaleSoftmaxBackward); - registry.Register(flash_attention("scale.bias.softmax.dropout"), - FlashAttentionScaleBiasSoftmaxDropoutBackward); - registry.Register(flash_attention("scale.softmax.dropout"), - FlashAttentionScaleSoftmaxDropoutBackward); - registry.Register(flash_attention("scale.bias.mask.softmax"), - FlashAttentionScaleBiasMaskSoftmaxBackward); - registry.Register(flash_attention("scale.mask.softmax"), - FlashAttentionScaleMaskSoftmaxBackward); - registry.Register(flash_attention("scale.bias.mask.softmax.dropout"), - FlashAttentionScaleBiasMaskSoftmaxDropoutBackward); - registry.Register(flash_attention("scale.mask.softmax.dropout"), - FlashAttentionScaleMaskSoftmaxDropoutBackward); -} -} // namespace gpu -} // namespace xla diff --git a/third_party/xla/xla/service/gpu/runtime/fused_attention.h b/third_party/xla/xla/service/gpu/runtime/fused_attention.h deleted file mode 100644 index 6c8a4ce6b2f790..00000000000000 --- a/third_party/xla/xla/service/gpu/runtime/fused_attention.h +++ /dev/null @@ -1,108 +0,0 @@ -/* Copyright 2023 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_SERVICE_GPU_RUNTIME_FUSED_ATTENTION_H_ -#define XLA_SERVICE_GPU_RUNTIME_FUSED_ATTENTION_H_ - -#include -#include - -#include "absl/container/node_hash_map.h" -#include "absl/synchronization/mutex.h" -#include "xla/mlir/runtime/transforms/custom_call_encoding.h" -#include "xla/runtime/custom_call_registry.h" -#include "xla/service/gpu/gpu_fused_mha_runner.h" - -namespace xla { -namespace gpu { - -// Registers XLA Gpu runtime fused attention custom calls. -void RegisterFusedAttentionCustomCalls( - runtime::DirectCustomCallRegistry& registry); - -// Register type names for fused attention attributes defined by MHLO dialect. -void RegisterFusedAttentionTypeIdNames(runtime::TypeIDNameRegistry& registry); - -// Add attributes encoding for fused attention attributes defined by LMHLO -// dialect. -void PopulateFusedAttentionForwardDAGSignatureAttrEncoding( - runtime::CustomCallAttrEncodingSet& encoding); - -// Registers XLA Gpu runtime fused attention backward custom calls. -void RegisterFusedAttentionBackwardCustomCalls( - runtime::DirectCustomCallRegistry& registry); - -// Add attributes encoding for fused attention backward attributes defined by -// LMHLO dialect. -void PopulateFusedAttentionBackwardDAGSignatureAttrEncoding( - runtime::CustomCallAttrEncodingSet& encoding); - -void PopulateFusedAttentionAlgorithmConfigAttrEncoding( - runtime::CustomCallAttrEncodingSet& encoding); - -//===----------------------------------------------------------------------===// -// Cache fused dot attention runners between invocations of fused dot attention -// custom calls. -//===----------------------------------------------------------------------===// -struct FusedAttentionRunner { - explicit FusedAttentionRunner(GpufMHAConfig config) - : config(std::move(config)), runner(this->config) {} - GpufMHAConfig config; - FusedMultiHeadedAttentionRunner runner; -}; - -struct FusedAttentionBackwardRunner { - explicit FusedAttentionBackwardRunner(GpufMHABackwardConfig config) - : config(std::move(config)), runner(this->config) {} - GpufMHABackwardConfig config; - FusedMultiHeadedAttentionBackwardRunner runner; -}; - -class StreamExecutorFusedAttentionRunners - : public runtime::StateVector {}; - -class StreamExecutorFusedAttentionBackwardRunners - : public runtime::StateVector {}; - -// Xla executable keeps a mapping from stream executors to fused attention -// runners. -class FusedAttentionRunners { - public: - StreamExecutorFusedAttentionRunners* operator()(se::StreamExecutor* executor); - - private: - mutable absl::Mutex mutex_; - absl::node_hash_map - runners_ ABSL_GUARDED_BY(mutex_); -}; - -// Xla executable keeps a mapping from stream executors to fused attention -// backward runners. -class FusedAttentionBackwardRunners { - public: - StreamExecutorFusedAttentionBackwardRunners* operator()( - se::StreamExecutor* executor); - - private: - mutable absl::Mutex mutex_; - absl::node_hash_map - runners_ ABSL_GUARDED_BY(mutex_); -}; - -} // namespace gpu -} // namespace xla - -#endif // XLA_SERVICE_GPU_RUNTIME_FUSED_ATTENTION_H_ diff --git a/third_party/xla/xla/service/gpu/runtime3/fused_mha_thunk.cc b/third_party/xla/xla/service/gpu/runtime/fused_mha_thunk.cc similarity index 99% rename from third_party/xla/xla/service/gpu/runtime3/fused_mha_thunk.cc rename to third_party/xla/xla/service/gpu/runtime/fused_mha_thunk.cc index 731f4a087afa63..03978f0d1f8184 100644 --- a/third_party/xla/xla/service/gpu/runtime3/fused_mha_thunk.cc +++ b/third_party/xla/xla/service/gpu/runtime/fused_mha_thunk.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/runtime3/fused_mha_thunk.h" +#include "xla/service/gpu/runtime/fused_mha_thunk.h" #include #include diff --git a/third_party/xla/xla/service/gpu/runtime3/fused_mha_thunk.h b/third_party/xla/xla/service/gpu/runtime/fused_mha_thunk.h similarity index 97% rename from third_party/xla/xla/service/gpu/runtime3/fused_mha_thunk.h rename to third_party/xla/xla/service/gpu/runtime/fused_mha_thunk.h index c00ea8240942ce..9b7cbf000d12db 100644 --- a/third_party/xla/xla/service/gpu/runtime3/fused_mha_thunk.h +++ b/third_party/xla/xla/service/gpu/runtime/fused_mha_thunk.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_GPU_RUNTIME3_FUSED_MHA_THUNK_H_ -#define XLA_SERVICE_GPU_RUNTIME3_FUSED_MHA_THUNK_H_ +#ifndef XLA_SERVICE_GPU_RUNTIME_FUSED_MHA_THUNK_H_ +#define XLA_SERVICE_GPU_RUNTIME_FUSED_MHA_THUNK_H_ #include @@ -127,4 +127,4 @@ class FusedMHABackwardThunk : public Thunk { }; } // namespace gpu } // namespace xla -#endif // XLA_SERVICE_GPU_RUNTIME3_FUSED_MHA_THUNK_H_ +#endif // XLA_SERVICE_GPU_RUNTIME_FUSED_MHA_THUNK_H_ diff --git a/third_party/xla/xla/service/gpu/runtime/gemm.cc b/third_party/xla/xla/service/gpu/runtime/gemm.cc deleted file mode 100644 index a4a1e47baca94c..00000000000000 --- a/third_party/xla/xla/service/gpu/runtime/gemm.cc +++ /dev/null @@ -1,208 +0,0 @@ -/* Copyright 2022 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "xla/service/gpu/runtime/gemm.h" - -#include -#include -#include -#include -#include - -#include "absl/status/status.h" -#include "xla/runtime/custom_call.h" -#include "xla/runtime/executable.h" -#include "xla/service/gpu/gpu_asm_opts_util.h" -#include "xla/service/gpu/matmul_utils.h" -#include "xla/service/gpu/non_atomically_upgradeable_rw_lock.h" -#include "xla/service/gpu/runtime/support.h" -#include "xla/service/hlo_module_config.h" -#include "xla/service/service_executable_run_options.h" -#include "xla/status.h" -#include "xla/stream_executor/blas.h" -#include "xla/stream_executor/device_memory.h" -#include "xla/xla.pb.h" -#include "tsl/platform/errors.h" - -#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM -#include "xla/service/gpu/gemm_algorithm_picker.h" -#include "xla/stream_executor/gpu/redzone_allocator.h" -#endif - -namespace xla { -namespace gpu { - -using xla::runtime::CustomCall; -using xla::runtime::State; -using xla::runtime::StridedMemrefView; - -#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM - -// TODO(ezhulenev): Delete run time auto tuning from XLA. -absl::Status DoRuntimeAutotuning(se::Stream* stream, GemmConfig* config, - se::DeviceMemoryBase lhs_buffer, - se::DeviceMemoryBase rhs_buffer, - se::DeviceMemoryBase output_buffer, - const Shape& output_shape, double beta, - const DebugOptions* debug_options, - NonAtomicallyUpgradeableRWLock* gpu_lock) { - VLOG(3) << "Running GEMM runtime autotuning"; - std::vector algorithms; - stream->parent()->GetBlasGemmAlgorithms(stream, &algorithms); - const bool deterministic_ops = debug_options->xla_gpu_deterministic_ops(); - - AutotuneConfig autotune_config{ - DeviceConfig{stream->parent(), stream->parent()->GetAllocator()}, - *debug_options}; - - // TODO(jlebar): We should not use stream->parent()->GetAllocator() here; - // that's the global CUDA allocator. There may not be any free space in - // there, because TF usually gobbles it all up for its own BFCAllocator. We - // should use the allocator the user passed when running the XLA program. - se::RedzoneAllocator buffer_allocator( - stream, stream->parent()->GetAllocator(), - PtxOptsFromDebugOptions(*debug_options), - /*memory_limit=*/std::numeric_limits::max(), - /*redzone_size=*/autotune_config.should_check_correctness() - ? debug_options->xla_gpu_redzone_padding_bytes() - : 0); - - // Upgrade the reader lock for execution to a writer lock to protect runtime - // autotuning. - NonAtomicallyUpgradeableRWLock::WriterLock writer_lock = - gpu_lock->UpgradeToWriterMutexLock(); - - TF_ASSIGN_OR_RETURN( - AutotuneResult best_algorithm, - GetBestBlasAlgorithm( - stream, buffer_allocator, /*gemm_str=*/std::nullopt, autotune_config, - lhs_buffer, rhs_buffer, output_buffer, algorithms, output_shape, - HloModuleConfig(), beta, - [&](const se::blas::AlgorithmType& algorithm) - -> absl::StatusOr { - se::blas::ProfileResult profile_result; - // We expect GemmWithAlgorithm to fail sometimes -- in fact, it will - // fail for all algorithms if we're targeting < sm_50. But because - // we pass a non-null ProfileResult, DoGemmWithAlgorithm should - // always return true, and the actual success-ness is returned in - // ProfileResult::is_valid. - TF_RETURN_IF_ERROR( - RunGemm(*config, lhs_buffer, rhs_buffer, output_buffer, - se::DeviceMemoryBase(nullptr, 0), deterministic_ops, - stream, algorithm, &profile_result)); - return std::move(profile_result); - })); - - if (best_algorithm.has_gemm()) { - config->algorithm = algorithms[best_algorithm.gemm().algorithm()]; - return absl::OkStatus(); - } else { - return Internal("Runtime autotuning failed to select an algorithm"); - } -} -#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM - -static absl::Status GemmImpl(const ServiceExecutableRunOptions* run_options, - const DebugOptions* debug_options, - NonAtomicallyUpgradeableRWLock* gpu_lock, - State state, StridedMemrefView lhs, - StridedMemrefView rhs, StridedMemrefView out, - StridedMemrefView workspace, int64_t algorithm, - double alpha_real, double alpha_imag, double beta, - DotDimensionNumbers dot_dims, - absl::Span precision) { - se::DeviceMemoryBase lhs_data = GetDeviceAddress(lhs); - se::DeviceMemoryBase rhs_data = GetDeviceAddress(rhs); - se::DeviceMemoryBase output_data = GetDeviceAddress(out); - se::DeviceMemoryBase workspace_data = GetDeviceAddress(workspace); - const bool deterministic_ops = debug_options->xla_gpu_deterministic_ops(); - - VLOG(3) << "Running GEMM"; - se::Stream* stream = run_options->stream(); - Shape output_shape = ToShape(out); - - // Get the gemm config from the state. - TF_ASSIGN_OR_RETURN(GemmConfig * gemm_config, state.GetOrCreate([&] { - absl::StatusOr gemm_config = - GetGemmConfig(lhs, rhs, out, algorithm, alpha_real, alpha_imag, beta, - dot_dims.lhs_batch, dot_dims.lhs_contract, - dot_dims.rhs_batch, dot_dims.rhs_contract, - precision.empty() ? se::blas::kDefaultComputePrecision - : *absl::c_max_element(precision)); - return ToAbsl(gemm_config); - })); - - // Set the gemm algorithm by runtime autotuning. We do runtime autotuning - // outside of state.GetOrCreate() because otherwise it would be a potential - // deadlock. - if (gemm_config->algorithm == stream_executor::blas::kRuntimeAutotuning) { -#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM - auto status = DoRuntimeAutotuning(stream, gemm_config, lhs_data, rhs_data, - output_data, output_shape, beta, - debug_options, gpu_lock); - if (!status.ok()) { - return absl::InternalError(status.ToString()); - } -#else - return absl::InternalError( - "Failed to run runtime autotuner because GPU support is not enabled"); -#endif - } - - return RunGemm(*gemm_config, lhs_data, rhs_data, output_data, workspace_data, - deterministic_ops, stream); -} - -static absl::Status InitCuBLASImpl( - const ServiceExecutableRunOptions* run_options) { - // Initialize (with memoization) BlasSupport here because cublasCreate fails - // during gpu graph capturing. - se::StreamExecutor* executor = run_options->stream()->parent(); - if (!executor->AsBlas()) { - return absl::InternalError("Failed to initialize BLAS support"); - } - return absl::OkStatus(); -} - -XLA_RUNTIME_DEFINE_CUSTOM_CALL( - Gemm, FunctionWrapper(), checks, - CustomCall::Bind("xla.gpu.gemm") - .UserData() - .UserData() - .UserData() - .State("uid") - .Arg() // lhs - .Arg() // rhs - .Arg() // out - .Arg() // workspace - .Attr("algorithm") - .Attr("alpha_real") - .Attr("alpha_imag") - .Attr("beta") - .Attr("dot_dims") - .Attr>("precision")); - -XLA_RUNTIME_DEFINE_CUSTOM_CALL( - InitCuBLAS, FunctionWrapper(), checks, - CustomCall::Bind("xla.gpu.init_cublas") - .UserData()); - -void RegisterGemmCustomCalls(runtime::DirectCustomCallRegistry& registry) { - registry.Register("xla.gpu.gemm", Gemm); - registry.Register("xla.gpu.init_cublas", InitCuBLAS); -} - -} // namespace gpu -} // namespace xla diff --git a/third_party/xla/xla/service/gpu/runtime/gemm.h b/third_party/xla/xla/service/gpu/runtime/gemm.h deleted file mode 100644 index 8ccaea1ea25131..00000000000000 --- a/third_party/xla/xla/service/gpu/runtime/gemm.h +++ /dev/null @@ -1,36 +0,0 @@ -/* Copyright 2022 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_SERVICE_GPU_RUNTIME_GEMM_H_ -#define XLA_SERVICE_GPU_RUNTIME_GEMM_H_ - -#include "absl/container/node_hash_map.h" -#include "xla/runtime/custom_call_registry.h" -#include "xla/runtime/state.h" -#include "xla/service/gpu/matmul_utils.h" - -namespace xla { -namespace gpu { - -// Registers XLA Gpu runtime Gemm# custom calls. -void RegisterGemmCustomCalls(runtime::DirectCustomCallRegistry& registry); - -// Keep GemmConfigs for all gemm/matmul instances in the executable. -class GemmConfigs : public runtime::StateVector {}; - -} // namespace gpu -} // namespace xla - -#endif // XLA_SERVICE_GPU_RUNTIME_GEMM_H_ diff --git a/third_party/xla/xla/service/gpu/runtime3/gemm_thunk.cc b/third_party/xla/xla/service/gpu/runtime/gemm_thunk.cc similarity index 91% rename from third_party/xla/xla/service/gpu/runtime3/gemm_thunk.cc rename to third_party/xla/xla/service/gpu/runtime/gemm_thunk.cc index 0e11081f31d404..cc74707c84e01f 100644 --- a/third_party/xla/xla/service/gpu/runtime3/gemm_thunk.cc +++ b/third_party/xla/xla/service/gpu/runtime/gemm_thunk.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/runtime3/gemm_thunk.h" +#include "xla/service/gpu/runtime/gemm_thunk.h" #include @@ -48,10 +48,14 @@ absl::Status GemmThunk::ExecuteOnStream(const ExecuteParams& params) { if (workspace_.has_value()) { workspace = allocs.GetDeviceAddress(workspace_.value()); } + TF_ASSIGN_OR_RETURN( + se::Stream * stream, + GetStreamForExecution(Thunk::execution_stream_id(), params)); + return RunGemm(config_, allocs.GetDeviceAddress(lhs_buffer_), allocs.GetDeviceAddress(rhs_buffer_), allocs.GetDeviceAddress(output_buffer_), workspace, - deterministic_, params.stream); + deterministic_, stream); } absl::Status GemmThunk::Initialize(const InitializeParams& params) { diff --git a/third_party/xla/xla/service/gpu/runtime3/gemm_thunk.h b/third_party/xla/xla/service/gpu/runtime/gemm_thunk.h similarity index 94% rename from third_party/xla/xla/service/gpu/runtime3/gemm_thunk.h rename to third_party/xla/xla/service/gpu/runtime/gemm_thunk.h index 79dffc09e1b7c5..a134ed2623b1ae 100644 --- a/third_party/xla/xla/service/gpu/runtime3/gemm_thunk.h +++ b/third_party/xla/xla/service/gpu/runtime/gemm_thunk.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_GPU_RUNTIME3_GEMM_THUNK_H_ -#define XLA_SERVICE_GPU_RUNTIME3_GEMM_THUNK_H_ +#ifndef XLA_SERVICE_GPU_RUNTIME_GEMM_THUNK_H_ +#define XLA_SERVICE_GPU_RUNTIME_GEMM_THUNK_H_ #include @@ -66,4 +66,4 @@ class GemmThunk : public Thunk { } // namespace gpu } // namespace xla -#endif // XLA_SERVICE_GPU_RUNTIME3_GEMM_THUNK_H_ +#endif // XLA_SERVICE_GPU_RUNTIME_GEMM_THUNK_H_ diff --git a/third_party/xla/xla/service/gpu/runtime/gpu_kernel_helper.h b/third_party/xla/xla/service/gpu/runtime/gpu_kernel_helper.h deleted file mode 100644 index f5a69111dc7342..00000000000000 --- a/third_party/xla/xla/service/gpu/runtime/gpu_kernel_helper.h +++ /dev/null @@ -1,148 +0,0 @@ -/* Copyright 2023 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_SERVICE_GPU_RUNTIME_GPU_KERNEL_HELPER_H_ -#define XLA_SERVICE_GPU_RUNTIME_GPU_KERNEL_HELPER_H_ - -#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM - -#include - -#include "tsl/lib/math/math_util.h" - -#if GOOGLE_CUDA -#include "third_party/gpus/cuda/include/cuda_runtime_api.h" -#else -#include "rocm/include/hip/hip_runtime.h" -#endif - -namespace xla { -namespace gpu { - -#if GOOGLE_CUDA -#define WAVEFRONT_SIZE 32 -#define FORCEINLINE __forceinline__ -using gpuStream_t = cudaStream_t; -using gpuError_t = cudaError_t; -using gpuEvent_t = cudaEvent_t; -#define gpuSuccess cudaSuccess -#define gpuGetLastError cudaGetLastError -#define gpuGetErrorString cudaGetErrorString -#define gpuEventRecord cudaEventRecord -#define gpuEventSynchronize cudaEventSynchronize -#define gpuEventDestroy cudaEventDestroy -#define gpuEventCreate cudaEventCreate -#define gpuEventCreateWithFlags cudaEventCreateWithFlags -#define gpuEventDisableTiming cudaEventDisableTiming -#define gpuEventElapsedTime cudaEventElapsedTime -#define gpuDeviceSynchronize cudaDeviceSynchronize -#define gpuLaunchKernel cudaLaunchKernel -#define gpuMemcpy cudaMemcpy -#define gpuMalloc cudaMalloc -#define gpuFree cudaFree -#define gpuMemcpyHostToDevice cudaMemcpyHostToDevice -#define gpuMemcpyDeviceToHost cudaMemcpyDeviceToHost -#define gpuStreamCreate cudaStreamCreate -#define gpuStreamSynchronize cudaStreamSynchronize - -#elif TENSORFLOW_USE_ROCM -using gpuStream_t = hipStream_t; -using gpuError_t = hipError_t; -using gpuEvent_t = hipEvent_t; -#define gpuSuccess hipSuccess -#define gpuGetLastError hipGetLastError -#define gpuGetErrorString hipGetErrorString -#define gpuEventRecord hipEventRecord -#define gpuEventDestroy hipEventDestroy -#define gpuEventSynchronize hipEventSynchronize -#define gpuEventCreate hipEventCreate -#define gpuEventCreateWithFlags hipEventCreateWithFlags -#define gpuEventDisableTiming hipEventDisableTiming -#define gpuEventElapsedTime hipEventElapsedTime -#define gpuDeviceSynchronize hipDeviceSynchronize -#define gpuLaunchKernel hipLaunchKernel -#define gpuMemcpy hipMemcpy -#define gpuMalloc hipMalloc -#define gpuFree hipFree -#define gpuMemcpyHostToDevice hipMemcpyHostToDevice -#define gpuMemcpyDeviceToHost hipMemcpyDeviceToHost -#define gpuStreamCreate hipStreamCreate -#define gpuStreamSynchronize hipStreamSynchronize - -#ifdef __AMDGCN_WAVEFRONT_SIZE -#define WAVEFRONT_SIZE __AMDGCN_WAVEFRONT_SIZE -#else -#define WAVEFRONT_SIZE 64 -#endif -#define FORCEINLINE __forceinline__ -#endif - -// macro wrapper to declare dynamic shared memory -#if GOOGLE_CUDA - -#define GPU_DYNAMIC_SHARED_MEM_DECL(ALIGN, TYPE, NAME) \ - extern __shared__ __align__(ALIGN) \ - TYPE NAME[] - -#elif TENSORFLOW_USE_ROCM - -#define GPU_DYNAMIC_SHARED_MEM_DECL(ALIGN, TYPE, NAME) \ - HIP_DYNAMIC_SHARED(TYPE, NAME) - -#endif - -enum class ShflType { Sync, Up, Down, Xor }; - -template -__device__ FORCEINLINE NT GpuShuffle(NT val, uint32_t idx, - uint32_t allmsk = 0xffffffffu) { - constexpr uint32_t SZ = - tsl::MathUtil::CeilOfRatio(sizeof(NT), sizeof(uint32_t)); - union S { - NT v; - uint32_t d[SZ]; - }; - S in{val}, res{}; - -#pragma unroll - for (uint32_t i = 0; i < SZ; i++) { -#if GOOGLE_CUDA - if constexpr (Type == ShflType::Sync) - res.d[i] = __shfl_sync(allmsk, in.d[i], idx); - else if constexpr (Type == ShflType::Up) - res.d[i] = __shfl_up_sync(allmsk, in.d[i], idx); - else if constexpr (Type == ShflType::Down) - res.d[i] = __shfl_down_sync(allmsk, in.d[i], idx); - else if constexpr (Type == ShflType::Xor) - res.d[i] = __shfl_xor_sync(allmsk, in.d[i], idx); -#elif TENSORFLOW_USE_ROCM // ROcm does not support sync shuffle intrinsics - if constexpr (Type == ShflType::Sync) - res.d[i] = __shfl(in.d[i], idx); - else if constexpr (Type == ShflType::Up) - res.d[i] = __shfl_up(in.d[i], idx); - else if constexpr (Type == ShflType::Down) - res.d[i] = __shfl_down(in.d[i], idx); - else if constexpr (Type == ShflType::Xor) - res.d[i] = __shfl_xor(in.d[i], idx); -#endif - } - return res.v; -} - -} // namespace gpu -} // namespace xla - -#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM -#endif // XLA_SERVICE_GPU_RUNTIME_GPU_KERNEL_HELPER_H_ diff --git a/third_party/xla/xla/service/gpu/runtime/gpublas_lt_matmul.cc b/third_party/xla/xla/service/gpu/runtime/gpublas_lt_matmul.cc deleted file mode 100644 index 0183254e756a14..00000000000000 --- a/third_party/xla/xla/service/gpu/runtime/gpublas_lt_matmul.cc +++ /dev/null @@ -1,302 +0,0 @@ -/* Copyright 2022 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License.1 -==============================================================================*/ - -#include "xla/service/gpu/runtime/gpublas_lt_matmul.h" - -#include -#include -#include -#include -#include - -#include "xla/mlir/runtime/transforms/custom_call_encoding.h" -#include "xla/runtime/custom_call.h" -#include "xla/runtime/executable.h" -#include "xla/runtime/logical_result.h" -#include "xla/runtime/state.h" -#include "xla/service/gpu/matmul_utils.h" -#include "xla/service/gpu/runtime/support.h" -#include "xla/service/service_executable_run_options.h" -#include "xla/stream_executor/scratch_allocator.h" -#include "xla/xla.pb.h" -#include "tsl/platform/status.h" - -#if TENSORFLOW_USE_ROCM -#include "rocm/rocm_config.h" -#endif - -namespace xla { -#if GOOGLE_CUDA || TF_HIPBLASLT - -using xla::runtime::CustomCall; -using xla::runtime::CustomCallAttrEncodingSet; -using xla::runtime::EnumAttrEncoding; -using xla::runtime::State; -using xla::runtime::StridedMemrefView; - -namespace lmhlo_gpu = ::mlir::lmhlo_gpu; - -//===----------------------------------------------------------------------===// -// Register cuBLASLt attributes decoding with the Xla runtime. -//===----------------------------------------------------------------------===// - -namespace runtime { -XLA_RUNTIME_REGISTER_ENUM_ATTR_DECODING(se::gpu::BlasLt::Epilogue); -} // namespace runtime - -//===----------------------------------------------------------------------===// -// Encoding from MHLO attributes to Xla runtime enums. -//===----------------------------------------------------------------------===// - -namespace gpu { - -void PopulateCublasLtMatmulAttrEncoding(CustomCallAttrEncodingSet& encoding) { - encoding.Add>( - [](lmhlo_gpu::CublasLtMatmulEpilogue value) -> se::gpu::BlasLt::Epilogue { - return gpublas_lt::AsBlasLtEpilogue(value).value(); - }); -} - -//===----------------------------------------------------------------------===// -// cuBLASLt matmul custom call implementation. -//===----------------------------------------------------------------------===// - -namespace { - -absl::Status DoMatmul( - const ServiceExecutableRunOptions* run_options, - const DebugOptions* debug_options, State gemm_config, - State matmul_plan, StridedMemrefView a, - StridedMemrefView b, StridedMemrefView c, StridedMemrefView d, - std::optional bias, std::optional aux, - std::optional a_scale, - std::optional b_scale, - std::optional c_scale, - std::optional d_scale, - std::optional d_amax, int64_t algorithm, - double alpha_real, double alpha_imag, double beta, - DotDimensionNumbers dot_dims, se::gpu::BlasLt::Epilogue epilogue, - absl::Span precision) { - se::Stream* stream = run_options->stream(); - - // Find the gemm config for this instance of matmul. - TF_ASSIGN_OR_RETURN(GemmConfig * config, gemm_config.GetOrCreate([&] { - return ToAbsl(GetGemmConfig( - a, b, d, algorithm, alpha_real, alpha_imag, beta, dot_dims.lhs_batch, - dot_dims.lhs_contract, dot_dims.rhs_batch, dot_dims.rhs_contract, - precision.empty() ? se::blas::kDefaultComputePrecision - : *absl::c_max_element(precision), - c, bias)); - })); - - // Get the matmul plan for this instance of matmul. - TF_ASSIGN_OR_RETURN(auto plan, matmul_plan.GetOrCreate([&] { - return ToAbsl(se::gpu::BlasLt::GetMatmulPlan(stream, *config, epilogue)); - })); - - TF_ASSIGN_OR_RETURN(auto algos, (*plan)->GetAlgorithms()); - if (static_cast(algorithm) >= algos.size()) { - return absl::InternalError( - absl::StrFormat("The requested gpublas-lt matmul " - "algorithm is not found. Total algorithms available: " - "%zu; requested: %ld", - algos.size(), algorithm)); - } - - se::DeviceMemoryBase a_data = GetDeviceAddress(a); - se::DeviceMemoryBase b_data = GetDeviceAddress(b); - se::DeviceMemoryBase c_data = GetDeviceAddress(c); - se::DeviceMemoryBase d_data = GetDeviceAddress(d); - se::DeviceMemoryBase bias_data; - if (bias.has_value()) bias_data = GetDeviceAddress(*bias); - se::DeviceMemoryBase aux_data; - if (aux.has_value()) aux_data = GetDeviceAddress(*aux); - - se::DeviceMemoryBase a_scale_data; - if (a_scale.has_value()) a_scale_data = GetDeviceAddress(*a_scale); - se::DeviceMemoryBase b_scale_data; - if (b_scale.has_value()) b_scale_data = GetDeviceAddress(*b_scale); - se::DeviceMemoryBase c_scale_data; - if (c_scale.has_value()) c_scale_data = GetDeviceAddress(*c_scale); - se::DeviceMemoryBase d_scale_data; - if (d_scale.has_value()) d_scale_data = GetDeviceAddress(*d_scale); - se::DeviceMemoryBase d_amax_data; - if (d_amax.has_value()) d_amax_data = GetDeviceAddress(*d_amax); - - se::OwningScratchAllocator<> scratch_allocator( - stream->parent()->device_ordinal(), stream->parent()->GetAllocator()); - - return (*plan)->ExecuteOnStream( - stream, a_data, b_data, c_data, d_data, bias_data, aux_data, a_scale_data, - b_scale_data, c_scale_data, d_scale_data, d_amax_data, algos[algorithm], - scratch_allocator); -} - -} // namespace - -static absl::Status CublasLtMatmulImpl( - const ServiceExecutableRunOptions* run_options, - const DebugOptions* debug_options, State gemm_config, - State matmul_plan, StridedMemrefView a, - StridedMemrefView b, StridedMemrefView c, StridedMemrefView d, - std::optional bias, std::optional aux, - int64_t algorithm, double alpha_real, double alpha_imag, double beta, - DotDimensionNumbers dot_dims, se::gpu::BlasLt::Epilogue epilogue, - absl::Span precision) { - VLOG(3) << "Running CublasLtMatmul"; - std::optional a_scale, b_scale, c_scale, d_scale, d_amax; - return DoMatmul(run_options, debug_options, gemm_config, matmul_plan, a, b, c, - d, bias, aux, a_scale, b_scale, c_scale, d_scale, d_amax, - algorithm, alpha_real, alpha_imag, beta, dot_dims, epilogue, - precision); -} - -static absl::Status CublasLtMatmulF8Impl( - const ServiceExecutableRunOptions* run_options, - const DebugOptions* debug_options, State gemm_config, - State matmul_plan, StridedMemrefView a, - StridedMemrefView b, StridedMemrefView c, StridedMemrefView a_scale, - StridedMemrefView b_scale, StridedMemrefView c_scale, - StridedMemrefView d_scale, StridedMemrefView d, - CustomCall::RemainingArgs remaining_args, int64_t algorithm, - double alpha_real, double alpha_imag, double beta, - DotDimensionNumbers dot_dims, se::gpu::BlasLt::Epilogue epilogue, - absl::Span precision) { - VLOG(3) << "Running CublasLtMatmulF8"; - std::optional bias, d_amax, aux; - int current_remaining_arg = 0; - - // Get bias, if present - if (epilogue == se::gpu::BlasLt::Epilogue::kBias || - epilogue == se::gpu::BlasLt::Epilogue::kBiasThenReLU || - epilogue == se::gpu::BlasLt::Epilogue::kBiasThenGELU || - epilogue == se::gpu::BlasLt::Epilogue::kBiasThenGELUWithAux) { - if (remaining_args.size() <= current_remaining_arg) { - return absl::InternalError("Epilogue not present in CublasLtMatmulF8 op"); - } - auto bias_or_failure = - remaining_args.get(current_remaining_arg++); - if (failed(bias_or_failure)) { - return absl::InternalError("Failed to get epilogue"); - } - bias = bias_or_failure.value(); - } - - // Get amax, if present - if (remaining_args.size() > current_remaining_arg) { - auto d_amax_or_failure = - remaining_args.get(current_remaining_arg++); - if (failed(d_amax_or_failure)) { - return absl::InternalError("Failed to get d_amax"); - } - d_amax = d_amax_or_failure.value(); - } - - return DoMatmul(run_options, debug_options, gemm_config, matmul_plan, a, b, c, - d, bias, aux, a_scale, b_scale, c_scale, d_scale, d_amax, - algorithm, alpha_real, alpha_imag, beta, dot_dims, epilogue, - precision); -} - -//===----------------------------------------------------------------------===// -// cuBLASLt custom calls bindings and registration. -//===----------------------------------------------------------------------===// - -template -auto BindMatmulAttributes(runtime::CustomCallBinding binding) { - return std::move(binding) - .template Attr("algorithm") - .template Attr("alpha_real") - .template Attr("alpha_imag") - .template Attr("beta") - .template Attr("dot_dims") - .template Attr("epilogue") - .template Attr>("precision"); -} - -auto CublasLtMatmulCall(const char* name) { - return CustomCall::Bind(name) - .UserData() - .UserData() - .State("uid") - .State("uid") - .Arg() // a - .Arg() // b - .Arg() // c - .Arg(); // d -} - -XLA_RUNTIME_DEFINE_CUSTOM_CALL( - CublasLtMatmul, FunctionWrapper(), checks, - BindMatmulAttributes(CublasLtMatmulCall("xla.gpu.cublas.lt.matmul") - .Value(std::optional()) // bias - .Value(std::optional()) // aux - )); - -XLA_RUNTIME_DEFINE_CUSTOM_CALL( - CublasLtMatmulBias, FunctionWrapper(), checks, - BindMatmulAttributes(CublasLtMatmulCall("xla.gpu.cublas.lt.matmul.bias") - .Arg() // bias - .Value(std::optional()) // aux - )); - -XLA_RUNTIME_DEFINE_CUSTOM_CALL( - CublasLtMatmulAux, FunctionWrapper(), checks, - BindMatmulAttributes(CublasLtMatmulCall("xla.gpu.cublas.lt.matmul.aux") - .Value(std::optional()) // bias - .Arg() // aux - )); - -XLA_RUNTIME_DEFINE_CUSTOM_CALL( - CublasLtMatmulBiasAux, FunctionWrapper(), checks, - BindMatmulAttributes(CublasLtMatmulCall("xla.gpu.cublas.lt.matmul.bias.aux") - .Arg() // bias - .Arg() // aux - )); - -auto CublasLtMatmulF8Call(const char* name) { - return CustomCall::Bind(name) - .UserData() - .UserData() - .State("uid") - .State("uid") - .Arg() // a - .Arg() // b - .Arg() // c - .Arg() // a_scale - .Arg() // b_scale - .Arg() // c_scale - .Arg() // d_scale - .Arg(); // d -} - -XLA_RUNTIME_DEFINE_CUSTOM_CALL( - CublasLtMatmulF8, FunctionWrapper(), checks, - BindMatmulAttributes( - CublasLtMatmulF8Call("xla.gpu.cublas.lt.matmul.f8").RemainingArgs())); - -void RegisterMatmulCustomCalls(runtime::DirectCustomCallRegistry& registry) { - registry.Register("xla.gpu.cublas.lt.matmul", CublasLtMatmul); - registry.Register("xla.gpu.cublas.lt.matmul.bias", CublasLtMatmulBias); - registry.Register("xla.gpu.cublas.lt.matmul.aux", CublasLtMatmulAux); - registry.Register("xla.gpu.cublas.lt.matmul.bias.aux", CublasLtMatmulBiasAux); - registry.Register("xla.gpu.cublas.lt.matmul.f8", CublasLtMatmulF8); -} - -} // namespace gpu -#endif // GOOGLE_CUDA || TF_HIPBLASLT -} // namespace xla diff --git a/third_party/xla/xla/service/gpu/runtime/gpublas_lt_matmul.h b/third_party/xla/xla/service/gpu/runtime/gpublas_lt_matmul.h deleted file mode 100644 index d66a0db32bd2e3..00000000000000 --- a/third_party/xla/xla/service/gpu/runtime/gpublas_lt_matmul.h +++ /dev/null @@ -1,46 +0,0 @@ -/* Copyright 2022 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_SERVICE_GPU_RUNTIME_GPUBLAS_LT_MATMUL_H_ -#define XLA_SERVICE_GPU_RUNTIME_GPUBLAS_LT_MATMUL_H_ - -#include "xla/mlir/runtime/transforms/custom_call_encoding.h" -#include "xla/runtime/custom_call_registry.h" - -#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM -#include "xla/service/gpu/matmul_utils.h" -#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM - -namespace xla { -namespace gpu { - -// Add cuBLASLt attributes encoding -void PopulateCublasLtMatmulAttrEncoding( - runtime::CustomCallAttrEncodingSet& encoding); - -#if GOOGLE_CUDA || TF_HIPBLASLT - -// Registers XLA Gpu runtime cuBLASLt custom calls. -void RegisterMatmulCustomCalls(runtime::DirectCustomCallRegistry& registry); - -// Keep cublas_lt::MatmulPlan's for all matmul instances in the executable. -class MatmulPlans - : public runtime::StateVector {}; -#endif // GOOGLE_CUDA || TF_HIPBLASLT - -} // namespace gpu -} // namespace xla - -#endif // XLA_SERVICE_GPU_RUNTIME_GPUBLAS_LT_MATMUL_H_ diff --git a/third_party/xla/xla/service/gpu/runtime3/gpublas_lt_matmul_thunk.cc b/third_party/xla/xla/service/gpu/runtime/gpublas_lt_matmul_thunk.cc similarity index 98% rename from third_party/xla/xla/service/gpu/runtime3/gpublas_lt_matmul_thunk.cc rename to third_party/xla/xla/service/gpu/runtime/gpublas_lt_matmul_thunk.cc index ae0c7eb908cb8b..14cc3149163c1e 100644 --- a/third_party/xla/xla/service/gpu/runtime3/gpublas_lt_matmul_thunk.cc +++ b/third_party/xla/xla/service/gpu/runtime/gpublas_lt_matmul_thunk.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/runtime3/gpublas_lt_matmul_thunk.h" +#include "xla/service/gpu/runtime/gpublas_lt_matmul_thunk.h" #include diff --git a/third_party/xla/xla/service/gpu/runtime3/gpublas_lt_matmul_thunk.h b/third_party/xla/xla/service/gpu/runtime/gpublas_lt_matmul_thunk.h similarity index 94% rename from third_party/xla/xla/service/gpu/runtime3/gpublas_lt_matmul_thunk.h rename to third_party/xla/xla/service/gpu/runtime/gpublas_lt_matmul_thunk.h index e7eaf3a359815e..ca80a7bbd6ccea 100644 --- a/third_party/xla/xla/service/gpu/runtime3/gpublas_lt_matmul_thunk.h +++ b/third_party/xla/xla/service/gpu/runtime/gpublas_lt_matmul_thunk.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_GPU_RUNTIME3_GPUBLAS_LT_MATMUL_THUNK_H_ -#define XLA_SERVICE_GPU_RUNTIME3_GPUBLAS_LT_MATMUL_THUNK_H_ +#ifndef XLA_SERVICE_GPU_RUNTIME_GPUBLAS_LT_MATMUL_THUNK_H_ +#define XLA_SERVICE_GPU_RUNTIME_GPUBLAS_LT_MATMUL_THUNK_H_ #include @@ -79,4 +79,4 @@ class CublasLtMatmulThunk : public Thunk { } // namespace gpu } // namespace xla -#endif // XLA_SERVICE_GPU_RUNTIME3_GPUBLAS_LT_MATMUL_THUNK_H_ +#endif // XLA_SERVICE_GPU_RUNTIME_GPUBLAS_LT_MATMUL_THUNK_H_ diff --git a/third_party/xla/xla/service/gpu/runtime/graph_launch.cc b/third_party/xla/xla/service/gpu/runtime/graph_launch.cc deleted file mode 100644 index 2a23d9dfe97497..00000000000000 --- a/third_party/xla/xla/service/gpu/runtime/graph_launch.cc +++ /dev/null @@ -1,730 +0,0 @@ -/* Copyright 2022 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "xla/service/gpu/runtime/graph_launch.h" - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include "absl/log/check.h" -#include "absl/log/log.h" -#include "absl/status/status.h" -#include "absl/synchronization/mutex.h" -#include "absl/types/span.h" -#include "xla/runtime/custom_call.h" -#include "xla/runtime/executable.h" -#include "xla/service/gpu/buffer_allocations.h" -#include "xla/service/gpu/non_atomically_upgradeable_rw_lock.h" -#include "xla/service/gpu/runtime/concurrent_region.h" -#include "xla/service/gpu/runtime/conv.h" -#include "xla/service/gpu/runtime/gemm.h" -#include "xla/service/gpu/runtime/kernel_launch.h" -#include "xla/service/gpu/runtime/support.h" -#include "xla/service/service_executable_run_options.h" -#include "xla/statusor.h" -#include "tsl/profiler/lib/profiler_lock.h" -#include "tsl/profiler/lib/traceme.h" -#include "tsl/profiler/lib/traceme_encode.h" - -#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM -#include "xla/stream_executor/gpu/gpu_graph.h" -#endif // #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM - -namespace xla { -namespace gpu { - -using tsl::profiler::TraceMe; -using tsl::profiler::TraceMeEncode; - -using xla::runtime::Arguments; -using xla::runtime::AsyncTaskRunner; -using xla::runtime::CustomCall; -using xla::runtime::Executable; -using xla::runtime::FunctionRef; -using xla::runtime::FunctionType; -using xla::runtime::MemrefDesc; -using xla::runtime::MemrefType; -using xla::runtime::StridedMemrefView; - -#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM -using se::gpu::OwnedGpuGraph; - -// Captures Gpu graph by running given function in capture mode. -static absl::StatusOr CaptureGraph( - const ServiceExecutableRunOptions* run_options, - runtime::FunctionRef function_ref, Arguments& args, - CustomCall::UserData user_data); -#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM - -//===----------------------------------------------------------------------===// -// GPU graphs caching. -//===----------------------------------------------------------------------===// - -struct GraphInstances::Impl { - struct State { - // A flag signalling if `InstantiateAllGraphs` was already called and we - // have all Gpu graph instantiated ahead of time. - bool instantiated = false; - - // Last time graph instances were used by a particular stream executor. - uint64_t last_use_micros = 0; - - std::shared_ptr instances = - std::make_shared(); - }; - - // XLA module name that owns graph instances. We use it only to produce logs - // that can be attributed back to XLA executables. - std::string module_name; - - // Number of graphs in the parent module. - int64_t num_graphs = 0; - - mutable absl::Mutex mu; - absl::node_hash_map graphs ABSL_GUARDED_BY(mu); -}; - -// Keep track of instantiated graphs on each StreamExecutor, we use this -// information in the graph eviction policy. -using GraphInstancesState = absl::flat_hash_map; - -static absl::Mutex* GetGraphInstancesStateMutex() { - static auto* mu = new absl::Mutex(); - return mu; -} - -static GraphInstancesState& GetGraphInstancesState() { - static auto* state = new GraphInstancesState(); - return *state; -} - -static int64_t NotifyGraphInstancesCreated(se::StreamExecutor* executor, - int64_t num_graphs) { - absl::MutexLock lock(GetGraphInstancesStateMutex()); - return GetGraphInstancesState()[executor] += num_graphs; -} - -static int64_t NotifyGraphInstancesDestroyed(se::StreamExecutor* executor, - int64_t num_graphs) { - absl::MutexLock lock(GetGraphInstancesStateMutex()); - return GetGraphInstancesState()[executor] -= num_graphs; -} - -// We keep track of all graph instances in the process, to implement graph -// eviction on OOM. Graph instances owned by GpuExecutable, so we rely on -// weak ptr to check if they are still alive. -using GraphInstancesVec = std::vector>; - -static absl::Mutex* GetGraphInstancesVecMutex() { - static auto* mu = new absl::Mutex(); - return mu; -} - -static GraphInstancesVec& GetGraphInstancesVec() { - static auto* vec = new GraphInstancesVec(); - return *vec; -} - -static void AddGraphInstances(std::weak_ptr impl) { - absl::MutexLock lock(GetGraphInstancesVecMutex()); - GetGraphInstancesVec().push_back(std::move(impl)); -} - -// Evicts all graphs for a given executor in the current process. -static void EvictAllGraphs( - se::StreamExecutor* executor, - std::optional eviction_timeout_seconds = std::nullopt) { - // We WARN only when we evict all Gpu graphs because it happens when we - // recover from OOM. Eviction by time out is business as usual. - if (eviction_timeout_seconds.has_value()) { - VLOG(3) << "Evict timed out gpu graphs from executor " << executor; - } else { - LOG(WARNING) << "Evict all gpu graphs from executor " << executor; - } - - TraceMe trace_instantiation([&] { - return TraceMeEncode("cuda.graph.evict_all_graphs", - {{"device_ordinal", executor->device_ordinal()}}); - }); - - absl::MutexLock lock(GetGraphInstancesVecMutex()); - auto& vec = GetGraphInstancesVec(); - - // Erase all expired graph instances. - vec.erase(std::remove_if(vec.begin(), vec.end(), - [](auto& weak_ptr) { return weak_ptr.expired(); }), - vec.end()); - - auto timed_out = [&](GraphInstances::Impl::State& state) -> bool { - if (!eviction_timeout_seconds.has_value()) { - return false; - } - - auto diff = tsl::Env::Default()->NowMicros() - state.last_use_micros; - return (diff / (1000 * 1000)) > *eviction_timeout_seconds; - }; - - int64_t num_evicted = 0; - - for (auto& weak_ptr : vec) { - auto ptr = weak_ptr.lock(); - if (!ptr) continue; - - if (!ptr->mu.TryLock()) continue; - - auto it = ptr->graphs.find(executor); - if (it == ptr->graphs.end()) { - ptr->mu.Unlock(); - continue; - } - - // If we have a timeout value, than check it first, otherwise always evict - // graphs for a given executor. - bool is_timed_out = timed_out(it->second); - if (eviction_timeout_seconds.has_value() && !is_timed_out) { - ptr->mu.Unlock(); - continue; - } - - if (ptr->num_graphs > 0) { - VLOG(3) << "Evict " << ptr->num_graphs << " graphs for: @" - << ptr->module_name << " at executor: " << executor - << " (timed_out = " << is_timed_out << ")." - << " Total remaining graphs at given executor: " - << NotifyGraphInstancesDestroyed(executor, ptr->num_graphs); - } - ptr->graphs.erase(it); - ptr->mu.Unlock(); - ++num_evicted; - } - - if (num_evicted > 0) { - VLOG(3) << "Evicted " << num_evicted << " graphs from executor " - << executor; -#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM - se::gpu::GpuGraphSupport::TrimDeviceMemory(executor); -#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM - } -} - -GraphInstances::GraphInstances(std::string module_name, int64_t num_graphs) - : impl_(std::make_shared()) { - impl_->module_name = std::move(module_name); - impl_->num_graphs = num_graphs; - if (impl_->num_graphs > 0) { - VLOG(3) << "Construct graph instances cache for: @" << impl_->module_name - << " (num_graphs = " << impl_->num_graphs << ")"; - } - AddGraphInstances(impl_); -} - -GraphInstances::~GraphInstances() { - if (impl_->num_graphs > 0) { - VLOG(3) << "Destroy graph instances cache for: @" << impl_->module_name - << " (num_graphs = " << impl_->num_graphs << ")"; - - absl::MutexLock lock(&impl_->mu); - for (auto& [executor, state] : impl_->graphs) { - VLOG(3) << "Destroy " << impl_->num_graphs << " graphs for: @" - << impl_->module_name << " at executor: " << executor - << ". Total remaining graphs at given executor: " - << NotifyGraphInstancesDestroyed(executor, impl_->num_graphs); - } - } -} - -std::shared_ptr GraphInstances::operator()( - se::StreamExecutor* executor) { - absl::MutexLock lock(&impl_->mu); - - auto it = impl_->graphs.try_emplace(executor); - if (it.second && impl_->num_graphs > 0) { - VLOG(3) << "Instantiate " << impl_->num_graphs << " graphs for: @" - << impl_->module_name << " at executor: " << executor - << ". Total graphs at given executor: " - << NotifyGraphInstancesCreated(executor, impl_->num_graphs); - } - - Impl::State& state = it.first->second; - state.last_use_micros = tsl::Env::Default()->NowMicros(); - return state.instances; -} - -bool GraphInstances::InstantiatedAllGraphs( - const ServiceExecutableRunOptions* run_options, - const Executable& executable) { - if (executable.num_functions() == 1) return true; - - absl::MutexLock lock(&impl_->mu); - return impl_->graphs[run_options->stream()->parent()].instantiated; -} - -absl::Status GraphInstances::InstantiateAllGraphs( - const ServiceExecutableRunOptions* run_options, - const Executable& executable, const CustomCall::UserData& user_data, - const BufferAllocations& buffer_allocations, - absl::Span buffer_sizes, - absl::Span> allocation_indices, - std::optional eviction_timeout_seconds) { - // We have only "main" function in the executable. - if (executable.num_functions() == 1) return absl::OkStatus(); - - absl::MutexLock lock(&impl_->mu); - se::StreamExecutor* executor = run_options->stream()->parent(); - - Impl::State& state = impl_->graphs[executor]; - - // All Gpu graphs are already instantiated for a given executor. - if (state.instantiated) return absl::OkStatus(); - - TraceMe trace("gpu.graph.instantiate_all"); - - // Evict all timeout graphs before trying to instantiate new ones. - EvictAllGraphs(executor, eviction_timeout_seconds); - - // We'll retry graph instantiation on OOM errors after evicting all graphs - // instantiated on `executor`. - int32_t num_retries = 0; - - StreamExecutorGraphInstances::Snapshot instances = - state.instances->snapshot(); - - // Instantiate all Gpu graphs by calling graph capture functions with fake - // arguments. Once we'll execute them first time for real, they'll be updated - // with correct pointers. - for (unsigned ordinal = 1; ordinal < executable.num_functions(); ++ordinal) { - if (!absl::StartsWith(executable.function_name(ordinal), - "xla.gpu.graph.capture")) - continue; - - VLOG(3) << "Instantiate Gpu graph defined by capture function @" - << executable.function_name(ordinal) << " (ordinal = " << ordinal - << ")"; - - TraceMe trace_instantiation([&] { - return TraceMeEncode("gpu.graph.instantiate", {{"ordinal", ordinal}}); - }); - - FunctionRef function_ref = executable.function_ref(ordinal); - - const FunctionType& signature = executable.signature(ordinal); - assert(signature.num_results() == 0 && "unexpected number of results"); - Arguments args(signature.num_operands()); - - // Mapping from graph capture argument to buffer allocation index. - absl::Span capture_allocs = allocation_indices[ordinal]; - if (capture_allocs.size() != signature.num_operands()) - return absl::InternalError( - "Invalid number of allocation indices for a graph capture function"); - - // Prepare arguments for the graph capture function. - for (size_t j = 0; j < signature.num_operands(); ++j) { - auto* memref = llvm::dyn_cast(signature.operand(j)); - - if (!memref) - return absl::InternalError(absl::StrFormat( - "Unsupported capture function argument type #%d", j)); - - if (memref->sizes().size() != 1) - return absl::InternalError( - absl::StrFormat("Unsupported capture function memref rank #%d: %d", - j, memref->sizes().size())); - - std::array sizes = {memref->size(0)}; - std::array strides = {1}; - - int64_t allocation_index = capture_allocs[j]; - args.emplace_back( - memref->element_type(), - buffer_allocations.GetDeviceAddress(allocation_index).opaque(), - /*offset=*/0, sizes, strides); - } - -#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM - // Instantiate a Gpu graph with fake arguments. - auto instantiate = [&]() -> absl::StatusOr { - TF_ASSIGN_OR_RETURN( - auto g, CaptureGraph(run_options, function_ref, args, user_data)); - TF_ASSIGN_OR_RETURN(auto e, se::gpu::InstantiateGpuGraph(std::move(g))); - return GraphInstance(0, std::move(e)); - }; - - absl::StatusOr instance = - instances.GetOrCreate(ordinal, instantiate); - - if (instance.status().code() == absl::StatusCode::kResourceExhausted) { - if (num_retries == 0) { - LOG(WARNING) << "InstantiateAllGraph failed due to insufficient memory." - " Try to evict all graphs and free device memory."; - - // Retry on OOM error after evicting all graphs from executor. - EvictAllGraphs(executor); - num_retries++; - ordinal--; // we'll try to instantiate the same graph one more time - continue; - } else { - LOG(WARNING) << "InstantiateAllGraph failed due to insufficient memory." - " Unitialized graphs will run in op-by-op mode."; - return absl::OkStatus(); - } - } - - // Otherwise return an error to the caller. - if (!instance.ok()) return instance.status(); -#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM - } - - state.instantiated = true; - return absl::OkStatus(); -} - -CapturedFunctionExecutionCount* CapturedFunctionExecutionCounts::operator()( - se::StreamExecutor* executor) { - absl::MutexLock lock(&mutex_); - return &counts_[executor]; -} - -//===----------------------------------------------------------------------===// -// Helper structure to hash the remaining arguments' memref pointers. -//===----------------------------------------------------------------------===// - -struct RemainingArgsPtrs { - CustomCall::RemainingArgs args; - se::DeviceMemoryBase* temp_buffer; - - template - friend H AbslHashValue(H h, const RemainingArgsPtrs& m); -}; - -template -H AbslHashValue(H h, const RemainingArgsPtrs& m) { - for (size_t i = 0; i < m.args.size(); ++i) { - if (auto memref = m.args.get(i); succeeded(memref)) - h = H::combine(std::move(h), memref->data); - } - return std::move(H::combine(std::move(h), m.temp_buffer->opaque())); -} - -//----------------------------------------------------------------------------// -// Runs capture function exported by the executable to construct a gpu graph. -//----------------------------------------------------------------------------// - -#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM - -static bool InDebugMode() { -#ifdef NDEBUG - return false; -#endif - return true; -} - -// Forwards custom call arguments to an arguments container that can be passed -// to an executable function. -static absl::Status ForwardArguments(CustomCall::RemainingArgs fwd_args, - Arguments& args) { - for (size_t i = 0; i < fwd_args.size(); ++i) { - if (auto memref = fwd_args.get(i); succeeded(memref)) { - args.emplace_back(memref->dtype, memref->data, /*offset=*/0, - memref->sizes, memref->strides); - continue; - } - - return absl::InvalidArgumentError("Unsupported argument type"); - } - - return absl::OkStatus(); -} - -static absl::StatusOr CaptureGraph( - const ServiceExecutableRunOptions* run_options, - runtime::FunctionRef function_ref, Arguments& args, - CustomCall::UserData user_data) { - // We capture graph on a borrowed stream because we do not want to - // accidentally record any concurrent kernel launches from other XLA - // executables. - se::StreamExecutor* executor = run_options->stream()->parent(); - - // Initialize (with memoization) BlasSupport here because cublasCreate fails - // during gpu graph capturing. - if (function_ref.RequiresBlas()) { - if (!executor->AsBlas()) { - return absl::InternalError("Failed to initialize BLAS support"); - } - } - - absl::StatusOr capture_stream = - run_options->BorrowStream(executor->device_ordinal()); - - if (!capture_stream.ok()) - return absl::InternalError( - absl::StrFormat("Failed to borrow a stream for graph capture: %s", - capture_stream.status().message())); - - TraceMe trace([&] { - return TraceMeEncode("gpu.graph.capture", - {{"ordinal", function_ref.ordinal()}}); - }); - - // TODO(ezhulenev): Pass graph capture context explicitly to the custom calls - // via UserData to be able to detect when executing custom call in graph - // capture mode. Currently we rely on the fact that we know for sure that - // operations in the graph capture function do not need anything except the - // main stream (we capture only kernel launches). - ExecutableRunOptions capture_run_options; - capture_run_options.set_stream(capture_stream->get()); - - const ServiceExecutableRunOptions capture_opts(capture_run_options); - user_data.insert(&capture_opts); - - // Collect all emitted diagnostic messages. - std::string diagnostic; - runtime::DiagnosticEngine diagnostic_engine; - AppendDiagnosticToString(diagnostic_engine, &diagnostic); - - // Prepare options for executing graph capture function. - Executable::ExecuteOpts opts; - opts.custom_call_data = &user_data; - opts.diagnostic_engine = &diagnostic_engine; - - // Graph capture function should not launch any async tasks. - opts.async_task_runner = reinterpret_cast(0XDEADBEEF); - - // Create a graph from running the graph capture function. - auto captured = se::gpu::CaptureGpuGraph(capture_stream->get(), [&]() { - return function_ref(args, runtime::NoResultConverter{}, opts, - /*verify_arguments=*/InDebugMode()) - .status(); - }); - - if (!captured.ok()) { - return Internal("CaptureGpuGraph failed (%s): %s", - diagnostic.empty() ? "" : diagnostic, - captured.status().ToString()); - } - return std::move(*captured); -} - -// When graph execution is disabled we run the graph capture function in -// "regular" mode and execute all operation one by one. -static absl::Status RunGraphOpByOp( - const ServiceExecutableRunOptions* run_options, - runtime::FunctionRef function_ref, CustomCall::RemainingArgs fwd_args, - CustomCall::UserData user_data) { - // Prepare options for executing graph capture function. - Executable::ExecuteOpts opts; - auto* concurrent_region_status = user_data.get(); - // Ops should not run in parallel during op-by-op execution. - concurrent_region_status->DisableConcurrentRegion(); - opts.custom_call_data = &user_data; - - TraceMe trace([&] { - return TraceMeEncode("gpu.graph.run_op_by_op_fallback", - {{"ordinal", function_ref.ordinal()}}); - }); - - // Collect all emitted diagnostic messages. - std::string diagnostic; - runtime::DiagnosticEngine diagnostic_engine; - AppendDiagnosticToString(diagnostic_engine, &diagnostic); - - opts.diagnostic_engine = &diagnostic_engine; - - // Graph capture function should not launch any async tasks. - opts.async_task_runner = reinterpret_cast(0XDEADBEEF); - - Arguments args(fwd_args.size()); - TF_RETURN_IF_ERROR(ForwardArguments(fwd_args, args)); - - auto executed = - function_ref(args, runtime::NoResultConverter{}, opts, InDebugMode()); - concurrent_region_status->EnableConcurrentRegion(); - if (!executed.ok()) { - return Internal("RunGraphOpByOp failed (%s): %s", - diagnostic.empty() ? "" : diagnostic, - executed.status().ToString()); - } - return absl::OkStatus(); -} - -#endif // #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM - -//===----------------------------------------------------------------------===// -// Define the gpu graph launch custom call. -//===----------------------------------------------------------------------===// - -static absl::Status LaunchGraph( - const ServiceExecutableRunOptions* run_options, - const DebugOptions* debug_options, const std::string* ptx, - const std::vector* cubin, se::DeviceMemoryBase* temp_buffer, - StreamExecutorKernels::Snapshot* kernels, - StreamExecutorConvRunners::Snapshot* convs, - StreamExecutorGraphInstances::Snapshot* instances, - CapturedFunctionExecutionCount::Snapshot* counts, - GemmConfigs::Snapshot* gemm_config, runtime::Executable* executable, - NonAtomicallyUpgradeableRWLock* gpu_lock, - ConcurrentRegionStatus* region_status, CustomCall::RemainingArgs fwd_args, - CustomCall::FunctionOrdinal capture) { -#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM - VLOG(1) << "Launch GPU Graph: ordinal = " << capture.ordinal; - - // Get a reference to exported function that captures the gpu graph. - runtime::FunctionRef function_ref = executable->function_ref(capture.ordinal); - - // Compute the hash of the buffer arguments. - size_t ptrs_hash = absl::HashOf(RemainingArgsPtrs{fwd_args, temp_buffer}); - - // Forwards user data required for launching kernels. - auto user_data = [&] { - return CustomCall::UserData(run_options, debug_options, ptx, cubin, - temp_buffer, kernels, convs, executable, - gemm_config, gpu_lock, region_status); - }; - - TF_ASSIGN_OR_RETURN(std::unique_ptr> * get_count, - counts->GetOrCreate(capture.ordinal, [] { - return std::make_unique>(0); - })); - - int64_t count = (*get_count)->fetch_add(1); - int64_t num_runs_to_instantiate = - debug_options->xla_gpu_graph_num_runs_to_instantiate(); - - // TODO(b/290773547): Profiler + CUDA graphs lead to memory corruption. As a - // work around disable graph execution and run everything in op-by-op mode. - bool is_profiling = tsl::profiler::ProfilerLock::HasActiveSession(); - - if (count < num_runs_to_instantiate || is_profiling) { - VLOG(3) << "Run gpu graph in op-by-op mode: ordinal = " << capture.ordinal; - return RunGraphOpByOp(run_options, function_ref, fwd_args, user_data()); - } - - // Instantiate Gpu graph by running graph capture function. - auto instantiate = [&]() -> absl::StatusOr { - Arguments args(fwd_args.size()); - TF_RETURN_IF_ERROR(ForwardArguments(fwd_args, args)); - - TF_ASSIGN_OR_RETURN( - auto g, CaptureGraph(run_options, function_ref, args, user_data())); - - TF_ASSIGN_OR_RETURN(auto e, se::gpu::InstantiateGpuGraph(std::move(g))); - - return GraphInstance(ptrs_hash, std::move(e)); - }; - - GraphInstance* instance; - if (num_runs_to_instantiate < 0) { - // If num_runs_to_instantiate is less than 0, all graphs should be - // instantiated ahead-of-time. If we fail to get the graph instance, then - // graph instantiation failed due to OOM. So we run the graph op-by-op. - absl::StatusOr try_get_instance = - instances->Get(capture.ordinal); - if (try_get_instance.ok()) { - instance = try_get_instance.value(); - } else { - return RunGraphOpByOp(run_options, function_ref, fwd_args, user_data()); - } - } else { - TF_ASSIGN_OR_RETURN(instance, - instances->GetOrCreate(capture.ordinal, instantiate)); - } - - { - // Lock graph instance for read only access. If we'll have to update the - // graph, we'll update to a writer lock below. - absl::ReaderMutexLock lock(instance->mutex.get()); - - // If pointers did not change we can run captured graph. - if (ptrs_hash == instance->ptr_hash) { - TraceMe trace([&] { - return TraceMeEncode("gpu.graph.launch_cached", - {{"ordinal", capture.ordinal}}); - }); - - VLOG(3) << "Execute cached graph instance"; - return instance->exec.Launch(run_options->stream()); - } - } - - // Otherwise we have to re-capture the graph and update the graph instance. - VLOG(3) << "Update cached graph instance"; - - Arguments args(fwd_args.size()); - TF_RETURN_IF_ERROR(ForwardArguments(fwd_args, args)); - - // Capture GPU graph by running capture function. - TF_ASSIGN_OR_RETURN( - auto g, CaptureGraph(run_options, function_ref, args, user_data())); - - // At this point we have to grab a writer lock, because we might potentially - // have concurrent execution of the cached graph instance. - absl::WriterMutexLock lock(instance->mutex.get()); - - // Update captured graph executable. - TF_RETURN_IF_ERROR(instance->exec.Update(std::move(g))); - - // Update captured pointer hash. - instance->ptr_hash = ptrs_hash; - - TraceMe trace([&] { - return TraceMeEncode("gpu.graph.launch_updated", - {{"ordinal", capture.ordinal}}); - }); - - return instance->exec.Launch(run_options->stream()); - -#else // #if !GOOGLE_CUDA && !TENSORFLOW_USE_ROCM - - return absl::InternalError("GPU graphs are not supported"); - -#endif // #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM -} - -//===----------------------------------------------------------------------===// - -XLA_RUNTIME_DEFINE_CUSTOM_CALL( - Launch, FunctionWrapper(), checks, - CustomCall::Bind("xla.gpu.graph.launch") - .UserData() - .UserData() - .UserData() - .UserData*>() - .UserData() - .UserData() - .UserData() - .UserData() - .UserData() - .UserData() - .UserData() - .UserData() - .UserData() - .RemainingArgs() - .Attr("capture")); - -void RegisterGraphLaunchCustomCalls( - runtime::DirectCustomCallRegistry& registry) { - registry.Register("xla.gpu.graph.launch", Launch); -} - -} // namespace gpu -} // namespace xla diff --git a/third_party/xla/xla/service/gpu/runtime/graph_launch.h b/third_party/xla/xla/service/gpu/runtime/graph_launch.h deleted file mode 100644 index e030781aeabc68..00000000000000 --- a/third_party/xla/xla/service/gpu/runtime/graph_launch.h +++ /dev/null @@ -1,140 +0,0 @@ -/* Copyright 2022 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_SERVICE_GPU_RUNTIME_GRAPH_LAUNCH_H_ -#define XLA_SERVICE_GPU_RUNTIME_GRAPH_LAUNCH_H_ - -#include -#include -#include -#include -#include -#include -#include - -#include "absl/container/node_hash_map.h" -#include "absl/types/span.h" -#include "xla/runtime/custom_call_registry.h" -#include "xla/runtime/executable.h" -#include "xla/service/gpu/buffer_allocations.h" -#include "xla/service/service_executable_run_options.h" -#include "xla/stream_executor/stream_executor.h" - -#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM -#include "xla/stream_executor/gpu/gpu_graph.h" -#endif // #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM - -namespace xla { -namespace gpu { - -// Registers XLA Gpu runtime graph launch custom calls. -void RegisterGraphLaunchCustomCalls( - runtime::DirectCustomCallRegistry& registry); - -struct GraphInstance; // Forward declare -class StreamExecutorGraphInstances; // Forward declare - -// A state vector that keeps track of the number of times a capture function -// gets executed. Graph capture function ordinal is the key in this container. -class CapturedFunctionExecutionCount - : public runtime::StateVector>> {}; - -#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM - -// A state vector that owns all instantiated GPU graphs. Graph capture function -// ordinal is the key in this container. -class StreamExecutorGraphInstances - : public runtime::StateVector {}; - -// Instantiated GPU graph instance guarded with a mutex for exclusive access. -struct GraphInstance { - GraphInstance(size_t ptr_hash, se::gpu::OwnedGpuGraphExec exec) - : ptr_hash(ptr_hash), exec(std::move(exec)), mutex(new absl::Mutex) {} - - // Graph instance is fully identified by the hash of its pointer arguments - // because currently it's guaranteed that all shapes and launch dimensions - // will be constant from run to run. - size_t ptr_hash ABSL_GUARDED_BY(*mutex); - se::gpu::OwnedGpuGraphExec exec ABSL_GUARDED_BY(*mutex); - - // Access to a graph instance must be synchronized, because we potentially can - // run concurrent graph instance updates. - std::unique_ptr mutex; -}; - -#else // #if !GOOGLE_CUDA && !TENSORFLOW_USE_ROCM - -// Define empty struct and empty state when GPU is not enabled. -struct GraphInstance {}; -class StreamExecutorGraphInstances - : public runtime::StateVector {}; - -#endif // #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM - -// Xla executable keeps a mapping from stream executors to graph instances. -// -// Graph instances allocate on-device memory, so we periodically destroy -// them to free up some space on device. JAX for example keeps all XLA -// executables alive, and destroys them when the process shuts down, so we can -// end up with thousands of unused (or rarely used) graphs in device memory. -class GraphInstances { - public: - struct Impl; - - GraphInstances(std::string module_name, int64_t num_graphs); - ~GraphInstances(); - - std::shared_ptr operator()( - se::StreamExecutor* executor); - - // Instantiates all Gpu graphs defined by the given executable using user - // provided run options. This guarantees that once we start execution, all Gpu - // graphs are ready, and will only require cheap update operation and will not - // require allocating new resources (we avoid non deterministic OOM errors). - // - // If timeout is not nullopt it will evict all previously instantiated graphs - // that were used more than `eviction_timeout_seconds` seconds ago. - absl::Status InstantiateAllGraphs( - const ServiceExecutableRunOptions* run_options, - const runtime::Executable& executable, - const runtime::CustomCall::UserData& user_data, - const BufferAllocations& buffer_allocations, - absl::Span buffer_sizes, - absl::Span> allocation_indices, - std::optional eviction_timeout_seconds = std::nullopt); - - // Returns true if all Gpu graphs were already instantiated. - bool InstantiatedAllGraphs(const ServiceExecutableRunOptions* run_options, - const runtime::Executable& executable); - - private: - std::shared_ptr impl_; -}; - -// Xla executable keeps a mapping from stream executors to execution counts. -class CapturedFunctionExecutionCounts { - public: - CapturedFunctionExecutionCount* operator()(se::StreamExecutor* executor); - - private: - mutable absl::Mutex mutex_; - absl::node_hash_map - counts_ ABSL_GUARDED_BY(mutex_); -}; - -} // namespace gpu -} // namespace xla - -#endif // XLA_SERVICE_GPU_RUNTIME_GRAPH_LAUNCH_H_ diff --git a/third_party/xla/xla/service/gpu/runtime3/infeed_thunk.cc b/third_party/xla/xla/service/gpu/runtime/infeed_thunk.cc similarity index 89% rename from third_party/xla/xla/service/gpu/runtime3/infeed_thunk.cc rename to third_party/xla/xla/service/gpu/runtime/infeed_thunk.cc index 62bccdabab9abf..02dabdf25e41ce 100644 --- a/third_party/xla/xla/service/gpu/runtime3/infeed_thunk.cc +++ b/third_party/xla/xla/service/gpu/runtime/infeed_thunk.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/runtime3/infeed_thunk.h" +#include "xla/service/gpu/runtime/infeed_thunk.h" #include "absl/status/status.h" #include "xla/service/gpu/buffer_allocations.h" @@ -45,14 +45,16 @@ absl::Status InfeedThunk::ExecuteOnStream(const ExecuteParams& params) { se::ScopedDeviceMemory& buffer = source.second; const Shape& source_shape = ShapeUtil::GetSubshape(source_buffers.shape(), shape_index); - TF_RET_CHECK(ShapeUtil::Equal(dest_slices_[index].shape, source_shape)) + TF_RET_CHECK( + ShapeUtil::ReshapeIsBitcast(dest_slices_[index].shape, source_shape)) << "Mismatch between infeed source buffer shape " << ShapeUtil::HumanStringWithLayout(source_shape) << " and infeed dest buffer shape " << ShapeUtil::HumanStringWithLayout(dest_slices_[index].shape); se::DeviceMemoryBase dest_address = buffer_allocations.GetDeviceAddress(dest_slices_[index++].slice); - stream.ThenMemcpy(&dest_address, *buffer.ptr(), buffer.ptr()->size()); + TF_RETURN_IF_ERROR( + stream.Memcpy(&dest_address, *buffer.ptr(), buffer.ptr()->size())); } // Make sure that all dest slices have been copied into. @@ -62,7 +64,7 @@ absl::Status InfeedThunk::ExecuteOnStream(const ExecuteParams& params) { absl::Status block_status = stream.BlockHostUntilDone(); if (!block_status.ok()) { return Internal("Failed to complete data transfer on stream %p: %s", - &stream, block_status.message()); + &stream, block_status.message()); } VLOG(2) << "Infeeding to GPU complete"; diff --git a/third_party/xla/xla/service/gpu/runtime3/infeed_thunk.h b/third_party/xla/xla/service/gpu/runtime/infeed_thunk.h similarity index 90% rename from third_party/xla/xla/service/gpu/runtime3/infeed_thunk.h rename to third_party/xla/xla/service/gpu/runtime/infeed_thunk.h index 9d685aa2c19d5f..9286193045f3cc 100644 --- a/third_party/xla/xla/service/gpu/runtime3/infeed_thunk.h +++ b/third_party/xla/xla/service/gpu/runtime/infeed_thunk.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_GPU_RUNTIME3_INFEED_THUNK_H_ -#define XLA_SERVICE_GPU_RUNTIME3_INFEED_THUNK_H_ +#ifndef XLA_SERVICE_GPU_RUNTIME_INFEED_THUNK_H_ +#define XLA_SERVICE_GPU_RUNTIME_INFEED_THUNK_H_ #include "xla/service/gpu/thunk.h" @@ -42,4 +42,4 @@ class InfeedThunk : public Thunk { } // namespace gpu } // namespace xla -#endif // XLA_SERVICE_GPU_RUNTIME3_INFEED_THUNK_H_ +#endif // XLA_SERVICE_GPU_RUNTIME_INFEED_THUNK_H_ diff --git a/third_party/xla/xla/service/gpu/runtime/io_feed.cc b/third_party/xla/xla/service/gpu/runtime/io_feed.cc deleted file mode 100644 index 97e6a83cb669c3..00000000000000 --- a/third_party/xla/xla/service/gpu/runtime/io_feed.cc +++ /dev/null @@ -1,178 +0,0 @@ -/* Copyright 2022 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "xla/service/gpu/runtime/io_feed.h" - -#include -#include - -#include "xla/runtime/custom_call.h" -#include "xla/runtime/executable.h" -#include "xla/service/gpu/infeed_manager.h" -#include "xla/service/gpu/outfeed_manager.h" -#include "xla/service/gpu/runtime/support.h" -#include "xla/service/service_executable_run_options.h" - -namespace xla { -namespace gpu { - -using runtime::CustomCall; - -static absl::Status InfeedImpl(const ServiceExecutableRunOptions* run_options, - CustomCall::RemainingArgs args, - std::string_view config) { - VLOG(3) << "Infeeding to GPU"; - - se::Stream* stream = run_options->stream(); - ShapeTree> source_buffers = - GetOrCreateInfeedManager(stream->parent())->BlockingGetNextDestination(); - - // Check that we have correct number of arguments. - if (args.size() != source_buffers.leaf_count()) - return absl::InvalidArgumentError("Incorrect number of arguments"); - - size_t index = 0; - for (auto& source : source_buffers.leaves()) { - // Get the destination buffer. - auto dest = args.get(index); - if (failed(dest)) - return absl::InternalError("Failed to get the destination buffer"); - - // Get the source buffer shape. - const Shape& source_shape = - ShapeUtil::GetSubshape(source_buffers.shape(), source.first); - - // Check that destination shape matches the source shape. - Shape dest_shape = ToShape(*dest); - if (!ShapeUtil::ReshapeIsBitcast(dest_shape, source_shape)) { - return absl::InvalidArgumentError( - "The destination shape does not match the source shape"); - } - - se::DeviceMemoryBase dest_address = GetDeviceAddress(*dest); - se::ScopedDeviceMemory& buffer = source.second; - stream->ThenMemcpy(&dest_address, *buffer.ptr(), buffer.ptr()->size()); - - ++index; - } - - // TODO(ezhulenev): Make this function async? - TF_RETURN_IF_ERROR(stream->BlockHostUntilDone()); - - VLOG(3) << "Infeeding to GPU complete"; - - return absl::OkStatus(); -} - -static absl::Status OutfeedImpl(const ServiceExecutableRunOptions* run_options, - CustomCall::RemainingArgs args, - std::string_view config) { - VLOG(3) << "Outfeeding from GPU"; - - se::Stream* stream = run_options->stream(); - OutfeedManager* outfeed_manager = GetOrCreateOutfeedManager(stream->parent()); - ShapeTree>* dest_buffers = - outfeed_manager->BlockingGetNextDestination(); - - // Nothing to be done for an outfeed with no inputs. - // Note: Must do this after `BlockingGetNextDestination` above to dequeue an - // entry from the outfeed manager. - if (args.empty()) return absl::OkStatus(); - - // Check that we have correct number of arguments. - if (args.size() != dest_buffers->leaf_count()) - return absl::InvalidArgumentError("Incorrect number of arguments"); - - int64_t leaf_count = dest_buffers->leaf_count(); - auto dest_leaf_it = dest_buffers->leaf_begin(); - - for (int64_t index = 0; index < leaf_count; ++index) { - const ShapeIndex& shape_index = dest_leaf_it->first; - std::unique_ptr& buffer = dest_leaf_it->second; - - // NOTE: This code needs deal with the `dest_buffers` object getting - // deleted when it is executing. Specifically, objects in the outfeed queue - // are pointers to instances of stack-allocated objects in - // `GpuTransferManager::TransferLiteralFromOutfeed`. When all leaf node - // buffers are notified via "buffer->Done()" below in the stream host - // callback, `TransferLiteralFromOutfeed` deletes this stack-allocated - // object when it returns. This means that it is possible that during the - // last iteration, after the call to "buffer->Done()" is scheduled onto the - // stream, the `dest_buffers` object might get deleted, so we should avoid - // accessing the object after that. - // - // To achieve that, increment the leaf iterator here before the last "Done" - // is enqueued, instead of in the loop increment, which would be after the - // "Done" is scheduled. - ++dest_leaf_it; - - // Get the source buffer. - auto source = args.get(index); - if (failed(source)) - return absl::InternalError("Failed to get the source buffer"); - - // Get the source buffer shape. - const Shape& dest_shape = - ShapeUtil::GetSubshape(dest_buffers->shape(), shape_index); - - // Check that destination shape matches the source shape. - Shape source_shape = ToShape(*source); - if (!ShapeUtil::ReshapeIsBitcast(dest_shape, source_shape)) { - return absl::InvalidArgumentError( - "The destination shape does not match the source shape"); - } - - se::DeviceMemoryBase source_address = GetDeviceAddress(*source); - - // Schedule the memory transfer. - auto* dest_address = buffer->destination()->untyped_data(); - stream->ThenMemcpy(dest_address, source_address, buffer->length()) - .ThenDoHostCallback([&buffer]() { buffer->Done(); }); - } - - TF_RETURN_IF_ERROR(stream->BlockHostUntilDone()); - - VLOG(3) << "Outfeeding from GPU complete"; - - return absl::OkStatus(); -} - -//===----------------------------------------------------------------------===// -// Define Xla runtime bindings for the custom calls. -//===----------------------------------------------------------------------===// - -XLA_RUNTIME_DEFINE_CUSTOM_CALL( - Infeed, FunctionWrapper(), checks, - CustomCall::Bind("xla.gpu.infeed") - .UserData() - .Arg() // args - .Attr("config")); - -XLA_RUNTIME_DEFINE_CUSTOM_CALL( - Outfeed, FunctionWrapper(), checks, - CustomCall::Bind("xla.gpu.outfeed") - .UserData() - .Arg() // args - .Attr("config")); - -//===----------------------------------------------------------------------===// - -void RegisterIoFeedCustomCalls(runtime::DirectCustomCallRegistry& registry) { - registry.Register("xla.gpu.infeed", Infeed); - registry.Register("xla.gpu.outfeed", Outfeed); -} - -} // namespace gpu -} // namespace xla diff --git a/third_party/xla/xla/service/gpu/runtime/io_feed.h b/third_party/xla/xla/service/gpu/runtime/io_feed.h deleted file mode 100644 index e3d22f694b52a9..00000000000000 --- a/third_party/xla/xla/service/gpu/runtime/io_feed.h +++ /dev/null @@ -1,30 +0,0 @@ -/* Copyright 2022 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_SERVICE_GPU_RUNTIME_IO_FEED_H_ -#define XLA_SERVICE_GPU_RUNTIME_IO_FEED_H_ - -#include "xla/runtime/custom_call_registry.h" - -namespace xla { -namespace gpu { - -// Registers XLA Gpu runtime infeed and outfeed custom calls. -void RegisterIoFeedCustomCalls(runtime::DirectCustomCallRegistry& registry); - -} // namespace gpu -} // namespace xla - -#endif // XLA_SERVICE_GPU_RUNTIME_IO_FEED_H_ diff --git a/third_party/xla/xla/service/gpu/runtime/kernel_launch.cc b/third_party/xla/xla/service/gpu/runtime/kernel_launch.cc deleted file mode 100644 index 9287914e646ef3..00000000000000 --- a/third_party/xla/xla/service/gpu/runtime/kernel_launch.cc +++ /dev/null @@ -1,331 +0,0 @@ -/* Copyright 2022 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "xla/service/gpu/runtime/kernel_launch.h" - -#include -#include -#include -#include -#include -#include - -#include "absl/container/flat_hash_map.h" -#include "absl/container/inlined_vector.h" -#include "absl/status/status.h" -#include "absl/strings/str_cat.h" -#include "absl/strings/str_format.h" -#include "absl/synchronization/mutex.h" -#include "xla/hlo/ir/hlo_computation.h" -#include "xla/runtime/custom_call.h" -#include "xla/runtime/custom_call_registry.h" -#include "xla/runtime/executable.h" -#include "xla/runtime/memref_view.h" -#include "xla/runtime/state.h" -#include "xla/service/gpu/kernels/custom_kernel.h" -#include "xla/service/gpu/kernels/custom_kernel_fusion.h" -#include "xla/service/gpu/launch_dimensions.h" -#include "xla/service/gpu/runtime/concurrent_region.h" -#include "xla/service/gpu/runtime/support.h" -#include "xla/service/gpu/stream_executor_util.h" -#include "xla/service/hlo.pb.h" -#include "xla/service/service_executable_run_options.h" -#include "xla/statusor.h" -#include "xla/stream_executor/device_memory.h" -#include "xla/stream_executor/kernel.h" -#include "xla/stream_executor/launch_dim.h" -#include "xla/stream_executor/stream.h" -#include "tsl/platform/errors.h" -#include "tsl/platform/logging.h" - -#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM -#include "xla/stream_executor/gpu/gpu_graph.h" -#endif // #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM - -namespace xla { -namespace gpu { - -using xla::runtime::CustomCall; -using xla::runtime::State; -using xla::runtime::StridedMemrefView; - -StreamExecutorKernels* GpuExecutableKernels::operator()( - se::StreamExecutor* executor) { - absl::MutexLock lock(&mutex_); - return &kernels_[executor]; -} - -//===----------------------------------------------------------------------===// -// Define the kernel launch custom call. -//===----------------------------------------------------------------------===// - -static absl::Status LaunchImpl( - const ServiceExecutableRunOptions* run_options, const std::string* ptx, - const std::vector* cubin, se::DeviceMemoryBase* temp_buffer, - ConcurrentRegionStatus* region_status, - State> device_kernel, - int32_t shared_memory_bytes, int32_t grid_size_x, int32_t grid_size_y, - int32_t grid_size_z, int32_t block_size_x, int32_t block_size_y, - int32_t block_size_z, CustomCall::RemainingArgs args, std::string_view name, - int64_t stream_id) { - se::Stream* stream = run_options->stream(); - se::StreamExecutor* executor = stream->parent(); - - LaunchDimensions launch_dimensions( - se::BlockDim(grid_size_x, grid_size_y, grid_size_z), - se::ThreadDim(block_size_x, block_size_y, block_size_z)); - - const int args_size_including_temp_buffer = args.size() + 1; - - // If kernel does not exist create it from the ptx and cubin. - TF_ASSIGN_OR_RETURN( - std::unique_ptr * kernel, device_kernel.GetOrCreate([&] { - return ToAbsl(CreateKernel(absl::string_view(name.data(), name.size()), - args_size_including_temp_buffer, *ptx, - *cubin, executor, shared_memory_bytes)); - })); - assert((*kernel)->name() == name && "unexpected loaded kernel"); - -#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM - if (VLOG_IS_ON(3)) { - TF_ASSIGN_OR_RETURN(bool is_capturing, se::gpu::IsStreamCapturing(stream)); - if (is_capturing) { - if (region_status->IsInConcurrentRegion()) { - LOG(INFO) << "Launching " << (*kernel)->name() - << "in a concurrent region during GPU graph capture"; - } else { - LOG(INFO) << "Launching " << (*kernel)->name() - << "during GPU graph capture"; - } - } else { - LOG(INFO) << "Launching " << (*kernel)->name(); - } - } -#else - VLOG(3) << "Launching " << (*kernel)->name(); -#endif - - absl::InlinedVector buffer_args( - args_size_including_temp_buffer); - - // Add MemRef arguments as buffer arguments. - for (unsigned i = 0; i < args.size(); ++i) { - // We get arguments corresponding to XLA allocations required by the - // compiled device kernel, and not the actual memrefs that device kernel - // writes/reads, so we don't have to pass the size along with the pointer. - if (auto strided = args.get(i); succeeded(strided)) { - buffer_args[i] = se::DeviceMemoryBase(strided->data); - continue; - } - - return absl::InvalidArgumentError( - absl::StrFormat("Unsupported argument #%d type", i)); - } - - // Always add temporary buffer as the last kernel argument. - buffer_args.back() = *temp_buffer; - - // If we are capturing a concurrent region in a GPU graph, then use the - // stream provided by ConcurrentRegionStatus to execute the kernel. - se::Stream* execution_stream = stream; - if (stream_id != 0) { - DCHECK(region_status->IsInConcurrentRegion()); - TF_ASSIGN_OR_RETURN(execution_stream, region_status->GetStream(stream_id)); - } else if (region_status->IsInConcurrentRegion()) { - execution_stream = region_status->GetNextStream(); - } - - // Execute device kernel on the execution stream. - return ExecuteKernelOnStream(**kernel, buffer_args, launch_dimensions, - execution_stream); -} - -//===----------------------------------------------------------------------===// -// Define the custom kernel (fusion) launch custom call. -//===----------------------------------------------------------------------===// - -static absl::StatusOr> CreateCustomKernel( - se::StreamExecutor* executor, std::string_view name, - std::string_view custom_fusion_computation) { - auto* registry = CustomKernelFusionRegistry::Default(); - auto* custom_kernel_fusion = registry->Lookup(name); - - // If custom fusion is not found it means that some of the build targets might - // not be statically linked into the binary. - if (custom_kernel_fusion == nullptr) { - return absl::InternalError(absl::StrCat( - "Custom kernel fusion ", name, " not found in a default registry.")); - } - - // Parse attached custom fusion computation. - HloComputationProto computation_proto; - if (!computation_proto.ParseFromArray(custom_fusion_computation.data(), - custom_fusion_computation.size())) { - return absl::InternalError("Failed to parse custom fusion computation"); - } - - // Build HloComputation from a proto for passing to custom fusion. - absl::flat_hash_map computation_map; - TF_ASSIGN_OR_RETURN( - std::unique_ptr computation, - HloComputation::CreateFromProto(computation_proto, computation_map)); - - // Load custom kernels that can implement a fusion computation. - TF_ASSIGN_OR_RETURN(std::vector kernels, - custom_kernel_fusion->LoadKernels( - executor->GetDeviceDescription(), computation.get())); - - // This should never happen, it means that compilation pipeline created a - // fusion operation that is not supported by a given custom fusion. - if (kernels.empty()) { - return absl::InternalError( - absl::StrCat("Custom kernel fusion ", name, - " returned empty custom kernels for a fused computation")); - } - - auto kernel = std::make_unique(executor); - TF_RETURN_IF_ERROR( - executor->GetKernel(kernels[0].kernel_spec(), kernel.get())); - - return kernel; -} - -static absl::Status CustomLaunchImpl( - const ServiceExecutableRunOptions* run_options, const std::string* ptx, - const std::vector* cubin, se::DeviceMemoryBase* temp_buffer, - ConcurrentRegionStatus* region_status, - State> device_kernel, - int32_t shared_memory_bytes, int32_t grid_size_x, int32_t grid_size_y, - int32_t grid_size_z, int32_t block_size_x, int32_t block_size_y, - int32_t block_size_z, CustomCall::RemainingArgs args, std::string_view name, - int64_t stream_id, std::string_view custom_fusion_computation) { - se::Stream* stream = run_options->stream(); - se::StreamExecutor* executor = stream->parent(); - - LaunchDimensions launch_dimensions( - se::BlockDim(grid_size_x, grid_size_y, grid_size_z), - se::ThreadDim(block_size_x, block_size_y, block_size_z)); - - // If kernel does not exist load it from a custom fusion computation. - TF_ASSIGN_OR_RETURN( - std::unique_ptr * kernel, device_kernel.GetOrCreate([&] { - return ToAbsl( - CreateCustomKernel(executor, name, custom_fusion_computation)); - })); - assert((*kernel)->name() == name && "unexpected loaded kernel"); - -#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM - if (VLOG_IS_ON(3)) { - TF_ASSIGN_OR_RETURN(bool is_capturing, se::gpu::IsStreamCapturing(stream)); - if (is_capturing) { - if (region_status->IsInConcurrentRegion()) { - LOG(INFO) << "Launching " << (*kernel)->name() - << "in a concurrent region during GPU graph capture"; - } else { - LOG(INFO) << "Launching " << (*kernel)->name() - << "during GPU graph capture"; - } - } else { - LOG(INFO) << "Launching " << (*kernel)->name(); - } - } -#else - VLOG(3) << "Launching " << (*kernel)->name(); -#endif - - absl::InlinedVector buffer_args(args.size()); - - // Add MemRef arguments as buffer arguments. - for (unsigned i = 0; i < args.size(); ++i) { - // We get arguments corresponding to XLA allocations required by the - // compiled device kernel, and not the actual memrefs that device kernel - // writes/reads, so we don't have to pass the size along with the pointer. - if (auto strided = args.get(i); succeeded(strided)) { - buffer_args[i] = se::DeviceMemoryBase(strided->data); - continue; - } - - return absl::InvalidArgumentError( - absl::StrFormat("Unsupported argument #%d type", i)); - } - - // If we are capturing a concurrent region in a GPU graph, then use the - // stream provided by ConcurrentRegionStatus to execute the kernel. - se::Stream* execution_stream = stream; - if (stream_id != 0) { - DCHECK(region_status->IsInConcurrentRegion()); - TF_ASSIGN_OR_RETURN(execution_stream, region_status->GetStream(stream_id)); - } else if (region_status->IsInConcurrentRegion()) { - execution_stream = region_status->GetNextStream(); - } - - se::KernelArgsDeviceMemoryArray kernel_args(buffer_args, shared_memory_bytes); - return executor->Launch( - stream, se::ThreadDim(block_size_x, block_size_y, block_size_z), - se::BlockDim(grid_size_x, grid_size_y, grid_size_z), **kernel, - kernel_args); -} - -//===----------------------------------------------------------------------===// - -XLA_RUNTIME_DEFINE_CUSTOM_CALL( - Launch, FunctionWrapper(), checks, - CustomCall::Bind("xla.gpu.func.launch") - .UserData() - .UserData() - .UserData*>() - .UserData() - .UserData() - .State>("uid") - .Arg() // shared_memory_bytes - .Arg() // grid_size_x - .Arg() // grid_size_y - .Arg() // grid_size_z - .Arg() // block_size_x - .Arg() // block_size_y - .Arg() // block_size_x - .RemainingArgs() // args - .Attr("kernel") - .Attr("stream")); - -XLA_RUNTIME_DEFINE_CUSTOM_CALL( - CustomLaunch, FunctionWrapper(), checks, - CustomCall::Bind("xla.gpu.func.custom_launch") - .UserData() - .UserData() - .UserData*>() - .UserData() - .UserData() - .State>("uid") - .Arg() // shared_memory_bytes - .Arg() // grid_size_x - .Arg() // grid_size_y - .Arg() // grid_size_z - .Arg() // block_size_x - .Arg() // block_size_y - .Arg() // block_size_x - .RemainingArgs() // args - .Attr("kernel") - .Attr("stream") - .Attr("__custom_fusion_computation")); - -void RegisterKernelLaunchCustomCalls( - runtime::DirectCustomCallRegistry& registry) { - registry.Register("xla.gpu.func.launch", Launch); - registry.Register("xla.gpu.func.custom_launch", CustomLaunch); -} - -} // namespace gpu -} // namespace xla diff --git a/third_party/xla/xla/service/gpu/runtime/kernel_launch.h b/third_party/xla/xla/service/gpu/runtime/kernel_launch.h deleted file mode 100644 index 9bec6c71dcae0e..00000000000000 --- a/third_party/xla/xla/service/gpu/runtime/kernel_launch.h +++ /dev/null @@ -1,56 +0,0 @@ -/* Copyright 2022 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_SERVICE_GPU_RUNTIME_KERNEL_LAUNCH_H_ -#define XLA_SERVICE_GPU_RUNTIME_KERNEL_LAUNCH_H_ - -#include -#include -#include - -#include "absl/container/node_hash_map.h" -#include "absl/synchronization/mutex.h" -#include "xla/runtime/custom_call_registry.h" -#include "xla/runtime/state.h" -#include "xla/stream_executor/stream_executor.h" - -namespace xla { -namespace gpu { - -// Registers XLA Gpu runtime kernel launch custom calls. -void RegisterKernelLaunchCustomCalls( - runtime::DirectCustomCallRegistry& registry); - -// Kernels loaded by Gpu executable for a single stream executor. -class StreamExecutorKernels - : public runtime::StateVector> {}; - -// Xla runtime Gpu executable owns the pre-compiled device module (PTX and -// Cubin for Nvidia Gpus) for all device kernels, and the cache keeps a mapping -// from stream executor to pre-loaded kernels -class GpuExecutableKernels { - public: - StreamExecutorKernels* operator()(se::StreamExecutor* executor); - - private: - mutable absl::Mutex mutex_; - absl::node_hash_map kernels_ - ABSL_GUARDED_BY(mutex_); -}; - -} // namespace gpu -} // namespace xla - -#endif // XLA_SERVICE_GPU_RUNTIME_KERNEL_LAUNCH_H_ diff --git a/third_party/xla/xla/service/gpu/runtime3/kernel_thunk.cc b/third_party/xla/xla/service/gpu/runtime/kernel_thunk.cc similarity index 76% rename from third_party/xla/xla/service/gpu/runtime3/kernel_thunk.cc rename to third_party/xla/xla/service/gpu/runtime/kernel_thunk.cc index ef741d71d8b104..45d02be5e54bf8 100644 --- a/third_party/xla/xla/service/gpu/runtime3/kernel_thunk.cc +++ b/third_party/xla/xla/service/gpu/runtime/kernel_thunk.cc @@ -13,16 +13,18 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/runtime3/kernel_thunk.h" +#include "xla/service/gpu/runtime/kernel_thunk.h" #include #include +#include #include #include #include #include #include "absl/container/inlined_vector.h" +#include "absl/status/status.h" #include "absl/strings/str_format.h" #include "absl/synchronization/mutex.h" #include "absl/types/span.h" @@ -36,6 +38,7 @@ limitations under the License. #include "xla/status.h" #include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/kernel.h" +#include "xla/stream_executor/launch_dim.h" #include "xla/stream_executor/stream_executor.h" #include "tsl/platform/errors.h" #include "tsl/platform/logging.h" @@ -60,17 +63,15 @@ mlir::Value RemoveTransformingOperations(mlir::Value value) { } // namespace -KernelThunk::KernelThunk( - std::variant op, - std::string kernel_name, absl::Span kernel_arguments, - LaunchDimensions launch_dimensions, int64_t shmem_bytes) - : Thunk(Kind::kKernel, std::holds_alternative(op) - ? Thunk::ThunkInfo::WithProfileAnnotation( - std::get(op)) - : Thunk::ThunkInfo::WithProfileAnnotation( - std::get(op))), +KernelThunk::KernelThunk(const HloInstruction* instr, std::string kernel_name, + absl::Span kernel_arguments, + LaunchDimensions launch_dimensions, + std::optional cluster_dim, + int64_t shmem_bytes) + : Thunk(Kind::kKernel, Thunk::ThunkInfo::WithProfileAnnotation(instr)), kernel_name_(std::move(kernel_name)), launch_dimensions_(std::move(launch_dimensions)), + cluster_dim_(std::move(cluster_dim)), shmem_bytes_(shmem_bytes) { args_.reserve(kernel_arguments.size()); written_.reserve(kernel_arguments.size()); @@ -80,23 +81,13 @@ KernelThunk::KernelThunk( written_.push_back(kernel_argument.written()); } } - - if (std::holds_alternative(op)) { - // Skip populating MLIR values_ if emitting from HLO. - return; - } - - values_.reserve(kernel_arguments.size()); - for (const auto& kernel_argument : kernel_arguments) { - if (!kernel_argument.first_with_same_slice().has_value()) { - values_.push_back(RemoveTransformingOperations(kernel_argument.value())); - } - } } std::string KernelThunk::ToStringExtra(int indent) const { - return absl::StrFormat(", kernel = %s, launch dimensions = %s", kernel_name_, - launch_dimensions_.ToString()); + return absl::StrFormat( + ", kernel = %s, launch dimensions = %s, cluster_dim = %s", kernel_name_, + launch_dimensions_.ToString(), + cluster_dim_.has_value() ? cluster_dim_->ToString() : "nullopt"); } absl::Status KernelThunk::Initialize(const InitializeParams& params) { @@ -125,7 +116,7 @@ static void PrintBufferContents( int input_idx = 0; for (const se::DeviceMemoryBase& buf : buffer_args) { auto host_buffer = std::make_unique(buf.size()); - CHECK(stream->ThenMemcpy(host_buffer.get(), buf, buf.size()).ok()); + CHECK(stream->Memcpy(host_buffer.get(), buf, buf.size()).ok()); CHECK_OK(stream->BlockHostUntilDone()); std::string buffer_contents; @@ -141,6 +132,7 @@ absl::Status KernelThunk::ExecuteOnStream(const ExecuteParams& params) { // Load the kernel. se::StreamExecutor* executor = params.stream->parent(); LaunchDimensions launch_dimensions; + std::optional cluster_dim; const se::Kernel* kernel = nullptr; { @@ -149,6 +141,7 @@ absl::Status KernelThunk::ExecuteOnStream(const ExecuteParams& params) { CHECK(it != kernel_cache_.end()) << "Initialize() not called for StreamExecutor " << executor; launch_dimensions = launch_dimensions_; + cluster_dim = cluster_dim_; kernel = it->second.get(); } @@ -165,8 +158,13 @@ absl::Status KernelThunk::ExecuteOnStream(const ExecuteParams& params) { PrintBufferContents(params.stream, buffer_args); } - return ExecuteKernelOnStream(*kernel, buffer_args, launch_dimensions, - params.stream); + if (cluster_dim.has_value()) { + return ExecuteKernelOnStream(*kernel, buffer_args, launch_dimensions, + cluster_dim.value(), params.stream); + } else { + return ExecuteKernelOnStream(*kernel, buffer_args, launch_dimensions, + params.stream); + } } //===----------------------------------------------------------------------===// @@ -174,15 +172,10 @@ absl::Status KernelThunk::ExecuteOnStream(const ExecuteParams& params) { //===----------------------------------------------------------------------===// CustomKernelThunk::CustomKernelThunk( - std::variant instr, - CustomKernel custom_kernel, + const HloInstruction* instr, CustomKernel custom_kernel, absl::Span kernel_arguments) : Thunk(Kind::kCustomKernel, - std::holds_alternative(instr) - ? Thunk::ThunkInfo::WithProfileAnnotation( - std::get(instr)) - : Thunk::ThunkInfo::WithProfileAnnotation( - std::get(instr))), + Thunk::ThunkInfo::WithProfileAnnotation(instr)), custom_kernel_(std::move(custom_kernel)) { args_.reserve(kernel_arguments.size()); written_.reserve(kernel_arguments.size()); @@ -192,18 +185,6 @@ CustomKernelThunk::CustomKernelThunk( written_.push_back(kernel_argument.written()); } } - - if (std::holds_alternative(instr)) { - // Skip populating MLIR values_ if emitting from HLO. - return; - } - - values_.reserve(kernel_arguments.size()); - for (const auto& kernel_argument : kernel_arguments) { - if (!kernel_argument.first_with_same_slice().has_value()) { - values_.push_back(RemoveTransformingOperations(kernel_argument.value())); - } - } } std::string CustomKernelThunk::ToStringExtra(int indent) const { @@ -215,9 +196,9 @@ absl::Status CustomKernelThunk::Initialize(const InitializeParams& params) { auto it = kernel_cache_.find(params.executor); if (kernel_cache_.end() == it) { - auto kernel = std::make_unique(params.executor); - TF_RETURN_IF_ERROR( - params.executor->GetKernel(custom_kernel_.kernel_spec(), kernel.get())); + TF_ASSIGN_OR_RETURN( + std::unique_ptr kernel, + se::Kernel::Create(params.executor, custom_kernel_.kernel_spec())); kernel_cache_.emplace(params.executor, std::move(kernel)); } diff --git a/third_party/xla/xla/service/gpu/runtime3/kernel_thunk.h b/third_party/xla/xla/service/gpu/runtime/kernel_thunk.h similarity index 82% rename from third_party/xla/xla/service/gpu/runtime3/kernel_thunk.h rename to third_party/xla/xla/service/gpu/runtime/kernel_thunk.h index 6e4a45a4ffa765..c88ce5366c8c93 100644 --- a/third_party/xla/xla/service/gpu/runtime3/kernel_thunk.h +++ b/third_party/xla/xla/service/gpu/runtime/kernel_thunk.h @@ -13,11 +13,12 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_GPU_RUNTIME3_KERNEL_THUNK_H_ -#define XLA_SERVICE_GPU_RUNTIME3_KERNEL_THUNK_H_ +#ifndef XLA_SERVICE_GPU_RUNTIME_KERNEL_THUNK_H_ +#define XLA_SERVICE_GPU_RUNTIME_KERNEL_THUNK_H_ #include #include +#include #include #include #include @@ -27,8 +28,6 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "absl/synchronization/mutex.h" #include "absl/types/span.h" -#include "mlir/IR/Operation.h" // from @llvm-project -#include "mlir/IR/Value.h" // from @llvm-project #include "xla/hlo/ir/hlo_instruction.h" #include "xla/service/buffer_assignment.h" #include "xla/service/gpu/kernel_arguments.h" @@ -36,6 +35,7 @@ limitations under the License. #include "xla/service/gpu/launch_dimensions.h" #include "xla/service/gpu/thunk.h" #include "xla/status.h" +#include "xla/stream_executor/launch_dim.h" #include "xla/stream_executor/stream_executor.h" #include "xla/types.h" // IWYU pragma: keep @@ -69,10 +69,10 @@ class KernelThunk : public Thunk { // output of the computation. Also, the values must correspond to each arg // directly, not to their base allocation (e.g. they can be the result of an // `mlir::memref::ViewOp`). - KernelThunk(std::variant op, - std::string kernel_name, + KernelThunk(const HloInstruction* instr, std::string kernel_name, absl::Span kernel_arguments, - LaunchDimensions launch_dimensions, int64_t shmem_bytes); + LaunchDimensions launch_dimensions, + std::optional cluster_dim, int64_t shmem_bytes); KernelThunk(const KernelThunk&) = delete; KernelThunk& operator=(const KernelThunk&) = delete; ~KernelThunk() override = default; @@ -82,12 +82,7 @@ class KernelThunk : public Thunk { absl::Status Initialize(const InitializeParams& params) override; absl::Status ExecuteOnStream(const ExecuteParams& params) override; - void ClearCompileTimeInfo() override { - Thunk::ClearCompileTimeInfo(); - for (auto& value : values_) { - value = nullptr; - } - } + void ClearCompileTimeInfo() override { Thunk::ClearCompileTimeInfo(); } const std::vector& arguments() const { return args_; @@ -100,7 +95,6 @@ class KernelThunk : public Thunk { } // The shared memory required by the kernel. int64_t shmem_bytes() const { return shmem_bytes_; } - absl::Span values() const { return values_; } private: // Buffer slices passed to the kernel as arguments. @@ -115,10 +109,10 @@ class KernelThunk : public Thunk { // The thread and block dimension used to launch the kernel. const LaunchDimensions launch_dimensions_; - int64_t shmem_bytes_; + // The cluster dimensions used to launch the kernel. + const std::optional cluster_dim_; - // mlir::Value(s) corresponding to the buffer slice arguments. - std::vector values_; + int64_t shmem_bytes_; // Loaded kernels for each `StreamExecutor`. mutable absl::Mutex mutex_; @@ -135,8 +129,7 @@ class KernelThunk : public Thunk { // compiled by XLA and loaded from an executable source. class CustomKernelThunk : public Thunk { public: - CustomKernelThunk(std::variant inst, - CustomKernel custom_kernel, + CustomKernelThunk(const HloInstruction* inst, CustomKernel custom_kernel, absl::Span kernel_arguments); std::string ToStringExtra(int indent) const override; @@ -149,13 +142,10 @@ class CustomKernelThunk : public Thunk { const std::vector& arguments() const { return args_; } - // TODO(ezhulenev): All of the APIs below needed only for LMHLO lowering and - // should be removed after we migrate to Thunks runtime. std::string_view custom_kernel_name() const { return custom_kernel_.name(); } const std::vector& written() const { return written_; } - absl::Span values() const { return values_; } LaunchDimensions launch_dimensions() const { return LaunchDimensions(custom_kernel_.block_dims(), @@ -171,9 +161,6 @@ class CustomKernelThunk : public Thunk { // args_[i] is written iff (written_[i] == true). std::vector written_; - // mlir::Value(s) corresponding to the buffer slice arguments. - std::vector values_; - CustomKernel custom_kernel_; // Loaded kernels for each `StreamExecutor`. @@ -185,4 +172,4 @@ class CustomKernelThunk : public Thunk { } // namespace gpu } // namespace xla -#endif // XLA_SERVICE_GPU_RUNTIME3_KERNEL_THUNK_H_ +#endif // XLA_SERVICE_GPU_RUNTIME_KERNEL_THUNK_H_ diff --git a/third_party/xla/xla/service/gpu/runtime/memcpy.cc b/third_party/xla/xla/service/gpu/runtime/memcpy.cc deleted file mode 100644 index 9fd5888433f0bc..00000000000000 --- a/third_party/xla/xla/service/gpu/runtime/memcpy.cc +++ /dev/null @@ -1,98 +0,0 @@ -/* Copyright 2022 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "xla/service/gpu/runtime/memcpy.h" - -#include "xla/runtime/custom_call.h" -#include "xla/runtime/executable.h" -#include "xla/service/gpu/runtime/concurrent_region.h" -#include "xla/service/gpu/runtime/support.h" -#include "xla/service/service_executable_run_options.h" - -namespace xla { -namespace gpu { - -using xla::runtime::CustomCall; -using xla::runtime::StridedMemrefView; - -enum class MemcpyDirection { kD2D, kD2H, kH2D }; - -template -absl::Status MemcpyImpl(const ServiceExecutableRunOptions* run_options, - ConcurrentRegionStatus* region_status, - runtime::StridedMemrefView dst, - runtime::StridedMemrefView src, int64_t stream_id) { - se::Stream* stream = run_options->stream(); - if (stream_id != 0) { - DCHECK(region_status->IsInConcurrentRegion()); - TF_ASSIGN_OR_RETURN(stream, region_status->GetStream(stream_id)); - } else if (region_status->IsInConcurrentRegion()) { - stream = region_status->GetNextStream(); - } - - if (dst.sizes != src.sizes) { - return absl::InvalidArgumentError( - "Source memref sizes do not match destination memref sizes"); - } - - if (dst.strides != src.strides) { - return absl::InvalidArgumentError( - "Source memref strides do not match destination memref strides"); - } - - switch (direction) { - case MemcpyDirection::kD2D: { - se::DeviceMemoryBase dst_data = GetDeviceAddress(dst); - se::DeviceMemoryBase src_data = GetDeviceAddress(src); - stream->ThenMemcpy(&dst_data, src_data, src_data.size()); - } break; - case MemcpyDirection::kD2H: { - se::DeviceMemoryBase src_data = GetDeviceAddress(src); - stream->ThenMemcpy(dst.data, src_data, src_data.size()); - } break; - case MemcpyDirection::kH2D: { - se::DeviceMemoryBase dst_data = GetDeviceAddress(dst); - stream->ThenMemcpy(&dst_data, src.data, dst_data.size()); - } break; - } - - // TODO(jacksonstokes): H2D and D2H memcpy instead of blocking the execution - // thread should return an async token that will become available when - // transfer is completed. - if (direction != MemcpyDirection::kD2D) { - TF_RETURN_IF_ERROR(stream->BlockHostUntilDone()); - } - - return absl::OkStatus(); -} - -XLA_RUNTIME_DEFINE_CUSTOM_CALL_TEMPLATE( - MemcpyDirection direction, Memcpy, FunctionWrapper>(), - checks, - CustomCall::Bind("xla.gpu.memcpy") - .UserData() - .UserData() - .Arg() // dst - .Arg() // src - .Attr("stream")); - -void RegisterMemcpyCustomCalls(runtime::DirectCustomCallRegistry& registry) { - registry.Register("xla.gpu.memcpy.d2d", Memcpy); - registry.Register("xla.gpu.memcpy.h2d", Memcpy); - registry.Register("xla.gpu.memcpy.d2h", Memcpy); -} - -} // namespace gpu -} // namespace xla diff --git a/third_party/xla/xla/service/gpu/runtime/memcpy.h b/third_party/xla/xla/service/gpu/runtime/memcpy.h deleted file mode 100644 index 6fe7ea05155644..00000000000000 --- a/third_party/xla/xla/service/gpu/runtime/memcpy.h +++ /dev/null @@ -1,30 +0,0 @@ -/* Copyright 2022 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_SERVICE_GPU_RUNTIME_MEMCPY_H_ -#define XLA_SERVICE_GPU_RUNTIME_MEMCPY_H_ - -#include "xla/runtime/custom_call_registry.h" - -namespace xla { -namespace gpu { - -// Registers XLA Gpu runtime memcpy custom calls. -void RegisterMemcpyCustomCalls(runtime::DirectCustomCallRegistry& registry); - -} // namespace gpu -} // namespace xla - -#endif // XLA_SERVICE_GPU_RUNTIME_MEMCPY_H_ diff --git a/third_party/xla/xla/service/gpu/runtime/memset.cc b/third_party/xla/xla/service/gpu/runtime/memset.cc deleted file mode 100644 index b22834a1642375..00000000000000 --- a/third_party/xla/xla/service/gpu/runtime/memset.cc +++ /dev/null @@ -1,147 +0,0 @@ -/* Copyright 2022 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "xla/service/gpu/runtime/memset.h" - -#include "absl/base/casts.h" -#include "xla/runtime/custom_call.h" -#include "xla/runtime/executable.h" -#include "xla/service/gpu/runtime/support.h" -#include "xla/service/service_executable_run_options.h" - -namespace xla { -namespace gpu { - -using xla::runtime::CustomCall; -using xla::runtime::StridedMemrefView; - -// Checks all supported data types to see if the value is zero. -static bool IsZero(CustomCall::VariantArg constant) { - if (auto i1 = constant.get(); succeeded(i1)) - return *i1 == false; - else if (auto i8 = constant.get(); succeeded(i8)) - return *i8 == 0; - else if (auto i16 = constant.get(); succeeded(i16)) - return *i16 == 0; - else if (auto i32 = constant.get(); succeeded(i32)) - return *i32 == 0; - else if (auto i64 = constant.get(); succeeded(i64)) - return *i64 == 0; - else if (auto bf16 = constant.get(); succeeded(bf16)) - return *bf16 == bfloat16(0.0); - else if (auto f16 = constant.get(); succeeded(f16)) - return *f16 == half(0.0); - else if (auto f32 = constant.get(); succeeded(f32)) - return *f32 == 0.0; - else if (auto f64 = constant.get(); succeeded(f64)) - return *f64 == 0.0; - - return false; -} - -// Convert constant value to 32-bit pattern. -static absl::StatusOr ToBitPattern(CustomCall::VariantArg constant) { - // If the value is 8 or 16 bits wide, we can emit a 32-bit memset by - // repeating the value 4 or 2 times, so long as the destination buffer is - // an even multiple of 32 bits long. - // - // This code is identical to `ir_emitter_unnested`. - // - // We use `memcpy` operation to copy bytes between value and the uint32_t bit - // pattern because in theory they might have incompatible alignment, and we - // rely on LLVM to optimize it. - auto extend = [](auto value) -> uint32_t { - static constexpr size_t num_bytes = sizeof(value); - static_assert(num_bytes < 4); - - uint16_t pattern16; - if constexpr (num_bytes == 1) { - uint8_t b = value; - pattern16 = uint16_t{b} | (uint16_t{b} << 8); - } else { - memcpy(&pattern16, &value, sizeof(pattern16)); - } - return uint32_t{pattern16} | (uint32_t{pattern16} << 16); - }; - - // Truncate value to 32-bit pattern. - auto truncate = [](auto value) -> uint32_t { - static_assert(sizeof(value) >= 4); - - uint32_t pattern; - memcpy(&pattern, &value, sizeof(pattern)); - return pattern; - }; - - if (auto i1 = constant.get(); succeeded(i1)) - return extend(*i1); - else if (auto i8 = constant.get(); succeeded(i8)) - return extend(*i8); - else if (auto i16 = constant.get(); succeeded(i16)) - return extend(*i16); - else if (auto i32 = constant.get(); succeeded(i32)) - return truncate(*i32); - else if (auto i64 = constant.get(); succeeded(i64)) - return truncate(*i64); - else if (auto bf16 = constant.get(); succeeded(bf16)) - return extend(absl::bit_cast(*bf16)); - else if (auto f16 = constant.get(); succeeded(f16)) - return extend(absl::bit_cast(*f16)); - else if (auto f32 = constant.get(); succeeded(f32)) - return truncate(*f32); - else if (auto f64 = constant.get(); succeeded(f64)) - return truncate(*f64); - - return absl::InvalidArgumentError("Unsupported memset constant type"); -} - -static absl::Status MemsetImpl(const ServiceExecutableRunOptions* run_options, - StridedMemrefView dst, - CustomCall::VariantArg constant) { - se::Stream* stream = run_options->stream(); - se::DeviceMemoryBase dst_data = GetDeviceAddress(dst); - - // If the constant is zero we can use memzero directly. - if (IsZero(constant)) { - stream->ThenMemZero(&dst_data, dst_data.size()); - return absl::OkStatus(); - } - - // If the constant is not zero, use the given pattern to `memset`. - TF_ASSIGN_OR_RETURN(uint32_t pattern, ToBitPattern(constant)); - - if (dst_data.size() % 4 != 0) { - return absl::InvalidArgumentError("Memref size is not divisible by 4"); - } - - stream->ThenMemset32(&dst_data, pattern, dst_data.size()); - - return absl::OkStatus(); -} - -XLA_RUNTIME_DEFINE_CUSTOM_CALL( - Memset, FunctionWrapper(), checks, - CustomCall::Bind("xla.gpu.memset") - .UserData() - .Arg() // dst - .Arg() // constant -); - -void RegisterMemsetCustomCalls(runtime::DirectCustomCallRegistry& registry) { - registry.Register("xla.gpu.memset", Memset); -} - -} // namespace gpu -} // namespace xla diff --git a/third_party/xla/xla/service/gpu/runtime3/memset_thunk.cc b/third_party/xla/xla/service/gpu/runtime/memset_thunk.cc similarity index 83% rename from third_party/xla/xla/service/gpu/runtime3/memset_thunk.cc rename to third_party/xla/xla/service/gpu/runtime/memset_thunk.cc index 98c79a5fe4cb2e..573db38dd877e9 100644 --- a/third_party/xla/xla/service/gpu/runtime3/memset_thunk.cc +++ b/third_party/xla/xla/service/gpu/runtime/memset_thunk.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/runtime3/memset_thunk.h" +#include "xla/service/gpu/runtime/memset_thunk.h" #include "absl/status/status.h" #include "xla/stream_executor/stream_executor.h" @@ -24,16 +24,14 @@ namespace gpu { absl::Status MemzeroThunk::ExecuteOnStream(const ExecuteParams& params) { se::DeviceMemoryBase dest_data = params.buffer_allocations->GetDeviceAddress(dest_); - params.stream->ThenMemZero(&dest_data, dest_data.size()); - return absl::OkStatus(); + return params.stream->MemZero(&dest_data, dest_data.size()); } absl::Status Memset32BitValueThunk::ExecuteOnStream( const ExecuteParams& params) { se::DeviceMemoryBase dest_data = params.buffer_allocations->GetDeviceAddress(dest_); - params.stream->ThenMemset32(&dest_data, value_, dest_data.size()); - return absl::OkStatus(); + return params.stream->Memset32(&dest_data, value_, dest_data.size()); } } // namespace gpu diff --git a/third_party/xla/xla/service/gpu/runtime3/memset_thunk.h b/third_party/xla/xla/service/gpu/runtime/memset_thunk.h similarity index 69% rename from third_party/xla/xla/service/gpu/runtime3/memset_thunk.h rename to third_party/xla/xla/service/gpu/runtime/memset_thunk.h index b4f3886a424807..0c1646302eec71 100644 --- a/third_party/xla/xla/service/gpu/runtime3/memset_thunk.h +++ b/third_party/xla/xla/service/gpu/runtime/memset_thunk.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_GPU_RUNTIME3_MEMSET_THUNK_H_ -#define XLA_SERVICE_GPU_RUNTIME3_MEMSET_THUNK_H_ +#ifndef XLA_SERVICE_GPU_RUNTIME_MEMSET_THUNK_H_ +#define XLA_SERVICE_GPU_RUNTIME_MEMSET_THUNK_H_ #include "xla/service/buffer_assignment.h" #include "xla/service/gpu/thunk.h" @@ -30,25 +30,17 @@ namespace gpu { class MemzeroThunk : public Thunk { public: explicit MemzeroThunk(ThunkInfo thunk_info, - const BufferAllocation::Slice& dest, - mlir::Value dest_value) - : Thunk(Kind::kMemzero, thunk_info), - dest_(dest), - dest_value_(dest_value) {} + const BufferAllocation::Slice& dest) + : Thunk(Kind::kMemzero, thunk_info), dest_(dest) {} absl::Status ExecuteOnStream(const ExecuteParams& params) override; - void ClearCompileTimeInfo() override { - Thunk::ClearCompileTimeInfo(); - dest_value_ = nullptr; - } + void ClearCompileTimeInfo() override { Thunk::ClearCompileTimeInfo(); } const BufferAllocation::Slice& destination() const { return dest_; } - mlir::Value dest_value() const { return dest_value_; } private: const BufferAllocation::Slice dest_; - mlir::Value dest_value_; }; // Thunk that sets a given chunk of memory to a particular 32-bit value. The @@ -56,31 +48,24 @@ class MemzeroThunk : public Thunk { class Memset32BitValueThunk : public Thunk { public: explicit Memset32BitValueThunk(ThunkInfo thunk_info, uint32_t value, - const BufferAllocation::Slice& dest, - mlir::Value dest_value) + const BufferAllocation::Slice& dest) : Thunk(Kind::kMemset32BitValue, thunk_info), value_(value), - dest_(dest), - dest_value_(dest_value) {} + dest_(dest) {} absl::Status ExecuteOnStream(const ExecuteParams& params) override; - void ClearCompileTimeInfo() override { - Thunk::ClearCompileTimeInfo(); - dest_value_ = nullptr; - } + void ClearCompileTimeInfo() override { Thunk::ClearCompileTimeInfo(); } const BufferAllocation::Slice& destination() const { return dest_; } uint32_t value() const { return value_; } - mlir::Value dest_value() const { return dest_value_; } private: const uint32_t value_; const BufferAllocation::Slice dest_; - mlir::Value dest_value_; }; } // namespace gpu } // namespace xla -#endif // XLA_SERVICE_GPU_RUNTIME3_MEMSET_THUNK_H_ +#endif // XLA_SERVICE_GPU_RUNTIME_MEMSET_THUNK_H_ diff --git a/third_party/xla/xla/service/gpu/nccl_all_gather_thunk.cc b/third_party/xla/xla/service/gpu/runtime/nccl_all_gather_thunk.cc similarity index 93% rename from third_party/xla/xla/service/gpu/nccl_all_gather_thunk.cc rename to third_party/xla/xla/service/gpu/runtime/nccl_all_gather_thunk.cc index d71a7fed092df9..49a83b344923cc 100644 --- a/third_party/xla/xla/service/gpu/nccl_all_gather_thunk.cc +++ b/third_party/xla/xla/service/gpu/runtime/nccl_all_gather_thunk.cc @@ -13,10 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/nccl_all_gather_thunk.h" +#include "xla/service/gpu/runtime/nccl_all_gather_thunk.h" #include -#include #include #include "absl/status/status.h" @@ -25,7 +24,6 @@ limitations under the License. #include "xla/hlo/ir/hlo_instructions.h" #include "xla/service/collective_ops_utils.h" #include "xla/service/gpu/backend_configs.pb.h" -#include "xla/service/gpu/ir_emission_utils.h" #include "xla/service/gpu/nccl_api.h" #include "xla/service/gpu/nccl_collective_thunk.h" #include "xla/service/gpu/thunk.h" @@ -64,7 +62,7 @@ absl::Status CheckImplementableInst(const HloAllGatherInstruction* inst) { if (!ShapeUtil::IsEffectivelyMostMajorDimension( shape, inst->all_gather_dimension())) { - return tsl::errors::Unimplemented(absl::StrFormat( + return absl::AbortedError(absl::StrFormat( "all-gather dim %u is not the most major in input shape %s", inst->all_gather_dimension(), shape.ToString(/*print_layout=*/true))); } @@ -79,7 +77,7 @@ absl::Status CheckImplementable(AllGatherStartOp op) { Shape shape = GetShape(operand); if (!ShapeUtil::IsEffectivelyMostMajorDimension( shape, op.getAllGatherDimension())) { - return tsl::errors::Unimplemented(absl::StrFormat( + return absl::AbortedError(absl::StrFormat( "all-gather dim %u is not the most major in input shape %s", op.getAllGatherDimension(), shape.ToString(/*print_layout=*/true))); } @@ -102,9 +100,7 @@ NcclAllGatherStartThunk::NcclAllGatherStartThunk( ThunkInfo thunk_info, NcclApi* nccl_api, const HloAllGatherInstruction* inst, std::vector buffers) : NcclCollectiveThunk(Thunk::kNcclAllGatherStart, thunk_info, nccl_api, - inst->backend_config() - ->collective_backend_config() - .is_sync()), + IsSyncCollective(inst)), config_(impl::GetNcclAllGatherConfig(inst)), buffers_(std::move(buffers)) { CHECK_EQ(config_.config.operand_count, buffers_.size()); diff --git a/third_party/xla/xla/service/gpu/nccl_all_gather_thunk.h b/third_party/xla/xla/service/gpu/runtime/nccl_all_gather_thunk.h similarity index 94% rename from third_party/xla/xla/service/gpu/nccl_all_gather_thunk.h rename to third_party/xla/xla/service/gpu/runtime/nccl_all_gather_thunk.h index a35355c17dc77d..cc62464ec501b1 100644 --- a/third_party/xla/xla/service/gpu/nccl_all_gather_thunk.h +++ b/third_party/xla/xla/service/gpu/runtime/nccl_all_gather_thunk.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_GPU_NCCL_ALL_GATHER_THUNK_H_ -#define XLA_SERVICE_GPU_NCCL_ALL_GATHER_THUNK_H_ +#ifndef XLA_SERVICE_GPU_RUNTIME_NCCL_ALL_GATHER_THUNK_H_ +#define XLA_SERVICE_GPU_RUNTIME_NCCL_ALL_GATHER_THUNK_H_ #include #include @@ -81,4 +81,4 @@ absl::Status RunAllGather(NcclApi* nccl_api, } // namespace gpu } // namespace xla -#endif // XLA_SERVICE_GPU_NCCL_ALL_GATHER_THUNK_H_ +#endif // XLA_SERVICE_GPU_RUNTIME_NCCL_ALL_GATHER_THUNK_H_ diff --git a/third_party/xla/xla/service/gpu/nccl_all_reduce_thunk.cc b/third_party/xla/xla/service/gpu/runtime/nccl_all_reduce_thunk.cc similarity index 97% rename from third_party/xla/xla/service/gpu/nccl_all_reduce_thunk.cc rename to third_party/xla/xla/service/gpu/runtime/nccl_all_reduce_thunk.cc index 50565cf43f851a..a2644e38c79574 100644 --- a/third_party/xla/xla/service/gpu/nccl_all_reduce_thunk.cc +++ b/third_party/xla/xla/service/gpu/runtime/nccl_all_reduce_thunk.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/nccl_all_reduce_thunk.h" +#include "xla/service/gpu/runtime/nccl_all_reduce_thunk.h" #include #include @@ -126,7 +126,7 @@ absl::Status CheckImplementableInst(const HloInstruction* inst, if (!MatchReductionComputation(inst->called_computations().front()) .has_value()) { - return tsl::errors::Unimplemented("Unrecognized reduction computation"); + return absl::UnimplementedError("Unrecognized reduction computation"); } return absl::OkStatus(); @@ -140,7 +140,7 @@ absl::Status CheckImplementable(OpT op, Thunk::Kind reduction_op) { if (!NcclAllReduceReduceScatterThunkBase::MatchAllReduceComputation( op.getComputation()) .has_value()) { - return tsl::errors::Unimplemented("Unrecognized reduction computation"); + return absl::UnimplementedError("Unrecognized reduction computation"); } return absl::OkStatus(); } @@ -251,9 +251,7 @@ NcclAllReduceStartThunk::NcclAllReduceStartThunk( : NcclAllReduceReduceScatterThunkBase( Thunk::kNcclAllReduceStart, thunk_info, nccl_api, impl::GetNcclAllReduceConfigInst(inst), std::move(buffers), - inst->backend_config() - ->collective_backend_config() - .is_sync()) {} + IsSyncCollective(inst)) {} absl::Status NcclAllReduceStartThunk::CheckImplementable( AllReduceStartOp op, int64_t replica_count, int64_t partition_count) { diff --git a/third_party/xla/xla/service/gpu/nccl_all_reduce_thunk.h b/third_party/xla/xla/service/gpu/runtime/nccl_all_reduce_thunk.h similarity index 96% rename from third_party/xla/xla/service/gpu/nccl_all_reduce_thunk.h rename to third_party/xla/xla/service/gpu/runtime/nccl_all_reduce_thunk.h index 79750940fed80e..6fe1b8788a7240 100644 --- a/third_party/xla/xla/service/gpu/nccl_all_reduce_thunk.h +++ b/third_party/xla/xla/service/gpu/runtime/nccl_all_reduce_thunk.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_GPU_NCCL_ALL_REDUCE_THUNK_H_ -#define XLA_SERVICE_GPU_NCCL_ALL_REDUCE_THUNK_H_ +#ifndef XLA_SERVICE_GPU_RUNTIME_NCCL_ALL_REDUCE_THUNK_H_ +#define XLA_SERVICE_GPU_RUNTIME_NCCL_ALL_REDUCE_THUNK_H_ #include #include @@ -144,4 +144,4 @@ absl::Status RunReduceScatter(NcclApi* nccl_api, ReductionKind reduction_kind, } // namespace gpu } // namespace xla -#endif // XLA_SERVICE_GPU_NCCL_ALL_REDUCE_THUNK_H_ +#endif // XLA_SERVICE_GPU_RUNTIME_NCCL_ALL_REDUCE_THUNK_H_ diff --git a/third_party/xla/xla/service/gpu/nccl_all_to_all_thunk.cc b/third_party/xla/xla/service/gpu/runtime/nccl_all_to_all_thunk.cc similarity index 99% rename from third_party/xla/xla/service/gpu/nccl_all_to_all_thunk.cc rename to third_party/xla/xla/service/gpu/runtime/nccl_all_to_all_thunk.cc index 1b11d0b86b4373..e83db700456b95 100644 --- a/third_party/xla/xla/service/gpu/nccl_all_to_all_thunk.cc +++ b/third_party/xla/xla/service/gpu/runtime/nccl_all_to_all_thunk.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/nccl_all_to_all_thunk.h" +#include "xla/service/gpu/runtime/nccl_all_to_all_thunk.h" #include #include diff --git a/third_party/xla/xla/service/gpu/nccl_all_to_all_thunk.h b/third_party/xla/xla/service/gpu/runtime/nccl_all_to_all_thunk.h similarity index 94% rename from third_party/xla/xla/service/gpu/nccl_all_to_all_thunk.h rename to third_party/xla/xla/service/gpu/runtime/nccl_all_to_all_thunk.h index 83337339814083..ed19d638e3455e 100644 --- a/third_party/xla/xla/service/gpu/nccl_all_to_all_thunk.h +++ b/third_party/xla/xla/service/gpu/runtime/nccl_all_to_all_thunk.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_GPU_NCCL_ALL_TO_ALL_THUNK_H_ -#define XLA_SERVICE_GPU_NCCL_ALL_TO_ALL_THUNK_H_ +#ifndef XLA_SERVICE_GPU_RUNTIME_NCCL_ALL_TO_ALL_THUNK_H_ +#define XLA_SERVICE_GPU_RUNTIME_NCCL_ALL_TO_ALL_THUNK_H_ #include @@ -74,4 +74,4 @@ absl::Status RunAllToAll(NcclApi* nccl_api, bool has_split_dimension, } // namespace gpu } // namespace xla -#endif // XLA_SERVICE_GPU_NCCL_ALL_TO_ALL_THUNK_H_ +#endif // XLA_SERVICE_GPU_RUNTIME_NCCL_ALL_TO_ALL_THUNK_H_ diff --git a/third_party/xla/xla/service/gpu/runtime/norm.cc b/third_party/xla/xla/service/gpu/runtime/norm.cc deleted file mode 100644 index 5a7647521029bf..00000000000000 --- a/third_party/xla/xla/service/gpu/runtime/norm.cc +++ /dev/null @@ -1,231 +0,0 @@ -/* Copyright 2023 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "xla/service/gpu/runtime/norm.h" - -#include -#include -#include -#include -#include - -#include "llvm/ADT/Sequence.h" -#include "xla/mlir/runtime/transforms/custom_call_encoding.h" -#include "xla/runtime/custom_call.h" -#include "xla/runtime/executable.h" -#include "xla/service/gpu/gpu_asm_opts_util.h" -#include "xla/service/gpu/runtime/support.h" -#include "xla/service/service_executable_run_options.h" -#include "xla/status.h" -#include "xla/stream_executor/device_memory.h" -#include "xla/stream_executor/device_memory_allocator.h" -#include "xla/translate/mhlo_to_hlo/attribute_exporter.h" -#include "xla/xla.pb.h" - -namespace xla { - -using xla::runtime::CustomCall; -using xla::runtime::EnumAttrEncoding; -using xla::runtime::FlatMemrefView; -using xla::runtime::State; -using xla::runtime::StridedMemrefView; -using xla::runtime::Tagged; - -namespace lmhlo_gpu = ::mlir::lmhlo_gpu; -namespace gpu { - -struct NormAlgorithmConfig { - int64_t algorithm; - int64_t workspace_size; -}; - -void PopulateNormAlgorithmConfigAttrEncoding( - runtime::CustomCallAttrEncodingSet& encoding) { - { // --- Encode `lmhlo_gpu::NormAlgorithmConfigAttr`. - using Attr = mlir::lmhlo_gpu::NormAlgorithmConfigAttr; - encoding - .Add>( - encoding, xla::runtime::AggregateAttrDef() - .Add("algorithm", &Attr::getAlgorithm) - .Add("workspace_size", &Attr::getWorkspaceSize)); - } -} -} // namespace gpu - -namespace runtime { -XLA_RUNTIME_REGISTER_AGGREGATE_ATTR_DECODING( - xla::gpu::NormAlgorithmConfig, // - AggregateMember("algorithm"), - AggregateMember("workspace_size")); -} // namespace runtime - -namespace gpu { - -void RegisterNormTypeIdNames(runtime::TypeIDNameRegistry& registry) { - registry.Register>( - "__type_id_norm_algorithm_config"); -} - -static GpuNormDescriptor GetGpuNormDescriptor( - StridedMemrefView input, StridedMemrefView scale, StridedMemrefView bias, - StridedMemrefView output, std::optional expectation, - std::optional norm_factor, double epsilon, - NormAlgorithmConfig algorithm_config, - absl::Span operand_layouts) { - GpuNormDescriptor descriptor; - - auto* algorithm = descriptor.backend_config.mutable_algorithm(); - algorithm->set_algo_id(algorithm_config.algorithm); - algorithm->set_is_cudnn_frontend(true); - if (algorithm_config.workspace_size >= 0) { - algorithm->mutable_workspace_size()->set_value( - algorithm_config.workspace_size); - } - - // Apply backend config layout to the shape. - int layout_idx = 0; - auto apply_shape = [&operand_layouts, - &layout_idx](const StridedMemrefView& memref) -> Shape { - std::vector minor_to_major = { - operand_layouts.begin() + layout_idx, - operand_layouts.begin() + layout_idx + memref.sizes.size()}; - layout_idx += memref.sizes.size(); - Shape shape = ToShape(memref); - return ShapeUtil::MakeShapeWithDenseLayout( - shape.element_type(), shape.dimensions(), minor_to_major); - }; - - descriptor.input_shape = apply_shape(input); - descriptor.scale_shape = apply_shape(scale); - descriptor.bias_shape = apply_shape(bias); - descriptor.output_shape = apply_shape(output); - if (expectation) { - descriptor.expectation_shape = apply_shape(*expectation); - } - if (norm_factor) { - descriptor.norm_factor_shape = apply_shape(*norm_factor); - } - - descriptor.backend_config.set_epsilon(epsilon); - - return descriptor; -} - -static absl::Status NormImpl(const ServiceExecutableRunOptions* run_options, - const DebugOptions* debug_options, - State runner_state, - StridedMemrefView input, StridedMemrefView scale, - StridedMemrefView bias, StridedMemrefView output, - CustomCall::RemainingArgs remaining_args, - int64_t uid, double epsilon, - absl::Span operand_layouts, - NormAlgorithmConfig algorithm_config) { - std::optional expectation, norm_factor; - // Final remaining arg is the scratch space. - if (remaining_args.size() == 3) { - auto expectation_ = remaining_args.get(0); - if (failed(expectation_)) { - return absl::InternalError("Failure while retrieving expectation."); - } - expectation = expectation_.value(); - - auto norm_factor_ = remaining_args.get(1); - if (failed(norm_factor_)) { - return absl::InternalError("Failure while retrieving norm factor."); - } - norm_factor = norm_factor_.value(); - } - - GpuNormDescriptor descriptor = - GetGpuNormDescriptor(input, scale, bias, output, expectation, norm_factor, - epsilon, algorithm_config, operand_layouts); - - auto config = GpuNormConfig::For(descriptor); - if (!config.ok()) { - return tsl::ToAbslStatus(config.status()); - } - auto current_runner = - runner_state.GetOrCreate([&config]() -> absl::StatusOr { - return NormRunnerState(std::move(config.value())); - }); - if (!current_runner.ok()) { - return tsl::ToAbslStatus(current_runner.status()); - } - - se::DeviceMemoryBase input_buffer = GetDeviceAddress(input); - se::DeviceMemoryBase scale_buffer = GetDeviceAddress(scale); - se::DeviceMemoryBase bias_buffer = GetDeviceAddress(bias); - se::DeviceMemoryBase output_buffer = GetDeviceAddress(output); - std::optional expectation_buffer, norm_factor_buffer; - if (expectation) { - expectation_buffer = GetDeviceAddress(expectation.value()); - } - if (norm_factor) { - norm_factor_buffer = GetDeviceAddress(norm_factor.value()); - } - - auto scratch = remaining_args.get(remaining_args.size() - 1); - if (failed(scratch)) { - return absl::InternalError("Failure while retrieving scratch."); - } - se::DeviceMemoryBase scratch_buffer = GetDeviceAddress(scratch.value()); - - RunNormOptions opts; - opts.norm_runner = ¤t_runner.value()->runner; - - // Run the norm. - return RunGpuNorm(current_runner.value()->config, input_buffer, scale_buffer, - bias_buffer, output_buffer, expectation_buffer, - norm_factor_buffer, scratch_buffer, run_options->stream(), - opts); -} - -template -auto BindNormAttributes(runtime::CustomCallBinding binding) { - return std::move(binding) - // Unique convolution id for caching state. - .template Attr("uid") - .template Attr("epsilon") - .template Attr>("operand_layouts") - .template Attr("norm_algorithm_config"); -} - -auto NormCall(const char* name) { - return CustomCall::Bind(name) - .UserData() - .UserData() - .State("uid") - .Arg() // input - .Arg() // scale - .Arg() // bias - .Arg(); // output -} - -XLA_RUNTIME_DEFINE_CUSTOM_CALL( - Norm, FunctionWrapper(), checks, - BindNormAttributes(NormCall("xla.gpu.norm").RemainingArgs())); - -void RegisterNormCustomCalls(runtime::DirectCustomCallRegistry& registry) { - registry.Register("xla.gpu.norm", Norm); -} - -StreamExecutorNormRunners* NormRunnerStates::operator()( - se::StreamExecutor* executor) { - absl::MutexLock lock(&mutex_); - return &runners_[executor]; -} - -} // namespace gpu -} // namespace xla diff --git a/third_party/xla/xla/service/gpu/runtime/norm.h b/third_party/xla/xla/service/gpu/runtime/norm.h deleted file mode 100644 index ad2c13dbb4cc3b..00000000000000 --- a/third_party/xla/xla/service/gpu/runtime/norm.h +++ /dev/null @@ -1,65 +0,0 @@ -/* Copyright 2023 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_SERVICE_GPU_RUNTIME_NORM_H_ -#define XLA_SERVICE_GPU_RUNTIME_NORM_H_ - -#include -#include - -#include "absl/container/node_hash_map.h" -#include "absl/synchronization/mutex.h" -#include "xla/mlir/runtime/transforms/custom_call_encoding.h" -#include "xla/runtime/custom_call_registry.h" -#include "xla/service/gpu/gpu_norm_runner.h" - -namespace xla { -namespace gpu { - -// Registers XLA GPU runtime norm custom calls. -void RegisterNormCustomCalls(runtime::DirectCustomCallRegistry& registry); - -// Register type names for norm attributes defined by MHLO dialect. -void RegisterNormTypeIdNames(runtime::TypeIDNameRegistry& registry); - -void PopulateNormAlgorithmConfigAttrEncoding( - runtime::CustomCallAttrEncodingSet& encoding); - -// State of the norm runners between invocations. -struct NormRunnerState { - explicit NormRunnerState(GpuNormConfig config) - : config(std::move(config)), runner(this->config) {} - GpuNormConfig config; - NormRunner runner; -}; - -class StreamExecutorNormRunners : public runtime::StateVector { -}; - -// XLA executable keeps a mapping from stream executors to norm runners. -class NormRunnerStates { - public: - StreamExecutorNormRunners* operator()(se::StreamExecutor* executor); - - private: - mutable absl::Mutex mutex_; - absl::node_hash_map runners_ - ABSL_GUARDED_BY(mutex_); -}; - -} // namespace gpu -} // namespace xla - -#endif // XLA_SERVICE_GPU_RUNTIME_NORM_H_ diff --git a/third_party/xla/xla/service/gpu/runtime3/norm_thunk.cc b/third_party/xla/xla/service/gpu/runtime/norm_thunk.cc similarity index 61% rename from third_party/xla/xla/service/gpu/runtime3/norm_thunk.cc rename to third_party/xla/xla/service/gpu/runtime/norm_thunk.cc index 966f696ce8dafc..d3862f7bfeac74 100644 --- a/third_party/xla/xla/service/gpu/runtime3/norm_thunk.cc +++ b/third_party/xla/xla/service/gpu/runtime/norm_thunk.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/runtime3/norm_thunk.h" +#include "xla/service/gpu/runtime/norm_thunk.h" #include #include @@ -26,20 +26,26 @@ namespace xla { namespace gpu { NormThunk::NormThunk(ThunkInfo thunk_info, GpuNormConfig config, - BufferAllocation::Slice input_slice, + BufferAllocation::Slice x_slice, BufferAllocation::Slice scale_slice, - BufferAllocation::Slice bias_slice, - BufferAllocation::Slice output_slice, + BufferAllocation::Slice y_or_dx_slice, + std::optional bias_slice, std::optional expectation_slice, std::optional norm_factor_slice, + std::optional dy_slice, + std::optional dscale_slice, + std::optional dbias_slice, BufferAllocation::Slice scratch_slice) : Thunk(Kind::kNorm, thunk_info), - input_buffer_(input_slice), + x_buffer_(x_slice), scale_buffer_(scale_slice), + y_or_dx_buffer_(y_or_dx_slice), bias_buffer_(bias_slice), - output_buffer_(output_slice), expectation_buffer_(expectation_slice), norm_factor_buffer_(norm_factor_slice), + dy_buffer_(dy_slice), + dscale_buffer_(dscale_slice), + dbias_buffer_(dbias_slice), scratch_buffer_(scratch_slice), config_(config) {} @@ -57,25 +63,31 @@ NormRunner& NormThunk::GetOrCreateRunner( absl::Status NormThunk::ExecuteOnStream(const ExecuteParams& params) { const auto& buffer_allocations = *params.buffer_allocations; - se::DeviceMemoryBase input_se_buffer = - buffer_allocations.GetDeviceAddress(input_buffer_); + se::DeviceMemoryBase x_se_buffer = + buffer_allocations.GetDeviceAddress(x_buffer_); se::DeviceMemoryBase scale_se_buffer = buffer_allocations.GetDeviceAddress(scale_buffer_); - se::DeviceMemoryBase bias_se_buffer = - buffer_allocations.GetDeviceAddress(bias_buffer_); - se::DeviceMemoryBase output_se_buffer = - buffer_allocations.GetDeviceAddress(output_buffer_); + se::DeviceMemoryBase y_or_dx_se_buffer = + buffer_allocations.GetDeviceAddress(y_or_dx_buffer_); - std::optional expectation_se_buffer, - norm_factor_se_buffer; + std::optional bias_se_buffer, expectation_se_buffer, + norm_factor_se_buffer, dy_se_buffer, dscale_se_buffer, dbias_se_buffer; + if (bias_buffer_) { + bias_se_buffer = buffer_allocations.GetDeviceAddress(bias_buffer_.value()); + } if (expectation_buffer_) { expectation_se_buffer = buffer_allocations.GetDeviceAddress(expectation_buffer_.value()); - } - if (norm_factor_buffer_) { norm_factor_se_buffer = buffer_allocations.GetDeviceAddress(norm_factor_buffer_.value()); } + if (dscale_buffer_) { + dy_se_buffer = buffer_allocations.GetDeviceAddress(dy_buffer_.value()); + dscale_se_buffer = + buffer_allocations.GetDeviceAddress(dscale_buffer_.value()); + dbias_se_buffer = + buffer_allocations.GetDeviceAddress(dbias_buffer_.value()); + } se::DeviceMemoryBase scratch = buffer_allocations.GetDeviceAddress(scratch_buffer_); @@ -83,10 +95,10 @@ absl::Status NormThunk::ExecuteOnStream(const ExecuteParams& params) { RunNormOptions opts; opts.norm_runner = &GetOrCreateRunner(params.stream); - TF_RETURN_IF_ERROR(RunGpuNorm(config_, input_se_buffer, scale_se_buffer, - bias_se_buffer, output_se_buffer, - expectation_se_buffer, norm_factor_se_buffer, - scratch, params.stream, opts)); + TF_RETURN_IF_ERROR(RunGpuNorm( + config_, x_se_buffer, scale_se_buffer, y_or_dx_se_buffer, bias_se_buffer, + dy_se_buffer, expectation_se_buffer, norm_factor_se_buffer, + dscale_se_buffer, dbias_se_buffer, scratch, params.stream, opts)); if (!params.stream->ok()) { return Internal("NormThunk::ExecuteOnStream failed."); diff --git a/third_party/xla/xla/service/gpu/runtime3/norm_thunk.h b/third_party/xla/xla/service/gpu/runtime/norm_thunk.h similarity index 70% rename from third_party/xla/xla/service/gpu/runtime3/norm_thunk.h rename to third_party/xla/xla/service/gpu/runtime/norm_thunk.h index ba414941d695fa..4e362118df9138 100644 --- a/third_party/xla/xla/service/gpu/runtime3/norm_thunk.h +++ b/third_party/xla/xla/service/gpu/runtime/norm_thunk.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_GPU_RUNTIME3_NORM_THUNK_H_ -#define XLA_SERVICE_GPU_RUNTIME3_NORM_THUNK_H_ +#ifndef XLA_SERVICE_GPU_RUNTIME_NORM_THUNK_H_ +#define XLA_SERVICE_GPU_RUNTIME_NORM_THUNK_H_ #include #include @@ -32,10 +32,14 @@ namespace gpu { class NormThunk : public Thunk { public: NormThunk(ThunkInfo thunk_info, GpuNormConfig config, - BufferAllocation::Slice input, BufferAllocation::Slice scale, - BufferAllocation::Slice bias, BufferAllocation::Slice output, + BufferAllocation::Slice x, BufferAllocation::Slice scale, + BufferAllocation::Slice y_or_dx, + std::optional bias, std::optional expectation, std::optional norm_factor, + std::optional dy, + std::optional dscale, + std::optional dbias, BufferAllocation::Slice scratch); NormThunk(const NormThunk&) = delete; @@ -44,12 +48,15 @@ class NormThunk : public Thunk { absl::Status ExecuteOnStream(const ExecuteParams& params) override; private: - BufferAllocation::Slice input_buffer_; + BufferAllocation::Slice x_buffer_; BufferAllocation::Slice scale_buffer_; - BufferAllocation::Slice bias_buffer_; - BufferAllocation::Slice output_buffer_; + BufferAllocation::Slice y_or_dx_buffer_; + std::optional bias_buffer_; std::optional expectation_buffer_; std::optional norm_factor_buffer_; + std::optional dy_buffer_; + std::optional dscale_buffer_; + std::optional dbias_buffer_; BufferAllocation::Slice scratch_buffer_; NormRunner& GetOrCreateRunner(const stream_executor::Stream*); @@ -63,4 +70,4 @@ class NormThunk : public Thunk { } // namespace gpu } // namespace xla -#endif // XLA_SERVICE_GPU_RUNTIME3_NORM_THUNK_H_ +#endif // XLA_SERVICE_GPU_RUNTIME_NORM_THUNK_H_ diff --git a/third_party/xla/xla/service/gpu/runtime3/outfeed_thunk.cc b/third_party/xla/xla/service/gpu/runtime/outfeed_thunk.cc similarity index 90% rename from third_party/xla/xla/service/gpu/runtime3/outfeed_thunk.cc rename to third_party/xla/xla/service/gpu/runtime/outfeed_thunk.cc index 85708fba01f741..dd3dc2e153dbb7 100644 --- a/third_party/xla/xla/service/gpu/runtime3/outfeed_thunk.cc +++ b/third_party/xla/xla/service/gpu/runtime/outfeed_thunk.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/runtime3/outfeed_thunk.h" +#include "xla/service/gpu/runtime/outfeed_thunk.h" #include "absl/status/status.h" #include "xla/service/gpu/outfeed_manager.h" @@ -73,7 +73,8 @@ absl::Status OutfeedThunk::ExecuteOnStream(const ExecuteParams& params) { ++output_leaf_it; const Shape& output_shape = ShapeUtil::GetSubshape(output_buffers->shape(), shape_index); - TF_RET_CHECK(ShapeUtil::Equal(source_slices_[index].shape, output_shape)) + TF_RET_CHECK( + ShapeUtil::ReshapeIsBitcast(source_slices_[index].shape, output_shape)) << "Mismatch between outfeed output buffer shape " << ShapeUtil::HumanStringWithLayout(output_shape) << " and outfeed source buffer shape " @@ -87,16 +88,15 @@ absl::Status OutfeedThunk::ExecuteOnStream(const ExecuteParams& params) { // TODO(b/111309141): Run this on a separate stream so it doesn't block // the GPU from doing work during the transfer. - stream - .ThenMemcpy(buffer->destination()->untyped_data(), data_address, - buffer->length()) - .ThenDoHostCallback([&buffer]() { buffer->Done(); }); + TF_RETURN_IF_ERROR(stream.Memcpy(buffer->destination()->untyped_data(), + data_address, buffer->length())); + TF_RETURN_IF_ERROR(stream.DoHostCallback([&buffer]() { buffer->Done(); })); } absl::Status block_status = stream.BlockHostUntilDone(); if (!block_status.ok()) { return Internal("Failed to complete data transfer on stream %p: %s", - &stream, block_status.message()); + &stream, block_status.message()); } VLOG(2) << "Outfeeding from GPU complete"; diff --git a/third_party/xla/xla/service/gpu/runtime3/outfeed_thunk.h b/third_party/xla/xla/service/gpu/runtime/outfeed_thunk.h similarity index 90% rename from third_party/xla/xla/service/gpu/runtime3/outfeed_thunk.h rename to third_party/xla/xla/service/gpu/runtime/outfeed_thunk.h index af76973ee48fda..6ffd59cae9713a 100644 --- a/third_party/xla/xla/service/gpu/runtime3/outfeed_thunk.h +++ b/third_party/xla/xla/service/gpu/runtime/outfeed_thunk.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_GPU_RUNTIME3_OUTFEED_THUNK_H_ -#define XLA_SERVICE_GPU_RUNTIME3_OUTFEED_THUNK_H_ +#ifndef XLA_SERVICE_GPU_RUNTIME_OUTFEED_THUNK_H_ +#define XLA_SERVICE_GPU_RUNTIME_OUTFEED_THUNK_H_ #include "xla/service/gpu/thunk.h" @@ -42,4 +42,4 @@ class OutfeedThunk : public Thunk { } // namespace gpu } // namespace xla -#endif // XLA_SERVICE_GPU_RUNTIME3_OUTFEED_THUNK_H_ +#endif // XLA_SERVICE_GPU_RUNTIME_OUTFEED_THUNK_H_ diff --git a/third_party/xla/xla/service/gpu/runtime3/replica_id_thunk.cc b/third_party/xla/xla/service/gpu/runtime/replica_id_thunk.cc similarity index 90% rename from third_party/xla/xla/service/gpu/runtime3/replica_id_thunk.cc rename to third_party/xla/xla/service/gpu/runtime/replica_id_thunk.cc index b27eb6b298cd2d..c563afed6fea15 100644 --- a/third_party/xla/xla/service/gpu/runtime3/replica_id_thunk.cc +++ b/third_party/xla/xla/service/gpu/runtime/replica_id_thunk.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/runtime3/replica_id_thunk.h" +#include "xla/service/gpu/runtime/replica_id_thunk.h" #include "absl/status/status.h" #include "xla/service/global_device_id.h" @@ -32,8 +32,7 @@ absl::Status ReplicaOrPartitionIdThunk::ExecuteOnStream( global_device_id)); int id = kind() == Kind::kReplicaId ? logical_id.replica_id : logical_id.computation_id; - params.stream->ThenMemset32(&dest_addr, id, /*size=*/4); - return absl::OkStatus(); + return params.stream->Memset32(&dest_addr, id, /*size=*/4); } } // namespace gpu diff --git a/third_party/xla/xla/service/gpu/runtime3/replica_id_thunk.h b/third_party/xla/xla/service/gpu/runtime/replica_id_thunk.h similarity index 91% rename from third_party/xla/xla/service/gpu/runtime3/replica_id_thunk.h rename to third_party/xla/xla/service/gpu/runtime/replica_id_thunk.h index 2c94b20fa453f9..f9c2cda14c8001 100644 --- a/third_party/xla/xla/service/gpu/runtime3/replica_id_thunk.h +++ b/third_party/xla/xla/service/gpu/runtime/replica_id_thunk.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_GPU_RUNTIME3_REPLICA_ID_THUNK_H_ -#define XLA_SERVICE_GPU_RUNTIME3_REPLICA_ID_THUNK_H_ +#ifndef XLA_SERVICE_GPU_RUNTIME_REPLICA_ID_THUNK_H_ +#define XLA_SERVICE_GPU_RUNTIME_REPLICA_ID_THUNK_H_ #include "xla/service/buffer_assignment.h" #include "xla/service/gpu/thunk.h" @@ -53,4 +53,4 @@ class PartitionIdThunk : public ReplicaOrPartitionIdThunk { } // namespace gpu } // namespace xla -#endif // XLA_SERVICE_GPU_RUNTIME3_REPLICA_ID_THUNK_H_ +#endif // XLA_SERVICE_GPU_RUNTIME_REPLICA_ID_THUNK_H_ diff --git a/third_party/xla/xla/service/gpu/runtime/send_recv.cc b/third_party/xla/xla/service/gpu/runtime/send_recv.cc deleted file mode 100644 index 7093b1f5be9fed..00000000000000 --- a/third_party/xla/xla/service/gpu/runtime/send_recv.cc +++ /dev/null @@ -1,307 +0,0 @@ -/* Copyright 2022 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "xla/service/gpu/runtime/send_recv.h" - -#include -#include -#include -#include -#include - -#include "absl/status/status.h" -#include "absl/strings/str_format.h" -#include "absl/synchronization/mutex.h" -#include "xla/mlir/runtime/transforms/custom_call_encoding.h" -#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" -#include "xla/runtime/custom_call.h" -#include "xla/runtime/executable.h" -#include "xla/service/gpu/runtime/support.h" -#include "xla/service/service_executable_run_options.h" -#include "tsl/concurrency/async_value.h" -#include "tsl/concurrency/async_value_ref.h" -#include "tsl/profiler/lib/traceme.h" -#include "tsl/profiler/lib/traceme_encode.h" - -namespace xla { -namespace gpu { - -using absl::InternalError; -using absl::InvalidArgumentError; -using absl::StrFormat; - -using tsl::AsyncValueRef; -using tsl::profiler::TraceMe; -using tsl::profiler::TraceMeEncode; - -using xla::runtime::AggregateAttrDef; -using xla::runtime::AggregateAttrEncoding; -using xla::runtime::CustomCall; -using xla::runtime::CustomCallAttrEncodingSet; -using xla::runtime::Dictionary; -using xla::runtime::StridedMemrefView; -using xla::runtime::Tagged; -using xla::runtime::TypeIDNameRegistry; - -namespace mhlo = ::mlir::mhlo; - -//===----------------------------------------------------------------------===// -// Structs for encoding send/recv operations attributes. -//===----------------------------------------------------------------------===// - -struct ChannelHandle { - int64_t handle; - int64_t type; -}; - -} // namespace gpu - -//===----------------------------------------------------------------------===// -// Register send/recv attributes decoding with the Xla runtime. -//===----------------------------------------------------------------------===// - -namespace runtime { - -XLA_RUNTIME_REGISTER_AGGREGATE_ATTR_DECODING(xla::gpu::ChannelHandle, - AggregateMember("handle"), - AggregateMember("type")); - -} // namespace runtime - -//===----------------------------------------------------------------------===// -// Type names for encoded attributes. -//===----------------------------------------------------------------------===// - -namespace gpu { - -void RegisterSendRecvTypeIdNames(TypeIDNameRegistry& registry) { - registry.Register>("__type_id_channel_handle"); -} - -//===----------------------------------------------------------------------===// -// Encoding from MHLO attributes to Xla runtime aggregate attributes. -//===----------------------------------------------------------------------===// - -void PopulateSendRecvAttrEncoding(CustomCallAttrEncodingSet& encoding) { - { // --- Encode `mhlo::ChannelHandleAttr`. - using Attr = mhlo::ChannelHandleAttr; - encoding.Add>( - encoding, AggregateAttrDef() - .Add("handle", &Attr::getHandle) - .Add("type", &Attr::getType)); - } -} - -//===----------------------------------------------------------------------===// -// Support for running asynchronous Send/Recv SendDone/RecvDone operations. -//===----------------------------------------------------------------------===// - -absl::Status SendRecvEvents::PushEvent(int32_t handle, - AsyncValueRef event) { - absl::MutexLock lock(&mutex_); - if (auto it = events_.try_emplace(handle, std::move(event)); it.second) - return absl::OkStatus(); - - return InternalError( - StrFormat("Async send/recv event already exists (handle=%d)", handle)); -} - -absl::StatusOr> SendRecvEvents::PopEvent( - int32_t handle) { - absl::MutexLock lock(&mutex_); - if (auto event = events_.extract(handle)) return std::move(event.mapped()); - - return InternalError( - StrFormat("Async send/recv event was not found (handle==%d)", handle)); -} - -//===----------------------------------------------------------------------===// -// Generate a map with frontend attributes. -//===----------------------------------------------------------------------===// - -absl::flat_hash_map GenerateFrontEndAttributeMap( - Dictionary frontend_attrs) { - absl::flat_hash_map frontend_attr_map; - for (std::string_view key : frontend_attrs.keys()) { - auto frontend_attr = frontend_attrs.get(key); - if (mlir::succeeded(frontend_attr)) { - frontend_attr_map.insert({std::string(key), std::string(*frontend_attr)}); - } - } - return frontend_attr_map; -} - -//===----------------------------------------------------------------------===// -// Send/Recv custom call implementation. -//===----------------------------------------------------------------------===// - -static absl::Status SendImpl(const ServiceExecutableRunOptions* run_options, - SendRecvEvents* events, StridedMemrefView arg, - ChannelHandle channel, Dictionary frontend_attrs) { - VLOG(3) << "Host Send buffer:" - << " channel=" << channel.handle; - - TraceMe trace([&] { - return TraceMeEncode("xla.gpu.send_host", {{"channel", channel.handle}}); - }); - - // Use device_to_host stream if it is available. - se::Stream* stream = run_options->run_options().device_to_host_stream(); - if (stream) { - stream->ThenWaitFor(run_options->stream()); - } else { - stream = run_options->stream(); - } - - // Send buffer to a handler registered with the run options. - if (auto* send = run_options->run_options().send_device_memory_function()) { - TF_ASSIGN_OR_RETURN( - auto done_event, - (*send)(channel.handle, stream, ToShape(arg), GetDeviceAddress(arg), - GenerateFrontEndAttributeMap(frontend_attrs))); - return events->PushEvent(channel.handle, std::move(done_event)); - } - - return InvalidArgumentError("SendDeviceMemoryFunction is not available"); -} - -static absl::Status RecvImpl(const ServiceExecutableRunOptions* run_options, - SendRecvEvents* events, StridedMemrefView arg, - ChannelHandle channel, Dictionary frontend_attrs) { - VLOG(3) << "Host Receive buffer:" - << " channel=" << channel.handle; - - TraceMe trace([&] { - return TraceMeEncode("xla.gpu.recv_host", {{"channel", channel.handle}}); - }); - - // Use host_to_device stream if it is available. - se::Stream* stream = run_options->run_options().host_to_device_stream(); - if (stream) { - stream->ThenWaitFor(run_options->stream()); - } else { - stream = run_options->stream(); - } - - // Recv buffer from a handler registered with the run options. - if (auto* recv = run_options->run_options().recv_device_memory_function()) { - auto dst = GetDeviceAddress(arg); - TF_ASSIGN_OR_RETURN(auto done_event, - (*recv)(channel.handle, stream, ToShape(arg), &dst, - GenerateFrontEndAttributeMap(frontend_attrs))); - return events->PushEvent(channel.handle, std::move(done_event)); - } - - return InvalidArgumentError("RecvDeviceMemoryFunction is not available"); -} - -static absl::Status SendDoneImpl(const ServiceExecutableRunOptions* run_options, - SendRecvEvents* events, - ChannelHandle channel) { - VLOG(3) << "Wait for Host Send completion:" - << " channel=" << channel.handle; - - TraceMe trace([&] { - return TraceMeEncode("xla.gpu.send_done_host", - {{"channel", channel.handle}}); - }); - - TF_ASSIGN_OR_RETURN(auto done_event, events->PopEvent(channel.handle)); - - // Wait until send handler will record an event on the stream. - BlockUntilReady(done_event.GetAsyncValue()); - if (done_event.IsError()) return done_event.GetError(); - - VLOG(5) << "Completed Host Send operation: " - << " channel=" << channel.handle; - - // Once event is recorded we can add a stream dependency. - run_options->stream()->ThenWaitFor(&done_event.get()); - return absl::OkStatus(); -} - -static absl::Status RecvDoneImpl(const ServiceExecutableRunOptions* run_options, - SendRecvEvents* events, - ChannelHandle channel) { - VLOG(3) << "Wait for Recv completion:" - << " channel=" << channel.handle; - - TraceMe trace([&] { - return TraceMeEncode("xla.gpu.recv_done_host", - {{"channel", channel.handle}}); - }); - - TF_ASSIGN_OR_RETURN(auto done_event, events->PopEvent(channel.handle)); - - // Wait until send handler will record an event on the stream. - BlockUntilReady(done_event.GetAsyncValue()); - if (done_event.IsError()) return done_event.GetError(); - - VLOG(5) << "Completed Host Recv operation: " - << " channel=" << channel.handle; - - // Once event is recorded we can add a stream dependency. - run_options->stream()->ThenWaitFor(&done_event.get()); - return absl::OkStatus(); -} - -//===----------------------------------------------------------------------===// -// Send/Recv custom calls bindings and registration. -//===----------------------------------------------------------------------===// - -XLA_RUNTIME_DEFINE_CUSTOM_CALL( - SendHost, FunctionWrapper(), checks, - CustomCall::Bind("xla.gpu.send_host") - .UserData() - .UserData() - .Arg() - .Attr("channel_handle") - .Attr("frontend_attributes")); - -XLA_RUNTIME_DEFINE_CUSTOM_CALL( - RecvHost, FunctionWrapper(), checks, - CustomCall::Bind("xla.gpu.recv_host") - .UserData() - .UserData() - .Arg() - .Attr("channel_handle") - .Attr("frontend_attributes")); - -XLA_RUNTIME_DEFINE_CUSTOM_CALL( - SendDoneHost, FunctionWrapper(), checks, - CustomCall::Bind("xla.gpu.send_done_host") - .UserData() - .UserData() - .Attr("channel_handle")); - -XLA_RUNTIME_DEFINE_CUSTOM_CALL( - RecvDoneHost, FunctionWrapper(), checks, - CustomCall::Bind("xla.gpu.recv_done_host") - .UserData() - .UserData() - .Attr("channel_handle")); - -//===----------------------------------------------------------------------===// - -// Registers XLA Gpu runtime Host Send/Recv custom calls. -void RegisterSendRecvCustomCalls(runtime::DirectCustomCallRegistry& registry) { - registry.Register("xla.gpu.send_host", SendHost); - registry.Register("xla.gpu.recv_host", RecvHost); - registry.Register("xla.gpu.send_done_host", SendDoneHost); - registry.Register("xla.gpu.recv_done_host", RecvDoneHost); -} - -} // namespace gpu -} // namespace xla diff --git a/third_party/xla/xla/service/gpu/runtime/send_recv.h b/third_party/xla/xla/service/gpu/runtime/send_recv.h deleted file mode 100644 index 1090f643e1def7..00000000000000 --- a/third_party/xla/xla/service/gpu/runtime/send_recv.h +++ /dev/null @@ -1,56 +0,0 @@ -/* Copyright 2022 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_SERVICE_GPU_RUNTIME_SEND_RECV_H_ -#define XLA_SERVICE_GPU_RUNTIME_SEND_RECV_H_ - -#include -#include - -#include "xla/mlir/runtime/transforms/custom_call_encoding.h" -#include "xla/runtime/custom_call_registry.h" -#include "xla/stream_executor/event.h" - -namespace xla { -namespace gpu { - -// Registers XLA Gpu runtime Send/Recv custom calls. -void RegisterSendRecvCustomCalls(runtime::DirectCustomCallRegistry& registry); - -// Register type names for communication attributes defined by MHLO dialect. -void RegisterSendRecvTypeIdNames(runtime::TypeIDNameRegistry& registry); - -// Adds attributes encoding for Send/Recv custom calls -void PopulateSendRecvAttrEncoding(runtime::CustomCallAttrEncodingSet& encoding); - -//===----------------------------------------------------------------------===// -// Support for running asynchronous Send/Recv SendDone/RecvDone operations. -//===----------------------------------------------------------------------===// - -class SendRecvEvents { - public: - absl::Status PushEvent(int32_t handle, tsl::AsyncValueRef event); - absl::StatusOr> PopEvent(int32_t handle); - - private: - absl::Mutex mutex_; - absl::flat_hash_map> events_ - ABSL_GUARDED_BY(mutex_); -}; - -} // namespace gpu -} // namespace xla - -#endif // XLA_SERVICE_GPU_RUNTIME_SEND_RECV_H_ diff --git a/third_party/xla/xla/service/gpu/runtime3/send_recv_thunk.cc b/third_party/xla/xla/service/gpu/runtime/send_recv_thunk.cc similarity index 96% rename from third_party/xla/xla/service/gpu/runtime3/send_recv_thunk.cc rename to third_party/xla/xla/service/gpu/runtime/send_recv_thunk.cc index 0ed651b0a40f08..7d6a3183069e6e 100644 --- a/third_party/xla/xla/service/gpu/runtime3/send_recv_thunk.cc +++ b/third_party/xla/xla/service/gpu/runtime/send_recv_thunk.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/runtime3/send_recv_thunk.h" +#include "xla/service/gpu/runtime/send_recv_thunk.h" #include #include @@ -38,6 +38,7 @@ limitations under the License. #include "xla/stream_executor/stream_executor.h" #include "tsl/concurrency/async_value.h" #include "tsl/concurrency/async_value_ref.h" +#include "tsl/platform/errors.h" #include "tsl/platform/statusor.h" #include "tsl/profiler/lib/traceme.h" @@ -123,7 +124,7 @@ absl::Status SendThunk::ExecuteOnStream(const ExecuteParams& params) { // Use device_to_host stream if it is available. se::Stream* stream = params.device_to_host_stream; if (stream) { - stream->ThenWaitFor(params.stream); + TF_RETURN_IF_ERROR(stream->WaitFor(params.stream)); } else { stream = params.stream; } @@ -175,8 +176,7 @@ absl::Status SendDoneThunk::ExecuteOnStream(const ExecuteParams& params) { VLOG(5) << "Completed Send operation: channel_id=" << channel_id_; // Once event is recorded we can add a stream dependency. - params.stream->ThenWaitFor(&done_event.get()); - return absl::OkStatus(); + return params.stream->WaitFor(&done_event.get()); } //===----------------------------------------------------------------------===// @@ -210,7 +210,7 @@ absl::Status RecvThunk::ExecuteOnStream(const ExecuteParams& params) { // Use host_to_device stream if it is available. se::Stream* stream = params.host_to_device_stream; if (stream) { - stream->ThenWaitFor(params.stream); + TF_RETURN_IF_ERROR(stream->WaitFor(params.stream)); } else { stream = params.stream; } @@ -261,8 +261,7 @@ absl::Status RecvDoneThunk::ExecuteOnStream(const ExecuteParams& params) { VLOG(5) << "Completed Recv operation: channel=" << channel_id_; // Once event is recorded we can add a stream dependency. - params.stream->ThenWaitFor(&done_event.get()); - return absl::OkStatus(); + return params.stream->WaitFor(&done_event.get()); } } // namespace xla::gpu diff --git a/third_party/xla/xla/service/gpu/runtime3/send_recv_thunk.h b/third_party/xla/xla/service/gpu/runtime/send_recv_thunk.h similarity index 97% rename from third_party/xla/xla/service/gpu/runtime3/send_recv_thunk.h rename to third_party/xla/xla/service/gpu/runtime/send_recv_thunk.h index 25e891c7f4f970..034a9e01249853 100644 --- a/third_party/xla/xla/service/gpu/runtime3/send_recv_thunk.h +++ b/third_party/xla/xla/service/gpu/runtime/send_recv_thunk.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_GPU_RUNTIME3_SEND_RECV_THUNK_H_ -#define XLA_SERVICE_GPU_RUNTIME3_SEND_RECV_THUNK_H_ +#ifndef XLA_SERVICE_GPU_RUNTIME_SEND_RECV_THUNK_H_ +#define XLA_SERVICE_GPU_RUNTIME_SEND_RECV_THUNK_H_ #include #include @@ -166,4 +166,4 @@ class RecvDoneThunk : public Thunk { } // namespace xla::gpu -#endif // XLA_SERVICE_GPU_RUNTIME3_SEND_RECV_THUNK_H_ +#endif // XLA_SERVICE_GPU_RUNTIME_SEND_RECV_THUNK_H_ diff --git a/third_party/xla/xla/service/gpu/runtime3/sequential_thunk.cc b/third_party/xla/xla/service/gpu/runtime/sequential_thunk.cc similarity index 88% rename from third_party/xla/xla/service/gpu/runtime3/sequential_thunk.cc rename to third_party/xla/xla/service/gpu/runtime/sequential_thunk.cc index d2f4e8616da94f..e872792fc5c467 100644 --- a/third_party/xla/xla/service/gpu/runtime3/sequential_thunk.cc +++ b/third_party/xla/xla/service/gpu/runtime/sequential_thunk.cc @@ -13,13 +13,14 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/runtime3/sequential_thunk.h" +#include "xla/service/gpu/runtime/sequential_thunk.h" #include #include #include "absl/status/status.h" #include "absl/strings/str_cat.h" +#include "xla/service/gpu/runtime/annotation.h" #include "xla/service/gpu/thunk.h" #include "tsl/platform/errors.h" #include "tsl/profiler/lib/scoped_annotation.h" @@ -54,8 +55,10 @@ absl::Status SequentialThunk::Initialize(const InitializeParams& params) { } absl::Status SequentialThunk::ExecuteOnStream(const ExecuteParams& params) { + const ModuleAnnotations* annotations = GetCurrentModuleAnnotations(); for (const auto& thunk : thunks_) { - ScopedAnnotation annotation([&] { return thunk->profile_annotation(); }); + auto annotation = + GetKernelAnnotation(annotations, thunk->profile_annotation()); TF_RETURN_IF_ERROR(thunk->ExecuteOnStream(params)); } return absl::OkStatus(); diff --git a/third_party/xla/xla/service/gpu/runtime3/sequential_thunk.h b/third_party/xla/xla/service/gpu/runtime/sequential_thunk.h similarity index 91% rename from third_party/xla/xla/service/gpu/runtime3/sequential_thunk.h rename to third_party/xla/xla/service/gpu/runtime/sequential_thunk.h index 2538af9b45cf76..fe1abafe5df5d1 100644 --- a/third_party/xla/xla/service/gpu/runtime3/sequential_thunk.h +++ b/third_party/xla/xla/service/gpu/runtime/sequential_thunk.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_GPU_RUNTIME3_SEQUENTIAL_THUNK_H_ -#define XLA_SERVICE_GPU_RUNTIME3_SEQUENTIAL_THUNK_H_ +#ifndef XLA_SERVICE_GPU_RUNTIME_SEQUENTIAL_THUNK_H_ +#define XLA_SERVICE_GPU_RUNTIME_SEQUENTIAL_THUNK_H_ #include @@ -50,4 +50,4 @@ class SequentialThunk : public Thunk { } // namespace gpu } // namespace xla -#endif // XLA_SERVICE_GPU_RUNTIME3_SEQUENTIAL_THUNK_H_ +#endif // XLA_SERVICE_GPU_RUNTIME_SEQUENTIAL_THUNK_H_ diff --git a/third_party/xla/xla/service/gpu/runtime/stream_synchronization.cc b/third_party/xla/xla/service/gpu/runtime/stream_synchronization.cc deleted file mode 100644 index d41a57aa35d2a5..00000000000000 --- a/third_party/xla/xla/service/gpu/runtime/stream_synchronization.cc +++ /dev/null @@ -1,55 +0,0 @@ -/* Copyright 2023 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "xla/service/gpu/runtime/stream_synchronization.h" - -#include "xla/runtime/executable.h" -#include "xla/service/gpu/runtime/concurrent_region.h" -#include "xla/service/gpu/runtime/support.h" - -namespace xla { -namespace gpu { - -static absl::Status AwaitImpl(ConcurrentRegionStatus* region_status, - int64_t from, absl::Span to) { - TF_ASSIGN_OR_RETURN(se::Stream * from_stream, region_status->GetStream(from)); - for (int64_t to_index : to) { - TF_ASSIGN_OR_RETURN(se::Stream * to_stream, - region_status->GetStream(to_index)); - from_stream->ThenWaitFor(to_stream); - } - - return absl::OkStatus(); -} - -//===----------------------------------------------------------------------===// -// Define custom calls that mark the concurrent region in CUDA graphs. -//===----------------------------------------------------------------------===// - -using xla::runtime::CustomCall; - -XLA_RUNTIME_DEFINE_CUSTOM_CALL(Await, FunctionWrapper(), checks, - CustomCall::Bind("xla.streams.await") - .UserData() - .Attr("from") - .Attr>("to")); - -void RegisterStreamSynchronizationCustomCalls( - runtime::DirectCustomCallRegistry& registry) { - registry.Register("xla.streams.await", Await); -} - -} // namespace gpu -} // namespace xla diff --git a/third_party/xla/xla/service/gpu/runtime/stream_synchronization.h b/third_party/xla/xla/service/gpu/runtime/stream_synchronization.h deleted file mode 100644 index 1b480b9864c8cb..00000000000000 --- a/third_party/xla/xla/service/gpu/runtime/stream_synchronization.h +++ /dev/null @@ -1,31 +0,0 @@ -/* Copyright 2023 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_SERVICE_GPU_RUNTIME_STREAM_SYNCHRONIZATION_H_ -#define XLA_SERVICE_GPU_RUNTIME_STREAM_SYNCHRONIZATION_H_ - -#include "xla/runtime/custom_call_registry.h" - -namespace xla { -namespace gpu { - -// Registers XLA Gpu runtime stream synchronization custom calls. -void RegisterStreamSynchronizationCustomCalls( - runtime::DirectCustomCallRegistry& registry); - -} // namespace gpu -} // namespace xla - -#endif // XLA_SERVICE_GPU_RUNTIME_STREAM_SYNCHRONIZATION_H_ diff --git a/third_party/xla/xla/service/gpu/runtime/support.cc b/third_party/xla/xla/service/gpu/runtime/support.cc deleted file mode 100644 index a59c0d44cfdc53..00000000000000 --- a/third_party/xla/xla/service/gpu/runtime/support.cc +++ /dev/null @@ -1,67 +0,0 @@ -/* Copyright 2023 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "xla/service/gpu/runtime/support.h" - -#include -#include - -#include "tsl/profiler/lib/scoped_annotation_stack.h" - -namespace xla { -namespace gpu { - -namespace { -static thread_local std::string_view current_tracing_scope = {}; -} // namespace - -void SetCurrentTracingScope(std::string_view scope) { - current_tracing_scope = scope; -} - -void ResetCurrentTracingScope() { current_tracing_scope = std::string_view(); } - -void AppendDiagnosticToString(runtime::DiagnosticEngine& diagnostic_engine, - std::string* diagnostic, - bool append_annotation_stack) { - diagnostic_engine.AddHandler( - [append_annotation_stack, diagnostic](runtime::Diagnostic& d) { - if (!diagnostic->empty()) absl::StrAppend(diagnostic, "; "); - absl::StrAppend(diagnostic, d.status().message()); - - // Append the current trace which should help identifying original HLO - // operation that fails. - if (!current_tracing_scope.empty()) { - absl::StrAppend(diagnostic, - "; current tracing scope: ", current_tracing_scope); - } - - // Append current profiling annotation which will have the XLA - // executable name and program id. - if (append_annotation_stack) { - absl::StrAppend(diagnostic, "; current profiling annotation: ", - tsl::profiler::AnnotationStack::Get()); - } - - LOG(WARNING) << "Intercepted XLA runtime error:\n" - << d.status().ToString( - absl::StatusToStringMode::kWithEverything); - - return runtime::success(); - }); -} - -} // namespace gpu -} // namespace xla diff --git a/third_party/xla/xla/service/gpu/runtime/support.h b/third_party/xla/xla/service/gpu/runtime/support.h deleted file mode 100644 index 10efd22a01e122..00000000000000 --- a/third_party/xla/xla/service/gpu/runtime/support.h +++ /dev/null @@ -1,166 +0,0 @@ -/* Copyright 2022 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_SERVICE_GPU_RUNTIME_SUPPORT_H_ -#define XLA_SERVICE_GPU_RUNTIME_SUPPORT_H_ - -#include -#include -#include - -#include "absl/strings/str_cat.h" -#include "llvm/ADT/ArrayRef.h" -#include "xla/mlir/runtime/transforms/custom_call_encoding.h" -#include "xla/runtime/custom_call.h" -#include "xla/service/gpu/matmul_utils.h" -#include "xla/shape_util.h" -#include "xla/stream_executor/blas.h" -#include "xla/stream_executor/device_memory.h" - -namespace xla { -namespace gpu { - -template -using FunctionWrapper = xla::runtime::CustomCall::FunctionWrapper; - -struct DotDimensionNumbers { - absl::Span lhs_batch; - absl::Span lhs_contract; - absl::Span rhs_batch; - absl::Span rhs_contract; -}; - -// Disable expensive CustomCall checks in optimized build. -inline constexpr runtime::CustomCall::RuntimeChecks checks = // NOLINT -#if defined(NDEBUG) - runtime::CustomCall::RuntimeChecks::kLess; -#else - runtime::CustomCall::RuntimeChecks::kDefault; -#endif - -template -absl::StatusOr ToAbsl(absl::StatusOr status_or) { - if (!status_or.ok()) return status_or.status(); - return std::move(status_or).value(); -} - -inline se::DeviceMemoryBase GetDeviceAddress( - const runtime::FlatMemrefView& memref) { - return se::DeviceMemoryBase(memref.data, memref.size_in_bytes); -} - -inline se::DeviceMemoryBase GetDeviceAddress( - const runtime::MemrefView& memref) { - uint64_t size = primitive_util::ByteWidth(memref.dtype); - for (auto dim : memref.sizes) size *= dim; - return se::DeviceMemoryBase(memref.data, size); -} - -inline se::DeviceMemoryBase GetDeviceAddress( - const runtime::StridedMemrefView& memref) { - uint64_t size = primitive_util::ByteWidth(memref.dtype); - for (auto dim : memref.sizes) size *= dim; - if (primitive_util::Is4BitType(memref.dtype)) { - size = (size + 1) / 2; - } - return se::DeviceMemoryBase(memref.data, size); -} - -inline Shape ToShape(const runtime::StridedMemrefView& memref) { - // Recover `minor_to_major` dimensions permutation from strides. - auto indexed_strides_range = - llvm::map_range(llvm::enumerate(memref.strides), [](auto pair) { - return std::pair{pair.value(), pair.index()}; - }); - - auto indexed_strides = llvm::to_vector(indexed_strides_range); - llvm::stable_sort(indexed_strides); - - llvm::SmallVector minor_to_major; - minor_to_major.reserve(indexed_strides.size()); - for (auto& pair : indexed_strides) minor_to_major.push_back(pair.second); - - return ShapeUtil::MakeShapeWithDenseLayout(memref.dtype, memref.sizes, - minor_to_major); -} - -inline absl::StatusOr GetGemmConfig( - const runtime::StridedMemrefView& lhs, - const runtime::StridedMemrefView& rhs, - const runtime::StridedMemrefView& out, int64_t algorithm, double alpha_real, - double alpha_imag, double beta, absl::Span lhs_batch, - absl::Span lhs_contract, absl::Span rhs_batch, - absl::Span rhs_contract, int64_t compute_precision, - const std::optional c = std::nullopt, - const std::optional& bias = std::nullopt, - bool grad_x = false, bool grad_y = false) { - Shape c_shape = ToShape(c.value_or(out)); - Shape bias_shape; - Shape* bias_shape_ptr = nullptr; - if (bias) { - bias_shape = ToShape(*bias); - bias_shape_ptr = &bias_shape; - } - return GemmConfig::For(ToShape(lhs), lhs_batch, lhs_contract, ToShape(rhs), - rhs_batch, rhs_contract, c_shape, bias_shape_ptr, - ToShape(out), alpha_real, alpha_imag, beta, algorithm, - compute_precision, grad_x, grad_y); -} - -// adds Dot Dimension Attribute encodings for calls to Gemm and cuBLASLt -inline void PopulateDotDimsAttrEncoding( - runtime::CustomCallAttrEncodingSet& encoding) { - using DotDimsAttr = mlir::mhlo::DotDimensionNumbersAttr; - encoding.Add< - xla::runtime::AggregateAttrEncoding>( - encoding, - xla::runtime::AggregateAttrDef() - .Add("lhs_batch", &DotDimsAttr::getLhsBatchingDimensions) - .Add("lhs_contract", &DotDimsAttr::getLhsContractingDimensions) - .Add("rhs_batch", &DotDimsAttr::getRhsBatchingDimensions) - .Add("rhs_contract", &DotDimsAttr::getRhsContractingDimensions)); -} - -// Appends to `diagnostic_engine` a handler that appends all emitted errors to -// the `diagnostic` string. If `append_annotation_stack` is true, it will append -// current profiler annotation stack to the diagnostic message (annotation used -// in Xprof). -void AppendDiagnosticToString(runtime::DiagnosticEngine& diagnostic_engine, - std::string* diagnostic, - bool append_annotation_stack = false); - -// Sets the current tracing scope that will be added to all emitted diagnostics. -void SetCurrentTracingScope(std::string_view scope); -void ResetCurrentTracingScope(); - -} // namespace gpu -} // namespace xla - -namespace xla { -namespace runtime { - -// using llvm::ArrayRef; - -XLA_RUNTIME_REGISTER_AGGREGATE_ATTR_DECODING( - xla::gpu::DotDimensionNumbers, - AggregateMember>("lhs_batch"), - AggregateMember>("lhs_contract"), - AggregateMember>("rhs_batch"), - AggregateMember>("rhs_contract")); - -} // namespace runtime -} // namespace xla - -#endif // XLA_SERVICE_GPU_RUNTIME_SUPPORT_H_ diff --git a/third_party/xla/xla/service/gpu/runtime/topk.cc b/third_party/xla/xla/service/gpu/runtime/topk.cc deleted file mode 100644 index 197bd68bbd2c60..00000000000000 --- a/third_party/xla/xla/service/gpu/runtime/topk.cc +++ /dev/null @@ -1,68 +0,0 @@ -/* Copyright 2023 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "xla/service/gpu/runtime/topk.h" - -#include - -#include - -#include "absl/status/status.h" -#include "absl/types/span.h" -#include "xla/runtime/custom_call.h" -#include "xla/runtime/custom_call_registry.h" -#include "xla/runtime/executable.h" -#include "xla/service/gpu/runtime/support.h" -#include "xla/service/gpu/runtime/topk_kernel.h" -#include "xla/service/service_executable_run_options.h" -#include "xla/stream_executor/gpu/gpu_stream.h" -#include "xla/types.h" -#include "xla/xla_data.pb.h" - -namespace xla::gpu { -using ::xla::runtime::CustomCall; -using ::xla::runtime::StridedMemrefView; - -static absl::Status TopkImpl(const ServiceExecutableRunOptions* run_options, - StridedMemrefView data, - StridedMemrefView top_elements, - StridedMemrefView indices) { - if (data.sizes.size() > 2) - return absl::InvalidArgumentError("Invalid input shape"); - if (indices.dtype != PrimitiveType::S32) - return absl::InvalidArgumentError("Indices should be S32"); - bool has_batch = data.sizes.size() == 2; - size_t batch_size = has_batch ? data.sizes[0] : 1; - size_t n = has_batch ? data.sizes[1] : data.sizes[0]; - size_t k = has_batch ? top_elements.sizes[1] : top_elements.sizes[0]; - return RunTopk(run_options->stream(), data.dtype, GetDeviceAddress(data), n, - GetDeviceAddress(top_elements), GetDeviceAddress(indices), k, - batch_size); -} - -XLA_RUNTIME_DEFINE_CUSTOM_CALL( - Topk, FunctionWrapper(), checks, - CustomCall::Bind("__gpu$TopK") - .UserData() - .Arg() // input - .Arg() // output (values) - .Arg() // output (indices) -); - -void RegisterTopkCustomCall(runtime::DirectCustomCallRegistry& registry) { - registry.Register("__gpu$TopK", Topk); -} - -} // namespace xla::gpu diff --git a/third_party/xla/xla/service/gpu/runtime/tracing.cc b/third_party/xla/xla/service/gpu/runtime/tracing.cc deleted file mode 100644 index f2f32552d6b0e2..00000000000000 --- a/third_party/xla/xla/service/gpu/runtime/tracing.cc +++ /dev/null @@ -1,94 +0,0 @@ -/* Copyright 2022 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "xla/service/gpu/runtime/tracing.h" - -#include - -#include "absl/status/statusor.h" -#include "absl/strings/str_format.h" -#include "xla/runtime/executable.h" -#include "xla/runtime/tracing.h" -#include "xla/service/gpu/runtime/support.h" -#include "tsl/profiler/lib/scoped_annotation_stack.h" - -namespace xla { -namespace gpu { - -using ::xla::runtime::CustomCall; -using ::xla::runtime::HloTrace; - -using ::tsl::profiler::ScopedAnnotationStack; - -//===----------------------------------------------------------------------===// -// Type names for encoded attributes. -//===----------------------------------------------------------------------===// - -void RegisterTracingTypeIdNames(runtime::TypeIDNameRegistry& registry) { - runtime::PopulateTraceTypeIdNames(registry); -} - -//===----------------------------------------------------------------------===// -// Tracing custom calls implementation. -//===----------------------------------------------------------------------===// - -namespace { -thread_local const ModuleAnnotations* current_annotations{}; -} - -static absl::StatusOr ActivityStart(runtime::HloTrace annotation) { - SetCurrentTracingScope(annotation.hlo_op); - if (current_annotations) { - // We know which HloModule we belong to, and may have pre-prepared - // annotation structs ready to use - const auto iter = current_annotations->kernels.find(annotation.hlo_op); - if (iter != current_annotations->kernels.end()) { - // Have a pre-prepared annotation, use it - return ScopedAnnotationStack::ActivityStart([&] { return iter->second; }); - } - } - return ScopedAnnotationStack::ActivityStart([&] { - // We use the same tracing annotation scheme as the ThunkSequence. - return absl::StrFormat("Thunk:#hlo_op=%s#", annotation.hlo_op); - }); -} - -static absl::Status ActivityEnd(int64_t activity_id) { - ResetCurrentTracingScope(); - ScopedAnnotationStack::ActivityEnd(activity_id); - return absl::OkStatus(); -} - -XLA_RUNTIME_DEFINE_CUSTOM_CALL(Start, FunctionWrapper(), checks, - CustomCall::Bind("xla.trace.activity_start") - .Attr("annotation") - .Ret()); - -XLA_RUNTIME_DEFINE_CUSTOM_CALL( - End, FunctionWrapper(), checks, - CustomCall::Bind("xla.trace.activity_end").Arg()); - -void RegisterTracingCustomCalls(runtime::DirectCustomCallRegistry& registry) { - registry.Register("xla.trace.activity_start", Start); - registry.Register("xla.trace.activity_end", End); -} - -const ModuleAnnotations* SetCurrentModuleAnnotations( - const ModuleAnnotations* annotations) { - return std::exchange(current_annotations, annotations); -} - -} // namespace gpu -} // namespace xla diff --git a/third_party/xla/xla/service/gpu/runtime/tracing.h b/third_party/xla/xla/service/gpu/runtime/tracing.h deleted file mode 100644 index c59db877e49f8d..00000000000000 --- a/third_party/xla/xla/service/gpu/runtime/tracing.h +++ /dev/null @@ -1,38 +0,0 @@ -/* Copyright 2022 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_SERVICE_GPU_RUNTIME_TRACING_H_ -#define XLA_SERVICE_GPU_RUNTIME_TRACING_H_ - -#include - -#include "xla/runtime/custom_call_registry.h" -#include "xla/runtime/type_id.h" -#include "xla/service/gpu/runtime/annotation.h" - -namespace xla { -namespace gpu { - -void RegisterTracingTypeIdNames(runtime::TypeIDNameRegistry& registry); - -void RegisterTracingCustomCalls(runtime::DirectCustomCallRegistry& registry); - -const ModuleAnnotations* SetCurrentModuleAnnotations( - const ModuleAnnotations* annotations); - -} // namespace gpu -} // namespace xla - -#endif // XLA_SERVICE_GPU_RUNTIME_TRACING_H_ diff --git a/third_party/xla/xla/service/gpu/runtime/triangular_solve.cc b/third_party/xla/xla/service/gpu/runtime/triangular_solve.cc deleted file mode 100644 index 264dcbde686f41..00000000000000 --- a/third_party/xla/xla/service/gpu/runtime/triangular_solve.cc +++ /dev/null @@ -1,137 +0,0 @@ -/* Copyright 2022 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "xla/service/gpu/runtime/triangular_solve.h" - -#include -#include -#include - -#include "xla/runtime/custom_call.h" -#include "xla/service/gpu/gpu_asm_opts_util.h" -#include "xla/service/gpu/runtime/support.h" -#include "tsl/platform/human_readable_json.h" - -#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM -#include "xla/service/gpu/runtime3/triangular_solve_thunk.h" -#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM - -namespace xla { -namespace gpu { - -using xla::runtime::CustomCall; - -using mlir::failure; -using mlir::FailureOr; - -absl::Status TriangularSolve::run( - const ServiceExecutableRunOptions* run_options, - const DebugOptions* debug_options, CustomCall::RemainingArgs args, - std::string_view backend_config) { - TriangularSolve handler = TriangularSolve::Handler(); - - if (args.size() != 4) - return absl::InvalidArgumentError( - absl::StrFormat("Expected 4 arguments, got %d", args.size())); - - // Check if all arguments have the correct type. - auto a = args.get(0); - auto b = args.get(1); - auto result = args.get(2); - auto temp = args.get(3); - if (failed(a) || failed(b) || failed(result) || failed(temp)) - return absl::InvalidArgumentError("Incorrect argument types"); - - // Parse backend config string. - TriangularSolveOptions opts; - - const std::string backend_config_str = - std::string(backend_config.data(), backend_config.length()); - - TF_RETURN_IF_ERROR(tsl::HumanReadableJsonToProto(backend_config_str, &opts)); - - return handler(run_options, debug_options, *a, *b, *result, *temp, - opts.left_side(), opts.lower(), opts.unit_diagonal(), - opts.transpose_a()); -} - -absl::Status TriangularSolve::operator()( - const ServiceExecutableRunOptions* run_options, - const DebugOptions* debug_options, runtime::StridedMemrefView a, - runtime::StridedMemrefView b, runtime::StridedMemrefView result, - runtime::FlatMemrefView temp, bool left_side, bool lower, - bool unit_diagonal, TriangularSolveOptions::Transpose transpose_a) const { -#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM - se::Stream* stream = run_options->stream(); - - se::DeviceMemoryBase a_data = GetDeviceAddress(a); - se::DeviceMemoryBase b_data = GetDeviceAddress(b); - se::DeviceMemoryBase result_data = GetDeviceAddress(result); - se::DeviceMemoryBase temp_data = GetDeviceAddress(temp); - - // Triangular solve is in-place on 'b', so copy 'b' to the output if they - // aren't the same buffer. - if (b.data != result.data) - stream->ThenMemcpy(&result_data, b_data, b_data.size()); - - Shape b_shape = ToShape(b); - int64_t m = b_shape.dimensions(b_shape.rank() - 2); - int64_t n = b_shape.dimensions(b_shape.rank() - 1); - int64_t batch_size = std::accumulate( - b_shape.dimensions().begin(), b_shape.dimensions().end() - 2, int64_t{1}, - [](int64_t a, int64_t b) { return a * b; }); - - PrimitiveType elem_type = b.dtype; - int64_t elem_size = ShapeUtil::ByteSizeOfPrimitiveType(elem_type); - int64_t a_batch_stride = left_side ? m * m * elem_size : n * n * elem_size; - int64_t b_batch_stride = m * n * elem_size; - - using Side = se::blas::Side; - using Diagonal = se::blas::Diagonal; - using Transpose = se::blas::Transpose; - using UpperLower = se::blas::UpperLower; - - // Convert custom call attributes to se::blas enums. - UpperLower uplo = lower ? UpperLower::kLower : UpperLower::kUpper; - Side side = left_side ? Side::kLeft : Side::kRight; - Diagonal diagonal = unit_diagonal ? Diagonal::kUnit : Diagonal::kNonUnit; - - auto transpose = [&]() -> mlir::FailureOr { - switch (transpose_a) { - case TriangularSolveOptions::NO_TRANSPOSE: - return se::blas::Transpose::kNoTranspose; - case TriangularSolveOptions::TRANSPOSE: - return se::blas::Transpose::kTranspose; - case TriangularSolveOptions::ADJOINT: - return se::blas::Transpose::kConjugateTranspose; - default: - return failure(); - } - }(); - - if (failed(transpose)) - return absl::InternalError("Failed to convert transpose type"); - - return RunTriangularSolve(a_data, result_data, temp_data, - PtxOptsFromDebugOptions(*debug_options), uplo, side, - diagonal, *transpose, elem_type, batch_size, m, n, - a_batch_stride, b_batch_stride, stream); -#else // GOOGLE_CUDA || TENSORFLOW_USE_ROCM - return absl::InternalError("Not implemented without Gpu"); -#endif -} - -} // namespace gpu -} // namespace xla diff --git a/third_party/xla/xla/service/gpu/runtime/triangular_solve.h b/third_party/xla/xla/service/gpu/runtime/triangular_solve.h deleted file mode 100644 index 3f5d9c22f0eb92..00000000000000 --- a/third_party/xla/xla/service/gpu/runtime/triangular_solve.h +++ /dev/null @@ -1,52 +0,0 @@ -/* Copyright 2022 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_SERVICE_GPU_RUNTIME_TRIANGULAR_SOLVE_H_ -#define XLA_SERVICE_GPU_RUNTIME_TRIANGULAR_SOLVE_H_ - -#include - -#include "xla/runtime/custom_call.h" -#include "xla/service/service_executable_run_options.h" -#include "xla/xla.pb.h" - -namespace xla { -namespace gpu { - -using runtime::CustomCall; - -struct TriangularSolve { - // Adaptor from XlaCustomCall API to properly typed TriangularSolve handler. - static absl::Status run(const ServiceExecutableRunOptions* run_options, - const DebugOptions* debug_options, - CustomCall::RemainingArgs args, - std::string_view backend_config); - - absl::Status operator()(const ServiceExecutableRunOptions* run_options, - const DebugOptions* debug_options, - runtime::StridedMemrefView a, - runtime::StridedMemrefView b, - runtime::StridedMemrefView result, - runtime::FlatMemrefView temp, bool left_side, - bool lower, bool unit_diagonal, - TriangularSolveOptions::Transpose transpose_a) const; - - static TriangularSolve Handler() { return TriangularSolve(); } -}; - -} // namespace gpu -} // namespace xla - -#endif // XLA_SERVICE_GPU_RUNTIME_TRIANGULAR_SOLVE_H_ diff --git a/third_party/xla/xla/service/gpu/runtime3/triangular_solve_thunk.cc b/third_party/xla/xla/service/gpu/runtime/triangular_solve_thunk.cc similarity index 70% rename from third_party/xla/xla/service/gpu/runtime3/triangular_solve_thunk.cc rename to third_party/xla/xla/service/gpu/runtime/triangular_solve_thunk.cc index 81607a93f2498e..8be15a1846e12b 100644 --- a/third_party/xla/xla/service/gpu/runtime3/triangular_solve_thunk.cc +++ b/third_party/xla/xla/service/gpu/runtime/triangular_solve_thunk.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/runtime3/triangular_solve_thunk.h" +#include "xla/service/gpu/runtime/triangular_solve_thunk.h" #include "absl/status/status.h" #include "absl/strings/str_format.h" @@ -99,49 +99,43 @@ absl::Status RunTriangularSolve( const int lda = side == se::blas::Side::kLeft ? m : n; const int ldb = m; + auto blas = stream->parent()->AsBlas(); + if (blas == nullptr) { + return absl::InternalError("No BLAS support in stream."); + } bool launch_ok; if (batch_size == 1) { switch (type) { case F32: { se::DeviceMemory b_data_typed(b_data); - launch_ok = - stream - ->ThenBlasTrsm(side, uplo, transpose_a, unit_diagonal, m, n, - /*alpha=*/1.0f, se::DeviceMemory(a_data), - lda, &b_data_typed, ldb) - .ok(); + launch_ok = blas->DoBlasTrsm( + stream, side, uplo, transpose_a, unit_diagonal, m, n, + /*alpha=*/1.0f, se::DeviceMemory(a_data), lda, &b_data_typed, + ldb); break; } case F64: { se::DeviceMemory b_data_typed(b_data); - launch_ok = - stream - ->ThenBlasTrsm(side, uplo, transpose_a, unit_diagonal, m, n, - /*alpha=*/1.0, se::DeviceMemory(a_data), - lda, &b_data_typed, ldb) - .ok(); + launch_ok = blas->DoBlasTrsm( + stream, side, uplo, transpose_a, unit_diagonal, m, n, + /*alpha=*/1.0, se::DeviceMemory(a_data), lda, &b_data_typed, + ldb); break; } case C64: { se::DeviceMemory> b_data_typed(b_data); - launch_ok = - stream - ->ThenBlasTrsm(side, uplo, transpose_a, unit_diagonal, m, n, - /*alpha=*/1.0f, - se::DeviceMemory>(a_data), - lda, &b_data_typed, ldb) - .ok(); + launch_ok = blas->DoBlasTrsm( + stream, side, uplo, transpose_a, unit_diagonal, m, n, + /*alpha=*/1.0f, se::DeviceMemory>(a_data), lda, + &b_data_typed, ldb); break; } case C128: { se::DeviceMemory> b_data_typed(b_data); - launch_ok = - stream - ->ThenBlasTrsm(side, uplo, transpose_a, unit_diagonal, m, n, - /*alpha=*/1.0, - se::DeviceMemory>(a_data), - lda, &b_data_typed, ldb) - .ok(); + launch_ok = blas->DoBlasTrsm( + stream, side, uplo, transpose_a, unit_diagonal, m, n, + /*alpha=*/1.0, se::DeviceMemory>(a_data), lda, + &b_data_typed, ldb); break; } default: @@ -167,46 +161,34 @@ absl::Status RunTriangularSolve( switch (type) { case F32: { se::DeviceMemory typed_b_pointers(b_pointers); - launch_ok = - stream - ->ThenBlasTrsmBatched(side, uplo, transpose_a, unit_diagonal, m, - n, /*alpha=*/1.0f, - se::DeviceMemory(a_pointers), lda, - &typed_b_pointers, ldb, batch_size) - .ok(); + launch_ok = blas->DoBlasTrsmBatched( + stream, side, uplo, transpose_a, unit_diagonal, m, n, + /*alpha=*/1.0f, se::DeviceMemory(a_pointers), lda, + &typed_b_pointers, ldb, batch_size); break; } case F64: { se::DeviceMemory typed_b_pointers(b_pointers); - launch_ok = - stream - ->ThenBlasTrsmBatched(side, uplo, transpose_a, unit_diagonal, m, - n, /*alpha=*/1.0f, - se::DeviceMemory(a_pointers), - lda, &typed_b_pointers, ldb, batch_size) - .ok(); + launch_ok = blas->DoBlasTrsmBatched( + stream, side, uplo, transpose_a, unit_diagonal, m, n, + /*alpha=*/1.0f, se::DeviceMemory(a_pointers), lda, + &typed_b_pointers, ldb, batch_size); break; } case C64: { se::DeviceMemory*> typed_b_pointers(b_pointers); - launch_ok = stream - ->ThenBlasTrsmBatched( - side, uplo, transpose_a, unit_diagonal, m, n, - /*alpha=*/1.0f, - se::DeviceMemory*>(a_pointers), - lda, &typed_b_pointers, ldb, batch_size) - .ok(); + launch_ok = blas->DoBlasTrsmBatched( + stream, side, uplo, transpose_a, unit_diagonal, m, n, + /*alpha=*/1.0f, se::DeviceMemory*>(a_pointers), + lda, &typed_b_pointers, ldb, batch_size); break; } case C128: { se::DeviceMemory*> typed_b_pointers(b_pointers); - launch_ok = stream - ->ThenBlasTrsmBatched( - side, uplo, transpose_a, unit_diagonal, m, n, - /*alpha=*/1.0f, - se::DeviceMemory*>(a_pointers), - lda, &typed_b_pointers, ldb, batch_size) - .ok(); + launch_ok = blas->DoBlasTrsmBatched( + stream, side, uplo, transpose_a, unit_diagonal, m, n, + /*alpha=*/1.0f, se::DeviceMemory*>(a_pointers), + lda, &typed_b_pointers, ldb, batch_size); break; } default: diff --git a/third_party/xla/xla/service/gpu/runtime3/triangular_solve_thunk.h b/third_party/xla/xla/service/gpu/runtime/triangular_solve_thunk.h similarity index 94% rename from third_party/xla/xla/service/gpu/runtime3/triangular_solve_thunk.h rename to third_party/xla/xla/service/gpu/runtime/triangular_solve_thunk.h index 451a004b949462..e55ffb4da36a84 100644 --- a/third_party/xla/xla/service/gpu/runtime3/triangular_solve_thunk.h +++ b/third_party/xla/xla/service/gpu/runtime/triangular_solve_thunk.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_GPU_RUNTIME3_TRIANGULAR_SOLVE_THUNK_H_ -#define XLA_SERVICE_GPU_RUNTIME3_TRIANGULAR_SOLVE_THUNK_H_ +#ifndef XLA_SERVICE_GPU_RUNTIME_TRIANGULAR_SOLVE_THUNK_H_ +#define XLA_SERVICE_GPU_RUNTIME_TRIANGULAR_SOLVE_THUNK_H_ #include @@ -83,4 +83,4 @@ absl::Status RunTriangularSolve( } // namespace gpu } // namespace xla -#endif // XLA_SERVICE_GPU_RUNTIME3_TRIANGULAR_SOLVE_THUNK_H_ +#endif // XLA_SERVICE_GPU_RUNTIME_TRIANGULAR_SOLVE_THUNK_H_ diff --git a/third_party/xla/xla/service/gpu/runtime/wait_for_streams_thunk.cc b/third_party/xla/xla/service/gpu/runtime/wait_for_streams_thunk.cc new file mode 100644 index 00000000000000..a63cf493c1a461 --- /dev/null +++ b/third_party/xla/xla/service/gpu/runtime/wait_for_streams_thunk.cc @@ -0,0 +1,47 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/runtime/wait_for_streams_thunk.h" + +#include +#include + +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" +#include "xla/service/gpu/thunk.h" +#include "tsl/platform/errors.h" + +namespace xla::gpu { + +absl::Status WaitForStreamsThunk::ExecuteOnStream(const ExecuteParams& params) { + TF_ASSIGN_OR_RETURN(se::Stream * stream, + Thunk::GetStreamForExecution(stream_id_, params)); + + VLOG(5) << "Waiting for stream ids: " + << absl::StrJoin( + wait_for_stream_ids_, ", ", + [&](std::string* s, const ExecutionStreamId& stream_id) { + absl::StrAppend(s, stream_id.value()); + }); + for (const auto& stream_id : wait_for_stream_ids_) { + TF_ASSIGN_OR_RETURN(se::Stream * wait_on_stream, + Thunk::GetStreamForExecution(stream_id, params)); + + TF_RETURN_IF_ERROR(stream->WaitFor(wait_on_stream)); + } + return absl::OkStatus(); +} + +} // namespace xla::gpu diff --git a/third_party/xla/xla/service/gpu/runtime/wait_for_streams_thunk.h b/third_party/xla/xla/service/gpu/runtime/wait_for_streams_thunk.h new file mode 100644 index 00000000000000..4835b56b7e1e59 --- /dev/null +++ b/third_party/xla/xla/service/gpu/runtime/wait_for_streams_thunk.h @@ -0,0 +1,53 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_GPU_RUNTIME_WAIT_FOR_STREAMS_THUNK_H_ +#define XLA_SERVICE_GPU_RUNTIME_WAIT_FOR_STREAMS_THUNK_H_ + +#include + +#include "absl/status/status.h" +#include "xla/service/gpu/thunk.h" + +namespace xla::gpu { + +// This thunk +class WaitForStreamsThunk : public Thunk { + public: + WaitForStreamsThunk(ThunkInfo thunk_info, ExecutionStreamId stream_id, + std::vector wait_for_stream_ids) + : Thunk(Kind::kWaitForStreams, thunk_info), + stream_id_(stream_id), + wait_for_stream_ids_(wait_for_stream_ids){}; + + WaitForStreamsThunk(const WaitForStreamsThunk&) = delete; + WaitForStreamsThunk& operator=(const WaitForStreamsThunk&) = delete; + + const ExecutionStreamId& stream_id() const { return stream_id_; } + + const std::vector& wait_for_stream_ids() const { + return wait_for_stream_ids_; + } + + absl::Status ExecuteOnStream(const ExecuteParams& params) override; + + private: + ExecutionStreamId stream_id_; + std::vector wait_for_stream_ids_; +}; + +} // namespace xla::gpu + +#endif // XLA_SERVICE_GPU_RUNTIME_WAIT_FOR_STREAMS_THUNK_H_ diff --git a/third_party/xla/xla/service/gpu/runtime3/while_thunk.cc b/third_party/xla/xla/service/gpu/runtime/while_thunk.cc similarity index 77% rename from third_party/xla/xla/service/gpu/runtime3/while_thunk.cc rename to third_party/xla/xla/service/gpu/runtime/while_thunk.cc index 51dc58b45ffc36..0f933aad078f50 100644 --- a/third_party/xla/xla/service/gpu/runtime3/while_thunk.cc +++ b/third_party/xla/xla/service/gpu/runtime/while_thunk.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/runtime3/while_thunk.h" +#include "xla/service/gpu/runtime/while_thunk.h" #include #include @@ -22,12 +22,15 @@ limitations under the License. #include "absl/status/status.h" #include "absl/strings/str_format.h" +#include "absl/synchronization/mutex.h" #include "xla/service/buffer_assignment.h" -#include "xla/service/gpu/runtime3/sequential_thunk.h" +#include "xla/service/gpu/runtime/sequential_thunk.h" #include "xla/service/gpu/thunk.h" #include "xla/stream_executor/device_memory.h" +#include "xla/stream_executor/memory_allocation.h" #include "tsl/platform/errors.h" #include "tsl/platform/logging.h" +#include "tsl/platform/statusor.h" namespace xla { namespace gpu { @@ -57,6 +60,14 @@ absl::Status WhileThunk::Prepare(const PrepareParams& params, absl::Status WhileThunk::Initialize(const InitializeParams& params) { TF_RETURN_IF_ERROR(condition_thunk_sequence_->Initialize(params)); TF_RETURN_IF_ERROR(body_thunk_sequence_->Initialize(params)); + + absl::MutexLock lock(&mutex_); + if (auto it = predicates_.find(params.executor); it == predicates_.end()) { + TF_ASSIGN_OR_RETURN(std::unique_ptr allocation, + params.executor->HostMemoryAllocate(sizeof(bool))); + predicates_.emplace(params.executor, std::move(allocation)); + } + return absl::OkStatus(); } @@ -78,22 +89,29 @@ absl::Status WhileThunk::ExecuteOnStream(const ExecuteParams& params) { int64_t iter = 0; + // Get memory allocation for copying condition result from device. + bool* condition_result = [&] { + absl::MutexLock lock(&mutex_); + return reinterpret_cast(predicates_.at(stream.parent())->opaque()); + }(); + while (true) { VLOG(3) << "Executing WhileThunk condition computation; iter=" << iter; TF_RETURN_IF_ERROR(condition_thunk_sequence_->ExecuteOnStream(params)); // Copy the result of condition computation and break the loop if 'false'. - bool condition_result; - stream.ThenMemcpy(&condition_result, condition_result_data, sizeof(bool)); - VLOG(3) << "condition_result = " << condition_result; + TF_RETURN_IF_ERROR( + stream.Memcpy(condition_result, condition_result_data, sizeof(bool))); + if (absl::Status blocked = stream.BlockHostUntilDone(); !blocked.ok()) { return absl::InternalError(absl::StrFormat( "Failed to complete all kernels launched on stream %p: %s", &stream, blocked.message())); } - if (!condition_result) { - VLOG(3) << "Break WHileThunk loop; iter=" << iter; + VLOG(3) << "condition_result = " << *condition_result; + if (!*condition_result) { + VLOG(3) << "Break WhileThunk loop; iter=" << iter; break; } diff --git a/third_party/xla/xla/service/gpu/runtime3/while_thunk.h b/third_party/xla/xla/service/gpu/runtime/while_thunk.h similarity index 81% rename from third_party/xla/xla/service/gpu/runtime3/while_thunk.h rename to third_party/xla/xla/service/gpu/runtime/while_thunk.h index 04a9c5b4756239..ec48b0c5714b04 100644 --- a/third_party/xla/xla/service/gpu/runtime3/while_thunk.h +++ b/third_party/xla/xla/service/gpu/runtime/while_thunk.h @@ -13,17 +13,22 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_GPU_RUNTIME3_WHILE_THUNK_H_ -#define XLA_SERVICE_GPU_RUNTIME3_WHILE_THUNK_H_ +#ifndef XLA_SERVICE_GPU_RUNTIME_WHILE_THUNK_H_ +#define XLA_SERVICE_GPU_RUNTIME_WHILE_THUNK_H_ #include #include #include +#include "absl/base/thread_annotations.h" +#include "absl/container/flat_hash_map.h" #include "absl/status/status.h" +#include "absl/synchronization/mutex.h" #include "xla/service/buffer_assignment.h" -#include "xla/service/gpu/runtime3/sequential_thunk.h" +#include "xla/service/gpu/runtime/sequential_thunk.h" #include "xla/service/gpu/thunk.h" +#include "xla/stream_executor/memory_allocation.h" +#include "xla/stream_executor/stream_executor.h" namespace xla { namespace gpu { @@ -75,9 +80,15 @@ class WhileThunk : public Thunk { std::unique_ptr condition_thunk_sequence_; std::unique_ptr body_thunk_sequence_; std::optional trip_count_; + + // Pinned host memory for transfering predicate value from device to host. + absl::Mutex mutex_; + absl::flat_hash_map> + predicates_ ABSL_GUARDED_BY(mutex_); }; } // namespace gpu } // namespace xla -#endif // XLA_SERVICE_GPU_RUNTIME3_WHILE_THUNK_H_ +#endif // XLA_SERVICE_GPU_RUNTIME_WHILE_THUNK_H_ diff --git a/third_party/xla/xla/service/gpu/runtime3/BUILD b/third_party/xla/xla/service/gpu/runtime3/BUILD deleted file mode 100644 index d2ddb7fea82053..00000000000000 --- a/third_party/xla/xla/service/gpu/runtime3/BUILD +++ /dev/null @@ -1,612 +0,0 @@ -load("//xla/tests:build_defs.bzl", "xla_test") -load("//xla/service/gpu:build_defs.bzl", "get_cub_sort_kernel_types") -load("//xla/stream_executor:build_defs.bzl", "if_gpu_is_configured") -load("@local_config_rocm//rocm:build_defs.bzl", "if_rocm_is_configured") -load("@local_tsl//tsl/platform/default:cuda_build_defs.bzl", "if_cuda_is_configured") - -package( - default_visibility = ["//visibility:public"], - licenses = ["notice"], -) - -package_group( - name = "friends", - includes = ["//xla:friends"], -) - -#===-------------------------------------------------------------------------------------------===// -# Command Buffer Integration -#===-------------------------------------------------------------------------------------------===// - -cc_library( - name = "command_buffer_allocations", - srcs = ["command_buffer_allocations.cc"], - hdrs = ["command_buffer_allocations.h"], - visibility = ["//visibility:public"], - deps = [ - "//xla:status", - "//xla:statusor", - "//xla/service:buffer_assignment", - "//xla/service/gpu:buffer_allocations", - "//xla/stream_executor", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/log", - "@com_google_absl//absl/status", - "@com_google_absl//absl/strings", - ], -) - -cc_library( - name = "command_buffer_cmd", - srcs = ["command_buffer_cmd.cc"], - hdrs = ["command_buffer_cmd.h"], - local_defines = if_cuda_is_configured([ - "GOOGLE_CUDA=1", - ]), - visibility = ["//visibility:public"], - deps = [ - ":custom_call_thunk", - "//xla:shape_util", - "//xla:status", - "//xla:types", - "//xla:util", - "//xla/service:buffer_assignment", - "//xla/service:collective_ops_utils", - "//xla/service:computation_placer", - "//xla/service:custom_call_status_internal", - "//xla/service:custom_call_status_public_headers", - "//xla/service:global_device_id", - "//xla/service/gpu:buffer_allocations", - "//xla/service/gpu:gpu_executable_run_options", - "//xla/service/gpu:launch_dimensions", - "//xla/service/gpu:matmul_utils", - "//xla/service/gpu:nccl_api", - "//xla/service/gpu:nccl_clique", - "//xla/service/gpu:nccl_clique_key", - "//xla/service/gpu:nccl_collective_thunks", - "//xla/service/gpu:stream_executor_util", - "//xla/service/gpu:thunk", - "//xla/service/gpu/kernels:custom_kernel", - "//xla/stream_executor", - "//xla/stream_executor/gpu:gpu_stream_header", - "//xla/stream_executor/gpu:gpu_types_header", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/base:core_headers", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/container:inlined_vector", - "@com_google_absl//absl/log", - "@com_google_absl//absl/status", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/synchronization", - "@com_google_absl//absl/types:span", - "@local_tsl//tsl/concurrency:ref_count", - "@local_tsl//tsl/platform:env", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:statusor", - ], -) - -cc_library( - name = "command_buffer_cmd_emitter", - srcs = ["command_buffer_cmd_emitter.cc"], - hdrs = ["command_buffer_cmd_emitter.h"], - visibility = ["//visibility:public"], - deps = [ - ":command_buffer_cmd", - ":conditional_thunk", - ":copy_thunk", - ":custom_call_thunk", - ":gemm_thunk", - ":kernel_thunk", - ":memset_thunk", - ":replica_id_thunk", - ":sequential_thunk", - ":while_thunk", - "//xla:status", - "//xla:statusor", - "//xla:util", - "//xla/service/gpu:nccl_collective_thunks", - "//xla/service/gpu:thunk", - "@com_google_absl//absl/container:inlined_vector", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:statusor", - ], -) - -xla_test( - name = "command_buffer_cmd_test", - srcs = if_gpu_is_configured(["command_buffer_cmd_test.cc"]), - backends = ["gpu"], - local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]) + if_rocm_is_configured(["TENSORFLOW_USE_ROCM=1"]), - deps = [ - ":command_buffer_cmd", - "//xla:status", - "//xla:types", - "//xla/service:buffer_assignment", - "//xla/service:platform_util", - "//xla/service/gpu:buffer_allocations", - "//xla/service/gpu:launch_dimensions", - "//xla/service/gpu:thunk", - "//xla/stream_executor", - "//xla/stream_executor:multi_platform_manager", - "//xla/stream_executor:platform", - "//xla/stream_executor/gpu:gpu_test_kernels", - "@com_google_absl//absl/strings", - "@local_tsl//tsl/lib/core:status_test_util", - "@local_tsl//tsl/platform:test", - "@local_tsl//tsl/platform:test_main", - ], -) - -#===-------------------------------------------------------------------------------------------===// -# XLA Thunks Runtime -#===-------------------------------------------------------------------------------------------===// - -cc_library( - name = "cholesky_thunk", - srcs = if_gpu_is_configured(["cholesky_thunk.cc"]), - hdrs = if_gpu_is_configured(["cholesky_thunk.h"]), - visibility = ["//visibility:public"], - deps = if_gpu_is_configured([ - "//xla/service/gpu:buffer_allocations", - "//xla/service/gpu:cusolver_context", - "//xla/service/gpu:make_batch_pointers", - "//xla/service/gpu:thunk", - "//xla:types", - "//xla:util", - "//xla:xla_data_proto_cc", - "//xla/service:buffer_assignment", - "//xla/hlo/ir:hlo", - "@local_tsl//tsl/platform:logging", - "//xla/stream_executor", - "//xla/stream_executor:device_memory", - "//xla/stream_executor/gpu:gpu_asm_opts", - ]) + ["@local_tsl//tsl/platform:status"], -) - -cc_library( - name = "command_buffer_thunk", - srcs = ["command_buffer_thunk.cc"], - hdrs = ["command_buffer_thunk.h"], - visibility = ["//visibility:public"], - deps = [ - ":command_buffer_allocations", - ":command_buffer_cmd", - "//xla:status", - "//xla:statusor", - "//xla/service:buffer_assignment", # build_cleaner: keep - "//xla/service/gpu:buffer_allocations", # build_cleaner: keep - "//xla/service/gpu:thunk", - "//xla/stream_executor", - "@com_google_absl//absl/base:core_headers", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/synchronization", - "@local_tsl//tsl/platform:env", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:logging", - "@local_tsl//tsl/platform:statusor", - "@local_tsl//tsl/profiler/lib:profiler_lock", - "@local_tsl//tsl/profiler/lib:scoped_annotation", - "@local_tsl//tsl/profiler/lib:traceme", - "@local_tsl//tsl/profiler/lib:traceme_encode", - ], -) - -xla_test( - name = "command_buffer_thunk_test", - srcs = if_gpu_is_configured(["command_buffer_thunk_test.cc"]), - backend_tags = { - "gpu_a100": ["config-cuda-only"], - "gpu_v100": ["config-cuda-only"], - }, - backends = [ - "gpu_a100", - "gpu_v100", - ], - local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]) + if_rocm_is_configured(["TENSORFLOW_USE_ROCM=1"]), - deps = [ - ":command_buffer_allocations", - ":command_buffer_cmd", - ":command_buffer_thunk", - "//xla:shape_util", - "//xla:types", - "//xla/service:buffer_assignment", - "//xla/service:executable", - "//xla/service:platform_util", - "//xla/service/gpu:buffer_allocations", - "//xla/service/gpu:launch_dimensions", - "//xla/service/gpu:matmul_utils", - "//xla/service/gpu:thunk", - "//xla/stream_executor", - "//xla/stream_executor:multi_platform_manager", - "//xla/stream_executor:platform", - "//xla/stream_executor/gpu:gpu_test_kernels", - "//xla/stream_executor/gpu:gpu_types_header", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@local_tsl//tsl/lib/core:status_test_util", - "@local_tsl//tsl/platform:statusor", - "@local_tsl//tsl/platform:test", - "@local_tsl//tsl/platform:test_main", - ], -) - -cc_library( - name = "conditional_thunk", - srcs = ["conditional_thunk.cc"], - hdrs = ["conditional_thunk.h"], - visibility = ["//visibility:public"], - deps = [ - ":sequential_thunk", - "//xla:status", - "//xla:util", - "//xla/hlo/ir:hlo", - "//xla/service/gpu:buffer_allocations", - "//xla/service/gpu:thunk", - "//xla/stream_executor", - "@com_google_absl//absl/status", - "@com_google_absl//absl/types:span", - "@local_tsl//tsl/platform:errors", - ], -) - -cc_library( - name = "convolution_thunk", - srcs = ["convolution_thunk.cc"], - hdrs = ["convolution_thunk.h"], - visibility = ["//visibility:public"], - deps = [ - "//xla:util", - "//xla/service:buffer_assignment", - "//xla/service/gpu:gpu_conv_runner", - "//xla/service/gpu:thunk", - "//xla/stream_executor", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/status", - ], -) - -cc_library( - name = "copy_thunk", - srcs = ["copy_thunk.cc"], - hdrs = ["copy_thunk.h"], - visibility = ["//visibility:public"], - deps = [ - "//xla:status", - "//xla/service:buffer_assignment", - "//xla/service/gpu:thunk", - "//xla/stream_executor", - "@llvm-project//mlir:IR", - ], -) - -cc_library( - name = "cub_sort_thunk", - srcs = if_gpu_is_configured(["cub_sort_thunk.cc"]), - hdrs = if_gpu_is_configured(["cub_sort_thunk.h"]), - local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]) + if_rocm_is_configured([ - "TENSORFLOW_USE_ROCM=1", - ]), - visibility = ["//visibility:public"], - deps = if_gpu_is_configured([ - "@com_google_absl//absl/log:check", - "@com_google_absl//absl/status", - "@com_google_absl//absl/strings", - "//xla/service:buffer_assignment", - "//xla/service/gpu:buffer_allocations", - "//xla/service/gpu:thunk", - "//xla/stream_executor:device_memory", - "//xla:shape_util", - "//xla:util", - "//xla:xla_data_proto_cc", - "@local_tsl//tsl/platform:errors", - ] + ["//xla/service/gpu:cub_sort_kernel_" + suffix for suffix in get_cub_sort_kernel_types()]), -) - -cc_library( - name = "custom_call_thunk", - srcs = ["custom_call_thunk.cc"], - hdrs = ["custom_call_thunk.h"], - local_defines = if_cuda_is_configured([ - "GOOGLE_CUDA=1", - ]), - visibility = ["//visibility:public"], - deps = [ - "//xla:executable_run_options", - "//xla:shape_util", - "//xla:status", - "//xla:util", - "//xla/ffi:call_frame", - "//xla/ffi:ffi_api", - "//xla/ffi/api:c_api", - "//xla/hlo/ir:hlo", - "//xla/service:buffer_assignment", - "//xla/service:custom_call_status", - "//xla/service:custom_call_status_internal", - "//xla/service:executable", - "//xla/service/gpu:thunk", - "//xla/stream_executor:device_memory", - "//xla/stream_executor/gpu:gpu_stream_header", - "//xla/stream_executor/gpu:gpu_types_header", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/status", - "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/strings:string_view", - ], -) - -cc_library( - name = "fft_thunk", - srcs = ["fft_thunk.cc"], - hdrs = ["fft_thunk.h"], - visibility = ["//visibility:public"], - deps = [ - "//xla:types", - "//xla:util", - "//xla:xla_data_proto_cc", - "//xla/hlo/ir:hlo", - "//xla/service:buffer_assignment", - "//xla/service/gpu:buffer_allocations", - "//xla/service/gpu:thunk", - "//xla/stream_executor", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/status", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:str_format", - "@local_tsl//tsl/platform:logging", - "@local_tsl//tsl/platform:status", - ], -) - -cc_library( - name = "fused_mha_thunk", - srcs = ["fused_mha_thunk.cc"], - hdrs = ["fused_mha_thunk.h"], - visibility = ["//visibility:public"], - deps = [ - "//xla:util", - "//xla:xla_data_proto_cc", - "//xla/service:buffer_assignment", - "//xla/service/gpu:gpu_fused_mha_runner", - "//xla/service/gpu:thunk", - "//xla/stream_executor", - "@com_google_absl//absl/container:flat_hash_map", - ], -) - -cc_library( - name = "gemm_thunk", - srcs = ["gemm_thunk.cc"], - hdrs = ["gemm_thunk.h"], - visibility = ["//visibility:public"], - deps = [ - "//xla:status", - "//xla/service:buffer_assignment", - "//xla/service/gpu:matmul_utils", - "//xla/service/gpu:thunk", - "//xla/stream_executor:device_memory", - "@com_google_absl//absl/status", - "@local_tsl//tsl/platform:logging", - ], -) - -cc_library( - name = "gpublas_lt_matmul_thunk", - srcs = if_gpu_is_configured(["gpublas_lt_matmul_thunk.cc"]), - hdrs = if_gpu_is_configured(["gpublas_lt_matmul_thunk.h"]), - local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]) + if_rocm_is_configured([ - "TENSORFLOW_USE_ROCM=1", - ]), - visibility = ["//visibility:public"], - deps = if_gpu_is_configured([ - "//xla/service:buffer_assignment", - "//xla/service/gpu:matmul_utils", - "//xla/service/gpu:thunk", - "//xla:status", - "//xla/stream_executor:device_memory", - "//xla/stream_executor", - "@local_tsl//tsl/platform:logging", - ]), -) - -cc_library( - name = "infeed_thunk", - srcs = ["infeed_thunk.cc"], - hdrs = ["infeed_thunk.h"], - visibility = ["//visibility:public"], - deps = [ - "//xla:shape_util", - "//xla:status_macros", - "//xla:util", - "//xla/service/gpu:buffer_allocations", - "//xla/service/gpu:io_feed_manager", - "//xla/service/gpu:thunk", - "//xla/stream_executor", - "@com_google_absl//absl/status", - ], -) - -cc_library( - name = "kernel_thunk", - srcs = ["kernel_thunk.cc"], - hdrs = ["kernel_thunk.h"], - visibility = ["//visibility:public"], - deps = [ - "//xla:status", - "//xla:types", - "//xla/hlo/ir:hlo", - "//xla/service:buffer_assignment", - "//xla/service/gpu:kernel_arguments", - "//xla/service/gpu:launch_dimensions", - "//xla/service/gpu:stream_executor_util", - "//xla/service/gpu:thunk", - "//xla/service/gpu/kernels:custom_kernel", - "//xla/stream_executor", - "@com_google_absl//absl/base:core_headers", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/container:inlined_vector", - "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/synchronization", - "@com_google_absl//absl/types:span", - "@llvm-project//mlir:IR", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:logging", - "@local_tsl//tsl/platform:statusor", - ], -) - -cc_library( - name = "memset_thunk", - srcs = ["memset_thunk.cc"], - hdrs = ["memset_thunk.h"], - visibility = ["//visibility:public"], - deps = [ - "//xla:status", - "//xla/service:buffer_assignment", - "//xla/service/gpu:thunk", - "//xla/stream_executor", - "@com_google_absl//absl/status", - ], -) - -cc_library( - name = "norm_thunk", - srcs = ["norm_thunk.cc"], - hdrs = ["norm_thunk.h"], - visibility = ["//visibility:public"], - deps = [ - "//xla:util", - "//xla:xla_data_proto_cc", - "//xla/service:buffer_assignment", - "//xla/service/gpu:gpu_norm_runner", - "//xla/service/gpu:thunk", - "//xla/stream_executor", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/status", - ], -) - -cc_library( - name = "outfeed_thunk", - srcs = ["outfeed_thunk.cc"], - hdrs = ["outfeed_thunk.h"], - visibility = ["//visibility:public"], - deps = [ - "//xla:util", - "//xla/service/gpu:io_feed_manager", - "//xla/service/gpu:thunk", - "//xla/stream_executor", - "@com_google_absl//absl/status", - ], -) - -cc_library( - name = "replica_id_thunk", - srcs = ["replica_id_thunk.cc"], - hdrs = ["replica_id_thunk.h"], - visibility = ["//visibility:public"], - deps = [ - "//xla/service:buffer_assignment", - "//xla/service:global_device_id", - "//xla/service/gpu:thunk", - "@com_google_absl//absl/status", - "@local_tsl//tsl/platform:statusor", - ], -) - -cc_library( - name = "sequential_thunk", - srcs = ["sequential_thunk.cc"], - hdrs = ["sequential_thunk.h"], - local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]) + if_rocm_is_configured([ - "TENSORFLOW_USE_ROCM=1", - ]), - visibility = ["//visibility:public"], - deps = [ - "//xla:status", - "//xla/hlo/ir:hlo", - "//xla/service/gpu:buffer_allocations", - "//xla/service/gpu:thunk", - "@com_google_absl//absl/status", - "@com_google_absl//absl/strings", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/profiler/lib:scoped_annotation", - ], -) - -cc_library( - name = "send_recv_thunk", - srcs = ["send_recv_thunk.cc"], - hdrs = ["send_recv_thunk.h"], - visibility = ["//visibility:public"], - deps = [ - "//xla:shape_util", - "//xla:status", - "//xla:statusor", - "//xla:xla_data_proto_cc", - "//xla/service:buffer_assignment", - "//xla/service:global_device_id", - "//xla/service/gpu:thunk", - "//xla/stream_executor", - "@com_google_absl//absl/base:core_headers", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/log", - "@com_google_absl//absl/status", - "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/synchronization", - "@local_tsl//tsl/concurrency:async_value", - "@local_tsl//tsl/platform:statusor", - "@local_tsl//tsl/profiler/lib:traceme", - ], -) - -cc_library( - name = "triangular_solve_thunk", - srcs = if_gpu_is_configured(["triangular_solve_thunk.cc"]), - hdrs = if_gpu_is_configured(["triangular_solve_thunk.h"]), - visibility = ["//visibility:public"], - deps = if_gpu_is_configured([ - "@com_google_absl//absl/strings:str_format", - "//xla/service/gpu:buffer_allocations", - "//xla/service/gpu:make_batch_pointers", - "//xla/service/gpu:thunk", - "//xla:types", - "//xla:util", - "//xla:xla_data_proto_cc", - "//xla/service:buffer_assignment", - "//xla/hlo/ir:hlo", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:logging", - "//xla/stream_executor", - "//xla/stream_executor:device_memory", - "//xla/stream_executor/gpu:gpu_asm_opts", - ]) + [ - "@com_google_absl//absl/status", - "@local_tsl//tsl/platform:status", - ], -) - -cc_library( - name = "while_thunk", - srcs = ["while_thunk.cc"], - hdrs = ["while_thunk.h"], - visibility = ["//visibility:public"], - deps = [ - ":sequential_thunk", - "//xla:status", - "//xla/hlo/ir:hlo", - "//xla/service:buffer_assignment", - "//xla/service/gpu:buffer_allocations", - "//xla/service/gpu:thunk", - "//xla/stream_executor", - "@com_google_absl//absl/status", - "@com_google_absl//absl/strings:str_format", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:logging", - ], -) diff --git a/third_party/xla/xla/service/gpu/runtime3/README.md b/third_party/xla/xla/service/gpu/runtime3/README.md deleted file mode 100644 index 351de805194d9a..00000000000000 --- a/third_party/xla/xla/service/gpu/runtime3/README.md +++ /dev/null @@ -1,8 +0,0 @@ -# XLA:GPU Runtime Under Construction - -This is a temporary folder to consolidate and clean up the Thunk-based XLA:GPU -runtime (right now it's all over the xla/servive/gpu folder), with a goal to -eventually delete `runtime` and `runtime2` folders and make `runtime3` the -default and only XLA:GPU runtime. - -Preliminary timeline for completion is late Q4 2023 - early Q1 2024. diff --git a/third_party/xla/xla/service/gpu/runtime_intrinsics.cc b/third_party/xla/xla/service/gpu/runtime_intrinsics.cc index 13c81ba11e4737..94795cad0fb712 100644 --- a/third_party/xla/xla/service/gpu/runtime_intrinsics.cc +++ b/third_party/xla/xla/service/gpu/runtime_intrinsics.cc @@ -29,8 +29,8 @@ limitations under the License. #include "xla/shape_util.h" #include "xla/status.h" #include "xla/statusor.h" -#include "xla/stream_executor/multi_platform_manager.h" #include "xla/stream_executor/platform.h" +#include "xla/stream_executor/platform_manager.h" #include "xla/stream_executor/stream.h" #include "xla/stream_executor/stream_executor.h" #include "xla/util.h" @@ -51,7 +51,7 @@ absl::Status AssertOnGpu(void* stream_handle, void* buffer, absl::string_view error_msg) { TF_ASSIGN_OR_RETURN( se::Platform * platform, - se::MultiPlatformManager::PlatformWithName(GetGpuPlatformName())); + se::PlatformManager::PlatformWithName(GetGpuPlatformName())); se::StreamExecutorConfig config; config.gpu_stream = stream_handle; TF_ASSIGN_OR_RETURN(se::StreamExecutor * executor, @@ -64,9 +64,9 @@ absl::Status AssertOnGpu(void* stream_handle, void* buffer, int8_t expected = false; int64_t byte_size = sizeof(int8_t); CHECK_EQ(byte_size, ShapeUtil::ByteSizeOfPrimitiveType(PrimitiveType::PRED)); - stream->ThenMemcpy( + TF_RETURN_IF_ERROR(stream->Memcpy( &expected, se::DeviceMemoryBase{buffer, static_cast(byte_size)}, - byte_size); + byte_size)); TF_RETURN_IF_ERROR(stream->BlockHostUntilDone()); if (!static_cast(expected)) { return Internal("%s", error_msg); diff --git a/third_party/xla/xla/service/gpu/softmax_rewriter_triton.cc b/third_party/xla/xla/service/gpu/softmax_rewriter_triton.cc index 0401ae5b546d3f..509a0d9289dd4b 100644 --- a/third_party/xla/xla/service/gpu/softmax_rewriter_triton.cc +++ b/third_party/xla/xla/service/gpu/softmax_rewriter_triton.cc @@ -13,8 +13,8 @@ limitations under the License. #include "xla/service/gpu/softmax_rewriter_triton.h" #include -#include #include +#include #include #include "absl/algorithm/container.h" @@ -34,11 +34,11 @@ limitations under the License. #include "xla/service/gpu/backend_configs.pb.h" #include "xla/service/gpu/ir_emission_utils.h" #include "xla/service/gpu/triton_support.h" +#include "xla/service/instruction_fusion.h" #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/status.h" #include "xla/status_macros.h" -#include "xla/statusor.h" #include "xla/stream_executor/device_description.h" #include "xla/util.h" #include "xla/xla_data.pb.h" @@ -231,81 +231,10 @@ bool IsTriviallyConnectedProducerOf( bool IsTritonSupportedComputation(const HloComputation* computation, const se::GpuComputeCapability& gpu_version) { - for (const HloInstruction* instr : computation->instructions()) { - if (!IsTritonSupportedInstruction(instr, gpu_version)) { - return false; - } - } - return true; -} - -std::optional MatchesTritonCompatibleClosedReductionDiamond( - HloInstruction* instr, const se::GpuComputeCapability& gpu_version) { - // Return the producer of the following pattern: - // - // producer - // | \ - // | reduce_{max,sum,...} - // | | - // | broadcast - // | / - // binop (elementwise) - // - // where each edge is allowed to contain also trivial operations that can be - // generated by Triton. We mean by "trivial" here those operations that do not - // increase the amount of memory read/written by the fusion, and that are - // compatible with any chosen tiling. - // - // We also assume that the reduction is done on the last axis of the producer - // array. - std::optional match_failure = std::nullopt; - - if (!instr->IsElementwiseBinary() || - !IsTritonSupportedInstruction(instr, gpu_version)) { - return match_failure; - } - - HloInstruction* producer; - HloInstruction* broadcast; - HloInstruction* reduce; - - if (!(TrivialEdge(&broadcast, instr->mutable_operand(1), - HloOpcode::kBroadcast, gpu_version) && - TrivialEdge(&reduce, broadcast->mutable_operand(0), HloOpcode::kReduce, - gpu_version) && - HasDefaultLayout(broadcast->shape()) && - HasDefaultLayout(reduce->shape()) && reduce->operand_count() == 2 && - reduce->operand(1)->opcode() == HloOpcode::kConstant && - IsTritonSupportedComputation(reduce->to_apply(), gpu_version))) { - return match_failure; - } - - if (!HasOneUse(broadcast) || !HasOneUse(reduce)) { - return match_failure; - } - - producer = reduce->mutable_operand(0); - - if (!(reduce->dimensions().size() == 1 && - reduce->dimensions(0) == producer->shape().rank() - 1 && - !absl::c_linear_search(broadcast->dimensions(), - broadcast->shape().rank() - 1))) { - return match_failure; - } - - while (IsTriviallyFusible(producer, gpu_version)) { - producer = ChooseOperandForFusionProcessing(producer); - } - - if (!HasDefaultLayout(producer->shape()) || - !IsTriviallyConnectedProducerOf(producer, instr->mutable_operand(0), - gpu_version) || - !(producer == instr->operand(0) || - instr->operand(0)->user_count() == 1)) { - return match_failure; - } - - return producer; + return absl::c_all_of( + computation->instructions(), [&](const HloInstruction* instr) { + return IsTritonSupportedInstruction(instr, gpu_version); + }); } // Finds the first non-fusible producer of a diamond. This instruction is either @@ -391,6 +320,88 @@ using DiamondDescriptor = DiamondChainDescriptor; } // anonymous namespace +DiamondMatchingDecision +SoftmaxRewriterTriton::MatchesTritonCompatibleClosedReductionDiamond( + HloInstruction* instr) const { + if (!instr->IsElementwiseBinary()) { + return "Root is not elementwise binary."; + } + + if (!IsTritonSupportedInstruction(instr, gpu_version_)) { + return "Root is not supported for Triton instruction."; + } + + HloInstruction* producer; + HloInstruction* broadcast; + HloInstruction* reduce; + + if (!TrivialEdge(&broadcast, instr->mutable_operand(1), HloOpcode::kBroadcast, + gpu_version_)) { + return "Could not find a trivial connection from root to a broadcast."; + } + + if (!TrivialEdge(&reduce, broadcast->mutable_operand(0), HloOpcode::kReduce, + gpu_version_)) { + return "Could not find a trivial connection from matched broadcast to a " + "reduction."; + } + + if (!(HasDefaultLayout(broadcast->shape()) && + HasDefaultLayout(reduce->shape()))) { + return "Broadcast or reduce have non-default layouts."; + } + + if (!(reduce->operand_count() == 2 && + reduce->operand(1)->opcode() == HloOpcode::kConstant)) { + return "Reduce has a non-constant second operand and/or is variadic."; + } + + if (!(IsTritonSupportedComputation(reduce->to_apply(), gpu_version_))) { + return "Unsupported reduction by Triton."; + } + + if (!HasOneUse(broadcast) || !HasOneUse(reduce)) { + return "More than one use of broadcast or reduce."; + } + + producer = reduce->mutable_operand(0); + + if (reduce->dimensions().size() != 1 || + reduce->dimensions(0) != producer->shape().rank() - 1) { + return "Reduction is not a row-reduction of a single operand."; + } + + if (absl::c_linear_search(broadcast->dimensions(), + broadcast->shape().rank() - 1)) { + return "Broadcast is not along the reduction dimension."; + } + + while (IsTriviallyFusible(producer, gpu_version_)) { + producer = ChooseOperandForFusionProcessing(producer); + } + + if (!HasDefaultLayout(producer->shape())) { + return "Producer has non-default layout."; + } + + if (!IsTriviallyConnectedProducerOf(producer, instr->mutable_operand(0), + gpu_version_)) { + return "Producer is not trivially connected."; + } + + if (producer != instr->operand(0) && instr->operand(0)->user_count() != 1) { + return "Unsupported root-producer connection."; + } + + VLOG(5) << "Matched Softmax diamond with: "; + VLOG(5) << "root: " << instr->ToString(); + VLOG(5) << "producer: " << producer->ToString(); + VLOG(5) << "broadcast: " << broadcast->ToString(); + VLOG(5) << "reduce: " << reduce->ToString(); + + return producer; +} + std::vector SoftmaxRewriterTriton::FindAllFusibleDiamondChains( HloModule& module, @@ -411,9 +422,16 @@ SoftmaxRewriterTriton::FindAllFusibleDiamondChains( continue; } - if (auto producer = MatchesTritonCompatibleClosedReductionDiamond( - instr, gpu_version_)) { - matched_diamonds.push_back(DiamondDescriptor{instr, producer.value()}); + auto producer = MatchesTritonCompatibleClosedReductionDiamond(instr); + if (std::holds_alternative(producer)) { + matched_diamonds.push_back(DiamondDescriptor{ + instr, + std::get(producer), + }); + } else { + VLOG(5) << "Cannot match the diamond pattern for instruction " + << instr->ToString() + << ". Reason: " << std::get(producer).Explain(); } } } @@ -509,7 +527,8 @@ SoftmaxRewriterTriton::FindAllFusibleDiamondChains( // diamond producer of diamond chain n+1. diamond_chains.push_back(DiamondChainDescriptor{ last_trivially_fusible_user(previous_diamond_root), - current_fusion_producer}); + current_fusion_producer, + }); current_fusion_producer = first_non_fusible_diamond_producer; current_reduce_dimension_size = diamond_reduce_dimension_size; diff --git a/third_party/xla/xla/service/gpu/softmax_rewriter_triton.h b/third_party/xla/xla/service/gpu/softmax_rewriter_triton.h index 3b71d3e3f9280e..9463d510f4590e 100644 --- a/third_party/xla/xla/service/gpu/softmax_rewriter_triton.h +++ b/third_party/xla/xla/service/gpu/softmax_rewriter_triton.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef XLA_SERVICE_GPU_SOFTMAX_REWRITER_TRITON_H_ #define XLA_SERVICE_GPU_SOFTMAX_REWRITER_TRITON_H_ +#include #include #include "absl/container/flat_hash_set.h" @@ -23,8 +24,8 @@ limitations under the License. #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/service/hlo_pass_interface.h" +#include "xla/service/instruction_fusion.h" #include "xla/status.h" -#include "xla/statusor.h" #include "xla/stream_executor/device_description.h" namespace xla { @@ -35,6 +36,8 @@ struct DiamondChainDescriptor { HloInstruction* producer = nullptr; }; +using DiamondMatchingDecision = std::variant; + // Rewrite compatible Softmax into a custom fusion region to be code-generated // with the Triton-based Softmax emitter. class SoftmaxRewriterTriton : public HloModulePass { @@ -60,6 +63,26 @@ class SoftmaxRewriterTriton : public HloModulePass { // fusion. absl::Status FuseDiamondChain(const DiamondChainDescriptor& diamond_chain); + // Return the producer of the following pattern: + // + // producer + // | \ + // | reduce_{max,sum,...} + // | | + // | broadcast + // | / + // binop (elementwise) + // + // where each edge is allowed to contain also trivial operations that can be + // generated by Triton. We mean by "trivial" here those operations that do not + // increase the amount of memory read/written by the fusion, and that are + // compatible with any chosen tiling. + // + // We also assume that the reduction is done on the last axis of the producer + // array. + DiamondMatchingDecision MatchesTritonCompatibleClosedReductionDiamond( + HloInstruction* instr) const; + private: se::GpuComputeCapability gpu_version_; }; diff --git a/third_party/xla/xla/service/gpu/softmax_rewriter_triton_test.cc b/third_party/xla/xla/service/gpu/softmax_rewriter_triton_test.cc index 80c258cf9e47fc..6cbfc618a6b760 100644 --- a/third_party/xla/xla/service/gpu/softmax_rewriter_triton_test.cc +++ b/third_party/xla/xla/service/gpu/softmax_rewriter_triton_test.cc @@ -13,16 +13,18 @@ limitations under the License. #include #include +#include #include #include #include #include "absl/base/optimization.h" #include "absl/log/check.h" -#include "absl/log/log.h" #include "absl/strings/substitute.h" +#include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/primitive_util.h" +#include "xla/service/instruction_fusion.h" #include "xla/service/pattern_matcher.h" #include "xla/service/pattern_matcher_gmock.h" #include "xla/statusor.h" @@ -38,6 +40,8 @@ namespace { namespace m = ::xla::match; +using ::testing::HasSubstr; + // Wrapper around SoftmaxRewriterTriton(gpu_version).Run(module) that finds // and fuses as many diamond chains as possible without invoking any kind of // cost analysis. @@ -1736,6 +1740,47 @@ ENTRY main { SoftmaxRewriterTritonMatchAndRewrite(gpu_version_, module.get()).value()); } +TEST_F(SoftmaxRewriterTritonTest, FusionDecisionIsCapturedExplicitly) { + const std::string hlo_string = R"( +HloModule softmax +max_computation { + arg_0 = f32[] parameter(0) + arg_1 = f32[] parameter(1) + ROOT maximum = f32[] maximum(arg_0, arg_1) +} +ENTRY main { + param_0 = f32[127,125]{1,0} parameter(0) + identity = f32[] parameter(1) + reduce = f32[127]{0} reduce(param_0, identity), dimensions={1}, to_apply=max_computation + broadcast = f32[127,125]{1,0} broadcast(reduce), dimensions={0} + ROOT subtract = f32[127,125]{1,0} subtract(param_0, broadcast) +} +)"; + + auto module = ParseAndReturnVerifiedModule(hlo_string).value(); + SoftmaxRewriterTriton softmax_rewriter_triton(gpu_version_); + int unmatched = 0, matched = 0; + for (HloInstruction* instruction : + module->entry_computation()->MakeInstructionPostOrder()) { + DiamondMatchingDecision decision = + softmax_rewriter_triton.MatchesTritonCompatibleClosedReductionDiamond( + instruction); + if (std::holds_alternative(decision)) { + std::string actual_decision = + std::get(decision).Explain(); + EXPECT_THAT(actual_decision, + AnyOf(HasSubstr("Root is not elementwise binary"), + HasSubstr("Reduce has a non-constant second operand " + "and/or is variadic"))); + unmatched++; + } else { + matched++; + } + } + EXPECT_EQ(unmatched, 5); + EXPECT_EQ(matched, 0); +} + INSTANTIATE_TEST_SUITE_P(SoftmaxRewriterTritonTestSuite, SoftmaxRewriterTritonTest, ::testing::Values(F32, F16, BF16)); diff --git a/third_party/xla/xla/service/gpu/split_k_gemm_rewriter.cc b/third_party/xla/xla/service/gpu/split_k_gemm_rewriter.cc index 84d393fecfe6b1..4269c180195af7 100644 --- a/third_party/xla/xla/service/gpu/split_k_gemm_rewriter.cc +++ b/third_party/xla/xla/service/gpu/split_k_gemm_rewriter.cc @@ -379,9 +379,9 @@ absl::Status MakeDotSplitKBatch(HloInstruction* dot_fusion, dot_fusion->parent()->AddInstruction(HloInstruction::CreateConstant( LiteralUtil::Zero(root->shape().element_type()))); // The batch dimension to reduce is the first one by construction. - TF_ASSIGN_OR_RETURN( - HloInstruction * reduce, - MakeReduceHlo(dot_fusion, zero, /*dimensions=*/{0}, HloOpcode::kAdd)); + TF_ASSIGN_OR_RETURN(HloInstruction * reduce, + MakeReduceHlo(dot_fusion, zero, /*dimensions=*/{0}, + HloOpcode::kAdd, &dot_fusion->metadata())); // The output of the reduce has to have the layout of the original dot. *reduce->mutable_shape()->mutable_layout() = output_layout; diff --git a/third_party/xla/xla/service/gpu/split_k_gemm_rewriter_test.cc b/third_party/xla/xla/service/gpu/split_k_gemm_rewriter_test.cc index 7d1756abbe047b..d064c733ec4a16 100644 --- a/third_party/xla/xla/service/gpu/split_k_gemm_rewriter_test.cc +++ b/third_party/xla/xla/service/gpu/split_k_gemm_rewriter_test.cc @@ -92,15 +92,17 @@ ENTRY e { p0 = s8[3,128,5,32]{3,2,1,0} parameter(0) p1 = bf16[16,128]{1,0} parameter(1) ROOT fusion = bf16[480,16]{1,0} fusion(p0, p1), - kind=kCustom, calls=triton_gemm_dot, backend_config="__triton_gemm" + kind=kCustom, calls=triton_gemm_dot, backend_config="__triton_gemm", + metadata={op_name="foo"} })"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(hlo_text)); TritonGemmConfig config(16, 16, 16, 4, 1, 4); TF_EXPECT_OK(MakeDotSplitKBatch( module->entry_computation()->root_instruction(), config)); - EXPECT_EQ(module->entry_computation()->root_instruction()->opcode(), - HloOpcode::kReduce); + const HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_EQ(root->opcode(), HloOpcode::kReduce); + EXPECT_EQ(root->metadata().op_name(), "foo"); } TEST_F(SplitKTest, MakeSplitKWithOutputFusion) { diff --git a/third_party/xla/xla/service/gpu/stream_executor_util.cc b/third_party/xla/xla/service/gpu/stream_executor_util.cc index 60c6c9fec0f9e3..aace5768acb7bf 100644 --- a/third_party/xla/xla/service/gpu/stream_executor_util.cc +++ b/third_party/xla/xla/service/gpu/stream_executor_util.cc @@ -15,6 +15,7 @@ limitations under the License. #include "xla/service/gpu/stream_executor_util.h" +#include #include #include #include @@ -37,8 +38,10 @@ limitations under the License. #include "xla/layout_util.h" #include "xla/stream_executor/kernel.h" #include "xla/stream_executor/kernel_spec.h" +#include "xla/stream_executor/launch_dim.h" #include "xla/util.h" #include "tsl/platform/errors.h" +#include "tsl/platform/statusor.h" #include "tsl/util/env_var.h" #include "tsl/util/proto/proto_utils.h" @@ -334,12 +337,13 @@ absl::StatusOr> CreateKernel( reinterpret_cast(cubin_data.data()), kernel_name); } - auto kernel_base = std::make_unique(stream_exec); - TF_RETURN_IF_ERROR(stream_exec->GetKernel(loader_spec, kernel_base.get())); + TF_ASSIGN_OR_RETURN(std::unique_ptr kernel, + se::Kernel::Create(stream_exec, loader_spec)); + se::KernelMetadata m; m.set_shared_memory_bytes(shared_mem_bytes); - kernel_base->set_metadata(m); - return std::move(kernel_base); + kernel->set_metadata(m); + return kernel; } absl::Status ExecuteKernelOnStream(const se::Kernel& kernel, @@ -354,6 +358,20 @@ absl::Status ExecuteKernelOnStream(const se::Kernel& kernel, dims.block_counts(), kernel, *kernel_args); } +absl::Status ExecuteKernelOnStream(const se::Kernel& kernel, + absl::Span args, + const LaunchDimensions& dims, + const se::ClusterDim& cluster_dim, + se::Stream* stream) { + TF_ASSIGN_OR_RETURN( + std::unique_ptr kernel_args, + se::PackKernelArgs(args, kernel.metadata())); + + return stream->parent()->Launch(stream, dims.thread_counts_per_block(), + dims.block_counts(), cluster_dim, kernel, + *kernel_args); +} + // Unimplemented for integers yet. template typename std::enable_if::value, @@ -413,8 +431,8 @@ static void InitializeTypedBuffer(se::Stream* stream, int64_t elements_copied = std::min(host_buffer->size() - host_index, elements_left); se::DeviceMemoryBase mem(current_addr, elements_copied * sizeof(T)); - stream->ThenMemcpy(&mem, host_buffer->data() + host_index, - elements_copied * sizeof(T)); + TF_CHECK_OK(stream->Memcpy(&mem, host_buffer->data() + host_index, + elements_copied * sizeof(T))); current_addr += elements_copied * sizeof(T); elements_left -= elements_copied; host_index += elements_copied; @@ -468,6 +486,20 @@ absl::StatusOr GetDNNConvKindFromCudnnConvKind( return Internal("Unexpected convolution kind"); } +absl::StatusOr GetDNNNormKindFromCudnnNormKind( + CudnnNormKind kind) { + switch (kind) { + case CudnnNormKind::kLayerForwardInfer: + return se::dnn::LAYER_FWD_INFER; + case CudnnNormKind::kLayerForwardTrain: + return se::dnn::LAYER_FWD_TRAIN; + case CudnnNormKind::kLayerBackward: + return se::dnn::LAYER_BWD; + default: + return Internal("Unexpected norm kind"); + } +} + absl::StatusOr GetDNNFusedMHAKindFromCudnnfMHAKind( CudnnfMHAKind kind) { switch (kind) { @@ -519,7 +551,7 @@ absl::StatusOr GetDNNDataTypeFromPrimitiveType( default: break; } - return Internal("Unsupported convolution datatype"); + return Internal("Unsupported datatype"); } bool RequireDeterminism(const HloModuleConfig& config) { diff --git a/third_party/xla/xla/service/gpu/stream_executor_util.h b/third_party/xla/xla/service/gpu/stream_executor_util.h index ef418f5274a399..38598f3f42750c 100644 --- a/third_party/xla/xla/service/gpu/stream_executor_util.h +++ b/third_party/xla/xla/service/gpu/stream_executor_util.h @@ -27,6 +27,7 @@ limitations under the License. #include "xla/service/hlo_module_config.h" #include "xla/statusor.h" #include "xla/stream_executor/kernel_spec.h" +#include "xla/stream_executor/launch_dim.h" #include "xla/stream_executor/stream_executor.h" #include "xla/types.h" #include "xla/xla_data.pb.h" @@ -91,6 +92,13 @@ absl::Status ExecuteKernelOnStream(const se::Kernel& kernel, const LaunchDimensions& dims, se::Stream* stream); +// Runs loaded kernel on the stream with the provided arguments. +absl::Status ExecuteKernelOnStream(const se::Kernel& kernel, + absl::Span args, + const LaunchDimensions& dims, + const se::ClusterDim& cluster_dim, + se::Stream* stream); + // Initializes `buffer` with random data on `stream`. // `rng_state` is an inout parameter for the pseudorandom generator state. // `buffer_type` determines what buffer would be filled out with. @@ -103,6 +111,9 @@ void InitializeBuffer(se::Stream* stream, PrimitiveType buffer_type, absl::StatusOr GetDNNConvKindFromCudnnConvKind( CudnnConvKind kind); +absl::StatusOr GetDNNNormKindFromCudnnNormKind( + CudnnNormKind kind); + absl::StatusOr GetDNNFusedMHAKindFromCudnnfMHAKind( CudnnfMHAKind kind); diff --git a/third_party/xla/xla/service/gpu/target_util.cc b/third_party/xla/xla/service/gpu/target_util.cc index e7cd015f1161ac..15d0799b005153 100644 --- a/third_party/xla/xla/service/gpu/target_util.cc +++ b/third_party/xla/xla/service/gpu/target_util.cc @@ -205,6 +205,9 @@ struct TargetDeviceFunction GetDeviceFunctionRoot( case TargetDeviceFunctionID::kCos: { return {"__nv_cos", "__ocml_cos", "_Z15__spirv_ocl_cos"}; } + case TargetDeviceFunctionID::kErf: { + return {"__nv_erf", "__ocml_erf", "_Z15__spirv_ocl_erf"}; + } case TargetDeviceFunctionID::kExp: { return {"__nv_exp", "__ocml_exp", "_Z15__spirv_ocl_exp"}; } @@ -256,6 +259,8 @@ absl::StatusOr GetTargetDeviceFunctionID(HloOpcode op) { return TargetDeviceFunctionID::kCos; case HloOpcode::kExp: return TargetDeviceFunctionID::kExp; + case HloOpcode::kErf: + return TargetDeviceFunctionID::kErf; case HloOpcode::kExpm1: return TargetDeviceFunctionID::kExpm1; case HloOpcode::kLog: diff --git a/third_party/xla/xla/service/gpu/target_util.h b/third_party/xla/xla/service/gpu/target_util.h index 7024438b940be1..d88e8ed5b54f14 100644 --- a/third_party/xla/xla/service/gpu/target_util.h +++ b/third_party/xla/xla/service/gpu/target_util.h @@ -62,6 +62,7 @@ enum class TargetDeviceFunctionID { kSqrt, kTan, kTanh, + kErf, }; // HLO opcode -> TargetDeviceFunctionID mapping. diff --git a/third_party/xla/xla/service/gpu/tests/BUILD b/third_party/xla/xla/service/gpu/tests/BUILD index b5bfa5c4cd968a..ed6b14caee63d4 100644 --- a/third_party/xla/xla/service/gpu/tests/BUILD +++ b/third_party/xla/xla/service/gpu/tests/BUILD @@ -23,7 +23,8 @@ load( ) package( - default_visibility = ["//visibility:public"], + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = [":friends"], licenses = ["notice"], ) @@ -41,7 +42,6 @@ filegroup( "**/*.cc", "**/*.h", ]), - visibility = ["//visibility:public"], ) cc_library( @@ -50,12 +50,12 @@ cc_library( srcs = ["gpu_codegen_test.cc"], hdrs = ["gpu_codegen_test.h"], tags = tf_cuda_tests_tags(), - visibility = ["//visibility:public"], deps = [ "//xla:debug_options_flags", "//xla:shape_util", "//xla/service:gpu_plugin", "//xla/service/gpu:gpu_executable", + "//xla/stream_executor:platform_manager", "//xla/tests:filecheck", "//xla/tests:llvm_irgen_test_base", "//xla/tests:verified_hlo_module", @@ -85,6 +85,20 @@ xla_cc_test( ], ) +xla_test( + name = "float_conversions_test", + srcs = ["float_conversions_test.cc"], + backends = ["gpu"], + deps = [ + ":gpu_codegen_test", + "//xla:error_spec", + "//xla/tests:test_utils", + "@com_google_absl//absl/strings:string_view", + "@local_tsl//tsl/platform:test", + "@local_tsl//tsl/platform:test_main", + ], +) + xla_cc_test( name = "gpu_reduce_scatter_creator_test", srcs = ["gpu_reduce_scatter_creator_test.cc"], @@ -658,20 +672,16 @@ lit_test_suite( [ "add_preds.hlo", "calling_convention.hlo", - "concat.hlo", - "constant.hlo", "copy.hlo", - "copy_nested.hlo", "dynamic_update_slice_inplace.hlo", "element_wise_row_vectorization.hlo", "fused_scatter.hlo", - "fused_slice_different_operands.hlo", "fused_slice.hlo", - "fusion.hlo", "kernel_reuse.hlo", "launch_dimensions.hlo", "pad_to_static.hlo", "reduce_atomic_min.hlo", + "reduce_column_layout_change.hlo", "reduce_f64_column.hlo", "reduce_large_row_to_scalar.hlo", "reduce_row_vectorized.hlo", @@ -731,7 +741,6 @@ filegroup( "//xla/tools/hlo_opt:gpu_specs/v100.txtpb", "@llvm-project//llvm:FileCheck", ], - visibility = ["//visibility:public"], ) xla_cc_test( @@ -775,6 +784,7 @@ xla_cc_test( "//xla:shape_util", "//xla:types", "//xla:xla_proto_cc", + "//xla/stream_executor:platform_manager", "@com_google_absl//absl/strings", "@local_tsl//tsl/platform:status", "@local_tsl//tsl/platform:statusor", @@ -856,7 +866,6 @@ cc_library( testonly = True, srcs = ["simple_optimization_test.cc"], tags = tf_cuda_tests_tags(), - visibility = ["//visibility:public"], deps = [ "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", diff --git a/third_party/xla/xla/service/gpu/tests/add_preds.hlo b/third_party/xla/xla/service/gpu/tests/add_preds.hlo index 0134a471255dbd..120b6a5ad686bf 100644 --- a/third_party/xla/xla/service/gpu/tests/add_preds.hlo +++ b/third_party/xla/xla/service/gpu/tests/add_preds.hlo @@ -1,29 +1,9 @@ -// RUN: hlo-opt %s --platform=gpu --stage=llvm-before-optimizations --xla_gpu_target_config_filename=%S/../../../tools/hlo_opt/gpu_specs/%{GPU}.txtpb | FileCheck --check-prefixes=CHECK,CHECK-%{PTX} %s +// RUN: hlo-opt %s --platform=gpu --stage=llvm-before-optimizations --xla_gpu_target_config_filename=%S/../../../tools/hlo_opt/gpu_specs/%{GPU}.txtpb | FileCheck %s -// NOTE: Assertions have been autogenerated by utils/generate-test-checks.py - -// CHECK-LABEL: entry: -// CHECK-PTX: %[[VAL_0:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x -// CHECK-GCN: %[[VAL_0:.*]] = call i32 @llvm.amdgcn.workgroup.id.x -// CHECK-PTX: %[[VAL_1:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x -// CHECK-GCN: %[[VAL_1:.*]] = call i32 @llvm.amdgcn.workitem.id.x -// CHECK: %[[VAL_2:.*]] = mul nuw nsw i32 %[[VAL_0]], 1 -// CHECK: %[[VAL_3:.*]] = add nuw nsw i32 %[[VAL_2]], %[[VAL_1]] -// CHECK: %[[VAL_4:.*]] = icmp ult i32 %[[VAL_3]], 1 -// CHECK: call void @llvm.assume(i1 %[[VAL_4]]) -// CHECK: %[[VAL_5:.*]] = icmp ult i32 %[[VAL_3]], 1 -// CHECK: br i1 %[[VAL_5]], label %[[VAL_6:.*]], label %[[VAL_7:.*]] -// CHECK: fusion.in_bounds-after: ; preds = %[[VAL_6]], %[[VAL_8:.*]] -// CHECK: ret void -// CHECK: fusion.in_bounds-true: ; preds = %[[VAL_8]] -// CHECK: %[[VAL_9:.*]] = load i8, ptr %[[VAL_10:.*]], align 1, !invariant.load -// CHECK: %[[VAL_11:.*]] = load i8, ptr %[[VAL_12:.*]], align 1, !invariant.load -// CHECK: %[[VAL_13:.*]] = or i8 %[[VAL_9]], %[[VAL_11]] -// CHECK: %[[VAL_14:.*]] = trunc i8 %[[VAL_13]] to i1 -// CHECK: %[[VAL_15:.*]] = xor i1 %[[VAL_14]], true -// CHECK: %[[VAL_16:.*]] = zext i1 %[[VAL_15]] to i8 -// CHECK: store i8 %[[VAL_16]], ptr %[[VAL_17:.*]], align 1 -// CHECK: br label %[[VAL_7]] +// CHECK: define void @fusion({{.*}}%[[ARG0:.*]], {{.*}}%[[ARG1:.*]], +// CHECK: %[[A:.*]] = load {{.*}} ptr %[[ARG0]] +// CHECK: %[[B:.*]] = load {{.*}} ptr %[[ARG1]] +// CHECK: or {{.*}} %[[A]], %[[B]] HloModule xla_computation_f.8, is_scheduled=true diff --git a/third_party/xla/xla/service/gpu/tests/calling_convention.hlo b/third_party/xla/xla/service/gpu/tests/calling_convention.hlo index 6abbeaad5e15b2..c84e0194c347cb 100644 --- a/third_party/xla/xla/service/gpu/tests/calling_convention.hlo +++ b/third_party/xla/xla/service/gpu/tests/calling_convention.hlo @@ -6,8 +6,8 @@ // CHECK-LABEL: target triple // CHECK: @buffer_for_dynamic // CHECK: @buffer_for_static -// CHECK-PTX: define void @custom_call(ptr noalias align 16 dereferenceable(32) %arg0, ptr noalias align 128 dereferenceable(4) %arg1, ptr noalias align 128 dereferenceable(4) %arg2, ptr noalias align 128 dereferenceable(32) %arg3) -// CHECK-GCN: define amdgpu_kernel void @custom_call(ptr noalias align 16 dereferenceable(32) %arg0, ptr noalias align 128 dereferenceable(4) %arg1, ptr noalias align 128 dereferenceable(4) %arg2, ptr noalias align 128 dereferenceable(32) %arg3) +// CHECK-PTX: define void @custom_call(ptr noalias align 16 dereferenceable(32) %arg0, ptr noalias align 128 dereferenceable(4) %arg1, ptr noalias align 128 dereferenceable(4) %arg2, ptr noalias align 128 dereferenceable(44) %arg3) +// CHECK-GCN: define amdgpu_kernel void @custom_call(ptr noalias align 16 dereferenceable(32) %arg0, ptr noalias align 128 dereferenceable(4) %arg1, ptr noalias align 128 dereferenceable(4) %arg2, ptr noalias align 128 dereferenceable(44) %arg3) // CHECK-NOT: @buffer_for_dynamic // CHECK-NOT: @buffer_for_static diff --git a/third_party/xla/xla/service/gpu/tests/concat.hlo b/third_party/xla/xla/service/gpu/tests/concat.hlo deleted file mode 100644 index f1087b150fb379..00000000000000 --- a/third_party/xla/xla/service/gpu/tests/concat.hlo +++ /dev/null @@ -1,202 +0,0 @@ -// RUN: hlo-opt %s --platform=gpu --stage=llvm-before-optimizations --xla_gpu_target_config_filename=%S/../../../tools/hlo_opt/gpu_specs/%{GPU}.txtpb | FileCheck --check-prefixes=CHECK,CHECK-%{PTX} %s - -// NOTE: Assertions have been autogenerated by utils/generate-test-checks.py - -// CHECK-LABEL: entry: -// CHECK-PTX: %[[VAL_0:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x -// CHECK-GCN: %[[VAL_0:.*]] = call i32 @llvm.amdgcn.workgroup.id.x -// CHECK-PTX: %[[VAL_1:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x -// CHECK-GCN: %[[VAL_1:.*]] = call i32 @llvm.amdgcn.workitem.id.x -// CHECK-PTX: %[[VAL_2:.*]] = mul nuw nsw i32 %[[VAL_0]], 128 -// CHECK-GCN: %[[VAL_2:.*]] = mul nuw nsw i32 %[[VAL_0]], 256 -// CHECK: %[[VAL_3:.*]] = add nuw nsw i32 %[[VAL_2]], %[[VAL_1]] -// CHECK: %[[VAL_4:.*]] = icmp ult i32 %[[VAL_3]], 11008 -// CHECK: call void @llvm.assume(i1 %[[VAL_4]]) -// CHECK: %[[VAL_5:.*]] = add nuw nsw i32 %[[VAL_3]], 0 -// CHECK: %[[VAL_6:.*]] = udiv i32 %[[VAL_5]], 1 -// CHECK: %[[VAL_7:.*]] = icmp ult i32 %[[VAL_3]], 11000 -// CHECK: br i1 %[[VAL_7]], label %[[VAL_8:.*]], label %[[VAL_9:.*]] -// CHECK: fusion.in_bounds-after: ; preds = %[[VAL_10:.*]], %[[VAL_11:.*]] -// CHECK: ret void -// CHECK: fusion.in_bounds-true: ; preds = %[[VAL_11]] -// CHECK: br label %[[VAL_12:.*]] -// CHECK: concat_index_from_operand_id0: ; preds = %[[VAL_13:.*]] -// CHECK: %[[VAL_14:.*]] = phi i32 [ 0, %[[VAL_13]] ] -// CHECK: %[[VAL_15:.*]] = sub nsw i32 %[[VAL_6]], %[[VAL_14]] -// CHECK: %[[VAL_16:.*]] = getelementptr inbounds [1000 x float], ptr %[[VAL_17:.*]], i32 0, i32 %[[VAL_15]] -// CHECK: %[[VAL_18:.*]] = load float, ptr %[[VAL_16]], align 4, !invariant.load -// CHECK: %[[VAL_19:.*]] = fptrunc float %[[VAL_18]] to half -// CHECK: br label %[[VAL_10]] -// CHECK: concat_index_from_operand_id1: ; preds = %[[VAL_20:.*]] -// CHECK: %[[VAL_21:.*]] = phi i32 [ 1000, %[[VAL_20]] ] -// CHECK: %[[VAL_22:.*]] = sub nsw i32 %[[VAL_6]], %[[VAL_21]] -// CHECK: %[[VAL_23:.*]] = getelementptr inbounds [1000 x float], ptr %[[VAL_24:.*]], i32 0, i32 %[[VAL_22]] -// CHECK: %[[VAL_25:.*]] = load float, ptr %[[VAL_23]], align 4, !invariant.load -// CHECK: %[[VAL_26:.*]] = fptrunc float %[[VAL_25]] to half -// CHECK: br label %[[VAL_10]] -// CHECK: concat_index_from_operand_id2: ; preds = %[[VAL_27:.*]] -// CHECK: %[[VAL_28:.*]] = phi i32 [ 2000, %[[VAL_27]] ] -// CHECK: %[[VAL_29:.*]] = sub nsw i32 %[[VAL_6]], %[[VAL_28]] -// CHECK: %[[VAL_30:.*]] = getelementptr inbounds [1000 x float], ptr %[[VAL_31:.*]], i32 0, i32 %[[VAL_29]] -// CHECK: %[[VAL_32:.*]] = load float, ptr %[[VAL_30]], align 4, !invariant.load -// CHECK: %[[VAL_33:.*]] = fptrunc float %[[VAL_32]] to half -// CHECK: br label %[[VAL_10]] -// CHECK: concat_index_from_operand_id3: ; preds = %[[VAL_34:.*]] -// CHECK: %[[VAL_35:.*]] = phi i32 [ 3000, %[[VAL_34]] ] -// CHECK: %[[VAL_36:.*]] = sub nsw i32 %[[VAL_6]], %[[VAL_35]] -// CHECK: %[[VAL_37:.*]] = getelementptr inbounds [1000 x float], ptr %[[VAL_38:.*]], i32 0, i32 %[[VAL_36]] -// CHECK: %[[VAL_39:.*]] = load float, ptr %[[VAL_37]], align 4, !invariant.load -// CHECK: %[[VAL_40:.*]] = fptrunc float %[[VAL_39]] to half -// CHECK: br label %[[VAL_10]] -// CHECK: concat_index_from_operand_id4: ; preds = %[[VAL_41:.*]] -// CHECK: %[[VAL_42:.*]] = phi i32 [ 4000, %[[VAL_41]] ] -// CHECK: %[[VAL_43:.*]] = sub nsw i32 %[[VAL_6]], %[[VAL_42]] -// CHECK: %[[VAL_44:.*]] = getelementptr inbounds [1000 x float], ptr %[[VAL_45:.*]], i32 0, i32 %[[VAL_43]] -// CHECK: %[[VAL_46:.*]] = load float, ptr %[[VAL_44]], align 4, !invariant.load -// CHECK: %[[VAL_47:.*]] = fptrunc float %[[VAL_46]] to half -// CHECK: br label %[[VAL_10]] -// CHECK: concat_index_from_operand_id5: ; preds = %[[VAL_48:.*]] -// CHECK: %[[VAL_49:.*]] = phi i32 [ 5000, %[[VAL_48]] ] -// CHECK: %[[VAL_50:.*]] = sub nsw i32 %[[VAL_6]], %[[VAL_49]] -// CHECK: %[[VAL_51:.*]] = getelementptr inbounds [1000 x float], ptr %[[VAL_52:.*]], i32 0, i32 %[[VAL_50]] -// CHECK: %[[VAL_53:.*]] = load float, ptr %[[VAL_51]], align 4, !invariant.load -// CHECK: %[[VAL_54:.*]] = fptrunc float %[[VAL_53]] to half -// CHECK: br label %[[VAL_10]] -// CHECK: concat_index_from_operand_id6: ; preds = %[[VAL_55:.*]] -// CHECK: %[[VAL_56:.*]] = phi i32 [ 6000, %[[VAL_55]] ] -// CHECK: %[[VAL_57:.*]] = sub nsw i32 %[[VAL_6]], %[[VAL_56]] -// CHECK: %[[VAL_58:.*]] = getelementptr inbounds [1000 x float], ptr %[[VAL_59:.*]], i32 0, i32 %[[VAL_57]] -// CHECK: %[[VAL_60:.*]] = load float, ptr %[[VAL_58]], align 4, !invariant.load -// CHECK: %[[VAL_61:.*]] = fptrunc float %[[VAL_60]] to half -// CHECK: br label %[[VAL_10]] -// CHECK: concat_index_from_operand_id7: ; preds = %[[VAL_62:.*]] -// CHECK: %[[VAL_63:.*]] = phi i32 [ 7000, %[[VAL_62]] ] -// CHECK: %[[VAL_64:.*]] = sub nsw i32 %[[VAL_6]], %[[VAL_63]] -// CHECK: %[[VAL_65:.*]] = getelementptr inbounds [1000 x float], ptr %[[VAL_66:.*]], i32 0, i32 %[[VAL_64]] -// CHECK: %[[VAL_67:.*]] = load float, ptr %[[VAL_65]], align 4, !invariant.load -// CHECK: %[[VAL_68:.*]] = fptrunc float %[[VAL_67]] to half -// CHECK: br label %[[VAL_10]] -// CHECK: concat_index_from_operand_id8: ; preds = %[[VAL_69:.*]] -// CHECK: %[[VAL_70:.*]] = phi i32 [ 8000, %[[VAL_69]] ] -// CHECK: %[[VAL_71:.*]] = sub nsw i32 %[[VAL_6]], %[[VAL_70]] -// CHECK: %[[VAL_72:.*]] = getelementptr inbounds [1000 x float], ptr %[[VAL_73:.*]], i32 0, i32 %[[VAL_71]] -// CHECK: %[[VAL_74:.*]] = load float, ptr %[[VAL_72]], align 4, !invariant.load -// CHECK: %[[VAL_75:.*]] = fptrunc float %[[VAL_74]] to half -// CHECK: br label %[[VAL_10]] -// CHECK: concat_index_from_operand_id9: ; preds = %[[VAL_76:.*]] -// CHECK: %[[VAL_77:.*]] = phi i32 [ 9000, %[[VAL_76]] ] -// CHECK: %[[VAL_78:.*]] = sub nsw i32 %[[VAL_6]], %[[VAL_77]] -// CHECK: %[[VAL_79:.*]] = getelementptr inbounds [1000 x float], ptr %[[VAL_80:.*]], i32 0, i32 %[[VAL_78]] -// CHECK: %[[VAL_81:.*]] = load float, ptr %[[VAL_79]], align 4, !invariant.load -// CHECK: %[[VAL_82:.*]] = fptrunc float %[[VAL_81]] to half -// CHECK: br label %[[VAL_10]] -// CHECK: concat_index_from_operand_id10: ; preds = %[[VAL_83:.*]] -// CHECK: %[[VAL_84:.*]] = phi i32 [ 10000, %[[VAL_83]] ] -// CHECK: %[[VAL_85:.*]] = sub nsw i32 %[[VAL_6]], %[[VAL_84]] -// CHECK: %[[VAL_86:.*]] = getelementptr inbounds [1000 x float], ptr %[[VAL_87:.*]], i32 0, i32 %[[VAL_85]] -// CHECK: %[[VAL_88:.*]] = load float, ptr %[[VAL_86]], align 4, !invariant.load -// CHECK: %[[VAL_89:.*]] = fptrunc float %[[VAL_88]] to half -// CHECK: br label %[[VAL_10]] -// CHECK: concatenate.pivot.5000.: ; preds = %[[VAL_8]] -// CHECK: %[[VAL_90:.*]] = icmp ult i32 %[[VAL_6]], 5000 -// CHECK: br i1 %[[VAL_90]], label %[[VAL_91:.*]], label %[[VAL_92:.*]] -// CHECK: concatenate.pivot.2000.: ; preds = %[[VAL_12]] -// CHECK: %[[VAL_93:.*]] = icmp ult i32 %[[VAL_6]], 2000 -// CHECK: br i1 %[[VAL_93]], label %[[VAL_94:.*]], label %[[VAL_95:.*]] -// CHECK: concatenate.pivot.1000.: ; preds = %[[VAL_91]] -// CHECK: %[[VAL_96:.*]] = icmp ult i32 %[[VAL_6]], 1000 -// CHECK: br i1 %[[VAL_96]], label %[[VAL_13]], label %[[VAL_20]] -// CHECK: concatenate.pivot.0.: ; preds = %[[VAL_94]] -// CHECK: br label %[[VAL_97:.*]] -// CHECK: concatenate.pivot.1000.1: ; preds = %[[VAL_94]] -// CHECK: br label %[[VAL_98:.*]] -// CHECK: concatenate.pivot.3000.: ; preds = %[[VAL_91]] -// CHECK: %[[VAL_99:.*]] = icmp ult i32 %[[VAL_6]], 3000 -// CHECK: br i1 %[[VAL_99]], label %[[VAL_27]], label %[[VAL_100:.*]] -// CHECK: concatenate.pivot.2000.2: ; preds = %[[VAL_95]] -// CHECK: br label %[[VAL_101:.*]] -// CHECK: concatenate.pivot.4000.: ; preds = %[[VAL_95]] -// CHECK: %[[VAL_102:.*]] = icmp ult i32 %[[VAL_6]], 4000 -// CHECK: br i1 %[[VAL_102]], label %[[VAL_34]], label %[[VAL_41]] -// CHECK: concatenate.pivot.3000.3: ; preds = %[[VAL_100]] -// CHECK: br label %[[VAL_103:.*]] -// CHECK: concatenate.pivot.4000.4: ; preds = %[[VAL_100]] -// CHECK: br label %[[VAL_104:.*]] -// CHECK: concatenate.pivot.8000.: ; preds = %[[VAL_12]] -// CHECK: %[[VAL_105:.*]] = icmp ult i32 %[[VAL_6]], 8000 -// CHECK: br i1 %[[VAL_105]], label %[[VAL_106:.*]], label %[[VAL_107:.*]] -// CHECK: concatenate.pivot.6000.: ; preds = %[[VAL_92]] -// CHECK: %[[VAL_108:.*]] = icmp ult i32 %[[VAL_6]], 6000 -// CHECK: br i1 %[[VAL_108]], label %[[VAL_48]], label %[[VAL_109:.*]] -// CHECK: concatenate.pivot.5000.5: ; preds = %[[VAL_106]] -// CHECK: br label %[[VAL_110:.*]] -// CHECK: concatenate.pivot.7000.: ; preds = %[[VAL_106]] -// CHECK: %[[VAL_111:.*]] = icmp ult i32 %[[VAL_6]], 7000 -// CHECK: br i1 %[[VAL_111]], label %[[VAL_55]], label %[[VAL_62]] -// CHECK: concatenate.pivot.6000.6: ; preds = %[[VAL_109]] -// CHECK: br label %[[VAL_112:.*]] -// CHECK: concatenate.pivot.7000.7: ; preds = %[[VAL_109]] -// CHECK: br label %[[VAL_113:.*]] -// CHECK: concatenate.pivot.9000.: ; preds = %[[VAL_92]] -// CHECK: %[[VAL_114:.*]] = icmp ult i32 %[[VAL_6]], 9000 -// CHECK: br i1 %[[VAL_114]], label %[[VAL_69]], label %[[VAL_115:.*]] -// CHECK: concatenate.pivot.8000.8: ; preds = %[[VAL_107]] -// CHECK: br label %[[VAL_116:.*]] -// CHECK: concatenate.pivot.10000.: ; preds = %[[VAL_107]] -// CHECK: %[[VAL_117:.*]] = icmp ult i32 %[[VAL_6]], 10000 -// CHECK: br i1 %[[VAL_117]], label %[[VAL_76]], label %[[VAL_83]] -// CHECK: concatenate.pivot.9000.9: ; preds = %[[VAL_115]] -// CHECK: br label %[[VAL_118:.*]] -// CHECK: concatenate.pivot.10000.10: ; preds = %[[VAL_115]] -// CHECK: br label %[[VAL_119:.*]] -// CHECK: out.1.merge: ; preds = %[[VAL_119]], %[[VAL_118]], %[[VAL_116]], %[[VAL_113]], %[[VAL_112]], %[[VAL_110]], %[[VAL_104]], %[[VAL_103]], %[[VAL_101]], %[[VAL_98]], %[[VAL_97]] -// CHECK: %[[VAL_120:.*]] = phi half [ %[[VAL_19]], %[[VAL_97]] ], [ %[[VAL_26]], %[[VAL_98]] ], [ %[[VAL_33]], %[[VAL_101]] ], [ %[[VAL_40]], %[[VAL_103]] ], [ %[[VAL_47]], %[[VAL_104]] ], [ %[[VAL_54]], %[[VAL_110]] ], [ %[[VAL_61]], %[[VAL_112]] ], [ %[[VAL_68]], %[[VAL_113]] ], [ %[[VAL_75]], %[[VAL_116]] ], [ %[[VAL_82]], %[[VAL_118]] ], [ %[[VAL_89]], %[[VAL_119]] ] -// CHECK: %[[VAL_121:.*]] = getelementptr half, ptr %[[VAL_122:.*]], i32 %[[VAL_3]] -// CHECK: %[[VAL_123:.*]] = getelementptr inbounds half, ptr %[[VAL_121]], i32 0 -// CHECK: store half %[[VAL_120]], ptr %[[VAL_123]], align 2 -// CHECK: br label %[[VAL_9]] - - -HloModule module, is_scheduled=true - -%fused_computation (param_0.1: f32[1000], param_1.2: f32[1000], param_2.3: f32[1000], param_3.4: f32[1000], param_4.5: f32[1000], param_5.6: f32[1000], param_6.7: f32[1000], param_7.8: f32[1000], param_8.9: f32[1000], param_9.10: f32[1000], param_10.11: f32[1000]) -> f16[11000] { - %param_10.11 = f32[1000]{0} parameter(10) - %converted0.1 = f16[1000]{0} convert(f32[1000]{0} %param_10.11) - %param_9.10 = f32[1000]{0} parameter(9) - %converted1.1 = f16[1000]{0} convert(f32[1000]{0} %param_9.10) - %param_8.9 = f32[1000]{0} parameter(8) - %converted2.1 = f16[1000]{0} convert(f32[1000]{0} %param_8.9) - %param_7.8 = f32[1000]{0} parameter(7) - %converted3.1 = f16[1000]{0} convert(f32[1000]{0} %param_7.8) - %param_6.7 = f32[1000]{0} parameter(6) - %converted4.1 = f16[1000]{0} convert(f32[1000]{0} %param_6.7) - %param_5.6 = f32[1000]{0} parameter(5) - %converted5.1 = f16[1000]{0} convert(f32[1000]{0} %param_5.6) - %param_4.5 = f32[1000]{0} parameter(4) - %converted6.1 = f16[1000]{0} convert(f32[1000]{0} %param_4.5) - %param_3.4 = f32[1000]{0} parameter(3) - %converted7.1 = f16[1000]{0} convert(f32[1000]{0} %param_3.4) - %param_2.3 = f32[1000]{0} parameter(2) - %converted8.1 = f16[1000]{0} convert(f32[1000]{0} %param_2.3) - %param_1.2 = f32[1000]{0} parameter(1) - %converted9.1 = f16[1000]{0} convert(f32[1000]{0} %param_1.2) - %param_0.1 = f32[1000]{0} parameter(0) - %converted10.1 = f16[1000]{0} convert(f32[1000]{0} %param_0.1) - ROOT %out.1 = f16[11000]{0} concatenate(f16[1000]{0} %converted0.1, f16[1000]{0} %converted1.1, f16[1000]{0} %converted2.1, f16[1000]{0} %converted3.1, f16[1000]{0} %converted4.1, /*index=5*/f16[1000]{0} %converted5.1, f16[1000]{0} %converted6.1, f16[1000]{0} %converted7.1, f16[1000]{0} %converted8.1, f16[1000]{0} %converted9.1, /*index=10*/f16[1000]{0} %converted10.1), dimensions={0} -} - -ENTRY %computation (p0: f32[1000], p1: f32[1000], p2: f32[1000], p3: f32[1000], p4: f32[1000], p5: f32[1000], p6: f32[1000], p7: f32[1000], p8: f32[1000], p9: f32[1000], p10: f32[1000]) -> f16[11000] { - %p10 = f32[1000]{0} parameter(10) - %p9 = f32[1000]{0} parameter(9) - %p8 = f32[1000]{0} parameter(8) - %p7 = f32[1000]{0} parameter(7) - %p6 = f32[1000]{0} parameter(6) - %p5 = f32[1000]{0} parameter(5) - %p4 = f32[1000]{0} parameter(4) - %p3 = f32[1000]{0} parameter(3) - %p2 = f32[1000]{0} parameter(2) - %p1 = f32[1000]{0} parameter(1) - %p0 = f32[1000]{0} parameter(0) - ROOT %fusion = f16[11000]{0} fusion(f32[1000]{0} %p10, f32[1000]{0} %p9, f32[1000]{0} %p8, f32[1000]{0} %p7, f32[1000]{0} %p6, /*index=5*/f32[1000]{0} %p5, f32[1000]{0} %p4, f32[1000]{0} %p3, f32[1000]{0} %p2, f32[1000]{0} %p1, /*index=10*/f32[1000]{0} %p0), kind=kLoop, calls=%fused_computation -} - diff --git a/third_party/xla/xla/service/gpu/tests/constant.hlo b/third_party/xla/xla/service/gpu/tests/constant.hlo deleted file mode 100644 index f7a0554fc28f93..00000000000000 --- a/third_party/xla/xla/service/gpu/tests/constant.hlo +++ /dev/null @@ -1,19 +0,0 @@ -// RUN: hlo-opt %s --platform=gpu --stage=llvm-before-optimizations --xla_gpu_target_config_filename=%S/../../../tools/hlo_opt/gpu_specs/%{GPU}.txtpb | FileCheck %s - -HloModule Test, is_scheduled=true - -fused_computation { - param_0 = pred[2,2]{1,0} parameter(0) - param_1 = pred[2,2]{1,0} parameter(1) - ROOT xor.1 = pred[2,2]{1,0} xor(pred[2,2]{1,0} param_0, pred[2,2]{1,0} param_1) -} - -ENTRY main { -// CHECK: %[[VAL_1:.*]] = getelementptr i8, ptr %arg0, i32 %{{.*}} -// CHECK: %{{.*}} = getelementptr inbounds i8, ptr %[[VAL_1]], i32 0 -// CHECK: %[[VAL_2:.*]] = getelementptr i8, ptr %arg1, i32 %{{.*}} -// CHECK: %{{.*}} = getelementptr inbounds i8, ptr %[[VAL_2]], i32 0 - a = pred[2, 2]{1,0} constant({{false, true}, {true, false}}) - b = pred[2, 2]{1,0} constant({{false, true}, {false, true}}) - ROOT wrapped_xor = pred[2,2]{1,0} fusion(pred[2,2]{1,0} a, pred[2,2]{1,0} b), kind=kLoop, calls=fused_computation -} diff --git a/third_party/xla/xla/service/gpu/tests/copy_nested.hlo b/third_party/xla/xla/service/gpu/tests/copy_nested.hlo deleted file mode 100644 index ce0888decf4cb9..00000000000000 --- a/third_party/xla/xla/service/gpu/tests/copy_nested.hlo +++ /dev/null @@ -1,95 +0,0 @@ -// RUN: hlo-opt %s --platform=gpu --stage=llvm-before-optimizations --xla_gpu_target_config_filename=%S/../../../tools/hlo_opt/gpu_specs/%{GPU}.txtpb | FileCheck --check-prefixes=CHECK,CHECK-%{PTX} %s - -// NOTE: Assertions have been autogenerated by utils/generate-test-checks.py - -// CHECK-LABEL: entry: -// CHECK: %[[VAL_0:.*]] = alloca i32, align 4 -// CHECK-PTX: store i32 0, ptr %[[VAL_0]], align 4 -// CHECK-GCN: store i32 0, ptr addrspace(5) %[[VAL_0]], align 4 -// CHECK: br label %[[VAL_1:.*]] -// CHECK: loop.loop_header: ; preds = %[[VAL_2:.*]], %[[VAL_3:.*]] -// CHECK-PTX: %[[VAL_4:.*]] = load i32, ptr %[[VAL_0]], align 4 -// CHECK-GCN: %[[VAL_4:.*]] = load i32, ptr addrspace(5) %[[VAL_0]], align 4 -// CHECK: %[[VAL_5:.*]] = icmp uge i32 %[[VAL_4]], 6000000 -// CHECK: br i1 %[[VAL_5]], label %[[VAL_6:.*]], label %[[VAL_7:.*]] -// CHECK: loop.loop_body: ; preds = %[[VAL_1]] -// CHECK-PTX: %[[VAL_8:.*]] = add nuw nsw i32 %[[VAL_4]], 516096 -// CHECK-GCN: %[[VAL_8:.*]] = add nuw nsw i32 %[[VAL_4]], 851968 -// CHECK-PTX: store i32 %[[VAL_8]], ptr %[[VAL_0]], align 4 -// CHECK-GCN: store i32 %[[VAL_8]], ptr addrspace(5) %[[VAL_0]], align 4 -// CHECK: %[[VAL_9:.*]] = icmp eq i32 %[[VAL_4]], 0 -// CHECK-PTX: %[[VAL_10:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x -// CHECK-GCN: %[[VAL_10:.*]] = call i32 @llvm.amdgcn.workgroup.id.x -// CHECK-PTX: %[[VAL_11:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x -// CHECK-GCN: %[[VAL_11:.*]] = call i32 @llvm.amdgcn.workitem.id.x -// CHECK: %[[VAL_12:.*]] = mul nuw nsw i32 %[[VAL_10]], 128 -// CHECK: %[[VAL_13:.*]] = add nuw nsw i32 %[[VAL_12]], %[[VAL_11]] -// CHECK-PTX: %[[VAL_14:.*]] = icmp ult i32 %[[VAL_13]], 129024 -// CHECK-GCN: %[[VAL_14:.*]] = icmp ult i32 %[[VAL_13]], 212992 -// CHECK: call void @llvm.assume(i1 %[[VAL_14]]) -// CHECK: %[[VAL_15:.*]] = mul nuw nsw i32 %[[VAL_13]], 4 -// CHECK: %[[VAL_16:.*]] = add nuw nsw i32 %[[VAL_15]], %[[VAL_4]] -// CHECK: %[[VAL_17:.*]] = add nuw nsw i32 %[[VAL_16]], 0 -// CHECK: %[[VAL_18:.*]] = udiv i32 %[[VAL_17]], 1 -// CHECK: %[[VAL_19:.*]] = urem i32 %[[VAL_18]], 300 -// CHECK: %[[VAL_20:.*]] = udiv i32 %[[VAL_17]], 300 -// CHECK: %[[VAL_21:.*]] = urem i32 %[[VAL_20]], 100 -// CHECK: %[[VAL_22:.*]] = udiv i32 %[[VAL_17]], 30000 -// CHECK: %[[VAL_23:.*]] = add nuw nsw i32 %[[VAL_16]], 1 -// CHECK: %[[VAL_24:.*]] = udiv i32 %[[VAL_23]], 1 -// CHECK: %[[VAL_25:.*]] = urem i32 %[[VAL_24]], 300 -// CHECK: %[[VAL_26:.*]] = udiv i32 %[[VAL_23]], 300 -// CHECK: %[[VAL_27:.*]] = urem i32 %[[VAL_26]], 100 -// CHECK: %[[VAL_28:.*]] = udiv i32 %[[VAL_23]], 30000 -// CHECK: %[[VAL_29:.*]] = add nuw nsw i32 %[[VAL_16]], 2 -// CHECK: %[[VAL_30:.*]] = udiv i32 %[[VAL_29]], 1 -// CHECK: %[[VAL_31:.*]] = urem i32 %[[VAL_30]], 300 -// CHECK: %[[VAL_32:.*]] = udiv i32 %[[VAL_29]], 300 -// CHECK: %[[VAL_33:.*]] = urem i32 %[[VAL_32]], 100 -// CHECK: %[[VAL_34:.*]] = udiv i32 %[[VAL_29]], 30000 -// CHECK: %[[VAL_35:.*]] = add nuw nsw i32 %[[VAL_16]], 3 -// CHECK: %[[VAL_36:.*]] = udiv i32 %[[VAL_35]], 1 -// CHECK: %[[VAL_37:.*]] = urem i32 %[[VAL_36]], 300 -// CHECK: %[[VAL_38:.*]] = udiv i32 %[[VAL_35]], 300 -// CHECK: %[[VAL_39:.*]] = urem i32 %[[VAL_38]], 100 -// CHECK: %[[VAL_40:.*]] = udiv i32 %[[VAL_35]], 30000 -// CHECK: %[[VAL_41:.*]] = icmp ult i32 %[[VAL_16]], 6000000 -// CHECK: br i1 %[[VAL_41]], label %[[VAL_42:.*]], label %[[VAL_2]] -// CHECK: wrapped_b.in_bounds-after: ; preds = %[[VAL_42]], %[[VAL_7]] -// CHECK: br label %[[VAL_1]], !llvm.loop -// CHECK: loop.loop_exit: ; preds = %[[VAL_1]] -// CHECK: ret void -// CHECK: wrapped_b.in_bounds-true: ; preds = %[[VAL_7]] -// CHECK: %[[VAL_43:.*]] = getelementptr inbounds [100 x [200 x [300 x float]]], ptr %[[VAL_44:.*]], i32 0, i32 %[[VAL_21]], i32 %[[VAL_22]], i32 %[[VAL_19]] -// CHECK: %[[VAL_45:.*]] = load float, ptr %[[VAL_43]], align 4, !invariant.load -// CHECK: %[[VAL_46:.*]] = getelementptr float, ptr %[[VAL_47:.*]], i32 %[[VAL_16]] -// CHECK: %[[VAL_48:.*]] = getelementptr inbounds float, ptr %[[VAL_46]], i32 0 -// CHECK: store float %[[VAL_45]], ptr %[[VAL_48]], align 4 -// CHECK: %[[VAL_49:.*]] = getelementptr inbounds [100 x [200 x [300 x float]]], ptr %[[VAL_44]], i32 0, i32 %[[VAL_27]], i32 %[[VAL_28]], i32 %[[VAL_25]] -// CHECK: %[[VAL_50:.*]] = load float, ptr %[[VAL_49]], align 4, !invariant.load -// CHECK: %[[VAL_51:.*]] = getelementptr float, ptr %[[VAL_47]], i32 %[[VAL_16]] -// CHECK: %[[VAL_52:.*]] = getelementptr inbounds float, ptr %[[VAL_51]], i32 1 -// CHECK: store float %[[VAL_50]], ptr %[[VAL_52]], align 4 -// CHECK: %[[VAL_53:.*]] = getelementptr inbounds [100 x [200 x [300 x float]]], ptr %[[VAL_44]], i32 0, i32 %[[VAL_33]], i32 %[[VAL_34]], i32 %[[VAL_31]] -// CHECK: %[[VAL_54:.*]] = load float, ptr %[[VAL_53]], align 4, !invariant.load -// CHECK: %[[VAL_55:.*]] = getelementptr float, ptr %[[VAL_47]], i32 %[[VAL_16]] -// CHECK: %[[VAL_56:.*]] = getelementptr inbounds float, ptr %[[VAL_55]], i32 2 -// CHECK: store float %[[VAL_54]], ptr %[[VAL_56]], align 4 -// CHECK: %[[VAL_57:.*]] = getelementptr inbounds [100 x [200 x [300 x float]]], ptr %[[VAL_44]], i32 0, i32 %[[VAL_39]], i32 %[[VAL_40]], i32 %[[VAL_37]] -// CHECK: %[[VAL_58:.*]] = load float, ptr %[[VAL_57]], align 4, !invariant.load -// CHECK: %[[VAL_59:.*]] = getelementptr float, ptr %[[VAL_47]], i32 %[[VAL_16]] -// CHECK: %[[VAL_60:.*]] = getelementptr inbounds float, ptr %[[VAL_59]], i32 3 -// CHECK: store float %[[VAL_58]], ptr %[[VAL_60]], align 4 -// CHECK: br label %[[VAL_2]] - -HloModule Test, is_scheduled=true - -fused_computation { - param_0 = f32[100,200,300]{2,1,0} parameter(0) - ROOT b.1 = f32[100,200,300]{2,0,1} copy(f32[100,200,300]{2,1,0} param_0) -} - -ENTRY main { - a = f32[100, 200, 300]{2,1,0} parameter(0) - ROOT wrapped_b = f32[100,200,300]{2,0,1} fusion(f32[100,200,300]{2,1,0} %a), kind=kLoop, calls=fused_computation -} diff --git a/third_party/xla/xla/service/gpu/tests/dynamic_shared_memory_test.cc b/third_party/xla/xla/service/gpu/tests/dynamic_shared_memory_test.cc index ff8c5e59d2ef41..b239ffea27ac77 100644 --- a/third_party/xla/xla/service/gpu/tests/dynamic_shared_memory_test.cc +++ b/third_party/xla/xla/service/gpu/tests/dynamic_shared_memory_test.cc @@ -23,6 +23,7 @@ limitations under the License. #include "xla/stream_executor/device_description.h" #include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/gpu/asm_compiler.h" +#include "xla/stream_executor/platform_manager.h" #include "xla/stream_executor/stream_executor.h" #include "xla/xla.pb.h" #include "tsl/platform/status.h" @@ -130,10 +131,10 @@ TEST(SharedMemoryUseTest, ArrayReversalWorks) { // memory with it, read it back inverting both axes, // copy the result back to the host and verify it. se::Platform* platform = - se::MultiPlatformManager::PlatformWithName("cuda").value(); + se::PlatformManager::PlatformWithName("cuda").value(); se::StreamExecutor* executor = platform->ExecutorForDevice(0).value(); se::Stream stream(executor); - stream.Init(); + TF_CHECK_OK(stream.Initialize()); // Use 90% of the available shared memory to verify that a fractional // amount works as well, not only the full size. @@ -170,17 +171,19 @@ TEST(SharedMemoryUseTest, ArrayReversalWorks) { } } - stream.ThenMemcpy(&device_buffer, host_buffer.data(), buffer_size_bytes); + TF_CHECK_OK( + stream.Memcpy(&device_buffer, host_buffer.data(), buffer_size_bytes)); se::DeviceMemory dev_n_cols = executor->AllocateScalar(); - stream.ThenMemcpy(&dev_n_cols, &n_cols, sizeof(uint32_t)); + TF_CHECK_OK(stream.Memcpy(&dev_n_cols, &n_cols, sizeof(uint32_t))); se::DeviceMemory dev_n_rows = executor->AllocateScalar(); - stream.ThenMemcpy(&dev_n_rows, &n_rows, sizeof(uint32_t)); + TF_CHECK_OK(stream.Memcpy(&dev_n_rows, &n_rows, sizeof(uint32_t))); TF_CHECK_OK(stream.BlockHostUntilDone()); TF_CHECK_OK(ExecuteKernelOnStream( *kernel, {device_buffer, dev_n_cols, dev_n_rows}, {/*block_x_count=*/1, /*thread_x_count_per_block=*/n_cols}, &stream)); TF_CHECK_OK(stream.BlockHostUntilDone()); - stream.ThenMemcpy(host_buffer.data(), device_buffer, buffer_size_bytes); + TF_CHECK_OK( + stream.Memcpy(host_buffer.data(), device_buffer, buffer_size_bytes)); TF_CHECK_OK(stream.BlockHostUntilDone()); for (int row = 0; row < n_rows; ++row) { diff --git a/third_party/xla/xla/service/gpu/tests/element_wise_row_vectorization.hlo b/third_party/xla/xla/service/gpu/tests/element_wise_row_vectorization.hlo index a1dd68b4e89652..b49e155da0a685 100644 --- a/third_party/xla/xla/service/gpu/tests/element_wise_row_vectorization.hlo +++ b/third_party/xla/xla/service/gpu/tests/element_wise_row_vectorization.hlo @@ -310,5 +310,8 @@ ENTRY computation { ROOT %fusion.9 = f16[5000,65,65,32] fusion(p0, zero), kind=kLoop, calls=%fused_computation.1 } -// Check that we emit vectorized read. -// CHECK: ld.global.nc.v4.f32 +// Our codegen can't emit a vectorized load here, but it can emit a vectorized +// store. +// CHECK-LABEL: .visible .entry fusion_9 +// CHECK-COUNT-4: ld.global.nc.u16 +// CHECK: st.global.v4.b16 diff --git a/third_party/xla/xla/service/gpu/tests/float_conversions_test.cc b/third_party/xla/xla/service/gpu/tests/float_conversions_test.cc new file mode 100644 index 00000000000000..502c8c223afc16 --- /dev/null +++ b/third_party/xla/xla/service/gpu/tests/float_conversions_test.cc @@ -0,0 +1,199 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "absl/strings/string_view.h" +#include "xla/error_spec.h" +#include "xla/service/gpu/tests/gpu_codegen_test.h" +#include "xla/tests/test_utils.h" +#include "tsl/platform/test.h" + +namespace xla { +namespace gpu { + +class FloatConversionTest : public GpuCodegenTest {}; + +TEST_F(FloatConversionTest, F8E5M2ToF16) { + EXPECT_TRUE(RunAndCompare(R"(ENTRY m { + %p = f8e5m2[] parameter(0) + ROOT %c = f16[] convert(%p) + })", + ErrorSpec{1e-5, 1e-5})); +} + +TEST_F(FloatConversionTest, F8E4M3FNToF16) { + EXPECT_TRUE(RunAndCompare(R"(ENTRY m { + %p = f8e4m3fn[] parameter(0) + ROOT %c = f16[] convert(%p) + })", + ErrorSpec{1e-5, 1e-5})); +} + +TEST_F(FloatConversionTest, F8E4M3B11FNUZToF16) { + EXPECT_TRUE(RunAndCompare(R"(ENTRY m { + %p = f8e4m3b11fnuz[] parameter(0) + ROOT %c = f16[] convert(%p) + })", + ErrorSpec{1e-5, 1e-5})); +} + +TEST_F(FloatConversionTest, F8E5M2FNUZToF16) { + EXPECT_TRUE(RunAndCompare(R"(ENTRY m { + %p = f8e5m2fnuz[] parameter(0) + ROOT %c = f16[] convert(%p) + })", + ErrorSpec{1e-5, 1e-5})); +} + +TEST_F(FloatConversionTest, F8E4M3FNUZToF16) { + EXPECT_TRUE(RunAndCompare(R"(ENTRY m { + %p = f8e4m3fnuz[] parameter(0) + ROOT %c = f16[] convert(%p) + })", + ErrorSpec{1e-5, 1e-5})); +} + +TEST_F(FloatConversionTest, BF16ToF32) { + EXPECT_TRUE(RunAndCompare(R"(ENTRY m { + %p = bf16[] parameter(0) + ROOT %c = f32[] convert(%p) + })", + ErrorSpec{1e-5, 1e-5})); +} + +TEST_F(FloatConversionTest, F16ToF32) { + EXPECT_TRUE(RunAndCompare(R"(ENTRY m { + %p = f16[] parameter(0) + ROOT %c = f32[] convert(%p) + })", + ErrorSpec{1e-5, 1e-5})); +} + +TEST_F(FloatConversionTest, F64ToF32) { + EXPECT_TRUE(RunAndCompare(R"(ENTRY m { + %p = f64[] parameter(0) + ROOT %c = f32[] convert(%p) + })", + ErrorSpec{1e-5, 1e-5})); +} + +TEST_F(FloatConversionTest, F16ToF8E5M2) { + EXPECT_TRUE(RunAndCompare(R"(ENTRY m { + %p = f16[] parameter(0) + ROOT %c = f8e5m2[] convert(%p) + })", + ErrorSpec{1e-5, 1e-5})); +} + +TEST_F(FloatConversionTest, F16ToF8E4M3FN) { + EXPECT_TRUE(RunAndCompare(R"(ENTRY m { + %p = f16[] parameter(0) + ROOT %c = f8e4m3fn[] convert(%p) + })", + ErrorSpec{1e-5, 1e-5})); +} + +TEST_F(FloatConversionTest, F16ToF8E4M3B11FNUZ) { + EXPECT_TRUE(RunAndCompare(R"(ENTRY m { + %p = f16[] parameter(0) + ROOT %c = f8e4m3b11fnuz[] convert(%p) + })", + ErrorSpec{1e-5, 1e-5})); +} + +TEST_F(FloatConversionTest, F16ToF8E5M2FNUZ) { + EXPECT_TRUE(RunAndCompare(R"(ENTRY m { + %p = f16[] parameter(0) + ROOT %c = f8e5m2fnuz[] convert(%p) + })", + ErrorSpec{1e-5, 1e-5})); +} + +TEST_F(FloatConversionTest, F16ToF8E4M3FNUZ) { + EXPECT_TRUE(RunAndCompare(R"(ENTRY m { + %p = f16[] parameter(0) + ROOT %c = f8e4m3fnuz[] convert(%p) + })", + ErrorSpec{1e-5, 1e-5})); +} + +TEST_F(FloatConversionTest, F32ToBF16) { + EXPECT_TRUE(RunAndCompare(R"(ENTRY m { + %p = f32[] parameter(0) + ROOT %c = bf16[] convert(%p) + })", + ErrorSpec{1e-5, 1e-5})); +} + +TEST_F(FloatConversionTest, F32ToF16) { + EXPECT_TRUE(RunAndCompare(R"(ENTRY m { + %p = f32[] parameter(0) + ROOT %c = f16[] convert(%p) + })", + ErrorSpec{1e-5, 1e-5})); +} + +TEST_F(FloatConversionTest, F32ToF64) { + EXPECT_TRUE(RunAndCompare(R"(ENTRY m { + %p = f32[] parameter(0) + ROOT %c = f64[] convert(%p) + })", + ErrorSpec{1e-5, 1e-5})); +} + +TEST_F(FloatConversionTest, F32ToPred) { + EXPECT_TRUE(RunAndCompare(R"(ENTRY m { + iota = f32[1000] iota(), iota_dimension=0 + c500 = f32[] constant(500) + c500_b = f32[1000] broadcast(c500), dimensions={} + sub = f32[1000] subtract(iota, c500_b) + ROOT c = pred[1000] convert(sub) + })", + ErrorSpec{1e-5, 1e-5})); + + EXPECT_TRUE(RunAndCompare(R"(ENTRY m { + n = f32[] constant(nan) + ROOT c = pred[] convert(n) + })", + ErrorSpec{1e-5, 1e-5})); +} + +TEST_F(FloatConversionTest, F32ToS8) { + EXPECT_TRUE(RunAndCompare(R"(ENTRY m { + iota = f32[1000] iota(), iota_dimension=0 + c500 = f32[] constant(500) + c500_b = f32[1000] broadcast(c500), dimensions={} + sub = f32[1000] subtract(iota, c500_b) + ROOT c = s8[1000] convert(sub) + })", + ErrorSpec{1e-5, 1e-5})); + + EXPECT_TRUE(RunAndCompare(R"(ENTRY m { + n = f32[] constant(nan) + ROOT c = s8[] convert(n) + })", + ErrorSpec{1e-5, 1e-5})); +} + +TEST_F(FloatConversionTest, BF16ToS16IsBroken) { + EXPECT_TRUE(RunAndCompare(R"(ENTRY m { + iota = u16[65536] iota(), iota_dimension=0 + bc = bf16[65536] bitcast-convert(iota) + ROOT c = s16[65536] convert(bc) + })", + ErrorSpec{1e-5, 1e-5})); +} + +} // namespace gpu +} // namespace xla diff --git a/third_party/xla/xla/service/gpu/tests/fused_scatter.hlo b/third_party/xla/xla/service/gpu/tests/fused_scatter.hlo index b287e2fad69129..9a30436ebfa38c 100644 --- a/third_party/xla/xla/service/gpu/tests/fused_scatter.hlo +++ b/third_party/xla/xla/service/gpu/tests/fused_scatter.hlo @@ -2,92 +2,7 @@ // NOTE: Assertions have been autogenerated by utils/generate-test-checks.py -// CHECK-LABEL: entry: -// CHECK-PTX: %[[VAL_0:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x -// CHECK-GCN: %[[VAL_0:.*]] = call i32 @llvm.amdgcn.workgroup.id.x -// CHECK-PTX: %[[VAL_1:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x -// CHECK-GCN: %[[VAL_1:.*]] = call i32 @llvm.amdgcn.workitem.id.x -// CHECK: %[[VAL_2:.*]] = mul nuw nsw i32 %[[VAL_0]], 2 -// CHECK: %[[VAL_3:.*]] = add nuw nsw i32 %[[VAL_2]], %[[VAL_1]] -// CHECK: %[[VAL_4:.*]] = icmp ult i32 %[[VAL_3]], 2 -// CHECK: call void @llvm.assume(i1 %[[VAL_4]]) -// CHECK: %[[VAL_5:.*]] = add nuw nsw i32 %[[VAL_3]], 0 -// CHECK: %[[VAL_6:.*]] = udiv i32 %[[VAL_5]], 1 -// CHECK: %[[VAL_7:.*]] = icmp ult i32 %[[VAL_3]], 2 -// CHECK: br i1 %[[VAL_7]], label %[[VAL_8:.*]], label %[[VAL_9:.*]] -// CHECK: wrapped_indices.in_bounds-after: ; preds = %[[VAL_8]], %[[VAL_10:.*]] -// CHECK: ret void -// CHECK: wrapped_indices.in_bounds-true: ; preds = %[[VAL_10]] -// CHECK: %[[VAL_11:.*]] = getelementptr i32, ptr %[[VAL_12:.*]], i32 %[[VAL_3]] -// CHECK: %[[VAL_13:.*]] = getelementptr inbounds i32, ptr %[[VAL_11]], i32 0 -// CHECK: %[[VAL_14:.*]] = load i32, ptr %[[VAL_13]], align 4, !invariant.load -// CHECK: %[[VAL_15:.*]] = getelementptr i32, ptr %[[VAL_12]], i32 %[[VAL_3]] -// CHECK: %[[VAL_16:.*]] = getelementptr inbounds i32, ptr %[[VAL_15]], i32 0 -// CHECK: %[[VAL_17:.*]] = load i32, ptr %[[VAL_16]], align 4, !invariant.load -// CHECK: %[[VAL_18:.*]] = add i32 %[[VAL_14]], %[[VAL_17]] -// CHECK: %[[VAL_19:.*]] = getelementptr i32, ptr %[[VAL_20:.*]], i32 %[[VAL_3]] -// CHECK: %[[VAL_21:.*]] = getelementptr inbounds i32, ptr %[[VAL_19]], i32 0 -// CHECK: store i32 %[[VAL_18]], ptr %[[VAL_21]], align 4 -// CHECK: br label %[[VAL_9]] -// CHECK: entry: -// CHECK-PTX: %[[VAL_22:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x -// CHECK-GCN: %[[VAL_22:.*]] = call i32 @llvm.amdgcn.workgroup.id.x -// CHECK-PTX: %[[VAL_23:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x -// CHECK-GCN: %[[VAL_23:.*]] = call i32 @llvm.amdgcn.workitem.id.x -// CHECK: %[[VAL_24:.*]] = mul nuw nsw i32 %[[VAL_22]], 6 -// CHECK: %[[VAL_25:.*]] = add nuw nsw i32 %[[VAL_24]], %[[VAL_23]] -// CHECK: %[[VAL_26:.*]] = icmp ult i32 %[[VAL_25]], 6 -// CHECK: call void @llvm.assume(i1 %[[VAL_26]]) -// CHECK: %[[VAL_27:.*]] = add nuw nsw i32 %[[VAL_25]], 0 -// CHECK: %[[VAL_28:.*]] = udiv i32 %[[VAL_27]], 1 -// CHECK: %[[VAL_29:.*]] = urem i32 %[[VAL_28]], 3 -// CHECK: %[[VAL_30:.*]] = udiv i32 %[[VAL_27]], 3 -// CHECK: %[[VAL_31:.*]] = icmp ult i32 %[[VAL_25]], 6 -// CHECK: br i1 %[[VAL_31]], label %[[VAL_32:.*]], label %[[VAL_33:.*]] -// CHECK: wrapped_updates.in_bounds-after: ; preds = %[[VAL_32]], %[[VAL_34:.*]] -// CHECK: ret void -// CHECK: wrapped_updates.in_bounds-true: ; preds = %[[VAL_34]] -// CHECK: %[[VAL_35:.*]] = getelementptr i32, ptr %[[VAL_36:.*]], i32 %[[VAL_25]] -// CHECK: %[[VAL_37:.*]] = getelementptr inbounds i32, ptr %[[VAL_35]], i32 0 -// CHECK: %[[VAL_38:.*]] = load i32, ptr %[[VAL_37]], align 4, !invariant.load -// CHECK: %[[VAL_39:.*]] = getelementptr i32, ptr %[[VAL_36]], i32 %[[VAL_25]] -// CHECK: %[[VAL_40:.*]] = getelementptr inbounds i32, ptr %[[VAL_39]], i32 0 -// CHECK: %[[VAL_41:.*]] = load i32, ptr %[[VAL_40]], align 4, !invariant.load -// CHECK: %[[VAL_42:.*]] = add i32 %[[VAL_38]], %[[VAL_41]] -// CHECK: %[[VAL_43:.*]] = getelementptr i32, ptr %[[VAL_44:.*]], i32 %[[VAL_25]] -// CHECK: %[[VAL_45:.*]] = getelementptr inbounds i32, ptr %[[VAL_43]], i32 0 -// CHECK: store i32 %[[VAL_42]], ptr %[[VAL_45]], align 4 -// CHECK: br label %[[VAL_33]] -// CHECK: entry: -// CHECK-PTX: %[[VAL_46:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x -// CHECK-GCN: %[[VAL_46:.*]] = call i32 @llvm.amdgcn.workgroup.id.x -// CHECK-PTX: %[[VAL_47:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x -// CHECK-GCN: %[[VAL_47:.*]] = call i32 @llvm.amdgcn.workitem.id.x -// CHECK: %[[VAL_48:.*]] = mul nuw nsw i32 %[[VAL_46]], 9 -// CHECK: %[[VAL_49:.*]] = add nuw nsw i32 %[[VAL_48]], %[[VAL_47]] -// CHECK: %[[VAL_50:.*]] = icmp ult i32 %[[VAL_49]], 9 -// CHECK: call void @llvm.assume(i1 %[[VAL_50]]) -// CHECK: %[[VAL_51:.*]] = add nuw nsw i32 %[[VAL_49]], 0 -// CHECK: %[[VAL_52:.*]] = udiv i32 %[[VAL_51]], 1 -// CHECK: %[[VAL_53:.*]] = urem i32 %[[VAL_52]], 3 -// CHECK: %[[VAL_54:.*]] = udiv i32 %[[VAL_51]], 3 -// CHECK: %[[VAL_55:.*]] = icmp ult i32 %[[VAL_49]], 9 -// CHECK: br i1 %[[VAL_55]], label %[[VAL_56:.*]], label %[[VAL_57:.*]] -// CHECK: wrapped_operand.in_bounds-after: ; preds = %[[VAL_56]], %[[VAL_58:.*]] -// CHECK: ret void -// CHECK: wrapped_operand.in_bounds-true: ; preds = %[[VAL_58]] -// CHECK: %[[VAL_59:.*]] = getelementptr i32, ptr %[[VAL_60:.*]], i32 %[[VAL_49]] -// CHECK: %[[VAL_61:.*]] = getelementptr inbounds i32, ptr %[[VAL_59]], i32 0 -// CHECK: %[[VAL_62:.*]] = load i32, ptr %[[VAL_61]], align 4, !invariant.load -// CHECK: %[[VAL_63:.*]] = getelementptr i32, ptr %[[VAL_60]], i32 %[[VAL_49]] -// CHECK: %[[VAL_64:.*]] = getelementptr inbounds i32, ptr %[[VAL_63]], i32 0 -// CHECK: %[[VAL_65:.*]] = load i32, ptr %[[VAL_64]], align 4, !invariant.load -// CHECK: %[[VAL_66:.*]] = add i32 %[[VAL_62]], %[[VAL_65]] -// CHECK: %[[VAL_67:.*]] = getelementptr i32, ptr %[[VAL_68:.*]], i32 %[[VAL_49]] -// CHECK: %[[VAL_69:.*]] = getelementptr inbounds i32, ptr %[[VAL_67]], i32 0 -// CHECK: store i32 %[[VAL_66]], ptr %[[VAL_69]], align 4 -// CHECK: br label %[[VAL_57]] -// CHECK: entry: +// CHECK: define void @wrapped_scatter // CHECK: %[[VAL_70:.*]] = alloca i32, align 4 // CHECK-PTX: %[[VAL_71:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x // CHECK-GCN: %[[VAL_71:.*]] = call i32 @llvm.amdgcn.workgroup.id.x diff --git a/third_party/xla/xla/service/gpu/tests/fused_slice_different_operands.hlo b/third_party/xla/xla/service/gpu/tests/fused_slice_different_operands.hlo deleted file mode 100644 index c10dfd2395efcb..00000000000000 --- a/third_party/xla/xla/service/gpu/tests/fused_slice_different_operands.hlo +++ /dev/null @@ -1,98 +0,0 @@ -// RUN: hlo-opt %s --platform=gpu --stage=llvm-before-optimizations --xla_gpu_target_config_filename=%S/../../../tools/hlo_opt/gpu_specs/%{GPU}.txtpb | FileCheck --check-prefixes=CHECK,CHECK-%{PTX} %s - -// NOTE: Assertions have been autogenerated by utils/generate-test-checks.py - -// The script is designed to make adding checks to -// a test case fast, it is *not* designed to be authoritative -// about what constitutes a good test! The CHECK should be -// minimized and named to reflect the test intent. - - -// CHECK-LABEL: entry: -// CHECK-PTX: %[[VAL_0:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x -// CHECK-GCN: %[[VAL_0:.*]] = call i32 @llvm.amdgcn.workgroup.id.x -// CHECK-PTX: %[[VAL_1:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x -// CHECK-GCN: %[[VAL_1:.*]] = call i32 @llvm.amdgcn.workitem.id.x -// CHECK-PTX: %[[VAL_2:.*]] = mul nuw nsw i32 %[[VAL_0]], 128 -// CHECK-GCN: %[[VAL_2:.*]] = mul nuw nsw i32 %[[VAL_0]], 256 -// CHECK: %[[VAL_3:.*]] = add nuw nsw i32 %[[VAL_2]], %[[VAL_1]] -// CHECK: %[[VAL_4:.*]] = icmp ult i32 %[[VAL_3]], 1024 -// CHECK: call void @llvm.assume(i1 %[[VAL_4]]) -// CHECK: %[[VAL_5:.*]] = add nuw nsw i32 %[[VAL_3]], 0 -// CHECK: %[[VAL_6:.*]] = udiv i32 %[[VAL_5]], 1 -// CHECK: %[[VAL_7:.*]] = icmp ult i32 %[[VAL_3]], 1024 -// CHECK: br i1 %[[VAL_7]], label %[[VAL_8:.*]], label %[[VAL_9:.*]] -// CHECK: fusion.in_bounds-after: ; preds = %[[VAL_10:.*]], %[[VAL_11:.*]] -// CHECK: ret void -// CHECK: fusion.in_bounds-true: ; preds = %[[VAL_11]] -// CHECK: %[[VAL_12:.*]] = add i32 %[[VAL_6]], 0 -// CHECK: br label %[[VAL_13:.*]] -// CHECK: concat_index_from_operand_id0: ; preds = %[[VAL_14:.*]] -// CHECK: %[[VAL_15:.*]] = phi i32 [ 0, %[[VAL_14]] ] -// CHECK: %[[VAL_16:.*]] = sub nsw i32 %[[VAL_12]], %[[VAL_15]] -// CHECK: %[[VAL_17:.*]] = getelementptr inbounds [1024 x half], ptr %[[VAL_18:.*]], i32 0, i32 %[[VAL_16]] -// CHECK: %[[VAL_19:.*]] = load half, ptr %[[VAL_17]], align 2, !invariant.load -// CHECK: %[[VAL_20:.*]] = getelementptr inbounds [1024 x half], ptr %[[VAL_21:.*]], i32 0, i32 %[[VAL_16]] -// CHECK: %[[VAL_22:.*]] = load half, ptr %[[VAL_20]], align 2, !invariant.load -// CHECK: %[[VAL_23:.*]] = fmul half %[[VAL_19]], %[[VAL_22]] -// CHECK: br label %[[VAL_10]] -// CHECK: concat_index_from_operand_id1: ; preds = %[[VAL_24:.*]] -// CHECK: %[[VAL_25:.*]] = phi i32 [ 1024, %[[VAL_24]] ] -// CHECK: %[[VAL_26:.*]] = sub nsw i32 %[[VAL_12]], %[[VAL_25]] -// CHECK: %[[VAL_27:.*]] = getelementptr inbounds [1023 x half], ptr %[[VAL_28:.*]], i32 0, i32 %[[VAL_26]] -// CHECK: %[[VAL_29:.*]] = load half, ptr %[[VAL_27]], align 2, !invariant.load -// CHECK: %[[VAL_30:.*]] = getelementptr inbounds [1023 x half], ptr %[[VAL_31:.*]], i32 0, i32 %[[VAL_26]] -// CHECK: %[[VAL_32:.*]] = load half, ptr %[[VAL_30]], align 2, !invariant.load -// CHECK: %[[VAL_33:.*]] = fadd half %[[VAL_29]], %[[VAL_32]] -// CHECK: br label %[[VAL_10]] -// CHECK: concatenate.pivot.1024.: ; preds = %[[VAL_8]] -// CHECK: %[[VAL_34:.*]] = icmp ult i32 %[[VAL_12]], 1024 -// CHECK: br i1 %[[VAL_34]], label %[[VAL_14]], label %[[VAL_24]] -// CHECK: concatenate.pivot.0.: ; preds = %[[VAL_13]] -// CHECK: br label %[[VAL_35:.*]] -// CHECK: concatenate.pivot.1024.1: ; preds = %[[VAL_13]] -// CHECK: br label %[[VAL_36:.*]] -// CHECK: concat.1.merge: ; preds = %[[VAL_36]], %[[VAL_35]] -// CHECK: %[[VAL_37:.*]] = phi half [ %[[VAL_23]], %[[VAL_35]] ], [ %[[VAL_33]], %[[VAL_36]] ] -// CHECK: %[[VAL_38:.*]] = insertvalue { half, half } undef, half %[[VAL_37]], 0 -// CHECK: %[[VAL_39:.*]] = add i32 %[[VAL_6]], 0 -// CHECK: %[[VAL_40:.*]] = getelementptr inbounds [1024 x half], ptr %[[VAL_18]], i32 0, i32 %[[VAL_39]] -// CHECK: %[[VAL_41:.*]] = load half, ptr %[[VAL_40]], align 2, !invariant.load -// CHECK: %[[VAL_42:.*]] = getelementptr inbounds [1024 x half], ptr %[[VAL_21]], i32 0, i32 %[[VAL_39]] -// CHECK: %[[VAL_43:.*]] = load half, ptr %[[VAL_42]], align 2, !invariant.load -// CHECK: %[[VAL_44:.*]] = fmul half %[[VAL_41]], %[[VAL_43]] -// CHECK: %[[VAL_45:.*]] = insertvalue { half, half } %[[VAL_38]], half %[[VAL_44]], 1 -// CHECK: %[[VAL_46:.*]] = extractvalue { half, half } %[[VAL_45]], 0 -// CHECK: %[[VAL_47:.*]] = getelementptr half, ptr %[[VAL_48:.*]], i32 %[[VAL_3]] -// CHECK: %[[VAL_49:.*]] = getelementptr inbounds half, ptr %[[VAL_47]], i32 0 -// CHECK: store half %[[VAL_46]], ptr %[[VAL_49]], align 2 -// CHECK: %[[VAL_50:.*]] = extractvalue { half, half } %[[VAL_45]], 1 -// CHECK: %[[VAL_51:.*]] = getelementptr half, ptr %[[VAL_52:.*]], i32 %[[VAL_3]] -// CHECK: %[[VAL_53:.*]] = getelementptr inbounds half, ptr %[[VAL_51]], i32 0 -// CHECK: store half %[[VAL_50]], ptr %[[VAL_53]], align 2 -// CHECK: br label %[[VAL_9]] - -HloModule input_fusion_with_a_tuple_of_slices, is_scheduled=true - -fused_computation { - arg.1 = f16[1024]{0} parameter(0) - arg.2 = f16[1024]{0} parameter(1) - arg.3 = f16[1023]{0} parameter(2) - arg.4 = f16[1023]{0} parameter(3) - mul.1 = f16[1024]{0} multiply(arg.1, arg.2) - add.1 = f16[1023]{0} add(arg.3, arg.4) - concat.1 = f16[2047]{0} concatenate(mul.1, add.1), dimensions={0} - slice.1 = f16[1024]{0} slice(concat.1), slice={[0:1024]} - slice.2 = f16[1024]{0} slice(mul.1), slice={[0:1024]} - ROOT tuple.1 = (f16[1024]{0}, f16[1024]{0}) tuple(slice.1, slice.2) -} - -ENTRY kernel_entry { - arg.1 = f16[1024]{0} parameter(0) - arg.2 = f16[1024]{0} parameter(1) - arg.3 = f16[1023]{0} parameter(2) - arg.4 = f16[1023]{0} parameter(3) - ROOT fusion = (f16[1024]{0}, f16[1024]{0}) - fusion(arg.1, arg.2, arg.3, arg.4), kind=kLoop, calls=fused_computation -} - diff --git a/third_party/xla/xla/service/gpu/tests/fusion.hlo b/third_party/xla/xla/service/gpu/tests/fusion.hlo deleted file mode 100644 index de3058b0bcb7d5..00000000000000 --- a/third_party/xla/xla/service/gpu/tests/fusion.hlo +++ /dev/null @@ -1,290 +0,0 @@ -// RUN: hlo-opt %s --platform=gpu --stage=llvm-before-optimizations --xla_gpu_target_config_filename=%S/../../../tools/hlo_opt/gpu_specs/%{GPU}.txtpb | FileCheck --check-prefixes=CHECK,CHECK-%{PTX} %s - -HloModule TestModule, is_scheduled=true - -// NOTE: Assertions have been autogenerated by utils/generate-test-checks.py - -// CHECK-LABEL: entry: -// CHECK-PTX: %[[VAL_0:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x -// CHECK-GCN: %[[VAL_0:.*]] = call i32 @llvm.amdgcn.workgroup.id.x -// CHECK-PTX: %[[VAL_1:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x -// CHECK-GCN: %[[VAL_1:.*]] = call i32 @llvm.amdgcn.workitem.id.x -// CHECK-PTX: %[[VAL_2:.*]] = mul nuw nsw i32 %[[VAL_0]], 128 -// CHECK-GCN: %[[VAL_2:.*]] = mul nuw nsw i32 %[[VAL_0]], 256 -// CHECK: %[[VAL_3:.*]] = add nuw nsw i32 %[[VAL_2]], %[[VAL_1]] -// CHECK: %[[VAL_4:.*]] = icmp ult i32 %[[VAL_3]], 25690112 -// CHECK: call void @llvm.assume(i1 %[[VAL_4]]) -// CHECK: %[[VAL_5:.*]] = mul nuw nsw i32 %[[VAL_3]], 4 -// CHECK: %[[VAL_6:.*]] = add nuw nsw i32 %[[VAL_5]], 0 -// CHECK: %[[VAL_7:.*]] = udiv i32 %[[VAL_6]], 1 -// CHECK: %[[VAL_8:.*]] = urem i32 %[[VAL_7]], 64 -// CHECK: %[[VAL_9:.*]] = udiv i32 %[[VAL_6]], 64 -// CHECK: %[[VAL_10:.*]] = urem i32 %[[VAL_9]], 112 -// CHECK: %[[VAL_11:.*]] = udiv i32 %[[VAL_6]], 7168 -// CHECK: %[[VAL_12:.*]] = urem i32 %[[VAL_11]], 112 -// CHECK: %[[VAL_13:.*]] = udiv i32 %[[VAL_6]], 802816 -// CHECK: %[[VAL_14:.*]] = add nuw nsw i32 %[[VAL_5]], 1 -// CHECK: %[[VAL_15:.*]] = udiv i32 %[[VAL_14]], 1 -// CHECK: %[[VAL_16:.*]] = urem i32 %[[VAL_15]], 64 -// CHECK: %[[VAL_17:.*]] = udiv i32 %[[VAL_14]], 64 -// CHECK: %[[VAL_18:.*]] = urem i32 %[[VAL_17]], 112 -// CHECK: %[[VAL_19:.*]] = udiv i32 %[[VAL_14]], 7168 -// CHECK: %[[VAL_20:.*]] = urem i32 %[[VAL_19]], 112 -// CHECK: %[[VAL_21:.*]] = udiv i32 %[[VAL_14]], 802816 -// CHECK: %[[VAL_22:.*]] = add nuw nsw i32 %[[VAL_5]], 2 -// CHECK: %[[VAL_23:.*]] = udiv i32 %[[VAL_22]], 1 -// CHECK: %[[VAL_24:.*]] = urem i32 %[[VAL_23]], 64 -// CHECK: %[[VAL_25:.*]] = udiv i32 %[[VAL_22]], 64 -// CHECK: %[[VAL_26:.*]] = urem i32 %[[VAL_25]], 112 -// CHECK: %[[VAL_27:.*]] = udiv i32 %[[VAL_22]], 7168 -// CHECK: %[[VAL_28:.*]] = urem i32 %[[VAL_27]], 112 -// CHECK: %[[VAL_29:.*]] = udiv i32 %[[VAL_22]], 802816 -// CHECK: %[[VAL_30:.*]] = add nuw nsw i32 %[[VAL_5]], 3 -// CHECK: %[[VAL_31:.*]] = udiv i32 %[[VAL_30]], 1 -// CHECK: %[[VAL_32:.*]] = urem i32 %[[VAL_31]], 64 -// CHECK: %[[VAL_33:.*]] = udiv i32 %[[VAL_30]], 64 -// CHECK: %[[VAL_34:.*]] = urem i32 %[[VAL_33]], 112 -// CHECK: %[[VAL_35:.*]] = udiv i32 %[[VAL_30]], 7168 -// CHECK: %[[VAL_36:.*]] = urem i32 %[[VAL_35]], 112 -// CHECK: %[[VAL_37:.*]] = udiv i32 %[[VAL_30]], 802816 -// CHECK: %[[VAL_38:.*]] = icmp ult i32 %[[VAL_5]], 102760448 -// CHECK: br i1 %[[VAL_38]], label %[[VAL_39:.*]], label %[[VAL_40:.*]] -// CHECK: fusion.1.in_bounds-after: ; preds = %[[VAL_39]], %[[VAL_41:.*]] -// CHECK: ret void -// CHECK: fusion.1.in_bounds-true: ; preds = %[[VAL_41]] -// CHECK: %[[VAL_42:.*]] = getelementptr inbounds [64 x float], ptr %[[VAL_43:.*]], i32 0, i32 %[[VAL_8]] -// CHECK: %[[VAL_44:.*]] = load float, ptr %[[VAL_42]], align 4, !invariant.load -// CHECK: %[[VAL_45:.*]] = getelementptr inbounds [64 x float], ptr %[[VAL_46:.*]], i32 0, i32 %[[VAL_8]] -// CHECK: %[[VAL_47:.*]] = load float, ptr %[[VAL_45]], align 4, !invariant.load -// CHECK: %[[VAL_48:.*]] = fmul float %[[VAL_44]], %[[VAL_47]] -// CHECK: %[[VAL_49:.*]] = load float, ptr @0, align 4 -// CHECK: %[[VAL_50:.*]] = fmul float %[[VAL_48]], %[[VAL_49]] -// CHECK: %[[VAL_51:.*]] = getelementptr half, ptr %[[VAL_52:.*]], i32 %[[VAL_5]] -// CHECK: %[[VAL_53:.*]] = getelementptr inbounds half, ptr %[[VAL_51]], i32 0 -// CHECK: %[[VAL_54:.*]] = load half, ptr %[[VAL_53]], align 2, !invariant.load -// CHECK: %[[VAL_55:.*]] = load half, ptr @2, align 2 -// CHECK: %[[VAL_56:.*]] = fcmp ogt half %[[VAL_54]], %[[VAL_55]] -// CHECK: %[[VAL_57:.*]] = zext i1 %[[VAL_56]] to i8 -// CHECK: %[[VAL_58:.*]] = getelementptr half, ptr %[[VAL_59:.*]], i32 %[[VAL_5]] -// CHECK: %[[VAL_60:.*]] = getelementptr inbounds half, ptr %[[VAL_58]], i32 0 -// CHECK: %[[VAL_61:.*]] = load half, ptr %[[VAL_60]], align 2, !invariant.load -// CHECK: %[[VAL_62:.*]] = trunc i8 %[[VAL_57]] to i1 -// CHECK: %[[VAL_63:.*]] = select i1 %[[VAL_62]], half %[[VAL_61]], half %[[VAL_55]] -// CHECK: %[[VAL_64:.*]] = fpext half %[[VAL_63]] to float -// CHECK: %[[VAL_65:.*]] = load float, ptr @1, align 4 -// CHECK: %[[VAL_66:.*]] = fmul float %[[VAL_64]], %[[VAL_65]] -// CHECK: %[[VAL_67:.*]] = getelementptr inbounds [64 x float], ptr %[[VAL_68:.*]], i32 0, i32 %[[VAL_8]] -// CHECK: %[[VAL_69:.*]] = load float, ptr %[[VAL_67]], align 4, !invariant.load -// CHECK: %[[VAL_70:.*]] = fsub float %[[VAL_66]], %[[VAL_69]] -// CHECK: %[[VAL_71:.*]] = getelementptr inbounds [64 x float], ptr %[[VAL_72:.*]], i32 0, i32 %[[VAL_8]] -// CHECK: %[[VAL_73:.*]] = load float, ptr %[[VAL_71]], align 4, !invariant.load -// CHECK: %[[VAL_74:.*]] = getelementptr half, ptr %[[VAL_75:.*]], i32 %[[VAL_5]] -// CHECK: %[[VAL_76:.*]] = getelementptr inbounds half, ptr %[[VAL_74]], i32 0 -// CHECK: %[[VAL_77:.*]] = load half, ptr %[[VAL_76]], align 2, !invariant.load -// CHECK: %[[VAL_78:.*]] = fpext half %[[VAL_77]] to float -// CHECK: %[[VAL_79:.*]] = getelementptr inbounds [64 x float], ptr %[[VAL_80:.*]], i32 0, i32 %[[VAL_8]] -// CHECK: %[[VAL_81:.*]] = load float, ptr %[[VAL_79]], align 4, !invariant.load -// CHECK: %[[VAL_82:.*]] = load float, ptr @0, align 4 -// CHECK: %[[VAL_83:.*]] = fmul float %[[VAL_81]], %[[VAL_82]] -// CHECK: %[[VAL_84:.*]] = fsub float %[[VAL_78]], %[[VAL_83]] -// CHECK: %[[VAL_85:.*]] = fmul float %[[VAL_73]], %[[VAL_84]] -// CHECK: %[[VAL_86:.*]] = getelementptr inbounds [64 x float], ptr %[[VAL_87:.*]], i32 0, i32 %[[VAL_8]] -// CHECK: %[[VAL_88:.*]] = load float, ptr %[[VAL_86]], align 4, !invariant.load -// CHECK: %[[VAL_89:.*]] = fdiv float %[[VAL_85]], %[[VAL_88]] -// CHECK: %[[VAL_90:.*]] = fsub float %[[VAL_70]], %[[VAL_89]] -// CHECK: %[[VAL_91:.*]] = fmul float %[[VAL_50]], %[[VAL_90]] -// CHECK: %[[VAL_92:.*]] = fptrunc float %[[VAL_91]] to half -// CHECK: %[[VAL_93:.*]] = getelementptr half, ptr %[[VAL_94:.*]], i32 %[[VAL_5]] -// CHECK: %[[VAL_95:.*]] = getelementptr inbounds half, ptr %[[VAL_93]], i32 0 -// CHECK: store half %[[VAL_92]], ptr %[[VAL_95]], align 2 -// CHECK: %[[VAL_96:.*]] = getelementptr inbounds [64 x float], ptr %[[VAL_43]], i32 0, i32 %[[VAL_16]] -// CHECK: %[[VAL_97:.*]] = load float, ptr %[[VAL_96]], align 4, !invariant.load -// CHECK: %[[VAL_98:.*]] = getelementptr inbounds [64 x float], ptr %[[VAL_46]], i32 0, i32 %[[VAL_16]] -// CHECK: %[[VAL_99:.*]] = load float, ptr %[[VAL_98]], align 4, !invariant.load -// CHECK: %[[VAL_100:.*]] = fmul float %[[VAL_97]], %[[VAL_99]] -// CHECK: %[[VAL_101:.*]] = load float, ptr @0, align 4 -// CHECK: %[[VAL_102:.*]] = fmul float %[[VAL_100]], %[[VAL_101]] -// CHECK: %[[VAL_103:.*]] = getelementptr half, ptr %[[VAL_52]], i32 %[[VAL_5]] -// CHECK: %[[VAL_104:.*]] = getelementptr inbounds half, ptr %[[VAL_103]], i32 1 -// CHECK: %[[VAL_105:.*]] = load half, ptr %[[VAL_104]], align 2, !invariant.load -// CHECK: %[[VAL_106:.*]] = load half, ptr @2, align 2 -// CHECK: %[[VAL_107:.*]] = fcmp ogt half %[[VAL_105]], %[[VAL_106]] -// CHECK: %[[VAL_108:.*]] = zext i1 %[[VAL_107]] to i8 -// CHECK: %[[VAL_109:.*]] = getelementptr half, ptr %[[VAL_59]], i32 %[[VAL_5]] -// CHECK: %[[VAL_110:.*]] = getelementptr inbounds half, ptr %[[VAL_109]], i32 1 -// CHECK: %[[VAL_111:.*]] = load half, ptr %[[VAL_110]], align 2, !invariant.load -// CHECK: %[[VAL_112:.*]] = trunc i8 %[[VAL_108]] to i1 -// CHECK: %[[VAL_113:.*]] = select i1 %[[VAL_112]], half %[[VAL_111]], half %[[VAL_106]] -// CHECK: %[[VAL_114:.*]] = fpext half %[[VAL_113]] to float -// CHECK: %[[VAL_115:.*]] = load float, ptr @1, align 4 -// CHECK: %[[VAL_116:.*]] = fmul float %[[VAL_114]], %[[VAL_115]] -// CHECK: %[[VAL_117:.*]] = getelementptr inbounds [64 x float], ptr %[[VAL_68]], i32 0, i32 %[[VAL_16]] -// CHECK: %[[VAL_118:.*]] = load float, ptr %[[VAL_117]], align 4, !invariant.load -// CHECK: %[[VAL_119:.*]] = fsub float %[[VAL_116]], %[[VAL_118]] -// CHECK: %[[VAL_120:.*]] = getelementptr inbounds [64 x float], ptr %[[VAL_72]], i32 0, i32 %[[VAL_16]] -// CHECK: %[[VAL_121:.*]] = load float, ptr %[[VAL_120]], align 4, !invariant.load -// CHECK: %[[VAL_122:.*]] = getelementptr half, ptr %[[VAL_75]], i32 %[[VAL_5]] -// CHECK: %[[VAL_123:.*]] = getelementptr inbounds half, ptr %[[VAL_122]], i32 1 -// CHECK: %[[VAL_124:.*]] = load half, ptr %[[VAL_123]], align 2, !invariant.load -// CHECK: %[[VAL_125:.*]] = fpext half %[[VAL_124]] to float -// CHECK: %[[VAL_126:.*]] = getelementptr inbounds [64 x float], ptr %[[VAL_80]], i32 0, i32 %[[VAL_16]] -// CHECK: %[[VAL_127:.*]] = load float, ptr %[[VAL_126]], align 4, !invariant.load -// CHECK: %[[VAL_128:.*]] = load float, ptr @0, align 4 -// CHECK: %[[VAL_129:.*]] = fmul float %[[VAL_127]], %[[VAL_128]] -// CHECK: %[[VAL_130:.*]] = fsub float %[[VAL_125]], %[[VAL_129]] -// CHECK: %[[VAL_131:.*]] = fmul float %[[VAL_121]], %[[VAL_130]] -// CHECK: %[[VAL_132:.*]] = getelementptr inbounds [64 x float], ptr %[[VAL_87]], i32 0, i32 %[[VAL_16]] -// CHECK: %[[VAL_133:.*]] = load float, ptr %[[VAL_132]], align 4, !invariant.load -// CHECK: %[[VAL_134:.*]] = fdiv float %[[VAL_131]], %[[VAL_133]] -// CHECK: %[[VAL_135:.*]] = fsub float %[[VAL_119]], %[[VAL_134]] -// CHECK: %[[VAL_136:.*]] = fmul float %[[VAL_102]], %[[VAL_135]] -// CHECK: %[[VAL_137:.*]] = fptrunc float %[[VAL_136]] to half -// CHECK: %[[VAL_138:.*]] = getelementptr half, ptr %[[VAL_94]], i32 %[[VAL_5]] -// CHECK: %[[VAL_139:.*]] = getelementptr inbounds half, ptr %[[VAL_138]], i32 1 -// CHECK: store half %[[VAL_137]], ptr %[[VAL_139]], align 2 -// CHECK: %[[VAL_140:.*]] = getelementptr inbounds [64 x float], ptr %[[VAL_43]], i32 0, i32 %[[VAL_24]] -// CHECK: %[[VAL_141:.*]] = load float, ptr %[[VAL_140]], align 4, !invariant.load -// CHECK: %[[VAL_142:.*]] = getelementptr inbounds [64 x float], ptr %[[VAL_46]], i32 0, i32 %[[VAL_24]] -// CHECK: %[[VAL_143:.*]] = load float, ptr %[[VAL_142]], align 4, !invariant.load -// CHECK: %[[VAL_144:.*]] = fmul float %[[VAL_141]], %[[VAL_143]] -// CHECK: %[[VAL_145:.*]] = load float, ptr @0, align 4 -// CHECK: %[[VAL_146:.*]] = fmul float %[[VAL_144]], %[[VAL_145]] -// CHECK: %[[VAL_147:.*]] = getelementptr half, ptr %[[VAL_52]], i32 %[[VAL_5]] -// CHECK: %[[VAL_148:.*]] = getelementptr inbounds half, ptr %[[VAL_147]], i32 2 -// CHECK: %[[VAL_149:.*]] = load half, ptr %[[VAL_148]], align 2, !invariant.load -// CHECK: %[[VAL_150:.*]] = load half, ptr @2, align 2 -// CHECK: %[[VAL_151:.*]] = fcmp ogt half %[[VAL_149]], %[[VAL_150]] -// CHECK: %[[VAL_152:.*]] = zext i1 %[[VAL_151]] to i8 -// CHECK: %[[VAL_153:.*]] = getelementptr half, ptr %[[VAL_59]], i32 %[[VAL_5]] -// CHECK: %[[VAL_154:.*]] = getelementptr inbounds half, ptr %[[VAL_153]], i32 2 -// CHECK: %[[VAL_155:.*]] = load half, ptr %[[VAL_154]], align 2, !invariant.load -// CHECK: %[[VAL_156:.*]] = trunc i8 %[[VAL_152]] to i1 -// CHECK: %[[VAL_157:.*]] = select i1 %[[VAL_156]], half %[[VAL_155]], half %[[VAL_150]] -// CHECK: %[[VAL_158:.*]] = fpext half %[[VAL_157]] to float -// CHECK: %[[VAL_159:.*]] = load float, ptr @1, align 4 -// CHECK: %[[VAL_160:.*]] = fmul float %[[VAL_158]], %[[VAL_159]] -// CHECK: %[[VAL_161:.*]] = getelementptr inbounds [64 x float], ptr %[[VAL_68]], i32 0, i32 %[[VAL_24]] -// CHECK: %[[VAL_162:.*]] = load float, ptr %[[VAL_161]], align 4, !invariant.load -// CHECK: %[[VAL_163:.*]] = fsub float %[[VAL_160]], %[[VAL_162]] -// CHECK: %[[VAL_164:.*]] = getelementptr inbounds [64 x float], ptr %[[VAL_72]], i32 0, i32 %[[VAL_24]] -// CHECK: %[[VAL_165:.*]] = load float, ptr %[[VAL_164]], align 4, !invariant.load -// CHECK: %[[VAL_166:.*]] = getelementptr half, ptr %[[VAL_75]], i32 %[[VAL_5]] -// CHECK: %[[VAL_167:.*]] = getelementptr inbounds half, ptr %[[VAL_166]], i32 2 -// CHECK: %[[VAL_168:.*]] = load half, ptr %[[VAL_167]], align 2, !invariant.load -// CHECK: %[[VAL_169:.*]] = fpext half %[[VAL_168]] to float -// CHECK: %[[VAL_170:.*]] = getelementptr inbounds [64 x float], ptr %[[VAL_80]], i32 0, i32 %[[VAL_24]] -// CHECK: %[[VAL_171:.*]] = load float, ptr %[[VAL_170]], align 4, !invariant.load -// CHECK: %[[VAL_172:.*]] = load float, ptr @0, align 4 -// CHECK: %[[VAL_173:.*]] = fmul float %[[VAL_171]], %[[VAL_172]] -// CHECK: %[[VAL_174:.*]] = fsub float %[[VAL_169]], %[[VAL_173]] -// CHECK: %[[VAL_175:.*]] = fmul float %[[VAL_165]], %[[VAL_174]] -// CHECK: %[[VAL_176:.*]] = getelementptr inbounds [64 x float], ptr %[[VAL_87]], i32 0, i32 %[[VAL_24]] -// CHECK: %[[VAL_177:.*]] = load float, ptr %[[VAL_176]], align 4, !invariant.load -// CHECK: %[[VAL_178:.*]] = fdiv float %[[VAL_175]], %[[VAL_177]] -// CHECK: %[[VAL_179:.*]] = fsub float %[[VAL_163]], %[[VAL_178]] -// CHECK: %[[VAL_180:.*]] = fmul float %[[VAL_146]], %[[VAL_179]] -// CHECK: %[[VAL_181:.*]] = fptrunc float %[[VAL_180]] to half -// CHECK: %[[VAL_182:.*]] = getelementptr half, ptr %[[VAL_94]], i32 %[[VAL_5]] -// CHECK: %[[VAL_183:.*]] = getelementptr inbounds half, ptr %[[VAL_182]], i32 2 -// CHECK: store half %[[VAL_181]], ptr %[[VAL_183]], align 2 -// CHECK: %[[VAL_184:.*]] = getelementptr inbounds [64 x float], ptr %[[VAL_43]], i32 0, i32 %[[VAL_32]] -// CHECK: %[[VAL_185:.*]] = load float, ptr %[[VAL_184]], align 4, !invariant.load -// CHECK: %[[VAL_186:.*]] = getelementptr inbounds [64 x float], ptr %[[VAL_46]], i32 0, i32 %[[VAL_32]] -// CHECK: %[[VAL_187:.*]] = load float, ptr %[[VAL_186]], align 4, !invariant.load -// CHECK: %[[VAL_188:.*]] = fmul float %[[VAL_185]], %[[VAL_187]] -// CHECK: %[[VAL_189:.*]] = load float, ptr @0, align 4 -// CHECK: %[[VAL_190:.*]] = fmul float %[[VAL_188]], %[[VAL_189]] -// CHECK: %[[VAL_191:.*]] = getelementptr half, ptr %[[VAL_52]], i32 %[[VAL_5]] -// CHECK: %[[VAL_192:.*]] = getelementptr inbounds half, ptr %[[VAL_191]], i32 3 -// CHECK: %[[VAL_193:.*]] = load half, ptr %[[VAL_192]], align 2, !invariant.load -// CHECK: %[[VAL_194:.*]] = load half, ptr @2, align 2 -// CHECK: %[[VAL_195:.*]] = fcmp ogt half %[[VAL_193]], %[[VAL_194]] -// CHECK: %[[VAL_196:.*]] = zext i1 %[[VAL_195]] to i8 -// CHECK: %[[VAL_197:.*]] = getelementptr half, ptr %[[VAL_59]], i32 %[[VAL_5]] -// CHECK: %[[VAL_198:.*]] = getelementptr inbounds half, ptr %[[VAL_197]], i32 3 -// CHECK: %[[VAL_199:.*]] = load half, ptr %[[VAL_198]], align 2, !invariant.load -// CHECK: %[[VAL_200:.*]] = trunc i8 %[[VAL_196]] to i1 -// CHECK: %[[VAL_201:.*]] = select i1 %[[VAL_200]], half %[[VAL_199]], half %[[VAL_194]] -// CHECK: %[[VAL_202:.*]] = fpext half %[[VAL_201]] to float -// CHECK: %[[VAL_203:.*]] = load float, ptr @1, align 4 -// CHECK: %[[VAL_204:.*]] = fmul float %[[VAL_202]], %[[VAL_203]] -// CHECK: %[[VAL_205:.*]] = getelementptr inbounds [64 x float], ptr %[[VAL_68]], i32 0, i32 %[[VAL_32]] -// CHECK: %[[VAL_206:.*]] = load float, ptr %[[VAL_205]], align 4, !invariant.load -// CHECK: %[[VAL_207:.*]] = fsub float %[[VAL_204]], %[[VAL_206]] -// CHECK: %[[VAL_208:.*]] = getelementptr inbounds [64 x float], ptr %[[VAL_72]], i32 0, i32 %[[VAL_32]] -// CHECK: %[[VAL_209:.*]] = load float, ptr %[[VAL_208]], align 4, !invariant.load -// CHECK: %[[VAL_210:.*]] = getelementptr half, ptr %[[VAL_75]], i32 %[[VAL_5]] -// CHECK: %[[VAL_211:.*]] = getelementptr inbounds half, ptr %[[VAL_210]], i32 3 -// CHECK: %[[VAL_212:.*]] = load half, ptr %[[VAL_211]], align 2, !invariant.load -// CHECK: %[[VAL_213:.*]] = fpext half %[[VAL_212]] to float -// CHECK: %[[VAL_214:.*]] = getelementptr inbounds [64 x float], ptr %[[VAL_80]], i32 0, i32 %[[VAL_32]] -// CHECK: %[[VAL_215:.*]] = load float, ptr %[[VAL_214]], align 4, !invariant.load -// CHECK: %[[VAL_216:.*]] = load float, ptr @0, align 4 -// CHECK: %[[VAL_217:.*]] = fmul float %[[VAL_215]], %[[VAL_216]] -// CHECK: %[[VAL_218:.*]] = fsub float %[[VAL_213]], %[[VAL_217]] -// CHECK: %[[VAL_219:.*]] = fmul float %[[VAL_209]], %[[VAL_218]] -// CHECK: %[[VAL_220:.*]] = getelementptr inbounds [64 x float], ptr %[[VAL_87]], i32 0, i32 %[[VAL_32]] -// CHECK: %[[VAL_221:.*]] = load float, ptr %[[VAL_220]], align 4, !invariant.load -// CHECK: %[[VAL_222:.*]] = fdiv float %[[VAL_219]], %[[VAL_221]] -// CHECK: %[[VAL_223:.*]] = fsub float %[[VAL_207]], %[[VAL_222]] -// CHECK: %[[VAL_224:.*]] = fmul float %[[VAL_190]], %[[VAL_223]] -// CHECK: %[[VAL_225:.*]] = fptrunc float %[[VAL_224]] to half -// CHECK: %[[VAL_226:.*]] = getelementptr half, ptr %[[VAL_94]], i32 %[[VAL_5]] -// CHECK: %[[VAL_227:.*]] = getelementptr inbounds half, ptr %[[VAL_226]], i32 3 -// CHECK: store half %[[VAL_225]], ptr %[[VAL_227]], align 2 -// CHECK: br label %[[VAL_40]] - - -%fused_computation.1 (param_0.5: f32[64], param_1.3088: f32[64], param_2.2116: f32[64], param_3.974: f32[64], param_4.1162: f32[64], param_5.893: f32[64], param_6.809: f16[128,64,112,112], param_7.770: f16[128,64,112,112], param_8.637: f16[128,64,112,112]) -> f16[128,64,112,112] { - %param_4.1162 = f32[64]{0} parameter(4) - %broadcast.2313 = f32[128,64,112,112]{1,3,2,0} broadcast(f32[64]{0} %param_4.1162), dimensions={1} - %param_3.974 = f32[64]{0} parameter(3) - %broadcast.1844 = f32[128,64,112,112]{1,3,2,0} broadcast(f32[64]{0} %param_3.974), dimensions={1} - %multiply.1049 = f32[128,64,112,112]{1,3,2,0} multiply(f32[128,64,112,112]{1,3,2,0} %broadcast.2313, f32[128,64,112,112]{1,3,2,0} %broadcast.1844) - %constant_1404 = f32[] constant(6.22807704e-07) - %broadcast.1843 = f32[128,64,112,112]{1,3,2,0} broadcast(f32[] %constant_1404), dimensions={} - %multiply.1048 = f32[128,64,112,112]{1,3,2,0} multiply(f32[128,64,112,112]{1,3,2,0} %multiply.1049, f32[128,64,112,112]{1,3,2,0} %broadcast.1843) - %param_8.637 = f16[128,64,112,112]{1,3,2,0} parameter(8) - %constant_3626 = f16[] constant(0) - %broadcast.4770 = f16[128,64,112,112]{1,3,2,0} broadcast(f16[] %constant_3626), dimensions={} - %compare.259 = pred[128,64,112,112]{1,3,2,0} compare(f16[128,64,112,112]{1,3,2,0} %param_8.637, f16[128,64,112,112]{1,3,2,0} %broadcast.4770), direction=GT - %param_7.770 = f16[128,64,112,112]{1,3,2,0} parameter(7) - %select.254 = f16[128,64,112,112]{1,3,2,0} select(pred[128,64,112,112]{1,3,2,0} %compare.259, f16[128,64,112,112]{1,3,2,0} %param_7.770, f16[128,64,112,112]{1,3,2,0} %broadcast.4770) - %convert.108 = f32[128,64,112,112]{1,3,2,0} convert(f16[128,64,112,112]{1,3,2,0} %select.254) - %constant_1390 = f32[] constant(1605632) - %broadcast.1841 = f32[128,64,112,112]{1,3,2,0} broadcast(f32[] %constant_1390), dimensions={} - %multiply.1046 = f32[128,64,112,112]{1,3,2,0} multiply(f32[128,64,112,112]{1,3,2,0} %convert.108, f32[128,64,112,112]{1,3,2,0} %broadcast.1841) - %param_2.2116 = f32[64]{0} parameter(2) - %broadcast.1840 = f32[128,64,112,112]{1,3,2,0} broadcast(f32[64]{0} %param_2.2116), dimensions={1} - %subtract.266 = f32[128,64,112,112]{1,3,2,0} subtract(f32[128,64,112,112]{1,3,2,0} %multiply.1046, f32[128,64,112,112]{1,3,2,0} %broadcast.1840) - %param_1.3088 = f32[64]{0} parameter(1) - %broadcast.1839 = f32[128,64,112,112]{1,3,2,0} broadcast(f32[64]{0} %param_1.3088), dimensions={1} - %param_6.809 = f16[128,64,112,112]{1,3,2,0} parameter(6) - %convert.644 = f32[128,64,112,112]{1,3,2,0} convert(f16[128,64,112,112]{1,3,2,0} %param_6.809) - %param_5.893 = f32[64]{0} parameter(5) - %broadcast.3388 = f32[64]{0} broadcast(f32[] %constant_1404), dimensions={} - %multiply.2336 = f32[64]{0} multiply(f32[64]{0} %param_5.893, f32[64]{0} %broadcast.3388) - %broadcast.3387 = f32[128,64,112,112]{1,3,2,0} broadcast(f32[64]{0} %multiply.2336), dimensions={1} - %subtract.591 = f32[128,64,112,112]{1,3,2,0} subtract(f32[128,64,112,112]{1,3,2,0} %convert.644, f32[128,64,112,112]{1,3,2,0} %broadcast.3387) - %multiply.1045 = f32[128,64,112,112]{1,3,2,0} multiply(f32[128,64,112,112]{1,3,2,0} %broadcast.1839, f32[128,64,112,112]{1,3,2,0} %subtract.591) - %param_0.5 = f32[64]{0} parameter(0) - %broadcast.1838 = f32[128,64,112,112]{1,3,2,0} broadcast(f32[64]{0} %param_0.5), dimensions={1} - %divide.212 = f32[128,64,112,112]{1,3,2,0} divide(f32[128,64,112,112]{1,3,2,0} %multiply.1045, f32[128,64,112,112]{1,3,2,0} %broadcast.1838) - %subtract.265 = f32[128,64,112,112]{1,3,2,0} subtract(f32[128,64,112,112]{1,3,2,0} %subtract.266, f32[128,64,112,112]{1,3,2,0} %divide.212) - %multiply.1044 = f32[128,64,112,112]{1,3,2,0} multiply(f32[128,64,112,112]{1,3,2,0} %multiply.1048, f32[128,64,112,112]{1,3,2,0} %subtract.265) - ROOT %convert.107 = f16[128,64,112,112]{1,3,2,0} convert(f32[128,64,112,112]{1,3,2,0} %multiply.1044) -} - -ENTRY main { - %get-tuple-element.1532 = f32[64]{0} parameter(0) - %get-tuple-element.876 = f32[64]{0} parameter(1) - %get-tuple-element.877 = f32[64]{0} parameter(2) - %get-tuple-element.1530 = f32[64]{0} parameter(3) - %arg112.113 = f32[64]{0} parameter(4) - %get-tuple-element.881 = f32[64]{0} parameter(5) - %get-tuple-element.872 = f16[128,64,112,112]{1,3,2,0} parameter(6) - %select-and-scatter.3626 = f16[128,64,112,112]{1,3,2,0} parameter(7) - %fusion.845 = f16[128,64,112,112]{1,3,2,0} parameter(8) - - ROOT %fusion.1 = f16[128,64,112,112]{1,3,2,0} fusion(f32[64]{0} %get-tuple-element.1532, f32[64]{0} %get-tuple-element.876, f32[64]{0} %get-tuple-element.877, f32[64]{0} %get-tuple-element.1530, f32[64]{0} %arg112.113, f32[64]{0} %get-tuple-element.881, f16[128,64,112,112]{1,3,2,0} %get-tuple-element.872, f16[128,64,112,112]{1,3,2,0} %select-and-scatter.3626, f16[128,64,112,112]{1,3,2,0} %fusion.845), kind=kLoop, calls=%fused_computation.1 -} diff --git a/third_party/xla/xla/service/gpu/tests/gemm_rewrite_test.cc b/third_party/xla/xla/service/gpu/tests/gemm_rewrite_test.cc index 930d2e2c5449e4..493f9f5b0e62c8 100644 --- a/third_party/xla/xla/service/gpu/tests/gemm_rewrite_test.cc +++ b/third_party/xla/xla/service/gpu/tests/gemm_rewrite_test.cc @@ -38,6 +38,8 @@ limitations under the License. #if GOOGLE_CUDA #include "third_party/gpus/cuda/include/cuda.h" +#elif TENSORFLOW_USE_ROCM +#include "rocm/rocm_config.h" #endif namespace xla { @@ -985,10 +987,6 @@ ENTRY bf16gemm { } TEST_P(ParameterizedGemmRewriteTest, Int8Gemm) { - if (CudaOrRocmCheck(Switch::False, Switch::True)) { - GTEST_SKIP() << "DoBlasGemmWithAlgorithm is not yet implemented on ROCm"; - } - const char* hlo_text = R"( HloModule int8gemm @@ -1044,9 +1042,6 @@ ENTRY main.4 { } TEST_P(ParameterizedGemmRewriteTest, Int8GemmNoAlphaRewrite) { - if (CudaOrRocmCheck(Switch::False, Switch::True)) { - GTEST_SKIP() << "DoBlasGemmWithAlgorithm is not yet implemented on ROCm"; - } const char* hlo_text = R"( HloModule int8gemm @@ -1082,9 +1077,6 @@ ENTRY int8gemm { } TEST_P(ParameterizedGemmRewriteTest, Int8GemmNoBetaRewrite) { - if (CudaOrRocmCheck(Switch::False, Switch::True)) { - GTEST_SKIP() << "DoBlasGemmWithAlgorithm is not yet implemented on ROCm"; - } const char* hlo_text = R"( HloModule int8gemm @@ -1898,9 +1890,6 @@ ENTRY test { #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM // Test gemm matrix bias add fusion with mix type TEST_F(LegacyCublasGemmRewriteTest, MatrixBiasMixType) { - if (CudaOrRocmCheck(Switch::False, Switch::True)) { - GTEST_SKIP() << "DoBlasGemmWithAlgorithm is not yet implemented on ROCm"; - } std::vector> type_combinations = { {"f16", "f32"}, @@ -1939,9 +1928,6 @@ ENTRY test { // Test batch gemm matrix bias add fusion with mix type TEST_F(LegacyCublasGemmRewriteTest, MatrixBiasMixTypeBatched) { - if (CudaOrRocmCheck(Switch::False, Switch::True)) { - GTEST_SKIP() << "DoBlasGemmWithAlgorithm is not yet implemented on ROCm"; - } std::vector> type_combinations = { {"f16", "f32"}, @@ -3456,7 +3442,12 @@ ENTRY test { } TEST_F(CublasLtGemmRewriteTest, VectorBiasThenApproxGeluActivation) { - if (CudaOrRocmCheck(Switch::False, Switch::True)) { +#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION >= 60000 + auto rocm_switch = Switch::False; // GELU is only available from ROCM 6.0 +#else + auto rocm_switch = Switch::True; +#endif + if (CudaOrRocmCheck(Switch::False, rocm_switch)) { GTEST_SKIP() << "TODO: Unsupported blas-lt epilogue on ROCM"; } const char* hlo_text = R"( diff --git a/third_party/xla/xla/service/gpu/tests/gpu_alignment_test.cc b/third_party/xla/xla/service/gpu/tests/gpu_alignment_test.cc index e98a1767659c52..27e7a5925e2612 100644 --- a/third_party/xla/xla/service/gpu/tests/gpu_alignment_test.cc +++ b/third_party/xla/xla/service/gpu/tests/gpu_alignment_test.cc @@ -13,13 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include -#include - -#include "xla/service/custom_call_target_registry.h" #include "xla/service/gpu/tests/gpu_codegen_test.h" -#include "xla/service/llvm_ir/alias_analysis.h" -#include "xla/tests/filecheck.h" #include "tsl/platform/test.h" namespace xla { @@ -44,13 +38,10 @@ ENTRY main { } )"; - auto expected_ir = is_built_with_rocm_ ? R"( -CHECK: @fusion(ptr noalias align 128 dereferenceable(800) %arg0, ptr noalias align 16 dereferenceable(400) %arg1, ptr noalias align 128 dereferenceable(600) %arg2) -)" - : R"( -CHECK: define void @fusion(ptr noalias align 128 dereferenceable(800) %arg0, ptr noalias align 16 dereferenceable(400) %arg1, ptr noalias align 128 dereferenceable(600) %arg2) -)"; - CompileAndVerifyIr(hlo_string, expected_ir); + CompileAndVerifyIr( + hlo_string, + "CHECK: {{.*}}align 128 dereferenceable(800) %{{.*}}align 16 " + "dereferenceable(400) %{{.*}}align 128 dereferenceable(600) %"); } } // namespace diff --git a/third_party/xla/xla/service/gpu/tests/gpu_codegen_test.h b/third_party/xla/xla/service/gpu/tests/gpu_codegen_test.h index c1648bb873a29d..4be343add7dd5c 100644 --- a/third_party/xla/xla/service/gpu/tests/gpu_codegen_test.h +++ b/third_party/xla/xla/service/gpu/tests/gpu_codegen_test.h @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include "xla/stream_executor/platform_manager.h" #include "xla/tests/llvm_irgen_test_base.h" #include "xla/tests/verified_hlo_module.h" @@ -30,7 +31,7 @@ class GpuCodegenTest : public LlvmIrGenTestBase { public: GpuCodegenTest() : is_built_with_rocm_( - se::MultiPlatformManager::PlatformWithName("ROCM").ok()) {} + se::PlatformManager::PlatformWithName("ROCM").ok()) {} protected: // Converts LLVM match to be platform-specific. diff --git a/third_party/xla/xla/service/gpu/tests/gpu_copy_test.cc b/third_party/xla/xla/service/gpu/tests/gpu_copy_test.cc index 87e863c9123761..e97a15fad66520 100644 --- a/third_party/xla/xla/service/gpu/tests/gpu_copy_test.cc +++ b/third_party/xla/xla/service/gpu/tests/gpu_copy_test.cc @@ -52,5 +52,24 @@ TEST_F(GpuCopyTest, UseMemcpy) { /*match_optimized_ir=*/false); } +TEST_F(GpuCopyTest, CopyTranspose) { + const char* hlo_text = R"( + HloModule Test + + fused_computation { + param_0 = f32[100,200,300]{2,1,0} parameter(0) + ROOT b.1 = f32[100,200,300]{2,0,1} copy(f32[100,200,300]{2,1,0} param_0) + } + + ENTRY main { + a = f32[100, 200, 300]{2,1,0} parameter(0) + ROOT wrapped_b = f32[100,200,300]{2,0,1} fusion(f32[100,200,300]{2,1,0} %a), kind=kLoop, calls=fused_computation + })"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr optimized_module, + ParseAndReturnVerifiedModule(hlo_text)); + + EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5})); +} + } // namespace gpu } // namespace xla diff --git a/third_party/xla/xla/service/gpu/tests/gpu_dyn_shape_test.cc b/third_party/xla/xla/service/gpu/tests/gpu_dyn_shape_test.cc index 9bc2cbda244739..19d37adaadf482 100644 --- a/third_party/xla/xla/service/gpu/tests/gpu_dyn_shape_test.cc +++ b/third_party/xla/xla/service/gpu/tests/gpu_dyn_shape_test.cc @@ -41,8 +41,8 @@ TEST_F(GpuDynamicShapeTest, DynamicShapeR2) { CompileAndVerifyIr(std::move(hlo_module), R"( ; CHECK-DAG: is_thread_0-true -; CHECK-DAG: x_padded{{(_1)?}}.in_dyn_bounds-true -; CHECK-DAG: x_padded{{(_1)?}}.in_bounds-true +; CHECK-DAG: x.padded{{.*}}.in_dyn_bounds-true +; CHECK-DAG: x.padded{{.*}}.in_bounds-true ; CHECK: %[[dyn_dim_size:.*]] = load i32, ptr ; CHECK: %[[dyn_element_total:.*]] = mul i32 1, %[[dyn_dim_size:.*]] ; CHECK: %[[linear_index:.*]] = add nuw nsw i32 diff --git a/third_party/xla/xla/service/gpu/tests/gpu_fused_mha_test.cc b/third_party/xla/xla/service/gpu/tests/gpu_fused_mha_test.cc index 5a880dd49a54bc..0bf8c7dce1e129 100644 --- a/third_party/xla/xla/service/gpu/tests/gpu_fused_mha_test.cc +++ b/third_party/xla/xla/service/gpu/tests/gpu_fused_mha_test.cc @@ -84,7 +84,7 @@ class MultiHeadedAttentionTest : public GpuCodegenTest { protected: DebugOptions GetDebugOptionsForTest() override { auto debug_options = HloTestBase::GetDebugOptionsForTest(); - debug_options.set_xla_gpu_enable_xla_runtime_executable(true); + debug_options.set_xla_gpu_enable_xla_runtime_executable(false); return debug_options; } @@ -134,7 +134,7 @@ class MultiHeadedAttentionTest : public GpuCodegenTest { EXPECT_TRUE( LiteralTestUtil::Near(expected_result, actual_result, error_spec_)); - std::string prefix = "__cudnn$fhma"; + std::string prefix = "__cudnn$fmha"; IsFMHACalled(hlo_string, config_with_fmha, prefix, is_training); } @@ -2839,6 +2839,254 @@ class FlashAttentionBMMScaleBiasMaskSoftmaxBMM true); } }; + +class FlashAttentionBMMScaleSoftmaxBMM : public MultiHeadedAttentionTest { + protected: + const std::string // NOLINT + GetModuleFlash_Attention_Training_BMM1_Softmax_BMM2_HloString_BF16() { // NOLINT + const std::string hlo_text = R"( + HloModule jit__unnamed_wrapped_function_, entry_computation_layout={(bf16[2,6,1024,64]{3,2,1,0}, bf16[2,6,64,1024]{3,2,1,0}, bf16[2,6,1024,64]{3,2,1,0}, bf16[2,6,1024,64]{3,2,1,0})->(bf16[2,6,1024,64]{3,2,1,0}, bf16[2,6,1024,64]{3,2,1,0}, bf16[2,6,64,1024]{3,2,1,0}, bf16[2,6,1024,64]{3,2,1,0})}, allow_spmd_sharding_propagation_to_output={true,true,true,true} + + region_0.13 { + Arg_0.14 = bf16[] parameter(0) + Arg_1.15 = bf16[] parameter(1) + ROOT maximum.16 = bf16[] maximum(Arg_0.14, Arg_1.15) + } + + region_1.25 { + Arg_0.26 = f32[] parameter(0) + Arg_1.27 = f32[] parameter(1) + ROOT add.28 = f32[] add(Arg_0.26, Arg_1.27) + } + + region_2.47 { + Arg_0.48 = bf16[] parameter(0) + Arg_1.49 = bf16[] parameter(1) + ROOT add.50 = bf16[] add(Arg_0.48, Arg_1.49) + } + + region_3.59 { + Arg_0.60 = f32[] parameter(0) + Arg_1.61 = f32[] parameter(1) + ROOT add.62 = f32[] add(Arg_0.60, Arg_1.61) + } + + ENTRY main.72 { + Arg_0.1 = bf16[2,6,1024,64]{3,2,1,0} parameter(0), sharding={replicated} + Arg_1.2 = bf16[2,6,64,1024]{3,2,1,0} parameter(1), sharding={replicated} + dot.11 = bf16[2,6,1024,1024]{3,2,1,0} dot(Arg_0.1, Arg_1.2), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2} + constant.17 = bf16[] constant(2) + broadcast.29 = bf16[2,6,1024,1024]{3,2,1,0} broadcast(constant.17), dimensions={} + multiply.2 = bf16[2,6,1024,1024]{3,2,1,0} multiply(dot.11, broadcast.29) + constant.9 = bf16[] constant(-inf) + reduce.17 = bf16[2,6,1024]{2,1,0} reduce(multiply.2, constant.9), dimensions={3}, to_apply=region_0.13 + reshape.18 = bf16[2,6,1024,1]{3,2,1,0} reshape(reduce.17) + broadcast.19 = bf16[2,6,1024,1]{3,2,1,0} broadcast(reshape.18), dimensions={0,1,2,3} + reshape.20 = bf16[2,6,1024]{2,1,0} reshape(broadcast.19) + broadcast.21 = bf16[2,6,1024,1024]{3,2,1,0} broadcast(reshape.20), dimensions={0,1,2} + subtract.22 = bf16[2,6,1024,1024]{3,2,1,0} subtract(multiply.2, broadcast.21) + exponential.23 = bf16[2,6,1024,1024]{3,2,1,0} exponential(subtract.22) + convert.24 = f32[2,6,1024,1024]{3,2,1,0} convert(exponential.23) + constant.8 = f32[] constant(0) + reduce.29 = f32[2,6,1024]{2,1,0} reduce(convert.24, constant.8), dimensions={3}, to_apply=region_1.25 + reshape.30 = f32[2,6,1024,1]{3,2,1,0} reshape(reduce.29) + convert.31 = bf16[2,6,1024,1]{3,2,1,0} convert(reshape.30) + broadcast.32 = bf16[2,6,1024,1]{3,2,1,0} broadcast(convert.31), dimensions={0,1,2,3} + reshape.33 = bf16[2,6,1024]{2,1,0} reshape(broadcast.32) + broadcast.34 = bf16[2,6,1024,1024]{3,2,1,0} broadcast(reshape.33), dimensions={0,1,2} + divide.35 = bf16[2,6,1024,1024]{3,2,1,0} divide(exponential.23, broadcast.34) + Arg_2.3 = bf16[2,6,1024,64]{3,2,1,0} parameter(2), sharding={replicated} + dot.38 = bf16[2,6,1024,64]{3,2,1,0} dot(divide.35, Arg_2.3), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2} + Arg_4.5 = bf16[2,6,1024,64]{3,2,1,0} parameter(3), sharding={replicated} + dot.41 = bf16[2,6,1024,1024]{3,2,1,0} dot(Arg_4.5, Arg_2.3), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={3} + broadcast.54 = bf16[2,6,1024,1]{3,2,1,0} broadcast(convert.31), dimensions={0,1,2,3} + reshape.55 = bf16[2,6,1024]{2,1,0} reshape(broadcast.54) + broadcast.56 = bf16[2,6,1024,1024]{3,2,1,0} broadcast(reshape.55), dimensions={0,1,2} + divide.57 = bf16[2,6,1024,1024]{3,2,1,0} divide(dot.41, broadcast.56) + constant.5 = bf16[] constant(1) + broadcast.6 = bf16[2,6,1024,1]{3,2,1,0} broadcast(constant.5), dimensions={} + multiply.36 = bf16[2,6,1024,1]{3,2,1,0} multiply(convert.31, convert.31) + divide.37 = bf16[2,6,1024,1]{3,2,1,0} divide(broadcast.6, multiply.36) + broadcast.42 = bf16[2,6,1024,1]{3,2,1,0} broadcast(divide.37), dimensions={0,1,2,3} + reshape.43 = bf16[2,6,1024]{2,1,0} reshape(broadcast.42) + broadcast.44 = bf16[2,6,1024,1024]{3,2,1,0} broadcast(reshape.43), dimensions={0,1,2} + multiply.45 = bf16[2,6,1024,1024]{3,2,1,0} multiply(dot.41, broadcast.44) + multiply.46 = bf16[2,6,1024,1024]{3,2,1,0} multiply(multiply.45, exponential.23) + constant.7 = bf16[] constant(0) + reduce.51 = bf16[2,6,1024]{2,1,0} reduce(multiply.46, constant.7), dimensions={3}, to_apply=region_2.47 + reshape.52 = bf16[2,6,1024,1]{3,2,1,0} reshape(reduce.51) + negate.53 = bf16[2,6,1024,1]{3,2,1,0} negate(reshape.52) + convert.58 = f32[2,6,1024,1]{3,2,1,0} convert(negate.53) + reduce.63 = f32[2,6,1024]{2,1,0} reduce(convert.58, constant.8), dimensions={3}, to_apply=region_3.59 + broadcast.64 = f32[2,6,1024,1024]{3,2,1,0} broadcast(reduce.63), dimensions={0,1,2} + convert.65 = bf16[2,6,1024,1024]{3,2,1,0} convert(broadcast.64) + add.66 = bf16[2,6,1024,1024]{3,2,1,0} add(divide.57, convert.65) + multiply.67 = bf16[2,6,1024,1024]{3,2,1,0} multiply(add.66, exponential.23) + dot.70 = bf16[2,6,1024,64]{3,2,1,0} dot(multiply.67, Arg_1.2), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={3} + dot.68 = bf16[2,6,1024,64]{3,2,1,0} dot(multiply.67, Arg_0.1), lhs_batch_dims={0,1}, lhs_contracting_dims={2}, rhs_batch_dims={0,1}, rhs_contracting_dims={2} + transpose.69 = bf16[2,6,64,1024]{2,3,1,0} transpose(dot.68), dimensions={0,1,3,2} + dot.39 = bf16[2,6,64,1024]{3,2,1,0} dot(Arg_4.5, divide.35), lhs_batch_dims={0,1}, lhs_contracting_dims={2}, rhs_batch_dims={0,1}, rhs_contracting_dims={2} + transpose.40 = bf16[2,6,1024,64]{2,3,1,0} transpose(dot.39), dimensions={0,1,3,2} + ROOT tuple.71 = (bf16[2,6,1024,64]{3,2,1,0}, bf16[2,6,1024,64]{3,2,1,0}, bf16[2,6,64,1024]{2,3,1,0}, bf16[2,6,1024,64]{2,3,1,0}) tuple(dot.38, dot.70, transpose.69, transpose.40) + } + )"; + return hlo_text; + } + + template + void TestImpl_Flash_Attention_Training_BMM1_Softmax_BMM2() { + stream_executor::CudaComputeCapability cc = GetCudaComputeCapability(); + se::dnn::VersionInfo real_cudnn_version = GetCudnnVersion(); + if (!(cc.IsAtLeast(se::CudaComputeCapability::AMPERE) && cc.minor == 0 && + real_cudnn_version >= se::dnn::VersionInfo(8, 9, 3))) { + GTEST_SKIP() << "Flash Attention is supported with the Nvidia AMPERE+ " + "GPUs and cuDNN >= 8.9.3."; + } + XlaBuilder builder(TestName()); + auto lhs_bmm1_literal = + GetInput4DLiteral({2, 6, 1024, 64}, {3, 2, 1, 0}); + auto rhs_bmm1_literal = + GetInput4DLiteral({2, 6, 64, 1024}, {3, 2, 1, 0}); + auto rhs_bmm2_literal = + GetInput4DLiteral({2, 6, 1024, 64}, {3, 2, 1, 0}); + auto do_literal = GetInput4DLiteral({2, 6, 1024, 64}, {3, 2, 1, 0}); + std::string hlo_string = ""; + hlo_string = + GetModuleFlash_Attention_Training_BMM1_Softmax_BMM2_HloString_BF16(); // NOLINT + ExecuteAndCompare( + hlo_string, + {&lhs_bmm1_literal, &rhs_bmm1_literal, &rhs_bmm2_literal, &do_literal}, + true); + } +}; + +class FlashAttentionBMMScaleMaskSoftmaxBMM : public MultiHeadedAttentionTest { + protected: + const std::string // NOLINT + GetModuleFlash_Attention_Training_BMM1_Mask_Softmax_BMM2_HloString_BF16() { // NOLINT + const std::string hlo_text = R"( + HloModule jit__unnamed_wrapped_function_, entry_computation_layout={(bf16[2,6,1024,64]{3,2,1,0},bf16[2,6,64,1024]{3,2,1,0},bf16[2,6,1024,64]{3,2,1,0},bf16[2,6,1024,64]{3,2,1,0},pred[2,6,1024,1024]{3,2,1,0})->(bf16[2,6,1024,64]{3,2,1,0}, bf16[2,6,1024,64]{3,2,1,0}, bf16[2,6,64,1024]{3,2,1,0}, bf16[2,6,1024,64]{3,2,1,0})}, allow_spmd_sharding_propagation_to_output={true,true,true,true} + + region_0.21 { + Arg_0.22 = bf16[] parameter(0) + Arg_1.23 = bf16[] parameter(1) + ROOT maximum.24 = bf16[] maximum(Arg_0.22, Arg_1.23) + } + + region_1.33 { + Arg_0.34 = f32[] parameter(0) + Arg_1.35 = f32[] parameter(1) + ROOT add.36 = f32[] add(Arg_0.34, Arg_1.35) + } + + region_2.55 { + Arg_0.56 = bf16[] parameter(0) + Arg_1.57 = bf16[] parameter(1) + ROOT add.58 = bf16[] add(Arg_0.56, Arg_1.57) + } + + region_3.67 { + Arg_0.68 = f32[] parameter(0) + Arg_1.69 = f32[] parameter(1) + ROOT add.70 = f32[] add(Arg_0.68, Arg_1.69) + } + + ENTRY main.82 { + constant.16 = pred[2,6,1024,1024]{3,2,1,0} parameter(4) + Arg_0.1 = bf16[2,6,1024,64]{3,2,1,0} parameter(0) + Arg_1.2 = bf16[2,6,64,1024]{3,2,1,0} parameter(1) + dot.17 = bf16[2,6,1024,1024]{3,2,1,0} dot(Arg_0.1, Arg_1.2), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2} + constant.5 = bf16[] constant(2) + broadcast.6 = bf16[2,6,1024,1024]{3,2,1,0} broadcast(constant.5), dimensions={} + multiply.18 = bf16[2,6,1024,1024]{3,2,1,0} multiply(dot.17, broadcast.6) + constant.7 = bf16[] constant(0) + broadcast.8 = bf16[2,6,1024,1024]{3,2,1,0} broadcast(constant.7), dimensions={} + select.20 = bf16[2,6,1024,1024]{3,2,1,0} select(constant.16, multiply.18, broadcast.8) + constant.12 = bf16[] constant(-inf) + reduce.25 = bf16[2,6,1024]{2,1,0} reduce(select.20, constant.12), dimensions={3}, to_apply=region_0.21 + reshape.26 = bf16[2,6,1024,1]{3,2,1,0} reshape(reduce.25) + broadcast.27 = bf16[2,6,1024,1]{3,2,1,0} broadcast(reshape.26), dimensions={0,1,2,3} + reshape.28 = bf16[2,6,1024]{2,1,0} reshape(broadcast.27) + broadcast.29 = bf16[2,6,1024,1024]{3,2,1,0} broadcast(reshape.28), dimensions={0,1,2} + subtract.30 = bf16[2,6,1024,1024]{3,2,1,0} subtract(select.20, broadcast.29) + exponential.31 = bf16[2,6,1024,1024]{3,2,1,0} exponential(subtract.30) + convert.32 = f32[2,6,1024,1024]{3,2,1,0} convert(exponential.31) + constant.11 = f32[] constant(0) + reduce.37 = f32[2,6,1024]{2,1,0} reduce(convert.32, constant.11), dimensions={3}, to_apply=region_1.33 + reshape.38 = f32[2,6,1024,1]{3,2,1,0} reshape(reduce.37) + convert.39 = bf16[2,6,1024,1]{3,2,1,0} convert(reshape.38) + broadcast.40 = bf16[2,6,1024,1]{3,2,1,0} broadcast(convert.39), dimensions={0,1,2,3} + reshape.41 = bf16[2,6,1024]{2,1,0} reshape(broadcast.40) + broadcast.42 = bf16[2,6,1024,1024]{3,2,1,0} broadcast(reshape.41), dimensions={0,1,2} + divide.43 = bf16[2,6,1024,1024]{3,2,1,0} divide(exponential.31, broadcast.42) + Arg_2.3 = bf16[2,6,1024,64]{3,2,1,0} parameter(2) + dot.46 = bf16[2,6,1024,64]{3,2,1,0} dot(divide.43, Arg_2.3), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2} + Arg_3.4 = bf16[2,6,1024,64]{3,2,1,0} parameter(3) + dot.49 = bf16[2,6,1024,1024]{3,2,1,0} dot(Arg_3.4, Arg_2.3), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={3} + broadcast.62 = bf16[2,6,1024,1]{3,2,1,0} broadcast(convert.39), dimensions={0,1,2,3} + reshape.63 = bf16[2,6,1024]{2,1,0} reshape(broadcast.62) + broadcast.64 = bf16[2,6,1024,1024]{3,2,1,0} broadcast(reshape.63), dimensions={0,1,2} + divide.65 = bf16[2,6,1024,1024]{3,2,1,0} divide(dot.49, broadcast.64) + constant.9 = bf16[] constant(1) + broadcast.10 = bf16[2,6,1024,1]{3,2,1,0} broadcast(constant.9), dimensions={} + multiply.44 = bf16[2,6,1024,1]{3,2,1,0} multiply(convert.39, convert.39) + divide.45 = bf16[2,6,1024,1]{3,2,1,0} divide(broadcast.10, multiply.44) + broadcast.50 = bf16[2,6,1024,1]{3,2,1,0} broadcast(divide.45), dimensions={0,1,2,3} + reshape.51 = bf16[2,6,1024]{2,1,0} reshape(broadcast.50) + broadcast.52 = bf16[2,6,1024,1024]{3,2,1,0} broadcast(reshape.51), dimensions={0,1,2} + multiply.53 = bf16[2,6,1024,1024]{3,2,1,0} multiply(dot.49, broadcast.52) + multiply.54 = bf16[2,6,1024,1024]{3,2,1,0} multiply(multiply.53, exponential.31) + constant.13 = bf16[] constant(0) + reduce.59 = bf16[2,6,1024]{2,1,0} reduce(multiply.54, constant.13), dimensions={3}, to_apply=region_2.55 + reshape.60 = bf16[2,6,1024,1]{3,2,1,0} reshape(reduce.59) + negate.61 = bf16[2,6,1024,1]{3,2,1,0} negate(reshape.60) + convert.66 = f32[2,6,1024,1]{3,2,1,0} convert(negate.61) + reduce.71 = f32[2,6,1024]{2,1,0} reduce(convert.66, constant.11), dimensions={3}, to_apply=region_3.67 + broadcast.72 = f32[2,6,1024,1024]{3,2,1,0} broadcast(reduce.71), dimensions={0,1,2} + convert.73 = bf16[2,6,1024,1024]{3,2,1,0} convert(broadcast.72) + add.74 = bf16[2,6,1024,1024]{3,2,1,0} add(divide.65, convert.73) + multiply.75 = bf16[2,6,1024,1024]{3,2,1,0} multiply(add.74, exponential.31) + select.76 = bf16[2,6,1024,1024]{3,2,1,0} select(constant.16, multiply.75, broadcast.8) + multiply.77 = bf16[2,6,1024,1024]{3,2,1,0} multiply(select.76, broadcast.6) + dot.80 = bf16[2,6,1024,64]{3,2,1,0} dot(multiply.77, Arg_1.2), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={3} + dot.78 = bf16[2,6,1024,64]{3,2,1,0} dot(multiply.77, Arg_0.1), lhs_batch_dims={0,1}, lhs_contracting_dims={2}, rhs_batch_dims={0,1}, rhs_contracting_dims={2} + transpose.79 = bf16[2,6,64,1024]{2,3,1,0} transpose(dot.78), dimensions={0,1,3,2} + dot.47 = bf16[2,6,64,1024]{3,2,1,0} dot(Arg_3.4, divide.43), lhs_batch_dims={0,1}, lhs_contracting_dims={2}, rhs_batch_dims={0,1}, rhs_contracting_dims={2} + transpose.48 = bf16[2,6,1024,64]{2,3,1,0} transpose(dot.47), dimensions={0,1,3,2} + ROOT tuple.81 = (bf16[2,6,1024,64]{3,2,1,0}, bf16[2,6,1024,64]{3,2,1,0}, bf16[2,6,64,1024]{2,3,1,0}, bf16[2,6,1024,64]{2,3,1,0}) tuple(dot.46, dot.80, transpose.79, transpose.48) + } + )"; + + return hlo_text; + } + + template + void TestImpl_Flash_Attention_Training_BMM1_Mask_Softmax_BMM2() { + stream_executor::CudaComputeCapability cc = GetCudaComputeCapability(); + se::dnn::VersionInfo real_cudnn_version = GetCudnnVersion(); + if (!(cc.IsAtLeast(se::CudaComputeCapability::AMPERE) && cc.minor == 0 && + real_cudnn_version >= se::dnn::VersionInfo(8, 9, 3))) { + GTEST_SKIP() << "Flash Attention is supported with the Nvidia AMPERE+ " + "GPUs and cuDNN >= 8.9.3."; + } + XlaBuilder builder(TestName()); + auto lhs_bmm1_literal = + GetInput4DLiteral({2, 6, 1024, 64}, {3, 2, 1, 0}); + auto rhs_bmm1_literal = + GetInput4DLiteral({2, 6, 64, 1024}, {3, 2, 1, 0}); + auto rhs_bmm2_literal = + GetInput4DLiteral({2, 6, 1024, 64}, {3, 2, 1, 0}); + auto mask_literal = GetMask4DLiteral({2, 6, 1024, 1024}, {3, 2, 1, 0}); + auto do_literal = GetInput4DLiteral({2, 6, 1024, 64}, {3, 2, 1, 0}); + std::string hlo_string = + GetModuleFlash_Attention_Training_BMM1_Mask_Softmax_BMM2_HloString_BF16(); // NOLINT + ExecuteAndCompare(hlo_string, + {&lhs_bmm1_literal, &rhs_bmm1_literal, &rhs_bmm2_literal, + &do_literal, &mask_literal}, + true); + } +}; + // BMM1 - BMM2 XLA_TEST_F(MultiHeadedAttentionBMMBMM, FMHABMM_BMM_vanilla_F16) { TestImpl_FMHABMM_BMM_vanilla(); @@ -3003,5 +3251,17 @@ XLA_TEST_F(FlashAttentionBMMScaleBiasMaskSoftmaxBMM, Flash_Attention_Training_BMM1_Bias_Mask_Softmax_BMM2_BF16) { TestImpl_Flash_Attention_Training_BMM1_Bias_Mask_Softmax_BMM2(); } + +// BMM1 - Scale - Softmax - BMM2 +XLA_TEST_F(FlashAttentionBMMScaleSoftmaxBMM, + Flash_Attention_Training_BMM1_Softmax_BMM2_BF16) { + TestImpl_Flash_Attention_Training_BMM1_Softmax_BMM2(); +} + +// BMM1 - Scale - Mask - Softmax - BMM2 +XLA_TEST_F(FlashAttentionBMMScaleMaskSoftmaxBMM, + Flash_Attention_Training_BMM1_Mask_Softmax_BMM2_BF16) { + TestImpl_Flash_Attention_Training_BMM1_Mask_Softmax_BMM2(); +} } // namespace gpu } // namespace xla diff --git a/third_party/xla/xla/service/gpu/tests/gpu_index_test.cc b/third_party/xla/xla/service/gpu/tests/gpu_index_test.cc index dc1b75abd984f1..d4bf10b8e424f0 100644 --- a/third_party/xla/xla/service/gpu/tests/gpu_index_test.cc +++ b/third_party/xla/xla/service/gpu/tests/gpu_index_test.cc @@ -18,11 +18,8 @@ limitations under the License. #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" -#include "xla/hlo/ir/hlo_module.h" -#include "xla/literal.h" #include "xla/service/gpu/tests/gpu_codegen_test.h" #include "xla/service/hlo_module_config.h" -#include "xla/service/hlo_parser.h" #include "xla/shape_util.h" #include "xla/tests/hlo_test_base.h" #include "xla/xla.pb.h" @@ -87,36 +84,6 @@ TEST_F(GpuIndexTest, CompatibleUseLinearIndexWithReshape) { /*match_optimized_ir=*/true); } -TEST_F(GpuIndexTest, - ReuseMultidimIndexWithTrivialReshapeAndNonContiguousBroadcast) { - HloModuleConfig config; - config.set_debug_options(HloTestBase::GetDebugOptionsForTest()); - auto module = ParseAndReturnVerifiedModule(R"( - HloModule test_module - - ENTRY CompatibleUseLinearIndexWithReshape { - x = f32[1,7,2,5,3]{4,3,2,1,0} parameter(0) - y = f32[2,1,3]{2,1,0} parameter(1) - reshape = f32[1,2,3]{2,1,0} reshape(y) - broadcast = f32[1,7,2,5,3]{4,3,2,1,0} broadcast(reshape), dimensions={0,2,4} - ROOT gte = pred[1,7,2,5,3]{4,3,2,1,0} compare(x, broadcast), direction=GE - })", - config) - .value(); - CompileAndVerifyIr(std::move(module), - R"( -; CHECK: %[[tmp4:.*]] = udiv i32 %[[linear_index:.*]], 1 -; CHECK: %[[dim4:.*]] = urem i32 %[[tmp4]], 3 -; CHECK: %[[tmp3:.*]] = udiv i32 %[[linear_index]], 3 -; CHECK: %[[dim3:.*]] = urem i32 %[[tmp3]], 5 -; CHECK: %[[tmp2:.*]] = udiv i32 %[[linear_index]], 15 -; CHECK: %[[dim2:.*]] = urem i32 %[[tmp2]], 2 -; CHECK: %[[tmp1:.*]] = udiv i32 %[[linear_index]], 30 -; CHECK: %{{.*}} = getelementptr inbounds [2 x [1 x [3 x float]]], ptr %{{.*}}, i32 0, i32 %[[dim2]], i32 0, i32 %[[dim4]] - )", - /*match_optimized_ir=*/false); -} - #if TENSORFLOW_USE_ROCM #else TEST_F(GpuIndexTest, CompatibleUseLinearIndexWithReshapeAndBroadcast) { diff --git a/third_party/xla/xla/service/gpu/tests/gpu_input_fusible_slice_test.cc b/third_party/xla/xla/service/gpu/tests/gpu_input_fusible_slice_test.cc index 58413aa022ff42..d7348f0a6d734c 100644 --- a/third_party/xla/xla/service/gpu/tests/gpu_input_fusible_slice_test.cc +++ b/third_party/xla/xla/service/gpu/tests/gpu_input_fusible_slice_test.cc @@ -67,12 +67,12 @@ TEST_F(GpuSliceInputFusionTest, InputFusionWithATupleOfSlices) { ParseAndReturnVerifiedModule(kHloString, ConfigWithoutLayoutAssignment()) .value(); auto expected_ir = is_built_with_rocm_ ? R"( -; CHECK-LABEL: define amdgpu_kernel void @fusion +; CHECK-LABEL: define amdgpu_kernel void @{{[a-z_]*}}fusion ; CHECK: slice2 ; CHECK: } )" : R"( -; CHECK-LABEL: define void @fusion +; CHECK-LABEL: define void @{{[a-z_]*}}fusion ; CHECK: slice2 ; CHECK: } )"; @@ -114,12 +114,12 @@ TEST_F(GpuSliceInputFusionTest, ConcatThenSplit) { ParseAndReturnVerifiedModule(kHloString, ConfigWithoutLayoutAssignment()) .value(); auto expected_ir = is_built_with_rocm_ ? R"( -; CHECK-LABEL: define amdgpu_kernel void @fusion +; CHECK-LABEL: define amdgpu_kernel void @{{[a-z_]*}}fusion ; CHECK: slice2 ; CHECK: } )" : R"( -; CHECK-LABEL: define void @fusion +; CHECK-LABEL: define void @{{[a-z_]*}}fusion ; CHECK: slice2 ; CHECK: } )"; diff --git a/third_party/xla/xla/service/gpu/tests/gpu_kernel_tiling_test.cc b/third_party/xla/xla/service/gpu/tests/gpu_kernel_tiling_test.cc index f42b559feca3ac..e073e100a79900 100644 --- a/third_party/xla/xla/service/gpu/tests/gpu_kernel_tiling_test.cc +++ b/third_party/xla/xla/service/gpu/tests/gpu_kernel_tiling_test.cc @@ -134,7 +134,7 @@ TEST_F(GpuKernelTilingTest, SimpleFusionWithTransposeTiled) { ParseAndReturnVerifiedModule(kHloString, ConfigWithoutLayoutAssignment()) .value(); auto expected_ir = R"( -; CHECK-LABEL: define KERNEL_ANNOTATION @fusion +; CHECK-LABEL: define KERNEL_ANNOTATION @{{[a-z_]*}}fusion ; CHECK: call void BARRIER() ; CHECK: } )"; @@ -170,7 +170,7 @@ TEST_F(GpuKernelTilingTest, MultipleOutputFusionWithOnePossibleTransposeTiled) { ParseAndReturnVerifiedModule(kHloString, ConfigWithoutLayoutAssignment()) .value(); auto expected_ir = R"( -; CHECK-LABEL: define KERNEL_ANNOTATION @fusion +; CHECK-LABEL: define KERNEL_ANNOTATION @{{[a-z_]*}}fusion ; CHECK: call void BARRIER() ; CHECK: } )"; @@ -202,7 +202,7 @@ TEST_F(GpuKernelTilingTest, TransposedInputWithUserReverseNotTiled) { ParseAndReturnVerifiedModule(kHloString, ConfigWithoutLayoutAssignment()) .value(); auto expected_ir = R"( -; CHECK-LABEL: define KERNEL_ANNOTATION @fusion +; CHECK-LABEL: define KERNEL_ANNOTATION @{{[a-z_]*}}fusion ; CHECK-NOT: call void BARRIER() ; CHECK: } )"; @@ -231,7 +231,7 @@ TEST_F(GpuKernelTilingTest, TransposedInputWithUserBitcastNotTiled) { ParseAndReturnVerifiedModule(kHloString, ConfigWithoutLayoutAssignment()) .value(); auto expected_ir = R"( -; CHECK-LABEL: define KERNEL_ANNOTATION @fusion +; CHECK-LABEL: define KERNEL_ANNOTATION @{{[a-z_]*}}fusion ; CHECK-NOT: call void BARRIER() ; CHECK: } )"; @@ -268,7 +268,7 @@ TEST_F(GpuKernelTilingTest, TransposedInputWithoutUnsafeUseTiled) { ParseAndReturnVerifiedModule(kHloString, ConfigWithoutLayoutAssignment()) .value(); auto expected_ir = R"( -; CHECK-LABEL: define KERNEL_ANNOTATION @fusion +; CHECK-LABEL: define KERNEL_ANNOTATION @{{[a-z_]*}}fusion ; CHECK: call void BARRIER() ; CHECK: } )"; @@ -342,7 +342,7 @@ TEST_F(GpuKernelTilingTest, ColumnReductionWithLayoutChangeTiled) { /*match_optimized_ir=*/true); // Check that the kernel runs correctly. - EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{0.001})); + EXPECT_TRUE(RunAndCompare(kHloString, ErrorSpec{0.001})); } TEST_F(GpuKernelTilingTest, RowReductionWithLayoutChangeTiled) { @@ -366,7 +366,7 @@ TEST_F(GpuKernelTilingTest, RowReductionWithLayoutChangeTiled) { ParseAndReturnVerifiedModule(kHloString, ConfigWithoutLayoutAssignment()) .value(); auto expected_ir = R"( -; CHECK-LABEL: define KERNEL_ANNOTATION @{{(wrapped_reduce|fusion)}} +; CHECK-LABEL: define KERNEL_ANNOTATION @{{(wrapped_reduce|.*fusion)}} ; CHECK: call SHUFFLE ; CHECK: } )"; @@ -375,7 +375,7 @@ TEST_F(GpuKernelTilingTest, RowReductionWithLayoutChangeTiled) { /*match_optimized_ir=*/true); // Check that the kernel runs correctly. - EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{0.001})); + EXPECT_TRUE(RunAndCompare(kHloString, ErrorSpec{0.001})); } TEST_F(GpuKernelTilingTest, RowReductionTwoRowsPerWarp) { @@ -400,7 +400,7 @@ TEST_F(GpuKernelTilingTest, RowReductionTwoRowsPerWarp) { ParseAndReturnVerifiedModule(kHloString, ConfigWithoutLayoutAssignment()) .value(); auto expected_ir = R"( -; CHECK-LABEL: define KERNEL_ANNOTATION @{{(wrapped_reduce|fusion)}} +; CHECK-LABEL: define KERNEL_ANNOTATION @{{(wrapped_reduce|.*fusion)}} ; CHECK: %[[TID_X:.*]] = tail call i32 TIDX() ; CHECK: %[[TID_LOGICAL:.*]] = and i32 %[[TID_X]], 15 ; CHECK: call SHUFFLE @@ -439,14 +439,11 @@ TEST_F(GpuKernelTilingTest, RowReductionFourRowsPerWarp) { ParseAndReturnVerifiedModule(kHloString, ConfigWithoutLayoutAssignment()) .value(); auto expected_ir = R"( -; CHECK-LABEL: define KERNEL_ANNOTATION @{{(wrapped_reduce|fusion)}} +; CHECK-LABEL: define KERNEL_ANNOTATION @{{(wrapped_reduce|.*fusion)}} ; CHECK: %[[TID_X:.*]] = tail call i32 TIDX() ; CHECK: %[[TID_LOGICAL:.*]] = and i32 %[[TID_X]], 7 ; CHECK: call SHUFFLE ; CHECK: %[[LOGICAL_T0:.*]] = icmp eq i32 %[[TID_LOGICAL]], 0 -; CHECK: LCAL -; CHECK: EXTV -; CHECK: BR_CAL )"; CompileAndVerifyIr(std::move(hlo_module), @@ -479,7 +476,7 @@ TEST_F(GpuKernelTilingTest, ParseAndReturnVerifiedModule(kHloString, ConfigWithoutLayoutAssignment()) .value(); const char *expected_ir = R"( -; CHECK-LABEL: define KERNEL_ANNOTATION @{{(wrapped_reduce|fusion)}} +; CHECK-LABEL: define KERNEL_ANNOTATION @{{(wrapped_reduce|.*fusion)}} ; CHECK: store float %{{.*}}, ptr addrspace(1) ; CHECK: } )"; @@ -557,7 +554,7 @@ TEST_F(GpuKernelTilingTest, ParseAndReturnVerifiedModule(kHloString, ConfigWithoutLayoutAssignment()) .value(); auto expected_ir = R"( -; CHECK-LABEL: define KERNEL_ANNOTATION @{{(wrapped_reduce|fusion)}} +; CHECK-LABEL: define KERNEL_ANNOTATION @{{(wrapped_reduce|.*fusion)}} ; CHECK-NOT: call SHUFFLE ; CHECK: } )"; @@ -637,10 +634,11 @@ TEST_F(GpuKernelTilingTest, RowReductionCorrectShmemUsage) { )"; auto hlo_module = ParseAndReturnVerifiedModule(kHloString).value(); auto expected_ir = is_built_with_rocm_ ? R"( -; CHECK: initial_value_addr = internal unnamed_addr addrspace({{[0-9]*}}) global [1024 x float] poison, align 4 +; CHECK: %llvm.amdgcn.kernel.input_reduce_fusion.lds.t = type { [4 x [2 x float]] } +; CHECK: @llvm.amdgcn.kernel.input_reduce_fusion.lds = internal addrspace(3) global %llvm.amdgcn.kernel.input_reduce_fusion.lds.t poison )" : R"( -; CHECK: shared_cache = private unnamed_addr addrspace({{[0-9]*}}) global [1 x [2 x float]] +; CHECK: shared_cache = private unnamed_addr addrspace({{[0-9]*}}) global [4 x [2 x float]] )"; CompileAndVerifyIr(std::move(hlo_module), expected_ir, /*match_optimized_ir=*/true); @@ -657,9 +655,9 @@ TEST_F(GpuKernelTilingTest, ReductionInputTooLarge) { } ENTRY reduce.1 { - parameter = f32[4,1048576,1024,1024] parameter(0) + parameter = f32[16,1048576,1024,1024] parameter(0) init_value = f32[] constant(0) - ROOT reduce = f32[4,1048576,1024] reduce(parameter, init_value), dimensions={3}, to_apply=Sum + ROOT reduce = f32[16,1048576,1024] reduce(parameter, init_value), dimensions={3}, to_apply=Sum } )"; auto hlo_module = ParseAndReturnVerifiedModule(kHloString).value(); diff --git a/third_party/xla/xla/service/gpu/tests/gpu_too_many_blocks_test.cc b/third_party/xla/xla/service/gpu/tests/gpu_too_many_blocks_test.cc index 5b1301fa3eea40..758e7397319aa2 100644 --- a/third_party/xla/xla/service/gpu/tests/gpu_too_many_blocks_test.cc +++ b/third_party/xla/xla/service/gpu/tests/gpu_too_many_blocks_test.cc @@ -55,7 +55,7 @@ ENTRY primitive_computation_mul.8 { EXPECT_FALSE(failed_executable.ok()); EXPECT_THAT( failed_executable.status().ToString(), - ::testing::ContainsRegex("Kernel 'fusion.*' launch needs more blocks")); + ::testing::ContainsRegex("Kernel '.*fusion.*' launch needs more blocks")); } } // namespace diff --git a/third_party/xla/xla/service/gpu/tests/gpu_unrolling_test.cc b/third_party/xla/xla/service/gpu/tests/gpu_unrolling_test.cc index f5bf91339c316d..786c0f3b0b68ae 100644 --- a/third_party/xla/xla/service/gpu/tests/gpu_unrolling_test.cc +++ b/third_party/xla/xla/service/gpu/tests/gpu_unrolling_test.cc @@ -50,27 +50,14 @@ TEST_F(GpuUnrollingTest, UnrollDefaultTimes) { CompileAndVerifyIr(std::move(hlo_module), R"( -; CHECK-LABEL: @fusion -; CHECK: load float -; CHECK: load float -; CHECK: fadd -; CHECK: store float -; CHECK: load float -; CHECK: load float -; CHECK: fadd -; CHECK: store float -; CHECK: load float -; CHECK: load float -; CHECK: fadd -; CHECK: store float -; CHECK: load float -; CHECK: load float -; CHECK: fadd -; CHECK: store float -; CHECK-NOT: fadd -; CHECK: } +; CHECK-LABEL: @{{[a-z_]*}}fusion +; CHECK-NOT: load float +; CHECK-NOT: store float +; CHECK: load <4 x float> +; CHECK: load <4 x float> +; CHECK: store <4 x float> )", - /*match_optimized_ir=*/false); + /*match_optimized_ir=*/true); } TEST_F(GpuUnrollingTest, UnrollUnfusedAdd) { @@ -91,26 +78,13 @@ TEST_F(GpuUnrollingTest, UnrollUnfusedAdd) { CompileAndVerifyIr(std::move(hlo_module), R"( ; CHECK-LABEL: @wrapped_add -; CHECK: load float -; CHECK: load float -; CHECK: fadd -; CHECK: store float -; CHECK: load float -; CHECK: load float -; CHECK: fadd -; CHECK: store float -; CHECK: load float -; CHECK: load float -; CHECK: fadd -; CHECK: store float -; CHECK: load float -; CHECK: load float -; CHECK: fadd -; CHECK: store float -; CHECK-NOT: fadd -; CHECK: } +; CHECK-NOT: load float +; CHECK-NOT: store float +; CHECK: load <4 x float> +; CHECK: load <4 x float> +; CHECK: store <4 x float> )", - /*match_optimized_ir=*/false); + /*match_optimized_ir=*/true); } TEST_F(GpuUnrollingTest, DisabledUnrollUnfusedSine) { @@ -243,38 +217,15 @@ TEST_F(GpuUnrollingTest, UnrollMultiOutputFusion) { CompileAndVerifyIr(std::move(hlo_module), R"( -; CHECK-LABEL: @fusion -; CHECK: load float -; CHECK: load float -; CHECK-NOT: load float -; CHECK-NOT: load float -; CHECK: fadd -; CHECK: load float -; CHECK: load float -; CHECK-NOT: load float -; CHECK-NOT: load float -; CHECK: fmul -; CHECK: store float -; CHECK: store float -; CHECK-NOT: store float -; CHECK-NOT: store float -; CHECK: load float -; CHECK: load float +; CHECK-LABEL: @{{[a-z_]*}}fusion ; CHECK-NOT: load float -; CHECK-NOT: load float -; CHECK: fadd -; CHECK: load float -; CHECK: load float -; CHECK-NOT: load float -; CHECK-NOT: load float -; CHECK: fmul -; CHECK: store float -; CHECK: store float -; CHECK-NOT: store float ; CHECK-NOT: store float -; CHECK: } +; CHECK: load <4 x float> +; CHECK: load <4 x float> +; CHECK: store <4 x float> +; CHECK: store <4 x float> )", - /*match_optimized_ir=*/false); + /*match_optimized_ir=*/true); } } // namespace diff --git a/third_party/xla/xla/service/gpu/tests/kernel_reuse.hlo b/third_party/xla/xla/service/gpu/tests/kernel_reuse.hlo index d4110f8d48864e..41734e06259a00 100644 --- a/third_party/xla/xla/service/gpu/tests/kernel_reuse.hlo +++ b/third_party/xla/xla/service/gpu/tests/kernel_reuse.hlo @@ -3,9 +3,9 @@ // All fusions must reuse the same kernel: // CHECK-LABEL: target triple // CHECK-PTX: define void -// CHECK-GCN: define amdgpu_kernelvoid +// CHECK-GCN: define amdgpu_kernel void // CHECK-PTX-NOT: define void -// CHECK-GCN-NOT: define amdgpu_kernelvoid +// CHECK-GCN-NOT: define amdgpu_kernel void HloModule KernelReuse, is_scheduled=true @@ -47,11 +47,11 @@ ENTRY main { // All (Triton) fusions must reuse the same kernel: // CHECK-LABEL: target triple // CHECK-PTX-NOT: define void -// CHECK-GCN-NOT: define amdgpu_kernelvoid +// CHECK-GCN-NOT: define amdgpu_kernel void // CHECK-PTX: define void @triton_gemm_dot1( -// CHECK-GCN: define amdgpu_kernelvoid @triton_gemm_dot1( +// CHECK-GCN: define amdgpu_kernel void @triton_gemm_dot1( // CHECK-PTX-NOT: define void -// CHECK-GCN-NOT: define amdgpu_kernelvoid +// CHECK-GCN-NOT: define amdgpu_kernel void HloModule t, is_scheduled=true @@ -74,8 +74,8 @@ ENTRY e { p2 = f16[15,19]{1,0} parameter(2) p1 = s8[19,17]{1,0} parameter(1) p0 = f16[15,19]{1,0} parameter(0) - triton_gemm_dot1 = f16[15,17]{1,0} fusion(p3, p2), kind=kCustom, calls=triton_gemm_dot1, backend_config="{ \"fusion_backend_config\": {kind: \"__triton_gemm\", triton_gemm_config: {\"block_m\":\"64\",\"block_n\":\"32\",\"block_k\":\"64\",\"split_k\":\"1\",\"num_stages\":\"4\",\"num_warps\":\"4\"}}}" - triton_gemm_dot0 = f16[15,17]{1,0} fusion(p1, p0), kind=kCustom, calls=triton_gemm_dot0, backend_config="{ \"fusion_backend_config\": {kind: \"__triton_gemm\", triton_gemm_config: {\"block_m\":\"64\",\"block_n\":\"32\",\"block_k\":\"64\",\"split_k\":\"1\",\"num_stages\":\"4\",\"num_warps\":\"4\"}}}" + triton_gemm_dot1 = f16[15,17]{1,0} fusion(p3, p2), kind=kCustom, calls=triton_gemm_dot1, backend_config="{ \"fusion_backend_config\": {kind: \"__triton_gemm\", triton_gemm_config: {\"block_m\":\"64\",\"block_n\":\"32\",\"block_k\":\"64\",\"split_k\":\"1\",\"num_stages\":\"4\",\"num_warps\":\"4\",\"num_ctas\":\"1\"}}}" + triton_gemm_dot0 = f16[15,17]{1,0} fusion(p1, p0), kind=kCustom, calls=triton_gemm_dot0, backend_config="{ \"fusion_backend_config\": {kind: \"__triton_gemm\", triton_gemm_config: {\"block_m\":\"64\",\"block_n\":\"32\",\"block_k\":\"64\",\"split_k\":\"1\",\"num_stages\":\"4\",\"num_warps\":\"4\",\"num_ctas\":\"1\"}}}" ROOT tuple = (f16[15,17]{1,0}, f16[15,17]{1,0}) tuple(triton_gemm_dot0, triton_gemm_dot1) } @@ -85,12 +85,12 @@ ENTRY e { // - @fusion_2's %arg0 must have align 16, because we are passing a module input // - @fusion_1's %arg0 must have align 128, because we are passing an internal buffer // CHECK-LABEL: target triple -// CHECK-PTX: define void @fusion_2(ptr noalias align 16 dereferenceable(100) %arg0, ptr noalias align 128 dereferenceable(100) %arg1) -// CHECK-GCN: define amdgpu_kernelvoid @fusion_2(ptr noalias align 16 dereferenceable(100) %arg0, ptr noalias align 128 dereferenceable(100) %arg1) -// CHECK-PTX: define void @fusion_1(ptr noalias align 128 dereferenceable(100) %arg0, ptr noalias align 128 dereferenceable(100) %arg1) -// CHECK-GCN: define amdgpu_kernelvoid @fusion_1(ptr noalias align 128 dereferenceable(100) %arg0, ptr noalias align 128 dereferenceable(100) %arg1) +// CHECK-PTX-DAG: define void @fusion_2(ptr noalias align 16 dereferenceable(100) %{{.*}}, ptr noalias align 128 dereferenceable(100) %{{.*}}) +// CHECK-GCN-DAG: define amdgpu_kernel void @fusion_2(ptr noalias align 16 dereferenceable(100) %{{.*}}, ptr noalias align 128 dereferenceable(100) %{{.*}}) +// CHECK-PTX-DAG: define void @fusion_1(ptr noalias align 128 dereferenceable(100) %{{.*}}, ptr noalias align 128 dereferenceable(100) %{{.*}}) +// CHECK-GCN-DAG: define amdgpu_kernel void @fusion_1(ptr noalias align 128 dereferenceable(100) %{{.*}}, ptr noalias align 128 dereferenceable(100) %{{.*}}) // CHECK-PTX-NOT: define void -// CHECK-GCN-NOT: define amdgpu_kernelvoid +// CHECK-GCN-NOT: define amdgpu_kernel void HloModule KernelReuse, is_scheduled=true @@ -132,10 +132,10 @@ ENTRY main { // The first has just 2 parameters (1 input, 1 output) and the second has 3 (2 input, 1 output). // All the parameters are noalias, because we are not passing the same argument twice to the kernel. // CHECK-LABEL: target triple -// CHECK-PTX: define void @fusion_2(ptr noalias align 128 dereferenceable(100) %arg0, ptr noalias align 128 dereferenceable(100) %arg1) -// CHECK-GCN: define amdgpu_kernelvoid @fusion_2(ptr noalias align 128 dereferenceable(100) %arg0, ptr noalias align 128 dereferenceable(100) %arg1) -// CHECK-PTX: define void @fusion_1(ptr noalias align 128 dereferenceable(100) %arg0, ptr noalias align 128 dereferenceable(100) %arg1, ptr noalias align 128 dereferenceable(100) %arg2) -// CHECK-GCN: define amdgpu_kernelvoid @fusion_1(ptr noalias align 128 dereferenceable(100) %arg0, ptr noalias align 128 dereferenceable(100) %arg1, ptr noalias align 128 dereferenceable(100) %arg2) +// CHECK-PTX-DAG: define void @fusion_2(ptr noalias align 128 dereferenceable(100) %{{.*}}, ptr noalias align 128 dereferenceable(100) %{{.*}}) +// CHECK-GCN-DAG: define amdgpu_kernel void @fusion_2(ptr noalias align 128 dereferenceable(100) %{{.*}}, ptr noalias align 128 dereferenceable(100) %{{.*}}) +// CHECK-PTX-DAG: define void @fusion_1(ptr noalias align 128 dereferenceable(100) %{{.*}}, ptr noalias align 128 dereferenceable(100) %{{.*}}, ptr noalias align 128 dereferenceable(100) %{{.*}}) +// CHECK-GCN-DAG: define amdgpu_kernel void @fusion_1(ptr noalias align 128 dereferenceable(100) %{{.*}}, ptr noalias align 128 dereferenceable(100) %{{.*}}, ptr noalias align 128 dereferenceable(100) %{{.*}}) // CHECK-NOT: define void HloModule KernelReuse, is_scheduled=true @@ -178,16 +178,16 @@ ENTRY main { // "!invariant.load" (thanks to ir_array.MarkInvariantOverWholeProgram). // // CHECK-LABEL: target triple -// CHECK-PTX: define void @fusion_2(ptr noalias align 128 dereferenceable(100) %arg0) -// CHECK-GCN: define amdgpu_kernelvoid @fusion_2(ptr noalias align 128 dereferenceable(100) %arg0) +// CHECK-PTX: define void @fusion_2(ptr noalias align 128 dereferenceable(100) %{{.*}}) +// CHECK-GCN: define amdgpu_kernel void @fusion_2(ptr noalias align 128 dereferenceable(100) %{{.*}}) // CHECK-NOT: !invariant.load -// CHECK-PTX: define void @fusion(ptr noalias align 128 dereferenceable(100) %arg0, ptr noalias align 128 dereferenceable(100) %arg1) -// CHECK-GCN: define amdgpu_kernelvoid @fusion(ptr noalias align 128 dereferenceable(100) %arg0, ptr noalias align 128 dereferenceable(100) %arg1) +// CHECK-PTX: define void @fusion(ptr noalias align 128 dereferenceable(100) %{{.*}}, ptr noalias align 128 dereferenceable(100) %{{.*}}) +// CHECK-GCN: define amdgpu_kernel void @fusion(ptr noalias align 128 dereferenceable(100) %{{.*}}, ptr noalias align 128 dereferenceable(100) %{{.*}}) // CHECK-PTX-NOT: define void -// CHECK-GCN-NOT: define amdgpu_kernelvoid +// CHECK-GCN-NOT: define amdgpu_kernel void // CHECK: !invariant.load // CHECK-PTX-NOT: define void -// CHECK-GCN-NOT: define amdgpu_kernelvoid +// CHECK-GCN-NOT: define amdgpu_kernel void HloModule KernelReuse, is_scheduled=true diff --git a/third_party/xla/xla/service/gpu/tests/launch_dimensions.hlo b/third_party/xla/xla/service/gpu/tests/launch_dimensions.hlo index ca926d2f0a7d6f..bcfa37733f7e67 100644 --- a/third_party/xla/xla/service/gpu/tests/launch_dimensions.hlo +++ b/third_party/xla/xla/service/gpu/tests/launch_dimensions.hlo @@ -2,13 +2,13 @@ // This tests that we do not increase the grid launch size when // few_waves is enabled. -// CHECK-LABEL: entry: -// CHECK-PTX: call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range ![[ctaid_range:[0-9]+]] -// CHECK-GCN: call i32 @llvm.amdgcn.workgroup.id.x(), !range ![[ctaid_range:[0-9]+]] -// CHECK-PTX: call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range ![[tid_range:[0-9]+]] -// CHECK-GCN: call i32 @llvm.amdgcn.workitem.id.x(), !range ![[tid_range:[0-9]+]] -// CHECK: ![[ctaid_range]] = !{i32 0, i32 2} -// CHECK: ![[tid_range]] = !{i32 0, i32 1024} +// CHECK-LABEL: define void @wrapped_b +// CHECK-PTX-DAG: call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range ![[ctaid_range:[0-9]+]] +// CHECK-GCN-DAG: call i32 @llvm.amdgcn.workgroup.id.x(), !range ![[ctaid_range:[0-9]+]] +// CHECK-PTX-DAG: call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range ![[tid_range:[0-9]+]] +// CHECK-GCN-DAG: call i32 @llvm.amdgcn.workitem.id.x(), !range ![[tid_range:[0-9]+]] +// CHECK-DAG: ![[ctaid_range]] = !{i32 0, i32 2} +// CHECK-DAG: ![[tid_range]] = !{i32 0, i32 1024} HloModule Test, is_scheduled=true @@ -27,14 +27,14 @@ ENTRY main { // This tests that we cap grid launch code when few_waves is enabled. -// CHECK-LABEL: entry: -// CHECK-PTX: call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range ![[ctaid_range:[0-9]+]] -// CHECK-GCN: call i32 @llvm.amdgcn.workgroup.id.x(), !range ![[ctaid_range:[0-9]+]] -// CHECK-PTX: call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range ![[tid_range:[0-9]+]] -// CHECK-GCN: call i32 @llvm.amdgcn.workitem.id.x(), !range ![[tid_range:[0-9]+]] -// CHECK-PTX: ![[ctaid_range]] = !{i32 0, i32 1008} -// CHECK-GCN: ![[ctaid_range]] = !{i32 0, i32 1664} -// CHECK: ![[tid_range]] = !{i32 0, i32 128} +// CHECK-LABEL: define void @wrapped_b +// CHECK-PTX-DAG: call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range ![[ctaid_range:[0-9]+]] +// CHECK-GCN-DAG: call i32 @llvm.amdgcn.workgroup.id.x(), !range ![[ctaid_range:[0-9]+]] +// CHECK-PTX-DAG: call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range ![[tid_range:[0-9]+]] +// CHECK-GCN-DAG: call i32 @llvm.amdgcn.workitem.id.x(), !range ![[tid_range:[0-9]+]] +// CHECK-PTX-DAG: ![[ctaid_range]] = !{i32 0, i32 1008} +// CHECK-GCN-DAG: ![[ctaid_range]] = !{i32 0, i32 1760} +// CHECK-DAG: ![[tid_range]] = !{i32 0, i32 128} HloModule Test, is_scheduled=true @@ -53,15 +53,15 @@ ENTRY main { // This tests that we cap grid launch code when few_waves is enabled // and scalar broadcast are present. -// CHECK-LABEL: entry: -// CHECK-PTX: call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range ![[ctaid_range:[0-9]+]] -// CHECK-PTX: call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range ![[tid_range:[0-9]+]] -// CHECK-PTX: ![[ctaid_range]] = !{i32 0, i32 1008} -// CHECK-PTX: ![[tid_range]] = !{i32 0, i32 128} -// CHECK-GCN: call i32 @llvm.amdgcn.workgroup.id.x(), !range ![[ctaid_range:[0-9]+]] -// CHECK-GCN: call i32 @llvm.amdgcn.workitem.id.x(), !range ![[tid_range:[0-9]+]] -// CHECK-GCN: ![[ctaid_range]] = !{i32 0, i32 1664} -// CHECK-GCN: ![[tid_range]] = !{i32 0, i32 128} +// CHECK-LABEL: define void @fusion_3 +// CHECK-PTX-DAG: call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range ![[ctaid_range:[0-9]+]] +// CHECK-PTX-DAG: call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range ![[tid_range:[0-9]+]] +// CHECK-PTX-DAG: ![[ctaid_range]] = !{i32 0, i32 1008} +// CHECK-PTX-DAG: ![[tid_range]] = !{i32 0, i32 128} +// CHECK-GCN-DAG: call i32 @llvm.amdgcn.workgroup.id.x(), !range ![[ctaid_range:[0-9]+]] +// CHECK-GCN-DAG: call i32 @llvm.amdgcn.workitem.id.x(), !range ![[tid_range:[0-9]+]] +// CHECK-GCN-DAG: ![[ctaid_range]] = !{i32 0, i32 1760} +// CHECK-GCN-DAG: ![[tid_range]] = !{i32 0, i32 128} HloModule ScalarBroadcast, is_scheduled=true @@ -84,15 +84,15 @@ ENTRY main { // This tests that we enable few_waves in a simple fusion. It is the baseline // for the tests below. -// CHECK-LABEL: entry: -// CHECK-PTX: call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range ![[ctaid_range:[0-9]+]] -// CHECK-PTX: call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range ![[tid_range:[0-9]+]] -// CHECK-PTX: ![[ctaid_range]] = !{i32 0, i32 1008} -// CHECK-PTX: ![[tid_range]] = !{i32 0, i32 128} -// CHECK-GCN: call i32 @llvm.amdgcn.workgroup.id.x(), !range ![[ctaid_range:[0-9]+]] -// CHECK-GCN: call i32 @llvm.amdgcn.workitem.id.x(), !range ![[tid_range:[0-9]+]] -// CHECK-GCN: ![[ctaid_range]] = !{i32 0, i32 1664} -// CHECK-GCN: ![[tid_range]] = !{i32 0, i32 128} +// CHECK-LABEL: define void @fusion +// CHECK-PTX-DAG: call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range ![[ctaid_range:[0-9]+]] +// CHECK-PTX-DAG: call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range ![[tid_range:[0-9]+]] +// CHECK-PTX-DAG: ![[ctaid_range]] = !{i32 0, i32 1008} +// CHECK-PTX-DAG: ![[tid_range]] = !{i32 0, i32 128} +// CHECK-GCN-DAG: call i32 @llvm.amdgcn.workgroup.id.x(), !range ![[ctaid_range:[0-9]+]] +// CHECK-GCN-DAG: call i32 @llvm.amdgcn.workitem.id.x(), !range ![[tid_range:[0-9]+]] +// CHECK-GCN-DAG: ![[ctaid_range]] = !{i32 0, i32 1760} +// CHECK-GCN-DAG: ![[tid_range]] = !{i32 0, i32 128} HloModule SimpleFusion, is_scheduled=true @@ -113,15 +113,15 @@ ENTRY main { // This tests that we keep few_waves enabled for large constants. -// CHECK-LABEL: entry: -// CHECK-PTX: call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range ![[ctaid_range:[0-9]+]] -// CHECK-PTX: call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range ![[tid_range:[0-9]+]] -// CHECK-PTX: ![[ctaid_range]] = !{i32 0, i32 1008} -// CHECK-PTX: ![[tid_range]] = !{i32 0, i32 128} -// CHECK-GCN: call i32 @llvm.amdgcn.workgroup.id.x(), !range ![[ctaid_range:[0-9]+]] -// CHECK-GCN: call i32 @llvm.amdgcn.workitem.id.x(), !range ![[tid_range:[0-9]+]] -// CHECK-GCN: ![[ctaid_range]] = !{i32 0, i32 1664} -// CHECK-GCN: ![[tid_range]] = !{i32 0, i32 128} +// CHECK-LABEL: define void @fusion +// CHECK-PTX-DAG: call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range ![[ctaid_range:[0-9]+]] +// CHECK-PTX-DAG: call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range ![[tid_range:[0-9]+]] +// CHECK-PTX-DAG: ![[ctaid_range]] = !{i32 0, i32 1008} +// CHECK-PTX-DAG: ![[tid_range]] = !{i32 0, i32 128} +// CHECK-GCN-DAG: call i32 @llvm.amdgcn.workgroup.id.x(), !range ![[ctaid_range:[0-9]+]] +// CHECK-GCN-DAG: call i32 @llvm.amdgcn.workitem.id.x(), !range ![[tid_range:[0-9]+]] +// CHECK-GCN-DAG: ![[ctaid_range]] = !{i32 0, i32 1760} +// CHECK-GCN-DAG: ![[tid_range]] = !{i32 0, i32 128} HloModule LargeConstant, is_scheduled=true @@ -141,15 +141,15 @@ ENTRY main { // This tests that we disable few_waves if a non-elementwise op is present. -// CHECK-LABEL: entry: -// CHECK-PTX: call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range ![[ctaid_range:[0-9]+]] -// CHECK-PTX: call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range ![[tid_range:[0-9]+]] -// CHECK-PTX: ![[ctaid_range]] = !{i32 0, i32 195313} -// CHECK-PTX: ![[tid_range]] = !{i32 0, i32 128} -// CHECK-GCN: call i32 @llvm.amdgcn.workgroup.id.x(), !range ![[ctaid_range:[0-9]+]] -// CHECK-GCN: call i32 @llvm.amdgcn.workitem.id.x(), !range ![[tid_range:[0-9]+]] -// CHECK-GCN: ![[ctaid_range]] = !{i32 0, i32 97657} -// CHECK-GCN: ![[tid_range]] = !{i32 0, i32 256} +// CHECK-LABEL: define void @fusion +// CHECK-PTX-DAG: call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range ![[ctaid_range:[0-9]+]] +// CHECK-PTX-DAG: call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range ![[tid_range:[0-9]+]] +// CHECK-PTX-DAG: ![[ctaid_range]] = !{i32 0, i32 195313} +// CHECK-PTX-DAG: ![[tid_range]] = !{i32 0, i32 128} +// CHECK-GCN-DAG: call i32 @llvm.amdgcn.workgroup.id.x(), !range ![[ctaid_range:[0-9]+]] +// CHECK-GCN-DAG: call i32 @llvm.amdgcn.workitem.id.x(), !range ![[tid_range:[0-9]+]] +// CHECK-GCN-DAG: ![[ctaid_range]] = !{i32 0, i32 97657} +// CHECK-GCN-DAG: ![[tid_range]] = !{i32 0, i32 256} HloModule NonElementwise, is_scheduled=true @@ -175,15 +175,15 @@ ENTRY main { // - the fusion is not row-vectorizable // It serves as a baseline for the tests below. -// CHECK-LABEL: entry: -// CHECK-PTX: call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range ![[ctaid_range:[0-9]+]] -// CHECK-PTX: call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range ![[tid_range:[0-9]+]] -// CHECK-PTX: ![[ctaid_range]] = !{i32 0, i32 7813} -// CHECK-PTX: ![[tid_range]] = !{i32 0, i32 128} -// CHECK-GCN: call i32 @llvm.amdgcn.workgroup.id.x(), !range ![[ctaid_range:[0-9]+]] -// CHECK-GCN: call i32 @llvm.amdgcn.workitem.id.x(), !range ![[tid_range:[0-9]+]] -// CHECK-GCN: ![[ctaid_range]] = !{i32 0, i32 3907} -// CHECK-GCN: ![[tid_range]] = !{i32 0, i32 256} +// CHECK-LABEL: define void @fusion +// CHECK-PTX-DAG: call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range ![[ctaid_range:[0-9]+]] +// CHECK-PTX-DAG: call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range ![[tid_range:[0-9]+]] +// CHECK-PTX-DAG: ![[ctaid_range]] = !{i32 0, i32 7813} +// CHECK-PTX-DAG: ![[tid_range]] = !{i32 0, i32 128} +// CHECK-GCN-DAG: call i32 @llvm.amdgcn.workgroup.id.x(), !range ![[ctaid_range:[0-9]+]] +// CHECK-GCN-DAG: call i32 @llvm.amdgcn.workitem.id.x(), !range ![[tid_range:[0-9]+]] +// CHECK-GCN-DAG: ![[ctaid_range]] = !{i32 0, i32 3907} +// CHECK-GCN-DAG: ![[tid_range]] = !{i32 0, i32 256} HloModule NoFewWaves, is_scheduled=true @@ -219,15 +219,15 @@ ENTRY main { // - the fusion IS row-vectorizable // In this case, the block count is changed from 7813 to 2000. -// CHECK-LABEL: entry: -// CHECK-PTX: call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range ![[ctaid_range:[0-9]+]] -// CHECK-PTX: call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range ![[tid_range:[0-9]+]] -// CHECK-PTX: ![[ctaid_range]] = !{i32 0, i32 2000} -// CHECK-PTX: ![[tid_range]] = !{i32 0, i32 500} -// CHECK-GCN: call i32 @llvm.amdgcn.workgroup.id.x(), !range ![[ctaid_range:[0-9]+]] -// CHECK-GCN: call i32 @llvm.amdgcn.workitem.id.x(), !range ![[tid_range:[0-9]+]] -// CHECK-GCN: ![[ctaid_range]] = !{i32 0, i32 2000} -// CHECK-GCN: ![[tid_range]] = !{i32 0, i32 500} +// CHECK-LABEL: define void @fusion +// CHECK-PTX-DAG: call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range ![[ctaid_range:[0-9]+]] +// CHECK-PTX-DAG: call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range ![[tid_range:[0-9]+]] +// CHECK-PTX-DAG: ![[ctaid_range]] = !{i32 0, i32 2000} +// CHECK-PTX-DAG: ![[tid_range]] = !{i32 0, i32 500} +// CHECK-GCN-DAG: call i32 @llvm.amdgcn.workgroup.id.x(), !range ![[ctaid_range:[0-9]+]] +// CHECK-GCN-DAG: call i32 @llvm.amdgcn.workitem.id.x(), !range ![[tid_range:[0-9]+]] +// CHECK-GCN-DAG: ![[ctaid_range]] = !{i32 0, i32 2000} +// CHECK-GCN-DAG: ![[tid_range]] = !{i32 0, i32 500} HloModule RowVectorizable, is_scheduled=true @@ -260,15 +260,15 @@ ENTRY main { // - the fusion is not row-vectorizable // In this case, the block count is changed from 7813 to 1008. -// CHECK-LABEL: entry: -// CHECK-PTX: call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range ![[ctaid_range:[0-9]+]] -// CHECK-PTX: call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range ![[tid_range:[0-9]+]] -// CHECK-PTX: ![[ctaid_range]] = !{i32 0, i32 1008} -// CHECK-PTX: ![[tid_range]] = !{i32 0, i32 128} -// CHECK-GCN: call i32 @llvm.amdgcn.workgroup.id.x(), !range ![[ctaid_range:[0-9]+]] -// CHECK-GCN: call i32 @llvm.amdgcn.workitem.id.x(), !range ![[tid_range:[0-9]+]] -// CHECK-GCN: ![[ctaid_range]] = !{i32 0, i32 1664} -// CHECK-GCN: ![[tid_range]] = !{i32 0, i32 128} +// CHECK-LABEL: define void @fusion +// CHECK-PTX-DAG: call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range ![[ctaid_range:[0-9]+]] +// CHECK-PTX-DAG: call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range ![[tid_range:[0-9]+]] +// CHECK-PTX-DAG: ![[ctaid_range]] = !{i32 0, i32 1008} +// CHECK-PTX-DAG: ![[tid_range]] = !{i32 0, i32 128} +// CHECK-GCN-DAG: call i32 @llvm.amdgcn.workgroup.id.x(), !range ![[ctaid_range:[0-9]+]] +// CHECK-GCN-DAG: call i32 @llvm.amdgcn.workitem.id.x(), !range ![[tid_range:[0-9]+]] +// CHECK-GCN-DAG: ![[ctaid_range]] = !{i32 0, i32 1760} +// CHECK-GCN-DAG: ![[tid_range]] = !{i32 0, i32 128} HloModule ScalarBroadcastFourInputs, is_scheduled=true @@ -300,14 +300,14 @@ ENTRY main { // This tests the GELU kernel. The original kernel that // motivated few_waves implementation. -// CHECK-LABEL: entry: -// CHECK-PTX: call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range ![[ctaid_range:[0-9]+]] -// CHECK-GCN: call i32 @llvm.amdgcn.workgroup.id.x(), !range ![[ctaid_range:[0-9]+]] -// CHECK-PTX: call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range ![[tid_range:[0-9]+]] -// CHECK-GCN: call i32 @llvm.amdgcn.workitem.id.x(), !range ![[tid_range:[0-9]+]] -// CHECK-PTX: ![[ctaid_range]] = !{i32 0, i32 1008} -// CHECK-GCN: ![[ctaid_range]] = !{i32 0, i32 1664} -// CHECK: ![[tid_range]] = !{i32 0, i32 128} +// CHECK-LABEL: define void @fusion +// CHECK-PTX-DAG: call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range ![[ctaid_range:[0-9]+]] +// CHECK-GCN-DAG: call i32 @llvm.amdgcn.workgroup.id.x(), !range ![[ctaid_range:[0-9]+]] +// CHECK-PTX-DAG: call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range ![[tid_range:[0-9]+]] +// CHECK-GCN-DAG: call i32 @llvm.amdgcn.workitem.id.x(), !range ![[tid_range:[0-9]+]] +// CHECK-PTX-DAG: ![[ctaid_range]] = !{i32 0, i32 1008} +// CHECK-GCN-DAG: ![[ctaid_range]] = !{i32 0, i32 1760} +// CHECK-DAG: ![[tid_range]] = !{i32 0, i32 128} HloModule Test, is_scheduled=true diff --git a/third_party/xla/xla/service/gpu/tests/reduce_atomic_min.hlo b/third_party/xla/xla/service/gpu/tests/reduce_atomic_min.hlo index ea888609024e78..cff4ff8a79a5b9 100644 --- a/third_party/xla/xla/service/gpu/tests/reduce_atomic_min.hlo +++ b/third_party/xla/xla/service/gpu/tests/reduce_atomic_min.hlo @@ -67,260 +67,244 @@ ENTRY reduce.1 { // CHECK: %[[VAL_36:.*]] = alloca float, align 4 // CHECK: %[[VAL_37:.*]] = alloca float, align 4 // CHECK: %[[VAL_38:.*]] = alloca float, align 4 -// CHECK: %[[VAL_39:.*]] = alloca float, align 4 +// CHECK: %[[LOOP3_I_2:loop3.invar_address.*]] = alloca i32, align 4 +// CHECK: %[[LOOP2_I_2:loop2.invar_address.*]] = alloca i32, align 4 +// CHECK: %[[VAL_42:return_buffer.*]] = alloca float, align 4 // CHECK: %[[VAL_40:.*]] = alloca i32, align 4 -// CHECK: %[[VAL_41:.*]] = alloca float, align 4 -// CHECK: %[[VAL_42:.*]] = alloca float, align 4 // CHECK: %[[VAL_43:.*]] = alloca i32, align 4 -// CHECK: %[[VAL_44:.*]] = alloca i32, align 4 -// CHECK: %[[VAL_45:.*]] = alloca float, align 4 -// CHECK: %[[VAL_46:.*]] = alloca float, align 4 +// CHECK: %partial_reduction_result = alloca float, align 4 +// CHECK: %reduction_input_address = alloca float, align 4 // CHECK-PTX: %[[VAL_47:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.y(), !range !4 -// CHECK-GCN: %[[VAL_56:.*]] = call i32 @llvm.amdgcn.workgroup.id.y +// CHECK-GCN: %[[VAL_47:.*]] = call i32 @llvm.amdgcn.workgroup.id.y // CHECK: %[[VAL_48:.*]] = icmp eq i32 %[[VAL_47]], 0 // CHECK: br i1 %[[VAL_48]], label %[[VAL_49:.*]], label %[[VAL_50:.*]] // CHECK: reduce-group-0-after: ; preds = %[[VAL_51:.*]], %[[VAL_52:.*]] // CHECK: ret void // CHECK: reduce-group-0-true: ; preds = %[[VAL_52]] // CHECK: %[[VAL_53:.*]] = load float, ptr %[[VAL_54:.*]], align 4, !invariant.load !5 -// CHECK: store float %[[VAL_53]], ptr %[[VAL_45]], align 4 -// CHECK-PTX: %[[VAL_55:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range !6 -// CHECK-GCN: %[[VAL_55:.*]] = call i32 @llvm.amdgcn.workitem.id.x -// CHECK-PTX: %[[VAL_56:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range !7 -// CHECK-GCN: %[[VAL_56:.*]] = call i32 @llvm.amdgcn.workgroup.id.x -// CHECK: %[[VAL_57:.*]] = urem i32 %[[VAL_55]], 1024 -// CHECK: %[[VAL_58:.*]] = udiv i32 %[[VAL_55]], 1024 -// CHECK: %[[VAL_59:.*]] = mul i32 %[[VAL_56]], 1 -// CHECK: %[[VAL_60:.*]] = add i32 %[[VAL_59]], %[[VAL_58]] -// CHECK: %[[VAL_61:.*]] = icmp ult i32 %[[VAL_60]], 19 -// CHECK: br i1 %[[VAL_61]], label %[[VAL_62:.*]], label %[[VAL_63:.*]] -// CHECK: 9: ; preds = %[[VAL_49]] -// CHECK: %[[VAL_65:.*]] = udiv i32 %[[VAL_57]], 1024 -// CHECK: %[[VAL_64:.*]] = urem i32 %[[VAL_57]], 1024 -// CHECK: %[[VAL_114:.*]] = mul i32 %[[VAL_64]], 2 -// CHECK: %[[VAL_66:.*]] = urem i32 %[[VAL_57]], 32 -// CHECK: %[[VAL_67:.*]] = udiv i32 %[[VAL_60]], 1 -// CHECK: %[[VAL_68:.*]] = urem i32 %[[VAL_67]], 19 -// CHECK: %[[VAL_69:.*]] = udiv i32 %[[VAL_60]], 19 -// CHECK: %[[VAL_70:.*]] = urem i32 %[[VAL_69]], 1 -// CHECK: %[[VAL_71:.*]] = udiv i32 %[[VAL_60]], 19 -// CHECK: %[[VAL_74:.*]] = icmp eq i32 %[[VAL_68]], 18 -// CHECK: %[[VAL_75:.*]] = select i1 %[[VAL_74]], i32 5088, i32 16384 -// CHECK: %[[VAL_76:.*]] = mul i32 %[[VAL_71]], 1 -// CHECK: %[[VAL_77:.*]] = mul i32 %[[VAL_70]], 1 -// CHECK: %[[VAL_78:.*]] = mul i32 %[[VAL_68]], 16384 -// CHECK: store i32 %[[VAL_65]], ptr %[[VAL_44]], align 4 -// CHECK: br label %[[VAL_79:.*]] -// CHECK: loop1.loop_header: ; preds = %[[VAL_80:.*]], %[[VAL_62]] -// CHECK: %[[VAL_81:.*]] = load i32, ptr %[[VAL_44]], align 4 -// CHECK: %[[VAL_82:.*]] = icmp uge i32 %[[VAL_81]], 1 -// CHECK: br i1 %[[VAL_82]], label %[[VAL_83:.*]], label %[[VAL_84:.*]] -// CHECK: loop1.loop_body: ; preds = %[[VAL_79]] -// CHECK: %[[VAL_85:.*]] = add nuw nsw i32 %[[VAL_81]], 1 -// CHECK: store i32 %[[VAL_85]], ptr %[[VAL_44]], align 4 -// CHECK: %[[VAL_86:.*]] = icmp eq i32 %[[VAL_81]], %[[VAL_65]] -// CHECK: %[[VAL_87:.*]] = icmp eq i32 16384, %[[VAL_75]] -// CHECK: br i1 %[[VAL_87]], label %[[VAL_88:.*]], label %[[VAL_89:.*]] -// CHECK: is_full_tile-after: ; preds = %[[VAL_90:.*]], %[[VAL_91:.*]] -// CHECK: br label %[[VAL_79]], !llvm.loop !8 -// CHECK: loop1.loop_exit: ; preds = %[[VAL_79]] -// CHECK: %[[VAL_92:.*]] = load float, ptr %[[VAL_45]], align 4 -// CHECK: %[[VAL_93:.*]] = call float @llvm.nvvm.shfl.sync.down.f32(i32 -1, float %[[VAL_92]], i32 16, i32 31) -// CHECK: store float %[[VAL_93]], ptr %[[VAL_37]], align 4 -// CHECK: call void @[[MIN:Min.*]](ptr %[[VAL_45]], ptr %[[VAL_37]], ptr %[[VAL_36]]) -// CHECK: %[[VAL_94:.*]] = load float, ptr %[[VAL_36]], align 4 -// CHECK: store float %[[VAL_94]], ptr %[[VAL_45]], align 4 -// CHECK: %[[VAL_95:.*]] = load float, ptr %[[VAL_45]], align 4 -// CHECK: %[[VAL_96:.*]] = call float @llvm.nvvm.shfl.sync.down.f32(i32 -1, float %[[VAL_95]], i32 8, i32 31) -// CHECK: store float %[[VAL_96]], ptr %[[VAL_35]], align 4 -// CHECK: call void @[[MIN]](ptr %[[VAL_45]], ptr %[[VAL_35]], ptr %[[VAL_34]]) -// CHECK: %[[VAL_97:.*]] = load float, ptr %[[VAL_34]], align 4 -// CHECK: store float %[[VAL_97]], ptr %[[VAL_45]], align 4 -// CHECK: %[[VAL_98:.*]] = load float, ptr %[[VAL_45]], align 4 -// CHECK: %[[VAL_99:.*]] = call float @llvm.nvvm.shfl.sync.down.f32(i32 -1, float %[[VAL_98]], i32 4, i32 31) -// CHECK: store float %[[VAL_99]], ptr %[[VAL_33]], align 4 -// CHECK: call void @[[MIN]](ptr %[[VAL_45]], ptr %[[VAL_33]], ptr %[[VAL_32]]) -// CHECK: %[[VAL_100:.*]] = load float, ptr %[[VAL_32]], align 4 -// CHECK: store float %[[VAL_100]], ptr %[[VAL_45]], align 4 -// CHECK: %[[VAL_101:.*]] = load float, ptr %[[VAL_45]], align 4 -// CHECK: %[[VAL_102:.*]] = call float @llvm.nvvm.shfl.sync.down.f32(i32 -1, float %[[VAL_101]], i32 2, i32 31) -// CHECK: store float %[[VAL_102]], ptr %[[VAL_31]], align 4 -// CHECK: call void @[[MIN]](ptr %[[VAL_45]], ptr %[[VAL_31]], ptr %[[VAL_30]]) -// CHECK: %[[VAL_103:.*]] = load float, ptr %[[VAL_30]], align 4 -// CHECK: store float %[[VAL_103]], ptr %[[VAL_45]], align 4 -// CHECK: %[[VAL_104:.*]] = load float, ptr %[[VAL_45]], align 4 -// CHECK: %[[VAL_105:.*]] = call float @llvm.nvvm.shfl.sync.down.f32(i32 -1, float %[[VAL_104]], i32 1, i32 31) -// CHECK: store float %[[VAL_105]], ptr %[[VAL_29]], align 4 -// CHECK: call void @[[MIN]](ptr %[[VAL_45]], ptr %[[VAL_29]], ptr %[[VAL_28]]) -// CHECK: %[[VAL_106:.*]] = load float, ptr %[[VAL_28]], align 4 -// CHECK: store float %[[VAL_106]], ptr %[[VAL_45]], align 4 -// CHECK: %[[VAL_107:.*]] = udiv i32 %[[VAL_64]], 32 -// CHECK: %[[VAL_108:.*]] = icmp eq i32 %[[VAL_66]], 0 -// CHECK: br i1 %[[VAL_108]], label %[[VAL_109:.*]], label %[[VAL_110:.*]] -// CHECK: intra_warp_reduce_write-after: ; preds = %[[VAL_109]], %[[VAL_83]] -// CHECK: call void @llvm.nvvm.barrier0() -// CHECK: %[[VAL_111:.*]] = icmp eq i32 %[[VAL_107]], 0 -// CHECK: br i1 %[[VAL_111]], label %[[VAL_112:.*]], label %[[VAL_51]] -// CHECK: inter_warp_reduce-after: ; preds = %[[VAL_113:.*]], %[[VAL_110]] +// CHECK: store float %[[VAL_53]], ptr %partial_reduction_result, align 4 +// CHECK-PTX: %thread.id.x = call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range !6 +// CHECK-GCN: %thread.id.x = call i32 @llvm.amdgcn.workitem.id.x +// CHECK-PTX: %block.id.x = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range !7 +// CHECK-GCN: %block.id.x = call i32 @llvm.amdgcn.workgroup.id.x +// CHECK: %thread.id.2 = urem i32 %thread.id.x, 1024 +// CHECK: %lane_id = urem i32 %thread.id.x, 32 +// CHECK: %[[VAL_63:.*]] = udiv i32 %block.id.x, 1 +// CHECK: %[[VECTOR_OFFSET:.*]] = urem i32 %[[VAL_63]], 1 +// CHECK: %[[VAL_63_2:.*]] = udiv i32 %block.id.x, 1 +// CHECK: %[[VAL_64:.*]] = urem i32 %[[VAL_63_2]], 19 +// CHECK: %[[VAL_65:.*]] = udiv i32 %block.id.x, 19 +// CHECK: %[[VAL_66:.*]] = urem i32 %[[VAL_65]], 1 +// CHECK: %[[VAL_67:.*]] = udiv i32 %block.id.x, 19 +// CHECK: %[[VAL_68:.*]] = icmp eq i32 %[[VAL_64]], 18 +// CHECK: %tile_bound.2 = select i1 %[[VAL_68]], i32 2544, i32 8192 +// CHECK: %tile_origin.0 = mul i32 %[[VAL_67]], 1 +// CHECK: %tile_origin.1 = mul i32 %[[VAL_66]], 1 +// CHECK: %tile_origin.2 = mul i32 %[[VAL_64]], 8192 +// CHECK: %tile_origin.3 = mul i32 %[[VECTOR_OFFSET]], 2 +// CHECK: %[[VAL_81:.*]] = icmp eq i32 8192, %tile_bound.2 +// CHECK: br i1 %[[VAL_81]], label %[[VAL_82:.*]], label %[[VAL_83:.*]] +// CHECK: is_full_tile-after: ; preds = %[[VAL_84:.*]], %[[VAL_85:.*]] +// CHECK: %[[VAL_86:.*]] = load float, ptr %partial_reduction_result, align 4 +// CHECK: %[[VAL_87:.*]] = call float @llvm.nvvm.shfl.sync.down.f32(i32 -1, float %[[VAL_86]], i32 16, i32 31) +// CHECK: store float %[[VAL_87]], ptr %[[VAL_37]], align 4 +// CHECK: call void @[[MIN:Min.*]](ptr %partial_reduction_result, ptr %[[VAL_37]], ptr %[[VAL_36]]) +// CHECK: %[[VAL_88:.*]] = load float, ptr %[[VAL_36]], align 4 +// CHECK: store float %[[VAL_88]], ptr %partial_reduction_result, align 4 +// CHECK: %[[VAL_89:.*]] = load float, ptr %partial_reduction_result, align 4 +// CHECK: %[[VAL_90:.*]] = call float @llvm.nvvm.shfl.sync.down.f32(i32 -1, float %[[VAL_89]], i32 8, i32 31) +// CHECK: store float %[[VAL_90]], ptr %[[VAL_35]], align 4 +// CHECK: call void @[[MIN]](ptr %partial_reduction_result, ptr %[[VAL_35]], ptr %[[VAL_34]]) +// CHECK: %[[VAL_91:.*]] = load float, ptr %[[VAL_34]], align 4 +// CHECK: store float %[[VAL_91]], ptr %partial_reduction_result, align 4 +// CHECK: %[[VAL_92:.*]] = load float, ptr %partial_reduction_result, align 4 +// CHECK: %[[VAL_93:.*]] = call float @llvm.nvvm.shfl.sync.down.f32(i32 -1, float %[[VAL_92]], i32 4, i32 31) +// CHECK: store float %[[VAL_93]], ptr %[[VAL_33]], align 4 +// CHECK: call void @[[MIN]](ptr %partial_reduction_result, ptr %[[VAL_33]], ptr %[[VAL_32]]) +// CHECK: %[[VAL_94:.*]] = load float, ptr %[[VAL_32]], align 4 +// CHECK: store float %[[VAL_94]], ptr %partial_reduction_result, align 4 +// CHECK: %[[VAL_95:.*]] = load float, ptr %partial_reduction_result, align 4 +// CHECK: %[[VAL_96:.*]] = call float @llvm.nvvm.shfl.sync.down.f32(i32 -1, float %[[VAL_95]], i32 2, i32 31) +// CHECK: store float %[[VAL_96]], ptr %[[VAL_31]], align 4 +// CHECK: call void @[[MIN]](ptr %partial_reduction_result, ptr %[[VAL_31]], ptr %[[VAL_30]]) +// CHECK: %[[VAL_97:.*]] = load float, ptr %[[VAL_30]], align 4 +// CHECK: store float %[[VAL_97]], ptr %partial_reduction_result, align 4 +// CHECK: %[[VAL_98:.*]] = load float, ptr %partial_reduction_result, align 4 +// CHECK: %[[VAL_99:.*]] = call float @llvm.nvvm.shfl.sync.down.f32(i32 -1, float %[[VAL_98]], i32 1, i32 31) +// CHECK: store float %[[VAL_99]], ptr %[[VAL_29]], align 4 +// CHECK: call void @[[MIN]](ptr %partial_reduction_result, ptr %[[VAL_29]], ptr %[[VAL_28]]) +// CHECK: %[[VAL_100:.*]] = load float, ptr %[[VAL_28]], align 4 +// CHECK: store float %[[VAL_100]], ptr %partial_reduction_result, align 4 +// CHECK: %[[VAL_101:.*]] = udiv i32 %thread.id.2, 32 +// CHECK: br i1 true, label %[[VAL_105:.*]], label %[[VAL_51]] +// CHECK: thread_in_bounds-after: // CHECK: br label %[[VAL_50]] -// CHECK: early_return: ; preds = %[[VAL_49]] -// CHECK: ret void -// CHECK: is_full_tile-true: ; preds = %[[VAL_84]] +// CHECK: is_full_tile-true: // CHECK: store i32 0, ptr %[[VAL_43]], align 4 -// CHECK: br label %[[VAL_115:.*]] -// CHECK: loop2.loop_header: ; preds = %[[VAL_116:.*]], %[[VAL_88]] -// CHECK: %[[VAL_117:.*]] = load i32, ptr %[[VAL_43]], align 4 -// CHECK: %[[VAL_118:.*]] = icmp uge i32 %[[VAL_117]], 8 -// CHECK: br i1 %[[VAL_118]], label %[[VAL_91]], label %[[VAL_116]] -// CHECK: loop2.loop_body: ; preds = %[[VAL_115]] -// CHECK: %[[VAL_119:.*]] = add nuw nsw i32 %[[VAL_117]], 1 -// CHECK: store i32 %[[VAL_119]], ptr %[[VAL_43]], align 4 -// CHECK: %[[VAL_120:.*]] = icmp eq i32 %[[VAL_117]], 0 -// CHECK: %[[VAL_121:.*]] = mul i32 %[[VAL_117]], 2048 -// CHECK: %[[VAL_122:.*]] = add i32 %[[VAL_121]], 0 -// CHECK: %[[VAL_123:.*]] = add i32 %[[VAL_122]], %[[VAL_114]] -// CHECK: %[[VAL_124:.*]] = add i32 %[[VAL_77]], %[[VAL_81]] -// CHECK: %[[VAL_125:.*]] = add i32 %[[VAL_78]], %[[VAL_123]] -// CHECK: %[[VAL_126:.*]] = getelementptr inbounds [300000 x float], ptr %[[VAL_127:.*]], i32 0, i32 %[[VAL_125]] -// CHECK: %[[VAL_128:.*]] = load float, ptr %[[VAL_126]], align 4, !invariant.load !5 -// CHECK: store float %[[VAL_128]], ptr %[[VAL_46]], align 4 -// CHECK: call void @[[MIN]](ptr %[[VAL_45]], ptr %[[VAL_46]], ptr %[[VAL_42]]) -// CHECK: %[[VAL_130:.*]] = load float, ptr %[[VAL_42]], align 4 -// CHECK: store float %[[VAL_130]], ptr %[[VAL_45]], align 4 -// CHECK: %[[VAL_131:.*]] = mul i32 %[[VAL_117]], 2048 -// CHECK: %[[VAL_132:.*]] = add i32 %[[VAL_131]], 1 -// CHECK: %[[VAL_133:.*]] = add i32 %[[VAL_132]], %[[VAL_114]] -// CHECK: %[[VAL_134:.*]] = add i32 %[[VAL_77]], %[[VAL_81]] -// CHECK: %[[VAL_135:.*]] = add i32 %[[VAL_78]], %[[VAL_133]] -// CHECK: %[[VAL_136:.*]] = getelementptr inbounds [300000 x float], ptr %[[VAL_127]], i32 0, i32 %[[VAL_135]] -// CHECK: %[[VAL_137:.*]] = load float, ptr %[[VAL_136]], align 4, !invariant.load !5 -// CHECK: store float %[[VAL_137]], ptr %[[VAL_46]], align 4 -// CHECK: call void @[[MIN]](ptr %[[VAL_45]], ptr %[[VAL_46]], ptr %[[VAL_41]]) -// CHECK: %[[VAL_139:.*]] = load float, ptr %[[VAL_41]], align 4 -// CHECK: store float %[[VAL_139]], ptr %[[VAL_45]], align 4 -// CHECK: br label %[[VAL_115]], !llvm.loop !10 -// CHECK: loop2.loop_exit: ; preds = %[[VAL_115]] -// CHECK: br label %[[VAL_80]] -// CHECK: is_full_tile-false: ; preds = %[[VAL_84]] -// CHECK: store i32 0, ptr %[[VAL_40]], align 4 -// CHECK: br label %[[VAL_141:.*]] -// CHECK: loop2.loop_header7: ; preds = %[[VAL_142:.*]], %[[VAL_89]] -// CHECK: %[[VAL_143:.*]] = load i32, ptr %[[VAL_40]], align 4 -// CHECK: %[[VAL_144:.*]] = icmp uge i32 %[[VAL_143]], 8 -// CHECK: br i1 %[[VAL_144]], label %[[VAL_90]], label %[[VAL_145:.*]] -// CHECK: loop2.loop_body8: ; preds = %[[VAL_141]] -// CHECK: %[[VAL_146:.*]] = add nuw nsw i32 %[[VAL_143]], 1 -// CHECK: store i32 %[[VAL_146]], ptr %[[VAL_40]], align 4 -// CHECK: %[[VAL_147:.*]] = icmp eq i32 %[[VAL_143]], 0 -// CHECK: %[[VAL_148:.*]] = mul i32 %[[VAL_143]], 2048 -// CHECK: %[[VAL_149:.*]] = add i32 %[[VAL_148]], 0 -// CHECK: %[[VAL_150:.*]] = add i32 %[[VAL_149]], %[[VAL_114]] -// CHECK: %[[VAL_151:.*]] = icmp ult i32 %[[VAL_150]], %[[VAL_75]] -// CHECK: br i1 %[[VAL_151]], label %[[VAL_152:.*]], label %[[VAL_153:.*]] -// CHECK: x_in_tile-after: ; preds = %[[VAL_152]], %[[VAL_145]] -// CHECK: %[[VAL_154:.*]] = mul i32 %[[VAL_143]], 2048 -// CHECK: %[[VAL_155:.*]] = add i32 %[[VAL_154]], 1 -// CHECK: %[[VAL_156:.*]] = add i32 %[[VAL_155]], %[[VAL_114]] -// CHECK: %[[VAL_157:.*]] = icmp ult i32 %[[VAL_156]], %[[VAL_75]] -// CHECK: br i1 %[[VAL_157]], label %[[VAL_158:.*]], label %[[VAL_142]] -// CHECK: x_in_tile-after16: ; preds = %[[VAL_158]], %[[VAL_153]] -// CHECK: br label %[[VAL_141]], !llvm.loop !12 -// CHECK: loop2.loop_exit6: ; preds = %[[VAL_141]] +// CHECK: br label %[[VAL_107:.*]] +// CHECK: loop2.loop_header: ; preds = %[[VAL_108:.*]], %[[VAL_82]] +// CHECK: %[[VAL_109:.*]] = load i32, ptr %[[VAL_43]], align 4 +// CHECK: %[[VAL_110:.*]] = icmp uge i32 %[[VAL_109]], 8 +// CHECK: br i1 %[[VAL_110]], label %loop2.loop_exit, label %loop2.loop_body +// CHECK: loop2.loop_body: ; preds = %[[VAL_107]] +// CHECK: %[[VAL_111:.*]] = add nuw nsw i32 %[[VAL_109]], 1 +// CHECK: store i32 %[[VAL_111]], ptr %[[VAL_43]], align 4 +// CHECK: %[[VAL_112:.*]] = icmp eq i32 %[[VAL_109]], 0 +// CHECK: %[[OFFSET_2:.*]] = add i32 %loop2.indvar, %thread.id.2 +// CHECK: store i32 0, ptr %loop3.invar_address, align 4 +// CHECK: br label %loop3.loop_header +// CHECK: loop3.loop_header: +// CHECK: %loop3.indvar = load i32, ptr %loop3.invar_address, align 4 +// CHECK: %[[LOOP3_OOB:.*]] = icmp uge i32 %loop3.indvar, 2 +// CHECK: br i1 %[[LOOP3_OOB]], label %loop3.loop_exit, label %loop3.loop_body +// CHECK: loop3.loop_body: +// CHECK: %[[LOOP3_INC:.*]] = add nuw nsw i32 %loop3.indvar, 1 +// CHECK: store i32 %[[LOOP3_INC]], ptr %loop3.invar_address, align 4 +// CHECK: %[[START_0:.*]] = add i32 %tile_origin.0, 0 +// CHECK: %[[START_1:.*]] = add i32 %tile_origin.1, 0 +// CHECK: %[[START_2:.*]] = add i32 %tile_origin.2, %[[OFFSET_2]] +// CHECK: %[[START_3:.*]] = add i32 %tile_origin.3, %loop3.indvar +// CHECK: %[[VAL_113:.*]] = mul nuw nsw i32 %[[START_3]], 1 +// CHECK: %[[VAL_114:.*]] = add nuw nsw i32 0, %[[VAL_113]] +// CHECK: %[[VAL_115:.*]] = mul nuw nsw i32 %[[START_2]], 2 +// CHECK: %[[VAL_116:.*]] = add nuw nsw i32 %[[VAL_114]], %[[VAL_115]] +// CHECK: %[[VAL_119:.*]] = getelementptr inbounds [300000 x float], ptr %[[VAL_120:.*]], i32 0, i32 %[[VAL_116]] +// CHECK: %[[VAL_121:.*]] = load float, ptr %[[VAL_119]], align 4, !invariant.load !5 +// CHECK: store float %[[VAL_121]], ptr %reduction_input_address, align 4 +// CHECK: call void @[[MIN]](ptr %partial_reduction_result, ptr %reduction_input_address, ptr %[[VAL_42]]) +// CHECK: %[[VAL_123:.*]] = load float, ptr %[[VAL_42]], align 4 +// CHECK: store float %[[VAL_123]], ptr %partial_reduction_result, align 4 +// CHECK: br label %loop3.loop_header +// CHECK: loop3.loop_exit: +// CHECK: br label %loop2.loop_header +// CHECK: loop2.loop_exit: // CHECK: br label %is_full_tile-after -// CHECK: x_in_tile-true: ; preds = %[[VAL_145]] -// CHECK: %[[VAL_159:.*]] = add i32 %[[VAL_77]], %[[VAL_81]] -// CHECK: %[[VAL_160:.*]] = add i32 %[[VAL_78]], %[[VAL_150]] -// CHECK: %[[VAL_161:.*]] = getelementptr inbounds [300000 x float], ptr %[[VAL_127]], i32 0, i32 %[[VAL_160]] -// CHECK: %[[VAL_162:.*]] = load float, ptr %[[VAL_161]], align 4, !invariant.load !5 -// CHECK: store float %[[VAL_162]], ptr %[[VAL_46]], align 4 -// CHECK: call void @[[MIN]](ptr %[[VAL_45]], ptr %[[VAL_46]], ptr %[[VAL_39]]) -// CHECK: %[[VAL_164:.*]] = load float, ptr %[[VAL_39]], align 4 -// CHECK: store float %[[VAL_164]], ptr %[[VAL_45]], align 4 -// CHECK: br label %[[VAL_153]] -// CHECK: x_in_tile-true15: ; preds = %[[VAL_153]] -// CHECK: %[[VAL_165:.*]] = add i32 %[[VAL_77]], %[[VAL_81]] -// CHECK: %[[VAL_166:.*]] = add i32 %[[VAL_78]], %[[VAL_156]] -// CHECK: %[[VAL_167:.*]] = getelementptr inbounds [300000 x float], ptr %[[VAL_127]], i32 0, i32 %[[VAL_166]] -// CHECK: %[[VAL_168:.*]] = load float, ptr %[[VAL_167]], align 4, !invariant.load !5 -// CHECK: store float %[[VAL_168]], ptr %[[VAL_46]], align 4 -// CHECK: call void @[[MIN]](ptr %[[VAL_45]], ptr %[[VAL_46]], ptr %[[VAL_38]]) -// CHECK: %[[VAL_170:.*]] = load float, ptr %[[VAL_38]], align 4 -// CHECK: store float %[[VAL_170]], ptr %[[VAL_45]], align 4 -// CHECK: br label %[[VAL_142]] -// CHECK: intra_warp_reduce_write-true: -// CHECK: %[[VAL_173:.*]] = load float, ptr %[[VAL_45]], align 4 -// CHECK: %[[VAL_171:.*]] = getelementptr inbounds [1 x [32 x float]], ptr addrspace(3) @shared_cache, i32 0, i32 %[[VAL_58]], i32 %[[VAL_107]] -// CHECK: %[[VAL_172:.*]] = addrspacecast ptr addrspace(3) %[[VAL_171]] to ptr -// CHECK: store float %[[VAL_173]], ptr %[[VAL_172]], align 4 -// CHECK: br label %[[VAL_110]] -// CHECK: inter_warp_reduce-true: ; preds = %[[VAL_110]] -// CHECK: %[[VAL_174:.*]] = getelementptr inbounds [1 x [32 x float]], ptr addrspace(3) @shared_cache, i32 0, i32 %[[VAL_58]], i32 %[[VAL_66]] -// CHECK: %[[VAL_175:.*]] = addrspacecast ptr addrspace(3) %[[VAL_174]] to ptr -// CHECK: store float %[[VAL_53]], ptr %[[VAL_27]], align 4 -// CHECK: %[[VAL_176:.*]] = icmp ult i32 %[[VAL_64]], 32 -// CHECK: %[[VAL_177:.*]] = select i1 %[[VAL_176]], ptr %[[VAL_175]], ptr %[[VAL_27]] -// CHECK: %[[VAL_178:.*]] = load float, ptr %[[VAL_177]], align 4 -// CHECK: %[[VAL_179:.*]] = call float @llvm.nvvm.shfl.sync.down.f32(i32 -1, float %[[VAL_178]], i32 16, i32 31) -// CHECK: store float %[[VAL_179]], ptr %[[VAL_26]], align 4 -// CHECK: call void @[[MIN]](ptr %[[VAL_177]], ptr %[[VAL_26]], ptr %[[VAL_25]]) -// CHECK: %[[VAL_180:.*]] = load float, ptr %[[VAL_25]], align 4 -// CHECK: store float %[[VAL_180]], ptr %[[VAL_177]], align 4 -// CHECK: %[[VAL_181:.*]] = load float, ptr %[[VAL_177]], align 4 -// CHECK: %[[VAL_182:.*]] = call float @llvm.nvvm.shfl.sync.down.f32(i32 -1, float %[[VAL_181]], i32 8, i32 31) -// CHECK: store float %[[VAL_182]], ptr %[[VAL_24]], align 4 -// CHECK: call void @[[MIN]](ptr %[[VAL_177]], ptr %[[VAL_24]], ptr %[[VAL_23]]) -// CHECK: %[[VAL_183:.*]] = load float, ptr %[[VAL_23]], align 4 -// CHECK: store float %[[VAL_183]], ptr %[[VAL_177]], align 4 -// CHECK: %[[VAL_184:.*]] = load float, ptr %[[VAL_177]], align 4 -// CHECK: %[[VAL_185:.*]] = call float @llvm.nvvm.shfl.sync.down.f32(i32 -1, float %[[VAL_184]], i32 4, i32 31) -// CHECK: store float %[[VAL_185]], ptr %[[VAL_22]], align 4 -// CHECK: call void @[[MIN]](ptr %[[VAL_177]], ptr %[[VAL_22]], ptr %[[VAL_21]]) -// CHECK: %[[VAL_186:.*]] = load float, ptr %[[VAL_21]], align 4 -// CHECK: store float %[[VAL_186]], ptr %[[VAL_177]], align 4 -// CHECK: %[[VAL_187:.*]] = load float, ptr %[[VAL_177]], align 4 -// CHECK: %[[VAL_188:.*]] = call float @llvm.nvvm.shfl.sync.down.f32(i32 -1, float %[[VAL_187]], i32 2, i32 31) -// CHECK: store float %[[VAL_188]], ptr %[[VAL_20]], align 4 -// CHECK: call void @[[MIN]](ptr %[[VAL_177]], ptr %[[VAL_20]], ptr %[[VAL_19]]) -// CHECK: %[[VAL_189:.*]] = load float, ptr %[[VAL_19]], align 4 -// CHECK: store float %[[VAL_189]], ptr %[[VAL_177]], align 4 -// CHECK: %[[VAL_190:.*]] = load float, ptr %[[VAL_177]], align 4 -// CHECK: %[[VAL_191:.*]] = call float @llvm.nvvm.shfl.sync.down.f32(i32 -1, float %[[VAL_190]], i32 1, i32 31) -// CHECK: store float %[[VAL_191]], ptr %[[VAL_18]], align 4 -// CHECK: call void @[[MIN]](ptr %[[VAL_177]], ptr %[[VAL_18]], ptr %[[VAL_17]]) -// CHECK: %[[VAL_192:.*]] = load float, ptr %[[VAL_17]], align 4 -// CHECK: store float %[[VAL_192]], ptr %[[VAL_177]], align 4 -// CHECK: %[[VAL_193:.*]] = icmp eq i32 %[[VAL_64]], 0 -// CHECK: br i1 %[[VAL_193]], label %[[VAL_194:.*]], label %[[VAL_113]] -// CHECK: reduction_write_output-after: ; preds = %[[VAL_195:.*]], %[[VAL_112]] +// CHECK: is_full_tile-false: +// CHECK: store i32 0, ptr %[[LOOP2_I_2]], align 4 +// CHECK: br label %[[VAL_134:.*]] +// CHECK: loop2.loop_header4: +// CHECK: %[[VAL_136:.*]] = load i32, ptr %[[LOOP2_I_2]], align 4 +// CHECK: %[[VAL_137:.*]] = icmp uge i32 %[[VAL_136]], 8 +// CHECK: br i1 %[[VAL_137]], label %[[VAL_84]], label %[[VAL_138:.*]] +// CHECK: loop2.loop_body5: +// CHECK: %[[VAL_139:.*]] = add nuw nsw i32 %[[VAL_136]], 1 +// CHECK: store i32 %[[VAL_139]], ptr %[[LOOP2_I_2]], align 4 +// CHECK: %[[VAL_140:.*]] = icmp eq i32 %[[VAL_136]], 0 +// CHECK: %[[VAL_141:.*]] = add i32 %[[VAL_136]], %thread.id.2 +// CHECK: %[[VAL_144:.*]] = icmp ult i32 %[[VAL_141]], %tile_bound.2 +// CHECK: br i1 %[[VAL_144]], label %x_in_tile-true, label %x_in_tile-after +// CHECK: x_in_tile-after: +// CHECK: br label %loop2.loop_header4 +// CHECK: loop2.loop_exit3: +// CHECK: br label %is_full_tile-after +// CHECK: x_in_tile-true: ; preds = %[[VAL_138]] +// CHECK: store i32 0, ptr %[[LOOP3_I_2]], align 4 +// CHECK: br label %loop3.loop_header10 +// CHECK: loop3.loop_header10: +// CHECK: %[[VAL_145:.*]] = load i32, ptr %[[LOOP3_I_2]], align 4 +// CHECK: %[[VAL_146:.*]] = icmp uge i32 %[[VAL_145]], 2 +// CHECK: br i1 %[[VAL_146]], label %loop3.loop_exit9, label %loop3.loop_body11 +// CHECK: loop3.loop_body11: +// CHECK: %[[VAL_147:.*]] = add nuw nsw i32 %[[VAL_145]], 1 +// CHECK: store i32 %[[VAL_147]], ptr %[[LOOP3_I_2]], align 4 +// CHECK: %[[IDX0:.*]] = add i32 %tile_origin.0, 0 +// CHECK: %[[IDX1:.*]] = add i32 %tile_origin.1, 0 +// CHECK: %[[IDX2:.*]] = add i32 %tile_origin.2, %[[VAL_141]] +// CHECK: %[[IDX3:.*]] = add i32 %tile_origin.3, %[[VAL_145]] +// CHECK: %[[VAL_148:.*]] = mul nuw nsw i32 %[[IDX3]], 1 +// CHECK: %[[VAL_149:.*]] = add nuw nsw i32 0, %[[VAL_148]] +// CHECK: %[[VAL_150:.*]] = mul nuw nsw i32 %[[IDX2]], 2 +// CHECK: %[[VAL_151:.*]] = add nuw nsw i32 %[[VAL_149]], %[[VAL_150]] +// CHECK: %[[VAL_155:.*]] = getelementptr inbounds [300000 x float], ptr %[[VAL_120]], i32 0, i32 %[[VAL_151]] +// CHECK: %[[VAL_156:.*]] = load float, ptr %[[VAL_155]], align 4, !invariant.load !5 +// CHECK: store float %[[VAL_156]], ptr %reduction_input_address, align 4 +// CHECK: call void @[[MIN]](ptr %partial_reduction_result, ptr %reduction_input_address, ptr %[[VAL_38]]) +// CHECK: %[[VAL_158:.*]] = load float, ptr %[[VAL_38]], align 4 +// CHECK: store float %[[VAL_158]], ptr %partial_reduction_result, align 4 +// CHECK: br label %loop3.loop_header10 +// CHECK: loop3.loop_exit9: +// CHECK: br label %x_in_tile-after +// CHECK: thread_in_bounds-true: +// CHECK: %[[VAL_166:.*]] = icmp eq i32 %lane_id, 0 +// CHECK: br i1 %[[VAL_166]], label %[[VAL_167:.*]], label %[[VAL_168:.*]] +// CHECK: intra_warp_reduce_write-after: ; preds = %[[VAL_167]], %[[VAL_105]] +// CHECK: call void @llvm.nvvm.barrier0() +// CHECK: %[[VAL_169:.*]] = icmp eq i32 %[[VAL_101]], 0 +// CHECK: br i1 %[[VAL_169]], label %inter_warp_reduce-true, label %inter_warp_reduce-after +// CHECK: inter_warp_reduce-after: ; preds = %[[VAL_171:.*]], %[[VAL_168]] // CHECK: br label %[[VAL_51]] -// CHECK: reduction_write_output-true: ; preds = %[[VAL_112]] -// CHECK: %[[VAL_197:.*]] = add i32 %[[VAL_77]], %[[VAL_65]] -// CHECK: %[[VAL_198:.*]] = add i32 %[[VAL_78]], %[[VAL_114]] -// CHECK: %[[VAL_199:.*]] = load float, ptr %[[VAL_177]], align 4 -// CHECK: %[[VAL_200:.*]] = load i32, ptr %[[VAL_201:.*]], align 4 -// CHECK: store i32 %[[VAL_200]], ptr %[[VAL_16]], align 4 -// CHECK: br label %[[VAL_202:.*]] -// CHECK: atomic_op_loop_exit: ; preds = %[[VAL_203:.*]], %[[VAL_202]] -// CHECK: br label %[[VAL_113]] -// CHECK: atomic_op_loop_body: ; preds = %[[VAL_203]], %[[VAL_194]] -// CHECK: %[[VAL_204:.*]] = load i32, ptr %[[VAL_16]], align 4 -// CHECK: store i32 %[[VAL_204]], ptr %[[VAL_15]], align 4 -// CHECK: call void @[[MIN]](ptr %[[VAL_15]], ptr %[[VAL_177]], ptr %[[VAL_15]]) -// CHECK: %[[VAL_205:.*]] = load i32, ptr %[[VAL_15]], align 4 -// CHECK: %[[VAL_206:.*]] = icmp eq i32 %[[VAL_204]], %[[VAL_205]] -// CHECK: br i1 %[[VAL_206]], label %[[VAL_195]], label %[[VAL_203]] -// CHECK: atomic_op_loop_cas: ; preds = %[[VAL_202]] -// CHECK: %[[VAL_207:.*]] = cmpxchg ptr %[[VAL_201]], i32 %[[VAL_204]], i32 %[[VAL_205]] seq_cst seq_cst, align 4 -// CHECK: %[[VAL_208:.*]] = extractvalue { i32, i1 } %[[VAL_207]], 0 -// CHECK: store i32 %[[VAL_208]], ptr %[[VAL_16]], align 4 -// CHECK: %[[VAL_209:.*]] = extractvalue { i32, i1 } %[[VAL_207]], 1 -// CHECK: br i1 %[[VAL_209]], label %[[VAL_195]], label %[[VAL_202]] +// CHECK: intra_warp_reduce_write-true: ; preds = %[[VAL_105]] +// CHECK: %[[VAL_172:.*]] = load float, ptr %partial_reduction_result, align 4 +// CHECK: %[[VAL_173:.*]] = getelementptr inbounds [1 x [32 x float]], ptr addrspace(3) @shared_cache, i32 0, i32 0, i32 %[[VAL_101]] +// CHECK: %[[VAL_174:.*]] = addrspacecast ptr addrspace(3) %[[VAL_173]] to ptr +// CHECK: store float %[[VAL_172]], ptr %[[VAL_174]], align 4 +// CHECK: br label %[[VAL_168]] +// CHECK: inter_warp_reduce-true: ; preds = %[[VAL_168]] +// CHECK: %[[VAL_175:.*]] = getelementptr inbounds [1 x [32 x float]], ptr addrspace(3) @shared_cache, i32 0, i32 0, i32 %lane_id +// CHECK: %[[VAL_176:.*]] = addrspacecast ptr addrspace(3) %[[VAL_175]] to ptr +// CHECK: store float %[[VAL_53]], ptr %[[VAL_27]], align 4 +// CHECK: %[[VAL_177:.*]] = icmp ult i32 %thread.id.2, 32 +// CHECK: %[[VAL_178:.*]] = select i1 %[[VAL_177]], ptr %[[VAL_176]], ptr %[[VAL_27]] +// CHECK: %[[VAL_179:.*]] = load float, ptr %[[VAL_178]], align 4 +// CHECK: %[[VAL_180:.*]] = call float @llvm.nvvm.shfl.sync.down.f32(i32 -1, float %[[VAL_179]], i32 16, i32 31) +// CHECK: store float %[[VAL_180]], ptr %[[VAL_26]], align 4 +// CHECK: call void @[[MIN]](ptr %[[VAL_178]], ptr %[[VAL_26]], ptr %[[VAL_25]]) +// CHECK: %[[VAL_181:.*]] = load float, ptr %[[VAL_25]], align 4 +// CHECK: store float %[[VAL_181]], ptr %[[VAL_178]], align 4 +// CHECK: %[[VAL_182:.*]] = load float, ptr %[[VAL_178]], align 4 +// CHECK: %[[VAL_183:.*]] = call float @llvm.nvvm.shfl.sync.down.f32(i32 -1, float %[[VAL_182]], i32 8, i32 31) +// CHECK: store float %[[VAL_183]], ptr %[[VAL_24]], align 4 +// CHECK: call void @[[MIN]](ptr %[[VAL_178]], ptr %[[VAL_24]], ptr %[[VAL_23]]) +// CHECK: %[[VAL_184:.*]] = load float, ptr %[[VAL_23]], align 4 +// CHECK: store float %[[VAL_184]], ptr %[[VAL_178]], align 4 +// CHECK: %[[VAL_185:.*]] = load float, ptr %[[VAL_178]], align 4 +// CHECK: %[[VAL_186:.*]] = call float @llvm.nvvm.shfl.sync.down.f32(i32 -1, float %[[VAL_185]], i32 4, i32 31) +// CHECK: store float %[[VAL_186]], ptr %[[VAL_22]], align 4 +// CHECK: call void @[[MIN]](ptr %[[VAL_178]], ptr %[[VAL_22]], ptr %[[VAL_21]]) +// CHECK: %[[VAL_187:.*]] = load float, ptr %[[VAL_21]], align 4 +// CHECK: store float %[[VAL_187]], ptr %[[VAL_178]], align 4 +// CHECK: %[[VAL_188:.*]] = load float, ptr %[[VAL_178]], align 4 +// CHECK: %[[VAL_189:.*]] = call float @llvm.nvvm.shfl.sync.down.f32(i32 -1, float %[[VAL_188]], i32 2, i32 31) +// CHECK: store float %[[VAL_189]], ptr %[[VAL_20]], align 4 +// CHECK: call void @[[MIN]](ptr %[[VAL_178]], ptr %[[VAL_20]], ptr %[[VAL_19]]) +// CHECK: %[[VAL_190:.*]] = load float, ptr %[[VAL_19]], align 4 +// CHECK: store float %[[VAL_190]], ptr %[[VAL_178]], align 4 +// CHECK: %[[VAL_191:.*]] = load float, ptr %[[VAL_178]], align 4 +// CHECK: %[[VAL_192:.*]] = call float @llvm.nvvm.shfl.sync.down.f32(i32 -1, float %[[VAL_191]], i32 1, i32 31) +// CHECK: store float %[[VAL_192]], ptr %[[VAL_18]], align 4 +// CHECK: call void @[[MIN]](ptr %[[VAL_178]], ptr %[[VAL_18]], ptr %[[VAL_17]]) +// CHECK: %[[VAL_193:.*]] = load float, ptr %[[VAL_17]], align 4 +// CHECK: store float %[[VAL_193]], ptr %[[VAL_178]], align 4 +// CHECK: %[[VAL_194:.*]] = icmp eq i32 %thread.id.2, 0 +// CHECK: br i1 %[[VAL_194]], label %[[VAL_195:.*]], label %[[VAL_171]] +// CHECK: reduction_write_output-after: +// CHECK: br label %inter_warp_reduce-after +// CHECK: reduction_write_output-true: +// CHECK: %[[VAL_200:.*]] = load float, ptr %[[VAL_178]], align 4 +// CHECK: %[[VAL_201:.*]] = load i32, ptr %[[VAL_202:.*]], align 4 +// CHECK: store i32 %[[VAL_201]], ptr %[[VAL_16]], align 4 +// CHECK: br label %[[VAL_203:.*]] +// CHECK: atomic_op_loop_exit: ; preds = %[[VAL_204:.*]], %[[VAL_203]] +// CHECK: br label %[[VAL_171]] +// CHECK: atomic_op_loop_body: ; preds = %[[VAL_204]], %[[VAL_195]] +// CHECK: %[[VAL_205:.*]] = load i32, ptr %[[VAL_16]], align 4 +// CHECK: store i32 %[[VAL_205]], ptr %[[VAL_15]], align 4 +// CHECK: call void @[[MIN]](ptr %[[VAL_15]], ptr %[[VAL_178]], ptr %[[VAL_15]]) +// CHECK: %[[VAL_206:.*]] = load i32, ptr %[[VAL_15]], align 4 +// CHECK: %[[VAL_207:.*]] = icmp eq i32 %[[VAL_205]], %[[VAL_206]] +// CHECK: br i1 %[[VAL_207]], label %atomic_op_loop_exit, label %atomic_op_loop_cas +// CHECK: atomic_op_loop_cas: ; preds = %[[VAL_203]] +// CHECK: %[[VAL_208:.*]] = cmpxchg ptr %[[VAL_202]], i32 %[[VAL_205]], i32 %[[VAL_206]] seq_cst seq_cst, align 4 +// CHECK: %[[VAL_209:.*]] = extractvalue { i32, i1 } %[[VAL_208]], 0 +// CHECK: store i32 %[[VAL_209]], ptr %[[VAL_16]], align 4 +// CHECK: %[[VAL_210:.*]] = extractvalue { i32, i1 } %[[VAL_208]], 1 +// CHECK: br i1 %[[VAL_210]], label %atomic_op_loop_exit, label %atomic_op_loop_body // CHECK: entry: -// CHECK: %[[VAL_210:.*]] = alloca float, align 4 -// CHECK: %[[VAL_211:.*]] = load float, ptr %[[VAL_212:.*]], align 4 -// CHECK: %[[VAL_213:.*]] = load float, ptr %[[VAL_214:.*]], align 4 -// CHECK: %[[VAL_215:.*]] = call float @llvm.minimum.f32(float %[[VAL_211]], float %[[VAL_213]]) -// CHECK: store float %[[VAL_215]], ptr %[[VAL_210]], align 4 -// CHECK: %[[VAL_216:.*]] = load float, ptr %[[VAL_210]], align 4 -// CHECK: store float %[[VAL_216]], ptr %[[VAL_217:.*]], align 4 +// CHECK: %[[VAL_211:.*]] = alloca float, align 4 +// CHECK: %[[VAL_212:.*]] = load float, ptr %[[VAL_213:.*]], align 4 +// CHECK: %[[VAL_214:.*]] = load float, ptr %[[VAL_215:.*]], align 4 +// CHECK: %[[VAL_216:.*]] = call float @llvm.minimum.f32(float %[[VAL_212]], float %[[VAL_214]]) +// CHECK: store float %[[VAL_216]], ptr %[[VAL_211]], align 4 +// CHECK: %[[VAL_217:.*]] = load float, ptr %[[VAL_211]], align 4 +// CHECK: store float %[[VAL_217]], ptr %[[VAL_218:.*]], align 4 // CHECK: ret void diff --git a/third_party/xla/xla/service/gpu/tests/reduce_column_layout_change.hlo b/third_party/xla/xla/service/gpu/tests/reduce_column_layout_change.hlo new file mode 100644 index 00000000000000..4c90b12e02ca73 --- /dev/null +++ b/third_party/xla/xla/service/gpu/tests/reduce_column_layout_change.hlo @@ -0,0 +1,196 @@ +// RUN: hlo-opt %s --platform=gpu --stage=llvm-before-optimizations --xla_gpu_target_config_filename=%S/../../../tools/hlo_opt/gpu_specs/%{GPU}.txtpb | FileCheck %s --check-prefixes=CHECK,CHECK-%{PTX} + +HloModule reduce_with_layout_change, is_scheduled=true + +reduction0 { + x0 = f32[] parameter(0) + y0 = f32[] parameter(1) + ROOT add0 = f32[] add(x0, y0) +} + +fused_computation { + arg0 = f32[12,3,32,16,32,4,3,12] parameter(0) + constant0 = f32[] constant(0) + ROOT reduce0 = f32[16,32,4,3,12]{1,3,2,0,4} reduce(arg0, constant0), dimensions={0,1,2}, to_apply=reduction0 +} + +ENTRY kernel_entry { + arg0 = f32[12,3,32,16,32,4,3,12] parameter(0) + ROOT fusion = f32[16,32,4,3,12]{1,3,2,0,4} fusion(arg0), kind=kInput, calls=fused_computation +} + +// CHECK-LABEL: entry: +// CHECK: %[[VAL_0:.*]] = alloca float, align 4 +// CHECK: %[[VAL_1:.*]] = alloca float, align 4 +// CHECK: %[[VAL_2:.*]] = alloca float, align 4 +// CHECK: %[[VAL_3:.*]] = alloca float, align 4 +// CHECK: %[[VAL_4:.*]] = alloca float, align 4 +// CHECK: %[[VAL_5:.*]] = alloca float, align 4 +// CHECK: %[[VAL_6:.*]] = alloca float, align 4 +// CHECK: %[[VAL_7:.*]] = alloca float, align 4 +// CHECK: %[[VAL_8:.*]] = alloca float, align 4 +// CHECK: %[[VAL_9:.*]] = alloca float, align 4 +// CHECK: %[[VAL_10:.*]] = alloca float, align 4 +// CHECK: %[[VAL_11:.*]] = alloca i32, align 4 +// CHECK: %[[VAL_12:.*]] = alloca i32, align 4 +// CHECK: %[[VAL_13:.*]] = alloca float, align 4 +// CHECK: %[[VAL_14:.*]] = alloca float, align 4 +// CHECK-PTX: %[[VAL_15:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.y(), !range !2 +// CHECK-GCN: %[[VAL_15:.*]] = call i32 @llvm.amdgcn.workgroup.id.y +// CHECK: %[[VAL_16:.*]] = icmp eq i32 %[[VAL_15]], 0 +// CHECK: br i1 %[[VAL_16]], label %[[VAL_17:.*]], label %[[VAL_18:.*]] +// CHECK: reduce-group-0-after: ; preds = %[[VAL_19:.*]], %[[VAL_20:.*]] +// CHECK: ret void +// CHECK: reduce-group-0-true: ; preds = %[[VAL_20]] +// CHECK: %[[VAL_21:.*]] = load float, ptr @0, align 4 +// CHECK: store float %[[VAL_21]], ptr %[[VAL_13]], align 4 +// CHECK-PTX: %thread.id.x = call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range !3 +// CHECK-GCN: %thread.id.x = call i32 @llvm.amdgcn.workitem.id.x +// CHECK-PTX: %block.id.x = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range !4 +// CHECK-GCN: %block.id.x = call i32 @llvm.amdgcn.workgroup.id.x +// CHECK: %[[VAL_22:.*]] = udiv i32 %thread.id.x, 32 +// CHECK: %thread.id.1 = urem i32 %[[VAL_22]], 32 +// CHECK: %thread.id.2 = urem i32 %thread.id.x, 32 +// CHECK: %lane_id = urem i32 %thread.id.x, 32 +// CHECK: %[[VAL_23:.*]] = udiv i32 %block.id.x, 1 +// CHECK: %[[VAL_24:.*]] = urem i32 %[[VAL_23]], 2304 +// CHECK: %[[VAL_25:.*]] = udiv i32 %block.id.x, 2304 +// CHECK: %[[VAL_26:.*]] = urem i32 %[[VAL_25]], 1 +// CHECK: %[[VAL_27:.*]] = udiv i32 %block.id.x, 2304 +// CHECK: %[[VAL_28:.*]] = icmp eq i32 %[[VAL_26]], 0 +// CHECK: %tile_bound.1 = select i1 %[[VAL_28]], i32 1152, i32 4096 +// CHECK: %tile_origin.0 = mul i32 %[[VAL_27]], 1 +// CHECK: %tile_origin.1 = mul i32 %[[VAL_26]], 4096 +// CHECK: %tile_origin.2 = mul i32 %[[VAL_24]], 32 +// CHECK: store i32 %thread.id.1, ptr %[[VAL_12]], align 4 +// CHECK: br label %[[VAL_29:.*]] +// CHECK: loop1.loop_header: ; preds = %[[VAL_30:.*]], %[[VAL_17]] +// CHECK: %[[VAL_31:.*]] = load i32, ptr %[[VAL_12]], align 4 +// CHECK: %[[VAL_32:.*]] = icmp uge i32 %[[VAL_31]], %tile_bound.1 +// CHECK: br i1 %[[VAL_32]], label %[[VAL_33:.*]], label %[[VAL_34:.*]] +// CHECK: loop1.loop_body: ; preds = %[[VAL_29]] +// CHECK: %[[VAL_35:.*]] = add nuw nsw i32 %[[VAL_31]], 32 +// CHECK: store i32 %[[VAL_35]], ptr %[[VAL_12]], align 4 +// CHECK: %[[VAL_36:.*]] = icmp eq i32 %[[VAL_31]], %thread.id.1 +// CHECK: store i32 0, ptr %[[VAL_11]], align 4 +// CHECK: br label %[[VAL_37:.*]] +// CHECK: loop2.loop_header: ; preds = %[[VAL_38:.*]], %[[VAL_34]] +// CHECK: %[[VAL_39:.*]] = load i32, ptr %[[VAL_11]], align 4 +// CHECK: %[[VAL_40:.*]] = icmp uge i32 %[[VAL_39]], 32 +// CHECK: br i1 %[[VAL_40]], label %[[VAL_30]], label %[[VAL_41:.*]] +// CHECK: loop2.loop_body: ; preds = %[[VAL_37]] +// CHECK: %[[VAL_42:.*]] = add nuw nsw i32 %[[VAL_39]], 32 +// CHECK: store i32 %[[VAL_42]], ptr %[[VAL_11]], align 4 +// CHECK: %[[VAL_43:.*]] = icmp eq i32 %[[VAL_39]], 0 +// CHECK: %[[VAL_44:.*]] = add i32 %[[VAL_39]], %thread.id.2 +// CHECK: %[[VAL_45:.*]] = icmp ult i32 %[[VAL_44]], 32 +// CHECK: br i1 %[[VAL_45]], label %[[VAL_46:.*]], label %[[VAL_38]] +// CHECK: x_in_tile-after: ; preds = %[[VAL_46]], %[[VAL_41]] +// CHECK: br label %[[VAL_37]], !llvm.loop !5 +// CHECK: loop2.loop_exit: ; preds = %[[VAL_37]] +// CHECK: br label %[[VAL_29]], !llvm.loop !8 +// CHECK: loop1.loop_exit: ; preds = %[[VAL_29]] +// CHECK: %[[VAL_47:.*]] = load float, ptr %[[VAL_13]], align 4 +// CHECK: %[[VAL_48:.*]] = getelementptr inbounds [32 x [33 x float]], ptr addrspace(3) @shared_cache, i32 0, i32 %thread.id.2, i32 %thread.id.1 +// CHECK: %[[VAL_49:.*]] = addrspacecast ptr addrspace(3) %[[VAL_48]] to ptr +// CHECK: store float %[[VAL_47]], ptr %[[VAL_49]], align 4 +// CHECK: call void @llvm.nvvm.barrier0() +// CHECK: %[[VAL_50:.*]] = getelementptr inbounds [32 x [33 x float]], ptr addrspace(3) @shared_cache, i32 0, i32 %thread.id.1, i32 %thread.id.2 +// CHECK: %[[VAL_51:.*]] = addrspacecast ptr addrspace(3) %[[VAL_50]] to ptr +// CHECK: %[[VAL_52:.*]] = load float, ptr %[[VAL_51]], align 4 +// CHECK: %[[VAL_53:.*]] = call float @llvm.nvvm.shfl.sync.down.f32(i32 -1, float %[[VAL_52]], i32 16, i32 31) +// CHECK: store float %[[VAL_53]], ptr %[[VAL_9]], align 4 +// CHECK: call void @[[REDUCTION0:reduction0.*]](ptr %[[VAL_51]], ptr %[[VAL_9]], ptr %[[VAL_8]]) +// CHECK: %[[VAL_54:.*]] = load float, ptr %[[VAL_8]], align 4 +// CHECK: store float %[[VAL_54]], ptr %[[VAL_51]], align 4 +// CHECK: %[[VAL_55:.*]] = load float, ptr %[[VAL_51]], align 4 +// CHECK: %[[VAL_56:.*]] = call float @llvm.nvvm.shfl.sync.down.f32(i32 -1, float %[[VAL_55]], i32 8, i32 31) +// CHECK: store float %[[VAL_56]], ptr %[[VAL_7]], align 4 +// CHECK: call void @[[REDUCTION0]](ptr %[[VAL_51]], ptr %[[VAL_7]], ptr %[[VAL_6]]) +// CHECK: %[[VAL_57:.*]] = load float, ptr %[[VAL_6]], align 4 +// CHECK: store float %[[VAL_57]], ptr %[[VAL_51]], align 4 +// CHECK: %[[VAL_58:.*]] = load float, ptr %[[VAL_51]], align 4 +// CHECK: %[[VAL_59:.*]] = call float @llvm.nvvm.shfl.sync.down.f32(i32 -1, float %[[VAL_58]], i32 4, i32 31) +// CHECK: store float %[[VAL_59]], ptr %[[VAL_5]], align 4 +// CHECK: call void @[[REDUCTION0]](ptr %[[VAL_51]], ptr %[[VAL_5]], ptr %[[VAL_4]]) +// CHECK: %[[VAL_60:.*]] = load float, ptr %[[VAL_4]], align 4 +// CHECK: store float %[[VAL_60]], ptr %[[VAL_51]], align 4 +// CHECK: %[[VAL_61:.*]] = load float, ptr %[[VAL_51]], align 4 +// CHECK: %[[VAL_62:.*]] = call float @llvm.nvvm.shfl.sync.down.f32(i32 -1, float %[[VAL_61]], i32 2, i32 31) +// CHECK: store float %[[VAL_62]], ptr %[[VAL_3]], align 4 +// CHECK: call void @[[REDUCTION0]](ptr %[[VAL_51]], ptr %[[VAL_3]], ptr %[[VAL_2]]) +// CHECK: %[[VAL_63:.*]] = load float, ptr %[[VAL_2]], align 4 +// CHECK: store float %[[VAL_63]], ptr %[[VAL_51]], align 4 +// CHECK: %[[VAL_64:.*]] = load float, ptr %[[VAL_51]], align 4 +// CHECK: %[[VAL_65:.*]] = call float @llvm.nvvm.shfl.sync.down.f32(i32 -1, float %[[VAL_64]], i32 1, i32 31) +// CHECK: store float %[[VAL_65]], ptr %[[VAL_1]], align 4 +// CHECK: call void @[[REDUCTION0]](ptr %[[VAL_51]], ptr %[[VAL_1]], ptr %[[VAL_0]]) +// CHECK: %[[VAL_66:.*]] = load float, ptr %[[VAL_0]], align 4 +// CHECK: store float %[[VAL_66]], ptr %[[VAL_51]], align 4 +// CHECK: %[[VAL_67:.*]] = icmp ult i32 %thread.id.1, 32 +// CHECK: %[[VAL_68:.*]] = icmp ult i32 %thread.id.2, %tile_bound.1 +// CHECK: %[[VAL_69:.*]] = and i1 %[[VAL_67]], %[[VAL_68]] +// CHECK: %[[VAL_70:.*]] = icmp eq i32 %lane_id, 0 +// CHECK: %[[VAL_71:.*]] = and i1 %[[VAL_69]], %[[VAL_70]] +// CHECK: br i1 %[[VAL_71]], label %[[VAL_72:.*]], label %[[VAL_19]] +// CHECK: reduction_write_output-after: ; preds = %[[VAL_72]], %[[VAL_33]] +// CHECK: br label %[[VAL_18]] +// CHECK: x_in_tile-true: ; preds = %[[VAL_41]] +// CHECK: %[[VAL_73:.*]] = add i32 %tile_origin.0, 0 +// CHECK: %[[VAL_74:.*]] = add i32 %tile_origin.1, %[[VAL_31]] +// CHECK: %[[VAL_75:.*]] = add i32 %tile_origin.2, %[[VAL_44]] +// CHECK: %[[VAL_76:.*]] = mul nuw nsw i32 %[[VAL_75]], 1 +// CHECK: %[[VAL_77:.*]] = add nuw nsw i32 0, %[[VAL_76]] +// CHECK: %[[VAL_78:.*]] = urem i32 %[[VAL_77]], 12 +// CHECK: %[[VAL_79:.*]] = udiv i32 %[[VAL_77]], 12 +// CHECK: %[[VAL_80:.*]] = urem i32 %[[VAL_79]], 3 +// CHECK: %[[VAL_81:.*]] = udiv i32 %[[VAL_79]], 3 +// CHECK: %[[VAL_82:.*]] = urem i32 %[[VAL_81]], 4 +// CHECK: %[[VAL_83:.*]] = udiv i32 %[[VAL_81]], 4 +// CHECK: %[[VAL_84:.*]] = urem i32 %[[VAL_83]], 32 +// CHECK: %[[VAL_85:.*]] = udiv i32 %[[VAL_83]], 32 +// CHECK: %[[VAL_86:.*]] = udiv i32 %[[VAL_85]], 16 +// CHECK: %[[VAL_87:.*]] = mul nuw nsw i32 %[[VAL_74]], 1 +// CHECK: %[[VAL_88:.*]] = add nuw nsw i32 0, %[[VAL_87]] +// CHECK: %[[VAL_89:.*]] = urem i32 %[[VAL_88]], 32 +// CHECK: %[[VAL_90:.*]] = udiv i32 %[[VAL_88]], 32 +// CHECK: %[[VAL_91:.*]] = urem i32 %[[VAL_90]], 3 +// CHECK: %[[VAL_92:.*]] = udiv i32 %[[VAL_90]], 3 +// CHECK: %[[VAL_93:.*]] = udiv i32 %[[VAL_92]], 12 +// CHECK: %[[VAL_94:.*]] = mul nuw nsw i32 %[[VAL_73]], 1 +// CHECK: %[[VAL_95:.*]] = add nuw nsw i32 0, %[[VAL_94]] +// CHECK: %[[VAL_96:.*]] = getelementptr inbounds [12 x [3 x [32 x [16 x [32 x [4 x [3 x [12 x float]]]]]]]], ptr %[[VAL_97:.*]], i32 0, i32 %[[VAL_92]], i32 %[[VAL_91]], i32 %[[VAL_89]], i32 %[[VAL_85]], i32 %[[VAL_84]], i32 %[[VAL_82]], i32 %[[VAL_80]], i32 %[[VAL_78]] +// CHECK: %[[VAL_98:.*]] = load float, ptr %[[VAL_96]], align 4, !invariant.load !9 +// CHECK: store float %[[VAL_98]], ptr %[[VAL_14]], align 4 +// CHECK: call void @[[REDUCTION0]](ptr %[[VAL_13]], ptr %[[VAL_14]], ptr %[[VAL_10]]) +// CHECK: %[[VAL_99:.*]] = load float, ptr %[[VAL_10]], align 4 +// CHECK: store float %[[VAL_99]], ptr %[[VAL_13]], align 4 +// CHECK: br label %[[VAL_38]] +// CHECK: reduction_write_output-true: ; preds = %[[VAL_33]] +// CHECK: %[[VAL_100:.*]] = add i32 %tile_origin.2, %thread.id.1 +// CHECK: %[[VAL_101:.*]] = mul nuw nsw i32 %[[VAL_100]], 1 +// CHECK: %[[VAL_102:.*]] = add nuw nsw i32 0, %[[VAL_101]] +// CHECK: %[[VAL_103:.*]] = urem i32 %[[VAL_102]], 12 +// CHECK: %[[VAL_104:.*]] = udiv i32 %[[VAL_102]], 12 +// CHECK: %[[VAL_105:.*]] = urem i32 %[[VAL_104]], 3 +// CHECK: %[[VAL_106:.*]] = udiv i32 %[[VAL_104]], 3 +// CHECK: %[[VAL_107:.*]] = urem i32 %[[VAL_106]], 4 +// CHECK: %[[VAL_108:.*]] = udiv i32 %[[VAL_106]], 4 +// CHECK: %[[VAL_109:.*]] = urem i32 %[[VAL_108]], 32 +// CHECK: %[[VAL_110:.*]] = udiv i32 %[[VAL_108]], 32 +// CHECK: %[[VAL_111:.*]] = udiv i32 %[[VAL_110]], 16 +// CHECK: %[[VAL_112:.*]] = mul nuw nsw i32 %tile_origin.0, 1 +// CHECK: %[[VAL_113:.*]] = add nuw nsw i32 0, %[[VAL_112]] +// CHECK: %[[VAL_114:.*]] = getelementptr inbounds [12 x [16 x [4 x [3 x [32 x float]]]]], ptr %[[VAL_115:.*]], i32 0, i32 %[[VAL_103]], i32 %[[VAL_110]], i32 %[[VAL_107]], i32 %[[VAL_105]], i32 %[[VAL_109]] +// CHECK: %[[VAL_116:.*]] = load float, ptr %[[VAL_51]], align 4 +// CHECK: store float %[[VAL_116]], ptr %[[VAL_114]], align 4 +// CHECK: br label %[[VAL_19]] +// CHECK: entry: +// CHECK: %[[VAL_117:.*]] = alloca float, align 4 +// CHECK: %[[VAL_118:.*]] = load float, ptr %[[VAL_119:.*]], align 4 +// CHECK: %[[VAL_120:.*]] = load float, ptr %[[VAL_121:.*]], align 4 +// CHECK: %[[VAL_122:.*]] = fadd float %[[VAL_118]], %[[VAL_120]] +// CHECK: store float %[[VAL_122]], ptr %[[VAL_117]], align 4 +// CHECK: %[[VAL_123:.*]] = load float, ptr %[[VAL_117]], align 4 +// CHECK: store float %[[VAL_123]], ptr %[[VAL_124:.*]], align 4 +// CHECK: ret void diff --git a/third_party/xla/xla/service/gpu/tests/reduce_f64_column.hlo b/third_party/xla/xla/service/gpu/tests/reduce_f64_column.hlo index 2b7a5dd0b8fadd..abfab462332389 100644 --- a/third_party/xla/xla/service/gpu/tests/reduce_f64_column.hlo +++ b/third_party/xla/xla/service/gpu/tests/reduce_f64_column.hlo @@ -50,187 +50,207 @@ ENTRY e { // CHECK: ret void // CHECK: reduce-group-0-true: ; preds = %[[VAL_20]] // CHECK: %[[VAL_21:.*]] = load double, ptr @0, align 8 -// CHECK: store double %[[VAL_21]], ptr %[[VAL_13]], align 8 -// CHECK-PTX: %[[VAL_22:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range !3 -// CHECK-GCN: %[[VAL_22:.*]] = call i32 @llvm.amdgcn.workitem.id.x -// CHECK-PTX: %[[VAL_23:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range !4 -// CHECK-GCN: %[[VAL_23:.*]] = call i32 @llvm.amdgcn.workgroup.id.x -// CHECK: %[[VAL_24:.*]] = urem i32 %[[VAL_22]], 1024 -// CHECK: %[[VAL_25:.*]] = udiv i32 %[[VAL_22]], 1024 -// CHECK: %[[VAL_26:.*]] = mul i32 %[[VAL_23]], 1 -// CHECK: %[[VAL_27:.*]] = add i32 %[[VAL_26]], %[[VAL_25]] -// CHECK: %[[VAL_28:.*]] = icmp ult i32 %[[VAL_27]], 32 -// CHECK: br i1 %[[VAL_28]], label %[[VAL_29:.*]], label %[[VAL_30:.*]] -// CHECK: 9: ; preds = %[[VAL_17]] -// CHECK: %[[VAL_32:.*]] = udiv i32 %[[VAL_24]], 32 -// CHECK: %[[VAL_31:.*]] = urem i32 %[[VAL_24]], 32 -// CHECK: %[[VAL_54:.*]] = mul i32 %[[VAL_31]], 1 -// CHECK: %[[VAL_33:.*]] = urem i32 %[[VAL_24]], 32 -// CHECK: %[[VAL_34:.*]] = udiv i32 %[[VAL_27]], 1 -// CHECK: %[[VAL_35:.*]] = urem i32 %[[VAL_34]], 32 -// CHECK: %[[VAL_36:.*]] = udiv i32 %[[VAL_27]], 32 -// CHECK: %[[VAL_37:.*]] = urem i32 %[[VAL_36]], 1 -// CHECK: %[[VAL_38:.*]] = udiv i32 %[[VAL_27]], 32 -// CHECK: %[[VAL_39:.*]] = icmp eq i32 %[[VAL_37]], 0 -// CHECK: %[[VAL_40:.*]] = select i1 %[[VAL_39]], i32 1024, i32 4096 -// CHECK: %[[VAL_43:.*]] = mul i32 %[[VAL_38]], 1 -// CHECK: %[[VAL_44:.*]] = mul i32 %[[VAL_37]], 4096 -// CHECK: %[[VAL_45:.*]] = mul i32 %[[VAL_35]], 32 -// CHECK: store i32 %[[VAL_32]], ptr %[[VAL_12]], align 4 -// CHECK: br label %[[VAL_46:.*]] -// CHECK: loop1.loop_header: ; preds = %[[VAL_47:.*]], %[[VAL_29]] -// CHECK: %[[VAL_48:.*]] = load i32, ptr %[[VAL_12]], align 4 -// CHECK: %[[VAL_49:.*]] = icmp uge i32 %[[VAL_48]], %[[VAL_40]] -// CHECK: br i1 %[[VAL_49]], label %[[VAL_50:.*]], label %[[VAL_51:.*]] -// CHECK: loop1.loop_body: ; preds = %[[VAL_46]] -// CHECK: %[[VAL_52:.*]] = add nuw nsw i32 %[[VAL_48]], 32 -// CHECK: store i32 %[[VAL_52]], ptr %[[VAL_12]], align 4 -// CHECK: %[[VAL_53:.*]] = icmp eq i32 %[[VAL_48]], %[[VAL_32]] -// CHECK: store i32 0, ptr %[[VAL_11]], align 4 -// CHECK: br label %[[VAL_55:.*]] -// CHECK: loop2.loop_header: ; preds = %[[VAL_56:.*]], %[[VAL_51]] -// CHECK: %[[VAL_57:.*]] = load i32, ptr %[[VAL_11]], align 4 -// CHECK: %[[VAL_58:.*]] = icmp uge i32 %[[VAL_57]], 1 -// CHECK: br i1 %[[VAL_58]], label %[[VAL_47]], label %[[VAL_59:.*]] -// CHECK: loop2.loop_body: ; preds = %[[VAL_55]] -// CHECK: %[[VAL_60:.*]] = add nuw nsw i32 %[[VAL_57]], 1 -// CHECK: store i32 %[[VAL_60]], ptr %[[VAL_11]], align 4 -// CHECK: %[[VAL_61:.*]] = icmp eq i32 %[[VAL_57]], 0 -// CHECK: %[[VAL_62:.*]] = mul i32 %[[VAL_57]], 32 -// CHECK: %[[VAL_63:.*]] = add i32 %[[VAL_62]], 0 -// CHECK: %[[VAL_64:.*]] = add i32 %[[VAL_63]], %[[VAL_54]] -// CHECK: %[[VAL_65:.*]] = icmp ult i32 %[[VAL_64]], 32 -// CHECK: br i1 %[[VAL_65]], label %[[VAL_66:.*]], label %[[VAL_56]] -// CHECK: x_in_tile-after: ; preds = %[[VAL_66]], %[[VAL_59]] -// CHECK: br label %[[VAL_55]], !llvm.loop !5 -// CHECK: loop2.loop_exit: ; preds = %[[VAL_55]] -// CHECK: br label %[[VAL_46]], !llvm.loop !8 -// CHECK: loop1.loop_exit: ; preds = %[[VAL_46]] -// CHECK: %[[VAL_69:.*]] = load double, ptr %[[VAL_13]], align 8 -// CHECK: %[[VAL_67:.*]] = getelementptr inbounds [32 x [33 x double]], ptr addrspace(3) @shared_cache, i32 0, i32 %[[VAL_31]], i32 %[[VAL_32]] -// CHECK: %[[VAL_68:.*]] = addrspacecast ptr addrspace(3) %[[VAL_67]] to ptr -// CHECK: store double %[[VAL_69]], ptr %[[VAL_68]], align 8 -// CHECK: call void @llvm.nvvm.barrier0() -// CHECK: %[[VAL_70:.*]] = getelementptr inbounds [32 x [33 x double]], ptr addrspace(3) @shared_cache, i32 0, i32 %[[VAL_32]], i32 %[[VAL_31]] -// CHECK: %[[VAL_71:.*]] = addrspacecast ptr addrspace(3) %[[VAL_70]] to ptr -// CHECK: %[[VAL_72:.*]] = load double, ptr %[[VAL_71]], align 8 -// CHECK: %[[VAL_73:.*]] = bitcast double %[[VAL_72]] to i64 -// CHECK: %[[VAL_74:.*]] = bitcast i64 %[[VAL_73]] to <2 x i32> -// CHECK: %[[VAL_75:.*]] = extractelement <2 x i32> %[[VAL_74]], i64 0 -// CHECK: %[[VAL_76:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_75]], i32 16, i32 31) -// CHECK: %[[VAL_77:.*]] = insertelement <2 x i32> %[[VAL_74]], i32 %[[VAL_76]], i64 0 -// CHECK: %[[VAL_78:.*]] = extractelement <2 x i32> %[[VAL_77]], i64 1 -// CHECK: %[[VAL_79:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_78]], i32 16, i32 31) -// CHECK: %[[VAL_80:.*]] = insertelement <2 x i32> %[[VAL_77]], i32 %[[VAL_79]], i64 1 -// CHECK: %[[VAL_81:.*]] = bitcast <2 x i32> %[[VAL_80]] to i64 -// CHECK: %[[VAL_82:.*]] = bitcast i64 %[[VAL_81]] to double -// CHECK: store double %[[VAL_82]], ptr %[[VAL_9]], align 8 -// CHECK: call void @[[ADD:add.*]](ptr %[[VAL_71]], ptr %[[VAL_9]], ptr %[[VAL_8]]) -// CHECK: %[[VAL_83:.*]] = load double, ptr %[[VAL_8]], align 8 -// CHECK: store double %[[VAL_83]], ptr %[[VAL_71]], align 8 -// CHECK: %[[VAL_84:.*]] = load double, ptr %[[VAL_71]], align 8 -// CHECK: %[[VAL_85:.*]] = bitcast double %[[VAL_84]] to i64 -// CHECK: %[[VAL_86:.*]] = bitcast i64 %[[VAL_85]] to <2 x i32> -// CHECK: %[[VAL_87:.*]] = extractelement <2 x i32> %[[VAL_86]], i64 0 -// CHECK: %[[VAL_88:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_87]], i32 8, i32 31) -// CHECK: %[[VAL_89:.*]] = insertelement <2 x i32> %[[VAL_86]], i32 %[[VAL_88]], i64 0 -// CHECK: %[[VAL_90:.*]] = extractelement <2 x i32> %[[VAL_89]], i64 1 -// CHECK: %[[VAL_91:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_90]], i32 8, i32 31) -// CHECK: %[[VAL_92:.*]] = insertelement <2 x i32> %[[VAL_89]], i32 %[[VAL_91]], i64 1 -// CHECK: %[[VAL_93:.*]] = bitcast <2 x i32> %[[VAL_92]] to i64 -// CHECK: %[[VAL_94:.*]] = bitcast i64 %[[VAL_93]] to double -// CHECK: store double %[[VAL_94]], ptr %[[VAL_7]], align 8 -// CHECK: call void @[[ADD]](ptr %[[VAL_71]], ptr %[[VAL_7]], ptr %[[VAL_6]]) -// CHECK: %[[VAL_95:.*]] = load double, ptr %[[VAL_6]], align 8 -// CHECK: store double %[[VAL_95]], ptr %[[VAL_71]], align 8 -// CHECK: %[[VAL_96:.*]] = load double, ptr %[[VAL_71]], align 8 -// CHECK: %[[VAL_97:.*]] = bitcast double %[[VAL_96]] to i64 -// CHECK: %[[VAL_98:.*]] = bitcast i64 %[[VAL_97]] to <2 x i32> -// CHECK: %[[VAL_99:.*]] = extractelement <2 x i32> %[[VAL_98]], i64 0 -// CHECK: %[[VAL_100:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_99]], i32 4, i32 31) -// CHECK: %[[VAL_101:.*]] = insertelement <2 x i32> %[[VAL_98]], i32 %[[VAL_100]], i64 0 -// CHECK: %[[VAL_102:.*]] = extractelement <2 x i32> %[[VAL_101]], i64 1 -// CHECK: %[[VAL_103:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_102]], i32 4, i32 31) -// CHECK: %[[VAL_104:.*]] = insertelement <2 x i32> %[[VAL_101]], i32 %[[VAL_103]], i64 1 -// CHECK: %[[VAL_105:.*]] = bitcast <2 x i32> %[[VAL_104]] to i64 -// CHECK: %[[VAL_106:.*]] = bitcast i64 %[[VAL_105]] to double -// CHECK: store double %[[VAL_106]], ptr %[[VAL_5]], align 8 -// CHECK: call void @[[ADD]](ptr %[[VAL_71]], ptr %[[VAL_5]], ptr %[[VAL_4]]) -// CHECK: %[[VAL_107:.*]] = load double, ptr %[[VAL_4]], align 8 -// CHECK: store double %[[VAL_107]], ptr %[[VAL_71]], align 8 -// CHECK: %[[VAL_108:.*]] = load double, ptr %[[VAL_71]], align 8 -// CHECK: %[[VAL_109:.*]] = bitcast double %[[VAL_108]] to i64 -// CHECK: %[[VAL_110:.*]] = bitcast i64 %[[VAL_109]] to <2 x i32> -// CHECK: %[[VAL_111:.*]] = extractelement <2 x i32> %[[VAL_110]], i64 0 -// CHECK: %[[VAL_112:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_111]], i32 2, i32 31) -// CHECK: %[[VAL_113:.*]] = insertelement <2 x i32> %[[VAL_110]], i32 %[[VAL_112]], i64 0 -// CHECK: %[[VAL_114:.*]] = extractelement <2 x i32> %[[VAL_113]], i64 1 -// CHECK: %[[VAL_115:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_114]], i32 2, i32 31) -// CHECK: %[[VAL_116:.*]] = insertelement <2 x i32> %[[VAL_113]], i32 %[[VAL_115]], i64 1 -// CHECK: %[[VAL_117:.*]] = bitcast <2 x i32> %[[VAL_116]] to i64 -// CHECK: %[[VAL_118:.*]] = bitcast i64 %[[VAL_117]] to double -// CHECK: store double %[[VAL_118]], ptr %[[VAL_3]], align 8 -// CHECK: call void @[[ADD]](ptr %[[VAL_71]], ptr %[[VAL_3]], ptr %[[VAL_2]]) -// CHECK: %[[VAL_119:.*]] = load double, ptr %[[VAL_2]], align 8 -// CHECK: store double %[[VAL_119]], ptr %[[VAL_71]], align 8 -// CHECK: %[[VAL_120:.*]] = load double, ptr %[[VAL_71]], align 8 -// CHECK: %[[VAL_121:.*]] = bitcast double %[[VAL_120]] to i64 -// CHECK: %[[VAL_122:.*]] = bitcast i64 %[[VAL_121]] to <2 x i32> -// CHECK: %[[VAL_123:.*]] = extractelement <2 x i32> %[[VAL_122]], i64 0 -// CHECK: %[[VAL_124:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_123]], i32 1, i32 31) -// CHECK: %[[VAL_125:.*]] = insertelement <2 x i32> %[[VAL_122]], i32 %[[VAL_124]], i64 0 -// CHECK: %[[VAL_126:.*]] = extractelement <2 x i32> %[[VAL_125]], i64 1 -// CHECK: %[[VAL_127:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_126]], i32 1, i32 31) -// CHECK: %[[VAL_128:.*]] = insertelement <2 x i32> %[[VAL_125]], i32 %[[VAL_127]], i64 1 -// CHECK: %[[VAL_129:.*]] = bitcast <2 x i32> %[[VAL_128]] to i64 -// CHECK: %[[VAL_130:.*]] = bitcast i64 %[[VAL_129]] to double -// CHECK: store double %[[VAL_130]], ptr %[[VAL_1]], align 8 -// CHECK: call void @[[ADD]](ptr %[[VAL_71]], ptr %[[VAL_1]], ptr %[[VAL_0]]) -// CHECK: %[[VAL_131:.*]] = load double, ptr %[[VAL_0]], align 8 -// CHECK: store double %[[VAL_131]], ptr %[[VAL_71]], align 8 -// CHECK: %[[VAL_133:.*]] = icmp ult i32 %[[VAL_32]], 32 -// CHECK: %[[VAL_134:.*]] = icmp ult i32 %[[VAL_31]], %[[VAL_40]] -// CHECK: %[[VAL_135:.*]] = and i1 %[[VAL_133]], %[[VAL_134]] -// CHECK: %[[VAL_136:.*]] = icmp eq i32 %[[VAL_33]], 0 -// CHECK: %[[VAL_137:.*]] = and i1 %[[VAL_135]], %[[VAL_136]] -// CHECK: br i1 %[[VAL_137]], label %[[VAL_138:.*]], label %[[VAL_19]] -// CHECK: reduction_write_output-after: ; preds = %[[VAL_138]], %[[VAL_50]] +// CHECK: store double %[[VAL_21]], ptr{{.*}}%[[VAL_13]], align 8 +// CHECK-PTX: %thread.id.x = call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range !3 +// CHECK-GCN: %thread.id.x = call i32 @llvm.amdgcn.workitem.id.x +// CHECK-PTX: %block.id.x = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range !4 +// CHECK-GCN: %block.id.x = call i32 @llvm.amdgcn.workgroup.id.x +// CHECK: %[[VAL_22:.*]] = udiv i32 %thread.id.x, 32 +// CHECK: %thread.id.1 = urem i32 %[[VAL_22]], 32 +// CHECK: %thread.id.2 = urem i32 %thread.id.x, 32 +// CHECK: %lane_id = urem i32 %thread.id.x, 32 +// CHECK: %[[VAL_23:.*]] = udiv i32 %block.id.x, 1 +// CHECK: %[[VAL_24:.*]] = urem i32 %[[VAL_23]], 32 +// CHECK: %[[VAL_25:.*]] = udiv i32 %block.id.x, 32 +// CHECK: %[[VAL_26:.*]] = urem i32 %[[VAL_25]], 1 +// CHECK: %[[VAL_27:.*]] = udiv i32 %block.id.x, 32 +// CHECK: %[[VAL_28:.*]] = icmp eq i32 %[[VAL_26]], 0 +// CHECK: %tile_bound.1 = select i1 %[[VAL_28]], i32 1024, i32 4096 +// CHECK: %tile_origin.0 = mul i32 %[[VAL_27]], 1 +// CHECK: %tile_origin.1 = mul i32 %[[VAL_26]], 4096 +// CHECK: %tile_origin.2 = mul i32 %[[VAL_24]], 32 +// CHECK: store i32 %thread.id.1, ptr{{.*}}%[[VAL_12]], align 4 +// CHECK: br label %[[VAL_29:.*]] +// CHECK: loop1.loop_header: ; preds = %[[VAL_30:.*]], %[[VAL_17]] +// CHECK: %[[VAL_31:.*]] = load i32, ptr{{.*}}%[[VAL_12]], align 4 +// CHECK: %[[VAL_32:.*]] = icmp uge i32 %[[VAL_31]], %tile_bound.1 +// CHECK: br i1 %[[VAL_32]], label %[[VAL_33:.*]], label %[[VAL_34:.*]] +// CHECK: loop1.loop_body: ; preds = %[[VAL_29]] +// CHECK: %[[VAL_35:.*]] = add nuw nsw i32 %[[VAL_31]], 32 +// CHECK: store i32 %[[VAL_35]], ptr{{.*}}%[[VAL_12]], align 4 +// CHECK: %[[VAL_36:.*]] = icmp eq i32 %[[VAL_31]], %thread.id.1 +// CHECK: store i32 0, ptr{{.*}}%[[VAL_11]], align 4 +// CHECK: br label %[[VAL_37:.*]] +// CHECK: loop2.loop_header: ; preds = %[[VAL_38:.*]], %[[VAL_34]] +// CHECK: %[[VAL_39:.*]] = load i32, ptr{{.*}}%[[VAL_11]], align 4 +// CHECK: %[[VAL_40:.*]] = icmp uge i32 %[[VAL_39]], 32 +// CHECK: br i1 %[[VAL_40]], label %[[VAL_30]], label %[[VAL_41:.*]] +// CHECK: loop2.loop_body: ; preds = %[[VAL_37]] +// CHECK: %[[VAL_42:.*]] = add nuw nsw i32 %[[VAL_39]], 32 +// CHECK: store i32 %[[VAL_42]], ptr{{.*}}%[[VAL_11]], align 4 +// CHECK: %[[VAL_43:.*]] = icmp eq i32 %[[VAL_39]], 0 +// CHECK: %[[VAL_44:.*]] = add i32 %[[VAL_39]], %thread.id.2 +// CHECK: %[[VAL_45:.*]] = icmp ult i32 %[[VAL_44]], 32 +// CHECK: br i1 %[[VAL_45]], label %[[VAL_46:.*]], label %[[VAL_38]] +// CHECK: x_in_tile-after: ; preds = %[[VAL_46]], %[[VAL_41]] +// CHECK: br label %[[VAL_37]], !llvm.loop !{{[0-9]}} +// CHECK: loop2.loop_exit: ; preds = %[[VAL_37]] +// CHECK: br label %[[VAL_29]], !llvm.loop !{{[0-9]}} +// CHECK: loop1.loop_exit: ; preds = %[[VAL_29]] +// CHECK: %[[VAL_47:.*]] = load double, ptr{{.*}}%[[VAL_13]], align 8 +// CHECK: %[[VAL_48:.*]] = getelementptr inbounds [32 x [33 x double]], ptr addrspace(3) @shared_cache, i32 0, i32 %thread.id.2, i32 %thread.id.1 +// CHECK: %[[VAL_49:.*]] = addrspacecast ptr addrspace(3) %[[VAL_48]] to ptr +// CHECK: store double %[[VAL_47]], ptr{{.*}}%[[VAL_49]], align 8 +// CHECK-PTX: call void @llvm.nvvm.barrier0() +// CHECK-GCN: call void @llvm.amdgcn.s.barrier() +// CHECK: %[[VAL_50:.*]] = getelementptr inbounds [32 x [33 x double]], ptr addrspace(3) @shared_cache, i32 0, i32 %thread.id.1, i32 %thread.id.2 +// CHECK: %[[VAL_51:.*]] = addrspacecast ptr addrspace(3) %[[VAL_50]] to ptr +// CHECK: %[[VAL_52:.*]] = load double, ptr{{.*}}%[[VAL_51]], align 8 +// CHECK: %[[VAL_53:.*]] = bitcast double %[[VAL_52]] to i64 +// CHECK: %[[VAL_54:.*]] = bitcast i64 %[[VAL_53]] to <2 x i32> +// CHECK: %[[VAL_55:.*]] = extractelement <2 x i32> %[[VAL_54]], i64 0 +// CHECK-PTX: %[[VAL_56:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_55]], i32 16, i32 31) +// CHECK-GCN: %[[VAL_56:.*]] = call i32 @__ockl_readuplane_i32(i32 %[[VAL_55]], i32 16) +// CHECK: %[[VAL_57:.*]] = insertelement <2 x i32> %[[VAL_54]], i32 %[[VAL_56]], i64 0 +// CHECK: %[[VAL_58:.*]] = extractelement <2 x i32> %[[VAL_57]], i64 1 +// CHECK-PTX: %[[VAL_59:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_58]], i32 16, i32 31) +// CHECK-GCN: %[[VAL_59:.*]] = call i32 @__ockl_readuplane_i32(i32 %[[VAL_58]], i32 16) +// CHECK: %[[VAL_60:.*]] = insertelement <2 x i32> %[[VAL_57]], i32 %[[VAL_59]], i64 1 +// CHECK: %[[VAL_61:.*]] = bitcast <2 x i32> %[[VAL_60]] to i64 +// CHECK: %[[VAL_62:.*]] = bitcast i64 %[[VAL_61]] to double +// CHECK: store double %[[VAL_62]], ptr{{.*}}%[[VAL_9]], align 8 +// CHECK-PTX: call void @[[ADD:add.*]](ptr %[[VAL_51]], ptr %[[VAL_9]], ptr %[[VAL_8]]) +// CHECK-GCN: %[[VAL_9_1:.*]] = addrspacecast ptr{{.*}}%[[VAL_9]] to ptr +// CHECK-GCN: %[[VAL_8_1:.*]] = addrspacecast ptr{{.*}}%[[VAL_8]] to ptr +// CHECK-GCN: call void @[[ADD:add.*]](ptr %[[VAL_51]], ptr %[[VAL_9_1]], ptr %[[VAL_8_1]]) +// CHECK: %[[VAL_63:.*]] = load double, ptr{{.*}}%[[VAL_8]], align 8 +// CHECK: store double %[[VAL_63]], ptr{{.*}}%[[VAL_51]], align 8 +// CHECK: %[[VAL_64:.*]] = load double, ptr{{.*}}%[[VAL_51]], align 8 +// CHECK: %[[VAL_65:.*]] = bitcast double %[[VAL_64]] to i64 +// CHECK: %[[VAL_66:.*]] = bitcast i64 %[[VAL_65]] to <2 x i32> +// CHECK: %[[VAL_67:.*]] = extractelement <2 x i32> %[[VAL_66]], i64 0 +// CHECK-PTX: %[[VAL_68:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_67]], i32 8, i32 31) +// CHECK-GCN: %[[VAL_68:.*]] = call i32 @__ockl_readuplane_i32(i32 %[[VAL_67]], i32 8) +// CHECK: %[[VAL_69:.*]] = insertelement <2 x i32> %[[VAL_66]], i32 %[[VAL_68]], i64 0 +// CHECK: %[[VAL_70:.*]] = extractelement <2 x i32> %[[VAL_69]], i64 1 +// CHECK-PTX: %[[VAL_71:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_70]], i32 8, i32 31) +// CHECK-GCN: %[[VAL_71:.*]] = call i32 @__ockl_readuplane_i32(i32 %[[VAL_70]], i32 8) +// CHECK: %[[VAL_72:.*]] = insertelement <2 x i32> %[[VAL_69]], i32 %[[VAL_71]], i64 1 +// CHECK: %[[VAL_73:.*]] = bitcast <2 x i32> %[[VAL_72]] to i64 +// CHECK: %[[VAL_74:.*]] = bitcast i64 %[[VAL_73]] to double +// CHECK: store double %[[VAL_74]], ptr{{.*}}%[[VAL_7]], align 8 +// CHECK-PTX: call void @[[ADD]](ptr %[[VAL_51]], ptr %[[VAL_7]], ptr %[[VAL_6]]) +// CHECK-GCN: %[[VAL_7_1:.*]] = addrspacecast ptr{{.*}}%[[VAL_7]] to ptr +// CHECK-GCN: %[[VAL_6_1:.*]] = addrspacecast ptr{{.*}}%[[VAL_6]] to ptr +// CHECK-GCN: call void @[[ADD]](ptr %[[VAL_51]], ptr %[[VAL_7_1]], ptr %[[VAL_6_1]]) +// CHECK: %[[VAL_75:.*]] = load double, ptr{{.*}}%[[VAL_6]], align 8 +// CHECK: store double %[[VAL_75]], ptr{{.*}}%[[VAL_51]], align 8 +// CHECK: %[[VAL_76:.*]] = load double, ptr{{.*}}%[[VAL_51]], align 8 +// CHECK: %[[VAL_77:.*]] = bitcast double %[[VAL_76]] to i64 +// CHECK: %[[VAL_78:.*]] = bitcast i64 %[[VAL_77]] to <2 x i32> +// CHECK: %[[VAL_79:.*]] = extractelement <2 x i32> %[[VAL_78]], i64 0 +// CHECK-PTX: %[[VAL_80:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_79]], i32 4, i32 31) +// CHECK-GCN: %[[VAL_80:.*]] = call i32 @__ockl_readuplane_i32(i32 %[[VAL_79]], i32 4) +// CHECK: %[[VAL_81:.*]] = insertelement <2 x i32> %[[VAL_78]], i32 %[[VAL_80]], i64 0 +// CHECK: %[[VAL_82:.*]] = extractelement <2 x i32> %[[VAL_81]], i64 1 +// CHECK-PTX: %[[VAL_83:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_82]], i32 4, i32 31) +// CHECK-GCN: %[[VAL_83:.*]] = call i32 @__ockl_readuplane_i32(i32 %[[VAL_82]], i32 4) +// CHECK: %[[VAL_84:.*]] = insertelement <2 x i32> %[[VAL_81]], i32 %[[VAL_83]], i64 1 +// CHECK: %[[VAL_85:.*]] = bitcast <2 x i32> %[[VAL_84]] to i64 +// CHECK: %[[VAL_86:.*]] = bitcast i64 %[[VAL_85]] to double +// CHECK: store double %[[VAL_86]], ptr{{.*}}%[[VAL_5]], align 8 +// CHECK-PTX: call void @[[ADD]](ptr %[[VAL_51]], ptr %[[VAL_5]], ptr %[[VAL_4]]) +// CHECK-GCN: %[[VAL_5_1:.*]] = addrspacecast ptr{{.*}}%[[VAL_5]] to ptr +// CHECK-GCN: %[[VAL_4_1:.*]] = addrspacecast ptr{{.*}}%[[VAL_4]] to ptr +// CHECK-GCN: call void @[[ADD]](ptr %[[VAL_51]], ptr %[[VAL_5_1]], ptr %[[VAL_4_1]]) +// CHECK: %[[VAL_87:.*]] = load double, ptr{{.*}}%[[VAL_4]], align 8 +// CHECK: store double %[[VAL_87]], ptr{{.*}}%[[VAL_51]], align 8 +// CHECK: %[[VAL_88:.*]] = load double, ptr{{.*}}%[[VAL_51]], align 8 +// CHECK: %[[VAL_89:.*]] = bitcast double %[[VAL_88]] to i64 +// CHECK: %[[VAL_90:.*]] = bitcast i64 %[[VAL_89]] to <2 x i32> +// CHECK: %[[VAL_91:.*]] = extractelement <2 x i32> %[[VAL_90]], i64 0 +// CHECK-PTX: %[[VAL_92:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_91]], i32 2, i32 31) +// CHECK-GCN: %[[VAL_92:.*]] = call i32 @__ockl_readuplane_i32(i32 %[[VAL_91]], i32 2) +// CHECK: %[[VAL_93:.*]] = insertelement <2 x i32> %[[VAL_90]], i32 %[[VAL_92]], i64 0 +// CHECK: %[[VAL_94:.*]] = extractelement <2 x i32> %[[VAL_93]], i64 1 +// CHECK-PTX: %[[VAL_95:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_94]], i32 2, i32 31) +// CHECK-GCN: %[[VAL_95:.*]] = call i32 @__ockl_readuplane_i32(i32 %[[VAL_94]], i32 2) +// CHECK: %[[VAL_96:.*]] = insertelement <2 x i32> %[[VAL_93]], i32 %[[VAL_95]], i64 1 +// CHECK: %[[VAL_97:.*]] = bitcast <2 x i32> %[[VAL_96]] to i64 +// CHECK: %[[VAL_98:.*]] = bitcast i64 %[[VAL_97]] to double +// CHECK: store double %[[VAL_98]], ptr{{.*}}%[[VAL_3]], align 8 +// CHECK-PTX: call void @[[ADD]](ptr %[[VAL_51]], ptr %[[VAL_3]], ptr %[[VAL_2]]) +// CHECK-GCN: %[[VAL_3_1:.*]] = addrspacecast ptr{{.*}}%[[VAL_3]] to ptr +// CHECK-GCN: %[[VAL_2_1:.*]] = addrspacecast ptr{{.*}}%[[VAL_2]] to ptr +// CHECK-GCN: call void @[[ADD]](ptr %[[VAL_51]], ptr %[[VAL_3_1]], ptr %[[VAL_2_1]]) +// CHECK: %[[VAL_99:.*]] = load double, ptr{{.*}}%[[VAL_2]], align 8 +// CHECK: store double %[[VAL_99]], ptr{{.*}}%[[VAL_51]], align 8 +// CHECK: %[[VAL_100:.*]] = load double, ptr{{.*}}%[[VAL_51]], align 8 +// CHECK: %[[VAL_101:.*]] = bitcast double %[[VAL_100]] to i64 +// CHECK: %[[VAL_102:.*]] = bitcast i64 %[[VAL_101]] to <2 x i32> +// CHECK: %[[VAL_103:.*]] = extractelement <2 x i32> %[[VAL_102]], i64 0 +// CHECK-PTX: %[[VAL_104:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_103]], i32 1, i32 31) +// CHECK-GCN: %[[VAL_104:.*]] = call i32 @__ockl_readuplane_i32(i32 %[[VAL_103]], i32 1) +// CHECK: %[[VAL_105:.*]] = insertelement <2 x i32> %[[VAL_102]], i32 %[[VAL_104]], i64 0 +// CHECK: %[[VAL_106:.*]] = extractelement <2 x i32> %[[VAL_105]], i64 1 +// CHECK-PTX: %[[VAL_107:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_106]], i32 1, i32 31) +// CHECK-GCN: %[[VAL_107:.*]] = call i32 @__ockl_readuplane_i32(i32 %[[VAL_106]], i32 1) +// CHECK: %[[VAL_108:.*]] = insertelement <2 x i32> %[[VAL_105]], i32 %[[VAL_107]], i64 1 +// CHECK: %[[VAL_109:.*]] = bitcast <2 x i32> %[[VAL_108]] to i64 +// CHECK: %[[VAL_110:.*]] = bitcast i64 %[[VAL_109]] to double +// CHECK: store double %[[VAL_110]], ptr{{.*}}%[[VAL_1]], align 8 +// CHECK-PTX: call void @[[ADD]](ptr %[[VAL_51]], ptr %[[VAL_1]], ptr %[[VAL_0]]) +// CHECK-GCN: %[[VAL_1_1:.*]] = addrspacecast ptr{{.*}}%[[VAL_1]] to ptr +// CHECK-GCN: %[[VAL_0_1:.*]] = addrspacecast ptr{{.*}}%[[VAL_0]] to ptr +// CHECK-GCN: call void @[[ADD]](ptr %[[VAL_51]], ptr %[[VAL_1_1]], ptr %[[VAL_0_1]]) +// CHECK: %[[VAL_111:.*]] = load double, ptr{{.*}}%[[VAL_0]], align 8 +// CHECK: store double %[[VAL_111]], ptr{{.*}}%[[VAL_51]], align 8 +// CHECK-PTX: %[[VAL_112:.*]] = icmp ult i32 %thread.id.1, 32 +// CHECK-PTX: %[[VAL_113:.*]] = icmp ult i32 %thread.id.2, %tile_bound.1 +// CHECK-GCN: %[[VAL_113:.*]] = icmp ult i32 %thread.id.2, %tile_bound.1 +// CHECK-GCN: %[[VAL_112:.*]] = icmp ult i32 %thread.id.1, 32 +// CHECK: %[[VAL_114:.*]] = and i1 %[[VAL_112]], %[[VAL_113]] +// CHECK: %[[VAL_115:.*]] = icmp eq i32 %lane_id, 0 +// CHECK: %[[VAL_116:.*]] = and i1 %[[VAL_114]], %[[VAL_115]] +// CHECK: br i1 %[[VAL_116]], label %[[VAL_117:.*]], label %[[VAL_19]] +// CHECK: reduction_write_output-after: ; preds = %[[VAL_117]], %[[VAL_33]] // CHECK: br label %[[VAL_18]] -// CHECK: early_return: ; preds = %[[VAL_17]] -// CHECK: ret void -// CHECK: x_in_tile-true: ; preds = %[[VAL_59]] -// CHECK: %[[VAL_139:.*]] = add i32 %[[VAL_44]], %[[VAL_48]] -// CHECK: %[[VAL_140:.*]] = add i32 %[[VAL_45]], %[[VAL_64]] -// CHECK: %[[VAL_143:.*]] = getelementptr inbounds [1024 x [1024 x i8]], ptr %[[VAL_144:.*]], i32 0, i32 %[[VAL_139]], i32 %[[VAL_140]] -// CHECK: %[[VAL_145:.*]] = load i8, ptr %[[VAL_143]], align 1, !invariant.load !9 -// CHECK: %[[VAL_146:.*]] = getelementptr inbounds [1024 x [1024 x double]], ptr %[[VAL_147:.*]], i32 0, i32 %[[VAL_139]], i32 %[[VAL_140]] -// CHECK: %[[VAL_148:.*]] = load double, ptr %[[VAL_146]], align 8, !invariant.load !9 -// CHECK: %[[VAL_149:.*]] = getelementptr inbounds [1024 x [1024 x double]], ptr %[[VAL_150:.*]], i32 0, i32 %[[VAL_139]], i32 %[[VAL_140]] -// CHECK: %[[VAL_151:.*]] = load double, ptr %[[VAL_149]], align 8, !invariant.load !9 -// CHECK: %[[VAL_152:.*]] = trunc i8 %[[VAL_145]] to i1 -// CHECK: %[[VAL_153:.*]] = select i1 %[[VAL_152]], double %[[VAL_148]], double %[[VAL_151]] -// CHECK: store double %[[VAL_153]], ptr %[[VAL_14]], align 8 -// CHECK: call void @[[ADD]](ptr %[[VAL_13]], ptr %[[VAL_14]], ptr %[[VAL_10]]) -// CHECK: %[[VAL_155:.*]] = load double, ptr %[[VAL_10]], align 8 -// CHECK: store double %[[VAL_155]], ptr %[[VAL_13]], align 8 -// CHECK: br label %[[VAL_56]] -// CHECK: reduction_write_output-true: ; preds = %[[VAL_50]] -// CHECK: %[[VAL_156:.*]] = add i32 %[[VAL_43]], 0 -// CHECK: %[[VAL_157:.*]] = add i32 %[[VAL_44]], %[[VAL_54]] -// CHECK: %[[VAL_158:.*]] = add i32 %[[VAL_45]], %[[VAL_32]] -// CHECK: %[[VAL_159:.*]] = mul i32 %[[VAL_156]], 1024 -// CHECK: %[[VAL_160:.*]] = add i32 %[[VAL_159]], %[[VAL_158]] -// CHECK: %[[VAL_161:.*]] = udiv i32 %[[VAL_160]], 1 -// CHECK: %[[VAL_162:.*]] = getelementptr inbounds [1024 x double], ptr %[[VAL_163:.*]], i32 0, i32 %[[VAL_161]] -// CHECK: %[[VAL_164:.*]] = load double, ptr %[[VAL_71]], align 8 -// CHECK: store double %[[VAL_164]], ptr %[[VAL_162]], align 8 +// CHECK: x_in_tile-true: ; preds = %[[VAL_41]] +// CHECK: %[[VAL_118:.*]] = add i32 %tile_origin.0, 0 +// CHECK: %[[VAL_119:.*]] = add i32 %tile_origin.1, %[[VAL_31]] +// CHECK: %[[VAL_120:.*]] = add i32 %tile_origin.2, %[[VAL_44]] +// CHECK: %[[VAL_121:.*]] = getelementptr inbounds [1024 x [1024 x i8]], ptr{{.*}}%[[VAL_122:.*]], i32 0, i32 %[[VAL_119]], i32 %[[VAL_120]] +// CHECK: %[[VAL_123:.*]] = load i8, ptr{{.*}}%[[VAL_121]], align 1, !invariant.load !{{[0-9]}} +// CHECK: %[[VAL_124:.*]] = getelementptr inbounds [1024 x [1024 x double]], ptr{{.*}}%[[VAL_125:.*]], i32 0, i32 %[[VAL_119]], i32 %[[VAL_120]] +// CHECK: %[[VAL_126:.*]] = load double, ptr{{.*}}%[[VAL_124]], align 8, !invariant.load !{{[0-9]}} +// CHECK: %[[VAL_127:.*]] = getelementptr inbounds [1024 x [1024 x double]], ptr{{.*}}%[[VAL_128:.*]], i32 0, i32 %[[VAL_119]], i32 %[[VAL_120]] +// CHECK: %[[VAL_129:.*]] = load double, ptr{{.*}}%[[VAL_127]], align 8, !invariant.load !{{[0-9]}} +// CHECK: %[[VAL_130:.*]] = trunc i8 %[[VAL_123]] to i1 +// CHECK: %[[VAL_131:.*]] = select i1 %[[VAL_130]], double %[[VAL_126]], double %[[VAL_129]] +// CHECK: store double %[[VAL_131]], ptr{{.*}}%[[VAL_14]], align 8 +// CHECK-PTX: call void @[[ADD]](ptr %[[VAL_13]], ptr %[[VAL_14]], ptr %[[VAL_10]]) +// CHECK-GCN: %[[VAL_13_1:.*]] = addrspacecast ptr{{.*}}%[[VAL_13]] to ptr +// CHECK-GCN: %[[VAL_14_1:.*]] = addrspacecast ptr{{.*}}%[[VAL_14]] to ptr +// CHECK-GCN: %[[VAL_10_1:.*]] = addrspacecast ptr{{.*}}%[[VAL_10]] to ptr +// CHECK-GCN: call void @[[ADD]](ptr %[[VAL_13_1]], ptr %[[VAL_14_1]], ptr %[[VAL_10_1]]) +// CHECK: %[[VAL_132:.*]] = load double, ptr{{.*}}%[[VAL_10]], align 8 +// CHECK: store double %[[VAL_132]], ptr{{.*}}%[[VAL_13]], align 8 +// CHECK: br label %[[VAL_38]] +// CHECK: reduction_write_output-true: ; preds = %[[VAL_33]] +// CHECK: %[[VAL_135:.*]] = add i32 %tile_origin.2, %thread.id.1 +// CHECK: %[[VAL_139:.*]] = getelementptr inbounds [1024 x double], ptr{{.*}}%[[VAL_140:.*]], i32 0, i32 %[[VAL_135]] +// CHECK: %[[VAL_141:.*]] = load double, ptr{{.*}}%[[VAL_51]], align 8 +// CHECK: store double %[[VAL_141]], ptr{{.*}}%[[VAL_139]], align 8 // CHECK: br label %[[VAL_19]] // CHECK: entry: -// CHECK: %[[VAL_165:.*]] = alloca double, align 8 -// CHECK: %[[VAL_166:.*]] = load double, ptr %[[VAL_167:.*]], align 8 -// CHECK: %[[VAL_168:.*]] = load double, ptr %[[VAL_169:.*]], align 8 -// CHECK: %[[VAL_170:.*]] = fadd double %[[VAL_166]], %[[VAL_168]] -// CHECK: store double %[[VAL_170]], ptr %[[VAL_165]], align 8 -// CHECK: %[[VAL_171:.*]] = load double, ptr %[[VAL_165]], align 8 -// CHECK: store double %[[VAL_171]], ptr %[[VAL_172:.*]], align 8 +// CHECK: %[[VAL_142:.*]] = alloca double, align 8 +// CHECK: %[[VAL_143:.*]] = load double, ptr{{.*}}%[[VAL_144:.*]], align 8 +// CHECK: %[[VAL_145:.*]] = load double, ptr{{.*}}%[[VAL_146:.*]], align 8 +// CHECK: %[[VAL_147:.*]] = fadd double %[[VAL_143]], %[[VAL_145]] +// CHECK: store double %[[VAL_147]], ptr{{.*}}%[[VAL_142]], align 8 +// CHECK: %[[VAL_148:.*]] = load double, ptr{{.*}}%[[VAL_142]], align 8 +// CHECK: store double %[[VAL_148]], ptr{{.*}}%[[VAL_149:.*]], align 8 // CHECK: ret void + +// CHECK-PTX: !3 = !{i32 0, i32 1024} +// CHECK-PTX: !4 = !{i32 0, i32 32} diff --git a/third_party/xla/xla/service/gpu/tests/reduce_large_row_to_scalar.hlo b/third_party/xla/xla/service/gpu/tests/reduce_large_row_to_scalar.hlo index ffa8bc386d98fb..21d32aebf1915e 100644 --- a/third_party/xla/xla/service/gpu/tests/reduce_large_row_to_scalar.hlo +++ b/third_party/xla/xla/service/gpu/tests/reduce_large_row_to_scalar.hlo @@ -43,381 +43,378 @@ ENTRY reduce.1 { // CHECK: %[[VAL_20:.*]] = alloca %[[VAL_1]], align 8 // CHECK: %[[VAL_21:.*]] = alloca %[[VAL_1]], align 8 // CHECK: %[[VAL_22:.*]] = alloca %[[VAL_1]], align 8 -// CHECK: %[[VAL_23:.*]] = alloca %[[VAL_1]], align 8 +// CHECK: %[[VAL_23:.*]] = alloca i32, align 4 // CHECK: %[[VAL_24:.*]] = alloca i32, align 4 // CHECK: %[[VAL_25:.*]] = alloca %[[VAL_1]], align 8 -// CHECK: %[[VAL_26:.*]] = alloca %[[VAL_1]], align 8 +// CHECK: %[[VAL_26:.*]] = alloca i32, align 4 // CHECK: %[[VAL_27:.*]] = alloca i32, align 4 -// CHECK: %[[VAL_28:.*]] = alloca i32, align 4 +// CHECK: %[[VAL_28:.*]] = alloca %[[VAL_1]], align 8 // CHECK: %[[VAL_29:.*]] = alloca %[[VAL_1]], align 8 -// CHECK: %[[VAL_30:.*]] = alloca %[[VAL_1]], align 8 -// CHECK-PTX: %[[VAL_31:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.y(), !range !2 -// CHECK-GCN: %[[VAL_31:.*]] = call i32 @llvm.amdgcn.workgroup.id.y -// CHECK: %[[VAL_32:.*]] = icmp eq i32 %[[VAL_31]], 0 -// CHECK: br i1 %[[VAL_32]], label %[[VAL_33:.*]], label %[[VAL_34:.*]] -// CHECK: reduce-group-0-after: ; preds = %[[VAL_35:.*]], %[[VAL_36:.*]] +// CHECK-PTX: %[[VAL_30:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.y(), !range !2 +// CHECK-GCN: %[[VAL_30:.*]] = call i32 @llvm.amdgcn.workgroup.id.y +// CHECK: %[[VAL_31:.*]] = icmp eq i32 %[[VAL_30]], 0 +// CHECK: br i1 %[[VAL_31]], label %[[VAL_32:.*]], label %[[VAL_33:.*]] +// CHECK: reduce-group-0-after: ; preds = %thread_in_bounds-after, %[[VAL_34:.*]] // CHECK: ret void -// CHECK: reduce-group-0-true: ; preds = %[[VAL_36]] -// CHECK: %[[VAL_37:.*]] = load %[[VAL_1]], ptr %[[VAL_38:.*]], align 1, !invariant.load !3 -// CHECK: store %[[VAL_1]] %[[VAL_37]], ptr %[[VAL_29]], align 1 -// CHECK-PTX: %[[VAL_39:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range !4 -// CHECK-GCN: %[[VAL_39:.*]] = call i32 @llvm.amdgcn.workitem.id.x -// CHECK-PTX: %[[VAL_40:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range !2 -// CHECK-GCN: %[[VAL_40:.*]] = call i32 @llvm.amdgcn.workgroup.id.x -// CHECK: %[[VAL_41:.*]] = urem i32 %[[VAL_39]], 640 -// CHECK: %[[VAL_42:.*]] = udiv i32 %[[VAL_39]], 640 -// CHECK: %[[VAL_43:.*]] = mul i32 %[[VAL_40]], 1 -// CHECK: %[[VAL_44:.*]] = add i32 %[[VAL_43]], %[[VAL_42]] -// CHECK: %[[VAL_45:.*]] = icmp ult i32 %[[VAL_44]], 1 +// CHECK: reduce-group-0-true: ; preds = %[[VAL_34]] +// CHECK: %[[VAL_35:.*]] = load %[[VAL_1]], ptr %[[VAL_36:.*]], align 1, !invariant.load !3 +// CHECK: store %[[VAL_1]] %[[VAL_35]], ptr %[[VAL_28]], align 1 +// CHECK-PTX: %thread.id.x = call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range !4 +// CHECK-GCN: %thread.id.x = call i32 @llvm.amdgcn.workitem.id.x +// CHECK-PTX: %block.id.x = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range !2 +// CHECK-GCN: %block.id.x = call i32 @llvm.amdgcn.workgroup.id.x +// CHECK: %thread.id.2 = urem i32 %thread.id.x, 640 +// CHECK: %lane_id = urem i32 %thread.id.x, 32 +// CHECK: %[[VAL_37:.*]] = udiv i32 %block.id.x, 1 +// CHECK: %[[VAL_38:.*]] = urem i32 %[[VAL_37]], 1 +// CHECK: %[[VAL_39:.*]] = udiv i32 %block.id.x, 1 +// CHECK: %[[VAL_40:.*]] = urem i32 %[[VAL_39]], 1 +// CHECK: %[[VAL_41:.*]] = udiv i32 %block.id.x, 1 +// CHECK: %[[VAL_42:.*]] = urem i32 %[[VAL_41]], 1 +// CHECK: %[[VAL_43:.*]] = udiv i32 %block.id.x, 1 +// CHECK: %[[VAL_44:.*]] = icmp eq i32 %[[VAL_40]], 0 +// CHECK: %tile_bound.2 = select i1 %[[VAL_44]], i32 5000, i32 5120 +// CHECK: %tile_origin.0 = mul i32 %[[VAL_43]], 1 +// CHECK: %tile_origin.1 = mul i32 %[[VAL_42]], 1 +// CHECK: %tile_origin.2 = mul i32 %[[VAL_40]], 5120 +// CHECK: %tile_origin.3 = mul i32 %[[VAL_38]], 2 +// CHECK: %[[VAL_45:.*]] = icmp eq i32 5120, %tile_bound.2 // CHECK: br i1 %[[VAL_45]], label %[[VAL_46:.*]], label %[[VAL_47:.*]] -// CHECK: 9: ; preds = %[[VAL_33]] -// CHECK: %[[VAL_49:.*]] = udiv i32 %[[VAL_41]], 640 -// CHECK: %[[VAL_48:.*]] = urem i32 %[[VAL_41]], 640 -// CHECK: %[[VAL_189:.*]] = mul i32 %[[VAL_48]], 2 -// CHECK: %[[VAL_50:.*]] = urem i32 %[[VAL_41]], 32 -// CHECK: %[[VAL_51:.*]] = udiv i32 %[[VAL_44]], 1 -// CHECK: %[[VAL_52:.*]] = urem i32 %[[VAL_51]], 1 -// CHECK: %[[VAL_53:.*]] = udiv i32 %[[VAL_44]], 1 -// CHECK: %[[VAL_54:.*]] = urem i32 %[[VAL_53]], 1 -// CHECK: %[[VAL_55:.*]] = udiv i32 %[[VAL_44]], 1 -// CHECK: %[[VAL_58:.*]] = icmp eq i32 %[[VAL_52]], 0 -// CHECK: %[[VAL_59:.*]] = select i1 %[[VAL_58]], i32 10000, i32 10240 -// CHECK: %[[VAL_60:.*]] = mul i32 %[[VAL_55]], 1 -// CHECK: %[[VAL_61:.*]] = mul i32 %[[VAL_54]], 1 -// CHECK: %[[VAL_62:.*]] = mul i32 %[[VAL_52]], 10240 -// CHECK: store i32 %[[VAL_49]], ptr %[[VAL_28]], align 4 -// CHECK: br label %[[VAL_63:.*]] -// CHECK: loop1.loop_header: ; preds = %[[VAL_64:.*]], %[[VAL_46]] -// CHECK: %[[VAL_65:.*]] = load i32, ptr %[[VAL_28]], align 4 -// CHECK: %[[VAL_66:.*]] = icmp uge i32 %[[VAL_65]], 1 -// CHECK: br i1 %[[VAL_66]], label %[[VAL_67:.*]], label %[[VAL_68:.*]] -// CHECK: loop1.loop_body: ; preds = %[[VAL_63]] -// CHECK: %[[VAL_69:.*]] = add nuw nsw i32 %[[VAL_65]], 1 -// CHECK: store i32 %[[VAL_69]], ptr %[[VAL_28]], align 4 -// CHECK: %[[VAL_70:.*]] = icmp eq i32 %[[VAL_65]], %[[VAL_49]] -// CHECK: %[[VAL_71:.*]] = icmp eq i32 10240, %[[VAL_59]] -// CHECK: br i1 %[[VAL_71]], label %[[VAL_72:.*]], label %[[VAL_73:.*]] -// CHECK: is_full_tile-after: ; preds = %[[VAL_74:.*]], %[[VAL_75:.*]] -// CHECK: br label %[[VAL_63]], !llvm.loop !5 -// CHECK: loop1.loop_exit: ; preds = %[[VAL_63]] -// CHECK: %[[VAL_76:.*]] = load i128, ptr %[[VAL_29]], align 16 -// CHECK: %[[VAL_77:.*]] = bitcast i128 %[[VAL_76]] to <4 x i32> -// CHECK: %[[VAL_78:.*]] = extractelement <4 x i32> %[[VAL_77]], i64 0 -// CHECK: %[[VAL_79:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_78]], i32 16, i32 31) -// CHECK: %[[VAL_80:.*]] = insertelement <4 x i32> %[[VAL_77]], i32 %[[VAL_79]], i64 0 -// CHECK: %[[VAL_81:.*]] = extractelement <4 x i32> %[[VAL_80]], i64 1 -// CHECK: %[[VAL_82:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_81]], i32 16, i32 31) -// CHECK: %[[VAL_83:.*]] = insertelement <4 x i32> %[[VAL_80]], i32 %[[VAL_82]], i64 1 -// CHECK: %[[VAL_84:.*]] = extractelement <4 x i32> %[[VAL_83]], i64 2 -// CHECK: %[[VAL_85:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_84]], i32 16, i32 31) -// CHECK: %[[VAL_86:.*]] = insertelement <4 x i32> %[[VAL_83]], i32 %[[VAL_85]], i64 2 -// CHECK: %[[VAL_87:.*]] = extractelement <4 x i32> %[[VAL_86]], i64 3 -// CHECK: %[[VAL_88:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_87]], i32 16, i32 31) -// CHECK: %[[VAL_89:.*]] = insertelement <4 x i32> %[[VAL_86]], i32 %[[VAL_88]], i64 3 -// CHECK: %[[VAL_90:.*]] = bitcast <4 x i32> %[[VAL_89]] to i128 -// CHECK: store i128 %[[VAL_90]], ptr %[[VAL_21]], align 16 -// CHECK: call void @[[SUM:Sum.*]](ptr %[[VAL_29]], ptr %[[VAL_21]], ptr %[[VAL_20]]) -// CHECK: %[[VAL_91:.*]] = load %[[VAL_1]], ptr %[[VAL_20]], align 1 -// CHECK: store %[[VAL_1]] %[[VAL_91]], ptr %[[VAL_29]], align 1 -// CHECK: %[[VAL_92:.*]] = load i128, ptr %[[VAL_29]], align 16 -// CHECK: %[[VAL_93:.*]] = bitcast i128 %[[VAL_92]] to <4 x i32> -// CHECK: %[[VAL_94:.*]] = extractelement <4 x i32> %[[VAL_93]], i64 0 -// CHECK: %[[VAL_95:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_94]], i32 8, i32 31) -// CHECK: %[[VAL_96:.*]] = insertelement <4 x i32> %[[VAL_93]], i32 %[[VAL_95]], i64 0 -// CHECK: %[[VAL_97:.*]] = extractelement <4 x i32> %[[VAL_96]], i64 1 -// CHECK: %[[VAL_98:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_97]], i32 8, i32 31) -// CHECK: %[[VAL_99:.*]] = insertelement <4 x i32> %[[VAL_96]], i32 %[[VAL_98]], i64 1 -// CHECK: %[[VAL_100:.*]] = extractelement <4 x i32> %[[VAL_99]], i64 2 -// CHECK: %[[VAL_101:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_100]], i32 8, i32 31) -// CHECK: %[[VAL_102:.*]] = insertelement <4 x i32> %[[VAL_99]], i32 %[[VAL_101]], i64 2 -// CHECK: %[[VAL_103:.*]] = extractelement <4 x i32> %[[VAL_102]], i64 3 -// CHECK: %[[VAL_104:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_103]], i32 8, i32 31) -// CHECK: %[[VAL_105:.*]] = insertelement <4 x i32> %[[VAL_102]], i32 %[[VAL_104]], i64 3 -// CHECK: %[[VAL_106:.*]] = bitcast <4 x i32> %[[VAL_105]] to i128 -// CHECK: store i128 %[[VAL_106]], ptr %[[VAL_19]], align 16 -// CHECK: call void @[[SUM]](ptr %[[VAL_29]], ptr %[[VAL_19]], ptr %[[VAL_18]]) -// CHECK: %[[VAL_107:.*]] = load %[[VAL_1]], ptr %[[VAL_18]], align 1 -// CHECK: store %[[VAL_1]] %[[VAL_107]], ptr %[[VAL_29]], align 1 -// CHECK: %[[VAL_108:.*]] = load i128, ptr %[[VAL_29]], align 16 -// CHECK: %[[VAL_109:.*]] = bitcast i128 %[[VAL_108]] to <4 x i32> -// CHECK: %[[VAL_110:.*]] = extractelement <4 x i32> %[[VAL_109]], i64 0 -// CHECK: %[[VAL_111:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_110]], i32 4, i32 31) -// CHECK: %[[VAL_112:.*]] = insertelement <4 x i32> %[[VAL_109]], i32 %[[VAL_111]], i64 0 -// CHECK: %[[VAL_113:.*]] = extractelement <4 x i32> %[[VAL_112]], i64 1 -// CHECK: %[[VAL_114:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_113]], i32 4, i32 31) -// CHECK: %[[VAL_115:.*]] = insertelement <4 x i32> %[[VAL_112]], i32 %[[VAL_114]], i64 1 -// CHECK: %[[VAL_116:.*]] = extractelement <4 x i32> %[[VAL_115]], i64 2 -// CHECK: %[[VAL_117:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_116]], i32 4, i32 31) -// CHECK: %[[VAL_118:.*]] = insertelement <4 x i32> %[[VAL_115]], i32 %[[VAL_117]], i64 2 -// CHECK: %[[VAL_119:.*]] = extractelement <4 x i32> %[[VAL_118]], i64 3 -// CHECK: %[[VAL_120:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_119]], i32 4, i32 31) -// CHECK: %[[VAL_121:.*]] = insertelement <4 x i32> %[[VAL_118]], i32 %[[VAL_120]], i64 3 -// CHECK: %[[VAL_122:.*]] = bitcast <4 x i32> %[[VAL_121]] to i128 -// CHECK: store i128 %[[VAL_122]], ptr %[[VAL_17]], align 16 -// CHECK: call void @[[SUM]](ptr %[[VAL_29]], ptr %[[VAL_17]], ptr %[[VAL_16]]) -// CHECK: %[[VAL_123:.*]] = load %[[VAL_1]], ptr %[[VAL_16]], align 1 -// CHECK: store %[[VAL_1]] %[[VAL_123]], ptr %[[VAL_29]], align 1 -// CHECK: %[[VAL_124:.*]] = load i128, ptr %[[VAL_29]], align 16 -// CHECK: %[[VAL_125:.*]] = bitcast i128 %[[VAL_124]] to <4 x i32> -// CHECK: %[[VAL_126:.*]] = extractelement <4 x i32> %[[VAL_125]], i64 0 -// CHECK: %[[VAL_127:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_126]], i32 2, i32 31) -// CHECK: %[[VAL_128:.*]] = insertelement <4 x i32> %[[VAL_125]], i32 %[[VAL_127]], i64 0 -// CHECK: %[[VAL_129:.*]] = extractelement <4 x i32> %[[VAL_128]], i64 1 -// CHECK: %[[VAL_130:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_129]], i32 2, i32 31) -// CHECK: %[[VAL_131:.*]] = insertelement <4 x i32> %[[VAL_128]], i32 %[[VAL_130]], i64 1 -// CHECK: %[[VAL_132:.*]] = extractelement <4 x i32> %[[VAL_131]], i64 2 -// CHECK: %[[VAL_133:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_132]], i32 2, i32 31) -// CHECK: %[[VAL_134:.*]] = insertelement <4 x i32> %[[VAL_131]], i32 %[[VAL_133]], i64 2 -// CHECK: %[[VAL_135:.*]] = extractelement <4 x i32> %[[VAL_134]], i64 3 -// CHECK: %[[VAL_136:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_135]], i32 2, i32 31) -// CHECK: %[[VAL_137:.*]] = insertelement <4 x i32> %[[VAL_134]], i32 %[[VAL_136]], i64 3 -// CHECK: %[[VAL_138:.*]] = bitcast <4 x i32> %[[VAL_137]] to i128 -// CHECK: store i128 %[[VAL_138]], ptr %[[VAL_15]], align 16 -// CHECK: call void @[[SUM]](ptr %[[VAL_29]], ptr %[[VAL_15]], ptr %[[VAL_14]]) -// CHECK: %[[VAL_139:.*]] = load %[[VAL_1]], ptr %[[VAL_14]], align 1 -// CHECK: store %[[VAL_1]] %[[VAL_139]], ptr %[[VAL_29]], align 1 -// CHECK: %[[VAL_140:.*]] = load i128, ptr %[[VAL_29]], align 16 -// CHECK: %[[VAL_141:.*]] = bitcast i128 %[[VAL_140]] to <4 x i32> -// CHECK: %[[VAL_142:.*]] = extractelement <4 x i32> %[[VAL_141]], i64 0 -// CHECK: %[[VAL_143:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_142]], i32 1, i32 31) -// CHECK: %[[VAL_144:.*]] = insertelement <4 x i32> %[[VAL_141]], i32 %[[VAL_143]], i64 0 -// CHECK: %[[VAL_145:.*]] = extractelement <4 x i32> %[[VAL_144]], i64 1 -// CHECK: %[[VAL_146:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_145]], i32 1, i32 31) -// CHECK: %[[VAL_147:.*]] = insertelement <4 x i32> %[[VAL_144]], i32 %[[VAL_146]], i64 1 -// CHECK: %[[VAL_148:.*]] = extractelement <4 x i32> %[[VAL_147]], i64 2 -// CHECK: %[[VAL_149:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_148]], i32 1, i32 31) -// CHECK: %[[VAL_150:.*]] = insertelement <4 x i32> %[[VAL_147]], i32 %[[VAL_149]], i64 2 -// CHECK: %[[VAL_151:.*]] = extractelement <4 x i32> %[[VAL_150]], i64 3 -// CHECK: %[[VAL_152:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_151]], i32 1, i32 31) -// CHECK: %[[VAL_153:.*]] = insertelement <4 x i32> %[[VAL_150]], i32 %[[VAL_152]], i64 3 -// CHECK: %[[VAL_154:.*]] = bitcast <4 x i32> %[[VAL_153]] to i128 -// CHECK: store i128 %[[VAL_154]], ptr %[[VAL_13]], align 16 -// CHECK: call void @[[SUM]](ptr %[[VAL_29]], ptr %[[VAL_13]], ptr %[[VAL_12]]) -// CHECK: %[[VAL_155:.*]] = load %[[VAL_1]], ptr %[[VAL_12]], align 1 -// CHECK: store %[[VAL_1]] %[[VAL_155]], ptr %[[VAL_29]], align 1 -// CHECK: %[[VAL_156:.*]] = udiv i32 %[[VAL_48]], 32 -// CHECK: %[[VAL_157:.*]] = icmp eq i32 %[[VAL_50]], 0 -// CHECK: br i1 %[[VAL_157]], label %[[VAL_158:.*]], label %[[VAL_159:.*]] -// CHECK: intra_warp_reduce_write-after: ; preds = %[[VAL_158]], %[[VAL_67]] -// CHECK: call void @llvm.nvvm.barrier0() -// CHECK: %[[VAL_160:.*]] = icmp eq i32 %[[VAL_156]], 0 -// CHECK: br i1 %[[VAL_160]], label %[[VAL_161:.*]], label %[[VAL_35]] -// CHECK: inter_warp_reduce-after: ; preds = %[[VAL_162:.*]], %[[VAL_159]] -// CHECK: br label %[[VAL_34]] -// CHECK: early_return: ; preds = %[[VAL_33]] -// CHECK: ret void -// CHECK: is_full_tile-true: ; preds = %[[VAL_68]] +// CHECK: is_full_tile-after: ; preds = %[[VAL_48:.*]], %[[VAL_49:.*]] +// CHECK: %[[VAL_50:.*]] = load i128, ptr %[[VAL_28]], align 16 +// CHECK: %[[VAL_51:.*]] = bitcast i128 %[[VAL_50]] to <4 x i32> +// CHECK: %[[VAL_52:.*]] = extractelement <4 x i32> %[[VAL_51]], i64 0 +// CHECK: %[[VAL_53:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_52]], i32 16, i32 31) +// CHECK: %[[VAL_54:.*]] = insertelement <4 x i32> %[[VAL_51]], i32 %[[VAL_53]], i64 0 +// CHECK: %[[VAL_55:.*]] = extractelement <4 x i32> %[[VAL_54]], i64 1 +// CHECK: %[[VAL_56:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_55]], i32 16, i32 31) +// CHECK: %[[VAL_57:.*]] = insertelement <4 x i32> %[[VAL_54]], i32 %[[VAL_56]], i64 1 +// CHECK: %[[VAL_58:.*]] = extractelement <4 x i32> %[[VAL_57]], i64 2 +// CHECK: %[[VAL_59:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_58]], i32 16, i32 31) +// CHECK: %[[VAL_60:.*]] = insertelement <4 x i32> %[[VAL_57]], i32 %[[VAL_59]], i64 2 +// CHECK: %[[VAL_61:.*]] = extractelement <4 x i32> %[[VAL_60]], i64 3 +// CHECK: %[[VAL_62:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_61]], i32 16, i32 31) +// CHECK: %[[VAL_63:.*]] = insertelement <4 x i32> %[[VAL_60]], i32 %[[VAL_62]], i64 3 +// CHECK: %[[VAL_64:.*]] = bitcast <4 x i32> %[[VAL_63]] to i128 +// CHECK: store i128 %[[VAL_64]], ptr %[[VAL_21]], align 16 +// CHECK: call void @[[SUM:Sum.*]](ptr %[[VAL_28]], ptr %[[VAL_21]], ptr %[[VAL_20]]) +// CHECK: %[[VAL_65:.*]] = load %[[VAL_1]], ptr %[[VAL_20]], align 1 +// CHECK: store %[[VAL_1]] %[[VAL_65]], ptr %[[VAL_28]], align 1 +// CHECK: %[[VAL_66:.*]] = load i128, ptr %[[VAL_28]], align 16 +// CHECK: %[[VAL_67:.*]] = bitcast i128 %[[VAL_66]] to <4 x i32> +// CHECK: %[[VAL_68:.*]] = extractelement <4 x i32> %[[VAL_67]], i64 0 +// CHECK: %[[VAL_69:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_68]], i32 8, i32 31) +// CHECK: %[[VAL_70:.*]] = insertelement <4 x i32> %[[VAL_67]], i32 %[[VAL_69]], i64 0 +// CHECK: %[[VAL_71:.*]] = extractelement <4 x i32> %[[VAL_70]], i64 1 +// CHECK: %[[VAL_72:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_71]], i32 8, i32 31) +// CHECK: %[[VAL_73:.*]] = insertelement <4 x i32> %[[VAL_70]], i32 %[[VAL_72]], i64 1 +// CHECK: %[[VAL_74:.*]] = extractelement <4 x i32> %[[VAL_73]], i64 2 +// CHECK: %[[VAL_75:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_74]], i32 8, i32 31) +// CHECK: %[[VAL_76:.*]] = insertelement <4 x i32> %[[VAL_73]], i32 %[[VAL_75]], i64 2 +// CHECK: %[[VAL_77:.*]] = extractelement <4 x i32> %[[VAL_76]], i64 3 +// CHECK: %[[VAL_78:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_77]], i32 8, i32 31) +// CHECK: %[[VAL_79:.*]] = insertelement <4 x i32> %[[VAL_76]], i32 %[[VAL_78]], i64 3 +// CHECK: %[[VAL_80:.*]] = bitcast <4 x i32> %[[VAL_79]] to i128 +// CHECK: store i128 %[[VAL_80]], ptr %[[VAL_19]], align 16 +// CHECK: call void @[[SUM]](ptr %[[VAL_28]], ptr %[[VAL_19]], ptr %[[VAL_18]]) +// CHECK: %[[VAL_81:.*]] = load %[[VAL_1]], ptr %[[VAL_18]], align 1 +// CHECK: store %[[VAL_1]] %[[VAL_81]], ptr %[[VAL_28]], align 1 +// CHECK: %[[VAL_82:.*]] = load i128, ptr %[[VAL_28]], align 16 +// CHECK: %[[VAL_83:.*]] = bitcast i128 %[[VAL_82]] to <4 x i32> +// CHECK: %[[VAL_84:.*]] = extractelement <4 x i32> %[[VAL_83]], i64 0 +// CHECK: %[[VAL_85:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_84]], i32 4, i32 31) +// CHECK: %[[VAL_86:.*]] = insertelement <4 x i32> %[[VAL_83]], i32 %[[VAL_85]], i64 0 +// CHECK: %[[VAL_87:.*]] = extractelement <4 x i32> %[[VAL_86]], i64 1 +// CHECK: %[[VAL_88:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_87]], i32 4, i32 31) +// CHECK: %[[VAL_89:.*]] = insertelement <4 x i32> %[[VAL_86]], i32 %[[VAL_88]], i64 1 +// CHECK: %[[VAL_90:.*]] = extractelement <4 x i32> %[[VAL_89]], i64 2 +// CHECK: %[[VAL_91:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_90]], i32 4, i32 31) +// CHECK: %[[VAL_92:.*]] = insertelement <4 x i32> %[[VAL_89]], i32 %[[VAL_91]], i64 2 +// CHECK: %[[VAL_93:.*]] = extractelement <4 x i32> %[[VAL_92]], i64 3 +// CHECK: %[[VAL_94:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_93]], i32 4, i32 31) +// CHECK: %[[VAL_95:.*]] = insertelement <4 x i32> %[[VAL_92]], i32 %[[VAL_94]], i64 3 +// CHECK: %[[VAL_96:.*]] = bitcast <4 x i32> %[[VAL_95]] to i128 +// CHECK: store i128 %[[VAL_96]], ptr %[[VAL_17]], align 16 +// CHECK: call void @[[SUM]](ptr %[[VAL_28]], ptr %[[VAL_17]], ptr %[[VAL_16]]) +// CHECK: %[[VAL_97:.*]] = load %[[VAL_1]], ptr %[[VAL_16]], align 1 +// CHECK: store %[[VAL_1]] %[[VAL_97]], ptr %[[VAL_28]], align 1 +// CHECK: %[[VAL_98:.*]] = load i128, ptr %[[VAL_28]], align 16 +// CHECK: %[[VAL_99:.*]] = bitcast i128 %[[VAL_98]] to <4 x i32> +// CHECK: %[[VAL_100:.*]] = extractelement <4 x i32> %[[VAL_99]], i64 0 +// CHECK: %[[VAL_101:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_100]], i32 2, i32 31) +// CHECK: %[[VAL_102:.*]] = insertelement <4 x i32> %[[VAL_99]], i32 %[[VAL_101]], i64 0 +// CHECK: %[[VAL_103:.*]] = extractelement <4 x i32> %[[VAL_102]], i64 1 +// CHECK: %[[VAL_104:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_103]], i32 2, i32 31) +// CHECK: %[[VAL_105:.*]] = insertelement <4 x i32> %[[VAL_102]], i32 %[[VAL_104]], i64 1 +// CHECK: %[[VAL_106:.*]] = extractelement <4 x i32> %[[VAL_105]], i64 2 +// CHECK: %[[VAL_107:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_106]], i32 2, i32 31) +// CHECK: %[[VAL_108:.*]] = insertelement <4 x i32> %[[VAL_105]], i32 %[[VAL_107]], i64 2 +// CHECK: %[[VAL_109:.*]] = extractelement <4 x i32> %[[VAL_108]], i64 3 +// CHECK: %[[VAL_110:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_109]], i32 2, i32 31) +// CHECK: %[[VAL_111:.*]] = insertelement <4 x i32> %[[VAL_108]], i32 %[[VAL_110]], i64 3 +// CHECK: %[[VAL_112:.*]] = bitcast <4 x i32> %[[VAL_111]] to i128 +// CHECK: store i128 %[[VAL_112]], ptr %[[VAL_15]], align 16 +// CHECK: call void @[[SUM]](ptr %[[VAL_28]], ptr %[[VAL_15]], ptr %[[VAL_14]]) +// CHECK: %[[VAL_113:.*]] = load %[[VAL_1]], ptr %[[VAL_14]], align 1 +// CHECK: store %[[VAL_1]] %[[VAL_113]], ptr %[[VAL_28]], align 1 +// CHECK: %[[VAL_114:.*]] = load i128, ptr %[[VAL_28]], align 16 +// CHECK: %[[VAL_115:.*]] = bitcast i128 %[[VAL_114]] to <4 x i32> +// CHECK: %[[VAL_116:.*]] = extractelement <4 x i32> %[[VAL_115]], i64 0 +// CHECK: %[[VAL_117:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_116]], i32 1, i32 31) +// CHECK: %[[VAL_118:.*]] = insertelement <4 x i32> %[[VAL_115]], i32 %[[VAL_117]], i64 0 +// CHECK: %[[VAL_119:.*]] = extractelement <4 x i32> %[[VAL_118]], i64 1 +// CHECK: %[[VAL_120:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_119]], i32 1, i32 31) +// CHECK: %[[VAL_121:.*]] = insertelement <4 x i32> %[[VAL_118]], i32 %[[VAL_120]], i64 1 +// CHECK: %[[VAL_122:.*]] = extractelement <4 x i32> %[[VAL_121]], i64 2 +// CHECK: %[[VAL_123:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_122]], i32 1, i32 31) +// CHECK: %[[VAL_124:.*]] = insertelement <4 x i32> %[[VAL_121]], i32 %[[VAL_123]], i64 2 +// CHECK: %[[VAL_125:.*]] = extractelement <4 x i32> %[[VAL_124]], i64 3 +// CHECK: %[[VAL_126:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_125]], i32 1, i32 31) +// CHECK: %[[VAL_127:.*]] = insertelement <4 x i32> %[[VAL_124]], i32 %[[VAL_126]], i64 3 +// CHECK: %[[VAL_128:.*]] = bitcast <4 x i32> %[[VAL_127]] to i128 +// CHECK: store i128 %[[VAL_128]], ptr %[[VAL_13]], align 16 +// CHECK: call void @[[SUM]](ptr %[[VAL_28]], ptr %[[VAL_13]], ptr %[[VAL_12]]) +// CHECK: %[[VAL_129:.*]] = load %[[VAL_1]], ptr %[[VAL_12]], align 1 +// CHECK: store %[[VAL_1]] %[[VAL_129]], ptr %[[VAL_28]], align 1 +// CHECK: %[[VAL_130:.*]] = udiv i32 %thread.id.2, 32 +// CHECK: br i1 true, label %thread_in_bounds-true, label %thread_in_bounds-after +// CHECK: thread_in_bounds-after: ; preds = %[[VAL_131:.*]], %[[VAL_132:.*]] +// CHECK: br label %[[VAL_33]] +// CHECK: is_full_tile-true: ; preds = %[[VAL_32]] // CHECK: store i32 0, ptr %[[VAL_27]], align 4 -// CHECK: br label %[[VAL_164:.*]] -// CHECK: loop2.loop_header: ; preds = %[[VAL_165:.*]], %[[VAL_72]] -// CHECK: %[[VAL_166:.*]] = load i32, ptr %[[VAL_27]], align 4 -// CHECK: %[[VAL_167:.*]] = icmp uge i32 %[[VAL_166]], 8 -// CHECK: br i1 %[[VAL_167]], label %[[VAL_75]], label %[[VAL_165]] -// CHECK: loop2.loop_body: ; preds = %[[VAL_164]] -// CHECK: %[[VAL_168:.*]] = add nuw nsw i32 %[[VAL_166]], 1 -// CHECK: store i32 %[[VAL_168]], ptr %[[VAL_27]], align 4 -// CHECK: %[[VAL_169:.*]] = icmp eq i32 %[[VAL_166]], 0 -// CHECK: %[[VAL_170:.*]] = mul i32 %[[VAL_166]], 1280 -// CHECK: %[[VAL_171:.*]] = add i32 %[[VAL_170]], 0 -// CHECK: %[[VAL_172:.*]] = add i32 %[[VAL_171]], %[[VAL_189]] -// CHECK: %[[VAL_173:.*]] = add i32 %[[VAL_61]], %[[VAL_65]] -// CHECK: %[[VAL_174:.*]] = add i32 %[[VAL_62]], %[[VAL_172]] -// CHECK: %[[VAL_175:.*]] = getelementptr inbounds [10000 x %[[VAL_1]]], ptr %[[VAL_176:.*]], i32 0, i32 %[[VAL_174]] -// CHECK: %[[VAL_177:.*]] = load %[[VAL_1]], ptr %[[VAL_175]], align 1, !invariant.load !3 -// CHECK: store %[[VAL_1]] %[[VAL_177]], ptr %[[VAL_30]], align 1 -// CHECK: call void @[[SUM]](ptr %[[VAL_29]], ptr %[[VAL_30]], ptr %[[VAL_26]]) -// CHECK: %[[VAL_179:.*]] = load %[[VAL_1]], ptr %[[VAL_26]], align 1 -// CHECK: store %[[VAL_1]] %[[VAL_179]], ptr %[[VAL_29]], align 1 -// CHECK: %[[VAL_180:.*]] = mul i32 %[[VAL_166]], 1280 -// CHECK: %[[VAL_181:.*]] = add i32 %[[VAL_180]], 1 -// CHECK: %[[VAL_182:.*]] = add i32 %[[VAL_181]], %[[VAL_189]] -// CHECK: %[[VAL_183:.*]] = add i32 %[[VAL_61]], %[[VAL_65]] -// CHECK: %[[VAL_184:.*]] = add i32 %[[VAL_62]], %[[VAL_182]] -// CHECK: %[[VAL_185:.*]] = getelementptr inbounds [10000 x %[[VAL_1]]], ptr %[[VAL_176]], i32 0, i32 %[[VAL_184]] -// CHECK: %[[VAL_186:.*]] = load %[[VAL_1]], ptr %[[VAL_185]], align 1, !invariant.load !3 -// CHECK: store %[[VAL_1]] %[[VAL_186]], ptr %[[VAL_30]], align 1 -// CHECK: call void @[[SUM]](ptr %[[VAL_29]], ptr %[[VAL_30]], ptr %[[VAL_25]]) -// CHECK: %[[VAL_188:.*]] = load %[[VAL_1]], ptr %[[VAL_25]], align 1 -// CHECK: store %[[VAL_1]] %[[VAL_188]], ptr %[[VAL_29]], align 1 -// CHECK: br label %[[VAL_164]], !llvm.loop !7 -// CHECK: loop2.loop_exit: ; preds = %[[VAL_164]] -// CHECK: br label %[[VAL_64]] -// CHECK: is_full_tile-false: ; preds = %[[VAL_68]] +// CHECK: br label %[[VAL_133:.*]] +// CHECK: loop2.loop_header: ; preds = %[[VAL_134:.*]], %[[VAL_46]] +// CHECK: %[[VAL_135:.*]] = load i32, ptr %[[VAL_27]], align 4 +// CHECK: %[[VAL_136:.*]] = icmp uge i32 %[[VAL_135]], 5120 +// CHECK: br i1 %[[VAL_136]], label %[[VAL_49]], label %[[VAL_137:.*]] +// CHECK: loop2.loop_body: ; preds = %[[VAL_133]] +// CHECK: %[[VAL_138:.*]] = add nuw nsw i32 %[[VAL_135]], 640 +// CHECK: store i32 %[[VAL_138]], ptr %[[VAL_27]], align 4 +// CHECK: %[[VAL_139:.*]] = icmp eq i32 %[[VAL_135]], 0 +// CHECK: %[[VAL_140:.*]] = add i32 %[[VAL_135]], %thread.id.2 +// CHECK: store i32 0, ptr %[[VAL_26]], align 4 +// CHECK: br label %[[VAL_141:.*]] +// CHECK: loop3.loop_header: ; preds = %[[VAL_142:.*]], %[[VAL_137]] +// CHECK: %[[VAL_143:.*]] = load i32, ptr %[[VAL_26]], align 4 +// CHECK: %[[VAL_144:.*]] = icmp uge i32 %[[VAL_143]], 2 +// CHECK: br i1 %[[VAL_144]], label %[[VAL_134]], label %[[VAL_142]] +// CHECK: loop3.loop_body: ; preds = %[[VAL_141]] +// CHECK: %[[VAL_145:.*]] = add nuw nsw i32 %[[VAL_143]], 1 +// CHECK: store i32 %[[VAL_145]], ptr %[[VAL_26]], align 4 +// CHECK: %[[VAL_146:.*]] = icmp eq i32 %[[VAL_143]], 0 +// CHECK: %[[VAL_147:.*]] = add i32 %tile_origin.0, 0 +// CHECK: %[[VAL_148:.*]] = add i32 %tile_origin.1, 0 +// CHECK: %[[VAL_149:.*]] = add i32 %tile_origin.2, %[[VAL_140]] +// CHECK: %[[VAL_150:.*]] = add i32 %tile_origin.3, %[[VAL_143]] +// CHECK: %[[VAL_151:.*]] = mul nuw nsw i32 %[[VAL_150]], 1 +// CHECK: %[[VAL_152:.*]] = add nuw nsw i32 0, %[[VAL_151]] +// CHECK: %[[VAL_153:.*]] = mul nuw nsw i32 %[[VAL_149]], 2 +// CHECK: %[[VAL_154:.*]] = add nuw nsw i32 %[[VAL_152]], %[[VAL_153]] +// CHECK: %[[VAL_155:.*]] = udiv i32 %[[VAL_154]], 10000 +// CHECK: %[[VAL_156:.*]] = mul nuw nsw i32 %[[VAL_148]], 1 +// CHECK: %[[VAL_157:.*]] = add nuw nsw i32 0, %[[VAL_156]] +// CHECK: %[[VAL_158:.*]] = mul nuw nsw i32 %[[VAL_147]], 1 +// CHECK: %[[VAL_159:.*]] = add nuw nsw i32 0, %[[VAL_158]] +// CHECK: %[[VAL_160:.*]] = getelementptr inbounds [10000 x %[[VAL_1]]], ptr %[[VAL_161:.*]], i32 0, i32 %[[VAL_154]] +// CHECK: %[[VAL_162:.*]] = load %[[VAL_1]], ptr %[[VAL_160]], align 1, !invariant.load !3 +// CHECK: store %[[VAL_1]] %[[VAL_162]], ptr %[[VAL_29]], align 1 +// CHECK: call void @[[SUM]](ptr %[[VAL_28]], ptr %[[VAL_29]], ptr %[[VAL_25]]) +// CHECK: %[[VAL_163:.*]] = load %[[VAL_1]], ptr %[[VAL_25]], align 1 +// CHECK: store %[[VAL_1]] %[[VAL_163]], ptr %[[VAL_28]], align 1 +// CHECK: br label %[[VAL_141]], !llvm.loop !5 +// CHECK: loop3.loop_exit: ; preds = %[[VAL_141]] +// CHECK: br label %[[VAL_133]], !llvm.loop !7 +// CHECK: loop2.loop_exit: ; preds = %[[VAL_133]] +// CHECK: br label %[[VAL_132]] +// CHECK: is_full_tile-false: ; preds = %[[VAL_32]] // CHECK: store i32 0, ptr %[[VAL_24]], align 4 -// CHECK: br label %[[VAL_190:.*]] -// CHECK: loop2.loop_header7: ; preds = %[[VAL_191:.*]], %[[VAL_73]] -// CHECK: %[[VAL_192:.*]] = load i32, ptr %[[VAL_24]], align 4 -// CHECK: %[[VAL_193:.*]] = icmp uge i32 %[[VAL_192]], 8 -// CHECK: br i1 %[[VAL_193]], label %[[VAL_74]], label %[[VAL_194:.*]] -// CHECK: loop2.loop_body8: ; preds = %[[VAL_190]] -// CHECK: %[[VAL_195:.*]] = add nuw nsw i32 %[[VAL_192]], 1 -// CHECK: store i32 %[[VAL_195]], ptr %[[VAL_24]], align 4 -// CHECK: %[[VAL_196:.*]] = icmp eq i32 %[[VAL_192]], 0 -// CHECK: %[[VAL_197:.*]] = mul i32 %[[VAL_192]], 1280 -// CHECK: %[[VAL_198:.*]] = add i32 %[[VAL_197]], 0 -// CHECK: %[[VAL_199:.*]] = add i32 %[[VAL_198]], %[[VAL_189]] -// CHECK: %[[VAL_200:.*]] = icmp ult i32 %[[VAL_199]], %[[VAL_59]] -// CHECK: br i1 %[[VAL_200]], label %[[VAL_201:.*]], label %[[VAL_202:.*]] -// CHECK: x_in_tile-after: ; preds = %[[VAL_201]], %[[VAL_194]] -// CHECK: %[[VAL_203:.*]] = mul i32 %[[VAL_192]], 1280 -// CHECK: %[[VAL_204:.*]] = add i32 %[[VAL_203]], 1 -// CHECK: %[[VAL_205:.*]] = add i32 %[[VAL_204]], %[[VAL_189]] -// CHECK: %[[VAL_206:.*]] = icmp ult i32 %[[VAL_205]], %[[VAL_59]] -// CHECK: br i1 %[[VAL_206]], label %[[VAL_207:.*]], label %[[VAL_191]] -// CHECK: x_in_tile-after16: ; preds = %[[VAL_207]], %[[VAL_202]] -// CHECK: br label %[[VAL_190]], !llvm.loop !9 -// CHECK: loop2.loop_exit6: ; preds = %[[VAL_190]] -// CHECK: br label %[[VAL_64]] -// CHECK: x_in_tile-true: ; preds = %[[VAL_194]] -// CHECK: %[[VAL_208:.*]] = add i32 %[[VAL_61]], %[[VAL_65]] -// CHECK: %[[VAL_209:.*]] = add i32 %[[VAL_62]], %[[VAL_199]] -// CHECK: %[[VAL_210:.*]] = getelementptr inbounds [10000 x %[[VAL_1]]], ptr %[[VAL_176]], i32 0, i32 %[[VAL_209]] -// CHECK: %[[VAL_211:.*]] = load %[[VAL_1]], ptr %[[VAL_210]], align 1, !invariant.load !3 -// CHECK: store %[[VAL_1]] %[[VAL_211]], ptr %[[VAL_30]], align 1 -// CHECK: call void @[[SUM]](ptr %[[VAL_29]], ptr %[[VAL_30]], ptr %[[VAL_23]]) -// CHECK: %[[VAL_213:.*]] = load %[[VAL_1]], ptr %[[VAL_23]], align 1 -// CHECK: store %[[VAL_1]] %[[VAL_213]], ptr %[[VAL_29]], align 1 +// CHECK: br label %[[VAL_164:.*]] +// CHECK: loop2.loop_header4: ; preds = %[[VAL_165:.*]], %[[VAL_47]] +// CHECK: %[[VAL_166:.*]] = load i32, ptr %[[VAL_24]], align 4 +// CHECK: %[[VAL_167:.*]] = icmp uge i32 %[[VAL_166]], 5120 +// CHECK: br i1 %[[VAL_167]], label %[[VAL_48]], label %[[VAL_168:.*]] +// CHECK: loop2.loop_body5: ; preds = %[[VAL_164]] +// CHECK: %[[VAL_169:.*]] = add nuw nsw i32 %[[VAL_166]], 640 +// CHECK: store i32 %[[VAL_169]], ptr %[[VAL_24]], align 4 +// CHECK: %[[VAL_170:.*]] = icmp eq i32 %[[VAL_166]], 0 +// CHECK: %[[VAL_171:.*]] = add i32 %[[VAL_166]], %thread.id.2 +// CHECK: %[[VAL_172:.*]] = icmp ult i32 %[[VAL_171]], %tile_bound.2 +// CHECK: br i1 %[[VAL_172]], label %[[VAL_173:.*]], label %[[VAL_165]] +// CHECK: x_in_tile-after: ; preds = %[[VAL_174:.*]], %[[VAL_168]] +// CHECK: br label %[[VAL_164]], !llvm.loop !9 +// CHECK: loop2.loop_exit3: ; preds = %[[VAL_164]] +// CHECK: br label %[[VAL_132]] +// CHECK: x_in_tile-true: ; preds = %[[VAL_168]] +// CHECK: store i32 0, ptr %[[VAL_23]], align 4 +// CHECK: br label %[[VAL_175:.*]] +// CHECK: loop3.loop_header10: ; preds = %[[VAL_176:.*]], %[[VAL_173]] +// CHECK: %[[VAL_177:.*]] = load i32, ptr %[[VAL_23]], align 4 +// CHECK: %[[VAL_178:.*]] = icmp uge i32 %[[VAL_177]], 2 +// CHECK: br i1 %[[VAL_178]], label %[[VAL_174]], label %[[VAL_176]] +// CHECK: loop3.loop_body11: ; preds = %[[VAL_175]] +// CHECK: %[[VAL_179:.*]] = add nuw nsw i32 %[[VAL_177]], 1 +// CHECK: store i32 %[[VAL_179]], ptr %[[VAL_23]], align 4 +// CHECK: %[[VAL_180:.*]] = icmp eq i32 %[[VAL_177]], 0 +// CHECK: %[[VAL_181:.*]] = add i32 %tile_origin.0, 0 +// CHECK: %[[VAL_182:.*]] = add i32 %tile_origin.1, 0 +// CHECK: %[[VAL_183:.*]] = add i32 %tile_origin.2, %[[VAL_171]] +// CHECK: %[[VAL_184:.*]] = add i32 %tile_origin.3, %[[VAL_177]] +// CHECK: %[[VAL_185:.*]] = mul nuw nsw i32 %[[VAL_184]], 1 +// CHECK: %[[VAL_186:.*]] = add nuw nsw i32 0, %[[VAL_185]] +// CHECK: %[[VAL_187:.*]] = mul nuw nsw i32 %[[VAL_183]], 2 +// CHECK: %[[VAL_188:.*]] = add nuw nsw i32 %[[VAL_186]], %[[VAL_187]] +// CHECK: %[[VAL_189:.*]] = udiv i32 %[[VAL_188]], 10000 +// CHECK: %[[VAL_190:.*]] = mul nuw nsw i32 %[[VAL_182]], 1 +// CHECK: %[[VAL_191:.*]] = add nuw nsw i32 0, %[[VAL_190]] +// CHECK: %[[VAL_192:.*]] = mul nuw nsw i32 %[[VAL_181]], 1 +// CHECK: %[[VAL_193:.*]] = add nuw nsw i32 0, %[[VAL_192]] +// CHECK: %[[VAL_194:.*]] = getelementptr inbounds [10000 x %[[VAL_1]]], ptr %[[VAL_161]], i32 0, i32 %[[VAL_188]] +// CHECK: %[[VAL_195:.*]] = load %[[VAL_1]], ptr %[[VAL_194]], align 1, !invariant.load !3 +// CHECK: store %[[VAL_1]] %[[VAL_195]], ptr %[[VAL_29]], align 1 +// CHECK: call void @[[SUM]](ptr %[[VAL_28]], ptr %[[VAL_29]], ptr %[[VAL_22]]) +// CHECK: %[[VAL_196:.*]] = load %[[VAL_1]], ptr %[[VAL_22]], align 1 +// CHECK: store %[[VAL_1]] %[[VAL_196]], ptr %[[VAL_28]], align 1 +// CHECK: br label %[[VAL_175]], !llvm.loop !10 +// CHECK: loop3.loop_exit9: ; preds = %[[VAL_175]] +// CHECK: br label %[[VAL_165]] +// CHECK: thread_in_bounds-true: ; preds = %[[VAL_132]] +// CHECK: %[[VAL_197:.*]] = icmp eq i32 %lane_id, 0 +// CHECK: br i1 %[[VAL_197]], label %[[VAL_198:.*]], label %[[VAL_199:.*]] +// CHECK: intra_warp_reduce_write-after: ; preds = %[[VAL_198]], %thread_in_bounds-true +// CHECK: call void @llvm.nvvm.barrier0() +// CHECK: %[[VAL_200:.*]] = icmp eq i32 %[[VAL_130]], 0 +// CHECK: br i1 %[[VAL_200]], label %[[VAL_201:.*]], label %[[VAL_131]] +// CHECK: inter_warp_reduce-after: ; preds = %[[VAL_202:.*]], %[[VAL_199]] +// CHECK: br label %thread_in_bounds-after +// CHECK: intra_warp_reduce_write-true: ; preds = %thread_in_bounds-true +// CHECK: %[[VAL_203:.*]] = load %[[VAL_1]], ptr %[[VAL_28]], align 1 +// CHECK: %[[VAL_204:.*]] = getelementptr inbounds [1 x [20 x %[[VAL_1]]]], ptr addrspace(3) @shared_cache, i32 0, i32 0, i32 %[[VAL_130]] +// CHECK: %[[VAL_205:.*]] = addrspacecast ptr addrspace(3) %[[VAL_204]] to ptr +// CHECK: store %[[VAL_1]] %[[VAL_203]], ptr %[[VAL_205]], align 1 +// CHECK: br label %[[VAL_199]] +// CHECK: inter_warp_reduce-true: ; preds = %[[VAL_199]] +// CHECK: %[[VAL_206:.*]] = getelementptr inbounds [1 x [20 x %[[VAL_1]]]], ptr addrspace(3) @shared_cache, i32 0, i32 0, i32 %lane_id +// CHECK: %[[VAL_207:.*]] = addrspacecast ptr addrspace(3) %[[VAL_206]] to ptr +// CHECK: store %[[VAL_1]] %[[VAL_35]], ptr %[[VAL_11]], align 1 +// CHECK: %[[VAL_208:.*]] = icmp ult i32 %thread.id.2, 20 +// CHECK: %[[VAL_209:.*]] = select i1 %[[VAL_208]], ptr %[[VAL_207]], ptr %[[VAL_11]] +// CHECK: %[[VAL_210:.*]] = load i128, ptr %[[VAL_209]], align 16 +// CHECK: %[[VAL_211:.*]] = bitcast i128 %[[VAL_210]] to <4 x i32> +// CHECK: %[[VAL_212:.*]] = extractelement <4 x i32> %[[VAL_211]], i64 0 +// CHECK: %[[VAL_213:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_212]], i32 16, i32 31) +// CHECK: %[[VAL_214:.*]] = insertelement <4 x i32> %[[VAL_211]], i32 %[[VAL_213]], i64 0 +// CHECK: %[[VAL_215:.*]] = extractelement <4 x i32> %[[VAL_214]], i64 1 +// CHECK: %[[VAL_216:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_215]], i32 16, i32 31) +// CHECK: %[[VAL_217:.*]] = insertelement <4 x i32> %[[VAL_214]], i32 %[[VAL_216]], i64 1 +// CHECK: %[[VAL_218:.*]] = extractelement <4 x i32> %[[VAL_217]], i64 2 +// CHECK: %[[VAL_219:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_218]], i32 16, i32 31) +// CHECK: %[[VAL_220:.*]] = insertelement <4 x i32> %[[VAL_217]], i32 %[[VAL_219]], i64 2 +// CHECK: %[[VAL_221:.*]] = extractelement <4 x i32> %[[VAL_220]], i64 3 +// CHECK: %[[VAL_222:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_221]], i32 16, i32 31) +// CHECK: %[[VAL_223:.*]] = insertelement <4 x i32> %[[VAL_220]], i32 %[[VAL_222]], i64 3 +// CHECK: %[[VAL_224:.*]] = bitcast <4 x i32> %[[VAL_223]] to i128 +// CHECK: store i128 %[[VAL_224]], ptr %[[VAL_10]], align 16 +// CHECK: call void @[[SUM]](ptr %[[VAL_209]], ptr %[[VAL_10]], ptr %[[VAL_9]]) +// CHECK: %[[VAL_225:.*]] = load %[[VAL_1]], ptr %[[VAL_9]], align 1 +// CHECK: store %[[VAL_1]] %[[VAL_225]], ptr %[[VAL_209]], align 1 +// CHECK: %[[VAL_226:.*]] = load i128, ptr %[[VAL_209]], align 16 +// CHECK: %[[VAL_227:.*]] = bitcast i128 %[[VAL_226]] to <4 x i32> +// CHECK: %[[VAL_228:.*]] = extractelement <4 x i32> %[[VAL_227]], i64 0 +// CHECK: %[[VAL_229:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_228]], i32 8, i32 31) +// CHECK: %[[VAL_230:.*]] = insertelement <4 x i32> %[[VAL_227]], i32 %[[VAL_229]], i64 0 +// CHECK: %[[VAL_231:.*]] = extractelement <4 x i32> %[[VAL_230]], i64 1 +// CHECK: %[[VAL_232:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_231]], i32 8, i32 31) +// CHECK: %[[VAL_233:.*]] = insertelement <4 x i32> %[[VAL_230]], i32 %[[VAL_232]], i64 1 +// CHECK: %[[VAL_234:.*]] = extractelement <4 x i32> %[[VAL_233]], i64 2 +// CHECK: %[[VAL_235:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_234]], i32 8, i32 31) +// CHECK: %[[VAL_236:.*]] = insertelement <4 x i32> %[[VAL_233]], i32 %[[VAL_235]], i64 2 +// CHECK: %[[VAL_237:.*]] = extractelement <4 x i32> %[[VAL_236]], i64 3 +// CHECK: %[[VAL_238:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_237]], i32 8, i32 31) +// CHECK: %[[VAL_239:.*]] = insertelement <4 x i32> %[[VAL_236]], i32 %[[VAL_238]], i64 3 +// CHECK: %[[VAL_240:.*]] = bitcast <4 x i32> %[[VAL_239]] to i128 +// CHECK: store i128 %[[VAL_240]], ptr %[[VAL_8]], align 16 +// CHECK: call void @[[SUM]](ptr %[[VAL_209]], ptr %[[VAL_8]], ptr %[[VAL_7]]) +// CHECK: %[[VAL_241:.*]] = load %[[VAL_1]], ptr %[[VAL_7]], align 1 +// CHECK: store %[[VAL_1]] %[[VAL_241]], ptr %[[VAL_209]], align 1 +// CHECK: %[[VAL_242:.*]] = load i128, ptr %[[VAL_209]], align 16 +// CHECK: %[[VAL_243:.*]] = bitcast i128 %[[VAL_242]] to <4 x i32> +// CHECK: %[[VAL_244:.*]] = extractelement <4 x i32> %[[VAL_243]], i64 0 +// CHECK: %[[VAL_245:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_244]], i32 4, i32 31) +// CHECK: %[[VAL_246:.*]] = insertelement <4 x i32> %[[VAL_243]], i32 %[[VAL_245]], i64 0 +// CHECK: %[[VAL_247:.*]] = extractelement <4 x i32> %[[VAL_246]], i64 1 +// CHECK: %[[VAL_248:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_247]], i32 4, i32 31) +// CHECK: %[[VAL_249:.*]] = insertelement <4 x i32> %[[VAL_246]], i32 %[[VAL_248]], i64 1 +// CHECK: %[[VAL_250:.*]] = extractelement <4 x i32> %[[VAL_249]], i64 2 +// CHECK: %[[VAL_251:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_250]], i32 4, i32 31) +// CHECK: %[[VAL_252:.*]] = insertelement <4 x i32> %[[VAL_249]], i32 %[[VAL_251]], i64 2 +// CHECK: %[[VAL_253:.*]] = extractelement <4 x i32> %[[VAL_252]], i64 3 +// CHECK: %[[VAL_254:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_253]], i32 4, i32 31) +// CHECK: %[[VAL_255:.*]] = insertelement <4 x i32> %[[VAL_252]], i32 %[[VAL_254]], i64 3 +// CHECK: %[[VAL_256:.*]] = bitcast <4 x i32> %[[VAL_255]] to i128 +// CHECK: store i128 %[[VAL_256]], ptr %[[VAL_6]], align 16 +// CHECK: call void @[[SUM]](ptr %[[VAL_209]], ptr %[[VAL_6]], ptr %[[VAL_5]]) +// CHECK: %[[VAL_257:.*]] = load %[[VAL_1]], ptr %[[VAL_5]], align 1 +// CHECK: store %[[VAL_1]] %[[VAL_257]], ptr %[[VAL_209]], align 1 +// CHECK: %[[VAL_258:.*]] = load i128, ptr %[[VAL_209]], align 16 +// CHECK: %[[VAL_259:.*]] = bitcast i128 %[[VAL_258]] to <4 x i32> +// CHECK: %[[VAL_260:.*]] = extractelement <4 x i32> %[[VAL_259]], i64 0 +// CHECK: %[[VAL_261:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_260]], i32 2, i32 31) +// CHECK: %[[VAL_262:.*]] = insertelement <4 x i32> %[[VAL_259]], i32 %[[VAL_261]], i64 0 +// CHECK: %[[VAL_263:.*]] = extractelement <4 x i32> %[[VAL_262]], i64 1 +// CHECK: %[[VAL_264:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_263]], i32 2, i32 31) +// CHECK: %[[VAL_265:.*]] = insertelement <4 x i32> %[[VAL_262]], i32 %[[VAL_264]], i64 1 +// CHECK: %[[VAL_266:.*]] = extractelement <4 x i32> %[[VAL_265]], i64 2 +// CHECK: %[[VAL_267:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_266]], i32 2, i32 31) +// CHECK: %[[VAL_268:.*]] = insertelement <4 x i32> %[[VAL_265]], i32 %[[VAL_267]], i64 2 +// CHECK: %[[VAL_269:.*]] = extractelement <4 x i32> %[[VAL_268]], i64 3 +// CHECK: %[[VAL_270:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_269]], i32 2, i32 31) +// CHECK: %[[VAL_271:.*]] = insertelement <4 x i32> %[[VAL_268]], i32 %[[VAL_270]], i64 3 +// CHECK: %[[VAL_272:.*]] = bitcast <4 x i32> %[[VAL_271]] to i128 +// CHECK: store i128 %[[VAL_272]], ptr %[[VAL_4]], align 16 +// CHECK: call void @[[SUM]](ptr %[[VAL_209]], ptr %[[VAL_4]], ptr %[[VAL_3]]) +// CHECK: %[[VAL_273:.*]] = load %[[VAL_1]], ptr %[[VAL_3]], align 1 +// CHECK: store %[[VAL_1]] %[[VAL_273]], ptr %[[VAL_209]], align 1 +// CHECK: %[[VAL_274:.*]] = load i128, ptr %[[VAL_209]], align 16 +// CHECK: %[[VAL_275:.*]] = bitcast i128 %[[VAL_274]] to <4 x i32> +// CHECK: %[[VAL_276:.*]] = extractelement <4 x i32> %[[VAL_275]], i64 0 +// CHECK: %[[VAL_277:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_276]], i32 1, i32 31) +// CHECK: %[[VAL_278:.*]] = insertelement <4 x i32> %[[VAL_275]], i32 %[[VAL_277]], i64 0 +// CHECK: %[[VAL_279:.*]] = extractelement <4 x i32> %[[VAL_278]], i64 1 +// CHECK: %[[VAL_280:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_279]], i32 1, i32 31) +// CHECK: %[[VAL_281:.*]] = insertelement <4 x i32> %[[VAL_278]], i32 %[[VAL_280]], i64 1 +// CHECK: %[[VAL_282:.*]] = extractelement <4 x i32> %[[VAL_281]], i64 2 +// CHECK: %[[VAL_283:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_282]], i32 1, i32 31) +// CHECK: %[[VAL_284:.*]] = insertelement <4 x i32> %[[VAL_281]], i32 %[[VAL_283]], i64 2 +// CHECK: %[[VAL_285:.*]] = extractelement <4 x i32> %[[VAL_284]], i64 3 +// CHECK: %[[VAL_286:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_285]], i32 1, i32 31) +// CHECK: %[[VAL_287:.*]] = insertelement <4 x i32> %[[VAL_284]], i32 %[[VAL_286]], i64 3 +// CHECK: %[[VAL_288:.*]] = bitcast <4 x i32> %[[VAL_287]] to i128 +// CHECK: store i128 %[[VAL_288]], ptr %[[VAL_2]], align 16 +// CHECK: call void @[[SUM]](ptr %[[VAL_209]], ptr %[[VAL_2]], ptr %[[VAL_0]]) +// CHECK: %[[VAL_289:.*]] = load %[[VAL_1]], ptr %[[VAL_0]], align 1 +// CHECK: store %[[VAL_1]] %[[VAL_289]], ptr %[[VAL_209]], align 1 +// CHECK: %[[VAL_290:.*]] = icmp eq i32 %thread.id.2, 0 +// CHECK: br i1 %[[VAL_290]], label %[[VAL_291:.*]], label %[[VAL_202]] +// CHECK: reduction_write_output-after: ; preds = %[[VAL_291]], %[[VAL_201]] +// CHECK: br label %[[VAL_131]] +// CHECK: reduction_write_output-true: ; preds = %[[VAL_201]] +// CHECK: %[[VAL_293:.*]] = add i32 %tile_origin.1, 0 +// CHECK: %[[VAL_296:.*]] = load %[[VAL_1]], ptr %[[VAL_209]], align 1 +// CHECK: store %[[VAL_1]] %[[VAL_296]], ptr %[[VAL_297:.*]], align 1 // CHECK: br label %[[VAL_202]] -// CHECK: x_in_tile-true15: ; preds = %[[VAL_202]] -// CHECK: %[[VAL_214:.*]] = add i32 %[[VAL_61]], %[[VAL_65]] -// CHECK: %[[VAL_215:.*]] = add i32 %[[VAL_62]], %[[VAL_205]] -// CHECK: %[[VAL_216:.*]] = getelementptr inbounds [10000 x %[[VAL_1]]], ptr %[[VAL_176]], i32 0, i32 %[[VAL_215]] -// CHECK: %[[VAL_217:.*]] = load %[[VAL_1]], ptr %[[VAL_216]], align 1, !invariant.load !3 -// CHECK: store %[[VAL_1]] %[[VAL_217]], ptr %[[VAL_30]], align 1 -// CHECK: call void @[[SUM]](ptr %[[VAL_29]], ptr %[[VAL_30]], ptr %[[VAL_22]]) -// CHECK: %[[VAL_219:.*]] = load %[[VAL_1]], ptr %[[VAL_22]], align 1 -// CHECK: store %[[VAL_1]] %[[VAL_219]], ptr %[[VAL_29]], align 1 -// CHECK: br label %[[VAL_191]] -// CHECK: intra_warp_reduce_write-true: ; preds = %[[VAL_67]] -// CHECK: %[[VAL_222:.*]] = load %[[VAL_1]], ptr %[[VAL_29]], align 1 -// CHECK: %[[VAL_220:.*]] = getelementptr inbounds [1 x [20 x %[[VAL_1]]]], ptr addrspace(3) @shared_cache, i32 0, i32 %[[VAL_42]], i32 %[[VAL_156]] -// CHECK: %[[VAL_221:.*]] = addrspacecast ptr addrspace(3) %[[VAL_220]] to ptr -// CHECK: store %[[VAL_1]] %[[VAL_222]], ptr %[[VAL_221]], align 1 -// CHECK: br label %[[VAL_159]] -// CHECK: inter_warp_reduce-true: ; preds = %[[VAL_159]] -// CHECK: %[[VAL_223:.*]] = getelementptr inbounds [1 x [20 x %[[VAL_1]]]], ptr addrspace(3) @shared_cache, i32 0, i32 %[[VAL_42]], i32 %[[VAL_50]] -// CHECK: %[[VAL_224:.*]] = addrspacecast ptr addrspace(3) %[[VAL_223]] to ptr -// CHECK: store %[[VAL_1]] %[[VAL_37]], ptr %[[VAL_11]], align 1 -// CHECK: %[[VAL_225:.*]] = icmp ult i32 %[[VAL_48]], 20 -// CHECK: %[[VAL_226:.*]] = select i1 %[[VAL_225]], ptr %[[VAL_224]], ptr %[[VAL_11]] -// CHECK: %[[VAL_227:.*]] = load i128, ptr %[[VAL_226]], align 16 -// CHECK: %[[VAL_228:.*]] = bitcast i128 %[[VAL_227]] to <4 x i32> -// CHECK: %[[VAL_229:.*]] = extractelement <4 x i32> %[[VAL_228]], i64 0 -// CHECK: %[[VAL_230:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_229]], i32 16, i32 31) -// CHECK: %[[VAL_231:.*]] = insertelement <4 x i32> %[[VAL_228]], i32 %[[VAL_230]], i64 0 -// CHECK: %[[VAL_232:.*]] = extractelement <4 x i32> %[[VAL_231]], i64 1 -// CHECK: %[[VAL_233:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_232]], i32 16, i32 31) -// CHECK: %[[VAL_234:.*]] = insertelement <4 x i32> %[[VAL_231]], i32 %[[VAL_233]], i64 1 -// CHECK: %[[VAL_235:.*]] = extractelement <4 x i32> %[[VAL_234]], i64 2 -// CHECK: %[[VAL_236:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_235]], i32 16, i32 31) -// CHECK: %[[VAL_237:.*]] = insertelement <4 x i32> %[[VAL_234]], i32 %[[VAL_236]], i64 2 -// CHECK: %[[VAL_238:.*]] = extractelement <4 x i32> %[[VAL_237]], i64 3 -// CHECK: %[[VAL_239:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_238]], i32 16, i32 31) -// CHECK: %[[VAL_240:.*]] = insertelement <4 x i32> %[[VAL_237]], i32 %[[VAL_239]], i64 3 -// CHECK: %[[VAL_241:.*]] = bitcast <4 x i32> %[[VAL_240]] to i128 -// CHECK: store i128 %[[VAL_241]], ptr %[[VAL_10]], align 16 -// CHECK: call void @[[SUM]](ptr %[[VAL_226]], ptr %[[VAL_10]], ptr %[[VAL_9]]) -// CHECK: %[[VAL_242:.*]] = load %[[VAL_1]], ptr %[[VAL_9]], align 1 -// CHECK: store %[[VAL_1]] %[[VAL_242]], ptr %[[VAL_226]], align 1 -// CHECK: %[[VAL_243:.*]] = load i128, ptr %[[VAL_226]], align 16 -// CHECK: %[[VAL_244:.*]] = bitcast i128 %[[VAL_243]] to <4 x i32> -// CHECK: %[[VAL_245:.*]] = extractelement <4 x i32> %[[VAL_244]], i64 0 -// CHECK: %[[VAL_246:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_245]], i32 8, i32 31) -// CHECK: %[[VAL_247:.*]] = insertelement <4 x i32> %[[VAL_244]], i32 %[[VAL_246]], i64 0 -// CHECK: %[[VAL_248:.*]] = extractelement <4 x i32> %[[VAL_247]], i64 1 -// CHECK: %[[VAL_249:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_248]], i32 8, i32 31) -// CHECK: %[[VAL_250:.*]] = insertelement <4 x i32> %[[VAL_247]], i32 %[[VAL_249]], i64 1 -// CHECK: %[[VAL_251:.*]] = extractelement <4 x i32> %[[VAL_250]], i64 2 -// CHECK: %[[VAL_252:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_251]], i32 8, i32 31) -// CHECK: %[[VAL_253:.*]] = insertelement <4 x i32> %[[VAL_250]], i32 %[[VAL_252]], i64 2 -// CHECK: %[[VAL_254:.*]] = extractelement <4 x i32> %[[VAL_253]], i64 3 -// CHECK: %[[VAL_255:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_254]], i32 8, i32 31) -// CHECK: %[[VAL_256:.*]] = insertelement <4 x i32> %[[VAL_253]], i32 %[[VAL_255]], i64 3 -// CHECK: %[[VAL_257:.*]] = bitcast <4 x i32> %[[VAL_256]] to i128 -// CHECK: store i128 %[[VAL_257]], ptr %[[VAL_8]], align 16 -// CHECK: call void @[[SUM]](ptr %[[VAL_226]], ptr %[[VAL_8]], ptr %[[VAL_7]]) -// CHECK: %[[VAL_258:.*]] = load %[[VAL_1]], ptr %[[VAL_7]], align 1 -// CHECK: store %[[VAL_1]] %[[VAL_258]], ptr %[[VAL_226]], align 1 -// CHECK: %[[VAL_259:.*]] = load i128, ptr %[[VAL_226]], align 16 -// CHECK: %[[VAL_260:.*]] = bitcast i128 %[[VAL_259]] to <4 x i32> -// CHECK: %[[VAL_261:.*]] = extractelement <4 x i32> %[[VAL_260]], i64 0 -// CHECK: %[[VAL_262:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_261]], i32 4, i32 31) -// CHECK: %[[VAL_263:.*]] = insertelement <4 x i32> %[[VAL_260]], i32 %[[VAL_262]], i64 0 -// CHECK: %[[VAL_264:.*]] = extractelement <4 x i32> %[[VAL_263]], i64 1 -// CHECK: %[[VAL_265:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_264]], i32 4, i32 31) -// CHECK: %[[VAL_266:.*]] = insertelement <4 x i32> %[[VAL_263]], i32 %[[VAL_265]], i64 1 -// CHECK: %[[VAL_267:.*]] = extractelement <4 x i32> %[[VAL_266]], i64 2 -// CHECK: %[[VAL_268:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_267]], i32 4, i32 31) -// CHECK: %[[VAL_269:.*]] = insertelement <4 x i32> %[[VAL_266]], i32 %[[VAL_268]], i64 2 -// CHECK: %[[VAL_270:.*]] = extractelement <4 x i32> %[[VAL_269]], i64 3 -// CHECK: %[[VAL_271:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_270]], i32 4, i32 31) -// CHECK: %[[VAL_272:.*]] = insertelement <4 x i32> %[[VAL_269]], i32 %[[VAL_271]], i64 3 -// CHECK: %[[VAL_273:.*]] = bitcast <4 x i32> %[[VAL_272]] to i128 -// CHECK: store i128 %[[VAL_273]], ptr %[[VAL_6]], align 16 -// CHECK: call void @[[SUM]](ptr %[[VAL_226]], ptr %[[VAL_6]], ptr %[[VAL_5]]) -// CHECK: %[[VAL_274:.*]] = load %[[VAL_1]], ptr %[[VAL_5]], align 1 -// CHECK: store %[[VAL_1]] %[[VAL_274]], ptr %[[VAL_226]], align 1 -// CHECK: %[[VAL_275:.*]] = load i128, ptr %[[VAL_226]], align 16 -// CHECK: %[[VAL_276:.*]] = bitcast i128 %[[VAL_275]] to <4 x i32> -// CHECK: %[[VAL_277:.*]] = extractelement <4 x i32> %[[VAL_276]], i64 0 -// CHECK: %[[VAL_278:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_277]], i32 2, i32 31) -// CHECK: %[[VAL_279:.*]] = insertelement <4 x i32> %[[VAL_276]], i32 %[[VAL_278]], i64 0 -// CHECK: %[[VAL_280:.*]] = extractelement <4 x i32> %[[VAL_279]], i64 1 -// CHECK: %[[VAL_281:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_280]], i32 2, i32 31) -// CHECK: %[[VAL_282:.*]] = insertelement <4 x i32> %[[VAL_279]], i32 %[[VAL_281]], i64 1 -// CHECK: %[[VAL_283:.*]] = extractelement <4 x i32> %[[VAL_282]], i64 2 -// CHECK: %[[VAL_284:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_283]], i32 2, i32 31) -// CHECK: %[[VAL_285:.*]] = insertelement <4 x i32> %[[VAL_282]], i32 %[[VAL_284]], i64 2 -// CHECK: %[[VAL_286:.*]] = extractelement <4 x i32> %[[VAL_285]], i64 3 -// CHECK: %[[VAL_287:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_286]], i32 2, i32 31) -// CHECK: %[[VAL_288:.*]] = insertelement <4 x i32> %[[VAL_285]], i32 %[[VAL_287]], i64 3 -// CHECK: %[[VAL_289:.*]] = bitcast <4 x i32> %[[VAL_288]] to i128 -// CHECK: store i128 %[[VAL_289]], ptr %[[VAL_4]], align 16 -// CHECK: call void @[[SUM]](ptr %[[VAL_226]], ptr %[[VAL_4]], ptr %[[VAL_3]]) -// CHECK: %[[VAL_290:.*]] = load %[[VAL_1]], ptr %[[VAL_3]], align 1 -// CHECK: store %[[VAL_1]] %[[VAL_290]], ptr %[[VAL_226]], align 1 -// CHECK: %[[VAL_291:.*]] = load i128, ptr %[[VAL_226]], align 16 -// CHECK: %[[VAL_292:.*]] = bitcast i128 %[[VAL_291]] to <4 x i32> -// CHECK: %[[VAL_293:.*]] = extractelement <4 x i32> %[[VAL_292]], i64 0 -// CHECK: %[[VAL_294:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_293]], i32 1, i32 31) -// CHECK: %[[VAL_295:.*]] = insertelement <4 x i32> %[[VAL_292]], i32 %[[VAL_294]], i64 0 -// CHECK: %[[VAL_296:.*]] = extractelement <4 x i32> %[[VAL_295]], i64 1 -// CHECK: %[[VAL_297:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_296]], i32 1, i32 31) -// CHECK: %[[VAL_298:.*]] = insertelement <4 x i32> %[[VAL_295]], i32 %[[VAL_297]], i64 1 -// CHECK: %[[VAL_299:.*]] = extractelement <4 x i32> %[[VAL_298]], i64 2 -// CHECK: %[[VAL_300:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_299]], i32 1, i32 31) -// CHECK: %[[VAL_301:.*]] = insertelement <4 x i32> %[[VAL_298]], i32 %[[VAL_300]], i64 2 -// CHECK: %[[VAL_302:.*]] = extractelement <4 x i32> %[[VAL_301]], i64 3 -// CHECK: %[[VAL_303:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_302]], i32 1, i32 31) -// CHECK: %[[VAL_304:.*]] = insertelement <4 x i32> %[[VAL_301]], i32 %[[VAL_303]], i64 3 -// CHECK: %[[VAL_305:.*]] = bitcast <4 x i32> %[[VAL_304]] to i128 -// CHECK: store i128 %[[VAL_305]], ptr %[[VAL_2]], align 16 -// CHECK: call void @[[SUM]](ptr %[[VAL_226]], ptr %[[VAL_2]], ptr %[[VAL_0]]) -// CHECK: %[[VAL_306:.*]] = load %[[VAL_1]], ptr %[[VAL_0]], align 1 -// CHECK: store %[[VAL_1]] %[[VAL_306]], ptr %[[VAL_226]], align 1 -// CHECK: %[[VAL_307:.*]] = icmp eq i32 %[[VAL_48]], 0 -// CHECK: br i1 %[[VAL_307]], label %[[VAL_308:.*]], label %[[VAL_162]] -// CHECK: reduction_write_output-after: ; preds = %[[VAL_308]], %[[VAL_161]] -// CHECK: br label %[[VAL_35]] -// CHECK: reduction_write_output-true: ; preds = %[[VAL_161]] -// CHECK: %[[VAL_310:.*]] = add i32 %[[VAL_61]], %[[VAL_49]] -// CHECK: %[[VAL_311:.*]] = add i32 %[[VAL_62]], %[[VAL_189]] -// CHECK: %[[VAL_312:.*]] = load %[[VAL_1]], ptr %[[VAL_226]], align 1 -// CHECK: store %[[VAL_1]] %[[VAL_312]], ptr %[[VAL_313:.*]], align 1 -// CHECK: br label %[[VAL_162]] // CHECK: entry: -// CHECK: %[[VAL_314:.*]] = alloca %[[VAL_315:.*]], align 8 -// CHECK: %[[VAL_316:.*]] = load %[[VAL_315]], ptr %[[VAL_317:.*]], align 1 -// CHECK: %[[VAL_318:.*]] = load %[[VAL_315]], ptr %[[VAL_319:.*]], align 1 -// CHECK: %[[VAL_320:.*]] = extractvalue %[[VAL_315]] %[[VAL_316]], 0 -// CHECK: %[[VAL_321:.*]] = extractvalue %[[VAL_315]] %[[VAL_318]], 0 -// CHECK: %[[VAL_322:.*]] = fadd double %[[VAL_320]], %[[VAL_321]] -// CHECK: %[[VAL_323:.*]] = extractvalue %[[VAL_315]] %[[VAL_316]], 1 -// CHECK: %[[VAL_324:.*]] = extractvalue %[[VAL_315]] %[[VAL_318]], 1 -// CHECK: %[[VAL_325:.*]] = fadd double %[[VAL_323]], %[[VAL_324]] -// CHECK: %[[VAL_326:.*]] = insertvalue %[[VAL_315]] zeroinitializer, double %[[VAL_322]], 0 -// CHECK: %[[VAL_327:.*]] = insertvalue %[[VAL_315]] %[[VAL_326]], double %[[VAL_325]], 1 -// CHECK: store %[[VAL_315]] %[[VAL_327]], ptr %[[VAL_314]], align 1 -// CHECK: %[[VAL_328:.*]] = load %[[VAL_315]], ptr %[[VAL_314]], align 1 -// CHECK: store %[[VAL_315]] %[[VAL_328]], ptr %[[VAL_329:.*]], align 1 +// CHECK: %[[VAL_298:.*]] = alloca %[[VAL_299:.*]], align 8 +// CHECK: %[[VAL_300:.*]] = load %[[VAL_299]], ptr %[[VAL_301:.*]], align 1 +// CHECK: %[[VAL_302:.*]] = load %[[VAL_299]], ptr %[[VAL_303:.*]], align 1 +// CHECK: %[[VAL_304:.*]] = extractvalue %[[VAL_299]] %[[VAL_300]], 0 +// CHECK: %[[VAL_305:.*]] = extractvalue %[[VAL_299]] %[[VAL_302]], 0 +// CHECK: %[[VAL_306:.*]] = fadd double %[[VAL_304]], %[[VAL_305]] +// CHECK: %[[VAL_307:.*]] = extractvalue %[[VAL_299]] %[[VAL_300]], 1 +// CHECK: %[[VAL_308:.*]] = extractvalue %[[VAL_299]] %[[VAL_302]], 1 +// CHECK: %[[VAL_309:.*]] = fadd double %[[VAL_307]], %[[VAL_308]] +// CHECK: %[[VAL_310:.*]] = insertvalue %[[VAL_299]] zeroinitializer, double %[[VAL_306]], 0 +// CHECK: %[[VAL_311:.*]] = insertvalue %[[VAL_299]] %[[VAL_310]], double %[[VAL_309]], 1 +// CHECK: store %[[VAL_299]] %[[VAL_311]], ptr %[[VAL_298]], align 1 +// CHECK: %[[VAL_312:.*]] = load %[[VAL_299]], ptr %[[VAL_298]], align 1 +// CHECK: store %[[VAL_299]] %[[VAL_312]], ptr %[[VAL_313:.*]], align 1 // CHECK: ret void diff --git a/third_party/xla/xla/service/gpu/tests/reduce_row_vectorized.hlo b/third_party/xla/xla/service/gpu/tests/reduce_row_vectorized.hlo index af908358f3c152..b85eeb0ac8831a 100644 --- a/third_party/xla/xla/service/gpu/tests/reduce_row_vectorized.hlo +++ b/third_party/xla/xla/service/gpu/tests/reduce_row_vectorized.hlo @@ -44,10 +44,10 @@ ENTRY reduce.1 { // CHECK: %[[VAL_19:.*]] = alloca float, align 4 // CHECK: %[[VAL_20:.*]] = alloca float, align 4 // CHECK: %[[VAL_21:.*]] = alloca float, align 4 -// CHECK: %[[VAL_22:.*]] = alloca float, align 4 +// CHECK: %[[VAL_22:.*]] = alloca i32, align 4 // CHECK: %[[VAL_23:.*]] = alloca i32, align 4 // CHECK: %[[VAL_24:.*]] = alloca float, align 4 -// CHECK: %[[VAL_25:.*]] = alloca float, align 4 +// CHECK: %[[VAL_25:.*]] = alloca i32, align 4 // CHECK: %[[VAL_26:.*]] = alloca i32, align 4 // CHECK: %[[VAL_27:.*]] = alloca i32, align 4 // CHECK: %[[VAL_28:.*]] = alloca float, align 4 @@ -56,231 +56,245 @@ ENTRY reduce.1 { // CHECK-GCN: %[[VAL_30:.*]] = call i32 @llvm.amdgcn.workgroup.id.y // CHECK: %[[VAL_31:.*]] = icmp eq i32 %[[VAL_30]], 0 // CHECK: br i1 %[[VAL_31]], label %[[VAL_32:.*]], label %[[VAL_33:.*]] -// CHECK: reduce-group-0-after: ; preds = %[[VAL_34:.*]], %[[VAL_35:.*]] +// CHECK: reduce-group-0-after: ; preds = %thread_in_bounds-after, %[[VAL_34:.*]] // CHECK: ret void -// CHECK: reduce-group-0-true: ; preds = %[[VAL_35]] -// CHECK: %[[VAL_36:.*]] = load float, ptr @0, align 4 -// CHECK: store float %[[VAL_36]], ptr %[[VAL_28]], align 4 -// CHECK-PTX: %[[VAL_37:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range !3 -// CHECK-GCN: %[[VAL_37:.*]] = call i32 @llvm.amdgcn.woritem.id.x -// CHECK-PTX: %[[VAL_38:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range !4 -// CHECK-GCN: %[[VAL_38:.*]] = call i32 @llvm.amdgcn.workgroup.id.x -// CHECK: %[[VAL_39:.*]] = urem i32 %[[VAL_37]], 64 -// CHECK: %[[VAL_40:.*]] = udiv i32 %[[VAL_37]], 64 -// CHECK: %[[VAL_41:.*]] = mul i32 %[[VAL_38]], 1 -// CHECK: %[[VAL_42:.*]] = add i32 %[[VAL_41]], %[[VAL_40]] -// CHECK: %[[VAL_43:.*]] = icmp ult i32 %[[VAL_42]], 131072 -// CHECK: br i1 %[[VAL_43]], label %[[VAL_44:.*]], label %[[VAL_45:.*]] -// CHECK: 9: ; preds = %[[VAL_32]] -// CHECK: %[[VAL_47:.*]] = udiv i32 %[[VAL_39]], 64 -// CHECK: %[[VAL_46:.*]] = urem i32 %[[VAL_39]], 64 -// CHECK: %[[VAL_96:.*]] = mul i32 %[[VAL_46]], 2 -// CHECK: %[[VAL_48:.*]] = urem i32 %[[VAL_39]], 32 -// CHECK: %[[VAL_49:.*]] = udiv i32 %[[VAL_42]], 1 -// CHECK: %[[VAL_50:.*]] = urem i32 %[[VAL_49]], 1 -// CHECK: %[[VAL_51:.*]] = udiv i32 %[[VAL_42]], 1 -// CHECK: %[[VAL_52:.*]] = urem i32 %[[VAL_51]], 131072 -// CHECK: %[[VAL_53:.*]] = udiv i32 %[[VAL_42]], 131072 -// CHECK: %[[VAL_58:.*]] = mul i32 %[[VAL_53]], 1 -// CHECK: %[[VAL_59:.*]] = mul i32 %[[VAL_52]], 1 -// CHECK: %[[VAL_60:.*]] = mul i32 %[[VAL_50]], 1024 -// CHECK: store i32 %[[VAL_47]], ptr %[[VAL_27]], align 4 -// CHECK: br label %[[VAL_61:.*]] -// CHECK: loop1.loop_header: ; preds = %[[VAL_62:.*]], %[[VAL_44]] -// CHECK: %[[VAL_63:.*]] = load i32, ptr %[[VAL_27]], align 4 -// CHECK: %[[VAL_64:.*]] = icmp uge i32 %[[VAL_63]], 1 -// CHECK: br i1 %[[VAL_64]], label %[[VAL_65:.*]], label %[[VAL_66:.*]] -// CHECK: loop1.loop_body: ; preds = %[[VAL_61]] -// CHECK: %[[VAL_67:.*]] = add nuw nsw i32 %[[VAL_63]], 1 -// CHECK: store i32 %[[VAL_67]], ptr %[[VAL_27]], align 4 -// CHECK: %[[VAL_68:.*]] = icmp eq i32 %[[VAL_63]], %[[VAL_47]] -// CHECK: br i1 true, label %[[VAL_70:.*]], label %[[VAL_71:.*]] -// CHECK: is_full_tile-after: ; preds = %[[VAL_72:.*]], %[[VAL_73:.*]] -// CHECK: br label %[[VAL_61]], !llvm.loop !5 -// CHECK: loop1.loop_exit: ; preds = %[[VAL_61]] -// CHECK: %[[VAL_74:.*]] = load float, ptr %[[VAL_28]], align 4 -// CHECK: %[[VAL_75:.*]] = call float @llvm.nvvm.shfl.sync.down.f32(i32 -1, float %[[VAL_74]], i32 16, i32 31) -// CHECK: store float %[[VAL_75]], ptr %[[VAL_20]], align 4 +// CHECK: reduce-group-0-true: ; preds = %[[VAL_34]] +// CHECK: %[[VAL_35:.*]] = load float, ptr @0, align 4 +// CHECK: store float %[[VAL_35]], ptr %[[VAL_28]], align 4 +// CHECK-PTX: %thread.id.x = call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range !3 +// CHECK-GCN: %thread.id.x = call i32 @llvm.amdgcn.workitem.id.x +// CHECK-PTX: %block.id.x = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range !4 +// CHECK-GCN: %block.id.x = call i32 @llvm.amdgcn.workgroup.id.x +// CHECK: %[[VAL_36:.*]] = udiv i32 %thread.id.x, 64 +// CHECK: %thread.id.1 = urem i32 %[[VAL_36]], 4 +// CHECK: %thread.id.2 = urem i32 %thread.id.x, 64 +// CHECK: %lane_id = urem i32 %thread.id.x, 32 +// CHECK: %[[VAL_37:.*]] = udiv i32 %block.id.x, 1 +// CHECK: %[[VAL_38:.*]] = urem i32 %[[VAL_37]], 1 +// CHECK: %[[VAL_39:.*]] = udiv i32 %block.id.x, 1 +// CHECK: %[[VAL_40:.*]] = urem i32 %[[VAL_39]], 1 +// CHECK: %[[VAL_41:.*]] = udiv i32 %block.id.x, 1 +// CHECK: %[[VAL_42:.*]] = urem i32 %[[VAL_41]], 32768 +// CHECK: %[[VAL_43:.*]] = udiv i32 %block.id.x, 32768 +// CHECK: %tile_origin.0 = mul i32 %[[VAL_43]], 1 +// CHECK: %tile_origin.1 = mul i32 %[[VAL_42]], 4 +// CHECK: %tile_origin.2 = mul i32 %[[VAL_40]], 512 +// CHECK: %tile_origin.3 = mul i32 %[[VAL_38]], 2 +// CHECK: store i32 %thread.id.1, ptr %[[VAL_27]], align 4 +// CHECK: br label %[[VAL_44:.*]] +// CHECK: loop1.loop_header: ; preds = %[[VAL_45:.*]], %[[VAL_32]] +// CHECK: %[[VAL_46:.*]] = load i32, ptr %[[VAL_27]], align 4 +// CHECK: %[[VAL_47:.*]] = icmp uge i32 %[[VAL_46]], 4 +// CHECK: br i1 %[[VAL_47]], label %[[VAL_48:.*]], label %[[VAL_49:.*]] +// CHECK: loop1.loop_body: ; preds = %[[VAL_44]] +// CHECK: %[[VAL_50:.*]] = add nuw nsw i32 %[[VAL_46]], 4 +// CHECK: store i32 %[[VAL_50]], ptr %[[VAL_27]], align 4 +// CHECK: %[[VAL_51:.*]] = icmp eq i32 %[[VAL_46]], %thread.id.1 +// CHECK: br i1 true, label %[[VAL_52:.*]], label %[[VAL_53:.*]] +// CHECK: is_full_tile-after: ; preds = %[[VAL_54:.*]], %[[VAL_55:.*]] +// CHECK: br label %[[VAL_44]], !llvm.loop !5 +// CHECK: loop1.loop_exit: ; preds = %[[VAL_44]] +// CHECK: %[[VAL_56:.*]] = load float, ptr %[[VAL_28]], align 4 +// CHECK: %[[VAL_57:.*]] = call float @llvm.nvvm.shfl.sync.down.f32(i32 -1, float %[[VAL_56]], i32 16, i32 31) +// CHECK: store float %[[VAL_57]], ptr %[[VAL_20]], align 4 // CHECK: call void @[[SUM:Sum.*]](ptr %[[VAL_28]], ptr %[[VAL_20]], ptr %[[VAL_19]]) -// CHECK: %[[VAL_76:.*]] = load float, ptr %[[VAL_19]], align 4 -// CHECK: store float %[[VAL_76]], ptr %[[VAL_28]], align 4 -// CHECK: %[[VAL_77:.*]] = load float, ptr %[[VAL_28]], align 4 -// CHECK: %[[VAL_78:.*]] = call float @llvm.nvvm.shfl.sync.down.f32(i32 -1, float %[[VAL_77]], i32 8, i32 31) -// CHECK: store float %[[VAL_78]], ptr %[[VAL_18]], align 4 +// CHECK: %[[VAL_58:.*]] = load float, ptr %[[VAL_19]], align 4 +// CHECK: store float %[[VAL_58]], ptr %[[VAL_28]], align 4 +// CHECK: %[[VAL_59:.*]] = load float, ptr %[[VAL_28]], align 4 +// CHECK: %[[VAL_60:.*]] = call float @llvm.nvvm.shfl.sync.down.f32(i32 -1, float %[[VAL_59]], i32 8, i32 31) +// CHECK: store float %[[VAL_60]], ptr %[[VAL_18]], align 4 // CHECK: call void @[[SUM]](ptr %[[VAL_28]], ptr %[[VAL_18]], ptr %[[VAL_17]]) -// CHECK: %[[VAL_79:.*]] = load float, ptr %[[VAL_17]], align 4 -// CHECK: store float %[[VAL_79]], ptr %[[VAL_28]], align 4 -// CHECK: %[[VAL_80:.*]] = load float, ptr %[[VAL_28]], align 4 -// CHECK: %[[VAL_81:.*]] = call float @llvm.nvvm.shfl.sync.down.f32(i32 -1, float %[[VAL_80]], i32 4, i32 31) -// CHECK: store float %[[VAL_81]], ptr %[[VAL_16]], align 4 +// CHECK: %[[VAL_61:.*]] = load float, ptr %[[VAL_17]], align 4 +// CHECK: store float %[[VAL_61]], ptr %[[VAL_28]], align 4 +// CHECK: %[[VAL_62:.*]] = load float, ptr %[[VAL_28]], align 4 +// CHECK: %[[VAL_63:.*]] = call float @llvm.nvvm.shfl.sync.down.f32(i32 -1, float %[[VAL_62]], i32 4, i32 31) +// CHECK: store float %[[VAL_63]], ptr %[[VAL_16]], align 4 // CHECK: call void @[[SUM]](ptr %[[VAL_28]], ptr %[[VAL_16]], ptr %[[VAL_15]]) -// CHECK: %[[VAL_82:.*]] = load float, ptr %[[VAL_15]], align 4 -// CHECK: store float %[[VAL_82]], ptr %[[VAL_28]], align 4 -// CHECK: %[[VAL_83:.*]] = load float, ptr %[[VAL_28]], align 4 -// CHECK: %[[VAL_84:.*]] = call float @llvm.nvvm.shfl.sync.down.f32(i32 -1, float %[[VAL_83]], i32 2, i32 31) -// CHECK: store float %[[VAL_84]], ptr %[[VAL_14]], align 4 +// CHECK: %[[VAL_64:.*]] = load float, ptr %[[VAL_15]], align 4 +// CHECK: store float %[[VAL_64]], ptr %[[VAL_28]], align 4 +// CHECK: %[[VAL_65:.*]] = load float, ptr %[[VAL_28]], align 4 +// CHECK: %[[VAL_66:.*]] = call float @llvm.nvvm.shfl.sync.down.f32(i32 -1, float %[[VAL_65]], i32 2, i32 31) +// CHECK: store float %[[VAL_66]], ptr %[[VAL_14]], align 4 // CHECK: call void @[[SUM]](ptr %[[VAL_28]], ptr %[[VAL_14]], ptr %[[VAL_13]]) -// CHECK: %[[VAL_85:.*]] = load float, ptr %[[VAL_13]], align 4 -// CHECK: store float %[[VAL_85]], ptr %[[VAL_28]], align 4 -// CHECK: %[[VAL_86:.*]] = load float, ptr %[[VAL_28]], align 4 -// CHECK: %[[VAL_87:.*]] = call float @llvm.nvvm.shfl.sync.down.f32(i32 -1, float %[[VAL_86]], i32 1, i32 31) -// CHECK: store float %[[VAL_87]], ptr %[[VAL_12]], align 4 +// CHECK: %[[VAL_67:.*]] = load float, ptr %[[VAL_13]], align 4 +// CHECK: store float %[[VAL_67]], ptr %[[VAL_28]], align 4 +// CHECK: %[[VAL_68:.*]] = load float, ptr %[[VAL_28]], align 4 +// CHECK: %[[VAL_69:.*]] = call float @llvm.nvvm.shfl.sync.down.f32(i32 -1, float %[[VAL_68]], i32 1, i32 31) +// CHECK: store float %[[VAL_69]], ptr %[[VAL_12]], align 4 // CHECK: call void @[[SUM]](ptr %[[VAL_28]], ptr %[[VAL_12]], ptr %[[VAL_11]]) -// CHECK: %[[VAL_88:.*]] = load float, ptr %[[VAL_11]], align 4 -// CHECK: store float %[[VAL_88]], ptr %[[VAL_28]], align 4 -// CHECK: %[[VAL_89:.*]] = udiv i32 %[[VAL_46]], 32 -// CHECK: %[[VAL_90:.*]] = icmp eq i32 %[[VAL_48]], 0 -// CHECK: br i1 %[[VAL_90]], label %[[VAL_91:.*]], label %[[VAL_92:.*]] -// CHECK: intra_warp_reduce_write-after: ; preds = %[[VAL_91]], %[[VAL_65]] -// CHECK: call void @llvm.nvvm.barrier0() -// CHECK: %[[VAL_93:.*]] = icmp eq i32 %[[VAL_89]], 0 -// CHECK: br i1 %[[VAL_93]], label %[[VAL_94:.*]], label %[[VAL_34]] -// CHECK: inter_warp_reduce-after: ; preds = %[[VAL_95:.*]], %[[VAL_92]] +// CHECK: %[[VAL_70:.*]] = load float, ptr %[[VAL_11]], align 4 +// CHECK: store float %[[VAL_70]], ptr %[[VAL_28]], align 4 +// CHECK: %[[VAL_71:.*]] = udiv i32 %thread.id.2, 32 +// CHECK: %[[VAL_72:.*]] = icmp ult i32 %thread.id.1, 4 +// CHECK: br i1 %[[VAL_72]], label %thread_in_bounds-true, label %thread_in_bounds-after +// CHECK: thread_in_bounds-after: ; preds = %[[VAL_73:.*]], %[[VAL_48]] // CHECK: br label %[[VAL_33]] -// CHECK: early_return: ; preds = %[[VAL_32]] -// CHECK: ret void -// CHECK: is_full_tile-true: ; preds = %[[VAL_66]] +// CHECK: is_full_tile-true: ; preds = %[[VAL_49]] // CHECK: store i32 0, ptr %[[VAL_26]], align 4 -// CHECK: br label %[[VAL_97:.*]] -// CHECK: loop2.loop_header: ; preds = %[[VAL_98:.*]], %[[VAL_70]] -// CHECK: %[[VAL_99:.*]] = load i32, ptr %[[VAL_26]], align 4 -// CHECK: %[[VAL_100:.*]] = icmp uge i32 %[[VAL_99]], 8 -// CHECK: br i1 %[[VAL_100]], label %[[VAL_73]], label %[[VAL_98]] -// CHECK: loop2.loop_body: ; preds = %[[VAL_97]] -// CHECK: %[[VAL_101:.*]] = add nuw nsw i32 %[[VAL_99]], 1 -// CHECK: store i32 %[[VAL_101]], ptr %[[VAL_26]], align 4 -// CHECK: %[[VAL_102:.*]] = icmp eq i32 %[[VAL_99]], 0 -// CHECK: %[[VAL_103:.*]] = mul i32 %[[VAL_99]], 128 -// CHECK: %[[VAL_104:.*]] = add i32 %[[VAL_103]], 0 -// CHECK: %[[VAL_105:.*]] = add i32 %[[VAL_104]], %[[VAL_96]] -// CHECK: %[[VAL_106:.*]] = add i32 %[[VAL_59]], %[[VAL_63]] -// CHECK: %[[VAL_107:.*]] = add i32 %[[VAL_60]], %[[VAL_105]] -// CHECK: %[[VAL_108:.*]] = getelementptr inbounds [131072 x [1024 x float]], ptr %[[VAL_109:.*]], i32 0, i32 %[[VAL_106]], i32 %[[VAL_107]] -// CHECK: %[[VAL_110:.*]] = load float, ptr %[[VAL_108]], align 4, !invariant.load !7 -// CHECK: store float %[[VAL_110]], ptr %[[VAL_29]], align 4 -// CHECK: call void @[[SUM]](ptr %[[VAL_28]], ptr %[[VAL_29]], ptr %[[VAL_25]]) -// CHECK: %[[VAL_112:.*]] = load float, ptr %[[VAL_25]], align 4 -// CHECK: store float %[[VAL_112]], ptr %[[VAL_28]], align 4 -// CHECK: %[[VAL_113:.*]] = mul i32 %[[VAL_99]], 128 -// CHECK: %[[VAL_114:.*]] = add i32 %[[VAL_113]], 1 -// CHECK: %[[VAL_115:.*]] = add i32 %[[VAL_114]], %[[VAL_96]] -// CHECK: %[[VAL_116:.*]] = add i32 %[[VAL_59]], %[[VAL_63]] -// CHECK: %[[VAL_117:.*]] = add i32 %[[VAL_60]], %[[VAL_115]] -// CHECK: %[[VAL_118:.*]] = getelementptr inbounds [131072 x [1024 x float]], ptr %[[VAL_109]], i32 0, i32 %[[VAL_116]], i32 %[[VAL_117]] -// CHECK: %[[VAL_119:.*]] = load float, ptr %[[VAL_118]], align 4, !invariant.load !7 -// CHECK: store float %[[VAL_119]], ptr %[[VAL_29]], align 4 +// CHECK: br label %[[VAL_74:.*]] +// CHECK: loop2.loop_header: ; preds = %[[VAL_75:.*]], %[[VAL_52]] +// CHECK: %[[VAL_76:.*]] = load i32, ptr %[[VAL_26]], align 4 +// CHECK: %[[VAL_77:.*]] = icmp uge i32 %[[VAL_76]], 512 +// CHECK: br i1 %[[VAL_77]], label %[[VAL_55]], label %[[VAL_78:.*]] +// CHECK: loop2.loop_body: ; preds = %[[VAL_74]] +// CHECK: %[[VAL_79:.*]] = add nuw nsw i32 %[[VAL_76]], 64 +// CHECK: store i32 %[[VAL_79]], ptr %[[VAL_26]], align 4 +// CHECK: %[[VAL_80:.*]] = icmp eq i32 %[[VAL_76]], 0 +// CHECK: %[[VAL_81:.*]] = add i32 %[[VAL_76]], %thread.id.2 +// CHECK: store i32 0, ptr %[[VAL_25]], align 4 +// CHECK: br label %[[VAL_82:.*]] +// CHECK: loop3.loop_header: ; preds = %[[VAL_83:.*]], %[[VAL_78]] +// CHECK: %[[VAL_84:.*]] = load i32, ptr %[[VAL_25]], align 4 +// CHECK: %[[VAL_85:.*]] = icmp uge i32 %[[VAL_84]], 2 +// CHECK: br i1 %[[VAL_85]], label %[[VAL_75]], label %[[VAL_83]] +// CHECK: loop3.loop_body: ; preds = %[[VAL_82]] +// CHECK: %[[VAL_86:.*]] = add nuw nsw i32 %[[VAL_84]], 1 +// CHECK: store i32 %[[VAL_86]], ptr %[[VAL_25]], align 4 +// CHECK: %[[VAL_87:.*]] = icmp eq i32 %[[VAL_84]], 0 +// CHECK: %[[VAL_88:.*]] = add i32 %tile_origin.0, 0 +// CHECK: %[[VAL_89:.*]] = add i32 %tile_origin.1, %[[VAL_46]] +// CHECK: %[[VAL_90:.*]] = add i32 %tile_origin.2, %[[VAL_81]] +// CHECK: %[[VAL_91:.*]] = add i32 %tile_origin.3, %[[VAL_84]] +// CHECK: %[[VAL_92:.*]] = mul nuw nsw i32 %[[VAL_91]], 1 +// CHECK: %[[VAL_93:.*]] = add nuw nsw i32 0, %[[VAL_92]] +// CHECK: %[[VAL_94:.*]] = mul nuw nsw i32 %[[VAL_90]], 2 +// CHECK: %[[VAL_95:.*]] = add nuw nsw i32 %[[VAL_93]], %[[VAL_94]] +// CHECK: %[[VAL_96:.*]] = udiv i32 %[[VAL_95]], 1024 +// CHECK: %[[VAL_97:.*]] = mul nuw nsw i32 %[[VAL_89]], 1 +// CHECK: %[[VAL_98:.*]] = add nuw nsw i32 0, %[[VAL_97]] +// CHECK: %[[VAL_99:.*]] = udiv i32 %[[VAL_98]], 131072 +// CHECK: %[[VAL_100:.*]] = mul nuw nsw i32 %[[VAL_88]], 1 +// CHECK: %[[VAL_101:.*]] = add nuw nsw i32 0, %[[VAL_100]] +// CHECK: %[[VAL_102:.*]] = getelementptr inbounds [131072 x [1024 x float]], ptr %[[VAL_103:.*]], i32 0, i32 %[[VAL_98]], i32 %[[VAL_95]] +// CHECK: %[[VAL_104:.*]] = load float, ptr %[[VAL_102]], align 4, !invariant.load !7 +// CHECK: store float %[[VAL_104]], ptr %[[VAL_29]], align 4 // CHECK: call void @[[SUM]](ptr %[[VAL_28]], ptr %[[VAL_29]], ptr %[[VAL_24]]) -// CHECK: %[[VAL_121:.*]] = load float, ptr %[[VAL_24]], align 4 -// CHECK: store float %[[VAL_121]], ptr %[[VAL_28]], align 4 -// CHECK: br label %[[VAL_97]], !llvm.loop !8 -// CHECK: loop2.loop_exit: ; preds = %[[VAL_97]] -// CHECK: br label %[[VAL_62]] -// CHECK: is_full_tile-false: ; preds = %[[VAL_66]] +// CHECK: %[[VAL_105:.*]] = load float, ptr %[[VAL_24]], align 4 +// CHECK: store float %[[VAL_105]], ptr %[[VAL_28]], align 4 +// CHECK: br label %[[VAL_82]], !llvm.loop !8 +// CHECK: loop3.loop_exit: ; preds = %[[VAL_82]] +// CHECK: br label %[[VAL_74]], !llvm.loop !9 +// CHECK: loop2.loop_exit: ; preds = %[[VAL_74]] +// CHECK: br label %[[VAL_45]] +// CHECK: is_full_tile-false: ; preds = %[[VAL_49]] // CHECK: store i32 0, ptr %[[VAL_23]], align 4 -// CHECK: br label %[[VAL_123:.*]] -// CHECK: loop2.loop_header7: ; preds = %[[VAL_124:.*]], %[[VAL_71]] -// CHECK: %[[VAL_125:.*]] = load i32, ptr %[[VAL_23]], align 4 -// CHECK: %[[VAL_126:.*]] = icmp uge i32 %[[VAL_125]], 8 -// CHECK: br i1 %[[VAL_126]], label %[[VAL_72]], label %[[VAL_127:.*]] -// CHECK: loop2.loop_body8: ; preds = %[[VAL_123]] -// CHECK: %[[VAL_128:.*]] = add nuw nsw i32 %[[VAL_125]], 1 -// CHECK: store i32 %[[VAL_128]], ptr %[[VAL_23]], align 4 -// CHECK: %[[VAL_129:.*]] = icmp eq i32 %[[VAL_125]], 0 -// CHECK: %[[VAL_130:.*]] = mul i32 %[[VAL_125]], 128 -// CHECK: %[[VAL_131:.*]] = add i32 %[[VAL_130]], 0 -// CHECK: %[[VAL_132:.*]] = add i32 %[[VAL_131]], %[[VAL_96]] -// CHECK: %[[VAL_133:.*]] = icmp ult i32 %[[VAL_132]], 1024 -// CHECK: br i1 %[[VAL_133]], label %[[VAL_134:.*]], label %[[VAL_135:.*]] -// CHECK: x_in_tile-after: ; preds = %[[VAL_134]], %[[VAL_127]] -// CHECK: %[[VAL_136:.*]] = mul i32 %[[VAL_125]], 128 -// CHECK: %[[VAL_137:.*]] = add i32 %[[VAL_136]], 1 -// CHECK: %[[VAL_138:.*]] = add i32 %[[VAL_137]], %[[VAL_96]] -// CHECK: %[[VAL_139:.*]] = icmp ult i32 %[[VAL_138]], 1024 -// CHECK: br i1 %[[VAL_139]], label %[[VAL_140:.*]], label %[[VAL_124]] -// CHECK: x_in_tile-after16: ; preds = %[[VAL_140]], %[[VAL_135]] -// CHECK: br label %[[VAL_123]], !llvm.loop !10 -// CHECK: loop2.loop_exit6: ; preds = %[[VAL_123]] -// CHECK: br label %[[VAL_62]] -// CHECK: x_in_tile-true: ; preds = %[[VAL_127]] -// CHECK: %[[VAL_141:.*]] = add i32 %[[VAL_59]], %[[VAL_63]] -// CHECK: %[[VAL_142:.*]] = add i32 %[[VAL_60]], %[[VAL_132]] -// CHECK: %[[VAL_143:.*]] = getelementptr inbounds [131072 x [1024 x float]], ptr %[[VAL_109]], i32 0, i32 %[[VAL_141]], i32 %[[VAL_142]] -// CHECK: %[[VAL_144:.*]] = load float, ptr %[[VAL_143]], align 4, !invariant.load !7 -// CHECK: store float %[[VAL_144]], ptr %[[VAL_29]], align 4 -// CHECK: call void @[[SUM]](ptr %[[VAL_28]], ptr %[[VAL_29]], ptr %[[VAL_22]]) -// CHECK: %[[VAL_146:.*]] = load float, ptr %[[VAL_22]], align 4 -// CHECK: store float %[[VAL_146]], ptr %[[VAL_28]], align 4 -// CHECK: br label %[[VAL_135]] -// CHECK: x_in_tile-true15: ; preds = %[[VAL_135]] -// CHECK: %[[VAL_147:.*]] = add i32 %[[VAL_59]], %[[VAL_63]] -// CHECK: %[[VAL_148:.*]] = add i32 %[[VAL_60]], %[[VAL_138]] -// CHECK: %[[VAL_149:.*]] = getelementptr inbounds [131072 x [1024 x float]], ptr %[[VAL_109]], i32 0, i32 %[[VAL_147]], i32 %[[VAL_148]] -// CHECK: %[[VAL_150:.*]] = load float, ptr %[[VAL_149]], align 4, !invariant.load !7 -// CHECK: store float %[[VAL_150]], ptr %[[VAL_29]], align 4 +// CHECK: br label %[[VAL_106:.*]] +// CHECK: loop2.loop_header5: ; preds = %[[VAL_107:.*]], %[[VAL_53]] +// CHECK: %[[VAL_108:.*]] = load i32, ptr %[[VAL_23]], align 4 +// CHECK: %[[VAL_109:.*]] = icmp uge i32 %[[VAL_108]], 512 +// CHECK: br i1 %[[VAL_109]], label %[[VAL_54]], label %[[VAL_110:.*]] +// CHECK: loop2.loop_body6: ; preds = %[[VAL_106]] +// CHECK: %[[VAL_111:.*]] = add nuw nsw i32 %[[VAL_108]], 64 +// CHECK: store i32 %[[VAL_111]], ptr %[[VAL_23]], align 4 +// CHECK: %[[VAL_112:.*]] = icmp eq i32 %[[VAL_108]], 0 +// CHECK: %[[VAL_113:.*]] = add i32 %[[VAL_108]], %thread.id.2 +// CHECK: %[[VAL_114:.*]] = icmp ult i32 %[[VAL_113]], 512 +// CHECK: br i1 %[[VAL_114]], label %[[VAL_115:.*]], label %[[VAL_107]] +// CHECK: x_in_tile-after: ; preds = %[[VAL_116:.*]], %[[VAL_110]] +// CHECK: br label %[[VAL_106]], !llvm.loop !11 +// CHECK: loop2.loop_exit4: ; preds = %[[VAL_106]] +// CHECK: br label %[[VAL_45]] +// CHECK: x_in_tile-true: ; preds = %[[VAL_110]] +// CHECK: store i32 0, ptr %[[VAL_22]], align 4 +// CHECK: br label %[[VAL_117:.*]] +// CHECK: loop3.loop_header11: ; preds = %[[VAL_118:.*]], %[[VAL_115]] +// CHECK: %[[VAL_119:.*]] = load i32, ptr %[[VAL_22]], align 4 +// CHECK: %[[VAL_120:.*]] = icmp uge i32 %[[VAL_119]], 2 +// CHECK: br i1 %[[VAL_120]], label %[[VAL_116]], label %[[VAL_118]] +// CHECK: loop3.loop_body12: ; preds = %[[VAL_117]] +// CHECK: %[[VAL_121:.*]] = add nuw nsw i32 %[[VAL_119]], 1 +// CHECK: store i32 %[[VAL_121]], ptr %[[VAL_22]], align 4 +// CHECK: %[[VAL_122:.*]] = icmp eq i32 %[[VAL_119]], 0 +// CHECK: %[[VAL_123:.*]] = add i32 %tile_origin.0, 0 +// CHECK: %[[VAL_124:.*]] = add i32 %tile_origin.1, %[[VAL_46]] +// CHECK: %[[VAL_125:.*]] = add i32 %tile_origin.2, %[[VAL_113]] +// CHECK: %[[VAL_126:.*]] = add i32 %tile_origin.3, %[[VAL_119]] +// CHECK: %[[VAL_127:.*]] = mul nuw nsw i32 %[[VAL_126]], 1 +// CHECK: %[[VAL_128:.*]] = add nuw nsw i32 0, %[[VAL_127]] +// CHECK: %[[VAL_129:.*]] = mul nuw nsw i32 %[[VAL_125]], 2 +// CHECK: %[[VAL_130:.*]] = add nuw nsw i32 %[[VAL_128]], %[[VAL_129]] +// CHECK: %[[VAL_131:.*]] = udiv i32 %[[VAL_130]], 1024 +// CHECK: %[[VAL_132:.*]] = mul nuw nsw i32 %[[VAL_124]], 1 +// CHECK: %[[VAL_133:.*]] = add nuw nsw i32 0, %[[VAL_132]] +// CHECK: %[[VAL_134:.*]] = udiv i32 %[[VAL_133]], 131072 +// CHECK: %[[VAL_135:.*]] = mul nuw nsw i32 %[[VAL_123]], 1 +// CHECK: %[[VAL_136:.*]] = add nuw nsw i32 0, %[[VAL_135]] +// CHECK: %[[VAL_137:.*]] = getelementptr inbounds [131072 x [1024 x float]], ptr %[[VAL_103]], i32 0, i32 %[[VAL_133]], i32 %[[VAL_130]] +// CHECK: %[[VAL_138:.*]] = load float, ptr %[[VAL_137]], align 4, !invariant.load !7 +// CHECK: store float %[[VAL_138]], ptr %[[VAL_29]], align 4 // CHECK: call void @[[SUM]](ptr %[[VAL_28]], ptr %[[VAL_29]], ptr %[[VAL_21]]) -// CHECK: %[[VAL_152:.*]] = load float, ptr %[[VAL_21]], align 4 -// CHECK: store float %[[VAL_152]], ptr %[[VAL_28]], align 4 -// CHECK: br label %[[VAL_124]] -// CHECK: intra_warp_reduce_write-true: ; preds = %[[VAL_65]] -// CHECK: %[[VAL_155:.*]] = load float, ptr %[[VAL_28]], align 4 -// CHECK: %[[VAL_153:.*]] = getelementptr inbounds [1 x [2 x float]], ptr addrspace(3) @shared_cache, i32 0, i32 %[[VAL_40]], i32 %[[VAL_89]] -// CHECK: %[[VAL_154:.*]] = addrspacecast ptr addrspace(3) %[[VAL_153]] to ptr -// CHECK: store float %[[VAL_155]], ptr %[[VAL_154]], align 4 -// CHECK: br label %[[VAL_92]] -// CHECK: inter_warp_reduce-true: ; preds = %[[VAL_92]] -// CHECK: %[[VAL_156:.*]] = getelementptr inbounds [1 x [2 x float]], ptr addrspace(3) @shared_cache, i32 0, i32 %[[VAL_40]], i32 %[[VAL_48]] -// CHECK: %[[VAL_157:.*]] = addrspacecast ptr addrspace(3) %[[VAL_156]] to ptr -// CHECK: store float %[[VAL_36]], ptr %[[VAL_10]], align 4 -// CHECK: %[[VAL_158:.*]] = icmp ult i32 %[[VAL_46]], 2 -// CHECK: %[[VAL_159:.*]] = select i1 %[[VAL_158]], ptr %[[VAL_157]], ptr %[[VAL_10]] -// CHECK: %[[VAL_160:.*]] = load float, ptr %[[VAL_159]], align 4 -// CHECK: %[[VAL_161:.*]] = call float @llvm.nvvm.shfl.sync.down.f32(i32 -1, float %[[VAL_160]], i32 16, i32 31) -// CHECK: store float %[[VAL_161]], ptr %[[VAL_9]], align 4 -// CHECK: call void @[[SUM]](ptr %[[VAL_159]], ptr %[[VAL_9]], ptr %[[VAL_8]]) -// CHECK: %[[VAL_162:.*]] = load float, ptr %[[VAL_8]], align 4 -// CHECK: store float %[[VAL_162]], ptr %[[VAL_159]], align 4 -// CHECK: %[[VAL_163:.*]] = load float, ptr %[[VAL_159]], align 4 -// CHECK: %[[VAL_164:.*]] = call float @llvm.nvvm.shfl.sync.down.f32(i32 -1, float %[[VAL_163]], i32 8, i32 31) -// CHECK: store float %[[VAL_164]], ptr %[[VAL_7]], align 4 -// CHECK: call void @[[SUM]](ptr %[[VAL_159]], ptr %[[VAL_7]], ptr %[[VAL_6]]) -// CHECK: %[[VAL_165:.*]] = load float, ptr %[[VAL_6]], align 4 -// CHECK: store float %[[VAL_165]], ptr %[[VAL_159]], align 4 -// CHECK: %[[VAL_166:.*]] = load float, ptr %[[VAL_159]], align 4 -// CHECK: %[[VAL_167:.*]] = call float @llvm.nvvm.shfl.sync.down.f32(i32 -1, float %[[VAL_166]], i32 4, i32 31) -// CHECK: store float %[[VAL_167]], ptr %[[VAL_5]], align 4 -// CHECK: call void @[[SUM]](ptr %[[VAL_159]], ptr %[[VAL_5]], ptr %[[VAL_4]]) -// CHECK: %[[VAL_168:.*]] = load float, ptr %[[VAL_4]], align 4 -// CHECK: store float %[[VAL_168]], ptr %[[VAL_159]], align 4 -// CHECK: %[[VAL_169:.*]] = load float, ptr %[[VAL_159]], align 4 -// CHECK: %[[VAL_170:.*]] = call float @llvm.nvvm.shfl.sync.down.f32(i32 -1, float %[[VAL_169]], i32 2, i32 31) -// CHECK: store float %[[VAL_170]], ptr %[[VAL_3]], align 4 -// CHECK: call void @[[SUM]](ptr %[[VAL_159]], ptr %[[VAL_3]], ptr %[[VAL_2]]) -// CHECK: %[[VAL_171:.*]] = load float, ptr %[[VAL_2]], align 4 -// CHECK: store float %[[VAL_171]], ptr %[[VAL_159]], align 4 -// CHECK: %[[VAL_172:.*]] = load float, ptr %[[VAL_159]], align 4 -// CHECK: %[[VAL_173:.*]] = call float @llvm.nvvm.shfl.sync.down.f32(i32 -1, float %[[VAL_172]], i32 1, i32 31) -// CHECK: store float %[[VAL_173]], ptr %[[VAL_1]], align 4 -// CHECK: call void @[[SUM]](ptr %[[VAL_159]], ptr %[[VAL_1]], ptr %[[VAL_0]]) -// CHECK: %[[VAL_174:.*]] = load float, ptr %[[VAL_0]], align 4 -// CHECK: store float %[[VAL_174]], ptr %[[VAL_159]], align 4 -// CHECK: %[[VAL_175:.*]] = icmp eq i32 %[[VAL_46]], 0 -// CHECK: br i1 %[[VAL_175]], label %[[VAL_176:.*]], label %[[VAL_95]] -// CHECK: reduction_write_output-after: ; preds = %[[VAL_176]], %[[VAL_94]] -// CHECK: br label %[[VAL_34]] -// CHECK: reduction_write_output-true: ; preds = %[[VAL_94]] -// CHECK: %[[VAL_178:.*]] = add i32 %[[VAL_59]], %[[VAL_47]] -// CHECK: %[[VAL_179:.*]] = add i32 %[[VAL_60]], %[[VAL_96]] -// CHECK: %[[VAL_180:.*]] = udiv i32 %[[VAL_178]], 1 -// CHECK: %[[VAL_181:.*]] = getelementptr inbounds [131072 x float], ptr %[[VAL_182:.*]], i32 0, i32 %[[VAL_180]] -// CHECK: %[[VAL_183:.*]] = load float, ptr %[[VAL_159]], align 4 -// CHECK: store float %[[VAL_183]], ptr %[[VAL_181]], align 4 -// CHECK: br label %[[VAL_95]] +// CHECK: %[[VAL_139:.*]] = load float, ptr %[[VAL_21]], align 4 +// CHECK: store float %[[VAL_139]], ptr %[[VAL_28]], align 4 +// CHECK: br label %[[VAL_117]], !llvm.loop !12 +// CHECK: loop3.loop_exit10: ; preds = %[[VAL_117]] +// CHECK: br label %[[VAL_107]] +// CHECK: thread_in_bounds-true: ; preds = %[[VAL_48]] +// CHECK: %[[VAL_140:.*]] = icmp eq i32 %lane_id, 0 +// CHECK: br i1 %[[VAL_140]], label %[[VAL_141:.*]], label %[[VAL_142:.*]] +// CHECK: intra_warp_reduce_write-after: ; preds = %[[VAL_141]], %thread_in_bounds-true +// CHECK: call void @llvm.nvvm.barrier0() +// CHECK: %[[VAL_143:.*]] = icmp eq i32 %[[VAL_71]], 0 +// CHECK: br i1 %[[VAL_143]], label %[[VAL_144:.*]], label %[[VAL_73]] +// CHECK: inter_warp_reduce-after: ; preds = %[[VAL_145:.*]], %[[VAL_142]] +// CHECK: br label %thread_in_bounds-after +// CHECK: intra_warp_reduce_write-true: ; preds = %thread_in_bounds-true +// CHECK: %[[VAL_146:.*]] = load float, ptr %[[VAL_28]], align 4 +// CHECK: %[[VAL_147:.*]] = getelementptr inbounds [4 x [2 x float]], ptr addrspace(3) @shared_cache, i32 0, i32 %thread.id.1, i32 %[[VAL_71]] +// CHECK: %[[VAL_148:.*]] = addrspacecast ptr addrspace(3) %[[VAL_147]] to ptr +// CHECK: store float %[[VAL_146]], ptr %[[VAL_148]], align 4 +// CHECK: br label %[[VAL_142]] +// CHECK: inter_warp_reduce-true: ; preds = %[[VAL_142]] +// CHECK: %[[VAL_149:.*]] = getelementptr inbounds [4 x [2 x float]], ptr addrspace(3) @shared_cache, i32 0, i32 %thread.id.1, i32 %lane_id +// CHECK: %[[VAL_150:.*]] = addrspacecast ptr addrspace(3) %[[VAL_149]] to ptr +// CHECK: store float %[[VAL_35]], ptr %[[VAL_10]], align 4 +// CHECK: %[[VAL_151:.*]] = icmp ult i32 %thread.id.2, 2 +// CHECK: %[[VAL_152:.*]] = select i1 %[[VAL_151]], ptr %[[VAL_150]], ptr %[[VAL_10]] +// CHECK: %[[VAL_153:.*]] = load float, ptr %[[VAL_152]], align 4 +// CHECK: %[[VAL_154:.*]] = call float @llvm.nvvm.shfl.sync.down.f32(i32 -1, float %[[VAL_153]], i32 16, i32 31) +// CHECK: store float %[[VAL_154]], ptr %[[VAL_9]], align 4 +// CHECK: call void @[[SUM]](ptr %[[VAL_152]], ptr %[[VAL_9]], ptr %[[VAL_8]]) +// CHECK: %[[VAL_155:.*]] = load float, ptr %[[VAL_8]], align 4 +// CHECK: store float %[[VAL_155]], ptr %[[VAL_152]], align 4 +// CHECK: %[[VAL_156:.*]] = load float, ptr %[[VAL_152]], align 4 +// CHECK: %[[VAL_157:.*]] = call float @llvm.nvvm.shfl.sync.down.f32(i32 -1, float %[[VAL_156]], i32 8, i32 31) +// CHECK: store float %[[VAL_157]], ptr %[[VAL_7]], align 4 +// CHECK: call void @[[SUM]](ptr %[[VAL_152]], ptr %[[VAL_7]], ptr %[[VAL_6]]) +// CHECK: %[[VAL_158:.*]] = load float, ptr %[[VAL_6]], align 4 +// CHECK: store float %[[VAL_158]], ptr %[[VAL_152]], align 4 +// CHECK: %[[VAL_159:.*]] = load float, ptr %[[VAL_152]], align 4 +// CHECK: %[[VAL_160:.*]] = call float @llvm.nvvm.shfl.sync.down.f32(i32 -1, float %[[VAL_159]], i32 4, i32 31) +// CHECK: store float %[[VAL_160]], ptr %[[VAL_5]], align 4 +// CHECK: call void @[[SUM]](ptr %[[VAL_152]], ptr %[[VAL_5]], ptr %[[VAL_4]]) +// CHECK: %[[VAL_161:.*]] = load float, ptr %[[VAL_4]], align 4 +// CHECK: store float %[[VAL_161]], ptr %[[VAL_152]], align 4 +// CHECK: %[[VAL_162:.*]] = load float, ptr %[[VAL_152]], align 4 +// CHECK: %[[VAL_163:.*]] = call float @llvm.nvvm.shfl.sync.down.f32(i32 -1, float %[[VAL_162]], i32 2, i32 31) +// CHECK: store float %[[VAL_163]], ptr %[[VAL_3]], align 4 +// CHECK: call void @[[SUM]](ptr %[[VAL_152]], ptr %[[VAL_3]], ptr %[[VAL_2]]) +// CHECK: %[[VAL_164:.*]] = load float, ptr %[[VAL_2]], align 4 +// CHECK: store float %[[VAL_164]], ptr %[[VAL_152]], align 4 +// CHECK: %[[VAL_165:.*]] = load float, ptr %[[VAL_152]], align 4 +// CHECK: %[[VAL_166:.*]] = call float @llvm.nvvm.shfl.sync.down.f32(i32 -1, float %[[VAL_165]], i32 1, i32 31) +// CHECK: store float %[[VAL_166]], ptr %[[VAL_1]], align 4 +// CHECK: call void @[[SUM]](ptr %[[VAL_152]], ptr %[[VAL_1]], ptr %[[VAL_0]]) +// CHECK: %[[VAL_167:.*]] = load float, ptr %[[VAL_0]], align 4 +// CHECK: store float %[[VAL_167]], ptr %[[VAL_152]], align 4 +// CHECK: %[[VAL_168:.*]] = icmp eq i32 %thread.id.2, 0 +// CHECK: br i1 %[[VAL_168]], label %[[VAL_169:.*]], label %[[VAL_145]] +// CHECK: reduction_write_output-after: ; preds = %[[VAL_169]], %[[VAL_144]] +// CHECK: br label %[[VAL_73]] +// CHECK: reduction_write_output-true: ; preds = %[[VAL_144]] +// CHECK: %[[VAL_171:.*]] = add i32 %tile_origin.1, %thread.id.1 +// CHECK: %[[VAL_175:.*]] = getelementptr inbounds [131072 x float], ptr %[[VAL_176:.*]], i32 0, i32 %[[VAL_171]] +// CHECK: %[[VAL_177:.*]] = load float, ptr %[[VAL_152]], align 4 +// CHECK: store float %[[VAL_177]], ptr %[[VAL_175]], align 4 +// CHECK: br label %[[VAL_145]] // CHECK: entry: -// CHECK: %[[VAL_184:.*]] = alloca float, align 4 -// CHECK: %[[VAL_185:.*]] = load float, ptr %[[VAL_186:.*]], align 4 -// CHECK: %[[VAL_187:.*]] = load float, ptr %[[VAL_188:.*]], align 4 -// CHECK: %[[VAL_189:.*]] = fadd float %[[VAL_185]], %[[VAL_187]] -// CHECK: store float %[[VAL_189]], ptr %[[VAL_184]], align 4 -// CHECK: %[[VAL_190:.*]] = load float, ptr %[[VAL_184]], align 4 -// CHECK: store float %[[VAL_190]], ptr %[[VAL_191:.*]], align 4 +// CHECK: %[[VAL_178:.*]] = alloca float, align 4 +// CHECK: %[[VAL_179:.*]] = load float, ptr %[[VAL_180:.*]], align 4 +// CHECK: %[[VAL_181:.*]] = load float, ptr %[[VAL_182:.*]], align 4 +// CHECK: %[[VAL_183:.*]] = fadd float %[[VAL_179]], %[[VAL_181]] +// CHECK: store float %[[VAL_183]], ptr %[[VAL_178]], align 4 +// CHECK: %[[VAL_184:.*]] = load float, ptr %[[VAL_178]], align 4 +// CHECK: store float %[[VAL_184]], ptr %[[VAL_185:.*]], align 4 // CHECK: ret void diff --git a/third_party/xla/xla/service/gpu/tests/reduce_unnested.hlo b/third_party/xla/xla/service/gpu/tests/reduce_unnested.hlo index 6c796f4b4baead..844c3ded2ef024 100644 --- a/third_party/xla/xla/service/gpu/tests/reduce_unnested.hlo +++ b/third_party/xla/xla/service/gpu/tests/reduce_unnested.hlo @@ -80,34 +80,3 @@ ENTRY reduce.1 { f32[131072,1024] parameter0 ), kind=kLoop, calls=fusion_not_vectorized } - -// ----- - -// TODO(jreiffers): This should most likely not be unrolled. The heuristic only -// checks instructions that are directly in the fusion, not nested computations. - -// CHECK: define void @fusion_row_reduction_sin_does_not_prevent_vectorization( -// CHECK-COUNT-2: {{^x_in_tile-true}} -// CHECK-NOT: {{^x_in_tile-true}} - -HloModule RowReductionVectorized, is_scheduled=true - -Sum { - x.1 = f32[] parameter(0) - y.1 = f32[] parameter(1) - sin = f32[] sine(y.1) - ROOT add.1 = f32[] add(x.1, sin) -} - -fusion_vectorized { - a = f32[131072,1024] parameter(0) - init = f32[] constant(0) - ROOT reduce = f32[131072] reduce(a, init), dimensions={1}, to_apply=Sum -} - -ENTRY reduce.1 { - parameter0 = f32[131072,1024] parameter(0) - ROOT fusion_row_reduction_sin_does_not_prevent_vectorization = f32[131072] fusion( - f32[131072,1024] parameter0 - ), kind=kLoop, calls=fusion_vectorized -} diff --git a/third_party/xla/xla/service/gpu/tests/reduce_variadic_column.hlo b/third_party/xla/xla/service/gpu/tests/reduce_variadic_column.hlo index ef296aff11c900..2032a64930717b 100644 --- a/third_party/xla/xla/service/gpu/tests/reduce_variadic_column.hlo +++ b/third_party/xla/xla/service/gpu/tests/reduce_variadic_column.hlo @@ -72,293 +72,393 @@ ENTRY main { // CHECK-GCN: %[[VAL_41:.*]] = call i32 @llvm.amdgcn.workgroup.id.y // CHECK: %[[VAL_42:.*]] = icmp eq i32 %[[VAL_41]], 0 // CHECK: br i1 %[[VAL_42]], label %[[VAL_43:.*]], label %[[VAL_44:.*]] -// CHECK: reduce-group-0-after: ; preds = %[[VAL_45:.*]], %[[VAL_46:.*]] +// CHECK: reduce-group-0-after: ; preds = %thread_in_bounds-after, %[[VAL_45:.*]] // CHECK: ret void -// CHECK: reduce-group-0-true: ; preds = %[[VAL_46]] -// CHECK: %[[VAL_47:.*]] = load float, ptr %[[VAL_48:.*]], align 4, !invariant.load !3 -// CHECK: store float %[[VAL_47]], ptr %[[VAL_39]], align 4 -// CHECK: %[[VAL_49:.*]] = load float, ptr %[[VAL_48]], align 4, !invariant.load !3 -// CHECK: store float %[[VAL_49]], ptr %[[VAL_37]], align 4 -// CHECK-PTX: %[[VAL_50:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range !4 -// CHECK-GCN: %[[VAL_50:.*]] = call i32 @llvm.amdgcn.workitem.id.x -// CHECK-PTX: %[[VAL_51:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range !5 -// CHECK-GCN: %[[VAL_51:.*]] = call i32 @llvm.amdgcn.workgroup.id.x -// CHECK: %[[VAL_52:.*]] = urem i32 %[[VAL_50]], 32 -// CHECK: %[[VAL_53:.*]] = udiv i32 %[[VAL_50]], 32 -// CHECK: %[[VAL_54:.*]] = mul i32 %[[VAL_51]], 1 -// CHECK: %[[VAL_55:.*]] = add i32 %[[VAL_54]], %[[VAL_53]] -// CHECK: %[[VAL_56:.*]] = icmp ult i32 %[[VAL_55]], 200 -// CHECK: br i1 %[[VAL_56]], label %[[VAL_57:.*]], label %[[VAL_58:.*]] -// CHECK: 9: ; preds = %[[VAL_43]] -// CHECK: %[[VAL_60:.*]] = udiv i32 %[[VAL_52]], 32 -// CHECK: %[[VAL_59:.*]] = urem i32 %[[VAL_52]], 32 -// CHECK: %[[THREAD_X:.*]] = mul i32 %[[VAL_59]], 1 -// CHECK: %[[VAL_61:.*]] = urem i32 %[[VAL_52]], 32 -// CHECK: %[[VAL_62:.*]] = udiv i32 %[[VAL_55]], 1 -// CHECK: %[[VAL_63:.*]] = urem i32 %[[VAL_62]], 1 -// CHECK: %[[VAL_64:.*]] = udiv i32 %[[VAL_55]], 1 -// CHECK: %[[VAL_65:.*]] = urem i32 %[[VAL_64]], 200 -// CHECK: %[[VAL_66:.*]] = udiv i32 %[[VAL_55]], 200 -// CHECK: %[[VAL_69:.*]] = icmp eq i32 %[[VAL_63]], 0 -// CHECK: %[[VAL_70:.*]] = select i1 %[[VAL_69]], i32 300, i32 512 -// CHECK: %[[VAL_71:.*]] = mul i32 %[[VAL_66]], 5 -// CHECK: %[[VAL_72:.*]] = mul i32 %[[VAL_65]], 1 -// CHECK: %[[VAL_73:.*]] = mul i32 %[[VAL_63]], 512 -// CHECK: store i32 0, ptr %[[VAL_36]], align 4 -// CHECK: br label %[[VAL_77:.*]] -// CHECK: loop0.loop_header: ; preds = %[[VAL_78:.*]], %[[VAL_57]] -// CHECK: %[[VAL_79:.*]] = load i32, ptr %[[VAL_36]], align 4 -// CHECK: %[[VAL_80:.*]] = icmp uge i32 %[[VAL_79]], 5 -// CHECK: br i1 %[[VAL_80]], label %[[VAL_81:.*]], label %[[VAL_82:.*]] -// CHECK: loop0.loop_body: ; preds = %[[VAL_77]] -// CHECK: %[[VAL_83:.*]] = add nuw nsw i32 %[[VAL_79]], 1 -// CHECK: store i32 %[[VAL_83]], ptr %[[VAL_36]], align 4 -// CHECK: %[[VAL_84:.*]] = icmp eq i32 %[[VAL_79]], 0 -// CHECK: store i32 %[[VAL_60]], ptr %[[VAL_35]], align 4 -// CHECK: br label %[[VAL_86:.*]] -// CHECK: loop1.loop_header: ; preds = %[[VAL_87:.*]], %[[VAL_82]] -// CHECK: %[[VAL_88:.*]] = load i32, ptr %[[VAL_35]], align 4 -// CHECK: %[[VAL_89:.*]] = icmp uge i32 %[[VAL_88]], 1 -// CHECK: br i1 %[[VAL_89]], label %[[VAL_78]], label %[[VAL_90:.*]] -// CHECK: loop1.loop_body: ; preds = %[[VAL_86]] -// CHECK: %[[VAL_91:.*]] = add nuw nsw i32 %[[VAL_88]], 1 -// CHECK: store i32 %[[VAL_91]], ptr %[[VAL_35]], align 4 -// CHECK: %[[VAL_92:.*]] = icmp eq i32 %[[VAL_88]], %[[VAL_60]] -// CHECK: %[[VAL_93:.*]] = icmp eq i32 512, %[[VAL_70]] -// CHECK: br i1 %[[VAL_93]], label %[[VAL_94:.*]], label %[[VAL_95:.*]] -// CHECK: is_full_tile-after: ; preds = %[[VAL_96:.*]], %[[VAL_97:.*]] -// CHECK: br label %[[VAL_86]], !llvm.loop !6 -// CHECK: loop1.loop_exit: ; preds = %[[VAL_86]] -// CHECK: br label %[[VAL_77]], !llvm.loop !8 -// CHECK: loop0.loop_exit: ; preds = %[[VAL_77]] -// CHECK: %[[VAL_98:.*]] = load float, ptr %[[VAL_39]], align 4 -// CHECK: %[[VAL_99:.*]] = call float @llvm.nvvm.shfl.sync.down.f32(i32 -1, float %[[VAL_98]], i32 16, i32 31) -// CHECK: store float %[[VAL_99]], ptr %[[VAL_26]], align 4 -// CHECK: %[[VAL_100:.*]] = load float, ptr %[[VAL_37]], align 4 -// CHECK: %[[VAL_101:.*]] = call float @llvm.nvvm.shfl.sync.down.f32(i32 -1, float %[[VAL_100]], i32 16, i32 31) -// CHECK: store float %[[VAL_101]], ptr %[[VAL_25]], align 4 -// CHECK: %[[VAL_102:.*]] = getelementptr inbounds [2 x ptr], ptr %[[VAL_24]], i64 0, i64 0 -// CHECK: store ptr %[[VAL_22]], ptr %[[VAL_102]], align 8 -// CHECK: %[[VAL_103:.*]] = getelementptr inbounds [2 x ptr], ptr %[[VAL_24]], i64 0, i64 1 -// CHECK: store ptr %[[VAL_23]], ptr %[[VAL_103]], align 8 -// CHECK: call void @[[ADD:Add.*]](ptr %[[VAL_39]], ptr %[[VAL_37]], ptr %[[VAL_26]], ptr %[[VAL_25]], ptr %[[VAL_24]]) -// CHECK: %[[VAL_104:.*]] = load float, ptr %[[VAL_22]], align 4 -// CHECK: %[[VAL_105:.*]] = load float, ptr %[[VAL_23]], align 4 -// CHECK: store float %[[VAL_104]], ptr %[[VAL_39]], align 4 -// CHECK: store float %[[VAL_105]], ptr %[[VAL_37]], align 4 -// CHECK: %[[VAL_106:.*]] = load float, ptr %[[VAL_39]], align 4 -// CHECK: %[[VAL_107:.*]] = call float @llvm.nvvm.shfl.sync.down.f32(i32 -1, float %[[VAL_106]], i32 8, i32 31) -// CHECK: store float %[[VAL_107]], ptr %[[VAL_21]], align 4 -// CHECK: %[[VAL_108:.*]] = load float, ptr %[[VAL_37]], align 4 -// CHECK: %[[VAL_109:.*]] = call float @llvm.nvvm.shfl.sync.down.f32(i32 -1, float %[[VAL_108]], i32 8, i32 31) -// CHECK: store float %[[VAL_109]], ptr %[[VAL_20]], align 4 -// CHECK: %[[VAL_110:.*]] = getelementptr inbounds [2 x ptr], ptr %[[VAL_19]], i64 0, i64 0 -// CHECK: store ptr %[[VAL_17]], ptr %[[VAL_110]], align 8 -// CHECK: %[[VAL_111:.*]] = getelementptr inbounds [2 x ptr], ptr %[[VAL_19]], i64 0, i64 1 -// CHECK: store ptr %[[VAL_18]], ptr %[[VAL_111]], align 8 -// CHECK: call void @[[ADD]](ptr %[[VAL_39]], ptr %[[VAL_37]], ptr %[[VAL_21]], ptr %[[VAL_20]], ptr %[[VAL_19]]) -// CHECK: %[[VAL_112:.*]] = load float, ptr %[[VAL_17]], align 4 -// CHECK: %[[VAL_113:.*]] = load float, ptr %[[VAL_18]], align 4 -// CHECK: store float %[[VAL_112]], ptr %[[VAL_39]], align 4 -// CHECK: store float %[[VAL_113]], ptr %[[VAL_37]], align 4 -// CHECK: %[[VAL_114:.*]] = load float, ptr %[[VAL_39]], align 4 -// CHECK: %[[VAL_115:.*]] = call float @llvm.nvvm.shfl.sync.down.f32(i32 -1, float %[[VAL_114]], i32 4, i32 31) -// CHECK: store float %[[VAL_115]], ptr %[[VAL_16]], align 4 -// CHECK: %[[VAL_116:.*]] = load float, ptr %[[VAL_37]], align 4 -// CHECK: %[[VAL_117:.*]] = call float @llvm.nvvm.shfl.sync.down.f32(i32 -1, float %[[VAL_116]], i32 4, i32 31) -// CHECK: store float %[[VAL_117]], ptr %[[VAL_15]], align 4 -// CHECK: %[[VAL_118:.*]] = getelementptr inbounds [2 x ptr], ptr %[[VAL_14]], i64 0, i64 0 -// CHECK: store ptr %[[VAL_12]], ptr %[[VAL_118]], align 8 -// CHECK: %[[VAL_119:.*]] = getelementptr inbounds [2 x ptr], ptr %[[VAL_14]], i64 0, i64 1 -// CHECK: store ptr %[[VAL_13]], ptr %[[VAL_119]], align 8 -// CHECK: call void @[[ADD]](ptr %[[VAL_39]], ptr %[[VAL_37]], ptr %[[VAL_16]], ptr %[[VAL_15]], ptr %[[VAL_14]]) -// CHECK: %[[VAL_120:.*]] = load float, ptr %[[VAL_12]], align 4 -// CHECK: %[[VAL_121:.*]] = load float, ptr %[[VAL_13]], align 4 -// CHECK: store float %[[VAL_120]], ptr %[[VAL_39]], align 4 -// CHECK: store float %[[VAL_121]], ptr %[[VAL_37]], align 4 -// CHECK: %[[VAL_122:.*]] = load float, ptr %[[VAL_39]], align 4 -// CHECK: %[[VAL_123:.*]] = call float @llvm.nvvm.shfl.sync.down.f32(i32 -1, float %[[VAL_122]], i32 2, i32 31) -// CHECK: store float %[[VAL_123]], ptr %[[VAL_11]], align 4 -// CHECK: %[[VAL_124:.*]] = load float, ptr %[[VAL_37]], align 4 -// CHECK: %[[VAL_125:.*]] = call float @llvm.nvvm.shfl.sync.down.f32(i32 -1, float %[[VAL_124]], i32 2, i32 31) -// CHECK: store float %[[VAL_125]], ptr %[[VAL_10]], align 4 -// CHECK: %[[VAL_126:.*]] = getelementptr inbounds [2 x ptr], ptr %[[VAL_9]], i64 0, i64 0 -// CHECK: store ptr %[[VAL_7]], ptr %[[VAL_126]], align 8 -// CHECK: %[[VAL_127:.*]] = getelementptr inbounds [2 x ptr], ptr %[[VAL_9]], i64 0, i64 1 -// CHECK: store ptr %[[VAL_8]], ptr %[[VAL_127]], align 8 -// CHECK: call void @[[ADD]](ptr %[[VAL_39]], ptr %[[VAL_37]], ptr %[[VAL_11]], ptr %[[VAL_10]], ptr %[[VAL_9]]) -// CHECK: %[[VAL_128:.*]] = load float, ptr %[[VAL_7]], align 4 -// CHECK: %[[VAL_129:.*]] = load float, ptr %[[VAL_8]], align 4 -// CHECK: store float %[[VAL_128]], ptr %[[VAL_39]], align 4 -// CHECK: store float %[[VAL_129]], ptr %[[VAL_37]], align 4 -// CHECK: %[[VAL_130:.*]] = load float, ptr %[[VAL_39]], align 4 -// CHECK: %[[VAL_131:.*]] = call float @llvm.nvvm.shfl.sync.down.f32(i32 -1, float %[[VAL_130]], i32 1, i32 31) -// CHECK: store float %[[VAL_131]], ptr %[[VAL_6]], align 4 -// CHECK: %[[VAL_132:.*]] = load float, ptr %[[VAL_37]], align 4 -// CHECK: %[[VAL_133:.*]] = call float @llvm.nvvm.shfl.sync.down.f32(i32 -1, float %[[VAL_132]], i32 1, i32 31) -// CHECK: store float %[[VAL_133]], ptr %[[VAL_5]], align 4 -// CHECK: %[[VAL_134:.*]] = getelementptr inbounds [2 x ptr], ptr %[[VAL_4]], i64 0, i64 0 -// CHECK: store ptr %[[VAL_2]], ptr %[[VAL_134]], align 8 -// CHECK: %[[VAL_135:.*]] = getelementptr inbounds [2 x ptr], ptr %[[VAL_4]], i64 0, i64 1 -// CHECK: store ptr %[[VAL_3]], ptr %[[VAL_135]], align 8 -// CHECK: call void @[[ADD]](ptr %[[VAL_39]], ptr %[[VAL_37]], ptr %[[VAL_6]], ptr %[[VAL_5]], ptr %[[VAL_4]]) -// CHECK: %[[VAL_136:.*]] = load float, ptr %[[VAL_2]], align 4 -// CHECK: %[[VAL_137:.*]] = load float, ptr %[[VAL_3]], align 4 -// CHECK: store float %[[VAL_136]], ptr %[[VAL_39]], align 4 -// CHECK: store float %[[VAL_137]], ptr %[[VAL_37]], align 4 -// CHECK: %[[VAL_138:.*]] = udiv i32 %[[VAL_59]], 32 -// CHECK: %[[VAL_139:.*]] = icmp eq i32 %[[VAL_61]], 0 -// CHECK: br i1 %[[VAL_139]], label %[[VAL_140:.*]], label %[[VAL_141:.*]] -// CHECK: intra_warp_reduce_write-after: ; preds = %[[VAL_140]], %[[VAL_81]] -// CHECK: call void @llvm.nvvm.barrier0() -// CHECK: %[[VAL_142:.*]] = icmp eq i32 %[[VAL_138]], 0 -// CHECK: br i1 %[[VAL_142]], label %[[VAL_143:.*]], label %[[VAL_45]] -// CHECK: inter_warp_reduce-after: ; preds = %[[VAL_144:.*]], %[[VAL_141]] +// CHECK: reduce-group-0-true: ; preds = %[[VAL_45]] +// CHECK: %[[VAL_46:.*]] = load float, ptr{{.*}}%[[VAL_47:.*]], align 4, !invariant.load !{{[0-9]}} +// CHECK: store float %[[VAL_46]], ptr{{.*}}%[[VAL_39]], align 4 +// CHECK: %[[VAL_48:.*]] = load float, ptr{{.*}}%[[VAL_47]], align 4, !invariant.load !{{[0-9]}} +// CHECK: store float %[[VAL_48]], ptr{{.*}}%[[VAL_37]], align 4 +// CHECK-PTX: %thread.id.x = call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range !4 +// CHECK-GCN: %thread.id.x = call i32 @llvm.amdgcn.workitem.id.x +// CHECK-PTX: %block.id.x = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range !5 +// CHECK-GCN: %block.id.x = call i32 @llvm.amdgcn.workgroup.id.x +// CHECK: %[[VAL_49:.*]] = udiv i32 %thread.id.x, 32 +// CHECK: %thread.id.1 = urem i32 %[[VAL_49]], 8 +// CHECK: %thread.id.2 = urem i32 %thread.id.x, 32 +// CHECK: %lane_id = urem i32 %thread.id.x, 32 +// CHECK: %[[VAL_50:.*]] = udiv i32 %block.id.x, 1 +// CHECK: %[[VAL_51:.*]] = urem i32 %[[VAL_50]], 1 +// CHECK: %[[VAL_52:.*]] = udiv i32 %block.id.x, 1 +// CHECK: %[[VAL_53:.*]] = urem i32 %[[VAL_52]], 25 +// CHECK: %[[VAL_54:.*]] = udiv i32 %block.id.x, 25 +// CHECK: %[[VAL_55:.*]] = icmp eq i32 %[[VAL_51]], 0 +// CHECK: %tile_bound.2 = select i1 %[[VAL_55]], i32 300, i32 512 +// CHECK: %tile_origin.0 = mul i32 %[[VAL_54]], 5 +// CHECK: %tile_origin.1 = mul i32 %[[VAL_53]], 8 +// CHECK: %tile_origin.2 = mul i32 %[[VAL_51]], 512 +// CHECK: store i32 0, ptr{{.*}}%[[VAL_36]], align 4 +// CHECK: br label %[[VAL_56:.*]] +// CHECK: loop0.loop_header: ; preds = %[[VAL_57:.*]], %[[VAL_43]] +// CHECK: %[[VAL_58:.*]] = load i32, ptr{{.*}}%[[VAL_36]], align 4 +// CHECK: %[[VAL_59:.*]] = icmp uge i32 %[[VAL_58]], 5 +// CHECK: br i1 %[[VAL_59]], label %[[VAL_60:.*]], label %[[VAL_61:.*]] +// CHECK: loop0.loop_body: ; preds = %[[VAL_56]] +// CHECK: %[[VAL_62:.*]] = add nuw nsw i32 %[[VAL_58]], 1 +// CHECK: store i32 %[[VAL_62]], ptr{{.*}}%[[VAL_36]], align 4 +// CHECK: %[[VAL_63:.*]] = icmp eq i32 %[[VAL_58]], 0 +// CHECK: store i32 %thread.id.1, ptr{{.*}}%[[VAL_35]], align 4 +// CHECK: br label %[[VAL_64:.*]] +// CHECK: loop1.loop_header: ; preds = %[[VAL_65:.*]], %[[VAL_61]] +// CHECK: %[[VAL_66:.*]] = load i32, ptr{{.*}}%[[VAL_35]], align 4 +// CHECK: %[[VAL_67:.*]] = icmp uge i32 %[[VAL_66]], 8 +// CHECK: br i1 %[[VAL_67]], label %[[VAL_57]], label %[[VAL_68:.*]] +// CHECK: loop1.loop_body: ; preds = %[[VAL_64]] +// CHECK: %[[VAL_69:.*]] = add nuw nsw i32 %[[VAL_66]], 8 +// CHECK: store i32 %[[VAL_69]], ptr{{.*}}%[[VAL_35]], align 4 +// CHECK: %[[VAL_70:.*]] = icmp eq i32 %[[VAL_66]], %thread.id.1 +// CHECK: %[[VAL_71:.*]] = icmp eq i32 512, %tile_bound.2 +// CHECK: br i1 %[[VAL_71]], label %[[VAL_72:.*]], label %[[VAL_73:.*]] +// CHECK: is_full_tile-after: ; preds = %[[VAL_74:.*]], %[[VAL_75:.*]] +// CHECK: br label %[[VAL_64]], !llvm.loop !{{[0-9]}} +// CHECK: loop1.loop_exit: ; preds = %[[VAL_64]] +// CHECK: br label %[[VAL_56]], !llvm.loop !{{[0-9]}} +// CHECK: loop0.loop_exit: ; preds = %[[VAL_56]] +// CHECK: %[[VAL_76:.*]] = load float, ptr{{.*}}%[[VAL_39]], align 4 +// CHECK-PTX: %[[VAL_77:.*]] = call float @llvm.nvvm.shfl.sync.down.f32(i32 -1, float %[[VAL_76]], i32 16, i32 31) +// CHECK-GCN: %[[VAL_76_1:.*]] = bitcast float %[[VAL_76]] to i32 +// CHECK-GCN: %[[VAL_77_1:.*]] = call i32 @__ockl_readuplane_i32(i32 %[[VAL_76_1:.*]], i32 16) +// CHECK-GCN: %[[VAL_77:.*]] = bitcast i32 %[[VAL_77_1:.*]] to float +// CHECK: store float %[[VAL_77]], ptr{{.*}}%[[VAL_26]], align 4 +// CHECK: %[[VAL_78:.*]] = load float, ptr{{.*}}%[[VAL_37]], align 4 +// CHECK-PTX: %[[VAL_79:.*]] = call float @llvm.nvvm.shfl.sync.down.f32(i32 -1, float %[[VAL_78]], i32 16, i32 31) +// CHECK-GCN: %[[VAL_78_1:.*]] = bitcast float %[[VAL_78]] to i32 +// CHECK-GCN: %[[VAL_79_1:.*]] = call i32 @__ockl_readuplane_i32(i32 %[[VAL_78_1:.*]], i32 16) +// CHECK-GCN: %[[VAL_79:.*]] = bitcast i32 %[[VAL_79_1:.*]] to float +// CHECK: store float %[[VAL_79]], ptr{{.*}}%[[VAL_25]], align 4 +// CHECK-GCN: %[[VAL_22_1:.*]] = addrspacecast ptr{{.*}}%[[VAL_22]] to ptr +// CHECK: %[[VAL_80:.*]] = getelementptr inbounds [2 x ptr], ptr{{.*}}%[[VAL_24]], i64 0, i64 0 +// CHECK-PTX: store ptr %[[VAL_22]], ptr %[[VAL_80]], align 8 +// CHECK-GCN: store ptr %[[VAL_22_1]], ptr{{.*}}%[[VAL_80]], align 8 +// CHECK-GCN: %[[VAL_23_1:.*]] = addrspacecast ptr{{.*}}%[[VAL_23]] to ptr +// CHECK: %[[VAL_81:.*]] = getelementptr inbounds [2 x ptr], ptr{{.*}}%[[VAL_24]], i64 0, i64 1 +// CHECK-PTX: store ptr %[[VAL_23]], ptr %[[VAL_81]], align 8 +// CHECK-GCN: store ptr %[[VAL_23_1]], ptr{{.*}}%[[VAL_81]], align 8 +// CHECK-PTX: call void @[[ADD:Add.*]](ptr %[[VAL_39]], ptr %[[VAL_37]], ptr %[[VAL_26]], ptr %[[VAL_25]], ptr %[[VAL_24]]) +// CHECK-GCN: %[[VAL_39_1:.*]] = addrspacecast ptr{{.*}}%[[VAL_39]] to ptr +// CHECK-GCN: %[[VAL_37_1:.*]] = addrspacecast ptr{{.*}}%[[VAL_37]] to ptr +// CHECK-GCN: %[[VAL_26_1:.*]] = addrspacecast ptr{{.*}}%[[VAL_26]] to ptr +// CHECK-GCN: %[[VAL_25_1:.*]] = addrspacecast ptr{{.*}}%[[VAL_25]] to ptr +// CHECK-GCN: %[[VAL_24_1:.*]] = addrspacecast ptr{{.*}}%[[VAL_24]] to ptr +// CHECK-GCN: call void @[[ADD:Add.*]](ptr %[[VAL_39_1]], ptr %[[VAL_37_1]], ptr %[[VAL_26_1]], ptr %[[VAL_25_1]], ptr %[[VAL_24_1]]) +// CHECK: %[[VAL_82:.*]] = load float, ptr{{.*}}%[[VAL_22]], align 4 +// CHECK: %[[VAL_83:.*]] = load float, ptr{{.*}}%[[VAL_23]], align 4 +// CHECK: store float %[[VAL_82]], ptr{{.*}}%[[VAL_39]], align 4 +// CHECK: store float %[[VAL_83]], ptr{{.*}}%[[VAL_37]], align 4 +// CHECK: %[[VAL_84:.*]] = load float, ptr{{.*}}%[[VAL_39]], align 4 +// CHECK-PTX: %[[VAL_85:.*]] = call float @llvm.nvvm.shfl.sync.down.f32(i32 -1, float %[[VAL_84]], i32 8, i32 31) +// CHECK-GCN: %[[VAL_84_1:.*]] = bitcast float %[[VAL_84]] to i32 +// CHECK-GCN: %[[VAL_85_1:.*]] = call i32 @__ockl_readuplane_i32(i32 %[[VAL_84_1:.*]], i32 8) +// CHECK-GCN: %[[VAL_85:.*]] = bitcast i32 %[[VAL_85_1:.*]] to float +// CHECK: store float %[[VAL_85]], ptr{{.*}}%[[VAL_21]], align 4 +// CHECK: %[[VAL_86:.*]] = load float, ptr{{.*}}%[[VAL_37]], align 4 +// CHECK-PTX: %[[VAL_87:.*]] = call float @llvm.nvvm.shfl.sync.down.f32(i32 -1, float %[[VAL_86]], i32 8, i32 31) +// CHECK-GCN: %[[VAL_86_1:.*]] = bitcast float %[[VAL_86]] to i32 +// CHECK-GCN: %[[VAL_87_1:.*]] = call i32 @__ockl_readuplane_i32(i32 %[[VAL_86_1:.*]], i32 8) +// CHECK-GCN: %[[VAL_87:.*]] = bitcast i32 %[[VAL_87_1:.*]] to float +// CHECK: store float %[[VAL_87]], ptr{{.*}}%[[VAL_20]], align 4 +// CHECK-GCN: %[[VAL_17_1:.*]] = addrspacecast ptr{{.*}}%[[VAL_17]] to ptr +// CHECK: %[[VAL_88:.*]] = getelementptr inbounds [2 x ptr], ptr{{.*}}%[[VAL_19]], i64 0, i64 0 +// CHECK-PTX: store ptr %[[VAL_17]], ptr %[[VAL_88]], align 8 +// CHECK-GCN: store ptr %[[VAL_17_1]], ptr{{.*}}%[[VAL_88]], align 8 +// CHECK-GCN: %[[VAL_18_1:.*]] = addrspacecast ptr{{.*}}%[[VAL_18]] to ptr +// CHECK: %[[VAL_89:.*]] = getelementptr inbounds [2 x ptr], ptr{{.*}}%[[VAL_19]], i64 0, i64 1 +// CHECK-PTX: store ptr %[[VAL_18]], ptr %[[VAL_89]], align 8 +// CHECK-GCN: store ptr %[[VAL_18_1]], ptr{{.*}}%[[VAL_89]], align 8 +// CHECK-PTX: call void @[[ADD:Add.*]](ptr %[[VAL_39]], ptr %[[VAL_37]], ptr %[[VAL_21]], ptr %[[VAL_20]], ptr %[[VAL_19]]) +// CHECK-GCN: %[[VAL_39_2:.*]] = addrspacecast ptr{{.*}}%[[VAL_39]] to ptr +// CHECK-GCN: %[[VAL_37_2:.*]] = addrspacecast ptr{{.*}}%[[VAL_37]] to ptr +// CHECK-GCN: %[[VAL_21_2:.*]] = addrspacecast ptr{{.*}}%[[VAL_21]] to ptr +// CHECK-GCN: %[[VAL_20_2:.*]] = addrspacecast ptr{{.*}}%[[VAL_20]] to ptr +// CHECK-GCN: %[[VAL_19_2:.*]] = addrspacecast ptr{{.*}}%[[VAL_19]] to ptr +// CHECK-GCN: call void @[[ADD:Add.*]](ptr %[[VAL_39_2]], ptr %[[VAL_37_2]], ptr %[[VAL_21_2]], ptr %[[VAL_20_2]], ptr %[[VAL_19_2]]) +// CHECK: %[[VAL_90:.*]] = load float, ptr{{.*}}%[[VAL_17]], align 4 +// CHECK: %[[VAL_91:.*]] = load float, ptr{{.*}}%[[VAL_18]], align 4 +// CHECK: store float %[[VAL_90]], ptr{{.*}}%[[VAL_39]], align 4 +// CHECK: store float %[[VAL_91]], ptr{{.*}}%[[VAL_37]], align 4 +// CHECK: %[[VAL_92:.*]] = load float, ptr{{.*}}%[[VAL_39]], align 4 +// CHECK-PTX: %[[VAL_93:.*]] = call float @llvm.nvvm.shfl.sync.down.f32(i32 -1, float %[[VAL_92]], i32 4, i32 31) +// CHECK-GCN: %[[VAL_92_1:.*]] = bitcast float %[[VAL_92]] to i32 +// CHECK-GCN: %[[VAL_93_1:.*]] = call i32 @__ockl_readuplane_i32(i32 %[[VAL_92_1:.*]], i32 4) +// CHECK-GCN: %[[VAL_93:.*]] = bitcast i32 %[[VAL_93_1:.*]] to float +// CHECK: store float %[[VAL_93]], ptr{{.*}}%[[VAL_16]], align 4 +// CHECK: %[[VAL_94:.*]] = load float, ptr{{.*}}%[[VAL_37]], align 4 +// CHECK-PTX: %[[VAL_95:.*]] = call float @llvm.nvvm.shfl.sync.down.f32(i32 -1, float %[[VAL_94]], i32 4, i32 31) +// CHECK-GCN: %[[VAL_94_1:.*]] = bitcast float %[[VAL_94]] to i32 +// CHECK-GCN: %[[VAL_95_1:.*]] = call i32 @__ockl_readuplane_i32(i32 %[[VAL_94_1:.*]], i32 4) +// CHECK-GCN: %[[VAL_95:.*]] = bitcast i32 %[[VAL_95_1:.*]] to float +// CHECK: store float %[[VAL_95]], ptr{{.*}}%[[VAL_15]], align 4 +// CHECK-GCN: %[[VAL_12_1:.*]] = addrspacecast ptr{{.*}}%[[VAL_12]] to ptr +// CHECK: %[[VAL_96:.*]] = getelementptr inbounds [2 x ptr], ptr{{.*}}%[[VAL_14]], i64 0, i64 0 +// CHECK-PTX: store ptr %[[VAL_12]], ptr %[[VAL_96]], align 8 +// CHECK-GCN: store ptr %[[VAL_12_1]], ptr{{.*}}%[[VAL_96]], align 8 +// CHECK-GCN: %[[VAL_13_1:.*]] = addrspacecast ptr{{.*}}%[[VAL_13]] to ptr +// CHECK: %[[VAL_97:.*]] = getelementptr inbounds [2 x ptr], ptr{{.*}}%[[VAL_14]], i64 0, i64 1 +// CHECK-PTX: store ptr %[[VAL_13]], ptr %[[VAL_97]], align 8 +// CHECK-GCN: store ptr %[[VAL_13_1]], ptr{{.*}}%[[VAL_97]], align 8 +// CHECK-PTX: call void @[[ADD:Add.*]](ptr %[[VAL_39]], ptr %[[VAL_37]], ptr %[[VAL_16]], ptr %[[VAL_15]], ptr %[[VAL_14]]) +// CHECK-GCN: %[[VAL_39_3:.*]] = addrspacecast ptr{{.*}}%[[VAL_39]] to ptr +// CHECK-GCN: %[[VAL_37_3:.*]] = addrspacecast ptr{{.*}}%[[VAL_37]] to ptr +// CHECK-GCN: %[[VAL_16_3:.*]] = addrspacecast ptr{{.*}}%[[VAL_16]] to ptr +// CHECK-GCN: %[[VAL_15_3:.*]] = addrspacecast ptr{{.*}}%[[VAL_15]] to ptr +// CHECK-GCN: %[[VAL_14_3:.*]] = addrspacecast ptr{{.*}}%[[VAL_14]] to ptr +// CHECK-GCN: call void @[[ADD:Add.*]](ptr %[[VAL_39_3]], ptr %[[VAL_37_3]], ptr %[[VAL_16_3]], ptr %[[VAL_15_3]], ptr %[[VAL_14_3]]) +// CHECK: %[[VAL_98:.*]] = load float, ptr{{.*}}%[[VAL_12]], align 4 +// CHECK: %[[VAL_99:.*]] = load float, ptr{{.*}}%[[VAL_13]], align 4 +// CHECK: store float %[[VAL_98]], ptr{{.*}}%[[VAL_39]], align 4 +// CHECK: store float %[[VAL_99]], ptr{{.*}}%[[VAL_37]], align 4 +// CHECK: %[[VAL_100:.*]] = load float, ptr{{.*}}%[[VAL_39]], align 4 +// CHECK-PTX: %[[VAL_101:.*]] = call float @llvm.nvvm.shfl.sync.down.f32(i32 -1, float %[[VAL_100]], i32 2, i32 31) +// CHECK-GCN: %[[VAL_100_1:.*]] = bitcast float %[[VAL_100]] to i32 +// CHECK-GCN: %[[VAL_101_1:.*]] = call i32 @__ockl_readuplane_i32(i32 %[[VAL_100_1:.*]], i32 2) +// CHECK-GCN: %[[VAL_101:.*]] = bitcast i32 %[[VAL_101_1:.*]] to float +// CHECK: store float %[[VAL_101]], ptr{{.*}}%[[VAL_11]], align 4 +// CHECK: %[[VAL_102:.*]] = load float, ptr{{.*}}%[[VAL_37]], align 4 +// CHECK-PTX: %[[VAL_103:.*]] = call float @llvm.nvvm.shfl.sync.down.f32(i32 -1, float %[[VAL_102]], i32 2, i32 31) +// CHECK-GCN: %[[VAL_102_1:.*]] = bitcast float %[[VAL_102]] to i32 +// CHECK-GCN: %[[VAL_103_1:.*]] = call i32 @__ockl_readuplane_i32(i32 %[[VAL_102_1:.*]], i32 2) +// CHECK-GCN: %[[VAL_103:.*]] = bitcast i32 %[[VAL_103_1:.*]] to float +// CHECK: store float %[[VAL_103]], ptr{{.*}}%[[VAL_10]], align 4 +// CHECK-GCN: %[[VAL_7_1:.*]] = addrspacecast ptr{{.*}}%[[VAL_7]] to ptr +// CHECK: %[[VAL_104:.*]] = getelementptr inbounds [2 x ptr], ptr{{.*}}%[[VAL_9]], i64 0, i64 0 +// CHECK-PTX: store ptr %[[VAL_7]], ptr %[[VAL_104]], align 8 +// CHECK-GCN: store ptr %[[VAL_7_1]], ptr{{.*}}%[[VAL_104]], align 8 +// CHECK-GCN: %[[VAL_8_1:.*]] = addrspacecast ptr{{.*}}%[[VAL_8]] to ptr +// CHECK: %[[VAL_105:.*]] = getelementptr inbounds [2 x ptr], ptr{{.*}}%[[VAL_9]], i64 0, i64 1 +// CHECK-PTX: store ptr %[[VAL_8]], ptr %[[VAL_105]], align 8 +// CHECK-GCN: store ptr %[[VAL_8_1]], ptr{{.*}}%[[VAL_105]], align 8 +// CHECK-PTX: call void @[[ADD:Add.*]](ptr %[[VAL_39]], ptr %[[VAL_37]], ptr %[[VAL_11]], ptr %[[VAL_10]], ptr %[[VAL_9]]) +// CHECK-GCN: %[[VAL_39_4:.*]] = addrspacecast ptr{{.*}}%[[VAL_39]] to ptr +// CHECK-GCN: %[[VAL_37_4:.*]] = addrspacecast ptr{{.*}}%[[VAL_37]] to ptr +// CHECK-GCN: %[[VAL_11_4:.*]] = addrspacecast ptr{{.*}}%[[VAL_11]] to ptr +// CHECK-GCN: %[[VAL_10_4:.*]] = addrspacecast ptr{{.*}}%[[VAL_10]] to ptr +// CHECK-GCN: %[[VAL_9_4:.*]] = addrspacecast ptr{{.*}}%[[VAL_9]] to ptr +// CHECK-GCN: call void @[[ADD:Add.*]](ptr %[[VAL_39_4]], ptr %[[VAL_37_4]], ptr %[[VAL_11_4]], ptr %[[VAL_10_4]], ptr %[[VAL_9_4]]) +// CHECK: %[[VAL_106:.*]] = load float, ptr{{.*}}%[[VAL_7]], align 4 +// CHECK: %[[VAL_107:.*]] = load float, ptr{{.*}}%[[VAL_8]], align 4 +// CHECK: store float %[[VAL_106]], ptr{{.*}}%[[VAL_39]], align 4 +// CHECK: store float %[[VAL_107]], ptr{{.*}}%[[VAL_37]], align 4 +// CHECK: %[[VAL_108:.*]] = load float, ptr{{.*}}%[[VAL_39]], align 4 +// CHECK-PTX: %[[VAL_109:.*]] = call float @llvm.nvvm.shfl.sync.down.f32(i32 -1, float %[[VAL_108]], i32 1, i32 31) +// CHECK-GCN: %[[VAL_108_1:.*]] = bitcast float %[[VAL_108]] to i32 +// CHECK-GCN: %[[VAL_109_1:.*]] = call i32 @__ockl_readuplane_i32(i32 %[[VAL_108_1:.*]], i32 1) +// CHECK-GCN: %[[VAL_109:.*]] = bitcast i32 %[[VAL_109_1:.*]] to float +// CHECK: store float %[[VAL_109]], ptr{{.*}}%[[VAL_6]], align 4 +// CHECK: %[[VAL_110:.*]] = load float, ptr{{.*}}%[[VAL_37]], align 4 +// CHECK-PTX: %[[VAL_111:.*]] = call float @llvm.nvvm.shfl.sync.down.f32(i32 -1, float %[[VAL_110]], i32 1, i32 31) +// CHECK-GCN: %[[VAL_110_1:.*]] = bitcast float %[[VAL_110]] to i32 +// CHECK-GCN: %[[VAL_111_1:.*]] = call i32 @__ockl_readuplane_i32(i32 %[[VAL_110_1:.*]], i32 1) +// CHECK-GCN: %[[VAL_111:.*]] = bitcast i32 %[[VAL_111_1:.*]] to float +// CHECK: store float %[[VAL_111]], ptr{{.*}}%[[VAL_5]], align 4 +// CHECK-GCN: %[[VAL_2_1:.*]] = addrspacecast ptr{{.*}}%[[VAL_2]] to ptr +// CHECK: %[[VAL_112:.*]] = getelementptr inbounds [2 x ptr], ptr{{.*}}%[[VAL_4]], i64 0, i64 0 +// CHECK-PTX: store ptr %[[VAL_2]], ptr %[[VAL_112]], align 8 +// CHECK-GCN: store ptr %[[VAL_2_1]], ptr{{.*}}%[[VAL_112]], align 8 +// CHECK-GCN: %[[VAL_3_1:.*]] = addrspacecast ptr{{.*}}%[[VAL_3]] to ptr +// CHECK: %[[VAL_113:.*]] = getelementptr inbounds [2 x ptr], ptr{{.*}}%[[VAL_4]], i64 0, i64 1 +// CHECK-PTX: store ptr %[[VAL_3]], ptr %[[VAL_113]], align 8 +// CHECK-GCN: store ptr %[[VAL_3_1]], ptr{{.*}}%[[VAL_113]], align 8 +// CHECK-PTX: call void @[[ADD:Add.*]](ptr %[[VAL_39]], ptr %[[VAL_37]], ptr %[[VAL_6]], ptr %[[VAL_5]], ptr %[[VAL_4]]) +// CHECK-GCN: %[[VAL_39_5:.*]] = addrspacecast ptr{{.*}}%[[VAL_39]] to ptr +// CHECK-GCN: %[[VAL_37_5:.*]] = addrspacecast ptr{{.*}}%[[VAL_37]] to ptr +// CHECK-GCN: %[[VAL_6_5:.*]] = addrspacecast ptr{{.*}}%[[VAL_6]] to ptr +// CHECK-GCN: %[[VAL_5_5:.*]] = addrspacecast ptr{{.*}}%[[VAL_5]] to ptr +// CHECK-GCN: %[[VAL_4_5:.*]] = addrspacecast ptr{{.*}}%[[VAL_4]] to ptr +// CHECK-GCN: call void @[[ADD:Add.*]](ptr %[[VAL_39_5]], ptr %[[VAL_37_5]], ptr %[[VAL_6_5]], ptr %[[VAL_5_5]], ptr %[[VAL_4_5]]) +// CHECK: %[[VAL_114:.*]] = load float, ptr{{.*}}%[[VAL_2]], align 4 +// CHECK: %[[VAL_115:.*]] = load float, ptr{{.*}}%[[VAL_3]], align 4 +// CHECK: store float %[[VAL_114]], ptr{{.*}}%[[VAL_39]], align 4 +// CHECK: store float %[[VAL_115]], ptr{{.*}}%[[VAL_37]], align 4 +// CHECK: %[[VAL_116:.*]] = udiv i32 %thread.id.2, 32 +// CHECK: %[[VAL_117:.*]] = icmp ult i32 %thread.id.1, 8 +// CHECK: br i1 %[[VAL_117]], label %thread_in_bounds-true, label %thread_in_bounds-after +// CHECK: thread_in_bounds-after: ; preds = %[[VAL_118:.*]], %[[VAL_60]] // CHECK: br label %[[VAL_44]] -// CHECK: early_return: ; preds = %[[VAL_43]] -// CHECK: ret void -// CHECK: is_full_tile-true: ; preds = %[[VAL_90]] -// CHECK: store i32 0, ptr %[[VAL_34]], align 4 -// CHECK: br label %[[VAL_146:.*]] -// CHECK: loop2.loop_header: ; preds = %[[VAL_147:.*]], %[[VAL_94]] -// CHECK: %[[VAL_148:.*]] = load i32, ptr %[[VAL_34]], align 4 -// CHECK: %[[VAL_149:.*]] = icmp uge i32 %[[VAL_148]], 16 -// CHECK: br i1 %[[VAL_149]], label %[[VAL_97]], label %[[VAL_147]] -// CHECK: loop2.loop_body: ; preds = %[[VAL_146]] -// CHECK: %[[VAL_150:.*]] = add nuw nsw i32 %[[VAL_148]], 1 -// CHECK: store i32 %[[VAL_150]], ptr %[[VAL_34]], align 4 -// CHECK: %[[VAL_151:.*]] = icmp eq i32 %[[VAL_148]], 0 -// CHECK: %[[VAL_152:.*]] = mul i32 %[[VAL_148]], 32 -// CHECK: %[[VAL_153:.*]] = add i32 %[[VAL_152]], 0 -// CHECK: %[[VAL_154:.*]] = add i32 %[[VAL_153]], %[[THREAD_X]] -// CHECK: %[[VAL_85:.*]] = add i32 %[[VAL_71]], %[[VAL_79]] -// CHECK: %[[VAL_155:.*]] = add i32 %[[VAL_72]], %[[VAL_88]] -// CHECK: %[[VAL_156:.*]] = add i32 %[[VAL_73]], %[[VAL_154]] -// CHECK: %[[VAL_157:.*]] = getelementptr inbounds [5 x [200 x [300 x float]]], ptr %[[VAL_158:.*]], i32 0, i32 %[[VAL_85]], i32 %[[VAL_155]], i32 %[[VAL_156]] -// CHECK: %[[VAL_159:.*]] = load float, ptr %[[VAL_157]], align 4, !invariant.load !3 -// CHECK: store float %[[VAL_159]], ptr %[[VAL_40]], align 4 -// CHECK: %[[VAL_161:.*]] = getelementptr inbounds [5 x [200 x [300 x float]]], ptr %[[VAL_162:.*]], i32 0, i32 %[[VAL_85]], i32 %[[VAL_155]], i32 %[[VAL_156]] -// CHECK: %[[VAL_163:.*]] = load float, ptr %[[VAL_161]], align 4, !invariant.load !3 -// CHECK: store float %[[VAL_163]], ptr %[[VAL_38]], align 4 -// CHECK: %[[VAL_165:.*]] = getelementptr inbounds [2 x ptr], ptr %[[VAL_33]], i64 0, i64 0 -// CHECK: store ptr %[[VAL_31]], ptr %[[VAL_165]], align 8 -// CHECK: %[[VAL_166:.*]] = getelementptr inbounds [2 x ptr], ptr %[[VAL_33]], i64 0, i64 1 -// CHECK: store ptr %[[VAL_32]], ptr %[[VAL_166]], align 8 -// CHECK: call void @[[ADD]](ptr %[[VAL_39]], ptr %[[VAL_37]], ptr %[[VAL_40]], ptr %[[VAL_38]], ptr %[[VAL_33]]) -// CHECK: %[[VAL_167:.*]] = load float, ptr %[[VAL_31]], align 4 -// CHECK: %[[VAL_168:.*]] = load float, ptr %[[VAL_32]], align 4 -// CHECK: store float %[[VAL_167]], ptr %[[VAL_39]], align 4 -// CHECK: store float %[[VAL_168]], ptr %[[VAL_37]], align 4 -// CHECK: br label %[[VAL_146]], !llvm.loop !9 -// CHECK: loop2.loop_exit: ; preds = %[[VAL_146]] -// CHECK: br label %[[VAL_87]] -// CHECK: is_full_tile-false: ; preds = %[[VAL_90]] -// CHECK: store i32 0, ptr %[[VAL_30]], align 4 -// CHECK: br label %[[VAL_170:.*]] -// CHECK: loop2.loop_header9: ; preds = %[[VAL_171:.*]], %[[VAL_95]] -// CHECK: %[[VAL_172:.*]] = load i32, ptr %[[VAL_30]], align 4 -// CHECK: %[[VAL_173:.*]] = icmp uge i32 %[[VAL_172]], 16 -// CHECK: br i1 %[[VAL_173]], label %[[VAL_96]], label %[[VAL_174:.*]] -// CHECK: loop2.loop_body10: ; preds = %[[VAL_170]] -// CHECK: %[[VAL_175:.*]] = add nuw nsw i32 %[[VAL_172]], 1 -// CHECK: store i32 %[[VAL_175]], ptr %[[VAL_30]], align 4 -// CHECK: %[[VAL_176:.*]] = icmp eq i32 %[[VAL_172]], 0 -// CHECK: %[[VAL_177:.*]] = mul i32 %[[VAL_172]], 32 -// CHECK: %[[VAL_178:.*]] = add i32 %[[VAL_177]], 0 -// CHECK: %[[VAL_179:.*]] = add i32 %[[VAL_178]], %[[THREAD_X]] -// CHECK: %[[VAL_180:.*]] = icmp ult i32 %[[VAL_179]], %[[VAL_70]] -// CHECK: br i1 %[[VAL_180]], label %[[VAL_181:.*]], label %[[VAL_171]] -// CHECK: x_in_tile-after: ; preds = %[[VAL_181]], %[[VAL_174]] -// CHECK: br label %[[VAL_170]], !llvm.loop !11 -// CHECK: loop2.loop_exit8: ; preds = %[[VAL_170]] -// CHECK: br label %[[VAL_87]] -// CHECK: x_in_tile-true: ; preds = %[[VAL_174]] -// CHECK: %[[VAL_85:.*]] = add i32 %[[VAL_71]], %[[VAL_79]] -// CHECK: %[[VAL_182:.*]] = add i32 %[[VAL_72]], %[[VAL_88]] -// CHECK: %[[VAL_183:.*]] = add i32 %[[VAL_73]], %[[VAL_179]] -// CHECK: %[[VAL_184:.*]] = getelementptr inbounds [5 x [200 x [300 x float]]], ptr %[[VAL_158]], i32 0, i32 %[[VAL_85]], i32 %[[VAL_182]], i32 %[[VAL_183]] -// CHECK: %[[VAL_185:.*]] = load float, ptr %[[VAL_184]], align 4, !invariant.load !3 -// CHECK: store float %[[VAL_185]], ptr %[[VAL_40]], align 4 -// CHECK: %[[VAL_187:.*]] = getelementptr inbounds [5 x [200 x [300 x float]]], ptr %[[VAL_162]], i32 0, i32 %[[VAL_85]], i32 %[[VAL_182]], i32 %[[VAL_183]] -// CHECK: %[[VAL_188:.*]] = load float, ptr %[[VAL_187]], align 4, !invariant.load !3 -// CHECK: store float %[[VAL_188]], ptr %[[VAL_38]], align 4 -// CHECK: %[[VAL_190:.*]] = getelementptr inbounds [2 x ptr], ptr %[[VAL_29]], i64 0, i64 0 -// CHECK: store ptr %[[VAL_27]], ptr %[[VAL_190]], align 8 -// CHECK: %[[VAL_191:.*]] = getelementptr inbounds [2 x ptr], ptr %[[VAL_29]], i64 0, i64 1 -// CHECK: store ptr %[[VAL_28]], ptr %[[VAL_191]], align 8 -// CHECK: call void @[[ADD]](ptr %[[VAL_39]], ptr %[[VAL_37]], ptr %[[VAL_40]], ptr %[[VAL_38]], ptr %[[VAL_29]]) -// CHECK: %[[VAL_192:.*]] = load float, ptr %[[VAL_27]], align 4 -// CHECK: %[[VAL_193:.*]] = load float, ptr %[[VAL_28]], align 4 -// CHECK: store float %[[VAL_192]], ptr %[[VAL_39]], align 4 -// CHECK: store float %[[VAL_193]], ptr %[[VAL_37]], align 4 -// CHECK: br label %[[VAL_171]] -// CHECK: intra_warp_reduce_write-true: ; preds = %[[VAL_81]] -// CHECK: %[[VAL_196:.*]] = load float, ptr %[[VAL_39]], align 4 -// CHECK: %[[VAL_194:.*]] = getelementptr inbounds [1 x [1 x float]], ptr addrspace(3) @shared_cache, i32 0, i32 %[[VAL_53]], i32 %[[VAL_138]] -// CHECK: %[[VAL_195:.*]] = addrspacecast ptr addrspace(3) %[[VAL_194]] to ptr -// CHECK: store float %[[VAL_196]], ptr %[[VAL_195]], align 4 -// CHECK: %[[VAL_199:.*]] = load float, ptr %[[VAL_37]], align 4 -// CHECK: %[[VAL_197:.*]] = getelementptr inbounds [1 x [1 x float]], ptr addrspace(3) @shared_cache1, i32 0, i32 %[[VAL_53]], i32 %[[VAL_138]] -// CHECK: %[[VAL_198:.*]] = addrspacecast ptr addrspace(3) %[[VAL_197]] to ptr -// CHECK: store float %[[VAL_199]], ptr %[[VAL_198]], align 4 -// CHECK: br label %[[VAL_141]] -// CHECK: inter_warp_reduce-true: ; preds = %[[VAL_141]] -// CHECK: %[[VAL_200:.*]] = getelementptr inbounds [1 x [1 x float]], ptr addrspace(3) @shared_cache, i32 0, i32 %[[VAL_53]], i32 %[[VAL_61]] -// CHECK: %[[VAL_201:.*]] = addrspacecast ptr addrspace(3) %[[VAL_200]] to ptr -// CHECK: store float %[[VAL_47]], ptr %[[VAL_1]], align 4 -// CHECK: %[[VAL_202:.*]] = icmp ult i32 %[[VAL_59]], 1 -// CHECK: %[[VAL_203:.*]] = select i1 %[[VAL_202]], ptr %[[VAL_201]], ptr %[[VAL_1]] -// CHECK: %[[VAL_204:.*]] = getelementptr inbounds [1 x [1 x float]], ptr addrspace(3) @shared_cache1, i32 0, i32 %[[VAL_53]], i32 %[[VAL_61]] -// CHECK: %[[VAL_205:.*]] = addrspacecast ptr addrspace(3) %[[VAL_204]] to ptr -// CHECK: store float %[[VAL_49]], ptr %[[VAL_0]], align 4 -// CHECK: %[[VAL_206:.*]] = icmp ult i32 %[[VAL_59]], 1 -// CHECK: %[[VAL_207:.*]] = select i1 %[[VAL_206]], ptr %[[VAL_205]], ptr %[[VAL_0]] -// CHECK: %[[VAL_208:.*]] = icmp eq i32 %[[VAL_59]], 0 -// CHECK: br i1 %[[VAL_208]], label %[[VAL_209:.*]], label %[[VAL_144]] -// CHECK: reduction_write_output-after: ; preds = %[[VAL_209]], %[[VAL_143]] -// CHECK: br label %[[VAL_45]] -// CHECK: reduction_write_output-true: ; preds = %[[VAL_143]] -// CHECK: %[[VAL_211:.*]] = add i32 %[[VAL_72]], %[[VAL_60]] -// CHECK: %[[VAL_212:.*]] = add i32 %[[VAL_73]], %[[THREAD_X]] -// CHECK: %[[VAL_213:.*]] = udiv i32 %[[VAL_211]], 1 -// CHECK: %[[VAL_214:.*]] = getelementptr inbounds [200 x float], ptr %[[VAL_215:.*]], i32 0, i32 %[[VAL_213]] -// CHECK: %[[VAL_216:.*]] = load float, ptr %[[VAL_203]], align 4 -// CHECK: store float %[[VAL_216]], ptr %[[VAL_214]], align 4 -// CHECK: %[[VAL_218:.*]] = add i32 %[[VAL_72]], %[[VAL_60]] -// CHECK: %[[VAL_219:.*]] = add i32 %[[VAL_73]], %[[THREAD_X]] -// CHECK: %[[VAL_220:.*]] = udiv i32 %[[VAL_218]], 1 -// CHECK: %[[VAL_221:.*]] = getelementptr inbounds [200 x float], ptr %[[VAL_222:.*]], i32 0, i32 %[[VAL_220]] -// CHECK: %[[VAL_223:.*]] = load float, ptr %[[VAL_207]], align 4 -// CHECK: store float %[[VAL_223]], ptr %[[VAL_221]], align 4 -// CHECK: br label %[[VAL_144]] +// CHECK: is_full_tile-true: ; preds = %[[VAL_68]] +// CHECK: store i32 0, ptr{{.*}}%[[VAL_34]], align 4 +// CHECK: br label %[[VAL_119:.*]] +// CHECK: loop2.loop_header: ; preds = %[[VAL_120:.*]], %[[VAL_72]] +// CHECK: %[[VAL_121:.*]] = load i32, ptr{{.*}}%[[VAL_34]], align 4 +// CHECK: %[[VAL_122:.*]] = icmp uge i32 %[[VAL_121]], 512 +// CHECK: br i1 %[[VAL_122]], label %[[VAL_75]], label %[[VAL_120]] +// CHECK: loop2.loop_body: ; preds = %[[VAL_119]] +// CHECK: %[[VAL_123:.*]] = add nuw nsw i32 %[[VAL_121]], 32 +// CHECK: store i32 %[[VAL_123]], ptr{{.*}}%[[VAL_34]], align 4 +// CHECK: %[[VAL_124:.*]] = icmp eq i32 %[[VAL_121]], 0 +// CHECK: %[[VAL_125:.*]] = add i32 %[[VAL_121]], %thread.id.2 +// CHECK: %[[VAL_126:.*]] = add i32 %tile_origin.0, %[[VAL_58]] +// CHECK: %[[VAL_127:.*]] = add i32 %tile_origin.1, %[[VAL_66]] +// CHECK: %[[VAL_128:.*]] = add i32 %tile_origin.2, %[[VAL_125]] +// CHECK: %[[VAL_129:.*]] = getelementptr inbounds [5 x [200 x [300 x float]]], ptr{{.*}}%[[VAL_130:.*]], i32 0, i32 %[[VAL_126]], i32 %[[VAL_127]], i32 %[[VAL_128]] +// CHECK: %[[VAL_131:.*]] = load float, ptr{{.*}}%[[VAL_129]], align 4, !invariant.load !{{[0-9]}} +// CHECK: store float %[[VAL_131]], ptr{{.*}}%[[VAL_40]], align 4 +// CHECK: %[[VAL_132:.*]] = getelementptr inbounds [5 x [200 x [300 x float]]], ptr{{.*}}%[[VAL_133:.*]], i32 0, i32 %[[VAL_126]], i32 %[[VAL_127]], i32 %[[VAL_128]] +// CHECK: %[[VAL_134:.*]] = load float, ptr{{.*}}%[[VAL_132]], align 4, !invariant.load !{{[0-9]}} +// CHECK: store float %[[VAL_134]], ptr{{.*}}%[[VAL_38]], align 4 +// CHECK-GCN: %[[VAL_31_1:.*]] = addrspacecast ptr{{.*}}%[[VAL_31]] to ptr +// CHECK: %[[VAL_135:.*]] = getelementptr inbounds [2 x ptr], ptr{{.*}}%[[VAL_33]], i64 0, i64 0 +// CHECK-PTX: store ptr %[[VAL_31]], ptr %[[VAL_135]], align 8 +// CHECK-GCN: store ptr %[[VAL_31_1]], ptr{{.*}}%[[VAL_135]], align 8 +// CHECK-GCN: %[[VAL_32_1:.*]] = addrspacecast ptr{{.*}}%[[VAL_32]] to ptr +// CHECK: %[[VAL_136:.*]] = getelementptr inbounds [2 x ptr], ptr{{.*}}%[[VAL_33]], i64 0, i64 1 +// CHECK-PTX: store ptr %[[VAL_32]], ptr %[[VAL_136]], align 8 +// CHECK-GCN: store ptr %[[VAL_32_1]], ptr{{.*}}%[[VAL_136]], align 8 +// CHECK-PTX: call void @[[ADD:Add.*]](ptr %[[VAL_39]], ptr %[[VAL_37]], ptr %[[VAL_40]], ptr %[[VAL_38]], ptr %[[VAL_33]]) +// CHECK-GCN: %[[VAL_39_6:.*]] = addrspacecast ptr{{.*}}%[[VAL_39]] to ptr +// CHECK-GCN: %[[VAL_37_6:.*]] = addrspacecast ptr{{.*}}%[[VAL_37]] to ptr +// CHECK-GCN: %[[VAL_40_6:.*]] = addrspacecast ptr{{.*}}%[[VAL_40]] to ptr +// CHECK-GCN: %[[VAL_38_6:.*]] = addrspacecast ptr{{.*}}%[[VAL_38]] to ptr +// CHECK-GCN: %[[VAL_33_6:.*]] = addrspacecast ptr{{.*}}%[[VAL_33]] to ptr +// CHECK-GCN: call void @[[ADD:Add.*]](ptr %[[VAL_39_6]], ptr %[[VAL_37_6]], ptr %[[VAL_40_6]], ptr %[[VAL_38_6]], ptr %[[VAL_33_6]]) +// CHECK: %[[VAL_137:.*]] = load float, ptr{{.*}}%[[VAL_31]], align 4 +// CHECK: %[[VAL_138:.*]] = load float, ptr{{.*}}%[[VAL_32]], align 4 +// CHECK: store float %[[VAL_137]], ptr{{.*}}%[[VAL_39]], align 4 +// CHECK: store float %[[VAL_138]], ptr{{.*}}%[[VAL_37]], align 4 +// CHECK: br label %[[VAL_119]], !llvm.loop !{{[0-9]}} +// CHECK: loop2.loop_exit: ; preds = %[[VAL_119]] +// CHECK: br label %[[VAL_65]] +// CHECK: is_full_tile-false: ; preds = %[[VAL_68]] +// CHECK: store i32 0, ptr{{.*}}%[[VAL_30]], align 4 +// CHECK: br label %[[VAL_139:.*]] +// CHECK: loop2.loop_header9: ; preds = %[[VAL_140:.*]], %[[VAL_73]] +// CHECK: %[[VAL_141:.*]] = load i32, ptr{{.*}}%[[VAL_30]], align 4 +// CHECK: %[[VAL_142:.*]] = icmp uge i32 %[[VAL_141]], 512 +// CHECK: br i1 %[[VAL_142]], label %[[VAL_74]], label %[[VAL_143:.*]] +// CHECK: loop2.loop_body10: ; preds = %[[VAL_139]] +// CHECK: %[[VAL_144:.*]] = add nuw nsw i32 %[[VAL_141]], 32 +// CHECK: store i32 %[[VAL_144]], ptr{{.*}}%[[VAL_30]], align 4 +// CHECK: %[[VAL_145:.*]] = icmp eq i32 %[[VAL_141]], 0 +// CHECK: %[[VAL_146:.*]] = add i32 %[[VAL_141]], %thread.id.2 +// CHECK: %[[VAL_147:.*]] = icmp ult i32 %[[VAL_146]], %tile_bound.2 +// CHECK: br i1 %[[VAL_147]], label %[[VAL_148:.*]], label %[[VAL_140]] +// CHECK: x_in_tile-after: ; preds = %[[VAL_148]], %[[VAL_143]] +// CHECK: br label %[[VAL_139]], !llvm.loop !{{[0-9]}} +// CHECK: loop2.loop_exit8: ; preds = %[[VAL_139]] +// CHECK: br label %[[VAL_65]] +// CHECK: x_in_tile-true: ; preds = %[[VAL_143]] +// CHECK: %[[VAL_149:.*]] = add i32 %tile_origin.0, %[[VAL_58]] +// CHECK: %[[VAL_150:.*]] = add i32 %tile_origin.1, %[[VAL_66]] +// CHECK: %[[VAL_151:.*]] = add i32 %tile_origin.2, %[[VAL_146]] +// CHECK: %[[VAL_152:.*]] = getelementptr inbounds [5 x [200 x [300 x float]]], ptr{{.*}}%[[VAL_130]], i32 0, i32 %[[VAL_149]], i32 %[[VAL_150]], i32 %[[VAL_151]] +// CHECK: %[[VAL_153:.*]] = load float, ptr{{.*}}%[[VAL_152]], align 4, !invariant.load !{{[0-9]}} +// CHECK: store float %[[VAL_153]], ptr{{.*}}%[[VAL_40]], align 4 +// CHECK: %[[VAL_154:.*]] = getelementptr inbounds [5 x [200 x [300 x float]]], ptr{{.*}}%[[VAL_133]], i32 0, i32 %[[VAL_149]], i32 %[[VAL_150]], i32 %[[VAL_151]] +// CHECK: %[[VAL_155:.*]] = load float, ptr{{.*}}%[[VAL_154]], align 4, !invariant.load !{{[0-9]}} +// CHECK: store float %[[VAL_155]], ptr{{.*}}%[[VAL_38]], align 4 +// CHECK-GCN: %[[VAL_27_1:.*]] = addrspacecast ptr{{.*}}%[[VAL_27]] to ptr +// CHECK: %[[VAL_156:.*]] = getelementptr inbounds [2 x ptr], ptr{{.*}}%[[VAL_29]], i64 0, i64 0 +// CHECK-PTX: store ptr %[[VAL_27]], ptr %[[VAL_156]], align 8 +// CHECK-GCN: store ptr %[[VAL_27_1]], ptr{{.*}}%[[VAL_156]], align 8 +// CHECK-GCN: %[[VAL_28_1:.*]] = addrspacecast ptr{{.*}}%[[VAL_28]] to ptr +// CHECK: %[[VAL_157:.*]] = getelementptr inbounds [2 x ptr], ptr{{.*}}%[[VAL_29]], i64 0, i64 1 +// CHECK-PTX: store ptr %[[VAL_28]], ptr %[[VAL_157]], align 8 +// CHECK-GCN: store ptr %[[VAL_28_1]], ptr{{.*}}%[[VAL_157]], align 8 +// CHECK-PTX: call void @[[ADD:Add.*]](ptr %[[VAL_39]], ptr %[[VAL_37]], ptr %[[VAL_40]], ptr %[[VAL_38]], ptr %[[VAL_29]]) +// CHECK-GCN: %[[VAL_39_7:.*]] = addrspacecast ptr{{.*}}%[[VAL_39]] to ptr +// CHECK-GCN: %[[VAL_37_7:.*]] = addrspacecast ptr{{.*}}%[[VAL_37]] to ptr +// CHECK-GCN: %[[VAL_40_7:.*]] = addrspacecast ptr{{.*}}%[[VAL_40]] to ptr +// CHECK-GCN: %[[VAL_38_7:.*]] = addrspacecast ptr{{.*}}%[[VAL_38]] to ptr +// CHECK-GCN: %[[VAL_29_7:.*]] = addrspacecast ptr{{.*}}%[[VAL_29]] to ptr +// CHECK-GCN: call void @[[ADD:Add.*]](ptr %[[VAL_39_7]], ptr %[[VAL_37_7]], ptr %[[VAL_40_7]], ptr %[[VAL_38_7]], ptr %[[VAL_29_7]]) +// CHECK: %[[VAL_158:.*]] = load float, ptr{{.*}}%[[VAL_27]], align 4 +// CHECK: %[[VAL_159:.*]] = load float, ptr{{.*}}%[[VAL_28]], align 4 +// CHECK: store float %[[VAL_158]], ptr{{.*}}%[[VAL_39]], align 4 +// CHECK: store float %[[VAL_159]], ptr{{.*}}%[[VAL_37]], align 4 +// CHECK: br label %[[VAL_140]] +// CHECK: thread_in_bounds-true: ; preds = %[[VAL_60]] +// CHECK: %[[VAL_160:.*]] = icmp eq i32 %lane_id, 0 +// CHECK: br i1 %[[VAL_160]], label %[[VAL_161:.*]], label %[[VAL_162:.*]] +// CHECK: intra_warp_reduce_write-after: ; preds = %[[VAL_161]], %thread_in_bounds-true +// CHECK-PTX: call void @llvm.nvvm.barrier0() +// CHECK-GCN: call void @llvm.amdgcn.s.barrier() +// CHECK: %[[VAL_163:.*]] = icmp eq i32 %[[VAL_116]], 0 +// CHECK: br i1 %[[VAL_163]], label %[[VAL_164:.*]], label %[[VAL_118]] +// CHECK: inter_warp_reduce-after: ; preds = %[[VAL_165:.*]], %[[VAL_162]] +// CHECK: br label %thread_in_bounds-after +// CHECK: intra_warp_reduce_write-true: ; preds = %thread_in_bounds-true +// CHECK: %[[VAL_166:.*]] = load float, ptr{{.*}}%[[VAL_39]], align 4 +// CHECK: %[[VAL_167:.*]] = getelementptr inbounds [8 x [1 x float]], ptr addrspace(3) @shared_cache, i32 0, i32 %thread.id.1, i32 %[[VAL_116]] +// CHECK: %[[VAL_168:.*]] = addrspacecast ptr addrspace(3) %[[VAL_167]] to ptr +// CHECK: store float %[[VAL_166]], ptr{{.*}}%[[VAL_168]], align 4 +// CHECK: %[[VAL_169:.*]] = load float, ptr{{.*}}%[[VAL_37]], align 4 +// CHECK: %[[VAL_170:.*]] = getelementptr inbounds [8 x [1 x float]], ptr addrspace(3) @shared_cache{{.*}}, i32 0, i32 %thread.id.1, i32 %[[VAL_116]] +// CHECK: %[[VAL_171:.*]] = addrspacecast ptr addrspace(3) %[[VAL_170]] to ptr +// CHECK: store float %[[VAL_169]], ptr{{.*}}%[[VAL_171]], align 4 +// CHECK: br label %[[VAL_162]] +// CHECK: inter_warp_reduce-true: ; preds = %[[VAL_162]] +// CHECK: %[[VAL_172:.*]] = getelementptr inbounds [8 x [1 x float]], ptr addrspace(3) @shared_cache, i32 0, i32 %thread.id.1, i32 %lane_id +// CHECK: %[[VAL_173:.*]] = addrspacecast ptr addrspace(3) %[[VAL_172]] to ptr +// CHECK-GCN: %[[VAL_1_1:.*]] = addrspacecast ptr{{.*}}%[[VAL_1]] to ptr +// CHECK-PTX: store float %[[VAL_46]], ptr %[[VAL_1]], align 4 +// CHECK-GCN: store float %[[VAL_46]], ptr %[[VAL_1_1]], align 4 +// CHECK: %[[VAL_174:.*]] = icmp ult i32 %thread.id.2, 1 +// CHECK-PTX: %[[VAL_175:.*]] = select i1 %[[VAL_174]], ptr %[[VAL_173]], ptr %[[VAL_1]] +// CHECK-GCN: %[[VAL_175:.*]] = select i1 %[[VAL_174]], ptr %[[VAL_173]], ptr %[[VAL_1_1]] +// CHECK: %[[VAL_176:.*]] = getelementptr inbounds [8 x [1 x float]], ptr addrspace(3) @shared_cache{{.*}}, i32 0, i32 %thread.id.1, i32 %lane_id +// CHECK: %[[VAL_177:.*]] = addrspacecast ptr addrspace(3) %[[VAL_176]] to ptr +// CHECK-GCN: %[[VAL_0_1:.*]] = addrspacecast ptr{{.*}}%[[VAL_0]] to ptr +// CHECK-PTX: store float %[[VAL_48]], ptr{{.*}}%[[VAL_0]], align 4 +// CHECK-GCN: store float %[[VAL_48]], ptr{{.*}}%[[VAL_0_1]], align 4 +// CHECK: %[[VAL_178:.*]] = icmp ult i32 %thread.id.2, 1 +// CHECK-PTX: %[[VAL_179:.*]] = select i1 %[[VAL_178]], ptr{{.*}}%[[VAL_177]], ptr %[[VAL_0]] +// CHECK-GCN: %[[VAL_179:.*]] = select i1 %[[VAL_178]], ptr{{.*}}%[[VAL_177]], ptr %[[VAL_0_1]] +// CHECK: %[[VAL_180:.*]] = icmp eq i32 %thread.id.2, 0 +// CHECK: br i1 %[[VAL_180]], label %[[VAL_181:.*]], label %[[VAL_165]] +// CHECK: reduction_write_output-after: ; preds = %[[VAL_181]], %[[VAL_164]] +// CHECK: br label %[[VAL_118]] +// CHECK: reduction_write_output-true: ; preds = %[[VAL_164]] +// CHECK: %[[VAL_183:.*]] = add i32 %tile_origin.1, %thread.id.1 +// CHECK: %[[VAL_186:.*]] = getelementptr inbounds [200 x float], ptr{{.*}}%[[VAL_187:.*]], i32 0, i32 %[[VAL_183]] +// CHECK: %[[VAL_188:.*]] = load float, ptr{{.*}}%[[VAL_175]], align 4 +// CHECK: store float %[[VAL_188]], ptr{{.*}}%[[VAL_186]], align 4 +// CHECK: %[[VAL_190:.*]] = add i32 %tile_origin.1, %thread.id.1 +// CHECK: %[[VAL_193:.*]] = getelementptr inbounds [200 x float], ptr{{.*}}%[[VAL_194:.*]], i32 0, i32 %[[VAL_190]] +// CHECK: %[[VAL_195:.*]] = load float, ptr{{.*}}%[[VAL_179]], align 4 +// CHECK: store float %[[VAL_195]], ptr{{.*}}%[[VAL_193]], align 4 +// CHECK: br label %[[VAL_165]] // CHECK: entry: -// CHECK: %[[VAL_224:.*]] = alloca float, align 4 -// CHECK: %[[VAL_225:.*]] = alloca float, align 4 -// CHECK: %[[VAL_226:.*]] = alloca [2 x ptr], align 8 -// CHECK: %[[VAL_227:.*]] = alloca [2 x ptr], align 8 -// CHECK: %[[VAL_228:.*]] = alloca [2 x ptr], align 8 -// CHECK: %[[VAL_229:.*]] = load float, ptr %[[VAL_230:.*]], align 4 -// CHECK: %[[VAL_231:.*]] = load float, ptr %[[VAL_232:.*]], align 4 -// CHECK: %[[VAL_233:.*]] = fadd float %[[VAL_229]], %[[VAL_231]] -// CHECK: store float %[[VAL_233]], ptr %[[VAL_225]], align 4 -// CHECK: %[[VAL_234:.*]] = load float, ptr %[[VAL_235:.*]], align 4 -// CHECK: %[[VAL_236:.*]] = load float, ptr %[[VAL_237:.*]], align 4 -// CHECK: %[[VAL_238:.*]] = fadd float %[[VAL_234]], %[[VAL_236]] -// CHECK: store float %[[VAL_238]], ptr %[[VAL_224]], align 4 -// CHECK: %[[VAL_239:.*]] = getelementptr inbounds [2 x ptr], ptr %[[VAL_228]], i64 0, i64 0 -// CHECK: store ptr %[[VAL_225]], ptr %[[VAL_239]], align 8 -// CHECK: %[[VAL_240:.*]] = getelementptr inbounds [2 x ptr], ptr %[[VAL_228]], i64 0, i64 1 -// CHECK: store ptr %[[VAL_224]], ptr %[[VAL_240]], align 8 -// CHECK: %[[VAL_241:.*]] = getelementptr inbounds [2 x ptr], ptr %[[VAL_242:.*]], i64 0, i64 0 -// CHECK: %[[VAL_243:.*]] = load ptr, ptr %[[VAL_241]], align 8, !dereferenceable !12, !align !13 -// CHECK: %[[VAL_244:.*]] = getelementptr inbounds [2 x ptr], ptr %[[VAL_228]], i64 0, i64 0 -// CHECK: %[[VAL_245:.*]] = load ptr, ptr %[[VAL_244]], align 8, !dereferenceable !12, !align !13 -// CHECK: %[[VAL_246:.*]] = load float, ptr %[[VAL_245]], align 4 -// CHECK: store float %[[VAL_246]], ptr %[[VAL_243]], align 4 -// CHECK: %[[VAL_247:.*]] = getelementptr inbounds [2 x ptr], ptr %[[VAL_242]], i64 0, i64 1 -// CHECK: %[[VAL_248:.*]] = load ptr, ptr %[[VAL_247]], align 8, !dereferenceable !12, !align !13 -// CHECK: %[[VAL_249:.*]] = getelementptr inbounds [2 x ptr], ptr %[[VAL_228]], i64 0, i64 1 -// CHECK: %[[VAL_250:.*]] = load ptr, ptr %[[VAL_249]], align 8, !dereferenceable !12, !align !13 -// CHECK: %[[VAL_251:.*]] = load float, ptr %[[VAL_250]], align 4 -// CHECK: store float %[[VAL_251]], ptr %[[VAL_248]], align 4 +// CHECK: %[[VAL_196:.*]] = alloca float, align 4 +// CHECK: %[[VAL_197:.*]] = alloca float, align 4 +// CHECK: %[[VAL_198:.*]] = alloca [2 x ptr], align 8 +// CHECK: %[[VAL_199:.*]] = alloca [2 x ptr], align 8 +// CHECK: %[[VAL_200:.*]] = alloca [2 x ptr], align 8 +// CHECK: %[[VAL_201:.*]] = load float, ptr{{.*}}%[[VAL_202:.*]], align 4 +// CHECK: %[[VAL_203:.*]] = load float, ptr{{.*}}%[[VAL_204:.*]], align 4 +// CHECK: %[[VAL_205:.*]] = fadd float %[[VAL_201]], %[[VAL_203]] +// CHECK: store float %[[VAL_205]], ptr{{.*}}%[[VAL_197]], align 4 +// CHECK: %[[VAL_206:.*]] = load float, ptr{{.*}}%[[VAL_207:.*]], align 4 +// CHECK: %[[VAL_208:.*]] = load float, ptr{{.*}}%[[VAL_209:.*]], align 4 +// CHECK: %[[VAL_210:.*]] = fadd float %[[VAL_206]], %[[VAL_208]] +// CHECK: store float %[[VAL_210]], ptr{{.*}}%[[VAL_196]], align 4 +// CHECK-GCN: %[[VAL_197_1:.*]] = addrspacecast ptr{{.*}}%[[VAL_197]] to ptr +// CHECK: %[[VAL_211:.*]] = getelementptr inbounds [2 x ptr], ptr{{.*}}%[[VAL_200]], i64 0, i64 0 +// CHECK-PTX: store ptr %[[VAL_197]], ptr %[[VAL_211]], align 8 +// CHECK-GCN: store ptr %[[VAL_197_1]], ptr{{.*}}%[[VAL_211]], align 8 +// CHECK-GCN: %[[VAL_196_1:.*]] = addrspacecast ptr{{.*}}%[[VAL_196]] to ptr +// CHECK: %[[VAL_212:.*]] = getelementptr inbounds [2 x ptr], ptr{{.*}}%[[VAL_200]], i64 0, i64 1 +// CHECK-PTX: store ptr %[[VAL_196]], ptr %[[VAL_212]], align 8 +// CHECK-GCN: store ptr %[[VAL_196_1]], ptr{{.*}}%[[VAL_212]], align 8 +// CHECK: %[[VAL_213:.*]] = getelementptr inbounds [2 x ptr], ptr{{.*}}%[[VAL_214:.*]], i64 0, i64 0 +// CHECK: %[[VAL_215:.*]] = load ptr, ptr{{.*}}%[[VAL_213]], align 8, !dereferenceable !{{[0-9]*}}, !align !{{[0-9]*}} +// CHECK: %[[VAL_216:.*]] = getelementptr inbounds [2 x ptr], ptr{{.*}}%[[VAL_200]], i64 0, i64 0 +// CHECK: %[[VAL_217:.*]] = load ptr, ptr{{.*}}%[[VAL_216]], align 8, !dereferenceable !{{[0-9]*}}, !align !{{[0-9]*}} +// CHECK: %[[VAL_218:.*]] = load float, ptr{{.*}}%[[VAL_217]], align 4 +// CHECK: store float %[[VAL_218]], ptr{{.*}}%[[VAL_215]], align 4 +// CHECK: %[[VAL_219:.*]] = getelementptr inbounds [2 x ptr], ptr{{.*}}%[[VAL_214]], i64 0, i64 1 +// CHECK: %[[VAL_220:.*]] = load ptr, ptr{{.*}}%[[VAL_219]], align 8, !dereferenceable !{{[0-9]*}}, !align !{{[0-9]*}} +// CHECK: %[[VAL_221:.*]] = getelementptr inbounds [2 x ptr], ptr{{.*}}%[[VAL_200]], i64 0, i64 1 +// CHECK: %[[VAL_222:.*]] = load ptr, ptr{{.*}}%[[VAL_221]], align 8, !dereferenceable !{{[0-9]*}}, !align !{{[0-9]*}} +// CHECK: %[[VAL_223:.*]] = load float, ptr{{.*}}%[[VAL_222]], align 4 +// CHECK: store float %[[VAL_223]], ptr{{.*}}%[[VAL_220]], align 4 // CHECK: ret void + diff --git a/third_party/xla/xla/service/gpu/tests/reduction_vectorization_sm_all.hlo b/third_party/xla/xla/service/gpu/tests/reduction_vectorization_sm_all.hlo index fab7dd00de8164..6a25580a4bcff9 100644 --- a/third_party/xla/xla/service/gpu/tests/reduction_vectorization_sm_all.hlo +++ b/third_party/xla/xla/service/gpu/tests/reduction_vectorization_sm_all.hlo @@ -32,7 +32,7 @@ ENTRY %main { // ----- // CHECK-SM86-LABEL: .entry wrapped_reduce_small_row -// CHECK-SM86: .reqntid 96, 1, 1 +// CHECK-SM86: .reqntid 256, 1, 1 HloModule ReduceSmallRow, is_scheduled=true @@ -189,7 +189,7 @@ HloModule ReduceEvenColumns, is_scheduled=true // CHECK-SM70-LABEL: .entry wrapped_reduce_even_col // CHECK-SM70-COUNT-2: ld.global.nc.v2.f32 -// CHECK-SM70-COUNT-2: ld.global.nc.f32 +// CHECK-SM70-COUNT-2: ld.global.nc.v2.f32 %max_ { %x = f32[] parameter(0) diff --git a/third_party/xla/xla/service/gpu/tests/simplify_fp_conversions_test.cc b/third_party/xla/xla/service/gpu/tests/simplify_fp_conversions_test.cc index 9cf312bdde936a..98e56085123a17 100644 --- a/third_party/xla/xla/service/gpu/tests/simplify_fp_conversions_test.cc +++ b/third_party/xla/xla/service/gpu/tests/simplify_fp_conversions_test.cc @@ -27,7 +27,7 @@ class SimplifyFPConversionsTest : public HloTestBase { public: DebugOptions GetDebugOptionsForTest() override { DebugOptions debug_options = HloTestBase::GetDebugOptionsForTest(); - debug_options.set_xla_gpu_simplify_all_fp_conversions( + debug_options.set_xla_allow_excess_precision( enable_simplify_all_fp_conversions_); return debug_options; } diff --git a/third_party/xla/xla/service/gpu/tests/slice_to_dynamic.hlo b/third_party/xla/xla/service/gpu/tests/slice_to_dynamic.hlo index fa7511e282aa93..242bd749bdaf11 100644 --- a/third_party/xla/xla/service/gpu/tests/slice_to_dynamic.hlo +++ b/third_party/xla/xla/service/gpu/tests/slice_to_dynamic.hlo @@ -34,7 +34,7 @@ // CHECK: %[[VAL_26:.*]] = udiv i32 %[[VAL_21]], 4 // CHECK: %[[VAL_27:.*]] = icmp ult i32 %[[VAL_19]], 8 // CHECK: br i1 %[[VAL_27]], label %[[VAL_28:.*]], label %[[VAL_29:.*]] -// CHECK: custom_call.in_bounds-after: ; preds = %[[VAL_30:.*]], %[[VAL_11]] +// CHECK: custom-call.in_bounds-after: ; preds = %[[VAL_30:.*]], %[[VAL_11]] // CHECK: ret void // CHECK: is_thread_0-true: ; preds = %[[VAL_12]] // CHECK: %[[VAL_31:.*]] = getelementptr inbounds i8, ptr %[[VAL_32:.*]], i32 32 @@ -44,7 +44,7 @@ // CHECK: %[[VAL_34:.*]] = getelementptr inbounds i8, ptr %[[VAL_32]], i32 40 // CHECK: store i32 %[[VAL_4]], ptr %[[VAL_34]], align 4 // CHECK: br label %[[VAL_11]] -// CHECK: custom_call.in_bounds-true: ; preds = %[[VAL_11]] +// CHECK: custom-call.in_bounds-true: ; preds = %[[VAL_11]] // CHECK: %[[VAL_35:.*]] = mul nuw nsw i32 %[[VAL_23]], 1 // CHECK: %[[VAL_36:.*]] = add nuw nsw i32 0, %[[VAL_35]] // CHECK: %[[VAL_37:.*]] = mul nuw nsw i32 %[[VAL_25]], 2 @@ -53,9 +53,9 @@ // CHECK: %[[VAL_40:.*]] = add nuw nsw i32 %[[VAL_38]], %[[VAL_39]] // CHECK: %[[VAL_41:.*]] = icmp ult i32 %[[VAL_40]], %[[VAL_15]] // CHECK: br i1 %[[VAL_41]], label %[[VAL_42:.*]], label %[[VAL_30]] -// CHECK: custom_call.in_dyn_bounds-after: ; preds = %[[VAL_42]], %[[VAL_28]] +// CHECK: custom-call.in_dyn_bounds-after: ; preds = %[[VAL_42]], %[[VAL_28]] // CHECK: br label %[[VAL_29]] -// CHECK: custom_call.in_dyn_bounds-true: ; preds = %[[VAL_28]] +// CHECK: custom-call.in_dyn_bounds-true: ; preds = %[[VAL_28]] // CHECK: %[[VAL_43:.*]] = udiv i32 %[[VAL_40]], 1 // CHECK: %[[VAL_44:.*]] = urem i32 %[[VAL_43]], %[[VAL_4]] // CHECK: %[[VAL_45:.*]] = mul i32 1, %[[VAL_4]] diff --git a/third_party/xla/xla/service/gpu/tests/test_autotune_cache.textproto b/third_party/xla/xla/service/gpu/tests/test_autotune_cache.textproto index 2dee80241f5da4..b20a9d20ece50a 100644 --- a/third_party/xla/xla/service/gpu/tests/test_autotune_cache.textproto +++ b/third_party/xla/xla/service/gpu/tests/test_autotune_cache.textproto @@ -1,4 +1,18 @@ -version: 2 +# Copyright 2023 The OpenXLA Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +version: 3 results { device: "sm_8.0 with 42331013120B RAM, 108 cores, 1410000KHz clock, 1215000KHz mem clock, 41943040B L2$" hlo: "{\n tmp_0 = f16[1,16,17,3]{3,2,1,0} parameter(0)\n tmp_1 = f16[16,51]{1,0} bitcast(f16[1,16,17,3]{3,2,1,0} tmp_0)\n tmp_2 = s8[16,17,3]{2,1,0} parameter(1)\n tmp_3 = s8[51,16]{0,1} bitcast(s8[16,17,3]{2,1,0} tmp_2)\n tmp_4 = f16[51,16]{0,1} convert(s8[51,16]{0,1} tmp_3)\n tmp_5 = f16[16,16]{1,0} dot(f16[16,51]{1,0} tmp_1, f16[51,16]{0,1} tmp_4), lhs_contracting_dims={1}, rhs_contracting_dims={0}\n ROOT tmp_6 = f16[1,16,16]{2,1,0} bitcast(f16[16,16]{1,0} tmp_5)\n}" @@ -13,6 +27,7 @@ results { split_k: 1 num_stages: 1 num_warps: 4 + num_ctas: 1 } } } diff --git a/third_party/xla/xla/service/gpu/tests/transpose_021.hlo b/third_party/xla/xla/service/gpu/tests/transpose_021.hlo index 9377220756f4ee..370779aed2142a 100644 --- a/third_party/xla/xla/service/gpu/tests/transpose_021.hlo +++ b/third_party/xla/xla/service/gpu/tests/transpose_021.hlo @@ -18,114 +18,90 @@ ENTRY main { // CHECK: %[[VAL_1:.*]] = alloca i32, align 4 // CHECK: %[[VAL_2:.*]] = alloca i32, align 4 // CHECK: %[[VAL_3:.*]] = alloca i32, align 4 -// CHECK-PTX: %[[VAL_4:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range !2 -// CHECK-GCN: %[[VAL_4:.*]] = call i32 @llvm.amdgcn.workitem.id.x -// CHECK-PTX: %[[VAL_5:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range !3 -// CHECK-GCN: %[[VAL_5:.*]] = call i32 @llvm.amdgcn.workgroup.id.x -// CHECK: %[[VAL_6:.*]] = urem i32 %[[VAL_4]], 128 -// CHECK: %[[VAL_7:.*]] = udiv i32 %[[VAL_4]], 128 -// CHECK: %[[VAL_8:.*]] = mul i32 %[[VAL_5]], 1 -// CHECK: %[[VAL_9:.*]] = add i32 %[[VAL_8]], %[[VAL_7]] -// CHECK: %[[VAL_10:.*]] = icmp ult i32 %[[VAL_9]], 2 -// CHECK: br i1 %[[VAL_10]], label %[[VAL_11:.*]], label %[[VAL_12:.*]] -// CHECK: 7: ; preds = %[[VAL_13:.*]] -// CHECK: %[[VAL_15:.*]] = udiv i32 %[[VAL_6]], 32 -// CHECK: %[[VAL_14:.*]] = urem i32 %[[VAL_6]], 32 -// CHECK: %[[VAL_37:.*]] = mul i32 %[[VAL_14]], 1 -// CHECK: %[[VAL_16:.*]] = urem i32 %[[VAL_6]], 32 -// CHECK: %[[VAL_17:.*]] = udiv i32 %[[VAL_9]], 1 -// CHECK: %[[VAL_18:.*]] = urem i32 %[[VAL_17]], 1 -// CHECK: %[[VAL_19:.*]] = udiv i32 %[[VAL_9]], 1 -// CHECK: %[[VAL_20:.*]] = urem i32 %[[VAL_19]], 1 -// CHECK: %[[VAL_21:.*]] = udiv i32 %[[VAL_9]], 1 -// CHECK: %[[VAL_22:.*]] = icmp eq i32 %[[VAL_20]], 0 -// CHECK: %[[VAL_23:.*]] = select i1 %[[VAL_22]], i32 16, i32 32 -// CHECK: %[[VAL_24:.*]] = icmp eq i32 %[[VAL_18]], 0 -// CHECK: %[[VAL_25:.*]] = select i1 %[[VAL_24]], i32 17, i32 32 -// CHECK: %[[VAL_26:.*]] = mul i32 %[[VAL_21]], 1 -// CHECK: %[[VAL_27:.*]] = mul i32 %[[VAL_20]], 32 -// CHECK: %[[VAL_28:.*]] = mul i32 %[[VAL_18]], 32 -// CHECK: store i32 %[[VAL_15]], ptr %[[VAL_3]], align 4 -// CHECK: br label %[[VAL_29:.*]] -// CHECK: loop1.loop_header: ; preds = %[[VAL_30:.*]], %[[VAL_11]] -// CHECK: %[[VAL_31:.*]] = load i32, ptr %[[VAL_3]], align 4 -// CHECK: %[[VAL_32:.*]] = icmp uge i32 %[[VAL_31]], %[[VAL_23]] -// CHECK: br i1 %[[VAL_32]], label %[[VAL_33:.*]], label %[[VAL_34:.*]] -// CHECK: loop1.loop_body: ; preds = %[[VAL_29]] -// CHECK: %[[VAL_35:.*]] = add nuw nsw i32 %[[VAL_31]], 4 -// CHECK: store i32 %[[VAL_35]], ptr %[[VAL_3]], align 4 -// CHECK: %[[VAL_36:.*]] = icmp eq i32 %[[VAL_31]], %[[VAL_15]] -// CHECK: store i32 0, ptr %[[VAL_2]], align 4 -// CHECK: br label %[[VAL_38:.*]] -// CHECK: loop2.loop_header: ; preds = %[[VAL_39:.*]], %[[VAL_34]] -// CHECK: %[[VAL_40:.*]] = load i32, ptr %[[VAL_2]], align 4 -// CHECK: %[[VAL_41:.*]] = icmp uge i32 %[[VAL_40]], 1 -// CHECK: br i1 %[[VAL_41]], label %[[VAL_30]], label %[[VAL_42:.*]] -// CHECK: loop2.loop_body: ; preds = %[[VAL_38]] -// CHECK: %[[VAL_43:.*]] = add nuw nsw i32 %[[VAL_40]], 1 -// CHECK: store i32 %[[VAL_43]], ptr %[[VAL_2]], align 4 -// CHECK: %[[VAL_44:.*]] = icmp eq i32 %[[VAL_40]], 0 -// CHECK: %[[VAL_45:.*]] = mul i32 %[[VAL_40]], 32 -// CHECK: %[[VAL_46:.*]] = add i32 %[[VAL_45]], 0 -// CHECK: %[[VAL_47:.*]] = add i32 %[[VAL_46]], %[[VAL_37]] -// CHECK: %[[VAL_48:.*]] = icmp ult i32 %[[VAL_47]], %[[VAL_25]] -// CHECK: br i1 %[[VAL_48]], label %[[VAL_49:.*]], label %[[VAL_39]] -// CHECK: x_in_tile-after: ; preds = %[[VAL_49]], %[[VAL_42]] -// CHECK: br label %[[VAL_38]], !llvm.loop !4 -// CHECK: loop2.loop_exit: ; preds = %[[VAL_38]] -// CHECK: br label %[[VAL_29]], !llvm.loop !7 -// CHECK: loop1.loop_exit: ; preds = %[[VAL_29]] -// CHECK: call void @llvm.nvvm.barrier0() -// CHECK: store i32 %[[VAL_15]], ptr %[[VAL_1]], align 4 -// CHECK: br label %[[VAL_50:.*]] -// CHECK: loop1.loop_header5: ; preds = %[[VAL_51:.*]], %[[VAL_33]] -// CHECK: %[[VAL_52:.*]] = load i32, ptr %[[VAL_1]], align 4 -// CHECK: %[[VAL_53:.*]] = icmp uge i32 %[[VAL_52]], %[[VAL_25]] -// CHECK: br i1 %[[VAL_53]], label %[[VAL_54:.*]], label %[[VAL_55:.*]] -// CHECK: loop1.loop_body6: ; preds = %[[VAL_50]] -// CHECK: %[[VAL_56:.*]] = add nuw nsw i32 %[[VAL_52]], 4 -// CHECK: store i32 %[[VAL_56]], ptr %[[VAL_1]], align 4 -// CHECK: %[[VAL_57:.*]] = icmp eq i32 %[[VAL_52]], %[[VAL_15]] -// CHECK: store i32 0, ptr %[[VAL_0]], align 4 -// CHECK: br label %[[VAL_59:.*]] -// CHECK: loop2.loop_header11: ; preds = %[[VAL_60:.*]], %[[VAL_55]] -// CHECK: %[[VAL_61:.*]] = load i32, ptr %[[VAL_0]], align 4 -// CHECK: %[[VAL_62:.*]] = icmp uge i32 %[[VAL_61]], 1 -// CHECK: br i1 %[[VAL_62]], label %[[VAL_51]], label %[[VAL_63:.*]] -// CHECK: loop2.loop_body12: ; preds = %[[VAL_59]] -// CHECK: %[[VAL_64:.*]] = add nuw nsw i32 %[[VAL_61]], 1 -// CHECK: store i32 %[[VAL_64]], ptr %[[VAL_0]], align 4 -// CHECK: %[[VAL_65:.*]] = icmp eq i32 %[[VAL_61]], 0 -// CHECK: %[[VAL_66:.*]] = mul i32 %[[VAL_61]], 32 -// CHECK: %[[VAL_67:.*]] = add i32 %[[VAL_66]], 0 -// CHECK: %[[VAL_68:.*]] = add i32 %[[VAL_67]], %[[VAL_37]] -// CHECK: %[[VAL_69:.*]] = icmp ult i32 %[[VAL_68]], %[[VAL_23]] -// CHECK: br i1 %[[VAL_69]], label %[[VAL_70:.*]], label %[[VAL_60]] -// CHECK: x_in_tile-after17: ; preds = %[[VAL_70]], %[[VAL_63]] -// CHECK: br label %[[VAL_59]], !llvm.loop !8 -// CHECK: loop2.loop_exit10: ; preds = %[[VAL_59]] -// CHECK: br label %[[VAL_50]], !llvm.loop !9 -// CHECK: loop1.loop_exit4: ; preds = %[[VAL_50]] +// CHECK-PTX: %thread.id.x = call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range !2 +// CHECK-GCN: %thread.id.x = call i32 @llvm.amdgcn.workitem.id.x +// CHECK-PTX: %block.id.x = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range !3 +// CHECK-GCN: %block.id.x = call i32 @llvm.amdgcn.workgroup.id.x +// CHECK: %[[VAL_4:.*]] = udiv i32 %thread.id.x, 32 +// CHECK: %thread.id.1 = urem i32 %[[VAL_4]], 4 +// CHECK: %thread.id.2 = urem i32 %thread.id.x, 32 +// CHECK: %lane_id = urem i32 %thread.id.x, 32 +// CHECK: %[[VAL_5:.*]] = udiv i32 %block.id.x, 1 +// CHECK: %[[VAL_6:.*]] = urem i32 %[[VAL_5]], 1 +// CHECK: %[[VAL_7:.*]] = udiv i32 %block.id.x, 1 +// CHECK: %[[VAL_8:.*]] = urem i32 %[[VAL_7]], 1 +// CHECK: %[[VAL_9:.*]] = udiv i32 %block.id.x, 1 +// CHECK: %[[VAL_10:.*]] = icmp eq i32 %[[VAL_8]], 0 +// CHECK: %tile_bound.1 = select i1 %[[VAL_10]], i32 16, i32 32 +// CHECK: %[[VAL_11:.*]] = icmp eq i32 %[[VAL_6]], 0 +// CHECK: %tile_bound.2 = select i1 %[[VAL_11]], i32 17, i32 32 +// CHECK: %tile_origin.0 = mul i32 %[[VAL_9]], 1 +// CHECK: %tile_origin.1 = mul i32 %[[VAL_8]], 32 +// CHECK: %tile_origin.2 = mul i32 %[[VAL_6]], 32 +// CHECK: store i32 %thread.id.1, ptr{{.*}} %[[VAL_3]], align 4 +// CHECK: br label %[[VAL_12:.*]] +// CHECK: loop1.loop_header: ; preds = %[[VAL_13:.*]], %[[VAL_14:.*]] +// CHECK: %[[VAL_15:.*]] = load i32, ptr{{.*}} %[[VAL_3]], align 4 +// CHECK: %[[VAL_16:.*]] = icmp uge i32 %[[VAL_15]], %tile_bound.1 +// CHECK: br i1 %[[VAL_16]], label %[[VAL_17:.*]], label %[[VAL_18:.*]] +// CHECK: loop1.loop_body: ; preds = %[[VAL_12]] +// CHECK: %[[VAL_19:.*]] = add nuw nsw i32 %[[VAL_15]], 4 +// CHECK: store i32 %[[VAL_19]], ptr{{.*}} %[[VAL_3]], align 4 +// CHECK: %[[VAL_20:.*]] = icmp eq i32 %[[VAL_15]], %thread.id.1 +// CHECK: store i32 %thread.id.2, ptr{{.*}} %[[VAL_2]], align 4 +// CHECK: br label %[[VAL_21:.*]] +// CHECK: loop2.loop_header: ; preds = %[[VAL_22:.*]], %[[VAL_18]] +// CHECK: %[[VAL_23:.*]] = load i32, ptr{{.*}} %[[VAL_2]], align 4 +// CHECK: %[[VAL_24:.*]] = icmp uge i32 %[[VAL_23]], %tile_bound.2 +// CHECK: br i1 %[[VAL_24]], label %[[VAL_13]], label %[[VAL_22]] +// CHECK: loop2.loop_body: ; preds = %[[VAL_21]] +// CHECK: %[[VAL_25:.*]] = add nuw nsw i32 %[[VAL_23]], 32 +// CHECK: store i32 %[[VAL_25]], ptr{{.*}} %[[VAL_2]], align 4 +// CHECK: %[[VAL_26:.*]] = icmp eq i32 %[[VAL_23]], %thread.id.2 +// CHECK: %[[VAL_27:.*]] = add i32 %tile_origin.0, 0 +// CHECK: %[[VAL_28:.*]] = add i32 %tile_origin.1, %[[VAL_15]] +// CHECK: %[[VAL_29:.*]] = add i32 %tile_origin.2, %[[VAL_23]] +// CHECK: %[[VAL_30:.*]] = getelementptr{{.*}} inbounds [2 x [16 x [17 x float]]], ptr{{.*}} %[[VAL_31:.*]], i32 0, i32 %[[VAL_27]], i32 %[[VAL_28]], i32 %[[VAL_29]] +// CHECK: %[[VAL_32:.*]] = load float, ptr{{.*}} %[[VAL_30]], align 4, !invariant.load !{{[0-9]}} +// CHECK: %[[VAL_33:.*]] = getelementptr{{.*}} inbounds [1 x [32 x [33 x float]]], ptr{{.*}} addrspace(3) @tr_tile_0, i32 0, i32 0, i32 %[[VAL_15]], i32 %[[VAL_23]] +// CHECK: %[[VAL_34:.*]] = addrspacecast ptr{{.*}} addrspace(3) %[[VAL_33]] to ptr +// CHECK: store float %[[VAL_32]], ptr{{.*}} %[[VAL_34]], align 4 +// CHECK: br label %[[VAL_21]], !llvm.loop !{{[0-9]}} +// CHECK: loop2.loop_exit: ; preds = %[[VAL_21]] +// CHECK: br label %[[VAL_12]], !llvm.loop !{{[0-9]}} +// CHECK: loop1.loop_exit: ; preds = %[[VAL_12]] +// CHECK-PTX: call void @llvm.nvvm.barrier0() +// CHECK-GCN: call void @llvm.amdgcn.s.barrier() +// CHECK: store i32 %thread.id.1, ptr{{.*}} %[[VAL_1]], align 4 +// CHECK: br label %[[VAL_35:.*]] +// CHECK: loop1.loop_header4: ; preds = %[[VAL_36:.*]], %[[VAL_17]] +// CHECK: %[[VAL_37:.*]] = load i32, ptr{{.*}} %[[VAL_1]], align 4 +// CHECK: %[[VAL_38:.*]] = icmp uge i32 %[[VAL_37]], %tile_bound.2 +// CHECK: br i1 %[[VAL_38]], label %[[VAL_39:.*]], label %[[VAL_40:.*]] +// CHECK: loop1.loop_body5: ; preds = %[[VAL_35]] +// CHECK: %[[VAL_41:.*]] = add nuw nsw i32 %[[VAL_37]], 4 +// CHECK: store i32 %[[VAL_41]], ptr{{.*}} %[[VAL_1]], align 4 +// CHECK: %[[VAL_42:.*]] = icmp eq i32 %[[VAL_37]], %thread.id.1 +// CHECK: store i32 %thread.id.2, ptr{{.*}} %[[VAL_0]], align 4 +// CHECK: br label %[[VAL_43:.*]] +// CHECK: loop2.loop_header10: ; preds = %[[VAL_44:.*]], %[[VAL_40]] +// CHECK: %[[VAL_45:.*]] = load i32, ptr{{.*}} %[[VAL_0]], align 4 +// CHECK: %[[VAL_46:.*]] = icmp uge i32 %[[VAL_45]], %tile_bound.1 +// CHECK: br i1 %[[VAL_46]], label %[[VAL_36]], label %[[VAL_44]] +// CHECK: loop2.loop_body11: ; preds = %[[VAL_43]] +// CHECK: %[[VAL_47:.*]] = add nuw nsw i32 %[[VAL_45]], 32 +// CHECK: store i32 %[[VAL_47]], ptr{{.*}} %[[VAL_0]], align 4 +// CHECK: %[[VAL_48:.*]] = icmp eq i32 %[[VAL_45]], %thread.id.2 +// CHECK: %[[VAL_49:.*]] = add i32 %tile_origin.0, 0 +// CHECK: %[[VAL_50:.*]] = add i32 %tile_origin.2, %[[VAL_37]] +// CHECK: %[[VAL_51:.*]] = add i32 %tile_origin.1, %[[VAL_45]] +// CHECK: %[[VAL_52:.*]] = getelementptr{{.*}} inbounds [1 x [32 x [33 x float]]], ptr{{.*}} addrspace(3) @tr_tile_0, i32 0, i32 0, i32 %[[VAL_45]], i32 %[[VAL_37]] +// CHECK: %[[VAL_53:.*]] = addrspacecast ptr{{.*}} addrspace(3) %[[VAL_52]] to ptr +// CHECK: %[[VAL_54:.*]] = load float, ptr{{.*}} %[[VAL_53]], align 4 +// CHECK: %[[VAL_55:.*]] = getelementptr{{.*}} inbounds [2 x [17 x [16 x float]]], ptr{{.*}} %[[VAL_56:.*]], i32 0, i32 %[[VAL_49]], i32 %[[VAL_50]], i32 %[[VAL_51]] +// CHECK: store float %[[VAL_54]], ptr{{.*}} %[[VAL_55]], align 4 +// CHECK: br label %[[VAL_43]], !llvm.loop !{{[0-9]}} +// CHECK: loop2.loop_exit9: ; preds = %[[VAL_43]] +// CHECK: br label %[[VAL_35]], !llvm.loop !{{[0-9]}} +// CHECK: loop1.loop_exit3: ; preds = %[[VAL_35]] // CHECK: ret void -// CHECK: early_return: ; preds = %[[VAL_13]] -// CHECK: ret void -// CHECK: x_in_tile-true: ; preds = %[[VAL_42]] -// CHECK: %[[TILE_0:.*]] = add i32 %[[VAL_26]], 0 -// CHECK: %[[VAL_71:.*]] = add i32 %[[VAL_27]], %[[VAL_31]] -// CHECK: %[[VAL_72:.*]] = add i32 %[[VAL_28]], %[[VAL_47]] -// CHECK: %[[VAL_73:.*]] = getelementptr inbounds [2 x [16 x [17 x float]]], ptr %[[VAL_74:.*]], i32 0, i32 %[[TILE_0]], i32 %[[VAL_71]], i32 %[[VAL_72]] -// CHECK: %[[VAL_75:.*]] = load float, ptr %[[VAL_73]], align 4, !invariant.load !10 -// CHECK: %[[VAL_76:.*]] = getelementptr inbounds [1 x [32 x [33 x float]]], ptr addrspace(3) @tr_tile_0, i32 0, i32 0, i32 %[[VAL_31]], i32 %[[VAL_47]] -// CHECK: %[[VAL_77:.*]] = addrspacecast ptr addrspace(3) %[[VAL_76]] to ptr -// CHECK: store float %[[VAL_75]], ptr %[[VAL_77]], align 4 -// CHECK: br label %[[VAL_39]] -// CHECK: x_in_tile-true16: ; preds = %[[VAL_63]] -// CHECK: %[[TILE_0:.*]] = add i32 %[[VAL_26]], 0 -// CHECK: %[[VAL_78:.*]] = add i32 %[[VAL_28]], %[[VAL_52]] -// CHECK: %[[VAL_79:.*]] = add i32 %[[VAL_27]], %[[VAL_68]] -// CHECK: %[[VAL_80:.*]] = getelementptr inbounds [1 x [32 x [33 x float]]], ptr addrspace(3) @tr_tile_0, i32 0, i32 0, i32 %[[VAL_68]], i32 %[[VAL_52]] -// CHECK: %[[VAL_81:.*]] = addrspacecast ptr addrspace(3) %[[VAL_80]] to ptr -// CHECK: %[[VAL_82:.*]] = load float, ptr %[[VAL_81]], align 4 -// CHECK: %[[VAL_83:.*]] = getelementptr inbounds [2 x [17 x [16 x float]]], ptr %[[VAL_84:.*]], i32 0, i32 %[[TILE_0]], i32 %[[VAL_78]], i32 %[[VAL_79]] -// CHECK: store float %[[VAL_82]], ptr %[[VAL_83]], align 4 -// CHECK: br label %[[VAL_60]] diff --git a/third_party/xla/xla/service/gpu/tests/transpose_021_extra_output.hlo b/third_party/xla/xla/service/gpu/tests/transpose_021_extra_output.hlo index 017805b00ef9dd..b285b335e91610 100644 --- a/third_party/xla/xla/service/gpu/tests/transpose_021_extra_output.hlo +++ b/third_party/xla/xla/service/gpu/tests/transpose_021_extra_output.hlo @@ -15,124 +15,101 @@ ENTRY main { ROOT %fusion = (f32[2,16,17], f32[2,17,16]) fusion(%param), kind=kInput, calls=%fused_computation } + // CHECK-LABEL: entry: // CHECK: %[[VAL_0:.*]] = alloca i32, align 4 // CHECK: %[[VAL_1:.*]] = alloca i32, align 4 // CHECK: %[[VAL_2:.*]] = alloca i32, align 4 // CHECK: %[[VAL_3:.*]] = alloca i32, align 4 -// CHECK-PTX: %[[VAL_4:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range !2 -// CHECK-GCN: %[[VAL_4:.*]] = call i32 @llvm.amdgcn.workitem.id.x -// CHECK-PTX: %[[VAL_5:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range !3 -// CHECK-GCN: %[[VAL_5:.*]] = call i32 @llvm.amdgcn.workgroup.id.x -// CHECK: %[[VAL_6:.*]] = urem i32 %[[VAL_4]], 128 -// CHECK: %[[VAL_7:.*]] = udiv i32 %[[VAL_4]], 128 -// CHECK: %[[VAL_8:.*]] = mul i32 %[[VAL_5]], 1 -// CHECK: %[[VAL_9:.*]] = add i32 %[[VAL_8]], %[[VAL_7]] -// CHECK: %[[VAL_10:.*]] = icmp ult i32 %[[VAL_9]], 2 -// CHECK: br i1 %[[VAL_10]], label %[[VAL_11:.*]], label %[[VAL_12:.*]] -// CHECK: 7: ; preds = %[[VAL_13:.*]] -// CHECK: %[[VAL_15:.*]] = udiv i32 %[[VAL_6]], 32 -// CHECK: %[[VAL_14:.*]] = urem i32 %[[VAL_6]], 32 -// CHECK: %[[VAL_37:.*]] = mul i32 %[[VAL_14]], 1 -// CHECK: %[[VAL_16:.*]] = urem i32 %[[VAL_6]], 32 -// CHECK: %[[VAL_17:.*]] = udiv i32 %[[VAL_9]], 1 -// CHECK: %[[VAL_18:.*]] = urem i32 %[[VAL_17]], 1 -// CHECK: %[[VAL_19:.*]] = udiv i32 %[[VAL_9]], 1 -// CHECK: %[[VAL_20:.*]] = urem i32 %[[VAL_19]], 1 -// CHECK: %[[VAL_21:.*]] = udiv i32 %[[VAL_9]], 1 -// CHECK: %[[VAL_22:.*]] = icmp eq i32 %[[VAL_20]], 0 -// CHECK: %[[VAL_23:.*]] = select i1 %[[VAL_22]], i32 16, i32 32 -// CHECK: %[[VAL_24:.*]] = icmp eq i32 %[[VAL_18]], 0 -// CHECK: %[[VAL_25:.*]] = select i1 %[[VAL_24]], i32 17, i32 32 -// CHECK: %[[VAL_26:.*]] = mul i32 %[[VAL_21]], 1 -// CHECK: %[[VAL_27:.*]] = mul i32 %[[VAL_20]], 32 -// CHECK: %[[VAL_28:.*]] = mul i32 %[[VAL_18]], 32 -// CHECK: store i32 %[[VAL_15]], ptr %[[VAL_3]], align 4 -// CHECK: br label %[[VAL_29:.*]] -// CHECK: loop1.loop_header: ; preds = %[[VAL_30:.*]], %[[VAL_11]] -// CHECK: %[[VAL_31:.*]] = load i32, ptr %[[VAL_3]], align 4 -// CHECK: %[[VAL_32:.*]] = icmp uge i32 %[[VAL_31]], %[[VAL_23]] -// CHECK: br i1 %[[VAL_32]], label %[[VAL_33:.*]], label %[[VAL_34:.*]] -// CHECK: loop1.loop_body: ; preds = %[[VAL_29]] -// CHECK: %[[VAL_35:.*]] = add nuw nsw i32 %[[VAL_31]], 4 -// CHECK: store i32 %[[VAL_35]], ptr %[[VAL_3]], align 4 -// CHECK: %[[VAL_36:.*]] = icmp eq i32 %[[VAL_31]], %[[VAL_15]] -// CHECK: store i32 0, ptr %[[VAL_2]], align 4 -// CHECK: br label %[[VAL_38:.*]] -// CHECK: loop2.loop_header: ; preds = %[[VAL_39:.*]], %[[VAL_34]] -// CHECK: %[[VAL_40:.*]] = load i32, ptr %[[VAL_2]], align 4 -// CHECK: %[[VAL_41:.*]] = icmp uge i32 %[[VAL_40]], 1 -// CHECK: br i1 %[[VAL_41]], label %[[VAL_30]], label %[[VAL_42:.*]] -// CHECK: loop2.loop_body: ; preds = %[[VAL_38]] -// CHECK: %[[VAL_43:.*]] = add nuw nsw i32 %[[VAL_40]], 1 -// CHECK: store i32 %[[VAL_43]], ptr %[[VAL_2]], align 4 -// CHECK: %[[VAL_44:.*]] = icmp eq i32 %[[VAL_40]], 0 -// CHECK: %[[VAL_45:.*]] = mul i32 %[[VAL_40]], 32 -// CHECK: %[[VAL_46:.*]] = add i32 %[[VAL_45]], 0 -// CHECK: %[[VAL_47:.*]] = add i32 %[[VAL_46]], %[[VAL_37]] -// CHECK: %[[VAL_48:.*]] = icmp ult i32 %[[VAL_47]], %[[VAL_25]] -// CHECK: br i1 %[[VAL_48]], label %[[VAL_49:.*]], label %[[VAL_39]] -// CHECK: x_in_tile-after: ; preds = %[[VAL_49]], %[[VAL_42]] -// CHECK: br label %[[VAL_38]], !llvm.loop !4 -// CHECK: loop2.loop_exit: ; preds = %[[VAL_38]] -// CHECK: br label %[[VAL_29]], !llvm.loop !7 -// CHECK: loop1.loop_exit: ; preds = %[[VAL_29]] -// CHECK: call void @llvm.nvvm.barrier0() -// CHECK: store i32 %[[VAL_15]], ptr %[[VAL_1]], align 4 -// CHECK: br label %[[VAL_50:.*]] -// CHECK: loop1.loop_header7: ; preds = %[[VAL_51:.*]], %[[VAL_33]] -// CHECK: %[[VAL_52:.*]] = load i32, ptr %[[VAL_1]], align 4 -// CHECK: %[[VAL_53:.*]] = icmp uge i32 %[[VAL_52]], %[[VAL_25]] -// CHECK: br i1 %[[VAL_53]], label %[[VAL_54:.*]], label %[[VAL_55:.*]] -// CHECK: loop1.loop_body8: ; preds = %[[VAL_50]] -// CHECK: %[[VAL_56:.*]] = add nuw nsw i32 %[[VAL_52]], 4 -// CHECK: store i32 %[[VAL_56]], ptr %[[VAL_1]], align 4 -// CHECK: %[[VAL_57:.*]] = icmp eq i32 %[[VAL_52]], %[[VAL_15]] -// CHECK: store i32 0, ptr %[[VAL_0]], align 4 -// CHECK: br label %[[VAL_59:.*]] -// CHECK: loop2.loop_header13: ; preds = %[[VAL_60:.*]], %[[VAL_55]] -// CHECK: %[[VAL_61:.*]] = load i32, ptr %[[VAL_0]], align 4 -// CHECK: %[[VAL_62:.*]] = icmp uge i32 %[[VAL_61]], 1 -// CHECK: br i1 %[[VAL_62]], label %[[VAL_51]], label %[[VAL_63:.*]] -// CHECK: loop2.loop_body14: ; preds = %[[VAL_59]] -// CHECK: %[[VAL_64:.*]] = add nuw nsw i32 %[[VAL_61]], 1 -// CHECK: store i32 %[[VAL_64]], ptr %[[VAL_0]], align 4 -// CHECK: %[[VAL_65:.*]] = icmp eq i32 %[[VAL_61]], 0 -// CHECK: %[[VAL_66:.*]] = mul i32 %[[VAL_61]], 32 -// CHECK: %[[VAL_67:.*]] = add i32 %[[VAL_66]], 0 -// CHECK: %[[VAL_68:.*]] = add i32 %[[VAL_67]], %[[VAL_37]] -// CHECK: %[[VAL_69:.*]] = icmp ult i32 %[[VAL_68]], %[[VAL_23]] -// CHECK: br i1 %[[VAL_69]], label %[[VAL_70:.*]], label %[[VAL_60]] -// CHECK: x_in_tile-after19: ; preds = %[[VAL_70]], %[[VAL_63]] -// CHECK: br label %[[VAL_59]], !llvm.loop !8 -// CHECK: loop2.loop_exit12: ; preds = %[[VAL_59]] -// CHECK: br label %[[VAL_50]], !llvm.loop !9 -// CHECK: loop1.loop_exit6: ; preds = %[[VAL_50]] -// CHECK: ret void -// CHECK: early_return: ; preds = %[[VAL_13]] +// CHECK-PTX: %thread.id.x = call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range !2 +// CHECK-GCN: %thread.id.x = call i32 @llvm.amdgcn.workitem.id.x +// CHECK-PTX: %block.id.x = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range !3 +// CHECK-GCN: %block.id.x = call i32 @llvm.amdgcn.workgroup.id.x +// CHECK: %[[VAL_4:.*]] = udiv i32 %thread.id.x, 32 +// CHECK: %thread.id.1 = urem i32 %[[VAL_4]], 4 +// CHECK: %thread.id.2 = urem i32 %thread.id.x, 32 +// CHECK: %lane_id = urem i32 %thread.id.x, 32 +// CHECK: %[[VAL_5:.*]] = udiv i32 %block.id.x, 1 +// CHECK: %[[VAL_6:.*]] = urem i32 %[[VAL_5]], 1 +// CHECK: %[[VAL_7:.*]] = udiv i32 %block.id.x, 1 +// CHECK: %[[VAL_8:.*]] = urem i32 %[[VAL_7]], 1 +// CHECK: %[[VAL_9:.*]] = udiv i32 %block.id.x, 1 +// CHECK: %[[VAL_10:.*]] = icmp eq i32 %[[VAL_8]], 0 +// CHECK: %tile_bound.1 = select i1 %[[VAL_10]], i32 16, i32 32 +// CHECK: %[[VAL_11:.*]] = icmp eq i32 %[[VAL_6]], 0 +// CHECK: %tile_bound.2 = select i1 %[[VAL_11]], i32 17, i32 32 +// CHECK: %tile_origin.0 = mul i32 %[[VAL_9]], 1 +// CHECK: %tile_origin.1 = mul i32 %[[VAL_8]], 32 +// CHECK: %tile_origin.2 = mul i32 %[[VAL_6]], 32 +// CHECK: store i32 %thread.id.1, ptr{{.*}} %[[VAL_3]], align 4 +// CHECK: br label %[[VAL_12:.*]] +// CHECK: loop1.loop_header: ; preds = %[[VAL_13:.*]], %[[VAL_14:.*]] +// CHECK: %[[VAL_15:.*]] = load i32, ptr{{.*}} %[[VAL_3]], align 4 +// CHECK: %[[VAL_16:.*]] = icmp uge i32 %[[VAL_15]], %tile_bound.1 +// CHECK: br i1 %[[VAL_16]], label %[[VAL_17:.*]], label %[[VAL_18:.*]] +// CHECK: loop1.loop_body: ; preds = %[[VAL_12]] +// CHECK: %[[VAL_19:.*]] = add nuw nsw i32 %[[VAL_15]], 4 +// CHECK: store i32 %[[VAL_19]], ptr{{.*}} %[[VAL_3]], align 4 +// CHECK: %[[VAL_20:.*]] = icmp eq i32 %[[VAL_15]], %thread.id.1 +// CHECK: store i32 %thread.id.2, ptr{{.*}} %[[VAL_2]], align 4 +// CHECK: br label %[[VAL_21:.*]] +// CHECK: loop2.loop_header: ; preds = %[[VAL_22:.*]], %[[VAL_18]] +// CHECK: %[[VAL_23:.*]] = load i32, ptr{{.*}} %[[VAL_2]], align 4 +// CHECK: %[[VAL_24:.*]] = icmp uge i32 %[[VAL_23]], %tile_bound.2 +// CHECK: br i1 %[[VAL_24]], label %[[VAL_13]], label %[[VAL_22]] +// CHECK: loop2.loop_body: ; preds = %[[VAL_21]] +// CHECK: %[[VAL_25:.*]] = add nuw nsw i32 %[[VAL_23]], 32 +// CHECK: store i32 %[[VAL_25]], ptr{{.*}} %[[VAL_2]], align 4 +// CHECK: %[[VAL_26:.*]] = icmp eq i32 %[[VAL_23]], %thread.id.2 +// CHECK: %[[VAL_27:.*]] = add i32 %tile_origin.0, 0 +// CHECK: %[[VAL_28:.*]] = add i32 %tile_origin.1, %[[VAL_15]] +// CHECK: %[[VAL_29:.*]] = add i32 %tile_origin.2, %[[VAL_23]] +// CHECK: %[[VAL_30:.*]] = getelementptr inbounds [2 x [16 x [17 x float]]], ptr{{.*}} %[[VAL_31:.*]], i32 0, i32 %[[VAL_27]], i32 %[[VAL_28]], i32 %[[VAL_29]] +// CHECK: %[[VAL_32:.*]] = load float, ptr{{.*}} %[[VAL_30]], align 4, !invariant.load !{{[0-9]}} +// CHECK: %[[VAL_33:.*]] = getelementptr inbounds [1 x [32 x [33 x float]]], ptr{{.*}} addrspace(3) @tr_tile_0, i32 0, i32 0, i32 %[[VAL_15]], i32 %[[VAL_23]] +// CHECK: %[[VAL_34:.*]] = addrspacecast ptr{{.*}} addrspace(3) %[[VAL_33]] to ptr +// CHECK: store float %[[VAL_32]], ptr{{.*}} %[[VAL_34]], align 4 +// CHECK: %[[VAL_35:.*]] = getelementptr inbounds [2 x [16 x [17 x float]]], ptr{{.*}} %[[VAL_31]], i32 0, i32 %[[VAL_27]], i32 %[[VAL_28]], i32 %[[VAL_29]] +// CHECK: %[[VAL_36:.*]] = load float, ptr{{.*}} %[[VAL_35]], align 4, !invariant.load !{{[0-9]}} +// CHECK: %[[VAL_37:.*]] = fneg float %[[VAL_36]] +// CHECK: %[[VAL_38:.*]] = getelementptr inbounds [2 x [16 x [17 x float]]], ptr{{.*}} %[[VAL_39:.*]], i32 0, i32 %[[VAL_27]], i32 %[[VAL_28]], i32 %[[VAL_29]] +// CHECK: store float %[[VAL_37]], ptr{{.*}} %[[VAL_38]], align 4 +// CHECK: br label %[[VAL_21]], !llvm.loop !{{[0-9]}} +// CHECK: loop2.loop_exit: ; preds = %[[VAL_21]] +// CHECK: br label %[[VAL_12]], !llvm.loop !{{[0-9]}} +// CHECK: loop1.loop_exit: ; preds = %[[VAL_12]] +// CHECK-PTX: call void @llvm.nvvm.barrier0() +// CHECK-GCN: call void @llvm.amdgcn.s.barrier() +// CHECK: store i32 %thread.id.1, ptr{{.*}} %[[VAL_1]], align 4 +// CHECK: br label %[[VAL_40:.*]] +// CHECK: loop1.loop_header6: ; preds = %[[VAL_41:.*]], %[[VAL_17]] +// CHECK: %[[VAL_42:.*]] = load i32, ptr{{.*}} %[[VAL_1]], align 4 +// CHECK: %[[VAL_43:.*]] = icmp uge i32 %[[VAL_42]], %tile_bound.2 +// CHECK: br i1 %[[VAL_43]], label %[[VAL_44:.*]], label %[[VAL_45:.*]] +// CHECK: loop1.loop_body7: ; preds = %[[VAL_40]] +// CHECK: %[[VAL_46:.*]] = add nuw nsw i32 %[[VAL_42]], 4 +// CHECK: store i32 %[[VAL_46]], ptr{{.*}} %[[VAL_1]], align 4 +// CHECK: %[[VAL_47:.*]] = icmp eq i32 %[[VAL_42]], %thread.id.1 +// CHECK: store i32 %thread.id.2, ptr{{.*}} %[[VAL_0]], align 4 +// CHECK: br label %[[VAL_48:.*]] +// CHECK: loop2.loop_header12: ; preds = %[[VAL_49:.*]], %[[VAL_45]] +// CHECK: %[[VAL_50:.*]] = load i32, ptr{{.*}} %[[VAL_0]], align 4 +// CHECK: %[[VAL_51:.*]] = icmp uge i32 %[[VAL_50]], %tile_bound.1 +// CHECK: br i1 %[[VAL_51]], label %[[VAL_41]], label %[[VAL_49]] +// CHECK: loop2.loop_body13: ; preds = %[[VAL_48]] +// CHECK: %[[VAL_52:.*]] = add nuw nsw i32 %[[VAL_50]], 32 +// CHECK: store i32 %[[VAL_52]], ptr{{.*}} %[[VAL_0]], align 4 +// CHECK: %[[VAL_53:.*]] = icmp eq i32 %[[VAL_50]], %thread.id.2 +// CHECK: %[[VAL_54:.*]] = add i32 %tile_origin.0, 0 +// CHECK: %[[VAL_55:.*]] = add i32 %tile_origin.2, %[[VAL_42]] +// CHECK: %[[VAL_56:.*]] = add i32 %tile_origin.1, %[[VAL_50]] +// CHECK: %[[VAL_57:.*]] = getelementptr inbounds [1 x [32 x [33 x float]]], ptr{{.*}} addrspace(3) @tr_tile_0, i32 0, i32 0, i32 %[[VAL_50]], i32 %[[VAL_42]] +// CHECK: %[[VAL_58:.*]] = addrspacecast ptr{{.*}} addrspace(3) %[[VAL_57]] to ptr +// CHECK: %[[VAL_59:.*]] = load float, ptr{{.*}} %[[VAL_58]], align 4 +// CHECK: %[[VAL_60:.*]] = getelementptr inbounds [2 x [17 x [16 x float]]], ptr{{.*}} %[[VAL_61:.*]], i32 0, i32 %[[VAL_54]], i32 %[[VAL_55]], i32 %[[VAL_56]] +// CHECK: store float %[[VAL_59]], ptr{{.*}} %[[VAL_60]], align 4 +// CHECK: br label %[[VAL_48]], !llvm.loop !{{[0-9]}} +// CHECK: loop2.loop_exit11: ; preds = %[[VAL_48]] +// CHECK: br label %[[VAL_40]], !llvm.loop !{{[0-9]}} +// CHECK: loop1.loop_exit5: ; preds = %[[VAL_40]] // CHECK: ret void -// CHECK: x_in_tile-true: ; preds = %[[VAL_42]] -// CHECK: %[[TILE_0:.*]] = add i32 %[[VAL_26]], 0 -// CHECK: %[[VAL_71:.*]] = add i32 %[[VAL_27]], %[[VAL_31]] -// CHECK: %[[VAL_72:.*]] = add i32 %[[VAL_28]], %[[VAL_47]] -// CHECK: %[[VAL_73:.*]] = getelementptr inbounds [2 x [16 x [17 x float]]], ptr %[[VAL_74:.*]], i32 0, i32 %[[TILE_0]], i32 %[[VAL_71]], i32 %[[VAL_72]] -// CHECK: %[[VAL_75:.*]] = load float, ptr %[[VAL_73]], align 4, !invariant.load !10 -// CHECK: %[[VAL_76:.*]] = getelementptr inbounds [1 x [32 x [33 x float]]], ptr addrspace(3) @tr_tile_0, i32 0, i32 0, i32 %[[VAL_31]], i32 %[[VAL_47]] -// CHECK: %[[VAL_77:.*]] = addrspacecast ptr addrspace(3) %[[VAL_76]] to ptr -// CHECK: store float %[[VAL_75]], ptr %[[VAL_77]], align 4 -// CHECK: %[[VAL_78:.*]] = getelementptr inbounds [2 x [16 x [17 x float]]], ptr %[[VAL_74]], i32 0, i32 %[[TILE_0]], i32 %[[VAL_71]], i32 %[[VAL_72]] -// CHECK: %[[VAL_79:.*]] = load float, ptr %[[VAL_78]], align 4, !invariant.load !10 -// CHECK: %[[VAL_80:.*]] = fneg float %[[VAL_79]] -// CHECK: %[[VAL_81:.*]] = getelementptr inbounds [2 x [16 x [17 x float]]], ptr %[[VAL_82:.*]], i32 0, i32 %[[TILE_0]], i32 %[[VAL_71]], i32 %[[VAL_72]] -// CHECK: store float %[[VAL_80]], ptr %[[VAL_81]], align 4 -// CHECK: br label %[[VAL_39]] -// CHECK: x_in_tile-true18: ; preds = %[[VAL_63]] -// CHECK: %[[TILE_0:.*]] = add i32 %[[VAL_26]], 0 -// CHECK: %[[VAL_83:.*]] = add i32 %[[VAL_28]], %[[VAL_52]] -// CHECK: %[[VAL_84:.*]] = add i32 %[[VAL_27]], %[[VAL_68]] -// CHECK: %[[VAL_85:.*]] = getelementptr inbounds [1 x [32 x [33 x float]]], ptr addrspace(3) @tr_tile_0, i32 0, i32 0, i32 %[[VAL_68]], i32 %[[VAL_52]] -// CHECK: %[[VAL_86:.*]] = addrspacecast ptr addrspace(3) %[[VAL_85]] to ptr -// CHECK: %[[VAL_87:.*]] = load float, ptr %[[VAL_86]], align 4 -// CHECK: %[[VAL_88:.*]] = getelementptr inbounds [2 x [17 x [16 x float]]], ptr %[[VAL_89:.*]], i32 0, i32 %[[TILE_0]], i32 %[[VAL_83]], i32 %[[VAL_84]] -// CHECK: store float %[[VAL_87]], ptr %[[VAL_88]], align 4 -// CHECK: br label %[[VAL_60]] diff --git a/third_party/xla/xla/service/gpu/tests/transpose_210.hlo b/third_party/xla/xla/service/gpu/tests/transpose_210.hlo index f510ba8e81b828..cf83fa7a8c0292 100644 --- a/third_party/xla/xla/service/gpu/tests/transpose_210.hlo +++ b/third_party/xla/xla/service/gpu/tests/transpose_210.hlo @@ -18,114 +18,89 @@ ENTRY main { // CHECK: %[[VAL_1:.*]] = alloca i32, align 4 // CHECK: %[[VAL_2:.*]] = alloca i32, align 4 // CHECK: %[[VAL_3:.*]] = alloca i32, align 4 -// CHECK-PTX: %[[VAL_4:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range !2 -// CHECK-GCN: %[[VAL_4:.*]] = call i32 @llvm.amdgcn.workitem.id.x -// CHECK-PTX: %[[VAL_5:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range !3 -// CHECK-GCN: %[[VAL_5:.*]] = call i32 @llvm.amdgcn.workgroup.id.x -// CHECK: %[[VAL_6:.*]] = urem i32 %[[VAL_4]], 128 -// CHECK: %[[VAL_7:.*]] = udiv i32 %[[VAL_4]], 128 -// CHECK: %[[VAL_8:.*]] = mul i32 %[[VAL_5]], 1 -// CHECK: %[[VAL_9:.*]] = add i32 %[[VAL_8]], %[[VAL_7]] -// CHECK: %[[VAL_10:.*]] = icmp ult i32 %[[VAL_9]], 294 -// CHECK: br i1 %[[VAL_10]], label %[[VAL_11:.*]], label %[[VAL_12:.*]] -// CHECK: 7: ; preds = %[[VAL_13:.*]] -// CHECK: %[[VAL_15:.*]] = udiv i32 %[[VAL_6]], 32 -// CHECK: %[[VAL_14:.*]] = urem i32 %[[VAL_6]], 32 -// CHECK: %[[VAL_37:.*]] = mul i32 %[[VAL_14]], 1 -// CHECK: %[[VAL_16:.*]] = urem i32 %[[VAL_6]], 32 -// CHECK: %[[VAL_17:.*]] = udiv i32 %[[VAL_9]], 1 -// CHECK: %[[VAL_18:.*]] = urem i32 %[[VAL_17]], 3 -// CHECK: %[[VAL_19:.*]] = udiv i32 %[[VAL_9]], 3 -// CHECK: %[[VAL_20:.*]] = urem i32 %[[VAL_19]], 2 -// CHECK: %[[VAL_21:.*]] = udiv i32 %[[VAL_9]], 6 -// CHECK: %[[VAL_22:.*]] = icmp eq i32 %[[VAL_20]], 1 -// CHECK: %[[VAL_23:.*]] = select i1 %[[VAL_22]], i32 1, i32 32 -// CHECK: %[[VAL_24:.*]] = icmp eq i32 %[[VAL_18]], 2 -// CHECK: %[[VAL_25:.*]] = select i1 %[[VAL_24]], i32 1, i32 32 -// CHECK: %[[VAL_26:.*]] = mul i32 %[[VAL_21]], 1 -// CHECK: %[[VAL_27:.*]] = mul i32 %[[VAL_20]], 32 -// CHECK: %[[VAL_28:.*]] = mul i32 %[[VAL_18]], 32 -// CHECK: store i32 %[[VAL_15]], ptr %[[VAL_3]], align 4 -// CHECK: br label %[[VAL_29:.*]] -// CHECK: loop1.loop_header: ; preds = %[[VAL_30:.*]], %[[VAL_11]] -// CHECK: %[[VAL_31:.*]] = load i32, ptr %[[VAL_3]], align 4 -// CHECK: %[[VAL_32:.*]] = icmp uge i32 %[[VAL_31]], %[[VAL_23]] -// CHECK: br i1 %[[VAL_32]], label %[[VAL_33:.*]], label %[[VAL_34:.*]] -// CHECK: loop1.loop_body: ; preds = %[[VAL_29]] -// CHECK: %[[VAL_35:.*]] = add nuw nsw i32 %[[VAL_31]], 4 -// CHECK: store i32 %[[VAL_35]], ptr %[[VAL_3]], align 4 -// CHECK: %[[VAL_36:.*]] = icmp eq i32 %[[VAL_31]], %[[VAL_15]] -// CHECK: store i32 0, ptr %[[VAL_2]], align 4 -// CHECK: br label %[[VAL_38:.*]] -// CHECK: loop2.loop_header: ; preds = %[[VAL_39:.*]], %[[VAL_34]] -// CHECK: %[[VAL_40:.*]] = load i32, ptr %[[VAL_2]], align 4 -// CHECK: %[[VAL_41:.*]] = icmp uge i32 %[[VAL_40]], 1 -// CHECK: br i1 %[[VAL_41]], label %[[VAL_30]], label %[[VAL_42:.*]] -// CHECK: loop2.loop_body: ; preds = %[[VAL_38]] -// CHECK: %[[VAL_43:.*]] = add nuw nsw i32 %[[VAL_40]], 1 -// CHECK: store i32 %[[VAL_43]], ptr %[[VAL_2]], align 4 -// CHECK: %[[VAL_44:.*]] = icmp eq i32 %[[VAL_40]], 0 -// CHECK: %[[VAL_45:.*]] = mul i32 %[[VAL_40]], 32 -// CHECK: %[[VAL_46:.*]] = add i32 %[[VAL_45]], 0 -// CHECK: %[[VAL_47:.*]] = add i32 %[[VAL_46]], %[[VAL_37]] -// CHECK: %[[VAL_48:.*]] = icmp ult i32 %[[VAL_47]], %[[VAL_25]] -// CHECK: br i1 %[[VAL_48]], label %[[VAL_49:.*]], label %[[VAL_39]] -// CHECK: x_in_tile-after: ; preds = %[[VAL_49]], %[[VAL_42]] -// CHECK: br label %[[VAL_38]], !llvm.loop !4 -// CHECK: loop2.loop_exit: ; preds = %[[VAL_38]] -// CHECK: br label %[[VAL_29]], !llvm.loop !7 -// CHECK: loop1.loop_exit: ; preds = %[[VAL_29]] -// CHECK: call void @llvm.nvvm.barrier0() -// CHECK: store i32 %[[VAL_15]], ptr %[[VAL_1]], align 4 -// CHECK: br label %[[VAL_50:.*]] -// CHECK: loop1.loop_header5: ; preds = %[[VAL_51:.*]], %[[VAL_33]] -// CHECK: %[[VAL_52:.*]] = load i32, ptr %[[VAL_1]], align 4 -// CHECK: %[[VAL_53:.*]] = icmp uge i32 %[[VAL_52]], %[[VAL_25]] -// CHECK: br i1 %[[VAL_53]], label %[[VAL_54:.*]], label %[[VAL_55:.*]] -// CHECK: loop1.loop_body6: ; preds = %[[VAL_50]] -// CHECK: %[[VAL_56:.*]] = add nuw nsw i32 %[[VAL_52]], 4 -// CHECK: store i32 %[[VAL_56]], ptr %[[VAL_1]], align 4 -// CHECK: %[[VAL_57:.*]] = icmp eq i32 %[[VAL_52]], %[[VAL_15]] -// CHECK: store i32 0, ptr %[[VAL_0]], align 4 -// CHECK: br label %[[VAL_59:.*]] -// CHECK: loop2.loop_header11: ; preds = %[[VAL_60:.*]], %[[VAL_55]] -// CHECK: %[[VAL_61:.*]] = load i32, ptr %[[VAL_0]], align 4 -// CHECK: %[[VAL_62:.*]] = icmp uge i32 %[[VAL_61]], 1 -// CHECK: br i1 %[[VAL_62]], label %[[VAL_51]], label %[[VAL_63:.*]] -// CHECK: loop2.loop_body12: ; preds = %[[VAL_59]] -// CHECK: %[[VAL_64:.*]] = add nuw nsw i32 %[[VAL_61]], 1 -// CHECK: store i32 %[[VAL_64]], ptr %[[VAL_0]], align 4 -// CHECK: %[[VAL_65:.*]] = icmp eq i32 %[[VAL_61]], 0 -// CHECK: %[[VAL_66:.*]] = mul i32 %[[VAL_61]], 32 -// CHECK: %[[VAL_67:.*]] = add i32 %[[VAL_66]], 0 -// CHECK: %[[VAL_68:.*]] = add i32 %[[VAL_67]], %[[VAL_37]] -// CHECK: %[[VAL_69:.*]] = icmp ult i32 %[[VAL_68]], %[[VAL_23]] -// CHECK: br i1 %[[VAL_69]], label %[[VAL_70:.*]], label %[[VAL_60]] -// CHECK: x_in_tile-after17: ; preds = %[[VAL_70]], %[[VAL_63]] -// CHECK: br label %[[VAL_59]], !llvm.loop !8 -// CHECK: loop2.loop_exit10: ; preds = %[[VAL_59]] -// CHECK: br label %[[VAL_50]], !llvm.loop !9 -// CHECK: loop1.loop_exit4: ; preds = %[[VAL_50]] +// CHECK-PTX: %thread.id.x = call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range !2 +// CHECK-GCN: %thread.id.x = call i32 @llvm.amdgcn.workitem.id.x +// CHECK-PTX: %block.id.x = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range !3 +// CHECK-GCN: %block.id.x = call i32 @llvm.amdgcn.workgroup.id.x +// CHECK: %thread.id.0 = udiv i32 %thread.id.x, 32 +// CHECK: %thread.id.2 = urem i32 %thread.id.x, 32 +// CHECK: %lane_id = urem i32 %thread.id.x, 32 +// CHECK: %[[VAL_5:.*]] = udiv i32 %block.id.x, 1 +// CHECK: %[[VAL_6:.*]] = urem i32 %[[VAL_5]], 3 +// CHECK: %[[VAL_7:.*]] = udiv i32 %block.id.x, 3 +// CHECK: %[[VAL_8:.*]] = urem i32 %[[VAL_7]], 49 +// CHECK: %[[VAL_9:.*]] = udiv i32 %block.id.x, 147 +// CHECK: %[[VAL_10:.*]] = icmp eq i32 %[[VAL_9]], 1 +// CHECK: %tile_bound.0 = select i1 %[[VAL_10]], i32 1, i32 32 +// CHECK: %[[VAL_11:.*]] = icmp eq i32 %[[VAL_6]], 2 +// CHECK: %tile_bound.2 = select i1 %[[VAL_11]], i32 1, i32 32 +// CHECK: %tile_origin.0 = mul i32 %[[VAL_9]], 32 +// CHECK: %tile_origin.1 = mul i32 %[[VAL_8]], 1 +// CHECK: %tile_origin.2 = mul i32 %[[VAL_6]], 32 +// CHECK: store i32 %thread.id.0, ptr{{.*}} %[[VAL_3]], align 4 +// CHECK: br label %[[VAL_12:.*]] +// CHECK: loop0.loop_header: ; preds = %[[VAL_13:.*]], %[[VAL_14:.*]] +// CHECK: %[[VAL_15:.*]] = load i32, ptr{{.*}} %[[VAL_3]], align 4 +// CHECK: %[[VAL_16:.*]] = icmp uge i32 %[[VAL_15]], %tile_bound.0 +// CHECK: br i1 %[[VAL_16]], label %[[VAL_17:.*]], label %[[VAL_18:.*]] +// CHECK: loop0.loop_body: ; preds = %[[VAL_12]] +// CHECK: %[[VAL_19:.*]] = add nuw nsw i32 %[[VAL_15]], 4 +// CHECK: store i32 %[[VAL_19]], ptr{{.*}} %[[VAL_3]], align 4 +// CHECK: %[[VAL_20:.*]] = icmp eq i32 %[[VAL_15]], %thread.id.0 +// CHECK: store i32 %thread.id.2, ptr{{.*}} %[[VAL_2]], align 4 +// CHECK: br label %[[VAL_21:.*]] +// CHECK: loop2.loop_header: ; preds = %[[VAL_22:.*]], %[[VAL_18]] +// CHECK: %[[VAL_23:.*]] = load i32, ptr{{.*}} %[[VAL_2]], align 4 +// CHECK: %[[VAL_24:.*]] = icmp uge i32 %[[VAL_23]], %tile_bound.2 +// CHECK: br i1 %[[VAL_24]], label %[[VAL_13]], label %[[VAL_22]] +// CHECK: loop2.loop_body: ; preds = %[[VAL_21]] +// CHECK: %[[VAL_25:.*]] = add nuw nsw i32 %[[VAL_23]], 32 +// CHECK: store i32 %[[VAL_25]], ptr{{.*}} %[[VAL_2]], align 4 +// CHECK: %[[VAL_26:.*]] = icmp eq i32 %[[VAL_23]], %thread.id.2 +// CHECK: %[[VAL_27:.*]] = add i32 %tile_origin.0, %[[VAL_15]] +// CHECK: %[[VAL_28:.*]] = add i32 %tile_origin.1, 0 +// CHECK: %[[VAL_29:.*]] = add i32 %tile_origin.2, %[[VAL_23]] +// CHECK: %[[VAL_30:.*]] = getelementptr inbounds [33 x [49 x [65 x float]]], ptr %[[VAL_31:.*]], i32 0, i32 %[[VAL_27]], i32 %[[VAL_28]], i32 %[[VAL_29]] +// CHECK: %[[VAL_32:.*]] = load float, ptr %[[VAL_30]], align 4, !invariant.load !{{[0-9]}} +// CHECK: %[[VAL_33:.*]] = getelementptr inbounds [32 x [1 x [33 x float]]], ptr addrspace(3) @tr_tile_0, i32 0, i32 %[[VAL_15]], i32 0, i32 %[[VAL_23]] +// CHECK: %[[VAL_34:.*]] = addrspacecast ptr addrspace(3) %[[VAL_33]] to ptr +// CHECK: store float %[[VAL_32]], ptr %[[VAL_34]], align 4 +// CHECK: br label %[[VAL_21]], !llvm.loop !{{[0-9]}} +// CHECK: loop2.loop_exit: ; preds = %[[VAL_21]] +// CHECK: br label %[[VAL_12]], !llvm.loop !{{[0-9]}} +// CHECK: loop0.loop_exit: ; preds = %[[VAL_12]] +// CHECK-PTX: call void @llvm.nvvm.barrier0() +// CHECK-GCN: call void @llvm.amdgcn.s.barrier() +// CHECK: store i32 %thread.id.0, ptr{{.*}} %[[VAL_1]], align 4 +// CHECK: br label %[[VAL_35:.*]] +// CHECK: loop0.loop_header4: ; preds = %[[VAL_36:.*]], %[[VAL_17]] +// CHECK: %[[VAL_37:.*]] = load i32, ptr{{.*}} %[[VAL_1]], align 4 +// CHECK: %[[VAL_38:.*]] = icmp uge i32 %[[VAL_37]], %tile_bound.2 +// CHECK: br i1 %[[VAL_38]], label %[[VAL_39:.*]], label %[[VAL_40:.*]] +// CHECK: loop0.loop_body5: ; preds = %[[VAL_35]] +// CHECK: %[[VAL_41:.*]] = add nuw nsw i32 %[[VAL_37]], 4 +// CHECK: store i32 %[[VAL_41]], ptr{{.*}} %[[VAL_1]], align 4 +// CHECK: %[[VAL_42:.*]] = icmp eq i32 %[[VAL_37]], %thread.id.0 +// CHECK: store i32 %thread.id.2, ptr{{.*}} %[[VAL_0]], align 4 +// CHECK: br label %[[VAL_43:.*]] +// CHECK: loop2.loop_header10: ; preds = %[[VAL_44:.*]], %[[VAL_40]] +// CHECK: %[[VAL_45:.*]] = load i32, ptr{{.*}} %[[VAL_0]], align 4 +// CHECK: %[[VAL_46:.*]] = icmp uge i32 %[[VAL_45]], %tile_bound.0 +// CHECK: br i1 %[[VAL_46]], label %[[VAL_36]], label %[[VAL_44]] +// CHECK: loop2.loop_body11: ; preds = %[[VAL_43]] +// CHECK: %[[VAL_47:.*]] = add nuw nsw i32 %[[VAL_45]], 32 +// CHECK: store i32 %[[VAL_47]], ptr{{.*}} %[[VAL_0]], align 4 +// CHECK: %[[VAL_48:.*]] = icmp eq i32 %[[VAL_45]], %thread.id.2 +// CHECK: %[[VAL_49:.*]] = add i32 %tile_origin.2, %[[VAL_37]] +// CHECK: %[[VAL_50:.*]] = add i32 %tile_origin.1, 0 +// CHECK: %[[VAL_51:.*]] = add i32 %tile_origin.0, %[[VAL_45]] +// CHECK: %[[VAL_52:.*]] = getelementptr inbounds [32 x [1 x [33 x float]]], ptr addrspace(3) @tr_tile_0, i32 0, i32 %[[VAL_45]], i32 0, i32 %[[VAL_37]] +// CHECK: %[[VAL_53:.*]] = addrspacecast ptr addrspace(3) %[[VAL_52]] to ptr +// CHECK: %[[VAL_54:.*]] = load float, ptr{{.*}} %[[VAL_53]], align 4 +// CHECK: %[[VAL_55:.*]] = getelementptr inbounds [65 x [49 x [33 x float]]], ptr %[[VAL_56:.*]], i32 0, i32 %[[VAL_49]], i32 %[[VAL_50]], i32 %[[VAL_51]] +// CHECK: store float %[[VAL_54]], ptr %[[VAL_55]], align 4 +// CHECK: br label %[[VAL_43]], !llvm.loop !{{[0-9]}} +// CHECK: loop2.loop_exit9: ; preds = %[[VAL_43]] +// CHECK: br label %[[VAL_35]], !llvm.loop !{{[0-9]}} +// CHECK: loop0.loop_exit3: ; preds = %[[VAL_35]] // CHECK: ret void -// CHECK: early_return: ; preds = %[[VAL_13]] -// CHECK: ret void -// CHECK: x_in_tile-true: ; preds = %[[VAL_42]] -// CHECK: %[[TILE_1:.*]] = add i32 %[[VAL_26]], 0 -// CHECK: %[[VAL_71:.*]] = add i32 %[[VAL_27]], %[[VAL_31]] -// CHECK: %[[VAL_72:.*]] = add i32 %[[VAL_28]], %[[VAL_47]] -// CHECK: %[[VAL_73:.*]] = getelementptr inbounds [33 x [49 x [65 x float]]], ptr %[[VAL_74:.*]], i32 0, i32 %[[VAL_71]], i32 %[[TILE_1]], i32 %[[VAL_72]] -// CHECK: %[[VAL_75:.*]] = load float, ptr %[[VAL_73]], align 4, !invariant.load !10 -// CHECK: %[[VAL_76:.*]] = getelementptr inbounds [1 x [32 x [33 x float]]], ptr addrspace(3) @tr_tile_0, i32 0, i32 0, i32 %[[VAL_31]], i32 %[[VAL_47]] -// CHECK: %[[VAL_77:.*]] = addrspacecast ptr addrspace(3) %[[VAL_76]] to ptr -// CHECK: store float %[[VAL_75]], ptr %[[VAL_77]], align 4 -// CHECK: br label %[[VAL_39]] -// CHECK: x_in_tile-true16: ; preds = %[[VAL_63]] -// CHECK: %[[TILE_1:.*]] = add i32 %[[VAL_26]], 0 -// CHECK: %[[VAL_78:.*]] = add i32 %[[VAL_28]], %[[VAL_52]] -// CHECK: %[[VAL_79:.*]] = add i32 %[[VAL_27]], %[[VAL_68]] -// CHECK: %[[VAL_80:.*]] = getelementptr inbounds [1 x [32 x [33 x float]]], ptr addrspace(3) @tr_tile_0, i32 0, i32 0, i32 %[[VAL_68]], i32 %[[VAL_52]] -// CHECK: %[[VAL_81:.*]] = addrspacecast ptr addrspace(3) %[[VAL_80]] to ptr -// CHECK: %[[VAL_82:.*]] = load float, ptr %[[VAL_81]], align 4 -// CHECK: %[[VAL_83:.*]] = getelementptr inbounds [65 x [49 x [33 x float]]], ptr %[[VAL_84:.*]], i32 0, i32 %[[VAL_78]], i32 %[[TILE_1]], i32 %[[VAL_79]] -// CHECK: store float %[[VAL_82]], ptr %[[VAL_83]], align 4 -// CHECK: br label %[[VAL_60]] diff --git a/third_party/xla/xla/service/gpu/tests/transpose_210_extra_output.hlo b/third_party/xla/xla/service/gpu/tests/transpose_210_extra_output.hlo index 3fff5e9bbab17d..9581099deed5c4 100644 --- a/third_party/xla/xla/service/gpu/tests/transpose_210_extra_output.hlo +++ b/third_party/xla/xla/service/gpu/tests/transpose_210_extra_output.hlo @@ -20,119 +20,94 @@ ENTRY main { // CHECK: %[[VAL_1:.*]] = alloca i32, align 4 // CHECK: %[[VAL_2:.*]] = alloca i32, align 4 // CHECK: %[[VAL_3:.*]] = alloca i32, align 4 -// CHECK-PTX: %[[VAL_4:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range !2 -// CHECK-GCN: %[[VAL_4:.*]] = call i32 @llvm.amdgcn.workitem.id.x -// CHECK-PTX: %[[VAL_5:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range !3 -// CHECK-GCN: %[[VAL_5:.*]] = call i32 @llvm.amdgcn.workgroup.id.x -// CHECK: %[[VAL_6:.*]] = urem i32 %[[VAL_4]], 128 -// CHECK: %[[VAL_7:.*]] = udiv i32 %[[VAL_4]], 128 -// CHECK: %[[VAL_8:.*]] = mul i32 %[[VAL_5]], 1 -// CHECK: %[[VAL_9:.*]] = add i32 %[[VAL_8]], %[[VAL_7]] -// CHECK: %[[VAL_10:.*]] = icmp ult i32 %[[VAL_9]], 294 -// CHECK: br i1 %[[VAL_10]], label %[[VAL_11:.*]], label %[[VAL_12:.*]] -// CHECK: 7: ; preds = %[[VAL_13:.*]] -// CHECK: %[[VAL_15:.*]] = udiv i32 %[[VAL_6]], 32 -// CHECK: %[[VAL_14:.*]] = urem i32 %[[VAL_6]], 32 -// CHECK: %[[VAL_37:.*]] = mul i32 %[[VAL_14]], 1 -// CHECK: %[[VAL_16:.*]] = urem i32 %[[VAL_6]], 32 -// CHECK: %[[VAL_17:.*]] = udiv i32 %[[VAL_9]], 1 -// CHECK: %[[VAL_18:.*]] = urem i32 %[[VAL_17]], 3 -// CHECK: %[[VAL_19:.*]] = udiv i32 %[[VAL_9]], 3 -// CHECK: %[[VAL_20:.*]] = urem i32 %[[VAL_19]], 2 -// CHECK: %[[VAL_21:.*]] = udiv i32 %[[VAL_9]], 6 -// CHECK: %[[VAL_22:.*]] = icmp eq i32 %[[VAL_20]], 1 -// CHECK: %[[VAL_23:.*]] = select i1 %[[VAL_22]], i32 1, i32 32 -// CHECK: %[[VAL_24:.*]] = icmp eq i32 %[[VAL_18]], 2 -// CHECK: %[[VAL_25:.*]] = select i1 %[[VAL_24]], i32 1, i32 32 -// CHECK: %[[VAL_26:.*]] = mul i32 %[[VAL_21]], 1 -// CHECK: %[[VAL_27:.*]] = mul i32 %[[VAL_20]], 32 -// CHECK: %[[VAL_28:.*]] = mul i32 %[[VAL_18]], 32 -// CHECK: store i32 %[[VAL_15]], ptr %[[VAL_3]], align 4 -// CHECK: br label %[[VAL_29:.*]] -// CHECK: loop1.loop_header: ; preds = %[[VAL_30:.*]], %[[VAL_11]] -// CHECK: %[[VAL_31:.*]] = load i32, ptr %[[VAL_3]], align 4 -// CHECK: %[[VAL_32:.*]] = icmp uge i32 %[[VAL_31]], %[[VAL_23]] -// CHECK: br i1 %[[VAL_32]], label %[[VAL_33:.*]], label %[[VAL_34:.*]] -// CHECK: loop1.loop_body: ; preds = %[[VAL_29]] -// CHECK: %[[VAL_35:.*]] = add nuw nsw i32 %[[VAL_31]], 4 -// CHECK: store i32 %[[VAL_35]], ptr %[[VAL_3]], align 4 -// CHECK: %[[VAL_36:.*]] = icmp eq i32 %[[VAL_31]], %[[VAL_15]] -// CHECK: store i32 0, ptr %[[VAL_2]], align 4 -// CHECK: br label %[[VAL_38:.*]] -// CHECK: loop2.loop_header: ; preds = %[[VAL_39:.*]], %[[VAL_34]] -// CHECK: %[[VAL_40:.*]] = load i32, ptr %[[VAL_2]], align 4 -// CHECK: %[[VAL_41:.*]] = icmp uge i32 %[[VAL_40]], 1 -// CHECK: br i1 %[[VAL_41]], label %[[VAL_30]], label %[[VAL_42:.*]] -// CHECK: loop2.loop_body: ; preds = %[[VAL_38]] -// CHECK: %[[VAL_43:.*]] = add nuw nsw i32 %[[VAL_40]], 1 -// CHECK: store i32 %[[VAL_43]], ptr %[[VAL_2]], align 4 -// CHECK: %[[VAL_44:.*]] = icmp eq i32 %[[VAL_40]], 0 -// CHECK: %[[VAL_45:.*]] = mul i32 %[[VAL_40]], 32 -// CHECK: %[[VAL_46:.*]] = add i32 %[[VAL_45]], 0 -// CHECK: %[[VAL_47:.*]] = add i32 %[[VAL_46]], %[[VAL_37]] -// CHECK: %[[VAL_48:.*]] = icmp ult i32 %[[VAL_47]], %[[VAL_25]] -// CHECK: br i1 %[[VAL_48]], label %[[VAL_49:.*]], label %[[VAL_39]] -// CHECK: x_in_tile-after: ; preds = %[[VAL_49]], %[[VAL_42]] -// CHECK: br label %[[VAL_38]], !llvm.loop !4 -// CHECK: loop2.loop_exit: ; preds = %[[VAL_38]] -// CHECK: br label %[[VAL_29]], !llvm.loop !7 -// CHECK: loop1.loop_exit: ; preds = %[[VAL_29]] -// CHECK: call void @llvm.nvvm.barrier0() -// CHECK: store i32 %[[VAL_15]], ptr %[[VAL_1]], align 4 -// CHECK: br label %[[VAL_50:.*]] -// CHECK: loop1.loop_header7: ; preds = %[[VAL_51:.*]], %[[VAL_33]] -// CHECK: %[[VAL_52:.*]] = load i32, ptr %[[VAL_1]], align 4 -// CHECK: %[[VAL_53:.*]] = icmp uge i32 %[[VAL_52]], %[[VAL_25]] -// CHECK: br i1 %[[VAL_53]], label %[[VAL_54:.*]], label %[[VAL_55:.*]] -// CHECK: loop1.loop_body8: ; preds = %[[VAL_50]] -// CHECK: %[[VAL_56:.*]] = add nuw nsw i32 %[[VAL_52]], 4 -// CHECK: store i32 %[[VAL_56]], ptr %[[VAL_1]], align 4 -// CHECK: %[[VAL_57:.*]] = icmp eq i32 %[[VAL_52]], %[[VAL_15]] -// CHECK: store i32 0, ptr %[[VAL_0]], align 4 -// CHECK: br label %[[VAL_59:.*]] -// CHECK: loop2.loop_header13: ; preds = %[[VAL_60:.*]], %[[VAL_55]] -// CHECK: %[[VAL_61:.*]] = load i32, ptr %[[VAL_0]], align 4 -// CHECK: %[[VAL_62:.*]] = icmp uge i32 %[[VAL_61]], 1 -// CHECK: br i1 %[[VAL_62]], label %[[VAL_51]], label %[[VAL_63:.*]] -// CHECK: loop2.loop_body14: ; preds = %[[VAL_59]] -// CHECK: %[[VAL_64:.*]] = add nuw nsw i32 %[[VAL_61]], 1 -// CHECK: store i32 %[[VAL_64]], ptr %[[VAL_0]], align 4 -// CHECK: %[[VAL_65:.*]] = icmp eq i32 %[[VAL_61]], 0 -// CHECK: %[[VAL_66:.*]] = mul i32 %[[VAL_61]], 32 -// CHECK: %[[VAL_67:.*]] = add i32 %[[VAL_66]], 0 -// CHECK: %[[VAL_68:.*]] = add i32 %[[VAL_67]], %[[VAL_37]] -// CHECK: %[[VAL_69:.*]] = icmp ult i32 %[[VAL_68]], %[[VAL_23]] -// CHECK: br i1 %[[VAL_69]], label %[[VAL_70:.*]], label %[[VAL_60]] -// CHECK: x_in_tile-after19: ; preds = %[[VAL_70]], %[[VAL_63]] -// CHECK: br label %[[VAL_59]], !llvm.loop !8 -// CHECK: loop2.loop_exit12: ; preds = %[[VAL_59]] -// CHECK: br label %[[VAL_50]], !llvm.loop !9 -// CHECK: loop1.loop_exit6: ; preds = %[[VAL_50]] +// CHECK-PTX: %thread.id.x = call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range !2 +// CHECK-GCN: %thread.id.x = call i32 @llvm.amdgcn.workitem.id.x +// CHECK-PTX: %block.id.x = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range !3 +// CHECK-GCN: %block.id.x = call i32 @llvm.amdgcn.workgroup.id.x +// CHECK: %thread.id.0 = udiv i32 %thread.id.x, 32 +// CHECK: %thread.id.2 = urem i32 %thread.id.x, 32 +// CHECK: %lane_id = urem i32 %thread.id.x, 32 +// CHECK: %[[VAL_5:.*]] = udiv i32 %block.id.x, 1 +// CHECK: %[[VAL_6:.*]] = urem i32 %[[VAL_5]], 3 +// CHECK: %[[VAL_7:.*]] = udiv i32 %block.id.x, 3 +// CHECK: %[[VAL_8:.*]] = urem i32 %[[VAL_7]], 49 +// CHECK: %[[VAL_9:.*]] = udiv i32 %block.id.x, 147 +// CHECK: %[[VAL_10:.*]] = icmp eq i32 %[[VAL_9]], 1 +// CHECK: %tile_bound.0 = select i1 %[[VAL_10]], i32 1, i32 32 +// CHECK: %[[VAL_11:.*]] = icmp eq i32 %[[VAL_6]], 2 +// CHECK: %tile_bound.2 = select i1 %[[VAL_11]], i32 1, i32 32 +// CHECK: %tile_origin.0 = mul i32 %[[VAL_9]], 32 +// CHECK: %tile_origin.1 = mul i32 %[[VAL_8]], 1 +// CHECK: %tile_origin.2 = mul i32 %[[VAL_6]], 32 +// CHECK: store i32 %thread.id.0, ptr{{.*}} %[[VAL_3]], align 4 +// CHECK: br label %[[VAL_12:.*]] +// CHECK: loop0.loop_header: ; preds = %[[VAL_13:.*]], %[[VAL_14:.*]] +// CHECK: %[[VAL_15:.*]] = load i32, ptr{{.*}} %[[VAL_3]], align 4 +// CHECK: %[[VAL_16:.*]] = icmp uge i32 %[[VAL_15]], %tile_bound.0 +// CHECK: br i1 %[[VAL_16]], label %[[VAL_17:.*]], label %[[VAL_18:.*]] +// CHECK: loop0.loop_body: ; preds = %[[VAL_12]] +// CHECK: %[[VAL_19:.*]] = add nuw nsw i32 %[[VAL_15]], 4 +// CHECK: store i32 %[[VAL_19]], ptr{{.*}} %[[VAL_3]], align 4 +// CHECK: %[[VAL_20:.*]] = icmp eq i32 %[[VAL_15]], %thread.id.0 +// CHECK: store i32 %thread.id.2, ptr{{.*}} %[[VAL_2]], align 4 +// CHECK: br label %[[VAL_21:.*]] +// CHECK: loop2.loop_header: ; preds = %[[VAL_22:.*]], %[[VAL_18]] +// CHECK: %[[VAL_23:.*]] = load i32, ptr{{.*}} %[[VAL_2]], align 4 +// CHECK: %[[VAL_24:.*]] = icmp uge i32 %[[VAL_23]], %tile_bound.2 +// CHECK: br i1 %[[VAL_24]], label %[[VAL_13]], label %[[VAL_22]] +// CHECK: loop2.loop_body: ; preds = %[[VAL_21]] +// CHECK: %[[VAL_25:.*]] = add nuw nsw i32 %[[VAL_23]], 32 +// CHECK: store i32 %[[VAL_25]], ptr{{.*}} %[[VAL_2]], align 4 +// CHECK: %[[VAL_26:.*]] = icmp eq i32 %[[VAL_23]], %thread.id.2 +// CHECK: %[[VAL_27:.*]] = add i32 %tile_origin.0, %[[VAL_15]] +// CHECK: %[[VAL_28:.*]] = add i32 %tile_origin.1, 0 +// CHECK: %[[VAL_29:.*]] = add i32 %tile_origin.2, %[[VAL_23]] +// CHECK: %[[VAL_30:.*]] = getelementptr inbounds [33 x [49 x [65 x float]]], ptr{{.*}} %[[VAL_31:.*]], i32 0, i32 %[[VAL_27]], i32 %[[VAL_28]], i32 %[[VAL_29]] +// CHECK: %[[VAL_32:.*]] = load float, ptr{{.*}} %[[VAL_30]], align 4, !invariant.load !{{[0-9]}} +// CHECK: %[[VAL_33:.*]] = getelementptr inbounds [32 x [1 x [33 x float]]], ptr{{.*}} addrspace(3) @tr_tile_0, i32 0, i32 %[[VAL_15]], i32 0, i32 %[[VAL_23]] +// CHECK: %[[VAL_34:.*]] = addrspacecast ptr{{.*}} addrspace(3) %[[VAL_33]] to ptr +// CHECK: store float %[[VAL_32]], ptr{{.*}} %[[VAL_34]], align 4 +// CHECK: %[[VAL_35:.*]] = getelementptr inbounds [33 x [49 x [65 x float]]], ptr{{.*}} %[[VAL_31]], i32 0, i32 %[[VAL_27]], i32 %[[VAL_28]], i32 %[[VAL_29]] +// CHECK: %[[VAL_36:.*]] = load float, ptr{{.*}} %[[VAL_35]], align 4, !invariant.load !{{[0-9]}} +// CHECK: %[[VAL_37:.*]] = fneg float %[[VAL_36]] +// CHECK: %[[VAL_38:.*]] = getelementptr inbounds [33 x [49 x [65 x float]]], ptr{{.*}} %[[VAL_39:.*]], i32 0, i32 %[[VAL_27]], i32 %[[VAL_28]], i32 %[[VAL_29]] +// CHECK: store float %[[VAL_37]], ptr{{.*}} %[[VAL_38]], align 4 +// CHECK: br label %[[VAL_21]], !llvm.loop !{{[0-9]}} +// CHECK: loop2.loop_exit: ; preds = %[[VAL_21]] +// CHECK: br label %[[VAL_12]], !llvm.loop !{{[0-9]}} +// CHECK: loop0.loop_exit: ; preds = %[[VAL_12]] +// CHECK-PTX: call void @llvm.nvvm.barrier0() +// CHECK-GCN: call void @llvm.amdgcn.s.barrier() +// CHECK: store i32 %thread.id.0, ptr{{.*}} %[[VAL_1]], align 4 +// CHECK: br label %[[VAL_40:.*]] +// CHECK: loop0.loop_header6: ; preds = %[[VAL_41:.*]], %[[VAL_17]] +// CHECK: %[[VAL_42:.*]] = load i32, ptr{{.*}} %[[VAL_1]], align 4 +// CHECK: %[[VAL_43:.*]] = icmp uge i32 %[[VAL_42]], %tile_bound.2 +// CHECK: br i1 %[[VAL_43]], label %[[VAL_44:.*]], label %[[VAL_45:.*]] +// CHECK: loop0.loop_body7: ; preds = %[[VAL_40]] +// CHECK: %[[VAL_46:.*]] = add nuw nsw i32 %[[VAL_42]], 4 +// CHECK: store i32 %[[VAL_46]], ptr{{.*}} %[[VAL_1]], align 4 +// CHECK: %[[VAL_47:.*]] = icmp eq i32 %[[VAL_42]], %thread.id.0 +// CHECK: store i32 %thread.id.2, ptr{{.*}} %[[VAL_0]], align 4 +// CHECK: br label %[[VAL_48:.*]] +// CHECK: loop2.loop_header12: ; preds = %[[VAL_49:.*]], %[[VAL_45]] +// CHECK: %[[VAL_50:.*]] = load i32, ptr{{.*}} %[[VAL_0]], align 4 +// CHECK: %[[VAL_51:.*]] = icmp uge i32 %[[VAL_50]], %tile_bound.0 +// CHECK: br i1 %[[VAL_51]], label %[[VAL_41]], label %[[VAL_49]] +// CHECK: loop2.loop_body13: ; preds = %[[VAL_48]] +// CHECK: %[[VAL_52:.*]] = add nuw nsw i32 %[[VAL_50]], 32 +// CHECK: store i32 %[[VAL_52]], ptr{{.*}} %[[VAL_0]], align 4 +// CHECK: %[[VAL_53:.*]] = icmp eq i32 %[[VAL_50]], %thread.id.2 +// CHECK: %[[VAL_54:.*]] = add i32 %tile_origin.2, %[[VAL_42]] +// CHECK: %[[VAL_55:.*]] = add i32 %tile_origin.1, 0 +// CHECK: %[[VAL_56:.*]] = add i32 %tile_origin.0, %[[VAL_50]] +// CHECK: %[[VAL_57:.*]] = getelementptr inbounds [32 x [1 x [33 x float]]], ptr{{.*}} addrspace(3) @tr_tile_0, i32 0, i32 %[[VAL_50]], i32 0, i32 %[[VAL_42]] +// CHECK: %[[VAL_58:.*]] = addrspacecast ptr{{.*}} addrspace(3) %[[VAL_57]] to ptr +// CHECK: %[[VAL_59:.*]] = load float, ptr{{.*}} %[[VAL_58]], align 4 +// CHECK: %[[VAL_60:.*]] = getelementptr inbounds [65 x [49 x [33 x float]]], ptr{{.*}} %[[VAL_61:.*]], i32 0, i32 %[[VAL_54]], i32 %[[VAL_55]], i32 %[[VAL_56]] +// CHECK: store float %[[VAL_59]], ptr{{.*}} %[[VAL_60]], align 4 +// CHECK: br label %[[VAL_48]], !llvm.loop !{{[0-9]}} +// CHECK: loop2.loop_exit11: ; preds = %[[VAL_48]] +// CHECK: br label %[[VAL_40]], !llvm.loop !{{[0-9]}} +// CHECK: loop0.loop_exit5: ; preds = %[[VAL_40]] // CHECK: ret void -// CHECK: early_return: ; preds = %[[VAL_13]] -// CHECK: ret void -// CHECK: x_in_tile-true: ; preds = %[[VAL_42]] -// CHECK: %[[TILE_1:.*]] = add i32 %[[VAL_26]], 0 -// CHECK: %[[VAL_71:.*]] = add i32 %[[VAL_27]], %[[VAL_31]] -// CHECK: %[[VAL_72:.*]] = add i32 %[[VAL_28]], %[[VAL_47]] -// CHECK: %[[VAL_73:.*]] = getelementptr inbounds [33 x [49 x [65 x float]]], ptr %[[VAL_74:.*]], i32 0, i32 %[[VAL_71]], i32 %[[TILE_1]], i32 %[[VAL_72]] -// CHECK: %[[VAL_75:.*]] = load float, ptr %[[VAL_73]], align 4, !invariant.load !10 -// CHECK: %[[VAL_76:.*]] = getelementptr inbounds [1 x [32 x [33 x float]]], ptr addrspace(3) @tr_tile_0, i32 0, i32 0, i32 %[[VAL_31]], i32 %[[VAL_47]] -// CHECK: %[[VAL_77:.*]] = addrspacecast ptr addrspace(3) %[[VAL_76]] to ptr -// CHECK: store float %[[VAL_75]], ptr %[[VAL_77]], align 4 -// CHECK: %[[VAL_78:.*]] = getelementptr inbounds [33 x [49 x [65 x float]]], ptr %[[VAL_74]], i32 0, i32 %[[VAL_71]], i32 %[[TILE_1]], i32 %[[VAL_72]] -// CHECK: %[[VAL_79:.*]] = load float, ptr %[[VAL_78]], align 4, !invariant.load !10 -// CHECK: %[[VAL_80:.*]] = fneg float %[[VAL_79]] -// CHECK: %[[VAL_81:.*]] = getelementptr inbounds [33 x [49 x [65 x float]]], ptr %[[VAL_82:.*]], i32 0, i32 %[[VAL_71]], i32 %[[TILE_1]], i32 %[[VAL_72]] -// CHECK: store float %[[VAL_80]], ptr %[[VAL_81]], align 4 -// CHECK: br label %[[VAL_39]] -// CHECK: x_in_tile-true18: ; preds = %[[VAL_63]] -// CHECK: %[[TILE_1:.*]] = add i32 %[[VAL_26]], 0 -// CHECK: %[[VAL_83:.*]] = add i32 %[[VAL_28]], %[[VAL_52]] -// CHECK: %[[VAL_84:.*]] = add i32 %[[VAL_27]], %[[VAL_68]] -// CHECK: %[[VAL_85:.*]] = getelementptr inbounds [1 x [32 x [33 x float]]], ptr addrspace(3) @tr_tile_0, i32 0, i32 0, i32 %[[VAL_68]], i32 %[[VAL_52]] -// CHECK: %[[VAL_86:.*]] = addrspacecast ptr addrspace(3) %[[VAL_85]] to ptr -// CHECK: %[[VAL_87:.*]] = load float, ptr %[[VAL_86]], align 4 -// CHECK: %[[VAL_88:.*]] = getelementptr inbounds [65 x [49 x [33 x float]]], ptr %[[VAL_89:.*]], i32 0, i32 %[[VAL_83]], i32 %[[TILE_1]], i32 %[[VAL_84]] -// CHECK: store float %[[VAL_87]], ptr %[[VAL_88]], align 4 -// CHECK: br label %[[VAL_60]] diff --git a/third_party/xla/xla/service/gpu/tests/triton_naming.hlo b/third_party/xla/xla/service/gpu/tests/triton_naming.hlo index a954c4c46561ad..2739e349181786 100644 --- a/third_party/xla/xla/service/gpu/tests/triton_naming.hlo +++ b/third_party/xla/xla/service/gpu/tests/triton_naming.hlo @@ -15,5 +15,5 @@ HloModule t, is_scheduled=true, entry_computation_layout={(f16[15,19]{1,0},s8[19 ENTRY %e (p0: f16[15,19], p1: s8[19,17]) -> f16[15,17] { %p1 = s8[19,17]{1,0} parameter(1) %p0 = f16[15,19]{1,0} parameter(0) - ROOT %triton_gemm_r = f16[15,17]{1,0} fusion(%p1, %p0), kind=kCustom, calls=%triton_gemm_r, backend_config="{ \"fusion_backend_config\": {kind: \"__triton_gemm\", triton_gemm_config: {\"block_m\":\"64\",\"block_n\":\"32\",\"block_k\":\"64\",\"split_k\":\"1\",\"num_stages\":\"2\",\"num_warps\":\"8\"}}}" + ROOT %triton_gemm_r = f16[15,17]{1,0} fusion(%p1, %p0), kind=kCustom, calls=%triton_gemm_r, backend_config="{ \"fusion_backend_config\": {kind: \"__triton_gemm\", triton_gemm_config: {\"block_m\":\"64\",\"block_n\":\"32\",\"block_k\":\"64\",\"split_k\":\"1\",\"num_stages\":\"2\",\"num_warps\":\"8\",\"num_ctas\":\"1\"}}}" } diff --git a/third_party/xla/xla/service/gpu/thunk.cc b/third_party/xla/xla/service/gpu/thunk.cc index 1318ff18993ec0..c6dcfbb3ed5143 100644 --- a/third_party/xla/xla/service/gpu/thunk.cc +++ b/third_party/xla/xla/service/gpu/thunk.cc @@ -15,6 +15,7 @@ limitations under the License. #include "xla/service/gpu/thunk.h" +#include #include #include #include @@ -33,8 +34,10 @@ limitations under the License. #include "xla/executable_run_options.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/service/global_device_id.h" +#include "xla/service/gpu/backend_configs.pb.h" #include "xla/service/gpu/buffer_allocations.h" #include "xla/service/gpu/gpu_executable_run_options.h" +#include "xla/service/gpu/nccl_api.h" #include "xla/service/gpu/nccl_clique.h" #include "xla/service/gpu/nccl_clique_key.h" #include "xla/service/service_executable_run_options.h" @@ -49,10 +52,11 @@ namespace gpu { // Thunk::CollectiveCliques //===----------------------------------------------------------------------===// -Thunk::CollectiveCliques::CollectiveCliques(CliquesMap cliques_map) +Thunk::CollectiveCliques::CollectiveCliques( + NcclClique::AcquiredCliquesMap cliques_map) : cliques_map_(std::move(cliques_map)) {} -absl::StatusOr Thunk::CollectiveCliques::GetComm( +absl::StatusOr Thunk::CollectiveCliques::GetComm( const NcclCliqueKey& clique_key, int32_t rank) const { // Check that we locked access to a clique for `clique_key`. auto clique = cliques_map_.find(clique_key); @@ -69,7 +73,19 @@ absl::StatusOr Thunk::CollectiveCliques::GetComm( clique_key.ToString())); } - return (*communicator)->Acquire(); + return *communicator; +} + +absl::StatusOr Thunk::CollectiveCliques::num_communicators( + const NcclCliqueKey& clique_key) const { + // Check that we locked access to a clique for `clique_key`. + auto clique = cliques_map_.find(clique_key); + if (clique == cliques_map_.end()) { + return absl::NotFoundError(absl::StrCat("No clique found for clique key: ", + clique_key.ToString())); + } + + return (*clique->second)->num_communicators(); } //===----------------------------------------------------------------------===// @@ -98,7 +114,8 @@ static absl::StatusOr GetGlobalDeviceId( absl::StatusOr Thunk::CollectiveExecuteParams::Create( const ServiceExecutableRunOptions& run_options, - int64_t local_device_ordinal) { + int64_t local_device_ordinal, int64_t collective_max_nchannels, + int64_t p2p_max_nchannels) { const GpuExecutableRunOptions* gpu_options = run_options.run_options().gpu_executable_run_options(); @@ -113,23 +130,28 @@ Thunk::CollectiveExecuteParams::Create( TF_ASSIGN_OR_RETURN(GlobalDeviceId global_device_id, GetGlobalDeviceId(device_id_map, local_device_ordinal)); - return CollectiveExecuteParams(run_options.run_options().run_id(), - local_device_ordinal, global_device_id, - run_options.run_options().device_assignment(), - device_id_map, nccl_callback); + return CollectiveExecuteParams( + run_options.stream()->parent(), run_options.run_options().run_id(), + local_device_ordinal, global_device_id, + run_options.run_options().device_assignment(), device_id_map, + nccl_callback, collective_max_nchannels, p2p_max_nchannels); } Thunk::CollectiveExecuteParams::CollectiveExecuteParams( - RunId run_id, int64_t local_device_ordinal, GlobalDeviceId global_device_id, - const DeviceAssignment* device_assn, + se::StreamExecutor* executor, RunId run_id, int64_t local_device_ordinal, + GlobalDeviceId global_device_id, const DeviceAssignment* device_assn, const GlobalDeviceIdMap* global_device_id_map, - const NcclCliqueIdCallback* nccl_clique_id_callback) - : run_id(run_id), + const NcclCliqueIdCallback* nccl_clique_id_callback, + int64_t collective_max_nchannels, int64_t p2p_max_nchannels) + : executor(executor), + run_id(run_id), local_device_ordinal(local_device_ordinal), global_device_id(global_device_id), device_assn(device_assn), global_device_id_map(global_device_id_map), - nccl_clique_id_callback(nccl_clique_id_callback) {} + nccl_clique_id_callback(nccl_clique_id_callback), + collective_max_nchannels(collective_max_nchannels), + p2p_max_nchannels(p2p_max_nchannels) {} //===----------------------------------------------------------------------===// // Thunk::ExecuteParams @@ -141,14 +163,16 @@ Thunk::ExecuteParams Thunk::ExecuteParams::Create( se::Stream* command_buffer_trace_stream, absl::Span async_streams, CollectiveExecuteParams* collective_params, - CollectiveCliques* collective_cliques) { + CollectiveCliques* collective_cliques, + ExecutionStreamIdMap additional_compute_streams) { return ExecuteParams(&buffer_allocations, stream, command_buffer_trace_stream, {async_streams.begin(), async_streams.end()}, collective_params, collective_cliques, run_options.run_options().device_to_host_stream(), run_options.run_options().host_to_device_stream(), run_options.run_options().send_device_memory_function(), - run_options.run_options().recv_device_memory_function()); + run_options.run_options().recv_device_memory_function(), + additional_compute_streams); } Thunk::ExecuteParams::ExecuteParams( @@ -159,7 +183,8 @@ Thunk::ExecuteParams::ExecuteParams( CollectiveCliques* collective_cliques, se::Stream* device_to_host_stream, se::Stream* host_to_device_stream, SendDeviceMemoryFunction* send_device_memory_function, - RecvDeviceMemoryFunction* recv_device_memory_function) + RecvDeviceMemoryFunction* recv_device_memory_function, + ExecutionStreamIdMap additional_compute_streams) : buffer_allocations(buffer_allocations), stream(stream), command_buffer_trace_stream(command_buffer_trace_stream), @@ -169,7 +194,8 @@ Thunk::ExecuteParams::ExecuteParams( device_to_host_stream(device_to_host_stream), host_to_device_stream(host_to_device_stream), send_device_memory_function(send_device_memory_function), - recv_device_memory_function(recv_device_memory_function) {} + recv_device_memory_function(recv_device_memory_function), + additional_compute_streams(additional_compute_streams) {} //===----------------------------------------------------------------------===// @@ -225,7 +251,21 @@ Thunk::ExecuteParams::ExecuteParams( CASE(kTriangularSolve); CASE(kWhile); CASE(kFusedMHA); + CASE(kWaitForStreams); + } +} + +/*static*/ +absl::StatusOr Thunk::GetStreamForExecution( + ExecutionStreamId stream_id, const ExecuteParams& params) { + if (stream_id == GetMainComputeStreamId()) { + return params.stream; } + auto iter = params.additional_compute_streams.find(stream_id); + if (iter == params.additional_compute_streams.end()) { + return absl::InvalidArgumentError("Invalid execution stream id."); + } + return iter->second; } std::ostream& operator<<(std::ostream& os, Thunk::Kind kind) { @@ -270,16 +310,21 @@ bool IsReductionCollective(Thunk::Kind kind) { Thunk::ThunkInfo Thunk::ThunkInfo::WithProfileAnnotation(mlir::Operation* op) { ThunkInfo thunk_info(op); - thunk_info.profile_annotation = absl::StrFormat( - "Thunk:#hlo_op=%s#", mlir::mhlo::GetDebugNameFromLocation(op->getLoc())); + thunk_info.profile_annotation = + mlir::mhlo::GetDebugNameFromLocation(op->getLoc()); return thunk_info; } Thunk::ThunkInfo Thunk::ThunkInfo::WithProfileAnnotation( const HloInstruction* instr) { ThunkInfo thunk_info(nullptr); - thunk_info.profile_annotation = - absl::StrFormat("Thunk:#hlo_op=%s#", instr->name()); + thunk_info.profile_annotation = instr->name(); + auto gpu_backend_config = instr->backend_config(); + if (gpu_backend_config.ok()) { + thunk_info.execution_stream_id = + std::max(Thunk::GetMainComputeStreamId().value(), + gpu_backend_config->operation_queue_id()); + } return thunk_info; } diff --git a/third_party/xla/xla/service/gpu/thunk.h b/third_party/xla/xla/service/gpu/thunk.h index f0ae399975f61d..25de078d5a381d 100644 --- a/third_party/xla/xla/service/gpu/thunk.h +++ b/third_party/xla/xla/service/gpu/thunk.h @@ -16,11 +16,11 @@ limitations under the License. #ifndef XLA_SERVICE_GPU_THUNK_H_ #define XLA_SERVICE_GPU_THUNK_H_ +#include #include #include #include #include -#include #include #include #include @@ -38,15 +38,19 @@ limitations under the License. #include "xla/service/buffer_assignment.h" #include "xla/service/global_device_id.h" #include "xla/service/gpu/buffer_allocations.h" +#include "xla/service/gpu/nccl_api.h" #include "xla/service/gpu/nccl_clique.h" #include "xla/service/gpu/nccl_clique_key.h" #include "xla/service/service_executable_run_options.h" #include "xla/stream_executor/stream.h" #include "xla/stream_executor/stream_executor.h" +#include "tsl/lib/gtl/int_type.h" namespace xla { namespace gpu { +TSL_LIB_GTL_DEFINE_INT_TYPE(ExecutionStreamId, int64_t); + // Thunk acts as the bridge between IrEmitter and GpuExecutable. It stores the // metadata IrEmitter generates for GpuExecutable to invoke an HloInstruction. // @@ -75,6 +79,9 @@ namespace gpu { // different threads and coordinate resource acquisition via rendezvous. class Thunk { public: + using ExecutionStreamIdMap = + absl::flat_hash_map; + enum Kind { kCholesky, kConditional, @@ -122,7 +129,8 @@ class Thunk { kSendDone, kTriangularSolve, kWhile, - kFusedMHA + kFusedMHA, + kWaitForStreams }; // TODO(ezhulenev): This should become a part of StreamExecutor library, but @@ -143,6 +151,8 @@ class Thunk { // TODO(b/304613751): This is only needed by the LMHLO. Remove this when // LMHLO is removed from the runtime pipeline. mlir::Operation* op; + + ExecutionStreamId execution_stream_id = Thunk::GetMainComputeStreamId(); }; //===--------------------------------------------------------------------===// @@ -167,19 +177,21 @@ class Thunk { // collected from all thunks at prepare stage. class CollectiveCliques { public: - using CliquesMap = - absl::flat_hash_map>; - CollectiveCliques() = default; - explicit CollectiveCliques(CliquesMap cliques_map); + explicit CollectiveCliques(NcclClique::AcquiredCliquesMap cliques_map); - absl::StatusOr GetComm(const NcclCliqueKey& clique_key, - int32_t rank) const; + absl::StatusOr GetComm( + const NcclCliqueKey& clique_key, int32_t rank) const; + + // Returns the number of communicators in a collective clique. Returns error + // if we do not have an acquired clique for a given key. + absl::StatusOr num_communicators( + const NcclCliqueKey& clique_key) const; bool empty() const { return cliques_map_.empty(); } private: - CliquesMap cliques_map_; + NcclClique::AcquiredCliquesMap cliques_map_; }; //===--------------------------------------------------------------------===// @@ -194,11 +206,14 @@ class Thunk { // missing a global device mapping for a local device ordinal). static absl::StatusOr Create( const ServiceExecutableRunOptions& run_options, - int64_t local_device_ordinal); + int64_t local_device_ordinal, int64_t collective_max_nchannels = 0, + int64_t p2p_max_nchannels = 0); // A mapping from local device ordinals to global device IDs. using GlobalDeviceIdMap = std::map; + se::StreamExecutor* executor; + // XLA execution run id allows us to distinguish collective operations // from different concurrent executions and avoid deadlocks. RunId run_id; @@ -210,12 +225,18 @@ class Thunk { const GlobalDeviceIdMap* global_device_id_map; const NcclCliqueIdCallback* nccl_clique_id_callback; + int64_t collective_max_nchannels; + int64_t p2p_max_nchannels; + private: - CollectiveExecuteParams( - RunId run_id, int64_t local_device_ordinal, - GlobalDeviceId global_device_id, const DeviceAssignment* device_assn, - const GlobalDeviceIdMap* global_device_id_map, - const NcclCliqueIdCallback* nccl_clique_id_callback); + CollectiveExecuteParams(se::StreamExecutor* executor, RunId run_id, + int64_t local_device_ordinal, + GlobalDeviceId global_device_id, + const DeviceAssignment* device_assn, + const GlobalDeviceIdMap* global_device_id_map, + const NcclCliqueIdCallback* nccl_clique_id_callback, + int64_t collective_max_nchannels, + int64_t p2p_max_nchannels); }; //===--------------------------------------------------------------------===// @@ -234,6 +255,9 @@ class Thunk { // InitializeParams //===--------------------------------------------------------------------===// + // TODO(ezhulenev): Merge InitializeParams and ExecuteParams as they have + // almost the same members and tightly coupled. + // Parameters passed to Initialize. At thunk initialization time we do not // launch any "work" on device and only initialize thunks for execution, i.e. // we pre-load kernels on device and instantiate all command buffers. @@ -255,7 +279,7 @@ class Thunk { se::Stream* command_buffer_trace_stream = nullptr; // Parameters for executing collective operations. - const CollectiveExecuteParams* collective_params = nullptr; + CollectiveExecuteParams* collective_params = nullptr; // Collective cliques acquired based on resource requests. CollectiveCliques* collective_cliques = nullptr; @@ -271,13 +295,14 @@ class Thunk { struct ExecuteParams { // Constructs execute parameters from an executable run options. Return // error if run options are misconfigured. - static ExecuteParams Create(const ServiceExecutableRunOptions& run_options, - const BufferAllocations& buffer_allocations, - se::Stream* stream, - se::Stream* command_buffer_trace_stream, - absl::Span async_streams, - CollectiveExecuteParams* collective_params, - CollectiveCliques* collective_cliques); + static ExecuteParams Create( + const ServiceExecutableRunOptions& run_options, + const BufferAllocations& buffer_allocations, se::Stream* stream, + se::Stream* command_buffer_trace_stream, + absl::Span async_streams, + CollectiveExecuteParams* collective_params, + CollectiveCliques* collective_cliques, + ExecutionStreamIdMap additional_compute_streams = {}); const BufferAllocations* buffer_allocations; // never null @@ -306,7 +331,12 @@ class Thunk { SendDeviceMemoryFunction* send_device_memory_function; RecvDeviceMemoryFunction* recv_device_memory_function; + // Additional compute streams on which thunks launch operations. + ExecutionStreamIdMap additional_compute_streams; + private: + friend class CommandBufferThunk; + ExecuteParams(const BufferAllocations* buffer_allocations, se::Stream* stream, se::Stream* command_buffer_trace_stream, absl::InlinedVector async_comms_streams, @@ -315,7 +345,8 @@ class Thunk { se::Stream* device_to_host_stream, se::Stream* host_to_device_stream, SendDeviceMemoryFunction* send_device_memory_function, - RecvDeviceMemoryFunction* recv_device_memory_function); + RecvDeviceMemoryFunction* recv_device_memory_function, + ExecutionStreamIdMap additional_compute_streams = {}); }; //===--------------------------------------------------------------------===// @@ -326,14 +357,15 @@ class Thunk { Thunk(Kind kind, ThunkInfo thunk_info) : kind_(kind), profile_annotation_(thunk_info.profile_annotation), - op_(thunk_info.op) {} + op_(thunk_info.op), + execution_stream_id_(thunk_info.execution_stream_id) {} virtual ~Thunk() = default; Thunk(const Thunk&) = delete; Thunk& operator=(const Thunk&) = delete; virtual std::string ToStringExtra(int indent) const { return ""; } Kind kind() const { return kind_; } - std::string profile_annotation() const { return profile_annotation_; } + std::string_view profile_annotation() const { return profile_annotation_; } // Only valid during compilation, i.e., lowering thunks to kernel-launch // related XLA runtime custom calls). nullptr at runtime. MLIR codegen will @@ -372,10 +404,20 @@ class Thunk { static absl::string_view KindToString(Thunk::Kind kind); + ExecutionStreamId execution_stream_id() const { return execution_stream_id_; } + + static absl::StatusOr GetStreamForExecution( + ExecutionStreamId stream_id, const ExecuteParams& params); + + static ExecutionStreamId GetMainComputeStreamId() { + return ExecutionStreamId(0); + } + private: Kind kind_; std::string profile_annotation_; mlir::Operation* op_; + ExecutionStreamId execution_stream_id_; }; // A sequence of thunks. diff --git a/third_party/xla/xla/service/gpu/runtime/topk_test.cc b/third_party/xla/xla/service/gpu/topk_test.cc similarity index 99% rename from third_party/xla/xla/service/gpu/runtime/topk_test.cc rename to third_party/xla/xla/service/gpu/topk_test.cc index a508aa79b4de9c..28a9caa7bda2ff 100644 --- a/third_party/xla/xla/service/gpu/runtime/topk_test.cc +++ b/third_party/xla/xla/service/gpu/topk_test.cc @@ -13,8 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/runtime/topk.h" - #include #include diff --git a/third_party/xla/xla/service/gpu/triton_autotuner.cc b/third_party/xla/xla/service/gpu/triton_autotuner.cc index 550430d79a1d61..be223f7ff5133e 100644 --- a/third_party/xla/xla/service/gpu/triton_autotuner.cc +++ b/third_party/xla/xla/service/gpu/triton_autotuner.cc @@ -74,6 +74,7 @@ limitations under the License. #include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/gpu/redzone_allocator.h" #include "xla/stream_executor/stream.h" +#include "xla/tools/hlo_decomposer.h" #include "xla/util.h" #include "xla/xla.pb.h" #include "tsl/lib/core/bits.h" @@ -133,7 +134,7 @@ class TritonAutotunerVisitor : public DfsHloRewriteVisitor { return absl::InternalError(absl::StrCat( "Expect autotune result cache hit for deviceless " "compilation (HLO: ", - hlo->ToString())); + hlo->ToString(), ")")); } return absl::InternalError("Expect autotune result cache hit."); })); @@ -156,8 +157,9 @@ class TritonAutotunerVisitor : public DfsHloRewriteVisitor { // This cannot be the "else" branch of the previous "if". if (backend_config.has_triton_gemm_config()) { - const TritonGemmConfig config = - TritonGemmConfig::FromProto(backend_config.triton_gemm_config()); + TF_ASSIGN_OR_RETURN( + const TritonGemmConfig config, + TritonGemmConfig::FromProto(backend_config.triton_gemm_config())); if (config.split_k > 1) { TF_RETURN_IF_ERROR(MakeDotSplitKBatch(hlo, config)); } @@ -290,23 +292,12 @@ constexpr std::array BLOCK_SIZES = {16, 32, 64, 128, 256, 512}; constexpr std::array NUM_STAGES = {1, 2, 3, 4}; constexpr std::array NUM_WARPS = {2, 4, 8, 16}; constexpr std::array SPLIT_K = {1, 2, 4, 8, 16}; - -// For arch >= Hopper autotuning. -constexpr std::array CLUSTER_DIMS = { - TritonGemmConfig::ClusterDims(1, 1, 1), - TritonGemmConfig::ClusterDims(2, 2, 1), - TritonGemmConfig::ClusterDims(2, 4, 1), - TritonGemmConfig::ClusterDims(4, 2, 1), - TritonGemmConfig::ClusterDims(4, 4, 1), - TritonGemmConfig::ClusterDims(2, 8, 1), - TritonGemmConfig::ClusterDims(8, 2, 1), -}; -constexpr std::array WARP_SPECIALIZATION = {false, true}; - -// Currently we believe that num_ctas is inferable from cluster_dims. -int InferNumCtas(TritonGemmConfig::ClusterDims cluster_dims) { - return cluster_dims.x * cluster_dims.y * cluster_dims.z; -} +// This is the number of blocks per cluster. +// +// Clusters have 3 dimensions (x,y,z) and only 1 <= x*y*z <= 16 are supported. +// Triton doesn't support (3,3,1) and possibly other non-"power of 2" values. +// It's possible that some other values may be(come) supported. +constexpr std::array NUM_CTAS = {1, 2, 4, 8, 16}; std::vector GetExhaustiveMatmulAutotuneConfigs( const HloDotInstruction& dot, @@ -331,9 +322,7 @@ std::vector GetExhaustiveMatmulAutotuneConfigs( continue; } for (int block_n : BLOCK_SIZES) { - // Exclude configs not supported by MMA layout v2. - if (block_n > limit.block_n || - (mma_layout_v2 && (block_m * block_n / 256) % num_warps != 0)) { + if (block_n > limit.block_n) { continue; } for (int block_k : BLOCK_SIZES) { @@ -351,15 +340,14 @@ std::vector GetExhaustiveMatmulAutotuneConfigs( block_m, block_n, block_k, split_k, num_stages, num_warps)); continue; } - // Arch >= Hopper autotuning. - for (bool enable_ws : WARP_SPECIALIZATION) { - for (TritonGemmConfig::ClusterDims cluster_dims : - CLUSTER_DIMS) { - configs.push_back(TritonGemmConfig( - block_m, block_n, block_k, split_k, num_stages, num_warps, - InferNumCtas(cluster_dims), cluster_dims, enable_ws)); - } + // We only want to autotune this if it provides any speedup. So + // please think about that before adding it to the default + // autotuning parameters. + for (int num_ctas : NUM_CTAS) { + configs.push_back(TritonGemmConfig(block_m, block_n, block_k, + split_k, num_stages, + num_warps, num_ctas)); } } } @@ -398,14 +386,13 @@ std::vector GetFixedMatmulAutotuneConfigs( std::back_inserter(configs)); } if (compute_capability.IsAtLeast(se::CudaComputeCapability::HOPPER)) { - configs.erase( - std::remove_if(configs.begin(), configs.end(), - [](const Config& config) { - return (config.block_m * config.block_n / 256) % - config.num_warps != - 0; - }), - configs.end()); + absl::c_copy( + std::vector{ + Config(16, 32, 32, 8, 1, 2), + Config(16, 64, 128, 8, 1, 4), + Config(16, 64, 128, 16, 3, 4), + }, + std::back_inserter(configs)); } configs.erase(std::remove_if(configs.begin(), configs.end(), [&](const Config& config) { @@ -447,7 +434,7 @@ absl::StatusOr> TritonGemmAutotuneExtractor( const HloFusionInstruction* fusion, DebugOptions debug_opts, bool allow_filtering_kernels_spilling_registers) { std::unique_ptr new_module = - AutotunerUtil::ExtractInstructionIntoNewModule(*fusion); + ExtractInstructionIntoNewModule(*fusion); // Reduce memory usage during compilation by disabling GPU runtime. debug_opts.set_xla_gpu_enable_xla_runtime_executable(false); // TODO(anlunx): Disable command buffers for now because it breaks triton @@ -501,7 +488,7 @@ absl::StatusOr> CublasGemmAutotuneExtractor( const HloComputation* fusion_computation = fusion->called_computations().at(0); std::unique_ptr new_module = - AutotunerUtil::ExtractComputationIntoNewModule(*fusion_computation); + ExtractComputationIntoNewModule(*fusion_computation); new_module->mutable_config().set_debug_options(debug_opts); GemmRewriter rewriter(config.GetGpuComputeCapability()); @@ -669,6 +656,7 @@ CompileMany(const AutotuneConfig& config, AutotunerCompileUtil& util, const GemmConfigSet& gemm_config_set = key_value.second; for (const TritonGemmConfig& gemm_config : gemm_config_set.configs) { + VLOG(5) << "Compiling " << gemm_config.ToString(); TF_ASSIGN_OR_RETURN( bool has_executable, compile( @@ -851,17 +839,20 @@ absl::StatusOr Execute(const AutotuneConfig& config, return best_triton; } -absl::Status DumpAutotunedFusion(const AutotuneConfig& config, +absl::Status DumpAutotunedFusion(const AutotuneConfig& autotune_config, AutotunerCompileUtil& util, const AutotuneResult result, const HloFusionInstruction* fusion, int fusion_id) { + TF_ASSIGN_OR_RETURN(TritonGemmConfig triton_gemm_config, + TritonGemmConfig::FromProto(result.triton())); + const se::DeviceDescription& device_desc = + autotune_config.GetExecutor()->GetDeviceDescription(); TF_ASSIGN_OR_RETURN( std::unique_ptr module, util.ExtractModule([&](const DebugOptions& debug_opts) { return TritonGemmAutotuneExtractor( - TritonGemmConfig::FromProto(result.triton()), - config.GetExecutor()->GetDeviceDescription(), fusion, debug_opts, + triton_gemm_config, device_desc, fusion, debug_opts, /*allow_filtering_kernels_spilling_registers=*/true); })); module->set_name(std::string(fusion->name())); diff --git a/third_party/xla/xla/service/gpu/triton_autotuner_test.cc b/third_party/xla/xla/service/gpu/triton_autotuner_test.cc index 8bca3088f86563..588e7ebee24c98 100644 --- a/third_party/xla/xla/service/gpu/triton_autotuner_test.cc +++ b/third_party/xla/xla/service/gpu/triton_autotuner_test.cc @@ -48,6 +48,7 @@ limitations under the License. #include "xla/tests/hlo_test_base.h" #include "xla/tests/test_utils.h" #include "xla/tests/verified_hlo_module.h" +#include "xla/tools/hlo_decomposer.h" #include "xla/xla.pb.h" #include "xla/xla_data.pb.h" #include "tsl/lib/core/status_test_util.h" @@ -88,9 +89,8 @@ ENTRY entry { })") .value(); - std::unique_ptr extracted_module = - AutotunerUtil::ExtractInstructionIntoNewModule( - *module->entry_computation()->root_instruction()->operand(0)); + std::unique_ptr extracted_module = ExtractInstructionIntoNewModule( + *module->entry_computation()->root_instruction()->operand(0)); // Destroy the original module to be sure that the extracted one has no // dependency on it. @@ -127,11 +127,10 @@ ENTRY entry { .value(); std::unique_ptr extracted_module = - AutotunerUtil::ExtractComputationIntoNewModule( - *module->entry_computation() - ->root_instruction() - ->operand(0) - ->fused_instructions_computation()); + ExtractComputationIntoNewModule(*module->entry_computation() + ->root_instruction() + ->operand(0) + ->fused_instructions_computation()); // Destroy the original module to be sure that the extracted one has no // dependency on it. @@ -426,12 +425,12 @@ ENTRY e { p0 = f16[55,120]{1,0} parameter(0) p1 = f16[120,20]{1,0} parameter(1) ROOT _ = f16[55,20] fusion(p0, p1), kind=kCustom, calls=triton_dot, - backend_config={"fusion_backend_config":{kind: "__triton_gemm", triton_gemm_config: {"block_m":16,"block_n":64,"block_k":32,"split_k":3,"num_stages":1,"num_warps":2}}} + backend_config={"fusion_backend_config":{kind: "__triton_gemm", triton_gemm_config: {"block_m":16,"block_n":64,"block_k":32,"split_k":3,"num_stages":1,"num_warps":2,"num_ctas":1}}} })"; MatchOptimizedHlo(kHloText, R"( ; CHECK: f16[3,55,20] -; CHECK: {"block_m":16,"block_n":64,"block_k":32,"split_k":3,"num_stages":1,"num_warps":2} +; CHECK: {"block_m":16,"block_n":64,"block_k":32,"split_k":3,"num_stages":1,"num_warps":2,"num_ctas":1} ; CHECK: f16[55,20]{1,0} {{(reduce|fusion)}} )"); @@ -455,7 +454,7 @@ ENTRY %e { %get-tuple-element.7020 = s8[12288,1536]{1,0} parameter(0) %convert = s8[4,12288]{1,0} parameter(1) ROOT %triton = s8[4,1536]{1,0} fusion(s8[12288,1536]{1,0} %get-tuple-element.7020, s8[4,12288]{1,0} %convert), kind=kCustom, calls=%triton_gemm_dot, - backend_config={"fusion_backend_config":{"kind":"__triton_gemm","triton_gemm_config":{"block_m":"256","block_n":"256","block_k":"16","split_k":"1","num_stages":"1","num_warps":"16"}}} + backend_config={"fusion_backend_config":{"kind":"__triton_gemm","triton_gemm_config":{"block_m":"256","block_n":"256","block_k":"16","split_k":"1","num_stages":"1","num_warps":"16","num_ctas":"1"}}} })"; if (!GetCudaComputeCapability().IsAtLeast( @@ -493,7 +492,7 @@ ENTRY %e { %get-tuple-element.7020 = s8[12288,1536]{1,0} parameter(0) %convert = s8[4,12288]{1,0} parameter(1) ROOT %triton = s8[4,1536]{1,0} fusion(s8[12288,1536]{1,0} %get-tuple-element.7020, s8[4,12288]{1,0} %convert), kind=kCustom, calls=%triton_gemm_dot, - backend_config={"fusion_backend_config":{"kind":"__triton_gemm","triton_gemm_config":{"block_m":"256","block_n":"256","block_k":"16","split_k":"1","num_stages":"1","num_warps":"16"}}} + backend_config={"fusion_backend_config":{"kind":"__triton_gemm","triton_gemm_config":{"block_m":"256","block_n":"256","block_k":"16","split_k":"1","num_stages":"1","num_warps":"16","num_ctas":"1"}}} })"; if (!GetCudaComputeCapability().IsAtLeast( @@ -535,7 +534,7 @@ ENTRY %e { %p0 = s8[12288,1536]{1,0} parameter(0) %p1 = f16[4,12288]{1,0} parameter(1) ROOT %triton_dot = f16[4,1536]{1,0} fusion(s8[12288,1536]{1,0} %p0, f16[4,12288]{1,0} %p1), kind=kCustom, calls=%triton_gemm_dot, - backend_config={"fusion_backend_config":{"kind":"__triton_gemm","triton_gemm_config":{"block_m":"16","block_n":"32","block_k":"16","split_k":"1","num_stages":"1","num_warps":"2"}}} + backend_config={"fusion_backend_config":{"kind":"__triton_gemm","triton_gemm_config":{"block_m":"16","block_n":"32","block_k":"16","split_k":"1","num_stages":"1","num_warps":"2","num_ctas":"1"}}} })"; auto module = ParseAndReturnVerifiedModule(kHloText).value(); @@ -631,7 +630,7 @@ ENTRY e { RunFileCheck( module->ToString(HloPrintOptions{}.set_print_operand_shape(false)), R"( -// CHECK: backend_config={"operation_queue_id":"0","wait_on_operation_queues":[],"fusion_backend_config":{"kind":"__triton_gemm","triton_gemm_config":{"block_m":"32","block_n":"32","block_k":"32","split_k":"1","num_stages":"1","num_warps":"4"}}} +// CHECK: backend_config={"operation_queue_id":"0","wait_on_operation_queues":[],"fusion_backend_config":{"kind":"__triton_gemm","triton_gemm_config":{"block_m":"32","block_n":"32","block_k":"32","split_k":"1","num_stages":"1","num_warps":"4","num_ctas":"1"}}} )")); EXPECT_TRUE(filecheck_matches); } else { diff --git a/third_party/xla/xla/service/gpu/triton_call.cc b/third_party/xla/xla/service/gpu/triton_call.cc new file mode 100644 index 00000000000000..dfc88e578b0ebe --- /dev/null +++ b/third_party/xla/xla/service/gpu/triton_call.cc @@ -0,0 +1,51 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/triton_call.h" + +#include +#include +#include + +#include "mlir/AsmParser/AsmParser.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/Parser/Parser.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project + +namespace xla::gpu { + +TritonCall TritonCall::Parse(std::string_view backend_config, + mlir::MLIRContext* mlir_context) { + // TODO(slebedev): Plumb through num_ctas and enable_wrap_specialization. + auto attrs = mlir::cast( + mlir::parseAttribute(backend_config, mlir_context)); + auto name = attrs.getAs("name").getValue().str(); + auto ir = attrs.getAs("ir").str(); + auto grid_x = static_cast( + attrs.getAs("grid_x").getValue().getSExtValue()); + auto grid_y = static_cast( + attrs.getAs("grid_y").getValue().getSExtValue()); + auto grid_z = static_cast( + attrs.getAs("grid_z").getValue().getSExtValue()); + auto num_stages = + attrs.getAs("num_stages").getValue().getSExtValue(); + auto num_warps = + attrs.getAs("num_warps").getValue().getSExtValue(); + return TritonCall{std::move(name), std::move(ir), num_stages, num_warps, + grid_x, grid_y, grid_z}; +} + +} // namespace xla::gpu diff --git a/third_party/xla/xla/service/gpu/runtime/topk_kernel_bfloat16.cu.cc b/third_party/xla/xla/service/gpu/triton_call.h similarity index 50% rename from third_party/xla/xla/service/gpu/runtime/topk_kernel_bfloat16.cu.cc rename to third_party/xla/xla/service/gpu/triton_call.h index 218770510e6657..169e4e703e7dc2 100644 --- a/third_party/xla/xla/service/gpu/runtime/topk_kernel_bfloat16.cu.cc +++ b/third_party/xla/xla/service/gpu/triton_call.h @@ -1,4 +1,4 @@ -/* Copyright 2023 The OpenXLA Authors. +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -13,15 +13,31 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "Eigen/Core" // from @eigen_archive -#include "xla/service/gpu/runtime/topk_kernel.cu.h" +#ifndef XLA_SERVICE_GPU_TRITON_CALL_H_ +#define XLA_SERVICE_GPU_TRITON_CALL_H_ + +#include +#include +#include + +#include "mlir/IR/MLIRContext.h" // from @llvm-project namespace xla::gpu { -template void* GetTopKKernelForK(int n); -template void* GetTopKKernelForK(int n); -template void* GetTopKKernelForK(int n); -template void* GetTopKKernelForK(int n); -template void* GetTopKKernelForK(int n); +struct TritonCall { + std::string name; + std::string ir; + int64_t num_stages; + int64_t num_warps; + int32_t grid_x; + int32_t grid_y; + int32_t grid_z; + + // Parse the metadata of a __gpu$xla.gpu.triton call. + static TritonCall Parse(std::string_view backend_config, + mlir::MLIRContext* mlir_context); +}; } // namespace xla::gpu + +#endif // XLA_SERVICE_GPU_TRITON_CALL_H_ diff --git a/third_party/xla/xla/service/gpu/triton_fusion_analysis.cc b/third_party/xla/xla/service/gpu/triton_fusion_analysis.cc index 9981ad86ac03c9..c942962ce35867 100644 --- a/third_party/xla/xla/service/gpu/triton_fusion_analysis.cc +++ b/third_party/xla/xla/service/gpu/triton_fusion_analysis.cc @@ -256,6 +256,7 @@ absl::Status TritonFusionAnalysis::ExecuteForDotFusion( .insert( {output, context.dim_orders().at(output).ToTensorIterationSpec()}) .second); + parameters_[Scope::OUTPUT] = {}; if (output != &dot) { // Propagate back to parameters of the output fusion. TF_RETURN_IF_ERROR(context.PropagateDimensionOrdersToParameters( diff --git a/third_party/xla/xla/service/gpu/triton_fusion_analysis_test.cc b/third_party/xla/xla/service/gpu/triton_fusion_analysis_test.cc index a43e61a52cf0bb..aa9639e5e09986 100644 --- a/third_party/xla/xla/service/gpu/triton_fusion_analysis_test.cc +++ b/third_party/xla/xla/service/gpu/triton_fusion_analysis_test.cc @@ -39,6 +39,28 @@ using ::testing::FieldsAre; using TritonDotAnalysisTest = HloTestBase; +TEST_F(TritonDotAnalysisTest, QueryingOutputScopeParametersAlwaysWorks) { + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(R"( +triton_dot { + p0 = f32[8,8] parameter(0) + ROOT dot = f32[8,8] dot(p0, p0), + lhs_contracting_dims={1}, rhs_contracting_dims={0} +} + +ENTRY e { + p0 = f32[8,8] parameter(0) + ROOT r = f32[8,8] fusion(p0), kind=kCustom, calls=triton_dot +})")); + TF_ASSERT_OK_AND_ASSIGN( + const auto analysis, + TritonFusionAnalysis::Execute(*module->entry_computation() + ->root_instruction() + ->called_computations()[0])); + EXPECT_TRUE( + analysis.ScopeParameters(TritonFusionAnalysis::Scope::OUTPUT).empty()); +} + TEST_F(TritonDotAnalysisTest, NopBitcasts) { const std::string hlo_text = R"( HloModule t diff --git a/third_party/xla/xla/service/gpu/triton_support.cc b/third_party/xla/xla/service/gpu/triton_support.cc index c207f59859c155..5713bd70050ff1 100644 --- a/third_party/xla/xla/service/gpu/triton_support.cc +++ b/third_party/xla/xla/service/gpu/triton_support.cc @@ -90,7 +90,7 @@ std::vector TritonSupportedUnaryElementwise( HloOpcode::kLog1p, HloOpcode::kRsqrt, HloOpcode::kSin, HloOpcode::kSqrt, HloOpcode::kCbrt, HloOpcode::kTan, - HloOpcode::kTanh}, + HloOpcode::kTanh, HloOpcode::kErf}, std::back_inserter(ret)); } return ret; diff --git a/third_party/xla/xla/service/gpu_compilation_environment.cc b/third_party/xla/xla/service/gpu_compilation_environment.cc index d92f541f52a354..d6551239db2a99 100644 --- a/third_party/xla/xla/service/gpu_compilation_environment.cc +++ b/third_party/xla/xla/service/gpu_compilation_environment.cc @@ -49,7 +49,7 @@ void InitializeFlagsForGpuCompEnv(std::vector* flag_list, gpu_comp_env->dummy_flag(), "Dummy flag to demonstrate the flow")); } -StatusOr CreateGpuCompEnvFromFlagStrings( +absl::StatusOr CreateGpuCompEnvFromFlagStrings( std::vector& flags, bool strict) { GpuCompilationEnvironment gpu_comp_env; std::vector flag_objects; @@ -62,7 +62,7 @@ StatusOr CreateGpuCompEnvFromFlagStrings( return gpu_comp_env; } -StatusOr CreateGpuCompEnvFromEnvVar() { +absl::StatusOr CreateGpuCompEnvFromEnvVar() { GpuCompilationEnvironment env; std::vector flag_objects; InitializeFlagsForGpuCompEnv(&flag_objects, &env); @@ -119,7 +119,7 @@ namespace { // // The implementation returns Empty env if one doesn't exist already. // NOLINTNEXTLINE -StatusOr> +absl::StatusOr> ProcessNewGpuCompilationEnvironment( std::unique_ptr env) { // NOLINT if (!env) { diff --git a/third_party/xla/xla/service/gpu_compilation_environment.h b/third_party/xla/xla/service/gpu_compilation_environment.h index 881ddc802ad08d..23f2a30273c8a9 100644 --- a/third_party/xla/xla/service/gpu_compilation_environment.h +++ b/third_party/xla/xla/service/gpu_compilation_environment.h @@ -23,10 +23,10 @@ limitations under the License. namespace xla { -StatusOr CreateGpuCompEnvFromFlagStrings( +absl::StatusOr CreateGpuCompEnvFromFlagStrings( std::vector& flags, bool strict); -StatusOr CreateGpuCompEnvFromEnvVar(); +absl::StatusOr CreateGpuCompEnvFromEnvVar(); GpuCompilationEnvironment CreateGpuCompEnvWithDefaultValues(); diff --git a/third_party/xla/xla/service/graphcycles/BUILD b/third_party/xla/xla/service/graphcycles/BUILD index 3cd370a96da244..dfaf6a27b5ca21 100644 --- a/third_party/xla/xla/service/graphcycles/BUILD +++ b/third_party/xla/xla/service/graphcycles/BUILD @@ -1,9 +1,12 @@ load("//xla:xla.bzl", "xla_cc_test") -load("@local_tsl//tsl:tsl.bzl", "set_external_visibility") +load("@local_tsl//tsl:tsl.bzl", "internal_visibility") load("@local_tsl//tsl/platform:rules_cc.bzl", "cc_library") package( - default_visibility = ["//visibility:public"], + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = internal_visibility([ + "//tensorflow/compiler:__subpackages__", + ]), licenses = ["notice"], ) @@ -11,7 +14,6 @@ cc_library( name = "graphcycles", srcs = ["graphcycles.cc"], hdrs = ["graphcycles.h"], - visibility = ["//visibility:public"], deps = [ ":ordered_set", "@com_google_absl//absl/algorithm:container", @@ -26,7 +28,6 @@ cc_library( cc_library( name = "ordered_set", hdrs = ["ordered_set.h"], - visibility = ["//visibility:public"], deps = [ "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/types:span", diff --git a/third_party/xla/xla/service/heap_simulator/BUILD b/third_party/xla/xla/service/heap_simulator/BUILD index 2f756362ec9b9f..161c8f069028d3 100644 --- a/third_party/xla/xla/service/heap_simulator/BUILD +++ b/third_party/xla/xla/service/heap_simulator/BUILD @@ -8,7 +8,8 @@ load( load("@local_tsl//tsl/platform:rules_cc.bzl", "cc_library") package( - default_visibility = ["//visibility:public"], + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = [":friends"], licenses = ["notice"], ) @@ -23,7 +24,6 @@ cc_library( name = "allocation_block", srcs = ["allocation_block.cc"], hdrs = ["allocation_block.h"], - visibility = ["//visibility:public"], deps = [ "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/strings", @@ -35,7 +35,6 @@ cc_library( name = "heap_simulator", srcs = ["heap_simulator.cc"], hdrs = ["heap_simulator.h"], - visibility = ["//visibility:public"], deps = [ ":allocation_block", "//xla:comparison_util", diff --git a/third_party/xla/xla/service/hlo.proto b/third_party/xla/xla/service/hlo.proto index e1908ce0f6a8c3..e505f31f78ec72 100644 --- a/third_party/xla/xla/service/hlo.proto +++ b/third_party/xla/xla/service/hlo.proto @@ -30,6 +30,7 @@ syntax = "proto3"; package xla; +import "google/protobuf/any.proto"; import "xla/xla_data.proto"; option cc_enable_arenas = true; @@ -111,7 +112,7 @@ enum CustomCallApiVersion { } // Serialization of HloInstruction. -// Next ID: 86 +// Next ID: 87 message HloInstructionProto { reserved 10; reserved "parameter_name"; @@ -373,6 +374,9 @@ message HloInstructionProto { reserved 83; // Used to be wait_on_operation_queues. reserved 84; + + // Sparsity descriptor for dot operation. + xla.SparsityDescriptor dot_sparsity = 86; } // Serialization of HloComputation. @@ -790,6 +794,9 @@ message HloPassMetadata { // Timestamp before and after the pass is run. Note they may be equal. int64 start_timestamp_usec = 8; int64 end_timestamp_usec = 9; + + // Custom metadata for the pass. + google.protobuf.Any custom_metadata = 10; } // Encodes the underlying Xla runtime executable compiled from the XLA module. diff --git a/third_party/xla/xla/service/hlo_cost_analysis.cc b/third_party/xla/xla/service/hlo_cost_analysis.cc index 8889ba7849012e..1511edb3ec828b 100644 --- a/third_party/xla/xla/service/hlo_cost_analysis.cc +++ b/third_party/xla/xla/service/hlo_cost_analysis.cc @@ -135,13 +135,14 @@ Status HloCostAnalysis::HandleElementwiseOp( // operation can correspond to several floating point ops. // kLogistic is included in "trascendental" as it is implemented using // trascendental ops (tanh or exp). - if (opcode == HloOpcode::kExp || opcode == HloOpcode::kLog || - opcode == HloOpcode::kLogistic || opcode == HloOpcode::kPower || - opcode == HloOpcode::kSqrt || opcode == HloOpcode::kCbrt || - opcode == HloOpcode::kRsqrt || opcode == HloOpcode::kTanh || - opcode == HloOpcode::kSin || opcode == HloOpcode::kCos || - opcode == HloOpcode::kExpm1 || opcode == HloOpcode::kLog1p || - opcode == HloOpcode::kAtan2 || opcode == HloOpcode::kTan) { + if (opcode == HloOpcode::kErf || opcode == HloOpcode::kExp || + opcode == HloOpcode::kLog || opcode == HloOpcode::kLogistic || + opcode == HloOpcode::kPower || opcode == HloOpcode::kSqrt || + opcode == HloOpcode::kCbrt || opcode == HloOpcode::kRsqrt || + opcode == HloOpcode::kTanh || opcode == HloOpcode::kSin || + opcode == HloOpcode::kCos || opcode == HloOpcode::kExpm1 || + opcode == HloOpcode::kLog1p || opcode == HloOpcode::kAtan2 || + opcode == HloOpcode::kTan) { current_properties_[kTranscendentalsKey] = computation_count; } else { // Note: transcendental operations are considered a separate category from diff --git a/third_party/xla/xla/service/hlo_creation_utils.h b/third_party/xla/xla/service/hlo_creation_utils.h index 1ece27022567b6..c55ae84eeede9c 100644 --- a/third_party/xla/xla/service/hlo_creation_utils.h +++ b/third_party/xla/xla/service/hlo_creation_utils.h @@ -284,8 +284,8 @@ HloInstruction* MakeScalarLike(HloInstruction* base, NativeT value) { *scalar->mutable_shape() = base->shape(); return scalar; } - return base->AddInstruction( - HloInstruction::CreateBroadcast(base->shape(), scalar, {})); + return base->AddInstruction(HloInstruction::CreateBroadcast( + ShapeUtil::MakeStaticShape(base->shape()), scalar, {})); } // Creates a fusion instruction and fuses `fused` into the created fusion diff --git a/third_party/xla/xla/service/hlo_creation_utils_test.cc b/third_party/xla/xla/service/hlo_creation_utils_test.cc index fc4ee3cf36ea6b..4df62a6463e484 100644 --- a/third_party/xla/xla/service/hlo_creation_utils_test.cc +++ b/third_party/xla/xla/service/hlo_creation_utils_test.cc @@ -22,6 +22,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_module.h" #include "xla/service/pattern_matcher.h" #include "xla/service/pattern_matcher_gmock.h" +#include "xla/shape.h" #include "xla/shape_util.h" #include "xla/test.h" #include "xla/tests/hlo_test_base.h" @@ -372,6 +373,7 @@ TEST_F(HloCreationUtilsTest, MaybeMakeTupleTuplizesMultipleOperands) { Literal expected_result = LiteralUtil::MakeTuple({&input1, &input0}); EXPECT_EQ(result_literal, expected_result); } + TEST_F(HloCreationUtilsTest, DynamicUpdateSliceVectorStartIndices) { auto module = CreateNewVerifiedModule("dus-creation-test"); // arg: @@ -485,5 +487,19 @@ TEST_F(HloCreationUtilsTest, ReduceWindow) { expected_output_shape); } +TEST_F(HloCreationUtilsTest, DynamicBroadcastShape) { + HloInstruction* param; + HloComputation* entry_computation; + + auto module = CreateModuleWithProgramShape(F32, /*input_shape_dims=*/{10}, + /*output_shape_dims=*/{10}, ¶m, + &entry_computation); + param->mutable_shape()->set_dynamic_dimension(0, true); + + HloInstruction* one_constant = MakeScalarLike(param, 1.0f); + // Broadcasts should always have a static shape that is inferred. + EXPECT_TRUE(one_constant->shape().is_static()); +} + } // namespace } // namespace xla diff --git a/third_party/xla/xla/service/hlo_dataflow_analysis.cc b/third_party/xla/xla/service/hlo_dataflow_analysis.cc index a115624c608e52..cf4f665fd7b91b 100644 --- a/third_party/xla/xla/service/hlo_dataflow_analysis.cc +++ b/third_party/xla/xla/service/hlo_dataflow_analysis.cc @@ -30,6 +30,7 @@ limitations under the License. #include "absl/container/flat_hash_set.h" #include "absl/container/inlined_vector.h" #include "absl/functional/function_ref.h" +#include "absl/log/check.h" #include "absl/memory/memory.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" @@ -332,7 +333,7 @@ HloValue* HloDataflowAnalysis::NewHloValue(HloInstruction* instruction, } void HloDataflowAnalysis::MarkValueForDeletion(HloValue::Id value_id) { - const HloValue& value = *values_.at(value_id); + const HloValue& value = GetValue(value_id); VLOG(4) << "MarkValueForDeletion(" << value.ToShortString() << ")"; value_ids_to_delete_.push_back(value_id); @@ -526,11 +527,15 @@ bool HloDataflowAnalysis::Phi( } const HloValue& HloDataflowAnalysis::GetValue(HloValue::Id value_id) const { - return *values_.at(value_id); + const auto value = values_.find(value_id); + CHECK(value != values_.end()) << "Value not found: " << value_id; + return *value->second; } HloValue& HloDataflowAnalysis::GetValue(HloValue::Id value_id) { - return *values_.at(value_id); + const auto value = values_.find(value_id); + CHECK(value != values_.end()) << "Value not found: " << value_id; + return *value->second; } HloValueSet HloDataflowAnalysis::GetFlattenedValueSet( @@ -1403,12 +1408,18 @@ void HloDataflowAnalysis::Propagate() { const InstructionValueSet& HloDataflowAnalysis::GetInstructionValueSet( const HloInstruction* instruction) const { - return *value_sets_.at(instruction); + const auto value_set = value_sets_.find(instruction); + CHECK(value_set != value_sets_.end()) + << "Instruction " << instruction->ToString() << " not found."; + return *value_set->second; } InstructionValueSet& HloDataflowAnalysis::GetInstructionValueSet( const HloInstruction* instruction) { - return *value_sets_.at(instruction); + const auto value_set = value_sets_.find(instruction); + CHECK(value_set != value_sets_.end()) + << "Instruction " << instruction->ToString() << " not found."; + return *value_set->second; } Status HloDataflowAnalysis::InitializeInstructionValueSets() { diff --git a/third_party/xla/xla/service/hlo_dce_test.cc b/third_party/xla/xla/service/hlo_dce_test.cc index 8b44ebd7f484fa..e80cb84c67a5e7 100644 --- a/third_party/xla/xla/service/hlo_dce_test.cc +++ b/third_party/xla/xla/service/hlo_dce_test.cc @@ -111,6 +111,30 @@ TEST_F(HloDceTest, CustomCallInstructionsWithSideEffect) { EXPECT_FALSE(result); } +TEST_F(HloDceTest, AsyncCustomCallInstructionsWithSideEffect) { + // Verify that custom call instruction with side-effect is not removed. + auto builder = HloComputation::Builder(TestName()); + auto instr = Cast(builder.AddInstruction( + HloInstruction::CreateCustomCall(ShapeUtil::MakeShape(F32, {}), + /*operands=*/{}, + /*custom_call_target=*/"foo"))); + instr->set_custom_call_has_side_effect(true); + builder.AddInstruction(HloInstruction::CreateTuple({})); + + auto module = CreateNewVerifiedModule(); + module->AddEntryComputation(builder.Build()); + + TF_ASSERT_OK_AND_ASSIGN([[maybe_unused]] HloInstruction * async_done, + module->entry_computation()->CreateAsyncInstructions( + instr, {{ShapeUtil::MakeScalarShape(U32)}}, + HloInstruction::kMainExecutionThread, + /*replace=*/true, /*override_names=*/true)); + + HloDCE dce; + TF_ASSERT_OK_AND_ASSIGN(bool result, RunHloPass(&dce, module.get())); + EXPECT_FALSE(result); +} + TEST_F(HloDceTest, CustomCallInstructionsWithoutSideEffect) { // Verify that custom call instruction without side-effect is removed. auto builder = HloComputation::Builder(TestName()); @@ -128,6 +152,30 @@ TEST_F(HloDceTest, CustomCallInstructionsWithoutSideEffect) { EXPECT_TRUE(result); } +TEST_F(HloDceTest, AsyncCustomCallInstructionsWithoutSideEffect) { + // Verify that custom call instruction without side-effect is removed. + auto builder = HloComputation::Builder(TestName()); + auto instr = Cast(builder.AddInstruction( + HloInstruction::CreateCustomCall(ShapeUtil::MakeShape(F32, {}), + /*operands=*/{}, + /*custom_call_target=*/"foo"))); + instr->set_custom_call_has_side_effect(false); + builder.AddInstruction(HloInstruction::CreateTuple({})); + + auto module = CreateNewVerifiedModule(); + module->AddEntryComputation(builder.Build()); + + TF_ASSERT_OK_AND_ASSIGN([[maybe_unused]] HloInstruction * async_done, + module->entry_computation()->CreateAsyncInstructions( + instr, {{ShapeUtil::MakeScalarShape(U32)}}, + HloInstruction::kMainExecutionThread, + /*replace=*/true, /*override_names=*/true)); + + HloDCE dce; + TF_ASSERT_OK_AND_ASSIGN(bool result, RunHloPass(&dce, module.get())); + EXPECT_TRUE(result); +} + TEST_F(HloDceTest, ShardingCustomCallInstruction) { // Verify that sharding custom call instruction is not removed. auto builder = HloComputation::Builder(TestName()); diff --git a/third_party/xla/xla/service/hlo_graph_dumper.cc b/third_party/xla/xla/service/hlo_graph_dumper.cc index 02325b834dfe05..40d8872dce44d0 100644 --- a/third_party/xla/xla/service/hlo_graph_dumper.cc +++ b/third_party/xla/xla/service/hlo_graph_dumper.cc @@ -1094,6 +1094,7 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) { case HloOpcode::kConvert: case HloOpcode::kCos: case HloOpcode::kDivide: + case HloOpcode::kErf: case HloOpcode::kExp: case HloOpcode::kExpm1: case HloOpcode::kFloor: @@ -1209,6 +1210,7 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) { case HloOpcode::kAllReduceStart: case HloOpcode::kAllReduceDone: case HloOpcode::kAllToAll: + case HloOpcode::kCollectiveBroadcast: case HloOpcode::kCollectivePermute: case HloOpcode::kCollectivePermuteStart: case HloOpcode::kCollectivePermuteDone: diff --git a/third_party/xla/xla/service/hlo_instruction_test.cc b/third_party/xla/xla/service/hlo_instruction_test.cc index e40cc977c9aca4..0ddabf39028575 100644 --- a/third_party/xla/xla/service/hlo_instruction_test.cc +++ b/third_party/xla/xla/service/hlo_instruction_test.cc @@ -856,6 +856,60 @@ TEST_F(HloInstructionTest, AsyncOp) { EXPECT_EQ(computation->root_instruction(), async_done); } +TEST_F(HloInstructionTest, AsyncOpWithDeps) { + HloComputation::Builder builder(TestName()); + // Create a call instruction containing a single binary operation. + auto constant1 = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.1f))); + auto constant2 = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.1f))); + + auto constant3 = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.1f))); + auto constant4 = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.1f))); + + auto add1 = builder.AddInstruction(HloInstruction::CreateBinary( + r0f32_, HloOpcode::kAdd, constant3, constant4)); + + auto add = builder.AddInstruction(HloInstruction::CreateBinary( + r0f32_, HloOpcode::kAdd, constant1, constant2)); + + auto add2 = builder.AddInstruction(HloInstruction::CreateBinary( + r0f32_, HloOpcode::kAdd, constant1, constant2)); + + // control chain is add1 <- add <- add2 + TF_ASSERT_OK(add1->AddControlDependencyTo(add)); + + TF_ASSERT_OK(add->AddControlDependencyTo(add2)); + + auto module = CreateNewVerifiedModule(); + auto* computation = module->AddEntryComputation(builder.Build()); + TF_ASSERT_OK_AND_ASSIGN( + auto* async_done, + computation->CreateAsyncInstructions( + add, {ShapeUtil::MakeScalarShape(U32)}, "parallel_thread")); + auto* async_start = async_done->operand(0); + // Verify that control chain is not broken. + // New chain should be add1 <- asyncStart <- asyncDone <- add2 + EXPECT_EQ(async_start->control_predecessors().size(), 1); + EXPECT_EQ(async_start->control_predecessors()[0], add1); + + EXPECT_EQ(async_done->control_successors().size(), 1); + EXPECT_EQ(async_done->control_successors()[0], add2); + + EXPECT_EQ(async_start->shape().tuple_shapes_size(), 3); + EXPECT_EQ(async_start->async_execution_thread(), "parallel_thread"); + EXPECT_EQ(async_done->async_execution_thread(), "parallel_thread"); + EXPECT_TRUE(ShapeUtil::Equal(async_start->shape().tuple_shapes(2), + ShapeUtil::MakeScalarShape(U32))); + EXPECT_EQ(async_start->async_wrapped_computation()->execution_thread(), + "parallel_thread"); + EXPECT_EQ(async_done->async_wrapped_computation()->execution_thread(), + "parallel_thread"); + EXPECT_THAT(async_start->operands(), ElementsAre(constant1, constant2)); +} + TEST_F(HloInstructionTest, PreserveOutfeedShapeThroughClone) { HloComputation::Builder builder(TestName()); auto constant = builder.AddInstruction( diff --git a/third_party/xla/xla/service/hlo_module_config.cc b/third_party/xla/xla/service/hlo_module_config.cc index b7f8b6153bc1ca..bef8a0417c2b3b 100644 --- a/third_party/xla/xla/service/hlo_module_config.cc +++ b/third_party/xla/xla/service/hlo_module_config.cc @@ -85,6 +85,9 @@ std::string HloModuleConfig::compilation_cache_key() const { StrAppend(&key, device_type()); } StrAppend(&key, "::alias_passthrough_params=", alias_passthrough_params_); + StrAppend(&key, "::allow_spmd_sharding_propagation_to_parameters={", + absl::StrJoin(allow_spmd_sharding_propagation_to_parameters_, ","), + "}"); StrAppend(&key, "::allow_spmd_sharding_propagation_to_output={", absl::StrJoin(allow_spmd_sharding_propagation_to_output_, ","), "}"); @@ -303,6 +306,9 @@ StatusOr HloModuleConfig::ToProto() const { AssignProtoPhaseOrderingConfig(proto, phase_ordering_config_); proto.set_phase_index(phase_index_); + for (bool value : allow_spmd_sharding_propagation_to_parameters_) { + proto.add_allow_spmd_sharding_propagation_to_parameters(value); + } for (bool value : allow_spmd_sharding_propagation_to_output_) { proto.add_allow_spmd_sharding_propagation_to_output(value); } @@ -370,6 +376,9 @@ StatusOr> HloModuleConfig::CreateFromProto( proto.memory_space_assignment_config().end()); AssignStructPhaseOrderingConfig(*config, proto); config->phase_index_ = proto.phase_index(); + config->allow_spmd_sharding_propagation_to_parameters_.assign( + proto.allow_spmd_sharding_propagation_to_parameters().begin(), + proto.allow_spmd_sharding_propagation_to_parameters().end()); config->allow_spmd_sharding_propagation_to_output_.assign( proto.allow_spmd_sharding_propagation_to_output().begin(), proto.allow_spmd_sharding_propagation_to_output().end()); diff --git a/third_party/xla/xla/service/hlo_module_config.h b/third_party/xla/xla/service/hlo_module_config.h index 6765a6f5ec33eb..9a5e415bfa099c 100644 --- a/third_party/xla/xla/service/hlo_module_config.h +++ b/third_party/xla/xla/service/hlo_module_config.h @@ -324,9 +324,17 @@ class HloModuleConfig { int phase_index() const { return phase_index_; } void set_phase_index(const int phase_index) { phase_index_ = phase_index; } + absl::Span allow_spmd_sharding_propagation_to_parameters() const { + return allow_spmd_sharding_propagation_to_parameters_; + } absl::Span allow_spmd_sharding_propagation_to_output() const { return allow_spmd_sharding_propagation_to_output_; } + void set_allow_spmd_sharding_propagation_to_parameters( + absl::Span data) { + return allow_spmd_sharding_propagation_to_parameters_.assign(data.begin(), + data.end()); + } void set_allow_spmd_sharding_propagation_to_output( absl::Span data) { return allow_spmd_sharding_propagation_to_output_.assign(data.begin(), @@ -453,6 +461,18 @@ class HloModuleConfig { // config across functions during compilation. int phase_index_ = 0; + // Allows sharding propagation to propagate to the parameters. This changes + // the input shape of the computation (which is undesirable), but it can be + // used to allow to run partial compilation to determine what would be the + // input sharding of a computation if XLA would be allowed to propagate the + // sharding which can be used by higher level framework as a way to query + // intermediate sharding of operations when multiple computation would be + // chained and merged together. + // This is a vector of bool, because the user can control which parameters can + // have the sharding substituted. If only one boolean value is passed in the + // vector that is interpreted as the value to be applied for every parameter. + absl::InlinedVector allow_spmd_sharding_propagation_to_parameters_ = + {false}; // Allows sharding propagation to propagate to the outputs. This changes the // output shape of the computation (which is undesirable), but it can be used // to allow to run partial compilation to determine what would be the output diff --git a/third_party/xla/xla/service/hlo_module_util.cc b/third_party/xla/xla/service/hlo_module_util.cc index b7db06527fdba9..fab668e4d4e6db 100644 --- a/third_party/xla/xla/service/hlo_module_util.cc +++ b/third_party/xla/xla/service/hlo_module_util.cc @@ -100,6 +100,11 @@ StatusOr> CreateModuleConfig( } config->set_use_spmd_partitioning( execution_options->use_spmd_partitioning()); + if (!execution_options->allow_spmd_sharding_propagation_to_parameters() + .empty()) { + config->set_allow_spmd_sharding_propagation_to_parameters( + execution_options->allow_spmd_sharding_propagation_to_parameters()); + } if (!execution_options->allow_spmd_sharding_propagation_to_output() .empty()) { config->set_allow_spmd_sharding_propagation_to_output( diff --git a/third_party/xla/xla/service/hlo_opcode_test.cc b/third_party/xla/xla/service/hlo_opcode_test.cc index b0cbe761e086b8..b073d043749268 100644 --- a/third_party/xla/xla/service/hlo_opcode_test.cc +++ b/third_party/xla/xla/service/hlo_opcode_test.cc @@ -58,6 +58,7 @@ TEST(HloOpcodeTest, OpcodeProperties) { case HloOpcode::kAllReduceStart: case HloOpcode::kAllToAll: case HloOpcode::kCall: + case HloOpcode::kCollectiveBroadcast: case HloOpcode::kCollectivePermute: case HloOpcode::kCollectivePermuteStart: case HloOpcode::kConcatenate: diff --git a/third_party/xla/xla/service/hlo_parser.cc b/third_party/xla/xla/service/hlo_parser.cc index 8e0e10c2beafda..ce0832c2462d25 100644 --- a/third_party/xla/xla/service/hlo_parser.cc +++ b/third_party/xla/xla/service/hlo_parser.cc @@ -129,6 +129,7 @@ bool CanInferShape(HloOpcode code) { case HloOpcode::kDivide: case HloOpcode::kDomain: case HloOpcode::kDot: + case HloOpcode::kErf: case HloOpcode::kExp: case HloOpcode::kExpm1: case HloOpcode::kFft: @@ -196,6 +197,7 @@ bool CanInferShape(HloOpcode code) { case HloOpcode::kAllReduceStart: case HloOpcode::kAllReduceDone: case HloOpcode::kAllToAll: + case HloOpcode::kCollectiveBroadcast: case HloOpcode::kCollectivePermute: case HloOpcode::kCollectivePermuteStart: case HloOpcode::kCollectivePermuteDone: @@ -1014,6 +1016,7 @@ bool HloParserImpl::ParseHloModule(HloModule* module, absl::flat_hash_map attrs; std::optional entry_computation_layout; std::optional frontend_attributes; + BoolList allow_spmd_sharding_propagation_to_parameters; BoolList allow_spmd_sharding_propagation_to_output; attrs["is_scheduled"] = {/*required=*/false, AttrTy::kBool, &is_scheduled}; @@ -1031,6 +1034,9 @@ bool HloParserImpl::ParseHloModule(HloModule* module, &entry_computation_layout}; attrs["frontend_attributes"] = { /*required=*/false, AttrTy::kFrontendAttributes, &frontend_attributes}; + attrs["allow_spmd_sharding_propagation_to_parameters"] = { + /*required=*/false, AttrTy::kBracedBoolListOrBool, + &allow_spmd_sharding_propagation_to_parameters}; attrs["allow_spmd_sharding_propagation_to_output"] = { /*required=*/false, AttrTy::kBracedBoolListOrBool, &allow_spmd_sharding_propagation_to_output}; @@ -1089,6 +1095,11 @@ bool HloParserImpl::ParseHloModule(HloModule* module, if (frontend_attributes) { module->set_frontend_attributes(frontend_attributes.value()); } + if (!allow_spmd_sharding_propagation_to_parameters.empty()) { + config.set_allow_spmd_sharding_propagation_to_parameters( + allow_spmd_sharding_propagation_to_parameters); + default_config = false; + } if (!allow_spmd_sharding_propagation_to_output.empty()) { config.set_allow_spmd_sharding_propagation_to_output( allow_spmd_sharding_propagation_to_output); @@ -1504,6 +1515,7 @@ HloInstruction* HloParserImpl::CreateInstruction( // NOLINT case HloOpcode::kCopyDone: case HloOpcode::kCos: case HloOpcode::kOptimizationBarrier: + case HloOpcode::kErf: case HloOpcode::kExp: case HloOpcode::kExpm1: case HloOpcode::kImag: @@ -2105,14 +2117,15 @@ HloInstruction* HloParserImpl::CreateInstruction( // NOLINT !ParseAttributes(attrs, allow_attributes)) { return nullptr; } - if (dynamic_cast(operands[0]) == nullptr) { - return nullptr; - } - if (channel_id != operands[0]->channel_id()) { - return nullptr; + + if (dynamic_cast(operands[0]) != nullptr) { + if (channel_id != operands[0]->channel_id()) { + return nullptr; + } } - return builder->AddInstruction( - HloInstruction::CreateRecvDone(operands[0], *is_host_transfer)); + + return builder->AddInstruction(HloInstruction::CreateRecvDone( + operands[0], channel_id.value(), *is_host_transfer)); } case HloOpcode::kSend: { optional channel_id; @@ -3214,8 +3227,9 @@ HloInstruction* HloParserImpl::CreateInstruction( // NOLINT return builder->AddInstruction(HloInstruction::CreateSetDimensionSize( *shape, operands[0], operands[1], (*dimensions)[0])); } + default: + return nullptr; } - return nullptr; } // NOLINT(readability/fn_size) // ::= '{' (single_sharding | tuple_sharding) '}' diff --git a/third_party/xla/xla/service/hlo_parser_test.cc b/third_party/xla/xla/service/hlo_parser_test.cc index 46dbd653a9f6c2..e34b6daa3a1482 100644 --- a/third_party/xla/xla/service/hlo_parser_test.cc +++ b/third_party/xla/xla/service/hlo_parser_test.cc @@ -4529,6 +4529,50 @@ ENTRY TestComputation { "attr_value"); } +TEST_F(HloParserTest, CheckAllowSpmdShardingPropagationToParameters) { + const char* const hlo_string = R"( +HloModule TestModule, allow_spmd_sharding_propagation_to_parameters=true + +ENTRY TestComputation { + p0 = f16[2048,1024] parameter(0) + p1 = f16[2048,1024] parameter(1) + ROOT root = (f16[2048,1024], f16[2048,1024]) tuple(p0, p1) +} +)"; + auto result = ParseAndReturnVerifiedModule(hlo_string); + TF_EXPECT_OK(result.status()); + EXPECT_EQ((*result) + ->config() + .allow_spmd_sharding_propagation_to_parameters() + .size(), + 1); + EXPECT_TRUE( + (*result)->config().allow_spmd_sharding_propagation_to_parameters()[0]); +} + +TEST_F(HloParserTest, CheckAllowSpmdShardingPropagationToParametersVec) { + const char* const hlo_string = R"( +HloModule TestModule, allow_spmd_sharding_propagation_to_parameters={true,false} + +ENTRY TestComputation { + p0 = f16[2048,1024] parameter(0) + p1 = f16[2048,1024] parameter(1) + ROOT root = (f16[2048,1024], f16[2048,1024]) tuple(p0, p1) +} +)"; + auto result = ParseAndReturnVerifiedModule(hlo_string); + TF_EXPECT_OK(result.status()); + EXPECT_EQ((*result) + ->config() + .allow_spmd_sharding_propagation_to_parameters() + .size(), + 2); + EXPECT_TRUE( + (*result)->config().allow_spmd_sharding_propagation_to_parameters()[0]); + EXPECT_FALSE( + (*result)->config().allow_spmd_sharding_propagation_to_parameters()[1]); +} + TEST_F(HloParserTest, CheckAllowSpmdShardingPropagationToOutput) { const char* const hlo_string = R"( HloModule TestModule, allow_spmd_sharding_propagation_to_output=true @@ -5057,5 +5101,75 @@ ENTRY %Entry (p0: f32[10]) -> f32[20] { "but got foo_thread"))); } +TEST_F(HloParserTest, PipelinedSendRecv) { + const std::string hlo_string = R"( + HloModule test + cond { + param = (u32[], u32[2], (u32[2], u32[], token[])) parameter(0) + count = get-tuple-element(%param), index=0 + ub = u32[] constant(1) + ROOT result = pred[] compare(count, ub), direction=LT + } + + body { + param = (u32[], u32[2], (u32[2], u32[], token[])) parameter(0) + count = get-tuple-element(%param), index=0 + send-data = get-tuple-element(%param), index=1 + + after-all.0 = token[] after-all() + send.0 = (u32[2], u32[], token[]) send(send-data, after-all.0), + channel_id=1, + frontend_attributes={ + _xla_send_recv_source_target_pairs="{{1,0}}" + } + + recv.0 = (u32[2], u32[], token[]) get-tuple-element(param), index=2 + recv-done.0 = (u32[2], token[]) recv-done(recv.0), channel_id=1 + recv-data.0 = u32[2] get-tuple-element(recv-done.0), index=0 + + c1 = u32[] constant(1) + new_count = u32[] add(count, c1) + + after-all.0.n = token[] after-all() + recv.0.n = (u32[2], u32[], token[]) recv(after-all.0.n), channel_id=1, + frontend_attributes={ + _xla_send_recv_source_target_pairs="{{1,0}}" + } + + send-done.0 = token[] send-done(send.0), channel_id=1 + + ROOT result = (u32[], u32[2], (u32[2], u32[], token[])) tuple(new_count, recv-data.0, recv.0.n) + } + + ENTRY test_computation { + c0 = u32[] constant(0) + init = u32[2] broadcast(c0), dimensions={} + after-all.0.p = token[] after-all() + recv.0.p = (u32[2], u32[], token[]) recv(after-all.0.p), channel_id=1, + frontend_attributes={ + _xla_send_recv_source_target_pairs="{{1,0}}" + } + + while_init = (u32[], u32[2], (u32[2], u32[], token[])) tuple(c0, init, recv.0.p) + while_result = (u32[], u32[2], (u32[2], u32[], token[])) while(while_init), body=body, condition=cond + + send-data.q = u32[2] get-tuple-element(while_result), index=1 + after-all.0.q = token[] after-all() + send.0.q = (u32[2], u32[], token[]) send(send-data.q, after-all.0.q), + channel_id=1, + frontend_attributes={ + _xla_send_recv_source_target_pairs="{{1,0}}" + } + + recv.0.q = (u32[2], u32[], token[]) get-tuple-element(while_result), index=2 + recv-done.0.q = (u32[2], token[]) recv-done(recv.0.q), channel_id=1 + send-done.0.q = token[] send-done(send.0.q), channel_id=1 + + ROOT recv-data.0.q = u32[2] get-tuple-element(recv-done.0.q), index=0 + })"; + auto result = ParseAndReturnUnverifiedModule(hlo_string); + EXPECT_EQ(OkStatus(), result.status()); +} + } // namespace } // namespace xla diff --git a/third_party/xla/xla/service/hlo_pass_pipeline.cc b/third_party/xla/xla/service/hlo_pass_pipeline.cc index d08ebbd8762c8c..01a9c7ee103794 100644 --- a/third_party/xla/xla/service/hlo_pass_pipeline.cc +++ b/third_party/xla/xla/service/hlo_pass_pipeline.cc @@ -99,29 +99,6 @@ void RecordPassEndMetadata(HloModuleGroup& module_group, } } -void SetInstructionMetadata(HloModule& module) { - StatusOr pass_id = module.metadata()->current_pass_id(); - if (!pass_id.ok()) { - LOG(FATAL) << pass_id.status(); - } - for (xla::HloComputation* computation : module.computations()) { - for (xla::HloInstruction* instruction : computation->instructions()) { - if (instruction->metadata().creation_pass_id() == 0) { - instruction->set_creation_pass_id(*pass_id); - } - if (instruction->metadata().logical_creation_pass_id() == 0) { - instruction->set_logical_creation_pass_id(*pass_id); - } - } - } -} - -void SetInstructionMetadata(HloModuleGroup& module_group) { - for (HloModule* module : module_group.modules()) { - SetInstructionMetadata(*module); - } -} - } // namespace template @@ -179,7 +156,6 @@ StatusOr HloPassPipeline::RunPassesInternal( RunInvariantCheckers(hlo, kPipelineStart, execution_threads)); RecordPassStartMetadata(*hlo, std::string(kPipelineStart), pipeline_name); - SetInstructionMetadata(*hlo); MaybeDumpHloAndSaveFilenames(*hlo, /*after_pass_name=*/kPipelineStart, /*before_pass_name=*/passes.empty() @@ -207,7 +183,6 @@ StatusOr HloPassPipeline::RunPassesInternal( pass_name, absl::StatusCodeToString(status.code())); } TF_ASSIGN_OR_RETURN(bool pass_changed, status_or_changed); - SetInstructionMetadata(*hlo); if (!dump_regex.empty() && (pass_changed || dump_regex != ".*")) { MaybeDumpHloAndSaveFilenames(*hlo, /*after_pass_name=*/pass_name, diff --git a/third_party/xla/xla/service/hlo_runner.cc b/third_party/xla/xla/service/hlo_runner.cc index 0e56fda23a26db..14e9b9b461c204 100644 --- a/third_party/xla/xla/service/hlo_runner.cc +++ b/third_party/xla/xla/service/hlo_runner.cc @@ -640,6 +640,11 @@ HloRunner::CreateExecutableWithBufferAssignment( LOG(WARNING) << "Ignoring buffer assignment provided because hlo passes " "are enabled."; } + // Setup intra-op threads in module config + if (backend().eigen_intra_op_thread_pool() != nullptr) { + module->mutable_config().set_intra_op_parallelism_threads( + backend().eigen_intra_op_thread_pool()->NumThreads()); + } auto module_group = std::make_unique(std::move(module)); TF_ASSIGN_OR_RETURN( auto executables, diff --git a/third_party/xla/xla/service/hlo_verifier.cc b/third_party/xla/xla/service/hlo_verifier.cc index d234f45de64476..e182a6e193f004 100644 --- a/third_party/xla/xla/service/hlo_verifier.cc +++ b/third_party/xla/xla/service/hlo_verifier.cc @@ -299,26 +299,11 @@ Status ShapeVerifier::HandleOptimizationBarrier(HloInstruction* hlo) { return CheckShape(hlo, hlo->operand(0)->shape()); } -bool ShapeVerifier::ShapesSame( - const Shape& a, const Shape& b, bool minor_to_major_only, - bool ignore_memory_space, bool ignore_tiles, - bool ignore_trailing_padding_alignment_in_elements) { +bool ShapeVerifier::ShapesSame(const Shape& a, const Shape& b, + Shape::Equal equal) { if (!opts_.layout_sensitive) { return ShapeUtil::Compatible(a, b); } - Shape::Equal equal; - if (ignore_memory_space) { - equal.IgnoreMemorySpaceInLayout(); - } - if (minor_to_major_only) { - equal.MinorToMajorOnlyInLayout(); - } - if (ignore_tiles) { - equal.IgnoreTilesInLayout(); - } - if (ignore_trailing_padding_alignment_in_elements) { - equal.IgnoreTailPaddingAlignmentInElements(); - } return equal(a, b); } @@ -1612,8 +1597,7 @@ Status ShapeVerifier::HandleCopyDone(HloInstruction* copy_done) { const Shape& dest_shape = ShapeUtil::GetTupleElementShape(operand_shape, 0); const Shape& src_shape = ShapeUtil::GetTupleElementShape(operand_shape, 1); if (!ShapesSame(dest_shape, src_shape, - /*minor_to_major_only=*/false, - /*ignore_memory_space=*/true)) { + Shape::Equal().IgnoreMemorySpaceInLayout())) { return Internal( "Source and destination buffers in CopyDone arguments need to be the " "same shape found %s and %s\n%s", @@ -1838,18 +1822,27 @@ Status ShapeVerifier::CheckShape(const HloInstruction* instruction, case HloOpcode::kSend: case HloOpcode::kSendDone: case HloOpcode::kTuple: - case HloOpcode::kWhile: - return ShapesSame(instruction->shape(), inferred_shape, - only_compare_minor_to_major_in_layout); - case HloOpcode::kDynamicUpdateSlice: - // For DynamicUpdateSlice it has an "in-place" update semantics, but - // inside of fusions memory space propagation doesn't propagate the - // memory spaces all the way, causing possible mismatches. Relax the - // constraint in that condition. - return ShapesSame(instruction->shape(), inferred_shape, - only_compare_minor_to_major_in_layout, - /*ignore_memory_space=*/ - instruction->parent()->IsFusionComputation()); + case HloOpcode::kWhile: { + Shape::Equal equal; + if (only_compare_minor_to_major_in_layout) { + equal.MinorToMajorOnlyInLayout(); + } + return ShapesSame(instruction->shape(), inferred_shape, equal); + } + case HloOpcode::kDynamicUpdateSlice: { + Shape::Equal equal; + if (only_compare_minor_to_major_in_layout) { + equal.MinorToMajorOnlyInLayout(); + } + if (instruction->parent()->IsFusionComputation()) { + // For DynamicUpdateSlice it has an "in-place" update semantics, but + // inside of fusions memory space propagation doesn't propagate the + // memory spaces all the way, causing possible mismatches. Relax the + // constraint in that condition. + equal.IgnoreMemorySpaceInLayout(); + } + return ShapesSame(instruction->shape(), inferred_shape, equal); + } // We allow arbitrary layout and f32->bf16 transformations on all other // instructions, although this may be made more strict pending discussion @@ -1922,9 +1915,9 @@ Status ShapeVerifier::VerifyEntryComputationLayout(const HloModule& module) { // let's not check that. if (!ShapesSame(computation->root_instruction()->shape(), result_layout.shape(), - /*minor_to_major_only=*/false, /*ignore_memory_space=*/false, - /*ignore_tiles=*/true, - /*ignore_trailing_padding_alignment_in_elements=*/true)) { + Shape::Equal() + .IgnoreTilesInLayout() + .IgnoreTailPaddingAlignmentInElements())) { return Internal( "Shape of the root instruction of entry computation (%s) should be " "compatible to one specified in module's entry computation layout (%s)", @@ -1946,10 +1939,9 @@ Status ShapeVerifier::VerifyEntryComputationLayout(const HloModule& module) { // TPU layout assignment doesn't set the tiles on entry_computation_layout, // so let's not check that. if (!ShapesSame(parameter->shape(), layout.parameter_shape(i), - /*minor_to_major_only=*/false, - /*ignore_memory_space=*/false, - /*ignore_tiles=*/true, - /*ignore_trailing_padding_alignment_in_elements=*/true)) { + Shape::Equal() + .IgnoreTilesInLayout() + .IgnoreTailPaddingAlignmentInElements())) { return Internal( "Shape of the entry computation parameter %d is %s should be " "compatible to the one specified in module's entry computation " @@ -2254,20 +2246,40 @@ Status VerifyChannels(const HloModule& module) { } case HloOpcode::kRecv: { TF_RET_CHECK(instruction->users().size() == 1); - const HloInstruction* recv_done = instruction->users().front(); - TF_RET_CHECK(recv_done->opcode() == HloOpcode::kRecvDone); - TF_RETURN_IF_ERROR(CheckSameChannel(instruction, recv_done)); - TF_RETURN_IF_ERROR(CheckSameIsHostTransfer(instruction, recv_done)); + const HloInstruction* recv_user = instruction->users().front(); + if (recv_user->opcode() == HloOpcode::kRecvDone) { + TF_RETURN_IF_ERROR(CheckSameChannel(instruction, recv_user)); + TF_RETURN_IF_ERROR(CheckSameIsHostTransfer(instruction, recv_user)); + } else { + // If a Recv user is not a RecvDone, it has to be a tuple that is + // either the root of a while-body or the init of a while-loop. + TF_RET_CHECK(recv_user->opcode() == HloOpcode::kTuple); + if (recv_user != recv_user->parent()->root_instruction()) { + TF_RET_CHECK(recv_user->users().size() == 1); + const HloInstruction* user = recv_user->users().front(); + TF_RET_CHECK(user->opcode() == HloOpcode::kWhile); + } + } break; } case HloOpcode::kSendDone: TF_RET_CHECK(instruction->operands().size() == 1); TF_RET_CHECK(instruction->operand(0)->opcode() == HloOpcode::kSend); break; - case HloOpcode::kRecvDone: + case HloOpcode::kRecvDone: { TF_RET_CHECK(instruction->operands().size() == 1); - TF_RET_CHECK(instruction->operand(0)->opcode() == HloOpcode::kRecv); + const HloInstruction* recv_done_operand = instruction->operand(0); + if (recv_done_operand->opcode() != HloOpcode::kRecv) { + // If the RecvDone operand is not a Recv, it has to be either part + // of a while-loop result or a parameter of a while-body. + TF_RET_CHECK(recv_done_operand->opcode() == + HloOpcode::kGetTupleElement); + HloOpcode opcode = recv_done_operand->operand(0)->opcode(); + TF_RET_CHECK(opcode == HloOpcode::kWhile || + opcode == HloOpcode::kParameter); + } break; + } default: break; } @@ -2865,12 +2877,6 @@ void MetadataTracker::HandleMetadata(const OpMetadata& metadata) { if (metadata.source_line() != 0) { ++has_source_line_count_; } - if (metadata.creation_pass_id() != 0) { - ++has_creation_pass_id_count_; - } - if (metadata.logical_creation_pass_id() != 0) { - ++has_logical_creation_pass_id_count_; - } if (metadata.size_of_generated_code_in_bytes() != 0) { ++has_size_of_generated_code_in_bytes_count_; } diff --git a/third_party/xla/xla/service/hlo_verifier.h b/third_party/xla/xla/service/hlo_verifier.h index c806dafe420777..2debca84b89c23 100644 --- a/third_party/xla/xla/service/hlo_verifier.h +++ b/third_party/xla/xla/service/hlo_verifier.h @@ -244,10 +244,7 @@ class ShapeVerifier : public DfsHloVisitor { protected: // Helpers that switch on layout_sensitive_. - bool ShapesSame(const Shape& a, const Shape& b, - bool minor_to_major_only = false, - bool ignore_memory_space = false, bool ignore_tiles = false, - bool ignore_trailing_padding_alignment_in_elements = false); + bool ShapesSame(const Shape& a, const Shape& b, Shape::Equal equal = {}); // Check the instruction's shape against the shape given by ShapeInference // and return an appropriate error if there is a mismatch. @@ -270,19 +267,6 @@ class ShapeVerifier : public DfsHloVisitor { Status CheckVariadicShape(const HloInstruction* instruction); private: - bool ShapesSameIgnoringFpPrecision(const Shape& a, const Shape& b, - bool minor_to_major_only = false) { - if (!opts_.layout_sensitive) { - return ShapeUtil::CompatibleIgnoringFpPrecision(a, b); - } - Shape::Equal equal; - if (minor_to_major_only) { - equal.MinorToMajorOnlyInLayout(); - } - equal.IgnoreFpPrecision(); - return equal(a, b); - } - std::string StringifyShape(const Shape& s) { return opts_.layout_sensitive ? ShapeUtil::HumanStringWithLayout(s) : ShapeUtil::HumanString(s); diff --git a/third_party/xla/xla/service/host_memory_offload_annotations.h b/third_party/xla/xla/service/host_memory_offload_annotations.h new file mode 100644 index 00000000000000..fc7ab75559d084 --- /dev/null +++ b/third_party/xla/xla/service/host_memory_offload_annotations.h @@ -0,0 +1,42 @@ +/* Copyright 2024 The OpenXLA Authors. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ==============================================================================*/ + +#ifndef XLA_SERVICE_HOST_MEMORY_OFFLOAD_ANNOTATIONS_H_ +#define XLA_SERVICE_HOST_MEMORY_OFFLOAD_ANNOTATIONS_H_ + +#include "absl/strings/string_view.h" + +namespace xla { +namespace host_memory_offload_annotations { + +// External annotations: +inline const absl::string_view kDevicePlacement = "annotate_device_placement"; +inline const absl::string_view kMemoryTargetPinnedHost = "pinned_host"; +inline const absl::string_view kMemoryTargetUnpinnedHost = "unpinned_host"; +inline const absl::string_view kMemoryTargetDevice = "device"; + +// Internal annotations: +// These are currently called PipelineForward/PipelineBackward, because they +// were originally meant as a hook point for the collective-pipeliner. They do +// more than just that though (identify memory movement direction), so should be +// renamed to something related to memory movement. +inline const absl::string_view kMoveToHostCustomCallTarget = "PipelineForward"; +inline const absl::string_view kMoveToDeviceCustomCallTarget = + "PipelineBackward"; + +} // namespace host_memory_offload_annotations +} // namespace xla + +#endif // XLA_SERVICE_HOST_MEMORY_OFFLOAD_ANNOTATIONS_H_ diff --git a/third_party/xla/xla/service/host_memory_transfer_asyncifier.cc b/third_party/xla/xla/service/host_memory_transfer_asyncifier.cc index df6d1487617eaa..499e6f62cca116 100644 --- a/third_party/xla/xla/service/host_memory_transfer_asyncifier.cc +++ b/third_party/xla/xla/service/host_memory_transfer_asyncifier.cc @@ -19,6 +19,7 @@ limitations under the License. #include "absl/container/flat_hash_set.h" #include "absl/log/log.h" +#include "absl/strings/str_format.h" #include "absl/strings/string_view.h" #include "xla/hlo/ir/dfs_hlo_visitor_with_default.h" #include "xla/hlo/ir/hlo_instruction.h" @@ -60,6 +61,10 @@ class HostMemoryTransferAsyncifierVisitor : public DfsHloVisitorWithDefault { ", does not have a layout."); } + VLOG(3) << absl::StreamFormat( + "\"%s\" from S(%d) to S(%d)", dynamic_slice->name(), + dynamic_slice_operand->shape().layout().memory_space(), + dynamic_slice->shape().layout().memory_space()); // Check that this is a dynamic-slice slicing from host memory to device // memory. if (dynamic_slice_operand->shape().layout().memory_space() != diff --git a/third_party/xla/xla/service/host_offload_legalize.cc b/third_party/xla/xla/service/host_offload_legalize.cc new file mode 100644 index 00000000000000..04c28cd6090f93 --- /dev/null +++ b/third_party/xla/xla/service/host_offload_legalize.cc @@ -0,0 +1,598 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/host_offload_legalize.h" + +#include +#include +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/container/flat_hash_set.h" +#include "absl/container/inlined_vector.h" +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/service/hlo_value.h" +#include "xla/service/host_memory_offload_annotations.h" +#include "xla/shape.h" +#include "xla/shape_util.h" +#include "xla/status.h" +#include "xla/statusor.h" +#include "xla/util.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/statusor.h" + +namespace xla { + +namespace { + +constexpr std::array kUsersOpcodes = {HloOpcode::kSlice, + HloOpcode::kDynamicSlice}; + +// Find an annotation moving up. Meant to find an annotation from a DUS operand. +HloInstruction* FindAnnotationToUpdate(HloInstruction* instr) { + while (!instr->IsCustomCall( + host_memory_offload_annotations::kMoveToHostCustomCallTarget)) { + if ((instr->opcode() != HloOpcode::kBitcast && + instr->opcode() != HloOpcode::kCopy && + instr->opcode() != HloOpcode::kReshape) || + instr->mutable_operand(0)->user_count() != 1) { + return nullptr; + } + instr = instr->mutable_operand(0); + } + return instr; +} + +// Find an annotation moving up. Meant to find an annotation from a DUS operand. +HloInstruction* FindToDeviceAnnotationToUpdate(HloInstruction* instr) { + while (!instr->IsCustomCall( + host_memory_offload_annotations::kMoveToDeviceCustomCallTarget)) { + if (instr->user_count() != 1 || + (instr->opcode() != HloOpcode::kBitcast && + instr->opcode() != HloOpcode::kReshape && + instr->opcode() != HloOpcode::kCopy && + !absl::c_linear_search(kUsersOpcodes, instr->opcode()))) { + return nullptr; + } + instr = instr->users()[0]; + } + return instr; +} + +// Find a DUS starting from an annotation. +HloInstruction* FindDUSFromAnnotation(HloInstruction* instr) { + while (instr->opcode() != HloOpcode::kDynamicUpdateSlice) { + if (instr->user_count() != 1 || (instr->opcode() != HloOpcode::kBitcast && + instr->opcode() != HloOpcode::kReshape)) { + break; + } + instr = instr->users()[0]; + } + return instr; +} + +// Make sure that broadcasts are duplicated for each use. +StatusOr DuplicateBroadcastForEachUse(HloModule* module) { + bool split_at_least_one = false; + for (HloComputation* computation : module->computations()) { + std::vector broadcasts; + for (HloInstruction* instruction : computation->instructions()) { + if (instruction->opcode() != HloOpcode::kBroadcast || + !instruction->HasConstantOperand()) { + continue; + } + broadcasts.push_back(instruction); + } + for (HloInstruction* instruction : broadcasts) { + if (instruction->opcode() != HloOpcode::kBroadcast || + !instruction->HasConstantOperand()) { + continue; + } + absl::InlinedVector uses; + for (HloInstruction* user : instruction->users()) { + for (int64_t i = 0; i < user->operand_count(); ++i) { + if (user->operand(i) != instruction) { + continue; + } + uses.push_back(HloUse{user, i, /*operand_index=*/{}}); + } + } + + if (uses.size() <= 1) { + VLOG(5) << "Skipping broadcast " << instruction->ToString() + << " which has " << uses.size() << " uses"; + continue; + } + + VLOG(5) << "Splitting broadcast " << instruction->ToString() + << " which has " << uses.size() << " uses"; + split_at_least_one = true; + // Don't create a new broadcast for the first use; we can still use the + // original. + for (int i = 1; i < uses.size(); ++i) { + const HloUse& use = uses[i]; + HloInstruction* new_broadcast = + instruction->parent()->AddInstruction(instruction->Clone()); + VLOG(5) << "New broadcast " << new_broadcast->ToString(); + TF_RETURN_IF_ERROR(use.instruction->ReplaceOperandWith( + use.operand_number, new_broadcast)); + } + } + } + return split_at_least_one; +} + +// Walk up in the chain of memory offloaded instructions. Status not-ok when +// an instructions not supported or end of chain reached. +// Walks one instruction at a time. +StatusOr> WalkUpMemoryOffload( + std::pair current_value, + const CallGraph& call_graph) { + // TODO(maggioni): Verify that set of instructions supported in chain by + // legalization is in sync with host_offloader. + auto& [instruction, index] = current_value; + // Walk up to find definition + switch (instruction->opcode()) { + case HloOpcode::kGetTupleElement: { + CHECK_EQ(index, -1); + return std::make_pair(instruction->mutable_operand(0), + instruction->tuple_index()); + } + case HloOpcode::kBitcast: + case HloOpcode::kReshape: { + return std::make_pair(instruction->mutable_operand(0), index); + } + case HloOpcode::kTuple: { + return std::make_pair(instruction->mutable_operand(index), -1); + } + case HloOpcode::kOptimizationBarrier: { + return std::make_pair(instruction->mutable_operand(0), index); + } + case HloOpcode::kWhile: { + HloComputation* while_body = instruction->while_body(); + HloInstruction* root = while_body->root_instruction(); + CHECK_EQ(root->opcode(), HloOpcode::kTuple); + return std::make_pair(root, index); + } + case HloOpcode::kParameter: { + CHECK_NE(instruction->parent(), + instruction->GetModule()->entry_computation()); + auto callers = call_graph.GetComputationCallers(instruction->parent()); + if (callers.size() != 1) { + return absl::InvalidArgumentError( + "Expected to be called only by one caller"); + } + auto* caller = callers[0]; + if (caller->opcode() != HloOpcode::kWhile) { + return absl::InvalidArgumentError( + "Expected to be called by a while loop"); + } + return std::make_pair(caller->mutable_operand(0), index); + } + case HloOpcode::kDynamicUpdateSlice: { + return std::make_pair(instruction->mutable_operand(0), index); + } + case HloOpcode::kCustomCall: { + if (!instruction->IsCustomCall("AllocateBuffer") && + !instruction->IsCustomCall( + host_memory_offload_annotations::kMoveToHostCustomCallTarget)) { + return absl::InvalidArgumentError( + "Expected AllocateBuffer or MoveToHost custom-call"); + } + return std::make_pair(instruction, index); + } + case HloOpcode::kBroadcast: { + auto* broadcast_operand = instruction->mutable_operand(0); + if (broadcast_operand->opcode() != HloOpcode::kConstant) { + return absl::InvalidArgumentError("Expected a constant as operand"); + } + if (!ShapeUtil::IsEffectiveScalar(broadcast_operand->shape())) { + return absl::InvalidArgumentError("Expected a scalar broadcast"); + } + return std::make_pair(instruction, index); + } + default: { + return absl::InvalidArgumentError( + absl::StrFormat("Invalid opcode %s", instruction->ToString())); + } + } +} + +// Walk down in the chain of memory offloaded instructions. Status not-ok when +// an instructions not supported or end of chain reached. +// Walks one instruction at a time, but returns multiple instructions for each +// conforming user. +StatusOr>> WalkDownMemoryOffload( + const std::pair& current_value, + const CallGraph& call_graph) { + // TODO(maggioni): Verify that set of instructions supported in chain by + // legalization is in sync with host_offloader. + VLOG(5) << "Current value in progress: " << current_value.first->ToString() + << " idx: " << current_value.second; + std::vector> results; + auto add_gte_for_idx = [&results](HloInstruction* instr, int idx) -> Status { + HloInstruction* gte = nullptr; + for (HloInstruction* user : instr->users()) { + if (user->opcode() != HloOpcode::kGetTupleElement) { + return absl::InvalidArgumentError( + "Expected users to be only get-tuple-elements"); + } + if (user->tuple_index() != idx) { + continue; + } + if (gte != nullptr) { + return absl::InvalidArgumentError( + "Expected to find only one gte per index."); + } + results.push_back(std::make_pair(user, -1)); + } + return OkStatus(); + }; + if (current_value.first->user_count() == 0) { + if (current_value.first->parent()->root_instruction() == + current_value.first) { + auto callers = + call_graph.GetComputationCallers(current_value.first->parent()); + if (callers.size() != 1 || callers[0]->opcode() != HloOpcode::kWhile) { + return absl::InvalidArgumentError( + "Expected to be called only by one caller and caller be a While"); + } + TF_RETURN_IF_ERROR(add_gte_for_idx(callers[0], current_value.second)); + return results; + } + } + if (current_value.first->opcode() == HloOpcode::kParameter && + current_value.first->shape().IsTuple()) { + TF_RETURN_IF_ERROR( + add_gte_for_idx(current_value.first, current_value.second)); + return results; + } + for (HloInstruction* user : current_value.first->users()) { + switch (user->opcode()) { + case HloOpcode::kGetTupleElement: { + CHECK_NE(user->tuple_index(), -1); + if (user->tuple_index() != current_value.second) { + continue; + } + results.push_back(std::make_pair(user, -1)); + break; + } + case HloOpcode::kTuple: { + auto output_indices = user->OperandIndices(current_value.first); + if (output_indices.size() != 1) { + return absl::InvalidArgumentError( + "Expected operand to be used only once in the tuple."); + } + results.push_back(std::make_pair(user, output_indices[0])); + break; + } + case HloOpcode::kOptimizationBarrier: { + results.push_back(std::make_pair(user, current_value.second)); + break; + } + case HloOpcode::kWhile: { + HloComputation* while_body = user->while_body(); + HloInstruction* parameter = while_body->parameter_instruction(0); + results.push_back(std::make_pair(parameter, current_value.second)); + break; + } + case HloOpcode::kDynamicUpdateSlice: { + if (user->OperandIndices(current_value.first)[0] != 0) { + return absl::InvalidArgumentError( + "Expected to be used by first operand of dynamic-update-slice"); + } + results.push_back(std::make_pair(user, current_value.second)); + break; + } + case HloOpcode::kCustomCall: { + if (user->IsCustomCall(host_memory_offload_annotations:: + kMoveToDeviceCustomCallTarget)) { + results.push_back(std::make_pair(user, current_value.second)); + break; + } + return absl::InvalidArgumentError("Invalid custom-call found."); + } + case HloOpcode::kBitcast: + case HloOpcode::kCopy: + case HloOpcode::kDynamicSlice: + case HloOpcode::kReshape: + case HloOpcode::kSlice: { + results.push_back(std::make_pair(user, current_value.second)); + break; + } + default: { + return absl::InvalidArgumentError("Unrecognized user opcode"); + } + } + } + return results; +} + +StatusOr ProcessAnnotationForCopyMovement( + HloInstruction* instruction, const CallGraph* call_graph, + absl::flat_hash_set& processed_annotations, + std::vector& to_remove) { + HloInstruction* starting_instr = + FindDUSFromAnnotation(instruction->users()[0]); + // If it's the pure copy case reset instruction. + if (starting_instr->opcode() != HloOpcode::kDynamicUpdateSlice) { + starting_instr = instruction; + } + VLOG(3) << "Dus or Annotation: " << starting_instr->ToString(); + std::pair current_value = + std::make_pair(starting_instr, -1); + // Found a copy that would block offloading. Walk up to find all annotations + // to update (required in case there are multiple insertions in the buffer). + processed_annotations.insert(current_value.first); + if (!current_value.first->IsCustomCall( + host_memory_offload_annotations::kMoveToHostCustomCallTarget)) { + CHECK_EQ(current_value.first->opcode(), HloOpcode::kDynamicUpdateSlice); + while (true) { + VLOG(10) << "Current value before: " << current_value.first->ToString(); + auto current_value_up = WalkUpMemoryOffload(current_value, *call_graph); + // Invalid upward walking means the chain is unrecognized. + if (!current_value_up.ok()) { + return false; + } + // This means we encountered a broadcast with constant 0 expansion. + if (current_value_up.value() == current_value) { + break; + } + current_value = current_value_up.value(); + VLOG(10) << "Current value after: " << current_value.first->ToString(); + HloInstruction* annotation = current_value.first; + if (annotation->opcode() == HloOpcode::kDynamicUpdateSlice) { + HloInstruction* real_annotation = + FindAnnotationToUpdate(annotation->mutable_operand(1)); + // Check if this dynamic-update-slice doesn't have an annotation + // attached. + if (!real_annotation->IsCustomCall( + host_memory_offload_annotations::kMoveToHostCustomCallTarget)) { + return false; + } + } + } + } + std::vector> copies_to_move; + // Do a final walkdown from the top to collect all the instructions that need + // their shape updated. + std::vector> stack(1, current_value); + while (!stack.empty()) { + VLOG(5) << "Current value before down: " << stack.back().first->ToString(); + if (absl::c_linear_search(kUsersOpcodes, stack.back().first->opcode()) || + stack.back().first->IsCustomCall( + host_memory_offload_annotations::kMoveToDeviceCustomCallTarget)) { + HloInstruction* annotation = + FindToDeviceAnnotationToUpdate(stack.back().first); + if (!annotation || + !annotation->IsCustomCall( + host_memory_offload_annotations::kMoveToDeviceCustomCallTarget)) { + VLOG(5) << "Couldn't find annotation for consumer instruction in chain"; + return false; + } + stack.pop_back(); + continue; + } + auto current_value_down = WalkDownMemoryOffload(stack.back(), *call_graph); + if (!current_value_down.ok()) { + VLOG(5) << "Current value down failed: " << current_value_down.status(); + break; + } + stack.pop_back(); + stack.insert(stack.end(), current_value_down.value().begin(), + current_value_down.value().end()); + for (auto& instruction : current_value_down.value()) { + VLOG(5) << "Current value last down: " << stack.back().first->ToString(); + if (instruction.first->opcode() == HloOpcode::kCopy) { + copies_to_move.push_back(instruction); + } + } + } + auto update_shape_layout = + [&](const std::pair& instruction, + HloInstruction* copy_to_move) { + // Update shape. Tuple shape vs array shape. + if (instruction.second != -1) { + *instruction.first->mutable_shape() + ->mutable_tuple_shapes(instruction.second) + ->mutable_layout() = copy_to_move->operand(0)->shape().layout(); + } else { + *instruction.first->mutable_shape()->mutable_layout() = + copy_to_move->operand(0)->shape().layout(); + } + }; + // Process all copies one at a time from the last to the first and push it to + // its specific user. + while (!copies_to_move.empty()) { + auto& copy_to_move = copies_to_move.back(); + VLOG(5) << "Copy to move: " << copy_to_move.first->ToString(); + stack.clear(); + stack.push_back(copy_to_move); + while (!stack.empty()) { + VLOG(5) << "Current value before down: " + << stack.back().first->ToString(); + auto current_value_down = + WalkDownMemoryOffload(stack.back(), *call_graph); + if (!current_value_down.ok()) { + VLOG(5) << "Current value down failed: " << current_value_down.status(); + break; + } + for (auto& instruction : current_value_down.value()) { + update_shape_layout(instruction, copy_to_move.first); + if (instruction.first->opcode() == HloOpcode::kParameter) { + auto callers = + call_graph->GetComputationCallers(instruction.first->parent()); + if (callers.size() != 1) { + return absl::InvalidArgumentError( + "Expected to be called only by one caller"); + } + auto* caller = callers[0]; + update_shape_layout(std::make_pair(caller, instruction.second), + copy_to_move.first); + } + } + stack.pop_back(); + for (auto& instruction : current_value_down.value()) { + VLOG(5) << "Current value last down: " << instruction.first->ToString(); + CHECK_NE(instruction.first->opcode(), HloOpcode::kCopy) + << "Copies should be processed in order"; + if (absl::c_linear_search(kUsersOpcodes, instruction.first->opcode()) || + instruction.first->IsCustomCall( + host_memory_offload_annotations:: + kMoveToDeviceCustomCallTarget)) { + HloInstruction* annotation = + FindToDeviceAnnotationToUpdate(instruction.first); + CHECK_NE(annotation, nullptr) + << "We already verified we could find an annotation here. " + "Something went wrong."; + HloInstruction* new_annotation = nullptr; + if (instruction.first->opcode() == HloOpcode::kCustomCall) { + new_annotation = annotation; + } else { + new_annotation = instruction.first->AddInstruction( + annotation->CloneWithNewOperands(instruction.first->shape(), + {instruction.first})); + } + update_shape_layout(std::make_pair(new_annotation, -1), + copy_to_move.first); + HloInstruction* new_copy = instruction.first->AddInstruction( + copy_to_move.first->CloneWithNewOperands(new_annotation->shape(), + {new_annotation})); + std::vector users = instruction.first->users(); + for (auto* use : users) { + if (use == new_copy || use == new_annotation) { + continue; + } + TF_RETURN_IF_ERROR( + instruction.first->ReplaceUseWithDifferentShape(use, new_copy)); + } + // Move the copy here. + if (new_annotation != annotation) { + TF_RETURN_IF_ERROR(annotation->ReplaceAllUsesWithDifferentShape( + annotation->mutable_operand(0))); + to_remove.push_back(annotation); + } + continue; + } + // Move the annotation first just before dynamic-update-slice to avoid + // shape changes. + if (instruction.first->opcode() == HloOpcode::kDynamicUpdateSlice) { + HloInstruction* annotation = + FindAnnotationToUpdate(instruction.first->mutable_operand(1)); + if (annotation == nullptr) { + CHECK(false); + return false; + } + CHECK(annotation->opcode() == HloOpcode::kCustomCall); + HloInstruction* new_annotation = instruction.first->AddInstruction( + annotation->CloneWithNewOperands( + instruction.first->operand(1)->shape(), + {instruction.first->mutable_operand(1)})); + TF_RETURN_IF_ERROR( + instruction.first->ReplaceOperandWith(1, new_annotation)); + TF_RETURN_IF_ERROR( + annotation->ReplaceAllUsesWith(annotation->mutable_operand(0))); + processed_annotations.insert(annotation); + processed_annotations.insert(new_annotation); + to_remove.push_back(annotation); + } + stack.push_back(instruction); + } + } + VLOG(5) << "MOVED: " << copy_to_move.first->ToString(); + TF_RETURN_IF_ERROR(copy_to_move.first->ReplaceAllUsesWithDifferentShape( + copy_to_move.first->mutable_operand(0))); + TF_RETURN_IF_ERROR( + copy_to_move.first->parent()->RemoveInstruction(copy_to_move.first)); + copies_to_move.pop_back(); + } + return true; +} + +// Fixes layout changing copies in between on the path to users. +StatusOr FixupInterveningCopies( + const std::vector& copy_to_host_annotations, + const CallGraph* call_graph) { + absl::flat_hash_set processed_annotations; + std::vector annotations_to_remove; + bool changed = false; + for (HloInstruction* instruction : copy_to_host_annotations) { + if (processed_annotations.count(instruction)) { + continue; + } + TF_ASSIGN_OR_RETURN(bool changed_annotation_for_copy_movement, + ProcessAnnotationForCopyMovement( + instruction, call_graph, processed_annotations, + annotations_to_remove)); + changed |= changed_annotation_for_copy_movement; + } + for (HloInstruction* instruction : annotations_to_remove) { + TF_RETURN_IF_ERROR(instruction->parent()->RemoveInstruction(instruction)); + } + return changed; +} + +} // namespace + +StatusOr HostOffloadLegalize::Run( + HloModule* module, + const absl::flat_hash_set& execution_threads) { + bool changed = false; + + // Split broadcasts so that each HloUse of a broadcast instruction will get + // its own copy. + // TODO(b/319293925): Do not blindly duplicate all broadcasts, instead do it + // only when necessary. + TF_ASSIGN_OR_RETURN(bool duplicated_at_least_one_broadcast, + DuplicateBroadcastForEachUse(module)); + if (duplicated_at_least_one_broadcast) { + changed = true; + } + if (!after_layout_) { + return changed; + } + std::unique_ptr call_graph = CallGraph::Build(module); + std::vector copy_to_host_annotations; + + // Iterate over all instructions and look for XLA host offload annotations. + for (HloComputation* computation : + module->MakeNonfusionComputations(execution_threads)) { + for (HloInstruction* instruction : computation->instructions()) { + if (instruction->opcode() != HloOpcode::kCustomCall) { + continue; + } + if (instruction->custom_call_target() == + host_memory_offload_annotations::kMoveToHostCustomCallTarget) { + copy_to_host_annotations.push_back(instruction); + } + } + } + // Fixup layout changing copies that are in between memory offloaded sections. + // Move them before the data is moved to the host. + TF_ASSIGN_OR_RETURN( + bool changed_intervening_copies, + FixupInterveningCopies(copy_to_host_annotations, call_graph.get())); + changed |= changed_intervening_copies; + + return changed; +} + +} // namespace xla diff --git a/third_party/xla/xla/service/host_offload_legalize.h b/third_party/xla/xla/service/host_offload_legalize.h new file mode 100644 index 00000000000000..3bbe387c557a03 --- /dev/null +++ b/third_party/xla/xla/service/host_offload_legalize.h @@ -0,0 +1,55 @@ +/* Copyright 2024 The OpenXLA Authors. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ==============================================================================*/ +#ifndef XLA_SERVICE_HOST_OFFLOAD_LEGALIZE_H_ +#define XLA_SERVICE_HOST_OFFLOAD_LEGALIZE_H_ + +#include +#include + +#include "absl/container/flat_hash_set.h" +#include "absl/strings/string_view.h" +#include "xla/service/hlo_alias_analysis.h" +#include "xla/service/hlo_pass_interface.h" + +namespace xla { + +class HloCostAnalysis; + +// This pass legalizes the graph for the "host memory offloading" pass to +// correctly identified buffers that are meant to be move on the host. Any +// legalization that could block that is welcome into this pass. +class HostOffloadLegalize : public HloModulePass { + public: + explicit HostOffloadLegalize(int64_t host_memory_space_color, + bool after_layout) + : kHostMemorySpaceColor(host_memory_space_color), + after_layout_(after_layout) {} + ~HostOffloadLegalize() override = default; + + absl::string_view name() const override { return "host-offload-legalize"; } + + using HloPassInterface::Run; + StatusOr Run( + HloModule* module, + const absl::flat_hash_set& execution_threads) override; + + private: + const int64_t kHostMemorySpaceColor; + const bool after_layout_; +}; + +} // namespace xla + +#endif // XLA_SERVICE_HOST_OFFLOAD_LEGALIZE_H_ diff --git a/third_party/xla/xla/service/host_offload_legalize_test.cc b/third_party/xla/xla/service/host_offload_legalize_test.cc new file mode 100644 index 00000000000000..6ef84706147b4c --- /dev/null +++ b/third_party/xla/xla/service/host_offload_legalize_test.cc @@ -0,0 +1,349 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/host_offload_legalize.h" + +#include +#include +#include + +#include +#include +#include "absl/container/flat_hash_set.h" +#include "absl/status/status.h" +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/service/host_memory_offload_annotations.h" +#include "xla/service/pattern_matcher.h" +#include "xla/service/pattern_matcher_gmock.h" +#include "xla/shape.h" +#include "xla/shape_util.h" +#include "xla/statusor.h" +#include "xla/tests/hlo_test_base.h" +#include "xla/util.h" +#include "tsl/lib/core/status_test_util.h" +#include "tsl/platform/statusor.h" + +namespace m = ::xla::match; + +namespace xla { +namespace { + +class HostOffloadLegalizeTest : public HloTestBase { + protected: + static constexpr int64_t kHostMemorySpaceColor{5}; + + StatusOr RunHostOffloadLegalize(HloModule* module) { + TF_EXPECT_OK(verifier().Run(module).status()); + if (module->has_schedule()) { + return absl::InternalError("Expected a non-scheduled module"); + } + HostOffloadLegalize host_offload_legalize(kHostMemorySpaceColor, + /*after_layout=*/true); + return host_offload_legalize.Run(module); + } + + void TestShapeHasMemorySpace(const Shape& shape, int64_t memory_space) { + ASSERT_TRUE(shape.has_layout()); + EXPECT_EQ(shape.layout().memory_space(), memory_space); + } + + bool HaveRemainingOffloadAnnotations(const HloModule* module) { + for (const HloComputation* computation : module->computations()) { + for (const HloInstruction* instruction : computation->instructions()) { + if (instruction->IsCustomCall( + {host_memory_offload_annotations::kMoveToHostCustomCallTarget, + host_memory_offload_annotations:: + kMoveToDeviceCustomCallTarget})) { + return true; + } + } + } + return false; + } +}; + +TEST_F(HostOffloadLegalizeTest, NoCopyWithOptBarrierMoreElaborate) { + const std::string& hlo_string = R"( +HloModule jit_f, entry_computation_layout={(f32[16,256]{0,1})->f32[16,256]{0,1}} + +ENTRY main.24 { + Arg_0.1 = f32[16,256]{0,1} parameter(0) + cosine.4 = f32[16,256]{0,1} cosine(Arg_0.1) + custom-call.5 = f32[16,256]{0,1} custom-call(cosine.4), custom_call_target="PipelineForward" + sine.3 = f32[16,256]{0,1} sine(Arg_0.1) + cosine.7 = f32[16,256]{0,1} cosine(sine.3) + custom-call.8 = f32[16,256]{0,1} custom-call(cosine.7), custom_call_target="PipelineForward" + sine.6 = f32[16,256]{0,1} sine(sine.3) + cosine.9 = f32[16,256]{0,1} cosine(sine.6) + custom-call.10 = f32[16,256]{0,1} custom-call(cosine.9), custom_call_target="PipelineForward" + constant.2 = f32[] constant(1) + cp = f32[16,256]{1,0} copy(custom-call.8) + tuple.11 = (f32[16,256]{0,1}, f32[16,256]{1,0}, f32[16,256]{0,1}, f32[]) tuple(custom-call.5, cp, custom-call.10, constant.2) + opt-barrier.12 = (f32[16,256]{0,1}, f32[16,256]{1,0}, f32[16,256]{0,1}, f32[]) opt-barrier(tuple.11) + get-tuple-element.16 = f32[] get-tuple-element(opt-barrier.12), index=3 + broadcast.20 = f32[16,256]{0,1} broadcast(get-tuple-element.16), dimensions={} + get-tuple-element.15 = f32[16,256]{0,1} get-tuple-element(opt-barrier.12), index=2 + custom-call.19 = f32[16,256]{0,1} custom-call(get-tuple-element.15), custom_call_target="PipelineBackward" + multiply.21 = f32[16,256]{0,1} multiply(broadcast.20, custom-call.19) + cp2 = f32[16,256]{1,0} copy(multiply.21) + get-tuple-element.14 = f32[16,256]{1,0} get-tuple-element(opt-barrier.12), index=1 + custom-call.18 = f32[16,256]{1,0} custom-call(get-tuple-element.14), custom_call_target="PipelineBackward" + multiply.22 = f32[16,256]{1,0} multiply(cp2, custom-call.18) + get-tuple-element.13 = f32[16,256]{0,1} get-tuple-element(opt-barrier.12), index=0 + custom-call.17 = f32[16,256]{0,1} custom-call(get-tuple-element.13), custom_call_target="PipelineBackward" + cp3 = f32[16,256]{1,0} copy(custom-call.17) + ROOT multiply.23 = f32[16,256]{1,0} multiply(multiply.22, cp3) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + + TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHostOffloadLegalize(module.get())); + + EXPECT_TRUE(changed); + XLA_VLOG_LINES(1, module->ToString()); + HloInstruction* custom_call = FindInstruction(module.get(), "custom-call.18"); + EXPECT_EQ(custom_call->users()[0]->opcode(), HloOpcode::kCopy); + EXPECT_EQ(custom_call->shape().layout(), LayoutUtil::MakeLayout({0, 1})); +} + +TEST_F(HostOffloadLegalizeTest, LlmActivationHostMemoryMultipleConsumers) { + const std::string& hlo_string = R"( +HloModule llm_while + +producing_while_condition { + producing_condition_param = (s32[], f32[96,8,6,2048,2048]{0,1,2,3,4}) parameter(0) + producing_condition_current_iteration_index = s32[] get-tuple-element(producing_condition_param), index=0 + producing_condition_iteration_count = s32[] constant(96) + ROOT producing_condition_result = pred[] compare(producing_condition_current_iteration_index, producing_condition_iteration_count), direction=LT +} + +consuming_while_condition { + consuming_condition_param = (s32[], f32[96,8,6,2048,2048]{0,1,2,3,4}) parameter(0) + consuming_condition_current_iteration_index = s32[] get-tuple-element(consuming_condition_param), index=0 + consuming_condition_iteration_count = s32[] constant(96) + ROOT consuming_condition_result = pred[] compare(consuming_condition_current_iteration_index, consuming_condition_iteration_count), direction=LT +} + +producing_while_body { + input_tuple.0 = (s32[], f32[96,8,6,2048,2048]{0,1,2,3,4}) parameter(0) + current_iteration_index.0 = s32[] get-tuple-element(input_tuple.0), index=0 + data_0.0 = f32[96,8,6,2048,2048]{0,1,2,3,4} get-tuple-element(input_tuple.0), index=1 + constant_0.0 = s32[] constant(0) + constant_1.0 = s32[] constant(1) + constant_96 = s32[] constant(96) + + /* Create dummy data used in DUS */ + slice_data_0 = f32[1,8,6,2048,2048] constant({...}) + + /* Build DUS index */ + compare_result.0 = pred[] compare(current_iteration_index.0, constant_0.0), direction=LT + add_result = s32[] add(current_iteration_index.0, constant_96) + select_result.0 = s32[] select(compare_result.0, add_result, current_iteration_index.0) + + /* Annotate DUS for offload */ + custom_call_0.0 = f32[1,8,6,2048,2048] custom-call(slice_data_0), custom_call_target="PipelineForward" + + dynamic_update_slice_0 = f32[96,8,6,2048,2048]{0,1,2,3,4} dynamic-update-slice(data_0.0, custom_call_0.0, select_result.0, constant_0.0, constant_0.0, constant_0.0, constant_0.0) + + /* Increment iteration index */ + incremented_index.0 = s32[] add(current_iteration_index.0, constant_1.0) + ROOT tuple_result.0 = (s32[], f32[96,8,6,2048,2048]{0,1,2,3,4}) tuple(incremented_index.0, dynamic_update_slice_0) +} + +consuming_while_body { + input_tuple.1 = (s32[], f32[96,8,6,2048,2048]{0,1,3,2,4}) parameter(0) + current_iteration_index.1 = s32[] get-tuple-element(input_tuple.1), index=0 + data_0.1 = f32[96,8,6,2048,2048]{0,1,3,2,4} get-tuple-element(input_tuple.1), index=1 + constant_0.1 = s32[] constant(0) + constant_1.1 = s32[] constant(1) + constant_95 = s32[] constant(95) + constant_191 = s32[] constant(191) + + /* Build DS index */ + subtract_0 = s32[] subtract(constant_95, current_iteration_index.1) + compare_result.1 = pred[] compare(subtract_0, constant_0.1), direction=LT + subtract_1 = s32[] subtract(constant_191, current_iteration_index.1) + select_result.1 = s32[] select(compare_result.1, subtract_1, subtract_0) + + dynamic_slice_0 = f32[1,8,6,2048,2048] dynamic-slice(data_0.1, select_result.1, constant_0.1, constant_0.1, constant_0.1, constant_0.1), dynamic_slice_sizes={1,8,6,2048,2048} + + /* Annotate DS for offload */ + custom_call_0.1 = f32[1,8,6,2048,2048] custom-call(dynamic_slice_0), custom_call_target="PipelineBackward" + + /* Do some work with the dynamic slice outputs. */ + tanh_0 = f32[1,8,6,2048,2048] tanh(custom_call_0.1) + + /* Increment iteration index */ + incremented_index.1 = s32[] add(current_iteration_index.1, constant_1.1) + ROOT tuple_result.1 = (s32[], f32[96,8,6,2048,2048]{0,1,3,2,4}) tuple(incremented_index.1, data_0.1) +} + +ENTRY main { + entry_param_0 = f32[] parameter(0) + entry_param_1 = s32[] parameter(1) + entry_param_2 = s32[] parameter(2) + cs0 = f32[] constant(0) + broadcast_0 = f32[96,8,6,2048,2048]{0,1,2,3,4} broadcast(cs0), dimensions={} + constant_s32_0 = s32[] constant(0) + tuple_for_producing_while = (s32[], f32[96,8,6,2048,2048]{0,1,2,3,4}) tuple(constant_s32_0, broadcast_0) + producing_while = (s32[], f32[96,8,6,2048,2048]{0,1,2,3,4}) while(tuple_for_producing_while), condition=producing_while_condition, body=producing_while_body + while_output_1 = f32[96,8,6,2048,2048]{0,1,2,3,4} get-tuple-element(producing_while), index=1 + cp = f32[96,8,6,2048,2048]{0,1,3,2,4} copy(while_output_1) + tuple_for_consuming_while = (s32[], f32[96,8,6,2048,2048]{0,1,3,2,4}) tuple(constant_s32_0, cp) + consuming_while = (s32[], f32[96,8,6,2048,2048]{0,1,3,2,4}) while(tuple_for_consuming_while), condition=consuming_while_condition, body=consuming_while_body + second_while_output = f32[96,8,6,2048,2048]{0,1,3,2,4} get-tuple-element(consuming_while), index=1 + final_dynamic_slice_0 = f32[1,8,6,2048,2048] dynamic-slice(second_while_output, entry_param_1, constant_s32_0, constant_s32_0, constant_s32_0, constant_s32_0), dynamic_slice_sizes={1,8,6,2048,2048} + final_host_to_device_custom_call_0 = f32[1,8,6,2048,2048] custom-call(final_dynamic_slice_0), custom_call_target="PipelineBackward" + final_slice_0 = f32[1,8,6,2048,2048] slice(second_while_output), slice={[41:42], [0:8], [0:6], [0:2048], [0:2048]} + final_host_to_device_custom_call_1 = f32[1,8,6,2048,2048] custom-call(final_slice_0), custom_call_target="PipelineBackward" + ROOT add = f32[1,8,6,2048,2048] add(final_host_to_device_custom_call_0, final_host_to_device_custom_call_1) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + + TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHostOffloadLegalize(module.get())); + + EXPECT_TRUE(changed); + HloInstruction* copy = FindInstruction(module.get(), HloOpcode::kCopy); + HloInstruction* consuming_while = + FindInstruction(module.get(), "consuming_while"); + EXPECT_NE(copy, nullptr); + EXPECT_NE(consuming_while, nullptr); + EXPECT_EQ(copy->parent(), consuming_while->while_body()); + XLA_VLOG_LINES(1, module->ToString()); +} + +TEST_F(HostOffloadLegalizeTest, LlmActivationHostMemoryMultipleCopies) { + const std::string& hlo_string = R"( +HloModule llm_while + +producing_while_condition { + producing_condition_param = (s32[], f32[96,8,6,2048,2048]{0,1,2,3,4}) parameter(0) + producing_condition_current_iteration_index = s32[] get-tuple-element(producing_condition_param), index=0 + producing_condition_iteration_count = s32[] constant(96) + ROOT producing_condition_result = pred[] compare(producing_condition_current_iteration_index, producing_condition_iteration_count), direction=LT +} + +consuming_while_condition { + consuming_condition_param = (s32[], f32[96,8,6,2048,2048]{0,1,2,3,4}) parameter(0) + consuming_condition_current_iteration_index = s32[] get-tuple-element(consuming_condition_param), index=0 + consuming_condition_iteration_count = s32[] constant(96) + ROOT consuming_condition_result = pred[] compare(consuming_condition_current_iteration_index, consuming_condition_iteration_count), direction=LT +} + +producing_while_body { + input_tuple.0 = (s32[], f32[96,8,6,2048,2048]{0,1,2,3,4}) parameter(0) + current_iteration_index.0 = s32[] get-tuple-element(input_tuple.0), index=0 + data_0.0 = f32[96,8,6,2048,2048]{0,1,2,3,4} get-tuple-element(input_tuple.0), index=1 + constant_0.0 = s32[] constant(0) + constant_1.0 = s32[] constant(1) + constant_96 = s32[] constant(96) + + /* Create dummy data used in DUS */ + slice_data_0 = f32[1,8,6,2048,2048] constant({...}) + + /* Build DUS index */ + compare_result.0 = pred[] compare(current_iteration_index.0, constant_0.0), direction=LT + add_result = s32[] add(current_iteration_index.0, constant_96) + select_result.0 = s32[] select(compare_result.0, add_result, current_iteration_index.0) + + /* Annotate DUS for offload */ + custom_call_0.0 = f32[1,8,6,2048,2048] custom-call(slice_data_0), custom_call_target="PipelineForward" + + dynamic_update_slice_0 = f32[96,8,6,2048,2048]{0,1,2,3,4} dynamic-update-slice(data_0.0, custom_call_0.0, select_result.0, constant_0.0, constant_0.0, constant_0.0, constant_0.0) + + /* Increment iteration index */ + incremented_index.0 = s32[] add(current_iteration_index.0, constant_1.0) + ROOT tuple_result.0 = (s32[], f32[96,8,6,2048,2048]{0,1,2,3,4}) tuple(incremented_index.0, dynamic_update_slice_0) +} + +consuming_while_body { + input_tuple.1 = (s32[], f32[96,8,6,2048,2048]{0,1,3,2,4}) parameter(0) + current_iteration_index.1 = s32[] get-tuple-element(input_tuple.1), index=0 + data_0.1 = f32[96,8,6,2048,2048]{0,1,3,2,4} get-tuple-element(input_tuple.1), index=1 + constant_0.1 = s32[] constant(0) + constant_1.1 = s32[] constant(1) + constant_95 = s32[] constant(95) + constant_191 = s32[] constant(191) + + /* Build DS index */ + subtract_0 = s32[] subtract(constant_95, current_iteration_index.1) + compare_result.1 = pred[] compare(subtract_0, constant_0.1), direction=LT + subtract_1 = s32[] subtract(constant_191, current_iteration_index.1) + select_result.1 = s32[] select(compare_result.1, subtract_1, subtract_0) + + dynamic_slice_0 = f32[1,8,6,2048,2048] dynamic-slice(data_0.1, select_result.1, constant_0.1, constant_0.1, constant_0.1, constant_0.1), dynamic_slice_sizes={1,8,6,2048,2048} + + /* Annotate DS for offload */ + custom_call_0.1 = f32[1,8,6,2048,2048] custom-call(dynamic_slice_0), custom_call_target="PipelineBackward" + + /* Do some work with the dynamic slice outputs. */ + tanh_0 = f32[1,8,6,2048,2048] tanh(custom_call_0.1) + + /* Increment iteration index */ + incremented_index.1 = s32[] add(current_iteration_index.1, constant_1.1) + ROOT tuple_result.1 = (s32[], f32[96,8,6,2048,2048]{0,1,3,2,4}) tuple(incremented_index.1, data_0.1) +} + +ENTRY main { + entry_param_0 = f32[] parameter(0) + entry_param_1 = s32[] parameter(1) + entry_param_2 = s32[] parameter(2) + cs0 = f32[] constant(0) + broadcast_0 = f32[96,8,6,2048,2048]{0,1,2,3,4} broadcast(cs0), dimensions={} + constant_s32_0 = s32[] constant(0) + tuple_for_producing_while = (s32[], f32[96,8,6,2048,2048]{0,1,2,3,4}) tuple(constant_s32_0, broadcast_0) + producing_while = (s32[], f32[96,8,6,2048,2048]{0,1,2,3,4}) while(tuple_for_producing_while), condition=producing_while_condition, body=producing_while_body + while_output_1 = f32[96,8,6,2048,2048]{0,1,2,3,4} get-tuple-element(producing_while), index=1 + cp = f32[96,8,6,2048,2048]{0,1,3,2,4} copy(while_output_1) + cp1 = f32[96,8,6,2048,2048]{0,1,3,2,4} copy(cp) + tuple_for_consuming_while = (s32[], f32[96,8,6,2048,2048]{0,1,3,2,4}) tuple(constant_s32_0, cp1) + consuming_while = (s32[], f32[96,8,6,2048,2048]{0,1,3,2,4}) while(tuple_for_consuming_while), condition=consuming_while_condition, body=consuming_while_body + second_while_output = f32[96,8,6,2048,2048]{0,1,3,2,4} get-tuple-element(consuming_while), index=1 + final_dynamic_slice_0 = f32[1,8,6,2048,2048] dynamic-slice(second_while_output, entry_param_1, constant_s32_0, constant_s32_0, constant_s32_0, constant_s32_0), dynamic_slice_sizes={1,8,6,2048,2048} + final_host_to_device_custom_call_0 = f32[1,8,6,2048,2048] custom-call(final_dynamic_slice_0), custom_call_target="PipelineBackward" + final_slice_0 = f32[1,8,6,2048,2048] slice(second_while_output), slice={[41:42], [0:8], [0:6], [0:2048], [0:2048]} + final_host_to_device_custom_call_1 = f32[1,8,6,2048,2048] custom-call(final_slice_0), custom_call_target="PipelineBackward" + ROOT add = f32[1,8,6,2048,2048] add(final_host_to_device_custom_call_0, final_host_to_device_custom_call_1) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + + TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHostOffloadLegalize(module.get())); + + EXPECT_TRUE(changed); + HloInstruction* copy_0 = FindInstruction(module.get(), "cp.2"); + HloInstruction* copy_1 = FindInstruction(module.get(), "cp1.2"); + HloInstruction* consuming_while = + FindInstruction(module.get(), "consuming_while"); + EXPECT_NE(copy_0, nullptr); + EXPECT_NE(copy_1, nullptr); + EXPECT_NE(consuming_while, nullptr); + EXPECT_EQ(copy_0->parent(), module->entry_computation()); + EXPECT_EQ(copy_1->operand(0), copy_0); + XLA_VLOG_LINES(1, module->ToString()); +} + +} // namespace + +} // namespace xla diff --git a/third_party/xla/xla/service/host_offloader.cc b/third_party/xla/xla/service/host_offloader.cc index 32c36db1689a17..021a92d6e52cbc 100644 --- a/third_party/xla/xla/service/host_offloader.cc +++ b/third_party/xla/xla/service/host_offloader.cc @@ -30,8 +30,11 @@ limitations under the License. #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/literal_util.h" +#include "xla/service/hlo_alias_analysis.h" #include "xla/service/hlo_buffer.h" #include "xla/service/hlo_value.h" +#include "xla/service/host_memory_offload_annotations.h" +#include "xla/service/pattern_matcher.h" #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/status.h" @@ -49,51 +52,106 @@ void SetMemorySpace(Shape* shape, int64_t memory_space_color) { shape->mutable_layout()->set_memory_space(memory_space_color); } -StatusOr DuplicateBroadcastForEachUse(HloModule* module) { - bool split_at_least_one = false; - for (HloComputation* computation : module->computations()) { - for (HloInstruction* instruction : - computation->MakeInstructionPostOrder()) { - if (instruction->opcode() != HloOpcode::kBroadcast || - !instruction->HasConstantOperand()) { - continue; - } - absl::InlinedVector uses; - for (HloInstruction* user : instruction->users()) { - for (int64_t i = 0; i < user->operand_count(); ++i) { - if (user->operand(i) != instruction) { - continue; +// Checks if all of the HloPositions of this HloValue, apart from the defining +// position, are allowed when doing memory-only offload. +bool AllPositionsAreAllowed(const HloValue* value) { + // Given an HloValue, validate that none of its positions are doing any + // compute. + for (const HloPosition& position : value->positions()) { + if (position == value->defining_position()) { + // Skip defining positions. + continue; + } + // Check if this position is of an allowed type. + if (!absl::c_linear_search(HostOffloader::GetAllowedPositionOpcodes(), + position.instruction->opcode())) { + VLOG(1) << "Position " << position.instruction->ToString() + << " is not supported."; + return false; + } + } + + // Did not find any invalid ops. + return true; +} + +bool DefiningPositionIsAllowed(const HloInstruction* instruction) { + static constexpr std::array kAllowedOpcodes = {HloOpcode::kWhile, + HloOpcode::kParameter}; + return absl::c_linear_search(kAllowedOpcodes, instruction->opcode()); +} + +template +StatusOr BufferHasPositionWithUser(const HloBuffer& buffer, + MatcherType matcher) { + HloInstruction* result = nullptr; + for (const HloValue* value : buffer.values()) { + for (const HloPosition& position : value->positions()) { + for (HloInstruction* user : position.instruction->users()) { + if (Match(user, matcher)) { + if (result != nullptr && result != user) { + return Internal("Found multiple matching users! At least %s and %s", + result->name(), user->name()); } - uses.push_back(HloUse{user, i, /*operand_index=*/{}}); + result = user; } } + } + } + return result; +} - if (uses.size() <= 1) { - continue; +template +StatusOr> GetBufferUsersOfType( + const HloBuffer& buffer, MatcherType matcher) { + std::vector result; + for (const HloValue* value : buffer.values()) { + VLOG(3) << "Buffer defined at " << value->defining_instruction()->name() + << " has positions [" + << absl::StrJoin(value->positions(), ", ", + [](std::string* out, const HloPosition& position) { + out->append(position.instruction->name()); + }) + << "]"; + for (const HloPosition& position : value->positions()) { + VLOG(4) << " Position " << position.instruction->name() << " has users [" + << absl::StrJoin( + position.instruction->users(), ", ", + [](std::string* out, const HloInstruction* user) { + out->append(user->name()); + }) + << "]"; + for (HloInstruction* user : position.instruction->users()) { + if (Match(user, matcher)) { + result.emplace_back(user); + } } + } + } + return result; +} - VLOG(1) << "Splitting broadcast " << instruction->ToString() - << " which has " << uses.size() << " uses"; - split_at_least_one = true; - // Don't create a new broadcast for the first use; we can still use the - // original. - for (int i = 1; i < uses.size(); ++i) { - const HloUse& use = uses[i]; - HloInstruction* new_broadcast = - instruction->parent()->AddInstruction(instruction->Clone()); - VLOG(2) << "New broadcast " << new_broadcast->ToString(); - TF_RETURN_IF_ERROR(use.instruction->ReplaceOperandWith( - use.operand_number, new_broadcast)); - } +HloInstruction* FindDSAnnotation(HloInstruction* hlo) { + while (!hlo->IsCustomCall( + host_memory_offload_annotations::kMoveToDeviceCustomCallTarget)) { + if (hlo->opcode() != HloOpcode::kReshape && + hlo->opcode() != HloOpcode::kBitcast) { + break; } + if (hlo->user_count() != 1) { + break; + } + hlo = hlo->users()[0]; } - return split_at_least_one; + return hlo; } } // namespace Status HostOffloader::HandlePipelineForwardCustomCall( HloInstruction* custom_call) { + VLOG(2) << "Found a custom call annotating start-of-host-offload: " + << custom_call->ToString(); // Save a pointer to this custom call for when we want to remove it later. custom_calls_to_remove_.emplace(custom_call); @@ -107,91 +165,327 @@ Status HostOffloader::HandlePipelineForwardCustomCall( out->append(user->name()); })); } - HloInstruction* dynamic_update_slice = custom_call->users()[0]; + HloInstruction* op_being_annotated = custom_call->users()[0]; // Skip past any bitcasts. - // TODO(b/319167527): Update this to be a bit more generic and safe. - while (dynamic_update_slice->opcode() == HloOpcode::kBitcast) { - VLOG(1) << "Skipping bitcast " << dynamic_update_slice->ToString(); - dynamic_update_slice = dynamic_update_slice->users()[0]; + while (op_being_annotated->opcode() == HloOpcode::kBitcast) { + VLOG(1) << "Skipping bitcast " << op_being_annotated->ToString(); + op_being_annotated = op_being_annotated->users()[0]; } - if (dynamic_update_slice->opcode() != HloOpcode::kDynamicUpdateSlice) { - return Internal( - "Expecting only bitcasts between custom call (%s) and dynamic update " - "slice (%s)", - custom_call->name(), dynamic_update_slice->name()); + + if (op_being_annotated->opcode() == HloOpcode::kDynamicUpdateSlice) { + TF_RETURN_IF_ERROR(MemoryOnlyOffloadStartingWithDus(op_being_annotated)); + } else if (op_being_annotated->opcode() == HloOpcode::kCopy) { + TF_RETURN_IF_ERROR(MemoryOnlyOffloadStartingWithCopy(op_being_annotated)); + } else { + TF_RETURN_IF_ERROR(MemoryOnlyOffloadInsertCopies(custom_call)); } + return OkStatus(); +} +Status HostOffloader::MemoryOnlyOffloadStartingWithDus( + const HloInstruction* dynamic_update_slice) { + // The user wants to offload the data defined by this dynamic-update-slice. + VLOG(2) << "Host memory offload starts with a dynamic-update-slice: " + << dynamic_update_slice->name(); // Get the buffer for this DUS. const HloBuffer& unique_buffer = alias_analysis_->GetUniqueBufferAt(dynamic_update_slice); - // Look at the positions of this DUS: - // TODO(b/319167527): - // Add kCopy to the list after ensuring that it is always safe to - // do so. - constexpr std::array kAllowedPositionOpcodes = { - HloOpcode::kTuple, - HloOpcode::kGetTupleElement, - HloOpcode::kDynamicUpdateSlice, - HloOpcode::kBroadcast, - HloOpcode::kWhile, - HloOpcode::kParameter, - HloOpcode::kOptimizationBarrier}; + // We must find at least two HloValues: + // 1. Defined by a broadcast. + // a. For now, we only offload if the original destination of DUS is + // created by a broadcast. + // 2. Defined by a dynamic-update-slice. + const HloValue* dus_value = nullptr; + const HloValue* broadcast_value = nullptr; for (const HloValue* value : unique_buffer.values()) { - for (const HloPosition& position : value->positions()) { - // Check if this position is of an allowed type. - if (absl::c_find(kAllowedPositionOpcodes, - position.instruction->opcode()) == - kAllowedPositionOpcodes.end()) { + HloInstruction* defining_instruction = + value->defining_position().instruction; + if (defining_instruction->opcode() == HloOpcode::kBroadcast) { + if (broadcast_value != nullptr) { + LOG(WARNING) << "Already found one broadcast (" + << broadcast_value->defining_position().instruction->name() + << ") value for this buffer. This one is " + << defining_instruction->name(); + } + broadcast_value = value; + } else if (defining_instruction->opcode() == + HloOpcode::kDynamicUpdateSlice) { + if (dus_value != nullptr) { + LOG(WARNING) << "Already found one dynamic-update-slice (" + << dus_value->defining_position().instruction->name() + << ") value for this buffer. This one is " + << defining_instruction->name(); + } + dus_value = value; + } else { + // For all values other than the two we were looking for, ensure that the + // defining position is non-compute as well as all other positions. + if (!DefiningPositionIsAllowed(value->defining_position().instruction)) { return Internal( - "DynamicUpdateSlice %s's position %s is not supported. Not going " - "to offload this one", - dynamic_update_slice->name(), position.instruction->name()); + "HloValue is defined by an unsupported op: %s. HloValue: %s", + defining_instruction->name(), value->ToString()); + } + if (!AllPositionsAreAllowed(value)) { + return Internal( + "HloValue defined by %s has an invalid position. HloValue: %s", + defining_instruction->name(), value->ToString()); } } } - // Check if there is a broadcast which creates this buffer. - // For now, we only offload if the original destination of DUS is created by a - // broadcast. - HloInstruction* broadcast_instruction = nullptr; - for (const HloValue* val : unique_buffer.values()) { - HloInstruction* defining_instruction = val->defining_position().instruction; - if (defining_instruction->opcode() == HloOpcode::kBroadcast) { - VLOG(1) << "Found a broadcast instruction " - << defining_instruction->ToString(); - broadcast_instruction = defining_instruction; - break; - } + // For the two found HloValues, ensure that all other positions are + // non-compute. + if (dus_value == nullptr) { + return Internal( + "DynamicUpdateSlice's buffer does not have a value which is defined by " + "a dynamic update slice. HloBuffer: %s", + unique_buffer.ToString()); + } + if (!AllPositionsAreAllowed(dus_value)) { + return Internal( + "HloValue defined by %s has an invalid position. HloValue: %s", + dus_value->defining_position().instruction->name(), + dus_value->ToString()); + } + if (broadcast_value == nullptr) { + return Internal( + "DynamicUpdateSlice's buffer does not have a value which is defined by " + "a broadcast. HloBuffer: %s", + unique_buffer.ToString()); + } + if (!AllPositionsAreAllowed(broadcast_value)) { + return Internal( + "HloValue defined by %s has an invalid position. HloValue: %s", + broadcast_value->defining_position().instruction->name(), + broadcast_value->ToString()); } - if (broadcast_instruction == nullptr) { + // TODO(b/319681297): Further analyze the HloValue defined by the broadcast. + // Make sure that nothing is expecting the result of the broadcast, as we'll + // be replacing it. + + // Check that this buffer is finally an input to at least one slice or + // dynamic-slice. + TF_ASSIGN_OR_RETURN( + std::vector consuming_slices, + GetBufferUsersOfType( + unique_buffer, + match::AnyOf(match::Slice(), match::DynamicSlice()))); + VLOG(2) << dynamic_update_slice->name() + << " is consumed by [dynamic-]slices: [" + << absl::StrJoin(consuming_slices, ", ", + [](std::string* out, const HloInstruction* inst) { + out->append(inst->name()); + }) + << ']'; + if (consuming_slices.empty()) { return Internal( - "The destination buffer of %s was not created by a broadcast; cannot " - "offload. Has defining position(s) [%s]", - dynamic_update_slice->name(), - absl::StrJoin(unique_buffer.values(), ", ", - [](std::string* str, const HloValue* value) { - str->append( - value->defining_position().instruction->name()); - })); + "The dynamic-update-slice (%s) never feeds into a slice nor " + "dynamic-slice.", + dynamic_update_slice->name()); } - // TODO(b/319681297): Check that all uses of the broadcast are preceded by a - // host copy. + // Each dynamic_slice and slice should feed into another annotation. + for (HloInstruction* consuming_slice : consuming_slices) { + VLOG(1) << "Host data produced by " << dynamic_update_slice->name() + << " is consumed by " << consuming_slice->name(); + if (consuming_slice->user_count() != 1) { + return Internal( + "Slice/Dynamic-slice %s should only have one user. It should be an " + "annotation " + "to load the data back on the device. Instead, it has users [%s]", + consuming_slice->name(), + absl::StrJoin(consuming_slice->users(), ", ", + [](std::string* out, const HloInstruction* inst) { + out->append(inst->name()); + })); + } + HloInstruction* consuming_slice_user = + FindDSAnnotation(consuming_slice->users()[0]); + if (consuming_slice_user->opcode() != HloOpcode::kCustomCall) { + return Internal( + "Slice/Dynamic-slice %s does not have a matching annotation.", + consuming_slice->name()); + } + if (consuming_slice_user->custom_call_target() != + host_memory_offload_annotations::kMoveToDeviceCustomCallTarget) { + return Internal( + "Found custom-call (%s) is not the expected matching host offload " + "annotation", + consuming_slice_user->name()); + } + expected_host_to_device_annotations_.emplace(consuming_slice_user); + } // Save the broadcast to later be replaced with a // custom-call("AllocateBuffer") - broadcasts_to_replace_.emplace(broadcast_instruction); - buffers_to_move_to_host_memory_.emplace(&unique_buffer); + broadcasts_to_replace_.emplace( + broadcast_value->defining_position().instruction); + AddAllPositionsToBeMovedToHostMemory(unique_buffer); return OkStatus(); } -void HostOffloader::HandlePipelineBackwardCustomCall( +void HostOffloader::AddAllPositionsToBeMovedToHostMemory( + const HloBuffer& unique_buffer) { + for (const HloValue* value : unique_buffer.values()) { + for (const HloPosition& position : value->positions()) { + positions_to_move_to_host_memory_.emplace(position); + } + } +} + +Status HostOffloader::MemoryOnlyOffloadStartingWithCopy( + const HloInstruction* copy) { + // The user wants to offload the data defined by this copy. + VLOG(2) << "Host memory offload starts with a copy: " << copy->name(); + + // Get the buffer for this copy. + const HloBuffer& unique_buffer = alias_analysis_->GetUniqueBufferAt(copy); + + // Look for a value defined by a copy. + const HloValue* copy_value = nullptr; + for (const HloValue* value : unique_buffer.values()) { + HloInstruction* defining_instruction = + value->defining_position().instruction; + if (defining_instruction->opcode() == HloOpcode::kCopy) { + if (copy_value != nullptr) { + LOG(WARNING) + << "Already found one dynamic-update-slice value for this buffer"; + } + copy_value = value; + } else { + // For all other values (that aren't defined by a copy), ensure that the + // defining position is non-compute as well as all other positions. + if (!DefiningPositionIsAllowed(value->defining_position().instruction)) { + return Internal( + "HloValue is defined by an unsupported op: %s. HloValue: %s", + defining_instruction->name(), value->ToString()); + } + if (!AllPositionsAreAllowed(value)) { + return Internal( + "HloValue defined by %s has an invalid position. HloValue: %s", + defining_instruction->name(), value->ToString()); + } + } + } + + if (copy_value == nullptr) { + return Internal( + "Copy's buffer does not have a value which is defined by a copy. " + "HloBuffer: %s", + unique_buffer.ToString()); + } + // For the copy, ensure that all other positions are non-compute. + if (!AllPositionsAreAllowed(copy_value)) { + return Internal( + "HloValue defined by %s has an invalid position. HloValue: %s", + copy_value->defining_position().instruction->name(), + copy_value->ToString()); + } + + // Check that this buffer is finally an input to another copy. + TF_ASSIGN_OR_RETURN(HloInstruction * consuming_copy, + BufferHasPositionWithUser(unique_buffer, match::Copy())); + if (consuming_copy == nullptr) { + return Internal("The copy (%s) never feeds into another copy.", + copy->name()); + } + + // The copy should feed into another annotation. + if (consuming_copy->user_count() != 1) { + return Internal( + "Copy should only have one user. It should be an annotation to load " + "the data back on the device. Instead, it has users [%s]", + absl::StrJoin(consuming_copy->users(), ", ", + [](std::string* out, const HloInstruction* inst) { + out->append(inst->name()); + })); + } + HloInstruction* consuming_copy_user = consuming_copy->users()[0]; + if (consuming_copy_user->opcode() != HloOpcode::kCustomCall) { + return Internal("Copy does not have a matching annotation."); + } + if (consuming_copy_user->custom_call_target() != + host_memory_offload_annotations::kMoveToDeviceCustomCallTarget) { + return Internal( + "Found custom-call is not the expected matching host offload " + "annotation"); + } + expected_host_to_device_annotations_.emplace(consuming_copy_user); + + AddAllPositionsToBeMovedToHostMemory(unique_buffer); + return OkStatus(); +} + +Status HostOffloader::MemoryOnlyOffloadInsertCopies( HloInstruction* custom_call) { + VLOG(3) << "Found an offload annotation (" << custom_call->name() + << "). Expecting that we'll need to insert copies"; + const HloBuffer& unique_buffer = + alias_analysis_->GetUniqueBufferAt(custom_call); + for (const HloValue* value : unique_buffer.values()) { + HloInstruction* defining_instruction = + value->defining_position().instruction; + if (!AllPositionsAreAllowed(value)) { + return Internal( + "HloValue defined by %s has an invalid position. HloValue: %s", + defining_instruction->name(), value->ToString()); + } + } + + // Check that this buffer is finally an input to a load-from-host custom-call. + TF_ASSIGN_OR_RETURN( + HloInstruction * matching_annotation, + BufferHasPositionWithUser( + unique_buffer, + match::CustomCall({host_memory_offload_annotations:: + kMoveToDeviceCustomCallTarget}))); + if (matching_annotation == nullptr) { + return Internal( + "The offloaded data (from %s) never feeds into a matching \"load\" " + "annotation.", + custom_call->name()); + } + expected_host_to_device_annotations_.emplace(matching_annotation); + + // This fits the pattern that we're looking for. Now insert copies to do the + // offload and reload. + HloInstruction* thing_to_offload = custom_call->operands()[0]; + // Create a copy (to host) of the first and only operand to the given custom + // call. + HloInstruction* copy_to_host = + custom_call->parent()->AddInstruction(HloInstruction::CreateUnary( + thing_to_offload->shape(), HloOpcode::kCopy, thing_to_offload)); + // Replace all uses of the offloading custom call with the first copy. + TF_RETURN_IF_ERROR(custom_call->ReplaceAllUsesWith(copy_to_host)); + + HloInstruction* operand_of_load_annotation = + matching_annotation->mutable_operand(0); + // Create another copy (back to device) of that copy. + HloInstruction* copy_to_device = + custom_call->parent()->AddInstruction(HloInstruction::CreateUnary( + copy_to_host->shape(), HloOpcode::kCopy, operand_of_load_annotation)); + // Replace all uses of the operand of the matching annotation with the second + // copy. + TF_RETURN_IF_ERROR( + operand_of_load_annotation->ReplaceAllUsesWith(copy_to_device)); + + AddAllPositionsToBeMovedToHostMemory(unique_buffer); + // Also save the position of the newly created copy-to-host to later have its + // memory space updated. + positions_to_move_to_host_memory_.emplace(HloPosition{copy_to_host}); + return OkStatus(); +} + +Status HostOffloader::HandlePipelineBackwardCustomCall( + HloInstruction* custom_call) { + VLOG(2) << "Found a custom call annotating end-of-host-offload: " + << custom_call->ToString(); // Save a pointer to this custom call for later removal. - custom_calls_to_remove_.emplace(custom_call); + found_host_to_device_annotations_.emplace(custom_call); + return OkStatus(); } Status HostOffloader::DynamifySlice(HloInstruction* slice) { @@ -222,54 +516,64 @@ StatusOr HostOffloader::Run( const absl::flat_hash_set& execution_threads) { bool changed = false; - // Split broadcasts so that each HloUse of a broadcast instruction will get - // its own copy. - // TODO(b/319293925): Do not blindly duplicate all broadcasts, instead do it - // only when necessary. - TF_ASSIGN_OR_RETURN(bool duplicated_at_least_one_broadcast, - DuplicateBroadcastForEachUse(module)); - if (duplicated_at_least_one_broadcast) { - changed = true; - } - // Run HloAliasAnalysis on module. TF_ASSIGN_OR_RETURN(alias_analysis_, HloAliasAnalysis::Run(module)); // Iterate over all instructions and look for XLA host offload annoations. for (HloComputation* computation : module->MakeNonfusionComputations(execution_threads)) { - for (HloInstruction* instruction : computation->instructions()) { + for (HloInstruction* instruction : + computation->MakeInstructionPostOrder()) { if (instruction->opcode() != HloOpcode::kCustomCall) { continue; } - if (instruction->custom_call_target() == kPipelineForwardTarget) { + if (instruction->custom_call_target() == + host_memory_offload_annotations::kMoveToHostCustomCallTarget) { TF_RETURN_IF_ERROR(HandlePipelineForwardCustomCall(instruction)); - } else if (instruction->custom_call_target() == kPipelineBackwardTarget) { - HandlePipelineBackwardCustomCall(instruction); + } else if (instruction->custom_call_target() == + host_memory_offload_annotations:: + kMoveToDeviceCustomCallTarget) { + TF_RETURN_IF_ERROR(HandlePipelineBackwardCustomCall(instruction)); } } } + // Check that we found all the annotations that we expected. + if (found_host_to_device_annotations_ != + expected_host_to_device_annotations_) { + return Internal( + "There is a mismatch between the expected host-to-device annotations " + "(%s) and the found host-to-device annotations (%s)", + absl::StrJoin(expected_host_to_device_annotations_, ", ", + [](std::string* str, HloInstruction* instr) { + str->append(instr->name()); + }), + absl::StrJoin(found_host_to_device_annotations_, ", ", + [](std::string* str, HloInstruction* instr) { + str->append(instr->name()); + })); + } + + // Remove these host-to-device annotations. + for (HloInstruction* instr : found_host_to_device_annotations_) { + custom_calls_to_remove_.emplace(instr); + } + absl::flat_hash_set slices_to_dynamify; - // Change the memory space of these buffers to the host memory space. - for (const HloBuffer* buffer : buffers_to_move_to_host_memory_) { - for (const HloValue* value : buffer->values()) { - for (const HloPosition& position : value->positions()) { - for (HloInstruction* user : position.instruction->users()) { - // If a user of this position is a slice, change it to be a - // dynamic-slice. - if (user->opcode() == HloOpcode::kSlice) { - slices_to_dynamify.emplace(user); - } - } - Shape* shape_to_change = ShapeUtil::GetMutableSubshape( - position.instruction->mutable_shape(), position.index); - VLOG(2) << "Setting instruction to have host memory space: " - << position.instruction->name(); - SetMemorySpace(shape_to_change, kHostMemorySpaceColor); - changed = true; + // Change the memory space of these positions to the host memory space. + for (const HloPosition& position : positions_to_move_to_host_memory_) { + // If a user of this position is a slice, change it to be a dynamic-slice. + for (HloInstruction* user : position.instruction->users()) { + if (user->opcode() == HloOpcode::kSlice) { + slices_to_dynamify.emplace(user); } } + Shape* shape_to_change = ShapeUtil::GetMutableSubshape( + position.instruction->mutable_shape(), position.index); + VLOG(2) << "Setting instruction to have host memory space: " + << position.instruction->name(); + SetMemorySpace(shape_to_change, kHostMemorySpaceColor); + changed = true; } for (HloInstruction* user : slices_to_dynamify) { @@ -289,23 +593,42 @@ StatusOr HostOffloader::Run( changed = true; } + // Recompute alias analysis after changes. + TF_ASSIGN_OR_RETURN(alias_analysis_, HloAliasAnalysis::Run(module)); + auto uses_parameter_buffer = [this](HloInstruction* hlo) { + for (const HloBuffer* buffer : alias_analysis_->ComputeBuffersAt(hlo)) { + for (const HloValue* value : buffer->values()) { + for (const HloPosition& pos : value->positions()) { + if (absl::c_linear_search(hlo->parent()->parameter_instructions(), + pos.instruction)) { + return true; + } + } + } + } + return false; + }; // Remove these custom-calls that were previously used for annotation. for (HloInstruction* custom_call : custom_calls_to_remove_) { CHECK_EQ(custom_call->operand_count(), 1); HloInstruction* operand = custom_call->operands()[0]; - - if (custom_call->shape().layout() != operand->shape().layout()) { - // LayoutAssignment might change the layout of the operand but leave the - // custom call layout unchanged. In that case, insert a copy. - // TODO(b/319686942): Once LayoutAssignment propagates the layout through - // this specific custom call, remove this insertion of a copy. - TF_RETURN_IF_ERROR(custom_call->ReplaceAllUsesWith( - custom_call->parent()->AddInstruction(HloInstruction::CreateUnary( - custom_call->shape(), HloOpcode::kCopy, operand)))); - } else { - CHECK_OK(custom_call->ReplaceAllUsesWith(operand)); + if (custom_call->parent() != + custom_call->GetModule()->entry_computation() && + custom_call->IsCustomCall( + host_memory_offload_annotations::kMoveToHostCustomCallTarget)) { + // Replace custom call with a copy for dynamic-update-slice in case it + // used parameter buffer directly because in case of aliasing with loop + // parameters control dependencies can mess with scheduling. + if (uses_parameter_buffer(operand)) { + VLOG(10) << "Adding copy for custom call " << custom_call->name(); + operand = + custom_call->parent()->AddInstruction(HloInstruction::CreateUnary( + operand->shape(), HloOpcode::kCopy, operand)); + } else { + VLOG(10) << "NOT Adding copy for custom call " << custom_call->name(); + } } - + CHECK_OK(custom_call->ReplaceAllUsesWith(operand)); TF_RETURN_IF_ERROR(custom_call->parent()->RemoveInstruction(custom_call)); changed = true; } diff --git a/third_party/xla/xla/service/host_offloader.h b/third_party/xla/xla/service/host_offloader.h index ed2bc6ddf48c01..6d09717279df1f 100644 --- a/third_party/xla/xla/service/host_offloader.h +++ b/third_party/xla/xla/service/host_offloader.h @@ -29,18 +29,12 @@ class HloCostAnalysis; // This pass does "host memory offloading". If a tensor is annotated to be moved // to or from the host, this pass will remove the annotations and update each -// tensor's layout with host memory spaces and insert copies* if necessary. This +// tensor's layout with host memory spaces and insert copies if necessary. This // pass checks to make sure that no compute is done on the tensors annotated for // host memory offload; if there is compute, it is considered a user error and // an error will be returned. -// -// * TODO(b/319293918): Inserting of copies is not yet implemented. class HostOffloader : public HloModulePass { public: - static constexpr absl::string_view kPipelineForwardTarget = "PipelineForward"; - static constexpr absl::string_view kPipelineBackwardTarget = - "PipelineBackward"; - explicit HostOffloader(int64_t host_memory_space_color) : kHostMemorySpaceColor(host_memory_space_color) {} ~HostOffloader() override = default; @@ -51,17 +45,45 @@ class HostOffloader : public HloModulePass { StatusOr Run( HloModule* module, const absl::flat_hash_set& execution_threads) override; + static absl::Span GetAllowedPositionOpcodes() { + return kAllowedPositionOpcodes; + } private: const int64_t kHostMemorySpaceColor; std::unique_ptr alias_analysis_; + absl::flat_hash_set found_host_to_device_annotations_; + absl::flat_hash_set expected_host_to_device_annotations_; absl::flat_hash_set custom_calls_to_remove_; absl::flat_hash_set broadcasts_to_replace_; - absl::flat_hash_set buffers_to_move_to_host_memory_; + absl::flat_hash_set positions_to_move_to_host_memory_; + + // Positions of all HloValues of the given HloBuffer will be added to + // positions_to_move_to_host_memory_. + void AddAllPositionsToBeMovedToHostMemory(const HloBuffer& unique_buffer); Status HandlePipelineForwardCustomCall(HloInstruction* custom_call); - void HandlePipelineBackwardCustomCall(HloInstruction* custom_call); + Status HandlePipelineBackwardCustomCall(HloInstruction* custom_call); + + // Handle memory-only offloading where the data is written to the host via a + // dynamic-update-slice and is read back via a dynamic-slice. + Status MemoryOnlyOffloadStartingWithDus( + const HloInstruction* dynamic_update_slice); + + // Handle memory-only offloading where the data is written to the host via a + // copy and is read back via a copy. + Status MemoryOnlyOffloadStartingWithCopy(const HloInstruction* copy); + + // Handle memory-only offloading where there are no ops yet for data movement. + // We will insert copies at the points where the annotations are. + Status MemoryOnlyOffloadInsertCopies(HloInstruction* custom_call); + Status DynamifySlice(HloInstruction* slice); + + static constexpr std::array kAllowedPositionOpcodes = { + HloOpcode::kTuple, HloOpcode::kGetTupleElement, + HloOpcode::kOptimizationBarrier, HloOpcode::kParameter, + HloOpcode::kWhile}; }; } // namespace xla diff --git a/third_party/xla/xla/service/host_offloader_test.cc b/third_party/xla/xla/service/host_offloader_test.cc index 4eda253b8c6802..faebee5ee7647c 100644 --- a/third_party/xla/xla/service/host_offloader_test.cc +++ b/third_party/xla/xla/service/host_offloader_test.cc @@ -26,6 +26,8 @@ limitations under the License. #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_opcode.h" +#include "xla/service/host_memory_offload_annotations.h" +#include "xla/service/host_offload_legalize.h" #include "xla/service/pattern_matcher.h" #include "xla/service/pattern_matcher_gmock.h" #include "xla/shape.h" @@ -50,9 +52,15 @@ class HostOffloaderTest : public HloTestBase { if (module->has_schedule()) { return absl::InternalError("Expected a non-scheduled module"); } - + bool changed = false; + HostOffloadLegalize host_offload_legalize(kHostMemorySpaceColor, + /*after_layout=*/false); + TF_ASSIGN_OR_RETURN(bool legal_changed, host_offload_legalize.Run(module)); + changed |= legal_changed; HostOffloader host_offloader(kHostMemorySpaceColor); - return host_offloader.Run(module); + TF_ASSIGN_OR_RETURN(bool offload_changed, host_offloader.Run(module)); + changed |= offload_changed; + return changed; } void TestShapeHasMemorySpace(const Shape& shape, int64_t memory_space) { @@ -64,8 +72,9 @@ class HostOffloaderTest : public HloTestBase { for (const HloComputation* computation : module->computations()) { for (const HloInstruction* instruction : computation->instructions()) { if (instruction->IsCustomCall( - {HostOffloader::kPipelineForwardTarget, - HostOffloader::kPipelineBackwardTarget})) { + {host_memory_offload_annotations::kMoveToHostCustomCallTarget, + host_memory_offload_annotations:: + kMoveToDeviceCustomCallTarget})) { return true; } } @@ -76,7 +85,7 @@ class HostOffloaderTest : public HloTestBase { TEST_F(HostOffloaderTest, BasicDusDs) { const std::string& hlo_string = R"( -HloModule llm_while +HloModule my_module ENTRY main { data_param = f32[1,2048,2048] parameter(0) index_param = s32[] parameter(1) @@ -123,9 +132,322 @@ ENTRY main { EXPECT_FALSE(HaveRemainingOffloadAnnotations(module.get())); } +TEST_F(HostOffloaderTest, BasicCopy) { + const std::string& hlo_string = R"( +HloModule my_module +ENTRY main { + data_param = f32[2048] parameter(0) + offload_custom_call = f32[2048] custom-call(data_param), custom_call_target="PipelineForward" + copy_0 = f32[2048] copy(offload_custom_call) + copy_1 = f32[2048] copy(copy_0) + ROOT load_custom_call = f32[2048] custom-call(copy_1), custom_call_target="PipelineBackward" +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + + TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHostOffloader(module.get())); + + EXPECT_TRUE(changed); + + // Look for the following pattern: + // param + // | + // copy (to host) + // | + // copy (to device) + + HloInstruction* param; + HloInstruction* copy_to_host; + HloInstruction* copy_to_device; + ASSERT_THAT( + module->entry_computation()->root_instruction(), + GmockMatch(m::Copy(©_to_device, + m::Copy(©_to_host, m::Parameter(¶m, 0))))); + TestShapeHasMemorySpace(param->shape(), Layout::kDefaultMemorySpace); + TestShapeHasMemorySpace(copy_to_host->shape(), kHostMemorySpaceColor); + TestShapeHasMemorySpace(copy_to_device->shape(), Layout::kDefaultMemorySpace); + + EXPECT_FALSE(HaveRemainingOffloadAnnotations(module.get())); +} + +TEST_F(HostOffloaderTest, BasicNoCopy) { + const std::string& hlo_string = R"( +HloModule my_module +ENTRY main { + data_param = f32[2048] parameter(0) + offload_custom_call = f32[2048] custom-call(data_param), custom_call_target="PipelineForward" + ROOT load_custom_call = f32[2048] custom-call(offload_custom_call), custom_call_target="PipelineBackward" +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + + TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHostOffloader(module.get())); + + EXPECT_TRUE(changed); + + // Look for the following pattern: + // param + // | + // copy (to host) + // | + // copy (to device) + + HloInstruction* param; + HloInstruction* copy_to_host; + HloInstruction* copy_to_device; + ASSERT_THAT( + module->entry_computation()->root_instruction(), + GmockMatch(m::Copy(©_to_device, + m::Copy(©_to_host, m::Parameter(¶m, 0))))); + TestShapeHasMemorySpace(param->shape(), Layout::kDefaultMemorySpace); + TestShapeHasMemorySpace(copy_to_host->shape(), kHostMemorySpaceColor); + TestShapeHasMemorySpace(copy_to_device->shape(), Layout::kDefaultMemorySpace); + + EXPECT_FALSE(HaveRemainingOffloadAnnotations(module.get())); +} + +TEST_F(HostOffloaderTest, NoCopyWithOptBarrier) { + const std::string& hlo_string = R"( +HloModule my_module +ENTRY main { + data_param = f32[2048] parameter(0) + offload_custom_call = f32[2048] custom-call(data_param), custom_call_target="PipelineForward" + tuple = (f32[2048]) tuple(offload_custom_call) + opt_barrier = (f32[2048]) opt-barrier(tuple) + get_tuple_element = f32[2048] get-tuple-element(opt_barrier), index=0 + ROOT load_custom_call = f32[2048] custom-call(get_tuple_element), custom_call_target="PipelineBackward" +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + + TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHostOffloader(module.get())); + + EXPECT_TRUE(changed); + + // Look for the following pattern: + // param + // | + // copy (to host) + // | + // tuple + // | + // opt-barrier + // | + // get-tuple-element + // | + // copy (to device) + + HloInstruction* param; + HloInstruction* copy_to_host; + HloInstruction* tuple; + HloInstruction* opt_barrier; + HloInstruction* gte; + HloInstruction* copy_to_device; + ASSERT_THAT( + module->entry_computation()->root_instruction(), + GmockMatch(m::Copy( + ©_to_device, + m::GetTupleElement( + >e, m::OptimizationBarrier( + &opt_barrier, + m::Tuple(&tuple, m::Copy(©_to_host, + m::Parameter(¶m, 0)))))))); + TestShapeHasMemorySpace(param->shape(), Layout::kDefaultMemorySpace); + TestShapeHasMemorySpace(copy_to_host->shape(), kHostMemorySpaceColor); + TestShapeHasMemorySpace(ShapeUtil::GetSubshape(tuple->shape(), {0}), + kHostMemorySpaceColor); + TestShapeHasMemorySpace(ShapeUtil::GetSubshape(opt_barrier->shape(), {0}), + kHostMemorySpaceColor); + TestShapeHasMemorySpace(gte->shape(), kHostMemorySpaceColor); + TestShapeHasMemorySpace(copy_to_device->shape(), Layout::kDefaultMemorySpace); + + EXPECT_FALSE(HaveRemainingOffloadAnnotations(module.get())); +} + +TEST_F(HostOffloaderTest, NoCopyWithOptBarrierMoreElaborate) { + const std::string& hlo_string = R"( +HloModule jit_f, entry_computation_layout={(f32[16]{0})->f32[16]{0}} + +ENTRY main.24 { + Arg_0.1 = f32[16]{0} parameter(0), sharding={devices=[2]<=[2]} + cosine.4 = f32[16]{0} cosine(Arg_0.1) + custom-call.5 = f32[16]{0} custom-call(cosine.4), custom_call_target="PipelineForward" + sine.3 = f32[16]{0} sine(Arg_0.1) + cosine.7 = f32[16]{0} cosine(sine.3) + custom-call.8 = f32[16]{0} custom-call(cosine.7), custom_call_target="PipelineForward" + sine.6 = f32[16]{0} sine(sine.3) + cosine.9 = f32[16]{0} cosine(sine.6) + custom-call.10 = f32[16]{0} custom-call(cosine.9), custom_call_target="PipelineForward" + constant.2 = f32[] constant(1) + tuple.11 = (f32[16]{0}, f32[16]{0}, f32[16]{0}, f32[]) tuple(custom-call.5, custom-call.8, custom-call.10, constant.2) + opt-barrier.12 = (f32[16]{0}, f32[16]{0}, f32[16]{0}, f32[]) opt-barrier(tuple.11) + get-tuple-element.16 = f32[] get-tuple-element(opt-barrier.12), index=3 + broadcast.20 = f32[16]{0} broadcast(get-tuple-element.16), dimensions={} + get-tuple-element.15 = f32[16]{0} get-tuple-element(opt-barrier.12), index=2 + custom-call.19 = f32[16]{0} custom-call(get-tuple-element.15), custom_call_target="PipelineBackward" + multiply.21 = f32[16]{0} multiply(broadcast.20, custom-call.19) + get-tuple-element.14 = f32[16]{0} get-tuple-element(opt-barrier.12), index=1 + custom-call.18 = f32[16]{0} custom-call(get-tuple-element.14), custom_call_target="PipelineBackward" + multiply.22 = f32[16]{0} multiply(multiply.21, custom-call.18) + get-tuple-element.13 = f32[16]{0} get-tuple-element(opt-barrier.12), index=0 + custom-call.17 = f32[16]{0} custom-call(get-tuple-element.13), custom_call_target="PipelineBackward" + ROOT multiply.23 = f32[16]{0} multiply(multiply.22, custom-call.17) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + + TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHostOffloader(module.get())); + + EXPECT_TRUE(changed); + + // Look for the following pattern: + // param constant + // __________/ | | + // / | | + // cosine sine | + // | | \____________ | + // | | \ | + // | | sine | + // | | | | + // | cosine cosine | + // | | | | + // copy(to host) copy(to host) copy(to host) | + // \ \ / | + // \______________ | | _________________/ + // \ | | / + // tuple + // | + // opt-barrier + // _____________/ / \ \_____________ + // / / \ \ + // get-tuple-element get-tuple-element get-tuple-element get-tuple-element + // | | | | + // copy(to device) copy(to device) copy(to device) broadcast + // \ \ \ / + // \ \__________ multiply + // \ \ / + // \ multiply + // \_________________________ / + // \ / + // multiply + + HloInstruction* param; + HloInstruction* constant; + HloInstruction* sine_0; + HloInstruction* sine_1; + HloInstruction* cosine_0; + HloInstruction* cosine_1; + HloInstruction* cosine_2; + HloInstruction* copy_to_host_0; + HloInstruction* copy_to_host_1; + HloInstruction* copy_to_host_2; + HloInstruction* tuple; + HloInstruction* opt_barrier; + HloInstruction* gte_0; + HloInstruction* gte_1; + HloInstruction* gte_2; + HloInstruction* gte_3; + HloInstruction* broadcast; + HloInstruction* copy_to_device_0; + HloInstruction* copy_to_device_1; + HloInstruction* copy_to_device_2; + HloInstruction* multiply_0; + HloInstruction* multiply_1; + HloInstruction* multiply_2; + + auto parameter_matcher = m::Parameter(¶m, 0); + auto first_sine_matcher = m::Op(&sine_0) + .WithOpcode(xla::HloOpcode::kSin) + .WithOperand(0, parameter_matcher); + auto opt_barrier_matcher = m::OptimizationBarrier( + &opt_barrier, + m::Tuple( + &tuple, + m::Copy(©_to_host_0, m::Op(&cosine_0) + .WithOpcode(xla::HloOpcode::kCos) + .WithOperand(0, parameter_matcher)), + m::Copy(©_to_host_1, m::Op(&cosine_1) + .WithOpcode(xla::HloOpcode::kCos) + .WithOperand(0, first_sine_matcher)), + m::Copy(©_to_host_2, + m::Op(&cosine_2) + .WithOpcode(xla::HloOpcode::kCos) + .WithOperand(0, m::Op(&sine_1) + .WithOpcode(xla::HloOpcode::kSin) + .WithOperand(0, first_sine_matcher))), + m::Constant(&constant))); + ASSERT_THAT( + module->entry_computation()->root_instruction(), + GmockMatch(m::Multiply( + &multiply_0, + m::Multiply( + &multiply_1, + m::Multiply( + &multiply_2, + m::Broadcast(&broadcast, m::GetTupleElement( + >e_3, opt_barrier_matcher, 3)), + m::Copy(©_to_device_2, + m::GetTupleElement(>e_2, opt_barrier_matcher, 2))), + m::Copy(©_to_device_1, + m::GetTupleElement(>e_1, opt_barrier_matcher, 1))), + m::Copy(©_to_device_0, + m::GetTupleElement(>e_0, opt_barrier_matcher, 0))))); + + TestShapeHasMemorySpace(param->shape(), Layout::kDefaultMemorySpace); + TestShapeHasMemorySpace(constant->shape(), Layout::kDefaultMemorySpace); + TestShapeHasMemorySpace(sine_0->shape(), Layout::kDefaultMemorySpace); + TestShapeHasMemorySpace(sine_1->shape(), Layout::kDefaultMemorySpace); + TestShapeHasMemorySpace(cosine_0->shape(), Layout::kDefaultMemorySpace); + TestShapeHasMemorySpace(cosine_1->shape(), Layout::kDefaultMemorySpace); + TestShapeHasMemorySpace(cosine_2->shape(), Layout::kDefaultMemorySpace); + TestShapeHasMemorySpace(copy_to_host_0->shape(), kHostMemorySpaceColor); + TestShapeHasMemorySpace(copy_to_host_1->shape(), kHostMemorySpaceColor); + TestShapeHasMemorySpace(copy_to_host_2->shape(), kHostMemorySpaceColor); + TestShapeHasMemorySpace(ShapeUtil::GetSubshape(tuple->shape(), {0}), + kHostMemorySpaceColor); + TestShapeHasMemorySpace(ShapeUtil::GetSubshape(tuple->shape(), {1}), + kHostMemorySpaceColor); + TestShapeHasMemorySpace(ShapeUtil::GetSubshape(tuple->shape(), {2}), + kHostMemorySpaceColor); + TestShapeHasMemorySpace(ShapeUtil::GetSubshape(tuple->shape(), {3}), + Layout::kDefaultMemorySpace); + TestShapeHasMemorySpace(ShapeUtil::GetSubshape(opt_barrier->shape(), {0}), + kHostMemorySpaceColor); + TestShapeHasMemorySpace(ShapeUtil::GetSubshape(opt_barrier->shape(), {1}), + kHostMemorySpaceColor); + TestShapeHasMemorySpace(ShapeUtil::GetSubshape(opt_barrier->shape(), {2}), + kHostMemorySpaceColor); + TestShapeHasMemorySpace(ShapeUtil::GetSubshape(opt_barrier->shape(), {3}), + Layout::kDefaultMemorySpace); + TestShapeHasMemorySpace(gte_0->shape(), kHostMemorySpaceColor); + TestShapeHasMemorySpace(gte_1->shape(), kHostMemorySpaceColor); + TestShapeHasMemorySpace(gte_2->shape(), kHostMemorySpaceColor); + TestShapeHasMemorySpace(gte_3->shape(), Layout::kDefaultMemorySpace); + TestShapeHasMemorySpace(broadcast->shape(), Layout::kDefaultMemorySpace); + TestShapeHasMemorySpace(copy_to_device_0->shape(), + Layout::kDefaultMemorySpace); + TestShapeHasMemorySpace(copy_to_device_1->shape(), + Layout::kDefaultMemorySpace); + TestShapeHasMemorySpace(copy_to_device_2->shape(), + Layout::kDefaultMemorySpace); + TestShapeHasMemorySpace(multiply_0->shape(), Layout::kDefaultMemorySpace); + TestShapeHasMemorySpace(multiply_1->shape(), Layout::kDefaultMemorySpace); + TestShapeHasMemorySpace(multiply_2->shape(), Layout::kDefaultMemorySpace); + + EXPECT_FALSE(HaveRemainingOffloadAnnotations(module.get())); +} + TEST_F(HostOffloaderTest, BasicDusDsWithMultipleBroadcastUsers) { const std::string& hlo_string = R"( -HloModule llm_while +HloModule my_module ENTRY main { data_param = f32[1,2048,2048] parameter(0) index_param = s32[] parameter(1) @@ -190,7 +512,7 @@ ENTRY main { TEST_F(HostOffloaderTest, BasicDusDsBitcastBeforeDus) { const std::string& hlo_string = R"( -HloModule llm_while +HloModule my_module ENTRY main { data_param = f32[2048,2048] parameter(0) index_param = s32[] parameter(1) @@ -247,7 +569,7 @@ ENTRY main { // before. TEST_F(HostOffloaderTest, BasicDusDsDusAnnotationOnWrongSide) { const std::string& hlo_string = R"( -HloModule llm_while +HloModule my_module ENTRY main { data_param = f32[1,2048,2048] parameter(0) index_param = s32[] parameter(1) @@ -270,10 +592,9 @@ ENTRY main { } // The annotation is mistakenly before the dynamic-slice; it should be after. -// TODO(b/319686133): Enable this test once it passes. -TEST_F(HostOffloaderTest, DISABLED_BasicDusDsDsAnnotationOnWrongSide) { +TEST_F(HostOffloaderTest, BasicDusDsDsAnnotationOnWrongSide) { const std::string& hlo_string = R"( -HloModule llm_while +HloModule my_module ENTRY main { data_param = f32[1,2048,2048] parameter(0) index_param = s32[] parameter(1) @@ -376,9 +697,9 @@ consuming_while_body { } ENTRY main { - moop = f32[] parameter(0) - broadcast_0 = f32[96,8,6,2048,2048] broadcast(moop), dimensions={} - broadcast_1 = f32[96,8,6,2048,1] broadcast(moop), dimensions={} + entry_param_0 = f32[] parameter(0) + broadcast_0 = f32[96,8,6,2048,2048] broadcast(entry_param_0), dimensions={} + broadcast_1 = f32[96,8,6,2048,1] broadcast(entry_param_0), dimensions={} constant_s32_0 = s32[] constant(0) tuple_for_producing_while = (s32[], f32[96,8,6,2048,2048], f32[96,8,6,2048,1]) tuple(constant_s32_0, broadcast_0, broadcast_1) producing_while = (s32[], f32[96,8,6,2048,2048], f32[96,8,6,2048,1]) while(tuple_for_producing_while), condition=producing_while_condition, body=producing_while_body @@ -616,6 +937,715 @@ ENTRY main { EXPECT_FALSE(HaveRemainingOffloadAnnotations(module.get())); } +TEST_F(HostOffloaderTest, LlmActivationDsWithReshape) { + const std::string& hlo_string = R"( +HloModule llm_while + +producing_while_condition { + producing_condition_param = (s32[], f32[96,8,6,2048,2048], f32[96,8,6,2048,1]) parameter(0) + producing_condition_current_iteration_index = s32[] get-tuple-element(producing_condition_param), index=0 + producing_condition_iteration_count = s32[] constant(96) + ROOT producing_condition_result = pred[] compare(producing_condition_current_iteration_index, producing_condition_iteration_count), direction=LT +} + +consuming_while_condition { + consuming_condition_param = (s32[], f32[96,8,6,2048,2048], f32[96,8,6,2048,1]) parameter(0) + consuming_condition_current_iteration_index = s32[] get-tuple-element(consuming_condition_param), index=0 + consuming_condition_iteration_count = s32[] constant(96) + ROOT consuming_condition_result = pred[] compare(consuming_condition_current_iteration_index, consuming_condition_iteration_count), direction=LT +} + +producing_while_body { + input_tuple.0 = (s32[], f32[96,8,6,2048,2048], f32[96,8,6,2048,1]) parameter(0) + current_iteration_index.0 = s32[] get-tuple-element(input_tuple.0), index=0 + data_0.0 = f32[96,8,6,2048,2048] get-tuple-element(input_tuple.0), index=1 + data_1.0 = f32[96,8,6,2048,1] get-tuple-element(input_tuple.0), index=2 + constant_0.0 = s32[] constant(0) + constant_1.0 = s32[] constant(1) + constant_96 = s32[] constant(96) + + /* Create dummy data used in DUS */ + slice_data_0 = f32[1,8,6,2048,2048] constant({...}) + slice_data_1 = f32[1,8,6,2048,1] constant({...}) + + /* Build DUS index */ + compare_result.0 = pred[] compare(current_iteration_index.0, constant_0.0), direction=LT + add_result = s32[] add(current_iteration_index.0, constant_96) + select_result.0 = s32[] select(compare_result.0, add_result, current_iteration_index.0) + + /* Annotate DUS for offload */ + custom_call_0.0 = f32[1,8,6,2048,2048] custom-call(slice_data_0), custom_call_target="PipelineForward" + custom_call_1.0 = f32[1,8,6,2048,1] custom-call(slice_data_1), custom_call_target="PipelineForward" + + dynamic_update_slice_0 = f32[96,8,6,2048,2048] dynamic-update-slice(data_0.0, custom_call_0.0, select_result.0, constant_0.0, constant_0.0, constant_0.0, constant_0.0) + dynamic_update_slice_1 = f32[96,8,6,2048,1] dynamic-update-slice(data_1.0, custom_call_1.0, select_result.0, constant_0.0, constant_0.0, constant_0.0, constant_0.0) + + /* Increment iteration index */ + incremented_index.0 = s32[] add(current_iteration_index.0, constant_1.0) + ROOT tuple_result.0 = (s32[], f32[96,8,6,2048,2048], f32[96,8,6,2048,1]) tuple(incremented_index.0, dynamic_update_slice_0, dynamic_update_slice_1) +} + +consuming_while_body { + input_tuple.1 = (s32[], f32[96,8,6,2048,2048], f32[96,8,6,2048,1]) parameter(0) + current_iteration_index.1 = s32[] get-tuple-element(input_tuple.1), index=0 + data_0.1 = f32[96,8,6,2048,2048] get-tuple-element(input_tuple.1), index=1 + data_1.1 = f32[96,8,6,2048,1] get-tuple-element(input_tuple.1), index=2 + constant_0.1 = s32[] constant(0) + constant_1.1 = s32[] constant(1) + constant_95 = s32[] constant(95) + constant_191 = s32[] constant(191) + + /* Build DS index */ + subtract_0 = s32[] subtract(constant_95, current_iteration_index.1) + compare_result.1 = pred[] compare(subtract_0, constant_0.1), direction=LT + subtract_1 = s32[] subtract(constant_191, current_iteration_index.1) + select_result.1 = s32[] select(compare_result.1, subtract_1, subtract_0) + + dynamic_slice_0 = f32[1,8,6,2048,2048] dynamic-slice(data_0.1, select_result.1, constant_0.1, constant_0.1, constant_0.1, constant_0.1), dynamic_slice_sizes={1,8,6,2048,2048} + dynamic_slice_1 = f32[1,8,6,2048,1] dynamic-slice(data_1.1, select_result.1, constant_0.1, constant_0.1, constant_0.1, constant_0.1), dynamic_slice_sizes={1,8,6,2048,1} + rs = f32[1,8,6,2048,2048] reshape(dynamic_slice_0) + rs2 = f32[1,8,6,2048,1] reshape(dynamic_slice_1) + /* Annotate DS for offload */ + custom_call_0.1 = f32[1,8,6,2048,2048] custom-call(rs), custom_call_target="PipelineBackward" + custom_call_1.1 = f32[1,8,6,2048,1] custom-call(rs2), custom_call_target="PipelineBackward" + + /* Do some work with the dynamic slice outputs. */ + tanh_0 = f32[1,8,6,2048,2048] tanh(custom_call_0.1) + tanh_1 = f32[1,8,6,2048,1] tanh(custom_call_1.1) + + /* Increment iteration index */ + incremented_index.1 = s32[] add(current_iteration_index.1, constant_1.1) + ROOT tuple_result.1 = (s32[], f32[96,8,6,2048,2048], f32[96,8,6,2048,1]) tuple(incremented_index.1, data_0.1, data_1.1) +} + +ENTRY main { + entry_param_0 = f32[] parameter(0) + broadcast_0 = f32[96,8,6,2048,2048] broadcast(entry_param_0), dimensions={} + broadcast_1 = f32[96,8,6,2048,1] broadcast(entry_param_0), dimensions={} + constant_s32_0 = s32[] constant(0) + tuple_for_producing_while = (s32[], f32[96,8,6,2048,2048], f32[96,8,6,2048,1]) tuple(constant_s32_0, broadcast_0, broadcast_1) + producing_while = (s32[], f32[96,8,6,2048,2048], f32[96,8,6,2048,1]) while(tuple_for_producing_while), condition=producing_while_condition, body=producing_while_body + while_output_1 = f32[96,8,6,2048,2048] get-tuple-element(producing_while), index=1 + while_output_2 = f32[96,8,6,2048,1] get-tuple-element(producing_while), index=2 + tuple_for_consuming_while = (s32[], f32[96,8,6,2048,2048], f32[96,8,6,2048,1]) tuple(constant_s32_0, while_output_1, while_output_2) + ROOT consuming_while = (s32[], f32[96,8,6,2048,2048], f32[96,8,6,2048,1]) while(tuple_for_consuming_while), condition=consuming_while_condition, body=consuming_while_body +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + + TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHostOffloader(module.get())); + + EXPECT_TRUE(changed); + + // First, look for the pattern: + // producing_while + // / \ + // gte gte constant + // \ / / + // \/ / + // tuple + // | + // consuming_while + HloInstruction* consuming_while; + HloInstruction* producing_while_0; + HloInstruction* producing_while_1; + { + HloInstruction* tuple; + HloInstruction* gte_0; + HloInstruction* gte_1; + ASSERT_THAT( + module->entry_computation()->root_instruction(), + GmockMatch(m::While( + &consuming_while, + m::Tuple( + &tuple, m::Constant(), + m::GetTupleElement(>e_0, m::While(&producing_while_0)), + m::GetTupleElement(>e_1, m::While(&producing_while_1)))))); + ASSERT_EQ(producing_while_0, producing_while_1); + + // Check that the memory spaces were properly set. + TestShapeHasMemorySpace(gte_0->shape(), kHostMemorySpaceColor); + TestShapeHasMemorySpace(gte_1->shape(), kHostMemorySpaceColor); + TestShapeHasMemorySpace( + ShapeUtil::GetSubshape(consuming_while->shape(), {1}), + kHostMemorySpaceColor); + TestShapeHasMemorySpace( + ShapeUtil::GetSubshape(consuming_while->shape(), {2}), + kHostMemorySpaceColor); + TestShapeHasMemorySpace( + ShapeUtil::GetSubshape(producing_while_0->shape(), {1}), + kHostMemorySpaceColor); + TestShapeHasMemorySpace( + ShapeUtil::GetSubshape(producing_while_0->shape(), {2}), + kHostMemorySpaceColor); + TestShapeHasMemorySpace(ShapeUtil::GetSubshape(tuple->shape(), {1}), + kHostMemorySpaceColor); + TestShapeHasMemorySpace(ShapeUtil::GetSubshape(tuple->shape(), {2}), + kHostMemorySpaceColor); + } + + // Now, look for the AllocateBuffers leading into the producing while. + { + HloInstruction* allocate_buffer_0; + HloInstruction* allocate_buffer_1; + ASSERT_THAT(producing_while_0, + GmockMatch(m::While(m::Tuple( + m::Constant(), + m::CustomCall(&allocate_buffer_0, {"AllocateBuffer"}), + m::CustomCall(&allocate_buffer_1, {"AllocateBuffer"}))))); + // Check that the memory spaces were properly set. + ASSERT_TRUE(allocate_buffer_0->shape().has_layout()); + EXPECT_EQ(allocate_buffer_0->shape().layout().memory_space(), + kHostMemorySpaceColor); + ASSERT_TRUE(allocate_buffer_1->shape().has_layout()); + EXPECT_EQ(allocate_buffer_1->shape().layout().memory_space(), + kHostMemorySpaceColor); + } + + // There are 4 computations to look at: + // - Consuming while's body + // - Consuming while's condition + // - Producing while's body + // - Producing while's condition + + // For the condition computations, just check that the parameters have the + // right memory space. + TestShapeHasMemorySpace( + ShapeUtil::GetSubshape( + consuming_while->while_condition()->parameter_instruction(0)->shape(), + {1}), + kHostMemorySpaceColor); + TestShapeHasMemorySpace( + ShapeUtil::GetSubshape( + consuming_while->while_condition()->parameter_instruction(0)->shape(), + {2}), + kHostMemorySpaceColor); + + // Now, check the producing while for the following pattern: + // param param + // | | + // gte _... gte _... + // | / | / + // | / | / + // | / | / + // dus dus + // | / + // | / + // _ | / + // \ | / + // \ | / + // \| / + // tuple + { + HloInstruction* tuple; + HloInstruction* dynamic_update_slice_0; + HloInstruction* dynamic_update_slice_1; + HloInstruction* dynamic_update_slice_second_param_0; + HloInstruction* dynamic_update_slice_second_param_1; + HloInstruction* gte_0; + HloInstruction* gte_1; + HloInstruction* param_0; + HloInstruction* param_1; + ASSERT_THAT(producing_while_0->while_body()->root_instruction(), + GmockMatch(m::Tuple( + &tuple, m::Op(), + m::DynamicUpdateSlice( + &dynamic_update_slice_0, + m::GetTupleElement(>e_0, m::Parameter(¶m_0)), + m::Op(&dynamic_update_slice_second_param_0), m::Op(), + m::Op(), m::Op(), m::Op(), m::Op()), + m::DynamicUpdateSlice( + &dynamic_update_slice_1, + m::GetTupleElement(>e_1, m::Parameter(¶m_1)), + m::Op(&dynamic_update_slice_second_param_1), m::Op(), + m::Op(), m::Op(), m::Op(), m::Op())))); + EXPECT_EQ(param_0, param_1); + + // Check that the memory spaces were properly set. + // HOST: + // tuple subshape 1 + // tuple subshape 2 + // dynamic_update_slice_0 shape + // dynamic_update_slice_1 shape + // gte_0 shape + // gte_1 shape + // param_0 subshape 1 + // param_0 subshape 2 + // DEVICE: + // dynamic_update_slice_second_param_0 + // dynamic_update_slice_second_param_1 + + TestShapeHasMemorySpace(ShapeUtil::GetSubshape(tuple->shape(), {1}), + kHostMemorySpaceColor); + TestShapeHasMemorySpace(ShapeUtil::GetSubshape(tuple->shape(), {2}), + kHostMemorySpaceColor); + TestShapeHasMemorySpace(dynamic_update_slice_0->shape(), + kHostMemorySpaceColor); + TestShapeHasMemorySpace(dynamic_update_slice_1->shape(), + kHostMemorySpaceColor); + TestShapeHasMemorySpace(gte_0->shape(), kHostMemorySpaceColor); + TestShapeHasMemorySpace(gte_1->shape(), kHostMemorySpaceColor); + TestShapeHasMemorySpace(ShapeUtil::GetSubshape(param_0->shape(), {1}), + kHostMemorySpaceColor); + TestShapeHasMemorySpace(ShapeUtil::GetSubshape(param_0->shape(), {2}), + kHostMemorySpaceColor); + TestShapeHasMemorySpace(dynamic_update_slice_second_param_0->shape(), + Layout::kDefaultMemorySpace); + TestShapeHasMemorySpace(dynamic_update_slice_second_param_1->shape(), + Layout::kDefaultMemorySpace); + } + + // Now, check the consuming while for the following pattern: + // param + // | | + // gte gte + // | | + // ds ds + { + // Since we do not do anything meaningful with the result of the + // dynamic-slices, there is no easy way to access them from the root. + // Instead, search from the parameter and find all dynamic-slices. + EXPECT_EQ(consuming_while->while_body()->parameter_instructions().size(), + 1); + const HloInstruction* param = + consuming_while->while_body()->parameter_instruction(0); + absl::flat_hash_set dynamic_slices; + std::stack stack; + stack.emplace(param); + while (!stack.empty()) { + const HloInstruction* current = stack.top(); + stack.pop(); + if (current->opcode() == HloOpcode::kDynamicSlice) { + dynamic_slices.emplace(current); + continue; + } + // Add all users. + for (const HloInstruction* user : current->users()) { + stack.emplace(user); + } + } + // There should only be two dynamic-slices. + ASSERT_EQ(dynamic_slices.size(), 2); + for (const HloInstruction* dynamic_slice : dynamic_slices) { + const HloInstruction* get_tuple_element; + const HloInstruction* parameter; + ASSERT_THAT( + dynamic_slice, + GmockMatch(m::DynamicSlice( + m::GetTupleElement(&get_tuple_element, m::Parameter(¶meter)), + m::Op(), m::Op(), m::Op(), m::Op(), m::Op()))); + + // Check that the memory spaces were properly set. + // HOST: + // parameter subshape 1 + // parameter subshape 2 + // get_tuple_element + // DEVICE: + // dynamic_slice + TestShapeHasMemorySpace(ShapeUtil::GetSubshape(parameter->shape(), {1}), + kHostMemorySpaceColor); + TestShapeHasMemorySpace(ShapeUtil::GetSubshape(parameter->shape(), {2}), + kHostMemorySpaceColor); + TestShapeHasMemorySpace(get_tuple_element->shape(), + kHostMemorySpaceColor); + TestShapeHasMemorySpace(dynamic_slice->shape(), + Layout::kDefaultMemorySpace); + } + } + + // Finally, ensure that all annotations have been removed. + EXPECT_FALSE(HaveRemainingOffloadAnnotations(module.get())); +} + +TEST_F(HostOffloaderTest, LlmActivationHostMemoryMultipleConsumers) { + const std::string& hlo_string = R"( +HloModule llm_while + +producing_while_condition { + producing_condition_param = (s32[], f32[96,8,6,2048,2048]) parameter(0) + producing_condition_current_iteration_index = s32[] get-tuple-element(producing_condition_param), index=0 + producing_condition_iteration_count = s32[] constant(96) + ROOT producing_condition_result = pred[] compare(producing_condition_current_iteration_index, producing_condition_iteration_count), direction=LT +} + +consuming_while_condition { + consuming_condition_param = (s32[], f32[96,8,6,2048,2048]) parameter(0) + consuming_condition_current_iteration_index = s32[] get-tuple-element(consuming_condition_param), index=0 + consuming_condition_iteration_count = s32[] constant(96) + ROOT consuming_condition_result = pred[] compare(consuming_condition_current_iteration_index, consuming_condition_iteration_count), direction=LT +} + +producing_while_body { + input_tuple.0 = (s32[], f32[96,8,6,2048,2048]) parameter(0) + current_iteration_index.0 = s32[] get-tuple-element(input_tuple.0), index=0 + data_0.0 = f32[96,8,6,2048,2048] get-tuple-element(input_tuple.0), index=1 + constant_0.0 = s32[] constant(0) + constant_1.0 = s32[] constant(1) + constant_96 = s32[] constant(96) + + /* Create dummy data used in DUS */ + slice_data_0 = f32[1,8,6,2048,2048] constant({...}) + + /* Build DUS index */ + compare_result.0 = pred[] compare(current_iteration_index.0, constant_0.0), direction=LT + add_result = s32[] add(current_iteration_index.0, constant_96) + select_result.0 = s32[] select(compare_result.0, add_result, current_iteration_index.0) + + /* Annotate DUS for offload */ + custom_call_0.0 = f32[1,8,6,2048,2048] custom-call(slice_data_0), custom_call_target="PipelineForward" + + dynamic_update_slice_0 = f32[96,8,6,2048,2048] dynamic-update-slice(data_0.0, custom_call_0.0, select_result.0, constant_0.0, constant_0.0, constant_0.0, constant_0.0) + + /* Increment iteration index */ + incremented_index.0 = s32[] add(current_iteration_index.0, constant_1.0) + ROOT tuple_result.0 = (s32[], f32[96,8,6,2048,2048]) tuple(incremented_index.0, dynamic_update_slice_0) +} + +consuming_while_body { + input_tuple.1 = (s32[], f32[96,8,6,2048,2048]) parameter(0) + current_iteration_index.1 = s32[] get-tuple-element(input_tuple.1), index=0 + data_0.1 = f32[96,8,6,2048,2048] get-tuple-element(input_tuple.1), index=1 + constant_0.1 = s32[] constant(0) + constant_1.1 = s32[] constant(1) + constant_95 = s32[] constant(95) + constant_191 = s32[] constant(191) + + /* Build DS index */ + subtract_0 = s32[] subtract(constant_95, current_iteration_index.1) + compare_result.1 = pred[] compare(subtract_0, constant_0.1), direction=LT + subtract_1 = s32[] subtract(constant_191, current_iteration_index.1) + select_result.1 = s32[] select(compare_result.1, subtract_1, subtract_0) + + dynamic_slice_0 = f32[1,8,6,2048,2048] dynamic-slice(data_0.1, select_result.1, constant_0.1, constant_0.1, constant_0.1, constant_0.1), dynamic_slice_sizes={1,8,6,2048,2048} + + /* Annotate DS for offload */ + custom_call_0.1 = f32[1,8,6,2048,2048] custom-call(dynamic_slice_0), custom_call_target="PipelineBackward" + + /* Do some work with the dynamic slice outputs. */ + tanh_0 = f32[1,8,6,2048,2048] tanh(custom_call_0.1) + + /* Increment iteration index */ + incremented_index.1 = s32[] add(current_iteration_index.1, constant_1.1) + ROOT tuple_result.1 = (s32[], f32[96,8,6,2048,2048]) tuple(incremented_index.1, data_0.1) +} + +ENTRY main { + entry_param_0 = f32[] parameter(0) + entry_param_1 = s32[] parameter(1) + entry_param_2 = s32[] parameter(2) + broadcast_0 = f32[96,8,6,2048,2048] broadcast(entry_param_0), dimensions={} + constant_s32_0 = s32[] constant(0) + tuple_for_producing_while = (s32[], f32[96,8,6,2048,2048]) tuple(constant_s32_0, broadcast_0) + producing_while = (s32[], f32[96,8,6,2048,2048]) while(tuple_for_producing_while), condition=producing_while_condition, body=producing_while_body + while_output_1 = f32[96,8,6,2048,2048] get-tuple-element(producing_while), index=1 + tuple_for_consuming_while = (s32[], f32[96,8,6,2048,2048]) tuple(constant_s32_0, while_output_1) + consuming_while = (s32[], f32[96,8,6,2048,2048]) while(tuple_for_consuming_while), condition=consuming_while_condition, body=consuming_while_body + second_while_output = f32[96,8,6,2048,2048] get-tuple-element(consuming_while), index=1 + final_dynamic_slice_0 = f32[1,8,6,2048,2048] dynamic-slice(second_while_output, entry_param_1, constant_s32_0, constant_s32_0, constant_s32_0, constant_s32_0), dynamic_slice_sizes={1,8,6,2048,2048} + final_host_to_device_custom_call_0 = f32[1,8,6,2048,2048] custom-call(final_dynamic_slice_0), custom_call_target="PipelineBackward" + final_slice_0 = f32[1,8,6,2048,2048] slice(second_while_output), slice={[41:42], [0:8], [0:6], [0:2048], [0:2048]} + final_host_to_device_custom_call_1 = f32[1,8,6,2048,2048] custom-call(final_slice_0), custom_call_target="PipelineBackward" + ROOT add = f32[1,8,6,2048,2048] add(final_host_to_device_custom_call_0, final_host_to_device_custom_call_1) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + + TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHostOffloader(module.get())); + + EXPECT_TRUE(changed); + + // First, look for the pattern: + // producing_while + // | + // constant gte + // \ | + // \ | + // tuple + // | + // consuming_while + // | + // gte + // / \ + // dynamic-slice dynamic-slice + // \ / + // add + // Note: The second dynamic-slice was originally a slice. + HloInstruction* consuming_while; + HloInstruction* producing_while; + { + HloInstruction* tuple; + HloInstruction* gte_between_whiles; + HloInstruction* final_gte; + HloInstruction* dynamic_slice_0; + HloInstruction* dynalic_slice_1; + HloInstruction* add; + auto pattern_ending_in_gte = m::GetTupleElement( + &final_gte, + m::While(&consuming_while, + m::Tuple(&tuple, m::Constant(), + m::GetTupleElement(>e_between_whiles, + m::While(&producing_while))))); + ASSERT_THAT( + module->entry_computation()->root_instruction(), + GmockMatch( + m::Add(&add, + m::DynamicSlice(&dynamic_slice_0, pattern_ending_in_gte, + m::Op(), m::Op(), m::Op(), m::Op(), m::Op()), + m::DynamicSlice(&dynalic_slice_1, pattern_ending_in_gte, + m::ConstantScalar(41), m::Op(), m::Op(), + m::Op(), m::Op())))); + + // Check that the memory spaces were properly set. + TestShapeHasMemorySpace(gte_between_whiles->shape(), kHostMemorySpaceColor); + TestShapeHasMemorySpace( + ShapeUtil::GetSubshape(consuming_while->shape(), {1}), + kHostMemorySpaceColor); + TestShapeHasMemorySpace( + ShapeUtil::GetSubshape(producing_while->shape(), {1}), + kHostMemorySpaceColor); + TestShapeHasMemorySpace(ShapeUtil::GetSubshape(tuple->shape(), {1}), + kHostMemorySpaceColor); + TestShapeHasMemorySpace(final_gte->shape(), kHostMemorySpaceColor); + TestShapeHasMemorySpace(dynamic_slice_0->shape(), + Layout::kDefaultMemorySpace); + TestShapeHasMemorySpace(dynalic_slice_1->shape(), + Layout::kDefaultMemorySpace); + TestShapeHasMemorySpace(add->shape(), Layout::kDefaultMemorySpace); + } + + // Now, look for the AllocateBuffers leading into the producing while. + { + HloInstruction* allocate_buffer; + ASSERT_THAT(producing_while, + GmockMatch(m::While(m::Tuple( + m::Constant(), + m::CustomCall(&allocate_buffer, {"AllocateBuffer"}))))); + // Check that the memory spaces were properly set. + ASSERT_TRUE(allocate_buffer->shape().has_layout()); + EXPECT_EQ(allocate_buffer->shape().layout().memory_space(), + kHostMemorySpaceColor); + } + + // There are 4 computations to look at: + // - Consuming while's body + // - Consuming while's condition + // - Producing while's body + // - Producing while's condition + + // For the condition computations, just check that the parameters have the + // right memory space. + TestShapeHasMemorySpace( + ShapeUtil::GetSubshape( + consuming_while->while_condition()->parameter_instruction(0)->shape(), + {1}), + kHostMemorySpaceColor); + + // Now, check the producing while for the following pattern: + // param + // | + // gte _ + // | / + // | / + // _ dus + // \ | + // tuple + { + HloInstruction* tuple; + HloInstruction* dynamic_update_slice; + HloInstruction* dynamic_update_slice_second_param; + HloInstruction* gte; + HloInstruction* param; + ASSERT_THAT( + producing_while->while_body()->root_instruction(), + GmockMatch(m::Tuple(&tuple, m::Op(), + m::DynamicUpdateSlice( + &dynamic_update_slice, + m::GetTupleElement(>e, m::Parameter(¶m)), + m::Op(&dynamic_update_slice_second_param), + m::Op(), m::Op(), m::Op(), m::Op(), m::Op())))); + + // Check that the memory spaces were properly set. + TestShapeHasMemorySpace(ShapeUtil::GetSubshape(tuple->shape(), {1}), + kHostMemorySpaceColor); + TestShapeHasMemorySpace(dynamic_update_slice->shape(), + kHostMemorySpaceColor); + TestShapeHasMemorySpace(gte->shape(), kHostMemorySpaceColor); + TestShapeHasMemorySpace(ShapeUtil::GetSubshape(param->shape(), {1}), + kHostMemorySpaceColor); + TestShapeHasMemorySpace(dynamic_update_slice_second_param->shape(), + Layout::kDefaultMemorySpace); + } + + // Now, check the consuming while for the following pattern: + // param + // | + // gte + // | + // ds + { + // Since we do not do anything meaningful with the result of the + // dynamic-slices, there is no easy way to access them from the root. + // Instead, search from the parameter and find all dynamic-slices. + EXPECT_EQ(consuming_while->while_body()->parameter_instructions().size(), + 1); + const HloInstruction* param = + consuming_while->while_body()->parameter_instruction(0); + absl::flat_hash_set dynamic_slices; + std::stack stack; + stack.emplace(param); + while (!stack.empty()) { + const HloInstruction* current = stack.top(); + stack.pop(); + if (current->opcode() == HloOpcode::kDynamicSlice) { + dynamic_slices.emplace(current); + continue; + } + // Add all users. + for (const HloInstruction* user : current->users()) { + stack.emplace(user); + } + } + // There should only be one dynamic-slice. + ASSERT_EQ(dynamic_slices.size(), 1); + const HloInstruction* dynamic_slice = *dynamic_slices.begin(); + const HloInstruction* get_tuple_element; + const HloInstruction* parameter; + ASSERT_THAT( + dynamic_slice, + GmockMatch(m::DynamicSlice( + m::GetTupleElement(&get_tuple_element, m::Parameter(¶meter)), + m::Op(), m::Op(), m::Op(), m::Op(), m::Op()))); + + // Check that the memory spaces were properly set. + TestShapeHasMemorySpace(ShapeUtil::GetSubshape(parameter->shape(), {1}), + kHostMemorySpaceColor); + TestShapeHasMemorySpace(get_tuple_element->shape(), kHostMemorySpaceColor); + TestShapeHasMemorySpace(dynamic_slice->shape(), + Layout::kDefaultMemorySpace); + } + + // Finally, ensure that all annotations have been removed. + EXPECT_FALSE(HaveRemainingOffloadAnnotations(module.get())); +} + +TEST_F(HostOffloaderTest, InsertExtraCopyForScheduling) { + const std::string& hlo_string = R"( +HloModule llm_while + +producing_while_condition { + producing_condition_param = (s32[], f32[96,8,6,2048,2048], f32[96,8,6,2048,1], f32[1,8,6,2048,1]) parameter(0) + producing_condition_current_iteration_index = s32[] get-tuple-element(producing_condition_param), index=0 + producing_condition_iteration_count = s32[] constant(96) + ROOT producing_condition_result = pred[] compare(producing_condition_current_iteration_index, producing_condition_iteration_count), direction=LT +} + +consuming_while_condition { + consuming_condition_param = (s32[], f32[96,8,6,2048,2048], f32[96,8,6,2048,1]) parameter(0) + consuming_condition_current_iteration_index = s32[] get-tuple-element(consuming_condition_param), index=0 + consuming_condition_iteration_count = s32[] constant(96) + ROOT consuming_condition_result = pred[] compare(consuming_condition_current_iteration_index, consuming_condition_iteration_count), direction=LT +} + +producing_while_body { + input_tuple.0 = (s32[], f32[96,8,6,2048,2048], f32[96,8,6,2048,1], f32[1,8,6,2048,1]) parameter(0) + current_iteration_index.0 = s32[] get-tuple-element(input_tuple.0), index=0 + data_0.0 = f32[96,8,6,2048,2048] get-tuple-element(input_tuple.0), index=1 + data_1.0 = f32[96,8,6,2048,1] get-tuple-element(input_tuple.0), index=2 + data_2.1 = f32[1,8,6,2048,1] get-tuple-element(input_tuple.0), index=3 + constant_0.0 = s32[] constant(0) + constant_1.0 = s32[] constant(1) + constant_96 = s32[] constant(96) + + /* Create dummy data used in DUS */ + slice_data_0 = f32[1,8,6,2048,2048] constant({...}) + slice_data_1 = f32[1,8,6,2048,1] constant({...}) + + /* Build DUS index */ + compare_result.0 = pred[] compare(current_iteration_index.0, constant_0.0), direction=LT + add_result = s32[] add(current_iteration_index.0, constant_96) + select_result.0 = s32[] select(compare_result.0, add_result, current_iteration_index.0) + + /* Annotate DUS for offload */ + custom_call_0.0 = f32[1,8,6,2048,2048] custom-call(slice_data_0), custom_call_target="PipelineForward" + custom_call_1.0 = f32[1,8,6,2048,1] custom-call(data_2.1), custom_call_target="PipelineForward" + + dynamic_update_slice_0 = f32[96,8,6,2048,2048] dynamic-update-slice(data_0.0, custom_call_0.0, select_result.0, constant_0.0, constant_0.0, constant_0.0, constant_0.0) + dynamic_update_slice_1 = f32[96,8,6,2048,1] dynamic-update-slice(data_1.0, custom_call_1.0, select_result.0, constant_0.0, constant_0.0, constant_0.0, constant_0.0) + + /* Increment iteration index */ + incremented_index.0 = s32[] add(current_iteration_index.0, constant_1.0) + ROOT tuple_result.0 = (s32[], f32[96,8,6,2048,2048], f32[96,8,6,2048,1], f32[1,8,6,2048,1]) tuple(incremented_index.0, dynamic_update_slice_0, dynamic_update_slice_1, data_2.1) +} + +consuming_while_body { + input_tuple.1 = (s32[], f32[96,8,6,2048,2048], f32[96,8,6,2048,1]) parameter(0) + current_iteration_index.1 = s32[] get-tuple-element(input_tuple.1), index=0 + data_0.1 = f32[96,8,6,2048,2048] get-tuple-element(input_tuple.1), index=1 + data_1.1 = f32[96,8,6,2048,1] get-tuple-element(input_tuple.1), index=2 + constant_0.1 = s32[] constant(0) + constant_1.1 = s32[] constant(1) + constant_95 = s32[] constant(95) + constant_191 = s32[] constant(191) + + /* Build DS index */ + subtract_0 = s32[] subtract(constant_95, current_iteration_index.1) + compare_result.1 = pred[] compare(subtract_0, constant_0.1), direction=LT + subtract_1 = s32[] subtract(constant_191, current_iteration_index.1) + select_result.1 = s32[] select(compare_result.1, subtract_1, subtract_0) + + dynamic_slice_0 = f32[1,8,6,2048,2048] dynamic-slice(data_0.1, select_result.1, constant_0.1, constant_0.1, constant_0.1, constant_0.1), dynamic_slice_sizes={1,8,6,2048,2048} + dynamic_slice_1 = f32[1,8,6,2048,1] dynamic-slice(data_1.1, select_result.1, constant_0.1, constant_0.1, constant_0.1, constant_0.1), dynamic_slice_sizes={1,8,6,2048,1} + + /* Annotate DS for offload */ + custom_call_0.1 = f32[1,8,6,2048,2048] custom-call(dynamic_slice_0), custom_call_target="PipelineBackward" + custom_call_1.1 = f32[1,8,6,2048,1] custom-call(dynamic_slice_1), custom_call_target="PipelineBackward" + + /* Do some work with the dynamic slice outputs. */ + tanh_0 = f32[1,8,6,2048,2048] tanh(custom_call_0.1) + tanh_1 = f32[1,8,6,2048,1] tanh(custom_call_1.1) + + /* Increment iteration index */ + incremented_index.1 = s32[] add(current_iteration_index.1, constant_1.1) + ROOT tuple_result.1 = (s32[], f32[96,8,6,2048,2048], f32[96,8,6,2048,1]) tuple(incremented_index.1, data_0.1, data_1.1) +} + +ENTRY main { + entry_param_0 = f32[] parameter(0) + broadcast_0 = f32[96,8,6,2048,2048] broadcast(entry_param_0), dimensions={} + broadcast_1 = f32[96,8,6,2048,1] broadcast(entry_param_0), dimensions={} + broadcast_2 = f32[1,8,6,2048,1] broadcast(entry_param_0), dimensions={} + constant_s32_0 = s32[] constant(0) + tuple_for_producing_while = (s32[], f32[96,8,6,2048,2048], f32[96,8,6,2048,1], f32[1,8,6,2048,1]) tuple(constant_s32_0, broadcast_0, broadcast_1, broadcast_2) + producing_while = (s32[], f32[96,8,6,2048,2048], f32[96,8,6,2048,1], f32[1,8,6,2048,1]) while(tuple_for_producing_while), condition=producing_while_condition, body=producing_while_body + while_output_1 = f32[96,8,6,2048,2048] get-tuple-element(producing_while), index=1 + while_output_2 = f32[96,8,6,2048,1] get-tuple-element(producing_while), index=2 + tuple_for_consuming_while = (s32[], f32[96,8,6,2048,2048], f32[96,8,6,2048,1]) tuple(constant_s32_0, while_output_1, while_output_2) + ROOT consuming_while = (s32[], f32[96,8,6,2048,2048], f32[96,8,6,2048,1]) while(tuple_for_consuming_while), condition=consuming_while_condition, body=consuming_while_body +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHostOffloader(module.get())); + EXPECT_TRUE(changed); + // Finally, ensure that all annotations have been removed. + EXPECT_FALSE(HaveRemainingOffloadAnnotations(module.get())); + const HloInstruction* dus0 = + FindInstruction(module.get(), "dynamic_update_slice_0"); + const HloInstruction* dus1 = + FindInstruction(module.get(), "dynamic_update_slice_1"); + EXPECT_THAT(dus0, GmockMatch(m::DynamicUpdateSlice(m::Op(), m::Constant(), + m::Op(), m::Op(), m::Op(), + m::Op(), m::Op()))); + EXPECT_THAT(dus1, GmockMatch(m::DynamicUpdateSlice(m::Op(), m::Copy(), + m::Op(), m::Op(), m::Op(), + m::Op(), m::Op()))); +} + } // namespace } // namespace xla diff --git a/third_party/xla/xla/service/instruction_fusion.cc b/third_party/xla/xla/service/instruction_fusion.cc index 59095d70086c2b..a36681f055bf8f 100644 --- a/third_party/xla/xla/service/instruction_fusion.cc +++ b/third_party/xla/xla/service/instruction_fusion.cc @@ -177,12 +177,14 @@ bool IsAlwaysDuplicable(const HloInstruction& instruction) { case HloOpcode::kAllReduceStart: case HloOpcode::kAllReduceDone: case HloOpcode::kAllToAll: + case HloOpcode::kCollectiveBroadcast: case HloOpcode::kCollectivePermute: case HloOpcode::kCollectivePermuteDone: case HloOpcode::kCollectivePermuteStart: case HloOpcode::kCustomCall: case HloOpcode::kDomain: case HloOpcode::kDot: + case HloOpcode::kErf: case HloOpcode::kExp: case HloOpcode::kExpm1: case HloOpcode::kFft: diff --git a/third_party/xla/xla/service/latency_hiding_scheduler.cc b/third_party/xla/xla/service/latency_hiding_scheduler.cc index 6dd05f8775b977..8ba9d68143dfb8 100644 --- a/third_party/xla/xla/service/latency_hiding_scheduler.cc +++ b/third_party/xla/xla/service/latency_hiding_scheduler.cc @@ -75,6 +75,10 @@ CanonicalAsyncOp DefaultGetCanonicalAsyncOp(const HloInstruction& hlo) { return {HloOpcode::kAsyncStart, HloOpcode::kAllGather}; case HloOpcode::kCollectivePermuteStart: return {HloOpcode::kAsyncStart, HloOpcode::kCollectivePermute}; + case HloOpcode::kCopyStart: + return {HloOpcode::kAsyncStart, HloOpcode::kCopy}; + case HloOpcode::kCopyDone: + return {HloOpcode::kAsyncDone, HloOpcode::kCopy}; case HloOpcode::kAllReduceDone: return {HloOpcode::kAsyncDone, HloOpcode::kAllReduce}; case HloOpcode::kAllGatherDone: @@ -135,6 +139,7 @@ bool AsyncTracker::IsSupportedAsyncDone(const HloInstruction& hlo) const { case HloOpcode::kAllGather: case HloOpcode::kAllReduce: case HloOpcode::kCollectivePermute: + case HloOpcode::kCopy: case HloOpcode::kReduceScatter: return true; default: @@ -162,6 +167,7 @@ bool AsyncTracker::IsSupportedAsyncStart(const HloInstruction& hlo) const { case HloOpcode::kAllGather: case HloOpcode::kAllReduce: case HloOpcode::kCollectivePermute: + case HloOpcode::kCopy: case HloOpcode::kReduceScatter: return true; default: @@ -184,6 +190,8 @@ ResourcesVector AsyncTracker::GetResourcesFromInstructionImpl( return ResourceType::kAllToAll; case HloOpcode::kCollectivePermute: return ResourceType::kCollectivePermute; + case HloOpcode::kCopy: + return ResourceType::kCopy; case HloOpcode::kReduceScatter: return ResourceType::kReduceScatter; default: @@ -327,6 +335,8 @@ void AsyncTracker::SetConcurrentResourceLimits( max_concurrent_resource[ResourceTypeToIndex( ResourceType::kCollectivePermute)] = config_.collective_permute_overlap_limit; + max_concurrent_resource[ResourceTypeToIndex(ResourceType::kCopy)] = + config_.copy_overlap_limit; max_concurrent_resource[ResourceTypeToIndex(ResourceType::kAllToAll)] = config_.all_to_all_overlap_limit; max_concurrent_resource[ResourceTypeToIndex(ResourceType::kAllGather)] = @@ -362,12 +372,16 @@ absl::string_view AsyncTracker::GetResourceName(int64_t resource_type) const { return "kAllReduce"; case ResourceTypeToIndex(ResourceType::kCollectivePermute): return "kCollectivePermute"; + case ResourceTypeToIndex(ResourceType::kCopy): + return "kCopy"; case ResourceTypeToIndex(ResourceType::kSendRecv): return "kSendRecv"; case ResourceTypeToIndex(ResourceType::kSendHost): return "kSendHost"; case ResourceTypeToIndex(ResourceType::kRecvHost): return "kRecvHost"; + case ResourceTypeToIndex(ResourceType::kReduceScatter): + return "kReduceScatter"; default: return "Not a valid default resource"; } diff --git a/third_party/xla/xla/service/latency_hiding_scheduler.h b/third_party/xla/xla/service/latency_hiding_scheduler.h index 5f73cff7e1afbd..e491dd86ec47a0 100644 --- a/third_party/xla/xla/service/latency_hiding_scheduler.h +++ b/third_party/xla/xla/service/latency_hiding_scheduler.h @@ -56,11 +56,12 @@ enum class ResourceType { kAllGather = 2, kAllReduce = 3, kCollectivePermute = 4, - kReduceScatter = 5, - kSendRecv = 6, - kSendHost = 7, - kRecvHost = 8, - kNumResources = 9, + kCopy = 5, + kReduceScatter = 6, + kSendRecv = 7, + kSendHost = 8, + kRecvHost = 9, + kNumResources = 10, kTargetDefinedResourcesBound = 10000, }; @@ -99,6 +100,7 @@ struct SchedulerConfig { int64_t reduce_scatter_overlap_limit = 1; int64_t send_recv_overlap_limit = 1; int64_t send_recv_host_overlap_limit = 1; + int64_t copy_overlap_limit = 1; uint64_t memory_limit = UINT64_MAX; bool schedule_send_recvs = false; // Consider send recv as the same resource. Some platforms do not take well diff --git a/third_party/xla/xla/service/latency_hiding_scheduler_test.cc b/third_party/xla/xla/service/latency_hiding_scheduler_test.cc index dec1378f8aa3c8..03bcf499c7dc94 100644 --- a/third_party/xla/xla/service/latency_hiding_scheduler_test.cc +++ b/third_party/xla/xla/service/latency_hiding_scheduler_test.cc @@ -2974,4 +2974,79 @@ ENTRY main { // not create a failure of scheduling by the async done checks. EXPECT_TRUE(RunScheduler(hlo_module.get(), sched_config).ok()); } + +TEST_F(LatencyHidingSchedulerTest, CopyScheduling) { + absl::string_view hlo_string = R"( +HloModule EinsumTest, is_scheduled=true +ENTRY AddR2 { + y_host = bf16[12800,12800]{1,0:T(8,128)(2,1)} parameter(1) + z = bf16[12800,12800]{1,0:T(8,128)(2,1)} parameter(2) + x = bf16[12800,12800]{1,0:T(8,128)(2,1)} parameter(0) + convolution = bf16[12800,12800]{1,0:T(8,128)(2,1)} convolution(x, z), dim_labels=bf_io->bf + copy-start = (bf16[12800,12800]{1,0:T(8,128)(2,1)}, bf16[12800,12800]{1,0:T(8,128)(2,1)}, u32[]{:S(2)}) copy-start(y_host) + copy-done = bf16[12800,12800]{1,0:T(8,128)(2,1)} copy-done(copy-start) + ROOT convolution.1 = bf16[12800,12800]{1,0:T(8,128)(2,1)} convolution(convolution, copy-done), dim_labels=bf_io->bf +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(auto hlo_module, ParseHloText(hlo_string)); + HloSchedule& module_schedule = hlo_module->schedule(); + EXPECT_TRUE(hlo_module->has_entry_computation()); + HloComputation* entry_computation = hlo_module->entry_computation(); + std::vector original_instruction_sequence = + module_schedule.sequence(entry_computation).instructions(); + auto sched_config = GetDefaultSchedConfig(); + EXPECT_TRUE(RunScheduler(hlo_module.get(), sched_config).ok()); + const HloInstruction* conv = FindInstruction(hlo_module.get(), "convolution"); + const HloInstruction* cps = FindInstruction(hlo_module.get(), "copy-start"); + const HloInstruction* cpd = FindInstruction(hlo_module.get(), "copy-done"); + std::vector new_instruction_sequence = + module_schedule.sequence(entry_computation).instructions(); + EXPECT_LT(PositionInVector(new_instruction_sequence, cps), + PositionInVector(new_instruction_sequence, conv)); + EXPECT_LT(PositionInVector(new_instruction_sequence, conv), + PositionInVector(new_instruction_sequence, cpd)); + XLA_VLOG_LINES(1, hlo_module->ToString()); +} + +TEST_F(LatencyHidingSchedulerTest, MaxCopyScheduling) { + absl::string_view hlo_string = R"( +HloModule EinsumTest, is_scheduled=true +ENTRY AddR2 { + y_host = bf16[12800,12800]{1,0:T(8,128)(2,1)} parameter(1) + q_host = bf16[12800,12800]{1,0:T(8,128)(2,1)} parameter(3) + z = bf16[12800,12800]{1,0:T(8,128)(2,1)} parameter(2) + x = bf16[12800,12800]{1,0:T(8,128)(2,1)} parameter(0) + convolution = bf16[12800,12800]{1,0:T(8,128)(2,1)} convolution(x, z), dim_labels=bf_io->bf + copy-start = (bf16[12800,12800]{1,0:T(8,128)(2,1)}, bf16[12800,12800]{1,0:T(8,128)(2,1)}, u32[]{:S(2)}) copy-start(y_host) + copy-done = bf16[12800,12800]{1,0:T(8,128)(2,1)} copy-done(copy-start) + copy-start2 = (bf16[12800,12800]{1,0:T(8,128)(2,1)}, bf16[12800,12800]{1,0:T(8,128)(2,1)}, u32[]{:S(2)}) copy-start(q_host) + copy-done2 = bf16[12800,12800]{1,0:T(8,128)(2,1)} copy-done(copy-start2) + ROOT t = (bf16[12800,12800]{1,0:T(8,128)(2,1)}, bf16[12800,12800]{1,0:T(8,128)(2,1)}) tuple(copy-done2, copy-done) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(auto hlo_module, ParseHloText(hlo_string)); + HloSchedule& module_schedule = hlo_module->schedule(); + EXPECT_TRUE(hlo_module->has_entry_computation()); + HloComputation* entry_computation = hlo_module->entry_computation(); + std::vector original_instruction_sequence = + module_schedule.sequence(entry_computation).instructions(); + auto sched_config = GetDefaultSchedConfig(); + EXPECT_TRUE(RunScheduler(hlo_module.get(), sched_config).ok()); + const HloInstruction* conv = FindInstruction(hlo_module.get(), "convolution"); + const HloInstruction* cps = FindInstruction(hlo_module.get(), "copy-start"); + const HloInstruction* cps2 = FindInstruction(hlo_module.get(), "copy-start2"); + const HloInstruction* cpd2 = FindInstruction(hlo_module.get(), "copy-done2"); + std::vector new_instruction_sequence = + module_schedule.sequence(entry_computation).instructions(); + EXPECT_LT(PositionInVector(new_instruction_sequence, cps2), + PositionInVector(new_instruction_sequence, conv)); + EXPECT_LT(PositionInVector(new_instruction_sequence, conv), + PositionInVector(new_instruction_sequence, cpd2)); + EXPECT_LT(PositionInVector(new_instruction_sequence, cps), + PositionInVector(new_instruction_sequence, cpd2)); + XLA_VLOG_LINES(1, hlo_module->ToString()); +} + } // namespace xla diff --git a/third_party/xla/xla/service/layout_assignment.cc b/third_party/xla/xla/service/layout_assignment.cc index b55568ca3b9d58..b73f96444a659b 100644 --- a/third_party/xla/xla/service/layout_assignment.cc +++ b/third_party/xla/xla/service/layout_assignment.cc @@ -2726,10 +2726,12 @@ bool LayoutAssignment::InstructionCanChangeLayout( case HloOpcode::kAllGatherStart: case HloOpcode::kAllGatherDone: case HloOpcode::kAllToAll: + case HloOpcode::kCollectiveBroadcast: case HloOpcode::kCollectivePermute: case HloOpcode::kDivide: case HloOpcode::kDynamicSlice: case HloOpcode::kDynamicUpdateSlice: + case HloOpcode::kErf: case HloOpcode::kExp: case HloOpcode::kExpm1: case HloOpcode::kFft: diff --git a/third_party/xla/xla/service/layout_normalization.cc b/third_party/xla/xla/service/layout_normalization.cc index 729ab2578b7443..052d6e458440fe 100644 --- a/third_party/xla/xla/service/layout_normalization.cc +++ b/third_party/xla/xla/service/layout_normalization.cc @@ -238,6 +238,24 @@ class LayoutNormalizationVisitor : public DfsHloRewriteVisitor { return OkStatus(); } + Status HandleIota(HloInstruction* hlo) override { + VLOG(3) << "Input iota: " << hlo->ToString(); + auto s = hlo->shape(); + auto normalized_shape = Normalize(s); + std::vector orig_output_layout_as_permutation = + ToTransposeDimensions(s.layout()); + int64_t iota_dimension = hlo->dimensions()[0]; + int64_t new_iota_dimension = + FindIndex(orig_output_layout_as_permutation, iota_dimension); + auto normalized_iota = hlo->AddInstruction( + HloInstruction::CreateIota(normalized_shape, new_iota_dimension)); + SetVisited(*normalized_iota); + VLOG(3) << "Generated iota: " << normalized_iota->ToString(); + auto bc_to_orig = MakeBitcastHlo(normalized_iota, s); + TF_RETURN_IF_ERROR(ReplaceInstruction(hlo, bc_to_orig)); + return OkStatus(); + } + // BitcastConvert is only layout-preserving if it doesn't change the rank. Status HandleBitcastConvert(HloInstruction* hlo) override { // If the rank isn't changing this is just an unary op. diff --git a/third_party/xla/xla/service/layout_normalization_test.cc b/third_party/xla/xla/service/layout_normalization_test.cc index 5dd284edbf29c1..6cebe0f2a858cc 100644 --- a/third_party/xla/xla/service/layout_normalization_test.cc +++ b/third_party/xla/xla/service/layout_normalization_test.cc @@ -311,6 +311,23 @@ ENTRY main { )"); } +TEST_F(LayoutNormalizationTest, IotaCustomOutputLayout) { + const char* hlo = R"( +HloModule module + +ENTRY main { + a = f32[2,4,3]{1,2,0} iota(), iota_dimension=2 + ROOT out = abs(a) +} +)"; + + CheckLayoutNormalization(hlo, R"( +// CHECK: [[iota_2:%[^ ]+]] = f32[2,3,4]{2,1,0} iota(), iota_dimension=1 +// CHECK: [[abs_3:%[^ ]+]] = f32[2,3,4]{2,1,0} abs([[iota_2]]) +// CHECK: ROOT [[bitcast_3_4:%[^ ]+]] = f32[2,4,3]{1,2,0} bitcast([[abs_3]]) +)"); +} + TEST_F(LayoutNormalizationTest, Concatenate) { const char* hlo = R"( HloModule module diff --git a/third_party/xla/xla/service/llvm_ir/BUILD b/third_party/xla/xla/service/llvm_ir/BUILD index a932ff23a0102b..a706957cb6f960 100644 --- a/third_party/xla/xla/service/llvm_ir/BUILD +++ b/third_party/xla/xla/service/llvm_ir/BUILD @@ -2,11 +2,13 @@ # Libraries for helping construct LLVM IR for XLA backends. load("//xla:xla.bzl", "xla_cc_test") +load("@local_tsl//tsl:tsl.bzl", "internal_visibility") load("@local_tsl//tsl:tsl.default.bzl", "filegroup") load("@local_tsl//tsl/platform:rules_cc.bzl", "cc_library") package( - default_visibility = ["//visibility:public"], + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = internal_visibility([":friends"]), licenses = ["notice"], ) @@ -24,14 +26,12 @@ filegroup( "**/*.cc", "**/*.h", ]), - visibility = ["//visibility:public"], ) cc_library( name = "alias_analysis", srcs = ["alias_analysis.cc"], hdrs = ["alias_analysis.h"], - visibility = ["//visibility:public"], deps = [ ":ir_array", ":llvm_type_conversion_util", @@ -63,7 +63,6 @@ cc_library( name = "llvm_util", srcs = ["llvm_util.cc"], hdrs = ["llvm_util.h"], - visibility = ["//visibility:public"], deps = [ ":llvm_type_conversion_util", "//xla:literal", @@ -81,6 +80,7 @@ cc_library( "@com_google_absl//absl/types:span", "@llvm-project//llvm:Core", "@llvm-project//llvm:Support", + "@llvm-project//llvm:TargetParser", "@llvm-project//llvm:TransformUtils", "@llvm-project//mlir:IR", "@local_tsl//tsl/platform:byte_order", @@ -93,7 +93,6 @@ cc_library( cc_library( name = "llvm_type_conversion_util", hdrs = ["llvm_type_conversion_util.h"], - visibility = ["//visibility:public"], deps = [ "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", @@ -104,7 +103,6 @@ cc_library( cc_library( name = "llvm_command_line_options", hdrs = ["llvm_command_line_options.h"], - visibility = ["//visibility:public"], deps = [ "@com_google_absl//absl/strings", "@llvm-project//llvm:Support", @@ -116,7 +114,6 @@ cc_library( name = "ir_array", srcs = ["ir_array.cc"], hdrs = ["ir_array.h"], - visibility = ["//visibility:public"], deps = [ ":llvm_type_conversion_util", ":llvm_util", @@ -138,7 +135,6 @@ cc_library( name = "llvm_loop", srcs = ["llvm_loop.cc"], hdrs = ["llvm_loop.h"], - visibility = ["//visibility:public"], deps = [ ":ir_array", ":llvm_util", @@ -157,7 +153,6 @@ cc_library( name = "loop_emitter", srcs = ["loop_emitter.cc"], hdrs = ["loop_emitter.h"], - visibility = ["//visibility:public"], deps = [ ":ir_array", ":llvm_loop", @@ -177,7 +172,6 @@ cc_library( name = "fused_ir_emitter", srcs = ["fused_ir_emitter.cc"], hdrs = ["fused_ir_emitter.h"], - visibility = ["//visibility:public"], deps = [ ":ir_array", ":llvm_util", @@ -191,6 +185,7 @@ cc_library( "//xla/service:fusion_node_indexing_evaluation", "@com_google_absl//absl/container:flat_hash_map", "@llvm-project//llvm:Core", + "@llvm-project//llvm:TargetParser", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:logging", "@local_tsl//tsl/platform:statusor", @@ -201,7 +196,6 @@ cc_library( name = "dynamic_update_slice_util", srcs = ["dynamic_update_slice_util.cc"], hdrs = ["dynamic_update_slice_util.h"], - visibility = ["//visibility:public"], deps = [ ":fused_ir_emitter", ":ir_array", @@ -220,7 +214,6 @@ cc_library( name = "sort_util", srcs = ["sort_util.cc"], hdrs = ["sort_util.h"], - visibility = ["//visibility:public"], deps = [ ":ir_array", ":kernel_support_library", @@ -244,7 +237,6 @@ cc_library( name = "tuple_ops", srcs = ["tuple_ops.cc"], hdrs = ["tuple_ops.h"], - visibility = ["//visibility:public"], deps = [ ":ir_array", ":llvm_type_conversion_util", @@ -262,7 +254,6 @@ cc_library( name = "kernel_support_library", srcs = ["kernel_support_library.cc"], hdrs = ["kernel_support_library.h"], - visibility = ["//visibility:public"], deps = [ ":llvm_loop", ":llvm_type_conversion_util", @@ -276,7 +267,6 @@ cc_library( name = "buffer_assignment_util", srcs = ["buffer_assignment_util.cc"], hdrs = ["buffer_assignment_util.h"], - visibility = ["//visibility:public"], deps = [ "//xla/hlo/ir:hlo", "//xla/service:buffer_assignment", @@ -288,7 +278,6 @@ cc_library( name = "math_ops", srcs = ["math_ops.cc"], hdrs = ["math_ops.h"], - visibility = ["//visibility:public"], deps = [ ":llvm_util", "@llvm-project//llvm:Core", @@ -299,7 +288,6 @@ cc_library( name = "ir_builder_mixin", srcs = [], hdrs = ["ir_builder_mixin.h"], - visibility = ["//visibility:public"], deps = [ "@llvm-project//llvm:Core", ], diff --git a/third_party/xla/xla/service/llvm_ir/dynamic_update_slice_util.cc b/third_party/xla/xla/service/llvm_ir/dynamic_update_slice_util.cc index 90d8ba733b33c1..6b17b942364799 100644 --- a/third_party/xla/xla/service/llvm_ir/dynamic_update_slice_util.cc +++ b/third_party/xla/xla/service/llvm_ir/dynamic_update_slice_util.cc @@ -98,7 +98,7 @@ bool CanEmitFusedDynamicUpdateSliceInPlace(HloInstruction* fusion, // EmitFusedDynamicUpdateSliceInPlace. // // Emits a sequential loop if launch_dimensions is null. -using IndexGenerator = std::function(int64_t)>; +using IndexGenerator = std::function(int64_t)>; static Status EmitDynamicUpdateSliceInPlaceImpl( const Shape& update_shape, const IndexGenerator& start_indices_generator, @@ -235,7 +235,7 @@ static Status EmitFusedDynamicUpdateSliceInPlaceImpl( fused_emitter->GetGenerator(*update)); IndexGenerator start_indices_generator = - [&](int64_t index) -> StatusOr { + [&](int64_t index) -> absl::StatusOr { TF_ASSIGN_OR_RETURN(ElementGenerator element_generator, fused_emitter->GetGenerator( *dynamic_update_slice->operand(2 + index))); diff --git a/third_party/xla/xla/service/llvm_ir/fused_ir_emitter.cc b/third_party/xla/xla/service/llvm_ir/fused_ir_emitter.cc index d15499c60aa8ec..f69cc88a14f711 100644 --- a/third_party/xla/xla/service/llvm_ir/fused_ir_emitter.cc +++ b/third_party/xla/xla/service/llvm_ir/fused_ir_emitter.cc @@ -23,6 +23,7 @@ limitations under the License. #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Module.h" #include "llvm/IR/Value.h" +#include "llvm/TargetParser/Triple.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" @@ -44,14 +45,14 @@ namespace xla { using llvm_ir::IrArray; -StatusOr FusedIrEmitter::DefaultAction( +absl::StatusOr FusedIrEmitter::DefaultAction( const HloInstruction& instruction) { IndexedGenerator generator = elemental_emitter_.MakeElementGenerator( &instruction, indexed_generators_); - return StatusOr([&, generator = std::move(generator)]( - const IrArray::Index& index) - -> StatusOr { + return absl::StatusOr([&, generator = std::move(generator)]( + const IrArray::Index& index) + -> absl::StatusOr { ValueCacheKey key{&instruction, index.multidim()}; llvm::Value* value = value_cache_.insert({key, nullptr}).first->second; @@ -95,6 +96,8 @@ FusedIrEmitter::IndexedGenerator FusedIrEmitter::HandleConstant( llvm::Module* module = elemental_emitter_.module(); llvm::IRBuilder<>* b = elemental_emitter_.b(); + // Explicitly set global addrspace for SPIR backend. + int addrspace = llvm::Triple(module->getTargetTriple()).isSPIR() ? 1 : 0; llvm::Constant* initializer = llvm_ir::ConvertLiteralToIrConstant(constant.literal(), module); llvm::GlobalVariable* global = new llvm::GlobalVariable( @@ -104,7 +107,7 @@ FusedIrEmitter::IndexedGenerator FusedIrEmitter::HandleConstant( /*Initializer=*/initializer, /*Name=*/"", /*InsertBefore=*/nullptr, /*TLMode=*/llvm::GlobalValue::NotThreadLocal, - /*AddressSpace=*/0, + /*AddressSpace=*/addrspace, /*isExternallyInitialized=*/false); global->setUnnamedAddr(llvm::GlobalVariable::UnnamedAddr::Global); @@ -116,7 +119,7 @@ FusedIrEmitter::IndexedGenerator FusedIrEmitter::HandleConstant( }; } -StatusOr FusedIrEmitter::HandleTuple( +absl::StatusOr FusedIrEmitter::HandleTuple( const HloInstruction& tuple) { std::vector element_ir_types; element_ir_types.reserve(tuple.operand_count()); @@ -128,8 +131,9 @@ StatusOr FusedIrEmitter::HandleTuple( llvm::IRBuilder<>* b = elemental_emitter_.b(); llvm::Type* type = llvm::StructType::get(b->getContext(), element_ir_types); - return StatusOr([&, b, type](const IrArray::Index& index) - -> StatusOr { + return absl::StatusOr([&, b, + type](const IrArray::Index& index) + -> absl::StatusOr { llvm::Value* ret = llvm::UndefValue::get(type); for (size_t i = 0; i < tuple.operand_count(); ++i) { IrArray::Index used_index = index; @@ -146,8 +150,8 @@ StatusOr FusedIrEmitter::HandleTuple( }); } -StatusOr FusedIrEmitter::CreateGenerator( - const HloInstruction& instruction) { +absl::StatusOr +FusedIrEmitter::CreateGenerator(const HloInstruction& instruction) { switch (instruction.opcode()) { case HloOpcode::kConstant: return HandleConstant(instruction); @@ -162,7 +166,7 @@ StatusOr FusedIrEmitter::CreateGenerator( } } -StatusOr FusedIrEmitter::GetGenerator( +absl::StatusOr FusedIrEmitter::GetGenerator( const HloInstruction& instruction) { std::vector stack = {&instruction}; while (!stack.empty()) { diff --git a/third_party/xla/xla/service/llvm_ir/fused_ir_emitter.h b/third_party/xla/xla/service/llvm_ir/fused_ir_emitter.h index e8e38117a24cb0..098666d890eb96 100644 --- a/third_party/xla/xla/service/llvm_ir/fused_ir_emitter.h +++ b/third_party/xla/xla/service/llvm_ir/fused_ir_emitter.h @@ -56,13 +56,16 @@ class FusedIrEmitter { } // Returns the generator function for the given instruction. - StatusOr GetGenerator(const HloInstruction& instruction); + absl::StatusOr GetGenerator( + const HloInstruction& instruction); private: - StatusOr CreateGenerator(const HloInstruction& instruction); - StatusOr DefaultAction(const HloInstruction& instruction); + absl::StatusOr CreateGenerator( + const HloInstruction& instruction); + absl::StatusOr DefaultAction( + const HloInstruction& instruction); IndexedGenerator HandleConstant(const HloInstruction& constant); - StatusOr HandleTuple(const HloInstruction& tuple); + absl::StatusOr HandleTuple(const HloInstruction& tuple); ElementalIrEmitter& elemental_emitter_; diff --git a/third_party/xla/xla/service/llvm_ir/ir_array.cc b/third_party/xla/xla/service/llvm_ir/ir_array.cc index d29fabe8588e90..25785d175b108c 100644 --- a/third_party/xla/xla/service/llvm_ir/ir_array.cc +++ b/third_party/xla/xla/service/llvm_ir/ir_array.cc @@ -373,6 +373,13 @@ IrArray::Index IrArray::Index::SourceIndexOfBitcast( return index; } +IrArray::Index IrArray::Index::SourceIndexOfBitcast( + const Shape& operand_shape, llvm::IRBuilder<>* builder) const { + auto shape = ShapeUtil::MakeShape(F32, dims_); + *shape.mutable_layout() = layout_; + return SourceIndexOfBitcast(shape, operand_shape, builder); +} + IrArray::Index IrArray::Index::SourceIndexOfBroadcast( const Shape& shape, const Shape& operand_shape, absl::Span dimension_mapping, diff --git a/third_party/xla/xla/service/llvm_ir/ir_array.h b/third_party/xla/xla/service/llvm_ir/ir_array.h index 837b24512d4341..0212e90e8da0d0 100644 --- a/third_party/xla/xla/service/llvm_ir/ir_array.h +++ b/third_party/xla/xla/service/llvm_ir/ir_array.h @@ -165,6 +165,9 @@ class IrArray { // to `shape`, returns the source index. Index SourceIndexOfBitcast(const Shape& shape, const Shape& operand_shape, llvm::IRBuilder<>* builder) const; + // Same as above, but for bitcasts from `operand_shape` to `this->dims`. + Index SourceIndexOfBitcast(const Shape& operand_shape, + llvm::IRBuilder<>* builder) const; // Given that "this" is the target index of a broadcast from `operand_shape` // to `shape` with the given dimension mapping, returns the source index. diff --git a/third_party/xla/xla/service/llvm_ir/llvm_util.cc b/third_party/xla/xla/service/llvm_ir/llvm_util.cc index ece29e638306b4..3a9ab3bb6f9c3c 100644 --- a/third_party/xla/xla/service/llvm_ir/llvm_util.cc +++ b/third_party/xla/xla/service/llvm_ir/llvm_util.cc @@ -50,6 +50,7 @@ limitations under the License. #include "llvm/Support/Casting.h" #include "llvm/Support/CodeGen.h" #include "llvm/Support/raw_ostream.h" +#include "llvm/TargetParser/Triple.h" #include "llvm/Transforms/Utils/Cloning.h" #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/IR/Location.h" // from @llvm-project @@ -321,9 +322,8 @@ llvm::Type* ShapeToIrType(const Shape& shape, llvm::Module* module) { return result_type; } -StatusOr EncodeSelfDescribingShapeConstant(const Shape& shape, - int32_t* shape_size, - llvm::IRBuilder<>* b) { +absl::StatusOr EncodeSelfDescribingShapeConstant( + const Shape& shape, int32_t* shape_size, llvm::IRBuilder<>* b) { std::string encoded_shape = shape.SerializeAsString(); if (encoded_shape.size() > std::numeric_limits::max()) { return Internal("Encoded shape size exceeded int32_t size limit."); @@ -420,8 +420,11 @@ llvm::AllocaInst* EmitAllocaAtFunctionEntryWithCount(llvm::Type* type, llvm::Function* function = b->GetInsertBlock()->getParent(); b->SetInsertPoint(&function->getEntryBlock(), function->getEntryBlock().getFirstInsertionPt()); + llvm::Module* module = b->GetInsertBlock()->getModule(); + // Explicitly set local addrspace for SPIR backend. + int addrspace = llvm::Triple(module->getTargetTriple()).isSPIR() ? 5 : 0; llvm::AllocaInst* alloca = - b->CreateAlloca(type, element_count, AsStringRef(name)); + b->CreateAlloca(type, addrspace, element_count, AsStringRef(name)); if (alignment != 0) { alloca->setAlignment(llvm::Align(alignment)); } @@ -539,7 +542,11 @@ void SetDereferenceableMetadataForLoad(llvm::LoadInst* load, } llvm::Instruction* AddRangeMetadata(int32_t lower, int32_t upper, - llvm::Instruction* inst) { + llvm::Instruction* inst, + llvm::Module* module) { + if (llvm::Triple(module->getTargetTriple()).isSPIR()) { + return inst; + } llvm::LLVMContext& context = inst->getParent()->getContext(); llvm::IntegerType* i32 = llvm::Type::getInt32Ty(context); inst->setMetadata( diff --git a/third_party/xla/xla/service/llvm_ir/llvm_util.h b/third_party/xla/xla/service/llvm_ir/llvm_util.h index b5cabf61d2def8..41945f6cbbeb8b 100644 --- a/third_party/xla/xla/service/llvm_ir/llvm_util.h +++ b/third_party/xla/xla/service/llvm_ir/llvm_util.h @@ -143,9 +143,8 @@ llvm::Type* ShapeToIrType(const Shape& shape, llvm::Module* module); // Returns a value that represents a pointer to a global string constant that // encodes the shape as a serialized protobuf. -StatusOr EncodeSelfDescribingShapeConstant(const Shape& shape, - int32_t* shape_size, - llvm::IRBuilder<>* b); +absl::StatusOr EncodeSelfDescribingShapeConstant( + const Shape& shape, int32_t* shape_size, llvm::IRBuilder<>* b); // Converts a given literal to an IR Constant. Literals have known constant // values at IR emission time. @@ -273,7 +272,8 @@ void SetDereferenceableMetadataForLoad(llvm::LoadInst* load, // Tells LLVM `inst >= lower && inst < upper`. Returns `inst` for convenience. llvm::Instruction* AddRangeMetadata(int32_t lower, int32_t upper, - llvm::Instruction* inst); + llvm::Instruction* inst, + llvm::Module* module); void SetToFirstInsertPoint(llvm::BasicBlock* blk, llvm::IRBuilder<>* builder); diff --git a/third_party/xla/xla/service/llvm_ir/loop_emitter.h b/third_party/xla/xla/service/llvm_ir/loop_emitter.h index 01d030e17b568f..da23eeb5ce37c2 100644 --- a/third_party/xla/xla/service/llvm_ir/loop_emitter.h +++ b/third_party/xla/xla/service/llvm_ir/loop_emitter.h @@ -34,7 +34,7 @@ namespace llvm_ir { // The function has to emit code to compute this value and return the resulting // llvm::Value*. using ElementGenerator = - std::function(const IrArray::Index& index)>; + std::function(const IrArray::Index& index)>; using BodyEmitter = std::function; // Creates the body emitter from target arrays. diff --git a/third_party/xla/xla/service/llvm_ir/math_ops.cc b/third_party/xla/xla/service/llvm_ir/math_ops.cc index 2210a64b77c454..f33e8ec40bb3b8 100644 --- a/third_party/xla/xla/service/llvm_ir/math_ops.cc +++ b/third_party/xla/xla/service/llvm_ir/math_ops.cc @@ -78,5 +78,76 @@ llvm::Value* EmitFastTanh(llvm::IRBuilder<>* b, llvm::Value* input, b->CreateFDiv(numerator, denominator)); } +llvm::Value* EmitErfF32(llvm::IRBuilder<>* b, llvm::Value* x) { + auto type = x->getType(); + constexpr float kErfInvOneMinusHalfULP = 3.832506856900711f; + auto call_fabs = [b](llvm::Value* operand_value) { + return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::fabs, {operand_value}, + {operand_value->getType()}, b); + }; + auto fcmp_le = [b](llvm::Value* lhs_value, llvm::Value* rhs_value) { + return b->CreateFCmpOLE(lhs_value, rhs_value); + }; + llvm::Value* const clamp = fcmp_le( + llvm::ConstantFP::get(type, kErfInvOneMinusHalfULP), call_fabs(x)); + // The monomial coefficients of the numerator polynomial (odd). + llvm::Value* const alpha_1 = llvm::ConstantFP::get(type, 1.128379143519084f); + llvm::Value* const alpha_3 = + llvm::ConstantFP::get(type, 0.18520832239976145f); + llvm::Value* const alpha_5 = + llvm::ConstantFP::get(type, 0.050955695062380861f); + llvm::Value* const alpha_7 = + llvm::ConstantFP::get(type, 0.0034082910107109506f); + llvm::Value* const alpha_9 = + llvm::ConstantFP::get(type, 0.00022905065861350646f); + + // The monomial coefficients of the denominator polynomial (even). + llvm::Value* const beta_0 = llvm::ConstantFP::get(type, 1.0f); + llvm::Value* const beta_2 = llvm::ConstantFP::get(type, 0.49746925110067538f); + llvm::Value* const beta_4 = llvm::ConstantFP::get(type, 0.11098505178285362f); + llvm::Value* const beta_6 = + llvm::ConstantFP::get(type, 0.014070470171167667f); + llvm::Value* const beta_8 = + llvm::ConstantFP::get(type, 0.0010179625278914885f); + llvm::Value* const beta_10 = + llvm::ConstantFP::get(type, 0.000023547966471313185f); + llvm::Value* const beta_12 = + llvm::ConstantFP::get(type, -1.1791602954361697e-7f); + + // Since the polynomials are odd/even, we need x^2. + llvm::Value* const x2 = b->CreateFMul(x, x); + + // Evaluate the numerator polynomial p. + auto call_fma = [b](llvm::Value* multiplier, llvm::Value* multiplicand, + llvm::Value* addend) { + return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::fma, + {multiplier, multiplicand, addend}, + {multiplier->getType()}, b); + }; + llvm::Value* p = call_fma(x2, alpha_9, alpha_7); + p = call_fma(x2, p, alpha_5); + p = call_fma(x2, p, alpha_3); + p = call_fma(x2, p, alpha_1); + p = b->CreateFMul(x, p); + + // Evaluate the denominator polynomial p. + llvm::Value* q = call_fma(x2, beta_12, beta_10); + q = call_fma(x2, q, beta_8); + q = call_fma(x2, q, beta_6); + q = call_fma(x2, q, beta_4); + q = call_fma(x2, q, beta_2); + q = call_fma(x2, q, beta_0); + + // Divide the numerator by the denominator. + auto call_copysign = [b](llvm::Value* mag, llvm::Value* sign) { + return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::copysign, {mag, sign}, + {mag->getType()}, b); + }; + auto* result = + b->CreateSelect(clamp, call_copysign(llvm::ConstantFP::get(type, 1.0), x), + b->CreateFDiv(p, q)); + return result; +} + } // namespace llvm_ir } // namespace xla diff --git a/third_party/xla/xla/service/llvm_ir/math_ops.h b/third_party/xla/xla/service/llvm_ir/math_ops.h index 558e6a77d83148..7c5bf27c55de0d 100644 --- a/third_party/xla/xla/service/llvm_ir/math_ops.h +++ b/third_party/xla/xla/service/llvm_ir/math_ops.h @@ -28,6 +28,10 @@ namespace llvm_ir { llvm::Value* EmitFastTanh(llvm::IRBuilder<>* b, llvm::Value* input, bool with_fma = false); +// Emits an approximation of erf. The implementation uses the same rational +// interpolant as implemented in Eigen3. +llvm::Value* EmitErfF32(llvm::IRBuilder<>* b, llvm::Value* x); + } // namespace llvm_ir } // namespace xla diff --git a/third_party/xla/xla/service/llvm_ir/sort_util.cc b/third_party/xla/xla/service/llvm_ir/sort_util.cc index 7cbe55ad27f848..22ed179ae40958 100644 --- a/third_party/xla/xla/service/llvm_ir/sort_util.cc +++ b/third_party/xla/xla/service/llvm_ir/sort_util.cc @@ -165,7 +165,8 @@ Status EmitTiledCompareLoop( llvm::Value* thread_id = gpu::EmitCallToTargetIntrinsic( gpu::TargetIntrinsicID::kThreadIdx, {}, {}, b); llvm_ir::AddRangeMetadata(0, tile_size / 2, - llvm::cast(thread_id)); + llvm::cast(thread_id), + b->GetInsertBlock()->getModule()); thread_id = b->CreateIntCast(thread_id, tiled_keys_index.GetType(), /*isSigned=*/true, "thread.id.x"); diff --git a/third_party/xla/xla/service/lockable.h b/third_party/xla/xla/service/lockable.h index 5d9c72b1f1742c..3a71685c653432 100644 --- a/third_party/xla/xla/service/lockable.h +++ b/third_party/xla/xla/service/lockable.h @@ -73,10 +73,16 @@ class Lockable { }; Lockable() = default; + explicit Lockable(T value) : value_(std::move(value)) { VLOG(2) << "Constructed " << LockableName::ToString(value_); } + template + explicit Lockable(Args&&... args) : value_(std::forward(args)...) { + VLOG(2) << "Constructed " << LockableName::ToString(value_); + } + Lockable(const Lockable&) = delete; Lockable& operator=(const Lockable&) = delete; diff --git a/third_party/xla/xla/service/logistic_expander_test.cc b/third_party/xla/xla/service/logistic_expander_test.cc index af808f0653452b..dae9292b715e4a 100644 --- a/third_party/xla/xla/service/logistic_expander_test.cc +++ b/third_party/xla/xla/service/logistic_expander_test.cc @@ -16,30 +16,19 @@ limitations under the License. #include "xla/service/logistic_expander.h" #include +#include -#include "absl/strings/str_cat.h" -#include "absl/strings/str_join.h" -#include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" -#include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_opcode.h" -#include "xla/layout_util.h" -#include "xla/literal.h" -#include "xla/service/hlo_creation_utils.h" +#include "xla/service/dynamic_padder.h" #include "xla/service/hlo_parser.h" -#include "xla/service/hlo_pass_fix.h" -#include "xla/service/hlo_pass_pipeline.h" #include "xla/service/pattern_matcher.h" #include "xla/service/pattern_matcher_gmock.h" -#include "xla/service/shape_inference.h" -#include "xla/shape_util.h" +#include "xla/statusor.h" #include "xla/test.h" #include "xla/tests/hlo_test_base.h" -#include "xla/types.h" -#include "xla/window_util.h" #include "xla/xla_data.pb.h" -#include "tsl/lib/core/status_test_util.h" namespace xla { namespace { @@ -73,5 +62,21 @@ TEST_F(LogisticExpanderTest, ExpandWith) { m::Exp(m::Negate(m::Parameter(0))))))); } +TEST_F(LogisticExpanderTest, DynamicDimensions) { + constexpr std::string_view hlo = R"( +HloModule DynamicDimensions + +ENTRY main { + p = f32[<=10] parameter(0) + ROOT root = f32[<=10] logistic(p) +} + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo)); + + LogisticExpander logistic_expander; + ASSERT_TRUE(logistic_expander.Run(module.get()).value()); + DynamicPadder dynamic_padder; + EXPECT_TRUE(dynamic_padder.Run(module.get()).value()); +} } // namespace } // namespace xla diff --git a/third_party/xla/xla/service/memory_space_assignment/BUILD b/third_party/xla/xla/service/memory_space_assignment/BUILD index 6b90dd29c50ac0..eaccc3d1d405e9 100644 --- a/third_party/xla/xla/service/memory_space_assignment/BUILD +++ b/third_party/xla/xla/service/memory_space_assignment/BUILD @@ -5,6 +5,7 @@ load( "//xla:xla.bzl", "xla_cc_test", ) +load("@local_tsl//tsl:tsl.bzl", "internal_visibility") load( "@local_tsl//tsl/platform:build_config.bzl", "tf_proto_library", @@ -12,7 +13,8 @@ load( load("@local_tsl//tsl/platform:rules_cc.bzl", "cc_library") package( - default_visibility = ["//visibility:public"], + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = internal_visibility([":friends"]), licenses = ["notice"], ) @@ -28,17 +30,19 @@ tf_proto_library( srcs = ["memory_space_assignment.proto"], cc_api_version = 2, make_default_target_header_only = True, - visibility = ["//visibility:public"], ) cc_library( name = "memory_space_assignment", srcs = ["memory_space_assignment.cc"], hdrs = ["memory_space_assignment.h"], - visibility = ["//visibility:public"], deps = [ + ":allocation", + ":cost_analysis", ":memory_space_assignment_proto_cc", + ":prefetch_interval_picker", ":repacking", + ":slice", ":tuning_utils", ":utils", "//xla:debug_options_flags", @@ -47,6 +51,7 @@ cc_library( "//xla:status_macros", "//xla:statusor", "//xla:util", + "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", "//xla/hlo/utils:hlo_live_range", "//xla/service:buffer_value", @@ -58,7 +63,6 @@ cc_library( "//xla/service:hlo_proto_cc", "//xla/service:hlo_value", "//xla/service:time_utils", - "//xla/service:tuple_util", "//xla/service/heap_simulator", "//xla/service/heap_simulator:allocation_block", "@com_google_absl//absl/algorithm:container", @@ -84,9 +88,14 @@ xla_cc_test( name = "memory_space_assignment_test", srcs = ["memory_space_assignment_test.cc"], deps = [ + ":allocation", + ":cost_analysis", ":memory_space_assignment", ":memory_space_assignment_proto_cc", + ":prefetch_interval_picker", ":repacking", + ":slice", + ":testing_utils", "//xla:shape_util", "//xla:status", "//xla:statusor", @@ -126,7 +135,6 @@ xla_cc_test( cc_library( name = "repacking", hdrs = ["repacking.h"], - visibility = ["//visibility:public"], deps = [ "//xla:statusor", "//xla/service/heap_simulator:allocation_block", @@ -138,7 +146,6 @@ cc_library( name = "best_fit_repacker", srcs = ["best_fit_repacker.cc"], hdrs = ["best_fit_repacker.h"], - visibility = ["//visibility:public"], deps = [ ":repacking", "//xla:comparison_util", @@ -160,7 +167,6 @@ cc_library( name = "utils", srcs = ["utils.cc"], hdrs = ["utils.h"], - visibility = ["//visibility:public"], deps = [ "//xla/hlo/ir:hlo", "//xla/service:hlo_value", @@ -168,17 +174,166 @@ cc_library( ], ) +cc_library( + name = "slice", + srcs = ["slice.cc"], + hdrs = [ + "slice.h", + ], + deps = [ + ":memory_space_assignment_proto_cc", + "//xla:shape_util", + "//xla/service/heap_simulator", + "@com_google_absl//absl/strings", + ], +) + +cc_library( + name = "allocation", + srcs = ["allocation.cc"], + hdrs = [ + "allocation.h", + ], + deps = [ + ":cost_analysis", + ":memory_space_assignment_proto_cc", + ":slice", + "//xla:shape_util", + "//xla:status", + "//xla:util", + "//xla/hlo/ir:hlo", + "//xla/service:hlo_value", + "//xla/service:time_utils", + "//xla/service:tuple_util", + "//xla/service/heap_simulator", + "//xla/service/heap_simulator:allocation_block", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/functional:function_ref", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + "@local_tsl//tsl/platform:errors", + ], +) + cc_library( name = "tuning_utils", srcs = ["tuning_utils.cc"], hdrs = ["tuning_utils.h"], - visibility = ["//visibility:public"], deps = [ "//xla/hlo/ir:hlo", "//xla/service/heap_simulator", ], ) +cc_library( + name = "cost_analysis", + srcs = ["cost_analysis.cc"], + hdrs = ["cost_analysis.h"], + deps = [ + "//xla:shape_util", + "//xla:statusor", + "//xla:util", + "//xla/hlo/ir:hlo", + "//xla/hlo/utils:hlo_live_range", + "//xla/service:call_graph", + "//xla/service:hlo_alias_analysis", + "//xla/service:hlo_buffer", + "//xla/service:hlo_cost_analysis", + "//xla/service:hlo_value", + "//xla/service/heap_simulator", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/functional:function_ref", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/types:span", + "@local_tsl//tsl/platform:statusor", + ], +) + +xla_cc_test( + name = "cost_analysis_test", + srcs = ["cost_analysis_test.cc"], + deps = [ + ":cost_analysis", + "//xla:shape_util", + "//xla:status", + "//xla/hlo/ir:hlo", + "//xla/service:hlo_cost_analysis", + "//xla/tests:hlo_test_base", + "@com_google_absl//absl/log", + "@com_google_absl//absl/strings:string_view", + "@com_google_googletest//:gtest_main", + "@local_tsl//tsl/lib/core:status_test_util", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:statusor", + ], +) + +cc_library( + name = "prefetch_interval_picker", + srcs = ["prefetch_interval_picker.cc"], + hdrs = ["prefetch_interval_picker.h"], + deps = [ + ":cost_analysis", + ":memory_space_assignment_proto_cc", + "//xla:shape_util", + "//xla:status", + "//xla:util", + "//xla/hlo/ir:hlo", + "//xla/hlo/utils:hlo_live_range", + "//xla/service:hlo_proto_cc", + "//xla/service:hlo_value", + "//xla/service/heap_simulator", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + "@local_tsl//tsl/platform:logging", + ], +) + +cc_library( + name = "testing_utils", + testonly = True, + hdrs = ["testing_utils.h"], + deps = [ + ":cost_analysis", + "//xla:shape_util", + "//xla:statusor", + "//xla/hlo/ir:hlo", + "//xla/hlo/utils:hlo_live_range", + "//xla/service:call_graph", + "//xla/service:hlo_alias_analysis", + "//xla/service:hlo_cost_analysis", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/types:span", + "@local_tsl//tsl/platform:statusor", + ], +) + +xla_cc_test( + name = "prefetch_interval_picker_test", + srcs = ["prefetch_interval_picker_test.cc"], + deps = [ + ":cost_analysis", + ":prefetch_interval_picker", + ":testing_utils", + "//xla:shape_util", + "//xla/hlo/ir:hlo", + "//xla/service:hlo_cost_analysis", + "//xla/service:hlo_value", + "//xla/tests:hlo_test_base", + "@com_google_absl//absl/log", + "@com_google_absl//absl/strings:string_view", + "@com_google_googletest//:gtest_main", + "@local_tsl//tsl/platform:statusor", + ], +) + xla_cc_test( name = "best_fit_repacker_test", srcs = ["best_fit_repacker_test.cc"], diff --git a/third_party/xla/xla/service/memory_space_assignment/allocation.cc b/third_party/xla/xla/service/memory_space_assignment/allocation.cc new file mode 100644 index 00000000000000..54cf00f9a31f39 --- /dev/null +++ b/third_party/xla/xla/service/memory_space_assignment/allocation.cc @@ -0,0 +1,854 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/memory_space_assignment/allocation.h" + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/container/flat_hash_set.h" +#include "absl/functional/function_ref.h" +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/service/heap_simulator/allocation_block.h" +#include "xla/service/heap_simulator/heap_simulator.h" +#include "xla/service/hlo_value.h" +#include "xla/service/memory_space_assignment/cost_analysis.h" +#include "xla/service/memory_space_assignment/slice.h" +#include "xla/service/time_utils.h" +#include "xla/service/tuple_util.h" +#include "xla/shape.h" +#include "xla/shape_util.h" +#include "xla/status.h" +#include "xla/util.h" +#include "tsl/platform/errors.h" + +namespace xla::memory_space_assignment { +namespace { + +std::string UsesToString(const std::vector& uses) { + if (uses.empty()) { + return "none"; + } + std::vector uses_str; + uses_str.reserve(uses.size()); + for (const auto& use : uses) { + uses_str.push_back(use.ToString()); + } + return absl::StrJoin(uses_str, ","); +} + +// Helper function to compute the start time for a SlicedCopyAllocation. +int64_t GetSlicedCopyAllocationExclusiveStartTime( + const std::vector& + slice_decisions_sorted_by_exclusive_start_time) { + if (slice_decisions_sorted_by_exclusive_start_time.empty()) { + return -1; + } + + return slice_decisions_sorted_by_exclusive_start_time.front() + .exclusive_start_time; +} + +// Helper function to compute the underlying Allocation chunk for a +// SlicedCopyAllocation. +std::optional GetSlicedCopyAllocationChunk( + const std::vector& slice_decisions_sorted_by_start_time) { + if (slice_decisions_sorted_by_start_time.empty()) { + return std::nullopt; + } + auto offset_cmp = [](const SliceDecision& lhs, const SliceDecision& rhs) { + return lhs.chunk.offset < rhs.chunk.offset; + }; + auto end_cmp = [](const SliceDecision& lhs, const SliceDecision& rhs) { + return lhs.chunk.chunk_end() < rhs.chunk.chunk_end(); + }; + return HeapSimulator::Chunk::FromOffsetEnd( + std::min_element(slice_decisions_sorted_by_start_time.begin(), + slice_decisions_sorted_by_start_time.end(), offset_cmp) + ->chunk.offset, + std::max_element(slice_decisions_sorted_by_start_time.begin(), + slice_decisions_sorted_by_start_time.end(), end_cmp) + ->chunk.chunk_end()); +} + +} // namespace + +std::optional Allocation::cross_program_prefetch_index() const { + return cross_program_prefetch_index_; +} + +HeapSimulator::Chunk Allocation::chunk() const { + CHECK(chunk_.has_value()); + return *chunk_; +} + +void Allocation::set_offset(int64_t offset) { + CHECK(chunk_.has_value()); + *chunk_ = HeapSimulator::Chunk::FromOffsetSize(offset, chunk_->size); +} + +bool Allocation::is_in_alternate_mem() const { + return memory_space_ == MemorySpace::kAlternate; +} + +bool Allocation::is_in_default_mem() const { + return memory_space_ == MemorySpace::kDefault; +} + +void Allocation::AddUse(HloUse use) { + HloInstruction* operand = + use.instruction->mutable_operand(use.operand_number); + // If the use is a tuple, look inside the tuple to find the actual use. + for (int64_t index : use.operand_index) { + if (operand->opcode() != HloOpcode::kTuple) { + break; + } + operand = operand->mutable_operand(index); + } + + // Look beyond GetTupleElement(Tuple()) pattern for any bitcasts. + std::function get_simplified_operand; + get_simplified_operand = [&](HloInstruction* instruction) { + while (instruction->opcode() == HloOpcode::kGetTupleElement) { + HloInstruction* operand = + get_simplified_operand(instruction->mutable_operand(0)); + if (operand->opcode() == HloOpcode::kTuple) { + instruction = operand->mutable_operand(instruction->tuple_index()); + } else { + return instruction; + } + } + return instruction; + }; + operand = get_simplified_operand(operand); + + uses_.push_back(use); +} + +Status Allocation::UpdateUses(HloComputation* computation, + HloInstruction* producing_instruction) { + for (const HloUse& use : uses()) { + HloInstruction* replacement_instruction = producing_instruction; + Shape operand_shape = use.instruction->operand(use.operand_number)->shape(); + if (operand_shape.IsTuple()) { + TF_ASSIGN_OR_RETURN( + replacement_instruction, + TupleUtil::ReplaceTupleWith( + producing_instruction, + use.instruction->mutable_operand(use.operand_number), + use.operand_index)); + } else if (operand_shape != producing_instruction->shape()) { + // When processing allocations, we treat bitcasts as trivial positions and + // do not create allocations for them. We insert bitcasts after copies, to + // account for the fact that we don't have an allocation for the bitcast. + VLOG(4) << "Old shape = " << operand_shape.ToString() + << ", new shape = " << producing_instruction->shape().ToString() + << "; inserting a bitcast."; + replacement_instruction = computation->AddInstruction( + HloInstruction::CreateBitcast(operand_shape, producing_instruction)); + } + TF_RETURN_IF_ERROR(use.instruction->ReplaceOperandWith( + use.operand_number, replacement_instruction)); + } + return OkStatus(); +} + +bool Allocation::is_copy_like_allocation() const { + return is_copy_allocation() || is_sliced_copy_allocation(); +} + +HloInstruction* Allocation::AddGetTupleElements() const { + CHECK_NE(defining_position().instruction, nullptr); + + Shape shape = defining_position().shape(); + CHECK(shape.IsArray()) << "Allocation shape is not an array. Shape = " + << shape.ToString() + << " position = " << defining_position().shape(); + return TupleUtil::AddGetTupleElements(defining_position()); +} + +Allocation::Allocation(HloPosition defining_position, MemorySpace memory_space, + std::optional chunk, + int64_t start_time, int64_t end_time, + bool is_scoped_allocation, + std::optional cross_program_prefetch_index) + : original_defining_position_(std::move(defining_position)), + memory_space_(memory_space), + chunk_(std::move(chunk)), + start_time_(start_time), + end_time_(end_time), + is_scoped_allocation_(is_scoped_allocation), + cross_program_prefetch_index_(cross_program_prefetch_index) { + CHECK(!is_scoped_allocation || + original_defining_position_.index == ShapeIndex({})); +} + +HloPosition Allocation::original_defining_position() const { + return original_defining_position_; +} + +void Allocation::set_original_defining_position(HloPosition defining_position) { + original_defining_position_ = std::move(defining_position); +} + +bool Allocation::base_is_equal(const Allocation& other) const { + return defining_position() == other.defining_position() && + uses() == other.uses() && memory_space() == other.memory_space() && + chunk() == other.chunk() && start_time() == other.start_time() && + end_time() == other.end_time() && + earliest_available_time() == other.earliest_available_time() && + is_copy_allocation() == other.is_copy_allocation() && + is_scoped_allocation() == other.is_scoped_allocation(); +} + +PinnedAllocation::PinnedAllocation(HloPosition defining_position, + MemorySpace memory_space, + std::optional chunk, + int64_t start_time, int64_t end_time, + bool is_scoped_allocation) + : Allocation(std::move(defining_position), memory_space, chunk, start_time, + end_time, is_scoped_allocation, + /*cross_program_prefetch_index=*/std::nullopt) {} + +HloPosition PinnedAllocation::defining_position() const { + return original_defining_position(); +} + +bool PinnedAllocation::operator==(const PinnedAllocation& other) const { + return this->base_is_equal(static_cast(other)); +} + +bool MirroredAllocation::operator==(const MirroredAllocation& other) const { + return this->base_is_equal(static_cast(other)); +} + +bool ParentAllocation::operator==(const ParentAllocation& other) const { + return this->base_is_equal(static_cast(other)); +} + +bool PinnedAllocation::operator==(const Allocation& other) const { + const PinnedAllocation* casted_other = + dynamic_cast(&other); + return casted_other != nullptr && (*this) == (*casted_other); +} + +Status PinnedAllocation::Process() { + if (is_scoped_allocation()) { + // Nothing to do here for scoped allocations. + return OkStatus(); + } + HloInstruction* producing_instruction = AddGetTupleElements(); + HloComputation* computation = producing_instruction->parent(); + return UpdateUses(computation, producing_instruction); +} + +std::string PinnedAllocation::ToString() const { + std::string memory_space_str = + memory_space() == MemorySpace::kDefault ? "def" : "alt"; + std::optional chunk = maybe_chunk(); + if (chunk) { + absl::StrAppend(&memory_space_str, " (off: ", chunk->offset, ")"); + } + return absl::StrCat((is_scoped_allocation() ? "Scoped " : ""), + "PinnedAllocation in ", memory_space_str, " defined at ", + original_defining_position().ToString(), + ", start_time:", start_time(), ", end_time:", end_time(), + ", uses: ", UsesToString(uses())); +} + +void PinnedAllocation::MarkIfNeeded( + absl::flat_hash_set& needed_allocations) const { + MarkNeeded(needed_allocations); +} + +void PinnedAllocation::MarkNeeded( + absl::flat_hash_set& needed_allocations) const { + needed_allocations.insert(this); +} + +CopyAllocation::CopyAllocation( + Allocation& prev_allocation, MemorySpace memory_space, + std::optional chunk, + int64_t copy_start_schedule_after_time, + int64_t copy_done_schedule_before_time, int64_t end_time, + std::optional cross_program_prefetch_index) + : Allocation( + /*defining_position=*/{nullptr, {}}, memory_space, chunk, + // Allocation uses an inclusive start time + ExclusiveToInclusiveStartTime(copy_start_schedule_after_time), + end_time, + /*is_scoped_allocation=*/false, cross_program_prefetch_index), + prev_allocation_(prev_allocation), + copy_start_schedule_after_(copy_start_schedule_after_time), + copy_done_schedule_before_(copy_done_schedule_before_time) {} + +int64_t CopyAllocation::earliest_available_time() const { + return copy_done_schedule_before_; +} + +Status CopyAllocation::Process() { + // Copy allocations need to insert asynchronous copy nodes. + Shape shape = defining_position().shape(); + HloInstruction* producing_instruction = AddGetTupleElements(); + HloComputation* computation = producing_instruction->parent(); + copy_start_ = computation->AddInstruction(HloInstruction::CreateCopyStart( + ShapeUtil::MakeTupleShape({shape, shape, ShapeUtil::MakeShape(U32, {})}), + producing_instruction, cross_program_prefetch_index())); + copy_done_ = computation->AddInstruction( + HloInstruction::CreateUnary(shape, HloOpcode::kCopyDone, copy_start_)); + VLOG(4) << "Created " << copy_start_->name() + << " for copy allocation: " << ToString(); + + // Update the allocation position with the copy complete instruction, so that + // if there are further copies from it, they can find the correct position. + set_original_defining_position(HloPosition{copy_done_, {}}); + return UpdateUses(computation, copy_done_); +} + +void CopyAllocation::MarkIfNeeded( + absl::flat_hash_set& needed_allocations) const { + MarkNeeded(needed_allocations); +} + +void CopyAllocation::MarkNeeded( + absl::flat_hash_set& needed_allocations) const { + needed_allocations.insert(this); + prev_allocation_.MarkNeeded(needed_allocations); +} + +std::string CopyAllocation::ToString() const { + std::string memory_space_str = + memory_space() == MemorySpace::kDefault ? "def" : "alt"; + std::optional chunk = maybe_chunk(); + if (chunk) { + absl::StrAppend(&memory_space_str, " (off: ", chunk->offset, ")"); + } + return absl::StrCat("Copy Allocation in ", memory_space_str, + ", start_time:", start_time(), ", end_time:", end_time(), + ", copy_start_after_time: ", copy_start_schedule_after(), + ", copy_done_before_time: ", copy_done_schedule_before(), + ", uses: ", UsesToString(uses()), ", from ", + prev_allocation_.ToString()); +} + +HloPosition CopyAllocation::defining_position() const { + // Unless explicitly set, the defining position of a copy allocation is + // retrieved from the previous allocation. This is because we don't create + // new CopyStart/CopyDone instructions until later and the position should + // point to the previous (copy or otherwise) allocation's position for the + // original defining position. + HloPosition defining_position = original_defining_position(); + if (defining_position.instruction == nullptr) { + return prev_allocation_.defining_position(); + } + return defining_position; +} + +bool CopyAllocation::operator==(const CopyAllocation& other) const { + return this->base_is_equal(static_cast(other)) && + copy_done_schedule_before() == other.copy_done_schedule_before() && + copy_start_schedule_after() == other.copy_start_schedule_after() && + copy_start() == other.copy_start() && copy_done() == other.copy_done(); +} + +bool CopyAllocation::operator==(const Allocation& other) const { + const CopyAllocation* casted_other = + dynamic_cast(&other); + return casted_other != nullptr && (*this) == (*casted_other); +} + +void CopyAllocation::set_copy_start_schedule_after( + int64_t copy_start_schedule_after) { + copy_start_schedule_after_ = copy_start_schedule_after; +} + +void CopyAllocation::set_copy_done_schedule_before( + int64_t copy_done_schedule_before) { + copy_done_schedule_before_ = copy_done_schedule_before; +} + +int64_t CopyAllocation::copy_start_schedule_after() const { + return copy_start_schedule_after_; +} + +int64_t CopyAllocation::copy_done_schedule_before() const { + return copy_done_schedule_before_; +} + +SlicedCopyAllocation::SlicedCopyAllocation( + const Allocation& prev_allocation, MemorySpace memory_space, + std::vector slice_decisions_sorted_by_exclusive_start_time, + int64_t copy_done_schedule_before_time, int64_t end_time, + const SlicedPrefetchOptions& sliced_prefetch_options, + absl::FunctionRef get_equivalent_s8_shape_fn) + : Allocation( + /*defining_position=*/{nullptr, {}}, memory_space, + GetSlicedCopyAllocationChunk( + slice_decisions_sorted_by_exclusive_start_time), + // Allocation uses an inclusive start time + ExclusiveToInclusiveStartTime( + GetSlicedCopyAllocationExclusiveStartTime( + slice_decisions_sorted_by_exclusive_start_time)), + end_time, + /*is_scoped_allocation=*/false, + /*cross_program_prefetch_index=*/std::nullopt), + original_shape_to_slice_(prev_allocation.defining_position().shape()), + prev_allocation_(prev_allocation), + sliced_prefetch_options_(sliced_prefetch_options), + get_equivalent_s8_shape_fn_(get_equivalent_s8_shape_fn) { + CHECK_GE(slice_decisions_sorted_by_exclusive_start_time.size(), 2); + slice_details_sorted_by_exclusive_start_time_.reserve( + slice_decisions_sorted_by_exclusive_start_time.size()); + for (SliceDecision& decision : + slice_decisions_sorted_by_exclusive_start_time) { + int64_t copy_done_schedule_after_time = decision.exclusive_start_time; + slice_details_sorted_by_exclusive_start_time_.push_back(SliceDetail{ + std::move(decision), + copy_done_schedule_after_time, + copy_done_schedule_before_time, + /*copy_start=*/nullptr, + /*copy_done=*/nullptr, + }); + } +} + +Status SlicedCopyAllocation::Process() { + Shape shape = defining_position().shape(); + HloInstruction* producing_instruction = AddGetTupleElements(); + + // Calling Process() over the previous allocation might have modified the + // defining position, and hence the shape that was used when we computed + // the slices. In cases where the shape has changed, we insert a bitcast, so + // slice instructions operate on the originally sliced shape. + // + // Note, these bitcasts are being inserted in the same cases that + // UpdateUses() is inserting bitcasts, except we are + // inserting the bitcasts before the copy, instead of after the copy. + if (!Shape::Equal().IgnoreMemorySpaceInLayout()(shape, + original_shape_to_slice_)) { + int64_t new_memory_space = shape.layout().memory_space(); + shape = original_shape_to_slice_; + shape.mutable_layout()->set_memory_space(new_memory_space); + producing_instruction = producing_instruction->parent()->AddInstruction( + HloInstruction::CreateBitcast(shape, producing_instruction)); + } + + HloComputation* computation = producing_instruction->parent(); + std::vector slice_dones; + slice_dones.reserve(slice_details_sorted_by_exclusive_start_time_.size()); + + // If we are trying to make all slices a uniform size, we bitcast the + // producing instruction to an array of bytes, so it is easy to slice into any + // size. + Shape slice_shape = shape; + if (IsUniformSliceSizingEnabled(sliced_prefetch_options_)) { + slice_shape = get_equivalent_s8_shape_fn_(shape); + producing_instruction = producing_instruction->parent()->AddInstruction( + HloInstruction::CreateBitcast(slice_shape, producing_instruction)); + } + + // Sliced copy allocations need to insert asynchronous copy nodes. + for (SliceDetail& slice_detail : + slice_details_sorted_by_exclusive_start_time_) { + TF_RETURN_IF_ERROR(slice_detail.CreateAsyncSlice( + slice_shape, *producing_instruction, *computation)); + VLOG(4) << "Created " << slice_detail.copy_start->name() + << " for sliced copy allocation: " << ToString(); + slice_dones.push_back(slice_detail.copy_done); + } + + TF_RETURN_IF_ERROR(CreateBitcastConcat(shape, slice_dones)); + + // If we bitcast to an array of bytes above, the result of the concatenated + // slices will also be an array of bytes. Thus, we need to cast the + // concatentation back to the original shape. + if (IsUniformSliceSizingEnabled(sliced_prefetch_options_)) { + concat_ = concat_->parent()->AddInstruction( + HloInstruction::CreateBitcast(shape, concat_)); + } + + // Update the allocation position with the copy complete instruction, so that + // if there are further copies from it, they can find the correct position. + set_original_defining_position(HloPosition{concat_, {}}); + return UpdateUses(computation, concat_); +} + +void SlicedCopyAllocation::MarkIfNeeded( + absl::flat_hash_set& needed_allocations) const { + MarkNeeded(needed_allocations); +} + +void SlicedCopyAllocation::MarkNeeded( + absl::flat_hash_set& needed_allocations) const { + needed_allocations.insert(this); + prev_allocation_.MarkNeeded(needed_allocations); +} + +HloPosition SlicedCopyAllocation::defining_position() const { + // Unless explicitly set, the defining position of a sliced copy allocation is + // retrieved from the previous allocation. This is because we don't create + // new CopyStart/CopyDone instructions until later and the position should + // point to the previous (copy or otherwise) allocation's position for the + // original defining position. + HloPosition defining_position = original_defining_position(); + if (defining_position.instruction == nullptr) { + return prev_allocation_.defining_position(); + } + return defining_position; +} + +int64_t SlicedCopyAllocation::earliest_available_time() const { + return slice_details_sorted_by_start_time().back().copy_done_before_time; +} + +std::vector SlicedCopyAllocation::SliceOffsetsSortedByStartTime() + const { + std::vector offsets; + offsets.reserve(slice_details_sorted_by_exclusive_start_time_.size()); + + for (const SliceDetail& slice_detail : + slice_details_sorted_by_exclusive_start_time_) { + offsets.push_back(slice_detail.slice_decision.chunk.offset); + } + + return offsets; +} + +void SlicedCopyAllocation::AddDiffToAllSliceOffsets(int64_t diff) { + for (SliceDetail& slice_detail : + slice_details_sorted_by_exclusive_start_time_) { + HeapSimulator::Chunk& chunk = slice_detail.slice_decision.chunk; + chunk = + HeapSimulator::Chunk::FromOffsetSize(chunk.offset + diff, chunk.size); + } +} + +void SlicedCopyAllocation::ImportRepackedSliceData( + const SlicedAllocationData& data) { + int num_slices = slice_details_sorted_by_exclusive_start_time_.size(); + CHECK_EQ(data.slices_sorted_by_offset.size(), num_slices); + + std::vector slice_details_sorted_by_offset; + slice_details_sorted_by_offset.reserve(num_slices); + for (SliceDetail& slice_detail : + slice_details_sorted_by_exclusive_start_time_) { + slice_details_sorted_by_offset.push_back(&slice_detail); + } + absl::c_sort(slice_details_sorted_by_offset, [](const SliceDetail* lhs, + const SliceDetail* rhs) { + return lhs->slice_decision.chunk.offset < rhs->slice_decision.chunk.offset; + }); + + for (int i = 0; i < num_slices; ++i) { + SliceDetail* slice_detail = slice_details_sorted_by_offset[i]; + HeapSimulator::Chunk& chunk = slice_detail->slice_decision.chunk; + const AllocatedSlice& repacked_slice_data = data.slices_sorted_by_offset[i]; + chunk = HeapSimulator::Chunk::FromOffsetSize(repacked_slice_data.offset, + chunk.size); + slice_detail->copy_start_after_time = + repacked_slice_data.inclusive_start_time - 1; + slice_detail->slice_decision.exclusive_start_time = + InclusiveToExclusiveStartTime(repacked_slice_data.inclusive_start_time); + } + + absl::c_sort(slice_details_sorted_by_exclusive_start_time_, + [](const SliceDetail& lhs, const SliceDetail& rhs) { + return std::make_tuple(lhs.copy_start_after_time, + lhs.slice_decision.chunk.offset) < + std::make_tuple(rhs.copy_start_after_time, + rhs.slice_decision.chunk.offset); + }); +} + +const std::vector& +SlicedCopyAllocation::slice_details_sorted_by_start_time() const { + return slice_details_sorted_by_exclusive_start_time_; +} + +std::vector& +SlicedCopyAllocation::mutable_slice_details_sorted_by_start_time() { + return slice_details_sorted_by_exclusive_start_time_; +} + +bool SlicedCopyAllocation::operator==(const SlicedCopyAllocation& other) const { + return this->base_is_equal(static_cast(other)) && + slice_details_sorted_by_exclusive_start_time_ == + other.slice_details_sorted_by_exclusive_start_time_ && + concat_ == other.concat_; +} + +std::string SlicedCopyAllocation::ToString() const { + std::string memory_space_str = "def"; + if (memory_space() == MemorySpace::kAlternate) { + memory_space_str = absl::StrCat("alt (off: ", maybe_chunk()->offset, ")"); + } + return absl::StrCat( + "Sliced Copy Allocation in ", memory_space_str, + ", start_time:", start_time(), ", end_time:", end_time(), + ", first_slice_copy_start_after_time: ", + slice_details_sorted_by_start_time().front().copy_start_after_time, + ", last_slice_copy_done_before_time: ", + slice_details_sorted_by_start_time().back().copy_done_before_time, + ", uses: ", UsesToString(uses()), ", from ", prev_allocation_.ToString()); +} + +Status SlicedCopyAllocation::CreateBitcastConcat( + const Shape& shape, absl::Span slices) { + CHECK(!slices.empty()); + concat_ = + slices.front()->parent()->AddInstruction(HloInstruction::CreateCustomCall( + shape, slices, + xla::memory_space_assignment::kConcatBitcastCustomCall)); + return OkStatus(); +} + +std::string SlicedCopyAllocation::SliceDetail::ToString() const { + return absl::StrCat("{ slice_decision: ", slice_decision.ToString(), + ", copy_start_after_time: ", copy_start_after_time, + ", copy_done_before_time: ", copy_done_before_time, " }"); +} + +std::tuple +SliceDetailToTuple(const SlicedCopyAllocation::SliceDetail& slice_detail) { + return std::make_tuple(std::ref(slice_detail.slice_decision), + slice_detail.copy_start_after_time, + slice_detail.copy_done_before_time, + slice_detail.copy_start, slice_detail.copy_done); +} + +bool SlicedCopyAllocation::SliceDetail::operator==( + const SliceDetail& other) const { + return SliceDetailToTuple(*this) == SliceDetailToTuple(other); +} + +Status SlicedCopyAllocation::SliceDetail::CreateAsyncSlice( + const Shape& original_shape, HloInstruction& producer, + HloComputation& parent) { + if (original_shape.rank() != slice_decision.sizing.slice_params.size()) { + return FailedPrecondition( + "%s", absl::StrCat("The number of SlicedCopyAllocation parameters ", + slice_decision.sizing.slice_params.size(), + " does not match the rank ", original_shape.rank(), + " of the tensor we are slicing.")); + } + + std::vector start_indices; + start_indices.reserve(slice_decision.sizing.slice_params.size()); + std::vector limit_indices; + limit_indices.reserve(slice_decision.sizing.slice_params.size()); + std::vector strides; + strides.reserve(slice_decision.sizing.slice_params.size()); + + for (int i = 0; i < slice_decision.sizing.slice_params.size(); ++i) { + const SliceParam& slice_param = slice_decision.sizing.slice_params[i]; + start_indices.push_back(slice_param.start_inclusive); + limit_indices.push_back(slice_param.end_exclusive); + strides.push_back(1); + const int64_t new_dim = + slice_param.end_exclusive - slice_param.start_inclusive; + if (new_dim <= 0) { + return FailedPrecondition( + "%s", absl::StrCat("SlicedCopyAllocation new dimension size is ", + new_dim, ", expected something > 0.")); + } + if (original_shape.dimensions(i) < new_dim) { + return FailedPrecondition( + "%s", + absl::StrCat("SlicedCopyAllocation sliced dimension size ", new_dim, + " is bigger than its original dimension size of ", + original_shape.dimensions(i), ".")); + } + } + + HloInstruction* slice = parent.AddInstruction( + HloInstruction::CreateSlice(slice_decision.sizing.slice_shape, &producer, + start_indices, limit_indices, strides)); + TF_ASSIGN_OR_RETURN(copy_done, parent.CreateAsyncInstructions( + slice, {ShapeUtil::MakeShape(S32, {})})); + copy_start = copy_done->mutable_operand(0); + + return OkStatus(); +} + +bool SlicedCopyAllocation::operator==(const Allocation& other) const { + const SlicedCopyAllocation* casted_other = + dynamic_cast(&other); + return casted_other != nullptr && (*this) == (*casted_other); +} + +HloPosition MirroredAllocation::defining_position() const { + return original_defining_position(); +} + +std::string MirroredAllocation::ToString() const { + return absl::StrCat("Mirrored Allocation for ", + original_allocation_.ToString()); +} + +std::string ParentAllocation::ToString() const { + return absl::StrCat("Parent Allocation mirrored at ", + original_defining_position().ToString(), ", originally ", + original_allocation_.ToString()); +} + +MirroredAllocation::MirroredAllocation(const Allocation& original_allocation, + int64_t time) + : Allocation(original_allocation.defining_position(), MemorySpace::kDefault, + original_allocation.maybe_chunk(), + /*start_time=*/time, + /*end_time=*/time, /*is_scoped_allocation=*/false, + /*cross_program_prefetch_index=*/std::nullopt), + original_allocation_(original_allocation) {} + +Status MirroredAllocation::Process() { + set_original_defining_position(original_allocation_.defining_position()); + if (is_scoped_allocation()) { + // Nothing to do here for scoped allocations. + return OkStatus(); + } + HloInstruction* producing_instruction = AddGetTupleElements(); + HloComputation* computation = producing_instruction->parent(); + return UpdateUses(computation, producing_instruction); +} + +ParentAllocation::ParentAllocation(const Allocation& original_allocation, + HloInstruction* calling_instruction, + HloPosition position, int64_t time) + : Allocation(std::move(position), MemorySpace::kDefault, + original_allocation.maybe_chunk(), + /*start_time=*/time, + /*end_time=*/time, /*is_scoped_allocation=*/false, + /*cross_program_prefetch_index=*/std::nullopt), + original_allocation_(original_allocation), + calling_instruction_(calling_instruction) {} + +HloPosition ParentAllocation::defining_position() const { + return original_defining_position(); +} + +Status ParentAllocation::Process() { + // Add an additional parameter to the while HLO with a reference to the buffer + // in the default memory space. + HloInstruction* producing_instruction = + original_allocation_.AddGetTupleElements(); + int new_tuple_index = calling_instruction_->shape().tuple_shapes_size(); + + TF_ASSIGN_OR_RETURN( + HloInstruction * new_while_operand, + TupleUtil::ReplaceTupleWith(producing_instruction, + calling_instruction_->mutable_operand(0), + {new_tuple_index})); + TF_RETURN_IF_ERROR(calling_instruction_->ReplaceOperandWithDifferentShape( + 0, new_while_operand)); + *calling_instruction_->mutable_shape() = new_while_operand->shape(); + *calling_instruction_->while_condition() + ->parameter_instruction(0) + ->mutable_shape() = new_while_operand->shape(); + *calling_instruction_->while_body() + ->parameter_instruction(0) + ->mutable_shape() = new_while_operand->shape(); + HloPosition defining_position = original_defining_position(); + defining_position.index = {new_tuple_index}; + set_original_defining_position(defining_position); + // Also replace the while op with a tuple that has the old shape. Note that we + // need to first take a snapshot of the users before calling ExtractPrefix + // since ExtractPrefix introduces additional gte users. + std::vector while_users = calling_instruction_->users(); + HloInstruction* tuple_with_old_shape = + TupleUtil::ExtractPrefix(calling_instruction_, new_tuple_index); + TF_RETURN_IF_ERROR(calling_instruction_->ReplaceAllUsesWithDifferentShape( + while_users, tuple_with_old_shape)); + + if (is_scoped_allocation()) { + // Nothing to do here for scoped allocations. + return OkStatus(); + } + HloInstruction* final_instruction = AddGetTupleElements(); + HloComputation* computation = final_instruction->parent(); + return UpdateUses(computation, final_instruction); +} + +Status ParentAllocation::PostProcess() { + // Update the root of the while body with the new parameter. The reason why we + // need a separate post-process for this is because other allocations may have + // while body root as a use, so they would update the old root instead of the + // new root. Doing the post-process step later ensures the root has been + // updated with other changes, and we can safely add the additional parameter. + HloComputation* while_body = calling_instruction_->while_body(); + TF_ASSIGN_OR_RETURN(HloInstruction * new_while_body_root, + TupleUtil::ReplaceTupleWith( + AddGetTupleElements(), while_body->root_instruction(), + original_defining_position().index)); + while_body->set_root_instruction(new_while_body_root, + /*accept_different_shape=*/true); + return OkStatus(); +} + +void ParentAllocation::MarkIfNeeded( + absl::flat_hash_set& needed_allocations) const { + // Parent allocations are only needed if they have any uses or if there is a + // copy allocation that copies this value (in that case, the copy allocation + // will call this allocation's MarkNeeded function). + if (!has_no_uses()) { + MarkNeeded(needed_allocations); + } +} + +void ParentAllocation::MarkNeeded( + absl::flat_hash_set& needed_allocations) const { + needed_allocations.insert(this); + original_allocation_.MarkNeeded(needed_allocations); +} + +bool ParentAllocation::operator==(const Allocation& other) const { + const ParentAllocation* casted_other = + dynamic_cast(&other); + return casted_other != nullptr && (*this) == (*casted_other); +} + +void MirroredAllocation::MarkIfNeeded( + absl::flat_hash_set& needed_allocations) const { + MarkNeeded(needed_allocations); +} + +void MirroredAllocation::MarkNeeded( + absl::flat_hash_set& needed_allocations) const { + needed_allocations.insert(this); + original_allocation_.MarkNeeded(needed_allocations); +} + +bool MirroredAllocation::operator==(const Allocation& other) const { + const MirroredAllocation* casted_other = + dynamic_cast(&other); + return casted_other != nullptr && (*this) == (*casted_other); +} + +} // namespace xla::memory_space_assignment diff --git a/third_party/xla/xla/service/memory_space_assignment/allocation.h b/third_party/xla/xla/service/memory_space_assignment/allocation.h new file mode 100644 index 00000000000000..86bdab72c55bae --- /dev/null +++ b/third_party/xla/xla/service/memory_space_assignment/allocation.h @@ -0,0 +1,447 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_MEMORY_SPACE_ASSIGNMENT_ALLOCATION_H_ +#define XLA_SERVICE_MEMORY_SPACE_ASSIGNMENT_ALLOCATION_H_ + +#include +#include +#include +#include +#include +#include + +#include "absl/container/flat_hash_set.h" +#include "absl/functional/function_ref.h" +#include "absl/types/span.h" +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/service/heap_simulator/allocation_block.h" +#include "xla/service/heap_simulator/heap_simulator.h" +#include "xla/service/hlo_value.h" +#include "xla/service/memory_space_assignment/memory_space_assignment.pb.h" +#include "xla/service/memory_space_assignment/slice.h" +#include "xla/shape.h" +#include "xla/status.h" + +namespace xla::memory_space_assignment { + +// MemorySpaceAssignment uses a notion of a slow and large default memory +// space and a fast and small alternate memory space. +enum class MemorySpace : std::uint8_t { kDefault, kAlternate }; + +// An interface describing what to do with a value in memory over its lifetime. +// An allocation might either be placed in the default or alternate memory. An +// HloValue might live in multiple different allocations over its lifetime. The +// lifetimes of the allocations are defined using start_time and end_time, which +// corresponds to the instruction indexes in the flattened schedule. Each of +// these allocations might partially overlap with each other. +// +// Consider an instruction Foo, and its users Bar and Baz, and the times given +// in terms of the flattened schedule of the entire module: +// +// Foo:10 +// / \ +// Bar:14 \ +// Baz:25 +// +// A valid memory space assignment could be like the following: +// +// Time: 10 ... 14 ... 25 +// Foo Bar Baz +// Alternate +-------+ +-----+ +// Default +---------------------+ +// ^ ^ ^ ^ +// | | | | +// evict evict prefetch prefetch +// start end start end +// +// This would be represented with: +// - PinnedAllocation(memory_space=kAlternate, start_time=10, end_time=14) +// - CopyAllocation(memory_space=kDefault, start_time=12, end_time=25) +// - CopyAllocation(memory_space=kAlternate, start_time=22, end_time=25) +class Allocation { + public: + virtual ~Allocation() = default; + + // Allocation source methods + // -------------------------------------------------------------------------- + // Returns the defining position for this allocation. + virtual HloPosition defining_position() const = 0; + // Returns the cross-program prefetch index for this allocation. + std::optional cross_program_prefetch_index() const; + + // Allocation timing methods + // -------------------------------------------------------------------------- + // TODO(cl/604356742): update all timing methods to explicitly state that + // they're representing inclusive intervals. + int64_t start_time() const { return start_time_; } + int64_t end_time() const { return end_time_; } + // Returns the time the buffer is first available to be used + virtual int64_t earliest_available_time() const = 0; + void set_start_time(int64_t start_time) { start_time_ = start_time; } + void set_end_time(int64_t end_time) { end_time_ = end_time; } + // Extends the end time of this allocation. + void Extend(int64_t end_time) { end_time_ = std::max(end_time_, end_time); } + + // Allocation space methods + // -------------------------------------------------------------------------- + MemorySpace memory_space() const { return memory_space_; } + // Returns the associated chunk that may be a nullopt if the allocation is + // in the default memory space. + std::optional maybe_chunk() const { return chunk_; } + // Returns the associated chunk. The caller should ensure that the chunk is + // defined (the allocation should be in the alternate memory space). + HeapSimulator::Chunk chunk() const; + HeapSimulator::Chunk* mutable_chunk() { return &*chunk_; } + void set_offset(int64_t offset); + bool is_scoped_allocation() const { return is_scoped_allocation_; } + // Returns true if the allocation is in the alternate memory space. + bool is_in_alternate_mem() const; + // Returns true if the allocation is in the default memory space. + bool is_in_default_mem() const; + + // Use methods + // -------------------------------------------------------------------------- + const std::vector& uses() const { return uses_; } + void clear_uses() { uses_.clear(); } + bool has_no_uses() const { return uses_.empty(); } + // Adds a use to this allocation. + void AddUse(HloUse use); + // Replaces all uses of the allocation with the copy_complete instruction. + Status UpdateUses(HloComputation* computation, + HloInstruction* producing_instruction); + + // Allocation type methods + // -------------------------------------------------------------------------- + virtual bool is_copy_allocation() const = 0; + virtual bool is_sliced_copy_allocation() const = 0; + // True if the allocation is for a copy or a sliced-copy. + bool is_copy_like_allocation() const; + + // Processing methods + // -------------------------------------------------------------------------- + // Recursively create kGetTupleElement instructions if the defining position + // shape is not an array. Returns the new instruction that has array shape. + HloInstruction* AddGetTupleElements() const; + // After all of the time ranges for the allocations have been assigned, + // Process morphs the instructions affected to assign the memory spaces and + // insert asynchronous copy instructions if necessary. + virtual Status Process() = 0; + // An optional post-process step that will be called after all allocations + // have been processed. + virtual Status PostProcess() = 0; + // Marks (adds this allocation to needed_allocations) if this allocation is + // needed. PinnedAllocation and CopyAllocations are always needed and + // ParentAllocations are needed if they have any uses or if other + // CopyAllocation or ParentAllocations depend on them. + virtual void MarkIfNeeded( + absl::flat_hash_set& needed_allocations) const = 0; + // Marks this allocation as needed. + virtual void MarkNeeded( + absl::flat_hash_set& needed_allocations) const = 0; + + // Utility methods + // -------------------------------------------------------------------------- + virtual std::string ToString() const = 0; + virtual bool operator==(const Allocation& other) const = 0; + + protected: + // Protected constructor to encourage use of the final subclasses (e.g., + // PinnedAllocation, CopyAllocation, etc.). + Allocation(HloPosition defining_position, MemorySpace memory_space, + std::optional chunk, int64_t start_time, + int64_t end_time, bool is_scoped_allocation, + std::optional cross_program_prefetch_index); + + // Returns the original defining position of this allocation. + HloPosition original_defining_position() const; + // Sets the original defining position of this allocation. + void set_original_defining_position(HloPosition defining_position); + bool base_is_equal(const Allocation& other) const; + + private: + HloPosition original_defining_position_; + MemorySpace memory_space_; + std::optional chunk_; + int64_t start_time_; + int64_t end_time_; + const bool is_scoped_allocation_; + std::vector uses_; + std::optional cross_program_prefetch_index_; +}; + +// This class represents an allocation that pins a tensor to +// a specific memory space. +class PinnedAllocation final : public Allocation { + public: + PinnedAllocation(HloPosition defining_position, MemorySpace memory_space, + std::optional chunk, + int64_t start_time, int64_t end_time, + bool is_scoped_allocation); + + // Overridden methods + // + // Returns the original defining position. + HloPosition defining_position() const override; + int64_t earliest_available_time() const override { return start_time(); } + bool is_copy_allocation() const override { return false; } + bool is_sliced_copy_allocation() const override { return false; } + Status Process() override; + Status PostProcess() override { return OkStatus(); } + void MarkIfNeeded(absl::flat_hash_set& needed_allocations) + const override; + void MarkNeeded(absl::flat_hash_set& needed_allocations) + const override; + std::string ToString() const override; + bool operator==(const Allocation& other) const override; + + // New non-virtual methods + bool operator==(const PinnedAllocation& other) const; +}; + +// This class represents an allocation as a result of an asynchronous copy. +// Note: CopyStart instructions are inserted after +// `copy_start_schedule_after`, while CopyDone instructions are inserted +// before `copy_done_schedule_before_time`. +class CopyAllocation final : public Allocation { + public: + // TODO(b/307342076): Reorder scheduling times to be + // copy_start_schedule_after_time, copy_done_schedule_before_time, end_time + CopyAllocation( + Allocation& prev_allocation, MemorySpace memory_space, + std::optional chunk, + int64_t copy_start_schedule_after_time, + int64_t copy_done_schedule_before_time, int64_t end_time, + std::optional cross_program_prefetch_index = std::nullopt); + + // Overridden methods + // + HloPosition defining_position() const override; + // Returns the time the buffer is first available to be used. For + // CopyAllocation, this is when the copy ends, which is + // copy_done_schedule_before. + int64_t earliest_available_time() const override; + bool is_copy_allocation() const override { return true; } + bool is_sliced_copy_allocation() const override { return false; } + Status Process() override; + Status PostProcess() override { return OkStatus(); } + void MarkIfNeeded(absl::flat_hash_set& needed_allocations) + const override; + void MarkNeeded(absl::flat_hash_set& needed_allocations) + const override; + std::string ToString() const override; + bool operator==(const Allocation& other) const override; + + // New non-virtual methods + bool operator==(const CopyAllocation& other) const; + + const Allocation& prev_allocation() { return prev_allocation_; } + Allocation& mutable_prev_allocation() { return prev_allocation_; } + + HloInstruction* copy_start() const { return copy_start_; } + HloInstruction* copy_done() const { return copy_done_; } + + void set_copy_start_schedule_after(int64_t copy_start_schedule_after); + void set_copy_done_schedule_before(int64_t copy_done_schedule_before); + int64_t copy_start_schedule_after() const; + int64_t copy_done_schedule_before() const; + + private: + Allocation& prev_allocation_; + // These variables define the scheduling boundaries where CopyStart and + // CopyDone can be scheduled. The earliest CopyStart can be scheduled is + // after copy_start_schedule_after_ and the latest CopyDone can be scheduled + // is before copy_done_schedule_before_. + int64_t copy_start_schedule_after_; + int64_t copy_done_schedule_before_; + HloInstruction* copy_start_ = nullptr; + HloInstruction* copy_done_ = nullptr; +}; + +// This class represents an allocation resulting from asynchronous sliced +// copies. +// +// Let the sliced allocation be represented as follows, and imagine that t3 +// is the time when the entire buffer [p0, p3) is available for use +// +// space +// ^ +// p3 | +-----------+ +// | | | +// p2 | +---+ | +// | | | +// p1 | +-------+ | +// | | | +// p0 | +-------+ +// +---|---|---|---|---|----> time +// t0 t1 t2 t3 t4 +// +// The PinnedAllocation underlying the SlicedCopyAllocation will use the +// following dimensions: +// - chunk = [p0, p3) +// - start time = t2 +// - earliest_available_time = t3 +// - end_time = t4 +class SlicedCopyAllocation final : public Allocation { + public: + // Full details about a slice in the sliced allocation. + struct SliceDetail { + std::string ToString() const; + std::tuple + ToTuple() const; + bool operator==(const SliceDetail& other) const; + + // Create the instructions to copy the slice. This method updates + // copy_start and copy_done. + Status CreateAsyncSlice(const Shape& original_shape, + HloInstruction& producer, HloComputation& parent); + + SliceDecision slice_decision; + int64_t copy_start_after_time = -1; + int64_t copy_done_before_time = -1; + HloInstruction* copy_start = nullptr; + HloInstruction* copy_done = nullptr; + }; + + // REQUIRES: + // - slice_decisions_sorted_by_exclusive_start_time.size() >= 2, otherwise, + // CopyAllocation should be used. + SlicedCopyAllocation( + const Allocation& prev_allocation, MemorySpace memory_space, + std::vector slice_decisions_sorted_by_exclusive_start_time, + int64_t copy_done_schedule_before_time, int64_t end_time, + const SlicedPrefetchOptions& sliced_prefetch_options, + absl::FunctionRef get_equivalent_s8_shape_fn); + + // Overridden methods + // + HloPosition defining_position() const override; + // Returns the time the buffer is first available to be used. For + // SlicedCopyAllocation, this is when all copies have ended. + int64_t earliest_available_time() const override; + bool is_copy_allocation() const override { return false; } + bool is_sliced_copy_allocation() const override { return true; } + // MemorySpaceAssignment::Process() calls Process() to create asynchronous + // slice copies, and a bitcast-concat call to glue the slices back together. + Status Process() override; + Status PostProcess() override { return OkStatus(); } + // Marks the allocation as needed. + void MarkIfNeeded(absl::flat_hash_set& needed_allocations) + const override; + void MarkNeeded(absl::flat_hash_set& needed_allocations) + const override; + std::string ToString() const override; + bool operator==(const Allocation& other) const override; + + // New non-virtual methods + bool operator==(const SlicedCopyAllocation& other) const; + + std::vector SliceOffsetsSortedByStartTime() const; + void AddDiffToAllSliceOffsets(int64_t diff); + // Used to update offsets and start times after repacking. + void ImportRepackedSliceData(const SlicedAllocationData& data); + const std::vector& slice_details_sorted_by_start_time() const; + std::vector& mutable_slice_details_sorted_by_start_time(); + HloInstruction* concat() const { return concat_; } + + private: + SlicedCopyAllocation() = delete; + + // Create an instruction to concatenate the slices. Populates concat_. + Status CreateBitcastConcat(const Shape& shape, + absl::Span slices); + + Shape original_shape_to_slice_; + const Allocation& prev_allocation_; + // REQUIRES: + // - sorted_segments_[i].copy_start_after_time <= + // sorted_segments_[i+j].copy.start_after_time + // - sorted_segments_[i].copy_done_before_time <= + // sorted_segments_[i+j].copy.start_before_time + std::vector slice_details_sorted_by_exclusive_start_time_; + HloInstruction* concat_ = nullptr; + const SlicedPrefetchOptions& sliced_prefetch_options_; + absl::FunctionRef get_equivalent_s8_shape_fn_; +}; + +// An allocation in the default memory space that mirrors another Allocation +// object. This is useful to model an eviction that happens before a while op +// so that we don't need to redundantly evict the buffer after the while op as +// well. +class MirroredAllocation final : public Allocation { + public: + MirroredAllocation(const Allocation& original_allocation, int64_t time); + + // Overridden methods + // + // Returns the original defining position. + HloPosition defining_position() const override; + int64_t earliest_available_time() const override { return start_time(); } + bool is_copy_allocation() const override { return false; } + bool is_sliced_copy_allocation() const override { return false; } + Status Process() override; + Status PostProcess() override { return OkStatus(); } + void MarkIfNeeded(absl::flat_hash_set& needed_allocations) + const override; + void MarkNeeded(absl::flat_hash_set& needed_allocations) + const override; + std::string ToString() const override; + bool operator==(const Allocation& other) const override; + + // New non-virtual methods + bool operator==(const MirroredAllocation& other) const; + + private: + const Allocation& original_allocation_; +}; + +// An allocation in default memory space that is defined in the parent +// computation. If a value has a copy in the default memory space in the +// parent computation, we don't need to evict this buffer in a while loop. +class ParentAllocation final : public Allocation { + public: + ParentAllocation(const Allocation& original_allocation, + HloInstruction* calling_instruction, HloPosition position, + int64_t time); + + // Overridden methods + // + // Returns the original defining position. + HloPosition defining_position() const override; + int64_t earliest_available_time() const override { return start_time(); } + bool is_copy_allocation() const override { return false; } + bool is_sliced_copy_allocation() const override { return false; } + Status Process() override; + Status PostProcess() override; + void MarkIfNeeded(absl::flat_hash_set& needed_allocations) + const override; + void MarkNeeded(absl::flat_hash_set& needed_allocations) + const override; + std::string ToString() const override; + bool operator==(const Allocation& other) const override; + + // New non-virtual methods + bool operator==(const ParentAllocation& other) const; + + private: + const Allocation& original_allocation_; + HloInstruction* calling_instruction_; +}; + +} // namespace xla::memory_space_assignment + +#endif // XLA_SERVICE_MEMORY_SPACE_ASSIGNMENT_ALLOCATION_H_ diff --git a/third_party/xla/xla/service/memory_space_assignment/cost_analysis.cc b/third_party/xla/xla/service/memory_space_assignment/cost_analysis.cc new file mode 100644 index 00000000000000..2e5c07ac7546c1 --- /dev/null +++ b/third_party/xla/xla/service/memory_space_assignment/cost_analysis.cc @@ -0,0 +1,403 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/memory_space_assignment/cost_analysis.h" + +#include +#include +#include +#include +#include +#include + +#include "absl/log/check.h" +#include "absl/memory/memory.h" +#include "absl/types/span.h" +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/utils/hlo_live_range.h" +#include "xla/service/call_graph.h" +#include "xla/service/heap_simulator/heap_simulator.h" +#include "xla/service/hlo_alias_analysis.h" +#include "xla/service/hlo_buffer.h" +#include "xla/service/hlo_cost_analysis.h" +#include "xla/service/hlo_value.h" +#include "xla/shape.h" +#include "xla/shape_util.h" +#include "xla/statusor.h" +#include "xla/util.h" +#include "tsl/platform/statusor.h" + +namespace xla { +namespace memory_space_assignment { +/*static*/ StatusOr> CostAnalysis::Create( + const HloCostAnalysis& cost_analysis, const CostAnalysisOptions& options, + const HloModule& module) { + TF_ASSIGN_OR_RETURN(auto alias_analysis, HloAliasAnalysis::Run(&module)); + TF_ASSIGN_OR_RETURN(auto hlo_live_range, + HloLiveRange::Run(module.schedule(), *alias_analysis, + module.entry_computation())); + auto call_graph = CallGraph::Build(&module); + // Using `new` to access a non-public constructor. + return absl::WrapUnique( + new CostAnalysis(cost_analysis, options, std::move(alias_analysis), + std::move(hlo_live_range), std::move(call_graph))); +} + +float CostAnalysis::GetAlternateMemoryBenefit( + const HloInstruction& instruction, float elapsed_time_due_to_alternate_mem, + CostAnalysis::Cache* cache) const { + float elapsed_time_due_to_compute = + GetInstructionElapsedDueToCompute(instruction); + float elapsed_time_due_to_memory = + GetInstructionElapsedDueToMemory(instruction); + if (elapsed_time_due_to_memory > elapsed_time_due_to_compute) { + // Memory bound, return how much alternate memory is better. + float while_nest_multiplier; + if (cache) { + // If there is a cache provided, memoize the while nest multiplier. + auto it = cache->while_nest_multiplier.find(&instruction); + if (it != cache->while_nest_multiplier.end()) { + while_nest_multiplier = it->second; + } else { + while_nest_multiplier = GetWhileNestMultiplier( + CalculateComputationNestLevel(&instruction, + /*while_only=*/true)); + cache->while_nest_multiplier[&instruction] = while_nest_multiplier; + } + } else { + while_nest_multiplier = GetWhileNestMultiplier( + CalculateComputationNestLevel(&instruction, + /*while_only=*/true)); + } + return (elapsed_time_due_to_memory - elapsed_time_due_to_alternate_mem) * + while_nest_multiplier; + } else { + // Compute bound, return how far off are we to memory boundedness. + return elapsed_time_due_to_memory - elapsed_time_due_to_compute; + } +} + +float CostAnalysis::GetMemoryBoundedness( + const GlobalDecreasingSizeBestFitHeap::BufferInterval& interval, + CostAnalysis::Cache* cache) const { + if (cache) { + auto it = + cache->memory_boundedness.find(interval.buffer->defining_position()); + if (it != cache->memory_boundedness.end()) { + return it->second; + } + } + float alternate_mem_benefit = + GetAlternateMemoryBenefit(interval.buffer->defining_position(), cache); + + for (const HloBuffer* buffer : alias_analysis_->ComputeBuffersAt( + interval.buffer->defining_position().instruction, + interval.buffer->defining_position().index)) { + for (const HloValue* value : buffer->values()) { + for (const HloUse& use : value->GetUses()) { + // We look inside the called computations of while and conditional, so + // don't use the benefit of while and conditional directly. + if (use.instruction->opcode() == HloOpcode::kWhile || + use.instruction->opcode() == HloOpcode::kConditional) { + continue; + } + float use_alternate_mem_benefit = GetAlternateMemoryBenefit(use, cache); + // If the benefit is positive (memory bound), add it to this buffer's + // benefit. If the benefit is negative (compute bound), calculate the + // maximum. + if (alternate_mem_benefit > 0 && use_alternate_mem_benefit > 0) { + alternate_mem_benefit += use_alternate_mem_benefit; + } else { + alternate_mem_benefit = + std::max(alternate_mem_benefit, use_alternate_mem_benefit); + } + } + } + } + + // Penalize larger buffers by dividing the benefit by the square root of + // the size. Empirically, we observed this resulted in better performance + // compared to dividing by the size. + float memory_boundedness = 1; + if (options_ + .xla_tpu_alternate_memory_benefit_scaling_factor_for_large_buffers == + "NO_SCALE") { + memory_boundedness = alternate_mem_benefit; + } else { + memory_boundedness = alternate_mem_benefit / std::sqrt(interval.size); + } + + if (cache) { + cache->memory_boundedness[interval.buffer->defining_position()] = + memory_boundedness; + } + return memory_boundedness; +} + +float CostAnalysis::GetAlternateMemoryBenefit( + const HloPosition& position, CostAnalysis::Cache* cache) const { + return GetAlternateMemoryBenefit( + *position.instruction, + GetInstructionElapsedDueToMemory( + *position.instruction, + /*operands_in_alternate_mem=*/{}, + /*outputs_in_alternate_mem=*/{position.index}), + cache); +} + +float CostAnalysis::GetAlternateMemoryBenefit( + const HloUse& use, CostAnalysis::Cache* cache) const { + return GetAlternateMemoryBenefit( + *use.instruction, + GetInstructionElapsedDueToMemory( + *use.instruction, + /*operands_in_alternate_mem=*/{std::make_pair(use.operand_number, + use.operand_index)}), + cache); +} + +int CostAnalysis::CalculateComputationNestLevel( + const HloInstruction* instruction, bool while_only) const { + int nest_level = 0; + const HloComputation* computation = instruction->parent(); + while (!computation->IsEntryComputation()) { + auto& node = call_graph_->GetNode(computation); + auto callsites = node.caller_callsites(); + CHECK(node.computation()->IsAsyncComputation() || callsites.size() == 1) + << "The module is not flattened!"; + auto& callsite = callsites[0]; + if (!while_only || callsite.instruction()->opcode() == HloOpcode::kWhile) { + ++nest_level; + } + computation = callsite.instruction()->parent(); + } + return nest_level; +} + +float CostAnalysis::GetWhileNestMultiplier(int while_nest_level) const { + return IPow( + options_.xla_tpu_memory_space_assignment_while_execution_count, + while_nest_level); +} + +float CostAnalysis::GetDefaultMemoryAccessOverhead( + const HloInstruction& instruction, + absl::Span> operands_in_alternate_mem, + absl::Span outputs_in_alternate_mem) const { + // Calculate the pipeline overhead of accessing the default memory. We use the + // maximum of the window size heuristic and the actual default memory bytes + // accessed multiplied with the compute as the overhead. So, the math is: + // + // overhead = compute_per_iteration + // = compute_elapsed / num_iterations + // = compute_elapsed / (bytes_accessed / window_size) + // = (window_size / bytes_accessed) * compute_elapsed + const float window_size_bytes = + options_.pipeline_overhead_window_size_mib * 1024 * 1024; + const float bytes_accessed = hlo_cost_analysis_.bytes_accessed(instruction); + const float default_memory_bytes_accessed = + bytes_accessed - + GetBytesAccessedFromAlternateMemory( + instruction, operands_in_alternate_mem, outputs_in_alternate_mem); + const float compute_elapsed = GetInstructionElapsedDueToCompute(instruction); + const float effective_window_size_bytes = + std::min(window_size_bytes, default_memory_bytes_accessed); + float overhead = 0; + if (bytes_accessed > 0) { + overhead = (effective_window_size_bytes / bytes_accessed) * compute_elapsed; + } + return overhead; +} + +float CostAnalysis::GetDefaultMemoryBandwidthIdleTime( + const HloInstruction& instruction, + absl::Span> operands_in_alternate_mem, + absl::Span outputs_in_alternate_mem) const { + const float default_memory_bytes_accessed = + hlo_cost_analysis_.bytes_accessed(instruction) - + GetBytesAccessedFromAlternateMemory( + instruction, operands_in_alternate_mem, outputs_in_alternate_mem); + const float elapsed_due_to_default_mem = + default_memory_bytes_accessed / + hlo_cost_analysis_.per_second_rate(HloCostAnalysis::kBytesAccessedKey); + const float elapsed = GetInstructionElapsedInAlternateMemory( + instruction, operands_in_alternate_mem, outputs_in_alternate_mem); + return elapsed - elapsed_due_to_default_mem; +} + +float CostAnalysis::GetBytesAccessedFromAlternateMemory( + const HloInstruction& instruction, + absl::Span> operands_in_alternate_mem, + absl::Span outputs_in_alternate_mem) const { + float bytes_accessed_from_alternate_mem = 0.0; + for (auto& operand : operands_in_alternate_mem) { + const float operand_bytes_accessed = + hlo_cost_analysis_.operand_bytes_accessed(instruction, operand.first, + operand.second); + bytes_accessed_from_alternate_mem += operand_bytes_accessed; + } + + for (auto& shape_idx : outputs_in_alternate_mem) { + const float output_bytes_accessed = + hlo_cost_analysis_.output_bytes_accessed(instruction, shape_idx); + bytes_accessed_from_alternate_mem += output_bytes_accessed; + } + return bytes_accessed_from_alternate_mem; +} + +namespace { +// Returns true on async instructions since we assume they are already +// efficiently scheduled such that they are not in the critical path and appear +// to take no time. +bool ExcludeInstructionFromElapsed(const HloInstruction& instruction) { + return instruction.opcode() == HloOpcode::kAllGatherStart || + instruction.opcode() == HloOpcode::kAllGatherDone || + instruction.opcode() == HloOpcode::kAllReduceStart || + instruction.opcode() == HloOpcode::kAllReduceDone || + instruction.opcode() == HloOpcode::kAsyncStart || + instruction.opcode() == HloOpcode::kAsyncDone || + instruction.opcode() == HloOpcode::kCollectivePermuteStart || + instruction.opcode() == HloOpcode::kCollectivePermuteDone || + instruction.opcode() == HloOpcode::kCopyStart || + instruction.opcode() == HloOpcode::kCopyDone; +} +} // namespace + +float CostAnalysis::GetInstructionElapsedDueToCompute( + const HloInstruction& instruction) const { + if (ExcludeInstructionFromElapsed(instruction)) { + return 0.0f; + } + return std::max( + hlo_cost_analysis_.flop_count(instruction) / + hlo_cost_analysis_.per_second_rate(HloCostAnalysis::kFlopsKey), + hlo_cost_analysis_.transcendental_count(instruction) / + hlo_cost_analysis_.per_second_rate( + HloCostAnalysis::kTranscendentalsKey)); +} + +float CostAnalysis::GetInstructionElapsedDueToMemory( + const HloInstruction& instruction, + absl::Span> operands_in_alternate_mem, + absl::Span outputs_in_alternate_mem) const { + if (ExcludeInstructionFromElapsed(instruction)) { + return 0.0f; + } + float total_bytes_accessed = hlo_cost_analysis_.bytes_accessed(instruction); + float bytes_accessed_from_alternate_mem = GetBytesAccessedFromAlternateMemory( + instruction, operands_in_alternate_mem, outputs_in_alternate_mem); + float elapsed_due_to_alternate_mem = + bytes_accessed_from_alternate_mem / + options_.alternate_mem_bandwidth_bytes_per_second; + float elapsed_due_to_default_mem = + (total_bytes_accessed - bytes_accessed_from_alternate_mem) / + hlo_cost_analysis_.per_second_rate(HloCostAnalysis::kBytesAccessedKey); + return elapsed_due_to_alternate_mem + elapsed_due_to_default_mem; +} + +float CostAnalysis::GetInstructionElapsedDueToMemory( + const HloInstruction& instruction, + IsInAlternateMemoryFun is_in_alternate_mem) const { + if (ExcludeInstructionFromElapsed(instruction)) { + return 0.0f; + } + float total_bytes_accessed = hlo_cost_analysis_.bytes_accessed(instruction); + float bytes_accessed_from_alternate_mem = 0.0; + for (int operand_num = 0; operand_num < instruction.operand_count(); + ++operand_num) { + ShapeUtil::ForEachSubshape( + instruction.operand(operand_num)->shape(), + [&](const Shape& subshape, const ShapeIndex& index) { + if (!subshape.IsArray()) { + return; + } + if (is_in_alternate_mem(operand_num, index, subshape)) { + bytes_accessed_from_alternate_mem += + hlo_cost_analysis_.operand_bytes_accessed(instruction, + operand_num, index); + } + }); + } + ShapeUtil::ForEachSubshape(instruction.shape(), [&](const Shape& subshape, + const ShapeIndex& index) { + if (!subshape.IsArray()) { + return; + } + if (is_in_alternate_mem(/*operand_num=*/std::nullopt, index, subshape)) { + bytes_accessed_from_alternate_mem += + hlo_cost_analysis_.output_bytes_accessed(instruction, index); + } + }); + float elapsed_due_to_alternate_mem = + bytes_accessed_from_alternate_mem / + options_.alternate_mem_bandwidth_bytes_per_second; + float elapsed_due_to_default_mem = + (total_bytes_accessed - bytes_accessed_from_alternate_mem) / + hlo_cost_analysis_.per_second_rate(HloCostAnalysis::kBytesAccessedKey); + return elapsed_due_to_alternate_mem + elapsed_due_to_default_mem; +} + +float CostAnalysis::GetInstructionElapsed( + const HloInstruction& instruction) const { + if (ExcludeInstructionFromElapsed(instruction)) { + return 0.0f; + } + float overhead = GetDefaultMemoryAccessOverhead(instruction); + return std::max(GetInstructionElapsedDueToCompute(instruction), + GetInstructionElapsedDueToMemory(instruction) + overhead); +} + +float CostAnalysis::GetInstructionElapsedInAlternateMemory( + const HloInstruction& instruction, + absl::Span> operands_in_alternate_mem, + absl::Span outputs_in_alternate_mem) const { + if (ExcludeInstructionFromElapsed(instruction)) { + return 0.0f; + } + float overhead = GetDefaultMemoryAccessOverhead( + instruction, operands_in_alternate_mem, outputs_in_alternate_mem); + return std::max( + GetInstructionElapsedDueToCompute(instruction), + GetInstructionElapsedDueToMemory(instruction, operands_in_alternate_mem, + outputs_in_alternate_mem) + + overhead); +} + +float CostAnalysis::GetInstructionElapsedInAlternateMemory( + const HloInstruction& instruction, + IsInAlternateMemoryFun is_in_alternate_mem) const { + if (ExcludeInstructionFromElapsed(instruction)) { + return 0.0f; + } + return std::max( + GetInstructionElapsedDueToCompute(instruction), + GetInstructionElapsedDueToMemory(instruction, is_in_alternate_mem)); +} + +float CostAnalysis::GetAsyncCopyElapsed(const Shape& shape) const { + int64_t size_in_bytes = hlo_cost_analysis_.GetShapeSize(shape); + return static_cast(size_in_bytes) / + (options_.async_copy_bandwidth_bytes_per_second * + options_.async_copy_bandwidth_scaling_factor); +} + +int64_t CostAnalysis::GetScheduleEndTime() const { + return hlo_live_range_->schedule_end_time(); +} + +} // namespace memory_space_assignment +} // namespace xla diff --git a/third_party/xla/xla/service/memory_space_assignment/cost_analysis.h b/third_party/xla/xla/service/memory_space_assignment/cost_analysis.h new file mode 100644 index 00000000000000..8463f7f914564b --- /dev/null +++ b/third_party/xla/xla/service/memory_space_assignment/cost_analysis.h @@ -0,0 +1,237 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_MEMORY_SPACE_ASSIGNMENT_COST_ANALYSIS_H_ +#define XLA_SERVICE_MEMORY_SPACE_ASSIGNMENT_COST_ANALYSIS_H_ + +#include +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/functional/function_ref.h" +#include "absl/types/span.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/utils/hlo_live_range.h" +#include "xla/service/call_graph.h" +#include "xla/service/heap_simulator/heap_simulator.h" +#include "xla/service/hlo_alias_analysis.h" +#include "xla/service/hlo_cost_analysis.h" +#include "xla/service/hlo_value.h" +#include "xla/shape.h" +#include "xla/shape_util.h" +#include "xla/statusor.h" +#include "xla/util.h" + +namespace xla { +namespace memory_space_assignment { + +// Options to be passed to the CostAnalysis. +struct CostAnalysisOptions { + // This variable is used by the cost analysis in estimating how many times + // each while loop will execute. Nested loops will be assumed to have + // executed pow(while_execution_count, nesting_level) times. + uint64_t xla_tpu_memory_space_assignment_while_execution_count = 5ULL; + + // This variable is used to scale the alternate memory benefit factor for + // large buffers. The default scaling function is sqrt. + std::string + xla_tpu_alternate_memory_benefit_scaling_factor_for_large_buffers = + "SQRT"; + + // The window size used to calculate the pipeline overhead when HLO accesses + // the default memory, in MiB. + float pipeline_overhead_window_size_mib = 0; + + float alternate_mem_bandwidth_bytes_per_second = 0.0f; + + float async_copy_bandwidth_bytes_per_second = 0.0f; + + // Scales effective bandwidth for async copies. Valid range is (0, 1]. + float async_copy_bandwidth_scaling_factor = 1.0; +}; + +// A wrapper class around HloCostAnalysis with additional knowledge about the +// bandwidths of different memory spaces. +class CostAnalysis { + public: + // An optional Cache object may be provided to some of the methods below to + // speed up the lookup. + struct Cache { + absl::flat_hash_map while_nest_multiplier; + absl::flat_hash_map memory_boundedness; + }; + + // Function type that can be used to indicate which input/output values are in + // the alternate memory. + using IsInAlternateMemoryFun = absl::FunctionRef /*operand_num*/, const ShapeIndex& /*index*/, + const Shape& /*shape*/)>; + + virtual ~CostAnalysis() = default; + + static StatusOr> Create( + const HloCostAnalysis& cost_analysis, const CostAnalysisOptions& options, + const HloModule& module); + + const HloCostAnalysis& hlo_cost_analysis() const { + return hlo_cost_analysis_; + } + + // Returns a heuristic value that captures how much putting this tensor to the + // alternate memory would help if the op is memory bound, or otherwise how far + // off is the op to memory boundedness. The larger this number, the higher + // priority it will be placed in the alternate memory. + float GetAlternateMemoryBenefit(const HloInstruction& instruction, + float elapsed_time_due_to_alternate_mem, + Cache* cache = nullptr) const; + // Like above, return the benefit of putting the output tensor in the + // alternate memory. + float GetAlternateMemoryBenefit(const HloPosition& position, + Cache* cache = nullptr) const; + // Like above, return the benefit of putting the input tensor in the alternate + // memory. + float GetAlternateMemoryBenefit(const HloUse& use, + Cache* cache = nullptr) const; + + // Returns a heuristic value of memory boundedness for the given + // BufferInterval. The larger this number, the higher priority it will be + // placed in the alternate memory. + float GetMemoryBoundedness( + const GlobalDecreasingSizeBestFitHeap::BufferInterval& interval, + Cache* cache = nullptr) const; + + // If enabled in CostAnalysisOptions::pipeline_overhead_window_size_mib, + // returns the overhead of accessing the default memory, in seconds. The + // source of the overhead is the software pipelining ovehead. The lowering of + // the operations typically use tiling to copy one window at a time from + // default memory, and perform compute: + // + // Pipeline overhead: <-> + // +----+----+----+----+ + // Copy from default mem: | | | | | + // +----+----+----+----+ + // \ \ \ \ + // \ \ \ \ + // V V V V + // +--+ +--+ +--+ +--+ + // Compute: | | | | | | | | + // +--+ +--+ +--+ +--+ + float GetDefaultMemoryAccessOverhead( + const HloInstruction& instruction, + absl::Span> + operands_in_alternate_mem = {}, + absl::Span outputs_in_alternate_mem = {}) const; + + // Returns the amount of time the default memory bandwidth is idle, while + // executing this instruction, in seconds. This value can be multiplied with + // the default memory bandwidth to get the amount of bytes that are available + // to be copied to/from default memory during the execution of this + // instruction. + float GetDefaultMemoryBandwidthIdleTime( + const HloInstruction& instruction, + absl::Span> + operands_in_alternate_mem = {}, + absl::Span outputs_in_alternate_mem = {}) const; + + // Returns the bytes accessed from alternate memory. + float GetBytesAccessedFromAlternateMemory( + const HloInstruction& instruction, + absl::Span> + operands_in_alternate_mem = {}, + absl::Span outputs_in_alternate_mem = {}) const; + + // Returns the elapsed time in seconds due to compute only. + float GetInstructionElapsedDueToCompute( + const HloInstruction& instruction) const; + + // Returns the elapsed time in seconds due to memory only. If + // operands_in_alternate_mem or outputs_in_alternate_mem is provided, it will + // assume that the corresponding operands or output will be in the alternate + // memory space. This is useful for calculating the benefit of placing the + // buffer in alternate memory. + float GetInstructionElapsedDueToMemory( + const HloInstruction& instruction, + absl::Span> + operands_in_alternate_mem = {}, + absl::Span outputs_in_alternate_mem = {}) const; + + // Like above, only the inputs/outputs indicated by is_in_alternate_mem are in + // the alternate memory. + float GetInstructionElapsedDueToMemory( + const HloInstruction& instruction, + IsInAlternateMemoryFun is_in_alternate_mem) const; + + // Returns the estimated elapsed duration of the instruction in seconds. It + // assumes all operands and outputs of the instruction are in the default + // memory. + virtual float GetInstructionElapsed(const HloInstruction& instruction) const; + + // Returns the estimated elapsed duration of the instruction in seconds. It + // assumes all operands and outputs of the instruction are in the default + // memory, except for the operands and outputs specified to be in the + // alternate memory. + virtual float GetInstructionElapsedInAlternateMemory( + const HloInstruction& instruction, + absl::Span> + operands_in_alternate_mem, + absl::Span outputs_in_alternate_mem) const; + + // Like above, only the inputs/outputs indicated by is_in_alternate_mem are in + // the alternate memory. + float GetInstructionElapsedInAlternateMemory( + const HloInstruction& instruction, + IsInAlternateMemoryFun is_in_alternate_mem) const; + + // Returns the elapsed time it would take to asynchronously copy the shape + // from default to alternate memory space (or vice versa). + virtual float GetAsyncCopyElapsed(const Shape& shape) const; + + int64_t GetScheduleEndTime() const; + + // Returns the number of nested computation levels this instruction resides + // in. If while_only is true, it returns the while loop nest level and 0 + // means the instruction is not in a while loop. + int CalculateComputationNestLevel(const HloInstruction* instruction, + bool while_only) const; + float GetWhileNestMultiplier(int while_nest_level) const; + + const HloLiveRange& hlo_live_range() const { return *hlo_live_range_; } + + protected: + CostAnalysis(const HloCostAnalysis& hlo_cost_analysis, + const CostAnalysisOptions& options, + std::unique_ptr alias_analysis, + std::unique_ptr hlo_live_range, + std::unique_ptr call_graph) + : hlo_cost_analysis_(hlo_cost_analysis), + options_(options), + alias_analysis_(std::move(alias_analysis)), + hlo_live_range_(std::move(hlo_live_range)), + call_graph_(std::move(call_graph)) {} + + private: + const HloCostAnalysis& hlo_cost_analysis_; + const CostAnalysisOptions options_; + std::unique_ptr alias_analysis_; + std::unique_ptr hlo_live_range_; + std::unique_ptr call_graph_; +}; + +} // namespace memory_space_assignment +} // namespace xla +#endif // XLA_SERVICE_MEMORY_SPACE_ASSIGNMENT_COST_ANALYSIS_H_ diff --git a/third_party/xla/xla/service/memory_space_assignment/cost_analysis_test.cc b/third_party/xla/xla/service/memory_space_assignment/cost_analysis_test.cc new file mode 100644 index 00000000000000..54567bb3855d0f --- /dev/null +++ b/third_party/xla/xla/service/memory_space_assignment/cost_analysis_test.cc @@ -0,0 +1,229 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/memory_space_assignment/cost_analysis.h" + +#include +#include + +#include +#include "absl/log/log.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/service/hlo_cost_analysis.h" +#include "xla/shape.h" +#include "xla/shape_util.h" +#include "xla/status.h" +#include "xla/tests/hlo_test_base.h" +#include "tsl/lib/core/status_test_util.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/statusor.h" + +namespace xla { +namespace { + +using memory_space_assignment::CostAnalysis; +using memory_space_assignment::CostAnalysisOptions; + +constexpr int64_t kPointerSize = 8; + +int64_t ShapeSize(const Shape& shape) { + return ShapeUtil::ByteSizeOf(shape, kPointerSize); +} + +class MemorySpaceAssignmentCostAnalysisTest : public HloTestBase { + protected: + Status Initialize(const HloModule* module, + float pipeline_overhead_window_size_mib = 0.0) { + HloCostAnalysis::Options options; + options_.alternate_mem_bandwidth_bytes_per_second = 128; + options_.async_copy_bandwidth_bytes_per_second = 32; + options_.pipeline_overhead_window_size_mib = + pipeline_overhead_window_size_mib; + options.shape_size = ShapeSize; + options.set_flops_per_second(8); + options.set_bytes_per_second(32); + options.set_transcendentals_per_second(16); + hlo_cost_analysis_ = std::make_unique(options); + TF_RETURN_IF_ERROR( + module->entry_computation()->Accept(hlo_cost_analysis_.get())); + TF_ASSIGN_OR_RETURN( + cost_analysis_, + CostAnalysis::Create(*hlo_cost_analysis_, options_, *module)); + return OkStatus(); + } + + CostAnalysisOptions options_; + std::unique_ptr hlo_cost_analysis_; + std::unique_ptr cost_analysis_; +}; + +TEST_F(MemorySpaceAssignmentCostAnalysisTest, NoPipelineOverhead) { + absl::string_view hlo_string = R"( + HloModule module, is_scheduled=true + + ENTRY Entry { + param0 = f32[2,4] parameter(0) + param1 = f32[2,4] parameter(1) + ROOT add = f32[2,4] add(param0, param1) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK(Initialize(module.get())); + + const HloInstruction* add = module->entry_computation()->root_instruction(); + const float expected_compute_elapsed = + /*num_flops=*/8 / /*flops_per_second=*/8.0; + LOG(INFO) << "Expected compute elapsed = " << expected_compute_elapsed; + EXPECT_EQ(cost_analysis_->GetInstructionElapsedDueToCompute(*add), + expected_compute_elapsed); + float expected_memory_elapsed = + /*bytes_accessed=*/(3 * 4 * 8) / /*bytes_per_second=*/32.0; + LOG(INFO) << "Expected memory elapsed = " << expected_memory_elapsed; + EXPECT_EQ(cost_analysis_->GetInstructionElapsedDueToMemory(*add), + expected_memory_elapsed); + + // This HLO is memory-bound. + EXPECT_EQ(cost_analysis_->GetInstructionElapsed(*add), + expected_memory_elapsed); + EXPECT_EQ( + cost_analysis_->GetInstructionElapsedInAlternateMemory(*add, {}, {}), + expected_memory_elapsed); + + // Put operand 0 in alternate memory. Still memory bound. + expected_memory_elapsed = + (/*bytes_accessed=*/(2 * 4 * 8) / /*bytes_per_second=*/32.0) + + (/*bytes_accessed=*/(4 * 8) / /*bytes_per_second=*/128.0); + LOG(INFO) << "Expected memory elapsed = " << expected_memory_elapsed; + EXPECT_EQ(cost_analysis_->GetInstructionElapsedDueToMemory(*add, {{0, {}}}), + expected_memory_elapsed); + EXPECT_EQ(cost_analysis_->GetInstructionElapsedInAlternateMemory( + *add, {{0, {}}}, {}), + expected_memory_elapsed); + + // Put operand 0 and output in alternate memory. Still memory bound. + expected_memory_elapsed = + (/*bytes_accessed=*/(4 * 8) / /*bytes_per_second=*/32.0) + + (/*bytes_accessed=*/(2 * 4 * 8) / /*bytes_per_second=*/128.0); + LOG(INFO) << "Expected memory elapsed = " << expected_memory_elapsed; + EXPECT_EQ( + cost_analysis_->GetInstructionElapsedDueToMemory(*add, {{0, {}}}, {{}}), + expected_memory_elapsed); + EXPECT_EQ(cost_analysis_->GetInstructionElapsedInAlternateMemory( + *add, {{0, {}}}, {{}}), + expected_memory_elapsed); + + // Put everything in alternate memory. We're now compute bound. + expected_memory_elapsed = + /*bytes_accessed=*/(3 * 4 * 8) / /*bytes_per_second=*/128.0; + LOG(INFO) << "Expected memory elapsed = " << expected_memory_elapsed; + EXPECT_EQ(cost_analysis_->GetInstructionElapsedDueToMemory( + *add, {{0, {}}, {1, {}}}, {{}}), + expected_memory_elapsed); + EXPECT_EQ(cost_analysis_->GetInstructionElapsedInAlternateMemory( + *add, {{0, {}}, {1, {}}}, {{}}), + expected_compute_elapsed); +} + +TEST_F(MemorySpaceAssignmentCostAnalysisTest, PipelineOverhead) { + absl::string_view hlo_string = R"( + HloModule module, is_scheduled=true + + ENTRY Entry { + param0 = f32[2,4] parameter(0) + param1 = f32[2,4] parameter(1) + ROOT add = f32[2,4] add(param0, param1) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + // Set the window size 64B. + TF_ASSERT_OK( + Initialize(module.get(), + /*pipeline_overhead_window_size_mib=*/(64.0 / 1024 / 1024))); + + const HloInstruction* add = module->entry_computation()->root_instruction(); + const float expected_compute_elapsed = + /*num_flops=*/8 / /*flops_per_second=*/8.0; + LOG(INFO) << "Expected compute elapsed = " << expected_compute_elapsed; + EXPECT_EQ(cost_analysis_->GetInstructionElapsedDueToCompute(*add), + expected_compute_elapsed); + float expected_memory_elapsed = + /*bytes_accessed=*/(3 * 4 * 8) / /*bytes_per_second=*/32.0; + LOG(INFO) << "Expected memory elapsed = " << expected_memory_elapsed; + EXPECT_EQ(cost_analysis_->GetInstructionElapsedDueToMemory(*add), + expected_memory_elapsed); + + float expected_overhead = expected_compute_elapsed * 2 / 3; + LOG(INFO) << "Expected overhead = " << expected_overhead; + EXPECT_EQ(cost_analysis_->GetDefaultMemoryAccessOverhead(*add), + expected_overhead); + // This HLO is memory-bound. + EXPECT_EQ(cost_analysis_->GetInstructionElapsed(*add), + expected_memory_elapsed + expected_overhead); + EXPECT_EQ( + cost_analysis_->GetInstructionElapsedInAlternateMemory(*add, {}, {}), + expected_memory_elapsed + expected_overhead); + + // Put operand 0 in alternate memory. Still memory bound. + expected_memory_elapsed = + (/*bytes_accessed=*/(2 * 4 * 8) / /*bytes_per_second=*/32.0) + + (/*bytes_accessed=*/(4 * 8) / /*bytes_per_second=*/128.0); + LOG(INFO) << "Expected memory elapsed = " << expected_memory_elapsed; + EXPECT_EQ(cost_analysis_->GetDefaultMemoryAccessOverhead(*add, {{0, {}}}), + expected_overhead); + EXPECT_EQ(cost_analysis_->GetInstructionElapsedDueToMemory(*add, {{0, {}}}), + expected_memory_elapsed); + EXPECT_EQ(cost_analysis_->GetInstructionElapsedInAlternateMemory( + *add, {{0, {}}}, {}), + expected_memory_elapsed + expected_overhead); + + // Put operand 0 and output in alternate memory. Still memory bound. + expected_memory_elapsed = + (/*bytes_accessed=*/(4 * 8) / /*bytes_per_second=*/32.0) + + (/*bytes_accessed=*/(2 * 4 * 8) / /*bytes_per_second=*/128.0); + LOG(INFO) << "Expected memory elapsed = " << expected_memory_elapsed; + expected_overhead = expected_compute_elapsed / 3; + LOG(INFO) << "Expected overhead = " << expected_overhead; + EXPECT_EQ( + cost_analysis_->GetDefaultMemoryAccessOverhead(*add, {{0, {}}}, {{}}), + expected_overhead); + EXPECT_EQ( + cost_analysis_->GetInstructionElapsedDueToMemory(*add, {{0, {}}}, {{}}), + expected_memory_elapsed); + EXPECT_EQ(cost_analysis_->GetInstructionElapsedInAlternateMemory( + *add, {{0, {}}}, {{}}), + expected_memory_elapsed + expected_overhead); + + // Put everything in alternate memory. We're now compute bound. + expected_memory_elapsed = + /*bytes_accessed=*/(3 * 4 * 8) / /*bytes_per_second=*/128.0; + LOG(INFO) << "Expected memory elapsed = " << expected_memory_elapsed; + expected_overhead = 0; + LOG(INFO) << "Expected overhead = " << expected_overhead; + EXPECT_EQ(cost_analysis_->GetDefaultMemoryAccessOverhead( + *add, {{0, {}}, {1, {}}}, {{}}), + expected_overhead); + EXPECT_EQ(cost_analysis_->GetInstructionElapsedDueToMemory( + *add, {{0, {}}, {1, {}}}, {{}}), + expected_memory_elapsed); + EXPECT_EQ(cost_analysis_->GetInstructionElapsedInAlternateMemory( + *add, {{0, {}}, {1, {}}}, {{}}), + expected_compute_elapsed); +} + +} // namespace +} // namespace xla diff --git a/third_party/xla/xla/service/memory_space_assignment/memory_space_assignment.cc b/third_party/xla/xla/service/memory_space_assignment/memory_space_assignment.cc index f021a25ff0a1a5..40ec7d2eaf05ef 100644 --- a/third_party/xla/xla/service/memory_space_assignment/memory_space_assignment.cc +++ b/third_party/xla/xla/service/memory_space_assignment/memory_space_assignment.cc @@ -42,11 +42,9 @@ limitations under the License. #include "absl/functional/function_ref.h" #include "absl/log/check.h" #include "absl/memory/memory.h" -#include "absl/strings/numbers.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" #include "absl/strings/str_join.h" -#include "absl/strings/str_split.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "re2/re2.h" @@ -65,18 +63,21 @@ limitations under the License. #include "xla/service/hlo_cost_analysis.h" #include "xla/service/hlo_dataflow_analysis.h" #include "xla/service/hlo_value.h" +#include "xla/service/memory_space_assignment/allocation.h" +#include "xla/service/memory_space_assignment/cost_analysis.h" #include "xla/service/memory_space_assignment/memory_space_assignment.pb.h" #include "xla/service/memory_space_assignment/repacking.h" +#include "xla/service/memory_space_assignment/slice.h" #include "xla/service/memory_space_assignment/tuning_utils.h" #include "xla/service/memory_space_assignment/utils.h" #include "xla/service/time_utils.h" -#include "xla/service/tuple_util.h" #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/status.h" #include "xla/status_macros.h" #include "xla/statusor.h" #include "xla/util.h" +#include "xla/xla_data.pb.h" #include "tsl/platform/casts.h" #include "tsl/platform/errors.h" #include "tsl/platform/logging.h" @@ -93,13 +94,6 @@ const HeapSimulator::Chunk kDummyChunk = // For cross-program prefetched buffer, we only perform the freeing optimization // if the buffer occupies less of the execution time ratio than this value. const float kCrossProgramPrefetchOccupyFreeingLimit = 0.6; -// Each time we retry compilation, increase the preferred eviction end time by -// this amount multiplied by preferred overlap to async copy ratio. -const float kEvictionRetryMultiplier = 2.0; -// The number of decreasing intervals for CostAnalysisPrefetchIntervalPicker to -// return when it runs out of increasing intervals. Increasing this number may -// hurt compilation time. -const int kNumExploredDecreasingIntervals = 100; template std::string VectorToString(const std::vector& v, @@ -233,6 +227,17 @@ std::vector FindCrossProgramPrefetchUses( bool IsCrossProgramPrefetchCandidate(const HloValue& value, const HloAliasAnalysis& alias_analysis, const Options& options) { + // Filter out values that alias with the entry computation root. + const HloBuffer& buffer = alias_analysis.GetBufferContainingValue(value); + const HloInstruction* root = alias_analysis.dataflow_analysis() + .module() + .entry_computation() + ->root_instruction(); + for (const HloPosition& position : buffer.ComputePositions()) { + if (position.instruction == root) { + return false; + } + } std::vector uses = FindCrossProgramPrefetchUses(value.GetUses(), alias_analysis); return value.defining_instruction()->parent() == @@ -340,18 +345,6 @@ Status InsertInstructionAndEnsureOperandsInserted( return OkStatus(); } -std::string UsesToString(const std::vector& uses) { - if (uses.empty()) { - return "none"; - } - std::vector uses_str; - uses_str.reserve(uses.size()); - for (const auto& use : uses) { - uses_str.push_back(use.ToString()); - } - return absl::StrJoin(uses_str, ","); -} - StatusOr GetScheduleTimeFromInstructionName( absl::string_view name, const absl::flat_hash_map GetAllocationSortTuple( - const std::unique_ptr& allocation) { + const std::unique_ptr& allocation) { int64_t scheduled_on_or_before = allocation->start_time(); int64_t scheduled_on_or_after = allocation->start_time(); if (allocation->is_copy_allocation()) { auto copy_allocation = - tensorflow::down_cast( - allocation.get()); + tensorflow::down_cast(allocation.get()); scheduled_on_or_before = copy_allocation->copy_done_schedule_before(); scheduled_on_or_after = copy_allocation->copy_start_schedule_after(); } @@ -540,12 +532,10 @@ std::tuple GetAllocationSortTuple( void SortAllocationSequence( MemorySpaceAssignment::AllocationSequence& allocations) { - absl::c_sort( - allocations, - [](const std::unique_ptr& lhs, - const std::unique_ptr& rhs) { - return GetAllocationSortTuple(lhs) < GetAllocationSortTuple(rhs); - }); + absl::c_sort(allocations, [](const std::unique_ptr& lhs, + const std::unique_ptr& rhs) { + return GetAllocationSortTuple(lhs) < GetAllocationSortTuple(rhs); + }); } std::string AllocationSequenceToString( @@ -555,8 +545,7 @@ std::string AllocationSequenceToString( SortAllocationSequence(allocations); } std::string allocations_str = "\n"; - for (const std::unique_ptr& allocation : - allocations) { + for (const std::unique_ptr& allocation : allocations) { absl::StrAppend(&allocations_str, allocation->ToString(), "\n"); } return allocations_str; @@ -580,15 +569,12 @@ std::string InstructionScheduleToString(const HloLiveRange& hlo_live_range) { return instruction_schedule_str; } -void EnsureParentAllocationIsAvailableForCopy( - MemorySpaceAssignment::CopyAllocation* copy_allocation) { - MemorySpaceAssignment::Allocation& parent_allocation = - copy_allocation->mutable_prev_allocation(); +void EnsureParentAllocationIsAvailableForCopy(CopyAllocation* copy_allocation) { + Allocation& parent_allocation = copy_allocation->mutable_prev_allocation(); parent_allocation.Extend(copy_allocation->copy_done_schedule_before()); if (parent_allocation.is_copy_allocation()) { auto parent_copy_allocation = - tensorflow::down_cast( - &parent_allocation); + tensorflow::down_cast(&parent_allocation); parent_copy_allocation->set_copy_done_schedule_before( std::min(parent_copy_allocation->copy_done_schedule_before(), copy_allocation->start_time())); @@ -598,8 +584,8 @@ void EnsureParentAllocationIsAvailableForCopy( } } -void MakeCopyAllocationJitForSingleUse( - MemorySpaceAssignment::CopyAllocation* copy_allocation, int64_t use_time) { +void MakeCopyAllocationJitForSingleUse(CopyAllocation* copy_allocation, + int64_t use_time) { copy_allocation->set_start_time(use_time - 1); copy_allocation->set_copy_start_schedule_after(use_time - 1); copy_allocation->set_end_time(use_time); @@ -611,12 +597,10 @@ int64_t GetUseTime(const HloUse& use, const HloLiveRange& hlo_live_range) { return hlo_live_range.instruction_schedule().at(use.instruction); } -std::vector -GetAllocationSequenceInRawPointers( +std::vector GetAllocationSequenceInRawPointers( MemorySpaceAssignment::AllocationSequence& allocations) { - std::vector allocations_in_raw_pointers; - for (const std::unique_ptr& allocation : - allocations) { + std::vector allocations_in_raw_pointers; + for (const std::unique_ptr& allocation : allocations) { allocations_in_raw_pointers.push_back(allocation.get()); } return allocations_in_raw_pointers; @@ -625,14 +609,13 @@ GetAllocationSequenceInRawPointers( void ProcessPrefetchesToAlternateMemory( MemorySpaceAssignment::AllocationSequence& allocations, const HloLiveRange& hlo_live_range) { - std::vector allocations_in_raw_pointers = + std::vector allocations_in_raw_pointers = GetAllocationSequenceInRawPointers(allocations); for (auto allocation : allocations_in_raw_pointers) { if (allocation->is_copy_allocation() && allocation->is_in_alternate_mem() && !allocation->uses().empty()) { - MemorySpaceAssignment::CopyAllocation* prefetch = - tensorflow::down_cast( - allocation); + CopyAllocation* prefetch = + tensorflow::down_cast(allocation); std::vector uses = prefetch->uses(); // Create a copy of uses. prefetch->clear_uses(); // Clear old uses. // For every prefetch, update prefetch to serve earliest use just in time. @@ -644,11 +627,9 @@ void ProcessPrefetchesToAlternateMemory( for (size_t use_index = 1; use_index < uses.size(); ++use_index) { const HloUse& use = uses[use_index]; int64_t use_time = GetUseTime(use, hlo_live_range); - auto jit_single_use_prefetch = - std::make_unique( - prefetch->mutable_prev_allocation(), - MemorySpaceAssignment::MemorySpace::kAlternate, - prefetch->chunk(), use_time - 1, use_time, use_time); + auto jit_single_use_prefetch = std::make_unique( + prefetch->mutable_prev_allocation(), MemorySpace::kAlternate, + prefetch->chunk(), use_time - 1, use_time, use_time); jit_single_use_prefetch->set_copy_start_schedule_after(use_time - 1); jit_single_use_prefetch->AddUse(use); EnsureParentAllocationIsAvailableForCopy(jit_single_use_prefetch.get()); @@ -658,28 +639,21 @@ void ProcessPrefetchesToAlternateMemory( } } -void MakeEvictionImmediate(MemorySpaceAssignment::CopyAllocation* eviction) { - const MemorySpaceAssignment::Allocation& parent_allocation = - eviction->prev_allocation(); +void MakeEvictionImmediate(CopyAllocation* eviction) { + const Allocation& parent_allocation = eviction->prev_allocation(); eviction->set_start_time(parent_allocation.start_time()); eviction->set_copy_start_schedule_after(parent_allocation.start_time()); eviction->set_copy_done_schedule_before(parent_allocation.start_time() + 1); eviction->Extend(parent_allocation.start_time() + 1); } -absl::flat_hash_map -GetEvictionsMap(std::vector& allocations) { - absl::flat_hash_map - evictions_map; +absl::flat_hash_map GetEvictionsMap( + std::vector& allocations) { + absl::flat_hash_map evictions_map; for (auto& allocation : allocations) { if (allocation->is_copy_allocation() && allocation->is_in_default_mem()) { - auto eviction = - tensorflow::down_cast( - allocation); - MemorySpaceAssignment::Allocation& parent_allocation = - eviction->mutable_prev_allocation(); + auto eviction = tensorflow::down_cast(allocation); + Allocation& parent_allocation = eviction->mutable_prev_allocation(); if (!parent_allocation.is_copy_allocation()) { evictions_map[&parent_allocation] = eviction; } @@ -691,13 +665,12 @@ GetEvictionsMap(std::vector& allocations) { void ProcessBuffersProducedInAlternateMemory( MemorySpaceAssignment::AllocationSequence& allocations, const HloLiveRange& hlo_live_range) { - std::vector allocations_in_raw_pointers = + std::vector allocations_in_raw_pointers = GetAllocationSequenceInRawPointers(allocations); // For all parent allocations produced in alternate memory, create a map from // parent allocation -> eviction. - absl::flat_hash_map - evictions_map = GetEvictionsMap(allocations_in_raw_pointers); + absl::flat_hash_map evictions_map = + GetEvictionsMap(allocations_in_raw_pointers); // Make all such evictions immediate. for (auto& [_, eviction] : evictions_map) { MakeEvictionImmediate(eviction); @@ -723,22 +696,19 @@ void ProcessBuffersProducedInAlternateMemory( continue; } if (!evictions_map.contains(allocation)) { - auto eviction_unique_ptr = - std::make_unique( - *allocation, MemorySpaceAssignment::MemorySpace::kDefault, - std::nullopt, allocation->start_time(), - allocation->start_time() + 1, allocation->start_time() + 1); + auto eviction_unique_ptr = std::make_unique( + *allocation, MemorySpace::kDefault, std::nullopt, + allocation->start_time(), allocation->start_time() + 1, + allocation->start_time() + 1); eviction_unique_ptr->set_copy_start_schedule_after( allocation->start_time()); evictions_map[allocation] = eviction_unique_ptr.get(); allocations.push_back(std::move(eviction_unique_ptr)); } - MemorySpaceAssignment::CopyAllocation* eviction = - evictions_map[allocation]; - auto jit_single_use_prefetch = - std::make_unique( - *eviction, MemorySpaceAssignment::MemorySpace::kAlternate, - allocation->chunk(), use_time - 1, use_time, use_time); + CopyAllocation* eviction = evictions_map[allocation]; + auto jit_single_use_prefetch = std::make_unique( + *eviction, MemorySpace::kAlternate, allocation->chunk(), + use_time - 1, use_time, use_time); jit_single_use_prefetch->set_copy_start_schedule_after(use_time - 1); jit_single_use_prefetch->AddUse(use); EnsureParentAllocationIsAvailableForCopy(jit_single_use_prefetch.get()); @@ -763,883 +733,9 @@ void TransformAllocationSequenceToSpill( XLA_LOG_LINES(2, AllocationSequenceToString(allocations, true)); SortAllocationSequence(allocations); } -} // namespace - -/*static*/ StatusOr> -MemorySpaceAssignmentCostAnalysis::Create(const HloCostAnalysis& cost_analysis, - const Options& options, - const HloModule& module) { - TF_ASSIGN_OR_RETURN(auto alias_analysis, HloAliasAnalysis::Run(&module)); - TF_ASSIGN_OR_RETURN(auto hlo_live_range, - HloLiveRange::Run(module.schedule(), *alias_analysis, - module.entry_computation())); - auto call_graph = CallGraph::Build(&module); - return absl::WrapUnique(new MemorySpaceAssignmentCostAnalysis( - cost_analysis, options, std::move(alias_analysis), - std::move(hlo_live_range), std::move(call_graph))); -} - -float MemorySpaceAssignmentCostAnalysis::GetAlternateMemoryBenefit( - const HloInstruction& instruction, float elapsed_time_due_to_alternate_mem, - MemorySpaceAssignmentCostAnalysis::Cache* cache) const { - float elapsed_time_due_to_compute = - GetInstructionElapsedDueToCompute(instruction); - float elapsed_time_due_to_memory = - GetInstructionElapsedDueToMemory(instruction); - if (elapsed_time_due_to_memory > elapsed_time_due_to_compute) { - // Memory bound, return how much alternate memory is better. - float while_nest_multiplier; - if (cache) { - // If there is a cache provided, memoize the while nest multiplier. - auto it = cache->while_nest_multiplier.find(&instruction); - if (it != cache->while_nest_multiplier.end()) { - while_nest_multiplier = it->second; - } else { - while_nest_multiplier = IPow( - options_.xla_tpu_memory_space_assignment_while_execution_count, - CalculateComputationNestLevel(&instruction, - /*while_only=*/true)); - cache->while_nest_multiplier[&instruction] = while_nest_multiplier; - } - } else { - while_nest_multiplier = IPow( - options_.xla_tpu_memory_space_assignment_while_execution_count, - CalculateComputationNestLevel(&instruction, - /*while_only=*/true)); - } - return (elapsed_time_due_to_memory - elapsed_time_due_to_alternate_mem) * - while_nest_multiplier; - } else { - // TODO(b/317935037): Multiply with while nest multiplier and test for - // speedup. - // Compute bound, return how far off are we to memory boundedness. - return elapsed_time_due_to_memory - elapsed_time_due_to_compute; - } -} -float MemorySpaceAssignmentCostAnalysis::GetMemoryBoundedness( - const GlobalDecreasingSizeBestFitHeap::BufferInterval& interval, - MemorySpaceAssignmentCostAnalysis::Cache* cache) const { - if (cache) { - auto it = - cache->memory_boundedness.find(interval.buffer->defining_position()); - if (it != cache->memory_boundedness.end()) { - return it->second; - } - } - float alternate_mem_benefit = - GetAlternateMemoryBenefit(interval.buffer->defining_position(), cache); - - for (const HloBuffer* buffer : alias_analysis_->ComputeBuffersAt( - interval.buffer->defining_position().instruction, - interval.buffer->defining_position().index)) { - for (const HloValue* value : buffer->values()) { - for (const HloUse& use : value->GetUses()) { - // We look inside the called computations of while and conditional, so - // don't use the benefit of while and conditional directly. - if (use.instruction->opcode() == HloOpcode::kWhile || - use.instruction->opcode() == HloOpcode::kConditional) { - continue; - } - float use_alternate_mem_benefit = GetAlternateMemoryBenefit(use, cache); - // If the benefit is positive (memory bound), add it to this buffer's - // benefit. If the benefit is negative (compute bound), calculate the - // maximum. - if (alternate_mem_benefit > 0 && use_alternate_mem_benefit > 0) { - alternate_mem_benefit += use_alternate_mem_benefit; - } else { - alternate_mem_benefit = - std::max(alternate_mem_benefit, use_alternate_mem_benefit); - } - } - } - } - - // Penalize larger buffers by dividing the benefit by the square root of - // the size. Empirically, we observed this resulted in better performance - // compared to dividing by the size. - float memory_boundedness = 1; - if (options_ - .xla_tpu_alternate_memory_benefit_scaling_factor_for_large_buffers == - "NO_SCALE") { - memory_boundedness = alternate_mem_benefit; - } else { - memory_boundedness = alternate_mem_benefit / std::sqrt(interval.size); - } - - if (cache) { - cache->memory_boundedness[interval.buffer->defining_position()] = - memory_boundedness; - } - return memory_boundedness; -} - -float MemorySpaceAssignmentCostAnalysis::GetAlternateMemoryBenefit( - const HloPosition& position, - MemorySpaceAssignmentCostAnalysis::Cache* cache) const { - return GetAlternateMemoryBenefit( - *position.instruction, - GetInstructionElapsedDueToMemory( - *position.instruction, - /*operands_in_alternate_mem=*/{}, - /*outputs_in_alternate_mem=*/{position.index}), - cache); -} - -float MemorySpaceAssignmentCostAnalysis::GetAlternateMemoryBenefit( - const HloUse& use, MemorySpaceAssignmentCostAnalysis::Cache* cache) const { - return GetAlternateMemoryBenefit( - *use.instruction, - GetInstructionElapsedDueToMemory( - *use.instruction, - /*operands_in_alternate_mem=*/{std::make_pair(use.operand_number, - use.operand_index)}), - cache); -} - -int MemorySpaceAssignmentCostAnalysis::CalculateComputationNestLevel( - const HloInstruction* instruction, bool while_only) const { - int nest_level = 0; - const HloComputation* computation = instruction->parent(); - while (!computation->IsEntryComputation()) { - auto& node = call_graph_->GetNode(computation); - auto callsites = node.caller_callsites(); - CHECK(node.computation()->IsAsyncComputation() || callsites.size() == 1) - << "The module is not flattened!"; - auto& callsite = callsites[0]; - if (!while_only || callsite.instruction()->opcode() == HloOpcode::kWhile) { - ++nest_level; - } - computation = callsite.instruction()->parent(); - } - return nest_level; -} - -float MemorySpaceAssignmentCostAnalysis::GetDefaultMemoryAccessOverhead( - const HloInstruction& instruction, - absl::Span> operands_in_alternate_mem, - absl::Span outputs_in_alternate_mem) const { - // Calculate the pipeline overhead of accessing the default memory. We use the - // maximum of the window size heuristic and the actual default memory bytes - // accessed multiplied with the compute as the overhead. So, the math is: - // - // overhead = compute_per_iteration - // = compute_elapsed / num_iterations - // = compute_elapsed / (bytes_accessed / window_size) - // = (window_size / bytes_accessed) * compute_elapsed - const float window_size_bytes = - options_.pipeline_overhead_window_size_mib * 1024 * 1024; - const float bytes_accessed = cost_analysis_.bytes_accessed(instruction); - const float default_memory_bytes_accessed = - bytes_accessed - - GetBytesAccessedFromAlternateMemory( - instruction, operands_in_alternate_mem, outputs_in_alternate_mem); - const float compute_elapsed = GetInstructionElapsedDueToCompute(instruction); - const float effective_window_size_bytes = - std::min(window_size_bytes, default_memory_bytes_accessed); - float overhead = 0; - if (bytes_accessed > 0) { - overhead = (effective_window_size_bytes / bytes_accessed) * compute_elapsed; - } - return overhead; -} - -float MemorySpaceAssignmentCostAnalysis::GetDefaultMemoryBandwidthIdleTime( - const HloInstruction& instruction, - absl::Span> operands_in_alternate_mem, - absl::Span outputs_in_alternate_mem) const { - const float default_memory_bytes_accessed = - cost_analysis_.bytes_accessed(instruction) - - GetBytesAccessedFromAlternateMemory( - instruction, operands_in_alternate_mem, outputs_in_alternate_mem); - const float elapsed_due_to_default_mem = - default_memory_bytes_accessed / - cost_analysis_.per_second_rate(HloCostAnalysis::kBytesAccessedKey); - const float elapsed = GetInstructionElapsedInAlternateMemory( - instruction, operands_in_alternate_mem, outputs_in_alternate_mem); - return elapsed - elapsed_due_to_default_mem; -} - -float MemorySpaceAssignmentCostAnalysis::GetBytesAccessedFromAlternateMemory( - const HloInstruction& instruction, - absl::Span> operands_in_alternate_mem, - absl::Span outputs_in_alternate_mem) const { - float bytes_accessed_from_alternate_mem = 0.0; - for (auto& operand : operands_in_alternate_mem) { - const float operand_bytes_accessed = cost_analysis_.operand_bytes_accessed( - instruction, operand.first, operand.second); - bytes_accessed_from_alternate_mem += operand_bytes_accessed; - } - - for (auto& shape_idx : outputs_in_alternate_mem) { - const float output_bytes_accessed = - cost_analysis_.output_bytes_accessed(instruction, shape_idx); - bytes_accessed_from_alternate_mem += output_bytes_accessed; - } - return bytes_accessed_from_alternate_mem; -} - -namespace { -// Returns true on async instructions since we assume they are already -// efficiently scheduled such that they are not in the critical path and appear -// to take no time. -bool ExcludeInstructionFromElapsed(const HloInstruction& instruction) { - return instruction.opcode() == HloOpcode::kAllGatherStart || - instruction.opcode() == HloOpcode::kAllGatherDone || - instruction.opcode() == HloOpcode::kAllReduceStart || - instruction.opcode() == HloOpcode::kAllReduceDone || - instruction.opcode() == HloOpcode::kAsyncStart || - instruction.opcode() == HloOpcode::kAsyncDone || - instruction.opcode() == HloOpcode::kCollectivePermuteStart || - instruction.opcode() == HloOpcode::kCollectivePermuteDone || - instruction.opcode() == HloOpcode::kCopyStart || - instruction.opcode() == HloOpcode::kCopyDone; -} } // namespace -float MemorySpaceAssignmentCostAnalysis::GetInstructionElapsedDueToCompute( - const HloInstruction& instruction) const { - if (ExcludeInstructionFromElapsed(instruction)) { - return 0.0f; - } - return std::max( - cost_analysis_.flop_count(instruction) / - cost_analysis_.per_second_rate(HloCostAnalysis::kFlopsKey), - cost_analysis_.transcendental_count(instruction) / - cost_analysis_.per_second_rate(HloCostAnalysis::kTranscendentalsKey)); -} - -float MemorySpaceAssignmentCostAnalysis::GetInstructionElapsedDueToMemory( - const HloInstruction& instruction, - absl::Span> operands_in_alternate_mem, - absl::Span outputs_in_alternate_mem) const { - if (ExcludeInstructionFromElapsed(instruction)) { - return 0.0f; - } - float total_bytes_accessed = cost_analysis_.bytes_accessed(instruction); - float bytes_accessed_from_alternate_mem = GetBytesAccessedFromAlternateMemory( - instruction, operands_in_alternate_mem, outputs_in_alternate_mem); - float elapsed_due_to_alternate_mem = - bytes_accessed_from_alternate_mem / - options().alternate_mem_bandwidth_bytes_per_second; - float elapsed_due_to_default_mem = - (total_bytes_accessed - bytes_accessed_from_alternate_mem) / - cost_analysis_.per_second_rate(HloCostAnalysis::kBytesAccessedKey); - return elapsed_due_to_alternate_mem + elapsed_due_to_default_mem; -} - -float MemorySpaceAssignmentCostAnalysis::GetInstructionElapsedDueToMemory( - const HloInstruction& instruction, - IsInAlternateMemoryFun is_in_alternate_mem) const { - if (ExcludeInstructionFromElapsed(instruction)) { - return 0.0f; - } - float total_bytes_accessed = cost_analysis_.bytes_accessed(instruction); - float bytes_accessed_from_alternate_mem = 0.0; - for (int operand_num = 0; operand_num < instruction.operand_count(); - ++operand_num) { - ShapeUtil::ForEachSubshape( - instruction.operand(operand_num)->shape(), - [&](const Shape& subshape, const ShapeIndex& index) { - if (!subshape.IsArray()) { - return; - } - if (is_in_alternate_mem(operand_num, index, subshape)) { - bytes_accessed_from_alternate_mem += - cost_analysis_.operand_bytes_accessed(instruction, operand_num, - index); - } - }); - } - ShapeUtil::ForEachSubshape(instruction.shape(), [&](const Shape& subshape, - const ShapeIndex& index) { - if (!subshape.IsArray()) { - return; - } - if (is_in_alternate_mem(/*operand_num=*/std::nullopt, index, subshape)) { - bytes_accessed_from_alternate_mem += - cost_analysis_.output_bytes_accessed(instruction, index); - } - }); - float elapsed_due_to_alternate_mem = - bytes_accessed_from_alternate_mem / - options().alternate_mem_bandwidth_bytes_per_second; - float elapsed_due_to_default_mem = - (total_bytes_accessed - bytes_accessed_from_alternate_mem) / - cost_analysis_.per_second_rate(HloCostAnalysis::kBytesAccessedKey); - return elapsed_due_to_alternate_mem + elapsed_due_to_default_mem; -} - -float MemorySpaceAssignmentCostAnalysis::GetInstructionElapsed( - const HloInstruction& instruction) const { - if (ExcludeInstructionFromElapsed(instruction)) { - return 0.0f; - } - float overhead = GetDefaultMemoryAccessOverhead(instruction); - return std::max(GetInstructionElapsedDueToCompute(instruction), - GetInstructionElapsedDueToMemory(instruction) + overhead); -} - -float MemorySpaceAssignmentCostAnalysis::GetInstructionElapsedInAlternateMemory( - const HloInstruction& instruction, - absl::Span> operands_in_alternate_mem, - absl::Span outputs_in_alternate_mem) const { - if (ExcludeInstructionFromElapsed(instruction)) { - return 0.0f; - } - float overhead = GetDefaultMemoryAccessOverhead( - instruction, operands_in_alternate_mem, outputs_in_alternate_mem); - return std::max( - GetInstructionElapsedDueToCompute(instruction), - GetInstructionElapsedDueToMemory(instruction, operands_in_alternate_mem, - outputs_in_alternate_mem) + - overhead); -} - -float MemorySpaceAssignmentCostAnalysis::GetInstructionElapsedInAlternateMemory( - const HloInstruction& instruction, - IsInAlternateMemoryFun is_in_alternate_mem) const { - if (ExcludeInstructionFromElapsed(instruction)) { - return 0.0f; - } - return std::max( - GetInstructionElapsedDueToCompute(instruction), - GetInstructionElapsedDueToMemory(instruction, is_in_alternate_mem)); -} - -float MemorySpaceAssignmentCostAnalysis::GetAsyncCopyElapsed( - const Shape& shape) const { - int64_t size_in_bytes = cost_analysis_.GetShapeSize(shape); - return static_cast(size_in_bytes) / - (options().async_copy_bandwidth_bytes_per_second * - options().async_copy_bandwidth_scaling_factor); -} - -int64_t MemorySpaceAssignmentCostAnalysis::GetScheduleEndTime() const { - return hlo_live_range_->schedule_end_time(); -} - -bool InstructionCountPrefetchIntervalPicker::CanAllocateInAlternateMemoryNoCopy( - const Shape& shape, int64_t start_time, int64_t end_time) const { - return end_time - start_time <= max_overlap_count_; -} - -int64_t InstructionCountPrefetchIntervalPicker::PreferredEvictionEndTime( - const Shape& shape, int64_t start_time, int64_t latest_end_time) const { - return std::min(start_time + min_overlap_count_, latest_end_time); -} - -int64_t InstructionCountPrefetchIntervalPicker::LatestPrefetchStartTime( - const Shape& shape, int64_t start_time, int64_t end_time, - const HloUse* use) const { - return end_time - min_overlap_count_; -} - -int64_t InstructionCountPrefetchIntervalPicker::PreferredPrefetchStartTime( - const Shape& shape, int64_t earliest_prefetch_start_time, - int64_t latest_prefetch_start_time, int64_t prefetch_end_time) const { - return std::max(earliest_prefetch_start_time, - prefetch_end_time - max_overlap_count_); -} - -int64_t InstructionCountPrefetchIntervalPicker::EstimatedPrefetchEndTime( - const Shape& shape, int64_t start_time, int64_t end_time) const { - // For testing, assume the end time is the estimated prefetch end time. - return end_time; -} - -float InstructionCountPrefetchIntervalPicker::GetLogicalIntervalElapsed( - int64_t start_time, int64_t end_time) const { - // For testing, just assume every HLO takes 1 second. - return static_cast(end_time - start_time - 1); -} - -void InstructionCountPrefetchIntervalPicker::Begin( - const HloUse& use, int64_t start_time, int64_t end_time, - std::optional preferred_time) { - end_time_ = end_time; - const Shape& shape = ShapeUtil::GetSubshape( - use.instruction->operand(use.operand_number)->shape(), use.operand_index); - if (preferred_time) { - current_prefetch_time_ = *preferred_time; - } else { - current_prefetch_time_ = - PreferredPrefetchStartTime(shape, start_time, end_time, end_time); - } -} - -int64_t InstructionCountPrefetchIntervalPicker::Next() { - CHECK(!Done()) << "Prefetch interval picker's Next() is called even though " - "Done() is false"; - return current_prefetch_time_++; -} - -bool InstructionCountPrefetchIntervalPicker::Done() const { - return end_time_ - current_prefetch_time_ <= min_overlap_count_; -} - -int64_t InstructionCountPrefetchIntervalPicker::latest_time() const { - return end_time_ - min_overlap_count_ - 1; -} - -std::string InstructionCountPrefetchIntervalPicker::ToDebugString() const { - return absl::StrCat("Overlapped HLOs = ", end_time_ - current_prefetch_time_); -} - -std::string InstructionCountPrefetchIntervalPicker::ToNoCopyDebugString( - const Shape& shape, int64_t start_time, int64_t end_time) const { - return absl::StrCat("Overlapped HLOs = ", end_time - start_time); -} - -CostAnalysisPrefetchIntervalPicker::CostAnalysisPrefetchIntervalPicker( - const MemorySpaceAssignmentCostAnalysis& cost_analysis, - float min_overlap_to_async_copy_ratio, - float preferred_overlap_to_async_copy_ratio, - float max_overlap_to_mem_size_async_copy_ratio, int64_t mem_size_bytes, - const Shape* shape_override) - : while_nest_level_( - cost_analysis.hlo_live_range().instruction_schedule().size() + 1, 0), - computation_nest_level_( - cost_analysis.hlo_live_range().instruction_schedule().size() + 1, 0), - cost_analysis_(cost_analysis), - min_overlap_to_async_copy_ratio_(min_overlap_to_async_copy_ratio), - preferred_overlap_to_async_copy_ratio_( - preferred_overlap_to_async_copy_ratio), - max_async_copy_elapsed_( - cost_analysis_.GetAsyncCopyElapsed( - ShapeUtil::MakeShape(S32, {mem_size_bytes / 4})) * - max_overlap_to_mem_size_async_copy_ratio), - shape_override_(shape_override ? std::optional(*shape_override) - : std::nullopt) { - instruction_schedule_ = - &cost_analysis_.hlo_live_range().instruction_schedule(); - - // Create a vector of elapsed times and while nesting levels of HLO - // instructions. The elapsed times are multiplied by - // pow(while_execution_count, nest_level) to account for executing the HLOs - // multiple times in while loops. - std::vector instructions_elapsed_time( - instruction_schedule_->size() + 1, 0.0); - int max_while_nest_level = 0; - for (const auto& instruction_and_logical_time : *instruction_schedule_) { - // To avoid double counting, don't include the elapsed time of while and - // conditional HLOs. - const HloInstruction* instruction = instruction_and_logical_time.first; - int64_t logical_time = instruction_and_logical_time.second; - if (logical_time >= instructions_elapsed_time.size()) { - instructions_elapsed_time.resize(logical_time + 1, 0.0); - while_nest_level_.resize(logical_time + 1, 0); - } - int while_nest_level = cost_analysis_.CalculateComputationNestLevel( - instruction_and_logical_time.first, /*while_only=*/true); - while_nest_level_[logical_time] = while_nest_level; - max_while_nest_level = std::max(max_while_nest_level, while_nest_level); - int computation_nest_level = cost_analysis_.CalculateComputationNestLevel( - instruction_and_logical_time.first, /*while_only=*/false); - computation_nest_level_[logical_time] = computation_nest_level; - if (instruction->opcode() == HloOpcode::kWhile || - instruction->opcode() == HloOpcode::kConditional) { - continue; - } - float elapsed_time = cost_analysis_.GetInstructionElapsed( - *instruction_and_logical_time.first); - instructions_elapsed_time[logical_time] = - elapsed_time * - IPow(cost_analysis_.options() - .xla_tpu_memory_space_assignment_while_execution_count, - while_nest_level); - } - // As an optimization, create a cumulative sum vector of elapsed time. - float cumsum = 0.0; - elapsed_time_cumsum_.reserve(instructions_elapsed_time.size()); - for (float elapsed_time : instructions_elapsed_time) { - cumsum += elapsed_time; - elapsed_time_cumsum_.push_back(cumsum); - } - // To be able to accurately determine the minimum nest level between a start - // time and an end time efficiently, populate a data structure that stores the - // closest 'smaller' nest level change index. - const int64_t size = instructions_elapsed_time.size(); - CHECK_EQ(size, while_nest_level_.size()); - std::vector most_recent_by_level(while_nest_level_.size(), -1); - int prev_nest_level = 0; - int change_idx = -1; - while_nest_level_change_.reserve(size); - for (int i = 0; i < size; ++i) { - int nest_level = while_nest_level_[i]; - if (nest_level != prev_nest_level) { - prev_nest_level = nest_level; - // Compute last change index by choosing the most recent instruction index - // with smaller nesting level. Note that it may happen that even though - // there were few different regions with other nest levels before, all of - // then are same or bigger than this one, in which case we'll end up with - // -1, e.g. if you got nest level 0 no need checking anything else. - change_idx = -1; - for (int smaller_level = 0; smaller_level < nest_level; smaller_level++) { - change_idx = std::max(change_idx, most_recent_by_level[smaller_level]); - } - } - most_recent_by_level[nest_level] = i; - while_nest_level_change_.push_back(change_idx); - } - for (int i = 0; i <= max_while_nest_level; ++i) { - while_execution_counts_.push_back( - IPow(cost_analysis_.options() - .xla_tpu_memory_space_assignment_while_execution_count, - i)); - } -} - -float CostAnalysisPrefetchIntervalPicker::GetMaxElapsedInAlternateMemory( - float async_copy_elapsed) const { - return max_async_copy_elapsed_; -} - -bool CostAnalysisPrefetchIntervalPicker::CanAllocateInAlternateMemoryNoCopy( - const Shape& shape, int64_t start_time, int64_t end_time) const { - // Even though this method returns if we allow the buffer in alternate memory - // _without_ asynchronous copies, calculate how long it would have taken to - // copy it and compare it to the elapsed time in the logical interval. - float async_copy_elapsed = cost_analysis_.GetAsyncCopyElapsed( - shape_override_ ? *shape_override_ : shape); - float logical_interval_elapsed = - GetLogicalIntervalElapsed(start_time, end_time); - return GetMaxElapsedInAlternateMemory(async_copy_elapsed) > - logical_interval_elapsed; -} - -int64_t CostAnalysisPrefetchIntervalPicker::PreferredEvictionEndTime( - const Shape& shape, int64_t start_time, int64_t latest_end_time) const { - float async_copy_elapsed = cost_analysis_.GetAsyncCopyElapsed( - shape_override_ ? *shape_override_ : shape); - int64_t end_time; - for (end_time = start_time + 1; end_time <= latest_end_time; ++end_time) { - float logical_interval_elapsed = - GetLogicalIntervalElapsed(start_time, end_time); - if (logical_interval_elapsed >= - (1 + kEvictionRetryMultiplier * retry_number_) * - preferred_overlap_to_async_copy_ratio_ * async_copy_elapsed) { - break; - } - } - return end_time; -} - -int64_t CostAnalysisPrefetchIntervalPicker::LatestPrefetchStartTime( - const Shape& shape, int64_t start_time, int64_t end_time, - const HloUse* use) const { - // Find the earliest time that satisfies max_overlap_to_async_copy_ratio_. - float async_copy_elapsed = cost_analysis_.GetAsyncCopyElapsed( - shape_override_ ? *shape_override_ : shape); - // If there is a use, estimate the time we would save by having this op in - // alternate memory. - float inst_elapsed_reduction = 0.0f; - if (use) { - float elapsed_time = - cost_analysis_.GetInstructionElapsed(*use->instruction); - float elapsed_time_in_alternate_mem = - cost_analysis_.GetInstructionElapsedInAlternateMemory( - *use->instruction, - /*operands_in_alternate_mem=*/ - {std::make_pair(use->operand_number, use->operand_index)}, - /*outputs_in_alternate_mem=*/{}); - inst_elapsed_reduction = elapsed_time - elapsed_time_in_alternate_mem; - } - int end_nest_level = computation_nest_level_[end_time]; - - // Find the latest time we're allowed to start prefetching. - float min_interval = min_overlap_to_async_copy_ratio_ * async_copy_elapsed; - int latest_prefetch_time; - for (latest_prefetch_time = end_time - 1; - latest_prefetch_time >= start_time && - (computation_nest_level_[latest_prefetch_time] != end_nest_level || - min_interval > - GetLogicalIntervalElapsed(latest_prefetch_time, end_time) + - inst_elapsed_reduction); - --latest_prefetch_time) { - } - - return latest_prefetch_time; -} - -int64_t CostAnalysisPrefetchIntervalPicker::PreferredPrefetchStartTime( - const Shape& shape, int64_t earliest_prefetch_start_time, - int64_t latest_prefetch_start_time, int64_t prefetch_end_time) const { - // Between the earliest and latest prefetch interval, find the interval - // closest to the preferred interval and start iterating from there. - float async_copy_elapsed = cost_analysis_.GetAsyncCopyElapsed( - shape_override_ ? *shape_override_ : shape); - int64_t preferred_prefetch_start_time = earliest_prefetch_start_time; - float preferred_interval = - preferred_overlap_to_async_copy_ratio_ * async_copy_elapsed; - float best_interval = GetLogicalIntervalElapsed(earliest_prefetch_start_time, - prefetch_end_time); - int end_nest_level = computation_nest_level_[prefetch_end_time]; - for (int64_t prefetch_start_time = earliest_prefetch_start_time + 1; - prefetch_start_time <= latest_prefetch_start_time; - ++prefetch_start_time) { - float interval = - GetLogicalIntervalElapsed(prefetch_start_time, prefetch_end_time); - if (computation_nest_level_[prefetch_start_time] == end_nest_level && - std::abs(preferred_interval - interval) < - std::abs(preferred_interval - best_interval)) { - best_interval = interval; - preferred_prefetch_start_time = prefetch_start_time; - } - } - return preferred_prefetch_start_time; -} - -int64_t CostAnalysisPrefetchIntervalPicker::LatestPrefetchEndTime( - int64_t original_prefetch_end_time, - int64_t proposed_prefetch_end_time) const { - // Iterate towards the beginning until we find a suitable end time that is the - // same while nest level as the original prefetch end time. - int64_t original_nest_level = - computation_nest_level_[original_prefetch_end_time]; - int64_t new_prefetch_end_time; - for (new_prefetch_end_time = proposed_prefetch_end_time; - computation_nest_level_[new_prefetch_end_time] != original_nest_level; - --new_prefetch_end_time) { - } - return new_prefetch_end_time; -} - -int64_t CostAnalysisPrefetchIntervalPicker::EstimatedPrefetchEndTime( - const Shape& shape, int64_t start_time, int64_t end_time) const { - float async_copy_elapsed = cost_analysis_.GetAsyncCopyElapsed( - shape_override_ ? *shape_override_ : shape); - int64_t estimated_end_time; - for (estimated_end_time = start_time + 1; estimated_end_time < end_time; - ++estimated_end_time) { - float interval = GetLogicalIntervalElapsed(start_time, estimated_end_time); - if (interval >= async_copy_elapsed) { - break; - } - } - return estimated_end_time; -} - -void CostAnalysisPrefetchIntervalPicker::Begin( - const HloUse& use, int64_t start_time, int64_t end_time, - std::optional preferred_time) { - const Shape& shape = ShapeUtil::GetSubshape( - use.instruction->operand(use.operand_number)->shape(), use.operand_index); - // Find the earliest time that satisfies max_overlap_to_async_copy_ratio_. - async_copy_elapsed_ = cost_analysis_.GetAsyncCopyElapsed( - shape_override_ ? *shape_override_ : shape); - // Estimate the time we would save by having this op in alternate memory. - float elapsed_time = cost_analysis_.GetInstructionElapsed(*use.instruction); - float elapsed_time_in_alternate_mem = - cost_analysis_.GetInstructionElapsedInAlternateMemory( - *use.instruction, /*operands_in_alternate_mem=*/ - {std::make_pair(use.operand_number, use.operand_index)}, - /*outputs_in_alternate_mem=*/{}); - inst_elapsed_reduction_ = elapsed_time - elapsed_time_in_alternate_mem; - end_logical_time_ = end_time; - int end_nest_level = computation_nest_level_[end_logical_time_]; - - // Find the latest time we're allowed to start prefetching. - float min_interval = min_overlap_to_async_copy_ratio_ * async_copy_elapsed_; - latest_prefetch_time_ = - LatestPrefetchStartTime(shape, start_time, end_time, &use); - - // Find the earliest time we're allowed to start prefetching. - float max_interval = GetMaxElapsedInAlternateMemory(async_copy_elapsed_); - for (earliest_prefetch_time_ = start_time; - earliest_prefetch_time_ < latest_prefetch_time_ && - (computation_nest_level_[earliest_prefetch_time_] != end_nest_level || - max_interval < GetLogicalIntervalElapsed(earliest_prefetch_time_, - end_logical_time_)); - ++earliest_prefetch_time_) { - } - if (earliest_prefetch_time_ > latest_prefetch_time_) { - // There is no available prefetch interval for the given start and end - // times. Set the iterators accordingly to ensure Done() returns true. - increasing_prefetch_time_iterator_ = earliest_prefetch_time_; - decreasing_prefetch_time_iterator_ = latest_prefetch_time_; - CHECK(Done()); - return; - } - - int64_t starting_prefetch_time; - if (preferred_time && *preferred_time <= latest_prefetch_time_) { - starting_prefetch_time = *preferred_time; - } else { - starting_prefetch_time = - PreferredPrefetchStartTime(shape, earliest_prefetch_time_, - latest_prefetch_time_, end_logical_time_); - } - float preferred_interval = - preferred_overlap_to_async_copy_ratio_ * async_copy_elapsed_; - VLOG(4) << "Interval min/max/preferred = " << min_interval << " " - << max_interval << " " << preferred_interval - << " prefetch time earliest/latest/starting = " - << earliest_prefetch_time_ << " " << latest_prefetch_time_ << " " - << starting_prefetch_time; - - increasing_prefetch_time_iterator_ = starting_prefetch_time; - decreasing_prefetch_time_iterator_ = starting_prefetch_time; - using_increasing_prefetch_time_iterator_ = true; - // Since both iterators start at the same position, call Next() once to - // advance one of the iterators. - Next(); -} - -int64_t CostAnalysisPrefetchIntervalPicker::Next() { - CHECK(!Done()) << "Prefetch interval picker's Next() is called even though " - "Done() is false"; - if (using_increasing_prefetch_time_iterator_) { - int64_t prefetch_time = increasing_prefetch_time_iterator_++; - while (increasing_prefetch_time_iterator_ <= latest_prefetch_time_ && - computation_nest_level_[increasing_prefetch_time_iterator_] != - computation_nest_level_[end_logical_time_]) { - ++increasing_prefetch_time_iterator_; - } - if (decreasing_prefetch_time_iterator_ >= earliest_prefetch_time_) { - using_increasing_prefetch_time_iterator_ = false; - } - return prefetch_time; - } else { - int64_t prefetch_time = decreasing_prefetch_time_iterator_--; - // As a compilation time optimization, reduce the number of intervals that - // this prefetch interval picker returns. When we run out of the increasing - // prefetch time iterator, only explore up to - // kNumExploredDecreasingIntervals intervals. To do that, calculate the - // 1/kNumExploredDecreasingIntervals of the elapsed time between the - // earliest prefetch time and the use, and decrement the iterator until the - // prefetch elapsed time is at least as large as this target value. This - // allows us to reduce the number of expensive heap fit and resource checks - // when the graph consists of a large number of fast-executing HLOs. - // - // Shown pictorially, assuming kNumExploredDecreasingIntervals = 3 and the - // numbers indicating the elapsed time of the HLOs, only the indicated - // options for prefetch start time would be explored: - // - // ---1---1---3---1---1---1---1---0---0---0---0---1---5---X - // ^ ^ ^ ^ - // Option3 Option2 Option1 Use - // (Earliest) - float next_target_interval_elapsed = 0; - if (increasing_prefetch_time_iterator_ > latest_prefetch_time_) { - next_target_interval_elapsed = - GetLogicalIntervalElapsed(prefetch_time, end_logical_time_) + - (GetLogicalIntervalElapsed(earliest_prefetch_time_, - end_logical_time_) / - kNumExploredDecreasingIntervals); - VLOG(3) << "Next target interval elapsed: " - << next_target_interval_elapsed; - } - while (decreasing_prefetch_time_iterator_ >= earliest_prefetch_time_ && - (computation_nest_level_[decreasing_prefetch_time_iterator_] != - computation_nest_level_[end_logical_time_] || - GetLogicalIntervalElapsed(decreasing_prefetch_time_iterator_, - end_logical_time_) < - next_target_interval_elapsed)) { - --decreasing_prefetch_time_iterator_; - } - if (increasing_prefetch_time_iterator_ <= latest_prefetch_time_) { - using_increasing_prefetch_time_iterator_ = true; - } - return prefetch_time; - } -} - -bool CostAnalysisPrefetchIntervalPicker::Done() const { - return increasing_prefetch_time_iterator_ > latest_prefetch_time_ && - decreasing_prefetch_time_iterator_ < earliest_prefetch_time_; -} - -int64_t CostAnalysisPrefetchIntervalPicker::latest_time() const { - return latest_prefetch_time_; -} - -void CostAnalysisPrefetchIntervalPicker::SetRetryNumber(int retry_number) { - retry_number_ = retry_number; -} - -int CostAnalysisPrefetchIntervalPicker::GetMinWhileNestLevel( - int64_t start_time, int64_t end_time) const { - int min_nest_level = - std::min(while_nest_level_[start_time], while_nest_level_[end_time]); - int change_idx = while_nest_level_change_[end_time]; - while (change_idx >= start_time) { - min_nest_level = std::min(min_nest_level, while_nest_level_[change_idx]); - change_idx = while_nest_level_change_[change_idx]; - } - return min_nest_level; -} - -float CostAnalysisPrefetchIntervalPicker::GetLogicalIntervalElapsed( - int64_t start_time, int64_t end_time) const { - CHECK_LE(start_time, end_time); - if (start_time == end_time) { - return 0.0; - } - if (start_time < 0) { - start_time = 0; - } - // Since elapsed_time_cumsum_ is already weighed by the while loop nesting - // level, normalize the elapsed time by dividing with the nesting factor of - // the interval (start and end times). - int interval_while_nest_level = GetMinWhileNestLevel(start_time, end_time); - return (elapsed_time_cumsum_[end_time - 1] - - elapsed_time_cumsum_[start_time]) / - while_execution_counts_[interval_while_nest_level]; -} - -std::string CostAnalysisPrefetchIntervalPicker::ToDebugString() const { - int current_logical_prefetch_time = using_increasing_prefetch_time_iterator_ - ? increasing_prefetch_time_iterator_ - : decreasing_prefetch_time_iterator_; - float logical_interval_elapsed = GetLogicalIntervalElapsed( - current_logical_prefetch_time, end_logical_time_); - return absl::StrCat( - "Async copy elapsed (s) = ", async_copy_elapsed_, - ", inst elapsed reduction (s) = ", inst_elapsed_reduction_, - ", logical interval elapsed (s) = ", logical_interval_elapsed, - ", interval = (", current_logical_prefetch_time, ", ", end_logical_time_, - ")"); -} - -std::string CostAnalysisPrefetchIntervalPicker::ToNoCopyDebugString( - const Shape& shape, int64_t start_time, int64_t end_time) const { - float async_copy_elapsed = cost_analysis_.GetAsyncCopyElapsed( - shape_override_ ? *shape_override_ : shape); - float logical_interval_elapsed = - GetLogicalIntervalElapsed(start_time, end_time); - return absl::StrCat( - "Async copy elapsed (s) = ", async_copy_elapsed, - ", logical interval elapsed (s) = ", logical_interval_elapsed); -} - -std::optional -CostAnalysisPrefetchIntervalPicker::BufferIntervalAlternateMemoryBenefit( - const GlobalDecreasingSizeBestFitHeap::BufferInterval& interval) - const { - return cost_analysis_.GetMemoryBoundedness(interval); -} - -bool MemorySpaceAssignment::Allocation::operator==( - const MemorySpaceAssignment::Allocation& other) const { - return defining_position() == other.defining_position() && - uses() == other.uses() && memory_space() == other.memory_space() && - chunk() == other.chunk() && start_time() == other.start_time() && - end_time() == other.end_time() && - earliest_available_time() == other.earliest_available_time() && - is_copy_allocation() == other.is_copy_allocation() && - is_scoped_allocation() == other.is_scoped_allocation(); -} - -bool MemorySpaceAssignment::CopyAllocation::operator==( - const MemorySpaceAssignment::CopyAllocation& other) const { - return static_cast(*this) == - static_cast(other) && - copy_done_schedule_before() == other.copy_done_schedule_before() && - copy_start_schedule_after() == other.copy_start_schedule_after() && - copy_start() == other.copy_start() && copy_done() == other.copy_done(); -} - std::string MemorySpaceAssignment::AllocationValue::ToString() const { std::string out = absl::StrCat("computation = ", computation()->name()); absl::StrAppend(&out, @@ -2109,8 +1205,7 @@ void AlternateMemoryBestFitHeap::AppendScopedAllocationBufferInfoDebugString( } void AlternateMemoryBestFitHeap::AppendAllocationInfoDebugString( - const MemorySpaceAssignment::Allocation& allocation, - std::string& debug_str) const { + const Allocation& allocation, std::string& debug_str) const { // Columns in allocation information: // buffer_id: int. This value can be used the match with buffer info. // size: int. In bytes. @@ -2150,7 +1245,7 @@ MemoryBoundLoopOptimizer::Create( int loop_start, int loop_end, uint64_t alternate_memory_size, const MemoryBoundLoopOptimizerOptions& options, const HloLiveRange& hlo_live_range, const HloAliasAnalysis& alias_analysis, - const MemorySpaceAssignmentCostAnalysis& cost_analysis, + const CostAnalysis& cost_analysis, const BufferValue::SizeFunction& size_function) { std::unique_ptr optimizer = absl::WrapUnique(new MemoryBoundLoopOptimizer( @@ -2164,7 +1259,7 @@ MemoryBoundLoopOptimizer::MemoryBoundLoopOptimizer( int loop_start, int loop_end, uint64_t alternate_memory_size, const MemoryBoundLoopOptimizerOptions& options, const HloLiveRange& hlo_live_range, const HloAliasAnalysis& alias_analysis, - const MemorySpaceAssignmentCostAnalysis& cost_analysis, + const CostAnalysis& cost_analysis, const BufferValue::SizeFunction& size_function) : loop_start_(loop_start), loop_end_(loop_end), @@ -2309,7 +1404,7 @@ void MemoryBoundLoopOptimizer::MaybeCreateLoopValue( // Keep track of bytes accessed by this value. if (loop_index || prev_iteration_index) { float bytes_accessed = - cost_analysis_.cost_analysis().output_bytes_accessed( + cost_analysis_.hlo_cost_analysis().output_bytes_accessed( *position.instruction, position.index); pos_bytes += bytes_accessed; VLOG(3) << " accessed: " << bytes_accessed; @@ -2340,7 +1435,7 @@ void MemoryBoundLoopOptimizer::MaybeCreateLoopValue( // Keep track of bytes accessed by this value. if (loop_index || next_iteration_index) { float bytes_accessed = - cost_analysis_.cost_analysis().operand_bytes_accessed( + cost_analysis_.hlo_cost_analysis().operand_bytes_accessed( *use.instruction, use.operand_number, use.operand_index); use_bytes += bytes_accessed; VLOG(3) << " accessed: " << bytes_accessed; @@ -2443,14 +1538,12 @@ void MemoryBoundLoopOptimizer::Optimize() { float MemoryBoundLoopOptimizer::CalculateExecutionTime() const { // First populate the list of prefetches. - std::vector> - prefetches; + std::vector> prefetches; for (const LoopValue& value : loop_values_) { if (!value.allocations.empty() && value.allocations.back()->is_copy_allocation()) { prefetches.push_back( - {static_cast( - value.allocations.back().get()), + {static_cast(value.allocations.back().get()), cost_analysis_.GetAsyncCopyElapsed( value.hlo_values.front()->shape())}); } @@ -2474,11 +1567,8 @@ float MemoryBoundLoopOptimizer::CalculateExecutionTime() const { // Sort the prefetches by first the start time, then the effective done time. absl::c_sort( - prefetches, - [&](const std::pair& - a, - const std::pair& - b) { + prefetches, [&](const std::pair& a, + const std::pair& b) { return std::forward_as_tuple( a.first->copy_start_schedule_after(), get_effective_done_time( @@ -2750,11 +1840,9 @@ void MemoryBoundLoopOptimizer::PostProcess() { if (!unallocated_uses.empty()) { // TODO(b/281582241): We should find the correct position. For now, we're // using the defining position on the first HLO value. - value.allocations.push_back( - std::make_unique( - value.hlo_values.front()->defining_position(), - MemorySpaceAssignment::MemorySpace::kDefault, std::nullopt, 0, - loop_size_, /*is_scoped_allocation=*/false)); + value.allocations.push_back(std::make_unique( + value.hlo_values.front()->defining_position(), MemorySpace::kDefault, + std::nullopt, 0, loop_size_, /*is_scoped_allocation=*/false)); for (const HloUse& use : unallocated_uses) { value.allocations.back()->AddUse(use); } @@ -2798,12 +1886,10 @@ bool MemoryBoundLoopOptimizer::AllocateTemporary(LoopValue& value) { bool success = AllocateBetween(definition_idx, max_use_idx, value.size); if (success) { VLOG(3) << "Pos: " << value.loop_positions[0].second; - value.allocations.push_back( - std::make_unique( - value.loop_positions[0].second, - MemorySpaceAssignment::MemorySpace::kAlternate, std::nullopt, - definition_idx, max_use_idx, - /*is_scoped_allocation=*/false)); + value.allocations.push_back(std::make_unique( + value.loop_positions[0].second, MemorySpace::kAlternate, std::nullopt, + definition_idx, max_use_idx, + /*is_scoped_allocation=*/false)); AddAllLoopPositionsAndUses(value, /*allocate_next_iteration_uses=*/true); } return success; @@ -2813,12 +1899,10 @@ bool MemoryBoundLoopOptimizer::AllocatePinned(LoopValue& value) { bool success = AllocateBetween(0, loop_size_ - 1, value.size); if (success) { CHECK(value.header_position); - value.allocations.push_back( - std::make_unique( - *value.header_position, - MemorySpaceAssignment::MemorySpace::kAlternate, std::nullopt, 0, - loop_size_, - /*is_scoped_allocation=*/false)); + value.allocations.push_back(std::make_unique( + *value.header_position, MemorySpace::kAlternate, std::nullopt, 0, + loop_size_, + /*is_scoped_allocation=*/false)); AddAllLoopPositionsAndUses(value, /*allocate_next_iteration_uses=*/false); } return success; @@ -3046,8 +2130,8 @@ bool MemoryBoundLoopOptimizer::AllocatePrefetch( if (early_forced_value->allocations.empty()) { continue; } - const MemorySpaceAssignment::CopyAllocation* early_forced_prefetch = - static_cast( + const CopyAllocation* early_forced_prefetch = + static_cast( early_forced_value->allocations.back().get()); VLOG(3) << "Prefetch: " << early_forced_prefetch->ToString(); @@ -3192,16 +2276,13 @@ bool MemoryBoundLoopOptimizer::AllocatePrefetch( // Create the Allocation objects that correspond to the scheduled prefetch. CHECK(value->header_position); - value->allocations.push_back( - std::make_unique( - *value->header_position, MemorySpaceAssignment::MemorySpace::kDefault, - std::nullopt, 0, loop_size_, /*is_scoped_allocation=*/false)); - value->allocations.push_back( - std::make_unique( - *value->allocations.back(), - MemorySpaceAssignment::MemorySpace::kAlternate, std::nullopt, - ((*copy_start_time - 1) + loop_size_) % loop_size_, first_use_idx, - last_use_idx_sentinel)); + value->allocations.push_back(std::make_unique( + *value->header_position, MemorySpace::kDefault, std::nullopt, 0, + loop_size_, /*is_scoped_allocation=*/false)); + value->allocations.push_back(std::make_unique( + *value->allocations.back(), MemorySpace::kAlternate, std::nullopt, + ((*copy_start_time - 1) + loop_size_) % loop_size_, first_use_idx, + last_use_idx_sentinel)); AddAllLoopPositionsAndUses(*value, /*allocate_next_iteration_uses=*/true); // Account for the additional memory used by early forcing the already @@ -3212,9 +2293,8 @@ bool MemoryBoundLoopOptimizer::AllocatePrefetch( LoopValue* early_forced_value = context.values.at( context.value_indices[early_forced_prefetch_value_index]); CHECK(!early_forced_value->allocations.empty()); - MemorySpaceAssignment::CopyAllocation* early_forced_prefetch = - static_cast( - early_forced_value->allocations.back().get()); + CopyAllocation* early_forced_prefetch = static_cast( + early_forced_value->allocations.back().get()); for (int index = early_forced_prefetch->copy_start_schedule_after(); index >= *copy_start_time; --index) { update_additional_memory_used(index, early_forced_value->size); @@ -3231,7 +2311,7 @@ bool MemoryBoundLoopOptimizer::AllocatePrefetch( void MemoryBoundLoopOptimizer::AddAllLoopPositionsAndUses( LoopValue& value, bool allocate_next_iteration_uses) { CHECK_GE(value.allocations.size(), 1); - MemorySpaceAssignment::Allocation& allocation = *value.allocations.back(); + Allocation& allocation = *value.allocations.back(); for (const auto& [idx, position] : value.loop_positions) { positions_in_alternate_mem_[position.instruction].push_back(position.index); } @@ -3928,26 +3008,23 @@ HloPosition TupleUseToPosition(const HloUse& use) { } // Returns the memory space of the defining position of an Allocation object. -MemorySpaceAssignment::MemorySpace GetDefiningPositionMemorySpace( - const MemorySpaceAssignment::Allocation& allocation) { +MemorySpace GetDefiningPositionMemorySpace(const Allocation& allocation) { if (!allocation.is_copy_like_allocation()) { return allocation.memory_space(); } - if (allocation.memory_space() == - MemorySpaceAssignment::MemorySpace::kDefault) { - return MemorySpaceAssignment::MemorySpace::kAlternate; + if (allocation.memory_space() == MemorySpace::kDefault) { + return MemorySpace::kAlternate; } - return MemorySpaceAssignment::MemorySpace::kDefault; + return MemorySpace::kDefault; } } // namespace -std::vector> +std::vector> AlternateMemoryBestFitHeap::GetLinkedAllocationsInAlternateMemory( absl::Span allocation_values) const { - std::vector> - linked_allocations; + std::vector> linked_allocations; // A map from position to index into linked_allocations. absl::flat_hash_map link_id_map; // Iterate over the allocation values. Find Allocation objects across the @@ -4040,8 +3117,7 @@ AlternateMemoryBestFitHeap::GetLinkedAllocationsInAlternateMemory( if (VLOG_IS_ON(3)) { for (int i = 0; i < linked_allocations.size(); ++i) { VLOG(3) << "Link id = " << i; - for (const MemorySpaceAssignment::Allocation* allocation : - linked_allocations[i]) { + for (const Allocation* allocation : linked_allocations[i]) { VLOG(3) << " " << allocation->ToString(); } } @@ -4080,15 +3156,16 @@ AlternateMemoryBestFitHeap::GetInefficientAllocationSites( const HloPosition& defining_position = allocation->defining_position(); int64_t accessed = - options_.cost_analysis->cost_analysis().output_bytes_accessed( + options_.cost_analysis->hlo_cost_analysis().output_bytes_accessed( *defining_position.instruction, defining_position.index); VLOG(3) << " pos: " << defining_position.ToString() << ", accessed: " << accessed << " / " << size; } for (const HloUse& use : allocation->uses()) { int64_t accessed = - options_.cost_analysis->cost_analysis().operand_bytes_accessed( - *use.instruction, use.operand_number, use.operand_index); + options_.cost_analysis->hlo_cost_analysis() + .operand_bytes_accessed(*use.instruction, use.operand_number, + use.operand_index); VLOG(3) << " use: " << use.ToString() << ", accessed: " << accessed << " / " << size; } @@ -4096,12 +3173,11 @@ AlternateMemoryBestFitHeap::GetInefficientAllocationSites( } } - std::vector> - linked_allocations = - GetLinkedAllocationsInAlternateMemory(allocation_values); + std::vector> linked_allocations = + GetLinkedAllocationsInAlternateMemory(allocation_values); std::vector inefficient_sites; - for (const std::vector& - allocation_group : linked_allocations) { + for (const std::vector& allocation_group : + linked_allocations) { // For all of allocation in the linked allocation group, calculate the total // use bytes in alternate memory and async copy bytes. If the ratio between // the two is below inefficient_use_to_copy_ratio, add all of the @@ -4109,8 +3185,7 @@ AlternateMemoryBestFitHeap::GetInefficientAllocationSites( VLOG(3) << "AllocationGroup:"; int64_t copy_bytes = 0; int64_t use_bytes = 0; - for (const MemorySpaceAssignment::Allocation* allocation : - allocation_group) { + for (const Allocation* allocation : allocation_group) { VLOG(3) << " Allocation: " << allocation->ToString(); MemorySpace position_memory_space = GetDefiningPositionMemorySpace(*allocation); @@ -4119,22 +3194,22 @@ AlternateMemoryBestFitHeap::GetInefficientAllocationSites( } if (position_memory_space == MemorySpace::kAlternate) { use_bytes += - options_.cost_analysis->cost_analysis().output_bytes_accessed( + options_.cost_analysis->hlo_cost_analysis().output_bytes_accessed( *allocation->defining_position().instruction, allocation->defining_position().index); } if (allocation->memory_space() == MemorySpace::kAlternate) { for (const HloUse& use : allocation->uses()) { use_bytes += - options_.cost_analysis->cost_analysis().operand_bytes_accessed( - *use.instruction, use.operand_number, use.operand_index); + options_.cost_analysis->hlo_cost_analysis() + .operand_bytes_accessed(*use.instruction, use.operand_number, + use.operand_index); } } } VLOG(3) << " use bytes: " << use_bytes << ", copy bytes: " << copy_bytes; if (options_.inefficient_use_to_copy_ratio * copy_bytes > use_bytes) { - for (const MemorySpaceAssignment::Allocation* allocation : - allocation_group) { + for (const Allocation* allocation : allocation_group) { MemorySpace position_memory_space = GetDefiningPositionMemorySpace(*allocation); if (position_memory_space == MemorySpace::kAlternate) { @@ -4409,16 +3484,22 @@ AlternateMemoryBestFitHeap::AllocateAllocationValues( loop_optimized_allocations_map_.end()) { const LoopOptimizedAllocationInfo& loop_optimized_allocation_info = loop_optimized_allocation_it->second; - const MemorySpaceAssignment::Allocation* allocation = + const Allocation* allocation = loop_optimized_allocation_info.loop_optimized_allocation; VLOG(3) << "Found optimized allocation for " << use.hlo_use.ToString() << " (loop idx: " << loop_optimized_allocation_info.use_index << "): " << allocation->ToString(); - if (allocation->is_copy_allocation()) { + if (require_no_copy_alternate_mem_allocation) { + if (allocation->is_copy_allocation() || + allocation->memory_space() == MemorySpace::kDefault) { + LOG(WARNING) << "Optimized allocation could not be applied " + "because the tensor is pre-colored, allocation: " + << allocation->ToString(); + } + } else if (allocation->is_copy_allocation()) { allow_no_copy_alternate_mem_allocation = true; - const MemorySpaceAssignment::CopyAllocation* copy_allocation = - static_cast( - allocation); + const CopyAllocation* copy_allocation = + static_cast(allocation); int64_t effective_copy_start_time = copy_allocation->copy_start_schedule_after(); if (copy_allocation->copy_start_schedule_after() == @@ -4546,9 +3627,8 @@ AlternateMemoryBestFitHeap::AllocateAllocationValues( } // Propagate the allocation to any aliases this use might have had. - MemorySpaceAssignment::Allocation* aliased_allocation = - GetLiveAllocationAt(*allocation_value.allocation_sequence(), - use_time); + Allocation* aliased_allocation = GetLiveAllocationAt( + *allocation_value.allocation_sequence(), use_time); for (const HloPosition& aliased_position : use.aliases) { AddAliasedRequiredAssignment(aliased_position.instruction, aliased_position.index, @@ -4599,7 +3679,7 @@ AlternateMemoryBestFitHeap::AllocateAllocationValues( int64_t body_parameter_time = instruction_schedule.at( body_allocation_value_it->defining_instruction()); body_allocation_value_it->mutable_allocation_sequence()->push_back( - std::make_unique( + std::make_unique( **prev_allocation_in_default_mem_it, hlo_use.instruction, body_allocation_value_it->defining_position(), body_parameter_time)); @@ -4617,9 +3697,8 @@ AlternateMemoryBestFitHeap::AllocateAllocationValues( << after_while_allocation_value_it->ToShortString(); int64_t while_time = instruction_schedule.at(hlo_use.instruction); after_while_allocation_value_it->mutable_allocation_sequence() - ->push_back( - std::make_unique( - **prev_allocation_in_default_mem_it, while_time)); + ->push_back(std::make_unique( + **prev_allocation_in_default_mem_it, while_time)); VLOG(3) << "Created: " << after_while_allocation_value_it->allocation_sequence() ->back() @@ -4921,7 +4000,7 @@ struct CopyResourceDumpData { std::string AsynchronousCopyResource::Dump( int64_t start_time, int64_t end_time, - MemorySpaceAssignment::MemorySpace memory_space_filter) const { + MemorySpace memory_space_filter) const { std::vector available = GetCurrentResources(); std::vector time_dump_data; for (int i = start_time; i < end_time; ++i) { @@ -4983,15 +4062,14 @@ std::string AsynchronousCopyResource::Dump( } AlternateMemoryBestFitHeap::AliasedOffset* -AlternateMemoryBestFitHeap::GetAliasedOffset( - const MemorySpaceAssignment::Allocation& allocation) { +AlternateMemoryBestFitHeap::GetAliasedOffset(const Allocation& allocation) { auto aliased_offset_it = aliased_offset_map_.find(&allocation); CHECK(aliased_offset_it != aliased_offset_map_.end()); return aliased_offset_it->second; } void AlternateMemoryBestFitHeap::CreateOrAddToAliasedOffset( - const MemorySpaceAssignment::Allocation& allocation, + const Allocation& allocation, AlternateMemoryBestFitHeap::AliasedOffset* aliased_offset) { CHECK(allocation.memory_space() == MemorySpace::kAlternate); CHECK(!aliased_offset_map_.contains(&allocation)); @@ -5004,8 +4082,7 @@ void AlternateMemoryBestFitHeap::CreateOrAddToAliasedOffset( aliased_offset_map_[&allocation] = aliased_offset; } -/*static*/ MemorySpaceAssignment::Allocation* -AlternateMemoryBestFitHeap::GetLiveAllocationAt( +/*static*/ Allocation* AlternateMemoryBestFitHeap::GetLiveAllocationAt( const MemorySpaceAssignment::AllocationSequence& allocations, int64_t time) { for (auto allocation_it = allocations.rbegin(); @@ -5032,7 +4109,7 @@ void AlternateMemoryBestFitHeap::AllocateCrossProgramPrefetchBuffer( module->AddCrossProgramPrefetch(parameter, buffer->index()); MemorySpaceAssignment::AllocationSequence allocations; - allocations.push_back(std::make_unique( + allocations.push_back(std::make_unique( buffer->defining_position(), MemorySpace::kDefault, kDummyChunk, prefetch_candidate.start, prefetch_candidate.end, /*is_scoped_allocation=*/false)); @@ -5200,10 +4277,9 @@ void AlternateMemoryBestFitHeap::AllocateReservedScopedAllocations() { instruction, i, reserved_scoped_memory, buffer_info_str_); } - allocations_->push_back( - std::make_unique( - HloPosition{instruction_sequence[i], {}}, MemorySpace::kAlternate, - chunk_candidate, i, i, /*is_scoped_allocation=*/true)); + allocations_->push_back(std::make_unique( + HloPosition{instruction_sequence[i], {}}, MemorySpace::kAlternate, + chunk_candidate, i, i, /*is_scoped_allocation=*/true)); repack_allocation_blocks_.push_back(MakeRepackAllocationBlock( i, i, reserved_scoped_memory, @@ -5280,7 +4356,7 @@ AlternateMemoryBestFitHeap::AliasedRequiredAssignmentForUse( void AlternateMemoryBestFitHeap::AddAliasedRequiredAssignment( const HloInstruction* instruction, ShapeIndex index, - const MemorySpaceAssignment::Allocation* aliased_allocation) { + const Allocation* aliased_allocation) { AliasedOffset* offset = nullptr; if (aliased_allocation->memory_space() == MemorySpace::kAlternate) { offset = GetAliasedOffset(*aliased_allocation); @@ -5291,8 +4367,8 @@ void AlternateMemoryBestFitHeap::AddAliasedRequiredAssignment( void AlternateMemoryBestFitHeap::AddRequiredAssignment( const HloValue* value, const HloInstruction* instruction, - MemorySpaceAssignment::MemorySpace memory_space, int64_t time, - AliasedOffset* offset, bool add_to_pending) { + MemorySpace memory_space, int64_t time, AliasedOffset* offset, + bool add_to_pending) { // Check for existing required assignment at this time and make sure it is the // same as this if there is one. auto existing_required_assignment = RequiredMemoryAssignmentAt(value, time); @@ -5527,7 +4603,7 @@ void AlternateMemoryBestFitHeap::UpdateReservedScopedAllocationSize() { } // Update scoped allocation sizes. for (RepackAllocationBlock& allocation_block : repack_allocation_blocks_) { - MemorySpaceAssignment::Allocation* allocation = allocation_block.allocation; + Allocation* allocation = allocation_block.allocation; if (allocation->is_scoped_allocation()) { allocation_block.size = reserved_scoped_memory_map[allocation->start_time()]; @@ -5539,7 +4615,6 @@ void AlternateMemoryBestFitHeap::UpdateReservedScopedAllocationSize() { void AlternateMemoryBestFitHeap::ExportAllocationsForRepacking( std::vector& allocations) { - using SlicedCopyAllocation = MemorySpaceAssignment::SlicedCopyAllocation; using SliceDetail = SlicedCopyAllocation::SliceDetail; if (options_.reduce_scoped_memory_limit) { @@ -5602,7 +4677,7 @@ void AlternateMemoryBestFitHeap::ImportRepackedAllocations() { void AlternateMemoryBestFitHeap::ImportRepackedNonSlicedAllocation( RepackAllocationBlock& block) { - MemorySpaceAssignment::Allocation* allocation = block.allocation; + Allocation* allocation = block.allocation; int64_t original_offset = block.initial_offset; int64_t repacked_offset = block.offset; @@ -5621,7 +4696,7 @@ void AlternateMemoryBestFitHeap::ImportRepackedNonSlicedAllocation( void AlternateMemoryBestFitHeap::ImportRepackedSlicedAllocation( RepackAllocationBlock& block) { - using SlicedCopyAllocation = MemorySpaceAssignment::SlicedCopyAllocation; + using SlicedCopyAllocation = memory_space_assignment::SlicedCopyAllocation; using SliceDetail = SlicedCopyAllocation::SliceDetail; CHECK_OK(AreRepackedSlicesValid(block)); @@ -5788,8 +4863,7 @@ void AlternateMemoryBestFitHeap::UncommitPendingChunks( void AlternateMemoryBestFitHeap::FinalizeAllocations( absl::Span allocation_values) { - absl::flat_hash_map> + absl::flat_hash_map> colocation_map; for (AllocationValue& allocation_value : allocation_values) { for (auto& allocation : *allocation_value.mutable_allocation_sequence()) { @@ -5806,8 +4880,7 @@ void AlternateMemoryBestFitHeap::FinalizeAllocations( } } allocations_->push_back(std::move(allocation)); - MemorySpaceAssignment::Allocation* inserted_allocation = - allocations_->back().get(); + Allocation* inserted_allocation = allocations_->back().get(); if (inserted_allocation->memory_space() == MemorySpace::kAlternate) { colocation_map[GetAliasedOffset(*inserted_allocation)].push_back( inserted_allocation); @@ -5819,8 +4892,7 @@ void AlternateMemoryBestFitHeap::FinalizeAllocations( // reduce fragmentation. for (auto& colocation : colocation_map) { std::vector colocations; - for (MemorySpaceAssignment::Allocation* colocated_allocation : - colocation.second) { + for (Allocation* colocated_allocation : colocation.second) { repack_allocation_blocks_.push_back(MakeRepackAllocationBlock( colocated_allocation->start_time(), colocated_allocation->end_time(), colocated_allocation->chunk().size, @@ -5888,7 +4960,7 @@ AlternateMemoryBestFitHeap::Result AlternateMemoryBestFitHeap::AllocateSegment( // consumed multiple times by the same instruction. We can just find the // previous allocation and use that allocation. if (request.inclusive_start_time == request.end_time) { - MemorySpaceAssignment::Allocation* allocation = + Allocation* allocation = GetLiveAllocationAt(*allocation_sequence, request.end_time); CHECK_NE(allocation, nullptr); allocation->AddUse(request.use->hlo_use); @@ -5919,12 +4991,13 @@ AlternateMemoryBestFitHeap::Result AlternateMemoryBestFitHeap::AllocateSegment( << " use benefit = " << options_.cost_analysis->GetAlternateMemoryBenefit( request.use->hlo_use); - VLOG(3) << "Definition bytes accessed = " - << options_.cost_analysis->cost_analysis().output_bytes_accessed( - *defining_position.instruction, defining_position.index) - << ", use bytes accessed = " - << options_.cost_analysis->cost_analysis().operand_bytes_accessed( - *use.instruction, use.operand_number, use.operand_index); + VLOG(3) + << "Definition bytes accessed = " + << options_.cost_analysis->hlo_cost_analysis().output_bytes_accessed( + *defining_position.instruction, defining_position.index) + << ", use bytes accessed = " + << options_.cost_analysis->hlo_cost_analysis().operand_bytes_accessed( + *use.instruction, use.operand_number, use.operand_index); } // There could be a requirement to pin this buffer to default memory either @@ -5978,12 +5051,11 @@ AlternateMemoryBestFitHeap::Result AlternateMemoryBestFitHeap::AllocateSegment( aliased_chunk = Chunk::FromOffsetSize( required_assignment_at_start->offset->offset, request.size); } - allocation_sequence->push_back( - std::make_unique( - defining_position, required_assignment_at_start->memory_space, - aliased_chunk, request.inclusive_start_time, - request.inclusive_start_time, - /*is_scoped_allocation=*/false)); + allocation_sequence->push_back(std::make_unique( + defining_position, required_assignment_at_start->memory_space, + aliased_chunk, request.inclusive_start_time, + request.inclusive_start_time, + /*is_scoped_allocation=*/false)); if (required_assignment_at_start->memory_space == MemorySpace::kAlternate) { CreateOrAddToAliasedOffset(*allocation_sequence->back(), @@ -6032,12 +5104,10 @@ AlternateMemoryBestFitHeap::Result AlternateMemoryBestFitHeap::AllocateSegment( } prev_allocation_in_default_mem_it = allocation_sequence->rbegin(); } else if (prev_allocation_in_default_mem_it == allocation_sequence->rend()) { - allocation_sequence->push_back( - std::make_unique( - defining_position, MemorySpace::kDefault, - /*chunk=*/std::nullopt, request.inclusive_start_time, - request.end_time, - /*is_scoped_allocation=*/false)); + allocation_sequence->push_back(std::make_unique( + defining_position, MemorySpace::kDefault, + /*chunk=*/std::nullopt, request.inclusive_start_time, request.end_time, + /*is_scoped_allocation=*/false)); prev_allocation_in_default_mem_it = allocation_sequence->rbegin(); } @@ -6076,21 +5146,17 @@ AlternateMemoryBestFitHeap::Result AlternateMemoryBestFitHeap::AllocateSegment( // Warn if the prefetch time picked doesn't match the preferred prefetch // time. CHECK(!request.allocation_value->allocation_sequence()->empty()); - const MemorySpaceAssignment::Allocation* allocation = + const Allocation* allocation = request.allocation_value->allocation_sequence()->back().get(); int64_t prefetch_time = 0; if (allocation->is_copy_allocation()) { - prefetch_time = - static_cast( - allocation) - ->copy_start_schedule_after(); + prefetch_time = static_cast(allocation) + ->copy_start_schedule_after(); } else if (allocation->is_sliced_copy_allocation()) { - prefetch_time = - static_cast( - allocation) - ->slice_details_sorted_by_start_time() - .front() - .copy_start_after_time; + prefetch_time = static_cast(allocation) + ->slice_details_sorted_by_start_time() + .front() + .copy_start_after_time; } else { LOG(FATAL) << "Prefetch allocation are expected to be " "CopyAllocations or SlicedCopyAllocations."; @@ -6137,34 +5203,29 @@ AlternateMemoryBestFitHeap::Result AlternateMemoryBestFitHeap::AllocateSegment( } void AlternateMemoryBestFitHeap::AddAsyncCopy( - MemorySpaceAssignment::Allocation& prev_allocation, - MemorySpace memory_space, std::optional chunk, - int64_t exclusive_start_time, int64_t end_time, + Allocation& prev_allocation, MemorySpace memory_space, + std::optional chunk, int64_t exclusive_start_time, int64_t end_time, int64_t copy_done_schedule_before_time, MemorySpaceAssignment::AllocationSequence* allocations, AliasedOffset* aliased_offset, float resource, std::optional cross_program_prefetch_index) { VLOG(3) << "Copy to " - << (memory_space == MemorySpaceAssignment::MemorySpace::kDefault - ? "default" - : "alternate") + << (memory_space == MemorySpace::kDefault ? "default" : "alternate") << " memory in (" << exclusive_start_time << ", " << copy_done_schedule_before_time << "), keeping until " << end_time << ", estimated copy resource is " << resource; CHECK_LT(exclusive_start_time, copy_done_schedule_before_time); - allocations->push_back( - std::make_unique( - prev_allocation, memory_space, chunk, exclusive_start_time, - copy_done_schedule_before_time, end_time, - cross_program_prefetch_index)); + allocations->push_back(std::make_unique( + prev_allocation, memory_space, chunk, exclusive_start_time, + copy_done_schedule_before_time, end_time, cross_program_prefetch_index)); // Register the additional async copy with the interval tree to keep track of // the limit at any given time. pending_async_copies_.push_back({exclusive_start_time, copy_done_schedule_before_time, resource, memory_space, next_async_copy_id_++}); - if (memory_space == MemorySpaceAssignment::MemorySpace::kAlternate) { + if (memory_space == MemorySpace::kAlternate) { prefetch_interval_tree_.Add( /*start=*/ ExclusiveToInclusiveStartTime(exclusive_start_time), @@ -6192,8 +5253,8 @@ namespace { // - When the allocation for the slice ends // - An estimation of how much copy resource the slice consumes std::string SliceTimesAndCopyResourcesToString( - const std::vector& slice_decisions, - int64_t prefetch_end, int64_t allocation_end) { + const std::vector& slice_decisions, int64_t prefetch_end, + int64_t allocation_end) { std::vector slice_strings; slice_strings.reserve(slice_decisions.size()); @@ -6217,11 +5278,10 @@ std::string SliceTimesAndCopyResourcesToString( } // namespace void AlternateMemoryBestFitHeap::AddAsyncSlicesForPrefetch( - const MemorySpaceAssignment::Allocation& prev_allocation, + const Allocation& prev_allocation, MemorySpaceAssignment::AllocationSequence* allocations, AliasedOffset* aliased_offset, - const std::vector& - slice_decisions_sorted_by_start_time, + const std::vector& slice_decisions_sorted_by_start_time, int64_t prefetch_end_time, int64_t allocation_end_time) { VLOG(3) << "Sliced copy to alternate memory. " << SliceTimesAndCopyResourcesToString( @@ -6232,19 +5292,18 @@ void AlternateMemoryBestFitHeap::AddAsyncSlicesForPrefetch( return slice_decision.exclusive_start_time < prefetch_end_time; })); - allocations->push_back( - std::make_unique( - prev_allocation, MemorySpaceAssignment::MemorySpace::kAlternate, - slice_decisions_sorted_by_start_time, prefetch_end_time, - allocation_end_time)); + allocations->push_back(std::make_unique( + prev_allocation, MemorySpace::kAlternate, + slice_decisions_sorted_by_start_time, prefetch_end_time, + allocation_end_time, options_.sliced_prefetch_options, + options_.get_equivalent_s8_shape_fn)); // Register the additional async copy with the interval tree to keep track of // the limit at any given time. for (const auto& slice_decision : slice_decisions_sorted_by_start_time) { pending_async_copies_.push_back( {slice_decision.exclusive_start_time, prefetch_end_time, - slice_decision.copy_resource_consumed, - MemorySpaceAssignment::MemorySpace::kAlternate, + slice_decision.copy_resource_consumed, MemorySpace::kAlternate, next_async_copy_id_++}); prefetch_interval_tree_.Add(slice_decision.exclusive_start_time, prefetch_end_time, kDummyChunk); @@ -6289,7 +5348,7 @@ bool AlternateMemoryBestFitHeap::ViolatesMaximumOutstandingAsyncCopies( AlternateMemoryBestFitHeap::Result AlternateMemoryBestFitHeap::AllocateInAlternateMemoryNoCopy( const AllocationRequest& request) { - MemorySpaceAssignment::Allocation* prev_allocation = nullptr; + Allocation* prev_allocation = nullptr; bool can_eliminate_copy = false; if (request.allocation_value->allocation_sequence()->empty()) { // There hasn't been any allocations for this interval so far. We can @@ -6398,7 +5457,7 @@ AlternateMemoryBestFitHeap::AllocateInAlternateMemoryNoCopy( prev_allocation->set_end_time(request.end_time); } else { request.allocation_value->mutable_allocation_sequence()->push_back( - std::make_unique( + std::make_unique( defining_position, MemorySpace::kAlternate, chunk_candidate, request.inclusive_start_time, request.end_time, /*is_scoped_allocation=*/false)); @@ -6420,7 +5479,7 @@ AlternateMemoryBestFitHeap::AllocateInAlternateMemoryNoCopy( AlternateMemoryBestFitHeap::Result AlternateMemoryBestFitHeap::Evict( const AllocationRequest& request) { CHECK_GT(request.allocation_value->allocation_sequence()->size(), 0); - MemorySpaceAssignment::Allocation* prev_allocation = + Allocation* prev_allocation = request.allocation_value->allocation_sequence()->back().get(); // We do not ever expect an Evict() to be immediately proceeded by a prefetch. // If that case ever occurs, the eviction_exclusive_start_time below will be @@ -6553,7 +5612,7 @@ namespace { // A debugging/logging method for describing a sliced solution. std::string DescribeSlicedBufferMove( - const std::vector& slice_decisions, + const std::vector& slice_decisions, const AlternateMemoryBestFitHeap::HeapResult& heap_result, const AlternateMemoryBestFitHeap::Chunk& full_chunk, absl::string_view prefetch_picker_debug_string) { @@ -6578,7 +5637,7 @@ std::string DescribeSlicedBufferMove( AlternateMemoryBestFitHeap::Result AlternateMemoryBestFitHeap::Prefetch( const AllocationRequest& request, - MemorySpaceAssignment::Allocation& prev_allocation_in_default_mem) { + Allocation& prev_allocation_in_default_mem) { // Try partially placing the buffer in the alternate space. The time that is // overlapped will be used to asynchronously copy the buffer from the // default memory to the alternate memory. @@ -6763,12 +5822,10 @@ void AlternateMemoryBestFitHeap::GenerateSliceProposal( } VLOG(6) << log_prefix() << ". Slice proposal = [" - << absl::StrJoin( - status_or_proposal.value(), ", ", - [](std::string* out, - const MemorySpaceAssignment::SliceProposal& proposal) { - absl::StrAppend(out, proposal.ToString()); - }) + << absl::StrJoin(status_or_proposal.value(), ", ", + [](std::string* out, const SliceProposal& proposal) { + absl::StrAppend(out, proposal.ToString()); + }) << "]"; context.slice_proposal_collection = std::move(status_or_proposal.value()); @@ -6802,7 +5859,7 @@ void AlternateMemoryBestFitHeap::SetupPrefetchWorkingIntervalsAndSliceProposal( context.sliced_solution_intervals.full)); std::vector sizes; sizes.reserve(context.slice_proposal_collection->size()); - for (const MemorySpaceAssignment::SliceProposal& single_slice_proposal : + for (const SliceProposal& single_slice_proposal : *context.slice_proposal_collection) { sizes.push_back(single_slice_proposal.slice_size); } @@ -6905,12 +5962,10 @@ float CopyResourceForShape(const Options& options, const Shape& shape) { // collection, in descending order. std::vector GetCopyResourcesSortedDescending( const Options& options, - const MemorySpaceAssignment::SliceProposalCollection& - slice_proposal_collection) { + const SliceProposalCollection& slice_proposal_collection) { std::vector copy_resources; copy_resources.reserve(slice_proposal_collection.size()); - for (const MemorySpaceAssignment::SliceProposal& proposal : - slice_proposal_collection) { + for (const SliceProposal& proposal : slice_proposal_collection) { copy_resources.push_back( CopyResourceForShape(options, proposal.slice_shape)); } @@ -7128,20 +6183,16 @@ AlternateMemoryBestFitHeap::Result AlternateMemoryBestFitHeap::CheckPrefetchFit( GetCandidateToProposalIndexMap(chunk_candidates); // Create slice decisions, sorted by time. - std::vector - slice_decisions_sorted_by_start_time; + std::vector slice_decisions_sorted_by_start_time; for (int64_t slice_time = 0; slice_time < sliced_buffer_interval->num_slices(); ++slice_time) { - const MemorySpaceAssignment::SliceProposal& proposal = - context.slice_proposal_collection->at( - candidate_to_proposal_index_map[slice_time]); + const SliceProposal& proposal = context.slice_proposal_collection->at( + candidate_to_proposal_index_map[slice_time]); copy_resource_per_slice_sorted_by_start_time[slice_time] = CopyResourceForShape(options_, proposal.slice_shape); - slice_decisions_sorted_by_start_time.push_back( - MemorySpaceAssignment::SliceDecision{ - chunk_candidates[slice_time], - exclusive_slice_start_times[slice_time], proposal, - copy_resource_per_slice_sorted_by_start_time[slice_time]}); + slice_decisions_sorted_by_start_time.push_back(SliceDecision{ + chunk_candidates[slice_time], exclusive_slice_start_times[slice_time], + proposal, copy_resource_per_slice_sorted_by_start_time[slice_time]}); } // Check that we have enough copy resources for all the slice decisions. @@ -7548,45 +6599,6 @@ Status MemorySpaceAssignment::FindAllocationSequence( return OkStatus(); } -bool MemorySpaceAssignment::Allocation::is_copy_like_allocation() const { - return is_copy_allocation() || is_sliced_copy_allocation(); -} - -void MemorySpaceAssignment::Allocation::AddUse(HloUse use) { - HloInstruction* operand = - use.instruction->mutable_operand(use.operand_number); - // If the use is a tuple, look inside the tuple to find the actual use. - for (int64_t index : use.operand_index) { - if (operand->opcode() != HloOpcode::kTuple) { - break; - } - operand = operand->mutable_operand(index); - } - - // Look beyond GetTupleElement(Tuple()) pattern for any bitcasts. - std::function get_simplified_operand; - get_simplified_operand = [&](HloInstruction* instruction) { - while (instruction->opcode() == HloOpcode::kGetTupleElement) { - HloInstruction* operand = - get_simplified_operand(instruction->mutable_operand(0)); - if (operand->opcode() == HloOpcode::kTuple) { - instruction = operand->mutable_operand(instruction->tuple_index()); - } else { - return instruction; - } - } - return instruction; - }; - operand = get_simplified_operand(operand); - - uses_.push_back(use); -} - -void MemorySpaceAssignment::Allocation::set_offset(int64_t offset) { - CHECK(chunk_.has_value()); - *chunk_ = Chunk::FromOffsetSize(offset, chunk_->size); -} - float MemorySpaceAssignment::ComputeEstimatedElapsedTime( const HloLiveRange& hlo_live_range, const AllocationSequence& allocations) { absl::flat_hash_map> @@ -7629,646 +6641,16 @@ float MemorySpaceAssignment::ComputeEstimatedElapsedTime( options_.cost_analysis->GetInstructionElapsedInAlternateMemory( *instruction, operands_in_alternate_memory, outputs_in_alternate_memory); - float while_nest_multiplier = IPow( - options_.xla_tpu_memory_space_assignment_while_execution_count, - options_.cost_analysis->CalculateComputationNestLevel( - instruction, - /*while_only=*/true)); + float while_nest_multiplier = + options_.cost_analysis->GetWhileNestMultiplier( + options_.cost_analysis->CalculateComputationNestLevel( + instruction, + /*while_only=*/true)); total_elapsed += while_nest_multiplier * instruction_elapsed; } return total_elapsed; } -Status MemorySpaceAssignment::Allocation::Process() { - if (is_scoped_allocation()) { - // Nothing to do here for scoped allocations. - return OkStatus(); - } - HloInstruction* producing_instruction = AddGetTupleElements(); - HloComputation* computation = producing_instruction->parent(); - for (const HloUse& use : uses_) { - Shape operand_shape = use.instruction->operand(use.operand_number)->shape(); - HloInstruction* replacement_instruction = producing_instruction; - if (operand_shape.IsTuple()) { - TF_ASSIGN_OR_RETURN( - replacement_instruction, - TupleUtil::ReplaceTupleWith( - producing_instruction, - use.instruction->mutable_operand(use.operand_number), - use.operand_index)); - } else if (operand_shape != producing_instruction->shape()) { - VLOG(4) << "Old shape = " << operand_shape.ToString() - << ", new shape = " << producing_instruction->shape().ToString() - << "; inserting a bitcast."; - replacement_instruction = computation->AddInstruction( - HloInstruction::CreateBitcast(operand_shape, producing_instruction)); - } - TF_RETURN_IF_ERROR(use.instruction->ReplaceOperandWith( - use.operand_number, replacement_instruction)); - } - return OkStatus(); -} - -HloInstruction* MemorySpaceAssignment::Allocation::AddGetTupleElements() const { - CHECK_NE(defining_position().instruction, nullptr); - - Shape shape = defining_position().shape(); - CHECK(shape.IsArray()) << "Allocation shape is not an array. Shape = " - << shape.ToString() - << " position = " << defining_position().shape(); - return TupleUtil::AddGetTupleElements(defining_position()); -} - -std::string MemorySpaceAssignment::Allocation::ToString() const { - std::string memory_space_str = - memory_space_ == MemorySpace::kDefault ? "def" : "alt"; - if (chunk_) { - absl::StrAppend(&memory_space_str, " (off: ", chunk_->offset, ")"); - } - return absl::StrCat((is_scoped_allocation() ? "Scoped " : ""), - "Allocation in ", memory_space_str, " defined at ", - defining_position_.ToString(), - ", start_time:", start_time(), ", end_time:", end_time(), - ", uses: ", UsesToString(uses())); -} - -std::string MemorySpaceAssignment::CopyAllocation::ToString() const { - std::string memory_space_str = - memory_space_ == MemorySpace::kDefault ? "def" : "alt"; - if (chunk_) { - absl::StrAppend(&memory_space_str, " (off: ", chunk_->offset, ")"); - } - return absl::StrCat("Copy Allocation in ", memory_space_str, - ", start_time:", start_time(), ", end_time:", end_time(), - ", copy_start_after_time: ", copy_start_schedule_after(), - ", copy_done_before_time: ", copy_done_schedule_before(), - ", uses: ", UsesToString(uses()), ", from ", - prev_allocation_.ToString()); -} - -std::string MemorySpaceAssignment::SliceParam::ToString() const { - return absl::StrCat("[", start_inclusive, ",", end_exclusive, ")"); -} - -bool MemorySpaceAssignment::SliceParam::operator==( - const SliceParam& other) const { - return start_inclusive == other.start_inclusive && - end_exclusive == other.end_exclusive; -} - -std::string MemorySpaceAssignment::SliceProposal::ToString() const { - return absl::StrCat( - "{ slice_shape: ", slice_shape.ToString(true), ", slice_params: { ", - absl::StrJoin(slice_params, ", ", - [](std::string* out, const SliceParam& param) { - absl::StrAppend(out, param.ToString()); - }), - " }, slice_size: ", slice_size, " }"); -} - -std::ostream& operator<<(std::ostream& os, - const MemorySpaceAssignment::SliceProposal& proposal) { - os << proposal.ToString(); - return os; -} - -std::tuple&, - int64_t> -MemorySpaceAssignment::SliceProposal::ToTuple() const { - return std::make_tuple(std::ref(slice_shape), std::ref(slice_params), - slice_size); -} - -bool MemorySpaceAssignment::SliceProposal::operator==( - const SliceProposal& other) const { - return ToTuple() == other.ToTuple(); -} - -std::string MemorySpaceAssignment::SliceDecision::ToString() const { - return absl::StrCat("{ chunk: ", chunk.ToString(), - ", (exclusive) start_time: ", exclusive_start_time, - ", sizing: ", sizing.ToString(), - ", copy_resource_consumed: ", copy_resource_consumed, - " }"); -} - -namespace { - -std::tuple -SliceDecisionToTuple(const MemorySpaceAssignment::SliceDecision& decision) { - return std::make_tuple( - std::ref(decision.chunk), decision.exclusive_start_time, - std::ref(decision.sizing), decision.copy_resource_consumed); -} - -} // namespace - -bool MemorySpaceAssignment::SliceDecision::operator==( - const SliceDecision& other) const { - return SliceDecisionToTuple(*this) == SliceDecisionToTuple(other); -} - -std::string MemorySpaceAssignment::SlicedCopyAllocation::SliceDetail::ToString() - const { - return absl::StrCat("{ slice_decision: ", slice_decision.ToString(), - ", copy_start_after_time: ", copy_start_after_time, - ", copy_done_before_time: ", copy_done_before_time, " }"); -} - -namespace { - -std::tuple -SliceDetailToTuple( - const MemorySpaceAssignment::SlicedCopyAllocation::SliceDetail& - slice_detail) { - return std::make_tuple(std::ref(slice_detail.slice_decision), - slice_detail.copy_start_after_time, - slice_detail.copy_done_before_time, - slice_detail.copy_start, slice_detail.copy_done); -} - -} // namespace - -bool MemorySpaceAssignment::SlicedCopyAllocation::SliceDetail::operator==( - const SliceDetail& other) const { - return SliceDetailToTuple(*this) == SliceDetailToTuple(other); -} - -Status -MemorySpaceAssignment::SlicedCopyAllocation::SliceDetail::CreateAsyncSlice( - const Shape& original_shape, HloInstruction& producer, - HloComputation& parent) { - if (original_shape.rank() != slice_decision.sizing.slice_params.size()) { - return FailedPrecondition( - "%s", absl::StrCat("The number of SlicedCopyAllocation parameters ", - slice_decision.sizing.slice_params.size(), - " does not match the rank ", original_shape.rank(), - " of the tensor we are slicing.")); - } - - std::vector start_indices; - start_indices.reserve(slice_decision.sizing.slice_params.size()); - std::vector limit_indices; - limit_indices.reserve(slice_decision.sizing.slice_params.size()); - std::vector strides; - strides.reserve(slice_decision.sizing.slice_params.size()); - - for (int i = 0; i < slice_decision.sizing.slice_params.size(); ++i) { - const SliceParam& slice_param = slice_decision.sizing.slice_params[i]; - start_indices.push_back(slice_param.start_inclusive); - limit_indices.push_back(slice_param.end_exclusive); - strides.push_back(1); - const int64_t new_dim = - slice_param.end_exclusive - slice_param.start_inclusive; - if (new_dim <= 0) { - return FailedPrecondition( - "%s", absl::StrCat("SlicedCopyAllocation new dimension size is ", - new_dim, ", expected something > 0.")); - } - if (original_shape.dimensions(i) < new_dim) { - return FailedPrecondition( - "%s", - absl::StrCat("SlicedCopyAllocation sliced dimension size ", new_dim, - " is bigger than its original dimension size of ", - original_shape.dimensions(i), ".")); - } - } - - HloInstruction* slice = parent.AddInstruction( - HloInstruction::CreateSlice(slice_decision.sizing.slice_shape, &producer, - start_indices, limit_indices, strides)); - TF_ASSIGN_OR_RETURN(copy_done, parent.CreateAsyncInstructions( - slice, {ShapeUtil::MakeShape(S32, {})})); - copy_start = copy_done->mutable_operand(0); - - return OkStatus(); -} - -namespace { - -// Helper function to compute the underlying Allocation chunk for a -// SlicedCopyAllocation. -std::optional GetSlicedCopyAllocationChunk( - const std::vector& - slice_decisions_sorted_by_start_time) { - if (slice_decisions_sorted_by_start_time.empty()) { - return std::nullopt; - } - auto offset_cmp = [](const MemorySpaceAssignment::SliceDecision& lhs, - const MemorySpaceAssignment::SliceDecision& rhs) { - return lhs.chunk.offset < rhs.chunk.offset; - }; - auto end_cmp = [](const MemorySpaceAssignment::SliceDecision& lhs, - const MemorySpaceAssignment::SliceDecision& rhs) { - return lhs.chunk.chunk_end() < rhs.chunk.chunk_end(); - }; - return MemorySpaceAssignment::Chunk::FromOffsetEnd( - std::min_element(slice_decisions_sorted_by_start_time.begin(), - slice_decisions_sorted_by_start_time.end(), offset_cmp) - ->chunk.offset, - std::max_element(slice_decisions_sorted_by_start_time.begin(), - slice_decisions_sorted_by_start_time.end(), end_cmp) - ->chunk.chunk_end()); -} - -// Helper function to compute the start time for a SlicedCopyAllocation. -int64_t GetSlicedCopyAllocationExclusiveStartTime( - const std::vector& - slice_decisions_sorted_by_exclusive_start_time) { - if (slice_decisions_sorted_by_exclusive_start_time.empty()) { - return -1; - } - - return slice_decisions_sorted_by_exclusive_start_time.front() - .exclusive_start_time; -} - -} // namespace - -MemorySpaceAssignment::SlicedCopyAllocation::SlicedCopyAllocation( - const Allocation& prev_allocation, MemorySpace memory_space, - std::vector slice_decisions_sorted_by_exclusive_start_time, - int64_t copy_done_schedule_before_time, int64_t end_time) - : Allocation( - /*defining_position=*/{nullptr, {}}, memory_space, - GetSlicedCopyAllocationChunk( - slice_decisions_sorted_by_exclusive_start_time), - // Allocation uses an inclusive start time - ExclusiveToInclusiveStartTime( - GetSlicedCopyAllocationExclusiveStartTime( - slice_decisions_sorted_by_exclusive_start_time)), - end_time, - /*is_scoped_allocation=*/false), - original_shape_to_slice_(prev_allocation.defining_position().shape()), - prev_allocation_(prev_allocation) { - CHECK_GE(slice_decisions_sorted_by_exclusive_start_time.size(), 2); - slice_details_sorted_by_start_time_.reserve( - slice_decisions_sorted_by_exclusive_start_time.size()); - for (SliceDecision& decision : - slice_decisions_sorted_by_exclusive_start_time) { - int64_t copy_done_schedule_after_time = decision.exclusive_start_time; - slice_details_sorted_by_start_time_.push_back(SliceDetail{ - std::move(decision), - copy_done_schedule_after_time, - copy_done_schedule_before_time, - /*copy_start=*/nullptr, - /*copy_done=*/nullptr, - }); - } -} - -namespace { - -// Sets defining_position with the copy_complete instruction and replaces all -// uses of the allocation with the copy_complete instruction. -Status ProcessCopyLikeAllocationUses(HloPosition& defining_position, - std::vector& uses, - HloComputation* computation, - HloInstruction* copy_complete) { - // Update the allocation position with the copy complete instruction, so that - // if there are further copies from it, they can find the correct position. - defining_position = HloPosition{copy_complete, {}}; - - // Replace all the uses of the copy-like allocation with the copy complete - // instruction. - for (HloUse use : uses) { - // If the operand is a tuple, we need to descend to the actual instruction - // we want to replace. - HloInstruction* replacement_instruction = copy_complete; - Shape operand_shape = use.instruction->operand(use.operand_number)->shape(); - if (operand_shape.IsTuple()) { - TF_ASSIGN_OR_RETURN( - replacement_instruction, - TupleUtil::ReplaceTupleWith( - copy_complete, - use.instruction->mutable_operand(use.operand_number), - use.operand_index)); - } else if (operand_shape != copy_complete->shape()) { - // When processing allocations, we treat bitcasts as trivial positions and - // do not create allocations for them. We insert bitcasts after copies, to - // account for the fact that we don't have an allocation for the bitcast. - VLOG(4) << "Old shape = " << operand_shape.ToString() - << ", new shape = " << copy_complete->shape().ToString() - << "; inserting a bitcast."; - replacement_instruction = computation->AddInstruction( - HloInstruction::CreateBitcast(operand_shape, copy_complete)); - } - TF_RETURN_IF_ERROR(use.instruction->ReplaceOperandWith( - use.operand_number, replacement_instruction)); - } - - return OkStatus(); -} - -} // namespace - -Status MemorySpaceAssignment::SlicedCopyAllocation::Process() { - Shape shape = defining_position().shape(); - HloInstruction* producing_instruction = AddGetTupleElements(); - - // Calling Process() over the previous allocation might have modified the - // defining position, and hence the shape that was used when we computed - // the slices. In cases where the shape has changed, we insert a bitcast, so - // slice instructions operate on the originally sliced shape. - // - // Note, these bitcasts are being inserted in the same cases that - // ProcessCopyLikeAllocationUses() is inserting bitcasts, except we are - // inserting the bitcasts before the copy, instead of after the copy. - if (!Shape::Equal().IgnoreMemorySpaceInLayout()(shape, - original_shape_to_slice_)) { - int64_t new_memory_space = shape.layout().memory_space(); - shape = original_shape_to_slice_; - shape.mutable_layout()->set_memory_space(new_memory_space); - producing_instruction = producing_instruction->parent()->AddInstruction( - HloInstruction::CreateBitcast(shape, producing_instruction)); - } - - HloComputation* computation = producing_instruction->parent(); - std::vector slice_dones; - slice_dones.reserve(slice_details_sorted_by_start_time_.size()); - - // Sliced copy allocations need to insert asynchronous copy nodes. - for (SliceDetail& slice_detail : slice_details_sorted_by_start_time_) { - TF_RETURN_IF_ERROR(slice_detail.CreateAsyncSlice( - shape, *producing_instruction, *computation)); - VLOG(4) << "Created " << slice_detail.copy_start->name() - << " for sliced copy allocation: " << ToString(); - slice_dones.push_back(slice_detail.copy_done); - } - - TF_RETURN_IF_ERROR(CreateBitcastConcat(shape, slice_dones)); - - return ProcessCopyLikeAllocationUses(defining_position_, uses_, computation, - concat_); -} - -void MemorySpaceAssignment::SlicedCopyAllocation::MarkNeeded( - absl::flat_hash_set& needed_allocations) const { - needed_allocations.insert(this); - prev_allocation_.MarkNeeded(needed_allocations); -} - -HloPosition MemorySpaceAssignment::SlicedCopyAllocation::defining_position() - const { - // Unless explicitly set, the defining position of a sliced copy allocation is - // retrieved from the previous allocation. This is because we don't create - // new CopyStart/CopyDone instructions until later and the position should - // point to the previous (copy or otherwise) allocation's position for the - // original defining position. - if (defining_position_.instruction == nullptr) { - return prev_allocation_.defining_position(); - } - return defining_position_; -} - -int64_t MemorySpaceAssignment::SlicedCopyAllocation::earliest_available_time() - const { - return slice_details_sorted_by_start_time().back().copy_done_before_time; -} - -std::vector -MemorySpaceAssignment::SlicedCopyAllocation::SliceOffsetsSortedByStartTime() - const { - std::vector offsets; - offsets.reserve(slice_details_sorted_by_start_time_.size()); - - for (const SliceDetail& slice_detail : slice_details_sorted_by_start_time_) { - offsets.push_back(slice_detail.slice_decision.chunk.offset); - } - - return offsets; -} - -void MemorySpaceAssignment::SlicedCopyAllocation::AddDiffToAllSliceOffsets( - int64_t diff) { - for (SliceDetail& slice_detail : slice_details_sorted_by_start_time_) { - Chunk& chunk = slice_detail.slice_decision.chunk; - chunk = Chunk::FromOffsetSize(chunk.offset + diff, chunk.size); - } -} - -void MemorySpaceAssignment::SlicedCopyAllocation::ImportRepackedSliceData( - const SlicedAllocationData& data) { - int num_slices = slice_details_sorted_by_start_time_.size(); - CHECK_EQ(data.slices_sorted_by_offset.size(), num_slices); - - std::vector slice_details_sorted_by_offset; - slice_details_sorted_by_offset.reserve(num_slices); - for (SliceDetail& slice_detail : slice_details_sorted_by_start_time_) { - slice_details_sorted_by_offset.push_back(&slice_detail); - } - absl::c_sort(slice_details_sorted_by_offset, [](const SliceDetail* lhs, - const SliceDetail* rhs) { - return lhs->slice_decision.chunk.offset < rhs->slice_decision.chunk.offset; - }); - - for (int i = 0; i < num_slices; ++i) { - SliceDetail* slice_detail = slice_details_sorted_by_offset[i]; - Chunk& chunk = slice_detail->slice_decision.chunk; - const AllocatedSlice& repacked_slice_data = data.slices_sorted_by_offset[i]; - chunk = Chunk::FromOffsetSize(repacked_slice_data.offset, chunk.size); - slice_detail->copy_start_after_time = - repacked_slice_data.inclusive_start_time - 1; - slice_detail->slice_decision.exclusive_start_time = - InclusiveToExclusiveStartTime(repacked_slice_data.inclusive_start_time); - } - - absl::c_sort(slice_details_sorted_by_start_time_, - [](const SliceDetail& lhs, const SliceDetail& rhs) { - return std::make_tuple(lhs.copy_start_after_time, - lhs.slice_decision.chunk.offset) < - std::make_tuple(rhs.copy_start_after_time, - rhs.slice_decision.chunk.offset); - }); -} - -const std::vector& -MemorySpaceAssignment::SlicedCopyAllocation:: - slice_details_sorted_by_start_time() const { - return slice_details_sorted_by_start_time_; -} - -std::vector& -MemorySpaceAssignment::SlicedCopyAllocation:: - mutable_slice_details_sorted_by_start_time() { - return slice_details_sorted_by_start_time_; -} - -std::tuple&, - const HloInstruction*> -MemorySpaceAssignment::SlicedCopyAllocation::ToTuple() const { - return std::make_tuple( - std::ref(*this), std::ref(slice_details_sorted_by_start_time_), concat_); -} - -bool MemorySpaceAssignment::SlicedCopyAllocation::operator==( - const SlicedCopyAllocation& other) const { - return ToTuple() == other.ToTuple(); -} - -std::string MemorySpaceAssignment::SlicedCopyAllocation::ToString() const { - std::string memory_space_str = "def"; - if (memory_space_ == MemorySpace::kAlternate) { - memory_space_str = absl::StrCat("alt (off: ", chunk_->offset, ")"); - } - return absl::StrCat( - "Sliced Copy Allocation in ", memory_space_str, - ", start_time:", start_time(), ", end_time:", end_time(), - ", first_slice_copy_start_after_time: ", - slice_details_sorted_by_start_time().front().copy_start_after_time, - ", last_slice_copy_done_before_time: ", - slice_details_sorted_by_start_time().back().copy_done_before_time, - ", uses: ", UsesToString(uses()), ", from ", prev_allocation_.ToString()); -} - -Status MemorySpaceAssignment::SlicedCopyAllocation::CreateBitcastConcat( - const Shape& shape, absl::Span slices) { - CHECK(!slices.empty()); - concat_ = - slices.front()->parent()->AddInstruction(HloInstruction::CreateCustomCall( - shape, slices, kConcatBitcastCustomCall)); - return OkStatus(); -} - -std::string MemorySpaceAssignment::MirroredAllocation::ToString() const { - return absl::StrCat("Mirrored Allocation for ", - original_allocation_.ToString()); -} - -std::string MemorySpaceAssignment::ParentAllocation::ToString() const { - return absl::StrCat("Parent Allocation mirrored at ", - defining_position_.ToString(), ", originally ", - original_allocation_.ToString()); -} - -MemorySpaceAssignment::CopyAllocation::CopyAllocation( - Allocation& prev_allocation, MemorySpace memory_space, - std::optional chunk, int64_t copy_start_schedule_after_time, - int64_t copy_done_schedule_before_time, int64_t end_time, - std::optional cross_program_prefetch_index) - : Allocation(/*defining_position=*/{nullptr, {}}, memory_space, chunk, - // Allocation uses an inclusive start time - ExclusiveToInclusiveStartTime(copy_start_schedule_after_time), - end_time, - /*is_scoped_allocation=*/false), - prev_allocation_(prev_allocation), - copy_start_schedule_after_(copy_start_schedule_after_time), - copy_done_schedule_before_(copy_done_schedule_before_time), - cross_program_prefetch_index_(cross_program_prefetch_index) {} - -Status MemorySpaceAssignment::CopyAllocation::Process() { - // Copy allocations need to insert asynchronous copy nodes. - Shape shape = defining_position().shape(); - HloInstruction* producing_instruction = AddGetTupleElements(); - HloComputation* computation = producing_instruction->parent(); - copy_start_ = computation->AddInstruction(HloInstruction::CreateCopyStart( - ShapeUtil::MakeTupleShape({shape, shape, ShapeUtil::MakeShape(U32, {})}), - producing_instruction, cross_program_prefetch_index_)); - copy_done_ = computation->AddInstruction( - HloInstruction::CreateUnary(shape, HloOpcode::kCopyDone, copy_start_)); - VLOG(4) << "Created " << copy_start_->name() - << " for copy allocation: " << ToString(); - - return ProcessCopyLikeAllocationUses(defining_position_, uses_, computation, - copy_done_); -} - -Status MemorySpaceAssignment::MirroredAllocation::Process() { - defining_position_ = original_allocation_.defining_position(); - return Allocation::Process(); -} - -Status MemorySpaceAssignment::ParentAllocation::Process() { - // Add an additional parameter to the while HLO with a reference to the buffer - // in the default memory space. - HloInstruction* producing_instruction = - original_allocation_.AddGetTupleElements(); - int new_tuple_index = calling_instruction_->shape().tuple_shapes_size(); - - TF_ASSIGN_OR_RETURN( - HloInstruction * new_while_operand, - TupleUtil::ReplaceTupleWith(producing_instruction, - calling_instruction_->mutable_operand(0), - {new_tuple_index})); - TF_RETURN_IF_ERROR(calling_instruction_->ReplaceOperandWithDifferentShape( - 0, new_while_operand)); - *calling_instruction_->mutable_shape() = new_while_operand->shape(); - *calling_instruction_->while_condition() - ->parameter_instruction(0) - ->mutable_shape() = new_while_operand->shape(); - *calling_instruction_->while_body() - ->parameter_instruction(0) - ->mutable_shape() = new_while_operand->shape(); - defining_position_.index = {new_tuple_index}; - // Also replace the while op with a tuple that has the old shape. Note that we - // need to first take a snapshot of the users before calling ExtractPrefix - // since ExtractPrefix introduces additional gte users. - std::vector while_users = calling_instruction_->users(); - HloInstruction* tuple_with_old_shape = - TupleUtil::ExtractPrefix(calling_instruction_, new_tuple_index); - TF_RETURN_IF_ERROR(calling_instruction_->ReplaceAllUsesWithDifferentShape( - while_users, tuple_with_old_shape)); - return Allocation::Process(); -} - -Status MemorySpaceAssignment::ParentAllocation::PostProcess() { - // Update the root of the while body with the new parameter. The reason why we - // need a separate post-process for this is because other allocations may have - // while body root as a use, so they would update the old root instead of the - // new root. Doing the post-process step later ensures the root has been - // updated with other changes, and we can safely add the additional parameter. - HloComputation* while_body = calling_instruction_->while_body(); - TF_ASSIGN_OR_RETURN(HloInstruction * new_while_body_root, - TupleUtil::ReplaceTupleWith( - AddGetTupleElements(), while_body->root_instruction(), - defining_position_.index)); - while_body->set_root_instruction(new_while_body_root, - /*accept_different_shape=*/true); - return OkStatus(); -} - -void MemorySpaceAssignment::Allocation::MarkIfNeeded( - absl::flat_hash_set& needed_allocations) const { - MarkNeeded(needed_allocations); -} - -void MemorySpaceAssignment::Allocation::MarkNeeded( - absl::flat_hash_set& needed_allocations) const { - needed_allocations.insert(this); -} - -void MemorySpaceAssignment::CopyAllocation::MarkNeeded( - absl::flat_hash_set& needed_allocations) const { - needed_allocations.insert(this); - prev_allocation_.MarkNeeded(needed_allocations); -} - -void MemorySpaceAssignment::ParentAllocation::MarkIfNeeded( - absl::flat_hash_set& needed_allocations) const { - // Parent allocations are only needed if they have any uses or if there is a - // copy allocation that copies this value (in that case, the copy allocation - // will call this allocation's MarkNeeded function). - if (!uses_.empty()) { - MarkNeeded(needed_allocations); - } -} - -void MemorySpaceAssignment::ParentAllocation::MarkNeeded( - absl::flat_hash_set& needed_allocations) const { - needed_allocations.insert(this); - original_allocation_.MarkNeeded(needed_allocations); -} - -void MemorySpaceAssignment::MirroredAllocation::MarkNeeded( - absl::flat_hash_set& needed_allocations) const { - needed_allocations.insert(this); - original_allocation_.MarkNeeded(needed_allocations); -} - Status MemorySpaceAssignment::Process(const HloLiveRange& hlo_live_range) { VLOG(1) << "Processing assigned buffers..."; // Since some parent allocations may not be needed (e.g. when they don't have @@ -8579,8 +6961,7 @@ class AsyncCopyStep { class AsyncCopyStepForCopyAllocation : public AsyncCopyStep { public: - explicit AsyncCopyStepForCopyAllocation( - MemorySpaceAssignment::CopyAllocation* copy_allocation) + explicit AsyncCopyStepForCopyAllocation(CopyAllocation* copy_allocation) : AsyncCopyStep(), copy_allocation_(copy_allocation) {} ~AsyncCopyStepForCopyAllocation() override = default; @@ -8606,14 +6987,13 @@ class AsyncCopyStepForCopyAllocation : public AsyncCopyStep { } private: - MemorySpaceAssignment::CopyAllocation* copy_allocation_ = nullptr; + CopyAllocation* copy_allocation_ = nullptr; }; class AsyncCopyStepForSlice : public AsyncCopyStep { public: - AsyncCopyStepForSlice( - MemorySpaceAssignment::SlicedCopyAllocation* sliced_copy_allocation, - size_t slice_index) + AsyncCopyStepForSlice(SlicedCopyAllocation* sliced_copy_allocation, + size_t slice_index) : AsyncCopyStep(), sliced_copy_allocation_(sliced_copy_allocation), slice_index_(slice_index) {} @@ -8625,10 +7005,9 @@ class AsyncCopyStepForSlice : public AsyncCopyStep { } std::optional start_phase() const override { - const MemorySpaceAssignment::SlicedCopyAllocation::SliceDetail& - slice_details = - sliced_copy_allocation_ - ->slice_details_sorted_by_start_time()[slice_index_]; + const SlicedCopyAllocation::SliceDetail& slice_details = + sliced_copy_allocation_ + ->slice_details_sorted_by_start_time()[slice_index_]; StartPhase phase{slice_details.copy_start_after_time, slice_details.copy_start}; @@ -8642,10 +7021,9 @@ class AsyncCopyStepForSlice : public AsyncCopyStep { } DonePhase done_phase() const override { - const MemorySpaceAssignment::SlicedCopyAllocation::SliceDetail& - slice_details = - sliced_copy_allocation_ - ->slice_details_sorted_by_start_time()[slice_index_]; + const SlicedCopyAllocation::SliceDetail& slice_details = + sliced_copy_allocation_ + ->slice_details_sorted_by_start_time()[slice_index_]; DonePhase phase{slice_details.copy_done_before_time, slice_details.copy_done}; @@ -8653,15 +7031,14 @@ class AsyncCopyStepForSlice : public AsyncCopyStep { } private: - MemorySpaceAssignment::SlicedCopyAllocation* sliced_copy_allocation_ = - nullptr; + SlicedCopyAllocation* sliced_copy_allocation_ = nullptr; size_t slice_index_; }; class AsyncCopyStepForSliceConcat : public AsyncCopyStep { public: explicit AsyncCopyStepForSliceConcat( - MemorySpaceAssignment::SlicedCopyAllocation* sliced_copy_allocation) + SlicedCopyAllocation* sliced_copy_allocation) : AsyncCopyStep(), sliced_copy_allocation_(sliced_copy_allocation) {} ~AsyncCopyStepForSliceConcat() override = default; @@ -8682,8 +7059,7 @@ class AsyncCopyStepForSliceConcat : public AsyncCopyStep { } private: - MemorySpaceAssignment::SlicedCopyAllocation* sliced_copy_allocation_ = - nullptr; + SlicedCopyAllocation* sliced_copy_allocation_ = nullptr; }; } // namespace @@ -9115,12 +7491,22 @@ DefaultCrossProgramPrefetchBufferIntervalComparator::GetTuple( MemoryBoundednessBufferIntervalComparator:: MemoryBoundednessBufferIntervalComparator( - const MemorySpaceAssignmentCostAnalysis& cost_analysis, - MemorySpaceAssignmentCostAnalysis::Cache* cost_analysis_cache) + const CostAnalysis& cost_analysis, + CostAnalysis::Cache* cost_analysis_cache) : MemorySpaceAssignment::BufferIntervalComparator(), cost_analysis_(cost_analysis), cost_analysis_cache_(cost_analysis_cache) {} +MemoryBoundednessBufferIntervalComparator:: + MemoryBoundednessBufferIntervalComparator( + const CostAnalysis& cost_analysis, + CostAnalysis::Cache* cost_analysis_cache, + MsaSortOrderOverrides msa_sort_order_overrides) + : MemorySpaceAssignment::BufferIntervalComparator(), + cost_analysis_(cost_analysis), + cost_analysis_cache_(cost_analysis_cache), + msa_sort_order_overrides_(msa_sort_order_overrides) {} + std::string MemoryBoundednessBufferIntervalComparator::DescribeComparisonCriteria() const { return "[override priority, -memory boundedness, -size, -buffer duration, " @@ -9162,7 +7548,7 @@ MemoryBoundednessBufferIntervalComparator::ComparisonTuple MemoryBoundednessBufferIntervalComparator::GetTuple( const BufferInterval& buffer_interval) { int64_t priority = GetBufferIntervalOverridePriority( - cost_analysis_.options().msa_sort_order_overrides, buffer_interval); + msa_sort_order_overrides_, buffer_interval); float inverse_memory_boundedness = -1.0 * cost_analysis_.GetMemoryBoundedness(buffer_interval, cost_analysis_cache_); diff --git a/third_party/xla/xla/service/memory_space_assignment/memory_space_assignment.h b/third_party/xla/xla/service/memory_space_assignment/memory_space_assignment.h index eadb1d01a3ea2a..a5f9a5a3f42ef5 100644 --- a/third_party/xla/xla/service/memory_space_assignment/memory_space_assignment.h +++ b/third_party/xla/xla/service/memory_space_assignment/memory_space_assignment.h @@ -13,6 +13,156 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +/* +Quick reference + +This section is meant as to be a quick reference for getting the gist of +commonly used terminology in the code and logging. Please see the code for more +details. + +General concepts + + - Time: In MSA, time typically refers to an index into the flattened + instruction schedule. + + - Cross-program prefetch: Cross-program prefetched tensors are copied from + memory to alternate the first time a program executes, like usual + prefetches. MSA keeps these buffers alive in alternate memory at the end of + the program, such that if the same program is executed again, these tensors + would not need to be prefetched again. + +Classes + + - HloPosition (Hlo dataflow analysis concept): Identifies a tensor referenced + in an instruction's output. Defined by . + + - HloValue (Hlo dataflow analysis concept): The value of a tensor. Each + HloValue is represented by a collection of HloPositions. Exactly 1 of those + positions is the HloValue's defining position, i.e., the point in code where + the value is created/written. The rest of the positions pertain to read-only + uses of the value. + * Example: A tensor that is inserted in a Tuple has 2 HloPositions, one for + the instruction that creates the tensor, and one indexing into the Tuple + instruction result. + * The read-only positions of an HloValue should not be confused with + HloUses. Read-only positions are references to the HloValue in the output + of an instruction. Uses are references to an HloValue in the input of an + instruction. + * Dataflow analysis assigns HloValues for the instructions in computations + pertaining to while loops, conditionals, and call ops. However, it does + not assign HloValues to the computations pertaining to instructions with + "call" semantics (e.g., fusions, reduce, and custom-call) because those + computations are treated as black boxes. + * If a while loop does not modify an input tensor, that tensor will be + assigned 1 HloValue that lasts from its creation point through the while + loop. + * If a while loop modifies one of its input tensors, that tensor will + receive at least the following HloValues: + - An HloValue for the tensor's creation, with a use at the operand of the + while instruction. + - An HloValue with its defining position at the while body's parameter. + - An HloValue whose defining position is an instruction in the while body + that feeds the new tensor value to the body's ROOT instruction. + - An HloValue with its defining position at the while instruction's + result. + + - HloBuffer (Hlo alias analysis concept): A memory container that holds one + or more HloValues that must alias. Typically, each HloValue corresponds to + 1 HloBuffer; however, many exceptions exist. For example, tensors that are + modified by a while loop have their HloValues share an HloBuffer, for the + HloValues that come immediately before, during, and immediately after the + loop. HloBuffers are shared between HloValues wherever their is aliasing, + whether implicit by the nature of the instruction (e.g., + dynamic-update-slice) or explicit (e.g., fusion input-output aliasing). + + - BufferInterval (HeapSimulator concept): A BufferInterval is defined by a + buffer of a given size, with a defined lifetime. In MSA, the buffer + corresponds to an HloValue. + + - AllocationValue: An AllocationValue is defined by an HloValue, and *one* of + its HloPositions. + * We do not create AllocationValues for non-trivial HloPositions, e.g., ones + defined by Tuple, GetTupleElement, and Bitcast instructions. + * The HloPosition used to define the AllocationValue is referred to as the + AllocationValue's defining position. + * Typically, this is also the defining position of the HloValue. However, + it may not be. For example, we would create an AllocationValue with an + HloPosition of a read-only while loop parameter, but the HloValue + corresponding to that HloPosition would have a different defining + position. + * The uses of an AllocationValue are limited to the direct uses of the + AllocationValue's defining position. + * An AllocationValue is associated with an AllocationSequence, describing + what to do with the underlying tensor, in memory, over the lifetime of the + AllocationValue. + + - (Use) Segment: Each AllocationValue and its uses are separated into periods + of time called use segments. The first use segment is from the (inclusive) + time of the AllocationValue's defining position to its first use + (inclusive). The second use segment is from the first use (inclusive) to + the second use (inclusive), etc. + + - AllocationRequest: A request to determine what to do with an + AllocationValue, in memory, during a use segment. It also contains + restrictions and preferences on what to do. + * A request results in updates to the AllocationValue's AllocationSequence. + It may add Allocations, or modify existing Allocations in the sequence. + + - Allocation: A description of what to do with an AllocationValue in memory, + over a period of time. + * Pure virtual base class of all Allocations. + + - AllocationSequence: A sequential list of Allocations, explaining what to do + with an AllocationValue over its lifetime. Allocations in the sequence may + overlap. + + - Pinned Allocation: Represents producing a tensor in a particular memory + space, or keeping a tensor in a memory space in which it already exists. + + - Copy Allocation: Instructions to copy an AllocationValue from one memory + space to another. Used for prefetching (default mem -> alt mem), and + eviction (alt mem -> default mem). + * A copy Allocation contains a copy_done_schedule_before_time. The buffer is + available for use at that schedule time, through the Allocation's + end_time. + + - Sliced Copy Allocation: Similar to a Copy Allocation, except the memory is + copied in slices, in an effort to delay allocating memory in the destination + memory space, for as long as possible. + + - Mirrored Allocation and Parent Allocation: R/W tensors passed to while loops + typically have at least 3 AllocationValues, 1 for the producer of the tensor + before the while loop, 1 for the while loop's body parameter, and 1 for the + result of the while loop. There are situations heading into a while loop, in + which the while loop input is both in alternate memory and default memory. + (For example, this could happen beause we want the buffer in alternate + memory for the while loop and default memory after the while loop, but we + don't have resources to evict the buffer after the while loop.) In those + cases, we use a mirrored allocation for the AllocationValue inside the + while loop, to mirror the allocation in default memory. We use a parent + allocation for the AllocationValue resulting from the while loop result. + +Useful logging and error messages + + - Live range too long: The live range of a use segement is too long to for an + alternate memory no copy, i.e., its longer than we want to keep a buffer in + alternate memory wihtout being used. + * If the CostAnalysisPrefetchIntervalPicker is used, which is the default, + live range too long is governed by the picker's + max_overlap_to_mem_size_async_copy_ratio argument. + + - Live range too short: The live range of a use segement is too short to + prefetch a buffer to alternate memory, according to some heuristic and not + based on limited copy resource. + * If the CostAnalysisPrefetchIntervalPicker is used, which is the default, + live range too long is governed by the picker's + min_overlap_to_async_copy_ratio argument. + + - "Finding allocation for": Magical logging phrase indicating the point in + time where we are are trying to determine how to update an AllocationValue's + AllocationSequenece, for a particular use segment. +*/ + #ifndef XLA_SERVICE_MEMORY_SPACE_ASSIGNMENT_MEMORY_SPACE_ASSIGNMENT_H_ #define XLA_SERVICE_MEMORY_SPACE_ASSIGNMENT_MEMORY_SPACE_ASSIGNMENT_H_ @@ -41,7 +191,6 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" #include "absl/functional/any_invocable.h" -#include "absl/functional/function_ref.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "xla/hlo/ir/hlo_instruction.h" @@ -52,10 +201,13 @@ limitations under the License. #include "xla/service/hlo.pb.h" #include "xla/service/hlo_alias_analysis.h" #include "xla/service/hlo_buffer.h" -#include "xla/service/hlo_cost_analysis.h" #include "xla/service/hlo_value.h" +#include "xla/service/memory_space_assignment/allocation.h" +#include "xla/service/memory_space_assignment/cost_analysis.h" #include "xla/service/memory_space_assignment/memory_space_assignment.pb.h" +#include "xla/service/memory_space_assignment/prefetch_interval_picker.h" #include "xla/service/memory_space_assignment/repacking.h" +#include "xla/service/memory_space_assignment/slice.h" #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/statusor.h" @@ -67,8 +219,6 @@ namespace memory_space_assignment { // Forward Declaration of Options. struct Options; -inline constexpr char kConcatBitcastCustomCall[] = "ConcatBitcast"; - // This class contains pre-set assignments determined by memory space // assignment. It contains two data structures: (1) a chunks vector that maps a // defining HloPosition to a Chunk (offset and size), and (2) an assignment_info @@ -139,428 +289,6 @@ class PresetAssignments { std::string instruction_schedule_str_; }; -// A wrapper class around HloCostAnalysis with additional knowledge about the -// bandwidths of different memory spaces. -class MemorySpaceAssignmentCostAnalysis { - public: - // An optional Cache object may be provided to some of the methods below to - // speed up the lookup. - struct Cache { - absl::flat_hash_map while_nest_multiplier; - absl::flat_hash_map memory_boundedness; - }; - - // Function type that can be used to indicate which input/output values are in - // the alternate memory. - using IsInAlternateMemoryFun = absl::FunctionRef /*operand_num*/, const ShapeIndex& /*index*/, - const Shape& /*shape*/)>; - - virtual ~MemorySpaceAssignmentCostAnalysis() = default; - - static StatusOr> Create( - const HloCostAnalysis& cost_analysis, const Options& options, - const HloModule& module); - - const HloCostAnalysis& cost_analysis() const { return cost_analysis_; } - - // Returns a heuristic value that captures how much putting this tensor to the - // alternate memory would help if the op is memory bound, or otherwise how far - // off is the op to memory boundedness. The larger this number, the higher - // priority it will be placed in the alternate memory. - float GetAlternateMemoryBenefit(const HloInstruction& instruction, - float elapsed_time_due_to_alternate_mem, - Cache* cache = nullptr) const; - // Like above, return the benefit of putting the output tensor in the - // alternate memory. - float GetAlternateMemoryBenefit(const HloPosition& position, - Cache* cache = nullptr) const; - // Like above, return the benefit of putting the input tensor in the alternate - // memory. - float GetAlternateMemoryBenefit(const HloUse& use, - Cache* cache = nullptr) const; - - // Returns a heuristic value of memory boundedness for the given - // BufferInterval. The larger this number, the higher priority it will be - // placed in the alternate memory. - float GetMemoryBoundedness( - const GlobalDecreasingSizeBestFitHeap::BufferInterval& interval, - Cache* cache = nullptr) const; - - // If enabled in Options::pipeline_overhead_window_size_mib, returns the - // overhead of accessing the default memory, in seconds. The source of the - // overhead is the software pipelining ovehead. The lowering of the operations - // typically use tiling to copy one window at a time from default memory, and - // perform compute: - // - // Pipeline overhead: <-> - // +----+----+----+----+ - // Copy from default mem: | | | | | - // +----+----+----+----+ - // \ \ \ \ - // \ \ \ \ - // V V V V - // +--+ +--+ +--+ +--+ - // Compute: | | | | | | | | - // +--+ +--+ +--+ +--+ - float GetDefaultMemoryAccessOverhead( - const HloInstruction& instruction, - absl::Span> - operands_in_alternate_mem = {}, - absl::Span outputs_in_alternate_mem = {}) const; - - // Returns the amount of time the default memory bandwidth is idle, while - // executing this instruction, in seconds. This value can be multiplied with - // the default memory bandwidth to get the amount of bytes that are available - // to be copied to/from default memory during the execution of this - // instruction. - float GetDefaultMemoryBandwidthIdleTime( - const HloInstruction& instruction, - absl::Span> - operands_in_alternate_mem = {}, - absl::Span outputs_in_alternate_mem = {}) const; - - // Returns the bytes accessed from alternate memory. - float GetBytesAccessedFromAlternateMemory( - const HloInstruction& instruction, - absl::Span> - operands_in_alternate_mem = {}, - absl::Span outputs_in_alternate_mem = {}) const; - - // Returns the elapsed time in seconds due to compute only. - float GetInstructionElapsedDueToCompute( - const HloInstruction& instruction) const; - - // Returns the elapsed time in seconds due to memory only. If - // operands_in_alternate_mem or outputs_in_alternate_mem is provided, it will - // assume that the corresponding operands or output will be in the alternate - // memory space. This is useful for calculating the benefit of placing the - // buffer in alternate memory. - float GetInstructionElapsedDueToMemory( - const HloInstruction& instruction, - absl::Span> - operands_in_alternate_mem = {}, - absl::Span outputs_in_alternate_mem = {}) const; - - // Like above, only the inputs/outputs indicated by is_in_alternate_mem are in - // the alternate memory. - float GetInstructionElapsedDueToMemory( - const HloInstruction& instruction, - IsInAlternateMemoryFun is_in_alternate_mem) const; - - // Returns the estimated elapsed duration of the instruction in seconds. It - // assumes all operands and outputs of the instruction are in the default - // memory. - virtual float GetInstructionElapsed(const HloInstruction& instruction) const; - - // Returns the estimated elapsed duration of the instruction in seconds. It - // assumes all operands and outputs of the instruction are in the default - // memory, except for the operands and outputs specified to be in the - // alternate memory. - virtual float GetInstructionElapsedInAlternateMemory( - const HloInstruction& instruction, - absl::Span> - operands_in_alternate_mem, - absl::Span outputs_in_alternate_mem) const; - - // Like above, only the inputs/outputs indicated by is_in_alternate_mem are in - // the alternate memory. - float GetInstructionElapsedInAlternateMemory( - const HloInstruction& instruction, - IsInAlternateMemoryFun is_in_alternate_mem) const; - - // Returns the elapsed time it would take to asynchronously copy the shape - // from default to alternate memory space (or vice versa). - virtual float GetAsyncCopyElapsed(const Shape& shape) const; - - int64_t GetScheduleEndTime() const; - - // Returns the number of nested computation levels this instruction resides - // in. If while_only is true, it returns the while loop nest level and 0 - // means the instruction is not in a while loop. - int CalculateComputationNestLevel(const HloInstruction* instruction, - bool while_only) const; - - const HloLiveRange& hlo_live_range() const { return *hlo_live_range_; } - const Options& options() const { return options_; } - - protected: - MemorySpaceAssignmentCostAnalysis( - const HloCostAnalysis& cost_analysis, const Options& options, - std::unique_ptr alias_analysis, - std::unique_ptr hlo_live_range, - std::unique_ptr call_graph) - : cost_analysis_(cost_analysis), - options_(options), - alias_analysis_(std::move(alias_analysis)), - hlo_live_range_(std::move(hlo_live_range)), - call_graph_(std::move(call_graph)) {} - - private: - const HloCostAnalysis& cost_analysis_; - const Options& options_; - std::unique_ptr alias_analysis_; - std::unique_ptr hlo_live_range_; - std::unique_ptr call_graph_; -}; - -// Abstract base class that memory space assignment uses to pick prefetch -// intervals. -class PrefetchIntervalPicker { - public: - PrefetchIntervalPicker() = default; - virtual ~PrefetchIntervalPicker() = default; - - // Returns true if the buffer can be allocated in alternate memory space - // without any copies (prefetches). - virtual bool CanAllocateInAlternateMemoryNoCopy(const Shape& shape, - int64_t start_time, - int64_t end_time) const = 0; - - // Returns the preferred end time for an eviction that starts at a given time - // and must end by the given end time. - virtual int64_t PreferredEvictionEndTime(const Shape& shape, - int64_t start_time, - int64_t latest_end_time) const = 0; - - // Returns the latest time that a prefetch can start. - virtual int64_t LatestPrefetchStartTime(const Shape& shape, - int64_t start_time, int64_t end_time, - const HloUse* use) const = 0; - - // Returns the preferred time that a prefetch can start. - virtual int64_t PreferredPrefetchStartTime( - const Shape& shape, int64_t earliest_prefetch_start_time, - int64_t latest_prefetch_start_time, int64_t prefetch_end_time) const = 0; - - // Returns the latest time that a prefetch can end that is less than or equal - // to proposed_prefetch_end_time. - virtual int64_t LatestPrefetchEndTime( - int64_t original_prefetch_end_time, - int64_t proposed_prefetch_end_time) const { - return proposed_prefetch_end_time; - } - - // Returns the estimated end time of a prefetch that starts at the given time. - virtual int64_t EstimatedPrefetchEndTime(const Shape& shape, - int64_t start_time, - int64_t end_time) const = 0; - - // Returns the elapsed time in seconds between the logical interval that - // corresponds to the instruction schedule. - virtual float GetLogicalIntervalElapsed(int64_t start_time, - int64_t end_time) const = 0; - - // Begins the iterator for the first start time of the prefetch. - virtual void Begin(const HloUse& use, int64_t start_time, int64_t end_time, - std::optional preferred_time) = 0; - - // Advances the start time of the prefetch and returns that value. - virtual int64_t Next() = 0; - - // Returns true if the available prefetch intervals have been exhausted. - virtual bool Done() const = 0; - - // Returns the latest time the prefetch interval picker will have pick. - virtual int64_t latest_time() const = 0; - - // The retry number can be used to modify the interval picking policies. The - // first attempt will have a retry_number of 0, then 1, etc. - virtual void SetRetryNumber(int retry_number) { - retry_number_ = retry_number; - } - int retry_number() const { return retry_number_; } - - // Returns a debug string for the current state of the prefetch interval - // picker. - virtual std::string ToDebugString() const = 0; - - // Returns a debug string for no-copy allocation. - virtual std::string ToNoCopyDebugString(const Shape& shape, - int64_t start_time, - int64_t end_time) const = 0; - - // Prefetch interval pickers may return a value corresponding to the benefit - // of placing the BufferInterval in the alternate memory. The larger value, - // the more beneficial. - virtual std::optional BufferIntervalAlternateMemoryBenefit( - const GlobalDecreasingSizeBestFitHeap::BufferInterval& interval) - const { - return std::nullopt; - } - - protected: - const absl::flat_hash_map* - instruction_schedule_ = nullptr; - int retry_number_ = 0; -}; - -// Prefetch interval picker that uses instruction count to overlap asynchronous -// copies with independent computation. The min and max overlap counts describe -// the number of independent HLOs overlapped while a value is being prefetched -// into the alternate memory (between CopyStart and CopyDone HLO instructions). -// max_overlap_count attempts to prevent bringing tensors into the alternate -// memory too eagerly and hence occupying the space for other tensors which -// might use it. min_overlap_count attempts to prevent cases where tensors are -// prefetched into the alternate memory without sufficient time for the copy to -// take place. In those cases, it's just better to keep the tensor in the -// default memory instead of hurting the critical path with this copy that -// likely won't finish in time. -class InstructionCountPrefetchIntervalPicker : public PrefetchIntervalPicker { - public: - InstructionCountPrefetchIntervalPicker(int64_t min_overlap_count, - int64_t max_overlap_count) - : min_overlap_count_(min_overlap_count), - max_overlap_count_(max_overlap_count) {} - - bool CanAllocateInAlternateMemoryNoCopy(const Shape& shape, - int64_t start_time, - int64_t end_time) const override; - - int64_t PreferredEvictionEndTime(const Shape& shape, int64_t start_time, - int64_t latest_end_time) const override; - - int64_t LatestPrefetchStartTime(const Shape& shape, int64_t start_time, - int64_t end_time, - const HloUse* use) const override; - - int64_t PreferredPrefetchStartTime(const Shape& shape, - int64_t earliest_prefetch_start_time, - int64_t latest_prefetch_start_time, - int64_t prefetch_end_time) const override; - - int64_t EstimatedPrefetchEndTime(const Shape& shape, int64_t start_time, - int64_t end_time) const override; - float GetLogicalIntervalElapsed(int64_t start_time, - int64_t end_time) const override; - - void Begin(const HloUse& use, int64_t start_time, int64_t end_time, - std::optional preferred_time) override; - - int64_t Next() override; - bool Done() const override; - - int64_t latest_time() const override; - - std::string ToDebugString() const override; - std::string ToNoCopyDebugString(const Shape& shape, int64_t start_time, - int64_t end_time) const override; - - private: - int64_t min_overlap_count_; - int64_t max_overlap_count_; - int64_t end_time_; - int64_t current_prefetch_time_; -}; - -// Forward Declaration of MemorySpaceAssignmentCostAnalysis -class MemorySpaceAssignmentCostAnalysis; -// Prefetch interval picker that uses cost analysis to overlap asynchronous -// copies with independent computation. It uses min (independent computation -// duration) / (asynchronous copy duration) ratio to guide whether the prefetch -// is within the lower bound. For the upper bound, it restricts the maximum -// duration that a buffer may occupy the alternate memory space as a multiple of -// the time it would take to copy a buffer that is the size of the alternate -// memory. It starts with the preferred ratio in Begin() and works its way for -// alternately earlier and later prefetches until hitting min and max ratios. -// The value for buffer size for max async copy is a mechanism to prevent -// copying small buffers between the two memories unnecessarily. For calculating -// the max time that the buffer can reside in alternate memory, we use the -// larger of this value and the actual size of the buffer. A shape override can -// also be provided which causes the interval picker to use that shape for async -// copy durations instead of the actual shape of the copy. -class CostAnalysisPrefetchIntervalPicker : public PrefetchIntervalPicker { - public: - CostAnalysisPrefetchIntervalPicker( - const MemorySpaceAssignmentCostAnalysis& cost_analysis, - float min_overlap_to_async_copy_ratio, - float preferred_overlap_to_async_copy_ratio, - float max_overlap_to_mem_size_async_copy_ratio, int64_t mem_size_bytes, - const Shape* shape_override = nullptr); - - bool CanAllocateInAlternateMemoryNoCopy(const Shape& shape, - int64_t start_time, - int64_t end_time) const override; - - int64_t PreferredEvictionEndTime(const Shape& shape, int64_t start_time, - int64_t latest_end_time) const override; - - int64_t LatestPrefetchEndTime( - int64_t original_prefetch_end_time, - int64_t proposed_prefetch_end_time) const override; - - int64_t LatestPrefetchStartTime(const Shape& shape, int64_t start_time, - int64_t end_time, - const HloUse* use) const override; - - int64_t PreferredPrefetchStartTime(const Shape& shape, - int64_t earliest_prefetch_start_time, - int64_t latest_prefetch_start_time, - int64_t prefetch_end_time) const override; - - int64_t EstimatedPrefetchEndTime(const Shape& shape, int64_t start_time, - int64_t end_time) const override; - float GetLogicalIntervalElapsed(int64_t start_time, - int64_t end_time) const override; - - void Begin(const HloUse& use, int64_t start_time, int64_t end_time, - std::optional preferred_time) override; - - int64_t Next() override; - bool Done() const override; - - int64_t latest_time() const override; - - void SetRetryNumber(int retry_number) override; - - std::string ToDebugString() const override; - std::string ToNoCopyDebugString(const Shape& shape, int64_t start_time, - int64_t end_time) const override; - - std::optional BufferIntervalAlternateMemoryBenefit( - const GlobalDecreasingSizeBestFitHeap::BufferInterval& interval) - const override; - - private: - // Finds the minimum nest level in the given interval. - int GetMinWhileNestLevel(int64_t start_time, int64_t end_time) const; - - // Given the elapsed time to copy this buffer to the alternate memory, returns - // the longest time that this buffer may reside in the alternate memory space. - float GetMaxElapsedInAlternateMemory(float async_copy_elapsed) const; - - // For each instruction in the flattened schedule, maintain their elapsed time - // (in cumulative sum) and while nesting level. - std::vector elapsed_time_cumsum_; - std::vector while_nest_level_; - std::vector computation_nest_level_; - // Maintain the index of the most recent (before this instruction) nest level - // change in order to efficiently determine the minimum nest level in an - // interval. - std::vector while_nest_level_change_; - - const MemorySpaceAssignmentCostAnalysis& cost_analysis_; - float min_overlap_to_async_copy_ratio_; - float preferred_overlap_to_async_copy_ratio_; - float max_async_copy_elapsed_; - float max_overlap_multiplier_ = 1.0; - - float async_copy_elapsed_; - float inst_elapsed_reduction_; - int64_t end_logical_time_; - int64_t earliest_prefetch_time_; - int64_t latest_prefetch_time_; - bool using_increasing_prefetch_time_iterator_ = true; - int64_t increasing_prefetch_time_iterator_; - int64_t decreasing_prefetch_time_iterator_; - - std::vector while_execution_counts_; - // Shape override is used to override the shape of the shape of the async copy - // to treat all async copies the same duration. Having an override forces - // prefetches to be scheduled roughly in FIFO order. - std::optional shape_override_; -}; - // A class for turning a copy start time and end time into slice start times. class SlicedPrefetchStartTimePicker { public: @@ -614,6 +342,8 @@ class MemorySpaceAssignment { std::pair>& /*operands_in_alternate_memory*/, const absl::flat_hash_set& /*outputs_in_alternate_memory*/)>; + using AllocationSequence = std::vector>; + // The BufferInterval sorting interface that MemorySpaceAssignment expects. class BufferIntervalComparator { public: @@ -647,445 +377,6 @@ class MemorySpaceAssignment { BufferIntervalComparator() = default; }; - // MemorySpaceAssignment uses a notion of a slow and large default memory - // space and a fast and small alternate memory space. - enum class MemorySpace { kDefault, kAlternate }; - - // Forward declaration for Allocation. - class Allocation; - class ParentAllocation; - - // This class represents an allocation that might either be in the default or - // alternate memory. An HloValue might live in multiple different allocations - // over its lifetime. The lifetimes of the allocations are defined using - // start_time and end_time, which corresponds to the instruction indexes in - // the flattened schedule. Each of these allocations might partially overlap - // with each other. CopyAllocation defined below represents asynchronous - // copies between Allocations. - // - // Consider an instruction Foo, and its users Bar and Baz, and the times given - // in terms of the flattened schedule of the entire module: - // - // Foo:10 - // / \ - // Bar:14 \ - // Baz:25 - // - // A valid memory space assignment could be like the following: - // - // Time: 10 ... 14 ... 25 - // Foo Bar Baz - // Alternate +-------+ +-----+ - // Default +---------------------+ - // ^ ^ ^ ^ - // | | | | - // evict evict prefetch prefetch - // start end start end - // - // This would be represented with: - // - Allocation(memory_space=kAlternate, start_time=10, end_time=14) - // - CopyAllocation(memory_space=kDefault, start_time=12, end_time=25) - // - CopyAllocation(memory_space=kAlternate, start_time=22, end_time=25) - class Allocation { - friend class ParentAllocation; - - public: - Allocation(HloPosition defining_position, MemorySpace memory_space, - std::optional chunk, int64_t start_time, int64_t end_time, - bool is_scoped_allocation) - : defining_position_(defining_position), - memory_space_(memory_space), - chunk_(chunk), - start_time_(start_time), - end_time_(end_time), - is_scoped_allocation_(is_scoped_allocation) { - CHECK(!is_scoped_allocation || defining_position.index == ShapeIndex({})); - } - virtual ~Allocation() = default; - - // True if the allocation is for a copy or a sliced-copy. - bool is_copy_like_allocation() const; - - virtual bool is_copy_allocation() const { return false; } - virtual bool is_sliced_copy_allocation() const { return false; } - - // Adds a use to this allocation. - void AddUse(HloUse use); - - // Extends the end time of this allocation. - void Extend(int64_t end_time) { end_time_ = std::max(end_time_, end_time); } - - // After all of the time ranges for the allocations have been assigned, - // Process morphs the instructions affected to assign the memory spaces and - // insert asynchronous copy instructions if necessary. - virtual Status Process(); - - // An optional post-process step that will be called after all allocations - // have been processed. - virtual Status PostProcess() { return OkStatus(); } - - // Marks (adds this allocation to needed_allocations) if this allocation is - // needed. Allocation and CopyAllocations are always needed and - // ParentAllocations are needed if they have any uses or if other - // CopyAllocation or ParentAllocations depend on them. - virtual void MarkIfNeeded( - absl::flat_hash_set& needed_allocations) const; - - // Marks this allocation as needed. - virtual void MarkNeeded( - absl::flat_hash_set& needed_allocations) const; - - // Returns the defining position for this allocation. - virtual HloPosition defining_position() const { return defining_position_; } - - // Returns the time the buffer is first available to be used. For - // Allocation, this is start_time. - virtual int64_t earliest_available_time() const { return start_time_; } - - const std::vector& uses() const { return uses_; } - void clear_uses() { uses_.clear(); } - MemorySpace memory_space() const { return memory_space_; } - // Returns the associated chunk that may be a nullopt if the allocation is - // in the default memory space. - std::optional maybe_chunk() const { return chunk_; } - // Returns the associated chunk. The caller should ensure that the chunk is - // defined (the allocation should be in the alternate memory space). - Chunk chunk() const { - CHECK(chunk_.has_value()); - return *chunk_; - } - Chunk* mutable_chunk() { return &*chunk_; } - void set_offset(int64_t offset); - void set_start_time(int64_t start_time) { start_time_ = start_time; } - void set_end_time(int64_t end_time) { end_time_ = end_time; } - int64_t start_time() const { return start_time_; } - int64_t end_time() const { return end_time_; } - bool is_scoped_allocation() const { return is_scoped_allocation_; } - virtual std::optional cross_program_prefetch_index() const { - return std::nullopt; - } - - bool operator==(const Allocation& other) const; - virtual std::string ToString() const; - - bool is_in_alternate_mem() const { - return memory_space_ == MemorySpace::kAlternate; - } - bool is_in_default_mem() const { - return memory_space_ == MemorySpace::kDefault; - } - - protected: - // Recursively create kGetTupleElement instructions if the defining position - // shape is not an array. Returns the new instruction that has array shape. - HloInstruction* AddGetTupleElements() const; - - HloPosition defining_position_; - std::vector uses_; - MemorySpace memory_space_; - std::optional chunk_; - int64_t start_time_; - int64_t end_time_; - const bool is_scoped_allocation_; - }; - - // This class represents an allocation as a result of an asynchronous copy. - // Note: CopyStart instructions are inserted after - // `copy_start_schedule_after`, while CopyDone instructions are inserted - // before `copy_done_schedule_before_time`. - class CopyAllocation : public Allocation { - public: - // TODO(b/307342076): Reorder scheduling times to be - // copy_start_schedule_after_time, copy_done_schedule_before_time, end_time - CopyAllocation( - Allocation& prev_allocation, MemorySpace memory_space, - std::optional chunk, int64_t copy_start_schedule_after_time, - int64_t copy_done_schedule_before_time, int64_t end_time, - std::optional cross_program_prefetch_index = std::nullopt); - - bool is_copy_allocation() const override { return true; } - - Status Process() override; - - void MarkNeeded(absl::flat_hash_set& needed_allocations) - const override; - - HloPosition defining_position() const override { - // Unless explicitly set, the defining position of a copy allocation in - // retrieved from the previous allocation. This is because we don't create - // new CopyStart/CopyDone instructions until later and the position should - // point to the previous (copy or otherwise) allocation's position for the - // original defining position. - if (defining_position_.instruction == nullptr) { - return prev_allocation_.defining_position(); - } - return defining_position_; - } - - HloInstruction* copy_start() const { return copy_start_; } - HloInstruction* copy_done() const { return copy_done_; } - - // Returns the time the buffer is first available to be used. For - // CopyAllocation, this is when the copy ends, which is - // copy_done_schedule_before. - int64_t earliest_available_time() const override { - return copy_done_schedule_before_; - } - - int64_t copy_start_schedule_after() const { - return copy_start_schedule_after_; - } - int64_t copy_done_schedule_before() const { - return copy_done_schedule_before_; - } - - void set_copy_start_schedule_after(int64_t copy_start_schedule_after) { - copy_start_schedule_after_ = copy_start_schedule_after; - } - - void set_copy_done_schedule_before(int64_t copy_done_schedule_before) { - copy_done_schedule_before_ = copy_done_schedule_before; - } - - std::optional cross_program_prefetch_index() const override { - return cross_program_prefetch_index_; - } - - bool operator==(const CopyAllocation& other) const; - std::string ToString() const override; - - const Allocation& prev_allocation() { return prev_allocation_; } - Allocation& mutable_prev_allocation() { return prev_allocation_; } - - private: - Allocation& prev_allocation_; - // These variables define the scheduling boundaries where CopyStart and - // CopyDone can be scheduled. The earliest CopyStart can be scheduled is - // after copy_start_schedule_after_ and the latest CopyDone can be scheduled - // is before copy_done_schedule_before_. - int64_t copy_start_schedule_after_; - int64_t copy_done_schedule_before_; - HloInstruction* copy_start_; - HloInstruction* copy_done_; - std::optional cross_program_prefetch_index_; - }; - - // The parameters for slicing a single dimension of a tensor. - struct SliceParam { - std::string ToString() const; - bool operator==(const SliceParam& other) const; - - int64_t start_inclusive; - int64_t end_exclusive; - }; - - // A proposed way to slice a buffer. - struct SliceProposal { - std::string ToString() const; - friend std::ostream& operator<<(std::ostream& os, - const SliceProposal& proposal); - std::tuple&, int64_t> - ToTuple() const; - bool operator==(const SliceProposal& other) const; - - // Shape resulting from the slice. - Shape slice_shape; - - // slice_params map to the parameters that would be passed to a slice - // instruction. Thus: - // * There should be a slice parameter for every dimension in the shape of - // the tensor being sliced. - // * The ith slice_param applies to the ith logical dimension in the shape - // being sliced. - // * If a dimension is not being sliced, it should have a SliceParam of - // {0, dim size}. - std::vector slice_params; - - // The size to be allocated for the slice. Note, this may be > the size of - // the slice shape, due to additional padding that may occur when the slices - // are concatenated back together. - int64_t slice_size; - }; - - // A SliceProposalCollection proposes a way to to slice an AllocationRequest. - // A SliceProposalCollection is generated from a SliceProposalFunction and is - // used when we want to slice a prefetch. - using SliceProposalCollection = std::vector; - using SliceProposalFunction = std::function( - const Shape& shape, const SlicedPrefetchOptions& options)>; - - // A SliceDecision is a SliceProposal that we've determined where and when to - // allocate. - struct SliceDecision { - std::string ToString() const; - bool operator==(const SliceDecision& other) const; - - Chunk chunk; - int64_t exclusive_start_time; - SliceProposal sizing; - float copy_resource_consumed; - }; - - // This class represents an allocation resulting from asynchronous sliced - // copies. - // - // Let the sliced allocation be represented as follows, and imagine that t3 - // is the time when the entire buffer [p0, p3) is available for use - // - // space - // ^ - // p3 | +-----------+ - // | | | - // p2 | +---+ | - // | | | - // p1 | +-------+ | - // | | | - // p0 | +-------+ - // +---|---|---|---|---|----> time - // t0 t1 t2 t3 t4 - // - // The Allocation underlying the SlicedCopyAllocation will use the following - // dimensions: - // - chunk = [p0, p3) - // - start time = t2 - // - earliest_available_time = t3 - // - end_time = t4 - class SlicedCopyAllocation : public Allocation { - public: - // Full details about a slice in the sliced allocation. - struct SliceDetail { - std::string ToString() const; - std::tuple - ToTuple() const; - bool operator==(const SliceDetail& other) const; - - // Create the instructions to copy the slice. This method updates - // copy_start and copy_done. - Status CreateAsyncSlice(const Shape& original_shape, - HloInstruction& producer, HloComputation& parent); - - SliceDecision slice_decision; - int64_t copy_start_after_time = -1; - int64_t copy_done_before_time = -1; - HloInstruction* copy_start = nullptr; - HloInstruction* copy_done = nullptr; - }; - - // REQUIRES: - // - slice_decisions_sorted_by_start_time.size() >= 2, otherwise, - // CopyAllocation should be used. - SlicedCopyAllocation( - const Allocation& prev_allocation, MemorySpace memory_space, - std::vector slice_decisions_sorted_by_start_time, - int64_t copy_done_schedule_before_time, int64_t end_time); - - bool is_sliced_copy_allocation() const override { return true; } - - // MemorySpaceAssignment::Process() calls Process() to create asynchronous - // slice copies, and a bitcast-concat call to glue the slices back together. - Status Process() override; - - // Marks the allocation as needed. - void MarkNeeded(absl::flat_hash_set& needed_allocations) - const override; - - // Returns the defining position for this allocation. - HloPosition defining_position() const override; - - // Returns the time the buffer is first available to be used. For - // SlicedCopyAllocation, this is when all copies have ended. - int64_t earliest_available_time() const override; - - std::vector SliceOffsetsSortedByStartTime() const; - void AddDiffToAllSliceOffsets(int64_t diff); - - // Used to update offsets and start times after repacking. - void ImportRepackedSliceData(const SlicedAllocationData& data); - - const std::vector& slice_details_sorted_by_start_time() const; - std::vector& mutable_slice_details_sorted_by_start_time(); - HloInstruction* concat() const { return concat_; } - - std::tuple&, - const HloInstruction*> - ToTuple() const; - bool operator==(const SlicedCopyAllocation& other) const; - std::string ToString() const override; - - private: - SlicedCopyAllocation() = delete; - - // Create an instruction to concatenate the slices. Populates concat_. - Status CreateBitcastConcat(const Shape& shape, - absl::Span slices); - - Shape original_shape_to_slice_; - const Allocation& prev_allocation_; - // REQUIRES: - // - sorted_segments_[i].copy_start_after_time <= - // sorted_segments_[i+j].copy.start_after_time - // - sorted_segments_[i].copy_done_before_time <= - // sorted_segments_[i+j].copy.start_before_time - std::vector slice_details_sorted_by_start_time_; - HloInstruction* concat_ = nullptr; - }; - - // An allocation in the default memory space that mirrors another Allocation - // object. This is useful to model an eviction that happens before a while op - // so that we don't need to redundantly evict the buffer after the while op as - // well. - class MirroredAllocation : public Allocation { - public: - MirroredAllocation(const Allocation& original_allocation, int64_t time) - : Allocation(original_allocation.defining_position(), - MemorySpace::kDefault, original_allocation.maybe_chunk(), - /*start_time=*/time, - /*end_time=*/time, /*is_scoped_allocation=*/false), - original_allocation_(original_allocation) {} - - Status Process() override; - - void MarkNeeded(absl::flat_hash_set& needed_allocations) - const override; - - std::string ToString() const override; - - private: - const Allocation& original_allocation_; - }; - - // An allocation in default memory space that is defined in the parent - // computation. If a value has a copy in the default memory space in the - // parent computation, we don't need to evict this buffer in a while loop. - class ParentAllocation : public Allocation { - public: - ParentAllocation(const Allocation& original_allocation, - HloInstruction* calling_instruction, HloPosition position, - int64_t time) - : Allocation(position, MemorySpace::kDefault, - original_allocation.maybe_chunk(), /*start_time=*/time, - /*end_time=*/time, /*is_scoped_allocation=*/false), - original_allocation_(original_allocation), - calling_instruction_(calling_instruction) {} - - Status Process() override; - Status PostProcess() override; - - void MarkIfNeeded(absl::flat_hash_set& - needed_allocations) const override; - void MarkNeeded(absl::flat_hash_set& needed_allocations) - const override; - - std::string ToString() const override; - - private: - const Allocation& original_allocation_; - HloInstruction* calling_instruction_; - }; - - using AllocationSequence = std::vector>; // AllocationValue is used to break up HloValues for each non-trivial position // (trivial positions are considered Tuple, GetTupleElement, and Bitcast). An // HloValue may include positions and uses that alias with each other across @@ -1349,8 +640,13 @@ class MemoryBoundednessBufferIntervalComparator : public MemorySpaceAssignment::BufferIntervalComparator { public: MemoryBoundednessBufferIntervalComparator( - const MemorySpaceAssignmentCostAnalysis& cost_analysis, - MemorySpaceAssignmentCostAnalysis::Cache* cost_analysis_cache); + const CostAnalysis& cost_analysis, + CostAnalysis::Cache* cost_analysis_cache); + + MemoryBoundednessBufferIntervalComparator( + const CostAnalysis& cost_analysis, + CostAnalysis::Cache* cost_analysis_cache, + MsaSortOrderOverrides msa_sort_order_overrides); ~MemoryBoundednessBufferIntervalComparator() override = default; @@ -1367,8 +663,12 @@ class MemoryBoundednessBufferIntervalComparator ComparisonTuple GetTuple(const BufferInterval& buffer_interval); int64_t GetLatestUseTime(const BufferInterval& buffer_interval); absl::flat_hash_map buffer_to_latest_use_; - const MemorySpaceAssignmentCostAnalysis& cost_analysis_; - MemorySpaceAssignmentCostAnalysis::Cache* cost_analysis_cache_; + const CostAnalysis& cost_analysis_; + CostAnalysis::Cache* cost_analysis_cache_; + + // Config to override alternate memory assignment sorting order for filtered + // buffers. + MsaSortOrderOverrides msa_sort_order_overrides_; }; // The default BufferIntervalComparator used for cross-program prefetching. @@ -1424,11 +724,13 @@ struct Options { PrefetchIntervalPicker* prefetch_interval_picker = nullptr; // This object is used to determine the benefit of a particular allocation. - MemorySpaceAssignmentCostAnalysis* cost_analysis = nullptr; + CostAnalysis* cost_analysis = nullptr; // Size function for buffer values. BufferValue::SizeFunction size_fn; + std::function get_equivalent_s8_shape_fn; + // This function can be used to prevent certain HloValues (e.g., based on // the opcode) to be placed on the alternate memory. MemorySpaceAssignment::IsAllowedInAlternateMemoryFunction @@ -1483,21 +785,6 @@ struct Options { // greater than 0, repacker must be non-nullptr. int64_t max_repacks = 0; - // This variable is used by the cost analysis in estimating how many times - // each while loop will execute. Nested loops will be assumed to have - // executed pow(while_execution_count, nesting_level) times. - uint64_t xla_tpu_memory_space_assignment_while_execution_count = 5ULL; - - // This variable is used to scale the alternate memory benefit factor for - // large buffers. The default scaling function is sqrt. - std::string - xla_tpu_alternate_memory_benefit_scaling_factor_for_large_buffers = - "SQRT"; - - float async_copy_bandwidth_bytes_per_second = 0.0f; - - float alternate_mem_bandwidth_bytes_per_second = 0.0f; - // The repacking algorithm to reduce fragmentation. Must be non-null if // max_repacks is greater than 0. MemorySpaceAssignmentRepacker* repacker = nullptr; @@ -1543,9 +830,6 @@ struct Options { // to sort allocated buffers. std::optional> autotuning_config = std::nullopt; - // Scales effective bandwidth for async copies. Valid range is (0, 1]. - float async_copy_bandwidth_scaling_factor = 1.0; - // If true, uses the earlier instance of the same instruction to use as // preferred prefetch start time. bool use_repeated_instance_for_preferred_prefetch_time = false; @@ -1571,10 +855,6 @@ struct Options { absl::Span)> get_inefficient_allocation_sites_fn = nullptr; - // The window size used to calculate the pipeline overhead when HLO accesses - // the default memory, in MiB. - float pipeline_overhead_window_size_mib = 0; - // Config to filter prefetches and update preferred prefetch times for the // filtered prefetches. PreferredPrefetchOverrides preferred_prefetch_overrides; @@ -1585,19 +865,15 @@ struct Options { // Options for the memory-bound loop optimizer feature. MemoryBoundLoopOptimizerOptions memory_bound_loop_optimizer_options; - MemorySpaceAssignment::SliceProposalFunction propose_slice_fn = - [](const Shape&, const SlicedPrefetchOptions&) - -> xla::StatusOr { + SliceProposalFunction propose_slice_fn = [](const Shape&, + const SlicedPrefetchOptions&) + -> xla::StatusOr { return UnimplementedStrCat("Generation of SliceProposals unimplemented"); }; // Option to always spill buffers from alternate memory to default memory // and prefetching back to alternate memory(if needed) just in time for use. bool always_spill_to_default_memory = false; - - // Config to override alternate memory assignment sorting order for filtered - // buffers. - MsaSortOrderOverrides msa_sort_order_overrides; }; // A struct representing an asynchronous copy with its logical start and end @@ -1607,12 +883,10 @@ struct AsynchronousCopy { int64_t exclusive_start_time; int64_t end_time; float resource; - MemorySpaceAssignment::MemorySpace destination; + MemorySpace destination; int64_t id; - std::tuple - AsTuple() const { + std::tuple AsTuple() const { return std::make_tuple(exclusive_start_time, end_time, resource, destination, id); } @@ -1727,9 +1001,8 @@ class AsynchronousCopyResource { // A useful debugging tool for printing several pieces of information about // AsynchronousCopyResource. - std::string Dump( - int64_t start_time, int64_t end_time, - MemorySpaceAssignment::MemorySpace memory_space_filter) const; + std::string Dump(int64_t start_time, int64_t end_time, + MemorySpace memory_space_filter) const; private: // Internal helper method to implement adding/removing/checking resources. @@ -1878,7 +1151,7 @@ class MemoryBoundLoopOptimizer { const MemoryBoundLoopOptimizerOptions& options, const HloLiveRange& hlo_live_range, const HloAliasAnalysis& alias_analysis_, - const MemorySpaceAssignmentCostAnalysis& cost_analysis, + const CostAnalysis& cost_analysis, const BufferValue::SizeFunction& size_function); // Optimize the loop. Initialize must be called first. @@ -1920,13 +1193,13 @@ class MemoryBoundLoopOptimizer { std::vector additional_memory_used; }; - MemoryBoundLoopOptimizer( - int loop_start, int loop_end, uint64_t alternate_memory_size, - const MemoryBoundLoopOptimizerOptions& options, - const HloLiveRange& hlo_live_range, - const HloAliasAnalysis& alias_analysis_, - const MemorySpaceAssignmentCostAnalysis& cost_analysis, - const BufferValue::SizeFunction& size_function); + MemoryBoundLoopOptimizer(int loop_start, int loop_end, + uint64_t alternate_memory_size, + const MemoryBoundLoopOptimizerOptions& options, + const HloLiveRange& hlo_live_range, + const HloAliasAnalysis& alias_analysis_, + const CostAnalysis& cost_analysis, + const BufferValue::SizeFunction& size_function); // Initializes the data structures used by the optimizer. Status Initialize(); @@ -1995,7 +1268,7 @@ class MemoryBoundLoopOptimizer { MemoryBoundLoopOptimizerOptions options_; const HloLiveRange& hlo_live_range_; const HloAliasAnalysis& alias_analysis_; - const MemorySpaceAssignmentCostAnalysis& cost_analysis_; + const CostAnalysis& cost_analysis_; BufferValue::SizeFunction size_function_; absl::flat_hash_map instructions_in_loop_; @@ -2017,7 +1290,6 @@ class MemoryBoundLoopOptimizer { class AlternateMemoryBestFitHeap : public GlobalDecreasingSizeBestFitHeap { public: - using MemorySpace = MemorySpaceAssignment::MemorySpace; using AllocationValue = MemorySpaceAssignment::AllocationValue; using HloPositionOrUse = std::variant; @@ -2069,14 +1341,14 @@ class AlternateMemoryBestFitHeap // We inherit AllocationBlock struct to attach the Allocation information to // make importing repacked offsets easier. struct RepackAllocationBlock : AllocationBlock { - MemorySpaceAssignment::Allocation* allocation; + Allocation* allocation; }; // A data structure we use to associate Allocation objects that are aliased // and must get the same offset. struct AliasedOffset { int64_t offset; - absl::flat_hash_set allocations; + absl::flat_hash_set allocations; }; // An allocation request for a use segment. A use segment is the time segment @@ -2121,7 +1393,7 @@ class AlternateMemoryBestFitHeap // time of the parameter instruction, and an output's time would correspond to // the time of last use. struct RequiredMemoryAssignment { - MemorySpaceAssignment::MemorySpace memory_space; + MemorySpace memory_space; int64_t time; AliasedOffset* offset; @@ -2149,7 +1421,7 @@ class AlternateMemoryBestFitHeap // instruction. int64_t loop_size; // A pointer into an Allocation in loop_optimized_allocations_. - const MemorySpaceAssignment::Allocation* loop_optimized_allocation; + const Allocation* loop_optimized_allocation; }; // A context object that is used to share state amongst the methods that @@ -2184,8 +1456,7 @@ class AlternateMemoryBestFitHeap // p0 | +-------+ // +---|---|---|---|---|----> time // t0 t1 t2 t3 t4 - std::vector - slice_decisions_sorted_by_start_time; + std::vector slice_decisions_sorted_by_start_time; // In order to support colocated buffer calculations, we need to add a // BufferInterval-Chunk pair to pending_chunks_, such that: @@ -2247,7 +1518,7 @@ class AlternateMemoryBestFitHeap // Parameters to Prefetch(). const AllocationRequest* request; - MemorySpaceAssignment::Allocation* prev_allocation_in_default_mem; + Allocation* prev_allocation_in_default_mem; // Intermediate calculations common to both the sliced and unsliced // solutions. @@ -2261,8 +1532,8 @@ class AlternateMemoryBestFitHeap std::optional exclusive_out_of_mem_start = std::nullopt; // Data structures used to compute and store the sliced solution. - std::optional - slice_proposal_collection = std::nullopt; + std::optional slice_proposal_collection = + std::nullopt; WorkingIntervals sliced_solution_intervals; std::optional sliced_solution; @@ -2347,20 +1618,18 @@ class AlternateMemoryBestFitHeap void AllocateReservedScopedAllocations(); // Returns the AliasedOffset object associated with the allocation. - AliasedOffset* GetAliasedOffset( - const MemorySpaceAssignment::Allocation& allocation); + AliasedOffset* GetAliasedOffset(const Allocation& allocation); // If aliased_offset is non-null, this method adds the allocation to // aliased_offset. Otherwise, it creates a new AliasedOffset object and adds // the allocation to this new AliasedOffset. - void CreateOrAddToAliasedOffset( - const MemorySpaceAssignment::Allocation& allocation, - AliasedOffset* aliased_offset); + void CreateOrAddToAliasedOffset(const Allocation& allocation, + AliasedOffset* aliased_offset); // Given an allocation sequence, returns the live allocation at time with a // preference towards allocations in alternate memory. Returns nullptr if no // allocation is alive at that time. - static MemorySpaceAssignment::Allocation* GetLiveAllocationAt( + static Allocation* GetLiveAllocationAt( const MemorySpaceAssignment::AllocationSequence& allocations, int64_t time); @@ -2406,9 +1675,8 @@ class AlternateMemoryBestFitHeap int64_t earliest_prefetch_time) const; // Try prefetching to alternate memory space. - Result Prefetch( - const AllocationRequest& request, - MemorySpaceAssignment::Allocation& prev_allocation_in_default_mem); + Result Prefetch(const AllocationRequest& request, + Allocation& prev_allocation_in_default_mem); // Helper methods used to implement Prefetch(). // @@ -2463,9 +1731,9 @@ class AlternateMemoryBestFitHeap colocated_intervals); // Propagates aliased required assignment for a given position. - void AddAliasedRequiredAssignment( - const HloInstruction* instruction, ShapeIndex index, - const MemorySpaceAssignment::Allocation* aliased_allocation); + void AddAliasedRequiredAssignment(const HloInstruction* instruction, + ShapeIndex index, + const Allocation* aliased_allocation); // This sets a required assignment. CHECK fails if there is a conflicting // required assignment at the same time. @@ -2493,7 +1761,7 @@ class AlternateMemoryBestFitHeap // allocations all share a common allocation site (a use or position) with // each other. This can be used to determine if a group of linked allocations // are considered efficient or not. - std::vector> + std::vector> GetLinkedAllocationsInAlternateMemory( absl::Span allocation_values) const; @@ -2556,10 +1824,9 @@ class AlternateMemoryBestFitHeap // Adds an asynchronous copy to allocations. void AddAsyncCopy( - MemorySpaceAssignment::Allocation& prev_allocation, - MemorySpace memory_space, std::optional chunk, - int64_t exclusive_start_time, int64_t end_time, - int64_t copy_done_schedule_before_time, + Allocation& prev_allocation, MemorySpace memory_space, + std::optional chunk, int64_t exclusive_start_time, + int64_t end_time, int64_t copy_done_schedule_before_time, MemorySpaceAssignment::AllocationSequence* allocations, AliasedOffset* aliased_offset, float resource, std::optional cross_program_prefetch_index = std::nullopt); @@ -2568,11 +1835,10 @@ class AlternateMemoryBestFitHeap // asynchronous copy data structures, prefetch_interval_tree_, and aliasing // data structures void AddAsyncSlicesForPrefetch( - const MemorySpaceAssignment::Allocation& prev_allocation, + const Allocation& prev_allocation, MemorySpaceAssignment::AllocationSequence* allocations, AliasedOffset* aliased_offset, - const std::vector& - slice_decisions_sorted_by_start_time, + const std::vector& slice_decisions_sorted_by_start_time, int64_t prefetch_end_time, int64_t allocation_end_time); // This method is used for committing the chunk candidate but adding it to @@ -2599,9 +1865,8 @@ class AlternateMemoryBestFitHeap void AppendScopedAllocationBufferInfoDebugString( const HloInstruction* instruction, int64_t time, int64_t size, std::string& debug_str) const; - void AppendAllocationInfoDebugString( - const MemorySpaceAssignment::Allocation& allocation, - std::string& debug_str) const; + void AppendAllocationInfoDebugString(const Allocation& allocation, + std::string& debug_str) const; void DumpDebugStringsIfEnabled() const; // Returns the available heap size in the alternate memory. @@ -2618,8 +1883,7 @@ class AlternateMemoryBestFitHeap // Creates and returns a RepackAllocationBlock. static RepackAllocationBlock MakeRepackAllocationBlock( int64_t start_time, int64_t end_time, int64_t size, - int64_t initial_offset, int64_t id, - MemorySpaceAssignment::Allocation* allocation) { + int64_t initial_offset, int64_t id, Allocation* allocation) { RepackAllocationBlock allocation_block; allocation_block.inclusive_start_time = start_time; allocation_block.end_time = end_time; @@ -2671,8 +1935,7 @@ class AlternateMemoryBestFitHeap // The data structure that contains AliasedOffset objects and Allocation to // AliasedOffset map for efficient lookup. std::list aliased_offsets_; - absl::flat_hash_map - aliased_offset_map_; + absl::flat_hash_map aliased_offset_map_; // This map contains required memory assignments for HloValues (e.g., input // and outputs). absl::flat_hash_map> @@ -2714,6 +1977,7 @@ class AlternateMemoryBestFitHeap std::string allocation_info_str_; std::string instruction_schedule_str_; }; + } // namespace memory_space_assignment } // namespace xla diff --git a/third_party/xla/xla/service/memory_space_assignment/memory_space_assignment.proto b/third_party/xla/xla/service/memory_space_assignment/memory_space_assignment.proto index 00a55401c9d877..47a89e74bebb00 100644 --- a/third_party/xla/xla/service/memory_space_assignment/memory_space_assignment.proto +++ b/third_party/xla/xla/service/memory_space_assignment/memory_space_assignment.proto @@ -40,6 +40,10 @@ message SlicedPrefetchOptions { // The threshold for max_slices after which we limit the permutations of slice // times that we try when placing a sliced allocation. uint32 all_slice_time_permutations_threshold = 4; + + // The preferred slize size for MSA sliced prefetches. 0 means there is no + // preferred slice size, in which case, we'll try to slice into max_slices. + uint64 preferred_slice_size = 5; } // Options for memory-bound loop optimizations in memory space assignment. If diff --git a/third_party/xla/xla/service/memory_space_assignment/memory_space_assignment_test.cc b/third_party/xla/xla/service/memory_space_assignment/memory_space_assignment_test.cc index f4e574a41174cc..e9421c99735823 100644 --- a/third_party/xla/xla/service/memory_space_assignment/memory_space_assignment_test.cc +++ b/third_party/xla/xla/service/memory_space_assignment/memory_space_assignment_test.cc @@ -54,8 +54,13 @@ limitations under the License. #include "xla/service/hlo_cost_analysis.h" #include "xla/service/hlo_value.h" #include "xla/service/instruction_hoister.h" +#include "xla/service/memory_space_assignment/allocation.h" +#include "xla/service/memory_space_assignment/cost_analysis.h" #include "xla/service/memory_space_assignment/memory_space_assignment.pb.h" +#include "xla/service/memory_space_assignment/prefetch_interval_picker.h" #include "xla/service/memory_space_assignment/repacking.h" +#include "xla/service/memory_space_assignment/slice.h" +#include "xla/service/memory_space_assignment/testing_utils.h" #include "xla/service/time_utils.h" #include "xla/shape.h" #include "xla/shape_util.h" @@ -67,39 +72,18 @@ limitations under the License. #include "xla/xla_data.pb.h" #include "tsl/lib/core/status_test_util.h" #include "tsl/platform/errors.h" -#include "tsl/platform/protobuf.h" +#include "tsl/platform/protobuf.h" // IWYU pragma: keep #include "tsl/platform/status.h" #include "tsl/platform/status_matchers.h" #include "tsl/platform/statusor.h" #include "tsl/platform/test.h" namespace xla { +namespace memory_space_assignment { namespace { namespace op = xla::testing::opcode_matchers; using Chunk = HeapSimulator::Chunk; -using memory_space_assignment::AsynchronousCopy; -using memory_space_assignment::AsynchronousCopyOrdering; -using memory_space_assignment::AsynchronousCopyResource; -using memory_space_assignment::CostAnalysisPrefetchIntervalPicker; -using memory_space_assignment::InstructionCountPrefetchIntervalPicker; -using memory_space_assignment::MemoryBoundLoopOptimizer; -using memory_space_assignment::MemoryBoundLoopOptimizerOptions; -using memory_space_assignment::MemorySpaceAssignment; -using memory_space_assignment::MemorySpaceAssignmentCostAnalysis; -using memory_space_assignment::MemorySpaceAssignmentRepacker; -using memory_space_assignment::MsaSortOrderOverrides; -using memory_space_assignment::Options; -using memory_space_assignment::PreferredPrefetchOverrides; -using memory_space_assignment::PrefetchIntervalPicker; -using memory_space_assignment::PresetAssignments; -using memory_space_assignment::SlicedPrefetchOptions; -using SliceParam = memory_space_assignment::MemorySpaceAssignment::SliceParam; -using SliceProposal = - memory_space_assignment::MemorySpaceAssignment::SliceProposal; -using SliceProposalCollection = - memory_space_assignment::MemorySpaceAssignment::SliceProposalCollection; -using MSA = memory_space_assignment::MemorySpaceAssignment; using ::testing::_; using ::testing::Return; using ::testing::UnorderedElementsAre; @@ -178,8 +162,6 @@ class MemorySpaceAssignmentTestBase : public HloTestBase { Options DefaultMemorySpaceOptions() { Options options; - options.async_copy_bandwidth_bytes_per_second = kAsyncCopyBandwidth; - options.alternate_mem_bandwidth_bytes_per_second = kAlternateMemBandwidth; options.max_size_in_bytes = 128; options.alignment_in_bytes = 8; options.verify = true; @@ -192,6 +174,13 @@ class MemorySpaceAssignmentTestBase : public HloTestBase { return options; } + CostAnalysisOptions DefaultCostAnalysisOptions() { + CostAnalysisOptions options; + options.async_copy_bandwidth_bytes_per_second = kAsyncCopyBandwidth; + options.alternate_mem_bandwidth_bytes_per_second = kAlternateMemBandwidth; + return options; + } + Options UpdateMaxAsyncCopies(Options options, int64_t max_async_copies) { options.max_outstanding_prefetches = max_async_copies; options.max_outstanding_evictions = max_async_copies; @@ -202,14 +191,18 @@ class MemorySpaceAssignmentTestBase : public HloTestBase { std::unique_ptr AssignMemorySpaceUsingCostAnalysis( HloModule* module, std::optional memory_space_options_override = std::nullopt, - std::optional cost_options_override = + std::optional cost_analysis_options_override = + std::nullopt, + std::optional hlo_cost_options_override = + std::nullopt, + std::optional optional_msa_sort_order_overrides = std::nullopt) { - HloCostAnalysis::Options cost_options = DefaultHloCostAnalysisOptions(); - if (cost_options_override) { - cost_options = *cost_options_override; + HloCostAnalysis::Options hlo_cost_options = DefaultHloCostAnalysisOptions(); + if (hlo_cost_options_override) { + hlo_cost_options = *hlo_cost_options_override; } - HloCostAnalysis hlo_cost_analysis(cost_options); + HloCostAnalysis hlo_cost_analysis(hlo_cost_options); for (HloComputation* computation : module->MakeNonfusionComputations()) { TF_CHECK_OK(computation->Accept(&hlo_cost_analysis)); } @@ -219,10 +212,14 @@ class MemorySpaceAssignmentTestBase : public HloTestBase { if (memory_space_options_override) { memory_space_options = *memory_space_options_override; } + CostAnalysisOptions cost_analysis_options = DefaultCostAnalysisOptions(); + if (cost_analysis_options_override) { + cost_analysis_options = *cost_analysis_options_override; + } - auto cost_analysis = MemorySpaceAssignmentCostAnalysis::Create( - hlo_cost_analysis, memory_space_options, *module) - .value(); + auto cost_analysis = + CostAnalysis::Create(hlo_cost_analysis, cost_analysis_options, *module) + .value(); memory_space_options.cost_analysis = cost_analysis.get(); CostAnalysisPrefetchIntervalPicker prefetch_interval_picker( CostAnalysisPrefetchIntervalPicker( @@ -230,8 +227,12 @@ class MemorySpaceAssignmentTestBase : public HloTestBase { /*preferred_overlap_to_async_copy_ratio=*/1.5, /*max_overlap_to_mem_size_async_copy_ratio=*/10.0, /*mem_size_bytes=*/memory_space_options.max_size_in_bytes)); - memory_space_assignment::MemoryBoundednessBufferIntervalComparator - comparator(*cost_analysis, &cache_); + MsaSortOrderOverrides msa_sort_order_overrides; + if (optional_msa_sort_order_overrides.has_value()) { + msa_sort_order_overrides = optional_msa_sort_order_overrides.value(); + } + MemoryBoundednessBufferIntervalComparator comparator( + *cost_analysis, &cache_, msa_sort_order_overrides); return AssignMemorySpace( module, memory_space_options, [&comparator](const MemorySpaceAssignment::BufferInterval& lhs, @@ -497,7 +498,7 @@ class MemorySpaceAssignmentTestBase : public HloTestBase { return module; } - MemorySpaceAssignmentCostAnalysis::Cache cache_; + CostAnalysis::Cache cache_; }; class MemorySpaceAssignmentTest : public MemorySpaceAssignmentTestBase, @@ -506,93 +507,6 @@ class MemorySpaceAssignmentTest : public MemorySpaceAssignmentTestBase, bool allocate_across_sequential_calls() const override { return GetParam(); } }; -// For testing purposes, we define a cost analysis where we can control the -// elapsed times of each HLO and asynchronous copy. -class FakeMemorySpaceAssignmentCostAnalysis - : public MemorySpaceAssignmentCostAnalysis { - public: - static StatusOr> - Create(const HloCostAnalysis& cost_analysis, const HloModule& module, - const Options& options) { - TF_ASSIGN_OR_RETURN(auto alias_analysis, HloAliasAnalysis::Run(&module)); - TF_ASSIGN_OR_RETURN(auto hlo_live_range, - HloLiveRange::Run(module.schedule(), *alias_analysis, - module.entry_computation())); - auto call_graph = CallGraph::Build(&module); - return absl::WrapUnique(new FakeMemorySpaceAssignmentCostAnalysis( - cost_analysis, options, std::move(alias_analysis), - std::move(hlo_live_range), std::move(call_graph))); - } - - float GetInstructionElapsed( - const HloInstruction& instruction) const override { - if (get_instruction_elapsed_override_) { - return get_instruction_elapsed_override_(instruction); - } - return 1.0; - } - - float GetInstructionElapsedInAlternateMemory( - const HloInstruction& instruction, - absl::Span> - operands_in_alternate_mem, - absl::Span outputs_in_alternate_mem) const override { - if (get_instruction_elapsed_in_alternate_memory_override_) { - return get_instruction_elapsed_in_alternate_memory_override_( - instruction, operands_in_alternate_mem, outputs_in_alternate_mem); - } - if (!operands_in_alternate_mem.empty()) { - return 0.5; - } else { - return 1.0; - } - } - - float GetAsyncCopyElapsed(const Shape& shape) const override { - if (get_async_copy_elapsed_override_) { - return get_async_copy_elapsed_override_(shape); - } - return 3.0; - } - - // The following methods can be used to override what the above API calls - // return. - void SetOverrideForGetInstructionElapsed( - std::function function) { - get_instruction_elapsed_override_ = function; - } - void SetOverrideForGetInstructionElapsedInAlternateMemory( - std::function>, - absl::Span)> - function) { - get_instruction_elapsed_in_alternate_memory_override_ = function; - } - void SetOverrideForGetAsyncCopyElapsed( - std::function function) { - get_async_copy_elapsed_override_ = function; - } - - protected: - FakeMemorySpaceAssignmentCostAnalysis( - const HloCostAnalysis& cost_analysis, const Options& options, - std::unique_ptr alias_analysis, - std::unique_ptr hlo_live_range, - std::unique_ptr call_graph) - : MemorySpaceAssignmentCostAnalysis( - cost_analysis, options, std::move(alias_analysis), - std::move(hlo_live_range), std::move(call_graph)) {} - - private: - std::function - get_instruction_elapsed_override_ = nullptr; - std::function>, - absl::Span)> - get_instruction_elapsed_in_alternate_memory_override_ = nullptr; - std::function get_async_copy_elapsed_override_ = nullptr; -}; - TEST_P(MemorySpaceAssignmentTest, ParameterOnly) { // A module consisting of a single parameter. Inputs/outputs are currently // excluded from memory space assignment. @@ -4650,16 +4564,19 @@ TEST_P(MemorySpaceAssignmentTest, TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(hlo_string)); - Options options = DefaultMemorySpaceOptions(); const std::string text_proto = R"pb( overrides { hlo_position_matcher { instruction_name_regex: "negate(.*)" } override_options { assign_first: true } })pb"; - TF_ASSERT_OK_AND_ASSIGN(options.msa_sort_order_overrides, + TF_ASSERT_OK_AND_ASSIGN(auto msa_sort_order_overrides, ParseTextProto(text_proto)); - AssignMemorySpaceUsingCostAnalysis(module.get(), options); + AssignMemorySpaceUsingCostAnalysis( + module.get(), /*memory_space_options_override=*/std::nullopt, + /*cost_analysis_options_override=*/std::nullopt, + /*hlo_cost_options_override=*/std::nullopt, + /*optional_msa_sort_order_overrides=*/msa_sort_order_overrides); // Parameters are in the default memory space. const HloInstruction* p0 = FindInstruction(module.get(), "p0"); EXPECT_EQ(p0->shape().layout().memory_space(), kDefaultMemorySpace); @@ -4715,17 +4632,20 @@ TEST_P(MemorySpaceAssignmentTest, TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(hlo_string)); - Options options = DefaultMemorySpaceOptions(); const std::string text_proto = R"pb( overrides { hlo_position_matcher { instruction_name_regex: "negate(.*)" } override_options { assign_last: true } } )pb"; - TF_ASSERT_OK_AND_ASSIGN(options.msa_sort_order_overrides, + TF_ASSERT_OK_AND_ASSIGN(auto msa_sort_order_overrides, ParseTextProto(text_proto)); - AssignMemorySpaceUsingCostAnalysis(module.get(), options); + AssignMemorySpaceUsingCostAnalysis( + module.get(), /*memory_space_options_override=*/std::nullopt, + /*cost_analysis_options_override=*/std::nullopt, + /*hlo_cost_options_override=*/std::nullopt, + /*optional_msa_sort_order_overrides=*/msa_sort_order_overrides); // Parameters are in the default memory space. const HloInstruction* p0 = FindInstruction(module.get(), "p0"); EXPECT_EQ(p0->shape().layout().memory_space(), kDefaultMemorySpace); @@ -7577,7 +7497,8 @@ ENTRY entry { // Disable inefficiency check. Expect that the fusion output and operand are // in the alternate memory. options.inefficient_use_to_copy_ratio = 0.0; - AssignMemorySpaceUsingCostAnalysis(module.get(), options); + AssignMemorySpaceUsingCostAnalysis(module.get(), + /*memory_space_options_override=*/options); if (allocate_across_sequential_calls()) { EXPECT_THAT( module->entry_computation()->root_instruction(), @@ -7593,7 +7514,8 @@ ENTRY entry { // f32[2,3]), so this should be considered inefficient (8/48 < 0.5). TF_ASSERT_OK_AND_ASSIGN(module, ParseAndReturnVerifiedModule(hlo_string)); options.inefficient_use_to_copy_ratio = 0.5; - AssignMemorySpaceUsingCostAnalysis(module.get(), options); + AssignMemorySpaceUsingCostAnalysis(module.get(), + /*memory_space_options_override=*/options); EXPECT_THAT(module->entry_computation()->root_instruction(), op::Tuple(op::Fusion(op::Parameter()), op::Negate())); } @@ -7663,10 +7585,13 @@ ENTRY entry { Options options = DefaultMemorySpaceOptions(); options.enable_cross_program_prefetch = false; options.inefficient_use_to_copy_ratio = 0.5; - HloCostAnalysis::Options cost_options = DefaultHloCostAnalysisOptions(); - cost_options.set_transcendentals_per_second(0.4); + HloCostAnalysis::Options hlo_cost_options = DefaultHloCostAnalysisOptions(); + hlo_cost_options.set_transcendentals_per_second(0.4); - AssignMemorySpaceUsingCostAnalysis(module.get(), options, cost_options); + AssignMemorySpaceUsingCostAnalysis( + module.get(), /*memory_space_options_override=*/options, + /*cost_analysis_options_override=*/std::nullopt, + /*hlo_cost_options_override=*/hlo_cost_options); } TEST_P(MemorySpaceAssignmentTest, AsyncOpElapsedTime) { @@ -7773,7 +7698,7 @@ TEST_F(AsynchronousCopyOrderingTest, Simple) { // 6,17 +----------+ Violate // 5,13 +-------+ OK (same start as 5,14) // 5,14 +--------+ OK (same as 5,14) - auto alternate_mem_space = MemorySpaceAssignment::MemorySpace::kAlternate; + auto alternate_mem_space = MemorySpace::kAlternate; AsynchronousCopyOrdering ordering; EXPECT_FALSE(ordering.ViolatesOrdering(3, 11)); ordering.AddCopy({3, 11, 1, alternate_mem_space, 0}); @@ -7793,7 +7718,7 @@ TEST_F(AsynchronousCopyOrderingTest, Simple) { } TEST_F(AsynchronousCopyOrderingTest, SameInterval) { - auto alternate_mem_space = MemorySpaceAssignment::MemorySpace::kAlternate; + auto alternate_mem_space = MemorySpace::kAlternate; AsynchronousCopyOrdering ordering; EXPECT_FALSE(ordering.ViolatesOrdering(1, 5)); EXPECT_FALSE(ordering.ViolatesOrdering(2, 4)); @@ -7825,7 +7750,7 @@ TEST_F(AsynchronousCopyResourceTest, Simple) { // 4,9,3 +-------+ Violate // 4,8,2 +-----+ OK; The 5,9 copy shifts resource to right. // resource: 0 0 0 3 7 0 0 0 0 4 - auto alternate_mem_space = MemorySpaceAssignment::MemorySpace::kAlternate; + auto alternate_mem_space = MemorySpace::kAlternate; AsynchronousCopyResource resource( {2.0, 3.0, 1.0, 6.0, 7.0, 1.0, 7.0, 2.0, 2.0, 4.0}); EXPECT_TRUE(resource.HasEnoughResource(-1, 3, 5.0)); @@ -7859,7 +7784,7 @@ TEST_F(AsynchronousCopyResourceTest, Propagate) { // 0,4,3 +-----+ OK // resource: 2 0 0 0 0 0 0 0 0 0 // 0,4,1 +-----+ Violate - auto alternate_mem_space = MemorySpaceAssignment::MemorySpace::kAlternate; + auto alternate_mem_space = MemorySpace::kAlternate; AsynchronousCopyResource resource( {2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0}); EXPECT_TRUE(resource.HasEnoughResource(6, 10, 2.0)); @@ -7903,7 +7828,7 @@ TEST_F(AsynchronousCopyResourceTest, CantPropagate) { // 4,8,4 +-----+ OK // resource: 2 2 2 2 2 0 0 0 0 2 // 3,6,4 +---+ Violate - auto alternate_mem_space = MemorySpaceAssignment::MemorySpace::kAlternate; + auto alternate_mem_space = MemorySpace::kAlternate; AsynchronousCopyResource resource( {2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0}); EXPECT_TRUE(resource.HasEnoughResource(5, 10, 2.0)); @@ -7930,7 +7855,7 @@ TEST_F(AsynchronousCopyResourceTest, Nested) { // 1,3,2 +-+ OK // resource: 2 2 0 2 2 // 0,4,4 +-----+ Violate - auto alternate_mem_space = MemorySpaceAssignment::MemorySpace::kAlternate; + auto alternate_mem_space = MemorySpace::kAlternate; AsynchronousCopyResource resource({2.0, 2.0, 2.0, 2.0, 2.0}); EXPECT_TRUE(resource.HasEnoughResource(1, 3, 2.0)); resource.AddCopy({1, 3, 2.0, alternate_mem_space, 0}); @@ -7954,7 +7879,7 @@ TEST_F(AsynchronousCopyResourceTest, Remove) { // resource: 0 1 2 2 2 // rem:-1,2,3+---+ // resource: 2 2 2 2 2 - auto alternate_mem_space = MemorySpaceAssignment::MemorySpace::kAlternate; + auto alternate_mem_space = MemorySpace::kAlternate; AsynchronousCopyResource resource({2.0, 2.0, 2.0, 2.0, 2.0}); AsynchronousCopy copy1{2, 5, 2.0, alternate_mem_space, 0}; AsynchronousCopy copy2{-1, 2, 3.0, alternate_mem_space, 1}; @@ -7997,7 +7922,7 @@ TEST_F(AsynchronousCopyResourceTest, NestedRemove) { // resource: 2 2 2 2 2 // add:1,3,2 +-+ OK // resource: 2 2 0 2 2 - auto alternate_mem_space = MemorySpaceAssignment::MemorySpace::kAlternate; + auto alternate_mem_space = MemorySpace::kAlternate; AsynchronousCopyResource resource({2.0, 2.0, 2.0, 2.0, 2.0}); AsynchronousCopy copy1{1, 3, 2.0, alternate_mem_space, 0}; AsynchronousCopy copy2{0, 4, 4.0, alternate_mem_space, 1}; @@ -8044,7 +7969,7 @@ TEST_F(AsynchronousCopyResourceTest, PropagateRemove) { // resource: 2 0 0 0 0 0 0 0 1 2 // rem:0,4,3 +-----+ // resource: 2 2 0 0 0 0 0 0 2 2 - auto alternate_mem_space = MemorySpaceAssignment::MemorySpace::kAlternate; + auto alternate_mem_space = MemorySpace::kAlternate; AsynchronousCopyResource resource( {2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0}); EXPECT_TRUE(resource.HasEnoughResource(6, 10, 2.0)); @@ -8096,7 +8021,7 @@ TEST_F(AsynchronousCopyResourceTest, StartAtZeroAndRemove) { // resource: 0 0 1 1 2 // add:0,4,2 +-----+ OK // resource: 0 0 0 0 2 - auto alternate_mem_space = MemorySpaceAssignment::MemorySpace::kAlternate; + auto alternate_mem_space = MemorySpace::kAlternate; AsynchronousCopyResource resource({0.0, 0.0, 1.0, 1.0, 2.0}); AsynchronousCopy copy1{0, 4, 2.0, alternate_mem_space, 0}; EXPECT_TRUE(resource.HasEnoughResource(0, 4, 2.0)); @@ -8139,7 +8064,7 @@ TEST_F(AsynchronousCopyResourceTest, OutOfOrderRemovalSameStartTime) { // resource: 2 2 1 2 2 // rem:1,5,1 +-----+ // resource: 2 2 2 2 2 - auto alternate_mem_space = MemorySpaceAssignment::MemorySpace::kAlternate; + auto alternate_mem_space = MemorySpace::kAlternate; AsynchronousCopyResource resource({2.0, 2.0, 2.0, 2.0, 2.0}); AsynchronousCopy copy1{1, 3, 1.0, alternate_mem_space, 0}; AsynchronousCopy copy2{1, 4, 2.0, alternate_mem_space, 1}; @@ -8204,7 +8129,7 @@ TEST_F(AsynchronousCopyResourceTest, HasEnoughResourceMultiCheckSuccess) { // 0,6,4 +-----------+ // 4,6,3 +-+ 2 copies OK; The 1,10 copy shifts. // resource: 0 0 0 0 6 0 7 2 2 4 - auto alternate_mem_space = MemorySpaceAssignment::MemorySpace::kAlternate; + auto alternate_mem_space = MemorySpace::kAlternate; AsynchronousCopyResource resource( {2.0, 1.0, 3.0, 6.0, 7.0, 3.0, 7.0, 2.0, 2.0, 4.0}); EXPECT_TRUE(resource.HasEnoughResource(-1, 3, 5.0)); @@ -8232,7 +8157,7 @@ TEST_F(AsynchronousCopyResourceTest, HasEnoughResourceMultiCheckFailure) { // resource: 0 0 0 3 7 3 7 2 2 4 // 0,6,4 +-----------+ // 4,6,4 +-+ Not-OK - auto alternate_mem_space = MemorySpaceAssignment::MemorySpace::kAlternate; + auto alternate_mem_space = MemorySpace::kAlternate; AsynchronousCopyResource resource( {2.0, 1.0, 3.0, 6.0, 7.0, 3.0, 7.0, 2.0, 2.0, 4.0}); EXPECT_TRUE(resource.HasEnoughResource(-1, 3, 5.0)); @@ -8249,7 +8174,7 @@ TEST_F(AsynchronousCopyResourceTest, HasEnoughResourceMultiCheckFailure) { TEST_F(AsynchronousCopyResourceTest, HasEnoughResourceMultiCheckRegressionTest) { - auto alternate_mem_space = MemorySpaceAssignment::MemorySpace::kAlternate; + auto alternate_mem_space = MemorySpace::kAlternate; AsynchronousCopyResource resource({/*0:*/ 24.0f, /*1:*/ 0.0f, /*2:*/ 6.0f, @@ -8865,6 +8790,35 @@ TEST_P(MemorySpaceAssignmentTest, CrossProgramRootDupMayAlias) { op::Parameter(0)); } +TEST_P(MemorySpaceAssignmentTest, CrossProgramRootDusFusionMayAlias) { + absl::string_view hlo_string = R"( + HloModule cross_program_prefetch, is_scheduled=true, input_output_alias={ {}: (0, {}, may-alias) } + fused_computation { + fused_p0 = s32[2,2] parameter(0) + fused_p1 = s32[1,2] parameter(1) + fused_p2 = s32[] parameter(2) + fused_p3 = s32[] parameter(3) + ROOT dus = s32[2,2] dynamic-update-slice(fused_p0, fused_p1, fused_p2, fused_p3) + } + + ENTRY CrossProgramPrefetch { + p0 = s32[2,2] parameter(0) + c0 = s32[1,2] constant({{77, 77}}) + c1 = s32[] constant(0) + bitcast1 = s32[2,2] bitcast(p0) + ROOT fusion = s32[2,2] fusion(bitcast1, c0, c1, c1), kind=kLoop, calls=fused_computation + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + auto preset_assignments = AssignMemorySpace( + module.get(), DefaultMemorySpaceOptions(), + /*max_prefetch_interval=*/5, /*min_prefetch_interval=*/2); + + auto cross_program_prefetches = module->CrossProgramPrefetches(); + EXPECT_EQ(cross_program_prefetches.size(), 0); +} + TEST_P(MemorySpaceAssignmentTest, CrossProgramRootDup) { absl::string_view hlo_string = R"( HloModule cross_program_prefetch, is_scheduled=true @@ -8929,19 +8883,15 @@ TEST_P(MemorySpaceAssignmentTest, CrossProgramRootDotMayAlias) { /*max_prefetch_interval=*/5, /*min_prefetch_interval=*/2); auto cross_program_prefetches = module->CrossProgramPrefetches(); - EXPECT_EQ(cross_program_prefetches.size(), 1); + EXPECT_EQ(cross_program_prefetches.size(), 0); EXPECT_THAT(FindInstruction(module.get(), "dot")->operand(1), - op::AsyncCopy(kAlternateMemorySpace, kDefaultMemorySpace, - op::Parameter(0))); + op::Parameter(0)); } TEST_P(MemorySpaceAssignmentTest, CrossProgramRootLiveOutBug) { - // An in-place fusion that lives out should not be included as a use to the - // cross-program prefetch allocation. Due to a bug, we considered in-place - // update that feeds the ROOT of the entry computation as a valid use of the - // cross-program prefetch. This then would cause this live-out buffer to be - // placed in the alternate memory. We expect p0 to be cross-program prefetched - // but only for the dot operand and not the fusion operand. + // Input-output aliased buffers should not be cross-program prefetched since + // the update on the buffer will not be reflected on the next program + // execution (the data in the alternate memory would be stale). absl::string_view hlo_string = R"( HloModule cross_program_prefetch, is_scheduled=true, input_output_alias={ {0}: (0, {}, may-alias) } fused_computation { @@ -8967,12 +8917,7 @@ TEST_P(MemorySpaceAssignmentTest, CrossProgramRootLiveOutBug) { /*max_prefetch_interval=*/5, /*min_prefetch_interval=*/2); auto cross_program_prefetches = module->CrossProgramPrefetches(); - EXPECT_EQ(cross_program_prefetches.size(), 1); - EXPECT_THAT(FindInstruction(module.get(), "dot")->operand(1), - op::AsyncCopy(kAlternateMemorySpace, kDefaultMemorySpace, - op::Parameter(0))); - EXPECT_THAT(FindInstruction(module.get(), "fusion")->operand(0), - op::Parameter(0)); + EXPECT_EQ(cross_program_prefetches.size(), 0); } TEST_P(MemorySpaceAssignmentTest, CrossProgramRootParameter) { @@ -9355,9 +9300,10 @@ ENTRY main { // Setup cost analysis so it takes 2 instructions to prefetch anything. HloCostAnalysis hlo_cost_analysis(ShapeSize); + CostAnalysisOptions cost_analysis_options; TF_ASSERT_OK_AND_ASSIGN(auto cost_analysis, - FakeMemorySpaceAssignmentCostAnalysis::Create( - hlo_cost_analysis, *module, options)); + FakeCostAnalysis::Create(hlo_cost_analysis, *module, + cost_analysis_options)); cost_analysis->SetOverrideForGetInstructionElapsed( [](const HloInstruction& instruction) -> float { return 10.0; }); cost_analysis->SetOverrideForGetAsyncCopyElapsed( @@ -9444,549 +9390,6 @@ ENTRY main { EXPECT_EQ(f_index, p1_copy_end + 1); } -using CostAnalysisPrefetchIntervalPickerTest = HloTestBase; - -TEST_F(CostAnalysisPrefetchIntervalPickerTest, PrefetchIntervalOrder) { - absl::string_view hlo_string = R"( - HloModule bug, is_scheduled=true - - ENTRY Entry { - param0 = f32[2,4] parameter(0) - a = f32[2,4] negate(param0) - b = f32[2,4] negate(a) - c = f32[2,4] negate(b) - d = f32[2,4] negate(c) - e = f32[2,4] negate(d) - f = f32[2,4] negate(e) - g = f32[2,4] negate(f) - h = f32[2,4] negate(g) - i = f32[2,4] negate(h) - j = f32[2,4] negate(i) - k = f32[2,4] negate(j) - l = f32[2,4] negate(k) - m = f32[2,4] negate(l) - n = f32[2,4] negate(m) - o = f32[2,4] negate(n) - p = f32[2,4] negate(o) - q = f32[2,4] negate(p) - r = f32[2,4] negate(q) - s = f32[2,4] negate(r) - t = f32[2,4] negate(s) - u = f32[2,4] negate(t) - ROOT v = f32[2,4] add(u, param0) - } - )"; - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnVerifiedModule(hlo_string)); - - HloCostAnalysis hlo_cost_analysis(ShapeSize); - Options options; - TF_ASSERT_OK_AND_ASSIGN(auto cost_analysis, - FakeMemorySpaceAssignmentCostAnalysis::Create( - hlo_cost_analysis, *module, options)); - CostAnalysisPrefetchIntervalPicker interval_picker( - *cost_analysis, - /*min_overlap_to_async_copy_ratio=*/1.0, - /*preferred_overlap_to_async_copy_ratio=*/2.0, - /*max_overlap_to_mem_size_async_copy_ratio=*/4.0, - /*mem_size_bytes=*/32); - - HloInstruction* root = module->entry_computation()->root_instruction(); - const HloUse use{root, /*operand_number=*/1, /*operand_index=*/{}}; - interval_picker.Begin(use, /*start_time=*/0, /*end_time=*/22, std::nullopt); - - // Expect that the first interval is (15, 22), which has elapsed time of 6.0, - // twice of the async copy elased (3.0). Then we expect that intervals will be - // visited in alternating increasing and decreasing orders until hitting the - // min and max async copy overlap ratios, which are the intervals (18, 22) - // and (9, 22) respectively. - LOG(INFO) << interval_picker.ToDebugString(); - EXPECT_EQ(interval_picker.Next(), 15); - LOG(INFO) << interval_picker.ToDebugString(); - EXPECT_EQ(interval_picker.Next(), 16); - LOG(INFO) << interval_picker.ToDebugString(); - EXPECT_EQ(interval_picker.Next(), 14); - LOG(INFO) << interval_picker.ToDebugString(); - EXPECT_EQ(interval_picker.Next(), 17); - LOG(INFO) << interval_picker.ToDebugString(); - EXPECT_EQ(interval_picker.Next(), 13); - LOG(INFO) << interval_picker.ToDebugString(); - EXPECT_EQ(interval_picker.Next(), 18); // Min async overlap ratio reached. - LOG(INFO) << interval_picker.ToDebugString(); - EXPECT_EQ(interval_picker.Next(), 12); - LOG(INFO) << interval_picker.ToDebugString(); - EXPECT_EQ(interval_picker.Next(), 11); - LOG(INFO) << interval_picker.ToDebugString(); - EXPECT_EQ(interval_picker.Next(), 10); - LOG(INFO) << interval_picker.ToDebugString(); - EXPECT_EQ(interval_picker.Next(), 9); // Max async overlap ratio reached. - LOG(INFO) << interval_picker.ToDebugString(); - EXPECT_TRUE(interval_picker.Done()); - - // Expect that if the time between start_time and end_time is too short, there - // won't be any available intervals. - interval_picker.Begin(use, /*start_time=*/19, /*end_time=*/22, std::nullopt); - LOG(INFO) << interval_picker.ToDebugString(); - EXPECT_TRUE(interval_picker.Done()); -} - -TEST_F(CostAnalysisPrefetchIntervalPickerTest, PrefetchIntervalOrderWhile) { - absl::string_view hlo_string = R"( - HloModule bug, is_scheduled=true - - while_condition { - param1 = (f32[2,4]) parameter(0) // 19 - ROOT cond = pred[] constant(true) // 20 - } - - while_body { - param2 = (f32[2,4]) parameter(0) // 21 - gte2 = f32[2,4] get-tuple-element(param2), index=0 // 22 - add = f32[2,4] add(gte2, gte2) // 23 - ROOT tuple2 = (f32[2,4]) tuple(add) // 24 - } - - ENTRY Entry { - param0 = f32[2,4] parameter(0) // 0 - a = f32[2,4] negate(param0) // 1 - b = f32[2,4] negate(a) // 2 - c = f32[2,4] negate(b) // 3 - d = f32[2,4] negate(c) // 4 - e = f32[2,4] negate(d) // 5 - f = f32[2,4] negate(e) // 6 - g = f32[2,4] negate(f) // 7 - h = f32[2,4] negate(g) // 8 - i = f32[2,4] negate(h) // 9 - j = f32[2,4] negate(i) // 10 - k = f32[2,4] negate(j) // 11 - l = f32[2,4] negate(k) // 12 - m = f32[2,4] negate(l) // 13 - n = f32[2,4] negate(m) // 14 - o = f32[2,4] negate(n) // 15 - p = f32[2,4] negate(o) // 16 - q = f32[2,4] negate(p) // 17 - tuple = (f32[2,4]) tuple(q) // 18 - while = (f32[2,4]) while(tuple), condition=while_condition, body=while_body // 25 - gte1 = f32[2,4] get-tuple-element(while), index=0 // 26 - r = f32[2,4] negate(gte1) // 27 - s = f32[2,4] negate(r) // 28 - t = f32[2,4] negate(s) // 29 - u = f32[2,4] negate(t) // 30 - ROOT v = f32[2,4] add(u, param0) // 31 - } - )"; - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnVerifiedModule(hlo_string)); - - HloCostAnalysis hlo_cost_analysis(ShapeSize); - Options options; - TF_ASSERT_OK_AND_ASSIGN(auto cost_analysis, - FakeMemorySpaceAssignmentCostAnalysis::Create( - hlo_cost_analysis, *module, options)); - CostAnalysisPrefetchIntervalPicker interval_picker( - *cost_analysis, - /*min_overlap_to_async_copy_ratio=*/1.0, - /*preferred_overlap_to_async_copy_ratio=*/2.0, - /*max_overlap_to_mem_size_async_copy_ratio=*/12.0, - /*mem_size_bytes=*/32); - - EXPECT_EQ(cost_analysis->options() - .xla_tpu_memory_space_assignment_while_execution_count, - 5); - HloInstruction* root = module->entry_computation()->root_instruction(); - const HloUse use{root, /*operand_number=*/1, /*operand_index=*/{}}; - interval_picker.Begin(use, /*start_time=*/0, /*end_time=*/31, std::nullopt); - - // Because there are while loop computations between [19, 24], we ensure that - // the interval picker avoids this interval. - LOG(INFO) << interval_picker.ToDebugString(); - EXPECT_EQ(interval_picker.Next(), 25); - LOG(INFO) << interval_picker.ToDebugString(); - EXPECT_EQ(interval_picker.Next(), 26); - LOG(INFO) << interval_picker.ToDebugString(); - EXPECT_EQ(interval_picker.Next(), 18); - LOG(INFO) << interval_picker.ToDebugString(); - EXPECT_EQ(interval_picker.Next(), 27); // Min async overlap ratio reached. - LOG(INFO) << interval_picker.ToDebugString(); - EXPECT_EQ(interval_picker.Next(), 17); // Max async overlap ratio reached. - LOG(INFO) << interval_picker.ToDebugString(); - EXPECT_TRUE(interval_picker.Done()); -} - -TEST_F(CostAnalysisPrefetchIntervalPickerTest, NestedWhile) { - // This test is to check against a bug where we didn't assign - // while_nest_level_ for while instructions, and defaulting to 0. This could - // cause the prefetch interval logic to think a nested while instruction is - // the same level as the outermost computation. - absl::string_view hlo_string = R"( - HloModule bug, is_scheduled=true - - while_condition.2 { - param1 = (f32[2,4]) parameter(0) // 11 - ROOT cond = pred[] constant(true) // 12 - } - - while_body.2 { - param2 = (f32[2,4]) parameter(0) // 13 - gte2 = f32[2,4] get-tuple-element(param2), index=0 // 14 - add = f32[2,4] add(gte2, gte2) // 15 - ROOT tuple2 = (f32[2,4]) tuple(add) // 16 - } - - while_condition.1 { - param3 = (f32[2,4]) parameter(0) // 5 - ROOT cond = pred[] constant(true) // 6 - } - - while_body.1 { - param4 = (f32[2,4]) parameter(0) // 7 - gte1 = f32[2,4] get-tuple-element(param4), index=0 // 8 - add1 = f32[2,4] add(gte1, gte1) // 9 - tuple1 = (f32[2,4]) tuple(add1) // 10 - while = (f32[2,4]) while(tuple1), condition=while_condition.2, body=while_body.2 // 17 - gte2 = f32[2,4] get-tuple-element(while), index=0 // 18 - add2 = f32[2,4] add(gte2, gte2) // 19 - ROOT tuple2 = (f32[2,4]) tuple(add2) // 20 - } - - ENTRY Entry { - param0 = f32[2,4] parameter(0) // 0 - a = f32[2,4] negate(param0) // 1 - b = f32[2,4] negate(a) // 2 - c = f32[2,4] negate(b) // 3 - tuple = (f32[2,4]) tuple(c) // 4 - while = (f32[2,4]) while(tuple), condition=while_condition.1, body=while_body.1 // 21 - gte1 = f32[2,4] get-tuple-element(while), index=0 // 22 - ROOT root = f32[2,4] add(gte1, param0) // 23 - } - )"; - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnVerifiedModule(hlo_string)); - - HloCostAnalysis hlo_cost_analysis(ShapeSize); - Options options; - TF_ASSERT_OK_AND_ASSIGN(auto cost_analysis, - FakeMemorySpaceAssignmentCostAnalysis::Create( - hlo_cost_analysis, *module, options)); - CostAnalysisPrefetchIntervalPicker interval_picker( - *cost_analysis, - /*min_overlap_to_async_copy_ratio=*/1.0, - /*preferred_overlap_to_async_copy_ratio=*/2.0, - /*max_overlap_to_mem_size_async_copy_ratio=*/12.0, - /*mem_size_bytes=*/32); - - HloInstruction* root = module->entry_computation()->root_instruction(); - const HloUse use{root, /*operand_number=*/1, /*operand_index=*/{}}; - const Shape& shape = root->operand(1)->shape(); - - // We expect the root's latest prefetch start time to be before the while loop - // (logical time 4). - EXPECT_EQ(interval_picker.LatestPrefetchStartTime(shape, /*start_time=*/0, - /*end_time=*/23, &use), - 4); -} - -TEST_F(CostAnalysisPrefetchIntervalPickerTest, ConsecutiveConditionals) { - // This is a test for b/170668492, where prefetching for consecutive - // conditionals can cause the prefetch to start in the conditional's - // computation. - absl::string_view hlo_string = R"( - HloModule bug, is_scheduled=true - - true_computation.0 { - p0 = (f32[3]{0}) parameter(0) // 5 - gte = f32[3]{0} get-tuple-element(p0), index=0 // 6 - ROOT neg1 = f32[3]{0} negate(gte) // 7 - } - - false_computation.0 { - p0 = (f32[3]{0}) parameter(0) // 8 - gte = f32[3]{0} get-tuple-element(p0), index=0 // 9 - ROOT neg2 = f32[3]{0} negate(gte) // 10 - } - - true_computation.1 { - p0 = (f32[3]{0}) parameter(0) // 12 - gte = f32[3]{0} get-tuple-element(p0), index=0 // 13 - ROOT neg1 = f32[3]{0} negate(gte) // 14 - } - - false_computation.1 { - p0 = (f32[3]{0}) parameter(0) // 15 - gte = f32[3]{0} get-tuple-element(p0), index=0 // 16 - ROOT neg2 = f32[3]{0} negate(gte) // 17 - } - - ENTRY entry { - p0 = f32[3]{0} parameter(0) // 0 - p1 = f32[3]{0} parameter(1) // 1 - p2 = pred[] parameter(2) // 2 - tuple0 = (f32[3]{0}) tuple(p0) // 3 - tuple1 = (f32[3]{0}) tuple(p1) // 4 - conditional0 = f32[3]{0} conditional(p2, tuple0, tuple0), true_computation=true_computation.0, false_computation=false_computation.0 // 11 - conditional1 = f32[3]{0} conditional(p2, tuple1, tuple1), true_computation=true_computation.1, false_computation=false_computation.1 // 18 - ROOT tuple2 = (f32[3]{0}, f32[3]{0}) tuple(conditional0, conditional1) // 19 - } - )"; - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnVerifiedModule(hlo_string)); - - HloCostAnalysis hlo_cost_analysis(ShapeSize); - Options options; - TF_ASSERT_OK_AND_ASSIGN(auto cost_analysis, - FakeMemorySpaceAssignmentCostAnalysis::Create( - hlo_cost_analysis, *module, options)); - CostAnalysisPrefetchIntervalPicker interval_picker( - *cost_analysis, - /*min_overlap_to_async_copy_ratio=*/1.0, - /*preferred_overlap_to_async_copy_ratio=*/2.0, - /*max_overlap_to_mem_size_async_copy_ratio=*/12.0, - /*mem_size_bytes=*/32); - - LOG(INFO) << module->ToString(); - - HloInstruction* conditional1 = - module->entry_computation()->GetInstructionWithName("conditional1"); - const HloUse use{conditional1, /*operand_number=*/1, /*operand_index=*/{0}}; - const Shape& shape = - module->entry_computation()->parameter_instruction(0)->shape(); - - // Expect that the prefetch to start before conditional0's called - // computations. - EXPECT_LT(interval_picker.LatestPrefetchStartTime(shape, /*start_time=*/0, - /*end_time=*/11, &use), - 5); -} - -TEST_F(CostAnalysisPrefetchIntervalPickerTest, EarliestLatestWindowTooSmall) { - // This tests the scenario where there is an op that takes a long time (tanh - // in this example) and as a result the earliest and latest times both fall - // inside this long-running op. In this case, we should still return a valid - // prefetch interval just before the long-running op. - absl::string_view hlo_string = R"( - HloModule bug, is_scheduled=true - - ENTRY Entry { - param0 = f32[2,4] parameter(0) - negate = f32[2,4] negate(param0) - tanh = f32[2,4] tanh(param0) - ROOT add = f32[2,4] add(tanh, negate) - } - )"; - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnVerifiedModule(hlo_string)); - - HloCostAnalysis hlo_cost_analysis(ShapeSize); - Options options; - TF_ASSERT_OK_AND_ASSIGN(auto cost_analysis, - FakeMemorySpaceAssignmentCostAnalysis::Create( - hlo_cost_analysis, *module, options)); - cost_analysis->SetOverrideForGetInstructionElapsed( - [](const HloInstruction& hlo) { - if (hlo.opcode() == HloOpcode::kTanh) { - return 20.0; - } - return 1.0; - }); - CostAnalysisPrefetchIntervalPicker interval_picker( - *cost_analysis, - /*min_overlap_to_async_copy_ratio=*/1.0, - /*preferred_overlap_to_async_copy_ratio=*/2.0, - /*max_overlap_to_mem_size_async_copy_ratio=*/12.0, - /*mem_size_bytes=*/32); - - HloInstruction* root = module->entry_computation()->root_instruction(); - const HloUse use{root, /*operand_number=*/1, /*operand_index=*/{}}; - interval_picker.Begin(use, /*start_time=*/1, /*end_time=*/3, std::nullopt); - - LOG(INFO) << interval_picker.ToDebugString(); - EXPECT_FALSE(interval_picker.Done()); - EXPECT_EQ(interval_picker.Next(), 1); - EXPECT_TRUE(interval_picker.Done()); -} - -class MemorySpaceAssignmentCostAnalysisTest : public HloTestBase { - protected: - Status Initialize(const HloModule* module, - float pipeline_overhead_window_size_mib = 0.0) { - HloCostAnalysis::Options options; - options_.alternate_mem_bandwidth_bytes_per_second = 128; - options_.async_copy_bandwidth_bytes_per_second = 32; - options_.pipeline_overhead_window_size_mib = - pipeline_overhead_window_size_mib; - options.shape_size = ShapeSize; - options.set_flops_per_second(8); - options.set_bytes_per_second(32); - options.set_transcendentals_per_second(16); - hlo_cost_analysis_ = std::make_unique(options); - TF_RETURN_IF_ERROR( - module->entry_computation()->Accept(hlo_cost_analysis_.get())); - TF_ASSIGN_OR_RETURN(cost_analysis_, - MemorySpaceAssignmentCostAnalysis::Create( - *hlo_cost_analysis_, options_, *module)); - return OkStatus(); - } - - Options options_; - std::unique_ptr hlo_cost_analysis_; - std::unique_ptr cost_analysis_; -}; - -TEST_F(MemorySpaceAssignmentCostAnalysisTest, NoPipelineOverhead) { - absl::string_view hlo_string = R"( - HloModule module, is_scheduled=true - - ENTRY Entry { - param0 = f32[2,4] parameter(0) - param1 = f32[2,4] parameter(1) - ROOT add = f32[2,4] add(param0, param1) - } - )"; - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnVerifiedModule(hlo_string)); - TF_ASSERT_OK(Initialize(module.get())); - - const HloInstruction* add = module->entry_computation()->root_instruction(); - const float expected_compute_elapsed = - /*num_flops=*/8 / /*flops_per_second=*/8.0; - LOG(INFO) << "Expected compute elapsed = " << expected_compute_elapsed; - EXPECT_EQ(cost_analysis_->GetInstructionElapsedDueToCompute(*add), - expected_compute_elapsed); - float expected_memory_elapsed = - /*bytes_accessed=*/(3 * 4 * 8) / /*bytes_per_second=*/32.0; - LOG(INFO) << "Expected memory elapsed = " << expected_memory_elapsed; - EXPECT_EQ(cost_analysis_->GetInstructionElapsedDueToMemory(*add), - expected_memory_elapsed); - - // This HLO is memory-bound. - EXPECT_EQ(cost_analysis_->GetInstructionElapsed(*add), - expected_memory_elapsed); - EXPECT_EQ( - cost_analysis_->GetInstructionElapsedInAlternateMemory(*add, {}, {}), - expected_memory_elapsed); - - // Put operand 0 in alternate memory. Still memory bound. - expected_memory_elapsed = - (/*bytes_accessed=*/(2 * 4 * 8) / /*bytes_per_second=*/32.0) + - (/*bytes_accessed=*/(4 * 8) / /*bytes_per_second=*/128.0); - LOG(INFO) << "Expected memory elapsed = " << expected_memory_elapsed; - EXPECT_EQ(cost_analysis_->GetInstructionElapsedDueToMemory(*add, {{0, {}}}), - expected_memory_elapsed); - EXPECT_EQ(cost_analysis_->GetInstructionElapsedInAlternateMemory( - *add, {{0, {}}}, {}), - expected_memory_elapsed); - - // Put operand 0 and output in alternate memory. Still memory bound. - expected_memory_elapsed = - (/*bytes_accessed=*/(4 * 8) / /*bytes_per_second=*/32.0) + - (/*bytes_accessed=*/(2 * 4 * 8) / /*bytes_per_second=*/128.0); - LOG(INFO) << "Expected memory elapsed = " << expected_memory_elapsed; - EXPECT_EQ( - cost_analysis_->GetInstructionElapsedDueToMemory(*add, {{0, {}}}, {{}}), - expected_memory_elapsed); - EXPECT_EQ(cost_analysis_->GetInstructionElapsedInAlternateMemory( - *add, {{0, {}}}, {{}}), - expected_memory_elapsed); - - // Put everything in alternate memory. We're now compute bound. - expected_memory_elapsed = - /*bytes_accessed=*/(3 * 4 * 8) / /*bytes_per_second=*/128.0; - LOG(INFO) << "Expected memory elapsed = " << expected_memory_elapsed; - EXPECT_EQ(cost_analysis_->GetInstructionElapsedDueToMemory( - *add, {{0, {}}, {1, {}}}, {{}}), - expected_memory_elapsed); - EXPECT_EQ(cost_analysis_->GetInstructionElapsedInAlternateMemory( - *add, {{0, {}}, {1, {}}}, {{}}), - expected_compute_elapsed); -} - -TEST_F(MemorySpaceAssignmentCostAnalysisTest, PipelineOverhead) { - absl::string_view hlo_string = R"( - HloModule module, is_scheduled=true - - ENTRY Entry { - param0 = f32[2,4] parameter(0) - param1 = f32[2,4] parameter(1) - ROOT add = f32[2,4] add(param0, param1) - } - )"; - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnVerifiedModule(hlo_string)); - // Set the window size 64B. - TF_ASSERT_OK( - Initialize(module.get(), - /*pipeline_overhead_window_size_mib=*/(64.0 / 1024 / 1024))); - - const HloInstruction* add = module->entry_computation()->root_instruction(); - const float expected_compute_elapsed = - /*num_flops=*/8 / /*flops_per_second=*/8.0; - LOG(INFO) << "Expected compute elapsed = " << expected_compute_elapsed; - EXPECT_EQ(cost_analysis_->GetInstructionElapsedDueToCompute(*add), - expected_compute_elapsed); - float expected_memory_elapsed = - /*bytes_accessed=*/(3 * 4 * 8) / /*bytes_per_second=*/32.0; - LOG(INFO) << "Expected memory elapsed = " << expected_memory_elapsed; - EXPECT_EQ(cost_analysis_->GetInstructionElapsedDueToMemory(*add), - expected_memory_elapsed); - - float expected_overhead = expected_compute_elapsed * 2 / 3; - LOG(INFO) << "Expected overhead = " << expected_overhead; - EXPECT_EQ(cost_analysis_->GetDefaultMemoryAccessOverhead(*add), - expected_overhead); - // This HLO is memory-bound. - EXPECT_EQ(cost_analysis_->GetInstructionElapsed(*add), - expected_memory_elapsed + expected_overhead); - EXPECT_EQ( - cost_analysis_->GetInstructionElapsedInAlternateMemory(*add, {}, {}), - expected_memory_elapsed + expected_overhead); - - // Put operand 0 in alternate memory. Still memory bound. - expected_memory_elapsed = - (/*bytes_accessed=*/(2 * 4 * 8) / /*bytes_per_second=*/32.0) + - (/*bytes_accessed=*/(4 * 8) / /*bytes_per_second=*/128.0); - LOG(INFO) << "Expected memory elapsed = " << expected_memory_elapsed; - EXPECT_EQ(cost_analysis_->GetDefaultMemoryAccessOverhead(*add, {{0, {}}}), - expected_overhead); - EXPECT_EQ(cost_analysis_->GetInstructionElapsedDueToMemory(*add, {{0, {}}}), - expected_memory_elapsed); - EXPECT_EQ(cost_analysis_->GetInstructionElapsedInAlternateMemory( - *add, {{0, {}}}, {}), - expected_memory_elapsed + expected_overhead); - - // Put operand 0 and output in alternate memory. Still memory bound. - expected_memory_elapsed = - (/*bytes_accessed=*/(4 * 8) / /*bytes_per_second=*/32.0) + - (/*bytes_accessed=*/(2 * 4 * 8) / /*bytes_per_second=*/128.0); - LOG(INFO) << "Expected memory elapsed = " << expected_memory_elapsed; - expected_overhead = expected_compute_elapsed / 3; - LOG(INFO) << "Expected overhead = " << expected_overhead; - EXPECT_EQ( - cost_analysis_->GetDefaultMemoryAccessOverhead(*add, {{0, {}}}, {{}}), - expected_overhead); - EXPECT_EQ( - cost_analysis_->GetInstructionElapsedDueToMemory(*add, {{0, {}}}, {{}}), - expected_memory_elapsed); - EXPECT_EQ(cost_analysis_->GetInstructionElapsedInAlternateMemory( - *add, {{0, {}}}, {{}}), - expected_memory_elapsed + expected_overhead); - - // Put everything in alternate memory. We're now compute bound. - expected_memory_elapsed = - /*bytes_accessed=*/(3 * 4 * 8) / /*bytes_per_second=*/128.0; - LOG(INFO) << "Expected memory elapsed = " << expected_memory_elapsed; - expected_overhead = 0; - LOG(INFO) << "Expected overhead = " << expected_overhead; - EXPECT_EQ(cost_analysis_->GetDefaultMemoryAccessOverhead( - *add, {{0, {}}, {1, {}}}, {{}}), - expected_overhead); - EXPECT_EQ(cost_analysis_->GetInstructionElapsedDueToMemory( - *add, {{0, {}}, {1, {}}}, {{}}), - expected_memory_elapsed); - EXPECT_EQ(cost_analysis_->GetInstructionElapsedInAlternateMemory( - *add, {{0, {}}, {1, {}}}, {{}}), - expected_compute_elapsed); -} - class MemoryBoundLoopOptimizerTest : public HloTestBase { public: MemoryBoundLoopOptimizerTest() = default; @@ -10004,9 +9407,9 @@ class MemoryBoundLoopOptimizerTest : public HloTestBase { optimizer_options.set_allow_unsatisfied_fully_pipelined_prefetch(false); optimizer_options.set_min_num_iterations(3.0); options_.memory_bound_loop_optimizer_options = optimizer_options; - options_.alternate_mem_bandwidth_bytes_per_second = 128; - options_.async_copy_bandwidth_bytes_per_second = 32; - options_.pipeline_overhead_window_size_mib = 1; + cost_analysis_options_.alternate_mem_bandwidth_bytes_per_second = 128; + cost_analysis_options_.async_copy_bandwidth_bytes_per_second = 32; + cost_analysis_options_.pipeline_overhead_window_size_mib = 1; options.shape_size = ShapeSize; options.set_flops_per_second(16); options.set_bytes_per_second(32); @@ -10015,8 +9418,8 @@ class MemoryBoundLoopOptimizerTest : public HloTestBase { TF_RETURN_IF_ERROR( module->entry_computation()->Accept(hlo_cost_analysis_.get())); TF_ASSIGN_OR_RETURN(cost_analysis_, - MemorySpaceAssignmentCostAnalysis::Create( - *hlo_cost_analysis_, options_, *module)); + CostAnalysis::Create(*hlo_cost_analysis_, + cost_analysis_options_, *module)); TF_ASSIGN_OR_RETURN(alias_analysis_, HloAliasAnalysis::Run(module)); TF_ASSIGN_OR_RETURN(live_range_, HloLiveRange::Run(module->schedule(), *alias_analysis_, @@ -10224,9 +9627,9 @@ ENTRY Entry { if (!cost_analysis_) { TF_RETURN_IF_ERROR(Initialize(module, alternate_memory_size)); } - MemorySpaceAssignmentCostAnalysis::Cache cache; - memory_space_assignment::MemoryBoundednessBufferIntervalComparator - comparator(*cost_analysis_, &cache); + CostAnalysis::Cache cache; + MemoryBoundednessBufferIntervalComparator comparator(*cost_analysis_, + &cache); options_.buffer_interval_comparator = &comparator; CostAnalysisPrefetchIntervalPicker prefetch_interval_picker( CostAnalysisPrefetchIntervalPicker( @@ -10269,9 +9672,7 @@ ENTRY Entry { Status VerifyMsaEquivalence(HloModule* module, bool expect_unsupported_allocations = false) { // Create a map indexed by instruction number and operand number. - absl::flat_hash_map, - const MemorySpaceAssignment::Allocation*> - allocation_map; + absl::flat_hash_map, const Allocation*> allocation_map; for (const MemoryBoundLoopOptimizer::LoopValue& value : optimizer_->loop_values()) { // Skip verification for unsupported allocations as they will go through @@ -10325,25 +9726,22 @@ ENTRY Entry { TF_RET_CHECK(expect_unsupported_allocations); continue; } - const MemorySpaceAssignment::Allocation* allocation = + const Allocation* allocation = allocation_map.at({inst_number, operand_number}); if (!allocation->is_copy_allocation()) { // We don't expect a prefetch here. EXPECT_NE(operand->opcode(), HloOpcode::kCopyDone); int expected_memory_space = - allocation->memory_space() == - MemorySpaceAssignment::MemorySpace::kDefault + allocation->memory_space() == MemorySpace::kDefault ? kDefaultMemorySpace : kAlternateMemorySpace; EXPECT_EQ(operand->shape().layout().memory_space(), expected_memory_space); } else { - EXPECT_EQ(allocation->memory_space(), - MemorySpaceAssignment::MemorySpace::kAlternate); + EXPECT_EQ(allocation->memory_space(), MemorySpace::kAlternate); TF_RET_CHECK(operand->opcode() == HloOpcode::kCopyDone); - const MemorySpaceAssignment::CopyAllocation* copy_allocation = - static_cast( - allocation); + const CopyAllocation* copy_allocation = + static_cast(allocation); if (copy_allocation->copy_done_schedule_before() != inst_number) { // The only case where the copy done schedule before is not the // same as this use would be that this use is not the first use of @@ -10402,8 +9800,9 @@ ENTRY Entry { private: Options options_; + CostAnalysisOptions cost_analysis_options_; std::unique_ptr hlo_cost_analysis_; - std::unique_ptr cost_analysis_; + std::unique_ptr cost_analysis_; std::unique_ptr alias_analysis_; std::unique_ptr live_range_; std::unique_ptr optimizer_; @@ -10476,8 +9875,7 @@ TEST_F(MemoryBoundLoopOptimizerTest, NoAlternateMem) { optimizer->loop_values()) { LOG(INFO) << loop_value.ToString(); for (const auto& allocation : loop_value.allocations) { - EXPECT_EQ(allocation->memory_space(), - MemorySpaceAssignment::MemorySpace::kDefault); + EXPECT_EQ(allocation->memory_space(), MemorySpace::kDefault); for (const HloUse& use : allocation->uses()) { EXPECT_FALSE(seen_uses.contains(use)) << use.ToString(); seen_uses.insert(use); @@ -10573,20 +9971,19 @@ TEST_F(MemoryBoundLoopOptimizerTest, PrefetchFifoOrderWithOverlap) { // +=========+ // 13 14| 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14| 0 1 // prev | loop | next - std::vector prefetches; + std::vector prefetches; for (const MemoryBoundLoopOptimizer::LoopValue& loop_value : optimizer->loop_values()) { if (!loop_value.allocations.empty() && loop_value.allocations.back()->is_copy_allocation()) { - prefetches.push_back( - static_cast( - loop_value.allocations.back().get())); + prefetches.push_back(static_cast( + loop_value.allocations.back().get())); } } EXPECT_EQ(prefetches.size(), 3); bool seen_overlap = false; bool seen_nonoverlap = false; - for (const MemorySpaceAssignment::CopyAllocation* prefetch : prefetches) { + for (const CopyAllocation* prefetch : prefetches) { const HloUse& use = *prefetch->uses().begin(); if (use.instruction->name() == "op14") { EXPECT_EQ(prefetch->copy_done_schedule_before(), 14); @@ -10684,19 +10081,18 @@ TEST_F(MemoryBoundLoopOptimizerTest, PrefetchFifoOrderWithoutOverlap) { // =====> ===============================> // 13 14| 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14| 0 1 // prev | loop | next - std::vector prefetches; + std::vector prefetches; for (const MemoryBoundLoopOptimizer::LoopValue& loop_value : optimizer->loop_values()) { if (!loop_value.allocations.empty() && loop_value.allocations.back()->is_copy_allocation()) { - prefetches.push_back( - static_cast( - loop_value.allocations.back().get())); + prefetches.push_back(static_cast( + loop_value.allocations.back().get())); } } EXPECT_EQ(prefetches.size(), 2); std::optional expected_op14_copy_start_time; - for (const MemorySpaceAssignment::CopyAllocation* prefetch : prefetches) { + for (const CopyAllocation* prefetch : prefetches) { const HloUse& use = *prefetch->uses().begin(); if (use.instruction->name() == "op1") { EXPECT_EQ(prefetch->copy_done_schedule_before(), 1); @@ -10705,7 +10101,7 @@ TEST_F(MemoryBoundLoopOptimizerTest, PrefetchFifoOrderWithoutOverlap) { } } EXPECT_TRUE(expected_op14_copy_start_time.has_value()); - for (const MemorySpaceAssignment::CopyAllocation* prefetch : prefetches) { + for (const CopyAllocation* prefetch : prefetches) { const HloUse& use = *prefetch->uses().begin(); if (use.instruction->name() == "op14") { EXPECT_EQ(prefetch->copy_done_schedule_before(), 14); @@ -10771,20 +10167,19 @@ TEST_F(MemoryBoundLoopOptimizerTest, PrefetchFifoOrderWithOverlap2) { // ==> ========================================> ====== // 13 14| 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14| 0 1 // prev | loop | next - std::vector prefetches; + std::vector prefetches; for (const MemoryBoundLoopOptimizer::LoopValue& loop_value : optimizer->loop_values()) { if (!loop_value.allocations.empty() && loop_value.allocations.back()->is_copy_allocation()) { - prefetches.push_back( - static_cast( - loop_value.allocations.back().get())); + prefetches.push_back(static_cast( + loop_value.allocations.back().get())); } } EXPECT_EQ(prefetches.size(), 3); bool seen_overlap = false; bool seen_nonoverlap = false; - for (const MemorySpaceAssignment::CopyAllocation* prefetch : prefetches) { + for (const CopyAllocation* prefetch : prefetches) { const HloUse& use = *prefetch->uses().begin(); if (use.instruction->name() == "op13") { EXPECT_EQ(prefetch->copy_done_schedule_before(), 13); @@ -11201,7 +10596,7 @@ class SlicedPrefetchStartTimePickerTest : public ::testing::Test { std::vector Pick( const std::vector& schedule_data, int64_t num_slices, int64_t prefetch_start_time, int64_t prefetch_end_time) { - return memory_space_assignment::SlicedPrefetchStartTimePicker::Pick( + return SlicedPrefetchStartTimePicker::Pick( num_slices, prefetch_start_time, prefetch_end_time, [&schedule_data](int64_t exclusive_start_time, int64_t exclusive_end_time) { @@ -11424,8 +10819,8 @@ class SlicedPrefetchTest : public MemorySpaceAssignmentTestBase { } // A class that can be mocked to set expectations on slice proposals. To do - // that, we set memory_space_assignment::Options::propose_slice_fn to a lambda - // that calls our mocks ProposeSlices() method. + // that, we set Options::propose_slice_fn to a lambda that calls our mocks + // ProposeSlices() method. class SliceProposer { public: SliceProposer() = default; @@ -11445,9 +10840,11 @@ class SlicedPrefetchTest : public MemorySpaceAssignmentTestBase { // An HloInstruction* matcher for matching the asynchronous sliced copies // produced by MSA. In particular, the matcher performs the following // checks: - // - The copy is concluded with a concat-bitcast custom call + // - The copy is concluded with a concat-bitcast custom call, or a + // bitcast of a concat-bitcast custom call if expect_bitcasted_io is true // - The operands to the concat-bitcast are asynchronous slices of the - // expected operand + // expected operand, or asynchronous slices of a bitcast of the expected + // operand if expect_bitcasted_io is true // - The number of slices is as expected (i.e., // expected_slice_params_per_slice_in_spatial_order_.size()) // - The copy is from and to the correct memory spaces @@ -11465,47 +10862,57 @@ class SlicedPrefetchTest : public MemorySpaceAssignmentTestBase { AsyncSlicedCopy(int64_t to_space, int64_t from_space, std::vector> expected_slice_params_per_slice_in_spatial_order, - ::testing::Matcher operand) + ::testing::Matcher operand, + bool expect_bitcasted_io) : to_space_(to_space), from_space_(from_space), expected_slice_params_per_slice_in_spatial_order_( std::move(expected_slice_params_per_slice_in_spatial_order)), - custom_call_matcher_( - memory_space_assignment::kConcatBitcastCustomCall, - std::vector<::testing::Matcher>( - expected_slice_params_per_slice_in_spatial_order_.size(), - op::AsyncDone(op::AsyncStart(operand)))) {} + base_hlo_matcher_(CreateBaseHloMatcher( + operand, expected_slice_params_per_slice_in_spatial_order_.size(), + expect_bitcasted_io)), + expect_bitcasted_io_(expect_bitcasted_io) {} bool MatchAndExplain( const HloInstruction* instruction, ::testing::MatchResultListener* listener) const override { - // Match the custom call. - if (!custom_call_matcher_.MatchAndExplain(instruction, listener)) { + // Match opcodes and number of operands. + if (!base_hlo_matcher_.MatchAndExplain(instruction, listener)) { return false; } - // Check if the custom call has the proper memory space. - const HloInstruction* concat_bitcast = instruction; - if (!MatchMemorySpace(concat_bitcast, to_space_, "concat-bitcast", - listener)) { + // Check if the copied result has the proper memory space. + if (!MatchMemorySpace(instruction, to_space_, "copy result", listener)) { return false; } - // Check if the copied tensor has the proper memory space. + // Find some instructions in the async copy. + const HloInstruction* concat_bitcast = + (expect_bitcasted_io_ ? instruction->operand(0) : instruction); + VLOG(2) << "AsyncSlicedCopy identified the concat-bitcast as " + << concat_bitcast->name(); const HloInstruction* copy_operand = concat_bitcast->operand(0)->operand(0)->operand(0); - if (!MatchMemorySpace(copy_operand, from_space_, "copy operand", + const HloInstruction* original_copy_operand = + (expect_bitcasted_io_ ? copy_operand->operand(0) : copy_operand); + VLOG(2) << "AsyncSlicedCopy identified the copy operand as " + << copy_operand->name() << ", and the original copy operand as " + << original_copy_operand->name(); + + // Check if the copied tensor has the proper memory space. + if (!MatchMemorySpace(original_copy_operand, from_space_, "copy operand", listener)) { return false; } // Check if the copied tensor retains its shape. - if (!Shape::Equal().IgnoreMemorySpaceInLayout()(concat_bitcast->shape(), - copy_operand->shape())) { + if (!Shape::Equal().IgnoreMemorySpaceInLayout()( + instruction->shape(), original_copy_operand->shape())) { *listener << " has a shape of " - << copy_operand->shape().ToString(/*print_layout=*/true) + << original_copy_operand->shape().ToString( + /*print_layout=*/true) << " before copying but a shape of " - << concat_bitcast->shape().ToString(/*print_layout=*/true) + << instruction->shape().ToString(/*print_layout=*/true) << " after copying (ignoring memory space)"; return false; @@ -11581,7 +10988,7 @@ class SlicedPrefetchTest : public MemorySpaceAssignmentTestBase { } void DescribeTo(std::ostream* os) const override { - custom_call_matcher_.DescribeTo(os); + base_hlo_matcher_.DescribeTo(os); std::vector slice_parameters_per_operand; for (int op_idx = 0; op_idx < expected_slice_params_per_slice_in_spatial_order_.size(); @@ -11609,6 +11016,22 @@ class SlicedPrefetchTest : public MemorySpaceAssignmentTestBase { } private: + static ::testing::Matcher CreateBaseHloMatcher( + ::testing::Matcher operand, int64_t num_slices, + bool expect_bitcasted_io) { + if (expect_bitcasted_io) { + return op::Bitcast(op::CustomCall( + kConcatBitcastCustomCall, + std::vector<::testing::Matcher>( + num_slices, + op::AsyncDone(op::AsyncStart(op::Bitcast(operand)))))); + } + return op::CustomCall( + kConcatBitcastCustomCall, + std::vector<::testing::Matcher>( + num_slices, op::AsyncDone(op::AsyncStart(operand)))); + } + static bool MatchMemorySpace(const HloInstruction* instruction, int64_t expected_memory_space, std::string_view error_message_identifier, @@ -11636,7 +11059,8 @@ class SlicedPrefetchTest : public MemorySpaceAssignmentTestBase { int64_t from_space_; std::vector> expected_slice_params_per_slice_in_spatial_order_; - ::xla::testing::HloCustomCallMatcher custom_call_matcher_; + ::testing::Matcher base_hlo_matcher_; + bool expect_bitcasted_io_; }; // Returns an AsyncSlicedCopy matcher. @@ -11644,10 +11068,11 @@ class SlicedPrefetchTest : public MemorySpaceAssignmentTestBase { int64_t to_space, int64_t from_space, std::vector> expected_slice_params_per_slice_in_spatial_order, - ::testing::Matcher operand_matcher) { + ::testing::Matcher operand_matcher, + bool expect_bitcasted_io = false) { return ::testing::MakeMatcher(new AsyncSlicedCopy( to_space, from_space, expected_slice_params_per_slice_in_spatial_order, - operand_matcher)); + operand_matcher, expect_bitcasted_io)); } // We make our own matcher for SlicedPrefetchOptions to work around the fact @@ -11761,8 +11186,7 @@ class SlicedPrefetchTest : public MemorySpaceAssignmentTestBase { // Returns true if instruction is a concat-bitcast. static bool IsConcatBitcast(const HloInstruction* instruction) { - return instruction->IsCustomCall( - memory_space_assignment::kConcatBitcastCustomCall); + return instruction->IsCustomCall(kConcatBitcastCustomCall); } // Returns the index of the first instruction with the given name. @@ -12041,6 +11465,8 @@ class SlicedPrefetchTest : public MemorySpaceAssignmentTestBase { std::string_view slices_start_after_instruction_name, std::string_view slices_done_before_instruction_name, bool expect_slices_started_at_different_times) { + CHECK(concat_bitcast->IsCustomCall(kConcatBitcastCustomCall)); + // Get the schedule. auto entry_schedule = module.schedule().sequence(module.entry_computation()).instructions(); @@ -12114,29 +11540,38 @@ class SlicedPrefetchTest : public MemorySpaceAssignmentTestBase { } // Returns OkStatus iff: - // - When the slices of concat_bitcast are sorted in expected spatial order, - // they are assigned chunks that spatially fall in the same order AND - // - The slices of concat_bitcast are assigned contiguous memory chunks AND - // - The concat_bitcast is assigned a chunk that is the concatenation of the - // slice chunks AND - // - The size of the chunk assigned to the concat_bitcast has the same size - // as the instruction's shape + // - Each slice is assigned a chunk that is the same size as the slice + // instruction's shape. + // - When the slices of sliced_copy_result are sorted in expected spatial + // order, they are assigned chunks that spatially fall in the same order AND + // - The slices of sliced_copy_result are assigned contiguous memory chunks + // AND + // - The sliced_copy_result is assigned a chunk that is the concatenation of + // the slice chunks AND + // - The size of the chunk assigned to the sliced_copy_result has the same + // size as the instruction's shape static Status CheckSliceChunks(const PresetAssignments& assignments, - const HloInstruction* concat_bitcast) { + const HloInstruction* sliced_copy_result, + bool expect_bitcasted_io = false) { + const HloInstruction* concat_bitcast = + (expect_bitcasted_io ? sliced_copy_result->operand(0) + : sliced_copy_result); + CHECK(concat_bitcast->IsCustomCall(kConcatBitcastCustomCall)); + absl::flat_hash_map slices_to_chunks; - std::optional concat_bitcast_chunk = std::nullopt; + std::optional result_chunk = std::nullopt; for (const std::pair& position_chunk_pair : assignments.chunks()) { - if (position_chunk_pair.first.instruction == concat_bitcast) { - if (concat_bitcast_chunk.has_value()) { + if (position_chunk_pair.first.instruction == sliced_copy_result) { + if (result_chunk.has_value()) { return FailedPrecondition( - "%s", absl::StrCat("Concat-bitcast ", concat_bitcast->name(), + "%s", absl::StrCat("Sliced copy ", sliced_copy_result->name(), " is assigned more than one chunk: ", - concat_bitcast_chunk->ToString(), " and ", + result_chunk->ToString(), " and ", position_chunk_pair.second.ToString())); } - concat_bitcast_chunk = position_chunk_pair.second; + result_chunk = position_chunk_pair.second; } for (const HloInstruction* slice : concat_bitcast->operands()) { if (position_chunk_pair.first.instruction == slice) { @@ -12155,7 +11590,7 @@ class SlicedPrefetchTest : public MemorySpaceAssignmentTestBase { std::vector sorted_slices = SortSlicesInExpectedSpatialOrder(concat_bitcast); - VLOG(1) << "Chunk assignments for " << concat_bitcast->name() << ":\n" + VLOG(1) << "Chunk assignments for " << sliced_copy_result->name() << ":\n" << absl::StrJoin( sorted_slices, "\n", [&](std::string* out, const HloInstruction* slice) { @@ -12167,16 +11602,16 @@ class SlicedPrefetchTest : public MemorySpaceAssignmentTestBase { absl::StrAppend(out, " slice ", slice->name(), ": ", chunk); }) - << "\n concat-bitcast " << concat_bitcast->name() << ": " - << (concat_bitcast_chunk.has_value() - ? concat_bitcast_chunk->ToString() - : "no chunk assigned"); + << "\n sliced copy result " << sliced_copy_result->name() << ": " + << (result_chunk.has_value() ? result_chunk->ToString() + : "no chunk assigned"); if (sorted_slices.empty()) { return OkStatus(); } // Check that slices are assigned contiguous chunks that are spatially - // ordered according to sorted_slices. + // ordered according to sorted_slices. Also make sure that slices are + // assigned chunks with sizes that match their shape. int64_t previous_end = -1; int64_t min_offset = std::numeric_limits::max(); int64_t max_limit = std::numeric_limits::min(); @@ -12188,6 +11623,16 @@ class SlicedPrefetchTest : public MemorySpaceAssignmentTestBase { absl::StrCat("Slice ", slice->name(), " is not assigned a chunk")); } const Chunk& chunk = it->second; + + if (chunk.size != ShapeSize(slice->shape())) { + return FailedPrecondition( + "%s", + absl::StrCat("Slice ", slice->name(), " is assigned chunk ", + chunk.ToString(), " with size ", chunk.size, + ". Expected a size of ", ShapeSize(slice->shape()), + ", to match its shape.")); + } + if (previous_end != -1 && chunk.offset != previous_end) { return FailedPrecondition( "%s", absl::StrCat( @@ -12200,31 +11645,29 @@ class SlicedPrefetchTest : public MemorySpaceAssignmentTestBase { max_limit = std::max(max_limit, chunk.chunk_end()); } - // Check that the concat_bitcast is assigned a chunk that is the + // Check that the sliced copy result is assigned a chunk that is the // concatenation of the slice chunks. - if (!concat_bitcast_chunk.has_value()) { + if (!result_chunk.has_value()) { return FailedPrecondition( - "%s", absl::StrCat("Concat-bitcast ", concat_bitcast->name(), + "%s", absl::StrCat("Sliced copy result ", sliced_copy_result->name(), " is not assigned a chunk.")); } - Chunk expected_concat_bitcast_chunk = - Chunk::FromOffsetEnd(min_offset, max_limit); - if (!(*concat_bitcast_chunk == expected_concat_bitcast_chunk)) { + Chunk expected_result_chunk = Chunk::FromOffsetEnd(min_offset, max_limit); + if (!(*result_chunk == expected_result_chunk)) { return FailedPrecondition( - "%s", - absl::StrCat("Concat-bitcast ", concat_bitcast->name(), - " is assigned chunk ", concat_bitcast_chunk->ToString(), - " but its expected to be assigned chunk ", - expected_concat_bitcast_chunk.ToString())); + "%s", absl::StrCat("Sliced copy result ", sliced_copy_result->name(), + " is assigned chunk ", result_chunk->ToString(), + ", but it's expected to be assigned chunk ", + expected_result_chunk.ToString())); } - if (concat_bitcast_chunk->size != ShapeSize(concat_bitcast->shape())) { + if (result_chunk->size != ShapeSize(sliced_copy_result->shape())) { return FailedPrecondition( - "%s", - absl::StrCat( - "Concat-bitcast ", concat_bitcast->name(), " is assigned chunk ", - concat_bitcast_chunk->ToString(), " with size ", - concat_bitcast_chunk->size, ". Expected a size of ", - ShapeSize(concat_bitcast->shape()), ", to match its shape.")); + "%s", absl::StrCat("Sliced copy result ", sliced_copy_result->name(), + " is assigned chunk ", result_chunk->ToString(), + " with size ", result_chunk->size, + ". Expected a size of ", + ShapeSize(sliced_copy_result->shape()), + ", to match its shape.")); } return OkStatus(); @@ -12237,11 +11680,13 @@ class SlicedPrefetchTest : public MemorySpaceAssignmentTestBase { options_.max_size_in_bytes = 1024; options_.sliced_prefetch_options.set_max_slices(2); options_.sliced_prefetch_options.set_min_bytes(8); - options_.propose_slice_fn = - [&](const Shape& shape, - const memory_space_assignment::SlicedPrefetchOptions& options) { - return slice_proposer_.ProposeSlices(shape, options); - }; + options_.propose_slice_fn = [&](const Shape& shape, + const SlicedPrefetchOptions& options) { + return slice_proposer_.ProposeSlices(shape, options); + }; + options_.get_equivalent_s8_shape_fn = [](const Shape& original_shape) { + return ShapeUtil::MakeShape(S8, {ShapeSize(original_shape)}); + }; } bool allocate_across_sequential_calls() const override { return true; } @@ -12518,7 +11963,8 @@ ENTRY main { << module->ToString(HloPrintOptions::ShortParsable()); std::unique_ptr assignments = - AssignMemorySpaceUsingCostAnalysis(module.get(), options_); + AssignMemorySpaceUsingCostAnalysis( + module.get(), /*memory_space_options_override=*/options_); VLOG(1) << "Post-MSA module:\n" << module->ToString(HloPrintOptions::ShortParsable()); @@ -12594,7 +12040,8 @@ ENTRY main { << module->ToString(HloPrintOptions::ShortParsable()); std::unique_ptr assignments = - AssignMemorySpaceUsingCostAnalysis(module.get(), options_); + AssignMemorySpaceUsingCostAnalysis( + module.get(), /*memory_space_options_override=*/options_); VLOG(1) << "Post-MSA module:\n" << module->ToString(HloPrintOptions::ShortParsable()); @@ -13139,5 +12586,97 @@ TEST_F(RepackingTest, Colocations) { EXPECT_EQ(f.GetColocationsCount(), 3); EXPECT_THAT(f.GetColocations(), UnorderedElementsAre(&d, &e, &f)); } + +TEST_F(SlicedPrefetchTest, UniformSizedSlicing) { + std::string hlo_text = R"zz( +HloModule Slice, is_scheduled=true + +ENTRY main { + p0 = f32[8,8] parameter(0) + p1 = f32[8,8] parameter(1) + p2 = f32[8,16] parameter(2) + constant1 = f32[] constant(1.1) + + a = f32[8,8] tanh(p0) + b = f32[8,8] tanh(a) + c = f32[8,8] tanh(b) + d = f32[8,8] tanh(c) + e = f32[8,8] tanh(d) + f = f32[8,8] tanh(e) + g = f32[8,8] tanh(f) + h = f32[8,8] tanh(g) + + x = f32[8,8] add(p1, h) + padded_x = f32[8,16] pad(x, constant1), padding=0_0x0_8 + ROOT r = f32[8,16] add(padded_x, p2) +})zz"; + const Shape f32_8_16 = ShapeUtil::MakeShape(F32, {8, 16}); + const Shape s8_128 = ShapeUtil::MakeShape(S8, {128}); + + options_.sliced_prefetch_options.set_max_slices(100000); + options_.sliced_prefetch_options.set_preferred_slice_size(4 * 8 * 4); + + EXPECT_CALL(slice_proposer_, + ProposeSlices(f32_8_8_, EqualsSlicedPrefetchOptions( + options_.sliced_prefetch_options))) + .WillRepeatedly(Return(SliceProposalCollection({ + SliceProposal( + {s8_128, std::vector({{0, 128}}), ShapeSize(s8_128)}), + SliceProposal({s8_128, std::vector({{128, 256}}), + ShapeSize(s8_128)}), + }))); + + EXPECT_CALL(slice_proposer_, + ProposeSlices(f32_8_16, EqualsSlicedPrefetchOptions( + options_.sliced_prefetch_options))) + .WillRepeatedly(Return(SliceProposalCollection({ + SliceProposal( + {s8_128, std::vector({{0, 128}}), ShapeSize(s8_128)}), + SliceProposal({s8_128, std::vector({{128, 256}}), + ShapeSize(s8_128)}), + SliceProposal({s8_128, std::vector({{256, 384}}), + ShapeSize(s8_128)}), + SliceProposal({s8_128, std::vector({{384, 512}}), + ShapeSize(s8_128)}), + }))); + + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo_text)); + VLOG(1) << "Original module:\n" + << module->ToString(HloPrintOptions::ShortParsable()); + + std::unique_ptr assignments = AssignMemorySpace( + module.get(), options_, + /*max_prefetch_interval=*/100, /*min_prefetch_interval=*/1); + + VLOG(1) << "Post-MSA module:\n" + << module->ToString(HloPrintOptions::ShortParsable()); + + auto root = module->entry_computation()->root_instruction(); + + // Expect p1 to be asynchronously copied via 2 slices, and p2 to be + // asynchronously copied via 4 slices. We expect p1 and p2 to be bitcast + // before slicing and after slicing. + EXPECT_THAT( + root, + op::Add(op::Pad(op::Add(IsAsyncSlicedCopy( + kAlternateMemorySpace, kDefaultMemorySpace, + {{{0, 128}}, {{128, 256}}}, op::Parameter(1), + /*expect_bitcasted_io=*/true), + /*don't care*/ _), + /*padding constant*/ _), + IsAsyncSlicedCopy( + kAlternateMemorySpace, kDefaultMemorySpace, + {{{0, 128}}, {{128, 256}}, {{256, 384}}, {{384, 512}}}, + op::Parameter(2), /*expect_bitcasted_io=*/true))); + + // Check expectations on the chunks assigned to the asynchronous sliced copy. + TF_EXPECT_OK(CheckSliceChunks(*assignments, root->operand(1), + /*expect_bitcasted_io=*/true)); + TF_EXPECT_OK(CheckSliceChunks(*assignments, + root->operand(0)->operand(0)->operand(0), + /*expect_bitcasted_io=*/true)); +} + } // namespace +} // namespace memory_space_assignment } // namespace xla diff --git a/third_party/xla/xla/service/memory_space_assignment/prefetch_interval_picker.cc b/third_party/xla/xla/service/memory_space_assignment/prefetch_interval_picker.cc new file mode 100644 index 00000000000000..ad63a509dc4fdb --- /dev/null +++ b/third_party/xla/xla/service/memory_space_assignment/prefetch_interval_picker.cc @@ -0,0 +1,553 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/memory_space_assignment/prefetch_interval_picker.h" + +#include +#include +#include +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/log/check.h" +#include "absl/strings/str_cat.h" +#include "absl/types/span.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/utils/hlo_live_range.h" +#include "xla/service/heap_simulator/heap_simulator.h" +#include "xla/service/hlo_value.h" +#include "xla/service/memory_space_assignment/cost_analysis.h" +#include "xla/service/memory_space_assignment/memory_space_assignment.pb.h" +#include "xla/shape.h" +#include "xla/shape_util.h" +#include "xla/status.h" +#include "xla/util.h" +#include "tsl/platform/logging.h" + +namespace xla { +namespace memory_space_assignment { +namespace { + +// Each time we retry compilation, increase the preferred eviction end time by +// this amount multiplied by preferred overlap to async copy ratio. +const float kEvictionRetryMultiplier = 2.0; + +// The number of decreasing intervals for CostAnalysisPrefetchIntervalPicker to +// return when it runs out of increasing intervals. Increasing this number may +// hurt compilation time. +const int kNumExploredDecreasingIntervals = 100; + +} // namespace + +bool InstructionCountPrefetchIntervalPicker::CanAllocateInAlternateMemoryNoCopy( + const Shape& shape, int64_t start_time, int64_t end_time) const { + return end_time - start_time <= max_overlap_count_; +} + +int64_t InstructionCountPrefetchIntervalPicker::PreferredEvictionEndTime( + const Shape& shape, int64_t start_time, int64_t latest_end_time) const { + return std::min(start_time + min_overlap_count_, latest_end_time); +} + +int64_t InstructionCountPrefetchIntervalPicker::LatestPrefetchStartTime( + const Shape& shape, int64_t start_time, int64_t end_time, + const HloUse* use) const { + return end_time - min_overlap_count_; +} + +int64_t InstructionCountPrefetchIntervalPicker::PreferredPrefetchStartTime( + const Shape& shape, int64_t earliest_prefetch_start_time, + int64_t latest_prefetch_start_time, int64_t prefetch_end_time) const { + return std::max(earliest_prefetch_start_time, + prefetch_end_time - max_overlap_count_); +} + +int64_t InstructionCountPrefetchIntervalPicker::EstimatedPrefetchEndTime( + const Shape& shape, int64_t start_time, int64_t end_time) const { + // For testing, assume the end time is the estimated prefetch end time. + return end_time; +} + +float InstructionCountPrefetchIntervalPicker::GetLogicalIntervalElapsed( + int64_t start_time, int64_t end_time) const { + // For testing, just assume every HLO takes 1 second. + return static_cast(end_time - start_time - 1); +} + +void InstructionCountPrefetchIntervalPicker::Begin( + const HloUse& use, int64_t start_time, int64_t end_time, + std::optional preferred_time) { + end_time_ = end_time; + const Shape& shape = ShapeUtil::GetSubshape( + use.instruction->operand(use.operand_number)->shape(), use.operand_index); + if (preferred_time) { + current_prefetch_time_ = *preferred_time; + } else { + current_prefetch_time_ = + PreferredPrefetchStartTime(shape, start_time, end_time, end_time); + } +} + +int64_t InstructionCountPrefetchIntervalPicker::Next() { + CHECK(!Done()) << "Prefetch interval picker's Next() is called even though " + "Done() is false"; + return current_prefetch_time_++; +} + +bool InstructionCountPrefetchIntervalPicker::Done() const { + return end_time_ - current_prefetch_time_ <= min_overlap_count_; +} + +int64_t InstructionCountPrefetchIntervalPicker::latest_time() const { + return end_time_ - min_overlap_count_ - 1; +} + +std::string InstructionCountPrefetchIntervalPicker::ToDebugString() const { + return absl::StrCat("Overlapped HLOs = ", end_time_ - current_prefetch_time_); +} + +std::string InstructionCountPrefetchIntervalPicker::ToNoCopyDebugString( + const Shape& shape, int64_t start_time, int64_t end_time) const { + return absl::StrCat("Overlapped HLOs = ", end_time - start_time); +} + +CostAnalysisPrefetchIntervalPicker::CostAnalysisPrefetchIntervalPicker( + const CostAnalysis& cost_analysis, float min_overlap_to_async_copy_ratio, + float preferred_overlap_to_async_copy_ratio, + float max_overlap_to_mem_size_async_copy_ratio, int64_t mem_size_bytes, + const Shape* shape_override) + : while_nest_level_( + cost_analysis.hlo_live_range().instruction_schedule().size() + 1, 0), + computation_nest_level_( + cost_analysis.hlo_live_range().instruction_schedule().size() + 1, 0), + cost_analysis_(cost_analysis), + min_overlap_to_async_copy_ratio_(min_overlap_to_async_copy_ratio), + preferred_overlap_to_async_copy_ratio_( + preferred_overlap_to_async_copy_ratio), + max_async_copy_elapsed_( + cost_analysis_.GetAsyncCopyElapsed( + ShapeUtil::MakeShape(S32, {mem_size_bytes / 4})) * + max_overlap_to_mem_size_async_copy_ratio), + shape_override_(shape_override ? std::optional(*shape_override) + : std::nullopt) { + instruction_schedule_ = + &cost_analysis_.hlo_live_range().instruction_schedule(); + + // Create a vector of elapsed times and while nesting levels of HLO + // instructions. The elapsed times are multiplied by + // pow(while_execution_count, nest_level) to account for executing the HLOs + // multiple times in while loops. + std::vector instructions_elapsed_time( + instruction_schedule_->size() + 1, 0.0); + int max_while_nest_level = 0; + for (const auto& instruction_and_logical_time : *instruction_schedule_) { + // To avoid double counting, don't include the elapsed time of while and + // conditional HLOs. + const HloInstruction* instruction = instruction_and_logical_time.first; + int64_t logical_time = instruction_and_logical_time.second; + if (logical_time >= instructions_elapsed_time.size()) { + instructions_elapsed_time.resize(logical_time + 1, 0.0); + while_nest_level_.resize(logical_time + 1, 0); + } + int while_nest_level = cost_analysis_.CalculateComputationNestLevel( + instruction_and_logical_time.first, /*while_only=*/true); + while_nest_level_[logical_time] = while_nest_level; + max_while_nest_level = std::max(max_while_nest_level, while_nest_level); + int computation_nest_level = cost_analysis_.CalculateComputationNestLevel( + instruction_and_logical_time.first, /*while_only=*/false); + computation_nest_level_[logical_time] = computation_nest_level; + if (instruction->opcode() == HloOpcode::kWhile || + instruction->opcode() == HloOpcode::kConditional) { + continue; + } + float elapsed_time = cost_analysis_.GetInstructionElapsed( + *instruction_and_logical_time.first); + instructions_elapsed_time[logical_time] = + elapsed_time * cost_analysis_.GetWhileNestMultiplier(while_nest_level); + } + // As an optimization, create a cumulative sum vector of elapsed time. + float cumsum = 0.0; + elapsed_time_cumsum_.reserve(instructions_elapsed_time.size()); + for (float elapsed_time : instructions_elapsed_time) { + cumsum += elapsed_time; + elapsed_time_cumsum_.push_back(cumsum); + } + // To be able to accurately determine the minimum nest level between a start + // time and an end time efficiently, populate a data structure that stores the + // closest 'smaller' nest level change index. + const int64_t size = instructions_elapsed_time.size(); + CHECK_EQ(size, while_nest_level_.size()); + std::vector most_recent_by_level(while_nest_level_.size(), -1); + int prev_nest_level = 0; + int change_idx = -1; + while_nest_level_change_.reserve(size); + for (int i = 0; i < size; ++i) { + int nest_level = while_nest_level_[i]; + if (nest_level != prev_nest_level) { + prev_nest_level = nest_level; + // Compute last change index by choosing the most recent instruction index + // with smaller nesting level. Note that it may happen that even though + // there were few different regions with other nest levels before, all of + // then are same or bigger than this one, in which case we'll end up with + // -1, e.g. if you got nest level 0 no need checking anything else. + change_idx = -1; + for (int smaller_level = 0; smaller_level < nest_level; smaller_level++) { + change_idx = std::max(change_idx, most_recent_by_level[smaller_level]); + } + } + most_recent_by_level[nest_level] = i; + while_nest_level_change_.push_back(change_idx); + } + for (int i = 0; i <= max_while_nest_level; ++i) { + while_execution_counts_.push_back(cost_analysis_.GetWhileNestMultiplier(i)); + } +} + +float CostAnalysisPrefetchIntervalPicker::GetMaxElapsedInAlternateMemory( + float async_copy_elapsed) const { + return max_async_copy_elapsed_; +} + +bool CostAnalysisPrefetchIntervalPicker::CanAllocateInAlternateMemoryNoCopy( + const Shape& shape, int64_t start_time, int64_t end_time) const { + // Even though this method returns if we allow the buffer in alternate memory + // _without_ asynchronous copies, calculate how long it would have taken to + // copy it and compare it to the elapsed time in the logical interval. + float async_copy_elapsed = cost_analysis_.GetAsyncCopyElapsed( + shape_override_ ? *shape_override_ : shape); + float logical_interval_elapsed = + GetLogicalIntervalElapsed(start_time, end_time); + return GetMaxElapsedInAlternateMemory(async_copy_elapsed) > + logical_interval_elapsed; +} + +int64_t CostAnalysisPrefetchIntervalPicker::PreferredEvictionEndTime( + const Shape& shape, int64_t start_time, int64_t latest_end_time) const { + float async_copy_elapsed = cost_analysis_.GetAsyncCopyElapsed( + shape_override_ ? *shape_override_ : shape); + int64_t end_time; + for (end_time = start_time + 1; end_time <= latest_end_time; ++end_time) { + float logical_interval_elapsed = + GetLogicalIntervalElapsed(start_time, end_time); + if (logical_interval_elapsed >= + (1 + kEvictionRetryMultiplier * retry_number_) * + preferred_overlap_to_async_copy_ratio_ * async_copy_elapsed) { + break; + } + } + return end_time; +} + +int64_t CostAnalysisPrefetchIntervalPicker::LatestPrefetchStartTime( + const Shape& shape, int64_t start_time, int64_t end_time, + const HloUse* use) const { + // Find the earliest time that satisfies max_overlap_to_async_copy_ratio_. + float async_copy_elapsed = cost_analysis_.GetAsyncCopyElapsed( + shape_override_ ? *shape_override_ : shape); + // If there is a use, estimate the time we would save by having this op in + // alternate memory. + float inst_elapsed_reduction = 0.0f; + if (use) { + float elapsed_time = + cost_analysis_.GetInstructionElapsed(*use->instruction); + float elapsed_time_in_alternate_mem = + cost_analysis_.GetInstructionElapsedInAlternateMemory( + *use->instruction, + /*operands_in_alternate_mem=*/ + {std::make_pair(use->operand_number, use->operand_index)}, + /*outputs_in_alternate_mem=*/{}); + inst_elapsed_reduction = elapsed_time - elapsed_time_in_alternate_mem; + } + int end_nest_level = computation_nest_level_[end_time]; + + // Find the latest time we're allowed to start prefetching. + float min_interval = min_overlap_to_async_copy_ratio_ * async_copy_elapsed; + int latest_prefetch_time; + for (latest_prefetch_time = end_time - 1; + latest_prefetch_time >= start_time && + (computation_nest_level_[latest_prefetch_time] != end_nest_level || + min_interval > + GetLogicalIntervalElapsed(latest_prefetch_time, end_time) + + inst_elapsed_reduction); + --latest_prefetch_time) { + } + + return latest_prefetch_time; +} + +int64_t CostAnalysisPrefetchIntervalPicker::PreferredPrefetchStartTime( + const Shape& shape, int64_t earliest_prefetch_start_time, + int64_t latest_prefetch_start_time, int64_t prefetch_end_time) const { + // Between the earliest and latest prefetch interval, find the interval + // closest to the preferred interval and start iterating from there. + float async_copy_elapsed = cost_analysis_.GetAsyncCopyElapsed( + shape_override_ ? *shape_override_ : shape); + int64_t preferred_prefetch_start_time = earliest_prefetch_start_time; + float preferred_interval = + preferred_overlap_to_async_copy_ratio_ * async_copy_elapsed; + float best_interval = GetLogicalIntervalElapsed(earliest_prefetch_start_time, + prefetch_end_time); + int end_nest_level = computation_nest_level_[prefetch_end_time]; + for (int64_t prefetch_start_time = earliest_prefetch_start_time + 1; + prefetch_start_time <= latest_prefetch_start_time; + ++prefetch_start_time) { + float interval = + GetLogicalIntervalElapsed(prefetch_start_time, prefetch_end_time); + if (computation_nest_level_[prefetch_start_time] == end_nest_level && + std::abs(preferred_interval - interval) < + std::abs(preferred_interval - best_interval)) { + best_interval = interval; + preferred_prefetch_start_time = prefetch_start_time; + } + } + return preferred_prefetch_start_time; +} + +int64_t CostAnalysisPrefetchIntervalPicker::LatestPrefetchEndTime( + int64_t original_prefetch_end_time, + int64_t proposed_prefetch_end_time) const { + // Iterate towards the beginning until we find a suitable end time that is the + // same while nest level as the original prefetch end time. + int64_t original_nest_level = + computation_nest_level_[original_prefetch_end_time]; + int64_t new_prefetch_end_time; + for (new_prefetch_end_time = proposed_prefetch_end_time; + computation_nest_level_[new_prefetch_end_time] != original_nest_level; + --new_prefetch_end_time) { + } + return new_prefetch_end_time; +} + +int64_t CostAnalysisPrefetchIntervalPicker::EstimatedPrefetchEndTime( + const Shape& shape, int64_t start_time, int64_t end_time) const { + float async_copy_elapsed = cost_analysis_.GetAsyncCopyElapsed( + shape_override_ ? *shape_override_ : shape); + int64_t estimated_end_time; + for (estimated_end_time = start_time + 1; estimated_end_time < end_time; + ++estimated_end_time) { + float interval = GetLogicalIntervalElapsed(start_time, estimated_end_time); + if (interval >= async_copy_elapsed) { + break; + } + } + return estimated_end_time; +} + +void CostAnalysisPrefetchIntervalPicker::Begin( + const HloUse& use, int64_t start_time, int64_t end_time, + std::optional preferred_time) { + const Shape& shape = ShapeUtil::GetSubshape( + use.instruction->operand(use.operand_number)->shape(), use.operand_index); + // Find the earliest time that satisfies max_overlap_to_async_copy_ratio_. + async_copy_elapsed_ = cost_analysis_.GetAsyncCopyElapsed( + shape_override_ ? *shape_override_ : shape); + // Estimate the time we would save by having this op in alternate memory. + float elapsed_time = cost_analysis_.GetInstructionElapsed(*use.instruction); + float elapsed_time_in_alternate_mem = + cost_analysis_.GetInstructionElapsedInAlternateMemory( + *use.instruction, /*operands_in_alternate_mem=*/ + {std::make_pair(use.operand_number, use.operand_index)}, + /*outputs_in_alternate_mem=*/{}); + inst_elapsed_reduction_ = elapsed_time - elapsed_time_in_alternate_mem; + end_logical_time_ = end_time; + int end_nest_level = computation_nest_level_[end_logical_time_]; + + // Find the latest time we're allowed to start prefetching. + float min_interval = min_overlap_to_async_copy_ratio_ * async_copy_elapsed_; + latest_prefetch_time_ = + LatestPrefetchStartTime(shape, start_time, end_time, &use); + + // Find the earliest time we're allowed to start prefetching. + float max_interval = GetMaxElapsedInAlternateMemory(async_copy_elapsed_); + for (earliest_prefetch_time_ = start_time; + earliest_prefetch_time_ < latest_prefetch_time_ && + (computation_nest_level_[earliest_prefetch_time_] != end_nest_level || + max_interval < GetLogicalIntervalElapsed(earliest_prefetch_time_, + end_logical_time_)); + ++earliest_prefetch_time_) { + } + if (earliest_prefetch_time_ > latest_prefetch_time_) { + // There is no available prefetch interval for the given start and end + // times. Set the iterators accordingly to ensure Done() returns true. + increasing_prefetch_time_iterator_ = earliest_prefetch_time_; + decreasing_prefetch_time_iterator_ = latest_prefetch_time_; + CHECK(Done()); + return; + } + + int64_t starting_prefetch_time; + if (preferred_time && *preferred_time <= latest_prefetch_time_) { + starting_prefetch_time = *preferred_time; + } else { + starting_prefetch_time = + PreferredPrefetchStartTime(shape, earliest_prefetch_time_, + latest_prefetch_time_, end_logical_time_); + } + float preferred_interval = + preferred_overlap_to_async_copy_ratio_ * async_copy_elapsed_; + VLOG(4) << "Interval min/max/preferred = " << min_interval << " " + << max_interval << " " << preferred_interval + << " prefetch time earliest/latest/starting = " + << earliest_prefetch_time_ << " " << latest_prefetch_time_ << " " + << starting_prefetch_time; + + increasing_prefetch_time_iterator_ = starting_prefetch_time; + decreasing_prefetch_time_iterator_ = starting_prefetch_time; + using_increasing_prefetch_time_iterator_ = true; + // Since both iterators start at the same position, call Next() once to + // advance one of the iterators. + Next(); +} + +int64_t CostAnalysisPrefetchIntervalPicker::Next() { + CHECK(!Done()) << "Prefetch interval picker's Next() is called even though " + "Done() is false"; + if (using_increasing_prefetch_time_iterator_) { + int64_t prefetch_time = increasing_prefetch_time_iterator_++; + while (increasing_prefetch_time_iterator_ <= latest_prefetch_time_ && + computation_nest_level_[increasing_prefetch_time_iterator_] != + computation_nest_level_[end_logical_time_]) { + ++increasing_prefetch_time_iterator_; + } + if (decreasing_prefetch_time_iterator_ >= earliest_prefetch_time_) { + using_increasing_prefetch_time_iterator_ = false; + } + return prefetch_time; + } else { + int64_t prefetch_time = decreasing_prefetch_time_iterator_--; + // As a compilation time optimization, reduce the number of intervals that + // this prefetch interval picker returns. When we run out of the increasing + // prefetch time iterator, only explore up to + // kNumExploredDecreasingIntervals intervals. To do that, calculate the + // 1/kNumExploredDecreasingIntervals of the elapsed time between the + // earliest prefetch time and the use, and decrement the iterator until the + // prefetch elapsed time is at least as large as this target value. This + // allows us to reduce the number of expensive heap fit and resource checks + // when the graph consists of a large number of fast-executing HLOs. + // + // Shown pictorially, assuming kNumExploredDecreasingIntervals = 3 and the + // numbers indicating the elapsed time of the HLOs, only the indicated + // options for prefetch start time would be explored: + // + // ---1---1---3---1---1---1---1---0---0---0---0---1---5---X + // ^ ^ ^ ^ + // Option3 Option2 Option1 Use + // (Earliest) + float next_target_interval_elapsed = 0; + if (increasing_prefetch_time_iterator_ > latest_prefetch_time_) { + next_target_interval_elapsed = + GetLogicalIntervalElapsed(prefetch_time, end_logical_time_) + + (GetLogicalIntervalElapsed(earliest_prefetch_time_, + end_logical_time_) / + kNumExploredDecreasingIntervals); + VLOG(3) << "Next target interval elapsed: " + << next_target_interval_elapsed; + } + while (decreasing_prefetch_time_iterator_ >= earliest_prefetch_time_ && + (computation_nest_level_[decreasing_prefetch_time_iterator_] != + computation_nest_level_[end_logical_time_] || + GetLogicalIntervalElapsed(decreasing_prefetch_time_iterator_, + end_logical_time_) < + next_target_interval_elapsed)) { + --decreasing_prefetch_time_iterator_; + } + if (increasing_prefetch_time_iterator_ <= latest_prefetch_time_) { + using_increasing_prefetch_time_iterator_ = true; + } + return prefetch_time; + } +} + +bool CostAnalysisPrefetchIntervalPicker::Done() const { + return increasing_prefetch_time_iterator_ > latest_prefetch_time_ && + decreasing_prefetch_time_iterator_ < earliest_prefetch_time_; +} + +int64_t CostAnalysisPrefetchIntervalPicker::latest_time() const { + return latest_prefetch_time_; +} + +void CostAnalysisPrefetchIntervalPicker::SetRetryNumber(int retry_number) { + retry_number_ = retry_number; +} + +int CostAnalysisPrefetchIntervalPicker::GetMinWhileNestLevel( + int64_t start_time, int64_t end_time) const { + int min_nest_level = + std::min(while_nest_level_[start_time], while_nest_level_[end_time]); + int change_idx = while_nest_level_change_[end_time]; + while (change_idx >= start_time) { + min_nest_level = std::min(min_nest_level, while_nest_level_[change_idx]); + change_idx = while_nest_level_change_[change_idx]; + } + return min_nest_level; +} + +float CostAnalysisPrefetchIntervalPicker::GetLogicalIntervalElapsed( + int64_t start_time, int64_t end_time) const { + CHECK_LE(start_time, end_time); + if (start_time == end_time) { + return 0.0; + } + if (start_time < 0) { + start_time = 0; + } + // Since elapsed_time_cumsum_ is already weighed by the while loop nesting + // level, normalize the elapsed time by dividing with the nesting factor of + // the interval (start and end times). + int interval_while_nest_level = GetMinWhileNestLevel(start_time, end_time); + return (elapsed_time_cumsum_[end_time - 1] - + elapsed_time_cumsum_[start_time]) / + while_execution_counts_[interval_while_nest_level]; +} + +std::string CostAnalysisPrefetchIntervalPicker::ToDebugString() const { + int current_logical_prefetch_time = using_increasing_prefetch_time_iterator_ + ? increasing_prefetch_time_iterator_ + : decreasing_prefetch_time_iterator_; + float logical_interval_elapsed = GetLogicalIntervalElapsed( + current_logical_prefetch_time, end_logical_time_); + return absl::StrCat( + "Async copy elapsed (s) = ", async_copy_elapsed_, + ", inst elapsed reduction (s) = ", inst_elapsed_reduction_, + ", logical interval elapsed (s) = ", logical_interval_elapsed, + ", interval = (", current_logical_prefetch_time, ", ", end_logical_time_, + ")"); +} + +std::string CostAnalysisPrefetchIntervalPicker::ToNoCopyDebugString( + const Shape& shape, int64_t start_time, int64_t end_time) const { + float async_copy_elapsed = cost_analysis_.GetAsyncCopyElapsed( + shape_override_ ? *shape_override_ : shape); + float logical_interval_elapsed = + GetLogicalIntervalElapsed(start_time, end_time); + return absl::StrCat( + "Async copy elapsed (s) = ", async_copy_elapsed, + ", logical interval elapsed (s) = ", logical_interval_elapsed); +} + +std::optional +CostAnalysisPrefetchIntervalPicker::BufferIntervalAlternateMemoryBenefit( + const GlobalDecreasingSizeBestFitHeap::BufferInterval& interval) + const { + return cost_analysis_.GetMemoryBoundedness(interval); +} + +} // namespace memory_space_assignment +} // namespace xla diff --git a/third_party/xla/xla/service/memory_space_assignment/prefetch_interval_picker.h b/third_party/xla/xla/service/memory_space_assignment/prefetch_interval_picker.h new file mode 100644 index 00000000000000..0ae8af53071283 --- /dev/null +++ b/third_party/xla/xla/service/memory_space_assignment/prefetch_interval_picker.h @@ -0,0 +1,292 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_MEMORY_SPACE_ASSIGNMENT_PREFETCH_INTERVAL_PICKER_H_ +#define XLA_SERVICE_MEMORY_SPACE_ASSIGNMENT_PREFETCH_INTERVAL_PICKER_H_ + +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/service/heap_simulator/heap_simulator.h" +#include "xla/service/hlo.pb.h" +#include "xla/service/hlo_value.h" +#include "xla/service/memory_space_assignment/cost_analysis.h" +#include "xla/service/memory_space_assignment/memory_space_assignment.pb.h" +#include "xla/shape.h" +#include "xla/util.h" + +namespace xla { +namespace memory_space_assignment { + +// Abstract base class that memory space assignment uses to pick prefetch +// intervals. +class PrefetchIntervalPicker { + public: + PrefetchIntervalPicker() = default; + virtual ~PrefetchIntervalPicker() = default; + + // Returns true if the buffer can be allocated in alternate memory space + // without any copies (prefetches). + virtual bool CanAllocateInAlternateMemoryNoCopy(const Shape& shape, + int64_t start_time, + int64_t end_time) const = 0; + + // Returns the preferred end time for an eviction that starts at a given time + // and must end by the given end time. + virtual int64_t PreferredEvictionEndTime(const Shape& shape, + int64_t start_time, + int64_t latest_end_time) const = 0; + + // Returns the latest time that a prefetch can start. + virtual int64_t LatestPrefetchStartTime(const Shape& shape, + int64_t start_time, int64_t end_time, + const HloUse* use) const = 0; + + // Returns the preferred time that a prefetch can start. + virtual int64_t PreferredPrefetchStartTime( + const Shape& shape, int64_t earliest_prefetch_start_time, + int64_t latest_prefetch_start_time, int64_t prefetch_end_time) const = 0; + + // Returns the latest time that a prefetch can end that is less than or equal + // to proposed_prefetch_end_time. + virtual int64_t LatestPrefetchEndTime( + int64_t original_prefetch_end_time, + int64_t proposed_prefetch_end_time) const { + return proposed_prefetch_end_time; + } + + // Returns the estimated end time of a prefetch that starts at the given time. + virtual int64_t EstimatedPrefetchEndTime(const Shape& shape, + int64_t start_time, + int64_t end_time) const = 0; + + // Returns the elapsed time in seconds between the logical interval that + // corresponds to the instruction schedule. + virtual float GetLogicalIntervalElapsed(int64_t start_time, + int64_t end_time) const = 0; + + // Begins the iterator for the first start time of the prefetch. + virtual void Begin(const HloUse& use, int64_t start_time, int64_t end_time, + std::optional preferred_time) = 0; + + // Advances the start time of the prefetch and returns that value. + virtual int64_t Next() = 0; + + // Returns true if the available prefetch intervals have been exhausted. + virtual bool Done() const = 0; + + // Returns the latest time the prefetch interval picker will have pick. + virtual int64_t latest_time() const = 0; + + // The retry number can be used to modify the interval picking policies. The + // first attempt will have a retry_number of 0, then 1, etc. + virtual void SetRetryNumber(int retry_number) { + retry_number_ = retry_number; + } + int retry_number() const { return retry_number_; } + + // Returns a debug string for the current state of the prefetch interval + // picker. + virtual std::string ToDebugString() const = 0; + + // Returns a debug string for no-copy allocation. + virtual std::string ToNoCopyDebugString(const Shape& shape, + int64_t start_time, + int64_t end_time) const = 0; + + // Prefetch interval pickers may return a value corresponding to the benefit + // of placing the BufferInterval in the alternate memory. The larger value, + // the more beneficial. + virtual std::optional BufferIntervalAlternateMemoryBenefit( + const GlobalDecreasingSizeBestFitHeap::BufferInterval& interval) + const { + return std::nullopt; + } + + protected: + const absl::flat_hash_map* + instruction_schedule_ = nullptr; + int retry_number_ = 0; +}; + +// Prefetch interval picker that uses instruction count to overlap asynchronous +// copies with independent computation. The min and max overlap counts describe +// the number of independent HLOs overlapped while a value is being prefetched +// into the alternate memory (between CopyStart and CopyDone HLO instructions). +// max_overlap_count attempts to prevent bringing tensors into the alternate +// memory too eagerly and hence occupying the space for other tensors which +// might use it. min_overlap_count attempts to prevent cases where tensors are +// prefetched into the alternate memory without sufficient time for the copy to +// take place. In those cases, it's just better to keep the tensor in the +// default memory instead of hurting the critical path with this copy that +// likely won't finish in time. +class InstructionCountPrefetchIntervalPicker : public PrefetchIntervalPicker { + public: + InstructionCountPrefetchIntervalPicker(int64_t min_overlap_count, + int64_t max_overlap_count) + : min_overlap_count_(min_overlap_count), + max_overlap_count_(max_overlap_count) {} + + bool CanAllocateInAlternateMemoryNoCopy(const Shape& shape, + int64_t start_time, + int64_t end_time) const override; + + int64_t PreferredEvictionEndTime(const Shape& shape, int64_t start_time, + int64_t latest_end_time) const override; + + int64_t LatestPrefetchStartTime(const Shape& shape, int64_t start_time, + int64_t end_time, + const HloUse* use) const override; + + int64_t PreferredPrefetchStartTime(const Shape& shape, + int64_t earliest_prefetch_start_time, + int64_t latest_prefetch_start_time, + int64_t prefetch_end_time) const override; + + int64_t EstimatedPrefetchEndTime(const Shape& shape, int64_t start_time, + int64_t end_time) const override; + float GetLogicalIntervalElapsed(int64_t start_time, + int64_t end_time) const override; + + void Begin(const HloUse& use, int64_t start_time, int64_t end_time, + std::optional preferred_time) override; + + int64_t Next() override; + bool Done() const override; + + int64_t latest_time() const override; + + std::string ToDebugString() const override; + std::string ToNoCopyDebugString(const Shape& shape, int64_t start_time, + int64_t end_time) const override; + + private: + int64_t min_overlap_count_; + int64_t max_overlap_count_; + int64_t end_time_; + int64_t current_prefetch_time_; +}; + +// Prefetch interval picker that uses cost analysis to overlap asynchronous +// copies with independent computation. It uses min (independent computation +// duration) / (asynchronous copy duration) ratio to guide whether the prefetch +// is within the lower bound. For the upper bound, it restricts the maximum +// duration that a buffer may occupy the alternate memory space as a multiple of +// the time it would take to copy a buffer that is the size of the alternate +// memory. It starts with the preferred ratio in Begin() and works its way for +// alternately earlier and later prefetches until hitting min and max ratios. +// The value for buffer size for max async copy is a mechanism to prevent +// copying small buffers between the two memories unnecessarily. For calculating +// the max time that the buffer can reside in alternate memory, we use the +// larger of this value and the actual size of the buffer. A shape override can +// also be provided which causes the interval picker to use that shape for async +// copy durations instead of the actual shape of the copy. +class CostAnalysisPrefetchIntervalPicker : public PrefetchIntervalPicker { + public: + CostAnalysisPrefetchIntervalPicker( + const CostAnalysis& cost_analysis, float min_overlap_to_async_copy_ratio, + float preferred_overlap_to_async_copy_ratio, + float max_overlap_to_mem_size_async_copy_ratio, int64_t mem_size_bytes, + const Shape* shape_override = nullptr); + + bool CanAllocateInAlternateMemoryNoCopy(const Shape& shape, + int64_t start_time, + int64_t end_time) const override; + + int64_t PreferredEvictionEndTime(const Shape& shape, int64_t start_time, + int64_t latest_end_time) const override; + + int64_t LatestPrefetchEndTime( + int64_t original_prefetch_end_time, + int64_t proposed_prefetch_end_time) const override; + + int64_t LatestPrefetchStartTime(const Shape& shape, int64_t start_time, + int64_t end_time, + const HloUse* use) const override; + + int64_t PreferredPrefetchStartTime(const Shape& shape, + int64_t earliest_prefetch_start_time, + int64_t latest_prefetch_start_time, + int64_t prefetch_end_time) const override; + + int64_t EstimatedPrefetchEndTime(const Shape& shape, int64_t start_time, + int64_t end_time) const override; + float GetLogicalIntervalElapsed(int64_t start_time, + int64_t end_time) const override; + + void Begin(const HloUse& use, int64_t start_time, int64_t end_time, + std::optional preferred_time) override; + + int64_t Next() override; + bool Done() const override; + + int64_t latest_time() const override; + + void SetRetryNumber(int retry_number) override; + + std::string ToDebugString() const override; + std::string ToNoCopyDebugString(const Shape& shape, int64_t start_time, + int64_t end_time) const override; + + std::optional BufferIntervalAlternateMemoryBenefit( + const GlobalDecreasingSizeBestFitHeap::BufferInterval& interval) + const override; + + private: + // Finds the minimum nest level in the given interval. + int GetMinWhileNestLevel(int64_t start_time, int64_t end_time) const; + + // Given the elapsed time to copy this buffer to the alternate memory, returns + // the longest time that this buffer may reside in the alternate memory space. + float GetMaxElapsedInAlternateMemory(float async_copy_elapsed) const; + + // For each instruction in the flattened schedule, maintain their elapsed time + // (in cumulative sum) and while nesting level. + std::vector elapsed_time_cumsum_; + std::vector while_nest_level_; + std::vector computation_nest_level_; + // Maintain the index of the most recent (before this instruction) nest level + // change in order to efficiently determine the minimum nest level in an + // interval. + std::vector while_nest_level_change_; + + const CostAnalysis& cost_analysis_; + float min_overlap_to_async_copy_ratio_; + float preferred_overlap_to_async_copy_ratio_; + float max_async_copy_elapsed_; + float async_copy_elapsed_; + float inst_elapsed_reduction_; + int64_t end_logical_time_; + int64_t earliest_prefetch_time_; + int64_t latest_prefetch_time_; + bool using_increasing_prefetch_time_iterator_ = true; + int64_t increasing_prefetch_time_iterator_; + int64_t decreasing_prefetch_time_iterator_; + + std::vector while_execution_counts_; + // Shape override is used to override the shape of the shape of the async copy + // to treat all async copies the same duration. Having an override forces + // prefetches to be scheduled roughly in FIFO order. + std::optional shape_override_; +}; + +} // namespace memory_space_assignment +} // namespace xla + +#endif // XLA_SERVICE_MEMORY_SPACE_ASSIGNMENT_PREFETCH_INTERVAL_PICKER_H_ diff --git a/third_party/xla/xla/service/memory_space_assignment/prefetch_interval_picker_test.cc b/third_party/xla/xla/service/memory_space_assignment/prefetch_interval_picker_test.cc new file mode 100644 index 00000000000000..7b8cac3fcab70a --- /dev/null +++ b/third_party/xla/xla/service/memory_space_assignment/prefetch_interval_picker_test.cc @@ -0,0 +1,406 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/memory_space_assignment/prefetch_interval_picker.h" + +#include +#include + +#include +#include "absl/log/log.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/service/hlo_cost_analysis.h" +#include "xla/service/hlo_value.h" +#include "xla/service/memory_space_assignment/cost_analysis.h" +#include "xla/service/memory_space_assignment/testing_utils.h" +#include "xla/shape.h" +#include "xla/shape_util.h" +#include "xla/tests/hlo_test_base.h" +#include "tsl/platform/statusor.h" + +namespace xla { +namespace memory_space_assignment { +namespace { + +constexpr int64_t kPointerSize = 8; + +int64_t ShapeSize(const Shape& shape) { + return ShapeUtil::ByteSizeOf(shape, kPointerSize); +} + +using CostAnalysisPrefetchIntervalPickerTest = HloTestBase; + +TEST_F(CostAnalysisPrefetchIntervalPickerTest, PrefetchIntervalOrder) { + absl::string_view hlo_string = R"( + HloModule bug, is_scheduled=true + + ENTRY Entry { + param0 = f32[2,4] parameter(0) + a = f32[2,4] negate(param0) + b = f32[2,4] negate(a) + c = f32[2,4] negate(b) + d = f32[2,4] negate(c) + e = f32[2,4] negate(d) + f = f32[2,4] negate(e) + g = f32[2,4] negate(f) + h = f32[2,4] negate(g) + i = f32[2,4] negate(h) + j = f32[2,4] negate(i) + k = f32[2,4] negate(j) + l = f32[2,4] negate(k) + m = f32[2,4] negate(l) + n = f32[2,4] negate(m) + o = f32[2,4] negate(n) + p = f32[2,4] negate(o) + q = f32[2,4] negate(p) + r = f32[2,4] negate(q) + s = f32[2,4] negate(r) + t = f32[2,4] negate(s) + u = f32[2,4] negate(t) + ROOT v = f32[2,4] add(u, param0) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + + HloCostAnalysis hlo_cost_analysis(ShapeSize); + CostAnalysisOptions options; + TF_ASSERT_OK_AND_ASSIGN( + auto cost_analysis, + FakeCostAnalysis::Create(hlo_cost_analysis, *module, options)); + CostAnalysisPrefetchIntervalPicker interval_picker( + *cost_analysis, + /*min_overlap_to_async_copy_ratio=*/1.0, + /*preferred_overlap_to_async_copy_ratio=*/2.0, + /*max_overlap_to_mem_size_async_copy_ratio=*/4.0, + /*mem_size_bytes=*/32); + + HloInstruction* root = module->entry_computation()->root_instruction(); + const HloUse use{root, /*operand_number=*/1, /*operand_index=*/{}}; + interval_picker.Begin(use, /*start_time=*/0, /*end_time=*/22, std::nullopt); + + // Expect that the first interval is (15, 22), which has elapsed time of 6.0, + // twice of the async copy elased (3.0). Then we expect that intervals will be + // visited in alternating increasing and decreasing orders until hitting the + // min and max async copy overlap ratios, which are the intervals (18, 22) + // and (9, 22) respectively. + LOG(INFO) << interval_picker.ToDebugString(); + EXPECT_EQ(interval_picker.Next(), 15); + LOG(INFO) << interval_picker.ToDebugString(); + EXPECT_EQ(interval_picker.Next(), 16); + LOG(INFO) << interval_picker.ToDebugString(); + EXPECT_EQ(interval_picker.Next(), 14); + LOG(INFO) << interval_picker.ToDebugString(); + EXPECT_EQ(interval_picker.Next(), 17); + LOG(INFO) << interval_picker.ToDebugString(); + EXPECT_EQ(interval_picker.Next(), 13); + LOG(INFO) << interval_picker.ToDebugString(); + EXPECT_EQ(interval_picker.Next(), 18); // Min async overlap ratio reached. + LOG(INFO) << interval_picker.ToDebugString(); + EXPECT_EQ(interval_picker.Next(), 12); + LOG(INFO) << interval_picker.ToDebugString(); + EXPECT_EQ(interval_picker.Next(), 11); + LOG(INFO) << interval_picker.ToDebugString(); + EXPECT_EQ(interval_picker.Next(), 10); + LOG(INFO) << interval_picker.ToDebugString(); + EXPECT_EQ(interval_picker.Next(), 9); // Max async overlap ratio reached. + LOG(INFO) << interval_picker.ToDebugString(); + EXPECT_TRUE(interval_picker.Done()); + + // Expect that if the time between start_time and end_time is too short, there + // won't be any available intervals. + interval_picker.Begin(use, /*start_time=*/19, /*end_time=*/22, std::nullopt); + LOG(INFO) << interval_picker.ToDebugString(); + EXPECT_TRUE(interval_picker.Done()); +} + +TEST_F(CostAnalysisPrefetchIntervalPickerTest, PrefetchIntervalOrderWhile) { + absl::string_view hlo_string = R"( + HloModule bug, is_scheduled=true + + while_condition { + param1 = (f32[2,4]) parameter(0) // 19 + ROOT cond = pred[] constant(true) // 20 + } + + while_body { + param2 = (f32[2,4]) parameter(0) // 21 + gte2 = f32[2,4] get-tuple-element(param2), index=0 // 22 + add = f32[2,4] add(gte2, gte2) // 23 + ROOT tuple2 = (f32[2,4]) tuple(add) // 24 + } + + ENTRY Entry { + param0 = f32[2,4] parameter(0) // 0 + a = f32[2,4] negate(param0) // 1 + b = f32[2,4] negate(a) // 2 + c = f32[2,4] negate(b) // 3 + d = f32[2,4] negate(c) // 4 + e = f32[2,4] negate(d) // 5 + f = f32[2,4] negate(e) // 6 + g = f32[2,4] negate(f) // 7 + h = f32[2,4] negate(g) // 8 + i = f32[2,4] negate(h) // 9 + j = f32[2,4] negate(i) // 10 + k = f32[2,4] negate(j) // 11 + l = f32[2,4] negate(k) // 12 + m = f32[2,4] negate(l) // 13 + n = f32[2,4] negate(m) // 14 + o = f32[2,4] negate(n) // 15 + p = f32[2,4] negate(o) // 16 + q = f32[2,4] negate(p) // 17 + tuple = (f32[2,4]) tuple(q) // 18 + while = (f32[2,4]) while(tuple), condition=while_condition, body=while_body // 25 + gte1 = f32[2,4] get-tuple-element(while), index=0 // 26 + r = f32[2,4] negate(gte1) // 27 + s = f32[2,4] negate(r) // 28 + t = f32[2,4] negate(s) // 29 + u = f32[2,4] negate(t) // 30 + ROOT v = f32[2,4] add(u, param0) // 31 + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + + HloCostAnalysis hlo_cost_analysis(ShapeSize); + CostAnalysisOptions options; + TF_ASSERT_OK_AND_ASSIGN( + auto cost_analysis, + FakeCostAnalysis::Create(hlo_cost_analysis, *module, options)); + CostAnalysisPrefetchIntervalPicker interval_picker( + *cost_analysis, + /*min_overlap_to_async_copy_ratio=*/1.0, + /*preferred_overlap_to_async_copy_ratio=*/2.0, + /*max_overlap_to_mem_size_async_copy_ratio=*/12.0, + /*mem_size_bytes=*/32); + + EXPECT_EQ(cost_analysis->GetWhileNestMultiplier(1), 5.0); + HloInstruction* root = module->entry_computation()->root_instruction(); + const HloUse use{root, /*operand_number=*/1, /*operand_index=*/{}}; + interval_picker.Begin(use, /*start_time=*/0, /*end_time=*/31, std::nullopt); + + // Because there are while loop computations between [19, 24], we ensure that + // the interval picker avoids this interval. + LOG(INFO) << interval_picker.ToDebugString(); + EXPECT_EQ(interval_picker.Next(), 25); + LOG(INFO) << interval_picker.ToDebugString(); + EXPECT_EQ(interval_picker.Next(), 26); + LOG(INFO) << interval_picker.ToDebugString(); + EXPECT_EQ(interval_picker.Next(), 18); + LOG(INFO) << interval_picker.ToDebugString(); + EXPECT_EQ(interval_picker.Next(), 27); // Min async overlap ratio reached. + LOG(INFO) << interval_picker.ToDebugString(); + EXPECT_EQ(interval_picker.Next(), 17); // Max async overlap ratio reached. + LOG(INFO) << interval_picker.ToDebugString(); + EXPECT_TRUE(interval_picker.Done()); +} + +TEST_F(CostAnalysisPrefetchIntervalPickerTest, NestedWhile) { + // This test is to check against a bug where we didn't assign + // while_nest_level_ for while instructions, and defaulting to 0. This could + // cause the prefetch interval logic to think a nested while instruction is + // the same level as the outermost computation. + absl::string_view hlo_string = R"( + HloModule bug, is_scheduled=true + + while_condition.2 { + param1 = (f32[2,4]) parameter(0) // 11 + ROOT cond = pred[] constant(true) // 12 + } + + while_body.2 { + param2 = (f32[2,4]) parameter(0) // 13 + gte2 = f32[2,4] get-tuple-element(param2), index=0 // 14 + add = f32[2,4] add(gte2, gte2) // 15 + ROOT tuple2 = (f32[2,4]) tuple(add) // 16 + } + + while_condition.1 { + param3 = (f32[2,4]) parameter(0) // 5 + ROOT cond = pred[] constant(true) // 6 + } + + while_body.1 { + param4 = (f32[2,4]) parameter(0) // 7 + gte1 = f32[2,4] get-tuple-element(param4), index=0 // 8 + add1 = f32[2,4] add(gte1, gte1) // 9 + tuple1 = (f32[2,4]) tuple(add1) // 10 + while = (f32[2,4]) while(tuple1), condition=while_condition.2, body=while_body.2 // 17 + gte2 = f32[2,4] get-tuple-element(while), index=0 // 18 + add2 = f32[2,4] add(gte2, gte2) // 19 + ROOT tuple2 = (f32[2,4]) tuple(add2) // 20 + } + + ENTRY Entry { + param0 = f32[2,4] parameter(0) // 0 + a = f32[2,4] negate(param0) // 1 + b = f32[2,4] negate(a) // 2 + c = f32[2,4] negate(b) // 3 + tuple = (f32[2,4]) tuple(c) // 4 + while = (f32[2,4]) while(tuple), condition=while_condition.1, body=while_body.1 // 21 + gte1 = f32[2,4] get-tuple-element(while), index=0 // 22 + ROOT root = f32[2,4] add(gte1, param0) // 23 + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + + HloCostAnalysis hlo_cost_analysis(ShapeSize); + CostAnalysisOptions options; + TF_ASSERT_OK_AND_ASSIGN( + auto cost_analysis, + FakeCostAnalysis::Create(hlo_cost_analysis, *module, options)); + CostAnalysisPrefetchIntervalPicker interval_picker( + *cost_analysis, + /*min_overlap_to_async_copy_ratio=*/1.0, + /*preferred_overlap_to_async_copy_ratio=*/2.0, + /*max_overlap_to_mem_size_async_copy_ratio=*/12.0, + /*mem_size_bytes=*/32); + + HloInstruction* root = module->entry_computation()->root_instruction(); + const HloUse use{root, /*operand_number=*/1, /*operand_index=*/{}}; + const Shape& shape = root->operand(1)->shape(); + + // We expect the root's latest prefetch start time to be before the while loop + // (logical time 4). + EXPECT_EQ(interval_picker.LatestPrefetchStartTime(shape, /*start_time=*/0, + /*end_time=*/23, &use), + 4); +} + +TEST_F(CostAnalysisPrefetchIntervalPickerTest, ConsecutiveConditionals) { + // This is a test for b/170668492, where prefetching for consecutive + // conditionals can cause the prefetch to start in the conditional's + // computation. + absl::string_view hlo_string = R"( + HloModule bug, is_scheduled=true + + true_computation.0 { + p0 = (f32[3]{0}) parameter(0) // 5 + gte = f32[3]{0} get-tuple-element(p0), index=0 // 6 + ROOT neg1 = f32[3]{0} negate(gte) // 7 + } + + false_computation.0 { + p0 = (f32[3]{0}) parameter(0) // 8 + gte = f32[3]{0} get-tuple-element(p0), index=0 // 9 + ROOT neg2 = f32[3]{0} negate(gte) // 10 + } + + true_computation.1 { + p0 = (f32[3]{0}) parameter(0) // 12 + gte = f32[3]{0} get-tuple-element(p0), index=0 // 13 + ROOT neg1 = f32[3]{0} negate(gte) // 14 + } + + false_computation.1 { + p0 = (f32[3]{0}) parameter(0) // 15 + gte = f32[3]{0} get-tuple-element(p0), index=0 // 16 + ROOT neg2 = f32[3]{0} negate(gte) // 17 + } + + ENTRY entry { + p0 = f32[3]{0} parameter(0) // 0 + p1 = f32[3]{0} parameter(1) // 1 + p2 = pred[] parameter(2) // 2 + tuple0 = (f32[3]{0}) tuple(p0) // 3 + tuple1 = (f32[3]{0}) tuple(p1) // 4 + conditional0 = f32[3]{0} conditional(p2, tuple0, tuple0), true_computation=true_computation.0, false_computation=false_computation.0 // 11 + conditional1 = f32[3]{0} conditional(p2, tuple1, tuple1), true_computation=true_computation.1, false_computation=false_computation.1 // 18 + ROOT tuple2 = (f32[3]{0}, f32[3]{0}) tuple(conditional0, conditional1) // 19 + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + + HloCostAnalysis hlo_cost_analysis(ShapeSize); + CostAnalysisOptions options; + TF_ASSERT_OK_AND_ASSIGN( + auto cost_analysis, + FakeCostAnalysis::Create(hlo_cost_analysis, *module, options)); + CostAnalysisPrefetchIntervalPicker interval_picker( + *cost_analysis, + /*min_overlap_to_async_copy_ratio=*/1.0, + /*preferred_overlap_to_async_copy_ratio=*/2.0, + /*max_overlap_to_mem_size_async_copy_ratio=*/12.0, + /*mem_size_bytes=*/32); + + LOG(INFO) << module->ToString(); + + HloInstruction* conditional1 = + module->entry_computation()->GetInstructionWithName("conditional1"); + const HloUse use{conditional1, /*operand_number=*/1, /*operand_index=*/{0}}; + const Shape& shape = + module->entry_computation()->parameter_instruction(0)->shape(); + + // Expect that the prefetch to start before conditional0's called + // computations. + EXPECT_LT(interval_picker.LatestPrefetchStartTime(shape, /*start_time=*/0, + /*end_time=*/11, &use), + 5); +} + +TEST_F(CostAnalysisPrefetchIntervalPickerTest, EarliestLatestWindowTooSmall) { + // This tests the scenario where there is an op that takes a long time (tanh + // in this example) and as a result the earliest and latest times both fall + // inside this long-running op. In this case, we should still return a valid + // prefetch interval just before the long-running op. + absl::string_view hlo_string = R"( + HloModule bug, is_scheduled=true + + ENTRY Entry { + param0 = f32[2,4] parameter(0) + negate = f32[2,4] negate(param0) + tanh = f32[2,4] tanh(param0) + ROOT add = f32[2,4] add(tanh, negate) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + + HloCostAnalysis hlo_cost_analysis(ShapeSize); + CostAnalysisOptions options; + TF_ASSERT_OK_AND_ASSIGN( + auto cost_analysis, + FakeCostAnalysis::Create(hlo_cost_analysis, *module, options)); + cost_analysis->SetOverrideForGetInstructionElapsed( + [](const HloInstruction& hlo) { + if (hlo.opcode() == HloOpcode::kTanh) { + return 20.0; + } + return 1.0; + }); + CostAnalysisPrefetchIntervalPicker interval_picker( + *cost_analysis, + /*min_overlap_to_async_copy_ratio=*/1.0, + /*preferred_overlap_to_async_copy_ratio=*/2.0, + /*max_overlap_to_mem_size_async_copy_ratio=*/12.0, + /*mem_size_bytes=*/32); + + HloInstruction* root = module->entry_computation()->root_instruction(); + const HloUse use{root, /*operand_number=*/1, /*operand_index=*/{}}; + interval_picker.Begin(use, /*start_time=*/1, /*end_time=*/3, std::nullopt); + + LOG(INFO) << interval_picker.ToDebugString(); + EXPECT_FALSE(interval_picker.Done()); + EXPECT_EQ(interval_picker.Next(), 1); + EXPECT_TRUE(interval_picker.Done()); +} + +} // namespace +} // namespace memory_space_assignment +} // namespace xla diff --git a/third_party/xla/xla/service/memory_space_assignment/slice.cc b/third_party/xla/xla/service/memory_space_assignment/slice.cc new file mode 100644 index 00000000000000..e550e965f804fb --- /dev/null +++ b/third_party/xla/xla/service/memory_space_assignment/slice.cc @@ -0,0 +1,89 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/memory_space_assignment/slice.h" + +#include +#include +#include +#include +#include +#include + +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "xla/service/heap_simulator/heap_simulator.h" +#include "xla/shape.h" + +namespace xla::memory_space_assignment { + +std::tuple +SliceDecisionToTuple(const SliceDecision& decision) { + return std::make_tuple( + std::ref(decision.chunk), decision.exclusive_start_time, + std::ref(decision.sizing), decision.copy_resource_consumed); +} + +std::string SliceDecision::ToString() const { + return absl::StrCat("{ chunk: ", chunk.ToString(), + ", (exclusive) start_time: ", exclusive_start_time, + ", sizing: ", sizing.ToString(), + ", copy_resource_consumed: ", copy_resource_consumed, + " }"); +} + +bool SliceDecision::operator==(const SliceDecision& other) const { + return SliceDecisionToTuple(*this) == SliceDecisionToTuple(other); +} + +std::string SliceProposal::ToString() const { + return absl::StrCat( + "{ slice_shape: ", slice_shape.ToString(true), ", slice_params: { ", + absl::StrJoin(slice_params, ", ", + [](std::string* out, const SliceParam& param) { + absl::StrAppend(out, param.ToString()); + }), + " }, slice_size: ", slice_size, " }"); +} + +std::ostream& operator<<(std::ostream& os, const SliceProposal& proposal) { + os << proposal.ToString(); + return os; +} + +std::tuple&, int64_t> +SliceProposal::ToTuple() const { + return std::make_tuple(std::ref(slice_shape), std::ref(slice_params), + slice_size); +} + +bool SliceProposal::operator==(const SliceProposal& other) const { + return ToTuple() == other.ToTuple(); +} + +std::string SliceParam::ToString() const { + return absl::StrCat("[", start_inclusive, ",", end_exclusive, ")"); +} + +bool SliceParam::operator==(const SliceParam& other) const { + return start_inclusive == other.start_inclusive && + end_exclusive == other.end_exclusive; +} + +bool IsUniformSliceSizingEnabled(const SlicedPrefetchOptions& options) { + return options.max_slices() > 0 && options.preferred_slice_size() > 0; +} + +} // namespace xla::memory_space_assignment diff --git a/third_party/xla/xla/service/memory_space_assignment/slice.h b/third_party/xla/xla/service/memory_space_assignment/slice.h new file mode 100644 index 00000000000000..ca67dd29faf7b1 --- /dev/null +++ b/third_party/xla/xla/service/memory_space_assignment/slice.h @@ -0,0 +1,119 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This file contains definitions for MSA slicing. Slicing is an allocation +// technique in which we allocate a buffer in slices that can start at different +// times, but once allocated, form a contiguous buffer. When copying buffers, we +// may want to allocate a buffer in slices, so that we delay allocating memory +// that would otherwise not be in use, due to copy bandwidth constraints. +// +// The following illustrates a buffer that is fully allocated at time t2, via +// slices. +// +// space +// ^ +// p3 | +-----------+ +// | | s2 | +// p2 | +---+-----------+ +// | | s1 | +// p1 | +-------+-------+ +// | | s0 | +// p0 | +-------+ +// +---|---|---|---|---|----> time +// t0 t1 t2 t3 t4 + +#ifndef XLA_SERVICE_MEMORY_SPACE_ASSIGNMENT_SLICE_H_ +#define XLA_SERVICE_MEMORY_SPACE_ASSIGNMENT_SLICE_H_ + +#include +#include +#include +#include +#include + +#include "xla/service/heap_simulator/heap_simulator.h" +#include "xla/service/memory_space_assignment/memory_space_assignment.pb.h" +#include "xla/shape.h" +#include "xla/shape_util.h" + +namespace xla::memory_space_assignment { + +// The target of a custom call that slicing uses to concatenate slices +// that are already contiguous in memory, into a larger buffer. +inline constexpr char kConcatBitcastCustomCall[] = "ConcatBitcast"; + +// The parameters for slicing a single dimension of a tensor. +struct SliceParam { + std::string ToString() const; + bool operator==(const SliceParam& other) const; + + int64_t start_inclusive; + int64_t end_exclusive; +}; + +// A proposed way to slice a buffer. +struct SliceProposal { + std::string ToString() const; + friend std::ostream& operator<<(std::ostream& os, + const SliceProposal& proposal); + std::tuple&, int64_t> ToTuple() + const; + bool operator==(const SliceProposal& other) const; + + // Shape resulting from the slice. + Shape slice_shape; + + // slice_params map to the parameters that would be passed to a slice + // instruction. Thus: + // * There should be a slice parameter for every dimension in the shape of + // the tensor being sliced. + // * The ith slice_param applies to the ith logical dimension in the shape + // being sliced. + // * If a dimension is not being sliced, it should have a SliceParam of + // {0, dim size}. + std::vector slice_params; + + // The size to be allocated for the slice. Note, this may be > the size of + // the slice shape, due to additional padding that may occur when the slices + // are concatenated back together. + int64_t slice_size; +}; + +// A SliceProposalCollection proposes a way to to slice an AllocationRequest. +// A SliceProposalCollection is generated from a SliceProposalFunction and is +// used when we want to slice a prefetch. +using SliceProposalCollection = std::vector; +using SliceProposalFunction = std::function( + const Shape& shape, const SlicedPrefetchOptions& options)>; + +// A SliceDecision is a SliceProposal that we've determined where and when to +// allocate. +struct SliceDecision { + std::string ToString() const; + bool operator==(const SliceDecision& other) const; + + HeapSimulator::Chunk chunk; + int64_t exclusive_start_time; + SliceProposal sizing; + float copy_resource_consumed; +}; + +// Returns true if the options indicates that there is a preferred slice +// size. +bool IsUniformSliceSizingEnabled(const SlicedPrefetchOptions& options); + +} // namespace xla::memory_space_assignment + +#endif // XLA_SERVICE_MEMORY_SPACE_ASSIGNMENT_SLICE_H_ diff --git a/third_party/xla/xla/service/memory_space_assignment/testing_utils.h b/third_party/xla/xla/service/memory_space_assignment/testing_utils.h new file mode 100644 index 00000000000000..b1d3c94cc0f421 --- /dev/null +++ b/third_party/xla/xla/service/memory_space_assignment/testing_utils.h @@ -0,0 +1,128 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_MEMORY_SPACE_ASSIGNMENT_TESTING_UTILS_H_ +#define XLA_SERVICE_MEMORY_SPACE_ASSIGNMENT_TESTING_UTILS_H_ + +#include +#include +#include +#include + +#include "absl/memory/memory.h" +#include "absl/types/span.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/utils/hlo_live_range.h" +#include "xla/service/call_graph.h" +#include "xla/service/hlo_alias_analysis.h" +#include "xla/service/hlo_cost_analysis.h" +#include "xla/service/memory_space_assignment/cost_analysis.h" +#include "xla/shape.h" +#include "xla/shape_util.h" +#include "xla/statusor.h" +#include "tsl/platform/statusor.h" + +namespace xla { +namespace memory_space_assignment { + +// For testing purposes, we define a cost analysis where we can control the +// elapsed times of each HLO and asynchronous copy. +class FakeCostAnalysis : public CostAnalysis { + public: + static StatusOr> Create( + const HloCostAnalysis& cost_analysis, const HloModule& module, + const CostAnalysisOptions& options) { + TF_ASSIGN_OR_RETURN(auto alias_analysis, HloAliasAnalysis::Run(&module)); + TF_ASSIGN_OR_RETURN(auto hlo_live_range, + HloLiveRange::Run(module.schedule(), *alias_analysis, + module.entry_computation())); + auto call_graph = CallGraph::Build(&module); + return absl::WrapUnique( + new FakeCostAnalysis(cost_analysis, options, std::move(alias_analysis), + std::move(hlo_live_range), std::move(call_graph))); + } + + float GetInstructionElapsed( + const HloInstruction& instruction) const override { + if (get_instruction_elapsed_override_) { + return get_instruction_elapsed_override_(instruction); + } + return 1.0; + } + + float GetInstructionElapsedInAlternateMemory( + const HloInstruction& instruction, + absl::Span> + operands_in_alternate_mem, + absl::Span outputs_in_alternate_mem) const override { + if (get_instruction_elapsed_in_alternate_memory_override_) { + return get_instruction_elapsed_in_alternate_memory_override_( + instruction, operands_in_alternate_mem, outputs_in_alternate_mem); + } + if (!operands_in_alternate_mem.empty()) { + return 0.5; + } else { + return 1.0; + } + } + + float GetAsyncCopyElapsed(const Shape& shape) const override { + if (get_async_copy_elapsed_override_) { + return get_async_copy_elapsed_override_(shape); + } + return 3.0; + } + + // The following methods can be used to override what the above API calls + // return. + void SetOverrideForGetInstructionElapsed( + std::function function) { + get_instruction_elapsed_override_ = function; + } + void SetOverrideForGetInstructionElapsedInAlternateMemory( + std::function>, + absl::Span)> + function) { + get_instruction_elapsed_in_alternate_memory_override_ = function; + } + void SetOverrideForGetAsyncCopyElapsed( + std::function function) { + get_async_copy_elapsed_override_ = function; + } + + protected: + FakeCostAnalysis(const HloCostAnalysis& cost_analysis, + const CostAnalysisOptions& options, + std::unique_ptr alias_analysis, + std::unique_ptr hlo_live_range, + std::unique_ptr call_graph) + : CostAnalysis(cost_analysis, options, std::move(alias_analysis), + std::move(hlo_live_range), std::move(call_graph)) {} + + private: + std::function + get_instruction_elapsed_override_ = nullptr; + std::function>, + absl::Span)> + get_instruction_elapsed_in_alternate_memory_override_ = nullptr; + std::function get_async_copy_elapsed_override_ = nullptr; +}; + +} // namespace memory_space_assignment +} // namespace xla + +#endif // XLA_SERVICE_MEMORY_SPACE_ASSIGNMENT_TESTING_UTILS_H_ diff --git a/third_party/xla/xla/service/metrics.proto b/third_party/xla/xla/service/metrics.proto index 910616b35dcffc..90325b70fcc6fc 100644 --- a/third_party/xla/xla/service/metrics.proto +++ b/third_party/xla/xla/service/metrics.proto @@ -2,6 +2,7 @@ syntax = "proto3"; package xla; +import "google/protobuf/any.proto"; import "google/protobuf/duration.proto"; import "google/protobuf/timestamp.proto"; @@ -13,6 +14,10 @@ message PassMetrics { string pass_name = 2; // Duration of the pass. google.protobuf.Duration pass_duration = 3; + // Custom pass metrics. This is kept opaque, via `google.protobuf.Any`, in + // order to decouple pass agnostic compilation logs from possibly proprietary + // compiler passes. + google.protobuf.Any custom_metrics = 4; } // Defines XLA compilation metrics. diff --git a/third_party/xla/xla/service/optimization_barrier_expander.h b/third_party/xla/xla/service/optimization_barrier_expander.h index 540a974c125ed8..562d29fa8e93a4 100644 --- a/third_party/xla/xla/service/optimization_barrier_expander.h +++ b/third_party/xla/xla/service/optimization_barrier_expander.h @@ -27,7 +27,6 @@ class OptimizationBarrierExpander : public HloModulePass { absl::string_view name() const override { return "cse_barrier_expander"; } - protected: using HloPassInterface::Run; StatusOr Run( HloModule* module, diff --git a/third_party/xla/xla/service/pattern_matcher.h b/third_party/xla/xla/service/pattern_matcher.h index 9796671527baa1..e460b7bbbb5a83 100644 --- a/third_party/xla/xla/service/pattern_matcher.h +++ b/third_party/xla/xla/service/pattern_matcher.h @@ -2694,6 +2694,7 @@ XLA_UNOP_PATTERN(IsFinite) XLA_UNOP_PATTERN(Log) XLA_UNOP_PATTERN(Not) XLA_UNOP_PATTERN(Negate) +XLA_UNOP_PATTERN(OptimizationBarrier) XLA_UNOP_PATTERN(Real) XLA_UNOP_PATTERN(Recv) XLA_UNOP_PATTERN(RecvDone) diff --git a/third_party/xla/xla/service/platform_util.cc b/third_party/xla/xla/service/platform_util.cc index bce5419de36261..83c54d2e9b1a65 100644 --- a/third_party/xla/xla/service/platform_util.cc +++ b/third_party/xla/xla/service/platform_util.cc @@ -27,6 +27,7 @@ limitations under the License. #include "xla/statusor.h" #include "xla/stream_executor/cuda/cuda_platform_id.h" #include "xla/stream_executor/host/host_platform_id.h" +#include "xla/stream_executor/platform_manager.h" #include "xla/stream_executor/rocm/rocm_platform_id.h" #include "xla/stream_executor/stream_executor.h" #include "xla/types.h" @@ -64,7 +65,7 @@ std::string CanonicalPlatformName(const std::string& platform_name) { } StatusOr> GetSupportedPlatforms() { - return se::MultiPlatformManager::PlatformsWithFilter( + return se::PlatformManager::PlatformsWithFilter( [](const se::Platform* platform) { auto compiler_status = Compiler::GetForPlatform(platform); bool supported = compiler_status.ok(); @@ -124,7 +125,7 @@ PlatformUtil::GetSupportedPlatforms() { /*static*/ StatusOr PlatformUtil::GetPlatform( const std::string& platform_name) { TF_ASSIGN_OR_RETURN(se::Platform * platform, - se::MultiPlatformManager::PlatformWithName( + se::PlatformManager::PlatformWithName( xla::CanonicalPlatformName(platform_name))); TF_RETURN_IF_ERROR(Compiler::GetForPlatform(platform).status()); return platform; diff --git a/third_party/xla/xla/service/profile_guided_latency_estimator.cc b/third_party/xla/xla/service/profile_guided_latency_estimator.cc index 288e6775db59b0..50d250a102a205 100644 --- a/third_party/xla/xla/service/profile_guided_latency_estimator.cc +++ b/third_party/xla/xla/service/profile_guided_latency_estimator.cc @@ -37,10 +37,27 @@ LatencyEstimator::TimeCost ProfileGuidedLatencyEstimator::GetLatencyBetween( } auto it = instr_map_.find(from.GetInstr().name()); + if (it == instr_map_.end() && + (from.GetInstr().opcode() == HloOpcode::kAsyncStart || + from.GetInstr().opcode() == HloOpcode::kAsyncDone)) { + absl::string_view wrapped_inst_name = + from.GetInstr().async_wrapped_instruction()->name(); + VLOG(10) << "PGLE found async wrapped instruction: " << wrapped_inst_name + << " in " << from.GetInstr().name(); + it = instr_map_.find(wrapped_inst_name); + } + if (it == instr_map_.end()) { return latency_estimator_->GetLatencyBetween(from, target); } + auto it2 = it->second.latencies.find(target.GetInstr().name()); + if (it2 == it->second.latencies.end() && + (target.GetInstr().opcode() == HloOpcode::kAsyncStart || + target.GetInstr().opcode() == HloOpcode::kAsyncDone)) { + it2 = it->second.latencies.find( + target.GetInstr().async_wrapped_instruction()->name()); + } if (it2 != it->second.latencies.end()) { VLOG(10) << "PGLE found latency between " << from.GetInstr().name() << " and " << target.GetInstr().name() << " in latency info"; diff --git a/third_party/xla/xla/service/profile_guided_latency_estimator_test.cc b/third_party/xla/xla/service/profile_guided_latency_estimator_test.cc index 49aab73c61dcae..1f226f69e6e11c 100644 --- a/third_party/xla/xla/service/profile_guided_latency_estimator_test.cc +++ b/third_party/xla/xla/service/profile_guided_latency_estimator_test.cc @@ -158,4 +158,54 @@ ENTRY entry { INSTANTIATE_TEST_SUITE_P(LatencyHidingSchedulerTest, LatencyHidingSchedulerTest, ::testing::Bool()); +using ProfileGuidedLatencyEstimatorTest = HloTestBase; + +TEST_F(ProfileGuidedLatencyEstimatorTest, + TestProfileGuidedLatencyEstimatorWithAsyncInstruction) { + absl::string_view hlo_string = R"( +HloModule module, is_scheduled=true + +add.1 { + x = f32[] parameter(0) + y = f32[] parameter(1) + ROOT add = f32[] add(x, y) +} + +ENTRY entry { + p0 = f32[16,64,256]{2,1,0} parameter(0) + p1 = f32[16,64,256]{2,1,0} parameter(1) + reduce-scatter-start = ((f32[16,64,256]{2,1,0}, f32[16,64,256]{2,1,0}), (f32[4,64,256]{2,1,0}, f32[4,64,256]{2,1,0})) reduce-scatter-start(p0, p1), channel_id=1, replica_groups={}, dimensions={0}, to_apply=add.1 + reduce-scatter-done = (f32[4,64,256]{2,1,0}, f32[4,64,256]{2,1,0}) reduce-scatter-done(reduce-scatter-start) + ROOT gte = f32[4,64,256]{2,1,0} get-tuple-element(reduce-scatter-done), index=0 +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(auto hlo_module, + ParseAndReturnVerifiedModule(hlo_string)); + EXPECT_TRUE(hlo_module->has_entry_computation()); + + std::string profiled_instructions_text_proto = R"pb( + costs { name: "reduce-scatter" cost_us: 120.0 } + )pb"; + ; + tensorflow::profiler::ProfiledInstructionsProto profiled_instructions_proto; + ASSERT_TRUE(tsl::protobuf::TextFormat::ParseFromString( + profiled_instructions_text_proto, &profiled_instructions_proto)); + + auto sched_config = GetDefaultSchedConfig(); + auto latency_estimator = std::make_unique( + sched_config, std::make_unique(), + profiled_instructions_proto); + HloInstruction* rs_start = + FindInstruction(hlo_module.get(), "reduce-scatter-start"); + HloInstruction* rs_done = + FindInstruction(hlo_module.get(), "reduce-scatter-done"); + HloGraphNode rs_start_node = HloGraphNode(rs_start, 0); + HloGraphNode rs_done_node = HloGraphNode(rs_done, 1); + + double latency = + latency_estimator->GetLatencyBetween(rs_start_node, rs_done_node); + EXPECT_EQ(latency, 120.0); +} + } // namespace xla diff --git a/third_party/xla/xla/service/rendezvous.cc b/third_party/xla/xla/service/rendezvous.cc index 5fea8029195d39..ff37bebca30acc 100644 --- a/third_party/xla/xla/service/rendezvous.cc +++ b/third_party/xla/xla/service/rendezvous.cc @@ -15,14 +15,18 @@ limitations under the License. #include "xla/service/rendezvous.h" +#include +#include #include +#include #include #include "absl/synchronization/notification.h" #include "absl/time/time.h" #include "tsl/platform/logging.h" -namespace xla::internal { +namespace xla { +namespace internal { void AwaitAndLogIfStuck(absl::Notification& ready, std::string_view name, size_t num_threads, absl::Duration warn_stuck_timeout, @@ -52,4 +56,62 @@ void AwaitAndLogIfStuck(absl::Notification& ready, std::string_view name, std::exit(42); } -} // namespace xla::internal +} // namespace internal + +namespace { +inline constexpr int32_t kPending = 0; +inline constexpr int32_t kCompleted = std::numeric_limits::max(); +} // namespace + +RendezvousSingleFlag::RendezvousSingleFlag() : state_(kPending) {} + +RendezvousSingleFlag::InFlightRendezvous::InFlightRendezvous( + RendezvousSingleFlag* flag) + : flag_(flag) {} + +RendezvousSingleFlag::InFlightRendezvous::~InFlightRendezvous() { + if (flag_ == nullptr) return; + + // Reload state and use CAS to decide if we are the one who + // should mark rendezvous flag completed. + int32_t state = flag_->state_.load(); + + CHECK(state != kPending && state != kCompleted) // NOLINT + << "rendezvous can't be in pending or completed state"; + + // Exit the critical section and maybe mark rendezvous as completed. + while (!flag_->state_.compare_exchange_weak( + state, state == 1 ? kCompleted : state - 1)) { + // Check state after CAS failure: while we are in this function no one + // should complete rendezvous without us or switch it back to pending. + CHECK(state != kPending && state != kCompleted); // NOLINT + } +} + +RendezvousSingleFlag::InFlightRendezvous::operator bool() const { + return flag_ != nullptr; +} + +RendezvousSingleFlag::InFlightRendezvous RendezvousSingleFlag::TryJoin() { + // If `state_` is `kCompleted` it means that we have at least one completed + // rendezvous for this flag and can skip it. + if (state_.load() == kCompleted) return InFlightRendezvous(nullptr); + + // Try to increment a state in a CAS loop to signal all other participants + // that we joined an in-flight rendezvous. + int32_t state = state_.load(); + while (state != kCompleted && + !state_.compare_exchange_weak(state, state + 1)) { + } + + // Someone else completed the rendezvous and we don't need to join. + if (state == kCompleted) return InFlightRendezvous(nullptr); + + return InFlightRendezvous(this); +} + +bool RendezvousSingleFlag::IsCompleted() const { + return state_.load() == kCompleted; +} + +} // namespace xla diff --git a/third_party/xla/xla/service/rendezvous.h b/third_party/xla/xla/service/rendezvous.h index 4bda2c33b18af9..bee54f44fc9bd4 100644 --- a/third_party/xla/xla/service/rendezvous.h +++ b/third_party/xla/xla/service/rendezvous.h @@ -23,6 +23,7 @@ limitations under the License. #include #include #include +#include #include #include "absl/base/thread_annotations.h" @@ -50,6 +51,8 @@ struct RendezvousResult { using Type = std::shared_ptr; static Type Wrap(R result) { return std::make_shared(std::move(result)); } + + static Type Empty() { return std::shared_ptr(); } }; template @@ -60,6 +63,8 @@ struct RendezvousResult> { if (!result.ok()) return result.status(); return std::make_shared(std::move(*result)); } + + static Type Empty() { return {std::shared_ptr()}; } }; template @@ -91,6 +96,76 @@ void RendezvousSingle( absl::Duration warn_stuck_timeout = absl::InfiniteDuration(), absl::Duration terminate_timeout = absl::InfiniteDuration()); +// An `std::once_flag`-like primitive for executing RendezvousSingle operations. +// +// RendezvousSingleFlag guarantees that all or none participants in a rendezvous +// join the rendezvous process and once rendezvous is completed flag marked as +// `completed` and all further rendezvous using this flag will be skipped. It +// has a weaker than exactly-once guarantee and multiple racing rendezvous can +// execute in parallel, and the last completed rendezvous will switch flag to +// `completed` state. +// +// In XLA rendezvous are rare and used to guard costly shared state +// initialization, so in practice we do not expect to see many racing rendezvous +// and prefer simpler implementation with weaker guarantees. +// +// See: https://en.cppreference.com/w/cpp/thread/once_flag +class RendezvousSingleFlag { + public: + RendezvousSingleFlag(); + + RendezvousSingleFlag(const RendezvousSingleFlag&) = delete; + RendezvousSingleFlag& operator=(const RendezvousSingleFlag&) = delete; + + // RAII wrapper to exit from in-flight rendezvous when destructed. + class InFlightRendezvous { + public: + explicit InFlightRendezvous(RendezvousSingleFlag* flag); + ~InFlightRendezvous(); + + InFlightRendezvous(const InFlightRendezvous&) = delete; + InFlightRendezvous& operator=(const InFlightRendezvous&) = delete; + + operator bool() const; // NOLINT + + private: + RendezvousSingleFlag* flag_; + }; + + // Returns InFlightRendezvous convertible to `true` if the caller should join + // the rendezvous process. If result conversion to bool is `false` it means + // that the rendezvous is already completed. + InFlightRendezvous TryJoin(); + + bool IsCompleted() const; + + private: + friend class InFlightRendezvous; + + std::atomic state_; +}; + +// A rendezvous for a group of threads that will be executed only if the flag is +// not in `completed` state and will switch it to `completed` after finishing a +// rendezvous. If rendezvous will not be executed it will return empty shared +// pointer result. +template +RendezvousResultType RendezvousSingle( + RendezvousSingleFlag& flag, std::string_view name, const K& key, + size_t num_threads, Fn fn, + absl::Duration warn_stuck_timeout = absl::InfiniteDuration(), + absl::Duration terminate_timeout = absl::InfiniteDuration()); + +// A rendezvous for a group of threads that will be executed only if the flag is +// not in `completed` state and will switch it to `completed` after finishing a +// rendezvous. +template +void RendezvousSingle( + RendezvousSingleFlag& flag, std::string_view name, const K& key, + size_t num_threads, + absl::Duration warn_stuck_timeout = absl::InfiniteDuration(), + absl::Duration terminate_timeout = absl::InfiniteDuration()); + //===----------------------------------------------------------------------===// // Internal implementation details. //===----------------------------------------------------------------------===// @@ -209,7 +284,9 @@ RendezvousResultType RendezvousSingle(std::string_view name, const K& key, << "Id can't be larger than the number of participating threads" << "; id=" << id << "; num_threads=" << num_threads; - state->values[id] = &value; + // std::vector::operator[] creates data races, so we rely on data pointer + // here and when we create an absl::Span below. + *(state->values.data() + id) = &value; // Use a second atomic to safely publish values without data races. if constexpr (!std::is_same_v) { @@ -227,7 +304,8 @@ RendezvousResultType RendezvousSingle(std::string_view name, const K& key, // be notified via `state->ready` notification when result is ready, and we // rely on the notification to create a memory barrier that makes access to // `state->result` safe without any extra synchronization. - rendezvous.Complete(key, RendezvousResult::Wrap(fn(state->values))); + absl::Span values(state->values.data(), num_threads); + rendezvous.Complete(key, RendezvousResult::Wrap(fn(values))); } return state->result; @@ -252,6 +330,31 @@ void RendezvousSingle(std::string_view name, const K& key, size_t num_threads, warn_stuck_timeout, terminate_timeout); } +template +RendezvousResultType RendezvousSingle(RendezvousSingleFlag& flag, + std::string_view name, const K& key, + size_t num_threads, Fn fn, + absl::Duration warn_stuck_timeout, + absl::Duration terminate_timeout) { + if (auto in_flight_rendezvous = flag.TryJoin()) { + return RendezvousSingle(name, key, num_threads, std::move(fn), + warn_stuck_timeout, terminate_timeout); + } else { + return RendezvousResult::Empty(); + } +} + +template +void RendezvousSingle(RendezvousSingleFlag& flag, std::string_view name, + const K& key, size_t num_threads, + absl::Duration warn_stuck_timeout, + absl::Duration terminate_timeout) { + if (auto in_flight_rendezvous = flag.TryJoin()) { + RendezvousSingle(name, key, num_threads, warn_stuck_timeout, + terminate_timeout); + } +} + } // namespace xla #endif // XLA_SERVICE_RENDEZVOUS_H_ diff --git a/third_party/xla/xla/service/rendezvous_test.cc b/third_party/xla/xla/service/rendezvous_test.cc index d88bf312af097b..f7fda89bdccdd0 100644 --- a/third_party/xla/xla/service/rendezvous_test.cc +++ b/third_party/xla/xla/service/rendezvous_test.cc @@ -16,10 +16,13 @@ limitations under the License. #include #include +#include #include #include "absl/status/statusor.h" #include "absl/synchronization/blocking_counter.h" +#include "absl/synchronization/notification.h" +#include "absl/time/time.h" #include "absl/types/span.h" #include "tsl/platform/env.h" #include "tsl/platform/test.h" @@ -29,6 +32,9 @@ limitations under the License. namespace xla { namespace { +absl::Duration Timeout() { return absl::Seconds(10); } +absl::Duration Terminate() { return absl::Seconds(10); } + tsl::thread::ThreadPool CreateThreadPool(int32_t size) { return tsl::thread::ThreadPool(tsl::Env::Default(), "rendezvous_test", size); } @@ -128,6 +134,97 @@ TEST(RendezvousTest, ReturningStatusOr) { ASSERT_EQ(**results[1], 42); } +TEST(RendezvousTest, RendezvousSingleFlag) { + RendezvousSingleFlag flag; + + auto thread_pool = CreateThreadPool(2); + int32_t num_executed = 0; + + absl::BlockingCounter round_0(2); + absl::BlockingCounter round_1(2); + + auto task = [&](absl::BlockingCounter& counter) { + return [&] { + RendezvousSingle( + flag, "rendezvous_test", 0, 2, [&] { return ++num_executed; }, + Timeout(), Terminate()); + counter.DecrementCount(); + }; + }; + + // Execute rendezvous a first time. + thread_pool.Schedule(task(round_0)); + thread_pool.Schedule(task(round_0)); + round_0.Wait(); + + ASSERT_EQ(num_executed, 1); + + // Execute rendezvous a second time. + thread_pool.Schedule(task(round_1)); + thread_pool.Schedule(task(round_1)); + round_1.Wait(); + + // Check that we did not execute it second time. + ASSERT_EQ(num_executed, 1); +} + +TEST(RendezvousTest, RendezvousSingleFlagRace) { + RendezvousSingleFlag flag; + + static constexpr int32_t kNumRendezvous = 16; + static constexpr int32_t kNumThreads = 8; + + auto thread_pool = CreateThreadPool(kNumRendezvous * kNumThreads); + + auto task = [&](int32_t key) { + return [&, key] { + RendezvousSingle(flag, "key: " + std::to_string(key), key, kNumThreads, + Timeout(), Terminate()); + }; + }; + + for (int32_t key = 0; key < kNumRendezvous; ++key) { + for (int32_t thread = 0; thread < kNumThreads; ++thread) { + thread_pool.Schedule(task(key)); + } + } +} + +TEST(RendezvousTest, RendezvousSingleFlagRaceWithBarriers) { + RendezvousSingleFlag flag; + + static constexpr int32_t kNumRendezvous = 16; + static constexpr int32_t kNumThreads = 8; + + auto thread_pool = CreateThreadPool(kNumRendezvous * kNumThreads); + + // We use barriers and notifications to make sure all 128 threads start + // rendezvous at the same time to detect potential deadlocks and data races. + absl::BlockingCounter participants_ready(kNumRendezvous * kNumThreads); + absl::Notification participants_notification; + absl::BlockingCounter participants_done(kNumRendezvous * kNumThreads); + + auto task = [&](int32_t key) { + return [&, key] { + participants_ready.DecrementCount(); + participants_notification.WaitForNotification(); + RendezvousSingle(flag, "key: " + std::to_string(key), key, kNumThreads, + Timeout(), Terminate()); + participants_done.DecrementCount(); + }; + }; + + for (int32_t key = 0; key < kNumRendezvous; ++key) { + for (int32_t thread = 0; thread < kNumThreads; ++thread) { + thread_pool.Schedule(task(key)); + } + } + + participants_notification.Notify(); + participants_ready.Wait(); + participants_done.Wait(); +} + //===----------------------------------------------------------------------===// // Performance benchmarks below //===----------------------------------------------------------------------===// diff --git a/third_party/xla/xla/service/service.h b/third_party/xla/xla/service/service.h index f7b9325efbb5ad..f606698b1a8582 100644 --- a/third_party/xla/xla/service/service.h +++ b/third_party/xla/xla/service/service.h @@ -171,7 +171,7 @@ class Service : public ServiceInterface { // Create a Hlo module config for the given program shape and arguments. // aot_options is optional; if not given a default is used. - StatusOr> CreateModuleConfig( + absl::StatusOr> CreateModuleConfig( const ProgramShape& program_shape, absl::Span argument_shapes, const ExecutionOptions* execution_options, @@ -186,19 +186,19 @@ class Service : public ServiceInterface { private: // A private overload for Service itself, used by other methods within this // class. - StatusOr> CreateModuleConfig( + absl::StatusOr> CreateModuleConfig( const ProgramShape& program_shape, absl::Span arguments, const ExecutionOptions& execution_options, const AotCompilationOptions* aot_options = nullptr); // Prepare the executors for executing parallel. - StatusOr> GetExecutors( + absl::StatusOr> GetExecutors( const ExecutionOptions& execution_options, int64_t requests_size, int64_t request_index) const; // Prepare the arguments for executing parallel. - StatusOr>> GetArguments( + absl::StatusOr>> GetArguments( const ExecutionOptions& execution_options, absl::Span arguments) const; @@ -214,7 +214,7 @@ class Service : public ServiceInterface { // the corresponding allocations for every replica. The function also verifies // that each allocation matches the execution platform and device ordinal of // the corresponding replica. - StatusOr>> + absl::StatusOr>> ResolveAndValidateArguments( absl::Span arguments, absl::Span stream_executors) const; @@ -225,7 +225,7 @@ class Service : public ServiceInterface { // If device_allocator is not null, the compiler may use it to allocate temp // buffers, which the compiler is responsible for freeing. The allocator // given here need not match the allocator used when running the executable. - StatusOr> BuildExecutable( + absl::StatusOr> BuildExecutable( const HloModuleProto& module_proto, std::unique_ptr module_config, Backend* backend, se::StreamExecutor* executor, const Compiler::CompileOptions& options, @@ -233,7 +233,7 @@ class Service : public ServiceInterface { // Same as BuildExecutable() above, but builds a list of Executables for the // given computations that may interact with each other. - StatusOr>> BuildExecutables( + absl::StatusOr>> BuildExecutables( const std::vector& module_protos, std::vector> module_configs, Backend* backend, std::vector> executors, @@ -243,17 +243,19 @@ class Service : public ServiceInterface { // Same as BuildExecutable() above, but builds a list of // AotCompilationResult(s), which can be persisted to later load Executable // objects. - StatusOr>> BuildAotResults( - const std::vector& module_protos, - std::vector> module_configs, - Backend* backend, std::vector> executors, - const Compiler::CompileOptions& options, bool run_backend_only = false); + absl::StatusOr>> + BuildAotResults(const std::vector& module_protos, + std::vector> module_configs, + Backend* backend, + std::vector> executors, + const Compiler::CompileOptions& options, + bool run_backend_only = false); // Runs the given executable with the given arguments and register the result // in the allocation tracker. The handle of the result from the tracker is // returned. If the parameter "profile" is not null, it points to an // ExecutionProfile object which will be filled in with profile data. - StatusOr ExecuteAndRegisterResult( + absl::StatusOr ExecuteAndRegisterResult( Executable* executable, absl::Span> arguments, Backend* backend, const DeviceHandle& device_handle, @@ -262,7 +264,8 @@ class Service : public ServiceInterface { // Runs the given executables with the given arguments and register the result // from each executable in the allocation tracker. The handles of the result // from the tracker are returned. - StatusOr> ExecuteParallelAndRegisterResult( + absl::StatusOr> + ExecuteParallelAndRegisterResult( absl::Span executables, absl::Span>> arguments, Backend* backend, absl::Span device_handles, @@ -271,7 +274,7 @@ class Service : public ServiceInterface { // Returns the stream executors assigned to the replicas represented by the // given device handle. Each device_handle is a virtual replicated device that // represents a set of physical devices for the replicas. - StatusOr> Replicas( + absl::StatusOr> Replicas( const Backend& backend, const DeviceHandle& device_handle) const; // Returns the device handle that represents the replicated device for a diff --git a/third_party/xla/xla/service/service_executable_run_options.h b/third_party/xla/xla/service/service_executable_run_options.h index 1d47628e69b520..6013ca758d42ba 100644 --- a/third_party/xla/xla/service/service_executable_run_options.h +++ b/third_party/xla/xla/service/service_executable_run_options.h @@ -35,8 +35,9 @@ class ServiceExecutableRunOptions { // with the first argument being the device ordinal, the second // argument being the number of streams to borrow, and the third // argument being the priority of the streams. - using StreamBorrower = std::function>( - int, int, se::StreamPriority)>; + using StreamBorrower = + std::function>( + int, int, se::StreamPriority)>; ServiceExecutableRunOptions() : ServiceExecutableRunOptions(ExecutableRunOptions()) {} @@ -59,7 +60,7 @@ class ServiceExecutableRunOptions { // Borrows a stream and returns a smart pointer which returns the stream on // destruction. - StatusOr BorrowStream( + absl::StatusOr BorrowStream( int device_ordinal, se::StreamPriority priority = se::StreamPriority::Default) const { if (!stream_borrower_) { @@ -73,7 +74,7 @@ class ServiceExecutableRunOptions { return stream; } - StatusOr> BorrowStreams( + absl::StatusOr> BorrowStreams( int device_ordinal, int num_streams, se::StreamPriority priority = se::StreamPriority::Default) const { return stream_borrower_ diff --git a/third_party/xla/xla/service/shape_inference.cc b/third_party/xla/xla/service/shape_inference.cc index 01aa72a4252dab..c2c447521e4566 100644 --- a/third_party/xla/xla/service/shape_inference.cc +++ b/third_party/xla/xla/service/shape_inference.cc @@ -184,9 +184,9 @@ Status VerifyReducerShape(const ProgramShape& reducer_shape, return OkStatus(); } -StatusOr InferWindowOutputShape(const Shape& base_shape, - const Window& window, - PrimitiveType element_type) { +absl::StatusOr InferWindowOutputShape(const Shape& base_shape, + const Window& window, + PrimitiveType element_type) { if (window.dimensions_size() != base_shape.rank()) { return InvalidArgument( "Window has dimension %d but base shape has dimension %d.", @@ -290,11 +290,11 @@ DimAndBound InferConcatenatedDimAndBound(int64_t left_size, int64_t right_size, // A HLO static dimension size `X` is expressed as size=X, and bound=? // A bounded dynamic dimension size `<=X` is be expressed as size=X, and bound=? // A unbounded dynamic dimension size, `?`, is expressed as size=?, and bound=? -StatusOr InferMostSpecificDimAndBound(int64_t dim, - int64_t left_size, - int64_t right_size, - int64_t left_bound, - int64_t right_bound) { +absl::StatusOr InferMostSpecificDimAndBound(int64_t dim, + int64_t left_size, + int64_t right_size, + int64_t left_bound, + int64_t right_bound) { bool is_left_static_dim = !IsUnboundedDynamicSize(left_size); bool is_right_static_dim = !IsUnboundedDynamicSize(right_size); bool is_left_static_bound = !IsUnboundedDynamicSize(left_bound); @@ -330,12 +330,12 @@ StatusOr InferMostSpecificDimAndBound(int64_t dim, } // namespace -/* static */ StatusOr ShapeInference::InferUnaryOpShape( +/* static */ absl::StatusOr ShapeInference::InferUnaryOpShape( HloOpcode opcode, const HloInstruction* operand) { return InferUnaryOpShape(opcode, operand->shape()); } -/* static */ StatusOr ShapeInference::InferUnaryOpShape( +/* static */ absl::StatusOr ShapeInference::InferUnaryOpShape( HloOpcode opcode, const Shape& shape) { // There is no copy operation at the proto level, so handle copy explicitly. // A domain shape is the same as the input one. @@ -349,6 +349,7 @@ StatusOr InferMostSpecificDimAndBound(int64_t dim, switch (opcode) { case HloOpcode::kFloor: case HloOpcode::kCeil: + case HloOpcode::kErf: case HloOpcode::kRoundNearestAfz: case HloOpcode::kRoundNearestEven: if (!ShapeUtil::ElementIsFloating(shape)) { @@ -464,7 +465,7 @@ StatusOr InferMostSpecificDimAndBound(int64_t dim, } } -/* static */ StatusOr ShapeInference::InferTopKShape( +/* static */ absl::StatusOr ShapeInference::InferTopKShape( const Shape& operand_shape, int64_t k) { TF_RETURN_IF_ERROR(ExpectArray(operand_shape, "operand of top-k operation")); int64_t last_dim = operand_shape.rank() - 1; @@ -486,7 +487,7 @@ StatusOr InferMostSpecificDimAndBound(int64_t dim, return ShapeUtil::MakeTupleShape({out, idxs_shape}); } -/* static */ StatusOr ShapeInference::InferConcatOpShape( +/* static */ absl::StatusOr ShapeInference::InferConcatOpShape( absl::Span arg_shapes, const int64_t dimension) { if (arg_shapes.empty()) { return InvalidArgument("Concatenate expects at least one argument."); @@ -579,7 +580,7 @@ StatusOr InferMostSpecificDimAndBound(int64_t dim, return result; } -/* static */ StatusOr ShapeInference::InferConvertShape( +/* static */ absl::StatusOr ShapeInference::InferConvertShape( const Shape& operand_shape, PrimitiveType new_element_type) { if (!operand_shape.IsArray() || !primitive_util::IsArrayType(new_element_type)) { @@ -595,7 +596,7 @@ StatusOr InferMostSpecificDimAndBound(int64_t dim, return ShapeUtil::ChangeElementType(operand_shape, new_element_type); } -/* static */ StatusOr ShapeInference::InferBitcastConvertShape( +/* static */ absl::StatusOr ShapeInference::InferBitcastConvertShape( const Shape& operand_shape, PrimitiveType new_element_type) { auto old_element_type = operand_shape.element_type(); if (primitive_util::IsComplexType(old_element_type) != @@ -649,7 +650,7 @@ StatusOr InferMostSpecificDimAndBound(int64_t dim, return new_shape; } -/* static */ StatusOr ShapeInference::InferStochasticConvertShape( +/* static */ absl::StatusOr ShapeInference::InferStochasticConvertShape( const Shape& operand_shape, const Shape& random_shape, PrimitiveType new_element_type) { TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(operand_shape)); @@ -693,7 +694,7 @@ StatusOr InferMostSpecificDimAndBound(int64_t dim, return ShapeUtil::ChangeElementType(operand_shape, new_element_type); } -/* static */ StatusOr ShapeInference::InferReducePrecisionShape( +/* static */ absl::StatusOr ShapeInference::InferReducePrecisionShape( const Shape& operand_shape, const int exponent_bits, const int mantissa_bits) { if (!ShapeUtil::ElementIsFloating(operand_shape)) { @@ -717,7 +718,7 @@ StatusOr InferMostSpecificDimAndBound(int64_t dim, return operand_shape; } -/* static */ StatusOr ShapeInference::InferPadShape( +/* static */ absl::StatusOr ShapeInference::InferPadShape( const Shape& operand_shape, const Shape& padding_value_shape, const PaddingConfig& padding_config) { if (!operand_shape.IsArray()) { @@ -840,7 +841,7 @@ Status ValidateDotDimensionNumbers( } // namespace -/* static */ StatusOr ShapeInference::InferDotOpShape( +/* static */ absl::StatusOr ShapeInference::InferDotOpShape( const Shape& lhs, const Shape& rhs, const DotDimensionNumbers& dimension_numbers, std::optional preferred_element_type) { @@ -939,9 +940,8 @@ Status ValidateDotDimensionNumbers( return result; } -/* static */ StatusOr -ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, - const Shape& lhs, +/* static */ absl::StatusOr +ShapeInference::InferDegenerateDimensionBroadcastShape(const Shape& lhs, const Shape& rhs) { TF_RET_CHECK(lhs.rank() == rhs.rank()); @@ -994,10 +994,9 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, ? rhs.is_dynamic_dimension(i) : lhs.is_dynamic_dimension(i); } else { - return InvalidArgument( - "Binary op %s with incompatible shapes: %s and %s.", - HloOpcodeString(operation), ShapeUtil::HumanString(lhs), - ShapeUtil::HumanString(rhs)); + return InvalidArgument("Binary op with incompatible shapes: %s and %s.", + ShapeUtil::HumanString(lhs), + ShapeUtil::HumanString(rhs)); } } @@ -1005,7 +1004,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, output_dimensions, output_dimensions_is_dynamic); } -/* static */ StatusOr ShapeInference::InferInDimBroadcastShape( +/* static */ absl::StatusOr ShapeInference::InferInDimBroadcastShape( const Shape& smaller_shape, const Shape& larger_shape, absl::Span broadcast_dimensions) { if (smaller_shape.is_unbounded_dynamic() || @@ -1131,7 +1130,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, return output_shape; } -/* static */ StatusOr ShapeInference::InferElementwiseBinaryOpShape( +/* static */ absl::StatusOr +ShapeInference::InferElementwiseBinaryOpShape( HloOpcode operation, const Shape& lhs, const Shape& rhs, absl::Span broadcast_dimensions) { TF_RETURN_IF_ERROR(ExpectArray(lhs, "lhs of elementwise binary operation")); @@ -1171,7 +1171,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, return result; } else if (lhs.rank() == rhs.rank()) { - return InferDegenerateDimensionBroadcastShape(operation, lhs, rhs); + return InferDegenerateDimensionBroadcastShape(lhs, rhs); } else { // Ranks do not match, so perform InDim broadcasting using // broadcast_dimensions. Scalar broadcasting is a special case of this. @@ -1183,18 +1183,18 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, InferInDimBroadcastShape(smaller_shape, larger_shape, broadcast_dimensions)); - return InferDegenerateDimensionBroadcastShape( - operation, indim_broadcast_shape, larger_shape); + return InferDegenerateDimensionBroadcastShape(indim_broadcast_shape, + larger_shape); } } -/* static */ StatusOr ShapeInference::InferBinaryOpShape( +/* static */ absl::StatusOr ShapeInference::InferBinaryOpShape( HloOpcode opcode, const HloInstruction* lhs, const HloInstruction* rhs) { return InferBinaryOpShape(opcode, lhs->shape(), rhs->shape(), /*broadcast_dimensions=*/{}); } -/* static */ StatusOr ShapeInference::InferBinaryOpShape( +/* static */ absl::StatusOr ShapeInference::InferBinaryOpShape( HloOpcode opcode, const Shape& lhs, const Shape& rhs, absl::Span broadcast_dimensions) { VLOG(2) << StrFormat( @@ -1279,13 +1279,13 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, } } -/* static */ StatusOr ShapeInference::InferTernaryOpShape( +/* static */ absl::StatusOr ShapeInference::InferTernaryOpShape( HloOpcode opcode, const HloInstruction* lhs, const HloInstruction* rhs, const HloInstruction* ehs) { return InferTernaryOpShape(opcode, lhs->shape(), rhs->shape(), ehs->shape()); } -/* static */ StatusOr ShapeInference::InferTernaryOpShape( +/* static */ absl::StatusOr ShapeInference::InferTernaryOpShape( HloOpcode opcode, const Shape& lhs, const Shape& rhs, const Shape& ehs) { TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(lhs)); TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(rhs)); @@ -1300,7 +1300,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, } } -/* static */ StatusOr ShapeInference::InferVariadicOpShape( +/* static */ absl::StatusOr ShapeInference::InferVariadicOpShape( HloOpcode opcode, absl::Span operands) { std::vector operand_shapes; operand_shapes.reserve(operands.size()); @@ -1310,7 +1310,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, return InferVariadicOpShape(opcode, operand_shapes); } -/* static */ StatusOr ShapeInference::InferVariadicOpShape( +/* static */ absl::StatusOr ShapeInference::InferVariadicOpShape( HloOpcode opcode, absl::Span operand_shapes) { for (const Shape* shape : operand_shapes) { TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(*shape)); @@ -1347,7 +1347,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, } } -/* static */ StatusOr ShapeInference::InferMapShape( +/* static */ absl::StatusOr ShapeInference::InferMapShape( absl::Span arg_shapes, const ProgramShape& to_apply, absl::Span dimensions) { if (arg_shapes.empty()) { @@ -1439,7 +1439,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, arg_shape->dimensions()); } -/* static */ StatusOr ShapeInference::InferBatchNormTrainingShape( +/* static */ absl::StatusOr ShapeInference::InferBatchNormTrainingShape( const Shape& operand_shape, const Shape& scale_shape, const Shape& offset_shape, int64_t feature_index) { TF_RETURN_IF_ERROR( @@ -1547,7 +1547,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, &output_shape_for_mean_and_var}); } -/* static */ StatusOr ShapeInference::InferBatchNormInferenceShape( +/* static */ absl::StatusOr ShapeInference::InferBatchNormInferenceShape( const Shape& operand_shape, const Shape& scale_shape, const Shape& offset_shape, const Shape& mean_shape, const Shape& variance_shape, int64_t feature_index) { @@ -1693,7 +1693,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, return operand_shape; } -/* static */ StatusOr ShapeInference::InferBatchNormGradShape( +/* static */ absl::StatusOr ShapeInference::InferBatchNormGradShape( const Shape& operand_shape, const Shape& scale_shape, const Shape& mean_shape, const Shape& var_shape, const Shape& output_grad_shape, int64_t feature_index) { @@ -1853,7 +1853,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, {&operand_shape, &feature_shape, &feature_shape}); } -/* static */ StatusOr ShapeInference::InferConvolveShape( +/* static */ absl::StatusOr ShapeInference::InferConvolveShape( const Shape& lhs, const Shape& rhs, int64_t feature_group_count, int64_t batch_group_count, const Window& window, const ConvolutionDimensionNumbers& dnums, @@ -2117,7 +2117,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, return ShapeUtil::MakeShape(type, dimensions, is_dynamic); } -/* static */ StatusOr ShapeInference::InferFftShape( +/* static */ absl::StatusOr ShapeInference::InferFftShape( const Shape& in, const FftType fft_type, const absl::Span fft_length) { const int64_t fft_rank = fft_length.size(); @@ -2206,7 +2206,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, #undef RET_CHECK_RANK } -/* static */ StatusOr ShapeInference::InferTriangularSolveShape( +/* static */ absl::StatusOr ShapeInference::InferTriangularSolveShape( const Shape& a, const Shape& b, const TriangularSolveOptions& options) { if ((!ShapeUtil::ElementIsFloating(a) && !ShapeUtil::ElementIsComplex(a)) || a.element_type() != b.element_type()) { @@ -2257,7 +2257,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, return b; } -/* static */ StatusOr ShapeInference::InferCholeskyShape( +/* static */ absl::StatusOr ShapeInference::InferCholeskyShape( const Shape& a) { if (!ShapeUtil::ElementIsFloating(a) && !ShapeUtil::ElementIsComplex(a)) { return InvalidArgument( @@ -2278,7 +2278,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, return a; } -/* static */ StatusOr ShapeInference::InferAllGatherShape( +/* static */ absl::StatusOr ShapeInference::InferAllGatherShape( absl::Span operand_shapes, int64_t all_gather_dimension, int64_t shard_count) { TF_RET_CHECK(all_gather_dimension >= 0); @@ -2302,7 +2302,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, return ShapeUtil::MakeTupleShape(output_shapes); } -/* static */ StatusOr ShapeInference::InferAllGatherStartShape( +/* static */ absl::StatusOr ShapeInference::InferAllGatherStartShape( absl::Span operand_shapes, int64_t all_gather_dimension, int64_t shard_count) { TF_ASSIGN_OR_RETURN( @@ -2317,12 +2317,12 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, return ShapeUtil::MakeTupleShapeWithPtrs({&input_shape, &ag_shape}); } -/* static */ StatusOr ShapeInference::InferAllGatherDoneShape( +/* static */ absl::StatusOr ShapeInference::InferAllGatherDoneShape( const Shape& all_gather_start_shape) { return ShapeUtil::GetTupleElementShape(all_gather_start_shape, 1); } -/* static */ StatusOr ShapeInference::InferAllReduceShape( +/* static */ absl::StatusOr ShapeInference::InferAllReduceShape( absl::Span operand_shapes) { for (const Shape* operand_shape : operand_shapes) { TF_RETURN_IF_ERROR( @@ -2334,7 +2334,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, return ShapeUtil::MakeTupleShapeWithPtrs(operand_shapes); } -/* static */ StatusOr ShapeInference::InferReduceScatterShape( +/* static */ absl::StatusOr ShapeInference::InferReduceScatterShape( absl::Span operand_shapes, int64_t scatter_dimension, int64_t shard_count) { TF_RET_CHECK(scatter_dimension >= 0); @@ -2369,18 +2369,18 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, return ShapeUtil::MakeTupleShape(output_shapes); } -/* static */ StatusOr ShapeInference::InferAllReduceStartShape( +/* static */ absl::StatusOr ShapeInference::InferAllReduceStartShape( absl::Span operand_shapes) { return InferAllReduceShape(operand_shapes); } -/* static */ StatusOr ShapeInference::InferAllReduceDoneShape( +/* static */ absl::StatusOr ShapeInference::InferAllReduceDoneShape( const Shape& operand_shape) { // The returned value from AllReduceDone is the operand forwarded. return operand_shape; } -/* static */ StatusOr ShapeInference::InferAllToAllShape( +/* static */ absl::StatusOr ShapeInference::InferAllToAllShape( const Shape& shape, int64_t split_dimension, int64_t concat_dimension, int64_t split_count) { TF_RET_CHECK(split_count > 0); @@ -2407,7 +2407,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, return ShapeUtil::MakeShape(shape.element_type(), new_dimensions); } -/* static */ StatusOr ShapeInference::InferAllToAllTupleShape( +/* static */ absl::StatusOr ShapeInference::InferAllToAllTupleShape( absl::Span operand_shapes) { // An Alltoall HLO instruction receives N operands (with the same shape) and // returns a tuple that contains N array shapes. @@ -2426,7 +2426,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, return InferVariadicOpShape(HloOpcode::kTuple, operand_shapes); } -/* static */ StatusOr ShapeInference::InferCollectivePermuteShape( +/* static */ absl::StatusOr ShapeInference::InferCollectivePermuteShape( absl::Span operand_shapes) { if (operand_shapes.size() == 1) { TF_RETURN_IF_ERROR( @@ -2438,7 +2438,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, } } -/* static */ StatusOr ShapeInference::InferCollectivePermuteStartShape( +/* static */ absl::StatusOr +ShapeInference::InferCollectivePermuteStartShape( absl::Span operand_shapes, absl::Span context_shapes) { absl::InlinedVector shapes; @@ -2455,13 +2456,13 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, return ShapeUtil::MakeTupleShapeWithPtrs(shapes); } -/* static */ StatusOr ShapeInference::InferCollectivePermuteDoneShape( - const Shape& operand_shape) { +/* static */ absl::StatusOr +ShapeInference::InferCollectivePermuteDoneShape(const Shape& operand_shape) { TF_RET_CHECK(operand_shape.IsTuple()); return ShapeUtil::GetTupleElementShape(operand_shape, 1); } -/* static */ StatusOr ShapeInference::InferReduceShape( +/* static */ absl::StatusOr ShapeInference::InferReduceShape( absl::Span arg_shapes, absl::Span dimensions_to_reduce, const ProgramShape& to_apply) { @@ -2538,7 +2539,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, } } -/* static */ StatusOr ShapeInference::InferReduceWindowShape( +/* static */ absl::StatusOr ShapeInference::InferReduceWindowShape( const Shape& operand_shape, const Shape& init_value_shape, const Window& window, const ProgramShape& to_apply_shape) { TF_RETURN_IF_ERROR(VerifyReducerShape(to_apply_shape, {&init_value_shape}, @@ -2547,7 +2548,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, return InferReduceWindowShape(operand_shape, init_value_shape, window); } -/* static */ StatusOr ShapeInference::InferReduceWindowShape( +/* static */ absl::StatusOr ShapeInference::InferReduceWindowShape( absl::Span operands, absl::Span init_values, const Window& window, const ProgramShape& to_apply_shape) { @@ -2588,7 +2589,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, } } -/* static */ StatusOr ShapeInference::InferReduceWindowShape( +/* static */ absl::StatusOr ShapeInference::InferReduceWindowShape( const Shape& operand_shape, const Shape& init_value_shape, const Window& window) { TF_RETURN_IF_ERROR(ExpectArray(operand_shape, "operand of reduce-window")); @@ -2596,7 +2597,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, init_value_shape.element_type()); } -/* static */ StatusOr ShapeInference::InferSelectAndScatterShape( +/* static */ absl::StatusOr ShapeInference::InferSelectAndScatterShape( const Shape& operand_shape, const ProgramShape& select_shape, const Window& window, const Shape& source_shape, const Shape& init_value_shape, const ProgramShape& scatter_shape) { @@ -2655,7 +2656,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, return operand_shape; } -/* static */ StatusOr ShapeInference::InferGetDimensionSizeShape( +/* static */ absl::StatusOr ShapeInference::InferGetDimensionSizeShape( const Shape& shape, int64_t dimension) { if (dimension < 0 || dimension >= shape.rank()) { return InvalidArgument("GetDimensionSize dimension out of bounds: %d.", @@ -2674,7 +2675,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, return ShapeUtil::MakeShape(S32, {}); } -/* static */ StatusOr ShapeInference::InferSetDimensionSizeShape( +/* static */ absl::StatusOr ShapeInference::InferSetDimensionSizeShape( const Shape& shape, const Shape& val_shape, int64_t dimension) { if (dimension < 0 || dimension >= shape.rank()) { return InvalidArgument("SetDimensionSize dimension out of bounds: %d.", @@ -2700,7 +2701,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, return result; } -/* static */ StatusOr ShapeInference::InferWindowFromDimensions( +/* static */ absl::StatusOr ShapeInference::InferWindowFromDimensions( absl::Span window_dimensions, absl::Span window_strides, absl::Span> padding, @@ -2762,7 +2763,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, return window; } -/* static */ StatusOr ShapeInference::InferSliceShape( +/* static */ absl::StatusOr ShapeInference::InferSliceShape( const Shape& arg, absl::Span starts, absl::Span limits, absl::Span strides) { auto error = [&](const std::string& message) { @@ -2837,7 +2838,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, return ShapeUtil::MakeShape(arg.element_type(), sizes, is_dynamic); } -/* static */ StatusOr ShapeInference::InferDynamicSliceShape( +/* static */ absl::StatusOr ShapeInference::InferDynamicSliceShape( const Shape& operand_shape, absl::Span start_index_shapes, absl::Span slice_sizes, bool allow_scalar_indices) { TF_RETURN_IF_ERROR(ExpectArray(operand_shape, "operand of dynamic slice")); @@ -2950,7 +2951,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, return result; } -/* static */ StatusOr ShapeInference::InferDynamicUpdateSliceShape( +/* static */ absl::StatusOr ShapeInference::InferDynamicUpdateSliceShape( const Shape& operand_shape, const Shape& update_shape, absl::Span start_index_shapes, bool allow_scalar_indices) { TF_RETURN_IF_ERROR( @@ -3087,7 +3088,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, return result_shape; } -/*static */ StatusOr ShapeInference::InferReverseShape( +/*static */ absl::StatusOr ShapeInference::InferReverseShape( const Shape& operand_shape, absl::Span dimensions) { TF_RETURN_IF_ERROR(ExpectArray(operand_shape, "operand of reverse")); if (!AllUnique(dimensions)) { @@ -3103,7 +3104,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, return operand_shape; } -/* static */ StatusOr ShapeInference::InferGetTupleElementShape( +/* static */ absl::StatusOr ShapeInference::InferGetTupleElementShape( const Shape& arg, int64_t index) { if (!arg.IsTuple()) { return InvalidArgument( @@ -3121,7 +3122,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, return arg.tuple_shapes(index); } -/* static */ StatusOr ShapeInference::InferWhileShape( +/* static */ absl::StatusOr ShapeInference::InferWhileShape( const ProgramShape& condition, const ProgramShape& body, const Shape& init) { // Check the number of parameters for given computations. @@ -3158,7 +3159,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, return init; } -/* static */ StatusOr ShapeInference::InferConditionalShape( +/* static */ absl::StatusOr ShapeInference::InferConditionalShape( const Shape& branch_index, absl::Span branch_computations, absl::Span branch_operands) { @@ -3233,7 +3234,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, return result; } -/* static */ StatusOr ShapeInference::InferBroadcastShape( +/* static */ absl::StatusOr ShapeInference::InferBroadcastShape( const Shape& operand, absl::Span broadcast_sizes) { // This method is used to infer shape for xla::BroadcastInDim. TF_RETURN_IF_ERROR(ExpectArray(operand, "operand of broadcast")); @@ -3263,7 +3264,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, return result; } -/* static */ StatusOr ShapeInference::InferBroadcastShape( +/* static */ absl::StatusOr ShapeInference::InferBroadcastShape( const Shape& operand_shape, const Shape& output_shape, absl::Span broadcast_dimensions) { // This method is used to infer shape for xla::BroadcastInDim. @@ -3320,7 +3321,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, return output_shape; } -/* static */ StatusOr ShapeInference::InferDynamicReshapeShape( +/* static */ absl::StatusOr ShapeInference::InferDynamicReshapeShape( const Shape& operand, absl::Span dim_size_shapes, absl::Span new_size_bounds, const std::vector& dims_are_dynamic) { @@ -3352,7 +3353,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, return inferred_shape; } -/* static */ StatusOr ShapeInference::InferReshapeShape( +/* static */ absl::StatusOr ShapeInference::InferReshapeShape( const Shape& operand, absl::Span dimensions, absl::Span new_sizes, int64_t inferred_dimension) { TF_RETURN_IF_ERROR(ExpectArray(operand, "reshape")); @@ -3511,7 +3512,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, return inferred_shape; } -/* static */ StatusOr ShapeInference::InferTransposeShape( +/* static */ absl::StatusOr ShapeInference::InferTransposeShape( const Shape& operand, absl::Span dimensions) { TF_RETURN_IF_ERROR(ExpectArray(operand, "transpose")); @@ -3528,7 +3529,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, return ShapeUtil::PermuteDimensions(dimensions, operand); } -/* static */ StatusOr ShapeInference::InferClampShape( +/* static */ absl::StatusOr ShapeInference::InferClampShape( const Shape& min, const Shape& operand, const Shape& max) { TF_RETURN_IF_ERROR(ExpectArray(min, "clamp min")); TF_RETURN_IF_ERROR(ExpectArray(operand, "clamp operand")); @@ -3545,7 +3546,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, return operand; } -/* static */ StatusOr ShapeInference::InferSelectShape( +/* static */ absl::StatusOr ShapeInference::InferSelectShape( const Shape& pred, const Shape& on_true, const Shape& on_false) { TF_RETURN_IF_ERROR(ExpectArray(pred, "select pred")); TF_RETURN_IF_ERROR(ExpectArray(on_true, "select on-true")); @@ -3575,7 +3576,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, for (int64_t dimension = 0; dimension < pred.rank(); ++dimension) { if (on_true.is_unbounded_dynamic_dimension(dimension) || on_false.is_unbounded_dynamic_dimension(dimension)) { - StatusOr inferred = InferMostSpecificDimAndBound( + absl::StatusOr inferred = InferMostSpecificDimAndBound( dimension, on_true.dimensions(dimension), on_false.dimensions(dimension), on_true.dimensions(dimension), on_false.dimensions(dimension)); @@ -3593,7 +3594,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, return std::move(result); } -/* static */ StatusOr ShapeInference::InferCallShape( +/* static */ absl::StatusOr ShapeInference::InferCallShape( absl::Span arg_shapes, const ProgramShape& to_apply) { // The applied function's arity equals the number of arguments. if (arg_shapes.size() != to_apply.parameters_size()) { @@ -3719,7 +3720,7 @@ static Status ValidateGatherDimensionNumbers( return OkStatus(); } -/*static*/ StatusOr ShapeInference::InferGatherShape( +/*static*/ absl::StatusOr ShapeInference::InferGatherShape( const Shape& input_shape, const Shape& start_indices_shape, const GatherDimensionNumbers& gather_dim_numbers, absl::Span slice_sizes) { @@ -3957,7 +3958,7 @@ Status ValidateScatterDimensionNumbers( } // namespace -/*static*/ StatusOr ShapeInference::InferScatterShape( +/*static*/ absl::StatusOr ShapeInference::InferScatterShape( absl::Span arg_shapes, const ProgramShape& to_apply_shape, const ScatterDimensionNumbers& scatter_dim_numbers) { diff --git a/third_party/xla/xla/service/shape_inference.h b/third_party/xla/xla/service/shape_inference.h index 5618d87ad288af..2a91a107330699 100644 --- a/third_party/xla/xla/service/shape_inference.h +++ b/third_party/xla/xla/service/shape_inference.h @@ -44,142 +44,141 @@ class ShapeInference { public: // Infers the shape produced by applying the given unary operation to the // given input shape. - static StatusOr InferUnaryOpShape(HloOpcode opcode, - const Shape& shape); - static StatusOr InferUnaryOpShape(HloOpcode opcode, - const HloInstruction* operand); + static absl::StatusOr InferUnaryOpShape(HloOpcode opcode, + const Shape& shape); + static absl::StatusOr InferUnaryOpShape(HloOpcode opcode, + const HloInstruction* operand); // Infers the shape produced by applying the given binary operation to the // given input shapes. - static StatusOr InferBinaryOpShape( + static absl::StatusOr InferBinaryOpShape( HloOpcode opcode, const Shape& lhs, const Shape& rhs, absl::Span broadcast_dimensions); - static StatusOr InferBinaryOpShape(HloOpcode opcode, - const HloInstruction* lhs, - const HloInstruction* rhs); + static absl::StatusOr InferBinaryOpShape(HloOpcode opcode, + const HloInstruction* lhs, + const HloInstruction* rhs); // Infers the shape produced by applying the given ternary operation to the // given input shapes. - static StatusOr InferTernaryOpShape(HloOpcode opcode, const Shape& lhs, - const Shape& rhs, - const Shape& ehs); - static StatusOr InferTernaryOpShape(HloOpcode opcode, - const HloInstruction* lhs, - const HloInstruction* rhs, - const HloInstruction* ehs); + static absl::StatusOr InferTernaryOpShape(HloOpcode opcode, + const Shape& lhs, + const Shape& rhs, + const Shape& ehs); + static absl::StatusOr InferTernaryOpShape(HloOpcode opcode, + const HloInstruction* lhs, + const HloInstruction* rhs, + const HloInstruction* ehs); // Infers the shape produced by applying the given variadic operation to the // given input operand shapes. - static StatusOr InferVariadicOpShape( + static absl::StatusOr InferVariadicOpShape( HloOpcode opcode, absl::Span operand_shapes); - static StatusOr InferVariadicOpShape( + static absl::StatusOr InferVariadicOpShape( HloOpcode opcode, absl::Span operands); // Infers the shape produced by applying the given mapping computation shape // to the given operand shapes. - static StatusOr InferMapShape( + static absl::StatusOr InferMapShape( absl::Span arg_shapes, const ProgramShape& to_apply, absl::Span dimensions); // Infers the shape produced by InferBatchNormTraining with the given // operands. - static StatusOr InferBatchNormTrainingShape(const Shape& operand_shape, - const Shape& scale_shape, - const Shape& offset_shape, - int64_t feature_index); + static absl::StatusOr InferBatchNormTrainingShape( + const Shape& operand_shape, const Shape& scale_shape, + const Shape& offset_shape, int64_t feature_index); // Infers the shape produced by InferBatchNormInference with the given // operands. - static StatusOr InferBatchNormInferenceShape( + static absl::StatusOr InferBatchNormInferenceShape( const Shape& operand_shape, const Shape& scale_shape, const Shape& offset_shape, const Shape& mean_shape, const Shape& variance_shape, int64_t feature_index); // Infers the shape produced by InferBatchNormGrad with the given operands. - static StatusOr InferBatchNormGradShape(const Shape& operand_shape, - const Shape& scale_shape, - const Shape& mean_shape, - const Shape& var_shape, - const Shape& output_grad_shape, - int64_t feature_index); + static absl::StatusOr InferBatchNormGradShape( + const Shape& operand_shape, const Shape& scale_shape, + const Shape& mean_shape, const Shape& var_shape, + const Shape& output_grad_shape, int64_t feature_index); // Infers the shape produced by applying the given convolutional filter (rhs) // to lhs in the way specified by the fields on window. An optional // preferred_element_type can be specified to upcast the element type. - static StatusOr InferConvolveShape( + static absl::StatusOr InferConvolveShape( const Shape& lhs, const Shape& rhs, int64_t feature_group_count, int64_t batch_group_count, const Window& window, const ConvolutionDimensionNumbers& dimension_numbers, std::optional preferred_element_type); // Infers the shape produced by the given FFT type on the given operand. - static StatusOr InferFftShape(const Shape& in, FftType fft_type, - absl::Span fft_length); + static absl::StatusOr InferFftShape( + const Shape& in, FftType fft_type, absl::Span fft_length); // Infers the shape produced by the given triangular solve operation. - static StatusOr InferTriangularSolveShape( + static absl::StatusOr InferTriangularSolveShape( const Shape& a, const Shape& b, const TriangularSolveOptions& options); // Infers the shape produced by the given triangular solve operation. - static StatusOr InferCholeskyShape(const Shape& a); + static absl::StatusOr InferCholeskyShape(const Shape& a); // Infers the shape produced by an all-gather with the given operand shape, // concat dimension, and shard count. - static StatusOr InferAllGatherShape( + static absl::StatusOr InferAllGatherShape( absl::Span operand_shapes, int64_t all_gather_dimension, int64_t shard_count); // Infers the shape produced by an all-gather-start with the given operand // shape, concat dimension, and shard count. - static StatusOr InferAllGatherStartShape( + static absl::StatusOr InferAllGatherStartShape( absl::Span operand_shapes, int64_t all_gather_dimension, int64_t shard_count); // Infers the shape produced by an all-gather-done given a certain // all-gather-start shape. - static StatusOr InferAllGatherDoneShape( + static absl::StatusOr InferAllGatherDoneShape( const Shape& all_gather_start_shape); // Infers the shape produced by a cross replica sum with the given operand // shapes. - static StatusOr InferAllReduceShape( + static absl::StatusOr InferAllReduceShape( absl::Span operand_shapes); // Infers the shape produced by a reduce-scatter with the given operand // shape, scatter dimension, and shard count. - static StatusOr InferReduceScatterShape( + static absl::StatusOr InferReduceScatterShape( absl::Span operand_shapes, int64_t scatter_dimension, int64_t shard_count); // Infers the shape produced by a cross replica sum start. - static StatusOr InferAllReduceStartShape( + static absl::StatusOr InferAllReduceStartShape( absl::Span operand_shapes); // Infers the shape produced by a cross replica sum done. - static StatusOr InferAllReduceDoneShape(const Shape& operand_shape); + static absl::StatusOr InferAllReduceDoneShape( + const Shape& operand_shape); // Infers final shape of an Alltoall operation that is created by the xla // builder. - static StatusOr InferAllToAllShape(const Shape& shape, - int64_t split_dimension, - int64_t concat_dimension, - int64_t split_count); + static absl::StatusOr InferAllToAllShape(const Shape& shape, + int64_t split_dimension, + int64_t concat_dimension, + int64_t split_count); // Infers the shape of an HLO all-to-all instruction. - static StatusOr InferAllToAllTupleShape( + static absl::StatusOr InferAllToAllTupleShape( absl::Span operand_shapes); // Infers the shape of a collective permute operation. - static StatusOr InferCollectivePermuteShape( + static absl::StatusOr InferCollectivePermuteShape( absl::Span operand_shapes); // Infers the shape of a collective permute start operation. - static StatusOr InferCollectivePermuteStartShape( + static absl::StatusOr InferCollectivePermuteStartShape( absl::Span operand_shapes, absl::Span context_shapes); // Infers the shape of a collective permute operation. - static StatusOr InferCollectivePermuteDoneShape( + static absl::StatusOr InferCollectivePermuteDoneShape( const Shape& operand_shape); // Infers the shape produced by applying the given reduction computation @@ -188,58 +187,57 @@ class ShapeInference { // If pass_index is true, the reduce function is invoked with the element // index as the leading parameter, and the program shape should match // accordingly (or an error will result). - static StatusOr InferReduceShape( + static absl::StatusOr InferReduceShape( absl::Span arg_shapes, absl::Span dimensions_to_reduce, const ProgramShape& to_apply); // Infers the shape produced by applying the given computation to the operand // shape with the given window and stride dimensions. - static StatusOr InferReduceWindowShape( + static absl::StatusOr InferReduceWindowShape( const Shape& operand_shape, const Shape& init_value, const Window& window, const ProgramShape& to_apply_shape); - static StatusOr InferReduceWindowShape(const Shape& operand_shape, - const Shape& init_value, - const Window& window); - static StatusOr InferReduceWindowShape( + static absl::StatusOr InferReduceWindowShape( + const Shape& operand_shape, const Shape& init_value, + const Window& window); + static absl::StatusOr InferReduceWindowShape( absl::Span operands, absl::Span init_values, const Window& window, const ProgramShape& to_apply_shape); - static StatusOr InferReduceWindowShape( + static absl::StatusOr InferReduceWindowShape( absl::Span operands, absl::Span init_values, const Window& window); // Infers the shape produced by scattering the given source shape to the // selected indices of each window on the operand shape. - static StatusOr InferSelectAndScatterShape( + static absl::StatusOr InferSelectAndScatterShape( const Shape& operand_shape, const ProgramShape& select_shape, const Window& window, const Shape& source_shape, const Shape& init_value_shape, const ProgramShape& scatter_shape); // Infers the shape produced by a reverse operation that reverses the order // of the elements in the given dimensions. - static StatusOr InferReverseShape( + static absl::StatusOr InferReverseShape( const Shape& operand_shape, absl::Span dimensions); // Infers the shape produced by a slice operation spanning from the starts to // the limits in the original shape's dimensions. // // e.g. slice f32[32x32] 0:16 0:16 -> f32[16x16] - static StatusOr InferSliceShape(const Shape& arg, - absl::Span starts, - absl::Span limits, - absl::Span strides); + static absl::StatusOr InferSliceShape( + const Shape& arg, absl::Span starts, + absl::Span limits, absl::Span strides); // Infers the shape produced by a dynamic slice operation of size specified // in 'slice_sizes', with dynamic start indices shape 'start_indices_shape'. - static StatusOr InferDynamicSliceShape( + static absl::StatusOr InferDynamicSliceShape( const Shape& operand_shape, absl::Span start_index_shapes, absl::Span slice_sizes, bool allow_scalar_indices = true); // Infers the shape produced by a dynamic update slice operation based // on the shape of operand and update. - static StatusOr InferDynamicUpdateSliceShape( + static absl::StatusOr InferDynamicUpdateSliceShape( const Shape& operand_shape, const Shape& update_shape, absl::Span start_index_shapes, bool allow_scalar_indices = true); @@ -248,100 +246,99 @@ class ShapeInference { // the given input shape. This is essential for operations on tuples, because // it is impossible to infer the type that comes out of the tuple indexing if // it is not a compile time constant. - static StatusOr InferGetTupleElementShape(const Shape& arg, - int64_t index); + static absl::StatusOr InferGetTupleElementShape(const Shape& arg, + int64_t index); // Infers the shape produced from a while node. condition and body are the // shapes of computations for the condition and the body of a while node, and // init is the shape of data initially passed in to the body as an argument. // The shapes must match; condition: T -> PRED, body: T -> T, init: T - static StatusOr InferWhileShape(const ProgramShape& condition, - const ProgramShape& body, - const Shape& init); + static absl::StatusOr InferWhileShape(const ProgramShape& condition, + const ProgramShape& body, + const Shape& init); // Infers the shape produced by a predicated or indexed conditional operation. - static StatusOr InferConditionalShape( + static absl::StatusOr InferConditionalShape( const Shape& branch_index, absl::Span branch_computations, absl::Span branch_operands); // Infers the shape produced by a broadcast operation. - static StatusOr InferBroadcastShape( + static absl::StatusOr InferBroadcastShape( const Shape& operand, absl::Span broadcast_sizes); // Checks whether the given parameters can form a broadcast. Returns the same // output_shape if it's legal. - static StatusOr InferBroadcastShape( + static absl::StatusOr InferBroadcastShape( const Shape& operand_shape, const Shape& output_shape, absl::Span broadcast_dimensions); // Infers the shape produced by a reshape operation from the element type of // its operand and the new dimension sizes specified. - static StatusOr InferReshapeShape(const Shape& operand, - absl::Span dimensions, - absl::Span new_sizes, - int64_t inferred_dimension); + static absl::StatusOr InferReshapeShape( + const Shape& operand, absl::Span dimensions, + absl::Span new_sizes, int64_t inferred_dimension); // Infers the shape produced by a dynamic reshape operation from the element // type of its operand and the new dimension sizes specified. The result shape // will have dynamic dimensions as specific in `dim_is_dynamic` and bound // `new_size_bounds`. - static StatusOr InferDynamicReshapeShape( + static absl::StatusOr InferDynamicReshapeShape( const Shape& operand, absl::Span dim_size_shapes, absl::Span new_size_bounds, const std::vector& dims_are_dynamic); // Infers the shape produced by a transpose operation from the element type of // its operand and its dimensions field. - static StatusOr InferTransposeShape( + static absl::StatusOr InferTransposeShape( const Shape& operand, absl::Span dimensions); // Helper that infers the shape produced by performing a concatenate operation // with the given operand shapes. - static StatusOr InferConcatOpShape( + static absl::StatusOr InferConcatOpShape( absl::Span arg_shapes, int64_t dimension); // Helper that validates the given operand shape can be converted to the // target output_shape via a convert instruction -- the requirement is that // the shape is identical except for the element type. - static StatusOr InferConvertShape(const Shape& operand_shape, - PrimitiveType new_element_type); + static absl::StatusOr InferConvertShape( + const Shape& operand_shape, PrimitiveType new_element_type); // Helper that validates the given operand shape can be bitcast converted to // the target output_shape via a bitcast convert instruction -- the // requirement is that the shape is identical except for the element type and // the element types have identical bit-widths. - static StatusOr InferBitcastConvertShape( + static absl::StatusOr InferBitcastConvertShape( const Shape& operand_shape, PrimitiveType new_element_type); // Helper that validates the given operand shape can be converted to the // target output_shape via a stochastic convert instruction -- the requirement // is that the shape is identical except for the element type. - static StatusOr InferStochasticConvertShape( + static absl::StatusOr InferStochasticConvertShape( const Shape& operand_shape, const Shape& random_shape, PrimitiveType new_element_type); // Helper that validates the input data type for a reduce-precision operation, // and returns the result shape. - static StatusOr InferReducePrecisionShape(const Shape& operand_shape, - const int exponent_bits, - const int mantissa_bits); + static absl::StatusOr InferReducePrecisionShape( + const Shape& operand_shape, const int exponent_bits, + const int mantissa_bits); // Helper that infers the shape produced by a pad operation based on the // padding configuration. - static StatusOr InferPadShape(const Shape& operand_shape, - const Shape& padding_value_shape, - const PaddingConfig& padding_config); + static absl::StatusOr InferPadShape( + const Shape& operand_shape, const Shape& padding_value_shape, + const PaddingConfig& padding_config); // Helper that validates the given arg_shapes are compatible with the shape of // the to_apply parameters, and returns the to_apply result shape. - static StatusOr InferCallShape( + static absl::StatusOr InferCallShape( absl::Span arg_shapes, const ProgramShape& to_apply); // Helper that infers the shape produced by performing a dot operation with // the given LHS and RHS shapes. An optional preferred_element_type can be // specified to upcast the element type. - static StatusOr InferDotOpShape( + static absl::StatusOr InferDotOpShape( const Shape& lhs, const Shape& rhs, const DotDimensionNumbers& dimension_numbers, std::optional preferred_element_type); @@ -349,7 +346,7 @@ class ShapeInference { // Helper that infers the shape of the tensor produced by a gather operation // with the given input shape, gather indices shape and gather dimension // numbers. - static StatusOr InferGatherShape( + static absl::StatusOr InferGatherShape( const Shape& input_shape, const Shape& start_indices_shape, const GatherDimensionNumbers& gather_dim_numbers, absl::Span slice_sizes); @@ -357,25 +354,25 @@ class ShapeInference { // Helper that validates the given input shape, scatter indices shape, updates // shape, and scatter dimension numbers that constitute a scatter operation, // and returns the result shape of the scatter operation. - static StatusOr InferScatterShape( + static absl::StatusOr InferScatterShape( absl::Span arg_shapes, const ProgramShape& to_apply_shape, const ScatterDimensionNumbers& scatter_dim_numbers); // Helper that validates the given input shape to GetDimensionSize. - static StatusOr InferGetDimensionSizeShape(const Shape& shape, - int64_t dimension); + static absl::StatusOr InferGetDimensionSizeShape(const Shape& shape, + int64_t dimension); // Helper that validates the given input shape to SetDimensionSize. - static StatusOr InferSetDimensionSizeShape(const Shape& operand_shape, - const Shape& val_shape, - int64_t dimension); + static absl::StatusOr InferSetDimensionSizeShape( + const Shape& operand_shape, const Shape& val_shape, int64_t dimension); - static StatusOr InferTopKShape(const Shape& operand_shape, int64_t k); + static absl::StatusOr InferTopKShape(const Shape& operand_shape, + int64_t k); // Helper function for creating a Window proto from user-supplied data. // Returns error if the user-supplied data was invalid. - static StatusOr InferWindowFromDimensions( + static absl::StatusOr InferWindowFromDimensions( absl::Span window_dimensions, absl::Span window_strides, absl::Span> padding, @@ -389,31 +386,32 @@ class ShapeInference { // Note: By "element-wise" we mean operations that look at a single element in // the LHS and a single element in the RHS to produce a single output element, // even in the presence of broadcasting of one of the operands over the other. - static StatusOr InferElementwiseBinaryOpShape( + static absl::StatusOr InferElementwiseBinaryOpShape( HloOpcode operation, const Shape& lhs, const Shape& rhs, absl::Span broadcast_dimensions); // Helper for inferring the shape of Clamp ops. - static StatusOr InferClampShape(const Shape& min, const Shape& operand, - const Shape& max); + static absl::StatusOr InferClampShape(const Shape& min, + const Shape& operand, + const Shape& max); // Helper for inferring the shape of Select ops. - static StatusOr InferSelectShape(const Shape& pred, - const Shape& on_true, - const Shape& on_false); + static absl::StatusOr InferSelectShape(const Shape& pred, + const Shape& on_true, + const Shape& on_false); // Helper for inferring shapes of binary operations which use degenerate // dimension broadcasting (a dimension of size 1 in one operand is broadcast // up to match the size of the dimension in the other operand). - static StatusOr InferDegenerateDimensionBroadcastShape( - HloOpcode operation, const Shape& lhs, const Shape& rhs); + static absl::StatusOr InferDegenerateDimensionBroadcastShape( + const Shape& lhs, const Shape& rhs); // Helper for inferring shapes of binary operations using "InDim" // broadcasting. This is the broadcasting used in the *InDim binary operations // (for example ComputationBuilder::AddInDim). smaller_shape must be a // lower-rank shape than larger_shape. Returns the shape that the // smaller_shape is broadcast to. - static StatusOr InferInDimBroadcastShape( + static absl::StatusOr InferInDimBroadcastShape( const Shape& smaller_shape, const Shape& larger_shape, absl::Span broadcast_dimensions); diff --git a/third_party/xla/xla/service/shape_inference_test.cc b/third_party/xla/xla/service/shape_inference_test.cc index 43dcb0f249d6af..65e7b41c9eb20d 100644 --- a/third_party/xla/xla/service/shape_inference_test.cc +++ b/third_party/xla/xla/service/shape_inference_test.cc @@ -45,6 +45,9 @@ namespace { using ::testing::ContainsRegex; using ::testing::HasSubstr; +constexpr absl::string_view kIncompatibleBinaryOpShapeErrorMessage = + "Binary op with incompatible shapes"; + class ShapeInferenceTest : public ::testing::Test { protected: // Some handy scalar shapes. @@ -355,7 +358,7 @@ TEST_F(ShapeInferenceTest, Complex) { } TEST_F(ShapeInferenceTest, VariadicOpTuplify) { - StatusOr result = + absl::StatusOr result = ShapeInference::InferVariadicOpShape(HloOpcode::kTuple, {&s32_, &f32_}); ASSERT_IS_OK(result.status()); ASSERT_TRUE( @@ -2467,7 +2470,7 @@ TEST_F(ShapeInferenceTest, ConditionalDynamic) { TEST_F(ShapeInferenceTest, BadSlice) { auto arg = ShapeUtil::MakeShape(F32, {4}); - StatusOr statusor = + absl::StatusOr statusor = ShapeInference::InferSliceShape(arg, {0}, {5}, {1}); ASSERT_FALSE(statusor.ok()); @@ -2483,7 +2486,7 @@ TEST_F(ShapeInferenceTest, BadSlice) { TEST_F(ShapeInferenceTest, BadSort) { auto keys = ShapeUtil::MakeShape(F32, {4}); auto values = ShapeUtil::MakeShape(F32, {5}); - StatusOr statusor = + absl::StatusOr statusor = ShapeInference::InferVariadicOpShape(HloOpcode::kSort, {&keys, &values}); EXPECT_FALSE(statusor.ok()); EXPECT_THAT(statusor.status().message(), HasSubstr("dimensions must match")) @@ -2494,7 +2497,7 @@ TEST_F(ShapeInferenceTest, BadSortValuesMismatch) { auto keys = ShapeUtil::MakeShape(F32, {4}); auto values_good = ShapeUtil::MakeShape(F32, {4}); auto values_bad = ShapeUtil::MakeShape(F32, {5}); - StatusOr statusor = ShapeInference::InferVariadicOpShape( + absl::StatusOr statusor = ShapeInference::InferVariadicOpShape( HloOpcode::kSort, {&keys, &values_good, &values_bad}); EXPECT_FALSE(statusor.ok()); EXPECT_THAT(statusor.status().message(), HasSubstr("dimensions must match")) @@ -2505,7 +2508,7 @@ TEST_F(ShapeInferenceTest, SortManyValues) { auto keys = ShapeUtil::MakeShape(F32, {4}); auto values_s32 = ShapeUtil::MakeShape(S32, {4}); auto values_u32 = ShapeUtil::MakeShape(U32, {4}); - StatusOr statusor = ShapeInference::InferVariadicOpShape( + absl::StatusOr statusor = ShapeInference::InferVariadicOpShape( HloOpcode::kSort, {&keys, &values_s32, &values_u32}); EXPECT_IS_OK(statusor); Shape inferred_shape = *statusor; @@ -2516,7 +2519,7 @@ TEST_F(ShapeInferenceTest, SortManyValues) { TEST_F(ShapeInferenceTest, GoodTopK) { auto input = ShapeUtil::MakeShape(F32, {3, 4, 5}); - StatusOr s = ShapeInference::InferTopKShape(input, /*k=*/2); + absl::StatusOr s = ShapeInference::InferTopKShape(input, /*k=*/2); ASSERT_IS_OK(s.status()); ASSERT_TRUE(ShapeUtil::Equal( *s, ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {3, 4, 2}), @@ -2525,7 +2528,8 @@ TEST_F(ShapeInferenceTest, GoodTopK) { TEST_F(ShapeInferenceTest, FailTopKLargeK) { auto input = ShapeUtil::MakeShape(F32, {3, 4, 5}); - StatusOr statusor = ShapeInference::InferTopKShape(input, /*k=*/10); + absl::StatusOr statusor = + ShapeInference::InferTopKShape(input, /*k=*/10); EXPECT_FALSE(statusor.ok()); } @@ -2774,7 +2778,7 @@ TEST_F(GatherShapeInferenceTest, ScalarGatherIndices) { } TEST_F(GatherShapeInferenceTest, TupleShapedTensorInput) { - StatusOr statusor = ShapeInference::InferGatherShape( + absl::StatusOr statusor = ShapeInference::InferGatherShape( tuple_shape_, s64_vector_32_, HloGatherInstruction::MakeGatherDimNumbers( /*offset_dims=*/{0}, @@ -2789,7 +2793,7 @@ TEST_F(GatherShapeInferenceTest, TupleShapedTensorInput) { } TEST_F(GatherShapeInferenceTest, TupleShapedGatherIndicesInput) { - StatusOr statusor = ShapeInference::InferGatherShape( + absl::StatusOr statusor = ShapeInference::InferGatherShape( s64_vector_32_, tuple_shape_, HloGatherInstruction::MakeGatherDimNumbers( /*offset_dims=*/{0}, @@ -2804,7 +2808,7 @@ TEST_F(GatherShapeInferenceTest, TupleShapedGatherIndicesInput) { } TEST_F(GatherShapeInferenceTest, FloatingPointGatherIndicesInput) { - StatusOr statusor = ShapeInference::InferGatherShape( + absl::StatusOr statusor = ShapeInference::InferGatherShape( s64_vector_32_, vector_32_, HloGatherInstruction::MakeGatherDimNumbers( /*offset_dims=*/{0}, @@ -2820,7 +2824,7 @@ TEST_F(GatherShapeInferenceTest, FloatingPointGatherIndicesInput) { TEST_F(GatherShapeInferenceTest, InvalidGatherDimNumbers_NonAscendingWindowIndices) { - StatusOr statusor = ShapeInference::InferGatherShape( + absl::StatusOr statusor = ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, HloGatherInstruction::MakeGatherDimNumbers( /*offset_dims=*/{4, 5, 6, 8, 7}, @@ -2837,7 +2841,7 @@ TEST_F(GatherShapeInferenceTest, TEST_F(GatherShapeInferenceTest, InvalidGatherDimNumbers_RepeatedWindowIndices) { - StatusOr statusor = ShapeInference::InferGatherShape( + absl::StatusOr statusor = ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, HloGatherInstruction::MakeGatherDimNumbers( /*offset_dims=*/{4, 5, 6, 7, 7}, @@ -2854,7 +2858,7 @@ TEST_F(GatherShapeInferenceTest, TEST_F(GatherShapeInferenceTest, InvalidGatherDimNumbers_WindowIndexOutOfBounds) { - StatusOr statusor = ShapeInference::InferGatherShape( + absl::StatusOr statusor = ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, HloGatherInstruction::MakeGatherDimNumbers( /*offset_dims=*/{4, 5, 99, 100, 101}, @@ -2870,7 +2874,7 @@ TEST_F(GatherShapeInferenceTest, TEST_F(GatherShapeInferenceTest, InvalidGatherDimNumbers_WindowIndexBarelyOutOfBounds) { - StatusOr statusor = ShapeInference::InferGatherShape( + absl::StatusOr statusor = ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, HloGatherInstruction::MakeGatherDimNumbers( /*offset_dims=*/{4, 5, 6, 7, 9}, @@ -2886,7 +2890,7 @@ TEST_F(GatherShapeInferenceTest, TEST_F(GatherShapeInferenceTest, InvalidGatherDimNumbers_MismatchingElidedWindowDims) { - StatusOr statusor = ShapeInference::InferGatherShape( + absl::StatusOr statusor = ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, HloGatherInstruction::MakeGatherDimNumbers( /*offset_dims=*/{4, 5, 6, 7, 8}, @@ -2904,7 +2908,7 @@ TEST_F(GatherShapeInferenceTest, TEST_F(GatherShapeInferenceTest, InvalidGatherDimNumbers_OutOfBoundsWindowToInputMapping) { - StatusOr statusor = ShapeInference::InferGatherShape( + absl::StatusOr statusor = ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, HloGatherInstruction::MakeGatherDimNumbers( /*offset_dims=*/{4, 5, 6, 7, 8}, @@ -2921,7 +2925,7 @@ TEST_F(GatherShapeInferenceTest, TEST_F(GatherShapeInferenceTest, InvalidGatherDimNumbers_RepeatedWindowToInputMapping) { - StatusOr statusor = ShapeInference::InferGatherShape( + absl::StatusOr statusor = ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, HloGatherInstruction::MakeGatherDimNumbers( /*offset_dims=*/{4, 5, 6, 7, 8}, @@ -2938,7 +2942,7 @@ TEST_F(GatherShapeInferenceTest, TEST_F(GatherShapeInferenceTest, InvalidGatherDimNumbers_MismatchingGatherToInputMapping) { - StatusOr statusor = ShapeInference::InferGatherShape( + absl::StatusOr statusor = ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, HloGatherInstruction::MakeGatherDimNumbers( /*offset_dims=*/{4, 5, 6, 7, 8}, @@ -2956,7 +2960,7 @@ TEST_F(GatherShapeInferenceTest, TEST_F(GatherShapeInferenceTest, InvalidGatherDimNumbers_OutOfBoundsGatherToInputMapping) { - StatusOr statusor = ShapeInference::InferGatherShape( + absl::StatusOr statusor = ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, HloGatherInstruction::MakeGatherDimNumbers( /*offset_dims=*/{4, 5, 6, 7, 8}, @@ -2972,7 +2976,7 @@ TEST_F(GatherShapeInferenceTest, TEST_F(GatherShapeInferenceTest, InvalidGatherDimNumbers_RepeatedGatherToInputMapping) { - StatusOr statusor = ShapeInference::InferGatherShape( + absl::StatusOr statusor = ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, HloGatherInstruction::MakeGatherDimNumbers( /*offset_dims=*/{4, 5, 6, 7, 8}, @@ -2989,7 +2993,7 @@ TEST_F(GatherShapeInferenceTest, TEST_F(GatherShapeInferenceTest, InvalidGatherDimNumbers_NonAscendingElidedWindowDims) { - StatusOr statusor = ShapeInference::InferGatherShape( + absl::StatusOr statusor = ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, HloGatherInstruction::MakeGatherDimNumbers( /*offset_dims=*/{4, 5, 6, 7, 8}, @@ -3004,7 +3008,7 @@ TEST_F(GatherShapeInferenceTest, } TEST_F(GatherShapeInferenceTest, InvalidGatherDimNumbers_WindowBoundsTooLarge) { - StatusOr statusor = ShapeInference::InferGatherShape( + absl::StatusOr statusor = ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, HloGatherInstruction::MakeGatherDimNumbers( /*offset_dims=*/{4, 5, 6, 7}, @@ -3021,7 +3025,7 @@ TEST_F(GatherShapeInferenceTest, InvalidGatherDimNumbers_WindowBoundsTooLarge) { TEST_F(GatherShapeInferenceTest, InvalidGatherDimNumbers_MismatchingNumberOfWindowBounds) { - StatusOr statusor = ShapeInference::InferGatherShape( + absl::StatusOr statusor = ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, HloGatherInstruction::MakeGatherDimNumbers( /*offset_dims=*/{4, 5, 6, 7, 8}, @@ -3038,7 +3042,7 @@ TEST_F(GatherShapeInferenceTest, TEST_F(GatherShapeInferenceTest, InvalidGatherDimNumbers_WindowBoundsNot1ForElidedDim) { - StatusOr statusor = ShapeInference::InferGatherShape( + absl::StatusOr statusor = ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, HloGatherInstruction::MakeGatherDimNumbers( /*offset_dims=*/{4, 5, 6, 7}, @@ -3055,7 +3059,7 @@ TEST_F(GatherShapeInferenceTest, } TEST_F(GatherShapeInferenceTest, OutOfBoundsGatherIndicesLeafDim) { - StatusOr statusor = ShapeInference::InferGatherShape( + absl::StatusOr statusor = ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_5_7_6_, HloGatherInstruction::MakeGatherDimNumbers( /*offset_dims=*/{4, 5, 6, 7, 8}, @@ -3197,7 +3201,7 @@ TEST_P(ScatterShapeInferenceTest, TfScatterWithPartialUpdatesV2) { TEST_P(ScatterShapeInferenceTest, TfScatterWithUpdatesBiggerThanInput) { auto shapes = CreateShapes({64, 48}, s64_vector(32), {65, 32}, types()); - StatusOr statusor = ShapeInference::InferScatterShape( + absl::StatusOr statusor = ShapeInference::InferScatterShape( shapes.ptrs, to_apply(types()), HloScatterInstruction::MakeScatterDimNumbers( /*update_window_dims=*/{0}, @@ -3214,7 +3218,7 @@ TEST_P(ScatterShapeInferenceTest, TfScatterWithUpdatesBiggerThanInput) { TEST_P(ScatterShapeInferenceTest, TfScatterWithUpdatesBiggerThanInputV2) { auto shapes = CreateShapes({64, 48}, s64_vector(32), {32, 49}, types()); - StatusOr statusor = ShapeInference::InferScatterShape( + absl::StatusOr statusor = ShapeInference::InferScatterShape( shapes.ptrs, to_apply(types()), HloScatterInstruction::MakeScatterDimNumbers( /*update_window_dims=*/{1}, @@ -3231,7 +3235,7 @@ TEST_P(ScatterShapeInferenceTest, TfScatterWithUpdatesBiggerThanInputV2) { TEST_P(ScatterShapeInferenceTest, TfScatterWithUpdatesNotMatchingIndices) { auto shapes = CreateShapes({64, 48}, s64_vector(32), {64, 31}, types()); - StatusOr statusor = ShapeInference::InferScatterShape( + absl::StatusOr statusor = ShapeInference::InferScatterShape( shapes.ptrs, to_apply(types()), HloScatterInstruction::MakeScatterDimNumbers( /*update_window_dims=*/{0}, @@ -3249,7 +3253,7 @@ TEST_P(ScatterShapeInferenceTest, TfScatterWithUpdatesNotMatchingIndices) { TEST_P(ScatterShapeInferenceTest, TfScatterWithUpdatesNotMatchingIndicesV2) { auto shapes = CreateShapes({64, 48}, s64_vector(32), {31, 48}, types()); - StatusOr statusor = ShapeInference::InferScatterShape( + absl::StatusOr statusor = ShapeInference::InferScatterShape( shapes.ptrs, to_apply(types()), HloScatterInstruction::MakeScatterDimNumbers( /*update_window_dims=*/{1}, @@ -3328,7 +3332,7 @@ TEST_P(ScatterShapeInferenceTest, TfScatterNdWithPartialUpdatesV2) { TEST_P(ScatterShapeInferenceTest, TfScatterNdWithUpdatesBiggerThanInput) { auto shapes = CreateShapes({64, 48}, s64_tensor({10, 9, 8, 7, 1}), {10, 9, 8, 7, 65}, types()); - StatusOr statusor = ShapeInference::InferScatterShape( + absl::StatusOr statusor = ShapeInference::InferScatterShape( shapes.ptrs, to_apply(types()), HloScatterInstruction::MakeScatterDimNumbers( /*update_window_dims=*/{4}, @@ -3346,7 +3350,7 @@ TEST_P(ScatterShapeInferenceTest, TfScatterNdWithUpdatesBiggerThanInput) { TEST_P(ScatterShapeInferenceTest, TfScatterNdWithUpdatesNotMatchingIndices) { auto shapes = CreateShapes({64, 48}, s64_tensor({10, 9, 8, 7, 1}), {9, 9, 8, 7, 64}, types()); - StatusOr statusor = ShapeInference::InferScatterShape( + absl::StatusOr statusor = ShapeInference::InferScatterShape( shapes.ptrs, to_apply(types()), HloScatterInstruction::MakeScatterDimNumbers( /*update_window_dims=*/{4}, @@ -3458,7 +3462,7 @@ TEST_P(ScatterShapeInferenceTest, ScatterWithTupleShapedTensorInput) { ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(S64, {10, 9, 8, 7, 1}), ShapeUtil::MakeShape(S64, {10, 9, 8, 7, 1})}); Shape s64_vector_32 = s64_vector(32); - StatusOr statusor = ShapeInference::InferScatterShape( + absl::StatusOr statusor = ShapeInference::InferScatterShape( {&tuple_shape, &s64_vector_32, &s64_vector_32}, to_apply(types()), HloScatterInstruction::MakeScatterDimNumbers( /*update_window_dims=*/{0}, @@ -3476,7 +3480,7 @@ TEST_P(ScatterShapeInferenceTest, ScatterWithTupleShapedScatterIndicesInput) { ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(S64, {10, 9, 8, 7, 1}), ShapeUtil::MakeShape(S64, {10, 9, 8, 7, 1})}); Shape s64_vector_32 = s64_vector(32); - StatusOr statusor = ShapeInference::InferScatterShape( + absl::StatusOr statusor = ShapeInference::InferScatterShape( {&s64_vector_32, &tuple_shape, &s64_vector_32}, to_apply(types()), HloScatterInstruction::MakeScatterDimNumbers( /*update_window_dims=*/{0}, @@ -3494,7 +3498,7 @@ TEST_P(ScatterShapeInferenceTest, ScatterWithTupleShapedUpdatesInput) { ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(S64, {10, 9, 8, 7, 1}), ShapeUtil::MakeShape(S64, {10, 9, 8, 7, 1})}); Shape s64_vector_32 = s64_vector(32); - StatusOr statusor = ShapeInference::InferScatterShape( + absl::StatusOr statusor = ShapeInference::InferScatterShape( {&s64_vector_32, &s64_vector_32, &tuple_shape}, to_apply(types()), HloScatterInstruction::MakeScatterDimNumbers( /*update_window_dims=*/{0}, @@ -3509,7 +3513,7 @@ TEST_P(ScatterShapeInferenceTest, ScatterWithTupleShapedUpdatesInput) { TEST_P(ScatterShapeInferenceTest, FloatingPointScatterIndicesInput) { Shape s64_vector_32 = s64_vector(32); - StatusOr statusor = ShapeInference::InferScatterShape( + absl::StatusOr statusor = ShapeInference::InferScatterShape( {&s64_vector_32, &vector_32_, &s64_vector_32}, to_apply(types()), HloScatterInstruction::MakeScatterDimNumbers( /*update_window_dims=*/{0}, @@ -3525,7 +3529,7 @@ TEST_P(ScatterShapeInferenceTest, FloatingPointScatterIndicesInput) { TEST_P(ScatterShapeInferenceTest, OutOfBoundsScatterIndicesLeafDim) { auto shapes = CreateShapes({50, 49, 48, 47, 46}, s64_tensor({10, 9, 8, 7, 5}), {10, 9, 8, 7, 30, 29, 28}, types()); - StatusOr statusor = ShapeInference::InferScatterShape( + absl::StatusOr statusor = ShapeInference::InferScatterShape( shapes.ptrs, to_apply(types()), HloScatterInstruction::MakeScatterDimNumbers( /*update_window_dims=*/{4, 5, 6}, @@ -3542,7 +3546,7 @@ TEST_P(ScatterShapeInferenceTest, OutOfBoundsScatterIndicesLeafDim) { TEST_P(ScatterShapeInferenceTest, InvalidUpdates) { auto shapes = CreateShapes({50, 49, 48, 47, 46}, s64_tensor({10, 9, 8, 7, 5}), {10, 9, 8, 7, 30, 29, 28, 50}, types()); - StatusOr statusor = ShapeInference::InferScatterShape( + absl::StatusOr statusor = ShapeInference::InferScatterShape( shapes.ptrs, to_apply(types()), HloScatterInstruction::MakeScatterDimNumbers( /*update_window_dims=*/{4, 5, 6}, @@ -3560,7 +3564,7 @@ TEST_P(ScatterShapeInferenceTest, InvalidUpdateComputation) { ShapeUtil::MakeProgramShape({f32_}, f32_); auto shapes = CreateShapes({50, 49, 48, 47, 46}, s64_tensor({10, 9, 8, 7, 5}), {10, 9, 8, 7, 30, 29, 28}, types()); - StatusOr statusor = ShapeInference::InferScatterShape( + absl::StatusOr statusor = ShapeInference::InferScatterShape( shapes.ptrs, invalid_update_computation, HloScatterInstruction::MakeScatterDimNumbers( /*update_window_dims=*/{4, 5, 6}, @@ -3579,7 +3583,7 @@ TEST_P(ScatterShapeInferenceTest, InvalidScatterDimNumbers_NonAscendingUpdateWindowDims) { auto shapes = CreateShapes({50, 49, 48, 47, 46}, s64_tensor({10, 9, 8, 7, 5}), {10, 9, 8, 7, 30, 29, 28, 27, 26}, types()); - StatusOr statusor = ShapeInference::InferScatterShape( + absl::StatusOr statusor = ShapeInference::InferScatterShape( shapes.ptrs, to_apply(types()), HloScatterInstruction::MakeScatterDimNumbers( /*update_window_dims=*/{4, 5, 6, 8, 7}, @@ -3596,7 +3600,7 @@ TEST_P(ScatterShapeInferenceTest, InvalidScatterDimNumbers_RepeatedUpdateWindowDims) { auto shapes = CreateShapes({50, 49, 48, 47, 46}, s64_tensor({10, 9, 8, 7, 5}), {10, 9, 8, 7, 30, 29, 28, 27, 26}, types()); - StatusOr statusor = ShapeInference::InferScatterShape( + absl::StatusOr statusor = ShapeInference::InferScatterShape( shapes.ptrs, to_apply(types()), HloScatterInstruction::MakeScatterDimNumbers( /*update_window_dims=*/{4, 5, 6, 7, 7}, @@ -3613,7 +3617,7 @@ TEST_P(ScatterShapeInferenceTest, InvalidScatterDimNumbers_OutOfBoundsUpdateWindowDims) { auto shapes = CreateShapes({50, 49, 48, 47, 46}, s64_tensor({10, 9, 8, 7, 5}), {10, 9, 8, 7, 30, 29, 28, 27, 26}, types()); - StatusOr statusor = ShapeInference::InferScatterShape( + absl::StatusOr statusor = ShapeInference::InferScatterShape( shapes.ptrs, to_apply(types()), HloScatterInstruction::MakeScatterDimNumbers( /*update_window_dims=*/{4, 5, 6, 7, 9}, @@ -3631,7 +3635,7 @@ TEST_P(ScatterShapeInferenceTest, InvalidScatterDimNumbers_NonAscendingInsertedWindowDims) { auto shapes = CreateShapes({50, 49, 48, 47, 46}, s64_tensor({10, 9, 8, 7, 5}), {10, 9, 8, 7, 30, 29, 28}, types()); - StatusOr statusor = ShapeInference::InferScatterShape( + absl::StatusOr statusor = ShapeInference::InferScatterShape( shapes.ptrs, to_apply(types()), HloScatterInstruction::MakeScatterDimNumbers( /*update_window_dims=*/{4, 5, 6}, @@ -3648,7 +3652,7 @@ TEST_P(ScatterShapeInferenceTest, InvalidScatterDimNumbers_RepeatedInsertedWindowDims) { auto shapes = CreateShapes({50, 49, 48, 47, 46}, s64_tensor({10, 9, 8, 7, 5}), {10, 9, 8, 7, 30, 29, 28}, types()); - StatusOr statusor = ShapeInference::InferScatterShape( + absl::StatusOr statusor = ShapeInference::InferScatterShape( shapes.ptrs, to_apply(types()), HloScatterInstruction::MakeScatterDimNumbers( /*update_window_dims=*/{4, 5, 6}, @@ -3665,7 +3669,7 @@ TEST_P(ScatterShapeInferenceTest, InvalidScatterDimNumbers_OutOfBoundsInsertedWindowDims) { auto shapes = CreateShapes({50, 49, 48, 47, 46}, s64_tensor({10, 9, 8, 7, 5}), {10, 9, 8, 7, 30, 29, 28}, types()); - StatusOr statusor = ShapeInference::InferScatterShape( + absl::StatusOr statusor = ShapeInference::InferScatterShape( shapes.ptrs, to_apply(types()), HloScatterInstruction::MakeScatterDimNumbers( /*update_window_dims=*/{4, 5, 6}, @@ -3683,7 +3687,7 @@ TEST_P(ScatterShapeInferenceTest, InvalidScatterDimNumbers_MismatchingScatterDimsToOperandDims) { auto shapes = CreateShapes({50, 49, 48, 47, 46}, s64_tensor({10, 9, 8, 7, 5}), {10, 9, 8, 7, 30, 29, 28}, types()); - StatusOr statusor = ShapeInference::InferScatterShape( + absl::StatusOr statusor = ShapeInference::InferScatterShape( shapes.ptrs, to_apply(types()), HloScatterInstruction::MakeScatterDimNumbers( /*update_window_dims=*/{4, 5, 6}, @@ -3703,7 +3707,7 @@ TEST_P(ScatterShapeInferenceTest, InvalidScatterDimNumbers_OutOfBoundsScatterDimsToOperandDims) { auto shapes = CreateShapes({50, 49, 48, 47, 46}, s64_tensor({10, 9, 8, 7, 5}), {10, 9, 8, 7, 30, 29, 28}, types()); - StatusOr statusor = ShapeInference::InferScatterShape( + absl::StatusOr statusor = ShapeInference::InferScatterShape( shapes.ptrs, to_apply(types()), HloScatterInstruction::MakeScatterDimNumbers( /*update_window_dims=*/{4, 5, 6}, @@ -3721,7 +3725,7 @@ TEST_P(ScatterShapeInferenceTest, InvalidScatterDimNumbers_RepeatedValuesInScatterDimsToOperandDims) { auto shapes = CreateShapes({50, 49, 48, 47, 46}, s64_tensor({10, 9, 8, 7, 5}), {10, 9, 8, 7, 30, 29, 28}, types()); - StatusOr statusor = ShapeInference::InferScatterShape( + absl::StatusOr statusor = ShapeInference::InferScatterShape( shapes.ptrs, to_apply(types()), HloScatterInstruction::MakeScatterDimNumbers( /*update_window_dims=*/{4, 5, 6}, @@ -3740,7 +3744,7 @@ TEST_P(ScatterShapeInferenceTest, InvalidScatterDimNumbers_InsufficientWindowDims) { auto shapes = CreateShapes({50, 49, 48, 47, 46}, scalar(S64), {30, 29, 28, 27}, types()); - StatusOr statusor = ShapeInference::InferScatterShape( + absl::StatusOr statusor = ShapeInference::InferScatterShape( shapes.ptrs, to_apply(types()), HloScatterInstruction::MakeScatterDimNumbers( /*update_window_dims=*/{0, 1, 2, 3}, @@ -3782,7 +3786,7 @@ TEST_P(UnboundedUnaryOpShapeInferenceTest, UnboundedUnaryOps) { TEST_P(UnboundedBinaryOpShapeInferenceTest, UnboundedAdd) { TF_ASSERT_OK_AND_ASSIGN(Shape lhs, ParseShape(GetParam().lhs)); TF_ASSERT_OK_AND_ASSIGN(Shape rhs, ParseShape(GetParam().rhs)); - StatusOr inferred_status = + absl::StatusOr inferred_status = ShapeInference::InferBinaryOpShape(HloOpcode::kAdd, lhs, rhs, /*broadcast_dimensions=*/{}); if (inferred_status.ok()) { @@ -3792,14 +3796,14 @@ TEST_P(UnboundedBinaryOpShapeInferenceTest, UnboundedAdd) { << " expected: " << ShapeUtil::HumanString(expected); } else { EXPECT_THAT(inferred_status.status().message(), - HasSubstr("Binary op add with incompatible shapes")); + HasSubstr(kIncompatibleBinaryOpShapeErrorMessage)); } } TEST_P(UnboundedAndOpShapeInferenceTest, UnboundedAnd) { TF_ASSERT_OK_AND_ASSIGN(Shape lhs, ParseShape(GetParam().lhs)); TF_ASSERT_OK_AND_ASSIGN(Shape rhs, ParseShape(GetParam().rhs)); - StatusOr inferred_status = + absl::StatusOr inferred_status = ShapeInference::InferBinaryOpShape(HloOpcode::kAnd, lhs, rhs, /*broadcast_dimensions=*/{}); if (inferred_status.ok()) { @@ -3809,7 +3813,7 @@ TEST_P(UnboundedAndOpShapeInferenceTest, UnboundedAnd) { << " expected: " << ShapeUtil::HumanString(expected); } else { EXPECT_THAT(inferred_status.status().message(), - HasSubstr("Binary op and with incompatible shapes")); + HasSubstr(kIncompatibleBinaryOpShapeErrorMessage)); } } @@ -3867,7 +3871,7 @@ TEST_F(ShapeInferenceTest, UnboundedBatchNormTraining) { TEST_F(ShapeInferenceTest, UnboundedBroadcastUnsupportedOperand) { TF_ASSERT_OK_AND_ASSIGN(Shape operand, ParseShape("f32[<=2, ?]")); TF_ASSERT_OK_AND_ASSIGN(Shape expected, ParseShape("f32[1, <=2, ?]")); - StatusOr inferred_status = + absl::StatusOr inferred_status = ShapeInference::InferBroadcastShape(operand, /*broadcast_sizes=*/{1}); EXPECT_THAT(inferred_status.status().message(), HasSubstr("is_unbounded_dynamic")); @@ -3875,7 +3879,7 @@ TEST_F(ShapeInferenceTest, UnboundedBroadcastUnsupportedOperand) { TEST_F(ShapeInferenceTest, UnboundedBroadcastUnsupportedBroadcastSize) { TF_ASSERT_OK_AND_ASSIGN(Shape operand, ParseShape("f32[<=2, 4]")); - StatusOr inferred_status = ShapeInference::InferBroadcastShape( + absl::StatusOr inferred_status = ShapeInference::InferBroadcastShape( operand, /*broadcast_sizes=*/{Shape::kUnboundedSize}); EXPECT_THAT(inferred_status.status().message(), HasSubstr("Non-broadcast dimensions must not be dynamic.")); @@ -3908,7 +3912,7 @@ TEST_F(ShapeInferenceTest, UnboundedBroadcastInDimToBounded) { TEST_F(ShapeInferenceTest, UnboundedBroadcastInDimUnsupportedOutput) { TF_ASSERT_OK_AND_ASSIGN(Shape operand, ParseShape("f32[<=2, ?]")); TF_ASSERT_OK_AND_ASSIGN(Shape expected, ParseShape("f32[<=2, 3, ?]")); - StatusOr inferred_status = + absl::StatusOr inferred_status = ShapeInference::InferBroadcastShape(operand, expected, /*broadcast_dimensions=*/{0, 2}); EXPECT_THAT(inferred_status.status().message(), @@ -3917,7 +3921,7 @@ TEST_F(ShapeInferenceTest, UnboundedBroadcastInDimUnsupportedOutput) { TEST_F(ShapeInferenceTest, UnboundedBroadcastInDimUnsupported) { TF_ASSERT_OK_AND_ASSIGN(Shape operand, ParseShape("f32[<=2, 4]")); - StatusOr inferred_status = ShapeInference::InferBroadcastShape( + absl::StatusOr inferred_status = ShapeInference::InferBroadcastShape( operand, /*broadcast_sizes=*/{2, Shape::kUnboundedSize, 4}); EXPECT_THAT(inferred_status.status().message(), HasSubstr("Non-broadcast dimensions must not be dynamic.")); @@ -3927,7 +3931,7 @@ TEST_P(UnboundedClampOpShapeInferenceTest, UnboundedClamp) { TF_ASSERT_OK_AND_ASSIGN(Shape lhs, ParseShape(GetParam()[0])); TF_ASSERT_OK_AND_ASSIGN(Shape rhs, ParseShape(GetParam()[1])); TF_ASSERT_OK_AND_ASSIGN(Shape ehs, ParseShape(GetParam()[2])); - StatusOr inferred_status = + absl::StatusOr inferred_status = ShapeInference::InferTernaryOpShape(HloOpcode::kClamp, lhs, rhs, ehs); if (inferred_status.ok()) { TF_ASSERT_OK_AND_ASSIGN(Shape expected, ParseShape(GetParam()[3])); @@ -3944,7 +3948,7 @@ TEST_F(ShapeInferenceTest, UnboundedClampWithTuple) { TF_ASSERT_OK_AND_ASSIGN(Shape rhs, ParseShape("(f32[?], f32[2])")); TF_ASSERT_OK_AND_ASSIGN(Shape ehs, ParseShape("(f32[2], f32[?])")); TF_ASSERT_OK_AND_ASSIGN(Shape expected, ParseShape("(f32[?], f32[2])")); - StatusOr inferred_status = + absl::StatusOr inferred_status = ShapeInference::InferTernaryOpShape(HloOpcode::kClamp, lhs, rhs, ehs); EXPECT_THAT( inferred_status.status().message(), @@ -3955,7 +3959,7 @@ TEST_F(ShapeInferenceTest, UnboundedClampWithTuple) { TEST_P(UnboundedCompareOpShapeInferenceTest, UnboundedCompare) { TF_ASSERT_OK_AND_ASSIGN(Shape lhs, ParseShape(GetParam().lhs)); TF_ASSERT_OK_AND_ASSIGN(Shape rhs, ParseShape(GetParam().rhs)); - StatusOr inferred_status = + absl::StatusOr inferred_status = ShapeInference::InferBinaryOpShape(HloOpcode::kCompare, lhs, rhs, /*broadcast_dimensions=*/{}); if (inferred_status.ok()) { @@ -3965,14 +3969,14 @@ TEST_P(UnboundedCompareOpShapeInferenceTest, UnboundedCompare) { << " expected: " << ShapeUtil::HumanString(expected); } else { EXPECT_THAT(inferred_status.status().message(), - HasSubstr("Binary op compare with incompatible shapes")); + HasSubstr(kIncompatibleBinaryOpShapeErrorMessage)); } } TEST_P(UnboundedConcatenateOpShapeInferenceTest, UnboundedConcatenate) { TF_ASSERT_OK_AND_ASSIGN(Shape operand1, ParseShape(GetParam()[0])); TF_ASSERT_OK_AND_ASSIGN(Shape operand2, ParseShape(GetParam()[1])); - StatusOr inferred_status = + absl::StatusOr inferred_status = ShapeInference::InferConcatOpShape({&operand1, &operand2}, /*dimension=*/0); if (inferred_status.ok()) { @@ -3990,7 +3994,7 @@ TEST_F(UnboundedConcatenateOpShapeInferenceTest, TF_ASSERT_OK_AND_ASSIGN(Shape operand1, ParseShape("f32[2, ?]")); TF_ASSERT_OK_AND_ASSIGN(Shape operand2, ParseShape("f32[2, 3]")); TF_ASSERT_OK_AND_ASSIGN(Shape operand3, ParseShape("f32[2, 4]")); - StatusOr inferred_status = + absl::StatusOr inferred_status = ShapeInference::InferConcatOpShape({&operand1, &operand2, &operand3}, /*dimension=*/0); EXPECT_THAT(inferred_status.status().message(), @@ -4002,7 +4006,7 @@ TEST_F(UnboundedConcatenateOpShapeInferenceTest, TF_ASSERT_OK_AND_ASSIGN(Shape operand1, ParseShape("f32[2, ?]")); TF_ASSERT_OK_AND_ASSIGN(Shape operand2, ParseShape("f32[2, <=3]")); TF_ASSERT_OK_AND_ASSIGN(Shape operand3, ParseShape("f32[2, <=4]")); - StatusOr inferred_status = + absl::StatusOr inferred_status = ShapeInference::InferConcatOpShape({&operand1, &operand2, &operand3}, /*dimension=*/0); EXPECT_THAT(inferred_status.status().message(), @@ -4059,7 +4063,7 @@ TEST_F(ShapeInferenceTest, UnboundedConvolution) { TEST_P(UnboundedBinaryOpShapeInferenceTest, UnboundedDiv) { TF_ASSERT_OK_AND_ASSIGN(Shape lhs, ParseShape(GetParam().lhs)); TF_ASSERT_OK_AND_ASSIGN(Shape rhs, ParseShape(GetParam().rhs)); - StatusOr inferred_status = + absl::StatusOr inferred_status = ShapeInference::InferBinaryOpShape(HloOpcode::kDivide, lhs, rhs, /*broadcast_dimensions=*/{}); if (inferred_status.ok()) { @@ -4069,7 +4073,7 @@ TEST_P(UnboundedBinaryOpShapeInferenceTest, UnboundedDiv) { << " expected: " << ShapeUtil::HumanString(expected); } else { EXPECT_THAT(inferred_status.status().message(), - HasSubstr("Binary op divide with incompatible shapes")); + HasSubstr(kIncompatibleBinaryOpShapeErrorMessage)); } } @@ -4136,7 +4140,7 @@ TEST_F(ShapeInferenceTest, UnboundedGather) { TEST_P(UnboundedBinaryOpShapeInferenceTest, UnboundedMax) { TF_ASSERT_OK_AND_ASSIGN(Shape lhs, ParseShape(GetParam().lhs)); TF_ASSERT_OK_AND_ASSIGN(Shape rhs, ParseShape(GetParam().rhs)); - StatusOr inferred_status = + absl::StatusOr inferred_status = ShapeInference::InferBinaryOpShape(HloOpcode::kMaximum, lhs, rhs, /*broadcast_dimensions=*/{}); if (inferred_status.ok()) { @@ -4146,14 +4150,14 @@ TEST_P(UnboundedBinaryOpShapeInferenceTest, UnboundedMax) { << " expected: " << ShapeUtil::HumanString(expected); } else { EXPECT_THAT(inferred_status.status().message(), - HasSubstr("Binary op maximum with incompatible shapes")); + HasSubstr(kIncompatibleBinaryOpShapeErrorMessage)); } } TEST_P(UnboundedBinaryOpShapeInferenceTest, UnboundedMul) { TF_ASSERT_OK_AND_ASSIGN(Shape lhs, ParseShape(GetParam().lhs)); TF_ASSERT_OK_AND_ASSIGN(Shape rhs, ParseShape(GetParam().rhs)); - StatusOr inferred_status = + absl::StatusOr inferred_status = ShapeInference::InferBinaryOpShape(HloOpcode::kMultiply, lhs, rhs, /*broadcast_dimensions=*/{}); if (inferred_status.ok()) { @@ -4163,7 +4167,7 @@ TEST_P(UnboundedBinaryOpShapeInferenceTest, UnboundedMul) { << " expected: " << ShapeUtil::HumanString(expected); } else { EXPECT_THAT(inferred_status.status().message(), - HasSubstr("Binary op multiply with incompatible shapes")); + HasSubstr(kIncompatibleBinaryOpShapeErrorMessage)); } } @@ -4191,7 +4195,7 @@ TEST_F(ShapeInferenceTest, UnboundedPad) { TEST_P(UnboundedBinaryOpShapeInferenceTest, UnboundedPow) { TF_ASSERT_OK_AND_ASSIGN(Shape lhs, ParseShape(GetParam().lhs)); TF_ASSERT_OK_AND_ASSIGN(Shape rhs, ParseShape(GetParam().rhs)); - StatusOr inferred_status = + absl::StatusOr inferred_status = ShapeInference::InferBinaryOpShape(HloOpcode::kPower, lhs, rhs, /*broadcast_dimensions=*/{}); if (inferred_status.ok()) { @@ -4201,7 +4205,7 @@ TEST_P(UnboundedBinaryOpShapeInferenceTest, UnboundedPow) { << " expected: " << ShapeUtil::HumanString(expected); } else { EXPECT_THAT(inferred_status.status().message(), - HasSubstr("Binary op power with incompatible shapes")); + HasSubstr(kIncompatibleBinaryOpShapeErrorMessage)); } } @@ -4232,7 +4236,7 @@ TEST_F(ShapeInferenceTest, UnboundedReduceInvalidReduceDimension) { ProgramShape to_apply = ShapeUtil::MakeProgramShape( {f32_, f32_, f32_, f32_, f32_, f32_}, ShapeUtil::MakeTupleShape({f32_, f32_, f32_})); - StatusOr inferred_status = ShapeInference::InferReduceShape( + absl::StatusOr inferred_status = ShapeInference::InferReduceShape( {&input0, &input1, &input2, &f32_, &f32_, &f32_}, {1}, to_apply); EXPECT_THAT(inferred_status.status().message(), HasSubstr("All reduced tensors must have compatible dimension")); @@ -4279,7 +4283,7 @@ TEST_F(ShapeInferenceTest, UnboundedReshape) { TEST_F(ShapeInferenceTest, UnboundedReshapeUnsupportedOutputShape) { TF_ASSERT_OK_AND_ASSIGN(Shape operand, ParseShape("f32[6]")); - StatusOr inferred_status = ShapeInference::InferReshapeShape( + absl::StatusOr inferred_status = ShapeInference::InferReshapeShape( operand, /*dimensions=*/{0}, /*new_sizes=*/{Shape::kUnboundedSize, Shape::kUnboundedSize}, -1); EXPECT_THAT( @@ -4302,7 +4306,7 @@ TEST_P(UnboundedSelectOpShapeInferenceTest, UnboundedSelect) { TF_ASSERT_OK_AND_ASSIGN(Shape lhs, ParseShape(GetParam()[0])); TF_ASSERT_OK_AND_ASSIGN(Shape rhs, ParseShape(GetParam()[1])); TF_ASSERT_OK_AND_ASSIGN(Shape ehs, ParseShape(GetParam()[2])); - StatusOr inferred_status = + absl::StatusOr inferred_status = ShapeInference::InferTernaryOpShape(HloOpcode::kSelect, lhs, rhs, ehs); if (inferred_status.ok()) { TF_ASSERT_OK_AND_ASSIGN(Shape expected, ParseShape(GetParam()[3])); @@ -4319,7 +4323,7 @@ TEST_F(ShapeInferenceTest, UnboundedSelectWithTupleUnsupported) { TF_ASSERT_OK_AND_ASSIGN(Shape rhs, ParseShape("(f32[?], f32[2])")); TF_ASSERT_OK_AND_ASSIGN(Shape ehs, ParseShape("(f32[2], f32[?])")); TF_ASSERT_OK_AND_ASSIGN(Shape expected, ParseShape("(f32[?], f32[2])")); - StatusOr inferred_status = + absl::StatusOr inferred_status = ShapeInference::InferTernaryOpShape(HloOpcode::kSelect, lhs, rhs, ehs); EXPECT_THAT(inferred_status.status().message(), HasSubstr("Expected array argument for select pred, but got " @@ -4341,7 +4345,7 @@ TEST_F(ShapeInferenceTest, UnboundedSlice) { TEST_P(UnboundedBinaryOpShapeInferenceTest, UnboundedSub) { TF_ASSERT_OK_AND_ASSIGN(Shape lhs, ParseShape(GetParam().lhs)); TF_ASSERT_OK_AND_ASSIGN(Shape rhs, ParseShape(GetParam().rhs)); - StatusOr inferred_status = + absl::StatusOr inferred_status = ShapeInference::InferBinaryOpShape(HloOpcode::kSubtract, lhs, rhs, /*broadcast_dimensions=*/{}); if (inferred_status.ok()) { @@ -4351,7 +4355,7 @@ TEST_P(UnboundedBinaryOpShapeInferenceTest, UnboundedSub) { << " expected: " << ShapeUtil::HumanString(expected); } else { EXPECT_THAT(inferred_status.status().message(), - HasSubstr("Binary op subtract with incompatible shapes")); + HasSubstr(kIncompatibleBinaryOpShapeErrorMessage)); } } @@ -4562,6 +4566,7 @@ INSTANTIATE_TEST_SUITE_P(UnboundedDynamism, UnboundedUnaryOpShapeInferenceTest, {"f32[?]", "f32[?]", HloOpcode::kCeil}, {"u32[?]", "u32[?]", HloOpcode::kClz}, {"f32[?]", "f32[?]", HloOpcode::kCos}, + {"f32[?]", "f32[?]", HloOpcode::kErf}, {"f32[?]", "f32[?]", HloOpcode::kExp}, {"f32[?]", "f32[?]", HloOpcode::kExpm1}, {"f32[?]", "f32[?]", HloOpcode::kFloor}, diff --git a/third_party/xla/xla/service/shaped_buffer.cc b/third_party/xla/xla/service/shaped_buffer.cc index 18e6b2defc1fd3..a429ff8b3819af 100644 --- a/third_party/xla/xla/service/shaped_buffer.cc +++ b/third_party/xla/xla/service/shaped_buffer.cc @@ -67,7 +67,7 @@ ShapedBuffer& ShapedBuffer::operator=(ShapedBuffer&& s) { ShapedBuffer::~ShapedBuffer() {} -StatusOr ShapedBuffer::SubShapedBuffer( +absl::StatusOr ShapedBuffer::SubShapedBuffer( const ShapeIndex& index) const { TF_ASSIGN_OR_RETURN(const Shape* device_sub_shape, ShapeUtil::TryGetSubshape(on_device_shape(), index)); diff --git a/third_party/xla/xla/service/shaped_buffer.h b/third_party/xla/xla/service/shaped_buffer.h index d2211306265bba..3882241e746702 100644 --- a/third_party/xla/xla/service/shaped_buffer.h +++ b/third_party/xla/xla/service/shaped_buffer.h @@ -115,7 +115,7 @@ class ShapedBuffer { const ShapeTree& buffers() const { return buffers_; } ShapeTree& buffers() { return buffers_; } - StatusOr SubShapedBuffer(const ShapeIndex& index) const; + absl::StatusOr SubShapedBuffer(const ShapeIndex& index) const; // Set all device memory pointers in the object to null. void clear(); diff --git a/third_party/xla/xla/service/shaped_buffer_test.cc b/third_party/xla/xla/service/shaped_buffer_test.cc index fc6fb8c8c641c4..08ebed27077600 100644 --- a/third_party/xla/xla/service/shaped_buffer_test.cc +++ b/third_party/xla/xla/service/shaped_buffer_test.cc @@ -56,9 +56,9 @@ class TestAllocator : public se::DeviceMemoryAllocator { // Pull in two-arg overload of Allocate. using se::DeviceMemoryAllocator::Allocate; - StatusOr Allocate(int device_ordinal, uint64_t size, - bool /*retry_on_failure*/, - int64_t /*memory_space*/) override { + absl::StatusOr Allocate( + int device_ordinal, uint64_t size, bool /*retry_on_failure*/, + int64_t /*memory_space*/) override { // By contract, we must return null if size == 0. if (size == 0) { return se::OwningDeviceMemory(); @@ -86,7 +86,7 @@ class TestAllocator : public se::DeviceMemoryAllocator { bool AllowsAsynchronousDeallocation() const override { return false; } - StatusOr GetStream(int device_ordinal) override { + absl::StatusOr GetStream(int device_ordinal) override { LOG(FATAL) << "Not implemented"; } diff --git a/third_party/xla/xla/service/sharding_format_picker.cc b/third_party/xla/xla/service/sharding_format_picker.cc index 2262497e744395..13e0a244853de0 100644 --- a/third_party/xla/xla/service/sharding_format_picker.cc +++ b/third_party/xla/xla/service/sharding_format_picker.cc @@ -164,7 +164,7 @@ std::unique_ptr MaybeConvertToV1(const HloSharding& sharding) { } // namespace -StatusOr ShardingFormatPicker::Run( +absl::StatusOr ShardingFormatPicker::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { bool changed = false; diff --git a/third_party/xla/xla/service/sharding_format_picker.h b/third_party/xla/xla/service/sharding_format_picker.h index a67c141dc46db8..1ebd96520c97bd 100644 --- a/third_party/xla/xla/service/sharding_format_picker.h +++ b/third_party/xla/xla/service/sharding_format_picker.h @@ -32,7 +32,7 @@ class ShardingFormatPicker : public HloModulePass { : sharding_type_(sharding_type) {} absl::string_view name() const override { return "sharding-format-picker"; } using HloPassInterface::Run; - StatusOr Run( + absl::StatusOr Run( HloModule* module, const absl::flat_hash_set& execution_threads) override; diff --git a/third_party/xla/xla/service/sharding_propagation.cc b/third_party/xla/xla/service/sharding_propagation.cc index 24750840b794cb..2d9230b9ff2f4f 100644 --- a/third_party/xla/xla/service/sharding_propagation.cc +++ b/third_party/xla/xla/service/sharding_propagation.cc @@ -303,8 +303,10 @@ const HloInstruction* PickRepresentativeOperand( case HloOpcode::kAllReduce: case HloOpcode::kReduceScatter: case HloOpcode::kAllToAll: + case HloOpcode::kCollectiveBroadcast: case HloOpcode::kCollectivePermute: case HloOpcode::kDivide: + case HloOpcode::kErf: case HloOpcode::kExp: case HloOpcode::kExpm1: case HloOpcode::kFloor: @@ -1513,7 +1515,7 @@ bool InferReduceShardingFromOperand(HloInstruction* instruction, // copy node for reshard. // `unspecified_dims` will be populated with the converted copies if the custom // call is partially specified. -StatusOr ProcessShardingInstruction( +absl::StatusOr ProcessShardingInstruction( HloModule* module, const absl::flat_hash_set& execution_threads, bool replace_sharding_with_copy, @@ -2845,7 +2847,7 @@ Status ShardingPropagation::CanonicalizeLayouts(HloModule* module) { return OkStatus(); } -StatusOr ShardingPropagation::Run( +absl::StatusOr ShardingPropagation::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { std::optional> diff --git a/third_party/xla/xla/service/sharding_propagation.h b/third_party/xla/xla/service/sharding_propagation.h index 6b90c7443d8a73..390be726e24b7c 100644 --- a/third_party/xla/xla/service/sharding_propagation.h +++ b/third_party/xla/xla/service/sharding_propagation.h @@ -57,7 +57,7 @@ bool InferConvolutionShardingFromOperands(HloInstruction* instruction, // operand's existing sharding. // unspecified_dims will be populated with the converted copies if the custom // call is partially specified. -StatusOr ProcessShardingInstruction( +absl::StatusOr ProcessShardingInstruction( HloModule* module, const absl::flat_hash_set& execution_threads, bool replace_sharding_with_copy, @@ -122,7 +122,7 @@ class ShardingPropagation : public HloModulePass { } absl::string_view name() const override { return "sharding-propagation"; } using HloPassInterface::Run; - StatusOr Run( + absl::StatusOr Run( HloModule* module, const absl::flat_hash_set& execution_threads) override; diff --git a/third_party/xla/xla/service/sharding_remover.cc b/third_party/xla/xla/service/sharding_remover.cc index ef5f066abfb67e..2dc0b6e6b1d409 100644 --- a/third_party/xla/xla/service/sharding_remover.cc +++ b/third_party/xla/xla/service/sharding_remover.cc @@ -31,7 +31,7 @@ namespace xla { // Remove Sharding custom-call instruction by assigning its users to // to its operand. -StatusOr ShardingRemover::Run( +absl::StatusOr ShardingRemover::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { bool changed = false; diff --git a/third_party/xla/xla/service/sharding_remover.h b/third_party/xla/xla/service/sharding_remover.h index 173f88837d4774..39acf378cad655 100644 --- a/third_party/xla/xla/service/sharding_remover.h +++ b/third_party/xla/xla/service/sharding_remover.h @@ -31,7 +31,7 @@ class ShardingRemover : public HloModulePass { public: absl::string_view name() const override { return "sharding-remover"; } using HloPassInterface::Run; - StatusOr Run( + absl::StatusOr Run( HloModule* module, const absl::flat_hash_set& execution_threads) override; }; diff --git a/third_party/xla/xla/service/simplify_fp_conversions.cc b/third_party/xla/xla/service/simplify_fp_conversions.cc index c0a0fca5d66ee2..8cabbb17a4da19 100644 --- a/third_party/xla/xla/service/simplify_fp_conversions.cc +++ b/third_party/xla/xla/service/simplify_fp_conversions.cc @@ -16,11 +16,8 @@ limitations under the License. #include "xla/service/simplify_fp_conversions.h" #include -#include -#include #include "absl/container/flat_hash_set.h" -#include "absl/strings/match.h" #include "absl/strings/str_format.h" #include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_computation.h" @@ -36,37 +33,13 @@ namespace xla { namespace { // Simplifies floating-point conversions `A -> B -> C -> D` as `A -> D`. -StatusOr RunOnComputation(HloComputation& computation, - SimplifyFPConversions::Scope scope) { - // Since the goal of this pass is to simplify type conversions by removing - // some Convert ops, we don't want to run this pass for tests that are meant - // to test for functionality of the Convert op itself. - const absl::string_view comp_name = computation.name(); - const std::vector test_names{ - "ConvertF16F8e5m2Roundtrip", - "ConvertF16F8e4m3fnRoundtrip", - "ConvertF16F8e4m3b11fnuzRoundtrip", - "ConvertF16F8e5m2fnuzRoundtrip", - "ConvertF32F8e5m2fnuzRoundtrip", - "ConvertF8e5m2fnuzRoundtripExhaustive", - "ConvertF16F8e4m3fnuzRoundtrip", - "ConvertF32F8e4m3fnuzRoundtrip", - "ConvertF8e4m3fnuzRoundtripExhaustive"}; - for (const auto& test_name : test_names) { - if (absl::StrContains(comp_name, test_name)) { - return false; - } - } - const int minimum_logical_creation_pass_id = - (scope == SimplifyFPConversions::Scope::kSimplifyAllConversions) ? -1 : 0; +absl::StatusOr RunOnComputation(HloComputation& computation) { bool changed = false; for (HloInstruction* instruction : computation.MakeInstructionPostOrder()) { HloInstruction* input = instruction; size_t convert_chain_length = 0; - while ((input->opcode() == HloOpcode::kConvert) && - (input->metadata().logical_creation_pass_id() >= - minimum_logical_creation_pass_id) && + while (input->opcode() == HloOpcode::kConvert && primitive_util::IsFloatingPointType(input->shape().element_type())) { input = input->mutable_operand(0); ++convert_chain_length; @@ -89,36 +62,23 @@ StatusOr RunOnComputation(HloComputation& computation, return changed; } -std::string ToString(SimplifyFPConversions::Scope scope) { - using Scope = SimplifyFPConversions::Scope; - switch (scope) { - case Scope::kSimplifyAllConversions: - return "SimplifyAllConversions"; - case Scope::kOnlySimplifyCompilerGeneratedConversions: - return "OnlySimplifyCompilerGeneratedConversions"; - } -} - } // namespace -StatusOr SimplifyFPConversions::Run( +absl::StatusOr SimplifyFPConversions::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { XLA_VLOG_LINES( - 2, - absl::StrFormat("SimplifyFPConversions::Run() with scope=%s, before:\n%s", - ToString(scope_), module->ToString())); + 2, absl::StrFormat("SimplifyFPConversions::Run() with before:\n%s", + module->ToString())); bool changed = false; for (HloComputation* computation : module->MakeComputationPostOrder(execution_threads)) { - TF_ASSIGN_OR_RETURN(bool comp_changed, - RunOnComputation(*computation, scope_)); + TF_ASSIGN_OR_RETURN(bool comp_changed, RunOnComputation(*computation)); changed |= comp_changed; } - XLA_VLOG_LINES( - 2, - absl::StrFormat("SimplifyFPConversions::Run() with scope=%s, after:\n%s", - ToString(scope_), module->ToString())); + XLA_VLOG_LINES(2, + absl::StrFormat("SimplifyFPConversions::Run() with after:\n%s", + module->ToString())); return changed; } diff --git a/third_party/xla/xla/service/simplify_fp_conversions.h b/third_party/xla/xla/service/simplify_fp_conversions.h index b7d0ac4463f655..099b06a283cefc 100644 --- a/third_party/xla/xla/service/simplify_fp_conversions.h +++ b/third_party/xla/xla/service/simplify_fp_conversions.h @@ -26,31 +26,19 @@ namespace xla { // Simplifies chains of floating-point conversions. // // The algebraic simplifier will remove convert pairs of the form `X -> Y -> X`, -// only when they are a no-op (e.g. `bf16 -> f32 -> bf16`). This passes does -// similar, but has two scopes: -// - kSimplifyAllConversions: Simplify any chain of float conversions, possibly -// improving accuracy (e.g. `f32 -> bf16 -> f32` is removed). -// - kOnlySimplifyCompilerGeneratedConversions: Only simplify chains of float -// conversions generated by the compiler in one of the previous optimization -// passes. +// only when they are a no-op, e.g. `bf16 -> f32 -> bf16` or +// `f32 -> bf16 -> f32`. Note that the latter optimization might lead to +// increased precision. class SimplifyFPConversions : public HloModulePass { public: - enum class Scope { - kOnlySimplifyCompilerGeneratedConversions, - kSimplifyAllConversions - }; - - explicit SimplifyFPConversions(Scope scope) : scope_(scope) {} + explicit SimplifyFPConversions() = default; absl::string_view name() const override { return "simplify-fp-conversions"; } using HloPassInterface::Run; - StatusOr Run( + absl::StatusOr Run( HloModule* module, const absl::flat_hash_set& execution_threads) override; - - private: - Scope scope_; }; } // namespace xla diff --git a/third_party/xla/xla/service/simplify_fp_conversions_test.cc b/third_party/xla/xla/service/simplify_fp_conversions_test.cc index 8249f1969c47d9..ad85bb873eb654 100644 --- a/third_party/xla/xla/service/simplify_fp_conversions_test.cc +++ b/third_party/xla/xla/service/simplify_fp_conversions_test.cc @@ -36,32 +36,6 @@ using ::tsl::testing::IsOkAndHolds; using SimplifyFPConversionsTest = HloTestBase; -// This marks all ops in `module` as user-provided, meaning the -// simplifier won't remove any of the converts -static void InitializeCreationPassIds(HloModule* module) { - constexpr int kUserSuppliedOpCreationPassId = -1; - for (HloComputation* computation : module->computations()) { - for (HloInstruction* instruction : computation->instructions()) { - instruction->set_creation_pass_id(kUserSuppliedOpCreationPassId); - instruction->set_logical_creation_pass_id(kUserSuppliedOpCreationPassId); - } - } -} - -// This marks all converts ops in `module` as being created by the -// optimization pass `creation_pass_id`. -static void SetCreationPassIdInAllConvertOps(HloModule* module, - int creation_pass_id) { - for (HloComputation* computation : module->computations()) { - for (HloInstruction* instruction : computation->instructions()) { - if (instruction->opcode() == HloOpcode::kConvert) { - instruction->set_creation_pass_id(creation_pass_id); - instruction->set_logical_creation_pass_id(creation_pass_id); - } - } - } -} - TEST_F(SimplifyFPConversionsTest, DoesNotChangeSingleConvert) { const absl::string_view kModuleStr = R"( HloModule test @@ -74,10 +48,8 @@ TEST_F(SimplifyFPConversionsTest, DoesNotChangeSingleConvert) { )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(kModuleStr)); - InitializeCreationPassIds(module.get()); - SimplifyFPConversions simplifier{ - SimplifyFPConversions::Scope::kSimplifyAllConversions}; + SimplifyFPConversions simplifier; EXPECT_THAT(simplifier.Run(module.get()), IsOkAndHolds(false)); } @@ -94,60 +66,13 @@ TEST_F(SimplifyFPConversionsTest, SimplifiesF32ToBF16ToF32) { )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(kModuleStr)); - InitializeCreationPassIds(module.get()); - SimplifyFPConversions simplifier{ - SimplifyFPConversions::Scope::kSimplifyAllConversions}; + SimplifyFPConversions simplifier; EXPECT_THAT(simplifier.Run(module.get()), IsOkAndHolds(true)); EXPECT_THAT(module->entry_computation()->root_instruction(), op::Tuple(op::Parameter(0))); } -TEST_F(SimplifyFPConversionsTest, SimplifiesCompilerGeneratedF32ToBF16ToF32) { - const absl::string_view kModuleStr = R"( - HloModule test - - ENTRY entry { - p0 = f32[2,3] parameter(0) - c0 = bf16[2,3] convert(p0) - c1 = f32[2,3] convert(c0) - ROOT ret = (f32[2,3]) tuple(c1) - } - )"; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnVerifiedModule(kModuleStr)); - InitializeCreationPassIds(module.get()); - - constexpr int kRandomCreationPassId = 42; - SetCreationPassIdInAllConvertOps(module.get(), kRandomCreationPassId); - - SimplifyFPConversions simplifier{ - SimplifyFPConversions::Scope::kOnlySimplifyCompilerGeneratedConversions}; - EXPECT_THAT(simplifier.Run(module.get()), IsOkAndHolds(true)); - EXPECT_THAT(module->entry_computation()->root_instruction(), - op::Tuple(op::Parameter(0))); -} - -TEST_F(SimplifyFPConversionsTest, DoesNotChangeUserInsertedConverts) { - const absl::string_view kModuleStr = R"( - HloModule test - - ENTRY entry { - p0 = f32[2,3] parameter(0) - c0 = bf16[2,3] convert(p0) - c1 = f32[2,3] convert(c0) - ROOT ret = (f32[2,3]) tuple(c1) - } - )"; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnVerifiedModule(kModuleStr)); - InitializeCreationPassIds(module.get()); - - SimplifyFPConversions simplifier{ - SimplifyFPConversions::Scope::kOnlySimplifyCompilerGeneratedConversions}; - EXPECT_THAT(simplifier.Run(module.get()), IsOkAndHolds(false)); -} - TEST_F(SimplifyFPConversionsTest, SimplifiesF64ToF16ToF32ToBF16) { const absl::string_view kModuleStr = R"( HloModule test @@ -162,10 +87,8 @@ TEST_F(SimplifyFPConversionsTest, SimplifiesF64ToF16ToF32ToBF16) { )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(kModuleStr)); - InitializeCreationPassIds(module.get()); - SimplifyFPConversions simplifier{ - SimplifyFPConversions::Scope::kSimplifyAllConversions}; + SimplifyFPConversions simplifier; EXPECT_THAT(simplifier.Run(module.get()), IsOkAndHolds(true)); EXPECT_THAT( module->entry_computation()->root_instruction(), diff --git a/third_party/xla/xla/service/slice_sinker.cc b/third_party/xla/xla/service/slice_sinker.cc index f2dc260f22d2d8..dc7559444436c6 100644 --- a/third_party/xla/xla/service/slice_sinker.cc +++ b/third_party/xla/xla/service/slice_sinker.cc @@ -238,7 +238,7 @@ Status SinkSlices(const std::vector& slice_sources, // This pass currently doesn't transform non-elementwise instructions. We may // extend this pass to transform non-elementwise instructions, such as dot, // broadcast and reduce in the future. -StatusOr SliceSinker::Run( +absl::StatusOr SliceSinker::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { bool changed = false; diff --git a/third_party/xla/xla/service/slice_sinker.h b/third_party/xla/xla/service/slice_sinker.h index 04e63e84e2b3ae..61805ca874211d 100644 --- a/third_party/xla/xla/service/slice_sinker.h +++ b/third_party/xla/xla/service/slice_sinker.h @@ -27,7 +27,7 @@ class SliceSinker : public HloModulePass { absl::string_view name() const override { return "slice-sinker"; } using HloPassInterface::Run; - StatusOr Run( + absl::StatusOr Run( HloModule* module, const absl::flat_hash_set& execution_threads) override; }; diff --git a/third_party/xla/xla/service/sort_simplifier.cc b/third_party/xla/xla/service/sort_simplifier.cc index d6bef49876e2fc..99df7c6035beba 100644 --- a/third_party/xla/xla/service/sort_simplifier.cc +++ b/third_party/xla/xla/service/sort_simplifier.cc @@ -30,7 +30,7 @@ namespace { // If the sort instruction has a tuple shape then looks for unused output // values and removes them from the sort instruction. Returns true if the // graph has been modified. -StatusOr RemoveUnusedOperandFromSort(HloInstruction* sort) { +absl::StatusOr RemoveUnusedOperandFromSort(HloInstruction* sort) { if (!sort->shape().IsTuple()) { return false; } @@ -135,7 +135,7 @@ StatusOr RemoveUnusedOperandFromSort(HloInstruction* sort) { } } // namespace -StatusOr SortSimplifier::Run( +absl::StatusOr SortSimplifier::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { VLOG(2) << "HLO module before SortSimplifier:"; diff --git a/third_party/xla/xla/service/sort_simplifier.h b/third_party/xla/xla/service/sort_simplifier.h index 9d416ffd7ce73d..2f02216168b930 100644 --- a/third_party/xla/xla/service/sort_simplifier.h +++ b/third_party/xla/xla/service/sort_simplifier.h @@ -28,7 +28,7 @@ class SortSimplifier : public HloModulePass { public: absl::string_view name() const override { return "simplify-sorts"; } using HloPassInterface::Run; - StatusOr Run( + absl::StatusOr Run( HloModule* module, const absl::flat_hash_set& execution_threads) override; }; diff --git a/third_party/xla/xla/service/space_to_batch_converter.cc b/third_party/xla/xla/service/space_to_batch_converter.cc index 472dbe92dbe9f0..890d734f25e260 100644 --- a/third_party/xla/xla/service/space_to_batch_converter.cc +++ b/third_party/xla/xla/service/space_to_batch_converter.cc @@ -118,11 +118,12 @@ class ConvolutionVisitor { // Propagates space-to-batch on the op, and returns a bool that indicates if // the users of the op need to be propagated through. - StatusOr Propagate(HloInstruction* consumer, HloInstruction* producer); + absl::StatusOr Propagate(HloInstruction* consumer, + HloInstruction* producer); // Splits the given spatial dimension on the activations and returns the // new instructions, and the dimension permutation of the new shape. - StatusOr>> SplitSpace( + absl::StatusOr>> SplitSpace( HloInstruction* activations, ConvolutionDimensionNumbers& dim_numbers, int64_t& activations_batch_dim, int64_t high_padding, int64_t low_padding, int64_t spatial_split_size, int64_t num_splits, @@ -130,7 +131,7 @@ class ConvolutionVisitor { bool is_backprop = false, bool is_rhs = false); // Performs the actual dimension splitting. - StatusOr PerformSplitSpace( + absl::StatusOr PerformSplitSpace( HloInstruction* activations, absl::Span spatial_dimensions_to_split, int64_t activations_batch_dim, int64_t spatial_split_size, @@ -140,22 +141,22 @@ class ConvolutionVisitor { // merges the batch(es). // The input activations dimensions are ... B, B0, S0, B1, S1, ... Bn, Sn, ... // The output dimensions will be ..., B, S0, S1,.. Sn, ... - StatusOr TransposeAndMergeBatch( + absl::StatusOr TransposeAndMergeBatch( HloInstruction* activations, absl::Span final_split_spatial_dim_positioning, int64_t activations_batch_dim, int64_t old_batch_size); // Helper function for the SplitSpace function above. Handles padding and // reshaping to generate space-to-batched shape. - StatusOr PadAndSplitSpace( + absl::StatusOr PadAndSplitSpace( HloInstruction* activations, absl::Span spatial_dimensions_to_split, int64_t activations_batch_dim, int64_t high_padding, int64_t low_padding, int64_t spatial_split_size, int64_t num_splits); // Perform space-to-batch propagation on constants. - StatusOr PropagateOnConstant(HloInstruction* consumer, - HloInstruction* producer); + absl::StatusOr PropagateOnConstant(HloInstruction* consumer, + HloInstruction* producer); // Perform space-to-batch propagation on the convolution. Assumes the // activations were already space-to-batched. @@ -189,7 +190,7 @@ class ConvolutionVisitor { // Generates masked output with valid data. This is useful when larger shapes // are generated due to space-to-batch. - StatusOr SelectValidPortion( + absl::StatusOr SelectValidPortion( HloInstruction* new_instr, HloInstruction* old_instr, HloInstruction* select_val, int64_t new_batch_dim, absl::Span new_space_dims, int64_t old_batch_dim, @@ -201,7 +202,7 @@ class ConvolutionVisitor { }; // Performs tranposition so that space dimension follows the batch dimension. - StatusOr BringSpaceNextToBatch( + absl::StatusOr BringSpaceNextToBatch( HloInstruction* activations, ConvolutionDimensionNumbers& dim_numbers, int64_t& activations_batch_dim, std::vector* spatial_dimensions_to_split, @@ -209,29 +210,29 @@ class ConvolutionVisitor { // Decreases the spatial dimension size in an already space-to-batched shape // so that the new size is new_spatial_dim_size. - StatusOr ChangeSpatialSizeOnSpaceToBatchedShape( + absl::StatusOr ChangeSpatialSizeOnSpaceToBatchedShape( HloInstruction* activations, int64_t batch_dimension, int64_t old_batch_size, absl::Span spatial_dimensions_to_split, int64_t new_spatial_dim_size, bool increase_spatial_size = false); // Turns B, S0, S1, ..., Sn into B, B0, S0, B1, S1,... Bn, Sn. - StatusOr SplitAndTransposeMergedBatch( + absl::StatusOr SplitAndTransposeMergedBatch( HloInstruction* activations, int64_t batch_dimension, int64_t old_batch_size, absl::Span spatial_dimensions); // Function that converts spaced-to-batch shape back to the original. - StatusOr BatchToSpace(HloInstruction* old_instr); + absl::StatusOr BatchToSpace(HloInstruction* old_instr); // Duplicates elements at boundaries. - StatusOr HaloDuplicateWithSlice( + absl::StatusOr HaloDuplicateWithSlice( HloInstruction* activations, absl::Span spatial_dimensions_to_split, int64_t activations_batch_dim, int64_t low_padding, int64_t halo_size, HloInstruction* pad_val = nullptr); // Runs the visitor on a computation. - StatusOr Run(); + absl::StatusOr Run(); // Returns whether any convolution ops were rewritten. const bool changed() const { return changed_; } @@ -507,7 +508,7 @@ bool ConvolutionVisitor::IsThisBackPropFilterConv(HloInstruction* convolution) { return true; } -StatusOr ConvolutionVisitor::HaloDuplicateWithSlice( +absl::StatusOr ConvolutionVisitor::HaloDuplicateWithSlice( HloInstruction* activations, absl::Span spatial_dimensions_to_split, int64_t activations_batch_dim, int64_t low_padding, int64_t halo_size, @@ -636,7 +637,7 @@ StatusOr ConvolutionVisitor::HaloDuplicateWithSlice( return activations; } -StatusOr +absl::StatusOr ConvolutionVisitor::BringSpaceNextToBatch( HloInstruction* activations, ConvolutionDimensionNumbers& dim_numbers, int64_t& activations_batch_dim, @@ -741,7 +742,8 @@ ConvolutionVisitor::BringSpaceNextToBatch( return SpaceNextToBatchDetails{activations, transpose_dims}; } -StatusOr ConvolutionVisitor::SplitAndTransposeMergedBatch( +absl::StatusOr +ConvolutionVisitor::SplitAndTransposeMergedBatch( HloInstruction* activations, int64_t batch_dimension, int64_t old_batch_size, absl::Span spatial_dimensions) { CHECK_EQ(batch_dimension + 1, spatial_dimensions[0]); @@ -792,7 +794,7 @@ StatusOr ConvolutionVisitor::SplitAndTransposeMergedBatch( return batch_split_activations; } -StatusOr +absl::StatusOr ConvolutionVisitor::ChangeSpatialSizeOnSpaceToBatchedShape( HloInstruction* activations, int64_t batch_dimension, int64_t old_batch_size, absl::Span spatial_dimensions, @@ -881,7 +883,7 @@ ConvolutionVisitor::ChangeSpatialSizeOnSpaceToBatchedShape( return activations_new; } -StatusOr ConvolutionVisitor::Run() { +absl::StatusOr ConvolutionVisitor::Run() { for (auto conv : conv_visitor_list_) { // If we expect to see an unpropagatable op, space-to-batch may not be // beneficial. @@ -1770,8 +1772,8 @@ bool ConvolutionVisitor::SupportedOpForPropagation(HloInstruction* consumer, return false; } -StatusOr ConvolutionVisitor::Propagate(HloInstruction* consumer, - HloInstruction* producer) { +absl::StatusOr ConvolutionVisitor::Propagate(HloInstruction* consumer, + HloInstruction* producer) { auto computation = consumer->parent(); if (IsTrivialElementwise(consumer)) { auto dim_map_val = instr_to_dim_map_[producer]; @@ -2325,7 +2327,7 @@ StatusOr ConvolutionVisitor::Propagate(HloInstruction* consumer, return true; } -StatusOr ConvolutionVisitor::SelectValidPortion( +absl::StatusOr ConvolutionVisitor::SelectValidPortion( HloInstruction* new_instr, HloInstruction* old_instr, HloInstruction* select_val, int64_t new_batch_dim, absl::Span new_space_dims, int64_t old_batch_dim, @@ -2407,7 +2409,7 @@ StatusOr ConvolutionVisitor::SelectValidPortion( return new_instr; } -StatusOr ConvolutionVisitor::BatchToSpace( +absl::StatusOr ConvolutionVisitor::BatchToSpace( HloInstruction* old_instr) { if (batch_to_space_map_.count(old_instr)) { CHECK_NE(batch_to_space_map_[old_instr], nullptr); @@ -2885,7 +2887,7 @@ Status ConvolutionVisitor::PropagateOnSlice(HloInstruction* slice) { return OkStatus(); } -StatusOr ConvolutionVisitor::TransposeAndMergeBatch( +absl::StatusOr ConvolutionVisitor::TransposeAndMergeBatch( HloInstruction* activations, absl::Span final_split_spatial_dim_positioning, int64_t activations_batch_dim, int64_t old_batch_size) { @@ -2927,7 +2929,7 @@ StatusOr ConvolutionVisitor::TransposeAndMergeBatch( return batch_collapsed_reshape; } -StatusOr ConvolutionVisitor::PerformSplitSpace( +absl::StatusOr ConvolutionVisitor::PerformSplitSpace( HloInstruction* activations, absl::Span spatial_dimensions_to_split, int64_t activations_batch_dim, int64_t spatial_split_size, @@ -2973,7 +2975,7 @@ StatusOr ConvolutionVisitor::PerformSplitSpace( activations_batch_dim, old_batch_size); } -StatusOr ConvolutionVisitor::PadAndSplitSpace( +absl::StatusOr ConvolutionVisitor::PadAndSplitSpace( HloInstruction* activations, absl::Span spatial_dimensions_to_split, int64_t activations_batch_dim, int64_t high_padding, int64_t low_padding, @@ -3007,7 +3009,7 @@ StatusOr ConvolutionVisitor::PadAndSplitSpace( num_splits); } -StatusOr>> +absl::StatusOr>> ConvolutionVisitor::SplitSpace( HloInstruction* activations, ConvolutionDimensionNumbers& dim_numbers, int64_t& activations_batch_dim, int64_t high_padding, int64_t low_padding, @@ -3029,7 +3031,7 @@ ConvolutionVisitor::SplitSpace( return std::make_pair(new_activations, transpose_dims); } -StatusOr ConvolutionVisitor::PropagateOnConstant( +absl::StatusOr ConvolutionVisitor::PropagateOnConstant( HloInstruction* consumer, HloInstruction* producer) { CHECK(old_to_new_instrs_.contains(producer)); HloInstruction* new_producer = old_to_new_instrs_[producer]; @@ -3920,7 +3922,7 @@ Status ConvolutionVisitor::PerformSpaceToBatchOnConvolution( } // namespace -StatusOr SpaceToBatchConverter::Run( +absl::StatusOr SpaceToBatchConverter::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { XLA_VLOG_LINES( diff --git a/third_party/xla/xla/service/space_to_batch_converter.h b/third_party/xla/xla/service/space_to_batch_converter.h index 29a08231c29e9d..2d9dba06a2b581 100644 --- a/third_party/xla/xla/service/space_to_batch_converter.h +++ b/third_party/xla/xla/service/space_to_batch_converter.h @@ -57,7 +57,7 @@ class SpaceToBatchConverter : public HloModulePass { // Run convolution rewriting on the given computation. Returns whether the // computation was changed. using HloPassInterface::Run; - StatusOr Run( + absl::StatusOr Run( HloModule* module, const absl::flat_hash_set& execution_threads) override; diff --git a/third_party/xla/xla/service/spmd/BUILD b/third_party/xla/xla/service/spmd/BUILD index 2cd13d18fbb2d6..5f0e4af6a67712 100644 --- a/third_party/xla/xla/service/spmd/BUILD +++ b/third_party/xla/xla/service/spmd/BUILD @@ -4,7 +4,8 @@ load("//xla:xla.bzl", "xla_cc_test") load("@local_tsl//tsl/platform:rules_cc.bzl", "cc_library") package( - default_visibility = ["//visibility:public"], + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = [":friends"], licenses = ["notice"], ) @@ -32,7 +33,6 @@ cc_library( "spmd_partitioner.h", "spmd_partitioner_util.h", ], - visibility = ["//visibility:public"], deps = [ "//xla:array", "//xla:comparison_util", @@ -129,7 +129,6 @@ cc_library( name = "canonicalize_all_gather_for_cse", srcs = ["canonicalize_all_gather_for_cse.cc"], hdrs = ["canonicalize_all_gather_for_cse.h"], - visibility = ["//visibility:public"], deps = [ "//xla/hlo/ir:hlo", "//xla/hlo/utils:hlo_query", @@ -159,7 +158,6 @@ cc_library( name = "schedule_aware_collective_ops_cse", srcs = ["schedule_aware_collective_ops_cse.cc"], hdrs = ["schedule_aware_collective_ops_cse.h"], - visibility = ["//visibility:public"], deps = [ "//xla/hlo/ir:hlo", "//xla/service:hlo_pass", @@ -189,7 +187,6 @@ cc_library( name = "spmd_prepare", srcs = ["spmd_prepare.cc"], hdrs = ["spmd_prepare.h"], - visibility = ["//visibility:public"], deps = [ "//xla/hlo/ir:hlo", "//xla/hlo/utils:hlo_sharding_util", @@ -203,7 +200,6 @@ cc_library( name = "stateful_rng_spmd_partitioner", srcs = ["stateful_rng_spmd_partitioner.cc"], hdrs = ["stateful_rng_spmd_partitioner.h"], - visibility = ["//visibility:public"], deps = [ ":spmd_partitioner", "//xla/hlo/ir:hlo", @@ -235,7 +231,6 @@ cc_library( name = "collective_permute_motion", srcs = ["collective_permute_motion.cc"], hdrs = ["collective_permute_motion.h"], - visibility = ["//visibility:public"], deps = [ "//xla:comparison_util", "//xla:shape_util", @@ -268,7 +263,6 @@ cc_library( hdrs = [ "partition_assignment.h", ], - visibility = ["//visibility:public"], deps = [ "//xla/hlo/ir:hlo", "//xla/service:hlo_pass", @@ -294,7 +288,6 @@ cc_library( hdrs = [ "whole_graph_manual_pass.h", ], - visibility = ["//visibility:public"], deps = [ "//xla:statusor", "//xla/hlo/ir:hlo", diff --git a/third_party/xla/xla/service/stream_pool_test.cc b/third_party/xla/xla/service/stream_pool_test.cc index e227df525fe145..551a35cef1843c 100644 --- a/third_party/xla/xla/service/stream_pool_test.cc +++ b/third_party/xla/xla/service/stream_pool_test.cc @@ -17,6 +17,7 @@ limitations under the License. #include +#include "xla/stream_executor/platform_manager.h" #include "xla/stream_executor/stream_executor.h" #include "xla/test_helpers.h" @@ -27,7 +28,7 @@ class StreamPoolTest : public ::testing::Test { protected: std::unique_ptr NewStreamExecutor() { se::Platform* platform = - se::MultiPlatformManager::PlatformWithName("Host").value(); + se::PlatformManager::PlatformWithName("Host").value(); se::StreamExecutorConfig config(/*ordinal=*/0); return platform->GetUncachedExecutor(config).value(); } @@ -93,78 +94,5 @@ TEST_F(StreamPoolTest, TwoStreamPool) { EXPECT_NE(stream3_ptr, stream4_ptr); } -TEST_F(StreamPoolTest, BadStreamDiscarded) { - std::unique_ptr executor = NewStreamExecutor(); - StreamPool pool; - - // Borrow a stream. - StreamPool::Ptr stream1 = pool.BorrowStream(executor.get()); - EXPECT_TRUE(stream1->ok()); - - // Force an error on the stream; here we call a method that requires - // DNN support, which we know the Host platform doesn't support. - stream1->ThenDepthConcatenate({}, {}, nullptr); - EXPECT_FALSE(stream1->ok()); - - // Return stream1 and borrow stream2. - stream1 = nullptr; - StreamPool::Ptr stream2 = pool.BorrowStream(executor.get()); - se::Stream* stream2_ptr = stream2.get(); - EXPECT_TRUE(stream2->ok()); - - // The underlying streams should be different. They would have been - // the same, but since we forced an error on stream1, it cannot be - // put back into the pool. Sadly we can't just check: - // EXPECT_NE(stream1_ptr, stream2_ptr); - // - // The above should hold logically, but it may fail if the new - // stream instance allocated for stream2 happens to reside in the - // same memory address as stream1, which has been deleted. - // - // The check that stream2->ok() serves as a good-enough check. - - // Return stream2 and borrow stream3. The previous error on stream1 - // has no effect on these streams, and they are the same. - stream2 = nullptr; - StreamPool::Ptr stream3 = pool.BorrowStream(executor.get()); - se::Stream* stream3_ptr = stream3.get(); - EXPECT_TRUE(stream3->ok()); - EXPECT_EQ(stream2_ptr, stream3_ptr); -} - -TEST_F(StreamPoolTest, BadStreamAfterReturnDiscarded) { - std::unique_ptr executor = NewStreamExecutor(); - StreamPool pool; - - // Borrow a stream. - StreamPool::Ptr stream1 = pool.BorrowStream(executor.get()); - EXPECT_TRUE(stream1->ok()); - - // Return the stream, but hold a handle to it. - se::Stream* stream1_ptr = stream1.get(); - stream1 = nullptr; - - // Now stream1 is back in the pool, force an error on the stream. Here we call - // a method that requires DNN support, which we know the Host platform doesn't - // support. - stream1_ptr->ThenDepthConcatenate({}, {}, nullptr); - EXPECT_FALSE(stream1_ptr->ok()); - - // Borrow stream2. - StreamPool::Ptr stream2 = pool.BorrowStream(executor.get()); - EXPECT_TRUE(stream2->ok()); - - // The underlying streams should be different. They would have been - // the same, but since we forced an error on stream1, it cannot be - // put back into the pool. Sadly we can't just check: - // EXPECT_NE(stream1_ptr, stream2_ptr); - // - // The above should hold logically, but it may fail if the new - // stream instance allocated for stream2 happens to reside in the - // same memory address as stream1, which has been deleted. - // - // The check that stream2->ok() serves as a good-enough check. -} - } // namespace } // namespace xla diff --git a/third_party/xla/xla/service/topk_rewriter.cc b/third_party/xla/xla/service/topk_rewriter.cc index a138ea2e7907b4..0258d607b42848 100644 --- a/third_party/xla/xla/service/topk_rewriter.cc +++ b/third_party/xla/xla/service/topk_rewriter.cc @@ -16,6 +16,7 @@ limitations under the License. #include "xla/service/topk_rewriter.h" #include +#include #include #include #include @@ -29,8 +30,10 @@ limitations under the License. #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" +#include "xla/primitive_util.h" #include "xla/service/pattern_matcher.h" #include "xla/shape_util.h" +#include "xla/util.h" #include "tsl/platform/errors.h" #include "tsl/platform/logging.h" @@ -120,6 +123,36 @@ static bool IsNanSafeGt(HloComputation* comp) { param_s32); }; + auto match_generic_iec559 = [](int64_t parameter_number, + PrimitiveType fp_type, + PrimitiveType int_type) { + auto param = m::Parameter(parameter_number) + .WithShape(m::Shape().WithElementType(fp_type)); + auto signed_value = m::BitcastConvert(param).WithShape( + m::Shape().WithElementType(int_type)); + int64_t bit_width = primitive_util::BitWidth(fp_type); + auto max_value = m::ConstantScalar(LsbMask(bit_width - 1)); + auto flipped_value = m::XorAnyOrder(max_value, signed_value); + auto is_negative = m::Lt(signed_value, m::ConstantScalar(0)); + return m::Select(is_negative, flipped_value, signed_value); + }; + + auto match_generic_iec559_with_convert = + [](int64_t parameter_number, PrimitiveType param_type, + PrimitiveType fp_type, PrimitiveType int_type) { + auto param = m::Parameter(parameter_number) + .WithShape(m::Shape().WithElementType(param_type)); + auto convert = + m::Convert(param).WithShape(m::Shape().WithElementType(fp_type)); + auto signed_value = m::BitcastConvert(convert).WithShape( + m::Shape().WithElementType(int_type)); + int64_t bit_width = primitive_util::BitWidth(fp_type); + auto max_value = m::ConstantScalar(LsbMask(bit_width - 1)); + auto flipped_value = m::XorAnyOrder(max_value, signed_value); + auto is_negative = m::Lt(signed_value, m::ConstantScalar(0)); + return m::Select(is_negative, flipped_value, signed_value); + }; + auto match_s32 = [](int64_t parameter_number) { auto param = m::Parameter(parameter_number) .WithShape(m::Shape().WithElementType(S32)); @@ -155,6 +188,15 @@ static bool IsNanSafeGt(HloComputation* comp) { }; return Match(comp->root_instruction(), + m::Gt(match_generic_iec559(0, F32, S32), + match_generic_iec559(1, F32, S32))) || + Match(comp->root_instruction(), + m::Gt(match_generic_iec559(0, BF16, S16), + match_generic_iec559(1, BF16, S16))) || + Match(comp->root_instruction(), + m::Gt(match_generic_iec559_with_convert(0, BF16, F32, S32), + match_generic_iec559_with_convert(1, BF16, F32, S32))) || + Match(comp->root_instruction(), m::Gt(match_bitcast_f32(0), match_bitcast_f32(1))) || Match(comp->root_instruction(), m::Gt(match_bitcast_bf16(0), match_bitcast_bf16(1))) || diff --git a/third_party/xla/xla/service/transfer_manager.cc b/third_party/xla/xla/service/transfer_manager.cc index 5307001e111e35..11158d947a813f 100644 --- a/third_party/xla/xla/service/transfer_manager.cc +++ b/third_party/xla/xla/service/transfer_manager.cc @@ -15,23 +15,31 @@ limitations under the License. #include "xla/service/transfer_manager.h" +#include #include #include -#include #include +#include +#include "absl/base/const_init.h" #include "absl/cleanup/cleanup.h" -#include "absl/strings/str_cat.h" +#include "absl/container/flat_hash_map.h" +#include "absl/synchronization/mutex.h" +#include "xla/literal.h" #include "xla/service/compiler.h" #include "xla/service/maybe_owning_device_memory.h" +#include "xla/service/shaped_buffer.h" +#include "xla/shape.h" #include "xla/shape_util.h" +#include "xla/status.h" #include "xla/status_macros.h" -#include "xla/types.h" +#include "xla/stream_executor/platform.h" +#include "xla/stream_executor/stream.h" #include "xla/util.h" +#include "tsl/platform/errors.h" #include "tsl/platform/logging.h" #include "tsl/platform/notification.h" - -using absl::StrCat; +#include "tsl/platform/statusor.h" namespace xla { @@ -58,7 +66,7 @@ Status TransferManager::TransferLiteralFromDevice( se::Stream* stream, const ShapedBuffer& device_buffer, const MutableBorrowingLiteral& literal, const TransferMetadata* transfer_metadata) { - se::Stream* substream = stream->GetOrCreateSubStream(); + TF_ASSIGN_OR_RETURN(se::Stream * substream, stream->GetOrCreateSubStream()); substream->ThenWaitFor(stream); absl::Cleanup cleanup = [&]() { stream->ReturnSubStream(substream); }; @@ -82,7 +90,7 @@ Status TransferManager::TransferLiteralToDevice( // Implement the synchronous version by waiting on the asynchronous version. // Use a substream so that if we are called from a HostCallback we don't // deadlock. - se::Stream* substream = stream->GetOrCreateSubStream(); + TF_ASSIGN_OR_RETURN(se::Stream * substream, stream->GetOrCreateSubStream()); substream->ThenWaitFor(stream); absl::Cleanup cleanup = [&]() { stream->ReturnSubStream(substream); }; TF_RETURN_IF_ERROR(TransferLiteralToDeviceAsync( @@ -111,7 +119,7 @@ Status TransferManager::TransferArrayToDevice( // Implement the synchronous version by waiting on the asynchronous version. // Use a substream so that if we are called from a HostCallback we don't // deadlock. - se::Stream* substream = stream->GetOrCreateSubStream(); + TF_ASSIGN_OR_RETURN(se::Stream * substream, stream->GetOrCreateSubStream()); substream->ThenWaitFor(stream); absl::Cleanup cleanup = [&]() { stream->ReturnSubStream(substream); }; TF_RETURN_IF_ERROR( diff --git a/third_party/xla/xla/service/transfer_manager.h b/third_party/xla/xla/service/transfer_manager.h index 89fc0d876614b6..3157b37b627443 100644 --- a/third_party/xla/xla/service/transfer_manager.h +++ b/third_party/xla/xla/service/transfer_manager.h @@ -16,19 +16,24 @@ limitations under the License. #ifndef XLA_SERVICE_TRANSFER_MANAGER_H_ #define XLA_SERVICE_TRANSFER_MANAGER_H_ -#include -#include -#include +#include +#include +#include #include "absl/container/flat_hash_map.h" +#include "absl/synchronization/mutex.h" #include "absl/types/span.h" #include "xla/literal.h" -#include "xla/service/executable.h" +#include "xla/service/maybe_owning_device_memory.h" #include "xla/service/shaped_buffer.h" +#include "xla/shape.h" +#include "xla/shape_tree.h" +#include "xla/shape_util.h" +#include "xla/status.h" #include "xla/statusor.h" #include "xla/stream_executor/device_memory.h" +#include "xla/stream_executor/platform.h" #include "xla/stream_executor/stream_executor.h" -#include "xla/types.h" #include "xla/xla_data.pb.h" namespace xla { diff --git a/third_party/xla/xla/service/while_loop_unroller.cc b/third_party/xla/xla/service/while_loop_unroller.cc index 6b2c03669f04ba..4f96cc48ef53e1 100644 --- a/third_party/xla/xla/service/while_loop_unroller.cc +++ b/third_party/xla/xla/service/while_loop_unroller.cc @@ -15,7 +15,6 @@ limitations under the License. #include "xla/service/while_loop_unroller.h" -#include #include #include #include @@ -25,11 +24,14 @@ limitations under the License. #include "absl/algorithm/algorithm.h" #include "absl/algorithm/container.h" +#include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" #include "absl/log/check.h" #include "absl/log/log.h" +#include "absl/status/status.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" +#include "xla/comparison_util.h" #include "xla/hlo/evaluator/hlo_evaluator.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" @@ -41,6 +43,7 @@ limitations under the License. #include "xla/overflow_util.h" #include "xla/primitive_util.h" #include "xla/service/call_inliner.h" +#include "xla/service/collective_ops_utils.h" #include "xla/service/flatten_call_graph.h" #include "xla/service/hlo_cse.h" #include "xla/service/hlo_pass_fix.h" @@ -48,6 +51,7 @@ limitations under the License. #include "xla/service/while_loop_analysis.h" #include "xla/service/while_loop_constant_sinking.h" #include "xla/shape.h" +#include "xla/shape_util.h" #include "xla/statusor.h" #include "xla/util.h" #include "xla/xla_data.pb.h" @@ -56,30 +60,188 @@ limitations under the License. namespace xla { namespace { + using hlo_query::ContainsInstrWithOpcode; +// Parameters for the unroller that can be adjusted. const int kUnrollTripCountThreshold = 64; const int kUnrollInstructionCountThreshold = 800; const int kUnrollExpandFactorThreshold = 10000; -}; // namespace + +// The following sequence of passes are necessary to prepare loops for +// unrolling. Failure to run these passes will prevent unroller from unrolling +// loops that would have been otherwise unrollable. +// +// Instead of placing these passes in compiler, they are placed +// here to indicate explicit dependency to these passes. +StatusOr PrepareModuleForUnrolling( + HloModule* module, + const absl::flat_hash_set& execution_threads) { + bool changed = false; + TF_ASSIGN_OR_RETURN( + bool applied_cse, + HloCSE{/*is_layout_sensitive=*/true}.Run(module, execution_threads)); + if (applied_cse) { + changed = true; + VLOG(3) << "Applied hlo cse to module " << module->name(); + } + + TF_ASSIGN_OR_RETURN(bool applied_tuple_simplifier, + TupleSimplifier{}.Run(module, execution_threads)); + if (applied_tuple_simplifier) { + changed = true; + VLOG(3) << "Applied tuple simplifier to module " << module->name(); + } + + // We apply constant sinking to fix point. + HloPassFix constant_sinking( + /*sink_broadcast_of_constants=*/true); + TF_ASSIGN_OR_RETURN(bool applied_constant_sinking, + constant_sinking.Run(module, execution_threads)); + if (applied_constant_sinking) { + changed = true; + VLOG(3) << "Applied constant sinking to module " << module->name(); + } + return changed; +} + +// A utility function that decides whether a loop is unrollable or not. +std::optional IsLoopUnrollable(HloInstruction* while_op) { + CHECK_EQ(while_op->opcode(), HloOpcode::kWhile); + + // TODO(b/300668690): Add support for unrolling loops with control dependency. + // For now, we bail. + // + // Finding all the while loops where other instructions have explicit control + // dependencies on them. + std::vector while_dependees; + for (HloComputation* comp : while_op->GetModule()->computations()) { + for (HloInstruction* instr : comp->instructions()) { + for (HloInstruction* control_dep : instr->control_predecessors()) { + if (control_dep->opcode() == HloOpcode::kWhile) { + while_dependees.push_back(control_dep); + } + } + } + } + if (absl::linear_search(while_dependees.begin(), while_dependees.end(), + while_op)) { + VLOG(2) << "Not attempting to unroll " << while_op->name() + << " due to control dependency: " << while_op->ToShortString(); + return std::nullopt; + } + + // We can't remove while loops that contain send/recv nodes, because we + // rely on the particular loop structure around the node matching on the + // send and recv sides. + if (ContainsInstrWithOpcode(while_op->while_body(), + {HloOpcode::kSend, HloOpcode::kSendDone, + HloOpcode::kRecv, HloOpcode::kRecvDone}) || + ContainsInstrWithOpcode(while_op->while_condition(), + {HloOpcode::kSend, HloOpcode::kSendDone, + HloOpcode::kRecv, HloOpcode::kRecvDone})) { + VLOG(2) << "Not attempting to unroll " << while_op->name() + << " because it contains a send/recv node: " + << while_op->ToShortString(); + return std::nullopt; + } + + if (while_op->operand(0)->opcode() != HloOpcode::kTuple) { + VLOG(2) << "Not attempting to unroll " << while_op->name() + << " because the operand is not a tuple: " + << while_op->ToShortString(); + return std::nullopt; + } + + // We cannot unroll loops that have side effecting condition because the + // condition will be removed after unrolling. This might be relaxed + // later when we add partial unrolling. + if (while_op->while_condition()->HasSideEffect()) { + VLOG(2) << "Not attempting to remove while loop whose condition contains " + "side-effecting instructions: " + << while_op->ToShortString(); + return std::nullopt; + } + + std::optional indvar_tuple_idx = + GetLoopInductionVarTupleIdx(while_op); + if (!indvar_tuple_idx.has_value()) { + return std::nullopt; + } + + HloEvaluator evaluator(/*max_loop_iterations=*/0); + const HloInstruction* while_init = while_op->operand(0); + const HloInstruction* indvar_init = while_init->operand(*indvar_tuple_idx); + StatusOr indvar_init_result = evaluator.Evaluate(indvar_init); + if (!indvar_init_result.ok()) { + VLOG(2) << "Couldn't evaluate induction variable init, " + << indvar_init_result.status() << ", " << indvar_init->ToString(); + return std::nullopt; + } + Literal indvar_iter_val = std::move(indvar_init_result).value(); + + std::optional trip_count = + MatchTrivialLoopTripCount(while_op, *indvar_tuple_idx, indvar_iter_val); + if (!trip_count.has_value()) { + return std::nullopt; + } + + VLOG(3) << "Loop trip count " << trip_count.value(); + + WhileLoopConfig config; + config.init = + LiteralUtil::LiteralAsScalarInt64(std::move(indvar_iter_val)).value(); + config.trip_count = trip_count.value(); + config.induction_var_idx = *indvar_tuple_idx; + + return config; +} + +std::unique_ptr GetConstantWithPrimitiveType(PrimitiveType type, + int64_t value) { + return primitive_util::PrimitiveTypeSwitch>( + [&](auto literal_constant) -> std::unique_ptr { + if constexpr (primitive_util::IsIntegralType(literal_constant)) { + using NativeT = primitive_util::NativeTypeOf; + return HloInstruction::CreateConstant( + LiteralUtil::CreateR0(static_cast(value))); + } + LOG(FATAL) << "literal is of non-integral type"; + }, + type); +} // Helper function that replaces a single iteration of a while loop with // induction variable equal to induction_value. -static StatusOr> -UnrollSingleIterationOfTrivialLoop(HloInstruction* while_op, - const int64_t indvar_idx, - const int64_t induction_value) { +StatusOr> UnrollSingleIterationOfTrivialLoop( + HloInstruction* while_op, const int64_t indvar_idx, + const int64_t induction_value) { // We clone the body since we are changing the computation. std::unique_ptr while_body_clone = while_op->while_body()->Clone(absl::StrCat(induction_value)); - const HloInstruction* induction_var_hlo = - while_op->operand(0)->operand(indvar_idx); + HloInstruction* induction_var_hlo = + while_op->mutable_operand(0)->mutable_operand(indvar_idx); + + // We record the next channel id to utilize when unrolling loops with + // collective communication instructions. During unrolling a single iteration + // of the body, we can reuse the same unique_channel_id. For the later + // iterations, we obtain it again. + int64_t unique_channel_id = hlo_query::NextChannelId(*while_op->GetModule()); // Go through the instructions in while body to get the instruction that // points to the induction var. Then replace it everywhere with the concrete // value. for (HloInstruction* body_inst : while_body_clone->instructions()) { + // We need to assign a unique channel_id for the collective ops that are + // unrolled within the while loop body or fusions containing collectives. + if (IsCollectiveWithChannelId(body_inst)) { + // To obtain the channel_id for the collective ops we only need to + // increment the `unique_channel_id` since it records the next available + // channel_id across the module. + body_inst->set_channel_id(unique_channel_id++); + } + if (body_inst->opcode() != HloOpcode::kGetTupleElement) { continue; } @@ -103,22 +265,8 @@ UnrollSingleIterationOfTrivialLoop(HloInstruction* while_op, // Found the induction var as an operand of body instruction. if (indvar_use_operand == body_inst) { std::unique_ptr constant = - primitive_util::PrimitiveTypeSwitch< - std::unique_ptr>( - [&](auto literal_constant) - -> std::unique_ptr { - if constexpr (primitive_util::IsIntegralType( - literal_constant)) { - using NativeT = - primitive_util::NativeTypeOf; - return HloInstruction::CreateConstant( - LiteralUtil::CreateR0( - static_cast(induction_value))); - } - LOG(FATAL) << "literal is of non-integral type"; - }, - induction_var_hlo->shape().element_type()); - + GetConstantWithPrimitiveType( + induction_var_hlo->shape().element_type(), induction_value); // Assign the same shape of the old instruction to the new // instruction. *constant->mutable_shape() = body_inst->shape(); @@ -132,45 +280,182 @@ UnrollSingleIterationOfTrivialLoop(HloInstruction* while_op, return while_body_clone; } -StatusOr WhileLoopUnroller::Run( - HloModule* module, - const absl::flat_hash_set& execution_threads) { - // TODO(b/288130138) For now, we only support full unrolling. Will add partial - // unrolling if needed. - if (unroll_factor_ != -1) { - return false; +// Helper function to create a condition for a single iteration while loop in +// the form of 'i <= init_value' where i is the induction variable. +std::unique_ptr MakeSingleIterWhileCond( + HloInstruction* while_op, int64_t induction_idx, int64_t init_value) { + auto condition_builder = + HloComputation::Builder(absl::StrCat("unrolled-cond-", while_op->name())); + + auto param_instruction = condition_builder.AddParameter( + while_op->while_condition()->parameter_instruction(0)->Clone()); + + CHECK_OK(param_instruction); + + HloInstruction* indvar_instruction = condition_builder.AddInstruction( + HloInstruction::CreateGetTupleElement(*param_instruction, induction_idx)); + + auto init_value_constant = + condition_builder.AddInstruction(GetConstantWithPrimitiveType( + indvar_instruction->shape().element_type(), init_value)); + + return condition_builder.Build( + condition_builder.AddInstruction(HloInstruction::CreateCompare( + ShapeUtil::MakeShape(PrimitiveType::PRED, {}), indvar_instruction, + init_value_constant, ComparisonDirection::kLe))); +} + +absl::Status InitialFeasibilityCheck(HloInstruction* while_op, + WhileLoopConfig config, + int64_t unroll_factor) { + CHECK_EQ(while_op->opcode(), HloOpcode::kWhile); + + // While loop must have a single tuple operand. + CHECK_EQ(while_op->operands().size(), 1); + if (while_op->operands().size() != 1) { + return FailedPrecondition( + "%s", + absl::StrCat("Cannot unroll while loop. While loop must have a single " + "tuple operand, instead has more than one operand: ", + while_op->operands().size())); } - XLA_VLOG_LINES(3, "WhileLoopUnroller::Run(), before:\n" + module->ToString()); - bool changed = false; - // The following sequence of passes are necessary to prepare loops for - // unrolling. Instead of placing these passes in compiler, they are placed - // here to indicate explicit dependency to these passes. - TF_ASSIGN_OR_RETURN( - bool applied_cse, - HloCSE{/*is_layout_sensitive=*/true}.Run(module, execution_threads)); - if (applied_cse) { - changed = true; - VLOG(3) << "Applied hlo cse to module " << module->name(); + VLOG(5) << "Trying to unroll " << while_op->ToShortString(); + + // TODO(b/288130138): For now, we only support full unrolling. Will add + // partial unrolling if needed. + if (unroll_factor != -1) { + return UnimplementedStrCat( + "Currently, only full unrolling is supported, unroll factor: ", + unroll_factor); } - TF_ASSIGN_OR_RETURN(bool applied_tuple_simplifier, - TupleSimplifier{}.Run(module, execution_threads)); - if (applied_tuple_simplifier) { - changed = true; - VLOG(3) << "Applied tuple simplifier to module " << module->name(); + // TODO(b/291628533): Extract this parameter to the unroller config. We don't + // attempt to unroll loops where the body has more than + // kUnrollInstructionCountThreshold instructions. + if (while_op->while_body()->instruction_count() > + kUnrollInstructionCountThreshold) { + return FailedPrecondition( + "%s", + absl::StrCat( + "Cannot unroll while loop. Too many instructions in the body: ", + while_op->while_body()->instruction_count())); } - // We apply constant sinking to fix point. - HloPassFix constant_sinking( - /*sink_broadcast_of_constants=*/true); - TF_ASSIGN_OR_RETURN(bool applied_constant_sinking, - constant_sinking.Run(module, execution_threads)); - if (applied_constant_sinking) { - VLOG(3) << "Applied constant sinking to module " << module->name(); + // TODO(b/291628533): Extract this parameter to the an unroller config. We + // only unroll loops up to a threshold. + if (config.trip_count > kUnrollTripCountThreshold) { + return FailedPrecondition( + "%s", + absl::StrCat("Cannot unroll while loop. The tip count is greater " + "than the threshold: ", + config.trip_count, " vs ", kUnrollTripCountThreshold)); } - // Processing the while loops in the reverse of topological order. If the body + // TODO(b/291628533): Extract this parameter to the unroller config. We don't + // unroll loops that increase the instruction count by more than + // kUnrollExpandFactorThreshold. + if (config.trip_count * while_op->while_body()->instruction_count() > + kUnrollExpandFactorThreshold) { + return FailedPrecondition( + "%s", absl::StrCat("Not attempting to unroll due to instruction count " + "increase explosion. New instruction count: ", + config.trip_count * + while_op->while_body()->instruction_count(), + " vs ", kUnrollExpandFactorThreshold)); + } + return absl::OkStatus(); +} + +StatusOr UnrollInternal(HloInstruction* while_op, WhileLoopConfig config, + int64_t unroll_factor) { + TF_RETURN_IF_ERROR(InitialFeasibilityCheck(while_op, config, unroll_factor)); + + VLOG(3) << "Unrolling while instruction " << while_op->ToShortString() + << " with body instruction count " + << while_op->while_body()->instruction_count(); + + HloModule* module = while_op->GetModule(); + HloComputation* computation = while_op->parent(); + HloInstruction* unrolled_body_call_op; + std::vector call_operands = {while_op->operands().at(0)}; + for (int64_t i = config.init; i < config.trip_count + config.init; ++i) { + CHECK(OverflowSafeAdd(i, (int64_t)1).has_value()); + + HloComputation* unrolled_body = module->AddEmbeddedComputation( + UnrollSingleIterationOfTrivialLoop(while_op, config.induction_var_idx, + i) + .value()); + unrolled_body_call_op = + computation->AddInstruction(HloInstruction::CreateCall( + while_op->shape(), call_operands, unrolled_body)); + call_operands.clear(); + call_operands.emplace_back(unrolled_body_call_op); + } + TF_RETURN_IF_ERROR( + computation->ReplaceInstruction(while_op, unrolled_body_call_op)); + + // Needed for the nested while loops in which the outer loop has been + // unrolled which leaves the call graph non-flat. + TF_RETURN_IF_ERROR(FlattenCallGraph().Run(module).status()); + return true; +} + +StatusOr UnrollInternalWrapped(HloInstruction* while_op, + WhileLoopConfig config, + int64_t unroll_factor) { + TF_RETURN_IF_ERROR(InitialFeasibilityCheck(while_op, config, unroll_factor)); + + VLOG(3) << "Unrolling (wrapped) while instruction " + << while_op->ToShortString() << " with body instruction count " + << while_op->while_body()->instruction_count(); + + HloModule* module = while_op->GetModule(); + HloComputation* computation = while_op->parent(); + HloInstruction* unrolled_body_call_op; + + auto body_builder = + HloComputation::Builder(absl::StrCat("unrolled-body-", while_op->name())); + StatusOr p = body_builder.AddParameter( + while_op->while_body()->parameter_instruction(0)->Clone()); + + std::vector call_operands = {p.value()}; + for (int64_t i = config.init; i < config.trip_count + config.init; ++i) { + CHECK(OverflowSafeAdd(i, (int64_t)1).has_value()); + + HloComputation* unrolled_body = module->AddEmbeddedComputation( + UnrollSingleIterationOfTrivialLoop(while_op, config.induction_var_idx, + i) + .value()); + unrolled_body_call_op = + body_builder.AddInstruction(HloInstruction::CreateCall( + while_op->shape(), call_operands, unrolled_body)); + call_operands.clear(); + call_operands.emplace_back(unrolled_body_call_op); + } + HloComputation* new_body = + module->AddEmbeddedComputation(body_builder.Build(unrolled_body_call_op)); + HloComputation* new_cond = module->AddEmbeddedComputation( + MakeSingleIterWhileCond(while_op, config.induction_var_idx, config.init)); + + HloInstruction* new_while_op = + computation->AddInstruction(HloInstruction::CreateWhile( + while_op->shape(), new_cond, new_body, while_op->mutable_operand(0))); + + CHECK_OK(computation->ReplaceInstruction(while_op, new_while_op)); + + // Needed for the nested while loops in which the outer loop has been + // unrolled which leaves the call graph non-flat. + TF_RETURN_IF_ERROR(FlattenCallGraph().Run(module).status()); + return true; +} + +}; // namespace + +absl::flat_hash_map GetUnrollableLoops( + HloModule* module, + const absl::flat_hash_set& execution_threads) { + // Processing the while loops in the reverse topological order. If the body // of while loop A calls while loop B, B comes before A. std::vector all_while_ops; for (auto* comp : module->MakeComputationPostOrder(execution_threads)) { @@ -178,184 +463,97 @@ StatusOr WhileLoopUnroller::Run( HloPredicateIsOp); } - // Finding all the while loops where other instructions have explicit control - // dependencies on them. - std::vector while_with_deps; - for (HloComputation* comp : module->computations(execution_threads)) { - for (HloInstruction* instr : comp->instructions()) { - for (HloInstruction* control_dep : instr->control_predecessors()) { - if (control_dep->opcode() == HloOpcode::kWhile) { - if (std::find(all_while_ops.begin(), all_while_ops.end(), - control_dep) != all_while_ops.end()) { - while_with_deps.push_back(control_dep); - } - } - } - } - } - - // Gather a preliminary vector of all the while ops that we think we can - // unroll. We only consider while loops that take a tuple as an argument. We - // do this ahead of time so we don't have to worry about mutating the lists of - // computations or instructions while we iterate. - std::vector while_ops; + absl::flat_hash_map while_loop_configs; for (HloInstruction* instr : all_while_ops) { - // TODO(b/300668690): Check control dependencies to the while - // instruction - if (absl::linear_search(while_with_deps.begin(), while_with_deps.end(), - instr)) { - VLOG(2) << "Not attempting to unroll " << instr->name() - << " due to control dependency: " << instr->ToShortString(); - continue; + std::optional config = IsLoopUnrollable(instr); + if (config.has_value()) { + while_loop_configs[instr] = *config; } - - // We can't remove while loops that contain send/recv nodes, because we - // rely on the particular loop structure around the node matching on the - // send and recv sides. - if (ContainsInstrWithOpcode(instr->while_body(), - {HloOpcode::kSend, HloOpcode::kSendDone, - HloOpcode::kRecv, HloOpcode::kRecvDone}) || - ContainsInstrWithOpcode(instr->while_condition(), - {HloOpcode::kSend, HloOpcode::kSendDone, - HloOpcode::kRecv, HloOpcode::kRecvDone})) { - VLOG(2) << "Not attempting to unroll " << instr->name() - << " because it contains a send/recv node: " - << instr->ToShortString(); - continue; - } - // TODO(b/291146216): Handle this case later - if (ContainsInstrWithOpcode(instr->while_body(), {HloOpcode::kAllReduce, - HloOpcode::kAllGather})) { - VLOG(2) << "Not attempting to unroll " << instr->name() - << " for now because it contains an all-reduce or an all-gather: " - << instr->ToShortString(); - continue; - } - if (instr->operand(0)->opcode() != HloOpcode::kTuple) { - VLOG(2) << "Not attempting to unroll " << instr->name() - << " because the operand is not a tuple: " - << instr->ToShortString(); - continue; - } - // We cannot unroll loops that have side effecting condition because the - // condition will be removed after unrolling. This might be relaxed - // later when we add partial unrolling. - if (instr->while_condition()->HasSideEffect()) { - VLOG(2) << "Not attempting to remove while loop whose condition contains " - "side-effecting instructions: " - << instr->ToShortString(); - return false; - } - // TODO(b/291628533): Extract this to the unroller config - if (instr->while_body()->instruction_count() > - kUnrollInstructionCountThreshold) { - continue; - } - while_ops.push_back(instr); } + return while_loop_configs; +} - VLOG(3) << "Number of while instructions in the module to unroll: " - << while_ops.size(); +StatusOr Unroll(HloInstruction* while_op, int64_t unroll_factor, + bool wrap_in_trivial_loop) { + bool changed = false; + HloModule* module = while_op->GetModule(); - for (HloInstruction* while_op : while_ops) { - VLOG(3) << "Trying to unroll " << while_op->ToShortString(); - bool unrolled_current_loop = false; - int64_t unroll_factor_current_loop = unroll_factor_; + // Make sure all the necessary passes are executed before unrolling in order + // to unroll every possible loop. + TF_ASSIGN_OR_RETURN( + changed, PrepareModuleForUnrolling(module, /*execution_threads=*/{})); - // TODO(b/288130138) For now, we only support full unrolling. Will add - // partial unrolling if needed. - CHECK_EQ(unroll_factor_current_loop, -1); + // Construct the loop config + std::optional config = IsLoopUnrollable(while_op); + if (!config.has_value()) { + return false; + } - std::optional indvar_tuple_idx = - GetLoopInductionVarTupleIdx(while_op); - if (!indvar_tuple_idx.has_value()) { - continue; - } + bool unrolled = false; + if (wrap_in_trivial_loop) { + TF_ASSIGN_OR_RETURN(unrolled, UnrollInternalWrapped( + while_op, config.value(), unroll_factor)); + } else { + TF_ASSIGN_OR_RETURN( + unrolled, UnrollInternal(while_op, config.value(), unroll_factor)); + } - HloEvaluator evaluator(/*max_loop_iterations=*/0); - const HloInstruction* while_init = while_op->operand(0); - const HloInstruction* indvar_init = while_init->operand(*indvar_tuple_idx); - StatusOr indvar_init_result = evaluator.Evaluate(indvar_init); - if (!indvar_init_result.ok()) { - VLOG(2) << "Couldn't evaluate induction variable init, " - << indvar_init_result.status() << ", " << indvar_init->ToString(); - continue; - } - Literal indvar_iter_val = std::move(indvar_init_result).value(); + // We need to inline the calls created for unrolling since later passes rely + // on the calls to be inlined. + if (unrolled) { + TF_RETURN_IF_ERROR(CallInliner().Run(module).status()); + } + return unrolled; +} - // TODO(b/288907795): Try using ComputeWhileLoopTripCount - std::optional trip_count = - MatchTrivialLoopTripCount(while_op, *indvar_tuple_idx, indvar_iter_val); - if (!trip_count.has_value()) { - continue; - } +StatusOr WhileLoopUnroller::Run( + HloModule* module, + const absl::flat_hash_set& execution_threads) { + // TODO(b/288130138) For now, we only support full unrolling. Will add partial + // unrolling if needed. + if (unroll_factor_ != -1) { + return false; + } + XLA_VLOG_LINES(3, "WhileLoopUnroller::Run(), before:\n" + module->ToString()); + bool changed = false; - VLOG(3) << "Loop trip count " << trip_count.value(); + // Make sure all the necessary passes are executed before unrolling in order + // to unroll every possible loop. + TF_ASSIGN_OR_RETURN(changed, + PrepareModuleForUnrolling(module, execution_threads)); - // TODO(b/291628533): Extract this to the unroller config. We only unroll - // loops up to a threshold. - if (trip_count > kUnrollTripCountThreshold) { - continue; - } + // Processing the while loops in the reverse of topological order. If the body + // of while loop A calls while loop B, B comes before A. + std::vector all_while_ops; + for (auto* comp : module->MakeComputationPostOrder(execution_threads)) { + absl::c_copy_if(comp->instructions(), std::back_inserter(all_while_ops), + HloPredicateIsOp); + } - unroll_factor_current_loop = trip_count.value(); - - // TODO(b/291628533): Extract this to the unroller config. We don't unroll - // loops that increase the instruction count by more than - // kUnrollExpandFactorThreshold. - if (trip_count.value() * while_op->while_body()->instruction_count() > - kUnrollExpandFactorThreshold) { - VLOG(3) << "Not attempting to unroll due to instruction count increase " - "explosion."; - VLOG(3) << "New instruction count: " - << trip_count.value() * - while_op->while_body()->instruction_count(); - continue; - } + // Gather a preliminary vector of all the while ops that we think we can + // unroll. We do this ahead of time so we don't have to worry about mutating + // the lists of computations or instructions while we iterate. + absl::flat_hash_map unrollable_while_ops = + GetUnrollableLoops(module, execution_threads); - std::optional init_value = - LiteralUtil::LiteralAsScalarInt64(indvar_iter_val); - // Init value must be int64_t at this point since we found the trip count. - CHECK(init_value.has_value()); - - unrolled_current_loop = true; - VLOG(3) << "Unrolling while instruction " << while_op->ToShortString() - << " with body instruction count " - << while_op->while_body()->instruction_count(); - HloComputation* computation = while_op->parent(); - HloInstruction* unrolled_body_call_op; - std::vector call_operands; - // We assume while has only one tuple parameter - call_operands.emplace_back(while_op->operands().at(0)); - for (int64_t i = init_value.value(); - i < unroll_factor_current_loop + init_value.value(); ++i) { - CHECK(OverflowSafeAdd(i, (int64_t)1).has_value()); - - HloComputation* unrolled_body = module->AddEmbeddedComputation( - UnrollSingleIterationOfTrivialLoop(while_op, *indvar_tuple_idx, i) - .value()); - unrolled_body_call_op = - computation->AddInstruction(HloInstruction::CreateCall( - while_op->shape(), call_operands, unrolled_body)); - call_operands.clear(); - call_operands.emplace_back(unrolled_body_call_op); - } - CHECK_OK(computation->ReplaceInstruction(while_op, unrolled_body_call_op)); - - // Need to perform following passes only if the current while loop has been - // unrolled. - if (unrolled_current_loop) { - // Needed for the nested while loops in which the outer loop has been - // unrolled which leaves the call graph non-flat. - TF_RETURN_IF_ERROR( - FlattenCallGraph().Run(module, execution_threads).status()); + VLOG(3) << "Number of while instructions in the module to unroll: " + << unrollable_while_ops.size(); + + bool unrolled = false; + for (auto& [while_op, config] : unrollable_while_ops) { + if (wrap_in_trivial_loop_) { + TF_ASSIGN_OR_RETURN( + unrolled, UnrollInternalWrapped(while_op, config, unroll_factor_)); + } else { + TF_ASSIGN_OR_RETURN(unrolled, + UnrollInternal(while_op, config, unroll_factor_)); } - changed |= unrolled_current_loop; + changed |= unrolled; } + // We need to inline the calls created for unrolling since later passes rely + // on the calls to be inlined. if (changed) { - // We need to inline the calls created for unrolling since later passes rely - // on the calls to be inlined. TF_RETURN_IF_ERROR(CallInliner().Run(module, execution_threads).status()); } diff --git a/third_party/xla/xla/service/while_loop_unroller.h b/third_party/xla/xla/service/while_loop_unroller.h index 751b672034d1b6..bd9bebaa3d16e0 100644 --- a/third_party/xla/xla/service/while_loop_unroller.h +++ b/third_party/xla/xla/service/while_loop_unroller.h @@ -17,16 +17,39 @@ limitations under the License. #define XLA_SERVICE_WHILE_LOOP_UNROLLER_H_ #include +#include +#include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" -#include "absl/status/statusor.h" #include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/service/hlo_pass_interface.h" #include "xla/statusor.h" namespace xla { +// Config for unrollable while loops. +struct WhileLoopConfig { + // The initial value of the induction variable of the while loop. + int64_t init; + // The number of iterations the loop executes. + int64_t trip_count; + // The index of the induction variable in the input tuple of the while loop. + int64_t induction_var_idx; +}; + +// Returns the list of unrollable loops in the given module +absl::flat_hash_map GetUnrollableLoops( + HloModule* module, + const absl::flat_hash_set& execution_threads); + +// Unrolls the given while loop with the defaul behaviour set to full unroll. If +// wrap_in_trivial_loop is set, the unrolled body of the loop will be wrapped in +// a loop with trip count of one. +StatusOr Unroll(HloInstruction* while_op, int64_t unroll_factor = -1, + bool wrap_in_trivial_loop = false); + // This pass unrolls while loops with the given unrolling factor. The value of // unroll_factor = -1 will fully unroll the loop. // @@ -42,8 +65,10 @@ class WhileLoopUnroller : public HloModulePass { ~WhileLoopUnroller() override = default; // Default unroll_factor of -1 indicates full unrolling - explicit WhileLoopUnroller(int64_t unroll_factor = -1) - : unroll_factor_(unroll_factor) {} + explicit WhileLoopUnroller(int64_t unroll_factor = -1, + bool wrap_in_trivial_loop = false) + : unroll_factor_(unroll_factor), + wrap_in_trivial_loop_(wrap_in_trivial_loop) {} absl::string_view name() const override { return "while_loop_unroller"; } @@ -54,6 +79,8 @@ class WhileLoopUnroller : public HloModulePass { private: int64_t unroll_factor_; + // Whether to wrap the unrolled computation in a loop with trip count of one. + bool wrap_in_trivial_loop_; }; } // namespace xla diff --git a/third_party/xla/xla/service/while_loop_unroller_test.cc b/third_party/xla/xla/service/while_loop_unroller_test.cc index 0c529a7e7dd249..d8122067f30bf4 100644 --- a/third_party/xla/xla/service/while_loop_unroller_test.cc +++ b/third_party/xla/xla/service/while_loop_unroller_test.cc @@ -16,21 +16,26 @@ limitations under the License. #include "xla/service/while_loop_unroller.h" #include +#include #include #include #include #include +#include #include +#include "absl/algorithm/container.h" #include "absl/log/log.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_replace.h" #include "absl/types/span.h" #include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_opcode.h" #include "xla/literal.h" #include "xla/tests/hlo_test_base.h" #include "xla/tests/literal_test_util.h" #include "xla/tests/verified_hlo_module.h" +#include "tsl/platform/statusor.h" namespace xla { namespace { @@ -47,15 +52,19 @@ class WhileLoopUnrollerTest : public HloTestBase { MakeModuleWithLoopBodyNestedCopyIndVar(int num_iters); [[nodiscard]] std::unique_ptr MakeModuleWithWhileFeedingAnotherWhile(int num_iters); + [[nodiscard]] std::unique_ptr + MakeModuleWithSimpleLoopAllReduce(int num_iters); public: void UnrollAndCompare(std::unique_ptr module, absl::Span arguments, - int64_t unroll_factor = -1) { + int64_t unroll_factor = -1, bool wrap_in_loop = false) { Literal before_unroll = ExecuteAndTransfer(module->Clone(), arguments); - VLOG(2) << "after unroll value: " << before_unroll.ToString(); + VLOG(2) << "before unroll value: " << before_unroll.ToString(); - EXPECT_TRUE(WhileLoopUnroller(unroll_factor).Run(module.get()).value()); + EXPECT_TRUE(WhileLoopUnroller(unroll_factor, wrap_in_loop) + .Run(module.get()) + .value()); Literal after_unroll = ExecuteAndTransfer(std::move(module), arguments); VLOG(2) << "after unroll value: " << after_unroll.ToString(); @@ -293,8 +302,55 @@ WhileLoopUnrollerTest::MakeModuleWithWhileFeedingAnotherWhile(int num_iters) { return ParseAndReturnVerifiedModule(hlo_string).value(); } +std::unique_ptr +WhileLoopUnrollerTest::MakeModuleWithSimpleLoopAllReduce(int num_iters) { + std::string hlo_string_template = R"( + HloModule SimpleLoop + + %reduction { + %x = f32[] parameter(0) + %y = f32[] parameter(1) + ROOT %add = f32[] add(f32[] %x, f32[] %y) + } + + SimpleLoop.body { + loop_var.1 = (s32[], f32[1024, 1024], f32[1024, 1024]) parameter(0) + get-tuple-element.1 = s32[] get-tuple-element(loop_var.1), index=0 + get-tuple-element.2 = f32[1024, 1024] get-tuple-element(loop_var.1), index=1 + get-tuple-element.3 = f32[1024, 1024] get-tuple-element(loop_var.1), index=2 + + %all-reduce = f32[1024, 1024] all-reduce(f32[1024, 1024] get-tuple-element.2), channel_id=1, replica_groups={{0}}, to_apply=%reduction + %accumulation = f32[1024, 1024] add(f32[1024, 1024] %all-reduce, f32[1024, 1024] get-tuple-element.3) + + constant.1 = s32[] constant(1) + add = s32[] add(get-tuple-element.1, constant.1) + ROOT tuple = (s32[], f32[1024, 1024], f32[1024, 1024]) tuple(add, get-tuple-element.2, %accumulation) + } + SimpleLoop.condition { + loop_var.2 = (s32[], f32[1024, 1024], f32[1024, 1024]) parameter(0) + get-tuple-element.3 = s32[] get-tuple-element(loop_var.2), index=0 + constant.2 = s32[] constant({{LOOP_BOUND}}) + ROOT less-than = pred[] compare(get-tuple-element.3, constant.2), direction=LT + } + ENTRY SimpleLoop { + %param.1 = f32[1024, 1024] parameter(0) + constant.3 = s32[] constant(0) + + %accumulation_buffer_init = f32[] constant(0) + %accumulation_buffer = f32[1024, 1024] broadcast(f32[] %accumulation_buffer_init), dimensions={} + + tuple.1 = (s32[], f32[1024, 1024], f32[1024, 1024]) tuple(constant.3, %param.1, %accumulation_buffer) + ROOT while = (s32[], f32[1024, 1024], f32[1024, 1024]) while(tuple.1), condition=SimpleLoop.condition, body=SimpleLoop.body + } + )"; + std::string hlo_string = absl::StrReplaceAll( + hlo_string_template, {{"{{LOOP_BOUND}}", absl::StrCat(num_iters)}}); + return ParseAndReturnVerifiedModule(hlo_string).value(); +} + TEST_F(WhileLoopUnrollerTest, SimpleLoopUnroll) { - UnrollAndCompare(MakeModuleWithSimpleLoop(/*num_iters=*/5), {}); + UnrollAndCompare(MakeModuleWithSimpleLoop(/*num_iters=*/5), {}, -1, false); + UnrollAndCompare(MakeModuleWithSimpleLoop(/*num_iters=*/5), {}, -1, true); } TEST_F(WhileLoopUnrollerTest, SimpleLoopNotRoot) { @@ -325,10 +381,95 @@ TEST_F(WhileLoopUnrollerTest, SimpleLoopNotRoot) { ROOT result = s32[3]{0} get-tuple-element(while), index=1 } )"; - UnrollAndCompare(ParseAndReturnVerifiedModule(hlo_string).value(), {}); + UnrollAndCompare(ParseAndReturnVerifiedModule(hlo_string).value(), {}, -1, + false); + UnrollAndCompare(ParseAndReturnVerifiedModule(hlo_string).value(), {}, -1, + true); } -TEST_F(WhileLoopUnrollerTest, SimpleLoopNonZeroInit) { +TEST_F(WhileLoopUnrollerTest, GetUnrollableLoops) { + std::string hlo_string = R"( + HloModule SimpleLoop + SimpleLoop.body { + loop_var.1 = (s64[], s32[3]{0}) parameter(0) + get-tuple-element.1 = s64[] get-tuple-element(loop_var.1), index=0 + constant.1 = s64[] constant(1) + add = s64[] add(get-tuple-element.1, constant.1) + get-tuple-element.2 = s32[3]{0} get-tuple-element(loop_var.1), index=1 + multiply = s32[3]{0} add(get-tuple-element.2, get-tuple-element.2) + ROOT tuple = (s64[], s32[3]{0}) tuple(add, multiply) + } + SimpleLoop.condition { + loop_var.2 = (s64[], s32[3]{0}) parameter(0) + get-tuple-element.3 = s64[] get-tuple-element(loop_var.2), index=0 + /* number of iterations is 10 */ + constant.2 = s64[] constant(10) + ROOT less-than = pred[] compare(get-tuple-element.3, constant.2), direction=LT + } + SimpleLoop.body.2 { + loop_var.1 = (s64[], s32[3]{0}) parameter(0) + get-tuple-element.1 = s64[] get-tuple-element(loop_var.1), index=0 + constant.1 = s64[] constant(1) + add = s64[] add(get-tuple-element.1, constant.1) + get-tuple-element.2 = s32[3]{0} get-tuple-element(loop_var.1), index=1 + multiply = s32[3]{0} add(get-tuple-element.2, get-tuple-element.2) + ROOT tuple = (s64[], s32[3]{0}) tuple(add, multiply) + } + SimpleLoop.condition.2 { + loop_var.2 = (s64[], s32[3]{0}) parameter(0) + get-tuple-element.3 = s64[] get-tuple-element(loop_var.2), index=0 + /* number of iterations is 10 */ + constant.2 = s64[] constant(10) + ROOT less-than = pred[] compare(get-tuple-element.3, constant.2), direction=LT + } + SimpleLoop.body.3 { + loop_var.1 = (s64[], s32[3]{0}) parameter(0) + get-tuple-element.1 = s64[] get-tuple-element(loop_var.1), index=0 + constant.1 = s64[] constant(1) + add = s64[] multiply(get-tuple-element.1, constant.1) + get-tuple-element.2 = s32[3]{0} get-tuple-element(loop_var.1), index=1 + multiply = s32[3]{0} add(get-tuple-element.2, get-tuple-element.2) + ROOT tuple = (s64[], s32[3]{0}) tuple(add, multiply) + } + SimpleLoop.condition.3 { + loop_var.2 = (s64[], s32[3]{0}) parameter(0) + get-tuple-element.3 = s64[] get-tuple-element(loop_var.2), index=0 + /* number of iterations is 10 */ + constant.2 = s64[] constant(10) + ROOT less-than = pred[] compare(get-tuple-element.3, constant.2), direction=LT + } + ENTRY SimpleLoop { + constant.3 = s64[] constant(0) + constant.4 = s32[3]{0} constant({0, 1, 2}) + tuple.1 = (s64[], s32[3]{0}) tuple(constant.3, constant.4) + while1 = (s64[], s32[3]{0}) while(tuple.1), condition= + SimpleLoop.condition, body=SimpleLoop.body + while3 = (s64[], s32[3]{0}) while(tuple.1), condition= + SimpleLoop.condition.3, body=SimpleLoop.body.3 + while2 = (s64[], s32[3]{0}) while(tuple.1), condition= + SimpleLoop.condition.2, body=SimpleLoop.body.2 + o1 = s32[3]{0} get-tuple-element(while1), index=1 + o2 = s32[3]{0} get-tuple-element(while2), index=1 + ROOT result = (s32[3]{0}, s32[3]{0}) tuple(o1,o2) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_string)); + + HloInstruction* while1 = + module->entry_computation()->GetInstructionWithName("while1"); + HloInstruction* while2 = + module->entry_computation()->GetInstructionWithName("while2"); + HloInstruction* while3 = + module->entry_computation()->GetInstructionWithName("while3"); + + auto unrollable_loops = GetUnrollableLoops(module.get(), {}); + EXPECT_TRUE(unrollable_loops.contains(while1)); + EXPECT_TRUE(unrollable_loops.contains(while2)); + EXPECT_FALSE(unrollable_loops.contains(while3)); +} + +TEST_F(WhileLoopUnrollerTest, UnrollMutipleLoops) { std::string hlo_string = R"( HloModule SimpleLoop SimpleLoop.body { @@ -347,6 +488,86 @@ TEST_F(WhileLoopUnrollerTest, SimpleLoopNonZeroInit) { constant.2 = s64[] constant(10) ROOT less-than = pred[] compare(get-tuple-element.3, constant.2), direction=LT } + SimpleLoop.body.2 { + loop_var.1 = (s64[], s32[3]{0}) parameter(0) + get-tuple-element.1 = s64[] get-tuple-element(loop_var.1), index=0 + constant.1 = s64[] constant(1) + add = s64[] add(get-tuple-element.1, constant.1) + get-tuple-element.2 = s32[3]{0} get-tuple-element(loop_var.1), index=1 + multiply = s32[3]{0} add(get-tuple-element.2, get-tuple-element.2) + ROOT tuple = (s64[], s32[3]{0}) tuple(add, multiply) + } + SimpleLoop.condition.2 { + loop_var.2 = (s64[], s32[3]{0}) parameter(0) + get-tuple-element.3 = s64[] get-tuple-element(loop_var.2), index=0 + /* number of iterations is 10 */ + constant.2 = s64[] constant(10) + ROOT less-than = pred[] compare(get-tuple-element.3, constant.2), direction=LT + } + ENTRY SimpleLoop { + constant.3 = s64[] constant(0) + constant.4 = s32[3]{0} constant({0, 1, 2}) + tuple.1 = (s64[], s32[3]{0}) tuple(constant.3, constant.4) + while1 = (s64[], s32[3]{0}) while(tuple.1), condition= + SimpleLoop.condition, body=SimpleLoop.body + input = s32[3]{0} get-tuple-element(while1), index=1 + tuple.2 = (s64[], s32[3]{0}) tuple(constant.3, input) + while2 = (s64[], s32[3]{0}) while(tuple.2), condition= + SimpleLoop.condition.2, body=SimpleLoop.body.2 + o1 = s32[3]{0} get-tuple-element(while1), index=1 + o2 = s32[3]{0} get-tuple-element(while2), index=1 + ROOT result = (s32[3]{0}, s32[3]{0}) tuple(o1,o2) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_string)); + + // Unroll the first loop + TF_ASSERT_OK_AND_ASSIGN( + bool unrolled1, + Unroll(module->entry_computation()->GetInstructionWithName("while1"))); + EXPECT_TRUE(unrolled1); + + // There should be no call instructions after unrolling either loops since we + // inline all the calls after unrolling. + std::vector call_instrs_1; + for (auto* comp : module->MakeComputationPostOrder()) { + absl::c_copy_if(comp->instructions(), std::back_inserter(call_instrs_1), + HloPredicateIsOp); + } + EXPECT_EQ(call_instrs_1.size(), 0); + + // Unroll the second loop + TF_ASSERT_OK_AND_ASSIGN( + bool unrolled2, + Unroll(module->entry_computation()->GetInstructionWithName("while2"))); + EXPECT_TRUE(unrolled2); + std::vector call_instrs_2; + for (auto* comp : module->MakeComputationPostOrder()) { + absl::c_copy_if(comp->instructions(), std::back_inserter(call_instrs_2), + HloPredicateIsOp); + } + EXPECT_EQ(call_instrs_2.size(), 0); +} + +TEST_F(WhileLoopUnrollerTest, SimpleLoopNonZeroInit) { + std::string hlo_string = R"( + HloModule SimpleLoop + SimpleLoop.body { + loop_var.1 = (s64[], s32[3]{0}) parameter(0) + get-tuple-element.1 = s64[] get-tuple-element(loop_var.1), index=0 + constant.1 = s64[] constant(1) + add = s64[] add(get-tuple-element.1, constant.1) + get-tuple-element.2 = s32[3]{0} get-tuple-element(loop_var.1), index=1 + multiply = s32[3]{0} add(get-tuple-element.2, get-tuple-element.2) + ROOT tuple = (s64[], s32[3]{0}) tuple(add, multiply) + } + SimpleLoop.condition { + loop_var.2 = (s64[], s32[3]{0}) parameter(0) + get-tuple-element.3 = s64[] get-tuple-element(loop_var.2), index=0 + constant.2 = s64[] constant(10) + ROOT less-than = pred[] compare(get-tuple-element.3, constant.2), direction=LT + } ENTRY SimpleLoop { constant.3 = s64[] constant(4) constant.4 = s32[3]{0} constant({0, 1, 2}) @@ -356,7 +577,10 @@ TEST_F(WhileLoopUnrollerTest, SimpleLoopNonZeroInit) { ROOT result = s32[3]{0} get-tuple-element(while), index=1 } )"; - UnrollAndCompare(ParseAndReturnVerifiedModule(hlo_string).value(), {}); + UnrollAndCompare(ParseAndReturnVerifiedModule(hlo_string).value(), {}, -1, + false); + UnrollAndCompare(ParseAndReturnVerifiedModule(hlo_string).value(), {}, -1, + true); } TEST_F(WhileLoopUnrollerTest, SimpleLoopS16IndVar) { @@ -386,7 +610,10 @@ TEST_F(WhileLoopUnrollerTest, SimpleLoopS16IndVar) { SimpleLoop.condition, body=SimpleLoop.body } )"; - UnrollAndCompare(ParseAndReturnVerifiedModule(hlo_string).value(), {}); + UnrollAndCompare(ParseAndReturnVerifiedModule(hlo_string).value(), {}, -1, + false); + UnrollAndCompare(ParseAndReturnVerifiedModule(hlo_string).value(), {}, -1, + true); } TEST_F(WhileLoopUnrollerTest, LoopWithControlDep) { @@ -431,17 +658,244 @@ TEST_F(WhileLoopUnrollerTest, SimpleLoopPartialUnroll) { TEST_F(WhileLoopUnrollerTest, IndirectBodyInc) { std::unique_ptr module = MakeModuleWithLoopBodyIndirectInc(/*num_iters=*/5); - UnrollAndCompare(std::move(module), {}); + UnrollAndCompare(MakeModuleWithLoopBodyIndirectInc(/*num_iters=*/5), {}, -1, + false); + UnrollAndCompare(MakeModuleWithLoopBodyIndirectInc(/*num_iters=*/5), {}, -1, + true); } TEST_F(WhileLoopUnrollerTest, NestedIndirectBodyInc) { std::unique_ptr module = MakeModuleWithNestedLoopBodyIndirectInc(/*num_iters=*/5); - UnrollAndCompare(std::move(module), {}); + UnrollAndCompare(MakeModuleWithNestedLoopBodyIndirectInc(/*num_iters=*/5), {}, + -1, false); + UnrollAndCompare(MakeModuleWithNestedLoopBodyIndirectInc(/*num_iters=*/5), {}, + -1, true); } TEST_F(WhileLoopUnrollerTest, WhileFeedingWhile) { - UnrollAndCompare(MakeModuleWithWhileFeedingAnotherWhile(/*num_iters=*/5), {}); + UnrollAndCompare(MakeModuleWithWhileFeedingAnotherWhile(/*num_iters=*/5), {}, + -1, false); + UnrollAndCompare(MakeModuleWithWhileFeedingAnotherWhile(/*num_iters=*/5), {}, + -1, true); +} + +TEST_F(WhileLoopUnrollerTest, LoopWithCollective) { + int64_t num_iters = 5; + auto module = MakeModuleWithSimpleLoopAllReduce(num_iters); + + EXPECT_TRUE( + WhileLoopUnroller(/*unroll_factor=*/-1).Run(module.get()).value()); + + EXPECT_EQ(absl::c_count_if(module->entry_computation()->instructions(), + [](const HloInstruction* instruction) { + return instruction->opcode() == + HloOpcode::kAllReduce; + }), + num_iters); +} + +TEST_F(WhileLoopUnrollerTest, LoopWithCollective2) { + std::string hlo_string = R"( + HloModule module, entry_computation_layout={(s8[1,32,2048]{2,1,0:T(8,128)(4,1)S(1)}, s8[1,2048,4096]{2,1,0:T(8,128)(4,1)S(1)}, s32[1,32,256]{2,1,0:T(8,128)S(1)}, s32[1,32,256]{2,1,0:T(8,128)S(1)})->(s8[1,32,2048]{2,1,0:T(8,128)(4,1)S(1)}, s8[1,2048,4096]{2,1,0:T(8,128)(4,1)S(1)}, s32[1,32,256]{2,1,0:T(8,128)S(1)}, s32[1,32,256]{2,1,0:T(8,128)S(1)}, u32[]{:T(128)}, /*index=5*/u32[]{:T(128)}, u32[256]{0:T(256)}, u32[]{:T(128)}, u32[]{:T(128)}, s32[]{:T(128)}, /*index=10*/u32[]{:T(128)}, u32[]{:T(128)}, u32[]{:T(128)})} + + fused_computation.70.clone.clone.clone { + param_0.10545 = s8[1,32,2048]{2,1,0:T(8,128)(4,1)S(1)} parameter(0) + ROOT bitcast.7213 = s8[32,2048,1]{1,0,2:T(8,128)(4,1)} bitcast(param_0.10545) + } + + fused_computation.68.clone.clone.clone { + param_1.12561 = s8[1,2048,1,4096]{3,1,2,0:T(8,128)(4,1)S(1)} parameter(1) + constant.26622 = s8[]{:T(512)} constant(0) + pad.3783 = s8[1,2048,2,4096]{3,1,2,0:T(8,128)(4,1)} pad(param_1.12561, constant.26622), padding=0_0x0_0x0_1x0_0 + constant.26621 = s32[]{:T(128)} constant(0) + param_2.10214 = s32[]{:T(128)S(6)} parameter(2) + dynamic-slice.5474 = s8[1,2048,2,256]{3,1,2,0:T(8,128)(4,1)} dynamic-slice(pad.3783, constant.26621, constant.26621, constant.26621, param_2.10214), dynamic_slice_sizes={1,2048,2,256} + pad.3782 = s8[1,2048,2,4096]{3,1,2,0:T(8,128)(4,1)} pad(param_1.12561, constant.26622), padding=0_0x0_0x1_0x0_0 + param_0.10544 = s32[]{:T(128)S(6)} parameter(0) + dynamic-slice.5473 = s8[1,2048,2,256]{3,1,2,0:T(8,128)(4,1)} dynamic-slice(pad.3782, constant.26621, constant.26621, constant.26621, param_0.10544), dynamic_slice_sizes={1,2048,2,256} + add.10207 = s8[1,2048,2,256]{3,1,2,0:T(8,128)(4,1)} add(dynamic-slice.5474, dynamic-slice.5473) + ROOT bitcast.7212 = s8[2048,2,256]{2,0,1:T(8,128)(4,1)} bitcast(add.10207) + } + + fused_computation.71.clone { + param_3.7588 = s8[1,32,2048]{2,1,0:T(8,128)(4,1)S(1)} parameter(3) + fusion.4288 = s8[32,2048,1]{1,0,2:T(8,128)(4,1)} fusion(param_3.7588), kind=kLoop, calls=fused_computation.70.clone.clone.clone + param_0.10546 = s32[]{:T(128)S(6)} parameter(0) + param_1.12562 = s8[1,2048,1,4096]{3,1,2,0:T(8,128)(4,1)S(1)} parameter(1) + param_2.10215 = s32[]{:T(128)S(6)} parameter(2) + fusion.4287 = s8[2048,2,256]{2,0,1:T(8,128)(4,1)} fusion(param_0.10546, param_1.12562, param_2.10215), kind=kLoop, calls=fused_computation.68.clone.clone.clone + convolution.802 = s32[32,2,256]{2,0,1:T(8,128)} convolution(fusion.4288, fusion.4287), window={size=2 pad=1_1 rhs_reversal=1}, dim_labels=bf0_i0o->b0f + ROOT bitcast.7214 = s32[1,32,2,256]{3,1,2,0:T(8,128)S(1)} bitcast(convolution.802) + } + + fused_computation.76.clone { + param_0.10547 = s32[1,32,256]{2,1,0:T(8,128)S(1)} parameter(0) + param_1.12563 = s32[1,32,2,256]{3,1,2,0:T(8,128)S(1)} parameter(1) + slice.12606 = s32[1,32,1,256]{3,1,2,0:T(8,128)} slice(param_1.12563), slice={[0:1], [0:32], [1:2], [0:256]} + bitcast.7215 = s32[1,32,256]{2,1,0:T(8,128)} bitcast(slice.12606) + add.10208 = s32[1,32,256]{2,1,0:T(8,128)S(1)} add(param_0.10547, bitcast.7215) + param_2.10216 = s32[1,32,256]{2,1,0:T(8,128)S(1)} parameter(2) + slice.12000.clone.2 = s32[1,32,1,256]{3,1,2,0:T(8,128)} slice(param_1.12563), slice={[0:1], [0:32], [0:1], [0:256]} + bitcast.1776.clone.2 = s32[1,32,256]{2,1,0:T(8,128)} bitcast(slice.12000.clone.2) + add.6006.clone.2 = s32[1,32,256]{2,1,0:T(8,128)S(1)} add(param_2.10216, bitcast.1776.clone.2) + ROOT tuple.2892 = (s32[1,32,256]{2,1,0:T(8,128)S(1)}, s32[1,32,256]{2,1,0:T(8,128)S(1)}) tuple(add.10208, add.6006.clone.2) + } + + fused_computation.69.clone.clone.clone { + param_0.10549 = s8[1,32,2048]{2,1,0:T(8,128)(4,1)S(1)} parameter(0) + ROOT bitcast.7217 = s8[32,2048,1]{1,0,2:T(8,128)(4,1)} bitcast(param_0.10549) + } + + fused_computation.66.clone.clone.clone { + param_1.12564 = s8[1,2048,1,4096]{3,1,2,0:T(8,128)(4,1)S(1)} parameter(1) + constant.26625 = s8[]{:T(512)} constant(0) + pad.3785 = s8[1,2048,2,4096]{3,1,2,0:T(8,128)(4,1)} pad(param_1.12564, constant.26625), padding=0_0x0_0x0_1x0_0 + constant.26624 = s32[]{:T(128)} constant(0) + param_2.10217 = s32[]{:T(128)S(6)} parameter(2) + dynamic-slice.5476 = s8[1,2048,2,256]{3,1,2,0:T(8,128)(4,1)} dynamic-slice(pad.3785, constant.26624, constant.26624, constant.26624, param_2.10217), dynamic_slice_sizes={1,2048,2,256} + pad.3784 = s8[1,2048,2,4096]{3,1,2,0:T(8,128)(4,1)} pad(param_1.12564, constant.26625), padding=0_0x0_0x1_0x0_0 + param_0.10548 = s32[]{:T(128)S(6)} parameter(0) + dynamic-slice.5475 = s8[1,2048,2,256]{3,1,2,0:T(8,128)(4,1)} dynamic-slice(pad.3784, constant.26624, constant.26624, constant.26624, param_0.10548), dynamic_slice_sizes={1,2048,2,256} + add.10212 = s8[1,2048,2,256]{3,1,2,0:T(8,128)(4,1)} add(dynamic-slice.5476, dynamic-slice.5475) + ROOT bitcast.7216 = s8[2048,2,256]{2,0,1:T(8,128)(4,1)} bitcast(add.10212) + } + + fused_computation.72.clone { + param_3.7589 = s8[1,32,2048]{2,1,0:T(8,128)(4,1)S(1)} parameter(3) + fusion.4292 = s8[32,2048,1]{1,0,2:T(8,128)(4,1)} fusion(param_3.7589), kind=kLoop, calls=fused_computation.69.clone.clone.clone + param_0.10550 = s32[]{:T(128)S(6)} parameter(0) + param_1.12565 = s8[1,2048,1,4096]{3,1,2,0:T(8,128)(4,1)S(1)} parameter(1) + param_2.10218 = s32[]{:T(128)S(6)} parameter(2) + fusion.4291 = s8[2048,2,256]{2,0,1:T(8,128)(4,1)} fusion(param_0.10550, param_1.12565, param_2.10218), kind=kLoop, calls=fused_computation.66.clone.clone.clone + convolution.803 = s32[32,2,256]{2,0,1:T(8,128)} convolution(fusion.4292, fusion.4291), window={size=2 pad=1_1 rhs_reversal=1}, dim_labels=bf0_i0o->b0f + ROOT bitcast.7218 = s32[1,32,2,256]{3,1,2,0:T(8,128)S(1)} bitcast(convolution.803) + } + + fused_computation.74.clone { + param_0.10551 = s32[1,32,256]{2,1,0:T(8,128)S(1)} parameter(0) + param_1.12566 = s32[1,32,2,256]{3,1,2,0:T(8,128)S(1)} parameter(1) + slice.12607 = s32[1,32,1,256]{3,1,2,0:T(8,128)} slice(param_1.12566), slice={[0:1], [0:32], [1:2], [0:256]} + bitcast.7219 = s32[1,32,256]{2,1,0:T(8,128)} bitcast(slice.12607) + add.10213 = s32[1,32,256]{2,1,0:T(8,128)S(1)} add(param_0.10551, bitcast.7219) + param_2.10219 = s32[1,32,256]{2,1,0:T(8,128)S(1)} parameter(2) + slice.11997.clone.2 = s32[1,32,1,256]{3,1,2,0:T(8,128)} slice(param_1.12566), slice={[0:1], [0:32], [0:1], [0:256]} + bitcast.1773.clone.2 = s32[1,32,256]{2,1,0:T(8,128)} bitcast(slice.11997.clone.2) + add.6005.clone.2 = s32[1,32,256]{2,1,0:T(8,128)S(1)} add(param_2.10219, bitcast.1773.clone.2) + ROOT tuple.2893 = (s32[1,32,256]{2,1,0:T(8,128)S(1)}, s32[1,32,256]{2,1,0:T(8,128)S(1)}) tuple(add.10213, add.6005.clone.2) + } + + wide.windowed_dot_general_body { + wide_param.41 = (s8[1,32,2048]{2,1,0:T(8,128)(4,1)S(1)}, s8[1,2048,4096]{2,1,0:T(8,128)(4,1)S(1)}, s32[1,32,256]{2,1,0:T(8,128)S(1)}, s32[1,32,256]{2,1,0:T(8,128)S(1)}, u32[]{:T(128)}, /*index=5*/u32[]{:T(128)}, u32[256]{0:T(256)}, u32[]{:T(128)}, u32[]{:T(128)}, s32[]{:T(128)}, /*index=10*/u32[]{:T(128)}, u32[]{:T(128)}, u32[]{:T(128)}) parameter(0) + get-tuple-element.29000 = s8[1,32,2048]{2,1,0:T(8,128)(4,1)S(1)} get-tuple-element(wide_param.41), index=0 + get-tuple-element.29001 = s8[1,2048,4096]{2,1,0:T(8,128)(4,1)S(1)} get-tuple-element(wide_param.41), index=1 + get-tuple-element.28990 = s32[1,32,256]{2,1,0:T(8,128)S(1)} get-tuple-element(wide_param.41), index=3 + collective-permute-start = (s32[1,32,256]{2,1,0:T(8,128)S(1)}, s32[1,32,256]{2,1,0:T(8,128)S(1)}, u32[]{:S(2)}, u32[]{:S(2)}) collective-permute-start(get-tuple-element.28990), channel_id=18, source_target_pairs={{0,1},{1,2},{2,3},{3,4},{4,5},{5,6},{6,7},{7,8},{8,9},{9,10},{10,11},{11,12},{12,13},{13,14},{14,15},{15,0},{16,17},{17,18},{18,19},{19,20},{20,21},{21,22},{22,23},{23,24},{24,25},{25,26},{26,27},{27,28},{28,29},{29,30},{30,31},{31,16},{32,33},{33,34},{34,35},{35,36},{36,37},{37,38},{38,39},{39,40},{40,41},{41,42},{42,43},{43,44},{44,45},{45,46},{46,47},{47,32},{48,49},{49,50},{50,51},{51,52},{52,53},{53,54},{54,55},{55,56},{56,57},{57,58},{58,59},{59,60},{60,61},{61,62},{62,63},{63,48},{64,65},{65,66},{66,67},{67,68},{68,69},{69,70},{70,71},{71,72},{72,73},{73,74},{74,75},{75,76},{76,77},{77,78},{78,79},{79,64},{80,81},{81,82},{82,83},{83,84},{84,85},{85,86},{86,87},{87,88},{88,89},{89,90},{90,91},{91,92},{92,93},{93,94},{94,95},{95,80},{96,97},{97,98},{98,99},{99,100},{100,101},{101,102},{102,103},{103,104},{104,105},{105,106},{106,107},{107,108},{108,109},{109,110},{110,111},{111,96},{112,113},{113,114},{114,115},{115,116},{116,117},{117,118},{118,119},{119,120},{120,121},{121,122},{122,123},{123,124},{124,125},{125,126},{126,127},{127,112},{128,129},{129,130},{130,131},{131,132},{132,133},{133,134},{134,135},{135,136},{136,137},{137,138},{138,139},{139,140},{140,141},{141,142},{142,143},{143,128},{144,145},{145,146},{146,147},{147,148},{148,149},{149,150},{150,151},{151,152},{152,153},{153,154},{154,155},{155,156},{156,157},{157,158},{158,159},{159,144},{160,161},{161,162},{162,163},{163,164},{164,165},{165,166},{166,167},{167,168},{168,169},{169,170},{170,171},{171,172},{172,173},{173,174},{174,175},{175,160},{176,177},{177,178},{178,179},{179,180},{180,181},{181,182},{182,183},{183,184},{184,185},{185,186},{186,187},{187,188},{188,189},{189,190},{190,191},{191,176},{192,193},{193,194},{194,195},{195,196},{196,197},{197,198},{198,199},{199,200},{200,201},{201,202},{202,203},{203,204},{204,205},{205,206},{206,207},{207,192},{208,209},{209,210},{210,211},{211,212},{212,213},{213,214},{214,215},{215,216},{216,217},{217,218},{218,219},{219,220},{220,221},{221,222},{222,223},{223,208},{224,225},{225,226},{226,227},{227,228},{228,229},{229,230},{230,231},{231,232},{232,233},{233,234},{234,235},{235,236},{236,237},{237,238},{238,239},{239,224},{240,241},{241,242},{242,243},{243,244},{244,245},{245,246},{246,247},{247,248},{248,249},{249,250},{250,251},{251,252},{252,253},{253,254},{254,255},{255,240}} + collective-permute-done = s32[1,32,256]{2,1,0:T(8,128)S(1)} collective-permute-done(collective-permute-start) + get-tuple-element.29005 = u32[]{:T(128)} get-tuple-element(wide_param.41), index=5 + get-tuple-element.29006 = u32[256]{0:T(256)} get-tuple-element(wide_param.41), index=6 + partition-id.101 = u32[] partition-id() + dynamic-slice.5472 = u32[1]{0:T(128)} dynamic-slice(get-tuple-element.29006, partition-id.101), dynamic_slice_sizes={1} + bitcast.7210 = u32[]{:T(128)} bitcast(dynamic-slice.5472) + get-tuple-element.29007 = u32[]{:T(128)} get-tuple-element(wide_param.41), index=7 + add.10204 = u32[]{:T(128)S(6)} add(bitcast.7210, get-tuple-element.29007) + get-tuple-element.28991 = u32[]{:T(128)} get-tuple-element(wide_param.41), index=4 + subtract.2863 = u32[]{:T(128)S(6)} subtract(add.10204, get-tuple-element.28991) + get-tuple-element.29008 = u32[]{:T(128)} get-tuple-element(wide_param.41), index=8 + and.400 = u32[]{:T(128)S(6)} and(subtract.2863, get-tuple-element.29008) + clamp.1712 = u32[]{:T(128)S(6)} clamp(get-tuple-element.29005, and.400, get-tuple-element.29008) + convert.8615 = s32[]{:T(128)S(6)} convert(clamp.1712) + get-tuple-element.29009 = s32[]{:T(128)} get-tuple-element(wide_param.41), index=9 + multiply.14830 = s32[]{:T(128)S(6)} multiply(convert.8615, get-tuple-element.29009) + bitcast.8823 = s8[1,2048,1,4096]{3,1,2,0:T(8,128)(4,1)S(1)} bitcast(get-tuple-element.29001) + add.10205 = u32[]{:T(128)S(6)} add(get-tuple-element.28991, bitcast.7210) + get-tuple-element.29010 = u32[]{:T(128)} get-tuple-element(wide_param.41), index=10 + add.10206 = u32[]{:T(128)S(6)} add(add.10205, get-tuple-element.29010) + and.401 = u32[]{:T(128)S(6)} and(add.10206, get-tuple-element.29008) + clamp.1713 = u32[]{:T(128)S(6)} clamp(get-tuple-element.29005, and.401, get-tuple-element.29008) + convert.8616 = s32[]{:T(128)S(6)} convert(clamp.1713) + multiply.14831 = s32[]{:T(128)S(6)} multiply(convert.8616, get-tuple-element.29009) + fusion.4289 = s32[1,32,2,256]{3,1,2,0:T(8,128)S(1)} fusion(multiply.14830, bitcast.8823, multiply.14831, get-tuple-element.29000), kind=kOutput, calls=fused_computation.71.clone + get-tuple-element.28989 = s32[1,32,256]{2,1,0:T(8,128)S(1)} get-tuple-element(wide_param.41), index=2 + collective-permute-start.1 = (s32[1,32,256]{2,1,0:T(8,128)S(1)}, s32[1,32,256]{2,1,0:T(8,128)S(1)}, u32[]{:S(2)}, u32[]{:S(2)}) collective-permute-start(get-tuple-element.28989), channel_id=17, source_target_pairs={{0,15},{1,0},{2,1},{3,2},{4,3},{5,4},{6,5},{7,6},{8,7},{9,8},{10,9},{11,10},{12,11},{13,12},{14,13},{15,14},{16,31},{17,16},{18,17},{19,18},{20,19},{21,20},{22,21},{23,22},{24,23},{25,24},{26,25},{27,26},{28,27},{29,28},{30,29},{31,30},{32,47},{33,32},{34,33},{35,34},{36,35},{37,36},{38,37},{39,38},{40,39},{41,40},{42,41},{43,42},{44,43},{45,44},{46,45},{47,46},{48,63},{49,48},{50,49},{51,50},{52,51},{53,52},{54,53},{55,54},{56,55},{57,56},{58,57},{59,58},{60,59},{61,60},{62,61},{63,62},{64,79},{65,64},{66,65},{67,66},{68,67},{69,68},{70,69},{71,70},{72,71},{73,72},{74,73},{75,74},{76,75},{77,76},{78,77},{79,78},{80,95},{81,80},{82,81},{83,82},{84,83},{85,84},{86,85},{87,86},{88,87},{89,88},{90,89},{91,90},{92,91},{93,92},{94,93},{95,94},{96,111},{97,96},{98,97},{99,98},{100,99},{101,100},{102,101},{103,102},{104,103},{105,104},{106,105},{107,106},{108,107},{109,108},{110,109},{111,110},{112,127},{113,112},{114,113},{115,114},{116,115},{117,116},{118,117},{119,118},{120,119},{121,120},{122,121},{123,122},{124,123},{125,124},{126,125},{127,126},{128,143},{129,128},{130,129},{131,130},{132,131},{133,132},{134,133},{135,134},{136,135},{137,136},{138,137},{139,138},{140,139},{141,140},{142,141},{143,142},{144,159},{145,144},{146,145},{147,146},{148,147},{149,148},{150,149},{151,150},{152,151},{153,152},{154,153},{155,154},{156,155},{157,156},{158,157},{159,158},{160,175},{161,160},{162,161},{163,162},{164,163},{165,164},{166,165},{167,166},{168,167},{169,168},{170,169},{171,170},{172,171},{173,172},{174,173},{175,174},{176,191},{177,176},{178,177},{179,178},{180,179},{181,180},{182,181},{183,182},{184,183},{185,184},{186,185},{187,186},{188,187},{189,188},{190,189},{191,190},{192,207},{193,192},{194,193},{195,194},{196,195},{197,196},{198,197},{199,198},{200,199},{201,200},{202,201},{203,202},{204,203},{205,204},{206,205},{207,206},{208,223},{209,208},{210,209},{211,210},{212,211},{213,212},{214,213},{215,214},{216,215},{217,216},{218,217},{219,218},{220,219},{221,220},{222,221},{223,222},{224,239},{225,224},{226,225},{227,226},{228,227},{229,228},{230,229},{231,230},{232,231},{233,232},{234,233},{235,234},{236,235},{237,236},{238,237},{239,238},{240,255},{241,240},{242,241},{243,242},{244,243},{245,244},{246,245},{247,246},{248,247},{249,248},{250,249},{251,250},{252,251},{253,252},{254,253},{255,254}} + collective-permute-done.1 = s32[1,32,256]{2,1,0:T(8,128)S(1)} collective-permute-done(collective-permute-start.1) + fusion.4290 = (s32[1,32,256]{2,1,0:T(8,128)S(1)}, s32[1,32,256]{2,1,0:T(8,128)S(1)}) fusion(collective-permute-done, fusion.4289, collective-permute-done.1), kind=kLoop, calls=fused_computation.76.clone + get-tuple-element.22079 = s32[1,32,256]{2,1,0:T(8,128)S(1)} get-tuple-element(fusion.4290), index=0 + collective-permute-start.2 = (s32[1,32,256]{2,1,0:T(8,128)S(1)}, s32[1,32,256]{2,1,0:T(8,128)S(1)}, u32[]{:S(2)}, u32[]{:S(2)}) collective-permute-start(get-tuple-element.22079), channel_id=20, source_target_pairs={{0,1},{1,2},{2,3},{3,4},{4,5},{5,6},{6,7},{7,8},{8,9},{9,10},{10,11},{11,12},{12,13},{13,14},{14,15},{15,0},{16,17},{17,18},{18,19},{19,20},{20,21},{21,22},{22,23},{23,24},{24,25},{25,26},{26,27},{27,28},{28,29},{29,30},{30,31},{31,16},{32,33},{33,34},{34,35},{35,36},{36,37},{37,38},{38,39},{39,40},{40,41},{41,42},{42,43},{43,44},{44,45},{45,46},{46,47},{47,32},{48,49},{49,50},{50,51},{51,52},{52,53},{53,54},{54,55},{55,56},{56,57},{57,58},{58,59},{59,60},{60,61},{61,62},{62,63},{63,48},{64,65},{65,66},{66,67},{67,68},{68,69},{69,70},{70,71},{71,72},{72,73},{73,74},{74,75},{75,76},{76,77},{77,78},{78,79},{79,64},{80,81},{81,82},{82,83},{83,84},{84,85},{85,86},{86,87},{87,88},{88,89},{89,90},{90,91},{91,92},{92,93},{93,94},{94,95},{95,80},{96,97},{97,98},{98,99},{99,100},{100,101},{101,102},{102,103},{103,104},{104,105},{105,106},{106,107},{107,108},{108,109},{109,110},{110,111},{111,96},{112,113},{113,114},{114,115},{115,116},{116,117},{117,118},{118,119},{119,120},{120,121},{121,122},{122,123},{123,124},{124,125},{125,126},{126,127},{127,112},{128,129},{129,130},{130,131},{131,132},{132,133},{133,134},{134,135},{135,136},{136,137},{137,138},{138,139},{139,140},{140,141},{141,142},{142,143},{143,128},{144,145},{145,146},{146,147},{147,148},{148,149},{149,150},{150,151},{151,152},{152,153},{153,154},{154,155},{155,156},{156,157},{157,158},{158,159},{159,144},{160,161},{161,162},{162,163},{163,164},{164,165},{165,166},{166,167},{167,168},{168,169},{169,170},{170,171},{171,172},{172,173},{173,174},{174,175},{175,160},{176,177},{177,178},{178,179},{179,180},{180,181},{181,182},{182,183},{183,184},{184,185},{185,186},{186,187},{187,188},{188,189},{189,190},{190,191},{191,176},{192,193},{193,194},{194,195},{195,196},{196,197},{197,198},{198,199},{199,200},{200,201},{201,202},{202,203},{203,204},{204,205},{205,206},{206,207},{207,192},{208,209},{209,210},{210,211},{211,212},{212,213},{213,214},{214,215},{215,216},{216,217},{217,218},{218,219},{219,220},{220,221},{221,222},{222,223},{223,208},{224,225},{225,226},{226,227},{227,228},{228,229},{229,230},{230,231},{231,232},{232,233},{233,234},{234,235},{235,236},{236,237},{237,238},{238,239},{239,224},{240,241},{241,242},{242,243},{243,244},{244,245},{245,246},{246,247},{247,248},{248,249},{249,250},{250,251},{251,252},{252,253},{253,254},{254,255},{255,240}} + collective-permute-done.2 = s32[1,32,256]{2,1,0:T(8,128)S(1)} collective-permute-done(collective-permute-start.2) + get-tuple-element.29011 = u32[]{:T(128)} get-tuple-element(wide_param.41), index=11 + add.10209 = u32[]{:T(128)S(6)} add(get-tuple-element.28991, get-tuple-element.29011) + subtract.2864 = u32[]{:T(128)S(6)} subtract(add.10204, add.10209) + and.402 = u32[]{:T(128)S(6)} and(subtract.2864, get-tuple-element.29008) + clamp.1714 = u32[]{:T(128)S(6)} clamp(get-tuple-element.29005, and.402, get-tuple-element.29008) + convert.8617 = s32[]{:T(128)S(6)} convert(clamp.1714) + multiply.14832 = s32[]{:T(128)S(6)} multiply(convert.8617, get-tuple-element.29009) + bitcast.8824 = s8[1,2048,1,4096]{3,1,2,0:T(8,128)(4,1)S(1)} bitcast(get-tuple-element.29001) + add.10210 = u32[]{:T(128)S(6)} add(add.10209, bitcast.7210) + add.10211 = u32[]{:T(128)S(6)} add(add.10210, get-tuple-element.29010) + and.403 = u32[]{:T(128)S(6)} and(add.10211, get-tuple-element.29008) + clamp.1715 = u32[]{:T(128)S(6)} clamp(get-tuple-element.29005, and.403, get-tuple-element.29008) + convert.8618 = s32[]{:T(128)S(6)} convert(clamp.1715) + multiply.14833 = s32[]{:T(128)S(6)} multiply(convert.8618, get-tuple-element.29009) + fusion.4293 = s32[1,32,2,256]{3,1,2,0:T(8,128)S(1)} fusion(multiply.14832, bitcast.8824, multiply.14833, get-tuple-element.29000), kind=kOutput, calls=fused_computation.72.clone + get-tuple-element.22080 = s32[1,32,256]{2,1,0:T(8,128)S(1)} get-tuple-element(fusion.4290), index=1 + collective-permute-start.3 = (s32[1,32,256]{2,1,0:T(8,128)S(1)}, s32[1,32,256]{2,1,0:T(8,128)S(1)}, u32[]{:S(2)}, u32[]{:S(2)}) collective-permute-start(get-tuple-element.22080), channel_id=19, source_target_pairs={{0,15},{1,0},{2,1},{3,2},{4,3},{5,4},{6,5},{7,6},{8,7},{9,8},{10,9},{11,10},{12,11},{13,12},{14,13},{15,14},{16,31},{17,16},{18,17},{19,18},{20,19},{21,20},{22,21},{23,22},{24,23},{25,24},{26,25},{27,26},{28,27},{29,28},{30,29},{31,30},{32,47},{33,32},{34,33},{35,34},{36,35},{37,36},{38,37},{39,38},{40,39},{41,40},{42,41},{43,42},{44,43},{45,44},{46,45},{47,46},{48,63},{49,48},{50,49},{51,50},{52,51},{53,52},{54,53},{55,54},{56,55},{57,56},{58,57},{59,58},{60,59},{61,60},{62,61},{63,62},{64,79},{65,64},{66,65},{67,66},{68,67},{69,68},{70,69},{71,70},{72,71},{73,72},{74,73},{75,74},{76,75},{77,76},{78,77},{79,78},{80,95},{81,80},{82,81},{83,82},{84,83},{85,84},{86,85},{87,86},{88,87},{89,88},{90,89},{91,90},{92,91},{93,92},{94,93},{95,94},{96,111},{97,96},{98,97},{99,98},{100,99},{101,100},{102,101},{103,102},{104,103},{105,104},{106,105},{107,106},{108,107},{109,108},{110,109},{111,110},{112,127},{113,112},{114,113},{115,114},{116,115},{117,116},{118,117},{119,118},{120,119},{121,120},{122,121},{123,122},{124,123},{125,124},{126,125},{127,126},{128,143},{129,128},{130,129},{131,130},{132,131},{133,132},{134,133},{135,134},{136,135},{137,136},{138,137},{139,138},{140,139},{141,140},{142,141},{143,142},{144,159},{145,144},{146,145},{147,146},{148,147},{149,148},{150,149},{151,150},{152,151},{153,152},{154,153},{155,154},{156,155},{157,156},{158,157},{159,158},{160,175},{161,160},{162,161},{163,162},{164,163},{165,164},{166,165},{167,166},{168,167},{169,168},{170,169},{171,170},{172,171},{173,172},{174,173},{175,174},{176,191},{177,176},{178,177},{179,178},{180,179},{181,180},{182,181},{183,182},{184,183},{185,184},{186,185},{187,186},{188,187},{189,188},{190,189},{191,190},{192,207},{193,192},{194,193},{195,194},{196,195},{197,196},{198,197},{199,198},{200,199},{201,200},{202,201},{203,202},{204,203},{205,204},{206,205},{207,206},{208,223},{209,208},{210,209},{211,210},{212,211},{213,212},{214,213},{215,214},{216,215},{217,216},{218,217},{219,218},{220,219},{221,220},{222,221},{223,222},{224,239},{225,224},{226,225},{227,226},{228,227},{229,228},{230,229},{231,230},{232,231},{233,232},{234,233},{235,234},{236,235},{237,236},{238,237},{239,238},{240,255},{241,240},{242,241},{243,242},{244,243},{245,244},{246,245},{247,246},{248,247},{249,248},{250,249},{251,250},{252,251},{253,252},{254,253},{255,254}} + collective-permute-done.3 = s32[1,32,256]{2,1,0:T(8,128)S(1)} collective-permute-done(collective-permute-start.3) + fusion.4294 = (s32[1,32,256]{2,1,0:T(8,128)S(1)}, s32[1,32,256]{2,1,0:T(8,128)S(1)}) fusion(collective-permute-done.2, fusion.4293, collective-permute-done.3), kind=kLoop, calls=fused_computation.74.clone + get-tuple-element.29002 = s32[1,32,256]{2,1,0:T(8,128)S(1)} get-tuple-element(fusion.4294), index=1 + get-tuple-element.29003 = s32[1,32,256]{2,1,0:T(8,128)S(1)} get-tuple-element(fusion.4294), index=0 + get-tuple-element.29012 = u32[]{:T(128)} get-tuple-element(wide_param.41), index=12 + constant.28871 = u32[]{:T(128)} constant(2) + add.10214 = u32[]{:T(128)} add(get-tuple-element.28991, constant.28871) + ROOT tuple.3341 = (s8[1,32,2048]{2,1,0:T(8,128)(4,1)S(1)}, s8[1,2048,4096]{2,1,0:T(8,128)(4,1)S(1)}, s32[1,32,256]{2,1,0:T(8,128)S(1)}, s32[1,32,256]{2,1,0:T(8,128)S(1)}, u32[]{:T(128)}, /*index=5*/u32[]{:T(128)}, u32[256]{0:T(256)}, u32[]{:T(128)}, u32[]{:T(128)}, s32[]{:T(128)}, /*index=10*/u32[]{:T(128)}, u32[]{:T(128)}, u32[]{:T(128)}) tuple(get-tuple-element.29000, get-tuple-element.29001, get-tuple-element.29002, get-tuple-element.29003, add.10214, get-tuple-element.29005, get-tuple-element.29006, get-tuple-element.29007, get-tuple-element.29008, get-tuple-element.29009, get-tuple-element.29010, get-tuple-element.29011, get-tuple-element.29012) + } + + wide.windowed_dot_general_cond { + wide_param.40 = (s8[1,32,2048]{2,1,0:T(8,128)(4,1)S(1)}, s8[1,2048,4096]{2,1,0:T(8,128)(4,1)S(1)}, s32[1,32,256]{2,1,0:T(8,128)S(1)}, s32[1,32,256]{2,1,0:T(8,128)S(1)}, u32[]{:T(128)}, /*index=5*/u32[]{:T(128)}, u32[256]{0:T(256)}, u32[]{:T(128)}, u32[]{:T(128)}, s32[]{:T(128)}, /*index=10*/u32[]{:T(128)}, u32[]{:T(128)}, u32[]{:T(128)}) parameter(0) + get-tuple-element.22055 = u32[]{:T(128)} get-tuple-element(wide_param.40), index=4 + constant.26614 = u32[]{:T(128)} constant(8) + ROOT compare.2683 = pred[]{:T(512)} compare(get-tuple-element.22055, constant.26614), direction=LT + } + + ENTRY test { + fusion.4456 = s8[1,32,2048]{2,1,0:T(8,128)(4,1)S(1)} parameter(0) + fusion.4457 = s8[1,2048,4096]{2,1,0:T(8,128)(4,1)S(1)} parameter(1) + broadcast.26239 = s32[1,32,256]{2,1,0:T(8,128)S(1)} parameter(2) + broadcast.26239.clone = s32[1,32,256]{2,1,0:T(8,128)S(1)} parameter(3) + constant.28863 = u32[]{:T(128)} constant(0) + constant.28864 = u32[]{:T(128)} constant(0) + constant.28865 = u32[256]{0:T(256)} constant({0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, 254, 255}) + constant.28866 = u32[]{:T(128)} constant(8) + constant.28867 = u32[]{:T(128)} constant(15) + constant.28868 = s32[]{:T(128)} constant(256) + constant.28869 = u32[]{:T(128)} constant(9) + constant.28870 = u32[]{:T(128)} constant(1) + constant.28871 = u32[]{:T(128)} constant(2) + tuple.3339 = (s8[1,32,2048]{2,1,0:T(8,128)(4,1)S(1)}, s8[1,2048,4096]{2,1,0:T(8,128)(4,1)S(1)}, s32[1,32,256]{2,1,0:T(8,128)S(1)}, s32[1,32,256]{2,1,0:T(8,128)S(1)}, u32[]{:T(128)}, /*index=5*/u32[]{:T(128)}, u32[256]{0:T(256)}, u32[]{:T(128)}, u32[]{:T(128)}, s32[]{:T(128)}, /*index=10*/u32[]{:T(128)}, u32[]{:T(128)}, u32[]{:T(128)}) tuple(fusion.4456, fusion.4457, broadcast.26239, broadcast.26239.clone, constant.28863, constant.28864, constant.28865, constant.28866, constant.28867, constant.28868, constant.28869, constant.28870, constant.28871) + ROOT while.636 = (s8[1,32,2048]{2,1,0:T(8,128)(4,1)S(1)}, s8[1,2048,4096]{2,1,0:T(8,128)(4,1)S(1)}, s32[1,32,256]{2,1,0:T(8,128)S(1)}, s32[1,32,256]{2,1,0:T(8,128)S(1)}, u32[]{:T(128)}, /*index=5*/u32[]{:T(128)}, u32[256]{0:T(256)}, u32[]{:T(128)}, u32[]{:T(128)}, s32[]{:T(128)}, /*index=10*/u32[]{:T(128)}, u32[]{:T(128)}, u32[]{:T(128)}) while(tuple.3339), condition=wide.windowed_dot_general_cond, body=wide.windowed_dot_general_body + })"; + auto module = ParseAndReturnVerifiedModule(hlo_string).value(); + + int64_t fusion_instr_count = absl::c_count_if( + module->GetComputationWithName("wide.windowed_dot_general_body") + ->instructions(), + [](const HloInstruction* instr) { + return (instr->IsLoopFusion() || instr->IsOutputFusion()); + }); + + // Fully unroll the specific loop (trip count is 4) + EXPECT_TRUE( + WhileLoopUnroller(/*unroll_factor=*/-1).Run(module.get()).value()); + + int64_t fusion_instr_count_after_unroll = absl::c_count_if( + module->entry_computation()->instructions(), + [](const HloInstruction* instr) { + return (instr->IsLoopFusion() || instr->IsOutputFusion()); + }); + + // The total number of fusions in the unrolled version in the entry must be + // equal to loop_trip_count * fusion_instr_count + EXPECT_EQ(fusion_instr_count * 4, fusion_instr_count_after_unroll); } } // namespace diff --git a/third_party/xla/xla/service/xla_aot_compile_cpu_test.cc b/third_party/xla/xla/service/xla_aot_compile_cpu_test.cc index 2e8a82707b485d..00dc0743f9670e 100644 --- a/third_party/xla/xla/service/xla_aot_compile_cpu_test.cc +++ b/third_party/xla/xla/service/xla_aot_compile_cpu_test.cc @@ -69,7 +69,9 @@ TEST(XlaCompileTest, LoadCpuExecutable) { executable_run_options.set_allocator(client->backend().memory_allocator()); TF_ASSERT_OK_AND_ASSIGN( ScopedShapedBuffer result, - local_executable->Run({&array1, &array2}, executable_run_options)); + local_executable->Run( + absl::Span{&array1, &array2}, + executable_run_options)); TF_ASSERT_OK_AND_ASSIGN(Literal output, client->ShapedBufferToLiteral(result)); diff --git a/third_party/xla/xla/service/xla_aot_compile_test_autotune_results.prototxt b/third_party/xla/xla/service/xla_aot_compile_test_autotune_results.prototxt index 6b413161dc4860..0f339aa4974f0a 100644 --- a/third_party/xla/xla/service/xla_aot_compile_test_autotune_results.prototxt +++ b/third_party/xla/xla/service/xla_aot_compile_test_autotune_results.prototxt @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -version: 2 +version: 3 results { device: "sm_6.0 with 17071734784B RAM, 56 cores, 1480500KHz clock, 715000KHz mem clock, 4194304B L2$" hlo: "(f32[3,3]{1,0}, s8[72]{0}) custom-call(f32[3,3]{1,0}, f32[3,3]{1,0}), custom_call_target=\"__cublas$gemm\", backend_config={\"operation_queue_id\":\"0\",\"wait_on_operation_queues\":[],\"gemm_backend_config\":{\"alpha_real\":1,\"beta\":0,\"dot_dimension_numbers\":{\"lhs_contracting_dimensions\":[\"1\"],\"rhs_contracting_dimensions\":[\"0\"],\"lhs_batch_dimensions\":[],\"rhs_batch_dimensions\":[]},\"alpha_imag\":0,\"precision_config\":{\"operand_precision\":[\"DEFAULT\",\"DEFAULT\"]},\"epilogue\":\"DEFAULT\",\"lhs_stride\":\"9\",\"rhs_stride\":\"9\",\"grad_x\":false,\"grad_y\":false}}" diff --git a/third_party/xla/xla/shape.cc b/third_party/xla/xla/shape.cc index 617465149e756c..f562e2f5958a12 100644 --- a/third_party/xla/xla/shape.cc +++ b/third_party/xla/xla/shape.cc @@ -116,23 +116,17 @@ std::string Shape::ToString(bool print_layout) const { } bool Shape::IsInteger() const { - if (primitive_util::IsIntegralType(element_type())) { - return true; - } if (IsTuple()) { - return absl::c_any_of(tuple_shapes_, + return absl::c_all_of(tuple_shapes_, [](const Shape& s) { return s.IsInteger(); }); } - return false; + return primitive_util::IsIntegralType(element_type()); } bool Shape::is_static() const { if (IsTuple()) { - for (const Shape& subshape : tuple_shapes_) { - if (!subshape.is_static()) { - return false; - } - } + return absl::c_all_of(tuple_shapes_, + [](const Shape& s) { return s.is_static(); }); } return !absl::c_any_of(dynamic_dimensions_, [](bool b) { return b; }); } diff --git a/third_party/xla/xla/shape.h b/third_party/xla/xla/shape.h index 2cdf308e7872f9..5361d8dfc7d1f8 100644 --- a/third_party/xla/xla/shape.h +++ b/third_party/xla/xla/shape.h @@ -130,6 +130,11 @@ class Shape { return dynamic_dimensions_[dimension]; } + // Returns true if the given dimension is statically-sized. + bool is_static_dimension(int dimension) const { + return !dynamic_dimensions_[dimension]; + } + // Sets whether or not the given dimension is dynamically-sized. void set_dynamic_dimension(int dimension, bool is_dynamic) { dynamic_dimensions_[dimension] = is_dynamic; diff --git a/third_party/xla/xla/shape_test.cc b/third_party/xla/xla/shape_test.cc index d573fa3d5c5e9a..3d23b421718ed8 100644 --- a/third_party/xla/xla/shape_test.cc +++ b/third_party/xla/xla/shape_test.cc @@ -108,6 +108,26 @@ TEST_F(ShapeTest, EqualityTest) { ShapeUtil::MakeShapeWithDenseLayout(F32, {23, 44}, {1, 0})); } +TEST_F(ShapeTest, IsInteger) { + EXPECT_FALSE(opaque_.IsInteger()); + EXPECT_FALSE(token_.IsInteger()); + EXPECT_TRUE(matrix_.IsInteger()); + EXPECT_FALSE(tuple_.IsInteger()); + EXPECT_FALSE(nested_tuple_.IsInteger()); + + Shape u32_shape = ShapeUtil::MakeShape(U32, {1}); + EXPECT_TRUE(u32_shape.IsInteger()); + + Shape f32_shape = ShapeUtil::MakeShape(F32, {1}); + EXPECT_FALSE(f32_shape.IsInteger()); + + Shape integer_tuple = ShapeUtil::MakeTupleShape({u32_shape, u32_shape}); + EXPECT_TRUE(integer_tuple.IsInteger()); + + Shape mixed_type_tuple = ShapeUtil::MakeTupleShape({u32_shape, f32_shape}); + EXPECT_FALSE(mixed_type_tuple.IsInteger()); +} + TEST_F(ShapeTest, IsStatic) { EXPECT_TRUE(opaque_.is_static()); EXPECT_TRUE(token_.is_static()); @@ -165,6 +185,15 @@ TEST_F(ShapeTest, IsDynamicDimension) { EXPECT_FALSE(unbounded_.is_dynamic_dimension(1)); } +TEST_F(ShapeTest, IsStaticDimension) { + Shape dynamic_matrix = matrix_; + dynamic_matrix.set_dynamic_dimension(1, true); + EXPECT_TRUE(dynamic_matrix.is_static_dimension(0)); + EXPECT_FALSE(dynamic_matrix.is_static_dimension(1)); + EXPECT_FALSE(unbounded_.is_static_dimension(0)); + EXPECT_TRUE(unbounded_.is_static_dimension(1)); +} + TEST_F(ShapeTest, ProgramShapeToFromProto) { ProgramShape program_shape; *program_shape.add_parameters() = ShapeUtil::MakeShape(F32, {1, 2, 3}); diff --git a/third_party/xla/xla/shape_util.cc b/third_party/xla/xla/shape_util.cc index fa6ed6c0c82577..439f20243145e9 100644 --- a/third_party/xla/xla/shape_util.cc +++ b/third_party/xla/xla/shape_util.cc @@ -688,8 +688,7 @@ Shape ShapeUtil::PrependMajorDimension(int64_t bound, Shape shape) { } /* static */ bool ShapeUtil::IsZeroElementArray(const Shape& shape) { - return shape.IsArray() && - absl::c_any_of(shape.dimensions(), [](int64_t d) { return d == 0; }); + return shape.IsArray() && absl::c_linear_search(shape.dimensions(), 0); } /* static */ bool ShapeUtil::IsScalarWithElementType( @@ -966,6 +965,7 @@ Shape ShapeUtil::PrependMajorDimension(int64_t bound, Shape shape) { } bool overflow; std::tie(product, overflow) = OverflowSafeMultiply(product, dimension); + any_overflows |= overflow; } if (any_overflows) { return InvalidArgument("shape's dimensions overflow: %s", diff --git a/third_party/xla/xla/status.h b/third_party/xla/xla/status.h index 71c18131030b56..818bfdf4b1ba2d 100644 --- a/third_party/xla/xla/status.h +++ b/third_party/xla/xla/status.h @@ -24,7 +24,6 @@ namespace xla { // NOLINTBEGIN(misc-unused-using-decls) using absl::OkStatus; using absl::Status; -using absl::StatusOr; // NOLINTEND(misc-unused-using-decls) } // namespace xla diff --git a/third_party/xla/xla/stream_executor/BUILD b/third_party/xla/xla/stream_executor/BUILD index b226a7e6c8e5a7..3d3e17c6a8a9c2 100644 --- a/third_party/xla/xla/stream_executor/BUILD +++ b/third_party/xla/xla/stream_executor/BUILD @@ -1,13 +1,14 @@ load("//xla:xla.bzl", "xla_cc_test") load("//xla/stream_executor:build_defs.bzl", "stream_executor_friends", "stream_executor_internal") -load("@local_tsl//tsl:tsl.bzl", "set_external_visibility", "transitive_hdrs") +load("@local_tsl//tsl:tsl.bzl", "internal_visibility", "transitive_hdrs") load("@local_tsl//tsl:tsl.default.bzl", "filegroup") load("@local_tsl//tsl/platform:build_config.bzl", "tf_proto_library") load("@local_tsl//tsl/platform:build_config_root.bzl", "if_static") load("@local_tsl//tsl/platform:rules_cc.bzl", "cc_library") package( - default_visibility = ["//visibility:public"], + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = internal_visibility([":friends"]), licenses = ["notice"], ) @@ -50,7 +51,7 @@ package_group( # an implementation detail of StreamExecutor and has internal visibility. # # TODO(ezhulenev): Remove from public API headers that are exported via standalone public libraries, -# e.g. `platform` and `multi_platform_manager` should be added with an explicit dependency. +# e.g. `platform` and `platform_manager` should be added with an explicit dependency. filegroup( name = "stream_executor_api_headers", srcs = [ @@ -63,19 +64,21 @@ filegroup( "device_options.h", "event.h", "executor_cache.h", + "host_memory_allocation.h", "kernel.h", "kernel_spec.h", "launch_dim.h", + "memory_allocation.h", "module_spec.h", "multi_platform_manager.h", "numeric_options.h", "platform.h", + "platform_manager.h", "scratch_allocator.h", "stream.h", "stream_executor.h", - "temporary_device_memory.h", ], - visibility = ["//visibility:public"], + visibility = ["//visibility:private"], ) # These are the headers for default StreamExecutor plugins. @@ -86,7 +89,7 @@ filegroup( "dnn.h", "fft.h", ], - visibility = ["//visibility:public"], + visibility = ["//visibility:private"], ) # This is a list of dependencies required for building `stream_executor` target (and required for @@ -110,6 +113,7 @@ STREAM_EXECUTOR_DEPENDENCIES = [ "//xla/stream_executor/platform", "@local_tsl//tsl/framework:device_id", "@local_tsl//tsl/framework:device_type", + "@local_tsl//tsl/lib/gtl:int_type", "@local_tsl//tsl/platform:env", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:ml_dtypes", @@ -126,7 +130,6 @@ cc_library( ":stream_executor_api_headers", ":stream_executor_plugin_headers", ], - visibility = ["//visibility:public"], deps = STREAM_EXECUTOR_DEPENDENCIES + [ ":stream_executor_pimpl", "@com_google_absl//absl/status:statusor", @@ -152,14 +155,12 @@ tf_proto_library( cc_api_version = 2, make_default_target_header_only = True, protodeps = ["//xla:autotune_results_proto"], - visibility = ["//visibility:public"], ) cc_library( name = "device_description", srcs = ["device_description.cc"], hdrs = ["device_description.h"], - visibility = ["//visibility:public"], deps = [ ":device_description_proto_cc", ":launch_dim", @@ -174,7 +175,6 @@ cc_library( cc_library( name = "device_memory", hdrs = ["device_memory.h"], - visibility = ["//visibility:public"], deps = [ "//xla/stream_executor/platform", "@local_tsl//tsl/platform:logging", @@ -184,7 +184,7 @@ cc_library( cc_library( name = "data_type", hdrs = ["data_type.h"], - visibility = ["//visibility:public"], + visibility = [":internal"], deps = [ "@local_tsl//tsl/platform:ml_dtypes", "@local_tsl//tsl/protobuf:dnn_proto_cc", @@ -194,7 +194,6 @@ cc_library( cc_library( name = "device_memory_allocator", hdrs = ["device_memory_allocator.h"], - visibility = ["//visibility:public"], deps = [ ":device_memory", ":platform", @@ -210,32 +209,43 @@ cc_library( cc_library( name = "device_options", hdrs = ["device_options.h"], - visibility = ["//visibility:public"], deps = [ "@com_google_absl//absl/strings", "@local_tsl//tsl/platform:logging", ], ) +cc_library( + name = "host_memory_allocation", + srcs = ["host_memory_allocation.cc"], + hdrs = ["host_memory_allocation.h"], + deps = [ + ":memory_allocation", + ":stream_executor_internal", # TODO(b/323534971): Remove dependency on Interface. + ], +) + cc_library( name = "host_or_device_scalar", hdrs = ["host_or_device_scalar.h"], - visibility = ["//visibility:public"], deps = [":device_memory"], ) cc_library( name = "launch_dim", hdrs = ["launch_dim.h"], - visibility = ["//visibility:public"], deps = ["@com_google_absl//absl/strings"], ) cc_library( - name = "multi_platform_manager", - srcs = ["multi_platform_manager.cc"], - hdrs = ["multi_platform_manager.h"], - visibility = ["//visibility:public"], + name = "memory_allocation", + hdrs = ["memory_allocation.h"], +) + +cc_library( + name = "platform_manager", + srcs = ["platform_manager.cc"], + hdrs = ["platform_manager.h"], deps = [ ":platform", "//xla/stream_executor/platform", @@ -251,17 +261,24 @@ cc_library( ], ) +# TODO(hebecker): Remove compatibility target when all users have been migrated to :platform_manager +cc_library( + name = "multi_platform_manager", + hdrs = ["multi_platform_manager.h"], + deprecation = "The type MultiPlatformManager is being renamed to PlatformManager. " + + "Use target :platform_manager instead.", + deps = [":platform_manager"], +) + cc_library( name = "numeric_options", hdrs = ["numeric_options.h"], - visibility = ["//visibility:public"], ) cc_library( name = "platform", srcs = ["platform.cc"], hdrs = ["platform.h"], - visibility = ["//visibility:public"], deps = [ ":device_description", ":device_options", @@ -283,7 +300,6 @@ cc_library( name = "blas", srcs = ["blas.cc"], hdrs = ["blas.h"], - visibility = ["//visibility:public"], deps = [ ":data_type", ":device_memory", @@ -292,6 +308,7 @@ cc_library( "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", + "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/protobuf:dnn_proto_cc", ], ) @@ -300,7 +317,6 @@ cc_library( name = "dnn", srcs = ["dnn.cc"], hdrs = ["dnn.h"], - visibility = ["//visibility:public"], deps = [ ":data_type", ":device_description_proto_cc", @@ -326,7 +342,6 @@ cc_library( cc_library( name = "fft", hdrs = ["fft.h"], - visibility = ["//visibility:public"], deps = [ "//xla/stream_executor/platform", ], @@ -335,7 +350,6 @@ cc_library( cc_library( name = "lazy_op_runner", hdrs = ["lazy_op_runner.h"], - visibility = ["//visibility:public"], deps = [ ":stream_executor_headers", "@com_google_absl//absl/base", @@ -346,10 +360,7 @@ cc_library( ) # TODO(ezhulenev): This should be removed. -exports_files( - ["lazy_op_runner.h"], - visibility = ["//visibility:public"], -) +exports_files(["lazy_op_runner.h"]) #===--------------------------------------------------------------------------------------------===# # StreamExecutor platform-dependent interfaces @@ -365,7 +376,7 @@ exports_files( cc_library( name = "stream_executor_internal", hdrs = ["stream_executor_internal.h"], - visibility = ["//visibility:public"], + visibility = internal_visibility([":internal"]), deps = [ ":stream_executor_headers", "//xla/stream_executor/platform", @@ -391,7 +402,7 @@ cc_library( ":stream_executor_api_headers", ":stream_executor_plugin_headers", ], - visibility = ["//visibility:public"], + visibility = [":internal"], deps = STREAM_EXECUTOR_DEPENDENCIES + if_static([ "@com_google_protobuf//:protobuf", # indirectly-used by dnn.h ]) + [ @@ -404,7 +415,7 @@ cc_library( name = "plugin_registry", srcs = ["plugin_registry.cc"], hdrs = ["plugin_registry.h"], - visibility = ["//visibility:public"], + visibility = [":internal"], deps = [ ":blas", ":dnn", @@ -431,7 +442,7 @@ cc_library( name = "allocator_stats", srcs = ["allocator_stats.cc"], hdrs = ["allocator_stats.h"], - visibility = ["//visibility:public"], + visibility = ["//visibility:private"], deps = [ "//xla/stream_executor/platform", "@com_google_absl//absl/strings:str_format", @@ -442,19 +453,15 @@ cc_library( name = "command_buffer", srcs = ["command_buffer.cc"], hdrs = ["command_buffer.h"], - local_defines = select({ - "//xla/stream_executor/cuda:graph_conditional_enabled": [ - "STREAM_EXECUTOR_CUDA_ENABLE_GRAPH_CONDITIONAL=1", - ], - "//conditions:default": [], - }), - visibility = ["//visibility:public"], + visibility = ["//visibility:private"], deps = [ ":stream_executor_headers", ":stream_executor_internal", "@com_google_absl//absl/functional:any_invocable", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/types:span", + "@local_tsl//tsl/lib/gtl:int_type", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:statusor", ], @@ -464,7 +471,7 @@ cc_library( name = "event", srcs = ["event.cc"], hdrs = ["event.h"], - visibility = ["//visibility:public"], + visibility = ["//visibility:private"], deps = [ ":stream_executor_headers", ":stream_executor_internal", @@ -477,7 +484,7 @@ cc_library( name = "executor_cache", srcs = ["executor_cache.cc"], hdrs = ["executor_cache.h"], - visibility = ["//visibility:public"], + visibility = ["//visibility:private"], deps = [ ":platform", ":stream_executor_headers", @@ -496,7 +503,7 @@ cc_library( name = "kernel_spec", srcs = ["kernel_spec.cc"], hdrs = ["kernel_spec.h"], - visibility = ["//visibility:public"], + visibility = ["//visibility:private"], deps = [ "//xla/stream_executor/platform", "@com_google_absl//absl/status:statusor", @@ -510,27 +517,31 @@ cc_library( name = "kernel", srcs = ["kernel.cc"], hdrs = ["kernel.h"], - visibility = ["//visibility:public"], + visibility = ["//visibility:private"], deps = [ ":device_memory", + ":kernel_spec", ":platform", ":stream_executor_headers", ":stream_executor_internal", "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/memory", "@com_google_absl//absl/meta:type_traits", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", + "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:logging", "@local_tsl//tsl/platform:platform_port", + "@local_tsl//tsl/platform:statusor", ], ) cc_library( name = "scratch_allocator", hdrs = ["scratch_allocator.h"], - visibility = ["//visibility:public"], + visibility = ["//visibility:private"], deps = [ ":device_memory_allocator", ":stream_executor_headers", @@ -541,18 +552,14 @@ cc_library( ], ) -cc_library( - name = "temporary_device_memory", - srcs = ["temporary_device_memory.cc"], - hdrs = ["temporary_device_memory.h"], - visibility = ["//visibility:public"], - deps = [":stream_executor_headers"], -) - #===--------------------------------------------------------------------------------------------===# transitive_hdrs( name = "stream_executor_install_hdrs", + tags = [ + "alt_dep=:stream_executor_headers", + "avoid_dep", + ], deps = [":stream_executor_headers"], ) @@ -566,17 +573,18 @@ cc_library( "stream_executor_pimpl.cc", ], hdrs = ["stream_executor_pimpl.h"], - visibility = ["//visibility:public"], + tags = ["avoid_dep"], + visibility = ["//visibility:private"], deps = [ ":blas", # build_cleaner: keep ":command_buffer", # build_cleaner: keep ":dnn", # build_cleaner: keep ":fft", + ":host_memory_allocation", ":kernel_spec", ":platform", ":stream_executor_headers", ":stream_executor_internal", - ":temporary_device_memory", "//xla/stream_executor/platform", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/functional:any_invocable", @@ -604,7 +612,6 @@ cc_library( # things that lead to nearly impossible to debug run time crashes. cc_library( name = "stream_executor_impl", - visibility = ["//visibility:public"], deps = [ ":allocator_stats", ":device_description", @@ -615,8 +622,8 @@ cc_library( ":kernel", ":kernel_spec", ":launch_dim", - ":multi_platform_manager", ":platform", + ":platform_manager", ":scratch_allocator", ":stream_executor_headers", ":stream_executor_pimpl", @@ -641,6 +648,18 @@ xla_cc_test( ], ) +xla_cc_test( + name = "stream_executor_test", + srcs = ["stream_executor_test.cc"], + deps = [ + ":stream_executor", + "//xla/stream_executor/host:host_platform", + "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/platform:test", + "@local_tsl//tsl/platform:test_main", + ], +) + xla_cc_test( name = "stream_test", size = "small", @@ -648,6 +667,9 @@ xla_cc_test( deps = [ ":stream_executor", "//xla/stream_executor/host:host_platform", + "@com_google_absl//absl/log:check", + "@com_google_googletest//:gtest", + "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/platform:test", "@local_tsl//tsl/platform:test_main", ], @@ -671,23 +693,21 @@ xla_cc_test( alias( name = "cuda_platform", actual = "//xla/stream_executor/cuda:all_runtime", - visibility = ["//visibility:public"], ) alias( name = "rocm_platform", actual = "//xla/stream_executor/rocm:all_runtime", - visibility = ["//visibility:public"], ) # TODO(ezhulenev): This should be removed. cc_library( name = "stream_executor_bundle", - visibility = ["//visibility:public"], + visibility = [":internal"], deps = [ ":dnn", ":event", - ":multi_platform_manager", + ":platform_manager", ":scratch_allocator", ":stream_executor", "//xla/stream_executor/cuda:cuda_platform_id", diff --git a/third_party/xla/xla/stream_executor/blas.h b/third_party/xla/xla/stream_executor/blas.h index 0d067d6aee6696..07a5dadc681624 100644 --- a/third_party/xla/xla/stream_executor/blas.h +++ b/third_party/xla/xla/stream_executor/blas.h @@ -1,3 +1,4 @@ +#include "tsl/platform/errors.h" /* Copyright 2015 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); @@ -17,24 +18,6 @@ limitations under the License. // use in conjunction with the StreamExecutor abstraction. // // Note that this interface is optionally supported by platforms. -// -// This abstraction makes it simple to entrain BLAS operations on GPU data into -// a Stream -- users typically will not use this API directly, but will use the -// Stream builder methods to entrain these operations "under the hood". For -// example: -// -// DeviceMemory x = stream_exec->AllocateArray(1024); -// DeviceMemory y = stream_exec->AllocateArray(1024); -// // ... populate x and y ... -// Stream stream{stream_exec}; -// stream -// .Init() -// .ThenBlasAxpy(1024, 5.5, x, 1, &y, 1); -// TF_CHECK_OK(stream.BlockHostUntilDone()); -// -// By using stream operations in this manner the user can easily intermix custom -// kernel launches (via StreamExecutor::ThenLaunch()) with these pre-canned BLAS -// routines. #ifndef XLA_STREAM_EXECUTOR_BLAS_H_ #define XLA_STREAM_EXECUTOR_BLAS_H_ @@ -61,7 +44,9 @@ namespace stream_executor { namespace gpu { struct BlasLt; -} +struct MatrixDescriptor; +struct OutputMatrixDescriptor; +} // namespace gpu class Stream; class ScratchAllocator; @@ -204,6 +189,21 @@ class AlgorithmConfig { typedef int64_t ComputePrecision; constexpr ComputePrecision kDefaultComputePrecision = 0; +namespace detail { + +// Helper to return if `T` is the same type as `First` or any or `Rest`. +template +constexpr bool is_any_of() { + return false; +} + +template +constexpr bool is_any_of() { + return std::is_same_v || is_any_of(); +} + +} // namespace detail + // BLAS support interface -- this can be derived from a GPU executor when the // underlying platform has an BLAS library implementation available. See // StreamExecutor::AsBlas(). @@ -311,7 +311,10 @@ class BlasSupport { // Gets a list of supported algorithms for DoBlasGemmWithAlgorithm. virtual bool GetBlasGemmAlgorithms( - Stream *stream, std::vector *out_algorithms) = 0; + Stream *stream, const gpu::MatrixDescriptor &a, + const gpu::MatrixDescriptor &b, gpu::OutputMatrixDescriptor *c, + const void *alpha, const void *beta, + std::vector *out_algorithms) = 0; // Like DoBlasGemm, but accepts an algorithm and an compute type. // @@ -404,6 +407,170 @@ class BlasSupport { DeviceMemoryBase *c, int ldc, int64_t stride_c, int batch_count, const NumericOptions &numeric_options, blas::CallContext context) = 0; + template + absl::Status BlasGemmStridedBatchedWithAlgorithm( + Stream *stream, blas::Transpose transa, blas::Transpose transb, + uint64_t m, uint64 n, uint64_t k, ConstantType alpha, + const DeviceMemory &a, int lda, int64_t stride_a, + const DeviceMemory &b, int ldb, int64_t stride_b, + ConstantType beta, DeviceMemory *c, int ldc, int64_t stride_c, + int batch_count, blas::ComputationType computation_type, + blas::AlgorithmType algorithm, const NumericOptions &numeric_options, + blas::ProfileResult *output_profile_result, blas::CallContext context) { + TF_RETURN_IF_ERROR( + CheckTypesForExtendedBlas( + computation_type)); + + void *alpha_ptr = α + void *beta_ptr = β + float alpha_storage, beta_storage; + UpcastHalfToFloat(&alpha_ptr, &beta_ptr, &alpha_storage, + &beta_storage); + absl::Status status = DoBlasGemmStridedBatchedWithAlgorithm( + stream, transa, transb, m, n, k, alpha_ptr, a, + blas::ToDataType::value, lda, stride_a, b, + blas::ToDataType::value, ldb, stride_b, beta_ptr, c, + blas::ToDataType::value, ldc, stride_c, batch_count, + computation_type, algorithm, numeric_options, output_profile_result, + context); + if (output_profile_result) { + // The error is recorded in the profile. + return absl::OkStatus(); + } + return status; + } + + template + absl::Status BlasGemm(Stream *stream, blas::Transpose transa, + blas::Transpose transb, uint64_t m, uint64 n, uint64 k, + ConstantType alpha, const DeviceMemory &a, + int lda, const DeviceMemory &b, int ldb, + ConstantType beta, DeviceMemory *c, int ldc, + const NumericOptions &numeric_options, + blas::CallContext context) { + static_assert( + detail::is_any_of, + std::complex>(), + "Input can be int8_t, half, bf16, float, double, std::complex " + "or " + "std::complex"); + static_assert(!std::is_same_v || + detail::is_any_of(), + "If input is Eigen::half, constant has to be either " + "Eigen::half or float"); + static_assert(detail::is_any_of(), + "If input is not int8_t, Eigen::half, constant and input " + "types have to match"); + void *alpha_ptr = α + void *beta_ptr = β + float alpha_storage, beta_storage; + UpcastHalfToFloat(&alpha_ptr, &beta_ptr, &alpha_storage, + &beta_storage); + + return DoBlasGemm(stream, transa, transb, m, n, k, + blas::ToDataType::value, alpha_ptr, a, lda, b, + ldb, beta_ptr, c, ldc, numeric_options, context); + } + + template + absl::Status BlasGemm(Stream *stream, blas::Transpose transa, + blas::Transpose transb, uint64_t m, uint64 n, uint64 k, + const DeviceMemory &a, int lda, + const DeviceMemory &b, int ldb, + DeviceMemory *c, int ldc, + const NumericOptions &numeric_options, + blas::CallContext context) { + InputType alpha{1.0}; + InputType beta{0.0}; + return BlasGemm(stream, transa, transb, m, n, k, alpha, a, lda, b, ldb, + beta, c, ldc, numeric_options, context); + } + + template + absl::Status BlasGemmWithAlgorithm( + Stream *stream, blas::Transpose transa, blas::Transpose transb, + uint64_t m, uint64 n, uint64_t k, ConstantType alpha, + const DeviceMemory &a, int lda, + const DeviceMemory &b, int ldb, ConstantType beta, + DeviceMemory *c, int ldc, + blas::ComputationType computation_type, blas::AlgorithmType algorithm, + const NumericOptions &numeric_options, + blas::ProfileResult *output_profile_result, blas::CallContext context) { + TF_RETURN_IF_ERROR( + CheckTypesForExtendedBlas( + computation_type)); + + void *alpha_ptr = α + void *beta_ptr = β + float alpha_storage, beta_storage; + UpcastHalfToFloat(&alpha_ptr, &beta_ptr, &alpha_storage, + &beta_storage); + + absl::Status st = DoBlasGemmWithAlgorithm( + stream, transa, transb, m, n, k, alpha_ptr, a, + blas::ToDataType::value, lda, b, + blas::ToDataType::value, ldb, beta_ptr, c, + blas::ToDataType::value, ldc, computation_type, algorithm, + numeric_options, output_profile_result, context); + + if (output_profile_result) { + // The error is recorded in the profile. + return absl::OkStatus(); + } + return st; + } + + template + absl::Status BlasGemmWithAlgorithm( + Stream *stream, blas::Transpose transa, blas::Transpose transb, + uint64_t m, uint64 n, uint64_t k, const DeviceMemory &a, + int lda, const DeviceMemory &b, int ldb, + DeviceMemory *c, int ldc, + blas::ComputationType computation_type, blas::AlgorithmType algorithm, + blas::ProfileResult *output_profile_result, blas::CallContext context) { + OutputType alpha{1}; + OutputType beta{0}; + + return BlasGemmWithAlgorithm(stream, transa, transb, m, n, k, alpha, a, lda, + b, ldb, beta, c, ldc, computation_type, + algorithm, NumericOptions{}, + output_profile_result, context); + } + + template + absl::Status BlasGemmStridedBatched( + Stream *stream, blas::Transpose transa, blas::Transpose transb, + uint64_t m, uint64 n, uint64_t k, ConstantType alpha, + const DeviceMemory &a, int lda, int64_t stride_a, + const DeviceMemory &b, int ldb, int64_t stride_b, + ConstantType beta, DeviceMemory *c, int ldc, int64_t stride_c, + int batch_count, const NumericOptions &numeric_options, + blas::CallContext context) { + static_assert( + detail::is_any_of, + std::complex>(), + "Unsupported input type"); + static_assert(std::is_same_v || + (detail::is_any_of() && + std::is_same_v), + "Mismatched input and alpha/beta types"); + + void *alpha_ptr = α + void *beta_ptr = β + float alpha_storage, beta_storage; + UpcastHalfToFloat(&alpha_ptr, &beta_ptr, &alpha_storage, + &beta_storage); + + return DoBlasGemmStridedBatched( + stream, transa, transb, m, n, k, blas::ToDataType::value, + alpha_ptr, a, lda, stride_a, b, ldb, stride_b, beta_ptr, c, ldc, + stride_c, batch_count, numeric_options, context); + } + // Solves a triangular matrix equation. // // op(a) * x = alpha * b, @@ -505,6 +672,71 @@ class BlasSupport { // own memory pool for allocating workspace. void ResetWorkspace(); + // Checks whether types match before a call to extended BLAS version. + template + absl::Status CheckTypesForExtendedBlas( + blas::ComputationType computation_type) { + static_assert( + detail::is_any_of, std::complex>(), + "The only buffer types supported are: Eigen::half, float, " + "double, int8, std::complex and std::complex"); + static_assert( + std::is_same_v || + (std::is_same_v && + detail::is_any_of()), + "Mismatched alpha/beta and output types"); + + bool valid_computation_type = [computation_type] { + switch (computation_type) { + case blas::ComputationType::kF16: + return std::is_same_v; + case blas::ComputationType::kF32: + return detail::is_any_of>(); + case blas::ComputationType::kF64: + return detail::is_any_of>(); + case blas::ComputationType::kI32: + return std::is_same_v; + case blas::ComputationType::kF16AsF32: // fall-through + case blas::ComputationType::kBF16AsF32: // fall-through + case blas::ComputationType::kTF32AsF32: + return detail::is_any_of>(); + } + }(); + + if (!valid_computation_type) { + return absl::InternalError(absl::StrCat( + "Invalid computation type ", + blas::ComputationTypeString(computation_type), " for output type: ", + blas::DataTypeString(blas::ToDataType::value))); + } + return absl::OkStatus(); + } + + // Non-extended BLAS interface requires alpha/beta to be floats when input + // type is Eigen::half. However, for consistency purposes it is convenient + // for the interface to accept Eigen::half. + template + void UpcastHalfToFloat(void **alpha_ptr, void **beta_ptr, + float *alpha_storage, float *beta_storage) { + if (std::is_same::value) { + *alpha_storage = + static_cast(*reinterpret_cast(*alpha_ptr)); + *beta_storage = + static_cast(*reinterpret_cast(*beta_ptr)); + *alpha_ptr = alpha_storage; + *beta_ptr = beta_storage; + } else if (std::is_same::value) { + *alpha_storage = + static_cast(*reinterpret_cast(*alpha_ptr)); + *beta_storage = + static_cast(*reinterpret_cast(*beta_ptr)); + *alpha_ptr = alpha_storage; + *beta_ptr = beta_storage; + } + } + BlasSupport(const BlasSupport &) = delete; void operator=(const BlasSupport &) = delete; }; @@ -563,9 +795,11 @@ class BlasSupport { const void *beta, DeviceMemoryBase *c, int ldc, \ const NumericOptions &numeric_options, blas::CallContext context) \ override; \ - bool GetBlasGemmAlgorithms(Stream *stream, \ - std::vector *out_algorithms) \ - override; \ + bool GetBlasGemmAlgorithms( \ + Stream *stream, const gpu::MatrixDescriptor &a, \ + const gpu::MatrixDescriptor &b, gpu::OutputMatrixDescriptor *c, \ + const void *alpha, const void *beta, \ + std::vector *out_algorithms) override; \ absl::Status DoBlasGemmWithAlgorithm( \ Stream *stream, blas::Transpose transa, blas::Transpose transb, \ uint64_t m, uint64 n, uint64 k, const void *alpha, \ diff --git a/third_party/xla/xla/stream_executor/build_defs.bzl b/third_party/xla/xla/stream_executor/build_defs.bzl index a29f7d28d60718..6916574c646edf 100644 --- a/third_party/xla/xla/stream_executor/build_defs.bzl +++ b/third_party/xla/xla/stream_executor/build_defs.bzl @@ -1,7 +1,11 @@ """Configurations for StreamExecutor builds""" load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda_is_configured") -load("@local_config_rocm//rocm:build_defs.bzl", "if_rocm_is_configured") +load("@local_config_rocm//rocm:build_defs.bzl", _if_gpu_is_configured = "if_gpu_is_configured") +load( + "@local_tsl//tsl/platform:rules_cc.bzl", + "cc_library", +) def stream_executor_friends(): return ["//..."] @@ -13,12 +17,11 @@ def tf_additional_cuda_platform_deps(): return [] def tf_additional_cudnn_plugin_copts(): - # TODO(timshen): remove TF_ENABLE_CUDNN_FRONTEND once cudnn-frontend is imported. - return ["-DNV_CUDNN_DISABLE_EXCEPTION", "-DTF_ENABLE_CUDNN_FRONTEND"] + return ["-DNV_CUDNN_DISABLE_EXCEPTION"] -# Returns whether any GPU backend is configuered. -def if_gpu_is_configured(x): - return if_cuda_is_configured(x) + if_rocm_is_configured(x) +# Returns whether any GPU backend is configured. +def if_gpu_is_configured(if_true, if_false = []): + return _if_gpu_is_configured(if_true, if_false) def if_cuda_or_rocm(x): return if_gpu_is_configured(x) @@ -27,3 +30,61 @@ def if_cuda_or_rocm(x): # unnecessary dependency def tf_additional_gpu_compilation_copts(): return ["-DTF_DISABLE_NVLINK_BY_DEFAULT"] + +def gpu_only_cc_library(name, tags = [], **kwargs): + """A library that only gets compiled when GPU is configured, otherwise it's an empty target. + + Args: + name: Name of the target + tags: Tags being applied to the implementation target + **kwargs: Accepts all arguments that a `cc_library` would also accept + """ + if not native.package_name().startswith("xla/stream_executor"): + fail("gpu_only_cc_library may only be used in `xla/stream_executor/...`.") + + cc_library( + name = "%s_non_gpu" % name, + tags = ["manual"], + ) + cc_library( + name = "%s_gpu_only" % name, + tags = tags + ["manual"], + **kwargs + ) + native.alias( + name = name, + actual = if_gpu_is_configured(":%s_gpu_only" % name, ":%s_non_gpu" % name), + visibility = kwargs.get("visibility"), + compatible_with = kwargs.get("compatible_with"), + restricted_to = kwargs.get("restricted_to"), + target_compatible_with = kwargs.get("target_compatible_with"), + ) + +def cuda_only_cc_library(name, tags = [], **kwargs): + """A library that only gets compiled when CUDA is configured, otherwise it's an empty target. + + Args: + name: Name of the target + tags: Tags being applied to the implementation target + **kwargs: Accepts all arguments that a `cc_library` would also accept + """ + if not native.package_name().startswith("xla/stream_executor"): + fail("cuda_only_cc_library may only be used in `xla/stream_executor/...`.") + + cc_library( + name = "%s_non_cuda" % name, + tags = ["manual"], + ) + cc_library( + name = "%s_cuda_only" % name, + tags = tags + ["manual"], + **kwargs + ) + native.alias( + name = name, + actual = if_cuda_is_configured(":%s_cuda_only" % name, ":%s_non_cuda" % name), + visibility = kwargs.get("visibility"), + compatible_with = kwargs.get("compatible_with"), + restricted_to = kwargs.get("restricted_to"), + target_compatible_with = kwargs.get("target_compatible_with"), + ) diff --git a/third_party/xla/xla/stream_executor/command_buffer.cc b/third_party/xla/xla/stream_executor/command_buffer.cc index 874d0198e9ff71..189cec64cdd2cf 100644 --- a/third_party/xla/xla/stream_executor/command_buffer.cc +++ b/third_party/xla/xla/stream_executor/command_buffer.cc @@ -15,19 +15,14 @@ limitations under the License. #include "xla/stream_executor/command_buffer.h" -#include -#include #include #include -#include #include "absl/functional/any_invocable.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/kernel.h" -#include "xla/stream_executor/kernel_spec.h" -#include "xla/stream_executor/launch_dim.h" #include "xla/stream_executor/stream_executor.h" #include "xla/stream_executor/stream_executor_internal.h" #include "tsl/platform/errors.h" @@ -35,26 +30,12 @@ limitations under the License. namespace stream_executor { -CommandBuffer::~CommandBuffer() = default; -CommandBuffer::CommandBuffer(CommandBuffer&&) = default; -CommandBuffer& CommandBuffer::operator=(CommandBuffer&&) = default; - -void CommandBuffer::Deleter::operator()( - internal::CommandBufferInterface* impl) { - if (owned) delete impl; -} - -/*static*/ absl::StatusOr CommandBuffer::Create( +absl::StatusOr> CommandBuffer::Create( StreamExecutor* executor, Mode mode) { - TF_ASSIGN_OR_RETURN( - std::unique_ptr command_buffer, - executor->implementation()->GetCommandBufferImplementation(mode)); - - CommandBuffer cmd(std::move(command_buffer)); - return cmd; + return executor->implementation()->CreateCommandBuffer(mode); } -/*static*/ absl::StatusOr CommandBuffer::Trace( +absl::StatusOr> CommandBuffer::Trace( StreamExecutor* executor, absl::AnyInvocable function, Mode mode) { Stream stream(executor); @@ -65,7 +46,7 @@ void CommandBuffer::Deleter::operator()( return Trace(executor, &stream, std::move(function), mode); } -/*static*/ absl::StatusOr CommandBuffer::Trace( +absl::StatusOr> CommandBuffer::Trace( StreamExecutor* executor, Stream* stream, absl::AnyInvocable function, Mode mode) { if (stream == nullptr) @@ -73,132 +54,15 @@ void CommandBuffer::Deleter::operator()( "Can't trace command buffer on a null stream"); // Prepare an empty command buffer instance. - TF_ASSIGN_OR_RETURN(CommandBuffer command_buffer, + TF_ASSIGN_OR_RETURN(std::unique_ptr command_buffer, CommandBuffer::Create(executor, mode)); // Trace and finalize the command buffer. - TF_RETURN_IF_ERROR(command_buffer.implementation()->Trace( - stream, [&]() { return function(stream); })); - TF_RETURN_IF_ERROR(command_buffer.implementation()->Finalize()); + TF_RETURN_IF_ERROR( + command_buffer->Trace(stream, [&]() { return function(stream); })); + TF_RETURN_IF_ERROR(command_buffer->Finalize()); return command_buffer; } -/*static*/ bool CommandBuffer::SupportsConditionalCommands( - const Platform* platform) { - // TODO(ezhulenev): We should extend a Platform with a way to query - // implemented StreamExecutor features, for now we know that only CUDA - // platform supports conditional commands in command buffers. -#if defined(STREAM_EXECUTOR_CUDA_ENABLE_GRAPH_CONDITIONAL) - return platform->Name() == "CUDA"; -#endif - return false; -} - -const internal::CommandBufferInterface* CommandBuffer::implementation() const { - return implementation_.get(); -} - -internal::CommandBufferInterface* CommandBuffer::implementation() { - return implementation_.get(); -} - -/*static*/ CommandBuffer CommandBuffer::Create( - std::unique_ptr implementation) { - return CommandBuffer(std::move(implementation)); -} - -/*static*/ absl::Status CommandBuffer::Build( - internal::CommandBufferInterface* implementation, - const CommandBuffer::Builder& builder) { - CommandBuffer command_buffer(implementation); - return builder(&command_buffer); -} - -CommandBuffer::CommandBuffer( - std::unique_ptr implementation) - : implementation_(implementation.release(), {/*owned=*/true}) {} - -CommandBuffer::CommandBuffer(internal::CommandBufferInterface* implementation) - : implementation_(implementation, {/*owned=*/false}) {} - -absl::Status CommandBuffer::Barrier(StreamExecutor* executor) { - return implementation_->Barrier(executor); -} - -absl::Status CommandBuffer::Launch(const ThreadDim& threads, - const BlockDim& blocks, const Kernel& kernel, - const KernelArgs& args) { - return implementation_->Launch(threads, blocks, kernel, args); -} - -absl::Status CommandBuffer::AddNestedCommandBuffer( - const CommandBuffer& nested) { - return implementation_->AddNestedCommandBuffer(nested); -} - -absl::Status CommandBuffer::MemcpyDeviceToDevice(DeviceMemoryBase* dst, - const DeviceMemoryBase& src, - uint64_t size) { - return implementation_->MemcpyDeviceToDevice(dst, src, size); -} - -absl::Status CommandBuffer::Memset(DeviceMemoryBase* dst, - BitPattern bit_pattern, - size_t num_elements) { - return implementation_->Memset(dst, bit_pattern, num_elements); -} - -absl::StatusOr CommandBuffer::Allocate(size_t bytes) { - return implementation_->Allocate(bytes); -} - -absl::Status CommandBuffer::If(StreamExecutor* executor, - DeviceMemory pred, Builder then_builder) { - return implementation_->If(executor, pred, std::move(then_builder)); -} - -absl::Status CommandBuffer::IfElse(StreamExecutor* executor, - DeviceMemory pred, - Builder then_builder, Builder else_builder) { - return implementation_->IfElse(executor, pred, std::move(then_builder), - std::move(else_builder)); -} - -absl::Status CommandBuffer::Case(StreamExecutor* executor, - DeviceMemory index, - std::vector branches) { - return implementation_->Case(executor, index, std::move(branches)); -} - -absl::Status CommandBuffer::For(StreamExecutor* executor, int32_t num_iteration, - DeviceMemory loop_counter, - Builder body_builder) { - return implementation_->For(executor, num_iteration, loop_counter, - std::move(body_builder)); -} - -absl::Status CommandBuffer::While(StreamExecutor* executor, - DeviceMemory pred, Builder cond_builder, - Builder body_builder) { - return implementation_->While(executor, pred, std::move(cond_builder), - std::move(body_builder)); -} - -absl::Status CommandBuffer::Free(DeviceMemoryBase dst) { - return implementation_->Free(dst); -} - -CommandBuffer::Mode CommandBuffer::mode() const { - return implementation_->mode(); -} - -CommandBuffer::State CommandBuffer::state() const { - return implementation_->state(); -} - -absl::Status CommandBuffer::Finalize() { return implementation_->Finalize(); } - -absl::Status CommandBuffer::Update() { return implementation_->Update(); } - } // namespace stream_executor diff --git a/third_party/xla/xla/stream_executor/command_buffer.h b/third_party/xla/xla/stream_executor/command_buffer.h index 7dd88aa9766741..8b25dea694a074 100644 --- a/third_party/xla/xla/stream_executor/command_buffer.h +++ b/third_party/xla/xla/stream_executor/command_buffer.h @@ -26,10 +26,12 @@ limitations under the License. #include "absl/functional/any_invocable.h" #include "absl/status/status.h" #include "absl/status/statusor.h" +#include "absl/types/span.h" #include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/kernel.h" #include "xla/stream_executor/launch_dim.h" #include "xla/stream_executor/platform.h" +#include "tsl/lib/gtl/int_type.h" #include "tsl/platform/errors.h" namespace stream_executor { @@ -37,10 +39,6 @@ namespace stream_executor { class Stream; class StreamExecutor; -namespace internal { -class CommandBufferInterface; -} - //===----------------------------------------------------------------------===// // CommandBuffer //===----------------------------------------------------------------------===// @@ -54,11 +52,92 @@ class CommandBufferInterface; class CommandBuffer { public: // Builder constructs nested command buffers owned by a parent command buffer. + // + // Builder can use arbitrary number of nested execution scopes, the only + // requirement is that after builder constructed all commands, they all must + // be synchronized with a default execution scope. using Builder = std::function; - ~CommandBuffer(); - CommandBuffer(CommandBuffer&&); - CommandBuffer& operator=(CommandBuffer&&); + // Execution scope enables fine-grained synchronization scopes inside + // commands buffers. Implementation is very backend-specific and for CUDA/ROCM + // backends it's implemented as DAG edges. By default all commands launched in + // the `kDefaulExecutionScope` execution scope. + // + // Example #1: independent execution scopes and independent barriers + // + // ExecutionScope #0 ExecutionScope #1 + // + // A D + // B E + // ----- barrier ----- ----- barrier ----- + // C F + // + // (1) Commands A and B can run concurrently and must complete before C. + // (2) Commands D and E can run concurrently and must complete before F. + // (3) There is no syncrhonization between execution scopes, and commands + // from different execution scopes can execute concurrently with each + // other as long as they satisfy constraints of their respective + // execution scopes. + // + // + // + // Example #2: dependencies between scopes and inter-scope barriers + // + // ExecutionScope #0 ExecutionScope #1 + // + // A D + // B E + // ----------------- barrier ------------------ + // C F + // + // (1) Commands A and B can run concurrently and must complete before + // C and F. + // (2) Commands D and E can run concurrently and must complete before + // C and F. + // (3) Commands C and F can run concurrently. + // (4) All commands before a shared barrier (in both excecution scopes) + // should complete before any command after a berrier starts execution. + // + // + // + // Example #3: one-directional barriers between execution scopes + // + // ExecutionScope #0 ExecutionScope #1 + // + // A + // B + // ----- barrier ----- D + // C \ E + // ----- barrier ----- + // F + // + // (1) Commands A and B can run concurrently and must complete before + // C and F. + // (2) Commands D and E can run concurrently and must complete before + // F (does not synchronize with C). + // (3) Commands C and F can run concurrently. + // + // This is a more fine-grained barrier than in example #2: it enforces + // synchronization from execution scope #0 to execution scope #1 but no + // synchronization in other direction. For CUDA/ROCM backend it has the same + // semantics as stream wait operation. + // + TSL_LIB_GTL_DEFINE_INT_TYPE(ExecutionScopeId, int64_t); + static constexpr auto kDefaulExecutionScope = ExecutionScopeId(0); + + // An extension of a `Builder` defined above that builds a nested command + // buffer in a given execution scope. Builder can use arbitrary number of + // nested execution scopes, the only requirement is that after builder + // constructed all commands, they all must be synchronized with an execution + // scope passed as an argument. + using ExecutionScopeBuilder = + std::function; + + CommandBuffer() = default; + virtual ~CommandBuffer() = default; + + CommandBuffer(const CommandBuffer&) = delete; + void operator=(const CommandBuffer&) = delete; // Command buffer state: // @@ -86,9 +165,13 @@ class CommandBuffer { // Command buffer constructors //===--------------------------------------------------------------------===// + // TODO(b/323534971): Command buffer constructors should be moved to + // StreamExecutor or a dedicated CommandBufferFactory accessible via + // StreamExecutor. + // Creates a new empty command buffer on the given executor. - static absl::StatusOr Create(StreamExecutor* executor, - Mode mode = Mode::kPrimary); + static absl::StatusOr> Create( + StreamExecutor* executor, Mode mode = Mode::kPrimary); // Creates a new command buffer on the given executor by tracing `function` // invocation. All StreamExecutor operations on a Stream argument will be @@ -100,59 +183,129 @@ class CommandBuffer { // default we construct traced command buffers in nested mode because the // primary use case for traced command buffers is to be inserted into primary // command buffers constructed with explicit APIs. - static absl::StatusOr Trace( + static absl::StatusOr> Trace( StreamExecutor* executor, absl::AnyInvocable function, Mode mode = Mode::kNested); // Creates a new command buffer on the given executor by tracing `function` // invocation using a user provided stream that will be passed to `function`. - static absl::StatusOr Trace( + static absl::StatusOr> Trace( StreamExecutor* executor, Stream* stream, absl::AnyInvocable function, Mode mode = Mode::kNested); //===--------------------------------------------------------------------===// - // Command buffer properties + // Command buffer API //===--------------------------------------------------------------------===// - // Returns true if command buffer on a given platform supports conditional - // commands (If, IfThen, While). - static bool SupportsConditionalCommands(const Platform* platform); + // Adds an execution barrier to a given execution scope: all commands added + // before a barrier in a the execution scope will complete before any of the + // commands added after a barrier in the same execution scope. + virtual absl::Status Barrier(StreamExecutor* executor, + ExecutionScopeId execution_scope_id) = 0; - //===--------------------------------------------------------------------===// - // Command buffer API - //===--------------------------------------------------------------------===// + // Adds an execution barrier that synchronizes commands across multiple + // execution scopes. See example #2 in execution scope id documentation. + virtual absl::Status Barrier( + StreamExecutor* executor, + absl::Span execution_scope_ids) = 0; + + // Adds an execution barrier from execution scope `from_execution_scope_id` to + // execution scope `to_execution_scope_id`. See example #3 for details. + virtual absl::Status Barrier(StreamExecutor* executor, + ExecutionScopeId from_execution_scope_id, + ExecutionScopeId to_execution_scope_id) = 0; - // Adds an execution barrier to a command buffer: all commands added before a - // barrier will complete before any of the commands added after a barrier. - absl::Status Barrier(StreamExecutor* executor); + // Adds an execution barrier to the default execution scope. + absl::Status Barrier(StreamExecutor* executor) { + return Barrier(executor, kDefaulExecutionScope); + } - // Adds a kernel launch command to the command buffer. + // Adds a kernel launch command. + virtual absl::Status Launch(ExecutionScopeId execution_scope_id, + const ThreadDim& threads, const BlockDim& blocks, + const Kernel& kernel, const KernelArgs& args) = 0; + + // Adds a kernel launch command to the default execution scope. absl::Status Launch(const ThreadDim& threads, const BlockDim& blocks, - const Kernel& kernel, const KernelArgs& args); + const Kernel& kernel, const KernelArgs& args) { + return Launch(kDefaulExecutionScope, threads, blocks, kernel, args); + } - // Adds a nested command buffer to the command buffer. - absl::Status AddNestedCommandBuffer(const CommandBuffer& nested); + // Type-safe wrapper for launching typed kernels. Notice that the order of + // arguments is different do disambiguate from the regular launch API. + template + absl::Status Launch(const TypedKernel& kernel, + ExecutionScopeId execution_scope_id, + const ThreadDim& threads, const BlockDim& blocks, + Args... args); - // Adds a device-to-device memory copy to the command buffer. + // Type-safe wrapper for launching typed kernels in default execution scope. + template + absl::Status Launch(const TypedKernel& kernel, + const ThreadDim& threads, const BlockDim& blocks, + Args... args) { + return Launch(kernel, kDefaulExecutionScope, threads, blocks, args...); + } + + // Adds a nested command buffer. + virtual absl::Status AddNestedCommandBuffer( + ExecutionScopeId execution_scope_id, const CommandBuffer& nested) = 0; + + // Adds a nested command buffer to the default execution scope. + absl::Status AddNestedCommandBuffer(const CommandBuffer& nested) { + return AddNestedCommandBuffer(kDefaulExecutionScope, nested); + } + + // Adds a device-to-device memory copy. + virtual absl::Status MemcpyDeviceToDevice(ExecutionScopeId execution_scope_id, + DeviceMemoryBase* dst, + const DeviceMemoryBase& src, + uint64_t size) = 0; + + // Adds a device-to-device memory copy to the default execution scope. absl::Status MemcpyDeviceToDevice(DeviceMemoryBase* dst, - const DeviceMemoryBase& src, uint64_t size); + const DeviceMemoryBase& src, + uint64_t size) { + return MemcpyDeviceToDevice(kDefaulExecutionScope, dst, src, size); + } - // Adds a memset node to the command buffer. + // Supported bit patterns for memset commands. using BitPattern = std::variant; + + // Adds a memset command. + virtual absl::Status Memset(ExecutionScopeId execution_scope_id, + DeviceMemoryBase* dst, BitPattern bit_pattern, + size_t num_elements) = 0; + + // Adds a memset command to the default execution scope. absl::Status Memset(DeviceMemoryBase* dst, BitPattern bit_pattern, - size_t num_elements); + size_t num_elements) { + return Memset(kDefaulExecutionScope, dst, bit_pattern, num_elements); + } //--------------------------------------------------------------------------// // Command buffer memory allocation API //--------------------------------------------------------------------------// - // Adds a device memory allocation command to the command buffer. - absl::StatusOr Allocate(size_t bytes); + // Adds a device memory allocation command. + virtual absl::StatusOr Allocate( + ExecutionScopeId execution_scope_id, size_t bytes) = 0; + + // Adds a device memory allocation command to the default execution scope. + absl::StatusOr Allocate(size_t bytes) { + return Allocate(kDefaulExecutionScope, bytes); + } + + // Adds a device memory free command. + virtual absl::Status Free(ExecutionScopeId execution_scope_id, + DeviceMemoryBase dst) = 0; - // This API free buffer that is allocated by Allocate command - absl::Status Free(DeviceMemoryBase dst); + // Adds a device memory free command to the default execution scope. + absl::Status Free(DeviceMemoryBase dst) { + return Free(kDefaulExecutionScope, dst); + } //--------------------------------------------------------------------------// // Command buffer condtitional commands API @@ -160,29 +313,61 @@ class CommandBuffer { // Adds a conditional operation that will execute a command buffer constructed // by `then_builder` if `pred` value is `true`. + virtual absl::Status If(ExecutionScopeId execution_scope_id, + StreamExecutor* executor, DeviceMemory pred, + Builder then_builder) = 0; + + // Adds a conditional If operation to default execution scope. absl::Status If(StreamExecutor* executor, DeviceMemory pred, - Builder then_builder); + Builder then_builder) { + return If(kDefaulExecutionScope, executor, pred, then_builder); + } // Adds a conditional operation that will execute a command buffer constructed // by `then_builder` if `pred` value is `true`, or a command buffer // constructed by `else_builder` if `pred` is `false`. + virtual absl::Status IfElse(ExecutionScopeId execution_scope_id, + StreamExecutor* executor, DeviceMemory pred, + Builder then_builder, Builder else_builder) = 0; + + // Adds a conditional IfElse operation to default execution scope. absl::Status IfElse(StreamExecutor* executor, DeviceMemory pred, - Builder then_builder, Builder else_builder); + Builder then_builder, Builder else_builder) { + return IfElse(kDefaulExecutionScope, executor, pred, then_builder, + else_builder); + } // Adds a conditional operation that will execute a command buffer constructed // by the `branches` builder at `index`. If `index` is out of range, then it // will run a conditional command buffer constructed by the last builder. // // See: https://github.com/openxla/stablehlo/blob/main/docs/spec.md#case + virtual absl::Status Case(ExecutionScopeId execution_scope_id, + StreamExecutor* executor, + DeviceMemory index, + std::vector branches) = 0; + + // Adds a conditional Case operation to default execution scope. absl::Status Case(StreamExecutor* executor, DeviceMemory index, - std::vector branches); + std::vector branches) { + return Case(kDefaulExecutionScope, executor, index, branches); + } // Adds a conditional operation that will execute a command buffer constructed // by the `body_builder` exactly `num_iteration` times. This means the // condition is known at compile time (`num_iteration` < `loop_counter`), and // does not require a `cond_builder`. + virtual absl::Status For(ExecutionScopeId execution_scope_id, + StreamExecutor* executor, int32_t num_iteration, + DeviceMemory loop_counter, + Builder body_builder) = 0; + + // Adds a conditional For operation to default execution scope. absl::Status For(StreamExecutor* executor, int32_t num_iteration, - DeviceMemory loop_counter, Builder body_builder); + DeviceMemory loop_counter, Builder body_builder) { + return For(kDefaulExecutionScope, executor, num_iteration, loop_counter, + body_builder); + } // Adds a conditional operation that will execute a command buffer constructed // by the `cond_builder` that must update `pred` value, and then depending on @@ -197,68 +382,48 @@ class CommandBuffer { // body_builder() // cond_builder() // + virtual absl::Status While(ExecutionScopeId execution_scope_id, + StreamExecutor* executor, DeviceMemory pred, + Builder cond_builder, Builder body_builder) = 0; + + // Adds a conditional While operation to default execution scope. absl::Status While(StreamExecutor* executor, DeviceMemory pred, - Builder cond_builder, Builder body_builder); + Builder cond_builder, Builder body_builder) { + return While(kDefaulExecutionScope, executor, pred, cond_builder, + body_builder); + } + //--------------------------------------------------------------------------// + // Command buffer state management API //--------------------------------------------------------------------------// // Finalizes command buffer and makes it executable. Once command buffer is // finalized no commands can be added to it. - absl::Status Finalize(); + virtual absl::Status Finalize() = 0; // Begins command buffer update. Command buffer update should be finalized // before it can be executed. - absl::Status Update(); - - // Type-safe wrapper for launching typed kernels. Notice that the order of - // arguments is different do disambiguate from the regular launch API. - template - absl::Status Launch(const TypedKernel& kernel, - const ThreadDim& threads, const BlockDim& blocks, - Args... args); + virtual absl::Status Update() = 0; // Returns command buffer execution mode. - Mode mode() const; + virtual Mode mode() const = 0; // Returns command buffer state. - State state() const; - - //===--------------------------------------------------------------------===// - // Semi-internal APIs - //===--------------------------------------------------------------------===// - - // Following APIs are public, but considered to be implementation detail and - // discouraged from uses outside of StreamExecutor package. - const internal::CommandBufferInterface* implementation() const; - internal::CommandBufferInterface* implementation(); - - // Creates a command buffer from a platform-specific command buffer - // implementation. - static CommandBuffer Create( - std::unique_ptr implementation); - - // An adaptor for a command buffer builder that records commands into the - // platform-specific implementation - static absl::Status Build(internal::CommandBufferInterface* implementation, - const CommandBuffer::Builder& builder); + virtual State state() const = 0; + //--------------------------------------------------------------------------// + // Command buffer tracing API + //--------------------------------------------------------------------------// private: - explicit CommandBuffer( - std::unique_ptr implementation); - - explicit CommandBuffer(internal::CommandBufferInterface* implementation); - - // A custom deleter to be able to construct command buffer that doesn't own - // underlying implementation (behaves like std::weak_ptr for implementation). - struct Deleter { - void operator()(internal::CommandBufferInterface*); - bool owned = true; - }; - - std::unique_ptr implementation_; - - CommandBuffer(const CommandBuffer&) = delete; - void operator=(const CommandBuffer&) = delete; + // Tracing APIs are private because they do not compose with command buffer + // updates. Instead of tracing directly into the command buffer users should + // create traced command buffers using factory methods and add them to primary + // command buffers as nested operations. + + // Traces `function` invocation by recording all operations on the `stream` + // into the command buffer. Command buffer must be empty. + virtual absl::Status Trace(Stream* stream, + absl::AnyInvocable function) = 0; }; //===----------------------------------------------------------------------===// @@ -267,11 +432,13 @@ class CommandBuffer { template inline absl::Status CommandBuffer::Launch(const TypedKernel& kernel, + ExecutionScopeId execution_scope_id, const ThreadDim& threads, const BlockDim& blocks, Args... args) { auto kernel_args = PackKernelArgs(kernel, args...); - TF_RETURN_IF_ERROR(Launch(threads, blocks, kernel, *kernel_args)); + TF_RETURN_IF_ERROR( + Launch(execution_scope_id, threads, blocks, *kernel, *kernel_args)); return absl::OkStatus(); } diff --git a/third_party/xla/xla/stream_executor/cuda/BUILD b/third_party/xla/xla/stream_executor/cuda/BUILD index dff6e99f2ec6f8..b77a3aafb87077 100644 --- a/third_party/xla/xla/stream_executor/cuda/BUILD +++ b/third_party/xla/xla/stream_executor/cuda/BUILD @@ -1,17 +1,17 @@ load("@bazel_skylib//rules:common_settings.bzl", "bool_flag") -load("@local_config_cuda//cuda:build_defs.bzl", "cuda_library", "if_cuda") load( "//xla:xla.bzl", "xla_cc_test", ) load( "//xla/stream_executor:build_defs.bzl", + "cuda_only_cc_library", "stream_executor_friends", "tf_additional_cuda_platform_deps", "tf_additional_cudnn_plugin_copts", "tf_additional_gpu_compilation_copts", ) -load("@local_tsl//tsl:tsl.bzl", "if_google", "if_nccl", "set_external_visibility", "tsl_copts") +load("@local_tsl//tsl:tsl.bzl", "if_google", "if_nccl", "internal_visibility", "tsl_copts") load( "@local_tsl//tsl/platform:build_config_root.bzl", "if_static", @@ -27,7 +27,8 @@ load( ) package( - default_visibility = ["//visibility:public"], + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = internal_visibility([":friends"]), licenses = ["notice"], ) @@ -36,21 +37,6 @@ package_group( packages = stream_executor_friends(), ) -# Add `--//third_party/tensorflow/compiler/xla/stream_executor/cuda:enable_graph_conditional` to -# build command to enable CUDA graph conditional nodes support. Requires CUDA >=12.3. -# -# See: https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#conditional-graph-nodes -bool_flag( - name = "enable_graph_conditional", - build_setting_default = False, -) - -config_setting( - name = "graph_conditional_enabled", - flag_values = {":enable_graph_conditional": "True"}, - visibility = ["//visibility:public"], -) - bool_flag( name = "enable_libnvptxcompiler_support", build_setting_default = if_google( @@ -64,133 +50,115 @@ config_setting( flag_values = { ":enable_libnvptxcompiler_support": "True", }, - visibility = ["//visibility:public"], ) cc_library( name = "cuda_platform_id", srcs = ["cuda_platform_id.cc"], hdrs = ["cuda_platform_id.h"], - visibility = ["//visibility:public"], deps = ["//xla/stream_executor:platform"], ) -cc_library( +cuda_only_cc_library( name = "cuda_platform", - srcs = if_cuda_is_configured(["cuda_platform.cc"]), - hdrs = if_cuda_is_configured(["cuda_platform.h"]), + srcs = ["cuda_platform.cc"], + hdrs = ["cuda_platform.h"], visibility = ["//visibility:public"], - deps = if_cuda_is_configured( + deps = [ ":cuda_activation", - ":cuda_driver", - ":cuda_runtime", ":cuda_collectives", + ":cuda_driver", ":cuda_executor", ":cuda_platform_id", + ":cuda_runtime", + "//xla/stream_executor", + "//xla/stream_executor:platform_manager", + "//xla/stream_executor:stream_executor_internal", + "//xla/stream_executor/gpu:gpu_driver_header", + "//xla/stream_executor/gpu:gpu_executor_header", + "//xla/stream_executor/platform", "@com_google_absl//absl/base", "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", "@com_google_absl//absl/memory", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/synchronization", - "//xla/stream_executor", - "//xla/stream_executor:stream_executor_internal", - "//xla/stream_executor:multi_platform_manager", - "//xla/stream_executor/gpu:gpu_executor_header", - "//xla/stream_executor/platform", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:status", "@local_tsl//tsl/platform:statusor", - ], - ) + tf_additional_cuda_platform_deps(), - alwayslink = True, # Registers itself with the MultiPlatformManager. + ] + tf_additional_cuda_platform_deps(), + alwayslink = True, # Registers itself with the PlatformManager. ) -cc_library( +cuda_only_cc_library( name = "cuda_diagnostics", - srcs = if_cuda_is_configured(["cuda_diagnostics.cc"]), - hdrs = if_cuda_is_configured(["cuda_diagnostics.h"]), - visibility = ["//visibility:public"], - deps = if_cuda_is_configured([ + srcs = ["cuda_diagnostics.cc"], + hdrs = ["cuda_diagnostics.h"], + deps = [ + "//xla/stream_executor/gpu:gpu_diagnostics_header", "@com_google_absl//absl/container:inlined_vector", - "@com_google_absl//absl/strings", "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", - "//xla/stream_executor/gpu:gpu_diagnostics_header", - "//xla/stream_executor/platform", "@local_tsl//tsl/platform:logging", "@local_tsl//tsl/platform:platform_port", - "@local_tsl//tsl/platform:status", - ]), + ], ) # Buildozer can not remove dependencies inside select guards, so we have to use # an intermediate target. -cc_library( - name = "ptxas_wrapper", - visibility = ["//visibility:public"], -) +cc_library(name = "ptxas_wrapper") -cc_library( - name = "nvlink_wrapper", - visibility = ["//visibility:public"], -) +cc_library(name = "nvlink_wrapper") # Buildozer can not remove dependencies inside select guards, so we have to use # an intermediate target. -cc_library( - name = "fatbinary_wrapper", - visibility = ["//visibility:public"], -) +cc_library(name = "fatbinary_wrapper") -cc_library( +cuda_only_cc_library( name = "cuda_driver", - srcs = if_cuda_is_configured(["cuda_driver.cc"]), - hdrs = if_cuda_is_configured(["cuda_driver.h"]), - visibility = ["//visibility:public"], - deps = if_cuda_is_configured([ - ":cuda_diagnostics", + srcs = ["cuda_driver.cc"], + hdrs = ["cuda_driver.h"], + local_defines = ["GOOGLE_CUDA=1"], + deps = [ + "//xla/stream_executor", + "//xla/stream_executor:device_options", + "//xla/stream_executor/gpu:gpu_diagnostics_header", + "//xla/stream_executor/gpu:gpu_driver_header", + "//xla/stream_executor/gpu:gpu_types_header", "@com_google_absl//absl/base", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/container:node_hash_map", "@com_google_absl//absl/debugging:leak_check", + "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", - "@com_google_absl//absl/memory", - "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/synchronization", "@com_google_absl//absl/types:span", "@local_config_cuda//cuda:cuda_headers", - "//xla/stream_executor:device_options", - "//xla/stream_executor", - "//xla/stream_executor/gpu:gpu_diagnostics_header", - "//xla/stream_executor/gpu:gpu_driver_header", - "//xla/stream_executor/gpu:gpu_types_header", - "//xla/stream_executor/platform", - "//xla/stream_executor/platform:dso_loader", "@local_tsl//tsl/cuda", "@local_tsl//tsl/cuda:cudart", "@local_tsl//tsl/platform:env", - "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:logging", "@local_tsl//tsl/platform:macros", "@local_tsl//tsl/platform:numbers", "@local_tsl//tsl/platform:stacktrace", - "@local_tsl//tsl/platform:status", - "@local_tsl//tsl/platform:statusor", - ]), + ], ) -cc_library( +cuda_only_cc_library( name = "cuda_runtime", - srcs = if_cuda_is_configured(["cuda_runtime.cc"]), - visibility = ["//visibility:public"], + srcs = ["cuda_runtime.cc"], deps = [ "//xla/stream_executor/gpu:gpu_runtime_header", "//xla/stream_executor/gpu:gpu_types_header", @@ -199,18 +167,16 @@ cc_library( "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", + "@local_config_cuda//cuda:cuda_headers", "@local_tsl//tsl/platform:logging", "@local_tsl//tsl/platform:statusor", - ] + if_cuda_is_configured([ - "@local_config_cuda//cuda:cuda_headers", - ]), + ], ) -cc_library( +cuda_only_cc_library( name = "cuda_collectives", - srcs = if_cuda_is_configured(["cuda_collectives.cc"]), + srcs = ["cuda_collectives.cc"], defines = if_nccl(["STREAM_EXECUTOR_GPU_ENABLE_XCCL"]), - visibility = ["//visibility:public"], deps = [ ":cuda_driver", "//xla/stream_executor/gpu:gpu_collectives_header", @@ -251,7 +217,8 @@ xla_cc_test( ], deps = [ ":cuda_driver", - "@com_google_absl//absl/memory", + "//xla/stream_executor/gpu:gpu_driver_header", + "@com_google_absl//absl/log", "@local_config_cuda//cuda:cuda_headers", "@local_tsl//tsl/platform:test", "@local_tsl//tsl/platform:test_main", @@ -269,8 +236,9 @@ xla_cc_test( ":cuda_platform", "//xla/stream_executor", "//xla/stream_executor:device_memory", - "//xla/stream_executor:multi_platform_manager", - "@com_google_absl//absl/memory", + "//xla/stream_executor:platform_manager", + "@local_tsl//tsl/lib/core:status_test_util", + "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/platform:test", "@local_tsl//tsl/platform:test_main", ], @@ -285,77 +253,66 @@ cc_library( deps = ["//xla/stream_executor/gpu:gpu_activation_header"], ) -cc_library( +cuda_only_cc_library( name = "cuda_activation", srcs = [], - hdrs = if_cuda_is_configured(["cuda_activation.h"]), - visibility = ["//visibility:public"], - deps = if_cuda_is_configured([ + hdrs = ["cuda_activation.h"], + deps = [ ":cuda_driver", - "@local_config_cuda//cuda:cuda_headers", "//xla/stream_executor", "//xla/stream_executor:stream_executor_internal", "//xla/stream_executor/gpu:gpu_activation", "//xla/stream_executor/platform", - ]), + "@local_config_cuda//cuda:cuda_headers", + ], ) -cc_library( +cuda_only_cc_library( name = "cublas_lt_header", - hdrs = if_cuda_is_configured([ + hdrs = [ "cuda_blas_lt.h", "cuda_blas_utils.h", - ]), + ], visibility = ["//visibility:public"], - deps = if_cuda_is_configured([ - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@local_config_cuda//cuda:cuda_headers", + deps = [ "//xla:types", - "//xla/stream_executor:host_or_device_scalar", "//xla/stream_executor", "//xla/stream_executor/gpu:gpu_blas_lt", - "//xla/stream_executor/platform", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/synchronization", + "@local_config_cuda//cuda:cuda_headers", "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:status", - "@local_tsl//tsl/platform:statusor", - ]), + ], ) -cc_library( +cuda_only_cc_library( name = "cublas_plugin", - srcs = if_cuda_is_configured([ + srcs = [ "cuda_blas.cc", "cuda_blas_lt.cc", - ]), - hdrs = if_cuda_is_configured([ + ], + hdrs = [ "cuda_blas.h", "cuda_blas_lt.h", - ]), + ], visibility = ["//visibility:public"], - deps = if_cuda_is_configured([ + deps = [ ":cuda_activation", ":cuda_blas_utils", ":cuda_executor", ":cuda_helpers", ":cuda_platform_id", ":cuda_stream", - "@com_google_absl//absl/base:core_headers", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/synchronization", - "@eigen_archive//:eigen3", - "@local_config_cuda//cuda:cuda_headers", "//xla:shape_util", "//xla:status_macros", "//xla:types", "//xla:util", "//xla/stream_executor", - "//xla/stream_executor:plugin_registry", "//xla/stream_executor:device_memory", "//xla/stream_executor:host_or_device_scalar", + "//xla/stream_executor:plugin_registry", "//xla/stream_executor/gpu:gpu_activation_header", "//xla/stream_executor/gpu:gpu_blas_lt", "//xla/stream_executor/gpu:gpu_executor_header", @@ -364,67 +321,76 @@ cc_library( "//xla/stream_executor/gpu:gpu_timer", "//xla/stream_executor/gpu:gpu_types_header", "//xla/stream_executor/platform", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/log", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/time", + "@eigen_archive//:eigen3", + "@local_config_cuda//cuda:cuda_headers", "@local_tsl//tsl/cuda:cublas", "@local_tsl//tsl/cuda:cublas_lt", + "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:logging", + "@local_tsl//tsl/platform:ml_dtypes", "@local_tsl//tsl/platform:status", "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/platform:tensor_float_32_hdr_lib", - ]) + if_static([ + "@local_tsl//tsl/protobuf:dnn_proto_cc", + ] + if_static([ "@local_tsl//tsl/platform:tensor_float_32_utils", ]), alwayslink = True, ) -cc_library( +cuda_only_cc_library( name = "cuda_blas_utils", - srcs = if_cuda_is_configured(["cuda_blas_utils.cc"]), - hdrs = if_cuda_is_configured(["cuda_blas_utils.h"]), - visibility = ["//visibility:public"], - deps = if_cuda_is_configured([ - "@com_google_absl//absl/strings", + srcs = ["cuda_blas_utils.cc"], + hdrs = ["cuda_blas_utils.h"], + deps = [ + "//xla/stream_executor", + "@com_google_absl//absl/log", "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", "@local_config_cuda//cuda:cuda_headers", - "//xla/stream_executor", "@local_tsl//tsl/cuda:cublas", "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:status", - ]), + ], ) -cc_library( +cuda_only_cc_library( name = "cufft_plugin", - srcs = if_cuda_is_configured(["cuda_fft.cc"]), - hdrs = if_cuda_is_configured(["cuda_fft.h"]), + srcs = ["cuda_fft.cc"], + hdrs = ["cuda_fft.h"], visibility = ["//visibility:public"], - deps = if_cuda_is_configured([ + deps = [ ":cuda_activation_header", - ":cuda_helpers", ":cuda_platform_id", - ":cuda_stream", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/status", - "@local_config_cuda//cuda:cuda_headers", "//xla/stream_executor", - "//xla/stream_executor:stream_executor_internal", "//xla/stream_executor:fft", "//xla/stream_executor:plugin_registry", + "//xla/stream_executor:stream_executor_internal", "//xla/stream_executor/gpu:gpu_executor_header", "//xla/stream_executor/gpu:gpu_helpers_header", + "//xla/stream_executor/gpu:gpu_stream_header", "//xla/stream_executor/platform", - "//xla/stream_executor/platform:dso_loader", + "@com_google_absl//absl/log", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@local_config_cuda//cuda:cuda_headers", "@local_tsl//tsl/cuda:cufft", - "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:logging", - "@local_tsl//tsl/platform:status", - ]), + "@local_tsl//tsl/platform:statusor", + ], alwayslink = True, ) cc_library( name = "cuda_dnn_headers", textual_hdrs = ["cuda_dnn.h"], - visibility = ["//visibility:public"], deps = if_cuda_is_configured([ ":cuda_activation_header", "//xla/stream_executor:dnn", @@ -432,45 +398,55 @@ cc_library( ]) + [ "//xla/stream_executor", # build_cleaner: keep "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/types:span", "@local_tsl//tsl/platform:status", + "@local_tsl//tsl/protobuf:dnn_proto_cc", ], ) -cc_library( +cuda_only_cc_library( name = "cudnn_plugin", - srcs = if_cuda_is_configured(["cuda_dnn.cc"]), - hdrs = if_cuda_is_configured(["cuda_dnn.h"]), + srcs = ["cuda_dnn.cc"], + hdrs = ["cuda_dnn.h"], copts = tf_additional_cudnn_plugin_copts(), visibility = ["//visibility:public"], - deps = if_cuda_is_configured([ + deps = [ ":cuda_activation", ":cuda_diagnostics", ":cuda_driver", ":cuda_executor", ":cuda_platform_id", ":cuda_stream", + "//xla/stream_executor", + "//xla/stream_executor:dnn", + "//xla/stream_executor:plugin_registry", + "//xla/stream_executor:stream_executor_internal", + "//xla/stream_executor/gpu:gpu_activation_header", + "//xla/stream_executor/gpu:gpu_diagnostics_header", + "//xla/stream_executor/gpu:gpu_driver_header", + "//xla/stream_executor/gpu:gpu_executor_header", + "//xla/stream_executor/gpu:gpu_stream", + "//xla/stream_executor/gpu:gpu_timer_header", + "//xla/stream_executor/platform", + "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/memory", "@com_google_absl//absl/status", - "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/time", "@com_google_absl//absl/types:span", - "@com_google_absl//absl/status:statusor", "@cudnn_frontend_archive//:cudnn_frontend", "@eigen_archive//:eigen3", "@local_config_cuda//cuda:cuda_headers", "@local_config_cuda//cuda:cudnn_header", - "//xla/stream_executor:dnn", - "//xla/stream_executor:plugin_registry", - "//xla/stream_executor:stream_executor_internal", - "//xla/stream_executor", - "//xla/stream_executor/gpu:gpu_executor_header", - "//xla/stream_executor/gpu:gpu_timer_header", - "//xla/stream_executor/platform", "@local_tsl//tsl/cuda:cudnn", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:logging", @@ -478,91 +454,69 @@ cc_library( "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/platform:tensor_float_32_hdr_lib", "@local_tsl//tsl/platform:tensor_float_32_utils", + "@local_tsl//tsl/protobuf:dnn_proto_cc", "@local_tsl//tsl/util:env_var", - ]), + ], alwayslink = True, ) -cc_library( +cuda_only_cc_library( name = "cuda_kernel", - srcs = if_cuda_is_configured(["cuda_kernel.cc"]), - hdrs = if_cuda_is_configured(["cuda_kernel.h"]), - visibility = ["//visibility:public"], - deps = if_cuda_is_configured([ + srcs = ["cuda_kernel.cc"], + hdrs = ["cuda_kernel.h"], + deps = [ ":cuda_driver", - "@com_google_absl//absl/log", - "@local_config_cuda//cuda:cuda_headers", - "@com_google_absl//absl/status:statusor", "//xla/stream_executor", - "//xla/stream_executor/gpu:gpu_kernel_header", "//xla/stream_executor/gpu:gpu_driver_header", + "//xla/stream_executor/gpu:gpu_kernel_header", "//xla/stream_executor/platform", + "@com_google_absl//absl/log", + "@com_google_absl//absl/status:statusor", + "@local_config_cuda//cuda:cuda_headers", "@local_tsl//tsl/platform:statusor", - ]), + ], ) -cuda_library( +cc_library( name = "cuda_conditional_kernels", - srcs = if_cuda( - ["cuda_conditional_kernels.cu.cc"], - ["cuda_conditional_kernels.cc"], - ), - local_defines = select({ - ":graph_conditional_enabled": ["STREAM_EXECUTOR_CUDA_ENABLE_GRAPH_CONDITIONAL=1"], - "//conditions:default": [], - }), - visibility = ["//visibility:public"], - deps = [ - "@com_google_absl//absl/log", - "@local_config_cuda//cuda:cuda_headers", - ], + srcs = ["cuda_conditional_kernels.cc"], ) # TODO(leary) we likely need to canonicalize/eliminate this. cc_library( name = "cuda_helpers", textual_hdrs = if_cuda_is_configured(["cuda_helpers.h"]), - visibility = ["//visibility:public"], deps = if_cuda_is_configured([ "//xla/stream_executor/gpu:gpu_helpers_header", ]), ) -cc_library( +cuda_only_cc_library( name = "cuda_event", - srcs = if_cuda_is_configured(["cuda_event.cc"]), - hdrs = if_cuda_is_configured(["cuda_event.h"]), - visibility = ["//visibility:public"], - deps = if_cuda_is_configured([ - ":cuda_driver", - ":cuda_stream", + srcs = ["cuda_event.cc"], + hdrs = ["cuda_event.h"], + deps = [ "//xla/stream_executor", + "//xla/stream_executor/gpu:gpu_driver_header", "//xla/stream_executor/gpu:gpu_event", + "//xla/stream_executor/gpu:gpu_executor_header", + "@com_google_absl//absl/log", "@com_google_absl//absl/status:statusor", "@local_config_cuda//cuda:cuda_headers", - "//xla/stream_executor/gpu:gpu_executor_header", - "//xla/stream_executor/gpu:gpu_stream_header", - "@local_tsl//tsl/platform:statusor", - ]), + ], ) -cc_library( +cuda_only_cc_library( name = "cuda_stream", srcs = [], - hdrs = if_cuda_is_configured(["cuda_stream.h"]), - visibility = ["//visibility:public"], - deps = if_cuda_is_configured([ - ":cuda_driver", + hdrs = ["cuda_stream.h"], + deps = [ "//xla/stream_executor", "//xla/stream_executor/gpu:gpu_stream", - "//xla/stream_executor/platform", - ]), + ], ) -cc_library( - name = "libnvptxcompiler_empty", - visibility = ["//visibility:public"], -) +cc_library(name = "libnvptxcompiler_empty") # This intermediate target is needed because we can't nest `select` statements. alias( @@ -572,12 +526,11 @@ alias( "//conditions:default": ":libnvptxcompiler_empty", }), tags = ["manual"], - visibility = ["//visibility:public"], ) -cc_library( +cuda_only_cc_library( name = "cuda_asm_compiler", - srcs = if_cuda_is_configured(["cuda_asm_compiler.cc"]), + srcs = ["cuda_asm_compiler.cc"], copts = tf_additional_gpu_compilation_copts(), local_defines = select({ ":libnvptxcompiler_support_enabled": [ @@ -585,9 +538,13 @@ cc_library( ], "//conditions:default": [], }), - visibility = ["//visibility:public"], - deps = if_cuda_is_configured([ + deps = [ ":libnvptxcompiler_if_enabled", + "//xla:status_macros", + "//xla/stream_executor:stream_executor_headers", + "//xla/stream_executor/gpu:asm_compiler_header", + "//xla/stream_executor/gpu:gpu_asm_opts", + "//xla/stream_executor/gpu:gpu_driver_header", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base", "@com_google_absl//absl/base:core_headers", @@ -598,64 +555,79 @@ cc_library( "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@local_config_cuda//cuda:cuda_headers", - "//xla:status_macros", - "//xla/stream_executor/gpu:asm_compiler_header", - "//xla/stream_executor/gpu:gpu_asm_opts", - "//xla/stream_executor/gpu:gpu_diagnostics_header", - "//xla/stream_executor/gpu:gpu_driver_header", + "@local_tsl//tsl/platform:env", "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:status", + "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/platform:subprocess", - ]), + ], ) -cc_library( +xla_cc_test( + name = "cuda_asm_compiler_test", + srcs = ["cuda_asm_compiler_test.cc"], + local_defines = select({ + ":libnvptxcompiler_support_enabled": [ + "ENABLE_LIBNVPTXCOMPILER_SUPPORT=1", + ], + "//conditions:default": [], + }), + tags = tf_cuda_tests_tags(), + deps = [ + ":cuda_asm_compiler", # buildcleaner: keep + "//xla/stream_executor:device_description", + "//xla/stream_executor/gpu:asm_compiler", # buildcleaner: keep + "//xla/stream_executor/gpu:gpu_asm_opts", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_googletest//:gtest", + "@local_tsl//tsl/platform:status_matchers", + "@local_tsl//tsl/platform:test", + "@local_tsl//tsl/platform:test_main", + ], +) + +cuda_only_cc_library( name = "cuda_executor", - srcs = if_cuda_is_configured(["cuda_executor.cc"]), - visibility = ["//visibility:public"], - deps = if_cuda_is_configured([ - ":cuda_activation", - ":cuda_asm_compiler", + srcs = ["cuda_executor.cc"], + deps = [ + ":cuda_collectives", # buildcleaner: keep ":cuda_diagnostics", ":cuda_driver", - ":cuda_event", - ":cuda_kernel", + ":cuda_event", # buildcleaner: keep + ":cuda_kernel", # buildcleaner: keep ":cuda_platform_id", - ":cuda_runtime", - ":cuda_collectives", - ":cuda_stream", - "@com_google_absl//absl/functional:any_invocable", - "@com_google_absl//absl/log:check", - "@com_google_absl//absl/status", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/synchronization", - "@com_google_absl//absl/types:span", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings:str_format", - "@local_config_cuda//cuda:cuda_headers", + ":cuda_runtime", # buildcleaner: keep "//xla/stream_executor", "//xla/stream_executor:plugin_registry", "//xla/stream_executor:stream_executor_internal", - "//xla/stream_executor/gpu:asm_compiler", - "//xla/stream_executor/gpu:gpu_command_buffer", "//xla/stream_executor/gpu:gpu_collectives_header", + "//xla/stream_executor/gpu:gpu_command_buffer", + "//xla/stream_executor/gpu:gpu_diagnostics_header", "//xla/stream_executor/gpu:gpu_driver_header", - "//xla/stream_executor/gpu:gpu_runtime_header", "//xla/stream_executor/gpu:gpu_event_header", - "//xla/stream_executor/gpu:gpu_executor_header", "//xla/stream_executor/gpu:gpu_kernel_header", + "//xla/stream_executor/gpu:gpu_runtime_header", "//xla/stream_executor/gpu:gpu_stream_header", "//xla/stream_executor/gpu:gpu_timer", "//xla/stream_executor/gpu:gpu_types_header", - "//xla/stream_executor/platform", - "//xla/stream_executor/platform:dso_loader", + "@com_google_absl//absl/base", + "@com_google_absl//absl/functional:any_invocable", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/numeric:int128", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/types:span", + "@local_config_cuda//cuda:cuda_headers", "@local_tsl//tsl/platform:env", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:fingerprint", "@local_tsl//tsl/platform:logging", - "@local_tsl//tsl/platform:numbers", - "@local_tsl//tsl/platform:status", "@local_tsl//tsl/platform:statusor", - ]), + ], alwayslink = True, ) @@ -680,12 +652,10 @@ cc_library( cc_library( name = "IOKit", linkopts = ["-framework IOKit"], - visibility = ["//visibility:public"], ) cc_library( name = "stream_executor_cuda", - visibility = ["//visibility:public"], deps = [ "//xla/stream_executor:stream_executor_bundle", ] + if_google( diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_asm_compiler.cc b/third_party/xla/xla/stream_executor/cuda/cuda_asm_compiler.cc index bf2a8957164684..e61ebb233dcc62 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_asm_compiler.cc +++ b/third_party/xla/xla/stream_executor/cuda/cuda_asm_compiler.cc @@ -20,11 +20,12 @@ limitations under the License. #include #include #include +#include #include #include "absl/algorithm/container.h" -#include "absl/base/attributes.h" #include "absl/base/call_once.h" +#include "absl/base/optimization.h" #include "absl/cleanup/cleanup.h" #include "absl/log/log.h" #include "absl/status/status.h" @@ -33,12 +34,16 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" #include "absl/strings/str_join.h" +#include "absl/strings/string_view.h" #include "xla/status_macros.h" #include "xla/stream_executor/gpu/asm_compiler.h" #include "xla/stream_executor/gpu/gpu_asm_opts.h" -#include "xla/stream_executor/gpu/gpu_diagnostics.h" #include "xla/stream_executor/gpu/gpu_driver.h" +#include "xla/stream_executor/stream_executor.h" +#include "tsl/platform/env.h" #include "tsl/platform/errors.h" +#include "tsl/platform/status.h" +#include "tsl/platform/statusor.h" #include "tsl/platform/subprocess.h" #ifdef ENABLE_LIBNVPTXCOMPILER_SUPPORT @@ -69,8 +74,10 @@ absl::StatusOr> LinkUsingNvlink( absl::call_once(log_once, [] { LOG(INFO) << "Using nvlink for parallel linking"; }); } - const std::string bin_path = - FindCudaExecutable("nvlink", std::string(preferred_cuda_dir)); + + TF_ASSIGN_OR_RETURN( + std::string bin_path, + FindCudaExecutable("nvlink", std::string(preferred_cuda_dir))); if (images.empty()) { return std::vector(); diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_asm_compiler_test.cc b/third_party/xla/xla/stream_executor/cuda/cuda_asm_compiler_test.cc new file mode 100644 index 00000000000000..6151beac0318f6 --- /dev/null +++ b/third_party/xla/xla/stream_executor/cuda/cuda_asm_compiler_test.cc @@ -0,0 +1,219 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include +#include +#include +#include + +#include +#include +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "xla/stream_executor/device_description.h" +#include "xla/stream_executor/gpu/asm_compiler.h" +#include "xla/stream_executor/gpu/gpu_asm_opts.h" +#include "tsl/platform/status_matchers.h" +#include "tsl/platform/test.h" + +#ifdef ENABLE_LIBNVPTXCOMPILER_SUPPORT + +namespace { + +// Generated by the following command: +// +// echo "__global__ void kernel(int* in) { for (int i=0; i < 16; i++) \ +// { in[i] += i; } for (int i=0; i < 16; i++) { in[15-i] += i; }}" \ +// | nvcc -o - -rdc true --ptx --x cu - -O0 +// +// The `.maxnreg` directive was added manually afterwards. +constexpr const char kSpillingPtx[] = R"( +// +// Generated by NVIDIA NVVM Compiler +// +// Compiler Build ID: CL-32267302 +// Cuda compilation tools, release 12.0, V12.0.140 +// Based on NVVM 7.0.1 +// + +.version 8.0 +.target sm_52 +.address_size 64 + + // .globl _Z6kernelPi + +.visible .entry _Z6kernelPi( + .param .u64 _Z6kernelPi_param_0 +) + .maxnreg 16 +{ + .reg .b32 %r<33>; + .reg .b64 %rd<3>; + + + ld.param.u64 %rd1, [_Z6kernelPi_param_0]; + cvta.to.global.u64 %rd2, %rd1; + ld.global.u32 %r1, [%rd2+4]; + ld.global.u32 %r2, [%rd2+8]; + ld.global.u32 %r3, [%rd2+12]; + ld.global.u32 %r4, [%rd2+16]; + ld.global.u32 %r5, [%rd2+20]; + ld.global.u32 %r6, [%rd2+24]; + ld.global.u32 %r7, [%rd2+28]; + ld.global.u32 %r8, [%rd2+32]; + ld.global.u32 %r9, [%rd2+36]; + ld.global.u32 %r10, [%rd2+40]; + ld.global.u32 %r11, [%rd2+44]; + ld.global.u32 %r12, [%rd2+48]; + ld.global.u32 %r13, [%rd2+52]; + ld.global.u32 %r14, [%rd2+56]; + ld.global.u32 %r15, [%rd2+60]; + add.s32 %r16, %r15, 15; + st.global.u32 [%rd2+60], %r16; + add.s32 %r17, %r14, 15; + st.global.u32 [%rd2+56], %r17; + add.s32 %r18, %r13, 15; + st.global.u32 [%rd2+52], %r18; + add.s32 %r19, %r12, 15; + st.global.u32 [%rd2+48], %r19; + add.s32 %r20, %r11, 15; + st.global.u32 [%rd2+44], %r20; + add.s32 %r21, %r10, 15; + st.global.u32 [%rd2+40], %r21; + add.s32 %r22, %r9, 15; + st.global.u32 [%rd2+36], %r22; + add.s32 %r23, %r8, 15; + st.global.u32 [%rd2+32], %r23; + add.s32 %r24, %r7, 15; + st.global.u32 [%rd2+28], %r24; + add.s32 %r25, %r6, 15; + st.global.u32 [%rd2+24], %r25; + add.s32 %r26, %r5, 15; + st.global.u32 [%rd2+20], %r26; + add.s32 %r27, %r4, 15; + st.global.u32 [%rd2+16], %r27; + add.s32 %r28, %r3, 15; + st.global.u32 [%rd2+12], %r28; + add.s32 %r29, %r2, 15; + st.global.u32 [%rd2+8], %r29; + add.s32 %r30, %r1, 15; + st.global.u32 [%rd2+4], %r30; + ld.global.u32 %r31, [%rd2]; + add.s32 %r32, %r31, 15; + st.global.u32 [%rd2], %r32; + ret; +} +)"; + +// Generated by the following command: +// +// echo "__global__ void kernel(int* output) { *output = 42; }" | +// nvcc -o - -rdc true --ptx --x cu - +// +constexpr const char kSimplePtx[] = R"( +.version 8.0 +.target sm_52 +.address_size 64 + + // .globl _Z6kernelPi + +.visible .entry _Z6kernelPi ( + .param .u64 _Z6kernelPi_param_0 +) +{ + .reg .b32 %r<16>; + .reg .b64 %rd<3>; + + + ld.param.u64 %rd1, [_Z6kernelPi_param_0]; + cvta.to.global.u64 %rd2, %rd1; + mov.u32 %r1, 42; + st.global.u32 [%rd2], %r15; + ret; + +})"; + +constexpr stream_executor::CudaComputeCapability kDefaultComputeCapability{5, + 2}; + +absl::StatusOr> CompileHelper( + stream_executor::CudaComputeCapability cc, const char* const ptx_input, + bool disable_gpuasm_optimizations = false, bool cancel_if_reg_spill = false, + std::vector extra_flags = {}) { + stream_executor::GpuAsmOpts options{}; + options.disable_gpuasm_optimizations = disable_gpuasm_optimizations; + options.extra_flags = std::move(extra_flags); + + return stream_executor::CompileGpuAsmUsingLibNvPtxCompiler( + cc.major, cc.minor, ptx_input, options, cancel_if_reg_spill); +} + +TEST(NvPtxCompilerTest, IdentifiesUnsupportedArchitecture) { + stream_executor::GpuAsmOpts options{}; + EXPECT_THAT( + CompileHelper(stream_executor::CudaComputeCapability{100, 0}, kSimplePtx), + tsl::testing::StatusIs(absl::StatusCode::kUnimplemented)); +} + +TEST(NvPtxCompilerTest, CanCompileSingleCompilationUnit) { + stream_executor::GpuAsmOpts options{}; + EXPECT_THAT(CompileHelper(kDefaultComputeCapability, kSimplePtx), + tsl::testing::IsOk()); +} + +TEST(NvPtxCompilerTest, CancelsOnRegSpill) { + // We have to disable optimization here, otherwise PTXAS will optimize our + // trivial register usages away and we don't spill as intended. + EXPECT_THAT(CompileHelper(kDefaultComputeCapability, kSpillingPtx, + /*disable_gpuasm_optimizations=*/true, + /*cancel_if_reg_spill=*/true), + tsl::testing::StatusIs(absl::StatusCode::kCancelled)); + + // We also test the converse to ensure our test case isn't broken. + EXPECT_THAT(CompileHelper(kDefaultComputeCapability, kSpillingPtx, + /*disable_gpuasm_optimizations=*/true, + /*cancel_if_reg_spill=*/false), + tsl::testing::IsOk()); +} + +TEST(NvPtxCompilerTest, AcceptsExtraArguments) { + // It's tricky to test whether `extra_arguments` works without depending on + // too much nvptx internals. So we pass the `--generate-line-info` flags and + // expect strictly larger outputs than without the flag. + auto reference_cubin = CompileHelper(kDefaultComputeCapability, kSimplePtx, + /*disable_gpuasm_optimizations=*/false, + /*cancel_if_reg_spill=*/false, {}); + auto cubin_with_line_info = + CompileHelper(kDefaultComputeCapability, kSimplePtx, + /*disable_gpuasm_optimizations=*/false, + /*cancel_if_reg_spill=*/false, {"--generate-line-info"}); + + EXPECT_THAT(reference_cubin, tsl::testing::IsOk()); + EXPECT_THAT(cubin_with_line_info, tsl::testing::IsOk()); + EXPECT_GT(cubin_with_line_info->size(), reference_cubin->size()); + + // We also test whether invalid flags lead to a compilation error. + EXPECT_THAT( + CompileHelper(kDefaultComputeCapability, kSimplePtx, + /*disable_gpuasm_optimizations=*/false, + /*cancel_if_reg_spill=*/false, {"--flag-does-not-exist"}), + tsl::testing::StatusIs(absl::StatusCode::kInternal)); +} + +} // namespace + +#endif // if defined(ENABLE_LIBNVPTXCOMPILER_SUPPORT) diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_blas.cc b/third_party/xla/xla/stream_executor/cuda/cuda_blas.cc index ca8bf1b3087594..dac34ead0ad30b 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_blas.cc +++ b/third_party/xla/xla/stream_executor/cuda/cuda_blas.cc @@ -16,24 +16,32 @@ limitations under the License. #include "xla/stream_executor/cuda/cuda_blas.h" #include +#include #include +#include #include +#include #include "absl/status/status.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" #include "absl/strings/string_view.h" #include "absl/synchronization/mutex.h" +#include "absl/time/time.h" #include "Eigen/Core" // from @eigen_archive #include "third_party/gpus/cuda/include/cublas_v2.h" #include "third_party/gpus/cuda/include/cuda.h" +#include "third_party/gpus/cuda/include/cuda_bf16.h" +#include "third_party/gpus/cuda/include/cuda_fp16.h" +#include "third_party/gpus/cuda/include/driver_types.h" +#include "third_party/gpus/cuda/include/library_types.h" +#include "third_party/gpus/cuda/include/vector_types.h" #include "xla/stream_executor/blas.h" -#include "xla/stream_executor/cuda/cuda_activation.h" #include "xla/stream_executor/cuda/cuda_blas_utils.h" -#include "xla/stream_executor/cuda/cuda_helpers.h" #include "xla/stream_executor/cuda/cuda_platform_id.h" -#include "xla/stream_executor/cuda/cuda_stream.h" +#include "xla/stream_executor/device_description.h" #include "xla/stream_executor/device_memory.h" +#include "xla/stream_executor/gpu/gpu_activation.h" #include "xla/stream_executor/gpu/gpu_executor.h" #include "xla/stream_executor/gpu/gpu_helpers.h" #include "xla/stream_executor/gpu/gpu_stream.h" @@ -45,8 +53,11 @@ limitations under the License. #include "xla/stream_executor/plugin_registry.h" #include "xla/stream_executor/scratch_allocator.h" #include "xla/stream_executor/stream_executor.h" +#include "tsl/platform/errors.h" #include "tsl/platform/logging.h" +#include "tsl/platform/statusor.h" #include "tsl/platform/tensor_float_32_utils.h" +#include "tsl/protobuf/dnn.pb.h" namespace stream_executor { namespace cuda { @@ -833,7 +844,9 @@ absl::Status CUDABlas::DoBlasGemmStridedBatchedWithAlgorithm( } bool CUDABlas::GetBlasGemmAlgorithms( - Stream *stream, std::vector *out_algorithms) { + Stream *stream, const gpu::MatrixDescriptor &, + const gpu::MatrixDescriptor &, gpu::OutputMatrixDescriptor *, const void *, + const void *, std::vector *out_algorithms) { // cublasGemmAlgo_t (and the function that accepts this type, cublasGemmEx) // were first introduced in CUDA 8. // @@ -942,53 +955,22 @@ absl::Status CUDABlas::DoBlasGemmBatchedInternal( const size_t size = batch_count * sizeof(CUDA_T *); - // Device-side copy of pointers to matrices. - DeviceMemory a; - DeviceMemory b; - DeviceMemory c; - - // If temporary space is allocated for device-side copies of pointers to - // matrices, that temporary space should not be freed until this function - // returns. Although the values for these unique_ptrs are not set here, they - // are declared at this scope so they will be destroyed when the function - // returns. - // - // If a scratch allocator is provided, these pointers will not be used at all. - std::unique_ptr> a_temporary; - std::unique_ptr> b_temporary; - std::unique_ptr> c_temporary; - - // Decide how to allocate device-side copy of pointers to matrices based on - // whether a scratch allocator was passed. - if (scratch_allocator != nullptr) { - TF_ASSIGN_OR_RETURN(DeviceMemory a_bytes, - scratch_allocator->AllocateBytes(size)); - TF_ASSIGN_OR_RETURN(DeviceMemory b_bytes, - scratch_allocator->AllocateBytes(size)); - TF_ASSIGN_OR_RETURN(DeviceMemory c_bytes, - scratch_allocator->AllocateBytes(size)); - a = DeviceMemory(a_bytes); - b = DeviceMemory(b_bytes); - c = DeviceMemory(c_bytes); - } else { - TF_ASSIGN_OR_RETURN(a_temporary, - stream->AllocateTemporaryArray(batch_count)); - TF_ASSIGN_OR_RETURN(b_temporary, - stream->AllocateTemporaryArray(batch_count)); - TF_ASSIGN_OR_RETURN(c_temporary, - stream->AllocateTemporaryArray(batch_count)); - a = DeviceMemory(*a_temporary->mutable_device_memory()); - b = DeviceMemory(*b_temporary->mutable_device_memory()); - c = DeviceMemory(*c_temporary->mutable_device_memory()); - } - - if (!stream->ThenMemcpy(&a, a_raw_ptrs.data(), size).ok() || - !stream->ThenMemcpy(&b, b_raw_ptrs.data(), size).ok() || - !stream->ThenMemcpy(&c, c_raw_ptrs.data(), size).ok()) { - return absl::InternalError( - "failed to copy memory from host to device in " - "CUDABlas::DoBlasGemmBatched"); + if (scratch_allocator == nullptr) { + return absl::InternalError("scratch_allocator is null"); } + TF_ASSIGN_OR_RETURN(DeviceMemory a_bytes, + scratch_allocator->AllocateBytes(size)); + TF_ASSIGN_OR_RETURN(DeviceMemory b_bytes, + scratch_allocator->AllocateBytes(size)); + TF_ASSIGN_OR_RETURN(DeviceMemory c_bytes, + scratch_allocator->AllocateBytes(size)); + DeviceMemory a(a_bytes); + DeviceMemory b(b_bytes); + DeviceMemory c(c_bytes); + + TF_RETURN_IF_ERROR(stream->Memcpy(&a, a_raw_ptrs.data(), size)); + TF_RETURN_IF_ERROR(stream->Memcpy(&b, b_raw_ptrs.data(), size)); + TF_RETURN_IF_ERROR(stream->Memcpy(&c, c_raw_ptrs.data(), size)); cudaDataType_t data_type = CUDADataType::type; @@ -1471,5 +1453,6 @@ void initialize_cublas() { } // namespace cuda } // namespace stream_executor -REGISTER_MODULE_INITIALIZER(register_cublas, - { stream_executor::cuda::initialize_cublas(); }); +STREAM_EXECUTOR_REGISTER_MODULE_INITIALIZER(register_cublas, { + stream_executor::cuda::initialize_cublas(); +}); diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_blas.h b/third_party/xla/xla/stream_executor/cuda/cuda_blas.h index f760a968f66dd7..5f69e8b04765a8 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_blas.h +++ b/third_party/xla/xla/stream_executor/cuda/cuda_blas.h @@ -20,14 +20,17 @@ limitations under the License. #ifndef XLA_STREAM_EXECUTOR_CUDA_CUDA_BLAS_H_ #define XLA_STREAM_EXECUTOR_CUDA_CUDA_BLAS_H_ +#include + #include "absl/base/thread_annotations.h" +#include "absl/status/status.h" #include "absl/synchronization/mutex.h" #include "third_party/gpus/cuda/include/cublas_v2.h" +#include "third_party/gpus/cuda/include/driver_types.h" #include "xla/stream_executor/blas.h" #include "xla/stream_executor/cuda/cuda_blas_lt.h" -#include "xla/stream_executor/device_memory.h" +#include "xla/stream_executor/numeric_options.h" #include "xla/stream_executor/platform/port.h" -#include "xla/stream_executor/plugin_registry.h" namespace stream_executor { diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_blas_lt.cc b/third_party/xla/xla/stream_executor/cuda/cuda_blas_lt.cc index 6fde4412ba59fe..0d20233d296e4b 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_blas_lt.cc +++ b/third_party/xla/xla/stream_executor/cuda/cuda_blas_lt.cc @@ -15,20 +15,27 @@ limitations under the License. #include "xla/stream_executor/cuda/cuda_blas_lt.h" +#include #include #include #include +#include #include #include #include -#include #include #include #include +#include "absl/log/log.h" #include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/synchronization/mutex.h" +#include "absl/time/time.h" #include "third_party/gpus/cuda/include/cublasLt.h" #include "third_party/gpus/cuda/include/cublas_v2.h" +#include "third_party/gpus/cuda/include/cuda.h" +#include "third_party/gpus/cuda/include/library_types.h" #include "xla/primitive_util.h" #include "xla/status_macros.h" #include "xla/stream_executor/blas.h" @@ -41,7 +48,11 @@ limitations under the License. #include "xla/stream_executor/gpu/gpu_timer.h" #include "xla/stream_executor/scratch_allocator.h" #include "xla/stream_executor/stream.h" +#include "xla/types.h" #include "xla/util.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/ml_dtypes.h" +#include "tsl/platform/statusor.h" #define SET_ATTR(setter, handle, attr, value) \ ToStatus(setter(handle, attr, &value, sizeof(decltype(value))), #setter) diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_blas_lt.h b/third_party/xla/xla/stream_executor/cuda/cuda_blas_lt.h index 8c3e86c19c92a5..2a0a5611b81ce0 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_blas_lt.h +++ b/third_party/xla/xla/stream_executor/cuda/cuda_blas_lt.h @@ -16,13 +16,20 @@ limitations under the License. #ifndef XLA_STREAM_EXECUTOR_CUDA_CUDA_BLAS_LT_H_ #define XLA_STREAM_EXECUTOR_CUDA_CUDA_BLAS_LT_H_ +#include #include +#include +#include #include +#include "absl/base/thread_annotations.h" #include "absl/status/status.h" #include "absl/status/statusor.h" +#include "absl/synchronization/mutex.h" #include "third_party/gpus/cuda/include/cublasLt.h" #include "third_party/gpus/cuda/include/cublas_v2.h" +#include "third_party/gpus/cuda/include/library_types.h" +#include "xla/stream_executor/blas.h" #include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/gpu/gpu_blas_lt.h" #include "xla/types.h" diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_blas_utils.cc b/third_party/xla/xla/stream_executor/cuda/cuda_blas_utils.cc index c85c30b39f7e75..ec700d6f3168f1 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_blas_utils.cc +++ b/third_party/xla/xla/stream_executor/cuda/cuda_blas_utils.cc @@ -15,10 +15,12 @@ limitations under the License. #include "xla/stream_executor/cuda/cuda_blas_utils.h" +#include "absl/log/log.h" #include "absl/status/status.h" #include "absl/strings/str_cat.h" #include "third_party/gpus/cuda/include/cublas_v2.h" #include "third_party/gpus/cuda/include/cuda.h" +#include "third_party/gpus/cuda/include/library_types.h" #include "xla/stream_executor/blas.h" namespace stream_executor { diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_blas_utils.h b/third_party/xla/xla/stream_executor/cuda/cuda_blas_utils.h index cb37c850a48dd4..aaaf4257f4f5b5 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_blas_utils.h +++ b/third_party/xla/xla/stream_executor/cuda/cuda_blas_utils.h @@ -16,10 +16,10 @@ limitations under the License. #ifndef XLA_STREAM_EXECUTOR_CUDA_CUDA_BLAS_UTILS_H_ #define XLA_STREAM_EXECUTOR_CUDA_CUDA_BLAS_UTILS_H_ -#include #include "absl/status/status.h" #include "third_party/gpus/cuda/include/cublas_v2.h" +#include "third_party/gpus/cuda/include/library_types.h" #include "xla/stream_executor/blas.h" #include "tsl/platform/errors.h" diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_conditional_kernels.cc b/third_party/xla/xla/stream_executor/cuda/cuda_conditional_kernels.cc index cb6a9e716f46b4..005889c540587c 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_conditional_kernels.cc +++ b/third_party/xla/xla/stream_executor/cuda/cuda_conditional_kernels.cc @@ -13,33 +13,732 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "absl/log/log.h" +#include namespace stream_executor::gpu { -void* GetSetIfConditionKernel() { - LOG(ERROR) << "XLA compiled without --config=cuda"; - return nullptr; +// Collection of helper kernels required by command buffers on CUDA backends. We +// use pre-compiled PTX instead of a CUDA C++ because conditional nodes require +// CUDA 12.3+ and trying to run with earlier CUDA versions leads to run time +// errors as all CUDA C++ kernels registered in a global static registry and a +// failure to load ONE kernel leads to failure to load ANY kernel at all. We +// should be able to switch to CUDA C++ once the minimum supported CUDA version +// will be larger than 12.3. + +// In all kernels defined below we set conditional handle value to `1` when we +// want to execute a CUDA graph tied to it, and to `0` otherwise. For loops, the +// graph will keep being executed until the conditional handle becomes `0`. + +// PTX kernel compiled from: +// +// __global__ void SetIfCondition(cudaGraphConditionalHandle then_handle, +// bool* predicate) { +// if (*predicate) { +// cudaGraphSetConditional(then_handle, 1); +// } else { +// cudaGraphSetConditional(then_handle, 0); +// } +// } +// +// Easiest way to get PTX from C++ is to use https://godbolt.org. +std::string_view GetSetIfConditionKernel() { + return R"( +.version 4.0 +.target sm_50 +.address_size 64 + +.extern .func cudaGraphSetConditional +( + .param .b64 cudaGraphSetConditional_param_0, + .param .b32 cudaGraphSetConditional_param_1 +) + +.visible .entry set_if_condition( + .param .u64 set_if_condition_param_0, + .param .u64 set_if_condition_param_1 +) +{ + .reg .pred %p<2>; + .reg .b16 %rs<2>; + .reg .b64 %rd<4>; + .loc 1 1 0 + + ld.param.u64 %rd1, [set_if_condition_param_0]; + ld.param.u64 %rd2, [set_if_condition_param_1]; + .loc 1 3 3 + cvta.to.global.u64 %rd3, %rd2; + ld.global.u8 %rs1, [%rd3]; + setp.eq.s16 %p1, %rs1, 0; + @%p1 bra $L__BB0_2; + + .loc 1 4 5 + { // callseq 0, 0 + .reg .b32 temp_param_reg; + .param .b64 param0; + st.param.b64 [param0+0], %rd1; + .param .b32 param1; + st.param.b32 [param1+0], 1; + call.uni + cudaGraphSetConditional, + ( + param0, + param1 + ); + } // callseq 0 + bra.uni $L__BB0_3; + +$L__BB0_2: + .loc 1 6 5 + { // callseq 1, 0 + .reg .b32 temp_param_reg; + .param .b64 param0; + st.param.b64 [param0+0], %rd1; + .param .b32 param1; + st.param.b32 [param1+0], 0; + call.uni + cudaGraphSetConditional, + ( + param0, + param1 + ); + } // callseq 1 + +$L__BB0_3: + .loc 1 8 1 + ret; + +})"; } -void* GetSetIfElseConditionKernel() { - LOG(ERROR) << "XLA compiled without --config=cuda"; - return nullptr; +// PTX kernel compiled from: +// +// __global__ void SetIfElseCondition(cudaGraphConditionalHandle then_handle, +// cudaGraphConditionalHandle else_handle, +// bool* predicate) { +// if (*predicate) { +// cudaGraphSetConditional(then_handle, 1); +// cudaGraphSetConditional(else_handle, 0); +// } else { +// cudaGraphSetConditional(then_handle, 0); +// cudaGraphSetConditional(else_handle, 1); +// } +// } +// +// Easiest way to get PTX from C++ is to use https://godbolt.org. +std::string_view GetSetIfElseConditionKernel() { + return R"( +.version 4.0 +.target sm_50 +.address_size 64 + +.extern .func cudaGraphSetConditional +( + .param .b64 cudaGraphSetConditional_param_0, + .param .b32 cudaGraphSetConditional_param_1 +) + +.visible .entry set_if_else_condition( + .param .u64 set_if_else_condition_param_0, + .param .u64 set_if_else_condition_param_1, + .param .u64 set_if_else_condition_param_2 +) +{ + .reg .pred %p<2>; + .reg .b16 %rs<2>; + .reg .b64 %rd<5>; + .loc 1 1 0 + + ld.param.u64 %rd1, [set_if_else_condition_param_0]; + ld.param.u64 %rd2, [set_if_else_condition_param_1]; + ld.param.u64 %rd3, [set_if_else_condition_param_2]; + .loc 1 4 3 + cvta.to.global.u64 %rd4, %rd3; + ld.global.u8 %rs1, [%rd4]; + setp.eq.s16 %p1, %rs1, 0; + @%p1 bra $L__BB0_2; + + .loc 1 5 5 + { // callseq 0, 0 + .reg .b32 temp_param_reg; + .param .b64 param0; + st.param.b64 [param0+0], %rd1; + .param .b32 param1; + st.param.b32 [param1+0], 1; + call.uni + cudaGraphSetConditional, + ( + param0, + param1 + ); + } // callseq 0 + .loc 1 6 5 + { // callseq 1, 0 + .reg .b32 temp_param_reg; + .param .b64 param0; + st.param.b64 [param0+0], %rd2; + .param .b32 param1; + st.param.b32 [param1+0], 0; + call.uni + cudaGraphSetConditional, + ( + param0, + param1 + ); + } // callseq 1 + bra.uni $L__BB0_3; + +$L__BB0_2: + .loc 1 8 5 + { // callseq 2, 0 + .reg .b32 temp_param_reg; + .param .b64 param0; + st.param.b64 [param0+0], %rd1; + .param .b32 param1; + st.param.b32 [param1+0], 0; + call.uni + cudaGraphSetConditional, + ( + param0, + param1 + ); + } // callseq 2 + .loc 1 9 5 + { // callseq 3, 0 + .reg .b32 temp_param_reg; + .param .b64 param0; + st.param.b64 [param0+0], %rd2; + .param .b32 param1; + st.param.b32 [param1+0], 1; + call.uni + cudaGraphSetConditional, + ( + param0, + param1 + ); + } // callseq 3 + +$L__BB0_3: + .loc 1 11 1 + ret; + +})"; } -void* GetSetCaseConditionKernel() { - LOG(ERROR) << "XLA compiled without --config=cuda"; - return nullptr; +// PTX kernel compiled from: +// +// __global__ void SetCaseCondition( +// cudaGraphConditionalHandle h0, cudaGraphConditionalHandle h1, +// cudaGraphConditionalHandle h2, cudaGraphConditionalHandle h3, +// cudaGraphConditionalHandle h4, cudaGraphConditionalHandle h5, +// cudaGraphConditionalHandle h6, cudaGraphConditionalHandle h7, +// int32_t* index, int32_t num_handles) { +// // Only handles in [0, num_handles) range are valid. +// // +// // We can't define a device function with dynamic number of handle +// // arguments, so we always pass 8 handles, but only some of them are valid. +// // Size 8 picked as a reasonable (but random) upper bound for what we see +// // in XLA uses. +// std::array handles = {h0, h1, h2, h3, +// h4, h5, h6, h7}; + +// // If branch index is out of range activate the last valid handle. +// int32_t branch_index = *index; +// if (branch_index < 0 || branch_index >= num_handles) { +// branch_index = num_handles - 1; +// } + +// for (int32_t i = 0; i < num_handles; ++i) { +// if (branch_index == i) { +// cudaGraphSetConditional(handles[i], 1); +// } else { +// cudaGraphSetConditional(handles[i], 0); +// } +// } +// } +// +// Easiest way to get PTX from C++ is to use https://godbolt.org. +std::string_view GetSetCaseConditionKernel() { + return R"( +.version 4.0 +.target sm_50 +.address_size 64 + +.extern .func cudaGraphSetConditional +( + .param .b64 cudaGraphSetConditional_param_0, + .param .b32 cudaGraphSetConditional_param_1 +) + +.visible .entry set_case_condition( + .param .u64 set_case_condition_param_0, + .param .u64 set_case_condition_param_1, + .param .u64 set_case_condition_param_2, + .param .u64 set_case_condition_param_3, + .param .u64 set_case_condition_param_4, + .param .u64 set_case_condition_param_5, + .param .u64 set_case_condition_param_6, + .param .u64 set_case_condition_param_7, + .param .u64 set_case_condition_param_8, + .param .u32 set_case_condition_param_9 +) +{ + .local .align 16 .b8 __local_depot0[64]; + .reg .b64 %SP; + .reg .b64 %SPL; + .reg .pred %p<14>; + .reg .b32 %r<31>; + .reg .b64 %rd<27>; + .loc 1 4 0 + + mov.u64 %SPL, __local_depot0; + ld.param.u64 %rd13, [set_case_condition_param_8]; + ld.param.u32 %r18, [set_case_condition_param_9]; + cvta.to.global.u64 %rd14, %rd13; + .loc 1 15 3 + add.u64 %rd1, %SPL, 0; + ld.param.u64 %rd16, [set_case_condition_param_1]; + ld.param.u64 %rd17, [set_case_condition_param_0]; + st.local.v2.u64 [%rd1], {%rd17, %rd16}; + ld.param.u64 %rd18, [set_case_condition_param_3]; + ld.param.u64 %rd19, [set_case_condition_param_2]; + st.local.v2.u64 [%rd1+16], {%rd19, %rd18}; + ld.param.u64 %rd20, [set_case_condition_param_5]; + ld.param.u64 %rd21, [set_case_condition_param_4]; + .loc 1 16 60 + st.local.v2.u64 [%rd1+32], {%rd21, %rd20}; + ld.param.u64 %rd22, [set_case_condition_param_7]; + ld.param.u64 %rd23, [set_case_condition_param_6]; + .loc 1 16 68 + st.local.v2.u64 [%rd1+48], {%rd23, %rd22}; + .loc 1 19 3 + ld.global.u32 %r19, [%rd14]; + .loc 1 20 3 + setp.lt.s32 %p1, %r19, 0; + setp.ge.s32 %p2, %r19, %r18; + or.pred %p3, %p1, %p2; + .loc 1 21 5 + add.s32 %r1, %r18, -1; + .loc 1 20 3 + selp.b32 %r2, %r1, %r19, %p3; + .loc 1 24 3 + setp.lt.s32 %p4, %r18, 1; + @%p4 bra $L__BB0_22; + + .loc 1 25 5 + and.b32 %r30, %r18, 3; + setp.lt.u32 %p5, %r1, 3; + mov.u32 %r28, 0; + @%p5 bra $L__BB0_16; + + sub.s32 %r27, %r18, %r30; + neg.s32 %r25, %r2; + mov.u32 %r28, 0; + mov.u64 %rd25, %rd1; + +$L__BB0_3: + .loc 1 0 0 + ld.local.u64 %rd4, [%rd25]; + .loc 1 25 5 + setp.eq.s32 %p6, %r25, 0; + @%p6 bra $L__BB0_5; + + .loc 1 28 7 + { // callseq 0, 0 + .reg .b32 temp_param_reg; + .param .b64 param0; + st.param.b64 [param0+0], %rd4; + .param .b32 param1; + st.param.b32 [param1+0], 0; + call.uni + cudaGraphSetConditional, + ( + param0, + param1 + ); + } // callseq 0 + bra.uni $L__BB0_6; + +$L__BB0_5: + .loc 1 26 7 + { // callseq 1, 0 + .reg .b32 temp_param_reg; + .param .b64 param0; + st.param.b64 [param0+0], %rd4; + .param .b32 param1; + st.param.b32 [param1+0], 1; + call.uni + cudaGraphSetConditional, + ( + param0, + param1 + ); + } // callseq 1 + +$L__BB0_6: + .loc 1 24 40 + add.s32 %r22, %r28, 1; + .loc 1 25 5 + setp.eq.s32 %p7, %r2, %r22; + .loc 1 0 0 + ld.local.u64 %rd5, [%rd25+8]; + .loc 1 25 5 + @%p7 bra $L__BB0_8; + bra.uni $L__BB0_7; + +$L__BB0_8: + .loc 1 26 7 + { // callseq 3, 0 + .reg .b32 temp_param_reg; + .param .b64 param0; + st.param.b64 [param0+0], %rd5; + .param .b32 param1; + st.param.b32 [param1+0], 1; + call.uni + cudaGraphSetConditional, + ( + param0, + param1 + ); + } // callseq 3 + bra.uni $L__BB0_9; + +$L__BB0_7: + .loc 1 28 7 + { // callseq 2, 0 + .reg .b32 temp_param_reg; + .param .b64 param0; + st.param.b64 [param0+0], %rd5; + .param .b32 param1; + st.param.b32 [param1+0], 0; + call.uni + cudaGraphSetConditional, + ( + param0, + param1 + ); + } // callseq 2 + +$L__BB0_9: + .loc 1 24 40 + add.s32 %r23, %r28, 2; + .loc 1 25 5 + setp.eq.s32 %p8, %r2, %r23; + .loc 1 0 0 + ld.local.u64 %rd6, [%rd25+16]; + .loc 1 25 5 + @%p8 bra $L__BB0_11; + bra.uni $L__BB0_10; + +$L__BB0_11: + .loc 1 26 7 + { // callseq 5, 0 + .reg .b32 temp_param_reg; + .param .b64 param0; + st.param.b64 [param0+0], %rd6; + .param .b32 param1; + st.param.b32 [param1+0], 1; + call.uni + cudaGraphSetConditional, + ( + param0, + param1 + ); + } // callseq 5 + bra.uni $L__BB0_12; + +$L__BB0_10: + .loc 1 28 7 + { // callseq 4, 0 + .reg .b32 temp_param_reg; + .param .b64 param0; + st.param.b64 [param0+0], %rd6; + .param .b32 param1; + st.param.b32 [param1+0], 0; + call.uni + cudaGraphSetConditional, + ( + param0, + param1 + ); + } // callseq 4 + +$L__BB0_12: + .loc 1 24 40 + add.s32 %r24, %r28, 3; + .loc 1 25 5 + setp.eq.s32 %p9, %r2, %r24; + .loc 1 0 0 + ld.local.u64 %rd7, [%rd25+24]; + .loc 1 25 5 + @%p9 bra $L__BB0_14; + bra.uni $L__BB0_13; + +$L__BB0_14: + .loc 1 26 7 + { // callseq 7, 0 + .reg .b32 temp_param_reg; + .param .b64 param0; + st.param.b64 [param0+0], %rd7; + .param .b32 param1; + st.param.b32 [param1+0], 1; + call.uni + cudaGraphSetConditional, + ( + param0, + param1 + ); + } // callseq 7 + bra.uni $L__BB0_15; + +$L__BB0_13: + .loc 1 28 7 + { // callseq 6, 0 + .reg .b32 temp_param_reg; + .param .b64 param0; + st.param.b64 [param0+0], %rd7; + .param .b32 param1; + st.param.b32 [param1+0], 0; + call.uni + cudaGraphSetConditional, + ( + param0, + param1 + ); + } // callseq 6 + +$L__BB0_15: + .loc 1 24 40 + add.s64 %rd25, %rd25, 32; + add.s32 %r28, %r28, 4; + .loc 1 24 3 + add.s32 %r25, %r25, 4; + add.s32 %r27, %r27, -4; + setp.ne.s32 %p10, %r27, 0; + @%p10 bra $L__BB0_3; + +$L__BB0_16: + .loc 1 25 5 + setp.eq.s32 %p11, %r30, 0; + @%p11 bra $L__BB0_22; + + mul.wide.s32 %rd24, %r28, 8; + add.s64 %rd26, %rd1, %rd24; + sub.s32 %r29, %r28, %r2; + +$L__BB0_18: + .pragma "nounroll"; + .loc 1 0 0 + ld.local.u64 %rd11, [%rd26]; + .loc 1 25 5 + setp.eq.s32 %p12, %r29, 0; + @%p12 bra $L__BB0_20; + + .loc 1 28 7 + { // callseq 8, 0 + .reg .b32 temp_param_reg; + .param .b64 param0; + st.param.b64 [param0+0], %rd11; + .param .b32 param1; + st.param.b32 [param1+0], 0; + call.uni + cudaGraphSetConditional, + ( + param0, + param1 + ); + } // callseq 8 + bra.uni $L__BB0_21; + +$L__BB0_20: + .loc 1 26 7 + { // callseq 9, 0 + .reg .b32 temp_param_reg; + .param .b64 param0; + st.param.b64 [param0+0], %rd11; + .param .b32 param1; + st.param.b32 [param1+0], 1; + call.uni + cudaGraphSetConditional, + ( + param0, + param1 + ); + } // callseq 9 + +$L__BB0_21: + .loc 1 24 3 + add.s32 %r30, %r30, -1; + add.s64 %rd26, %rd26, 8; + add.s32 %r29, %r29, 1; + setp.ne.s32 %p13, %r30, 0; + @%p13 bra $L__BB0_18; + +$L__BB0_22: + .loc 1 31 1 + ret; + +})"; } -void* GetSetForConditionKernel() { - LOG(ERROR) << "XLA compiled without --config=cuda"; - return nullptr; +// PTX kernel compiled from: +// +// __global__ void SetForCondition(cudaGraphConditionalHandle handle, +// int32_t* loop_index, +// int32_t num_iterations) { +// if (*loop_index < num_iterations) { +// cudaGraphSetConditional(handle, 1); +// } else { +// cudaGraphSetConditional(handle, 0); +// } +// *loop_index += 1; +// } +// +// Easiest way to get PTX from C++ is to use https://godbolt.org. +std::string_view GetSetForConditionKernel() { + return R"( +.version 4.0 +.target sm_50 +.address_size 64 + +.extern .func cudaGraphSetConditional +( + .param .b64 cudaGraphSetConditional_param_0, + .param .b32 cudaGraphSetConditional_param_1 +) + +.visible .entry set_for_condition( + .param .u64 set_for_condition_param_0, + .param .u64 set_for_condition_param_1, + .param .u32 set_for_condition_param_2 +) +{ + .reg .pred %p<2>; + .reg .b32 %r<5>; + .reg .b64 %rd<4>; + .loc 1 1 0 + + ld.param.u64 %rd2, [set_for_condition_param_0]; + ld.param.u64 %rd3, [set_for_condition_param_1]; + ld.param.u32 %r1, [set_for_condition_param_2]; + .loc 1 3 3 + cvta.to.global.u64 %rd1, %rd3; + ld.global.u32 %r2, [%rd1]; + setp.lt.s32 %p1, %r2, %r1; + @%p1 bra $L__BB0_2; + bra.uni $L__BB0_1; + +$L__BB0_2: + .loc 1 4 5 + { // callseq 1, 0 + .reg .b32 temp_param_reg; + .param .b64 param0; + st.param.b64 [param0+0], %rd2; + .param .b32 param1; + st.param.b32 [param1+0], 1; + call.uni + cudaGraphSetConditional, + ( + param0, + param1 + ); + } // callseq 1 + bra.uni $L__BB0_3; + +$L__BB0_1: + .loc 1 6 5 + { // callseq 0, 0 + .reg .b32 temp_param_reg; + .param .b64 param0; + st.param.b64 [param0+0], %rd2; + .param .b32 param1; + st.param.b32 [param1+0], 0; + call.uni + cudaGraphSetConditional, + ( + param0, + param1 + ); + } // callseq 0 + +$L__BB0_3: + .loc 1 8 3 + ld.global.u32 %r3, [%rd1]; + add.s32 %r4, %r3, 1; + st.global.u32 [%rd1], %r4; + .loc 1 9 1 + ret; + +})"; } -void* GetSetWhileConditionKernel() { - LOG(ERROR) << "XLA compiled without --config=cuda"; - return nullptr; +std::string_view GetSetWhileConditionKernel() { + // While condition kernel is the same as an `If` with a single branch. + return R"( +.version 4.0 +.target sm_50 +.address_size 64 + +.extern .func cudaGraphSetConditional +( + .param .b64 cudaGraphSetConditional_param_0, + .param .b32 cudaGraphSetConditional_param_1 +) + +.visible .entry set_while_condition( + .param .u64 set_while_condition_param_0, + .param .u64 set_while_condition_param_1 +) +{ + .reg .pred %p<2>; + .reg .b16 %rs<2>; + .reg .b64 %rd<4>; + .loc 1 1 0 + + ld.param.u64 %rd1, [set_while_condition_param_0]; + ld.param.u64 %rd2, [set_while_condition_param_1]; + .loc 1 3 3 + cvta.to.global.u64 %rd3, %rd2; + ld.global.u8 %rs1, [%rd3]; + setp.eq.s16 %p1, %rs1, 0; + @%p1 bra $L__BB0_2; + + .loc 1 4 5 + { // callseq 0, 0 + .reg .b32 temp_param_reg; + .param .b64 param0; + st.param.b64 [param0+0], %rd1; + .param .b32 param1; + st.param.b32 [param1+0], 1; + call.uni + cudaGraphSetConditional, + ( + param0, + param1 + ); + } // callseq 0 + bra.uni $L__BB0_3; + +$L__BB0_2: + .loc 1 6 5 + { // callseq 1, 0 + .reg .b32 temp_param_reg; + .param .b64 param0; + st.param.b64 [param0+0], %rd1; + .param .b32 param1; + st.param.b32 [param1+0], 0; + call.uni + cudaGraphSetConditional, + ( + param0, + param1 + ); + } // callseq 1 + +$L__BB0_3: + .loc 1 8 1 + ret; + +})"; } } // namespace stream_executor::gpu diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_conditional_kernels.cu.cc b/third_party/xla/xla/stream_executor/cuda/cuda_conditional_kernels.cu.cc deleted file mode 100644 index 7279b3f24f8367..00000000000000 --- a/third_party/xla/xla/stream_executor/cuda/cuda_conditional_kernels.cu.cc +++ /dev/null @@ -1,129 +0,0 @@ -/* Copyright 2023 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include -#include - -#include "third_party/gpus/cuda/include/cuda.h" - -namespace stream_executor { -namespace cuda { -namespace { - -// In all kernels defined below we set conditional handle value to `1` when we -// want to execute a CUDA graph tied to it, and to `0` otherwise. For loops, the -// graph will keep being executed until the conditional handle becomes `0`. - -#if defined(STREAM_EXECUTOR_CUDA_ENABLE_GRAPH_CONDITIONAL) && \ - CUDA_VERSION >= 12030 - -__global__ void SetIfCondition(cudaGraphConditionalHandle then_handle, - bool* predicate) { - if (*predicate) { - cudaGraphSetConditional(then_handle, 1); - } else { - cudaGraphSetConditional(then_handle, 0); - } -} - -__global__ void SetIfElseCondition(cudaGraphConditionalHandle then_handle, - cudaGraphConditionalHandle else_handle, - bool* predicate) { - if (*predicate) { - cudaGraphSetConditional(then_handle, 1); - cudaGraphSetConditional(else_handle, 0); - } else { - cudaGraphSetConditional(then_handle, 0); - cudaGraphSetConditional(else_handle, 1); - } -} - -__global__ void SetCaseCondition( - cudaGraphConditionalHandle h0, cudaGraphConditionalHandle h1, - cudaGraphConditionalHandle h2, cudaGraphConditionalHandle h3, - cudaGraphConditionalHandle h4, cudaGraphConditionalHandle h5, - cudaGraphConditionalHandle h6, cudaGraphConditionalHandle h7, - int32_t* index, int32_t num_handles) { - // Only handles in [0, num_handles) range are valid. - // - // We can't define a device function with dynamic number of handle arguments, - // so we always pass 8 handles, but only some of them are valid. Size 8 picked - // as a reasonable (but random) upper bound for what we see in XLA uses. - std::array handles = {h0, h1, h2, h3, - h4, h5, h6, h7}; - - // If branch index is out of range activate the last valid handle. - int32_t branch_index = *index; - if (branch_index < 0 || branch_index >= num_handles) { - branch_index = num_handles - 1; - } - - for (int32_t i = 0; i < num_handles; ++i) { - if (branch_index == i) { - cudaGraphSetConditional(handles[i], 1); - } else { - cudaGraphSetConditional(handles[i], 0); - } - } -} - -__global__ void SetForCondition(cudaGraphConditionalHandle handle, - int32_t* loop_index, int32_t num_iterations) { - if (*loop_index < num_iterations) { - cudaGraphSetConditional(handle, 1); - } else { - cudaGraphSetConditional(handle, 0); - } - *loop_index += 1; -} - -#else // CUDA graph conditionals are not available - -__global__ void SetIfCondition() {} -__global__ void SetIfElseCondition() {} -__global__ void SetCaseCondition() {} -__global__ void SetForCondition() {} - -#endif - -} // namespace -} // namespace cuda - -namespace gpu { - -void* GetSetIfConditionKernel() { - return reinterpret_cast(&cuda::SetIfCondition); -} - -void* GetSetIfElseConditionKernel() { - return reinterpret_cast(&cuda::SetIfElseCondition); -} - -void* GetSetCaseConditionKernel() { - return reinterpret_cast(&cuda::SetCaseCondition); -} - -void* GetSetForConditionKernel() { - return reinterpret_cast(&cuda::SetForCondition); -} - -void* GetSetWhileConditionKernel() { - // While condition kernel is the same as an `If` with a single branch. - return reinterpret_cast(&cuda::SetIfCondition); -} - -} // namespace gpu - -} // namespace stream_executor diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_diagnostics.cc b/third_party/xla/xla/stream_executor/cuda/cuda_diagnostics.cc index 99f687e3c2b03b..561ac0d401e2f2 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_diagnostics.cc +++ b/third_party/xla/xla/stream_executor/cuda/cuda_diagnostics.cc @@ -24,15 +24,16 @@ limitations under the License. #include #include #include + #if !defined(PLATFORM_WINDOWS) #include #include #include #endif + #include -#include -#include +#include #include #include "absl/container/inlined_vector.h" diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_diagnostics.h b/third_party/xla/xla/stream_executor/cuda/cuda_diagnostics.h index 9bffa93deccc2b..ea1fa0cfc51a7c 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_diagnostics.h +++ b/third_party/xla/xla/stream_executor/cuda/cuda_diagnostics.h @@ -16,6 +16,9 @@ limitations under the License. #ifndef XLA_STREAM_EXECUTOR_CUDA_CUDA_DIAGNOSTICS_H_ #define XLA_STREAM_EXECUTOR_CUDA_CUDA_DIAGNOSTICS_H_ +#include + +#include "absl/status/statusor.h" #include "xla/stream_executor/gpu/gpu_diagnostics.h" namespace stream_executor { diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_dnn.cc b/third_party/xla/xla/stream_executor/cuda/cuda_dnn.cc index 01098d5a56088a..ab5240e6752f20 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_dnn.cc +++ b/third_party/xla/xla/stream_executor/cuda/cuda_dnn.cc @@ -17,34 +17,49 @@ limitations under the License. #include #include +#include #include #include #include #include #include #include +#include #include #include #include #include #include +#include "absl/algorithm/container.h" #include "absl/base/optimization.h" #include "absl/base/thread_annotations.h" #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" -#include "absl/memory/memory.h" +#include "absl/container/inlined_vector.h" #include "absl/status/status.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" +#include "absl/strings/string_view.h" +#include "absl/synchronization/mutex.h" +#include "absl/time/time.h" +#include "absl/types/span.h" #include "Eigen/Core" // from @eigen_archive +#include "third_party/gpus/cuda/include/cuda.h" +#include "third_party/gpus/cuda/include/cuda_runtime_api.h" +#include "third_party/gpus/cuda/include/driver_types.h" #include "xla/stream_executor/cuda/cuda_activation.h" #include "xla/stream_executor/cuda/cuda_diagnostics.h" -#include "xla/stream_executor/cuda/cuda_driver.h" #include "xla/stream_executor/cuda/cuda_platform_id.h" -#include "xla/stream_executor/cuda/cuda_stream.h" +#include "xla/stream_executor/data_type.h" +#include "xla/stream_executor/device_description.h" #include "xla/stream_executor/dnn.h" +#include "xla/stream_executor/gpu/gpu_activation.h" +#include "xla/stream_executor/gpu/gpu_diagnostics.h" +#include "xla/stream_executor/gpu/gpu_driver.h" #include "xla/stream_executor/gpu/gpu_executor.h" +#include "xla/stream_executor/gpu/gpu_stream.h" #include "xla/stream_executor/gpu/gpu_timer.h" #include "xla/stream_executor/numeric_options.h" #include "xla/stream_executor/platform/initialize.h" @@ -55,19 +70,47 @@ limitations under the License. #include "xla/stream_executor/stream_executor_internal.h" #include "tsl/platform/errors.h" #include "tsl/platform/logging.h" +#include "tsl/platform/status.h" #include "tsl/platform/statusor.h" #include "tsl/platform/tensor_float_32_utils.h" +#include "tsl/protobuf/dnn.pb.h" #include "tsl/util/env_var.h" // clang-format off #include "third_party/gpus/cuda/include/library_types.h" -#include "third_party/gpus/cudnn/cudnn.h" #include "third_party/gpus/cudnn/cudnn_version.h" -#if CUDNN_VERSION >= 8100 && TF_ENABLE_CUDNN_FRONTEND + +#if CUDNN_VERSION >= 9000 +#include "third_party/gpus/cudnn/cudnn_adv.h" +#include "third_party/gpus/cudnn/cudnn_cnn.h" +#include "third_party/gpus/cudnn/cudnn_ops.h" +#elif CUDNN_VERSION >= 8100 +#include "third_party/gpus/cudnn/cudnn_adv_infer.h" +#include "third_party/gpus/cudnn/cudnn_adv_train.h" +#include "third_party/gpus/cudnn/cudnn_cnn_infer.h" +#include "third_party/gpus/cudnn/cudnn_cnn_train.h" +#include "third_party/gpus/cudnn/cudnn_ops_infer.h" +#include "third_party/gpus/cudnn/cudnn_ops_train.h" +#endif + +#include "third_party/gpus/cudnn/cudnn_backend.h" + +#if CUDNN_VERSION >= 8100 #include "third_party/cudnn_frontend/include/cudnn_frontend.h" #include "third_party/cudnn_frontend/include/cudnn_frontend_utils.h" -#endif // CUDNN_VERSION >= 8100 && TF_ENABLE_CUDNN_FRONTEND -#include "absl/strings/string_view.h" +#include "third_party/cudnn_frontend/include/cudnn_frontend_EngineConfig.h" +#include "third_party/cudnn_frontend/include/cudnn_frontend_Errata.h" +#include "third_party/cudnn_frontend/include/cudnn_frontend_ExecutionPlan.h" +#include "third_party/cudnn_frontend/include/cudnn_frontend_Filters.h" +#include "third_party/cudnn_frontend/include/cudnn_frontend_Heuristics.h" +#include "third_party/cudnn_frontend/include/cudnn_frontend_MatMulDesc.h" +#include "third_party/cudnn_frontend/include/cudnn_frontend_Operation.h" +#include "third_party/cudnn_frontend/include/cudnn_frontend_OperationGraph.h" +#include "third_party/cudnn_frontend/include/cudnn_frontend_PointWiseDesc.h" +#include "third_party/cudnn_frontend/include/cudnn_frontend_Rng.h" +#include "third_party/cudnn_frontend/include/cudnn_frontend_Tensor.h" +#include "third_party/cudnn_frontend/include/cudnn_frontend_VariantPack.h" +#endif // CUDNN_VERSION >= 8100 // clang-format on #ifdef __clang__ @@ -85,6 +128,25 @@ namespace { static_assert(CUDNN_VERSION >= 7300, "cuDNN needs to be version 7.3 or higher"); +// If 'expr' returns an error, then this returns from the current +// function with a non-successful absl::Status. +#define RETURN_IF_CUDNN_FRONTEND_ERROR(expr) \ + do { \ + if (ABSL_PREDICT_TRUE((expr).is_bad())) { \ + std::ostringstream oss; \ + oss << (expr).get_message() << "\nin " << __FILE__ << "(" << __LINE__ \ + << "): '" << #expr << "' "; \ + return absl::UnknownError(oss.str()); \ + } \ + } while (false) + +#define RETURN_FALSE_IF_CUDNN_FRONTEND_ERROR(expr) \ + do { \ + if (ABSL_PREDICT_TRUE((expr).is_bad())) { \ + return false; \ + } \ + } while (false) + // Exits the program if 'expr' doesn't return CUDNN_STATUS_SUCCESS. #define CHECK_CUDNN_OK(expr) CHECK_EQ(expr, CUDNN_STATUS_SUCCESS) @@ -373,16 +435,13 @@ void PreloadCudnnSubLibs(PreloadCudnnType type) { [[clang::fallthrough]]; } case PreloadCudnnType::ConvFwd: { -#if CUDNN >= 9000 && TF_ENABLE_CUDNN_FRONTEND +#if CUDNN_VERSION >= 9000 cudnnGraphVersionCheck(); cudnnOpsVersionCheck(); -#elif CUDNN_VERSION >= 9000 - cudnnCnnVersionCheck(); - cudnnOpsVersionCheck(); #elif CUDNN_VERSION >= 8004 cudnnOpsInferVersionCheck(); cudnnCnnInferVersionCheck(); -#endif // CUDNN >= 9000 && TF_ENABLE_CUDNN_FRONTEND +#endif // CUDNN_VERSION >= 9000 break; } case PreloadCudnnType::Rnn: { @@ -769,7 +828,7 @@ class CudnnFilterDescriptor { FilterDescriptor handle_; // Owned. }; -#if CUDNN_VERSION >= 8100 && TF_ENABLE_CUDNN_FRONTEND +#if CUDNN_VERSION >= 8100 // The errata sheet (JSON format) for marking the cudnn engines that might be // buggy. For example, we don't want the engine 999 of forward convolution: // R"({ "version" : 1, @@ -865,7 +924,7 @@ const json* CudnnExecutionPlanEngineFilterRuntime() { return json_handle; } -#endif // CUDNN_VERSION >= 8100 && TF_ENABLE_CUDNN_FRONTEND +#endif // CUDNN_VERSION >= 8100 // A helper function to decide whether to use // CUDNN_BATCHNORM_SPATIAL_PERSISTENT in batchnorm. This mode can be faster in @@ -1145,56 +1204,79 @@ class CudnnActivationDescriptor { ActivationDescriptor handle_; // Owned. }; -cudnnDataType_t ToCudnnDataType( +cudnn_frontend::DataType_t ToCudnnFrontendDataType( dnn::DataType data_type, dnn::DataLayout data_layout = dnn::DataLayout::kBatchDepthYX) { switch (data_type) { case dnn::DataType::kFloat: - return CUDNN_DATA_FLOAT; + return cudnn_frontend::DataType_t::FLOAT; case dnn::DataType::kDouble: - return CUDNN_DATA_DOUBLE; + return cudnn_frontend::DataType_t::DOUBLE; case dnn::DataType::kHalf: - return CUDNN_DATA_HALF; + return cudnn_frontend::DataType_t::HALF; case dnn::DataType::kInt8: switch (data_layout) { case dnn::DataLayout::kBatchDepthYX4: - return CUDNN_DATA_INT8x4; + return cudnn_frontend::DataType_t::INT8x4; case dnn::DataLayout::kBatchDepthYX32: - return CUDNN_DATA_INT8x32; + return cudnn_frontend::DataType_t::INT8x32; default: - return CUDNN_DATA_INT8; + return cudnn_frontend::DataType_t::INT8; } case dnn::DataType::kInt32: - return CUDNN_DATA_INT32; + return cudnn_frontend::DataType_t::INT32; case dnn::DataType::kInt64: - return CUDNN_DATA_INT64; + return cudnn_frontend::DataType_t::INT64; #if CUDNN_VERSION >= 8200 case dnn::DataType::kBF16: - return CUDNN_DATA_BFLOAT16; + return cudnn_frontend::DataType_t::BFLOAT16; #endif #if CUDNN_VERSION >= 8900 case dnn::DataType::kF8E4M3FN: - return CUDNN_DATA_FP8_E4M3; + return cudnn_frontend::DataType_t::FP8_E4M3; case dnn::DataType::kF8E5M2: - return CUDNN_DATA_FP8_E5M2; + return cudnn_frontend::DataType_t::FP8_E5M2; #endif default: LOG(FATAL) << "Invalid DNN data type: " << static_cast(data_type); } } -cudnnDataType_t ToCudnnDataType(dnn::DataType data_type, - dnn::FilterLayout filter_layout) { +cudnnDataType_t ToCudnnDataType( + dnn::DataType data_type, + dnn::DataLayout data_layout = dnn::DataLayout::kBatchDepthYX) { + cudnnDataType_t type; + CHECK_CUDNN_OK(cudnn_frontend::detail::convert_to_cudnn_type( + ToCudnnFrontendDataType(data_type, data_layout), type)); + return type; +} + +cudnn_frontend::DataType_t ToCudnnFrontendDataType( + dnn::DataType data_type, dnn::FilterLayout filter_layout) { if (data_type == dnn::DataType::kInt8 && filter_layout == dnn::FilterLayout::kOutputInputYX4) { - return CUDNN_DATA_INT8x4; + return cudnn_frontend::DataType_t::INT8x4; } if (data_type == dnn::DataType::kInt8 && (filter_layout == dnn::FilterLayout::kOutputInputYX32 || filter_layout == dnn::FilterLayout::kOutputInputYX32_CudnnReordered)) { - return CUDNN_DATA_INT8x32; + return cudnn_frontend::DataType_t::INT8x32; } - return ToCudnnDataType(data_type); + return ToCudnnFrontendDataType(data_type); +} + +cudnnDataType_t ToCudnnDataType(dnn::DataType data_type, + dnn::FilterLayout filter_layout) { + cudnnDataType_t type; + CHECK_CUDNN_OK(cudnn_frontend::detail::convert_to_cudnn_type( + ToCudnnFrontendDataType(data_type, filter_layout), type)); + return type; +} + +template +cudnn_frontend::DataType_t GetCudnnFrontendDataType( + dnn::DataLayout data_layout = dnn::DataLayout::kBatchDepthYX) { + return ToCudnnFrontendDataType(dnn::ToDataType::value, data_layout); } template @@ -1203,6 +1285,12 @@ cudnnDataType_t GetCudnnDataType( return ToCudnnDataType(dnn::ToDataType::value, data_layout); } +template +cudnn_frontend::DataType_t GetCudnnFrontendDataType( + dnn::FilterLayout filter_layout) { + return ToCudnnFrontendDataType(dnn::ToDataType::value, filter_layout); +} + template cudnnDataType_t GetCudnnDataType(dnn::FilterLayout filter_layout) { return ToCudnnDataType(dnn::ToDataType::value, filter_layout); @@ -2033,7 +2121,7 @@ absl::Status CheckRNNParameterSize( const CudnnHandle& cudnn, const CudnnRnnDescriptor& rnn_desc, const CudnnRnnSequenceTensorDescriptor& input_desc) { size_t params_size_in_bytes = 0; -#if CUDNN_VERSION >= 8100 && TF_ENABLE_CUDNN_FRONTEND +#if CUDNN_VERSION >= 8100 RETURN_IF_CUDNN_ERROR(cudnnGetRNNWeightSpaceSize( /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(), /*sizeInBytes=*/¶ms_size_in_bytes)); @@ -2177,7 +2265,7 @@ static absl::Status PopulateProfileFromTimer( profile_result->set_scratch_size(*scratch_size); } } - return tsl::OkStatus(); + return absl::OkStatus(); } template @@ -2423,7 +2511,8 @@ absl::Status CudnnSupport::DoRnnBackwardImpl( if (params_backprop_data != nullptr) { // Clear the dw to zeros. - stream->ThenMemZero(params_backprop_data, params_backprop_data->size()); + TF_RETURN_IF_ERROR( + stream->MemZero(params_backprop_data, params_backprop_data->size())); #if CUDNN_VERSION >= 8100 RETURN_IF_CUDNN_ERROR(cudnnRNNBackwardWeights_v8( /*handle=*/cudnn.handle(), @@ -2488,7 +2577,8 @@ absl::Status CudnnSupport::DoRnnBackwardImpl( if (params_backprop_data != nullptr) { // Clear the dw to zeros. - stream->ThenMemZero(params_backprop_data, params_backprop_data->size()); + TF_RETURN_IF_ERROR( + stream->MemZero(params_backprop_data, params_backprop_data->size())); // make the backward weight call RETURN_IF_CUDNN_ERROR(cudnnRNNBackwardWeights( /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(), @@ -2555,7 +2645,7 @@ absl::Status CudnnSupport::DoCtcLossImpl( } absl::StatusOr> -CudnnSupport::createRnnDescriptor( +CudnnSupport::CreateRnnDescriptor( int num_layers, int hidden_size, int input_size, int cell_size, int batch_size, dnn::RnnInputMode input_mode, dnn::RnnDirectionMode direction_mode, dnn::RnnMode rnn_mode, @@ -2579,7 +2669,7 @@ CudnnSupport::createRnnDescriptor( } absl::StatusOr> -CudnnSupport::createRnnSequenceTensorDescriptor(int max_seq_length, +CudnnSupport::CreateRnnSequenceTensorDescriptor(int max_seq_length, int batch_size, int data_size, dnn::DataType data_type) { TF_ASSIGN_OR_RETURN(CudnnRnnSequenceTensorDescriptor descriptor, @@ -2591,7 +2681,7 @@ CudnnSupport::createRnnSequenceTensorDescriptor(int max_seq_length, } absl::StatusOr> -CudnnSupport::createRnnSequenceTensorDescriptor( +CudnnSupport::CreateRnnSequenceTensorDescriptor( int max_seq_length, int batch_size, int data_size, const absl::Span& seq_lengths, bool time_major, dnn::DataType data_type) { @@ -2604,7 +2694,7 @@ CudnnSupport::createRnnSequenceTensorDescriptor( } absl::StatusOr> -CudnnSupport::createRnnStateTensorDescriptor(int num_layer, int batch_size, +CudnnSupport::CreateRnnStateTensorDescriptor(int num_layer, int batch_size, int data_size, dnn::DataType data_type) { return std::unique_ptr( @@ -3439,7 +3529,7 @@ struct RnnDoFP32ComputationFP16Input { namespace { -#if CUDNN_VERSION >= 8100 && TF_ENABLE_CUDNN_FRONTEND +#if CUDNN_VERSION >= 8100 bool GenericEngineFilter(cudnnBackendDescriptor_t engine_config, bool disable_winograd, bool disable_nondeterminism, @@ -3466,7 +3556,7 @@ bool GenericEngineFilter(cudnnBackendDescriptor_t engine_config, return ret; } -#endif // CUDNN_VERSION >= 8100 && TF_ENABLE_CUDNN_FRONTEND +#endif // CUDNN_VERSION >= 8100 } // namespace @@ -3541,7 +3631,7 @@ dnn::DataType GetConvAccumulatorType(dnn::DataType data_type) { } } -#if CUDNN_VERSION >= 8100 && TF_ENABLE_CUDNN_FRONTEND +#if CUDNN_VERSION >= 8100 namespace { static bool allowAllConfig(cudnnBackendDescriptor_t engine_config) { @@ -3610,7 +3700,7 @@ std::tuple GetTensorVectorSizeAndDim( return std::make_tuple(vector_size, vector_dim); } -#if (CUDNN_VERSION >= 8800 && TF_ENABLE_CUDNN_FRONTEND) +#if CUDNN_VERSION >= 8800 absl::StatusOr CreateCudnnTensor( absl::Span dims, absl::Span strides, int64_t uid, dnn::DataType dtype, int64_t vec_count, int64_t vec_dim, @@ -3666,7 +3756,7 @@ absl::StatusOr CreateCudnnTensor( absl::StatusOr CreateCudnnTensor( const cudnn_frontend::Tensor& original, int64_t uid, dnn::DataType dtype, bool is_virtual = false) { -#if (CUDNN_VERSION >= 8900 && TF_ENABLE_CUDNN_FRONTEND) +#if CUDNN_VERSION >= 8900 auto tensor = cudnn_frontend::TensorBuilder() .cloneFrom(original, uid) .setAlignment(32) @@ -3677,10 +3767,10 @@ absl::StatusOr CreateCudnnTensor( return tensor; #else return tsl::errors::Internal("Not implemented."); -#endif // CUDNN_VERSION >= 8900 && TF_ENABLE_CUDNN_FRONTEND +#endif // CUDNN_VERSION >= 8900 } -#if (CUDNN_VERSION >= 8800 && TF_ENABLE_CUDNN_FRONTEND) +#if CUDNN_VERSION >= 8800 enum CudnnfMHAUid { Q_ID = 400, K_ID, @@ -4107,7 +4197,7 @@ absl::StatusOr CreateCudnnDropoutFwdTensor( return dropout_scale_out_tensor; } -#endif // CUDNN_VERSION >= 8800 && TF_ENABLE_CUDNN_FRONTEND +#endif // CUDNN_VERSION >= 8800 absl::StatusOr> GetCudnnOperationGraph(dnn::ConvolutionKind kind, dnn::DataType input_type, @@ -4154,7 +4244,7 @@ GetCudnnOperationGraph(dnn::ConvolutionKind kind, dnn::DataType input_type, std::vector filter_strides = filter_descriptor.vectorized_strides( dnn::FilterLayout::kOutputInputYX, vector_size, vector_dim); -#if (CUDNN_VERSION >= 8800 && TF_ENABLE_CUDNN_FRONTEND) +#if CUDNN_VERSION >= 8800 cudnnBackendTensorReordering_t tensor_ordering_type = filter_descriptor.layout() == dnn::FilterLayout::kOutputInputYX32_CudnnReordered @@ -4166,7 +4256,7 @@ GetCudnnOperationGraph(dnn::ConvolutionKind kind, dnn::DataType input_type, dnn::FilterLayout::kOutputInputYX32_CudnnReordered; #endif -#if (CUDNN_VERSION >= 8800 && TF_ENABLE_CUDNN_FRONTEND) +#if CUDNN_VERSION >= 8800 TF_ASSIGN_OR_RETURN( auto tensor_w, CreateCudnnTensor(filter_dims, filter_strides, 'w', input_type, @@ -4316,7 +4406,7 @@ class OpGraph { } it->is_virtual = true; } - return tsl::OkStatus(); + return absl::OkStatus(); } absl::StatusOr FindOpDescriptor(int uid) const { @@ -4342,7 +4432,7 @@ class OpGraph { return tsl::errors::Internal("Unknown ID."); } it->sequence_index = index; - return tsl::OkStatus(); + return absl::OkStatus(); } bool Empty() const { return ops_.empty(); } @@ -4752,7 +4842,7 @@ GetCudnnFusedOperationGraph( std::vector filter_strides = filter_descriptor.vectorized_strides( dnn::FilterLayout::kOutputInputYX, vector_size, vector_dim); -#if (CUDNN_VERSION >= 8800 && TF_ENABLE_CUDNN_FRONTEND) +#if CUDNN_VERSION >= 8800 cudnnBackendTensorReordering_t tensor_ordering_type = filter_descriptor.layout() == dnn::FilterLayout::kOutputInputYX32_CudnnReordered @@ -4764,7 +4854,7 @@ GetCudnnFusedOperationGraph( dnn::FilterLayout::kOutputInputYX32_CudnnReordered; #endif -#if (CUDNN_VERSION >= 8800 && TF_ENABLE_CUDNN_FRONTEND) +#if CUDNN_VERSION >= 8800 TF_ASSIGN_OR_RETURN( auto tensor_w, CreateCudnnTensor(filter_dims, filter_strides, 'w', input_type, @@ -4815,7 +4905,7 @@ GetCudnnFusedOperationGraph( auto maybe_tensor_b = CreateCudnnTensor(bias_dims, bias_strides, 'b', bias_type, vector_size, vector_dim, /*is_virtual=*/false, -#if (CUDNN_VERSION >= 8800 && TF_ENABLE_CUDNN_FRONTEND) +#if CUDNN_VERSION >= 8800 tensor_ordering_type #else is_reordered_nchw_vect @@ -5152,7 +5242,7 @@ GetCudnnFusedMatmulGraph(dnn::DataType input_type, dnn::DataType bias_type, return std::make_unique(std::move(op_graph)); } -#if (CUDNN_VERSION >= 8800 && TF_ENABLE_CUDNN_FRONTEND) +#if CUDNN_VERSION >= 8800 absl::StatusOr> GetCudnnFusedMHAOperationGraph( const dnn::MatmulTensorDescriptor& bmm1_lhs_descriptor, @@ -5681,7 +5771,7 @@ absl::StatusOr CreateCudnnMaskBwdTensor( return dummy_mask_out_tensor; } } -#if (CUDNN_VERSION >= 8901 && TF_ENABLE_CUDNN_FRONTEND) +#if CUDNN_VERSION >= 8901 absl::StatusOr CreateCudnnBiasBwdTensor( std::vector& ops, absl::Span dims, absl::Span strides, dnn::DataType dtype, @@ -5757,7 +5847,7 @@ absl::StatusOr CreateCudnnBiasBwdTensor( return dbias_tensor; } -#endif // (CUDNN_VERSION >= 8901 && TF_ENABLE_CUDNN_FRONTEND) +#endif // CUDNN_VERSION >= 8901 absl::StatusOr> GetCudnnFusedMHABackwardOperationGraph( const dnn::MatmulTensorDescriptor& bmm1_grad_gemm1_rhs_descriptor, @@ -6050,7 +6140,7 @@ GetCudnnFusedMHABackwardOperationGraph( // bias backward if (use_bias) { -#if (CUDNN_VERSION >= 8901 && TF_ENABLE_CUDNN_FRONTEND) +#if CUDNN_VERSION >= 8901 TF_ASSIGN_OR_RETURN( auto tensor_dbias, CreateCudnnBiasBwdTensor(intermediate_ops, p_dims, p_strides, dtype, @@ -7453,7 +7543,7 @@ GetCudnnFlashAttentionBackwardOperationGraph( return std::make_unique(std::move(op_graph)); } -#endif // CUDNN_VERSION >= 8800 && TF_ENABLE_CUDNN_FRONTEND +#endif // CUDNN_VERSION >= 8800 } // namespace @@ -7556,7 +7646,7 @@ static absl::StatusOr RebuildExecutionPlan( return {std::move(plan)}; } -#endif // CUDNN_VERSION >= 8100 && TF_ENABLE_CUDNN_FRONTEND +#endif // CUDNN_VERSION >= 8100 } // namespace @@ -7939,7 +8029,7 @@ class ScalingParam { dnn::DataType default_target_dtype_; }; -#if CUDNN_VERSION >= 8100 && TF_ENABLE_CUDNN_FRONTEND +#if CUDNN_VERSION >= 8100 struct BackendDescriptorDeleter { void operator()(cudnnBackendDescriptor_t desc) { cudnnBackendDestroyDescriptor(desc); @@ -8159,7 +8249,7 @@ class CudnnExecutionPlanRunner } } if (offset_increment_ > 0) { -#if (CUDNN_VERSION >= 8800 && TF_ENABLE_CUDNN_FRONTEND) +#if CUDNN_VERSION >= 8800 initial_offset_ += offset_increment_; data_uids_vec.push_back(CudnnfMHAUid::D_SEED_ID); data_uids_vec.push_back(CudnnfMHAUid::D_OFFSET_ID); @@ -8177,7 +8267,7 @@ class CudnnExecutionPlanRunner return absl::UnimplementedError( "Cudnn dropout offset and seed are only supported with Cudnn >= " "8.8."); -#endif // CUDNN_VERSION >= 8800 && TF_ENABLE_CUDNN_FRONTEND +#endif // CUDNN_VERSION >= 8800 } auto variantPack = cudnn_frontend::VariantPackBuilder() @@ -8201,7 +8291,7 @@ class CudnnExecutionPlanRunner // should memset dq_accum because it is being atomic added std::vector dev_mem{inputs...}; DeviceMemoryBase* dev_dq_accum = &(dev_mem[10]); - stream->ThenMemZero(dev_dq_accum, dev_dq_accum->size()); + TF_RETURN_IF_ERROR(stream->MemZero(dev_dq_accum, dev_dq_accum->size())); } } @@ -8295,9 +8385,9 @@ class CudnnExecutionPlanRunner int64_t rng_seed_; bool is_flash_attention_; }; -#endif // CUDNN_VERSION >= 8100 && TF_ENABLE_CUDNN_FRONTEND +#endif // CUDNN_VERSION >= 8100 -#if CUDNN_VERSION >= 8100 && TF_ENABLE_CUDNN_FRONTEND +#if CUDNN_VERSION >= 8100 namespace { template @@ -8406,11 +8496,11 @@ absl::Status CreateOpRunners( VLOG(4) << "\nReturned execution plans size: " << out_runners->size(); - return tsl::OkStatus(); + return absl::OkStatus(); } } // namespace -#endif // CUDNN_VERSION >= 8100 && TF_ENABLE_CUDNN_FRONTEND +#endif // CUDNN_VERSION >= 8100 absl::Status CudnnSupport::GetConvolveRunners( bool use_cudnn_frontend, dnn::ConvolutionKind kind, @@ -8503,7 +8593,7 @@ absl::Status CudnnSupport::GetConvolveRunners( return absl::OkStatus(); } -#if CUDNN_VERSION >= 8100 && TF_ENABLE_CUDNN_FRONTEND +#if CUDNN_VERSION >= 8100 auto cudnn = cudnn_->GetHandle(parent_, stream); TF_ASSIGN_OR_RETURN( auto op_graph, @@ -8518,7 +8608,7 @@ absl::Status CudnnSupport::GetConvolveRunners( #else return tsl::errors::Unimplemented( "Cudnn execution plans are only supported with Cudnn >= 8.1."); -#endif // CUDNN_VERSION >= 8100 && TF_ENABLE_CUDNN_FRONTEND +#endif // CUDNN_VERSION >= 8100 } absl::Status CudnnSupport::GetGraphConvolveRunners( @@ -8585,7 +8675,7 @@ CudnnSupport::ConvolveRunnerFromDesc( return {std::make_unique(std::move(runner))}; } -#if CUDNN_VERSION >= 8100 && TF_ENABLE_CUDNN_FRONTEND +#if CUDNN_VERSION >= 8100 auto cudnn = cudnn_->GetHandle(parent_, stream); TF_ASSIGN_OR_RETURN( @@ -8624,7 +8714,7 @@ CudnnSupport::GraphConvolveRunnerFromDesc( "cuDNN graph execution requires the use of the cuDNN frontend."); } -#if CUDNN_VERSION >= 8900 && TF_ENABLE_CUDNN_FRONTEND +#if CUDNN_VERSION >= 8900 auto cudnn = cudnn_->GetHandle(parent_, stream); TF_ASSIGN_OR_RETURN( @@ -8880,7 +8970,7 @@ CudnnSupport::FusedConvolveRunnerFromDesc( return {std::make_unique(std::move(runner))}; } -#if CUDNN_VERSION >= 8100 && TF_ENABLE_CUDNN_FRONTEND +#if CUDNN_VERSION >= 8100 auto cudnn = cudnn_->GetHandle(parent_, stream); TF_ASSIGN_OR_RETURN(auto op_graph, @@ -9023,7 +9113,7 @@ absl::Status CudnnSupport::GetFusedConvolveRunners( return absl::OkStatus(); } -#if CUDNN_VERSION >= 8100 && TF_ENABLE_CUDNN_FRONTEND +#if CUDNN_VERSION >= 8100 auto cudnn = cudnn_->GetHandle(parent_, stream); auto op_graph_status = GetCudnnFusedOperationGraph( kind, input_type, bias_type, output_type, conv_scale, side_input_scale, @@ -9044,7 +9134,7 @@ absl::Status CudnnSupport::GetFusedConvolveRunners( #else return tsl::errors::Unimplemented( "Cudnn execution plans are only supported with Cudnn >= 8.1."); -#endif // CUDNN_VERSION >= 8100 && TF_ENABLE_CUDNN_FRONTEND +#endif // CUDNN_VERSION >= 8100 } absl::Status CudnnSupport::GetFusedMatmulRunners( @@ -9055,7 +9145,7 @@ absl::Status CudnnSupport::GetFusedMatmulRunners( const NumericOptions& numeric_options, std::vector>* out_exec_plans) { -#if CUDNN_VERSION >= 8400 && TF_ENABLE_CUDNN_FRONTEND +#if CUDNN_VERSION >= 8400 if (!use_cudnn_frontend) { return tsl::errors::Unimplemented( "Cudnn execution plans for matmul are only supported with cudnn " @@ -9083,7 +9173,7 @@ absl::Status CudnnSupport::GetFusedMatmulRunners( #else return tsl::errors::Unimplemented( "Cudnn execution plans for matmul are only supported with Cudnn >= 8.4."); -#endif // CUDNN_VERSION >= 8400 && TF_ENABLE_CUDNN_FRONTEND +#endif // CUDNN_VERSION >= 8400 } bool CudnnSupport::GetConvolveAlgorithms( @@ -9128,14 +9218,18 @@ bool CudnnSupport::GetConvolveAlgorithms( absl::StatusOr> CudnnSupport::NormRunnerFromDesc( - Stream* stream, const dnn::AlgorithmDesc& algorithm_desc, double epsilon, - const dnn::TensorDescriptor& input_descriptor, + Stream* stream, const dnn::AlgorithmDesc& algorithm_desc, + dnn::NormKind kind, double epsilon, + const dnn::TensorDescriptor& x_descriptor, const dnn::TensorDescriptor& scale_descriptor, - const dnn::TensorDescriptor& bias_descriptor, - const dnn::TensorDescriptor& output_descriptor, + const dnn::TensorDescriptor& y_or_dx_descriptor, + std::optional bias_descriptor, + std::optional dy_descriptor, std::optional expectation_descriptor, - std::optional norm_factor_descriptor) { -#if (CUDNN_VERSION >= 8905 && TF_ENABLE_CUDNN_FRONTEND) + std::optional norm_factor_descriptor, + std::optional dscale_descriptor, + std::optional dbias_descriptor) { +#if (CUDNN_VERSION >= 8905) auto cudnn = cudnn_->GetHandle(parent_, stream); std::vector uids; @@ -9146,45 +9240,48 @@ CudnnSupport::NormRunnerFromDesc( return uids.emplace_back(uids.back() + 1); }; - TF_ASSIGN_OR_RETURN( - auto xTensor, - CreateCudnnTensor(input_descriptor.dimensions(), - input_descriptor.GetPhysicalStridesMajorToMinor(), - next_uid(), input_descriptor.type(), 1, -1)); - TF_ASSIGN_OR_RETURN( - auto scaleTensor, - CreateCudnnTensor(scale_descriptor.dimensions(), - scale_descriptor.GetPhysicalStridesMajorToMinor(), - next_uid(), scale_descriptor.type(), 1, -1)); - TF_ASSIGN_OR_RETURN( - auto biasTensor, - CreateCudnnTensor(bias_descriptor.dimensions(), - bias_descriptor.GetPhysicalStridesMajorToMinor(), - next_uid(), bias_descriptor.type(), 1, -1)); - TF_ASSIGN_OR_RETURN( - auto yTensor, - CreateCudnnTensor(output_descriptor.dimensions(), - output_descriptor.GetPhysicalStridesMajorToMinor(), - next_uid(), output_descriptor.type(), 1, -1)); - std::optional expectation_tensor, norm_factor_tensor; - if (expectation_descriptor) { - TF_ASSIGN_OR_RETURN( - expectation_tensor, - CreateCudnnTensor( - expectation_descriptor->dimensions(), - expectation_descriptor->GetPhysicalStridesMajorToMinor(), - next_uid(), expectation_descriptor->type(), 1, -1)); - TF_ASSIGN_OR_RETURN( - norm_factor_tensor, - CreateCudnnTensor( - norm_factor_descriptor->dimensions(), - norm_factor_descriptor->GetPhysicalStridesMajorToMinor(), - next_uid(), norm_factor_descriptor->type(), 1, -1)); + auto create_cudnn_tensor = [next_uid](dnn::TensorDescriptor tensor_descriptor) + -> tsl::StatusOr { + return CreateCudnnTensor(tensor_descriptor.dimensions(), + tensor_descriptor.GetPhysicalStridesMajorToMinor(), + next_uid(), tensor_descriptor.type(), 1, -1); + }; + + TF_ASSIGN_OR_RETURN(auto x_tensor, create_cudnn_tensor(x_descriptor)); + TF_ASSIGN_OR_RETURN(auto scale_tensor, create_cudnn_tensor(scale_descriptor)); + TF_ASSIGN_OR_RETURN(auto y_or_dx_tensor, + create_cudnn_tensor(y_or_dx_descriptor)); + + std::optional bias_tensor, expectation_tensor, + norm_factor_tensor, dy_tensor, dscale_tensor, dbias_tensor; + if (kind == dnn::NormKind::LAYER_FWD_INFER || + kind == dnn::NormKind::LAYER_FWD_TRAIN) { + TF_ASSIGN_OR_RETURN(bias_tensor, + create_cudnn_tensor(bias_descriptor.value())); + } + + if (kind == dnn::LAYER_FWD_TRAIN) { + TF_ASSIGN_OR_RETURN(expectation_tensor, + create_cudnn_tensor(expectation_descriptor.value())); + TF_ASSIGN_OR_RETURN(norm_factor_tensor, + create_cudnn_tensor(norm_factor_descriptor.value())); + } + + if (kind == dnn::LAYER_BWD) { + TF_ASSIGN_OR_RETURN(dy_tensor, create_cudnn_tensor(dy_descriptor.value())); + TF_ASSIGN_OR_RETURN(expectation_tensor, + create_cudnn_tensor(expectation_descriptor.value())); + TF_ASSIGN_OR_RETURN(norm_factor_tensor, + create_cudnn_tensor(norm_factor_descriptor.value())); + TF_ASSIGN_OR_RETURN(dscale_tensor, + create_cudnn_tensor(dscale_descriptor.value())); + TF_ASSIGN_OR_RETURN(dbias_tensor, + create_cudnn_tensor(dbias_descriptor.value())); } std::vector scale_dim(4, 1), scalar_uids; TF_ASSIGN_OR_RETURN( - auto epsilonTensor, + auto epsilon_tensor, CreateCudnnTensor(scale_dim, scale_dim, scalar_uids.emplace_back(uids.back() + 1), dnn::DataType::kDouble, 1, -1, /*is_virtual=*/false, @@ -9194,30 +9291,47 @@ CudnnSupport::NormRunnerFromDesc( cudnnBackendNormMode_t normalizationMode = CUDNN_LAYER_NORM; std::optional norm_op; - if (!expectation_descriptor) { - cudnnBackendNormFwdPhase_t phase = CUDNN_NORM_FWD_INFERENCE; - norm_op = cudnn_frontend::OperationBuilder( - CUDNN_BACKEND_OPERATION_NORM_FORWARD_DESCRIPTOR) - .setNormalizationMode(normalizationMode) - .setNormFwdPhase(phase) - .setxDesc(xTensor) - .setScaleAndBias(scaleTensor, biasTensor) - .setEpsilonTensor(epsilonTensor) - .setyDesc(yTensor) - .build(); - } else { - cudnnBackendNormFwdPhase_t phase = CUDNN_NORM_FWD_TRAINING; - norm_op = cudnn_frontend::OperationBuilder( - CUDNN_BACKEND_OPERATION_NORM_FORWARD_DESCRIPTOR) - .setNormalizationMode(normalizationMode) - .setNormFwdPhase(phase) - .setxDesc(xTensor) - .setScaleAndBias(scaleTensor, biasTensor) - .setEpsilonTensor(epsilonTensor) - .setSavedMeanAndInvVar(expectation_tensor.value(), - norm_factor_tensor.value()) - .setyDesc(yTensor) - .build(); + switch (kind) { + case dnn::LAYER_FWD_INFER: + norm_op = cudnn_frontend::OperationBuilder( + CUDNN_BACKEND_OPERATION_NORM_FORWARD_DESCRIPTOR) + .setNormalizationMode(normalizationMode) + .setNormFwdPhase(CUDNN_NORM_FWD_INFERENCE) + .setxDesc(x_tensor) + .setScaleAndBias(scale_tensor, bias_tensor.value()) + .setEpsilonTensor(epsilon_tensor) + .setyDesc(y_or_dx_tensor) + .build(); + break; + case dnn::LAYER_FWD_TRAIN: + norm_op = cudnn_frontend::OperationBuilder( + CUDNN_BACKEND_OPERATION_NORM_FORWARD_DESCRIPTOR) + .setNormalizationMode(normalizationMode) + .setNormFwdPhase(CUDNN_NORM_FWD_TRAINING) + .setxDesc(x_tensor) + .setScaleAndBias(scale_tensor, bias_tensor.value()) + .setEpsilonTensor(epsilon_tensor) + .setSavedMeanAndInvVar(expectation_tensor.value(), + norm_factor_tensor.value()) + .setyDesc(y_or_dx_tensor) + .build(); + break; + case dnn::LAYER_BWD: + norm_op = + cudnn_frontend::OperationBuilder( + CUDNN_BACKEND_OPERATION_NORM_BACKWARD_DESCRIPTOR) + .setNormalizationMode(normalizationMode) + .setxDesc(x_tensor) + .setScale(scale_tensor) + .setSavedMeanAndInvVar(expectation_tensor.value(), + norm_factor_tensor.value()) + .setDScaleAndDBias(dscale_tensor.value(), dbias_tensor.value()) + .setdyDesc(dy_tensor.value()) + .setdxDesc(y_or_dx_tensor) + .build(); + break; + default: + break; } std::array ops = {&norm_op.value()}; @@ -9246,7 +9360,7 @@ CudnnSupport::NormRunnerFromDesc( #else return absl::UnimplementedError( "Layer norm kernels require cuDNN 8.9.5 or higher."); -#endif // CUDNN_VERSION >= 8905 && TF_ENABLE_CUDNN_FRONTEND +#endif // CUDNN_VERSION >= 8905 } // Returns the offset to increment for the dropout rng. @@ -9276,7 +9390,7 @@ CudnnSupport::FusedMHARunnerFromDesc( std::optional bias_descriptor, double scale, std::optional dropout_rate, std::optional seed, bool is_flash_attention, bool is_causal_mask) { -#if (CUDNN_VERSION >= 8800 && TF_ENABLE_CUDNN_FRONTEND) +#if CUDNN_VERSION >= 8800 auto cudnn = cudnn_->GetHandle(parent_, stream); bool use_dropout = dropout_rate && *dropout_rate > 0.0; std::vector intermediate_shape; @@ -9370,7 +9484,7 @@ CudnnSupport::FusedMHARunnerFromDesc( #else return absl::UnimplementedError( "Cudnn execution plans are only supported with Cudnn >= 8.8."); -#endif // CUDNN_VERSION >= 8800 && TF_ENABLE_CUDNN_FRONTEND +#endif // CUDNN_VERSION >= 8800 } absl::StatusOr> @@ -9392,7 +9506,7 @@ CudnnSupport::FusedMHABackwardRunnerFromDesc( std::optional bias_descriptor, double scale, std::optional dropout_rate, std::optional seed, bool is_flash_attention, bool is_causal_mask) { -#if (CUDNN_VERSION >= 8800 && TF_ENABLE_CUDNN_FRONTEND) +#if CUDNN_VERSION >= 8800 auto cudnn = cudnn_->GetHandle(parent_, stream); bool use_dropout = dropout_rate && *dropout_rate > 0.0; @@ -9511,7 +9625,7 @@ CudnnSupport::FusedMHABackwardRunnerFromDesc( return absl::UnimplementedError( "Cudnn execution plans with dbias calculation in bwd are only " "supported with Cudnn >= 8.8."); -#endif // CUDNN_VERSION >= 8800 && TF_ENABLE_CUDNN_FRONTEND +#endif // CUDNN_VERSION >= 8800 } bool CudnnSupport::GetRnnAlgorithms( @@ -9768,8 +9882,8 @@ absl::Status CudnnSupport::DoBatchNormalizationForwardImpl( void* batch_var_opaque; if (!batch_mean->is_null() && !batch_var->is_null()) { if (exponential_average_factor == 1.0) { - stream->ThenMemZero(batch_mean, batch_mean->size()); - stream->ThenMemZero(batch_var, batch_var->size()); + TF_RETURN_IF_ERROR(stream->MemZero(batch_mean, batch_mean->size())); + TF_RETURN_IF_ERROR(stream->MemZero(batch_var, batch_var->size())); } batch_mean_opaque = batch_mean->opaque(); batch_var_opaque = batch_var->opaque(); @@ -10108,7 +10222,7 @@ absl::Status CudnnSupport::CudnnReorderConvolutionFilterAndBias( /*reorderedBiasData=*/has_bias ? bias_output->opaque() : nullptr); RETURN_IF_CUDNN_ERROR(status); - return tsl::OkStatus(); + return absl::OkStatus(); } absl::Status CudnnSupport::DoPrepareForCtcLoss( @@ -10524,64 +10638,6 @@ bool CudnnSupport::DoNormalizeBackwardWithDimensions( return IsStatusOk(status, /*report_error=*/true); } -bool CudnnSupport::DoDepthConcatenate(Stream* stream, - BatchDescriptorSlice input_dimensions, - DeviceMemorySlice input_data, - DeviceMemory* output_data) { - CHECK_EQ(input_dimensions.size(), input_data.size()); - - for (const auto& dimensions : input_dimensions) { - if (dimensions.layout() != dnn::DataLayout::kBatchDepthYX) { - LOG(ERROR) << "CudnnSupport::DoDepthConcatenate currently only " - "supports the kBatchDepthYX layout."; - return false; - } - } - - if (input_dimensions.empty()) { - return true; // Nothing to do. - } - - dnn::BatchDescriptor output_dimensions = - dnn::BatchDescriptor::DepthConcatenateOutputDescriptor(input_dimensions); - - const int64_t area = output_dimensions.width() * output_dimensions.height(); - const auto index = [area](int64_t batch, int64_t depth, int64_t yx, - int64_t max_depth) { - return (batch * max_depth + depth) * area + yx; - }; - - std::vector output_host(output_dimensions.ElementCount()); - std::vector tmp; - int64_t depth_sum = 0; - for (size_t i = 0; i < input_data.size(); ++i) { - const auto& dimensions = input_dimensions[i]; - tmp.resize(dimensions.ElementCount()); - stream->ThenMemcpyD2H(*input_data[i], absl::MakeSpan(tmp)); - absl::Status block_status = stream->BlockHostUntilDone(); - if (!block_status.ok()) { - LOG(ERROR) << "BlockHostUntilDone failed: " << block_status; - return false; - } - - for (int64_t batch = 0; batch < output_dimensions.count(); ++batch) { - for (int64_t yx = 0; yx < area; ++yx) { - for (int64_t depth = 0; depth < dimensions.feature_map_count(); - ++depth) { - LOG(INFO) << output_dimensions.ElementCount() << ' ' << batch << ' ' - << yx << ' ' << depth; - output_host[index(batch, depth + depth_sum, yx, - output_dimensions.feature_map_count())] = - tmp[index(batch, depth, yx, dimensions.feature_map_count())]; - } - } - } - depth_sum += dimensions.feature_map_count(); - } - stream->ThenMemcpyH2D(output_host, output_data); - return true; -} - bool CudnnSupport::DeriveOutputBatchDescriptor( const dnn::BatchDescriptor& batch_descriptor, const dnn::FilterDescriptor& filter_descriptor, @@ -10644,5 +10700,6 @@ void initialize_cudnn() { #pragma clang diagnostic pop #endif -REGISTER_MODULE_INITIALIZER(register_cudnn, - { stream_executor::initialize_cudnn(); }); +STREAM_EXECUTOR_REGISTER_MODULE_INITIALIZER(register_cudnn, { + stream_executor::initialize_cudnn(); +}); diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_dnn.h b/third_party/xla/xla/stream_executor/cuda/cuda_dnn.h index 0e11868b5d64ae..fb7999e5e0cce0 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_dnn.h +++ b/third_party/xla/xla/stream_executor/cuda/cuda_dnn.h @@ -19,18 +19,21 @@ limitations under the License. #ifndef XLA_STREAM_EXECUTOR_CUDA_CUDA_DNN_H_ #define XLA_STREAM_EXECUTOR_CUDA_CUDA_DNN_H_ +#include #include +#include #include #include #include -#include "absl/base/thread_annotations.h" #include "absl/status/status.h" +#include "absl/status/statusor.h" #include "absl/types/span.h" -#include "xla/stream_executor/cuda/cuda_activation.h" +#include "xla/stream_executor/device_description.h" +#include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/dnn.h" -#include "xla/stream_executor/plugin_registry.h" -#include "xla/stream_executor/temporary_device_memory.h" +#include "xla/stream_executor/numeric_options.h" +#include "tsl/protobuf/dnn.pb.h" namespace stream_executor { namespace gpu { @@ -55,7 +58,7 @@ class CudnnSupport : public dnn::DnnSupport { absl::Status Init() override; absl::StatusOr GetVersion() override; - absl::StatusOr> createRnnDescriptor( + absl::StatusOr> CreateRnnDescriptor( int num_layers, int hidden_size, int input_size, int cell_size, int batch_size, dnn::RnnInputMode input_mode, dnn::RnnDirectionMode direction_mode, dnn::RnnMode rnn_mode, @@ -64,19 +67,19 @@ class CudnnSupport : public dnn::DnnSupport { ScratchAllocator* state_allocator, bool use_padded_io) override; absl::StatusOr> - createRnnSequenceTensorDescriptor(int max_seq_length, int batch_size, + CreateRnnSequenceTensorDescriptor(int max_seq_length, int batch_size, int data_size, dnn::DataType data_type) override; absl::StatusOr> - createRnnSequenceTensorDescriptor(int max_seq_length, int batch_size, + CreateRnnSequenceTensorDescriptor(int max_seq_length, int batch_size, int data_size, const absl::Span& seq_lengths, bool time_major, dnn::DataType data_type) override; absl::StatusOr> - createRnnStateTensorDescriptor(int num_layer, int batch_size, int data_size, + CreateRnnStateTensorDescriptor(int num_layer, int batch_size, int data_size, dnn::DataType data_type) override; bool DoRnnForward(Stream* stream, const dnn::RnnDescriptor& rnn_desc, @@ -296,13 +299,17 @@ class CudnnSupport : public dnn::DnnSupport { dnn::ActivationMode activation_mode) override; absl::StatusOr> NormRunnerFromDesc( - Stream* stream, const dnn::AlgorithmDesc& algorithm_desc, double epsilon, - const dnn::TensorDescriptor& input_descriptor, + Stream* stream, const dnn::AlgorithmDesc& algorithm_desc, + dnn::NormKind kind, double epsilon, + const dnn::TensorDescriptor& x_descriptor, const dnn::TensorDescriptor& scale_descriptor, - const dnn::TensorDescriptor& bias_descriptor, - const dnn::TensorDescriptor& output_descriptor, + const dnn::TensorDescriptor& y_or_dx_descriptor, + std::optional bias_descriptor, + std::optional dy_descriptor, std::optional expectation_descriptor, - std::optional norm_factor_descriptor) override; + std::optional norm_factor_descriptor, + std::optional dscale_descriptor, + std::optional dbias_descriptor) override; absl::StatusOr> FusedMHARunnerFromDesc( @@ -515,10 +522,6 @@ class CudnnSupport : public dnn::DnnSupport { DeviceMemory* raw_variable_gradient, ScratchAllocator* workspace_allocator) override; - bool DoDepthConcatenate(Stream* stream, BatchDescriptorSlice input_dimensions, - DeviceMemorySlice input_data, - DeviceMemory* output_data) override; - // Derives an output batch descriptor from an input batch and convolution // descriptors. bool DeriveOutputBatchDescriptor( diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_driver.cc b/third_party/xla/xla/stream_executor/cuda/cuda_driver.cc index 36662e6cc70949..f73448ce5299ab 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_driver.cc +++ b/third_party/xla/xla/stream_executor/cuda/cuda_driver.cc @@ -49,14 +49,12 @@ limitations under the License. #include "xla/stream_executor/gpu/gpu_types.h" #include "xla/stream_executor/platform.h" #include "tsl/platform/env.h" -#include "tsl/platform/errors.h" #include "tsl/platform/logging.h" #include "tsl/platform/macros.h" #include "tsl/platform/numbers.h" #include "tsl/platform/stacktrace.h" #include "tsl/platform/threadpool.h" - static constexpr bool FLAGS_gpuexec_cuda_driver_inject_init_error = false; static constexpr bool FLAGS_gpuexec_cuda_sync_around_driver_calls = false; static constexpr bool FLAGS_gpuexec_cuda_device_0_only = false; @@ -713,6 +711,25 @@ GpuDriver::GraphNodeGetType(CUgraphNode node) { return absl::InternalError("Invalid CUDA graph node type"); } +absl::StatusOr> +GpuDriver::GraphNodeGetDependencies(GpuGraphNodeHandle node) { + VLOG(2) << "Get CUDA graph node " << node << " dependencies"; + + std::vector dependencies; + + size_t num_dependencies = 0; + RETURN_IF_CUDA_RES_ERROR( + cuGraphNodeGetDependencies(node, nullptr, &num_dependencies), + "Failed to get CUDA graph node depedencies size"); + + dependencies.resize(num_dependencies, nullptr); + RETURN_IF_CUDA_RES_ERROR( + cuGraphNodeGetDependencies(node, dependencies.data(), &num_dependencies), + "Failed to get CUDA graph node depedencies"); + + return dependencies; +} + /* static */ absl::Status GpuDriver::DestroyGraphExec(CUgraphExec exec) { VLOG(2) << "Destroying CUDA executable graph " << exec; RETURN_IF_CUDA_RES_ERROR(cuGraphExecDestroy(exec), @@ -792,7 +809,7 @@ static std::string ConditionalTypeToString( /* static */ absl::StatusOr GpuDriver::GraphAddNode(CUgraphNode* node, CUgraph graph, - absl::Span deps, + absl::Span deps, const GpuGraphNodeParams& params) { #if CUDA_VERSION >= 12030 // Add conditional node to a graph. @@ -834,7 +851,7 @@ GpuDriver::GraphAddNode(CUgraphNode* node, CUgraph graph, } /* static */ absl::Status GpuDriver::GraphAddEmptyNode( - CUgraphNode* node, CUgraph graph, absl::Span deps) { + CUgraphNode* node, CUgraph graph, absl::Span deps) { VLOG(2) << "Add empty node to a graph " << graph << "; deps: " << deps.size(); RETURN_IF_CUDA_RES_ERROR( @@ -845,7 +862,7 @@ GpuDriver::GraphAddNode(CUgraphNode* node, CUgraph graph, } /* static */ absl::Status GpuDriver::GraphAddKernelNode( - CUgraphNode* node, CUgraph graph, absl::Span deps, + CUgraphNode* node, CUgraph graph, absl::Span deps, absl::string_view kernel_name, CUfunction function, unsigned int grid_dim_x, unsigned int grid_dim_y, unsigned int grid_dim_z, unsigned int block_dim_x, unsigned int block_dim_y, unsigned int block_dim_z, @@ -978,7 +995,7 @@ static CUmemAllocationType ToCudaAllocationType( } /*static*/ absl::Status GpuDriver::GraphAddMemAllocNode( - CUgraphNode* node, CUgraph graph, absl::Span deps, + CUgraphNode* node, CUgraph graph, absl::Span deps, GpuDriver::MemAccessFlags access_flags, GpuDriver::MemLocationType location_type, int device_id, GpuDriver::MemAllocationType allocation_type, uint64_t size, @@ -1029,7 +1046,7 @@ GpuDriver::GraphGetMemAllocNodeParams(CUgraphNode node) { } /*static*/ absl::Status GpuDriver::GraphAddMemFreeNode( - CUgraphNode* node, CUgraph graph, absl::Span deps, + CUgraphNode* node, CUgraph graph, absl::Span deps, CUdeviceptr gpu_dst) { RETURN_IF_CUDA_RES_ERROR( cuGraphAddMemFreeNode(node, graph, deps.data(), deps.size(), gpu_dst), @@ -1039,8 +1056,8 @@ GpuDriver::GraphGetMemAllocNodeParams(CUgraphNode node) { /* static */ absl::Status GpuDriver::GraphAddMemcpyD2DNode( GpuContext* context, CUgraphNode* node, CUgraph graph, - absl::Span deps, CUdeviceptr gpu_dst, CUdeviceptr gpu_src, - uint64_t size) { + absl::Span deps, CUdeviceptr gpu_dst, + CUdeviceptr gpu_src, uint64_t size) { VLOG(2) << "Add memcpy d2d node to a graph " << graph << "; dst: " << reinterpret_cast(gpu_dst) << "; src: " << reinterpret_cast(gpu_src) << "; size: " << size @@ -1125,7 +1142,7 @@ struct BitPatternToValue { /* static */ absl::Status GpuDriver::GraphAddMemsetNode( GpuContext* context, CUgraphNode* node, GpuGraphHandle graph, - absl::Span deps, CUdeviceptr dst, + absl::Span deps, CUdeviceptr dst, std::variant bit_pattern, uint64_t num_elements) { VLOG(2) << "Add memset node to a graph " << graph @@ -1184,7 +1201,7 @@ struct BitPatternToValue { } /* static */ absl::Status GpuDriver::GraphAddChildNode( - CUgraphNode* node, CUgraph graph, absl::Span deps, + CUgraphNode* node, CUgraph graph, absl::Span deps, CUgraph child) { VLOG(2) << "Create a new node by cloning the child graph " << child << " and add it to " << graph << "; deps: " << deps.size(); @@ -1379,6 +1396,7 @@ struct BitPatternToValue { "Failed to load PTX text as a module: %s", ToString(res))); } notification.Notify(); + return; } VLOG(3) << "PTX compilation info log (" << info_log_buffer_bytes diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_driver.h b/third_party/xla/xla/stream_executor/cuda/cuda_driver.h index a3f9e065209a0f..a72740ef4ead86 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_driver.h +++ b/third_party/xla/xla/stream_executor/cuda/cuda_driver.h @@ -18,10 +18,19 @@ limitations under the License. #ifndef XLA_STREAM_EXECUTOR_CUDA_CUDA_DRIVER_H_ #define XLA_STREAM_EXECUTOR_CUDA_CUDA_DRIVER_H_ +#include +#include +#include +#include +#include +#include + #include "absl/container/node_hash_map.h" -#include "absl/memory/memory.h" +#include "absl/log/check.h" +#include "absl/log/log.h" #include "absl/strings/str_cat.h" #include "absl/synchronization/mutex.h" +#include "third_party/gpus/cuda/include/cuda.h" #include "xla/stream_executor/gpu/gpu_driver.h" namespace stream_executor { @@ -109,9 +118,8 @@ class CreatedContexts { } } - // Return the context associated to that ptr. - static CUcontext GetAnyContext(void* ptr) { - absl::ReaderMutexLock lock(&mu_); + // Find device id from cuda pointer value. + static int GetDeviceOrdinal(void* ptr) { int device_ordinal; CUresult result = cuPointerGetAttribute(static_cast(&device_ordinal), CU_POINTER_ATTRIBUTE_DEVICE_ORDINAL, @@ -120,6 +128,13 @@ class CreatedContexts { LOG(FATAL) << "Not able to get the device_ordinal for ptr: " << ptr << ". Error: " << ToString(result); } + return device_ordinal; + } + + // Return the context associated to that ptr. + static CUcontext GetAnyContext(void* ptr) { + absl::ReaderMutexLock lock(&mu_); + int device_ordinal = GetDeviceOrdinal(ptr); CHECK_EQ(LiveOrdinal()->count(device_ordinal), 1); CHECK(!LiveOrdinal()->at(device_ordinal).empty()) << "Need at least one context."; diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_driver_test.cc b/third_party/xla/xla/stream_executor/cuda/cuda_driver_test.cc index bd54dac773a65f..05dd05d1a9ff8e 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_driver_test.cc +++ b/third_party/xla/xla/stream_executor/cuda/cuda_driver_test.cc @@ -13,10 +13,13 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "absl/log/log.h" +#include "third_party/gpus/cuda/include/cuda.h" +#include "third_party/gpus/cuda/include/driver_types.h" +#include "xla/stream_executor/gpu/gpu_driver.h" #if GOOGLE_CUDA #include "xla/stream_executor/cuda/cuda_driver.h" -#include "absl/memory/memory.h" #include "third_party/gpus/cuda/include/cuda_runtime_api.h" #include "tsl/platform/test.h" diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_event.cc b/third_party/xla/xla/stream_executor/cuda/cuda_event.cc index e3fd233795d561..f42cf47e86e1ab 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_event.cc +++ b/third_party/xla/xla/stream_executor/cuda/cuda_event.cc @@ -13,11 +13,12 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/stream_executor/cuda/cuda_event.h" - +#include "absl/log/log.h" #include "absl/status/statusor.h" #include "third_party/gpus/cuda/include/cuda.h" -#include "xla/stream_executor/cuda/cuda_stream.h" +#include "xla/stream_executor/event.h" +#include "xla/stream_executor/gpu/gpu_driver.h" +#include "xla/stream_executor/gpu/gpu_event.h" #include "xla/stream_executor/gpu/gpu_executor.h" namespace stream_executor { diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_executor.cc b/third_party/xla/xla/stream_executor/cuda/cuda_executor.cc index e163e1fd675c75..de50a4cc13e7e1 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_executor.cc +++ b/third_party/xla/xla/stream_executor/cuda/cuda_executor.cc @@ -15,9 +15,27 @@ limitations under the License. #include #include +#include +#include #include #include +#include +#include #include +#include + +#include "absl/base/casts.h" +#include "absl/numeric/int128.h" +#include "absl/strings/str_join.h" +#include "xla/stream_executor/blas.h" +#include "xla/stream_executor/device_description.h" +#include "xla/stream_executor/device_options.h" +#include "xla/stream_executor/dnn.h" +#include "xla/stream_executor/event.h" +#include "xla/stream_executor/fft.h" +#include "xla/stream_executor/gpu/gpu_diagnostics.h" +#include "xla/stream_executor/kernel_spec.h" +#include "xla/stream_executor/launch_dim.h" #if defined(PLATFORM_WINDOWS) #include @@ -389,9 +407,9 @@ GpuExecutor::CreateOrShareConstant(Stream* stream, "Failed to allocate %d bytes for new constant", content.size())); } - absl::Status status = - stream->ThenMemcpy(new_constant, content.data(), content.size()) - .BlockHostUntilDone(); + TF_RETURN_IF_ERROR( + stream->Memcpy(new_constant, content.data(), content.size())); + absl::Status status = stream->BlockHostUntilDone(); if (!status.ok()) { Deallocate(new_constant); status.Update(absl::InternalError(absl::StrFormat( @@ -454,7 +472,8 @@ absl::Status GpuExecutor::Launch(Stream* stream, const ThreadDim& thread_dims, if (VLOG_IS_ON(2)) { absl::MutexLock lock(&launched_kernels_mu_); if (!launched_kernels_.count(cufunc)) { - VlogOccupancyInfo(kernel, thread_dims, block_dims); + VlogOccupancyInfo(stream->parent()->GetDeviceDescription(), kernel, + thread_dims, block_dims); // TODO(rspringer): Remove elements from launched_kernels_...if we ever // expose a kernel/module deallocation method. launched_kernels_.insert(cufunc); @@ -469,8 +488,15 @@ absl::Status GpuExecutor::Launch(Stream* stream, const ThreadDim& thread_dims, // Launch CUDA kernels with packed arguments. auto launch = [&](const KernelArgsPackedArrayBase& packed) { - CHECK_EQ(kernel.Arity() + (packed.number_of_shared_bytes() > 0), - packed.number_of_arguments()); + int32_t expected_number_of_arguments = + kernel.Arity() + (packed.number_of_shared_bytes() > 0); + + CHECK_EQ(expected_number_of_arguments, packed.number_of_arguments()) + << "Kernel " << kernel.name() << " has " << packed.number_of_arguments() + << " arguments, but expected " << expected_number_of_arguments + << "; arity=" << kernel.Arity() + << "; number_of_shared_bytes=" << packed.number_of_shared_bytes(); + void** params = const_cast(packed.argument_addresses().data()); if (cluster_dims.has_value()) { @@ -526,7 +552,8 @@ absl::Status GpuExecutor::Submit(Stream* stream, // This is a non-essential operation; if there's a failure, proceed without // logging an error. It's nearly certain that in case of failures, we'd never // get here in the first place; these are very low-impact routines. -void GpuExecutor::VlogOccupancyInfo(const Kernel& kernel, +void GpuExecutor::VlogOccupancyInfo(const DeviceDescription& device_description, + const Kernel& kernel, const ThreadDim& thread_dims, const BlockDim& block_dims) { VLOG(2) << "Computing kernel occupancy for kernel " @@ -541,9 +568,6 @@ void GpuExecutor::VlogOccupancyInfo(const Kernel& kernel, return; } - const DeviceDescription& device_description = - kernel.parent()->GetDeviceDescription(); - const GpuKernel* cuda_kernel = AsGpuKernel(&kernel); CUfunction cufunc = cuda_kernel->AsGpuFunctionHandle(); @@ -924,28 +948,25 @@ GpuExecutor::CreateEventImplementation() { return std::unique_ptr(new GpuEvent(this)); } -std::unique_ptr -GpuExecutor::CreateKernelImplementation() { - return std::unique_ptr(new GpuKernel()); -} - std::unique_ptr GpuExecutor::GetStreamImplementation() { return std::unique_ptr(new GpuStream(this)); } -absl::StatusOr> -GpuExecutor::GetCommandBufferImplementation(CommandBuffer::Mode mode) { +absl::StatusOr> GpuExecutor::CreateKernel() { + return std::make_unique(this); +} + +absl::StatusOr> GpuExecutor::CreateCommandBuffer( + CommandBuffer::Mode mode) { VLOG(2) << "Create CUDA command buffer (CUDA graph)"; GpuGraphHandle graph = nullptr; TF_RETURN_IF_ERROR(GpuDriver::CreateGraph(&graph)); return std::make_unique(mode, /*parent=*/this, graph); } -std::unique_ptr -GpuExecutor::GetCommandBufferImplementation(CommandBuffer::Mode mode, - GpuGraphHandle graph, - bool is_owned_graph) { +std::unique_ptr GpuExecutor::CreateCommandBuffer( + CommandBuffer::Mode mode, GpuGraphHandle graph, bool is_owned_graph) { VLOG(2) << "Create CUDA command buffer (CUDA graph) from existing graph " << graph << "; is_owned_graph=" << is_owned_graph; return std::make_unique(mode, /*parent=*/this, graph, diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_fft.cc b/third_party/xla/xla/stream_executor/cuda/cuda_fft.cc index e45c4e4fbd7a82..60408566b087ef 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_fft.cc +++ b/third_party/xla/xla/stream_executor/cuda/cuda_fft.cc @@ -19,22 +19,29 @@ limitations under the License. #include #include #include +#include +#include +#include #include "absl/status/status.h" #include "absl/strings/str_cat.h" +#include "third_party/gpus/cuda/include/cuda.h" +#include "third_party/gpus/cuda/include/cufft.h" #include "xla/stream_executor/cuda/cuda_activation.h" -#include "xla/stream_executor/cuda/cuda_helpers.h" #include "xla/stream_executor/cuda/cuda_platform_id.h" -#include "xla/stream_executor/cuda/cuda_stream.h" #include "xla/stream_executor/device_memory.h" +#include "xla/stream_executor/fft.h" #include "xla/stream_executor/gpu/gpu_executor.h" +#include "xla/stream_executor/gpu/gpu_helpers.h" +#include "xla/stream_executor/gpu/gpu_stream.h" #include "xla/stream_executor/platform/initialize.h" #include "xla/stream_executor/platform/port.h" #include "xla/stream_executor/plugin_registry.h" +#include "xla/stream_executor/scratch_allocator.h" #include "xla/stream_executor/stream.h" #include "xla/stream_executor/stream_executor_internal.h" -#include "tsl/platform/errors.h" #include "tsl/platform/logging.h" +#include "tsl/platform/statusor.h" namespace stream_executor { namespace gpu { @@ -379,7 +386,7 @@ bool CUDAFft::DoFftInternal(Stream *stream, fft::Plan *plan, FuncT cufftExec, if (allocator) { auto allocated = allocator->AllocateBytes(input.size()); if (allocated.ok()) { - if (stream->ThenMemcpy(&allocated.value(), input, input.size()).ok()) { + if (stream->Memcpy(&allocated.value(), input, input.size()).ok()) { input_maybe_copy = DeviceMemory(allocated.value()); } } @@ -481,5 +488,6 @@ void initialize_cufft() { } // namespace stream_executor -REGISTER_MODULE_INITIALIZER(register_cufft, - { stream_executor::initialize_cufft(); }); +STREAM_EXECUTOR_REGISTER_MODULE_INITIALIZER(register_cufft, { + stream_executor::initialize_cufft(); +}); diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_fft.h b/third_party/xla/xla/stream_executor/cuda/cuda_fft.h index c3e0dbde33ee78..111f47903b2fc7 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_fft.h +++ b/third_party/xla/xla/stream_executor/cuda/cuda_fft.h @@ -20,12 +20,14 @@ limitations under the License. #ifndef XLA_STREAM_EXECUTOR_CUDA_CUDA_FFT_H_ #define XLA_STREAM_EXECUTOR_CUDA_CUDA_FFT_H_ +#include #include +#include "absl/log/log.h" +#include "absl/status/status.h" #include "third_party/gpus/cuda/include/cufft.h" #include "xla/stream_executor/fft.h" #include "xla/stream_executor/platform/port.h" -#include "xla/stream_executor/plugin_registry.h" #include "xla/stream_executor/scratch_allocator.h" namespace stream_executor { diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_platform.cc b/third_party/xla/xla/stream_executor/cuda/cuda_platform.cc index 3018f70586d433..9fa69a09e2b472 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_platform.cc +++ b/third_party/xla/xla/stream_executor/cuda/cuda_platform.cc @@ -15,20 +15,29 @@ limitations under the License. #include "xla/stream_executor/cuda/cuda_platform.h" +#include +#include +#include #include +#include +#include #include "absl/base/call_once.h" -#include "absl/base/const_init.h" -#include "absl/memory/memory.h" +#include "absl/log/check.h" +#include "absl/log/log.h" #include "absl/status/status.h" #include "absl/status/statusor.h" -#include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" -#include "xla/stream_executor/cuda/cuda_driver.h" #include "xla/stream_executor/cuda/cuda_platform_id.h" +#include "xla/stream_executor/device_description.h" +#include "xla/stream_executor/device_options.h" +#include "xla/stream_executor/gpu/gpu_driver.h" #include "xla/stream_executor/gpu/gpu_executor.h" +#include "xla/stream_executor/multi_platform_manager.h" +#include "xla/stream_executor/platform.h" #include "xla/stream_executor/platform/initialize.h" -#include "tsl/platform/errors.h" +#include "xla/stream_executor/platform_manager.h" +#include "tsl/platform/status.h" namespace stream_executor { namespace gpu { @@ -177,14 +186,14 @@ CudaPlatform::GetUncachedExecutor(const StreamExecutorConfig& config) { } // namespace gpu static void InitializeCudaPlatform() { - // Disabling leak checking, MultiPlatformManager does not destroy its + // Disabling leak checking, PlatformManager does not destroy its // registered platforms. std::unique_ptr platform(new gpu::CudaPlatform); - TF_CHECK_OK(MultiPlatformManager::RegisterPlatform(std::move(platform))); + TF_CHECK_OK(PlatformManager::RegisterPlatform(std::move(platform))); } } // namespace stream_executor -REGISTER_MODULE_INITIALIZER(cuda_platform, - stream_executor::InitializeCudaPlatform()); +STREAM_EXECUTOR_REGISTER_MODULE_INITIALIZER( + cuda_platform, stream_executor::InitializeCudaPlatform()); diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_platform.h b/third_party/xla/xla/stream_executor/cuda/cuda_platform.h index fb4c662eeaaff7..153282b26507e6 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_platform.h +++ b/third_party/xla/xla/stream_executor/cuda/cuda_platform.h @@ -17,16 +17,13 @@ limitations under the License. #define XLA_STREAM_EXECUTOR_CUDA_CUDA_PLATFORM_H_ #include -#include +#include -#include "absl/base/thread_annotations.h" #include "absl/status/statusor.h" #include "xla/stream_executor/executor_cache.h" -#include "xla/stream_executor/multi_platform_manager.h" #include "xla/stream_executor/platform.h" -#include "xla/stream_executor/platform/port.h" +#include "xla/stream_executor/platform_manager.h" #include "xla/stream_executor/stream_executor.h" -#include "xla/stream_executor/stream_executor_internal.h" namespace stream_executor { namespace cuda { diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_platform_id.cc b/third_party/xla/xla/stream_executor/cuda/cuda_platform_id.cc index 6a3f807ec0be7c..c8754155d6d511 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_platform_id.cc +++ b/third_party/xla/xla/stream_executor/cuda/cuda_platform_id.cc @@ -15,6 +15,8 @@ limitations under the License. #include "xla/stream_executor/cuda/cuda_platform_id.h" +#include "xla/stream_executor/platform.h" + namespace stream_executor { namespace cuda { diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_stream.h b/third_party/xla/xla/stream_executor/cuda/cuda_stream.h index 05106debf7ffa0..7e651b45d0e6fa 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_stream.h +++ b/third_party/xla/xla/stream_executor/cuda/cuda_stream.h @@ -19,6 +19,7 @@ limitations under the License. #ifndef XLA_STREAM_EXECUTOR_CUDA_CUDA_STREAM_H_ #define XLA_STREAM_EXECUTOR_CUDA_CUDA_STREAM_H_ +#include "xla/stream_executor/blas.h" #include "xla/stream_executor/gpu/gpu_stream.h" namespace stream_executor { diff --git a/third_party/xla/xla/stream_executor/cuda/memcpy_test.cc b/third_party/xla/xla/stream_executor/cuda/memcpy_test.cc index 4f897859f81fb2..cf3d421e01b739 100644 --- a/third_party/xla/xla/stream_executor/cuda/memcpy_test.cc +++ b/third_party/xla/xla/stream_executor/cuda/memcpy_test.cc @@ -14,26 +14,30 @@ limitations under the License. ==============================================================================*/ #if GOOGLE_CUDA -#include "absl/memory/memory.h" #include "xla/stream_executor/device_memory.h" -#include "xla/stream_executor/multi_platform_manager.h" +#include "xla/stream_executor/platform.h" +#include "xla/stream_executor/platform_manager.h" #include "xla/stream_executor/stream.h" #include "xla/stream_executor/stream_executor.h" +#include "tsl/lib/core/status_test_util.h" +#include "tsl/platform/statusor.h" #include "tsl/platform/test.h" namespace stream_executor { TEST(MemcpyTest, PinnedHostMemory) { - Platform* platform = MultiPlatformManager::PlatformWithName("CUDA").value(); + Platform* platform = PlatformManager::PlatformWithName("CUDA").value(); StreamExecutor* executor = platform->ExecutorForDevice(0).value(); Stream stream(executor); - stream.Init(); + TF_ASSERT_OK(stream.Initialize()); ASSERT_TRUE(stream.ok()); - void* d_ptr = executor->HostMemoryAllocate(sizeof(int)); - DeviceMemoryBase d_mem(d_ptr, sizeof(int)); + TF_ASSERT_OK_AND_ASSIGN(auto d_ptr, + executor->HostMemoryAllocate(sizeof(int))); + DeviceMemoryBase d_mem(d_ptr->opaque(), sizeof(int)); + int h_ptr; - stream.ThenMemcpy(&h_ptr, d_mem, d_mem.size()); + TF_ASSERT_OK(stream.Memcpy(&h_ptr, d_mem, d_mem.size())); EXPECT_TRUE(stream.BlockHostUntilDone().ok()); } diff --git a/third_party/xla/xla/stream_executor/cuda/stream_search_test.cc b/third_party/xla/xla/stream_executor/cuda/stream_search_test.cc index 8489f59186863c..6620622ff9c080 100644 --- a/third_party/xla/xla/stream_executor/cuda/stream_search_test.cc +++ b/third_party/xla/xla/stream_executor/cuda/stream_search_test.cc @@ -14,6 +14,8 @@ limitations under the License. ==============================================================================*/ #include "absl/status/statusor.h" +#include "xla/stream_executor/platform.h" +#include "xla/stream_executor/platform_manager.h" #include "xla/stream_executor/stream_executor.h" #include "tsl/platform/test.h" @@ -24,9 +26,7 @@ namespace { class StreamSearchTest : public ::testing::Test { public: - Platform* GetPlatform() { - return *MultiPlatformManager::PlatformWithName("CUDA"); - } + Platform* GetPlatform() { return *PlatformManager::PlatformWithName("CUDA"); } }; TEST_F(StreamSearchTest, NoMatchBadPtr) { diff --git a/third_party/xla/xla/stream_executor/device_description.h b/third_party/xla/xla/stream_executor/device_description.h index e66f4e74765a58..6d06755956d9f3 100644 --- a/third_party/xla/xla/stream_executor/device_description.h +++ b/third_party/xla/xla/stream_executor/device_description.h @@ -53,8 +53,8 @@ struct CudaComputeCapability { HOPPER = 9 }; - CudaComputeCapability() = default; - CudaComputeCapability(int major, int minor) { + constexpr CudaComputeCapability() = default; + constexpr CudaComputeCapability(int major, int minor) { this->major = major; this->minor = minor; } diff --git a/third_party/xla/xla/stream_executor/device_memory.h b/third_party/xla/xla/stream_executor/device_memory.h index c201afc7656f9a..9e81a90dfba3eb 100644 --- a/third_party/xla/xla/stream_executor/device_memory.h +++ b/third_party/xla/xla/stream_executor/device_memory.h @@ -28,6 +28,7 @@ limitations under the License. #include #include +#include #include "xla/stream_executor/platform/port.h" #include "tsl/platform/logging.h" @@ -62,12 +63,16 @@ class DeviceMemoryBase { bool operator==(std::nullptr_t other) const { return is_null(); } bool operator!=(std::nullptr_t other) const { return !is_null(); } + bool operator==(const DeviceMemoryBase &other) const { + return opaque_ == other.opaque_ && size_ == other.size_; + } + // Provides a partial order between device memory values. // // This operator is provided so that this object can be used as a key in an // ordered map. bool operator<(const DeviceMemoryBase &other) const { - return opaque() < other.opaque(); + return std::tie(opaque_, size_) < std::tie(other.opaque_, other.size_); } // Returns the size, in bytes, for the backing memory. diff --git a/third_party/xla/xla/stream_executor/dnn.cc b/third_party/xla/xla/stream_executor/dnn.cc index 91b5c0157d117f..eb0bf1028943c3 100644 --- a/third_party/xla/xla/stream_executor/dnn.cc +++ b/third_party/xla/xla/stream_executor/dnn.cc @@ -221,13 +221,17 @@ DnnSupport::FusedConvolveRunnerFromDesc( absl::StatusOr> DnnSupport::NormRunnerFromDesc( - Stream* stream, const dnn::AlgorithmDesc& algorithm_desc, double epsilon, - const dnn::TensorDescriptor& input_descriptor, + Stream* stream, const dnn::AlgorithmDesc& algorithm_desc, + dnn::NormKind kind, double epsilon, + const dnn::TensorDescriptor& x_descriptor, const dnn::TensorDescriptor& scale_descriptor, - const dnn::TensorDescriptor& bias_descriptor, - const dnn::TensorDescriptor& output_descriptor, + const dnn::TensorDescriptor& y_or_dx_descriptor, + std::optional bias_descriptor, + std::optional dy_descriptor, std::optional expectation_descriptor, - std::optional norm_factor_descriptor) { + std::optional norm_factor_descriptor, + std::optional dscale_descriptor, + std::optional dbias_descriptor) { return absl::UnimplementedError("NormRunnerFromDesc not implemented."); } diff --git a/third_party/xla/xla/stream_executor/dnn.h b/third_party/xla/xla/stream_executor/dnn.h index 996ac82a0e8850..71b65131f72dbf 100644 --- a/third_party/xla/xla/stream_executor/dnn.h +++ b/third_party/xla/xla/stream_executor/dnn.h @@ -217,7 +217,7 @@ class MatmulTensorDescriptor { // Specifies the descriptor for a RNN model. // // An example use case: -// * The user first creates a model through createRnnDescriptor. +// * The user first creates a model through CreateRnnDescriptor. // * The user queries the size of the underlying opaque parameter buffer. // * The user creates and initializes a parameter buffer of the proper size. // * The user runs forward and backward operations using this RNN descriptor. @@ -416,7 +416,7 @@ class BatchDescriptor { // dimensions, except possibly for feature_map_count(), though this // function does not verify that. static BatchDescriptor DepthConcatenateOutputDescriptor( - absl::Span inputs); + absl::Span inputs); private: absl::Span spatial_size() const { @@ -567,7 +567,7 @@ enum class PadAlignment : int64_t { std::string PadAlignmentString(PadAlignment alignment); // Print alignment to str. Needed to use CHECK_EQ between two PadAlignments. -std::ostream& operator<<(std::ostream& str, dnn::PadAlignment alignment); +std::ostream& operator<<(std::ostream& str, PadAlignment alignment); // Describes a convolution. // @@ -1309,12 +1309,11 @@ class DnnSupport { const DeviceMemory& scale, const DeviceMemory& offset, const DeviceMemory& estimated_mean, const DeviceMemory& estimated_variance, - const DeviceMemory& side_input, const dnn::BatchDescriptor& x_desc, - const dnn::BatchDescriptor& scale_offset_desc, const double epsilon, - const double exponential_average_factor, - dnn::ActivationMode activation_mode, DeviceMemory* y, - DeviceMemory* batch_mean, DeviceMemory* batch_var, - DeviceMemory* reserve_space_1, + const DeviceMemory& side_input, const BatchDescriptor& x_desc, + const BatchDescriptor& scale_offset_desc, const double epsilon, + const double exponential_average_factor, ActivationMode activation_mode, + DeviceMemory* y, DeviceMemory* batch_mean, + DeviceMemory* batch_var, DeviceMemory* reserve_space_1, DeviceMemory* reserve_space_2, bool is_training, ScratchAllocator* reserve_space_allocator, ScratchAllocator* workspace_allocator) { @@ -1329,10 +1328,9 @@ class DnnSupport { const DeviceMemory& estimated_mean, const DeviceMemory& estimated_variance, const DeviceMemory& side_input, - const dnn::BatchDescriptor& x_desc, - const dnn::BatchDescriptor& scale_offset_desc, const double epsilon, - const double exponential_average_factor, - dnn::ActivationMode activation_mode, DeviceMemory* y, + const BatchDescriptor& x_desc, const BatchDescriptor& scale_offset_desc, + const double epsilon, const double exponential_average_factor, + ActivationMode activation_mode, DeviceMemory* y, DeviceMemory* batch_mean, DeviceMemory* batch_var, DeviceMemory* reserve_space_1, DeviceMemory* reserve_space_2, bool is_training, @@ -1349,10 +1347,9 @@ class DnnSupport { const DeviceMemory& estimated_mean, const DeviceMemory& estimated_variance, const DeviceMemory& side_input, - const dnn::BatchDescriptor& x_desc, - const dnn::BatchDescriptor& scale_offset_desc, const double epsilon, - const double exponential_average_factor, - dnn::ActivationMode activation_mode, DeviceMemory* y, + const BatchDescriptor& x_desc, const BatchDescriptor& scale_offset_desc, + const double epsilon, const double exponential_average_factor, + ActivationMode activation_mode, DeviceMemory* y, DeviceMemory* batch_mean, DeviceMemory* batch_var, DeviceMemory* reserve_space_1, DeviceMemory* reserve_space_2, bool is_training, @@ -1383,10 +1380,10 @@ class DnnSupport { const DeviceMemory& x, const DeviceMemory& scale, const DeviceMemory& offset, const DeviceMemory& mean, const DeviceMemory& inv_var, const DeviceMemory& y, - const dnn::BatchDescriptor& x_desc, - const dnn::BatchDescriptor& scale_offset_desc, const double epsilon, - dnn::ActivationMode activation_mode, DeviceMemory* x_backprop, - DeviceMemory* scale_backprop, DeviceMemory* offset_backprop, + const BatchDescriptor& x_desc, const BatchDescriptor& scale_offset_desc, + const double epsilon, ActivationMode activation_mode, + DeviceMemory* x_backprop, DeviceMemory* scale_backprop, + DeviceMemory* offset_backprop, DeviceMemory* side_input_backprop, DeviceMemory* reserve_space_data, ScratchAllocator* workspace_allocator) { @@ -1401,9 +1398,8 @@ class DnnSupport { const DeviceMemory& x, const DeviceMemory& scale, const DeviceMemory& offset, const DeviceMemory& mean, const DeviceMemory& inv_var, const DeviceMemory& y, - const dnn::BatchDescriptor& x_desc, - const dnn::BatchDescriptor& scale_offset_desc, const double epsilon, - dnn::ActivationMode activation_mode, + const BatchDescriptor& x_desc, const BatchDescriptor& scale_offset_desc, + const double epsilon, ActivationMode activation_mode, DeviceMemory* x_backprop, DeviceMemory* scale_backprop, DeviceMemory* offset_backprop, DeviceMemory* side_input_backprop, @@ -1420,11 +1416,9 @@ class DnnSupport { const DeviceMemory& x, const DeviceMemory& scale, const DeviceMemory& offset, const DeviceMemory& mean, const DeviceMemory& inv_var, - const DeviceMemory& y, - const dnn::BatchDescriptor& x_desc, - const dnn::BatchDescriptor& scale_offset_desc, const double epsilon, - dnn::ActivationMode activation_mode, - DeviceMemory* x_backprop, + const DeviceMemory& y, const BatchDescriptor& x_desc, + const BatchDescriptor& scale_offset_desc, const double epsilon, + ActivationMode activation_mode, DeviceMemory* x_backprop, DeviceMemory* scale_backprop, DeviceMemory* offset_backprop, DeviceMemory* side_input_backprop, DeviceMemory* reserve_space_data, @@ -1485,22 +1479,44 @@ class DnnSupport { virtual absl::Status DoFusedConvolve( Stream* stream, DataType input_type, DataType side_input_type, DataType bias_type, DataType output_type, - const dnn::BatchDescriptor& conv_input_descriptor, + const BatchDescriptor& conv_input_descriptor, DeviceMemoryBase conv_input_data, double conv_input_scale, - const dnn::FilterDescriptor& filter_descriptor, - DeviceMemoryBase filter_data, - const dnn::ConvolutionDescriptor& convolution_descriptor, + const FilterDescriptor& filter_descriptor, DeviceMemoryBase filter_data, + const ConvolutionDescriptor& convolution_descriptor, DeviceMemoryBase side_input_data, double side_input_scale, - const dnn::BatchDescriptor& bias_descriptor, DeviceMemoryBase biases, - dnn::ActivationMode activation_mode, - const dnn::BatchDescriptor& output_descriptor, + const BatchDescriptor& bias_descriptor, DeviceMemoryBase biases, + ActivationMode activation_mode, const BatchDescriptor& output_descriptor, DeviceMemoryBase output_data, ScratchAllocator* scratch_allocator, - const dnn::AlgorithmConfig& algorithm_config, - dnn::ProfileResult* output_profile_result) { + const AlgorithmConfig& algorithm_config, + ProfileResult* output_profile_result) { return absl::UnimplementedError( "DnnSupport::DoFusedConvolve not implemented on this platform."); } + template + absl::Status FusedConvolveWithAlgorithm( + Stream* stream, const BatchDescriptor& conv_input_descriptor, + const DeviceMemory& conv_input_data, ScaleT conv_input_scale, + const FilterDescriptor& filter_descriptor, + const DeviceMemory& filter_data, + const ConvolutionDescriptor& convolution_descriptor, + const DeviceMemory& side_input_data, ScaleT side_input_scale, + const BatchDescriptor& bias_descriptor, const DeviceMemory& biases, + ActivationMode activation_mode, const BatchDescriptor& output_descriptor, + DeviceMemory* output, ScratchAllocator* scratch_allocator, + const AlgorithmConfig& algorithm_config, + ProfileResult* output_profile_result) { + return DoFusedConvolve( + stream, ToDataType::value, ToDataType::value, + ToDataType::value, ToDataType::value, + conv_input_descriptor, conv_input_data, conv_input_scale, + filter_descriptor, filter_data, convolution_descriptor, side_input_data, + side_input_scale, bias_descriptor, biases, activation_mode, + output_descriptor, *output, scratch_allocator, algorithm_config, + output_profile_result); + } + template absl::Status PrepareForConvolution( ConvolutionKind kind, Stream* stream, @@ -1578,115 +1594,135 @@ class DnnSupport { AlgorithmDesc algorithm_desc, DeviceMemory scratch_memory, ProfileResult* output_profile_result) = 0; + template + absl::Status ConvolveWithAlgorithm( + Stream* stream, ConvolutionKind kind, + const BatchDescriptor& input_descriptor, + DeviceMemory input_data, + const FilterDescriptor& filter_descriptor, + DeviceMemory filter_data, + const BatchDescriptor& output_descriptor, + DeviceMemory output_data, + const ConvolutionDescriptor& convolution_descriptor, + ScratchAllocator* scratch_allocator, + const AlgorithmConfig& algorithm_config, + ProfileResult* output_profile_result) { + DeviceMemory scratch_memory; + AlgorithmDesc algorithm_desc; + TF_RETURN_IF_ERROR(PrepareForConvolution( + kind, stream, input_descriptor, input_data, filter_descriptor, + filter_data, output_descriptor, output_data, convolution_descriptor, + algorithm_config, scratch_allocator, &algorithm_desc, &scratch_memory)); + return DoConvolve(kind, ToDataType::value, + ToDataType::value, stream, input_descriptor, + input_data, filter_descriptor, filter_data, + output_descriptor, output_data, convolution_descriptor, + algorithm_desc, scratch_memory, output_profile_result); + } + virtual absl::Status GetConvolveRunners( - bool use_cudnn_frontend, dnn::ConvolutionKind kind, - dnn::DataType input_type, dnn::DataType output_type, Stream* stream, - const dnn::BatchDescriptor& input_descriptor, DeviceMemoryBase input_data, - const dnn::FilterDescriptor& filter_descriptor, - DeviceMemoryBase filter_data, - const dnn::BatchDescriptor& output_descriptor, - DeviceMemoryBase output_data, - const dnn::ConvolutionDescriptor& convolution_descriptor, - bool use_fallback, ScratchAllocator* scratch_allocator, + bool use_cudnn_frontend, ConvolutionKind kind, DataType input_type, + DataType output_type, Stream* stream, + const BatchDescriptor& input_descriptor, DeviceMemoryBase input_data, + const FilterDescriptor& filter_descriptor, DeviceMemoryBase filter_data, + const BatchDescriptor& output_descriptor, DeviceMemoryBase output_data, + const ConvolutionDescriptor& convolution_descriptor, bool use_fallback, + ScratchAllocator* scratch_allocator, const NumericOptions& numeric_options, - std::vector>* out_exec_plans); + std::vector>* out_exec_plans); - virtual absl::StatusOr> - ConvolveRunnerFromDesc( - Stream* stream, const dnn::AlgorithmDesc& algorithm_desc, - dnn::ConvolutionKind kind, dnn::DataType element_type, - dnn::DataType output_type, const dnn::BatchDescriptor& input_descriptor, - const dnn::FilterDescriptor& filter_descriptor, - const dnn::BatchDescriptor& output_descriptor, - const dnn::ConvolutionDescriptor& convolution_descriptor); + virtual absl::StatusOr> + ConvolveRunnerFromDesc(Stream* stream, const AlgorithmDesc& algorithm_desc, + ConvolutionKind kind, DataType element_type, + DataType output_type, + const BatchDescriptor& input_descriptor, + const FilterDescriptor& filter_descriptor, + const BatchDescriptor& output_descriptor, + const ConvolutionDescriptor& convolution_descriptor); virtual absl::Status GetGraphConvolveRunners( - dnn::ConvolutionKind kind, dnn::DataType input_type, - dnn::DataType output_type, Stream* stream, - const dnn::BatchDescriptor& input_descriptor, - const dnn::FilterDescriptor& filter_descriptor, - const dnn::BatchDescriptor& output_descriptor, - const dnn::ConvolutionDescriptor& convolution_descriptor, - bool use_fallback, const NumericOptions& numeric_options, - std::vector>* out_exec_plans, + ConvolutionKind kind, DataType input_type, DataType output_type, + Stream* stream, const BatchDescriptor& input_descriptor, + const FilterDescriptor& filter_descriptor, + const BatchDescriptor& output_descriptor, + const ConvolutionDescriptor& convolution_descriptor, bool use_fallback, + const NumericOptions& numeric_options, + std::vector>* out_exec_plans, std::string serialized_graph); - virtual absl::StatusOr> + virtual absl::StatusOr> GraphConvolveRunnerFromDesc( - Stream* stream, const dnn::AlgorithmDesc& algorithm_desc, - dnn::ConvolutionKind kind, dnn::DataType element_type, - dnn::DataType output_type, const dnn::BatchDescriptor& input_descriptor, - const dnn::FilterDescriptor& filter_descriptor, - const dnn::BatchDescriptor& output_descriptor, - const dnn::ConvolutionDescriptor& convolution_descriptor, + Stream* stream, const AlgorithmDesc& algorithm_desc, ConvolutionKind kind, + DataType element_type, DataType output_type, + const BatchDescriptor& input_descriptor, + const FilterDescriptor& filter_descriptor, + const BatchDescriptor& output_descriptor, + const ConvolutionDescriptor& convolution_descriptor, std::string serialized_graph); virtual absl::Status GetFusedConvolveRunners( - bool use_cudnn_frontend, dnn::ConvolutionKind kind, - dnn::DataType element_type, dnn::DataType bias_type, - dnn::DataType output_type, double conv_input_scale, + bool use_cudnn_frontend, ConvolutionKind kind, DataType element_type, + DataType bias_type, DataType output_type, double conv_input_scale, double side_input_scale, double leakyrelu_alpha, Stream* stream, - const dnn::BatchDescriptor& input_descriptor, - const dnn::FilterDescriptor& filter_descriptor, - const dnn::BatchDescriptor& bias_descriptor, - const dnn::BatchDescriptor& output_descriptor, - const dnn::ConvolutionDescriptor& convolution_descriptor, - bool use_fallback, dnn::ActivationMode activation_mode, - const NumericOptions& numeric_options, - std::vector>* out_exec_plans); + const BatchDescriptor& input_descriptor, + const FilterDescriptor& filter_descriptor, + const BatchDescriptor& bias_descriptor, + const BatchDescriptor& output_descriptor, + const ConvolutionDescriptor& convolution_descriptor, bool use_fallback, + ActivationMode activation_mode, const NumericOptions& numeric_options, + std::vector>* out_exec_plans); virtual absl::Status GetFusedMatmulRunners( - bool use_cudnn_frontend, dnn::DataType element_type, - dnn::DataType bias_type, dnn::DataType output_type, Stream* stream, - bool trans_a, bool trans_b, uint64_t m, uint64_t n, uint64_t k, - int64_t lda, int64_t ldb, int64_t ldc, - dnn::ActivationMode activation_mode, bool use_fallback, + bool use_cudnn_frontend, DataType element_type, DataType bias_type, + DataType output_type, Stream* stream, bool trans_a, bool trans_b, + uint64_t m, uint64_t n, uint64_t k, int64_t lda, int64_t ldb, int64_t ldc, + ActivationMode activation_mode, bool use_fallback, const NumericOptions& numeric_options, - std::vector>* - out_exec_plans); + std::vector>* out_exec_plans); - virtual absl::StatusOr> + virtual absl::StatusOr> FusedConvolveRunnerFromDesc( - Stream* stream, const dnn::AlgorithmDesc& algorithm_desc, - dnn::ConvolutionKind kind, dnn::DataType element_type, - dnn::DataType bias_type, dnn::DataType output_type, double conv_scale, - double side_input_scale, double leakyrelu_alpha, - const dnn::BatchDescriptor& input_descriptor, - const dnn::FilterDescriptor& filter_descriptor, - const dnn::BatchDescriptor& bias_descriptor, - const dnn::BatchDescriptor& output_descriptor, - const dnn::ConvolutionDescriptor& convolution_descriptor, - dnn::ActivationMode activation_mode); + Stream* stream, const AlgorithmDesc& algorithm_desc, ConvolutionKind kind, + DataType element_type, DataType bias_type, DataType output_type, + double conv_scale, double side_input_scale, double leakyrelu_alpha, + const BatchDescriptor& input_descriptor, + const FilterDescriptor& filter_descriptor, + const BatchDescriptor& bias_descriptor, + const BatchDescriptor& output_descriptor, + const ConvolutionDescriptor& convolution_descriptor, + ActivationMode activation_mode); virtual absl::StatusOr> NormRunnerFromDesc( - Stream* stream, const dnn::AlgorithmDesc& algorithm_desc, double epsilon, - const dnn::TensorDescriptor& input_descriptor, + Stream* stream, const dnn::AlgorithmDesc& algorithm_desc, + dnn::NormKind kind, double epsilon, + const dnn::TensorDescriptor& x_descriptor, const dnn::TensorDescriptor& scale_descriptor, - const dnn::TensorDescriptor& bias_descriptor, - const dnn::TensorDescriptor& output_descriptor, + const dnn::TensorDescriptor& y_or_dx_descriptor, + std::optional bias_descriptor, + std::optional dy_descriptor, std::optional expectation_descriptor, - std::optional norm_factor_descriptor); + std::optional norm_factor_descriptor, + std::optional dscale_descriptor, + std::optional dbias_descriptor); - virtual absl::StatusOr> + virtual absl::StatusOr> FusedMHARunnerFromDesc( - Stream* stream, const dnn::AlgorithmDesc& algorithm_desc, - dnn::FusedMHAKind kind, - const dnn::MatmulTensorDescriptor& bmm1_lhs_descriptor, - const dnn::MatmulTensorDescriptor& bmm1_rhs_descriptor, - const dnn::MatmulTensorDescriptor& bmm2_rhs_descriptor, - const dnn::MatmulTensorDescriptor& intermediate_bmm2_lhs_descriptor, - const dnn::TensorDescriptor& output_descriptor, - std::optional activation_descriptor, - std::optional mask_descriptor, - std::optional bias_descriptor, double scale, + Stream* stream, const AlgorithmDesc& algorithm_desc, FusedMHAKind kind, + const MatmulTensorDescriptor& bmm1_lhs_descriptor, + const MatmulTensorDescriptor& bmm1_rhs_descriptor, + const MatmulTensorDescriptor& bmm2_rhs_descriptor, + const MatmulTensorDescriptor& intermediate_bmm2_lhs_descriptor, + const TensorDescriptor& output_descriptor, + std::optional activation_descriptor, + std::optional mask_descriptor, + std::optional bias_descriptor, double scale, std::optional dropout_rate, std::optional seed, bool is_flash_attention, bool is_causal_mask); - virtual absl::StatusOr> + virtual absl::StatusOr> FusedMHABackwardRunnerFromDesc( - Stream* stream, const dnn::AlgorithmDesc& algorithm_desc, - dnn::FusedMHAKind kind, + Stream* stream, const AlgorithmDesc& algorithm_desc, FusedMHAKind kind, const MatmulTensorDescriptor& bmm1_grad_gemm1_rhs_descriptor, const MatmulTensorDescriptor& bmm1_grad_gemm2_rhs_descriptor, const MatmulTensorDescriptor& bmm2_grad_gemm1_lhs_descriptor, @@ -1695,29 +1731,57 @@ class DnnSupport { const TensorDescriptor& d_bmm1_lhs_descriptor, const TensorDescriptor& d_bmm1_rhs_descriptor, const TensorDescriptor& d_bmm2_rhs_descriptor, - std::optional d_s_descriptor, - std::optional mask_descriptor, - std::optional d_bias_descriptor, - std::optional fwd_output_descriptor, - std::optional bias_descriptor, double scale, + std::optional d_s_descriptor, + std::optional mask_descriptor, + std::optional d_bias_descriptor, + std::optional fwd_output_descriptor, + std::optional bias_descriptor, double scale, std::optional dropout_rate, std::optional seed, bool is_flash_attention, bool is_causal_mask); virtual bool GetMIOpenConvolveAlgorithms( - dnn::ConvolutionKind kind, dnn::DataType element_type, Stream* stream, - const dnn::BatchDescriptor& input_descriptor, DeviceMemoryBase input_data, - const dnn::FilterDescriptor& filter_descriptor, - DeviceMemoryBase filter_data, - const dnn::BatchDescriptor& output_descriptor, - DeviceMemoryBase output_data, - const dnn::ConvolutionDescriptor& convolution_descriptor, + ConvolutionKind kind, DataType element_type, Stream* stream, + const BatchDescriptor& input_descriptor, DeviceMemoryBase input_data, + const FilterDescriptor& filter_descriptor, DeviceMemoryBase filter_data, + const BatchDescriptor& output_descriptor, DeviceMemoryBase output_data, + const ConvolutionDescriptor& convolution_descriptor, ScratchAllocator* scratch_allocator, std::vector* out_algorithms); // Returns a list of supported rnn algorithms. virtual bool GetRnnAlgorithms(std::vector* out_algorithms); - // Performs a forward pooling operation on input_data, writing to + template + absl::Status PoolForward(Stream* stream, + const PoolingDescriptor& pooling_dimensions, + const NumericOptions& numeric_options, + const BatchDescriptor& input_dimensions, + const DeviceMemory& input_data, + const BatchDescriptor& output_dimensions, + DeviceMemory* output_data, + ScratchAllocator* workspace_allocator = nullptr) { + return DoPoolForward(ToDataType::value, stream, + pooling_dimensions, numeric_options, input_dimensions, + input_data, output_dimensions, *output_data, + workspace_allocator); + } + + template + absl::Status PoolBackward(Stream* stream, + const PoolingDescriptor& pooling_dimensions, + const NumericOptions& numeric_options, + const BatchDescriptor& input_dimensions, + const DeviceMemory& input_data, + const BatchDescriptor& output_dimensions, + const DeviceMemory& output_data, + const DeviceMemory& input_diff_data, + DeviceMemory* output_diff_data, + ScratchAllocator* workspace_allocator = nullptr) { + return DoPoolBackward( + ToDataType::value, stream, pooling_dimensions, + numeric_options, input_dimensions, input_data, output_dimensions, + output_data, input_diff_data, *output_diff_data, workspace_allocator); + } // Performs a forward pooling operation on input_data, writing to // output_data. See PoolingDescriptor for how to configure the // pooling operation. // @@ -1732,37 +1796,36 @@ class DnnSupport { // See PoolingDescriptor for how to configure the pooling operation. virtual absl::Status DoPoolForward( DataType element_type, Stream* stream, - const dnn::PoolingDescriptor& pooling_dimensions, - const dnn::BatchDescriptor& input_dimensions, DeviceMemoryBase input_data, - const dnn::BatchDescriptor& output_dimensions, - DeviceMemoryBase output_data, ScratchAllocator* workspace_allocator) = 0; + const PoolingDescriptor& pooling_dimensions, + const BatchDescriptor& input_dimensions, DeviceMemoryBase input_data, + const BatchDescriptor& output_dimensions, DeviceMemoryBase output_data, + ScratchAllocator* workspace_allocator) = 0; virtual absl::Status DoPoolForward( DataType element_type, Stream* stream, - const dnn::PoolingDescriptor& pooling_dimensions, + const PoolingDescriptor& pooling_dimensions, const NumericOptions& numeric_options, - const dnn::BatchDescriptor& input_dimensions, DeviceMemoryBase input_data, - const dnn::BatchDescriptor& output_dimensions, - DeviceMemoryBase output_data, ScratchAllocator* workspace_allocator); + const BatchDescriptor& input_dimensions, DeviceMemoryBase input_data, + const BatchDescriptor& output_dimensions, DeviceMemoryBase output_data, + ScratchAllocator* workspace_allocator); // Performs differentiation of the pooling operation. virtual absl::Status DoPoolBackward( DataType element_type, Stream* stream, - const dnn::PoolingDescriptor& pooling_dimensions, - const dnn::BatchDescriptor& input_dimensions, DeviceMemoryBase input_data, - const dnn::BatchDescriptor& output_dimensions, - DeviceMemoryBase output_data, DeviceMemoryBase input_diff_data, - DeviceMemoryBase output_diff_data, + const PoolingDescriptor& pooling_dimensions, + const BatchDescriptor& input_dimensions, DeviceMemoryBase input_data, + const BatchDescriptor& output_dimensions, DeviceMemoryBase output_data, + DeviceMemoryBase input_diff_data, DeviceMemoryBase output_diff_data, ScratchAllocator* workspace_allocator) = 0; virtual absl::Status DoPoolBackward( DataType element_type, Stream* stream, - const dnn::PoolingDescriptor& pooling_dimensions, + const PoolingDescriptor& pooling_dimensions, const NumericOptions& numeric_options, - const dnn::BatchDescriptor& input_dimensions, DeviceMemoryBase input_data, - const dnn::BatchDescriptor& output_dimensions, - DeviceMemoryBase output_data, DeviceMemoryBase input_diff_data, - DeviceMemoryBase output_diff_data, ScratchAllocator* workspace_allocator); + const BatchDescriptor& input_dimensions, DeviceMemoryBase input_data, + const BatchDescriptor& output_dimensions, DeviceMemoryBase output_data, + DeviceMemoryBase input_diff_data, DeviceMemoryBase output_diff_data, + ScratchAllocator* workspace_allocator); // Applies local response normalization to the values from input_data and // writes the result to output_data. @@ -1770,9 +1833,9 @@ class DnnSupport { // See comments on NormalizeDescriptor for a description of local response // normalization. virtual bool DoNormalizeWithDimensions( - Stream* stream, const dnn::NormalizeDescriptor& normalize_descriptor, - const dnn::BatchDescriptor& dimensions, - const DeviceMemory& input_data, DeviceMemory* output_data) { + Stream* stream, const NormalizeDescriptor& normalize_descriptor, + const BatchDescriptor& dimensions, const DeviceMemory& input_data, + DeviceMemory* output_data) { return false; } @@ -1789,9 +1852,8 @@ class DnnSupport { // See comments on NormalizeDescriptor for a description of local response // normalization. virtual bool DoNormalizeBackwardWithDimensions( - Stream* stream, const dnn::NormalizeDescriptor& normalize_descriptor, - const dnn::BatchDescriptor& dimensions, - const DeviceMemory& raw_data, + Stream* stream, const NormalizeDescriptor& normalize_descriptor, + const BatchDescriptor& dimensions, const DeviceMemory& raw_data, const DeviceMemory& normalized_data, const DeviceMemory& normalized_variable_gradient, DeviceMemory* raw_variable_gradient, @@ -1799,25 +1861,6 @@ class DnnSupport { return false; } - // Concatenates several layers into one, by concatenating the depth of each - // layer at matching x and y coordinates. - // The inputs must all have the same width and height, the output will have - // the same width and height as the inputs and its depth will be the sum of - // the input depths. - // - // Arguments (all borrowed): - // stream: borrowed pointer to the stream that the 'depth concatenate' - // operation should be enqueued onto. - // input_dimensions: The dimensions of each input. - // input_data: un-owned device memory region which contains the - // input data for each input layer. - // output_data: un-owned device memory region in which to place the - // depth concatenate result. - virtual bool DoDepthConcatenate( - Stream* stream, absl::Span input_dimensions, - absl::Span* const> input_data, - DeviceMemory* output_data) = 0; - // Create an RNN descriptor based on model shapes and configurations. // The caller retains the ownership of the descriptor. // @@ -1840,17 +1883,14 @@ class DnnSupport { // for dropout layer. The user has to maintain the memory until the model // is no longer in use. // use_padded_io: a bool to specify whether the input is using padded IO. - virtual absl::StatusOr> - createRnnDescriptor(int num_layers, int hidden_size, int input_size, - int cell_size, int batch_size, - dnn::RnnInputMode input_mode, - dnn::RnnDirectionMode direction_mode, - dnn::RnnMode rnn_mode, dnn::DataType data_type, - const dnn::AlgorithmConfig& algorithm_config, - const NumericOptions& numeric_options, float dropout, - uint64_t seed, ScratchAllocator* state_allocator, - bool use_padded_io) { - return absl::UnimplementedError("createRnnDescriptor is unimplemented"); + virtual absl::StatusOr> CreateRnnDescriptor( + int num_layers, int hidden_size, int input_size, int cell_size, + int batch_size, RnnInputMode input_mode, RnnDirectionMode direction_mode, + RnnMode rnn_mode, DataType data_type, + const AlgorithmConfig& algorithm_config, + const NumericOptions& numeric_options, float dropout, uint64_t seed, + ScratchAllocator* state_allocator, bool use_padded_io) { + return absl::UnimplementedError("CreateRnnDescriptor is unimplemented"); } // Create a RNN sequence descriptor that specifies either the input or output @@ -1862,36 +1902,36 @@ class DnnSupport { // data_size: the size of the state. // seq_lengths: the lengths of sequences in a batch. // data_type: an enum to specify the type for the underlying data. - virtual absl::StatusOr> - createRnnSequenceTensorDescriptor(int max_seq_length, int batch_size, - int data_size, dnn::DataType data_type) { + virtual absl::StatusOr> + CreateRnnSequenceTensorDescriptor(int max_seq_length, int batch_size, + int data_size, DataType data_type) { return absl::UnimplementedError( - "createRnnSequenceTensorDescriptor is unimplemented"); + "CreateRnnSequenceTensorDescriptor is unimplemented"); } - virtual absl::StatusOr> - createRnnSequenceTensorDescriptor(int max_seq_length, int batch_size, + virtual absl::StatusOr> + CreateRnnSequenceTensorDescriptor(int max_seq_length, int batch_size, int data_size, const absl::Span& seq_lengths, - bool time_major, dnn::DataType data_type) { + bool time_major, DataType data_type) { return absl::UnimplementedError( - "createRnnSequenceTensorDescriptor is unimplemented"); + "CreateRnnSequenceTensorDescriptor is unimplemented"); } // Create an RNN state descriptor that specifies the input or hidden state. // The caller retains the ownership of the returned descriptor. - virtual absl::StatusOr> - createRnnStateTensorDescriptor(int num_layer, int batch_size, int data_size, - dnn::DataType data_type) { + virtual absl::StatusOr> + CreateRnnStateTensorDescriptor(int num_layer, int batch_size, int data_size, + DataType data_type) { return absl::UnimplementedError( - "createRnnStateTensorDescriptor is unimplemented"); + "CreateRnnStateTensorDescriptor is unimplemented"); } // Enqueue a forward operation of the RNN model onto the stream. // // Arguments: // stream: pointer to the stream where this operation should be enqueued to. - // rnn_desc: a RNN descriptor created by createRnnDescriptor. + // rnn_desc: a RNN descriptor created by CreateRnnDescriptor. // input_desc: descriptor for the input sequence. // input_data: the device memory region that contains the input data. // input_h_desc: descriptor for the input "h" state. @@ -1916,76 +1956,76 @@ class DnnSupport { // workspace_allocator: an allocator to create temporary workspace used in // this kernel. The caller is responsible for retaining the memory long // enough for the lifespan of this operation, and recycles afterwards. - virtual bool DoRnnForward(Stream* stream, const dnn::RnnDescriptor& rnn_desc, - const dnn::RnnSequenceTensorDescriptor& input_desc, + virtual bool DoRnnForward(Stream* stream, const RnnDescriptor& rnn_desc, + const RnnSequenceTensorDescriptor& input_desc, const DeviceMemory& input_data, const DeviceMemory& seq_lengths_data, - const dnn::RnnStateTensorDescriptor& input_h_desc, + const RnnStateTensorDescriptor& input_h_desc, const DeviceMemory& input_h_data, - const dnn::RnnStateTensorDescriptor& input_c_desc, + const RnnStateTensorDescriptor& input_c_desc, const DeviceMemory& input_c_data, const DeviceMemory& params, - const dnn::RnnSequenceTensorDescriptor& output_desc, + const RnnSequenceTensorDescriptor& output_desc, DeviceMemory* output_data, - const dnn::RnnStateTensorDescriptor& output_h_desc, + const RnnStateTensorDescriptor& output_h_desc, DeviceMemory* output_h_data, - const dnn::RnnStateTensorDescriptor& output_c_desc, + const RnnStateTensorDescriptor& output_c_desc, DeviceMemory* output_c_data, bool is_training, ScratchAllocator* reserve_space_allocator, ScratchAllocator* workspace_allocator, - dnn::ProfileResult* output_profile_result) { + ProfileResult* output_profile_result) { return false; } - virtual bool DoRnnForward(Stream* stream, const dnn::RnnDescriptor& rnn_desc, - const dnn::RnnSequenceTensorDescriptor& input_desc, + virtual bool DoRnnForward(Stream* stream, const RnnDescriptor& rnn_desc, + const RnnSequenceTensorDescriptor& input_desc, const DeviceMemory& input_data, const DeviceMemory& seq_lengths_data, - const dnn::RnnStateTensorDescriptor& input_h_desc, + const RnnStateTensorDescriptor& input_h_desc, const DeviceMemory& input_h_data, - const dnn::RnnStateTensorDescriptor& input_c_desc, + const RnnStateTensorDescriptor& input_c_desc, const DeviceMemory& input_c_data, const DeviceMemory& params, - const dnn::RnnSequenceTensorDescriptor& output_desc, + const RnnSequenceTensorDescriptor& output_desc, DeviceMemory* output_data, - const dnn::RnnStateTensorDescriptor& output_h_desc, + const RnnStateTensorDescriptor& output_h_desc, DeviceMemory* output_h_data, - const dnn::RnnStateTensorDescriptor& output_c_desc, + const RnnStateTensorDescriptor& output_c_desc, DeviceMemory* output_c_data, bool is_training, ScratchAllocator* reserve_space_allocator, ScratchAllocator* workspace_allocator, - dnn::ProfileResult* output_profile_result) { + ProfileResult* output_profile_result) { return false; } - virtual bool DoRnnForward(Stream* stream, const dnn::RnnDescriptor& rnn_desc, - const dnn::RnnSequenceTensorDescriptor& input_desc, + virtual bool DoRnnForward(Stream* stream, const RnnDescriptor& rnn_desc, + const RnnSequenceTensorDescriptor& input_desc, const DeviceMemory& input_data, const DeviceMemory& seq_lengths_data, - const dnn::RnnStateTensorDescriptor& input_h_desc, + const RnnStateTensorDescriptor& input_h_desc, const DeviceMemory& input_h_data, - const dnn::RnnStateTensorDescriptor& input_c_desc, + const RnnStateTensorDescriptor& input_c_desc, const DeviceMemory& input_c_data, const DeviceMemory& params, - const dnn::RnnSequenceTensorDescriptor& output_desc, + const RnnSequenceTensorDescriptor& output_desc, DeviceMemory* output_data, - const dnn::RnnStateTensorDescriptor& output_h_desc, + const RnnStateTensorDescriptor& output_h_desc, DeviceMemory* output_h_data, - const dnn::RnnStateTensorDescriptor& output_c_desc, + const RnnStateTensorDescriptor& output_c_desc, DeviceMemory* output_c_data, bool is_training, ScratchAllocator* reserve_space_allocator, ScratchAllocator* workspace_allocator, - dnn::ProfileResult* output_profile_result) { + ProfileResult* output_profile_result) { return false; } // Enqueue a backward operation of the RNN model onto the stream. // // Arguments: // stream: pointer to the stream where this operation should be enqueued to. - // rnn_desc: a RNN descriptor created by createRnnDescriptor. + // rnn_desc: a RNN descriptor created by CreateRnnDescriptor. // input_desc: descriptor for the input sequence. // input_data: the device memory region that contains the input data. // input_h_desc: descriptor for the input "h" state. @@ -2023,20 +2063,20 @@ class DnnSupport { // keeping the memory alive long enough for this operation, and recylces // afterwards. virtual bool DoRnnBackward( - Stream* stream, const dnn::RnnDescriptor& rnn_desc, - const dnn::RnnSequenceTensorDescriptor& input_desc, + Stream* stream, const RnnDescriptor& rnn_desc, + const RnnSequenceTensorDescriptor& input_desc, const DeviceMemory& input_data, const DeviceMemory& seq_lengths_data, - const dnn::RnnStateTensorDescriptor& input_h_desc, + const RnnStateTensorDescriptor& input_h_desc, const DeviceMemory& input_h_data, - const dnn::RnnStateTensorDescriptor& input_c_desc, + const RnnStateTensorDescriptor& input_c_desc, const DeviceMemory& input_c_data, const DeviceMemory& params, - const dnn::RnnSequenceTensorDescriptor& output_desc, + const RnnSequenceTensorDescriptor& output_desc, const DeviceMemory& output_data, - const dnn::RnnStateTensorDescriptor& output_h_desc, + const RnnStateTensorDescriptor& output_h_desc, const DeviceMemory& output_h_data, - const dnn::RnnStateTensorDescriptor& output_c_desc, + const RnnStateTensorDescriptor& output_c_desc, const DeviceMemory& output_c_data, const DeviceMemory& output_backprop_data, const DeviceMemory& output_h_backprop_data, @@ -2047,65 +2087,63 @@ class DnnSupport { DeviceMemory* params_backprop_data, DeviceMemory* reserve_space_data, ScratchAllocator* workspace_allocator, - dnn::ProfileResult* output_profile_result) { + ProfileResult* output_profile_result) { return false; } - virtual bool DoRnnBackward( - Stream* stream, const dnn::RnnDescriptor& rnn_desc, - const dnn::RnnSequenceTensorDescriptor& input_desc, - const DeviceMemory& input_data, - const DeviceMemory& seq_lengths_data, - const dnn::RnnStateTensorDescriptor& input_h_desc, - const DeviceMemory& input_h_data, - const dnn::RnnStateTensorDescriptor& input_c_desc, - const DeviceMemory& input_c_data, - const DeviceMemory& params, - const dnn::RnnSequenceTensorDescriptor& output_desc, - const DeviceMemory& output_data, - const dnn::RnnStateTensorDescriptor& output_h_desc, - const DeviceMemory& output_h_data, - const dnn::RnnStateTensorDescriptor& output_c_desc, - const DeviceMemory& output_c_data, - const DeviceMemory& output_backprop_data, - const DeviceMemory& output_h_backprop_data, - const DeviceMemory& output_c_backprop_data, - DeviceMemory* input_backprop_data, - DeviceMemory* input_h_backprop_data, - DeviceMemory* input_c_backprop_data, - DeviceMemory* params_backprop_data, - DeviceMemory* reserve_space_data, - ScratchAllocator* workspace_allocator, - dnn::ProfileResult* output_profile_result) { + virtual bool DoRnnBackward(Stream* stream, const RnnDescriptor& rnn_desc, + const RnnSequenceTensorDescriptor& input_desc, + const DeviceMemory& input_data, + const DeviceMemory& seq_lengths_data, + const RnnStateTensorDescriptor& input_h_desc, + const DeviceMemory& input_h_data, + const RnnStateTensorDescriptor& input_c_desc, + const DeviceMemory& input_c_data, + const DeviceMemory& params, + const RnnSequenceTensorDescriptor& output_desc, + const DeviceMemory& output_data, + const RnnStateTensorDescriptor& output_h_desc, + const DeviceMemory& output_h_data, + const RnnStateTensorDescriptor& output_c_desc, + const DeviceMemory& output_c_data, + const DeviceMemory& output_backprop_data, + const DeviceMemory& output_h_backprop_data, + const DeviceMemory& output_c_backprop_data, + DeviceMemory* input_backprop_data, + DeviceMemory* input_h_backprop_data, + DeviceMemory* input_c_backprop_data, + DeviceMemory* params_backprop_data, + DeviceMemory* reserve_space_data, + ScratchAllocator* workspace_allocator, + ProfileResult* output_profile_result) { return false; } - virtual bool DoRnnBackward( - Stream* stream, const dnn::RnnDescriptor& rnn_desc, - const dnn::RnnSequenceTensorDescriptor& input_desc, - const DeviceMemory& input_data, - const DeviceMemory& seq_lengths_data, - const dnn::RnnStateTensorDescriptor& input_h_desc, - const DeviceMemory& input_h_data, - const dnn::RnnStateTensorDescriptor& input_c_desc, - const DeviceMemory& input_c_data, - const DeviceMemory& params, - const dnn::RnnSequenceTensorDescriptor& output_desc, - const DeviceMemory& output_data, - const dnn::RnnStateTensorDescriptor& output_h_desc, - const DeviceMemory& output_h_data, - const dnn::RnnStateTensorDescriptor& output_c_desc, - const DeviceMemory& output_c_data, - const DeviceMemory& output_backprop_data, - const DeviceMemory& output_h_backprop_data, - const DeviceMemory& output_c_backprop_data, - DeviceMemory* input_backprop_data, - DeviceMemory* input_h_backprop_data, - DeviceMemory* input_c_backprop_data, - DeviceMemory* params_backprop_data, - DeviceMemory* reserve_space_data, - ScratchAllocator* workspace_allocator, - dnn::ProfileResult* output_profile_result) { + virtual bool DoRnnBackward(Stream* stream, const RnnDescriptor& rnn_desc, + const RnnSequenceTensorDescriptor& input_desc, + const DeviceMemory& input_data, + const DeviceMemory& seq_lengths_data, + const RnnStateTensorDescriptor& input_h_desc, + const DeviceMemory& input_h_data, + const RnnStateTensorDescriptor& input_c_desc, + const DeviceMemory& input_c_data, + const DeviceMemory& params, + const RnnSequenceTensorDescriptor& output_desc, + const DeviceMemory& output_data, + const RnnStateTensorDescriptor& output_h_desc, + const DeviceMemory& output_h_data, + const RnnStateTensorDescriptor& output_c_desc, + const DeviceMemory& output_c_data, + const DeviceMemory& output_backprop_data, + const DeviceMemory& output_h_backprop_data, + const DeviceMemory& output_c_backprop_data, + DeviceMemory* input_backprop_data, + DeviceMemory* input_h_backprop_data, + DeviceMemory* input_c_backprop_data, + DeviceMemory* params_backprop_data, + DeviceMemory* reserve_space_data, + ScratchAllocator* workspace_allocator, + ProfileResult* output_profile_result) { return false; } @@ -2149,7 +2187,7 @@ class DnnSupport { // keeping the memory alive long enough for this operation, and recylces // afterwards. virtual absl::Status DoCtcLoss( - Stream* stream, dnn::DataType element_type, + Stream* stream, DataType element_type, const RnnStateTensorDescriptor& probs_desc, const DeviceMemoryBase probs_data, absl::Span labels_data, absl::Span labels_lengths_data, @@ -2158,14 +2196,13 @@ class DnnSupport { DeviceMemory scratch_memory, int ctc_loss_algo_id); template - bool DoCtcLoss(Stream* stream, - const dnn::RnnStateTensorDescriptor& probs_desc, + bool DoCtcLoss(Stream* stream, const RnnStateTensorDescriptor& probs_desc, const DeviceMemory& probs_data, absl::Span labels_data, absl::Span labels_lengths_data, absl::Span input_lengths_data, DeviceMemory* costs_data, - const dnn::RnnStateTensorDescriptor& grads_desc, + const RnnStateTensorDescriptor& grads_desc, DeviceMemory* grads_data, DeviceMemory* scratch_memory, int ctc_loss_algo_id) { return IsStatusOk( @@ -2188,13 +2225,10 @@ class DnnSupport { // output_type: the data type of the output tensor. // scale: an element-wise scaling factor to apply. // output_data: the device memory region that contains the output tensor. - virtual bool DoTransformTensor(Stream* stream, - const dnn::BatchDescriptor& input_desc, - dnn::DataType input_type, - const DeviceMemoryBase& input_data, - const dnn::BatchDescriptor& output_desc, - dnn::DataType output_type, float scale, - DeviceMemoryBase* output_data) { + virtual bool DoTransformTensor( + Stream* stream, const BatchDescriptor& input_desc, DataType input_type, + const DeviceMemoryBase& input_data, const BatchDescriptor& output_desc, + DataType output_type, float scale, DeviceMemoryBase* output_data) { return false; } diff --git a/third_party/xla/xla/stream_executor/gpu/BUILD b/third_party/xla/xla/stream_executor/gpu/BUILD index d45878133fa03d..366b1f1f1eddcb 100644 --- a/third_party/xla/xla/stream_executor/gpu/BUILD +++ b/third_party/xla/xla/stream_executor/gpu/BUILD @@ -19,6 +19,7 @@ load( ) load( "//xla/stream_executor:build_defs.bzl", + "gpu_only_cc_library", "if_gpu_is_configured", ) load( @@ -29,7 +30,7 @@ load( load( "@local_tsl//tsl:tsl.bzl", "if_libtpu", - "set_external_visibility", + "internal_visibility", "tsl_copts", "tsl_gpu_library", ) @@ -48,56 +49,61 @@ load( ) package( - default_visibility = ["//visibility:public"], + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = internal_visibility([ + "//tensorflow/compiler/tf2xla:__subpackages__", + "//xla:__subpackages__", + "//tensorflow/core/kernels:__subpackages__", + "//xla/pjrt:__subpackages__", + "//xla/service/gpu:__subpackages__", + "//xla/stream_executor:__subpackages__", + "//tensorflow/core/common_runtime/gpu:__subpackages__", + "//waymo/ml/compiler/triton:__subpackages__", + ]), licenses = ["notice"], ) cc_library( name = "gpu_activation_header", hdrs = ["gpu_activation.h"], - visibility = ["//visibility:public"], - deps = ["//xla/stream_executor/platform"], ) -cc_library( +gpu_only_cc_library( name = "gpu_activation", - srcs = if_gpu_is_configured(["gpu_activation.cc"]), - hdrs = if_gpu_is_configured(["gpu_activation.h"]), - visibility = ["//visibility:public"], - deps = if_gpu_is_configured([ - ":gpu_executor_header", + srcs = ["gpu_activation.cc"], + hdrs = ["gpu_activation.h"], + deps = [ ":gpu_activation_header", ":gpu_driver_header", + ":gpu_executor_header", "//xla/stream_executor", - "//xla/stream_executor:stream_executor_internal", - "//xla/stream_executor/platform", - ]), + ], ) -cc_library( +gpu_only_cc_library( name = "gpu_diagnostics_header", - hdrs = if_gpu_is_configured(["gpu_diagnostics.h"]), - visibility = ["//visibility:public"], - deps = [ - "//xla/stream_executor/platform", - "@com_google_absl//absl/status:statusor", - ], + hdrs = ["gpu_diagnostics.h"], + deps = ["@com_google_absl//absl/status:statusor"], ) -cc_library( +gpu_only_cc_library( name = "gpu_collectives_header", - hdrs = if_gpu_is_configured(["gpu_collectives.h"]), - visibility = ["//visibility:public"], + hdrs = ["gpu_collectives.h"], deps = [ "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", ], ) -cc_library( +gpu_only_cc_library( name = "gpu_driver_header", - hdrs = if_gpu_is_configured(["gpu_driver.h"]), - visibility = ["//visibility:public"], + hdrs = ["gpu_driver.h"], + visibility = internal_visibility([ + "//xla/service/gpu:__subpackages__", + "//xla/stream_executor:__subpackages__", + "//tensorflow/core/common_runtime/gpu:__subpackages__", + "//tensorflow/core/util/autotune_maps:__subpackages__", + ]), deps = [ ":gpu_types_header", "//xla/stream_executor", @@ -114,30 +120,31 @@ cc_library( ), ) -cc_library( +gpu_only_cc_library( name = "gpu_runtime_header", - hdrs = if_gpu_is_configured(["gpu_runtime.h"]), - visibility = ["//visibility:public"], + hdrs = ["gpu_runtime.h"], + visibility = internal_visibility([ + "//xla/service/gpu:__subpackages__", + "//xla/stream_executor:__subpackages__", + ]), deps = [ ":gpu_types_header", "@com_google_absl//absl/status:statusor", ], ) -cc_library( +gpu_only_cc_library( name = "gpu_kernels", - hdrs = if_gpu_is_configured(["gpu_kernels.h"]), - visibility = ["//visibility:public"], + hdrs = ["gpu_kernels.h"], ) -cc_library( +gpu_only_cc_library( name = "gpu_command_buffer", - srcs = if_gpu_is_configured(["gpu_command_buffer.cc"]), - hdrs = if_gpu_is_configured(["gpu_command_buffer.h"]), + srcs = ["gpu_command_buffer.cc"], + hdrs = ["gpu_command_buffer.h"], local_defines = if_rocm_is_configured([ "TENSORFLOW_USE_ROCM=1", ]), - visibility = ["//visibility:public"], deps = [ ":gpu_driver_header", ":gpu_executor_header", @@ -147,6 +154,7 @@ cc_library( ":gpu_types_header", "//xla/stream_executor", "//xla/stream_executor:stream_executor_internal", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/functional:any_invocable", "@com_google_absl//absl/log", @@ -154,6 +162,7 @@ cc_library( "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:span", "@local_tsl//tsl/platform:env", "@local_tsl//tsl/platform:errors", @@ -169,61 +178,56 @@ cc_library( ]), ) -cc_library( +gpu_only_cc_library( name = "gpu_event_header", - hdrs = if_gpu_is_configured(["gpu_event.h"]), - visibility = ["//visibility:public"], - deps = if_gpu_is_configured([ - ":gpu_driver_header", + hdrs = ["gpu_event.h"], + deps = [ ":gpu_stream_header", - "@com_google_absl//absl/status", + ":gpu_types_header", "//xla/stream_executor", - "@local_tsl//tsl/platform:status", - ]), + "@com_google_absl//absl/status", + ], ) -cc_library( +gpu_only_cc_library( name = "gpu_event", - srcs = if_gpu_is_configured(["gpu_event.cc"]), - hdrs = if_gpu_is_configured(["gpu_event.h"]), - visibility = ["//visibility:public"], + srcs = ["gpu_event.cc"], + hdrs = ["gpu_event.h"], deps = [ ":gpu_driver_header", ":gpu_executor_header", ":gpu_stream", + ":gpu_types_header", "//xla/stream_executor", "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", ], ) -cc_library( +gpu_only_cc_library( name = "gpu_executor_header", - hdrs = if_gpu_is_configured(["gpu_executor.h"]), - visibility = ["//visibility:public"], + hdrs = ["gpu_executor.h"], deps = [ ":gpu_collectives_header", - ":gpu_kernel_header", + ":gpu_driver_header", ":gpu_types_header", "//xla/stream_executor", "//xla/stream_executor:platform", "//xla/stream_executor:stream_executor_internal", - "//xla/stream_executor/platform", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/functional:any_invocable", + "@com_google_absl//absl/numeric:int128", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", + "@com_google_absl//absl/synchronization", "@com_google_absl//absl/types:span", - "@local_tsl//tsl/platform:fingerprint", + "@local_tsl//tsl/platform:thread_annotations", ], ) -cc_library( +gpu_only_cc_library( name = "gpu_helpers_header", - hdrs = if_gpu_is_configured(["gpu_helpers.h"]), - visibility = ["//visibility:public"], + hdrs = ["gpu_helpers.h"], deps = [ ":gpu_types_header", "@local_tsl//tsl/platform:logging", @@ -235,7 +239,9 @@ tsl_gpu_library( hdrs = [ "gpu_init.h", ], - visibility = ["//visibility:public"], + visibility = internal_visibility([ + "@local_tsl//tsl:internal", + ]), deps = [ "@com_google_absl//absl/status", "@local_tsl//tsl/platform:status", @@ -254,10 +260,15 @@ tsl_gpu_library( ], copts = tsl_copts(), linkstatic = True, - visibility = ["//visibility:public"], + visibility = internal_visibility([ + "//tensorflow/compiler/tf2xla:__subpackages__", + "//xla:__subpackages__", + "//tensorflow/core/common_runtime/gpu:__subpackages__", + "//tensorflow/stream_executor:__subpackages__", + ]), deps = [ - "//xla/stream_executor:multi_platform_manager", "//xla/stream_executor:platform", + "//xla/stream_executor:platform_manager", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@local_tsl//tsl/platform:logging", @@ -265,74 +276,78 @@ tsl_gpu_library( alwayslink = True, ) -cc_library( +gpu_only_cc_library( name = "gpu_kernel_header", - hdrs = if_gpu_is_configured(["gpu_kernel.h"]), - visibility = ["//visibility:public"], + hdrs = ["gpu_kernel.h"], deps = [ ":gpu_driver_header", + ":gpu_executor_header", ":gpu_types_header", "//xla/stream_executor", "//xla/stream_executor:stream_executor_internal", "@com_google_absl//absl/status:statusor", + "@local_tsl//tsl/platform:logging", ], ) -cc_library( +gpu_only_cc_library( name = "gpu_stream_header", - hdrs = if_gpu_is_configured(["gpu_stream.h"]), - visibility = ["//visibility:public"], + hdrs = ["gpu_stream.h"], deps = [ ":gpu_types_header", "//xla/stream_executor", "//xla/stream_executor:stream_executor_internal", + "@com_google_absl//absl/log:check", ], ) -cc_library( +gpu_only_cc_library( name = "gpu_stream", - srcs = if_gpu_is_configured(["gpu_stream.cc"]), - hdrs = if_gpu_is_configured(["gpu_stream.h"]), - visibility = ["//visibility:public"], + srcs = ["gpu_stream.cc"], + hdrs = ["gpu_stream.h"], deps = [ + ":gpu_driver_header", ":gpu_executor_header", ":gpu_types_header", "//xla/stream_executor", "//xla/stream_executor:stream_executor_internal", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", ], ) -cc_library( +gpu_only_cc_library( name = "gpu_timer_header", - hdrs = if_gpu_is_configured(["gpu_timer.h"]), - visibility = ["//visibility:public"], + hdrs = ["gpu_timer.h"], deps = [ - ":gpu_driver_header", ":gpu_executor_header", - "//xla/stream_executor:stream_executor_internal", + ":gpu_types_header", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/time", ], ) -cc_library( +gpu_only_cc_library( name = "gpu_timer", - srcs = if_gpu_is_configured(["gpu_timer.cc"]), - hdrs = if_gpu_is_configured(["gpu_timer.h"]), - visibility = ["//visibility:public"], + srcs = ["gpu_timer.cc"], + hdrs = ["gpu_timer.h"], deps = [ ":gpu_driver_header", ":gpu_executor_header", ":gpu_stream", + ":gpu_types_header", "//xla/stream_executor", "//xla/stream_executor:stream_executor_internal", "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/synchronization", "@com_google_absl//absl/time", "@com_google_absl//absl/utility", + "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:status", "@local_tsl//tsl/platform:statusor", ] + if_cuda_is_configured([ @@ -342,10 +357,9 @@ cc_library( ]), ) -cc_library( +gpu_only_cc_library( name = "gpu_types_header", - hdrs = if_gpu_is_configured(["gpu_types.h"]), - visibility = ["//visibility:public"], + hdrs = ["gpu_types.h"], deps = [ "//xla/stream_executor/platform", ] + if_cuda_is_configured([ @@ -358,80 +372,104 @@ cc_library( cc_library( name = "gpu_asm_opts", hdrs = ["gpu_asm_opts.h"], - visibility = ["//visibility:public"], + visibility = internal_visibility([ + "//xla/service/gpu:__subpackages__", + "//xla/stream_executor:__subpackages__", + "//tensorflow/core/kernels:__subpackages__", + ]), deps = [ "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", ], ) -cc_library( +gpu_only_cc_library( name = "asm_compiler_header", - hdrs = if_gpu_is_configured(["asm_compiler.h"]), + hdrs = ["asm_compiler.h"], copts = tsl_copts(), - visibility = ["//visibility:public"], - deps = if_gpu_is_configured([ + visibility = internal_visibility([ + "//tensorflow/compiler/mlir/tools/kernel_gen:__subpackages__", + "//xla/service/gpu:__subpackages__", + "//xla/stream_executor:__subpackages__", + "//tensorflow/core/kernels:__subpackages__", + ]), + deps = [ ":gpu_asm_opts", ":gpu_driver_header", ":gpu_helpers_header", - "@com_google_absl//absl/cleanup", - "@com_google_absl//absl/base:core_headers", - "@local_tsl//tsl/platform:regexp", - "@local_tsl//tsl/platform:mutex", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:subprocess", - "@local_tsl//tsl/platform:path", - "@local_tsl//tsl/platform:cuda_libdevice_path", "//xla/stream_executor", - "@com_google_absl//absl/status:statusor", - "@local_tsl//tsl/platform:statusor", "//xla/stream_executor/platform", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/cleanup", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/container:node_hash_map", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/synchronization", "@com_google_absl//absl/types:span", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/container:flat_hash_set", - ]) + if_cuda_is_configured([ + "@local_config_cuda//cuda:cuda_headers", + "@local_tsl//tsl/platform:cuda_libdevice_path", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:mutex", + "@local_tsl//tsl/platform:path", + "@local_tsl//tsl/platform:regexp", + "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/platform:subprocess", + ] + if_cuda_is_configured([ "//xla/stream_executor/cuda:cuda_driver", ]) + if_rocm_is_configured([ "//xla/stream_executor/rocm:rocm_driver", ]), ) -cc_library( +gpu_only_cc_library( name = "asm_compiler", - srcs = if_gpu_is_configured(["asm_compiler.cc"]), - hdrs = if_gpu_is_configured(["asm_compiler.h"]), + srcs = ["asm_compiler.cc"], + hdrs = ["asm_compiler.h"], copts = tsl_copts(), local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]), - visibility = ["//visibility:public"], - deps = if_gpu_is_configured([ + visibility = internal_visibility([ + "//third_party/py/jax:__subpackages__", + "//tensorflow/compiler/mlir/tools/kernel_gen:__subpackages__", + "//xla/service/gpu:__subpackages__", + "//xla/stream_executor:__subpackages__", + "//tensorflow/core/kernels:__subpackages__", + ]), + deps = [ ":gpu_asm_opts", ":gpu_driver_header", ":gpu_helpers_header", - "@com_google_absl//absl/cleanup", - "@com_google_absl//absl/base:core_headers", - "@local_tsl//tsl/platform:regexp", - "@local_tsl//tsl/platform:mutex", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:env", - "@local_tsl//tsl/platform:subprocess", - "@local_tsl//tsl/platform:path", - "@local_tsl//tsl/platform:cuda_libdevice_path", + ":gpu_types_header", + "//xla:util", "//xla/stream_executor", "//xla/stream_executor/platform", - "//xla:util", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/cleanup", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/container:node_hash_map", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/status:statusor", - "@local_tsl//tsl/platform:statusor", "@com_google_absl//absl/synchronization", "@com_google_absl//absl/types:span", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/container:flat_hash_set", - ]) + if_cuda_is_configured([ + "@local_config_cuda//cuda:cuda_headers", + "@local_tsl//tsl/platform:cuda_libdevice_path", + "@local_tsl//tsl/platform:env", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:mutex", + "@local_tsl//tsl/platform:path", + "@local_tsl//tsl/platform:regexp", + "@local_tsl//tsl/platform:status", + "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/platform:subprocess", + ] + if_cuda_is_configured([ "//xla/stream_executor/cuda:cuda_asm_compiler", "//xla/stream_executor/cuda:cuda_driver", "//xla/stream_executor/cuda:ptxas_wrapper", @@ -454,6 +492,8 @@ gpu_kernel_library( "//xla/stream_executor:device_memory", "//xla/stream_executor:device_memory_allocator", "//xla/stream_executor", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", ]) + if_cuda_is_configured([ "@local_config_cuda//cuda:cuda_headers", ]) + if_rocm_is_configured([ @@ -461,33 +501,42 @@ gpu_kernel_library( ]), ) -cc_library( +gpu_only_cc_library( name = "redzone_allocator", - srcs = if_gpu_is_configured(["redzone_allocator.cc"]), - hdrs = if_gpu_is_configured(["redzone_allocator.h"]), + srcs = ["redzone_allocator.cc"], + hdrs = ["redzone_allocator.h"], copts = tsl_copts(), local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]) + if_rocm_is_configured([ "TENSORFLOW_USE_ROCM=1", ]), - visibility = ["//visibility:public"], - deps = if_gpu_is_configured([ + visibility = internal_visibility([ + "//xla/service/gpu:__subpackages__", + "//xla/stream_executor:__subpackages__", + "//tensorflow/core/kernels:__subpackages__", + ]), + deps = [ ":asm_compiler", ":gpu_asm_opts", + "//xla/stream_executor", + "//xla/stream_executor:device_memory", + "//xla/stream_executor:device_memory_allocator", "@com_google_absl//absl/base", "@com_google_absl//absl/container:fixed_array", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/types:optional", + "@com_google_absl//absl/types:span", + "@local_tsl//tsl/framework:allocator", "@local_tsl//tsl/lib/math:math_util", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:logging", - "@com_google_absl//absl/status:statusor", - "@local_tsl//tsl/framework:allocator", - "//xla/stream_executor:device_memory", - "//xla/stream_executor:device_memory_allocator", - "//xla/stream_executor", "@local_tsl//tsl/platform:status", - ]) + if_rocm_is_configured([ + "@local_tsl//tsl/platform:statusor", + ] + if_rocm_is_configured([ ":redzone_allocator_kernel", ]), ) @@ -505,7 +554,10 @@ xla_cc_test( "//xla/stream_executor", "//xla/stream_executor:device_memory_allocator", "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:string_view", "@local_tsl//tsl/lib/core:status_test_util", + "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/platform:test", "@local_tsl//tsl/platform:test_main", ], @@ -515,14 +567,12 @@ xla_cc_test( tsl_gpu_library( name = "gpu_cudamallocasync_allocator_header", hdrs = ["gpu_cudamallocasync_allocator.h"], - visibility = ["//visibility:public"], deps = [ - "//xla/stream_executor", + "//xla/stream_executor:stream_executor_headers", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:flat_hash_map", "@local_tsl//tsl/framework:allocator", "@local_tsl//tsl/framework:device_id", - "@local_tsl//tsl/platform:macros", "@local_tsl//tsl/platform:mutex", ], ) @@ -537,53 +587,20 @@ tsl_gpu_library( "//xla/stream_executor/cuda:cuda_activation", "//xla/stream_executor/cuda:cuda_executor", ], - visibility = ["//visibility:public"], deps = [ ":gpu_init_impl", - "//xla/stream_executor", + "//xla/stream_executor:stream_executor_headers", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/strings", "@local_tsl//tsl/framework:allocator", "@local_tsl//tsl/framework:device_id", "@local_tsl//tsl/platform:logging", - "@local_tsl//tsl/platform:macros", "@local_tsl//tsl/platform:mutex", "@local_tsl//tsl/util:env_var", ], ) -cc_library( - name = "gpu_graph", - srcs = if_gpu_is_configured(["gpu_graph.cc"]), - hdrs = if_gpu_is_configured(["gpu_graph.h"]), - visibility = ["//visibility:public"], - deps = if_gpu_is_configured([ - ":gpu_driver_header", - ":gpu_kernel_header", - ":gpu_types_header", - ":gpu_executor_header", - "@com_google_absl//absl/status", - "@com_google_absl//absl/types:span", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/functional:any_invocable", - "//xla/stream_executor/gpu:gpu_stream", - "//xla/stream_executor", - "@local_tsl//tsl/platform:env", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:path", - "@local_tsl//tsl/platform:status", - "@local_tsl//tsl/platform:statusor", - "@com_google_absl//absl/strings", - ]) + if_cuda_is_configured([ - "@local_config_cuda//cuda:cuda_headers", - "//xla/stream_executor/cuda:cuda_driver", - ]) + if_rocm_is_configured([ - "//xla/stream_executor/rocm:rocm_driver", - ]), -) - cc_library( name = "gpu_blas_lt", srcs = ["gpu_blas_lt.cc"], @@ -591,7 +608,6 @@ cc_library( local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]) + if_rocm_is_configured([ "TENSORFLOW_USE_ROCM=1", ]), - visibility = ["//visibility:public"], deps = [ "//xla:shape_util", "//xla:status", @@ -605,7 +621,9 @@ cc_library( "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/types:span", + "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/protobuf:dnn_proto_cc", ] + if_cuda_is_configured([ "@local_tsl//tsl/platform:tensor_float_32_hdr_lib", ]) + if_static([ @@ -635,9 +653,12 @@ xla_test( ":gpu_test_kernels", "//xla/service:platform_util", "//xla/stream_executor", - "//xla/stream_executor:multi_platform_manager", "//xla/stream_executor:platform", + "//xla/stream_executor:platform_manager", + "@com_google_absl//absl/strings", "@com_google_googletest//:gtest", + "@local_tsl//tsl/lib/core:status_test_util", + "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/platform:test", "@local_tsl//tsl/platform:test_main", ] + if_cuda([ @@ -653,17 +674,22 @@ xla_test( backends = ["gpu"], local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]) + if_rocm_is_configured(["TENSORFLOW_USE_ROCM=1"]), deps = [ + ":gpu_command_buffer", ":gpu_test_kernels", ":gpu_types_header", "//xla/service:platform_util", "//xla/stream_executor", - "//xla/stream_executor:multi_platform_manager", "//xla/stream_executor:platform", + "//xla/stream_executor:platform_manager", + "//xla/stream_executor/gpu:gpu_driver_header", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", "@local_config_cuda//cuda:cuda_headers", "@local_tsl//tsl/lib/core:status_test_util", + "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:status", + "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/platform:test", "@local_tsl//tsl/platform:test_benchmark", "@local_tsl//tsl/platform:test_main", diff --git a/third_party/xla/xla/stream_executor/gpu/asm_compiler.cc b/third_party/xla/xla/stream_executor/gpu/asm_compiler.cc index 16f7c69c0adc87..5a44b29ec95b7a 100644 --- a/third_party/xla/xla/stream_executor/gpu/asm_compiler.cc +++ b/third_party/xla/xla/stream_executor/gpu/asm_compiler.cc @@ -16,7 +16,10 @@ limitations under the License. #include "xla/stream_executor/gpu/asm_compiler.h" #include +#include #include +#include +#include #include #include #include @@ -28,21 +31,30 @@ limitations under the License. #include "absl/cleanup/cleanup.h" #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" +#include "absl/log/check.h" +#include "absl/log/log.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/match.h" +#include "absl/strings/numbers.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" +#include "absl/strings/str_split.h" #include "absl/strings/string_view.h" #include "absl/synchronization/mutex.h" #include "absl/types/span.h" +#include "xla/stream_executor/gpu/gpu_asm_opts.h" #include "xla/stream_executor/gpu/gpu_driver.h" +#include "xla/stream_executor/gpu/gpu_types.h" #include "xla/util.h" #include "tsl/platform/cuda_libdevice_path.h" #include "tsl/platform/env.h" #include "tsl/platform/errors.h" #include "tsl/platform/path.h" #include "tsl/platform/regexp.h" +#include "tsl/platform/status.h" +#include "tsl/platform/statusor.h" #include "tsl/platform/subprocess.h" namespace stream_executor { @@ -170,8 +182,8 @@ absl::StatusOr> CompileGpuAsm(int device_ordinal, return CompileGpuAsm(cc_major, cc_minor, ptx_contents, options); } -std::string FindCudaExecutable(const std::string& binary_name, - const std::string& preferred_cuda_dir) { +absl::StatusOr FindCudaExecutable( + const std::string& binary_name, const std::string& preferred_cuda_dir) { static absl::Mutex mu(absl::kConstInit); static auto* seen_binary_paths ABSL_GUARDED_BY(mu) = new absl::flat_hash_map, @@ -192,45 +204,44 @@ std::string FindCudaExecutable(const std::string& binary_name, } auto env = tsl::Env::Default(); - std::string binary_path = - tsl::io::JoinPath(preferred_cuda_dir, "bin", binary_filename); - - // Search in the preferred cuda directory - VLOG(2) << "Looking for " << binary_filename << " at " << binary_path; - if (env->FileExists(binary_path).ok() && - GetToolVersionString(binary_path).ok()) { - VLOG(2) << "Using " << binary_filename << " at " << binary_path; - seen_binary_paths->emplace(std::move(cache_key), binary_path); - return binary_path; - } - - // Try searching in PATH if the preferred cuda directory didn't work. - if (GetToolVersionString(binary_filename).ok()) { - VLOG(2) << "Using " << binary_filename; - seen_binary_paths->emplace(std::move(cache_key), binary_filename); - return binary_filename; - } - - // Search in cuda root candidates. - for (const std::string& cuda_root : tsl::CandidateCudaRoots()) { - binary_path = tsl::io::JoinPath(cuda_root, "bin", binary_filename); - VLOG(2) << "Looking for " << binary_filename << " at " << binary_path; - if (env->FileExists(binary_path).ok() && - GetToolVersionString(binary_path).ok()) { - VLOG(2) << "Using " << binary_filename << " at " << binary_path; - seen_binary_paths->emplace(std::move(cache_key), binary_path); - return binary_path; + std::vector candidates{}; + + // #1 - Check the preferred CUDA directory + candidates.emplace_back( + tsl::io::JoinPath(preferred_cuda_dir, "bin", binary_filename)); + + std::string_view path_env = std::getenv("PATH"); + +#if defined(PLATFORM_WINDOWS) + constexpr char kSearchPathSeparator = ';'; +#else + constexpr char kSearchPathSeparator = ':'; +#endif + + // #2 - Check the PATH environment variable + for (std::string_view path : absl::StrSplit(path_env, kSearchPathSeparator)) { + candidates.emplace_back(tsl::io::JoinPath(path, binary_filename)); + } + + // #2 - Check generic CUDA locations + for (std::string_view path : tsl::CandidateCudaRoots()) { + candidates.emplace_back(tsl::io::JoinPath(path, "bin", binary_filename)); + } + + for (const auto& candidate : candidates) { + VLOG(2) << "Looking for " << candidate; + if (env->FileExists(candidate).ok() && + GetToolVersionString(candidate).ok()) { + VLOG(2) << "Using " << candidate; + seen_binary_paths->emplace(std::move(cache_key), candidate); + return candidate; } } - // Give up and just rely on subprocess invocation to find the correct - // binary. This won't work, in all probability, given we already tried that - // above, but it's the best we can do. - VLOG(2) << "Unable to find " << binary_name; - binary_path = binary_filename; - VLOG(2) << "Using " << binary_filename << " at " << binary_path; - seen_binary_paths->emplace(std::move(cache_key), binary_path); - return binary_path; + return absl::NotFoundError( + absl::StrCat("Couldn't find ", binary_name, + ". The following locations were considered: ", + absl::StrJoin(candidates, ", "))); } static void LogPtxasTooOld(const std::string& ptxas_path, int cc_major, @@ -263,7 +274,8 @@ static void AppendArgsFromOptions(GpuAsmOpts options, absl::StatusOr> GetAsmCompilerVersion( const std::string& preferred_cuda_dir) { - std::string ptxas_path = FindCudaExecutable("ptxas", preferred_cuda_dir); + TF_ASSIGN_OR_RETURN(std::string ptxas_path, + FindCudaExecutable("ptxas", preferred_cuda_dir)); return GetToolVersion(ptxas_path); } @@ -280,8 +292,8 @@ absl::StatusOr> CompileGpuAsm(int cc_major, int cc_minor, std::get<0>(ptxas_version_tuple), std::get<1>(ptxas_version_tuple), std::get<2>(ptxas_version_tuple))); } - std::string ptxas_path = - FindCudaExecutable("ptxas", options.preferred_cuda_dir); + TF_ASSIGN_OR_RETURN(std::string ptxas_path, + FindCudaExecutable("ptxas", options.preferred_cuda_dir)); WarnIfBadPtxasVersion(ptxas_path); @@ -382,8 +394,9 @@ absl::StatusOr> CompileGpuAsm(int cc_major, int cc_minor, absl::StatusOr> BundleGpuAsm( std::vector images, GpuAsmOpts options) { - std::string fatbinary_path = - FindCudaExecutable("fatbinary", options.preferred_cuda_dir); + TF_ASSIGN_OR_RETURN( + std::string fatbinary_path, + FindCudaExecutable("fatbinary", options.preferred_cuda_dir)); // Write images to temporary files. std::vector image_paths; diff --git a/third_party/xla/xla/stream_executor/gpu/asm_compiler.h b/third_party/xla/xla/stream_executor/gpu/asm_compiler.h index eda86fb9dc97ac..79a2e6a1669d3a 100644 --- a/third_party/xla/xla/stream_executor/gpu/asm_compiler.h +++ b/third_party/xla/xla/stream_executor/gpu/asm_compiler.h @@ -18,21 +18,24 @@ limitations under the License. #include #include +#include #include #include #include "absl/base/const_init.h" #include "absl/base/thread_annotations.h" -#include "absl/container/flat_hash_map.h" +#include "absl/container/node_hash_map.h" +#include "absl/log/check.h" #include "absl/status/statusor.h" +#include "absl/strings/string_view.h" #include "absl/synchronization/mutex.h" #include "absl/types/span.h" #include "xla/stream_executor/gpu/gpu_asm_opts.h" #include "xla/stream_executor/kernel.h" -#include "xla/stream_executor/platform/port.h" #include "xla/stream_executor/stream_executor.h" #include "tsl/platform/statusor.h" #if GOOGLE_CUDA +#include "third_party/gpus/cuda/include/cuda.h" #include "xla/stream_executor/cuda/cuda_driver.h" #endif // GOOGLE_CUDA @@ -101,8 +104,8 @@ absl::StatusOr> LinkUsingNvlink( absl::string_view preferred_cuda_dir, gpu::GpuContext* context, std::vector images); -std::string FindCudaExecutable(const std::string& binary_name, - const std::string& preferred_cuda_dir); +absl::StatusOr FindCudaExecutable( + const std::string& binary_name, const std::string& preferred_cuda_dir); // Runs tool --version and parses its version string. absl::StatusOr> GetToolVersion( @@ -115,7 +118,7 @@ absl::StatusOr> GetAsmCompilerVersion( #if GOOGLE_CUDA // Maintains a cache of pointers to loaded kernels template -absl::StatusOr>> LoadKernelOrGetPtr( +absl::StatusOr*> LoadKernelOrGetPtr( StreamExecutor* executor, absl::string_view kernel_name, absl::string_view ptx, absl::Span cubin_data) { using KernelPtrCacheKey = @@ -123,8 +126,7 @@ absl::StatusOr>> LoadKernelOrGetPtr( static absl::Mutex kernel_ptr_cache_mutex(absl::kConstInit); static auto& kernel_ptr_cache ABSL_GUARDED_BY(kernel_ptr_cache_mutex) = - *new absl::flat_hash_map>>(); + *new absl::node_hash_map>(); CUcontext current_context = cuda::CurrentContextOrDie(); KernelPtrCacheKey kernel_ptr_cache_key{current_context, kernel_name, ptx}; absl::MutexLock lock(&kernel_ptr_cache_mutex); @@ -132,14 +134,14 @@ absl::StatusOr>> LoadKernelOrGetPtr( auto it = kernel_ptr_cache.find(kernel_ptr_cache_key); if (it == kernel_ptr_cache.end()) { TF_ASSIGN_OR_RETURN( - std::shared_ptr> loaded, - executor->CreateTypedKernel(kernel_name, ptx, cubin_data)); + TypedKernel loaded, + (TypedKernel::Create(executor, kernel_name, ptx, cubin_data))); it = kernel_ptr_cache.emplace(kernel_ptr_cache_key, std::move(loaded)).first; } CHECK(it != kernel_ptr_cache.end()); - return it->second; + return &it->second; } #endif // GOOGLE_CUDA diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_activation.cc b/third_party/xla/xla/stream_executor/gpu/gpu_activation.cc index 8e397e352bf154..c40182cccf1692 100644 --- a/third_party/xla/xla/stream_executor/gpu/gpu_activation.cc +++ b/third_party/xla/xla/stream_executor/gpu/gpu_activation.cc @@ -18,7 +18,6 @@ limitations under the License. #include "xla/stream_executor/gpu/gpu_driver.h" #include "xla/stream_executor/gpu/gpu_executor.h" #include "xla/stream_executor/stream_executor.h" -#include "xla/stream_executor/stream_executor_internal.h" namespace stream_executor { namespace gpu { diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_activation.h b/third_party/xla/xla/stream_executor/gpu/gpu_activation.h index 315ecc285a8aa3..a28bef2e5da836 100644 --- a/third_party/xla/xla/stream_executor/gpu/gpu_activation.h +++ b/third_party/xla/xla/stream_executor/gpu/gpu_activation.h @@ -23,7 +23,6 @@ limitations under the License. #ifndef XLA_STREAM_EXECUTOR_GPU_GPU_ACTIVATION_H_ #define XLA_STREAM_EXECUTOR_GPU_GPU_ACTIVATION_H_ -#include "xla/stream_executor/platform/port.h" namespace stream_executor { diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_blas_lt.cc b/third_party/xla/xla/stream_executor/gpu/gpu_blas_lt.cc index 9e30d1272bc67e..2a052e96975223 100644 --- a/third_party/xla/xla/stream_executor/gpu/gpu_blas_lt.cc +++ b/third_party/xla/xla/stream_executor/gpu/gpu_blas_lt.cc @@ -16,6 +16,7 @@ limitations under the License. #include "xla/stream_executor/gpu/gpu_blas_lt.h" #include +#include #include #include "absl/status/statusor.h" @@ -23,6 +24,7 @@ limitations under the License. #include "xla/stream_executor/blas.h" #include "xla/stream_executor/stream.h" #include "xla/util.h" +#include "tsl/protobuf/dnn.pb.h" #if GOOGLE_CUDA #include "tsl/platform/tensor_float_32_utils.h" #endif diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_blas_lt.h b/third_party/xla/xla/stream_executor/gpu/gpu_blas_lt.h index 6bac2ea13676e5..6027794255dd3f 100644 --- a/third_party/xla/xla/stream_executor/gpu/gpu_blas_lt.h +++ b/third_party/xla/xla/stream_executor/gpu/gpu_blas_lt.h @@ -17,18 +17,20 @@ limitations under the License. #define XLA_STREAM_EXECUTOR_GPU_GPU_BLAS_LT_H_ #include +#include #include +#include #include #include #include -#include "xla/shape.h" #include "xla/status.h" -#include "xla/statusor.h" #include "xla/stream_executor/blas.h" #include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/host_or_device_scalar.h" #include "xla/types.h" +#include "xla/xla_data.pb.h" +#include "tsl/platform/errors.h" namespace stream_executor::gpu { diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.cc b/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.cc index 23984bd04619e5..e9910cec9e85ff 100644 --- a/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.cc +++ b/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.cc @@ -32,6 +32,7 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" #include "absl/types/span.h" #include "xla/stream_executor/command_buffer.h" #include "xla/stream_executor/device_memory.h" @@ -119,6 +120,7 @@ GpuCommandBuffer::GpuCommandBuffer(Mode mode, GpuExecutor* parent, VLOG(5) << "Created command buffer for graph " << graph_ << "; mode=" << ModeToString(mode) << "; is_owned_graph=" << is_owned_graph_; + execution_scopes_.try_emplace(kDefaulExecutionScope); } GpuCommandBuffer::~GpuCommandBuffer() { @@ -126,12 +128,14 @@ GpuCommandBuffer::~GpuCommandBuffer() { VLOG(5) << "Destroy GPU command buffer executable graph " << exec_ << " " << "(remaining alive executable graphs: " << NotifyExecDestroyed() << ")"; - auto st = GpuDriver::DestroyGraphExec(exec_); - CHECK(st.ok()) << "Failed to destroy GPU graph exec: " << st.message(); + if (auto status = GpuDriver::DestroyGraphExec(exec_); !status.ok()) { + LOG(ERROR) << "Failed to destroy GPU graph exec: " << status.message(); + } } if (graph_ != nullptr && is_owned_graph_) { - auto st = GpuDriver::DestroyGraph(graph_); - CHECK(st.ok()) << "Failed to destroy GPU graph: " << st.message(); + if (auto status = GpuDriver::DestroyGraph(graph_); !status.ok()) { + LOG(ERROR) << "Failed to destroy GPU graph: " << status.message(); + } } } @@ -155,8 +159,6 @@ static GpuDevicePtr AsDevicePtr(const DeviceMemoryBase& mem) { absl::Status GpuCommandBuffer::Trace( Stream* stream, absl::AnyInvocable function) { - // TODO(ezhulenev): Check that graph is empty, because we should not be mixing - // graph tracing with explicit graph construction. TF_RETURN_IF_ERROR(CheckNotFinalized()); VLOG(5) << "Trace into GPU command buffer graph " << graph_ @@ -187,38 +189,101 @@ absl::Status GpuCommandBuffer::Trace( return absl::OkStatus(); } -GpuCommandBuffer::Dependencies GpuCommandBuffer::GetBarrier() { - return barrier_ ? Dependencies{barrier_} : Dependencies{}; +GpuCommandBuffer::Dependencies GpuCommandBuffer::GetBarrier( + ExecutionScopeId execution_scope_id) { + ExecutionScope& execution_scope = execution_scopes_[execution_scope_id]; + return execution_scope.barriers.empty() + ? Dependencies{} + : Dependencies{execution_scope.barriers.back().handle}; +} + +absl::StatusOr +GpuCommandBuffer::GetSetIfConditionKernel(StreamExecutor* executor) { + if (!set_if_condition_kernel_) { + MultiKernelLoaderSpec spec(/*arity=*/2); + spec.AddCudaPtxInMemory(gpu::GetSetIfConditionKernel(), "set_if_condition"); + TF_ASSIGN_OR_RETURN(set_if_condition_kernel_, + SetIfConditionKernel::Create(executor, spec)); + } + return &set_if_condition_kernel_; +} + +absl::StatusOr +GpuCommandBuffer::GetSetIfElseConditionKernel(StreamExecutor* executor) { + if (!set_if_else_condition_kernel_) { + MultiKernelLoaderSpec spec(/*arity=*/3); + spec.AddCudaPtxInMemory(gpu::GetSetIfElseConditionKernel(), + "set_if_else_condition"); + TF_ASSIGN_OR_RETURN(set_if_else_condition_kernel_, + SetIfElseConditionKernel::Create(executor, spec)); + } + return &set_if_else_condition_kernel_; +} + +absl::StatusOr +GpuCommandBuffer::GetSetCaseConditionKernel(StreamExecutor* executor) { + if (!set_case_condition_kernel_) { + MultiKernelLoaderSpec spec(/*arity=*/10); + spec.AddCudaPtxInMemory(gpu::GetSetCaseConditionKernel(), + "set_case_condition"); + TF_ASSIGN_OR_RETURN(set_case_condition_kernel_, + SetCaseConditionKernel::Create(executor, spec)); + } + return &set_case_condition_kernel_; +} + +absl::StatusOr +GpuCommandBuffer::GetSetForConditionKernel(StreamExecutor* executor) { + if (!set_for_condition_kernel_) { + MultiKernelLoaderSpec spec(/*arity=*/3); + spec.AddCudaPtxInMemory(gpu::GetSetForConditionKernel(), + "set_for_condition"); + TF_ASSIGN_OR_RETURN(set_for_condition_kernel_, + SetForConditionKernel::Create(executor, spec)); + } + return &set_for_condition_kernel_; +} + +absl::StatusOr +GpuCommandBuffer::GetSetWhileConditionKernel(StreamExecutor* executor) { + if (!set_while_condition_kernel_) { + MultiKernelLoaderSpec spec(/*arity=*/2); + spec.AddCudaPtxInMemory(gpu::GetSetWhileConditionKernel(), + "set_while_condition"); + TF_ASSIGN_OR_RETURN(set_while_condition_kernel_, + SetWhileConditionKernel::Create(executor, spec)); + } + return &set_while_condition_kernel_; } absl::StatusOr GpuCommandBuffer::GetNoOpKernel( StreamExecutor* executor) { if (!noop_kernel_) { - auto noop_kernel = std::make_unique(executor); - MultiKernelLoaderSpec spec(/*arity=*/0); #if !defined(TENSORFLOW_USE_ROCM) spec.AddCudaPtxInMemory(gpu::kNoOpKernel, "noop"); #else spec.AddInProcessSymbol(gpu::GetNoOpKernel(), "noop"); #endif // TENSORFLOW_USE_ROCM - TF_RETURN_IF_ERROR(executor->GetKernel(spec, noop_kernel.get())); - - noop_kernel_ = std::move(noop_kernel); + TF_ASSIGN_OR_RETURN(noop_kernel_, NoOpKernel::Create(executor, spec)); } - - return noop_kernel_.get(); + return &noop_kernel_; } absl::Status GpuCommandBuffer::DisableBarriersExecution( GpuGraphExecHandle exec) { - for (GpuGraphNodeHandle barrier : barriers_) { - if (barrier == nullptr) continue; - TF_RETURN_IF_ERROR(GpuDriver::GraphNodeSetEnabled(exec, barrier, false)); + ExecutionScope& execution_scope = execution_scopes_[kDefaulExecutionScope]; + + for (GpuGraphBarrierInfo& barrier : execution_scope.barriers) { + if (barrier.is_barrier_node) { + TF_RETURN_IF_ERROR( + GpuDriver::GraphNodeSetEnabled(exec, barrier.handle, false)); + } } - for (ConditionalCommandBuffers& cmd_buffers : conditional_command_buffers_) { - for (CommandBuffer& cmd_buffer : cmd_buffers.command_buffers) { - TF_RETURN_IF_ERROR(Cast(&cmd_buffer)->DisableBarriersExecution(exec)); + for (ConditionalCommandBuffers& cmd_buffers : + execution_scope.conditional_command_buffers) { + for (auto& cmd_buffer : cmd_buffers.command_buffers) { + TF_RETURN_IF_ERROR(cmd_buffer->DisableBarriersExecution(exec)); } } return absl::OkStatus(); @@ -241,49 +306,185 @@ absl::Status GpuCommandBuffer::CheckNumCommandBuffers( return absl::OkStatus(); } -absl::Status GpuCommandBuffer::Barrier(StreamExecutor* executor) { - // We don't support adding barriers as root nodes and simply skip them. - if ((state_ == State::kCreate && nodes_.empty()) || - (state_ == State::kUpdate && update_state_.node_idx == 0)) +absl::StatusOr GpuCommandBuffer::CreateBarrierNode( + StreamExecutor* executor, const Dependencies& dependencies) { + // TODO(b/316343054): Instead of empty nodes we create no-op kernel nodes as + // barriers because CUDA 12.3 does not support empty nodes inside + // conditional command buffers. This should be fixed in CUDA 12.4. + TF_ASSIGN_OR_RETURN(NoOpKernel * noop, GetNoOpKernel(executor)); + + GpuGraphNodeHandle barrier_handle = nullptr; + TF_RETURN_IF_ERROR(GpuDriver::GraphAddKernelNode( + &barrier_handle, graph_, dependencies, "noop", + AsGpuKernel(&**noop)->AsGpuFunctionHandle(), 1, 1, 1, 1, 1, 1, 0, + /*kernel_params=*/nullptr, /*extra=*/nullptr)); + + return barrier_handle; +} + +GpuCommandBuffer::Dependencies GpuCommandBuffer::GetBarrierDependencies( + ExecutionScopeId execution_scope_id) { + ExecutionScope& execution_scope = execution_scopes_[execution_scope_id]; + auto& barriers = execution_scope.barriers; + + // Collect nodes that will become a new barrier dependencies starting from + // the first command node added after the last barrier in the scope. + Dependencies dependencies; + for (size_t i = barriers.empty() ? 0 : barriers.back().nodes_offset; + i < execution_scope.nodes.size(); ++i) { + dependencies.push_back(execution_scope.nodes[i].handle); + } + return dependencies; +} + +absl::Status GpuCommandBuffer::Barrier(StreamExecutor* executor, + ExecutionScopeId execution_scope_id) { + ExecutionScope& execution_scope = execution_scopes_[execution_scope_id]; + + if (state_ == State::kCreate) { + // Nodes offset for a newly created barrier. + size_t nodes_offset = execution_scope.nodes.size(); + + // Collect nodes that will become a new barrier dependencies starting from + // the first command node added after the last barrier. + Dependencies dependencies = GetBarrierDependencies(execution_scope_id); + + // If there are no new dependencies and we have an existing barrier simply + // copy information from the last barrier to a new one. + if (dependencies.empty() && !execution_scope.barriers.empty()) { + execution_scope.barriers.push_back({execution_scope.barriers.back()}); + return absl::OkStatus(); + } + + // If we have only one node added after the last barrier simply reuse the + // last node corresponding to a command as a barrier. + if (dependencies.size() == 1) { + execution_scope.barriers.push_back( + {execution_scope.nodes.back().handle, false, nodes_offset}); + return absl::OkStatus(); + } + + // If we have multiple dependencies or no existing barriers we have to + // create a new empty node acting as an execution barrier. + TF_ASSIGN_OR_RETURN(auto barrier_handle, + CreateBarrierNode(executor, dependencies)); + execution_scope.barriers.push_back({barrier_handle, true, nodes_offset}); return absl::OkStatus(); + } + + if (state_ == State::kUpdate) { + // Command buffer updates can't change the structure of the underlying gpu + // graph (add or delete barriers). We simply do a sanity check that at + // update time we didn't try to add more barriers than we had originally. + if (execution_scope.update_state.barrier_idx++ >= + execution_scope.barriers.size()) { + return absl::InternalError( + absl::StrFormat("Execution scope %d barrier index out of range", + execution_scope_id.value())); + } + return absl::OkStatus(); + } + + return UnsupportedStateError(state_); +} + +absl::Status GpuCommandBuffer::Barrier( + StreamExecutor* executor, + absl::Span execution_scope_ids) { + // Nothing to synchronize here. + if (execution_scope_ids.empty()) return absl::OkStatus(); + + // Do not create two-level barriers for single execution scope. + if (execution_scope_ids.size() == 1) { + return Barrier(executor, execution_scope_ids[0]); + } + + // Add a new barrier to every synchronized execution scope. + for (ExecutionScopeId execution_scope_id : execution_scope_ids) { + TF_RETURN_IF_ERROR(Barrier(executor, execution_scope_id)); + } if (state_ == State::kCreate) { - // Collect nodes that will become a new barrier dependencies. + // Collect barriers from each scope as a dependencies. Dependencies dependencies; - for (int32_t i = nodes_.size() - 1; i >= 0; --i) { - if (nodes_[i].handle == barrier_) break; - dependencies.push_back(nodes_[i].handle); + for (ExecutionScopeId execution_scope_id : execution_scope_ids) { + ExecutionScope& execution_scope = execution_scopes_[execution_scope_id]; + dependencies.push_back(execution_scope.barriers.back().handle); } - // Add a noop kernel node acting as a barrier. - if (dependencies.size() > 1) { - // TODO(b/316343054): This should be an empty node, however CUDA 12.3 does - // not support empty nodes inside conditional command buffers. - GpuGraphNodeInfo& node_info = nodes_.emplace_back(); - TF_ASSIGN_OR_RETURN(auto noop, GetNoOpKernel(executor)); - TF_RETURN_IF_ERROR(GpuDriver::GraphAddKernelNode( - &node_info.handle, graph_, absl::MakeSpan(dependencies), noop->name(), - AsGpuKernel(noop)->AsGpuFunctionHandle(), 1, 1, 1, 1, 1, 1, 0, - /*kernel_params=*/nullptr, /*extra=*/nullptr)); - barriers_.push_back(node_info.handle); - } else { - barriers_.push_back(nullptr); + // Create a new barrier that joins all per-scope barriers together. + TF_ASSIGN_OR_RETURN(auto barrier_handle, + CreateBarrierNode(executor, dependencies)); + + // Broadcast new barrier to all participating execution scopes. + for (ExecutionScopeId execution_scope_id : execution_scope_ids) { + ExecutionScope& execution_scope = execution_scopes_[execution_scope_id]; + size_t nodes_offset = execution_scope.nodes.size(); + execution_scope.barriers.push_back({barrier_handle, true, nodes_offset}); } - // Make the last node a barrier, if we didn't add a new no-op node acting - // as a barrier we simply reuse the last node. - barrier_ = nodes_.back().handle; return absl::OkStatus(); } if (state_ == State::kUpdate) { - // Increment update node index only if we added a no-op node earlier and it - // means that we just updated a "real" barrier node, otherwise barrier is - // the last updated node. - if (barriers_[update_state_.barrier_idx++]) { - barrier_ = nodes_[update_state_.node_idx++].handle; - } else if (update_state_.node_idx) { - barrier_ = nodes_[update_state_.node_idx - 1].handle; + // Command buffer updates can't change the structure of the underlying gpu + // graph (add or delete barriers). We simply do a sanity check that at + // update time we didn't try to add more barriers than we had originally. + for (ExecutionScopeId execution_scope_id : execution_scope_ids) { + ExecutionScope& execution_scope = execution_scopes_[execution_scope_id]; + if (execution_scope.update_state.barrier_idx++ >= + execution_scope.barriers.size()) { + return absl::InternalError( + absl::StrFormat("Execution scope %d barrier index out of range", + execution_scope_id.value())); + } + } + return absl::OkStatus(); + } + + return UnsupportedStateError(state_); +} + +absl::Status GpuCommandBuffer::Barrier(StreamExecutor* executor, + ExecutionScopeId from_execution_scope_id, + ExecutionScopeId to_execution_scope_id) { + // If scopes are the same simply add a barrier to it. + if (from_execution_scope_id == to_execution_scope_id) { + return Barrier(executor, from_execution_scope_id); + } + + // Create new barriers in both execution scopes. + TF_RETURN_IF_ERROR(Barrier(executor, from_execution_scope_id)); + TF_RETURN_IF_ERROR(Barrier(executor, to_execution_scope_id)); + + if (state_ == State::kCreate) { + // Collect barriers from each scope as dependencies. + Dependencies dependencies = { + execution_scopes_[from_execution_scope_id].barriers.back().handle, + execution_scopes_[to_execution_scope_id].barriers.back().handle}; + + // Create a new barrier that joins `from` and `to` scopes. + TF_ASSIGN_OR_RETURN(auto barrier_handle, + CreateBarrierNode(executor, dependencies)); + + // Add a new barrier only to the `to_execution_scope_id`. + ExecutionScope& execution_scope = execution_scopes_[to_execution_scope_id]; + size_t nodes_offset = execution_scope.nodes.size(); + execution_scope.barriers.push_back({barrier_handle, true, nodes_offset}); + + return absl::OkStatus(); + } + + if (state_ == State::kUpdate) { + // Command buffer updates can't change the structure of the underlying gpu + // graph (add or delete barriers). We simply do a sanity check that at + // update time we didn't try to add more barriers than we had originally. + ExecutionScope& execution_scope = execution_scopes_[to_execution_scope_id]; + if (execution_scope.update_state.barrier_idx++ >= + execution_scope.barriers.size()) { + return absl::InternalError( + absl::StrFormat("Execution scope %d barrier index out of range", + to_execution_scope_id.value())); } return absl::OkStatus(); } @@ -292,8 +493,11 @@ absl::Status GpuCommandBuffer::Barrier(StreamExecutor* executor) { } absl::Status GpuCommandBuffer::LaunchWithPackedArgs( - const ThreadDim& threads, const BlockDim& blocks, const Kernel& kernel, + ExecutionScopeId execution_scope_id, const ThreadDim& threads, + const BlockDim& blocks, const Kernel& kernel, const KernelArgsPackedArrayBase& packed_args) { + ExecutionScope& execution_scope = execution_scopes_[execution_scope_id]; + CHECK_EQ(kernel.Arity() + (packed_args.number_of_shared_bytes() > 0), packed_args.number_of_arguments()); @@ -305,17 +509,18 @@ absl::Status GpuCommandBuffer::LaunchWithPackedArgs( // Adds a new kernel node to the graph under construction. if (state_ == State::kCreate) { - Dependencies barrier = GetBarrier(); - GpuGraphNodeInfo& node_info = nodes_.emplace_back(); + Dependencies barrier = GetBarrier(execution_scope_id); + GpuGraphNodeInfo& node_info = execution_scope.nodes.emplace_back(); return GpuDriver::GraphAddKernelNode( - &node_info.handle, graph_, absl::MakeSpan(barrier), kernel.name(), - gpu_func, blocks.x, blocks.y, blocks.z, threads.x, threads.y, threads.z, + &node_info.handle, graph_, barrier, kernel.name(), gpu_func, blocks.x, + blocks.y, blocks.z, threads.x, threads.y, threads.z, packed_args.number_of_shared_bytes(), kernel_params, /*extra=*/nullptr); } // Updates kernel node in the executable graph. if (state_ == State::kUpdate) { - GpuGraphNodeHandle node = nodes_[update_state_.node_idx++].handle; + GpuGraphNodeHandle node = + execution_scope.nodes[execution_scope.update_state.node_idx++].handle; return GpuDriver::GraphExecKernelNodeSetParams( exec_, node, kernel.name(), gpu_func, blocks.x, blocks.y, blocks.z, threads.x, threads.y, threads.z, packed_args.number_of_shared_bytes(), @@ -325,7 +530,8 @@ absl::Status GpuCommandBuffer::LaunchWithPackedArgs( return UnsupportedStateError(state_); } -absl::Status GpuCommandBuffer::Launch(const ThreadDim& threads, +absl::Status GpuCommandBuffer::Launch(ExecutionScopeId execution_scope_id, + const ThreadDim& threads, const BlockDim& blocks, const Kernel& kernel, const KernelArgs& args) { @@ -333,7 +539,8 @@ absl::Status GpuCommandBuffer::Launch(const ThreadDim& threads, // If arguments are already packed we can just launch the kernel. if (auto* packed = DynCast(&args)) { - return LaunchWithPackedArgs(threads, blocks, kernel, *packed); + return LaunchWithPackedArgs(execution_scope_id, threads, blocks, kernel, + *packed); } // For device memory array we rely on a custom kernel arguments packing. @@ -346,50 +553,57 @@ absl::Status GpuCommandBuffer::Launch(const ThreadDim& threads, } TF_ASSIGN_OR_RETURN(auto packed, pack(kernel, *device_mem)); - return LaunchWithPackedArgs(threads, blocks, kernel, *packed); + return LaunchWithPackedArgs(execution_scope_id, threads, blocks, kernel, + *packed); } return absl::InternalError("Unsupported kernel arguments type"); } absl::Status GpuCommandBuffer::AddNestedCommandBuffer( - const CommandBuffer& nested) { + ExecutionScopeId execution_scope_id, const CommandBuffer& nested) { + ExecutionScope& execution_scope = execution_scopes_[execution_scope_id]; + TF_RETURN_IF_ERROR(CheckNotFinalized()); GpuGraphHandle child_graph = GpuCommandBuffer::Cast(&nested)->graph(); // Adds a child graph node to the graph under construction. if (state_ == State::kCreate) { - Dependencies barrier = GetBarrier(); - GpuGraphNodeInfo& node_info = nodes_.emplace_back(); - return GpuDriver::GraphAddChildNode(&node_info.handle, graph_, - absl::MakeSpan(barrier), child_graph); + Dependencies barrier = GetBarrier(execution_scope_id); + GpuGraphNodeInfo& node_info = execution_scope.nodes.emplace_back(); + return GpuDriver::GraphAddChildNode(&node_info.handle, graph_, barrier, + child_graph); } // Updates child graph node in the executable graph. if (state_ == State::kUpdate) { - GpuGraphNodeHandle node = nodes_[update_state_.node_idx++].handle; + GpuGraphNodeHandle node = + execution_scope.nodes[execution_scope.update_state.node_idx++].handle; return GpuDriver::GraphExecChildNodeSetParams(exec_, node, child_graph); } return UnsupportedStateError(state_); } -absl::Status GpuCommandBuffer::MemcpyDeviceToDevice(DeviceMemoryBase* dst, - const DeviceMemoryBase& src, - uint64_t size) { +absl::Status GpuCommandBuffer::MemcpyDeviceToDevice( + ExecutionScopeId execution_scope_id, DeviceMemoryBase* dst, + const DeviceMemoryBase& src, uint64_t size) { + ExecutionScope& execution_scope = execution_scopes_[execution_scope_id]; + TF_RETURN_IF_ERROR(CheckNotFinalized()); if (state_ == State::kCreate) { - Dependencies barrier = GetBarrier(); - GpuGraphNodeInfo& node_info = nodes_.emplace_back(); + Dependencies barrier = GetBarrier(execution_scope_id); + GpuGraphNodeInfo& node_info = execution_scope.nodes.emplace_back(); return GpuDriver::GraphAddMemcpyD2DNode( - parent_->gpu_context(), &node_info.handle, graph_, - absl::MakeSpan(barrier), AsDevicePtr(*dst), AsDevicePtr(src), size); + parent_->gpu_context(), &node_info.handle, graph_, barrier, + AsDevicePtr(*dst), AsDevicePtr(src), size); } if (state_ == State::kUpdate) { - GpuGraphNodeHandle node = nodes_[update_state_.node_idx++].handle; + GpuGraphNodeHandle node = + execution_scope.nodes[execution_scope.update_state.node_idx++].handle; return GpuDriver::GraphExecMemcpyD2DNodeSetParams( parent_->gpu_context(), exec_, node, AsDevicePtr(*dst), AsDevicePtr(src), size); @@ -398,21 +612,25 @@ absl::Status GpuCommandBuffer::MemcpyDeviceToDevice(DeviceMemoryBase* dst, return UnsupportedStateError(state_); } -absl::Status GpuCommandBuffer::Memset(DeviceMemoryBase* dst, +absl::Status GpuCommandBuffer::Memset(ExecutionScopeId execution_scope_id, + DeviceMemoryBase* dst, CommandBuffer::BitPattern bit_pattern, size_t num_elements) { + ExecutionScope& execution_scope = execution_scopes_[execution_scope_id]; + TF_RETURN_IF_ERROR(CheckNotFinalized()); if (state_ == State::kCreate) { - Dependencies barrier = GetBarrier(); - GpuGraphNodeInfo& node_info = nodes_.emplace_back(); + Dependencies barrier = GetBarrier(execution_scope_id); + GpuGraphNodeInfo& node_info = execution_scope.nodes.emplace_back(); return GpuDriver::GraphAddMemsetNode( - parent_->gpu_context(), &node_info.handle, graph_, - absl::MakeSpan(barrier), AsDevicePtr(*dst), bit_pattern, num_elements); + parent_->gpu_context(), &node_info.handle, graph_, barrier, + AsDevicePtr(*dst), bit_pattern, num_elements); } if (state_ == State::kUpdate) { - GpuGraphNodeHandle node = nodes_[update_state_.node_idx++].handle; + GpuGraphNodeHandle node = + execution_scope.nodes[execution_scope.update_state.node_idx++].handle; return GpuDriver::GraphExecMemsetNodeSetParams( parent_->gpu_context(), exec_, node, AsDevicePtr(*dst), bit_pattern, num_elements); @@ -421,17 +639,20 @@ absl::Status GpuCommandBuffer::Memset(DeviceMemoryBase* dst, return UnsupportedStateError(state_); } -absl::StatusOr GpuCommandBuffer::Allocate(size_t bytes) { +absl::StatusOr GpuCommandBuffer::Allocate( + ExecutionScopeId execution_scope_id, size_t bytes) { + ExecutionScope& execution_scope = execution_scopes_[execution_scope_id]; + TF_RETURN_IF_ERROR(CheckNotFinalized()); // Adds a new memory allocation node to the graph under construction. if (state_ == State::kCreate) { - Dependencies barrier = GetBarrier(); - GpuGraphNodeInfo& node_info = nodes_.emplace_back(); + Dependencies barrier = GetBarrier(execution_scope_id); + GpuGraphNodeInfo& node_info = execution_scope.nodes.emplace_back(); GpuDevicePtr ptr; TF_RETURN_IF_ERROR(GpuDriver::GraphAddMemAllocNode( - &node_info.handle, graph_, absl::MakeSpan(barrier), + &node_info.handle, graph_, barrier, GpuDriver::MemAccessFlags::kReadWrite, GpuDriver::MemLocationType::kDevice, parent_->device_ordinal(), GpuDriver::MemAllocationType::kPinned, bytes, &ptr)); @@ -448,9 +669,10 @@ absl::StatusOr GpuCommandBuffer::Allocate(size_t bytes) { // Memory allocation node implemented through CUDA graph does not allocate // new memory region on update, just return the memory region allocated // during the create step. + GpuGraphNodeHandle node = + execution_scope.nodes[execution_scope.update_state.node_idx++].handle; TF_ASSIGN_OR_RETURN(AllocationResult params, - GpuDriver::GraphGetMemAllocNodeParams( - nodes_[update_state_.node_idx++].handle)); + GpuDriver::GraphGetMemAllocNodeParams(node)); return DeviceMemoryBase(reinterpret_cast(params.first), params.second); } @@ -458,16 +680,19 @@ absl::StatusOr GpuCommandBuffer::Allocate(size_t bytes) { return UnsupportedStateError(state_); } -absl::Status GpuCommandBuffer::Free(DeviceMemoryBase dst) { +absl::Status GpuCommandBuffer::Free(ExecutionScopeId execution_scope_id, + DeviceMemoryBase dst) { + ExecutionScope& execution_scope = execution_scopes_[execution_scope_id]; + TF_RETURN_IF_ERROR(CheckNotFinalized()); // Adds a new memfree node to the graph under construction. if (state_ == State::kCreate) { - Dependencies barrier = GetBarrier(); - GpuGraphNodeInfo& node_info = nodes_.emplace_back(); + Dependencies barrier = GetBarrier(execution_scope_id); + GpuGraphNodeInfo& node_info = execution_scope.nodes.emplace_back(); GpuDevicePtr gpu_dptr = AsDevicePtr(dst); - TF_RETURN_IF_ERROR(GpuDriver::GraphAddMemFreeNode( - &node_info.handle, graph_, absl::MakeSpan(barrier), gpu_dptr)); + TF_RETURN_IF_ERROR(GpuDriver::GraphAddMemFreeNode(&node_info.handle, graph_, + barrier, gpu_dptr)); return absl::OkStatus(); } @@ -475,7 +700,7 @@ absl::Status GpuCommandBuffer::Free(DeviceMemoryBase dst) { // memfree node implemented through CUDA graph only free buffers that is // allocated through memory alloc node, so buffer address will not change, // no update is required. - update_state_.node_idx++; + execution_scope.update_state.node_idx++; return absl::OkStatus(); } @@ -486,6 +711,8 @@ absl::Status GpuCommandBuffer::Free(DeviceMemoryBase dst) { // Command buffer condtitional commands API //--------------------------------------------------------------------------// +using ConditionalHandles = absl::Span; + /*static*/ GpuCommandBuffer::ConditionBuilder GpuCommandBuffer::ToConditionBuilder(CommandBuffer::Builder builder) { return [builder = std::move(builder)](CommandBuffer* cmd_buffer, @@ -504,40 +731,12 @@ GpuCommandBuffer::CreateConditionalHandles(size_t num_handles) { return handles; } -absl::StatusOr> -GpuCommandBuffer::CreateConditionalNodes( - ConditionType type, absl::Span handles) { - std::vector conditional_graphs; - - using ConditionalParams = GpuDriver::GpuGraphConditionalNodeParams; - using ConditionalResult = GpuDriver::GpuGraphConditionalNodeParams::Result; - - for (GpuGraphConditionalHandle handle : handles) { - Dependencies barrier = GetBarrier(); - GpuGraphNodeInfo& node_info = nodes_.emplace_back(); - - ConditionalParams params; - params.type = type; - params.handle = handle; - params.context = parent_->gpu_context(); - - TF_ASSIGN_OR_RETURN( - GpuDriver::GpuGraphNodeResult result, - GpuDriver::GraphAddNode(&node_info.handle, graph_, - absl::MakeSpan(barrier), params)); - - conditional_graphs.push_back(std::get(result).graph); - } - - return conditional_graphs; -} - -absl::StatusOr> +absl::StatusOr>> GpuCommandBuffer::CreateConditionalCommandBuffers( absl::Span handles, absl::Span graphs, absl::Span builders) { - std::vector cmd_buffers; + std::vector> cmd_buffers; // Conditional command buffers always created in nested mode and with // underlying graphs owned by a conditional node. @@ -545,13 +744,10 @@ GpuCommandBuffer::CreateConditionalCommandBuffers( bool is_owned_graph = false; for (size_t i = 0; i < handles.size(); ++i) { - auto command_buffer_impl = parent_->GetCommandBufferImplementation( - nested, graphs[i], is_owned_graph); - - auto command_buffer = CommandBuffer::Create(std::move(command_buffer_impl)); - - TF_RETURN_IF_ERROR(builders[i](&command_buffer, handles[i])); - TF_RETURN_IF_ERROR(command_buffer.Finalize()); + auto command_buffer = + parent_->CreateCommandBuffer(nested, graphs[i], is_owned_graph); + TF_RETURN_IF_ERROR(builders[i](command_buffer.get(), handles[i])); + TF_RETURN_IF_ERROR(command_buffer->Finalize()); cmd_buffers.push_back(std::move(command_buffer)); } @@ -561,23 +757,56 @@ GpuCommandBuffer::CreateConditionalCommandBuffers( absl::Status GpuCommandBuffer::UpdateConditionalCommandBuffers( absl::Span handles, - absl::Span command_buffers, + absl::Span> command_buffers, absl::Span builders) { for (size_t i = 0; i < command_buffers.size(); ++i) { // Use parent graph executable for conditional command buffer update. - ScopedGpuGraphExec scoped_exec(Cast(&command_buffers[i]), exec_); + ScopedGpuGraphExec scoped_exec(command_buffers[i].get(), exec_); // Update command buffer using user-provided builder callback. - TF_RETURN_IF_ERROR(command_buffers[i].Update()); - TF_RETURN_IF_ERROR(builders[i](&command_buffers[i], handles[i])); - TF_RETURN_IF_ERROR(command_buffers[i].Finalize()); + TF_RETURN_IF_ERROR(command_buffers[i]->Update()); + TF_RETURN_IF_ERROR(builders[i](command_buffers[i].get(), handles[i])); + TF_RETURN_IF_ERROR(command_buffers[i]->Finalize()); } return absl::OkStatus(); } +absl::StatusOr> +GpuCommandBuffer::CreateConditionalNodes( + ExecutionScopeId execution_scope_id, ConditionType type, + absl::Span handles) { + ExecutionScope& execution_scope = execution_scopes_[execution_scope_id]; + + std::vector conditional_graphs; + + using ConditionalParams = GpuDriver::GpuGraphConditionalNodeParams; + using ConditionalResult = GpuDriver::GpuGraphConditionalNodeParams::Result; + + for (GpuGraphConditionalHandle handle : handles) { + Dependencies barrier = GetBarrier(execution_scope_id); + GpuGraphNodeInfo& node_info = execution_scope.nodes.emplace_back(); + + ConditionalParams params; + params.type = type; + params.handle = handle; + params.context = parent_->gpu_context(); + + TF_ASSIGN_OR_RETURN( + GpuDriver::GpuGraphNodeResult result, + GpuDriver::GraphAddNode(&node_info.handle, graph_, barrier, params)); + + conditional_graphs.push_back(std::get(result).graph); + } + + return conditional_graphs; +} + absl::Status GpuCommandBuffer::CreateConditionalCommand( - StreamExecutor* executor, ConditionType type, SetConditionFn set_condition, + ExecutionScopeId execution_scope_id, StreamExecutor* executor, + ConditionType type, SetConditionFn set_condition, absl::Span builders) { + ExecutionScope& execution_scope = execution_scopes_[execution_scope_id]; + TF_RETURN_IF_ERROR(CheckNotFinalized()); // Every conditional command buffer is controlled by its own handle. @@ -587,38 +816,41 @@ absl::Status GpuCommandBuffer::CreateConditionalCommand( TF_ASSIGN_OR_RETURN(auto handles, CreateConditionalHandles(num_handles)); // Add a kernel to update conditional handles values. - TF_RETURN_IF_ERROR(set_condition(handles)); + TF_RETURN_IF_ERROR(set_condition(execution_scope_id, handles)); // Add a barrier between conditional handles and conditional nodes. - TF_RETURN_IF_ERROR(Barrier(executor)); + TF_RETURN_IF_ERROR(Barrier(executor, execution_scope_id)); // Create conditional command buffer for each builder. - TF_ASSIGN_OR_RETURN(auto graphs, CreateConditionalNodes(type, handles)); + TF_ASSIGN_OR_RETURN( + auto graphs, CreateConditionalNodes(execution_scope_id, type, handles)); TF_ASSIGN_OR_RETURN(auto cmd_buffers, CreateConditionalCommandBuffers( handles, graphs, builders)); // Keep track of created conditional handles and command buffers. - conditional_command_buffers_.emplace_back(std::move(handles), - std::move(cmd_buffers)); + execution_scope.conditional_command_buffers.push_back( + {std::move(handles), std::move(cmd_buffers)}); return absl::OkStatus(); } if (state_ == State::kUpdate) { ConditionalCommandBuffers& cond_cmd_buffers = - conditional_command_buffers_[update_state_.conditional_idx++]; + execution_scope.conditional_command_buffers[execution_scope.update_state + .conditional_idx++]; // Sanity check that we got the correct conditional command buffers. TF_RETURN_IF_ERROR(CheckNumCommandBuffers(cond_cmd_buffers, num_handles)); // Update a kernel that updates conditional handles values. - TF_RETURN_IF_ERROR(set_condition(cond_cmd_buffers.handles)); + TF_RETURN_IF_ERROR( + set_condition(execution_scope_id, cond_cmd_buffers.handles)); // Update a barrier between conditional handles and conditional nodes. - TF_RETURN_IF_ERROR(Barrier(executor)); + TF_RETURN_IF_ERROR(Barrier(executor, execution_scope_id)); // Skip updating conditional nodes. - update_state_.node_idx += num_handles; + execution_scope.update_state.node_idx += num_handles; return UpdateConditionalCommandBuffers( cond_cmd_buffers.handles, @@ -628,66 +860,53 @@ absl::Status GpuCommandBuffer::CreateConditionalCommand( return UnsupportedStateError(state_); } -absl::Status GpuCommandBuffer::If(StreamExecutor* executor, +absl::Status GpuCommandBuffer::If(ExecutionScopeId execution_scope_id, + StreamExecutor* executor, DeviceMemory predicate, CommandBuffer::Builder then_builder) { DCHECK(executor->implementation() == parent_); - // TODO(ezhulenev): Keep kernel in `GpuCommandBuffer` to avoid loading it on - // every call to `If`. - SetIfConditionKernel set_if_condition(executor); - - { // Load kernels that updates condition handle value. - MultiKernelLoaderSpec spec(/*arity=*/2); - spec.AddInProcessSymbol(gpu::GetSetIfConditionKernel(), "set_if_condition"); - TF_RETURN_IF_ERROR(executor->GetKernel(spec, &set_if_condition)); - } + TF_ASSIGN_OR_RETURN(SetIfConditionKernel * set_if_condition, + GetSetIfConditionKernel(executor)); - auto set_cond_fn = [&](absl::Span handles) { - return Launch(set_if_condition, ThreadDim(), BlockDim(), handles[0], - predicate); + auto set_cond_fn = [&](ExecutionScopeId id, ConditionalHandles handles) { + return CommandBuffer::Launch(*set_if_condition, id, ThreadDim(), BlockDim(), + handles[0], predicate); }; std::array builders = { ToConditionBuilder(std::move(then_builder))}; - return CreateConditionalCommand(executor, ConditionType::kIf, set_cond_fn, - builders); + return CreateConditionalCommand(execution_scope_id, executor, + ConditionType::kIf, set_cond_fn, builders); } -absl::Status GpuCommandBuffer::IfElse(StreamExecutor* executor, +absl::Status GpuCommandBuffer::IfElse(ExecutionScopeId execution_scope_id, + StreamExecutor* executor, DeviceMemory predicate, CommandBuffer::Builder then_builder, CommandBuffer::Builder else_builder) { DCHECK(executor->implementation() == parent_); - // TODO(ezhulenev): Keep kernel in `GpuCommandBuffer` to avoid loading it on - // every call to `IfElse`. - SetIfElseConditionKernel set_if_else_condition(executor); - - { // Load kernels that updates condition handle value. - MultiKernelLoaderSpec spec(/*arity=*/3); - spec.AddInProcessSymbol(gpu::GetSetIfElseConditionKernel(), - "set_if_else_condition"); - TF_RETURN_IF_ERROR(executor->GetKernel(spec, &set_if_else_condition)); - } + TF_ASSIGN_OR_RETURN(SetIfElseConditionKernel * set_if_else_condition, + GetSetIfElseConditionKernel(executor)); - auto set_cond_fn = [&](absl::Span handles) { - return Launch(set_if_else_condition, ThreadDim(), BlockDim(), handles[0], - handles[1], predicate); + auto set_cond_fn = [&](ExecutionScopeId id, ConditionalHandles handles) { + return CommandBuffer::Launch(*set_if_else_condition, id, ThreadDim(), + BlockDim(), handles[0], handles[1], predicate); }; std::array builders = { ToConditionBuilder(std::move(then_builder)), ToConditionBuilder(std::move(else_builder))}; - return CreateConditionalCommand(executor, ConditionType::kIf, set_cond_fn, - builders); + return CreateConditionalCommand(execution_scope_id, executor, + ConditionType::kIf, set_cond_fn, builders); } absl::Status GpuCommandBuffer::Case( - StreamExecutor* executor, DeviceMemory index, - std::vector branches) { + ExecutionScopeId execution_scope_id, StreamExecutor* executor, + DeviceMemory index, std::vector branches) { DCHECK(executor->implementation() == parent_); // TODO(ezhulenev): Relax this constraint, we can launch multiple back to back @@ -697,18 +916,10 @@ absl::Status GpuCommandBuffer::Case( "Case command supports only up to 8 branches, got: ", branches.size())); } - // TODO(ezhulenev): Keep kernel in `GpuCommandBuffer` to avoid loading it on - // every call to `Case`. - SetCaseConditionKernel set_case_condition(executor); + TF_ASSIGN_OR_RETURN(SetCaseConditionKernel * set_case_condition, + GetSetCaseConditionKernel(executor)); - { // Load kernels that updates condition handle value. - MultiKernelLoaderSpec spec(/*arity=*/10); - spec.AddInProcessSymbol(gpu::GetSetCaseConditionKernel(), - "set_case_condition"); - TF_RETURN_IF_ERROR(executor->GetKernel(spec, &set_case_condition)); - } - - auto set_cond_fn = [&](absl::Span handles) { + auto set_cond_fn = [&](ExecutionScopeId id, ConditionalHandles handles) { int32_t num_handles = handles.size(); // Pad handles up to size 8 with a default initialized handle. @@ -716,10 +927,11 @@ absl::Status GpuCommandBuffer::Case( handles.end()); padded_handles.resize(8); - return Launch(set_case_condition, ThreadDim(), BlockDim(), - padded_handles[0], padded_handles[1], padded_handles[2], - padded_handles[3], padded_handles[4], padded_handles[5], - padded_handles[6], padded_handles[7], index, num_handles); + return CommandBuffer::Launch( + *set_case_condition, id, ThreadDim(), BlockDim(), padded_handles[0], + padded_handles[1], padded_handles[2], padded_handles[3], + padded_handles[4], padded_handles[5], padded_handles[6], + padded_handles[7], index, num_handles); }; // Wrap all branches into conditional command buffer builders. @@ -729,34 +941,28 @@ absl::Status GpuCommandBuffer::Case( builders.push_back(ToConditionBuilder(std::move(branch))); } - return CreateConditionalCommand(executor, ConditionType::kIf, set_cond_fn, - builders); + return CreateConditionalCommand(execution_scope_id, executor, + ConditionType::kIf, set_cond_fn, builders); } -absl::Status GpuCommandBuffer::For(StreamExecutor* executor, +absl::Status GpuCommandBuffer::For(ExecutionScopeId execution_scope_id, + StreamExecutor* executor, int32_t num_iteration, DeviceMemory loop_counter, CommandBuffer::Builder body_builder) { DCHECK(executor->implementation() == parent_); - // TODO(ezhulenev): Keep kernel in `GpuCommandBuffer` to avoid loading it on - // every call to `For`. - SetForConditionKernel set_for_condition(executor); - - { // Load kernels that updates condition handle value. - MultiKernelLoaderSpec spec(/*arity=*/3); - spec.AddInProcessSymbol(gpu::GetSetForConditionKernel(), - "set_for_condition"); - TF_RETURN_IF_ERROR(executor->GetKernel(spec, &set_for_condition)); - } + TF_ASSIGN_OR_RETURN(SetForConditionKernel * set_for_condition, + GetSetForConditionKernel(executor)); // Reset loop counter to zero. - TF_RETURN_IF_ERROR(Memset(&loop_counter, uint32_t{0}, 1)); - TF_RETURN_IF_ERROR(Barrier(executor)); + TF_RETURN_IF_ERROR(Memset(execution_scope_id, &loop_counter, uint32_t{0}, 1)); + TF_RETURN_IF_ERROR(Barrier(executor, execution_scope_id)); - auto set_cond_fn = [&](absl::Span handles) { - return Launch(set_for_condition, ThreadDim(), BlockDim(), handles[0], - loop_counter, num_iteration); + auto set_cond_fn = [&](ExecutionScopeId id, ConditionalHandles handles) { + return CommandBuffer::Launch(*set_for_condition, id, ThreadDim(), + BlockDim(), handles[0], loop_counter, + num_iteration); }; auto body = [&](CommandBuffer* body, GpuGraphConditionalHandle handle) { @@ -764,40 +970,34 @@ absl::Status GpuCommandBuffer::For(StreamExecutor* executor, TF_RETURN_IF_ERROR(body->Barrier(executor)); // Decide if we want to continue loop iteration. - return body->Launch(set_for_condition, ThreadDim(), BlockDim(), handle, + return body->Launch(*set_for_condition, ThreadDim(), BlockDim(), handle, loop_counter, num_iteration); }; std::array builders = {std::move(body)}; - return CreateConditionalCommand(executor, ConditionType::kWhile, set_cond_fn, - builders); + return CreateConditionalCommand(execution_scope_id, executor, + ConditionType::kWhile, set_cond_fn, builders); } -absl::Status GpuCommandBuffer::While(StreamExecutor* executor, +absl::Status GpuCommandBuffer::While(ExecutionScopeId execution_scope_id, + StreamExecutor* executor, DeviceMemory pred, CommandBuffer::Builder cond_builder, CommandBuffer::Builder body_builder) { DCHECK(executor->implementation() == parent_); - // TODO(ezhulenev): Keep kernel in `GpuCommandBuffer` to avoid loading it on - // every call to `While`. - SetWhileConditionKernel set_while_condition(executor); - - { // Load kernels that updates condition handle value. - MultiKernelLoaderSpec spec(/*arity=*/2); - spec.AddInProcessSymbol(gpu::GetSetWhileConditionKernel(), - "set_while_condition"); - TF_RETURN_IF_ERROR(executor->GetKernel(spec, &set_while_condition)); - } + TF_ASSIGN_OR_RETURN(SetWhileConditionKernel * set_while_condition, + GetSetWhileConditionKernel(executor)); // Record condition commands into the parent command buffer. - TF_RETURN_IF_ERROR(CommandBuffer::Build(this, cond_builder)); - TF_RETURN_IF_ERROR(Barrier(executor)); + TF_RETURN_IF_ERROR(cond_builder(this)); + TF_RETURN_IF_ERROR( + Barrier(executor, kDefaulExecutionScope, execution_scope_id)); - auto set_cond_fn = [&](absl::Span handles) { - return Launch(set_while_condition, ThreadDim(), BlockDim(), handles[0], - pred); + auto set_cond_fn = [&](ExecutionScopeId id, ConditionalHandles handles) { + return CommandBuffer::Launch(*set_while_condition, id, ThreadDim(), + BlockDim(), handles[0], pred); }; auto body = [&](CommandBuffer* body, GpuGraphConditionalHandle handle) { @@ -805,14 +1005,14 @@ absl::Status GpuCommandBuffer::While(StreamExecutor* executor, TF_RETURN_IF_ERROR(body->Barrier(executor)); TF_RETURN_IF_ERROR(cond_builder(body)); TF_RETURN_IF_ERROR(body->Barrier(executor)); - return body->Launch(set_while_condition, ThreadDim(), BlockDim(), handle, + return body->Launch(*set_while_condition, ThreadDim(), BlockDim(), handle, pred); }; std::array builders = {std::move(body)}; - return CreateConditionalCommand(executor, ConditionType::kWhile, set_cond_fn, - builders); + return CreateConditionalCommand(execution_scope_id, executor, + ConditionType::kWhile, set_cond_fn, builders); } absl::Status GpuCommandBuffer::Finalize() { @@ -829,6 +1029,13 @@ absl::Status GpuCommandBuffer::Finalize() { } } + // Collect number of nodes and conditionals for logging below. + size_t num_nodes = 0, num_cond_cmd_buffers = 0; + for (auto& [_, execution_scope] : execution_scopes_) { + num_nodes += execution_scope.nodes.size(); + num_cond_cmd_buffers += execution_scope.conditional_command_buffers.size(); + } + if (mode_ == Mode::kPrimary && state_ == State::kCreate) { // If this is the first time we finalize command buffer after construction, // we need to instantiate it to an executable graph. @@ -841,20 +1048,35 @@ absl::Status GpuCommandBuffer::Finalize() { auto instantiated = GpuDriver::GraphInstantiate(&exec_, graph_, flags); if (instantiated.code() == absl::StatusCode::kResourceExhausted) { LOG(WARNING) << "Retry CUDA graph instantiation after OOM error" - << "; nodes=" << nodes_.size() - << "; conditionals=" << conditional_command_buffers_.size() + << "; execution_scopes: " << execution_scopes_.size() + << "; nodes: " << num_nodes + << "; conditionals: " << num_cond_cmd_buffers << "; alive executable graphs: " << AliveExecs(); TF_RETURN_IF_ERROR(GpuDriver::DeviceGraphMemTrim(parent_->device())); - TF_RETURN_IF_ERROR(GpuDriver::GraphInstantiate(&exec_, graph_, flags)); + + auto retry = GpuDriver::GraphInstantiate(&exec_, graph_, flags); + if (retry.code() == absl::StatusCode::kResourceExhausted) { + return absl::ResourceExhaustedError(absl::StrFormat( + "CUDA driver ran out of memory trying to instantiate CUDA graph " + "with %d nodes and %d conditionals (total of %d alive CUDA graphs " + "in the process). You can try to (a) Give more memory to CUDA " + "driver by reducing XLA_PYTHON_CLIENT_MEM_FRACTION (b) Disable " + "CUDA graph with 'XLA_FLAGS=--xla_gpu_enable_command_buffer=' " + "(empty set). Original error: %s", + num_nodes, num_cond_cmd_buffers, AliveExecs(), retry.message())); + } else { + TF_RETURN_IF_ERROR(retry); + } } uint64_t end_nanos = tsl::Env::Default()->NowNanos(); VLOG(5) << "Instantiated executable graph #" << NotifyExecCreated() << " " << exec_ << " in " << (end_nanos - start_nanos) / 1000 << " μs" - << "; nodes: " << nodes_.size() - << "; conditionals: " << conditional_command_buffers_.size() + << "; execution_scopes: " << execution_scopes_.size() + << "; nodes: " << num_nodes + << "; conditionals: " << num_cond_cmd_buffers << "; alive executable graphs: " << AliveExecs(); TF_RETURN_IF_ERROR(DisableBarriersExecution(exec_)); @@ -891,9 +1113,24 @@ absl::Status GpuCommandBuffer::Update() { << " command buffer update for executable graph " << exec_; state_ = State::kUpdate; - barrier_ = nullptr; - update_state_ = UpdateState(); + for (auto& [_, execution_scope] : execution_scopes_) { + execution_scope.update_state = ExecutionScope::UpdateState(); + } return absl::OkStatus(); } +absl::Span GpuCommandBuffer::nodes( + ExecutionScopeId id) const { + if (auto it = execution_scopes_.find(id); it != execution_scopes_.end()) + return it->second.nodes; + return {}; +} + +absl::Span +GpuCommandBuffer::barriers(ExecutionScopeId id) const { + if (auto it = execution_scopes_.find(id); it != execution_scopes_.end()) + return it->second.barriers; + return {}; +} + } // namespace stream_executor::gpu diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.h b/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.h index 6e333746c7c5fe..2adeadc9dc63a4 100644 --- a/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.h +++ b/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.h @@ -20,10 +20,12 @@ limitations under the License. #include #include #include +#include #include #include #include +#include "absl/container/flat_hash_map.h" #include "absl/container/inlined_vector.h" #include "absl/functional/any_invocable.h" #include "absl/status/status.h" @@ -37,57 +39,96 @@ limitations under the License. #include "xla/stream_executor/kernel.h" #include "xla/stream_executor/launch_dim.h" #include "xla/stream_executor/stream_executor_internal.h" -#include "tsl/platform/errors.h" namespace stream_executor::gpu { -// GpuCommandBuffer provides platform-specific CommandBufferInterface -// implementation (it's backed by CUDA or HIP graphs on NVIDIA and AMD devices). -class GpuCommandBuffer : public internal::CommandBufferInterface { +// GpuCommandBuffer provides platform-specific CommandBuffer implementation +// (it's backed by CUDA or HIP graphs on NVIDIA and AMD devices). +class GpuCommandBuffer : public CommandBuffer { public: - GpuCommandBuffer(CommandBuffer::Mode mode, GpuExecutor* parent, - GpuGraphHandle graph, bool is_owned_graph = true); + // A handle to a Gpu graph node and a metadata describing its properties. Each + // command (launch, memcpy, etc.) creates one or more graph nodes. + struct GpuGraphNodeInfo { + // A handle to the gpu graph node corresponding to a command. + GpuGraphNodeHandle handle = nullptr; + }; + + // A handle to Gpu graph barrier and metadata describing its properties. Each + // call to `Barrier` creates a new barrier record. + struct GpuGraphBarrierInfo { + // A handle to graph node acting as a barrier that defines execution order. + // It can be a handle to a `GpuGraphNodeInfo` node or a handle to an empty + // node created to be a barrier. We try to reuse existing nodes as barriers + // if possible to reduce the size of constructed gpu graphs. + GpuGraphNodeHandle handle = nullptr; + + // If `true` it means `handle` corresponds to an empty node specifically + // created to act as an execution barrier, otherwise `handle` points to one + // of the nodes created for recorded commands. + bool is_barrier_node = true; + + // Nodes with index smaller than `nodes_offset` are synchronized with this + // barrier. We use this offset to find nodes added after the last barrier + // that should be added as dependencies to the next barrier. + size_t nodes_offset = 0; + }; + + GpuCommandBuffer(Mode mode, GpuExecutor* parent, GpuGraphHandle graph, + bool is_owned_graph = true); ~GpuCommandBuffer() override; - absl::Status Trace(Stream* stream, - absl::AnyInvocable function) override; + absl::Status Barrier(StreamExecutor* executor, + ExecutionScopeId execution_scope_id) override; - absl::Status Barrier(StreamExecutor* executor) override; + absl::Status Barrier( + StreamExecutor* executor, + absl::Span execution_scope_ids) override; - absl::Status Launch(const ThreadDim& threads, const BlockDim& blocks, + absl::Status Barrier(StreamExecutor* executor, + ExecutionScopeId from_execution_scope_id, + ExecutionScopeId to_execution_scope_id) override; + + absl::Status Launch(ExecutionScopeId execution_scope_id, + const ThreadDim& threads, const BlockDim& blocks, const Kernel& kernel, const KernelArgs& args) override; - absl::Status AddNestedCommandBuffer(const CommandBuffer& nested) override; + absl::Status AddNestedCommandBuffer(ExecutionScopeId execution_scope_id, + const CommandBuffer& nested) override; - absl::Status MemcpyDeviceToDevice(DeviceMemoryBase* dst, + absl::Status MemcpyDeviceToDevice(ExecutionScopeId execution_scope_id, + DeviceMemoryBase* dst, const DeviceMemoryBase& src, uint64_t size) override; - absl::Status Memset(DeviceMemoryBase* dst, - CommandBuffer::BitPattern bit_pattern, + absl::Status Memset(ExecutionScopeId execution_scope_id, + DeviceMemoryBase* dst, BitPattern bit_pattern, size_t num_elements) override; - absl::StatusOr Allocate(size_t bytes) override; + absl::StatusOr Allocate(ExecutionScopeId execution_scope_id, + size_t bytes) override; - absl::Status Free(DeviceMemoryBase dst) override; + absl::Status Free(ExecutionScopeId execution_scope_id, + DeviceMemoryBase dst) override; - absl::Status If(StreamExecutor* executor, DeviceMemory predicate, - CommandBuffer::Builder then_builder) override; + absl::Status If(ExecutionScopeId execution_scope_id, StreamExecutor* executor, + DeviceMemory predicate, Builder then_builder) override; - absl::Status IfElse(StreamExecutor* executor, DeviceMemory predicate, - CommandBuffer::Builder then_builder, - CommandBuffer::Builder else_builder) override; + absl::Status IfElse(ExecutionScopeId execution_scope_id, + StreamExecutor* executor, DeviceMemory predicate, + Builder then_builder, Builder else_builder) override; - absl::Status Case(StreamExecutor* executor, DeviceMemory index, - std::vector branches) override; + absl::Status Case(ExecutionScopeId execution_scope_id, + StreamExecutor* executor, DeviceMemory index, + std::vector branches) override; - absl::Status For(StreamExecutor* executor, int32_t num_iteration, + absl::Status For(ExecutionScopeId execution_scope_id, + StreamExecutor* executor, int32_t num_iteration, DeviceMemory loop_counter, - CommandBuffer::Builder body_builder) override; + Builder body_builder) override; - absl::Status While(StreamExecutor* executor, DeviceMemory pred, - CommandBuffer::Builder cond_builder, - CommandBuffer::Builder body_builder) override; + absl::Status While(ExecutionScopeId execution_scope_id, + StreamExecutor* executor, DeviceMemory pred, + Builder cond_builder, Builder body_builder) override; absl::Status Finalize() override; absl::Status Update() override; @@ -95,37 +136,40 @@ class GpuCommandBuffer : public internal::CommandBufferInterface { GpuGraphExecHandle executable() const { return exec_; } GpuGraphHandle graph() const { return graph_; } - CommandBuffer::Mode mode() const override { return mode_; } - CommandBuffer::State state() const override { return state_; } + Mode mode() const override { return mode_; } + State state() const override { return state_; } - // A helper template for launching typed kernels. - template - absl::Status Launch(const TypedKernel& kernel, - const ThreadDim& threads, const BlockDim& blocks, - Args... args); + static GpuCommandBuffer* Cast(CommandBuffer* command_buffer) { + return static_cast(command_buffer); + } + + static const GpuCommandBuffer* Cast(const CommandBuffer* command_buffer) { + return static_cast(command_buffer); + } + + absl::Span nodes(ExecutionScopeId id) const; + absl::Span barriers(ExecutionScopeId id) const; + + absl::Span nodes() const { + return nodes(kDefaulExecutionScope); + } + + absl::Span barriers() const { + return barriers(kDefaulExecutionScope); + } + + private: + absl::Status Trace(Stream* stream, + absl::AnyInvocable function) override; // We track the total number of allocated and alive executable graphs in the // process to track the command buffers resource usage. Executable graph // allocates resources on a GPU devices (rule of thumb is ~8kb per node), so // we have to be careful not to keep too many of them alive for too long, or // we have a higher risk of OOM errors. - // - // TODO(ezhulenev): We need to have a policy for how to evict unused - // executable graph instances from a device, currently lifetime of an - // executable graph is tied to a parent command buffer, and we can have - // thousands of command buffers alive at the same time. static int64_t AllocatedExecs(); static int64_t AliveExecs(); - static GpuCommandBuffer* Cast(CommandBuffer* command_buffer) { - return static_cast(command_buffer->implementation()); - } - - static const GpuCommandBuffer* Cast(const CommandBuffer* command_buffer) { - return static_cast( - command_buffer->implementation()); - } - private: using Dependencies = absl::InlinedVector; @@ -153,16 +197,16 @@ class GpuCommandBuffer : public internal::CommandBufferInterface { TypedKernel>; // A callback to launch a kernel that updates conditional handles state. - using SetConditionFn = - std::function)>; + using SetConditionFn = std::function)>; - // An extension of `CommandBuffer::Builder` for building conditional command - // buffers tied to conditional handles. + // An extension of `Builder` for building conditional command buffers tied to + // conditional handles. using ConditionBuilder = std::function; // Wraps a regular command buffer builder into condition builder. - static ConditionBuilder ToConditionBuilder(CommandBuffer::Builder builder); + static ConditionBuilder ToConditionBuilder(Builder builder); using ConditionType = typename GpuDriver::GpuGraphConditionalNodeParams::Type; @@ -183,13 +227,8 @@ class GpuCommandBuffer : public internal::CommandBufferInterface { // For each conditional node in the Gpu graph we keep a record of conditional // command buffers attached to a node, so we can apply updates to them. struct ConditionalCommandBuffers { - ConditionalCommandBuffers(std::vector handles, - std::vector command_buffers) - : handles(std::move(handles)), - command_buffers(std::move(command_buffers)) {} - std::vector handles; - std::vector command_buffers; + std::vector> command_buffers; }; using AllocationResult = std::pair; @@ -197,28 +236,40 @@ class GpuCommandBuffer : public internal::CommandBufferInterface { absl::StatusOr> CreateConditionalHandles(size_t num_handles); - absl::StatusOr> CreateConditionalNodes( - ConditionType type, absl::Span handles); - - absl::StatusOr> CreateConditionalCommandBuffers( + absl::StatusOr>> + CreateConditionalCommandBuffers( absl::Span handles, absl::Span graphs, absl::Span builders); absl::Status UpdateConditionalCommandBuffers( absl::Span handles, - absl::Span command_buffers, + absl::Span> command_buffers, absl::Span builders); + absl::StatusOr> CreateConditionalNodes( + ExecutionScopeId execution_scope_id, ConditionType type, + absl::Span handles); + absl::Status CreateConditionalCommand( - StreamExecutor* executor, ConditionType type, - SetConditionFn set_condition, + ExecutionScopeId execution_scope_id, StreamExecutor* executor, + ConditionType type, SetConditionFn set_condition, absl::Span builders); - Dependencies GetBarrier(); - - // Returns loaded no-op kernel used as a barrier, or loads it on a given - // stream executor. Loaded kernel owned by a current command buffer. + Dependencies GetBarrier(ExecutionScopeId execution_scope_id); + + // Returns loaded auxiliary kernels, or loads them on a given stream executor. + // Loaded kernels owned by a current command buffer. + absl::StatusOr GetSetIfConditionKernel( + StreamExecutor* executor); + absl::StatusOr GetSetIfElseConditionKernel( + StreamExecutor* executor); + absl::StatusOr GetSetCaseConditionKernel( + StreamExecutor* executor); + absl::StatusOr GetSetForConditionKernel( + StreamExecutor* executor); + absl::StatusOr GetSetWhileConditionKernel( + StreamExecutor* executor); absl::StatusOr GetNoOpKernel(StreamExecutor* executor); // Recursively disable all nodes corresponding to barriers (including nested @@ -229,7 +280,8 @@ class GpuCommandBuffer : public internal::CommandBufferInterface { // Launches CUDA kernels with packed arguments. absl::Status LaunchWithPackedArgs( - const ThreadDim& threads, const BlockDim& blocks, const Kernel& kernel, + ExecutionScopeId execution_scope_id, const ThreadDim& threads, + const BlockDim& blocks, const Kernel& kernel, const KernelArgsPackedArrayBase& packed_args); // Returns OK status if command buffer is not finalized and it is still @@ -241,6 +293,13 @@ class GpuCommandBuffer : public internal::CommandBufferInterface { absl::Status CheckNumCommandBuffers( const ConditionalCommandBuffers& cmd_buffers, size_t num_cmd_buffers); + // Creates a new no-op node acting as a barrier. + absl::StatusOr CreateBarrierNode( + StreamExecutor* executor, const Dependencies& dependencies); + + // Collects a set of dependencies for a new barrier. + Dependencies GetBarrierDependencies(ExecutionScopeId execution_scope_id); + static_assert(std::is_pointer_v, "GpuGraphHandle must be a pointer"); static_assert(std::is_pointer_v, @@ -248,8 +307,8 @@ class GpuCommandBuffer : public internal::CommandBufferInterface { static_assert(std::is_pointer_v, "GpuGraphNodeHandle must be a pointer"); - CommandBuffer::Mode mode_; - CommandBuffer::State state_ = CommandBuffer::State::kCreate; + Mode mode_; + State state_ = State::kCreate; GpuExecutor* parent_; // not owned, must outlive *this @@ -259,60 +318,55 @@ class GpuCommandBuffer : public internal::CommandBufferInterface { GpuGraphExecHandle exec_ = nullptr; // owned if `is_owned_graph_exec_` bool is_owned_graph_exec_ = true; // ownership of `is_owned_graph_exec_` - // Handle of a graph node that acts as a barrier for all newly added commands. - GpuGraphNodeHandle barrier_ = nullptr; - - // A handle to a Gpu graph node and a metadata describing the node properties. - struct GpuGraphNodeInfo { - // Gpu graph node handle owned by `graph_` instance. - GpuGraphNodeHandle handle = nullptr; + // ExecutionScope holds the state of an underlying CUDA graph (nodes an + // barriers added to a graph) for a single execution scope. + struct ExecutionScope { + // Tracks indices into data structures during command buffer updates. + struct UpdateState { + // Index points to the graph node inside `nodes` that will be updated + // next. + int64_t node_idx = 0; + + // Index points to the barrier node inside `barriers` that will be updated + // on a next call to `Barrier(...)`. + int64_t barrier_idx = 0; + + // Index points to the conditional command buffers that will be updated + // next when we'll be updating next conditional command (If, Case, While). + int64_t conditional_idx = 0; + }; + + // Gpu graph nodes corresponding to recorded commands (launch, memcpy, + // etc.). + std::vector nodes; + + // Gpu graph barriers that define recorded commands execution order. + std::vector barriers; + + // Command buffers for conditional nodes in the Gpu graph. Underlying Gpu + // graphs owned by the `graph_` instance. + std::vector conditional_command_buffers; + + // Tracks execution scope update state. + UpdateState update_state; }; - // Gpu graph nodes info for load bearing graph nodes (kernel, memcpy, etc.) - // corresponding to command buffer commands and also to no-op nodes - // corresponding to barriers (nodes defining DAG structure). - std::vector nodes_; - - // Handles to no-op graph nodes corresponding to barriers that define nodes - // execution order. Can be nullptr if regular node acts as a barrier. - std::vector barriers_; - - // Command buffers for conditional nodes in the Gpu graph. Underlying Gpu - // graphs owned by the `graph_` instance. - std::vector conditional_command_buffers_; + // Execution scopes recorded into the command buffer. + absl::flat_hash_map execution_scopes_; // Track the number of command buffer updates for debugging. int64_t num_updates_ = 0; - // Tracks indices into internal data structures during command buffer updates. - struct UpdateState { - // Index points to the graph node inside `nodes_` that will be updated next. - int64_t node_idx = 0; - - // Index points to the barrier node inside `barriers_` that will be updated - // on a next call to `Barrier()`. - int64_t barrier_idx = 0; - - // Index points to the conditional command buffers that will be updated next - // when we'll be updating next conditional command (If, Case, While). - int64_t conditional_idx = 0; - }; - - UpdateState update_state_; - - // Loaded instance of a no-op kernel used as command buffer barrier. - std::unique_ptr noop_kernel_; + // Lazy loaded auxiliary kernels required for building CUDA graphs (no-op + // barriers, updating conditional handles, etc.). + SetIfConditionKernel set_if_condition_kernel_; + SetIfElseConditionKernel set_if_else_condition_kernel_; + SetCaseConditionKernel set_case_condition_kernel_; + SetForConditionKernel set_for_condition_kernel_; + SetWhileConditionKernel set_while_condition_kernel_; + NoOpKernel noop_kernel_; }; -template -inline absl::Status GpuCommandBuffer::Launch( - const TypedKernel& kernel, const ThreadDim& threads, - const BlockDim& blocks, Args... args) { - auto kernel_args = PackKernelArgs(kernel, args...); - TF_RETURN_IF_ERROR(Launch(threads, blocks, kernel, *kernel_args)); - return absl::OkStatus(); -} - //===----------------------------------------------------------------------===// // Implementation details device kernels required by GpuCommandBuffer. //===----------------------------------------------------------------------===// @@ -321,16 +375,15 @@ inline absl::Status GpuCommandBuffer::Launch( // empty nodes are not supported within conditional CUDA graphs (in CUDA 12.3). void* GetNoOpKernel(); -// See `cuda_conditional_kernels.cu.cc` for CUDA implementations. These are +// See `cuda_conditional_kernels.cc` for CUDA implementation. These are // various kernels that update Gpu conditionals based on the device memory // values, and allow implementing on-device control flow via conditional command // buffers. - -void* GetSetIfConditionKernel(); -void* GetSetIfElseConditionKernel(); -void* GetSetCaseConditionKernel(); -void* GetSetForConditionKernel(); -void* GetSetWhileConditionKernel(); +std::string_view GetSetIfConditionKernel(); +std::string_view GetSetIfElseConditionKernel(); +std::string_view GetSetCaseConditionKernel(); +std::string_view GetSetForConditionKernel(); +std::string_view GetSetWhileConditionKernel(); } // namespace stream_executor::gpu diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer_test.cc b/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer_test.cc index 53bf0958d23851..15cd989f722b16 100644 --- a/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer_test.cc +++ b/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer_test.cc @@ -13,31 +13,45 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "xla/stream_executor/gpu/gpu_command_buffer.h" + #include +#include #include #include #include "absl/log/check.h" +#include "absl/status/status.h" +#include "absl/strings/ascii.h" #include "xla/service/platform_util.h" #include "xla/stream_executor/command_buffer.h" +#include "xla/stream_executor/gpu/gpu_driver.h" #include "xla/stream_executor/gpu/gpu_test_kernels.h" #include "xla/stream_executor/gpu/gpu_types.h" // IWYU pragma: keep #include "xla/stream_executor/kernel.h" +#include "xla/stream_executor/kernel_spec.h" #include "xla/stream_executor/launch_dim.h" -#include "xla/stream_executor/multi_platform_manager.h" #include "xla/stream_executor/platform.h" +#include "xla/stream_executor/platform_manager.h" #include "xla/stream_executor/stream.h" #include "xla/stream_executor/stream_executor.h" #include "tsl/lib/core/status_test_util.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/status.h" +#include "tsl/platform/statusor.h" #include "tsl/platform/test.h" #include "tsl/platform/test_benchmark.h" +#if GOOGLE_CUDA +#include "third_party/gpus/cuda/include/cuda.h" +#endif + namespace stream_executor::gpu { static Platform* GpuPlatform() { auto name = absl::AsciiStrToUpper( xla::PlatformUtil::CanonicalPlatformName("gpu").value()); - return MultiPlatformManager::PlatformWithName(name).value(); + return PlatformManager::PlatformWithName(name).value(); } static MultiKernelLoaderSpec GetAddI32KernelSpec() { @@ -63,19 +77,40 @@ using AddI32Ptrs3 = TypedKernel>; static constexpr auto nested = CommandBuffer::Mode::kNested; // NOLINT static constexpr auto primary = CommandBuffer::Mode::kPrimary; // NOLINT +template +static std::vector Deps(Info info) { + if (auto deps = GpuDriver::GraphNodeGetDependencies(info.handle); deps.ok()) { + return *deps; + } + return {GpuGraphNodeHandle(0xDEADBEEF)}; +} + +template +static std::vector ExpectedDeps(Infos... info) { + return {info.handle...}; +} + +// Some of the tests rely on CUDA 12.3+ features. +static bool IsAtLeastCuda12300() { +#if defined(TENSORFLOW_USE_ROCM) + return false; +#endif +#if CUDA_VERSION >= 12030 + return true; +#endif + return false; +} + TEST(GpuCommandBufferTest, LaunchSingleKernel) { Platform* platform = GpuPlatform(); StreamExecutor* executor = platform->ExecutorForDevice(0).value(); Stream stream(executor); - stream.Init(); - ASSERT_TRUE(stream.ok()); + TF_ASSERT_OK(stream.Initialize()); MultiKernelLoaderSpec spec(/*arity=*/3); spec.AddInProcessSymbol(internal::GetAddI32Kernel(), "add"); - - AddI32Kernel add(executor); - TF_ASSERT_OK(executor->GetKernel(spec, &add)); + TF_ASSERT_OK_AND_ASSIGN(auto add, AddI32Kernel::Create(executor, spec)); int64_t length = 4; int64_t byte_length = sizeof(int32_t) * length; @@ -85,38 +120,38 @@ TEST(GpuCommandBufferTest, LaunchSingleKernel) { DeviceMemory b = executor->AllocateArray(length, 0); DeviceMemory c = executor->AllocateArray(length, 0); - stream.ThenMemset32(&a, 1, byte_length); - stream.ThenMemset32(&b, 2, byte_length); - stream.ThenMemZero(&c, byte_length); + TF_ASSERT_OK(stream.Memset32(&a, 1, byte_length)); + TF_ASSERT_OK(stream.Memset32(&b, 2, byte_length)); + TF_ASSERT_OK(stream.MemZero(&c, byte_length)); // Create a command buffer with a single kernel launch. auto cmd_buffer = CommandBuffer::Create(executor).value(); - TF_ASSERT_OK(cmd_buffer.Launch(add, ThreadDim(), BlockDim(4), a, b, c)); - TF_ASSERT_OK(cmd_buffer.Finalize()); + TF_ASSERT_OK(cmd_buffer->Launch(add, ThreadDim(), BlockDim(4), a, b, c)); + TF_ASSERT_OK(cmd_buffer->Finalize()); - TF_ASSERT_OK(executor->Submit(&stream, cmd_buffer)); + TF_ASSERT_OK(executor->Submit(&stream, *cmd_buffer)); // Copy `c` data back to host. std::vector dst(4, 42); - stream.ThenMemcpy(dst.data(), c, byte_length); + TF_ASSERT_OK(stream.Memcpy(dst.data(), c, byte_length)); std::vector expected = {3, 3, 3, 3}; ASSERT_EQ(dst, expected); // Prepare argument for graph update: d = 0 DeviceMemory d = executor->AllocateArray(length, 0); - stream.ThenMemZero(&d, byte_length); + TF_ASSERT_OK(stream.MemZero(&d, byte_length)); // Update command buffer to write into `d` buffer. - TF_ASSERT_OK(cmd_buffer.Update()); - TF_ASSERT_OK(cmd_buffer.Launch(add, ThreadDim(), BlockDim(4), a, b, d)); - TF_ASSERT_OK(cmd_buffer.Finalize()); + TF_ASSERT_OK(cmd_buffer->Update()); + TF_ASSERT_OK(cmd_buffer->Launch(add, ThreadDim(), BlockDim(4), a, b, d)); + TF_ASSERT_OK(cmd_buffer->Finalize()); - TF_ASSERT_OK(executor->Submit(&stream, cmd_buffer)); + TF_ASSERT_OK(executor->Submit(&stream, *cmd_buffer)); // Copy `d` data back to host. std::fill(dst.begin(), dst.end(), 42); - stream.ThenMemcpy(dst.data(), d, byte_length); + TF_ASSERT_OK(stream.Memcpy(dst.data(), d, byte_length)); ASSERT_EQ(dst, expected); } @@ -131,10 +166,7 @@ TEST(CudaCommandBufferTest, TraceSingleKernel) { StreamExecutor* executor = platform->ExecutorForDevice(0).value(); Stream stream(executor); - stream.Init(); - ASSERT_TRUE(stream.ok()); - - AddI32Ptrs3 add(executor); + TF_ASSERT_OK(stream.Initialize()); // Register a kernel with a custom arguments packing function that packs // device memory arguments into a struct with pointers. @@ -142,15 +174,15 @@ TEST(CudaCommandBufferTest, TraceSingleKernel) { const KernelArgs& args) { auto bufs = Cast(&args)->device_memory_args(); auto cast = [](auto m) { return reinterpret_cast(m.opaque()); }; - return PackKernelArgs(add, internal::Ptrs3{ - cast(bufs[0]), - cast(bufs[1]), - cast(bufs[2]), - }); + return PackKernelArgs(/*shmem_bytes=*/0, internal::Ptrs3{ + cast(bufs[0]), + cast(bufs[1]), + cast(bufs[2]), + }); }); spec.AddInProcessSymbol(internal::GetAddI32Ptrs3Kernel(), "add"); - TF_ASSERT_OK(executor->GetKernel(spec, &add)); + TF_ASSERT_OK_AND_ASSIGN(auto add, AddI32Ptrs3::Create(executor, spec)); int64_t length = 4; int64_t byte_length = sizeof(int32_t) * length; @@ -160,9 +192,9 @@ TEST(CudaCommandBufferTest, TraceSingleKernel) { DeviceMemory b = executor->AllocateArray(length, 0); DeviceMemory c = executor->AllocateArray(length, 0); - stream.ThenMemset32(&a, 1, byte_length); - stream.ThenMemset32(&b, 2, byte_length); - stream.ThenMemZero(&c, byte_length); + TF_ASSERT_OK(stream.Memset32(&a, 1, byte_length)); + TF_ASSERT_OK(stream.Memset32(&b, 2, byte_length)); + TF_ASSERT_OK(stream.MemZero(&c, byte_length)); // Use an array of device memory base pointers as argument to test packing. KernelArgsDeviceMemoryArray args({a, b, c}, 0); @@ -171,16 +203,16 @@ TEST(CudaCommandBufferTest, TraceSingleKernel) { auto cmd_buffer = CommandBuffer::Trace( executor, [&](Stream* stream) { - return executor->Launch(stream, ThreadDim(), BlockDim(4), add, args); + return executor->Launch(stream, ThreadDim(), BlockDim(4), *add, args); }, primary); TF_ASSERT_OK(cmd_buffer.status()); - TF_ASSERT_OK(executor->Submit(&stream, *cmd_buffer)); + TF_ASSERT_OK(executor->Submit(&stream, **cmd_buffer)); // Copy data back to host. std::vector dst(4, 42); - stream.ThenMemcpy(dst.data(), c, byte_length); + TF_ASSERT_OK(stream.Memcpy(dst.data(), c, byte_length)); std::vector expected = {3, 3, 3, 3}; ASSERT_EQ(dst, expected); @@ -191,13 +223,10 @@ TEST(GpuCommandBufferTest, LaunchNestedCommandBuffer) { StreamExecutor* executor = platform->ExecutorForDevice(0).value(); Stream stream(executor); - stream.Init(); - ASSERT_TRUE(stream.ok()); + TF_ASSERT_OK(stream.Initialize()); MultiKernelLoaderSpec spec = GetAddI32KernelSpec(); - - AddI32Kernel add(executor); - TF_ASSERT_OK(executor->GetKernel(spec, &add)); + TF_ASSERT_OK_AND_ASSIGN(auto add, AddI32Kernel::Create(executor, spec)); int64_t length = 4; int64_t byte_length = sizeof(int32_t) * length; @@ -207,43 +236,43 @@ TEST(GpuCommandBufferTest, LaunchNestedCommandBuffer) { DeviceMemory b = executor->AllocateArray(length, 0); DeviceMemory c = executor->AllocateArray(length, 0); - stream.ThenMemset32(&a, 1, byte_length); - stream.ThenMemset32(&b, 2, byte_length); - stream.ThenMemZero(&c, byte_length); + TF_ASSERT_OK(stream.Memset32(&a, 1, byte_length)); + TF_ASSERT_OK(stream.Memset32(&b, 2, byte_length)); + TF_ASSERT_OK(stream.MemZero(&c, byte_length)); // Create a command buffer with a single kernel launch. auto primary_cmd = CommandBuffer::Create(executor).value(); auto nested_cmd = CommandBuffer::Create(executor, nested).value(); - TF_ASSERT_OK(nested_cmd.Launch(add, ThreadDim(), BlockDim(4), a, b, c)); - TF_ASSERT_OK(primary_cmd.AddNestedCommandBuffer(nested_cmd)); - TF_ASSERT_OK(primary_cmd.Finalize()); + TF_ASSERT_OK(nested_cmd->Launch(add, ThreadDim(), BlockDim(4), a, b, c)); + TF_ASSERT_OK(primary_cmd->AddNestedCommandBuffer(*nested_cmd)); + TF_ASSERT_OK(primary_cmd->Finalize()); - TF_ASSERT_OK(executor->Submit(&stream, primary_cmd)); + TF_ASSERT_OK(executor->Submit(&stream, *primary_cmd)); // Copy `c` data back to host. std::vector dst(4, 42); - stream.ThenMemcpy(dst.data(), c, byte_length); + TF_ASSERT_OK(stream.Memcpy(dst.data(), c, byte_length)); std::vector expected = {3, 3, 3, 3}; ASSERT_EQ(dst, expected); // Prepare argument for graph update: d = 0 DeviceMemory d = executor->AllocateArray(length, 0); - stream.ThenMemZero(&d, byte_length); + TF_ASSERT_OK(stream.MemZero(&d, byte_length)); // Update command buffer to write into `d` buffer by creating a new nested // command buffer. nested_cmd = CommandBuffer::Create(executor, nested).value(); - TF_ASSERT_OK(nested_cmd.Launch(add, ThreadDim(), BlockDim(4), a, b, d)); - TF_ASSERT_OK(primary_cmd.Update()); - TF_ASSERT_OK(primary_cmd.AddNestedCommandBuffer(nested_cmd)); - TF_ASSERT_OK(primary_cmd.Finalize()); + TF_ASSERT_OK(nested_cmd->Launch(add, ThreadDim(), BlockDim(4), a, b, d)); + TF_ASSERT_OK(primary_cmd->Update()); + TF_ASSERT_OK(primary_cmd->AddNestedCommandBuffer(*nested_cmd)); + TF_ASSERT_OK(primary_cmd->Finalize()); - TF_ASSERT_OK(executor->Submit(&stream, primary_cmd)); + TF_ASSERT_OK(executor->Submit(&stream, *primary_cmd)); // Copy `d` data back to host. std::fill(dst.begin(), dst.end(), 42); - stream.ThenMemcpy(dst.data(), d, byte_length); + TF_ASSERT_OK(stream.Memcpy(dst.data(), d, byte_length)); ASSERT_EQ(dst, expected); } @@ -252,8 +281,7 @@ TEST(GpuCommandBufferTest, MemcpyDeviceToDevice) { StreamExecutor* executor = platform->ExecutorForDevice(0).value(); Stream stream(executor); - stream.Init(); - ASSERT_TRUE(stream.ok()); + TF_ASSERT_OK(stream.Initialize()); int64_t length = 4; int64_t byte_length = sizeof(int32_t) * length; @@ -262,35 +290,35 @@ TEST(GpuCommandBufferTest, MemcpyDeviceToDevice) { DeviceMemory a = executor->AllocateArray(length, 0); DeviceMemory b = executor->AllocateArray(length, 0); - stream.ThenMemset32(&a, 42, byte_length); + TF_ASSERT_OK(stream.Memset32(&a, 42, byte_length)); // Create a command buffer with a single a to b memcpy command. auto cmd_buffer = CommandBuffer::Create(executor).value(); - TF_ASSERT_OK(cmd_buffer.MemcpyDeviceToDevice(&b, a, byte_length)); - TF_ASSERT_OK(cmd_buffer.Finalize()); + TF_ASSERT_OK(cmd_buffer->MemcpyDeviceToDevice(&b, a, byte_length)); + TF_ASSERT_OK(cmd_buffer->Finalize()); - TF_ASSERT_OK(executor->Submit(&stream, cmd_buffer)); + TF_ASSERT_OK(executor->Submit(&stream, *cmd_buffer)); // Copy `b` data back to host. std::vector dst(4, 0); - stream.ThenMemcpy(dst.data(), a, byte_length); + TF_ASSERT_OK(stream.Memcpy(dst.data(), a, byte_length)); std::vector expected = {42, 42, 42, 42}; ASSERT_EQ(dst, expected); // Update command buffer to swap the memcpy direction. - TF_ASSERT_OK(cmd_buffer.Update()); - TF_ASSERT_OK(cmd_buffer.MemcpyDeviceToDevice(&a, b, byte_length)); - TF_ASSERT_OK(cmd_buffer.Finalize()); + TF_ASSERT_OK(cmd_buffer->Update()); + TF_ASSERT_OK(cmd_buffer->MemcpyDeviceToDevice(&a, b, byte_length)); + TF_ASSERT_OK(cmd_buffer->Finalize()); // Clear destination to test that command buffer actually copied memory. - stream.ThenMemset32(&a, 0, byte_length); + TF_ASSERT_OK(stream.Memset32(&a, 0, byte_length)); - TF_ASSERT_OK(executor->Submit(&stream, cmd_buffer)); + TF_ASSERT_OK(executor->Submit(&stream, *cmd_buffer)); // Copy `a` data back to host. std::fill(dst.begin(), dst.end(), 0); - stream.ThenMemcpy(dst.data(), a, byte_length); + TF_ASSERT_OK(stream.Memcpy(dst.data(), a, byte_length)); ASSERT_EQ(dst, expected); } @@ -299,8 +327,7 @@ TEST(GpuCommandBufferTest, Memset) { StreamExecutor* executor = platform->ExecutorForDevice(0).value(); Stream stream(executor); - stream.Init(); - ASSERT_TRUE(stream.ok()); + TF_ASSERT_OK(stream.Initialize()); int64_t length = 4; int64_t byte_length = sizeof(int32_t) * length; @@ -309,53 +336,378 @@ TEST(GpuCommandBufferTest, Memset) { // Create a command buffer with a single memset command. auto cmd_buffer = CommandBuffer::Create(executor).value(); - TF_ASSERT_OK(cmd_buffer.Memset(&a, uint32_t{42}, length)); - TF_ASSERT_OK(cmd_buffer.Finalize()); + TF_ASSERT_OK(cmd_buffer->Memset(&a, uint32_t{42}, length)); + TF_ASSERT_OK(cmd_buffer->Finalize()); - TF_ASSERT_OK(executor->Submit(&stream, cmd_buffer)); + TF_ASSERT_OK(executor->Submit(&stream, *cmd_buffer)); // Copy `a` data back to host. std::vector dst(4, 0); - stream.ThenMemcpy(dst.data(), a, byte_length); + TF_ASSERT_OK(stream.Memcpy(dst.data(), a, byte_length)); std::vector expected = {42, 42, 42, 42}; ASSERT_EQ(dst, expected); // Update command buffer to use a new bit pattern. - TF_ASSERT_OK(cmd_buffer.Update()); - TF_ASSERT_OK(cmd_buffer.Memset(&a, uint32_t{43}, length)); - TF_ASSERT_OK(cmd_buffer.Finalize()); + TF_ASSERT_OK(cmd_buffer->Update()); + TF_ASSERT_OK(cmd_buffer->Memset(&a, uint32_t{43}, length)); + TF_ASSERT_OK(cmd_buffer->Finalize()); - TF_ASSERT_OK(executor->Submit(&stream, cmd_buffer)); + TF_ASSERT_OK(executor->Submit(&stream, *cmd_buffer)); // Copy `d` data back to host. std::fill(dst.begin(), dst.end(), 0); - stream.ThenMemcpy(dst.data(), a, byte_length); + TF_ASSERT_OK(stream.Memcpy(dst.data(), a, byte_length)); expected = {43, 43, 43, 43}; ASSERT_EQ(dst, expected); } -TEST(GpuCommandBufferTest, ConditionalIf) { +TEST(GpuCommandBufferTest, Barriers) { Platform* platform = GpuPlatform(); - if (!CommandBuffer::SupportsConditionalCommands(platform)) { - GTEST_SKIP() << "CUDA graph conditionals are not supported"; + StreamExecutor* executor = platform->ExecutorForDevice(0).value(); + + Stream stream(executor); + TF_ASSERT_OK(stream.Initialize()); + + // Allocate device buffers for memset operations. + std::vector> buffers; + for (size_t i = 0; i < 6; ++i) { + buffers.push_back(executor->AllocateArray(1, 0)); + } + + // Transfer buffers data back to host. + auto transfer_buffers = [&]() -> std::vector { + std::vector dst(buffers.size(), 0); + for (size_t i = 0; i < buffers.size(); ++i) { + TF_CHECK_OK(stream.Memcpy(dst.data() + i, buffers[i], sizeof(int32_t))); + } + return dst; + }; + + auto record = [&](CommandBuffer* cmd_buffer, uint32_t bit_pattern) { + // Check that root barrier ignored. + TF_RETURN_IF_ERROR(cmd_buffer->Barrier(executor)); + TF_RETURN_IF_ERROR(cmd_buffer->Memset(&buffers[0], bit_pattern + 0, 1)); + // Check barrier after a single command. + TF_RETURN_IF_ERROR(cmd_buffer->Barrier(executor)); + TF_RETURN_IF_ERROR(cmd_buffer->Memset(&buffers[1], bit_pattern + 1, 1)); + // Check that repeated barriers are no-op. + TF_RETURN_IF_ERROR(cmd_buffer->Barrier(executor)); + TF_RETURN_IF_ERROR(cmd_buffer->Barrier(executor)); + TF_RETURN_IF_ERROR(cmd_buffer->Memset(&buffers[2], bit_pattern + 2, 1)); + TF_RETURN_IF_ERROR(cmd_buffer->Memset(&buffers[3], bit_pattern + 3, 1)); + // Check that barrier can have multiple dependencies. + TF_RETURN_IF_ERROR(cmd_buffer->Barrier(executor)); + TF_RETURN_IF_ERROR(cmd_buffer->Memset(&buffers[4], bit_pattern + 4, 1)); + TF_RETURN_IF_ERROR(cmd_buffer->Memset(&buffers[5], bit_pattern + 5, 1)); + // Check that barrier can be that last command. + TF_RETURN_IF_ERROR(cmd_buffer->Barrier(executor)); + return cmd_buffer->Finalize(); + }; + + // Create a command buffer with a DAG of memset commands. + auto cmd_buffer = CommandBuffer::Create(executor).value(); + TF_ASSERT_OK(record(cmd_buffer.get(), 42)); + TF_ASSERT_OK(executor->Submit(&stream, *cmd_buffer)); + + std::vector expected = {42, 43, 44, 45, 46, 47}; + ASSERT_EQ(transfer_buffers(), expected); + + // Check the command buffer structure. + GpuCommandBuffer* gpu_cmd_buffer = GpuCommandBuffer::Cast(cmd_buffer.get()); + ASSERT_EQ(gpu_cmd_buffer->nodes().size(), 6); + ASSERT_EQ(gpu_cmd_buffer->barriers().size(), 6); + + auto nodes = gpu_cmd_buffer->nodes(); + auto barriers = gpu_cmd_buffer->barriers(); + + // First barrier does not have any dependencies. + EXPECT_TRUE(barriers[0].is_barrier_node); + EXPECT_TRUE(Deps(barriers[0]).empty()); + + // Second barrier reuses first memset node. + EXPECT_FALSE(barriers[1].is_barrier_node); + EXPECT_EQ(barriers[1].handle, nodes[0].handle); + + // Third and fourth barriers reuse second memset node. + EXPECT_FALSE(barriers[2].is_barrier_node); + EXPECT_FALSE(barriers[3].is_barrier_node); + EXPECT_EQ(barriers[2].handle, nodes[1].handle); + EXPECT_EQ(barriers[3].handle, nodes[1].handle); + + // Fifth and sixth barriers are barrier nodes. + EXPECT_TRUE(barriers[4].is_barrier_node); + EXPECT_TRUE(barriers[5].is_barrier_node); + + EXPECT_EQ(Deps(barriers[4]), ExpectedDeps(nodes[2], nodes[3])); + EXPECT_EQ(Deps(barriers[5]), ExpectedDeps(nodes[4], nodes[5])); + + // Update command buffer to use a new bit pattern. + TF_ASSERT_OK(cmd_buffer->Update()); + TF_ASSERT_OK(record(cmd_buffer.get(), 43)); + TF_ASSERT_OK(executor->Submit(&stream, *cmd_buffer)); + + expected = {43, 44, 45, 46, 47, 48}; + ASSERT_EQ(transfer_buffers(), expected); +} + +TEST(GpuCommandBufferTest, IndependentExecutionScopes) { + Platform* platform = GpuPlatform(); + StreamExecutor* executor = platform->ExecutorForDevice(0).value(); + + Stream stream(executor); + TF_ASSERT_OK(stream.Initialize()); + + CommandBuffer::ExecutionScopeId s0 = CommandBuffer::ExecutionScopeId(0); + CommandBuffer::ExecutionScopeId s1 = CommandBuffer::ExecutionScopeId(1); + + // Allocate device buffers for memset operations. + std::vector> buffers; + for (size_t i = 0; i < 4; ++i) { + buffers.push_back(executor->AllocateArray(1, 0)); + } + + // Transfer buffers data back to host. + auto transfer_buffers = [&]() -> std::vector { + std::vector dst(buffers.size(), 0); + for (size_t i = 0; i < buffers.size(); ++i) { + TF_CHECK_OK(stream.Memcpy(dst.data() + i, buffers[i], sizeof(int32_t))); + } + return dst; + }; + + auto record = [&](CommandBuffer* cmd_buffer, uint32_t bit_pattern) { + TF_RETURN_IF_ERROR(cmd_buffer->Memset(s0, &buffers[0], bit_pattern + 0, 1)); + TF_RETURN_IF_ERROR(cmd_buffer->Memset(s0, &buffers[1], bit_pattern + 1, 1)); + TF_RETURN_IF_ERROR(cmd_buffer->Memset(s1, &buffers[2], bit_pattern + 2, 1)); + TF_RETURN_IF_ERROR(cmd_buffer->Memset(s1, &buffers[3], bit_pattern + 3, 1)); + TF_RETURN_IF_ERROR(cmd_buffer->Barrier(executor, s0)); + TF_RETURN_IF_ERROR(cmd_buffer->Barrier(executor, s1)); + return cmd_buffer->Finalize(); + }; + + // Create a command buffer with a DAG of memset commands. + auto cmd_buffer = CommandBuffer::Create(executor).value(); + TF_ASSERT_OK(record(cmd_buffer.get(), 42)); + TF_ASSERT_OK(executor->Submit(&stream, *cmd_buffer)); + + std::vector expected = {42, 43, 44, 45}; + ASSERT_EQ(transfer_buffers(), expected); + + // Check the command buffer structure. + GpuCommandBuffer* gpu_cmd_buffer = GpuCommandBuffer::Cast(cmd_buffer.get()); + + auto nodes0 = gpu_cmd_buffer->nodes(s0); + auto nodes1 = gpu_cmd_buffer->nodes(s1); + auto barriers0 = gpu_cmd_buffer->barriers(s0); + auto barriers1 = gpu_cmd_buffer->barriers(s1); + + ASSERT_EQ(nodes0.size(), 2); + ASSERT_EQ(nodes1.size(), 2); + ASSERT_EQ(barriers0.size(), 1); + ASSERT_EQ(barriers1.size(), 1); + + EXPECT_TRUE(barriers0[0].is_barrier_node); + EXPECT_TRUE(barriers1[0].is_barrier_node); + + EXPECT_EQ(Deps(barriers0[0]), ExpectedDeps(nodes0[0], nodes0[1])); + EXPECT_EQ(Deps(barriers1[0]), ExpectedDeps(nodes1[0], nodes1[1])); + + // Update command buffer to use a new bit pattern. + TF_ASSERT_OK(cmd_buffer->Update()); + TF_ASSERT_OK(record(cmd_buffer.get(), 43)); + TF_ASSERT_OK(executor->Submit(&stream, *cmd_buffer)); + + expected = {43, 44, 45, 46}; + ASSERT_EQ(transfer_buffers(), expected); +} + +TEST(GpuCommandBufferTest, ExecutionScopeBarriers) { + Platform* platform = GpuPlatform(); + StreamExecutor* executor = platform->ExecutorForDevice(0).value(); + + Stream stream(executor); + TF_ASSERT_OK(stream.Initialize()); + + CommandBuffer::ExecutionScopeId s0 = CommandBuffer::ExecutionScopeId(0); + CommandBuffer::ExecutionScopeId s1 = CommandBuffer::ExecutionScopeId(1); + CommandBuffer::ExecutionScopeId s2 = CommandBuffer::ExecutionScopeId(2); + + // Allocate device buffers for memset operations. + std::vector> buffers; + for (size_t i = 0; i < 7; ++i) { + buffers.push_back(executor->AllocateArray(1, 0)); } + // Transfer buffers data back to host. + auto transfer_buffers = [&]() -> std::vector { + std::vector dst(buffers.size(), 0); + for (size_t i = 0; i < buffers.size(); ++i) { + TF_CHECK_OK(stream.Memcpy(dst.data() + i, buffers[i], sizeof(int32_t))); + } + return dst; + }; + + auto record = [&](CommandBuffer* cmd_buffer, uint32_t bit_pattern) { + TF_RETURN_IF_ERROR(cmd_buffer->Memset(s0, &buffers[0], bit_pattern + 0, 1)); + TF_RETURN_IF_ERROR(cmd_buffer->Memset(s0, &buffers[1], bit_pattern + 1, 1)); + TF_RETURN_IF_ERROR(cmd_buffer->Memset(s1, &buffers[2], bit_pattern + 2, 1)); + TF_RETURN_IF_ERROR(cmd_buffer->Memset(s1, &buffers[3], bit_pattern + 3, 1)); + // This will synchronize scopes 0 and 1 and also create an empty scope 2. + TF_RETURN_IF_ERROR(cmd_buffer->Barrier(executor, {s0, s1, s2})); + TF_RETURN_IF_ERROR(cmd_buffer->Memset(s0, &buffers[4], bit_pattern + 4, 1)); + TF_RETURN_IF_ERROR(cmd_buffer->Memset(s1, &buffers[5], bit_pattern + 5, 1)); + TF_RETURN_IF_ERROR(cmd_buffer->Memset(s2, &buffers[6], bit_pattern + 6, 1)); + return cmd_buffer->Finalize(); + }; + + // Create a command buffer with a DAG of memset commands. + auto cmd_buffer = CommandBuffer::Create(executor).value(); + TF_ASSERT_OK(record(cmd_buffer.get(), 42)); + TF_ASSERT_OK(executor->Submit(&stream, *cmd_buffer)); + + std::vector expected = {42, 43, 44, 45, 46, 47, 48}; + ASSERT_EQ(transfer_buffers(), expected); + + // Check the command buffer structure. + GpuCommandBuffer* gpu_cmd_buffer = GpuCommandBuffer::Cast(cmd_buffer.get()); + + auto nodes0 = gpu_cmd_buffer->nodes(s0); + auto nodes1 = gpu_cmd_buffer->nodes(s1); + auto nodes2 = gpu_cmd_buffer->nodes(s2); + auto barriers0 = gpu_cmd_buffer->barriers(s0); + auto barriers1 = gpu_cmd_buffer->barriers(s1); + auto barriers2 = gpu_cmd_buffer->barriers(s2); + + ASSERT_EQ(nodes0.size(), 3); + ASSERT_EQ(nodes1.size(), 3); + ASSERT_EQ(nodes2.size(), 1); + ASSERT_EQ(barriers0.size(), 2); + ASSERT_EQ(barriers1.size(), 2); + ASSERT_EQ(barriers2.size(), 2); + + // All barriers are real barrier nodes. + EXPECT_TRUE(barriers0[0].is_barrier_node && barriers0[1].is_barrier_node); + EXPECT_TRUE(barriers1[0].is_barrier_node && barriers1[1].is_barrier_node); + EXPECT_TRUE(barriers2[0].is_barrier_node && barriers2[1].is_barrier_node); + + // All scopes share a broadcasted barrier. + EXPECT_TRUE(barriers0[1].handle == barriers1[1].handle); + EXPECT_TRUE(barriers1[1].handle == barriers2[1].handle); + + EXPECT_EQ(Deps(barriers0[0]), ExpectedDeps(nodes0[0], nodes0[1])); + EXPECT_EQ(Deps(barriers1[0]), ExpectedDeps(nodes1[0], nodes1[1])); + + EXPECT_TRUE(Deps(barriers2[0]).empty()); + EXPECT_EQ(Deps(barriers2[1]), + ExpectedDeps(barriers0[0], barriers1[0], barriers2[0])); + + EXPECT_EQ(Deps(nodes0[2]), ExpectedDeps(barriers0[1])); + EXPECT_EQ(Deps(nodes1[2]), ExpectedDeps(barriers1[1])); + EXPECT_EQ(Deps(nodes2[0]), ExpectedDeps(barriers2[1])); + + // Update command buffer to use a new bit pattern. + TF_ASSERT_OK(cmd_buffer->Update()); + TF_ASSERT_OK(record(cmd_buffer.get(), 43)); + TF_ASSERT_OK(executor->Submit(&stream, *cmd_buffer)); + + expected = {43, 44, 45, 46, 47, 48, 49}; + ASSERT_EQ(transfer_buffers(), expected); +} + +TEST(GpuCommandBufferTest, ExecutionScopeOneDirectionalBarriers) { + Platform* platform = GpuPlatform(); StreamExecutor* executor = platform->ExecutorForDevice(0).value(); Stream stream(executor); - stream.Init(); - ASSERT_TRUE(stream.ok()); + TF_ASSERT_OK(stream.Initialize()); + + CommandBuffer::ExecutionScopeId s0 = CommandBuffer::ExecutionScopeId(0); + CommandBuffer::ExecutionScopeId s1 = CommandBuffer::ExecutionScopeId(1); + + // Allocate device buffers for memset operations. + std::vector> buffers; + for (size_t i = 0; i < 6; ++i) { + buffers.push_back(executor->AllocateArray(1, 0)); + } + + // Transfer buffers data back to host. + auto transfer_buffers = [&]() -> std::vector { + std::vector dst(buffers.size(), 0); + for (size_t i = 0; i < buffers.size(); ++i) { + stream.ThenMemcpy(dst.data() + i, buffers[i], sizeof(int32_t)); + } + return dst; + }; + + auto record = [&](CommandBuffer* cmd_buffer, uint32_t bit_pattern) { + TF_RETURN_IF_ERROR(cmd_buffer->Memset(s0, &buffers[0], bit_pattern + 0, 1)); + TF_RETURN_IF_ERROR(cmd_buffer->Memset(s0, &buffers[1], bit_pattern + 1, 1)); + TF_RETURN_IF_ERROR(cmd_buffer->Memset(s1, &buffers[2], bit_pattern + 2, 1)); + TF_RETURN_IF_ERROR(cmd_buffer->Memset(s1, &buffers[3], bit_pattern + 3, 1)); + // This will synchronize scopes 0 and 1. + TF_RETURN_IF_ERROR(cmd_buffer->Barrier(executor, s0, s1)); + TF_RETURN_IF_ERROR(cmd_buffer->Memset(s0, &buffers[4], bit_pattern + 4, 1)); + TF_RETURN_IF_ERROR(cmd_buffer->Memset(s1, &buffers[5], bit_pattern + 5, 1)); + return cmd_buffer->Finalize(); + }; + + // Create a command buffer with a DAG of memset commands. + auto cmd_buffer = CommandBuffer::Create(executor).value(); + TF_ASSERT_OK(record(cmd_buffer.get(), 42)); + TF_ASSERT_OK(executor->Submit(&stream, *cmd_buffer)); + + std::vector expected = {42, 43, 44, 45, 46, 47}; + ASSERT_EQ(transfer_buffers(), expected); + + // Check the command buffer structure. + GpuCommandBuffer* gpu_cmd_buffer = GpuCommandBuffer::Cast(cmd_buffer.get()); + + auto nodes0 = gpu_cmd_buffer->nodes(s0); + auto nodes1 = gpu_cmd_buffer->nodes(s1); + auto barriers0 = gpu_cmd_buffer->barriers(s0); + auto barriers1 = gpu_cmd_buffer->barriers(s1); + + ASSERT_EQ(nodes0.size(), 3); + ASSERT_EQ(nodes1.size(), 3); + ASSERT_EQ(barriers0.size(), 1); + ASSERT_EQ(barriers1.size(), 2); + + // All barriers are real barrier nodes. + EXPECT_TRUE(barriers0[0].is_barrier_node); + EXPECT_TRUE(barriers1[0].is_barrier_node && barriers1[1].is_barrier_node); + + EXPECT_EQ(Deps(barriers0[0]), ExpectedDeps(nodes0[0], nodes0[1])); + EXPECT_EQ(Deps(barriers1[0]), ExpectedDeps(nodes1[0], nodes1[1])); + EXPECT_EQ(Deps(barriers1[1]), ExpectedDeps(barriers0[0], barriers1[0])); + EXPECT_EQ(Deps(nodes0[2]), ExpectedDeps(barriers0[0])); + EXPECT_EQ(Deps(nodes1[2]), ExpectedDeps(barriers1[1])); + + // Update command buffer to use a new bit pattern. + TF_ASSERT_OK(cmd_buffer->Update()); + TF_ASSERT_OK(record(cmd_buffer.get(), 43)); + TF_ASSERT_OK(executor->Submit(&stream, *cmd_buffer)); - AddI32Kernel add(executor); + expected = {43, 44, 45, 46, 47, 48}; + ASSERT_EQ(transfer_buffers(), expected); +} - { // Load addition kernel. - MultiKernelLoaderSpec spec(/*arity=*/3); - spec.AddInProcessSymbol(internal::GetAddI32Kernel(), "add"); - TF_ASSERT_OK(executor->GetKernel(spec, &add)); +TEST(GpuCommandBufferTest, ConditionalIf) { + if (!IsAtLeastCuda12300()) { + GTEST_SKIP() << "CUDA graph conditionals are not supported"; } + Platform* platform = GpuPlatform(); + StreamExecutor* executor = platform->ExecutorForDevice(0).value(); + + Stream stream(executor); + TF_ASSERT_OK(stream.Initialize()); + + MultiKernelLoaderSpec spec(/*arity=*/3); + spec.AddInProcessSymbol(internal::GetAddI32Kernel(), "add"); + TF_ASSERT_OK_AND_ASSIGN(auto add, AddI32Kernel::Create(executor, spec)); + int64_t length = 4; int64_t byte_length = sizeof(int32_t) * length; @@ -366,10 +718,10 @@ TEST(GpuCommandBufferTest, ConditionalIf) { DeviceMemory c = executor->AllocateArray(length, 0); constexpr bool kTrue = true; - stream.ThenMemcpy(&pred, &kTrue, 1); - stream.ThenMemset32(&a, 1, byte_length); - stream.ThenMemset32(&b, 2, byte_length); - stream.ThenMemZero(&c, byte_length); + TF_ASSERT_OK(stream.Memcpy(&pred, &kTrue, 1)); + TF_ASSERT_OK(stream.Memset32(&a, 1, byte_length)); + TF_ASSERT_OK(stream.Memset32(&b, 2, byte_length)); + TF_ASSERT_OK(stream.MemZero(&c, byte_length)); // if (pred == true) c = a + b CommandBuffer::Builder then_builder = [&](CommandBuffer* then_cmd) { @@ -378,37 +730,37 @@ TEST(GpuCommandBufferTest, ConditionalIf) { // Create a command buffer with a single conditional operation. auto cmd_buffer = CommandBuffer::Create(executor).value(); - TF_ASSERT_OK(cmd_buffer.If(executor, pred, then_builder)); - TF_ASSERT_OK(cmd_buffer.Finalize()); + TF_ASSERT_OK(cmd_buffer->If(executor, pred, then_builder)); + TF_ASSERT_OK(cmd_buffer->Finalize()); - TF_ASSERT_OK(executor->Submit(&stream, cmd_buffer)); + TF_ASSERT_OK(executor->Submit(&stream, *cmd_buffer)); // Copy `c` data back to host. std::vector dst(4, 42); - stream.ThenMemcpy(dst.data(), c, byte_length); + TF_ASSERT_OK(stream.Memcpy(dst.data(), c, byte_length)); std::vector expected = {3, 3, 3, 3}; ASSERT_EQ(dst, expected); // Reset predicate to false and clear output buffer. constexpr bool kFalse = false; - stream.ThenMemcpy(&pred, &kFalse, 1); - stream.ThenMemZero(&c, byte_length); + TF_ASSERT_OK(stream.Memcpy(&pred, &kFalse, 1)); + TF_ASSERT_OK(stream.MemZero(&c, byte_length)); // Submit the same command buffer, but this time it should not execute // conditional branch as conditional handle should be updated to false. - TF_ASSERT_OK(executor->Submit(&stream, cmd_buffer)); + TF_ASSERT_OK(executor->Submit(&stream, *cmd_buffer)); - stream.ThenMemcpy(dst.data(), c, byte_length); + TF_ASSERT_OK(stream.Memcpy(dst.data(), c, byte_length)); std::vector zeroes = {0, 0, 0, 0}; ASSERT_EQ(dst, zeroes); // Prepare argument for graph update: d = 0 DeviceMemory d = executor->AllocateArray(length, 0); - stream.ThenMemZero(&d, byte_length); + TF_ASSERT_OK(stream.MemZero(&d, byte_length)); // Set predicate buffer to true to run conditional command buffer. - stream.ThenMemcpy(&pred, &kTrue, 1); + TF_ASSERT_OK(stream.Memcpy(&pred, &kTrue, 1)); // if (pred == true) d = a + b (write to a new location). then_builder = [&](CommandBuffer* then_cmd) { @@ -416,44 +768,38 @@ TEST(GpuCommandBufferTest, ConditionalIf) { }; // Update command buffer with a conditional to use new builder. - TF_ASSERT_OK(cmd_buffer.Update()); - TF_ASSERT_OK(cmd_buffer.If(executor, pred, then_builder)); - TF_ASSERT_OK(cmd_buffer.Finalize()); + TF_ASSERT_OK(cmd_buffer->Update()); + TF_ASSERT_OK(cmd_buffer->If(executor, pred, then_builder)); + TF_ASSERT_OK(cmd_buffer->Finalize()); - TF_ASSERT_OK(executor->Submit(&stream, cmd_buffer)); + TF_ASSERT_OK(executor->Submit(&stream, *cmd_buffer)); // Copy `d` data back to host. std::fill(dst.begin(), dst.end(), 42); - stream.ThenMemcpy(dst.data(), d, byte_length); + TF_ASSERT_OK(stream.Memcpy(dst.data(), d, byte_length)); ASSERT_EQ(dst, expected); } TEST(GpuCommandBufferTest, ConditionalIfElse) { - Platform* platform = GpuPlatform(); - if (!CommandBuffer::SupportsConditionalCommands(platform)) { + if (!IsAtLeastCuda12300()) { GTEST_SKIP() << "CUDA graph conditionals are not supported"; } + Platform* platform = GpuPlatform(); StreamExecutor* executor = platform->ExecutorForDevice(0).value(); Stream stream(executor); - stream.Init(); - ASSERT_TRUE(stream.ok()); + TF_ASSERT_OK(stream.Initialize()); - AddI32Kernel add(executor); - MulI32Kernel mul(executor); - - { // Load addition kernel. - MultiKernelLoaderSpec spec(/*arity=*/3); - spec.AddInProcessSymbol(internal::GetAddI32Kernel(), "add"); - TF_ASSERT_OK(executor->GetKernel(spec, &add)); - } + // Load addition kernel. + MultiKernelLoaderSpec add_spec(/*arity=*/3); + add_spec.AddInProcessSymbol(internal::GetAddI32Kernel(), "add"); + TF_ASSERT_OK_AND_ASSIGN(auto add, AddI32Kernel::Create(executor, add_spec)); - { // Load multiplication kernel. - MultiKernelLoaderSpec spec(/*arity=*/3); - spec.AddInProcessSymbol(internal::GetMulI32Kernel(), "mul"); - TF_ASSERT_OK(executor->GetKernel(spec, &mul)); - } + // Load multiplication kernel. + MultiKernelLoaderSpec mul_spec(/*arity=*/3); + mul_spec.AddInProcessSymbol(internal::GetMulI32Kernel(), "mul"); + TF_ASSERT_OK_AND_ASSIGN(auto mul, MulI32Kernel::Create(executor, mul_spec)); int64_t length = 4; int64_t byte_length = sizeof(int32_t) * length; @@ -465,10 +811,10 @@ TEST(GpuCommandBufferTest, ConditionalIfElse) { DeviceMemory c = executor->AllocateArray(length, 0); constexpr bool kTrue = true; - stream.ThenMemcpy(&pred, &kTrue, 1); - stream.ThenMemset32(&a, 2, byte_length); - stream.ThenMemset32(&b, 3, byte_length); - stream.ThenMemZero(&c, byte_length); + TF_ASSERT_OK(stream.Memcpy(&pred, &kTrue, 1)); + TF_ASSERT_OK(stream.Memset32(&a, 2, byte_length)); + TF_ASSERT_OK(stream.Memset32(&b, 3, byte_length)); + TF_ASSERT_OK(stream.MemZero(&c, byte_length)); // if (pred == true) c = a + b CommandBuffer::Builder then_builder = [&](CommandBuffer* then_cmd) { @@ -482,35 +828,35 @@ TEST(GpuCommandBufferTest, ConditionalIfElse) { // Create a command buffer with a single conditional operation. auto cmd_buffer = CommandBuffer::Create(executor).value(); - TF_ASSERT_OK(cmd_buffer.IfElse(executor, pred, then_builder, else_builder)); - TF_ASSERT_OK(cmd_buffer.Finalize()); + TF_ASSERT_OK(cmd_buffer->IfElse(executor, pred, then_builder, else_builder)); + TF_ASSERT_OK(cmd_buffer->Finalize()); - TF_ASSERT_OK(executor->Submit(&stream, cmd_buffer)); + TF_ASSERT_OK(executor->Submit(&stream, *cmd_buffer)); TF_ASSERT_OK(stream.BlockHostUntilDone()); // Copy `c` data back to host. std::vector dst(4, 42); - stream.ThenMemcpy(dst.data(), c, byte_length); + TF_ASSERT_OK(stream.Memcpy(dst.data(), c, byte_length)); std::vector expected_add = {5, 5, 5, 5}; ASSERT_EQ(dst, expected_add); // Reset predicate to false. constexpr bool kFalse = false; - stream.ThenMemcpy(&pred, &kFalse, 1); + TF_ASSERT_OK(stream.Memcpy(&pred, &kFalse, 1)); // Submit the same command buffer, but this time it should execute `else` // branch and multiply inputs. - TF_ASSERT_OK(executor->Submit(&stream, cmd_buffer)); + TF_ASSERT_OK(executor->Submit(&stream, *cmd_buffer)); TF_ASSERT_OK(stream.BlockHostUntilDone()); - stream.ThenMemcpy(dst.data(), c, byte_length); + TF_ASSERT_OK(stream.Memcpy(dst.data(), c, byte_length)); std::vector expected_mul = {6, 6, 6, 6}; ASSERT_EQ(dst, expected_mul); // Prepare argument for graph update: d = 0 DeviceMemory d = executor->AllocateArray(length, 0); - stream.ThenMemZero(&d, byte_length); + TF_ASSERT_OK(stream.MemZero(&d, byte_length)); // if (pred == false) d = a * b (write to a new location). else_builder = [&](CommandBuffer* else_cmd) { @@ -518,45 +864,39 @@ TEST(GpuCommandBufferTest, ConditionalIfElse) { }; // Update command buffer with a conditional to use new `else` builder. - TF_ASSERT_OK(cmd_buffer.Update()); - TF_ASSERT_OK(cmd_buffer.IfElse(executor, pred, then_builder, else_builder)); - TF_ASSERT_OK(cmd_buffer.Finalize()); + TF_ASSERT_OK(cmd_buffer->Update()); + TF_ASSERT_OK(cmd_buffer->IfElse(executor, pred, then_builder, else_builder)); + TF_ASSERT_OK(cmd_buffer->Finalize()); - TF_ASSERT_OK(executor->Submit(&stream, cmd_buffer)); + TF_ASSERT_OK(executor->Submit(&stream, *cmd_buffer)); TF_ASSERT_OK(stream.BlockHostUntilDone()); // Copy `d` data back to host. std::fill(dst.begin(), dst.end(), 42); - stream.ThenMemcpy(dst.data(), d, byte_length); + TF_ASSERT_OK(stream.Memcpy(dst.data(), d, byte_length)); ASSERT_EQ(dst, expected_mul); } TEST(GpuCommandBufferTest, ConditionalCase) { - Platform* platform = GpuPlatform(); - if (!CommandBuffer::SupportsConditionalCommands(platform)) { + if (!IsAtLeastCuda12300()) { GTEST_SKIP() << "CUDA graph conditionals are not supported"; } + Platform* platform = GpuPlatform(); StreamExecutor* executor = platform->ExecutorForDevice(0).value(); Stream stream(executor); - stream.Init(); - ASSERT_TRUE(stream.ok()); + TF_ASSERT_OK(stream.Initialize()); - AddI32Kernel add(executor); - MulI32Kernel mul(executor); - - { // Load addition kernel. - MultiKernelLoaderSpec spec(/*arity=*/3); - spec.AddInProcessSymbol(internal::GetAddI32Kernel(), "add"); - TF_ASSERT_OK(executor->GetKernel(spec, &add)); - } + // Load addition kernel. + MultiKernelLoaderSpec add_spec(/*arity=*/3); + add_spec.AddInProcessSymbol(internal::GetAddI32Kernel(), "add"); + TF_ASSERT_OK_AND_ASSIGN(auto add, AddI32Kernel::Create(executor, add_spec)); - { // Load multiplication kernel. - MultiKernelLoaderSpec spec(/*arity=*/3); - spec.AddInProcessSymbol(internal::GetMulI32Kernel(), "mul"); - TF_ASSERT_OK(executor->GetKernel(spec, &mul)); - } + // Load multiplication kernel. + MultiKernelLoaderSpec mul_spec(/*arity=*/3); + mul_spec.AddInProcessSymbol(internal::GetMulI32Kernel(), "mul"); + TF_ASSERT_OK_AND_ASSIGN(auto mul, MulI32Kernel::Create(executor, mul_spec)); int64_t length = 4; int64_t byte_length = sizeof(int32_t) * length; @@ -567,10 +907,10 @@ TEST(GpuCommandBufferTest, ConditionalCase) { DeviceMemory b = executor->AllocateArray(length, 0); DeviceMemory c = executor->AllocateArray(length, 0); - stream.ThenMemset32(&index, 0, sizeof(int32_t)); - stream.ThenMemset32(&a, 2, byte_length); - stream.ThenMemset32(&b, 3, byte_length); - stream.ThenMemZero(&c, byte_length); + TF_ASSERT_OK(stream.Memset32(&index, 0, sizeof(int32_t))); + TF_ASSERT_OK(stream.Memset32(&a, 2, byte_length)); + TF_ASSERT_OK(stream.Memset32(&b, 3, byte_length)); + TF_ASSERT_OK(stream.MemZero(&c, byte_length)); // if (index == 0) c = a + b CommandBuffer::Builder branch0 = [&](CommandBuffer* branch0_cmd) { @@ -584,68 +924,63 @@ TEST(GpuCommandBufferTest, ConditionalCase) { // Create a command buffer with a single conditional operation. auto cmd_buffer = CommandBuffer::Create(executor).value(); - TF_ASSERT_OK(cmd_buffer.Case(executor, index, {branch0, branch1})); - TF_ASSERT_OK(cmd_buffer.Finalize()); + TF_ASSERT_OK(cmd_buffer->Case(executor, index, {branch0, branch1})); + TF_ASSERT_OK(cmd_buffer->Finalize()); - TF_ASSERT_OK(executor->Submit(&stream, cmd_buffer)); + TF_ASSERT_OK(executor->Submit(&stream, *cmd_buffer)); TF_ASSERT_OK(stream.BlockHostUntilDone()); // Copy `c` data back to host. std::vector dst(4, 42); - stream.ThenMemcpy(dst.data(), c, byte_length); + TF_ASSERT_OK(stream.Memcpy(dst.data(), c, byte_length)); std::vector expected_add = {5, 5, 5, 5}; ASSERT_EQ(dst, expected_add); // Set index to `1` - stream.ThenMemset32(&index, 1, sizeof(int32_t)); + TF_ASSERT_OK(stream.Memset32(&index, 1, sizeof(int32_t))); // Submit the same command buffer, but this time it should multiply inputs. - TF_ASSERT_OK(executor->Submit(&stream, cmd_buffer)); + TF_ASSERT_OK(executor->Submit(&stream, *cmd_buffer)); TF_ASSERT_OK(stream.BlockHostUntilDone()); - stream.ThenMemcpy(dst.data(), c, byte_length); + TF_ASSERT_OK(stream.Memcpy(dst.data(), c, byte_length)); std::vector expected_mul = {6, 6, 6, 6}; ASSERT_EQ(dst, expected_mul); // Set index to `-1` (out of bound index value). - stream.ThenMemset32(&index, -1, sizeof(int32_t)); + TF_ASSERT_OK(stream.Memset32(&index, -1, sizeof(int32_t))); - TF_ASSERT_OK(executor->Submit(&stream, cmd_buffer)); + TF_ASSERT_OK(executor->Submit(&stream, *cmd_buffer)); TF_ASSERT_OK(stream.BlockHostUntilDone()); - stream.ThenMemcpy(dst.data(), c, byte_length); + TF_ASSERT_OK(stream.Memcpy(dst.data(), c, byte_length)); ASSERT_EQ(dst, expected_mul); // Set index to `2` (out of bound index value). - stream.ThenMemset32(&index, 2, sizeof(int32_t)); + TF_ASSERT_OK(stream.Memset32(&index, 2, sizeof(int32_t))); - TF_ASSERT_OK(executor->Submit(&stream, cmd_buffer)); + TF_ASSERT_OK(executor->Submit(&stream, *cmd_buffer)); TF_ASSERT_OK(stream.BlockHostUntilDone()); - stream.ThenMemcpy(dst.data(), c, byte_length); + TF_ASSERT_OK(stream.Memcpy(dst.data(), c, byte_length)); ASSERT_EQ(dst, expected_mul); } TEST(GpuCommandBufferTest, ConditionalFor) { - Platform* platform = GpuPlatform(); - if (!CommandBuffer::SupportsConditionalCommands(platform)) { + if (!IsAtLeastCuda12300()) { GTEST_SKIP() << "CUDA graph conditionals are not supported"; } + Platform* platform = GpuPlatform(); StreamExecutor* executor = platform->ExecutorForDevice(0).value(); Stream stream(executor); - stream.Init(); - ASSERT_TRUE(stream.ok()); - - AddI32Kernel add(executor); + TF_ASSERT_OK(stream.Initialize()); - { // Load addition kernel. - MultiKernelLoaderSpec spec(/*arity=*/3); - spec.AddInProcessSymbol(internal::GetAddI32Kernel(), "add"); - TF_ASSERT_OK(executor->GetKernel(spec, &add)); - } + MultiKernelLoaderSpec spec(/*arity=*/3); + spec.AddInProcessSymbol(internal::GetAddI32Kernel(), "add"); + TF_ASSERT_OK_AND_ASSIGN(auto add, AddI32Kernel::Create(executor, spec)); int64_t length = 4; int64_t byte_length = sizeof(int32_t) * length; @@ -656,9 +991,9 @@ TEST(GpuCommandBufferTest, ConditionalFor) { DeviceMemory b = executor->AllocateArray(length, 0); // Set loop counter to 100 to check that command buffer resets it. - stream.ThenMemset32(&loop_counter, 100, sizeof(int32_t)); - stream.ThenMemset32(&a, 1, byte_length); - stream.ThenMemZero(&b, byte_length); + TF_ASSERT_OK(stream.Memset32(&loop_counter, 100, sizeof(int32_t))); + TF_ASSERT_OK(stream.Memset32(&a, 1, byte_length)); + TF_ASSERT_OK(stream.MemZero(&b, byte_length)); // Loop body: b = a + b CommandBuffer::Builder body_builder = [&](CommandBuffer* body_cmd) { @@ -669,45 +1004,41 @@ TEST(GpuCommandBufferTest, ConditionalFor) { // Create a command buffer with a single conditional operation. auto cmd_buffer = CommandBuffer::Create(executor).value(); - TF_ASSERT_OK(cmd_buffer.For(executor, num_iters, loop_counter, body_builder)); - TF_ASSERT_OK(cmd_buffer.Finalize()); + TF_ASSERT_OK( + cmd_buffer->For(executor, num_iters, loop_counter, body_builder)); + TF_ASSERT_OK(cmd_buffer->Finalize()); - TF_ASSERT_OK(executor->Submit(&stream, cmd_buffer)); + TF_ASSERT_OK(executor->Submit(&stream, *cmd_buffer)); // Copy `b` data back to host. std::vector dst(4, 42); - stream.ThenMemcpy(dst.data(), b, byte_length); + TF_ASSERT_OK(stream.Memcpy(dst.data(), b, byte_length)); std::vector expected = {10, 10, 10, 10}; ASSERT_EQ(dst, expected); } TEST(GpuCommandBufferTest, ConditionalWhile) { - Platform* platform = GpuPlatform(); - if (!CommandBuffer::SupportsConditionalCommands(platform)) { + if (!IsAtLeastCuda12300()) { GTEST_SKIP() << "CUDA graph conditionals are not supported"; } + Platform* platform = GpuPlatform(); StreamExecutor* executor = platform->ExecutorForDevice(0).value(); Stream stream(executor); - stream.Init(); - ASSERT_TRUE(stream.ok()); + TF_ASSERT_OK(stream.Initialize()); - AddI32Kernel add(executor); - IncAndCmpKernel inc_and_cmp(executor); - - { // Load addition kernel. - MultiKernelLoaderSpec spec(/*arity=*/3); - spec.AddInProcessSymbol(internal::GetAddI32Kernel(), "add"); - TF_ASSERT_OK(executor->GetKernel(spec, &add)); - } + // Load addition kernel. + MultiKernelLoaderSpec add_spec(/*arity=*/3); + add_spec.AddInProcessSymbol(internal::GetAddI32Kernel(), "add"); + TF_ASSERT_OK_AND_ASSIGN(auto add, AddI32Kernel::Create(executor, add_spec)); - { // Load inc_and_cmp kernel. - MultiKernelLoaderSpec spec(/*arity=*/3); - spec.AddInProcessSymbol(internal::GetIncAndCmpKernel(), "inc_and_cmp"); - TF_ASSERT_OK(executor->GetKernel(spec, &inc_and_cmp)); - } + // Load inc_and_cmp kernel. + MultiKernelLoaderSpec icmp_spec(/*arity=*/3); + icmp_spec.AddInProcessSymbol(internal::GetIncAndCmpKernel(), "inc_and_cmp"); + TF_ASSERT_OK_AND_ASSIGN(auto inc_and_cmp, + IncAndCmpKernel::Create(executor, icmp_spec)); int64_t length = 4; int64_t byte_length = sizeof(int32_t) * length; @@ -721,10 +1052,10 @@ TEST(GpuCommandBufferTest, ConditionalWhile) { DeviceMemory b = executor->AllocateArray(length, 0); static constexpr bool kFalse = false; - stream.ThenMemcpy(&pred, &kFalse, 1); - stream.ThenMemset32(&loop_counter, 0, sizeof(int32_t)); - stream.ThenMemset32(&a, 1, byte_length); - stream.ThenMemZero(&b, byte_length); + TF_ASSERT_OK(stream.Memcpy(&pred, &kFalse, 1)); + TF_ASSERT_OK(stream.Memset32(&loop_counter, 0, sizeof(int32_t))); + TF_ASSERT_OK(stream.Memset32(&a, 1, byte_length)); + TF_ASSERT_OK(stream.MemZero(&b, byte_length)); int32_t num_iters = 10; @@ -741,19 +1072,203 @@ TEST(GpuCommandBufferTest, ConditionalWhile) { // Create a command buffer with a single conditional operation. auto cmd_buffer = CommandBuffer::Create(executor).value(); - TF_ASSERT_OK(cmd_buffer.While(executor, pred, cond_builder, body_builder)); - TF_ASSERT_OK(cmd_buffer.Finalize()); + TF_ASSERT_OK(cmd_buffer->While(executor, pred, cond_builder, body_builder)); + TF_ASSERT_OK(cmd_buffer->Finalize()); - TF_ASSERT_OK(executor->Submit(&stream, cmd_buffer)); + TF_ASSERT_OK(executor->Submit(&stream, *cmd_buffer)); // Copy `b` data back to host. std::vector dst(4, 42); - stream.ThenMemcpy(dst.data(), b, byte_length); + TF_ASSERT_OK(stream.Memcpy(dst.data(), b, byte_length)); std::vector expected = {10, 10, 10, 10}; ASSERT_EQ(dst, expected); } +TEST(GpuCommandBufferTest, ConditionalIfInExecutionScope) { + if (!IsAtLeastCuda12300()) { + GTEST_SKIP() << "CUDA graph conditionals are not supported"; + } + + Platform* platform = GpuPlatform(); + StreamExecutor* executor = platform->ExecutorForDevice(0).value(); + + Stream stream(executor); + TF_ASSERT_OK(stream.Initialize()); + + CommandBuffer::ExecutionScopeId s0 = CommandBuffer::ExecutionScopeId(0); + CommandBuffer::ExecutionScopeId s1 = CommandBuffer::ExecutionScopeId(1); + + DeviceMemory pred = executor->AllocateArray(1, 0); + + constexpr bool kTrue = true; + TF_ASSERT_OK(stream.Memcpy(&pred, &kTrue, 1)); + + // Allocate device buffers for memset operations. + std::vector> buffers; + for (size_t i = 0; i < 3; ++i) { + buffers.push_back(executor->AllocateArray(1, 0)); + } + + // Transfer buffers back to host. + auto transfer_buffers = [&]() -> std::vector { + std::vector dst(buffers.size(), 0); + for (size_t i = 0; i < buffers.size(); ++i) { + stream.Memcpy(dst.data() + i, buffers[i], sizeof(int32_t)).IgnoreError(); + } + return dst; + }; + + auto record = [&](CommandBuffer* cmd_buffer, uint32_t bit_pattern) { + // Record memsets in execution scope #0 + TF_RETURN_IF_ERROR(cmd_buffer->Memset(s0, &buffers[0], bit_pattern + 0, 1)); + TF_RETURN_IF_ERROR(cmd_buffer->Memset(s0, &buffers[1], bit_pattern + 1, 1)); + + // Record If in execution scope #1 + TF_RETURN_IF_ERROR( + cmd_buffer->If(s1, executor, pred, [&](CommandBuffer* then_cmd) { + return then_cmd->Memset(&buffers[2], bit_pattern + 2, 1); + })); + + // Create a barrier in execution scope #0. + TF_RETURN_IF_ERROR(cmd_buffer->Barrier(executor, s0)); + + // Create a barrier between two execution scopes. + TF_RETURN_IF_ERROR(cmd_buffer->Barrier(executor, {s0, s1})); + + return cmd_buffer->Finalize(); + }; + + // Create a command buffer with a DAG of memset commands. + auto cmd_buffer = CommandBuffer::Create(executor).value(); + TF_ASSERT_OK(record(cmd_buffer.get(), 42)); + TF_ASSERT_OK(executor->Submit(&stream, *cmd_buffer)); + + std::vector expected = {42, 43, 44}; + ASSERT_EQ(transfer_buffers(), expected); + + // Check the command buffer structure. + GpuCommandBuffer* gpu_cmd_buffer = GpuCommandBuffer::Cast(cmd_buffer.get()); + + auto nodes0 = gpu_cmd_buffer->nodes(s0); + auto nodes1 = gpu_cmd_buffer->nodes(s1); + auto barriers0 = gpu_cmd_buffer->barriers(s0); + auto barriers1 = gpu_cmd_buffer->barriers(s1); + + ASSERT_EQ(nodes0.size(), 2); + ASSERT_EQ(nodes1.size(), 2); + ASSERT_EQ(barriers0.size(), 3); + ASSERT_EQ(barriers1.size(), 3); + + EXPECT_EQ(Deps(barriers0[0]), ExpectedDeps(nodes0[0], nodes0[1])); + EXPECT_EQ(barriers0[0].handle, barriers0[1].handle); + + EXPECT_EQ(barriers1[0].handle, nodes1[0].handle); + EXPECT_EQ(barriers1[1].handle, nodes1[1].handle); + + // s0 and s1 share broadcasted barrier. + EXPECT_TRUE(barriers0[2].handle == barriers1[2].handle); + EXPECT_EQ(Deps(barriers0[2]), ExpectedDeps(barriers0[1], nodes1[1])); + + // TODO(b/326284532): Add a test for bit pattern update. + + // Disable conditional branch. + constexpr bool kFalse = false; + TF_ASSERT_OK(stream.Memcpy(&pred, &kFalse, 1)); + TF_ASSERT_OK(stream.MemZero(&buffers[2], sizeof(int32_t))); + TF_ASSERT_OK(executor->Submit(&stream, *cmd_buffer)); + + expected = {42, 43, 0}; + ASSERT_EQ(transfer_buffers(), expected); +} + +TEST(GpuCommandBufferTest, ConditionalWhileInExecutionScope) { + if (!IsAtLeastCuda12300()) { + GTEST_SKIP() << "CUDA graph conditionals are not supported"; + } + + Platform* platform = GpuPlatform(); + StreamExecutor* executor = platform->ExecutorForDevice(0).value(); + + Stream stream(executor); + TF_ASSERT_OK(stream.Initialize()); + + CommandBuffer::ExecutionScopeId s0 = CommandBuffer::ExecutionScopeId(0); + CommandBuffer::ExecutionScopeId s1 = CommandBuffer::ExecutionScopeId(1); + + // Load addition kernel. + MultiKernelLoaderSpec add_spec(/*arity=*/3); + add_spec.AddInProcessSymbol(internal::GetAddI32Kernel(), "add"); + TF_ASSERT_OK_AND_ASSIGN(auto add, AddI32Kernel::Create(executor, add_spec)); + + // Load inc_and_cmp kernel. + MultiKernelLoaderSpec icmp_spec(/*arity=*/3); + icmp_spec.AddInProcessSymbol(internal::GetIncAndCmpKernel(), "inc_and_cmp"); + TF_ASSERT_OK_AND_ASSIGN(auto inc_and_cmp, + IncAndCmpKernel::Create(executor, icmp_spec)); + + DeviceMemory pred = executor->AllocateArray(1, 0); + DeviceMemory loop_counter = executor->AllocateArray(1, 0); + DeviceMemory a = executor->AllocateArray(1, 0); + DeviceMemory b = executor->AllocateArray(1, 0); + DeviceMemory c = executor->AllocateArray(1, 0); + + TF_ASSERT_OK(stream.MemZero(&loop_counter, sizeof(int32_t))); + TF_ASSERT_OK(stream.Memset32(&a, 1, sizeof(int32_t))); + TF_ASSERT_OK(stream.MemZero(&b, sizeof(int32_t))); + + auto record = [&](CommandBuffer* cmd_buffer, uint32_t bit_pattern, + int32_t num_iters) { + // Record memset in execution scope #0 + TF_RETURN_IF_ERROR(cmd_buffer->Memset(s0, &c, bit_pattern, 1)); + + // Record While in execution scope #1 + TF_RETURN_IF_ERROR(cmd_buffer->While( + s1, executor, pred, + // Loop cond: loop_counter++ < num_iters; + [&](CommandBuffer* cond_cmd) { + return cond_cmd->Launch(inc_and_cmp, ThreadDim(), BlockDim(), + loop_counter, pred, num_iters); + }, + // Loop body: b = a + b + [&](CommandBuffer* body_cmd) { + return body_cmd->Launch(add, ThreadDim(), BlockDim(), a, b, b); + })); + + // Create a barrier between two execution scopes. + TF_RETURN_IF_ERROR(cmd_buffer->Barrier(executor, {s0, s1})); + + return cmd_buffer->Finalize(); + }; + + // Create a command buffer with a single conditional operation. + auto cmd_buffer = CommandBuffer::Create(executor).value(); + TF_ASSERT_OK(record(cmd_buffer.get(), 42, 10)); + TF_ASSERT_OK(executor->Submit(&stream, *cmd_buffer)); + + // Copy `b` and `c` data back to host. + int32_t b_dst, c_dst; + stream.ThenMemcpy(&b_dst, b, sizeof(int32_t)); + stream.ThenMemcpy(&c_dst, c, sizeof(int32_t)); + + EXPECT_EQ(b_dst, 10); + EXPECT_EQ(c_dst, 42); + + // Update bit pattern and number of iterations + TF_ASSERT_OK(cmd_buffer->Update()); + TF_ASSERT_OK(record(cmd_buffer.get(), 43, 20)); + + TF_ASSERT_OK(stream.MemZero(&loop_counter, sizeof(int32_t))); + TF_ASSERT_OK(stream.MemZero(&b, sizeof(int32_t))); + TF_ASSERT_OK(executor->Submit(&stream, *cmd_buffer)); + + stream.ThenMemcpy(&b_dst, b, sizeof(int32_t)); + stream.ThenMemcpy(&c_dst, c, sizeof(int32_t)); + + EXPECT_EQ(b_dst, 20); + EXPECT_EQ(c_dst, 43); +} + //===----------------------------------------------------------------------===// // Performance benchmarks below //===----------------------------------------------------------------------===// @@ -767,19 +1282,18 @@ static void BM_CreateCommandBuffer(benchmark::State& state) { Platform* platform = GpuPlatform(); StreamExecutor* executor = platform->ExecutorForDevice(0).value(); - MultiKernelLoaderSpec spec = GetAddI32KernelSpec(); - - AddI32Kernel add(executor); - CHECK_OK(executor->GetKernel(spec, &add)); + MultiKernelLoaderSpec spec(/*arity=*/3); + spec.AddInProcessSymbol(internal::GetAddI32Kernel(), "add"); + TF_ASSERT_OK_AND_ASSIGN(auto add, AddI32Kernel::Create(executor, spec)); DeviceMemory b = executor->AllocateArray(1, 0); for (auto s : state) { auto cmd_buffer = CommandBuffer::Create(executor, nested).value(); for (int i = 1; i < state.range(0); ++i) { - CHECK_OK(cmd_buffer.Launch(add, ThreadDim(), BlockDim(4), b, b, b)); + CHECK_OK(cmd_buffer->Launch(add, ThreadDim(), BlockDim(4), b, b, b)); } - CHECK_OK(cmd_buffer.Finalize()); + CHECK_OK(cmd_buffer->Finalize()); } } @@ -790,13 +1304,11 @@ static void BM_TraceCommandBuffer(benchmark::State& state) { StreamExecutor* executor = platform->ExecutorForDevice(0).value(); Stream stream(executor); - stream.Init(); - CHECK(stream.ok()); + TF_CHECK_OK(stream.Initialize()); - MultiKernelLoaderSpec spec = GetAddI32KernelSpec(); - - AddI32Kernel add(executor); - CHECK_OK(executor->GetKernel(spec, &add)); + MultiKernelLoaderSpec spec(/*arity=*/3); + spec.AddInProcessSymbol(internal::GetAddI32Kernel(), "add"); + TF_ASSERT_OK_AND_ASSIGN(auto add, AddI32Kernel::Create(executor, spec)); DeviceMemory b = executor->AllocateArray(1, 0); @@ -818,25 +1330,24 @@ static void BM_UpdateCommandBuffer(benchmark::State& state) { Platform* platform = GpuPlatform(); StreamExecutor* executor = platform->ExecutorForDevice(0).value(); - MultiKernelLoaderSpec spec = GetAddI32KernelSpec(); - - AddI32Kernel add(executor); - CHECK_OK(executor->GetKernel(spec, &add)); + MultiKernelLoaderSpec spec(/*arity=*/3); + spec.AddInProcessSymbol(internal::GetAddI32Kernel(), "add"); + TF_ASSERT_OK_AND_ASSIGN(auto add, AddI32Kernel::Create(executor, spec)); DeviceMemory b = executor->AllocateArray(1, 0); auto cmd_buffer = CommandBuffer::Create(executor, primary).value(); for (int i = 1; i < state.range(0); ++i) { - CHECK_OK(cmd_buffer.Launch(add, ThreadDim(), BlockDim(4), b, b, b)); + CHECK_OK(cmd_buffer->Launch(add, ThreadDim(), BlockDim(4), b, b, b)); } - CHECK_OK(cmd_buffer.Finalize()); + CHECK_OK(cmd_buffer->Finalize()); for (auto s : state) { - CHECK_OK(cmd_buffer.Update()); + CHECK_OK(cmd_buffer->Update()); for (int i = 1; i < state.range(0); ++i) { - CHECK_OK(cmd_buffer.Launch(add, ThreadDim(), BlockDim(4), b, b, b)); + CHECK_OK(cmd_buffer->Launch(add, ThreadDim(), BlockDim(4), b, b, b)); } - CHECK_OK(cmd_buffer.Finalize()); + CHECK_OK(cmd_buffer->Finalize()); } } diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_cudamallocasync_allocator.cc b/third_party/xla/xla/stream_executor/gpu/gpu_cudamallocasync_allocator.cc index 4247fdee450b24..a97ebdd56b3e77 100644 --- a/third_party/xla/xla/stream_executor/gpu/gpu_cudamallocasync_allocator.cc +++ b/third_party/xla/xla/stream_executor/gpu/gpu_cudamallocasync_allocator.cc @@ -15,6 +15,8 @@ limitations under the License. #include "xla/stream_executor/gpu/gpu_cudamallocasync_allocator.h" +#include +#include #include #include #include @@ -28,13 +30,13 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" -#include "xla/stream_executor/gpu/gpu_init.h" -#include "xla/stream_executor/stream_executor.h" +#include "xla/stream_executor/gpu/gpu_init.h" // IWYU pragma: keep +#include "xla/stream_executor/stream_executor.h" // IWYU pragma: keep #include "tsl/framework/allocator.h" #include "tsl/framework/device_id.h" #include "tsl/platform/logging.h" #include "tsl/platform/mutex.h" -#include "tsl/util/env_var.h" +#include "tsl/util/env_var.h" // IWYU pragma: keep namespace stream_executor { @@ -441,8 +443,7 @@ void GpuCudaMallocAsyncAllocator::SetStreamAndPreallocateMemory(void* stream) { void* ptr = AllocateRaw(0, prealloc_size); DeallocateRaw(ptr); VLOG(2) << Name() << " GpuCudaMallocAsyncAllocator reserved the pool for " - << prealloc_size << " bytes" - << ". First ptr: " << ptr; + << prealloc_size << " bytes" << ". First ptr: " << ptr; ClearStats(); } #endif diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_cudamallocasync_allocator.h b/third_party/xla/xla/stream_executor/gpu/gpu_cudamallocasync_allocator.h index 4fcf31f4d3bd48..7e9d274163228b 100644 --- a/third_party/xla/xla/stream_executor/gpu/gpu_cudamallocasync_allocator.h +++ b/third_party/xla/xla/stream_executor/gpu/gpu_cudamallocasync_allocator.h @@ -16,16 +16,17 @@ limitations under the License. #ifndef XLA_STREAM_EXECUTOR_GPU_GPU_CUDAMALLOCASYNC_ALLOCATOR_H_ #define XLA_STREAM_EXECUTOR_GPU_GPU_CUDAMALLOCASYNC_ALLOCATOR_H_ +#include +#include #include #include #include #include "absl/base/thread_annotations.h" #include "absl/container/flat_hash_map.h" -#include "xla/stream_executor/stream_executor.h" +#include "xla/stream_executor/stream_executor.h" // IWYU pragma: keep #include "tsl/framework/allocator.h" #include "tsl/framework/device_id.h" -#include "tsl/platform/macros.h" #include "tsl/platform/mutex.h" #if GOOGLE_CUDA @@ -34,7 +35,6 @@ limitations under the License. #define TF_CUDA_MALLOC_ASYNC_SUPPORTED CUDA_VERSION >= 11020 #endif // GOOGLE_CUDA - namespace stream_executor { // An allocator that wraps cudaMallocAsync. It has fewer fragmentation diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_diagnostics.h b/third_party/xla/xla/stream_executor/gpu/gpu_diagnostics.h index 128f9b703c60d0..678a34e50a40ed 100644 --- a/third_party/xla/xla/stream_executor/gpu/gpu_diagnostics.h +++ b/third_party/xla/xla/stream_executor/gpu/gpu_diagnostics.h @@ -16,10 +16,10 @@ limitations under the License. #ifndef XLA_STREAM_EXECUTOR_GPU_GPU_DIAGNOSTICS_H_ #define XLA_STREAM_EXECUTOR_GPU_GPU_DIAGNOSTICS_H_ +#include #include #include "absl/status/statusor.h" -#include "xla/stream_executor/platform/port.h" namespace stream_executor { namespace gpu { diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_driver.h b/third_party/xla/xla/stream_executor/gpu/gpu_driver.h index e9e3bafc323f29..9046cbdbb65775 100644 --- a/third_party/xla/xla/stream_executor/gpu/gpu_driver.h +++ b/third_party/xla/xla/stream_executor/gpu/gpu_driver.h @@ -24,6 +24,7 @@ limitations under the License. #include #include #include +#include #include "absl/status/status.h" #include "absl/status/statusor.h" @@ -33,6 +34,10 @@ limitations under the License. #include "xla/stream_executor/gpu/gpu_types.h" #include "xla/stream_executor/platform.h" +#ifdef GOOGLE_CUDA +#include "third_party/gpus/cuda/include/cuda.h" +#endif + namespace stream_executor { namespace gpu { @@ -171,7 +176,7 @@ class GpuDriver { GpuContext* context, stream_executor::StreamPriority stream_priority); // Virtual memory support was added to CUDA in 10.2 -#if CUDA_VERSION >= 10020 +#if defined(GOOGLE_CUDA) && CUDA_VERSION >= 10020 // Reserves a range of virtual device memory addresses via // cuMemAddressReserve. bytes must be a multiple of the host page size. @@ -226,7 +231,7 @@ class GpuDriver { // https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__VA.html#group__CUDA__VA_1gfb50aac00c848fd7087e858f59bf7e2a static void UnmapMemory(GpuContext* context, GpuDevicePtr va, uint64_t bytes); -#endif // CUDA_VERSION >= 10200 +#endif // defined(GOOGLE_CUDA) && CUDA_VERSION >= 10020 // Given a device ordinal, returns a device handle into the device outparam, // which must not be null. @@ -433,6 +438,12 @@ class GpuDriver { static absl::StatusOr GraphNodeGetType( GpuGraphNodeHandle node); + // Returns a node's dependencies. + // + // https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__GRAPH.html#group__CUDA__GRAPH_1g048f4c0babcbba64a933fc277cd45083 + static absl::StatusOr> + GraphNodeGetDependencies(GpuGraphNodeHandle node); + // Destroys an executable graph. // https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__GRAPH.html#group__CUDA__GRAPH_1ga32ad4944cc5d408158207c978bc43a7 // https://rocm.docs.amd.com/projects/HIPIFY/en/latest/tables/CUDA_Driver_API_functions_supported_by_HIP.html#graph-management @@ -489,20 +500,21 @@ class GpuDriver { // https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__GRAPH.html#group__CUDA__GRAPH_1g4210c258cbba352040a26d1b4e658f9d static absl::StatusOr GraphAddNode( GpuGraphNodeHandle* node, GpuGraphHandle graph, - absl::Span deps, const GpuGraphNodeParams& params); + absl::Span deps, + const GpuGraphNodeParams& params); // Creates an empty node and adds it to a graph. // https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__GRAPH.html#group__CUDA__GRAPH_1g14b625984430cb2d574c63f29c9b9223 - static absl::Status GraphAddEmptyNode(GpuGraphNodeHandle* node, - GpuGraphHandle graph, - absl::Span deps); + static absl::Status GraphAddEmptyNode( + GpuGraphNodeHandle* node, GpuGraphHandle graph, + absl::Span deps); // Creates a kernel execution node and adds it to a graph. // https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__GRAPH.html#group__CUDA__GRAPH_1g50d871e3bd06c1b835e52f2966ef366b // https://rocm.docs.amd.com/projects/HIPIFY/en/latest/tables/CUDA_Driver_API_functions_supported_by_HIP.html#graph-management static absl::Status GraphAddKernelNode( GpuGraphNodeHandle* node, GpuGraphHandle graph, - absl::Span deps, absl::string_view kernel_name, + absl::Span deps, absl::string_view kernel_name, GpuFunctionHandle function, unsigned int grid_dim_x, unsigned int grid_dim_y, unsigned int grid_dim_z, unsigned int block_dim_x, unsigned int block_dim_y, @@ -549,7 +561,7 @@ class GpuDriver { // https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__GRAPH.html#group__CUDA__GRAPH_1g73a351cb71b2945a0bcb913a93f69ec9 static absl::Status GraphAddMemAllocNode( GpuGraphNodeHandle* node, GpuGraphHandle graph, - absl::Span deps, MemAccessFlags access_flags, + absl::Span deps, MemAccessFlags access_flags, MemLocationType location_type, int device_id, MemAllocationType allocation_type, uint64_t size, GpuDevicePtr* d_ptr, uint64_t max_pool_size = 0); @@ -561,16 +573,15 @@ class GpuDriver { // Create a memfree node and adds it to a graph. // https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__GRAPH.html#group__CUDA__GRAPH_1geb7cdce5d9be2d28d9428e74eb00fa53 - static absl::Status GraphAddMemFreeNode(GpuGraphNodeHandle* node, - GpuGraphHandle graph, - absl::Span deps, - GpuDevicePtr gpu_dst); + static absl::Status GraphAddMemFreeNode( + GpuGraphNodeHandle* node, GpuGraphHandle graph, + absl::Span deps, GpuDevicePtr gpu_dst); // Creates a memcpy node and adds it to a graph. // https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__GRAPH.html#group__CUDA__GRAPH_1g674da6ab54a677f13e0e0e8206ff5073 static absl::Status GraphAddMemcpyD2DNode( GpuContext* context, GpuGraphNodeHandle* node, GpuGraphHandle graph, - absl::Span deps, GpuDevicePtr gpu_dst, + absl::Span deps, GpuDevicePtr gpu_dst, GpuDevicePtr gpu_src, uint64_t size); // Sets the parameters for a memcpy node in the given graphExec. @@ -583,7 +594,7 @@ class GpuDriver { // https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__GRAPH.html#group__CUDA__GRAPH_1g89dc8fc3743392777c0daa2c4aca40d3 static absl::Status GraphAddMemsetNode( GpuContext* context, GpuGraphNodeHandle* node, GpuGraphHandle graph, - absl::Span deps, GpuDevicePtr dst, + absl::Span deps, GpuDevicePtr dst, std::variant bit_pattern, uint64_t num_elements); @@ -596,10 +607,9 @@ class GpuDriver { // Creates a child graph node and adds it to a graph. // https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__GRAPH.html#group__CUDA__GRAPH_1gde52afbcf91a8c79d4d7efbe0e3b6844 - static absl::Status GraphAddChildNode(GpuGraphNodeHandle* node, - GpuGraphHandle graph, - absl::Span deps, - GpuGraphHandle child); + static absl::Status GraphAddChildNode( + GpuGraphNodeHandle* node, GpuGraphHandle graph, + absl::Span deps, GpuGraphHandle child); // Sets the parameters for a child graph node in the given graph exec. // https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__GRAPH.html#group__CUDA__GRAPH_1g8f2d9893f6b899f992db1a2942ec03ff diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_event.cc b/third_party/xla/xla/stream_executor/gpu/gpu_event.cc index 6bc79d822587b8..bc714a519343c2 100644 --- a/third_party/xla/xla/stream_executor/gpu/gpu_event.cc +++ b/third_party/xla/xla/stream_executor/gpu/gpu_event.cc @@ -16,8 +16,10 @@ limitations under the License. #include "xla/stream_executor/gpu/gpu_event.h" #include "absl/status/status.h" +#include "xla/stream_executor/gpu/gpu_driver.h" #include "xla/stream_executor/gpu/gpu_executor.h" #include "xla/stream_executor/gpu/gpu_stream.h" +#include "xla/stream_executor/gpu/gpu_types.h" namespace stream_executor { namespace gpu { diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_event.h b/third_party/xla/xla/stream_executor/gpu/gpu_event.h index 7eb543089e68a6..2c8b588dab76cb 100644 --- a/third_party/xla/xla/stream_executor/gpu/gpu_event.h +++ b/third_party/xla/xla/stream_executor/gpu/gpu_event.h @@ -18,8 +18,8 @@ limitations under the License. #include "absl/status/status.h" #include "xla/stream_executor/event.h" -#include "xla/stream_executor/gpu/gpu_driver.h" #include "xla/stream_executor/gpu/gpu_stream.h" +#include "xla/stream_executor/gpu/gpu_types.h" namespace stream_executor { namespace gpu { diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_executor.h b/third_party/xla/xla/stream_executor/gpu/gpu_executor.h index a728c2775bdc35..b05fa774407abc 100644 --- a/third_party/xla/xla/stream_executor/gpu/gpu_executor.h +++ b/third_party/xla/xla/stream_executor/gpu/gpu_executor.h @@ -22,29 +22,41 @@ limitations under the License. #ifndef XLA_STREAM_EXECUTOR_GPU_GPU_EXECUTOR_H_ #define XLA_STREAM_EXECUTOR_GPU_GPU_EXECUTOR_H_ +#include #include +#include #include #include #include #include #include +#include #include "absl/base/thread_annotations.h" #include "absl/container/flat_hash_map.h" #include "absl/functional/any_invocable.h" +#include "absl/numeric/int128.h" #include "absl/status/status.h" #include "absl/status/statusor.h" +#include "absl/synchronization/mutex.h" #include "absl/types/span.h" #include "xla/stream_executor/command_buffer.h" +#include "xla/stream_executor/device_memory.h" +#include "xla/stream_executor/device_options.h" +#include "xla/stream_executor/dnn.h" #include "xla/stream_executor/event.h" +#include "xla/stream_executor/fft.h" #include "xla/stream_executor/gpu/gpu_collectives.h" -#include "xla/stream_executor/gpu/gpu_kernel.h" +#include "xla/stream_executor/gpu/gpu_driver.h" #include "xla/stream_executor/gpu/gpu_types.h" +#include "xla/stream_executor/kernel.h" +#include "xla/stream_executor/kernel_spec.h" +#include "xla/stream_executor/launch_dim.h" #include "xla/stream_executor/module_spec.h" #include "xla/stream_executor/platform.h" -#include "xla/stream_executor/platform/port.h" #include "xla/stream_executor/stream_executor.h" #include "xla/stream_executor/stream_executor_internal.h" +#include "tsl/platform/thread_annotations.h" namespace stream_executor { @@ -52,6 +64,9 @@ class StreamExecutor; namespace gpu { +class GpuKernel; +class GpuCommandBuffer; + // CUDA-platform implementation of the platform-agnostic // StreamExecutorInterface. class GpuExecutor : public internal::StreamExecutorInterface { @@ -265,20 +280,18 @@ class GpuExecutor : public internal::StreamExecutorInterface { std::unique_ptr CreateEventImplementation() override; - std::unique_ptr CreateKernelImplementation() - override; - std::unique_ptr GetStreamImplementation() override; - absl::StatusOr> - GetCommandBufferImplementation(CommandBuffer::Mode mode) override; + absl::StatusOr> CreateKernel() override; + + absl::StatusOr> CreateCommandBuffer( + CommandBuffer::Mode mode) override; // Wraps existing Gpu graph handle into an instance of Gpu command buffer. // This is required for wrapping nested graphs constructed for conditional // nodes and owned by a parent graph executable. - std::unique_ptr - GetCommandBufferImplementation(CommandBuffer::Mode mode, GpuGraphHandle graph, - bool is_owned_graph); + std::unique_ptr CreateCommandBuffer( + CommandBuffer::Mode mode, GpuGraphHandle graph, bool is_owned_graph); void* platform_specific_context() override; @@ -324,7 +337,8 @@ class GpuExecutor : public internal::StreamExecutorInterface { // Prints to VLOG(2) information about the kernel's occupancy and how it might // be improved. - void VlogOccupancyInfo(const Kernel& kernel, const ThreadDim& thread_dims, + void VlogOccupancyInfo(const DeviceDescription& device_description, + const Kernel& kernel, const ThreadDim& thread_dims, const BlockDim& block_dims); // (supported on CUDA only) diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_graph.cc b/third_party/xla/xla/stream_executor/gpu/gpu_graph.cc deleted file mode 100644 index daf719c90a0b1f..00000000000000 --- a/third_party/xla/xla/stream_executor/gpu/gpu_graph.cc +++ /dev/null @@ -1,319 +0,0 @@ -/* Copyright 2023 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "xla/stream_executor/gpu/gpu_graph.h" - -#include -#include -#include -#include -#include -#include - -#include "absl/functional/any_invocable.h" -#include "absl/status/status.h" -#include "absl/status/statusor.h" -#include "absl/strings/str_cat.h" -#include "xla/stream_executor/gpu/gpu_driver.h" -#include "xla/stream_executor/gpu/gpu_executor.h" -#include "xla/stream_executor/gpu/gpu_kernel.h" -#include "xla/stream_executor/gpu/gpu_stream.h" -#include "xla/stream_executor/gpu/gpu_types.h" -#include "xla/stream_executor/kernel.h" -#include "xla/stream_executor/stream.h" -#include "xla/stream_executor/stream_executor.h" -#include "tsl/platform/env.h" -#include "tsl/platform/errors.h" -#include "tsl/platform/path.h" -#include "tsl/platform/statusor.h" - -namespace stream_executor { -namespace gpu { - -//===----------------------------------------------------------------------===// -// RAII helpers for gpu graph types. -//===----------------------------------------------------------------------===// - -std::atomic GpuGraphSupport::allocated_gpu_graph_execs_; -std::atomic GpuGraphSupport::alive_gpu_graph_execs_; - -/*static*/ void GpuGraphSupport::TrimDeviceMemory(StreamExecutor* executor) { - auto* gpu_executor = ExtractGpuExecutor(executor); - auto st = GpuDriver::DeviceGraphMemTrim(gpu_executor->device()); - if (!st.ok()) { - LOG(ERROR) << "Failed to trim Gpu device graph memory: " << st.message(); - } -} - -/*static*/ size_t GpuGraphSupport::NotifyGraphExecCreated() { - alive_gpu_graph_execs_.fetch_add(1, std::memory_order_relaxed); - return allocated_gpu_graph_execs_.fetch_add(1, std::memory_order_relaxed); -} - -/*static*/ size_t GpuGraphSupport::NotifyGraphExecDestroyed() { - return alive_gpu_graph_execs_.fetch_sub(1, std::memory_order_relaxed) - 1; -} - -/*static*/ size_t GpuGraphSupport::allocated_gpu_graph_execs() { - return allocated_gpu_graph_execs_.load(std::memory_order_relaxed); -} - -/*static*/ size_t GpuGraphSupport::alive_gpu_graph_execs() { - return alive_gpu_graph_execs_.load(std::memory_order_relaxed); -} - -void GpuGraphSupport::DestroyGraph::operator()(GpuGraphHandle graph) { - auto st = GpuDriver::DestroyGraph(graph); - CHECK(st.ok()) << "Failed to destroy gpu graph: " << st.message(); -} - -void GpuGraphSupport::DestroyGraphExec::operator()(GpuGraphExecHandle exec) { - auto st = GpuDriver::DestroyGraphExec(exec); - CHECK(st.ok()) << "Failed to destroy executable gpu graph: " << st.message(); -} - -absl::StatusOr GraphExecUpdateResultToString( - GpuDriver::GraphExecUpdateResult result) { - switch (result) { - case GpuDriver::GraphExecUpdateResult::kSuccess: - return "kSuccess"; - case GpuDriver::GraphExecUpdateResult::kError: - return "kFailure"; - case GpuDriver::GraphExecUpdateResult::kTopologyChanged: - return "kTopologyChanged"; - case GpuDriver::GraphExecUpdateResult::kAttributesChanged: - return "kAttributesChanged"; - case GpuDriver::GraphExecUpdateResult::kFunctionChanged: - return "kFunctionChanged"; - case GpuDriver::GraphExecUpdateResult::kParametersChanged: - return "kParametersChanged"; - case GpuDriver::GraphExecUpdateResult::kUnsupportedFunctionChange: - return "kUnsupportedFunctionChange"; - case GpuDriver::GraphExecUpdateResult::kNodeTypeChanged: - return "kNodeTypeChanged"; - case GpuDriver::GraphExecUpdateResult::kNotSupported: - return "kNotSupported"; - } - return absl::InternalError("Unexpected value for GraphExecUpdateResult"); -} - -absl::StatusOr GraphNodeTypeToString( - GpuDriver::GraphNodeType node_type) { - switch (node_type) { - case GpuDriver::GraphNodeType::kKernel: - return "kKernel"; - case GpuDriver::GraphNodeType::kMemcpy: - return "kMemcpy"; - case GpuDriver::GraphNodeType::kMemset: - return "kMemset"; - case GpuDriver::GraphNodeType::kHost: - return "kHost"; - case GpuDriver::GraphNodeType::kGraph: - return "kGraph"; - case GpuDriver::GraphNodeType::kEmpty: - return "kEmpty"; - case GpuDriver::GraphNodeType::kWaitEvent: - return "kWaitEvent"; - case GpuDriver::GraphNodeType::kEventRecord: - return "kEventRecord"; - case GpuDriver::GraphNodeType::kExtSemasSignal: - return "kExtSemasSignal"; - case GpuDriver::GraphNodeType::kExtSemasWait: - return "kExtSemasWait"; - case GpuDriver::GraphNodeType::kMemAlloc: - return "kMemAlloc"; - case GpuDriver::GraphNodeType::kMemFree: - return "kMemFree"; - case GpuDriver::GraphNodeType::kBatchMemOp: - return "kBatchMemOp"; - } - return absl::InternalError("Unexpected value for GraphNodeType"); -} - -absl::Status OwnedGpuGraphExec::Update(OwnedGpuGraph graph) { - VLOG(3) << "Update gpu graph exec with a new graph after " << num_launches_ - << " launches since last update" - << " #" << num_updates_++; - - num_launches_ = 0; - - uint64_t start_nanos = tsl::Env::Default()->NowNanos(); - GpuDriver::GraphExecUpdateResultInfo result; - memset(&result, 0, sizeof(result)); - auto st = GpuDriver::GraphExecUpdate(get(), graph.get(), &result); - uint64_t end_nanos = tsl::Env::Default()->NowNanos(); - - if (!st.ok()) { - TF_ASSIGN_OR_RETURN(std::string result_str, - GraphExecUpdateResultToString(result.result)); - std::string error_message = absl::StrCat( - "Failed to update gpu graph: Graph update result=", result_str); - - if (result.error_node) { - TF_ASSIGN_OR_RETURN(GpuDriver::GraphNodeType node_type, - GpuDriver::GraphNodeGetType(result.error_node)); - TF_ASSIGN_OR_RETURN(std::string node_type_str, - GraphNodeTypeToString(node_type)); - absl::StrAppend(&error_message, ", Error node name=", node_type_str); - } - - if (result.error_from_node) { - TF_ASSIGN_OR_RETURN(GpuDriver::GraphNodeType node_type, - GpuDriver::GraphNodeGetType(result.error_from_node)); - TF_ASSIGN_OR_RETURN(std::string node_type_str, - GraphNodeTypeToString(node_type)); - absl::StrAppend(&error_message, ", Error from node name=", node_type_str); - } - - absl::StrAppend(&error_message, ": ", st.message()); - return absl::InternalError(error_message); - } - - VLOG(5) << "Updated gpu graph exec #" << id_ << " (took " - << (end_nanos - start_nanos) / 1000 << " us)"; - - return absl::OkStatus(); -} - -absl::Status OwnedGpuGraphExec::Launch(stream_executor::Stream* stream) { - VLOG(3) << "Launch gpu graph " << get() - << " on a stream: " << stream->DebugStreamPointers() << " #" - << ++num_launches_; - - return GpuDriver::GraphLaunch(get(), AsGpuStreamValue(stream)); -} - -OwnedGpuGraphExec::~OwnedGpuGraphExec() { - if (*this) // do not log for moved-from instances - VLOG(5) << "Destroy GPU graph exec #" << id_ - << " (remaining alive instances: " - << GpuGraphSupport::NotifyGraphExecDestroyed() << ")"; -} - -//===----------------------------------------------------------------------===// -// GPU Graph Helpers. -//===----------------------------------------------------------------------===// - -absl::StatusOr CreateGpuGraph() { - GpuGraphHandle graph; - TF_RETURN_IF_ERROR(GpuDriver::CreateGraph(&graph)); - return OwnedGpuGraph(graph); -} - -absl::StatusOr AddKernelNode( - GpuGraphHandle graph, absl::Span deps, - ThreadDim threads, BlockDim blocks, const Kernel& kernel, - const KernelArgs& args) { - const GpuKernel* gpu_kernel = AsGpuKernel(&kernel); - GpuFunctionHandle gpu_func = gpu_kernel->AsGpuFunctionHandle(); - - auto* packed_args = DynCast(&args); - if (!packed_args) - return absl::InternalError("Unsupported kernel arguments type"); - - void** kernel_params = - const_cast(packed_args->argument_addresses().data()); - - GpuGraphNodeHandle node; - TF_RETURN_IF_ERROR(GpuDriver::GraphAddKernelNode( - &node, graph, deps, kernel.name(), gpu_func, blocks.x, blocks.y, blocks.z, - threads.x, threads.y, threads.z, args.number_of_shared_bytes(), - kernel_params, /*extra=*/nullptr)); - - return node; -} - -static GpuDevicePtr AsDevicePtr(const DeviceMemoryBase& mem) { - return reinterpret_cast(const_cast(mem.opaque())); -} - -absl::StatusOr AddMemcpyD2DNode( - GpuContext* context, GpuGraphHandle graph, - absl::Span deps, const DeviceMemoryBase& dst, - const DeviceMemoryBase& src) { - GpuGraphNodeHandle node; - TF_RETURN_IF_ERROR(GpuDriver::GraphAddMemcpyD2DNode( - context, &node, graph, deps, AsDevicePtr(dst), AsDevicePtr(src), - dst.size())); - return node; -} - -absl::StatusOr CaptureGpuGraph( - stream_executor::Stream* stream, - absl::AnyInvocable capture) { - VLOG(3) << "Capture gpu graph on a stream: " << stream->DebugStreamPointers(); - uint64_t start_nanos = tsl::Env::Default()->NowNanos(); - - GpuGraphHandle graph; - - // Get the underlying stream for passing to GPU runtime APIs. - auto gpu_stream = AsGpuStreamValue(stream); - - // Capture graph constructed by the exported graph capture function. - TF_RETURN_IF_ERROR(GpuDriver::StreamBeginCapture( - gpu_stream, GpuDriver::StreamCaptureMode::kThreadLocal)); - - // Call into graph capture function. - auto captured = capture(); - - // Always stop capturing the stream before checking `captured` result. - TF_RETURN_IF_ERROR(GpuDriver::StreamEndCapture(gpu_stream, &graph)); - - if (!captured.ok()) - return absl::InternalError( - absl::StrCat("failed to capture gpu graph: ", captured.message())); - - uint64_t end_nanos = tsl::Env::Default()->NowNanos(); - VLOG(5) << "Captured XLA:GPU operations into the graph " << graph << " (took " - << (end_nanos - start_nanos) / 1000 << " us)"; - - if (const char* path = getenv("XLA_GPU_GRAPH_DEBUG_DIRECTORY"); path) { - std::string file = tsl::io::JoinPath(std::string(path), "/gpu-graph-"); - - if (tsl::Env::Default()->CreateUniqueFileName(&file, ".dot")) { - VLOG(100) << "Print gpu graph " << graph - << " debug dot file to: " << file; - auto printed = GpuDriver::GraphDebugDotPrint(graph, file.c_str()); - printed.IgnoreError(); // warning will be printed by GpuDriver - } else { - LOG(WARNING) << "Cannot create unique filename, won't enable gpu " - "graph debugging"; - } - } - - return OwnedGpuGraph(graph); -} - -absl::StatusOr InstantiateGpuGraph(OwnedGpuGraph graph) { - GpuGraphExecHandle exec; - - uint64_t start_nanos = tsl::Env::Default()->NowNanos(); - GpuDriver::GraphInstantiateFlags flags; - TF_RETURN_IF_ERROR(GpuDriver::GraphInstantiate(&exec, graph.get(), flags)); - uint64_t end_nanos = tsl::Env::Default()->NowNanos(); - - size_t id = GpuGraphSupport::NotifyGraphExecCreated(); - VLOG(5) << "Instantiated gpu graph exec instance #" << id << " in " - << (end_nanos - start_nanos) / 1000 << " us (alive instances: " - << GpuGraphSupport::alive_gpu_graph_execs() << ")"; - return OwnedGpuGraphExec(id, exec); -} - -absl::StatusOr IsStreamCapturing(stream_executor::Stream* stream) { - return GpuDriver::StreamIsCapturing(AsGpuStreamValue(stream)); -} - -} // namespace gpu -} // namespace stream_executor diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_graph.h b/third_party/xla/xla/stream_executor/gpu/gpu_graph.h deleted file mode 100644 index cada2064e9c8b1..00000000000000 --- a/third_party/xla/xla/stream_executor/gpu/gpu_graph.h +++ /dev/null @@ -1,142 +0,0 @@ -/* Copyright 2023 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_STREAM_EXECUTOR_GPU_GPU_GRAPH_H_ -#define XLA_STREAM_EXECUTOR_GPU_GPU_GRAPH_H_ - -#include -#include -#include -#include -#include - -#include "absl/functional/any_invocable.h" -#include "absl/status/status.h" -#include "absl/status/statusor.h" -#include "absl/types/span.h" -#include "xla/stream_executor/gpu/gpu_types.h" -#include "xla/stream_executor/kernel.h" -#include "xla/stream_executor/launch_dim.h" -#include "xla/stream_executor/stream.h" -#include "xla/stream_executor/stream_executor.h" - -namespace stream_executor { -namespace gpu { - -// Forward declare. -class GpuContext; - -class GpuGraphSupport { - public: - // Deleters for gpu graph and graph exec instance that check the returned - // status and terminate on error. - struct DestroyGraph { - void operator()(GpuGraphHandle); - }; - struct DestroyGraphExec { - void operator()(GpuGraphExecHandle); - }; - - static size_t NotifyGraphExecCreated(); - static size_t NotifyGraphExecDestroyed(); - - static size_t allocated_gpu_graph_execs(); - static size_t alive_gpu_graph_execs(); - - static void TrimDeviceMemory(StreamExecutor* executor); - - private: - // Global counters for the total number of allocated and alive gpu graph - // execs to track the resource usage at run time. - static std::atomic allocated_gpu_graph_execs_; - static std::atomic alive_gpu_graph_execs_; -}; - -//===----------------------------------------------------------------------===// -// RAII helpers for gpu graph types. -//===----------------------------------------------------------------------===// - -class OwnedGpuGraph - : public std::unique_ptr, - GpuGraphSupport::DestroyGraph> { - // Bring std::unique_ptr constructors in scope. - using std::unique_ptr, - GpuGraphSupport::DestroyGraph>::unique_ptr; -}; - -class OwnedGpuGraphExec - : public std::unique_ptr, - GpuGraphSupport::DestroyGraphExec> { - using Base = std::unique_ptr, - GpuGraphSupport::DestroyGraphExec>; - - public: - OwnedGpuGraphExec(uint64_t id, GpuGraphExecHandle exec) - : Base(exec), id_(id) {} - ~OwnedGpuGraphExec(); - - OwnedGpuGraphExec(OwnedGpuGraphExec&&) = default; - OwnedGpuGraphExec& operator=(OwnedGpuGraphExec&&) = default; - - // Updates executable graph instance with a newly captured graph. Returns an - // error if the new graph is not compatible (see `cudaGraphExecUpdate`). - absl::Status Update(OwnedGpuGraph graph); - - // Launches captured graph on a given stream. - absl::Status Launch(stream_executor::Stream* stream); - - uint64_t id() const { return id_; } - - private: - uint64_t id_; - uint64_t num_updates_ = 0; - uint64_t num_launches_ = 0; -}; - -//===----------------------------------------------------------------------===// -// Gpu Graph Helpers. -//===----------------------------------------------------------------------===// - -// Creates new empty Gpu graph. -absl::StatusOr CreateGpuGraph(); - -// Adds a kernel node to the graph. -absl::StatusOr AddKernelNode( - GpuGraphHandle graph, absl::Span deps, - ThreadDim threads, BlockDim blocks, const Kernel& kernel, - const KernelArgs& args); - -// Adds a memory copy node to the graph. -absl::StatusOr AddMemcpyD2DNode( - GpuContext* context, GpuGraphHandle graph, - absl::Span deps, const DeviceMemoryBase& dst, - const DeviceMemoryBase& src); - -// Captures all operations added to a `stream` by the `capture` function into -// the gpu graph instance. -absl::StatusOr CaptureGpuGraph( - stream_executor::Stream* stream, - absl::AnyInvocable capture); - -// Instantiates a captured gpu graph instance into a gpu graph executable. -absl::StatusOr InstantiateGpuGraph(OwnedGpuGraph graph); - -// Returns true if the stream is in graph capture mode -absl::StatusOr IsStreamCapturing(stream_executor ::Stream* stream); - -} // namespace gpu -} // namespace stream_executor - -#endif // XLA_STREAM_EXECUTOR_GPU_GPU_GRAPH_H_ diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_helpers.h b/third_party/xla/xla/stream_executor/gpu/gpu_helpers.h index 16f18454ff68ad..c86f49140a5218 100644 --- a/third_party/xla/xla/stream_executor/gpu/gpu_helpers.h +++ b/third_party/xla/xla/stream_executor/gpu/gpu_helpers.h @@ -24,6 +24,7 @@ limitations under the License. #include #include +#include #include "xla/stream_executor/gpu/gpu_types.h" #include "tsl/platform/logging.h" diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_init.cc b/third_party/xla/xla/stream_executor/gpu/gpu_init.cc index 9747deba245e88..a0f8e5919a5ea7 100644 --- a/third_party/xla/xla/stream_executor/gpu/gpu_init.cc +++ b/third_party/xla/xla/stream_executor/gpu/gpu_init.cc @@ -19,14 +19,14 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" -#include "xla/stream_executor/multi_platform_manager.h" #include "xla/stream_executor/platform.h" +#include "xla/stream_executor/platform_manager.h" #include "tsl/platform/logging.h" namespace stream_executor { absl::Status ValidateGPUMachineManager() { - return MultiPlatformManager::PlatformWithName(GpuPlatformName()).status(); + return PlatformManager::PlatformWithName(GpuPlatformName()).status(); } Platform* GPUMachineManager() { @@ -34,7 +34,7 @@ Platform* GPUMachineManager() { // (and probably other things as well). static Platform* platform = [&] { absl::StatusOr p = - MultiPlatformManager::PlatformWithName(GpuPlatformName()); + PlatformManager::PlatformWithName(GpuPlatformName()); if (!p.ok()) { LOG(FATAL) << "Could not find Platform with name " << GpuPlatformName(); } diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_kernel.h b/third_party/xla/xla/stream_executor/gpu/gpu_kernel.h index 4cf8b79705694b..251093cf965141 100644 --- a/third_party/xla/xla/stream_executor/gpu/gpu_kernel.h +++ b/third_party/xla/xla/stream_executor/gpu/gpu_kernel.h @@ -29,23 +29,21 @@ limitations under the License. #include "absl/status/statusor.h" #include "xla/stream_executor/gpu/gpu_driver.h" +#include "xla/stream_executor/gpu/gpu_executor.h" #include "xla/stream_executor/gpu/gpu_types.h" #include "xla/stream_executor/kernel.h" #include "xla/stream_executor/launch_dim.h" -#include "xla/stream_executor/stream_executor_internal.h" +#include "tsl/platform/logging.h" -namespace stream_executor { -namespace gpu { +namespace stream_executor::gpu { -// Wraps a GpuFunctionHandle to implement the platform-independent -// KernelInterface. -class GpuKernel : public internal::KernelInterface { +class GpuKernel : public Kernel { public: - GpuKernel() = default; + explicit GpuKernel(GpuExecutor* gpu_executor) : gpu_executor_(gpu_executor) {} // Note that the function is unloaded when the module is unloaded, and the // module that the function is contained in is owned by the GpuExecutor. - ~GpuKernel() override {} + ~GpuKernel() override { gpu_executor_->UnloadKernel(this); } // As arity cannot be reflected upon using the CUDA API, the arity is // explicitly set during the GpuExecutor::GetKernel initialization process. @@ -89,6 +87,7 @@ class GpuKernel : public internal::KernelInterface { ThreadDim threads, size_t dynamic_shared_memory_bytes) const override; private: + GpuExecutor* gpu_executor_ = nullptr; GpuContext* gpu_context_ = nullptr; // context where kernel is loaded std::string name_; // kernel name @@ -99,19 +98,14 @@ class GpuKernel : public internal::KernelInterface { KernelCacheConfig preferred_cache_config_ = KernelCacheConfig::kNoPreference; }; -// Given a platform-independent kernel datatype, returns the (const) internal -// CUDA platform implementation pointer. inline const GpuKernel* AsGpuKernel(const Kernel* kernel) { - return static_cast(kernel->implementation()); + return static_cast(kernel); } -// Given a platform-independent kernel datatype, returns the (non-const) -// internal CUDA platform implementation pointer. inline GpuKernel* AsGpuKernel(Kernel* kernel) { - return static_cast(kernel->implementation()); + return static_cast(kernel); } -} // namespace gpu -} // namespace stream_executor +} // namespace stream_executor::gpu #endif // XLA_STREAM_EXECUTOR_GPU_GPU_KERNEL_H_ diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_kernel_test.cc b/third_party/xla/xla/stream_executor/gpu/gpu_kernel_test.cc index d7597e310b7554..150d2307f08b95 100644 --- a/third_party/xla/xla/stream_executor/gpu/gpu_kernel_test.cc +++ b/third_party/xla/xla/stream_executor/gpu/gpu_kernel_test.cc @@ -17,13 +17,17 @@ limitations under the License. #include #include +#include "absl/strings/ascii.h" #include "xla/service/platform_util.h" #include "xla/stream_executor/gpu/gpu_test_kernels.h" +#include "xla/stream_executor/kernel_spec.h" #include "xla/stream_executor/launch_dim.h" -#include "xla/stream_executor/multi_platform_manager.h" #include "xla/stream_executor/platform.h" +#include "xla/stream_executor/platform_manager.h" #include "xla/stream_executor/stream.h" #include "xla/stream_executor/stream_executor.h" +#include "tsl/lib/core/status_test_util.h" +#include "tsl/platform/statusor.h" #include "tsl/platform/test.h" namespace stream_executor::gpu { @@ -33,11 +37,11 @@ TEST(GpuKernelTest, Add) { DeviceMemory>; auto name = absl::AsciiStrToUpper( xla::PlatformUtil::CanonicalPlatformName("gpu").value()); - Platform* platform = MultiPlatformManager::PlatformWithName(name).value(); + Platform* platform = PlatformManager::PlatformWithName(name).value(); StreamExecutor* executor = platform->ExecutorForDevice(0).value(); Stream stream(executor); - stream.Init(); + TF_ASSERT_OK(stream.Initialize()); ASSERT_TRUE(stream.ok()); MultiKernelLoaderSpec spec(/*arity=*/3); @@ -48,8 +52,7 @@ TEST(GpuKernelTest, Add) { reinterpret_cast(&internal::kAddI32KernelModule[0]), "add"); #endif - AddI32Kernel add(executor); - ASSERT_TRUE(executor->GetKernel(spec, &add).ok()); + TF_ASSERT_OK_AND_ASSIGN(auto add, AddI32Kernel::Create(executor, spec)); int64_t length = 4; int64_t byte_length = sizeof(int32_t) * length; @@ -59,16 +62,16 @@ TEST(GpuKernelTest, Add) { DeviceMemory b = executor->AllocateArray(length, 0); DeviceMemory c = executor->AllocateArray(length, 0); - stream.ThenMemset32(&a, 1, byte_length); - stream.ThenMemset32(&b, 2, byte_length); - stream.ThenMemZero(&c, byte_length); + TF_ASSERT_OK(stream.Memset32(&a, 1, byte_length)); + TF_ASSERT_OK(stream.Memset32(&b, 2, byte_length)); + TF_ASSERT_OK(stream.MemZero(&c, byte_length)); // Launch kernel. ASSERT_TRUE(stream.ThenLaunch(ThreadDim(), BlockDim(4), add, a, b, c).ok()); // Copy data back to host. std::vector dst(4, 42); - stream.ThenMemcpy(dst.data(), c, byte_length); + TF_ASSERT_OK(stream.Memcpy(dst.data(), c, byte_length)); std::vector expected = {3, 3, 3, 3}; ASSERT_EQ(dst, expected); diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_kernels.h b/third_party/xla/xla/stream_executor/gpu/gpu_kernels.h index 779864089a3f74..a2f14ec15f4514 100644 --- a/third_party/xla/xla/stream_executor/gpu/gpu_kernels.h +++ b/third_party/xla/xla/stream_executor/gpu/gpu_kernels.h @@ -28,7 +28,7 @@ namespace stream_executor::gpu { // // Easiest way to get PTX from C++ is to use https://godbolt.org. inline constexpr std::string_view kNoOpKernel = R"( -.version 8.0 +.version 4.0 .target sm_50 .address_size 64 diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_stream.cc b/third_party/xla/xla/stream_executor/gpu/gpu_stream.cc index 571f583e3881da..a04f5410dbd613 100644 --- a/third_party/xla/xla/stream_executor/gpu/gpu_stream.cc +++ b/third_party/xla/xla/stream_executor/gpu/gpu_stream.cc @@ -17,8 +17,13 @@ limitations under the License. #include +#include "absl/log/check.h" +#include "absl/log/log.h" #include "absl/status/status.h" +#include "xla/stream_executor/gpu/gpu_driver.h" #include "xla/stream_executor/gpu/gpu_executor.h" +#include "xla/stream_executor/gpu/gpu_types.h" +#include "xla/stream_executor/platform.h" #include "xla/stream_executor/stream.h" namespace stream_executor { diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_stream.h b/third_party/xla/xla/stream_executor/gpu/gpu_stream.h index 3c5fc6e19eb5da..166be822995879 100644 --- a/third_party/xla/xla/stream_executor/gpu/gpu_stream.h +++ b/third_party/xla/xla/stream_executor/gpu/gpu_stream.h @@ -21,7 +21,9 @@ limitations under the License. #include +#include "absl/log/check.h" #include "xla/stream_executor/gpu/gpu_types.h" +#include "xla/stream_executor/platform.h" #include "xla/stream_executor/stream_executor_internal.h" namespace stream_executor { diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_test_kernels.h b/third_party/xla/xla/stream_executor/gpu/gpu_test_kernels.h index 49be0a157c2df1..74931452bb6624 100644 --- a/third_party/xla/xla/stream_executor/gpu/gpu_test_kernels.h +++ b/third_party/xla/xla/stream_executor/gpu/gpu_test_kernels.h @@ -37,7 +37,7 @@ namespace stream_executor::gpu::internal { // // Easiest way to get PTX from C++ is to use https://godbolt.org. inline constexpr std::string_view kAddI32Kernel = R"( -.version 8.0 +.version 4.0 .target sm_50 .address_size 64 diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_timer.cc b/third_party/xla/xla/stream_executor/gpu/gpu_timer.cc index cda8c5646f6c8b..7ce8e8092c6f4f 100644 --- a/third_party/xla/xla/stream_executor/gpu/gpu_timer.cc +++ b/third_party/xla/xla/stream_executor/gpu/gpu_timer.cc @@ -15,12 +15,15 @@ limitations under the License. #include "xla/stream_executor/gpu/gpu_timer.h" +#include #include #include #include #include "absl/base/const_init.h" #include "absl/base/thread_annotations.h" +#include "absl/log/check.h" +#include "absl/log/log.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/synchronization/mutex.h" @@ -29,6 +32,8 @@ limitations under the License. #include "xla/stream_executor/gpu/gpu_driver.h" #include "xla/stream_executor/gpu/gpu_executor.h" #include "xla/stream_executor/gpu/gpu_stream.h" +#include "xla/stream_executor/gpu/gpu_types.h" +#include "tsl/platform/errors.h" #include "tsl/platform/statusor.h" namespace stream_executor { diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_timer.h b/third_party/xla/xla/stream_executor/gpu/gpu_timer.h index d2a52b72be47a8..2851bd569def67 100644 --- a/third_party/xla/xla/stream_executor/gpu/gpu_timer.h +++ b/third_party/xla/xla/stream_executor/gpu/gpu_timer.h @@ -21,9 +21,8 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/time/time.h" -#include "xla/stream_executor/gpu/gpu_driver.h" #include "xla/stream_executor/gpu/gpu_executor.h" -#include "xla/stream_executor/stream_executor_internal.h" +#include "xla/stream_executor/gpu/gpu_types.h" namespace xla { namespace gpu { diff --git a/third_party/xla/xla/stream_executor/gpu/redzone_allocator.cc b/third_party/xla/xla/stream_executor/gpu/redzone_allocator.cc index aa1643f5d987c1..5f08f84c2d86b5 100644 --- a/third_party/xla/xla/stream_executor/gpu/redzone_allocator.cc +++ b/third_party/xla/xla/stream_executor/gpu/redzone_allocator.cc @@ -15,26 +15,35 @@ limitations under the License. #include "xla/stream_executor/gpu/redzone_allocator.h" +#include #include #include +#include #include +#include +#include #include "absl/base/call_once.h" #include "absl/container/fixed_array.h" +#include "absl/log/check.h" +#include "absl/log/log.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_format.h" -#include "absl/types/optional.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" #include "xla/stream_executor/device_memory.h" +#include "xla/stream_executor/device_memory_allocator.h" #include "xla/stream_executor/gpu/asm_compiler.h" #include "xla/stream_executor/gpu/gpu_asm_opts.h" #include "xla/stream_executor/kernel.h" -#include "xla/stream_executor/kernel_spec.h" +#include "xla/stream_executor/launch_dim.h" #include "xla/stream_executor/stream.h" #include "xla/stream_executor/stream_executor.h" #include "tsl/framework/allocator.h" #include "tsl/lib/math/math_util.h" #include "tsl/platform/errors.h" +#include "tsl/platform/statusor.h" namespace stream_executor { @@ -95,7 +104,7 @@ absl::StatusOr> RedzoneAllocator::AllocateBytes( // Split up the RHS redzone into two pieces: // - 0 to kRhsRedzoneAlign bytes adjacent to the user buffer, followed by // - redzone_size_ bytes. - // We do this because Stream::ThenMemset32 requires the buffer address and + // We do this because Stream::Memset32 requires the buffer address and // size to be aligned to 4 bytes. DeviceMemory rhs_redzone_slop = allocated_buffer_memory.GetSlice(redzone_size_ + byte_size, rhs_slop); @@ -107,11 +116,13 @@ absl::StatusOr> RedzoneAllocator::AllocateBytes( redzone_pattern_}; uint32_t pattern32; std::memcpy(&pattern32, pattern_arr, sizeof(pattern32)); - stream_->ThenMemset32(&lhs_redzone, pattern32, redzone_size_); + TF_RETURN_IF_ERROR(stream_->Memset32(&lhs_redzone, pattern32, redzone_size_)); if (rhs_slop != 0) { - stream_->ThenMemcpy(&rhs_redzone_slop, &pattern32, rhs_slop); + TF_RETURN_IF_ERROR( + stream_->Memcpy(&rhs_redzone_slop, &pattern32, rhs_slop)); } - stream_->ThenMemset32(&rhs_redzone_nonslop, pattern32, redzone_size_); + TF_RETURN_IF_ERROR( + stream_->Memset32(&rhs_redzone_nonslop, pattern32, redzone_size_)); allocated_buffers_.emplace_back(std::move(allocated_buffer), byte_size); return data_chunk; @@ -186,8 +197,8 @@ static absl::StatusOr CheckRedzoneHost( absl::string_view name, Stream* stream, uint8_t redzone_pattern) { uint64_t size = redzone.size(); auto redzone_data = std::make_unique(size); - TF_RETURN_IF_ERROR(stream->ThenMemcpy(redzone_data.get(), redzone, size) - .BlockHostUntilDone()); + TF_RETURN_IF_ERROR(stream->Memcpy(redzone_data.get(), redzone, size)); + TF_RETURN_IF_ERROR(stream->BlockHostUntilDone()); std::array pattern_arr; pattern_arr.fill(redzone_pattern); @@ -246,7 +257,8 @@ static absl::Status ReinitializeRedzone(Stream* stream, uint8_t redzone_pattern) { absl::FixedArray redzone_array(redzone.size()); redzone_array.fill(redzone_pattern); - stream->ThenMemcpy(&redzone, redzone_array.data(), redzone.size()); + TF_RETURN_IF_ERROR( + stream->Memcpy(&redzone, redzone_array.data(), redzone.size())); TF_RETURN_IF_ERROR(stream->BlockHostUntilDone()); return absl::OkStatus(); } @@ -281,7 +293,7 @@ static absl::StatusOr CheckRedzonesForBuffer( out_param, comparison_kernel)); int64_t result; CHECK_EQ(out_param.size(), sizeof(result)); - stream->ThenMemcpy(&result, out_param, sizeof(result)); + TF_RETURN_IF_ERROR(stream->Memcpy(&result, out_param, sizeof(result))); TF_RETURN_IF_ERROR(stream->BlockHostUntilDone()); if (result != 0) { @@ -326,27 +338,29 @@ absl::StatusOr RedzoneAllocator::CheckRedzones() const { } TF_ASSIGN_OR_RETURN( - std::shared_ptr loaded_kernel, + ComparisonKernelT * kernel_ptr, (LoadKernelOrGetPtr, uint8_t, uint64_t, DeviceMemory>( executor, "redzone_checker", redzone_checker_ptx, compiled_ptx))); #elif TENSORFLOW_USE_ROCM TF_ASSIGN_OR_RETURN( - std::unique_ptr loaded_kernel, - (executor->CreateTypedKernel, uint8, uint64_t, - DeviceMemory>("redzone_checker", - kernel_symbol()))); + ComparisonKernelT loaded_kernel, + (TypedKernel, uint8, uint64_t, + DeviceMemory>::Create(executor, "redzone_checker", + kernel_symbol()))); + // CUDA side returns a pointer => hence get a pointer to the loaded kernel + auto* kernel_ptr = &loaded_kernel; #endif // GOOGLE_CUDA auto out_param = executor->AllocateOwnedScalar(); - stream_->ThenMemZero(out_param.ptr(), sizeof(uint64_t)); + TF_RETURN_IF_ERROR(stream_->MemZero(out_param.ptr(), sizeof(uint64_t))); for (const auto& buf_and_size : allocated_buffers_) { TF_ASSIGN_OR_RETURN( RedzoneCheckStatus redzone_status, CheckRedzonesForBuffer(stream_, *buf_and_size.first, out_param.cref(), - *loaded_kernel, buf_and_size.second, - redzone_size_, redzone_pattern_)); + *kernel_ptr, buf_and_size.second, redzone_size_, + redzone_pattern_)); if (!redzone_status.ok()) { return redzone_status; } diff --git a/third_party/xla/xla/stream_executor/gpu/redzone_allocator.cu.cc b/third_party/xla/xla/stream_executor/gpu/redzone_allocator.cu.cc index e42f82cf6c9e0c..d6a5108ef37ca8 100644 --- a/third_party/xla/xla/stream_executor/gpu/redzone_allocator.cu.cc +++ b/third_party/xla/xla/stream_executor/gpu/redzone_allocator.cu.cc @@ -15,7 +15,6 @@ limitations under the License. #include "xla/stream_executor/gpu/redzone_allocator.h" -#include namespace stream_executor { diff --git a/third_party/xla/xla/stream_executor/gpu/redzone_allocator.h b/third_party/xla/xla/stream_executor/gpu/redzone_allocator.h index e6f54e58aede1d..8fcac1ea2b7678 100644 --- a/third_party/xla/xla/stream_executor/gpu/redzone_allocator.h +++ b/third_party/xla/xla/stream_executor/gpu/redzone_allocator.h @@ -17,8 +17,12 @@ limitations under the License. #define XLA_STREAM_EXECUTOR_GPU_REDZONE_ALLOCATOR_H_ #include +#include +#include #include +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" #include "xla/stream_executor/device_memory_allocator.h" #include "xla/stream_executor/gpu/gpu_asm_opts.h" #include "xla/stream_executor/scratch_allocator.h" diff --git a/third_party/xla/xla/stream_executor/gpu/redzone_allocator_test.cc b/third_party/xla/xla/stream_executor/gpu/redzone_allocator_test.cc index 4c9f4bda1fe054..d9fc43f94a87bd 100644 --- a/third_party/xla/xla/stream_executor/gpu/redzone_allocator_test.cc +++ b/third_party/xla/xla/stream_executor/gpu/redzone_allocator_test.cc @@ -16,14 +16,19 @@ limitations under the License. #include "xla/stream_executor/gpu/redzone_allocator.h" #include +#include #include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/device_memory_allocator.h" #include "xla/stream_executor/gpu/gpu_asm_opts.h" #include "xla/stream_executor/gpu/gpu_init.h" -#include "xla/stream_executor/multi_platform_manager.h" #include "xla/stream_executor/platform.h" +#include "xla/stream_executor/platform_manager.h" #include "tsl/lib/core/status_test_util.h" +#include "tsl/platform/statusor.h" #include "tsl/platform/test.h" namespace stream_executor { @@ -52,7 +57,7 @@ TEST(RedzoneAllocatorTest, WriteToRedzone) { constexpr int64_t kAllocSize = (1 << 25) + 1; Platform* platform = - MultiPlatformManager::PlatformWithName(GpuPlatformName()).value(); + PlatformManager::PlatformWithName(GpuPlatformName()).value(); StreamExecutor* stream_exec = platform->ExecutorForDevice(0).value(); GpuAsmOpts opts; StreamExecutorMemoryAllocator se_allocator(platform, {stream_exec}); @@ -74,8 +79,8 @@ TEST(RedzoneAllocatorTest, WriteToRedzone) { // Check that the redzones are in fact filled with kRedzonePattern. auto check_redzone = [&](DeviceMemoryBase redzone, absl::string_view name) { std::vector host_buf(kRedzoneSize); - TF_ASSERT_OK(stream.ThenMemcpy(host_buf.data(), redzone, kRedzoneSize) - .BlockHostUntilDone()); + TF_ASSERT_OK(stream.Memcpy(host_buf.data(), redzone, kRedzoneSize)); + TF_ASSERT_OK(stream.BlockHostUntilDone()); const int64_t kMaxMismatches = 16; int64_t mismatches = 0; for (int64_t i = 0; i < host_buf.size(); ++i) { @@ -103,8 +108,8 @@ TEST(RedzoneAllocatorTest, WriteToRedzone) { reinterpret_cast(redzone.opaque()) + offset, 1); char old_redzone_value = 0; { EXPECT_REDZONE_OK(allocator.CheckRedzones()); } - stream.ThenMemcpy(&old_redzone_value, redzone_at_offset, 1) - .ThenMemZero(&redzone_at_offset, 1); + TF_ASSERT_OK(stream.Memcpy(&old_redzone_value, redzone_at_offset, 1)); + TF_ASSERT_OK(stream.MemZero(&redzone_at_offset, 1)); EXPECT_REDZONE_VIOLATION(allocator.CheckRedzones()); // Checking reinitializes the redzone. @@ -126,7 +131,7 @@ TEST(RedzoneAllocatorTest, VeryLargeRedzone) { // Make sure the redzone size would require grid dimension > 65535. constexpr int64_t kRedzoneSize = 65535 * 1024 + 1; Platform* platform = - MultiPlatformManager::PlatformWithName(GpuPlatformName()).value(); + PlatformManager::PlatformWithName(GpuPlatformName()).value(); StreamExecutor* stream_exec = platform->ExecutorForDevice(0).value(); GpuAsmOpts opts; StreamExecutorMemoryAllocator se_allocator(platform, {stream_exec}); diff --git a/third_party/xla/xla/stream_executor/host/BUILD b/third_party/xla/xla/stream_executor/host/BUILD index 6ec4ff8b6b74db..6212e467a34b98 100644 --- a/third_party/xla/xla/stream_executor/host/BUILD +++ b/third_party/xla/xla/stream_executor/host/BUILD @@ -3,11 +3,12 @@ load("//xla:xla.bzl", "xla_cc_test") load("//xla/stream_executor:build_defs.bzl", "stream_executor_friends") -load("@local_tsl//tsl:tsl.bzl", "set_external_visibility") +load("@local_tsl//tsl:tsl.bzl", "internal_visibility") load("@local_tsl//tsl/platform:rules_cc.bzl", "cc_library") package( - default_visibility = ["//visibility:public"], + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = internal_visibility([":friends"]), licenses = ["notice"], ) @@ -24,7 +25,6 @@ cc_library( hdrs = [ "host_platform_id.h", ], - visibility = ["//visibility:public"], deps = [ "//xla/stream_executor:platform", ], @@ -43,13 +43,13 @@ cc_library( ":host_gpu_executor", ":host_platform_id", "//xla/stream_executor", - "//xla/stream_executor:multi_platform_manager", + "//xla/stream_executor:platform_manager", "//xla/stream_executor/platform", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings:str_format", "@local_tsl//tsl/platform:errors", ], - alwayslink = True, # Registers itself with the MultiPlatformManager. + alwayslink = True, # Registers itself with the PlatformManager. ) cc_library( @@ -60,7 +60,6 @@ cc_library( hdrs = [ "host_stream.h", ], - visibility = ["//visibility:public"], deps = [ "//xla/stream_executor:stream_executor_internal", "@com_google_absl//absl/functional:any_invocable", @@ -81,19 +80,15 @@ cc_library( hdrs = [ "host_gpu_executor.h", ], - visibility = ["//visibility:public"], deps = [ - ":host_platform_id", ":host_stream", "//xla/stream_executor", - "//xla/stream_executor:plugin_registry", "//xla/stream_executor:stream_executor_internal", "@com_google_absl//absl/functional:any_invocable", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/synchronization", - "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:platform_port", "@local_tsl//tsl/platform/profile_utils:profile_utils_cpu_utils", ], @@ -106,8 +101,8 @@ xla_cc_test( deps = [ ":host_platform", "//xla/stream_executor", - "//xla/stream_executor:multi_platform_manager", "//xla/stream_executor:platform", + "//xla/stream_executor:platform_manager", "@com_google_absl//absl/status", "@com_google_absl//absl/synchronization", "@local_tsl//tsl/lib/core:status_test_util", diff --git a/third_party/xla/xla/stream_executor/host/host_gpu_executor.cc b/third_party/xla/xla/stream_executor/host/host_gpu_executor.cc index af784988570852..902a88c628bf55 100644 --- a/third_party/xla/xla/stream_executor/host/host_gpu_executor.cc +++ b/third_party/xla/xla/stream_executor/host/host_gpu_executor.cc @@ -32,9 +32,7 @@ limitations under the License. #include "absl/synchronization/notification.h" #include "xla/stream_executor/device_description.h" #include "xla/stream_executor/device_options.h" -#include "xla/stream_executor/host/host_platform_id.h" #include "xla/stream_executor/host/host_stream.h" -#include "xla/stream_executor/plugin_registry.h" #include "xla/stream_executor/stream_executor.h" #include "xla/stream_executor/stream_executor_internal.h" #include "tsl/platform/mem.h" @@ -275,32 +273,6 @@ HostExecutor::CreateDeviceDescription(int device_ordinal) { return builder.Build(); } -blas::BlasSupport* HostExecutor::CreateBlas() { - PluginRegistry* registry = PluginRegistry::Instance(); - absl::StatusOr status = - registry->GetFactory(kHostPlatformId); - if (!status.ok()) { - LOG(ERROR) << "Unable to retrieve BLAS factory: " - << status.status().message(); - return nullptr; - } - - return status.value()(this); -} - -fft::FftSupport* HostExecutor::CreateFft() { - PluginRegistry* registry = PluginRegistry::Instance(); - absl::StatusOr status = - registry->GetFactory(kHostPlatformId); - if (!status.ok()) { - LOG(ERROR) << "Unable to retrieve FFT factory: " - << status.status().message(); - return nullptr; - } - - return status.value()(this); -} - std::unique_ptr HostExecutor::GetStreamImplementation() { return std::unique_ptr( diff --git a/third_party/xla/xla/stream_executor/host/host_gpu_executor.h b/third_party/xla/xla/stream_executor/host/host_gpu_executor.h index 4182066be0fe09..3c3e827d446454 100644 --- a/third_party/xla/xla/stream_executor/host/host_gpu_executor.h +++ b/third_party/xla/xla/stream_executor/host/host_gpu_executor.h @@ -22,18 +22,14 @@ limitations under the License. #include #include "absl/functional/any_invocable.h" -#include "xla/stream_executor/blas.h" -#include "xla/stream_executor/host/host_stream.h" #include "xla/stream_executor/stream_executor.h" #include "xla/stream_executor/stream_executor_internal.h" -#include "tsl/platform/errors.h" namespace stream_executor { namespace host { // An implementation of StreamExecutor that does no communication or interaction // with a device, but DOES perform memory operations backed by the host. -// Plugin routines (BLAS) are also supported and functional. // Kernel invocations will fail, but host callbacks may be enqueued on this // executor and its associated stream, and should follow standard ordering // semantics. @@ -134,20 +130,9 @@ class HostExecutor : public internal::StreamExecutorInterface { return true; } - blas::BlasSupport* CreateBlas() override; - - dnn::DnnSupport* CreateDnn() override { return nullptr; } - - fft::FftSupport* CreateFft() override; - std::unique_ptr CreateEventImplementation() override; - std::unique_ptr CreateKernelImplementation() - override { - return nullptr; - } - std::unique_ptr GetStreamImplementation() override; private: diff --git a/third_party/xla/xla/stream_executor/host/host_platform.cc b/third_party/xla/xla/stream_executor/host/host_platform.cc index c425e0385c785b..5b8a36080ac668 100644 --- a/third_party/xla/xla/stream_executor/host/host_platform.cc +++ b/third_party/xla/xla/stream_executor/host/host_platform.cc @@ -22,6 +22,7 @@ limitations under the License. #include "xla/stream_executor/host/host_gpu_executor.h" #include "xla/stream_executor/host/host_platform_id.h" #include "xla/stream_executor/platform/initialize.h" +#include "xla/stream_executor/platform_manager.h" #include "tsl/platform/errors.h" namespace stream_executor { @@ -73,11 +74,11 @@ HostPlatform::GetUncachedExecutor(const StreamExecutorConfig& config) { static void InitializeHostPlatform() { std::unique_ptr platform(new host::HostPlatform); - TF_CHECK_OK(MultiPlatformManager::RegisterPlatform(std::move(platform))); + TF_CHECK_OK(PlatformManager::RegisterPlatform(std::move(platform))); } } // namespace host } // namespace stream_executor -REGISTER_MODULE_INITIALIZER(host_platform, - stream_executor::host::InitializeHostPlatform()); +STREAM_EXECUTOR_REGISTER_MODULE_INITIALIZER( + host_platform, stream_executor::host::InitializeHostPlatform()); diff --git a/third_party/xla/xla/stream_executor/host/host_platform.h b/third_party/xla/xla/stream_executor/host/host_platform.h index 08f89fbf6cb1d8..74d234b940f4bc 100644 --- a/third_party/xla/xla/stream_executor/host/host_platform.h +++ b/third_party/xla/xla/stream_executor/host/host_platform.h @@ -24,9 +24,9 @@ limitations under the License. #include #include "xla/stream_executor/executor_cache.h" -#include "xla/stream_executor/multi_platform_manager.h" #include "xla/stream_executor/platform.h" #include "xla/stream_executor/platform/port.h" +#include "xla/stream_executor/platform_manager.h" #include "xla/stream_executor/stream_executor.h" namespace stream_executor { diff --git a/third_party/xla/xla/stream_executor/host/host_stream_test.cc b/third_party/xla/xla/stream_executor/host/host_stream_test.cc index 11e8632f108110..12e8c17fb59556 100644 --- a/third_party/xla/xla/stream_executor/host/host_stream_test.cc +++ b/third_party/xla/xla/stream_executor/host/host_stream_test.cc @@ -15,8 +15,8 @@ limitations under the License. #include "absl/status/status.h" #include "absl/synchronization/mutex.h" -#include "xla/stream_executor/multi_platform_manager.h" #include "xla/stream_executor/platform.h" +#include "xla/stream_executor/platform_manager.h" #include "xla/stream_executor/stream.h" #include "xla/stream_executor/stream_executor.h" #include "tsl/lib/core/status_test_util.h" @@ -27,22 +27,22 @@ namespace se = stream_executor; TEST(HostStream, EnforcesFIFOOrder) { se::Platform* platform = - se::MultiPlatformManager::PlatformWithName("Host").value(); + se::PlatformManager::PlatformWithName("Host").value(); se::StreamExecutor* executor = platform->ExecutorForDevice(0).value(); se::Stream stream(executor); - stream.Init(); + TF_ASSERT_OK(stream.Initialize()); absl::Mutex mu; int expected = 0; bool ok = true; for (int i = 0; i < 2000; ++i) { - stream.ThenDoHostCallback([i, &mu, &expected, &ok]() { + TF_ASSERT_OK(stream.DoHostCallback([i, &mu, &expected, &ok]() { absl::MutexLock lock(&mu); if (expected != i) { ok = false; } ++expected; - }); + })); } TF_ASSERT_OK(stream.BlockHostUntilDone()); absl::MutexLock lock(&mu); @@ -51,13 +51,13 @@ TEST(HostStream, EnforcesFIFOOrder) { TEST(HostStream, ReportsHostCallbackError) { se::Platform* platform = - se::MultiPlatformManager::PlatformWithName("Host").value(); + se::PlatformManager::PlatformWithName("Host").value(); se::StreamExecutor* executor = platform->ExecutorForDevice(0).value(); se::Stream stream(executor); - stream.Init(); + TF_ASSERT_OK(stream.Initialize()); - stream.ThenDoHostCallbackWithStatus( - []() { return absl::InternalError("error!"); }); + TF_ASSERT_OK(stream.DoHostCallbackWithStatus( + []() { return absl::InternalError("error!"); })); auto status = stream.BlockHostUntilDone(); ASSERT_EQ(status.code(), tsl::error::INTERNAL); @@ -66,15 +66,15 @@ TEST(HostStream, ReportsHostCallbackError) { TEST(HostStream, ReportsFirstHostCallbackError) { se::Platform* platform = - se::MultiPlatformManager::PlatformWithName("Host").value(); + se::PlatformManager::PlatformWithName("Host").value(); se::StreamExecutor* executor = platform->ExecutorForDevice(0).value(); se::Stream stream(executor); - stream.Init(); + TF_ASSERT_OK(stream.Initialize()); - stream.ThenDoHostCallbackWithStatus( - []() { return absl::InternalError("error 1"); }); - stream.ThenDoHostCallbackWithStatus( - []() { return absl::InternalError("error 2"); }); + TF_ASSERT_OK(stream.DoHostCallbackWithStatus( + []() { return absl::InternalError("error 1"); })); + TF_ASSERT_OK(stream.DoHostCallbackWithStatus( + []() { return absl::InternalError("error 2"); })); // "error 2" is just lost. ASSERT_EQ(stream.BlockHostUntilDone().message(), "error 1"); diff --git a/tensorflow/core/profiler/backends/gpu/nvtx_utils.h b/third_party/xla/xla/stream_executor/host_memory_allocation.cc similarity index 52% rename from tensorflow/core/profiler/backends/gpu/nvtx_utils.h rename to third_party/xla/xla/stream_executor/host_memory_allocation.cc index 31d22b64a9feee..12affb0b3c68b6 100644 --- a/tensorflow/core/profiler/backends/gpu/nvtx_utils.h +++ b/third_party/xla/xla/stream_executor/host_memory_allocation.cc @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2024 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -13,21 +13,22 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CORE_PROFILER_BACKENDS_GPU_NVTX_UTILS_H_ -#define TENSORFLOW_CORE_PROFILER_BACKENDS_GPU_NVTX_UTILS_H_ +#include "xla/stream_executor/host_memory_allocation.h" -#include +#include -#include "absl/strings/string_view.h" -#include "xla/backends/profiler/gpu/nvtx_utils.h" -#include "tensorflow/core/platform/macros.h" +#include "xla/stream_executor/stream_executor_internal.h" -namespace tensorflow { -namespace profiler { +namespace stream_executor { -using xla::profiler::NVTXRangeTracker; // NOLINT +HostMemoryAllocation::HostMemoryAllocation( + void* ptr, uint64_t size, internal::StreamExecutorInterface* executor) + : ptr_(ptr), size_(size), executor_(executor) {} -} // namespace profiler -} // namespace tensorflow +HostMemoryAllocation::~HostMemoryAllocation() { + if (ptr_ != nullptr && executor_ != nullptr) { + executor_->HostMemoryDeallocate(ptr_); + } +} -#endif // TENSORFLOW_CORE_PROFILER_BACKENDS_GPU_NVTX_UTILS_H_ +} // namespace stream_executor diff --git a/third_party/xla/xla/stream_executor/host_memory_allocation.h b/third_party/xla/xla/stream_executor/host_memory_allocation.h new file mode 100644 index 00000000000000..974eb63fb8daa5 --- /dev/null +++ b/third_party/xla/xla/stream_executor/host_memory_allocation.h @@ -0,0 +1,48 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_STREAM_EXECUTOR_HOST_MEMORY_ALLOCATION_H_ +#define XLA_STREAM_EXECUTOR_HOST_MEMORY_ALLOCATION_H_ + +#include + +#include "xla/stream_executor/memory_allocation.h" + +namespace stream_executor { + +namespace internal { +class StreamExecutorInterface; +} + +// RAII container for pinned host memory allocation allocated on an underlying +// device owned by `*this`. +class HostMemoryAllocation final : public MemoryAllocation { + public: + HostMemoryAllocation(void* ptr, uint64_t size, + internal::StreamExecutorInterface* executor); + ~HostMemoryAllocation() final; + + void* opaque() const final { return ptr_; } + uint64_t size() const final { return size_; } + + private: + void* ptr_ = nullptr; + uint64_t size_ = 0; + internal::StreamExecutorInterface* executor_ = nullptr; +}; + +} // namespace stream_executor + +#endif // XLA_STREAM_EXECUTOR_HOST_MEMORY_ALLOCATION_H_ diff --git a/third_party/xla/xla/stream_executor/integrations/BUILD b/third_party/xla/xla/stream_executor/integrations/BUILD index 367e12463d6405..af782baf3a5185 100644 --- a/third_party/xla/xla/stream_executor/integrations/BUILD +++ b/third_party/xla/xla/stream_executor/integrations/BUILD @@ -1,11 +1,13 @@ +load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda") load("//xla:xla.bzl", "xla_cc_test") -load("//xla/stream_executor:build_defs.bzl", "stream_executor_friends") -load("@local_tsl//tsl:tsl.bzl", "if_google", "set_external_visibility") +load("//xla/stream_executor:build_defs.bzl", "if_gpu_is_configured", "stream_executor_friends") +load("@local_tsl//tsl:tsl.bzl", "if_google", "internal_visibility") load("@local_tsl//tsl:tsl.default.bzl", "filegroup") load("@local_tsl//tsl/platform:rules_cc.bzl", "cc_library") package( - default_visibility = ["//visibility:public"], + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = internal_visibility([":friends"]), licenses = ["notice"], ) @@ -28,7 +30,7 @@ filegroup( "device_host_allocator.h", "device_mem_allocator.h", ], - visibility = ["//visibility:public"], + visibility = internal_visibility(["//tensorflow/core:__pkg__"]), ) #===--------------------------------------------------------------------------------------------===# @@ -42,7 +44,6 @@ cc_library( name = "tf_allocator_adapter", srcs = ["tf_allocator_adapter.cc"], hdrs = ["tf_allocator_adapter.h"], - visibility = ["//visibility:public"], deps = [ "//xla/stream_executor", "//xla/stream_executor:device_memory", @@ -65,15 +66,38 @@ cc_library( "device_host_allocator.h", "device_mem_allocator.h", ], - visibility = ["//visibility:public"], deps = [ "//xla/stream_executor", + "//xla/stream_executor:memory_allocation", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/synchronization", "@local_tsl//tsl/framework:allocator", "@local_tsl//tsl/framework:device_id", + "@local_tsl//tsl/platform:logging", "@local_tsl//tsl/profiler/lib:traceme", ], ) +cc_library( + name = "gpu_virtual_mem_allocator", + srcs = ["gpu_virtual_mem_allocator.cc"], + hdrs = ["gpu_virtual_mem_allocator.h"], + defines = if_cuda(["GOOGLE_CUDA=1"]), + deps = [ + "//xla/stream_executor:stream_executor_headers", + "@com_google_absl//absl/strings:str_format", + "@local_tsl//tsl/framework:allocator", + "@local_tsl//tsl/framework:device_id_impl", + "@local_tsl//tsl/platform:numbers", + "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/profiler/lib:traceme", + ] + if_cuda([ + "//xla/stream_executor/gpu:gpu_driver_header", + "//xla/stream_executor/gpu:gpu_types_header", + ]), +) + xla_cc_test( name = "tf_allocator_adapter_test", srcs = ["tf_allocator_adapter_test.cc"], @@ -92,3 +116,21 @@ xla_cc_test( "@local_tsl//tsl/framework:allocator", ]), ) + +xla_cc_test( + name = "gpu_virtual_mem_allocator_test", + srcs = if_gpu_is_configured(["gpu_virtual_mem_allocator_test.cc"]), + tags = [ + "gpu", + "no_oss", + "requires-gpu-nvidia", + ], + deps = [ + ":gpu_virtual_mem_allocator", + "//xla/stream_executor/gpu:gpu_init", + "@local_tsl//tsl/framework:device_id_impl", + "@local_tsl//tsl/platform:test", + "@local_tsl//tsl/platform:test_benchmark", + "@local_tsl//tsl/platform:test_main", + ], +) diff --git a/third_party/xla/xla/stream_executor/integrations/device_host_allocator.h b/third_party/xla/xla/stream_executor/integrations/device_host_allocator.h index 27af4b6f84b06a..90292674a54569 100644 --- a/third_party/xla/xla/stream_executor/integrations/device_host_allocator.h +++ b/third_party/xla/xla/stream_executor/integrations/device_host_allocator.h @@ -16,13 +16,22 @@ limitations under the License. #ifndef XLA_STREAM_EXECUTOR_INTEGRATIONS_DEVICE_HOST_ALLOCATOR_H_ #define XLA_STREAM_EXECUTOR_INTEGRATIONS_DEVICE_HOST_ALLOCATOR_H_ +#include +#include +#include #include +#include "absl/base/thread_annotations.h" +#include "absl/container/flat_hash_map.h" +#include "absl/synchronization/mutex.h" +#include "xla/stream_executor/memory_allocation.h" #include "xla/stream_executor/stream_executor.h" #include "tsl/framework/allocator.h" +#include "tsl/platform/logging.h" #include "tsl/profiler/lib/traceme.h" namespace stream_executor { + // Allocator for pinned CPU RAM that is made known to a StreamExecutor-based // device for the purpose of efficient DMA with the device. class DeviceHostAllocator : public tsl::SubAllocator { @@ -45,15 +54,22 @@ class DeviceHostAllocator : public tsl::SubAllocator { void* ptr = nullptr; *bytes_received = num_bytes; + if (num_bytes > 0) { - ptr = stream_exec_->HostMemoryAllocate(num_bytes); - if (ptr == nullptr) { + auto allocation = stream_exec_->HostMemoryAllocate(num_bytes); + if (!allocation.ok()) { LOG(WARNING) << "could not allocate pinned host memory of size: " << num_bytes; - return ptr; + return nullptr; } + + ptr = (*allocation)->opaque(); VisitAlloc(ptr, numa_node_, num_bytes); + + absl::MutexLock lock(&mutex_); + allocs_[ptr] = std::move(*allocation); } + return ptr; } @@ -62,7 +78,8 @@ class DeviceHostAllocator : public tsl::SubAllocator { if (ptr != nullptr) { VisitFree(ptr, numa_node_, num_bytes); - stream_exec_->HostMemoryDeallocate(ptr); + absl::MutexLock lock(&mutex_); + allocs_.erase(ptr); } } @@ -78,6 +95,10 @@ class DeviceHostAllocator : public tsl::SubAllocator { DeviceHostAllocator(const DeviceHostAllocator&) = delete; void operator=(const DeviceHostAllocator&) = delete; + + absl::Mutex mutex_; + absl::flat_hash_map> allocs_ + ABSL_GUARDED_BY(mutex_); }; } // namespace stream_executor diff --git a/tensorflow/core/common_runtime/gpu/gpu_virtual_mem_allocator.cc b/third_party/xla/xla/stream_executor/integrations/gpu_virtual_mem_allocator.cc similarity index 92% rename from tensorflow/core/common_runtime/gpu/gpu_virtual_mem_allocator.cc rename to third_party/xla/xla/stream_executor/integrations/gpu_virtual_mem_allocator.cc index 774763ca4b24d2..470c9c7789e25b 100644 --- a/tensorflow/core/common_runtime/gpu/gpu_virtual_mem_allocator.cc +++ b/third_party/xla/xla/stream_executor/integrations/gpu_virtual_mem_allocator.cc @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2024 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -13,16 +13,16 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/core/common_runtime/gpu/gpu_virtual_mem_allocator.h" +#include "xla/stream_executor/integrations/gpu_virtual_mem_allocator.h" // IWYU pragma: keep -#include "absl/strings/str_format.h" -#include "xla/stream_executor/stream_executor.h" -#include "tsl/platform/numbers.h" -#include "tsl/profiler/lib/traceme.h" +#include "absl/strings/str_format.h" // IWYU pragma: keep +#include "xla/stream_executor/stream_executor.h" // IWYU pragma: keep +#include "tsl/platform/numbers.h" // IWYU pragma: keep +#include "tsl/profiler/lib/traceme.h" // IWYU pragma: keep -#if CUDA_VERSION >= 10020 +#if GOOGLE_CUDA -namespace tensorflow { +namespace stream_executor { namespace { using ::stream_executor::gpu::GpuContext; @@ -49,7 +49,7 @@ Status CheckVirtualAddressManagementSupport(GpuDeviceHandle device, TF_ASSIGN_OR_RETURN(bool supports_virtual_address_management, SupportsVirtualAddressManagement(device)); if (!supports_virtual_address_management) { - return tsl::errors::Internal(absl::StrFormat( + return absl::InternalError(absl::StrFormat( "GPU %d does not support virtual memory address management.", gpu_id.value())); } @@ -121,7 +121,10 @@ GpuVirtualMemAllocator::GpuVirtualMemAllocator( gpu_id_(gpu_id), access_gpu_handles_(access_gpu_handles), vmem_(vmem), - granularity_(granularity) {} + granularity_(granularity) { + CHECK_EQ(granularity & (granularity - 1), 0) + << "Granularity must be a power of two; granularity=" << granularity; +} GpuVirtualMemAllocator::~GpuVirtualMemAllocator() { for (const auto mapping : mappings_) { @@ -224,6 +227,6 @@ void GpuVirtualMemAllocator::Free(void* ptr, size_t num_bytes) { VisitFree(ptr, gpu_id_.value(), num_bytes); } -} // namespace tensorflow +} // namespace stream_executor -#endif +#endif // GOOGLE_CUDA diff --git a/third_party/xla/xla/stream_executor/integrations/gpu_virtual_mem_allocator.h b/third_party/xla/xla/stream_executor/integrations/gpu_virtual_mem_allocator.h new file mode 100644 index 00000000000000..0562907e2e201f --- /dev/null +++ b/third_party/xla/xla/stream_executor/integrations/gpu_virtual_mem_allocator.h @@ -0,0 +1,113 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_STREAM_EXECUTOR_INTEGRATIONS_GPU_VIRTUAL_MEM_ALLOCATOR_H_ +#define XLA_STREAM_EXECUTOR_INTEGRATIONS_GPU_VIRTUAL_MEM_ALLOCATOR_H_ + +#include +#include + +#include "xla/stream_executor/stream_executor.h" +#include "tsl/framework/allocator.h" +#include "tsl/framework/device_id.h" +#include "tsl/platform/statusor.h" + +#if GOOGLE_CUDA +#include "xla/stream_executor/gpu/gpu_driver.h" +#include "xla/stream_executor/gpu/gpu_types.h" + +namespace stream_executor { + +// GpuVirtualMemAllocator is a SubAllocator for use with BFCAllocator which +// provides contiguous allocations with each call to Alloc. This is done by +// reserving a large chunk of virtual addresses at construction and then mapping +// physical memory pages to this virtual address range as requested. +// +// This class is not thread-safe. +class GpuVirtualMemAllocator : public tsl::SubAllocator { + public: + static tsl::StatusOr> Create( + const std::vector& alloc_visitors, + const std::vector& free_visitors, + stream_executor::gpu::GpuContext& gpu_context, + tsl::PlatformDeviceId gpu_id, size_t virtual_address_space_size, + const std::vector& peer_gpu_ids); + ~GpuVirtualMemAllocator() override; + + // Allocates memory at least as large as requested by num_bytes. Will be + // aligned to the min allocation granularity (typically 2MiB). + // alignment is ignored by this allocator. + void* Alloc(size_t alignment, size_t num_bytes, + size_t* bytes_received) override; + + // Frees should only happen at the end of the contiguous memory allocations or + // else we introduce pointless fragmentation...But, this is supported. If the + // allocation happens at the end, then the next_alloc_offset_ is moved back, + // otherwise a hole is created. + // + // Holes are not re-used, all allocations continue to come at the end of the + // next_alloc_offset_. To accommodate this, the virtual_address_space_size + // should be much larger than the max physical size of the allocator. + // + // In practice, since the BFC allocator coalesces adjacent AllocationRegions, + // this free function should never be invoked. + void Free(void* ptr, size_t num_bytes) override; + + bool SupportsCoalescing() const override { return true; } + + private: + GpuVirtualMemAllocator( + const std::vector& alloc_visitors, + const std::vector& free_visitors, + stream_executor::gpu::GpuContext& gpu_context, + tsl::PlatformDeviceId gpu_id, + std::vector access_device_handles, + stream_executor::gpu::GpuDriver::VmemSpan vmem, size_t granularity); + + stream_executor::gpu::GpuContext& gpu_context_; + tsl::PlatformDeviceId gpu_id_; + + // Peer access is configured at mmap time so the allocator must be aware of + // all gpus that may want to read the memory. This list also includes the + // above gpu_id_ to facilitate the invocation of the GpuDriver::MapMemory + // function. + const std::vector access_gpu_handles_; + + // The virtual memory span held by this allocator. + stream_executor::gpu::GpuDriver::VmemSpan vmem_; + // The next offset from the vmem base address that will be allocated. This + // corresponds to the size of physically pinned memory if holes haven't been + // created with "free". + size_t next_alloc_offset_ = 0; + + // Smallest allocation as determined by CUDA. + const size_t granularity_; + + struct Mapping { + stream_executor::gpu::GpuDevicePtr va; + stream_executor::gpu::GpuDriver::GenericMemoryHandle physical; + }; + // List of mappings, sorted by va. + std::vector mappings_; + + GpuVirtualMemAllocator(const GpuVirtualMemAllocator&) = delete; + void operator=(const GpuVirtualMemAllocator&) = delete; +}; + +} // namespace stream_executor + +#endif // GOOGLE_CUDA + +#endif // XLA_STREAM_EXECUTOR_INTEGRATIONS_GPU_VIRTUAL_MEM_ALLOCATOR_H_ diff --git a/tensorflow/core/common_runtime/gpu/gpu_virtual_mem_allocator_test.cc b/third_party/xla/xla/stream_executor/integrations/gpu_virtual_mem_allocator_test.cc similarity index 92% rename from tensorflow/core/common_runtime/gpu/gpu_virtual_mem_allocator_test.cc rename to third_party/xla/xla/stream_executor/integrations/gpu_virtual_mem_allocator_test.cc index f40e70a2603aff..1db65c4d530ab7 100644 --- a/tensorflow/core/common_runtime/gpu/gpu_virtual_mem_allocator_test.cc +++ b/third_party/xla/xla/stream_executor/integrations/gpu_virtual_mem_allocator_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2024 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -13,16 +13,16 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/core/common_runtime/gpu/gpu_virtual_mem_allocator.h" - -#if CUDA_VERSION >= 10020 +#include "xla/stream_executor/integrations/gpu_virtual_mem_allocator.h" // IWYU pragma: keep #include "xla/stream_executor/gpu/gpu_init.h" -#include "tensorflow/core/platform/test.h" -#include "tensorflow/core/platform/test_benchmark.h" #include "tsl/framework/device_id.h" +#include "tsl/platform/test.h" // IWYU pragma: keep +#include "tsl/platform/test_benchmark.h" // IWYU pragma: keep + +#if GOOGLE_CUDA -namespace tensorflow { +namespace stream_executor { namespace { using ::stream_executor::gpu::GpuContext; @@ -36,7 +36,7 @@ constexpr size_t k2MiB{2 << 20}; std::unique_ptr CreateAllocator() { tsl::PlatformDeviceId gpu_id(0); auto executor = - se::GPUMachineManager()->ExecutorForDevice(gpu_id.value()).value(); + GPUMachineManager()->ExecutorForDevice(gpu_id.value()).value(); GpuContext* gpu_context = reinterpret_cast( executor->platform_specific_handle().context); return GpuVirtualMemAllocator::Create( @@ -48,7 +48,7 @@ std::unique_ptr CreateAllocator() { TEST(GpuVirtualMemAllocatorTest, SimpleAlloc) { tsl::PlatformDeviceId gpu_id(0); auto executor = - se::GPUMachineManager()->ExecutorForDevice(gpu_id.value()).value(); + GPUMachineManager()->ExecutorForDevice(gpu_id.value()).value(); GpuContext* gpu_context = reinterpret_cast( executor->platform_specific_handle().context); auto allocator = GpuVirtualMemAllocator::Create( @@ -177,6 +177,6 @@ TEST(GpuVirtualMemAllocatorTest, FreeRange) { } } // namespace -} // namespace tensorflow +} // namespace stream_executor -#endif +#endif // GOOGLE_CUDA diff --git a/third_party/xla/xla/stream_executor/kernel.cc b/third_party/xla/xla/stream_executor/kernel.cc index f38d12555c7f1c..f51257aa049c36 100644 --- a/third_party/xla/xla/stream_executor/kernel.cc +++ b/third_party/xla/xla/stream_executor/kernel.cc @@ -15,19 +15,21 @@ limitations under the License. #include "xla/stream_executor/kernel.h" -#include #include +#include #include #include -#include #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "absl/strings/strip.h" +#include "xla/stream_executor/kernel_spec.h" #include "xla/stream_executor/platform.h" #include "xla/stream_executor/stream_executor.h" #include "xla/stream_executor/stream_executor_internal.h" #include "tsl/platform/demangle.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/statusor.h" namespace stream_executor { @@ -51,39 +53,11 @@ void KernelMetadata::set_shared_memory_bytes(int shared_memory_bytes) { // Kernel //===----------------------------------------------------------------------===// -Kernel::Kernel(Kernel &&from) - : parent_(from.parent_), - implementation_(std::move(from.implementation_)), - name_(std::move(from.name_)), - demangled_name_(std::move(from.demangled_name_)), - metadata_(from.metadata_) { - from.parent_ = nullptr; -} - -Kernel::Kernel(StreamExecutor *parent) - : parent_(parent), - implementation_(parent->implementation()->CreateKernelImplementation()) {} - -Kernel::~Kernel() { - if (parent_) { - parent_->UnloadKernel(this); - } -} - -unsigned Kernel::Arity() const { return implementation_->Arity(); } - -void Kernel::SetPreferredCacheConfig(KernelCacheConfig config) { - return implementation_->SetPreferredCacheConfig(config); -} - -KernelCacheConfig Kernel::GetPreferredCacheConfig() const { - return implementation_->GetPreferredCacheConfig(); -} - -absl::StatusOr Kernel::GetMaxOccupiedBlocksPerCore( - ThreadDim threads, size_t dynamic_shared_memory_bytes) const { - return implementation_->GetMaxOccupiedBlocksPerCore( - threads, dynamic_shared_memory_bytes); +absl::StatusOr> Kernel::Create( + StreamExecutor *executor, const MultiKernelLoaderSpec &spec) { + TF_ASSIGN_OR_RETURN(auto kernel, executor->implementation()->CreateKernel()); + TF_RETURN_IF_ERROR(executor->GetKernel(spec, kernel.get())); + return kernel; } void Kernel::set_name(absl::string_view name) { diff --git a/third_party/xla/xla/stream_executor/kernel.h b/third_party/xla/xla/stream_executor/kernel.h index 42f2ef6d5960f2..c8cf0d90e976cf 100644 --- a/third_party/xla/xla/stream_executor/kernel.h +++ b/third_party/xla/xla/stream_executor/kernel.h @@ -88,18 +88,16 @@ limitations under the License. #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "xla/stream_executor/device_memory.h" +#include "xla/stream_executor/kernel_spec.h" #include "xla/stream_executor/launch_dim.h" #include "tsl/platform/logging.h" +#include "tsl/platform/statusor.h" namespace stream_executor { class Kernel; class StreamExecutor; -namespace internal { -class KernelInterface; -} // namespace internal - //===----------------------------------------------------------------------===// // Kernel cache config //===----------------------------------------------------------------------===// @@ -230,34 +228,22 @@ class Kernel { std::function>( const Kernel &kernel, const KernelArgs &args)>; - Kernel(Kernel &&from); + // TODO(b/323534971): Kernel constructor should be moved to StreamExecutor or + // a dedicated KernelFactory accessible via StreamExecutor. - // Constructs an "empty" (not-yet-loaded) kernel instance. - // - // parent is the StreamExecutor that will be responsible for loading the - // implementation of this kernel. It must not be null. - explicit Kernel(StreamExecutor *parent); + // Creates kernel on a given executor from a given kernel specification. + static absl::StatusOr> Create( + StreamExecutor *executor, const MultiKernelLoaderSpec &spec); - // Releases resources associated with the kernel instance (i.e. - // platform-specific implementation). - ~Kernel(); + Kernel() = default; + virtual ~Kernel() = default; + + Kernel(const Kernel &) = delete; + void operator=(const Kernel &) = delete; // Returns the number of parameters that this kernel accepts. (Arity refers to // nullary, unary, ...). - unsigned Arity() const; - - // Returns the StreamExecutor that represents the platform this kernel - // executes upon. - StreamExecutor *parent() const { return parent_; } - - // Returns a const pointer to the (opaque) platform-dependent implementation. - const internal::KernelInterface *implementation() const { - return implementation_.get(); - } - - // Returns a non-const pointer to the (opaque) platform-dependent - // implementation. - internal::KernelInterface *implementation() { return implementation_.get(); } + virtual unsigned Arity() const = 0; void set_metadata(const KernelMetadata &metadata) { metadata_ = metadata; } @@ -265,15 +251,15 @@ class Kernel { // Sets the preferred cache configuration for a kernel. This is just a // suggestion to the runtime, and may not be honored during execution. - void SetPreferredCacheConfig(KernelCacheConfig config); + virtual void SetPreferredCacheConfig(KernelCacheConfig config) = 0; // Gets the preferred cache configuration for a kernel. - KernelCacheConfig GetPreferredCacheConfig() const; + virtual KernelCacheConfig GetPreferredCacheConfig() const = 0; // Returns the maximum number of blocks (per multiprocessor) occupied by the // kernel given the number of threads per block and shared memory size. - absl::StatusOr GetMaxOccupiedBlocksPerCore( - ThreadDim threads, size_t dynamic_shared_memory_bytes) const; + virtual absl::StatusOr GetMaxOccupiedBlocksPerCore( + ThreadDim threads, size_t dynamic_shared_memory_bytes) const = 0; // Sets custom kernels arguments packing function for a kernel. void set_kernel_args_packing(KernelArgsPacking kernel_args_packing) { @@ -285,38 +271,66 @@ class Kernel { } void set_name(absl::string_view name); - const std::string &name() const { return name_; } - const std::string &demangled_name() const { return demangled_name_; } + std::string_view name() const { return name_; } + std::string_view demangled_name() const { return demangled_name_; } private: - // The StreamExecutor that loads this kernel object. - StreamExecutor *parent_; - - // Implementation delegated to for platform-specific functionality. - std::unique_ptr implementation_; - std::string name_; std::string demangled_name_; KernelMetadata metadata_; KernelArgsPacking kernel_args_packing_; - - Kernel(const Kernel &) = delete; - void operator=(const Kernel &) = delete; }; //===----------------------------------------------------------------------===// // Typed kernel //===----------------------------------------------------------------------===// -// Typed variant of Kernel, like a typed device function pointer. +// Typed kernel is a typed smart-pointer-like wrapper around untyped Kernel. template -class TypedKernel : public Kernel { +class TypedKernel { public: static constexpr size_t kNumberOfParameters = sizeof...(Params); - explicit TypedKernel(StreamExecutor *parent) : Kernel(parent) {} + // Creates a typed kernel on a given executor from a kernel specification. + static absl::StatusOr Create(StreamExecutor *executor, + const MultiKernelLoaderSpec &spec) { + TF_ASSIGN_OR_RETURN(std::unique_ptr kernel, + Kernel::Create(executor, spec)); + return TypedKernel(std::move(kernel)); + } + + // Creates a kernel which can be launched with `stream.ThenLaunch(...)` from a + // PTX (and optional CUBIN), such that the types of the arguments provided for + // launch would have to match types of the arguments provided at creation + // time. The canonical storage for both ptx and cubin_data should outlive the + // lifetime of the kernel. + static absl::StatusOr Create( + StreamExecutor *executor, absl::string_view kernel_name, + absl::string_view ptx, absl::Span cubin_data); + + // Creates a kernel which can be launched with `stream.ThenLaunch(...)` from + // an in-process symbol pointer. + static absl::StatusOr Create(StreamExecutor *executor, + absl::string_view kernel_name, + void *symbol); + + TypedKernel() = default; + + Kernel &operator*() { return *kernel_; } + const Kernel &operator*() const { return *kernel_; } + + Kernel *operator->() { return kernel_.get(); } + const Kernel *operator->() const { return kernel_.get(); } + + operator bool() const { return static_cast(kernel_); } // NOLINT + + private: + explicit TypedKernel(std::unique_ptr kernel) + : kernel_(std::move(kernel)) {} + + std::unique_ptr kernel_; }; //===----------------------------------------------------------------------===// @@ -720,10 +734,34 @@ std::unique_ptr PackKernelArgs( PackedParams::template CheckCompatibleStaticAssert(); - int64_t shmem_bytes = kernel.metadata().shared_memory_bytes().value_or(0); + int64_t shmem_bytes = kernel->metadata().shared_memory_bytes().value_or(0); return std::make_unique(std::forward(args)..., shmem_bytes); } +template +inline absl::StatusOr> TypedKernel::Create( + StreamExecutor *executor, absl::string_view kernel_name, + absl::string_view ptx, absl::Span cubin_data) { + MultiKernelLoaderSpec loader_spec(TypedKernel::kNumberOfParameters); + loader_spec.AddCudaPtxInMemory(ptx, kernel_name); + + if (!cubin_data.empty()) { + loader_spec.AddCudaCubinInMemory( + reinterpret_cast(cubin_data.data()), kernel_name); + } + + return TypedKernel::Create(executor, loader_spec); +} + +template +inline absl::StatusOr> TypedKernel::Create( + StreamExecutor *executor, absl::string_view kernel_name, void *symbol) { + MultiKernelLoaderSpec loader_spec(TypedKernel::kNumberOfParameters); + loader_spec.AddInProcessSymbol(symbol, kernel_name); + + return TypedKernel::Create(executor, loader_spec); +} + } // namespace stream_executor #endif // XLA_STREAM_EXECUTOR_KERNEL_H_ diff --git a/third_party/xla/xla/stream_executor/kernel_test.cc b/third_party/xla/xla/stream_executor/kernel_test.cc index adbc998902e70d..897b58a6bd2c4a 100644 --- a/third_party/xla/xla/stream_executor/kernel_test.cc +++ b/third_party/xla/xla/stream_executor/kernel_test.cc @@ -22,6 +22,8 @@ limitations under the License. #include #include "xla/stream_executor/device_memory.h" +#include "xla/stream_executor/kernel_spec.h" +#include "xla/stream_executor/platform_manager.h" #include "xla/stream_executor/stream_executor.h" #include "tsl/platform/test.h" #include "tsl/platform/test_benchmark.h" @@ -63,7 +65,7 @@ static_assert( std::tuple>); static std::unique_ptr NewStreamExecutor() { - Platform* platform = MultiPlatformManager::PlatformWithName("Host").value(); + Platform* platform = PlatformManager::PlatformWithName("Host").value(); StreamExecutorConfig config(/*ordinal=*/0); return platform->GetUncachedExecutor(config).value(); } @@ -103,11 +105,8 @@ TEST(KernelTest, PackPodArguments) { ASSERT_EQ(f64, 3.0); } -TEST(KernelTest, PackTypedKernelArguments) { - auto executor = NewStreamExecutor(); - TypedKernel kernel(executor.get()); - - auto args = PackKernelArgs(kernel, 1, 2.0f, 3.0); +TEST(KernelTest, PackTupleArguments) { + auto args = PackKernelArgs(/*shmem_bytes=*/0, 1, 2.0f, 3.0); ASSERT_EQ(args->number_of_arguments(), 3); auto packed = args->argument_addresses(); @@ -120,6 +119,14 @@ TEST(KernelTest, PackTypedKernelArguments) { ASSERT_EQ(f64, 3.0); } +TEST(KernelTest, FailToCreateTypedKernelFromEmptySpec) { + MultiKernelLoaderSpec empty_spec(/*arity=*/0); + + auto executor = NewStreamExecutor(); + auto kernel = TypedKernel<>::Create(executor.get(), empty_spec); + EXPECT_FALSE(kernel.ok()); +} + //===----------------------------------------------------------------------===// // Performance benchmarks below //===----------------------------------------------------------------------===// diff --git a/third_party/xla/xla/stream_executor/lazy_op_runner.h b/third_party/xla/xla/stream_executor/lazy_op_runner.h index ea9a7b165420a0..d69521777b07d2 100644 --- a/third_party/xla/xla/stream_executor/lazy_op_runner.h +++ b/third_party/xla/xla/stream_executor/lazy_op_runner.h @@ -25,6 +25,7 @@ limitations under the License. #include "absl/base/call_once.h" #include "absl/status/status.h" #include "absl/status/statusor.h" +#include "xla/stream_executor/blas.h" #include "xla/stream_executor/dnn.h" #include "xla/stream_executor/stream.h" #include "tsl/platform/statusor.h" @@ -32,6 +33,17 @@ limitations under the License. namespace stream_executor { namespace dnn { +namespace internal { +// Returns the DnnSupport object for the given stream. +inline absl::StatusOr GetDnnFromStream(Stream* stream) { + auto dnn = stream->parent()->AsDnn(); + if (dnn == nullptr) { + return absl::InternalError("No DNN support for stream"); + } + return dnn; +} +} // namespace internal + // A lazily-initialized OpRunner from an AlgorithmDesc. // // This exists to hold a choice of conv algorithm for a particular config, @@ -154,8 +166,9 @@ struct ConvOp { static absl::StatusOr>> RunnerFromAlgorithmDesc(const AlgorithmDesc& desc, Config config, Stream* stream) { - return stream->ConvolveRunnerFromDesc( - desc, config.kind, config.input_type, config.output_type, + TF_ASSIGN_OR_RETURN(auto dnn, internal::GetDnnFromStream(stream)); + return dnn->ConvolveRunnerFromDesc( + stream, desc, config.kind, config.input_type, config.output_type, config.input_descriptor, config.filter_descriptor, config.output_descriptor, config.convolution_descriptor); } @@ -179,8 +192,9 @@ struct GraphConvOp { static absl::StatusOr>> RunnerFromAlgorithmDesc(const AlgorithmDesc& desc, Config config, Stream* stream) { - return stream->GraphConvolveRunnerFromDesc( - desc, config.kind, config.input_type, config.output_type, + TF_ASSIGN_OR_RETURN(auto dnn, internal::GetDnnFromStream(stream)); + return dnn->GraphConvolveRunnerFromDesc( + stream, desc, config.kind, config.input_type, config.output_type, config.input_descriptor, config.filter_descriptor, config.output_descriptor, config.convolution_descriptor, config.serialized_graph); @@ -206,8 +220,9 @@ struct FusedConvOp { static absl::StatusOr>> RunnerFromAlgorithmDesc(const AlgorithmDesc& desc, Config config, Stream* stream) { - return stream->FusedConvolveRunnerFromDesc( - desc, config.kind, config.input_type, config.bias_type, + TF_ASSIGN_OR_RETURN(auto dnn, internal::GetDnnFromStream(stream)); + return dnn->FusedConvolveRunnerFromDesc( + stream, desc, config.kind, config.input_type, config.bias_type, config.output_type, config.conv_scale, config.side_input_scale, config.leakyrelu_alpha, config.input_descriptor, config.filter_descriptor, config.bias_descriptor, @@ -221,22 +236,29 @@ struct NormOp { using Signature = NormSignature; struct Config { + NormKind kind; double epsilon; - const TensorDescriptor& input_descriptor; + const TensorDescriptor& x_descriptor; const TensorDescriptor& scale_descriptor; - const TensorDescriptor& bias_descriptor; - const TensorDescriptor& output_descriptor; - std::optional expectation_descriptor; - std::optional norm_factor_descriptor; + const TensorDescriptor& y_or_dx_descriptor; + std::optional bias_descriptor; + std::optional dy_descriptor; + std::optional expectation_descriptor; + std::optional norm_factor_descriptor; + std::optional dscale_descriptor; + std::optional dbias_descriptor; }; static absl::StatusOr>> RunnerFromAlgorithmDesc(const AlgorithmDesc& desc, Config config, Stream* stream) { - return stream->NormRunnerFromDesc( - desc, config.epsilon, config.input_descriptor, config.scale_descriptor, - config.bias_descriptor, config.output_descriptor, - config.expectation_descriptor, config.norm_factor_descriptor); + TF_ASSIGN_OR_RETURN(auto dnn, internal::GetDnnFromStream(stream)); + return dnn->NormRunnerFromDesc( + stream, desc, config.kind, config.epsilon, config.x_descriptor, + config.scale_descriptor, config.y_or_dx_descriptor, + config.bias_descriptor, config.dy_descriptor, + config.expectation_descriptor, config.norm_factor_descriptor, + config.dscale_descriptor, config.dbias_descriptor); } }; @@ -278,8 +300,9 @@ struct FusedMHAOp { static absl::StatusOr>> RunnerFromAlgorithmDesc(const AlgorithmDesc& desc, Config config, Stream* stream) { - return stream->FusedMHARunnerFromDesc( - desc, config.kind, config.bmm1_lhs_descriptor, + TF_ASSIGN_OR_RETURN(auto dnn, internal::GetDnnFromStream(stream)); + return dnn->FusedMHARunnerFromDesc( + stream, desc, config.kind, config.bmm1_lhs_descriptor, config.bmm1_rhs_descriptor, config.bmm2_rhs_descriptor, config.intermediate_bmm2_lhs_descriptor, config.output_descriptor, config.activation_descriptor, config.mask_descriptor, @@ -317,8 +340,9 @@ struct FusedMHABackwardOp { std::unique_ptr>> RunnerFromAlgorithmDesc(const AlgorithmDesc& desc, Config config, Stream* stream) { - return stream->FusedMHABackwardRunnerFromDesc( - desc, config.kind, config.bmm1_grad_gemm1_rhs_descriptor, + TF_ASSIGN_OR_RETURN(auto dnn, internal::GetDnnFromStream(stream)); + return dnn->FusedMHABackwardRunnerFromDesc( + stream, desc, config.kind, config.bmm1_grad_gemm1_rhs_descriptor, config.bmm1_grad_gemm2_rhs_descriptor, config.bmm2_grad_gemm1_lhs_descriptor, config.bmm2_grad_gemm2_rhs_descriptor, config.d_output_descriptor, diff --git a/third_party/xla/xla/stream_executor/memory_allocation.h b/third_party/xla/xla/stream_executor/memory_allocation.h new file mode 100644 index 00000000000000..0e0df2442001e0 --- /dev/null +++ b/third_party/xla/xla/stream_executor/memory_allocation.h @@ -0,0 +1,40 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_STREAM_EXECUTOR_MEMORY_ALLOCATION_H_ +#define XLA_STREAM_EXECUTOR_MEMORY_ALLOCATION_H_ + +#include + +namespace stream_executor { + +// An RAII handle for a memory allocated for a device. It can be pinned host +// memory, unified memory, device memory, etc. depending on what kinds of +// memories are supported by underlying device. +class MemoryAllocation { + public: + MemoryAllocation() = default; + virtual ~MemoryAllocation() = default; + + MemoryAllocation(MemoryAllocation&&) = delete; + MemoryAllocation& operator=(MemoryAllocation&&) = delete; + + virtual void* opaque() const = 0; + virtual uint64_t size() const = 0; +}; + +} // namespace stream_executor + +#endif // XLA_STREAM_EXECUTOR_MEMORY_ALLOCATION_H_ diff --git a/third_party/xla/xla/stream_executor/multi_platform_manager.h b/third_party/xla/xla/stream_executor/multi_platform_manager.h index d7a726327a4648..352fae619fe2ea 100644 --- a/third_party/xla/xla/stream_executor/multi_platform_manager.h +++ b/third_party/xla/xla/stream_executor/multi_platform_manager.h @@ -13,129 +13,17 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -// This is a registration-oriented interface for multiple platforms. -// -// Usage: -// -// In your BUILD rule, add a dependency on a platform plugin that you'd like -// to use, such as: -// -// //third_party/tensorflow/compiler/xla/stream_executor/cuda:cuda_platform -// //third_party/tensorflow/compiler/xla/stream_executor/opencl:opencl_platform -// -// This will register platform plugins that can be discovered via this -// interface. Sample API usage: -// -// absl::StatusOr platform_status = -// se::MultiPlatformManager::PlatformWithName("OpenCL"); -// if (!platform_status.ok()) { ... } -// Platform* platform = platform_status.value(); -// LOG(INFO) << platform->VisibleDeviceCount() << " devices visible"; -// if (platform->VisibleDeviceCount() <= 0) { return; } -// -// for (int i = 0; i < platform->VisibleDeviceCount(); ++i) { -// absl::StatusOr executor_status = -// platform->ExecutorForDevice(i); -// if (!executor_status.ok()) { -// LOG(INFO) << "could not retrieve executor for device ordinal " << i -// << ": " << executor_status.status(); -// continue; -// } -// LOG(INFO) << "found usable executor: " << executor_status.value(); -// } -// -// A few things to note: -// - There is no standard formatting/practice for identifying the name of a -// platform. Ideally, a platform will list its registered name in its header -// or in other associated documentation. -// - Platform name lookup is case-insensitive. "OpenCL" or "opencl" (or even -// ("OpEnCl") would work correctly in the above example. -// -// And similarly, for standard interfaces (BLAS, etc.) you can add -// dependencies on support libraries, e.g.: -// -// //third_party/tensorflow/compiler/xla/stream_executor/cuda:pluton_blas_plugin -// //third_party/tensorflow/compiler/xla/stream_executor/cuda:cudnn_plugin -// //third_party/tensorflow/compiler/xla/stream_executor/cuda:cublas_plugin - #ifndef XLA_STREAM_EXECUTOR_MULTI_PLATFORM_MANAGER_H_ #define XLA_STREAM_EXECUTOR_MULTI_PLATFORM_MANAGER_H_ -#include -#include -#include -#include - -#include "absl/status/status.h" -#include "absl/status/statusor.h" -#include "absl/strings/string_view.h" -#include "xla/stream_executor/platform.h" +#include "xla/stream_executor/platform_manager.h" namespace stream_executor { - -// Manages multiple platforms that may be present on the current machine. -class MultiPlatformManager { - public: - // Registers a platform object, returns an error status if the platform is - // already registered. The associated listener, if not null, will be used to - // trace events for ALL executors for that platform. - // Takes ownership of platform. - static absl::Status RegisterPlatform(std::unique_ptr platform); - - // Retrieves the platform registered with the given platform name (e.g. - // "CUDA", "OpenCL", ...) or id (an opaque, comparable value provided by the - // Platform's Id() method). - // - // If the platform has not already been initialized, it will be initialized - // with a default set of parameters. - // - // If the requested platform is not registered, an error status is returned. - // Ownership of the platform is NOT transferred to the caller -- - // the MultiPlatformManager owns the platforms in a singleton-like fashion. - static absl::StatusOr PlatformWithName(absl::string_view target); - static absl::StatusOr PlatformWithId(const Platform::Id& id); - - // Same functions as above, but allows platforms to be returned without - // initialization if initialize_platform == false. - static absl::StatusOr PlatformWithName(absl::string_view target, - bool initialize_platform); - - // Retrieves the platform registered with the given platform id (an opaque, - // comparable value provided by the Platform's Id() method). - // - // The platform will be initialized with the given options. If the platform - // was already initialized, an error will be returned. - // - // If the requested platform is not registered, an error status is returned. - // Ownership of the platform is NOT transferred to the caller -- - // the MultiPlatformManager owns the platforms in a singleton-like fashion. - static absl::StatusOr InitializePlatformWithId( - const Platform::Id& id, - const std::map& options); - - // Retrieves the platforms satisfying the given filter, i.e. returns true. - // Returned Platforms are always initialized. - static absl::StatusOr> PlatformsWithFilter( - const std::function& filter); - - static absl::StatusOr> PlatformsWithFilter( - const std::function& filter, - bool initialize_platform); - - // Although the MultiPlatformManager "owns" its platforms, it holds them as - // undecorated pointers to prevent races during program exit (between this - // object's data and the underlying platforms (e.g., CUDA, OpenCL). - // Because certain platforms have unpredictable deinitialization - // times/sequences, it is not possible to strucure a safe deinitialization - // sequence. Thus, we intentionally "leak" allocated platforms to defer - // cleanup to the OS. This should be acceptable, as these are one-time - // allocations per program invocation. - // The MultiPlatformManager should be considered the owner - // of any platforms registered with it, and leak checking should be disabled - // during allocation of such Platforms, to avoid spurious reporting at program - // exit. -}; - +// The name `MultiPlatformManager` is deprecated. Please use `PlatformManager` +// instead and include `platform_manager.h`. +// TODO(hebecker): A migration is to `PlatformManager` is under way. +using MultiPlatformManager [[deprecated("Rename to PlatformManager")]] = + PlatformManager; } // namespace stream_executor #endif // XLA_STREAM_EXECUTOR_MULTI_PLATFORM_MANAGER_H_ diff --git a/third_party/xla/xla/stream_executor/platform.h b/third_party/xla/xla/stream_executor/platform.h index 5556d860093c33..e78a741a0b259e 100644 --- a/third_party/xla/xla/stream_executor/platform.h +++ b/third_party/xla/xla/stream_executor/platform.h @@ -62,7 +62,7 @@ struct StreamExecutorConfig { DeviceOptions device_options; }; -// Abstract base class for a platform registered with the MultiPlatformManager. +// Abstract base class for a platform registered with the PlatformManager. class Platform { public: virtual ~Platform(); @@ -103,7 +103,7 @@ class Platform { // initialized before obtaining StreamExecutor objects. The interpretation of // the platform_options argument is implementation specific. This method may // return an error if unrecognized options are provided. If using - // MultiPlatformManager, this method will be called automatically by + // PlatformManager, this method will be called automatically by // InitializePlatformWithId/InitializePlatformWithName. virtual absl::Status Initialize( const std::map& platform_options); diff --git a/third_party/xla/xla/stream_executor/platform/BUILD b/third_party/xla/xla/stream_executor/platform/BUILD index 485a897c4aba3f..cd73a976be545b 100644 --- a/third_party/xla/xla/stream_executor/platform/BUILD +++ b/third_party/xla/xla/stream_executor/platform/BUILD @@ -1,10 +1,11 @@ load("//xla/stream_executor:build_defs.bzl", "stream_executor_friends") -load("@local_tsl//tsl:tsl.bzl", "set_external_visibility") +load("@local_tsl//tsl:tsl.bzl", "internal_visibility") load("@local_tsl//tsl/platform:build_config.bzl", "tf_stream_executor_deps") load("@local_tsl//tsl/platform:rules_cc.bzl", "cc_library") package( - default_visibility = ["//visibility:public"], + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = internal_visibility([":friends"]), licenses = ["notice"], ) @@ -20,7 +21,6 @@ cc_library( "platform.h", "port.h", ], - visibility = ["//visibility:public"], deps = [ "@local_tsl//tsl/platform:logging", "@local_tsl//tsl/platform:macros", @@ -31,7 +31,6 @@ cc_library( cc_library( name = "dso_loader", hdrs = ["dso_loader.h"], - visibility = ["//visibility:public"], deps = [ ":platform", ] + tf_stream_executor_deps("dso_loader", "//xla/stream_executor/platform/"), diff --git a/third_party/xla/xla/stream_executor/platform/default/BUILD b/third_party/xla/xla/stream_executor/platform/default/BUILD index cdd0b54a4d29e6..8f201e9468172f 100644 --- a/third_party/xla/xla/stream_executor/platform/default/BUILD +++ b/third_party/xla/xla/stream_executor/platform/default/BUILD @@ -4,13 +4,15 @@ load("@local_tsl//tsl/platform:rules_cc.bzl", "cc_library") licenses(["notice"]) package( - default_visibility = ["//visibility:public"], + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = [ + "//xla/stream_executor:__subpackages__", + ], ) cc_library( name = "platform", textual_hdrs = ["initialize.h"], - visibility = ["//visibility:public"], ) cc_library( @@ -22,7 +24,6 @@ cc_library( "manual", "nobuilder", ], - visibility = ["//visibility:public"], deps = [ "@com_google_absl//absl/strings", "@local_tsl//tsl/platform:dso_loader", diff --git a/third_party/xla/xla/stream_executor/platform/default/initialize.h b/third_party/xla/xla/stream_executor/platform/default/initialize.h index 559faf28665e6b..cb951ed8b0611c 100644 --- a/third_party/xla/xla/stream_executor/platform/default/initialize.h +++ b/third_party/xla/xla/stream_executor/platform/default/initialize.h @@ -13,13 +13,11 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +// IWYU pragma: private, include "third_party/tensorflow/compiler/xla/stream_executor/platform/initialize.h" + #ifndef XLA_STREAM_EXECUTOR_PLATFORM_DEFAULT_INITIALIZE_H_ #define XLA_STREAM_EXECUTOR_PLATFORM_DEFAULT_INITIALIZE_H_ -#undef REGISTER_MODULE_INITIALIZER -#undef DECLARE_MODULE_INITIALIZER -#undef REGISTER_MODULE_INITIALIZER_SEQUENCE - namespace stream_executor { namespace port { @@ -44,19 +42,20 @@ class Initializer { } // namespace port } // namespace stream_executor -#define REGISTER_INITIALIZER(type, name, body) \ +#define STREAM_EXECUTOR_REGISTER_INITIALIZER(type, name, body) \ static void google_init_##type##_##name() { body; } \ ::stream_executor::port::Initializer google_initializer_##type##_##name( \ google_init_##type##_##name) -#define REGISTER_MODULE_INITIALIZER(name, body) \ - REGISTER_INITIALIZER(module, name, body) +#define STREAM_EXECUTOR_REGISTER_MODULE_INITIALIZER(name, body) \ + STREAM_EXECUTOR_REGISTER_INITIALIZER(module, name, body) -#define DECLARE_INITIALIZER(type, name) \ +#define STREAM_EXECUTOR_DECLARE_INITIALIZER(type, name) \ extern ::stream_executor::port::Initializer google_initializer_##type##_##name -#define DECLARE_MODULE_INITIALIZER(name) DECLARE_INITIALIZER(module, name) +#define STREAM_EXECUTOR_DECLARE_MODULE_INITIALIZER(name) \ + STREAM_EXECUTOR_DECLARE_INITIALIZER(module, name) -#define REGISTER_MODULE_INITIALIZER_SEQUENCE(name1, name2) +#define STREAM_EXECUTOR_REGISTER_MODULE_INITIALIZER_SEQUENCE(name1, name2) #endif // XLA_STREAM_EXECUTOR_PLATFORM_DEFAULT_INITIALIZE_H_ diff --git a/third_party/xla/xla/stream_executor/multi_platform_manager.cc b/third_party/xla/xla/stream_executor/platform_manager.cc similarity index 85% rename from third_party/xla/xla/stream_executor/multi_platform_manager.cc rename to third_party/xla/xla/stream_executor/platform_manager.cc index 7d29885a994e91..ec97c3e9603e5a 100644 --- a/third_party/xla/xla/stream_executor/multi_platform_manager.cc +++ b/third_party/xla/xla/stream_executor/platform_manager.cc @@ -13,8 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/stream_executor/multi_platform_manager.h" - #include #include #include @@ -28,6 +26,7 @@ limitations under the License. #include "absl/strings/str_join.h" #include "absl/strings/string_view.h" #include "absl/synchronization/mutex.h" +#include "xla/stream_executor/platform_manager.h" #include "xla/stream_executor/platform.h" #include "tsl/platform/errors.h" #include "tsl/platform/statusor.h" @@ -35,7 +34,7 @@ limitations under the License. namespace stream_executor { namespace { -class MultiPlatformManagerImpl { +class PlatformManagerImpl { public: absl::Status RegisterPlatform(std::unique_ptr platform) ABSL_LOCKS_EXCLUDED(mu_); @@ -89,7 +88,7 @@ class MultiPlatformManagerImpl { absl::flat_hash_map name_map_ ABSL_GUARDED_BY(mu_); }; -absl::Status MultiPlatformManagerImpl::RegisterPlatform( +absl::Status PlatformManagerImpl::RegisterPlatform( std::unique_ptr platform) { CHECK(platform != nullptr); std::string key = absl::AsciiStrToLower(platform->Name()); @@ -109,17 +108,17 @@ absl::Status MultiPlatformManagerImpl::RegisterPlatform( return absl::OkStatus(); } -absl::StatusOr MultiPlatformManagerImpl::PlatformWithName( +absl::StatusOr PlatformManagerImpl::PlatformWithName( absl::string_view target) { return PlatformWithName(target, /*initialize_platform=*/true); } -absl::StatusOr MultiPlatformManagerImpl::PlatformWithId( +absl::StatusOr PlatformManagerImpl::PlatformWithId( const Platform::Id& id) { return PlatformWithId(id, /*initialize_platform=*/true); } -absl::StatusOr MultiPlatformManagerImpl::PlatformWithName( +absl::StatusOr PlatformManagerImpl::PlatformWithName( absl::string_view target, bool initialize_platform) { absl::MutexLock lock(&mu_); @@ -131,7 +130,7 @@ absl::StatusOr MultiPlatformManagerImpl::PlatformWithName( return platform; } -absl::StatusOr MultiPlatformManagerImpl::PlatformWithId( +absl::StatusOr PlatformManagerImpl::PlatformWithId( const Platform::Id& id, bool initialize_platform) { absl::MutexLock lock(&mu_); @@ -143,7 +142,7 @@ absl::StatusOr MultiPlatformManagerImpl::PlatformWithId( return platform; } -absl::StatusOr MultiPlatformManagerImpl::InitializePlatformWithName( +absl::StatusOr PlatformManagerImpl::InitializePlatformWithName( absl::string_view target, const std::map& options) { absl::MutexLock lock(&mu_); @@ -159,7 +158,7 @@ absl::StatusOr MultiPlatformManagerImpl::InitializePlatformWithName( return platform; } -absl::StatusOr MultiPlatformManagerImpl::InitializePlatformWithId( +absl::StatusOr PlatformManagerImpl::InitializePlatformWithId( const Platform::Id& id, const std::map& options) { absl::MutexLock lock(&mu_); @@ -174,8 +173,7 @@ absl::StatusOr MultiPlatformManagerImpl::InitializePlatformWithId( return platform; } -absl::StatusOr> -MultiPlatformManagerImpl::PlatformsWithFilter( +absl::StatusOr> PlatformManagerImpl::PlatformsWithFilter( const std::function& filter, bool initialize_platform) { absl::MutexLock lock(&mu_); @@ -195,7 +193,7 @@ MultiPlatformManagerImpl::PlatformsWithFilter( } std::vector -MultiPlatformManagerImpl::InitializedPlatformNamesWithFilter( +PlatformManagerImpl::InitializedPlatformNamesWithFilter( const std::function& filter) { CHECK_EQ(id_map_.size(), name_map_.size()); std::vector initialized_platforms_names; @@ -211,7 +209,7 @@ MultiPlatformManagerImpl::InitializedPlatformNamesWithFilter( return initialized_platforms_names; } -absl::StatusOr MultiPlatformManagerImpl::LookupByNameLocked( +absl::StatusOr PlatformManagerImpl::LookupByNameLocked( absl::string_view target) { auto it = name_map_.find(absl::AsciiStrToLower(target)); if (it == name_map_.end()) { @@ -223,7 +221,7 @@ absl::StatusOr MultiPlatformManagerImpl::LookupByNameLocked( return it->second; } -absl::StatusOr MultiPlatformManagerImpl::LookupByIdLocked( +absl::StatusOr PlatformManagerImpl::LookupByIdLocked( const Platform::Id& id) { auto it = id_map_.find(id); if (it == id_map_.end()) { @@ -233,47 +231,46 @@ absl::StatusOr MultiPlatformManagerImpl::LookupByIdLocked( return it->second; } -MultiPlatformManagerImpl& Impl() { - static MultiPlatformManagerImpl* impl = new MultiPlatformManagerImpl; +PlatformManagerImpl& Impl() { + static PlatformManagerImpl* impl = new PlatformManagerImpl; return *impl; } } // namespace -/*static*/ absl::Status MultiPlatformManager::RegisterPlatform( +/*static*/ absl::Status PlatformManager::RegisterPlatform( std::unique_ptr platform) { return Impl().RegisterPlatform(std::move(platform)); } -/*static*/ absl::StatusOr MultiPlatformManager::PlatformWithName( +/*static*/ absl::StatusOr PlatformManager::PlatformWithName( absl::string_view target) { return Impl().PlatformWithName(target); } -/*static*/ absl::StatusOr MultiPlatformManager::PlatformWithId( +/*static*/ absl::StatusOr PlatformManager::PlatformWithId( const Platform::Id& id) { return Impl().PlatformWithId(id); } -/*static*/ absl::StatusOr MultiPlatformManager::PlatformWithName( +/*static*/ absl::StatusOr PlatformManager::PlatformWithName( absl::string_view target, bool initialize_platform) { return Impl().PlatformWithName(target, initialize_platform); } -/*static*/ absl::StatusOr -MultiPlatformManager::InitializePlatformWithId( +/*static*/ absl::StatusOr PlatformManager::InitializePlatformWithId( const Platform::Id& id, const std::map& options) { return Impl().InitializePlatformWithId(id, options); } /*static*/ absl::StatusOr> -MultiPlatformManager::PlatformsWithFilter( +PlatformManager::PlatformsWithFilter( const std::function& filter) { return PlatformsWithFilter(filter, /*initialize_platform=*/true); } /*static*/ absl::StatusOr> -MultiPlatformManager::PlatformsWithFilter( +PlatformManager::PlatformsWithFilter( const std::function& filter, bool initialize_platform) { return Impl().PlatformsWithFilter(filter, initialize_platform); diff --git a/third_party/xla/xla/stream_executor/platform_manager.h b/third_party/xla/xla/stream_executor/platform_manager.h new file mode 100644 index 00000000000000..665bf3c2cd0eb2 --- /dev/null +++ b/third_party/xla/xla/stream_executor/platform_manager.h @@ -0,0 +1,141 @@ +/* Copyright 2015 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This is a registration-oriented interface for multiple platforms. +// +// Usage: +// +// In your BUILD rule, add a dependency on a platform plugin that you'd like +// to use, such as: +// +// //third_party/tensorflow/compiler/xla/stream_executor/cuda:cuda_platform +// //third_party/tensorflow/compiler/xla/stream_executor/opencl:opencl_platform +// +// This will register platform plugins that can be discovered via this +// interface. Sample API usage: +// +// absl::StatusOr platform_status = +// se::PlatformManager::PlatformWithName("OpenCL"); +// if (!platform_status.ok()) { ... } +// Platform* platform = platform_status.value(); +// LOG(INFO) << platform->VisibleDeviceCount() << " devices visible"; +// if (platform->VisibleDeviceCount() <= 0) { return; } +// +// for (int i = 0; i < platform->VisibleDeviceCount(); ++i) { +// absl::StatusOr executor_status = +// platform->ExecutorForDevice(i); +// if (!executor_status.ok()) { +// LOG(INFO) << "could not retrieve executor for device ordinal " << i +// << ": " << executor_status.status(); +// continue; +// } +// LOG(INFO) << "found usable executor: " << executor_status.value(); +// } +// +// A few things to note: +// - There is no standard formatting/practice for identifying the name of a +// platform. Ideally, a platform will list its registered name in its header +// or in other associated documentation. +// - Platform name lookup is case-insensitive. "OpenCL" or "opencl" (or even +// ("OpEnCl") would work correctly in the above example. +// +// And similarly, for standard interfaces (BLAS, etc.) you can add +// dependencies on support libraries, e.g.: +// +// //third_party/tensorflow/compiler/xla/stream_executor/cuda:pluton_blas_plugin +// //third_party/tensorflow/compiler/xla/stream_executor/cuda:cudnn_plugin +// //third_party/tensorflow/compiler/xla/stream_executor/cuda:cublas_plugin + +#ifndef XLA_STREAM_EXECUTOR_PLATFORM_MANAGER_H_ +#define XLA_STREAM_EXECUTOR_PLATFORM_MANAGER_H_ + +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "xla/stream_executor/platform.h" + +namespace stream_executor { + +// Manages multiple platforms that may be present on the current machine. +class PlatformManager { + public: + // Registers a platform object, returns an error status if the platform is + // already registered. The associated listener, if not null, will be used to + // trace events for ALL executors for that platform. + // Takes ownership of platform. + static absl::Status RegisterPlatform(std::unique_ptr platform); + + // Retrieves the platform registered with the given platform name (e.g. + // "CUDA", "OpenCL", ...) or id (an opaque, comparable value provided by the + // Platform's Id() method). + // + // If the platform has not already been initialized, it will be initialized + // with a default set of parameters. + // + // If the requested platform is not registered, an error status is returned. + // Ownership of the platform is NOT transferred to the caller -- + // the PlatformManager owns the platforms in a singleton-like fashion. + static absl::StatusOr PlatformWithName(absl::string_view target); + static absl::StatusOr PlatformWithId(const Platform::Id& id); + + // Same functions as above, but allows platforms to be returned without + // initialization if initialize_platform == false. + static absl::StatusOr PlatformWithName(absl::string_view target, + bool initialize_platform); + + // Retrieves the platform registered with the given platform id (an opaque, + // comparable value provided by the Platform's Id() method). + // + // The platform will be initialized with the given options. If the platform + // was already initialized, an error will be returned. + // + // If the requested platform is not registered, an error status is returned. + // Ownership of the platform is NOT transferred to the caller -- + // the PlatformManager owns the platforms in a singleton-like fashion. + static absl::StatusOr InitializePlatformWithId( + const Platform::Id& id, + const std::map& options); + + // Retrieves the platforms satisfying the given filter, i.e. returns true. + // Returned Platforms are always initialized. + static absl::StatusOr> PlatformsWithFilter( + const std::function& filter); + + static absl::StatusOr> PlatformsWithFilter( + const std::function& filter, + bool initialize_platform); + + // Although the PlatformManager "owns" its platforms, it holds them as + // undecorated pointers to prevent races during program exit (between this + // object's data and the underlying platforms (e.g., CUDA, OpenCL). + // Because certain platforms have unpredictable deinitialization + // times/sequences, it is not possible to strucure a safe deinitialization + // sequence. Thus, we intentionally "leak" allocated platforms to defer + // cleanup to the OS. This should be acceptable, as these are one-time + // allocations per program invocation. + // The PlatformManager should be considered the owner + // of any platforms registered with it, and leak checking should be disabled + // during allocation of such Platforms, to avoid spurious reporting at program + // exit. +}; + +} // namespace stream_executor + +#endif // XLA_STREAM_EXECUTOR_PLATFORM_MANAGER_H_ diff --git a/third_party/xla/xla/stream_executor/rocm/BUILD b/third_party/xla/xla/stream_executor/rocm/BUILD index 5afb310388f889..e2157eb0028ce3 100644 --- a/third_party/xla/xla/stream_executor/rocm/BUILD +++ b/third_party/xla/xla/stream_executor/rocm/BUILD @@ -16,12 +16,13 @@ load( "if_rocm_is_configured", "rocm_library", ) -load("@local_tsl//tsl:tsl.bzl", "set_external_visibility", "tsl_copts") +load("@local_tsl//tsl:tsl.bzl", "internal_visibility", "tsl_copts") load("@local_tsl//tsl/platform:build_config_root.bzl", "if_static") load("@local_tsl//tsl/platform:rules_cc.bzl", "cc_library") package( - default_visibility = ["//visibility:public"], + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = internal_visibility([":friends"]), licenses = ["notice"], ) @@ -34,7 +35,6 @@ cc_library( name = "rocm_diagnostics", srcs = if_rocm_is_configured(["rocm_diagnostics.cc"]), hdrs = if_rocm_is_configured(["rocm_diagnostics.h"]), - visibility = ["//visibility:public"], deps = if_rocm_is_configured([ "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/strings", @@ -52,7 +52,6 @@ cc_library( "rocm_driver_wrapper.h", "rocm_driver.h", ]), - visibility = ["//visibility:public"], deps = if_rocm_is_configured([ ":rocm_diagnostics", "@com_google_absl//absl/base", @@ -78,7 +77,6 @@ cc_library( "rocm_driver_wrapper.h", "rocm_driver.h", ]), - visibility = ["//visibility:public"], deps = if_rocm_is_configured([ "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/log", @@ -99,7 +97,6 @@ cc_library( cc_library( name = "rocm_collectives", srcs = if_rocm_is_configured(["rocm_collectives.cc"]), - visibility = ["//visibility:public"], deps = if_rocm_is_configured([ "@com_google_absl//absl/status", "@com_google_absl//absl/strings", @@ -112,7 +109,6 @@ cc_library( name = "rocm_activation", srcs = [], hdrs = if_rocm_is_configured(["rocm_activation.h"]), - visibility = ["//visibility:public"], deps = if_rocm_is_configured([ ":rocm_driver", "@local_config_rocm//rocm:rocm_headers", @@ -126,7 +122,6 @@ cc_library( cc_library( name = "rocm_event", srcs = if_rocm_is_configured(["rocm_event.cc"]), - visibility = ["//visibility:public"], deps = if_rocm_is_configured([ ":rocm_driver", "//xla/stream_executor", @@ -139,7 +134,6 @@ cc_library( cc_library( name = "rocm_executor", srcs = if_rocm_is_configured(["rocm_executor.cc"]), - visibility = ["//visibility:public"], deps = if_rocm_is_configured([ ":rocm_diagnostics", ":rocm_driver", @@ -205,20 +199,18 @@ cc_library( "//xla/stream_executor", # buildcleaner: keep "//xla/stream_executor/platform", ]), - alwayslink = True, # Registers itself with the MultiPlatformManager. + alwayslink = True, # Registers itself with the PlatformManager. ) cc_library( name = "rocm_platform_id", srcs = ["rocm_platform_id.cc"], hdrs = ["rocm_platform_id.h"], - visibility = ["//visibility:public"], deps = ["//xla/stream_executor:platform"], ) cc_library( name = "rocblas_if_static", - visibility = ["//visibility:public"], deps = if_static([ ":rocblas_if_rocm_configured", ]), @@ -226,7 +218,6 @@ cc_library( cc_library( name = "rocblas_if_rocm_configured", - visibility = ["//visibility:public"], deps = if_rocm_is_configured([ "@local_config_rocm//rocm:rocblas", ]), @@ -236,7 +227,6 @@ cc_library( name = "rocblas_wrapper", srcs = if_rocm_is_configured(["rocblas_wrapper.h"]), hdrs = if_rocm_is_configured(["rocblas_wrapper.h"]), - visibility = ["//visibility:public"], deps = if_rocm_is_configured([ ":rocblas_if_static", ":rocm_executor", @@ -282,7 +272,6 @@ cc_library( cc_library( name = "hipfft_if_static", - visibility = ["//visibility:public"], deps = if_static([ ":hipfft_if_rocm_configured", ]), @@ -290,7 +279,6 @@ cc_library( cc_library( name = "hipfft_if_rocm_configured", - visibility = ["//visibility:public"], deps = if_rocm_is_configured([ "@local_config_rocm//rocm:hipfft", ]), @@ -322,7 +310,6 @@ cc_library( cc_library( name = "miopen_if_static", - visibility = ["//visibility:public"], deps = if_static([ ":miopen_if_rocm_configured", ]), @@ -330,7 +317,6 @@ cc_library( cc_library( name = "miopen_if_rocm_configured", - visibility = ["//visibility:public"], deps = if_rocm_is_configured([ "@local_config_rocm//rocm:miopen", ]), @@ -360,6 +346,7 @@ cc_library( "//xla/stream_executor/gpu:gpu_stream_header", "//xla/stream_executor/gpu:gpu_timer_header", "//xla/stream_executor/gpu:gpu_types_header", + "//xla/stream_executor:device_memory_allocator", "//xla/stream_executor/platform", "//xla/stream_executor/platform:dso_loader", "@com_google_absl//absl/algorithm:container", @@ -375,7 +362,6 @@ cc_library( cc_library( name = "hiprand_if_static", - visibility = ["//visibility:public"], deps = if_static([ ":hiprand_if_rocm_configured", ]), @@ -383,7 +369,6 @@ cc_library( cc_library( name = "hiprand_if_rocm_configured", - visibility = ["//visibility:public"], deps = if_rocm_is_configured([ "@local_config_rocm//rocm:hiprand", ]), @@ -391,7 +376,6 @@ cc_library( cc_library( name = "hipsparse_if_static", - visibility = ["//visibility:public"], deps = if_static([ ":hipsparse_if_rocm_configured", ]), @@ -399,7 +383,6 @@ cc_library( cc_library( name = "hipsparse_if_rocm_configured", - visibility = ["//visibility:public"], deps = if_rocm_is_configured([ "@local_config_rocm//rocm:hipsparse", ]), @@ -409,7 +392,6 @@ cc_library( name = "hipsparse_wrapper", srcs = if_rocm_is_configured(["hipsparse_wrapper.h"]), hdrs = if_rocm_is_configured(["hipsparse_wrapper.h"]), - visibility = ["//visibility:public"], deps = if_rocm_is_configured([ ":hipsparse_if_static", ":rocm_executor", @@ -424,7 +406,6 @@ cc_library( cc_library( name = "rocsolver_if_static", - visibility = ["//visibility:public"], deps = if_static([ ":rocsolver_if_rocm_configured", ]), @@ -432,7 +413,6 @@ cc_library( cc_library( name = "rocsolver_if_rocm_configured", - visibility = ["//visibility:public"], deps = if_rocm_is_configured([ "@local_config_rocm//rocm:rocsolver", ]), @@ -442,7 +422,6 @@ cc_library( name = "rocsolver_wrapper", srcs = if_rocm_is_configured(["rocsolver_wrapper.h"]), hdrs = if_rocm_is_configured(["rocsolver_wrapper.h"]), - visibility = ["//visibility:public"], deps = if_rocm_is_configured([ ":rocm_executor", ":rocm_platform_id", @@ -457,7 +436,6 @@ cc_library( cc_library( name = "hipsolver_if_static", - visibility = ["//visibility:public"], deps = if_static([ ":hipsolver_if_rocm_configured", ]), @@ -465,7 +443,6 @@ cc_library( cc_library( name = "hipsolver_if_rocm_configured", - visibility = ["//visibility:public"], deps = if_rocm_is_configured([ "@local_config_rocm//rocm:hipsolver", ]), @@ -475,7 +452,6 @@ cc_library( name = "hipsolver_wrapper", srcs = if_rocm_is_configured(["hipsolver_wrapper.h"]), hdrs = if_rocm_is_configured(["hipsolver_wrapper.h"]), - visibility = ["//visibility:public"], deps = if_rocm_is_configured([ ":rocm_executor", ":rocm_platform_id", @@ -490,7 +466,6 @@ cc_library( cc_library( name = "hipblaslt_if_static", - visibility = ["//visibility:public"], deps = if_rocm_hipblaslt([ "@local_config_rocm//rocm:hipblaslt", ]), @@ -504,7 +479,6 @@ cc_library( "hipblaslt_wrapper.h", "hip_blas_utils.h", ]), - visibility = ["//visibility:public"], deps = if_rocm_is_configured([ ":rocm_executor", ":rocm_platform_id", @@ -550,7 +524,6 @@ cc_library( name = "hip_blas_utils", srcs = if_rocm_is_configured(["hip_blas_utils.cc"]), hdrs = if_rocm_is_configured(["hip_blas_utils.h"]), - visibility = ["//visibility:public"], deps = if_rocm_is_configured([ ":rocblas_plugin", ":hipblas_lt_header", @@ -564,7 +537,6 @@ cc_library( cc_library( name = "roctracer_if_static", - visibility = ["//visibility:public"], deps = if_static([ ":roctracer_if_rocm_configured", ]), @@ -572,7 +544,6 @@ cc_library( cc_library( name = "roctracer_if_rocm_configured", - visibility = ["//visibility:public"], deps = if_rocm_is_configured([ "@local_config_rocm//rocm:roctracer", ]), @@ -582,7 +553,6 @@ cc_library( name = "roctracer_wrapper", srcs = if_rocm_is_configured(["roctracer_wrapper.h"]), hdrs = if_rocm_is_configured(["roctracer_wrapper.h"]), - visibility = ["//visibility:public"], deps = if_rocm_is_configured([ ":rocm_executor", ":rocm_platform_id", @@ -628,13 +598,11 @@ cc_library( "-Wl,-rpath,../local_config_rocm/rocm/rocm/lib", ], }), - visibility = ["//visibility:public"], deps = [], ) cc_library( name = "stream_executor_rocm", - visibility = ["//visibility:public"], deps = [ ":rocm_rpath", "//xla/stream_executor:stream_executor_bundle", diff --git a/third_party/xla/xla/stream_executor/rocm/hip_blas_lt.cc b/third_party/xla/xla/stream_executor/rocm/hip_blas_lt.cc index bbc7f834fccfcb..0a7df71587cd5e 100644 --- a/third_party/xla/xla/stream_executor/rocm/hip_blas_lt.cc +++ b/third_party/xla/xla/stream_executor/rocm/hip_blas_lt.cc @@ -110,6 +110,14 @@ static absl::StatusOr AsHipblasLtEpilogue( return HIPBLASLT_EPILOGUE_RELU_BIAS; case gpu::BlasLt::Epilogue::kGELU: return HIPBLASLT_EPILOGUE_GELU; +#if TF_ROCM_VERSION >= 60000 + case gpu::BlasLt::Epilogue::kGELUWithAux: + return HIPBLASLT_EPILOGUE_GELU_AUX; + case gpu::BlasLt::Epilogue::kBiasThenGELU: + return HIPBLASLT_EPILOGUE_GELU_BIAS; + case gpu::BlasLt::Epilogue::kBiasThenGELUWithAux: + return HIPBLASLT_EPILOGUE_GELU_AUX_BIAS; +#endif default: return absl::InternalError("Unsupported epilogue: " + std::to_string((int)epilogue)); diff --git a/third_party/xla/xla/stream_executor/rocm/hip_conditional_kernels.cu.cc b/third_party/xla/xla/stream_executor/rocm/hip_conditional_kernels.cu.cc index 6bec473fcc8e48..d88ae44f240d3a 100644 --- a/third_party/xla/xla/stream_executor/rocm/hip_conditional_kernels.cu.cc +++ b/third_party/xla/xla/stream_executor/rocm/hip_conditional_kernels.cu.cc @@ -13,33 +13,14 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include +#include -namespace stream_executor { -namespace rocm { -namespace { +namespace stream_executor::gpu { -__global__ void SetCondition() {} +std::string_view GetSetIfConditionKernel() { return ""; } +std::string_view GetSetIfElseConditionKernel() { return ""; } +std::string_view GetSetCaseConditionKernel() { return ""; } +std::string_view GetSetForConditionKernel() { return ""; } +std::string_view GetSetWhileConditionKernel() { return ""; } -} // namespace -} // namespace rocm - -namespace gpu { -void* GetSetIfConditionKernel() { - return reinterpret_cast(&rocm::SetCondition); -} -void* GetSetIfElseConditionKernel() { - return reinterpret_cast(&rocm::SetCondition); -} -void* GetSetCaseConditionKernel() { - return reinterpret_cast(&rocm::SetCondition); -} -void* GetSetForConditionKernel() { - return reinterpret_cast(&rocm::SetCondition); -} -void* GetSetWhileConditionKernel() { - return reinterpret_cast(&rocm::SetCondition); -} -} // namespace gpu - -} // namespace stream_executor +} // namespace stream_executor::gpu diff --git a/third_party/xla/xla/stream_executor/rocm/rocblas_wrapper.h b/third_party/xla/xla/stream_executor/rocm/rocblas_wrapper.h index 116d9fc03ad578..a637d68428d5e3 100644 --- a/third_party/xla/xla/stream_executor/rocm/rocblas_wrapper.h +++ b/third_party/xla/xla/stream_executor/rocm/rocblas_wrapper.h @@ -20,6 +20,8 @@ limitations under the License. #ifndef XLA_STREAM_EXECUTOR_ROCM_ROCBLAS_WRAPPER_H_ #define XLA_STREAM_EXECUTOR_ROCM_ROCBLAS_WRAPPER_H_ +// needed for rocblas_gemm_ex_get_solutions* functionality +#define ROCBLAS_BETA_FEATURES_API #include "rocm/include/rocblas/rocblas.h" #include "xla/stream_executor/gpu/gpu_activation.h" #include "xla/stream_executor/platform/dso_loader.h" @@ -32,44 +34,42 @@ namespace wrap { using stream_executor::internal::CachedDsoLoader::GetRocblasDsoHandle; #ifdef PLATFORM_GOOGLE -#define ROCBLAS_API_WRAPPER(__name) \ - struct WrapperShim__##__name { \ - static const char* kName; \ - template \ - rocblas_status operator()(Args... args) { \ - return ::__name(args...); \ - } \ - } __name; \ - const char* WrapperShim__##__name::kName = #__name; +#define ROCBLAS_API_WRAPPER(__name) \ + struct WrapperShim__##__name { \ + constexpr static const char* kName = #__name; \ + template \ + rocblas_status operator()(Args... args) { \ + return ::__name(args...); \ + } \ + } __name; #else -#define ROCBLAS_API_WRAPPER(__name) \ - struct DynLoadShim__##__name { \ - static const char* kName; \ - using FuncPtrT = std::add_pointer::type; \ - static void* GetDsoHandle() { \ - auto s = GetRocblasDsoHandle(); \ - return s.value(); \ - } \ - static FuncPtrT LoadOrDie() { \ - void* f; \ - auto s = tsl::Env::Default() \ - -> GetSymbolFromLibrary(GetDsoHandle(), kName, &f); \ - CHECK(s.ok()) << "could not find " << kName \ - << " in rocblas DSO; dlerror: " << s.message(); \ - return reinterpret_cast(f); \ - } \ - static FuncPtrT DynLoad() { \ - static FuncPtrT f = LoadOrDie(); \ - return f; \ - } \ - template \ - rocblas_status operator()(Args... args) { \ - return DynLoad()(args...); \ - } \ - } __name; \ - const char* DynLoadShim__##__name::kName = #__name; +#define ROCBLAS_API_WRAPPER(__name) \ + struct DynLoadShim__##__name { \ + constexpr static const char* kName = #__name; \ + using FuncPtrT = std::add_pointer::type; \ + static void* GetDsoHandle() { \ + auto s = GetRocblasDsoHandle(); \ + return s.value(); \ + } \ + static FuncPtrT LoadOrDie() { \ + void* f; \ + auto s = tsl::Env::Default()->GetSymbolFromLibrary(GetDsoHandle(), \ + kName, &f); \ + CHECK(s.ok()) << "could not find " << kName \ + << " in rocblas DSO; dlerror: " << s.message(); \ + return reinterpret_cast(f); \ + } \ + static FuncPtrT DynLoad() { \ + static FuncPtrT f = LoadOrDie(); \ + return f; \ + } \ + template \ + rocblas_status operator()(Args... args) { \ + return DynLoad()(args...); \ + } \ + } __name; #endif @@ -257,6 +257,11 @@ using stream_executor::internal::CachedDsoLoader::GetRocblasDsoHandle; __macro(rocblas_zgemm_strided_batched) \ __macro(rocblas_gemm_ex) \ __macro(rocblas_gemm_strided_batched_ex) \ + __macro(rocblas_gemm_ex_get_solutions) \ + __macro(rocblas_gemm_ex_get_solutions_by_type) \ + __macro(rocblas_gemm_batched_ex_get_solutions) \ + __macro(rocblas_gemm_batched_ex_get_solutions_by_type) \ + __macro(rocblas_gemm_strided_batched_ex_get_solutions) \ __macro(rocblas_strsm_batched) \ __macro(rocblas_dtrsm_batched) \ __macro(rocblas_ctrsm_batched) \ diff --git a/third_party/xla/xla/stream_executor/rocm/rocm_blas.cc b/third_party/xla/xla/stream_executor/rocm/rocm_blas.cc index cc5bc55e888812..129349f7144a5b 100644 --- a/third_party/xla/xla/stream_executor/rocm/rocm_blas.cc +++ b/third_party/xla/xla/stream_executor/rocm/rocm_blas.cc @@ -52,48 +52,56 @@ extern void rocm_Broadcast_fp32(void *stream, float *dst, int dst_stride, int size); template -const typename RocBlasTypeConversionHelper::mapped_type *complex_cast( - const DeviceMemory &a) { - return reinterpret_cast< - const typename RocBlasTypeConversionHelper::mapped_type *>( - GpuMemory(a)); +const RocBlasType_t *const *complex_cast(const DeviceMemory &a) { + return reinterpret_cast *const *>(GpuMemory(a)); } template -const typename RocBlasTypeConversionHelper::mapped_type *complex_cast( - const T &a) { - return reinterpret_cast< - const typename RocBlasTypeConversionHelper::mapped_type *>(&a); +RocBlasType_t *const *complex_cast(DeviceMemory &a) { + return reinterpret_cast *const *>(GpuMemory(a)); } + template -typename RocBlasTypeConversionHelper::mapped_type *complex_cast( - DeviceMemory *a) { - return reinterpret_cast< - typename RocBlasTypeConversionHelper::mapped_type *>( - GpuMemoryMutable(a)); +const RocBlasType_t *complex_cast(const DeviceMemory &a) { + return reinterpret_cast *>(GpuMemory(a)); } -static void blas_log(const char *c) {} +template +const RocBlasType_t *complex_cast(const T &a) { + return reinterpret_cast *>(&a); +} +template +RocBlasType_t *complex_cast(DeviceMemory *a) { + return reinterpret_cast *>(GpuMemoryMutable(a)); +} static string ToString(rocblas_status status) { +#define XVAL(x) \ + case x: \ + return #x switch (status) { - case rocblas_status_success: - return "rocblas_status_success"; - case rocblas_status_invalid_handle: - return "rocblas_status_invalid_handle"; - case rocblas_status_not_implemented: - return "rocblas_status_not_implemented"; - case rocblas_status_invalid_pointer: - return "rocblas_status_invalid_pointer"; - case rocblas_status_invalid_size: - return "rocblas_status_invalid_size"; - case rocblas_status_memory_error: - return "rocblas_status_memory_error"; - case rocblas_status_internal_error: - return "rocblas_status_internal_error"; + XVAL(rocblas_status_success); + XVAL(rocblas_status_invalid_handle); + XVAL(rocblas_status_not_implemented); + XVAL(rocblas_status_invalid_pointer); + XVAL(rocblas_status_invalid_size); + XVAL(rocblas_status_memory_error); + XVAL(rocblas_status_internal_error); +#if TF_ROCM_VERSION >= 60000 + XVAL(rocblas_status_perf_degraded); + XVAL(rocblas_status_size_query_mismatch); + XVAL(rocblas_status_size_increased); + XVAL(rocblas_status_size_unchanged); + XVAL(rocblas_status_invalid_value); + XVAL(rocblas_status_continue); + XVAL(rocblas_status_check_numerics_fail); + XVAL(rocblas_status_excluded_from_build); + XVAL(rocblas_status_arch_mismatch); +#endif default: return absl::StrCat(""); } +#undef XVAL } bool ROCMBlas::Init() { @@ -110,6 +118,17 @@ bool ROCMBlas::Init() { return false; } #endif + + int dev = 0; + hipError_t result = hipGetDevice(&dev); + hipDeviceProp_t props; + result = hipGetDeviceProperties(&props, dev); + if (result == hipSuccess) { + auto cap = RocmComputeCapability(props.gcnArchName); + has_mfma_ = cap.has_mfma_instr_support(); + use_hgemm_alt_impl_ = (cap.gfx_version() == "90a"); + } + return true; } @@ -203,17 +222,113 @@ rocblas_side ROCMBlasSide(blas::Side side) { } } +absl::StatusOr AsRocBlasType(blas::DataType type) { + switch (type) { + case blas::DataType::kHalf: + return rocblas_datatype_f16_r; + case blas::DataType::kBF16: + return rocblas_datatype_bf16_r; + case blas::DataType::kFloat: + return rocblas_datatype_f32_r; + case blas::DataType::kDouble: + return rocblas_datatype_f64_r; + case blas::DataType::kInt8: + return rocblas_datatype_i8_r; + case blas::DataType::kInt32: + return rocblas_datatype_i32_r; + case blas::DataType::kComplexFloat: + return rocblas_datatype_f32_c; + case blas::DataType::kComplexDouble: + return rocblas_datatype_f64_c; + default: + return absl::InternalError( + absl::StrFormat("Unsupported blas data type: %d", (int)type)); + } +} + +absl::StatusOr AsRocBlasComputeType( + blas::ComputationType type) { + switch (type) { + case blas::ComputationType::kF16: + return rocblas_datatype_f16_r; + case blas::ComputationType::kF32: + return rocblas_datatype_f32_r; + case blas::ComputationType::kF64: + return rocblas_datatype_f64_r; + case blas::ComputationType::kI32: + return rocblas_datatype_i32_r; + case blas::ComputationType::kF16AsF32: + case blas::ComputationType::kBF16AsF32: + case blas::ComputationType::kTF32AsF32: + default: + return absl::InternalError( + absl::StrFormat("Unsupported compute type: %d", (int)type)); + } +} + +void CheckPreconditions(blas::Transpose transa, blas::Transpose transb, + uint64_t m, uint64_t n, uint64_t k, + blas::DataType dtype, int lda, int ldb) { + if (dtype == blas::DataType::kHalf || dtype == blas::DataType::kFloat) { + if (transa == blas::Transpose::kNoTranspose) { + if (lda < static_cast(m)) { + LOG(WARNING) << "GEMM lda was smaller than m (no transpose case); " + "precondition violation"; + } + } else { + if (lda < static_cast(k)) { + LOG(WARNING) << "GEMM lda (" << lda << ") was smaller than k (" << k + << ") (transpose case); precondition violation"; + } + } + if (transb == blas::Transpose::kNoTranspose) { + if (ldb < static_cast(k)) { + LOG(WARNING) << "GEMM ldb (" << ldb << ") was smaller than k (" << k + << ") (no transpose case); precondition violation"; + } + } else { + if (ldb < static_cast(n)) { + LOG(WARNING) << "GEMM ldb was smaller than n (transpose case); " + "precondition violation"; + } + } + } +} + +uint32_t GemmFloat16Flags(blas::DataType dtype, blas::CallContext context, + bool use_alt_impl) { + bool is_backprop = (context == blas::CallContext::kBackpropInput1 || + context == blas::CallContext::kBackpropInput2); + + return ((dtype == blas::DataType::kHalf) && is_backprop && use_alt_impl) + ? rocblas_gemm_flags_fp16_alt_impl + : rocblas_gemm_flags_none; +} + +absl::Status PopulateProfileFromTimer( + std::optional &timer, blas::AlgorithmType algorithm, + blas::ProfileResult *output_profile_result) { + if (output_profile_result) { + TF_ASSIGN_OR_RETURN(absl::Duration duration, timer->GetElapsedDuration()); + output_profile_result->set_is_valid(true); + output_profile_result->set_algorithm(algorithm); + output_profile_result->set_elapsed_time_in_ms( + absl::ToDoubleMilliseconds(duration)); + } + return absl::OkStatus(); +} + } // namespace template -bool ROCMBlas::DoBlasInternalImpl(FuncT rocblas_func, Stream *stream, - bool pointer_mode_host, bool err_on_failure, - Args... args) { +absl::Status ROCMBlas::DoBlasInternalImpl(FuncT rocblas_func, Stream *stream, + bool pointer_mode_host, + bool err_on_failure, Args &&...args) { absl::MutexLock lock{&mu_}; CHECK(blas_ != nullptr); if (!SetStream(stream)) { - return false; + return absl::InternalError("Setting stream failed"); } gpu::ScopedActivateExecutorContext sac{parent_}; @@ -224,23 +339,26 @@ bool ROCMBlas::DoBlasInternalImpl(FuncT rocblas_func, Stream *stream, if (!allow_atomics) { ret = wrap::rocblas_set_atomics_mode(blas_, rocblas_atomics_not_allowed); if (err_on_failure && ret != rocblas_status_success) { - LOG(ERROR) << "failed to to set atomics mode before " - << rocblas_func.kName << ": " << ToString(ret); + LOG(ERROR) << "failed to to set atomics mode before " << FuncT::kName + << ": " << ToString(ret); } } - ret = rocblas_func(blas_, args...); - if (err_on_failure && ret != rocblas_status_success) { - LOG(ERROR) << "failed to run ROCBLAS routine " << rocblas_func.kName << ": " - << ToString(ret); + ret = rocblas_func(blas_, std::forward(args)...); + if (ret != rocblas_status_success) { + auto err_str = + absl::StrFormat("%s failed with: %s", FuncT::kName, ToString(ret)); + if (err_on_failure) { + LOG(ERROR) << err_str; + } + return absl::InternalError(err_str); } - return ret == rocblas_status_success; + return absl::OkStatus(); } bool ROCMBlas::DoBlasAxpy(Stream *stream, uint64_t elem_count, float alpha, const DeviceMemory &x, int incx, DeviceMemory *y, int incy) { - blas_log("DoBlasAxpy"); return DoBlasInternal(wrap::rocblas_saxpy, stream, /* pointer_mode_host = */ true, elem_count, &alpha, GpuMemory(x), incx, GpuMemoryMutable(y), incy); @@ -254,235 +372,143 @@ bool ROCMBlas::DoBlasCopy(Stream *stream, uint64_t elem_count, GpuMemory(x), incx, GpuMemoryMutable(y), incy); } -bool ROCMBlas::DoBlasScal(Stream *stream, uint64_t elem_count, float alpha, - DeviceMemory *x, int incx) { - blas_log("DoBlasScal"); - return DoBlasInternal(wrap::rocblas_sscal, stream, - /* pointer_mode_host = */ true, elem_count, &alpha, - GpuMemoryMutable(x), incx); -} - -bool ROCMBlas::DoBlasScal(Stream *stream, uint64_t elem_count, double alpha, - DeviceMemory *x, int incx) { - return DoBlasInternal(wrap::rocblas_dscal, stream, - /* pointer_mode_host = */ true, elem_count, &alpha, - GpuMemoryMutable(x), incx); -} - -bool ROCMBlas::DoBlasScal(Stream *stream, uint64_t elem_count, float alpha, - DeviceMemory> *x, int incx) { - return DoBlasInternal(wrap::rocblas_csscal, stream, - /* pointer_mode_host = */ true, elem_count, &alpha, - complex_cast(x), incx); -} - -bool ROCMBlas::DoBlasScal(Stream *stream, uint64_t elem_count, double alpha, - DeviceMemory> *x, int incx) { - return DoBlasInternal(wrap::rocblas_zdscal, stream, - /* pointer_mode_host = */ true, elem_count, &alpha, - complex_cast(x), incx); -} - -bool ROCMBlas::DoBlasScal(Stream *stream, uint64_t elem_count, - std::complex alpha, - DeviceMemory> *x, int incx) { - return DoBlasInternal(wrap::rocblas_cscal, stream, - /* pointer_mode_host = */ true, elem_count, - complex_cast(alpha), complex_cast(x), incx); -} - -bool ROCMBlas::DoBlasScal(Stream *stream, uint64_t elem_count, - std::complex alpha, - DeviceMemory> *x, int incx) { - return DoBlasInternal(wrap::rocblas_zscal, stream, - /* pointer_mode_host = */ true, elem_count, - complex_cast(alpha), complex_cast(x), incx); -} - -bool ROCMBlas::DoBlasGemv(Stream *stream, blas::Transpose trans, uint64_t m, - uint64_t n, float alpha, const DeviceMemory &a, - int lda, const DeviceMemory &x, int incx, - float beta, DeviceMemory *y, int incy) { - blas_log("DoBlasGemv"); - return DoBlasInternal( - wrap::rocblas_sgemv, stream, /* pointer_mode_host = */ true, - ROCMBlasTranspose(trans), m, n, &alpha, GpuMemory(a), lda, GpuMemory(x), - incx, &beta, GpuMemoryMutable(y), incy); -} - -bool ROCMBlas::DoBlasGemv(Stream *stream, blas::Transpose trans, uint64_t m, - uint64_t n, double alpha, - const DeviceMemory &a, int lda, - const DeviceMemory &x, int incx, double beta, - DeviceMemory *y, int incy) { - blas_log("DoBlasGemv"); - return DoBlasInternal( - wrap::rocblas_dgemv, stream, /* pointer_mode_host = */ true, - ROCMBlasTranspose(trans), m, n, &alpha, GpuMemory(a), lda, GpuMemory(x), - incx, &beta, GpuMemoryMutable(y), incy); -} - -bool ROCMBlas::DoBlasGemv(Stream *stream, blas::Transpose trans, uint64_t m, - uint64_t n, std::complex alpha, - const DeviceMemory> &a, int lda, - const DeviceMemory> &x, int incx, - std::complex beta, - DeviceMemory> *y, int incy) { - blas_log("DoBlasGemv"); - return DoBlasInternal( - wrap::rocblas_cgemv, stream, /* pointer_mode_host = */ true, - ROCMBlasTranspose(trans), m, n, complex_cast(alpha), complex_cast(a), lda, - complex_cast(x), incx, complex_cast(beta), complex_cast(y), incy); -} +#define Impl_DoBlasScal(Fun, T, Ta) \ + bool ROCMBlas::DoBlasScal(Stream *stream, uint64_t elem_count, Ta alpha, \ + DeviceMemory *x, int incx) { \ + return DoBlasInternal(Fun, stream, /* pointer_mode_host = */ true, \ + elem_count, complex_cast(alpha), complex_cast(x), \ + incx); \ + } -bool ROCMBlas::DoBlasGemv(Stream *stream, blas::Transpose trans, uint64_t m, - uint64_t n, std::complex alpha, - const DeviceMemory> &a, int lda, - const DeviceMemory> &x, int incx, - std::complex beta, - DeviceMemory> *y, int incy) { - blas_log("DoBlasGemv\n"); - return DoBlasInternal( - wrap::rocblas_zgemv, stream, /* pointer_mode_host = */ true, - ROCMBlasTranspose(trans), m, n, complex_cast(alpha), complex_cast(a), lda, - complex_cast(x), incx, complex_cast(beta), complex_cast(y), incy); -} +Impl_DoBlasScal(wrap::rocblas_sscal, float, + float) Impl_DoBlasScal(wrap::rocblas_dscal, double, double) + Impl_DoBlasScal(wrap::rocblas_csscal, std::complex, float) + Impl_DoBlasScal(wrap::rocblas_zdscal, std::complex, double) + Impl_DoBlasScal(wrap::rocblas_cscal, std::complex, + std::complex) + Impl_DoBlasScal(wrap::rocblas_zscal, std::complex, + std::complex) +#define Impl_DoBlasGemv(fun, T) \ + bool ROCMBlas::DoBlasGemv(Stream *stream, blas::Transpose trans, uint64_t m, \ + uint64_t n, T alpha, const DeviceMemory &a, \ + int lda, const DeviceMemory &x, int incx, \ + T beta, DeviceMemory *y, int incy) { \ + return DoBlasInternal(fun, stream, /* pointer_mode_host = */ true, \ + ROCMBlasTranspose(trans), m, n, complex_cast(alpha), \ + complex_cast(a), lda, complex_cast(x), incx, \ + complex_cast(beta), complex_cast(y), incy); \ + } -bool ROCMBlas::DoBlasSbmv(Stream *stream, blas::UpperLower uplo, uint64_t n, - uint64_t k, float alpha, const DeviceMemory &a, - int lda, const DeviceMemory &x, int incx, - float beta, DeviceMemory *y, int incy) { + Impl_DoBlasGemv(wrap::rocblas_sgemv, float) + Impl_DoBlasGemv(wrap::rocblas_dgemv, double) + Impl_DoBlasGemv(wrap::rocblas_cgemv, + std::complex) + Impl_DoBlasGemv(wrap::rocblas_zgemv, + std::complex) + + bool ROCMBlas::DoBlasSbmv( + Stream *stream, blas::UpperLower uplo, + uint64_t n, uint64_t k, float alpha, + const DeviceMemory &a, int lda, + const DeviceMemory &x, int incx, + float beta, DeviceMemory *y, + int incy) { return DoBlasInternal( wrap::rocblas_ssbmv, stream, /* pointer_mode_host = */ true, ROCMBlasUpperLower(uplo), n, k, &alpha, GpuMemory(a), lda, GpuMemory(x), incx, &beta, GpuMemoryMutable(y), incy); } +/** + * + * ALPHA/BETA TYPES + * + * For half and bf16, alpha and beta point to floats. + * For all other types, alpha and beta point to values of the same type as + *a/b/c. + * + * On the rocblas side, non-ex functions expect the same type as a/b/c + * (this seems to be a deviation from the blas standard); + * and ex functions expect the same type as the compute type (i.e. floats.) + * + **/ + absl::Status ROCMBlas::DoBlasGemm( Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64_t m, - uint64 n, uint64_t k, blas::DataType dtype, const void *alpha, + uint64_t n, uint64_t k, blas::DataType dtype, const void *alpha, const DeviceMemoryBase &a, int lda, const DeviceMemoryBase &b, int ldb, const void *beta, DeviceMemoryBase *c, int ldc, const NumericOptions &numeric_options, blas::CallContext context) { - blas_log("DoBlasGemm"); VLOG(1) << absl::StreamFormat( "doing rocBLAS GEMM: at=%d bt=%d m=%u n=%u " "k=%llu alpha=%p a=%p lda=%d b=%p ldb=%d beta=%p " "c=%p ldc=%d", static_cast(transa), static_cast(transb), m, n, k, alpha, a.opaque(), lda, b.opaque(), ldb, beta, c->opaque(), ldc); - if (dtype == blas::DataType::kHalf || dtype == blas::DataType::kFloat) { - if (transa == blas::Transpose::kNoTranspose) { - if (lda < static_cast(m)) { - LOG(WARNING) << "GEMM lda was smaller than m (no transpose case); " - "precondition violation"; - } - } else { - if (lda < static_cast(k)) { - LOG(WARNING) << "GEMM lda (" << lda << ") was smaller than k (" << k - << ") (transpose case); precondition violation"; - } - } - if (transb == blas::Transpose::kNoTranspose) { - if (ldb < static_cast(k)) { - LOG(WARNING) << "GEMM ldb (" << ldb << ") was smaller than k (" << k - << ") (no transpose case); precondition violation"; - } - } else { - if (ldb < static_cast(n)) { - LOG(WARNING) << "GEMM ldb was smaller than n (transpose case); " - "precondition violation"; - } - } + + CheckPreconditions(transa, transb, m, n, k, dtype, lda, ldb); + + absl::Status status; + uint32_t gemm_ex_flags = rocblas_gemm_flags_none; + bool is_backprop = (context == blas::CallContext::kBackpropInput1) || + (context == blas::CallContext::kBackpropInput2); + if (is_backprop && use_hgemm_alt_impl_) + gemm_ex_flags = rocblas_gemm_flags_fp16_alt_impl; + + Eigen::half alpha_half, beta_half; + + const void *alpha_downcast = alpha, *beta_downcast = beta; + if (dtype == blas::DataType::kHalf) { + alpha_half = Eigen::half(*static_cast(alpha)); + beta_half = Eigen::half(*static_cast(beta)); + alpha_downcast = &alpha_half; + beta_downcast = &beta_half; } + /* I would like to specify the type with a template parameter: + * + * auto call_gemm = [&](auto func) { ... } + * ... + * status = call_gemm(wrap::rocblas_sgemm); + * + * but that's a C++20 extension and can't be enabled (the compiler does + * support it, but enabling it causes compilation errors inside Eigen.) */ + auto call_gemm = [&](auto func, auto type) { + return DoBlasInternalStatus( + func, stream, /* pointer_mode_host = */ true, ROCMBlasTranspose(transa), + ROCMBlasTranspose(transb), m, n, k, + reinterpret_cast(alpha_downcast), + reinterpret_cast(a.opaque()), lda, + reinterpret_cast(b.opaque()), ldb, + reinterpret_cast(beta_downcast), + reinterpret_cast(c->opaque()), ldc); + }; + + auto call_gemm_ex = [&](rocblas_datatype dt) { + return DoBlasInternalStatus( + wrap::rocblas_gemm_ex, stream, /* pointer_mode_host = */ true, + ROCMBlasTranspose(transa), ROCMBlasTranspose(transb), (rocblas_int)m, + (rocblas_int)n, (rocblas_int)k, alpha, a.opaque(), dt, lda, b.opaque(), + dt, ldb, beta, c->opaque(), dt, ldc, c->opaque(), dt, ldc, + rocblas_datatype_f32_r, rocblas_gemm_algo_standard, 0, gemm_ex_flags); + }; + switch (dtype) { - case blas::DataType::kHalf: { - absl::StatusOr maybe_hasXDLOPS = GpuDriver::GetMFMASupport(); - if (maybe_hasXDLOPS.ok() && maybe_hasXDLOPS.value()) { - VLOG(1) << "Using rocblas_gemm_ex"; - bool is_backprop = (context == blas::CallContext::kBackpropInput1) || - (context == blas::CallContext::kBackpropInput2); - - uint32_t flags = rocblas_gemm_flags_none; -#if TF_ROCM_VERSION >= 50000 - if (is_backprop) { - flags = rocblas_gemm_flags_fp16_alt_impl; - } -#endif - return DoBlasInternalStatus( - wrap::rocblas_gemm_ex, stream, /* pointer_mode_host = */ true, - ROCMBlasTranspose(transa), ROCMBlasTranspose(transb), - (rocblas_int)m, (rocblas_int)n, (rocblas_int)k, alpha, a.opaque(), - rocblas_datatype_f16_r, lda, b.opaque(), rocblas_datatype_f16_r, - ldb, beta, c->opaque(), rocblas_datatype_f16_r, ldc, c->opaque(), - rocblas_datatype_f16_r, ldc, rocblas_datatype_f32_r, - rocblas_gemm_algo_standard, 0, flags); - } else { - VLOG(1) << "Using rocblas_hgemm"; - const Eigen::half alpha_half(*static_cast(alpha)); - const Eigen::half beta_half(*static_cast(beta)); - return DoBlasInternalStatus( - wrap::rocblas_hgemm, stream, /* pointer_mode_host = */ true, - ROCMBlasTranspose(transa), ROCMBlasTranspose(transb), m, n, k, - reinterpret_cast(&alpha_half), - reinterpret_cast(a.opaque()), lda, - reinterpret_cast(b.opaque()), ldb, - reinterpret_cast(&beta_half), - reinterpret_cast(c->opaque()), ldc); - } - } + case blas::DataType::kHalf: + if (has_mfma_) + return call_gemm_ex(rocblas_datatype_f16_r); + else + return call_gemm(wrap::rocblas_hgemm, rocblas_half()); case blas::DataType::kBF16: - return DoBlasInternalStatus( - wrap::rocblas_gemm_ex, stream, /* pointer_mode_host = */ true, - ROCMBlasTranspose(transa), ROCMBlasTranspose(transb), (rocblas_int)m, - (rocblas_int)n, (rocblas_int)k, alpha, a.opaque(), - rocblas_datatype_bf16_r, lda, b.opaque(), rocblas_datatype_bf16_r, - ldb, beta, c->opaque(), rocblas_datatype_bf16_r, ldc, c->opaque(), - rocblas_datatype_bf16_r, ldc, rocblas_datatype_f32_r, - rocblas_gemm_algo_standard, 0, 0); + return call_gemm_ex(rocblas_datatype_bf16_r); case blas::DataType::kFloat: - return DoBlasInternalStatus( - wrap::rocblas_sgemm, stream, /* pointer_mode_host = */ true, - ROCMBlasTranspose(transa), ROCMBlasTranspose(transb), m, n, k, - static_cast(alpha), - static_cast(a.opaque()), lda, - static_cast(b.opaque()), ldb, - static_cast(beta), static_cast(c->opaque()), - ldc); + return call_gemm(wrap::rocblas_sgemm, 1.0f); case blas::DataType::kDouble: - return DoBlasInternalStatus( - wrap::rocblas_dgemm, stream, /* pointer_mode_host = */ true, - ROCMBlasTranspose(transa), ROCMBlasTranspose(transb), m, n, k, - static_cast(alpha), - static_cast(a.opaque()), lda, - static_cast(b.opaque()), ldb, - static_cast(beta), static_cast(c->opaque()), - ldc); - case blas::DataType::kComplexFloat: { - auto cb_alpha = - complex_cast(*static_cast *>(alpha)); - auto cb_beta = - complex_cast(*static_cast *>(beta)); - return DoBlasInternalStatus( - wrap::rocblas_cgemm, stream, /* pointer_mode_host = */ true, - ROCMBlasTranspose(transa), ROCMBlasTranspose(transb), m, n, k, - cb_alpha, static_cast(a.opaque()), lda, - static_cast(b.opaque()), ldb, cb_beta, - static_cast(c->opaque()), ldc); - } - case blas::DataType::kComplexDouble: { - auto cb_alpha = - complex_cast(*static_cast *>(alpha)); - auto cb_beta = - complex_cast(*static_cast *>(beta)); - return DoBlasInternalStatus( - wrap::rocblas_zgemm, stream, /* pointer_mode_host = */ true, - ROCMBlasTranspose(transa), ROCMBlasTranspose(transb), m, n, k, - cb_alpha, static_cast(a.opaque()), - lda, static_cast(b.opaque()), ldb, - cb_beta, static_cast(c->opaque()), ldc); - } + return call_gemm(wrap::rocblas_dgemm, 1.0); + case blas::DataType::kComplexFloat: + return call_gemm(wrap::rocblas_cgemm, rocblas_float_complex()); + case blas::DataType::kComplexDouble: + return call_gemm(wrap::rocblas_zgemm, rocblas_double_complex()); default: return absl::InternalError(absl::StrCat("Unsupported datatype for GEMM: ", blas::DataTypeString(dtype))); @@ -491,38 +517,206 @@ absl::Status ROCMBlas::DoBlasGemm( absl::Status ROCMBlas::DoBlasGemmWithAlgorithm( Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64_t m, - uint64_t n, uint64 k, const void *alpha, const DeviceMemoryBase &a, + uint64_t n, uint64_t k, const void *alpha, const DeviceMemoryBase &a, blas::DataType type_a, int lda, const DeviceMemoryBase &b, blas::DataType type_b, int ldb, const void *beta, DeviceMemoryBase *c, blas::DataType type_c, int ldc, blas::ComputationType computation_type, blas::AlgorithmType algorithm, const NumericOptions &numeric_options, - blas::ProfileResult *output_profile_result, blas::CallContext context) { - // ROCM TODO: properly implement the interface - return absl::InternalError( - "DoBlasGemmWithAlgorithm " - "is not implemented on ROCm yet"); + blas::ProfileResult *profile_result, blas::CallContext context) { + if (type_a != type_b) { + return absl::InternalError(absl::StrFormat( + "DoBlasGemmWithAlgorithm: different " + "datatypes for the inputs a (%d) and b (%d) are unsupported", + static_cast(type_a), static_cast(type_b))); + } + TF_ASSIGN_OR_RETURN( + auto timer, + GpuTimer::CreateIfNeeded(AsGpuStream(stream), profile_result != nullptr)); + + // fall back to the default implementation + if (algorithm == blas::kDefaultAlgorithm && type_a == type_c) { + TF_RETURN_IF_ERROR(DoBlasGemm(stream, transa, transb, m, n, k, type_a, + alpha, a, lda, b, ldb, beta, c, ldc, + numeric_options, context)); + + } else { + CheckPreconditions(transa, transb, m, n, k, type_a, lda, ldb); + TF_ASSIGN_OR_RETURN(auto roc_type_a, AsRocBlasType(type_a)); + TF_ASSIGN_OR_RETURN(auto roc_type_c, AsRocBlasType(type_c)); + TF_ASSIGN_OR_RETURN(auto roc_comp_type, + AsRocBlasComputeType(computation_type)); + + VLOG(1) << absl::StreamFormat( + "doing rocBLAS GEMM with Algorithm: at=%d bt=%d m=%u n=%u " + "k=%llu alpha=%p a=%p lda=%d b=%p ldb=%d beta=%p " + "c=%p ldc=%d algorithm=%d type_a/b=%d type_c=%d comp_type=%d", + static_cast(transa), static_cast(transb), m, n, k, alpha, + a.opaque(), lda, b.opaque(), ldb, beta, c->opaque(), ldc, algorithm, + static_cast(roc_type_a), static_cast(roc_type_c), + static_cast(roc_comp_type)); + + TF_RETURN_IF_ERROR(DoBlasInternalImpl( + wrap::rocblas_gemm_ex, stream, + /* pointer_mode_host = */ true, + /* error_on_failure = */ false, ROCMBlasTranspose(transa), + ROCMBlasTranspose(transb), (rocblas_int)m, (rocblas_int)n, + (rocblas_int)k, alpha, a.opaque(), roc_type_a, lda, b.opaque(), + roc_type_a, ldb, beta, c->opaque(), roc_type_c, ldc, c->opaque(), + roc_type_c, ldc, roc_comp_type, rocblas_gemm_algo_solution_index, + algorithm, GemmFloat16Flags(type_a, context, use_hgemm_alt_impl_))); + } + TF_RETURN_IF_ERROR( + PopulateProfileFromTimer(timer, algorithm, profile_result)); + + return absl::OkStatus(); } absl::Status ROCMBlas::DoBlasGemmStridedBatchedWithAlgorithm( Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64_t m, - uint64_t n, uint64 k, const void *alpha, const DeviceMemoryBase &a, + uint64_t n, uint64_t k, const void *alpha, const DeviceMemoryBase &a, blas::DataType type_a, int lda, int64_t stride_a, const DeviceMemoryBase &b, blas::DataType type_b, int ldb, int64_t stride_b, const void *beta, DeviceMemoryBase *c, blas::DataType type_c, int ldc, int64_t stride_c, int batch_count, blas::ComputationType computation_type, blas::AlgorithmType algorithm, const NumericOptions &numeric_options, - blas::ProfileResult *output_profile_result, blas::CallContext context) { - // ROCM TODO: properly implement the interface - return absl::InternalError( - "DoBlasGemmStridedBatchedWithAlgorithm " - "is not implemented on ROCm yet"); + blas::ProfileResult *profile_result, blas::CallContext context) { + if (type_a != type_b) { + return absl::InternalError(absl::StrFormat( + "DoBlasGemmStridedBatchedWithAlgorithm: different " + "datatypes for the inputs a (%d) and b (%d) are unsupported", + static_cast(type_a), static_cast(type_b))); + } + TF_ASSIGN_OR_RETURN( + auto timer, + GpuTimer::CreateIfNeeded(AsGpuStream(stream), profile_result != nullptr)); + + // fall back to the default implementation + if (algorithm == blas::kDefaultAlgorithm && type_a == type_c) { + TF_RETURN_IF_ERROR(DoBlasGemmStridedBatched( + stream, transa, transb, m, n, k, type_a, alpha, a, lda, stride_a, b, + ldb, stride_b, beta, c, ldc, stride_c, batch_count, numeric_options, + context)); + } else { + VLOG(1) << absl::StreamFormat( + "doing rocBLAS GEMM strided batched with Algorithm: at=%d bt=%d m=%u " + "n=%u " + "k=%llu alpha=%p a=%p lda=%d b=%p ldb=%d beta=%p " + "c=%p ldc=%d algorithm=%d type_a/b=%d type_c=%d stride_a/b/c=%d/%d/%d " + "batch_count=%d", + static_cast(transa), static_cast(transb), m, n, k, alpha, + a.opaque(), lda, b.opaque(), ldb, beta, c->opaque(), ldc, algorithm, + static_cast(type_a), static_cast(type_c), stride_a, stride_b, + stride_c, batch_count); + + TF_ASSIGN_OR_RETURN(auto roc_type_a, AsRocBlasType(type_a)); + TF_ASSIGN_OR_RETURN(auto roc_type_c, AsRocBlasType(type_c)); + TF_ASSIGN_OR_RETURN(auto roc_comp_type, + AsRocBlasComputeType(computation_type)); + + TF_RETURN_IF_ERROR(DoBlasInternalImpl( + wrap::rocblas_gemm_strided_batched_ex, stream, + /* pointer_mode_host = */ true, + /* error_on_failure = */ false, ROCMBlasTranspose(transa), + ROCMBlasTranspose(transb), (rocblas_int)m, (rocblas_int)n, + (rocblas_int)k, alpha, a.opaque(), roc_type_a, lda, stride_a, + b.opaque(), roc_type_a, ldb, stride_b, beta, c->opaque(), roc_type_c, + ldc, stride_c, c->opaque(), roc_type_c, ldc, stride_c, batch_count, + roc_comp_type, rocblas_gemm_algo_solution_index, algorithm, + GemmFloat16Flags(type_a, context, use_hgemm_alt_impl_))); + } + TF_RETURN_IF_ERROR( + PopulateProfileFromTimer(timer, algorithm, profile_result)); + + return absl::OkStatus(); } +template +struct NameWrap : Lambda { + using Lambda::operator(); + constexpr static const char *kName = "rocblas_gemm_ex_get_solutions"; +}; +template +NameWrap(Func) -> NameWrap; + +#define ASSIGN_OR_FALSE(lhs, rexpr) \ + result = (rexpr); \ + if (TF_PREDICT_FALSE(!result.ok())) return false; \ + lhs = std::move(result).value() + bool ROCMBlas::GetBlasGemmAlgorithms( - Stream *stream, std::vector *out_algorithms) { - // ROCM TODO: properly implement the interface - return true; + Stream *stream, const gpu::MatrixDescriptor &a, + const gpu::MatrixDescriptor &b, gpu::OutputMatrixDescriptor *c, + const void *alpha, const void *beta, + std::vector *out_algorithms) { + out_algorithms->clear(); + auto blas_lambda = [this, out_algorithms](auto handle, auto &&blas_func, + auto &&...rest) { + rocblas_int num_sols = 0; + + if (auto ret = blas_func(handle, std::forward(rest)..., + nullptr, &num_sols); + ret != rocblas_status_success) { + return ret; + } + solutions_.resize(num_sols); + if (auto ret = blas_func(handle, std::forward(rest)..., + solutions_.data(), &num_sols); + ret != rocblas_status_success) { + return ret; + } + out_algorithms->resize(num_sols); + for (rocblas_int i = 0; i < num_sols; i++) { + (*out_algorithms)[i] = solutions_[i]; + } + return rocblas_status_success; + }; + + VLOG(1) << absl::StreamFormat( + "GetBlasAlgorithms: at=%d bt=%d m=%u n=%u " + "k=%llu alpha=%p a=%p lda=%d b=%p ldb=%d beta=%p " + "c=%p ldc=%d type_a/b=%d type_c=%d stride_a/b/c=%d/%d/%d " + "batch_count=%d", + static_cast(a.transpose), static_cast(b.transpose), c->m, c->n, + c->k, alpha, a.data.opaque(), a.leading_dim_stride, b.data.opaque(), + b.leading_dim_stride, beta, c->data.opaque(), c->leading_dim_stride, + static_cast(a.type), static_cast(c->type), a.batch_stride, + b.batch_stride, c->batch_stride, c->batch_size); + + if (a.type != b.type) { + LOG(ERROR) << "Gemm arguments types differ: no feasible solutions!"; + return false; + } + absl::StatusOr result; + ASSIGN_OR_FALSE(auto roc_type_a, AsRocBlasType(a.type)); + ASSIGN_OR_FALSE(auto roc_type_c, AsRocBlasType(c->type)); + ASSIGN_OR_FALSE(auto roc_comp_type, AsRocBlasComputeType(c->compute_type)); + + if (c->batch_size == 1) { + // TODO: we should possibly use GemmFloat16Flags(type_a, context) here.. + return DoBlasInternalFailureOK( + NameWrap{blas_lambda}, stream, true, + wrap::rocblas_gemm_ex_get_solutions, ROCMBlasTranspose(a.transpose), + ROCMBlasTranspose(b.transpose), c->m, c->n, c->k, alpha, + a.data.opaque(), roc_type_a, a.leading_dim_stride, b.data.opaque(), + roc_type_a, b.leading_dim_stride, beta, c->data.opaque(), roc_type_c, + c->leading_dim_stride, c->data.opaque(), roc_type_c, + c->leading_dim_stride, roc_comp_type, rocblas_gemm_algo_solution_index, + 0); + } + return DoBlasInternalFailureOK( + NameWrap{blas_lambda}, stream, true, + wrap::rocblas_gemm_strided_batched_ex_get_solutions, + ROCMBlasTranspose(a.transpose), ROCMBlasTranspose(b.transpose), c->m, + c->n, c->k, alpha, a.data.opaque(), roc_type_a, a.leading_dim_stride, + a.batch_stride, b.data.opaque(), roc_type_a, b.leading_dim_stride, + b.batch_stride, beta, c->data.opaque(), roc_type_c, c->leading_dim_stride, + c->batch_stride, c->data.opaque(), roc_type_c, c->leading_dim_stride, + c->batch_stride, c->batch_size, roc_comp_type, + rocblas_gemm_algo_solution_index, 0); } +#undef ASSIGN_OR_FALSE + +namespace { struct MemoryCopyOp { char *src_ptr; @@ -535,7 +729,7 @@ struct MemoryCopyOp { // Check whether two Memory Copy Ops can be fold together. // If it's true, fold it. Otherwise, return false. -static bool MemCopyOpsFold(MemoryCopyOp &y, const MemoryCopyOp &x) { +bool MemCopyOpsFold(MemoryCopyOp &y, const MemoryCopyOp &x) { bool misaligned = (x.size & 3) || (reinterpret_cast(x.dst_ptr) & 3) || (reinterpret_cast(x.src_ptr) & 3) || @@ -624,11 +818,7 @@ absl::Status ReorganizeMemory(Stream *stream, } else { DeviceMemoryBase src_mem = DeviceMemoryBase(x.src_ptr, x.size); DeviceMemoryBase target_mem = DeviceMemoryBase(x.dst_ptr, x.size); - bool a_status = stream->ThenMemcpy(&target_mem, src_mem, x.size).ok(); - if (!a_status) { - return absl::InternalError( - "failed to copy device memory in ROCMBlas::DoBlasGemmBatched"); - } + TF_RETURN_IF_ERROR(stream->Memcpy(&target_mem, src_mem, x.size)); } i++; } @@ -636,19 +826,21 @@ absl::Status ReorganizeMemory(Stream *stream, } template -absl::Status ROCMBlas::AllocateStridedBuffer( - const std::vector::mapped_type *> - &raw_ptrs, - int batch_count, uint64_t batch_stride, ScratchAllocator *scratch_allocator, - Stream *stream, - std::unique_ptr::mapped_type>> *temp_memory, - DeviceMemory::mapped_type> - *device_memory, - bool copy_data, bool &reallocated) { - assert(device_memory != nullptr); - - using MAPPED_T = typename RocBlasTypeConversionHelper::mapped_type; +struct AllocateStridedResult { + using Type = RocBlasType_t; + DeviceMemory device_mem; + bool reallocated; +}; + +// A helper allocation function to convert raw pointers memory layout to +// strided flavor +template +absl::StatusOr> AllocateStridedBuffer( + const std::vector *> &raw_ptrs, int batch_count, + uint64_t batch_stride, ScratchAllocator *scratch_allocator, Stream *stream, + bool copy_data) { + using MAPPED_T = RocBlasType_t; + AllocateStridedResult res; bool needs_allocate_strided = false; for (int i = 1; i < batch_count; ++i) { @@ -664,42 +856,37 @@ absl::Status ROCMBlas::AllocateStridedBuffer( // No need to do re-allocation, take the short cut and return if (!needs_allocate_strided) { - *device_memory = DeviceMemory( + res.device_mem = DeviceMemory( DeviceMemoryBase(raw_ptrs[0], matrix_batch_byte_size)); - reallocated = false; - return absl::OkStatus(); + res.reallocated = false; + return res; } - if (scratch_allocator != nullptr) { - TF_ASSIGN_OR_RETURN( - DeviceMemory batch_matrix_bytes, - scratch_allocator->AllocateBytes(matrix_batch_byte_size)); - *device_memory = DeviceMemory(batch_matrix_bytes); - } else { - assert(temp_memory != nullptr); - TF_ASSIGN_OR_RETURN(*temp_memory, stream->AllocateTemporaryArray( - matrix_batch_byte_size)); - *device_memory = - DeviceMemory(*(*temp_memory)->mutable_device_memory()); + if (scratch_allocator == nullptr) { + return absl::InternalError("scratch_allocator is null"); } - - reallocated = true; - - if (copy_data) - return ReorganizeMemory(stream, device_memory, raw_ptrs, batch_count, + TF_ASSIGN_OR_RETURN(DeviceMemory batch_matrix_bytes, + scratch_allocator->AllocateBytes(matrix_batch_byte_size)); + res.device_mem = DeviceMemory(batch_matrix_bytes); + res.reallocated = true; + if (copy_data) { + return ReorganizeMemory(stream, &res.device_mem, raw_ptrs, batch_count, batch_stride, true); - return absl::OkStatus(); + } + return res; } +} // namespace + template absl::Status ROCMBlas::DoBlasGemmBatchedInternal( FuncT rocblas_func, Stream *stream, blas::Transpose transa, - blas::Transpose transb, uint64_t m, uint64 n, uint64 k, T alpha, + blas::Transpose transb, uint64_t m, uint64_t n, uint64_t k, T alpha, DeviceMemorySlice a_ptrs_to_wrappers, int lda, DeviceMemorySlice b_ptrs_to_wrappers, int ldb, T beta, DeviceMemorySlice c_ptrs_to_wrappers, int ldc, int batch_count, ScratchAllocator *scratch_allocator) { - using MAPPED_T = typename RocBlasTypeConversionHelper::mapped_type; + using MAPPED_T = RocBlasType_t; // Sanity checks before making any further progress uint64_t batch_stride_a = 0; @@ -726,94 +913,137 @@ absl::Status ROCMBlas::DoBlasGemmBatchedInternal( } // Allocate local vectors to hold device pointers to matrices - std::vector a_raw_ptrs, b_raw_ptrs, c_raw_ptrs; + std::vector a_raw_ptrs(batch_count), b_raw_ptrs(batch_count), + c_raw_ptrs(batch_count); for (int i = 0; i < batch_count; ++i) { // static_cast does work when converting Eigen::half* to rocblas_half*, // hence the use of reinterpret_cast - a_raw_ptrs.push_back( - reinterpret_cast(a_ptrs_to_wrappers[i]->opaque())); - b_raw_ptrs.push_back( - reinterpret_cast(b_ptrs_to_wrappers[i]->opaque())); - c_raw_ptrs.push_back( - reinterpret_cast(c_ptrs_to_wrappers[i]->opaque())); + a_raw_ptrs[i] = + reinterpret_cast(a_ptrs_to_wrappers[i]->opaque()); + b_raw_ptrs[i] = + reinterpret_cast(b_ptrs_to_wrappers[i]->opaque()); + c_raw_ptrs[i] = + reinterpret_cast(c_ptrs_to_wrappers[i]->opaque()); } - DeviceMemory a; // Make sure the temporary memory are in-scope before the function returns - std::unique_ptr> a_temp; - bool reallocated_a, reallocated_b, reallocated_c; - absl::Status a_allocation_status = AllocateStridedBuffer( - a_raw_ptrs, batch_count, batch_stride_a, scratch_allocator, stream, - &a_temp, &a, true, reallocated_a); - if (a_allocation_status != absl::OkStatus()) { - return a_allocation_status; + TF_ASSIGN_OR_RETURN( + auto a, AllocateStridedBuffer(a_raw_ptrs, batch_count, batch_stride_a, + scratch_allocator, stream, true)); + + TF_ASSIGN_OR_RETURN( + auto b, AllocateStridedBuffer(b_raw_ptrs, batch_count, batch_stride_b, + scratch_allocator, stream, true)); + + TF_ASSIGN_OR_RETURN( + auto c, AllocateStridedBuffer(c_raw_ptrs, batch_count, batch_stride_c, + scratch_allocator, stream, + true)); // can disable copy if beta=0 + + MAPPED_T *alpha_ptr = reinterpret_cast(&alpha); + MAPPED_T *beta_ptr = reinterpret_cast(&beta); + bool ok = DoBlasInternal( + rocblas_func, stream, /* pointer_mode_host = */ true, + ROCMBlasTranspose(transa), ROCMBlasTranspose(transb), m, n, k, + GpuComplex(alpha_ptr), GpuMemory(a.device_mem), lda, batch_stride_a, + GpuMemory(b.device_mem), ldb, batch_stride_b, GpuComplex(beta_ptr), + GpuMemoryMutable(&c.device_mem), ldc, batch_stride_c, batch_count); + + if (!ok) { + return absl::Status(absl::StatusCode::kInternal, + "failed BLAS call, see log for details"); } - - DeviceMemory b; - std::unique_ptr> b_temp; - absl::Status b_allocation_status = AllocateStridedBuffer( - b_raw_ptrs, batch_count, batch_stride_b, scratch_allocator, stream, - &b_temp, &b, true, reallocated_b); - if (b_allocation_status != absl::OkStatus()) { - return b_allocation_status; + if (c.reallocated) { + return ReorganizeMemory(stream, &c.device_mem, c_raw_ptrs, batch_count, + batch_stride_c, false); } + return absl::OkStatus(); +} - DeviceMemory c; - std::unique_ptr> c_temp; - absl::Status c_allocation_status = AllocateStridedBuffer( - c_raw_ptrs, batch_count, batch_stride_c, scratch_allocator, stream, - &c_temp, &c, true, reallocated_c); // can disable copy if beta=0 - if (c_allocation_status != absl::OkStatus()) { - return c_allocation_status; +class rocblas_hgemm_strided_batched_mfma { + int ALT_; + + public: + rocblas_hgemm_strided_batched_mfma(int ALT) : ALT_(ALT) {} + static const char *kName; + rocblas_status operator()(rocblas_handle handle, rocblas_operation transA, + rocblas_operation transB, rocblas_int m, + rocblas_int n, rocblas_int k, + const rocblas_half *alpha, const rocblas_half *A, + rocblas_int lda, rocblas_stride stride_a, + const rocblas_half *B, rocblas_int ldb, + rocblas_stride stride_b, const rocblas_half *beta, + rocblas_half *C, rocblas_int ldc, + rocblas_stride stride_c, rocblas_int batch_count) { + float alpha32 = static_cast(*(const __half *)alpha); + float beta32 = static_cast(*(const __half *)beta); + uint32_t flags = rocblas_gemm_flags_none; + if (ALT_) flags = rocblas_gemm_flags_fp16_alt_impl; + return wrap::rocblas_gemm_strided_batched_ex( + handle, transA, transB, m, n, k, &alpha32, A, rocblas_datatype_f16_r, + lda, stride_a, B, rocblas_datatype_f16_r, ldb, stride_b, &beta32, C, + rocblas_datatype_f16_r, ldc, stride_c, C, rocblas_datatype_f16_r, ldc, + stride_c, batch_count, rocblas_datatype_f32_r, + rocblas_gemm_algo_standard, 0, flags); } +}; - bool ok; - if constexpr (std::is_same_v) { - float alpha_ = static_cast(alpha); - float beta_ = static_cast(beta); - const void *alpha_ptr = reinterpret_cast(&alpha_); - const void *beta_ptr = reinterpret_cast(&beta_); - - ok = DoBlasInternal( - rocblas_func, stream, /* pointer_mode_host = */ true, - ROCMBlasTranspose(transa), ROCMBlasTranspose(transb), m, n, k, - alpha_ptr, a.opaque(), rocblas_datatype_bf16_r, lda, batch_stride_a, - b.opaque(), rocblas_datatype_bf16_r, ldb, batch_stride_b, beta_ptr, - c.opaque(), rocblas_datatype_bf16_r, ldc, batch_stride_c, c.opaque(), - rocblas_datatype_bf16_r, ldc, batch_stride_c, batch_count, - rocblas_datatype_f32_r, rocblas_gemm_algo_standard, 0, 0); - } else { - MAPPED_T *alpha_ptr = reinterpret_cast(&alpha); - MAPPED_T *beta_ptr = reinterpret_cast(&beta); - ok = DoBlasInternal(rocblas_func, stream, /* pointer_mode_host = */ true, - ROCMBlasTranspose(transa), ROCMBlasTranspose(transb), m, - n, k, GpuComplex(alpha_ptr), GpuMemory(a), lda, - batch_stride_a, GpuMemory(b), ldb, batch_stride_b, - GpuComplex(beta_ptr), GpuMemoryMutable(&c), ldc, - batch_stride_c, batch_count); +const char *rocblas_hgemm_strided_batched_mfma::kName = + "rocblas_hgemm_strided_batched_mfma"; + +class rocblas_gemm_strided_batched_bf16 { + public: + static const char *kName; + rocblas_status operator()(rocblas_handle handle, rocblas_operation transA, + rocblas_operation transB, rocblas_int m, + rocblas_int n, rocblas_int k, + const rocblas_bfloat16 *alpha, + const rocblas_bfloat16 *A, rocblas_int lda, + rocblas_stride stride_a, const rocblas_bfloat16 *B, + rocblas_int ldb, rocblas_stride stride_b, + const rocblas_bfloat16 *beta, rocblas_bfloat16 *C, + rocblas_int ldc, rocblas_stride stride_c, + rocblas_int batch_count) { + float alpha32 = static_cast(*(const Eigen::bfloat16 *)alpha); + float beta32 = static_cast(*(const Eigen::bfloat16 *)beta); + uint32_t flags = rocblas_gemm_flags_none; + return wrap::rocblas_gemm_strided_batched_ex( + handle, transA, transB, m, n, k, &alpha32, A, rocblas_datatype_bf16_r, + lda, stride_a, B, rocblas_datatype_bf16_r, ldb, stride_b, &beta32, C, + rocblas_datatype_bf16_r, ldc, stride_c, C, rocblas_datatype_bf16_r, ldc, + stride_c, batch_count, rocblas_datatype_f32_r, + rocblas_gemm_algo_standard, 0, flags); } - if (!ok) return absl::InternalError("failed BLAS call, see log for details"); - if (reallocated_c) - return ReorganizeMemory(stream, &c, c_raw_ptrs, batch_count, batch_stride_c, - false); - return absl::OkStatus(); -} +}; +const char *rocblas_gemm_strided_batched_bf16::kName = + "rocblas_gemm_strided_batched_bf16"; bool ROCMBlas::DoBlasGemmBatched( Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64_t m, - uint64_t n, uint64 k, float alpha, DeviceMemorySlice a, + uint64_t n, uint64_t k, float alpha, DeviceMemorySlice a, int lda, DeviceMemorySlice b, int ldb, float beta, DeviceMemorySlice c, int ldc, int batch_count, const NumericOptions &numeric_options, ScratchAllocator *scratch_allocator, blas::CallContext context) { - blas_log("DoBlasGemmBatched"); const Eigen::half alpha_half(alpha); const Eigen::half beta_half(beta); + absl::Status status; + + auto call_gemm = [&](auto x) { + return DoBlasGemmBatchedInternal(x, stream, transa, transb, m, n, k, + alpha_half, a, lda, b, ldb, beta_half, c, + ldc, batch_count, scratch_allocator); + }; + + if (has_mfma_) { + bool is_backprop = (context == blas::CallContext::kBackpropInput1) || + (context == blas::CallContext::kBackpropInput2); + status = call_gemm( + rocblas_hgemm_strided_batched_mfma(is_backprop && use_hgemm_alt_impl_)); + } else { + status = call_gemm(wrap::rocblas_hgemm_strided_batched); + } - absl::Status status = DoBlasGemmBatchedInternal( - wrap::rocblas_hgemm_strided_batched, stream, transa, transb, m, n, k, - alpha_half, a, lda, b, ldb, beta_half, c, ldc, batch_count, - scratch_allocator); if (!status.ok()) { LOG(ERROR) << status; } @@ -823,18 +1053,17 @@ bool ROCMBlas::DoBlasGemmBatched( bool ROCMBlas::DoBlasGemmBatched( Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64_t m, - uint64_t n, uint64 k, float alpha, + uint64_t n, uint64_t k, float alpha, DeviceMemorySlice a_array, int lda, DeviceMemorySlice b_array, int ldb, float beta, DeviceMemorySlice c_array, int ldc, int batch_count, const NumericOptions &numeric_options, ScratchAllocator *scratch_allocator, blas::CallContext context) { - blas_log("DoBlasGemmBatched"); const Eigen::bfloat16 alpha_bf16(alpha); const Eigen::bfloat16 beta_bf16(beta); absl::Status status = DoBlasGemmBatchedInternal( - wrap::rocblas_gemm_strided_batched_ex, stream, transa, transb, m, n, k, + rocblas_gemm_strided_batched_bf16(), stream, transa, transb, m, n, k, alpha_bf16, a_array, lda, b_array, ldb, beta_bf16, c_array, ldc, batch_count, scratch_allocator); if (!status.ok()) { @@ -843,280 +1072,128 @@ bool ROCMBlas::DoBlasGemmBatched( return status.ok(); } -bool ROCMBlas::DoBlasGemmBatched( - Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64_t m, - uint64_t n, uint64 k, float alpha, DeviceMemorySlice a_array, - int lda, DeviceMemorySlice b_array, int ldb, float beta, - DeviceMemorySlice c_array, int ldc, int batch_count, - const NumericOptions &numeric_options, ScratchAllocator *scratch_allocator, - blas::CallContext context) { - blas_log("DoBlasGemmBatched"); - absl::Status status = DoBlasGemmBatchedInternal( - wrap::rocblas_sgemm_strided_batched, stream, transa, transb, m, n, k, - alpha, a_array, lda, b_array, ldb, beta, c_array, ldc, batch_count, - scratch_allocator); - if (!status.ok()) { - LOG(ERROR) << status; +#define IMPL_DoBlasGemmBatched(T, Fun) \ + bool ROCMBlas::DoBlasGemmBatched( \ + Stream *stream, blas::Transpose transa, blas::Transpose transb, \ + uint64_t m, uint64_t n, uint64 k, T alpha, DeviceMemorySlice a_array, \ + int lda, DeviceMemorySlice b_array, int ldb, T beta, \ + DeviceMemorySlice c_array, int ldc, int batch_count, \ + const NumericOptions &numeric_options, \ + ScratchAllocator *scratch_allocator, blas::CallContext context) { \ + absl::Status status = DoBlasGemmBatchedInternal( \ + Fun, stream, transa, transb, m, n, k, alpha, a_array, lda, b_array, \ + ldb, beta, c_array, ldc, batch_count, scratch_allocator); \ + if (!status.ok()) { \ + LOG(ERROR) << status; \ + } \ + return status.ok(); \ } - return status.ok(); -} -bool ROCMBlas::DoBlasGemmBatched( - Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64_t m, - uint64_t n, uint64 k, double alpha, DeviceMemorySlice a_array, - int lda, DeviceMemorySlice b_array, int ldb, double beta, - DeviceMemorySlice c_array, int ldc, int batch_count, - const NumericOptions &numeric_options, ScratchAllocator *scratch_allocator, - blas::CallContext context) { - blas_log("DoBlasGemmBatched"); - absl::Status status = DoBlasGemmBatchedInternal( - wrap::rocblas_dgemm_strided_batched, stream, transa, transb, m, n, k, - alpha, a_array, lda, b_array, ldb, beta, c_array, ldc, batch_count, - scratch_allocator); - if (!status.ok()) { - LOG(ERROR) << status; +IMPL_DoBlasGemmBatched(float, wrap::rocblas_sgemm_strided_batched) + IMPL_DoBlasGemmBatched(double, wrap::rocblas_dgemm_strided_batched) + IMPL_DoBlasGemmBatched(std::complex, + wrap::rocblas_cgemm_strided_batched) + IMPL_DoBlasGemmBatched(std::complex, + wrap::rocblas_zgemm_strided_batched) +#define IMPL_DoBlasTrsm(T, Fun, Fun2) \ + bool ROCMBlas::DoBlasTrsm(Stream *stream, blas::Side side, \ + blas::UpperLower uplo, blas::Transpose transa, \ + blas::Diagonal diag, uint64_t m, uint64 n, \ + T alpha, const DeviceMemory &a, int lda, \ + DeviceMemory *b, int ldb) { \ + return DoBlasInternal(Fun, stream, /* pointer_mode_host = */ true, \ + ROCMBlasSide(side), ROCMBlasUpperLower(uplo), \ + ROCMBlasTranspose(transa), ROCMBlasDiagonal(diag), \ + m, n, complex_cast(alpha), complex_cast(a), lda, \ + complex_cast(b), ldb); \ + } \ + \ + bool ROCMBlas::DoBlasTrsmBatched( \ + Stream *stream, blas::Side side, blas::UpperLower uplo, \ + blas::Transpose transa, blas::Diagonal diag, uint64_t m, uint64 n, \ + T alpha, const DeviceMemory &as, int lda, DeviceMemory *bs, \ + int ldb, int batch_count) { \ + return DoBlasInternal(Fun2, stream, true /* = pointer_mode_host */, \ + ROCMBlasSide(side), ROCMBlasUpperLower(uplo), \ + ROCMBlasTranspose(transa), ROCMBlasDiagonal(diag), \ + m, n, complex_cast(alpha), complex_cast(as), lda, \ + complex_cast(*bs), ldb, batch_count); \ } - return status.ok(); -} -bool ROCMBlas::DoBlasGemmBatched( - Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64_t m, - uint64_t n, uint64 k, std::complex alpha, - DeviceMemorySlice> a_array, int lda, - DeviceMemorySlice> b_array, int ldb, - std::complex beta, DeviceMemorySlice> c_array, - int ldc, int batch_count, const NumericOptions &numeric_options, - ScratchAllocator *scratch_allocator, blas::CallContext context) { - blas_log("DoBlasGemmBatched"); - absl::Status status = DoBlasGemmBatchedInternal( - wrap::rocblas_cgemm_strided_batched, stream, transa, transb, m, n, k, - alpha, a_array, lda, b_array, ldb, beta, c_array, ldc, batch_count, - scratch_allocator); - if (!status.ok()) { - LOG(ERROR) << status; - } - return status.ok(); -} - -bool ROCMBlas::DoBlasGemmBatched( - Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64_t m, - uint64_t n, uint64 k, std::complex alpha, - DeviceMemorySlice> a_array, int lda, - DeviceMemorySlice> b_array, int ldb, - std::complex beta, DeviceMemorySlice> c_array, - int ldc, int batch_count, const NumericOptions &numeric_options, - ScratchAllocator *scratch_allocator, blas::CallContext context) { - blas_log("DoBlasGemmBatched"); - absl::Status status = DoBlasGemmBatchedInternal( - wrap::rocblas_zgemm_strided_batched, stream, transa, transb, m, n, k, - alpha, a_array, lda, b_array, ldb, beta, c_array, ldc, batch_count, - scratch_allocator); - if (!status.ok()) { - LOG(ERROR) << status; - } - return status.ok(); -} - -bool ROCMBlas::DoBlasTrsm(Stream *stream, blas::Side side, - blas::UpperLower uplo, blas::Transpose transa, - blas::Diagonal diag, uint64_t m, uint64 n, - float alpha, const DeviceMemory &a, int lda, - DeviceMemory *b, int ldb) { - blas_log("DoBlasTrsm"); - return DoBlasInternal(wrap::rocblas_strsm, stream, - /* pointer_mode_host = */ true, ROCMBlasSide(side), - ROCMBlasUpperLower(uplo), ROCMBlasTranspose(transa), - ROCMBlasDiagonal(diag), m, n, &alpha, GpuMemory(a), lda, - GpuMemoryMutable(b), ldb); -} - -bool ROCMBlas::DoBlasTrsm(Stream *stream, blas::Side side, - blas::UpperLower uplo, blas::Transpose transa, - blas::Diagonal diag, uint64_t m, uint64 n, - double alpha, const DeviceMemory &a, int lda, - DeviceMemory *b, int ldb) { - blas_log("DoBlasTrsm"); - return DoBlasInternal(wrap::rocblas_dtrsm, stream, - /* pointer_mode_host = */ true, ROCMBlasSide(side), - ROCMBlasUpperLower(uplo), ROCMBlasTranspose(transa), - ROCMBlasDiagonal(diag), m, n, &alpha, GpuMemory(a), lda, - GpuMemoryMutable(b), ldb); -} - -bool ROCMBlas::DoBlasTrsm(Stream *stream, blas::Side side, - blas::UpperLower uplo, blas::Transpose transa, - blas::Diagonal diag, uint64_t m, uint64 n, - std::complex alpha, - const DeviceMemory> &a, int lda, - DeviceMemory> *b, int ldb) { - return DoBlasInternal(wrap::rocblas_ctrsm, stream, - /* pointer_mode_host = */ true, ROCMBlasSide(side), - ROCMBlasUpperLower(uplo), ROCMBlasTranspose(transa), - ROCMBlasDiagonal(diag), m, n, complex_cast(alpha), - complex_cast(a), lda, complex_cast(b), ldb); -} - -bool ROCMBlas::DoBlasTrsm(Stream *stream, blas::Side side, - blas::UpperLower uplo, blas::Transpose transa, - blas::Diagonal diag, uint64_t m, uint64 n, - std::complex alpha, - const DeviceMemory> &a, int lda, - DeviceMemory> *b, int ldb) { - return DoBlasInternal(wrap::rocblas_ztrsm, stream, - /* pointer_mode_host = */ true, ROCMBlasSide(side), - ROCMBlasUpperLower(uplo), ROCMBlasTranspose(transa), - ROCMBlasDiagonal(diag), m, n, complex_cast(alpha), - complex_cast(a), lda, complex_cast(b), ldb); -} - -bool ROCMBlas::DoBlasTrsmBatched(Stream *stream, blas::Side side, - blas::UpperLower uplo, blas::Transpose transa, - blas::Diagonal diag, uint64_t m, uint64 n, - float alpha, const DeviceMemory &as, - int lda, DeviceMemory *bs, int ldb, - int batch_count) { - return DoBlasInternal(wrap::rocblas_strsm_batched, stream, - true /* = pointer_mode_host */, ROCMBlasSide(side), - ROCMBlasUpperLower(uplo), ROCMBlasTranspose(transa), - ROCMBlasDiagonal(diag), m, n, &alpha, GpuMemory(as), - lda, GpuMemoryMutable(bs), ldb, batch_count); -} - -bool ROCMBlas::DoBlasTrsmBatched(Stream *stream, blas::Side side, - blas::UpperLower uplo, blas::Transpose transa, - blas::Diagonal diag, uint64_t m, uint64 n, - double alpha, const DeviceMemory &as, - int lda, DeviceMemory *bs, int ldb, - int batch_count) { - return DoBlasInternal(wrap::rocblas_dtrsm_batched, stream, - true /* = pointer_mode_host */, ROCMBlasSide(side), - ROCMBlasUpperLower(uplo), ROCMBlasTranspose(transa), - ROCMBlasDiagonal(diag), m, n, &alpha, GpuMemory(as), - lda, GpuMemoryMutable(bs), ldb, batch_count); -} - -bool ROCMBlas::DoBlasTrsmBatched(Stream *stream, blas::Side side, - blas::UpperLower uplo, blas::Transpose transa, - blas::Diagonal diag, uint64_t m, uint64 n, - std::complex alpha, - const DeviceMemory *> &as, - int lda, - DeviceMemory *> *bs, - int ldb, int batch_count) { - return DoBlasInternal( - wrap::rocblas_ctrsm_batched, stream, true /* = pointer_mode_host */, - ROCMBlasSide(side), ROCMBlasUpperLower(uplo), ROCMBlasTranspose(transa), - ROCMBlasDiagonal(diag), m, n, complex_cast(alpha), - static_cast(as.opaque()), lda, - static_cast(bs->opaque()), ldb, - batch_count); -} - -bool ROCMBlas::DoBlasTrsmBatched(Stream *stream, blas::Side side, - blas::UpperLower uplo, blas::Transpose transa, - blas::Diagonal diag, uint64_t m, uint64 n, - std::complex alpha, - const DeviceMemory *> &as, - int lda, - DeviceMemory *> *bs, - int ldb, int batch_count) { - return DoBlasInternal( - wrap::rocblas_ztrsm_batched, stream, true /* = pointer_mode_host */, - ROCMBlasSide(side), ROCMBlasUpperLower(uplo), ROCMBlasTranspose(transa), - ROCMBlasDiagonal(diag), m, n, complex_cast(alpha), - static_cast(as.opaque()), lda, - static_cast(bs->opaque()), ldb, - batch_count); -} - -absl::Status ROCMBlas::DoBlasGemmStridedBatched( - Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64_t m, - uint64_t n, uint64 k, blas::DataType dtype, const void *alpha, - const DeviceMemoryBase &a, int lda, int64_t stride_a, - const DeviceMemoryBase &b, int ldb, int64_t stride_b, const void *beta, - DeviceMemoryBase *c, int ldc, int64_t stride_c, int batch_count, - const NumericOptions &numeric_options, blas::CallContext context) { + IMPL_DoBlasTrsm(float, wrap::rocblas_strsm, + wrap::rocblas_strsm_batched) + IMPL_DoBlasTrsm(double, wrap::rocblas_dtrsm, + wrap::rocblas_dtrsm_batched) + IMPL_DoBlasTrsm(std::complex, + wrap::rocblas_ctrsm, + wrap::rocblas_ctrsm_batched) + IMPL_DoBlasTrsm(std::complex, + wrap::rocblas_ztrsm, + wrap::rocblas_ztrsm_batched) + + absl::Status + ROCMBlas::DoBlasGemmStridedBatched( + Stream *stream, blas::Transpose transa, blas::Transpose transb, + uint64_t m, uint64_t n, uint64_t k, blas::DataType dtype, + const void *alpha, const DeviceMemoryBase &a, int lda, int64_t stride_a, + const DeviceMemoryBase &b, int ldb, int64_t stride_b, const void *beta, + DeviceMemoryBase *c, int ldc, int64_t stride_c, int batch_count, + const NumericOptions &numeric_options, blas::CallContext context) { VLOG(1) << absl::StreamFormat( - "doing rocBLAS SGEMM Strided Batched: at=%d bt=%d m=%u n=%u " + "doing rocBLAS GEMM Strided Batched: at=%d bt=%d m=%u n=%u " "k=%llu alpha=%p a=%p lda=%d b=%p ldb=%d beta=%p " - "c=%p ldc=%d", + "c=%p ldc=%d stride_a/b/c=%d/%d/%d batch_count=%d", static_cast(transa), static_cast(transb), m, n, k, alpha, - a.opaque(), lda, b.opaque(), ldb, beta, c->opaque(), ldc); + a.opaque(), lda, b.opaque(), ldb, beta, c->opaque(), ldc, stride_a, + stride_b, stride_c, batch_count); + + absl::Status status; + auto call_gemm = [&](auto func, auto type) { + return DoBlasInternalStatus( + func, stream, false, /* pointer_mode_host */ + ROCMBlasTranspose(transa), ROCMBlasTranspose(transb), m, n, k, + reinterpret_cast(alpha), + reinterpret_cast(a.opaque()), lda, stride_a, + reinterpret_cast(b.opaque()), ldb, stride_b, + reinterpret_cast(beta), + reinterpret_cast(c->opaque()), ldc, stride_c, + batch_count); + }; switch (dtype) { case blas::DataType::kHalf: { - const Eigen::half alpha_half(*static_cast(alpha)); - const Eigen::half beta_half(*static_cast(beta)); - return DoBlasInternalStatus( - wrap::rocblas_hgemm_strided_batched, stream, - false, /* pointer_mode_host */ - ROCMBlasTranspose(transa), ROCMBlasTranspose(transb), m, n, k, - reinterpret_cast(&alpha_half), - reinterpret_cast(a.opaque()), lda, stride_a, - reinterpret_cast(b.opaque()), ldb, stride_b, - reinterpret_cast(&beta_half), - reinterpret_cast(c->opaque()), ldc, stride_c, - batch_count); + bool is_backprop = (context == blas::CallContext::kBackpropInput1) || + (context == blas::CallContext::kBackpropInput2); + Eigen::half alpha_half = Eigen::half(*static_cast(alpha)); + Eigen::half beta_half = Eigen::half(*static_cast(beta)); + alpha = &alpha_half; + beta = &beta_half; + if (has_mfma_) { + return call_gemm(rocblas_hgemm_strided_batched_mfma( + is_backprop && use_hgemm_alt_impl_), + rocblas_half()); + } else { + return call_gemm(wrap::rocblas_hgemm_strided_batched, rocblas_half()); + } + } + case blas::DataType::kBF16: { + Eigen::bfloat16 alpha_bf16, beta_bf16; + alpha_bf16 = Eigen::bfloat16(*static_cast(alpha)); + beta_bf16 = Eigen::bfloat16(*static_cast(beta)); + alpha = &alpha_bf16; + beta = &beta_bf16; + return call_gemm(rocblas_gemm_strided_batched_bf16(), rocblas_bfloat16()); } - case blas::DataType::kBF16: - return DoBlasInternalStatus( - wrap::rocblas_gemm_strided_batched_ex, stream, - false, /* pointer_mode_host */ - ROCMBlasTranspose(transa), ROCMBlasTranspose(transb), m, n, k, alpha, - a.opaque(), rocblas_datatype_bf16_r, lda, stride_a, b.opaque(), - rocblas_datatype_bf16_r, ldb, stride_b, beta, c->opaque(), - rocblas_datatype_bf16_r, ldc, stride_c, c->opaque(), - rocblas_datatype_bf16_r, ldc, stride_c, batch_count, - rocblas_datatype_f32_r, rocblas_gemm_algo_standard, 0, 0); case blas::DataType::kFloat: - return DoBlasInternalStatus( - wrap::rocblas_sgemm_strided_batched, stream, - false, /* pointer_mode_host */ - ROCMBlasTranspose(transa), ROCMBlasTranspose(transb), m, n, k, - reinterpret_cast(alpha), - reinterpret_cast(a.opaque()), lda, stride_a, - reinterpret_cast(b.opaque()), ldb, stride_b, - reinterpret_cast(beta), - reinterpret_cast(c->opaque()), ldc, stride_c, batch_count); + return call_gemm(wrap::rocblas_sgemm_strided_batched, 1.0f); case blas::DataType::kDouble: - return DoBlasInternalStatus( - wrap::rocblas_dgemm_strided_batched, stream, - false, /* pointer_mode_host */ - ROCMBlasTranspose(transa), ROCMBlasTranspose(transb), m, n, k, - reinterpret_cast(alpha), - reinterpret_cast(a.opaque()), lda, stride_a, - reinterpret_cast(b.opaque()), ldb, stride_b, - reinterpret_cast(beta), - reinterpret_cast(c->opaque()), ldc, stride_c, batch_count); - case blas::DataType::kComplexFloat: { - auto cb_alpha = - complex_cast(*static_cast *>(alpha)); - auto cb_beta = - complex_cast(*static_cast *>(beta)); - return DoBlasInternalStatus( - wrap::rocblas_cgemm_strided_batched, stream, - /* pointer_mode_host = */ true, ROCMBlasTranspose(transa), - ROCMBlasTranspose(transb), m, n, k, cb_alpha, - static_cast(a.opaque()), lda, stride_a, - static_cast(b.opaque()), ldb, stride_b, - cb_beta, static_cast(c->opaque()), ldc, - stride_c, batch_count); - } - case blas::DataType::kComplexDouble: { - auto cb_alpha = - complex_cast(*static_cast *>(alpha)); - auto cb_beta = - complex_cast(*static_cast *>(beta)); - return DoBlasInternalStatus( - wrap::rocblas_zgemm_strided_batched, stream, - /* pointer_mode_host = */ true, ROCMBlasTranspose(transa), - ROCMBlasTranspose(transb), m, n, k, cb_alpha, - static_cast(a.opaque()), lda, - stride_a, static_cast(b.opaque()), - ldb, stride_b, cb_beta, - static_cast(c->opaque()), ldc, stride_c, - batch_count); - } + return call_gemm(wrap::rocblas_dgemm_strided_batched, 1.0); + case blas::DataType::kComplexFloat: + return call_gemm(wrap::rocblas_cgemm_strided_batched, + rocblas_float_complex()); + case blas::DataType::kComplexDouble: + return call_gemm(wrap::rocblas_zgemm_strided_batched, + rocblas_double_complex()); default: return absl::InternalError(absl::StrCat("Unsupported datatype for GEMM: ", blas::DataTypeString(dtype))); @@ -1124,7 +1201,25 @@ absl::Status ROCMBlas::DoBlasGemmStridedBatched( } absl::Status ROCMBlas::GetVersion(string *version) { +#if TF_ROCM_VERSION >= 60300 // Not yet available in ROCM-6.1 + absl::MutexLock lock{&mu_}; + size_t len = 0; + if (auto res = rocblas_get_version_string_size(&len); + res != rocblas_status_success) { + return absl::InternalError( + absl::StrCat("GetVersion failed with: ", ToString(res))); + } + std::vector buf(len + 1); + if (auto res = rocblas_get_version_string(buf.data(), len); + res != rocblas_status_success) { + return absl::InternalError( + absl::StrCat("GetVersion failed with: ", ToString(res))); + } + *version = string(buf.begin(), buf.end()); + return absl::OkStatus(); +#else return absl::UnimplementedError(""); +#endif } } // namespace gpu @@ -1167,5 +1262,6 @@ void initialize_rocblas() { } // namespace stream_executor -REGISTER_MODULE_INITIALIZER(register_rocblas, - { stream_executor::initialize_rocblas(); }); +STREAM_EXECUTOR_REGISTER_MODULE_INITIALIZER(register_rocblas, { + stream_executor::initialize_rocblas(); +}); diff --git a/third_party/xla/xla/stream_executor/rocm/rocm_blas.h b/third_party/xla/xla/stream_executor/rocm/rocm_blas.h index 85679e3934afda..537a3a7a46f07a 100644 --- a/third_party/xla/xla/stream_executor/rocm/rocm_blas.h +++ b/third_party/xla/xla/stream_executor/rocm/rocm_blas.h @@ -24,15 +24,17 @@ limitations under the License. #include "absl/synchronization/mutex.h" #include "absl/types/span.h" #include "rocm/rocm_config.h" + +#define ROCBLAS_BETA_FEATURES_API #if TF_ROCM_VERSION >= 50600 #include "rocm/include/rocblas/rocblas.h" #else #include "rocm/include/rocblas.h" #endif #include "xla/stream_executor/blas.h" +#include "xla/stream_executor/gpu/gpu_blas_lt.h" #include "xla/stream_executor/platform/port.h" #include "xla/stream_executor/plugin_registry.h" -#include "xla/stream_executor/temporary_device_memory.h" #if TF_HIPBLASLT #include "xla/stream_executor/rocm/hip_blas_lt.h" #endif @@ -43,32 +45,34 @@ class Stream; namespace gpu { -// Type conversion helper that helps to map non-rocblas types to rocblas types -// Right now, it only converts the Eigen::half type to rocblas_half type -template -struct RocBlasTypeConversionHelper { - using mapped_type = T; -}; - -template <> -struct RocBlasTypeConversionHelper { - using mapped_type = rocblas_half; +template +struct ChooseType { + using type = std::conditional_t< + std::is_same_v, B, + typename ChooseType::type>; }; -template <> -struct RocBlasTypeConversionHelper { - using mapped_type = rocblas_bfloat16; +template +struct ChooseType { + // default case: return the same type Target if there is no recursive match + using type = std::conditional_t, B, Target>; }; -template <> -struct RocBlasTypeConversionHelper> { - using mapped_type = rocblas_float_complex; +template +struct ChooseType { + // default case: return compile error if type is not found + static_assert(std::is_same_v, + "ChooseType: the target type is not found!"); + using type = B; }; -template <> -struct RocBlasTypeConversionHelper> { - using mapped_type = rocblas_double_complex; -}; +// Type conversion helper that helps to map non-rocblas types to rocblas types +template +using RocBlasType_t = + typename ChooseType, + rocblas_float_complex, std::complex, + rocblas_double_complex>::type; class GpuExecutor; @@ -124,49 +128,39 @@ class ROCMBlas : public blas::BlasSupport { // err_on_failure: Whether to print an error if the rocBLAS function // fails. args: Arguments of rocBLAS function. template - bool DoBlasInternalImpl(FuncT rocblas_func, Stream *stream, - bool pointer_mode_host, bool err_on_failure, - Args... args); + absl::Status DoBlasInternalImpl(FuncT rocblas_func, Stream *stream, + bool pointer_mode_host, bool err_on_failure, + Args &&...args); // Convenience functions that call DoBlasInternalImpl with different values // for err_on_failure. template bool DoBlasInternal(FuncT rocblas_func, Stream *stream, - bool pointer_mode_host, Args... args) { - return DoBlasInternalImpl(rocblas_func, stream, pointer_mode_host, - /*err_on_failure=*/true, args...); + bool pointer_mode_host, Args &&...args) { + auto ret = DoBlasInternalImpl(rocblas_func, stream, pointer_mode_host, + /*err_on_failure=*/true, + std::forward(args)...); + return ret.ok(); } // Same as above, but returns absl::Status. - template - absl::Status DoBlasInternalStatus(Args... args) { - if (!DoBlasInternal(args...)) { - return absl::InternalError("Failed calling rocBLAS"); - } - return absl::OkStatus(); + template + absl::Status DoBlasInternalStatus(FuncT rocblas_func, Stream *stream, + bool pointer_mode_host, Args &&...args) { + return DoBlasInternalImpl(rocblas_func, stream, pointer_mode_host, + /*err_on_failure=*/true, + std::forward(args)...); } template bool DoBlasInternalFailureOK(FuncT rocblas_func, Stream *stream, - bool pointer_mode_host, Args... args) { - return DoBlasInternalImpl(rocblas_func, stream, pointer_mode_host, - /*err_on_failure=*/false, args...); + bool pointer_mode_host, Args &&...args) { + auto ret = DoBlasInternalImpl(rocblas_func, stream, pointer_mode_host, + /*err_on_failure=*/false, + std::forward(args)...); + return ret.ok(); } - // A helper allocation function to convert raw pointers memory layout to - // strided flavor - template - absl::Status AllocateStridedBuffer( - const std::vector::mapped_type *> - &raw_ptrs, - int batch_count, uint64_t batch_stride, - ScratchAllocator *scratch_allocator, Stream *stream, - std::unique_ptr::mapped_type>> *temp_memory, - DeviceMemory::mapped_type> - *device_memory, - bool copy_data, bool &reallocated); - // A helper function to implement DoBlasGemmBatched interfaces for generic // types. // @@ -186,7 +180,7 @@ class ROCMBlas : public blas::BlasSupport { template absl::Status DoBlasGemmBatchedInternal( FuncT rocblas_func, Stream *stream, blas::Transpose transa, - blas::Transpose transb, uint64_t m, uint64 n, uint64 k, T alpha, + blas::Transpose transb, uint64_t m, uint64_t n, uint64_t k, T alpha, DeviceMemorySlice a_ptrs_to_wrappers, int lda, DeviceMemorySlice b_ptrs_to_wrappers, int ldb, T beta, DeviceMemorySlice c_ptrs_to_wrappers, int ldc, int batch_count, @@ -202,12 +196,18 @@ class ROCMBlas : public blas::BlasSupport { // rocBLAS library handle on the device. rocblas_handle blas_ ABSL_GUARDED_BY(mu_); + // container holding solutions vector (to avoid reallocating it each time) + std::vector solutions_; + #if TF_HIPBLASLT rocm::BlasLt blas_lt_; #endif ROCMBlas(const ROCMBlas &) = delete; void operator=(const ROCMBlas &) = delete; + + bool has_mfma_ = false; + bool use_hgemm_alt_impl_ = false; }; } // namespace gpu diff --git a/third_party/xla/xla/stream_executor/rocm/rocm_diagnostics.cc b/third_party/xla/xla/stream_executor/rocm/rocm_diagnostics.cc index e816dc9c3c2496..73aad489e6fbaf 100644 --- a/third_party/xla/xla/stream_executor/rocm/rocm_diagnostics.cc +++ b/third_party/xla/xla/stream_executor/rocm/rocm_diagnostics.cc @@ -41,12 +41,12 @@ limitations under the License. namespace stream_executor { namespace rocm { -string DriverVersionToString(DriverVersion version) { +std::string DriverVersionToString(DriverVersion version) { return absl::StrFormat("%d.%d.%d", std::get<0>(version), std::get<1>(version), std::get<2>(version)); } -string DriverVersionStatusToString(absl::StatusOr version) { +std::string DriverVersionStatusToString(absl::StatusOr version) { if (!version.ok()) { return version.status().ToString(); } @@ -54,8 +54,8 @@ string DriverVersionStatusToString(absl::StatusOr version) { return DriverVersionToString(version.value()); } -absl::StatusOr StringToDriverVersion(const string& value) { - std::vector pieces = absl::StrSplit(value, '.'); +absl::StatusOr StringToDriverVersion(const std::string& value) { + std::vector pieces = absl::StrSplit(value, '.'); if (pieces.size() != 2 && pieces.size() != 3) { return absl::Status{absl::StatusCode::kInvalidArgument, absl::StrFormat("expected %%d.%%d or %%d.%%d.%%d form " @@ -102,7 +102,7 @@ namespace gpu { // -- class Diagnostician -string Diagnostician::GetDevNodePath(int dev_node_ordinal) { +std::string Diagnostician::GetDevNodePath(int dev_node_ordinal) { return absl::StrCat("/dev/kfd", dev_node_ordinal); } @@ -117,10 +117,10 @@ void Diagnostician::LogDiagnosticInformation() { LOG(INFO) << "hostname: " << tsl::port::Hostname(); if (VLOG_IS_ON(1)) { const char* value = getenv("LD_LIBRARY_PATH"); - string library_path = value == nullptr ? "" : value; + std::string library_path = value == nullptr ? "" : value; VLOG(1) << "LD_LIBRARY_PATH is: \"" << library_path << "\""; - std::vector pieces = absl::StrSplit(library_path, ':'); + std::vector pieces = absl::StrSplit(library_path, ':'); for (const auto& piece : pieces) { if (piece.empty()) { continue; @@ -176,11 +176,11 @@ absl::StatusOr Diagnostician::FindDsoVersion() { if (dot == nullptr) { return 0; } - string dso_version = dot + strlen(so_suffix); + std::string dso_version = dot + strlen(so_suffix); // TODO(b/22689637): Eliminate the explicit namespace if possible. auto stripped_dso_version = absl::StripSuffix(dso_version, ".ld64"); auto result = static_cast*>(data); - *result = rocm::StringToDriverVersion(string(stripped_dso_version)); + *result = rocm::StringToDriverVersion(std::string(stripped_dso_version)); return 1; } return 0; @@ -192,10 +192,10 @@ absl::StatusOr Diagnostician::FindDsoVersion() { } absl::StatusOr Diagnostician::FindKernelModuleVersion( - const string& driver_version_file_contents) { + const std::string& driver_version_file_contents) { static const char* kDriverFilePrelude = "Kernel Module "; size_t offset = driver_version_file_contents.find(kDriverFilePrelude); - if (offset == string::npos) { + if (offset == std::string::npos) { return absl::Status{ absl::StatusCode::kNotFound, absl::StrCat("could not find kernel module information in " @@ -203,13 +203,13 @@ absl::StatusOr Diagnostician::FindKernelModuleVersion( driver_version_file_contents, "\"")}; } - string version_and_rest = driver_version_file_contents.substr( - offset + strlen(kDriverFilePrelude), string::npos); + std::string version_and_rest = driver_version_file_contents.substr( + offset + strlen(kDriverFilePrelude), std::string::npos); size_t space_index = version_and_rest.find(" "); auto kernel_version = version_and_rest.substr(0, space_index); // TODO(b/22689637): Eliminate the explicit namespace if possible. auto stripped_kernel_version = absl::StripSuffix(kernel_version, ".ld64"); - return rocm::StringToDriverVersion(string(stripped_kernel_version)); + return rocm::StringToDriverVersion(std::string(stripped_kernel_version)); } void Diagnostician::WarnOnDsoKernelMismatch( diff --git a/third_party/xla/xla/stream_executor/rocm/rocm_diagnostics.h b/third_party/xla/xla/stream_executor/rocm/rocm_diagnostics.h index 2685d0ae099575..f9bc2c2c484b55 100644 --- a/third_party/xla/xla/stream_executor/rocm/rocm_diagnostics.h +++ b/third_party/xla/xla/stream_executor/rocm/rocm_diagnostics.h @@ -25,13 +25,13 @@ namespace rocm { using DriverVersion = gpu::DriverVersion; // Converts a parsed driver version to string form. -string DriverVersionToString(DriverVersion version); +std::string DriverVersionToString(DriverVersion version); // Converts a parsed driver version or status value to natural string form. -string DriverVersionStatusToString(absl::StatusOr version); +std::string DriverVersionStatusToString(absl::StatusOr version); // Converts a string of a form like "331.79" to a DriverVersion{331, 79}. -absl::StatusOr StringToDriverVersion(const string& value); +absl::StatusOr StringToDriverVersion(const std::string& value); using Diagnostician = gpu::Diagnostician; diff --git a/third_party/xla/xla/stream_executor/rocm/rocm_dnn.cc b/third_party/xla/xla/stream_executor/rocm/rocm_dnn.cc index 3152c934375004..6d431b767e48ae 100644 --- a/third_party/xla/xla/stream_executor/rocm/rocm_dnn.cc +++ b/third_party/xla/xla/stream_executor/rocm/rocm_dnn.cc @@ -241,31 +241,31 @@ namespace wrap { #else -#define STREAM_EXECUTOR_MIOPEN_WRAP(__name) \ - struct DynLoadShim__##__name { \ - static const char* kName; \ - using FuncPtrT = std::add_pointer::type; \ - static void* GetDsoHandle() { \ - auto s = internal::CachedDsoLoader::GetMiopenDsoHandle(); \ - return s.value(); \ - } \ - static FuncPtrT LoadOrDie() { \ - void* f; \ - auto s = tsl::Env::Default() \ - -> GetSymbolFromLibrary(GetDsoHandle(), kName, &f); \ - CHECK(s.ok()) << "could not find " << kName \ - << " in miopen DSO; dlerror: " << s.message(); \ - return reinterpret_cast(f); \ - } \ - static FuncPtrT DynLoad() { \ - static FuncPtrT f = LoadOrDie(); \ - return f; \ - } \ - template \ - miopenStatus_t operator()(Args... args) { \ - return DynLoad()(args...); \ - } \ - } __name; \ +#define STREAM_EXECUTOR_MIOPEN_WRAP(__name) \ + struct DynLoadShim__##__name { \ + static const char* kName; \ + using FuncPtrT = std::add_pointer::type; \ + static void* GetDsoHandle() { \ + auto s = internal::CachedDsoLoader::GetMiopenDsoHandle(); \ + return s.value(); \ + } \ + static FuncPtrT LoadOrDie() { \ + void* f; \ + auto s = tsl::Env::Default()->GetSymbolFromLibrary(GetDsoHandle(), \ + kName, &f); \ + CHECK(s.ok()) << "could not find " << kName \ + << " in miopen DSO; dlerror: " << s.message(); \ + return reinterpret_cast(f); \ + } \ + static FuncPtrT DynLoad() { \ + static FuncPtrT f = LoadOrDie(); \ + return f; \ + } \ + template \ + miopenStatus_t operator()(Args... args) { \ + return DynLoad()(args...); \ + } \ + } __name; \ const char* DynLoadShim__##__name::kName = #__name; #endif @@ -2292,7 +2292,9 @@ bool CreateRnnWorkspace(Stream* stream, miopenHandle_t miopen_handle, return false; } - stream->ThenMemZero(workspace, workspace_size_in_bytes); + if (!stream->MemZero(workspace, workspace_size_in_bytes).ok()) { + return false; + } } else { *workspace = DeviceMemory(); } @@ -2370,7 +2372,8 @@ absl::Status MIOpenSupport::DoRnnForwardImpl( LOG(ERROR) << "Fail to allocate RNN reserve space"; return absl::InternalError("AllocateBytes for RNN failed"); } - stream->ThenMemZero(&reserve_space, reserve_space_size_in_bytes); + TF_RETURN_IF_ERROR( + stream->MemZero(&reserve_space, reserve_space_size_in_bytes)); } } @@ -2488,17 +2491,20 @@ absl::Status MIOpenSupport::DoRnnBackwardImpl( auto size_data = input_desc.seq_length() * input_desc.batch_size() * input_desc.data_size(); if ((size_data > 0) && (input_backprop_data->opaque() != nullptr)) - stream->ThenMemZero(input_backprop_data, size_data * type_size); + TF_RETURN_IF_ERROR( + stream->MemZero(input_backprop_data, size_data * type_size)); size_data = input_h_desc.num_layers() * input_h_desc.batch_size() * input_h_desc.data_size(); if ((size_data > 0) && (input_h_backprop_data->opaque() != nullptr)) - stream->ThenMemZero(input_h_backprop_data, size_data * type_size); + TF_RETURN_IF_ERROR( + stream->MemZero(input_h_backprop_data, size_data * type_size)); size_data = input_c_desc.num_layers() * input_c_desc.batch_size() * input_c_desc.data_size(); if ((size_data > 0) && (input_c_backprop_data->opaque() != nullptr)) - stream->ThenMemZero(input_c_backprop_data, size_data * type_size); + TF_RETURN_IF_ERROR( + stream->MemZero(input_c_backprop_data, size_data * type_size)); const bool is_profiling = output_profile_result != nullptr; @@ -2533,7 +2539,8 @@ absl::Status MIOpenSupport::DoRnnBackwardImpl( if (params_backprop_data != nullptr) { // Clear the dw to zeros. - stream->ThenMemZero(params_backprop_data, params_backprop_data->size()); + TF_RETURN_IF_ERROR( + stream->MemZero(params_backprop_data, params_backprop_data->size())); // make the backward weight call status = wrap::miopenRNNBackwardWeights( miopen.handle() /*handle*/, rnn_desc.handle() /*rnnDesc*/, @@ -2761,7 +2768,7 @@ absl::Status MIOpenSupport::DoCtcLoss( } absl::StatusOr> -MIOpenSupport::createRnnDescriptor( +MIOpenSupport::CreateRnnDescriptor( int num_layers, int hidden_size, int input_size, int cell_size, int batch_size, dnn::RnnInputMode input_mode, dnn::RnnDirectionMode direction_mode, dnn::RnnMode rnn_mode, @@ -2796,7 +2803,7 @@ MIOpenSupport::createRnnDescriptor( } absl::StatusOr> -MIOpenSupport::createRnnSequenceTensorDescriptor(int seq_length, int batch_size, +MIOpenSupport::CreateRnnSequenceTensorDescriptor(int seq_length, int batch_size, int data_size, dnn::DataType data_type) { std::unique_ptr seq_desc( @@ -2810,7 +2817,7 @@ MIOpenSupport::createRnnSequenceTensorDescriptor(int seq_length, int batch_size, } absl::StatusOr> -MIOpenSupport::createRnnStateTensorDescriptor(int num_layer, int batch_size, +MIOpenSupport::CreateRnnStateTensorDescriptor(int num_layer, int batch_size, int data_size, dnn::DataType data_type) { std::unique_ptr state_desc( @@ -4099,11 +4106,15 @@ absl::Status ROCmFusedMatmulRunner::gemm(Stream* stream, blas::Transpose tb = _trans_b ? blas::Transpose::kTranspose : blas::Transpose::kNoTranspose; - return stream->ThenBlasGemm( - tb, ta, _n, _m, _k, static_cast>(b_data), _ldb, - static_cast>(a_data), _lda, - static_cast*>(&c_data), _ldc, NumericOptions{}, - blas::CallContext::kNone); + auto* blas = stream->parent()->AsBlas(); + if (blas == nullptr) { + return absl::InternalError("No Blas support for stream"); + } + return blas->BlasGemm(stream, tb, ta, _n, _m, _k, + static_cast>(b_data), _ldb, + static_cast>(a_data), _lda, + static_cast*>(&c_data), _ldc, + NumericOptions{}, blas::CallContext::kNone); } template @@ -4245,7 +4256,7 @@ absl::Status MIOpenSupport::DoPoolForward( bool do_backward = false; uint8* workspace = nullptr; size_t workspace_size = 0; - std::unique_ptr> wsp_mem; + ScopedDeviceMemory wsp_mem; if (m_pooling_cache_enabled && element_type == dnn::DataType::kFloat) { do_backward = true; auto status = wrap::miopenPoolingGetWorkSpaceSizeV2( @@ -4264,12 +4275,10 @@ absl::Status MIOpenSupport::DoPoolForward( miopenFloat, pdesc); if (cache_hit) { // reusing the same buffer - workspace = reinterpret_cast( - pdesc->workspace->mutable_device_memory()->opaque()); + workspace = reinterpret_cast(pdesc->workspace.ptr()->opaque()); } else { - wsp_mem = stream->AllocateTemporaryArray(workspace_size).value(); - workspace = reinterpret_cast( - wsp_mem->mutable_device_memory()->opaque()); + wsp_mem = stream->parent()->AllocateOwnedArray(workspace_size); + workspace = reinterpret_cast(wsp_mem.ptr()->opaque()); m_pooling_cache.insert(input_data.opaque(), input_dimensions, output_dimensions, pooling_dimensions, miopenFloat, wsp_mem, workspace_size, @@ -4326,7 +4335,7 @@ void PoolingWorkspaceCache::insert( const void* p, const dnn::BatchDescriptor& input_dimensions, const dnn::BatchDescriptor& output_dimensions, const dnn::PoolingDescriptor& pooling_dimensions, int _type, - std::unique_ptr>& workspace, size_t wsp_size, + ScopedDeviceMemory& workspace, size_t wsp_size, hipStream_t hip_stream) { PoolingWorkspaceDescriptor* desc = 0; auto it = cache.find(p); @@ -4423,8 +4432,8 @@ absl::Status MIOpenSupport::DoPoolBackward( miopen_dtype, pdesc); if (cache_hit) { assert(pdesc != 0); - workspace_ptr = reinterpret_cast( - pdesc->workspace->mutable_device_memory()->opaque()); + workspace_ptr = + reinterpret_cast(pdesc->workspace.ptr()->opaque()); VLOG(1) << "Pooling cache hit"; } else { VLOG(1) << "Pooling cache miss"; @@ -4623,64 +4632,6 @@ bool MIOpenSupport::DoNormalizeBackwardWithDimensions( return true; } -bool MIOpenSupport::DoDepthConcatenate( - Stream* stream, absl::Span input_dimensions, - absl::Span* const> input_data, - DeviceMemory* output_data) { - CHECK_EQ(input_dimensions.size(), input_data.size()); - - for (const auto& dimensions : input_dimensions) { - if (dimensions.layout() != dnn::DataLayout::kBatchDepthYX) { - LOG(ERROR) << "MIOpenSupport::DoDepthConcatenate currently only " - "supports the kBatchDepthYX layout."; - return false; - } - } - - if (input_dimensions.empty()) { - return true; // Nothing to do. - } - - dnn::BatchDescriptor output_dimensions = - dnn::BatchDescriptor::DepthConcatenateOutputDescriptor(input_dimensions); - - const int64_t area = output_dimensions.width() * output_dimensions.height(); - const auto index = [area](int64_t batch, int64_t depth, int64_t yx, - int64_t max_depth) { - return (batch * max_depth + depth) * area + yx; - }; - - std::vector output_host(output_dimensions.ElementCount()); - std::vector tmp; - int64_t depth_sum = 0; - for (size_t i = 0; i < input_data.size(); ++i) { - const auto& dimensions = input_dimensions[i]; - tmp.resize(dimensions.ElementCount()); - stream->ThenMemcpyD2H(*input_data[i], absl::MakeSpan(tmp)); - absl::Status block_status = stream->BlockHostUntilDone(); - if (!block_status.ok()) { - LOG(ERROR) << "BlockHostUntilDone failed: " << block_status; - return false; - } - - for (int64_t batch = 0; batch < output_dimensions.count(); ++batch) { - for (int64_t yx = 0; yx < area; ++yx) { - for (int64_t depth = 0; depth < dimensions.feature_map_count(); - ++depth) { - LOG(INFO) << output_dimensions.ElementCount() << ' ' << batch << ' ' - << yx << ' ' << depth; - output_host[index(batch, depth + depth_sum, yx, - output_dimensions.feature_map_count())] = - tmp[index(batch, depth, yx, dimensions.feature_map_count())]; - } - } - } - depth_sum += dimensions.feature_map_count(); - } - stream->ThenMemcpyH2D(output_host, output_data); - return true; -} - bool MIOpenSupport::DeriveOutputBatchDescriptor( const BatchDescriptor& batch_descriptor, const FilterDescriptor& filter_descriptor, @@ -4763,5 +4714,6 @@ void initialize_miopen() { } // namespace stream_executor -REGISTER_MODULE_INITIALIZER(register_miopen, - { stream_executor::initialize_miopen(); }); +STREAM_EXECUTOR_REGISTER_MODULE_INITIALIZER(register_miopen, { + stream_executor::initialize_miopen(); +}); diff --git a/third_party/xla/xla/stream_executor/rocm/rocm_dnn.h b/third_party/xla/xla/stream_executor/rocm/rocm_dnn.h index b9ed7ce6b6d9ea..ecaffd3cad392d 100644 --- a/third_party/xla/xla/stream_executor/rocm/rocm_dnn.h +++ b/third_party/xla/xla/stream_executor/rocm/rocm_dnn.h @@ -22,9 +22,9 @@ limitations under the License. #include "absl/synchronization/mutex.h" #include "absl/types/span.h" #include "rocm/include/miopen/miopen.h" +#include "xla/stream_executor/device_memory_allocator.h" #include "xla/stream_executor/dnn.h" #include "xla/stream_executor/plugin_registry.h" -#include "xla/stream_executor/temporary_device_memory.h" namespace stream_executor { namespace gpu { @@ -41,7 +41,7 @@ struct PoolingWorkspaceDescriptor { dnn::PoolingDescriptor op; int dtype; uint64_t timestamp; - std::unique_ptr> workspace; + ScopedDeviceMemory workspace; size_t workspace_size; bool IsSame(const dnn::BatchDescriptor& input_dimensions, const dnn::BatchDescriptor& output_dimensions, @@ -61,8 +61,8 @@ struct PoolingWorkspaceCache { void insert(const void* p, const dnn::BatchDescriptor& input_dimensions, const dnn::BatchDescriptor& output_dimensions, const dnn::PoolingDescriptor& pooling_dimensions, int _type, - std::unique_ptr>& workspace, - size_t wsp_size, hipStream_t hip_stream); + ScopedDeviceMemory& workspace, size_t wsp_size, + hipStream_t hip_stream); private: void trim(hipStream_t hip_stream); @@ -77,7 +77,7 @@ class MIOpenSupport : public dnn::DnnSupport { absl::Status Init() override; absl::StatusOr GetVersion() override; - absl::StatusOr> createRnnDescriptor( + absl::StatusOr> CreateRnnDescriptor( int num_layers, int hidden_size, int input_size, int cell_size, int batch_size, dnn::RnnInputMode input_mode, dnn::RnnDirectionMode direction_mode, dnn::RnnMode rnn_mode, @@ -86,12 +86,12 @@ class MIOpenSupport : public dnn::DnnSupport { ScratchAllocator* state_allocator, bool use_padded_io) override; absl::StatusOr> - createRnnSequenceTensorDescriptor(int seq_length, int batch_size, + CreateRnnSequenceTensorDescriptor(int seq_length, int batch_size, int data_size, dnn::DataType data_type) override; absl::StatusOr> - createRnnStateTensorDescriptor(int num_layer, int batch_size, int data_size, + CreateRnnStateTensorDescriptor(int num_layer, int batch_size, int data_size, dnn::DataType data_type) override; bool DoRnnForward(Stream* stream, const dnn::RnnDescriptor& rnn_desc, @@ -422,11 +422,6 @@ class MIOpenSupport : public dnn::DnnSupport { DeviceMemory* raw_variable_gradient, ScratchAllocator* workspace_allocator = nullptr) override; - bool DoDepthConcatenate( - Stream* stream, absl::Span input_dimensions, - absl::Span* const> input_data, - DeviceMemory* output_data) override; - // Derives an output batch descriptor from an input batch and convolution // descriptors. bool DeriveOutputBatchDescriptor( diff --git a/third_party/xla/xla/stream_executor/rocm/rocm_driver.cc b/third_party/xla/xla/stream_executor/rocm/rocm_driver.cc index 02dcac01c008f3..2d075f06cd1897 100644 --- a/third_party/xla/xla/stream_executor/rocm/rocm_driver.cc +++ b/third_party/xla/xla/stream_executor/rocm/rocm_driver.cc @@ -182,7 +182,7 @@ ScopedActivateContext::ScopedActivateContext(GpuContext* hip_context) { if (tls->depth == 0) { VLOG(3) << "ScopedActivateContext switching to " << hip_context->device_ordinal(); - FAIL_IF_ROCM_ERROR(hipCtxSetCurrent(hip_context->context()), + FAIL_IF_ROCM_ERROR(wrap::hipCtxSetCurrent(hip_context->context()), "Failed setting context"); tls->depth = 1; tls->current_device_ordinal = hip_context->device_ordinal(); @@ -205,7 +205,7 @@ ScopedActivateContext::ScopedActivateContext(GpuContext* hip_context) { to_restore_ = tls->context; // Set the device and update thread local. - FAIL_IF_ROCM_ERROR(hipCtxSetCurrent(hip_context->context()), + FAIL_IF_ROCM_ERROR(wrap::hipCtxSetCurrent(hip_context->context()), "Failed setting context"); tls->current_device_ordinal = hip_context->device_ordinal(); tls->context = hip_context; @@ -229,7 +229,7 @@ ScopedActivateContext::~ScopedActivateContext() { } // Set context and update thread local. - FAIL_IF_ROCM_ERROR(hipCtxSetCurrent(to_restore_->context()), + FAIL_IF_ROCM_ERROR(wrap::hipCtxSetCurrent(to_restore_->context()), "Failed setting context"); tls->current_device_ordinal = to_restore_->device_ordinal(); tls->context = to_restore_; @@ -743,16 +743,16 @@ GpuDriver::GraphNodeGetType(hipGraphNode_t node) { /* static */ absl::StatusOr GpuDriver::GraphAddNode(hipGraphNode_t* node, hipGraph_t graph, - absl::Span deps, + absl::Span deps, const GpuGraphNodeParams& params) { return absl::UnimplementedError("unsupported node type"); } /* static */ absl::Status GpuDriver::GraphAddKernelNode( - hipGraphNode_t* node, hipGraph_t graph, absl::Span deps, - absl::string_view kernel_name, hipFunction_t function, - unsigned int grid_dim_x, unsigned int grid_dim_y, unsigned int grid_dim_z, - unsigned int block_dim_x, unsigned int block_dim_y, + hipGraphNode_t* node, hipGraph_t graph, + absl::Span deps, absl::string_view kernel_name, + hipFunction_t function, unsigned int grid_dim_x, unsigned int grid_dim_y, + unsigned int grid_dim_z, unsigned int block_dim_x, unsigned int block_dim_y, unsigned int block_dim_z, unsigned int shared_mem_bytes, void** kernel_params, void** extra) { VLOG(2) << "Add kernel node to a graph " << graph @@ -833,8 +833,8 @@ GpuDriver::GraphAddNode(hipGraphNode_t* node, hipGraph_t graph, } /* static */ absl::Status GpuDriver::GraphAddChildNode( - hipGraphNode_t* node, hipGraph_t graph, absl::Span deps, - hipGraph_t child) { + hipGraphNode_t* node, hipGraph_t graph, + absl::Span deps, hipGraph_t child) { VLOG(2) << "Create a new node by cloning the child graph " << child << " and add it to " << graph << "; deps: " << deps.size(); @@ -895,7 +895,7 @@ static hipMemAllocationType ToHipAllocationType( /*static*/ absl::Status GpuDriver::GraphAddMemFreeNode( GpuGraphNodeHandle* node, GpuGraphHandle graph, - absl::Span deps, GpuDevicePtr gpu_dst) { + absl::Span deps, GpuDevicePtr gpu_dst) { RETURN_IF_ROCM_ERROR(wrap::hipGraphAddMemFreeNode(node, graph, deps.data(), deps.size(), gpu_dst), "Failed to add memory free node to a HIP graph"); @@ -904,7 +904,7 @@ static hipMemAllocationType ToHipAllocationType( /*static*/ absl::Status GpuDriver::GraphAddMemAllocNode( GpuGraphNodeHandle* node, GpuGraphHandle graph, - absl::Span deps, MemAccessFlags access_flags, + absl::Span deps, MemAccessFlags access_flags, MemLocationType location_type, int device_id, MemAllocationType allocation_type, uint64_t size, GpuDevicePtr* d_ptr, uint64_t max_pool_size) { @@ -952,25 +952,16 @@ GpuDriver::GraphGetMemAllocNodeParams(GpuGraphNodeHandle node) { /* static */ absl::Status GpuDriver::GraphAddMemcpyD2DNode( GpuContext* context, GpuGraphNodeHandle* node, GpuGraphHandle graph, - absl::Span deps, GpuDevicePtr gpu_dst, + absl::Span deps, GpuDevicePtr gpu_dst, GpuDevicePtr gpu_src, uint64_t size) { VLOG(2) << "Add memcpy d2d node to a graph " << graph << "; dst: " << reinterpret_cast(gpu_dst) << "; src: " << reinterpret_cast(gpu_src) << "; size: " << size << "; context: " << context->context() << "; deps: " << deps.size(); - hipMemcpy3DParms params{ - .srcArray = {}, - .srcPos = {}, - .srcPtr = {.ptr = gpu_src, .pitch = size, .xsize = size, .ysize = 1}, - .dstArray = {}, - .dstPos = {}, - .dstPtr = {.ptr = gpu_dst, .pitch = size, .xsize = size, .ysize = 1}, - .extent = hipExtent{.width = size, .height = 1, .depth = 1}, - .kind = hipMemcpyDeviceToDevice}; - - RETURN_IF_ROCM_ERROR(wrap::hipGraphAddMemcpyNode(node, graph, deps.data(), - deps.size(), ¶ms), + RETURN_IF_ROCM_ERROR(wrap::hipGraphAddMemcpyNode1D( + node, graph, deps.data(), deps.size(), gpu_dst, + gpu_src, size, hipMemcpyDeviceToDevice), "Failed to add memcpy d2d node to a HIP graph"); return absl::OkStatus(); @@ -984,18 +975,9 @@ GpuDriver::GraphGetMemAllocNodeParams(GpuGraphNodeHandle node) { << "; src: " << reinterpret_cast(gpu_src) << "; size: " << size << "; context: " << context->context(); - hipMemcpy3DParms params{ - .srcArray = {}, - .srcPos = {}, - .srcPtr = {.ptr = gpu_src, .pitch = size, .xsize = size, .ysize = 1}, - .dstArray = {}, - .dstPos = {}, - .dstPtr = {.ptr = gpu_dst, .pitch = size, .xsize = size, .ysize = 1}, - .extent = hipExtent{.width = size, .height = 1, .depth = 1}, - .kind = hipMemcpyDeviceToDevice}; - RETURN_IF_ROCM_ERROR( - wrap::hipGraphExecMemcpyNodeSetParams(exec, node, ¶ms), + wrap::hipGraphExecMemcpyNodeSetParams1D(exec, node, gpu_dst, gpu_src, + size, hipMemcpyDeviceToDevice), "Failed to set memcpy d2d node params"); return absl::OkStatus(); @@ -1035,7 +1017,7 @@ struct BitPatternToValue { /* static */ absl::Status GpuDriver::GraphAddMemsetNode( GpuContext* context, GpuGraphNodeHandle* node, GpuGraphHandle graph, - absl::Span deps, GpuDevicePtr dst, + absl::Span deps, GpuDevicePtr dst, std::variant bit_pattern, uint64_t num_elements) { VLOG(2) << "Add memset node to a graph " << graph diff --git a/third_party/xla/xla/stream_executor/rocm/rocm_driver_wrapper.h b/third_party/xla/xla/stream_executor/rocm/rocm_driver_wrapper.h index 833ae720cfe8c3..32079f4236a2bb 100644 --- a/third_party/xla/xla/stream_executor/rocm/rocm_driver_wrapper.h +++ b/third_party/xla/xla/stream_executor/rocm/rocm_driver_wrapper.h @@ -53,7 +53,7 @@ namespace wrap { static FuncPtrT loaded = []() -> FuncPtrT { \ static const char *kName = TO_STR(hipSymbolName); \ void *f; \ - auto s = tsl::Env::Default() -> GetSymbolFromLibrary( \ + auto s = tsl::Env::Default()->GetSymbolFromLibrary( \ stream_executor::internal::CachedDsoLoader::GetHipDsoHandle() \ .value(), \ kName, &f); \ @@ -106,7 +106,6 @@ namespace wrap { __macro(hipGraphAddKernelNode) \ __macro(hipGraphAddChildGraphNode) \ __macro(hipGraphAddMemAllocNode) \ - __macro(hipGraphAddMemcpyNode) \ __macro(hipGraphAddMemcpyNode1D) \ __macro(hipGraphAddMemsetNode) \ __macro(hipGraphAddMemFreeNode) \ @@ -116,7 +115,7 @@ namespace wrap { __macro(hipGraphExecChildGraphNodeSetParams) \ __macro(hipGraphExecDestroy) \ __macro(hipGraphExecKernelNodeSetParams) \ - __macro(hipGraphExecMemcpyNodeSetParams) \ + __macro(hipGraphExecMemcpyNodeSetParams1D) \ __macro(hipGraphExecMemsetNodeSetParams) \ __macro(hipGraphExecUpdate) \ __macro(hipGraphInstantiate) \ diff --git a/third_party/xla/xla/stream_executor/rocm/rocm_executor.cc b/third_party/xla/xla/stream_executor/rocm/rocm_executor.cc index a08c9730b6caf2..338502cf306d9e 100644 --- a/third_party/xla/xla/stream_executor/rocm/rocm_executor.cc +++ b/third_party/xla/xla/stream_executor/rocm/rocm_executor.cc @@ -160,9 +160,9 @@ GpuExecutor::CreateOrShareConstant(Stream* stream, "Failed to allocate %d bytes for new constant", content.size())); } - absl::Status status = - stream->ThenMemcpy(new_constant, content.data(), content.size()) - .BlockHostUntilDone(); + TF_RETURN_IF_ERROR( + stream->Memcpy(new_constant, content.data(), content.size())); + absl::Status status = stream->BlockHostUntilDone(); if (!status.ok()) { Deallocate(new_constant); status.Update(absl::InternalError(absl::StrFormat( @@ -346,7 +346,8 @@ absl::Status GpuExecutor::Launch(Stream* stream, const ThreadDim& thread_dims, if (VLOG_IS_ON(2)) { absl::MutexLock lock(&launched_kernels_mu_); if (!launched_kernels_.count(hipfunc)) { - VlogOccupancyInfo(kernel, thread_dims, block_dims); + VlogOccupancyInfo(stream->parent()->GetDeviceDescription(), kernel, + thread_dims, block_dims); // TODO(rspringer): Remove elements from launched_kernels_...if we ever // expose a kernel/module deallocation method. launched_kernels_.insert(hipfunc); @@ -464,7 +465,8 @@ absl::Status GpuExecutor::LoadModuleFromHsaco(const char* hsaco, // This is a non-essential operation; if there's a failure, proceed without // logging an error. It's nearly certain that in case of failures, we'd never // get here in the first place; these are very low-impact routines. -void GpuExecutor::VlogOccupancyInfo(const Kernel& kernel, +void GpuExecutor::VlogOccupancyInfo(const DeviceDescription& device_description, + const Kernel& kernel, const ThreadDim& thread_dims, const BlockDim& block_dims) { VLOG(2) << "Computing kernel occupancy for kernel " @@ -479,9 +481,6 @@ void GpuExecutor::VlogOccupancyInfo(const Kernel& kernel, return; } - const DeviceDescription& device_description = - kernel.parent()->GetDeviceDescription(); - const GpuKernel* rocm_kernel = AsGpuKernel(&kernel); auto hipfunc = rocm_kernel->AsGpuFunctionHandle(); @@ -857,28 +856,25 @@ GpuExecutor::CreateEventImplementation() { return std::unique_ptr(new GpuEvent(this)); } -std::unique_ptr -GpuExecutor::CreateKernelImplementation() { - return std::unique_ptr(new GpuKernel()); -} - std::unique_ptr GpuExecutor::GetStreamImplementation() { return std::unique_ptr(new GpuStream(this)); } -absl::StatusOr> -GpuExecutor::GetCommandBufferImplementation(CommandBuffer::Mode mode) { +absl::StatusOr> GpuExecutor::CreateKernel() { + return std::make_unique(this); +} + +absl::StatusOr> GpuExecutor::CreateCommandBuffer( + CommandBuffer::Mode mode) { VLOG(2) << "Create ROCm command buffer (ROCm graph)"; GpuGraphHandle graph = nullptr; TF_RETURN_IF_ERROR(GpuDriver::CreateGraph(&graph)); return std::make_unique(mode, /*parent=*/this, graph); } -std::unique_ptr -GpuExecutor::GetCommandBufferImplementation(CommandBuffer::Mode mode, - GpuGraphHandle graph, - bool is_owned_graph) { +std::unique_ptr GpuExecutor::CreateCommandBuffer( + CommandBuffer::Mode mode, GpuGraphHandle graph, bool is_owned_graph) { VLOG(2) << "Create HIP command buffer (HIP graph) from existing graph " << graph << "; is_owned_graph=" << is_owned_graph; return std::make_unique(mode, /*parent=*/this, graph, @@ -1078,4 +1074,4 @@ GpuExecutor::CreateDeviceDescription(int device_ordinal) { } // namespace stream_executor -REGISTER_MODULE_INITIALIZER(rocm_executor, {}); +STREAM_EXECUTOR_REGISTER_MODULE_INITIALIZER(rocm_executor, {}); diff --git a/third_party/xla/xla/stream_executor/rocm/rocm_fft.cc b/third_party/xla/xla/stream_executor/rocm/rocm_fft.cc index 8c862d3dae070c..9d2f9b30899cea 100644 --- a/third_party/xla/xla/stream_executor/rocm/rocm_fft.cc +++ b/third_party/xla/xla/stream_executor/rocm/rocm_fft.cc @@ -424,7 +424,7 @@ bool ROCMFft::DoFftInternal(Stream *stream, fft::Plan *plan, FuncT hipfftExec, if (allocator) { auto allocated = allocator->AllocateBytes(input.size()); if (allocated.ok()) { - if (stream->ThenMemcpy(&allocated.value(), input, input.size()).ok()) { + if (stream->Memcpy(&allocated.value(), input, input.size()).ok()) { input_maybe_copy = DeviceMemory(allocated.value()); } else { LOG(ERROR) << "failed to copy input buffer for rocFFT."; @@ -529,5 +529,6 @@ void initialize_rocfft() { } // namespace stream_executor -REGISTER_MODULE_INITIALIZER(register_rocfft, - { stream_executor::initialize_rocfft(); }); +STREAM_EXECUTOR_REGISTER_MODULE_INITIALIZER(register_rocfft, { + stream_executor::initialize_rocfft(); +}); diff --git a/third_party/xla/xla/stream_executor/rocm/rocm_platform.cc b/third_party/xla/xla/stream_executor/rocm/rocm_platform.cc index 174affdfe07657..1b5f54d11076f2 100644 --- a/third_party/xla/xla/stream_executor/rocm/rocm_platform.cc +++ b/third_party/xla/xla/stream_executor/rocm/rocm_platform.cc @@ -147,16 +147,16 @@ ROCmPlatform::GetUncachedExecutor(const StreamExecutorConfig& config) { } // namespace gpu static void InitializeROCmPlatform() { - // Disabling leak checking, MultiPlatformManager does not destroy its + // Disabling leak checking, PlatformManager does not destroy its // registered platforms. - auto status = MultiPlatformManager::PlatformWithName("ROCM"); + auto status = PlatformManager::PlatformWithName("ROCM"); if (!status.ok()) { std::unique_ptr platform(new gpu::ROCmPlatform); - TF_CHECK_OK(MultiPlatformManager::RegisterPlatform(std::move(platform))); + TF_CHECK_OK(PlatformManager::RegisterPlatform(std::move(platform))); } } } // namespace stream_executor -REGISTER_MODULE_INITIALIZER(rocm_platform, - stream_executor::InitializeROCmPlatform()); +STREAM_EXECUTOR_REGISTER_MODULE_INITIALIZER( + rocm_platform, stream_executor::InitializeROCmPlatform()); diff --git a/third_party/xla/xla/stream_executor/rocm/rocm_platform.h b/third_party/xla/xla/stream_executor/rocm/rocm_platform.h index 3f4de3120f906d..7c9f5037435496 100644 --- a/third_party/xla/xla/stream_executor/rocm/rocm_platform.h +++ b/third_party/xla/xla/stream_executor/rocm/rocm_platform.h @@ -21,9 +21,9 @@ limitations under the License. #include "absl/synchronization/mutex.h" #include "xla/stream_executor/executor_cache.h" -#include "xla/stream_executor/multi_platform_manager.h" #include "xla/stream_executor/platform.h" #include "xla/stream_executor/platform/port.h" +#include "xla/stream_executor/platform_manager.h" #include "xla/stream_executor/stream_executor.h" #include "xla/stream_executor/stream_executor_internal.h" diff --git a/third_party/xla/xla/stream_executor/scratch_allocator.h b/third_party/xla/xla/stream_executor/scratch_allocator.h index adb9f566cbbf05..4c860627c947ae 100644 --- a/third_party/xla/xla/stream_executor/scratch_allocator.h +++ b/third_party/xla/xla/stream_executor/scratch_allocator.h @@ -25,7 +25,6 @@ limitations under the License. #include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/device_memory_allocator.h" #include "xla/stream_executor/platform/port.h" -#include "xla/stream_executor/temporary_device_memory.h" #include "tsl/platform/statusor.h" namespace stream_executor { diff --git a/third_party/xla/xla/stream_executor/stream.cc b/third_party/xla/xla/stream_executor/stream.cc index eaaaca826b81dd..b0c2313e8ef45d 100644 --- a/third_party/xla/xla/stream_executor/stream.cc +++ b/third_party/xla/xla/stream_executor/stream.cc @@ -34,7 +34,7 @@ limitations under the License. #include "xla/stream_executor/platform/port.h" #include "xla/stream_executor/stream_executor.h" #include "xla/stream_executor/stream_executor_internal.h" -#include "xla/stream_executor/temporary_device_memory.h" +#include "tsl/platform/errors.h" #include "tsl/platform/logging.h" #include "tsl/platform/stacktrace.h" @@ -45,55 +45,6 @@ namespace { // will be VLOG'ed. We need overloads, instead of // e.g. BatchDescriptorToVlogString(), as the code that calls these // functions does not know what the type of the parameter is. -std::string ToVlogString(const dnn::BatchDescriptor &descriptor) { - return descriptor.ToShortString(); -} - -std::string ToVlogString(const dnn::FilterDescriptor &descriptor) { - return descriptor.ToShortString(); -} - -std::string ToVlogString(const dnn::ConvolutionDescriptor &descriptor) { - return descriptor.ToShortString(); -} - -std::string ToVlogString(const dnn::PoolingDescriptor &descriptor) { - return descriptor.ToShortString(); -} - -std::string ToVlogString(const dnn::NormalizeDescriptor &descriptor) { - return descriptor.ToShortString(); -} - -std::string ToVlogString(dnn::ActivationMode mode) { - return dnn::ActivationModeString(mode); -} - -std::string ToVlogString(const dnn::AlgorithmConfig &algo_config) { - return algo_config.ToString(); -} - -std::string ToVlogString(dnn::ElementwiseOperation op) { - return dnn::ElementwiseOperationString(op); -} - -std::string ToVlogString(dnn::QuantizedActivationMode mode) { - return dnn::QuantizedActivationModeString(mode); -} - -std::string ToVlogString(blas::Transpose t) { return blas::TransposeString(t); } - -std::string ToVlogString(blas::UpperLower ul) { - return blas::UpperLowerString(ul); -} - -std::string ToVlogString(blas::Diagonal d) { return blas::DiagonalString(d); } - -std::string ToVlogString(blas::Side s) { return blas::SideString(s); } - -std::string ToVlogString(blas::ComputationType ty) { - return blas::ComputationTypeString(ty); -} std::string ToVlogString(const void *ptr) { if (ptr == nullptr) { @@ -106,14 +57,6 @@ std::string ToVlogString(const void *ptr) { return out.str(); } -template -std::string ToVlogString(const std::complex &c) { - // StrCat does not convert std::complex to text. - std::ostringstream out; - out << c; - return out.str(); -} - template std::string ToVlogString(const std::function &f) { return f == nullptr ? "null" : ""; @@ -132,76 +75,12 @@ std::string ToVlogString(const DeviceMemoryBase *memory) { return memory == nullptr ? "null" : ToVlogString(*memory); } -std::string ToVlogString(const Eigen::half &h) { - return absl::StrCat(static_cast(h)); -} - -std::string ToVlogString(const Eigen::bfloat16 &bf) { // NOLINT - return absl::StrCat(static_cast(bf)); -} - -std::string ToVlogString(int i) { return absl::StrCat(i); } - std::string ToVlogString(uint32_t i) { return absl::StrCat(i); } std::string ToVlogString(uint64_t i) { return absl::StrCat(i); } -std::string ToVlogString(int64_t i) { return absl::StrCat(i); } - std::string ToVlogString(float f) { return absl::StrCat(f); } -std::string ToVlogString(double d) { return absl::StrCat(d); } - -template -std::string ToVlogString(absl::Span elements) { - std::string str = absl::StrCat( - ToVlogString(reinterpret_cast(elements.data())), "[", - elements.size(), "]{"); - const char *separator = ""; - size_t max_to_show = std::numeric_limits::max(); - if (!VLOG_IS_ON(2)) { - max_to_show = 5; - } else if (!VLOG_IS_ON(3)) { - max_to_show = 20; - } else if (!VLOG_IS_ON(11)) { - max_to_show = 1000; - } - for (size_t i = 0; i < elements.size(); ++i) { - if (i == max_to_show) { - str += ", ..."; - break; - } - absl::StrAppend(&str, separator, ToVlogString(elements[i])); - separator = ", "; - } - str += "}"; - return str; -} - -template -std::string ToVlogString(absl::Span elements) { - return ToVlogString(absl::Span(elements)); -} - -std::string ToVlogString(dnn::DataType data_type) { - switch (data_type) { - case dnn::DataType::kFloat: - return "dnn::DataType::kFloat"; - case dnn::DataType::kDouble: - return "dnn::DataType::kDouble"; - case dnn::DataType::kHalf: - return "dnn::DataType::kHalf"; - case dnn::DataType::kInt8: - return "dnn::DataType::kInt8"; - case dnn::DataType::kInt32: - return "dnn::DataType::kInt32"; - case dnn::DataType::kBF16: - return "dnn::DataType::kBF16"; - default: - return "unknown DataType"; - } -} - // Used together with PARAM to VLOG calls made to the stream. Intended // to be used like this: // @@ -305,271 +184,58 @@ absl::Status Stream::RefreshStatus() { return status; } -Stream &Stream::Init() { +absl::Status Stream::Initialize() { VLOG_CALL(); absl::MutexLock lock(&mu_); - CHECK_EQ(false, allocated_) - << "stream appears to already have been initialized"; - CHECK(!status_.ok()) << "stream should be in !ok() state pre-initialization"; + if (allocated_) { + return absl::InternalError( + "stream appears to already have been initialized"); + } + if (status_.ok()) { + return absl::InternalError( + "stream should be in !ok() state pre-initialization"); + } if (parent_->AllocateStream(this)) { // Successful initialization! allocated_ = true; status_ = absl::OkStatus(); - } else { - LOG(ERROR) << "failed to allocate stream during initialization"; + return absl::OkStatus(); } - return *this; + return absl::InternalError("failed to allocate stream during initialization"); } -Stream &Stream::ThenRecordEvent(Event *event) { - VLOG_CALL(PARAM(event)); +Stream &Stream::Init() { + VLOG_CALL(); - absl::Status status = parent_->RecordEvent(this, event); + absl::Status status = Initialize(); if (!status.ok()) { - LOG(ERROR) << "Error recording event in stream: " << status.message() - << "; not marking stream as bad, as the Event object may be " - << "at fault. Monitor for further errors."; - } - - return *this; -} - -Stream &Stream::ThenBatchNormalizationForward( - const DeviceMemory &x, const DeviceMemory &scale, - const DeviceMemory &offset, - const DeviceMemory &estimated_mean, - const DeviceMemory &estimated_variance, - const DeviceMemory &side_input, const dnn::BatchDescriptor &x_desc, - const dnn::BatchDescriptor &scale_offset_desc, const double epsilon, - const double exponential_average_factor, - dnn::ActivationMode activation_mode, DeviceMemory *y, - DeviceMemory *batch_mean, DeviceMemory *batch_var, - DeviceMemory *saved_mean, DeviceMemory *saved_inv_var, - bool is_training, ScratchAllocator *reserve_space_allocator, - ScratchAllocator *workspace_allocator) { - VLOG_CALL(PARAM(x), PARAM(scale), PARAM(offset), PARAM(x_desc), - PARAM(scale_offset_desc), PARAM(epsilon), PARAM(y)); - if (dnn::DnnSupport *dnn = parent_->AsDnn()) { - CheckError(dnn->DoBatchNormalizationForward( - this, x, scale, offset, estimated_mean, estimated_variance, side_input, - x_desc, scale_offset_desc, epsilon, exponential_average_factor, - activation_mode, y, batch_mean, batch_var, saved_mean, saved_inv_var, - is_training, reserve_space_allocator, workspace_allocator)); - } else { - SetErrorAndLogNoDnnSupport(); - } - return *this; -} - -Stream &Stream::ThenBatchNormalizationBackward( - const DeviceMemory &y_backprop, const DeviceMemory &x, - const DeviceMemory &scale, const DeviceMemory &offset, - const DeviceMemory &mean, const DeviceMemory &inv_var, - const DeviceMemory &y, const dnn::BatchDescriptor &x_desc, - const dnn::BatchDescriptor &scale_offset_desc, const double epsilon, - dnn::ActivationMode activation_mode, DeviceMemory *x_backprop, - DeviceMemory *scale_backprop, DeviceMemory *offset_backprop, - DeviceMemory *side_input_backprop, - DeviceMemory *reserve_space_data, - ScratchAllocator *workspace_allocator) { - VLOG_CALL(PARAM(y_backprop), PARAM(x), PARAM(scale), PARAM(x_desc), - PARAM(scale_offset_desc), PARAM(epsilon), PARAM(x_backprop), - PARAM(scale_backprop), PARAM(offset_backprop)); - if (dnn::DnnSupport *dnn = parent_->AsDnn()) { - CheckError(dnn->DoBatchNormalizationBackward( - this, y_backprop, x, scale, offset, mean, inv_var, y, x_desc, - scale_offset_desc, epsilon, activation_mode, x_backprop, scale_backprop, - offset_backprop, side_input_backprop, reserve_space_data, - workspace_allocator)); - } else { - SetErrorAndLogNoDnnSupport(); - } - return *this; -} - -Stream &Stream::ThenBatchNormalizationForward( - const DeviceMemory &x, const DeviceMemory &scale, - const DeviceMemory &offset, - const DeviceMemory &estimated_mean, - const DeviceMemory &estimated_variance, - const DeviceMemory &side_input, - const dnn::BatchDescriptor &x_desc, - const dnn::BatchDescriptor &scale_offset_desc, const double epsilon, - const double exponential_average_factor, - dnn::ActivationMode activation_mode, DeviceMemory *y, - DeviceMemory *batch_mean, DeviceMemory *batch_var, - DeviceMemory *saved_mean, DeviceMemory *saved_inv_var, - bool is_training, ScratchAllocator *reserve_space_allocator, - ScratchAllocator *workspace_allocator) { - VLOG_CALL(PARAM(x), PARAM(scale), PARAM(offset), PARAM(x_desc), - PARAM(scale_offset_desc), PARAM(epsilon), PARAM(y)); - if (dnn::DnnSupport *dnn = parent_->AsDnn()) { - CheckError(dnn->DoBatchNormalizationForward( - this, x, scale, offset, estimated_mean, estimated_variance, side_input, - x_desc, scale_offset_desc, epsilon, exponential_average_factor, - activation_mode, y, batch_mean, batch_var, saved_mean, saved_inv_var, - is_training, reserve_space_allocator, workspace_allocator)); - } else { - SetErrorAndLogNoDnnSupport(); - } - return *this; -} - -Stream &Stream::ThenBatchNormalizationBackward( - const DeviceMemory &y_backprop, - const DeviceMemory &x, const DeviceMemory &scale, - const DeviceMemory &offset, const DeviceMemory &mean, - const DeviceMemory &inv_var, const DeviceMemory &y, - const dnn::BatchDescriptor &x_desc, - const dnn::BatchDescriptor &scale_offset_desc, const double epsilon, - dnn::ActivationMode activation_mode, DeviceMemory *x_backprop, - DeviceMemory *scale_backprop, DeviceMemory *offset_backprop, - DeviceMemory *side_input_backprop, - DeviceMemory *reserve_space_data, - ScratchAllocator *workspace_allocator) { - VLOG_CALL(PARAM(y_backprop), PARAM(x), PARAM(scale), PARAM(x_desc), - PARAM(scale_offset_desc), PARAM(epsilon), PARAM(x_backprop), - PARAM(scale_backprop), PARAM(offset_backprop)); - if (dnn::DnnSupport *dnn = parent_->AsDnn()) { - CheckError(dnn->DoBatchNormalizationBackward( - this, y_backprop, x, scale, offset, mean, inv_var, y, x_desc, - scale_offset_desc, epsilon, activation_mode, x_backprop, scale_backprop, - offset_backprop, side_input_backprop, reserve_space_data, - workspace_allocator)); - - } else { - SetErrorAndLogNoDnnSupport(); - } - return *this; -} - -Stream &Stream::ThenBatchNormalizationForward( - const DeviceMemory &x, const DeviceMemory &scale, - const DeviceMemory &offset, - const DeviceMemory &estimated_mean, - const DeviceMemory &estimated_variance, - const DeviceMemory &side_input, - const dnn::BatchDescriptor &x_desc, - const dnn::BatchDescriptor &scale_offset_desc, const double epsilon, - const double exponential_average_factor, - dnn::ActivationMode activation_mode, DeviceMemory *y, - DeviceMemory *batch_mean, DeviceMemory *batch_var, - DeviceMemory *saved_mean, DeviceMemory *saved_inv_var, - bool is_training, ScratchAllocator *reserve_space_allocator, - ScratchAllocator *workspace_allocator) { - VLOG_CALL(PARAM(x), PARAM(scale), PARAM(offset), PARAM(x_desc), - PARAM(scale_offset_desc), PARAM(epsilon), PARAM(y)); - if (dnn::DnnSupport *dnn = parent_->AsDnn()) { - CheckError(dnn->DoBatchNormalizationForward( - this, x, scale, offset, estimated_mean, estimated_variance, side_input, - x_desc, scale_offset_desc, epsilon, exponential_average_factor, - activation_mode, y, batch_mean, batch_var, saved_mean, saved_inv_var, - is_training, reserve_space_allocator, workspace_allocator)); - } else { - SetErrorAndLogNoDnnSupport(); + LOG(ERROR) << status; } - return *this; -} - -Stream &Stream::ThenBatchNormalizationBackward( - const DeviceMemory &y_backprop, - const DeviceMemory &x, const DeviceMemory &scale, - const DeviceMemory &offset, const DeviceMemory &mean, - const DeviceMemory &inv_var, const DeviceMemory &y, - const dnn::BatchDescriptor &x_desc, - const dnn::BatchDescriptor &scale_offset_desc, const double epsilon, - dnn::ActivationMode activation_mode, - DeviceMemory *x_backprop, - DeviceMemory *scale_backprop, DeviceMemory *offset_backprop, - DeviceMemory *side_input_backprop, - DeviceMemory *reserve_space_data, - ScratchAllocator *workspace_allocator) { - VLOG_CALL(PARAM(y_backprop), PARAM(x), PARAM(scale), PARAM(x_desc), - PARAM(scale_offset_desc), PARAM(epsilon), PARAM(x_backprop), - PARAM(scale_backprop), PARAM(offset_backprop)); - if (dnn::DnnSupport *dnn = parent_->AsDnn()) { - CheckError(dnn->DoBatchNormalizationBackward( - this, y_backprop, x, scale, offset, mean, inv_var, y, x_desc, - scale_offset_desc, epsilon, activation_mode, x_backprop, scale_backprop, - offset_backprop, side_input_backprop, reserve_space_data, - workspace_allocator)); - } else { - SetErrorAndLogNoDnnSupport(); - } return *this; } -Stream &Stream::ThenNormalizeWithDimensions( - const dnn::NormalizeDescriptor &normalize_descriptor, - const dnn::BatchDescriptor &dimensions, - const DeviceMemory &input_data, DeviceMemory *output_data) { - VLOG_CALL(PARAM(normalize_descriptor), PARAM(dimensions), PARAM(input_data), - PARAM(output_data)); +Stream &Stream::ThenRecordEvent(Event *event) { + VLOG_CALL(PARAM(event)); - if (dnn::DnnSupport *dnn = parent_->AsDnn()) { - CheckError(dnn->DoNormalizeWithDimensions( - this, normalize_descriptor, dimensions, input_data, output_data)); - } else { - SetErrorAndLogNoDnnSupport(); + absl::Status status = RecordEvent(event); + if (!status.ok()) { + LOG(ERROR) << "Error recording event in stream: " << status.message() + << "; not marking stream as bad, as the Event object may be " + << "at fault. Monitor for further errors."; } - return *this; -} -Stream &Stream::ThenNormalizeBackwardWithDimensions( - const dnn::NormalizeDescriptor &normalize_descriptor, - const dnn::BatchDescriptor &dimensions, const DeviceMemory &raw_data, - const DeviceMemory &normalized_data, - const DeviceMemory &normalized_variable_gradient, - DeviceMemory *raw_variable_gradient, - ScratchAllocator *workspace_allocator) { - VLOG_CALL(PARAM(normalize_descriptor), PARAM(dimensions), PARAM(raw_data), - PARAM(normalized_data), PARAM(normalized_variable_gradient), - PARAM(raw_variable_gradient), PARAM(workspace_allocator)); - - if (dnn::DnnSupport *dnn = parent_->AsDnn()) { - CheckError(dnn->DoNormalizeBackwardWithDimensions( - this, normalize_descriptor, dimensions, raw_data, normalized_data, - normalized_variable_gradient, raw_variable_gradient, - workspace_allocator)); - } else { - SetErrorAndLogNoDnnSupport(); - } return *this; } -Stream &Stream::ThenDepthConcatenate( - absl::Span input_dimensions, - absl::Span *const> input_data, - DeviceMemory *output_data) { - VLOG_CALL(PARAM(input_dimensions), PARAM(input_data), PARAM(output_data)); - - for (size_t i = 1; i < input_dimensions.size(); ++i) { - if (input_dimensions[i].count() != input_dimensions[0].count() || - input_dimensions[i].height() != input_dimensions[0].height() || - input_dimensions[i].width() != input_dimensions[0].width()) { - SetError(); - LOG(ERROR) << "Incompatible dimensions for depth concatenation.\n" - << "input_dimensions[0]: " << input_dimensions[0].ToString() - << "input_dimensions[" << i - << "]: " << input_dimensions[i].ToString(); - return *this; - } - } - - if (dnn::DnnSupport *dnn = parent_->AsDnn()) { - CheckError(dnn->DoDepthConcatenate(this, input_dimensions, input_data, - output_data)); - } else { - SetErrorAndLogNoDnnSupport(); - } - return *this; +absl::Status Stream::RecordEvent(Event *event) { + return parent_->RecordEvent(this, event); } -Stream *Stream::GetOrCreateSubStream() { +absl::StatusOr Stream::GetOrCreateSubStream() { // Do not destroy bad streams when holding mu_ because ~Stream() may // BlockHostUntilDone and it's host callbacks might attempt to acquire mu_. std::vector> bad_streams; @@ -608,13 +274,9 @@ Stream *Stream::GetOrCreateSubStream() { } // No streams are reusable; create a new stream. - sub_streams_.emplace_back(std::unique_ptr{new Stream{parent_}}, - false); + sub_streams_.emplace_back(std::make_unique(parent_), false); Stream *sub_stream = sub_streams_.back().first.get(); - sub_stream->Init(); - if (!sub_stream->ok()) { - LOG(ERROR) << "sub-stream failed to be initialized"; - } + TF_RETURN_IF_ERROR(sub_stream->Initialize()); VLOG(1) << DebugStreamPointers() << " created new sub_stream " << sub_stream->DebugStreamPointers(); @@ -666,7 +328,7 @@ Stream &Stream::ThenWaitFor(Stream *other) { CHECK(this != other) << "stream cannot wait for itself"; if (ok() && other->ok()) { - CheckError(parent_->CreateStreamDependency(this, other)); + CheckStatus(WaitFor(other)); } else { SetError(); LOG(INFO) << DebugStreamPointers() << " did not wait for " @@ -679,7 +341,7 @@ Stream &Stream::ThenWaitFor(Event *event) { VLOG_CALL(PARAM(event)); if (ok()) { - absl::Status status = parent_->WaitForEvent(this, event); + absl::Status status = WaitFor(event); if (!status.ok()) { LOG(ERROR) << "Error waiting for event in stream: " << status.message() << "; not marking stream as bad, as the Event object may be " @@ -691,847 +353,82 @@ Stream &Stream::ThenWaitFor(Event *event) { return *this; } -// A functor that implements ThenBlasXXX interfaces, which calls DoBlasXXX -// functions and logs for errors. -template -struct ThenBlasImpl { - // blas_func is the DoBlasXXX member function pointer, and args are its - // arguments except the first one of Stream* type. - Stream &operator()(Stream *stream, - bool (blas::BlasSupport::*blas_func)(Stream *, Args...), - Args... args) { - return Run(stream, blas_func, /*record_error=*/true, args...); - } - - // Like operator(), but only calls stream->CheckError() if record_error is - // true. - Stream &Run(Stream *stream, - bool (blas::BlasSupport::*blas_func)(Stream *, Args...), - bool record_error, Args... args); -}; - -template -Stream &ThenBlasImpl::Run( - Stream *stream, bool (blas::BlasSupport::*blas_func)(Stream *, Args...), - bool record_error, Args... args) { - if (stream->ok()) { - bool ok; - if (blas::BlasSupport *blas = stream->parent_->AsBlas()) { - ok = (blas->*blas_func)(stream, args...); - } else { - LOG(WARNING) - << "attempting to perform BLAS operation using StreamExecutor " - "without BLAS support"; - ok = false; - } - if (record_error) { - stream->CheckError(ok); - } +absl::Status Stream::WaitFor(Stream *other) { + if (this == other) { + return absl::InternalError("stream cannot wait for itself"); } - return *stream; -} - -Stream &Stream::ThenBlasAxpy(uint64_t elem_count, float alpha, - const DeviceMemory &x, int incx, - DeviceMemory *y, int incy) { - VLOG_CALL(PARAM(elem_count), PARAM(alpha), PARAM(x), PARAM(incx), PARAM(y), - PARAM(incy)); - - ThenBlasImpl &, int, - DeviceMemory *, int> - impl; - return impl(this, &blas::BlasSupport::DoBlasAxpy, elem_count, alpha, x, incx, - y, incy); -} - -Stream &Stream::ThenBlasCopy(uint64_t elem_count, const DeviceMemory &x, - int incx, DeviceMemory *y, int incy) { - VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy)); - - ThenBlasImpl &, int, - DeviceMemory *, int> - impl; - return impl(this, &blas::BlasSupport::DoBlasCopy, elem_count, x, incx, y, - incy); -} - -Stream &Stream::ThenBlasScal(uint64_t elem_count, float alpha, - DeviceMemory *x, int incx) { - VLOG_CALL(PARAM(elem_count), PARAM(alpha), PARAM(x), PARAM(incx)); - - ThenBlasImpl *, int> impl; - return impl(this, &blas::BlasSupport::DoBlasScal, elem_count, alpha, x, incx); -} - -Stream &Stream::ThenBlasScal(uint64_t elem_count, double alpha, - DeviceMemory *x, int incx) { - VLOG_CALL(PARAM(elem_count), PARAM(alpha), PARAM(x), PARAM(incx)); - - ThenBlasImpl *, int> impl; - return impl(this, &blas::BlasSupport::DoBlasScal, elem_count, alpha, x, incx); -} - -Stream &Stream::ThenBlasScal(uint64_t elem_count, float alpha, - DeviceMemory> *x, int incx) { - VLOG_CALL(PARAM(elem_count), PARAM(alpha), PARAM(x), PARAM(incx)); - - ThenBlasImpl> *, int> impl; - return impl(this, &blas::BlasSupport::DoBlasScal, elem_count, alpha, x, incx); -} - -Stream &Stream::ThenBlasScal(uint64_t elem_count, double alpha, - DeviceMemory> *x, int incx) { - VLOG_CALL(PARAM(elem_count), PARAM(alpha), PARAM(x), PARAM(incx)); - - ThenBlasImpl> *, int> - impl; - return impl(this, &blas::BlasSupport::DoBlasScal, elem_count, alpha, x, incx); -} - -Stream &Stream::ThenBlasScal(uint64_t elem_count, std::complex alpha, - DeviceMemory> *x, int incx) { - VLOG_CALL(PARAM(elem_count), PARAM(alpha), PARAM(x), PARAM(incx)); - - ThenBlasImpl, - DeviceMemory> *, int> - impl; - return impl(this, &blas::BlasSupport::DoBlasScal, elem_count, alpha, x, incx); -} - -Stream &Stream::ThenBlasScal(uint64_t elem_count, std::complex alpha, - DeviceMemory> *x, int incx) { - VLOG_CALL(PARAM(elem_count), PARAM(alpha), PARAM(x), PARAM(incx)); - - ThenBlasImpl, - DeviceMemory> *, int> - impl; - return impl(this, &blas::BlasSupport::DoBlasScal, elem_count, alpha, x, incx); -} - -Stream &Stream::ThenBlasGemv(blas::Transpose trans, uint64_t m, uint64 n, - float alpha, const DeviceMemory &a, int lda, - const DeviceMemory &x, int incx, float beta, - DeviceMemory *y, int incy) { - VLOG_CALL(PARAM(trans), PARAM(m), PARAM(n), PARAM(alpha), PARAM(a), - PARAM(lda), PARAM(x), PARAM(incx), PARAM(beta), PARAM(y), - PARAM(incy)); - - ThenBlasImpl &, int, const DeviceMemory &, - int, float, DeviceMemory *, int> - impl; - return impl(this, &blas::BlasSupport::DoBlasGemv, trans, m, n, alpha, a, lda, - x, incx, beta, y, incy); -} - -Stream &Stream::ThenBlasGemv(blas::Transpose trans, uint64_t m, uint64 n, - double alpha, const DeviceMemory &a, - int lda, const DeviceMemory &x, int incx, - double beta, DeviceMemory *y, int incy) { - VLOG_CALL(PARAM(trans), PARAM(m), PARAM(n), PARAM(alpha), PARAM(a), - PARAM(lda), PARAM(x), PARAM(incx), PARAM(beta), PARAM(y), - PARAM(incy)); - - ThenBlasImpl &, int, const DeviceMemory &, - int, double, DeviceMemory *, int> - impl; - return impl(this, &blas::BlasSupport::DoBlasGemv, trans, m, n, alpha, a, lda, - x, incx, beta, y, incy); -} - -Stream &Stream::ThenBlasGemv(blas::Transpose trans, uint64_t m, uint64 n, - std::complex alpha, - const DeviceMemory> &a, - int lda, - const DeviceMemory> &x, - int incx, std::complex beta, - DeviceMemory> *y, int incy) { - VLOG_CALL(PARAM(trans), PARAM(m), PARAM(n), PARAM(alpha), PARAM(a), - PARAM(lda), PARAM(x), PARAM(incx), PARAM(beta), PARAM(y), - PARAM(incy)); - - ThenBlasImpl, - const DeviceMemory> &, int, - const DeviceMemory> &, int, - std::complex, DeviceMemory> *, int> - impl; - return impl(this, &blas::BlasSupport::DoBlasGemv, trans, m, n, alpha, a, lda, - x, incx, beta, y, incy); -} - -Stream &Stream::ThenBlasGemv(blas::Transpose trans, uint64_t m, uint64 n, - std::complex alpha, - const DeviceMemory> &a, - int lda, - const DeviceMemory> &x, - int incx, std::complex beta, - DeviceMemory> *y, int incy) { - VLOG_CALL(PARAM(trans), PARAM(m), PARAM(n), PARAM(alpha), PARAM(a), - PARAM(lda), PARAM(x), PARAM(incx), PARAM(beta), PARAM(y), - PARAM(incy)); - - ThenBlasImpl, - const DeviceMemory> &, int, - const DeviceMemory> &, int, - std::complex, DeviceMemory> *, int> - impl; - return impl(this, &blas::BlasSupport::DoBlasGemv, trans, m, n, alpha, a, lda, - x, incx, beta, y, incy); -} - -Stream &Stream::ThenBlasSbmv(blas::UpperLower uplo, uint64_t n, uint64 k, - float alpha, const DeviceMemory &a, int lda, - const DeviceMemory &x, int incx, float beta, - DeviceMemory *y, int incy) { - VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(k), PARAM(alpha), PARAM(a), PARAM(lda), - PARAM(x), PARAM(incx), PARAM(beta), PARAM(y), PARAM(incy)); - - ThenBlasImpl &, int, const DeviceMemory &, - int, float, DeviceMemory *, int> - impl; - return impl(this, &blas::BlasSupport::DoBlasSbmv, uplo, n, k, alpha, a, lda, - x, incx, beta, y, incy); -} - -namespace { -// Like ThenBlasImpl, except this expects the last argument of blas_func to be a -// blas::ProfileResult*. This functor doesn't put the stream into an error -// state if the op fails and the profile result is non-null. Instead, the -// error-ness is returned in the profile result itself. -template -struct ThenBlasWithProfileImpl { - Stream &operator()(Stream *stream, - bool (blas::BlasSupport::*blas_func)( - Stream *, Args..., blas::ProfileResult *), - Args... args, blas::ProfileResult *profile_result) { - ThenBlasImpl Runner; - bool record_error = profile_result == nullptr; - return Runner.Run(stream, blas_func, record_error, args..., profile_result); + if (parent_->CreateStreamDependency(this, other)) { + return absl::OkStatus(); } -}; -} // anonymous namespace - -Stream &Stream::ThenBlasTrsm(blas::Side side, blas::UpperLower uplo, - blas::Transpose transa, blas::Diagonal diag, - uint64_t m, uint64 n, float alpha, - const DeviceMemory &a, int lda, - DeviceMemory *b, int ldb) { - VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(transa), PARAM(diag), PARAM(m), - PARAM(n), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb)); - - ThenBlasImpl &, int, - DeviceMemory *, int> - impl; - return impl(this, &blas::BlasSupport::DoBlasTrsm, side, uplo, transa, diag, m, - n, alpha, a, lda, b, ldb); -} - -Stream &Stream::ThenBlasTrsm(blas::Side side, blas::UpperLower uplo, - blas::Transpose transa, blas::Diagonal diag, - uint64_t m, uint64 n, double alpha, - const DeviceMemory &a, int lda, - DeviceMemory *b, int ldb) { - VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(transa), PARAM(diag), PARAM(m), - PARAM(n), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb)); - - ThenBlasImpl &, int, - DeviceMemory *, int> - impl; - return impl(this, &blas::BlasSupport::DoBlasTrsm, side, uplo, transa, diag, m, - n, alpha, a, lda, b, ldb); -} - -Stream &Stream::ThenBlasTrsm(blas::Side side, blas::UpperLower uplo, - blas::Transpose transa, blas::Diagonal diag, - uint64_t m, uint64 n, std::complex alpha, - const DeviceMemory> &a, - int lda, DeviceMemory> *b, - int ldb) { - VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(transa), PARAM(diag), PARAM(m), - PARAM(n), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb)); - - ThenBlasImpl, - const DeviceMemory> &, int, - DeviceMemory> *, int> - impl; - return impl(this, &blas::BlasSupport::DoBlasTrsm, side, uplo, transa, diag, m, - n, alpha, a, lda, b, ldb); -} - -Stream &Stream::ThenBlasTrsm(blas::Side side, blas::UpperLower uplo, - blas::Transpose transa, blas::Diagonal diag, - uint64_t m, uint64 n, std::complex alpha, - const DeviceMemory> &a, - int lda, DeviceMemory> *b, - int ldb) { - VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(transa), PARAM(diag), PARAM(m), - PARAM(n), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb)); - - ThenBlasImpl, - const DeviceMemory> &, int, - DeviceMemory> *, int> - impl; - return impl(this, &blas::BlasSupport::DoBlasTrsm, side, uplo, transa, diag, m, - n, alpha, a, lda, b, ldb); -} - -Stream &Stream::ThenBlasTrsmBatched(blas::Side side, blas::UpperLower uplo, - blas::Transpose transa, blas::Diagonal diag, - uint64_t m, uint64 n, float alpha, - const DeviceMemory &as, int lda, - DeviceMemory *bs, int ldb, - int batch_count) { - VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(transa), PARAM(diag), PARAM(m), - PARAM(n), PARAM(alpha), PARAM(as), PARAM(lda), PARAM(bs), - PARAM(ldb), PARAM(batch_count)); - - ThenBlasImpl &, int, - DeviceMemory *, int, int> - impl; - return impl(this, &blas::BlasSupport::DoBlasTrsmBatched, side, uplo, transa, - diag, m, n, alpha, as, lda, bs, ldb, batch_count); -} - -Stream &Stream::ThenBlasTrsmBatched(blas::Side side, blas::UpperLower uplo, - blas::Transpose transa, blas::Diagonal diag, - uint64_t m, uint64 n, double alpha, - const DeviceMemory &as, int lda, - DeviceMemory *bs, int ldb, - int batch_count) { - VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(transa), PARAM(diag), PARAM(m), - PARAM(n), PARAM(alpha), PARAM(as), PARAM(lda), PARAM(bs), - PARAM(ldb), PARAM(batch_count)); - - ThenBlasImpl &, int, - DeviceMemory *, int, int> - impl; - return impl(this, &blas::BlasSupport::DoBlasTrsmBatched, side, uplo, transa, - diag, m, n, alpha, as, lda, bs, ldb, batch_count); -} - -Stream &Stream::ThenBlasTrsmBatched( - blas::Side side, blas::UpperLower uplo, blas::Transpose transa, - blas::Diagonal diag, uint64_t m, uint64 n, std::complex alpha, - const DeviceMemory *> &as, int lda, - DeviceMemory *> *bs, int ldb, int batch_count) { - VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(transa), PARAM(diag), PARAM(m), - PARAM(n), PARAM(alpha), PARAM(as), PARAM(lda), PARAM(bs), - PARAM(ldb), PARAM(batch_count)); - - ThenBlasImpl, - const DeviceMemory *> &, int, - DeviceMemory *> *, int, int> - impl; - return impl(this, &blas::BlasSupport::DoBlasTrsmBatched, side, uplo, transa, - diag, m, n, alpha, as, lda, bs, ldb, batch_count); -} - -Stream &Stream::ThenBlasTrsmBatched( - blas::Side side, blas::UpperLower uplo, blas::Transpose transa, - blas::Diagonal diag, uint64_t m, uint64 n, std::complex alpha, - const DeviceMemory *> &as, int lda, - DeviceMemory *> *bs, int ldb, int batch_count) { - VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(transa), PARAM(diag), PARAM(m), - PARAM(n), PARAM(alpha), PARAM(as), PARAM(lda), PARAM(bs), - PARAM(ldb), PARAM(batch_count)); - - ThenBlasImpl, - const DeviceMemory *> &, int, - DeviceMemory *> *, int, int> - impl; - return impl(this, &blas::BlasSupport::DoBlasTrsmBatched, side, uplo, transa, - diag, m, n, alpha, as, lda, bs, ldb, batch_count); + return absl::InternalError("stream cannot wait for other"); } -Stream &Stream::ThenBlasGemmBatchedWithScratch( - blas::Transpose transa, blas::Transpose transb, uint64_t m, uint64 n, - uint64_t k, float alpha, DeviceMemorySlice a, int lda, - DeviceMemorySlice b, int ldb, float beta, - DeviceMemorySlice c, int ldc, int batch_count, - const NumericOptions &numeric_options, ScratchAllocator *scratch_allocator, - blas::CallContext context) { - VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k), - PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), - PARAM(beta), PARAM(c), PARAM(ldc), PARAM(batch_count)); - - ThenBlasImpl, int, - DeviceMemorySlice, int, float, - DeviceMemorySlice, int, int, const NumericOptions &, - ScratchAllocator *, blas::CallContext> - impl; - return impl(this, &blas::BlasSupport::DoBlasGemmBatched, transa, transb, m, n, - k, alpha, a, lda, b, ldb, beta, c, ldc, batch_count, - numeric_options, scratch_allocator, context); -} - -Stream &Stream::ThenBlasGemmBatchedWithScratch( - blas::Transpose transa, blas::Transpose transb, uint64_t m, uint64 n, - uint64_t k, float alpha, DeviceMemorySlice a, int lda, - DeviceMemorySlice b, int ldb, float beta, - DeviceMemorySlice c, int ldc, int batch_count, - const NumericOptions &numeric_options, ScratchAllocator *scratch_allocator, - blas::CallContext context) { - VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k), - PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), - PARAM(beta), PARAM(c), PARAM(ldc), PARAM(batch_count)); - - ThenBlasImpl, int, - DeviceMemorySlice, int, float, - DeviceMemorySlice, int, int, - const NumericOptions &, ScratchAllocator *, blas::CallContext> - impl; - return impl(this, &blas::BlasSupport::DoBlasGemmBatched, transa, transb, m, n, - k, alpha, a, lda, b, ldb, beta, c, ldc, batch_count, - numeric_options, scratch_allocator, context); -} - -Stream &Stream::ThenBlasGemmBatched( - blas::Transpose transa, blas::Transpose transb, uint64_t m, uint64 n, - uint64_t k, float alpha, DeviceMemorySlice a, int lda, - DeviceMemorySlice b, int ldb, float beta, DeviceMemorySlice c, - int ldc, int batch_count, const NumericOptions &numeric_options, - blas::CallContext context) { - return ThenBlasGemmBatchedWithScratch(transa, transb, m, n, k, alpha, a, lda, - b, ldb, beta, c, ldc, batch_count, - numeric_options, - /*scratch_allocator=*/nullptr, context); -} - -Stream &Stream::ThenBlasGemmBatchedWithScratch( - blas::Transpose transa, blas::Transpose transb, uint64_t m, uint64 n, - uint64_t k, float alpha, DeviceMemorySlice a, int lda, - DeviceMemorySlice b, int ldb, float beta, DeviceMemorySlice c, - int ldc, int batch_count, const NumericOptions &numeric_options, - ScratchAllocator *scratch_allocator, blas::CallContext context) { - VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k), - PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), - PARAM(beta), PARAM(c), PARAM(ldc), PARAM(batch_count)); - - ThenBlasImpl, int, DeviceMemorySlice, - int, float, DeviceMemorySlice, int, int, - const NumericOptions &, ScratchAllocator *, blas::CallContext> - impl; - return impl(this, &blas::BlasSupport::DoBlasGemmBatched, transa, transb, m, n, - k, alpha, a, lda, b, ldb, beta, c, ldc, batch_count, - numeric_options, scratch_allocator, context); -} - -Stream &Stream::ThenBlasGemmBatchedWithScratch( - blas::Transpose transa, blas::Transpose transb, uint64_t m, uint64 n, - uint64_t k, double alpha, DeviceMemorySlice a, int lda, - DeviceMemorySlice b, int ldb, double beta, - DeviceMemorySlice c, int ldc, int batch_count, - const NumericOptions &numeric_options, ScratchAllocator *scratch_allocator, - blas::CallContext context) { - VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k), - PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), - PARAM(beta), PARAM(c), PARAM(ldc), PARAM(batch_count)); - - ThenBlasImpl, int, - DeviceMemorySlice, int, double, - DeviceMemorySlice, int, int, const NumericOptions &, - ScratchAllocator *, blas::CallContext> - impl; - return impl(this, &blas::BlasSupport::DoBlasGemmBatched, transa, transb, m, n, - k, alpha, a, lda, b, ldb, beta, c, ldc, batch_count, - numeric_options, scratch_allocator, context); -} - -Stream &Stream::ThenBlasGemmBatchedWithScratch( - blas::Transpose transa, blas::Transpose transb, uint64_t m, uint64 n, - uint64_t k, std::complex alpha, - DeviceMemorySlice> a, int lda, - DeviceMemorySlice> b, int ldb, std::complex beta, - DeviceMemorySlice> c, int ldc, int batch_count, - const NumericOptions &numeric_options, ScratchAllocator *scratch_allocator, - blas::CallContext context) { - VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k), - PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), - PARAM(beta), PARAM(c), PARAM(ldc), PARAM(batch_count)); - - ThenBlasImpl, DeviceMemorySlice>, int, - DeviceMemorySlice>, int, std::complex, - DeviceMemorySlice>, int, int, - const NumericOptions &, ScratchAllocator *, blas::CallContext> - impl; - return impl(this, &blas::BlasSupport::DoBlasGemmBatched, transa, transb, m, n, - k, alpha, a, lda, b, ldb, beta, c, ldc, batch_count, - numeric_options, scratch_allocator, context); -} - -Stream &Stream::ThenBlasGemmBatchedWithScratch( - blas::Transpose transa, blas::Transpose transb, uint64_t m, uint64 n, - uint64_t k, std::complex alpha, - DeviceMemorySlice> a, int lda, - DeviceMemorySlice> b, int ldb, - std::complex beta, DeviceMemorySlice> c, - int ldc, int batch_count, const NumericOptions &numeric_options, - ScratchAllocator *scratch_allocator, blas::CallContext context) { - VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k), - PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), - PARAM(beta), PARAM(c), PARAM(ldc), PARAM(batch_count)); - - ThenBlasImpl, DeviceMemorySlice>, - int, DeviceMemorySlice>, int, - std::complex, DeviceMemorySlice>, - int, int, const NumericOptions &, ScratchAllocator *, - blas::CallContext> - impl; - return impl(this, &blas::BlasSupport::DoBlasGemmBatched, transa, transb, m, n, - k, alpha, a, lda, b, ldb, beta, c, ldc, batch_count, - numeric_options, scratch_allocator, context); +absl::Status Stream::WaitFor(Event *event) { + return parent_->WaitForEvent(this, event); } Stream &Stream::ThenMemcpy(void *host_dst, const DeviceMemoryBase &gpu_src, uint64_t size) { VLOG_CALL(PARAM(host_dst), PARAM(gpu_src), PARAM(size)); - CheckError(parent_->Memcpy(this, host_dst, gpu_src, size)); + CheckStatus(Memcpy(host_dst, gpu_src, size)); return *this; } +absl::Status Stream::Memcpy(void *host_dst, const DeviceMemoryBase &gpu_src, + uint64_t size) { + if (parent_->Memcpy(this, host_dst, gpu_src, size)) { + return absl::OkStatus(); + } + return absl::InternalError("failed to memcpy"); +} + Stream &Stream::ThenMemcpy(DeviceMemoryBase *gpu_dst, const void *host_src, uint64_t size) { VLOG_CALL(PARAM(gpu_dst), PARAM(host_src), PARAM(size)); - CheckError(parent_->Memcpy(this, gpu_dst, host_src, size)); + CheckStatus(Memcpy(gpu_dst, host_src, size)); return *this; } +absl::Status Stream::Memcpy(DeviceMemoryBase *gpu_dst, const void *host_src, + uint64_t size) { + if (parent_->Memcpy(this, gpu_dst, host_src, size)) { + return absl::OkStatus(); + } + return absl::InternalError("failed to memcpy"); +} + Stream &Stream::ThenMemcpy(DeviceMemoryBase *gpu_dst, const DeviceMemoryBase &gpu_src, uint64_t size) { VLOG_CALL(PARAM(gpu_dst), PARAM(gpu_src), PARAM(size)); - CheckError(parent_->MemcpyDeviceToDevice(this, gpu_dst, gpu_src, size)); - return *this; -} - -Stream &Stream::ThenMemZero(DeviceMemoryBase *location, uint64_t size) { - VLOG_CALL(PARAM(location), PARAM(size)); - - CheckStatus(parent_->MemZero(this, location, size)); - return *this; -} - -Stream &Stream::ThenMemset32(DeviceMemoryBase *location, uint32_t pattern, - uint64_t size) { - VLOG_CALL(PARAM(location), PARAM(pattern), PARAM(size)); - - CheckStatus(parent_->Memset32(this, location, pattern, size)); - return *this; -} - -Stream &Stream::ThenRnnForward( - const dnn::RnnDescriptor &rnn_desc, - const dnn::RnnSequenceTensorDescriptor &input_desc, - const DeviceMemory &input_data, - const DeviceMemory &seq_lengths_data, - const dnn::RnnStateTensorDescriptor &input_h_desc, - const DeviceMemory &input_h_data, - const dnn::RnnStateTensorDescriptor &input_c_desc, - const DeviceMemory &input_c_data, - const DeviceMemory ¶ms, - const dnn::RnnSequenceTensorDescriptor &output_desc, - DeviceMemory *output_data, - const dnn::RnnStateTensorDescriptor &output_h_desc, - DeviceMemory *output_h_data, - const dnn::RnnStateTensorDescriptor &output_c_desc, - DeviceMemory *output_c_data, bool is_training, - ScratchAllocator *reserve_space_allocator, - ScratchAllocator *workspace_allocator, - dnn::ProfileResult *output_profile_result) { - // TODO(zhengxq): add VLOG PARAM calls. - if (dnn::DnnSupport *dnn = parent_->AsDnn()) { - auto status = dnn->DoRnnForward( - this, rnn_desc, input_desc, input_data, seq_lengths_data, input_h_desc, - input_h_data, input_c_desc, input_c_data, params, output_desc, - output_data, output_h_desc, output_h_data, output_c_desc, output_c_data, - is_training, reserve_space_allocator, workspace_allocator, - output_profile_result); - if (!status && !output_profile_result) { - SetError(); - } - } else { - SetErrorAndLogNoDnnSupport(); - } + CheckStatus(Memcpy(gpu_dst, gpu_src, size)); return *this; } -Stream &Stream::ThenRnnForward( - const dnn::RnnDescriptor &rnn_desc, - const dnn::RnnSequenceTensorDescriptor &input_desc, - const DeviceMemory &input_data, - const DeviceMemory &seq_lengths_data, - const dnn::RnnStateTensorDescriptor &input_h_desc, - const DeviceMemory &input_h_data, - const dnn::RnnStateTensorDescriptor &input_c_desc, - const DeviceMemory &input_c_data, const DeviceMemory ¶ms, - const dnn::RnnSequenceTensorDescriptor &output_desc, - DeviceMemory *output_data, - const dnn::RnnStateTensorDescriptor &output_h_desc, - DeviceMemory *output_h_data, - const dnn::RnnStateTensorDescriptor &output_c_desc, - DeviceMemory *output_c_data, bool is_training, - ScratchAllocator *reserve_space_allocator, - ScratchAllocator *workspace_allocator, - dnn::ProfileResult *output_profile_result) { - // TODO(zhengxq): add VLOG PARAM calls. - if (dnn::DnnSupport *dnn = parent_->AsDnn()) { - auto status = dnn->DoRnnForward( - this, rnn_desc, input_desc, input_data, seq_lengths_data, input_h_desc, - input_h_data, input_c_desc, input_c_data, params, output_desc, - output_data, output_h_desc, output_h_data, output_c_desc, output_c_data, - is_training, reserve_space_allocator, workspace_allocator, - output_profile_result); - if (!status && !output_profile_result) { - SetError(); - } - } else { - SetErrorAndLogNoDnnSupport(); +absl::Status Stream::Memcpy(DeviceMemoryBase *gpu_dst, + const DeviceMemoryBase &gpu_src, uint64_t size) { + if (parent_->MemcpyDeviceToDevice(this, gpu_dst, gpu_src, size)) { + return absl::OkStatus(); } - return *this; + return absl::InternalError("failed to memcpy"); } -Stream &Stream::ThenRnnForward( - const dnn::RnnDescriptor &rnn_desc, - const dnn::RnnSequenceTensorDescriptor &input_desc, - const DeviceMemory &input_data, - const DeviceMemory &seq_lengths_data, - const dnn::RnnStateTensorDescriptor &input_h_desc, - const DeviceMemory &input_h_data, - const dnn::RnnStateTensorDescriptor &input_c_desc, - const DeviceMemory &input_c_data, - const DeviceMemory ¶ms, - const dnn::RnnSequenceTensorDescriptor &output_desc, - DeviceMemory *output_data, - const dnn::RnnStateTensorDescriptor &output_h_desc, - DeviceMemory *output_h_data, - const dnn::RnnStateTensorDescriptor &output_c_desc, - DeviceMemory *output_c_data, bool is_training, - ScratchAllocator *reserve_space_allocator, - ScratchAllocator *workspace_allocator, - dnn::ProfileResult *output_profile_result) { - // TODO(zhengxq): add VLOG PARAM calls. - if (dnn::DnnSupport *dnn = parent_->AsDnn()) { - auto status = dnn->DoRnnForward( - this, rnn_desc, input_desc, input_data, seq_lengths_data, input_h_desc, - input_h_data, input_c_desc, input_c_data, params, output_desc, - output_data, output_h_desc, output_h_data, output_c_desc, output_c_data, - is_training, reserve_space_allocator, workspace_allocator, - output_profile_result); - if (!status && !output_profile_result) { - SetError(); - } - } else { - SetErrorAndLogNoDnnSupport(); - } - return *this; -} +Stream &Stream::ThenMemZero(DeviceMemoryBase *location, uint64_t size) { + VLOG_CALL(PARAM(location), PARAM(size)); -Stream &Stream::ThenRnnBackward( - const dnn::RnnDescriptor &rnn_desc, - const dnn::RnnSequenceTensorDescriptor &input_desc, - const DeviceMemory &input_data, - const DeviceMemory &seq_lengths_data, - const dnn::RnnStateTensorDescriptor &input_h_desc, - const DeviceMemory &input_h_data, - const dnn::RnnStateTensorDescriptor &input_c_desc, - const DeviceMemory &input_c_data, - const DeviceMemory ¶ms, - const dnn::RnnSequenceTensorDescriptor &output_desc, - const DeviceMemory &output_data, - const dnn::RnnStateTensorDescriptor &output_h_desc, - const DeviceMemory &output_h_data, - const dnn::RnnStateTensorDescriptor &output_c_desc, - const DeviceMemory &output_c_data, - const DeviceMemory &output_backprop_data, - const DeviceMemory &output_h_backprop_data, - const DeviceMemory &output_c_backprop_data, - DeviceMemory *input_backprop_data, - DeviceMemory *input_h_backprop_data, - DeviceMemory *input_c_backprop_data, - DeviceMemory *params_backprop_data, - DeviceMemory *reserve_space_data, - ScratchAllocator *workspace_allocator, - dnn::ProfileResult *output_profile_result) { - // TODO(zhengxq): add VLOG PARAM calls. - if (dnn::DnnSupport *dnn = parent_->AsDnn()) { - auto status = dnn->DoRnnBackward( - this, rnn_desc, input_desc, input_data, seq_lengths_data, input_h_desc, - input_h_data, input_c_desc, input_c_data, params, output_desc, - output_data, output_h_desc, output_h_data, output_c_desc, output_c_data, - output_backprop_data, output_h_backprop_data, output_c_backprop_data, - input_backprop_data, input_h_backprop_data, input_c_backprop_data, - params_backprop_data, reserve_space_data, workspace_allocator, - output_profile_result); - if (!status && !output_profile_result) { - SetError(); - } - } else { - SetError(); - LOG(WARNING) << "Attempting to call ThenRnnBackward without DNN support"; - } + CheckStatus(MemZero(location, size)); return *this; } -Stream &Stream::ThenRnnBackward( - const dnn::RnnDescriptor &rnn_desc, - const dnn::RnnSequenceTensorDescriptor &input_desc, - const DeviceMemory &input_data, - const DeviceMemory &seq_lengths_data, - const dnn::RnnStateTensorDescriptor &input_h_desc, - const DeviceMemory &input_h_data, - const dnn::RnnStateTensorDescriptor &input_c_desc, - const DeviceMemory &input_c_data, const DeviceMemory ¶ms, - const dnn::RnnSequenceTensorDescriptor &output_desc, - const DeviceMemory &output_data, - const dnn::RnnStateTensorDescriptor &output_h_desc, - const DeviceMemory &output_h_data, - const dnn::RnnStateTensorDescriptor &output_c_desc, - const DeviceMemory &output_c_data, - const DeviceMemory &output_backprop_data, - const DeviceMemory &output_h_backprop_data, - const DeviceMemory &output_c_backprop_data, - DeviceMemory *input_backprop_data, - DeviceMemory *input_h_backprop_data, - DeviceMemory *input_c_backprop_data, - DeviceMemory *params_backprop_data, - DeviceMemory *reserve_space_data, - ScratchAllocator *workspace_allocator, - dnn::ProfileResult *output_profile_result) { - // TODO(zhengxq): add VLOG PARAM calls. - if (dnn::DnnSupport *dnn = parent_->AsDnn()) { - auto status = dnn->DoRnnBackward( - this, rnn_desc, input_desc, input_data, seq_lengths_data, input_h_desc, - input_h_data, input_c_desc, input_c_data, params, output_desc, - output_data, output_h_desc, output_h_data, output_c_desc, output_c_data, - output_backprop_data, output_h_backprop_data, output_c_backprop_data, - input_backprop_data, input_h_backprop_data, input_c_backprop_data, - params_backprop_data, reserve_space_data, workspace_allocator, - output_profile_result); - if (!status && !output_profile_result) { - SetError(); - } - } else { - SetError(); - LOG(WARNING) << "Attempting to call ThenRnnBackward without DNN support"; - } - return *this; +absl::Status Stream::MemZero(DeviceMemoryBase *location, uint64_t size) { + return parent_->MemZero(this, location, size); } -Stream &Stream::ThenRnnBackward( - const dnn::RnnDescriptor &rnn_desc, - const dnn::RnnSequenceTensorDescriptor &input_desc, - const DeviceMemory &input_data, - const DeviceMemory &seq_lengths_data, - const dnn::RnnStateTensorDescriptor &input_h_desc, - const DeviceMemory &input_h_data, - const dnn::RnnStateTensorDescriptor &input_c_desc, - const DeviceMemory &input_c_data, - const DeviceMemory ¶ms, - const dnn::RnnSequenceTensorDescriptor &output_desc, - const DeviceMemory &output_data, - const dnn::RnnStateTensorDescriptor &output_h_desc, - const DeviceMemory &output_h_data, - const dnn::RnnStateTensorDescriptor &output_c_desc, - const DeviceMemory &output_c_data, - const DeviceMemory &output_backprop_data, - const DeviceMemory &output_h_backprop_data, - const DeviceMemory &output_c_backprop_data, - DeviceMemory *input_backprop_data, - DeviceMemory *input_h_backprop_data, - DeviceMemory *input_c_backprop_data, - DeviceMemory *params_backprop_data, - DeviceMemory *reserve_space_data, - ScratchAllocator *workspace_allocator, - dnn::ProfileResult *output_profile_result) { - // TODO(zhengxq): add VLOG PARAM calls. - if (dnn::DnnSupport *dnn = parent_->AsDnn()) { - auto status = dnn->DoRnnBackward( - this, rnn_desc, input_desc, input_data, seq_lengths_data, input_h_desc, - input_h_data, input_c_desc, input_c_data, params, output_desc, - output_data, output_h_desc, output_h_data, output_c_desc, output_c_data, - output_backprop_data, output_h_backprop_data, output_c_backprop_data, - input_backprop_data, input_h_backprop_data, input_c_backprop_data, - params_backprop_data, reserve_space_data, workspace_allocator, - output_profile_result); - if (!status && !output_profile_result) { - SetError(); - } - } else { - SetError(); - LOG(WARNING) << "Attempting to call ThenRnnBackward without DNN support"; - } - return *this; -} - -Stream &Stream::ThenCtcLoss(const dnn::RnnStateTensorDescriptor &probs_desc, - const DeviceMemory &probs_data, - absl::Span labels_data, - absl::Span labels_lengths_data, - absl::Span input_lengths_data, - const NumericOptions &numeric_options, - DeviceMemory *costs_data, - const dnn::RnnStateTensorDescriptor &grads_desc, - DeviceMemory *grads_data, - ScratchAllocator *workspace_allocator) { - if (dnn::DnnSupport *dnn = parent_->AsDnn()) { - DeviceMemory scratch_memory; - int ctc_loss_algo_id; - auto status = - dnn->PrepareForCtcLoss( - this, probs_desc, probs_data, grads_desc, labels_data, - labels_lengths_data, input_lengths_data, numeric_options, - workspace_allocator, &scratch_memory, &ctc_loss_algo_id) - .ok(); - if (status) { - status = dnn->DoCtcLoss(this, probs_desc, probs_data, labels_data, - labels_lengths_data, input_lengths_data, - costs_data, grads_desc, grads_data, - &scratch_memory, ctc_loss_algo_id); - } - if (!status) { - SetError(); - } - } else { - SetErrorAndLogNoDnnSupport(); - } - return *this; -} - -Stream &Stream::ThenTransformTensor(const dnn::BatchDescriptor &input_desc, - dnn::DataType input_type, - const DeviceMemoryBase &input_data, - const dnn::BatchDescriptor &output_desc, - dnn::DataType output_type, float scale, - DeviceMemoryBase *output_data) { - VLOG_CALL(PARAM(input_desc), PARAM(input_type), PARAM(input_data), - PARAM(output_desc), PARAM(output_type), PARAM(scale), - PARAM(output_data)); - if (dnn::DnnSupport *dnn = parent_->AsDnn()) { - CheckError(dnn->DoTransformTensor(this, input_desc, input_type, input_data, - output_desc, output_type, scale, - output_data)); - } else { - SetErrorAndLogNoDnnSupport(); - } - return *this; +absl::Status Stream::Memset32(DeviceMemoryBase *location, uint32_t pattern, + uint64_t size) { + return parent_->Memset32(this, location, pattern, size); } Stream &Stream::ThenDoHostCallback(absl::AnyInvocable callback) { @@ -1541,6 +438,13 @@ Stream &Stream::ThenDoHostCallback(absl::AnyInvocable callback) { }); } +absl::Status Stream::DoHostCallback(absl::AnyInvocable callback) { + return DoHostCallbackWithStatus([cb = std::move(callback)]() mutable { + std::move(cb)(); + return absl::OkStatus(); + }); +} + Stream &Stream::ThenDoHostCallbackWithStatus( absl::AnyInvocable callback) { VLOG_CALL(PARAM(callback)); @@ -1549,10 +453,18 @@ Stream &Stream::ThenDoHostCallbackWithStatus( LOG(INFO) << DebugStreamPointers() << " was in error state before adding host callback"; } - CheckError(parent_->HostCallback(this, std::move(callback))); + CheckStatus(DoHostCallbackWithStatus(std::move(callback))); return *this; } +absl::Status Stream::DoHostCallbackWithStatus( + absl::AnyInvocable callback) { + if (parent_->HostCallback(this, std::move(callback))) { + return absl::OkStatus(); + } + return absl::InternalError("failed to host callback"); +} + void Stream::CheckError(bool operation_retcode) { if (operation_retcode) { return; @@ -1561,114 +473,6 @@ void Stream::CheckError(bool operation_retcode) { status_ = absl::InternalError("Unknown error"); } -Stream &Stream::ThenFft(fft::Plan *plan, - const DeviceMemory> &input, - DeviceMemory> *output) { - VLOG_CALL(PARAM(plan), PARAM(input), PARAM(output)); - - if (fft::FftSupport *fft = parent_->AsFft()) { - CheckError(fft->DoFft(this, plan, input, output)); - } else { - SetError(); - LOG(INFO) << DebugStreamPointers() - << " attempting to perform FFT operation using StreamExecutor" - " without FFT support"; - } - return *this; -} - -Stream &Stream::ThenFft(fft::Plan *plan, - const DeviceMemory> &input, - DeviceMemory> *output) { - VLOG_CALL(PARAM(plan), PARAM(input), PARAM(output)); - - if (fft::FftSupport *fft = parent_->AsFft()) { - CheckError(fft->DoFft(this, plan, input, output)); - } else { - SetError(); - LOG(INFO) << DebugStreamPointers() - << " attempting to perform FFT operation using StreamExecutor" - " without FFT support"; - } - return *this; -} - -Stream &Stream::ThenFft(fft::Plan *plan, const DeviceMemory &input, - DeviceMemory> *output) { - VLOG_CALL(PARAM(plan), PARAM(input), PARAM(output)); - - if (fft::FftSupport *fft = parent_->AsFft()) { - CheckError(fft->DoFft(this, plan, input, output)); - } else { - SetError(); - LOG(INFO) << DebugStreamPointers() - << " attempting to perform FFT operation using StreamExecutor" - " without FFT support"; - } - return *this; -} - -Stream &Stream::ThenFft(fft::Plan *plan, const DeviceMemory &input, - DeviceMemory> *output) { - VLOG_CALL(PARAM(plan), PARAM(input), PARAM(output)); - - if (fft::FftSupport *fft = parent_->AsFft()) { - CheckError(fft->DoFft(this, plan, input, output)); - } else { - SetError(); - LOG(INFO) << DebugStreamPointers() - << " attempting to perform FFT operation using StreamExecutor" - " without FFT support"; - } - return *this; -} - -Stream &Stream::ThenFft(fft::Plan *plan, - const DeviceMemory> &input, - DeviceMemory *output) { - VLOG_CALL(PARAM(plan), PARAM(input), PARAM(output)); - - if (fft::FftSupport *fft = parent_->AsFft()) { - CheckError(fft->DoFft(this, plan, input, output)); - } else { - SetError(); - LOG(INFO) << DebugStreamPointers() - << " attempting to perform FFT operation using StreamExecutor" - " without FFT support"; - } - return *this; -} - -Stream &Stream::ThenFft(fft::Plan *plan, - const DeviceMemory> &input, - DeviceMemory *output) { - VLOG_CALL(PARAM(plan), PARAM(input), PARAM(output)); - - if (fft::FftSupport *fft = parent_->AsFft()) { - CheckError(fft->DoFft(this, plan, input, output)); - } else { - SetError(); - LOG(INFO) << DebugStreamPointers() - << " attempting to perform FFT operation using StreamExecutor" - " without FFT support"; - } - return *this; -} - -// It looks confusing, but all this is doing is inserting a callback at the -// present point in the stream to then enqueue a task on the host executor. -Stream &Stream::ThenEnqueueOnBackgroundThread( - std::function task) { - VLOG_CALL(PARAM(task)); - - StreamExecutor *stream_executor = this->parent_; - std::function bound_task = std::bind(task, stream_executor); - - return ThenDoHostCallback([stream_executor, bound_task]() { - stream_executor->EnqueueOnBackgroundThread(bound_task); - }); -} - absl::Status Stream::BlockHostUntilDone() { VLOG_CALL(); @@ -1702,19 +506,4 @@ void Stream::CheckStatus(absl::Status status) { status_ = status; } -absl::StatusOr> -Stream::AllocateArrayBase(uint64_t element_count, uint64_t element_size) { - uint64_t byte_size = element_count * element_size; - DeviceMemoryBase device_memory = parent()->AllocateArray(byte_size); - if (device_memory == nullptr) { - return absl::ResourceExhaustedError(absl::StrCat( - "could not allocate temporary memory of ", byte_size, " bytes")); - } - - VLOG(1) << absl::StreamFormat( - "stream %p allocated temporary device memory at %p (size %u) in ", this, - device_memory.opaque(), byte_size); - return std::make_unique(this, device_memory); -} - } // namespace stream_executor diff --git a/third_party/xla/xla/stream_executor/stream.h b/third_party/xla/xla/stream_executor/stream.h index ea8daad8b10b77..ab80740a0948ba 100644 --- a/third_party/xla/xla/stream_executor/stream.h +++ b/third_party/xla/xla/stream_executor/stream.h @@ -21,36 +21,27 @@ limitations under the License. #ifndef XLA_STREAM_EXECUTOR_STREAM_H_ #define XLA_STREAM_EXECUTOR_STREAM_H_ -#include #include -#include #include -#include #include -#include #include #include #include +#include "absl/base/attributes.h" #include "absl/base/thread_annotations.h" #include "absl/functional/any_invocable.h" #include "absl/status/status.h" -#include "absl/status/statusor.h" #include "absl/synchronization/mutex.h" -#include "xla/stream_executor/blas.h" #include "xla/stream_executor/device_memory.h" -#include "xla/stream_executor/dnn.h" #include "xla/stream_executor/event.h" #include "xla/stream_executor/fft.h" #include "xla/stream_executor/kernel.h" #include "xla/stream_executor/launch_dim.h" -#include "xla/stream_executor/numeric_options.h" #include "xla/stream_executor/platform.h" #include "xla/stream_executor/platform/port.h" #include "xla/stream_executor/stream_executor_pimpl.h" -#include "xla/stream_executor/temporary_device_memory.h" #include "tsl/platform/errors.h" -#include "tsl/platform/statusor.h" #include "tsl/platform/thread_annotations.h" namespace stream_executor { @@ -63,35 +54,7 @@ class DeviceMemoryBase; template class DeviceMemory; -namespace dnn { -class BatchDescriptor; -class FilterDescriptor; -class ConvolutionDescriptor; -class ProfileResult; -class AlgorithmDesc; -} // namespace dnn - class StreamExecutor; -class ScratchAllocator; - -namespace detail { - -// Helper to return if `T` is the same type as `First` or any or `Rest`. -template -constexpr bool is_any_of() { - return false; -} - -template -constexpr bool is_any_of() { - return std::is_same_v || is_any_of(); -} - -} // namespace detail - -// Convert a type to the corresponding QuantizedActivationMode. -template -struct Quantization; // Represents a stream of dependent computations on a GPU device. // @@ -144,14 +107,16 @@ class Stream { // Initialize the stream. This must be performed before entraining any other // operations. + ABSL_DEPRECATED("Use absl::Status Stream::Initialize instead.") Stream &Init() TF_LOCKS_EXCLUDED(mu_); + absl::Status Initialize(); // Get or create a sub-stream from this stream. If there is any sub-stream in // the pool that can be reused then just return this sub-stream. Otherwise // create a new sub-stream. // // TODO(b/112196569): The semantics of failed sub-streams is error-prone. - Stream *GetOrCreateSubStream() TF_LOCKS_EXCLUDED(mu_); + absl::StatusOr GetOrCreateSubStream() TF_LOCKS_EXCLUDED(mu_); // Return the sub-stream back to the host stream so that it can be reused // later. Sub-streams that are !ok() will not be reused. @@ -159,12 +124,6 @@ class Stream { // TODO(b/112196569): The semantics of failed sub-streams is error-prone. void ReturnSubStream(Stream *sub_stream) TF_LOCKS_EXCLUDED(mu_); - // Allocate temporary memories. The stream will deallocate them when blocked - // or destroyed. - template - absl::StatusOr>> - AllocateTemporaryArray(uint64_t element_count); - // Entrains onto the stream of operations: a kernel launch with the given // (variadic) parameters for the invocation. These arguments can be things // like DeviceMemory or primitive types such as int. What arguments you may @@ -210,980 +169,97 @@ class Stream { // Checks that a stream does not wait for itself, and it is up to the // user to guarantee that a stream does not come to wait on itself in a // cyclic manner; in that case, behavior is undefined. - // - // N.B. Base recursion case for the variadic ThenWaitFor. + ABSL_DEPRECATED("Use absl::Status returning method instead.") Stream &ThenWaitFor(Stream *other); + absl::Status WaitFor(Stream *other); // Waits for an event object to be set. // Note that ThenRecordEvent must have been called on the event before // you call this function; otherwise the event will be considered complete // and this wait will do nothing. + ABSL_DEPRECATED("Use absl::Status returning method instead.") Stream &ThenWaitFor(Event *event); + absl::Status WaitFor(Event *event); // Inserts the specified event into the end of this stream. Once the stream // has processed all events prior to the insertion point, the event will be // marked as completed. // The stream does not take ownership of event - meaning that event's lifetime // must extend past the point at which it is marked complete! + ABSL_DEPRECATED("Use absl::Status returning method instead.") Stream &ThenRecordEvent(Event *event); - - //////////////// - // DNN support - // - // See DnnSupport::* for comments on the following methods. - - Stream &ThenBatchNormalizationForward( - const DeviceMemory &x, const DeviceMemory &scale, - const DeviceMemory &offset, - const DeviceMemory &estimated_mean, - const DeviceMemory &estimated_variance, - const DeviceMemory &side_input, const dnn::BatchDescriptor &x_desc, - const dnn::BatchDescriptor &scale_offset_desc, const double epsilon, - const double exponential_average_factor, - dnn::ActivationMode activation_mode, DeviceMemory *y, - DeviceMemory *batch_mean, DeviceMemory *batch_var, - DeviceMemory *saved_mean, DeviceMemory *saved_inv_var, - bool is_training, ScratchAllocator *reserve_space_allocator, - ScratchAllocator *workspace_allocator); - - Stream &ThenBatchNormalizationBackward( - const DeviceMemory &y_backprop, const DeviceMemory &x, - const DeviceMemory &scale, const DeviceMemory &offset, - const DeviceMemory &mean, const DeviceMemory &inv_var, - const DeviceMemory &y, const dnn::BatchDescriptor &x_desc, - const dnn::BatchDescriptor &scale_offset_desc, const double epsilon, - dnn::ActivationMode activation_mode, DeviceMemory *x_backprop, - DeviceMemory *scale_backprop, DeviceMemory *offset_backprop, - DeviceMemory *side_input_backprop, - DeviceMemory *reserve_space_data, - ScratchAllocator *workspace_allocator); - - Stream &ThenBatchNormalizationForward( - const DeviceMemory &x, const DeviceMemory &scale, - const DeviceMemory &offset, - const DeviceMemory &estimated_mean, - const DeviceMemory &estimated_variance, - const DeviceMemory &side_input, - const dnn::BatchDescriptor &x_desc, - const dnn::BatchDescriptor &scale_offset_desc, const double epsilon, - const double exponential_average_factor, - dnn::ActivationMode activation_mode, DeviceMemory *y, - DeviceMemory *batch_mean, DeviceMemory *batch_var, - DeviceMemory *saved_mean, DeviceMemory *saved_inv_var, - bool is_training, ScratchAllocator *reserve_space_allocator, - ScratchAllocator *workspace_allocator); - - Stream &ThenBatchNormalizationBackward( - const DeviceMemory &y_backprop, - const DeviceMemory &x, const DeviceMemory &scale, - const DeviceMemory &offset, const DeviceMemory &mean, - const DeviceMemory &inv_var, const DeviceMemory &y, - const dnn::BatchDescriptor &x_desc, - const dnn::BatchDescriptor &scale_offset_desc, const double epsilon, - dnn::ActivationMode activation_mode, - DeviceMemory *x_backprop, - DeviceMemory *scale_backprop, DeviceMemory *offset_backprop, - DeviceMemory *side_input_backprop, - DeviceMemory *reserve_space_data, - ScratchAllocator *workspace_allocator); - - Stream &ThenBatchNormalizationForward( - const DeviceMemory &x, const DeviceMemory &scale, - const DeviceMemory &offset, - const DeviceMemory &estimated_mean, - const DeviceMemory &estimated_variance, - const DeviceMemory &side_input, - const dnn::BatchDescriptor &x_desc, - const dnn::BatchDescriptor &scale_offset_desc, const double epsilon, - const double exponential_average_factor, - dnn::ActivationMode activation_mode, DeviceMemory *y, - DeviceMemory *batch_mean, DeviceMemory *batch_var, - DeviceMemory *saved_mean, DeviceMemory *saved_inv_var, - bool is_training, ScratchAllocator *reserve_space_allocator, - ScratchAllocator *workspace_allocator); - - Stream &ThenBatchNormalizationBackward( - const DeviceMemory &y_backprop, - const DeviceMemory &x, const DeviceMemory &scale, - const DeviceMemory &offset, const DeviceMemory &mean, - const DeviceMemory &inv_var, - const DeviceMemory &y, - const dnn::BatchDescriptor &x_desc, - const dnn::BatchDescriptor &scale_offset_desc, const double epsilon, - dnn::ActivationMode activation_mode, - DeviceMemory *x_backprop, - DeviceMemory *scale_backprop, DeviceMemory *offset_backprop, - DeviceMemory *side_input_backprop, - DeviceMemory *reserve_space_data, - ScratchAllocator *workspace_allocator); - - template - absl::Status ConvolveWithAlgorithm( - dnn::ConvolutionKind kind, const dnn::BatchDescriptor &input_descriptor, - DeviceMemory input_data, - const dnn::FilterDescriptor &filter_descriptor, - DeviceMemory filter_data, - const dnn::BatchDescriptor &output_descriptor, - DeviceMemory output_data, - const dnn::ConvolutionDescriptor &convolution_descriptor, - ScratchAllocator *scratch_allocator, - const dnn::AlgorithmConfig &algorithm_config, - dnn::ProfileResult *output_profile_result) { - DeviceMemory scratch_memory; - dnn::AlgorithmDesc algorithm_desc; - if (dnn::DnnSupport *dnn = parent_->AsDnn()) { - TF_RETURN_IF_ERROR(dnn->PrepareForConvolution( - kind, this, input_descriptor, input_data, filter_descriptor, - filter_data, output_descriptor, output_data, convolution_descriptor, - algorithm_config, scratch_allocator, &algorithm_desc, - &scratch_memory)); - return dnn->DoConvolve(kind, dnn::ToDataType::value, - dnn::ToDataType::value, this, - input_descriptor, input_data, filter_descriptor, - filter_data, output_descriptor, output_data, - convolution_descriptor, algorithm_desc, - scratch_memory, output_profile_result); - } - return absl::UnimplementedError("DNN library is not found."); - } - - template - absl::Status FusedConvolveWithAlgorithm( - const dnn::BatchDescriptor &conv_input_descriptor, - const DeviceMemory &conv_input_data, ScaleT conv_input_scale, - const dnn::FilterDescriptor &filter_descriptor, - const DeviceMemory &filter_data, - const dnn::ConvolutionDescriptor &convolution_descriptor, - const DeviceMemory &side_input_data, ScaleT side_input_scale, - const dnn::BatchDescriptor &bias_descriptor, - const DeviceMemory &biases, dnn::ActivationMode activation_mode, - const dnn::BatchDescriptor &output_descriptor, - DeviceMemory *output, ScratchAllocator *scratch_allocator, - const dnn::AlgorithmConfig &algorithm_config, - dnn::ProfileResult *output_profile_result) { - if (dnn::DnnSupport *dnn = parent_->AsDnn()) { - return dnn->DoFusedConvolve( - this, dnn::ToDataType::value, - dnn::ToDataType::value, dnn::ToDataType::value, - dnn::ToDataType::value, conv_input_descriptor, - conv_input_data, conv_input_scale, filter_descriptor, filter_data, - convolution_descriptor, side_input_data, side_input_scale, - bias_descriptor, biases, activation_mode, output_descriptor, *output, - scratch_allocator, algorithm_config, output_profile_result); - } - return absl::UnimplementedError("DNN library is not found."); - } - - absl::Status CudnnReorderConvolutionFilterAndBias( - const dnn::FilterDescriptor &filter_descriptor, - const DeviceMemory &filter_input, - DeviceMemory *filter_output, - std::optional> bias_input, - std::optional> bias_output) { - if (dnn::DnnSupport *dnn = parent_->AsDnn()) { - return dnn->CudnnReorderConvolutionFilterAndBias( - this, filter_descriptor, filter_input, filter_output, - std::move(bias_input), std::move(bias_output)); - } - return absl::UnimplementedError("DNN library is not found."); - } - - absl::StatusOr> ConvolveRunnerFromDesc( - const dnn::AlgorithmDesc &algorithm_desc, dnn::ConvolutionKind kind, - dnn::DataType element_type, dnn::DataType output_type, - const dnn::BatchDescriptor &input_descriptor, - const dnn::FilterDescriptor &filter_descriptor, - const dnn::BatchDescriptor &output_descriptor, - const dnn::ConvolutionDescriptor &convolution_descriptor) { - dnn::DnnSupport *dnn_support = parent_->AsDnn(); - if (!dnn_support) { - return absl::UnimplementedError("DNN library is not found."); - } - return dnn_support->ConvolveRunnerFromDesc( - this, algorithm_desc, kind, element_type, output_type, input_descriptor, - filter_descriptor, output_descriptor, convolution_descriptor); - } - - absl::StatusOr> - GraphConvolveRunnerFromDesc( - const dnn::AlgorithmDesc &algorithm_desc, dnn::ConvolutionKind kind, - dnn::DataType element_type, dnn::DataType output_type, - const dnn::BatchDescriptor &input_descriptor, - const dnn::FilterDescriptor &filter_descriptor, - const dnn::BatchDescriptor &output_descriptor, - const dnn::ConvolutionDescriptor &convolution_descriptor, - std::string serialized_graph) { - dnn::DnnSupport *dnn_support = parent_->AsDnn(); - if (!dnn_support) { - return absl::UnimplementedError("DNN library is not found."); - } - return dnn_support->GraphConvolveRunnerFromDesc( - this, algorithm_desc, kind, element_type, output_type, input_descriptor, - filter_descriptor, output_descriptor, convolution_descriptor, - serialized_graph); - } - - absl::StatusOr> - FusedConvolveRunnerFromDesc( - const dnn::AlgorithmDesc &algorithm_desc, dnn::ConvolutionKind kind, - dnn::DataType element_type, dnn::DataType bias_type, - dnn::DataType output_type, double conv_input_scale, - double side_input_scale, double leakyrelu_alpha, - const dnn::BatchDescriptor &input_descriptor, - const dnn::FilterDescriptor &filter_descriptor, - const dnn::BatchDescriptor &bias_descriptor, - const dnn::BatchDescriptor &output_descriptor, - const dnn::ConvolutionDescriptor &convolution_descriptor, - dnn::ActivationMode activation_mode) { - dnn::DnnSupport *dnn_support = parent_->AsDnn(); - if (!dnn_support) { - return absl::UnimplementedError("DNN library is not found."); - } - return dnn_support->FusedConvolveRunnerFromDesc( - this, algorithm_desc, kind, element_type, bias_type, output_type, - conv_input_scale, side_input_scale, leakyrelu_alpha, input_descriptor, - filter_descriptor, bias_descriptor, output_descriptor, - convolution_descriptor, activation_mode); - } - - absl::StatusOr> NormRunnerFromDesc( - const dnn::AlgorithmDesc &algorithm_desc, double epsilon, - const dnn::TensorDescriptor &input_descriptor, - const dnn::TensorDescriptor &scale_descriptor, - const dnn::TensorDescriptor &bias_descriptor, - const dnn::TensorDescriptor &output_descriptor, - std::optional expectation_descriptor, - std::optional norm_factor_descriptor) { - dnn::DnnSupport *dnn_support = parent_->AsDnn(); - if (!dnn_support) { - return absl::UnimplementedError("DNN library is not found."); - } - return dnn_support->NormRunnerFromDesc( - this, algorithm_desc, epsilon, input_descriptor, scale_descriptor, - bias_descriptor, output_descriptor, expectation_descriptor, - norm_factor_descriptor); - } - - absl::StatusOr> - FusedMHARunnerFromDesc( - const dnn::AlgorithmDesc &algorithm_desc, dnn::FusedMHAKind kind, - const dnn::MatmulTensorDescriptor &bmm1_lhs_descriptor, - const dnn::MatmulTensorDescriptor &bmm1_rhs_descriptor, - const dnn::MatmulTensorDescriptor &bmm2_rhs_descriptor, - const dnn::MatmulTensorDescriptor &intermediate_bmm2_lhs_descriptor, - const dnn::TensorDescriptor &output_descriptor, - std::optional activation_descriptor, - std::optional mask_descriptor, - std::optional bias_descriptor, double scale, - std::optional dropout_rate, std::optional seed, - bool is_flash_attention, bool is_causal_mask) { - dnn::DnnSupport *dnn_support = parent_->AsDnn(); - if (!dnn_support) { - return absl::UnimplementedError("DNN library is not found."); - } - return dnn_support->FusedMHARunnerFromDesc( - this, algorithm_desc, kind, bmm1_lhs_descriptor, bmm1_rhs_descriptor, - bmm2_rhs_descriptor, intermediate_bmm2_lhs_descriptor, - output_descriptor, activation_descriptor, mask_descriptor, - bias_descriptor, scale, dropout_rate, seed, is_flash_attention, - is_causal_mask); - } - - absl::StatusOr> - FusedMHABackwardRunnerFromDesc( - const dnn::AlgorithmDesc &algorithm_desc, dnn::FusedMHAKind kind, - const dnn::MatmulTensorDescriptor &bmm1_grad_gemm1_rhs_descriptor, - const dnn::MatmulTensorDescriptor &bmm1_grad_gemm2_rhs_descriptor, - const dnn::MatmulTensorDescriptor &bmm2_grad_gemm1_lhs_descriptor, - const dnn::MatmulTensorDescriptor &bmm2_grad_gemm2_rhs_descriptor, - const dnn::MatmulTensorDescriptor &d_output_descriptor, - const dnn::TensorDescriptor &d_bmm1_lhs_descriptor, - const dnn::TensorDescriptor &d_bmm1_rhs_descriptor, - const dnn::TensorDescriptor &d_bmm2_rhs_descriptor, - std::optional d_s_descriptor, - std::optional mask_descriptor, - std::optional d_bias_descriptor, - std::optional fwd_output_descriptor, - std::optional bias_descriptor, double scale, - std::optional dropout_rate, std::optional seed, - bool is_flash_attention, bool is_causal_mask) { - dnn::DnnSupport *dnn_support = parent_->AsDnn(); - if (!dnn_support) { - return absl::UnimplementedError("DNN library is not found."); - } - return dnn_support->FusedMHABackwardRunnerFromDesc( - this, algorithm_desc, kind, bmm1_grad_gemm1_rhs_descriptor, - bmm1_grad_gemm2_rhs_descriptor, bmm2_grad_gemm1_lhs_descriptor, - bmm2_grad_gemm2_rhs_descriptor, d_output_descriptor, - d_bmm1_lhs_descriptor, d_bmm1_rhs_descriptor, d_bmm2_rhs_descriptor, - d_s_descriptor, mask_descriptor, d_bias_descriptor, - fwd_output_descriptor, bias_descriptor, scale, dropout_rate, seed, - is_flash_attention, is_causal_mask); - } - - template - absl::Status ThenPoolForward( - const dnn::PoolingDescriptor &pooling_dimensions, - const NumericOptions &numeric_options, - const dnn::BatchDescriptor &input_dimensions, - const DeviceMemory &input_data, - const dnn::BatchDescriptor &output_dimensions, - DeviceMemory *output_data, - ScratchAllocator *workspace_allocator = nullptr) { - if (dnn::DnnSupport *dnn = parent_->AsDnn()) { - return dnn->DoPoolForward(dnn::ToDataType::value, this, - pooling_dimensions, numeric_options, - input_dimensions, input_data, output_dimensions, - *output_data, workspace_allocator); - } - return absl::UnimplementedError("DNN library is not found."); - } - - template - absl::Status ThenPoolBackward( - const dnn::PoolingDescriptor &pooling_dimensions, - const NumericOptions &numeric_options, - const dnn::BatchDescriptor &input_dimensions, - const DeviceMemory &input_data, - const dnn::BatchDescriptor &output_dimensions, - const DeviceMemory &output_data, - const DeviceMemory &input_diff_data, - DeviceMemory *output_diff_data, - ScratchAllocator *workspace_allocator = nullptr) { - if (dnn::DnnSupport *dnn = parent_->AsDnn()) { - return dnn->DoPoolBackward( - dnn::ToDataType::value, this, pooling_dimensions, - numeric_options, input_dimensions, input_data, output_dimensions, - output_data, input_diff_data, *output_diff_data, workspace_allocator); - } - return absl::UnimplementedError("DNN library is not found."); - } - - Stream &ThenNormalizeWithDimensions( - const dnn::NormalizeDescriptor &normalize_descriptor, - const dnn::BatchDescriptor &dimensions, - const DeviceMemory &input_data, DeviceMemory *output_data); - - Stream &ThenNormalizeBackwardWithDimensions( - const dnn::NormalizeDescriptor &normalize_descriptor, - const dnn::BatchDescriptor &dimensions, - const DeviceMemory &raw_data, - const DeviceMemory &normalized_data, - const DeviceMemory &normalized_variable_gradient, - DeviceMemory *raw_variable_gradient, - ScratchAllocator *workspace_allocator = nullptr); - - Stream &ThenDepthConcatenate( - absl::Span input_dimensions, - absl::Span *const> input_data, - DeviceMemory *output_data); - - ///////////////// - // BLAS support - - // See BlasSupport::DoBlasAxpy. Note that, even for the case where alpha is - // present in DeviceMemory, it must be an execution-time constant (i.e. a - // value - // that the stream does not change or populate during the course of - // execution). The value is effectively captured at stream-enqueue time. - Stream &ThenBlasAxpy(uint64_t elem_count, float alpha, - const DeviceMemory &x, int incx, - DeviceMemory *y, int incy); - - // See BlasSupport::DoBlasCopy. - Stream &ThenBlasCopy(uint64_t elem_count, const DeviceMemory &x, - int incx, DeviceMemory *y, int incy); - - // See BlasSupport::DoBlasScal. - Stream &ThenBlasScal(uint64_t elem_count, float alpha, DeviceMemory *x, - int incx); - Stream &ThenBlasScal(uint64_t elem_count, double alpha, - DeviceMemory *x, int incx); - Stream &ThenBlasScal(uint64_t elem_count, float alpha, - DeviceMemory> *x, int incx); - Stream &ThenBlasScal(uint64_t elem_count, double alpha, - DeviceMemory> *x, int incx); - Stream &ThenBlasScal(uint64_t elem_count, std::complex alpha, - DeviceMemory> *x, int incx); - Stream &ThenBlasScal(uint64_t elem_count, std::complex alpha, - DeviceMemory> *x, int incx); - - // See BlasSupport::DoBlasGemv. - Stream &ThenBlasGemv(blas::Transpose trans, uint64_t m, uint64 n, float alpha, - const DeviceMemory &a, int lda, - const DeviceMemory &x, int incx, float beta, - DeviceMemory *y, int incy); - Stream &ThenBlasGemv(blas::Transpose trans, uint64_t m, uint64 n, - double alpha, const DeviceMemory &a, int lda, - const DeviceMemory &x, int incx, double beta, - DeviceMemory *y, int incy); - Stream &ThenBlasGemv(blas::Transpose trans, uint64_t m, uint64 n, - std::complex alpha, - const DeviceMemory> &a, int lda, - const DeviceMemory> &x, int incx, - std::complex beta, - DeviceMemory> *y, int incy); - Stream &ThenBlasGemv(blas::Transpose trans, uint64_t m, uint64 n, - std::complex alpha, - const DeviceMemory> &a, int lda, - const DeviceMemory> &x, int incx, - std::complex beta, - DeviceMemory> *y, int incy); - - // See BlasSupport::DoBlasSbmv. - Stream &ThenBlasSbmv(blas::UpperLower uplo, uint64_t n, uint64 k, float alpha, - const DeviceMemory &a, int lda, - const DeviceMemory &x, int incx, float beta, - DeviceMemory *y, int incy); - - template - absl::Status ThenBlasGemm(blas::Transpose transa, blas::Transpose transb, - uint64_t m, uint64 n, uint64 k, - const DeviceMemory &a, int lda, - const DeviceMemory &b, int ldb, - DeviceMemory *c, int ldc, - const NumericOptions &numeric_options, - blas::CallContext context) { - InputType alpha{1.0}; - InputType beta{0.0}; - return ThenBlasGemm(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, - ldc, numeric_options, context); - } - - template - absl::Status ThenBlasGemm(blas::Transpose transa, blas::Transpose transb, - uint64_t m, uint64 n, uint64 k, ConstantType alpha, - const DeviceMemory &a, int lda, - const DeviceMemory &b, int ldb, - ConstantType beta, DeviceMemory *c, - int ldc, const NumericOptions &numeric_options, - blas::CallContext context) { - static_assert( - detail::is_any_of, - std::complex>(), - "Input can be int8_t, half, bf16, float, double, std::complex " - "or " - "std::complex"); - static_assert(!std::is_same_v || - detail::is_any_of(), - "If input is Eigen::half, constant has to be either " - "Eigen::half or float"); - static_assert(detail::is_any_of(), - "If input is not int8_t, Eigen::half, constant and input " - "types have to match"); - blas::BlasSupport *blas = parent()->AsBlas(); - if (!blas) { - return absl::InternalError( - "Attempting to perform BLAS operation using " - "StreamExecutor without BLAS support"); - } - - void *alpha_ptr = α - void *beta_ptr = β - float alpha_storage, beta_storage; - UpcastHalfToFloat(&alpha_ptr, &beta_ptr, &alpha_storage, - &beta_storage); - - return blas->DoBlasGemm( - this, transa, transb, m, n, k, blas::ToDataType::value, - alpha_ptr, a, lda, b, ldb, beta_ptr, c, ldc, numeric_options, context); - } - - // TODO(reedwm): Update all callers to pass correct NumericOptions. - template - absl::Status ThenBlasGemm(blas::Transpose transa, blas::Transpose transb, - uint64_t m, uint64 n, uint64 k, ConstantType alpha, - const DeviceMemory &a, int lda, - const DeviceMemory &b, int ldb, - ConstantType beta, DeviceMemory *c, - int ldc, blas::CallContext context) { - return ThenBlasGemm(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, - ldc, NumericOptions{}, context); - } - - template - absl::Status ThenBlasGemmWithAlgorithm( - blas::Transpose transa, blas::Transpose transb, uint64_t m, uint64 n, - uint64_t k, const DeviceMemory &a, int lda, - const DeviceMemory &b, int ldb, DeviceMemory *c, - int ldc, blas::ComputationType computation_type, - blas::AlgorithmType algorithm, blas::ProfileResult *output_profile_result, - blas::CallContext context) { - OutputType alpha{1}; - OutputType beta{0}; - return ThenBlasGemmWithAlgorithm(transa, transb, m, n, k, alpha, a, lda, b, - ldb, beta, c, ldc, computation_type, - algorithm, NumericOptions{}, - output_profile_result, context); - } - - template - absl::Status ThenBlasGemmWithAlgorithm( - blas::Transpose transa, blas::Transpose transb, uint64_t m, uint64 n, - uint64_t k, ConstantType alpha, const DeviceMemory &a, int lda, - const DeviceMemory &b, int ldb, ConstantType beta, - DeviceMemory *c, int ldc, - blas::ComputationType computation_type, blas::AlgorithmType algorithm, - const NumericOptions &numeric_options, - blas::ProfileResult *output_profile_result, blas::CallContext context) { - TF_RETURN_IF_ERROR( - CheckTypesForExtendedBlas( - computation_type)); - - blas::BlasSupport *blas = parent()->AsBlas(); - if (!blas) { - return absl::InternalError( - "Attempting to perform BLAS operation using " - "StreamExecutor without BLAS support"); - } - - void *alpha_ptr = α - void *beta_ptr = β - float alpha_storage, beta_storage; - UpcastHalfToFloat(&alpha_ptr, &beta_ptr, &alpha_storage, - &beta_storage); - - absl::Status st = blas->DoBlasGemmWithAlgorithm( - this, transa, transb, m, n, k, alpha_ptr, a, - blas::ToDataType::value, lda, b, - blas::ToDataType::value, ldb, beta_ptr, c, - blas::ToDataType::value, ldc, computation_type, algorithm, - numeric_options, output_profile_result, context); - - if (output_profile_result) { - // The error is recorded in the profile. - return absl::OkStatus(); - } - return st; - } - - template - absl::Status ThenBlasGemmStridedBatchedWithAlgorithm( - blas::Transpose transa, blas::Transpose transb, uint64_t m, uint64 n, - uint64_t k, ConstantType alpha, const DeviceMemory &a, int lda, - int64_t stride_a, const DeviceMemory &b, int ldb, - int64_t stride_b, ConstantType beta, DeviceMemory *c, int ldc, - int64_t stride_c, int batch_count, blas::ComputationType computation_type, - blas::AlgorithmType algorithm, const NumericOptions &numeric_options, - blas::ProfileResult *output_profile_result, blas::CallContext context) { - TF_RETURN_IF_ERROR( - CheckTypesForExtendedBlas( - computation_type)); - - blas::BlasSupport *blas = parent()->AsBlas(); - if (!blas) { - return absl::InternalError( - "Attempting to perform BLAS operation using " - "StreamExecutor without BLAS support"); - } - void *alpha_ptr = α - void *beta_ptr = β - float alpha_storage, beta_storage; - UpcastHalfToFloat(&alpha_ptr, &beta_ptr, &alpha_storage, - &beta_storage); - absl::Status st = blas->DoBlasGemmStridedBatchedWithAlgorithm( - this, transa, transb, m, n, k, alpha_ptr, a, - blas::ToDataType::value, lda, stride_a, b, - blas::ToDataType::value, ldb, stride_b, beta_ptr, c, - blas::ToDataType::value, ldc, stride_c, batch_count, - computation_type, algorithm, numeric_options, output_profile_result, - context); - if (output_profile_result) { - // The error is recorded in the profile. - return absl::OkStatus(); - } - return st; - } - - template - using DeviceMemorySlice = absl::Span *const>; - - // See BlasSupport::DoBlasGemmBatched. - Stream &ThenBlasGemmBatched(blas::Transpose transa, blas::Transpose transb, - uint64_t m, uint64 n, uint64 k, float alpha, - DeviceMemorySlice a, int lda, - DeviceMemorySlice b, int ldb, float beta, - DeviceMemorySlice c, int ldc, - int batch_count, - const NumericOptions &numeric_options, - blas::CallContext context); - - Stream &ThenBlasGemmBatchedWithScratch( - blas::Transpose transa, blas::Transpose transb, uint64_t m, uint64 n, - uint64_t k, float alpha, DeviceMemorySlice a, int lda, - DeviceMemorySlice b, int ldb, float beta, - DeviceMemorySlice c, int ldc, int batch_count, - const NumericOptions &numeric_options, - ScratchAllocator *scratch_allocator, blas::CallContext context); - - Stream &ThenBlasGemmBatchedWithScratch( - blas::Transpose transa, blas::Transpose transb, uint64_t m, uint64 n, - uint64_t k, float alpha, DeviceMemorySlice a, int lda, - DeviceMemorySlice b, int ldb, float beta, - DeviceMemorySlice c, int ldc, int batch_count, - const NumericOptions &numeric_options, - ScratchAllocator *scratch_allocator, blas::CallContext context); - - Stream &ThenBlasGemmBatchedWithScratch( - blas::Transpose transa, blas::Transpose transb, uint64_t m, uint64 n, - uint64_t k, float alpha, DeviceMemorySlice a, int lda, - DeviceMemorySlice b, int ldb, float beta, - DeviceMemorySlice c, int ldc, int batch_count, - const NumericOptions &numeric_options, - ScratchAllocator *scratch_allocator, blas::CallContext context); - - Stream &ThenBlasGemmBatchedWithScratch( - blas::Transpose transa, blas::Transpose transb, uint64_t m, uint64 n, - uint64_t k, double alpha, DeviceMemorySlice a, int lda, - DeviceMemorySlice b, int ldb, double beta, - DeviceMemorySlice c, int ldc, int batch_count, - const NumericOptions &numeric_options, - ScratchAllocator *scratch_allocator, blas::CallContext context); - - Stream &ThenBlasGemmBatchedWithScratch( - blas::Transpose transa, blas::Transpose transb, uint64_t m, uint64 n, - uint64_t k, std::complex alpha, - DeviceMemorySlice> a, int lda, - DeviceMemorySlice> b, int ldb, - std::complex beta, DeviceMemorySlice> c, - int ldc, int batch_count, const NumericOptions &numeric_options, - ScratchAllocator *scratch_allocator, blas::CallContext context); - - Stream &ThenBlasGemmBatchedWithScratch( - blas::Transpose transa, blas::Transpose transb, uint64_t m, uint64 n, - uint64_t k, std::complex alpha, - DeviceMemorySlice> a, int lda, - DeviceMemorySlice> b, int ldb, - std::complex beta, DeviceMemorySlice> c, - int ldc, int batch_count, const NumericOptions &numeric_options, - ScratchAllocator *scratch_allocator, blas::CallContext context); - - template - absl::Status ThenBlasGemmStridedBatched( - blas::Transpose transa, blas::Transpose transb, uint64_t m, uint64 n, - uint64_t k, ConstantType alpha, const DeviceMemory &a, int lda, - int64_t stride_a, const DeviceMemory &b, int ldb, - int64_t stride_b, ConstantType beta, DeviceMemory *c, int ldc, - int64_t stride_c, int batch_count, const NumericOptions &numeric_options, - blas::CallContext context) { - static_assert( - detail::is_any_of, - std::complex>(), - "Unsupported input type"); - static_assert(std::is_same_v || - (detail::is_any_of() && - std::is_same_v), - "Mismatched input and alpha/beta types"); - blas::BlasSupport *blas = parent()->AsBlas(); - if (!blas) { - return absl::InternalError( - "Attempting to perform BLAS operation using " - "StreamExecutor without BLAS support"); - } - - void *alpha_ptr = α - void *beta_ptr = β - float alpha_storage, beta_storage; - UpcastHalfToFloat(&alpha_ptr, &beta_ptr, &alpha_storage, - &beta_storage); - - return blas->DoBlasGemmStridedBatched( - this, transa, transb, m, n, k, blas::ToDataType::value, - alpha_ptr, a, lda, stride_a, b, ldb, stride_b, beta_ptr, c, ldc, - stride_c, batch_count, numeric_options, context); - } - - // See BlasSupport::DoBlasTrsm. - Stream &ThenBlasTrsm(blas::Side side, blas::UpperLower uplo, - blas::Transpose transa, blas::Diagonal diag, uint64_t m, - uint64_t n, float alpha, const DeviceMemory &a, - int lda, DeviceMemory *b, int ldb); - Stream &ThenBlasTrsm(blas::Side side, blas::UpperLower uplo, - blas::Transpose transa, blas::Diagonal diag, uint64_t m, - uint64_t n, double alpha, const DeviceMemory &a, - int lda, DeviceMemory *b, int ldb); - Stream &ThenBlasTrsm(blas::Side side, blas::UpperLower uplo, - blas::Transpose transa, blas::Diagonal diag, uint64_t m, - uint64_t n, std::complex alpha, - const DeviceMemory> &a, int lda, - DeviceMemory> *b, int ldb); - Stream &ThenBlasTrsm(blas::Side side, blas::UpperLower uplo, - blas::Transpose transa, blas::Diagonal diag, uint64_t m, - uint64_t n, std::complex alpha, - const DeviceMemory> &a, int lda, - DeviceMemory> *b, int ldb); - - // See BlasSupport::DoBlasTrsmBatched. - Stream &ThenBlasTrsmBatched(blas::Side side, blas::UpperLower uplo, - blas::Transpose transa, blas::Diagonal diag, - uint64_t m, uint64 n, float alpha, - const DeviceMemory &as, int lda, - DeviceMemory *bs, int ldb, - int batch_count); - Stream &ThenBlasTrsmBatched(blas::Side side, blas::UpperLower uplo, - blas::Transpose transa, blas::Diagonal diag, - uint64_t m, uint64 n, double alpha, - const DeviceMemory &as, int lda, - DeviceMemory *bs, int ldb, - int batch_count); - Stream &ThenBlasTrsmBatched(blas::Side side, blas::UpperLower uplo, - blas::Transpose transa, blas::Diagonal diag, - uint64_t m, uint64 n, std::complex alpha, - const DeviceMemory *> &as, - int lda, DeviceMemory *> *bs, - int ldb, int batch_count); - Stream &ThenBlasTrsmBatched(blas::Side side, blas::UpperLower uplo, - blas::Transpose transa, blas::Diagonal diag, - uint64_t m, uint64 n, std::complex alpha, - const DeviceMemory *> &as, - int lda, DeviceMemory *> *bs, - int ldb, int batch_count); - - // See FftSupport::DoFft. - Stream &ThenFft(fft::Plan *plan, - const DeviceMemory> &input, - DeviceMemory> *output); - Stream &ThenFft(fft::Plan *plan, - const DeviceMemory> &input, - DeviceMemory> *output); - Stream &ThenFft(fft::Plan *plan, const DeviceMemory &input, - DeviceMemory> *output); - Stream &ThenFft(fft::Plan *plan, const DeviceMemory &input, - DeviceMemory> *output); - Stream &ThenFft(fft::Plan *plan, - const DeviceMemory> &input, - DeviceMemory *output); - Stream &ThenFft(fft::Plan *plan, - const DeviceMemory> &input, - DeviceMemory *output); + absl::Status RecordEvent(Event *event); // Entrain onto the stream: a memcpy to a host destination from a GPU source // of the given target size. host_dst must be a pointer to host memory // allocated by StreamExecutor::HostMemoryAllocate or otherwise allocated and // then registered with StreamExecutor::HostMemoryRegister. + ABSL_DEPRECATED("Use absl::Status returning method instead.") Stream &ThenMemcpy(void *host_dst, const DeviceMemoryBase &gpu_src, uint64_t size); + absl::Status Memcpy(void *host_dst, const DeviceMemoryBase &gpu_src, + uint64_t size); // Entrain onto the stream: a memcpy to a GPU destination from a host source // of the given target size. host_src must be a pointer to host memory // allocated by StreamExecutor::HostMemoryAllocate or otherwise allocated and // then registered with StreamExecutor::HostMemoryRegister. + ABSL_DEPRECATED("Use absl::Status returning method instead.") Stream &ThenMemcpy(DeviceMemoryBase *gpu_dst, const void *host_src, uint64_t size); + absl::Status Memcpy(DeviceMemoryBase *gpu_dst, const void *host_src, + uint64_t size); // Alternative interface for memcpying from device to host that takes an // array slice. Checks that the destination size can accommodate the host // slice size. template - Stream &ThenMemcpyD2H(const DeviceMemory &gpu_src, - absl::Span host_dst) { + absl::Status MemcpyD2H(const DeviceMemory &gpu_src, + absl::Span host_dst) { auto host_size = host_dst.size() * sizeof(T); - CHECK(gpu_src.size() == 0 || host_size >= gpu_src.size()); - return ThenMemcpy(host_dst.begin(), gpu_src, host_size); + if (gpu_src.size() == 0 || host_size >= gpu_src.size()) { + return Memcpy(host_dst.begin(), gpu_src, host_size); + } + return absl::InternalError("Bad source size."); } // Alternative interface for memcpying from host to device that takes an // array slice. Checks that the destination size can accommodate the host // slice size. template - Stream &ThenMemcpyH2D(absl::Span host_src, - DeviceMemory *gpu_dst) { + absl::Status MemcpyH2D(absl::Span host_src, + DeviceMemory *gpu_dst) { auto host_size = host_src.size() * sizeof(T); - CHECK(gpu_dst->size() == 0 || gpu_dst->size() >= host_size); - return ThenMemcpy(gpu_dst, host_src.begin(), host_size); + if (gpu_dst->size() == 0 || gpu_dst->size() >= host_size) { + return Memcpy(gpu_dst, host_src.begin(), host_size); + } + return absl::InternalError("Bad destination size."); } // Entrain onto the stream: a memcpy to a GPU destination from a GPU source // of the given target size. gpu_src/dst must be pointers to GPU memory and // peer access must be enabled between their owning StreamExecutors. + ABSL_DEPRECATED("Use absl::Status returning method instead.") Stream &ThenMemcpy(DeviceMemoryBase *gpu_dst, const DeviceMemoryBase &gpu_src, uint64_t size); - - // Calls to the device-to-device copy overload of ThenMemcpy -- useful for - // ensuring that the host pointer isn't getting confused accidentally with a - // device pointer if you're not doing metaprogramming against the API. - Stream &ThenMemcpyD2D(DeviceMemoryBase *gpu_dst, - const DeviceMemoryBase &gpu_src, uint64_t size) { - return ThenMemcpy(gpu_dst, gpu_src, size); + absl::Status Memcpy(DeviceMemoryBase *gpu_dst, + const DeviceMemoryBase &gpu_src, uint64_t size); + absl::Status MemcpyD2D(DeviceMemoryBase *gpu_dst, + const DeviceMemoryBase &gpu_src, uint64_t size) { + return Memcpy(gpu_dst, gpu_src, size); } // Entrain onto the stream: a memset of zero at a GPU location of size bytes. // The location must not be null. + ABSL_DEPRECATED("Use absl::Status returning method instead.") Stream &ThenMemZero(DeviceMemoryBase *location, uint64_t size); + absl::Status MemZero(DeviceMemoryBase *location, uint64_t size); // Entrain onto the stream: a memset of a 32-bit pattern at a GPU location of // size bytes, where bytes must be evenly 32-bit sized (i.e. evenly divisible // by 4). The location must not be null. - Stream &ThenMemset32(DeviceMemoryBase *location, uint32_t pattern, - uint64_t size); - - // Enqueue a forward operation of the RNN model onto the stream. - // See DnnSupport::DoRnnForward for more details. - Stream &ThenRnnForward(const dnn::RnnDescriptor &rnn_desc, - const dnn::RnnSequenceTensorDescriptor &input_desc, - const DeviceMemory &input_data, - const DeviceMemory &seq_lengths_data, - const dnn::RnnStateTensorDescriptor &input_h_desc, - const DeviceMemory &input_h_data, - const dnn::RnnStateTensorDescriptor &input_c_desc, - const DeviceMemory &input_c_data, - const DeviceMemory ¶ms, - const dnn::RnnSequenceTensorDescriptor &output_desc, - DeviceMemory *output_data, - const dnn::RnnStateTensorDescriptor &output_h_desc, - DeviceMemory *output_h_data, - const dnn::RnnStateTensorDescriptor &output_c_desc, - DeviceMemory *output_c_data, - bool is_training, - ScratchAllocator *reserve_space_allocator, - ScratchAllocator *workspace_allocator, - dnn::ProfileResult *output_profile_result); - - Stream &ThenRnnForward(const dnn::RnnDescriptor &rnn_desc, - const dnn::RnnSequenceTensorDescriptor &input_desc, - const DeviceMemory &input_data, - const DeviceMemory &seq_lengths_data, - const dnn::RnnStateTensorDescriptor &input_h_desc, - const DeviceMemory &input_h_data, - const dnn::RnnStateTensorDescriptor &input_c_desc, - const DeviceMemory &input_c_data, - const DeviceMemory ¶ms, - const dnn::RnnSequenceTensorDescriptor &output_desc, - DeviceMemory *output_data, - const dnn::RnnStateTensorDescriptor &output_h_desc, - DeviceMemory *output_h_data, - const dnn::RnnStateTensorDescriptor &output_c_desc, - DeviceMemory *output_c_data, bool is_training, - ScratchAllocator *reserve_space_allocator, - ScratchAllocator *workspace_allocator, - dnn::ProfileResult *output_profile_result); - - Stream &ThenRnnForward(const dnn::RnnDescriptor &rnn_desc, - const dnn::RnnSequenceTensorDescriptor &input_desc, - const DeviceMemory &input_data, - const DeviceMemory &seq_lengths_data, - const dnn::RnnStateTensorDescriptor &input_h_desc, - const DeviceMemory &input_h_data, - const dnn::RnnStateTensorDescriptor &input_c_desc, - const DeviceMemory &input_c_data, - const DeviceMemory ¶ms, - const dnn::RnnSequenceTensorDescriptor &output_desc, - DeviceMemory *output_data, - const dnn::RnnStateTensorDescriptor &output_h_desc, - DeviceMemory *output_h_data, - const dnn::RnnStateTensorDescriptor &output_c_desc, - DeviceMemory *output_c_data, bool is_training, - ScratchAllocator *reserve_space_allocator, - ScratchAllocator *workspace_allocator, - dnn::ProfileResult *output_profile_result); - - // Enqueue a backward operation of the RNN model onto the stream. - // See DnnSupport::DoRnnBackward for more details. - Stream &ThenRnnBackward( - const dnn::RnnDescriptor &rnn_desc, - const dnn::RnnSequenceTensorDescriptor &input_desc, - const DeviceMemory &input_data, - const DeviceMemory &seq_lengths_data, - const dnn::RnnStateTensorDescriptor &input_h_desc, - const DeviceMemory &input_h_data, - const dnn::RnnStateTensorDescriptor &input_c_desc, - const DeviceMemory &input_c_data, - const DeviceMemory ¶ms, - const dnn::RnnSequenceTensorDescriptor &output_desc, - const DeviceMemory &output_data, - const dnn::RnnStateTensorDescriptor &output_h_desc, - const DeviceMemory &output_h_data, - const dnn::RnnStateTensorDescriptor &output_c_desc, - const DeviceMemory &output_c_data, - const DeviceMemory &output_backprop_data, - const DeviceMemory &output_h_backprop_data, - const DeviceMemory &output_c_backprop_data, - DeviceMemory *input_backprop_data, - DeviceMemory *input_h_backprop_data, - DeviceMemory *input_c_backprop_data, - DeviceMemory *params_backprop_data, - DeviceMemory *reserve_space_data, - ScratchAllocator *workspace_allocator, - dnn::ProfileResult *output_profile_result); - - Stream &ThenRnnBackward(const dnn::RnnDescriptor &rnn_desc, - const dnn::RnnSequenceTensorDescriptor &input_desc, - const DeviceMemory &input_data, - const DeviceMemory &seq_lengths_data, - const dnn::RnnStateTensorDescriptor &input_h_desc, - const DeviceMemory &input_h_data, - const dnn::RnnStateTensorDescriptor &input_c_desc, - const DeviceMemory &input_c_data, - const DeviceMemory ¶ms, - const dnn::RnnSequenceTensorDescriptor &output_desc, - const DeviceMemory &output_data, - const dnn::RnnStateTensorDescriptor &output_h_desc, - const DeviceMemory &output_h_data, - const dnn::RnnStateTensorDescriptor &output_c_desc, - const DeviceMemory &output_c_data, - const DeviceMemory &output_backprop_data, - const DeviceMemory &output_h_backprop_data, - const DeviceMemory &output_c_backprop_data, - DeviceMemory *input_backprop_data, - DeviceMemory *input_h_backprop_data, - DeviceMemory *input_c_backprop_data, - DeviceMemory *params_backprop_data, - DeviceMemory *reserve_space_data, - ScratchAllocator *workspace_allocator, - dnn::ProfileResult *output_profile_result); - - Stream &ThenRnnBackward(const dnn::RnnDescriptor &rnn_desc, - const dnn::RnnSequenceTensorDescriptor &input_desc, - const DeviceMemory &input_data, - const DeviceMemory &seq_lengths_data, - const dnn::RnnStateTensorDescriptor &input_h_desc, - const DeviceMemory &input_h_data, - const dnn::RnnStateTensorDescriptor &input_c_desc, - const DeviceMemory &input_c_data, - const DeviceMemory ¶ms, - const dnn::RnnSequenceTensorDescriptor &output_desc, - const DeviceMemory &output_data, - const dnn::RnnStateTensorDescriptor &output_h_desc, - const DeviceMemory &output_h_data, - const dnn::RnnStateTensorDescriptor &output_c_desc, - const DeviceMemory &output_c_data, - const DeviceMemory &output_backprop_data, - const DeviceMemory &output_h_backprop_data, - const DeviceMemory &output_c_backprop_data, - DeviceMemory *input_backprop_data, - DeviceMemory *input_h_backprop_data, - DeviceMemory *input_c_backprop_data, - DeviceMemory *params_backprop_data, - DeviceMemory *reserve_space_data, - ScratchAllocator *workspace_allocator, - dnn::ProfileResult *output_profile_result); - - // Enqueue a CTCLoss operation onto the stream. - // See DnnSupport::DoCtcLoss for more details. - Stream &ThenCtcLoss(const dnn::RnnStateTensorDescriptor &probs_desc, - const DeviceMemory &probs_data, - absl::Span labels_data, - absl::Span labels_lengths_data, - absl::Span input_lengths_data, - const NumericOptions &numeric_options, - DeviceMemory *costs_data, - const dnn::RnnStateTensorDescriptor &grads_desc, - DeviceMemory *grads_data, - ScratchAllocator *workspace_allocator); - - // Enqueue onto the stream a operation that transforms a tensor. - // See DnnSupport::DoTransformTensor for more details. - Stream &ThenTransformTensor(const dnn::BatchDescriptor &input_desc, - dnn::DataType input_type, - const DeviceMemoryBase &input_data, - const dnn::BatchDescriptor &output_desc, - dnn::DataType output_type, float scale, - DeviceMemoryBase *output_data); + absl::Status Memset32(DeviceMemoryBase *location, uint32_t pattern, + uint64_t size); // (Synchronously) block the host code waiting for the operations // entrained on the stream (enqueued to this point in program @@ -1193,25 +269,6 @@ class Stream { // Otherwise returns an error describing why the blocking failed. absl::Status BlockHostUntilDone() TF_LOCKS_EXCLUDED(mu_); - // Warning! This method interacts with internal threads in - // sometimes-unpredictable ways and is intended for GPU-Executor-internal - // use - // only. Please check with a member of the FASTR team before making use of - // this method. - // - // Entrains onto the stream a function to be executed on the host at some - // point in the future. - // Async host callbacks DO NOT block the stream as device functions (or as - // synchronous host callbacks). No synchronization is possible with - // asynchronous callbacks; they are strictly fire-and-forget. - // This method is private due to the potential for undefined behavior with - // synchronization using OpenCL user events. - // The ONLY lifetime guarantee in these calls is that the StreamExecutor - // parameter will still be valid - this Stream may not be! - // Any callbacks requiring device API calls must use this method. - Stream &ThenEnqueueOnBackgroundThread( - std::function task); - // Returns the (opaque) platform-specific backing object. Ownership is not // transferred to the caller. internal::StreamInterface *implementation() { return implementation_.get(); } @@ -1223,7 +280,9 @@ class Stream { // This is kept for backward compatibility. Future code should use // ThenDoHostCallbackWithStatus and explicitly return a success status. // TODO(b/112125301): Eventually remove this method. + ABSL_DEPRECATED("Use absl::Status returning method instead.") Stream &ThenDoHostCallback(absl::AnyInvocable callback); + absl::Status DoHostCallback(absl::AnyInvocable callback); // Entrains onto the stream a callback to the host (from the device). // Host callbacks block/occupy the stream just as device functions @@ -1237,8 +296,11 @@ class Stream { // // On certain platforms, ThenDoHostCallback is expected to have significant // negative effects on performance. + ABSL_DEPRECATED("Use absl::Status returning method instead.") Stream &ThenDoHostCallbackWithStatus( absl::AnyInvocable callback); + absl::Status DoHostCallbackWithStatus( + absl::AnyInvocable callback); // Returns the StreamExecutor (parent object) associated with this stream. StreamExecutor *parent() const { @@ -1264,51 +326,6 @@ class Stream { std::variant priority() const; private: - template - friend struct ThenBlasImpl; // for implementing ThenBlasXXX. - - // Checks whether types match before a call to extended BLAS version. - template - absl::Status CheckTypesForExtendedBlas( - blas::ComputationType computation_type) { - static_assert( - detail::is_any_of, std::complex>(), - "The only buffer types supported are: Eigen::half, float, " - "double, int8, std::complex and std::complex"); - static_assert( - std::is_same_v || - (std::is_same_v && - detail::is_any_of()), - "Mismatched alpha/beta and output types"); - - bool valid_computation_type = [computation_type] { - switch (computation_type) { - case blas::ComputationType::kF16: - return std::is_same_v; - case blas::ComputationType::kF32: - return detail::is_any_of>(); - case blas::ComputationType::kF64: - return detail::is_any_of>(); - case blas::ComputationType::kI32: - return std::is_same_v; - case blas::ComputationType::kF16AsF32: // fall-through - case blas::ComputationType::kBF16AsF32: // fall-through - case blas::ComputationType::kTF32AsF32: - return detail::is_any_of>(); - } - }(); - - if (!valid_computation_type) { - return absl::InternalError(absl::StrCat( - "Invalid computation type ", - blas::ComputationTypeString(computation_type), " for output type: ", - blas::DataTypeString(blas::ToDataType::value))); - } - return absl::OkStatus(); - } - bool InErrorState() const TF_LOCKS_EXCLUDED(mu_) { absl::ReaderMutexLock lock(&mu_); return !status_.ok(); @@ -1323,19 +340,6 @@ class Stream { void SetError() { CheckError(false /* = operation_retcode */); } - void SetErrorAndLogNoDnnSupport() { - SetError(); - LOG(WARNING) << "attempting to perform DNN operation using StreamExecutor " - "without DNN support"; - } - - // Allocates an array without type parameterization, so that the - // implementation can live in the source file. Without this base allocation - // method, we incur a circular dependency between the StreamExecutor - // definition and this class' definition. - absl::StatusOr> AllocateArrayBase( - uint64_t element_count, uint64 element_size); - // The StreamExecutor that supports the operation of this stream. StreamExecutor *parent_; @@ -1361,29 +365,6 @@ class Stream { std::vector, bool>> sub_streams_ ABSL_GUARDED_BY(mu_); - // Non-extended BLAS interface requires alpha/beta to be floats when input - // type is Eigen::half. However, for consistency purposes it is convenient - // for the interface to accept Eigen::half. - template - void UpcastHalfToFloat(void **alpha_ptr, void **beta_ptr, - float *alpha_storage, float *beta_storage) { - if (std::is_same::value) { - *alpha_storage = - static_cast(*reinterpret_cast(*alpha_ptr)); - *beta_storage = - static_cast(*reinterpret_cast(*beta_ptr)); - *alpha_ptr = alpha_storage; - *beta_ptr = beta_storage; - } else if (std::is_same::value) { - *alpha_storage = - static_cast(*reinterpret_cast(*alpha_ptr)); - *beta_storage = - static_cast(*reinterpret_cast(*beta_ptr)); - *alpha_ptr = alpha_storage; - *beta_ptr = beta_storage; - } - } - Stream(const Stream &) = delete; void operator=(const Stream &) = delete; }; @@ -1398,7 +379,7 @@ inline absl::Status Stream::ThenLaunch(ThreadDim thread_dims, Args... args) { auto kernel_args = PackKernelArgs(kernel, args...); TF_RETURN_IF_ERROR( - parent_->Launch(this, thread_dims, block_dims, kernel, *kernel_args)); + parent_->Launch(this, thread_dims, block_dims, *kernel, *kernel_args)); return absl::OkStatus(); } @@ -1409,7 +390,7 @@ inline absl::Status Stream::ThenLaunch(ThreadDim thread_dims, Args... args) { auto kernel_args = PackKernelArgs(shmem_bytes, args...); TF_RETURN_IF_ERROR( - parent_->Launch(this, thread_dims, block_dims, kernel, *kernel_args)); + parent_->Launch(this, thread_dims, block_dims, *kernel, *kernel_args)); return absl::OkStatus(); } @@ -1421,7 +402,7 @@ inline absl::Status Stream::ThenLaunch(ThreadDim thread_dims, Args... args) { auto kernel_args = PackKernelArgs(kernel, args...); TF_RETURN_IF_ERROR(parent_->Launch(this, thread_dims, block_dims, - cluster_dims, kernel, *kernel_args)); + cluster_dims, *kernel, *kernel_args)); return absl::OkStatus(); } @@ -1431,39 +412,10 @@ inline absl::Status Stream::ThenLaunch( int32_t shmem_bytes, const TypedKernel &kernel, Args... args) { auto kernel_args = PackKernelArgs(shmem_bytes, args...); TF_RETURN_IF_ERROR(parent_->Launch(this, thread_dims, block_dims, - cluster_dims, kernel, *kernel_args)); + cluster_dims, *kernel, *kernel_args)); return absl::OkStatus(); } -template -inline absl::StatusOr>> -Stream::AllocateTemporaryArray(uint64_t element_count) { - TF_ASSIGN_OR_RETURN( - std::unique_ptr temporary_memory, - AllocateArrayBase(element_count, sizeof(T))); - - return std::unique_ptr>( - reinterpret_cast *>(temporary_memory.release())); -} - -template <> -struct Quantization { - static constexpr dnn::QuantizedActivationMode kModeId = - dnn::QuantizedActivationMode::k8Bit; -}; - -template <> -struct Quantization { - static constexpr dnn::QuantizedActivationMode kModeId = - dnn::QuantizedActivationMode::k16Bit; -}; - -template <> -struct Quantization { - static constexpr dnn::QuantizedActivationMode kModeId = - dnn::QuantizedActivationMode::k32Bit; -}; - } // namespace stream_executor #endif // XLA_STREAM_EXECUTOR_STREAM_H_ diff --git a/third_party/xla/xla/stream_executor/stream_executor.h b/third_party/xla/xla/stream_executor/stream_executor.h index 9be77b79d14ed8..84a9ff8674b41e 100644 --- a/third_party/xla/xla/stream_executor/stream_executor.h +++ b/third_party/xla/xla/stream_executor/stream_executor.h @@ -22,16 +22,8 @@ limitations under the License. #ifndef XLA_STREAM_EXECUTOR_STREAM_EXECUTOR_H_ #define XLA_STREAM_EXECUTOR_STREAM_EXECUTOR_H_ -#include "xla/stream_executor/device_description.h" // IWYU pragma: export -#include "xla/stream_executor/device_memory.h" // IWYU pragma: export -#include "xla/stream_executor/device_options.h" // IWYU pragma: export -#include "xla/stream_executor/event.h" // IWYU pragma: export -#include "xla/stream_executor/kernel.h" // IWYU pragma: export -#include "xla/stream_executor/kernel_spec.h" // IWYU pragma: export -#include "xla/stream_executor/launch_dim.h" // IWYU pragma: export -#include "xla/stream_executor/multi_platform_manager.h" // IWYU pragma: export -#include "xla/stream_executor/platform.h" // IWYU pragma: export -#include "xla/stream_executor/stream.h" // IWYU pragma: export +#include "xla/stream_executor/multi_platform_manager.h" +#include "xla/stream_executor/stream.h" #include "xla/stream_executor/stream_executor_pimpl.h" // IWYU pragma: export #endif // XLA_STREAM_EXECUTOR_STREAM_EXECUTOR_H_ diff --git a/third_party/xla/xla/stream_executor/stream_executor_internal.h b/third_party/xla/xla/stream_executor/stream_executor_internal.h index 105ad827d01547..3f2722b20a01e5 100644 --- a/third_party/xla/xla/stream_executor/stream_executor_internal.h +++ b/third_party/xla/xla/stream_executor/stream_executor_internal.h @@ -72,164 +72,6 @@ class EventInterface { void operator=(const EventInterface&) = delete; }; -//===----------------------------------------------------------------------===// -// KernelInterface -//===----------------------------------------------------------------------===// - -// Pointer-to-implementation object type (i.e. the Kernel class delegates to -// this interface) with virtual destruction. This class exists for the -// platform-dependent code to hang any kernel data/resource info/functionality -// off of. -class KernelInterface { - public: - // Default constructor for the abstract interface. - KernelInterface() = default; - - // Default destructor for the abstract interface. - virtual ~KernelInterface() = default; - - // Returns the number of formal parameters that this kernel accepts. - virtual unsigned Arity() const = 0; - - // Sets the preferred cache configuration. - virtual void SetPreferredCacheConfig(KernelCacheConfig config) = 0; - - // Gets the preferred cache configuration. - virtual KernelCacheConfig GetPreferredCacheConfig() const = 0; - - // Returns the maximum number of blocks (per multiprocessor) occupied by the - // kernel given the number of threads per block and shared memory size. - virtual absl::StatusOr GetMaxOccupiedBlocksPerCore( - ThreadDim threads, size_t dynamic_shared_memory_bytes) const { - return absl::UnimplementedError("Not Implemented"); - } - - private: - KernelInterface(const KernelInterface&) = delete; - void operator=(const KernelInterface&) = delete; -}; - -//===----------------------------------------------------------------------===// -// CommandBufferInterface -//===----------------------------------------------------------------------===// - -// Platform-dependent interface class for implementing generic CommandBuffer. -// -// TODO(ezhulenev): Currently we assume that all operations between barriers -// can execute concurrently, and it's up to the caller to insert barriers to -// guarantee correctness. Consider adding finer grained synchronization -// mechanism between different commands. -// -// TODO(ezhulenev): Currently command buffers do no support updates, and once -// finalized can be executed as recorded. We need to support cheap command -// buffer updates that in GPU backend will be mapped to CUDA/HIP graph node -// updates. -class CommandBufferInterface { - public: - CommandBufferInterface() = default; - virtual ~CommandBufferInterface() = default; - - // Traces `function` invocation by recording all operations on the `stream` - // into the command buffer. Command buffer must be empty. - virtual absl::Status Trace(Stream* stream, - absl::AnyInvocable function) = 0; - - // Adds an execution barrier to a command buffer: all commands added before a - // barrier will complete before any of the commands added after a barrier. - virtual absl::Status Barrier(StreamExecutor* executor) = 0; - - // Adds a kernel launch command to the command buffer. - virtual absl::Status Launch(const ThreadDim& threads, const BlockDim& blocks, - const Kernel& kernel, const KernelArgs& args) = 0; - - // Adds a nested command buffer to the command buffer. - virtual absl::Status AddNestedCommandBuffer(const CommandBuffer& nested) = 0; - - // Adds a device-to-device memory copy to the command buffer. - virtual absl::Status MemcpyDeviceToDevice(DeviceMemoryBase* dst, - const DeviceMemoryBase& src, - uint64_t size) = 0; - - // Adds a memset node to the command buffer. - virtual absl::Status Memset(DeviceMemoryBase* dst, - CommandBuffer::BitPattern bit_pattern, - size_t num_elements) = 0; - - // Adds a device memory allocation node to the command buffer. - virtual absl::StatusOr Allocate(size_t bytes) = 0; - - // Adds a device memory free command to the command buffer, buffer is - // allocated in other command buffer, free through real address. - virtual absl::Status Free(DeviceMemoryBase dst) = 0; - - // For all conditional command APIs defined below, nested command buffers - // constructed for conditional branches owned by *this and should never be - // finalized or updated inside builders. - - // Adds a conditional operation that will run a command buffer constructed by - // `then_builder` if `predicate` value is `true`. - virtual absl::Status If(StreamExecutor* executor, - DeviceMemory predicate, - CommandBuffer::Builder then_builder) = 0; - - // Adds a conditional operation that will run a command buffer constructed by - // `then_builder` if `predicate` value is `true`, or a command buffer - // constructed by `else_builder` if `predicate` is `false`. - virtual absl::Status IfElse(StreamExecutor* executor, - DeviceMemory predicate, - CommandBuffer::Builder then_builder, - CommandBuffer::Builder else_builder) = 0; - - // Adds a conditional operation that will run a command buffer constructed by - // the `branches` builder at `index`. If `index` is out of range, then it will - // run a conditional command buffer constructed by the last builder. - // - // See: https://github.com/openxla/stablehlo/blob/main/docs/spec.md#case - virtual absl::Status Case(StreamExecutor* executor, - DeviceMemory index, - std::vector branches) = 0; - - // Adds a conditional operation that will run a command buffer constructed by - // the `body_builder` exactly `num_iteration` times. - virtual absl::Status For(StreamExecutor* executor, int32_t num_iteration, - DeviceMemory loop_index, - CommandBuffer::Builder body_builder) = 0; - - // Adds a conditional operation that will execute a command buffer constructed - // by the `cond_builder` that must update `pred` value, and then depending on - // the value might execute command buffer constructed by `body_builder` and - // `cond_builder`. Will continue while `pred` value is `true`. - // - // In pseudocode: - // - // cond_builder() - // while(pred): - // body_builder() - // cond_builder() - // - virtual absl::Status While(StreamExecutor* executor, DeviceMemory pred, - CommandBuffer::Builder cond_builder, - CommandBuffer::Builder body_builder) = 0; - - // Finalizes command buffer and makes it executable. Once command buffer is - // finalized no commands can be added to it. - virtual absl::Status Finalize() = 0; - - // Begins command buffer update. Command buffer update should be finalized - // before it can be executed. - virtual absl::Status Update() = 0; - - // Returns command buffer execution mode. - virtual CommandBuffer::Mode mode() const = 0; - - // Returns command buffer state. - virtual CommandBuffer::State state() const = 0; - - private: - CommandBufferInterface(const CommandBufferInterface&) = delete; - void operator=(const CommandBufferInterface&) = delete; -}; - //===----------------------------------------------------------------------===// // StreamInterface //===----------------------------------------------------------------------===// @@ -455,11 +297,14 @@ class StreamExecutorInterface { // Each call creates a new instance of the platform-specific implementation of // the corresponding interface type. virtual std::unique_ptr CreateEventImplementation() = 0; - virtual std::unique_ptr CreateKernelImplementation() = 0; virtual std::unique_ptr GetStreamImplementation() = 0; - virtual absl::StatusOr> - GetCommandBufferImplementation(CommandBuffer::Mode mode) { + virtual absl::StatusOr> CreateKernel() { + return absl::UnimplementedError("Kernels are not implemented"); + } + + virtual absl::StatusOr> CreateCommandBuffer( + CommandBuffer::Mode mode) { return absl::UnimplementedError("Command buffers are not implemented"); } diff --git a/third_party/xla/xla/stream_executor/stream_executor_pimpl.cc b/third_party/xla/xla/stream_executor/stream_executor_pimpl.cc index 217116e7628a5f..f555855b0854ae 100644 --- a/third_party/xla/xla/stream_executor/stream_executor_pimpl.cc +++ b/third_party/xla/xla/stream_executor/stream_executor_pimpl.cc @@ -21,15 +21,13 @@ limitations under the License. #include #include +#include -#include "absl/base/const_init.h" #include "absl/functional/any_invocable.h" #include "absl/status/status.h" #include "absl/status/statusor.h" -#include "absl/strings/ascii.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" -#include "absl/synchronization/notification.h" #include "absl/types/span.h" #include "xla/stream_executor/blas.h" #include "xla/stream_executor/command_buffer.h" @@ -38,6 +36,7 @@ limitations under the License. #include "xla/stream_executor/dnn.h" #include "xla/stream_executor/event.h" #include "xla/stream_executor/fft.h" +#include "xla/stream_executor/host_memory_allocation.h" #include "xla/stream_executor/kernel.h" #include "xla/stream_executor/kernel_spec.h" #include "xla/stream_executor/launch_dim.h" @@ -62,14 +61,6 @@ std::string StackTraceIfVLOG10() { } } -// Make sure the executor is done with its work; we know (because this isn't -// publicly visible) that all enqueued work is quick. -void BlockOnThreadExecutor(tsl::thread::ThreadPool* executor) { - absl::Notification n; - executor->Schedule([&n]() { n.Notify(); }); - n.WaitForNotification(); -} - } // namespace // Get per-device memory limit in bytes. Returns 0 if @@ -88,15 +79,11 @@ StreamExecutor::StreamExecutor( : platform_(platform), implementation_(std::move(implementation)), device_ordinal_(device_ordinal), - background_threads_(new tsl::thread::ThreadPool( - tsl::Env::Default(), "stream_executor", kNumBackgroundThreads)), live_stream_count_(0), memory_limit_bytes_(GetMemoryLimitBytes()), allocator_(this) {} StreamExecutor::~StreamExecutor() { - BlockOnThreadExecutor(background_threads_.get()); - if (live_stream_count_.load() != 0) { LOG(WARNING) << "Not all streams were deallocated at executor destruction " << "time. This may lead to unexpected/bad behavior - " @@ -173,118 +160,6 @@ int64_t StreamExecutor::GetDeviceLoad() const { return implementation_->GetDeviceLoad(); } -absl::Status StreamExecutor::GetFusedMatmulRunners( - bool use_cudnn_frontend, dnn::DataType input_type, dnn::DataType bias_type, - dnn::DataType output_type, Stream* stream, bool trans_a, bool trans_b, - uint64_t m, uint64_t n, uint64_t k, int64_t lda, int64_t ldb, int64_t ldc, - dnn::ActivationMode activation_mode, bool use_fallback, - const NumericOptions& numeric_options, - std::vector>* - out_exec_plans) { - dnn::DnnSupport* dnn_support = AsDnn(); - if (!dnn_support) { - return absl::UnimplementedError("DNN library is not found."); - } - - return dnn_support->GetFusedMatmulRunners( - use_cudnn_frontend, input_type, bias_type, output_type, stream, trans_a, - trans_b, m, n, k, lda, ldb, ldc, activation_mode, use_fallback, - numeric_options, out_exec_plans); -} - -bool StreamExecutor::GetMIOpenConvolveAlgorithms( - dnn::ConvolutionKind kind, dnn::DataType element_type, Stream* stream, - const dnn::BatchDescriptor& input_descriptor, DeviceMemoryBase input_data, - const dnn::FilterDescriptor& filter_descriptor, - DeviceMemoryBase filter_data, const dnn::BatchDescriptor& output_descriptor, - DeviceMemoryBase output_data, - const dnn::ConvolutionDescriptor& convolution_descriptor, - ScratchAllocator* scratch_allocator, - std::vector* out_algorithms) { - dnn::DnnSupport* dnn_support = AsDnn(); - if (!dnn_support) { - return false; - } - return dnn_support->GetMIOpenConvolveAlgorithms( - kind, element_type, stream, input_descriptor, input_data, - filter_descriptor, filter_data, output_descriptor, output_data, - convolution_descriptor, scratch_allocator, out_algorithms); -} - -bool StreamExecutor::GetRnnAlgorithms( - std::vector* out_algorithms) { - dnn::DnnSupport* dnn_support = AsDnn(); - if (!dnn_support) { - return false; - } - return dnn_support->GetRnnAlgorithms(out_algorithms); -} - -bool StreamExecutor::GetBlasGemmAlgorithms( - Stream* stream, std::vector* out_algorithms) { - blas::BlasSupport* blas_support = AsBlas(); - if (!blas_support) { - return false; - } - return blas_support->GetBlasGemmAlgorithms(stream, out_algorithms); -} - -absl::StatusOr> -StreamExecutor::createRnnDescriptor( - int num_layers, int hidden_size, int input_size, int cell_size, - int batch_size, dnn::RnnInputMode input_mode, - dnn::RnnDirectionMode direction_mode, dnn::RnnMode rnn_mode, - dnn::DataType data_type, const dnn::AlgorithmConfig& algorithm_config, - const NumericOptions& numeric_options, float dropout, uint64_t seed, - ScratchAllocator* state_allocator, bool use_padded_io) { - dnn::DnnSupport* dnn_support = AsDnn(); - if (!dnn_support) { - return absl::UnknownError("Fail to find the dnn implementation."); - } - return dnn_support->createRnnDescriptor( - num_layers, hidden_size, input_size, cell_size, batch_size, input_mode, - direction_mode, rnn_mode, data_type, algorithm_config, numeric_options, - dropout, seed, state_allocator, use_padded_io); -} - -absl::StatusOr> -StreamExecutor::createRnnSequenceTensorDescriptor(int max_seq_length, - int batch_size, int data_size, - dnn::DataType data_type) { - dnn::DnnSupport* dnn_support = AsDnn(); - if (!dnn_support) { - return absl::UnknownError("Fail to find the dnn implementation."); - } - return dnn_support->createRnnSequenceTensorDescriptor( - max_seq_length, batch_size, data_size, data_type); -} - -absl::StatusOr> -StreamExecutor::createRnnSequenceTensorDescriptor( - int max_seq_length, int batch_size, int data_size, - const absl::Span& seq_lengths, bool time_major, - dnn::DataType data_type) { - dnn::DnnSupport* dnn_support = AsDnn(); - if (!dnn_support) { - return absl::UnknownError("Fail to find the dnn implementation."); - } - return dnn_support->createRnnSequenceTensorDescriptor( - max_seq_length, batch_size, data_size, seq_lengths, time_major, - data_type); -} - -absl::StatusOr> -StreamExecutor::createRnnStateTensorDescriptor(int num_layer, int batch_size, - int data_size, - dnn::DataType data_type) { - dnn::DnnSupport* dnn_support = AsDnn(); - if (!dnn_support) { - return absl::UnknownError("Fail to find the dnn implementation."); - } - return dnn_support->createRnnStateTensorDescriptor(num_layer, batch_size, - data_size, data_type); -} - dnn::DnnSupport* StreamExecutor::AsDnn() { absl::MutexLock lock(&mu_); if (dnn_ != nullptr) { @@ -416,18 +291,23 @@ absl::Status StreamExecutor::CollectiveMemoryDeallocate(void* location) { return implementation_->CollectiveMemoryDeallocate(location); } -void* StreamExecutor::HostMemoryAllocate(uint64_t size) { +absl::StatusOr> +StreamExecutor::HostMemoryAllocate(uint64_t size) { void* buffer = implementation_->HostMemoryAllocate(size); VLOG(1) << "Called StreamExecutor::HostMemoryAllocate(size=" << size << ") returns " << buffer << StackTraceIfVLOG10(); - return buffer; + if (buffer == nullptr && size > 0) { + return absl::InternalError( + absl::StrFormat("Failed to allocate HostMemory of size %d", size)); + } + return std::make_unique(buffer, size, implementation()); } -void StreamExecutor::HostMemoryDeallocate(void* location) { - VLOG(1) << "Called StreamExecutor::HostMemoryDeallocate(location=" << location - << ")" << StackTraceIfVLOG10(); +void StreamExecutor::HostMemoryDeallocate(void* data, uint64_t size) { + VLOG(1) << "Called StreamExecutor::HostMemoryDeallocate(data=" << data << ")" + << StackTraceIfVLOG10(); - return implementation_->HostMemoryDeallocate(location); + return implementation_->HostMemoryDeallocate(data); } bool StreamExecutor::SynchronizeAllActivity() { @@ -435,10 +315,6 @@ bool StreamExecutor::SynchronizeAllActivity() { << StackTraceIfVLOG10(); bool ok = implementation_->SynchronizeAllActivity(); - // This should all be quick and infallible work, so we can perform the - // synchronization even in the case of failure. - BlockOnThreadExecutor(background_threads_.get()); - return ok; } @@ -562,6 +438,20 @@ Event::Status StreamExecutor::PollForEventStatus(Event* event) { return implementation_->PollForEventStatus(event); } +absl::StatusOr> StreamExecutor::CreateStream( + std::optional> priority) { + auto stream = std::make_unique(this); + if (priority.has_value()) { + if (std::holds_alternative(*priority)) { + stream->SetPriority(std::get(*priority)); + } else { + stream->SetPriority(std::get(*priority)); + } + } + TF_RETURN_IF_ERROR(stream->Initialize()); + return std::move(stream); +} + bool StreamExecutor::AllocateStream(Stream* stream) { live_stream_count_.fetch_add(1, std::memory_order_relaxed); if (!implementation_->AllocateStream(stream)) { @@ -601,10 +491,6 @@ bool StreamExecutor::DeviceMemoryUsage(int64_t* free, int64_t* total) const { return implementation_->DeviceMemoryUsage(free, total); } -void StreamExecutor::EnqueueOnBackgroundThread(std::function task) { - background_threads_->Schedule(std::move(task)); -} - std::optional StreamExecutor::GetAllocatorStats() { return implementation_->GetAllocatorStats(); } @@ -689,18 +575,15 @@ absl::StatusOr StreamExecutorMemoryAllocator::GetStream( << "The logic below only works for synchronous allocators"; TF_ASSIGN_OR_RETURN(StreamExecutor * executor, GetStreamExecutor(device_ordinal)); - Stream* out = [&] { - absl::MutexLock lock(&mutex_); - if (!streams_.count(device_ordinal)) { - auto p = streams_.emplace(std::piecewise_construct, - std::forward_as_tuple(device_ordinal), - std::forward_as_tuple(executor)); - p.first->second.Init(); - return &p.first->second; - } - return &streams_.at(device_ordinal); - }(); - return out; + absl::MutexLock lock(&mutex_); + if (!streams_.count(device_ordinal)) { + auto p = streams_.emplace(std::piecewise_construct, + std::forward_as_tuple(device_ordinal), + std::forward_as_tuple(executor)); + TF_RETURN_IF_ERROR(p.first->second.Initialize()); + return &p.first->second; + } + return &streams_.at(device_ordinal); } } // namespace stream_executor diff --git a/third_party/xla/xla/stream_executor/stream_executor_pimpl.h b/third_party/xla/xla/stream_executor/stream_executor_pimpl.h index 8984df7415c202..cc1b3e62dbba6c 100644 --- a/third_party/xla/xla/stream_executor/stream_executor_pimpl.h +++ b/third_party/xla/xla/stream_executor/stream_executor_pimpl.h @@ -17,6 +17,7 @@ limitations under the License. #define XLA_STREAM_EXECUTOR_STREAM_EXECUTOR_PIMPL_H_ #include +#include #include #include #include @@ -40,14 +41,12 @@ limitations under the License. #include "xla/stream_executor/dnn.h" #include "xla/stream_executor/event.h" #include "xla/stream_executor/fft.h" +#include "xla/stream_executor/host_memory_allocation.h" #include "xla/stream_executor/kernel_spec.h" #include "xla/stream_executor/launch_dim.h" #include "xla/stream_executor/module_spec.h" -#include "xla/stream_executor/numeric_options.h" #include "xla/stream_executor/platform.h" -#include "xla/stream_executor/platform/port.h" -#include "tsl/platform/threadpool.h" -#include "tsl/protobuf/dnn.pb.h" +#include "tsl/platform/logging.h" namespace stream_executor { @@ -57,11 +56,6 @@ namespace internal { class StreamExecutorInterface; } // namespace internal -// Forward declaration of private friend class. -template -class ScopedTracer; - // A StreamExecutor manages a single device, in terms of executing work (kernel // launches) and memory management (allocation/deallocation, memory copies to // and from the device). It is conceptually the "handle" for a device -- Stream @@ -192,10 +186,8 @@ class StreamExecutor { // Memory allocated in this manner (or allocated and registered with // HostMemoryRegister() is required for use in asynchronous memcpy operations, // such as Stream::ThenMemcpy. - void* HostMemoryAllocate(uint64_t size); - - // Deallocates a region of host memory allocated by HostMemoryAllocate(). - void HostMemoryDeallocate(void* location); + absl::StatusOr> HostMemoryAllocate( + uint64_t size); // Synchronizes all activity occurring in the StreamExecutor's context (most // likely a whole device). @@ -279,64 +271,6 @@ class StreamExecutor { // will be reflected in "free". bool DeviceMemoryUsage(int64_t* free, int64_t* total) const; - absl::Status GetFusedMatmulRunners( - bool use_cudnn_frontend, dnn::DataType input_type, - dnn::DataType bias_type, dnn::DataType output_type, Stream* stream, - bool trans_a, bool trans_b, uint64_t m, uint64_t n, uint64_t k, - int64_t lda, int64_t ldb, int64_t ldc, - dnn::ActivationMode activation_mode, bool use_fallback, - const NumericOptions& numeric_options, - std::vector>* - out_exec_plans); - - // Returns the list of supported algorithms for the forward convolution - // operation. - bool GetMIOpenConvolveAlgorithms( - dnn::ConvolutionKind kind, dnn::DataType element_type, Stream* stream, - const dnn::BatchDescriptor& input_descriptor, DeviceMemoryBase input_data, - const dnn::FilterDescriptor& filter_descriptor, - DeviceMemoryBase filter_data, - const dnn::BatchDescriptor& output_descriptor, - DeviceMemoryBase output_data, - const dnn::ConvolutionDescriptor& convolution_descriptor, - ScratchAllocator* scratch_allocator, - std::vector* out_algorithms); - - // Returns the list of supported algorithms for rnn operation. - bool GetRnnAlgorithms(std::vector* out_algorithms); - - // Get the list of supported algorithms for BLAS gemm. - bool GetBlasGemmAlgorithms(Stream* stream, - std::vector* out_algorithms); - - // Create an RNN descriptor based on model shapes and configurations. - // The caller retains the ownership of the descriptor. - absl::StatusOr> createRnnDescriptor( - int num_layers, int hidden_size, int input_size, int cell_size, - int batch_size, dnn::RnnInputMode input_mode, - dnn::RnnDirectionMode direction_mode, dnn::RnnMode rnn_mode, - dnn::DataType data_type, const dnn::AlgorithmConfig& algorithm_config, - const NumericOptions& numeric_options, float dropout, uint64_t seed, - ScratchAllocator* state_allocator, bool use_padded_io); - - // Create a RNN sequence descriptor that specifies either the input or output - // sequence. The caller retains the ownership of the returned descriptor. - absl::StatusOr> - createRnnSequenceTensorDescriptor(int max_seq_length, int batch_size, - int data_size, dnn::DataType data_type); - - absl::StatusOr> - createRnnSequenceTensorDescriptor(int max_seq_length, int batch_size, - int data_size, - const absl::Span& seq_lengths, - bool time_major, dnn::DataType data_type); - - // Create an RNN state descriptor that specifies the input or hidden state. - // The caller retains the ownership of the returned descriptor. - absl::StatusOr> - createRnnStateTensorDescriptor(int num_layer, int batch_size, int data_size, - dnn::DataType data_type); - // Returns the device ordinal that this StreamExecutor was initialized with. // Meaningless before initialization. int device_ordinal() const { return device_ordinal_; } @@ -344,22 +278,6 @@ class StreamExecutor { // Returns a borrowed pointer to the underlying StreamExecutor implementation. internal::StreamExecutorInterface* implementation(); - // Creates a kernel which can be launched with `stream.ThenLaunch(...)` from a - // PTX (and optional CUBIN), such that the types of the arguments provided for - // launch would have to match types of the arguments provided at creation - // time. The canonical storage for both ptx and cubin_data should outlive the - // lifetime of the kernel. - template - absl::StatusOr>> CreateTypedKernel( - absl::string_view kernel_name, absl::string_view ptx, - absl::Span cubin_data); - - // Creates a kernel which can be launched with `stream.ThenLaunch(...)` from - // an in-process symbol pointer. - template - absl::StatusOr>> CreateTypedKernel( - absl::string_view kernel_name, void* symbol); - // Warning: use Stream::ThenLaunch instead, this method is not for general // consumption. However, this is the only way to launch a kernel for which // the type signature is only known at runtime; say, if an application @@ -431,16 +349,17 @@ class StreamExecutor { // Performs linear search over alive GPU streams. Stream* FindAllocatedStream(void* gpu_stream); + // Creates and initializes a Stream. + absl::StatusOr> CreateStream( + std::optional> priority = std::nullopt); + private: - template - friend class ScopedTracer; friend class Event; friend class Stream; - template - friend class TypedKernel; - template - friend struct ThenBlasImpl; + friend class HostMemoryAllocation; + + // Deallocates a region of host memory allocated by HostMemoryAllocate(). + void HostMemoryDeallocate(void* data, uint64_t size); // Synchronously allocates size bytes on the underlying platform and returns // a DeviceMemoryBase representing that allocation. In the case of failure, @@ -518,13 +437,6 @@ class StreamExecutor { // ownership transfer to caller. std::unique_ptr CreateDeviceDescription() const; - // Adds a task to the tsl::thread::ThreadPool work queue. These tasks must be - // fire-and-forget and have no external data or timing dependencies; their - // execution order and completion time have no guarantees. - // For an example of an appropriate task, see HostBlas::DoBlasGemmInternal; - // there, temporary internal buffers are freed using this method. - void EnqueueOnBackgroundThread(std::function task); - // Reader/writer lock for mutable data structures on this StreamExecutor. // // Mutable so that caching functions (like DeviceDescription, AsBlas, etc.) @@ -561,16 +473,6 @@ class StreamExecutor { // Immutable post-initialization. int device_ordinal_; - // Executor for handling host callback work that cannot be performed - // by a host callback thread - for example, cleanup after a host BLAS routine - // (which may make device API calls). This work cannot block the host - // callback thread, will be completed asynchronously, and should be treated - // as fire-and-forget. Assume no ordering guarantees WRT the tasks enqueued - // here. - // - // Immutable post-initialization. Object is thread-safe. - std::unique_ptr background_threads_; - // Counter for the current number of live streams. This is used to check // for accidentally-outstanding streams at StreamExecutor teardown time, as // well @@ -631,35 +533,6 @@ class ScopedModuleHandle { //////////// // Inlines -template -inline absl::StatusOr>> -StreamExecutor::CreateTypedKernel(absl::string_view kernel_name, - absl::string_view ptx, - absl::Span cubin_data) { - auto kernel_base = std::make_unique>(this); - MultiKernelLoaderSpec loader_spec(kernel_base->kNumberOfParameters); - loader_spec.AddCudaPtxInMemory(ptx, kernel_name); - - if (!cubin_data.empty()) { - loader_spec.AddCudaCubinInMemory( - reinterpret_cast(cubin_data.data()), kernel_name); - } - - TF_RETURN_IF_ERROR(GetKernel(loader_spec, kernel_base.get())); - return std::move(kernel_base); -} - -template -inline absl::StatusOr>> -StreamExecutor::CreateTypedKernel(absl::string_view kernel_name, void* symbol) { - auto kernel_base = std::make_unique>(this); - MultiKernelLoaderSpec loader_spec(kernel_base->kNumberOfParameters); - loader_spec.AddInProcessSymbol(symbol, kernel_name); - - TF_RETURN_IF_ERROR(GetKernel(loader_spec, kernel_base.get())); - return std::move(kernel_base); -} - template inline DeviceMemory StreamExecutor::AllocateArray(uint64_t element_count, int64_t memory_space) { diff --git a/third_party/xla/xla/stream_executor/stream_executor_test.cc b/third_party/xla/xla/stream_executor/stream_executor_test.cc new file mode 100644 index 00000000000000..8dd6c8b36e4183 --- /dev/null +++ b/third_party/xla/xla/stream_executor/stream_executor_test.cc @@ -0,0 +1,41 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/stream_executor/stream_executor.h" + +#include + +#include "xla/stream_executor/platform.h" +#include "xla/stream_executor/platform_manager.h" +#include "tsl/platform/statusor.h" +#include "tsl/platform/test.h" + +namespace stream_executor { + +static std::unique_ptr NewStreamExecutor() { + Platform* platform = PlatformManager::PlatformWithName("Host").value(); + StreamExecutorConfig config(/*ordinal=*/0); + return platform->GetUncachedExecutor(config).value(); +} + +TEST(StreamExecutorTest, HostMemoryAllocate) { + auto executor = NewStreamExecutor(); + + TF_ASSERT_OK_AND_ASSIGN(auto allocation, executor->HostMemoryAllocate(1024)); + EXPECT_NE(allocation->opaque(), nullptr); + EXPECT_EQ(allocation->size(), 1024); +} + +} // namespace stream_executor diff --git a/third_party/xla/xla/stream_executor/stream_test.cc b/third_party/xla/xla/stream_executor/stream_test.cc index e726a771c53d2c..2c6f9f7506e250 100644 --- a/third_party/xla/xla/stream_executor/stream_test.cc +++ b/third_party/xla/xla/stream_executor/stream_test.cc @@ -13,7 +13,11 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include +#include "absl/log/check.h" +#include "xla/stream_executor/platform_manager.h" #include "xla/stream_executor/stream_executor.h" +#include "tsl/platform/statusor.h" #include "tsl/platform/test.h" namespace stream_executor { @@ -22,7 +26,7 @@ namespace { class StreamTest : public ::testing::Test { protected: std::unique_ptr NewStreamExecutor() { - Platform* platform = MultiPlatformManager::PlatformWithName("Host").value(); + Platform* platform = PlatformManager::PlatformWithName("Host").value(); StreamExecutorConfig config(/*ordinal=*/0); return platform->GetUncachedExecutor(config).value(); } @@ -36,26 +40,33 @@ TEST_F(StreamTest, NoInitNotOk) { TEST_F(StreamTest, InitOk) { std::unique_ptr executor = NewStreamExecutor(); - Stream stream(executor.get()); - stream.Init(); - EXPECT_TRUE(stream.ok()); + TF_ASSERT_OK_AND_ASSIGN(auto stream, executor->CreateStream()); +} + +TEST_F(StreamTest, InitWithIntPriorityOk) { + std::unique_ptr executor = NewStreamExecutor(); + TF_ASSERT_OK_AND_ASSIGN(auto stream, executor->CreateStream(1)); +} + +TEST_F(StreamTest, InitWithStreamPriorityOk) { + std::unique_ptr executor = NewStreamExecutor(); + TF_ASSERT_OK_AND_ASSIGN(auto stream, + executor->CreateStream(StreamPriority::Highest)); } TEST_F(StreamTest, OneSubStream) { std::unique_ptr executor = NewStreamExecutor(); - Stream stream(executor.get()); - stream.Init(); - EXPECT_TRUE(stream.ok()); + TF_ASSERT_OK_AND_ASSIGN(auto stream, executor->CreateStream()); // Get and return a sub-stream. Sub-streams are always initialized. - Stream* sub_stream1 = stream.GetOrCreateSubStream(); + TF_ASSERT_OK_AND_ASSIGN(Stream * sub_stream1, stream->GetOrCreateSubStream()); EXPECT_TRUE(sub_stream1->ok()); - stream.ReturnSubStream(sub_stream1); + stream->ReturnSubStream(sub_stream1); // Get and return another sub-stream. - Stream* sub_stream2 = stream.GetOrCreateSubStream(); + TF_ASSERT_OK_AND_ASSIGN(Stream * sub_stream2, stream->GetOrCreateSubStream()); EXPECT_TRUE(sub_stream2->ok()); - stream.ReturnSubStream(sub_stream1); + stream->ReturnSubStream(sub_stream1); // The underlying sub-streams should be the same, since sub_stream1 // was returned before we tried to get sub_stream2. @@ -64,14 +75,12 @@ TEST_F(StreamTest, OneSubStream) { TEST_F(StreamTest, TwoSubStreams) { std::unique_ptr executor = NewStreamExecutor(); - Stream stream(executor.get()); - stream.Init(); - EXPECT_TRUE(stream.ok()); + TF_ASSERT_OK_AND_ASSIGN(auto stream, executor->CreateStream()); // Get two sub-streams. - Stream* sub_stream1 = stream.GetOrCreateSubStream(); + TF_ASSERT_OK_AND_ASSIGN(Stream * sub_stream1, stream->GetOrCreateSubStream()); EXPECT_TRUE(sub_stream1->ok()); - Stream* sub_stream2 = stream.GetOrCreateSubStream(); + TF_ASSERT_OK_AND_ASSIGN(Stream * sub_stream2, stream->GetOrCreateSubStream()); EXPECT_TRUE(sub_stream2->ok()); // The underlying sub-streams should be different, since neither @@ -79,123 +88,19 @@ TEST_F(StreamTest, TwoSubStreams) { EXPECT_NE(sub_stream1, sub_stream2); // Return sub_stream1 and get sub_stream3, which should be the same. - stream.ReturnSubStream(sub_stream1); - Stream* sub_stream3 = stream.GetOrCreateSubStream(); + stream->ReturnSubStream(sub_stream1); + TF_ASSERT_OK_AND_ASSIGN(Stream * sub_stream3, stream->GetOrCreateSubStream()); EXPECT_TRUE(sub_stream3->ok()); EXPECT_EQ(sub_stream1, sub_stream3); EXPECT_NE(sub_stream2, sub_stream3); // Return sub_stream2 and get sub_stream4, which should be the same. - stream.ReturnSubStream(sub_stream2); - Stream* sub_stream4 = stream.GetOrCreateSubStream(); + stream->ReturnSubStream(sub_stream2); + TF_ASSERT_OK_AND_ASSIGN(Stream * sub_stream4, stream->GetOrCreateSubStream()); EXPECT_TRUE(sub_stream4->ok()); EXPECT_EQ(sub_stream2, sub_stream4); EXPECT_NE(sub_stream3, sub_stream4); } -TEST_F(StreamTest, FailedSubStreamBeforeReturnNotReused) { - std::unique_ptr executor = NewStreamExecutor(); - Stream stream(executor.get()); - stream.Init(); - EXPECT_TRUE(stream.ok()); - - // Get sub_stream1. - Stream* sub_stream1 = stream.GetOrCreateSubStream(); - EXPECT_TRUE(sub_stream1->ok()); - - // Force an error on sub_stream1; here we call a method that requires DNN - // support, which we know the Host platform doesn't support. - sub_stream1->ThenDepthConcatenate({}, {}, nullptr); - EXPECT_FALSE(sub_stream1->ok()); - - // Return sub_stream1 and get sub_stream2. - stream.ReturnSubStream(sub_stream1); - Stream* sub_stream2 = stream.GetOrCreateSubStream(); - EXPECT_TRUE(sub_stream2->ok()); - - // The underlying sub_streams should be different. They would have been the - // same, but since we forced an error on sub_stream1, it will not be - // re-used. Sadly we can't just check: - // EXPECT_NE(sub_stream1, sub_stream2); - // - // The above should hold logically, but it may fail if the new Stream instance - // allocated for sub_stream2 happens to reside in the same memory address as - // sub_stream1. - // - // The check that sub_stream2->ok() serves as a good-enough check. - - // Return sub_stream2 and get sub_stream3. The previous error on sub_stream1 - // has no effect on these streams, and they are the same. - stream.ReturnSubStream(sub_stream2); - Stream* sub_stream3 = stream.GetOrCreateSubStream(); - EXPECT_TRUE(sub_stream3->ok()); - EXPECT_EQ(sub_stream2, sub_stream3); -} - -TEST_F(StreamTest, FailedSubStreamAfterReturnNotReused) { - std::unique_ptr executor = NewStreamExecutor(); - Stream stream(executor.get()); - stream.Init(); - EXPECT_TRUE(stream.ok()); - - // Get and return sub_stream1. - Stream* sub_stream1 = stream.GetOrCreateSubStream(); - EXPECT_TRUE(sub_stream1->ok()); - stream.ReturnSubStream(sub_stream1); - - // Force an error on sub_stream1; here we call a method that requires DNN - // support, which we know the Host platform doesn't support. - // - // It is a bit weird to use sub_stream1 after it has already been returned. By - // doing this, we're simulating an asynchronous error that occurs during - // execution of the sub_stream, that occurs after the sub_stream is returned. - // - // E.g. the following is a common pattern of usage, where the execution of the - // operations enqueued onto the sub streams may occur after the streams have - // already been returned. - // - // void EnqueueOnSubStreams(Stream* stream) { - // Stream* sub_stream1 = stream.GetOrCreateSubStream(); - // Stream* sub_stream2 = stream.GetOrCreateSubStream(); - // // ... enqueue some operations on the sub streams ... - // stream.ThenWaitFor(sub_stream1).ThenWaitFor(sub_stream2); - // stream.ReturnSubStream(sub_stream1); - // stream.ReturnSubStream(sub_stream2); - // } - // - // Stream* main_stream = ...; - // EnqueueOnSubStreams(main_stream); - // main_stream.BlockHostUntilDone(); - // - // TODO(b/112196569): The semantics of failed sub-streams is error-prone; - // GetOrCreateSubStream can still return a sub-stream that has not encountered - // an error yet, but will encounter one in the future, based on previously - // enqueued operations. - sub_stream1->ThenDepthConcatenate({}, {}, nullptr); - EXPECT_FALSE(sub_stream1->ok()); - - // Get and return sub_stream2. - Stream* sub_stream2 = stream.GetOrCreateSubStream(); - EXPECT_TRUE(sub_stream2->ok()); - - // The underlying streams should be different. They would have been the same, - // but since we forced an error on sub_stream1, it will not be re-used. Sadly - // we can't just check: - // EXPECT_NE(sub_stream1, sub_stream2); - // - // The above should hold logically, but it may fail if the new stream instance - // allocated for sub_stream2 happens to reside in the same memory address as - // sub_stream1. - // - // The check that sub_stream2->ok() serves as a good-enough check. - - // Return sub_stream2 and get sub_stream3. The previous error on sub_stream1 - // has no effect on these streams, and they are the same. - stream.ReturnSubStream(sub_stream2); - Stream* sub_stream3 = stream.GetOrCreateSubStream(); - EXPECT_TRUE(sub_stream3->ok()); - EXPECT_EQ(sub_stream2, sub_stream3); -} - } // namespace } // namespace stream_executor diff --git a/third_party/xla/xla/stream_executor/temporary_device_memory.cc b/third_party/xla/xla/stream_executor/temporary_device_memory.cc deleted file mode 100644 index f3816860d037a6..00000000000000 --- a/third_party/xla/xla/stream_executor/temporary_device_memory.cc +++ /dev/null @@ -1,38 +0,0 @@ -/* Copyright 2015 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "xla/stream_executor/temporary_device_memory.h" - -#include "xla/stream_executor/stream.h" - -namespace stream_executor { - -TemporaryDeviceMemoryBase::~TemporaryDeviceMemoryBase() { - parent_->parent()->Deallocate(&device_memory_); -} - -DeviceMemoryBase* TemporaryDeviceMemoryBase::mutable_device_memory() { - return &device_memory_; -} - -const DeviceMemoryBase& TemporaryDeviceMemoryBase::device_memory() const { - return device_memory_; -} - -TemporaryDeviceMemoryBase::TemporaryDeviceMemoryBase( - Stream* parent, DeviceMemoryBase device_memory) - : device_memory_(device_memory), parent_(parent) {} - -} // namespace stream_executor diff --git a/third_party/xla/xla/stream_executor/temporary_device_memory.h b/third_party/xla/xla/stream_executor/temporary_device_memory.h deleted file mode 100644 index 7456f825365ed6..00000000000000 --- a/third_party/xla/xla/stream_executor/temporary_device_memory.h +++ /dev/null @@ -1,94 +0,0 @@ -/* Copyright 2015 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// Temporary memories are used to allocate scratch space required by an -// operation about to be enqueued onto a stream. -// -// std::unique_ptr> temporary_memory = -// stream.AllocateTemporaryArray(1024).value(); -// // ... enqueue stuff onto the stream using the temporary memory ... -// // Note that the memory is accessible via -// // temporary_memory->device_memory() and similar. -// -// Note that standard usage takes advantage of the type-safe wrapper, -// TemporaryDeviceMemory, defined below. -// -// Also see tests for executable sample usage. - -#ifndef XLA_STREAM_EXECUTOR_TEMPORARY_DEVICE_MEMORY_H_ -#define XLA_STREAM_EXECUTOR_TEMPORARY_DEVICE_MEMORY_H_ - -#include "xla/stream_executor/device_memory.h" - -namespace stream_executor { - -class Stream; - -// Untyped base class (analogous to a void*) for temporary device memory -// allocations associated with a stream. -class TemporaryDeviceMemoryBase { - public: - // Marks the temporary memory as finalized if it is not already marked as - // such. - ~TemporaryDeviceMemoryBase(); - - // Precondition: !IsFinalized() - DeviceMemoryBase* mutable_device_memory(); - - // Precondition: !IsFinalized() - const DeviceMemoryBase& device_memory() const; - - // Note: construction DCHECKs that the memory is known-allocated in the - // stream's temporary-allocation-manager. - TemporaryDeviceMemoryBase(Stream* parent, DeviceMemoryBase device_memory); - - private: - // The device memory region that has allocated. - DeviceMemoryBase device_memory_; - - // The stream that this temporary memory was allocated for. - Stream* parent_; -}; - -// Type-safe wrapper around the base type (which is analogous to a void*). -template -class TemporaryDeviceMemory : public TemporaryDeviceMemoryBase { - public: - // Type-safe wrapper around TemporaryDeviceMemoryBase::mutable_device_memory. - DeviceMemory* mutable_device_memory() { - StaticSlicingAssertionDummy(); - return reinterpret_cast*>( - TemporaryDeviceMemoryBase::mutable_device_memory()); - } - - // Type-safe wrapper around TemporaryDeviceMemoryBase::device_memory. - const DeviceMemory& device_memory() const { - StaticSlicingAssertionDummy(); - return reinterpret_cast&>( - TemporaryDeviceMemoryBase::device_memory()); - } - - private: - static void StaticSlicingAssertionDummy() { - static_assert( - sizeof(TemporaryDeviceMemory) == sizeof(TemporaryDeviceMemoryBase), - "derived class is simply a wrapper, no members may be added due to " - "slicing"); - } -}; - -} // namespace stream_executor - -#endif // XLA_STREAM_EXECUTOR_TEMPORARY_DEVICE_MEMORY_H_ diff --git a/third_party/xla/xla/stream_executor/tpu/BUILD b/third_party/xla/xla/stream_executor/tpu/BUILD index a758bb89c1344b..15c4f6856151a7 100644 --- a/third_party/xla/xla/stream_executor/tpu/BUILD +++ b/third_party/xla/xla/stream_executor/tpu/BUILD @@ -1,11 +1,25 @@ # Description: StreamExecutor Interface for TPUs load("//xla:xla.bzl", "xla_cc_test") -load("@local_tsl//tsl:tsl.bzl", "set_external_visibility") +load("@local_tsl//tsl:tsl.bzl", "internal_visibility") load("@local_tsl//tsl/platform:rules_cc.bzl", "cc_library") package( - default_visibility = ["//visibility:public"], + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = internal_visibility([ + "//learning/brain/experimental/dtensor:__subpackages__", + "//learning/brain/google/xla/kernels:__subpackages__", + "//learning/brain/research/pjrt:__subpackages__", + "//learning/brain/tfrc/executor:__subpackages__", + "//learning/brain/tfrt/tpu_plugin:__subpackages__", + "//tensorflow/compiler/jit:__subpackages__", + "//tensorflow/compiler/mlir:__subpackages__", + "//xla:__subpackages__", + "//xla/backends/profiler/tpu:__subpackages__", + "//tensorflow/core/common_runtime/next_pluggable_device:__subpackages__", + "//tensorflow/core/tpu:__subpackages__", + "//tensorflow/dtensor:__subpackages__", + ]), licenses = ["notice"], ) @@ -36,7 +50,6 @@ cc_library( name = "c_api_conversions", srcs = ["c_api_conversions.cc"], hdrs = ["c_api_conversions.h"], - visibility = ["//visibility:public"], deps = [ ":c_api_decl", ":proto_helper", @@ -114,7 +127,6 @@ cc_library( cc_library( name = "status_helper", hdrs = ["status_helper.h"], - visibility = ["//visibility:public"], deps = [ ":c_api_decl", ":tpu_executor_api", @@ -127,7 +139,6 @@ cc_library( cc_library( name = "tsl_status_helper", hdrs = ["tsl_status_helper.h"], - visibility = ["//visibility:public"], deps = [ ":c_api_decl", "@com_google_absl//absl/status", @@ -141,7 +152,6 @@ cc_library( name = "proto_helper", srcs = ["proto_helper.cc"], hdrs = ["proto_helper.h"], - visibility = ["//visibility:public"], deps = [ ":c_api_decl", "@local_tsl//tsl/platform:logging", @@ -216,7 +226,6 @@ cc_library( cc_library( name = "tpu_platform_hdr", hdrs = ["tpu_platform.h"], - visibility = ["//visibility:public"], deps = [ ":c_api_decl", ":tpu_executor_c_api_hdrs", @@ -243,7 +252,6 @@ cc_library( "tpu_stream.h", "tpu_stream_interface.h", ], - visibility = ["//visibility:public"], deps = [ ":c_api_conversions", ":c_api_decl", @@ -270,7 +278,6 @@ cc_library( name = "tpu_platform_id", srcs = ["tpu_platform_id.cc"], hdrs = ["tpu_platform_id.h"], - visibility = ["//visibility:public"], deps = ["//xla/stream_executor:platform"], ) @@ -286,7 +293,6 @@ cc_library( "tpu_platform.h", "tpu_stream.h", ], - visibility = ["//visibility:public"], deps = [ ":c_api_conversions", ":c_api_decl", @@ -342,7 +348,6 @@ cc_library( name = "tpu_node_context", srcs = ["tpu_node_context.cc"], hdrs = ["tpu_node_context.h"], - visibility = ["//visibility:public"], deps = [ ":status_helper", ":tpu_api", @@ -402,7 +407,6 @@ cc_library( name = "tpu_transfer_manager_base", srcs = ["tpu_transfer_manager.cc"], hdrs = ["tpu_transfer_manager.h"], - visibility = ["//visibility:public"], deps = [ ":c_api_conversions", ":c_api_decl", @@ -433,7 +437,6 @@ cc_library( name = "tpu_op_executable", srcs = ["tpu_op_executable.cc"], hdrs = ["tpu_op_executable.h"], - visibility = ["//visibility:public"], deps = [ ":c_api_conversions", ":c_api_decl", @@ -608,7 +611,6 @@ cc_library( name = "tpu_executor_api", srcs = ["tpu_executor_api.cc"], hdrs = ["tpu_executor_api.h"], - visibility = ["//visibility:public"], deps = [ ":libtftpu_header", ":tpu_executor_c_api_hdrs", @@ -618,7 +620,6 @@ cc_library( cc_library( name = "tpu_profiler_c_api_hdrs", hdrs = ["tpu_profiler_c_api.h"], - visibility = ["//visibility:public"], deps = [ ":c_api_decl", ":libtftpu_header", diff --git a/third_party/xla/xla/stream_executor/tpu/c_api_conversions.cc b/third_party/xla/xla/stream_executor/tpu/c_api_conversions.cc index 8b0aa4d4d11c0c..07f68baa2b093d 100644 --- a/third_party/xla/xla/stream_executor/tpu/c_api_conversions.cc +++ b/third_party/xla/xla/stream_executor/tpu/c_api_conversions.cc @@ -534,6 +534,8 @@ XLA_HloModuleConfig ToC(const xla::HloModuleConfig& config) { hlo_config.num_partitions = config.num_partitions(); hlo_config.use_spmd_partitioning = config.use_spmd_partitioning(); hlo_config.use_auto_spmd_partitioning = config.use_auto_spmd_partitioning(); + CreateVector(config.allow_spmd_sharding_propagation_to_parameters(), + &hlo_config.allow_spmd_sharding_propagation_to_parameters); CreateVector(config.allow_spmd_sharding_propagation_to_output(), &hlo_config.allow_spmd_sharding_propagation_to_output); CreateVector(config.auto_spmd_partitioning_mesh_shape(), @@ -582,6 +584,8 @@ xla::HloModuleConfig FromC(const XLA_HloModuleConfig& c_config) { config.set_num_partitions(c_config.num_partitions); config.set_use_spmd_partitioning(c_config.use_spmd_partitioning); config.set_use_auto_spmd_partitioning(c_config.use_auto_spmd_partitioning); + config.set_allow_spmd_sharding_propagation_to_parameters( + MakeSpan(c_config.allow_spmd_sharding_propagation_to_parameters)); config.set_allow_spmd_sharding_propagation_to_output( MakeSpan(c_config.allow_spmd_sharding_propagation_to_output)); absl::Span mesh_shape_span = diff --git a/third_party/xla/xla/stream_executor/tpu/c_api_decl.h b/third_party/xla/xla/stream_executor/tpu/c_api_decl.h index 5681eb63086f94..0e9019f2c0ad2c 100644 --- a/third_party/xla/xla/stream_executor/tpu/c_api_decl.h +++ b/third_party/xla/xla/stream_executor/tpu/c_api_decl.h @@ -322,6 +322,7 @@ typedef struct XLA_HloModuleConfig { TpuSerializedProto static_device_assignment; bool has_entry_computation_layout; XLA_ComputationLayout entry_computation_layout; + BoolList allow_spmd_sharding_propagation_to_parameters; BoolList allow_spmd_sharding_propagation_to_output; } XLA_HloModuleConfig; diff --git a/third_party/xla/xla/stream_executor/tpu/tpu_executable_interface.cc b/third_party/xla/xla/stream_executor/tpu/tpu_executable_interface.cc index bb85e4189afacf..5890e4e5720436 100644 --- a/third_party/xla/xla/stream_executor/tpu/tpu_executable_interface.cc +++ b/third_party/xla/xla/stream_executor/tpu/tpu_executable_interface.cc @@ -60,7 +60,7 @@ static Status PopulateResultTupleBuffers(const ShapedBuffer& result, TF_RETURN_IF_ERROR(transfer_manager->WriteTupleIndexTablesAsync( transfer_stream ? transfer_stream : stream, result)); if (transfer_stream && transfer_stream != stream) { - stream->ThenWaitFor(transfer_stream); + TF_RETURN_IF_ERROR(stream->WaitFor(transfer_stream)); } return absl::OkStatus(); } else { diff --git a/third_party/xla/xla/stream_executor/tpu/tpu_executor.h b/third_party/xla/xla/stream_executor/tpu/tpu_executor.h index 289233e6e3c562..c9e694844b8e06 100644 --- a/third_party/xla/xla/stream_executor/tpu/tpu_executor.h +++ b/third_party/xla/xla/stream_executor/tpu/tpu_executor.h @@ -167,10 +167,6 @@ class TpuExecutor : public tensorflow::tpu::TpuExecutorInterface { } // -- Unimplemented (stubbed out) methods. - std::unique_ptr - CreateKernelImplementation() override { - LOG(FATAL) << "Not yet implemented"; - } absl::Status MemZero(Stream* stream, DeviceMemoryBase* location, uint64_t size) override { diff --git a/third_party/xla/xla/stream_executor/tpu/tpu_platform.cc b/third_party/xla/xla/stream_executor/tpu/tpu_platform.cc index d739bb25a847ba..d346c510d90069 100644 --- a/third_party/xla/xla/stream_executor/tpu/tpu_platform.cc +++ b/third_party/xla/xla/stream_executor/tpu/tpu_platform.cc @@ -26,8 +26,8 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/synchronization/mutex.h" -#include "xla/stream_executor/multi_platform_manager.h" #include "xla/stream_executor/platform.h" +#include "xla/stream_executor/platform_manager.h" #include "xla/stream_executor/stream_executor.h" #include "xla/stream_executor/stream_executor_internal.h" #include "xla/stream_executor/tpu/c_api_decl.h" @@ -211,7 +211,7 @@ bool RegisterTpuPlatform() { tpu_registered_platform = new TpuPlatform(); std::unique_ptr platform( tpu_registered_platform); - TF_CHECK_OK(stream_executor::MultiPlatformManager::RegisterPlatform( + TF_CHECK_OK(stream_executor::PlatformManager::RegisterPlatform( std::move(platform))); tpu_platform_registered = true; } diff --git a/third_party/xla/xla/stream_executor/tpu/tpu_platform_interface.cc b/third_party/xla/xla/stream_executor/tpu/tpu_platform_interface.cc index 1036ad6cb956d1..c7df6619342830 100644 --- a/third_party/xla/xla/stream_executor/tpu/tpu_platform_interface.cc +++ b/third_party/xla/xla/stream_executor/tpu/tpu_platform_interface.cc @@ -16,8 +16,8 @@ limitations under the License. #include "xla/stream_executor/tpu/tpu_platform_interface.h" #include "absl/synchronization/mutex.h" -#include "xla/stream_executor/multi_platform_manager.h" #include "xla/stream_executor/platform.h" +#include "xla/stream_executor/platform_manager.h" #include "tsl/platform/env.h" #include "tsl/platform/logging.h" // IWYU pragma: keep #include "tsl/protobuf/error_codes.pb.h" @@ -32,8 +32,8 @@ TpuPlatformInterface* GetRegisteredPlatformStatic(bool initialize_platform, DCHECK_GT(tries_left, 0); // Prefer TpuPlatform if it's registered. auto status_or_tpu_platform = - stream_executor::MultiPlatformManager::PlatformWithName( - "TPU", initialize_platform); + stream_executor::PlatformManager::PlatformWithName("TPU", + initialize_platform); if (status_or_tpu_platform.ok()) { return static_cast(status_or_tpu_platform.value()); } @@ -45,7 +45,7 @@ TpuPlatformInterface* GetRegisteredPlatformStatic(bool initialize_platform, // Use any other registered TPU platform. auto status_or_other_tpu_platforms = - stream_executor::MultiPlatformManager::PlatformsWithFilter( + stream_executor::PlatformManager::PlatformsWithFilter( [](const stream_executor::Platform* platform) { return dynamic_cast(platform) != nullptr; diff --git a/third_party/xla/xla/stream_executor/tpu/tpu_platform_registration.cc b/third_party/xla/xla/stream_executor/tpu/tpu_platform_registration.cc index 28450675347562..c0c15a98f6e6a9 100644 --- a/third_party/xla/xla/stream_executor/tpu/tpu_platform_registration.cc +++ b/third_party/xla/xla/stream_executor/tpu/tpu_platform_registration.cc @@ -17,7 +17,7 @@ limitations under the License. #include "xla/stream_executor/tpu/tpu_platform.h" #if defined(PLATFORM_GOOGLE) -REGISTER_MODULE_INITIALIZER(tpu_platform, - tensorflow::tpu::RegisterTpuPlatform()); +STREAM_EXECUTOR_REGISTER_MODULE_INITIALIZER( + tpu_platform, tensorflow::tpu::RegisterTpuPlatform()); #endif diff --git a/third_party/xla/xla/tests/BUILD b/third_party/xla/xla/tests/BUILD index 9bc35c5420abf9..7395d6c62940e1 100644 --- a/third_party/xla/xla/tests/BUILD +++ b/third_party/xla/xla/tests/BUILD @@ -11,7 +11,7 @@ load( "@local_config_rocm//rocm:build_defs.bzl", "if_rocm_is_configured", ) -load("@local_tsl//tsl:tsl.bzl", "tsl_copts") +load("@local_tsl//tsl:tsl.bzl", "internal_visibility", "tsl_copts") load("@local_tsl//tsl:tsl.default.bzl", "filegroup") load( "@local_tsl//tsl/platform:build_config_root.bzl", @@ -24,7 +24,8 @@ load( ) package( - default_visibility = ["//visibility:public"], + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = internal_visibility([":friends"]), licenses = ["notice"], ) @@ -43,7 +44,6 @@ filegroup( "**/*.cc", "**/*.h", ]), - visibility = ["//visibility:public"], ) # Generate test_suites for all backends, named "${backend}_tests". @@ -55,7 +55,6 @@ cc_library( name = "xla_internal_test_main", testonly = True, srcs = ["xla_internal_test_main.cc"], - visibility = ["//visibility:public"], deps = [ "//xla:debug_options_flags", "@com_google_absl//absl/flags:flag", @@ -71,7 +70,6 @@ cc_library( name = "test_macros_header", testonly = True, hdrs = ["test_macros.h"], - visibility = ["//visibility:public"], ) # Generate a test_macros_${BACKEND} library per backend with the proper copts. @@ -82,7 +80,6 @@ cc_library( testonly = True, srcs = ["manifest_checking_test.cc"], hdrs = ["manifest_checking_test.h"], - visibility = ["//visibility:public"], deps = [ ":test_macros_header", "@com_google_absl//absl/container:flat_hash_map", @@ -97,7 +94,6 @@ cc_library( name = "test_utils", srcs = ["test_utils.cc"], hdrs = ["test_utils.h"], - visibility = ["//visibility:public"], deps = [ "//xla:literal", "//xla:literal_util", @@ -116,7 +112,6 @@ cc_library( testonly = True, srcs = ["literal_test_util.cc"], hdrs = ["literal_test_util.h"], - visibility = ["//visibility:public"], deps = [ "//xla:array2d", "//xla:array3d", @@ -144,7 +139,6 @@ cc_library( testonly = True, srcs = ["verified_hlo_module.cc"], hdrs = ["verified_hlo_module.h"], - visibility = ["//visibility:public"], deps = [ "//xla:shape_util", "//xla:status_macros", @@ -166,7 +160,6 @@ cc_library( name = "pjrt_client_registry", srcs = ["pjrt_client_registry.cc"], hdrs = ["pjrt_client_registry.h"], - visibility = ["//visibility:public"], deps = [ "//xla/pjrt:pjrt_client", ], @@ -178,7 +171,6 @@ cc_library( srcs = [ "pjrt_cpu_client_registry.cc", ], - visibility = ["//visibility:public"], deps = [ ":pjrt_client_registry", "//xla/pjrt/cpu:cpu_client", @@ -191,7 +183,6 @@ cc_library( srcs = [ "pjrt_gpu_client_registry.cc", ], - visibility = ["//visibility:public"], deps = [ ":pjrt_client_registry", "//xla/pjrt/gpu:gpu_helpers", @@ -204,7 +195,6 @@ cc_library( testonly = True, srcs = ["hlo_test_base.cc"], hdrs = ["hlo_test_base.h"], - visibility = ["//visibility:public"], deps = [ ":filecheck", ":literal_test_util", @@ -272,7 +262,6 @@ cc_library( testonly = True, srcs = ["client_library_test_base.cc"], hdrs = ["client_library_test_base.h"], - visibility = ["//visibility:public"], deps = [ ":literal_test_util", ":manifest_checking_test", @@ -287,6 +276,7 @@ cc_library( "//xla:status_macros", "//xla:statusor", "//xla:test_helpers", + "//xla:types", "//xla:xla_data_proto_cc", "//xla/client:client_library", "//xla/client:global_data", @@ -310,7 +300,6 @@ cc_library( testonly = True, srcs = ["llvm_irgen_test_base.cc"], hdrs = ["llvm_irgen_test_base.h"], - visibility = ["//visibility:public"], deps = [ ":codegen_test_base", ":filecheck", @@ -326,7 +315,6 @@ cc_library( testonly = True, srcs = ["codegen_test_base.cc"], hdrs = ["codegen_test_base.h"], - visibility = ["//visibility:public"], deps = [ ":hlo_test_base", "//xla/hlo/ir:hlo", @@ -343,7 +331,6 @@ cc_library( data = [ "@llvm-project//llvm:FileCheck", ], - visibility = ["//visibility:public"], deps = [ "//xla:statusor", "//xla:types", @@ -362,7 +349,6 @@ cc_library( testonly = True, srcs = ["local_client_test_base.cc"], hdrs = ["local_client_test_base.h"], - visibility = ["//visibility:public"], deps = [ ":client_library_test_base", ":manifest_checking_test", @@ -564,7 +550,7 @@ xla_test( # Hlo profiles are not supported on the interpreter backend. "interpreter", ], - tags = ["no_arm64"], + tags = ["no_aarch64"], deps = [ ":client_library_test_base", ":test_macros_header", @@ -793,6 +779,7 @@ xla_test( "//xla:array2d", "//xla:array3d", "//xla:array4d", + "//xla:comparison_util", "//xla:literal", "//xla:shape_util", "//xla:statusor", @@ -803,6 +790,8 @@ xla_test( "//xla/client:xla_builder", "@com_google_absl//absl/base", "@com_google_absl//absl/types:span", + "@local_tsl//tsl/platform:ml_dtypes", + "@ml_dtypes//:float8", ], ) @@ -811,7 +800,6 @@ cc_library( testonly = True, srcs = ["conv_depthwise_common.cc"], hdrs = ["conv_depthwise_common.h"], - visibility = ["//visibility:public"], deps = [ ":client_library_test_base", ":hlo_test_base", @@ -1796,6 +1784,7 @@ xla_test( "//xla:test", "//xla/client:local_client", "//xla/client:xla_builder", + "@com_google_absl//absl/strings:string_view", ], ) @@ -2530,6 +2519,7 @@ xla_test( "//xla/service:transfer_manager", "//xla/stream_executor", "//xla/stream_executor:device_memory_allocator", + "//xla/stream_executor:platform_manager", "//xla/stream_executor/host:host_platform", "//xla/stream_executor/host:host_platform_id", "@local_tsl//tsl/platform:env", @@ -2747,6 +2737,7 @@ xla_cc_test( "//xla/client:client_library", "//xla/client:xla_builder", "//xla/service:cpu_plugin", + "//xla/stream_executor:platform_manager", "@com_google_absl//absl/synchronization", "@local_tsl//tsl/lib/core:status_test_util", "@local_tsl//tsl/platform:env", diff --git a/third_party/xla/xla/tests/array_elementwise_ops_test.cc b/third_party/xla/xla/tests/array_elementwise_ops_test.cc index a19d01afbc336c..d4fc025dc175a2 100644 --- a/third_party/xla/xla/tests/array_elementwise_ops_test.cc +++ b/third_party/xla/xla/tests/array_elementwise_ops_test.cc @@ -16,6 +16,7 @@ limitations under the License. #include #include #include +#include #include #include #include @@ -27,20 +28,24 @@ limitations under the License. #include "absl/base/casts.h" #include "absl/types/span.h" +#include "ml_dtypes/include/float8.h" // from @ml_dtypes #include "xla/array2d.h" #include "xla/array3d.h" #include "xla/array4d.h" #include "xla/client/global_data.h" #include "xla/client/local_client.h" #include "xla/client/xla_builder.h" +#include "xla/comparison_util.h" #include "xla/layout_util.h" #include "xla/literal.h" +#include "xla/primitive_util.h" #include "xla/statusor.h" #include "xla/test.h" #include "xla/tests/client_library_test_base.h" #include "xla/tests/literal_test_util.h" #include "xla/tests/test_macros.h" #include "xla/types.h" +#include "tsl/platform/ml_dtypes.h" #if TENSORFLOW_USE_ROCM #include "rocm/rocm_config.h" @@ -1293,14 +1298,140 @@ XLA_TEST_F(ArrayElementwiseOpTest, CompareEqF32s) { ComputeAndCompareR1(&builder, {false, false, true, false, false}, {}); } -XLA_TEST_F(ArrayElementwiseOpTest, CompareEqF32sTO) { - SetFastMathDisabled(true); - XlaBuilder builder(TestName()); - auto lhs = ConstantR1(&builder, {-2.5f, 25.5f, 2.25f, NAN, 6.0f}); - auto rhs = ConstantR1(&builder, {10.0f, 5.0f, 2.25f, NAN, NAN}); - EqTotalOrder(lhs, rhs); +template +class TotalOrderTest : public ClientLibraryTestBase { + public: + void DoIt(ComparisonDirection direction) { + this->SetFastMathDisabled(true); + XlaBuilder builder(this->TestName()); + std::vector values = { + static_cast(0.0f), + std::numeric_limits::min(), + static_cast(1.0f), + std::numeric_limits::max(), + }; + if constexpr (std::numeric_limits::has_denorm) { + auto denorm = static_cast(std::numeric_limits::denorm_min()); + if (denorm >= std::numeric_limits::min()) { + values.push_back(std::numeric_limits::denorm_min()); + } + } + if constexpr (std::is_same_v || std::is_same_v) { + values.push_back(std::fabs(std::numeric_limits::quiet_NaN())); + } + if constexpr (std::numeric_limits::has_infinity) { + values.push_back(std::numeric_limits::infinity()); + } +#if defined(XLA_TEST_BACKEND_CPU) || defined(XLA_TEST_BACKEND_GPU) || \ + defined(XLA_TEST_BACKEND_INTERPRETER) + if constexpr (std::numeric_limits::has_quiet_NaN) { + values.push_back(Eigen::numext::abs(std::numeric_limits::quiet_NaN())); + } +#endif + values.reserve(values.size() * 2); + for (size_t i = 0, n = values.size(); i < n; ++i) { + auto value = values[i]; + auto neg = -value; + if (Eigen::numext::signbit(neg) != Eigen::numext::signbit(value)) { + values.push_back(neg); + } + } + std::vector lhs_data; + std::vector rhs_data; + lhs_data.reserve(values.size() * values.size()); + rhs_data.reserve(values.size() * values.size()); + for (T lhs_value : values) { + for (T rhs_value : values) { + lhs_data.push_back(lhs_value); + rhs_data.push_back(rhs_value); + } + } + absl::InlinedVector results; + results.reserve(lhs_data.size()); + Comparison comparison(direction, primitive_util::NativeToPrimitiveType(), + Comparison::Order::kTotal); + for (size_t i = 0; i < lhs_data.size(); ++i) { + results.push_back(comparison.Compare(lhs_data[i], rhs_data[i])); + } + auto lhs = ConstantR1(&builder, lhs_data); + auto rhs = ConstantR1(&builder, rhs_data); + switch (direction) { + case ComparisonDirection::kEq: + EqTotalOrder(lhs, rhs); + break; + case ComparisonDirection::kNe: + NeTotalOrder(lhs, rhs); + break; + case ComparisonDirection::kGt: + GtTotalOrder(lhs, rhs); + break; + case ComparisonDirection::kGe: + GeTotalOrder(lhs, rhs); + break; + case ComparisonDirection::kLt: + LtTotalOrder(lhs, rhs); + break; + case ComparisonDirection::kLe: + LeTotalOrder(lhs, rhs); + break; + } + + this->ComputeAndCompareR1(&builder, results, {}); + } +}; - ComputeAndCompareR1(&builder, {false, false, true, true, false}, {}); +using Types = ::testing::Types; + +TYPED_TEST_SUITE(TotalOrderTest, Types); + +TYPED_TEST(TotalOrderTest, Eq) { this->DoIt(ComparisonDirection::kEq); } +TYPED_TEST(TotalOrderTest, Ne) { this->DoIt(ComparisonDirection::kNe); } +TYPED_TEST(TotalOrderTest, Le) { this->DoIt(ComparisonDirection::kLe); } +TYPED_TEST(TotalOrderTest, Lt) { this->DoIt(ComparisonDirection::kLt); } +TYPED_TEST(TotalOrderTest, Ge) { this->DoIt(ComparisonDirection::kGe); } +TYPED_TEST(TotalOrderTest, Gt) { this->DoIt(ComparisonDirection::kGt); } +TYPED_TEST(TotalOrderTest, LargeMagnitudeVsNaN) { + using T = TypeParam; + if constexpr (!std::numeric_limits::has_quiet_NaN) { + GTEST_SKIP(); + } + this->SetFastMathDisabled(true); + + XlaBuilder builder(this->TestName()); + std::vector values = { + static_cast(0.0f), + std::numeric_limits::min(), + static_cast(1.0f), + std::numeric_limits::max(), + }; + if constexpr (std::numeric_limits::has_infinity) { + values.push_back(std::numeric_limits::infinity()); + } + for (size_t i = 0, n = values.size(); i < n; ++i) { + auto value = values[i]; + auto neg = -value; + if (Eigen::numext::signbit(neg) != Eigen::numext::signbit(value)) { + values.push_back(neg); + } + } + auto lhs = ConstantR1(&builder, values); + auto rhs = ConstantR1( + &builder, + std::vector(values.size(), std::numeric_limits::quiet_NaN())); + LtTotalOrder(lhs, rhs); + TF_ASSERT_OK_AND_ASSIGN(auto result, this->ComputeAndTransfer(&builder, {})); + EXPECT_TRUE(result.IsAll(0) || result.IsAll(1)) << result.ToString(); } XLA_TEST_F(ArrayElementwiseOpTest, CompareEqZeroElementF32s) { @@ -1322,23 +1453,6 @@ XLA_TEST_F(ArrayElementwiseOpTest, CompareGeF32s) { ComputeAndCompareR1(&builder, {false, true, true, false, false}, {}); } -XLA_TEST_F(ArrayElementwiseOpTest, CompareGeF32sTO) { - SetFastMathDisabled(true); - XlaBuilder builder(TestName()); - // For portability, need to represent NAN using the following call. - // The C++ standard does not specify if quiet_NaN() sets the sign bit of - // its result. The call to std::fabs will ensure that it is not set. - auto kNaN = std::fabs(std::numeric_limits::quiet_NaN()); - auto lhs = - ConstantR1(&builder, {-2.5f, 25.5f, 2.25f, kNaN, 6.0f, 6.0f}); - auto rhs = - ConstantR1(&builder, {10.0f, 5.0f, 1.0f, 10.0f, kNaN, -kNaN}); - GeTotalOrder(lhs, rhs); - - ComputeAndCompareR1(&builder, {false, true, true, true, false, true}, - {}); -} - XLA_TEST_F(ArrayElementwiseOpTest, CompareGtF32s) { SetFastMathDisabled(true); XlaBuilder builder(TestName()); @@ -2716,6 +2830,19 @@ XLA_TEST_F(ArrayElementwiseOpTest, Atan2C64s) { ComputeAndCompare(&builder, {}, error_spec_); } +XLA_TEST_F(ArrayElementwiseOpTest, ErfF32s) { + XlaBuilder builder(TestName()); + auto kInf = std::numeric_limits::infinity(); + auto kNaN = std::numeric_limits::quiet_NaN(); + auto a = ConstantR1( + &builder, {-kInf, -2.5f, 3.14f, -0.0f, 0.0f, 2.25f, kInf, kNaN}); + + Erf(a); + + ErrorSpec error_spec{1e-5f, 1e-5f}; + ComputeAndCompare(&builder, {}, error_spec); +} + XLA_TEST_F(ArrayElementwiseOpTest, TanhF32s) { XlaBuilder builder(TestName()); auto kInf = std::numeric_limits::infinity(); diff --git a/third_party/xla/xla/tests/broadcast_simple_test.cc b/third_party/xla/xla/tests/broadcast_simple_test.cc index 9d2a6734c1aa81..7e4154ac01b4d7 100644 --- a/third_party/xla/xla/tests/broadcast_simple_test.cc +++ b/third_party/xla/xla/tests/broadcast_simple_test.cc @@ -17,6 +17,7 @@ limitations under the License. #include #include +#include "absl/strings/string_view.h" #include "xla/array2d.h" #include "xla/array4d.h" #include "xla/client/local_client.h" @@ -34,6 +35,9 @@ namespace { class BroadcastSimpleTest : public ClientLibraryTestBase { public: + static constexpr absl::string_view kIncompatibleBinaryOpShapeErrorMessage = + "Binary op with incompatible shapes"; + XlaOp BuildBinOp(HloOpcode op, const XlaOp lhs, const XlaOp rhs, XlaBuilder* builder) { switch (op) { @@ -753,7 +757,7 @@ XLA_TEST_F(BroadcastSimpleTest, InvalidInDimensionBroadcasting) { auto result_status = Execute(&b, {}); EXPECT_FALSE(result_status.ok()); EXPECT_THAT(result_status.status().message(), - HasSubstr("op add with incompatible shapes")); + HasSubstr(kIncompatibleBinaryOpShapeErrorMessage)); } XLA_TEST_F(BroadcastSimpleTest, InvalidDegenerateBroadcasting) { @@ -766,7 +770,7 @@ XLA_TEST_F(BroadcastSimpleTest, InvalidDegenerateBroadcasting) { auto result_status = Execute(&b, {}); EXPECT_FALSE(result_status.ok()); EXPECT_THAT(result_status.status().message(), - HasSubstr("op add with incompatible shapes")); + HasSubstr(kIncompatibleBinaryOpShapeErrorMessage)); } } // namespace diff --git a/third_party/xla/xla/tests/client_library_test_base.h b/third_party/xla/xla/tests/client_library_test_base.h index 475071de8631c0..6ed2cac8a238c4 100644 --- a/third_party/xla/xla/tests/client_library_test_base.h +++ b/third_party/xla/xla/tests/client_library_test_base.h @@ -37,6 +37,7 @@ limitations under the License. #include "xla/tests/literal_test_util.h" #include "xla/tests/manifest_checking_test.h" #include "xla/tests/test_utils.h" +#include "xla/types.h" #include "xla/xla_data.pb.h" #include "tsl/lib/core/bitmap.h" #include "tsl/platform/ml_dtypes.h" @@ -444,6 +445,10 @@ class ClientLibraryTestBase : public ManifestCheckingTest { // Arguments to be passed to the computation when it runs. std::vector arguments_; + + template + static constexpr inline bool is_floating_or_complex_v = + std::disjunction_v, is_complex>; }; template @@ -459,17 +464,7 @@ template void ClientLibraryTestBase::ComputeAndCompareR0( XlaBuilder* builder, NativeT expected, absl::Span arguments, ErrorSpec error) { - static_assert(std::is_same::value || - std::is_same::value || - std::is_same::value || - std::is_same::value || - std::is_same::value || - std::is_same::value || - std::is_same::value || - std::is_same::value || - std::is_same::value || - std::is_same::value || - std::is_same::value, + static_assert(is_floating_or_complex_v, "Float or complex type required when specifying an ErrorSpec"); Literal expected_literal = LiteralUtil::CreateR0(expected); ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal, @@ -489,17 +484,7 @@ template void ClientLibraryTestBase::ComputeAndCompareR1( XlaBuilder* builder, absl::Span expected, absl::Span arguments, ErrorSpec error) { - static_assert(std::is_same::value || - std::is_same::value || - std::is_same::value || - std::is_same::value || - std::is_same::value || - std::is_same::value || - std::is_same::value || - std::is_same::value || - std::is_same::value || - std::is_same::value || - std::is_same::value, + static_assert(is_floating_or_complex_v, "Float or complex type required when specifying an ErrorSpec"); Literal expected_literal = LiteralUtil::CreateR1(expected); ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal, @@ -520,17 +505,7 @@ template void ClientLibraryTestBase::ComputeAndCompareR2( XlaBuilder* builder, const Array2D& expected, absl::Span arguments, ErrorSpec error) { - static_assert(std::is_same::value || - std::is_same::value || - std::is_same::value || - std::is_same::value || - std::is_same::value || - std::is_same::value || - std::is_same::value || - std::is_same::value || - std::is_same::value || - std::is_same::value || - std::is_same::value, + static_assert(is_floating_or_complex_v, "Float or complex type required when specifying an ErrorSpec"); Literal expected_literal = LiteralUtil::CreateR2FromArray2D(expected); @@ -552,17 +527,7 @@ template void ClientLibraryTestBase::ComputeAndCompareR3( XlaBuilder* builder, const Array3D& expected, absl::Span arguments, ErrorSpec error) { - static_assert(std::is_same::value || - std::is_same::value || - std::is_same::value || - std::is_same::value || - std::is_same::value || - std::is_same::value || - std::is_same::value || - std::is_same::value || - std::is_same::value || - std::is_same::value || - std::is_same::value, + static_assert(is_floating_or_complex_v, "Float or complex type required when specifying an ErrorSpec"); Literal expected_literal = LiteralUtil::CreateR3FromArray3D(expected); @@ -584,17 +549,7 @@ template void ClientLibraryTestBase::ComputeAndCompareR4( XlaBuilder* builder, const Array4D& expected, absl::Span arguments, ErrorSpec error) { - static_assert(std::is_same::value || - std::is_same::value || - std::is_same::value || - std::is_same::value || - std::is_same::value || - std::is_same::value || - std::is_same::value || - std::is_same::value || - std::is_same::value || - std::is_same::value || - std::is_same::value, + static_assert(is_floating_or_complex_v, "Float or complex type required when specifying an ErrorSpec"); Literal expected_literal = LiteralUtil::CreateR4FromArray4D(expected); @@ -615,17 +570,7 @@ template void ClientLibraryTestBase::ComputeAndCompare( XlaBuilder* builder, const Array& expected, absl::Span arguments, ErrorSpec error) { - static_assert(std::is_same::value || - std::is_same::value || - std::is_same::value || - std::is_same::value || - std::is_same::value || - std::is_same::value || - std::is_same::value || - std::is_same::value || - std::is_same::value || - std::is_same::value || - std::is_same::value, + static_assert(is_floating_or_complex_v, "Float or complex type required when specifying an ErrorSpec"); Literal expected_literal = LiteralUtil::CreateFromArray(expected); ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal, diff --git a/third_party/xla/xla/tests/convert_test.cc b/third_party/xla/xla/tests/convert_test.cc index ab9f9d3eaefece..91d545023c6374 100644 --- a/third_party/xla/xla/tests/convert_test.cc +++ b/third_party/xla/xla/tests/convert_test.cc @@ -43,7 +43,9 @@ class ConvertTest : public ClientLibraryTestBase { : ClientLibraryTestBase(platform) { mutable_debug_options()->add_xla_disable_hlo_passes("algsimp"); mutable_debug_options()->add_xla_disable_hlo_passes("inline"); - mutable_debug_options()->set_xla_gpu_simplify_all_fp_conversions(false); + mutable_debug_options()->add_xla_disable_hlo_passes( + "simplify-fp-conversions"); + mutable_debug_options()->set_xla_allow_excess_precision(false); } }; diff --git a/third_party/xla/xla/tests/dot_operation_test.cc b/third_party/xla/xla/tests/dot_operation_test.cc index 89a9b5987e5bb9..3430881bd8eeee 100644 --- a/third_party/xla/xla/tests/dot_operation_test.cc +++ b/third_party/xla/xla/tests/dot_operation_test.cc @@ -289,6 +289,34 @@ std::string PrintDotTestParam( class ParametricDotTest : public DotOperationTest, public ::testing::WithParamInterface { protected: + // This method runs before each test runs. + void SetUp() override { + // Several F16 tests are subject to denormal issues on MI210 architecture. + // For that matter, we set propagate_grad_xy_ flag for these tests, which + // activates adapted GEMM algorithm on ROCM. Besides, the adapted algorithm + // does not work well with ROCBLAS autotuning, hence we also disable it. + // This also serves as a test that grad_x/y attributes are correctly + // propagated down to a GEMM routine. + const auto& gpu_comp = client_->backend() + .default_stream_executor() + ->GetDeviceDescription() + .gpu_compute_capability(); + if (std::holds_alternative(gpu_comp)) { + std::string_view name( + ::testing::UnitTest::GetInstance()->current_test_info()->name()); + if (name.find("TestF16/270x270x520_MajorToMinor") != std::string::npos) { + execution_options_.mutable_debug_options()->set_xla_gpu_autotune_level( + 0); + DotTestParam param = GetParam(); + // In order to test both grad_x and grad_y attributes, we set + // propagate_grad_xy_ to 1 or 2 based on some alternating parameter + // to set it deterministically. + propagate_grad_xy_ = param.dot_lhs_row_major ? 1 : 2; + } + } + ManifestCheckingTest::SetUp(); + } + template void TestImpl(); @@ -296,6 +324,8 @@ class ParametricDotTest : public DotOperationTest, void ComputeAndCompareR2WithError(XlaBuilder* builder, const Array2D& expected, absl::Span arguments); + + int32_t propagate_grad_xy_ = 0; }; template @@ -356,6 +386,15 @@ void ParametricDotTest::TestImpl() { XlaBuilder builder(TestName()); auto prim_type = primitive_util::NativeToPrimitiveType(); + + if (propagate_grad_xy_ != 0) { + FrontendAttributes attributes; + if (propagate_grad_xy_ == 1) + (*attributes.mutable_map())["grad_x"] = "true"; + else + (*attributes.mutable_map())["grad_y"] = "true"; + builder.SetFrontendAttributes(attributes); + } auto result = Dot(Parameter(&builder, 0, ShapeUtil::MakeShapeWithDenseLayout( @@ -367,6 +406,9 @@ void ParametricDotTest::TestImpl() { prim_type, {param.k, param.n}, MinorToMajorForIsRowMajor(param.dot_rhs_row_major)), "dot_rhs")); + if (propagate_grad_xy_ != 0) { + builder.ClearFrontendAttributes(); + } if (param.has_addend) { result = diff --git a/third_party/xla/xla/tests/exhaustive/BUILD b/third_party/xla/xla/tests/exhaustive/BUILD index 875bceabc7837b..1b3ee7e074d852 100644 --- a/third_party/xla/xla/tests/exhaustive/BUILD +++ b/third_party/xla/xla/tests/exhaustive/BUILD @@ -5,7 +5,8 @@ load("//xla/tests:build_defs.bzl", "xla_test") load("@local_tsl//tsl/platform:rules_cc.bzl", "cc_library") package( - default_visibility = ["//visibility:public"], + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = [":friends"], licenses = ["notice"], ) @@ -25,7 +26,6 @@ cc_library( srcs = ["exhaustive_op_test_utils.cc"], hdrs = ["exhaustive_op_test_utils.h"], tags = ["no_pip"], - visibility = ["//visibility:public"], deps = [ "//xla:bit_cast", "//xla:shape_util", diff --git a/third_party/xla/xla/tests/exhaustive/exhaustive_op_test_utils.h b/third_party/xla/xla/tests/exhaustive/exhaustive_op_test_utils.h index afe8b10f5e4db9..363ddb357e1458 100644 --- a/third_party/xla/xla/tests/exhaustive/exhaustive_op_test_utils.h +++ b/third_party/xla/xla/tests/exhaustive/exhaustive_op_test_utils.h @@ -204,12 +204,12 @@ class ExhaustiveOpTestBase : public ClientLibraryTestBase { check_valid_range); } - StatusOr RunComputationHelper(const XlaComputation& comp, - const Literal& literal) { + absl::StatusOr RunComputationHelper(const XlaComputation& comp, + const Literal& literal) { return RunComputation(comp, {&literal}); } - StatusOr RunComputationHelper( + absl::StatusOr RunComputationHelper( const XlaComputation& comp, const std::array& literals) { std::array lit_ptrs; for (int i = 0; i < N; ++i) { @@ -237,7 +237,7 @@ class ExhaustiveOpTestBase : public ClientLibraryTestBase { // plain Client API, which is used by ClientLibraryTestBase. This is because // the plain Client API results does more memcpys to/from Literals, and that's // slow given that we're touching a lot of data here. - StatusOr RunComputation( + absl::StatusOr RunComputation( const XlaComputation& computation, absl::Span input_literals) { // Copy debug options from ClientLibraryTestBase. In particular, we're diff --git a/third_party/xla/xla/tests/filecheck.cc b/third_party/xla/xla/tests/filecheck.cc index 9375a608895f7a..9b8188a166520f 100644 --- a/third_party/xla/xla/tests/filecheck.cc +++ b/third_party/xla/xla/tests/filecheck.cc @@ -28,8 +28,8 @@ limitations under the License. namespace xla { -StatusOr RunFileCheck(const std::string& input, - absl::string_view pattern) { +absl::StatusOr RunFileCheck(const std::string& input, + absl::string_view pattern) { // Generate an input file for the FileCheck pattern. std::string pattern_path; auto env = tsl::Env::Default(); @@ -41,8 +41,8 @@ StatusOr RunFileCheck(const std::string& input, return RunFileCheckWithPatternFile(input, pattern_path); } -StatusOr RunFileCheckWithPatternFile(const std::string& input, - const std::string& pattern_file) { +absl::StatusOr RunFileCheckWithPatternFile( + const std::string& input, const std::string& pattern_file) { // Invoke FileCheck to check whether input matches `pattern`. std::string file_check_path = tsl::GetDataDependencyFilepath( tsl::testing::kIsOpenSource diff --git a/third_party/xla/xla/tests/filecheck.h b/third_party/xla/xla/tests/filecheck.h index 4088abb041ebb8..f03609f8bea4c4 100644 --- a/third_party/xla/xla/tests/filecheck.h +++ b/third_party/xla/xla/tests/filecheck.h @@ -26,14 +26,14 @@ namespace xla { // Runs FileCheck with the given pattern over given input string. Provided that // FileCheck can execute, returns true if and only if FileCheck succeeded in // matching the input. -StatusOr RunFileCheck(const std::string& input, - absl::string_view pattern); +absl::StatusOr RunFileCheck(const std::string& input, + absl::string_view pattern); // Runs FileCheck with the given pattern file over given input string. Provided // that FileCheck can execute, returns true if and only if FileCheck succeeded // in matching the input. -StatusOr RunFileCheckWithPatternFile(const std::string& input, - const std::string& pattern_file); +absl::StatusOr RunFileCheckWithPatternFile( + const std::string& input, const std::string& pattern_file); } // namespace xla diff --git a/third_party/xla/xla/tests/fuzz/BUILD b/third_party/xla/xla/tests/fuzz/BUILD index c3275e1bf4304e..fa5dde0ff1c3df 100644 --- a/third_party/xla/xla/tests/fuzz/BUILD +++ b/third_party/xla/xla/tests/fuzz/BUILD @@ -6,7 +6,6 @@ cc_library( name = "hlo_test_lib", testonly = True, srcs = ["hlo_test_lib.cc"], - visibility = ["//visibility:public"], deps = [ "//xla:error_spec", "//xla/tests:hlo_test_base", diff --git a/third_party/xla/xla/tests/hlo_metadata_test.cc b/third_party/xla/xla/tests/hlo_metadata_test.cc index 77c86d5d783c23..ed5260426044eb 100644 --- a/third_party/xla/xla/tests/hlo_metadata_test.cc +++ b/third_party/xla/xla/tests/hlo_metadata_test.cc @@ -60,7 +60,6 @@ TEST_F(HloMetadataTest, MetadataPropagation) { ->root_instruction(); EXPECT_THAT(instruction->metadata().op_type(), StrEq("add")); EXPECT_THAT(instruction->metadata().op_name(), StrEq("my_sum_op")); - EXPECT_NE(instruction->metadata().logical_creation_pass_id(), 0); } TEST_F(HloMetadataTest, MetadataClearing) { diff --git a/third_party/xla/xla/tests/hlo_test_base.cc b/third_party/xla/xla/tests/hlo_test_base.cc index 064d0a943ceb1f..5d8f54e43d16d1 100644 --- a/third_party/xla/xla/tests/hlo_test_base.cc +++ b/third_party/xla/xla/tests/hlo_test_base.cc @@ -131,7 +131,7 @@ std::unique_ptr HloTestBase::CreateNewVerifiedModule( backend().compiler()->ShapeSizeBytesFunction()); } -StatusOr> +absl::StatusOr> HloTestBase::ParseAndReturnVerifiedModule(absl::string_view hlo_text, int64_t replica_count, int64_t num_partitions) { @@ -143,7 +143,7 @@ HloTestBase::ParseAndReturnVerifiedModule(absl::string_view hlo_text, return module; } -StatusOr> +absl::StatusOr> HloTestBase::ParseAndReturnVerifiedModule(absl::string_view hlo_text, const HloModuleConfig& config) { auto module = std::make_unique( @@ -168,8 +168,8 @@ void HloTestBase::UpdateEntryComputationLayout(HloModule* module) { } /* static */ -StatusOr HloTestBase::RunHloPass(HloPassInterface* hlo_pass, - HloModule* module) { +absl::StatusOr HloTestBase::RunHloPass(HloPassInterface* hlo_pass, + HloModule* module) { const std::string module_str_before_run = module->ToProto().ShortDebugString(); const auto status_or = hlo_pass->Run(module); @@ -189,8 +189,8 @@ StatusOr HloTestBase::RunHloPass(HloPassInterface* hlo_pass, } /* static */ -StatusOr HloTestBase::RunHloPass(HloPassInterface&& hlo_pass, - HloModuleGroup* module_group) { +absl::StatusOr HloTestBase::RunHloPass(HloPassInterface&& hlo_pass, + HloModuleGroup* module_group) { const std::string module_group_str_before_run = module_group->ToProto().ShortDebugString(); const auto status_or = hlo_pass.RunOnModuleGroup(module_group); @@ -292,8 +292,8 @@ void HloTestBase::RunAndFilecheckHloModuleGroupRewrite( } } -StatusOr HloTestBase::Execute(std::unique_ptr module, - absl::Span arguments) { +absl::StatusOr HloTestBase::Execute( + std::unique_ptr module, absl::Span arguments) { return runner_->Execute(std::move(module), arguments); } @@ -305,11 +305,12 @@ Literal HloTestBase::ExecuteNoHloPasses(std::unique_ptr module, .value(); } -StatusOr> HloTestBase::GetHloRunner() { +absl::StatusOr> +HloTestBase::GetHloRunner() { if (runner_ != nullptr) { return std::move(runner_); } - StatusOr> status_or_runner = + absl::StatusOr> status_or_runner = GetHloRunnerForTest(test_platform_); // Test for successful creation of PjRt based Hlo Runner. @@ -323,7 +324,7 @@ Literal HloTestBase::ExecuteAndTransfer(std::unique_ptr module, return runner_->Execute(std::move(module), arguments, true, nullptr).value(); } -StatusOr> HloTestBase::ExecuteReplicated( +absl::StatusOr> HloTestBase::ExecuteReplicated( std::unique_ptr module, absl::Span arguments, int64_t num_replicas, bool use_threads, bool run_hlo_passes) { HloRunner::ReplicatedExecuteOptions options; @@ -337,7 +338,7 @@ StatusOr> HloTestBase::ExecuteReplicated( return runner_->ExecuteReplicated(std::move(module), options); } -StatusOr> HloTestBase::ExecuteReplicated( +absl::StatusOr> HloTestBase::ExecuteReplicated( std::unique_ptr module, absl::Span arguments, int64_t num_replicas, DeviceAssignment* device_assignment, bool run_hlo_passes, bool use_threads) { @@ -352,7 +353,7 @@ StatusOr> HloTestBase::ExecuteReplicated( device_assignment); } -StatusOr> HloTestBase::ExecuteReplicated( +absl::StatusOr> HloTestBase::ExecuteReplicated( std::function executable_provider, std::function argument_count_provider, std::function argument_provider, @@ -367,7 +368,7 @@ StatusOr> HloTestBase::ExecuteReplicated( options, device_assignment); } -StatusOr> HloTestBase::MakeReferenceModule( +absl::StatusOr> HloTestBase::MakeReferenceModule( const HloModule& test_module, const std::function& reference_preprocessor) { std::unique_ptr reference_module = test_module.Clone(); @@ -385,7 +386,7 @@ StatusOr> HloTestBase::MakeReferenceModule( return std::move(reference_module); } -StatusOr<::testing::AssertionResult> HloTestBase::RunAndCompareInternal( +absl::StatusOr<::testing::AssertionResult> HloTestBase::RunAndCompareInternal( std::unique_ptr module, const absl::Span arguments, const optional& error, bool run_hlo_passes, @@ -491,7 +492,7 @@ ::testing::AssertionResult HloTestBase::RunAndCompare( reference_preprocessor); } -StatusOr<::testing::AssertionResult> +absl::StatusOr<::testing::AssertionResult> HloTestBase::RunAndCompareTwoModulesInternal( std::unique_ptr module_0, std::unique_ptr module_1, const absl::Span arguments, @@ -597,6 +598,31 @@ ::testing::AssertionResult HloTestBase::RunAndCompareTwoModules( run_hlo_passes, args_max_bits_of_precision); } +::testing::AssertionResult HloTestBase::RunAndCompareTwoModules( + string_view hlo_string_module_0, string_view hlo_string_module_1, + const HloModuleConfig& config_0, const HloModuleConfig& config_1, + const std::optional& error, bool run_hlo_passes, + std::optional args_max_bits_of_precision) { + auto module_0_or_status = + ParseAndReturnVerifiedModule(hlo_string_module_0, config_0); + if (!module_0_or_status.ok()) { + return ::testing::AssertionFailure() + << "Error while parsing HLO text format: " + << module_0_or_status.status().ToString(); + } + + auto module_1_or_status = + ParseAndReturnVerifiedModule(hlo_string_module_1, config_1); + if (!module_1_or_status.ok()) { + return ::testing::AssertionFailure() + << "Error while parsing HLO text format: " + << module_1_or_status.status().ToString(); + } + return RunAndCompareTwoModules(std::move(module_0_or_status).value(), + std::move(module_1_or_status).value(), error, + run_hlo_passes, args_max_bits_of_precision); +} + ::testing::AssertionResult HloTestBase::RunAndCompareTwoModules( absl::string_view hlo_string_module_0, absl::string_view hlo_string_module_1, @@ -761,7 +787,7 @@ ::testing::AssertionResult HloTestBase::RunMultipleTimes( std::optional canonical_output; for (int i = 0; i < n; ++i) { - StatusOr output = + absl::StatusOr output = runner_->ExecuteWithExecutable(executables[i].get(), fake_arguments[i], /*profile=*/&((*profiles)[i])); if (!output.ok()) { @@ -874,13 +900,13 @@ void HloTestBase::MatchOptimizedHlo(absl::string_view hlo, GetOptimizedModule(hlo)); HloPrintOptions print_opts; print_opts.set_print_operand_shape(print_operand_shape); - StatusOr filecheck_result = + absl::StatusOr filecheck_result = RunFileCheck(optimized_module->ToString(print_opts), pattern); TF_ASSERT_OK(filecheck_result.status()); EXPECT_TRUE(filecheck_result.value()); } -StatusOr> HloTestBase::GetOptimizedModule( +absl::StatusOr> HloTestBase::GetOptimizedModule( absl::string_view hlo) { TF_ASSIGN_OR_RETURN( std::unique_ptr module, @@ -890,15 +916,15 @@ StatusOr> HloTestBase::GetOptimizedModule( backend().default_stream_executor()->GetAllocator()); } -StatusOr> HloTestBase::GetOptimizedModule( +absl::StatusOr> HloTestBase::GetOptimizedModule( std::unique_ptr hlo_module) { return backend().compiler()->RunHloPasses( std::move(hlo_module), backend().default_stream_executor(), backend().default_stream_executor()->GetAllocator()); } -StatusOr> HloTestBase::GetHloRunnerForTest( - se::Platform* test_platform) { +absl::StatusOr> +HloTestBase::GetHloRunnerForTest(se::Platform* test_platform) { if (ShouldUsePjRt()) { PjRtClientTestFactoryRegistry& pjrt_registry = GetGlobalPjRtClientTestFactory(); diff --git a/third_party/xla/xla/tests/hlo_test_base.h b/third_party/xla/xla/tests/hlo_test_base.h index 596096cb51832a..c1dd7ce2fad71b 100644 --- a/third_party/xla/xla/tests/hlo_test_base.h +++ b/third_party/xla/xla/tests/hlo_test_base.h @@ -90,11 +90,13 @@ class HloTestBase : public ManifestCheckingTest { const std::string& name = TestName(), int64_t replica_count = 1); // Parses the given string and returns module as a VerifiedHloModule. - StatusOr> ParseAndReturnVerifiedModule( - absl::string_view hlo_text, int64_t replica_count = 1, - int64_t num_partitions = 1); - StatusOr> ParseAndReturnVerifiedModule( - absl::string_view hlo_text, const HloModuleConfig& config); + absl::StatusOr> + ParseAndReturnVerifiedModule(absl::string_view hlo_text, + int64_t replica_count = 1, + int64_t num_partitions = 1); + absl::StatusOr> + ParseAndReturnVerifiedModule(absl::string_view hlo_text, + const HloModuleConfig& config); // Runs the hlo_pass with the provided module and returns the result. This // function also verifies that the module remains unchanged when hlo_pass @@ -104,14 +106,14 @@ class HloTestBase : public ManifestCheckingTest { // `RunHloPass(MyPass(), module)` all in one line. The reason for the // overload that takes a pointer is that, at one point in the past, non-const // lvalue references were banned in Google code. - static StatusOr RunHloPass(HloPassInterface* hlo_pass, - HloModule* module); - static StatusOr RunHloPass(HloPassInterface& hlo_pass, - HloModule* module) { + static absl::StatusOr RunHloPass(HloPassInterface* hlo_pass, + HloModule* module); + static absl::StatusOr RunHloPass(HloPassInterface& hlo_pass, + HloModule* module) { return RunHloPass(&hlo_pass, module); } - static StatusOr RunHloPass(HloPassInterface&& hlo_pass, - HloModule* module) { + static absl::StatusOr RunHloPass(HloPassInterface&& hlo_pass, + HloModule* module) { return RunHloPass(&hlo_pass, module); } @@ -119,8 +121,8 @@ class HloTestBase : public ManifestCheckingTest { // This method runs the input HLO module group pass for a `HloModuleGroup` and // it also verifies the module group remains unchanged when hlo_pass returns // false as the StatusOr value. - static StatusOr RunHloPass(HloPassInterface&& hlo_pass, - HloModuleGroup* module_group); + static absl::StatusOr RunHloPass(HloPassInterface&& hlo_pass, + HloModuleGroup* module_group); static PrecisionConfig DefaultPrecisionConfig(int operands); @@ -140,10 +142,10 @@ class HloTestBase : public ManifestCheckingTest { } // Compiles and returns module with optimizations from a given HLO. - StatusOr> GetOptimizedModule( + absl::StatusOr> GetOptimizedModule( absl::string_view hlo); - StatusOr> GetOptimizedModule( + absl::StatusOr> GetOptimizedModule( std::unique_ptr hlo_module); protected: @@ -202,8 +204,8 @@ class HloTestBase : public ManifestCheckingTest { } // Executes the given module and return the result as a Literal. - StatusOr Execute(std::unique_ptr module, - absl::Span arguments); + absl::StatusOr Execute(std::unique_ptr module, + absl::Span arguments); // Same as above, except the module will be executed without running any HLO // passes on it. @@ -214,7 +216,7 @@ class HloTestBase : public ManifestCheckingTest { absl::Span arguments); // Compile the given module to an executable. - StatusOr> CreateExecutable( + absl::StatusOr> CreateExecutable( std::unique_ptr module, bool run_hlo_passes) { return runner_->CreateExecutable(std::move(module), run_hlo_passes); } @@ -224,18 +226,18 @@ class HloTestBase : public ManifestCheckingTest { // use_threads indicates whether this replicated computation will be executed // with a thread-per-replica, vs using an implicitly async call such as // Executable::ExecuteOnStreams. - StatusOr> ExecuteReplicated( + absl::StatusOr> ExecuteReplicated( std::unique_ptr module, absl::Span arguments, int64_t num_replicas, bool use_threads, bool run_hlo_passes = false); // Same as above, but uses specified device assignment. - StatusOr> ExecuteReplicated( + absl::StatusOr> ExecuteReplicated( std::unique_ptr module, absl::Span arguments, int64_t num_replicas, DeviceAssignment* device_assignment, bool run_hlo_passes, bool use_threads); // Same as above, but allows passing different programs for replicas. - StatusOr> ExecuteReplicated( + absl::StatusOr> ExecuteReplicated( std::function executable_provider, std::function argument_count_provider, std::function argument_provider, @@ -315,6 +317,14 @@ class HloTestBase : public ManifestCheckingTest { const std::optional& error, bool run_hlo_passes = true, std::optional args_max_bits_of_precision = std::nullopt); + // Same as above but allows running with different configs. + ::testing::AssertionResult RunAndCompareTwoModules( + absl::string_view hlo_string_module_0, + absl::string_view hlo_string_module_1, const HloModuleConfig& config_0, + const HloModuleConfig& config_1, const std::optional& error, + bool run_hlo_passes = true, + std::optional args_max_bits_of_precision = std::nullopt); + // Same as above but requires explicit arguments. ::testing::AssertionResult RunAndCompareTwoModules( absl::string_view hlo_string_module_0, @@ -410,7 +420,7 @@ class HloTestBase : public ManifestCheckingTest { HloModule*, std::unique_ptr computation); void UpdateEntryComputationLayout(HloModule* module); - StatusOr> GetHloRunner(); + absl::StatusOr> GetHloRunner(); protected: // Helper functions to get test and reference platforms. @@ -425,14 +435,14 @@ class HloTestBase : public ManifestCheckingTest { // Given the test module, makes a reference module that is ready to run on the // reference platform. This assumes that the given module is ready to run on // the test platform. - StatusOr> MakeReferenceModule( + absl::StatusOr> MakeReferenceModule( const HloModule& test_module, const std::function& reference_preprocessor); // Runs the module on two platforms with or without running hlo passes and // compares the results. Returns whether the results are near or equal. If any // error happens before the results are computed, returns the error status. - StatusOr<::testing::AssertionResult> RunAndCompareInternal( + absl::StatusOr<::testing::AssertionResult> RunAndCompareInternal( std::unique_ptr module, const absl::Span arguments, const std::optional& error, bool run_hlo_passes, @@ -441,14 +451,14 @@ class HloTestBase : public ManifestCheckingTest { // Runs the two module on with or without running hlo passes and // compares the results. Returns whether the results are near or equal. If any // error happens before the results are computed, returns the error status. - StatusOr<::testing::AssertionResult> RunAndCompareTwoModulesInternal( + absl::StatusOr<::testing::AssertionResult> RunAndCompareTwoModulesInternal( std::unique_ptr module_0, std::unique_ptr module_1, const absl::Span arguments, const std::optional& error, bool run_hlo_passes); // Returns either an HloRunner or HloRunnerPjRt implementation depending if // there exists a registered PjRtClientFactory. - StatusOr> GetHloRunnerForTest( + absl::StatusOr> GetHloRunnerForTest( se::Platform* test_platform); }; diff --git a/third_party/xla/xla/tests/llvm_compiler_test.cc b/third_party/xla/xla/tests/llvm_compiler_test.cc index a0864a757043b8..5f826254f2524b 100644 --- a/third_party/xla/xla/tests/llvm_compiler_test.cc +++ b/third_party/xla/xla/tests/llvm_compiler_test.cc @@ -69,7 +69,7 @@ class GpuDummyCompiler : public GpuCompiler { return OkStatus(); } - StatusOr CompileTargetBinary( + absl::StatusOr CompileTargetBinary( const HloModuleConfig& module_config, llvm::Module* llvm_module, se::GpuComputeCapability gpu_version, bool relocatable, const HloModule* debug_module, const CompileOptions& options) override { @@ -88,7 +88,7 @@ class LLVMCompilerTest : public ::testing::Test { BackendOptions backend_options; backend_options.set_platform(platform); - StatusOr> backend_or_status = + absl::StatusOr> backend_or_status = Backend::CreateBackend(backend_options); ASSERT_IS_OK(backend_or_status.status()); backend_ = std::move(backend_or_status).value(); diff --git a/third_party/xla/xla/tests/llvm_irgen_test_base.cc b/third_party/xla/xla/tests/llvm_irgen_test_base.cc index 23841bb11a003c..59b16529b4df1f 100644 --- a/third_party/xla/xla/tests/llvm_irgen_test_base.cc +++ b/third_party/xla/xla/tests/llvm_irgen_test_base.cc @@ -56,7 +56,7 @@ void LlvmIrGenTestBase::CompileAndVerifyIr( ResetIrHook(); TF_ASSERT_OK(status); - StatusOr filecheck_result = RunFileCheck(ir_, pattern); + absl::StatusOr filecheck_result = RunFileCheck(ir_, pattern); TF_ASSERT_OK(filecheck_result.status()); EXPECT_TRUE(filecheck_result.value()) << "Full IR: " << ir_; } @@ -82,7 +82,7 @@ void LlvmIrGenTestBase::CompileAheadOfTimeAndVerifyIr( ResetIrHook(); TF_ASSERT_OK(status); - StatusOr filecheck_result = RunFileCheck(ir_, pattern); + absl::StatusOr filecheck_result = RunFileCheck(ir_, pattern); ASSERT_TRUE(filecheck_result.ok()); EXPECT_TRUE(filecheck_result.value()) << "Full IR: " << ir_; } diff --git a/third_party/xla/xla/tests/local_client_execute_test.cc b/third_party/xla/xla/tests/local_client_execute_test.cc index 0483835c35f359..79ff96013223e5 100644 --- a/third_party/xla/xla/tests/local_client_execute_test.cc +++ b/third_party/xla/xla/tests/local_client_execute_test.cc @@ -30,6 +30,7 @@ limitations under the License. #include "xla/statusor.h" #include "xla/stream_executor/device_memory_allocator.h" #include "xla/stream_executor/host/host_platform_id.h" +#include "xla/stream_executor/platform_manager.h" #include "xla/stream_executor/stream_executor.h" #include "xla/test_helpers.h" #include "xla/tests/literal_test_util.h" @@ -673,8 +674,7 @@ XLA_TEST_F(LocalClientExecuteTest, // Try to run a computation on a stream for a platform (CPU) which does not // match the platform of the service (!= CPU). se::Platform* wrong_platform = - se::MultiPlatformManager::PlatformWithId(se::host::kHostPlatformId) - .value(); + se::PlatformManager::PlatformWithId(se::host::kHostPlatformId).value(); se::Stream wrong_stream(wrong_platform->ExecutorForDevice(0).value()); wrong_stream.Init(); @@ -691,8 +691,7 @@ XLA_TEST_F(LocalClientExecuteTest, XLA_TEST_F(LocalClientExecuteTest, DISABLED_ON_CPU(AllocatorDoesNotMatchPlatform)) { se::Platform* wrong_platform = - se::MultiPlatformManager::PlatformWithId(se::host::kHostPlatformId) - .value(); + se::PlatformManager::PlatformWithId(se::host::kHostPlatformId).value(); TestAllocator allocator(wrong_platform); XlaBuilder builder(TestName()); diff --git a/third_party/xla/xla/tests/local_client_test_base.cc b/third_party/xla/xla/tests/local_client_test_base.cc index 6cae6294c51c39..8df3e5eb11aca0 100644 --- a/third_party/xla/xla/tests/local_client_test_base.cc +++ b/third_party/xla/xla/tests/local_client_test_base.cc @@ -39,10 +39,9 @@ namespace xla { /* static */ TestAllocator* LocalClientTestBase::allocator_; -StatusOr TestAllocator::Allocate(int device_ordinal, - uint64_t size, - bool retry_on_failure, - int64_t memory_space) { +absl::StatusOr TestAllocator::Allocate( + int device_ordinal, uint64_t size, bool retry_on_failure, + int64_t memory_space) { VLOG(2) << "Allocate(" << device_ordinal << ", " << size << ")"; { absl::MutexLock lock(&count_mutex_); @@ -173,14 +172,14 @@ ScopedShapedBuffer LocalClientTestBase::ExecuteLocallyOrDie( .value(); } -StatusOr LocalClientTestBase::ExecuteLocally( +absl::StatusOr LocalClientTestBase::ExecuteLocally( const XlaComputation& computation, absl::Span arguments) { return ExecuteLocally(computation, arguments, DefaultExecutableBuildOptions(), DefaultExecutableRunOptions()); } -StatusOr LocalClientTestBase::ExecuteLocally( +absl::StatusOr LocalClientTestBase::ExecuteLocally( const XlaComputation& computation, absl::Span arguments, const ExecutableBuildOptions& build_options, @@ -208,12 +207,12 @@ StatusOr LocalClientTestBase::ExecuteLocally( return std::move(ret); } -StatusOr> +absl::StatusOr> LocalClientTestBase::ParseAndReturnVerifiedModule(absl::string_view hlo_text) { return ParseAndReturnVerifiedModule(hlo_text, HloModuleConfig()); } -StatusOr> +absl::StatusOr> LocalClientTestBase::ParseAndReturnVerifiedModule( absl::string_view hlo_text, const HloModuleConfig& config) { auto module = std::make_unique( diff --git a/third_party/xla/xla/tests/local_client_test_base.h b/third_party/xla/xla/tests/local_client_test_base.h index 8be0c62d02c588..7f99d8471442f6 100644 --- a/third_party/xla/xla/tests/local_client_test_base.h +++ b/third_party/xla/xla/tests/local_client_test_base.h @@ -47,9 +47,9 @@ class TestAllocator : public se::StreamExecutorMemoryAllocator { : se::StreamExecutorMemoryAllocator( platform, PlatformUtil::GetStreamExecutors(platform).value()) {} - StatusOr Allocate(int device_ordinal, uint64_t size, - bool retry_on_failure, - int64_t memory_space) override; + absl::StatusOr Allocate( + int device_ordinal, uint64_t size, bool retry_on_failure, + int64_t memory_space) override; Status Deallocate(int device_ordinal, se::DeviceMemoryBase mem) override; // Return the number of allocations that have been performed. @@ -93,10 +93,10 @@ class LocalClientTestBase : public ManifestCheckingTest { // Execute the given computation on the local client. With and without // options. - StatusOr ExecuteLocally( + absl::StatusOr ExecuteLocally( const XlaComputation& computation, absl::Span arguments); - StatusOr ExecuteLocally( + absl::StatusOr ExecuteLocally( const XlaComputation& computation, absl::Span arguments, const ExecutableBuildOptions& build_options, @@ -112,10 +112,11 @@ class LocalClientTestBase : public ManifestCheckingTest { const ExecutableRunOptions& run_options); // Parses the given string and returns module as a VerifiedHloModule. - StatusOr> ParseAndReturnVerifiedModule( - absl::string_view hlo_text); - StatusOr> ParseAndReturnVerifiedModule( - absl::string_view hlo_text, const HloModuleConfig& config); + absl::StatusOr> + ParseAndReturnVerifiedModule(absl::string_view hlo_text); + absl::StatusOr> + ParseAndReturnVerifiedModule(absl::string_view hlo_text, + const HloModuleConfig& config); // Returns a default set of execute options. ExecutableBuildOptions DefaultExecutableBuildOptions() const; diff --git a/third_party/xla/xla/tests/map_test.cc b/third_party/xla/xla/tests/map_test.cc index 39cb7f33f2321a..e72605bc4249ab 100644 --- a/third_party/xla/xla/tests/map_test.cc +++ b/third_party/xla/xla/tests/map_test.cc @@ -489,7 +489,7 @@ TEST_F(MapTest, MapOperationWithBuildError) { auto param1 = Parameter(&builder, 1, param1_literal.shape(), "param1"); Map(&builder, {param0, param1}, error_add, {0}); - StatusOr computation_status = builder.Build(); + absl::StatusOr computation_status = builder.Build(); ASSERT_TRUE(!computation_status.ok()); EXPECT_THAT(computation_status.status().ToString(), ::testing::HasSubstr("error from: ErrorAdd: Binary op add with " diff --git a/third_party/xla/xla/tests/multioutput_fusion_test.cc b/third_party/xla/xla/tests/multioutput_fusion_test.cc index b2ee3f364933de..c2220217b12de2 100644 --- a/third_party/xla/xla/tests/multioutput_fusion_test.cc +++ b/third_party/xla/xla/tests/multioutput_fusion_test.cc @@ -494,6 +494,26 @@ ENTRY main { EXPECT_TRUE(RunAndCompareNoHloPasses(std::move(module), ErrorSpec(1e-5))); } +XLA_TEST_F(MultiOutputFusionTest, MultiOutputReduceGeneralBitcastCompatible) { + const std::string testcase = absl::StrCat(kScalarOps, R"( +fused_computation { + param_0 = f32[64,128]{1,0} parameter(0) + neg = f32[64,128]{1,0} negate(param_0) + bitcast = f32[8,8,128]{2,1,0} bitcast(neg) + bitcast2 = f32[128,64]{0,1} bitcast(neg) + constant_0 = f32[] constant(0) + reduce.1 = f32[128]{0} reduce(bitcast, constant_0), dimensions={0,1}, to_apply=Add + ROOT tuple.12 = (f32[128]{0}, f32[64,128]{1,0}, f32[128,64]{0,1}) tuple(reduce.1, neg, bitcast2) +} + +ENTRY main { + Arg_2.1 = f32[64,128]{1,0} parameter(0) + ROOT fusion = (f32[128]{0}, f32[64,128]{1,0}, f32[128,64]{0,1}) fusion(Arg_2.1), kind=kInput, calls=fused_computation +})"); + auto module = ParseAndReturnVerifiedModule(testcase).value(); + EXPECT_TRUE(RunAndCompareNoHloPasses(std::move(module), ErrorSpec(1e-5))); +} + XLA_TEST_F(MultiOutputFusionTest, MultiOutputReduceWithEpilogue) { const std::string testcase = absl::StrCat(kScalarOps, R"( fused_computation { diff --git a/third_party/xla/xla/tests/multiple_devices_on_host_test.cc b/third_party/xla/xla/tests/multiple_devices_on_host_test.cc index 2604f4d10961d4..ef69ea068d7fee 100644 --- a/third_party/xla/xla/tests/multiple_devices_on_host_test.cc +++ b/third_party/xla/xla/tests/multiple_devices_on_host_test.cc @@ -17,13 +17,14 @@ limitations under the License. #include "xla/client/client_library.h" #include "xla/client/xla_builder.h" #include "xla/shape_util.h" +#include "xla/stream_executor/platform_manager.h" #include "tsl/lib/core/status_test_util.h" #include "tsl/platform/env.h" #include "tsl/platform/test.h" namespace xla { namespace { -StatusOr BuildComputation() { +absl::StatusOr BuildComputation() { XlaBuilder b("computation"); Shape scalar_s32 = ShapeUtil::MakeShape(S32, {}); XlaOp infeed = InfeedWithToken(CreateToken(&b), scalar_s32); @@ -45,7 +46,7 @@ void CompileAndExecute( xla::ClientLibrary::GetXlaService(client->platform()) ->backend() .memory_allocator()); - StatusOr result = + absl::StatusOr result = executable->Run(absl::Span(), execute_options); { absl::MutexLock lock(results_mutex); @@ -57,7 +58,7 @@ void TestWithDeviceCount(const int device_count) { // Run `device_count` copies of the XLA program built by BuildComputation. TF_ASSERT_OK_AND_ASSIGN( se::Platform* const platform, - stream_executor::MultiPlatformManager::PlatformWithName("Host")); + stream_executor::PlatformManager::PlatformWithName("Host")); xla::LocalClientOptions client_options; client_options.set_platform(platform); TF_ASSERT_OK_AND_ASSIGN( diff --git a/third_party/xla/xla/tests/multithreaded_compilation_test.cc b/third_party/xla/xla/tests/multithreaded_compilation_test.cc index d34dd83cc90bff..530384d16e894d 100644 --- a/third_party/xla/xla/tests/multithreaded_compilation_test.cc +++ b/third_party/xla/xla/tests/multithreaded_compilation_test.cc @@ -76,7 +76,7 @@ XLA_TEST_F(MultithreadedCompilation, EightModuleCompilation) { absl::MutexLock lock(&mu); executables.push_back(std::move(executable)); VLOG(2) << "Adding executable obtained from thread: " << iteration; - return tsl::OkStatus(); + return absl::OkStatus(); }; { diff --git a/third_party/xla/xla/tests/onednn_matmul_test.cc b/third_party/xla/xla/tests/onednn_matmul_test.cc index d4a96b544a3eea..8d6a307b029a90 100644 --- a/third_party/xla/xla/tests/onednn_matmul_test.cc +++ b/third_party/xla/xla/tests/onednn_matmul_test.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include "xla/literal.h" +#include "xla/service/cpu/onednn_util.h" #include "xla/shape_util.h" #include "xla/test.h" #include "xla/test_helpers.h" @@ -76,9 +77,7 @@ TEST_F(MatmulTest, SimpleTestF32) { TEST_F(MatmulTest, SimpleTestBF16) { // TODO(penporn): Refactor IsBF16SupportedByOneDNNOnThisCPU() from // tensorflow/core/graph/mkl_graph_util.h and call the function instead. - using tsl::port::TestCPUFeature; - if (!TestCPUFeature(tsl::port::CPUFeature::AVX512_BF16) && - !TestCPUFeature(tsl::port::CPUFeature::AMX_BF16)) { + if (!IsSupportedType(PrimitiveType::BF16)) { GTEST_SKIP() << "CPU does not support BF16."; } @@ -457,6 +456,70 @@ TEST_F(MatmulTest, SimpleBiasTestBF16_PARAM_BF16) { MatchOptimizedHlo(matmul_module_str, fused_matmul_bias_); } +TEST_F(MatmulTest, DivisionByConstantWithEltwiseLinearF32) { + const char* matmul_module_str = R"( + HloModule matmul.divide.test.1, entry_computation_layout={(f32[16,128,768]{2,1,0}, f32[768,12,64]{2,1,0})->f32[16,128,12,64]{3,2,1,0}} + ENTRY matmul.divide.test.f32 { + Arg_4.5 = f32[16,128,768]{2,1,0} parameter(0), sharding={replicated} + Arg_2.3 = f32[768,12,64]{2,1,0} parameter(1), sharding={replicated} + onednn.matmul.0 = f32[16,128,12,64]{3,2,1,0} dot(Arg_4.5, Arg_2.3), lhs_contracting_dims={2}, rhs_contracting_dims={0} + constant.8 = f32[] constant(8) + broadcast.9 = f32[16,128,12,64]{3,2,1,0} broadcast(constant.8), dimensions={} + ROOT divide.16 = f32[16,128,12,64]{3,2,1,0} divide(onednn.matmul.0, broadcast.9) + })"; + + EXPECT_TRUE(RunAndCompare(matmul_module_str, ErrorSpec(1e-4, 1e-4))); + MatchOptimizedHlo(matmul_module_str, + R"( + ; CHECK: custom_call_target="__onednn$matmul", + ; CHECK: backend_config={ + ; CHECK-DAG: "outer_dimension_partitions":[], + ; CHECK-DAG: "onednn_matmul_config":{ + ; CHECK-DAG: "fused_ops":["LINEAR"] + ; CHECK-DAG: } + ; CHECK: } + )"); +} + +TEST_F(MatmulTest, TestF32NonConstantWeights) { + const char* matmul_module_str = R"( + HloModule matmul.test.f32, entry_computation_layout={(f32[64,256,16]{2,1,0},f32[16,32]{1,0})->f32[64,256,32]{2,1,0}} + + ENTRY matmul.test.f32 { + arg.0 = f32[64,256,16]{2,1,0} parameter(0), parameter_replication={false} + arg.1 = f32[16,32]{1,0} parameter(1), parameter_replication={false} + ROOT onednn.matmul.0 = f32[64,256,32]{2,1,0} dot(arg.0, arg.1), lhs_contracting_dims={2}, rhs_contracting_dims={0} + })"; + + EXPECT_TRUE(RunAndCompare(matmul_module_str, ErrorSpec{1e-4, 1e-4})); + MatchOptimizedHlo(matmul_module_str, + R"( + ; CHECK: %matmul.test.f32 + ; CHECK-NOT: custom_call_target="__onednn$matmul_reorder", + ; CHECK: custom-call(%{{[a-z,A-Z,0-9,\.]*}}, %arg.1), custom_call_target="__onednn$matmul", + )"); +} + +TEST_F(MatmulTest, TestF32ConstantWeights) { + const char* matmul_module_str = R"( + HloModule matmul.test.f32, entry_computation_layout={(f32[64,256,16]{2,1,0})->f32[64,256,32]{2,1,0}} + + ENTRY matmul.test.f32 { + arg.0 = f32[64,256,16]{2,1,0} parameter(0), parameter_replication={false} + constant = f32[] constant(1) + arg.1 = f32[16,32]{1,0} broadcast(constant), dimensions={} + ROOT onednn.matmul.0 = f32[64,256,32]{2,1,0} dot(arg.0, arg.1), lhs_contracting_dims={2}, rhs_contracting_dims={0} + })"; + + EXPECT_TRUE(RunAndCompare(matmul_module_str, ErrorSpec{1e-4, 1e-4})); + MatchOptimizedHlo(matmul_module_str, + R"( + ; CHECK: %matmul.test.f32 + ; CHECK-NOT: custom_call_target="__onednn$matmul_reorder", + ; CHECK: custom-call(%{{[a-z,A-Z,0-9,\.]*}}, %constant{{[a-z,A-Z,0-9,\.]*}}), custom_call_target="__onednn$matmul", + )"); +} + } // namespace cpu } // namespace xla diff --git a/third_party/xla/xla/tests/pjrt_client_registry.h b/third_party/xla/xla/tests/pjrt_client_registry.h index f152a2460d686c..7d82fddc058c30 100644 --- a/third_party/xla/xla/tests/pjrt_client_registry.h +++ b/third_party/xla/xla/tests/pjrt_client_registry.h @@ -32,11 +32,11 @@ class PjRtClientTestFactoryRegistry { typedef std::function DeviceShapeRepresentationFn; typedef std::function DeviceShapeRepresentationFnFactory; - typedef std::function>()> + typedef std::function>()> PjRtClientFactory; static DeviceShapeRepresentationFn DefaultShapeRepresentationRegisteredFn( - StatusOr client) { + absl::StatusOr client) { return [](const Shape& host_shape) { return host_shape; }; } @@ -66,14 +66,14 @@ class PjRtClientTestFactoryRegistry { return factory_ != nullptr; } - std::function>()> Get() const { + std::function>()> Get() const { absl::MutexLock lock(&mu_); return factory_; } private: mutable absl::Mutex mu_; - std::function>()> factory_ + std::function>()> factory_ ABSL_GUARDED_BY(mu_); DeviceShapeRepresentationFnFactory registered_device_shape_representation_fn_; }; diff --git a/third_party/xla/xla/tests/query_inferred_shape_test.cc b/third_party/xla/xla/tests/query_inferred_shape_test.cc index ca1f274d6f0f2b..673b736ca87aa6 100644 --- a/third_party/xla/xla/tests/query_inferred_shape_test.cc +++ b/third_party/xla/xla/tests/query_inferred_shape_test.cc @@ -33,7 +33,7 @@ TEST_F(QueryInferredShapeTest, OnePlusOneShape) { XlaBuilder builder("one_plus_one"); auto one = ConstantR0(&builder, 1.0); auto result = Add(one, one); - StatusOr shape_status = builder.GetShape(result); + absl::StatusOr shape_status = builder.GetShape(result); ASSERT_IS_OK(shape_status.status()); auto shape = shape_status.value(); ASSERT_TRUE(ShapeUtil::Equal(shape, ShapeUtil::MakeShape(F32, {}))); diff --git a/third_party/xla/xla/tests/reduce_hlo_test.cc b/third_party/xla/xla/tests/reduce_hlo_test.cc index ae2c74fdec8fec..60378a80daafaa 100644 --- a/third_party/xla/xla/tests/reduce_hlo_test.cc +++ b/third_party/xla/xla/tests/reduce_hlo_test.cc @@ -51,7 +51,7 @@ class ReduceWithLayoutTest : public HloTestBase, public ::testing::WithParamInterface { public: - StatusOr> GetParsedModule() { + absl::StatusOr> GetParsedModule() { const char* const hlo_string = R"( HloModule BadReduce @@ -125,9 +125,7 @@ INSTANTIATE_TEST_CASE_P(ReduceWithLayoutTest_Instantiation, ReduceLayout{{3, 2, 1, 0}, {1, 0, 2}}, // ReduceLayout{{3, 2, 1, 0}, {2, 0, 1}}, // ReduceLayout{{3, 2, 1, 0}, {2, 1, 0}}, // - ReduceLayout{{3, 1, 2, 0}, {1, 2, 0}}, // - ReduceLayout{{1, 2, 3, 0}, {1, 0, 2}}, // - ReduceLayout{{0, 2, 1, 3}, {2, 0, 1}}), // + ReduceLayout{{3, 1, 2, 0}, {1, 2, 0}}), // PrintReduceLayout); } // namespace diff --git a/third_party/xla/xla/tests/select_and_scatter_test.cc b/third_party/xla/xla/tests/select_and_scatter_test.cc index 3addef18a25def..c4c9bb7b0a3fe0 100644 --- a/third_party/xla/xla/tests/select_and_scatter_test.cc +++ b/third_party/xla/xla/tests/select_and_scatter_test.cc @@ -86,7 +86,7 @@ XLA_TEST_P(SelectAndScatterTest, ParamTest) { GetParam().window_strides, GetParam().padding_type, source, ConstantR0(&builder_, 0.0f), add_f32_); - ComputeAndCompare(&builder_, {}, ErrorSpec(1e-5, 1e-5)); + ComputeAndCompare(&builder_, {}, ErrorSpec(3e-5, 3e-5)); } INSTANTIATE_TEST_CASE_P( diff --git a/third_party/xla/xla/tests/test_utils.cc b/third_party/xla/xla/tests/test_utils.cc index 608dd63a9ac156..46fdb80093ba44 100644 --- a/third_party/xla/xla/tests/test_utils.cc +++ b/third_party/xla/xla/tests/test_utils.cc @@ -230,7 +230,7 @@ void PopulateWithRandomIntegralDataWithBounds(Literal* literal, // floating point format. (floating point format only) // 'max_bits_of_precision' sets the data to have the given number of bits or // less (integer or floating point formats only). -StatusOr MakeFakeLiteralInternal( +absl::StatusOr MakeFakeLiteralInternal( const Shape& shape, std::minstd_rand0* engine, std::optional> limit, bool is_sorted, bool no_duplicates, bool use_large_range, @@ -469,7 +469,7 @@ std::vector FindConstrainedUses( // no constrained uses in the dataflow graph. If such constraints exist, // generate a constrained literal (either bounded in the case of indices, or // zero in the case of init_values for reductions). -StatusOr CreateLiteralForConstrainedUses( +absl::StatusOr CreateLiteralForConstrainedUses( const absl::Span constrained_uses, const HloInstruction& param, const Shape& param_shape, std::minstd_rand0* engine, bool use_large_range, @@ -579,7 +579,7 @@ StatusOr CreateLiteralForConstrainedUses( // Given a module entry parameter, use the dataflow analysis to see if a // special case literal must be created, or if we can generate fake data. -StatusOr MakeConstrainedArgument( +absl::StatusOr MakeConstrainedArgument( const HloDataflowAnalysis& dataflow, const HloInstruction& param, const Shape& param_shape, std::minstd_rand0* engine, bool use_large_range, bool treat_gte_as_data_formatting, @@ -593,8 +593,8 @@ StatusOr MakeConstrainedArgument( } // namespace -StatusOr MakeFakeLiteral(const Shape& shape, bool pseudo_random, - bool use_large_range) { +absl::StatusOr MakeFakeLiteral(const Shape& shape, bool pseudo_random, + bool use_large_range) { auto engine = pseudo_random ? std::make_unique() : nullptr; return MakeFakeLiteralInternal(shape, engine.get(), /*limit=*/std::nullopt, /*is_sorted=*/false, @@ -602,7 +602,7 @@ StatusOr MakeFakeLiteral(const Shape& shape, bool pseudo_random, /*max_bits_of_precision=*/std::nullopt); } -StatusOr> MakeFakeArguments( +absl::StatusOr> MakeFakeArguments( const HloModule* module, bool pseudo_random, bool use_large_range, bool treat_gte_as_data_formatting, std::optional max_bits_of_precision) { @@ -611,7 +611,7 @@ StatusOr> MakeFakeArguments( treat_gte_as_data_formatting, max_bits_of_precision); } -StatusOr> MakeFakeArguments( +absl::StatusOr> MakeFakeArguments( const HloModule* module, std::minstd_rand0* engine, bool use_large_range, bool treat_gte_as_data_formatting, std::optional max_bits_of_precision) { diff --git a/third_party/xla/xla/tests/test_utils.h b/third_party/xla/xla/tests/test_utils.h index 2d20a4c600371f..fb101102d6381d 100644 --- a/third_party/xla/xla/tests/test_utils.h +++ b/third_party/xla/xla/tests/test_utils.h @@ -56,8 +56,9 @@ class PseudorandomGenerator { // Generates fake data in a literal of the given shape, or returns an error // status if the element type is currently unhandled for fake data // generation. See below for documentation of pseudo_random and use_large_range. -StatusOr MakeFakeLiteral(const Shape& shape, bool pseudo_random = true, - bool use_large_range = false); +absl::StatusOr MakeFakeLiteral(const Shape& shape, + bool pseudo_random = true, + bool use_large_range = false); // Generates a vector of arguments containing fake data. The number, shape and // layout of the arguments is appropriate for given HLO module. @@ -93,7 +94,7 @@ StatusOr MakeFakeLiteral(const Shape& shape, bool pseudo_random = true, // TODO(b/79942829): Make interesting argument generation fast enough that using // pseudo_random does not save any noticeable amount of time so that the // parameter can be removed. -StatusOr> MakeFakeArguments( +absl::StatusOr> MakeFakeArguments( const HloModule* module, bool pseudo_random = true, bool use_large_range = false, bool treat_gte_as_data_formatting = false, std::optional max_bits_of_precision = std::nullopt); @@ -101,7 +102,7 @@ StatusOr> MakeFakeArguments( // Overload which accepts a random number generator. This enables generation of // different random values with sequential calls to MakeFakeArguments by reusing // the same generator. -StatusOr> MakeFakeArguments( +absl::StatusOr> MakeFakeArguments( const HloModule* module, std::minstd_rand0* engine, bool use_large_range = false, bool treat_gte_as_data_formatting = false, std::optional max_bits_of_precision = std::nullopt); diff --git a/third_party/xla/xla/tests/transfer_manager_test.cc b/third_party/xla/xla/tests/transfer_manager_test.cc index 3d401be36ae373..eb7ac0fb4e032e 100644 --- a/third_party/xla/xla/tests/transfer_manager_test.cc +++ b/third_party/xla/xla/tests/transfer_manager_test.cc @@ -328,7 +328,7 @@ XLA_TEST_F(TransferManagerTest, MultiStreamRoundTripSoak) { auto device_buffer2 = AllocateDeviceBuffer(literal2.shape()); auto stream1 = stream_; - auto stream2 = stream_->GetOrCreateSubStream(); + auto stream2 = stream_->GetOrCreateSubStream().value(); Literal result1, result2; diff --git a/third_party/xla/xla/tests/value_inference_test.cc b/third_party/xla/xla/tests/value_inference_test.cc index 23dcc1acf667b7..819e7f87a90c92 100644 --- a/third_party/xla/xla/tests/value_inference_test.cc +++ b/third_party/xla/xla/tests/value_inference_test.cc @@ -56,8 +56,8 @@ class DynamismInferenceTest : public ValueInferenceTest { explicit DynamismInferenceTest(se::Platform* platform = nullptr) : platform_(platform) {} - StatusOr ComputeDynamismLiteral(XlaOp operand, XlaBuilder* builder, - Layout* output_layout = nullptr) { + absl::StatusOr ComputeDynamismLiteral( + XlaOp operand, XlaBuilder* builder, Layout* output_layout = nullptr) { TF_RETURN_IF_ERROR(builder->first_error()); ValueInference value_inference(builder); TF_ASSIGN_OR_RETURN(auto literal_slice, @@ -65,8 +65,8 @@ class DynamismInferenceTest : public ValueInferenceTest { return literal_slice.Clone(); } - StatusOr ComputeDynamismScalar(XlaOp operand, XlaBuilder* builder, - ShapeIndex index = {}) { + absl::StatusOr ComputeDynamismScalar(XlaOp operand, XlaBuilder* builder, + ShapeIndex index = {}) { TF_ASSIGN_OR_RETURN(auto literal, ComputeDynamismLiteral(operand, builder, nullptr)); return literal.Get({}, index); @@ -558,7 +558,7 @@ class UpperBoundInferenceTest : public ValueInferenceTest { explicit UpperBoundInferenceTest(se::Platform* platform = nullptr) : platform_(platform) {} - StatusOr ComputeUpperBoundLiteral( + absl::StatusOr ComputeUpperBoundLiteral( XlaOp operand, XlaBuilder* builder, Layout* output_layout = nullptr) { ValueInference value_inference(builder); TF_ASSIGN_OR_RETURN(auto literal, @@ -715,7 +715,7 @@ class ConstValueInferenceTest : public ValueInferenceTest { explicit ConstValueInferenceTest(se::Platform* platform = nullptr) : platform_(platform) {} - StatusOr ComputeConstantValueLiteral( + absl::StatusOr ComputeConstantValueLiteral( XlaOp operand, XlaBuilder* builder, Layout* output_layout = nullptr) { ValueInference value_inference(builder); TF_ASSIGN_OR_RETURN(auto literal, value_inference.AnalyzeConstant( diff --git a/third_party/xla/xla/text_literal_reader.cc b/third_party/xla/xla/text_literal_reader.cc index a3b6195be3b875..06773cc229f03f 100644 --- a/third_party/xla/xla/text_literal_reader.cc +++ b/third_party/xla/xla/text_literal_reader.cc @@ -39,7 +39,7 @@ limitations under the License. namespace xla { -StatusOr TextLiteralReader::ReadPath(absl::string_view path) { +absl::StatusOr TextLiteralReader::ReadPath(absl::string_view path) { CHECK(!absl::EndsWith(path, ".gz")) << "TextLiteralReader no longer supports reading .gz files"; std::unique_ptr file; @@ -55,7 +55,7 @@ StatusOr TextLiteralReader::ReadPath(absl::string_view path) { TextLiteralReader::TextLiteralReader(tsl::RandomAccessFile* file) : file_(file) {} -StatusOr TextLiteralReader::ReadAllLines() { +absl::StatusOr TextLiteralReader::ReadAllLines() { tsl::io::RandomAccessInputStream stream(file_.get()); tsl::io::BufferedInputStream buf(&stream, 65536); std::string shape_string; diff --git a/third_party/xla/xla/text_literal_reader.h b/third_party/xla/xla/text_literal_reader.h index b008447c69223b..a0d56611bac030 100644 --- a/third_party/xla/xla/text_literal_reader.h +++ b/third_party/xla/xla/text_literal_reader.h @@ -40,7 +40,7 @@ class TextLiteralReader { public: // See class comment -- reads a file in its entirety (there must be only one // literal in the text file path provided). - static StatusOr ReadPath(absl::string_view path); + static absl::StatusOr ReadPath(absl::string_view path); private: // Ownership of file is transferred. @@ -48,7 +48,7 @@ class TextLiteralReader { // Parses a shape string on the first line, followed by lines of values to the // end of the file. - StatusOr ReadAllLines(); + absl::StatusOr ReadAllLines(); // Owns the file being read std::unique_ptr file_; diff --git a/third_party/xla/xla/tools/BUILD b/third_party/xla/xla/tools/BUILD index e790d6657010cf..1d136b8935abff 100644 --- a/third_party/xla/xla/tools/BUILD +++ b/third_party/xla/xla/tools/BUILD @@ -28,7 +28,8 @@ load( ) package( - default_visibility = ["//visibility:public"], + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = ["//xla:internal"], licenses = ["notice"], ) @@ -39,7 +40,7 @@ filegroup( "**/*.cc", "**/*.h", ]), - visibility = ["//visibility:public"], + visibility = ["//xla:internal"], ) build_test( @@ -253,7 +254,6 @@ cc_library( name = "hlo_extractor", srcs = ["hlo_extractor.cc"], hdrs = ["hlo_extractor.h"], - visibility = ["//visibility:public"], deps = [ "//xla:literal", "//xla:literal_util", @@ -291,7 +291,6 @@ xla_cc_binary( cc_library( name = "hlo_expand_main", srcs = ["hlo_expand_main.cc"], - visibility = ["//visibility:public"], deps = [ ":hlo_expand_lib", ":hlo_module_loader", @@ -312,7 +311,6 @@ cc_library( name = "hlo_expand_lib", srcs = ["hlo_expand.cc"], hdrs = ["hlo_expand.h"], - visibility = ["//visibility:public"], deps = [ "//xla:xla_data_proto_cc", "//xla/service:batchnorm_expander", @@ -370,7 +368,6 @@ cc_library( name = "hlo_slicer", srcs = ["hlo_slicer.cc"], hdrs = ["hlo_slicer.h"], - visibility = ["//visibility:public"], deps = [ ":hlo_extractor", "//xla:shape_util", @@ -436,7 +433,7 @@ cc_library( name = "hlo_module_loader", srcs = ["hlo_module_loader.cc"], hdrs = ["hlo_module_loader.h"], - visibility = ["//visibility:public"], + visibility = ["//xla:friends"], deps = [ ":run_hlo_module_proto_cc", "//xla:debug_options_flags", @@ -468,7 +465,6 @@ cc_library( name = "prepare_reference_module", srcs = ["prepare_reference_module.cc"], hdrs = ["prepare_reference_module.h"], - visibility = ["//visibility:public"], deps = [ "//xla:debug_options_flags", "//xla:statusor", @@ -500,29 +496,51 @@ xla_py_proto_library( deps = [":run_hlo_module_proto"], ) +cc_library( + name = "hlo_decomposer_lib", + srcs = ["hlo_decomposer.cc"], + hdrs = ["hlo_decomposer.h"], + deps = [ + "//xla:status", + "//xla/hlo/ir:hlo", + "//xla/service:call_graph", + "//xla/service:compilation_environments", + "@com_google_absl//absl/container:flat_hash_set", + "@local_tsl//tsl/platform:statusor", + ], +) + cc_library( name = "run_hlo_module_lib", srcs = ["run_hlo_module.cc"], hdrs = ["run_hlo_module.h"], - visibility = ["//visibility:public"], deps = [ ":hlo_control_flow_flattening", + ":hlo_decomposer_lib", ":hlo_module_loader", ":prepare_reference_module", ":run_hlo_module_proto_cc", "//xla:error_spec", "//xla:literal", "//xla:literal_comparison", + "//xla:status", "//xla:util", "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", + "//xla/service:hlo_module_config", "//xla/service:hlo_proto_cc", "//xla/service:hlo_runner", "//xla/service:hlo_verifier", "//xla/tests:test_utils", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", + "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:path", "@local_tsl//tsl/platform:status", + "@local_tsl//tsl/platform:statusor", ], ) @@ -575,7 +593,6 @@ cc_library( name = "hlo_control_flow_flattening", srcs = ["hlo_control_flow_flattening.cc"], hdrs = ["hlo_control_flow_flattening.h"], - visibility = ["//visibility:public"], deps = [ "//xla:literal_util", "//xla/hlo/ir:hlo", @@ -612,7 +629,6 @@ cc_library( name = "driver", srcs = ["driver.cc"], tags = ["nofixdeps"], - visibility = ["//visibility:public"], deps = [], ) @@ -640,7 +656,6 @@ tsl_gpu_library( cuda_deps = [ ], defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]) + if_rocm_is_configured(["TENSORFLOW_USE_ROCM=1"]), - visibility = ["//visibility:public"], deps = [ "//xla:util", "//xla/hlo/ir:hlo", @@ -715,3 +730,18 @@ xla_test( "@local_tsl//tsl/protobuf:error_codes_proto_impl_cc", ], ) + +xla_test( + name = "hlo_decomposer_test", + srcs = ["hlo_decomposer_test.cc"], + deps = [ + ":hlo_decomposer_lib", + "//xla/hlo/ir:hlo", + "//xla/tests:filecheck", + "//xla/tests:hlo_test_base", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/strings:string_view", + "@com_google_googletest//:gtest_main", + "@local_tsl//tsl/platform:statusor", + ], +) diff --git a/third_party/xla/xla/tools/hlo_bisect/BUILD b/third_party/xla/xla/tools/hlo_bisect/BUILD index cc965be000854b..f0dce33a38533a 100644 --- a/third_party/xla/xla/tools/hlo_bisect/BUILD +++ b/third_party/xla/xla/tools/hlo_bisect/BUILD @@ -10,7 +10,8 @@ load( ) package( - default_visibility = ["//visibility:public"], + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = ["//xla:internal"], ) build_test( @@ -40,7 +41,6 @@ cc_library( name = "hlo_bisect_state", srcs = ["hlo_bisect_state.cc"], hdrs = ["hlo_bisect_state.h"], - visibility = ["//visibility:public"], deps = [ "//xla:literal", "//xla:status", @@ -77,7 +77,6 @@ cc_library( testonly = True, srcs = ["hlo_bisect_utils.cc"], hdrs = ["hlo_bisect_utils.h"], - visibility = ["//visibility:public"], deps = [ ":hlo_bisect_state", "//xla:error_spec", diff --git a/third_party/xla/xla/tools/hlo_bisect/hlo_bisect_state.cc b/third_party/xla/xla/tools/hlo_bisect/hlo_bisect_state.cc index bbae56edb7c820..19f1b29fd1b77f 100644 --- a/third_party/xla/xla/tools/hlo_bisect/hlo_bisect_state.cc +++ b/third_party/xla/xla/tools/hlo_bisect/hlo_bisect_state.cc @@ -68,7 +68,7 @@ Status MorphModuleWithOutputs(HloModule* module, module->compute_computation_layout(); HloDCE dce; - StatusOr dce_result = dce.Run(module); + absl::StatusOr dce_result = dce.Run(module); return dce_result.status(); } @@ -127,7 +127,7 @@ Status MorphModuleWithLiterals( } xla::HloDCE dce; - StatusOr dce_status = dce.Run(module); + absl::StatusOr dce_status = dce.Run(module); return dce_status.status(); } @@ -144,12 +144,12 @@ bool InstructionNotReplaceableWithConstant(HloInstruction* instruction) { } // namespace -StatusOr HloBisectState::ShouldProcess() { +absl::StatusOr HloBisectState::ShouldProcess() { // Running the unmodified module should trigger the bug checker. return RunModule(*module_); } -StatusOr HloBisectState::TrimEntryComputation() { +absl::StatusOr HloBisectState::TrimEntryComputation() { bool changed_in_loop = false; bool changed = false; for (int iter = 0; changed || iter < 2; iter++) { @@ -172,11 +172,11 @@ std::unique_ptr&& HloBisectState::GetResult() { return std::move(module_); } -StatusOr HloBisectState::RunModule(const HloModule& module) { +absl::StatusOr HloBisectState::RunModule(const HloModule& module) { VLOG(3) << "Modified module: " << module.ToString(); // Run the modified module with the bug checker. - StatusOr bug_result = bug_checker_->Run(module); + absl::StatusOr bug_result = bug_checker_->Run(module); TF_RETURN_IF_ERROR(bug_result.status()); VLOG(3) << "Bug checker result: " << bug_result.value(); @@ -192,7 +192,7 @@ StatusOr HloBisectState::RunModule(const HloModule& module) { return bug_result; } -StatusOr HloBisectState::TrimByOutputs() { +absl::StatusOr HloBisectState::TrimByOutputs() { // Only available if the root instruction is a tuple. HloInstruction* root_instruction = module_->entry_computation()->root_instruction(); @@ -202,7 +202,7 @@ StatusOr HloBisectState::TrimByOutputs() { } // Run the modified module and return the error state. - auto run_modified = [&](int64_t start, int64_t end) -> StatusOr { + auto run_modified = [&](int64_t start, int64_t end) -> absl::StatusOr { std::unique_ptr new_module = module_->Clone(/*suffix=*/""); HloInstruction* const* new_operands = new_module->entry_computation()->root_instruction()->operands().begin(); @@ -245,7 +245,7 @@ StatusOr HloBisectState::TrimByOutputs() { return changed; } -StatusOr HloBisectState::TrimByInstructions() { +absl::StatusOr HloBisectState::TrimByInstructions() { HloComputation* computation = module_->entry_computation(); // If the root instruction is a tuple, exclude it from the bisect range. @@ -285,7 +285,7 @@ StatusOr HloBisectState::TrimByInstructions() { return changed; } -StatusOr HloBisectState::TrimByUsingConstants() { +absl::StatusOr HloBisectState::TrimByUsingConstants() { // Use random literals for the instructions which do not trigger the bug // checker and also didn't get a definitive value from it. absl::flat_hash_map literal_map; @@ -298,7 +298,7 @@ StatusOr HloBisectState::TrimByUsingConstants() { auto it = foldable_instructions_values_.extract(instr->name()); literal_map.insert(std::move(it)); } else if (foldable_instructions_.contains(instr->name())) { - StatusOr literal_status = MakeFakeLiteral(instr->shape()); + absl::StatusOr literal_status = MakeFakeLiteral(instr->shape()); TF_RETURN_IF_ERROR(literal_status.status()); literal_map[instr->name()] = std::move(literal_status).value(); ++random_literals_count; diff --git a/third_party/xla/xla/tools/hlo_bisect/hlo_bisect_state.h b/third_party/xla/xla/tools/hlo_bisect/hlo_bisect_state.h index 8df2713c4df98a..6cc19dd7bcbee9 100644 --- a/third_party/xla/xla/tools/hlo_bisect/hlo_bisect_state.h +++ b/third_party/xla/xla/tools/hlo_bisect/hlo_bisect_state.h @@ -38,7 +38,7 @@ class BugCheckerInterface { virtual ~BugCheckerInterface() {} // Returns true if `module` has a bug we're interested in. - virtual StatusOr Run(const HloModule& module) = 0; + virtual absl::StatusOr Run(const HloModule& module) = 0; // Returns mapping of instruction names to their results after the run // (empty if this information is unavailable). @@ -54,11 +54,11 @@ class HloBisectState { : module_(std::move(module)), bug_checker_(bug_checker) {} // Returns true if the current module has a bug and should be processed. - StatusOr ShouldProcess(); + absl::StatusOr ShouldProcess(); // Trims the entry computation until no more reductions are possible. Returns // a boolean to indicate whether the computation has been reduced. - StatusOr TrimEntryComputation(); + absl::StatusOr TrimEntryComputation(); // Returns the resulting module. std::unique_ptr&& GetResult(); @@ -66,19 +66,19 @@ class HloBisectState { private: // Runs a modified module and updates the foldable instructions data, if // available. Returns true if `module` has a bug. - StatusOr RunModule(const HloModule& module); + absl::StatusOr RunModule(const HloModule& module); // Trims the entry computation by reducing the total number of outputs. // Returns a boolean to indicate whether the computation has been reduced. - StatusOr TrimByOutputs(); + absl::StatusOr TrimByOutputs(); // Trims the entry computation by reducing the total number of instructions. // Returns a boolean to indicate whether the computation has been reduced. - StatusOr TrimByInstructions(); + absl::StatusOr TrimByInstructions(); // Trims the given computation by replacing instructions with constant values. // Returns a boolean to indicate whether the computation has been reduced. - StatusOr TrimByUsingConstants(); + absl::StatusOr TrimByUsingConstants(); // Asserts that the module still has the bug. If negative, runs the bug // checker repeatedly to verify that it's deterministic. diff --git a/third_party/xla/xla/tools/hlo_bisect/hlo_bisect_state_test.cc b/third_party/xla/xla/tools/hlo_bisect/hlo_bisect_state_test.cc index fc82db2b0465e7..3e608c82ddb16b 100644 --- a/third_party/xla/xla/tools/hlo_bisect/hlo_bisect_state_test.cc +++ b/third_party/xla/xla/tools/hlo_bisect/hlo_bisect_state_test.cc @@ -43,7 +43,7 @@ class TestBugSearch : public BugCheckerInterface { public: TestBugSearch(std::initializer_list opcodes) : opcodes_(opcodes) {} - StatusOr Run(const HloModule& module) override { + absl::StatusOr Run(const HloModule& module) override { auto has_opcode = [&](HloOpcode opcode) { return absl::c_any_of(module.entry_computation()->instructions(), [opcode](const HloInstruction* instr) { @@ -173,7 +173,7 @@ TEST_F(HloBisectStateTest, TrimByOutputsLostBug) { class CustomBugSearch : public TestBugSearch { public: CustomBugSearch() : TestBugSearch({HloOpcode::kConstant}) {} - StatusOr Run(const HloModule& module) override { + absl::StatusOr Run(const HloModule& module) override { TF_ASSIGN_OR_RETURN(bool has_constants, TestBugSearch::Run(module)); int program_size = module.entry_computation()->instruction_count(); return program_size == 5 && !has_constants; diff --git a/third_party/xla/xla/tools/hlo_bisect/hlo_bisect_utils.cc b/third_party/xla/xla/tools/hlo_bisect/hlo_bisect_utils.cc index 59645adf312317..338d4755ce47c9 100644 --- a/third_party/xla/xla/tools/hlo_bisect/hlo_bisect_utils.cc +++ b/third_party/xla/xla/tools/hlo_bisect/hlo_bisect_utils.cc @@ -56,7 +56,7 @@ Literal ExecuteWithRunnerAndRetrieveResult(std::unique_ptr module, } // Loads the given HloProto as HloModule. -StatusOr> LoadModuleFromHloProto( +absl::StatusOr> LoadModuleFromHloProto( const HloProto& proto) { const HloModuleProto& module_proto = proto.hlo_module(); TF_ASSIGN_OR_RETURN(const HloModuleConfig module_config, @@ -65,8 +65,9 @@ StatusOr> LoadModuleFromHloProto( return CreateModuleFromProto(module_proto, module_config); } -StatusOr> LoadModuleAndInputDataFromHloSnapshot( - const HloSnapshot& snapshot, std::vector* input_data) { +absl::StatusOr> +LoadModuleAndInputDataFromHloSnapshot(const HloSnapshot& snapshot, + std::vector* input_data) { for (int64_t i = 0; i < snapshot.arguments_size(); ++i) { TF_ASSIGN_OR_RETURN(Literal literal, Literal::CreateFromProto(snapshot.arguments(i))); @@ -79,7 +80,7 @@ StatusOr> LoadModuleAndInputDataFromHloSnapshot( return HloModule::CreateFromProto(snapshot.hlo().hlo_module(), config); } -StatusOr GetModuleAndInputData( +absl::StatusOr GetModuleAndInputData( absl::string_view input_filename) { const std::string input_file(input_filename); tsl::Env* env = tsl::Env::Default(); @@ -95,7 +96,7 @@ StatusOr GetModuleAndInputData( } LOG(INFO) << input_file << " is not HloSnapshot. Trying HLO binary proto.\n"; HloProto hlo_proto; - StatusOr> module_or_status; + absl::StatusOr> module_or_status; if (tsl::ReadBinaryProto(env, input_file, &hlo_proto).ok()) { module_or_status = LoadModuleFromHloProto(hlo_proto); if (!module_or_status.ok()) { @@ -168,7 +169,7 @@ MiscompareChecker::MiscompareChecker(HloModule* module, // Generate input data and store the data for all the execution. std::minstd_rand0 rng_engine; if (input_data.empty()) { - StatusOr> input_status = + absl::StatusOr> input_status = MakeFakeArguments(module, &rng_engine); CHECK(input_status.ok()); input_data_ = std::move(input_status).value(); @@ -178,14 +179,14 @@ MiscompareChecker::MiscompareChecker(HloModule* module, } // Set up the reference platform. - StatusOr reference_platform_status = + absl::StatusOr reference_platform_status = PlatformUtil::GetPlatform(std::string(reference_platform)); CHECK(reference_platform_status.ok()); reference_runner_ = std::make_unique(reference_platform_status.value()); // Set up the test platform. - StatusOr test_platform_status = + absl::StatusOr test_platform_status = PlatformUtil::GetPlatform(std::string(test_platform)); CHECK(test_platform_status.ok()); test_runner_ = @@ -195,7 +196,7 @@ MiscompareChecker::MiscompareChecker(HloModule* module, // Executes the module with the test_runner and the reference_runner and // compares the results from the two runs. Returns true if the two results are // not near to indicate a bug exists. -StatusOr MiscompareChecker::Run(const HloModule& module) { +absl::StatusOr MiscompareChecker::Run(const HloModule& module) { std::unique_ptr test_module = module.Clone(/*suffix=*/""); // Make sure that the module config has a non-zero seed, which the CPU and GPU @@ -224,7 +225,7 @@ StatusOr MiscompareChecker::Run(const HloModule& module) { /*run_hlo_passes=*/true); // Compare the results. - StatusOr<::testing::AssertionResult> status_or_result = + absl::StatusOr<::testing::AssertionResult> status_or_result = LiteralTestUtil::Near(/*expected=*/reference_result, /*actual=*/test_result, /*error_spec=*/error_spec_, @@ -240,13 +241,14 @@ absl::flat_hash_map MiscompareChecker::GetResults() { return {}; } -StatusOr> MiscompareChecker::PrepareReferenceModule( +absl::StatusOr> +MiscompareChecker::PrepareReferenceModule( const HloModule& hlo_module, HloRunnerInterface* hlo_runner) const { // By default clone the test module (could be overridden). return xla::PrepareReferenceModule(hlo_module, hlo_runner); } -StatusOr ScriptChecker::Run(const HloModule& module) { +absl::StatusOr ScriptChecker::Run(const HloModule& module) { tsl::Env* env = tsl::Env::Default(); // Write hlo into a temporary file. std::string hlo_path; @@ -292,7 +294,7 @@ absl::flat_hash_map ScriptChecker::GetResults() { return {}; } -StatusOr> BisectRunner::RunEntry() { +absl::StatusOr> BisectRunner::RunEntry() { HloBisectState hlo_bisect(std::move(module_), bug_checker_.get()); TF_ASSIGN_OR_RETURN(bool has_bug, hlo_bisect.ShouldProcess()); if (!has_bug) { @@ -305,13 +307,13 @@ StatusOr> BisectRunner::RunEntry() { return hlo_bisect.GetResult(); } -StatusOr> BisectRunner::RunAll() { +absl::StatusOr> BisectRunner::RunAll() { std::unique_ptr original_module = std::move(module_); std::unique_ptr result; for (HloComputation* c : original_module->computations()) { LOG(INFO) << "Bisecting computation: " << c->name(); module_ = original_module->Clone(/*suffix=*/""); - StatusOr> new_result; + absl::StatusOr> new_result; if (c->IsEntryComputation()) { // Run on the entry computation with input data. new_result = RunEntry(); @@ -340,7 +342,7 @@ StatusOr> BisectRunner::RunAll() { void RunBisect(std::unique_ptr runner, bool all_computations, absl::string_view dump_path, absl::string_view output_format) { - StatusOr> bisect_status = + absl::StatusOr> bisect_status = all_computations ? runner->RunAll() : runner->RunEntry(); CHECK(bisect_status.ok()) << bisect_status.status().message(); @@ -351,7 +353,7 @@ void RunBisect(std::unique_ptr runner, bool all_computations, CHECK(dump_status.ok()) << dump_status.message(); } -StatusOr GetVerifiedModuleAndInputData( +absl::StatusOr GetVerifiedModuleAndInputData( absl::string_view input_filename) { std::unique_ptr module; std::vector input_data; diff --git a/third_party/xla/xla/tools/hlo_bisect/hlo_bisect_utils.h b/third_party/xla/xla/tools/hlo_bisect/hlo_bisect_utils.h index 69d058c43263f2..b0649a78579aec 100644 --- a/third_party/xla/xla/tools/hlo_bisect/hlo_bisect_utils.h +++ b/third_party/xla/xla/tools/hlo_bisect/hlo_bisect_utils.h @@ -41,10 +41,10 @@ class MiscompareChecker : public BugCheckerInterface { absl::string_view test_platform, absl::string_view reference_platform, ErrorSpec error_spec); - StatusOr Run(const HloModule& module) override; + absl::StatusOr Run(const HloModule& module) override; absl::flat_hash_map GetResults() override; - virtual StatusOr> PrepareReferenceModule( + virtual absl::StatusOr> PrepareReferenceModule( const HloModule& hlo_module, HloRunnerInterface* hlo_runner) const; private: @@ -61,7 +61,7 @@ class ScriptChecker : public BugCheckerInterface { public: explicit ScriptChecker(std::string path_to_script) : path_to_script_(std::move(path_to_script)) {} - StatusOr Run(const HloModule& module) override; + absl::StatusOr Run(const HloModule& module) override; absl::flat_hash_map GetResults() override; private: @@ -75,8 +75,8 @@ class BisectRunner { std::unique_ptr bug_checker) : module_(std::move(module)), bug_checker_(std::move(bug_checker)) {} - StatusOr> RunEntry(); - StatusOr> RunAll(); + absl::StatusOr> RunEntry(); + absl::StatusOr> RunAll(); protected: std::unique_ptr module_; @@ -90,7 +90,7 @@ void RunBisect(std::unique_ptr runner, bool all_computations, // Utility function for getting the verified module and optional inputs. using ModuleWithInputs = std::pair, std::vector>; -xla::StatusOr GetVerifiedModuleAndInputData( +absl::StatusOr GetVerifiedModuleAndInputData( absl::string_view input_filename); } // namespace bisect diff --git a/third_party/xla/xla/tools/hlo_control_flow_flattening.cc b/third_party/xla/xla/tools/hlo_control_flow_flattening.cc index dcbec2bd759d5b..f1ff7706c7d3a6 100644 --- a/third_party/xla/xla/tools/hlo_control_flow_flattening.cc +++ b/third_party/xla/xla/tools/hlo_control_flow_flattening.cc @@ -155,7 +155,7 @@ Status HloControlFlowFlattening::FlattenWhileLoop( // non-get-tuple-element users with a new tuple instruction which has the // first N - 1 elements. auto replace_non_gte_users = - [](HloInstruction* new_tuple) -> StatusOr { + [](HloInstruction* new_tuple) -> absl::StatusOr { CHECK(new_tuple->shape().IsTuple()); HloInstruction* prefix = nullptr; std::vector users(new_tuple->users()); @@ -399,13 +399,18 @@ Status HloControlFlowFlattening::RemoveId(HloInstruction* hlo) const { return OkStatus(); } -StatusOr HloControlFlowFlattening::Run( +absl::StatusOr HloControlFlowFlattening::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { auto call_graph = CallGraph::Build(module); bool changed = false; absl::flat_hash_set removed; for (HloComputation* computation : module->computations(execution_threads)) { + // Do not change computations that are wrapped by async calls. Instead we + // remove the async callers if needed. + if (computation->IsAsyncComputation()) { + continue; + } for (HloInstruction* instruction : computation->MakeInstructionPostOrder()) { if (removed.contains(instruction)) { @@ -447,9 +452,28 @@ StatusOr HloControlFlowFlattening::Run( changed = true; } } else if (remove_comm_ && IsCollective(instruction) && - !instruction->parent()->IsFusionComputation()) { - VLOG(1) << "Remove " << instruction->name(); - TF_RETURN_IF_ERROR(RemoveCollective(instruction)); + !instruction->parent()->IsFusionComputation() && + (instruction->opcode() != HloOpcode::kAsyncStart && + instruction->opcode() != HloOpcode::kAsyncUpdate)) { + // We do not remove kAsyncStart or kAsyncUpdate here since we expect + // them to be removed as a part of the async chain above. + // We should remove the async chain all together because the async + // wrapped computation is only associated with the AsyncStart. So we + // need to refer to the AsyncStart in order to determine whether + // the Done or the Update wraps a collective. + if (instruction->opcode() == HloOpcode::kAsyncDone) { + while (instruction->opcode() == HloOpcode::kAsyncDone || + instruction->opcode() == HloOpcode::kAsyncUpdate || + instruction->opcode() == HloOpcode::kAsyncStart) { + HloInstruction* operand = instruction->mutable_operand(0); + VLOG(1) << "Remove " << instruction->name(); + TF_RETURN_IF_ERROR(RemoveCollective(instruction)); + instruction = operand; + } + } else { + VLOG(1) << "Remove " << instruction->name(); + TF_RETURN_IF_ERROR(RemoveCollective(instruction)); + } changed = true; } else if (remove_comm_ && (instruction->opcode() == HloOpcode::kPartitionId || diff --git a/third_party/xla/xla/tools/hlo_control_flow_flattening.h b/third_party/xla/xla/tools/hlo_control_flow_flattening.h index d4b6519b3349d1..30c05c30a73b0f 100644 --- a/third_party/xla/xla/tools/hlo_control_flow_flattening.h +++ b/third_party/xla/xla/tools/hlo_control_flow_flattening.h @@ -56,7 +56,7 @@ class HloControlFlowFlattening : public HloModulePass { ~HloControlFlowFlattening() override = default; absl::string_view name() const override { return "control-flow-flattening"; } using HloPassInterface::Run; - StatusOr Run( + absl::StatusOr Run( HloModule* module, const absl::flat_hash_set& execution_threads) override; diff --git a/third_party/xla/xla/tools/hlo_control_flow_flattening_test.cc b/third_party/xla/xla/tools/hlo_control_flow_flattening_test.cc index 83df2ec65dbad8..a391a59ebdad34 100644 --- a/third_party/xla/xla/tools/hlo_control_flow_flattening_test.cc +++ b/third_party/xla/xla/tools/hlo_control_flow_flattening_test.cc @@ -34,7 +34,7 @@ namespace op = xla::testing::opcode_matchers; class HloControlFlowFlatteningTest : public HloTestBase { public: - StatusOr> PartitionComputation( + absl::StatusOr> PartitionComputation( std::unique_ptr hlo_module, int64_t num_devices = 2) { spmd::SpmdPartitionerOptions options; auto collective_ops_creator = @@ -52,7 +52,7 @@ class HloControlFlowFlatteningTest : public HloTestBase { pass.AddPass(/*layout_sensitive=*/false, /*allow_mixed_precision=*/false); TF_RETURN_IF_ERROR(pass.Run(hlo_module.get()).status()); - return StatusOr>(std::move(hlo_module)); + return absl::StatusOr>(std::move(hlo_module)); } }; @@ -792,6 +792,27 @@ ENTRY main { EXPECT_EQ(module->entry_computation()->root_instruction()->name(), "fusion"); } +TEST_F(HloControlFlowFlatteningTest, AsyncAllToAll) { + absl::string_view hlo = R"( + + ENTRY main { + param = f32[4,8,128]{2,1,0} parameter(0) + all-to-all-start = ((f32[4,8,128]{2,1,0}), f32[4,8,128]{2,1,0}, u32[], u32[]) all-to-all-start(param), channel_id=1, replica_groups={{0,1,2,3,4,5,6,7}}, dimensions={1} + ROOT all-to-all-done = f32[4,8,128]{2,1,0} all-to-all-done(all-to-all-start) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo)); + EXPECT_TRUE(IsCollective(module->entry_computation()->root_instruction())); + HloControlFlowFlattening flattening({}); + EXPECT_TRUE(flattening.Run(module.get()).value()); + TF_ASSERT_OK(HloVerifier(/*layout_sensitive=*/true, + /*allow_mixed_precision=*/true) + .Run(module.get()) + .status()); + EXPECT_THAT(module->entry_computation()->root_instruction(), + op::CustomCall(op::CustomCall(op::Parameter(0)))); +} + void CheckWhileBound(HloInstruction* while_op, int expected_bound) { auto* cond = while_op->while_condition(); ASSERT_NE(cond, nullptr); diff --git a/third_party/xla/xla/tools/hlo_decomposer.cc b/third_party/xla/xla/tools/hlo_decomposer.cc new file mode 100644 index 00000000000000..a741549b801707 --- /dev/null +++ b/third_party/xla/xla/tools/hlo_decomposer.cc @@ -0,0 +1,152 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/tools/hlo_decomposer.h" + +#include +#include +#include +#include + +#include "absl/container/flat_hash_set.h" +#include "xla/hlo/ir/hlo_clone_context.h" +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/service/call_graph.h" +#include "xla/service/compilation_environments.h" +#include "xla/status.h" +#include "tsl/platform/statusor.h" + +namespace xla { +namespace { + +// Returns whether it makes sense to run the given instruction in isolation +// (e.g. whether it can run without dependent instructions). +bool ShouldIsolateOpcode(HloOpcode opcode) { + switch (opcode) { + case HloOpcode::kConstant: + case HloOpcode::kGetTupleElement: + case HloOpcode::kParameter: + case HloOpcode::kTuple: + return false; + default: + return true; + } +} + +absl::StatusOr>> Decompose( + const HloModule& module) { + std::vector> modules; + + absl::flat_hash_set computations_to_visit{ + module.entry_computation()}; + absl::flat_hash_set visited_computations; + + // Traverse the computation tree, starting from the entry computation, and + // recursing into the called computations. + while (!computations_to_visit.empty()) { + const HloComputation* computation = *computations_to_visit.begin(); + computations_to_visit.erase(computations_to_visit.begin()); + visited_computations.insert(computation); + + for (const HloInstruction* instruction : computation->instructions()) { + // Skip called computations in the embedded context (fusion, reduce, map, + // etc), as within these computations instructions are not lowered + // individually and it doesn't make sense to test them in isolation. + if (GetInstructionCallContext(instruction->opcode()) != + CallContext::kEmbedded) { + for (const HloComputation* called_computation : + instruction->called_computations()) { + if (!visited_computations.contains(called_computation)) { + computations_to_visit.insert(called_computation); + } + } + } + if (ShouldIsolateOpcode(instruction->opcode())) { + modules.push_back(ExtractInstructionIntoNewModule(*instruction)); + } + } + } + + return modules; +} + +} // namespace + +absl::StatusOr>> DecomposeHloModule( + const HloModule& module, bool deduplicate_modules) { + std::vector> modules; + absl::flat_hash_set module_fingerprints; + + auto should_add_module = [&](const HloModule* module) { + if (!deduplicate_modules) { + return true; + } + const std::string fingerprint = module->GetFingerprint128(); + if (module_fingerprints.contains(fingerprint)) { + return false; + } + module_fingerprints.insert(fingerprint); + return true; + }; + + TF_ASSIGN_OR_RETURN(std::vector> isolated_modules, + Decompose(module)); + for (auto& module : isolated_modules) { + if (should_add_module(module.get())) { + modules.push_back(std::move(module)); + } + } + return modules; +} + +std::unique_ptr ExtractInstructionIntoNewModule( + const HloInstruction& hlo) { + auto new_hlo_module = std::make_unique( + std::string(hlo.name()), HloModuleConfig{}, + std::make_unique(hlo.GetModule()->comp_envs())); + int parameter_number = 0; + HloComputation::Builder builder("entry_computation"); + HloCloneContext clone_context(new_hlo_module.get()); + std::vector new_operands; + for (const HloInstruction* operand : hlo.operands()) { + std::unique_ptr new_parameter = + HloInstruction::CreateParameter(parameter_number, operand->shape(), + operand->name()); + ++parameter_number; + new_operands.push_back(builder.AddInstruction(std::move(new_parameter))); + } + std::unique_ptr new_instruction = + hlo.CloneWithNewOperands(hlo.shape(), new_operands, &clone_context); + builder.AddInstruction(std::move(new_instruction)); + new_hlo_module->AddEntryComputationWithLayouts(builder.Build()); + return new_hlo_module; +} + +std::unique_ptr ExtractComputationIntoNewModule( + const HloComputation& computation) { + auto new_hlo_module = + std::make_unique("extracted", HloModuleConfig{}, + std::make_unique( + computation.parent()->comp_envs())); + HloCloneContext clone_context(new_hlo_module.get()); + new_hlo_module->AddEntryComputationWithLayouts( + computation.CloneInContext(clone_context)); + return new_hlo_module; +} + +} // namespace xla diff --git a/third_party/xla/xla/tools/hlo_decomposer.h b/third_party/xla/xla/tools/hlo_decomposer.h new file mode 100644 index 00000000000000..d16cee484c7f97 --- /dev/null +++ b/third_party/xla/xla/tools/hlo_decomposer.h @@ -0,0 +1,45 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_TOOLS_HLO_DECOMPOSER_H_ +#define XLA_TOOLS_HLO_DECOMPOSER_H_ + +#include +#include + +#include "xla/hlo/ir/hlo_module.h" + +namespace xla { + +// Decomposes the `module` into individual ops and de-duplicates the decomposed +// op if `deduplicate_modules` is true. The modules are considered duplicate if +// if their computation graphs are isomorphic (i.e. computations and +// instructions are sorted, names are ignored etc). +absl::StatusOr>> DecomposeHloModule( + const HloModule& module, bool deduplicate_modules); + +// Extracts an HLO instruction into a new HLO module replacing its operands +// with parameter instructions. +std::unique_ptr ExtractInstructionIntoNewModule( + const HloInstruction& hlo); + +// Extracts an HLO computation into a new HLO module, using its clone as the +// root computation. +std::unique_ptr ExtractComputationIntoNewModule( + const HloComputation& computation); + +} // namespace xla + +#endif // XLA_TOOLS_HLO_DECOMPOSER_H_ diff --git a/third_party/xla/xla/tools/hlo_decomposer_test.cc b/third_party/xla/xla/tools/hlo_decomposer_test.cc new file mode 100644 index 00000000000000..d60f94fdd26aa6 --- /dev/null +++ b/third_party/xla/xla/tools/hlo_decomposer_test.cc @@ -0,0 +1,161 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/tools/hlo_decomposer.h" + +#include +#include + +#include +#include "absl/algorithm/container.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/tests/filecheck.h" +#include "xla/tests/hlo_test_base.h" +#include "tsl/platform/statusor.h" + +namespace xla { +namespace { + +class HloDecomposerTest : public HloTestBase { + protected: + std::unique_ptr GetModule() { + absl::string_view kHlo = R"( +HloModule test_module, entry_computation_layout={(bf16[1024,8192]{1,0}, f32[8192]{0}, f32[16384]{0})->(bf16[1024]{0}, bf16[1024]{0}, f32[16384]{0}, f32[16384]{0})} + +add { + p0 = f32[] parameter(0) + p1 = f32[] parameter(1) + ROOT add.1 = f32[] add(p0, p1) +} + +fused_computation.1 { + param_1.3 = f32[8192]{0} parameter(1) + broadcast.2 = f32[1024,8192]{1,0} broadcast(param_1.3), dimensions={1} + param_0.3 = bf16[1024,8192]{1,0} parameter(0) + convert.5 = f32[1024,8192]{1,0} convert(param_0.3) + multiply.2 = f32[1024,8192]{1,0} multiply(broadcast.2, convert.5) + c0_1 = f32[] constant(0) + reduce.1 = f32[1024]{0} reduce(multiply.2, c0_1), dimensions={1}, to_apply=add + ROOT convert.4 = bf16[1024]{0} convert(reduce.1) +} + +fused_computation.2 { + p0.0 = bf16[1024,8192]{1,0} parameter(0) + c.0 = f32[1024,8192]{1,0} convert(p0.0) + co0_1.1 = f32[] constant(0) + p.0 = f32[8192]{0} parameter(1) + b.0 = f32[1024,8192]{1,0} broadcast(p.0), dimensions={1} + m.0 = f32[1024,8192]{1,0} multiply(b.0, c.0) + r.0 = f32[1024]{0} reduce(m.0, co0_1.1), dimensions={1}, to_apply=add + ROOT c.1 = bf16[1024]{0} convert(r.0) +} + +exp { + param_0.5 = f32[16384]{0} parameter(0) + m.4 = f32[16384]{0} multiply(param_0.5, param_0.5) + e = f32[16384]{0} exponential(m.4) + l.clone.1 = f32[16384]{0} log(m.4) + ROOT tuple = (f32[16384]{0}, f32[16384]{0}) tuple(e, l.clone.1) +} + +ENTRY main { + p0.1 = bf16[1024,8192]{1,0} parameter(0) + p1.1 = f32[8192]{0} parameter(1) + fusion.1 = bf16[1024]{0} fusion(p0.1, p1.1), kind=kInput, calls=fused_computation.1 + fusion.2 = bf16[1024]{0} fusion(p0.1, p1.1), kind=kInput, calls=fused_computation.2 + p2 = f32[16384]{0} parameter(2) + e.1 = (f32[16384]{0}, f32[16384]{0}) fusion(p2), kind=kInput, calls=exp + get-tuple-element.1 = f32[16384]{0} get-tuple-element(e.1), index=1 + get-tuple-element = f32[16384]{0} get-tuple-element(e.1), index=0 + ROOT result = (bf16[1024]{0}, bf16[1024]{0}, f32[16384]{0}, f32[16384]{0}) tuple(fusion.1, fusion.2, get-tuple-element.1, get-tuple-element) +})"; + return ParseAndReturnVerifiedModule(kHlo).value(); + } + + void FindAndCompare(const std::vector>& modules, + absl::string_view module_name, + absl::string_view pattern) { + auto iter = + absl::c_find_if(modules, [&](const std::unique_ptr& module) { + return module->name() == module_name; + }); + EXPECT_NE(iter, modules.end()) << "No module named " << module_name; + if (iter == modules.end()) { + return; + } + EXPECT_TRUE(*RunFileCheck((*iter)->ToString(), pattern)); + } +}; + +TEST_F(HloDecomposerTest, DecomposeNoDedup) { + auto module = GetModule(); + TF_ASSERT_OK_AND_ASSIGN( + auto decomposed, + DecomposeHloModule(*module, /*deduplicate_modules=*/false)); + EXPECT_EQ(decomposed.size(), 3); + + FindAndCompare(decomposed, "fusion.1", R"( +CHECK: %add{{.*}} { +CHECK: %fused_computation.1 +CHECK: ENTRY +CHECK-THEN: %parameter.0 = bf16[1024,8192]{1,0} parameter(0) +CHECK-THEN: %parameter.1 = f32[8192]{0} parameter(1) +CHECK-THEN: ROOT %fusion.1 +)"); + + FindAndCompare(decomposed, "fusion.2", R"( +CHECK: %add{{.*}} { +CHECK: %fused_computation.2 +CHECK: ENTRY +CHECK-THEN: %parameter.0 = bf16[1024,8192]{1,0} parameter(0) +CHECK-THEN: %parameter.1 = f32[8192]{0} parameter(1) +CHECK-THEN: ROOT %fusion.2 +)"); + + FindAndCompare(decomposed, "e.1", R"( +CHECK: %exp{{.*}} { +CHECK: ENTRY +CHECK-THEN: %parameter.0 = f32[16384]{0} parameter(0) +CHECK-THEN: ROOT %e.1 +)"); +} + +TEST_F(HloDecomposerTest, DecomposeDedup) { + auto module = GetModule(); + TF_ASSERT_OK_AND_ASSIGN( + auto decomposed, + DecomposeHloModule(*module, /*deduplicate_modules=*/true)); + EXPECT_EQ(decomposed.size(), 2); + + FindAndCompare(decomposed, "fusion.1", R"( +CHECK: %add{{.*}} { +CHECK: %fused_computation.1 +CHECK: ENTRY +CHECK-THEN: %parameter.0 = bf16[1024,8192]{1,0} parameter(0) +CHECK-THEN: %parameter.1 = f32[8192]{0} parameter(1) +CHECK-THEN: ROOT %fusion.1 +)"); + + FindAndCompare(decomposed, "e.1", R"( +CHECK: %exp{{.*}} { +CHECK: ENTRY +CHECK-THEN: %parameter.0 = f32[16384]{0} parameter(0) +CHECK-THEN: ROOT %e.1 +)"); +} + +} // namespace +} // namespace xla diff --git a/third_party/xla/xla/tools/hlo_expand.cc b/third_party/xla/xla/tools/hlo_expand.cc index dbefb0c46db680..70aed16de04941 100644 --- a/third_party/xla/xla/tools/hlo_expand.cc +++ b/third_party/xla/xla/tools/hlo_expand.cc @@ -61,7 +61,8 @@ void AddPassesToPipeline(HloExpandConfig& config, HloPassPipeline& pipeline, if (config.spmd_expander) { pipeline.AddPass( /*is_spmd=*/true, /*propagate_metadata=*/false, - hlo_module_config.allow_spmd_sharding_propagation_to_output()); + hlo_module_config.allow_spmd_sharding_propagation_to_output(), + hlo_module_config.allow_spmd_sharding_propagation_to_parameters()); pipeline.AddPass( hlo_module_config.num_partitions(), hlo_module_config.replica_count(), hlo_module_config.debug_options() diff --git a/third_party/xla/xla/tools/hlo_extractor.cc b/third_party/xla/xla/tools/hlo_extractor.cc index fad58e8b1cfc36..f9f07153cd9c9e 100644 --- a/third_party/xla/xla/tools/hlo_extractor.cc +++ b/third_party/xla/xla/tools/hlo_extractor.cc @@ -189,7 +189,7 @@ class ExtractionVisitor : public ConstDfsHloVisitorWithDefault { private: // Replace the `hlo` with Constant of the same shape. Status ReplaceWithConstant(const HloInstruction* hlo) { - StatusOr literal_status = MakeFakeLiteral(hlo->shape()); + absl::StatusOr literal_status = MakeFakeLiteral(hlo->shape()); TF_CHECK_OK(literal_status.status()); auto new_const = HloInstruction::CreateConstant(std::move(literal_status.value())); @@ -248,7 +248,8 @@ class ExtractionVisitor : public ConstDfsHloVisitorWithDefault { builder->AddInstruction(HloInstruction::CreateConstant( LiteralUtil::Zero(constant_shape.element_type()))); } else { - StatusOr literal_status = MakeFakeLiteral(constant_shape); + absl::StatusOr literal_status = + MakeFakeLiteral(constant_shape); TF_CHECK_OK(literal_status.status()); constant_instruction = builder->AddInstruction( HloInstruction::CreateConstant(std::move(literal_status.value()))); diff --git a/third_party/xla/xla/tools/hlo_module_loader.h b/third_party/xla/xla/tools/hlo_module_loader.h index b723e47c26d70a..29a4c7e8e51628 100644 --- a/third_party/xla/xla/tools/hlo_module_loader.h +++ b/third_party/xla/xla/tools/hlo_module_loader.h @@ -55,7 +55,7 @@ std::string StripLogHeaders(std::string_view hlo_string); // modifications before use. If the buffer assignment proto pointer is not null // and the hlo module format is proto, it loads buffer assignment from the // proto. -StatusOr> LoadModuleFromData( +absl::StatusOr> LoadModuleFromData( const std::string& data, std::string_view format, const hlo_module_loader_details::Config& ovr_config = hlo_module_loader_details::Config(), @@ -77,7 +77,7 @@ StatusOr> LoadModuleFromData( // modifications before use. If the buffer assignment proto pointer is not null // and the hlo module format is proto, it loads buffer assignment from the // proto. -StatusOr> LoadModuleFromFile( +absl::StatusOr> LoadModuleFromFile( const std::string& path, std::string format = "", const hlo_module_loader_details::Config& ovr_config = hlo_module_loader_details::Config(), @@ -88,8 +88,8 @@ StatusOr> LoadModuleFromFile( // The data format must be one of the following: // 1) A binary proto (format "pb") // 2) A text proto (format "pbtxt") -StatusOr> LoadInputFromData( - const std::string& data, std::string_view format); +absl::StatusOr> +LoadInputFromData(const std::string& data, std::string_view format); // Loads an HLO snapshot from file, only for its inputs // The file must be one of the following: @@ -97,8 +97,8 @@ StatusOr> LoadInputFromData( // 2) A text proto (with a .pbtxt extension) // If the format is specified (not empty), it overrides the one guessed from the // file extension. -StatusOr> LoadInputFromFile( - const std::string& path, std::string format = ""); +absl::StatusOr> +LoadInputFromFile(const std::string& path, std::string format = ""); } // namespace xla diff --git a/third_party/xla/xla/tools/hlo_opt/BUILD b/third_party/xla/xla/tools/hlo_opt/BUILD index fd6a85accd7bb4..3bd009b55481d8 100644 --- a/third_party/xla/xla/tools/hlo_opt/BUILD +++ b/third_party/xla/xla/tools/hlo_opt/BUILD @@ -20,7 +20,8 @@ load( ) package( - default_visibility = ["//visibility:public"], + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = ["//xla:internal"], licenses = ["notice"], ) @@ -29,7 +30,6 @@ cc_library( name = "opt_lib", srcs = ["opt_lib.cc"], hdrs = ["opt_lib.h"], - visibility = ["//visibility:public"], deps = [ "//xla:debug_options_flags", "//xla:statusor", @@ -52,7 +52,6 @@ cc_library( name = "gpu_opt", testonly = True, srcs = if_gpu_is_configured(["gpu_opt.cc"]), - visibility = ["//visibility:public"], deps = [ ":opt_lib", "//xla:debug_options_flags", @@ -93,7 +92,6 @@ cc_library( name = "cpu_opt", testonly = True, srcs = ["cpu_opt.cc"], - visibility = ["//visibility:public"], deps = [ ":opt_lib", "//xla/service:cpu_plugin", @@ -108,7 +106,6 @@ cc_library( name = "opt_main", testonly = True, srcs = ["opt_main.cc"], - visibility = ["//visibility:public"], deps = [ "cpu_opt", ":opt_lib", @@ -158,8 +155,10 @@ lit_test_suite( ], ), args = if_cuda_is_configured([ + "--param=PTX=PTX", "--param=GPU=a100_80", ]) + if_rocm_is_configured([ + "--param=PTX=GCN", "--param=GPU=mi200", ]), cfg = "//xla:lit.cfg.py", @@ -184,18 +183,13 @@ filegroup( "//xla/tools:hlo-opt", "@llvm-project//llvm:FileCheck", ], - visibility = ["//visibility:public"], ) filegroup( name = "all_gpu_specs", data = glob(["gpu_specs/*.txtpb"]), - visibility = ["//visibility:public"], ) -exports_files( - glob([ - "gpu_specs/*.txtpb", - ]), - visibility = ["//visibility:public"], -) +exports_files(glob([ + "gpu_specs/*.txtpb", +])) diff --git a/third_party/xla/xla/tools/hlo_opt/cpu_opt.cc b/third_party/xla/xla/tools/hlo_opt/cpu_opt.cc index 16dd19cc24e656..9e7d5c2b72ace7 100644 --- a/third_party/xla/xla/tools/hlo_opt/cpu_opt.cc +++ b/third_party/xla/xla/tools/hlo_opt/cpu_opt.cc @@ -31,7 +31,7 @@ class CpuOptProvider : public OptProvider { } // namespace } // namespace xla -REGISTER_MODULE_INITIALIZER(cpu_opt_provider, { +STREAM_EXECUTOR_REGISTER_MODULE_INITIALIZER(cpu_opt_provider, { xla::OptProvider::RegisterForPlatform( "cpu", std::make_unique()); }); diff --git a/third_party/xla/xla/tools/hlo_opt/gpu_hlo_llvm.hlo b/third_party/xla/xla/tools/hlo_opt/gpu_hlo_llvm.hlo index 8c65771b52f490..59800a9d170560 100644 --- a/third_party/xla/xla/tools/hlo_opt/gpu_hlo_llvm.hlo +++ b/third_party/xla/xla/tools/hlo_opt/gpu_hlo_llvm.hlo @@ -1,4 +1,4 @@ -// RUN: hlo-opt %s --platform=gpu --stage=llvm --xla_gpu_target_config_filename=%S/gpu_specs/%{GPU}.txtpb --split-input-file | FileCheck %s +// RUN: hlo-opt %s --platform=gpu --stage=llvm --xla_gpu_target_config_filename=%S/gpu_specs/%{GPU}.txtpb --split-input-file | FileCheck --check-prefixes=CHECK,CHECK-%{PTX} %s HloModule m @@ -10,7 +10,7 @@ add { // CHECK-LABEL: fusion -// CHECK: load <2 x half> +// CHECK: 2 x half ENTRY e { p1 = f16[1048576] parameter(0) i = f16[] constant(0) @@ -23,7 +23,8 @@ HloModule Test, is_scheduled=true // CHECK-LABEL: fusion -// CHECK: call void @llvm.nvvm.barrier0 +// CHECK-PTX: call void @llvm.nvvm.barrier0 +// CHECK-GCN: call void @llvm.amdgcn.s.barrier fused_computation { param_0 = f32[100,200]{1,0} parameter(0) ROOT b.1 = f32[100,200]{0,1} copy(f32[100,200]{1,0} param_0) diff --git a/third_party/xla/xla/tools/hlo_opt/gpu_hlo_unoptimized_llvm.hlo b/third_party/xla/xla/tools/hlo_opt/gpu_hlo_unoptimized_llvm.hlo index 179dd9bbfd717d..6c8bc8bd54fe6a 100644 --- a/third_party/xla/xla/tools/hlo_opt/gpu_hlo_unoptimized_llvm.hlo +++ b/third_party/xla/xla/tools/hlo_opt/gpu_hlo_unoptimized_llvm.hlo @@ -1,4 +1,4 @@ -// RUN: hlo-opt %s --platform=CUDA --stage=llvm-before-optimizations --xla_gpu_target_config_filename=%S/gpu_specs/a100_80.txtpb | FileCheck %s +// RUN: hlo-opt %s --platform=gpu --stage=llvm-before-optimizations --xla_gpu_target_config_filename=%S/gpu_specs/%{GPU}.txtpb | FileCheck %s // CHECK: fusion.in_bounds-true: // CHECK: br label diff --git a/third_party/xla/xla/tools/hlo_opt/gpu_opt.cc b/third_party/xla/xla/tools/hlo_opt/gpu_opt.cc index 01e14c9da835d0..3521ccc51664c2 100644 --- a/third_party/xla/xla/tools/hlo_opt/gpu_opt.cc +++ b/third_party/xla/xla/tools/hlo_opt/gpu_opt.cc @@ -131,7 +131,7 @@ class GpuOptProvider : public OptProvider { } // namespace } // namespace xla -REGISTER_MODULE_INITIALIZER(gpu_opt_provider, { +STREAM_EXECUTOR_REGISTER_MODULE_INITIALIZER(gpu_opt_provider, { xla::OptProvider::RegisterForPlatform( "gpu", std::make_unique()); }); diff --git a/third_party/xla/xla/tools/hlo_opt/gpu_specs/a100_80.txtpb b/third_party/xla/xla/tools/hlo_opt/gpu_specs/a100_80.txtpb index 97ae2a95a9b9dc..cf29fa306fdb94 100644 --- a/third_party/xla/xla/tools/hlo_opt/gpu_specs/a100_80.txtpb +++ b/third_party/xla/xla/tools/hlo_opt/gpu_specs/a100_80.txtpb @@ -13,25 +13,24 @@ # limitations under the License. gpu_device_info { - cuda_compute_capability { - major: 8 - minor: 0 - } threads_per_block_limit: 1024 threads_per_warp: 32 - shared_memory_per_block: 65536 - shared_memory_per_block_optin: 65536 - shared_memory_per_core: 65536 + shared_memory_per_block: 49152 + shared_memory_per_core: 167936 threads_per_core_limit: 2048 - core_count: 6192 + core_count: 108 fpus_per_core: 64 block_dim_limit_x: 2147483647 block_dim_limit_y: 65535 block_dim_limit_z: 65535 memory_bandwidth: 2039000000000 - l2_cache_size: 4194304 + l2_cache_size: 41943040 clock_rate_ghz: 1.1105 device_memory_size: 79050250240 + shared_memory_per_block_optin: 166912 + cuda_compute_capability { + major: 8 + } } platform_name: "CUDA" dnn_version_info { diff --git a/third_party/xla/xla/tools/multihost_hlo_runner/BUILD b/third_party/xla/xla/tools/multihost_hlo_runner/BUILD index ce9994d6169108..25c1777f3ba528 100644 --- a/third_party/xla/xla/tools/multihost_hlo_runner/BUILD +++ b/third_party/xla/xla/tools/multihost_hlo_runner/BUILD @@ -7,7 +7,8 @@ load("@local_tsl//tsl:tsl.bzl", "if_cuda_or_rocm") load("@local_tsl//tsl/platform:rules_cc.bzl", "cc_library") package( - default_visibility = ["//visibility:public"], + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = ["//xla:internal"], licenses = ["notice"], ) @@ -57,7 +58,6 @@ cc_library( name = "functional_hlo_runner", srcs = ["functional_hlo_runner.cc"], hdrs = ["functional_hlo_runner.h"], - visibility = ["//visibility:public"], deps = [ "//xla:literal", "//xla:shape_util", @@ -92,7 +92,6 @@ cc_library( name = "hlo_runner_flags", srcs = ["hlo_runner_flags.cc"], hdrs = ["hlo_runner_flags.h"], - visibility = ["//visibility:public"], deps = [ ":functional_hlo_runner", "//xla/pjrt:pjrt_executable", diff --git a/third_party/xla/xla/tools/multihost_hlo_runner/functional_hlo_runner.cc b/third_party/xla/xla/tools/multihost_hlo_runner/functional_hlo_runner.cc index c52e398193608b..53522cfffba0a2 100644 --- a/third_party/xla/xla/tools/multihost_hlo_runner/functional_hlo_runner.cc +++ b/third_party/xla/xla/tools/multihost_hlo_runner/functional_hlo_runner.cc @@ -53,13 +53,13 @@ namespace xla { namespace { // Creates an HloModule from the given proto. -StatusOr> HloTextToModule( +absl::StatusOr> HloTextToModule( absl::string_view hlo_text) { return ParseAndReturnUnverifiedModule(hlo_text); } // Creates an HloModule from the given proto. -StatusOr> HloProtoToModule( +absl::StatusOr> HloProtoToModule( const HloModuleProto& proto) { TF_ASSIGN_OR_RETURN( HloModuleConfig config, @@ -76,7 +76,8 @@ void PopulateWithSameValue(Literal* literal, ElementType val) { } } -StatusOr MakeFakeLiteralWithSameValue(const Shape& shape, int value) { +absl::StatusOr MakeFakeLiteralWithSameValue(const Shape& shape, + int value) { if (!shape.IsArray()) { return InvalidArgument( "MakeFakeLiteralWithSameValue does not support non-array type"); @@ -84,7 +85,7 @@ StatusOr MakeFakeLiteralWithSameValue(const Shape& shape, int value) { Shape new_shape = shape; new_shape.mutable_layout()->clear_tiles(); return primitive_util::PrimitiveTypeSwitch>( - [&](auto type) -> StatusOr { + [&](auto type) -> absl::StatusOr { if constexpr (primitive_util::IsArrayType(type)) { using NativeT = primitive_util::NativeTypeOf; @@ -243,19 +244,21 @@ void AddShardingAnnotationsToSpmdPartitionedModule(HloModule* hlo_module) { set_manual_sharding(entry_root); } -StatusOr> FunctionalHloRunner::CreateGpuClient() { +absl::StatusOr> +FunctionalHloRunner::CreateGpuClient() { return GetStreamExecutorGpuClient(GpuClientOptions()); } -StatusOr> FunctionalHloRunner::CreateMockGpuClient( - int num_nodes) { +absl::StatusOr> +FunctionalHloRunner::CreateMockGpuClient(int num_nodes) { GpuClientOptions options; options.num_nodes = num_nodes; options.enable_mock_nccl = true; return GetStreamExecutorGpuClient(options); } -StatusOr> FunctionalHloRunner::CreateGpuClient( +absl::StatusOr> +FunctionalHloRunner::CreateGpuClient( std::shared_ptr distributed_client, int node_id, int num_nodes) { if (node_id < 0 || node_id >= num_nodes) { @@ -273,7 +276,7 @@ StatusOr> FunctionalHloRunner::CreateGpuClient( return GetStreamExecutorGpuClient(options); } -StatusOr FunctionalHloRunner::LoadExecutionOptions( +absl::StatusOr FunctionalHloRunner::LoadExecutionOptions( absl::string_view path) { ExecutionOptions execution_options; TF_RETURN_IF_ERROR(tsl::ReadTextOrBinaryProto( @@ -281,7 +284,7 @@ StatusOr FunctionalHloRunner::LoadExecutionOptions( return execution_options; } -StatusOr FunctionalHloRunner::CreateCompileOptions( +absl::StatusOr FunctionalHloRunner::CreateCompileOptions( const PjRtClient& client, const FunctionalHloRunner::RawCompileOptions& raw_options, int task_id) { CompileOptions compile_options; @@ -397,10 +400,12 @@ FunctionalHloRunner::CreateExecutableBuildOptionsFromExecutionOptions( build_options.set_use_auto_spmd_partitioning( execution_options.use_auto_spmd_partitioning()); build_options.set_deduplicate_hlo(execution_options.deduplicate_hlo()); + build_options.set_allow_spmd_sharding_propagation_to_parameters( + execution_options.allow_spmd_sharding_propagation_to_parameters()); build_options.set_allow_spmd_sharding_propagation_to_output( execution_options.allow_spmd_sharding_propagation_to_output()); if (execution_options.has_device_assignment()) { - StatusOr> device_assignment = + absl::StatusOr> device_assignment = DeviceAssignment::Deserialize(execution_options.device_assignment()); TF_CHECK_OK(device_assignment.status()); build_options.set_device_assignment(**device_assignment); @@ -445,7 +450,7 @@ absl::Span FunctionalHloRunner::GetLocalDevices( return client.addressable_devices(); } -StatusOr +absl::StatusOr FunctionalHloRunner::LoadHloModuleAndArguments(absl::string_view hlo_file, InputFormat input_format) { HloModuleAndArguments hlo_module_and_arguments; @@ -494,7 +499,7 @@ Status FunctionalHloRunner::LoadAndRunAndDump( : FunctionalHloRunner::DumpOutput(output, dump_output_to, task_id); } -StatusOr +absl::StatusOr FunctionalHloRunner::LoadAndRun(PjRtClient& client, const DebugOptions& debug_options, const PreprocessingOptions& preproc_options, @@ -529,7 +534,7 @@ FunctionalHloRunner::LoadAndRun(PjRtClient& client, hlo_module_and_arguments.hlo_module.get(), loaded_arguments); } -StatusOr +absl::StatusOr FunctionalHloRunner::LoadAndRun( PjRtClient& client, const DebugOptions& debug_options, const PreprocessingOptions& preproc_options, @@ -585,7 +590,7 @@ Status FunctionalHloRunner::LoadAndCompile( return OkStatus(); } -StatusOr> +absl::StatusOr> FunctionalHloRunner::ReadModuleFromHloTextFile(absl::string_view hlo_file) { std::string hlo_string; TF_RETURN_IF_ERROR(tsl::ReadFileToString(tsl::Env::Default(), @@ -593,7 +598,7 @@ FunctionalHloRunner::ReadModuleFromHloTextFile(absl::string_view hlo_file) { return ParseAndReturnUnverifiedModule(hlo_string); } -StatusOr> +absl::StatusOr> FunctionalHloRunner::ReadModuleFromBinaryProtoFile(absl::string_view hlo_file) { HloProto proto; TF_RETURN_IF_ERROR( @@ -601,7 +606,7 @@ FunctionalHloRunner::ReadModuleFromBinaryProtoFile(absl::string_view hlo_file) { return HloProtoToModule(proto.hlo_module()); } -StatusOr> +absl::StatusOr> FunctionalHloRunner::ReadModuleFromTextProtoFile(absl::string_view hlo_file) { HloProto proto; TF_RETURN_IF_ERROR( @@ -609,7 +614,7 @@ FunctionalHloRunner::ReadModuleFromTextProtoFile(absl::string_view hlo_file) { return HloProtoToModule(proto.hlo_module()); } -StatusOr +absl::StatusOr FunctionalHloRunner::ReadModuleFromSnapshotBinaryProtoFile( absl::string_view hlo_file) { HloSnapshot proto; @@ -626,17 +631,17 @@ FunctionalHloRunner::ReadModuleFromSnapshotBinaryProtoFile( return hlo_module_and_arguments; } -StatusOr> FunctionalHloRunner::ReadModuleFromString( - absl::string_view hlo_text) { +absl::StatusOr> +FunctionalHloRunner::ReadModuleFromString(absl::string_view hlo_text) { return HloTextToModule(hlo_text); } -StatusOr> FunctionalHloRunner::ReadModuleFromProto( - const HloModuleProto& proto) { +absl::StatusOr> +FunctionalHloRunner::ReadModuleFromProto(const HloModuleProto& proto) { return HloProtoToModule(proto); } -StatusOr +absl::StatusOr FunctionalHloRunner::CompileAndRun(PjRtClient& client, const DebugOptions& debug_options, const PreprocessingOptions& preproc_options, @@ -651,7 +656,7 @@ FunctionalHloRunner::CompileAndRun(PjRtClient& client, return Run(client, executable.get(), arguments, running_options); } -StatusOr +absl::StatusOr FunctionalHloRunner::CompileAndRun( PjRtClient& client, const DebugOptions& debug_options, const PreprocessingOptions& preproc_options, @@ -779,11 +784,11 @@ CompileOptions FunctionalHloRunner::CompleteCompileOptions( return compile_options; } -StatusOr> FunctionalHloRunner::Compile( - PjRtClient& client, HloModule* hlo_module, - const DebugOptions& debug_options, - const PreprocessingOptions& preproc_options, - const CompileOptions& compile_options) { +absl::StatusOr> +FunctionalHloRunner::Compile(PjRtClient& client, HloModule* hlo_module, + const DebugOptions& debug_options, + const PreprocessingOptions& preproc_options, + const CompileOptions& compile_options) { TF_RETURN_IF_ERROR(PrepareHloModuleForCompilation(hlo_module, debug_options, preproc_options)); CompileOptions modified_compile_options = @@ -796,7 +801,7 @@ StatusOr> FunctionalHloRunner::Compile( return executable; } -StatusOr> FunctionalHloRunner::Compile( +absl::StatusOr> FunctionalHloRunner::Compile( PjRtClient& client, HloModule* hlo_module, const DebugOptions& debug_options, const PreprocessingOptions& preproc_options, @@ -818,11 +823,11 @@ StatusOr> FunctionalHloRunner::Compile( // Runs the executable and may repeat for multiple times. // Since the input buffers may be donated by the PjrtClient, we re-create the // input PjrtBuffers for each repetition. -StatusOr FunctionalHloRunner::Run( - PjRtClient& client, PjRtLoadedExecutable* executable, +absl::StatusOr +FunctionalHloRunner::Run(PjRtClient& client, PjRtLoadedExecutable* executable, - const PerDeviceLiteralVecType& arguments, - const RunningOptions& running_options) { + const PerDeviceLiteralVecType& arguments, + const RunningOptions& running_options) { auto create_argument_buffers_on_device = [&client, &executable, &arguments, &running_options]( bool flatten_tupled_arguments) { @@ -858,7 +863,8 @@ StatusOr FunctionalHloRunner::Run( // Runs the executable and may repeat for multiple times. // Since the input buffers may be donated by the PjrtClient, we re-create the // input PjrtBuffers for each repetition. -StatusOr FunctionalHloRunner::Run( +absl::StatusOr +FunctionalHloRunner::Run( PjRtClient& client, PjRtLoadedExecutable* executable, const LiteralVec& argument_literals, @@ -998,11 +1004,11 @@ Status EnsureSingleTupleForFlattening(const HloModule& module) { } // namespace -StatusOr +absl::StatusOr FunctionalHloRunner::RunInternal( PjRtClient& client, PjRtLoadedExecutable* executable, - std::function< - StatusOr>>>(bool)> + std::function>>>(bool)> create_argument_buffers_on_device, const RunningOptions& running_options) { ExecuteOptions execute_options; @@ -1122,7 +1128,7 @@ FunctionalHloRunner::RunInternal( return results; } -StatusOr>>> +absl::StatusOr>>> FunctionalHloRunner::CreateArgumentsOnDevice( PjRtClient& client, const PjRtLoadedExecutable* executable, const RunningOptions& running_options, bool flatten_arguments) { @@ -1228,7 +1234,7 @@ FunctionalHloRunner::CreateArgumentsOnDevice( running_options.log_input_output()); } -StatusOr>>> +absl::StatusOr>>> FunctionalHloRunner::CreateUninitializedArgumentsOnDevice( PjRtClient& client, const PjRtLoadedExecutable* executable, const RunningOptions& running_options, bool flatten_arguments) { @@ -1312,7 +1318,7 @@ FunctionalHloRunner::CreateUninitializedArgumentsOnDevice( return argument_buffers_per_device; } -StatusOr>>> +absl::StatusOr>>> FunctionalHloRunner::CopyArgumentsToDevice( PjRtClient& client, absl::Span addressable_devices, const PerDeviceLiteralVecType& arguments, bool log_input) { @@ -1357,7 +1363,7 @@ FunctionalHloRunner::CopyArgumentsToDevice( return argument_buffers; } -StatusOr>>> +absl::StatusOr>>> FunctionalHloRunner::CopyArgumentsToDevice( PjRtClient& client, absl::Span addressable_devices, const LiteralVec& argument_literals, @@ -1405,7 +1411,7 @@ FunctionalHloRunner::CopyArgumentsToDevice( return argument_buffers; } -StatusOr +absl::StatusOr FunctionalHloRunner::FetchAndLogOutput( PjRtClient& client, const std::vector>>& output_buffers, diff --git a/third_party/xla/xla/tools/multihost_hlo_runner/functional_hlo_runner.h b/third_party/xla/xla/tools/multihost_hlo_runner/functional_hlo_runner.h index d8a34d6fbe10e9..0f541043947c40 100644 --- a/third_party/xla/xla/tools/multihost_hlo_runner/functional_hlo_runner.h +++ b/third_party/xla/xla/tools/multihost_hlo_runner/functional_hlo_runner.h @@ -208,29 +208,29 @@ class FunctionalHloRunner { }; // Create a PjRtClient which can run HLOs on GPU. - static StatusOr> CreateGpuClient(); + static absl::StatusOr> CreateGpuClient(); // Create a PjRtClient which mocks multi-hosts GPU run - static StatusOr> CreateMockGpuClient( + static absl::StatusOr> CreateMockGpuClient( int num_nodes = 1); // Create a PjRtClient which can run HLOs on GPUs distributed across several // nodes. // The distributed client pointer passed as a parameter is expected to be // non-null, and 0 <= node_id < num_nodes must hold. - static StatusOr> CreateGpuClient( + static absl::StatusOr> CreateGpuClient( std::shared_ptr distributed_client, int node_id, int num_nodes); // Loads an ExecutionOptions proto (which can be used in RawCompileOptions). - static StatusOr LoadExecutionOptions( + static absl::StatusOr LoadExecutionOptions( absl::string_view path); // Creates the compilation options. // // If RawCompileOptions::num_slices is set, the // CompileOptions::device_assignment has to be set manually. - static StatusOr CreateCompileOptions( + static absl::StatusOr CreateCompileOptions( const PjRtClient& client, const FunctionalHloRunner::RawCompileOptions& raw_options, int task_id = 0); @@ -251,7 +251,7 @@ class FunctionalHloRunner { // not empty. Otherwise, use arguments from the HLO file or fake arguments. // The hlo file might be a HLO snapshot and thus contain arguments, otherwise // it is run with fake arguments. - static StatusOr LoadAndRun( + static absl::StatusOr LoadAndRun( PjRtClient& client, const DebugOptions& debug_options, const PreprocessingOptions& preproc_options, const CompileOptions& compile_options, @@ -265,7 +265,7 @@ class FunctionalHloRunner { // vector of indices for each local device. This means different device may // use the same argument literals. This is essential to run HLO modules with // large arguments (e.g., models with large weights). - static StatusOr LoadAndRun( + static absl::StatusOr LoadAndRun( PjRtClient& client, const DebugOptions& debug_options, const PreprocessingOptions& preproc_options, const CompileOptions& compile_options, @@ -288,7 +288,7 @@ class FunctionalHloRunner { // Compiles and runs the given HLO module with the given arguments for each // device. The given arguments is a map from device ID to a list of arguments. // If the arguments map is empty, the HLO module is run with fake arguments. - static StatusOr CompileAndRun( + static absl::StatusOr CompileAndRun( PjRtClient& client, const DebugOptions& debug_options, const PreprocessingOptions& preproc_options, const CompileOptions& compile_options, @@ -301,7 +301,7 @@ class FunctionalHloRunner { // contain a vector of indices for each local device. This means different // devices may use the same argument literals. This is essential to run HLO // modules with large arguments (e.g., models with large weights). - static StatusOr CompileAndRun( + static absl::StatusOr CompileAndRun( PjRtClient& client, const DebugOptions& debug_options, const PreprocessingOptions& preproc_options, const CompileOptions& compile_options, @@ -310,7 +310,7 @@ class FunctionalHloRunner { const PerDeviceIndexVecType& argument_indices); // Compiles the HLO module. - static StatusOr> Compile( + static absl::StatusOr> Compile( PjRtClient& client, HloModule* hlo_module, const DebugOptions& debug_options, const PreprocessingOptions& preproc_options, @@ -319,7 +319,7 @@ class FunctionalHloRunner { // Ahead-of-time compilation using the PjRtTopologyDescription that's passed // instead of using the registered topology. This enables reproduction of // compilation based on captured information. - static StatusOr> Compile( + static absl::StatusOr> Compile( PjRtClient& client, HloModule* hlo_module, const DebugOptions& debug_options, const PreprocessingOptions& preproc_options, @@ -327,35 +327,35 @@ class FunctionalHloRunner { const PjRtTopologyDescription& topology); // Runs the executable. - static StatusOr Run( + static absl::StatusOr Run( PjRtClient& client, PjRtLoadedExecutable* executable, const PerDeviceLiteralVecType& arguments, const RunningOptions& running_options); // Runs the executable, where the module arguments are provided through // a shared literal vector and per-device indices. - static StatusOr Run( + static absl::StatusOr Run( PjRtClient& client, PjRtLoadedExecutable* executable, const LiteralVec& argument_literals, const PerDeviceIndexVecType& argument_indices, const RunningOptions& running_options); - static StatusOr> ReadModuleFromHloTextFile( + static absl::StatusOr> ReadModuleFromHloTextFile( absl::string_view hlo_file); - static StatusOr> ReadModuleFromBinaryProtoFile( - absl::string_view hlo_file); - static StatusOr> ReadModuleFromTextProtoFile( + static absl::StatusOr> + ReadModuleFromBinaryProtoFile(absl::string_view hlo_file); + static absl::StatusOr> ReadModuleFromTextProtoFile( absl::string_view hlo_file); - static StatusOr ReadModuleFromSnapshotBinaryProtoFile( - absl::string_view hlo_file); - static StatusOr LoadHloModuleAndArguments( + static absl::StatusOr + ReadModuleFromSnapshotBinaryProtoFile(absl::string_view hlo_file); + static absl::StatusOr LoadHloModuleAndArguments( absl::string_view hlo_file, InputFormat input_format); - static StatusOr> ReadModuleFromString( + static absl::StatusOr> ReadModuleFromString( absl::string_view hlo_text); - static StatusOr> ReadModuleFromProto( + static absl::StatusOr> ReadModuleFromProto( const HloModuleProto& proto); // This would ideally be private, but we need it for the implementation of @@ -394,14 +394,14 @@ class FunctionalHloRunner { const PjRtClient& client); // Creates fake arguments to run the given executable. - static StatusOr>>> + static absl::StatusOr>>> CreateArgumentsOnDevice(PjRtClient& client, const PjRtLoadedExecutable* executable, const RunningOptions& running_options, bool flatten_arguments = false); // Creates uninitialized arguments to run the given executable. - static StatusOr>>> + static absl::StatusOr>>> CreateUninitializedArgumentsOnDevice(PjRtClient& client, const PjRtLoadedExecutable* executable, const RunningOptions& running_options, @@ -409,27 +409,27 @@ class FunctionalHloRunner { // Creates argument buffers based on the given arguments map. Note that the // arguments might be invalid when arguments are destructed. - static StatusOr>>> + static absl::StatusOr>>> CopyArgumentsToDevice(PjRtClient& client, absl::Span addressable_devices, const PerDeviceLiteralVecType& arguments, bool log_input); - static StatusOr>>> + static absl::StatusOr>>> CopyArgumentsToDevice(PjRtClient& client, absl::Span addressable_devices, const LiteralVec& argument_literals, const PerDeviceIndexVecType& argument_indices, bool log_input); - static StatusOr RunInternal( + static absl::StatusOr RunInternal( PjRtClient& client, PjRtLoadedExecutable* executable, - std::function< - StatusOr>>>(bool)> + std::function>>>(bool)> create_argument_buffers_on_device, const RunningOptions& running_options); - static StatusOr FetchAndLogOutput( + static absl::StatusOr FetchAndLogOutput( PjRtClient& client, const std::vector>>& output_buffers, diff --git a/third_party/xla/xla/tools/multihost_hlo_runner/functional_hlo_runner_test.cc b/third_party/xla/xla/tools/multihost_hlo_runner/functional_hlo_runner_test.cc index dbc208039d26ff..662bbdb7e46dcb 100644 --- a/third_party/xla/xla/tools/multihost_hlo_runner/functional_hlo_runner_test.cc +++ b/third_party/xla/xla/tools/multihost_hlo_runner/functional_hlo_runner_test.cc @@ -211,7 +211,7 @@ TEST_F(FunctionalHloRunnerTest, CanCompileWithoutHavingEnoughGpus) { std::string after_opt_hlo; TF_ASSERT_OK( tsl::ReadFileToString(env, after_opt_hlo_paths[0], &after_opt_hlo)); - StatusOr file_check_result = RunFileCheck(after_opt_hlo, R"( + absl::StatusOr file_check_result = RunFileCheck(after_opt_hlo, R"( // CHECK: param = f32[16,1]{1,0} // CHECK: add = f32[16,1]{1,0} )"); diff --git a/third_party/xla/xla/tools/multihost_hlo_runner/hlo_runner_flags.cc b/third_party/xla/xla/tools/multihost_hlo_runner/hlo_runner_flags.cc index 6bae928f051447..600fffaf023b89 100644 --- a/third_party/xla/xla/tools/multihost_hlo_runner/hlo_runner_flags.cc +++ b/third_party/xla/xla/tools/multihost_hlo_runner/hlo_runner_flags.cc @@ -129,7 +129,7 @@ bool MultiHostHloRunnerFlags::CreateOptionsFromFlags( ? FunctionalHloRunner::SpmdMode::kUseSpmdPartitioning : FunctionalHloRunner::SpmdMode::kNotUseSpmdPartitioning; if (!flag_values_.execution_options_path.empty()) { - StatusOr execution_options = + absl::StatusOr execution_options = FunctionalHloRunner::LoadExecutionOptions( flag_values_.execution_options_path); if (!execution_options.ok()) { diff --git a/third_party/xla/xla/tools/multihost_hlo_runner/hlo_runner_main.cc b/third_party/xla/xla/tools/multihost_hlo_runner/hlo_runner_main.cc index e1b67bd5dfb724..e492b08bddbcee 100644 --- a/third_party/xla/xla/tools/multihost_hlo_runner/hlo_runner_main.cc +++ b/third_party/xla/xla/tools/multihost_hlo_runner/hlo_runner_main.cc @@ -131,7 +131,7 @@ int main(int argc, char** argv) { } // The main logic: - xla::StatusOr> client; + absl::StatusOr> client; if (enable_mock_nccl) { CHECK_GT(num_nodes, 1); client = xla::FunctionalHloRunner::CreateMockGpuClient(num_nodes); diff --git a/third_party/xla/xla/tools/run_hlo_module.cc b/third_party/xla/xla/tools/run_hlo_module.cc index 85576d23daa7e8..4eab3038c0b546 100644 --- a/third_party/xla/xla/tools/run_hlo_module.cc +++ b/third_party/xla/xla/tools/run_hlo_module.cc @@ -16,32 +16,80 @@ limitations under the License. #include "xla/tools/run_hlo_module.h" #include +#include #include +#include #include +#include #include +#include #include #include #include +#include "absl/algorithm/container.h" +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_replace.h" +#include "absl/strings/string_view.h" #include "absl/types/span.h" #include "xla/error_spec.h" +#include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/ir/hlo_opcode.h" #include "xla/literal.h" #include "xla/literal_comparison.h" #include "xla/service/hlo.pb.h" +#include "xla/service/hlo_module_config.h" #include "xla/service/hlo_verifier.h" +#include "xla/status.h" #include "xla/tests/test_utils.h" #include "xla/tools/hlo_control_flow_flattening.h" +#include "xla/tools/hlo_decomposer.h" #include "xla/tools/hlo_module_loader.h" #include "xla/tools/prepare_reference_module.h" #include "xla/tools/run_hlo_module.pb.h" #include "xla/util.h" #include "xla/xla_data.pb.h" +#include "tsl/platform/errors.h" #include "tsl/platform/path.h" #include "tsl/platform/status.h" +#include "tsl/platform/statusor.h" namespace xla { namespace { +enum class ModuleResult { + kMatched, + kRan, + kSkipped, + kDidntRun, + kOtherError, + kCompilationError, + kRuntimeError, + kMismatch, +}; + +constexpr absl::string_view ModuleResultToString(ModuleResult result) { + switch (result) { + case ModuleResult::kMatched: + return "MATCHED"; + case ModuleResult::kRan: + return "RAN"; + case ModuleResult::kSkipped: + return "SKIPPED"; + case ModuleResult::kDidntRun: + return "DIDN'T RUN"; + case ModuleResult::kOtherError: + return "OTHER ERROR"; + case ModuleResult::kCompilationError: + return "COMPILATION ERROR"; + case ModuleResult::kRuntimeError: + return "RUNTIME ERROR"; + case ModuleResult::kMismatch: + return "MISMATCH"; + } +} // Writes the given literal to a file in the test temporary directory. void WriteLiteralToTempFile(const LiteralSlice& literal, @@ -118,9 +166,8 @@ StatusOr ExecuteWithRunner( return std::move(result_status).value(); } -} // namespace -Status RunAndCompare( +Status RunAndCompareInternal( std::unique_ptr test_module, const BufferAssignmentProto* buffer_assignment_proto, HloRunnerInterface* test_runner, HloRunnerInterface* reference_runner, @@ -128,7 +175,16 @@ Status RunAndCompare( xla::RunHloModuleIterationLiterals* iteration_literals_proto, std::function reference_module_modifier_hook, - std::function config_modifier_hook) { + std::function config_modifier_hook, + ModuleResult* test_run_result, ModuleResult* reference_run_result) { + auto copy_result_on_failure = [](auto status, ModuleResult result, + ModuleResult* out_result) { + if (!status.ok() && out_result != nullptr) { + *out_result = result; + } + return status; + }; + if (!config_modifier_hook) { config_modifier_hook = [](HloModuleConfig* config) { config->set_seed(42); @@ -138,19 +194,27 @@ Status RunAndCompare( if (options.flatten_control_flow) { HloControlFlowFlattening control_flow_flattening( HloControlFlowFlattening::Options{/*while_execution_count=*/1}); - TF_RETURN_IF_ERROR(control_flow_flattening.Run(test_module.get()).status()); + TF_RETURN_IF_ERROR( + copy_result_on_failure(control_flow_flattening.Run(test_module.get()), + ModuleResult::kCompilationError, test_run_result) + .status()); } const HloModuleProto test_module_proto = test_module->ToProto(); - TF_ASSIGN_OR_RETURN(auto args, - MakeFakeArguments(test_module.get(), engine, - options.use_large_float_range, - options.treat_gte_as_data_formatting)); + TF_ASSIGN_OR_RETURN( + auto args, copy_result_on_failure( + MakeFakeArguments(test_module.get(), engine, + options.use_large_float_range, + options.treat_gte_as_data_formatting), + ModuleResult::kOtherError, test_run_result)); // Use provided input literals as arguments, if any. if (iteration_literals_proto != nullptr && iteration_literals_proto->arguments_size() != 0) { if (iteration_literals_proto->arguments_size() != args.size()) { + if (test_run_result != nullptr) { + *test_run_result = ModuleResult::kOtherError; + } return xla::InvalidArgument( "Failed to use input literals as arguments; mismatched " "number of expected arguments."); @@ -160,14 +224,19 @@ Status RunAndCompare( xla::Shape(args[i].shape()), xla::Shape(iteration_literals_proto->arguments(i).shape())) .ok()) { + if (test_run_result != nullptr) { + *test_run_result = ModuleResult::kOtherError; + } return xla::InvalidArgument( "Failed to use input literals for argument %d " "because of a shape mismatch.", i); } - TF_ASSIGN_OR_RETURN(args[i], - xla::Literal::CreateFromProto( - iteration_literals_proto->arguments(i))); + TF_ASSIGN_OR_RETURN( + args[i], + copy_result_on_failure(xla::Literal::CreateFromProto( + iteration_literals_proto->arguments(i)), + ModuleResult::kOtherError, test_run_result)); } } } @@ -190,14 +259,22 @@ Status RunAndCompare( // properly match the test runner's numerics. TF_ASSIGN_OR_RETURN( reference_module, - PrepareReferenceModule(*test_module, test_runner, config_modifier_hook, - reference_module_modifier_hook)); + copy_result_on_failure( + PrepareReferenceModule(*test_module, test_runner, + config_modifier_hook, + reference_module_modifier_hook), + ModuleResult::kCompilationError, reference_run_result)); } TF_ASSIGN_OR_RETURN( auto test_result, - ExecuteWithRunner(std::move(test_module), buffer_assignment_proto, args, - test_runner, options.run_test_hlo_passes)); + copy_result_on_failure( + ExecuteWithRunner(std::move(test_module), buffer_assignment_proto, + args, test_runner, options.run_test_hlo_passes), + ModuleResult::kRuntimeError, test_run_result)); + if (test_run_result != nullptr) { + *test_run_result = ModuleResult::kRan; + } if (options.print_literals) { std::cout << "\n** Result with test runner " << test_runner->Name() << " **\n" @@ -209,15 +286,31 @@ Status RunAndCompare( } if (reference_module == nullptr) { - std::cerr << "Skipping reference runner\n"; + std::cerr << "Skipping reference runner"; + return OkStatus(); + } + if (const HloInstruction* root_instruction = + reference_module->entry_computation()->root_instruction(); + root_instruction->opcode() == HloOpcode::kCustomCall) { + // TODO(b/323849999) Use original computation for the reference platform. + std::cerr << "Skipping reference runner for a custom call " + << root_instruction->custom_call_target() << "\n"; + if (reference_run_result != nullptr) { + *reference_run_result = ModuleResult::kSkipped; + } return OkStatus(); } TF_ASSIGN_OR_RETURN( auto reference_result, - ExecuteWithRunner(std::move(reference_module), - /*buffer_assignment_proto=*/nullptr, args, - reference_runner, options.run_reference_hlo_passes)); + copy_result_on_failure( + ExecuteWithRunner(std::move(reference_module), + /*buffer_assignment_proto=*/nullptr, args, + reference_runner, options.run_reference_hlo_passes), + ModuleResult::kRuntimeError, reference_run_result)); + if (reference_run_result != nullptr) { + *reference_run_result = ModuleResult::kRan; + } if (options.print_literals) { std::cout << "\n** Result with reference runner " @@ -231,10 +324,164 @@ Status RunAndCompare( } ErrorSpec error_spec(static_cast(options.abs_error_bound), static_cast(options.rel_error_bound)); - return literal_comparison::Near(/*expected=*/reference_result, - /*actual=*/test_result, - /*error=*/error_spec, - /*detailed_message=*/true, &OnMiscompare); + + Status comparison_status = + literal_comparison::Near(/*expected=*/reference_result, + /*actual=*/test_result, + /*error=*/error_spec, + /*detailed_message=*/true, &OnMiscompare); + const ModuleResult comparison_result = + comparison_status.ok() ? ModuleResult::kMatched : ModuleResult::kMismatch; + if (test_run_result != nullptr) { + *test_run_result = comparison_result; + } + if (reference_run_result != nullptr) { + *reference_run_result = comparison_result; + } + return comparison_status; +} + +struct ChunkResult { + std::string module_name; + ModuleResult test_result = ModuleResult::kDidntRun; + ModuleResult reference_result = ModuleResult::kDidntRun; + Status status; + + bool operator<(const ChunkResult& other) const { + if (test_result != other.test_result) { + return test_result < other.test_result; + } + return reference_result < other.reference_result; + } +}; + +std::string BuildResultsTable(absl::Span chunk_results, + size_t num_modules) { + constexpr int kStatusWidth = 21; + constexpr int kNameWidth = 30; + constexpr int kThreeColumnsWidth = 5 + 2 * kStatusWidth + kNameWidth; + constexpr int kTableWidth = kThreeColumnsWidth + 30; + + std::ostringstream strstr; + auto print_row = [&](absl::string_view reference, absl::string_view test, + absl::string_view module_name, absl::string_view error) { + std::string formatted_error = absl::StrReplaceAll( + error, {{"\n", absl::StrCat("\n", std::string(kThreeColumnsWidth, ' '), + "|")}}); + strstr << " " << std::left << std::setw(kStatusWidth) << reference << "| " + << std::setw(kStatusWidth) << test << "| " << std::setw(kNameWidth) + << module_name << "| " << formatted_error << "\n"; + }; + auto print_line = [&](int line_width) { + strstr << std::string(line_width, '-') << "\n"; + }; + + print_row("Reference", "Test", "Module", "Status"); + print_line(kTableWidth); + + std::map, int> result_counts; + + for (const ChunkResult& chunk_result : chunk_results) { + const std::pair result_pair( + chunk_result.reference_result, chunk_result.test_result); + + ++result_counts[result_pair]; + print_row(ModuleResultToString(chunk_result.reference_result), + ModuleResultToString(chunk_result.test_result), + chunk_result.module_name, chunk_result.status.ToString()); + } + print_line(kTableWidth); + print_row("Reference", "Test", "Module", "Status"); + print_line(kTableWidth); + + strstr << "\n\n"; + + // Summary table. + print_line(kThreeColumnsWidth); + print_row("Reference", "Test", "Total count", ""); + print_line(kThreeColumnsWidth); + for (const auto& [result, count] : result_counts) { + print_row(ModuleResultToString(result.first), + ModuleResultToString(result.second), absl::StrCat(count), ""); + } + print_line(kThreeColumnsWidth); + if (chunk_results.size() < num_modules) { + strstr << "\n(did not " << (num_modules - chunk_results.size()) + << " modules due to earlier failures)\n\n"; + } + return strstr.str(); +} + +Status RunIsolatedAndCompare( + std::unique_ptr test_module, + const BufferAssignmentProto* buffer_assignment_proto, + HloRunnerInterface* test_runner, HloRunnerInterface* reference_runner, + std::minstd_rand0* engine, const RunHloModuleOptions& options, + xla::RunHloModuleIterationLiterals* iteration_literals_proto, + std::function + reference_module_modifier_hook, + std::function config_modifier_hook) { + CHECK(test_module); + CHECK(iteration_literals_proto == nullptr) + << "Cannot run decomposed module if input literals are provided."; + if (options.run_test_hlo_passes || (options.run_reference_hlo_passes && + !options.reference_platform.empty())) { + LOG(WARNING) + << "!!! Warning !!! When running decomposed module, running HLO " + "passes is likely not what you want. If you have unoptimized " + "HLO, first convert it to the optimized e.g. using the " + "hlo-opt tool, and then isolate without HLO passes."; + } + + std::vector chunk_results; + + TF_ASSIGN_OR_RETURN( + std::vector> modules, + DecomposeHloModule(*test_module, /*deduplicate_modules=*/true)); + + Status status = OkStatus(); + for (std::unique_ptr& module : modules) { + const std::string module_name = module->name(); + ModuleResult test_module_result = ModuleResult::kDidntRun; + ModuleResult reference_module_result = ModuleResult::kDidntRun; + Status chunk_status = RunAndCompareInternal( + std::move(module), buffer_assignment_proto, test_runner, + reference_runner, engine, options, iteration_literals_proto, + reference_module_modifier_hook, config_modifier_hook, + &test_module_result, &reference_module_result); + chunk_results.push_back({std::move(module_name), test_module_result, + reference_module_result, chunk_status}); + status.Update(chunk_status); + if (!chunk_status.ok() && test_module_result != ModuleResult::kMismatch) { + break; + } + } + absl::c_sort(chunk_results); + std::cout << BuildResultsTable(chunk_results, modules.size()); + return status; +} + +} // namespace + +Status RunAndCompare( + std::unique_ptr test_module, + const BufferAssignmentProto* buffer_assignment_proto, + HloRunnerInterface* test_runner, HloRunnerInterface* reference_runner, + std::minstd_rand0* engine, const RunHloModuleOptions& options, + xla::RunHloModuleIterationLiterals* iteration_literals_proto, + std::function + reference_module_modifier_hook, + std::function config_modifier_hook) { + if (options.isolate_instructions) { + return RunIsolatedAndCompare( + std::move(test_module), buffer_assignment_proto, test_runner, + reference_runner, engine, options, iteration_literals_proto, + reference_module_modifier_hook, config_modifier_hook); + } + return RunAndCompareInternal( + std::move(test_module), buffer_assignment_proto, test_runner, + reference_runner, engine, options, iteration_literals_proto, + reference_module_modifier_hook, config_modifier_hook, nullptr, nullptr); } Status RunAndCompare( @@ -272,7 +519,7 @@ Status RunAndCompare( std::unique_ptr iteration_literals_proto_local; if (iteration_literals_proto == nullptr) { // User did not explicitly give input - if (!options.force_fake_data && + if (!options.force_fake_data && !options.isolate_instructions && (options.input_format == "pb" || options.input_format == "pbtxt")) { // User is giving a snapshot (which contains inputs) LOG(INFO) << "Using input data from the user-provided snapshot."; diff --git a/third_party/xla/xla/tools/run_hlo_module.h b/third_party/xla/xla/tools/run_hlo_module.h index 0f9cc805b4e1d7..80b16b277c7f08 100644 --- a/third_party/xla/xla/tools/run_hlo_module.h +++ b/third_party/xla/xla/tools/run_hlo_module.h @@ -23,6 +23,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_module.h" #include "xla/service/hlo_runner.h" +#include "xla/status.h" #include "xla/tools/run_hlo_module.pb.h" #include "tsl/platform/status.h" @@ -55,6 +56,7 @@ struct RunHloModuleOptions { std::string input_literals_file; bool random_init_input_literals{true}; bool force_fake_data{false}; + bool isolate_instructions{false}; }; // Runs test_module on the platform with the name diff --git a/third_party/xla/xla/tools/run_hlo_module_main.cc b/third_party/xla/xla/tools/run_hlo_module_main.cc index 6e5dc66f42c0b5..6f4005a2da5d9c 100644 --- a/third_party/xla/xla/tools/run_hlo_module_main.cc +++ b/third_party/xla/xla/tools/run_hlo_module_main.cc @@ -127,6 +127,12 @@ int main(int argc, char** argv) { "iterations", &opts.iterations, "The number of times to run the module. Each iteration will be run " "with different input data."), + tsl::Flag( + "isolate_instructions", &opts.isolate_instructions, + "Rather than executing the entire module at once, run every " + "instruction individually, including the top-level and control-flow " + "dependent computations (e.g. inside conditions, calls). Skip " + "instructions inside fused computations etc."), tsl::Flag("different_random_seeds", &different_random_seeds, "Whether each iteration should use a different random seed for " "the HloModuleConfig."), @@ -194,8 +200,7 @@ int main(int argc, char** argv) { } if (!reference_platform_name.empty()) { - std::cerr << failure_count << "/" << iteration_count - << " runs miscompared.\n"; + std::cerr << failure_count << "/" << iteration_count << " runs failed.\n"; } return failure_count == 0 ? 0 : -1; diff --git a/third_party/xla/xla/tools/xla_compile_lib_test.cc b/third_party/xla/xla/tools/xla_compile_lib_test.cc index 953f6ad54499ab..380e6466a01145 100644 --- a/third_party/xla/xla/tools/xla_compile_lib_test.cc +++ b/third_party/xla/xla/tools/xla_compile_lib_test.cc @@ -89,6 +89,7 @@ TEST_F(XlaCompileLibTest, DISABLED_ON_CPU(CompilesForGpuWithDevice)) { EXPECT_THAT( CompileExecutable(std::move(module_), "gpu", std::nullopt, result), IsOkAndHolds(Not(IsEmpty()))); + EXPECT_TRUE(result.has_hlo_module()) << result.DebugString(); } TEST_F(XlaCompileLibTest, DISABLED_ON_CPU(CompilesForGpuWithoutDevice)) { @@ -102,14 +103,6 @@ TEST_F(XlaCompileLibTest, DISABLED_ON_CPU(CompilesForGpuWithoutDevice)) { EXPECT_THAT( CompileExecutable(std::move(module_), "gpu", std::nullopt, result), IsOkAndHolds(Not(IsEmpty()))); -} - -TEST_F(XlaCompileLibTest, - DISABLED_ON_CPU(ReturnsOptimizedModuleWhenRequested)) { - CompilationResult result; - EXPECT_THAT( - CompileExecutable(std::move(module_), "gpu", std::nullopt, result), - IsOkAndHolds(Not(IsEmpty()))); EXPECT_TRUE(result.has_hlo_module()) << result.DebugString(); } diff --git a/third_party/xla/xla/translate/BUILD b/third_party/xla/xla/translate/BUILD index 302ac9c5432700..aad084173418de 100644 --- a/third_party/xla/xla/translate/BUILD +++ b/third_party/xla/xla/translate/BUILD @@ -1,8 +1,13 @@ load("@bazel_skylib//rules:build_test.bzl", "build_test") load("//xla:xla.bzl", "xla_cc_binary") +load("@local_tsl//tsl:tsl.bzl", "internal_visibility") package( - default_visibility = ["//visibility:public"], + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = internal_visibility([ + "//learning/brain/mlir:tensorflow_friends", + "//learning/brain/mlir:xla_friends", + ]), licenses = ["notice"], ) @@ -23,7 +28,6 @@ xla_cc_binary( "//xla/stream_executor/host:host_platform", "//xla/translate/hlo_to_mhlo:translate_registration", "//xla/translate/mhlo_to_hlo:translate_registration", - "//xla/translate/mhlo_to_lhlo_with_xla:translate_registration", "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", "@llvm-project//mlir:Support", @@ -31,3 +35,28 @@ xla_cc_binary( "@local_tsl//tsl/platform:platform_port", ], ) + +build_test( + name = "xla-translate-opt_build_test", + targets = [ + ":xla-translate-opt", + ], +) + +xla_cc_binary( + name = "xla-translate-opt", + testonly = True, + srcs = ["xla_translate_opt_main.cc"], + deps = [ + "//xla/mlir/framework/ir:xla_framework", + "//xla/mlir/framework/transforms:passes", + "//xla/mlir_hlo:hlo_dialect_registration", + "//xla/service:cpu_plugin", + "//xla/service/cpu:hlo_xla_runtime_pipeline", # buildcleaner: keep + "@llvm-project//llvm:Support", + "@llvm-project//mlir:AllPassesAndDialects", + "@llvm-project//mlir:MlirOptLib", + "@local_tsl//tsl/platform:platform_port", + "@stablehlo//:register", + ], +) diff --git a/third_party/xla/xla/translate/hlo_to_mhlo/BUILD b/third_party/xla/xla/translate/hlo_to_mhlo/BUILD index 19dc3a498b2e4d..c3092da10f93ed 100644 --- a/third_party/xla/xla/translate/hlo_to_mhlo/BUILD +++ b/third_party/xla/xla/translate/hlo_to_mhlo/BUILD @@ -1,8 +1,13 @@ load("//xla:xla.bzl", "xla_cc_test") +load("@local_tsl//tsl:tsl.bzl", "internal_visibility") load("@local_tsl//tsl/platform:rules_cc.bzl", "cc_library") package( - default_visibility = ["//visibility:public"], + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = internal_visibility([ + "//learning/brain/mlir:tensorflow_friends", + "//learning/brain/mlir:xla_friends", + ]), licenses = ["notice"], ) @@ -10,7 +15,6 @@ cc_library( name = "attribute_importer", srcs = ["attribute_importer.cc"], hdrs = ["attribute_importer.h"], - visibility = ["//visibility:public"], deps = [ "//xla:shape_util", "//xla:statusor", @@ -22,11 +26,28 @@ cc_library( ], ) +cc_library( + name = "custom_call_importer", + srcs = ["custom_call_importer.cc"], + hdrs = ["custom_call_importer.h"], + deps = [ + "//xla:status", + "//xla:statusor", + "//xla:util", + "//xla/hlo/ir:hlo", + "//xla/mlir_hlo", + "@com_google_absl//absl/strings", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:AsmParser", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", + ], +) + cc_library( name = "stack_location_utils", srcs = ["stack_location_utils.cc"], hdrs = ["stack_location_utils.h"], - visibility = ["//visibility:public"], deps = [ "//xla/hlo/ir:hlo", "//xla/service:hlo_proto_cc", @@ -44,9 +65,9 @@ cc_library( "hlo_function_importer.h", "hlo_module_importer.h", ], - visibility = ["//visibility:public"], deps = [ ":attribute_importer", + ":custom_call_importer", ":hlo_utils", ":location_importer", "//xla:comparison_util", @@ -66,11 +87,13 @@ cc_library( "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", + "@com_google_absl//absl/types:span", "@llvm-project//llvm:Support", "@llvm-project//mlir:ArithDialect", "@llvm-project//mlir:AsmParser", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", + "@llvm-project//mlir:QuantOps", "@llvm-project//mlir:SparseTensorDialect", "@local_tsl//tsl/platform:statusor", ], @@ -80,7 +103,6 @@ cc_library( name = "hlo_to_mlir_hlo", srcs = ["hlo_to_mlir_hlo.cc"], hdrs = ["hlo_to_mlir_hlo.h"], - visibility = ["//visibility:public"], deps = [ ":hlo_module_importer", "//xla:status", @@ -94,7 +116,6 @@ cc_library( srcs = ["hlo_utils.cc"], hdrs = ["hlo_utils.h"], includes = ["include"], - visibility = ["//visibility:public"], deps = [ "//xla:literal", "//xla:shape_util", @@ -134,7 +155,6 @@ cc_library( name = "location_importer", srcs = ["location_importer.cc"], hdrs = ["location_importer.h"], - visibility = ["//visibility:public"], deps = [ "stack_location_utils", "//xla/hlo/ir:hlo", @@ -146,7 +166,6 @@ cc_library( name = "translate", srcs = ["translate.cc"], hdrs = ["translate.h"], - visibility = ["//visibility:public"], deps = [ ":hlo_to_mlir_hlo", "//xla:status", @@ -163,7 +182,6 @@ cc_library( name = "translate_registration", testonly = True, srcs = ["translate_registration.cc"], - visibility = ["//visibility:public"], deps = [ ":translate", "@llvm-project//mlir:IR", diff --git a/third_party/xla/xla/translate/hlo_to_mhlo/custom_call_importer.cc b/third_party/xla/xla/translate/hlo_to_mhlo/custom_call_importer.cc new file mode 100644 index 00000000000000..472207133b2390 --- /dev/null +++ b/third_party/xla/xla/translate/hlo_to_mhlo/custom_call_importer.cc @@ -0,0 +1,123 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/translate/hlo_to_mhlo/custom_call_importer.h" + +#include +#include + +#include "absl/strings/match.h" +#include "llvm/ADT/STLExtras.h" +#include "mlir/AsmParser/AsmParser.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/Location.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/ValueRange.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "xla/hlo/ir/hlo_instructions.h" +#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" +#include "xla/status.h" +#include "xla/util.h" + +namespace xla { +namespace { + +StatusOr ImportDynamicBroadcastInDimOp( + mlir::StringRef backend_config, mlir::Location loc, mlir::Type result_type, + mlir::ValueRange operands, mlir::OpBuilder* builder) { + if (backend_config.empty()) { + return Internal("backend_config attribute cannot be empty."); + } + + auto attr = mlir::parseAttribute(backend_config, builder->getContext()) + .dyn_cast(); + if (!attr) { + return Internal( + "Couldn't parse backend config into a dictionary attribute"); + } + + auto broadcast_dimensions_attr = + attr.get("broadcast_dimensions").dyn_cast_or_null(); + if (!broadcast_dimensions_attr) { + return Internal("broadcast_dimensions attribute is required."); + } + + std::vector broadcast_dimensions(broadcast_dimensions_attr.size()); + for (auto [i, broadcast_dimension] : + llvm::enumerate(broadcast_dimensions_attr)) { + broadcast_dimensions[i] = + broadcast_dimension.cast().getInt(); + } + + return builder + ->create( + loc, result_type, operands[0], operands[1], + builder->getI64TensorAttr(broadcast_dimensions)) + .getOperation(); +} + +StatusOr ImportDynamicReshapeOp( + mlir::StringRef backend_config, mlir::Location loc, mlir::Type result_type, + mlir::ValueRange operands, mlir::OpBuilder* builder) { + if (!backend_config.empty()) { + return Internal("backend_config attribute must be empty."); + } + return builder + ->create(loc, result_type, operands) + .getOperation(); +} + +StatusOr ImportRealDynamicSliceOp( + mlir::StringRef backend_config, mlir::Location loc, mlir::Type result_type, + mlir::ValueRange operands, mlir::OpBuilder* builder) { + if (!backend_config.empty()) { + return Internal("backend_config attribute must be empty."); + } + return builder + ->create(loc, result_type, operands) + .getOperation(); +} + +} // namespace + +absl::StatusOr ImportCustomCallAsOp( + const HloCustomCallInstruction* instruction, mlir::Location loc, + mlir::Type result_type, mlir::ValueRange operands, + mlir::OpBuilder* builder) { + const std::string& custom_call_target = instruction->custom_call_target(); + const std::string& backend_config_str = + instruction->raw_backend_config_string(); + if (custom_call_target == "mhlo.dynamic_broadcast_in_dim") { + return ImportDynamicBroadcastInDimOp(backend_config_str, loc, result_type, + operands, builder); + } + if (custom_call_target == "mhlo.dynamic_reshape") { + return ImportDynamicReshapeOp(backend_config_str, loc, result_type, + operands, builder); + } + if (custom_call_target == "mhlo.real_dynamic_slice") { + return ImportRealDynamicSliceOp(backend_config_str, loc, result_type, + operands, builder); + } + return InvalidArgument("Unsupported MHLO op custom_call %s", + custom_call_target); +} + +bool IsOpEncodedCustomCall(const HloCustomCallInstruction* instruction) { + return absl::StartsWith(instruction->custom_call_target(), "mhlo."); +} + +} // namespace xla diff --git a/third_party/xla/xla/translate/hlo_to_mhlo/custom_call_importer.h b/third_party/xla/xla/translate/hlo_to_mhlo/custom_call_importer.h new file mode 100644 index 00000000000000..c3fde619c23cd9 --- /dev/null +++ b/third_party/xla/xla/translate/hlo_to_mhlo/custom_call_importer.h @@ -0,0 +1,45 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_TRANSLATE_HLO_TO_MHLO_CUSTOM_CALL_IMPORTER_H_ +#define XLA_TRANSLATE_HLO_TO_MHLO_CUSTOM_CALL_IMPORTER_H_ + +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/Location.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/ValueRange.h" // from @llvm-project +#include "xla/hlo/ir/hlo_instructions.h" +#include "xla/statusor.h" + +namespace xla { + +// Imports custom_calls prefixed with `mhlo.` from HLO to MHLO. +// This is used for ops in MHLO / StableHLO that don't exist in HLO. Many of +// these ops are needed for XlaBuilder clients that need to raise HLO to +// StableHLO. +StatusOr ImportCustomCallAsOp( + const HloCustomCallInstruction* instruction, mlir::Location loc, + mlir::Type result_type, mlir::ValueRange operands, + mlir::OpBuilder* builder); + +// Indicates whether a custom call is an encoded MHLO op. +// Currently returns true for `mhlo.` prefixed custom calls. +bool IsOpEncodedCustomCall(const HloCustomCallInstruction* instruction); + +} // namespace xla + +#endif // XLA_TRANSLATE_HLO_TO_MHLO_CUSTOM_CALL_IMPORTER_H_ diff --git a/third_party/xla/xla/translate/hlo_to_mhlo/hlo_function_importer.cc b/third_party/xla/xla/translate/hlo_to_mhlo/hlo_function_importer.cc index a2d74adf8c29aa..06526c6bba5061 100644 --- a/third_party/xla/xla/translate/hlo_to_mhlo/hlo_function_importer.cc +++ b/third_party/xla/xla/translate/hlo_to_mhlo/hlo_function_importer.cc @@ -26,6 +26,7 @@ limitations under the License. #include "absl/algorithm/container.h" #include "absl/strings/match.h" #include "absl/types/optional.h" +#include "absl/types/span.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" @@ -48,13 +49,18 @@ limitations under the License. #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/ir/hlo_sharding.h" #include "xla/hlo/ir/hlo_sharding_metadata.h" +#include "xla/layout_util.h" #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" #include "xla/printer.h" #include "xla/protobuf_util.h" #include "xla/service/hlo.pb.h" +#include "xla/shape_util.h" +#include "xla/status.h" #include "xla/status_macros.h" #include "xla/translate/hlo_to_mhlo/attribute_importer.h" +#include "xla/translate/hlo_to_mhlo/custom_call_importer.h" #include "xla/translate/hlo_to_mhlo/hlo_utils.h" #include "xla/translate/hlo_to_mhlo/location_importer.h" #include "xla/util.h" @@ -105,7 +111,7 @@ bool DotIsDefault(const HloInstruction* instruction) { default_dimension_numbers.add_lhs_contracting_dimensions( instruction->operand(0)->shape().dimensions_size() == 1 ? 0 : 1); default_dimension_numbers.add_rhs_contracting_dimensions(0); - return xla::protobuf_util::ProtobufEquals(dnums, default_dimension_numbers); + return protobuf_util::ProtobufEquals(dnums, default_dimension_numbers); } // Clean up the GetTupleElementOp, created during the flattening of @@ -295,16 +301,16 @@ static bool IsNestedTupleInData(Type type) { return false; } -static bool HasCustomLayout(const xla::Shape& shape) { +static bool HasCustomLayout(const Shape& shape) { if (shape.IsTuple()) { return llvm::any_of(shape.tuple_shapes(), HasCustomLayout); } return shape.has_layout() && !shape.layout().minor_to_major().empty() && - shape.layout() != xla::LayoutUtil::GetDefaultLayoutForShape(shape); + shape.layout() != LayoutUtil::GetDefaultLayoutForShape(shape); } static mlir::Attribute GetLayoutAttribute(mlir::Builder& b, - const xla::Shape& shape) { + const Shape& shape) { if (shape.IsTuple()) { llvm::SmallVector element_attrs; for (const auto& tuple_shape : shape.tuple_shapes()) { @@ -325,8 +331,8 @@ static mlir::Attribute GetLayoutAttribute(mlir::Builder& b, return b.getIndexTensorAttr(layout); } -mlir::Attribute GetFrontendAttributes( - mlir::Builder& b, const xla::FrontendAttributes& attributes) { +mlir::Attribute GetFrontendAttributes(mlir::Builder& b, + const FrontendAttributes& attributes) { llvm::SmallVector attrs; attrs.reserve(attributes.map_size()); for (const auto& [k, v] : attributes.map()) { @@ -393,10 +399,11 @@ StatusOr HloFunctionImporter::ImportAsFunc( return importer.ImportAsFunc(computation, is_main); } -Status HloFunctionImporter::ImportAsRegion( - const xla::HloComputation& computation, mlir::SymbolTable& symbol_table, - mlir::Region* region, mlir::Builder* builder, - bool flatten_region_arg_tuple) { +Status HloFunctionImporter::ImportAsRegion(const HloComputation& computation, + mlir::SymbolTable& symbol_table, + mlir::Region* region, + mlir::Builder* builder, + bool flatten_region_arg_tuple) { HloFunctionImporter importer(symbol_table, {}, builder); return importer.ImportAsRegion(computation, region, flatten_region_arg_tuple); } @@ -482,7 +489,7 @@ StatusOr HloFunctionImporter::ImportAsFunc( computation_layout.result_layout().shape())); } if (llvm::any_of(computation_layout.parameter_layouts(), - [](const xla::ShapeLayout& shape) { + [](const ShapeLayout& shape) { return HasCustomLayout(shape.shape()); })) { llvm::SmallVector parameter_layouts; @@ -539,7 +546,7 @@ Status HloFunctionImporter::ImportAsRegion(const HloComputation& computation, } StatusOr HloFunctionImporter::ImportInstructionsImpl( - const xla::HloComputation& computation, + const HloComputation& computation, const llvm::SmallVectorImpl& arguments, mlir::OpBuilder* builder) { // Setup the input parameters. const int num_parameters = computation.num_parameters(); @@ -639,11 +646,11 @@ Status HloFunctionImporter::ImportInstructions( CleanUpTupleOps(block, &builder); - return ::tsl::OkStatus(); + return absl::OkStatus(); } StatusOr HloFunctionImporter::ImportInstructions( - const xla::HloComputation& computation, + const HloComputation& computation, const llvm::SmallVectorImpl& arguments, mlir::SymbolTable& symbol_table, mlir::OpBuilder* builder) { mlir::Block* block = builder->getBlock(); @@ -656,7 +663,7 @@ StatusOr HloFunctionImporter::ImportInstructions( } StatusOr HloFunctionImporter::ImportInstruction( - const xla::HloInstruction* instr, + const HloInstruction* instr, const llvm::SmallVectorImpl& operands, mlir::SymbolTable& symbol_table, mlir::OpBuilder* builder, DynamicShapeHandlingMode mode) { @@ -669,63 +676,13 @@ StatusOr HloFunctionImporter::ImportInstruction( return importer.ImportInstructionWithLayout(instr, operands, builder, mode); } -StatusOr HloFunctionImporter::ImportCustomCallAsOp( - const HloInstruction* instruction, mlir::Location loc, - const Type result_type, mlir::ValueRange operands, - mlir::OpBuilder* func_builder) { - auto custom_call = Cast(instruction); - if (custom_call->custom_call_target() == "mhlo.dynamic_broadcast_in_dim") { - auto raw_backend_config = custom_call->raw_backend_config_string(); - if (raw_backend_config.empty()) { - return Internal("backend_config attribute cannot be empty."); - } - - auto attr = mlir::parseAttribute(raw_backend_config, builder_->getContext()) - .dyn_cast(); - if (!attr) { - return Internal( - "Couldn't parse backend config into a dictionary attribute"); - } - - auto broadcast_dimensions_attr = - attr.get("broadcast_dimensions").dyn_cast_or_null(); - if (!broadcast_dimensions_attr) { - return Internal("broadcast_dimensions attribute is required."); - } - - std::vector broadcast_dimensions(broadcast_dimensions_attr.size()); - for (auto [i, broadcast_dimension] : - llvm::enumerate(broadcast_dimensions_attr)) { - broadcast_dimensions[i] = - broadcast_dimension.cast().getInt(); - } - - return func_builder - ->create( - loc, result_type, operands[0], operands[1], - builder_->getI64TensorAttr(broadcast_dimensions)) - .getOperation(); - } - if (custom_call->custom_call_target() == "mhlo.dynamic_reshape") { - auto raw_backend_config = custom_call->raw_backend_config_string(); - if (!raw_backend_config.empty()) { - return Internal("backend_config attribute should be empty."); - } - return func_builder - ->create(loc, result_type, operands) - .getOperation(); - } - return InvalidArgument("Unsupported MHLO op custom_call %s", - custom_call->custom_call_target()); -} - StatusOr HloFunctionImporter::ImportInstructionImpl( const HloInstruction* instruction, const llvm::SmallVectorImpl& operands, mlir::OpBuilder* func_builder, DynamicShapeHandlingMode mode) { const Shape& instruction_shape = instruction->shape(); const Shape& shape = mode == DynamicShapeHandlingMode::kConvertToStatic - ? xla::ShapeUtil::MakeStaticShape(instruction_shape) + ? ShapeUtil::MakeStaticShape(instruction_shape) : instruction_shape; TF_ASSIGN_OR_RETURN(auto result_type, ConvertShapeToType(shape, *builder_)); @@ -911,7 +868,7 @@ StatusOr HloFunctionImporter::ImportInstructionImpl( instruction->source_target_pairs(), builder_)); return ImportOldStyleAsyncStart( attributes, operands, loc, result_type, func_builder, - "collective_permute_", [&](auto) { return ::tsl::OkStatus(); }); + "collective_permute_", [&](auto) { return absl::OkStatus(); }); } case HloOpcode::kCollectivePermuteDone: { return ImportOldStyleAsyncDone(attributes, operands, loc, result_type, @@ -919,8 +876,8 @@ StatusOr HloFunctionImporter::ImportInstructionImpl( } case HloOpcode::kCustomCall: { auto custom_call = Cast(instruction); - if (absl::StrContains(custom_call->custom_call_target(), "mhlo.")) { - return ImportCustomCallAsOp(instruction, loc, result_type, operands, + if (IsOpEncodedCustomCall(custom_call)) { + return ImportCustomCallAsOp(custom_call, loc, result_type, operands, func_builder); } const auto& called_computations = custom_call->called_computations(); @@ -1262,7 +1219,7 @@ StatusOr HloFunctionImporter::ImportInstructionImpl( } return ImportOldStyleAsyncStart( attributes, operands, loc, result_type, func_builder, "copy_", - [](auto) { return ::tsl::OkStatus(); }); + [](auto) { return absl::OkStatus(); }); } case HloOpcode::kCopyDone: { return ImportOldStyleAsyncDone(attributes, operands, loc, result_type, @@ -1287,11 +1244,11 @@ StatusOr HloFunctionImporter::ImportInstructionImpl( "is_host_transfer", builder_->getBoolAttr(send_op->is_host_transfer()))); if (send_op->channel_id().has_value()) { - xla::ChannelHandle channel_handle; + ChannelHandle channel_handle; channel_handle.set_handle(send_op->channel_id().value()); channel_handle.set_type(send_op->is_host_transfer() - ? xla::ChannelHandle::DEVICE_TO_HOST - : xla::ChannelHandle::DEVICE_TO_DEVICE); + ? ChannelHandle::DEVICE_TO_HOST + : ChannelHandle::DEVICE_TO_DEVICE); attributes.push_back(ConvertChannelHandle(channel_handle)); } return ImportOldStyleAsyncStart( @@ -1321,11 +1278,11 @@ StatusOr HloFunctionImporter::ImportInstructionImpl( "is_host_transfer", builder_->getBoolAttr(recv_op->is_host_transfer()))); if (recv_op->channel_id().has_value()) { - xla::ChannelHandle channel_handle; + ChannelHandle channel_handle; channel_handle.set_handle(recv_op->channel_id().value()); channel_handle.set_type(recv_op->is_host_transfer() - ? xla::ChannelHandle::HOST_TO_DEVICE - : xla::ChannelHandle::DEVICE_TO_DEVICE); + ? ChannelHandle::HOST_TO_DEVICE + : ChannelHandle::DEVICE_TO_DEVICE); attributes.push_back(ConvertChannelHandle(channel_handle)); } return ImportOldStyleAsyncStart( @@ -1466,7 +1423,7 @@ StatusOr HloFunctionImporter::ImportInstructionImpl( return ImportOldStyleAsyncStart( attributes, operands, loc, result_type, func_builder, "all_gather_", - [](auto) { return ::tsl::OkStatus(); }); + [](auto) { return absl::OkStatus(); }); } case HloOpcode::kAllGatherDone: { return ImportOldStyleAsyncDone(attributes, operands, loc, result_type, @@ -1520,7 +1477,7 @@ StatusOr HloFunctionImporter::ImportInstructionImpl( TF_RETURN_IF_ERROR(ImportAsRegion( *instruction->to_apply(), &all_reduce_sync.getComputation(), /*flatten_region_arg_tuple=*/true)); - return ::tsl::OkStatus(); + return absl::OkStatus(); }); } case HloOpcode::kAllReduceDone: { @@ -1624,14 +1581,14 @@ StatusOr HloFunctionImporter::ImportInstructionImpl( auto shape = func_builder->create( loc, Convert(result_type.cast().getShape())); switch (instruction->random_distribution()) { - case xla::RNG_UNIFORM: + case RNG_UNIFORM: return func_builder ->create( loc, result_type, operands[0], operands[1], shape, ::mlir::mhlo::RngDistribution::UNIFORM) .getOperation(); - case xla::RNG_NORMAL: + case RNG_NORMAL: return func_builder ->create(loc, result_type, operands[0], operands[1], shape, @@ -1912,7 +1869,7 @@ StatusOr HloFunctionImporter::ImportInstructionImpl( } // Return type is boolean, let's use `operand != 0` instead of Convert. - xla::Shape input_shape = instruction->operand(0)->shape(); + Shape input_shape = instruction->operand(0)->shape(); TF_ASSIGN_OR_RETURN(mlir::Type type, ConvertTensorShapeToType( input_shape, *func_builder)); @@ -2012,6 +1969,7 @@ StatusOr HloFunctionImporter::ImportInstructionImpl( NO_ATTRIBUTE_CASE(kReplicaId, ReplicaIdOp); NO_ATTRIBUTE_CASE(kStochasticConvert, StochasticConvertOp); NO_ATTRIBUTE_CASE(kLogistic, LogisticOp); + NO_ATTRIBUTE_CASE(kErf, ErfOp); // The dimensions attribute is not present on the HLO Reshape // instruction. If dimensions are non-default, the XLA builder // implements it as a separate transpose. @@ -2049,8 +2007,8 @@ StatusOr HloFunctionImporter::ImportInstructionImpl( llvm::SmallVector flattened_ret_types; FlattenTupleType(result_type, flattened_ret_types); - auto fusion_kind = mlir::mhlo::symbolizeFusionKind( - xla::ToString(instruction->fusion_kind())); + auto fusion_kind = + mlir::mhlo::symbolizeFusionKind(ToString(instruction->fusion_kind())); attributes.push_back(builder_->getNamedAttr( "fusion_kind", mlir::mhlo::FusionKindAttr::get( func_builder->getContext(), fusion_kind.value()))); @@ -2152,7 +2110,7 @@ Status HloFunctionImporter::GetMlirTypes( instruction->shape(), *builder_)); types->push_back(ret_type); } - return ::tsl::OkStatus(); + return absl::OkStatus(); } StatusOr HloFunctionImporter::GetMlirValue( @@ -2279,13 +2237,13 @@ mlir::NamedAttribute HloFunctionImporter::ConvertReplicaGroups( mlir::NamedAttribute HloFunctionImporter::ConvertChannelHandle( std::optional channel_id) { - xla::ChannelHandle channel_handle; + ChannelHandle channel_handle; if (channel_id) channel_handle.set_handle(*channel_id); return ConvertChannelHandle(channel_handle); } mlir::NamedAttribute HloFunctionImporter::ConvertChannelHandle( - const xla::ChannelHandle& channel) { + const ChannelHandle& channel) { return builder_->getNamedAttr( "channel_handle", mlir::mhlo::ChannelHandleAttr::get( context_, channel.handle(), channel.type())); @@ -2304,10 +2262,10 @@ void HloFunctionImporter::SetLayoutForMlir(mlir::Operation* op, } Status HloFunctionImporter::ConvertShapeToMlirLayout( - const xla::Shape& shape, + const Shape& shape, llvm::SmallVectorImpl& flattened_attr) { if (shape.IsToken()) { - return ::tsl::OkStatus(); + return absl::OkStatus(); } if (shape.IsTuple()) { std::vector tuple_layouts; @@ -2315,29 +2273,29 @@ Status HloFunctionImporter::ConvertShapeToMlirLayout( TF_RETURN_IF_ERROR( ConvertShapeToMlirLayout(shape.tuple_shapes(i), flattened_attr)); } - return ::tsl::OkStatus(); + return absl::OkStatus(); } if (shape.IsArray()) { - const xla::Layout l = shape.layout(); + const Layout l = shape.layout(); std::vector minor_to_major; for (int64_t i : l.minor_to_major()) { minor_to_major.push_back(builder_->getI64IntegerAttr(i)); } llvm::ArrayRef array_ref(minor_to_major); flattened_attr.push_back(builder_->getArrayAttr(array_ref)); - return ::tsl::OkStatus(); + return absl::OkStatus(); } return Internal("Couldn't convert layout."); } -mlir::Attribute ConvertSharding(const xla::HloSharding& sharding, +mlir::Attribute ConvertSharding(const HloSharding& sharding, mlir::Builder* builder) { return builder->getStringAttr(sharding.ToString(/*include_metadata=*/true)); } -mlir::Attribute ConvertSharding(const xla::OpSharding& sharding, +mlir::Attribute ConvertSharding(const OpSharding& sharding, mlir::Builder* builder) { - auto hlo_sharding = xla::HloSharding::FromProto(sharding); + auto hlo_sharding = HloSharding::FromProto(sharding); if (!hlo_sharding.ok()) return {}; return ConvertSharding(hlo_sharding.value(), builder); } diff --git a/third_party/xla/xla/translate/hlo_to_mhlo/hlo_function_importer.h b/third_party/xla/xla/translate/hlo_to_mhlo/hlo_function_importer.h index 0d8da56ad38fa3..fbf7077dab98a5 100644 --- a/third_party/xla/xla/translate/hlo_to_mhlo/hlo_function_importer.h +++ b/third_party/xla/xla/translate/hlo_to_mhlo/hlo_function_importer.h @@ -57,15 +57,15 @@ class HloFunctionImporter { // returns the FuncOp. This also imports any computations referred by // instructions in this computation. static StatusOr ImportAsFunc( - const xla::HloComputation& computation, mlir::SymbolTable& symbol_table, - std::unordered_map* + const HloComputation& computation, mlir::SymbolTable& symbol_table, + std::unordered_map* function_map, mlir::Builder* builder, bool is_main); // Imports the given hlo computation to the specified region. If // 'flatten_region_arg_tuple' is true, then flatten the tuple-typed region // argument(s) and return value(s). - static Status ImportAsRegion(const xla::HloComputation& computation, + static Status ImportAsRegion(const HloComputation& computation, mlir::SymbolTable& symbol_table, mlir::Region* region, mlir::Builder* builder, bool flatten_region_arg_tuple = false); @@ -73,12 +73,12 @@ class HloFunctionImporter { // Imports the given computation to the given place specified by `builder`. // `arguments` contains values for all parameters. static StatusOr ImportInstructions( - const xla::HloComputation& computation, + const HloComputation& computation, const llvm::SmallVectorImpl& arguments, mlir::SymbolTable& symbol_table, mlir::OpBuilder* builder); static StatusOr ImportInstruction( - const xla::HloInstruction* instr, + const HloInstruction* instr, const llvm::SmallVectorImpl& operands, mlir::SymbolTable& symbol_table, mlir::OpBuilder* builder, DynamicShapeHandlingMode mode = DynamicShapeHandlingMode::kDynamic); @@ -139,7 +139,7 @@ class HloFunctionImporter { private: HloFunctionImporter(mlir::SymbolTable& symbol_table, - std::unordered_map* function_map, mlir::Builder* builder) : context_(symbol_table.getOp()->getContext()), @@ -154,8 +154,8 @@ class HloFunctionImporter { // Imports the given computation as a new function, if it hasn't been already // imported. - StatusOr ImportAsFunc( - const xla::HloComputation& computation, bool is_main); + StatusOr ImportAsFunc(const HloComputation& computation, + bool is_main); // Imports the given computation in the specified region. Status ImportAsRegion(const HloComputation& computation, mlir::Region* region, @@ -166,19 +166,13 @@ class HloFunctionImporter { Status ImportInstructions(const HloComputation& computation, mlir::Block* block, bool flatten_region_arg_tuple); StatusOr ImportInstructionsImpl( - const xla::HloComputation& computation, + const HloComputation& computation, const llvm::SmallVectorImpl& arguments, mlir::OpBuilder* builder); - // Imports custom_calls prefixed with `stablehlo.`. - StatusOr ImportCustomCallAsOp( - const HloInstruction* instruction, mlir::Location loc, - mlir::Type result_type, mlir::ValueRange operands, - mlir::OpBuilder* func_builder); - // Imports an instruction. StatusOr ImportInstructionWithLayout( - const xla::HloInstruction* instruction, + const HloInstruction* instruction, const llvm::SmallVectorImpl& operands, mlir::OpBuilder* func_builder, DynamicShapeHandlingMode mode = DynamicShapeHandlingMode::kDynamic); @@ -191,19 +185,19 @@ class HloFunctionImporter { // Gets the MLIR operand values from an HLO Instruction. StatusOr> GetOperands( - const xla::HloInstruction* instruction); + const HloInstruction* instruction); // Converts xla Tensor type to the corresponding MLIR type. - StatusOr ConvertTensorType(const xla::Shape& shape); + StatusOr ConvertTensorType(const Shape& shape); // Converts an XLA shape/layout to the corresponding MLIR layout, in // flattened_attr, while flattening the tuple layout. Status ConvertShapeToMlirLayout( - const xla::Shape& shape, + const Shape& shape, llvm::SmallVectorImpl& flattened_attr); // Returns the output type of an HloInstruction. - StatusOr GetReturnType(const xla::HloInstruction* instruction); + StatusOr GetReturnType(const HloInstruction* instruction); // Takes a list of HloInstructions and generates the list of types used for // input, bypassing tuples to subsets. @@ -211,7 +205,7 @@ class HloFunctionImporter { llvm::SmallVectorImpl* types); // Returns the Mlir Value for the corresponding HloInstruction. - StatusOr GetMlirValue(const xla::HloInstruction* instruction); + StatusOr GetMlirValue(const HloInstruction* instruction); // Converts an XLA ComparisonDirection to the corresponding MLIR attribute. mlir::NamedAttribute ConvertComparisonDirection( @@ -221,8 +215,7 @@ class HloFunctionImporter { mlir::NamedAttribute ConvertComparisonType(Comparison::Type type); // Converts an XLA CustomCallSchedule to the corresponding MLIR attribute. - mlir::NamedAttribute ConvertCustomCallSchedule( - xla::CustomCallSchedule schedule); + mlir::NamedAttribute ConvertCustomCallSchedule(CustomCallSchedule schedule); // Converts the dimensions of an HLO instruction into an MLIR attribute. mlir::DenseIntElementsAttr ConvertDimensions( @@ -245,7 +238,7 @@ class HloFunctionImporter { mlir::NamedAttribute ConvertUseGlobalDeviceIds(); // Converts channel handle to attribute - mlir::NamedAttribute ConvertChannelHandle(const xla::ChannelHandle& channel); + mlir::NamedAttribute ConvertChannelHandle(const ChannelHandle& channel); // ============ // Imports an old-style async start op. E.g. an HLO all-gather-start @@ -283,25 +276,23 @@ class HloFunctionImporter { mlir::Builder* builder_; // Mapping from HloComputation to the created MLIR function. - std::unordered_map* - function_map_; + std::unordered_map* function_map_; // Mapping from HloInstructions to the associative MLIR values. - std::unordered_map - instruction_value_map_; + std::unordered_map instruction_value_map_; }; // Returns a StringAttr that carries a prettyprinted representation of the // given HLO C++ sharding. // Always succeeds and returns a non-empty attribute. -mlir::Attribute ConvertSharding(const xla::HloSharding& sharding, +mlir::Attribute ConvertSharding(const HloSharding& sharding, mlir::Builder* builder); // Returns a StringAttr that carries a prettyprinted representation of the // given HLO proto sharding. // Will fail and return an empty attribute if the proto sharding cannot be // converted to the C++ sharding. -mlir::Attribute ConvertSharding(const xla::OpSharding& sharding, +mlir::Attribute ConvertSharding(const OpSharding& sharding, mlir::Builder* builder); } // namespace xla diff --git a/third_party/xla/xla/translate/hlo_to_mhlo/hlo_module_importer.cc b/third_party/xla/xla/translate/hlo_to_mhlo/hlo_module_importer.cc index f9edd12727b7e2..75bd2745d06d76 100644 --- a/third_party/xla/xla/translate/hlo_to_mhlo/hlo_module_importer.cc +++ b/third_party/xla/xla/translate/hlo_to_mhlo/hlo_module_importer.cc @@ -15,20 +15,25 @@ limitations under the License. #include "xla/translate/hlo_to_mhlo/hlo_module_importer.h" -#include +#include #include -#include - +#include + +#include "absl/types/span.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Support/Casting.h" +#include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project #include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" -#include "xla/layout_util.h" #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" -#include "xla/permutation_util.h" -#include "xla/translate/hlo_to_mhlo/attribute_importer.h" +#include "xla/status.h" #include "xla/translate/hlo_to_mhlo/hlo_function_importer.h" -#include "xla/translate/hlo_to_mhlo/hlo_utils.h" #include "xla/xla.pb.h" namespace xla { @@ -41,6 +46,7 @@ HloModuleImporter::HloModuleImporter(mlir::ModuleOp module, module.getContext()->loadDialect(); module.getContext()->loadDialect(); module.getContext()->loadDialect(); + module.getContext()->loadDialect(); } namespace { @@ -48,7 +54,7 @@ namespace { constexpr char kFrontendAttributesAttr[] = "mhlo.frontend_attributes"; mlir::ArrayAttr ConvertCrossProgramPrefetches( - const absl::Span prefetches, + const absl::Span prefetches, mlir::Builder* builder) { llvm::SmallVector shapes; for (auto [parameter, index, alt_memory_offset] : prefetches) { @@ -65,7 +71,7 @@ mlir::ArrayAttr ConvertCrossProgramPrefetches( } } // namespace -Status HloModuleImporter::Import(const xla::HloModule& hlo_module) { +Status HloModuleImporter::Import(const HloModule& hlo_module) { auto module = llvm::cast(symbol_table_.getOp()); module.setName(hlo_module.name()); module->setAttr("mhlo.cross_program_prefetches", @@ -113,19 +119,19 @@ Status HloModuleImporter::Import(const xla::HloModule& hlo_module) { return OkStatus(); } -Status HloModuleImporter::Import(const xla::HloModuleProto& module_proto) { - xla::DebugOptions debug_options; +Status HloModuleImporter::Import(const HloModuleProto& module_proto) { + DebugOptions debug_options; TF_ASSIGN_OR_RETURN( auto module_config, - xla::HloModule::CreateModuleConfigFromProto(module_proto, debug_options)); - TF_ASSIGN_OR_RETURN(auto module, xla::HloModule::CreateFromProto( - module_proto, module_config)); + HloModule::CreateModuleConfigFromProto(module_proto, debug_options)); + TF_ASSIGN_OR_RETURN(auto module, + HloModule::CreateFromProto(module_proto, module_config)); return Import(*module); } -void HloModuleImporter::ImportFrontendAttributes( - const xla::HloModule& hlo_module, mlir::ModuleOp module) { +void HloModuleImporter::ImportFrontendAttributes(const HloModule& hlo_module, + mlir::ModuleOp module) { if (!hlo_module.frontend_attributes().map().empty()) { llvm::SmallVector frontend_attributes; for (const auto& [k, v] : hlo_module.frontend_attributes().map()) { diff --git a/third_party/xla/xla/translate/hlo_to_mhlo/hlo_utils.cc b/third_party/xla/xla/translate/hlo_to_mhlo/hlo_utils.cc index 548011353daf4f..d60dee5dabda11 100644 --- a/third_party/xla/xla/translate/hlo_to_mhlo/hlo_utils.cc +++ b/third_party/xla/xla/translate/hlo_to_mhlo/hlo_utils.cc @@ -309,6 +309,8 @@ StatusOr<::xla::HloOpcode> MhloToHloOpcode(mlir::Operation* op) { return xla::HloOpcode::kClz; } else if (isa(op)) { return xla::HloOpcode::kCos; + } else if (isa(op)) { + return xla::HloOpcode::kErf; } else if (isa(op)) { return xla::HloOpcode::kExp; } else if (isa(op)) { diff --git a/third_party/xla/xla/translate/hlo_to_mhlo/tests/BUILD b/third_party/xla/xla/translate/hlo_to_mhlo/tests/BUILD index 1e348e25ed8acc..e2961242bb96b8 100644 --- a/third_party/xla/xla/translate/hlo_to_mhlo/tests/BUILD +++ b/third_party/xla/xla/translate/hlo_to_mhlo/tests/BUILD @@ -1,7 +1,6 @@ load("//xla:lit.bzl", "enforce_glob", "lit_test_suite") package( - default_visibility = ["//visibility:public"], # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], licenses = ["notice"], ) @@ -12,6 +11,7 @@ lit_test_suite( [ "bool_compare.hlotxt", "case_conditional.hlotxt", + "custom_call.hlotxt", "dynamic_param.hlo", "entry_computation_layout.hlotxt", "frontend_attributes.hlotxt", diff --git a/third_party/xla/xla/translate/hlo_to_mhlo/tests/custom_call.hlotxt b/third_party/xla/xla/translate/hlo_to_mhlo/tests/custom_call.hlotxt new file mode 100644 index 00000000000000..f91b2f98028bce --- /dev/null +++ b/third_party/xla/xla/translate/hlo_to_mhlo/tests/custom_call.hlotxt @@ -0,0 +1,41 @@ +// RUN: xla-translate --print-sugar=false -hlo-text-to-mlir-hlo -hlo-import-all-computations %s -o - | FileCheck %s + +// CHECK: module @foobar +HloModule foobar + +// CHECK-LABEL: func @main(%arg0: tensor) -> tensor { +ENTRY %dummy_main (Arg_0.1: f32[]) -> f32[] { + ROOT %Arg_0.1 = f32[] parameter(0) +} + +// CHECK-LABEL: func private @test_custom_call_dynamic_broadcast_in_dim +// CHECK-SAME: [[ARG_0:%.*]]: tensor<1x?xf32>, [[ARG_1:%.*]]: tensor<3xi64>) -> tensor<2x?x2xf32> +%test_custom_call_dynamic_broadcast_in_dim (arg1: f32[1,?], arg2: s64[3]) -> f32[2,?,2] { + %arg1 = f32[1,?] parameter(0) + %arg2 = s64[3] parameter(1) + // CHECK: "mhlo.dynamic_broadcast_in_dim"([[ARG_0]], [[ARG_1]]) { + // CHECK-SAME: broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} + // CHECK-SAME: (tensor<1x?xf32>, tensor<3xi64>) -> tensor<2x?x2xf32> + ROOT %custom-call = f32[2,?,2] custom-call(f32[1,?] %arg1, s64[3] %arg2), custom_call_target="mhlo.dynamic_broadcast_in_dim", backend_config={broadcast_dimensions=[0,1]} +} + +// CHECK-LABEL: func private @test_custom_call_dynamic_reshape +// CHECK-SAME: [[ARG_0:%.*]]: tensor, [[ARG_1:%.*]]: tensor<2xi64>) -> tensor +%test_custom_call_dynamic_reshape (arg1: f32[?], arg2: s64[2]) -> f32[?,?] { + %arg1 = f32[?] parameter(0) + %arg2 = s64[2] parameter(1) + // CHECK: mhlo.dynamic_reshape [[ARG_0]], [[ARG_1]] : (tensor, tensor<2xi64>) -> tensor + ROOT %custom-call = f32[?,?] custom-call(f32[?] %arg1, s64[2] %arg2), custom_call_target="mhlo.dynamic_reshape" +} + +// CHECK-LABEL: func private @test_custom_call_real_dynamic_slice +// CHECK-SAME: ([[ARG_0:%.*]]: tensor, [[ARG_1:%.*]]: tensor<4xi32>, [[ARG_2:%.*]]: tensor<4xi32>, [[ARG_3:%.*]]: tensor<4xi32>) -> tensor +%test_custom_call_real_dynamic_slice(arg1: f32[?,3,224,224], arg2: s32[4], arg3: s32[4], arg4: s32[4]) -> f32[?,3,224,224] { + %Arg_0.1 = f32[?,3,224,224] parameter(0) + %Arg_1.2 = s32[4] parameter(1) + %Arg_2.3 = s32[4] parameter(2) + %Arg_3.4 = s32[4] parameter(3) + + // CHECK: mhlo.real_dynamic_slice [[ARG_0]], [[ARG_1]], [[ARG_2]], [[ARG_3]] : (tensor, tensor<4xi32>, tensor<4xi32>, tensor<4xi32>) -> tensor + ROOT %custom-call.12 = f32[?,3,224,224] custom-call(f32[?,3,224,224] %Arg_0.1, s32[4] %Arg_1.2, s32[4] %Arg_2.3, s32[4] %Arg_3.4), custom_call_target="mhlo.real_dynamic_slice" +} diff --git a/third_party/xla/xla/translate/hlo_to_mhlo/tests/import.hlotxt b/third_party/xla/xla/translate/hlo_to_mhlo/tests/import.hlotxt index 5df4e887cd26cd..19783cf9e08974 100644 --- a/third_party/xla/xla/translate/hlo_to_mhlo/tests/import.hlotxt +++ b/third_party/xla/xla/translate/hlo_to_mhlo/tests/import.hlotxt @@ -608,26 +608,6 @@ add { ROOT %custom-call = (f32[2,3]) custom-call((f32[1,1], f32[2,3]) %arg1, f32[5,5] %arg2), custom_call_target="foo", output_to_operand_aliasing={{0}: (0, {1})} } -// CHECK-LABEL: func private @test_custom_call_dynamic_broadcast_in_dim -// CHECK-SAME: [[ARG_0:%.*]]: tensor<1x?xf32>, [[ARG_1:%.*]]: tensor<3xi64>) -> tensor<2x?x2xf32> -%test_custom_call_dynamic_broadcast_in_dim (arg1: f32[1,?], arg2: s64[3]) -> f32[2,?,2] { - %arg1 = f32[1,?] parameter(0) - %arg2 = s64[3] parameter(1) - // CHECK: "mhlo.dynamic_broadcast_in_dim"([[ARG_0]], [[ARG_1]]) { - // CHECK-SAME: broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} - // CHECK-SAME: (tensor<1x?xf32>, tensor<3xi64>) -> tensor<2x?x2xf32> - ROOT %custom-call = f32[2,?,2] custom-call(f32[1,?] %arg1, s64[3] %arg2), custom_call_target="mhlo.dynamic_broadcast_in_dim", backend_config={broadcast_dimensions=[0,1]} -} - -// CHECK-LABEL: func private @test_custom_call_dynamic_reshape -// CHECK-SAME: [[ARG_0:%.*]]: tensor, [[ARG_1:%.*]]: tensor<2xi64>) -> tensor -%test_custom_call_dynamic_reshape (arg1: f32[?], arg2: s64[2]) -> f32[?,?] { - %arg1 = f32[?] parameter(0) - %arg2 = s64[2] parameter(1) - // CHECK: mhlo.dynamic_reshape [[ARG_0]], [[ARG_1]] : (tensor, tensor<2xi64>) -> tensor - ROOT %custom-call = f32[?,?] custom-call(f32[?] %arg1, s64[2] %arg2), custom_call_target="mhlo.dynamic_reshape" -} - // CHECK-LABEL: func private @test_div(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> %test_div (Arg_0.1: f32[4], Arg_1.2: f32[4]) -> f32[4] { %Arg_0.1 = f32[4] parameter(0) @@ -1162,16 +1142,12 @@ add { // CHECK: mhlo.tuple %0#0, %0#1 {xla_shape = {{.*}}} : tuple, tensor> %reduce.1 = (f32[], f32[]) reduce(%Arg_0.1, %Arg_0.1, %Arg_2.3, %Arg_2.3), dimensions={0, 1}, to_apply=%reduce_helper.1 - // CHECK: [[VAL2:%.*]] = mhlo.reduce([[ARG0]] init: [[ARG2]]) - // CHECK: mhlo.add{{.*}} : tensor + // CHECK: [[VAL2:%.*]] = mhlo.reduce([[ARG0]] init: [[ARG2]]) applies mhlo.add across dimensions = [0, 1] : (tensor<4x4xf32>, tensor) -> tensor %reduce.3 = f32[] reduce(%Arg_0.1, %Arg_2.3), dimensions={0, 1}, to_apply=%reduce_helper.3 // CHECK: [[VAL3:%.*]] = mhlo.reduce([[ARG0]] init: [[ARG1]]) - // CHECK- // CHECK: mhlo.add{{.*}} : tensor<4xf32> %reduce.2 = f32[4] reduce(%Arg_0.1, %Arg_1.2), dimensions={0}, to_apply=%reduce_helper.2 - // CHECK: [[VAL4:%.*]] = mhlo.reduce([[VAL3]] init: [[ARG2]]) - // CHECK-SAME: dimensions = [0] - // CHECK: mhlo.add{{.*}} : tensor + // CHECK: [[VAL4:%.*]] = mhlo.reduce([[VAL3]] init: [[ARG2]]) applies mhlo.add across dimensions = [0] : (tensor<4xf32>, tensor) -> tensor %reduce.4 = f32[] reduce(%reduce.2, %Arg_2.3), dimensions={0}, to_apply=%reduce_helper.3 // CHECK: %5 = mhlo.subtract [[VAL2]], [[VAL4]] : tensor diff --git a/third_party/xla/xla/translate/mhlo_to_hlo/BUILD b/third_party/xla/xla/translate/mhlo_to_hlo/BUILD index e6a9493e9f592f..eeaf6fbc3adb6a 100644 --- a/third_party/xla/xla/translate/mhlo_to_hlo/BUILD +++ b/third_party/xla/xla/translate/mhlo_to_hlo/BUILD @@ -1,11 +1,16 @@ load("@bazel_skylib//rules:build_test.bzl", "build_test") load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library") load("//xla:xla.bzl", "xla_cc_test") +load("@local_tsl//tsl:tsl.bzl", "internal_visibility") load("@local_tsl//tsl:tsl.default.bzl", "get_compatible_with_portable") load("@local_tsl//tsl/platform:rules_cc.bzl", "cc_binary", "cc_library") package( - default_visibility = ["//visibility:public"], + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = internal_visibility([ + "//learning/brain/mlir:tensorflow_friends", + "//learning/brain/mlir:xla_friends", + ]), licenses = ["notice"], ) @@ -13,7 +18,6 @@ cc_library( name = "attribute_exporter", srcs = ["attribute_exporter.cc"], hdrs = ["attribute_exporter.h"], - visibility = ["//visibility:public"], deps = [ "//xla:shape_util", "//xla:statusor", @@ -33,7 +37,6 @@ cc_library( name = "layout_util", srcs = ["layout_util.cc"], hdrs = ["layout_util.h"], - visibility = ["//visibility:public"], deps = [ "//xla:shape_util", "//xla:statusor", @@ -47,7 +50,6 @@ cc_library( name = "location_exporter", srcs = ["location_exporter.cc"], hdrs = ["location_exporter.h"], - visibility = ["//visibility:public"], deps = [ ":stack_frame_index_builder", "//xla:xla_data_proto_cc", @@ -60,7 +62,6 @@ cc_library( name = "stack_frame_index_builder", srcs = ["stack_frame_index_builder.cc"], hdrs = ["stack_frame_index_builder.h"], - visibility = ["//visibility:public"], deps = [ "//xla/service:hlo_proto_cc", "@llvm-project//mlir:IR", @@ -74,7 +75,6 @@ cc_library( "operator_writers.inc", ], hdrs = ["mlir_hlo_to_hlo.h"], - visibility = ["//visibility:public"], deps = [ ":attribute_exporter", ":layout_util", @@ -156,7 +156,6 @@ cc_library( name = "translate", srcs = ["translate.cc"], hdrs = ["translate.h"], - visibility = ["//visibility:public"], deps = [ ":mlir_hlo_to_hlo", ":type_to_shape", @@ -171,7 +170,6 @@ cc_library( name = "translate_registration", testonly = True, srcs = ["translate_registration.cc"], - visibility = ["//visibility:public"], deps = [ ":translate", "//xla/mlir_hlo:hlo_dialect_registration", @@ -188,7 +186,6 @@ cc_library( name = "type_to_shape", srcs = ["type_to_shape.cc"], hdrs = ["type_to_shape.h"], - visibility = ["//visibility:public"], deps = [ "//xla:shape_util", "//xla:statusor", diff --git a/third_party/xla/xla/translate/mhlo_to_hlo/mlir_hlo_to_hlo.cc b/third_party/xla/xla/translate/mhlo_to_hlo/mlir_hlo_to_hlo.cc index 68ece3eb7dab81..f6f13091dadf57 100644 --- a/third_party/xla/xla/translate/mhlo_to_hlo/mlir_hlo_to_hlo.cc +++ b/third_party/xla/xla/translate/mhlo_to_hlo/mlir_hlo_to_hlo.cc @@ -3526,7 +3526,7 @@ xla::Status PrepareForExport(mlir::ModuleOp module) { } if (failed(pm.run(module))) return tsl::errors::Internal("Unable to prepare for XLA export"); - return ::tsl::OkStatus(); + return absl::OkStatus(); } } // namespace @@ -3603,7 +3603,7 @@ xla::Status ConvertMlirHloToHlo(mlir::ModuleOp module, xla::HloProto* hlo_proto, converter.BuildStackFramesIndexProto(); hlo_module.mutable_stack_frame_index()->Swap(&stack_frame_index); hlo_proto->mutable_hlo_module()->Swap(&hlo_module); - return ::tsl::OkStatus(); + return absl::OkStatus(); } xla::Status BuildHloFromMlirHlo(mlir::Block& block, xla::XlaBuilder& builder, @@ -3648,7 +3648,7 @@ xla::Status BuildHloFromMlirHlo(mlir::Block& block, xla::XlaBuilder& builder, } } - return ::tsl::OkStatus(); + return absl::OkStatus(); } } // namespace mlir diff --git a/third_party/xla/xla/translate/mhlo_to_hlo/tests/BUILD b/third_party/xla/xla/translate/mhlo_to_hlo/tests/BUILD index ad944460e69158..46548c0f647d91 100644 --- a/third_party/xla/xla/translate/mhlo_to_hlo/tests/BUILD +++ b/third_party/xla/xla/translate/mhlo_to_hlo/tests/BUILD @@ -2,7 +2,6 @@ load("//xla:lit.bzl", "enforce_glob", "lit_test_suite") load("@local_tsl//tsl:tsl.default.bzl", "filegroup") package( - default_visibility = ["//visibility:public"], # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], licenses = ["notice"], ) @@ -54,5 +53,4 @@ lit_test_suite( filegroup( name = "test_utilities", testonly = True, - visibility = ["//visibility:public"], ) diff --git a/third_party/xla/xla/translate/mhlo_to_hlo/tests/export.mlir b/third_party/xla/xla/translate/mhlo_to_hlo/tests/export.mlir index 63144f253e5933..9f6925c2b8066f 100644 --- a/third_party/xla/xla/translate/mhlo_to_hlo/tests/export.mlir +++ b/third_party/xla/xla/translate/mhlo_to_hlo/tests/export.mlir @@ -3108,3 +3108,166 @@ func.func @main(%operand: tensor) -> tensor { // CHECK-NEXT: [[ARG0]] = f32[?,784] parameter(0) // CHECK-NEXT: ROOT {{.*}} = f32[?,784] abs(f32[?,784] %Arg_0.1), {{.*}} // CHECK-NEXT: } + +// ----- + +// reduce multiple implicit captures test +// CHECK: HloModule +// CHECK: [[REG0:%region.*]] ({{.*}} { +// CHECK: f32[] constant(0) +// CHECK: f32[] constant(1) +// CHECK: ROOT +// CHECK: ENTRY +// CHECK: {{.*}} reduce{{.*}} to_apply=[[REG0]] +// CEHCK: ROOT +func.func @main(%arg0: tensor<2x2xf32>) -> tuple> { + %0 = mhlo.constant dense<1.000000e+00> : tensor + %1 = mhlo.constant dense<0.000000e+00> : tensor + %2 = mhlo.reduce(%arg0 init: %1) across dimensions = [0, 1] : (tensor<2x2xf32>, tensor) -> tensor + reducer(%arg1: tensor, %arg2: tensor) { + %5 = mhlo.compare NE, %arg1, %1 : (tensor, tensor) -> tensor + %6 = mhlo.compare NE, %arg2, %1 : (tensor, tensor) -> tensor + %7 = mhlo.or %5, %6 : tensor + %8 = mhlo.select %7, %0, %1 : tensor, tensor + mhlo.return %8 : tensor + } + %3 = mhlo.compare NE, %2, %1 : (tensor, tensor) -> tensor + %4 = mhlo.tuple %3 {xla_shape = "(pred[])"} : tuple> + return %4 : tuple> +} + +// ----- + +// all_reduce implicit capture test +// CHECK: HloModule +// CHECK: [[REG0:%region.*]] ({{.*}} { +// CHECK: f32[] constant(0) +// CHECK: ROOT +// CHECK: ENTRY +// CHECK: ROOT {{.*}} all-reduce{{.*}} to_apply=[[REG0]] +func.func @main(%arg0: tensor) -> tensor { + %c = mhlo.constant dense<0.0> : tensor + %0 = "mhlo.all_reduce"(%arg0) ({ + ^bb0(%arg1: tensor, %arg2: tensor): + %1 = mhlo.add %arg1, %c : tensor + mhlo.return %1 : tensor + }) {replica_groups = dense<[[0], [1]]> : tensor<2x1xi64>} : (tensor) -> tensor + return %0 : tensor +} + +// ----- + +// reduce_scatter implicit capture test +// CHECK: HloModule +// CHECK: [[REG0:%region.*]] ({{.*}} { +// CHECK: f32[] constant(0) +// CHECK: ROOT +// CHECK: ENTRY +// CHECK: ROOT {{.*}} reduce-scatter{{.*}} to_apply=[[REG0]] +func.func @main(%data: tensor<4x16xf32>) -> tensor<4x4xf32> { + %c = mhlo.constant dense<0.0> : tensor + %0 = "mhlo.reduce_scatter"(%data) ({ + ^bb0(%arg2: tensor, %arg3: tensor): + %1 = mhlo.add %arg2, %c : tensor + "mhlo.return"(%1) : (tensor) -> () + }) {replica_groups = dense<[[0, 1, 2, 3]]> : tensor<1x4xi64>, + scatter_dimension = 1 : i64, + channel_handle = #mhlo.channel_handle, + use_global_device_ids} : (tensor<4x16xf32>) -> tensor<4x4xf32> + func.return %0 : tensor<4x4xf32> +} + +// ----- + +// reduce_window implicit capture test +// CHECK: HloModule +// CHECK: [[REG0:%region.*]] ({{.*}} { +// CHECK: f32[] constant(0) +// CHECK: ROOT +// CHECK: ENTRY +// DCHECK: ROOT {{.*}} reduce-window{{.*}} to_apply=[[REG0]] +func.func @main(%arg0: tensor<2x17x31x7xf32>, %arg1: tensor) -> tensor<2x16x30x7xf32> { + %c = mhlo.constant dense<0.0> : tensor + %0 = "mhlo.reduce_window"(%arg0, %arg1) ({ + ^bb0(%arg2: tensor, %arg3: tensor): + %1 = mhlo.maximum %arg2, %c : tensor + mhlo.return %1 : tensor + }) {window_dimensions = dense<[1, 2, 2, 1]> : tensor<4xi64>} : (tensor<2x17x31x7xf32>, tensor) -> tensor<2x16x30x7xf32> + return %0 : tensor<2x16x30x7xf32> + } + +// ----- + +// Scatter implicit capture test +// CHECK: HloModule +// CHECK: [[REG0:%region.*]] ({{.*}} { +// CHECK: s32[] constant(0) +// CHECK: ROOT +// CHECK: ENTRY +// CHECK: ROOT {{.*}} scatter{{.*}} to_apply=[[REG0]] +func.func @main(%arg0: tensor<3xi32>, %arg1: tensor<1x1xi32>, + %arg2: tensor<1xi32>) -> tensor<3xi32> { + %c = mhlo.constant dense<0> : tensor + %0 = "mhlo.scatter"(%arg0, %arg1, %arg2) ({ + ^bb0(%arg3: tensor, %arg4: tensor): + %x = mhlo.add %arg4, %c : tensor + "mhlo.return"(%x) : (tensor) -> () + }) { + indices_are_sorted = false, + scatter_dimension_numbers = #mhlo.scatter< + update_window_dims = [], + inserted_window_dims = [0], + scatter_dims_to_operand_dims = [0], + index_vector_dim = 1, + >, + unique_indices = false + } : (tensor<3xi32>, tensor<1x1xi32>, tensor<1xi32>) -> tensor<3xi32> + func.return %0 : tensor<3xi32> +} + +// ----- + +// select_and_scatter implicit capture test +// CHECK: HloModule +// CHECK: [[SEL_REG:%region.*]] ({{.*}} { +// CHECK: f32[] constant(0) +// CHECK: ROOT +// CHECK: [[SCAT_REG:%region.*]] ({{.*}} { +// CHECK: f32[] constant(0) +// CHECK: ROOT +// CHECK: ENTRY +// CHECK: ROOT {{.*}} select-and-scatter{{.*}} select=[[SEL_REG]], scatter=[[SCAT_REG]] +func.func @main(%arg0: tensor<10x24x24x64xf32>, %arg1: tensor<10x23x23x64xf32>, %arg2: tensor) -> tensor<10x24x24x64xf32> { + %c1 = mhlo.constant dense<0.0> : tensor + %c2 = mhlo.constant dense<0.0> : tensor + %0 = "mhlo.select_and_scatter"(%arg0, %arg1, %arg2) ({ + ^bb0(%arg3: tensor, %arg4: tensor): + %1 = mhlo.compare GE, %arg3, %c1, TOTALORDER : (tensor, tensor) -> tensor + mhlo.return %1 : tensor + }, { + ^bb0(%arg3: tensor, %arg4: tensor): + %1 = mhlo.add %arg4, %c2 : tensor + mhlo.return %1 : tensor + }) {window_dimensions = dense<[1, 2, 2, 1]> : tensor<4xi64>} : (tensor<10x24x24x64xf32>, tensor<10x23x23x64xf32>, tensor) -> tensor<10x24x24x64xf32> + return %0 : tensor<10x24x24x64xf32> + } + +// ----- + +// sort implicit capture test +// CHECK: HloModule +// CHECK: [[REG0:%region.*]] ({{.*}} { +// CHECK: f32[] constant(0) +// CHECK: ROOT +// CHECK: ENTRY +// CHECK: {{.*}} sort{{.*}} to_apply=[[REG0]] +// CHECK: ROOT +func.func @main(%input0: tensor<16x16xf32>, %input1: tensor<16x16xi32>) { + %c = mhlo.constant dense<0.0> : tensor + %0:2 = "mhlo.sort"(%input0, %input1) ({ + ^bb0(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor): + %7 = "mhlo.compare"(%arg0, %c) {comparison_direction = #mhlo} : (tensor, tensor) -> tensor + "mhlo.return"(%7) : (tensor) -> () + }) {dimension = 1 : i64, is_stable = true} : (tensor<16x16xf32>, tensor<16x16xi32>) -> (tensor<16x16xf32>, tensor<16x16xi32>) + func.return +} diff --git a/third_party/xla/xla/translate/mhlo_to_lhlo_with_xla/BUILD b/third_party/xla/xla/translate/mhlo_to_lhlo_with_xla/BUILD deleted file mode 100644 index 0d7bd289090705..00000000000000 --- a/third_party/xla/xla/translate/mhlo_to_lhlo_with_xla/BUILD +++ /dev/null @@ -1,130 +0,0 @@ -load("@bazel_skylib//rules:build_test.bzl", "build_test") -load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda") -load("//xla:xla.bzl", "xla_cc_binary") -load("@local_config_rocm//rocm:build_defs.bzl", "if_rocm") -load("@local_tsl//tsl/platform:rules_cc.bzl", "cc_library") - -package( - default_visibility = ["//visibility:public"], - licenses = ["notice"], -) - -cc_library( - name = "mhlo_to_lhlo_with_xla", - srcs = ["mhlo_to_lhlo_with_xla.cc"], - hdrs = ["mhlo_to_lhlo_with_xla.h"], - visibility = ["//visibility:public"], - deps = [ - "//xla:debug_options_flags", - "//xla:shape_util", - "//xla:statusor", - "//xla:util", - "//xla:window_util", - "//xla:xla_data_proto_cc", - "//xla/hlo/ir:hlo", - "//xla/mlir/utils:error_util", - "//xla/mlir_hlo:lhlo", - "//xla/mlir_hlo:lhlo_gpu", - "//xla/service:backend", - "//xla/service:buffer_assignment", - "//xla/service:collective_ops_utils", - "//xla/service:hlo_parser", - "//xla/service/gpu:backend_configs_cc", - "//xla/service/gpu:cublas_cudnn", - "//xla/service/gpu:ir_emission_utils", - "//xla/service/gpu:matmul_utils", - "//xla/service/llvm_ir:buffer_assignment_util", - "//xla/service/llvm_ir:llvm_util", - "//xla/translate/hlo_to_mhlo:attribute_importer", - "//xla/translate/hlo_to_mhlo:hlo_module_importer", - "//xla/translate/hlo_to_mhlo:hlo_utils", - "//xla/translate/mhlo_to_hlo:mlir_hlo_to_hlo", - "//xla/translate/mhlo_to_hlo:type_to_shape", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/cleanup", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/status", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/types:optional", - "@llvm-project//llvm:Support", - "@llvm-project//mlir:ArithDialect", - "@llvm-project//mlir:AsmParser", - "@llvm-project//mlir:BufferizationDialect", - "@llvm-project//mlir:FuncDialect", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:MemRefDialect", - "@llvm-project//mlir:Pass", - "@llvm-project//mlir:Support", - "@llvm-project//mlir:TranslateLib", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:status", - "@local_tsl//tsl/platform:statusor", - ], -) - -cc_library( - name = "translate_registration", - testonly = True, - srcs = ["translate_registration.cc"], - visibility = ["//visibility:public"], - deps = [ - ":mhlo_to_lhlo_with_xla", - "@llvm-project//mlir:TranslateLib", - ], - alwayslink = 1, -) - -build_test( - name = "xla-translate-opt_build_test", - targets = [ - ":xla-translate-opt", - ], -) - -xla_cc_binary( - name = "xla-translate-opt", - testonly = True, - srcs = ["xla_translate_opt_main.cc"], - deps = [ - ":mhlo_to_lhlo_with_xla", # buildcleaner: keep - "//xla/mlir/framework/ir:xla_framework", - "//xla/mlir/framework/transforms:passes", - "//xla/mlir_hlo:hlo_dialect_registration", - "//xla/service:cpu_plugin", - "//xla/service/cpu:hlo_xla_runtime_pipeline", # buildcleaner: keep - "@llvm-project//llvm:Support", - "@llvm-project//mlir:AllPassesAndDialects", - "@llvm-project//mlir:MlirOptLib", - "@local_tsl//tsl/platform:platform_port", - "@stablehlo//:register", - ], -) - -build_test( - name = "xla-translate-gpu-opt_build_test", - targets = [ - ":xla-translate-gpu-opt", - ], -) - -xla_cc_binary( - name = "xla-translate-gpu-opt", - testonly = True, - srcs = ["xla_translate_opt_main.cc"], - deps = [ - ":mhlo_to_lhlo_with_xla", # buildcleaner: keep - "//xla/mlir/framework/ir:xla_framework", - "//xla/mlir/framework/transforms:passes", - "//xla/mlir_hlo:all_passes", - "//xla/mlir_hlo:hlo_dialect_registration", - "//xla/service:gpu_plugin", - "//xla/stream_executor", - "@llvm-project//llvm:Support", - "@llvm-project//mlir:AllPassesAndDialects", - "@llvm-project//mlir:MlirOptLib", - "@local_tsl//tsl/platform:platform_port", - "@stablehlo//:register", - ] + if_cuda(["//xla/stream_executor/cuda:cublas_plugin"]) + if_rocm([ - "//xla/stream_executor/rocm:rocblas_plugin", - ]), -) diff --git a/third_party/xla/xla/translate/mhlo_to_lhlo_with_xla/mhlo_to_lhlo_with_xla.cc b/third_party/xla/xla/translate/mhlo_to_lhlo_with_xla/mhlo_to_lhlo_with_xla.cc deleted file mode 100644 index 0e3adb005afd7c..00000000000000 --- a/third_party/xla/xla/translate/mhlo_to_lhlo_with_xla/mhlo_to_lhlo_with_xla.cc +++ /dev/null @@ -1,2595 +0,0 @@ -/* Copyright 2020 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "xla/translate/mhlo_to_lhlo_with_xla/mhlo_to_lhlo_with_xla.h" - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include "absl/algorithm/container.h" -#include "absl/cleanup/cleanup.h" -#include "absl/container/flat_hash_map.h" -#include "absl/status/status.h" -#include "absl/strings/match.h" -#include "absl/types/optional.h" -#include "llvm/ADT/STLExtras.h" -#include "llvm/ADT/SmallVector.h" -#include "mlir/AsmParser/AsmParser.h" // from @llvm-project -#include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project -#include "mlir/Dialect/Bufferization/IR/Bufferization.h" // from @llvm-project -#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project -#include "mlir/Dialect/MemRef/IR/MemRef.h" // from @llvm-project -#include "mlir/IR/Attributes.h" // from @llvm-project -#include "mlir/IR/Builders.h" // from @llvm-project -#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project -#include "mlir/IR/BuiltinOps.h" // from @llvm-project -#include "mlir/IR/BuiltinTypes.h" // from @llvm-project -#include "mlir/IR/Location.h" // from @llvm-project -#include "mlir/IR/MLIRContext.h" // from @llvm-project -#include "mlir/IR/OpDefinition.h" // from @llvm-project -#include "mlir/IR/Operation.h" // from @llvm-project -#include "mlir/IR/SymbolTable.h" // from @llvm-project -#include "mlir/IR/Value.h" // from @llvm-project -#include "mlir/IR/ValueRange.h" // from @llvm-project -#include "mlir/IR/Verifier.h" // from @llvm-project -#include "mlir/Pass/Pass.h" // from @llvm-project -#include "mlir/Support/LLVM.h" // from @llvm-project -#include "mlir/Support/LogicalResult.h" // from @llvm-project -#include "mlir/Tools/mlir-translate/Translation.h" // from @llvm-project -#include "xla/debug_options_flags.h" -#include "xla/hlo/ir/hlo_casting_utils.h" -#include "xla/hlo/ir/hlo_opcode.h" -#include "xla/mlir/utils/error_util.h" -#include "xla/mlir_hlo/lhlo/IR/lhlo_ops.h" -#include "xla/mlir_hlo/lhlo_gpu/IR/lhlo_gpu_ops.h" -#include "xla/service/backend.h" -#include "xla/service/buffer_assignment.h" -#include "xla/service/collective_ops_utils.h" -#include "xla/service/gpu/backend_configs.pb.h" -#include "xla/service/gpu/cublas_cudnn.h" -#include "xla/service/gpu/ir_emission_utils.h" -#include "xla/service/gpu/matmul_utils.h" -#include "xla/service/hlo_parser.h" -#include "xla/service/llvm_ir/buffer_assignment_util.h" -#include "xla/service/llvm_ir/llvm_util.h" -#include "xla/shape_util.h" -#include "xla/statusor.h" -#include "xla/translate/hlo_to_mhlo/attribute_importer.h" -#include "xla/translate/hlo_to_mhlo/hlo_function_importer.h" -#include "xla/translate/hlo_to_mhlo/hlo_utils.h" -#include "xla/translate/mhlo_to_hlo/mlir_hlo_to_hlo.h" -#include "xla/translate/mhlo_to_hlo/type_to_shape.h" -#include "xla/util.h" -#include "xla/window_util.h" -#include "xla/xla_data.pb.h" -#include "tsl/platform/errors.h" -#include "tsl/platform/status.h" -#include "tsl/platform/statusor.h" - -using xla::BufferAllocation; -using xla::BufferAssignment; -using xla::HloComputation; -using xla::HloCustomCallInstruction; -using xla::HloInfeedInstruction; -using xla::HloInstruction; -using xla::HloModule; -using xla::HloModuleProto; -using xla::HloOutfeedInstruction; -using xla::HloProto; -using xla::Shape; - -namespace mlir { -namespace { - -absl::string_view StringRefToView(llvm::StringRef ref) { - return {ref.data(), ref.size()}; -} - -tsl::StatusOr> HloModuleFromProto( - const HloProto& hlo_proto) { - const HloModuleProto& module_proto = hlo_proto.hlo_module(); - TF_ASSIGN_OR_RETURN(const xla::HloModuleConfig module_config, - HloModule::CreateModuleConfigFromProto( - module_proto, xla::GetDebugOptionsFromFlags())); - return HloModule::CreateFromProto(module_proto, module_config); -} - -bool NoParallelCustomCallCollective(const HloInstruction* instr) { - auto backend_config = instr->backend_config() - .value() - .collective_backend_config(); - return backend_config.no_parallel_custom_call(); -} - -// Convert the MLIR `module` from HLO dialect to LHLO dialect using XLA for the -// given platform. -tsl::Status ConvertHloToLmhlo(std::unique_ptr hlo_module, - ModuleOp module, StringRef platform_name) { - auto platform = xla::se::MultiPlatformManager::PlatformWithName( - StringRefToView(platform_name)); - if (!platform.ok()) { - std::string error_msg; - llvm::raw_string_ostream os(error_msg); - os << "failed to get platform: " << platform.status().ToString() - << " (available Platform: "; - std::vector available_platforms; - (void)xla::se::MultiPlatformManager::PlatformsWithFilter( - [&](const stream_executor::Platform* p) { - available_platforms.push_back(p->Name()); - return false; - }); - llvm::interleaveComma(available_platforms, os); - os << ")"; - return tsl::errors::InvalidArgument("%s", os.str().c_str()); - } - - xla::BackendOptions backend_options; - backend_options.set_platform(platform.value()); - auto backend_or_err = xla::Backend::CreateBackend(backend_options); - TF_RETURN_WITH_CONTEXT_IF_ERROR(backend_or_err.status(), - "failed to create XLA Backend "); - auto backend = std::move(backend_or_err.value()); - - tsl::StatusOr> assignment = - backend->compiler()->AssignBuffers(hlo_module.get(), - backend->default_stream_executor()); - TF_RETURN_WITH_CONTEXT_IF_ERROR(assignment.status(), - "running XLA buffer assigment"); - - // Clear the module before populating it back with the result of the - // conversion. - module.getBody()->clear(); - OpBuilder builder(module); - - std::vector ordered_allocations; - TF_RETURN_WITH_CONTEXT_IF_ERROR( - HloToLhloModule(**assignment, *hlo_module, module, &ordered_allocations), - "converting HLO to LHLO"); - - return ::tsl::OkStatus(); -} - -} // namespace - -// Creates MLIR operands corresponding to operands and results of the XLA HLO -// instruction. If `num_operands` is valid, then only the first `num_operands` -// operands of the HLO instruction will be considered. -tsl::Status LhloDialectEmitter::CreateOperands( - const HloInstruction* instr, std::optional num_operands, - TokenLoweringMode token_mode, llvm::SmallVectorImpl& operands, - size_t& num_arguments, size_t& num_results) { - if (num_operands.value_or(0) > instr->operand_count()) - return tsl::errors::InvalidArgument( - "num_operands must be <= operand count"); - for (int64_t i = 0; i < num_operands.value_or(instr->operand_count()); ++i) { - TF_RETURN_IF_ERROR(GetOrCreateView(instr->operand(i), &operands, - /*result_subset=*/{}, token_mode)); - } - num_arguments = operands.size(); - TF_RETURN_IF_ERROR( - GetOrCreateView(instr, &operands, /*result_subset=*/{}, token_mode)); - num_results = operands.size() - num_arguments; - return ::tsl::OkStatus(); -} - -template -OpType LhloDialectEmitter::CreateOpWithoutAttrs(const HloInstruction* instr, - ValueRange operands) { - Location loc = getLocation(instr); - return builder_.create(loc, std::nullopt, operands, - llvm::ArrayRef{}); -} - -template -tsl::StatusOr LhloDialectEmitter::CreateOpWithoutAttrs( - const HloInstruction* instr, size_t& num_arguments, size_t& num_results, - std::optional num_operands) { - llvm::SmallVector operands; - TF_RETURN_IF_ERROR(CreateOperands(instr, num_operands, - TokenLoweringMode::kFailToLower, operands, - num_arguments, num_results)); - return CreateOpWithoutAttrs(instr, operands); -} - -tsl::StatusOr LhloDialectEmitter::EmitOp( - const HloInstruction* instr) { - using xla::HloOpcode; - switch (instr->opcode()) { - case HloOpcode::kAddDependency: - return nullptr; - case HloOpcode::kAfterAll: - // LMHLO is already ordered. This assumption may be broken after - // introducing async regions and partial orders. - return nullptr; - case HloOpcode::kAllGatherStart: - return EmitAllGatherStartOp(instr); - case HloOpcode::kAllGatherDone: - return EmitAllGatherDoneOp(instr); - case HloOpcode::kAllReduceStart: - return EmitAllReduceStartOp(instr); - case HloOpcode::kAllReduceDone: - return EmitAllReduceDoneOp(instr); - case HloOpcode::kAsyncStart: - return EmitAsyncStartOp(instr); - case HloOpcode::kAsyncDone: - return EmitAsyncDoneOp(instr); - case HloOpcode::kBitcast: - return EmitBitcast(instr); - case HloOpcode::kCollectivePermuteStart: - return EmitCollectivePermuteStartOp(instr); - case HloOpcode::kCollectivePermuteDone: - return EmitCollectivePermuteDoneOp(instr); - case HloOpcode::kConditional: - return EmitCaseOp(instr); - case HloOpcode::kFft: - return EmitFftOp(instr); - case HloOpcode::kGetTupleElement: - return nullptr; - case HloOpcode::kInfeed: - return EmitInfeedOp(instr); - case HloOpcode::kOutfeed: - return EmitOutfeedOp(instr); - case HloOpcode::kPartitionId: - return CreateOpWithoutAttrs(instr); - case HloOpcode::kReplicaId: - return CreateOpWithoutAttrs(instr); - case HloOpcode::kTriangularSolve: - return EmitTriangularSolveOp(instr); - case HloOpcode::kTuple: - return nullptr; - case HloOpcode::kSort: - return EmitSortOp(instr); - case HloOpcode::kFusion: - return EmitFusionOp(instr); - case HloOpcode::kScatter: - return EmitScatterOp(instr); - case HloOpcode::kSelectAndScatter: - return EmitSelectAndScatterOp(instr); - case HloOpcode::kCustomCall: - return EmitCustomCallOp(instr); - case HloOpcode::kConstant: - return EmitConstant(instr); - case HloOpcode::kRngGetAndUpdateState: - return EmitRngGetAndUpdateStateOp(instr); - case HloOpcode::kWhile: - return EmitWhileOp(instr); - case HloOpcode::kSend: - return EmitSendOp(instr); - case HloOpcode::kSendDone: - return EmitSendDoneOp(instr); - case HloOpcode::kRecv: - return EmitRecvOp(instr); - case HloOpcode::kRecvDone: - return EmitRecvDoneOp(instr); - // TODO(b/302038092): Currently the command buffer call is represented by - // a kCall. We need to be able to differentiate it from a regular kCall. - case HloOpcode::kCall: - return EmitCommandBufferOp(instr); - default: - llvm::errs() << instr->ToString(); - llvm::errs() << "\n\nModule:\n" - << instr->GetModule()->ToString() << "\n\n"; - return tsl::errors::Internal( - absl::StrCat("LHLO opcode ", xla::HloOpcodeString(instr->opcode()), - " is not supported.")); - } -} - -tsl::Status LhloDialectEmitter::DefaultAction(const HloInstruction* instr) { - TF_ASSIGN_OR_RETURN(auto* op, EmitOp(instr)); - if (op) { - lhlo_to_hlo_[op] = instr; - } - return tsl::OkStatus(); -} - -tsl::StatusOr LhloDialectEmitter::EmitSortOp( - const HloInstruction* instr) { - TF_ASSIGN_OR_RETURN(auto sort, CreateOpWithoutAttrs(instr)); - auto* sort_instr = xla::Cast(instr); - sort.setDimensionAttr( - builder_.getI64IntegerAttr(sort_instr->sort_dimension())); - sort.setIsStableAttr(builder_.getBoolAttr(sort_instr->is_stable())); - TF_RETURN_IF_ERROR(xla::HloFunctionImporter::ImportAsRegion( - *sort_instr->called_computations()[0], symbol_table_, - &sort.getComparator(), &builder_)); - return sort; -} - -// Walks MHLO::TupleOp recursively. -tsl::Status WalkTuplePostOrder( - Value v, const std::function& visitor) { - if (auto* op = v.getDefiningOp()) { - if (auto tuple = dyn_cast(op)) { - for (Value sub_v : tuple.getVal()) { - TF_RETURN_IF_ERROR(WalkTuplePostOrder(sub_v, visitor)); - } - return ::tsl::OkStatus(); - } - } - return visitor(v); -} - -tsl::StatusOr LhloDialectEmitter::RewriteFusionOperand( - const HloInstruction* root, const Shape& shape, - xla::ShapeIndex* shape_index, OpBuilder* b, Location loc) { - if (shape.IsTuple()) { - llvm::SmallVector values; - for (int i = 0; i < shape.tuple_shapes_size(); ++i) { - shape_index->push_back(i); - TF_ASSIGN_OR_RETURN( - auto v, RewriteFusionOperand(root, shape.tuple_shapes(i), shape_index, - b, loc)); - values.push_back(v); - shape_index->pop_back(); - } - return Value(b->create(loc, values)); - } - TF_ASSIGN_OR_RETURN(Value memref, - GetOrCreateArrayView(root, shape, *shape_index)); - auto load = b->create(loc, memref); - if (shape.layout() != - xla::LayoutUtil::MakeDescendingLayout(shape.dimensions().size())) { - llvm::SmallVector minor_to_major( - shape.layout().minor_to_major().begin(), - shape.layout().minor_to_major().end()); - load->setAttr("xla_shape", - b->getStringAttr(shape.ToString(/*print_layout=*/true))); - } - return load.getResult(); -} - -// Emit a lmhlo.fusion based on XLA HLO fusion. Structurally they are not neatly -// equivalent. Specifically, XLA HLO fusion: -// fused_computation { -// %p0 = parameter(0) -// %p1 = parameter(1) -// ... -// ROOT %ret = ... -// } -// will be converted to -// lmhlo.fusion() { // no explicit operands -// // capturing outside buffers -// %p0 = bufferization.to_tensor(%arg0) : memref<...> -> tensor<...> -// %p1 = bufferization.to_tensor(%arg1) : memref<...> -> tensor<...> -// ... -// tensor_store ..., %ret // store a tensor to a memref -// } -tsl::StatusOr LhloDialectEmitter::EmitFusionOp( - const HloInstruction* instr) { - Location loc = getLocation(instr); - - auto* fusion_instr = xla::Cast(instr); - - auto fusion = builder_.create(getLocation(instr)); - auto after_fusion = builder_.saveInsertionPoint(); - auto reverter = absl::MakeCleanup( - [this, after_fusion] { builder_.restoreInsertionPoint(after_fusion); }); - builder_ = mlir::OpBuilder(fusion); - - auto region_builder = OpBuilder::atBlockBegin(&fusion.getRegion().front()); - - llvm::SmallVector arguments; - for (int i = 0; i < instr->operands().size(); ++i) { - const HloInstruction* operand = instr->operand(i); - xla::ShapeIndex shape_index; - TF_ASSIGN_OR_RETURN( - auto arg, RewriteFusionOperand(operand, operand->shape(), &shape_index, - ®ion_builder, loc)); - arguments.push_back(arg); - } - - TF_ASSIGN_OR_RETURN(Value result, - xla::HloFunctionImporter::ImportInstructions( - *fusion_instr->fused_instructions_computation(), - arguments, symbol_table_, ®ion_builder)); - { - int i = 0; - llvm::SmallVector output; - TF_RETURN_IF_ERROR(GetOrCreateView(instr, &output)); - TF_RETURN_IF_ERROR(WalkTuplePostOrder(result, [&](Value v) mutable { - auto materialize_op = - region_builder.create( - loc, v, output[i++]); - materialize_op.setWritable(true); - return ::tsl::OkStatus(); - })); - if (i != output.size()) { - return xla::Internal("output sizes don't match"); - } - } - - // The fusion op might not have a backend-config. But we at least want to set - // the fusion kind, because LMHLO doesn't have this concept. - TF_ASSIGN_OR_RETURN(auto gpu_config, - instr->backend_config()); - xla::gpu::FusionBackendConfig& backend_config = - *gpu_config.mutable_fusion_backend_config(); - if (backend_config.kind().empty() && - instr->opcode() == xla::HloOpcode::kFusion) { - backend_config.set_kind(std::string(ToString(instr->fusion_kind()))); - } - - TF_ASSIGN_OR_RETURN(std::string backend_config_str, - HloInstruction::BackendConfigToRawString(gpu_config)); - fusion.setBackendConfigAttr(builder_.getStringAttr(backend_config_str)); - - // For custom fusion backend config we also attach serialized version of the - // attached HLO computation. - if (backend_config.kind() == "__custom_fusion") { - std::string computation_str; - fusion_instr->fused_instructions_computation()->ToProto().SerializeToString( - &computation_str); - fusion->setAttr("__custom_fusion_computation", - builder_.getStringAttr(computation_str)); - } - - // Fold GTE/Tuple pairs. - // - // Since the fused region refers to values in its parent region, we can't - // call applyPatternAndFoldGreedily. We optimize it manually. - // - // Only walk once, because post-ordering is exactly what we need for GTE - // optimizations. - fusion.getRegion().walk([](mhlo::GetTupleElementOp gte) { - SmallVector folded_values; - if (succeeded(OpBuilder(gte).tryFold(gte, folded_values))) { - gte.replaceAllUsesWith(folded_values[0]); - } - }); - - // Effectively a DCE on the region. - { - llvm::SmallVector ops; - fusion.getRegion().walk([&](mlir::Operation* op) { ops.push_back(op); }); - // Visit the user first. - std::reverse(ops.begin(), ops.end()); - for (auto op : ops) { - if (isOpTriviallyDead(op)) op->erase(); - } - } - - return fusion; -} - -tsl::StatusOr -LhloDialectEmitter::GetScatterDimensionNumbers(const HloInstruction* instr, - mlir::MLIRContext* context) { - auto* scatter_instr = xla::Cast(instr); - - const xla::ScatterDimensionNumbers& xla_scatter_dim = - scatter_instr->scatter_dimension_numbers(); - - auto get_i64_array = [](absl::Span container) { - return ArrayRef{container.data(), - static_cast(container.size())}; - }; - auto scatter_dimension_numbers = mhlo::ScatterDimensionNumbersAttr::get( - context, get_i64_array(xla_scatter_dim.update_window_dims()), - get_i64_array(xla_scatter_dim.inserted_window_dims()), - get_i64_array(xla_scatter_dim.scatter_dims_to_operand_dims()), - xla_scatter_dim.index_vector_dim()); - return scatter_dimension_numbers; -} - -tsl::StatusOr LhloDialectEmitter::EmitScatterOp( - const HloInstruction* instr) { - TF_ASSIGN_OR_RETURN(auto scatter, - CreateOpWithoutAttrs(instr)); - - // copy attributes - auto* scatter_instr = xla::Cast(instr); - - TF_ASSIGN_OR_RETURN(auto scatter_dimension_numbers, - GetScatterDimensionNumbers(instr, builder_.getContext())); - scatter.setScatterDimensionNumbersAttr(scatter_dimension_numbers); - scatter.setIndicesAreSortedAttr( - builder_.getBoolAttr(scatter_instr->indices_are_sorted())); - scatter.setUniqueIndicesAttr( - builder_.getBoolAttr(scatter_instr->unique_indices())); - - // import update computation as region - TF_RETURN_IF_ERROR(xla::HloFunctionImporter::ImportAsRegion( - *scatter_instr->called_computations()[0], symbol_table_, - &scatter.getUpdateComputation(), &builder_)); - - return scatter; -} - -tsl::StatusOr -LhloDialectEmitter::EmitSelectAndScatterOp(const HloInstruction* instr) { - TF_ASSIGN_OR_RETURN(auto select_and_scatter, - CreateOpWithoutAttrs(instr)); - - // copy attributes - auto* select_and_scatter_instr = - xla::Cast(instr); - const xla::Window& window = select_and_scatter_instr->window(); - - if (xla::window_util::HasDilation(window)) { - return tsl::errors::Unimplemented( - "Dilation for SelectAndScatter is not supported"); - } - - select_and_scatter.setWindowDimensionsAttr( - GetWindowElements(window, [](const xla::WindowDimension& dim) { - return static_cast(dim.size()); - })); - select_and_scatter.setWindowStridesAttr( - GetWindowElements(window, [](const xla::WindowDimension& dim) { - return static_cast(dim.stride()); - })); - select_and_scatter.setPaddingAttr( - GetWindowElements(window, [](const xla::WindowDimension& dim) { - return static_cast(dim.padding_low()); - })); - - // import select and scatter computation as region - TF_RETURN_IF_ERROR(xla::HloFunctionImporter::ImportAsRegion( - *select_and_scatter_instr->select(), symbol_table_, - &select_and_scatter.getSelect(), &builder_)); - TF_RETURN_IF_ERROR(xla::HloFunctionImporter::ImportAsRegion( - *select_and_scatter_instr->scatter(), symbol_table_, - &select_and_scatter.getScatter(), &builder_)); - return select_and_scatter; -} - -tsl::StatusOr LhloDialectEmitter::EmitCustomCallOp( - const HloInstruction* instr) { - auto* custom_call_instr = xla::Cast(instr); - - if (xla::gpu::IsCustomCallToCusolver(*instr)) { - return EmitCholesky(custom_call_instr); - } - - if (xla::gpu::IsLegacyCublasMatmul(*instr)) { - return EmitGemm(custom_call_instr); - } - - if (xla::gpu::IsCublasLtMatmul(*instr)) { - return EmitCublasLtMatmul(custom_call_instr); - } - - if (xla::gpu::IsCublasLtMatmulF8(*instr)) { - return EmitCublasLtMatmulF8(custom_call_instr); - } - - if (xla::gpu::IsCustomCallToDnnConvolution(*instr)) { - return EmitDnnConvolution(custom_call_instr); - } - - if (xla::gpu::IsCudnnConvolutionReorder(*instr)) { - return EmitDnnConvolutionReorderVectorized(custom_call_instr); - } - - if (xla::gpu::IsCustomCallToDnnNorm(*instr)) { - return EmitDnnNorm(custom_call_instr); - } - - if (xla::gpu::IsFwdCustomCallTofMHA(*instr)) { - return EmitDnnfMHA(custom_call_instr); - } - if (xla::gpu::IsBwdCustomCallTofMHA(*instr)) { - return EmitDnnfMHABackward(custom_call_instr); - } - if (xla::gpu::IsCubDeviceRadixSort(*instr)) { - return EmitCubDeviceRadixSort(custom_call_instr); - } - - // For custom call, if there are any token operands or results, they will not - // be represented in LHLO so we need to remember the mapping. First create - // operands where each token is replaced with a null Value. - llvm::SmallVector operands; - size_t num_arguments, num_results; - TF_RETURN_IF_ERROR(CreateOperands(instr, /*num_operands=*/std::nullopt, - TokenLoweringMode::kUseNull, operands, - num_arguments, num_results)); - - // Now check if any of the operands is Null, which would indicate the presence - // of a token in the input or output. - bool has_token = llvm::any_of(operands, [](Value v) { return !v; }); - - lmhlo::CustomCallTargetArgMappingAttr target_mapping; - if (has_token) { - // If there was a token, squeeze all the non-token arguments and results - // (in-place) and remember the mapping. - int next_index = 0; - llvm::SmallVector arg_to_target_arg_mapping; - for (int i = 0; i < num_arguments; ++i) { - if (operands[i]) { - arg_to_target_arg_mapping.push_back(i); - operands[next_index++] = operands[i]; - } - } - // Size of arg_to_target_arg_mapping is the number of arguments in LHLO. - llvm::SmallVector result_to_target_result_mapping; - for (int i = num_arguments; i < operands.size(); ++i) { - if (operands[i]) { - result_to_target_result_mapping.push_back(i - num_arguments); - operands[next_index++] = operands[i]; - } - } - - // Build the mapping attribute. - target_mapping = lmhlo::CustomCallTargetArgMappingAttr::get( - builder_.getContext(), num_arguments, num_results, - arg_to_target_arg_mapping, result_to_target_result_mapping); - - // Drop the remaining operands and adjust num_arguments and num_results - // for LMHLO creation. - operands.resize(next_index); - num_arguments = arg_to_target_arg_mapping.size(); - num_results = result_to_target_result_mapping.size(); - } - - auto custom_call = CreateOpWithoutAttrs(instr, operands); - TF_ASSIGN_OR_RETURN( - auto mlir_api_version, - ConvertCustomCallApiVersion(custom_call_instr->api_version())); - custom_call.setCallTargetNameAttr( - builder_.getStringAttr(custom_call_instr->custom_call_target())); - custom_call.setApiVersionAttr(mhlo::CustomCallApiVersionAttr::get( - builder_.getContext(), mlir_api_version)); - - // For typed custom calls we need to parse user-defined attributes back to the - // dictionary attribute, and then add them back to the custom call op. - if (mlir_api_version == mhlo::CustomCallApiVersion::API_VERSION_TYPED_FFI) { - if (custom_call_instr->opaque().empty()) { - auto empty = mlir::DictionaryAttr::get(builder_.getContext()); - custom_call.setBackendConfigAttr(empty); - } else { - mlir::Attribute attr = mlir::parseAttribute(custom_call_instr->opaque(), - builder_.getContext()); - TF_RET_CHECK(attr.isa()) - << "Couldn't parse backend config into a dictionary attribute"; - custom_call.setBackendConfigAttr(attr); - } - } else { - custom_call.setBackendConfigAttr( - builder_.getStringAttr(custom_call_instr->opaque())); - } - - const int32_t segments[2] = {static_cast(num_arguments), - static_cast(num_results)}; - custom_call->setAttr(lmhlo::CustomCallOp::getOperandSegmentSizeAttr(), - builder_.getDenseI32ArrayAttr(segments)); - if (target_mapping) custom_call.setTargetArgMappingAttr(target_mapping); - - for (int i = 0; i < custom_call_instr->called_computations().size(); ++i) { - auto& region = custom_call->getRegion(i); - TF_RETURN_IF_ERROR(xla::HloFunctionImporter::ImportAsRegion( - *custom_call_instr->called_computation(), symbol_table_, ®ion, - &builder_)); - } - - return custom_call.getOperation(); -} - -tsl::StatusOr LhloDialectEmitter::EmitCholesky( - const HloCustomCallInstruction* custom_call) { - TF_ASSIGN_OR_RETURN(auto cholesky_op, - CreateOpWithoutAttrs(custom_call)); - TF_ASSIGN_OR_RETURN(xla::CholeskyOptions options, - custom_call->backend_config()); - cholesky_op.setIsLowerAttr(builder_.getBoolAttr(options.lower())); - return cholesky_op; -} - -namespace { - -mhlo::DotDimensionNumbersAttr GetDotDimensionNumbersAttr( - const OpBuilder& builder, const xla::DotDimensionNumbers& hlo_dims) { - auto arrayref = [](absl::Span array) { - return llvm::ArrayRef{array.data(), array.size()}; - }; - return mhlo::DotDimensionNumbersAttr::get( - builder.getContext(), arrayref(hlo_dims.lhs_batch_dimensions()), - arrayref(hlo_dims.rhs_batch_dimensions()), - arrayref(hlo_dims.lhs_contracting_dimensions()), - arrayref(hlo_dims.rhs_contracting_dimensions())); -} - -template -void SetMatmulAttributes(OpT op, const xla::gpu::GemmBackendConfig& config, - OpBuilder& builder) { - op.setDotDimensionNumbersAttr( - GetDotDimensionNumbersAttr(builder, config.dot_dimension_numbers())); - op.setAlphaRealAttr(builder.getF64FloatAttr(config.alpha_real())); - op.setAlphaImagAttr(builder.getF64FloatAttr(config.alpha_imag())); - op.setBetaAttr(builder.getF64FloatAttr(config.beta())); - if (config.algorithm_case() == - xla::gpu::GemmBackendConfig::kSelectedAlgorithm) { - op.setAlgorithmAttr(builder.getI64IntegerAttr(config.selected_algorithm())); - } - op.setPrecisionConfigAttr( - xla::ConvertPrecisionConfig(&config.precision_config(), &builder)); - op.setGradXAttr(builder.getBoolAttr(config.grad_x())); - op.setGradYAttr(builder.getBoolAttr(config.grad_y())); -} - -tsl::StatusOr AsLhloEpilogue( - xla::gpu::GemmBackendConfig_Epilogue epilogue) { - switch (epilogue) { - case xla::gpu::GemmBackendConfig::DEFAULT: - return lmhlo_gpu::CublasLtMatmulEpilogue::Default; - case xla::gpu::GemmBackendConfig::RELU: - return lmhlo_gpu::CublasLtMatmulEpilogue::Relu; - case xla::gpu::GemmBackendConfig::GELU: - return lmhlo_gpu::CublasLtMatmulEpilogue::Gelu; - case xla::gpu::GemmBackendConfig::GELU_AUX: - return lmhlo_gpu::CublasLtMatmulEpilogue::GeluAux; - case xla::gpu::GemmBackendConfig::BIAS: - return lmhlo_gpu::CublasLtMatmulEpilogue::Bias; - case xla::gpu::GemmBackendConfig::BIAS_RELU: - return lmhlo_gpu::CublasLtMatmulEpilogue::BiasRelu; - case xla::gpu::GemmBackendConfig::BIAS_GELU: - return lmhlo_gpu::CublasLtMatmulEpilogue::BiasGelu; - case xla::gpu::GemmBackendConfig::BIAS_GELU_AUX: - return lmhlo_gpu::CublasLtMatmulEpilogue::BiasGeluAux; - default: - return xla::Internal("unknown epilogue"); - } -} - -tsl::StatusOr AsLhloFusedMhaDagSignature( - xla::gpu::CudnnfMHAKind kind) { - switch (kind) { - case xla::gpu::CudnnfMHAKind::kBmmBmm: - return lmhlo_gpu::FusedMhaDagSignature::Default; - case xla::gpu::CudnnfMHAKind::kScaleBiasMaskSoftmax: - return lmhlo_gpu::FusedMhaDagSignature::ScaleBiasMaskSoftmax; - case xla::gpu::CudnnfMHAKind::kScaleBiasMaskSoftmaxDropout: - return lmhlo_gpu::FusedMhaDagSignature::ScaleBiasMaskSoftmaxDropout; - case xla::gpu::CudnnfMHAKind::kScaleMaskSoftmax: - return lmhlo_gpu::FusedMhaDagSignature::ScaleMaskSoftmax; - case xla::gpu::CudnnfMHAKind::kScaleMaskSoftmaxDropout: - return lmhlo_gpu::FusedMhaDagSignature::ScaleMaskSoftmaxDropout; - case xla::gpu::CudnnfMHAKind::kSoftmaxDropout: - return lmhlo_gpu::FusedMhaDagSignature::SoftmaxDropout; - case xla::gpu::CudnnfMHAKind::kSoftmax: - return lmhlo_gpu::FusedMhaDagSignature::Softmax; - case xla::gpu::CudnnfMHAKind::kScaleBiasSoftmax: - return lmhlo_gpu::FusedMhaDagSignature::ScaleBiasSoftmax; - case xla::gpu::CudnnfMHAKind::kScaleBiasSoftmaxDropout: - return lmhlo_gpu::FusedMhaDagSignature::ScaleBiasSoftmaxDropout; - default: - return xla::Internal("unknown cudnn fmha fwd kind"); - } -} -tsl::StatusOr -AsLhloFusedMhaBackwardDagSignature(xla::gpu::CudnnfMHAKind kind) { - switch (kind) { - case xla::gpu::CudnnfMHAKind::kBackwardScaleBiasSoftmax: - return lmhlo_gpu::FusedMhaBackwardDagSignature::BackwardScaleBiasSoftmax; - break; - case xla::gpu::CudnnfMHAKind::kBackwardScaleBiasSoftmaxDropout: - return lmhlo_gpu::FusedMhaBackwardDagSignature:: - BackwardScaleBiasSoftmaxDropout; - break; - case xla::gpu::CudnnfMHAKind::kBackwardScaleBiasMaskSoftmax: - return lmhlo_gpu::FusedMhaBackwardDagSignature:: - BackwardScaleBiasMaskSoftmax; - break; - case xla::gpu::CudnnfMHAKind::kBackwardScaleBiasMaskSoftmaxDropout: - return lmhlo_gpu::FusedMhaBackwardDagSignature:: - BackwardScaleBiasMaskSoftmaxDropout; - break; - case xla::gpu::CudnnfMHAKind::kBackwardSoftmax: - return lmhlo_gpu::FusedMhaBackwardDagSignature::BackwardSoftmax; - break; - case xla::gpu::CudnnfMHAKind::kBackwardSoftmaxDropout: - return lmhlo_gpu::FusedMhaBackwardDagSignature::BackwardSoftmaxDropout; - break; - case xla::gpu::CudnnfMHAKind::kBackwardScaleMaskSoftmax: - return lmhlo_gpu::FusedMhaBackwardDagSignature::BackwardScaleMaskSoftmax; - break; - case xla::gpu::CudnnfMHAKind::kBackwardScaleMaskSoftmaxDropout: - return lmhlo_gpu::FusedMhaBackwardDagSignature:: - BackwardScaleMaskSoftmaxDropout; - break; - default: - return xla::Internal("unknown cudnn fmha bwd kind"); - } -} -} // namespace - -tsl::StatusOr LhloDialectEmitter::EmitGemm( - const HloCustomCallInstruction* custom_call) { - TF_ASSIGN_OR_RETURN( - auto const gpu_config, - custom_call->backend_config()); - const xla::gpu::GemmBackendConfig& config = gpu_config.gemm_backend_config(); - if (custom_call->operand_count() == 2) { - TF_RET_CHECK(config.beta() == 0.); - } else if (custom_call->operand_count() != 3) { - return xla::InvalidArgument("GEMM custom call should have 2 or 3 operands"); - } - - // GEMM may have two or three operands. However, in the three operand case, - // the third operand is updated in-place, so we treat that as an output here. - TF_ASSIGN_OR_RETURN( - lmhlo_gpu::GEMMOp op, - CreateOpWithoutAttrs(custom_call, - /*num_operands=*/2)); - - SetMatmulAttributes(op, config, builder_); - return op.getOperation(); -} - -tsl::StatusOr LhloDialectEmitter::EmitCublasLtMatmul( - const HloCustomCallInstruction* custom_call) { - TF_ASSIGN_OR_RETURN( - auto const gpu_config, - custom_call->backend_config()); - const xla::gpu::GemmBackendConfig& config = gpu_config.gemm_backend_config(); - bool has_matrix_bias = config.beta() != 0.; - - TF_ASSIGN_OR_RETURN( - bool has_vector_bias, - xla::gpu::gpublas_lt::EpilogueAddsVectorBias(config.epilogue())); - - TF_ASSIGN_OR_RETURN( - bool has_aux_output, - xla::gpu::gpublas_lt::EpilogueHasAuxiliaryOutput(config.epilogue())); - - TF_RET_CHECK(custom_call->operand_count() == - 2 + int{has_matrix_bias} + int{has_vector_bias}); - - xla::ShapeIndex output_index = - has_aux_output ? xla::ShapeIndex{0} : xla::ShapeIndex{}; - - llvm::SmallVector operands; - TF_RETURN_IF_ERROR(GetOrCreateView(custom_call->operand(0), &operands)); - TF_RETURN_IF_ERROR(GetOrCreateView(custom_call->operand(1), &operands)); - if (has_matrix_bias) { - TF_RETURN_IF_ERROR(GetOrCreateView(custom_call->operand(2), &operands)); - } else { - TF_RETURN_IF_ERROR(GetOrCreateView(custom_call, &operands, output_index)); - } - TF_RETURN_IF_ERROR(GetOrCreateView(custom_call, &operands, output_index)); - - if (has_vector_bias) { - TF_RETURN_IF_ERROR(GetOrCreateView( - custom_call->operand(has_matrix_bias ? 3 : 2), &operands)); - } - - if (has_aux_output) { - TF_RETURN_IF_ERROR(GetOrCreateView(custom_call, &operands, {1})); - } - - auto op = - CreateOpWithoutAttrs(custom_call, operands); - SetMatmulAttributes(op, config, builder_); - - int32_t operand_sizes[] = { - 1, 1, 1, 1, has_vector_bias ? 1 : 0, has_aux_output ? 1 : 0}; - op->setAttr(op.getOperandSegmentSizeAttr(), - builder_.getDenseI32ArrayAttr(operand_sizes)); - - TF_ASSIGN_OR_RETURN(lmhlo_gpu::CublasLtMatmulEpilogue epilogue, - AsLhloEpilogue(config.epilogue())); - op.setEpilogueAttr(lmhlo_gpu::CublasLtMatmulEpilogueAttr::get( - builder_.getContext(), epilogue)); - - // Use the first algorithm by default (i.e. fastest according to heuristics). - if (config.algorithm_case() != - xla::gpu::GemmBackendConfig::kSelectedAlgorithm) { - op.setAlgorithmAttr(builder_.getI64IntegerAttr(0)); - } - - return op.getOperation(); -} - -tsl::StatusOr LhloDialectEmitter::EmitCublasLtMatmulF8( - const HloCustomCallInstruction* custom_call) { - TF_ASSIGN_OR_RETURN( - auto const gpu_config, - custom_call->backend_config()); - const xla::gpu::GemmBackendConfig& config = gpu_config.gemm_backend_config(); - int ops_num = custom_call->operand_count(); - TF_RET_CHECK(ops_num == 6 || ops_num == 7 || ops_num == 8); - TF_ASSIGN_OR_RETURN( - bool has_vector_bias, - xla::gpu::gpublas_lt::EpilogueAddsVectorBias(config.epilogue())); - - bool has_damax = custom_call->shape().IsTuple(); - bool has_matrix_bias = config.beta() != 0.; - xla::ShapeIndex output_index = - has_damax ? xla::ShapeIndex{0} : xla::ShapeIndex{}; - - llvm::SmallVector operands; - TF_RETURN_IF_ERROR(GetOrCreateView(custom_call->operand(0), &operands)); - TF_RETURN_IF_ERROR(GetOrCreateView(custom_call->operand(1), &operands)); - - int a_scale_index = has_matrix_bias ? 3 : 2; - if (has_matrix_bias) { - TF_RETURN_IF_ERROR(GetOrCreateView(custom_call->operand(2), &operands)); - } else { - TF_RETURN_IF_ERROR(GetOrCreateView(custom_call, &operands, output_index)); - } - - TF_RETURN_IF_ERROR( - GetOrCreateView(custom_call->operand(a_scale_index), &operands)); - TF_RETURN_IF_ERROR( - GetOrCreateView(custom_call->operand(a_scale_index + 1), &operands)); - TF_RETURN_IF_ERROR( - GetOrCreateView(custom_call->operand(a_scale_index + 2), &operands)); - TF_RETURN_IF_ERROR( - GetOrCreateView(custom_call->operand(a_scale_index + 3), &operands)); - TF_RETURN_IF_ERROR(GetOrCreateView(custom_call, &operands, output_index)); - - if (has_vector_bias) { - TF_RETURN_IF_ERROR( - GetOrCreateView(custom_call->operand(a_scale_index + 4), &operands)); - } - if (has_damax) { - TF_RETURN_IF_ERROR(GetOrCreateView(custom_call, &operands, {1})); - } - auto op = CreateOpWithoutAttrs(custom_call, - operands); - - SetMatmulAttributes(op, config, builder_); - int32_t operand_sizes[] = { - 1, 1, 1, 1, 1, 1, 1, 1, has_vector_bias ? 1 : 0, has_damax ? 1 : 0}; - op->setAttr(op.getOperandSegmentSizeAttr(), - builder_.getDenseI32ArrayAttr(operand_sizes)); - TF_ASSIGN_OR_RETURN(lmhlo_gpu::CublasLtMatmulEpilogue epilogue, - AsLhloEpilogue(config.epilogue())); - op.setEpilogueAttr(lmhlo_gpu::CublasLtMatmulEpilogueAttr::get( - builder_.getContext(), epilogue)); - - // Use the first algorithm by default (i.e. fastest according to heuristics). - if (config.algorithm_case() != - xla::gpu::GemmBackendConfig::kSelectedAlgorithm) { - op.setAlgorithmAttr(builder_.getI64IntegerAttr(0)); - } - - return op.getOperation(); -} - -static tsl::StatusOr GetLHLOActivation( - stream_executor::dnn::ActivationMode activation) { - switch (activation) { - case stream_executor::dnn::kNone: - return mlir::lmhlo_gpu::Activation::None; - case stream_executor::dnn::kSigmoid: - return mlir::lmhlo_gpu::Activation::Sigmoid; - case stream_executor::dnn::kRelu: - return mlir::lmhlo_gpu::Activation::Relu; - case stream_executor::dnn::kRelu6: - return mlir::lmhlo_gpu::Activation::Relu6; - case stream_executor::dnn::kReluX: - return mlir::lmhlo_gpu::Activation::ReluX; - case stream_executor::dnn::kTanh: - return mlir::lmhlo_gpu::Activation::Tanh; - case stream_executor::dnn::kBandPass: - return mlir::lmhlo_gpu::Activation::BandPass; - case stream_executor::dnn::kElu: - return mlir::lmhlo_gpu::Activation::Elu; - case stream_executor::dnn::kLeakyRelu: - return mlir::lmhlo_gpu::Activation::LeakyRelu; - default: - return xla::Internal("Unknown activation"); - } -} - -tsl::StatusOr LhloDialectEmitter::EmitDnnConvolution( - const HloCustomCallInstruction* custom_call) { - TF_ASSIGN_OR_RETURN( - auto const gpu_backend_config, - custom_call->backend_config()); - const xla::gpu::CudnnConvBackendConfig& backend_config = - gpu_backend_config.cudnn_conv_backend_config(); - TF_ASSIGN_OR_RETURN(const xla::gpu::CudnnConvKind kind, - xla::gpu::GetCudnnConvKind(custom_call)); - - auto get_layout_attribute = [&](const xla::Layout& layout) { - std::vector minor_to_major(layout.minor_to_major_size()); - absl::c_transform(layout.minor_to_major(), minor_to_major.begin(), - [](int64_t x) { return static_cast(x); }); - return minor_to_major; - }; - - auto set_common_conv_attributes = [&, this](auto op) -> Operation* { - const xla::Window& window = custom_call->window(); - // Window size for Cudnn Conv is same as the kernel size. - NamedAttrList attrs(op->getAttrDictionary()); - DenseIntElementsAttr window_strides; - attrs.set(op.getWindowStridesAttrName(), - window_strides = GetWindowElements( - window, [](const xla::WindowDimension& dim) { - return static_cast(dim.stride()); - })); - // Cudnn Conv requires low and high padding to be equal. - attrs.set(op.getPaddingAttrName(), - GetWindowElements(window, [](const xla::WindowDimension& dim) { - return static_cast(dim.padding_low()); - })); - // LHS dilation is encoded in base_dilation of the backend config. - // RHS dilation is encoded in window_dilation of the backend config. - attrs.set(op.getLhsDilationAttrName(), - GetWindowElements(window, [](const xla::WindowDimension& dim) { - return static_cast(dim.base_dilation()); - })); - attrs.set(op.getRhsDilationAttrName(), - GetWindowElements(window, [](const xla::WindowDimension& dim) { - return static_cast(dim.window_dilation()); - })); - // Setup window reversal. - auto window_reversal = llvm::to_vector<4>(llvm::map_range( - window.dimensions(), - [](const xla::WindowDimension& dim) { return dim.window_reversal(); })); - auto type = RankedTensorType::get(window_strides.getType().getShape(), - builder_.getIntegerType(/*width=*/1)); - attrs.set(op.getWindowReversalAttrName(), - DenseElementsAttr::get(type, window_reversal)); - - attrs.set(op.getDimensionNumbersAttrName(), - xla::ConvertConvDimensionNumbers( - custom_call->convolution_dimension_numbers(), &builder_)); - attrs.set(op.getFeatureGroupCountAttrName(), - builder_.getI64IntegerAttr(custom_call->feature_group_count())); - attrs.set(op.getBatchGroupCountAttrName(), - builder_.getI64IntegerAttr(custom_call->batch_group_count())); - attrs.set(op.getPrecisionConfigAttrName(), - xla::ConvertPrecisionConfig(&custom_call->precision_config(), - &builder_)); - attrs.set(op.getResultScaleAttrName(), - builder_.getF64FloatAttr(backend_config.conv_result_scale())); - - const auto& algorithm = backend_config.algorithm(); - std::vector knob_ids; - std::vector knob_values; - for (const auto& entry : algorithm.tuning_knobs()) { - knob_ids.push_back(entry.first); - knob_values.push_back(entry.second); - } - - auto config = mlir::lmhlo_gpu::ConvolutionBackendConfigAttr::get( - builder_.getContext(), algorithm.algo_id(), - - algorithm.math_type() == - stream_executor::dnn::AlgorithmProto::TENSOR_OP_MATH, - knob_ids, knob_values, algorithm.is_cudnn_frontend(), - backend_config.reordered_int8_nchw_vect(), - algorithm.has_workspace_size() ? algorithm.workspace_size().value() - : -1, - get_layout_attribute(custom_call->operand(0)->shape().layout()), - get_layout_attribute(custom_call->operand(1)->shape().layout()), - get_layout_attribute(custom_call->shape().tuple_shapes(0).layout())); - attrs.set(op.getBackendConfigAttrName(), config); - op->setAttrs(attrs.getDictionary(op->getContext())); - - return op.getOperation(); - }; - - auto set_activation = [&, this](auto op) -> tsl::Status { - auto se_activation = static_cast( - backend_config.activation_mode()); - TF_ASSIGN_OR_RETURN(mlir::lmhlo_gpu::Activation activation, - GetLHLOActivation(se_activation)); - auto activation_attr = ::mlir::lmhlo_gpu::ActivationAttr::get( - getLocation(custom_call).getContext(), activation); - op.setActivationModeAttr(activation_attr); - return ::tsl::OkStatus(); - }; - - switch (kind) { - case xla::gpu::CudnnConvKind::kForward: { - TF_ASSIGN_OR_RETURN( - auto cnn_forward, - CreateOpWithoutAttrs(custom_call)); - return set_common_conv_attributes(cnn_forward); - } - case xla::gpu::CudnnConvKind::kBackwardInput: { - TF_ASSIGN_OR_RETURN( - auto cnn_backward, - CreateOpWithoutAttrs(custom_call)); - return set_common_conv_attributes(cnn_backward); - } - case xla::gpu::CudnnConvKind::kBackwardFilter: { - TF_ASSIGN_OR_RETURN( - auto cnn_backward, - CreateOpWithoutAttrs(custom_call)); - return set_common_conv_attributes(cnn_backward); - } - case xla::gpu::CudnnConvKind::kForwardActivation: { - // Fused conv can be either with side input or without. - if (custom_call->operand_count() == 3) { - TF_ASSIGN_OR_RETURN( - auto cnn_fused, - CreateOpWithoutAttrs(custom_call)); - TF_RETURN_IF_ERROR(set_activation(cnn_fused)); - cnn_fused.setLeakyreluAlphaAttr( - builder_.getF64FloatAttr(backend_config.leakyrelu_alpha())); - return set_common_conv_attributes(cnn_fused); - } - - TF_RET_CHECK(custom_call->operand_count() == 4); - TF_ASSIGN_OR_RETURN( - auto cnn_fused_side_input, - CreateOpWithoutAttrs( - custom_call)); - cnn_fused_side_input.setSideInputScaleAttr( - builder_.getF64FloatAttr(backend_config.side_input_scale())); - TF_RETURN_IF_ERROR(set_activation(cnn_fused_side_input)); - return set_common_conv_attributes(cnn_fused_side_input); - } - case xla::gpu::CudnnConvKind::kForwardGraph: { - const int32_t n_binary_operands = custom_call->operand_count() - 2; - const int32_t n_aux_outputs = - custom_call->shape().tuple_shapes_size() - 2; - TF_ASSIGN_OR_RETURN( - auto cnn_graph, - CreateOpWithoutAttrs(custom_call)); - cnn_graph.setSerializedGraph(backend_config.serialized_graph()); - cnn_graph.setNAuxOutputs(n_aux_outputs); - int32_t operand_sizes[] = {1, 1, n_binary_operands, 1, n_aux_outputs, 1}; - cnn_graph->setAttr(cnn_graph.getOperandSegmentSizeAttr(), - builder_.getDenseI32ArrayAttr(operand_sizes)); - return set_common_conv_attributes(cnn_graph); - } - } -} - -tsl::StatusOr -LhloDialectEmitter::EmitDnnConvolutionReorderVectorized( - const HloCustomCallInstruction* custom_call) { - auto set_common_attributes = [&, this](auto op) -> Operation* { - // Output shape defines the filter, it must have NCHW_VECT_C layout. - Shape shape = custom_call->shape(); - if (shape.IsTuple()) { - // We explicitly create a copy here to avoid self-assignment issues - shape = Shape{shape.tuple_shapes(0)}; - } - - CHECK_EQ(shape.rank(), 5); - CHECK_EQ(shape.dimensions(4), 32); - llvm::SmallVector nchw = { - shape.dimensions(0), shape.dimensions(1) * 32, shape.dimensions(2), - shape.dimensions(3)}; - op->setAttr("filter_dims", GetI64DenseElementsAttr(nchw)); - - return op.getOperation(); - }; - - if (custom_call->operand_count() > 1) { - TF_ASSIGN_OR_RETURN( - auto reorder_filter_and_bias, - CreateOpWithoutAttrs( - custom_call)); - return set_common_attributes(reorder_filter_and_bias); - } else { - TF_ASSIGN_OR_RETURN( - auto reorder_filter, - CreateOpWithoutAttrs(custom_call)); - return set_common_attributes(reorder_filter); - } -} - -tsl::StatusOr LhloDialectEmitter::EmitDnnNorm( - const HloCustomCallInstruction* custom_call) { - TF_ASSIGN_OR_RETURN( - auto const gpu_config, - custom_call->backend_config()); - const xla::gpu::CudnnNormBackendConfig& backend_config = - gpu_config.cudnn_norm_backend_config(); - llvm::SmallVector operands; - TF_RETURN_IF_ERROR(GetOrCreateView(custom_call->operand(0), &operands)); - TF_RETURN_IF_ERROR(GetOrCreateView(custom_call->operand(1), &operands)); - TF_RETURN_IF_ERROR(GetOrCreateView(custom_call->operand(2), &operands)); - TF_RETURN_IF_ERROR(GetOrCreateView(custom_call, &operands)); - - auto norm = - CreateOpWithoutAttrs(custom_call, operands); - norm.setEpsilonAttr(builder_.getF64FloatAttr(backend_config.epsilon())); - - const auto& algorithm = backend_config.algorithm(); - auto norm_algo_config = mlir::lmhlo_gpu::NormAlgorithmConfigAttr::get( - builder_.getContext(), algorithm.algo_id(), - algorithm.has_workspace_size() ? algorithm.workspace_size().value() : -1); - norm.setAlgorithmConfigAttr(norm_algo_config); - - std::vector operand_minor_to_major; - - auto get_minor_to_major = - [&operand_minor_to_major](const xla::Layout& layout) -> void { - std::vector minor_to_major(layout.minor_to_major_size()); - absl::c_transform(layout.minor_to_major(), minor_to_major.begin(), - [](int64_t x) { return static_cast(x); }); - operand_minor_to_major.insert(operand_minor_to_major.end(), - minor_to_major.begin(), minor_to_major.end()); - }; - - // Store the layout information of all operands and outputs. - for (HloInstruction* operand : custom_call->operands()) { - get_minor_to_major(operand->shape().layout()); - } - for (int i = 0; i < custom_call->shape().tuple_shapes_size() - 1; ++i) { - get_minor_to_major(custom_call->shape().tuple_shapes(i).layout()); - } - - norm.setOperandLayoutsAttr(builder_.getI64ArrayAttr(llvm::ArrayRef{ - operand_minor_to_major.data(), operand_minor_to_major.size()})); - - bool has_aux_outputs = custom_call->shape().tuple_shapes_size() == 4; - int32_t operand_sizes[] = {1, 1, 1, 1, has_aux_outputs, has_aux_outputs, 1}; - norm->setAttr(norm.getOperandSegmentSizeAttr(), - builder_.getDenseI32ArrayAttr(operand_sizes)); - - return norm.getOperation(); -} - -tsl::StatusOr LhloDialectEmitter::EmitDnnfMHA( - const HloCustomCallInstruction* custom_call) { - TF_ASSIGN_OR_RETURN( - auto const gpu_config, - custom_call->backend_config()); - const xla::gpu::CudnnfMHABackendConfig& config = - gpu_config.cudnn_fmha_backend_config(); - TF_ASSIGN_OR_RETURN(const xla::gpu::CudnnfMHAKind kind, - xla::gpu::GetCudnnfMHAKind(custom_call)); - - bool has_activation = - xla::ShapeUtil::TupleElementCount(custom_call->shape()) == 3; - bool has_mask = false; - bool has_bias = false; - - auto set_common_fmha_attributes = - [&, this](auto op) -> tsl::StatusOr { - TF_ASSIGN_OR_RETURN(lmhlo_gpu::FusedMhaDagSignature fused_mha_dag_signature, - AsLhloFusedMhaDagSignature(kind)); - op.setFusedMhaDagAttr(lmhlo_gpu::FusedMhaDagSignatureAttr::get( - builder_.getContext(), fused_mha_dag_signature)); - op.setBmm1DotDimensionNumbersAttr(GetDotDimensionNumbersAttr( - builder_, config.bmm1_dot_dimension_numbers())); - op.setBmm2DotDimensionNumbersAttr(GetDotDimensionNumbersAttr( - builder_, config.bmm2_dot_dimension_numbers())); - - const auto& algorithm = config.algorithm(); - std::vector knob_ids; - std::vector knob_values; - for (const auto& entry : algorithm.tuning_knobs()) { - knob_ids.push_back(entry.first); - knob_values.push_back(entry.second); - } - auto fmha_algo_config = mlir::lmhlo_gpu::FusedMHAAlgorithmConfigAttr::get( - builder_.getContext(), algorithm.algo_id(), knob_ids, knob_values, - algorithm.has_workspace_size() ? algorithm.workspace_size().value() - : -1); - op.setAlgorithmConfigAttr(fmha_algo_config); - - auto intermediate_tensor_shape = Shape(config.intermediate_tensor_shape()); - auto arrayref = [](absl::Span array) { - return llvm::ArrayRef{array.data(), array.size()}; - }; - auto intermediate_tensor_dims = builder_.getI64ArrayAttr( - arrayref(intermediate_tensor_shape.dimensions())); - op.setIntermediateTensorDimensionsAttr(intermediate_tensor_dims); - - auto intermediate_tensor_layout = builder_.getI64ArrayAttr( - arrayref(intermediate_tensor_shape.layout().minor_to_major())); - op.setIntermediateTensorLayoutAttr(intermediate_tensor_layout); - op.setFmhaScaleAttr(builder_.getF64FloatAttr(config.fmha_scale())); - int32_t operand_sizes[] = {1, - 1, - 1, - has_mask ? 1 : 0, - has_bias ? 1 : 0, - 1, - 1, - has_activation ? 1 : 0}; - op->setAttr(op.getOperandSegmentSizeAttr(), - builder_.getDenseI32ArrayAttr(operand_sizes)); - // set is flash attention here - op.setIsFlashAttentionAttr( - builder_.getBoolAttr(config.is_flash_attention())); - // set is causal mask here - op.setIsCausalMaskAttr(builder_.getBoolAttr(config.is_causal_mask())); - return op.getOperation(); - }; - llvm::SmallVector operands; - TF_RETURN_IF_ERROR(GetOrCreateView(custom_call->operand(0), &operands)); - TF_RETURN_IF_ERROR(GetOrCreateView(custom_call->operand(1), &operands)); - TF_RETURN_IF_ERROR(GetOrCreateView(custom_call->operand(2), &operands)); - switch (kind) { - case xla::gpu::CudnnfMHAKind::kBmmBmm: - case xla::gpu::CudnnfMHAKind::kSoftmax: { - TF_RETURN_IF_ERROR(GetOrCreateView(custom_call, &operands)); - auto fmha = - CreateOpWithoutAttrs(custom_call, operands); - return set_common_fmha_attributes(fmha); - } - case xla::gpu::CudnnfMHAKind::kSoftmaxDropout: { - TF_RETURN_IF_ERROR(GetOrCreateView(custom_call, &operands)); - auto fmha = - CreateOpWithoutAttrs(custom_call, operands); - fmha.setDropoutRateAttr(builder_.getF64FloatAttr(config.dropout_rate())); - fmha.setSeedAttr(builder_.getI64IntegerAttr(config.seed())); - return set_common_fmha_attributes(fmha); - } - case xla::gpu::CudnnfMHAKind::kScaleMaskSoftmax: { - TF_RETURN_IF_ERROR(GetOrCreateView(custom_call->operand(3), &operands)); - TF_RETURN_IF_ERROR(GetOrCreateView(custom_call, &operands)); - auto fmha = - CreateOpWithoutAttrs(custom_call, operands); - has_mask = true; - return set_common_fmha_attributes(fmha); - } - case xla::gpu::CudnnfMHAKind::kScaleMaskSoftmaxDropout: { - TF_RETURN_IF_ERROR(GetOrCreateView(custom_call->operand(3), &operands)); - TF_RETURN_IF_ERROR(GetOrCreateView(custom_call, &operands)); - auto fmha = - CreateOpWithoutAttrs(custom_call, operands); - fmha.setDropoutRateAttr(builder_.getF64FloatAttr(config.dropout_rate())); - fmha.setSeedAttr(builder_.getI64IntegerAttr(config.seed())); - has_mask = true; - return set_common_fmha_attributes(fmha); - } - case xla::gpu::CudnnfMHAKind::kScaleBiasMaskSoftmax: { - TF_RETURN_IF_ERROR(GetOrCreateView(custom_call->operand(3), &operands)); - TF_RETURN_IF_ERROR(GetOrCreateView(custom_call->operand(4), &operands)); - TF_RETURN_IF_ERROR(GetOrCreateView(custom_call, &operands)); - auto fmha = - CreateOpWithoutAttrs(custom_call, operands); - has_mask = true; - has_bias = true; - return set_common_fmha_attributes(fmha); - } - case xla::gpu::CudnnfMHAKind::kScaleBiasMaskSoftmaxDropout: { - TF_RETURN_IF_ERROR(GetOrCreateView(custom_call->operand(3), &operands)); - TF_RETURN_IF_ERROR(GetOrCreateView(custom_call->operand(4), &operands)); - TF_RETURN_IF_ERROR(GetOrCreateView(custom_call, &operands)); - auto fmha = - CreateOpWithoutAttrs(custom_call, operands); - fmha.setDropoutRateAttr(builder_.getF64FloatAttr(config.dropout_rate())); - fmha.setSeedAttr(builder_.getI64IntegerAttr(config.seed())); - has_mask = true; - has_bias = true; - return set_common_fmha_attributes(fmha); - } - case xla::gpu::CudnnfMHAKind::kScaleBiasSoftmax: { - TF_RETURN_IF_ERROR(GetOrCreateView(custom_call->operand(3), &operands)); - TF_RETURN_IF_ERROR(GetOrCreateView(custom_call, &operands)); - auto fmha = - CreateOpWithoutAttrs(custom_call, operands); - has_bias = true; - return set_common_fmha_attributes(fmha); - } - case xla::gpu::CudnnfMHAKind::kScaleBiasSoftmaxDropout: { - TF_RETURN_IF_ERROR(GetOrCreateView(custom_call->operand(3), &operands)); - TF_RETURN_IF_ERROR(GetOrCreateView(custom_call, &operands)); - auto fmha = - CreateOpWithoutAttrs(custom_call, operands); - fmha.setDropoutRateAttr(builder_.getF64FloatAttr(config.dropout_rate())); - fmha.setSeedAttr(builder_.getI64IntegerAttr(config.seed())); - has_bias = true; - return set_common_fmha_attributes(fmha); - } - default: - return xla::Internal("Unknown forward fused MHA call."); - } -} - -tsl::StatusOr LhloDialectEmitter::EmitDnnfMHABackward( - const HloCustomCallInstruction* custom_call) { - TF_ASSIGN_OR_RETURN( - auto const gpu_config, - custom_call->backend_config()); - const xla::gpu::CudnnfMHABackendConfig& config = - gpu_config.cudnn_fmha_backend_config(); - TF_ASSIGN_OR_RETURN(const xla::gpu::CudnnfMHAKind kind, - xla::gpu::GetCudnnfMHAKind(custom_call)); - - bool is_flash_attention = config.is_flash_attention(); - bool has_dbias = - custom_call->shape().tuple_shapes().size() == 6 && !is_flash_attention; - bool has_mask = false; - bool has_bias = false; - - auto set_common_fmha_backward_attributes = - [&, this](auto op) -> tsl::StatusOr { - TF_ASSIGN_OR_RETURN(lmhlo_gpu::FusedMhaBackwardDagSignature - fused_mha_backward_dag_signature, - AsLhloFusedMhaBackwardDagSignature(kind)); - op.setFusedMhaDagAttr(lmhlo_gpu::FusedMhaBackwardDagSignatureAttr::get( - builder_.getContext(), fused_mha_backward_dag_signature)); - op.setBmm1GradGemm1DotDimensionNumbersAttr(GetDotDimensionNumbersAttr( - builder_, config.bmm1_grad_gemm1_dot_dimension_numbers())); - op.setBmm1GradGemm2DotDimensionNumbersAttr(GetDotDimensionNumbersAttr( - builder_, config.bmm1_grad_gemm2_dot_dimension_numbers())); - op.setBmm2GradGemm1DotDimensionNumbersAttr(GetDotDimensionNumbersAttr( - builder_, config.bmm2_grad_gemm1_dot_dimension_numbers())); - op.setBmm2GradGemm2DotDimensionNumbersAttr(GetDotDimensionNumbersAttr( - builder_, config.bmm2_grad_gemm2_dot_dimension_numbers())); - - auto intermediate_tensor_shape = Shape(config.intermediate_tensor_shape()); - auto arrayref = [](absl::Span array) { - return llvm::ArrayRef{array.data(), array.size()}; - }; - auto intermediate_tensor_dims = builder_.getI64ArrayAttr( - arrayref(intermediate_tensor_shape.dimensions())); - op.setIntermediateTensorDimensionsAttr(intermediate_tensor_dims); - - auto intermediate_tensor_layout = builder_.getI64ArrayAttr( - arrayref(intermediate_tensor_shape.layout().minor_to_major())); - op.setIntermediateTensorLayoutAttr(intermediate_tensor_layout); - - op.setFmhaScaleAttr(builder_.getF64FloatAttr(config.fmha_scale())); - - int32_t operand_sizes[] = {1, - 1, - 1, - 1, - 1, - has_mask ? 1 : 0, - has_bias ? 1 : 0, - is_flash_attention ? 1 : 0, // fwd_output - 1, - 1, - 1, - is_flash_attention ? 0 : 1, // d_S - is_flash_attention ? 1 : 0, // softmax_sum - is_flash_attention ? 1 : 0, // d_Q_accum - 1, - has_dbias ? 1 : 0}; - op->setAttr(op.getOperandSegmentSizeAttr(), - builder_.getDenseI32ArrayAttr(operand_sizes)); - - // set is flash attention here - op.setIsFlashAttentionAttr( - builder_.getBoolAttr(config.is_flash_attention())); - // set is causal mask here - op.setIsCausalMaskAttr(builder_.getBoolAttr(config.is_causal_mask())); - const auto& algorithm = config.algorithm(); - std::vector knob_ids; - std::vector knob_values; - for (const auto& entry : algorithm.tuning_knobs()) { - knob_ids.push_back(entry.first); - knob_values.push_back(entry.second); - } - auto fmha_algo_config = mlir::lmhlo_gpu::FusedMHAAlgorithmConfigAttr::get( - builder_.getContext(), algorithm.algo_id(), knob_ids, knob_values, - algorithm.has_workspace_size() ? algorithm.workspace_size().value() - : -1); - op.setAlgorithmConfigAttr(fmha_algo_config); - return op.getOperation(); - }; - - llvm::SmallVector operands; - TF_RETURN_IF_ERROR(GetOrCreateView(custom_call->operand(0), &operands)); - TF_RETURN_IF_ERROR(GetOrCreateView(custom_call->operand(1), &operands)); - TF_RETURN_IF_ERROR(GetOrCreateView(custom_call->operand(2), &operands)); - TF_RETURN_IF_ERROR(GetOrCreateView(custom_call->operand(3), &operands)); - TF_RETURN_IF_ERROR(GetOrCreateView(custom_call->operand(4), &operands)); - - switch (kind) { - case xla::gpu::CudnnfMHAKind::kBackwardBmmBmm: - case xla::gpu::CudnnfMHAKind::kBackwardSoftmax: { - // push fwd output for bwd here if it is flash attention - if (config.is_flash_attention()) { - TF_RETURN_IF_ERROR(GetOrCreateView(custom_call->operand(5), &operands)); - } - TF_RETURN_IF_ERROR(GetOrCreateView(custom_call, &operands)); - auto fmha_backward = CreateOpWithoutAttrs( - custom_call, operands); - return set_common_fmha_backward_attributes(fmha_backward); - } - case xla::gpu::CudnnfMHAKind::kBackwardSoftmaxDropout: { - // push fwd output for bwd here if it is flash attention - if (config.is_flash_attention()) { - TF_RETURN_IF_ERROR(GetOrCreateView(custom_call->operand(5), &operands)); - } - TF_RETURN_IF_ERROR(GetOrCreateView(custom_call, &operands)); - auto fmha_backward = CreateOpWithoutAttrs( - custom_call, operands); - fmha_backward.setDropoutRateAttr( - builder_.getF64FloatAttr(config.dropout_rate())); - fmha_backward.setSeedAttr(builder_.getI64IntegerAttr(config.seed())); - return set_common_fmha_backward_attributes(fmha_backward); - } - case xla::gpu::CudnnfMHAKind::kBackwardScaleBiasSoftmax: { - // push fwd output for bwd here if it is flash attention - if (config.is_flash_attention()) { - has_bias = true; - TF_RETURN_IF_ERROR(GetOrCreateView(custom_call->operand(5), &operands)); - TF_RETURN_IF_ERROR(GetOrCreateView(custom_call->operand(6), &operands)); - } - TF_RETURN_IF_ERROR(GetOrCreateView(custom_call, &operands)); - auto fmha_backward = CreateOpWithoutAttrs( - custom_call, operands); - return set_common_fmha_backward_attributes(fmha_backward); - } - case xla::gpu::CudnnfMHAKind::kBackwardScaleBiasSoftmaxDropout: { - // push fwd output for bwd here if it is flash attention - if (config.is_flash_attention()) { - has_bias = true; - TF_RETURN_IF_ERROR(GetOrCreateView(custom_call->operand(5), &operands)); - TF_RETURN_IF_ERROR(GetOrCreateView(custom_call->operand(6), &operands)); - } - TF_RETURN_IF_ERROR(GetOrCreateView(custom_call, &operands)); - auto fmha_backward = CreateOpWithoutAttrs( - custom_call, operands); - fmha_backward.setDropoutRateAttr( - builder_.getF64FloatAttr(config.dropout_rate())); - fmha_backward.setSeedAttr(builder_.getI64IntegerAttr(config.seed())); - return set_common_fmha_backward_attributes(fmha_backward); - } - - case xla::gpu::CudnnfMHAKind::kBackwardScaleMaskSoftmax: { - TF_RETURN_IF_ERROR(GetOrCreateView(custom_call->operand(5), &operands)); - // push fwd output for bwd here if it is flash attention - if (config.is_flash_attention()) { - TF_RETURN_IF_ERROR(GetOrCreateView(custom_call->operand(6), &operands)); - } - TF_RETURN_IF_ERROR(GetOrCreateView(custom_call, &operands)); - has_mask = true; - auto fmha_backward = CreateOpWithoutAttrs( - custom_call, operands); - return set_common_fmha_backward_attributes(fmha_backward); - } - case xla::gpu::CudnnfMHAKind::kBackwardScaleBiasMaskSoftmax: { - TF_RETURN_IF_ERROR(GetOrCreateView(custom_call->operand(5), &operands)); - // push fwd output for bwd here if it is flash attention - if (config.is_flash_attention()) { - has_bias = true; - TF_RETURN_IF_ERROR(GetOrCreateView(custom_call->operand(6), &operands)); - TF_RETURN_IF_ERROR(GetOrCreateView(custom_call->operand(7), &operands)); - } - TF_RETURN_IF_ERROR(GetOrCreateView(custom_call, &operands)); - has_mask = true; - auto fmha_backward = CreateOpWithoutAttrs( - custom_call, operands); - return set_common_fmha_backward_attributes(fmha_backward); - } - - case xla::gpu::CudnnfMHAKind::kBackwardScaleMaskSoftmaxDropout: { - TF_RETURN_IF_ERROR(GetOrCreateView(custom_call->operand(5), &operands)); - // push fwd output for bwd here if it is flash attention - if (config.is_flash_attention()) { - TF_RETURN_IF_ERROR(GetOrCreateView(custom_call->operand(6), &operands)); - } - TF_RETURN_IF_ERROR(GetOrCreateView(custom_call, &operands)); - has_mask = true; - auto fmha_backward = CreateOpWithoutAttrs( - custom_call, operands); - fmha_backward.setDropoutRateAttr( - builder_.getF64FloatAttr(config.dropout_rate())); - fmha_backward.setSeedAttr(builder_.getI64IntegerAttr(config.seed())); - return set_common_fmha_backward_attributes(fmha_backward); - } - case xla::gpu::CudnnfMHAKind::kBackwardScaleBiasMaskSoftmaxDropout: { - TF_RETURN_IF_ERROR(GetOrCreateView(custom_call->operand(5), &operands)); - // push fwd output for bwd here if it is flash attention - if (config.is_flash_attention()) { - has_bias = true; - TF_RETURN_IF_ERROR( - GetOrCreateView(custom_call->operand(6), &operands)); // bias - TF_RETURN_IF_ERROR( - GetOrCreateView(custom_call->operand(7), &operands)); // fwd_output - } - TF_RETURN_IF_ERROR(GetOrCreateView(custom_call, &operands)); - has_mask = true; - auto fmha_backward = CreateOpWithoutAttrs( - custom_call, operands); - fmha_backward.setDropoutRateAttr( - builder_.getF64FloatAttr(config.dropout_rate())); - fmha_backward.setSeedAttr(builder_.getI64IntegerAttr(config.seed())); - return set_common_fmha_backward_attributes(fmha_backward); - } - - default: - return xla::Internal("Unknown backward fused MHA call."); - } -} - -xla::StatusOr LhloDialectEmitter::EmitCubDeviceRadixSort( - const xla::HloCustomCallInstruction* custom_call) { - TF_ASSIGN_OR_RETURN( - auto radix_sort_op, - CreateOpWithoutAttrs(custom_call)); - TF_ASSIGN_OR_RETURN(xla::SortOptions options, - custom_call->backend_config()); - radix_sort_op.setDescendingAttr(builder_.getBoolAttr(options.descending())); - return radix_sort_op.getOperation(); -} - -// Convert an XLA HLO constant to a global_memref + get_global_memref pair. -tsl::StatusOr LhloDialectEmitter::EmitConstant( - const HloInstruction* instr) { - auto& instr_slice = instr_slices_[std::make_pair(instr, xla::ShapeIndex())]; - if (instr_slice) { - return dyn_cast(instr_slice.getDefiningOp()); - } - - // Insert a global_memref in the module. - Location loc = getLocation(instr); - - auto const_instr = xla::Cast(instr); - TF_RET_CHECK(const_instr->shape().IsArray() && - const_instr->shape().is_static()); - TF_ASSIGN_OR_RETURN(Type type, xla::ConvertShapeToType( - const_instr->shape(), builder_)); - auto memref_type = type.dyn_cast(); - TF_RET_CHECK(memref_type != nullptr); - - TF_ASSIGN_OR_RETURN( - DenseElementsAttr initial_value, - CreateDenseElementsAttrFromLiteral(const_instr->literal(), builder_)); - - std::string constant_name = xla::llvm_ir::ConstantNameToGlobalName( - xla::llvm_ir::SanitizeConstantName(instr->name())); - - // Insert the global memref at the top level. - { - OpBuilder::InsertionGuard guard(builder_); - builder_.clearInsertionPoint(); - auto global_var = builder_.create( - loc, constant_name, builder_.getStringAttr("private"), memref_type, - initial_value, true, /*alignment=*/IntegerAttr()); - symbol_table_.insert(global_var); - global_var.getOperation()->moveBefore(&module_.front()); - - // For operations that do not fold this constant value in their codegen, we - // still need to materialize it into a buffer. Since buffer allocation is - // already done, annotate the global_memref with the information to get to - // the allocated buffer slice for this constant if need be. - TF_ASSIGN_OR_RETURN(BufferAllocation::Slice slice, - assignment_.GetUniqueTopLevelSlice(instr)); - global_var->setAttr( - "lmhlo.alloc", - builder_.getIndexAttr(allocations_.find(slice.allocation()) - ->second.cast() - .getArgNumber())); - TF_RET_CHECK(slice.offset() == 0) - << "Each constant should have its own allocation from BufferAssignment"; - TF_RET_CHECK(slice.allocation()->size() == slice.size()) - << "Each constant should have its own allocation from BufferAssignment"; - } - - auto get_global_memref = - builder_.create(loc, memref_type, constant_name); - - // Update the cache to remember this value. - instr_slice = get_global_memref; - return get_global_memref; -} - -namespace { -template -void SetupChannelIdAttribute(OpT op, const xla::HloChannelInstruction* instr, - mlir::Builder builder) { - if (instr->channel_id().has_value()) { - op.setChannelIdAttr(mlir::mhlo::ChannelHandleAttr::get( - builder.getContext(), *instr->channel_id(), 0)); - } -} - -template -tsl::Status SetupCommonCollectiveOpAttributes(OpT op, - const HloInstruction* instr, - mlir::OpBuilder& builder) { - auto* collective = xla::Cast(instr); - auto replica_groups_attr = xla::HloFunctionImporter::ConvertReplicaGroups( - collective->replica_groups(), &builder); - op->setAttr(replica_groups_attr.getName(), replica_groups_attr.getValue()); - op.setConstrainLayoutAttr( - builder.getBoolAttr(collective->constrain_layout())); - SetupChannelIdAttribute(op, collective, builder); - return ::tsl::OkStatus(); -} -} // namespace - -template -tsl::StatusOr LhloDialectEmitter::EmitDoneOp( - const xla::HloInstruction* instr) { - auto token = ret_tokens_.extract(instr->operand(0)); - TF_RET_CHECK(token) << "didn't find " << OpT::getOperationName().str() - << " token"; - return builder_.create(getLocation(instr), /*resultTypes=*/std::nullopt, - token.mapped()); -} - -tsl::StatusOr -LhloDialectEmitter::EmitAllToAllStartOp(const xla::HloInstruction* instr) { - // All the input of async-done (which wraps the all-to-all) are also - // listed as outputs, so we just create operands for the outputs. - llvm::SmallVector operands; - TF_RETURN_IF_ERROR(GetOrCreateView(instr, &operands, /*result_subset=*/{})); - - mlir::Location loc = getLocation(instr); - mlir::Type token_type = mlir::mhlo::TokenType::get(builder_.getContext()); - std::array result_types = {token_type}; - auto all_to_all_start_op = - builder_.create(loc, result_types, operands); - - auto* all_to_all = xla::Cast( - instr->async_wrapped_instruction()); - TF_RETURN_IF_ERROR(SetupCommonCollectiveOpAttributes(all_to_all_start_op, - all_to_all, builder_)); - if (all_to_all->split_dimension().has_value()) { - all_to_all_start_op.setSplitDimensionAttr( - builder_.getI64IntegerAttr(*all_to_all->split_dimension())); - } - all_to_all_start_op.setIsSync(IsSyncCollective(instr)); - all_to_all_start_op.setNoParallelCustomCall( - NoParallelCustomCallCollective(instr)); - - auto [_, was_inserted] = - ret_tokens_.insert({instr, all_to_all_start_op.getToken()}); - TF_RET_CHECK(was_inserted) << "all-to-all-start already lowered"; - return all_to_all_start_op; -} - -tsl::StatusOr LhloDialectEmitter::EmitAllToAllDoneOp( - const HloInstruction* instr) { - return EmitDoneOp(instr); -} - -tsl::StatusOr -LhloDialectEmitter::EmitAllGatherStartOp(const HloInstruction* instr) { - llvm::SmallVector operands; - // In all-gather-start HLO, all inputs are also outputs of the HLO. In LMHLO - // though, we list the inputs and outputs just once. In the HLO result, - // the inputs are listed first, followed by outputs, which matches the order - // of operands we need for LMHLO AllGatherOp. - TF_RETURN_IF_ERROR(GetOrCreateView(instr, &operands, /*result_subset=*/{})); - - mlir::Location loc = getLocation(instr); - mlir::Type token_type = mlir::mhlo::TokenType::get(builder_.getContext()); - std::array result_types = {token_type}; - auto all_gather_start_op = - builder_.create(loc, result_types, operands); - - auto* all_gather = xla::Cast(instr); - TF_RETURN_IF_ERROR( - SetupCommonCollectiveOpAttributes(all_gather_start_op, instr, builder_)); - all_gather_start_op.setUseGlobalDeviceIdsAttr( - builder_.getBoolAttr(all_gather->use_global_device_ids())); - all_gather_start_op.setAllGatherDimensionAttr( - builder_.getI64IntegerAttr(all_gather->all_gather_dimension())); - all_gather_start_op.setIsSync(IsSyncCollective(instr)); - all_gather_start_op.setNoParallelCustomCall( - NoParallelCustomCallCollective(instr)); - auto [_, was_inserted] = - ret_tokens_.insert({instr, all_gather_start_op.getToken()}); - TF_RET_CHECK(was_inserted) << "all-gather-start already lowered"; - return all_gather_start_op; -} - -tsl::StatusOr -LhloDialectEmitter::EmitAllGatherDoneOp(const HloInstruction* instr) { - return EmitDoneOp(instr); -} - -tsl::StatusOr -LhloDialectEmitter::EmitAllReduceStartOp(const HloInstruction* instr) { - llvm::SmallVector operands; - for (const HloInstruction* operand : instr->operands()) { - TF_RETURN_IF_ERROR(GetOrCreateView(operand, &operands)); - } - TF_RETURN_IF_ERROR(GetOrCreateView(instr, &operands, /*result_subset=*/{})); - - mlir::Location loc = getLocation(instr); - mlir::Type token_type = mlir::mhlo::TokenType::get(builder_.getContext()); - std::array result_types = {token_type}; - auto all_reduce_start_op = - builder_.create(loc, result_types, operands); - - auto* all_reduce = xla::Cast(instr); - TF_RETURN_IF_ERROR( - SetupCommonCollectiveOpAttributes(all_reduce_start_op, instr, builder_)); - all_reduce_start_op.setUseGlobalDeviceIdsAttr( - builder_.getBoolAttr(all_reduce->use_global_device_ids())); - all_reduce_start_op.setIsSync(IsSyncCollective(instr)); - all_reduce_start_op.setNoParallelCustomCall( - NoParallelCustomCallCollective(instr)); - - TF_RETURN_IF_ERROR(xla::HloFunctionImporter::ImportAsRegion( - *instr->called_computations()[0], symbol_table_, - &all_reduce_start_op.getComputation(), &builder_)); - - auto [_, was_inserted] = - ret_tokens_.insert({instr, all_reduce_start_op.getToken()}); - TF_RET_CHECK(was_inserted) << "all-reduce-start already lowered"; - return all_reduce_start_op; -} - -tsl::StatusOr -LhloDialectEmitter::EmitAllReduceDoneOp(const HloInstruction* instr) { - return EmitDoneOp(instr); -} - -tsl::StatusOr LhloDialectEmitter::EmitAsyncStartOp( - const xla::HloInstruction* instr) { - const xla::HloAsyncInstruction* async = - xla::Cast(instr); - - switch (async->async_wrapped_opcode()) { - case xla::HloOpcode::kReduceScatter: - return EmitReduceScatterStartOp(instr); - case xla::HloOpcode::kAllToAll: - return EmitAllToAllStartOp(instr); - default: - return tsl::errors::InvalidArgument( - "Unexpected instruction %s wrapped in %s", - xla::HloOpcodeString(async->async_wrapped_opcode()), - HloOpcodeString(instr->opcode())); - } -} - -tsl::StatusOr LhloDialectEmitter::EmitAsyncDoneOp( - const xla::HloInstruction* instr) { - const xla::HloAsyncInstruction* async = - xla::Cast(instr); - switch (async->async_wrapped_opcode()) { - case xla::HloOpcode::kReduceScatter: - return EmitReduceScatterDoneOp(instr); - case xla::HloOpcode::kAllToAll: - return EmitAllToAllDoneOp(instr); - default: - return tsl::errors::InvalidArgument( - "Unexpected instruction %s wrapped in %s", - xla::HloOpcodeString(async->async_wrapped_opcode()), - HloOpcodeString(instr->opcode())); - } -} - -tsl::StatusOr -LhloDialectEmitter::EmitReduceScatterStartOp(const xla::HloInstruction* instr) { - // All the input of async-done (which wraps the reduce-scatter) are also - // listed as outputs, so we just create operands for the outputs. - llvm::SmallVector operands; - TF_RETURN_IF_ERROR(GetOrCreateView(instr, &operands, /*result_subset=*/{})); - - mlir::Location loc = getLocation(instr); - mlir::Type token_type = mlir::mhlo::TokenType::get(builder_.getContext()); - std::array result_types = {token_type}; - auto reduce_scatter_start_op = - builder_.create(loc, result_types, - operands); - - auto* reduce_scatter = xla::Cast( - instr->async_wrapped_instruction()); - TF_RETURN_IF_ERROR(SetupCommonCollectiveOpAttributes( - reduce_scatter_start_op, reduce_scatter, builder_)); - reduce_scatter_start_op.setUseGlobalDeviceIdsAttr( - builder_.getBoolAttr(reduce_scatter->use_global_device_ids())); - reduce_scatter_start_op.setScatterDimensionAttr( - builder_.getI64IntegerAttr(reduce_scatter->scatter_dimension())); - reduce_scatter_start_op.setIsSync(IsSyncCollective(instr)); - reduce_scatter_start_op.setNoParallelCustomCall( - NoParallelCustomCallCollective(instr)); - TF_RETURN_IF_ERROR(xla::HloFunctionImporter::ImportAsRegion( - *reduce_scatter->to_apply(), symbol_table_, - &reduce_scatter_start_op.getComputation(), &builder_)); - - auto [_, was_inserted] = - ret_tokens_.insert({instr, reduce_scatter_start_op.getToken()}); - TF_RET_CHECK(was_inserted) << "reduce-scatter-start already lowered"; - return reduce_scatter_start_op; -} - -tsl::StatusOr -LhloDialectEmitter::EmitReduceScatterDoneOp(const xla::HloInstruction* instr) { - return EmitDoneOp(instr); -} - -tsl::StatusOr -LhloDialectEmitter::EmitCollectivePermuteStartOp(const HloInstruction* instr) { - llvm::SmallVector operands; - for (const HloInstruction* operand : instr->operands()) { - TF_RETURN_IF_ERROR(GetOrCreateView(operand, &operands)); - } - // Ignore the aliased first output and TPU-specific outputs. - TF_RETURN_IF_ERROR(GetOrCreateView(instr, &operands, /*result_subset=*/{1})); - - mlir::Location loc = getLocation(instr); - mlir::Type token_type = mlir::mhlo::TokenType::get(builder_.getContext()); - std::array result_types = {token_type}; - auto permute_start_op = builder_.create( - loc, result_types, operands); - - auto* permute = xla::Cast(instr); - SetupChannelIdAttribute(permute_start_op, permute, builder_); - mlir::NamedAttribute source_target_pairs_attr = - xla::HloFunctionImporter::ConvertSourceTargetPairs( - permute->source_target_pairs(), &builder_); - permute_start_op->setAttr(source_target_pairs_attr.getName(), - source_target_pairs_attr.getValue()); - permute_start_op.setIsSync(IsSyncCollective(instr)); - permute_start_op.setNoParallelCustomCall( - NoParallelCustomCallCollective(instr)); - - auto [_, was_inserted] = - ret_tokens_.insert({instr, permute_start_op.getToken()}); - TF_RET_CHECK(was_inserted) << "collective-permute-start already lowered"; - return permute_start_op; -} - -tsl::StatusOr -LhloDialectEmitter::EmitCollectivePermuteDoneOp(const HloInstruction* instr) { - return EmitDoneOp(instr); -} - -tsl::StatusOr LhloDialectEmitter::EmitInfeedOp( - const HloInstruction* instr) { - const HloInfeedInstruction* infeed = xla::Cast(instr); - // HLO Infeed instruction has a single operand of token type and a tuple - // with buffers and a token as its output. LMHLO Infeed operation does not - // need the token operand or result, so drop it. - SmallVector operands; - TF_RETURN_IF_ERROR(GetOrCreateView(instr, &operands, /*result_subset=*/{0})); - auto infeed_op = CreateOpWithoutAttrs(instr, operands); - infeed_op.setConfigAttr(builder_.getStringAttr(infeed->infeed_config())); - return infeed_op; -} - -tsl::StatusOr LhloDialectEmitter::EmitOutfeedOp( - const HloInstruction* instr) { - const HloOutfeedInstruction* outfeed = - xla::Cast(instr); - // HLO outfeed instruction has 2 operands, the source and a token, and a - // single token output. LMHLO Outfeed does not need the token operand and - // result, do drop it. - SmallVector operands; - TF_RETURN_IF_ERROR(GetOrCreateView(instr->operand(0), &operands)); - auto outfeed_op = CreateOpWithoutAttrs(instr, operands); - outfeed_op.setConfigAttr(builder_.getStringAttr(outfeed->outfeed_config())); - return outfeed_op; -} - -tsl::StatusOr -LhloDialectEmitter::EmitRngGetAndUpdateStateOp( - const xla::HloInstruction* instr) { - TF_ASSIGN_OR_RETURN( - auto rng, CreateOpWithoutAttrs(instr)); - auto hlo_rng = xla::Cast(instr); - rng.setDeltaAttr(builder_.getI64IntegerAttr(hlo_rng->delta())); - return rng; -} - -tsl::StatusOr LhloDialectEmitter::EmitFftOp( - const HloInstruction* instr) { - auto hlo_fft = xla::Cast(instr); - TF_ASSIGN_OR_RETURN(auto fft, CreateOpWithoutAttrs(instr)); - TF_ASSIGN_OR_RETURN(mlir::mhlo::FftType fft_type, - xla::ConvertFftType(hlo_fft->fft_type())); - fft.setFftTypeAttr( - mlir::mhlo::FftTypeAttr::get(builder_.getContext(), fft_type)); - fft.setFftLengthAttr(GetI64DenseElementsAttr(instr->fft_length())); - return fft; -} - -tsl::StatusOr -LhloDialectEmitter::EmitTriangularSolveOp(const xla::HloInstruction* instr) { - auto hlo_triangular_solve = - xla::Cast(instr); - TF_ASSIGN_OR_RETURN(auto triangular_solve, - CreateOpWithoutAttrs(instr)); - const xla::TriangularSolveOptions& options = - hlo_triangular_solve->triangular_solve_options(); - triangular_solve.setLeftSideAttr(builder_.getBoolAttr(options.left_side())); - triangular_solve.setLowerAttr(builder_.getBoolAttr(options.lower())); - triangular_solve.setUnitDiagonalAttr( - builder_.getBoolAttr(options.unit_diagonal())); - TF_ASSIGN_OR_RETURN(mlir::mhlo::Transpose transpose, - xla::ConvertTranspose(options.transpose_a())); - triangular_solve.setTransposeAAttr( - mlir::mhlo::TransposeAttr::get(builder_.getContext(), transpose)); - triangular_solve.setLayoutAAttr( - GetLayoutAttribute(instr->operand(0)->shape().layout(), &builder_)); - triangular_solve.setLayoutBAttr( - GetLayoutAttribute(instr->operand(1)->shape().layout(), &builder_)); - triangular_solve.setLayoutOutputAttr( - GetLayoutAttribute(instr->shape().layout(), &builder_)); - return triangular_solve; -} - -tsl::StatusOr LhloDialectEmitter::EmitBitcast( - const xla::HloInstruction* instr) { - // XLA buffer assignment should assign the same slice to a bitcast input and - // output. - const xla::ShapeIndex top_index; - TF_ASSIGN_OR_RETURN(BufferAllocation::Slice result_slice, - assignment_.GetUniqueSlice(instr, top_index)); - TF_ASSIGN_OR_RETURN(BufferAllocation::Slice input_slice, - assignment_.GetUniqueSlice(instr->operand(0), top_index)); - - if (input_slice != result_slice) { - return tsl::errors::InvalidArgument( - "Bitcast input and result slice should be same"); - } - return nullptr; -} - -mlir::DenseIntElementsAttr LhloDialectEmitter::GetLayoutAttribute( - const xla::Layout& layout, Builder* builder) { - llvm::SmallVector minor_to_major(layout.minor_to_major().begin(), - layout.minor_to_major().end()); - return builder->getIndexTensorAttr(minor_to_major); -} - -tsl::Status LhloDialectEmitter::ImportAsLmhloRegion( - xla::HloComputation* computation, mlir::Region* region) { - auto after = builder_.saveInsertionPoint(); - auto reverter = absl::MakeCleanup( - [this, after] { builder_.restoreInsertionPoint(after); }); - - builder_ = OpBuilder(region); - xla::HloModule* hlo_module = computation->parent(); - if (!hlo_module->has_schedule()) { - return tsl::errors::Unimplemented( - "Missing sequential order for the computation"); - } - const xla::HloInstructionSequence* schedule = - &hlo_module->schedule().sequence(computation); - TF_RETURN_IF_ERROR( - computation->AcceptOrdered(this, schedule->instructions())); - builder_.create(builder_.getUnknownLoc()); - return ::tsl::OkStatus(); -} - -tsl::StatusOr LhloDialectEmitter::EmitCaseOp( - const HloInstruction* instr) { - Location loc = getLocation(instr); - llvm::SmallVector operands; - size_t num_arguments, num_results; - TF_RETURN_IF_ERROR(CreateOperands(instr, 1, TokenLoweringMode::kUseNull, - operands, num_arguments, num_results)); - - auto case_op = - builder_.create(loc, operands[0], instr->branch_count()); - - for (int i = 0; i < instr->branch_count(); i++) { - case_op.getBranches()[i].push_back(new mlir::Block()); - TF_RETURN_IF_ERROR(ImportAsLmhloRegion(instr->called_computations()[i], - &case_op.getBranches()[i])); - } - - return case_op; -} - -tsl::StatusOr LhloDialectEmitter::EmitWhileOp( - const xla::HloInstruction* instr) { - Location loc = getLocation(instr); - SmallVector operands; - TF_RETURN_IF_ERROR(GetOrCreateView( - instr->called_computations()[1]->root_instruction(), &operands)); - TF_RET_CHECK(operands.size() == 1); - - TF_ASSIGN_OR_RETURN(auto config, - instr->backend_config()); - mlir::IntegerAttr trip_count; - if (config.has_known_trip_count()) { - trip_count = builder_.getI64IntegerAttr(config.known_trip_count().n()); - } - lmhlo::WhileOp while_op = - builder_.create(loc, operands[0], trip_count); - - while_op.getCond().push_back(new mlir::Block()); - while_op.getBody().push_back(new mlir::Block()); - TF_RETURN_IF_ERROR(ImportAsLmhloRegion(instr->called_computations()[1], - &while_op.getCond())); - - TF_RETURN_IF_ERROR(ImportAsLmhloRegion(instr->called_computations()[0], - &while_op.getBody())); - - return while_op; -} - -// TODO(b/264291989): Use enum to define the host transfer type (channel type). -template -static void CopyChannelAttrs(OpBuilder& b, Instr* instr, OpTy op, - int host_transfer_type) { - op.setIsHostTransferAttr(b.getBoolAttr(instr->is_host_transfer())); - op.setChannelHandleAttr(mlir::mhlo::ChannelHandleAttr::get( - b.getContext(), *instr->channel_id(), - instr->is_host_transfer() ? host_transfer_type : /*DEVICE_TO_DEVICE*/ 1)); -} - -template -static void CopyFrontendAttrs(OpBuilder& b, Instr* instr, OpTy op) { - llvm::SmallVector frontend_attrs; - for (auto& [name, value] : instr->frontend_attributes().map()) { - frontend_attrs.push_back(b.getNamedAttr(name, b.getStringAttr(value))); - } - op->setAttr(b.getStringAttr("frontend_attributes"), - b.getDictionaryAttr(frontend_attrs)); -} - -tsl::StatusOr LhloDialectEmitter::EmitSendOp( - const xla::HloInstruction* instr) { - llvm::SmallVector operands; - TF_RETURN_IF_ERROR(GetOrCreateView(instr->operand(0), &operands)); - - auto token = mhlo::TokenType::get(builder_.getContext()); - auto send_op = builder_.create(getLocation(instr), - TypeRange(token), operands); - - // Set point-to-point op communication attributes. - auto* send = xla::Cast(instr); - CopyChannelAttrs(builder_, send, send_op, /*host_transfer_type=*/2); - CopyFrontendAttrs(builder_, send, send_op); - - auto [_, emplaced] = ret_tokens_.try_emplace(instr, send_op.getToken()); - TF_RET_CHECK(emplaced) << "send already lowered"; - return send_op; -} - -tsl::StatusOr LhloDialectEmitter::EmitSendDoneOp( - const xla::HloInstruction* instr) { - TF_ASSIGN_OR_RETURN(auto send_done_op, EmitDoneOp(instr)); - // Copy send-done attributes. - auto* send_done = xla::Cast(instr); - CopyChannelAttrs(builder_, send_done, send_done_op, - /*host_transfer_type=*/2); - - return send_done_op; -} - -tsl::StatusOr LhloDialectEmitter::EmitRecvOp( - const xla::HloInstruction* instr) { - llvm::SmallVector operands; - TF_RETURN_IF_ERROR(GetOrCreateView(instr, &operands, {0})); - - auto token = mhlo::TokenType::get(builder_.getContext()); - auto recv_op = builder_.create(getLocation(instr), - TypeRange(token), operands); - - // Set point-to-point op communication attributes. - auto* recv = xla::Cast(instr); - CopyChannelAttrs(builder_, recv, recv_op, /*host_transfer_type=*/3); - CopyFrontendAttrs(builder_, recv, recv_op); - - auto [_, emplaced] = ret_tokens_.try_emplace(instr, recv_op.getToken()); - TF_RET_CHECK(emplaced) << "recv already lowered"; - return recv_op; -} - -tsl::StatusOr LhloDialectEmitter::EmitRecvDoneOp( - const xla::HloInstruction* instr) { - TF_ASSIGN_OR_RETURN(auto recv_done_op, EmitDoneOp(instr)); - // Copy recv-done attributes. - auto* recv_done = xla::Cast(instr); - CopyChannelAttrs(builder_, recv_done, recv_done_op, - /*host_transfer_type=*/3); - - return recv_done_op; -} - -tsl::StatusOr LhloDialectEmitter::EmitCommandBufferOp( - const xla::HloInstruction* instr) { - const std::vector called_computations = - instr->called_computations(); - if (called_computations.size() != 1) { - return absl::InternalError( - "Command buffer calls must have one called computation"); - } - - if (!absl::StartsWith(called_computations[0]->name(), "command_buffer")) { - return absl::InternalError("Called computation must be a command buffer"); - } - return builder_.create(getLocation(instr)); -} - -// Sets builder insertion point for a new `memref.view` operation in the parent -// function. We create just one `memref.view` operation for every unique -// subspan of allocation, and because first use of the slice can be inside a -// block nested in a control flow operation, we have to find an insertion point -// in the parent function. Returns insertion guard for the original insertion -// point. -static tsl::StatusOr SetArrayViewInsertionPoint( - OpBuilder& builder) { - OpBuilder::InsertionGuard guard(builder); - - Operation* parent = builder.getInsertionBlock()->getParentOp(); - while (!isa(parent)) { - builder.setInsertionPoint(parent); - if ((parent = parent->getParentOp()) == nullptr) - return absl::InternalError( - "Can't find an insertion point for memref.view operation"); - } - - return guard; -} - -tsl::StatusOr LhloDialectEmitter::GetOrCreateArrayView( - const xla::HloInstruction* instr, const xla::Shape& current_shape, - const xla::ShapeIndex& shape_index) { - // For constants, the cache is managed inside EmitConstant since it can - // be called either from here or when we see a top-level HloConstant instr. - if (instr->IsConstant() && shape_index.empty()) { - TF_ASSIGN_OR_RETURN(Value constant_memref, EmitConstant(instr)); - return constant_memref; - } - - // Cache generated ViewOp and StaticMemRefCastOp by (instruction, - // shape_index). - auto& instr_slice = instr_slices_[std::make_pair(instr, shape_index)]; - if (instr_slice) { - return instr_slice; - } - - TF_ASSIGN_OR_RETURN(BufferAllocation::Slice slice, - assignment_.GetUniqueSlice(instr, shape_index)); - - // If the shape happens to have dynamic dimensions, create the memref using - // the underlying static shape. - // TODO(jurahul): Revisit this when we can model memrefs with dynamic shape - // but static bounds in MLIR. - xla::Shape static_shape = xla::ShapeUtil::MakeStaticShape(current_shape); - - // Try to find allocation slice with the same physical shape so that we always - // have only one memref.view operation covering the same buffer subspan. All - // reinterpret casts into different layouts will use the same source memref. - xla::Shape physical_shape = - xla::ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( - static_shape); - - // Initialize values in `allocation_slices_` before taking references, - // otherwise we can invalidate them and trigger asan errors below. - auto static_shape_key = std::make_pair(slice, static_shape); - auto physical_shape_key = std::make_pair(slice, physical_shape); - allocation_slices_[static_shape_key]; - allocation_slices_[physical_shape_key]; - - // Check if we already have a memref.view for a given slice and shape. - auto& allocation_slice = allocation_slices_[static_shape_key]; - if (allocation_slice) { - return instr_slice = allocation_slice; - } - - TF_ASSIGN_OR_RETURN(Type out_type, xla::ConvertShapeToType( - static_shape, builder_)); - TF_ASSIGN_OR_RETURN( - Type physical_out_type, - xla::ConvertShapeToType(physical_shape, builder_)); - - // Try to find an insertion point for a new memref.view operation. - TF_ASSIGN_OR_RETURN(auto guard, SetArrayViewInsertionPoint(builder_)); - - // TODO(timshen): revisit location handling. - Location loc = builder_.getUnknownLoc(); - - // Creates new memref.view operation with a `physical_shape`. - auto create_physical_slice = [&]() -> Value { - Value alloc = allocations_[slice.allocation()]; - Value byte_shift = - builder_.create(alloc.getLoc(), slice.offset()); - - // ViewOp only takes memrefs without affine maps (layouts). Let ViewOp - // produce the physical shape (where dimensions are ordered in major to - // minor) first, then follow up with a MemRefReinterpretCast to cast the - // resulting memref to the original layout. - return builder_.create(loc, physical_out_type, alloc, - byte_shift, - /*sizes=*/ValueRange()); - }; - - // Reuse existing physical slice if it exists, otherwise build a new - // memref.view operation and cache it. - auto& physical_slice = allocation_slices_[physical_shape_key]; - if (!physical_slice) { - physical_slice = create_physical_slice(); - } - - // Start from a physical slice as a result, and maybe reinterpret cast it into - // logical shape. - Value result = physical_slice; - - if (result.getType() != out_type) { - int64_t out_offset; - SmallVector out_strides; - auto out_memref_type = out_type.dyn_cast(); - if (!out_memref_type) - return tsl::errors::Internal( - "Expected memref type when creating a view for leaf type of a " - "tuple."); - if (failed(getStridesAndOffset(out_memref_type, out_strides, out_offset))) - return tsl::errors::Internal( - "Failed to get strides and offset from the output type."); - result = builder_.create( - loc, out_memref_type, result, out_offset, out_memref_type.getShape(), - out_strides); - } - - return instr_slice = allocation_slice = result; -} - -tsl::Status LhloDialectEmitter::GetOrCreateViewImpl( - const HloInstruction* instr, const Shape& current_shape, - xla::ShapeIndex* current_shape_index, SmallVectorImpl* values, - TokenLoweringMode token_mode) { - if (current_shape.IsTuple()) { - for (int i = 0; i < current_shape.tuple_shapes().size(); ++i) { - current_shape_index->push_back(i); - TF_RETURN_IF_ERROR( - GetOrCreateViewImpl(instr, current_shape.tuple_shapes(i), - current_shape_index, values, token_mode)); - current_shape_index->pop_back(); - } - return ::tsl::OkStatus(); - } - if (current_shape.IsArray()) { - TF_ASSIGN_OR_RETURN(auto v, GetOrCreateArrayView(instr, current_shape, - *current_shape_index)); - values->push_back(v); - return ::tsl::OkStatus(); - } - if (current_shape.IsToken()) { - switch (token_mode) { - case TokenLoweringMode::kFailToLower: - return tsl::errors::Internal( - "Unexpected token kind for %s and shape index %s", - instr->ToString(), current_shape_index->ToString()); - - case TokenLoweringMode::kUseNull: - values->push_back(Value{}); - return ::tsl::OkStatus(); - } - } - return tsl::errors::Internal( - "Unexpected shape kind for %s and shape index %s", instr->ToString(), - current_shape_index->ToString()); -} - -// Returns a view for the result of an instruction. -// We first get a view for the slice in the allocation, and then may need to -// create another view to adjust the slice for the shape of the instruction. -tsl::Status LhloDialectEmitter::GetOrCreateView( - const HloInstruction* instr, SmallVectorImpl* values, - const xla::ShapeIndex& result_subset, TokenLoweringMode token_mode) { - xla::ShapeIndex shape_index = result_subset; - const Shape& sub_shape = - xla::ShapeUtil::GetSubshape(instr->shape(), shape_index); - return GetOrCreateViewImpl(instr, sub_shape, &shape_index, values, - token_mode); -} - -tsl::Status LhloDialectEmitter::Initialize( - std::vector* ordered_allocations) { - TF_RET_CHECK(computation_.IsEntryComputation()); - - mlir::IntegerAttr unique_id = - builder_.getI32IntegerAttr(computation_.parent()->unique_id()); - module_->setAttr("hlo.unique_id", unique_id); - llvm::StringRef function_name = - computation_.name().empty() ? "__compute" - : llvm::StringRef(computation_.name().data(), - computation_.name().size()); - - // Create the function as () -> (), we'll compute the arguments from the - // buffer allocation and update the type then. - auto func_op = func::FuncOp::create(builder_.getUnknownLoc(), function_name, - builder_.getFunctionType({}, {})); - - { - // This is an optional attribute used by the XLA backend. If the resulting - // LMHLO doesn't go through XLA, this is not needed. - const Shape& shape = computation_.root_instruction()->shape(); - func_op->setAttr( - "result_xla_shape", - builder_.getStringAttr(shape.ToString(/*print_layout=*/true))); - } - Block* block = func_op.addEntryBlock(); - - for (const BufferAllocation& alloc : assignment_.Allocations()) { - if (!alloc.is_thread_local()) { - ordered_allocations->push_back(&alloc); - } - } - - if (computation_.IsEntryComputation()) { - // Sort the rather arbitrarily ordered allocations to match the input/output - // parameters. Specifically we want to sort buffer allocations in the - // following order: - // * Parameters always order before non-parameters. - // * Different parameters order by parameter number. - // * Different allocations for the same parameter order by the shape index. - // - // TODO(timshen): there should be only one non-parameter buffer, the temp - // buffer. Check on that. - const auto allocation_comparator = [](const BufferAllocation* lhs, - const BufferAllocation* rhs) { - if (lhs->is_entry_computation_parameter() != - rhs->is_entry_computation_parameter()) { - return lhs->is_entry_computation_parameter() > - rhs->is_entry_computation_parameter(); - } - if (lhs->is_entry_computation_parameter()) { - return std::tuple( - lhs->parameter_number(), lhs->param_shape_index()) < - std::tuple( - rhs->parameter_number(), rhs->param_shape_index()); - } - return false; - }; - - std::stable_sort(ordered_allocations->begin(), ordered_allocations->end(), - allocation_comparator); - } - - absl::flat_hash_map> - allocation_to_output_info; - TF_RETURN_IF_ERROR(xla::ShapeUtil::ForEachSubshapeWithStatus( - computation_.root_instruction()->shape(), - [&](const Shape& sub_shape, xla::ShapeIndex index) -> tsl::Status { - TF_ASSIGN_OR_RETURN( - auto slice, - assignment_.GetUniqueSlice(computation_.root_instruction(), index)); - const BufferAllocation* alloc = slice.allocation(); - TF_RET_CHECK(slice.offset() == 0); - TF_RET_CHECK(slice.size() == alloc->size()); - allocation_to_output_info[alloc] = std::make_pair(&sub_shape, index); - return ::tsl::OkStatus(); - })); - - // The function signature will be composed of: - // - one memref for each of the parameters. - // - one memref for each other buffer allocation. - llvm::SmallVector args_attrs; - auto it = ordered_allocations->begin(); - while (it != ordered_allocations->end()) { - const BufferAllocation* alloc = *it; - // There are optional attributes to help the program run through XLA. XLA - // defines ExecutionInput and ExecutionOutput structures to carry - // input-output type and buffer information, therefore any information they - // need (mainly the type structure, potentially containing tuples) to be - // preserved. They are not needed if the generated LMHLO is not sent to XLA. - NamedAttrList arg_attr_list; - mlir::Type arg_type = MemRefType::get({alloc->size()}, i8_type_); - - // Propagate source location information for every HLOInstruction that - // uses this allocation. - std::vector buf_locs; - buf_locs.reserve(alloc->assigned_buffers().size()); - for (const auto& entry : alloc->assigned_buffers()) { - const xla::HloValue* hlo_value = entry.first; - buf_locs.push_back(getLocation(hlo_value->instruction())); - } - mlir::Location loc = builder_.getFusedLoc(buf_locs); - - if (alloc->is_entry_computation_parameter()) { - arg_attr_list.set("lmhlo.params", - builder_.getIndexAttr(alloc->parameter_number())); - if (!alloc->param_shape_index().empty()) { - arg_attr_list.set("lmhlo.param_shape_index", - builder_.getI64TensorAttr(llvm::ArrayRef( - alloc->param_shape_index().begin(), - alloc->param_shape_index().end()))); - } - } - // Optional: an attribute for optimization. If a kernel uses this - // allocation, but the allocation has lmhlo.constant_name, then the kernel - // will instead use the global value indicated by the name for potentially - // more optimizations (e.g. constant propagation). - if (alloc->is_constant()) { - arg_attr_list.set( - "lmhlo.constant_name", - builder_.getStringAttr( - xla::llvm_ir::ConstantBufferAllocationToGlobalName(*alloc))); - } - auto iter = allocation_to_output_info.find(alloc); - if (iter != allocation_to_output_info.end()) { - const Shape* sub_shape = iter->second.first; - const xla::ShapeIndex& shape_index = iter->second.second; - if (!sub_shape->IsArray()) { - it = ordered_allocations->erase(it); - continue; - } - arg_attr_list.set("lmhlo.output_index", - builder_.getI64TensorAttr(llvm::ArrayRef( - shape_index.begin(), shape_index.end()))); - if (auto alias = computation_.parent() - ->input_output_alias_config() - .GetAliasedParameter(shape_index)) { - if (alias->must_alias()) { - arg_attr_list.set("lmhlo.must_alias", builder_.getUnitAttr()); - } - } - } - block->addArgument(arg_type, loc); - allocations_[alloc] = block->getArguments().back(); - args_attrs.push_back(arg_attr_list.getDictionary(builder_.getContext())); - it++; - } - - FunctionType function_type = - builder_.getFunctionType(block->getArgumentTypes(), {}); - func_op.setType(function_type); - func_op.setAllArgAttrs(args_attrs); - - symbol_table_.insert(func_op); - builder_.setInsertionPointToEnd(block); - - auto return_op = - builder_.create(builder_.getUnknownLoc()); - builder_ = OpBuilder(return_op); - - return ::tsl::OkStatus(); -} - -tsl::Status HloToLhloModule( - const BufferAssignment& assignment, const HloModule& hlo_module, - ModuleOp module, std::vector* ordered_allocations, - absl::flat_hash_map* - lhlo_to_hlo_map) { - module.getContext() - ->loadDialect(); - - module->setLoc(mlir::NameLoc::get( - mlir::StringAttr::get(module.getContext(), hlo_module.name()))); - - // Store the HloModule's unique_id in the MLIR module. - Builder builder(module.getContext()); - module->setAttr("mhlo.unique_id", - builder.getI64IntegerAttr(hlo_module.unique_id())); - - const HloComputation* computation = hlo_module.entry_computation(); - - LhloDialectEmitter emitter(assignment, *computation, module); - TF_RETURN_IF_ERROR(emitter.Initialize(ordered_allocations)); - - const xla::HloInstructionSequence* schedule = - &hlo_module.schedule().sequence(computation); - - if (!schedule) { - return tsl::errors::Unimplemented( - "Missing sequential order for the computation"); - } - BaseScopedDiagnosticHandler status_handler(module.getContext()); - - const std::vector& ordering = schedule->instructions(); - TF_RETURN_IF_ERROR(computation->AcceptOrdered(&emitter, ordering)); - TF_RETURN_IF_ERROR(status_handler.ConsumeStatus()); - - (void)mlir::verify(module); - - if (lhlo_to_hlo_map) { - auto map = emitter.ConsumeLhloToHloMap(); - std::swap(*lhlo_to_hlo_map, map); - } - return status_handler.ConsumeStatus(); -} - -OwningOpRef HloTextToLhloTranslateFunction( - llvm::StringRef input, MLIRContext* context) { - tsl::StatusOr> maybe_module = - xla::ParseAndReturnUnverifiedModule( - absl::string_view(input.data(), input.size())); - TF_CHECK_OK(maybe_module.status()); - - OwningOpRef module = - xla::llvm_ir::CreateMlirModuleOp(UnknownLoc::get(context)); - - TF_CHECK_OK( - ConvertHloToLmhlo(std::move(maybe_module).value(), module.get(), "Host")); - - return module; -} - -} // namespace mlir diff --git a/third_party/xla/xla/translate/mhlo_to_lhlo_with_xla/mhlo_to_lhlo_with_xla.h b/third_party/xla/xla/translate/mhlo_to_lhlo_with_xla/mhlo_to_lhlo_with_xla.h deleted file mode 100644 index e945ea2d2c5807..00000000000000 --- a/third_party/xla/xla/translate/mhlo_to_lhlo_with_xla/mhlo_to_lhlo_with_xla.h +++ /dev/null @@ -1,348 +0,0 @@ -/* Copyright 2020 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_TRANSLATE_MHLO_TO_LHLO_WITH_XLA_MHLO_TO_LHLO_WITH_XLA_H_ -#define XLA_TRANSLATE_MHLO_TO_LHLO_WITH_XLA_MHLO_TO_LHLO_WITH_XLA_H_ - -#include -#include -#include -#include - -#include "absl/types/optional.h" -#include "mlir/Dialect/MemRef/IR/MemRef.h" // from @llvm-project -#include "mlir/IR/Builders.h" // from @llvm-project -#include "mlir/IR/BuiltinOps.h" // from @llvm-project -#include "xla/hlo/ir/hlo_instruction.h" -#include "xla/hlo/ir/hlo_instructions.h" -#include "xla/mlir_hlo/lhlo/IR/lhlo_ops.h" -#include "xla/mlir_hlo/lhlo_gpu/IR/lhlo_gpu_ops.h" -#include "xla/service/buffer_assignment.h" -#include "xla/shape_util.h" -#include "tsl/platform/status.h" -#include "tsl/platform/statusor.h" - -namespace mlir { - -// This class will process an HloModule with the supplied BufferAssignment and -// populate the MLIR ModuleOp with the computation converted in the LHLO -// dialect. -class LhloDialectEmitter : public xla::ConstDfsHloVisitorWithDefault { - public: - // Initializes internal data structures. It must be called before calling any - // of the visitors. - tsl::Status Initialize( - std::vector* ordered_allocations); - - LhloDialectEmitter(const xla::BufferAssignment& assignment, - const xla::HloComputation& computation, ModuleOp module) - : assignment_(assignment), - computation_(computation), - module_(module), - symbol_table_(module), - builder_(module.getContext()), - i8_type_(builder_.getIntegerType(8)) {} - - tsl::StatusOr EmitOp(const xla::HloInstruction* instr); - - static tsl::StatusOr - GetScatterDimensionNumbers(const xla::HloInstruction* instr, - mlir::MLIRContext* context); - - absl::flat_hash_map - ConsumeLhloToHloMap() { - return std::move(lhlo_to_hlo_); - } - - private: - tsl::StatusOr EmitSortOp(const xla::HloInstruction* instr); - tsl::StatusOr EmitFusionOp(const xla::HloInstruction* instr); - tsl::StatusOr EmitScatterOp( - const xla::HloInstruction* instr); - tsl::StatusOr EmitSelectAndScatterOp( - const xla::HloInstruction* instr); - - tsl::StatusOr EmitCustomCallOp(const xla::HloInstruction* instr); - tsl::StatusOr EmitCholesky( - const xla::HloCustomCallInstruction* custom_call); - tsl::StatusOr EmitGemm( - const xla::HloCustomCallInstruction* custom_call); - tsl::StatusOr EmitCublasLtMatmul( - const xla::HloCustomCallInstruction* custom_call); - tsl::StatusOr EmitCublasLtMatmulF8( - const xla::HloCustomCallInstruction* custom_call); - tsl::StatusOr EmitDnnConvolution( - const xla::HloCustomCallInstruction* custom_call); - tsl::StatusOr EmitDnnConvolutionReorderVectorized( - const xla::HloCustomCallInstruction* custom_call); - tsl::StatusOr EmitDnnBatchNorm( - const xla::HloCustomCallInstruction* custom_call); - xla::StatusOr EmitDnnfMHA( - const xla::HloCustomCallInstruction* custom_call); - xla::StatusOr EmitDnnfMHABackward( - const xla::HloCustomCallInstruction* custom_call); - tsl::StatusOr EmitDnnNorm( - const xla::HloCustomCallInstruction* custom_call); - xla::StatusOr EmitCubDeviceRadixSort( - const xla::HloCustomCallInstruction* custom_call); - tsl::StatusOr EmitConstant( - const xla::HloInstruction* instr); - - tsl::StatusOr EmitInfeedOp(const xla::HloInstruction* instr); - tsl::StatusOr EmitOutfeedOp( - const xla::HloInstruction* instr); - - template - tsl::StatusOr EmitDoneOp(const xla::HloInstruction* instr); - - tsl::StatusOr EmitAllToAllStartOp( - const xla::HloInstruction* instr); - tsl::StatusOr EmitAllToAllDoneOp( - const xla::HloInstruction* instr); - tsl::StatusOr EmitAllGatherStartOp( - const xla::HloInstruction* instr); - tsl::StatusOr EmitAllGatherDoneOp( - const xla::HloInstruction* instr); - tsl::StatusOr EmitAllReduceStartOp( - const xla::HloInstruction* instr); - tsl::StatusOr EmitAllReduceDoneOp( - const xla::HloInstruction* instr); - tsl::StatusOr EmitAsyncStartOp( - const xla::HloInstruction* instr); - tsl::StatusOr EmitAsyncDoneOp( - const xla::HloInstruction* instr); - tsl::StatusOr EmitReduceScatterStartOp( - const xla::HloInstruction* instr); - tsl::StatusOr EmitReduceScatterDoneOp( - const xla::HloInstruction* instr); - tsl::StatusOr - EmitCollectivePermuteStartOp(const xla::HloInstruction* instr); - tsl::StatusOr EmitCollectivePermuteDoneOp( - const xla::HloInstruction* instr); - - tsl::StatusOr EmitRngGetAndUpdateStateOp( - const xla::HloInstruction* instr); - tsl::StatusOr EmitFftOp(const xla::HloInstruction* instr); - tsl::StatusOr EmitTriangularSolveOp( - const xla::HloInstruction* instr); - tsl::StatusOr EmitBitcast(const xla::HloInstruction* instr); - - tsl::StatusOr EmitCaseOp(const xla::HloInstruction* instr); - - tsl::StatusOr EmitWhileOp(const xla::HloInstruction* instr); - - tsl::StatusOr EmitSendOp(const xla::HloInstruction* instr); - tsl::StatusOr EmitSendDoneOp( - const xla::HloInstruction* instr); - - tsl::StatusOr EmitRecvOp(const xla::HloInstruction* instr); - tsl::StatusOr EmitRecvDoneOp( - const xla::HloInstruction* instr); - - tsl::StatusOr EmitCommandBufferOp( - const xla::HloInstruction* instr); - - tsl::Status ImportAsLmhloRegion(xla::HloComputation* computation, - mlir::Region* region); - - // Since LMHLO dialect does not define token types, this enum controls how - // token operand/results from XLA:HLO are lowered to MLIR. - enum class TokenLoweringMode { - kFailToLower, // Fail lowering if token inputs are encountered. - kUseNull, // Use a null Value in the operand list for each token. - // kSkip, // Skip any token inputs or outputs (not yet needed) - }; - - // Create LHLO operation operands given an XLA HLO instruction. By default, - // all XLA HLO operands and results are converted to MLIR and appended to - // `operands`. If `num_operands` is specified, only the first `num_operand` - // operands of the instruction are converted to MLIR. The function returns the - // actual number of operands and results generated for MLIR in `num_arguments` - // and `num_results`. - tsl::Status CreateOperands(const xla::HloInstruction* instr, - std::optional num_operands, - TokenLoweringMode token_mode, - SmallVectorImpl& operands, - size_t& num_arguments, size_t& num_results); - - template - tsl::StatusOr CreateOpWithoutAttrs( - const xla::HloInstruction* instr, - std::optional num_operands = std::nullopt) { - size_t unused; - return CreateOpWithoutAttrs(instr, unused, unused, num_operands); - } - - template - tsl::StatusOr CreateOpWithoutAttrs( - const xla::HloInstruction* instr, size_t& num_arguments, - size_t& num_results, std::optional num_operands = std::nullopt); - - template - OpType CreateOpWithoutAttrs(const xla::HloInstruction* instr, - ValueRange operands); - - template - DenseIntElementsAttr GetI64DenseElementsAttr(const T& container) { - return builder_.getI64TensorAttr( - {container.data(), static_cast(container.size())}); - } - - DenseIntElementsAttr GetWindowElements( - const xla::Window& window, - std::function getter) { - llvm::SmallVector elements; - elements.reserve(window.dimensions_size()); - for (const xla::WindowDimension& dim : window.dimensions()) { - elements.push_back(getter(dim)); - } - return GetI64DenseElementsAttr(elements); - } - - static mlir::DenseIntElementsAttr GetLayoutAttribute( - const xla::Layout& layout, Builder* builder); - - tsl::Status DefaultAction(const xla::HloInstruction* instr) final; - - // Computation parameters don't need any specific handling when they are - // visited, they are already processed when we enter a new computation. - tsl::Status HandleParameter(const xla::HloInstruction* instr) final { - return ::tsl::OkStatus(); - } - - // Helper function that recursively visits the tuple structure in - // `current_shape`, and reconstruct a matching lmhlo::TupleOp. - // Each leaf node is converted to an std.view op with corresponding offsets. - // If no tuple presents, it simply returns a view of the buffer. - tsl::Status GetOrCreateViewImpl(const xla::HloInstruction* instr, - const xla::Shape& current_shape, - xla::ShapeIndex* current_shape_index, - SmallVectorImpl* values, - TokenLoweringMode token_mode); - - // Helper function to create view/tuple of views to a buffer for a given - // instruction result. `result_subset` can be used to for instructions that - // have a tuple result and MLIR conversion needs to convert only one of the - // tuple elements. Note that if needed, this can be extended to take a list of - // ShapeIndex values in case we need finer control on what elements of the - // output tuple to be converted to MLIR. - tsl::Status GetOrCreateView( - const xla::HloInstruction* instr, SmallVectorImpl* values, - const xla::ShapeIndex& result_subset = {}, - TokenLoweringMode token_mode = TokenLoweringMode::kFailToLower); - - tsl::StatusOr GetOrCreateArrayView( - const xla::HloInstruction* instr, const xla::Shape& current_shape, - const xla::ShapeIndex& current_shape_index); - - tsl::StatusOr RewriteFusionOperand(const xla::HloInstruction* root, - const xla::Shape& shape, - xla::ShapeIndex* shape_index, - OpBuilder* b, Location loc); - - // Return an MLIR location for an HLO instruction. - Location getLocation(const xla::HloInstruction* inst) { - return NameLoc::get(builder_.getStringAttr(inst->name())); - } - - // This map provides access to MLIR buffers for each HLO buffer allocation. - // The MLIR buffers are all `memref<{size}xi8>` and correspond to function - // parameters. It is populated at the beginning of the processing with all - // the buffer allocations and is unchanged afterward. Every HLOInstruction - // is using a "slice" of the buffer allocation and providing shape, layout, - // and Dtype. An MLIR view is used separately to model slices into the - // allocations (see below). - llvm::DenseMap allocations_; - - // This map provides access to MLIR buffers constructed from memref arguments - // (allocations) using memref.view operation at the given offset (defined by - // slice) and result type (defined by shape). By using this cache we guarantee - // that we have a unique memref.view operation corresponding to each - // allocation slice. - absl::flat_hash_map, - Value> - allocation_slices_; - - // This map provides access to MLIR buffers for each HLO instruction, keyed - // instruction identity. A slice is contained in a BufferAllocation, and has - // an offset and a size. - // - // As for why we don't use HloInstruction*, see GetOrCreateView(), but - // mostly we want to leverage better of the aliased buffers. - // - // If the HloInstruction is a tuple, all leaf nodes are stored flattened. - // Otherwise, there will be a single buffer. - // - // An MLIR buffer is either an input parameter, or a ViewOp in the case - // where the slice is only part of its allocation. - // - // `instr_slices_` is populated lazily in the `GetOrCreateView()` helper as we - // process every instruction. - absl::flat_hash_map, - Value> - instr_slices_; - - // The BufferAssignment computed by XLA ahead of time. - const xla::BufferAssignment& assignment_; - - // The HLO module that will be converted. - const xla::HloComputation& computation_; - - // This is the MLIR module in which a function will be created for every HLO - // computation. - ModuleOp module_; - - // SymbolTable associated with the module. New functions should be added using - // this to avoid name conflicts. - mlir::SymbolTable symbol_table_; - - // The builder keeps track of the current insertion point in the MLIR - // module. - OpBuilder builder_; - // Convenient "cached" access to this widely used MLIR type (i8). - Type i8_type_; - - // Map ops returning tokens to their output (async collectives start ops, and - // point-to-point communication ops), to connect the correct done op. - absl::flat_hash_map ret_tokens_; - - // Maps each LHLO op created directly by this emitter to the corresponding HLO - // instruction. - // Note: this does not contain ops that are inside the bodies of fusions. - absl::flat_hash_map - lhlo_to_hlo_; -}; - -// Populate the MLIR `module` with the computation from the `hlo_module` using -// the provided buffer `assignment`. The returned `Status` indicates success -// or failure in the conversion. -// `lhlo_to_hlo_map`, if non-null, is populated with a mapping from generated -// top-level MLIR operations to the original HLO instructions. "top-level" means -// that ops inside the bodies of fusions are not included (but all fusions are). -// Store buffer allocations from buffer assignment in the order of inputs to the -// LMHLO entry function. -tsl::Status HloToLhloModule( - const xla::BufferAssignment& assignment, const xla::HloModule& hlo_module, - ModuleOp module, - std::vector* ordered_allocation, - absl::flat_hash_map* - lhlo_to_hlo_map = nullptr); - -OwningOpRef HloTextToLhloTranslateFunction( - llvm::StringRef input, MLIRContext* context); - -} // namespace mlir - -#endif // XLA_TRANSLATE_MHLO_TO_LHLO_WITH_XLA_MHLO_TO_LHLO_WITH_XLA_H_ diff --git a/third_party/xla/xla/translate/mhlo_to_lhlo_with_xla/tests/BUILD b/third_party/xla/xla/translate/mhlo_to_lhlo_with_xla/tests/BUILD deleted file mode 100644 index 2196d02175d1a3..00000000000000 --- a/third_party/xla/xla/translate/mhlo_to_lhlo_with_xla/tests/BUILD +++ /dev/null @@ -1,29 +0,0 @@ -load("//xla:lit.bzl", "enforce_glob", "lit_test_suite") - -package( - default_visibility = ["//visibility:public"], - # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - licenses = ["notice"], -) - -lit_test_suite( - name = "all_tests", - srcs = enforce_glob( - [ - "hlo_text_to_lhlo_no_opt.hlotxt", - "non_identity_layouts.hlotxt", - "no_opt_ops.hlotxt", - ], - include = [ - "*.hlotxt", - ], - ), - cfg = "//xla:lit.cfg.py", - tools = [ - "//xla/translate:xla-translate", - "//xla/translate/mhlo_to_lhlo_with_xla:xla-translate-gpu-opt", - "//xla/translate/mhlo_to_lhlo_with_xla:xla-translate-opt", - "@llvm-project//llvm:FileCheck", - "@llvm-project//llvm:not", - ], -) diff --git a/third_party/xla/xla/translate/mhlo_to_lhlo_with_xla/tests/hlo_text_to_lhlo_no_opt.hlotxt b/third_party/xla/xla/translate/mhlo_to_lhlo_with_xla/tests/hlo_text_to_lhlo_no_opt.hlotxt deleted file mode 100644 index ed6c3f71e95d35..00000000000000 --- a/third_party/xla/xla/translate/mhlo_to_lhlo_with_xla/tests/hlo_text_to_lhlo_no_opt.hlotxt +++ /dev/null @@ -1,736 +0,0 @@ -// RUN: xla-translate -split-input-file -hlo-text-to-lhlo %s | FileCheck %s - -HloModule TestModule - -// CHECK-LABEL: func @TestComputation - -FusedComputation { - // CHECK: to_tensor {{.*}} {xla_shape = "f32[3,2]{0,1}"} - x = f32[3, 2]{0,1} parameter(0) - ROOT y = f32[3, 2]{0,1} add(x, x) -} - -ENTRY TestComputation { - x = f32[3, 2]{0,1} parameter(0) - ROOT y = f32[3, 2]{0,1} fusion(x), kind=kLoop, calls=FusedComputation -} - -// ----- - -HloModule ScatterModule - -update_s32 (lhs: s32[], rhs: s32[]) -> s32[] { - lhs = s32[] parameter(0) - ROOT rhs = s32[] parameter(1) -} - -// CHECK-LABEL: func @main -// CHECK: "lmhlo.scatter" -// CHECK: indices_are_sorted = false -// CHECK: update_window_dims = [1] -// CHECK: inserted_window_dims = [0] -// CHECK: scatter_dims_to_operand_dims = [0] -// CHECK: index_vector_dim = 1 -// CHECK: unique_indices = false -// CHECK: ^bb0(%[[ARG5:.*]]: tensor, %[[ARG6:.*]]: tensor): -// CHECK: mhlo.return %[[ARG6]] -// CHECK: (memref<3x3xi32>, memref<2xi32>, memref<2x3xi32>, memref<3x3xi32>) -> () -ENTRY main { - operand = s32[3,3] parameter(0) - indices = s32[2] parameter(1) - updates = s32[2,3] parameter(2) - ROOT scatter_op = s32[3,3] scatter(operand, indices, updates), - to_apply=update_s32, - update_window_dims={1}, - inserted_window_dims={0}, - scatter_dims_to_operand_dims={0}, - index_vector_dim=1 -} - -// ----- - -HloModule SelectAndScatter - -%ge_F32 (lhs.5: f32[], rhs.6: f32[]) -> pred[] { - %lhs.5 = f32[] parameter(0) - %rhs.6 = f32[] parameter(1) - ROOT %compare.7 = pred[] compare(f32[] %lhs.5, f32[] %rhs.6), direction=GE -} - -%add_F32 (lhs.9: f32[], rhs.10: f32[]) -> f32[] { - %lhs.9 = f32[] parameter(0) - %rhs.10 = f32[] parameter(1) - ROOT %add.11 = f32[] add(f32[] %lhs.9, f32[] %rhs.10) -} - -// CHECK-LABEL: module -// CHECK: memref.global "private" constant @[[$GLOBAL:.*]] : memref = dense<0.000000e+00> -// CHECK-LABEL: func @main -// CHECK: %[[GLOBAL_MEMREF:.*]] = memref.get_global @[[$GLOBAL]] : memref -// CHECK: "lmhlo.select_and_scatter"(%{{.*}}, %{{.*}}, %[[GLOBAL_MEMREF]], %{{.*}}) -// CHECK: padding = dense<0> : tensor<1xi64> -// CHECK: window_dimensions = dense<3> : tensor<1xi64> -// CHECK: window_strides = dense<3> : tensor<1xi64> -// CHECK: ^bb0(%[[ARG0:.*]]: tensor, %[[ARG1:.*]]: tensor): -// CHECK: %[[COMPARE:.*]] = mhlo.compare GE, %[[ARG0]], %[[ARG1]] -// CHECK: mhlo.return %[[COMPARE]] : tensor -// CHECK: ^bb0(%[[ARG2:.*]]: tensor, %[[ARG3:.*]]: tensor): -// CHECK: %[[ADD:.*]] = mhlo.add %[[ARG2]], %[[ARG3]] -// CHECK: mhlo.return %[[ADD]] : tensor -// CHECK: (memref<6xf32>, memref<2xf32>, memref, memref<6xf32>) -> () -ENTRY main () -> f32[6] { - %operand = f32[6]{0} parameter(0) - %source = f32[2]{0} parameter(1) - %init = f32[] constant(0) - ROOT %select-and-scatter.12 = f32[6]{0} select-and-scatter(f32[6]{0} %operand, f32[2]{0} %source, f32[] %init), window={size=3 stride=3}, select=%ge_F32, scatter=%add_F32 -} - -// ----- - -HloModule SliceToDynamic - -// CHECK-LABEL: func @main -// CHECK: "lmhlo.custom_call" -// CHECK: backend_config = "", call_target_name = "SliceToDynamic" -// CHECK-SAME: operandSegmentSizes = array -// CHECK-NOT: target_arg_mapping -// CHECK: (memref<2x2x2xi32>, memref, memref, memref, memref<2x2x2xi32>) -> () -ENTRY main { - %param = s32[2,2,2] parameter(0) - %static = s32[] parameter(1) - %dynamic = s32[] parameter(2) - ROOT %custom-call = s32[2,<=2, 2] custom-call(s32[2,2,2] %param, - s32[] %static, - s32[] %dynamic, - s32[] %static), - custom_call_target="SliceToDynamic", - backend_config="" -} - -// ----- - -HloModule Cholesky - -// CHECK-LABEL: func @main -// CHECK: "lmhlo_gpu.cholesky" -// CHECK-SAME: is_lower = true -ENTRY main { - %param = f32[3,3] parameter(0) - ROOT %custom-call = (f32[3,3], f32[3], s32[]) custom-call(f32[3,3] %param), - custom_call_target="__cusolver$cholesky", - operand_layout_constraints={f32[3,3]}, - backend_config="{\"lower\":true}" -} - -// ----- - -HloModule Gemm - -// CHECK-LABEL: func @main -// CHECK: "lmhlo_gpu.gemm" -// CHECK-SAME: algorithm = 7 : i64 -// CHECK-SAME: alpha_imag = 0.000000e+00 : f64 -// CHECK-SAME: alpha_real = 1.000000e+00 : f64 -// CHECK-SAME: beta = 0.000000e+00 : f64 -// CHECK-NOT: lhs_batching_dimensions -// CHECK-NOT: rhs_batching_dimensions -// CHECK-SAME: lhs_contracting_dimensions = [1] -// CHECK-SAME: rhs_contracting_dimensions = [0] -// CHECK-SAME: precision_config = [#mhlo, #mhlo] -// CHECK: (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>) -> () -ENTRY main { - %A = f32[2,2]{1,0} parameter(0) - %B = f32[2,2]{1,0} parameter(1) - ROOT %sgemm = f32[2,2]{1,0} custom-call(f32[2,2]{1,0} %A, f32[2,2]{1,0} %B), - custom_call_target="__cublas$gemm", - backend_config="{ \"gemm_backend_config\": {\"alpha_real\":1,\"alpha_imag\":0,\"beta\":0,\"dot_dimension_numbers\":{\"lhs_contracting_dimensions\":[\"1\"],\"rhs_contracting_dimensions\":[\"0\"],\"lhs_batch_dimensions\":[],\"rhs_batch_dimensions\":[]},\"precision_config\":{\"operand_precision\":[\"HIGH\",\"HIGHEST\"]},\"selected_algorithm\":\"7\"}}" -} - -// ----- - -HloModule CublasLtMatmul - -// CHECK-LABEL: func @main -// CHECK: "lmhlo_gpu.cublas.lt.matmul" -// CHECK-SAME: alpha_imag = 0.000000e+00 : f64 -// CHECK-SAME: alpha_real = 1.000000e+00 : f64 -// CHECK-SAME: beta = 0.000000e+00 : f64 -// CHECK-NOT: lhs_batching_dimensions -// CHECK-NOT: rhs_batching_dimensions -// CHECK-SAME: lhs_contracting_dimensions = [1] -// CHECK-SAME: rhs_contracting_dimensions = [0] -// CHECK-SAME: precision_config = [#mhlo, #mhlo] -// CHECK: (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>) -> () - -ENTRY main { - %A = f32[2,2]{1,0} parameter(0) - %B = f32[2,2]{1,0} parameter(1) - ROOT %custom-call = f32[2,2]{1,0} custom-call(f32[2,2]{1,0} %A, f32[2,2]{1,0} %B), custom_call_target="__cublas$lt$matmul", - backend_config="{ \"gemm_backend_config\": {\"alpha_real\":1,\"alpha_imag\":0,\"beta\":0,\"dot_dimension_numbers\":{\"lhs_contracting_dimensions\":[\"1\"],\"rhs_contracting_dimensions\":[\"0\"],\"lhs_batch_dimensions\":[],\"rhs_batch_dimensions\":[]},\"precision_config\":{\"operand_precision\":[\"DEFAULT\",\"DEFAULT\"]}}}" -} - -// ----- - -HloModule CublasLtMatmulF8 - -// CHECK-LABEL: func @main -// CHECK: "lmhlo_gpu.cublas.lt.matmul.f8" -// CHECK-SAME: alpha_imag = 0.000000e+00 : f64 -// CHECK-SAME: alpha_real = 1.000000e+00 : f64 -// CHECK-SAME: beta = 1.000000e+00 : f64 -// CHECK-NOT: lhs_batching_dimensions -// CHECK-NOT: rhs_batching_dimensions -// CHECK-SAME: lhs_contracting_dimensions = [1] -// CHECK-SAME: rhs_contracting_dimensions = [0] -// CHECK-SAME: precision_config = [#mhlo, #mhlo] -// CHECK: (memref<16x16xf8E4M3FN>, memref<16x16xf8E4M3FN>, memref<16x16xf16>, memref, memref, memref, memref, memref<16x16xf8E4M3FN>, memref) -> () - -ENTRY main { - %A = f8e4m3fn[16,16]{1,0} parameter(0) - %B = f8e4m3fn[16,16]{1,0} parameter(1) - %C = f16[16,16]{1,0} parameter(2) - %A_SCALE = f32[] parameter(3) - %B_SCALE = f32[] parameter(4) - %C_SCALE = f32[] parameter(5) - %D_SCALE = f32[] parameter(6) - ROOT %custom-call = (f8e4m3fn[16,16]{1,0}, f32[]) custom-call(f8e4m3fn[16,16]{1,0} %A, f8e4m3fn[16,16]{1,0} %B, f16[16,16]{1,0} %C, f32[] %A_SCALE, f32[] %B_SCALE, f32[] %C_SCALE, f32[] %D_SCALE), custom_call_target="__cublas$lt$matmul$f8", - backend_config="{ \"gemm_backend_config\": {\"alpha_real\":1,\"alpha_imag\":0,\"beta\":1.0,\"dot_dimension_numbers\":{\"lhs_contracting_dimensions\":[\"1\"],\"rhs_contracting_dimensions\":[\"0\"],\"lhs_batch_dimensions\":[],\"rhs_batch_dimensions\":[]},\"precision_config\":{\"operand_precision\":[\"DEFAULT\",\"DEFAULT\"]}}}" -} - - -// ----- - -HloModule AsyncAllReduce - -// Test all-reduce-async -add { - lhs = f32[] parameter(0) - rhs = f32[] parameter(1) - ROOT add = f32[] add(lhs, rhs) -} - -// CHECK-LABEL: func @test_async_all_reduce -// CHECK-SAME: [[BUFFER:%.*]]: memref<32xi8> -%test_async_all_reduce { - param0 = f32[8] parameter(0) - // CHECK: [[VIEW:%.*]] = memref.view [[BUFFER]]{{.*}} : memref<32xi8> to memref<8xf32> - // CHECK: [[TOKEN:%.*]] = "lmhlo_gpu.all_reduce_start"([[VIEW]], [[VIEW]]) - // CHECK-SAME: channel_id = #mhlo.channel_handle - // CHECK-SAME{LITERAL}: replica_groups = dense<[[0, 1, 2, 3], [4, 5, 6, 7]]> : tensor<2x4xi64> - // CHECK: ^bb0([[ARG0:%.*]]: tensor, [[ARG1:%.*]]: tensor): - // CHECK: [[ADD:%.*]] = mhlo.add [[ARG0]], [[ARG1]] - // CHECK: mhlo.return [[ADD]] : tensor - // CHECK: }) - // CHECK: "lmhlo_gpu.all_reduce_done"([[TOKEN]]) - start = f32[8] all-reduce-start(param0), - channel_id=1, replica_groups={{0,1,2,3}, {4,5,6,7}}, to_apply=add - ROOT done = f32[8] all-reduce-done(start) -} - -// ----- - -HloModule AsyncAllReduceTwoOperands - -// Test all-reduce-async -add { - lhs = f32[] parameter(0) - rhs = f32[] parameter(1) - ROOT add = f32[] add(lhs, rhs) -} - -// CHECK-LABEL: func @test_async_all_reduce_two_operands -// CHECK-SAME: [[BUFFER0:%.*]]: memref<32xi8> -// CHECK-SAME: [[BUFFER1:%.*]]: memref<36xi8> -%test_async_all_reduce_two_operands { - param0 = f32[8] parameter(0) - param1 = f32[9] parameter(1) - // CHECK: [[VIEW0:%.*]] = memref.view [[BUFFER0]]{{.*}} : memref<32xi8> to memref<8xf32> - // CHECK: [[VIEW1:%.*]] = memref.view [[BUFFER1]]{{.*}} : memref<36xi8> to memref<9xf32> - // CHECK: [[TOKEN:%.*]] = "lmhlo_gpu.all_reduce_start"([[VIEW0]], [[VIEW1]], [[VIEW0]], [[VIEW1]]) - // CHECK-SAME: channel_id = #mhlo.channel_handle - // CHECK-SAME{LITERAL}: replica_groups = dense<[[0, 1, 2, 3], [4, 5, 6, 7]]> : tensor<2x4xi64> - // CHECK: ^bb0([[ARG0:%.*]]: tensor, [[ARG1:%.*]]: tensor): - // CHECK: [[ADD:%.*]] = mhlo.add [[ARG0]], [[ARG1]] - // CHECK: mhlo.return [[ADD]] : tensor - // CHECK: }) - // CHECK: "lmhlo_gpu.all_reduce_done"([[TOKEN]]) - start = (f32[8], f32[9]) all-reduce-start(param0, param1), - channel_id=1, replica_groups={{0,1,2,3}, {4,5,6,7}}, to_apply=add - ROOT done = (f32[8], f32[9]) all-reduce-done(start) -} - -// ----- - -HloModule ConvForward - -// CHECK-LABEL: func @main -// CHECK: lmhlo_gpu.conv_forward -// CHECK-SAME: dim_numbers = [b, f, 0, 1]x[o, i, 0, 1]->[b, f, 0, 1] -// CHECK-SAME{LITERAL}: window = {stride = [1, 1], pad = [0, 0], lhs_dilate = [1, 1], rhs_dilate = [1, 1], reverse = [1, 1]} -// CHECK-SAME: algorithm = 2 -// CHECK-SAME: tensor_ops_enabled = false -// CHECK-SAME: operand_0_layout = [3, 2, 1, 0] -// CKECK-SAME: operand_1_layout = [3, 2, 1, 0] -// CHECK-SAME: result_layout = [3, 2, 1, 0] -// CHECK-SAME: batch_group_count = 1 : i64 -// CHECK-SAME: feature_group_count = 1 : i64 -// CHECK-SAME: result_scale = 1.000000e+00 : f64 -// CHECK: (memref<4x256x3x3xf32>, memref<256x256x2x2xf32>, memref<4x256x2x2xf32>, memref<65536xui8>) -ENTRY main { - %input = f32[4,256,3,3]{3,2,1,0} parameter(0) - %filter = f32[256,256,2,2]{3,2,1,0} parameter(1) - ROOT %custom-call.1 = (f32[4,256,2,2]{3,2, 1,0}, u8[65536]{0}) custom-call(f32[4,256,3,3]{3,2,1,0} %input, f32[256,256,2,2]{3,2,1,0} %filter), - window={size=2x2 rhs_reversal=1x1}, dim_labels=bf01_oi01->bf01, - custom_call_target="__cudnn$convForward", - backend_config="{ \"cudnn_conv_backend_config\": {\"algorithm\": {\"algo_id\":\"2\",\"math_type\":\"DEFAULT_MATH\"},\"conv_result_scale\":1,\"activation_mode\":\"0\",\"side_input_scale\":0}}" -} - -// ----- - -// CHECK: func @main -// CHECK: lmhlo_gpu.conv_forward_fused -// CHECK-SAME: dim_numbers = [b, f, 0, 1]x[0, 1, i, o]->[b, f, 0, 1] -// CHECK-SAME{LITERAL}: window = {stride = [1, 1], pad = [1, 1], lhs_dilate = [1, 1], rhs_dilate = [1, 1], reverse = [0, 0]} -// CHECK-SAME: activation_mode = #lmhlo_gpu -// CHECK-SAME: algorithm = 0 -// CHECK-SAME: tensor_ops_enabled = false -// CHECK-SAME: operand_0_layout = [1, 3, 2, 0] -// CHECK-SAME: operand_1_layout = [2, 1, 0, 3] -// CHECK-SAME: result_layout = [1, 3, 2, 0] -// CHECK-SAME: batch_group_count = 1 : i64 -// CHECK-SAME: feature_group_count = 1 : i64 -// CHECK-SAME: precision_config = [#mhlo, #mhlo, #mhlo] -// CHECK-SAME: result_scale = 1.000000e+00 : f64 -// CHECK-SAME: (memref<1x17x9x9xf16, #map{{.*}}>, memref<3x3x17x32xf16, #map{{.*}}>, memref<32xf16>, memref<1x32x9x9xf16, #{{.*}}>, memref<0xui8>) -> () - -HloModule FusedConvForward - -ENTRY main { - %input = f16[1,17,9,9]{1,3,2,0} parameter(0) - %filter = f16[3,3,17,32]{2,1,0,3} parameter(1) - %bias = f16[32]{0} parameter(2) - ROOT %custom-call.2 = (f16[1,32,9,9]{1,3,2,0}, u8[0]{0}) custom-call(f16[1,17,9,9]{1,3,2,0} %input, f16[3,3,17,32]{2,1,0,3} %filter, f16[32]{0} %bias), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01, custom_call_target="__cudnn$convBiasActivationForward", backend_config="{ \"cudnn_conv_backend_config\": {\"algorithm\": {\"algo_id\":\"0\",\"math_type\":\"DEFAULT_MATH\"},\"conv_result_scale\":1,\"activation_mode\":\"2\",\"side_input_scale\":0}}" -} - -// ----- - -// CHECK: func @main -// CHECK: lmhlo_gpu.conv_forward_fused_with_side_input -// CHECK-SAME: dim_numbers = [b, f, 0, 1]x[0, 1, i, o]->[b, f, 0, 1] -// CHECK-SAME{LITERAL}: window = {stride = [1, 1], pad = [1, 1], lhs_dilate = [1, 1], rhs_dilate = [1, 1], reverse = [0, 0]} -// CHECK-SAME: activation_mode = #lmhlo_gpu -// CHECK-SAME: algorithm = 0 -// CHECK-SAME: tensor_ops_enabled = false -// CHECK-SAME: operand_0_layout = [1, 3, 2, 0] -// CHECK-SAME: operand_1_layout = [2, 1, 0, 3] -// CHECK-SAME: result_layout = [1, 3, 2, 0] -// CHECK-SAME: batch_group_count = 1 : i64 -// CHECK-SAME: feature_group_count = 1 : i64 -// CHECK-SAME: precision_config = [#mhlo, #mhlo, #mhlo, #mhlo] -// CHECK-SAME: result_scale = 1.000000e+00 : f64 -// CHECK-SAME: side_input_scale = 1.000000e+00 -// CHECK-SAME: (memref<1x17x9x9xf16, #map{{.*}}>, memref<3x3x17x32xf16, #map{{.*}}>, memref<32xf16>, memref<1x32x9x9xf16, #{{.*}}>, memref<0xui8>) -> () - -HloModule FusedConvForwardSideInput - -ENTRY main { - %input = f16[1,17,9,9]{1,3,2,0} parameter(0) - %filter = f16[3,3,17,32]{2,1,0,3} parameter(1) - %bias = f16[32]{0} parameter(2) - %side = f16[32]{0} parameter(3) - ROOT %custom-call.2 = (f16[1,32,9,9]{1,3,2,0}, u8[0]{0}) custom-call(f16[1,17,9,9]{1,3,2,0} %input, f16[3,3,17,32]{2,1,0,3} %filter, f16[32]{0} %bias, f16[32]{0} %side), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01, custom_call_target="__cudnn$convBiasActivationForward", backend_config="{ \"cudnn_conv_backend_config\": {\"algorithm\":{\"algo_id\":\"0\",\"math_type\":\"DEFAULT_MATH\"},\"conv_result_scale\":1,\"activation_mode\":\"2\",\"side_input_scale\":1}}" -} - -// ----- - -HloModule Infeed - -// CHECK: func @main -// CHECK: "lmhlo.infeed" -// CHECK-SAME: (memref<3xf32>) -> () -ENTRY main { - %tok = token[] parameter(0) - ROOT %infeed = (f32[3]{0}, token[]) infeed(token[] %tok) -} - -// ----- - -HloModule Outfeed - -// CHECK: func @main -// CHECK: "lmhlo.outfeed" -// CHECK-SAME: config = "" -// CHECK-SAME: (memref<3xf32>) -> () -ENTRY main { - %source = f32[3] parameter(0) - %tok = token[] parameter(1) - ROOT %o = token[] outfeed(f32[3] %source, token[] %tok) -} - -// ----- - -HloModule Outfeed - -// CHECK: func @main -// CHECK: "lmhlo.custom_call" -// CHECK: call_target_name = "foo" -// CHECK: "lmhlo.outfeed" -// CHECK-SAME: config = "" -// CHECK-SAME: (memref<3xf32>, memref<5xf16>) -> () -ENTRY main { - %tok = token[] parameter(0) - %tuple = (f32[3], f16[5]) custom-call(),custom_call_target="foo" - ROOT %o = token[] outfeed((f32[3], f16[5]) %tuple, token[] %tok) -} - -// ----- - -HloModule TestModule - -// CHECK: func @main -// CHECK: "lmhlo.rng_get_and_update_state"(%{{.*}}) <{delta = 131072 : i64}> : (memref<2xui64>) -> () -ENTRY main { - ROOT %rng-get-and-update-state = u64[2]{0} rng-get-and-update-state(), delta=131072 -} - -// ----- - -HloModule TestReplicaId - -// CHECK: func @main -// CHECK: "lmhlo.replica_id" -// CHECK-SAME: (memref) -> () -ENTRY main { - ROOT %replica_id = u32[] replica-id() -} - -// ----- - -HloModule fft - -// CHECK: func @main -// CHECK: "lmhlo.fft" -// CHECK-SAME: fft_length = dense<[8, 32]> : tensor<2xi64> -// CHECK-SAME: fft_type = #mhlo -ENTRY main { - %input = c64[5,8,32] parameter(0) - ROOT %fft = c64[5,8,32] fft(c64[5,8,32] %input), fft_type=IFFT, fft_length={8,32} -} - -// ----- - -HloModule TriangularSolve_module - -// CHECK: func @main -// CHECK: "lmhlo.triangular_solve" -// CHECK-SAME: layout_a = dense<[1, 0]> : tensor<2xindex> -// CHECK-SAME: layout_b = dense<[1, 0]> : tensor<2xindex> -// CHECK-SAME: layout_output = dense<[1, 0]> : tensor<2xindex> -// CHECK-SAME: left_side = false -// CHECK-SAME: lower = true -// CHECK-SAME: transpose_a = #mhlo -// CHECK-SAME: unit_diagonal = false -ENTRY main { - %a = f32[4,4]{1,0} parameter(0) - %b = f32[3,4]{1,0} parameter(1) - ROOT %triangular-solve = f32[3,4]{1,0} triangular-solve(f32[4,4]{1,0} %a, f32[3,4]{1,0} %b), lower=true, transpose_a=NO_TRANSPOSE -} - -// ----- - -HloModule CustomCallWithTypedFFIBackendConfig - -// CHECK: func @main -// CHECK: "lmhlo.custom_call" -// CHECK: api_version = 4 : i32 -// CHECK-SAME: backend_config = { -// CHECK-SAME: user_attr0 = 123 : i32 -// CHECK-SAME: user_attr1 = dense<42> : tensor -// CHECK-SAME: } -// CHECK-SAME: num_args = 1 -// CHECK-SAME: num_results = 2 -// CHECK-SAME: args_to_target_args = [] -// CHECK-SAME: results_to_target_results = [0] -ENTRY main { - %tok = token[] parameter(0) - ROOT %call = (f32[3], token[]) custom-call (%tok), custom_call_target="foo", - api_version=API_VERSION_TYPED_FFI, - backend_config="{user_attr0 = 123 : i32, user_attr1 = dense<42> : tensor}" -} - -// ----- - -HloModule CustomCallWithTokens - -// CHECK: func @main -// CHECK: "lmhlo.custom_call" -// CHECK: num_args = 1 -// CHECK-SAME: num_results = 2 -// CHECK-SAME: args_to_target_args = [] -// CHECK-SAME: results_to_target_results = [0] -ENTRY main { - %tok = token[] parameter(0) - ROOT %call = (f32[3], token[]) custom-call (%tok), custom_call_target="foo", - backend_config="" -} - -// ----- - -HloModule CustomCallWithTokens - -// CHECK: func @main -// CHECK: "lmhlo.custom_call" -// CHECK: num_args = 3 -// CHECK-SAME: num_results = 3 -// CHECK-SAME: args_to_target_args = [1] -// CHECK-SAME: results_to_target_results = [0, 2] -ENTRY main { - %tok = token[] parameter(0) - %input = f32[5,8,32] parameter(1) - ROOT %call = (f32[3]{0}, token[], f32[3]) custom-call (%tok, %input, %tok), - custom_call_target="foo", - backend_config="" -} - -// ----- - -HloModule CustomCallWithTokens - -// CHECK: func @main -// CHECK: "lmhlo.custom_call" -// CHECK: num_args = 3 -// CHECK-SAME: num_results = 1 -// CHECK-SAME: args_to_target_args = [1] -// CHECK-SAME: results_to_target_results = [0] -ENTRY main { - %tok = token[] parameter(0) - %input = f32[5,8,32] parameter(1) - ROOT %call = f32[3] custom-call (%tok, %input, %tok), - custom_call_target="foo", - backend_config="" -} - -// ----- - -HloModule CustomCallWithTokens - -// CHECK: func @main -// CHECK: "lmhlo.custom_call" -// CHECK: num_args = 1 -// CHECK-SAME: num_results = 4 -// CHECK-SAME: args_to_target_args = [0] -// CHECK-SAME: results_to_target_results = [1] -ENTRY main { - %input = f32[5,8,32] parameter(0) - ROOT %call = (token[], f32[3]{0}, token[], token[]) custom-call (%input), - custom_call_target="foo", - backend_config="" -} - -// ----- -// CHECK: func @main -// CHECK: "lmhlo.while"(%{{.*}}) ({ -HloModule WhileConstantCondition - -%body { - ROOT %parameter.5 = (f32[5]{0}) parameter(0) -} - -%cond { - %parameter.12 = (f32[5]{0}) parameter(0) - ROOT %constant_1 = pred[] constant(false) -} - -ENTRY %main (parameter.1: f32[5]) -> (f32[5]) { - %parameter.1 = f32[5]{0} parameter(0) - %tuple = (f32[5]{0}) tuple(f32[5]{0} %parameter.1) - ROOT %while.19 = (f32[5]{0}) while((f32[5]{0}) %tuple), condition=%cond, body=%body -} - -// ----- - -HloModule CustomCallNoComputation - -// CHECK: "lmhlo.custom_call" -// CHECK: call_target_name = "__custom" - -ENTRY main { - param = f32[] parameter(0) - ROOT cr = f32[] custom-call(param), custom_call_target="__custom" -} - -// ----- - -HloModule CustomCallWithComputation - -// CHECK: "lmhlo.custom_call" -// CHECK: call_target_name = "__custom" -// CHECK: %0 = mhlo.add -// CHECK: mhlo.return %0 - -computation1 { - param_0 = f32[] parameter(0) - ROOT r = f32[] add(param_0, param_0) -} - -ENTRY main { - param = f32[] parameter(0) - ROOT cr = f32[] custom-call(param), custom_call_target="__custom", - to_apply=computation1 -} - -// ----- - -HloModule Send - -// CHECK: func @main -// CHECK: %[[ARG1:arg[0-9]+]]: memref<16xi8> {lmhlo.params = 1 : index} -// CHECK: %[[VIEW:.*]] = memref.view %[[ARG1]][%c0][] -// CHECK: %[[TOKEN:.*]] = "lmhlo.send"(%[[VIEW]]) -// CHECK: channel_handle = #mhlo.channel_handle, -// CHECK: frontend_attributes = {_xla_dcn_recv_channel = "2", -// CHECK: _xla_host_transfer_handler_name = "undef", -// CHECK: _xla_host_transfer_rendezvous = "undef"} -// CHECK: is_host_transfer = true -// CHECK: : (memref<4xf32>) -> !mhlo.token -// CHECK: "lmhlo.send_done"(%0) -// CHECK: channel_handle = #mhlo.channel_handle, -// CHECK is_host_transfer = true -// CHECK: : (!mhlo.token) -> () -ENTRY main { - %tok = token[] parameter(0) - %buf = f32[4]{0} parameter(1) - %send = (f32[4]{0}, u32[], token[]) send(f32[4]{0} %buf, token[] %tok), channel_id=1, is_host_transfer=true, frontend_attributes={_xla_dcn_recv_channel="2",_xla_host_transfer_handler_name="undef",_xla_host_transfer_rendezvous="undef"} - ROOT %send-done = token[] send-done((f32[4]{0}, u32[], token[]) %send), channel_id=1, is_host_transfer=true -} - -// ----- - -HloModule Recv - -// CHECK: func @main -// CHECK: %[[ARG1:arg[0-9]+]]: memref<16xi8> {lmhlo.output_index = dense<0> : tensor<1xi64>} -// CHECK: %[[VIEW:.*]] = memref.view %[[ARG1]][%c0][] -// CHECK: %[[TOKEN:.*]] = "lmhlo.recv"(%[[VIEW]]) -// CHECK: channel_handle = #mhlo.channel_handle, -// CHECK: frontend_attributes = {_xla_host_transfer_handler_name = "undef", -// CHECK: _xla_host_transfer_rendezvous = "undef"} -// CHECK: is_host_transfer = true -// CHECK: : (memref<4xf32>) -> !mhlo.token -// CHECK: "lmhlo.recv_done"(%0) -// CHECK: channel_handle = #mhlo.channel_handle, -// CHECK is_host_transfer = true -// CHECK: : (!mhlo.token) -> () -ENTRY main { - %tok = token[] parameter(0) - %recv = (f32[4]{0}, u32[], token[]) recv(token[] %tok), channel_id=1, is_host_transfer=true, frontend_attributes={_xla_host_transfer_handler_name="undef",_xla_host_transfer_rendezvous="undef"} - ROOT %recv-done = (f32[4]{0}, token[]) recv-done((f32[4]{0}, u32[], token[]) %recv), channel_id=1, is_host_transfer=true -} - -// ----- - -HloModule TestAllGatherAsync - -// CHECK: func @main -// CHECK: %[[TOKEN:.*]] = "lmhlo_gpu.all_gather_start"(%{{.*}}, %{{.*}}) < -// CHECK-SAME: all_gather_dimension = 1 : i64 -// CHECK-SAME{LITERAL}: replica_groups = dense<[[0, 1, 2, 3]]> : tensor<1x4xi64> -// CHECK-SAME: use_global_device_ids = false -// CHECK ""lmhlo_gpu.all_gather_done"(%[[TOKEN]]) -ENTRY main { - param0 = f32[10,20] parameter(0) - ags = (f32[10,20], f32[10,80]) all-gather-start(param0), replica_groups={{0,1,2,3}}, - dimensions={1} - ROOT ag = f32[10,80] all-gather-done(ags) -} - -// ----- - -HloModule AsyncReduceScatter - -// CHECK: func @main -// CHECK: %[[TOKEN:.*]] = "lmhlo_gpu.reduce_scatter_start"(%{{.*}}, %{{.*}}) < -// CHECK-SAME{LITERAL}: replica_groups = dense<[[0, 1]]> : tensor<1x2xi64> -// CHECK-SAME: scatter_dimension = 0 -// CHECK-SAME: use_global_device_ids = false -// CHECK: ^bb0([[ARG0:%.*]]: tensor, [[ARG1:%.*]]: tensor): -// CHECK: [[ADD:%.*]] = mhlo.add [[ARG0]], [[ARG1]] -// CHECK: mhlo.return [[ADD]] : tensor -// CHECK: }) : -// CHECK ""lmhlo_gpu.reduce_scatter_done"(%[[TOKEN]]) - -add { - lhs = f32[] parameter(0) - rhs = f32[] parameter(1) - ROOT add = f32[] add(lhs, rhs) -} - -reduce_scatter { - p0 = f32[8] parameter(0) - ROOT result = f32[4] reduce-scatter(p0), replica_groups={{0,1}}, - dimensions={0}, to_apply=add -} - -ENTRY main { - input = f32[8] parameter(0) - rs-start = ((f32[8]), f32[4]) async-start(input), calls=reduce_scatter - ROOT rs-done = f32[4] async-done(rs-start), calls=reduce_scatter -} - -// ----- - -HloModule AsyncAllToAll - -// CHECK: func @main -// CHECK: %[[TOKEN:.*]] = "lmhlo_gpu.all_to_all_start"(%{{.*}}, %{{.*}}) < -// CHECK-SAME{LITERAL}: replica_groups = dense<[[0, 1]]> : tensor<1x2xi64> -// CHECK ""lmhlo_gpu.all_to_all_done"(%[[TOKEN]]) - -all_to_all { - p0 = f32[128,4] parameter(0) - ROOT a2a = f32[128,4] all-to-all(p0), replica_groups={{0,1}} -} - -ENTRY main { - p0 = f32[128,4] parameter(0) - a2a-start = ((f32[128,4]), f32[128,4]) async-start(p0), calls=all_to_all - ROOT a2a-done = f32[128,4] async-done(a2a-start), calls=all_to_all -} - -// ----- - -HloModule TestAllGatherAsyncWithSyncFlagFalse - -// CHECK: func @main -// CHECK: %[[TOKEN:.*]] = "lmhlo_gpu.all_gather_start"(%{{.*}}, %{{.*}}) < -// CHECK-SAME: all_gather_dimension = 1 : i64 -// CHECK-SAME: is_sync = false -// CHECK-SAME: no_parallel_custom_call = false -// CHECK-SAME{LITERAL}: replica_groups = dense<[[0, 1, 2, 3]]> : tensor<1x4xi64> -// CHECK-SAME: use_global_device_ids = false -// CHECK ""lmhlo_gpu.all_gather_done"(%[[TOKEN]]) -ENTRY main { - param0 = f32[10,20] parameter(0) - ags = (f32[10,20], f32[10,80]) all-gather-start(param0), replica_groups={{0,1,2,3}}, - dimensions={1} - ROOT ag = f32[10,80] all-gather-done(ags) -} - -// ----- - -HloModule TestAllGatherAsyncWithSyncFlagTrue - -// CHECK: func @main -// CHECK: %[[TOKEN:.*]] = "lmhlo_gpu.all_gather_start"(%{{.*}}, %{{.*}}) < -// CHECK-SAME: all_gather_dimension = 1 : i64 -// CHECK-SAME: is_sync = true -// CHECK-SAME: no_parallel_custom_call = true -// CHECK-SAME{LITERAL}: replica_groups = dense<[[0, 1, 2, 3]]> : tensor<1x4xi64> -// CHECK-SAME: use_global_device_ids = false -// CHECK ""lmhlo_gpu.all_gather_done"(%[[TOKEN]]) -ENTRY main { - param0 = f32[10,20] parameter(0) - ags = (f32[10,20], f32[10,80]) all-gather-start(param0), replica_groups={{0,1,2,3}}, - dimensions={1}, backend_config="{ \"collective_backend_config\": {\"is_sync\":true, \"no_parallel_custom_call\":true}}" - ROOT ag = f32[10,80] all-gather-done(ags) -} diff --git a/third_party/xla/xla/translate/mhlo_to_lhlo_with_xla/tests/no_opt_ops.hlotxt b/third_party/xla/xla/translate/mhlo_to_lhlo_with_xla/tests/no_opt_ops.hlotxt deleted file mode 100644 index 0c7fc220f73afc..00000000000000 --- a/third_party/xla/xla/translate/mhlo_to_lhlo_with_xla/tests/no_opt_ops.hlotxt +++ /dev/null @@ -1,95 +0,0 @@ -// RUN: xla-translate -split-input-file -hlo-text-to-lhlo %s | FileCheck %s - -HloModule indexed_conditional - -%Negate (x: f32[]) -> f32[] { - %x = f32[] parameter(0) - ROOT %negate = f32[] negate(f32[] %x) -} - -%NegateCond (x: f32[]) -> f32[] { - %x = f32[] parameter(0) - ROOT %negate = f32[] fusion(f32[] %x), kind=kLoop, calls=%Negate -} - -%Identity (y: f32[]) -> f32[] { - %y = f32[] parameter(0) - ROOT %copy = f32[] copy(f32[] %y) -} - -%IdentityCond (x: f32[]) -> f32[] { - %y = f32[] parameter(0) - ROOT %copy = f32[] fusion(f32[] %y), kind=kLoop, calls=%Identity -} - -%Floor (z: f32[]) -> f32[] { - %z = f32[] parameter(0) - ROOT %floor = f32[] floor(f32[] %z) -} - -%FloorCond (x: f32[]) -> f32[] { - %z = f32[] parameter(0) - ROOT %floor = f32[] fusion(f32[] %z), kind=kLoop, calls=%Floor -} - -// CHECK: %{{.*}} = memref.view -// CHECK: "lmhlo.case"(%{{.*}}) ({ -// CHECK: mhlo.negate -// CHECK: "lmhlo.terminator"() : () -> () -// CHECK: }, { -// CHECK: mhlo.copy -// CHECK: "lmhlo.terminator"() : () -> () -// CHECK: }, { -// CHECK: mhlo.floor -// CHECK: "lmhlo.terminator"() : () -> () -// CHECK: }) : (memref) -> () - -ENTRY %Parameters1.v4 () -> (f32[]) { - %constant = s32[] parameter(0) - %constant.1 = f32[] parameter(1) - %constant.2 = f32[] parameter(2) - %constant.3 = f32[] parameter(3) - %conditional = f32[] conditional(s32[] %constant, f32[] %constant.1, f32[] %constant.2, f32[] %constant.3), branch_computations={%NegateCond, %IdentityCond, %FloorCond} - ROOT %t = (f32[]) tuple(%conditional) -} - -// ----- - -HloModule WhileWithScalarS32Result_module - -%Add (a: s32[], b: s32[]) -> s32[] { - %a = s32[] parameter(0) - %b = s32[] parameter(1) - ROOT %add = s32[] add(s32[] %a, s32[] %b) -} - -%body.v3 (prev.1: s32[]) -> s32[] { - %constant = s32[] constant(1) - %prev.1 = s32[] parameter(0) - ROOT %add = s32[] fusion(s32[] %constant, s32[] %prev.1), kind=kLoop, calls=%Add -} - -%Compare (a: s32[], b: s32[]) -> pred[] { - %a = s32[] parameter(0) - %b = s32[] parameter(1) - ROOT %greater-than = pred[] compare(s32[] %a, s32[] %b), direction=GT -} - -%condition.v3 (prev.2: s32[]) -> pred[] { - %constant.1 = s32[] constant(5) - %prev.2 = s32[] parameter(0) - ROOT %greater-than = pred[] fusion(s32[] %constant.1, s32[] %prev.2), kind=kLoop, calls=%Compare -} - -// CHECK: %{{.*}} = memref.view -// CHECK: "lmhlo.while"(%{{.*}}) ({ -// CHECK: mhlo.compare -// CHECK: "lmhlo.terminator"() : () -> () -// CHECK: }, { -// CHECK: mhlo.add -// CHECK: "lmhlo.terminator"() : () -> () -// CHECK: }) : (memref) -> () -ENTRY %WhileWithScalarS32Result.v2 () -> s32[] { - %constant.2 = s32[] constant(0) - ROOT %while = s32[] while(s32[] %constant.2), condition=%condition.v3, body=%body.v3 -} diff --git a/third_party/xla/xla/translate/mhlo_to_lhlo_with_xla/tests/non_identity_layouts.hlotxt b/third_party/xla/xla/translate/mhlo_to_lhlo_with_xla/tests/non_identity_layouts.hlotxt deleted file mode 100644 index d74ec4e3434c93..00000000000000 --- a/third_party/xla/xla/translate/mhlo_to_lhlo_with_xla/tests/non_identity_layouts.hlotxt +++ /dev/null @@ -1,28 +0,0 @@ -// RUN: xla-translate -hlo-text-to-lhlo %s | FileCheck %s - -HloModule TestModule - -// CHECK: #[[MAP:.*]] = affine_map<(d0, d1) -> (d0 + d1 * 3)> - -Fusion { - x = f32[3, 2]{1,0} parameter(0) - ROOT x.copy = f32[3, 2]{0,1} copy(x) -} - -// CHECK: func @TestComputation -ENTRY TestComputation { - x = f32[3, 2]{1,0} parameter(0) - - // CHECK: %[[VIEW:.*]] = memref.view {{.*}} : memref<24xi8> to memref<3x2xf32> - // CHECK: "lmhlo.fusion"() <{backend_config = "{{.*}}"}> ({ - // CHECK: %[[VAL2:.*]] = bufferization.to_tensor %[[VIEW]] : memref<3x2xf32> - // CHECK: %[[VAL3:.*]] = mhlo.copy %[[VAL2]] { - // CHECK-SAME: result_layout = dense<[0, 1]> - // CHECK-SAME: xla_shape = "f32[3,2]{0,1}" - // CHECK-SAME: } : tensor<3x2xf32> - // CHECK: bufferization.materialize_in_destination %[[VAL3:.*]] in - // CHECK-SAME: writable %{{.*}} : (tensor<3x2xf32>, memref<3x2xf32, #[[MAP]]>) - // CHECK: "lmhlo.terminator"() : () -> () - // CHECK: }) : () -> () - ROOT fusion = f32[3, 2]{0,1} fusion(f32[3, 2]{1,0} x), kind=kLoop, calls=Fusion -} diff --git a/third_party/xla/xla/translate/mhlo_to_lhlo_with_xla/translate_registration.cc b/third_party/xla/xla/translate/mhlo_to_lhlo_with_xla/translate_registration.cc deleted file mode 100644 index 4de3c85777234e..00000000000000 --- a/third_party/xla/xla/translate/mhlo_to_lhlo_with_xla/translate_registration.cc +++ /dev/null @@ -1,29 +0,0 @@ -/* Copyright 2020 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "mlir/Tools/mlir-translate/Translation.h" // from @llvm-project -#include "xla/translate/mhlo_to_lhlo_with_xla/mhlo_to_lhlo_with_xla.h" - -//----------------------------------------------------------------------------// -// Hooks for tf-mlir-translate -//----------------------------------------------------------------------------/ - -// MHLO doesn't support explicit layouts, while XLA service does. -// TODO(timshen): remove it once MHLO supports explicit layouts. -static mlir::TranslateToMLIRRegistration HloTextToLhloMlirTranslate( - "hlo-text-to-lhlo", "hlo-text-to-lhlo", - [](llvm::StringRef input, mlir::MLIRContext* context) { - return mlir::HloTextToLhloTranslateFunction(input, context); - }); diff --git a/third_party/xla/xla/translate/mhlo_to_lhlo_with_xla/xla_translate_opt_main.cc b/third_party/xla/xla/translate/xla_translate_opt_main.cc similarity index 100% rename from third_party/xla/xla/translate/mhlo_to_lhlo_with_xla/xla_translate_opt_main.cc rename to third_party/xla/xla/translate/xla_translate_opt_main.cc diff --git a/third_party/xla/xla/types.h b/third_party/xla/xla/types.h index ef9dd11da6893a..f1891b9a8c076e 100644 --- a/third_party/xla/xla/types.h +++ b/third_party/xla/xla/types.h @@ -19,11 +19,11 @@ limitations under the License. #include #include #include +#include #include -#include "absl/strings/str_format.h" -#include "Eigen/Core" // from @eigen_archive // IWYU pragma: export -#include "ml_dtypes/include/int4.h" // from @ml_dtypes // IWYU pragma: export +#include "Eigen/Core" // from @eigen_archive // IWYU pragma: export +#include "tsl/platform/ml_dtypes.h" // IWYU pragma: export namespace xla { @@ -41,18 +41,26 @@ struct is_complex> : std::true_type {}; template inline constexpr bool is_complex_v = is_complex::value; +template +struct is_specialized_floating_point + : std::bool_constant::is_specialized && + !std::numeric_limits::is_integer> {}; + template inline constexpr bool is_specialized_floating_point_v = - std::numeric_limits::is_specialized && - !std::numeric_limits::is_integer; + is_specialized_floating_point::value; + +template +struct is_specialized_integral + : std::bool_constant::is_specialized && + std::numeric_limits::is_integer> {}; template inline constexpr bool is_specialized_integral_v = - std::numeric_limits::is_specialized && - std::numeric_limits::is_integer; + is_specialized_integral::value; -using u4 = ml_dtypes::uint4; -using s4 = ml_dtypes::int4; +using u4 = tsl::uint4; +using s4 = tsl::int4; } // namespace xla @@ -60,12 +68,12 @@ using s4 = ml_dtypes::int4; namespace ml_dtypes { template void AbslStringify(Sink& sink, const xla::s4& i) { - absl::Format(&sink, "%d", static_cast(i)); + sink.Append(std::to_string(static_cast(i))); } template void AbslStringify(Sink& sink, const xla::u4& i) { - absl::Format(&sink, "%d", static_cast(i)); + sink.Append(std::to_string(static_cast(i))); } } // namespace ml_dtypes @@ -112,6 +120,16 @@ struct make_specialized_signed { template using make_specialized_signed_t = typename make_specialized_signed::type; +template +struct has_negative_zero + : std::bool_constant::is_iec559> {}; + +template <> +struct has_negative_zero : std::bool_constant {}; + +template +inline constexpr bool has_negative_zero_v = has_negative_zero::value; + } // namespace xla #endif // XLA_TYPES_H_ diff --git a/third_party/xla/xla/util.h b/third_party/xla/xla/util.h index aae0a84bbaf575..23df713bea51a7 100644 --- a/third_party/xla/xla/util.h +++ b/third_party/xla/xla/util.h @@ -48,6 +48,7 @@ limitations under the License. #include "Eigen/Core" // from @eigen_archive #include "xla/status.h" #include "xla/status_macros.h" +#include "xla/types.h" #include "xla/xla_data.pb.h" #include "tsl/lib/math/math_util.h" #include "tsl/platform/bfloat16.h" @@ -573,9 +574,7 @@ auto SignAndMagnitude(T x) { BitType x_abs_bits = Eigen::numext::bit_cast(Eigen::numext::abs(x)); const BitType x_bits = Eigen::numext::bit_cast(x); const BitType x_sign = x_bits ^ x_abs_bits; - if constexpr (std::is_same_v || - std::is_same_v || - std::is_same_v) { + if constexpr (!has_negative_zero_v) { // f8e4m3b11, f8e4m3fnuz, and f8e5m2fnuz don't support -0, adjust negative // numbers to fill in the gap. if (x_sign) { diff --git a/third_party/xla/xla/xla.proto b/third_party/xla/xla/xla.proto index 8efa57455f948b..23719bdcad18cf 100644 --- a/third_party/xla/xla/xla.proto +++ b/third_party/xla/xla/xla.proto @@ -172,7 +172,12 @@ message DebugOptions { // useful when accelerating structured sparsity. int32 xla_cpu_sparse_cuda_threads = 207; - // Allows xla to increase the output precision of floating point operations. + // Allows xla to increase the output precision of floating point operations + // and all floating-point conversions to be simplified, including those + // that affect the numerics. The `FloatNormalization` pass inserts many + // `f32 -> bf16 -> f32` conversion pairs. These are not removed by the + // `AlgebraicSimplifier`, as that will only simplify conversions that are + // no-ops, e.g. `bf16 -> f32 -> bf16`. Removing these improves accuracy. bool xla_allow_excess_precision = 122; // Crashes the program when any kind of verification fails, instead of just @@ -446,6 +451,9 @@ message DebugOptions { // if `xla_gpu_enable_custom_fusion` set to true. string xla_gpu_enable_custom_fusions_re = 264; + // If true, use XLA runtime for XLA:GPU backend. + bool xla_gpu_enable_address_computation_fusion = 105; + reserved 233; // was xla_gpu_enable_gpu2_runtime reserved 234; // was xla_gpu_enable_gpu2_hal @@ -508,12 +516,8 @@ message DebugOptions { // scratch), so this can be multiplied by quite a lot. int64 xla_gpu_redzone_padding_bytes = 228; - // Allows all floating-point conversions to be simplified, including those - // that affect the numerics. The `FloatNormalization` pass inserts many - // `f32 -> bf16 -> f32` conversion pairs. These are not removed by the - // `AlgebraicSimplifier`, as that will only simplify conversions that are - // no-ops, e.g. `bf16 -> f32 -> bf16`. Removing these improves accuracy. - bool xla_gpu_simplify_all_fp_conversions = 168; + // Deprecated. Use xla_allow_excess_precision instead. + bool xla_gpu_simplify_all_fp_conversions = 168 [deprecated = true]; // An experimental option to force all layouts present in the // after-optimizations HLO to be descending, e.g. @@ -672,10 +676,28 @@ message DebugOptions { // Enable NCCL user buffers. bool xla_gpu_enable_nccl_user_buffers = 267; + // Enable NCCL communicator splitting. + bool xla_gpu_enable_nccl_comm_splitting = 272; + // If enabled, uses the libnvptxcompiler library to compile PTX to cuBIN. bool xla_gpu_enable_libnvptxcompiler = 269; - // Next id: 270 + bool xla_gpu_enable_dot_strength_reduction = 270; + + // If enabled, uses bf16_6way gemm to compute F32 gemm. + bool xla_gpu_enable_bf16_6way_gemm = 271; + + // Specify the maximum number of channels(SMs) NCCL + // will use for collective operations. + int64 xla_gpu_nccl_collective_max_nchannels = 273; + + // Specify the maximum number of channels(SMs) NCCL + // will use for p2p operations. + int64 xla_gpu_nccl_p2p_max_nchannels = 274; + + bool xla_gpu_enable_mlir_emitters = 275; + + // Next id: 276 // Extra options to pass to the compilation backend (e.g. LLVM); specific // interpretation of these values is left to the backend. @@ -713,7 +735,7 @@ message ShardableValueUpdatePairProto { // will have an effect on every platform. // // When adding new fields, keep in mind that boolean fields default to false. -// Next id: 23. +// Next id: 24. message ExecutionOptions { // This optional field's layout is used as a hint when storing the output of // this computation. Subsequent transfers of this output array to the client @@ -775,6 +797,18 @@ message ExecutionOptions { reserved 13; // Was broadcast_replicated_parameters_via_collectives + // Allows sharding propagation to propagate to the parameters. This changes + // the input shape of the computation (which is undesirable), but it can be + // used to allow to run partial compilation to determine what would be the + // input sharding of a computation if XLA would be allowed to propagate the + // sharding which can be used by higher level framework as a way to query + // intermediate sharding of operations when multiple computation would be + // chained and merged together. + // This is a vector of bool, because the user can control which parameters can + // have the sharding substituted. If only one boolean value is passed in the + // vector that is interpreted as the value to be applied for every parameter. + repeated bool allow_spmd_sharding_propagation_to_parameters = 23; + // Allows sharding propagation to propagate to the outputs. This changes the // output shape of the computation (which is undesirable), but it can be used // to allow to run partial compilation to determine what would be the output @@ -813,7 +847,7 @@ message ExecutionOptions { // Serialization of HloModuleConfig. See the C++ class definition for // descriptions of each field. // There are no guarantees of backwards or forwards compatibility. -// Next id: 33. +// Next id: 34. message HloModuleConfigProto { enum FusionConfigCollection { OFF = 0; // Do not collect configuration. @@ -861,6 +895,7 @@ message HloModuleConfigProto { repeated BoolList phase_ordering_config = 24; int32 phase_index = 25; reserved 26; // Was flag_config + repeated bool allow_spmd_sharding_propagation_to_parameters = 33; repeated bool allow_spmd_sharding_propagation_to_output = 27; map analysis_allowance_map = 28; xla.PrecisionConfig.Precision matrix_unit_operand_precision = 29; diff --git a/third_party/xla/xla/xla_data.proto b/third_party/xla/xla/xla_data.proto index 05588e3d2b7c7f..72458ff52edd24 100644 --- a/third_party/xla/xla/xla_data.proto +++ b/third_party/xla/xla/xla_data.proto @@ -380,16 +380,11 @@ message OpMetadata { // Deprecated, use [ProfileInfo][profile_type] instead. repeated ProfileType profile_type = 5 [deprecated = true]; - // HloPassMetadata.pass_id of the pass that created this HLO instruction - // object. Should never be copied between HLO instructions. Zero if unset and - // -1 if the instruction was created before HLO passes began. - int64 creation_pass_id = 6; + reserved 6; + reserved "creation_pass_id"; - // HloPassMetadata.pass_id of the pass that created the logical functionality - // that this HLO instruction represents. Should be copied between HLO - // instructions that correspond across compilation passes. Zero if unset and - // -1 if the instruction was created before HLO passes began. - int64 logical_creation_pass_id = 7; + reserved 7; + reserved "logical_creation_pass_id"; // The footprint of the generated code for the instruction. int64 size_of_generated_code_in_bytes = 8; @@ -719,6 +714,36 @@ message DotDimensionNumbers { repeated int64 rhs_batch_dimensions = 4; } +enum SparsityType { + SPARSITY_INVALID = 0; + + // Structured N:M sparsity. + SPARSITY_STRUCTURED_N_M = 1; + + // Next: 2 +} + +// Contains sparsity metadata for a sparse dot operation. +// The only supported type atm is structured 2:4 sparsity, which is natively +// supported on NVidia GPUs. +// Restrictions: +// - only one operand of the dot operation may be sparse; +// - only the contracting dimension may be sparse. +message SparsityDescriptor { + SparsityType type = 1; + + // Sparse operand index (0 or 1). + int32 index = 2; + // Sparse dimension number. + int32 dimension = 3; + + // Structured N:M sparsity (N < M). + int32 n = 4; + int32 m = 5; + + // Next: 6 +} + enum RandomDistribution { RNG_INVALID = 0;